Compare commits

...

37 Commits

Author SHA1 Message Date
Greyson Lalonde
515ce8f55f Merge branch 'gl/feat/async-crew-support' into gl/feat/async-flow-kickoff 2025-12-02 19:07:03 -05:00
Greyson Lalonde
1d40f5d83c Merge branch 'gl/feat/async-task-support' into gl/feat/async-crew-support 2025-12-02 19:06:26 -05:00
Greyson Lalonde
3afac2a696 Merge branch 'gl/feat/async-knowledge-support' into gl/feat/async-task-support 2025-12-02 19:05:51 -05:00
Greyson Lalonde
5fab437b7f Merge branch 'gl/feat/async-memory-support' into gl/feat/async-knowledge-support 2025-12-02 19:05:18 -05:00
Greyson Lalonde
30684f387e Merge branch 'gl/feat/async-agent-executor-support' into gl/feat/async-memory-support 2025-12-02 19:04:43 -05:00
Greyson Lalonde
f2b4efe7fa Merge branch 'gl/feat/async-crew-support' into gl/feat/async-flow-kickoff 2025-12-02 18:06:07 -05:00
Greyson Lalonde
4f175fdd6f Merge branch 'gl/feat/async-task-support' into gl/feat/async-crew-support 2025-12-02 18:05:38 -05:00
Greyson LaLonde
d72b79f932 Merge branch 'main' into gl/feat/async-flow-kickoff 2025-12-02 17:53:50 -05:00
Greyson LaLonde
e8638d318d Merge branch 'main' into gl/feat/async-crew-support 2025-12-02 17:53:34 -05:00
Greyson Lalonde
d2c880c6b3 chore: dry out duplicate logic 2025-12-02 17:52:17 -05:00
Greyson Lalonde
087f6d25a9 feat: add akickoff alias to flow 2025-12-02 17:22:51 -05:00
Greyson Lalonde
c57e325482 feat: add native async crew support 2025-12-02 16:47:53 -05:00
Greyson LaLonde
fdb7047780 Merge branch 'main' into gl/feat/async-task-support 2025-12-02 16:43:13 -05:00
Greyson LaLonde
adb485f7f7 Merge branch 'main' into gl/feat/async-knowledge-support 2025-12-02 16:43:06 -05:00
Greyson LaLonde
ee64bd426e Merge branch 'main' into gl/feat/async-memory-support 2025-12-02 16:42:52 -05:00
Greyson LaLonde
37b80ee937 Merge branch 'main' into gl/feat/async-agent-executor-support 2025-12-02 16:40:14 -05:00
Greyson LaLonde
09f1ba6956 feat: native async tool support
Some checks failed
CodeQL Advanced / Analyze (actions) (push) Has been cancelled
CodeQL Advanced / Analyze (python) (push) Has been cancelled
Notify Downstream / notify-downstream (push) Has been cancelled
Check Documentation Broken Links / Check broken links (push) Has been cancelled
Mark stale issues and pull requests / stale (push) Has been cancelled
- add async support for tools
- add async tool tests
- improve tool decorator typing
- fix _run backward compatibility
- update docs and improve readability of docstrings
2025-12-02 16:39:58 -05:00
Greyson Lalonde
bf9ccd418a feat: add async task support 2025-12-02 16:33:20 -05:00
Greyson Lalonde
bd95356ec5 feat: async knowledge support; add tests 2025-12-02 14:59:43 -05:00
Greyson Lalonde
441591d592 feat: add async ops to memory feat; create tests 2025-12-02 13:09:52 -05:00
Greyson Lalonde
132b6b224a feat: add aiosqlite dep; regenerate lockfile 2025-12-02 12:13:42 -05:00
Greyson Lalonde
4e2916d71a chore: add tests 2025-12-02 09:46:38 -05:00
Greyson Lalonde
0c4a0e1fda feat: add async execution support to agent executor 2025-12-02 09:30:56 -05:00
Greyson Lalonde
9c4126e0d8 chore: make docstrings a little more readable 2025-12-02 09:06:36 -05:00
Greyson Lalonde
5156fc4792 chore: update docs 2025-12-02 08:57:04 -05:00
Greyson Lalonde
c600b26ca6 fix: ensure _run backward compat 2025-12-02 08:36:03 -05:00
Greyson Lalonde
162a106002 chore: improve tool decorator typing 2025-12-02 00:32:10 -05:00
Greyson Lalonde
be33c8e3e5 feat: add async support for tools, add async tool tests 2025-12-02 00:03:28 -05:00
Greyson LaLonde
20704742e2 feat: async llm support
Some checks failed
CodeQL Advanced / Analyze (actions) (push) Has been cancelled
CodeQL Advanced / Analyze (python) (push) Has been cancelled
Check Documentation Broken Links / Check broken links (push) Has been cancelled
Notify Downstream / notify-downstream (push) Has been cancelled
Build uv cache / build-cache (3.10) (push) Has been cancelled
Build uv cache / build-cache (3.11) (push) Has been cancelled
Build uv cache / build-cache (3.12) (push) Has been cancelled
Build uv cache / build-cache (3.13) (push) Has been cancelled
Mark stale issues and pull requests / stale (push) Has been cancelled
feat: introduce async contract to BaseLLM

feat: add async call support for:

Azure provider

Anthropic provider

OpenAI provider

Gemini provider

Bedrock provider

LiteLLM provider

chore: expand scrubbed header fields (conftest, anthropic, bedrock)

chore: update docs to cover async functionality

chore: update and harden tests to support acall; re-add uri for cassette compatibility

chore: generate missing cassette

fix: ensure acall is non-abstract and set supports_tools = true for supported Anthropic models

chore: improve Bedrock async docstring and general test robustness
2025-12-01 18:56:56 -05:00
Greyson LaLonde
59180e9c9f fix: ensure supports_tools is true for all supported anthropic models
Some checks failed
CodeQL Advanced / Analyze (actions) (push) Has been cancelled
CodeQL Advanced / Analyze (python) (push) Has been cancelled
Notify Downstream / notify-downstream (push) Has been cancelled
2025-12-01 07:21:09 -05:00
Greyson LaLonde
3ce019b07b chore: pin dependencies in crewai, crewai-tools, devtools
Some checks failed
CodeQL Advanced / Analyze (actions) (push) Has been cancelled
CodeQL Advanced / Analyze (python) (push) Has been cancelled
Notify Downstream / notify-downstream (push) Has been cancelled
Build uv cache / build-cache (3.10) (push) Has been cancelled
Build uv cache / build-cache (3.11) (push) Has been cancelled
Build uv cache / build-cache (3.12) (push) Has been cancelled
Build uv cache / build-cache (3.13) (push) Has been cancelled
Mark stale issues and pull requests / stale (push) Has been cancelled
2025-11-30 19:51:20 -05:00
Greyson LaLonde
2355ec0733 feat: create sys event types and handler
feat: add system event types and handler

chore: add tests and improve signal-related error logging
2025-11-30 17:44:40 -05:00
Greyson LaLonde
c925d2d519 chore: restructure test env, cassettes, and conftest; fix flaky tests
Some checks failed
Build uv cache / build-cache (3.10) (push) Has been cancelled
Build uv cache / build-cache (3.11) (push) Has been cancelled
Build uv cache / build-cache (3.12) (push) Has been cancelled
Build uv cache / build-cache (3.13) (push) Has been cancelled
CodeQL Advanced / Analyze (actions) (push) Has been cancelled
CodeQL Advanced / Analyze (python) (push) Has been cancelled
Notify Downstream / notify-downstream (push) Has been cancelled
Mark stale issues and pull requests / stale (push) Has been cancelled
Consolidates pytest config, standardizes env handling, reorganizes cassette layout, removes outdated VCR configs, improves sync with threading.Condition, updates event-waiting logic, ensures cleanup, regenerates Gemini cassettes, and reverts unintended test changes.
2025-11-29 16:55:24 -05:00
Lorenze Jay
bc4e6a3127 feat: bump versions to 1.6.1 (#3993)
Some checks failed
CodeQL Advanced / Analyze (actions) (push) Has been cancelled
CodeQL Advanced / Analyze (python) (push) Has been cancelled
Notify Downstream / notify-downstream (push) Has been cancelled
Mark stale issues and pull requests / stale (push) Has been cancelled
* feat: bump versions to 1.6.1

* chore: update crewAI dependency version to 1.6.1 in project templates
2025-11-28 17:57:15 -08:00
Vidit Ostwal
37526c693b Fixing ChatCompletionsClinet call (#3910)
* Fixing ChatCompletionsClinet call

* Moving from json-object -> JsonSchemaFormat

* Regex handling

* Adding additionalProperties explicitly

* fix: ensure additionalProperties is recursive

---------

Co-authored-by: Greyson LaLonde <greyson.r.lalonde@gmail.com>
Co-authored-by: Lorenze Jay <63378463+lorenzejay@users.noreply.github.com>
2025-11-28 17:33:53 -08:00
Greyson LaLonde
c59173a762 fix: ensure async methods are executable for annotations 2025-11-28 19:54:40 -05:00
Lorenze Jay
4d8eec96e8 refactor: enhance model validation and provider inference in LLM class (#3976)
* refactor: enhance model validation and provider inference in LLM class

- Updated the model validation logic to support pattern matching for new models and "latest" versions, improving flexibility for various providers.
- Refactored the `_validate_model_in_constants` method to first check hardcoded constants and then fall back to pattern matching.
- Introduced `_matches_provider_pattern` to streamline provider-specific model checks.
- Enhanced the `_infer_provider_from_model` method to utilize pattern matching for better provider inference.

This refactor aims to improve the extensibility of the LLM class, allowing it to accommodate new models without requiring constant updates to the hardcoded lists.

* feat: add new Anthropic model versions to constants

- Introduced "claude-opus-4-5-20251101" and "claude-opus-4-5" to the AnthropicModels and ANTHROPIC_MODELS lists for enhanced model support.
- Added "anthropic.claude-opus-4-5-20251101-v1:0" to BedrockModels and BEDROCK_MODELS to ensure compatibility with the latest model offerings.
- Updated test cases to ensure proper environment variable handling for model validation, improving robustness in testing scenarios.

* dont infer this way - dropped
2025-11-28 13:54:40 -08:00
324 changed files with 18696 additions and 4022 deletions

161
.env.test Normal file
View File

@@ -0,0 +1,161 @@
# =============================================================================
# Test Environment Variables
# =============================================================================
# This file contains all environment variables needed to run tests locally
# in a way that mimics the GitHub Actions CI environment.
# =============================================================================
# -----------------------------------------------------------------------------
# LLM Provider API Keys
# -----------------------------------------------------------------------------
OPENAI_API_KEY=fake-api-key
ANTHROPIC_API_KEY=fake-anthropic-key
GEMINI_API_KEY=fake-gemini-key
AZURE_API_KEY=fake-azure-key
OPENROUTER_API_KEY=fake-openrouter-key
# -----------------------------------------------------------------------------
# AWS Credentials
# -----------------------------------------------------------------------------
AWS_ACCESS_KEY_ID=fake-aws-access-key
AWS_SECRET_ACCESS_KEY=fake-aws-secret-key
AWS_DEFAULT_REGION=us-east-1
AWS_REGION_NAME=us-east-1
# -----------------------------------------------------------------------------
# Azure OpenAI Configuration
# -----------------------------------------------------------------------------
AZURE_ENDPOINT=https://fake-azure-endpoint.openai.azure.com
AZURE_OPENAI_ENDPOINT=https://fake-azure-endpoint.openai.azure.com
AZURE_OPENAI_API_KEY=fake-azure-openai-key
AZURE_API_VERSION=2024-02-15-preview
OPENAI_API_VERSION=2024-02-15-preview
# -----------------------------------------------------------------------------
# Google Cloud Configuration
# -----------------------------------------------------------------------------
#GOOGLE_CLOUD_PROJECT=fake-gcp-project
#GOOGLE_CLOUD_LOCATION=us-central1
# -----------------------------------------------------------------------------
# OpenAI Configuration
# -----------------------------------------------------------------------------
OPENAI_BASE_URL=https://api.openai.com/v1
OPENAI_API_BASE=https://api.openai.com/v1
# -----------------------------------------------------------------------------
# Search & Scraping Tool API Keys
# -----------------------------------------------------------------------------
SERPER_API_KEY=fake-serper-key
EXA_API_KEY=fake-exa-key
BRAVE_API_KEY=fake-brave-key
FIRECRAWL_API_KEY=fake-firecrawl-key
TAVILY_API_KEY=fake-tavily-key
SERPAPI_API_KEY=fake-serpapi-key
SERPLY_API_KEY=fake-serply-key
LINKUP_API_KEY=fake-linkup-key
PARALLEL_API_KEY=fake-parallel-key
# -----------------------------------------------------------------------------
# Exa Configuration
# -----------------------------------------------------------------------------
EXA_BASE_URL=https://api.exa.ai
# -----------------------------------------------------------------------------
# Web Scraping & Automation
# -----------------------------------------------------------------------------
BRIGHT_DATA_API_KEY=fake-brightdata-key
BRIGHT_DATA_ZONE=fake-zone
BRIGHTDATA_API_URL=https://api.brightdata.com
BRIGHTDATA_DEFAULT_TIMEOUT=600
BRIGHTDATA_DEFAULT_POLLING_INTERVAL=1
OXYLABS_USERNAME=fake-oxylabs-user
OXYLABS_PASSWORD=fake-oxylabs-pass
SCRAPFLY_API_KEY=fake-scrapfly-key
SCRAPEGRAPH_API_KEY=fake-scrapegraph-key
BROWSERBASE_API_KEY=fake-browserbase-key
BROWSERBASE_PROJECT_ID=fake-browserbase-project
HYPERBROWSER_API_KEY=fake-hyperbrowser-key
MULTION_API_KEY=fake-multion-key
APIFY_API_TOKEN=fake-apify-token
# -----------------------------------------------------------------------------
# Database & Vector Store Credentials
# -----------------------------------------------------------------------------
SINGLESTOREDB_URL=mysql://fake:fake@localhost:3306/fake
SINGLESTOREDB_HOST=localhost
SINGLESTOREDB_PORT=3306
SINGLESTOREDB_USER=fake-user
SINGLESTOREDB_PASSWORD=fake-password
SINGLESTOREDB_DATABASE=fake-database
SINGLESTOREDB_CONNECT_TIMEOUT=30
SNOWFLAKE_USER=fake-snowflake-user
SNOWFLAKE_PASSWORD=fake-snowflake-password
SNOWFLAKE_ACCOUNT=fake-snowflake-account
SNOWFLAKE_WAREHOUSE=fake-snowflake-warehouse
SNOWFLAKE_DATABASE=fake-snowflake-database
SNOWFLAKE_SCHEMA=fake-snowflake-schema
WEAVIATE_URL=http://localhost:8080
WEAVIATE_API_KEY=fake-weaviate-key
EMBEDCHAIN_DB_URI=sqlite:///test.db
# Databricks Credentials
DATABRICKS_HOST=https://fake-databricks.cloud.databricks.com
DATABRICKS_TOKEN=fake-databricks-token
DATABRICKS_CONFIG_PROFILE=fake-profile
# MongoDB Credentials
MONGODB_URI=mongodb://fake:fake@localhost:27017/fake
# -----------------------------------------------------------------------------
# CrewAI Platform & Enterprise
# -----------------------------------------------------------------------------
# setting CREWAI_PLATFORM_INTEGRATION_TOKEN causes these test to fail:
#=========================== short test summary info ============================
#FAILED tests/test_context.py::TestPlatformIntegrationToken::test_platform_context_manager_basic_usage - AssertionError: assert 'fake-platform-token' is None
# + where 'fake-platform-token' = get_platform_integration_token()
#FAILED tests/test_context.py::TestPlatformIntegrationToken::test_context_var_isolation_between_tests - AssertionError: assert 'fake-platform-token' is None
# + where 'fake-platform-token' = get_platform_integration_token()
#FAILED tests/test_context.py::TestPlatformIntegrationToken::test_multiple_sequential_context_managers - AssertionError: assert 'fake-platform-token' is None
# + where 'fake-platform-token' = get_platform_integration_token()
#CREWAI_PLATFORM_INTEGRATION_TOKEN=fake-platform-token
CREWAI_PERSONAL_ACCESS_TOKEN=fake-personal-token
CREWAI_PLUS_URL=https://fake.crewai.com
# -----------------------------------------------------------------------------
# Other Service API Keys
# -----------------------------------------------------------------------------
ZAPIER_API_KEY=fake-zapier-key
PATRONUS_API_KEY=fake-patronus-key
MINDS_API_KEY=fake-minds-key
HF_TOKEN=fake-hf-token
# -----------------------------------------------------------------------------
# Feature Flags/Testing Modes
# -----------------------------------------------------------------------------
CREWAI_DISABLE_TELEMETRY=true
OTEL_SDK_DISABLED=true
CREWAI_TESTING=true
CREWAI_TRACING_ENABLED=false
# -----------------------------------------------------------------------------
# Testing/CI Configuration
# -----------------------------------------------------------------------------
# VCR recording mode: "none" (default), "new_episodes", "all", "once"
PYTEST_VCR_RECORD_MODE=none
# Set to "true" by GitHub when running in GitHub Actions
# GITHUB_ACTIONS=false
# -----------------------------------------------------------------------------
# Python Configuration
# -----------------------------------------------------------------------------
PYTHONUNBUFFERED=1

View File

@@ -5,18 +5,6 @@ on: [pull_request]
permissions:
contents: read
env:
OPENAI_API_KEY: fake-api-key
PYTHONUNBUFFERED: 1
BRAVE_API_KEY: fake-brave-key
SNOWFLAKE_USER: fake-snowflake-user
SNOWFLAKE_PASSWORD: fake-snowflake-password
SNOWFLAKE_ACCOUNT: fake-snowflake-account
SNOWFLAKE_WAREHOUSE: fake-snowflake-warehouse
SNOWFLAKE_DATABASE: fake-snowflake-database
SNOWFLAKE_SCHEMA: fake-snowflake-schema
EMBEDCHAIN_DB_URI: sqlite:///test.db
jobs:
tests:
name: tests (${{ matrix.python-version }})
@@ -84,26 +72,20 @@ jobs:
# fi
cd lib/crewai && uv run pytest \
--block-network \
--timeout=30 \
-vv \
--splits 8 \
--group ${{ matrix.group }} \
$DURATIONS_ARG \
--durations=10 \
-n auto \
--maxfail=3
- name: Run tool tests (group ${{ matrix.group }} of 8)
run: |
cd lib/crewai-tools && uv run pytest \
--block-network \
--timeout=30 \
-vv \
--splits 8 \
--group ${{ matrix.group }} \
--durations=10 \
-n auto \
--maxfail=3

193
conftest.py Normal file
View File

@@ -0,0 +1,193 @@
"""Pytest configuration for crewAI workspace."""
from collections.abc import Generator
import os
from pathlib import Path
import tempfile
from typing import Any
from dotenv import load_dotenv
import pytest
from vcr.request import Request # type: ignore[import-untyped]
env_test_path = Path(__file__).parent / ".env.test"
load_dotenv(env_test_path, override=True)
load_dotenv(override=True)
@pytest.fixture(autouse=True, scope="function")
def cleanup_event_handlers() -> Generator[None, Any, None]:
"""Clean up event bus handlers after each test to prevent test pollution."""
yield
try:
from crewai.events.event_bus import crewai_event_bus
with crewai_event_bus._rwlock.w_locked():
crewai_event_bus._sync_handlers.clear()
crewai_event_bus._async_handlers.clear()
except Exception: # noqa: S110
pass
@pytest.fixture(autouse=True, scope="function")
def setup_test_environment() -> Generator[None, Any, None]:
"""Setup test environment for crewAI workspace."""
with tempfile.TemporaryDirectory() as temp_dir:
storage_dir = Path(temp_dir) / "crewai_test_storage"
storage_dir.mkdir(parents=True, exist_ok=True)
if not storage_dir.exists() or not storage_dir.is_dir():
raise RuntimeError(
f"Failed to create test storage directory: {storage_dir}"
)
try:
test_file = storage_dir / ".permissions_test"
test_file.touch()
test_file.unlink()
except (OSError, IOError) as e:
raise RuntimeError(
f"Test storage directory {storage_dir} is not writable: {e}"
) from e
os.environ["CREWAI_STORAGE_DIR"] = str(storage_dir)
os.environ["CREWAI_TESTING"] = "true"
try:
yield
finally:
os.environ.pop("CREWAI_TESTING", "true")
os.environ.pop("CREWAI_STORAGE_DIR", None)
os.environ.pop("CREWAI_DISABLE_TELEMETRY", "true")
os.environ.pop("OTEL_SDK_DISABLED", "true")
os.environ.pop("OPENAI_BASE_URL", "https://api.openai.com/v1")
os.environ.pop("OPENAI_API_BASE", "https://api.openai.com/v1")
HEADERS_TO_FILTER = {
"authorization": "AUTHORIZATION-XXX",
"content-security-policy": "CSP-FILTERED",
"cookie": "COOKIE-XXX",
"set-cookie": "SET-COOKIE-XXX",
"permissions-policy": "PERMISSIONS-POLICY-XXX",
"referrer-policy": "REFERRER-POLICY-XXX",
"strict-transport-security": "STS-XXX",
"x-content-type-options": "X-CONTENT-TYPE-XXX",
"x-frame-options": "X-FRAME-OPTIONS-XXX",
"x-permitted-cross-domain-policies": "X-PERMITTED-XXX",
"x-request-id": "X-REQUEST-ID-XXX",
"x-runtime": "X-RUNTIME-XXX",
"x-xss-protection": "X-XSS-PROTECTION-XXX",
"x-stainless-arch": "X-STAINLESS-ARCH-XXX",
"x-stainless-os": "X-STAINLESS-OS-XXX",
"x-stainless-read-timeout": "X-STAINLESS-READ-TIMEOUT-XXX",
"cf-ray": "CF-RAY-XXX",
"etag": "ETAG-XXX",
"Strict-Transport-Security": "STS-XXX",
"access-control-expose-headers": "ACCESS-CONTROL-XXX",
"openai-organization": "OPENAI-ORG-XXX",
"openai-project": "OPENAI-PROJECT-XXX",
"x-ratelimit-limit-requests": "X-RATELIMIT-LIMIT-REQUESTS-XXX",
"x-ratelimit-limit-tokens": "X-RATELIMIT-LIMIT-TOKENS-XXX",
"x-ratelimit-remaining-requests": "X-RATELIMIT-REMAINING-REQUESTS-XXX",
"x-ratelimit-remaining-tokens": "X-RATELIMIT-REMAINING-TOKENS-XXX",
"x-ratelimit-reset-requests": "X-RATELIMIT-RESET-REQUESTS-XXX",
"x-ratelimit-reset-tokens": "X-RATELIMIT-RESET-TOKENS-XXX",
"x-goog-api-key": "X-GOOG-API-KEY-XXX",
"api-key": "X-API-KEY-XXX",
"User-Agent": "X-USER-AGENT-XXX",
"apim-request-id:": "X-API-CLIENT-REQUEST-ID-XXX",
"azureml-model-session": "AZUREML-MODEL-SESSION-XXX",
"x-ms-client-request-id": "X-MS-CLIENT-REQUEST-ID-XXX",
"x-ms-region": "X-MS-REGION-XXX",
"apim-request-id": "APIM-REQUEST-ID-XXX",
"x-api-key": "X-API-KEY-XXX",
"anthropic-organization-id": "ANTHROPIC-ORGANIZATION-ID-XXX",
"request-id": "REQUEST-ID-XXX",
"anthropic-ratelimit-input-tokens-limit": "ANTHROPIC-RATELIMIT-INPUT-TOKENS-LIMIT-XXX",
"anthropic-ratelimit-input-tokens-remaining": "ANTHROPIC-RATELIMIT-INPUT-TOKENS-REMAINING-XXX",
"anthropic-ratelimit-input-tokens-reset": "ANTHROPIC-RATELIMIT-INPUT-TOKENS-RESET-XXX",
"anthropic-ratelimit-output-tokens-limit": "ANTHROPIC-RATELIMIT-OUTPUT-TOKENS-LIMIT-XXX",
"anthropic-ratelimit-output-tokens-remaining": "ANTHROPIC-RATELIMIT-OUTPUT-TOKENS-REMAINING-XXX",
"anthropic-ratelimit-output-tokens-reset": "ANTHROPIC-RATELIMIT-OUTPUT-TOKENS-RESET-XXX",
"anthropic-ratelimit-tokens-limit": "ANTHROPIC-RATELIMIT-TOKENS-LIMIT-XXX",
"anthropic-ratelimit-tokens-remaining": "ANTHROPIC-RATELIMIT-TOKENS-REMAINING-XXX",
"anthropic-ratelimit-tokens-reset": "ANTHROPIC-RATELIMIT-TOKENS-RESET-XXX",
"x-amz-date": "X-AMZ-DATE-XXX",
"amz-sdk-invocation-id": "AMZ-SDK-INVOCATION-ID-XXX",
"accept-encoding": "ACCEPT-ENCODING-XXX",
"x-amzn-requestid": "X-AMZN-REQUESTID-XXX",
"x-amzn-RequestId": "X-AMZN-REQUESTID-XXX",
}
def _filter_request_headers(request: Request) -> Request: # type: ignore[no-any-unimported]
"""Filter sensitive headers from request before recording."""
for header_name, replacement in HEADERS_TO_FILTER.items():
for variant in [header_name, header_name.upper(), header_name.title()]:
if variant in request.headers:
request.headers[variant] = [replacement]
request.method = request.method.upper()
return request
def _filter_response_headers(response: dict[str, Any]) -> dict[str, Any]:
"""Filter sensitive headers from response before recording."""
for header_name, replacement in HEADERS_TO_FILTER.items():
for variant in [header_name, header_name.upper(), header_name.title()]:
if variant in response["headers"]:
response["headers"][variant] = [replacement]
return response
@pytest.fixture(scope="module")
def vcr_cassette_dir(request: Any) -> str:
"""Generate cassette directory path based on test module location.
Organizes cassettes to mirror test directory structure within each package:
lib/crewai/tests/llms/google/test_google.py -> lib/crewai/tests/cassettes/llms/google/
lib/crewai-tools/tests/tools/test_search.py -> lib/crewai-tools/tests/cassettes/tools/
"""
test_file = Path(request.fspath)
for parent in test_file.parents:
if parent.name in ("crewai", "crewai-tools") and parent.parent.name == "lib":
package_root = parent
break
else:
package_root = test_file.parent
tests_root = package_root / "tests"
test_dir = test_file.parent
if test_dir != tests_root:
relative_path = test_dir.relative_to(tests_root)
cassette_dir = tests_root / "cassettes" / relative_path
else:
cassette_dir = tests_root / "cassettes"
cassette_dir.mkdir(parents=True, exist_ok=True)
return str(cassette_dir)
@pytest.fixture(scope="module")
def vcr_config(vcr_cassette_dir: str) -> dict[str, Any]:
"""Configure VCR with organized cassette storage."""
config = {
"cassette_library_dir": vcr_cassette_dir,
"record_mode": os.getenv("PYTEST_VCR_RECORD_MODE", "once"),
"filter_headers": [(k, v) for k, v in HEADERS_TO_FILTER.items()],
"before_record_request": _filter_request_headers,
"before_record_response": _filter_response_headers,
"filter_query_parameters": ["key"],
"match_on": ["method", "scheme", "host", "port", "path"],
}
if os.getenv("GITHUB_ACTIONS") == "true":
config["record_mode"] = "none"
return config

View File

@@ -1089,6 +1089,50 @@ CrewAI supports streaming responses from LLMs, allowing your application to rece
</Tab>
</Tabs>
## Async LLM Calls
CrewAI supports asynchronous LLM calls for improved performance and concurrency in your AI workflows. Async calls allow you to run multiple LLM requests concurrently without blocking, making them ideal for high-throughput applications and parallel agent operations.
<Tabs>
<Tab title="Basic Usage">
Use the `acall` method for asynchronous LLM requests:
```python
import asyncio
from crewai import LLM
async def main():
llm = LLM(model="openai/gpt-4o")
# Single async call
response = await llm.acall("What is the capital of France?")
print(response)
asyncio.run(main())
```
The `acall` method supports all the same parameters as the synchronous `call` method, including messages, tools, and callbacks.
</Tab>
<Tab title="With Streaming">
Combine async calls with streaming for real-time concurrent responses:
```python
import asyncio
from crewai import LLM
async def stream_async():
llm = LLM(model="openai/gpt-4o", stream=True)
response = await llm.acall("Write a short story about AI")
print(response)
asyncio.run(stream_async())
```
</Tab>
</Tabs>
## Structured LLM Calls
CrewAI supports structured responses from LLM calls by allowing you to define a `response_format` using a Pydantic model. This enables the framework to automatically parse and validate the output, making it easier to integrate the response into your application without manual post-processing.

View File

@@ -66,5 +66,55 @@ def my_cache_strategy(arguments: dict, result: str) -> bool:
cached_tool.cache_function = my_cache_strategy
```
### Creating Async Tools
CrewAI supports async tools for non-blocking I/O operations. This is useful when your tool needs to make HTTP requests, database queries, or other I/O-bound operations.
#### Using the `@tool` Decorator with Async Functions
The simplest way to create an async tool is using the `@tool` decorator with an async function:
```python Code
import aiohttp
from crewai.tools import tool
@tool("Async Web Fetcher")
async def fetch_webpage(url: str) -> str:
"""Fetch content from a webpage asynchronously."""
async with aiohttp.ClientSession() as session:
async with session.get(url) as response:
return await response.text()
```
#### Subclassing `BaseTool` with Async Support
For more control, subclass `BaseTool` and implement both `_run` (sync) and `_arun` (async) methods:
```python Code
import requests
import aiohttp
from crewai.tools import BaseTool
from pydantic import BaseModel, Field
class WebFetcherInput(BaseModel):
"""Input schema for WebFetcher."""
url: str = Field(..., description="The URL to fetch")
class WebFetcherTool(BaseTool):
name: str = "Web Fetcher"
description: str = "Fetches content from a URL"
args_schema: type[BaseModel] = WebFetcherInput
def _run(self, url: str) -> str:
"""Synchronous implementation."""
return requests.get(url).text
async def _arun(self, url: str) -> str:
"""Asynchronous implementation for non-blocking I/O."""
async with aiohttp.ClientSession() as session:
async with session.get(url) as response:
return await response.text()
```
By adhering to these guidelines and incorporating new functionalities and collaboration tools into your tool creation and management processes,
you can leverage the full capabilities of the CrewAI framework, enhancing both the development experience and the efficiency of your AI agents.

View File

@@ -63,5 +63,55 @@ def my_cache_strategy(arguments: dict, result: str) -> bool:
cached_tool.cache_function = my_cache_strategy
```
### 비동기 도구 생성하기
CrewAI는 논블로킹 I/O 작업을 위한 비동기 도구를 지원합니다. 이는 HTTP 요청, 데이터베이스 쿼리 또는 기타 I/O 바운드 작업이 필요한 경우에 유용합니다.
#### `@tool` 데코레이터와 비동기 함수 사용하기
비동기 도구를 만드는 가장 간단한 방법은 `@tool` 데코레이터와 async 함수를 사용하는 것입니다:
```python Code
import aiohttp
from crewai.tools import tool
@tool("Async Web Fetcher")
async def fetch_webpage(url: str) -> str:
"""Fetch content from a webpage asynchronously."""
async with aiohttp.ClientSession() as session:
async with session.get(url) as response:
return await response.text()
```
#### 비동기 지원으로 `BaseTool` 서브클래싱하기
더 많은 제어를 위해 `BaseTool`을 상속하고 `_run`(동기) 및 `_arun`(비동기) 메서드를 모두 구현할 수 있습니다:
```python Code
import requests
import aiohttp
from crewai.tools import BaseTool
from pydantic import BaseModel, Field
class WebFetcherInput(BaseModel):
"""Input schema for WebFetcher."""
url: str = Field(..., description="The URL to fetch")
class WebFetcherTool(BaseTool):
name: str = "Web Fetcher"
description: str = "Fetches content from a URL"
args_schema: type[BaseModel] = WebFetcherInput
def _run(self, url: str) -> str:
"""Synchronous implementation."""
return requests.get(url).text
async def _arun(self, url: str) -> str:
"""Asynchronous implementation for non-blocking I/O."""
async with aiohttp.ClientSession() as session:
async with session.get(url) as response:
return await response.text()
```
이 가이드라인을 준수하고 새로운 기능과 협업 도구를 도구 생성 및 관리 프로세스에 통합함으로써,
CrewAI 프레임워크의 모든 기능을 활용할 수 있으며, AI agent의 개발 경험과 효율성을 모두 높일 수 있습니다.
CrewAI 프레임워크의 모든 기능을 활용할 수 있으며, AI agent의 개발 경험과 효율성을 모두 높일 수 있습니다.

View File

@@ -66,5 +66,55 @@ def my_cache_strategy(arguments: dict, result: str) -> bool:
cached_tool.cache_function = my_cache_strategy
```
### Criando Ferramentas Assíncronas
O CrewAI suporta ferramentas assíncronas para operações de I/O não bloqueantes. Isso é útil quando sua ferramenta precisa fazer requisições HTTP, consultas a banco de dados ou outras operações de I/O.
#### Usando o Decorador `@tool` com Funções Assíncronas
A maneira mais simples de criar uma ferramenta assíncrona é usando o decorador `@tool` com uma função async:
```python Code
import aiohttp
from crewai.tools import tool
@tool("Async Web Fetcher")
async def fetch_webpage(url: str) -> str:
"""Fetch content from a webpage asynchronously."""
async with aiohttp.ClientSession() as session:
async with session.get(url) as response:
return await response.text()
```
#### Subclassificando `BaseTool` com Suporte Assíncrono
Para maior controle, herde de `BaseTool` e implemente os métodos `_run` (síncrono) e `_arun` (assíncrono):
```python Code
import requests
import aiohttp
from crewai.tools import BaseTool
from pydantic import BaseModel, Field
class WebFetcherInput(BaseModel):
"""Input schema for WebFetcher."""
url: str = Field(..., description="The URL to fetch")
class WebFetcherTool(BaseTool):
name: str = "Web Fetcher"
description: str = "Fetches content from a URL"
args_schema: type[BaseModel] = WebFetcherInput
def _run(self, url: str) -> str:
"""Synchronous implementation."""
return requests.get(url).text
async def _arun(self, url: str) -> str:
"""Asynchronous implementation for non-blocking I/O."""
async with aiohttp.ClientSession() as session:
async with session.get(url) as response:
return await response.text()
```
Seguindo essas orientações e incorporando novas funcionalidades e ferramentas de colaboração nos seus processos de criação e gerenciamento de ferramentas,
você pode aproveitar ao máximo as capacidades do framework CrewAI, aprimorando tanto a experiência de desenvolvimento quanto a eficiência dos seus agentes de IA.
você pode aproveitar ao máximo as capacidades do framework CrewAI, aprimorando tanto a experiência de desenvolvimento quanto a eficiência dos seus agentes de IA.

View File

@@ -218,7 +218,7 @@ Update the root `README.md` only if the tool introduces a new category or notabl
## Discovery and specs
Our internal tooling discovers classes whose names end with `Tool`. Keep your class exported from the module path under `crewai_tools/tools/...` to be picked up by scripts like `generate_tool_specs.py`.
Our internal tooling discovers classes whose names end with `Tool`. Keep your class exported from the module path under `crewai_tools/tools/...` to be picked up by scripts like `crewai_tools.generate_tool_specs.py`.
---

View File

@@ -8,17 +8,17 @@ authors = [
]
requires-python = ">=3.10, <3.14"
dependencies = [
"lancedb>=0.5.4",
"pytube>=15.0.0",
"requests>=2.32.5",
"docker>=7.1.0",
"crewai==1.6.0",
"lancedb>=0.5.4",
"tiktoken>=0.8.0",
"beautifulsoup4>=4.13.4",
"python-docx>=1.2.0",
"youtube-transcript-api>=1.2.2",
"pymupdf>=1.26.6",
"lancedb~=0.5.4",
"pytube~=15.0.0",
"requests~=2.32.5",
"docker~=7.1.0",
"crewai==1.6.1",
"lancedb~=0.5.4",
"tiktoken~=0.8.0",
"beautifulsoup4~=4.13.4",
"python-docx~=1.2.0",
"youtube-transcript-api~=1.2.2",
"pymupdf~=1.26.6",
]

View File

@@ -291,4 +291,4 @@ __all__ = [
"ZapierActionTools",
]
__version__ = "1.6.0"
__version__ = "1.6.1"

View File

@@ -4,17 +4,20 @@ from collections.abc import Mapping
import inspect
import json
from pathlib import Path
from typing import Any, cast
from typing import Any
from crewai.tools.base_tool import BaseTool, EnvVar
from crewai_tools import tools
from pydantic import BaseModel
from pydantic.json_schema import GenerateJsonSchema
from pydantic_core import PydanticOmit
from crewai_tools import tools
class SchemaGenerator(GenerateJsonSchema):
def handle_invalid_for_json_schema(self, schema, error_info):
def handle_invalid_for_json_schema(
self, schema: Any, error_info: Any
) -> dict[str, Any]:
raise PydanticOmit
@@ -73,7 +76,7 @@ class ToolSpecExtractor:
@staticmethod
def _extract_field_default(
field: dict | None, fallback: str | list[Any] = ""
field: dict[str, Any] | None, fallback: str | list[Any] = ""
) -> str | list[Any] | int:
if not field:
return fallback
@@ -83,7 +86,7 @@ class ToolSpecExtractor:
return default if isinstance(default, (list, str, int)) else fallback
@staticmethod
def _extract_params(args_schema_field: dict | None) -> dict[str, Any]:
def _extract_params(args_schema_field: dict[str, Any] | None) -> dict[str, Any]:
if not args_schema_field:
return {}
@@ -94,15 +97,15 @@ class ToolSpecExtractor:
):
return {}
# Cast to type[BaseModel] after runtime check
schema_class = cast(type[BaseModel], args_schema_class)
try:
return schema_class.model_json_schema(schema_generator=SchemaGenerator)
return args_schema_class.model_json_schema(schema_generator=SchemaGenerator)
except Exception:
return {}
@staticmethod
def _extract_env_vars(env_vars_field: dict | None) -> list[dict[str, Any]]:
def _extract_env_vars(
env_vars_field: dict[str, Any] | None,
) -> list[dict[str, Any]]:
if not env_vars_field:
return []

View File

@@ -1,21 +0,0 @@
import pytest
def pytest_configure(config):
"""Register custom markers."""
config.addinivalue_line("markers", "integration: mark test as an integration test")
config.addinivalue_line("markers", "asyncio: mark test as an async test")
# Set the asyncio loop scope through ini configuration
config.inicfg["asyncio_mode"] = "auto"
@pytest.fixture(scope="function")
def event_loop():
"""Create an instance of the default event loop for each test case."""
import asyncio
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
yield loop
loop.close()

View File

@@ -2,7 +2,7 @@ import json
from unittest import mock
from crewai.tools.base_tool import BaseTool, EnvVar
from generate_tool_specs import ToolSpecExtractor
from crewai_tools.generate_tool_specs import ToolSpecExtractor
from pydantic import BaseModel, Field
import pytest
@@ -61,8 +61,8 @@ def test_unwrap_schema(extractor):
@pytest.fixture
def mock_tool_extractor(extractor):
with (
mock.patch("generate_tool_specs.dir", return_value=["MockTool"]),
mock.patch("generate_tool_specs.getattr", return_value=MockTool),
mock.patch("crewai_tools.generate_tool_specs.dir", return_value=["MockTool"]),
mock.patch("crewai_tools.generate_tool_specs.getattr", return_value=MockTool),
):
extractor.extract_all_tools()
assert len(extractor.tools_spec) == 1

View File

@@ -4,7 +4,7 @@ from crewai_tools.tools.firecrawl_crawl_website_tool.firecrawl_crawl_website_too
FirecrawlCrawlWebsiteTool,
)
@pytest.mark.vcr(filter_headers=["authorization"])
@pytest.mark.vcr()
def test_firecrawl_crawl_tool_integration():
tool = FirecrawlCrawlWebsiteTool(config={
"limit": 2,

View File

@@ -4,7 +4,7 @@ from crewai_tools.tools.firecrawl_scrape_website_tool.firecrawl_scrape_website_t
FirecrawlScrapeWebsiteTool,
)
@pytest.mark.vcr(filter_headers=["authorization"])
@pytest.mark.vcr()
def test_firecrawl_scrape_tool_integration():
tool = FirecrawlScrapeWebsiteTool()
result = tool.run(url="https://firecrawl.dev")

View File

@@ -3,7 +3,7 @@ import pytest
from crewai_tools.tools.firecrawl_search_tool.firecrawl_search_tool import FirecrawlSearchTool
@pytest.mark.vcr(filter_headers=["authorization"])
@pytest.mark.vcr()
def test_firecrawl_search_tool_integration():
tool = FirecrawlSearchTool()
result = tool.run(query="firecrawl")

View File

@@ -23,15 +23,13 @@ from crewai_tools.tools.rag.rag_tool import Adapter
import pytest
pytestmark = [pytest.mark.vcr(filter_headers=["authorization"])]
@pytest.fixture
def mock_adapter():
mock_adapter = MagicMock(spec=Adapter)
return mock_adapter
@pytest.mark.vcr()
def test_directory_search_tool():
with tempfile.TemporaryDirectory() as temp_dir:
test_file = Path(temp_dir) / "test.txt"
@@ -65,6 +63,7 @@ def test_pdf_search_tool(mock_adapter):
)
@pytest.mark.vcr()
def test_txt_search_tool():
with tempfile.NamedTemporaryFile(suffix=".txt", delete=False) as temp_file:
temp_file.write(b"This is a test file for txt search")
@@ -102,6 +101,7 @@ def test_docx_search_tool(mock_adapter):
)
@pytest.mark.vcr()
def test_json_search_tool():
with tempfile.NamedTemporaryFile(suffix=".json", delete=False) as temp_file:
temp_file.write(b'{"test": "This is a test JSON file"}')
@@ -127,6 +127,7 @@ def test_xml_search_tool(mock_adapter):
)
@pytest.mark.vcr()
def test_csv_search_tool():
with tempfile.NamedTemporaryFile(suffix=".csv", delete=False) as temp_file:
temp_file.write(b"name,description\ntest,This is a test CSV file")
@@ -141,6 +142,7 @@ def test_csv_search_tool():
os.unlink(temp_file_path)
@pytest.mark.vcr()
def test_mdx_search_tool():
with tempfile.NamedTemporaryFile(suffix=".mdx", delete=False) as temp_file:
temp_file.write(b"# Test MDX\nThis is a test MDX file")

View File

@@ -9,35 +9,36 @@ authors = [
requires-python = ">=3.10, <3.14"
dependencies = [
# Core Dependencies
"pydantic>=2.11.9",
"openai>=1.13.3",
"pydantic~=2.11.9",
"openai~=1.83.0",
"instructor>=1.3.3",
# Text Processing
"pdfplumber>=0.11.4",
"regex>=2024.9.11",
"pdfplumber~=0.11.4",
"regex~=2024.9.11",
# Telemetry and Monitoring
"opentelemetry-api>=1.30.0",
"opentelemetry-sdk>=1.30.0",
"opentelemetry-exporter-otlp-proto-http>=1.30.0",
"opentelemetry-api~=1.34.0",
"opentelemetry-sdk~=1.34.0",
"opentelemetry-exporter-otlp-proto-http~=1.34.0",
# Data Handling
"chromadb~=1.1.0",
"tokenizers>=0.20.3",
"openpyxl>=3.1.5",
"tokenizers~=0.20.3",
"openpyxl~=3.1.5",
# Authentication and Security
"python-dotenv>=1.1.1",
"pyjwt>=2.9.0",
"python-dotenv~=1.1.1",
"pyjwt~=2.9.0",
# Configuration and Utils
"click>=8.1.7",
"appdirs>=1.4.4",
"jsonref>=1.1.0",
"json-repair==0.25.2",
"uv>=0.4.25",
"tomli-w>=1.1.0",
"tomli>=2.0.2",
"json5>=0.10.0",
"portalocker==2.7.0",
"pydantic-settings>=2.10.1",
"mcp>=1.16.0",
"click~=8.1.7",
"appdirs~=1.4.4",
"jsonref~=1.1.0",
"json-repair~=0.25.2",
"tomli-w~=1.1.0",
"tomli~=2.0.2",
"json5~=0.10.0",
"portalocker~=2.7.0",
"pydantic-settings~=2.10.1",
"mcp~=1.16.0",
"uv~=0.9.13",
"aiosqlite~=0.21.0",
]
[project.urls]
@@ -48,55 +49,53 @@ Repository = "https://github.com/crewAIInc/crewAI"
[project.optional-dependencies]
tools = [
"crewai-tools==1.6.0",
"crewai-tools==1.6.1",
]
embeddings = [
"tiktoken~=0.8.0"
]
pdfplumber = [
"pdfplumber>=0.11.4",
]
pandas = [
"pandas>=2.2.3",
"pandas~=2.2.3",
]
openpyxl = [
"openpyxl>=3.1.5",
"openpyxl~=3.1.5",
]
mem0 = ["mem0ai>=0.1.94"]
mem0 = ["mem0ai~=0.1.94"]
docling = [
"docling>=2.12.0",
"docling~=2.63.0",
]
qdrant = [
"qdrant-client[fastembed]>=1.14.3",
"qdrant-client[fastembed]~=1.14.3",
]
aws = [
"boto3>=1.40.38",
"boto3~=1.40.38",
"aiobotocore~=2.25.2",
]
watson = [
"ibm-watsonx-ai>=1.3.39",
"ibm-watsonx-ai~=1.3.39",
]
voyageai = [
"voyageai>=0.3.5",
"voyageai~=0.3.5",
]
litellm = [
"litellm>=1.74.9",
"litellm~=1.74.9",
]
bedrock = [
"boto3>=1.40.45",
"boto3~=1.40.45",
]
google-genai = [
"google-genai>=1.2.0",
"google-genai~=1.2.0",
]
azure-ai-inference = [
"azure-ai-inference>=1.0.0b9",
"azure-ai-inference~=1.0.0b9",
]
anthropic = [
"anthropic>=0.69.0",
"anthropic~=0.71.0",
]
a2a = [
a2a = [
"a2a-sdk~=0.3.10",
"httpx-auth>=0.23.1",
"httpx-sse>=0.4.0",
"httpx-auth~=0.23.1",
"httpx-sse~=0.4.0",
]

View File

@@ -40,7 +40,7 @@ def _suppress_pydantic_deprecation_warnings() -> None:
_suppress_pydantic_deprecation_warnings()
__version__ = "1.6.0"
__version__ = "1.6.1"
_telemetry_submitted = False

View File

@@ -2,7 +2,6 @@ from __future__ import annotations
import asyncio
from collections.abc import Sequence
import json
import shutil
import subprocess
import time
@@ -19,6 +18,19 @@ from pydantic import BaseModel, Field, InstanceOf, PrivateAttr, model_validator
from typing_extensions import Self
from crewai.a2a.config import A2AConfig
from crewai.agent.utils import (
ahandle_knowledge_retrieval,
apply_training_data,
build_task_prompt_with_schema,
format_task_with_context,
get_knowledge_config,
handle_knowledge_retrieval,
handle_reasoning,
prepare_tools,
process_tool_results,
save_last_messages,
validate_max_execution_time,
)
from crewai.agents.agent_builder.base_agent import BaseAgent
from crewai.agents.cache.cache_handler import CacheHandler
from crewai.agents.crew_agent_executor import CrewAgentExecutor
@@ -27,9 +39,6 @@ from crewai.events.types.knowledge_events import (
KnowledgeQueryCompletedEvent,
KnowledgeQueryFailedEvent,
KnowledgeQueryStartedEvent,
KnowledgeRetrievalCompletedEvent,
KnowledgeRetrievalStartedEvent,
KnowledgeSearchQueryFailedEvent,
)
from crewai.events.types.memory_events import (
MemoryRetrievalCompletedEvent,
@@ -37,7 +46,6 @@ from crewai.events.types.memory_events import (
)
from crewai.knowledge.knowledge import Knowledge
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
from crewai.knowledge.utils.knowledge_utils import extract_knowledge_context
from crewai.lite_agent import LiteAgent
from crewai.llms.base_llm import BaseLLM
from crewai.mcp import (
@@ -61,7 +69,7 @@ from crewai.utilities.agent_utils import (
render_text_description_and_args,
)
from crewai.utilities.constants import TRAINED_AGENTS_DATA_FILE, TRAINING_DATA_FILE
from crewai.utilities.converter import Converter, generate_model_description
from crewai.utilities.converter import Converter
from crewai.utilities.guardrail_types import GuardrailType
from crewai.utilities.llm_utils import create_llm
from crewai.utilities.prompts import Prompts
@@ -295,53 +303,15 @@ class Agent(BaseAgent):
ValueError: If the max execution time is not a positive integer.
RuntimeError: If the agent execution fails for other reasons.
"""
if self.reasoning:
try:
from crewai.utilities.reasoning_handler import (
AgentReasoning,
AgentReasoningOutput,
)
reasoning_handler = AgentReasoning(task=task, agent=self)
reasoning_output: AgentReasoningOutput = (
reasoning_handler.handle_agent_reasoning()
)
# Add the reasoning plan to the task description
task.description += f"\n\nReasoning Plan:\n{reasoning_output.plan.plan}"
except Exception as e:
self._logger.log("error", f"Error during reasoning process: {e!s}")
handle_reasoning(self, task)
self._inject_date_to_task(task)
if self.tools_handler:
self.tools_handler.last_used_tool = None
task_prompt = task.prompt()
# If the task requires output in JSON or Pydantic format,
# append specific instructions to the task prompt to ensure
# that the final answer does not include any code block markers
# Skip this if task.response_model is set, as native structured outputs handle schema automatically
if (task.output_json or task.output_pydantic) and not task.response_model:
# Generate the schema based on the output format
if task.output_json:
schema_dict = generate_model_description(task.output_json)
schema = json.dumps(schema_dict["json_schema"]["schema"], indent=2)
task_prompt += "\n" + self.i18n.slice(
"formatted_task_instructions"
).format(output_format=schema)
elif task.output_pydantic:
schema_dict = generate_model_description(task.output_pydantic)
schema = json.dumps(schema_dict["json_schema"]["schema"], indent=2)
task_prompt += "\n" + self.i18n.slice(
"formatted_task_instructions"
).format(output_format=schema)
if context:
task_prompt = self.i18n.slice("task_with_context").format(
task=task_prompt, context=context
)
task_prompt = build_task_prompt_with_schema(task, task_prompt, self.i18n)
task_prompt = format_task_with_context(task_prompt, context, self.i18n)
if self._is_any_available_memory():
crewai_event_bus.emit(
@@ -379,84 +349,20 @@ class Agent(BaseAgent):
from_task=task,
),
)
knowledge_config = (
self.knowledge_config.model_dump() if self.knowledge_config else {}
knowledge_config = get_knowledge_config(self)
task_prompt = handle_knowledge_retrieval(
self,
task,
task_prompt,
knowledge_config,
self.knowledge.query if self.knowledge else lambda *a, **k: None,
self.crew.query_knowledge if self.crew else lambda *a, **k: None,
)
if self.knowledge or (self.crew and self.crew.knowledge):
crewai_event_bus.emit(
self,
event=KnowledgeRetrievalStartedEvent(
from_task=task,
from_agent=self,
),
)
try:
self.knowledge_search_query = self._get_knowledge_search_query(
task_prompt, task
)
if self.knowledge_search_query:
# Quering agent specific knowledge
if self.knowledge:
agent_knowledge_snippets = self.knowledge.query(
[self.knowledge_search_query], **knowledge_config
)
if agent_knowledge_snippets:
self.agent_knowledge_context = extract_knowledge_context(
agent_knowledge_snippets
)
if self.agent_knowledge_context:
task_prompt += self.agent_knowledge_context
prepare_tools(self, tools, task)
task_prompt = apply_training_data(self, task_prompt)
# Quering crew specific knowledge
knowledge_snippets = self.crew.query_knowledge(
[self.knowledge_search_query], **knowledge_config
)
if knowledge_snippets:
self.crew_knowledge_context = extract_knowledge_context(
knowledge_snippets
)
if self.crew_knowledge_context:
task_prompt += self.crew_knowledge_context
crewai_event_bus.emit(
self,
event=KnowledgeRetrievalCompletedEvent(
query=self.knowledge_search_query,
from_task=task,
from_agent=self,
retrieved_knowledge=(
(self.agent_knowledge_context or "")
+ (
"\n"
if self.agent_knowledge_context
and self.crew_knowledge_context
else ""
)
+ (self.crew_knowledge_context or "")
),
),
)
except Exception as e:
crewai_event_bus.emit(
self,
event=KnowledgeSearchQueryFailedEvent(
query=self.knowledge_search_query or "",
error=str(e),
from_task=task,
from_agent=self,
),
)
tools = tools or self.tools or []
self.create_agent_executor(tools=tools, task=task)
if self.crew and self.crew._train:
task_prompt = self._training_handler(task_prompt=task_prompt)
else:
task_prompt = self._use_trained_data(task_prompt=task_prompt)
# Import agent events locally to avoid circular imports
from crewai.events.types.agent_events import (
AgentExecutionCompletedEvent,
AgentExecutionErrorEvent,
@@ -474,15 +380,8 @@ class Agent(BaseAgent):
),
)
# Determine execution method based on timeout setting
validate_max_execution_time(self.max_execution_time)
if self.max_execution_time is not None:
if (
not isinstance(self.max_execution_time, int)
or self.max_execution_time <= 0
):
raise ValueError(
"Max Execution time must be a positive integer greater than zero"
)
result = self._execute_with_timeout(
task_prompt, task, self.max_execution_time
)
@@ -490,7 +389,6 @@ class Agent(BaseAgent):
result = self._execute_without_timeout(task_prompt, task)
except TimeoutError as e:
# Propagate TimeoutError without retry
crewai_event_bus.emit(
self,
event=AgentExecutionErrorEvent(
@@ -502,7 +400,6 @@ class Agent(BaseAgent):
raise e
except Exception as e:
if e.__class__.__module__.startswith("litellm"):
# Do not retry on litellm errors
crewai_event_bus.emit(
self,
event=AgentExecutionErrorEvent(
@@ -528,23 +425,13 @@ class Agent(BaseAgent):
if self.max_rpm and self._rpm_controller:
self._rpm_controller.stop_rpm_counter()
# If there was any tool in self.tools_results that had result_as_answer
# set to True, return the results of the last tool that had
# result_as_answer set to True
for tool_result in self.tools_results:
if tool_result.get("result_as_answer", False):
result = tool_result["result"]
result = process_tool_results(self, result)
crewai_event_bus.emit(
self,
event=AgentExecutionCompletedEvent(agent=self, task=task, output=result),
)
self._last_messages = (
self.agent_executor.messages.copy()
if self.agent_executor and hasattr(self.agent_executor, "messages")
else []
)
save_last_messages(self)
self._cleanup_mcp_clients()
return result
@@ -604,6 +491,208 @@ class Agent(BaseAgent):
}
)["output"]
async def aexecute_task(
self,
task: Task,
context: str | None = None,
tools: list[BaseTool] | None = None,
) -> Any:
"""Execute a task with the agent asynchronously.
Args:
task: Task to execute.
context: Context to execute the task in.
tools: Tools to use for the task.
Returns:
Output of the agent.
Raises:
TimeoutError: If execution exceeds the maximum execution time.
ValueError: If the max execution time is not a positive integer.
RuntimeError: If the agent execution fails for other reasons.
"""
handle_reasoning(self, task)
self._inject_date_to_task(task)
if self.tools_handler:
self.tools_handler.last_used_tool = None
task_prompt = task.prompt()
task_prompt = build_task_prompt_with_schema(task, task_prompt, self.i18n)
task_prompt = format_task_with_context(task_prompt, context, self.i18n)
if self._is_any_available_memory():
crewai_event_bus.emit(
self,
event=MemoryRetrievalStartedEvent(
task_id=str(task.id) if task else None,
source_type="agent",
from_agent=self,
from_task=task,
),
)
start_time = time.time()
contextual_memory = ContextualMemory(
self.crew._short_term_memory,
self.crew._long_term_memory,
self.crew._entity_memory,
self.crew._external_memory,
agent=self,
task=task,
)
memory = await contextual_memory.abuild_context_for_task(
task, context or ""
)
if memory.strip() != "":
task_prompt += self.i18n.slice("memory").format(memory=memory)
crewai_event_bus.emit(
self,
event=MemoryRetrievalCompletedEvent(
task_id=str(task.id) if task else None,
memory_content=memory,
retrieval_time_ms=(time.time() - start_time) * 1000,
source_type="agent",
from_agent=self,
from_task=task,
),
)
knowledge_config = get_knowledge_config(self)
task_prompt = await ahandle_knowledge_retrieval(
self, task, task_prompt, knowledge_config
)
prepare_tools(self, tools, task)
task_prompt = apply_training_data(self, task_prompt)
from crewai.events.types.agent_events import (
AgentExecutionCompletedEvent,
AgentExecutionErrorEvent,
AgentExecutionStartedEvent,
)
try:
crewai_event_bus.emit(
self,
event=AgentExecutionStartedEvent(
agent=self,
tools=self.tools,
task_prompt=task_prompt,
task=task,
),
)
validate_max_execution_time(self.max_execution_time)
if self.max_execution_time is not None:
result = await self._aexecute_with_timeout(
task_prompt, task, self.max_execution_time
)
else:
result = await self._aexecute_without_timeout(task_prompt, task)
except TimeoutError as e:
crewai_event_bus.emit(
self,
event=AgentExecutionErrorEvent(
agent=self,
task=task,
error=str(e),
),
)
raise e
except Exception as e:
if e.__class__.__module__.startswith("litellm"):
crewai_event_bus.emit(
self,
event=AgentExecutionErrorEvent(
agent=self,
task=task,
error=str(e),
),
)
raise e
self._times_executed += 1
if self._times_executed > self.max_retry_limit:
crewai_event_bus.emit(
self,
event=AgentExecutionErrorEvent(
agent=self,
task=task,
error=str(e),
),
)
raise e
result = await self.aexecute_task(task, context, tools)
if self.max_rpm and self._rpm_controller:
self._rpm_controller.stop_rpm_counter()
result = process_tool_results(self, result)
crewai_event_bus.emit(
self,
event=AgentExecutionCompletedEvent(agent=self, task=task, output=result),
)
save_last_messages(self)
self._cleanup_mcp_clients()
return result
async def _aexecute_with_timeout(
self, task_prompt: str, task: Task, timeout: int
) -> Any:
"""Execute a task with a timeout asynchronously.
Args:
task_prompt: The prompt to send to the agent.
task: The task being executed.
timeout: Maximum execution time in seconds.
Returns:
The output of the agent.
Raises:
TimeoutError: If execution exceeds the timeout.
RuntimeError: If execution fails for other reasons.
"""
try:
return await asyncio.wait_for(
self._aexecute_without_timeout(task_prompt, task),
timeout=timeout,
)
except asyncio.TimeoutError as e:
raise TimeoutError(
f"Task '{task.description}' execution timed out after {timeout} seconds. "
"Consider increasing max_execution_time or optimizing the task."
) from e
async def _aexecute_without_timeout(self, task_prompt: str, task: Task) -> Any:
"""Execute a task without a timeout asynchronously.
Args:
task_prompt: The prompt to send to the agent.
task: The task being executed.
Returns:
The output of the agent.
"""
if not self.agent_executor:
raise RuntimeError("Agent executor is not initialized.")
result = await self.agent_executor.ainvoke(
{
"input": task_prompt,
"tool_names": self.agent_executor.tools_names,
"tools": self.agent_executor.tools_description,
"ask_for_human_input": task.human_input,
}
)
return result["output"]
def create_agent_executor(
self, tools: list[BaseTool] | None = None, task: Task | None = None
) -> None:
@@ -633,7 +722,7 @@ class Agent(BaseAgent):
)
self.agent_executor = CrewAgentExecutor(
llm=self.llm,
llm=self.llm, # type: ignore[arg-type]
task=task, # type: ignore[arg-type]
agent=self,
crew=self.crew,
@@ -810,6 +899,7 @@ class Agent(BaseAgent):
from crewai.tools.base_tool import BaseTool
from crewai.tools.mcp_native_tool import MCPNativeTool
transport: StdioTransport | HTTPTransport | SSETransport
if isinstance(mcp_config, MCPServerStdio):
transport = StdioTransport(
command=mcp_config.command,
@@ -903,10 +993,10 @@ class Agent(BaseAgent):
server_name=server_name,
run_context=None,
)
if mcp_config.tool_filter(context, tool):
if mcp_config.tool_filter(context, tool): # type: ignore[call-arg, arg-type]
filtered_tools.append(tool)
except (TypeError, AttributeError):
if mcp_config.tool_filter(tool):
if mcp_config.tool_filter(tool): # type: ignore[call-arg, arg-type]
filtered_tools.append(tool)
else:
# Not callable - include tool
@@ -981,7 +1071,9 @@ class Agent(BaseAgent):
path = parsed.path.replace("/", "_").strip("_")
return f"{domain}_{path}" if path else domain
def _get_mcp_tool_schemas(self, server_params: dict) -> dict[str, dict]:
def _get_mcp_tool_schemas(
self, server_params: dict[str, Any]
) -> dict[str, dict[str, Any]]:
"""Get tool schemas from MCP server for wrapper creation with caching."""
server_url = server_params["url"]
@@ -995,7 +1087,7 @@ class Agent(BaseAgent):
self._logger.log(
"debug", f"Using cached MCP tool schemas for {server_url}"
)
return cached_data
return cached_data # type: ignore[no-any-return]
try:
schemas = asyncio.run(self._get_mcp_tool_schemas_async(server_params))
@@ -1013,7 +1105,7 @@ class Agent(BaseAgent):
async def _get_mcp_tool_schemas_async(
self, server_params: dict[str, Any]
) -> dict[str, dict]:
) -> dict[str, dict[str, Any]]:
"""Async implementation of MCP tool schema retrieval with timeouts and retries."""
server_url = server_params["url"]
return await self._retry_mcp_discovery(
@@ -1021,7 +1113,7 @@ class Agent(BaseAgent):
)
async def _retry_mcp_discovery(
self, operation_func, server_url: str
self, operation_func: Any, server_url: str
) -> dict[str, dict[str, Any]]:
"""Retry MCP discovery operation with exponential backoff, avoiding try-except in loop."""
last_error = None
@@ -1052,7 +1144,7 @@ class Agent(BaseAgent):
@staticmethod
async def _attempt_mcp_discovery(
operation_func, server_url: str
operation_func: Any, server_url: str
) -> tuple[dict[str, dict[str, Any]] | None, str, bool]:
"""Attempt single MCP discovery operation and return (result, error_message, should_retry)."""
try:
@@ -1142,7 +1234,7 @@ class Agent(BaseAgent):
properties = json_schema.get("properties", {})
required_fields = json_schema.get("required", [])
field_definitions = {}
field_definitions: dict[str, Any] = {}
for field_name, field_schema in properties.items():
field_type = self._json_type_to_python(field_schema)
@@ -1162,7 +1254,7 @@ class Agent(BaseAgent):
)
model_name = f"{tool_name.replace('-', '_').replace(' ', '_')}Schema"
return create_model(model_name, **field_definitions)
return create_model(model_name, **field_definitions) # type: ignore[no-any-return]
def _json_type_to_python(self, field_schema: dict[str, Any]) -> type:
"""Convert JSON Schema type to Python type.
@@ -1177,7 +1269,7 @@ class Agent(BaseAgent):
json_type = field_schema.get("type")
if "anyOf" in field_schema:
types = []
types: list[type] = []
for option in field_schema["anyOf"]:
if "const" in option:
types.append(str)
@@ -1185,13 +1277,13 @@ class Agent(BaseAgent):
types.append(self._json_type_to_python(option))
unique_types = list(set(types))
if len(unique_types) > 1:
result = unique_types[0]
result: Any = unique_types[0]
for t in unique_types[1:]:
result = result | t
return result
return result # type: ignore[no-any-return]
return unique_types[0]
type_mapping = {
type_mapping: dict[str | None, type] = {
"string": str,
"number": float,
"integer": int,
@@ -1203,7 +1295,7 @@ class Agent(BaseAgent):
return type_mapping.get(json_type, Any)
@staticmethod
def _fetch_amp_mcp_servers(mcp_name: str) -> list[dict]:
def _fetch_amp_mcp_servers(mcp_name: str) -> list[dict[str, Any]]:
"""Fetch MCP server configurations from CrewAI AOP API."""
# TODO: Implement AMP API call to "integrations/mcps" endpoint
# Should return list of server configs with URLs
@@ -1438,11 +1530,11 @@ class Agent(BaseAgent):
"""
if self.apps:
platform_tools = self.get_platform_tools(self.apps)
if platform_tools:
if platform_tools and self.tools is not None:
self.tools.extend(platform_tools)
if self.mcps:
mcps = self.get_mcp_tools(self.mcps)
if mcps:
if mcps and self.tools is not None:
self.tools.extend(mcps)
lite_agent = LiteAgent(

View File

@@ -0,0 +1,355 @@
"""Utility functions for agent task execution.
This module contains shared logic extracted from the Agent's execute_task
and aexecute_task methods to reduce code duplication.
"""
from __future__ import annotations
import json
from typing import TYPE_CHECKING, Any
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.knowledge_events import (
KnowledgeRetrievalCompletedEvent,
KnowledgeRetrievalStartedEvent,
KnowledgeSearchQueryFailedEvent,
)
from crewai.knowledge.utils.knowledge_utils import extract_knowledge_context
from crewai.utilities.converter import generate_model_description
if TYPE_CHECKING:
from crewai.agent.core import Agent
from crewai.task import Task
from crewai.tools.base_tool import BaseTool
from crewai.utilities.i18n import I18N
def handle_reasoning(agent: Agent, task: Task) -> None:
"""Handle the reasoning process for an agent before task execution.
Args:
agent: The agent performing the task.
task: The task to execute.
"""
if not agent.reasoning:
return
try:
from crewai.utilities.reasoning_handler import (
AgentReasoning,
AgentReasoningOutput,
)
reasoning_handler = AgentReasoning(task=task, agent=agent)
reasoning_output: AgentReasoningOutput = (
reasoning_handler.handle_agent_reasoning()
)
task.description += f"\n\nReasoning Plan:\n{reasoning_output.plan.plan}"
except Exception as e:
agent._logger.log("error", f"Error during reasoning process: {e!s}")
def build_task_prompt_with_schema(task: Task, task_prompt: str, i18n: I18N) -> str:
"""Build task prompt with JSON/Pydantic schema instructions if applicable.
Args:
task: The task being executed.
task_prompt: The initial task prompt.
i18n: Internationalization instance.
Returns:
The task prompt potentially augmented with schema instructions.
"""
if (task.output_json or task.output_pydantic) and not task.response_model:
if task.output_json:
schema_dict = generate_model_description(task.output_json)
schema = json.dumps(schema_dict["json_schema"]["schema"], indent=2)
task_prompt += "\n" + i18n.slice("formatted_task_instructions").format(
output_format=schema
)
elif task.output_pydantic:
schema_dict = generate_model_description(task.output_pydantic)
schema = json.dumps(schema_dict["json_schema"]["schema"], indent=2)
task_prompt += "\n" + i18n.slice("formatted_task_instructions").format(
output_format=schema
)
return task_prompt
def format_task_with_context(task_prompt: str, context: str | None, i18n: I18N) -> str:
"""Format task prompt with context if provided.
Args:
task_prompt: The task prompt.
context: Optional context string.
i18n: Internationalization instance.
Returns:
The task prompt formatted with context if provided.
"""
if context:
return i18n.slice("task_with_context").format(task=task_prompt, context=context)
return task_prompt
def get_knowledge_config(agent: Agent) -> dict[str, Any]:
"""Get knowledge configuration from agent.
Args:
agent: The agent instance.
Returns:
Dictionary of knowledge configuration.
"""
return agent.knowledge_config.model_dump() if agent.knowledge_config else {}
def handle_knowledge_retrieval(
agent: Agent,
task: Task,
task_prompt: str,
knowledge_config: dict[str, Any],
query_func: Any,
crew_query_func: Any,
) -> str:
"""Handle knowledge retrieval for task execution.
This function handles both agent-specific and crew-specific knowledge queries.
Args:
agent: The agent performing the task.
task: The task being executed.
task_prompt: The current task prompt.
knowledge_config: Knowledge configuration dictionary.
query_func: Function to query agent knowledge (sync or async).
crew_query_func: Function to query crew knowledge (sync or async).
Returns:
The task prompt potentially augmented with knowledge context.
"""
if not (agent.knowledge or (agent.crew and agent.crew.knowledge)):
return task_prompt
crewai_event_bus.emit(
agent,
event=KnowledgeRetrievalStartedEvent(
from_task=task,
from_agent=agent,
),
)
try:
agent.knowledge_search_query = agent._get_knowledge_search_query(
task_prompt, task
)
if agent.knowledge_search_query:
if agent.knowledge:
agent_knowledge_snippets = query_func(
[agent.knowledge_search_query], **knowledge_config
)
if agent_knowledge_snippets:
agent.agent_knowledge_context = extract_knowledge_context(
agent_knowledge_snippets
)
if agent.agent_knowledge_context:
task_prompt += agent.agent_knowledge_context
knowledge_snippets = crew_query_func(
[agent.knowledge_search_query], **knowledge_config
)
if knowledge_snippets:
agent.crew_knowledge_context = extract_knowledge_context(
knowledge_snippets
)
if agent.crew_knowledge_context:
task_prompt += agent.crew_knowledge_context
crewai_event_bus.emit(
agent,
event=KnowledgeRetrievalCompletedEvent(
query=agent.knowledge_search_query,
from_task=task,
from_agent=agent,
retrieved_knowledge=_combine_knowledge_context(agent),
),
)
except Exception as e:
crewai_event_bus.emit(
agent,
event=KnowledgeSearchQueryFailedEvent(
query=agent.knowledge_search_query or "",
error=str(e),
from_task=task,
from_agent=agent,
),
)
return task_prompt
def _combine_knowledge_context(agent: Agent) -> str:
"""Combine agent and crew knowledge contexts into a single string.
Args:
agent: The agent with knowledge contexts.
Returns:
Combined knowledge context string.
"""
agent_ctx = agent.agent_knowledge_context or ""
crew_ctx = agent.crew_knowledge_context or ""
separator = "\n" if agent_ctx and crew_ctx else ""
return agent_ctx + separator + crew_ctx
def apply_training_data(agent: Agent, task_prompt: str) -> str:
"""Apply training data to the task prompt.
Args:
agent: The agent performing the task.
task_prompt: The task prompt.
Returns:
The task prompt with training data applied.
"""
if agent.crew and agent.crew._train:
return agent._training_handler(task_prompt=task_prompt)
return agent._use_trained_data(task_prompt=task_prompt)
def process_tool_results(agent: Agent, result: Any) -> Any:
"""Process tool results, returning result_as_answer if applicable.
Args:
agent: The agent with tool results.
result: The current result.
Returns:
The final result, potentially overridden by tool result_as_answer.
"""
for tool_result in agent.tools_results:
if tool_result.get("result_as_answer", False):
result = tool_result["result"]
return result
def save_last_messages(agent: Agent) -> None:
"""Save the last messages from agent executor.
Args:
agent: The agent instance.
"""
agent._last_messages = (
agent.agent_executor.messages.copy()
if agent.agent_executor and hasattr(agent.agent_executor, "messages")
else []
)
def prepare_tools(
agent: Agent, tools: list[BaseTool] | None, task: Task
) -> list[BaseTool]:
"""Prepare tools for task execution and create agent executor.
Args:
agent: The agent instance.
tools: Optional list of tools.
task: The task being executed.
Returns:
The list of tools to use.
"""
final_tools = tools or agent.tools or []
agent.create_agent_executor(tools=final_tools, task=task)
return final_tools
def validate_max_execution_time(max_execution_time: int | None) -> None:
"""Validate max_execution_time parameter.
Args:
max_execution_time: The maximum execution time to validate.
Raises:
ValueError: If max_execution_time is not a positive integer.
"""
if max_execution_time is not None:
if not isinstance(max_execution_time, int) or max_execution_time <= 0:
raise ValueError(
"Max Execution time must be a positive integer greater than zero"
)
async def ahandle_knowledge_retrieval(
agent: Agent,
task: Task,
task_prompt: str,
knowledge_config: dict[str, Any],
) -> str:
"""Handle async knowledge retrieval for task execution.
Args:
agent: The agent performing the task.
task: The task being executed.
task_prompt: The current task prompt.
knowledge_config: Knowledge configuration dictionary.
Returns:
The task prompt potentially augmented with knowledge context.
"""
if not (agent.knowledge or (agent.crew and agent.crew.knowledge)):
return task_prompt
crewai_event_bus.emit(
agent,
event=KnowledgeRetrievalStartedEvent(
from_task=task,
from_agent=agent,
),
)
try:
agent.knowledge_search_query = agent._get_knowledge_search_query(
task_prompt, task
)
if agent.knowledge_search_query:
if agent.knowledge:
agent_knowledge_snippets = await agent.knowledge.aquery(
[agent.knowledge_search_query], **knowledge_config
)
if agent_knowledge_snippets:
agent.agent_knowledge_context = extract_knowledge_context(
agent_knowledge_snippets
)
if agent.agent_knowledge_context:
task_prompt += agent.agent_knowledge_context
knowledge_snippets = await agent.crew.aquery_knowledge(
[agent.knowledge_search_query], **knowledge_config
)
if knowledge_snippets:
agent.crew_knowledge_context = extract_knowledge_context(
knowledge_snippets
)
if agent.crew_knowledge_context:
task_prompt += agent.crew_knowledge_context
crewai_event_bus.emit(
agent,
event=KnowledgeRetrievalCompletedEvent(
query=agent.knowledge_search_query,
from_task=task,
from_agent=agent,
retrieved_knowledge=_combine_knowledge_context(agent),
),
)
except Exception as e:
crewai_event_bus.emit(
agent,
event=KnowledgeSearchQueryFailedEvent(
query=agent.knowledge_search_query or "",
error=str(e),
from_task=task,
from_agent=agent,
),
)
return task_prompt

View File

@@ -265,7 +265,7 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta):
if not mcps:
return mcps
validated_mcps = []
validated_mcps: list[str | MCPServerConfig] = []
for mcp in mcps:
if isinstance(mcp, str):
if mcp.startswith(("https://", "crewai-amp:")):
@@ -347,6 +347,15 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta):
) -> str:
pass
@abstractmethod
async def aexecute_task(
self,
task: Any,
context: str | None = None,
tools: list[BaseTool] | None = None,
) -> str:
"""Execute a task asynchronously."""
@abstractmethod
def create_agent_executor(self, tools: list[BaseTool] | None = None) -> None:
pass

View File

@@ -28,6 +28,7 @@ from crewai.hooks.llm_hooks import (
get_before_llm_call_hooks,
)
from crewai.utilities.agent_utils import (
aget_llm_response,
enforce_rpm_limit,
format_message_for_llm,
get_llm_response,
@@ -43,7 +44,10 @@ from crewai.utilities.agent_utils import (
from crewai.utilities.constants import TRAINING_DATA_FILE
from crewai.utilities.i18n import I18N, get_i18n
from crewai.utilities.printer import Printer
from crewai.utilities.tool_utils import execute_tool_and_check_finality
from crewai.utilities.tool_utils import (
aexecute_tool_and_check_finality,
execute_tool_and_check_finality,
)
from crewai.utilities.training_handler import CrewTrainingHandler
@@ -134,8 +138,8 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
self.messages: list[LLMMessage] = []
self.iterations = 0
self.log_error_after = 3
self.before_llm_call_hooks: list[Callable] = []
self.after_llm_call_hooks: list[Callable] = []
self.before_llm_call_hooks: list[Callable[..., Any]] = []
self.after_llm_call_hooks: list[Callable[..., Any]] = []
self.before_llm_call_hooks.extend(get_before_llm_call_hooks())
self.after_llm_call_hooks.extend(get_after_llm_call_hooks())
if self.llm:
@@ -312,6 +316,154 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
self._show_logs(formatted_answer)
return formatted_answer
async def ainvoke(self, inputs: dict[str, Any]) -> dict[str, Any]:
"""Execute the agent asynchronously with given inputs.
Args:
inputs: Input dictionary containing prompt variables.
Returns:
Dictionary with agent output.
"""
if "system" in self.prompt:
system_prompt = self._format_prompt(
cast(str, self.prompt.get("system", "")), inputs
)
user_prompt = self._format_prompt(
cast(str, self.prompt.get("user", "")), inputs
)
self.messages.append(format_message_for_llm(system_prompt, role="system"))
self.messages.append(format_message_for_llm(user_prompt))
else:
user_prompt = self._format_prompt(self.prompt.get("prompt", ""), inputs)
self.messages.append(format_message_for_llm(user_prompt))
self._show_start_logs()
self.ask_for_human_input = bool(inputs.get("ask_for_human_input", False))
try:
formatted_answer = await self._ainvoke_loop()
except AssertionError:
self._printer.print(
content="Agent failed to reach a final answer. This is likely a bug - please report it.",
color="red",
)
raise
except Exception as e:
handle_unknown_error(self._printer, e)
raise
if self.ask_for_human_input:
formatted_answer = self._handle_human_feedback(formatted_answer)
self._create_short_term_memory(formatted_answer)
self._create_long_term_memory(formatted_answer)
self._create_external_memory(formatted_answer)
return {"output": formatted_answer.output}
async def _ainvoke_loop(self) -> AgentFinish:
"""Execute agent loop asynchronously until completion.
Returns:
Final answer from the agent.
"""
formatted_answer = None
while not isinstance(formatted_answer, AgentFinish):
try:
if has_reached_max_iterations(self.iterations, self.max_iter):
formatted_answer = handle_max_iterations_exceeded(
formatted_answer,
printer=self._printer,
i18n=self._i18n,
messages=self.messages,
llm=self.llm,
callbacks=self.callbacks,
)
break
enforce_rpm_limit(self.request_within_rpm_limit)
answer = await aget_llm_response(
llm=self.llm,
messages=self.messages,
callbacks=self.callbacks,
printer=self._printer,
from_task=self.task,
from_agent=self.agent,
response_model=self.response_model,
executor_context=self,
)
formatted_answer = process_llm_response(answer, self.use_stop_words) # type: ignore[assignment]
if isinstance(formatted_answer, AgentAction):
fingerprint_context = {}
if (
self.agent
and hasattr(self.agent, "security_config")
and hasattr(self.agent.security_config, "fingerprint")
):
fingerprint_context = {
"agent_fingerprint": str(
self.agent.security_config.fingerprint
)
}
tool_result = await aexecute_tool_and_check_finality(
agent_action=formatted_answer,
fingerprint_context=fingerprint_context,
tools=self.tools,
i18n=self._i18n,
agent_key=self.agent.key if self.agent else None,
agent_role=self.agent.role if self.agent else None,
tools_handler=self.tools_handler,
task=self.task,
agent=self.agent,
function_calling_llm=self.function_calling_llm,
crew=self.crew,
)
formatted_answer = self._handle_agent_action(
formatted_answer, tool_result
)
self._invoke_step_callback(formatted_answer) # type: ignore[arg-type]
self._append_message(formatted_answer.text) # type: ignore[union-attr,attr-defined]
except OutputParserError as e:
formatted_answer = handle_output_parser_exception( # type: ignore[assignment]
e=e,
messages=self.messages,
iterations=self.iterations,
log_error_after=self.log_error_after,
printer=self._printer,
)
except Exception as e:
if e.__class__.__module__.startswith("litellm"):
raise e
if is_context_length_exceeded(e):
handle_context_length(
respect_context_window=self.respect_context_window,
printer=self._printer,
messages=self.messages,
llm=self.llm,
callbacks=self.callbacks,
i18n=self._i18n,
)
continue
handle_unknown_error(self._printer, e)
raise e
finally:
self.iterations += 1
if not isinstance(formatted_answer, AgentFinish):
raise RuntimeError(
"Agent execution ended without reaching a final answer. "
f"Got {type(formatted_answer).__name__} instead of AgentFinish."
)
self._show_logs(formatted_answer)
return formatted_answer
def _handle_agent_action(
self, formatted_answer: AgentAction, tool_result: ToolResult
) -> AgentAction | AgentFinish:

View File

@@ -5,7 +5,7 @@ description = "{{name}} using crewAI"
authors = [{ name = "Your Name", email = "you@example.com" }]
requires-python = ">=3.10,<3.14"
dependencies = [
"crewai[tools]==1.6.0"
"crewai[tools]==1.6.1"
]
[project.scripts]

View File

@@ -5,7 +5,7 @@ description = "{{name}} using crewAI"
authors = [{ name = "Your Name", email = "you@example.com" }]
requires-python = ">=3.10,<3.14"
dependencies = [
"crewai[tools]==1.6.0"
"crewai[tools]==1.6.1"
]
[project.scripts]

View File

@@ -327,7 +327,7 @@ class Crew(FlowTrackable, BaseModel):
def set_private_attrs(self) -> Crew:
"""set private attributes."""
self._cache_handler = CacheHandler()
event_listener = EventListener() # type: ignore[no-untyped-call]
event_listener = EventListener()
# Determine and set tracing state once for this execution
tracing_enabled = should_enable_tracing(override=self.tracing)
@@ -348,12 +348,12 @@ class Crew(FlowTrackable, BaseModel):
return self
def _initialize_default_memories(self) -> None:
self._long_term_memory = self._long_term_memory or LongTermMemory() # type: ignore[no-untyped-call]
self._short_term_memory = self._short_term_memory or ShortTermMemory( # type: ignore[no-untyped-call]
self._long_term_memory = self._long_term_memory or LongTermMemory()
self._short_term_memory = self._short_term_memory or ShortTermMemory(
crew=self,
embedder_config=self.embedder,
)
self._entity_memory = self.entity_memory or EntityMemory( # type: ignore[no-untyped-call]
self._entity_memory = self.entity_memory or EntityMemory(
crew=self, embedder_config=self.embedder
)
@@ -948,6 +948,342 @@ class Crew(FlowTrackable, BaseModel):
self._task_output_handler.reset()
return list(results)
async def akickoff(
self, inputs: dict[str, Any] | None = None
) -> CrewOutput | CrewStreamingOutput:
"""Native async kickoff method using async task execution throughout.
Unlike kickoff_async which wraps sync kickoff in a thread, this method
uses native async/await for all operations including task execution,
memory operations, and knowledge queries.
"""
if self.stream:
for agent in self.agents:
if agent.llm is not None:
agent.llm.stream = True
result_holder: list[CrewOutput] = []
current_task_info: TaskInfo = {
"index": 0,
"name": "",
"id": "",
"agent_role": "",
"agent_id": "",
}
state = create_streaming_state(
current_task_info, result_holder, use_async=True
)
output_holder: list[CrewStreamingOutput | FlowStreamingOutput] = []
async def run_crew() -> None:
try:
self.stream = False
result = await self.akickoff(inputs)
if isinstance(result, CrewOutput):
result_holder.append(result)
except Exception as e:
signal_error(state, e, is_async=True)
finally:
self.stream = True
signal_end(state, is_async=True)
streaming_output = CrewStreamingOutput(
async_iterator=create_async_chunk_generator(
state, run_crew, output_holder
)
)
output_holder.append(streaming_output)
return streaming_output
ctx = baggage.set_baggage(
"crew_context", CrewContext(id=str(self.id), key=self.key)
)
token = attach(ctx)
try:
for before_callback in self.before_kickoff_callbacks:
if inputs is None:
inputs = {}
inputs = before_callback(inputs)
crewai_event_bus.emit(
self,
CrewKickoffStartedEvent(crew_name=self.name, inputs=inputs),
)
self._task_output_handler.reset()
self._logging_color = "bold_purple"
if inputs is not None:
self._inputs = inputs
self._interpolate_inputs(inputs)
self._set_tasks_callbacks()
self._set_allow_crewai_trigger_context_for_first_task()
for agent in self.agents:
agent.crew = self
agent.set_knowledge(crew_embedder=self.embedder)
if not agent.function_calling_llm: # type: ignore[attr-defined]
agent.function_calling_llm = self.function_calling_llm # type: ignore[attr-defined]
if not agent.step_callback: # type: ignore[attr-defined]
agent.step_callback = self.step_callback # type: ignore[attr-defined]
agent.create_agent_executor()
if self.planning:
self._handle_crew_planning()
if self.process == Process.sequential:
result = await self._arun_sequential_process()
elif self.process == Process.hierarchical:
result = await self._arun_hierarchical_process()
else:
raise NotImplementedError(
f"The process '{self.process}' is not implemented yet."
)
for after_callback in self.after_kickoff_callbacks:
result = after_callback(result)
self.usage_metrics = self.calculate_usage_metrics()
return result
except Exception as e:
crewai_event_bus.emit(
self,
CrewKickoffFailedEvent(error=str(e), crew_name=self.name),
)
raise
finally:
detach(token)
async def akickoff_for_each(
self, inputs: list[dict[str, Any]]
) -> list[CrewOutput | CrewStreamingOutput] | CrewStreamingOutput:
"""Native async execution of the Crew's workflow for each input.
Uses native async throughout rather than thread-based async.
If stream=True, returns a single CrewStreamingOutput that yields chunks
from all crews as they arrive.
"""
crew_copies = [self.copy() for _ in inputs]
if self.stream:
result_holder: list[list[CrewOutput]] = [[]]
current_task_info: TaskInfo = {
"index": 0,
"name": "",
"id": "",
"agent_role": "",
"agent_id": "",
}
state = create_streaming_state(
current_task_info, result_holder, use_async=True
)
output_holder: list[CrewStreamingOutput | FlowStreamingOutput] = []
async def run_all_crews() -> None:
try:
streaming_outputs: list[CrewStreamingOutput] = []
for i, crew in enumerate(crew_copies):
streaming = await crew.akickoff(inputs=inputs[i])
if isinstance(streaming, CrewStreamingOutput):
streaming_outputs.append(streaming)
async def consume_stream(
stream_output: CrewStreamingOutput,
) -> CrewOutput:
async for chunk in stream_output:
if state.async_queue is not None and state.loop is not None:
state.loop.call_soon_threadsafe(
state.async_queue.put_nowait, chunk
)
return stream_output.result
crew_results = await asyncio.gather(
*[consume_stream(s) for s in streaming_outputs]
)
result_holder[0] = list(crew_results)
except Exception as e:
signal_error(state, e, is_async=True)
finally:
signal_end(state, is_async=True)
streaming_output = CrewStreamingOutput(
async_iterator=create_async_chunk_generator(
state, run_all_crews, output_holder
)
)
def set_results_wrapper(result: Any) -> None:
streaming_output._set_results(result)
streaming_output._set_result = set_results_wrapper # type: ignore[method-assign]
output_holder.append(streaming_output)
return streaming_output
tasks = [
asyncio.create_task(crew_copy.akickoff(inputs=input_data))
for crew_copy, input_data in zip(crew_copies, inputs, strict=True)
]
results = await asyncio.gather(*tasks)
total_usage_metrics = UsageMetrics()
for crew_copy in crew_copies:
if crew_copy.usage_metrics:
total_usage_metrics.add_usage_metrics(crew_copy.usage_metrics)
self.usage_metrics = total_usage_metrics
self._task_output_handler.reset()
return list(results)
async def _arun_sequential_process(self) -> CrewOutput:
"""Executes tasks sequentially using native async and returns the final output."""
return await self._aexecute_tasks(self.tasks)
async def _arun_hierarchical_process(self) -> CrewOutput:
"""Creates and assigns a manager agent to complete the tasks using native async."""
self._create_manager_agent()
return await self._aexecute_tasks(self.tasks)
async def _aexecute_tasks(
self,
tasks: list[Task],
start_index: int | None = 0,
was_replayed: bool = False,
) -> CrewOutput:
"""Executes tasks using native async and returns the final output.
Args:
tasks: List of tasks to execute
start_index: Index to start execution from (for replay)
was_replayed: Whether this is a replayed execution
Returns:
CrewOutput: Final output of the crew
"""
task_outputs: list[TaskOutput] = []
pending_tasks: list[tuple[Task, asyncio.Task[TaskOutput], int]] = []
last_sync_output: TaskOutput | None = None
for task_index, task in enumerate(tasks):
if start_index is not None and task_index < start_index:
if task.output:
if task.async_execution:
task_outputs.append(task.output)
else:
task_outputs = [task.output]
last_sync_output = task.output
continue
agent_to_use = self._get_agent_to_use(task)
if agent_to_use is None:
raise ValueError(
f"No agent available for task: {task.description}. "
f"Ensure that either the task has an assigned agent "
f"or a manager agent is provided."
)
tools_for_task = task.tools or agent_to_use.tools or []
tools_for_task = self._prepare_tools(
agent_to_use,
task,
tools_for_task,
)
self._log_task_start(task, agent_to_use.role)
if isinstance(task, ConditionalTask):
skipped_task_output = await self._ahandle_conditional_task(
task, task_outputs, pending_tasks, task_index, was_replayed
)
if skipped_task_output:
task_outputs.append(skipped_task_output)
continue
if task.async_execution:
context = self._get_context(
task, [last_sync_output] if last_sync_output else []
)
async_task = asyncio.create_task(
task.aexecute_sync(
agent=agent_to_use,
context=context,
tools=tools_for_task,
)
)
pending_tasks.append((task, async_task, task_index))
else:
if pending_tasks:
task_outputs = await self._aprocess_async_tasks(
pending_tasks, was_replayed
)
pending_tasks.clear()
context = self._get_context(task, task_outputs)
task_output = await task.aexecute_sync(
agent=agent_to_use,
context=context,
tools=tools_for_task,
)
task_outputs.append(task_output)
self._process_task_result(task, task_output)
self._store_execution_log(task, task_output, task_index, was_replayed)
if pending_tasks:
task_outputs = await self._aprocess_async_tasks(pending_tasks, was_replayed)
return self._create_crew_output(task_outputs)
async def _ahandle_conditional_task(
self,
task: ConditionalTask,
task_outputs: list[TaskOutput],
pending_tasks: list[tuple[Task, asyncio.Task[TaskOutput], int]],
task_index: int,
was_replayed: bool,
) -> TaskOutput | None:
"""Handle conditional task evaluation using native async."""
if pending_tasks:
task_outputs = await self._aprocess_async_tasks(pending_tasks, was_replayed)
pending_tasks.clear()
previous_output = task_outputs[-1] if task_outputs else None
if previous_output is not None and not task.should_execute(previous_output):
self._logger.log(
"debug",
f"Skipping conditional task: {task.description}",
color="yellow",
)
skipped_task_output = task.get_skipped_task_output()
if not was_replayed:
self._store_execution_log(task, skipped_task_output, task_index)
return skipped_task_output
return None
async def _aprocess_async_tasks(
self,
pending_tasks: list[tuple[Task, asyncio.Task[TaskOutput], int]],
was_replayed: bool = False,
) -> list[TaskOutput]:
"""Process pending async tasks and return their outputs."""
task_outputs: list[TaskOutput] = []
for future_task, async_task, task_index in pending_tasks:
task_output = await async_task
task_outputs.append(task_output)
self._process_task_result(future_task, task_output)
self._store_execution_log(
future_task, task_output, task_index, was_replayed
)
return task_outputs
def _handle_crew_planning(self) -> None:
"""Handles the Crew planning."""
self._logger.log("info", "Planning the crew execution")
@@ -1431,6 +1767,16 @@ class Crew(FlowTrackable, BaseModel):
)
return None
async def aquery_knowledge(
self, query: list[str], results_limit: int = 3, score_threshold: float = 0.35
) -> list[SearchResult] | None:
"""Query the crew's knowledge base for relevant information asynchronously."""
if self.knowledge:
return await self.knowledge.aquery(
query, results_limit=results_limit, score_threshold=score_threshold
)
return None
def fetch_inputs(self) -> set[str]:
"""
Gathers placeholders (e.g., {something}) referenced in tasks or agents.

View File

@@ -76,7 +76,7 @@ class TraceBatchManager:
use_ephemeral: bool = False,
) -> TraceBatch:
"""Initialize a new trace batch (thread-safe)"""
with self._init_lock:
with self._batch_ready_cv:
if self.current_batch is not None:
logger.debug(
"Batch already initialized, skipping duplicate initialization"
@@ -99,7 +99,6 @@ class TraceBatchManager:
self.backend_initialized = True
self._batch_ready_cv.notify_all()
return self.current_batch
def _initialize_backend_batch(
@@ -107,7 +106,7 @@ class TraceBatchManager:
user_context: dict[str, str],
execution_metadata: dict[str, Any],
use_ephemeral: bool = False,
):
) -> None:
"""Send batch initialization to backend"""
if not is_tracing_enabled_in_context():
@@ -204,7 +203,7 @@ class TraceBatchManager:
return False
return True
def add_event(self, trace_event: TraceEvent):
def add_event(self, trace_event: TraceEvent) -> None:
"""Add event to buffer"""
self.event_buffer.append(trace_event)
@@ -300,7 +299,7 @@ class TraceBatchManager:
return finalized_batch
def _finalize_backend_batch(self, events_count: int = 0):
def _finalize_backend_batch(self, events_count: int = 0) -> None:
"""Send batch finalization to backend
Args:
@@ -366,7 +365,7 @@ class TraceBatchManager:
logger.error(f"❌ Error finalizing trace batch: {e}")
self.plus_api.mark_trace_batch_as_failed(self.trace_batch_id, str(e))
def _cleanup_batch_data(self):
def _cleanup_batch_data(self) -> None:
"""Clean up batch data after successful finalization to free memory"""
try:
if hasattr(self, "event_buffer") and self.event_buffer:
@@ -411,7 +410,7 @@ class TraceBatchManager:
lambda: self.current_batch is not None, timeout=timeout
)
def record_start_time(self, key: str):
def record_start_time(self, key: str) -> None:
"""Record start time for duration calculation"""
self.execution_start_times[key] = datetime.now(timezone.utc)

View File

@@ -71,6 +71,7 @@ from crewai.events.types.reasoning_events import (
AgentReasoningFailedEvent,
AgentReasoningStartedEvent,
)
from crewai.events.types.system_events import SignalEvent, on_signal
from crewai.events.types.task_events import (
TaskCompletedEvent,
TaskFailedEvent,
@@ -159,6 +160,7 @@ class TraceCollectionListener(BaseEventListener):
self._register_flow_event_handlers(crewai_event_bus)
self._register_context_event_handlers(crewai_event_bus)
self._register_action_event_handlers(crewai_event_bus)
self._register_system_event_handlers(crewai_event_bus)
self._listeners_setup = True
@@ -458,6 +460,15 @@ class TraceCollectionListener(BaseEventListener):
) -> None:
self._handle_action_event("knowledge_query_failed", source, event)
def _register_system_event_handlers(self, event_bus: CrewAIEventsBus) -> None:
"""Register handlers for system signal events (SIGTERM, SIGINT, etc.)."""
@on_signal
def handle_signal(source: Any, event: SignalEvent) -> None:
"""Flush trace batch on system signals to prevent data loss."""
if self.batch_manager.is_batch_initialized():
self.batch_manager.finalize_batch()
def _initialize_crew_batch(self, source: Any, event: Any) -> None:
"""Initialize trace batch.

View File

@@ -0,0 +1,102 @@
"""System signal event types for CrewAI.
This module contains event types for system-level signals like SIGTERM,
allowing listeners to perform cleanup operations before process termination.
"""
from collections.abc import Callable
from enum import IntEnum
import signal
from typing import Annotated, Literal, TypeVar
from pydantic import Field, TypeAdapter
from crewai.events.base_events import BaseEvent
class SignalType(IntEnum):
"""Enumeration of supported system signals."""
SIGTERM = signal.SIGTERM
SIGINT = signal.SIGINT
SIGHUP = signal.SIGHUP
SIGTSTP = signal.SIGTSTP
SIGCONT = signal.SIGCONT
class SigTermEvent(BaseEvent):
"""Event emitted when SIGTERM is received."""
type: Literal["SIGTERM"] = "SIGTERM"
signal_number: SignalType = SignalType.SIGTERM
reason: str | None = None
class SigIntEvent(BaseEvent):
"""Event emitted when SIGINT is received."""
type: Literal["SIGINT"] = "SIGINT"
signal_number: SignalType = SignalType.SIGINT
reason: str | None = None
class SigHupEvent(BaseEvent):
"""Event emitted when SIGHUP is received."""
type: Literal["SIGHUP"] = "SIGHUP"
signal_number: SignalType = SignalType.SIGHUP
reason: str | None = None
class SigTStpEvent(BaseEvent):
"""Event emitted when SIGTSTP is received.
Note: SIGSTOP cannot be caught - it immediately suspends the process.
"""
type: Literal["SIGTSTP"] = "SIGTSTP"
signal_number: SignalType = SignalType.SIGTSTP
reason: str | None = None
class SigContEvent(BaseEvent):
"""Event emitted when SIGCONT is received."""
type: Literal["SIGCONT"] = "SIGCONT"
signal_number: SignalType = SignalType.SIGCONT
reason: str | None = None
SignalEvent = Annotated[
SigTermEvent | SigIntEvent | SigHupEvent | SigTStpEvent | SigContEvent,
Field(discriminator="type"),
]
signal_event_adapter: TypeAdapter[SignalEvent] = TypeAdapter(SignalEvent)
SIGNAL_EVENT_TYPES: tuple[type[BaseEvent], ...] = (
SigTermEvent,
SigIntEvent,
SigHupEvent,
SigTStpEvent,
SigContEvent,
)
T = TypeVar("T", bound=Callable[[object, SignalEvent], None])
def on_signal(func: T) -> T:
"""Decorator to register a handler for all signal events.
Args:
func: Handler function that receives (source, event) arguments.
Returns:
The original function, registered for all signal event types.
"""
from crewai.events.event_bus import crewai_event_bus
for event_type in SIGNAL_EVENT_TYPES:
crewai_event_bus.on(event_type)(func)
return func

View File

@@ -1032,6 +1032,20 @@ class Flow(Generic[T], metaclass=FlowMeta):
finally:
detach(flow_token)
async def akickoff(
self, inputs: dict[str, Any] | None = None
) -> Any | FlowStreamingOutput:
"""Native async method to start the flow execution. Alias for kickoff_async.
Args:
inputs: Optional dictionary containing input values and/or a state ID for restoration.
Returns:
The final output from the flow, which is the result of the last executed method.
"""
return await self.kickoff_async(inputs)
async def _execute_start_method(self, start_method_name: FlowMethodName) -> None:
"""Executes a flow's start method and its triggered listeners.

View File

@@ -32,8 +32,8 @@ class Knowledge(BaseModel):
sources: list[BaseKnowledgeSource],
embedder: EmbedderConfig | None = None,
storage: KnowledgeStorage | None = None,
**data,
):
**data: object,
) -> None:
super().__init__(**data)
if storage:
self.storage = storage
@@ -75,3 +75,44 @@ class Knowledge(BaseModel):
self.storage.reset()
else:
raise ValueError("Storage is not initialized.")
async def aquery(
self, query: list[str], results_limit: int = 5, score_threshold: float = 0.6
) -> list[SearchResult]:
"""Query across all knowledge sources asynchronously.
Args:
query: List of query strings.
results_limit: Maximum number of results to return.
score_threshold: Minimum similarity score for results.
Returns:
The top results matching the query.
Raises:
ValueError: If storage is not initialized.
"""
if self.storage is None:
raise ValueError("Storage is not initialized.")
return await self.storage.asearch(
query,
limit=results_limit,
score_threshold=score_threshold,
)
async def aadd_sources(self) -> None:
"""Add all knowledge sources to storage asynchronously."""
try:
for source in self.sources:
source.storage = self.storage
await source.aadd()
except Exception as e:
raise e
async def areset(self) -> None:
"""Reset the knowledge base asynchronously."""
if self.storage:
await self.storage.areset()
else:
raise ValueError("Storage is not initialized.")

View File

@@ -1,5 +1,6 @@
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any
from pydantic import Field, field_validator
@@ -25,7 +26,10 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
safe_file_paths: list[Path] = Field(default_factory=list)
@field_validator("file_path", "file_paths", mode="before")
def validate_file_path(cls, v, info): # noqa: N805
@classmethod
def validate_file_path(
cls, v: Path | list[Path] | str | list[str] | None, info: Any
) -> Path | list[Path] | str | list[str] | None:
"""Validate that at least one of file_path or file_paths is provided."""
# Single check if both are None, O(1) instead of nested conditions
if (
@@ -38,7 +42,7 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
raise ValueError("Either file_path or file_paths must be provided")
return v
def model_post_init(self, _):
def model_post_init(self, _: Any) -> None:
"""Post-initialization method to load content."""
self.safe_file_paths = self._process_file_paths()
self.validate_content()
@@ -48,7 +52,7 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
def load_content(self) -> dict[Path, str]:
"""Load and preprocess file content. Should be overridden by subclasses. Assume that the file path is relative to the project root in the knowledge directory."""
def validate_content(self):
def validate_content(self) -> None:
"""Validate the paths."""
for path in self.safe_file_paths:
if not path.exists():
@@ -65,13 +69,20 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
color="red",
)
def _save_documents(self):
def _save_documents(self) -> None:
"""Save the documents to the storage."""
if self.storage:
self.storage.save(self.chunks)
else:
raise ValueError("No storage found to save documents.")
async def _asave_documents(self) -> None:
"""Save the documents to the storage asynchronously."""
if self.storage:
await self.storage.asave(self.chunks)
else:
raise ValueError("No storage found to save documents.")
def convert_to_path(self, path: Path | str) -> Path:
"""Convert a path to a Path object."""
return Path(KNOWLEDGE_DIRECTORY + "/" + path) if isinstance(path, str) else path

View File

@@ -39,12 +39,32 @@ class BaseKnowledgeSource(BaseModel, ABC):
for i in range(0, len(text), self.chunk_size - self.chunk_overlap)
]
def _save_documents(self):
"""
Save the documents to the storage.
def _save_documents(self) -> None:
"""Save the documents to the storage.
This method should be called after the chunks and embeddings are generated.
Raises:
ValueError: If no storage is configured.
"""
if self.storage:
self.storage.save(self.chunks)
else:
raise ValueError("No storage found to save documents.")
@abstractmethod
async def aadd(self) -> None:
"""Process content, chunk it, compute embeddings, and save them asynchronously."""
async def _asave_documents(self) -> None:
"""Save the documents to the storage asynchronously.
This method should be called after the chunks and embeddings are generated.
Raises:
ValueError: If no storage is configured.
"""
if self.storage:
await self.storage.asave(self.chunks)
else:
raise ValueError("No storage found to save documents.")

View File

@@ -2,27 +2,24 @@ from __future__ import annotations
from collections.abc import Iterator
from pathlib import Path
from typing import TYPE_CHECKING, Any
from urllib.parse import urlparse
try:
from docling.datamodel.base_models import ( # type: ignore[import-not-found]
InputFormat,
)
from docling.document_converter import ( # type: ignore[import-not-found]
DocumentConverter,
)
from docling.exceptions import ConversionError # type: ignore[import-not-found]
from docling_core.transforms.chunker.hierarchical_chunker import ( # type: ignore[import-not-found]
HierarchicalChunker,
)
from docling_core.types.doc.document import ( # type: ignore[import-not-found]
DoclingDocument,
)
from docling.datamodel.base_models import InputFormat
from docling.document_converter import DocumentConverter
from docling.exceptions import ConversionError
from docling_core.transforms.chunker.hierarchical_chunker import HierarchicalChunker
from docling_core.types.doc.document import DoclingDocument
DOCLING_AVAILABLE = True
except ImportError:
DOCLING_AVAILABLE = False
# Provide type stubs for when docling is not available
if TYPE_CHECKING:
from docling.document_converter import DocumentConverter
from docling_core.types.doc.document import DoclingDocument
from pydantic import Field
@@ -32,11 +29,13 @@ from crewai.utilities.logger import Logger
class CrewDoclingSource(BaseKnowledgeSource):
"""Default Source class for converting documents to markdown or json
This will auto support PDF, DOCX, and TXT, XLSX, Images, and HTML files without any additional dependencies and follows the docling package as the source of truth.
"""Default Source class for converting documents to markdown or json.
This will auto support PDF, DOCX, and TXT, XLSX, Images, and HTML files without
any additional dependencies and follows the docling package as the source of truth.
"""
def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any) -> None:
if not DOCLING_AVAILABLE:
raise ImportError(
"The docling package is required to use CrewDoclingSource. "
@@ -66,7 +65,7 @@ class CrewDoclingSource(BaseKnowledgeSource):
)
)
def model_post_init(self, _) -> None:
def model_post_init(self, _: Any) -> None:
if self.file_path:
self._logger.log(
"warning",
@@ -99,6 +98,15 @@ class CrewDoclingSource(BaseKnowledgeSource):
self.chunks.extend(list(new_chunks_iterable))
self._save_documents()
async def aadd(self) -> None:
"""Add docling content asynchronously."""
if self.content is None:
return
for doc in self.content:
new_chunks_iterable = self._chunk_doc(doc)
self.chunks.extend(list(new_chunks_iterable))
await self._asave_documents()
def _convert_source_to_docling_documents(self) -> list[DoclingDocument]:
conv_results_iter = self.document_converter.convert_all(self.safe_file_paths)
return [result.document for result in conv_results_iter]

View File

@@ -31,6 +31,15 @@ class CSVKnowledgeSource(BaseFileKnowledgeSource):
self.chunks.extend(new_chunks)
self._save_documents()
async def aadd(self) -> None:
"""Add CSV file content asynchronously."""
content_str = (
str(self.content) if isinstance(self.content, dict) else self.content
)
new_chunks = self._chunk_text(content_str)
self.chunks.extend(new_chunks)
await self._asave_documents()
def _chunk_text(self, text: str) -> list[str]:
"""Utility method to split text into chunks."""
return [

View File

@@ -1,4 +1,6 @@
from pathlib import Path
from types import ModuleType
from typing import Any
from pydantic import Field, field_validator
@@ -26,7 +28,10 @@ class ExcelKnowledgeSource(BaseKnowledgeSource):
safe_file_paths: list[Path] = Field(default_factory=list)
@field_validator("file_path", "file_paths", mode="before")
def validate_file_path(cls, v, info): # noqa: N805
@classmethod
def validate_file_path(
cls, v: Path | list[Path] | str | list[str] | None, info: Any
) -> Path | list[Path] | str | list[str] | None:
"""Validate that at least one of file_path or file_paths is provided."""
# Single check if both are None, O(1) instead of nested conditions
if (
@@ -69,7 +74,7 @@ class ExcelKnowledgeSource(BaseKnowledgeSource):
return [self.convert_to_path(path) for path in path_list]
def validate_content(self):
def validate_content(self) -> None:
"""Validate the paths."""
for path in self.safe_file_paths:
if not path.exists():
@@ -86,7 +91,7 @@ class ExcelKnowledgeSource(BaseKnowledgeSource):
color="red",
)
def model_post_init(self, _) -> None:
def model_post_init(self, _: Any) -> None:
if self.file_path:
self._logger.log(
"warning",
@@ -128,12 +133,12 @@ class ExcelKnowledgeSource(BaseKnowledgeSource):
"""Convert a path to a Path object."""
return Path(KNOWLEDGE_DIRECTORY + "/" + path) if isinstance(path, str) else path
def _import_dependencies(self):
def _import_dependencies(self) -> ModuleType:
"""Dynamically import dependencies."""
try:
import pandas as pd # type: ignore[import-untyped,import-not-found]
import pandas as pd # type: ignore[import-untyped]
return pd
return pd # type: ignore[no-any-return]
except ImportError as e:
missing_package = str(e).split()[-1]
raise ImportError(
@@ -159,6 +164,20 @@ class ExcelKnowledgeSource(BaseKnowledgeSource):
self.chunks.extend(new_chunks)
self._save_documents()
async def aadd(self) -> None:
"""Add Excel file content asynchronously."""
content_str = ""
for value in self.content.values():
if isinstance(value, dict):
for sheet_value in value.values():
content_str += str(sheet_value) + "\n"
else:
content_str += str(value) + "\n"
new_chunks = self._chunk_text(content_str)
self.chunks.extend(new_chunks)
await self._asave_documents()
def _chunk_text(self, text: str) -> list[str]:
"""Utility method to split text into chunks."""
return [

View File

@@ -44,6 +44,15 @@ class JSONKnowledgeSource(BaseFileKnowledgeSource):
self.chunks.extend(new_chunks)
self._save_documents()
async def aadd(self) -> None:
"""Add JSON file content asynchronously."""
content_str = (
str(self.content) if isinstance(self.content, dict) else self.content
)
new_chunks = self._chunk_text(content_str)
self.chunks.extend(new_chunks)
await self._asave_documents()
def _chunk_text(self, text: str) -> list[str]:
"""Utility method to split text into chunks."""
return [

View File

@@ -1,4 +1,5 @@
from pathlib import Path
from types import ModuleType
from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledgeSource
@@ -23,7 +24,7 @@ class PDFKnowledgeSource(BaseFileKnowledgeSource):
content[path] = text
return content
def _import_pdfplumber(self):
def _import_pdfplumber(self) -> ModuleType:
"""Dynamically import pdfplumber."""
try:
import pdfplumber
@@ -44,6 +45,13 @@ class PDFKnowledgeSource(BaseFileKnowledgeSource):
self.chunks.extend(new_chunks)
self._save_documents()
async def aadd(self) -> None:
"""Add PDF file content asynchronously."""
for text in self.content.values():
new_chunks = self._chunk_text(text)
self.chunks.extend(new_chunks)
await self._asave_documents()
def _chunk_text(self, text: str) -> list[str]:
"""Utility method to split text into chunks."""
return [

View File

@@ -1,3 +1,5 @@
from typing import Any
from pydantic import Field
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
@@ -9,11 +11,11 @@ class StringKnowledgeSource(BaseKnowledgeSource):
content: str = Field(...)
collection_name: str | None = Field(default=None)
def model_post_init(self, _):
def model_post_init(self, _: Any) -> None:
"""Post-initialization method to validate content."""
self.validate_content()
def validate_content(self):
def validate_content(self) -> None:
"""Validate string content."""
if not isinstance(self.content, str):
raise ValueError("StringKnowledgeSource only accepts string content")
@@ -24,6 +26,12 @@ class StringKnowledgeSource(BaseKnowledgeSource):
self.chunks.extend(new_chunks)
self._save_documents()
async def aadd(self) -> None:
"""Add string content asynchronously."""
new_chunks = self._chunk_text(self.content)
self.chunks.extend(new_chunks)
await self._asave_documents()
def _chunk_text(self, text: str) -> list[str]:
"""Utility method to split text into chunks."""
return [

View File

@@ -25,6 +25,13 @@ class TextFileKnowledgeSource(BaseFileKnowledgeSource):
self.chunks.extend(new_chunks)
self._save_documents()
async def aadd(self) -> None:
"""Add text file content asynchronously."""
for text in self.content.values():
new_chunks = self._chunk_text(text)
self.chunks.extend(new_chunks)
await self._asave_documents()
def _chunk_text(self, text: str) -> list[str]:
"""Utility method to split text into chunks."""
return [

View File

@@ -21,10 +21,28 @@ class BaseKnowledgeStorage(ABC):
) -> list[SearchResult]:
"""Search for documents in the knowledge base."""
@abstractmethod
async def asearch(
self,
query: list[str],
limit: int = 5,
metadata_filter: dict[str, Any] | None = None,
score_threshold: float = 0.6,
) -> list[SearchResult]:
"""Search for documents in the knowledge base asynchronously."""
@abstractmethod
def save(self, documents: list[str]) -> None:
"""Save documents to the knowledge base."""
@abstractmethod
async def asave(self, documents: list[str]) -> None:
"""Save documents to the knowledge base asynchronously."""
@abstractmethod
def reset(self) -> None:
"""Reset the knowledge base."""
@abstractmethod
async def areset(self) -> None:
"""Reset the knowledge base asynchronously."""

View File

@@ -25,8 +25,8 @@ class KnowledgeStorage(BaseKnowledgeStorage):
def __init__(
self,
embedder: ProviderSpec
| BaseEmbeddingsProvider
| type[BaseEmbeddingsProvider]
| BaseEmbeddingsProvider[Any]
| type[BaseEmbeddingsProvider[Any]]
| None = None,
collection_name: str | None = None,
) -> None:
@@ -127,3 +127,96 @@ class KnowledgeStorage(BaseKnowledgeStorage):
) from e
Logger(verbose=True).log("error", f"Failed to upsert documents: {e}", "red")
raise
async def asearch(
self,
query: list[str],
limit: int = 5,
metadata_filter: dict[str, Any] | None = None,
score_threshold: float = 0.6,
) -> list[SearchResult]:
"""Search for documents in the knowledge base asynchronously.
Args:
query: List of query strings.
limit: Maximum number of results to return.
metadata_filter: Optional metadata filter for the search.
score_threshold: Minimum similarity score for results.
Returns:
List of search results.
"""
try:
if not query:
raise ValueError("Query cannot be empty")
client = self._get_client()
collection_name = (
f"knowledge_{self.collection_name}"
if self.collection_name
else "knowledge"
)
query_text = " ".join(query) if len(query) > 1 else query[0]
return await client.asearch(
collection_name=collection_name,
query=query_text,
limit=limit,
metadata_filter=metadata_filter,
score_threshold=score_threshold,
)
except Exception as e:
logging.error(
f"Error during knowledge search: {e!s}\n{traceback.format_exc()}"
)
return []
async def asave(self, documents: list[str]) -> None:
"""Save documents to the knowledge base asynchronously.
Args:
documents: List of document strings to save.
"""
try:
client = self._get_client()
collection_name = (
f"knowledge_{self.collection_name}"
if self.collection_name
else "knowledge"
)
await client.aget_or_create_collection(collection_name=collection_name)
rag_documents: list[BaseRecord] = [{"content": doc} for doc in documents]
await client.aadd_documents(
collection_name=collection_name, documents=rag_documents
)
except Exception as e:
if "dimension mismatch" in str(e).lower():
Logger(verbose=True).log(
"error",
"Embedding dimension mismatch. This usually happens when mixing different embedding models. Try resetting the collection using `crewai reset-memories -a`",
"red",
)
raise ValueError(
"Embedding dimension mismatch. Make sure you're using the same embedding model "
"across all operations with this collection."
"Try resetting the collection using `crewai reset-memories -a`"
) from e
Logger(verbose=True).log("error", f"Failed to upsert documents: {e}", "red")
raise
async def areset(self) -> None:
"""Reset the knowledge base asynchronously."""
try:
client = self._get_client()
collection_name = (
f"knowledge_{self.collection_name}"
if self.collection_name
else "knowledge"
)
await client.adelete_collection(collection_name=collection_name)
except Exception as e:
logging.error(
f"Error during knowledge reset: {e!s}\n{traceback.format_exc()}"
)

View File

@@ -57,7 +57,12 @@ if TYPE_CHECKING:
from litellm.litellm_core_utils.get_supported_openai_params import (
get_supported_openai_params,
)
from litellm.types.utils import ChatCompletionDeltaToolCall, Choices, ModelResponse
from litellm.types.utils import (
ChatCompletionDeltaToolCall,
Choices,
Function,
ModelResponse,
)
from litellm.utils import supports_response_schema
from crewai.agent.core import Agent
@@ -73,7 +78,12 @@ try:
from litellm.litellm_core_utils.get_supported_openai_params import (
get_supported_openai_params,
)
from litellm.types.utils import ChatCompletionDeltaToolCall, Choices, ModelResponse
from litellm.types.utils import (
ChatCompletionDeltaToolCall,
Choices,
Function,
ModelResponse,
)
from litellm.utils import supports_response_schema
LITELLM_AVAILABLE = True
@@ -84,6 +94,7 @@ except ImportError:
ContextWindowExceededError = Exception # type: ignore
get_supported_openai_params = None # type: ignore
ChatCompletionDeltaToolCall = None # type: ignore
Function = None # type: ignore
ModelResponse = None # type: ignore
supports_response_schema = None # type: ignore
CustomLogger = None # type: ignore
@@ -406,46 +417,100 @@ class LLM(BaseLLM):
instance.is_litellm = True
return instance
@classmethod
def _matches_provider_pattern(cls, model: str, provider: str) -> bool:
"""Check if a model name matches provider-specific patterns.
This allows supporting models that aren't in the hardcoded constants list,
including "latest" versions and new models that follow provider naming conventions.
Args:
model: The model name to check
provider: The provider to check against (canonical name)
Returns:
True if the model matches the provider's naming pattern, False otherwise
"""
model_lower = model.lower()
if provider == "openai":
return any(
model_lower.startswith(prefix)
for prefix in ["gpt-", "o1", "o3", "o4", "whisper-"]
)
if provider == "anthropic" or provider == "claude":
return any(
model_lower.startswith(prefix) for prefix in ["claude-", "anthropic."]
)
if provider == "gemini" or provider == "google":
return any(
model_lower.startswith(prefix)
for prefix in ["gemini-", "gemma-", "learnlm-"]
)
if provider == "bedrock":
return "." in model_lower
if provider == "azure":
return any(
model_lower.startswith(prefix)
for prefix in ["gpt-", "gpt-35-", "o1", "o3", "o4", "azure-"]
)
return False
@classmethod
def _validate_model_in_constants(cls, model: str, provider: str) -> bool:
"""Validate if a model name exists in the provider's constants.
"""Validate if a model name exists in the provider's constants or matches provider patterns.
This method first checks the hardcoded constants list for known models.
If not found, it falls back to pattern matching to support new models,
"latest" versions, and models that follow provider naming conventions.
Args:
model: The model name to validate
provider: The provider to check against (canonical name)
Returns:
True if the model exists in the provider's constants, False otherwise
True if the model exists in constants or matches provider patterns, False otherwise
"""
if provider == "openai":
return model in OPENAI_MODELS
if provider == "openai" and model in OPENAI_MODELS:
return True
if provider == "anthropic" or provider == "claude":
return model in ANTHROPIC_MODELS
if (
provider == "anthropic" or provider == "claude"
) and model in ANTHROPIC_MODELS:
return True
if provider == "gemini":
return model in GEMINI_MODELS
if (provider == "gemini" or provider == "google") and model in GEMINI_MODELS:
return True
if provider == "bedrock":
return model in BEDROCK_MODELS
if provider == "bedrock" and model in BEDROCK_MODELS:
return True
if provider == "azure":
# azure does not provide a list of available models, determine a better way to handle this
return True
return False
# Fallback to pattern matching for models not in constants
return cls._matches_provider_pattern(model, provider)
@classmethod
def _infer_provider_from_model(cls, model: str) -> str:
"""Infer the provider from the model name.
This method first checks the hardcoded constants list for known models.
If not found, it uses pattern matching to infer the provider from model name patterns.
This allows supporting new models and "latest" versions without hardcoding.
Args:
model: The model name without provider prefix
Returns:
The inferred provider name, defaults to "openai"
"""
if model in OPENAI_MODELS:
return "openai"
@@ -556,7 +621,9 @@ class LLM(BaseLLM):
self.callbacks = callbacks
self.context_window_size = 0
self.reasoning_effort = reasoning_effort
self.additional_params = kwargs
self.additional_params = {
k: v for k, v in kwargs.items() if k not in ("is_litellm", "provider")
}
self.is_anthropic = self._is_anthropic_model(model)
self.stream = stream
self.interceptor = interceptor
@@ -1150,6 +1217,281 @@ class LLM(BaseLLM):
)
return text_response
async def _ahandle_non_streaming_response(
self,
params: dict[str, Any],
callbacks: list[Any] | None = None,
available_functions: dict[str, Any] | None = None,
from_task: Task | None = None,
from_agent: Agent | None = None,
response_model: type[BaseModel] | None = None,
) -> str | Any:
"""Handle an async non-streaming response from the LLM.
Args:
params: Parameters for the completion call
callbacks: Optional list of callback functions
available_functions: Dict of available functions
from_task: Optional Task that invoked the LLM
from_agent: Optional Agent that invoked the LLM
response_model: Optional Response model
Returns:
str: The response text
"""
if response_model and self.is_litellm:
from crewai.utilities.internal_instructor import InternalInstructor
messages = params.get("messages", [])
if not messages:
raise ValueError("Messages are required when using response_model")
combined_content = "\n\n".join(
f"{msg['role'].upper()}: {msg['content']}" for msg in messages
)
instructor_instance = InternalInstructor(
content=combined_content,
model=response_model,
llm=self,
)
result = instructor_instance.to_pydantic()
structured_response = result.model_dump_json()
self._handle_emit_call_events(
response=structured_response,
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
)
return structured_response
try:
if response_model:
params["response_model"] = response_model
response = await litellm.acompletion(**params)
except ContextWindowExceededError as e:
raise LLMContextLengthExceededError(str(e)) from e
if response_model is not None:
if isinstance(response, BaseModel):
structured_response = response.model_dump_json()
self._handle_emit_call_events(
response=structured_response,
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
)
return structured_response
response_message = cast(Choices, cast(ModelResponse, response).choices)[
0
].message
text_response = response_message.content or ""
if callbacks and len(callbacks) > 0:
for callback in callbacks:
if hasattr(callback, "log_success_event"):
usage_info = getattr(response, "usage", None)
if usage_info:
callback.log_success_event(
kwargs=params,
response_obj={"usage": usage_info},
start_time=0,
end_time=0,
)
tool_calls = getattr(response_message, "tool_calls", [])
if (not tool_calls or not available_functions) and text_response:
self._handle_emit_call_events(
response=text_response,
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
)
return text_response
if tool_calls and not available_functions and not text_response:
return tool_calls
tool_result = self._handle_tool_call(
tool_calls, available_functions, from_task, from_agent
)
if tool_result is not None:
return tool_result
self._handle_emit_call_events(
response=text_response,
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
)
return text_response
async def _ahandle_streaming_response(
self,
params: dict[str, Any],
callbacks: list[Any] | None = None,
available_functions: dict[str, Any] | None = None,
from_task: Task | None = None,
from_agent: Agent | None = None,
response_model: type[BaseModel] | None = None,
) -> Any:
"""Handle an async streaming response from the LLM.
Args:
params: Parameters for the completion call
callbacks: Optional list of callback functions
available_functions: Dict of available functions
from_task: Optional task object
from_agent: Optional agent object
response_model: Optional response model
Returns:
str: The complete response text
"""
full_response = ""
chunk_count = 0
usage_info = None
accumulated_tool_args: defaultdict[int, AccumulatedToolArgs] = defaultdict(
AccumulatedToolArgs
)
params["stream"] = True
params["stream_options"] = {"include_usage": True}
try:
async for chunk in await litellm.acompletion(**params):
chunk_count += 1
chunk_content = None
try:
choices = None
if isinstance(chunk, dict) and "choices" in chunk:
choices = chunk["choices"]
elif hasattr(chunk, "choices"):
if not isinstance(chunk.choices, type):
choices = chunk.choices
if hasattr(chunk, "usage") and chunk.usage is not None:
usage_info = chunk.usage
if choices and len(choices) > 0:
first_choice = choices[0]
delta = None
if isinstance(first_choice, dict):
delta = first_choice.get("delta", {})
elif hasattr(first_choice, "delta"):
delta = first_choice.delta
if delta:
if isinstance(delta, dict):
chunk_content = delta.get("content")
elif hasattr(delta, "content"):
chunk_content = delta.content
tool_calls: list[ChatCompletionDeltaToolCall] | None = None
if isinstance(delta, dict):
tool_calls = delta.get("tool_calls")
elif hasattr(delta, "tool_calls"):
tool_calls = delta.tool_calls
if tool_calls:
for tool_call in tool_calls:
idx = tool_call.index
if tool_call.function:
if tool_call.function.name:
accumulated_tool_args[
idx
].function.name = tool_call.function.name
if tool_call.function.arguments:
accumulated_tool_args[
idx
].function.arguments += (
tool_call.function.arguments
)
except (AttributeError, KeyError, IndexError, TypeError):
pass
if chunk_content:
full_response += chunk_content
crewai_event_bus.emit(
self,
event=LLMStreamChunkEvent(
chunk=chunk_content,
from_task=from_task,
from_agent=from_agent,
),
)
if callbacks and len(callbacks) > 0 and usage_info:
for callback in callbacks:
if hasattr(callback, "log_success_event"):
callback.log_success_event(
kwargs=params,
response_obj={"usage": usage_info},
start_time=0,
end_time=0,
)
if accumulated_tool_args and available_functions:
# Convert accumulated tool args to ChatCompletionDeltaToolCall objects
tool_calls_list: list[ChatCompletionDeltaToolCall] = [
ChatCompletionDeltaToolCall(
index=idx,
function=Function(
name=tool_arg.function.name,
arguments=tool_arg.function.arguments,
),
)
for idx, tool_arg in accumulated_tool_args.items()
if tool_arg.function.name
]
if tool_calls_list:
result = self._handle_streaming_tool_calls(
tool_calls=tool_calls_list,
accumulated_tool_args=accumulated_tool_args,
available_functions=available_functions,
from_task=from_task,
from_agent=from_agent,
)
if result is not None:
return result
self._handle_emit_call_events(
response=full_response,
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=params.get("messages"),
)
return full_response
except ContextWindowExceededError as e:
raise LLMContextLengthExceededError(str(e)) from e
except Exception:
if chunk_count == 0:
raise
if full_response:
self._handle_emit_call_events(
response=full_response,
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=params.get("messages"),
)
return full_response
raise
def _handle_tool_call(
self,
tool_calls: list[Any],
@@ -1367,6 +1709,128 @@ class LLM(BaseLLM):
)
raise
async def acall(
self,
messages: str | list[LLMMessage],
tools: list[dict[str, BaseTool]] | None = None,
callbacks: list[Any] | None = None,
available_functions: dict[str, Any] | None = None,
from_task: Task | None = None,
from_agent: Agent | None = None,
response_model: type[BaseModel] | None = None,
) -> str | Any:
"""Async high-level LLM call method.
Args:
messages: Input messages for the LLM.
Can be a string or list of message dictionaries.
If string, it will be converted to a single user message.
If list, each dict must have 'role' and 'content' keys.
tools: Optional list of tool schemas for function calling.
Each tool should define its name, description, and parameters.
callbacks: Optional list of callback functions to be executed
during and after the LLM call.
available_functions: Optional dict mapping function names to callables
that can be invoked by the LLM.
from_task: Optional Task that invoked the LLM
from_agent: Optional Agent that invoked the LLM
response_model: Optional Model that contains a pydantic response model.
Returns:
Union[str, Any]: Either a text response from the LLM (str) or
the result of a tool function call (Any).
Raises:
TypeError: If messages format is invalid
ValueError: If response format is not supported
LLMContextLengthExceededError: If input exceeds model's context limit
"""
crewai_event_bus.emit(
self,
event=LLMCallStartedEvent(
messages=messages,
tools=tools,
callbacks=callbacks,
available_functions=available_functions,
from_task=from_task,
from_agent=from_agent,
model=self.model,
),
)
self._validate_call_params()
if isinstance(messages, str):
messages = [{"role": "user", "content": messages}]
if "o1" in self.model.lower():
for message in messages:
if message.get("role") == "system":
msg_role: Literal["assistant"] = "assistant"
message["role"] = msg_role
with suppress_warnings():
if callbacks and len(callbacks) > 0:
self.set_callbacks(callbacks)
try:
params = self._prepare_completion_params(messages, tools)
if self.stream:
return await self._ahandle_streaming_response(
params=params,
callbacks=callbacks,
available_functions=available_functions,
from_task=from_task,
from_agent=from_agent,
response_model=response_model,
)
return await self._ahandle_non_streaming_response(
params=params,
callbacks=callbacks,
available_functions=available_functions,
from_task=from_task,
from_agent=from_agent,
response_model=response_model,
)
except LLMContextLengthExceededError:
raise
except Exception as e:
unsupported_stop = "Unsupported parameter" in str(
e
) and "'stop'" in str(e)
if unsupported_stop:
if (
"additional_drop_params" in self.additional_params
and isinstance(
self.additional_params["additional_drop_params"], list
)
):
self.additional_params["additional_drop_params"].append("stop")
else:
self.additional_params = {"additional_drop_params": ["stop"]}
logging.info("Retrying LLM call without the unsupported 'stop'")
return await self.acall(
messages,
tools=tools,
callbacks=callbacks,
available_functions=available_functions,
from_task=from_task,
from_agent=from_agent,
response_model=response_model,
)
crewai_event_bus.emit(
self,
event=LLMCallFailedEvent(
error=str(e), from_task=from_task, from_agent=from_agent
),
)
raise
def _handle_emit_call_events(
self,
response: Any,
@@ -1699,12 +2163,14 @@ class LLM(BaseLLM):
max_tokens=self.max_tokens,
presence_penalty=self.presence_penalty,
frequency_penalty=self.frequency_penalty,
logit_bias=copy.deepcopy(self.logit_bias, memo)
if self.logit_bias
else None,
response_format=copy.deepcopy(self.response_format, memo)
if self.response_format
else None,
logit_bias=(
copy.deepcopy(self.logit_bias, memo) if self.logit_bias else None
),
response_format=(
copy.deepcopy(self.response_format, memo)
if self.response_format
else None
),
seed=self.seed,
logprobs=self.logprobs,
top_logprobs=self.top_logprobs,

View File

@@ -158,6 +158,44 @@ class BaseLLM(ABC):
RuntimeError: If the LLM request fails for other reasons.
"""
async def acall(
self,
messages: str | list[LLMMessage],
tools: list[dict[str, BaseTool]] | None = None,
callbacks: list[Any] | None = None,
available_functions: dict[str, Any] | None = None,
from_task: Task | None = None,
from_agent: Agent | None = None,
response_model: type[BaseModel] | None = None,
) -> str | Any:
"""Call the LLM with the given messages.
Args:
messages: Input messages for the LLM.
Can be a string or list of message dictionaries.
If string, it will be converted to a single user message.
If list, each dict must have 'role' and 'content' keys.
tools: Optional list of tool schemas for function calling.
Each tool should define its name, description, and parameters.
callbacks: Optional list of callback functions to be executed
during and after the LLM call.
available_functions: Optional dict mapping function names to callables
that can be invoked by the LLM.
from_task: Optional task caller to be used for the LLM call.
from_agent: Optional agent caller to be used for the LLM call.
response_model: Optional response model to be used for the LLM call.
Returns:
Either a text response from the LLM (str) or
the result of a tool function call (Any).
Raises:
ValueError: If the messages format is invalid.
TimeoutError: If the LLM request times out.
RuntimeError: If the LLM request fails for other reasons.
"""
raise NotImplementedError
def _convert_tools_for_interference(
self, tools: list[dict[str, BaseTool]]
) -> list[dict[str, BaseTool]]:

View File

@@ -182,6 +182,8 @@ OPENAI_MODELS: list[OpenAIModels] = [
AnthropicModels: TypeAlias = Literal[
"claude-opus-4-5-20251101",
"claude-opus-4-5",
"claude-3-7-sonnet-latest",
"claude-3-7-sonnet-20250219",
"claude-3-5-haiku-latest",
@@ -208,6 +210,8 @@ AnthropicModels: TypeAlias = Literal[
"claude-3-haiku-20240307",
]
ANTHROPIC_MODELS: list[AnthropicModels] = [
"claude-opus-4-5-20251101",
"claude-opus-4-5",
"claude-3-7-sonnet-latest",
"claude-3-7-sonnet-20250219",
"claude-3-5-haiku-latest",
@@ -252,6 +256,7 @@ GeminiModels: TypeAlias = Literal[
"gemini-2.5-flash-preview-tts",
"gemini-2.5-pro-preview-tts",
"gemini-2.5-computer-use-preview-10-2025",
"gemini-2.5-pro-exp-03-25",
"gemini-2.0-flash",
"gemini-2.0-flash-001",
"gemini-2.0-flash-exp",
@@ -305,6 +310,7 @@ GEMINI_MODELS: list[GeminiModels] = [
"gemini-2.5-flash-preview-tts",
"gemini-2.5-pro-preview-tts",
"gemini-2.5-computer-use-preview-10-2025",
"gemini-2.5-pro-exp-03-25",
"gemini-2.0-flash",
"gemini-2.0-flash-001",
"gemini-2.0-flash-exp",
@@ -452,6 +458,7 @@ BedrockModels: TypeAlias = Literal[
"anthropic.claude-3-sonnet-20240229-v1:0:28k",
"anthropic.claude-haiku-4-5-20251001-v1:0",
"anthropic.claude-instant-v1:2:100k",
"anthropic.claude-opus-4-5-20251101-v1:0",
"anthropic.claude-opus-4-1-20250805-v1:0",
"anthropic.claude-opus-4-20250514-v1:0",
"anthropic.claude-sonnet-4-20250514-v1:0",
@@ -524,6 +531,7 @@ BEDROCK_MODELS: list[BedrockModels] = [
"anthropic.claude-3-sonnet-20240229-v1:0:28k",
"anthropic.claude-haiku-4-5-20251001-v1:0",
"anthropic.claude-instant-v1:2:100k",
"anthropic.claude-opus-4-5-20251101-v1:0",
"anthropic.claude-opus-4-1-20250805-v1:0",
"anthropic.claude-opus-4-20250514-v1:0",
"anthropic.claude-sonnet-4-20250514-v1:0",

View File

@@ -9,7 +9,7 @@ from pydantic import BaseModel
from crewai.events.types.llm_events import LLMCallType
from crewai.llms.base_llm import BaseLLM
from crewai.llms.hooks.transport import HTTPTransport
from crewai.llms.hooks.transport import AsyncHTTPTransport, HTTPTransport
from crewai.utilities.agent_utils import is_context_length_exceeded
from crewai.utilities.exceptions.context_window_exceeding_exception import (
LLMContextLengthExceededError,
@@ -21,7 +21,7 @@ if TYPE_CHECKING:
from crewai.llms.hooks.base import BaseInterceptor
try:
from anthropic import Anthropic
from anthropic import Anthropic, AsyncAnthropic
from anthropic.types import Message
from anthropic.types.tool_use_block import ToolUseBlock
import httpx
@@ -84,15 +84,20 @@ class AnthropicCompletion(BaseLLM):
self.client = Anthropic(**self._get_client_params())
async_client_params = self._get_client_params()
if self.interceptor:
async_transport = AsyncHTTPTransport(interceptor=self.interceptor)
async_http_client = httpx.AsyncClient(transport=async_transport)
async_client_params["http_client"] = async_http_client
self.async_client = AsyncAnthropic(**async_client_params)
# Store completion parameters
self.max_tokens = max_tokens
self.top_p = top_p
self.stream = stream
self.stop_sequences = stop_sequences or []
# Model-specific settings
self.is_claude_3 = "claude-3" in model.lower()
self.supports_tools = self.is_claude_3 # Claude 3+ supports tool use
self.supports_tools = True
@property
def stop(self) -> list[str]:
@@ -213,6 +218,72 @@ class AnthropicCompletion(BaseLLM):
)
raise
async def acall(
self,
messages: str | list[LLMMessage],
tools: list[dict[str, Any]] | None = None,
callbacks: list[Any] | None = None,
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
response_model: type[BaseModel] | None = None,
) -> str | Any:
"""Async call to Anthropic messages API.
Args:
messages: Input messages for the chat completion
tools: List of tool/function definitions
callbacks: Callback functions (not used in native implementation)
available_functions: Available functions for tool calling
from_task: Task that initiated the call
from_agent: Agent that initiated the call
Returns:
Chat completion response or tool call result
"""
try:
self._emit_call_started_event(
messages=messages,
tools=tools,
callbacks=callbacks,
available_functions=available_functions,
from_task=from_task,
from_agent=from_agent,
)
formatted_messages, system_message = self._format_messages_for_anthropic(
messages
)
completion_params = self._prepare_completion_params(
formatted_messages, system_message, tools
)
if self.stream:
return await self._ahandle_streaming_completion(
completion_params,
available_functions,
from_task,
from_agent,
response_model,
)
return await self._ahandle_completion(
completion_params,
available_functions,
from_task,
from_agent,
response_model,
)
except Exception as e:
error_msg = f"Anthropic API call failed: {e!s}"
logging.error(error_msg)
self._emit_call_failed_event(
error=error_msg, from_task=from_task, from_agent=from_agent
)
raise
def _prepare_completion_params(
self,
messages: list[LLMMessage],
@@ -546,7 +617,7 @@ class AnthropicCompletion(BaseLLM):
# Execute the tool
result = self._handle_tool_execution(
function_name=function_name,
function_args=function_args, # type: ignore
function_args=function_args,
available_functions=available_functions,
from_task=from_task,
from_agent=from_agent,
@@ -626,6 +697,275 @@ class AnthropicCompletion(BaseLLM):
return tool_results[0]["content"]
raise e
async def _ahandle_completion(
self,
params: dict[str, Any],
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
response_model: type[BaseModel] | None = None,
) -> str | Any:
"""Handle non-streaming async message completion."""
if response_model:
structured_tool = {
"name": "structured_output",
"description": "Returns structured data according to the schema",
"input_schema": response_model.model_json_schema(),
}
params["tools"] = [structured_tool]
params["tool_choice"] = {"type": "tool", "name": "structured_output"}
try:
response: Message = await self.async_client.messages.create(**params)
except Exception as e:
if is_context_length_exceeded(e):
logging.error(f"Context window exceeded: {e}")
raise LLMContextLengthExceededError(str(e)) from e
raise e from e
usage = self._extract_anthropic_token_usage(response)
self._track_token_usage_internal(usage)
if response_model and response.content:
tool_uses = [
block for block in response.content if isinstance(block, ToolUseBlock)
]
if tool_uses and tool_uses[0].name == "structured_output":
structured_data = tool_uses[0].input
structured_json = json.dumps(structured_data)
self._emit_call_completed_event(
response=structured_json,
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
)
return structured_json
if response.content and available_functions:
tool_uses = [
block for block in response.content if isinstance(block, ToolUseBlock)
]
if tool_uses:
return await self._ahandle_tool_use_conversation(
response,
tool_uses,
params,
available_functions,
from_task,
from_agent,
)
content = ""
if response.content:
for content_block in response.content:
if hasattr(content_block, "text"):
content += content_block.text
content = self._apply_stop_words(content)
self._emit_call_completed_event(
response=content,
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
)
if usage.get("total_tokens", 0) > 0:
logging.info(f"Anthropic API usage: {usage}")
return content
async def _ahandle_streaming_completion(
self,
params: dict[str, Any],
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
response_model: type[BaseModel] | None = None,
) -> str:
"""Handle async streaming message completion."""
if response_model:
structured_tool = {
"name": "structured_output",
"description": "Returns structured data according to the schema",
"input_schema": response_model.model_json_schema(),
}
params["tools"] = [structured_tool]
params["tool_choice"] = {"type": "tool", "name": "structured_output"}
full_response = ""
stream_params = {k: v for k, v in params.items() if k != "stream"}
async with self.async_client.messages.stream(**stream_params) as stream:
async for event in stream:
if hasattr(event, "delta") and hasattr(event.delta, "text"):
text_delta = event.delta.text
full_response += text_delta
self._emit_stream_chunk_event(
chunk=text_delta,
from_task=from_task,
from_agent=from_agent,
)
final_message: Message = await stream.get_final_message()
usage = self._extract_anthropic_token_usage(final_message)
self._track_token_usage_internal(usage)
if response_model and final_message.content:
tool_uses = [
block
for block in final_message.content
if isinstance(block, ToolUseBlock)
]
if tool_uses and tool_uses[0].name == "structured_output":
structured_data = tool_uses[0].input
structured_json = json.dumps(structured_data)
self._emit_call_completed_event(
response=structured_json,
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
)
return structured_json
if final_message.content and available_functions:
tool_uses = [
block
for block in final_message.content
if isinstance(block, ToolUseBlock)
]
if tool_uses:
return await self._ahandle_tool_use_conversation(
final_message,
tool_uses,
params,
available_functions,
from_task,
from_agent,
)
full_response = self._apply_stop_words(full_response)
self._emit_call_completed_event(
response=full_response,
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
)
return full_response
async def _ahandle_tool_use_conversation(
self,
initial_response: Message,
tool_uses: list[ToolUseBlock],
params: dict[str, Any],
available_functions: dict[str, Any],
from_task: Any | None = None,
from_agent: Any | None = None,
) -> str:
"""Handle the complete async tool use conversation flow.
This implements the proper Anthropic tool use pattern:
1. Claude requests tool use
2. We execute the tools
3. We send tool results back to Claude
4. Claude processes results and generates final response
"""
tool_results = []
for tool_use in tool_uses:
function_name = tool_use.name
function_args = tool_use.input
result = self._handle_tool_execution(
function_name=function_name,
function_args=function_args,
available_functions=available_functions,
from_task=from_task,
from_agent=from_agent,
)
tool_result = {
"type": "tool_result",
"tool_use_id": tool_use.id,
"content": str(result)
if result is not None
else "Tool execution completed",
}
tool_results.append(tool_result)
follow_up_params = params.copy()
assistant_message = {"role": "assistant", "content": initial_response.content}
user_message = {"role": "user", "content": tool_results}
follow_up_params["messages"] = params["messages"] + [
assistant_message,
user_message,
]
try:
final_response: Message = await self.async_client.messages.create(
**follow_up_params
)
follow_up_usage = self._extract_anthropic_token_usage(final_response)
self._track_token_usage_internal(follow_up_usage)
final_content = ""
if final_response.content:
for content_block in final_response.content:
if hasattr(content_block, "text"):
final_content += content_block.text
final_content = self._apply_stop_words(final_content)
self._emit_call_completed_event(
response=final_content,
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=follow_up_params["messages"],
)
total_usage = {
"input_tokens": follow_up_usage.get("input_tokens", 0),
"output_tokens": follow_up_usage.get("output_tokens", 0),
"total_tokens": follow_up_usage.get("total_tokens", 0),
}
if total_usage.get("total_tokens", 0) > 0:
logging.info(f"Anthropic API tool conversation usage: {total_usage}")
return final_content
except Exception as e:
if is_context_length_exceeded(e):
logging.error(f"Context window exceeded in tool follow-up: {e}")
raise LLMContextLengthExceededError(str(e)) from e
logging.error(f"Tool follow-up conversation failed: {e}")
if tool_results:
return tool_results[0]["content"]
raise e
def supports_function_calling(self) -> bool:
"""Check if the model supports function calling."""
return self.supports_tools

View File

@@ -6,8 +6,10 @@ import os
from typing import TYPE_CHECKING, Any
from pydantic import BaseModel
from typing_extensions import Self
from crewai.utilities.agent_utils import is_context_length_exceeded
from crewai.utilities.converter import generate_model_description
from crewai.utilities.exceptions.context_window_exceeding_exception import (
LLMContextLengthExceededError,
)
@@ -23,9 +25,13 @@ try:
from azure.ai.inference import (
ChatCompletionsClient,
)
from azure.ai.inference.aio import (
ChatCompletionsClient as AsyncChatCompletionsClient,
)
from azure.ai.inference.models import (
ChatCompletions,
ChatCompletionsToolCall,
JsonSchemaFormat,
StreamingChatCompletionsUpdate,
)
from azure.core.credentials import (
@@ -133,6 +139,8 @@ class AzureCompletion(BaseLLM):
self.client = ChatCompletionsClient(**client_kwargs) # type: ignore[arg-type]
self.async_client = AsyncChatCompletionsClient(**client_kwargs) # type: ignore[arg-type]
self.top_p = top_p
self.frequency_penalty = frequency_penalty
self.presence_penalty = presence_penalty
@@ -256,6 +264,88 @@ class AzureCompletion(BaseLLM):
)
raise
async def acall(
self,
messages: str | list[LLMMessage],
tools: list[dict[str, BaseTool]] | None = None,
callbacks: list[Any] | None = None,
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
response_model: type[BaseModel] | None = None,
) -> str | Any:
"""Call Azure AI Inference chat completions API asynchronously.
Args:
messages: Input messages for the chat completion
tools: List of tool/function definitions
callbacks: Callback functions (not used in native implementation)
available_functions: Available functions for tool calling
from_task: Task that initiated the call
from_agent: Agent that initiated the call
response_model: Pydantic model for structured output
Returns:
Chat completion response or tool call result
"""
try:
self._emit_call_started_event(
messages=messages,
tools=tools,
callbacks=callbacks,
available_functions=available_functions,
from_task=from_task,
from_agent=from_agent,
)
formatted_messages = self._format_messages_for_azure(messages)
completion_params = self._prepare_completion_params(
formatted_messages, tools, response_model
)
if self.stream:
return await self._ahandle_streaming_completion(
completion_params,
available_functions,
from_task,
from_agent,
response_model,
)
return await self._ahandle_completion(
completion_params,
available_functions,
from_task,
from_agent,
response_model,
)
except HttpResponseError as e:
if e.status_code == 401:
error_msg = "Azure authentication failed. Check your API key."
elif e.status_code == 404:
error_msg = (
f"Azure endpoint not found. Check endpoint URL: {self.endpoint}"
)
elif e.status_code == 429:
error_msg = "Azure API rate limit exceeded. Please retry later."
else:
error_msg = f"Azure API HTTP error: {e.status_code} - {e.message}"
logging.error(error_msg)
self._emit_call_failed_event(
error=error_msg, from_task=from_task, from_agent=from_agent
)
raise
except Exception as e:
error_msg = f"Azure API call failed: {e!s}"
logging.error(error_msg)
self._emit_call_failed_event(
error=error_msg, from_task=from_task, from_agent=from_agent
)
raise
def _prepare_completion_params(
self,
messages: list[LLMMessage],
@@ -278,13 +368,16 @@ class AzureCompletion(BaseLLM):
}
if response_model and self.is_openai_model:
params["response_format"] = {
"type": "json_schema",
"json_schema": {
"name": response_model.__name__,
"schema": response_model.model_json_schema(),
},
}
model_description = generate_model_description(response_model)
json_schema_info = model_description["json_schema"]
json_schema_name = json_schema_info["name"]
params["response_format"] = JsonSchemaFormat(
name=json_schema_name,
schema=json_schema_info["schema"],
description=f"Schema for {json_schema_name}",
strict=json_schema_info["strict"],
)
# Only include model parameter for non-Azure OpenAI endpoints
# Azure OpenAI endpoints have the deployment name in the URL
@@ -311,8 +404,8 @@ class AzureCompletion(BaseLLM):
params["tool_choice"] = "auto"
additional_params = self.additional_params
additional_drop_params = additional_params.get('additional_drop_params')
drop_params = additional_params.get('drop_params')
additional_drop_params = additional_params.get("additional_drop_params")
drop_params = additional_params.get("drop_params")
if drop_params and isinstance(additional_drop_params, list):
for drop_param in additional_drop_params:
@@ -551,6 +644,170 @@ class AzureCompletion(BaseLLM):
return full_response
async def _ahandle_completion(
self,
params: dict[str, Any],
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
response_model: type[BaseModel] | None = None,
) -> str | Any:
"""Handle non-streaming chat completion asynchronously."""
try:
response: ChatCompletions = await self.async_client.complete(**params)
if not response.choices:
raise ValueError("No choices returned from Azure API")
choice = response.choices[0]
message = choice.message
usage = self._extract_azure_token_usage(response)
self._track_token_usage_internal(usage)
if response_model and self.is_openai_model:
content = message.content or ""
try:
structured_data = response_model.model_validate_json(content)
structured_json = structured_data.model_dump_json()
self._emit_call_completed_event(
response=structured_json,
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
)
return structured_json
except Exception as e:
error_msg = f"Failed to validate structured output with model {response_model.__name__}: {e}"
logging.error(error_msg)
raise ValueError(error_msg) from e
if message.tool_calls and available_functions:
tool_call = message.tool_calls[0] # Handle first tool call
if isinstance(tool_call, ChatCompletionsToolCall):
function_name = tool_call.function.name
try:
function_args = json.loads(tool_call.function.arguments)
except json.JSONDecodeError as e:
logging.error(f"Failed to parse tool arguments: {e}")
function_args = {}
result = self._handle_tool_execution(
function_name=function_name,
function_args=function_args,
available_functions=available_functions,
from_task=from_task,
from_agent=from_agent,
)
if result is not None:
return result
content = message.content or ""
content = self._apply_stop_words(content)
self._emit_call_completed_event(
response=content,
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
)
except Exception as e:
if is_context_length_exceeded(e):
logging.error(f"Context window exceeded: {e}")
raise LLMContextLengthExceededError(str(e)) from e
error_msg = f"Azure API call failed: {e!s}"
logging.error(error_msg)
self._emit_call_failed_event(
error=error_msg, from_task=from_task, from_agent=from_agent
)
raise e
return content
async def _ahandle_streaming_completion(
self,
params: dict[str, Any],
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
response_model: type[BaseModel] | None = None,
) -> str:
"""Handle streaming chat completion asynchronously."""
full_response = ""
tool_calls = {}
stream = await self.async_client.complete(**params)
async for update in stream:
if isinstance(update, StreamingChatCompletionsUpdate):
if update.choices:
choice = update.choices[0]
if choice.delta and choice.delta.content:
content_delta = choice.delta.content
full_response += content_delta
self._emit_stream_chunk_event(
chunk=content_delta,
from_task=from_task,
from_agent=from_agent,
)
if choice.delta and choice.delta.tool_calls:
for tool_call in choice.delta.tool_calls:
call_id = tool_call.id or "default"
if call_id not in tool_calls:
tool_calls[call_id] = {
"name": "",
"arguments": "",
}
if tool_call.function and tool_call.function.name:
tool_calls[call_id]["name"] = tool_call.function.name
if tool_call.function and tool_call.function.arguments:
tool_calls[call_id]["arguments"] += (
tool_call.function.arguments
)
if tool_calls and available_functions:
for call_data in tool_calls.values():
function_name = call_data["name"]
try:
function_args = json.loads(call_data["arguments"])
except json.JSONDecodeError as e:
logging.error(f"Failed to parse streamed tool arguments: {e}")
continue
result = self._handle_tool_execution(
function_name=function_name,
function_args=function_args,
available_functions=available_functions,
from_task=from_task,
from_agent=from_agent,
)
if result is not None:
return result
full_response = self._apply_stop_words(full_response)
self._emit_call_completed_event(
response=full_response,
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
)
return full_response
def supports_function_calling(self) -> bool:
"""Check if the model supports function calling."""
# Azure OpenAI models support function calling
@@ -604,3 +861,20 @@ class AzureCompletion(BaseLLM):
"total_tokens": getattr(usage, "total_tokens", 0),
}
return {"total_tokens": 0}
async def aclose(self) -> None:
"""Close the async client and clean up resources.
This ensures proper cleanup of the underlying aiohttp session
to avoid unclosed connector warnings.
"""
if hasattr(self.async_client, "close"):
await self.async_client.close()
async def __aenter__(self) -> Self:
"""Async context manager entry."""
return self
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
"""Async context manager exit."""
await self.aclose()

View File

@@ -1,6 +1,8 @@
from __future__ import annotations
from collections.abc import Mapping, Sequence
from contextlib import AsyncExitStack
import json
import logging
import os
from typing import TYPE_CHECKING, Any, TypedDict, cast
@@ -42,6 +44,16 @@ except ImportError:
'AWS Bedrock native provider not available, to install: uv add "crewai[bedrock]"'
) from None
try:
from aiobotocore.session import ( # type: ignore[import-untyped]
get_session as get_aiobotocore_session,
)
AIOBOTOCORE_AVAILABLE = True
except ImportError:
AIOBOTOCORE_AVAILABLE = False
get_aiobotocore_session = None
if TYPE_CHECKING:
@@ -221,6 +233,15 @@ class BedrockCompletion(BaseLLM):
self.client = session.client("bedrock-runtime", config=config)
self.region_name = region_name
self.aws_access_key_id = aws_access_key_id or os.getenv("AWS_ACCESS_KEY_ID")
self.aws_secret_access_key = aws_secret_access_key or os.getenv(
"AWS_SECRET_ACCESS_KEY"
)
self.aws_session_token = aws_session_token or os.getenv("AWS_SESSION_TOKEN")
self._async_exit_stack = AsyncExitStack() if AIOBOTOCORE_AVAILABLE else None
self._async_client_initialized = False
# Store completion parameters
self.max_tokens = max_tokens
self.top_p = top_p
@@ -354,6 +375,110 @@ class BedrockCompletion(BaseLLM):
)
raise
async def acall(
self,
messages: str | list[LLMMessage],
tools: list[dict[Any, Any]] | None = None,
callbacks: list[Any] | None = None,
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
response_model: type[BaseModel] | None = None,
) -> str | Any:
"""Async call to AWS Bedrock Converse API.
Args:
messages: Input messages as string or list of message dicts.
tools: Optional list of tool definitions.
callbacks: Optional list of callback handlers.
available_functions: Optional dict mapping function names to callables.
from_task: Optional task context for events.
from_agent: Optional agent context for events.
response_model: Optional Pydantic model for structured output.
Returns:
Generated text response or structured output.
Raises:
NotImplementedError: If aiobotocore is not installed.
LLMContextLengthExceededError: If context window is exceeded.
"""
if not AIOBOTOCORE_AVAILABLE:
raise NotImplementedError(
"Async support for AWS Bedrock requires aiobotocore. "
'Install with: uv add "crewai[bedrock-async]"'
)
try:
self._emit_call_started_event(
messages=messages,
tools=tools,
callbacks=callbacks,
available_functions=available_functions,
from_task=from_task,
from_agent=from_agent,
)
formatted_messages, system_message = self._format_messages_for_converse(
messages # type: ignore[arg-type]
)
body: BedrockConverseRequestBody = {
"inferenceConfig": self._get_inference_config(),
}
if system_message:
body["system"] = cast(
"list[SystemContentBlockTypeDef]",
cast(object, [{"text": system_message}]),
)
if tools:
tool_config: ToolConfigurationTypeDef = {
"tools": cast(
"Sequence[ToolTypeDef]",
cast(object, self._format_tools_for_converse(tools)),
)
}
body["toolConfig"] = tool_config
if self.guardrail_config:
guardrail_config: GuardrailConfigurationTypeDef = cast(
"GuardrailConfigurationTypeDef", cast(object, self.guardrail_config)
)
body["guardrailConfig"] = guardrail_config
if self.additional_model_request_fields:
body["additionalModelRequestFields"] = (
self.additional_model_request_fields
)
if self.additional_model_response_field_paths:
body["additionalModelResponseFieldPaths"] = (
self.additional_model_response_field_paths
)
if self.stream:
return await self._ahandle_streaming_converse(
formatted_messages, body, available_functions, from_task, from_agent
)
return await self._ahandle_converse(
formatted_messages, body, available_functions, from_task, from_agent
)
except Exception as e:
if is_context_length_exceeded(e):
logging.error(f"Context window exceeded: {e}")
raise LLMContextLengthExceededError(str(e)) from e
error_msg = f"AWS Bedrock API call failed: {e!s}"
logging.error(error_msg)
self._emit_call_failed_event(
error=error_msg, from_task=from_task, from_agent=from_agent
)
raise
def _handle_converse(
self,
messages: list[dict[str, Any]],
@@ -565,6 +690,341 @@ class BedrockCompletion(BaseLLM):
role = event["messageStart"].get("role")
logging.debug(f"Streaming message started with role: {role}")
elif "contentBlockStart" in event:
start = event["contentBlockStart"].get("start", {})
if "toolUse" in start:
current_tool_use = start["toolUse"]
tool_use_id = current_tool_use.get("toolUseId")
logging.debug(
f"Tool use started in stream: {json.dumps(current_tool_use)} (ID: {tool_use_id})"
)
elif "contentBlockDelta" in event:
delta = event["contentBlockDelta"]["delta"]
if "text" in delta:
text_chunk = delta["text"]
logging.debug(f"Streaming text chunk: {text_chunk[:50]}...")
full_response += text_chunk
self._emit_stream_chunk_event(
chunk=text_chunk,
from_task=from_task,
from_agent=from_agent,
)
elif "toolUse" in delta and current_tool_use:
tool_input = delta["toolUse"].get("input", "")
if tool_input:
logging.debug(f"Tool input delta: {tool_input}")
elif "contentBlockStop" in event:
logging.debug("Content block stopped in stream")
if current_tool_use and available_functions:
function_name = current_tool_use["name"]
function_args = cast(
dict[str, Any], current_tool_use.get("input", {})
)
tool_result = self._handle_tool_execution(
function_name=function_name,
function_args=function_args,
available_functions=available_functions,
from_task=from_task,
from_agent=from_agent,
)
if tool_result is not None and tool_use_id:
messages.append(
{
"role": "assistant",
"content": [{"toolUse": current_tool_use}],
}
)
messages.append(
{
"role": "user",
"content": [
{
"toolResult": {
"toolUseId": tool_use_id,
"content": [
{"text": str(tool_result)}
],
}
}
],
}
)
return self._handle_converse(
messages,
body,
available_functions,
from_task,
from_agent,
)
current_tool_use = None
tool_use_id = None
elif "messageStop" in event:
stop_reason = event["messageStop"].get("stopReason")
logging.debug(f"Streaming message stopped: {stop_reason}")
if stop_reason == "max_tokens":
logging.warning(
"Streaming response truncated due to max_tokens"
)
elif stop_reason == "content_filtered":
logging.warning(
"Streaming response filtered due to content policy"
)
break
elif "metadata" in event:
metadata = event["metadata"]
if "usage" in metadata:
usage_metrics = metadata["usage"]
self._track_token_usage_internal(usage_metrics)
logging.debug(f"Token usage: {usage_metrics}")
if "trace" in metadata:
logging.debug(
f"Trace information available: {metadata['trace']}"
)
except ClientError as e:
error_msg = self._handle_client_error(e)
raise RuntimeError(error_msg) from e
except BotoCoreError as e:
error_msg = f"Bedrock streaming connection error: {e}"
logging.error(error_msg)
raise ConnectionError(error_msg) from e
full_response = self._apply_stop_words(full_response)
if not full_response or full_response.strip() == "":
logging.warning("Bedrock streaming returned empty content, using fallback")
full_response = (
"I apologize, but I couldn't generate a response. Please try again."
)
self._emit_call_completed_event(
response=full_response,
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=messages,
)
return full_response
async def _ensure_async_client(self) -> Any:
"""Ensure async client is initialized and return it."""
if not self._async_client_initialized and get_aiobotocore_session:
if self._async_exit_stack is None:
raise RuntimeError(
"Async exit stack not initialized - aiobotocore not available"
)
session = get_aiobotocore_session()
client = await self._async_exit_stack.enter_async_context(
session.create_client(
"bedrock-runtime",
region_name=self.region_name,
aws_access_key_id=self.aws_access_key_id,
aws_secret_access_key=self.aws_secret_access_key,
aws_session_token=self.aws_session_token,
)
)
self._async_client = client
self._async_client_initialized = True
return self._async_client
async def _ahandle_converse(
self,
messages: list[dict[str, Any]],
body: BedrockConverseRequestBody,
available_functions: Mapping[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
) -> str:
"""Handle async non-streaming converse API call."""
try:
if not messages:
raise ValueError("Messages cannot be empty")
for i, msg in enumerate(messages):
if (
not isinstance(msg, dict)
or "role" not in msg
or "content" not in msg
):
raise ValueError(f"Invalid message format at index {i}")
async_client = await self._ensure_async_client()
response = await async_client.converse(
modelId=self.model_id,
messages=cast(
"Sequence[MessageTypeDef | MessageOutputTypeDef]",
cast(object, messages),
),
**body,
)
if "usage" in response:
self._track_token_usage_internal(response["usage"])
stop_reason = response.get("stopReason")
if stop_reason:
logging.debug(f"Response stop reason: {stop_reason}")
if stop_reason == "max_tokens":
logging.warning("Response truncated due to max_tokens limit")
elif stop_reason == "content_filtered":
logging.warning("Response was filtered due to content policy")
output = response.get("output", {})
message = output.get("message", {})
content = message.get("content", [])
if not content:
logging.warning("No content in Bedrock response")
return (
"I apologize, but I received an empty response. Please try again."
)
text_content = ""
for content_block in content:
if "text" in content_block:
text_content += content_block["text"]
elif "toolUse" in content_block and available_functions:
tool_use_block = content_block["toolUse"]
tool_use_id = tool_use_block.get("toolUseId")
function_name = tool_use_block["name"]
function_args = tool_use_block.get("input", {})
logging.debug(
f"Tool use requested: {function_name} with ID {tool_use_id}"
)
tool_result = self._handle_tool_execution(
function_name=function_name,
function_args=function_args,
available_functions=dict(available_functions),
from_task=from_task,
from_agent=from_agent,
)
if tool_result is not None:
messages.append(
{
"role": "assistant",
"content": [{"toolUse": tool_use_block}],
}
)
messages.append(
{
"role": "user",
"content": [
{
"toolResult": {
"toolUseId": tool_use_id,
"content": [{"text": str(tool_result)}],
}
}
],
}
)
return await self._ahandle_converse(
messages, body, available_functions, from_task, from_agent
)
text_content = self._apply_stop_words(text_content)
if not text_content or text_content.strip() == "":
logging.warning("Extracted empty text content from Bedrock response")
text_content = "I apologize, but I couldn't generate a proper response. Please try again."
self._emit_call_completed_event(
response=text_content,
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=messages,
)
return text_content
except ClientError as e:
error_code = e.response.get("Error", {}).get("Code", "Unknown")
error_msg = e.response.get("Error", {}).get("Message", str(e))
logging.error(f"AWS Bedrock ClientError ({error_code}): {error_msg}")
if error_code == "ValidationException":
if "last turn" in error_msg and "user message" in error_msg:
raise ValueError(
f"Conversation format error: {error_msg}. Check message alternation."
) from e
raise ValueError(f"Request validation failed: {error_msg}") from e
if error_code == "AccessDeniedException":
raise PermissionError(
f"Access denied to model {self.model_id}: {error_msg}"
) from e
if error_code == "ResourceNotFoundException":
raise ValueError(f"Model {self.model_id} not found: {error_msg}") from e
if error_code == "ThrottlingException":
raise RuntimeError(
f"API throttled, please retry later: {error_msg}"
) from e
if error_code == "ModelTimeoutException":
raise TimeoutError(f"Model request timed out: {error_msg}") from e
if error_code == "ServiceQuotaExceededException":
raise RuntimeError(f"Service quota exceeded: {error_msg}") from e
if error_code == "ModelNotReadyException":
raise RuntimeError(
f"Model {self.model_id} not ready: {error_msg}"
) from e
if error_code == "ModelErrorException":
raise RuntimeError(f"Model error: {error_msg}") from e
if error_code == "InternalServerException":
raise RuntimeError(f"Internal server error: {error_msg}") from e
if error_code == "ServiceUnavailableException":
raise RuntimeError(f"Service unavailable: {error_msg}") from e
raise RuntimeError(f"Bedrock API error ({error_code}): {error_msg}") from e
except BotoCoreError as e:
error_msg = f"Bedrock connection error: {e}"
logging.error(error_msg)
raise ConnectionError(error_msg) from e
except Exception as e:
error_msg = f"Unexpected error in Bedrock converse call: {e}"
logging.error(error_msg)
raise RuntimeError(error_msg) from e
async def _ahandle_streaming_converse(
self,
messages: list[dict[str, Any]],
body: BedrockConverseRequestBody,
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
) -> str:
"""Handle async streaming converse API call."""
full_response = ""
current_tool_use = None
tool_use_id = None
try:
async_client = await self._ensure_async_client()
response = await async_client.converse_stream(
modelId=self.model_id,
messages=cast(
"Sequence[MessageTypeDef | MessageOutputTypeDef]",
cast(object, messages),
),
**body,
)
stream = response.get("stream")
if stream:
async for event in stream:
if "messageStart" in event:
role = event["messageStart"].get("role")
logging.debug(f"Streaming message started with role: {role}")
elif "contentBlockStart" in event:
start = event["contentBlockStart"].get("start", {})
if "toolUse" in start:
@@ -590,17 +1050,14 @@ class BedrockCompletion(BaseLLM):
if tool_input:
logging.debug(f"Tool input delta: {tool_input}")
# Content block stop - end of a content block
elif "contentBlockStop" in event:
logging.debug("Content block stopped in stream")
# If we were accumulating a tool use, it's now complete
if current_tool_use and available_functions:
function_name = current_tool_use["name"]
function_args = cast(
dict[str, Any], current_tool_use.get("input", {})
)
# Execute tool
tool_result = self._handle_tool_execution(
function_name=function_name,
function_args=function_args,
@@ -610,7 +1067,6 @@ class BedrockCompletion(BaseLLM):
)
if tool_result is not None and tool_use_id:
# Continue conversation with tool result
messages.append(
{
"role": "assistant",
@@ -634,8 +1090,7 @@ class BedrockCompletion(BaseLLM):
}
)
# Recursive call - note this switches to non-streaming
return self._handle_converse(
return await self._ahandle_converse(
messages,
body,
available_functions,
@@ -643,10 +1098,9 @@ class BedrockCompletion(BaseLLM):
from_agent,
)
current_tool_use = None
tool_use_id = None
current_tool_use = None
tool_use_id = None
# Message stop - end of entire message
elif "messageStop" in event:
stop_reason = event["messageStop"].get("stopReason")
logging.debug(f"Streaming message stopped: {stop_reason}")
@@ -660,7 +1114,6 @@ class BedrockCompletion(BaseLLM):
)
break
# Metadata - contains usage information and trace details
elif "metadata" in event:
metadata = event["metadata"]
if "usage" in metadata:
@@ -680,17 +1133,14 @@ class BedrockCompletion(BaseLLM):
logging.error(error_msg)
raise ConnectionError(error_msg) from e
# Apply stop words to full response
full_response = self._apply_stop_words(full_response)
# Ensure we don't return empty content
if not full_response or full_response.strip() == "":
logging.warning("Bedrock streaming returned empty content, using fallback")
full_response = (
"I apologize, but I couldn't generate a response. Please try again."
)
# Emit completion event
self._emit_call_completed_event(
response=full_response,
call_type=LLMCallType.LLM_CALL,

View File

@@ -1,13 +1,14 @@
from __future__ import annotations
import logging
import os
import re
from typing import Any, cast
from typing import TYPE_CHECKING, Any
from pydantic import BaseModel
from crewai.events.types.llm_events import LLMCallType
from crewai.llms.base_llm import BaseLLM
from crewai.llms.hooks.base import BaseInterceptor
from crewai.utilities.agent_utils import is_context_length_exceeded
from crewai.utilities.exceptions.context_window_exceeding_exception import (
LLMContextLengthExceededError,
@@ -15,10 +16,15 @@ from crewai.utilities.exceptions.context_window_exceeding_exception import (
from crewai.utilities.types import LLMMessage
if TYPE_CHECKING:
from crewai.llms.hooks.base import BaseInterceptor
try:
from google import genai # type: ignore[import-untyped]
from google.genai import types # type: ignore[import-untyped]
from google.genai.errors import APIError # type: ignore[import-untyped]
from google import genai
from google.genai import types
from google.genai.errors import APIError
from google.genai.types import GenerateContentResponse, Schema
except ImportError:
raise ImportError(
'Google Gen AI native provider not available, to install: uv add "crewai[google-genai]"'
@@ -102,7 +108,9 @@ class GeminiCompletion(BaseLLM):
# Model-specific settings
version_match = re.search(r"gemini-(\d+(?:\.\d+)?)", model.lower())
self.supports_tools = bool(version_match and float(version_match.group(1)) >= 1.5)
self.supports_tools = bool(
version_match and float(version_match.group(1)) >= 1.5
)
@property
def stop(self) -> list[str]:
@@ -128,7 +136,7 @@ class GeminiCompletion(BaseLLM):
else:
self.stop_sequences = []
def _initialize_client(self, use_vertexai: bool = False) -> genai.Client: # type: ignore[no-any-unimported]
def _initialize_client(self, use_vertexai: bool = False) -> genai.Client:
"""Initialize the Google Gen AI client with proper parameter handling.
Args:
@@ -277,7 +285,84 @@ class GeminiCompletion(BaseLLM):
)
raise
def _prepare_generation_config( # type: ignore[no-any-unimported]
async def acall(
self,
messages: str | list[LLMMessage],
tools: list[dict[str, Any]] | None = None,
callbacks: list[Any] | None = None,
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
response_model: type[BaseModel] | None = None,
) -> str | Any:
"""Async call to Google Gemini generate content API.
Args:
messages: Input messages for the chat completion
tools: List of tool/function definitions
callbacks: Callback functions (not used as token counts are handled by the response)
available_functions: Available functions for tool calling
from_task: Task that initiated the call
from_agent: Agent that initiated the call
Returns:
Chat completion response or tool call result
"""
try:
self._emit_call_started_event(
messages=messages,
tools=tools,
callbacks=callbacks,
available_functions=available_functions,
from_task=from_task,
from_agent=from_agent,
)
self.tools = tools
formatted_content, system_instruction = self._format_messages_for_gemini(
messages
)
config = self._prepare_generation_config(
system_instruction, tools, response_model
)
if self.stream:
return await self._ahandle_streaming_completion(
formatted_content,
config,
available_functions,
from_task,
from_agent,
response_model,
)
return await self._ahandle_completion(
formatted_content,
system_instruction,
config,
available_functions,
from_task,
from_agent,
response_model,
)
except APIError as e:
error_msg = f"Google Gemini API error: {e.code} - {e.message}"
logging.error(error_msg)
self._emit_call_failed_event(
error=error_msg, from_task=from_task, from_agent=from_agent
)
raise
except Exception as e:
error_msg = f"Google Gemini API call failed: {e!s}"
logging.error(error_msg)
self._emit_call_failed_event(
error=error_msg, from_task=from_task, from_agent=from_agent
)
raise
def _prepare_generation_config(
self,
system_instruction: str | None = None,
tools: list[dict[str, Any]] | None = None,
@@ -294,7 +379,7 @@ class GeminiCompletion(BaseLLM):
GenerateContentConfig object for Gemini API
"""
self.tools = tools
config_params = {}
config_params: dict[str, Any] = {}
# Add system instruction if present
if system_instruction:
@@ -329,7 +414,7 @@ class GeminiCompletion(BaseLLM):
return types.GenerateContentConfig(**config_params)
def _convert_tools_for_interference( # type: ignore[no-any-unimported]
def _convert_tools_for_interference( # type: ignore[override]
self, tools: list[dict[str, Any]]
) -> list[types.Tool]:
"""Convert CrewAI tool format to Gemini function declaration format."""
@@ -346,7 +431,7 @@ class GeminiCompletion(BaseLLM):
)
# Add parameters if present - ensure parameters is a dict
if parameters and isinstance(parameters, dict):
if parameters and isinstance(parameters, Schema):
function_declaration.parameters = parameters
gemini_tool = types.Tool(function_declarations=[function_declaration])
@@ -354,7 +439,7 @@ class GeminiCompletion(BaseLLM):
return gemini_tools
def _format_messages_for_gemini( # type: ignore[no-any-unimported]
def _format_messages_for_gemini(
self, messages: str | list[LLMMessage]
) -> tuple[list[types.Content], str | None]:
"""Format messages for Gemini API.
@@ -373,32 +458,41 @@ class GeminiCompletion(BaseLLM):
# Use base class formatting first
base_formatted = super()._format_messages(messages)
contents = []
contents: list[types.Content] = []
system_instruction: str | None = None
for message in base_formatted:
role = message.get("role")
content = message.get("content", "")
role = message["role"]
content = message["content"]
# Convert content to string if it's a list
if isinstance(content, list):
text_content = " ".join(
str(item.get("text", "")) if isinstance(item, dict) else str(item)
for item in content
)
else:
text_content = str(content) if content else ""
if role == "system":
# Extract system instruction - Gemini handles it separately
if system_instruction:
system_instruction += f"\n\n{content}"
system_instruction += f"\n\n{text_content}"
else:
system_instruction = cast(str, content)
system_instruction = text_content
else:
# Convert role for Gemini (assistant -> model)
gemini_role = "model" if role == "assistant" else "user"
# Create Content object
gemini_content = types.Content(
role=gemini_role, parts=[types.Part.from_text(text=content)]
role=gemini_role, parts=[types.Part.from_text(text=text_content)]
)
contents.append(gemini_content)
return contents, system_instruction
def _handle_completion( # type: ignore[no-any-unimported]
def _handle_completion(
self,
contents: list[types.Content],
system_instruction: str | None,
@@ -409,14 +503,14 @@ class GeminiCompletion(BaseLLM):
response_model: type[BaseModel] | None = None,
) -> str | Any:
"""Handle non-streaming content generation."""
api_params = {
"model": self.model,
"contents": contents,
"config": config,
}
try:
response = self.client.models.generate_content(**api_params)
# The API accepts list[Content] but mypy is overly strict about variance
contents_for_api: Any = contents
response = self.client.models.generate_content(
model=self.model,
contents=contents_for_api,
config=config,
)
usage = self._extract_token_usage(response)
except Exception as e:
@@ -433,6 +527,8 @@ class GeminiCompletion(BaseLLM):
for part in candidate.content.parts:
if hasattr(part, "function_call") and part.function_call:
function_name = part.function_call.name
if function_name is None:
continue
function_args = (
dict(part.function_call.args)
if part.function_call.args
@@ -442,7 +538,7 @@ class GeminiCompletion(BaseLLM):
result = self._handle_tool_execution(
function_name=function_name,
function_args=function_args,
available_functions=available_functions, # type: ignore
available_functions=available_functions or {},
from_task=from_task,
from_agent=from_agent,
)
@@ -450,7 +546,7 @@ class GeminiCompletion(BaseLLM):
if result is not None:
return result
content = response.text if hasattr(response, "text") else ""
content = response.text or ""
content = self._apply_stop_words(content)
messages_for_event = self._convert_contents_to_dict(contents)
@@ -465,7 +561,7 @@ class GeminiCompletion(BaseLLM):
return content
def _handle_streaming_completion( # type: ignore[no-any-unimported]
def _handle_streaming_completion(
self,
contents: list[types.Content],
config: types.GenerateContentConfig,
@@ -476,16 +572,16 @@ class GeminiCompletion(BaseLLM):
) -> str:
"""Handle streaming content generation."""
full_response = ""
function_calls = {}
function_calls: dict[str, dict[str, Any]] = {}
api_params = {
"model": self.model,
"contents": contents,
"config": config,
}
for chunk in self.client.models.generate_content_stream(**api_params):
if hasattr(chunk, "text") and chunk.text:
# The API accepts list[Content] but mypy is overly strict about variance
contents_for_api: Any = contents
for chunk in self.client.models.generate_content_stream(
model=self.model,
contents=contents_for_api,
config=config,
):
if chunk.text:
full_response += chunk.text
self._emit_stream_chunk_event(
chunk=chunk.text,
@@ -493,7 +589,7 @@ class GeminiCompletion(BaseLLM):
from_agent=from_agent,
)
if hasattr(chunk, "candidates") and chunk.candidates:
if chunk.candidates:
candidate = chunk.candidates[0]
if candidate.content and candidate.content.parts:
for part in candidate.content.parts:
@@ -513,6 +609,14 @@ class GeminiCompletion(BaseLLM):
function_name = call_data["name"]
function_args = call_data["args"]
# Skip if function_name is None
if not isinstance(function_name, str):
continue
# Ensure function_args is a dict
if not isinstance(function_args, dict):
function_args = {}
# Execute tool
result = self._handle_tool_execution(
function_name=function_name,
@@ -537,6 +641,154 @@ class GeminiCompletion(BaseLLM):
return full_response
async def _ahandle_completion(
self,
contents: list[types.Content],
system_instruction: str | None,
config: types.GenerateContentConfig,
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
response_model: type[BaseModel] | None = None,
) -> str | Any:
"""Handle async non-streaming content generation."""
try:
# The API accepts list[Content] but mypy is overly strict about variance
contents_for_api: Any = contents
response = await self.client.aio.models.generate_content(
model=self.model,
contents=contents_for_api,
config=config,
)
usage = self._extract_token_usage(response)
except Exception as e:
if is_context_length_exceeded(e):
logging.error(f"Context window exceeded: {e}")
raise LLMContextLengthExceededError(str(e)) from e
raise e from e
self._track_token_usage_internal(usage)
if response.candidates and (self.tools or available_functions):
candidate = response.candidates[0]
if candidate.content and candidate.content.parts:
for part in candidate.content.parts:
if hasattr(part, "function_call") and part.function_call:
function_name = part.function_call.name
if function_name is None:
continue
function_args = (
dict(part.function_call.args)
if part.function_call.args
else {}
)
result = self._handle_tool_execution(
function_name=function_name,
function_args=function_args,
available_functions=available_functions or {},
from_task=from_task,
from_agent=from_agent,
)
if result is not None:
return result
content = response.text or ""
content = self._apply_stop_words(content)
messages_for_event = self._convert_contents_to_dict(contents)
self._emit_call_completed_event(
response=content,
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=messages_for_event,
)
return content
async def _ahandle_streaming_completion(
self,
contents: list[types.Content],
config: types.GenerateContentConfig,
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
response_model: type[BaseModel] | None = None,
) -> str:
"""Handle async streaming content generation."""
full_response = ""
function_calls: dict[str, dict[str, Any]] = {}
# The API accepts list[Content] but mypy is overly strict about variance
contents_for_api: Any = contents
stream = await self.client.aio.models.generate_content_stream(
model=self.model,
contents=contents_for_api,
config=config,
)
async for chunk in stream:
if chunk.text:
full_response += chunk.text
self._emit_stream_chunk_event(
chunk=chunk.text,
from_task=from_task,
from_agent=from_agent,
)
if chunk.candidates:
candidate = chunk.candidates[0]
if candidate.content and candidate.content.parts:
for part in candidate.content.parts:
if hasattr(part, "function_call") and part.function_call:
call_id = part.function_call.name or "default"
if call_id not in function_calls:
function_calls[call_id] = {
"name": part.function_call.name,
"args": dict(part.function_call.args)
if part.function_call.args
else {},
}
if function_calls and available_functions:
for call_data in function_calls.values():
function_name = call_data["name"]
function_args = call_data["args"]
# Skip if function_name is None
if not isinstance(function_name, str):
continue
# Ensure function_args is a dict
if not isinstance(function_args, dict):
function_args = {}
result = self._handle_tool_execution(
function_name=function_name,
function_args=function_args,
available_functions=available_functions,
from_task=from_task,
from_agent=from_agent,
)
if result is not None:
return result
messages_for_event = self._convert_contents_to_dict(contents)
self._emit_call_completed_event(
response=full_response,
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=messages_for_event,
)
return full_response
def supports_function_calling(self) -> bool:
"""Check if the model supports function calling."""
return self.supports_tools
@@ -583,9 +835,10 @@ class GeminiCompletion(BaseLLM):
# Default context window size for Gemini models
return int(1048576 * CONTEXT_WINDOW_USAGE_RATIO) # 1M tokens
def _extract_token_usage(self, response: dict[str, Any]) -> dict[str, Any]:
@staticmethod
def _extract_token_usage(response: GenerateContentResponse) -> dict[str, Any]:
"""Extract token usage from Gemini response."""
if hasattr(response, "usage_metadata"):
if response.usage_metadata:
usage = response.usage_metadata
return {
"prompt_token_count": getattr(usage, "prompt_token_count", 0),
@@ -595,21 +848,23 @@ class GeminiCompletion(BaseLLM):
}
return {"total_tokens": 0}
def _convert_contents_to_dict( # type: ignore[no-any-unimported]
def _convert_contents_to_dict(
self,
contents: list[types.Content],
) -> list[dict[str, str]]:
"""Convert contents to dict format."""
return [
{
"role": "assistant"
if content_obj.role == "model"
else content_obj.role,
"content": " ".join(
part.text
for part in content_obj.parts
if hasattr(part, "text") and part.text
),
}
for content_obj in contents
]
result: list[dict[str, str]] = []
for content_obj in contents:
role = content_obj.role
if role == "model":
role = "assistant"
elif role is None:
role = "user"
parts = content_obj.parts or []
content = " ".join(
part.text for part in parts if hasattr(part, "text") and part.text
)
result.append({"role": role, "content": content})
return result

View File

@@ -1,13 +1,13 @@
from __future__ import annotations
from collections.abc import Iterator
from collections.abc import AsyncIterator, Iterator
import json
import logging
import os
from typing import TYPE_CHECKING, Any
import httpx
from openai import APIConnectionError, NotFoundError, OpenAI
from openai import APIConnectionError, AsyncOpenAI, NotFoundError, OpenAI
from openai.types.chat import ChatCompletion, ChatCompletionChunk
from openai.types.chat.chat_completion import Choice
from openai.types.chat.chat_completion_chunk import ChoiceDelta
@@ -15,7 +15,7 @@ from pydantic import BaseModel
from crewai.events.types.llm_events import LLMCallType
from crewai.llms.base_llm import BaseLLM
from crewai.llms.hooks.transport import HTTPTransport
from crewai.llms.hooks.transport import AsyncHTTPTransport, HTTPTransport
from crewai.utilities.agent_utils import is_context_length_exceeded
from crewai.utilities.converter import generate_model_description
from crewai.utilities.exceptions.context_window_exceeding_exception import (
@@ -101,6 +101,14 @@ class OpenAICompletion(BaseLLM):
self.client = OpenAI(**client_config)
async_client_config = self._get_client_params()
if self.interceptor:
async_transport = AsyncHTTPTransport(interceptor=self.interceptor)
async_http_client = httpx.AsyncClient(transport=async_transport)
async_client_config["http_client"] = async_http_client
self.async_client = AsyncOpenAI(**async_client_config)
# Completion parameters
self.top_p = top_p
self.frequency_penalty = frequency_penalty
@@ -210,6 +218,71 @@ class OpenAICompletion(BaseLLM):
)
raise
async def acall(
self,
messages: str | list[LLMMessage],
tools: list[dict[str, BaseTool]] | None = None,
callbacks: list[Any] | None = None,
available_functions: dict[str, Any] | None = None,
from_task: Task | None = None,
from_agent: Agent | None = None,
response_model: type[BaseModel] | None = None,
) -> str | Any:
"""Async call to OpenAI chat completion API.
Args:
messages: Input messages for the chat completion
tools: list of tool/function definitions
callbacks: Callback functions (not used in native implementation)
available_functions: Available functions for tool calling
from_task: Task that initiated the call
from_agent: Agent that initiated the call
response_model: Response model for structured output.
Returns:
Chat completion response or tool call result
"""
try:
self._emit_call_started_event(
messages=messages,
tools=tools,
callbacks=callbacks,
available_functions=available_functions,
from_task=from_task,
from_agent=from_agent,
)
formatted_messages = self._format_messages(messages)
completion_params = self._prepare_completion_params(
messages=formatted_messages, tools=tools
)
if self.stream:
return await self._ahandle_streaming_completion(
params=completion_params,
available_functions=available_functions,
from_task=from_task,
from_agent=from_agent,
response_model=response_model,
)
return await self._ahandle_completion(
params=completion_params,
available_functions=available_functions,
from_task=from_task,
from_agent=from_agent,
response_model=response_model,
)
except Exception as e:
error_msg = f"OpenAI API call failed: {e!s}"
logging.error(error_msg)
self._emit_call_failed_event(
error=error_msg, from_task=from_task, from_agent=from_agent
)
raise
def _prepare_completion_params(
self, messages: list[LLMMessage], tools: list[dict[str, BaseTool]] | None = None
) -> dict[str, Any]:
@@ -352,10 +425,10 @@ class OpenAICompletion(BaseLLM):
if message.tool_calls and available_functions:
tool_call = message.tool_calls[0]
function_name = tool_call.function.name # type: ignore[union-attr]
function_name = tool_call.function.name
try:
function_args = json.loads(tool_call.function.arguments) # type: ignore[union-attr]
function_args = json.loads(tool_call.function.arguments)
except json.JSONDecodeError as e:
logging.error(f"Failed to parse tool arguments: {e}")
function_args = {}
@@ -564,6 +637,266 @@ class OpenAICompletion(BaseLLM):
return full_response
async def _ahandle_completion(
self,
params: dict[str, Any],
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
response_model: type[BaseModel] | None = None,
) -> str | Any:
"""Handle non-streaming async chat completion."""
try:
if response_model:
parse_params = {
k: v for k, v in params.items() if k != "response_format"
}
parsed_response = await self.async_client.beta.chat.completions.parse(
**parse_params,
response_format=response_model,
)
math_reasoning = parsed_response.choices[0].message
if math_reasoning.refusal:
pass
usage = self._extract_openai_token_usage(parsed_response)
self._track_token_usage_internal(usage)
parsed_object = parsed_response.choices[0].message.parsed
if parsed_object:
structured_json = parsed_object.model_dump_json()
self._emit_call_completed_event(
response=structured_json,
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
)
return structured_json
response: ChatCompletion = await self.async_client.chat.completions.create(
**params
)
usage = self._extract_openai_token_usage(response)
self._track_token_usage_internal(usage)
choice: Choice = response.choices[0]
message = choice.message
if message.tool_calls and available_functions:
tool_call = message.tool_calls[0]
function_name = tool_call.function.name
try:
function_args = json.loads(tool_call.function.arguments)
except json.JSONDecodeError as e:
logging.error(f"Failed to parse tool arguments: {e}")
function_args = {}
result = self._handle_tool_execution(
function_name=function_name,
function_args=function_args,
available_functions=available_functions,
from_task=from_task,
from_agent=from_agent,
)
if result is not None:
return result
content = message.content or ""
content = self._apply_stop_words(content)
if self.response_format and isinstance(self.response_format, type):
try:
structured_result = self._validate_structured_output(
content, self.response_format
)
self._emit_call_completed_event(
response=structured_result,
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
)
return structured_result
except ValueError as e:
logging.warning(f"Structured output validation failed: {e}")
self._emit_call_completed_event(
response=content,
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
)
if usage.get("total_tokens", 0) > 0:
logging.info(f"OpenAI API usage: {usage}")
except NotFoundError as e:
error_msg = f"Model {self.model} not found: {e}"
logging.error(error_msg)
self._emit_call_failed_event(
error=error_msg, from_task=from_task, from_agent=from_agent
)
raise ValueError(error_msg) from e
except APIConnectionError as e:
error_msg = f"Failed to connect to OpenAI API: {e}"
logging.error(error_msg)
self._emit_call_failed_event(
error=error_msg, from_task=from_task, from_agent=from_agent
)
raise ConnectionError(error_msg) from e
except Exception as e:
if is_context_length_exceeded(e):
logging.error(f"Context window exceeded: {e}")
raise LLMContextLengthExceededError(str(e)) from e
error_msg = f"OpenAI API call failed: {e!s}"
logging.error(error_msg)
self._emit_call_failed_event(
error=error_msg, from_task=from_task, from_agent=from_agent
)
raise e from e
return content
async def _ahandle_streaming_completion(
self,
params: dict[str, Any],
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
response_model: type[BaseModel] | None = None,
) -> str:
"""Handle async streaming chat completion."""
full_response = ""
tool_calls = {}
if response_model:
completion_stream: AsyncIterator[
ChatCompletionChunk
] = await self.async_client.chat.completions.create(**params)
accumulated_content = ""
async for chunk in completion_stream:
if not chunk.choices:
continue
choice = chunk.choices[0]
delta: ChoiceDelta = choice.delta
if delta.content:
accumulated_content += delta.content
self._emit_stream_chunk_event(
chunk=delta.content,
from_task=from_task,
from_agent=from_agent,
)
try:
parsed_object = response_model.model_validate_json(accumulated_content)
structured_json = parsed_object.model_dump_json()
self._emit_call_completed_event(
response=structured_json,
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
)
return structured_json
except Exception as e:
logging.error(f"Failed to parse structured output from stream: {e}")
self._emit_call_completed_event(
response=accumulated_content,
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
)
return accumulated_content
stream: AsyncIterator[
ChatCompletionChunk
] = await self.async_client.chat.completions.create(**params)
async for chunk in stream:
if not chunk.choices:
continue
choice = chunk.choices[0]
chunk_delta: ChoiceDelta = choice.delta
if chunk_delta.content:
full_response += chunk_delta.content
self._emit_stream_chunk_event(
chunk=chunk_delta.content,
from_task=from_task,
from_agent=from_agent,
)
if chunk_delta.tool_calls:
for tool_call in chunk_delta.tool_calls:
call_id = tool_call.id or "default"
if call_id not in tool_calls:
tool_calls[call_id] = {
"name": "",
"arguments": "",
}
if tool_call.function and tool_call.function.name:
tool_calls[call_id]["name"] = tool_call.function.name
if tool_call.function and tool_call.function.arguments:
tool_calls[call_id]["arguments"] += tool_call.function.arguments
if tool_calls and available_functions:
for call_data in tool_calls.values():
function_name = call_data["name"]
arguments = call_data["arguments"]
if not function_name or not arguments:
continue
if function_name not in available_functions:
logging.warning(
f"Function '{function_name}' not found in available functions"
)
continue
try:
function_args = json.loads(arguments)
except json.JSONDecodeError as e:
logging.error(f"Failed to parse streamed tool arguments: {e}")
continue
result = self._handle_tool_execution(
function_name=function_name,
function_args=function_args,
available_functions=available_functions,
from_task=from_task,
from_agent=from_agent,
)
if result is not None:
return result
full_response = self._apply_stop_words(full_response)
self._emit_call_completed_event(
response=full_response,
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
)
return full_response
def supports_function_calling(self) -> bool:
"""Check if the model supports function calling."""
return not self.is_o1_model

View File

@@ -1,5 +1,6 @@
from __future__ import annotations
import asyncio
from typing import TYPE_CHECKING
from crewai.memory import (
@@ -16,6 +17,8 @@ if TYPE_CHECKING:
class ContextualMemory:
"""Aggregates and retrieves context from multiple memory sources."""
def __init__(
self,
stm: ShortTermMemory,
@@ -46,9 +49,14 @@ class ContextualMemory:
self.exm.task = self.task
def build_context_for_task(self, task: Task, context: str) -> str:
"""
Automatically builds a minimal, highly relevant set of contextual information
for a given task.
"""Build contextual information for a task synchronously.
Args:
task: The task to build context for.
context: Additional context string.
Returns:
Formatted context string from all memory sources.
"""
query = f"{task.description} {context}".strip()
@@ -63,6 +71,31 @@ class ContextualMemory:
]
return "\n".join(filter(None, context_parts))
async def abuild_context_for_task(self, task: Task, context: str) -> str:
"""Build contextual information for a task asynchronously.
Args:
task: The task to build context for.
context: Additional context string.
Returns:
Formatted context string from all memory sources.
"""
query = f"{task.description} {context}".strip()
if query == "":
return ""
# Fetch all contexts concurrently
results = await asyncio.gather(
self._afetch_ltm_context(task.description),
self._afetch_stm_context(query),
self._afetch_entity_context(query),
self._afetch_external_context(query),
)
return "\n".join(filter(None, results))
def _fetch_stm_context(self, query: str) -> str:
"""
Fetches recent relevant insights from STM related to the task's description and expected_output,
@@ -135,3 +168,87 @@ class ContextualMemory:
f"- {result['content']}" for result in external_memories
)
return f"External memories:\n{formatted_memories}"
async def _afetch_stm_context(self, query: str) -> str:
"""Fetch recent relevant insights from STM asynchronously.
Args:
query: The search query.
Returns:
Formatted insights as bullet points, or empty string if none found.
"""
if self.stm is None:
return ""
stm_results = await self.stm.asearch(query)
formatted_results = "\n".join(
[f"- {result['content']}" for result in stm_results]
)
return f"Recent Insights:\n{formatted_results}" if stm_results else ""
async def _afetch_ltm_context(self, task: str) -> str | None:
"""Fetch historical data from LTM asynchronously.
Args:
task: The task description to search for.
Returns:
Formatted historical data as bullet points, or None if none found.
"""
if self.ltm is None:
return ""
ltm_results = await self.ltm.asearch(task, latest_n=2)
if not ltm_results:
return None
formatted_results = [
suggestion
for result in ltm_results
for suggestion in result["metadata"]["suggestions"]
]
formatted_results = list(dict.fromkeys(formatted_results))
formatted_results = "\n".join([f"- {result}" for result in formatted_results]) # type: ignore # Incompatible types in assignment (expression has type "str", variable has type "list[str]")
return f"Historical Data:\n{formatted_results}" if ltm_results else ""
async def _afetch_entity_context(self, query: str) -> str:
"""Fetch relevant entity information asynchronously.
Args:
query: The search query.
Returns:
Formatted entity information as bullet points, or empty string if none found.
"""
if self.em is None:
return ""
em_results = await self.em.asearch(query)
formatted_results = "\n".join(
[f"- {result['content']}" for result in em_results]
)
return f"Entities:\n{formatted_results}" if em_results else ""
async def _afetch_external_context(self, query: str) -> str:
"""Fetch relevant information from External Memory asynchronously.
Args:
query: The search query.
Returns:
Formatted information as bullet points, or empty string if none found.
"""
if self.exm is None:
return ""
external_memories = await self.exm.asearch(query)
if not external_memories:
return ""
formatted_memories = "\n".join(
f"- {result['content']}" for result in external_memories
)
return f"External memories:\n{formatted_memories}"

View File

@@ -26,7 +26,13 @@ class EntityMemory(Memory):
_memory_provider: str | None = PrivateAttr()
def __init__(self, crew=None, embedder_config=None, storage=None, path=None):
def __init__(
self,
crew: Any = None,
embedder_config: Any = None,
storage: Any = None,
path: str | None = None,
) -> None:
memory_provider = None
if embedder_config and isinstance(embedder_config, dict):
memory_provider = embedder_config.get("provider")
@@ -43,7 +49,7 @@ class EntityMemory(Memory):
if embedder_config and isinstance(embedder_config, dict)
else None
)
storage = Mem0Storage(type="short_term", crew=crew, config=config)
storage = Mem0Storage(type="short_term", crew=crew, config=config) # type: ignore[no-untyped-call]
else:
storage = (
storage
@@ -170,7 +176,17 @@ class EntityMemory(Memory):
query: str,
limit: int = 5,
score_threshold: float = 0.6,
):
) -> list[Any]:
"""Search entity memory for relevant entries.
Args:
query: The search query.
limit: Maximum number of results to return.
score_threshold: Minimum similarity score for results.
Returns:
List of matching memory entries.
"""
crewai_event_bus.emit(
self,
event=MemoryQueryStartedEvent(
@@ -217,6 +233,168 @@ class EntityMemory(Memory):
)
raise
async def asave(
self,
value: EntityMemoryItem | list[EntityMemoryItem],
metadata: dict[str, Any] | None = None,
) -> None:
"""Save entity items asynchronously.
Args:
value: Single EntityMemoryItem or list of EntityMemoryItems to save.
metadata: Optional metadata dict (not used, for signature compatibility).
"""
if not value:
return
items = value if isinstance(value, list) else [value]
is_batch = len(items) > 1
metadata = {"entity_count": len(items)} if is_batch else items[0].metadata
crewai_event_bus.emit(
self,
event=MemorySaveStartedEvent(
metadata=metadata,
source_type="entity_memory",
from_agent=self.agent,
from_task=self.task,
),
)
start_time = time.time()
saved_count = 0
errors: list[str | None] = []
async def save_single_item(item: EntityMemoryItem) -> tuple[bool, str | None]:
"""Save a single item asynchronously."""
try:
if self._memory_provider == "mem0":
data = f"""
Remember details about the following entity:
Name: {item.name}
Type: {item.type}
Entity Description: {item.description}
"""
else:
data = f"{item.name}({item.type}): {item.description}"
await super(EntityMemory, self).asave(data, item.metadata)
return True, None
except Exception as e:
return False, f"{item.name}: {e!s}"
try:
for item in items:
success, error = await save_single_item(item)
if success:
saved_count += 1
else:
errors.append(error)
if is_batch:
emit_value = f"Saved {saved_count} entities"
metadata = {"entity_count": saved_count, "errors": errors}
else:
emit_value = f"{items[0].name}({items[0].type}): {items[0].description}"
metadata = items[0].metadata
crewai_event_bus.emit(
self,
event=MemorySaveCompletedEvent(
value=emit_value,
metadata=metadata,
save_time_ms=(time.time() - start_time) * 1000,
source_type="entity_memory",
from_agent=self.agent,
from_task=self.task,
),
)
if errors:
raise Exception(
f"Partial save: {len(errors)} failed out of {len(items)}"
)
except Exception as e:
fail_metadata = (
{"entity_count": len(items), "saved": saved_count}
if is_batch
else items[0].metadata
)
crewai_event_bus.emit(
self,
event=MemorySaveFailedEvent(
metadata=fail_metadata,
error=str(e),
source_type="entity_memory",
from_agent=self.agent,
from_task=self.task,
),
)
raise
async def asearch(
self,
query: str,
limit: int = 5,
score_threshold: float = 0.6,
) -> list[Any]:
"""Search entity memory asynchronously.
Args:
query: The search query.
limit: Maximum number of results to return.
score_threshold: Minimum similarity score for results.
Returns:
List of matching memory entries.
"""
crewai_event_bus.emit(
self,
event=MemoryQueryStartedEvent(
query=query,
limit=limit,
score_threshold=score_threshold,
source_type="entity_memory",
from_agent=self.agent,
from_task=self.task,
),
)
start_time = time.time()
try:
results = await super().asearch(
query=query, limit=limit, score_threshold=score_threshold
)
crewai_event_bus.emit(
self,
event=MemoryQueryCompletedEvent(
query=query,
results=results,
limit=limit,
score_threshold=score_threshold,
query_time_ms=(time.time() - start_time) * 1000,
source_type="entity_memory",
from_agent=self.agent,
from_task=self.task,
),
)
return results
except Exception as e:
crewai_event_bus.emit(
self,
event=MemoryQueryFailedEvent(
query=query,
limit=limit,
score_threshold=score_threshold,
error=str(e),
source_type="entity_memory",
),
)
raise
def reset(self) -> None:
try:
self.storage.reset()

View File

@@ -30,7 +30,7 @@ class ExternalMemory(Memory):
def _configure_mem0(crew: Any, config: dict[str, Any]) -> Mem0Storage:
from crewai.memory.storage.mem0_storage import Mem0Storage
return Mem0Storage(type="external", crew=crew, config=config)
return Mem0Storage(type="external", crew=crew, config=config) # type: ignore[no-untyped-call]
@staticmethod
def external_supported_storages() -> dict[str, Any]:
@@ -53,7 +53,10 @@ class ExternalMemory(Memory):
if provider not in supported_storages:
raise ValueError(f"Provider {provider} not supported")
return supported_storages[provider](crew, embedder_config.get("config", {}))
storage: Storage = supported_storages[provider](
crew, embedder_config.get("config", {})
)
return storage
def save(
self,
@@ -111,7 +114,17 @@ class ExternalMemory(Memory):
query: str,
limit: int = 5,
score_threshold: float = 0.6,
):
) -> list[Any]:
"""Search external memory for relevant entries.
Args:
query: The search query.
limit: Maximum number of results to return.
score_threshold: Minimum similarity score for results.
Returns:
List of matching memory entries.
"""
crewai_event_bus.emit(
self,
event=MemoryQueryStartedEvent(
@@ -158,6 +171,124 @@ class ExternalMemory(Memory):
)
raise
async def asave(
self,
value: Any,
metadata: dict[str, Any] | None = None,
) -> None:
"""Save a value to external memory asynchronously.
Args:
value: The value to save.
metadata: Optional metadata to associate with the value.
"""
crewai_event_bus.emit(
self,
event=MemorySaveStartedEvent(
value=value,
metadata=metadata,
source_type="external_memory",
from_agent=self.agent,
from_task=self.task,
),
)
start_time = time.time()
try:
item = ExternalMemoryItem(
value=value,
metadata=metadata,
agent=self.agent.role if self.agent else None,
)
await super().asave(value=item.value, metadata=item.metadata)
crewai_event_bus.emit(
self,
event=MemorySaveCompletedEvent(
value=value,
metadata=metadata,
save_time_ms=(time.time() - start_time) * 1000,
source_type="external_memory",
from_agent=self.agent,
from_task=self.task,
),
)
except Exception as e:
crewai_event_bus.emit(
self,
event=MemorySaveFailedEvent(
value=value,
metadata=metadata,
error=str(e),
source_type="external_memory",
from_agent=self.agent,
from_task=self.task,
),
)
raise
async def asearch(
self,
query: str,
limit: int = 5,
score_threshold: float = 0.6,
) -> list[Any]:
"""Search external memory asynchronously.
Args:
query: The search query.
limit: Maximum number of results to return.
score_threshold: Minimum similarity score for results.
Returns:
List of matching memory entries.
"""
crewai_event_bus.emit(
self,
event=MemoryQueryStartedEvent(
query=query,
limit=limit,
score_threshold=score_threshold,
source_type="external_memory",
from_agent=self.agent,
from_task=self.task,
),
)
start_time = time.time()
try:
results = await super().asearch(
query=query, limit=limit, score_threshold=score_threshold
)
crewai_event_bus.emit(
self,
event=MemoryQueryCompletedEvent(
query=query,
results=results,
limit=limit,
score_threshold=score_threshold,
query_time_ms=(time.time() - start_time) * 1000,
source_type="external_memory",
from_agent=self.agent,
from_task=self.task,
),
)
return results
except Exception as e:
crewai_event_bus.emit(
self,
event=MemoryQueryFailedEvent(
query=query,
limit=limit,
score_threshold=score_threshold,
error=str(e),
source_type="external_memory",
),
)
raise
def reset(self) -> None:
self.storage.reset()

View File

@@ -24,7 +24,11 @@ class LongTermMemory(Memory):
LongTermMemoryItem instances.
"""
def __init__(self, storage=None, path=None):
def __init__(
self,
storage: LTMSQLiteStorage | None = None,
path: str | None = None,
) -> None:
if not storage:
storage = LTMSQLiteStorage(db_path=path) if path else LTMSQLiteStorage()
super().__init__(storage=storage)
@@ -48,7 +52,7 @@ class LongTermMemory(Memory):
metadata.update(
{"agent": item.agent, "expected_output": item.expected_output}
)
self.storage.save( # type: ignore # BUG?: Unexpected keyword argument "task_description","score","datetime" for "save" of "Storage"
self.storage.save(
task_description=item.task,
score=metadata["quality"],
metadata=metadata,
@@ -80,11 +84,20 @@ class LongTermMemory(Memory):
)
raise
def search( # type: ignore # signature of "search" incompatible with supertype "Memory"
def search( # type: ignore[override]
self,
task: str,
latest_n: int = 3,
) -> list[dict[str, Any]]: # type: ignore # signature of "search" incompatible with supertype "Memory"
) -> list[dict[str, Any]]:
"""Search long-term memory for relevant entries.
Args:
task: The task description to search for.
latest_n: Maximum number of results to return.
Returns:
List of matching memory entries.
"""
crewai_event_bus.emit(
self,
event=MemoryQueryStartedEvent(
@@ -98,7 +111,7 @@ class LongTermMemory(Memory):
start_time = time.time()
try:
results = self.storage.load(task, latest_n) # type: ignore # BUG?: "Storage" has no attribute "load"
results = self.storage.load(task, latest_n)
crewai_event_bus.emit(
self,
@@ -113,7 +126,118 @@ class LongTermMemory(Memory):
),
)
return results
return results or []
except Exception as e:
crewai_event_bus.emit(
self,
event=MemoryQueryFailedEvent(
query=task,
limit=latest_n,
error=str(e),
source_type="long_term_memory",
),
)
raise
async def asave(self, item: LongTermMemoryItem) -> None: # type: ignore[override]
"""Save an item to long-term memory asynchronously.
Args:
item: The LongTermMemoryItem to save.
"""
crewai_event_bus.emit(
self,
event=MemorySaveStartedEvent(
value=item.task,
metadata=item.metadata,
agent_role=item.agent,
source_type="long_term_memory",
from_agent=self.agent,
from_task=self.task,
),
)
start_time = time.time()
try:
metadata = item.metadata
metadata.update(
{"agent": item.agent, "expected_output": item.expected_output}
)
await self.storage.asave(
task_description=item.task,
score=metadata["quality"],
metadata=metadata,
datetime=item.datetime,
)
crewai_event_bus.emit(
self,
event=MemorySaveCompletedEvent(
value=item.task,
metadata=item.metadata,
agent_role=item.agent,
save_time_ms=(time.time() - start_time) * 1000,
source_type="long_term_memory",
from_agent=self.agent,
from_task=self.task,
),
)
except Exception as e:
crewai_event_bus.emit(
self,
event=MemorySaveFailedEvent(
value=item.task,
metadata=item.metadata,
agent_role=item.agent,
error=str(e),
source_type="long_term_memory",
),
)
raise
async def asearch( # type: ignore[override]
self,
task: str,
latest_n: int = 3,
) -> list[dict[str, Any]]:
"""Search long-term memory asynchronously.
Args:
task: The task description to search for.
latest_n: Maximum number of results to return.
Returns:
List of matching memory entries.
"""
crewai_event_bus.emit(
self,
event=MemoryQueryStartedEvent(
query=task,
limit=latest_n,
source_type="long_term_memory",
from_agent=self.agent,
from_task=self.task,
),
)
start_time = time.time()
try:
results = await self.storage.aload(task, latest_n)
crewai_event_bus.emit(
self,
event=MemoryQueryCompletedEvent(
query=task,
results=results,
limit=latest_n,
query_time_ms=(time.time() - start_time) * 1000,
source_type="long_term_memory",
from_agent=self.agent,
from_task=self.task,
),
)
return results or []
except Exception as e:
crewai_event_bus.emit(
self,
@@ -127,4 +251,5 @@ class LongTermMemory(Memory):
raise
def reset(self) -> None:
"""Reset long-term memory."""
self.storage.reset()

View File

@@ -13,9 +13,7 @@ if TYPE_CHECKING:
class Memory(BaseModel):
"""
Base class for memory, now supporting agent tags and generic metadata.
"""
"""Base class for memory, supporting agent tags and generic metadata."""
embedder_config: EmbedderConfig | dict[str, Any] | None = None
crew: Any | None = None
@@ -52,20 +50,72 @@ class Memory(BaseModel):
value: Any,
metadata: dict[str, Any] | None = None,
) -> None:
metadata = metadata or {}
"""Save a value to memory.
Args:
value: The value to save.
metadata: Optional metadata to associate with the value.
"""
metadata = metadata or {}
self.storage.save(value, metadata)
async def asave(
self,
value: Any,
metadata: dict[str, Any] | None = None,
) -> None:
"""Save a value to memory asynchronously.
Args:
value: The value to save.
metadata: Optional metadata to associate with the value.
"""
metadata = metadata or {}
await self.storage.asave(value, metadata)
def search(
self,
query: str,
limit: int = 5,
score_threshold: float = 0.6,
) -> list[Any]:
return self.storage.search(
"""Search memory for relevant entries.
Args:
query: The search query.
limit: Maximum number of results to return.
score_threshold: Minimum similarity score for results.
Returns:
List of matching memory entries.
"""
results: list[Any] = self.storage.search(
query=query, limit=limit, score_threshold=score_threshold
)
return results
async def asearch(
self,
query: str,
limit: int = 5,
score_threshold: float = 0.6,
) -> list[Any]:
"""Search memory for relevant entries asynchronously.
Args:
query: The search query.
limit: Maximum number of results to return.
score_threshold: Minimum similarity score for results.
Returns:
List of matching memory entries.
"""
results: list[Any] = await self.storage.asearch(
query=query, limit=limit, score_threshold=score_threshold
)
return results
def set_crew(self, crew: Any) -> Memory:
"""Set the crew for this memory instance."""
self.crew = crew
return self

View File

@@ -30,7 +30,13 @@ class ShortTermMemory(Memory):
_memory_provider: str | None = PrivateAttr()
def __init__(self, crew=None, embedder_config=None, storage=None, path=None):
def __init__(
self,
crew: Any = None,
embedder_config: Any = None,
storage: Any = None,
path: str | None = None,
) -> None:
memory_provider = None
if embedder_config and isinstance(embedder_config, dict):
memory_provider = embedder_config.get("provider")
@@ -47,7 +53,7 @@ class ShortTermMemory(Memory):
if embedder_config and isinstance(embedder_config, dict)
else None
)
storage = Mem0Storage(type="short_term", crew=crew, config=config)
storage = Mem0Storage(type="short_term", crew=crew, config=config) # type: ignore[no-untyped-call]
else:
storage = (
storage
@@ -123,7 +129,17 @@ class ShortTermMemory(Memory):
query: str,
limit: int = 5,
score_threshold: float = 0.6,
):
) -> list[Any]:
"""Search short-term memory for relevant entries.
Args:
query: The search query.
limit: Maximum number of results to return.
score_threshold: Minimum similarity score for results.
Returns:
List of matching memory entries.
"""
crewai_event_bus.emit(
self,
event=MemoryQueryStartedEvent(
@@ -140,7 +156,7 @@ class ShortTermMemory(Memory):
try:
results = self.storage.search(
query=query, limit=limit, score_threshold=score_threshold
) # type: ignore # BUG? The reference is to the parent class, but the parent class does not have this parameters
)
crewai_event_bus.emit(
self,
@@ -156,7 +172,130 @@ class ShortTermMemory(Memory):
),
)
return results
return list(results)
except Exception as e:
crewai_event_bus.emit(
self,
event=MemoryQueryFailedEvent(
query=query,
limit=limit,
score_threshold=score_threshold,
error=str(e),
source_type="short_term_memory",
),
)
raise
async def asave(
self,
value: Any,
metadata: dict[str, Any] | None = None,
) -> None:
"""Save a value to short-term memory asynchronously.
Args:
value: The value to save.
metadata: Optional metadata to associate with the value.
"""
crewai_event_bus.emit(
self,
event=MemorySaveStartedEvent(
value=value,
metadata=metadata,
source_type="short_term_memory",
from_agent=self.agent,
from_task=self.task,
),
)
start_time = time.time()
try:
item = ShortTermMemoryItem(
data=value,
metadata=metadata,
agent=self.agent.role if self.agent else None,
)
if self._memory_provider == "mem0":
item.data = (
f"Remember the following insights from Agent run: {item.data}"
)
await super().asave(value=item.data, metadata=item.metadata)
crewai_event_bus.emit(
self,
event=MemorySaveCompletedEvent(
value=value,
metadata=metadata,
save_time_ms=(time.time() - start_time) * 1000,
source_type="short_term_memory",
from_agent=self.agent,
from_task=self.task,
),
)
except Exception as e:
crewai_event_bus.emit(
self,
event=MemorySaveFailedEvent(
value=value,
metadata=metadata,
error=str(e),
source_type="short_term_memory",
from_agent=self.agent,
from_task=self.task,
),
)
raise
async def asearch(
self,
query: str,
limit: int = 5,
score_threshold: float = 0.6,
) -> list[Any]:
"""Search short-term memory asynchronously.
Args:
query: The search query.
limit: Maximum number of results to return.
score_threshold: Minimum similarity score for results.
Returns:
List of matching memory entries.
"""
crewai_event_bus.emit(
self,
event=MemoryQueryStartedEvent(
query=query,
limit=limit,
score_threshold=score_threshold,
source_type="short_term_memory",
from_agent=self.agent,
from_task=self.task,
),
)
start_time = time.time()
try:
results = await self.storage.asearch(
query=query, limit=limit, score_threshold=score_threshold
)
crewai_event_bus.emit(
self,
event=MemoryQueryCompletedEvent(
query=query,
results=results,
limit=limit,
score_threshold=score_threshold,
query_time_ms=(time.time() - start_time) * 1000,
source_type="short_term_memory",
from_agent=self.agent,
from_task=self.task,
),
)
return list(results)
except Exception as e:
crewai_event_bus.emit(
self,

View File

@@ -3,29 +3,30 @@ from pathlib import Path
import sqlite3
from typing import Any
import aiosqlite
from crewai.utilities import Printer
from crewai.utilities.paths import db_storage_path
class LTMSQLiteStorage:
"""
An updated SQLite storage class for LTM data storage.
"""
"""SQLite storage class for long-term memory data."""
def __init__(self, db_path: str | None = None) -> None:
"""Initialize the SQLite storage.
Args:
db_path: Optional path to the database file.
"""
if db_path is None:
# Get the parent directory of the default db path and create our db file there
db_path = str(Path(db_storage_path()) / "long_term_memory_storage.db")
self.db_path = db_path
self._printer: Printer = Printer()
# Ensure parent directory exists
Path(self.db_path).parent.mkdir(parents=True, exist_ok=True)
self._initialize_db()
def _initialize_db(self):
"""
Initializes the SQLite database and creates LTM table
"""
def _initialize_db(self) -> None:
"""Initialize the SQLite database and create LTM table."""
try:
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
@@ -106,9 +107,7 @@ class LTMSQLiteStorage:
)
return None
def reset(
self,
) -> None:
def reset(self) -> None:
"""Resets the LTM table with error handling."""
try:
with sqlite3.connect(self.db_path) as conn:
@@ -121,4 +120,87 @@ class LTMSQLiteStorage:
content=f"MEMORY ERROR: An error occurred while deleting all rows in LTM: {e}",
color="red",
)
return
async def asave(
self,
task_description: str,
metadata: dict[str, Any],
datetime: str,
score: int | float,
) -> None:
"""Save data to the LTM table asynchronously.
Args:
task_description: Description of the task.
metadata: Metadata associated with the memory.
datetime: Timestamp of the memory.
score: Quality score of the memory.
"""
try:
async with aiosqlite.connect(self.db_path) as conn:
await conn.execute(
"""
INSERT INTO long_term_memories (task_description, metadata, datetime, score)
VALUES (?, ?, ?, ?)
""",
(task_description, json.dumps(metadata), datetime, score),
)
await conn.commit()
except aiosqlite.Error as e:
self._printer.print(
content=f"MEMORY ERROR: An error occurred while saving to LTM: {e}",
color="red",
)
async def aload(
self, task_description: str, latest_n: int
) -> list[dict[str, Any]] | None:
"""Query the LTM table by task description asynchronously.
Args:
task_description: Description of the task to search for.
latest_n: Maximum number of results to return.
Returns:
List of matching memory entries or None if error occurs.
"""
try:
async with aiosqlite.connect(self.db_path) as conn:
cursor = await conn.execute(
f"""
SELECT metadata, datetime, score
FROM long_term_memories
WHERE task_description = ?
ORDER BY datetime DESC, score ASC
LIMIT {latest_n}
""", # nosec # noqa: S608
(task_description,),
)
rows = await cursor.fetchall()
if rows:
return [
{
"metadata": json.loads(row[0]),
"datetime": row[1],
"score": row[2],
}
for row in rows
]
except aiosqlite.Error as e:
self._printer.print(
content=f"MEMORY ERROR: An error occurred while querying LTM: {e}",
color="red",
)
return None
async def areset(self) -> None:
"""Reset the LTM table asynchronously."""
try:
async with aiosqlite.connect(self.db_path) as conn:
await conn.execute("DELETE FROM long_term_memories")
await conn.commit()
except aiosqlite.Error as e:
self._printer.print(
content=f"MEMORY ERROR: An error occurred while deleting all rows in LTM: {e}",
color="red",
)

View File

@@ -129,6 +129,12 @@ class RAGStorage(BaseRAGStorage):
return f"{base_path}/{file_name}"
def save(self, value: Any, metadata: dict[str, Any]) -> None:
"""Save a value to storage.
Args:
value: The value to save.
metadata: Metadata to associate with the value.
"""
try:
client = self._get_client()
collection_name = (
@@ -167,6 +173,51 @@ class RAGStorage(BaseRAGStorage):
f"Error during {self.type} save: {e!s}\n{traceback.format_exc()}"
)
async def asave(self, value: Any, metadata: dict[str, Any]) -> None:
"""Save a value to storage asynchronously.
Args:
value: The value to save.
metadata: Metadata to associate with the value.
"""
try:
client = self._get_client()
collection_name = (
f"memory_{self.type}_{self.agents}"
if self.agents
else f"memory_{self.type}"
)
await client.aget_or_create_collection(collection_name=collection_name)
document: BaseRecord = {"content": value}
if metadata:
document["metadata"] = metadata
batch_size = None
if (
self.embedder_config
and isinstance(self.embedder_config, dict)
and "config" in self.embedder_config
):
nested_config = self.embedder_config["config"]
if isinstance(nested_config, dict):
batch_size = nested_config.get("batch_size")
if batch_size is not None:
await client.aadd_documents(
collection_name=collection_name,
documents=[document],
batch_size=cast(int, batch_size),
)
else:
await client.aadd_documents(
collection_name=collection_name, documents=[document]
)
except Exception as e:
logging.error(
f"Error during {self.type} async save: {e!s}\n{traceback.format_exc()}"
)
def search(
self,
query: str,
@@ -174,6 +225,17 @@ class RAGStorage(BaseRAGStorage):
filter: dict[str, Any] | None = None,
score_threshold: float = 0.6,
) -> list[Any]:
"""Search for matching entries in storage.
Args:
query: The search query.
limit: Maximum number of results to return.
filter: Optional metadata filter.
score_threshold: Minimum similarity score for results.
Returns:
List of matching entries.
"""
try:
client = self._get_client()
collection_name = (
@@ -194,6 +256,44 @@ class RAGStorage(BaseRAGStorage):
)
return []
async def asearch(
self,
query: str,
limit: int = 5,
filter: dict[str, Any] | None = None,
score_threshold: float = 0.6,
) -> list[Any]:
"""Search for matching entries in storage asynchronously.
Args:
query: The search query.
limit: Maximum number of results to return.
filter: Optional metadata filter.
score_threshold: Minimum similarity score for results.
Returns:
List of matching entries.
"""
try:
client = self._get_client()
collection_name = (
f"memory_{self.type}_{self.agents}"
if self.agents
else f"memory_{self.type}"
)
return await client.asearch(
collection_name=collection_name,
query=query,
limit=limit,
metadata_filter=filter,
score_threshold=score_threshold,
)
except Exception as e:
logging.error(
f"Error during {self.type} async search: {e!s}\n{traceback.format_exc()}"
)
return []
def reset(self) -> None:
try:
client = self._get_client()

View File

@@ -2,8 +2,10 @@
from __future__ import annotations
import asyncio
from collections.abc import Callable
from functools import wraps
import inspect
from typing import TYPE_CHECKING, Any, Concatenate, ParamSpec, TypeVar, overload
from crewai.project.utils import memoize
@@ -156,6 +158,23 @@ def cache_handler(meth: Callable[P, R]) -> CacheHandlerMethod[P, R]:
return CacheHandlerMethod(memoize(meth))
def _call_method(method: Callable[..., Any], *args: Any, **kwargs: Any) -> Any:
"""Call a method, awaiting it if async and running in an event loop."""
result = method(*args, **kwargs)
if inspect.iscoroutine(result):
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = None
if loop and loop.is_running():
import concurrent.futures
with concurrent.futures.ThreadPoolExecutor() as pool:
return pool.submit(asyncio.run, result).result()
return asyncio.run(result)
return result
@overload
def crew(
meth: Callable[Concatenate[SelfT, P], Crew],
@@ -198,7 +217,7 @@ def crew(
# Instantiate tasks in order
for _, task_method in tasks:
task_instance = task_method(self)
task_instance = _call_method(task_method, self)
instantiated_tasks.append(task_instance)
agent_instance = getattr(task_instance, "agent", None)
if agent_instance and agent_instance.role not in agent_roles:
@@ -207,7 +226,7 @@ def crew(
# Instantiate agents not included by tasks
for _, agent_method in agents:
agent_instance = agent_method(self)
agent_instance = _call_method(agent_method, self)
if agent_instance.role not in agent_roles:
instantiated_agents.append(agent_instance)
agent_roles.add(agent_instance.role)
@@ -215,7 +234,7 @@ def crew(
self.agents = instantiated_agents
self.tasks = instantiated_tasks
crew_instance = meth(self, *args, **kwargs)
crew_instance: Crew = _call_method(meth, self, *args, **kwargs)
def callback_wrapper(
hook: Callable[Concatenate[CrewInstance, P2], R2], instance: CrewInstance

View File

@@ -1,7 +1,8 @@
"""Utility functions for the crewai project module."""
from collections.abc import Callable
from collections.abc import Callable, Coroutine
from functools import wraps
import inspect
from typing import Any, ParamSpec, TypeVar, cast
from pydantic import BaseModel
@@ -37,8 +38,8 @@ def _make_hashable(arg: Any) -> Any:
def memoize(meth: Callable[P, R]) -> Callable[P, R]:
"""Memoize a method by caching its results based on arguments.
Handles Pydantic BaseModel instances by converting them to JSON strings
before hashing for cache lookup.
Handles both sync and async methods. Pydantic BaseModel instances are
converted to JSON strings before hashing for cache lookup.
Args:
meth: The method to memoize.
@@ -46,18 +47,16 @@ def memoize(meth: Callable[P, R]) -> Callable[P, R]:
Returns:
A memoized version of the method that caches results.
"""
if inspect.iscoroutinefunction(meth):
return cast(Callable[P, R], _memoize_async(meth))
return _memoize_sync(meth)
def _memoize_sync(meth: Callable[P, R]) -> Callable[P, R]:
"""Memoize a synchronous method."""
@wraps(meth)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
"""Wrapper that converts arguments to hashable form before caching.
Args:
*args: Positional arguments to the memoized method.
**kwargs: Keyword arguments to the memoized method.
Returns:
The result of the memoized method call.
"""
hashable_args = tuple(_make_hashable(arg) for arg in args)
hashable_kwargs = tuple(
sorted((k, _make_hashable(v)) for k, v in kwargs.items())
@@ -73,3 +72,27 @@ def memoize(meth: Callable[P, R]) -> Callable[P, R]:
return result
return cast(Callable[P, R], wrapper)
def _memoize_async(
meth: Callable[P, Coroutine[Any, Any, R]],
) -> Callable[P, Coroutine[Any, Any, R]]:
"""Memoize an async method."""
@wraps(meth)
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
hashable_args = tuple(_make_hashable(arg) for arg in args)
hashable_kwargs = tuple(
sorted((k, _make_hashable(v)) for k, v in kwargs.items())
)
cache_key = str((hashable_args, hashable_kwargs))
cached_result: R | None = cache.read(tool=meth.__name__, input=cache_key)
if cached_result is not None:
return cached_result
result = await meth(*args, **kwargs)
cache.add(tool=meth.__name__, input=cache_key, output=result)
return result
return wrapper

View File

@@ -2,8 +2,10 @@
from __future__ import annotations
import asyncio
from collections.abc import Callable
from functools import partial
import inspect
from pathlib import Path
from typing import (
TYPE_CHECKING,
@@ -132,6 +134,22 @@ class CrewClass(Protocol):
crew: Callable[..., Crew]
def _resolve_result(result: Any) -> Any:
"""Resolve a potentially async result to its value."""
if inspect.iscoroutine(result):
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = None
if loop and loop.is_running():
import concurrent.futures
with concurrent.futures.ThreadPoolExecutor() as pool:
return pool.submit(asyncio.run, result).result()
return asyncio.run(result)
return result
class DecoratedMethod(Generic[P, R]):
"""Base wrapper for methods with decorator metadata.
@@ -162,7 +180,12 @@ class DecoratedMethod(Generic[P, R]):
"""
if obj is None:
return self
bound = partial(self._meth, obj)
inner = partial(self._meth, obj)
def _bound(*args: Any, **kwargs: Any) -> R:
result: R = _resolve_result(inner(*args, **kwargs)) # type: ignore[call-arg]
return result
for attr in (
"is_agent",
"is_llm",
@@ -174,8 +197,8 @@ class DecoratedMethod(Generic[P, R]):
"is_crew",
):
if hasattr(self, attr):
setattr(bound, attr, getattr(self, attr))
return bound
setattr(_bound, attr, getattr(self, attr))
return _bound
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R:
"""Call the wrapped method.
@@ -236,6 +259,7 @@ class BoundTaskMethod(Generic[TaskResultT]):
The task result with name ensured.
"""
result = self._task_method.unwrap()(self._obj, *args, **kwargs)
result = _resolve_result(result)
return self._task_method.ensure_task_name(result)
@@ -292,7 +316,9 @@ class TaskMethod(Generic[P, TaskResultT]):
Returns:
The task instance with name set if not already provided.
"""
return self.ensure_task_name(self._meth(*args, **kwargs))
result = self._meth(*args, **kwargs)
result = _resolve_result(result)
return self.ensure_task_name(result)
def unwrap(self) -> Callable[P, TaskResultT]:
"""Get the original unwrapped method.

View File

@@ -497,6 +497,107 @@ class Task(BaseModel):
result = self._execute_core(agent, context, tools)
future.set_result(result)
async def aexecute_sync(
self,
agent: BaseAgent | None = None,
context: str | None = None,
tools: list[BaseTool] | None = None,
) -> TaskOutput:
"""Execute the task asynchronously using native async/await."""
return await self._aexecute_core(agent, context, tools)
async def _aexecute_core(
self,
agent: BaseAgent | None,
context: str | None,
tools: list[Any] | None,
) -> TaskOutput:
"""Run the core execution logic of the task asynchronously."""
try:
agent = agent or self.agent
self.agent = agent
if not agent:
raise Exception(
f"The task '{self.description}' has no agent assigned, therefore it can't be executed directly and should be executed in a Crew using a specific process that support that, like hierarchical."
)
self.start_time = datetime.datetime.now()
self.prompt_context = context
tools = tools or self.tools or []
self.processed_by_agents.add(agent.role)
crewai_event_bus.emit(self, TaskStartedEvent(context=context, task=self)) # type: ignore[no-untyped-call]
result = await agent.aexecute_task(
task=self,
context=context,
tools=tools,
)
if not self._guardrails and not self._guardrail:
pydantic_output, json_output = self._export_output(result)
else:
pydantic_output, json_output = None, None
task_output = TaskOutput(
name=self.name or self.description,
description=self.description,
expected_output=self.expected_output,
raw=result,
pydantic=pydantic_output,
json_dict=json_output,
agent=agent.role,
output_format=self._get_output_format(),
messages=agent.last_messages, # type: ignore[attr-defined]
)
if self._guardrails:
for idx, guardrail in enumerate(self._guardrails):
task_output = await self._ainvoke_guardrail_function(
task_output=task_output,
agent=agent,
tools=tools,
guardrail=guardrail,
guardrail_index=idx,
)
if self._guardrail:
task_output = await self._ainvoke_guardrail_function(
task_output=task_output,
agent=agent,
tools=tools,
guardrail=self._guardrail,
)
self.output = task_output
self.end_time = datetime.datetime.now()
if self.callback:
self.callback(self.output)
crew = self.agent.crew # type: ignore[union-attr]
if crew and crew.task_callback and crew.task_callback != self.callback:
crew.task_callback(self.output)
if self.output_file:
content = (
json_output
if json_output
else (
pydantic_output.model_dump_json() if pydantic_output else result
)
)
self._save_file(content)
crewai_event_bus.emit(
self,
TaskCompletedEvent(output=task_output, task=self), # type: ignore[no-untyped-call]
)
return task_output
except Exception as e:
self.end_time = datetime.datetime.now()
crewai_event_bus.emit(self, TaskFailedEvent(error=str(e), task=self)) # type: ignore[no-untyped-call]
raise e # Re-raise the exception after emitting the event
def _execute_core(
self,
agent: BaseAgent | None,
@@ -539,7 +640,7 @@ class Task(BaseModel):
json_dict=json_output,
agent=agent.role,
output_format=self._get_output_format(),
messages=agent.last_messages,
messages=agent.last_messages, # type: ignore[attr-defined]
)
if self._guardrails:
@@ -950,7 +1051,103 @@ Follow these guidelines:
json_dict=json_output,
agent=agent.role,
output_format=self._get_output_format(),
messages=agent.last_messages,
messages=agent.last_messages, # type: ignore[attr-defined]
)
return task_output
async def _ainvoke_guardrail_function(
self,
task_output: TaskOutput,
agent: BaseAgent,
tools: list[BaseTool],
guardrail: GuardrailCallable | None,
guardrail_index: int | None = None,
) -> TaskOutput:
"""Invoke the guardrail function asynchronously."""
if not guardrail:
return task_output
if guardrail_index is not None:
current_retry_count = self._guardrail_retry_counts.get(guardrail_index, 0)
else:
current_retry_count = self.retry_count
max_attempts = self.guardrail_max_retries + 1
for attempt in range(max_attempts):
guardrail_result = process_guardrail(
output=task_output,
guardrail=guardrail,
retry_count=current_retry_count,
event_source=self,
from_task=self,
from_agent=agent,
)
if guardrail_result.success:
if guardrail_result.result is None:
raise Exception(
"Task guardrail returned None as result. This is not allowed."
)
if isinstance(guardrail_result.result, str):
task_output.raw = guardrail_result.result
pydantic_output, json_output = self._export_output(
guardrail_result.result
)
task_output.pydantic = pydantic_output
task_output.json_dict = json_output
elif isinstance(guardrail_result.result, TaskOutput):
task_output = guardrail_result.result
return task_output
if attempt >= self.guardrail_max_retries:
guardrail_name = (
f"guardrail {guardrail_index}"
if guardrail_index is not None
else "guardrail"
)
raise Exception(
f"Task failed {guardrail_name} validation after {self.guardrail_max_retries} retries. "
f"Last error: {guardrail_result.error}"
)
if guardrail_index is not None:
current_retry_count += 1
self._guardrail_retry_counts[guardrail_index] = current_retry_count
else:
self.retry_count += 1
current_retry_count = self.retry_count
context = self.i18n.errors("validation_error").format(
guardrail_result_error=guardrail_result.error,
task_output=task_output.raw,
)
printer = Printer()
printer.print(
content=f"Guardrail {guardrail_index if guardrail_index is not None else ''} blocked (attempt {attempt + 1}/{max_attempts}), retrying due to: {guardrail_result.error}\n",
color="yellow",
)
result = await agent.aexecute_task(
task=self,
context=context,
tools=tools,
)
pydantic_output, json_output = self._export_output(result)
task_output = TaskOutput(
name=self.name or self.description,
description=self.description,
expected_output=self.expected_output,
raw=result,
pydantic=pydantic_output,
json_dict=json_output,
agent=agent.role,
output_format=self._get_output_format(),
messages=agent.last_messages, # type: ignore[attr-defined]
)
return task_output

View File

@@ -9,12 +9,14 @@ data is collected. Users can opt-in to share more complete data using the
from __future__ import annotations
import asyncio
import atexit
from collections.abc import Callable
from importlib.metadata import version
import json
import logging
import os
import platform
import signal
import threading
from typing import TYPE_CHECKING, Any
@@ -31,6 +33,14 @@ from opentelemetry.sdk.trace.export import (
from opentelemetry.trace import Span
from typing_extensions import Self
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.system_events import (
SigContEvent,
SigHupEvent,
SigIntEvent,
SigTStpEvent,
SigTermEvent,
)
from crewai.telemetry.constants import (
CREWAI_TELEMETRY_BASE_URL,
CREWAI_TELEMETRY_SERVICE_NAME,
@@ -121,6 +131,7 @@ class Telemetry:
)
self.provider.add_span_processor(processor)
self._register_shutdown_handlers()
self.ready = True
except Exception as e:
if isinstance(
@@ -155,6 +166,71 @@ class Telemetry:
self.ready = False
self.trace_set = False
def _register_shutdown_handlers(self) -> None:
"""Register handlers for graceful shutdown on process exit and signals."""
atexit.register(self._shutdown)
self._original_handlers: dict[int, Any] = {}
self._register_signal_handler(signal.SIGTERM, SigTermEvent, shutdown=True)
self._register_signal_handler(signal.SIGINT, SigIntEvent, shutdown=True)
self._register_signal_handler(signal.SIGHUP, SigHupEvent, shutdown=False)
self._register_signal_handler(signal.SIGTSTP, SigTStpEvent, shutdown=False)
self._register_signal_handler(signal.SIGCONT, SigContEvent, shutdown=False)
def _register_signal_handler(
self,
sig: signal.Signals,
event_class: type,
shutdown: bool = False,
) -> None:
"""Register a signal handler that emits an event.
Args:
sig: The signal to handle.
event_class: The event class to instantiate and emit.
shutdown: Whether to trigger shutdown on this signal.
"""
try:
original_handler = signal.getsignal(sig)
self._original_handlers[sig] = original_handler
def handler(signum: int, frame: Any) -> None:
crewai_event_bus.emit(self, event_class())
if shutdown:
self._shutdown()
if original_handler not in (signal.SIG_DFL, signal.SIG_IGN, None):
if callable(original_handler):
original_handler(signum, frame)
elif shutdown:
raise SystemExit(0)
signal.signal(sig, handler)
except ValueError as e:
logger.warning(
f"Cannot register {sig.name} handler: not running in main thread",
exc_info=e,
)
except OSError as e:
logger.warning(f"Cannot register {sig.name} handler: {e}", exc_info=e)
def _shutdown(self) -> None:
"""Flush and shutdown the telemetry provider on process exit.
Uses a short timeout to avoid blocking process shutdown.
"""
if not self.ready:
return
try:
self.provider.force_flush(timeout_millis=5000)
self.provider.shutdown()
self.ready = False
except Exception as e:
logger.debug(f"Telemetry shutdown failed: {e}")
def _safe_telemetry_operation(
self, operation: Callable[[], Span | None]
) -> Span | None:

View File

@@ -2,9 +2,18 @@ from __future__ import annotations
from abc import ABC, abstractmethod
import asyncio
from collections.abc import Callable
from collections.abc import Awaitable, Callable
from inspect import signature
from typing import Any, cast, get_args, get_origin
from typing import (
Any,
Generic,
ParamSpec,
TypeVar,
cast,
get_args,
get_origin,
overload,
)
from pydantic import (
BaseModel,
@@ -14,6 +23,7 @@ from pydantic import (
create_model,
field_validator,
)
from typing_extensions import TypeIs
from crewai.tools.structured_tool import CrewStructuredTool
from crewai.utilities.printer import Printer
@@ -21,6 +31,19 @@ from crewai.utilities.printer import Printer
_printer = Printer()
P = ParamSpec("P")
R = TypeVar("R", covariant=True)
def _is_async_callable(func: Callable[..., Any]) -> bool:
"""Check if a callable is async."""
return asyncio.iscoroutinefunction(func)
def _is_awaitable(value: R | Awaitable[R]) -> TypeIs[Awaitable[R]]:
"""Type narrowing check for awaitable values."""
return asyncio.iscoroutine(value) or asyncio.isfuture(value)
class EnvVar(BaseModel):
name: str
@@ -55,7 +78,7 @@ class BaseTool(BaseModel, ABC):
default=False, description="Flag to check if the description has been updated."
)
cache_function: Callable = Field(
cache_function: Callable[..., bool] = Field(
default=lambda _args=None, _result=None: True,
description="Function that will be used to determine if the tool should be cached, should return a boolean. If None, the tool will be cached.",
)
@@ -123,6 +146,35 @@ class BaseTool(BaseModel, ABC):
return result
async def arun(
self,
*args: Any,
**kwargs: Any,
) -> Any:
"""Execute the tool asynchronously.
Args:
*args: Positional arguments to pass to the tool.
**kwargs: Keyword arguments to pass to the tool.
Returns:
The result of the tool execution.
"""
result = await self._arun(*args, **kwargs)
self.current_usage_count += 1
return result
async def _arun(
self,
*args: Any,
**kwargs: Any,
) -> Any:
"""Async implementation of the tool. Override for async support."""
raise NotImplementedError(
f"{self.__class__.__name__} does not implement _arun. "
"Override _arun for async support or use run() for sync execution."
)
def reset_usage_count(self) -> None:
"""Reset the current usage count to zero."""
self.current_usage_count = 0
@@ -133,7 +185,17 @@ class BaseTool(BaseModel, ABC):
*args: Any,
**kwargs: Any,
) -> Any:
"""Here goes the actual implementation of the tool."""
"""Sync implementation of the tool.
Subclasses must implement this method for synchronous execution.
Args:
*args: Positional arguments for the tool.
**kwargs: Keyword arguments for the tool.
Returns:
The result of the tool execution.
"""
def to_structured_tool(self) -> CrewStructuredTool:
"""Convert this tool to a CrewStructuredTool instance."""
@@ -239,21 +301,90 @@ class BaseTool(BaseModel, ABC):
if args:
args_str = ", ".join(BaseTool._get_arg_annotations(arg) for arg in args)
return f"{origin.__name__}[{args_str}]"
return str(f"{origin.__name__}[{args_str}]")
return origin.__name__
return str(origin.__name__)
class Tool(BaseTool):
"""The function that will be executed when the tool is called."""
class Tool(BaseTool, Generic[P, R]):
"""Tool that wraps a callable function.
func: Callable
def _run(self, *args: Any, **kwargs: Any) -> Any:
return self.func(*args, **kwargs)
Type Parameters:
P: ParamSpec capturing the function's parameters.
R: The return type of the function.
"""
func: Callable[P, R | Awaitable[R]]
def run(self, *args: P.args, **kwargs: P.kwargs) -> R:
"""Executes the tool synchronously.
Args:
*args: Positional arguments for the tool.
**kwargs: Keyword arguments for the tool.
Returns:
The result of the tool execution.
"""
_printer.print(f"Using Tool: {self.name}", color="cyan")
result = self.func(*args, **kwargs)
if asyncio.iscoroutine(result):
result = asyncio.run(result)
self.current_usage_count += 1
return result # type: ignore[return-value]
def _run(self, *args: P.args, **kwargs: P.kwargs) -> R:
"""Executes the wrapped function.
Args:
*args: Positional arguments for the function.
**kwargs: Keyword arguments for the function.
Returns:
The result of the function execution.
"""
return self.func(*args, **kwargs) # type: ignore[return-value]
async def arun(self, *args: P.args, **kwargs: P.kwargs) -> R:
"""Executes the tool asynchronously.
Args:
*args: Positional arguments for the tool.
**kwargs: Keyword arguments for the tool.
Returns:
The result of the tool execution.
"""
result = await self._arun(*args, **kwargs)
self.current_usage_count += 1
return result
async def _arun(self, *args: P.args, **kwargs: P.kwargs) -> R:
"""Executes the wrapped function asynchronously.
Args:
*args: Positional arguments for the function.
**kwargs: Keyword arguments for the function.
Returns:
The result of the async function execution.
Raises:
NotImplementedError: If the wrapped function is not async.
"""
result = self.func(*args, **kwargs)
if _is_awaitable(result):
return await result
raise NotImplementedError(
f"{self.name} does not have an async function. "
"Use run() for sync execution or provide an async function."
)
@classmethod
def from_langchain(cls, tool: Any) -> Tool:
def from_langchain(cls, tool: Any) -> Tool[..., Any]:
"""Create a Tool instance from a CrewStructuredTool.
This method takes a CrewStructuredTool object and converts it into a
@@ -261,10 +392,10 @@ class Tool(BaseTool):
attribute and infers the argument schema if not explicitly provided.
Args:
tool (Any): The CrewStructuredTool object to be converted.
tool: The CrewStructuredTool object to be converted.
Returns:
Tool: A new Tool instance created from the provided CrewStructuredTool.
A new Tool instance created from the provided CrewStructuredTool.
Raises:
ValueError: If the provided tool does not have a callable 'func' attribute.
@@ -308,37 +439,83 @@ class Tool(BaseTool):
def to_langchain(
tools: list[BaseTool | CrewStructuredTool],
) -> list[CrewStructuredTool]:
"""Convert a list of tools to CrewStructuredTool instances."""
return [t.to_structured_tool() if isinstance(t, BaseTool) else t for t in tools]
P2 = ParamSpec("P2")
R2 = TypeVar("R2")
@overload
def tool(func: Callable[P2, R2], /) -> Tool[P2, R2]: ...
@overload
def tool(
*args, result_as_answer: bool = False, max_usage_count: int | None = None
) -> Callable:
"""
Decorator to create a tool from a function.
name: str,
/,
*,
result_as_answer: bool = ...,
max_usage_count: int | None = ...,
) -> Callable[[Callable[P2, R2]], Tool[P2, R2]]: ...
@overload
def tool(
*,
result_as_answer: bool = ...,
max_usage_count: int | None = ...,
) -> Callable[[Callable[P2, R2]], Tool[P2, R2]]: ...
def tool(
*args: Callable[P2, R2] | str,
result_as_answer: bool = False,
max_usage_count: int | None = None,
) -> Tool[P2, R2] | Callable[[Callable[P2, R2]], Tool[P2, R2]]:
"""Decorator to create a Tool from a function.
Can be used in three ways:
1. @tool - decorator without arguments, uses function name
2. @tool("name") - decorator with custom name
3. @tool(result_as_answer=True) - decorator with options
Args:
*args: Positional arguments, either the function to decorate or the tool name.
result_as_answer: Flag to indicate if the tool result should be used as the final agent answer.
max_usage_count: Maximum number of times this tool can be used. None means unlimited usage.
*args: Either the function to decorate or a custom tool name.
result_as_answer: If True, the tool result becomes the final agent answer.
max_usage_count: Maximum times this tool can be used. None means unlimited.
Returns:
A Tool instance.
Example:
@tool
def greet(name: str) -> str:
'''Greet someone.'''
return f"Hello, {name}!"
result = greet.run("World")
"""
def _make_with_name(tool_name: str) -> Callable:
def _make_tool(f: Callable) -> BaseTool:
def _make_with_name(tool_name: str) -> Callable[[Callable[P2, R2]], Tool[P2, R2]]:
def _make_tool(f: Callable[P2, R2]) -> Tool[P2, R2]:
if f.__doc__ is None:
raise ValueError("Function must have a docstring")
if f.__annotations__ is None:
func_annotations = getattr(f, "__annotations__", None)
if func_annotations is None:
raise ValueError("Function must have type annotations")
class_name = "".join(tool_name.split()).title()
args_schema = cast(
tool_args_schema = cast(
type[PydanticBaseModel],
type(
class_name,
(PydanticBaseModel,),
{
"__annotations__": {
k: v for k, v in f.__annotations__.items() if k != "return"
k: v for k, v in func_annotations.items() if k != "return"
},
},
),
@@ -348,10 +525,9 @@ def tool(
name=tool_name,
description=f.__doc__,
func=f,
args_schema=args_schema,
args_schema=tool_args_schema,
result_as_answer=result_as_answer,
max_usage_count=max_usage_count,
current_usage_count=0,
)
return _make_tool
@@ -360,4 +536,10 @@ def tool(
return _make_with_name(args[0].__name__)(args[0])
if len(args) == 1 and isinstance(args[0], str):
return _make_with_name(args[0])
if len(args) == 0:
def decorator(f: Callable[P2, R2]) -> Tool[P2, R2]:
return _make_with_name(f.__name__)(f)
return decorator
raise ValueError("Invalid arguments")

View File

@@ -160,6 +160,251 @@ class ToolUsage:
return f"{self._use(tool_string=tool_string, tool=tool, calling=calling)}"
async def ause(
self, calling: ToolCalling | InstructorToolCalling, tool_string: str
) -> str:
"""Execute a tool asynchronously.
Args:
calling: The tool calling information.
tool_string: The raw tool string from the agent.
Returns:
The result of the tool execution as a string.
"""
if isinstance(calling, ToolUsageError):
error = calling.message
if self.agent and self.agent.verbose:
self._printer.print(content=f"\n\n{error}\n", color="red")
if self.task:
self.task.increment_tools_errors()
return error
try:
tool = self._select_tool(calling.tool_name)
except Exception as e:
error = getattr(e, "message", str(e))
if self.task:
self.task.increment_tools_errors()
if self.agent and self.agent.verbose:
self._printer.print(content=f"\n\n{error}\n", color="red")
return error
if (
isinstance(tool, CrewStructuredTool)
and tool.name == self._i18n.tools("add_image")["name"] # type: ignore
):
try:
return await self._ause(
tool_string=tool_string, tool=tool, calling=calling
)
except Exception as e:
error = getattr(e, "message", str(e))
if self.task:
self.task.increment_tools_errors()
if self.agent and self.agent.verbose:
self._printer.print(content=f"\n\n{error}\n", color="red")
return error
return (
f"{await self._ause(tool_string=tool_string, tool=tool, calling=calling)}"
)
async def _ause(
self,
tool_string: str,
tool: CrewStructuredTool,
calling: ToolCalling | InstructorToolCalling,
) -> str:
"""Internal async tool execution implementation.
Args:
tool_string: The raw tool string from the agent.
tool: The tool to execute.
calling: The tool calling information.
Returns:
The result of the tool execution as a string.
"""
if self._check_tool_repeated_usage(calling=calling):
try:
result = self._i18n.errors("task_repeated_usage").format(
tool_names=self.tools_names
)
self._telemetry.tool_repeated_usage(
llm=self.function_calling_llm,
tool_name=tool.name,
attempts=self._run_attempts,
)
return self._format_result(result=result)
except Exception:
if self.task:
self.task.increment_tools_errors()
if self.agent:
event_data = {
"agent_key": self.agent.key,
"agent_role": self.agent.role,
"tool_name": self.action.tool,
"tool_args": self.action.tool_input,
"tool_class": self.action.tool,
"agent": self.agent,
}
if self.agent.fingerprint: # type: ignore
event_data.update(self.agent.fingerprint) # type: ignore
if self.task:
event_data["task_name"] = self.task.name or self.task.description
event_data["task_id"] = str(self.task.id)
crewai_event_bus.emit(self, ToolUsageStartedEvent(**event_data))
started_at = time.time()
from_cache = False
result = None # type: ignore
if self.tools_handler and self.tools_handler.cache:
input_str = ""
if calling.arguments:
if isinstance(calling.arguments, dict):
input_str = json.dumps(calling.arguments)
else:
input_str = str(calling.arguments)
result = self.tools_handler.cache.read(
tool=calling.tool_name, input=input_str
) # type: ignore
from_cache = result is not None
available_tool = next(
(
available_tool
for available_tool in self.tools
if available_tool.name == tool.name
),
None,
)
usage_limit_error = self._check_usage_limit(available_tool, tool.name)
if usage_limit_error:
try:
result = usage_limit_error
self._telemetry.tool_usage_error(llm=self.function_calling_llm)
return self._format_result(result=result)
except Exception:
if self.task:
self.task.increment_tools_errors()
if result is None:
try:
if calling.tool_name in [
"Delegate work to coworker",
"Ask question to coworker",
]:
coworker = (
calling.arguments.get("coworker") if calling.arguments else None
)
if self.task:
self.task.increment_delegations(coworker)
if calling.arguments:
try:
acceptable_args = tool.args_schema.model_json_schema()[
"properties"
].keys()
arguments = {
k: v
for k, v in calling.arguments.items()
if k in acceptable_args
}
arguments = self._add_fingerprint_metadata(arguments)
result = await tool.ainvoke(input=arguments)
except Exception:
arguments = calling.arguments
arguments = self._add_fingerprint_metadata(arguments)
result = await tool.ainvoke(input=arguments)
else:
arguments = self._add_fingerprint_metadata({})
result = await tool.ainvoke(input=arguments)
except Exception as e:
self.on_tool_error(tool=tool, tool_calling=calling, e=e)
self._run_attempts += 1
if self._run_attempts > self._max_parsing_attempts:
self._telemetry.tool_usage_error(llm=self.function_calling_llm)
error_message = self._i18n.errors("tool_usage_exception").format(
error=e, tool=tool.name, tool_inputs=tool.description
)
error = ToolUsageError(
f"\n{error_message}.\nMoving on then. {self._i18n.slice('format').format(tool_names=self.tools_names)}"
).message
if self.task:
self.task.increment_tools_errors()
if self.agent and self.agent.verbose:
self._printer.print(
content=f"\n\n{error_message}\n", color="red"
)
return error
if self.task:
self.task.increment_tools_errors()
return await self.ause(calling=calling, tool_string=tool_string)
if self.tools_handler:
should_cache = True
if (
hasattr(available_tool, "cache_function")
and available_tool.cache_function
):
should_cache = available_tool.cache_function(
calling.arguments, result
)
self.tools_handler.on_tool_use(
calling=calling, output=result, should_cache=should_cache
)
self._telemetry.tool_usage(
llm=self.function_calling_llm,
tool_name=tool.name,
attempts=self._run_attempts,
)
result = self._format_result(result=result)
data = {
"result": result,
"tool_name": tool.name,
"tool_args": calling.arguments,
}
self.on_tool_use_finished(
tool=tool,
tool_calling=calling,
from_cache=from_cache,
started_at=started_at,
result=result,
)
if (
hasattr(available_tool, "result_as_answer")
and available_tool.result_as_answer # type: ignore
):
result_as_answer = available_tool.result_as_answer # type: ignore
data["result_as_answer"] = result_as_answer # type: ignore
if self.agent and hasattr(self.agent, "tools_results"):
self.agent.tools_results.append(data)
if available_tool and hasattr(available_tool, "current_usage_count"):
available_tool.current_usage_count += 1
if (
hasattr(available_tool, "max_usage_count")
and available_tool.max_usage_count is not None
):
self._printer.print(
content=f"Tool '{available_tool.name}' usage: {available_tool.current_usage_count}/{available_tool.max_usage_count}",
color="blue",
)
return result
def _use(
self,
tool_string: str,

View File

@@ -242,17 +242,17 @@ def get_llm_response(
"""Call the LLM and return the response, handling any invalid responses.
Args:
llm: The LLM instance to call
messages: The messages to send to the LLM
callbacks: List of callbacks for the LLM call
printer: Printer instance for output
from_task: Optional task context for the LLM call
from_agent: Optional agent context for the LLM call
response_model: Optional Pydantic model for structured outputs
executor_context: Optional executor context for hook invocation
llm: The LLM instance to call.
messages: The messages to send to the LLM.
callbacks: List of callbacks for the LLM call.
printer: Printer instance for output.
from_task: Optional task context for the LLM call.
from_agent: Optional agent context for the LLM call.
response_model: Optional Pydantic model for structured outputs.
executor_context: Optional executor context for hook invocation.
Returns:
The response from the LLM as a string
The response from the LLM as a string.
Raises:
Exception: If an error occurs.
@@ -284,6 +284,60 @@ def get_llm_response(
return _setup_after_llm_call_hooks(executor_context, answer, printer)
async def aget_llm_response(
llm: LLM | BaseLLM,
messages: list[LLMMessage],
callbacks: list[TokenCalcHandler],
printer: Printer,
from_task: Task | None = None,
from_agent: Agent | LiteAgent | None = None,
response_model: type[BaseModel] | None = None,
executor_context: CrewAgentExecutor | None = None,
) -> str:
"""Call the LLM asynchronously and return the response.
Args:
llm: The LLM instance to call.
messages: The messages to send to the LLM.
callbacks: List of callbacks for the LLM call.
printer: Printer instance for output.
from_task: Optional task context for the LLM call.
from_agent: Optional agent context for the LLM call.
response_model: Optional Pydantic model for structured outputs.
executor_context: Optional executor context for hook invocation.
Returns:
The response from the LLM as a string.
Raises:
Exception: If an error occurs.
ValueError: If the response is None or empty.
"""
if executor_context is not None:
if not _setup_before_llm_call_hooks(executor_context, printer):
raise ValueError("LLM call blocked by before_llm_call hook")
messages = executor_context.messages
try:
answer = await llm.acall(
messages,
callbacks=callbacks,
from_task=from_task,
from_agent=from_agent, # type: ignore[arg-type]
response_model=response_model,
)
except Exception as e:
raise e
if not answer:
printer.print(
content="Received None or empty response from LLM call.",
color="red",
)
raise ValueError("Invalid response from LLM call - None or empty.")
return _setup_after_llm_call_hooks(executor_context, answer, printer)
def process_llm_response(
answer: str, use_stop_words: bool
) -> AgentAction | AgentFinish:

View File

@@ -26,6 +26,138 @@ if TYPE_CHECKING:
from crewai.task import Task
async def aexecute_tool_and_check_finality(
agent_action: AgentAction,
tools: list[CrewStructuredTool],
i18n: I18N,
agent_key: str | None = None,
agent_role: str | None = None,
tools_handler: ToolsHandler | None = None,
task: Task | None = None,
agent: Agent | BaseAgent | None = None,
function_calling_llm: BaseLLM | LLM | None = None,
fingerprint_context: dict[str, str] | None = None,
crew: Crew | None = None,
) -> ToolResult:
"""Execute a tool asynchronously and check if the result should be a final answer.
This is the async version of execute_tool_and_check_finality. It integrates tool
hooks for before and after tool execution, allowing programmatic interception
and modification of tool calls.
Args:
agent_action: The action containing the tool to execute.
tools: List of available tools.
i18n: Internationalization settings.
agent_key: Optional key for event emission.
agent_role: Optional role for event emission.
tools_handler: Optional tools handler for tool execution.
task: Optional task for tool execution.
agent: Optional agent instance for tool execution.
function_calling_llm: Optional LLM for function calling.
fingerprint_context: Optional context for fingerprinting.
crew: Optional crew instance for hook context.
Returns:
ToolResult containing the execution result and whether it should be
treated as a final answer.
"""
logger = Logger(verbose=crew.verbose if crew else False)
tool_name_to_tool_map = {tool.name: tool for tool in tools}
if agent_key and agent_role and agent:
fingerprint_context = fingerprint_context or {}
if agent:
if hasattr(agent, "set_fingerprint") and callable(agent.set_fingerprint):
if isinstance(fingerprint_context, dict):
try:
fingerprint_obj = Fingerprint.from_dict(fingerprint_context)
agent.set_fingerprint(fingerprint=fingerprint_obj)
except Exception as e:
raise ValueError(f"Failed to set fingerprint: {e}") from e
tool_usage = ToolUsage(
tools_handler=tools_handler,
tools=tools,
function_calling_llm=function_calling_llm, # type: ignore[arg-type]
task=task,
agent=agent,
action=agent_action,
)
tool_calling = tool_usage.parse_tool_calling(agent_action.text)
if isinstance(tool_calling, ToolUsageError):
return ToolResult(tool_calling.message, False)
if tool_calling.tool_name.casefold().strip() in [
name.casefold().strip() for name in tool_name_to_tool_map
] or tool_calling.tool_name.casefold().replace("_", " ") in [
name.casefold().strip() for name in tool_name_to_tool_map
]:
tool = tool_name_to_tool_map.get(tool_calling.tool_name)
if not tool:
tool_result = i18n.errors("wrong_tool_name").format(
tool=tool_calling.tool_name,
tools=", ".join([t.name.casefold() for t in tools]),
)
return ToolResult(result=tool_result, result_as_answer=False)
tool_input = tool_calling.arguments if tool_calling.arguments else {}
hook_context = ToolCallHookContext(
tool_name=tool_calling.tool_name,
tool_input=tool_input,
tool=tool,
agent=agent,
task=task,
crew=crew,
)
before_hooks = get_before_tool_call_hooks()
try:
for hook in before_hooks:
result = hook(hook_context)
if result is False:
blocked_message = (
f"Tool execution blocked by hook. "
f"Tool: {tool_calling.tool_name}"
)
return ToolResult(blocked_message, False)
except Exception as e:
logger.log("error", f"Error in before_tool_call hook: {e}")
tool_result = await tool_usage.ause(tool_calling, agent_action.text)
after_hook_context = ToolCallHookContext(
tool_name=tool_calling.tool_name,
tool_input=tool_input,
tool=tool,
agent=agent,
task=task,
crew=crew,
tool_result=tool_result,
)
after_hooks = get_after_tool_call_hooks()
modified_result: str = tool_result
try:
for after_hook in after_hooks:
hook_result = after_hook(after_hook_context)
if hook_result is not None:
modified_result = hook_result
after_hook_context.tool_result = modified_result
except Exception as e:
logger.log("error", f"Error in after_tool_call hook: {e}")
return ToolResult(modified_result, tool.result_as_answer)
tool_result = i18n.errors("wrong_tool_name").format(
tool=tool_calling.tool_name,
tools=", ".join([tool.name.casefold() for tool in tools]),
)
return ToolResult(result=tool_result, result_as_answer=False)
def execute_tool_and_check_finality(
agent_action: AgentAction,
tools: list[CrewStructuredTool],
@@ -141,10 +273,10 @@ def execute_tool_and_check_finality(
# Execute after_tool_call hooks
after_hooks = get_after_tool_call_hooks()
modified_result = tool_result
modified_result: str = tool_result
try:
for hook in after_hooks:
hook_result = hook(after_hook_context)
for after_hook in after_hooks:
hook_result = after_hook(after_hook_context)
if hook_result is not None:
modified_result = hook_result
after_hook_context.tool_result = modified_result

View File

@@ -51,6 +51,15 @@ class ConcreteAgentAdapter(BaseAgentAdapter):
# Dummy implementation for MCP tools
return []
async def aexecute_task(
self,
task: Any,
context: str | None = None,
tools: list[Any] | None = None,
) -> str:
# Dummy async implementation
return "Task executed"
def test_base_agent_adapter_initialization():
"""Test initialization of the concrete agent adapter."""

View File

@@ -25,6 +25,14 @@ class MockAgent(BaseAgent):
def get_mcp_tools(self, mcps: list[str]) -> list[BaseTool]:
return []
async def aexecute_task(
self,
task: Any,
context: str | None = None,
tools: list[BaseTool] | None = None,
) -> str:
return ""
def get_output_converter(
self, llm: Any, text: str, model: type[BaseModel] | None, instructions: str
): ...

View File

@@ -147,7 +147,7 @@ def test_custom_llm():
assert agent.llm.model == "gpt-4"
@pytest.mark.vcr(filter_headers=["authorization"])
@pytest.mark.vcr()
def test_agent_execution():
agent = Agent(
role="test role",
@@ -166,7 +166,7 @@ def test_agent_execution():
assert output == "1 + 1 is 2"
@pytest.mark.vcr(filter_headers=["authorization"])
@pytest.mark.vcr()
def test_agent_execution_with_tools():
@tool
def multiplier(first_number: int, second_number: int) -> float:
@@ -211,7 +211,7 @@ def test_agent_execution_with_tools():
assert received_events[0].tool_args == {"first_number": 3, "second_number": 4}
@pytest.mark.vcr(filter_headers=["authorization"])
@pytest.mark.vcr()
def test_logging_tool_usage():
@tool
def multiplier(first_number: int, second_number: int) -> float:
@@ -245,7 +245,7 @@ def test_logging_tool_usage():
assert agent.tools_handler.last_used_tool.arguments == tool_usage.arguments
@pytest.mark.vcr(filter_headers=["authorization"])
@pytest.mark.vcr()
def test_cache_hitting():
@tool
def multiplier(first_number: int, second_number: int) -> float:
@@ -325,7 +325,7 @@ def test_cache_hitting():
assert received_events[0].output == "12"
@pytest.mark.vcr(filter_headers=["authorization"])
@pytest.mark.vcr()
def test_disabling_cache_for_agent():
@tool
def multiplier(first_number: int, second_number: int) -> float:
@@ -389,7 +389,7 @@ def test_disabling_cache_for_agent():
read.assert_not_called()
@pytest.mark.vcr(filter_headers=["authorization"])
@pytest.mark.vcr()
def test_agent_execution_with_specific_tools():
@tool
def multiplier(first_number: int, second_number: int) -> float:
@@ -412,7 +412,7 @@ def test_agent_execution_with_specific_tools():
assert output == "The result of the multiplication is 12."
@pytest.mark.vcr(filter_headers=["authorization"])
@pytest.mark.vcr()
def test_agent_powered_by_new_o_model_family_that_allows_skipping_tool():
@tool
def multiplier(first_number: int, second_number: int) -> float:
@@ -438,7 +438,7 @@ def test_agent_powered_by_new_o_model_family_that_allows_skipping_tool():
assert output == "12"
@pytest.mark.vcr(filter_headers=["authorization"])
@pytest.mark.vcr()
def test_agent_powered_by_new_o_model_family_that_uses_tool():
@tool
def comapny_customer_data() -> str:
@@ -464,7 +464,7 @@ def test_agent_powered_by_new_o_model_family_that_uses_tool():
assert output == "42"
@pytest.mark.vcr(filter_headers=["authorization"])
@pytest.mark.vcr()
def test_agent_custom_max_iterations():
@tool
def get_final_answer() -> float:
@@ -509,7 +509,7 @@ def test_agent_custom_max_iterations():
assert call_count == 2
@pytest.mark.vcr(filter_headers=["authorization"])
@pytest.mark.vcr()
@pytest.mark.timeout(30)
def test_agent_max_iterations_stops_loop():
"""Test that agent execution terminates when max_iter is reached."""
@@ -546,7 +546,7 @@ def test_agent_max_iterations_stops_loop():
)
@pytest.mark.vcr(filter_headers=["authorization"])
@pytest.mark.vcr()
def test_agent_repeated_tool_usage(capsys):
"""Test that agents handle repeated tool usage appropriately.
@@ -595,7 +595,7 @@ def test_agent_repeated_tool_usage(capsys):
)
@pytest.mark.vcr(filter_headers=["authorization"])
@pytest.mark.vcr()
def test_agent_repeated_tool_usage_check_even_with_disabled_cache(capsys):
@tool
def get_final_answer(anything: str) -> float:
@@ -638,7 +638,7 @@ def test_agent_repeated_tool_usage_check_even_with_disabled_cache(capsys):
)
@pytest.mark.vcr(filter_headers=["authorization"])
@pytest.mark.vcr()
def test_agent_moved_on_after_max_iterations():
@tool
def get_final_answer() -> float:
@@ -665,7 +665,7 @@ def test_agent_moved_on_after_max_iterations():
assert output == "42"
@pytest.mark.vcr(filter_headers=["authorization"])
@pytest.mark.vcr()
def test_agent_respect_the_max_rpm_set(capsys):
@tool
def get_final_answer() -> float:
@@ -699,7 +699,7 @@ def test_agent_respect_the_max_rpm_set(capsys):
moveon.assert_called()
@pytest.mark.vcr(filter_headers=["authorization"])
@pytest.mark.vcr()
def test_agent_respect_the_max_rpm_set_over_crew_rpm(capsys):
from unittest.mock import patch
@@ -737,7 +737,7 @@ def test_agent_respect_the_max_rpm_set_over_crew_rpm(capsys):
moveon.assert_not_called()
@pytest.mark.vcr(filter_headers=["authorization"])
@pytest.mark.vcr()
def test_agent_without_max_rpm_respects_crew_rpm(capsys):
from unittest.mock import patch
@@ -797,7 +797,7 @@ def test_agent_without_max_rpm_respects_crew_rpm(capsys):
moveon.assert_called_once()
@pytest.mark.vcr(filter_headers=["authorization"])
@pytest.mark.vcr()
def test_agent_error_on_parsing_tool(capsys):
from unittest.mock import patch
@@ -840,7 +840,7 @@ def test_agent_error_on_parsing_tool(capsys):
assert "Error on parsing tool." in captured.out
@pytest.mark.vcr(filter_headers=["authorization"])
@pytest.mark.vcr()
def test_agent_remembers_output_format_after_using_tools_too_many_times():
from unittest.mock import patch
@@ -875,7 +875,7 @@ def test_agent_remembers_output_format_after_using_tools_too_many_times():
remember_format.assert_called()
@pytest.mark.vcr(filter_headers=["authorization"])
@pytest.mark.vcr()
def test_agent_use_specific_tasks_output_as_context(capsys):
agent1 = Agent(role="test role", goal="test goal", backstory="test backstory")
agent2 = Agent(role="test role2", goal="test goal2", backstory="test backstory2")
@@ -902,7 +902,7 @@ def test_agent_use_specific_tasks_output_as_context(capsys):
assert "hi" in result.raw.lower() or "hello" in result.raw.lower()
@pytest.mark.vcr(filter_headers=["authorization"])
@pytest.mark.vcr()
def test_agent_step_callback():
class StepCallback:
def callback(self, step):
@@ -936,7 +936,7 @@ def test_agent_step_callback():
callback.assert_called()
@pytest.mark.vcr(filter_headers=["authorization"])
@pytest.mark.vcr()
def test_agent_function_calling_llm():
from crewai.llm import LLM
llm = LLM(model="gpt-4o", is_litellm=True)
@@ -983,7 +983,7 @@ def test_agent_function_calling_llm():
mock_original_tool_calling.assert_called()
@pytest.mark.vcr(filter_headers=["authorization"])
@pytest.mark.vcr()
def test_tool_result_as_answer_is_the_final_answer_for_the_agent():
from crewai.tools import BaseTool
@@ -1013,7 +1013,7 @@ def test_tool_result_as_answer_is_the_final_answer_for_the_agent():
assert result.raw == "Howdy!"
@pytest.mark.vcr(filter_headers=["authorization"])
@pytest.mark.vcr()
def test_tool_usage_information_is_appended_to_agent():
from crewai.tools import BaseTool
@@ -1068,7 +1068,7 @@ def test_agent_definition_based_on_dict():
# test for human input
@pytest.mark.vcr(filter_headers=["authorization"])
@pytest.mark.vcr()
def test_agent_human_input():
# Agent configuration
config = {
@@ -1216,7 +1216,7 @@ Thought:<|eot_id|>
assert mock_format_prompt.return_value == expected_prompt
@pytest.mark.vcr(filter_headers=["authorization"])
@pytest.mark.vcr()
def test_task_allow_crewai_trigger_context():
from crewai import Crew
@@ -1237,7 +1237,7 @@ def test_task_allow_crewai_trigger_context():
assert "Trigger Payload: Important context data" in prompt
@pytest.mark.vcr(filter_headers=["authorization"])
@pytest.mark.vcr()
def test_task_without_allow_crewai_trigger_context():
from crewai import Crew
@@ -1260,7 +1260,7 @@ def test_task_without_allow_crewai_trigger_context():
assert "Important context data" not in prompt
@pytest.mark.vcr(filter_headers=["authorization"])
@pytest.mark.vcr()
def test_task_allow_crewai_trigger_context_no_payload():
from crewai import Crew
@@ -1282,7 +1282,7 @@ def test_task_allow_crewai_trigger_context_no_payload():
assert "Trigger Payload:" not in prompt
@pytest.mark.vcr(filter_headers=["authorization"])
@pytest.mark.vcr()
def test_do_not_allow_crewai_trigger_context_for_first_task_hierarchical():
from crewai import Crew
@@ -1311,7 +1311,7 @@ def test_do_not_allow_crewai_trigger_context_for_first_task_hierarchical():
assert "Trigger Payload: Initial context data" not in first_prompt
@pytest.mark.vcr(filter_headers=["authorization"])
@pytest.mark.vcr()
def test_first_task_auto_inject_trigger():
from crewai import Crew
@@ -1344,7 +1344,7 @@ def test_first_task_auto_inject_trigger():
assert "Trigger Payload:" not in second_prompt
@pytest.mark.vcr(filter_headers=["authorization"])
@pytest.mark.vcr()
def test_ensure_first_task_allow_crewai_trigger_context_is_false_does_not_inject():
from crewai import Crew
@@ -1549,7 +1549,7 @@ def test_agent_with_additional_kwargs():
assert agent.llm.frequency_penalty == 0.1
@pytest.mark.vcr(filter_headers=["authorization"])
@pytest.mark.vcr()
def test_llm_call():
llm = LLM(model="gpt-3.5-turbo")
messages = [{"role": "user", "content": "Say 'Hello, World!'"}]
@@ -1558,7 +1558,7 @@ def test_llm_call():
assert "Hello, World!" in response
@pytest.mark.vcr(filter_headers=["authorization"])
@pytest.mark.vcr()
def test_llm_call_with_error():
llm = LLM(model="non-existent-model")
messages = [{"role": "user", "content": "This should fail"}]
@@ -1567,7 +1567,7 @@ def test_llm_call_with_error():
llm.call(messages)
@pytest.mark.vcr(filter_headers=["authorization"])
@pytest.mark.vcr()
def test_handle_context_length_exceeds_limit():
# Import necessary modules
from crewai.utilities.agent_utils import handle_context_length
@@ -1620,7 +1620,7 @@ def test_handle_context_length_exceeds_limit():
mock_summarize.assert_called_once()
@pytest.mark.vcr(filter_headers=["authorization"])
@pytest.mark.vcr()
def test_handle_context_length_exceeds_limit_cli_no():
agent = Agent(
role="test role",
@@ -1695,7 +1695,7 @@ def test_agent_with_all_llm_attributes():
assert agent.llm.api_key == "sk-your-api-key-here"
@pytest.mark.vcr(filter_headers=["authorization"])
@pytest.mark.vcr()
def test_llm_call_with_all_attributes():
llm = LLM(
model="gpt-3.5-turbo",
@@ -1712,7 +1712,7 @@ def test_llm_call_with_all_attributes():
assert "STOP" not in response
@pytest.mark.vcr(filter_headers=["authorization"])
@pytest.mark.vcr()
def test_agent_with_ollama_llama3():
agent = Agent(
role="test role",
@@ -1733,7 +1733,7 @@ def test_agent_with_ollama_llama3():
assert "Llama3" in response or "AI" in response or "language model" in response
@pytest.mark.vcr(filter_headers=["authorization"])
@pytest.mark.vcr()
def test_llm_call_with_ollama_llama3():
llm = LLM(
model="ollama/llama3.2:3b",
@@ -1752,7 +1752,7 @@ def test_llm_call_with_ollama_llama3():
assert "Llama3" in response or "AI" in response or "language model" in response
@pytest.mark.vcr(filter_headers=["authorization"])
@pytest.mark.vcr()
def test_agent_execute_task_basic():
agent = Agent(
role="test role",
@@ -1771,7 +1771,7 @@ def test_agent_execute_task_basic():
assert "4" in result
@pytest.mark.vcr(filter_headers=["authorization"])
@pytest.mark.vcr()
def test_agent_execute_task_with_context():
agent = Agent(
role="test role",
@@ -1793,7 +1793,7 @@ def test_agent_execute_task_with_context():
assert "fox" in result.lower() and "dog" in result.lower()
@pytest.mark.vcr(filter_headers=["authorization"])
@pytest.mark.vcr()
def test_agent_execute_task_with_tool():
@tool
def dummy_tool(query: str) -> str:
@@ -1818,7 +1818,7 @@ def test_agent_execute_task_with_tool():
assert "Dummy result for: test query" in result
@pytest.mark.vcr(filter_headers=["authorization"])
@pytest.mark.vcr()
def test_agent_execute_task_with_custom_llm():
agent = Agent(
role="test role",
@@ -1839,7 +1839,7 @@ def test_agent_execute_task_with_custom_llm():
)
@pytest.mark.vcr(filter_headers=["authorization"])
@pytest.mark.vcr()
def test_agent_execute_task_with_ollama():
agent = Agent(
role="test role",
@@ -1859,7 +1859,7 @@ def test_agent_execute_task_with_ollama():
assert "AI" in result or "artificial intelligence" in result.lower()
@pytest.mark.vcr(filter_headers=["authorization"])
@pytest.mark.vcr()
def test_agent_with_knowledge_sources():
content = "Brandon's favorite color is red and he likes Mexican food."
string_source = StringKnowledgeSource(content=content)
@@ -1891,7 +1891,7 @@ def test_agent_with_knowledge_sources():
assert "red" in result.raw.lower()
@pytest.mark.vcr(filter_headers=["authorization"])
@pytest.mark.vcr()
def test_agent_with_knowledge_sources_with_query_limit_and_score_threshold():
content = "Brandon's favorite color is red and he likes Mexican food."
string_source = StringKnowledgeSource(content=content)
@@ -1939,7 +1939,7 @@ def test_agent_with_knowledge_sources_with_query_limit_and_score_threshold():
)
@pytest.mark.vcr(filter_headers=["authorization"])
@pytest.mark.vcr()
def test_agent_with_knowledge_sources_with_query_limit_and_score_threshold_default():
content = "Brandon's favorite color is red and he likes Mexican food."
string_source = StringKnowledgeSource(content=content)
@@ -1988,7 +1988,7 @@ def test_agent_with_knowledge_sources_with_query_limit_and_score_threshold_defau
)
@pytest.mark.vcr(filter_headers=["authorization"])
@pytest.mark.vcr()
def test_agent_with_knowledge_sources_extensive_role():
content = "Brandon's favorite color is red and he likes Mexican food."
string_source = StringKnowledgeSource(content=content)
@@ -2024,7 +2024,7 @@ def test_agent_with_knowledge_sources_extensive_role():
assert "red" in result.raw.lower()
@pytest.mark.vcr(filter_headers=["authorization"])
@pytest.mark.vcr()
def test_agent_with_knowledge_sources_works_with_copy():
content = "Brandon's favorite color is red and he likes Mexican food."
string_source = StringKnowledgeSource(content=content)
@@ -2063,7 +2063,7 @@ def test_agent_with_knowledge_sources_works_with_copy():
assert isinstance(agent_copy.llm, BaseLLM)
@pytest.mark.vcr(filter_headers=["authorization"])
@pytest.mark.vcr()
def test_agent_with_knowledge_sources_generate_search_query():
content = "Brandon's favorite color is red and he likes Mexican food."
string_source = StringKnowledgeSource(content=content)
@@ -2116,7 +2116,7 @@ def test_agent_with_knowledge_sources_generate_search_query():
assert "red" in result.raw.lower()
@pytest.mark.vcr(record_mode="none", filter_headers=["authorization"])
@pytest.mark.vcr()
def test_agent_with_knowledge_with_no_crewai_knowledge():
mock_knowledge = MagicMock(spec=Knowledge)
@@ -2143,7 +2143,7 @@ def test_agent_with_knowledge_with_no_crewai_knowledge():
mock_knowledge.query.assert_called_once()
@pytest.mark.vcr(filter_headers=["authorization"])
@pytest.mark.vcr()
def test_agent_with_only_crewai_knowledge():
mock_knowledge = MagicMock(spec=Knowledge)
@@ -2168,7 +2168,7 @@ def test_agent_with_only_crewai_knowledge():
mock_knowledge.query.assert_called_once()
@pytest.mark.vcr(record_mode="none", filter_headers=["authorization"])
@pytest.mark.vcr()
def test_agent_knowledege_with_crewai_knowledge():
crew_knowledge = MagicMock(spec=Knowledge)
agent_knowledge = MagicMock(spec=Knowledge)
@@ -2197,7 +2197,7 @@ def test_agent_knowledege_with_crewai_knowledge():
crew_knowledge.query.assert_called_once()
@pytest.mark.vcr(filter_headers=["authorization"])
@pytest.mark.vcr()
def test_litellm_auth_error_handling():
"""Test that LiteLLM authentication errors are handled correctly and not retried."""
from litellm import AuthenticationError as LiteLLMAuthenticationError
@@ -2326,7 +2326,7 @@ def test_litellm_anthropic_error_handling():
mock_llm_call.assert_called_once()
@pytest.mark.vcr(filter_headers=["authorization"])
@pytest.mark.vcr()
def test_get_knowledge_search_query():
"""Test that _get_knowledge_search_query calls the LLM with the correct prompts."""
from crewai.utilities.i18n import I18N

View File

@@ -0,0 +1,345 @@
"""Tests for async agent executor functionality."""
import asyncio
from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from crewai.agents.crew_agent_executor import CrewAgentExecutor
from crewai.agents.parser import AgentAction, AgentFinish
from crewai.tools.tool_types import ToolResult
@pytest.fixture
def mock_llm() -> MagicMock:
"""Create a mock LLM for testing."""
llm = MagicMock()
llm.supports_stop_words.return_value = True
llm.stop = []
return llm
@pytest.fixture
def mock_agent() -> MagicMock:
"""Create a mock agent for testing."""
agent = MagicMock()
agent.role = "Test Agent"
agent.key = "test_agent_key"
agent.verbose = False
agent.id = "test_agent_id"
return agent
@pytest.fixture
def mock_task() -> MagicMock:
"""Create a mock task for testing."""
task = MagicMock()
task.description = "Test task description"
return task
@pytest.fixture
def mock_crew() -> MagicMock:
"""Create a mock crew for testing."""
crew = MagicMock()
crew.verbose = False
crew._train = False
return crew
@pytest.fixture
def mock_tools_handler() -> MagicMock:
"""Create a mock tools handler."""
return MagicMock()
@pytest.fixture
def executor(
mock_llm: MagicMock,
mock_agent: MagicMock,
mock_task: MagicMock,
mock_crew: MagicMock,
mock_tools_handler: MagicMock,
) -> CrewAgentExecutor:
"""Create a CrewAgentExecutor instance for testing."""
return CrewAgentExecutor(
llm=mock_llm,
task=mock_task,
crew=mock_crew,
agent=mock_agent,
prompt={"prompt": "Test prompt {input} {tool_names} {tools}"},
max_iter=5,
tools=[],
tools_names="",
stop_words=["Observation:"],
tools_description="",
tools_handler=mock_tools_handler,
)
class TestAsyncAgentExecutor:
"""Tests for async agent executor methods."""
@pytest.mark.asyncio
async def test_ainvoke_returns_output(self, executor: CrewAgentExecutor) -> None:
"""Test that ainvoke returns the expected output."""
expected_output = "Final answer from agent"
with patch.object(
executor,
"_ainvoke_loop",
new_callable=AsyncMock,
return_value=AgentFinish(
thought="Done", output=expected_output, text="Final Answer: Done"
),
):
with patch.object(executor, "_show_start_logs"):
with patch.object(executor, "_create_short_term_memory"):
with patch.object(executor, "_create_long_term_memory"):
with patch.object(executor, "_create_external_memory"):
result = await executor.ainvoke(
{
"input": "test input",
"tool_names": "",
"tools": "",
}
)
assert result == {"output": expected_output}
@pytest.mark.asyncio
async def test_ainvoke_loop_calls_aget_llm_response(
self, executor: CrewAgentExecutor
) -> None:
"""Test that _ainvoke_loop calls aget_llm_response."""
with patch(
"crewai.agents.crew_agent_executor.aget_llm_response",
new_callable=AsyncMock,
return_value="Thought: I know the answer\nFinal Answer: Test result",
) as mock_aget_llm:
with patch.object(executor, "_show_logs"):
result = await executor._ainvoke_loop()
mock_aget_llm.assert_called_once()
assert isinstance(result, AgentFinish)
@pytest.mark.asyncio
async def test_ainvoke_loop_handles_tool_execution(
self,
executor: CrewAgentExecutor,
) -> None:
"""Test that _ainvoke_loop handles tool execution asynchronously."""
call_count = 0
async def mock_llm_response(*args: Any, **kwargs: Any) -> str:
nonlocal call_count
call_count += 1
if call_count == 1:
return (
"Thought: I need to use a tool\n"
"Action: test_tool\n"
'Action Input: {"arg": "value"}'
)
return "Thought: I have the answer\nFinal Answer: Tool result processed"
with patch(
"crewai.agents.crew_agent_executor.aget_llm_response",
new_callable=AsyncMock,
side_effect=mock_llm_response,
):
with patch(
"crewai.agents.crew_agent_executor.aexecute_tool_and_check_finality",
new_callable=AsyncMock,
return_value=ToolResult(result="Tool executed", result_as_answer=False),
) as mock_tool_exec:
with patch.object(executor, "_show_logs"):
with patch.object(executor, "_handle_agent_action") as mock_handle:
mock_handle.return_value = AgentAction(
text="Tool result",
tool="test_tool",
tool_input='{"arg": "value"}',
thought="Used tool",
result="Tool executed",
)
result = await executor._ainvoke_loop()
assert mock_tool_exec.called
assert isinstance(result, AgentFinish)
@pytest.mark.asyncio
async def test_ainvoke_loop_respects_max_iterations(
self, executor: CrewAgentExecutor
) -> None:
"""Test that _ainvoke_loop respects max iterations."""
executor.max_iter = 2
async def always_return_action(*args: Any, **kwargs: Any) -> str:
return (
"Thought: I need to think more\n"
"Action: some_tool\n"
"Action Input: {}"
)
with patch(
"crewai.agents.crew_agent_executor.aget_llm_response",
new_callable=AsyncMock,
side_effect=always_return_action,
):
with patch(
"crewai.agents.crew_agent_executor.aexecute_tool_and_check_finality",
new_callable=AsyncMock,
return_value=ToolResult(result="Tool result", result_as_answer=False),
):
with patch(
"crewai.agents.crew_agent_executor.handle_max_iterations_exceeded",
return_value=AgentFinish(
thought="Max iterations",
output="Forced answer",
text="Max iterations reached",
),
) as mock_max_iter:
with patch.object(executor, "_show_logs"):
with patch.object(executor, "_handle_agent_action") as mock_ha:
mock_ha.return_value = AgentAction(
text="Action",
tool="some_tool",
tool_input="{}",
thought="Thinking",
)
result = await executor._ainvoke_loop()
mock_max_iter.assert_called_once()
assert isinstance(result, AgentFinish)
@pytest.mark.asyncio
async def test_ainvoke_handles_exceptions(
self, executor: CrewAgentExecutor
) -> None:
"""Test that ainvoke properly propagates exceptions."""
with patch.object(executor, "_show_start_logs"):
with patch.object(
executor,
"_ainvoke_loop",
new_callable=AsyncMock,
side_effect=ValueError("Test error"),
):
with pytest.raises(ValueError, match="Test error"):
await executor.ainvoke(
{"input": "test", "tool_names": "", "tools": ""}
)
@pytest.mark.asyncio
async def test_concurrent_ainvoke_calls(
self, mock_llm: MagicMock, mock_agent: MagicMock, mock_task: MagicMock,
mock_crew: MagicMock, mock_tools_handler: MagicMock
) -> None:
"""Test that multiple ainvoke calls can run concurrently."""
async def create_and_run_executor(executor_id: int) -> dict[str, Any]:
executor = CrewAgentExecutor(
llm=mock_llm,
task=mock_task,
crew=mock_crew,
agent=mock_agent,
prompt={"prompt": "Test {input} {tool_names} {tools}"},
max_iter=5,
tools=[],
tools_names="",
stop_words=["Observation:"],
tools_description="",
tools_handler=mock_tools_handler,
)
async def delayed_response(*args: Any, **kwargs: Any) -> str:
await asyncio.sleep(0.05)
return f"Thought: Done\nFinal Answer: Result from executor {executor_id}"
with patch(
"crewai.agents.crew_agent_executor.aget_llm_response",
new_callable=AsyncMock,
side_effect=delayed_response,
):
with patch.object(executor, "_show_start_logs"):
with patch.object(executor, "_show_logs"):
with patch.object(executor, "_create_short_term_memory"):
with patch.object(executor, "_create_long_term_memory"):
with patch.object(executor, "_create_external_memory"):
return await executor.ainvoke(
{
"input": f"test {executor_id}",
"tool_names": "",
"tools": "",
}
)
import time
start = time.time()
results = await asyncio.gather(
create_and_run_executor(1),
create_and_run_executor(2),
create_and_run_executor(3),
)
elapsed = time.time() - start
assert len(results) == 3
assert all("output" in r for r in results)
assert elapsed < 0.15, f"Expected concurrent execution, took {elapsed}s"
class TestAsyncLLMResponseHelper:
"""Tests for aget_llm_response helper function."""
@pytest.mark.asyncio
async def test_aget_llm_response_calls_acall(self) -> None:
"""Test that aget_llm_response calls llm.acall."""
from crewai.utilities.agent_utils import aget_llm_response
from crewai.utilities.printer import Printer
mock_llm = MagicMock()
mock_llm.acall = AsyncMock(return_value="LLM response")
result = await aget_llm_response(
llm=mock_llm,
messages=[{"role": "user", "content": "test"}],
callbacks=[],
printer=Printer(),
)
mock_llm.acall.assert_called_once()
assert result == "LLM response"
@pytest.mark.asyncio
async def test_aget_llm_response_raises_on_empty_response(self) -> None:
"""Test that aget_llm_response raises ValueError on empty response."""
from crewai.utilities.agent_utils import aget_llm_response
from crewai.utilities.printer import Printer
mock_llm = MagicMock()
mock_llm.acall = AsyncMock(return_value="")
with pytest.raises(ValueError, match="Invalid response from LLM call"):
await aget_llm_response(
llm=mock_llm,
messages=[{"role": "user", "content": "test"}],
callbacks=[],
printer=Printer(),
)
@pytest.mark.asyncio
async def test_aget_llm_response_propagates_exceptions(self) -> None:
"""Test that aget_llm_response propagates LLM exceptions."""
from crewai.utilities.agent_utils import aget_llm_response
from crewai.utilities.printer import Printer
mock_llm = MagicMock()
mock_llm.acall = AsyncMock(side_effect=RuntimeError("LLM error"))
with pytest.raises(RuntimeError, match="LLM error"):
await aget_llm_response(
llm=mock_llm,
messages=[{"role": "user", "content": "test"}],
callbacks=[],
printer=Printer(),
)

View File

@@ -70,7 +70,7 @@ class ResearchResult(BaseModel):
sources: list[str] = Field(description="List of sources used")
@pytest.mark.vcr(filter_headers=["authorization"])
@pytest.mark.vcr()
@pytest.mark.parametrize("verbose", [True, False])
def test_lite_agent_created_with_correct_parameters(monkeypatch, verbose):
"""Test that LiteAgent is created with the correct parameters when Agent.kickoff() is called."""
@@ -130,7 +130,7 @@ def test_lite_agent_created_with_correct_parameters(monkeypatch, verbose):
assert created_lite_agent["response_format"] == TestResponse
@pytest.mark.vcr(filter_headers=["authorization"])
@pytest.mark.vcr()
def test_lite_agent_with_tools():
"""Test that Agent can use tools."""
# Create a LiteAgent with tools
@@ -174,7 +174,7 @@ def test_lite_agent_with_tools():
assert event.tool_name == "search_web"
@pytest.mark.vcr(filter_headers=["authorization"])
@pytest.mark.vcr()
def test_lite_agent_structured_output():
"""Test that Agent can return a simple structured output."""
@@ -217,7 +217,7 @@ def test_lite_agent_structured_output():
return result
@pytest.mark.vcr(filter_headers=["authorization"])
@pytest.mark.vcr()
def test_lite_agent_returns_usage_metrics():
"""Test that LiteAgent returns usage metrics."""
llm = LLM(model="gpt-4o-mini")
@@ -238,7 +238,7 @@ def test_lite_agent_returns_usage_metrics():
assert result.usage_metrics["total_tokens"] > 0
@pytest.mark.vcr(filter_headers=["authorization"])
@pytest.mark.vcr()
def test_lite_agent_output_includes_messages():
"""Test that LiteAgentOutput includes messages from agent execution."""
llm = LLM(model="gpt-4o-mini")
@@ -259,7 +259,7 @@ def test_lite_agent_output_includes_messages():
assert len(result.messages) > 0
@pytest.mark.vcr(filter_headers=["authorization"])
@pytest.mark.vcr()
@pytest.mark.asyncio
async def test_lite_agent_returns_usage_metrics_async():
"""Test that LiteAgent returns usage metrics when run asynchronously."""
@@ -354,9 +354,9 @@ def test_sets_parent_flow_when_inside_flow():
assert captured_agent.parent_flow is flow
@pytest.mark.vcr(filter_headers=["authorization"])
@pytest.mark.vcr()
def test_guardrail_is_called_using_string():
guardrail_events = defaultdict(list)
guardrail_events: dict[str, list] = defaultdict(list)
from crewai.events.event_types import (
LLMGuardrailCompletedEvent,
LLMGuardrailStartedEvent,
@@ -369,35 +369,33 @@ def test_guardrail_is_called_using_string():
guardrail="""Only include Brazilian players, both women and men""",
)
all_events_received = threading.Event()
condition = threading.Condition()
@crewai_event_bus.on(LLMGuardrailStartedEvent)
def capture_guardrail_started(source, event):
assert isinstance(source, LiteAgent)
assert source.original_agent == agent
guardrail_events["started"].append(event)
if (
len(guardrail_events["started"]) == 2
and len(guardrail_events["completed"]) == 2
):
all_events_received.set()
with condition:
guardrail_events["started"].append(event)
condition.notify()
@crewai_event_bus.on(LLMGuardrailCompletedEvent)
def capture_guardrail_completed(source, event):
assert isinstance(source, LiteAgent)
assert source.original_agent == agent
guardrail_events["completed"].append(event)
if (
len(guardrail_events["started"]) == 2
and len(guardrail_events["completed"]) == 2
):
all_events_received.set()
with condition:
guardrail_events["completed"].append(event)
condition.notify()
result = agent.kickoff(messages="Top 10 best players in the world?")
assert all_events_received.wait(timeout=10), (
"Timeout waiting for all guardrail events"
)
with condition:
success = condition.wait_for(
lambda: len(guardrail_events["started"]) >= 2
and len(guardrail_events["completed"]) >= 2,
timeout=10,
)
assert success, "Timeout waiting for all guardrail events"
assert len(guardrail_events["started"]) == 2
assert len(guardrail_events["completed"]) == 2
assert not guardrail_events["completed"][0].success
@@ -408,33 +406,27 @@ def test_guardrail_is_called_using_string():
)
@pytest.mark.vcr(filter_headers=["authorization"])
@pytest.mark.vcr()
def test_guardrail_is_called_using_callable():
guardrail_events = defaultdict(list)
guardrail_events: dict[str, list] = defaultdict(list)
from crewai.events.event_types import (
LLMGuardrailCompletedEvent,
LLMGuardrailStartedEvent,
)
all_events_received = threading.Event()
condition = threading.Condition()
@crewai_event_bus.on(LLMGuardrailStartedEvent)
def capture_guardrail_started(source, event):
guardrail_events["started"].append(event)
if (
len(guardrail_events["started"]) == 1
and len(guardrail_events["completed"]) == 1
):
all_events_received.set()
with condition:
guardrail_events["started"].append(event)
condition.notify()
@crewai_event_bus.on(LLMGuardrailCompletedEvent)
def capture_guardrail_completed(source, event):
guardrail_events["completed"].append(event)
if (
len(guardrail_events["started"]) == 1
and len(guardrail_events["completed"]) == 1
):
all_events_received.set()
with condition:
guardrail_events["completed"].append(event)
condition.notify()
agent = Agent(
role="Sports Analyst",
@@ -445,42 +437,40 @@ def test_guardrail_is_called_using_callable():
result = agent.kickoff(messages="Top 1 best players in the world?")
assert all_events_received.wait(timeout=10), (
"Timeout waiting for all guardrail events"
)
with condition:
success = condition.wait_for(
lambda: len(guardrail_events["started"]) >= 1
and len(guardrail_events["completed"]) >= 1,
timeout=10,
)
assert success, "Timeout waiting for all guardrail events"
assert len(guardrail_events["started"]) == 1
assert len(guardrail_events["completed"]) == 1
assert guardrail_events["completed"][0].success
assert "Pelé - Santos, 1958" in result.raw
@pytest.mark.vcr(filter_headers=["authorization"])
@pytest.mark.vcr()
def test_guardrail_reached_attempt_limit():
guardrail_events = defaultdict(list)
guardrail_events: dict[str, list] = defaultdict(list)
from crewai.events.event_types import (
LLMGuardrailCompletedEvent,
LLMGuardrailStartedEvent,
)
all_events_received = threading.Event()
condition = threading.Condition()
@crewai_event_bus.on(LLMGuardrailStartedEvent)
def capture_guardrail_started(source, event):
guardrail_events["started"].append(event)
if (
len(guardrail_events["started"]) == 3
and len(guardrail_events["completed"]) == 3
):
all_events_received.set()
with condition:
guardrail_events["started"].append(event)
condition.notify()
@crewai_event_bus.on(LLMGuardrailCompletedEvent)
def capture_guardrail_completed(source, event):
guardrail_events["completed"].append(event)
if (
len(guardrail_events["started"]) == 3
and len(guardrail_events["completed"]) == 3
):
all_events_received.set()
with condition:
guardrail_events["completed"].append(event)
condition.notify()
agent = Agent(
role="Sports Analyst",
@@ -498,9 +488,13 @@ def test_guardrail_reached_attempt_limit():
):
agent.kickoff(messages="Top 10 best players in the world?")
assert all_events_received.wait(timeout=10), (
"Timeout waiting for all guardrail events"
)
with condition:
success = condition.wait_for(
lambda: len(guardrail_events["started"]) >= 3
and len(guardrail_events["completed"]) >= 3,
timeout=10,
)
assert success, "Timeout waiting for all guardrail events"
assert len(guardrail_events["started"]) == 3 # 2 retries + 1 initial call
assert len(guardrail_events["completed"]) == 3 # 2 retries + 1 initial call
assert not guardrail_events["completed"][0].success
@@ -508,7 +502,7 @@ def test_guardrail_reached_attempt_limit():
assert not guardrail_events["completed"][2].success
@pytest.mark.vcr(filter_headers=["authorization"])
@pytest.mark.vcr()
def test_agent_output_when_guardrail_returns_base_model():
class Player(BaseModel):
name: str
@@ -599,7 +593,7 @@ def test_lite_agent_with_custom_llm_and_guardrails():
assert result2.raw == "Modified by guardrail"
@pytest.mark.vcr(filter_headers=["authorization"])
@pytest.mark.vcr()
def test_lite_agent_with_invalid_llm():
"""Test that LiteAgent raises proper error when create_llm returns None."""
with patch("crewai.lite_agent.create_llm", return_value=None):
@@ -615,7 +609,7 @@ def test_lite_agent_with_invalid_llm():
@patch.dict("os.environ", {"CREWAI_PLATFORM_INTEGRATION_TOKEN": "test_token"})
@patch("crewai_tools.tools.crewai_platform_tools.crewai_platform_tool_builder.requests.get")
@pytest.mark.vcr(filter_headers=["authorization"])
@pytest.mark.vcr()
def test_agent_kickoff_with_platform_tools(mock_get):
"""Test that Agent.kickoff() properly integrates platform tools with LiteAgent"""
mock_response = Mock()
@@ -657,7 +651,7 @@ def test_agent_kickoff_with_platform_tools(mock_get):
@patch.dict("os.environ", {"EXA_API_KEY": "test_exa_key"})
@patch("crewai.agent.Agent._get_external_mcp_tools")
@pytest.mark.vcr(filter_headers=["authorization"])
@pytest.mark.vcr()
def test_agent_kickoff_with_mcp_tools(mock_get_mcp_tools):
"""Test that Agent.kickoff() properly integrates MCP tools with LiteAgent"""
# Setup mock MCP tools - create a proper BaseTool instance

View File

@@ -1,126 +0,0 @@
interactions:
- request:
body: '{"messages": [{"role": "system", "content": "You are Test Agent. Test backstory\nYour
personal goal is: Test goal\nTo give my best complete final answer to the task
respond using the exact following format:\n\nThought: I now can give a great
answer\nFinal Answer: Your final answer must be the great and the most complete
as possible, it must be outcome described.\n\nI MUST use these formats, my job
depends on it!"}, {"role": "user", "content": "\nCurrent Task: Say hello to
the world\n\nThis is the expected criteria for your final answer: hello world\nyou
MUST return the actual complete content as the final answer, not a summary.\n\nBegin!
This is VERY important to you, use the tools available and give your best Final
Answer, your job depends on it!\n\nThought:"}], "model": "gpt-4o-mini", "stop":
["\nObservation:"]}'
headers:
accept:
- application/json
accept-encoding:
- gzip, deflate, zstd
connection:
- keep-alive
content-length:
- '825'
content-type:
- application/json
host:
- api.openai.com
user-agent:
- OpenAI/Python 1.93.0
x-stainless-arch:
- arm64
x-stainless-async:
- 'false'
x-stainless-lang:
- python
x-stainless-os:
- MacOS
x-stainless-package-version:
- 1.93.0
x-stainless-raw-response:
- 'true'
x-stainless-read-timeout:
- '600.0'
x-stainless-retry-count:
- '0'
x-stainless-runtime:
- CPython
x-stainless-runtime-version:
- 3.12.9
method: POST
uri: https://api.openai.com/v1/chat/completions
response:
body:
string: !!binary |
H4sIAAAAAAAAAwAAAP//jFJdi9swEHz3r1j0HBc7ZydXvx1HC0evhT6UUtrDKNLa1lXWqpJ8aTjy
34vsXOz0A/pi8M7OaGZ3nxMApiSrgImOB9Fbnd4Why/v79HeePeD37/Zfdx8evf5+rY4fNjdPbJV
ZNDuEUV4Yb0S1FuNQZGZYOGQB4yq+bYsr7J1nhcj0JNEHWmtDWlBaa+MStfZukizbZpfn9gdKYGe
VfA1AQB4Hr/Rp5H4k1WQrV4qPXrPW2TVuQmAOdKxwrj3ygduAlvNoCAT0IzW78DQHgQ30KonBA5t
tA3c+D06gG/mrTJcw834X0GHWhPsyWm5FHTYDJ7HUGbQegFwYyjwOJQxysMJOZ7Na2qto53/jcoa
ZZTvaofck4lGfSDLRvSYADyMQxoucjPrqLehDvQdx+fycjvpsXk3C/TqBAYKXC/q29NoL/VqiYEr
7RdjZoKLDuVMnXfCB6loASSL1H+6+Zv2lFyZ9n/kZ0AItAFlbR1KJS4Tz20O4+n+q+085dEw8+ie
lMA6KHRxExIbPujpoJg/+IB93SjTorNOTVfV2LrcZLzZYFm+Zskx+QUAAP//AwB1vYZ+YwMAAA==
headers:
CF-RAY:
- 96fc9f29dea3cf1f-SJC
Connection:
- keep-alive
Content-Encoding:
- gzip
Content-Type:
- application/json
Date:
- Fri, 15 Aug 2025 23:55:15 GMT
Server:
- cloudflare
Set-Cookie:
- __cf_bm=oA9oTa3cE0ZaEUDRf0hCpnarSAQKzrVUhl6qDS4j09w-1755302115-1.0.1.1-gUUDl4ZqvBQkg7244DTwOmSiDUT2z_AiQu0P1xUaABjaufSpZuIlI5G0H7OSnW.ldypvpxjj45NGWesJ62M_2U7r20tHz_gMmDFw6D5ZiNc;
path=/; expires=Sat, 16-Aug-25 00:25:15 GMT; domain=.api.openai.com; HttpOnly;
Secure; SameSite=None
- _cfuvid=ICenEGMmOE5jaOjwD30bAOwrF8.XRbSIKTBl1EyWs0o-1755302115700-0.0.1.1-604800000;
path=/; domain=.api.openai.com; HttpOnly; Secure; SameSite=None
Strict-Transport-Security:
- max-age=31536000; includeSubDomains; preload
Transfer-Encoding:
- chunked
X-Content-Type-Options:
- nosniff
access-control-expose-headers:
- X-Request-ID
alt-svc:
- h3=":443"; ma=86400
cf-cache-status:
- DYNAMIC
openai-organization:
- crewai-iuxna1
openai-processing-ms:
- '735'
openai-project:
- proj_xitITlrFeen7zjNSzML82h9x
openai-version:
- '2020-10-01'
x-envoy-upstream-service-time:
- '753'
x-ratelimit-limit-project-tokens:
- '150000000'
x-ratelimit-limit-requests:
- '30000'
x-ratelimit-limit-tokens:
- '150000000'
x-ratelimit-remaining-project-tokens:
- '149999830'
x-ratelimit-remaining-requests:
- '29999'
x-ratelimit-remaining-tokens:
- '149999827'
x-ratelimit-reset-project-tokens:
- 0s
x-ratelimit-reset-requests:
- 2ms
x-ratelimit-reset-tokens:
- 0s
x-request-id:
- req_212fde9d945a462ba0d89ea856131dce
status:
code: 200
message: OK
version: 1

Some files were not shown because too many files have changed in this diff Show More