mirror of
https://github.com/crewAIInc/crewAI.git
synced 2025-12-16 20:38:29 +00:00
Compare commits
27 Commits
lorenze/la
...
1.7.0
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3e3b9df761 | ||
|
|
177294f588 | ||
|
|
beef712646 | ||
|
|
6125b866fd | ||
|
|
f2f994612c | ||
|
|
7fff2b654c | ||
|
|
34e09162ba | ||
|
|
24d1fad7ab | ||
|
|
9b8f31fa07 | ||
|
|
d898d7c02c | ||
|
|
f04c40babf | ||
|
|
c456e5c5fa | ||
|
|
633e279b51 | ||
|
|
a25778974d | ||
|
|
09f1ba6956 | ||
|
|
20704742e2 | ||
|
|
59180e9c9f | ||
|
|
3ce019b07b | ||
|
|
2355ec0733 | ||
|
|
c925d2d519 | ||
|
|
bc4e6a3127 | ||
|
|
37526c693b | ||
|
|
c59173a762 | ||
|
|
4d8eec96e8 | ||
|
|
2025a26fc3 | ||
|
|
bed9a3847a | ||
|
|
5239dc9859 |
161
.env.test
Normal file
161
.env.test
Normal file
@@ -0,0 +1,161 @@
|
||||
# =============================================================================
|
||||
# Test Environment Variables
|
||||
# =============================================================================
|
||||
# This file contains all environment variables needed to run tests locally
|
||||
# in a way that mimics the GitHub Actions CI environment.
|
||||
|
||||
# =============================================================================
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# LLM Provider API Keys
|
||||
# -----------------------------------------------------------------------------
|
||||
OPENAI_API_KEY=fake-api-key
|
||||
ANTHROPIC_API_KEY=fake-anthropic-key
|
||||
GEMINI_API_KEY=fake-gemini-key
|
||||
AZURE_API_KEY=fake-azure-key
|
||||
OPENROUTER_API_KEY=fake-openrouter-key
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# AWS Credentials
|
||||
# -----------------------------------------------------------------------------
|
||||
AWS_ACCESS_KEY_ID=fake-aws-access-key
|
||||
AWS_SECRET_ACCESS_KEY=fake-aws-secret-key
|
||||
AWS_DEFAULT_REGION=us-east-1
|
||||
AWS_REGION_NAME=us-east-1
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Azure OpenAI Configuration
|
||||
# -----------------------------------------------------------------------------
|
||||
AZURE_ENDPOINT=https://fake-azure-endpoint.openai.azure.com
|
||||
AZURE_OPENAI_ENDPOINT=https://fake-azure-endpoint.openai.azure.com
|
||||
AZURE_OPENAI_API_KEY=fake-azure-openai-key
|
||||
AZURE_API_VERSION=2024-02-15-preview
|
||||
OPENAI_API_VERSION=2024-02-15-preview
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Google Cloud Configuration
|
||||
# -----------------------------------------------------------------------------
|
||||
#GOOGLE_CLOUD_PROJECT=fake-gcp-project
|
||||
#GOOGLE_CLOUD_LOCATION=us-central1
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# OpenAI Configuration
|
||||
# -----------------------------------------------------------------------------
|
||||
OPENAI_BASE_URL=https://api.openai.com/v1
|
||||
OPENAI_API_BASE=https://api.openai.com/v1
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Search & Scraping Tool API Keys
|
||||
# -----------------------------------------------------------------------------
|
||||
SERPER_API_KEY=fake-serper-key
|
||||
EXA_API_KEY=fake-exa-key
|
||||
BRAVE_API_KEY=fake-brave-key
|
||||
FIRECRAWL_API_KEY=fake-firecrawl-key
|
||||
TAVILY_API_KEY=fake-tavily-key
|
||||
SERPAPI_API_KEY=fake-serpapi-key
|
||||
SERPLY_API_KEY=fake-serply-key
|
||||
LINKUP_API_KEY=fake-linkup-key
|
||||
PARALLEL_API_KEY=fake-parallel-key
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Exa Configuration
|
||||
# -----------------------------------------------------------------------------
|
||||
EXA_BASE_URL=https://api.exa.ai
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Web Scraping & Automation
|
||||
# -----------------------------------------------------------------------------
|
||||
BRIGHT_DATA_API_KEY=fake-brightdata-key
|
||||
BRIGHT_DATA_ZONE=fake-zone
|
||||
BRIGHTDATA_API_URL=https://api.brightdata.com
|
||||
BRIGHTDATA_DEFAULT_TIMEOUT=600
|
||||
BRIGHTDATA_DEFAULT_POLLING_INTERVAL=1
|
||||
|
||||
OXYLABS_USERNAME=fake-oxylabs-user
|
||||
OXYLABS_PASSWORD=fake-oxylabs-pass
|
||||
|
||||
SCRAPFLY_API_KEY=fake-scrapfly-key
|
||||
SCRAPEGRAPH_API_KEY=fake-scrapegraph-key
|
||||
|
||||
BROWSERBASE_API_KEY=fake-browserbase-key
|
||||
BROWSERBASE_PROJECT_ID=fake-browserbase-project
|
||||
|
||||
HYPERBROWSER_API_KEY=fake-hyperbrowser-key
|
||||
MULTION_API_KEY=fake-multion-key
|
||||
APIFY_API_TOKEN=fake-apify-token
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Database & Vector Store Credentials
|
||||
# -----------------------------------------------------------------------------
|
||||
SINGLESTOREDB_URL=mysql://fake:fake@localhost:3306/fake
|
||||
SINGLESTOREDB_HOST=localhost
|
||||
SINGLESTOREDB_PORT=3306
|
||||
SINGLESTOREDB_USER=fake-user
|
||||
SINGLESTOREDB_PASSWORD=fake-password
|
||||
SINGLESTOREDB_DATABASE=fake-database
|
||||
SINGLESTOREDB_CONNECT_TIMEOUT=30
|
||||
|
||||
SNOWFLAKE_USER=fake-snowflake-user
|
||||
SNOWFLAKE_PASSWORD=fake-snowflake-password
|
||||
SNOWFLAKE_ACCOUNT=fake-snowflake-account
|
||||
SNOWFLAKE_WAREHOUSE=fake-snowflake-warehouse
|
||||
SNOWFLAKE_DATABASE=fake-snowflake-database
|
||||
SNOWFLAKE_SCHEMA=fake-snowflake-schema
|
||||
|
||||
WEAVIATE_URL=http://localhost:8080
|
||||
WEAVIATE_API_KEY=fake-weaviate-key
|
||||
|
||||
EMBEDCHAIN_DB_URI=sqlite:///test.db
|
||||
|
||||
# Databricks Credentials
|
||||
DATABRICKS_HOST=https://fake-databricks.cloud.databricks.com
|
||||
DATABRICKS_TOKEN=fake-databricks-token
|
||||
DATABRICKS_CONFIG_PROFILE=fake-profile
|
||||
|
||||
# MongoDB Credentials
|
||||
MONGODB_URI=mongodb://fake:fake@localhost:27017/fake
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# CrewAI Platform & Enterprise
|
||||
# -----------------------------------------------------------------------------
|
||||
# setting CREWAI_PLATFORM_INTEGRATION_TOKEN causes these test to fail:
|
||||
#=========================== short test summary info ============================
|
||||
#FAILED tests/test_context.py::TestPlatformIntegrationToken::test_platform_context_manager_basic_usage - AssertionError: assert 'fake-platform-token' is None
|
||||
# + where 'fake-platform-token' = get_platform_integration_token()
|
||||
#FAILED tests/test_context.py::TestPlatformIntegrationToken::test_context_var_isolation_between_tests - AssertionError: assert 'fake-platform-token' is None
|
||||
# + where 'fake-platform-token' = get_platform_integration_token()
|
||||
#FAILED tests/test_context.py::TestPlatformIntegrationToken::test_multiple_sequential_context_managers - AssertionError: assert 'fake-platform-token' is None
|
||||
# + where 'fake-platform-token' = get_platform_integration_token()
|
||||
#CREWAI_PLATFORM_INTEGRATION_TOKEN=fake-platform-token
|
||||
CREWAI_PERSONAL_ACCESS_TOKEN=fake-personal-token
|
||||
CREWAI_PLUS_URL=https://fake.crewai.com
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Other Service API Keys
|
||||
# -----------------------------------------------------------------------------
|
||||
ZAPIER_API_KEY=fake-zapier-key
|
||||
PATRONUS_API_KEY=fake-patronus-key
|
||||
MINDS_API_KEY=fake-minds-key
|
||||
HF_TOKEN=fake-hf-token
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Feature Flags/Testing Modes
|
||||
# -----------------------------------------------------------------------------
|
||||
CREWAI_DISABLE_TELEMETRY=true
|
||||
OTEL_SDK_DISABLED=true
|
||||
CREWAI_TESTING=true
|
||||
CREWAI_TRACING_ENABLED=false
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Testing/CI Configuration
|
||||
# -----------------------------------------------------------------------------
|
||||
# VCR recording mode: "none" (default), "new_episodes", "all", "once"
|
||||
PYTEST_VCR_RECORD_MODE=none
|
||||
|
||||
# Set to "true" by GitHub when running in GitHub Actions
|
||||
# GITHUB_ACTIONS=false
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Python Configuration
|
||||
# -----------------------------------------------------------------------------
|
||||
PYTHONUNBUFFERED=1
|
||||
18
.github/workflows/tests.yml
vendored
18
.github/workflows/tests.yml
vendored
@@ -5,18 +5,6 @@ on: [pull_request]
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
env:
|
||||
OPENAI_API_KEY: fake-api-key
|
||||
PYTHONUNBUFFERED: 1
|
||||
BRAVE_API_KEY: fake-brave-key
|
||||
SNOWFLAKE_USER: fake-snowflake-user
|
||||
SNOWFLAKE_PASSWORD: fake-snowflake-password
|
||||
SNOWFLAKE_ACCOUNT: fake-snowflake-account
|
||||
SNOWFLAKE_WAREHOUSE: fake-snowflake-warehouse
|
||||
SNOWFLAKE_DATABASE: fake-snowflake-database
|
||||
SNOWFLAKE_SCHEMA: fake-snowflake-schema
|
||||
EMBEDCHAIN_DB_URI: sqlite:///test.db
|
||||
|
||||
jobs:
|
||||
tests:
|
||||
name: tests (${{ matrix.python-version }})
|
||||
@@ -84,26 +72,20 @@ jobs:
|
||||
# fi
|
||||
|
||||
cd lib/crewai && uv run pytest \
|
||||
--block-network \
|
||||
--timeout=30 \
|
||||
-vv \
|
||||
--splits 8 \
|
||||
--group ${{ matrix.group }} \
|
||||
$DURATIONS_ARG \
|
||||
--durations=10 \
|
||||
-n auto \
|
||||
--maxfail=3
|
||||
|
||||
- name: Run tool tests (group ${{ matrix.group }} of 8)
|
||||
run: |
|
||||
cd lib/crewai-tools && uv run pytest \
|
||||
--block-network \
|
||||
--timeout=30 \
|
||||
-vv \
|
||||
--splits 8 \
|
||||
--group ${{ matrix.group }} \
|
||||
--durations=10 \
|
||||
-n auto \
|
||||
--maxfail=3
|
||||
|
||||
|
||||
|
||||
193
conftest.py
Normal file
193
conftest.py
Normal file
@@ -0,0 +1,193 @@
|
||||
"""Pytest configuration for crewAI workspace."""
|
||||
|
||||
from collections.abc import Generator
|
||||
import os
|
||||
from pathlib import Path
|
||||
import tempfile
|
||||
from typing import Any
|
||||
|
||||
from dotenv import load_dotenv
|
||||
import pytest
|
||||
from vcr.request import Request # type: ignore[import-untyped]
|
||||
|
||||
|
||||
env_test_path = Path(__file__).parent / ".env.test"
|
||||
load_dotenv(env_test_path, override=True)
|
||||
load_dotenv(override=True)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True, scope="function")
|
||||
def cleanup_event_handlers() -> Generator[None, Any, None]:
|
||||
"""Clean up event bus handlers after each test to prevent test pollution."""
|
||||
yield
|
||||
|
||||
try:
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
|
||||
with crewai_event_bus._rwlock.w_locked():
|
||||
crewai_event_bus._sync_handlers.clear()
|
||||
crewai_event_bus._async_handlers.clear()
|
||||
except Exception: # noqa: S110
|
||||
pass
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True, scope="function")
|
||||
def setup_test_environment() -> Generator[None, Any, None]:
|
||||
"""Setup test environment for crewAI workspace."""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
storage_dir = Path(temp_dir) / "crewai_test_storage"
|
||||
storage_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if not storage_dir.exists() or not storage_dir.is_dir():
|
||||
raise RuntimeError(
|
||||
f"Failed to create test storage directory: {storage_dir}"
|
||||
)
|
||||
|
||||
try:
|
||||
test_file = storage_dir / ".permissions_test"
|
||||
test_file.touch()
|
||||
test_file.unlink()
|
||||
except (OSError, IOError) as e:
|
||||
raise RuntimeError(
|
||||
f"Test storage directory {storage_dir} is not writable: {e}"
|
||||
) from e
|
||||
|
||||
os.environ["CREWAI_STORAGE_DIR"] = str(storage_dir)
|
||||
os.environ["CREWAI_TESTING"] = "true"
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
os.environ.pop("CREWAI_TESTING", "true")
|
||||
os.environ.pop("CREWAI_STORAGE_DIR", None)
|
||||
os.environ.pop("CREWAI_DISABLE_TELEMETRY", "true")
|
||||
os.environ.pop("OTEL_SDK_DISABLED", "true")
|
||||
os.environ.pop("OPENAI_BASE_URL", "https://api.openai.com/v1")
|
||||
os.environ.pop("OPENAI_API_BASE", "https://api.openai.com/v1")
|
||||
|
||||
|
||||
HEADERS_TO_FILTER = {
|
||||
"authorization": "AUTHORIZATION-XXX",
|
||||
"content-security-policy": "CSP-FILTERED",
|
||||
"cookie": "COOKIE-XXX",
|
||||
"set-cookie": "SET-COOKIE-XXX",
|
||||
"permissions-policy": "PERMISSIONS-POLICY-XXX",
|
||||
"referrer-policy": "REFERRER-POLICY-XXX",
|
||||
"strict-transport-security": "STS-XXX",
|
||||
"x-content-type-options": "X-CONTENT-TYPE-XXX",
|
||||
"x-frame-options": "X-FRAME-OPTIONS-XXX",
|
||||
"x-permitted-cross-domain-policies": "X-PERMITTED-XXX",
|
||||
"x-request-id": "X-REQUEST-ID-XXX",
|
||||
"x-runtime": "X-RUNTIME-XXX",
|
||||
"x-xss-protection": "X-XSS-PROTECTION-XXX",
|
||||
"x-stainless-arch": "X-STAINLESS-ARCH-XXX",
|
||||
"x-stainless-os": "X-STAINLESS-OS-XXX",
|
||||
"x-stainless-read-timeout": "X-STAINLESS-READ-TIMEOUT-XXX",
|
||||
"cf-ray": "CF-RAY-XXX",
|
||||
"etag": "ETAG-XXX",
|
||||
"Strict-Transport-Security": "STS-XXX",
|
||||
"access-control-expose-headers": "ACCESS-CONTROL-XXX",
|
||||
"openai-organization": "OPENAI-ORG-XXX",
|
||||
"openai-project": "OPENAI-PROJECT-XXX",
|
||||
"x-ratelimit-limit-requests": "X-RATELIMIT-LIMIT-REQUESTS-XXX",
|
||||
"x-ratelimit-limit-tokens": "X-RATELIMIT-LIMIT-TOKENS-XXX",
|
||||
"x-ratelimit-remaining-requests": "X-RATELIMIT-REMAINING-REQUESTS-XXX",
|
||||
"x-ratelimit-remaining-tokens": "X-RATELIMIT-REMAINING-TOKENS-XXX",
|
||||
"x-ratelimit-reset-requests": "X-RATELIMIT-RESET-REQUESTS-XXX",
|
||||
"x-ratelimit-reset-tokens": "X-RATELIMIT-RESET-TOKENS-XXX",
|
||||
"x-goog-api-key": "X-GOOG-API-KEY-XXX",
|
||||
"api-key": "X-API-KEY-XXX",
|
||||
"User-Agent": "X-USER-AGENT-XXX",
|
||||
"apim-request-id:": "X-API-CLIENT-REQUEST-ID-XXX",
|
||||
"azureml-model-session": "AZUREML-MODEL-SESSION-XXX",
|
||||
"x-ms-client-request-id": "X-MS-CLIENT-REQUEST-ID-XXX",
|
||||
"x-ms-region": "X-MS-REGION-XXX",
|
||||
"apim-request-id": "APIM-REQUEST-ID-XXX",
|
||||
"x-api-key": "X-API-KEY-XXX",
|
||||
"anthropic-organization-id": "ANTHROPIC-ORGANIZATION-ID-XXX",
|
||||
"request-id": "REQUEST-ID-XXX",
|
||||
"anthropic-ratelimit-input-tokens-limit": "ANTHROPIC-RATELIMIT-INPUT-TOKENS-LIMIT-XXX",
|
||||
"anthropic-ratelimit-input-tokens-remaining": "ANTHROPIC-RATELIMIT-INPUT-TOKENS-REMAINING-XXX",
|
||||
"anthropic-ratelimit-input-tokens-reset": "ANTHROPIC-RATELIMIT-INPUT-TOKENS-RESET-XXX",
|
||||
"anthropic-ratelimit-output-tokens-limit": "ANTHROPIC-RATELIMIT-OUTPUT-TOKENS-LIMIT-XXX",
|
||||
"anthropic-ratelimit-output-tokens-remaining": "ANTHROPIC-RATELIMIT-OUTPUT-TOKENS-REMAINING-XXX",
|
||||
"anthropic-ratelimit-output-tokens-reset": "ANTHROPIC-RATELIMIT-OUTPUT-TOKENS-RESET-XXX",
|
||||
"anthropic-ratelimit-tokens-limit": "ANTHROPIC-RATELIMIT-TOKENS-LIMIT-XXX",
|
||||
"anthropic-ratelimit-tokens-remaining": "ANTHROPIC-RATELIMIT-TOKENS-REMAINING-XXX",
|
||||
"anthropic-ratelimit-tokens-reset": "ANTHROPIC-RATELIMIT-TOKENS-RESET-XXX",
|
||||
"x-amz-date": "X-AMZ-DATE-XXX",
|
||||
"amz-sdk-invocation-id": "AMZ-SDK-INVOCATION-ID-XXX",
|
||||
"accept-encoding": "ACCEPT-ENCODING-XXX",
|
||||
"x-amzn-requestid": "X-AMZN-REQUESTID-XXX",
|
||||
"x-amzn-RequestId": "X-AMZN-REQUESTID-XXX",
|
||||
}
|
||||
|
||||
|
||||
def _filter_request_headers(request: Request) -> Request: # type: ignore[no-any-unimported]
|
||||
"""Filter sensitive headers from request before recording."""
|
||||
for header_name, replacement in HEADERS_TO_FILTER.items():
|
||||
for variant in [header_name, header_name.upper(), header_name.title()]:
|
||||
if variant in request.headers:
|
||||
request.headers[variant] = [replacement]
|
||||
|
||||
request.method = request.method.upper()
|
||||
return request
|
||||
|
||||
|
||||
def _filter_response_headers(response: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Filter sensitive headers from response before recording."""
|
||||
for header_name, replacement in HEADERS_TO_FILTER.items():
|
||||
for variant in [header_name, header_name.upper(), header_name.title()]:
|
||||
if variant in response["headers"]:
|
||||
response["headers"][variant] = [replacement]
|
||||
return response
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def vcr_cassette_dir(request: Any) -> str:
|
||||
"""Generate cassette directory path based on test module location.
|
||||
|
||||
Organizes cassettes to mirror test directory structure within each package:
|
||||
lib/crewai/tests/llms/google/test_google.py -> lib/crewai/tests/cassettes/llms/google/
|
||||
lib/crewai-tools/tests/tools/test_search.py -> lib/crewai-tools/tests/cassettes/tools/
|
||||
"""
|
||||
test_file = Path(request.fspath)
|
||||
|
||||
for parent in test_file.parents:
|
||||
if parent.name in ("crewai", "crewai-tools") and parent.parent.name == "lib":
|
||||
package_root = parent
|
||||
break
|
||||
else:
|
||||
package_root = test_file.parent
|
||||
|
||||
tests_root = package_root / "tests"
|
||||
test_dir = test_file.parent
|
||||
|
||||
if test_dir != tests_root:
|
||||
relative_path = test_dir.relative_to(tests_root)
|
||||
cassette_dir = tests_root / "cassettes" / relative_path
|
||||
else:
|
||||
cassette_dir = tests_root / "cassettes"
|
||||
|
||||
cassette_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
return str(cassette_dir)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def vcr_config(vcr_cassette_dir: str) -> dict[str, Any]:
|
||||
"""Configure VCR with organized cassette storage."""
|
||||
config = {
|
||||
"cassette_library_dir": vcr_cassette_dir,
|
||||
"record_mode": os.getenv("PYTEST_VCR_RECORD_MODE", "once"),
|
||||
"filter_headers": [(k, v) for k, v in HEADERS_TO_FILTER.items()],
|
||||
"before_record_request": _filter_request_headers,
|
||||
"before_record_response": _filter_response_headers,
|
||||
"filter_query_parameters": ["key"],
|
||||
"match_on": ["method", "scheme", "host", "port", "path"],
|
||||
}
|
||||
|
||||
if os.getenv("GITHUB_ACTIONS") == "true":
|
||||
config["record_mode"] = "none"
|
||||
|
||||
return config
|
||||
@@ -283,11 +283,54 @@ In this section, you'll find detailed examples that help you select, configure,
|
||||
)
|
||||
```
|
||||
|
||||
**Extended Thinking (Claude Sonnet 4 and Beyond):**
|
||||
|
||||
CrewAI supports Anthropic's Extended Thinking feature, which allows Claude to think through problems in a more human-like way before responding. This is particularly useful for complex reasoning, analysis, and problem-solving tasks.
|
||||
|
||||
```python Code
|
||||
from crewai import LLM
|
||||
|
||||
# Enable extended thinking with default settings
|
||||
llm = LLM(
|
||||
model="anthropic/claude-sonnet-4",
|
||||
thinking={"type": "enabled"},
|
||||
max_tokens=10000
|
||||
)
|
||||
|
||||
# Configure thinking with budget control
|
||||
llm = LLM(
|
||||
model="anthropic/claude-sonnet-4",
|
||||
thinking={
|
||||
"type": "enabled",
|
||||
"budget_tokens": 5000 # Limit thinking tokens
|
||||
},
|
||||
max_tokens=10000
|
||||
)
|
||||
```
|
||||
|
||||
**Thinking Configuration Options:**
|
||||
- `type`: Set to `"enabled"` to activate extended thinking mode
|
||||
- `budget_tokens` (optional): Maximum tokens to use for thinking (helps control costs)
|
||||
|
||||
**Models Supporting Extended Thinking:**
|
||||
- `claude-sonnet-4` and newer models
|
||||
- `claude-3-7-sonnet` (with extended thinking capabilities)
|
||||
|
||||
**When to Use Extended Thinking:**
|
||||
- Complex reasoning and multi-step problem solving
|
||||
- Mathematical calculations and proofs
|
||||
- Code analysis and debugging
|
||||
- Strategic planning and decision making
|
||||
- Research and analytical tasks
|
||||
|
||||
**Note:** Extended thinking consumes additional tokens but can significantly improve response quality for complex tasks.
|
||||
|
||||
**Supported Environment Variables:**
|
||||
- `ANTHROPIC_API_KEY`: Your Anthropic API key (required)
|
||||
|
||||
**Features:**
|
||||
- Native tool use support for Claude 3+ models
|
||||
- Extended Thinking support for Claude Sonnet 4+
|
||||
- Streaming support for real-time responses
|
||||
- Automatic system message handling
|
||||
- Stop sequences for controlled output
|
||||
@@ -305,6 +348,7 @@ In this section, you'll find detailed examples that help you select, configure,
|
||||
|
||||
| Model | Context Window | Best For |
|
||||
|------------------------------|----------------|-----------------------------------------------|
|
||||
| claude-sonnet-4 | 200,000 tokens | Latest with extended thinking capabilities |
|
||||
| claude-3-7-sonnet | 200,000 tokens | Advanced reasoning and agentic tasks |
|
||||
| claude-3-5-sonnet-20241022 | 200,000 tokens | Latest Sonnet with best performance |
|
||||
| claude-3-5-haiku | 200,000 tokens | Fast, compact model for quick responses |
|
||||
@@ -1089,6 +1133,50 @@ CrewAI supports streaming responses from LLMs, allowing your application to rece
|
||||
</Tab>
|
||||
</Tabs>
|
||||
|
||||
## Async LLM Calls
|
||||
|
||||
CrewAI supports asynchronous LLM calls for improved performance and concurrency in your AI workflows. Async calls allow you to run multiple LLM requests concurrently without blocking, making them ideal for high-throughput applications and parallel agent operations.
|
||||
|
||||
<Tabs>
|
||||
<Tab title="Basic Usage">
|
||||
Use the `acall` method for asynchronous LLM requests:
|
||||
|
||||
```python
|
||||
import asyncio
|
||||
from crewai import LLM
|
||||
|
||||
async def main():
|
||||
llm = LLM(model="openai/gpt-4o")
|
||||
|
||||
# Single async call
|
||||
response = await llm.acall("What is the capital of France?")
|
||||
print(response)
|
||||
|
||||
asyncio.run(main())
|
||||
```
|
||||
|
||||
The `acall` method supports all the same parameters as the synchronous `call` method, including messages, tools, and callbacks.
|
||||
</Tab>
|
||||
|
||||
<Tab title="With Streaming">
|
||||
Combine async calls with streaming for real-time concurrent responses:
|
||||
|
||||
```python
|
||||
import asyncio
|
||||
from crewai import LLM
|
||||
|
||||
async def stream_async():
|
||||
llm = LLM(model="openai/gpt-4o", stream=True)
|
||||
|
||||
response = await llm.acall("Write a short story about AI")
|
||||
|
||||
print(response)
|
||||
|
||||
asyncio.run(stream_async())
|
||||
```
|
||||
</Tab>
|
||||
</Tabs>
|
||||
|
||||
## Structured LLM Calls
|
||||
|
||||
CrewAI supports structured responses from LLM calls by allowing you to define a `response_format` using a Pydantic model. This enables the framework to automatically parse and validate the output, making it easier to integrate the response into your application without manual post-processing.
|
||||
|
||||
@@ -515,8 +515,7 @@ crew = Crew(
|
||||
"provider": "huggingface",
|
||||
"config": {
|
||||
"api_key": "your-hf-token", # Optional for public models
|
||||
"model": "sentence-transformers/all-MiniLM-L6-v2",
|
||||
"api_url": "https://api-inference.huggingface.co" # or your custom endpoint
|
||||
"model": "sentence-transformers/all-MiniLM-L6-v2"
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
@@ -66,5 +66,55 @@ def my_cache_strategy(arguments: dict, result: str) -> bool:
|
||||
cached_tool.cache_function = my_cache_strategy
|
||||
```
|
||||
|
||||
### Creating Async Tools
|
||||
|
||||
CrewAI supports async tools for non-blocking I/O operations. This is useful when your tool needs to make HTTP requests, database queries, or other I/O-bound operations.
|
||||
|
||||
#### Using the `@tool` Decorator with Async Functions
|
||||
|
||||
The simplest way to create an async tool is using the `@tool` decorator with an async function:
|
||||
|
||||
```python Code
|
||||
import aiohttp
|
||||
from crewai.tools import tool
|
||||
|
||||
@tool("Async Web Fetcher")
|
||||
async def fetch_webpage(url: str) -> str:
|
||||
"""Fetch content from a webpage asynchronously."""
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(url) as response:
|
||||
return await response.text()
|
||||
```
|
||||
|
||||
#### Subclassing `BaseTool` with Async Support
|
||||
|
||||
For more control, subclass `BaseTool` and implement both `_run` (sync) and `_arun` (async) methods:
|
||||
|
||||
```python Code
|
||||
import requests
|
||||
import aiohttp
|
||||
from crewai.tools import BaseTool
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
class WebFetcherInput(BaseModel):
|
||||
"""Input schema for WebFetcher."""
|
||||
url: str = Field(..., description="The URL to fetch")
|
||||
|
||||
class WebFetcherTool(BaseTool):
|
||||
name: str = "Web Fetcher"
|
||||
description: str = "Fetches content from a URL"
|
||||
args_schema: type[BaseModel] = WebFetcherInput
|
||||
|
||||
def _run(self, url: str) -> str:
|
||||
"""Synchronous implementation."""
|
||||
return requests.get(url).text
|
||||
|
||||
async def _arun(self, url: str) -> str:
|
||||
"""Asynchronous implementation for non-blocking I/O."""
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(url) as response:
|
||||
return await response.text()
|
||||
```
|
||||
|
||||
By adhering to these guidelines and incorporating new functionalities and collaboration tools into your tool creation and management processes,
|
||||
you can leverage the full capabilities of the CrewAI framework, enhancing both the development experience and the efficiency of your AI agents.
|
||||
|
||||
@@ -515,8 +515,7 @@ crew = Crew(
|
||||
"provider": "huggingface",
|
||||
"config": {
|
||||
"api_key": "your-hf-token", # Optional for public models
|
||||
"model": "sentence-transformers/all-MiniLM-L6-v2",
|
||||
"api_url": "https://api-inference.huggingface.co" # or your custom endpoint
|
||||
"model": "sentence-transformers/all-MiniLM-L6-v2"
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
@@ -63,5 +63,55 @@ def my_cache_strategy(arguments: dict, result: str) -> bool:
|
||||
cached_tool.cache_function = my_cache_strategy
|
||||
```
|
||||
|
||||
### 비동기 도구 생성하기
|
||||
|
||||
CrewAI는 논블로킹 I/O 작업을 위한 비동기 도구를 지원합니다. 이는 HTTP 요청, 데이터베이스 쿼리 또는 기타 I/O 바운드 작업이 필요한 경우에 유용합니다.
|
||||
|
||||
#### `@tool` 데코레이터와 비동기 함수 사용하기
|
||||
|
||||
비동기 도구를 만드는 가장 간단한 방법은 `@tool` 데코레이터와 async 함수를 사용하는 것입니다:
|
||||
|
||||
```python Code
|
||||
import aiohttp
|
||||
from crewai.tools import tool
|
||||
|
||||
@tool("Async Web Fetcher")
|
||||
async def fetch_webpage(url: str) -> str:
|
||||
"""Fetch content from a webpage asynchronously."""
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(url) as response:
|
||||
return await response.text()
|
||||
```
|
||||
|
||||
#### 비동기 지원으로 `BaseTool` 서브클래싱하기
|
||||
|
||||
더 많은 제어를 위해 `BaseTool`을 상속하고 `_run`(동기) 및 `_arun`(비동기) 메서드를 모두 구현할 수 있습니다:
|
||||
|
||||
```python Code
|
||||
import requests
|
||||
import aiohttp
|
||||
from crewai.tools import BaseTool
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
class WebFetcherInput(BaseModel):
|
||||
"""Input schema for WebFetcher."""
|
||||
url: str = Field(..., description="The URL to fetch")
|
||||
|
||||
class WebFetcherTool(BaseTool):
|
||||
name: str = "Web Fetcher"
|
||||
description: str = "Fetches content from a URL"
|
||||
args_schema: type[BaseModel] = WebFetcherInput
|
||||
|
||||
def _run(self, url: str) -> str:
|
||||
"""Synchronous implementation."""
|
||||
return requests.get(url).text
|
||||
|
||||
async def _arun(self, url: str) -> str:
|
||||
"""Asynchronous implementation for non-blocking I/O."""
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(url) as response:
|
||||
return await response.text()
|
||||
```
|
||||
|
||||
이 가이드라인을 준수하고 새로운 기능과 협업 도구를 도구 생성 및 관리 프로세스에 통합함으로써,
|
||||
CrewAI 프레임워크의 모든 기능을 활용할 수 있으며, AI agent의 개발 경험과 효율성을 모두 높일 수 있습니다.
|
||||
CrewAI 프레임워크의 모든 기능을 활용할 수 있으며, AI agent의 개발 경험과 효율성을 모두 높일 수 있습니다.
|
||||
|
||||
@@ -515,8 +515,7 @@ crew = Crew(
|
||||
"provider": "huggingface",
|
||||
"config": {
|
||||
"api_key": "your-hf-token", # Opcional para modelos públicos
|
||||
"model": "sentence-transformers/all-MiniLM-L6-v2",
|
||||
"api_url": "https://api-inference.huggingface.co" # ou seu endpoint customizado
|
||||
"model": "sentence-transformers/all-MiniLM-L6-v2"
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
@@ -66,5 +66,55 @@ def my_cache_strategy(arguments: dict, result: str) -> bool:
|
||||
cached_tool.cache_function = my_cache_strategy
|
||||
```
|
||||
|
||||
### Criando Ferramentas Assíncronas
|
||||
|
||||
O CrewAI suporta ferramentas assíncronas para operações de I/O não bloqueantes. Isso é útil quando sua ferramenta precisa fazer requisições HTTP, consultas a banco de dados ou outras operações de I/O.
|
||||
|
||||
#### Usando o Decorador `@tool` com Funções Assíncronas
|
||||
|
||||
A maneira mais simples de criar uma ferramenta assíncrona é usando o decorador `@tool` com uma função async:
|
||||
|
||||
```python Code
|
||||
import aiohttp
|
||||
from crewai.tools import tool
|
||||
|
||||
@tool("Async Web Fetcher")
|
||||
async def fetch_webpage(url: str) -> str:
|
||||
"""Fetch content from a webpage asynchronously."""
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(url) as response:
|
||||
return await response.text()
|
||||
```
|
||||
|
||||
#### Subclassificando `BaseTool` com Suporte Assíncrono
|
||||
|
||||
Para maior controle, herde de `BaseTool` e implemente os métodos `_run` (síncrono) e `_arun` (assíncrono):
|
||||
|
||||
```python Code
|
||||
import requests
|
||||
import aiohttp
|
||||
from crewai.tools import BaseTool
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
class WebFetcherInput(BaseModel):
|
||||
"""Input schema for WebFetcher."""
|
||||
url: str = Field(..., description="The URL to fetch")
|
||||
|
||||
class WebFetcherTool(BaseTool):
|
||||
name: str = "Web Fetcher"
|
||||
description: str = "Fetches content from a URL"
|
||||
args_schema: type[BaseModel] = WebFetcherInput
|
||||
|
||||
def _run(self, url: str) -> str:
|
||||
"""Synchronous implementation."""
|
||||
return requests.get(url).text
|
||||
|
||||
async def _arun(self, url: str) -> str:
|
||||
"""Asynchronous implementation for non-blocking I/O."""
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(url) as response:
|
||||
return await response.text()
|
||||
```
|
||||
|
||||
Seguindo essas orientações e incorporando novas funcionalidades e ferramentas de colaboração nos seus processos de criação e gerenciamento de ferramentas,
|
||||
você pode aproveitar ao máximo as capacidades do framework CrewAI, aprimorando tanto a experiência de desenvolvimento quanto a eficiência dos seus agentes de IA.
|
||||
você pode aproveitar ao máximo as capacidades do framework CrewAI, aprimorando tanto a experiência de desenvolvimento quanto a eficiência dos seus agentes de IA.
|
||||
|
||||
@@ -218,7 +218,7 @@ Update the root `README.md` only if the tool introduces a new category or notabl
|
||||
|
||||
## Discovery and specs
|
||||
|
||||
Our internal tooling discovers classes whose names end with `Tool`. Keep your class exported from the module path under `crewai_tools/tools/...` to be picked up by scripts like `generate_tool_specs.py`.
|
||||
Our internal tooling discovers classes whose names end with `Tool`. Keep your class exported from the module path under `crewai_tools/tools/...` to be picked up by scripts like `crewai_tools.generate_tool_specs.py`.
|
||||
|
||||
---
|
||||
|
||||
|
||||
@@ -8,17 +8,17 @@ authors = [
|
||||
]
|
||||
requires-python = ">=3.10, <3.14"
|
||||
dependencies = [
|
||||
"lancedb>=0.5.4",
|
||||
"pytube>=15.0.0",
|
||||
"requests>=2.32.5",
|
||||
"docker>=7.1.0",
|
||||
"crewai==1.6.0",
|
||||
"lancedb>=0.5.4",
|
||||
"tiktoken>=0.8.0",
|
||||
"beautifulsoup4>=4.13.4",
|
||||
"pypdf>=5.9.0",
|
||||
"python-docx>=1.2.0",
|
||||
"youtube-transcript-api>=1.2.2",
|
||||
"lancedb~=0.5.4",
|
||||
"pytube~=15.0.0",
|
||||
"requests~=2.32.5",
|
||||
"docker~=7.1.0",
|
||||
"crewai==1.7.0",
|
||||
"lancedb~=0.5.4",
|
||||
"tiktoken~=0.8.0",
|
||||
"beautifulsoup4~=4.13.4",
|
||||
"python-docx~=1.2.0",
|
||||
"youtube-transcript-api~=1.2.2",
|
||||
"pymupdf~=1.26.6",
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -291,4 +291,4 @@ __all__ = [
|
||||
"ZapierActionTools",
|
||||
]
|
||||
|
||||
__version__ = "1.6.0"
|
||||
__version__ = "1.7.0"
|
||||
|
||||
@@ -3,8 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, TypeAlias, TypedDict, cast
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
import uuid
|
||||
|
||||
from crewai.rag.config.types import RagConfigType
|
||||
@@ -19,15 +18,13 @@ from typing_extensions import TypeIs, Unpack
|
||||
from crewai_tools.rag.data_types import DataType
|
||||
from crewai_tools.rag.misc import sanitize_metadata_for_chromadb
|
||||
from crewai_tools.tools.rag.rag_tool import Adapter
|
||||
from crewai_tools.tools.rag.types import AddDocumentParams, ContentItem
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.rag.qdrant.config import QdrantConfig
|
||||
|
||||
|
||||
ContentItem: TypeAlias = str | Path | dict[str, Any]
|
||||
|
||||
|
||||
def _is_qdrant_config(config: Any) -> TypeIs[QdrantConfig]:
|
||||
"""Check if config is a QdrantConfig using safe duck typing.
|
||||
|
||||
@@ -46,19 +43,6 @@ def _is_qdrant_config(config: Any) -> TypeIs[QdrantConfig]:
|
||||
return False
|
||||
|
||||
|
||||
class AddDocumentParams(TypedDict, total=False):
|
||||
"""Parameters for adding documents to the RAG system."""
|
||||
|
||||
data_type: DataType
|
||||
metadata: dict[str, Any]
|
||||
website: str
|
||||
url: str
|
||||
file_path: str | Path
|
||||
github_url: str
|
||||
youtube_url: str
|
||||
directory_path: str | Path
|
||||
|
||||
|
||||
class CrewAIRagAdapter(Adapter):
|
||||
"""Adapter that uses CrewAI's native RAG system.
|
||||
|
||||
@@ -131,13 +115,26 @@ class CrewAIRagAdapter(Adapter):
|
||||
def add(self, *args: ContentItem, **kwargs: Unpack[AddDocumentParams]) -> None:
|
||||
"""Add content to the knowledge base.
|
||||
|
||||
This method handles various input types and converts them to documents
|
||||
for the vector database. It supports the data_type parameter for
|
||||
compatibility with existing tools.
|
||||
|
||||
Args:
|
||||
*args: Content items to add (strings, paths, or document dicts)
|
||||
**kwargs: Additional parameters including data_type, metadata, etc.
|
||||
**kwargs: Additional parameters including:
|
||||
- data_type: DataType enum or string (e.g., "file", "pdf_file", "text")
|
||||
- path: Path to file or directory (alternative to positional arg)
|
||||
- file_path: Alias for path
|
||||
- metadata: Additional metadata to attach to documents
|
||||
- url: URL to fetch content from
|
||||
- website: Website URL to scrape
|
||||
- github_url: GitHub repository URL
|
||||
- youtube_url: YouTube video URL
|
||||
- directory_path: Path to directory
|
||||
|
||||
Examples:
|
||||
rag_tool.add("path/to/document.pdf", data_type=DataType.PDF_FILE)
|
||||
|
||||
rag_tool.add(path="path/to/document.pdf", data_type="file")
|
||||
rag_tool.add(file_path="path/to/document.pdf", data_type="pdf_file")
|
||||
|
||||
rag_tool.add("path/to/document.pdf") # auto-detects PDF
|
||||
"""
|
||||
import os
|
||||
|
||||
@@ -146,10 +143,54 @@ class CrewAIRagAdapter(Adapter):
|
||||
from crewai_tools.rag.source_content import SourceContent
|
||||
|
||||
documents: list[BaseRecord] = []
|
||||
data_type: DataType | None = kwargs.get("data_type")
|
||||
raw_data_type = kwargs.get("data_type")
|
||||
base_metadata: dict[str, Any] = kwargs.get("metadata", {})
|
||||
|
||||
for arg in args:
|
||||
data_type: DataType | None = None
|
||||
if raw_data_type is not None:
|
||||
if isinstance(raw_data_type, DataType):
|
||||
if raw_data_type != DataType.FILE:
|
||||
data_type = raw_data_type
|
||||
elif isinstance(raw_data_type, str):
|
||||
if raw_data_type != "file":
|
||||
try:
|
||||
data_type = DataType(raw_data_type)
|
||||
except ValueError:
|
||||
raise ValueError(
|
||||
f"Invalid data_type: '{raw_data_type}'. "
|
||||
f"Valid values are: 'file' (auto-detect), or one of: "
|
||||
f"{', '.join(dt.value for dt in DataType)}"
|
||||
) from None
|
||||
|
||||
content_items: list[ContentItem] = list(args)
|
||||
|
||||
path_value = kwargs.get("path") or kwargs.get("file_path")
|
||||
if path_value is not None:
|
||||
content_items.append(path_value)
|
||||
|
||||
if url := kwargs.get("url"):
|
||||
content_items.append(url)
|
||||
if website := kwargs.get("website"):
|
||||
content_items.append(website)
|
||||
if github_url := kwargs.get("github_url"):
|
||||
content_items.append(github_url)
|
||||
if youtube_url := kwargs.get("youtube_url"):
|
||||
content_items.append(youtube_url)
|
||||
if directory_path := kwargs.get("directory_path"):
|
||||
content_items.append(directory_path)
|
||||
|
||||
file_extensions = {
|
||||
".pdf",
|
||||
".txt",
|
||||
".csv",
|
||||
".json",
|
||||
".xml",
|
||||
".docx",
|
||||
".mdx",
|
||||
".md",
|
||||
}
|
||||
|
||||
for arg in content_items:
|
||||
source_ref: str
|
||||
if isinstance(arg, dict):
|
||||
source_ref = str(arg.get("source", arg.get("content", "")))
|
||||
@@ -157,6 +198,14 @@ class CrewAIRagAdapter(Adapter):
|
||||
source_ref = str(arg)
|
||||
|
||||
if not data_type:
|
||||
ext = os.path.splitext(source_ref)[1].lower()
|
||||
is_url = source_ref.startswith(("http://", "https://", "file://"))
|
||||
if (
|
||||
ext in file_extensions
|
||||
and not is_url
|
||||
and not os.path.isfile(source_ref)
|
||||
):
|
||||
raise FileNotFoundError(f"File does not exist: {source_ref}")
|
||||
data_type = DataTypes.from_content(source_ref)
|
||||
|
||||
if data_type == DataType.DIRECTORY:
|
||||
|
||||
@@ -4,17 +4,20 @@ from collections.abc import Mapping
|
||||
import inspect
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any, cast
|
||||
from typing import Any
|
||||
|
||||
from crewai.tools.base_tool import BaseTool, EnvVar
|
||||
from crewai_tools import tools
|
||||
from pydantic import BaseModel
|
||||
from pydantic.json_schema import GenerateJsonSchema
|
||||
from pydantic_core import PydanticOmit
|
||||
|
||||
from crewai_tools import tools
|
||||
|
||||
|
||||
class SchemaGenerator(GenerateJsonSchema):
|
||||
def handle_invalid_for_json_schema(self, schema, error_info):
|
||||
def handle_invalid_for_json_schema(
|
||||
self, schema: Any, error_info: Any
|
||||
) -> dict[str, Any]:
|
||||
raise PydanticOmit
|
||||
|
||||
|
||||
@@ -73,7 +76,7 @@ class ToolSpecExtractor:
|
||||
|
||||
@staticmethod
|
||||
def _extract_field_default(
|
||||
field: dict | None, fallback: str | list[Any] = ""
|
||||
field: dict[str, Any] | None, fallback: str | list[Any] = ""
|
||||
) -> str | list[Any] | int:
|
||||
if not field:
|
||||
return fallback
|
||||
@@ -83,7 +86,7 @@ class ToolSpecExtractor:
|
||||
return default if isinstance(default, (list, str, int)) else fallback
|
||||
|
||||
@staticmethod
|
||||
def _extract_params(args_schema_field: dict | None) -> dict[str, Any]:
|
||||
def _extract_params(args_schema_field: dict[str, Any] | None) -> dict[str, Any]:
|
||||
if not args_schema_field:
|
||||
return {}
|
||||
|
||||
@@ -94,15 +97,15 @@ class ToolSpecExtractor:
|
||||
):
|
||||
return {}
|
||||
|
||||
# Cast to type[BaseModel] after runtime check
|
||||
schema_class = cast(type[BaseModel], args_schema_class)
|
||||
try:
|
||||
return schema_class.model_json_schema(schema_generator=SchemaGenerator)
|
||||
return args_schema_class.model_json_schema(schema_generator=SchemaGenerator)
|
||||
except Exception:
|
||||
return {}
|
||||
|
||||
@staticmethod
|
||||
def _extract_env_vars(env_vars_field: dict | None) -> list[dict[str, Any]]:
|
||||
def _extract_env_vars(
|
||||
env_vars_field: dict[str, Any] | None,
|
||||
) -> list[dict[str, Any]]:
|
||||
if not env_vars_field:
|
||||
return []
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
from enum import Enum
|
||||
from importlib import import_module
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import cast
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from crewai_tools.rag.base_loader import BaseLoader
|
||||
@@ -8,6 +10,7 @@ from crewai_tools.rag.chunkers.base_chunker import BaseChunker
|
||||
|
||||
|
||||
class DataType(str, Enum):
|
||||
FILE = "file"
|
||||
PDF_FILE = "pdf_file"
|
||||
TEXT_FILE = "text_file"
|
||||
CSV = "csv"
|
||||
@@ -15,22 +18,14 @@ class DataType(str, Enum):
|
||||
XML = "xml"
|
||||
DOCX = "docx"
|
||||
MDX = "mdx"
|
||||
|
||||
# Database types
|
||||
MYSQL = "mysql"
|
||||
POSTGRES = "postgres"
|
||||
|
||||
# Repository types
|
||||
GITHUB = "github"
|
||||
DIRECTORY = "directory"
|
||||
|
||||
# Web types
|
||||
WEBSITE = "website"
|
||||
DOCS_SITE = "docs_site"
|
||||
YOUTUBE_VIDEO = "youtube_video"
|
||||
YOUTUBE_CHANNEL = "youtube_channel"
|
||||
|
||||
# Raw types
|
||||
TEXT = "text"
|
||||
|
||||
def get_chunker(self) -> BaseChunker:
|
||||
@@ -63,13 +58,11 @@ class DataType(str, Enum):
|
||||
|
||||
try:
|
||||
module = import_module(module_path)
|
||||
return getattr(module, class_name)()
|
||||
return cast(BaseChunker, getattr(module, class_name)())
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error loading chunker for {self}: {e}") from e
|
||||
|
||||
def get_loader(self) -> BaseLoader:
|
||||
from importlib import import_module
|
||||
|
||||
loaders = {
|
||||
DataType.PDF_FILE: ("pdf_loader", "PDFLoader"),
|
||||
DataType.TEXT_FILE: ("text_loader", "TextFileLoader"),
|
||||
@@ -98,7 +91,7 @@ class DataType(str, Enum):
|
||||
module_path = f"crewai_tools.rag.loaders.{module_name}"
|
||||
try:
|
||||
module = import_module(module_path)
|
||||
return getattr(module, class_name)()
|
||||
return cast(BaseLoader, getattr(module, class_name)())
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error loading loader for {self}: {e}") from e
|
||||
|
||||
|
||||
@@ -2,70 +2,112 @@
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import Any, cast
|
||||
from urllib.parse import urlparse
|
||||
import urllib.request
|
||||
|
||||
from crewai_tools.rag.base_loader import BaseLoader, LoaderResult
|
||||
from crewai_tools.rag.source_content import SourceContent
|
||||
|
||||
|
||||
class PDFLoader(BaseLoader):
|
||||
"""Loader for PDF files."""
|
||||
"""Loader for PDF files and URLs."""
|
||||
|
||||
def load(self, source: SourceContent, **kwargs) -> LoaderResult: # type: ignore[override]
|
||||
"""Load and extract text from a PDF file.
|
||||
@staticmethod
|
||||
def _is_url(path: str) -> bool:
|
||||
"""Check if the path is a URL."""
|
||||
try:
|
||||
parsed = urlparse(path)
|
||||
return parsed.scheme in ("http", "https")
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _download_pdf(url: str) -> bytes:
|
||||
"""Download PDF content from a URL.
|
||||
|
||||
Args:
|
||||
source: The source content containing the PDF file path
|
||||
url: The URL to download from.
|
||||
|
||||
Returns:
|
||||
LoaderResult with extracted text content
|
||||
The PDF content as bytes.
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If the PDF file doesn't exist
|
||||
ImportError: If required PDF libraries aren't installed
|
||||
ValueError: If the download fails.
|
||||
"""
|
||||
|
||||
try:
|
||||
with urllib.request.urlopen(url, timeout=30) as response: # noqa: S310
|
||||
return cast(bytes, response.read())
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to download PDF from {url}: {e!s}") from e
|
||||
|
||||
def load(self, source: SourceContent, **kwargs: Any) -> LoaderResult: # type: ignore[override]
|
||||
"""Load and extract text from a PDF file or URL.
|
||||
|
||||
Args:
|
||||
source: The source content containing the PDF file path or URL.
|
||||
|
||||
Returns:
|
||||
LoaderResult with extracted text content.
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If the PDF file doesn't exist.
|
||||
ImportError: If required PDF libraries aren't installed.
|
||||
ValueError: If the PDF cannot be read or downloaded.
|
||||
"""
|
||||
try:
|
||||
import pypdf
|
||||
except ImportError:
|
||||
try:
|
||||
import PyPDF2 as pypdf # type: ignore[import-not-found,no-redef] # noqa: N813
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"PDF support requires pypdf or PyPDF2. Install with: uv add pypdf"
|
||||
) from e
|
||||
import pymupdf # type: ignore[import-untyped]
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"PDF support requires pymupdf. Install with: uv add pymupdf"
|
||||
) from e
|
||||
|
||||
file_path = source.source
|
||||
is_url = self._is_url(file_path)
|
||||
|
||||
if not os.path.isfile(file_path):
|
||||
raise FileNotFoundError(f"PDF file not found: {file_path}")
|
||||
if is_url:
|
||||
source_name = Path(urlparse(file_path).path).name or "downloaded.pdf"
|
||||
else:
|
||||
source_name = Path(file_path).name
|
||||
|
||||
text_content = []
|
||||
text_content: list[str] = []
|
||||
metadata: dict[str, Any] = {
|
||||
"source": str(file_path),
|
||||
"file_name": Path(file_path).name,
|
||||
"source": file_path,
|
||||
"file_name": source_name,
|
||||
"file_type": "pdf",
|
||||
}
|
||||
|
||||
try:
|
||||
with open(file_path, "rb") as file:
|
||||
pdf_reader = pypdf.PdfReader(file)
|
||||
metadata["num_pages"] = len(pdf_reader.pages)
|
||||
if is_url:
|
||||
pdf_bytes = self._download_pdf(file_path)
|
||||
doc = pymupdf.open(stream=pdf_bytes, filetype="pdf")
|
||||
else:
|
||||
if not os.path.isfile(file_path):
|
||||
raise FileNotFoundError(f"PDF file not found: {file_path}")
|
||||
doc = pymupdf.open(file_path)
|
||||
|
||||
for page_num, page in enumerate(pdf_reader.pages, 1):
|
||||
page_text = page.extract_text()
|
||||
if page_text.strip():
|
||||
text_content.append(f"Page {page_num}:\n{page_text}")
|
||||
metadata["num_pages"] = len(doc)
|
||||
|
||||
for page_num, page in enumerate(doc, 1):
|
||||
page_text = page.get_text()
|
||||
if page_text.strip():
|
||||
text_content.append(f"Page {page_num}:\n{page_text}")
|
||||
|
||||
doc.close()
|
||||
except FileNotFoundError:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error reading PDF file {file_path}: {e!s}") from e
|
||||
raise ValueError(f"Error reading PDF from {file_path}: {e!s}") from e
|
||||
|
||||
if not text_content:
|
||||
content = f"[PDF file with no extractable text: {Path(file_path).name}]"
|
||||
content = f"[PDF file with no extractable text: {source_name}]"
|
||||
else:
|
||||
content = "\n\n".join(text_content)
|
||||
|
||||
return LoaderResult(
|
||||
content=content,
|
||||
source=str(file_path),
|
||||
source=file_path,
|
||||
metadata=metadata,
|
||||
doc_id=self.generate_doc_id(source_ref=str(file_path), content=content),
|
||||
doc_id=self.generate_doc_id(source_ref=file_path, content=content),
|
||||
)
|
||||
|
||||
@@ -14,9 +14,14 @@ from pydantic import (
|
||||
field_validator,
|
||||
model_validator,
|
||||
)
|
||||
from typing_extensions import Self
|
||||
from typing_extensions import Self, Unpack
|
||||
|
||||
from crewai_tools.tools.rag.types import RagToolConfig, VectorDbConfig
|
||||
from crewai_tools.tools.rag.types import (
|
||||
AddDocumentParams,
|
||||
ContentItem,
|
||||
RagToolConfig,
|
||||
VectorDbConfig,
|
||||
)
|
||||
|
||||
|
||||
def _validate_embedding_config(
|
||||
@@ -72,6 +77,8 @@ def _validate_embedding_config(
|
||||
|
||||
|
||||
class Adapter(BaseModel, ABC):
|
||||
"""Abstract base class for RAG adapters."""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
@abstractmethod
|
||||
@@ -86,8 +93,8 @@ class Adapter(BaseModel, ABC):
|
||||
@abstractmethod
|
||||
def add(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
*args: ContentItem,
|
||||
**kwargs: Unpack[AddDocumentParams],
|
||||
) -> None:
|
||||
"""Add content to the knowledge base."""
|
||||
|
||||
@@ -102,7 +109,11 @@ class RagTool(BaseTool):
|
||||
) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def add(self, *args: Any, **kwargs: Any) -> None:
|
||||
def add(
|
||||
self,
|
||||
*args: ContentItem,
|
||||
**kwargs: Unpack[AddDocumentParams],
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
name: str = "Knowledge base"
|
||||
@@ -207,9 +218,34 @@ class RagTool(BaseTool):
|
||||
|
||||
def add(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
*args: ContentItem,
|
||||
**kwargs: Unpack[AddDocumentParams],
|
||||
) -> None:
|
||||
"""Add content to the knowledge base.
|
||||
|
||||
|
||||
Args:
|
||||
*args: Content items to add (strings, paths, or document dicts)
|
||||
data_type: DataType enum or string (e.g., "file", "pdf_file", "text")
|
||||
path: Path to file or directory, alias to positional arg
|
||||
file_path: Alias for path
|
||||
metadata: Additional metadata to attach to documents
|
||||
url: URL to fetch content from
|
||||
website: Website URL to scrape
|
||||
github_url: GitHub repository URL
|
||||
youtube_url: YouTube video URL
|
||||
directory_path: Path to directory
|
||||
|
||||
Examples:
|
||||
rag_tool.add("path/to/document.pdf", data_type=DataType.PDF_FILE)
|
||||
|
||||
# Keyword argument (documented API)
|
||||
rag_tool.add(path="path/to/document.pdf", data_type="file")
|
||||
rag_tool.add(file_path="path/to/document.pdf", data_type="pdf_file")
|
||||
|
||||
# Auto-detect type from extension
|
||||
rag_tool.add("path/to/document.pdf") # auto-detects PDF
|
||||
"""
|
||||
self.adapter.add(*args, **kwargs)
|
||||
|
||||
def _run(
|
||||
|
||||
@@ -1,10 +1,50 @@
|
||||
"""Type definitions for RAG tool configuration."""
|
||||
|
||||
from typing import Any, Literal
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal, TypeAlias
|
||||
|
||||
from crewai.rag.embeddings.types import ProviderSpec
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from crewai_tools.rag.data_types import DataType
|
||||
|
||||
|
||||
DataTypeStr: TypeAlias = Literal[
|
||||
"file",
|
||||
"pdf_file",
|
||||
"text_file",
|
||||
"csv",
|
||||
"json",
|
||||
"xml",
|
||||
"docx",
|
||||
"mdx",
|
||||
"mysql",
|
||||
"postgres",
|
||||
"github",
|
||||
"directory",
|
||||
"website",
|
||||
"docs_site",
|
||||
"youtube_video",
|
||||
"youtube_channel",
|
||||
"text",
|
||||
]
|
||||
|
||||
ContentItem: TypeAlias = str | Path | dict[str, Any]
|
||||
|
||||
|
||||
class AddDocumentParams(TypedDict, total=False):
|
||||
"""Parameters for adding documents to the RAG system."""
|
||||
|
||||
data_type: DataType | DataTypeStr
|
||||
metadata: dict[str, Any]
|
||||
path: str | Path
|
||||
file_path: str | Path
|
||||
website: str
|
||||
url: str
|
||||
github_url: str
|
||||
youtube_url: str
|
||||
directory_path: str | Path
|
||||
|
||||
|
||||
class VectorDbConfig(TypedDict):
|
||||
"""Configuration for vector database provider.
|
||||
|
||||
@@ -1,21 +0,0 @@
|
||||
import pytest
|
||||
|
||||
|
||||
def pytest_configure(config):
|
||||
"""Register custom markers."""
|
||||
config.addinivalue_line("markers", "integration: mark test as an integration test")
|
||||
config.addinivalue_line("markers", "asyncio: mark test as an async test")
|
||||
|
||||
# Set the asyncio loop scope through ini configuration
|
||||
config.inicfg["asyncio_mode"] = "auto"
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def event_loop():
|
||||
"""Create an instance of the default event loop for each test case."""
|
||||
import asyncio
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
yield loop
|
||||
loop.close()
|
||||
@@ -2,7 +2,7 @@ import json
|
||||
from unittest import mock
|
||||
|
||||
from crewai.tools.base_tool import BaseTool, EnvVar
|
||||
from generate_tool_specs import ToolSpecExtractor
|
||||
from crewai_tools.generate_tool_specs import ToolSpecExtractor
|
||||
from pydantic import BaseModel, Field
|
||||
import pytest
|
||||
|
||||
@@ -61,8 +61,8 @@ def test_unwrap_schema(extractor):
|
||||
@pytest.fixture
|
||||
def mock_tool_extractor(extractor):
|
||||
with (
|
||||
mock.patch("generate_tool_specs.dir", return_value=["MockTool"]),
|
||||
mock.patch("generate_tool_specs.getattr", return_value=MockTool),
|
||||
mock.patch("crewai_tools.generate_tool_specs.dir", return_value=["MockTool"]),
|
||||
mock.patch("crewai_tools.generate_tool_specs.getattr", return_value=MockTool),
|
||||
):
|
||||
extractor.extract_all_tools()
|
||||
assert len(extractor.tools_spec) == 1
|
||||
|
||||
@@ -4,7 +4,7 @@ from crewai_tools.tools.firecrawl_crawl_website_tool.firecrawl_crawl_website_too
|
||||
FirecrawlCrawlWebsiteTool,
|
||||
)
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_firecrawl_crawl_tool_integration():
|
||||
tool = FirecrawlCrawlWebsiteTool(config={
|
||||
"limit": 2,
|
||||
|
||||
@@ -4,7 +4,7 @@ from crewai_tools.tools.firecrawl_scrape_website_tool.firecrawl_scrape_website_t
|
||||
FirecrawlScrapeWebsiteTool,
|
||||
)
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_firecrawl_scrape_tool_integration():
|
||||
tool = FirecrawlScrapeWebsiteTool()
|
||||
result = tool.run(url="https://firecrawl.dev")
|
||||
|
||||
@@ -3,7 +3,7 @@ import pytest
|
||||
from crewai_tools.tools.firecrawl_search_tool.firecrawl_search_tool import FirecrawlSearchTool
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_firecrawl_search_tool_integration():
|
||||
tool = FirecrawlSearchTool()
|
||||
result = tool.run(query="firecrawl")
|
||||
|
||||
471
lib/crewai-tools/tests/tools/rag/test_rag_tool_add_data_type.py
Normal file
471
lib/crewai-tools/tests/tools/rag/test_rag_tool_add_data_type.py
Normal file
@@ -0,0 +1,471 @@
|
||||
"""Tests for RagTool.add() method with various data_type values."""
|
||||
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai_tools.rag.data_types import DataType
|
||||
from crewai_tools.tools.rag.rag_tool import RagTool
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_rag_client() -> MagicMock:
|
||||
"""Create a mock RAG client for testing."""
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_or_create_collection = MagicMock(return_value=None)
|
||||
mock_client.add_documents = MagicMock(return_value=None)
|
||||
mock_client.search = MagicMock(return_value=[])
|
||||
return mock_client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def rag_tool(mock_rag_client: MagicMock) -> RagTool:
|
||||
"""Create a RagTool instance with mocked client."""
|
||||
with (
|
||||
patch(
|
||||
"crewai_tools.adapters.crewai_rag_adapter.get_rag_client",
|
||||
return_value=mock_rag_client,
|
||||
),
|
||||
patch(
|
||||
"crewai_tools.adapters.crewai_rag_adapter.create_client",
|
||||
return_value=mock_rag_client,
|
||||
),
|
||||
):
|
||||
return RagTool()
|
||||
|
||||
|
||||
class TestDataTypeFileAlias:
|
||||
"""Tests for data_type='file' alias."""
|
||||
|
||||
def test_file_alias_with_existing_file(
|
||||
self, rag_tool: RagTool, mock_rag_client: MagicMock
|
||||
) -> None:
|
||||
"""Test that data_type='file' works with existing files."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
test_file = Path(tmpdir) / "test.txt"
|
||||
test_file.write_text("Test content for file alias.")
|
||||
|
||||
rag_tool.add(path=str(test_file), data_type="file")
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
def test_file_alias_with_nonexistent_file_raises_error(
|
||||
self, rag_tool: RagTool
|
||||
) -> None:
|
||||
"""Test that data_type='file' raises FileNotFoundError for missing files."""
|
||||
with pytest.raises(FileNotFoundError, match="File does not exist"):
|
||||
rag_tool.add(path="nonexistent/path/to/file.pdf", data_type="file")
|
||||
|
||||
def test_file_alias_with_path_keyword(
|
||||
self, rag_tool: RagTool, mock_rag_client: MagicMock
|
||||
) -> None:
|
||||
"""Test that path keyword argument works with data_type='file'."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
test_file = Path(tmpdir) / "document.txt"
|
||||
test_file.write_text("Content via path keyword.")
|
||||
|
||||
rag_tool.add(data_type="file", path=str(test_file))
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
def test_file_alias_with_file_path_keyword(
|
||||
self, rag_tool: RagTool, mock_rag_client: MagicMock
|
||||
) -> None:
|
||||
"""Test that file_path keyword argument works with data_type='file'."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
test_file = Path(tmpdir) / "document.txt"
|
||||
test_file.write_text("Content via file_path keyword.")
|
||||
|
||||
rag_tool.add(data_type="file", file_path=str(test_file))
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
|
||||
class TestDataTypeStringValues:
|
||||
"""Tests for data_type as string values matching DataType enum."""
|
||||
|
||||
def test_pdf_file_string(
|
||||
self, rag_tool: RagTool, mock_rag_client: MagicMock
|
||||
) -> None:
|
||||
"""Test data_type='pdf_file' with existing PDF file."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
# Create a minimal valid PDF file
|
||||
test_file = Path(tmpdir) / "test.pdf"
|
||||
test_file.write_bytes(
|
||||
b"%PDF-1.4\n1 0 obj\n<<\n/Type /Catalog\n>>\nendobj\ntrailer\n"
|
||||
b"<<\n/Root 1 0 R\n>>\n%%EOF"
|
||||
)
|
||||
|
||||
# Mock the PDF loader to avoid actual PDF parsing
|
||||
with patch(
|
||||
"crewai_tools.adapters.crewai_rag_adapter.DataType.get_loader"
|
||||
) as mock_loader:
|
||||
mock_loader_instance = MagicMock()
|
||||
mock_loader_instance.load.return_value = MagicMock(
|
||||
content="PDF content", metadata={}, doc_id="test-id"
|
||||
)
|
||||
mock_loader.return_value = mock_loader_instance
|
||||
|
||||
rag_tool.add(path=str(test_file), data_type="pdf_file")
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
def test_text_file_string(
|
||||
self, rag_tool: RagTool, mock_rag_client: MagicMock
|
||||
) -> None:
|
||||
"""Test data_type='text_file' with existing text file."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
test_file = Path(tmpdir) / "test.txt"
|
||||
test_file.write_text("Plain text content.")
|
||||
|
||||
rag_tool.add(path=str(test_file), data_type="text_file")
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
def test_csv_string(self, rag_tool: RagTool, mock_rag_client: MagicMock) -> None:
|
||||
"""Test data_type='csv' with existing CSV file."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
test_file = Path(tmpdir) / "test.csv"
|
||||
test_file.write_text("name,value\nfoo,1\nbar,2")
|
||||
|
||||
rag_tool.add(path=str(test_file), data_type="csv")
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
def test_json_string(self, rag_tool: RagTool, mock_rag_client: MagicMock) -> None:
|
||||
"""Test data_type='json' with existing JSON file."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
test_file = Path(tmpdir) / "test.json"
|
||||
test_file.write_text('{"key": "value", "items": [1, 2, 3]}')
|
||||
|
||||
rag_tool.add(path=str(test_file), data_type="json")
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
def test_xml_string(self, rag_tool: RagTool, mock_rag_client: MagicMock) -> None:
|
||||
"""Test data_type='xml' with existing XML file."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
test_file = Path(tmpdir) / "test.xml"
|
||||
test_file.write_text('<?xml version="1.0"?><root><item>value</item></root>')
|
||||
|
||||
rag_tool.add(path=str(test_file), data_type="xml")
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
def test_mdx_string(self, rag_tool: RagTool, mock_rag_client: MagicMock) -> None:
|
||||
"""Test data_type='mdx' with existing MDX file."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
test_file = Path(tmpdir) / "test.mdx"
|
||||
test_file.write_text("# Heading\n\nSome markdown content.")
|
||||
|
||||
rag_tool.add(path=str(test_file), data_type="mdx")
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
def test_text_string(self, rag_tool: RagTool, mock_rag_client: MagicMock) -> None:
|
||||
"""Test data_type='text' with raw text content."""
|
||||
rag_tool.add("This is raw text content.", data_type="text")
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
def test_directory_string(
|
||||
self, rag_tool: RagTool, mock_rag_client: MagicMock
|
||||
) -> None:
|
||||
"""Test data_type='directory' with existing directory."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
# Create some files in the directory
|
||||
(Path(tmpdir) / "file1.txt").write_text("Content 1")
|
||||
(Path(tmpdir) / "file2.txt").write_text("Content 2")
|
||||
|
||||
rag_tool.add(path=tmpdir, data_type="directory")
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
|
||||
class TestDataTypeEnumValues:
|
||||
"""Tests for data_type as DataType enum values."""
|
||||
|
||||
def test_datatype_file_enum_with_existing_file(
|
||||
self, rag_tool: RagTool, mock_rag_client: MagicMock
|
||||
) -> None:
|
||||
"""Test data_type=DataType.FILE with existing file (auto-detect)."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
test_file = Path(tmpdir) / "test.txt"
|
||||
test_file.write_text("File enum auto-detect content.")
|
||||
|
||||
rag_tool.add(str(test_file), data_type=DataType.FILE)
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
def test_datatype_file_enum_with_nonexistent_file_raises_error(
|
||||
self, rag_tool: RagTool
|
||||
) -> None:
|
||||
"""Test data_type=DataType.FILE raises FileNotFoundError for missing files."""
|
||||
with pytest.raises(FileNotFoundError, match="File does not exist"):
|
||||
rag_tool.add("nonexistent/file.pdf", data_type=DataType.FILE)
|
||||
|
||||
def test_datatype_pdf_file_enum(
|
||||
self, rag_tool: RagTool, mock_rag_client: MagicMock
|
||||
) -> None:
|
||||
"""Test data_type=DataType.PDF_FILE with existing file."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
test_file = Path(tmpdir) / "test.pdf"
|
||||
test_file.write_bytes(
|
||||
b"%PDF-1.4\n1 0 obj\n<<\n/Type /Catalog\n>>\nendobj\ntrailer\n"
|
||||
b"<<\n/Root 1 0 R\n>>\n%%EOF"
|
||||
)
|
||||
|
||||
with patch(
|
||||
"crewai_tools.adapters.crewai_rag_adapter.DataType.get_loader"
|
||||
) as mock_loader:
|
||||
mock_loader_instance = MagicMock()
|
||||
mock_loader_instance.load.return_value = MagicMock(
|
||||
content="PDF content", metadata={}, doc_id="test-id"
|
||||
)
|
||||
mock_loader.return_value = mock_loader_instance
|
||||
|
||||
rag_tool.add(str(test_file), data_type=DataType.PDF_FILE)
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
def test_datatype_text_file_enum(
|
||||
self, rag_tool: RagTool, mock_rag_client: MagicMock
|
||||
) -> None:
|
||||
"""Test data_type=DataType.TEXT_FILE with existing file."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
test_file = Path(tmpdir) / "test.txt"
|
||||
test_file.write_text("Text file content.")
|
||||
|
||||
rag_tool.add(str(test_file), data_type=DataType.TEXT_FILE)
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
def test_datatype_text_enum(
|
||||
self, rag_tool: RagTool, mock_rag_client: MagicMock
|
||||
) -> None:
|
||||
"""Test data_type=DataType.TEXT with raw text."""
|
||||
rag_tool.add("Raw text using enum.", data_type=DataType.TEXT)
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
def test_datatype_directory_enum(
|
||||
self, rag_tool: RagTool, mock_rag_client: MagicMock
|
||||
) -> None:
|
||||
"""Test data_type=DataType.DIRECTORY with existing directory."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
(Path(tmpdir) / "file.txt").write_text("Directory file content.")
|
||||
|
||||
rag_tool.add(tmpdir, data_type=DataType.DIRECTORY)
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
|
||||
class TestInvalidDataType:
|
||||
"""Tests for invalid data_type values."""
|
||||
|
||||
def test_invalid_string_data_type_raises_error(self, rag_tool: RagTool) -> None:
|
||||
"""Test that invalid string data_type raises ValueError."""
|
||||
with pytest.raises(ValueError, match="Invalid data_type"):
|
||||
rag_tool.add("some content", data_type="invalid_type")
|
||||
|
||||
def test_invalid_data_type_error_message_contains_valid_values(
|
||||
self, rag_tool: RagTool
|
||||
) -> None:
|
||||
"""Test that error message lists valid data_type values."""
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
rag_tool.add("some content", data_type="not_a_type")
|
||||
|
||||
error_message = str(exc_info.value)
|
||||
assert "file" in error_message
|
||||
assert "pdf_file" in error_message
|
||||
assert "text_file" in error_message
|
||||
|
||||
|
||||
class TestFileExistenceValidation:
|
||||
"""Tests for file existence validation."""
|
||||
|
||||
def test_pdf_file_not_found_raises_error(self, rag_tool: RagTool) -> None:
|
||||
"""Test that non-existent PDF file raises FileNotFoundError."""
|
||||
with pytest.raises(FileNotFoundError, match="File does not exist"):
|
||||
rag_tool.add(path="nonexistent.pdf", data_type="pdf_file")
|
||||
|
||||
def test_text_file_not_found_raises_error(self, rag_tool: RagTool) -> None:
|
||||
"""Test that non-existent text file raises FileNotFoundError."""
|
||||
with pytest.raises(FileNotFoundError, match="File does not exist"):
|
||||
rag_tool.add(path="nonexistent.txt", data_type="text_file")
|
||||
|
||||
def test_csv_file_not_found_raises_error(self, rag_tool: RagTool) -> None:
|
||||
"""Test that non-existent CSV file raises FileNotFoundError."""
|
||||
with pytest.raises(FileNotFoundError, match="File does not exist"):
|
||||
rag_tool.add(path="nonexistent.csv", data_type="csv")
|
||||
|
||||
def test_json_file_not_found_raises_error(self, rag_tool: RagTool) -> None:
|
||||
"""Test that non-existent JSON file raises FileNotFoundError."""
|
||||
with pytest.raises(FileNotFoundError, match="File does not exist"):
|
||||
rag_tool.add(path="nonexistent.json", data_type="json")
|
||||
|
||||
def test_xml_file_not_found_raises_error(self, rag_tool: RagTool) -> None:
|
||||
"""Test that non-existent XML file raises FileNotFoundError."""
|
||||
with pytest.raises(FileNotFoundError, match="File does not exist"):
|
||||
rag_tool.add(path="nonexistent.xml", data_type="xml")
|
||||
|
||||
def test_docx_file_not_found_raises_error(self, rag_tool: RagTool) -> None:
|
||||
"""Test that non-existent DOCX file raises FileNotFoundError."""
|
||||
with pytest.raises(FileNotFoundError, match="File does not exist"):
|
||||
rag_tool.add(path="nonexistent.docx", data_type="docx")
|
||||
|
||||
def test_mdx_file_not_found_raises_error(self, rag_tool: RagTool) -> None:
|
||||
"""Test that non-existent MDX file raises FileNotFoundError."""
|
||||
with pytest.raises(FileNotFoundError, match="File does not exist"):
|
||||
rag_tool.add(path="nonexistent.mdx", data_type="mdx")
|
||||
|
||||
def test_directory_not_found_raises_error(self, rag_tool: RagTool) -> None:
|
||||
"""Test that non-existent directory raises ValueError."""
|
||||
with pytest.raises(ValueError, match="Directory does not exist"):
|
||||
rag_tool.add(path="nonexistent/directory", data_type="directory")
|
||||
|
||||
|
||||
class TestKeywordArgumentVariants:
|
||||
"""Tests for different keyword argument combinations."""
|
||||
|
||||
def test_positional_argument_with_data_type(
|
||||
self, rag_tool: RagTool, mock_rag_client: MagicMock
|
||||
) -> None:
|
||||
"""Test positional argument with data_type."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
test_file = Path(tmpdir) / "test.txt"
|
||||
test_file.write_text("Positional arg content.")
|
||||
|
||||
rag_tool.add(str(test_file), data_type="text_file")
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
def test_path_keyword_with_data_type(
|
||||
self, rag_tool: RagTool, mock_rag_client: MagicMock
|
||||
) -> None:
|
||||
"""Test path keyword argument with data_type."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
test_file = Path(tmpdir) / "test.txt"
|
||||
test_file.write_text("Path keyword content.")
|
||||
|
||||
rag_tool.add(path=str(test_file), data_type="text_file")
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
def test_file_path_keyword_with_data_type(
|
||||
self, rag_tool: RagTool, mock_rag_client: MagicMock
|
||||
) -> None:
|
||||
"""Test file_path keyword argument with data_type."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
test_file = Path(tmpdir) / "test.txt"
|
||||
test_file.write_text("File path keyword content.")
|
||||
|
||||
rag_tool.add(file_path=str(test_file), data_type="text_file")
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
def test_directory_path_keyword(
|
||||
self, rag_tool: RagTool, mock_rag_client: MagicMock
|
||||
) -> None:
|
||||
"""Test directory_path keyword argument."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
(Path(tmpdir) / "file.txt").write_text("Directory content.")
|
||||
|
||||
rag_tool.add(directory_path=tmpdir)
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
|
||||
class TestAutoDetection:
|
||||
"""Tests for auto-detection of data type from content."""
|
||||
|
||||
def test_auto_detect_nonexistent_file_raises_error(self, rag_tool: RagTool) -> None:
|
||||
"""Test that auto-detection raises FileNotFoundError for missing files."""
|
||||
with pytest.raises(FileNotFoundError, match="File does not exist"):
|
||||
rag_tool.add("path/to/document.pdf")
|
||||
|
||||
def test_auto_detect_txt_file(
|
||||
self, rag_tool: RagTool, mock_rag_client: MagicMock
|
||||
) -> None:
|
||||
"""Test auto-detection of .txt file type."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
test_file = Path(tmpdir) / "auto.txt"
|
||||
test_file.write_text("Auto-detected text file.")
|
||||
|
||||
# No data_type specified - should auto-detect
|
||||
rag_tool.add(str(test_file))
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
def test_auto_detect_csv_file(
|
||||
self, rag_tool: RagTool, mock_rag_client: MagicMock
|
||||
) -> None:
|
||||
"""Test auto-detection of .csv file type."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
test_file = Path(tmpdir) / "auto.csv"
|
||||
test_file.write_text("col1,col2\nval1,val2")
|
||||
|
||||
rag_tool.add(str(test_file))
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
def test_auto_detect_json_file(
|
||||
self, rag_tool: RagTool, mock_rag_client: MagicMock
|
||||
) -> None:
|
||||
"""Test auto-detection of .json file type."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
test_file = Path(tmpdir) / "auto.json"
|
||||
test_file.write_text('{"auto": "detected"}')
|
||||
|
||||
rag_tool.add(str(test_file))
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
def test_auto_detect_directory(
|
||||
self, rag_tool: RagTool, mock_rag_client: MagicMock
|
||||
) -> None:
|
||||
"""Test auto-detection of directory type."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
(Path(tmpdir) / "file.txt").write_text("Auto-detected directory.")
|
||||
|
||||
rag_tool.add(tmpdir)
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
def test_auto_detect_raw_text(
|
||||
self, rag_tool: RagTool, mock_rag_client: MagicMock
|
||||
) -> None:
|
||||
"""Test auto-detection of raw text (non-file content)."""
|
||||
rag_tool.add("Just some raw text content")
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
|
||||
class TestMetadataHandling:
|
||||
"""Tests for metadata handling with data_type."""
|
||||
|
||||
def test_metadata_passed_to_documents(
|
||||
self, rag_tool: RagTool, mock_rag_client: MagicMock
|
||||
) -> None:
|
||||
"""Test that metadata is properly passed to documents."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
test_file = Path(tmpdir) / "test.txt"
|
||||
test_file.write_text("Content with metadata.")
|
||||
|
||||
rag_tool.add(
|
||||
path=str(test_file),
|
||||
data_type="text_file",
|
||||
metadata={"custom_key": "custom_value"},
|
||||
)
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
call_args = mock_rag_client.add_documents.call_args
|
||||
documents = call_args.kwargs.get("documents", call_args.args[0] if call_args.args else [])
|
||||
|
||||
# Check that at least one document has the custom metadata
|
||||
assert any(
|
||||
doc.get("metadata", {}).get("custom_key") == "custom_value"
|
||||
for doc in documents
|
||||
)
|
||||
@@ -23,15 +23,13 @@ from crewai_tools.tools.rag.rag_tool import Adapter
|
||||
import pytest
|
||||
|
||||
|
||||
pytestmark = [pytest.mark.vcr(filter_headers=["authorization"])]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_adapter():
|
||||
mock_adapter = MagicMock(spec=Adapter)
|
||||
return mock_adapter
|
||||
|
||||
|
||||
@pytest.mark.vcr()
|
||||
def test_directory_search_tool():
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
test_file = Path(temp_dir) / "test.txt"
|
||||
@@ -65,6 +63,7 @@ def test_pdf_search_tool(mock_adapter):
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.vcr()
|
||||
def test_txt_search_tool():
|
||||
with tempfile.NamedTemporaryFile(suffix=".txt", delete=False) as temp_file:
|
||||
temp_file.write(b"This is a test file for txt search")
|
||||
@@ -102,6 +101,7 @@ def test_docx_search_tool(mock_adapter):
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.vcr()
|
||||
def test_json_search_tool():
|
||||
with tempfile.NamedTemporaryFile(suffix=".json", delete=False) as temp_file:
|
||||
temp_file.write(b'{"test": "This is a test JSON file"}')
|
||||
@@ -127,6 +127,7 @@ def test_xml_search_tool(mock_adapter):
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.vcr()
|
||||
def test_csv_search_tool():
|
||||
with tempfile.NamedTemporaryFile(suffix=".csv", delete=False) as temp_file:
|
||||
temp_file.write(b"name,description\ntest,This is a test CSV file")
|
||||
@@ -141,6 +142,7 @@ def test_csv_search_tool():
|
||||
os.unlink(temp_file_path)
|
||||
|
||||
|
||||
@pytest.mark.vcr()
|
||||
def test_mdx_search_tool():
|
||||
with tempfile.NamedTemporaryFile(suffix=".mdx", delete=False) as temp_file:
|
||||
temp_file.write(b"# Test MDX\nThis is a test MDX file")
|
||||
|
||||
@@ -9,35 +9,36 @@ authors = [
|
||||
requires-python = ">=3.10, <3.14"
|
||||
dependencies = [
|
||||
# Core Dependencies
|
||||
"pydantic>=2.11.9",
|
||||
"openai>=1.13.3",
|
||||
"pydantic~=2.11.9",
|
||||
"openai~=1.83.0",
|
||||
"instructor>=1.3.3",
|
||||
# Text Processing
|
||||
"pdfplumber>=0.11.4",
|
||||
"regex>=2024.9.11",
|
||||
"pdfplumber~=0.11.4",
|
||||
"regex~=2024.9.11",
|
||||
# Telemetry and Monitoring
|
||||
"opentelemetry-api>=1.30.0",
|
||||
"opentelemetry-sdk>=1.30.0",
|
||||
"opentelemetry-exporter-otlp-proto-http>=1.30.0",
|
||||
"opentelemetry-api~=1.34.0",
|
||||
"opentelemetry-sdk~=1.34.0",
|
||||
"opentelemetry-exporter-otlp-proto-http~=1.34.0",
|
||||
# Data Handling
|
||||
"chromadb~=1.1.0",
|
||||
"tokenizers>=0.20.3",
|
||||
"openpyxl>=3.1.5",
|
||||
"tokenizers~=0.20.3",
|
||||
"openpyxl~=3.1.5",
|
||||
# Authentication and Security
|
||||
"python-dotenv>=1.1.1",
|
||||
"pyjwt>=2.9.0",
|
||||
"python-dotenv~=1.1.1",
|
||||
"pyjwt~=2.9.0",
|
||||
# Configuration and Utils
|
||||
"click>=8.1.7",
|
||||
"appdirs>=1.4.4",
|
||||
"jsonref>=1.1.0",
|
||||
"json-repair==0.25.2",
|
||||
"uv>=0.4.25",
|
||||
"tomli-w>=1.1.0",
|
||||
"tomli>=2.0.2",
|
||||
"json5>=0.10.0",
|
||||
"portalocker==2.7.0",
|
||||
"pydantic-settings>=2.10.1",
|
||||
"mcp>=1.16.0",
|
||||
"click~=8.1.7",
|
||||
"appdirs~=1.4.4",
|
||||
"jsonref~=1.1.0",
|
||||
"json-repair~=0.25.2",
|
||||
"tomli-w~=1.1.0",
|
||||
"tomli~=2.0.2",
|
||||
"json5~=0.10.0",
|
||||
"portalocker~=2.7.0",
|
||||
"pydantic-settings~=2.10.1",
|
||||
"mcp~=1.16.0",
|
||||
"uv~=0.9.13",
|
||||
"aiosqlite~=0.21.0",
|
||||
]
|
||||
|
||||
[project.urls]
|
||||
@@ -48,55 +49,54 @@ Repository = "https://github.com/crewAIInc/crewAI"
|
||||
|
||||
[project.optional-dependencies]
|
||||
tools = [
|
||||
"crewai-tools==1.6.0",
|
||||
"crewai-tools==1.7.0",
|
||||
]
|
||||
embeddings = [
|
||||
"tiktoken~=0.8.0"
|
||||
]
|
||||
pdfplumber = [
|
||||
"pdfplumber>=0.11.4",
|
||||
]
|
||||
pandas = [
|
||||
"pandas>=2.2.3",
|
||||
"pandas~=2.2.3",
|
||||
]
|
||||
openpyxl = [
|
||||
"openpyxl>=3.1.5",
|
||||
"openpyxl~=3.1.5",
|
||||
]
|
||||
mem0 = ["mem0ai>=0.1.94"]
|
||||
mem0 = ["mem0ai~=0.1.94"]
|
||||
docling = [
|
||||
"docling>=2.12.0",
|
||||
"docling~=2.63.0",
|
||||
]
|
||||
qdrant = [
|
||||
"qdrant-client[fastembed]>=1.14.3",
|
||||
"qdrant-client[fastembed]~=1.14.3",
|
||||
]
|
||||
aws = [
|
||||
"boto3>=1.40.38",
|
||||
"boto3~=1.40.38",
|
||||
"aiobotocore~=2.25.2",
|
||||
]
|
||||
watson = [
|
||||
"ibm-watsonx-ai>=1.3.39",
|
||||
"ibm-watsonx-ai~=1.3.39",
|
||||
]
|
||||
voyageai = [
|
||||
"voyageai>=0.3.5",
|
||||
"voyageai~=0.3.5",
|
||||
]
|
||||
litellm = [
|
||||
"litellm>=1.74.9",
|
||||
"litellm~=1.74.9",
|
||||
]
|
||||
bedrock = [
|
||||
"boto3>=1.40.45",
|
||||
"boto3~=1.40.45",
|
||||
]
|
||||
google-genai = [
|
||||
"google-genai>=1.2.0",
|
||||
"google-genai~=1.2.0",
|
||||
]
|
||||
azure-ai-inference = [
|
||||
"azure-ai-inference>=1.0.0b9",
|
||||
"azure-ai-inference~=1.0.0b9",
|
||||
]
|
||||
anthropic = [
|
||||
"anthropic>=0.69.0",
|
||||
"anthropic~=0.71.0",
|
||||
]
|
||||
a2a = [
|
||||
a2a = [
|
||||
"a2a-sdk~=0.3.10",
|
||||
"httpx-auth>=0.23.1",
|
||||
"httpx-sse>=0.4.0",
|
||||
"httpx-auth~=0.23.1",
|
||||
"httpx-sse~=0.4.0",
|
||||
"aiocache[redis,memcached]~=0.12.3",
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -40,7 +40,7 @@ def _suppress_pydantic_deprecation_warnings() -> None:
|
||||
|
||||
_suppress_pydantic_deprecation_warnings()
|
||||
|
||||
__version__ = "1.6.0"
|
||||
__version__ = "1.7.0"
|
||||
_telemetry_submitted = False
|
||||
|
||||
|
||||
|
||||
4
lib/crewai/src/crewai/a2a/extensions/__init__.py
Normal file
4
lib/crewai/src/crewai/a2a/extensions/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
"""A2A Protocol Extensions for CrewAI.
|
||||
|
||||
This module contains extensions to the A2A (Agent-to-Agent) protocol.
|
||||
"""
|
||||
193
lib/crewai/src/crewai/a2a/extensions/base.py
Normal file
193
lib/crewai/src/crewai/a2a/extensions/base.py
Normal file
@@ -0,0 +1,193 @@
|
||||
"""Base extension interface for A2A wrapper integrations.
|
||||
|
||||
This module defines the protocol for extending A2A wrapper functionality
|
||||
with custom logic for conversation processing, prompt augmentation, and
|
||||
agent response handling.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import TYPE_CHECKING, Any, Protocol
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from a2a.types import Message
|
||||
|
||||
from crewai.agent.core import Agent
|
||||
|
||||
|
||||
class ConversationState(Protocol):
|
||||
"""Protocol for extension-specific conversation state.
|
||||
|
||||
Extensions can define their own state classes that implement this protocol
|
||||
to track conversation-specific data extracted from message history.
|
||||
"""
|
||||
|
||||
def is_ready(self) -> bool:
|
||||
"""Check if the state indicates readiness for some action.
|
||||
|
||||
Returns:
|
||||
True if the state is ready, False otherwise.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class A2AExtension(Protocol):
|
||||
"""Protocol for A2A wrapper extensions.
|
||||
|
||||
Extensions can implement this protocol to inject custom logic into
|
||||
the A2A conversation flow at various integration points.
|
||||
"""
|
||||
|
||||
def inject_tools(self, agent: Agent) -> None:
|
||||
"""Inject extension-specific tools into the agent.
|
||||
|
||||
Called when an agent is wrapped with A2A capabilities. Extensions
|
||||
can add tools that enable extension-specific functionality.
|
||||
|
||||
Args:
|
||||
agent: The agent instance to inject tools into.
|
||||
"""
|
||||
...
|
||||
|
||||
def extract_state_from_history(
|
||||
self, conversation_history: Sequence[Message]
|
||||
) -> ConversationState | None:
|
||||
"""Extract extension-specific state from conversation history.
|
||||
|
||||
Called during prompt augmentation to allow extensions to analyze
|
||||
the conversation history and extract relevant state information.
|
||||
|
||||
Args:
|
||||
conversation_history: The sequence of A2A messages exchanged.
|
||||
|
||||
Returns:
|
||||
Extension-specific conversation state, or None if no relevant state.
|
||||
"""
|
||||
...
|
||||
|
||||
def augment_prompt(
|
||||
self,
|
||||
base_prompt: str,
|
||||
conversation_state: ConversationState | None,
|
||||
) -> str:
|
||||
"""Augment the task prompt with extension-specific instructions.
|
||||
|
||||
Called during prompt augmentation to allow extensions to add
|
||||
custom instructions based on conversation state.
|
||||
|
||||
Args:
|
||||
base_prompt: The base prompt to augment.
|
||||
conversation_state: Extension-specific state from extract_state_from_history.
|
||||
|
||||
Returns:
|
||||
The augmented prompt with extension-specific instructions.
|
||||
"""
|
||||
...
|
||||
|
||||
def process_response(
|
||||
self,
|
||||
agent_response: Any,
|
||||
conversation_state: ConversationState | None,
|
||||
) -> Any:
|
||||
"""Process and potentially modify the agent response.
|
||||
|
||||
Called after parsing the agent's response, allowing extensions to
|
||||
enhance or modify the response based on conversation state.
|
||||
|
||||
Args:
|
||||
agent_response: The parsed agent response.
|
||||
conversation_state: Extension-specific state from extract_state_from_history.
|
||||
|
||||
Returns:
|
||||
The processed agent response (may be modified or original).
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class ExtensionRegistry:
|
||||
"""Registry for managing A2A extensions.
|
||||
|
||||
Maintains a collection of extensions and provides methods to invoke
|
||||
their hooks at various integration points.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the extension registry."""
|
||||
self._extensions: list[A2AExtension] = []
|
||||
|
||||
def register(self, extension: A2AExtension) -> None:
|
||||
"""Register an extension.
|
||||
|
||||
Args:
|
||||
extension: The extension to register.
|
||||
"""
|
||||
self._extensions.append(extension)
|
||||
|
||||
def inject_all_tools(self, agent: Agent) -> None:
|
||||
"""Inject tools from all registered extensions.
|
||||
|
||||
Args:
|
||||
agent: The agent instance to inject tools into.
|
||||
"""
|
||||
for extension in self._extensions:
|
||||
extension.inject_tools(agent)
|
||||
|
||||
def extract_all_states(
|
||||
self, conversation_history: Sequence[Message]
|
||||
) -> dict[type[A2AExtension], ConversationState]:
|
||||
"""Extract conversation states from all registered extensions.
|
||||
|
||||
Args:
|
||||
conversation_history: The sequence of A2A messages exchanged.
|
||||
|
||||
Returns:
|
||||
Mapping of extension types to their conversation states.
|
||||
"""
|
||||
states: dict[type[A2AExtension], ConversationState] = {}
|
||||
for extension in self._extensions:
|
||||
state = extension.extract_state_from_history(conversation_history)
|
||||
if state is not None:
|
||||
states[type(extension)] = state
|
||||
return states
|
||||
|
||||
def augment_prompt_with_all(
|
||||
self,
|
||||
base_prompt: str,
|
||||
extension_states: dict[type[A2AExtension], ConversationState],
|
||||
) -> str:
|
||||
"""Augment prompt with instructions from all registered extensions.
|
||||
|
||||
Args:
|
||||
base_prompt: The base prompt to augment.
|
||||
extension_states: Mapping of extension types to conversation states.
|
||||
|
||||
Returns:
|
||||
The fully augmented prompt.
|
||||
"""
|
||||
augmented = base_prompt
|
||||
for extension in self._extensions:
|
||||
state = extension_states.get(type(extension))
|
||||
augmented = extension.augment_prompt(augmented, state)
|
||||
return augmented
|
||||
|
||||
def process_response_with_all(
|
||||
self,
|
||||
agent_response: Any,
|
||||
extension_states: dict[type[A2AExtension], ConversationState],
|
||||
) -> Any:
|
||||
"""Process response through all registered extensions.
|
||||
|
||||
Args:
|
||||
agent_response: The parsed agent response.
|
||||
extension_states: Mapping of extension types to conversation states.
|
||||
|
||||
Returns:
|
||||
The processed agent response.
|
||||
"""
|
||||
processed = agent_response
|
||||
for extension in self._extensions:
|
||||
state = extension_states.get(type(extension))
|
||||
processed = extension.process_response(processed, state)
|
||||
return processed
|
||||
34
lib/crewai/src/crewai/a2a/extensions/registry.py
Normal file
34
lib/crewai/src/crewai/a2a/extensions/registry.py
Normal file
@@ -0,0 +1,34 @@
|
||||
"""Extension registry factory for A2A configurations.
|
||||
|
||||
This module provides utilities for creating extension registries from A2A configurations.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from crewai.a2a.extensions.base import ExtensionRegistry
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.a2a.config import A2AConfig
|
||||
|
||||
|
||||
def create_extension_registry_from_config(
|
||||
a2a_config: list[A2AConfig] | A2AConfig,
|
||||
) -> ExtensionRegistry:
|
||||
"""Create an extension registry from A2A configuration.
|
||||
|
||||
Args:
|
||||
a2a_config: A2A configuration (single or list)
|
||||
|
||||
Returns:
|
||||
Configured extension registry with all applicable extensions
|
||||
"""
|
||||
registry = ExtensionRegistry()
|
||||
configs = a2a_config if isinstance(a2a_config, list) else [a2a_config]
|
||||
|
||||
for _ in configs:
|
||||
pass
|
||||
|
||||
return registry
|
||||
@@ -23,6 +23,8 @@ from a2a.types import (
|
||||
TextPart,
|
||||
TransportProtocol,
|
||||
)
|
||||
from aiocache import cached # type: ignore[import-untyped]
|
||||
from aiocache.serializers import PickleSerializer # type: ignore[import-untyped]
|
||||
import httpx
|
||||
from pydantic import BaseModel, Field, create_model
|
||||
|
||||
@@ -65,7 +67,7 @@ def _fetch_agent_card_cached(
|
||||
endpoint: A2A agent endpoint URL
|
||||
auth_hash: Hash of the auth object
|
||||
timeout: Request timeout
|
||||
_ttl_hash: Time-based hash for cache invalidation (unused in body)
|
||||
_ttl_hash: Time-based hash for cache invalidation
|
||||
|
||||
Returns:
|
||||
Cached AgentCard
|
||||
@@ -106,7 +108,18 @@ def fetch_agent_card(
|
||||
A2AClientHTTPError: If authentication fails
|
||||
"""
|
||||
if use_cache:
|
||||
auth_hash = hash((type(auth).__name__, id(auth))) if auth else 0
|
||||
if auth:
|
||||
auth_data = auth.model_dump_json(
|
||||
exclude={
|
||||
"_access_token",
|
||||
"_token_expires_at",
|
||||
"_refresh_token",
|
||||
"_authorization_callback",
|
||||
}
|
||||
)
|
||||
auth_hash = hash((type(auth).__name__, auth_data))
|
||||
else:
|
||||
auth_hash = 0
|
||||
_auth_store[auth_hash] = auth
|
||||
ttl_hash = int(time.time() // cache_ttl)
|
||||
return _fetch_agent_card_cached(endpoint, auth_hash, timeout, ttl_hash)
|
||||
@@ -121,6 +134,26 @@ def fetch_agent_card(
|
||||
loop.close()
|
||||
|
||||
|
||||
@cached(ttl=300, serializer=PickleSerializer()) # type: ignore[untyped-decorator]
|
||||
async def _fetch_agent_card_async_cached(
|
||||
endpoint: str,
|
||||
auth_hash: int,
|
||||
timeout: int,
|
||||
) -> AgentCard:
|
||||
"""Cached async implementation of AgentCard fetching.
|
||||
|
||||
Args:
|
||||
endpoint: A2A agent endpoint URL
|
||||
auth_hash: Hash of the auth object
|
||||
timeout: Request timeout in seconds
|
||||
|
||||
Returns:
|
||||
Cached AgentCard object
|
||||
"""
|
||||
auth = _auth_store.get(auth_hash)
|
||||
return await _fetch_agent_card_async(endpoint=endpoint, auth=auth, timeout=timeout)
|
||||
|
||||
|
||||
async def _fetch_agent_card_async(
|
||||
endpoint: str,
|
||||
auth: AuthScheme | None,
|
||||
@@ -339,7 +372,22 @@ async def _execute_a2a_delegation_async(
|
||||
Returns:
|
||||
Dictionary with status, result/error, and new history
|
||||
"""
|
||||
agent_card = await _fetch_agent_card_async(endpoint, auth, timeout)
|
||||
if auth:
|
||||
auth_data = auth.model_dump_json(
|
||||
exclude={
|
||||
"_access_token",
|
||||
"_token_expires_at",
|
||||
"_refresh_token",
|
||||
"_authorization_callback",
|
||||
}
|
||||
)
|
||||
auth_hash = hash((type(auth).__name__, auth_data))
|
||||
else:
|
||||
auth_hash = 0
|
||||
_auth_store[auth_hash] = auth
|
||||
agent_card = await _fetch_agent_card_async_cached(
|
||||
endpoint=endpoint, auth_hash=auth_hash, timeout=timeout
|
||||
)
|
||||
|
||||
validate_auth_against_agent_card(agent_card, auth)
|
||||
|
||||
@@ -556,6 +604,34 @@ async def _execute_a2a_delegation_async(
|
||||
}
|
||||
break
|
||||
except Exception as e:
|
||||
if isinstance(e, A2AClientHTTPError):
|
||||
error_msg = f"HTTP Error {e.status_code}: {e!s}"
|
||||
|
||||
error_message = Message(
|
||||
role=Role.agent,
|
||||
message_id=str(uuid.uuid4()),
|
||||
parts=[Part(root=TextPart(text=error_msg))],
|
||||
context_id=context_id,
|
||||
task_id=task_id,
|
||||
)
|
||||
new_messages.append(error_message)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
None,
|
||||
A2AResponseReceivedEvent(
|
||||
response=error_msg,
|
||||
turn_number=turn_number,
|
||||
is_multiturn=is_multiturn,
|
||||
status="failed",
|
||||
agent_role=agent_role,
|
||||
),
|
||||
)
|
||||
return {
|
||||
"status": "failed",
|
||||
"error": error_msg,
|
||||
"history": new_messages,
|
||||
}
|
||||
|
||||
current_exception: Exception | BaseException | None = e
|
||||
while current_exception:
|
||||
if hasattr(current_exception, "response"):
|
||||
@@ -752,4 +828,5 @@ def get_a2a_agents_and_response_model(
|
||||
Tuple of A2A agent IDs and response model
|
||||
"""
|
||||
a2a_agents, agent_ids = extract_a2a_agent_ids_from_config(a2a_config=a2a_config)
|
||||
|
||||
return a2a_agents, create_agent_response_model(agent_ids)
|
||||
|
||||
@@ -15,6 +15,7 @@ from a2a.types import Role
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from crewai.a2a.config import A2AConfig
|
||||
from crewai.a2a.extensions.base import ExtensionRegistry
|
||||
from crewai.a2a.templates import (
|
||||
AVAILABLE_AGENTS_TEMPLATE,
|
||||
CONVERSATION_TURN_INFO_TEMPLATE,
|
||||
@@ -42,7 +43,9 @@ if TYPE_CHECKING:
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
|
||||
|
||||
def wrap_agent_with_a2a_instance(agent: Agent) -> None:
|
||||
def wrap_agent_with_a2a_instance(
|
||||
agent: Agent, extension_registry: ExtensionRegistry | None = None
|
||||
) -> None:
|
||||
"""Wrap an agent instance's execute_task method with A2A support.
|
||||
|
||||
This function modifies the agent instance by wrapping its execute_task
|
||||
@@ -51,7 +54,13 @@ def wrap_agent_with_a2a_instance(agent: Agent) -> None:
|
||||
|
||||
Args:
|
||||
agent: The agent instance to wrap
|
||||
extension_registry: Optional registry of A2A extensions for injecting tools and custom logic
|
||||
"""
|
||||
if extension_registry is None:
|
||||
extension_registry = ExtensionRegistry()
|
||||
|
||||
extension_registry.inject_all_tools(agent)
|
||||
|
||||
original_execute_task = agent.execute_task.__func__ # type: ignore[attr-defined]
|
||||
|
||||
@wraps(original_execute_task)
|
||||
@@ -85,6 +94,7 @@ def wrap_agent_with_a2a_instance(agent: Agent) -> None:
|
||||
agent_response_model=agent_response_model,
|
||||
context=context,
|
||||
tools=tools,
|
||||
extension_registry=extension_registry,
|
||||
)
|
||||
|
||||
object.__setattr__(agent, "execute_task", MethodType(execute_task_with_a2a, agent))
|
||||
@@ -154,6 +164,7 @@ def _execute_task_with_a2a(
|
||||
agent_response_model: type[BaseModel],
|
||||
context: str | None,
|
||||
tools: list[BaseTool] | None,
|
||||
extension_registry: ExtensionRegistry,
|
||||
) -> str:
|
||||
"""Wrap execute_task with A2A delegation logic.
|
||||
|
||||
@@ -165,6 +176,7 @@ def _execute_task_with_a2a(
|
||||
context: Optional context for task execution
|
||||
tools: Optional tools available to the agent
|
||||
agent_response_model: Optional agent response model
|
||||
extension_registry: Registry of A2A extensions
|
||||
|
||||
Returns:
|
||||
Task execution result (either from LLM or A2A agent)
|
||||
@@ -190,11 +202,12 @@ def _execute_task_with_a2a(
|
||||
finally:
|
||||
task.description = original_description
|
||||
|
||||
task.description = _augment_prompt_with_a2a(
|
||||
task.description, _ = _augment_prompt_with_a2a(
|
||||
a2a_agents=a2a_agents,
|
||||
task_description=original_description,
|
||||
agent_cards=agent_cards,
|
||||
failed_agents=failed_agents,
|
||||
extension_registry=extension_registry,
|
||||
)
|
||||
task.response_model = agent_response_model
|
||||
|
||||
@@ -204,6 +217,11 @@ def _execute_task_with_a2a(
|
||||
raw_result=raw_result, agent_response_model=agent_response_model
|
||||
)
|
||||
|
||||
if extension_registry and isinstance(agent_response, BaseModel):
|
||||
agent_response = extension_registry.process_response_with_all(
|
||||
agent_response, {}
|
||||
)
|
||||
|
||||
if isinstance(agent_response, BaseModel) and isinstance(
|
||||
agent_response, AgentResponseProtocol
|
||||
):
|
||||
@@ -217,6 +235,7 @@ def _execute_task_with_a2a(
|
||||
tools=tools,
|
||||
agent_cards=agent_cards,
|
||||
original_task_description=original_description,
|
||||
extension_registry=extension_registry,
|
||||
)
|
||||
return str(agent_response.message)
|
||||
|
||||
@@ -235,7 +254,8 @@ def _augment_prompt_with_a2a(
|
||||
turn_num: int = 0,
|
||||
max_turns: int | None = None,
|
||||
failed_agents: dict[str, str] | None = None,
|
||||
) -> str:
|
||||
extension_registry: ExtensionRegistry | None = None,
|
||||
) -> tuple[str, bool]:
|
||||
"""Add A2A delegation instructions to prompt.
|
||||
|
||||
Args:
|
||||
@@ -246,13 +266,14 @@ def _augment_prompt_with_a2a(
|
||||
turn_num: Current turn number (0-indexed)
|
||||
max_turns: Maximum allowed turns (from config)
|
||||
failed_agents: Dictionary mapping failed agent endpoints to error messages
|
||||
extension_registry: Optional registry of A2A extensions
|
||||
|
||||
Returns:
|
||||
Augmented task description with A2A instructions
|
||||
Tuple of (augmented prompt, disable_structured_output flag)
|
||||
"""
|
||||
|
||||
if not agent_cards:
|
||||
return task_description
|
||||
return task_description, False
|
||||
|
||||
agents_text = ""
|
||||
|
||||
@@ -270,6 +291,7 @@ def _augment_prompt_with_a2a(
|
||||
agents_text = AVAILABLE_AGENTS_TEMPLATE.substitute(available_a2a_agents=agents_text)
|
||||
|
||||
history_text = ""
|
||||
|
||||
if conversation_history:
|
||||
for msg in conversation_history:
|
||||
history_text += f"\n{msg.model_dump_json(indent=2, exclude_none=True, exclude={'message_id'})}\n"
|
||||
@@ -277,6 +299,15 @@ def _augment_prompt_with_a2a(
|
||||
history_text = PREVIOUS_A2A_CONVERSATION_TEMPLATE.substitute(
|
||||
previous_a2a_conversation=history_text
|
||||
)
|
||||
|
||||
extension_states = {}
|
||||
disable_structured_output = False
|
||||
if extension_registry and conversation_history:
|
||||
extension_states = extension_registry.extract_all_states(conversation_history)
|
||||
for state in extension_states.values():
|
||||
if state.is_ready():
|
||||
disable_structured_output = True
|
||||
break
|
||||
turn_info = ""
|
||||
|
||||
if max_turns is not None and conversation_history:
|
||||
@@ -296,16 +327,22 @@ def _augment_prompt_with_a2a(
|
||||
warning=warning,
|
||||
)
|
||||
|
||||
return f"""{task_description}
|
||||
augmented_prompt = f"""{task_description}
|
||||
|
||||
IMPORTANT: You have the ability to delegate this task to remote A2A agents.
|
||||
|
||||
{agents_text}
|
||||
{history_text}{turn_info}
|
||||
|
||||
|
||||
"""
|
||||
|
||||
if extension_registry:
|
||||
augmented_prompt = extension_registry.augment_prompt_with_all(
|
||||
augmented_prompt, extension_states
|
||||
)
|
||||
|
||||
return augmented_prompt, disable_structured_output
|
||||
|
||||
|
||||
def _parse_agent_response(
|
||||
raw_result: str | dict[str, Any], agent_response_model: type[BaseModel]
|
||||
@@ -373,7 +410,7 @@ def _handle_agent_response_and_continue(
|
||||
if "agent_card" in a2a_result and agent_id not in agent_cards_dict:
|
||||
agent_cards_dict[agent_id] = a2a_result["agent_card"]
|
||||
|
||||
task.description = _augment_prompt_with_a2a(
|
||||
task.description, disable_structured_output = _augment_prompt_with_a2a(
|
||||
a2a_agents=a2a_agents,
|
||||
task_description=original_task_description,
|
||||
conversation_history=conversation_history,
|
||||
@@ -382,7 +419,38 @@ def _handle_agent_response_and_continue(
|
||||
agent_cards=agent_cards_dict,
|
||||
)
|
||||
|
||||
original_response_model = task.response_model
|
||||
if disable_structured_output:
|
||||
task.response_model = None
|
||||
|
||||
raw_result = original_fn(self, task, context, tools)
|
||||
|
||||
if disable_structured_output:
|
||||
task.response_model = original_response_model
|
||||
|
||||
if disable_structured_output:
|
||||
final_turn_number = turn_num + 1
|
||||
result_text = str(raw_result)
|
||||
crewai_event_bus.emit(
|
||||
None,
|
||||
A2AMessageSentEvent(
|
||||
message=result_text,
|
||||
turn_number=final_turn_number,
|
||||
is_multiturn=True,
|
||||
agent_role=self.role,
|
||||
),
|
||||
)
|
||||
crewai_event_bus.emit(
|
||||
None,
|
||||
A2AConversationCompletedEvent(
|
||||
status="completed",
|
||||
final_result=result_text,
|
||||
error=None,
|
||||
total_turns=final_turn_number,
|
||||
),
|
||||
)
|
||||
return result_text, None
|
||||
|
||||
llm_response = _parse_agent_response(
|
||||
raw_result=raw_result, agent_response_model=agent_response_model
|
||||
)
|
||||
@@ -425,6 +493,7 @@ def _delegate_to_a2a(
|
||||
tools: list[BaseTool] | None,
|
||||
agent_cards: dict[str, AgentCard] | None = None,
|
||||
original_task_description: str | None = None,
|
||||
extension_registry: ExtensionRegistry | None = None,
|
||||
) -> str:
|
||||
"""Delegate to A2A agent with multi-turn conversation support.
|
||||
|
||||
@@ -437,6 +506,7 @@ def _delegate_to_a2a(
|
||||
tools: Optional tools available to the agent
|
||||
agent_cards: Pre-fetched agent cards from _execute_task_with_a2a
|
||||
original_task_description: The original task description before A2A augmentation
|
||||
extension_registry: Optional registry of A2A extensions
|
||||
|
||||
Returns:
|
||||
Result from A2A agent
|
||||
@@ -447,9 +517,13 @@ def _delegate_to_a2a(
|
||||
a2a_agents, agent_response_model = get_a2a_agents_and_response_model(self.a2a)
|
||||
agent_ids = tuple(config.endpoint for config in a2a_agents)
|
||||
current_request = str(agent_response.message)
|
||||
agent_id = agent_response.a2a_ids[0]
|
||||
|
||||
if agent_id not in agent_ids:
|
||||
if hasattr(agent_response, "a2a_ids") and agent_response.a2a_ids:
|
||||
agent_id = agent_response.a2a_ids[0]
|
||||
else:
|
||||
agent_id = agent_ids[0] if agent_ids else ""
|
||||
|
||||
if agent_id and agent_id not in agent_ids:
|
||||
raise ValueError(
|
||||
f"Unknown A2A agent ID(s): {agent_response.a2a_ids} not in {agent_ids}"
|
||||
)
|
||||
@@ -458,10 +532,11 @@ def _delegate_to_a2a(
|
||||
task_config = task.config or {}
|
||||
context_id = task_config.get("context_id")
|
||||
task_id_config = task_config.get("task_id")
|
||||
reference_task_ids = task_config.get("reference_task_ids")
|
||||
metadata = task_config.get("metadata")
|
||||
extensions = task_config.get("extensions")
|
||||
|
||||
reference_task_ids = task_config.get("reference_task_ids", [])
|
||||
|
||||
if original_task_description is None:
|
||||
original_task_description = task.description
|
||||
|
||||
@@ -497,11 +572,27 @@ def _delegate_to_a2a(
|
||||
|
||||
conversation_history = a2a_result.get("history", [])
|
||||
|
||||
if conversation_history:
|
||||
latest_message = conversation_history[-1]
|
||||
if latest_message.task_id is not None:
|
||||
task_id_config = latest_message.task_id
|
||||
if latest_message.context_id is not None:
|
||||
context_id = latest_message.context_id
|
||||
|
||||
if a2a_result["status"] in ["completed", "input_required"]:
|
||||
if (
|
||||
a2a_result["status"] == "completed"
|
||||
and agent_config.trust_remote_completion_status
|
||||
):
|
||||
if (
|
||||
task_id_config is not None
|
||||
and task_id_config not in reference_task_ids
|
||||
):
|
||||
reference_task_ids.append(task_id_config)
|
||||
if task.config is None:
|
||||
task.config = {}
|
||||
task.config["reference_task_ids"] = reference_task_ids
|
||||
|
||||
result_text = a2a_result.get("result", "")
|
||||
final_turn_number = turn_num + 1
|
||||
crewai_event_bus.emit(
|
||||
@@ -513,7 +604,7 @@ def _delegate_to_a2a(
|
||||
total_turns=final_turn_number,
|
||||
),
|
||||
)
|
||||
return result_text # type: ignore[no-any-return]
|
||||
return cast(str, result_text)
|
||||
|
||||
final_result, next_request = _handle_agent_response_and_continue(
|
||||
self=self,
|
||||
@@ -541,6 +632,31 @@ def _delegate_to_a2a(
|
||||
continue
|
||||
|
||||
error_msg = a2a_result.get("error", "Unknown error")
|
||||
|
||||
final_result, next_request = _handle_agent_response_and_continue(
|
||||
self=self,
|
||||
a2a_result=a2a_result,
|
||||
agent_id=agent_id,
|
||||
agent_cards=agent_cards,
|
||||
a2a_agents=a2a_agents,
|
||||
original_task_description=original_task_description,
|
||||
conversation_history=conversation_history,
|
||||
turn_num=turn_num,
|
||||
max_turns=max_turns,
|
||||
task=task,
|
||||
original_fn=original_fn,
|
||||
context=context,
|
||||
tools=tools,
|
||||
agent_response_model=agent_response_model,
|
||||
)
|
||||
|
||||
if final_result is not None:
|
||||
return final_result
|
||||
|
||||
if next_request is not None:
|
||||
current_request = next_request
|
||||
continue
|
||||
|
||||
crewai_event_bus.emit(
|
||||
None,
|
||||
A2AConversationCompletedEvent(
|
||||
@@ -550,7 +666,7 @@ def _delegate_to_a2a(
|
||||
total_turns=turn_num + 1,
|
||||
),
|
||||
)
|
||||
raise Exception(f"A2A delegation failed: {error_msg}")
|
||||
return f"A2A delegation failed: {error_msg}"
|
||||
|
||||
if conversation_history:
|
||||
for msg in reversed(conversation_history):
|
||||
|
||||
@@ -2,7 +2,6 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Sequence
|
||||
import json
|
||||
import shutil
|
||||
import subprocess
|
||||
import time
|
||||
@@ -19,6 +18,19 @@ from pydantic import BaseModel, Field, InstanceOf, PrivateAttr, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
from crewai.a2a.config import A2AConfig
|
||||
from crewai.agent.utils import (
|
||||
ahandle_knowledge_retrieval,
|
||||
apply_training_data,
|
||||
build_task_prompt_with_schema,
|
||||
format_task_with_context,
|
||||
get_knowledge_config,
|
||||
handle_knowledge_retrieval,
|
||||
handle_reasoning,
|
||||
prepare_tools,
|
||||
process_tool_results,
|
||||
save_last_messages,
|
||||
validate_max_execution_time,
|
||||
)
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
from crewai.agents.cache.cache_handler import CacheHandler
|
||||
from crewai.agents.crew_agent_executor import CrewAgentExecutor
|
||||
@@ -27,9 +39,6 @@ from crewai.events.types.knowledge_events import (
|
||||
KnowledgeQueryCompletedEvent,
|
||||
KnowledgeQueryFailedEvent,
|
||||
KnowledgeQueryStartedEvent,
|
||||
KnowledgeRetrievalCompletedEvent,
|
||||
KnowledgeRetrievalStartedEvent,
|
||||
KnowledgeSearchQueryFailedEvent,
|
||||
)
|
||||
from crewai.events.types.memory_events import (
|
||||
MemoryRetrievalCompletedEvent,
|
||||
@@ -37,7 +46,6 @@ from crewai.events.types.memory_events import (
|
||||
)
|
||||
from crewai.knowledge.knowledge import Knowledge
|
||||
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
|
||||
from crewai.knowledge.utils.knowledge_utils import extract_knowledge_context
|
||||
from crewai.lite_agent import LiteAgent
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.mcp import (
|
||||
@@ -61,7 +69,7 @@ from crewai.utilities.agent_utils import (
|
||||
render_text_description_and_args,
|
||||
)
|
||||
from crewai.utilities.constants import TRAINED_AGENTS_DATA_FILE, TRAINING_DATA_FILE
|
||||
from crewai.utilities.converter import Converter, generate_model_description
|
||||
from crewai.utilities.converter import Converter
|
||||
from crewai.utilities.guardrail_types import GuardrailType
|
||||
from crewai.utilities.llm_utils import create_llm
|
||||
from crewai.utilities.prompts import Prompts
|
||||
@@ -295,53 +303,15 @@ class Agent(BaseAgent):
|
||||
ValueError: If the max execution time is not a positive integer.
|
||||
RuntimeError: If the agent execution fails for other reasons.
|
||||
"""
|
||||
if self.reasoning:
|
||||
try:
|
||||
from crewai.utilities.reasoning_handler import (
|
||||
AgentReasoning,
|
||||
AgentReasoningOutput,
|
||||
)
|
||||
|
||||
reasoning_handler = AgentReasoning(task=task, agent=self)
|
||||
reasoning_output: AgentReasoningOutput = (
|
||||
reasoning_handler.handle_agent_reasoning()
|
||||
)
|
||||
|
||||
# Add the reasoning plan to the task description
|
||||
task.description += f"\n\nReasoning Plan:\n{reasoning_output.plan.plan}"
|
||||
except Exception as e:
|
||||
self._logger.log("error", f"Error during reasoning process: {e!s}")
|
||||
handle_reasoning(self, task)
|
||||
self._inject_date_to_task(task)
|
||||
|
||||
if self.tools_handler:
|
||||
self.tools_handler.last_used_tool = None
|
||||
|
||||
task_prompt = task.prompt()
|
||||
|
||||
# If the task requires output in JSON or Pydantic format,
|
||||
# append specific instructions to the task prompt to ensure
|
||||
# that the final answer does not include any code block markers
|
||||
# Skip this if task.response_model is set, as native structured outputs handle schema automatically
|
||||
if (task.output_json or task.output_pydantic) and not task.response_model:
|
||||
# Generate the schema based on the output format
|
||||
if task.output_json:
|
||||
schema_dict = generate_model_description(task.output_json)
|
||||
schema = json.dumps(schema_dict["json_schema"]["schema"], indent=2)
|
||||
task_prompt += "\n" + self.i18n.slice(
|
||||
"formatted_task_instructions"
|
||||
).format(output_format=schema)
|
||||
|
||||
elif task.output_pydantic:
|
||||
schema_dict = generate_model_description(task.output_pydantic)
|
||||
schema = json.dumps(schema_dict["json_schema"]["schema"], indent=2)
|
||||
task_prompt += "\n" + self.i18n.slice(
|
||||
"formatted_task_instructions"
|
||||
).format(output_format=schema)
|
||||
|
||||
if context:
|
||||
task_prompt = self.i18n.slice("task_with_context").format(
|
||||
task=task_prompt, context=context
|
||||
)
|
||||
task_prompt = build_task_prompt_with_schema(task, task_prompt, self.i18n)
|
||||
task_prompt = format_task_with_context(task_prompt, context, self.i18n)
|
||||
|
||||
if self._is_any_available_memory():
|
||||
crewai_event_bus.emit(
|
||||
@@ -379,84 +349,20 @@ class Agent(BaseAgent):
|
||||
from_task=task,
|
||||
),
|
||||
)
|
||||
knowledge_config = (
|
||||
self.knowledge_config.model_dump() if self.knowledge_config else {}
|
||||
|
||||
knowledge_config = get_knowledge_config(self)
|
||||
task_prompt = handle_knowledge_retrieval(
|
||||
self,
|
||||
task,
|
||||
task_prompt,
|
||||
knowledge_config,
|
||||
self.knowledge.query if self.knowledge else lambda *a, **k: None,
|
||||
self.crew.query_knowledge if self.crew else lambda *a, **k: None,
|
||||
)
|
||||
|
||||
if self.knowledge or (self.crew and self.crew.knowledge):
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=KnowledgeRetrievalStartedEvent(
|
||||
from_task=task,
|
||||
from_agent=self,
|
||||
),
|
||||
)
|
||||
try:
|
||||
self.knowledge_search_query = self._get_knowledge_search_query(
|
||||
task_prompt, task
|
||||
)
|
||||
if self.knowledge_search_query:
|
||||
# Quering agent specific knowledge
|
||||
if self.knowledge:
|
||||
agent_knowledge_snippets = self.knowledge.query(
|
||||
[self.knowledge_search_query], **knowledge_config
|
||||
)
|
||||
if agent_knowledge_snippets:
|
||||
self.agent_knowledge_context = extract_knowledge_context(
|
||||
agent_knowledge_snippets
|
||||
)
|
||||
if self.agent_knowledge_context:
|
||||
task_prompt += self.agent_knowledge_context
|
||||
prepare_tools(self, tools, task)
|
||||
task_prompt = apply_training_data(self, task_prompt)
|
||||
|
||||
# Quering crew specific knowledge
|
||||
knowledge_snippets = self.crew.query_knowledge(
|
||||
[self.knowledge_search_query], **knowledge_config
|
||||
)
|
||||
if knowledge_snippets:
|
||||
self.crew_knowledge_context = extract_knowledge_context(
|
||||
knowledge_snippets
|
||||
)
|
||||
if self.crew_knowledge_context:
|
||||
task_prompt += self.crew_knowledge_context
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=KnowledgeRetrievalCompletedEvent(
|
||||
query=self.knowledge_search_query,
|
||||
from_task=task,
|
||||
from_agent=self,
|
||||
retrieved_knowledge=(
|
||||
(self.agent_knowledge_context or "")
|
||||
+ (
|
||||
"\n"
|
||||
if self.agent_knowledge_context
|
||||
and self.crew_knowledge_context
|
||||
else ""
|
||||
)
|
||||
+ (self.crew_knowledge_context or "")
|
||||
),
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=KnowledgeSearchQueryFailedEvent(
|
||||
query=self.knowledge_search_query or "",
|
||||
error=str(e),
|
||||
from_task=task,
|
||||
from_agent=self,
|
||||
),
|
||||
)
|
||||
|
||||
tools = tools or self.tools or []
|
||||
self.create_agent_executor(tools=tools, task=task)
|
||||
|
||||
if self.crew and self.crew._train:
|
||||
task_prompt = self._training_handler(task_prompt=task_prompt)
|
||||
else:
|
||||
task_prompt = self._use_trained_data(task_prompt=task_prompt)
|
||||
|
||||
# Import agent events locally to avoid circular imports
|
||||
from crewai.events.types.agent_events import (
|
||||
AgentExecutionCompletedEvent,
|
||||
AgentExecutionErrorEvent,
|
||||
@@ -474,15 +380,8 @@ class Agent(BaseAgent):
|
||||
),
|
||||
)
|
||||
|
||||
# Determine execution method based on timeout setting
|
||||
validate_max_execution_time(self.max_execution_time)
|
||||
if self.max_execution_time is not None:
|
||||
if (
|
||||
not isinstance(self.max_execution_time, int)
|
||||
or self.max_execution_time <= 0
|
||||
):
|
||||
raise ValueError(
|
||||
"Max Execution time must be a positive integer greater than zero"
|
||||
)
|
||||
result = self._execute_with_timeout(
|
||||
task_prompt, task, self.max_execution_time
|
||||
)
|
||||
@@ -490,7 +389,6 @@ class Agent(BaseAgent):
|
||||
result = self._execute_without_timeout(task_prompt, task)
|
||||
|
||||
except TimeoutError as e:
|
||||
# Propagate TimeoutError without retry
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=AgentExecutionErrorEvent(
|
||||
@@ -502,7 +400,6 @@ class Agent(BaseAgent):
|
||||
raise e
|
||||
except Exception as e:
|
||||
if e.__class__.__module__.startswith("litellm"):
|
||||
# Do not retry on litellm errors
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=AgentExecutionErrorEvent(
|
||||
@@ -528,23 +425,13 @@ class Agent(BaseAgent):
|
||||
if self.max_rpm and self._rpm_controller:
|
||||
self._rpm_controller.stop_rpm_counter()
|
||||
|
||||
# If there was any tool in self.tools_results that had result_as_answer
|
||||
# set to True, return the results of the last tool that had
|
||||
# result_as_answer set to True
|
||||
for tool_result in self.tools_results:
|
||||
if tool_result.get("result_as_answer", False):
|
||||
result = tool_result["result"]
|
||||
result = process_tool_results(self, result)
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=AgentExecutionCompletedEvent(agent=self, task=task, output=result),
|
||||
)
|
||||
|
||||
self._last_messages = (
|
||||
self.agent_executor.messages.copy()
|
||||
if self.agent_executor and hasattr(self.agent_executor, "messages")
|
||||
else []
|
||||
)
|
||||
|
||||
save_last_messages(self)
|
||||
self._cleanup_mcp_clients()
|
||||
|
||||
return result
|
||||
@@ -604,6 +491,208 @@ class Agent(BaseAgent):
|
||||
}
|
||||
)["output"]
|
||||
|
||||
async def aexecute_task(
|
||||
self,
|
||||
task: Task,
|
||||
context: str | None = None,
|
||||
tools: list[BaseTool] | None = None,
|
||||
) -> Any:
|
||||
"""Execute a task with the agent asynchronously.
|
||||
|
||||
Args:
|
||||
task: Task to execute.
|
||||
context: Context to execute the task in.
|
||||
tools: Tools to use for the task.
|
||||
|
||||
Returns:
|
||||
Output of the agent.
|
||||
|
||||
Raises:
|
||||
TimeoutError: If execution exceeds the maximum execution time.
|
||||
ValueError: If the max execution time is not a positive integer.
|
||||
RuntimeError: If the agent execution fails for other reasons.
|
||||
"""
|
||||
handle_reasoning(self, task)
|
||||
self._inject_date_to_task(task)
|
||||
|
||||
if self.tools_handler:
|
||||
self.tools_handler.last_used_tool = None
|
||||
|
||||
task_prompt = task.prompt()
|
||||
task_prompt = build_task_prompt_with_schema(task, task_prompt, self.i18n)
|
||||
task_prompt = format_task_with_context(task_prompt, context, self.i18n)
|
||||
|
||||
if self._is_any_available_memory():
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryRetrievalStartedEvent(
|
||||
task_id=str(task.id) if task else None,
|
||||
source_type="agent",
|
||||
from_agent=self,
|
||||
from_task=task,
|
||||
),
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
contextual_memory = ContextualMemory(
|
||||
self.crew._short_term_memory,
|
||||
self.crew._long_term_memory,
|
||||
self.crew._entity_memory,
|
||||
self.crew._external_memory,
|
||||
agent=self,
|
||||
task=task,
|
||||
)
|
||||
memory = await contextual_memory.abuild_context_for_task(
|
||||
task, context or ""
|
||||
)
|
||||
if memory.strip() != "":
|
||||
task_prompt += self.i18n.slice("memory").format(memory=memory)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryRetrievalCompletedEvent(
|
||||
task_id=str(task.id) if task else None,
|
||||
memory_content=memory,
|
||||
retrieval_time_ms=(time.time() - start_time) * 1000,
|
||||
source_type="agent",
|
||||
from_agent=self,
|
||||
from_task=task,
|
||||
),
|
||||
)
|
||||
|
||||
knowledge_config = get_knowledge_config(self)
|
||||
task_prompt = await ahandle_knowledge_retrieval(
|
||||
self, task, task_prompt, knowledge_config
|
||||
)
|
||||
|
||||
prepare_tools(self, tools, task)
|
||||
task_prompt = apply_training_data(self, task_prompt)
|
||||
|
||||
from crewai.events.types.agent_events import (
|
||||
AgentExecutionCompletedEvent,
|
||||
AgentExecutionErrorEvent,
|
||||
AgentExecutionStartedEvent,
|
||||
)
|
||||
|
||||
try:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=AgentExecutionStartedEvent(
|
||||
agent=self,
|
||||
tools=self.tools,
|
||||
task_prompt=task_prompt,
|
||||
task=task,
|
||||
),
|
||||
)
|
||||
|
||||
validate_max_execution_time(self.max_execution_time)
|
||||
if self.max_execution_time is not None:
|
||||
result = await self._aexecute_with_timeout(
|
||||
task_prompt, task, self.max_execution_time
|
||||
)
|
||||
else:
|
||||
result = await self._aexecute_without_timeout(task_prompt, task)
|
||||
|
||||
except TimeoutError as e:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=AgentExecutionErrorEvent(
|
||||
agent=self,
|
||||
task=task,
|
||||
error=str(e),
|
||||
),
|
||||
)
|
||||
raise e
|
||||
except Exception as e:
|
||||
if e.__class__.__module__.startswith("litellm"):
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=AgentExecutionErrorEvent(
|
||||
agent=self,
|
||||
task=task,
|
||||
error=str(e),
|
||||
),
|
||||
)
|
||||
raise e
|
||||
self._times_executed += 1
|
||||
if self._times_executed > self.max_retry_limit:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=AgentExecutionErrorEvent(
|
||||
agent=self,
|
||||
task=task,
|
||||
error=str(e),
|
||||
),
|
||||
)
|
||||
raise e
|
||||
result = await self.aexecute_task(task, context, tools)
|
||||
|
||||
if self.max_rpm and self._rpm_controller:
|
||||
self._rpm_controller.stop_rpm_counter()
|
||||
|
||||
result = process_tool_results(self, result)
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=AgentExecutionCompletedEvent(agent=self, task=task, output=result),
|
||||
)
|
||||
|
||||
save_last_messages(self)
|
||||
self._cleanup_mcp_clients()
|
||||
|
||||
return result
|
||||
|
||||
async def _aexecute_with_timeout(
|
||||
self, task_prompt: str, task: Task, timeout: int
|
||||
) -> Any:
|
||||
"""Execute a task with a timeout asynchronously.
|
||||
|
||||
Args:
|
||||
task_prompt: The prompt to send to the agent.
|
||||
task: The task being executed.
|
||||
timeout: Maximum execution time in seconds.
|
||||
|
||||
Returns:
|
||||
The output of the agent.
|
||||
|
||||
Raises:
|
||||
TimeoutError: If execution exceeds the timeout.
|
||||
RuntimeError: If execution fails for other reasons.
|
||||
"""
|
||||
try:
|
||||
return await asyncio.wait_for(
|
||||
self._aexecute_without_timeout(task_prompt, task),
|
||||
timeout=timeout,
|
||||
)
|
||||
except asyncio.TimeoutError as e:
|
||||
raise TimeoutError(
|
||||
f"Task '{task.description}' execution timed out after {timeout} seconds. "
|
||||
"Consider increasing max_execution_time or optimizing the task."
|
||||
) from e
|
||||
|
||||
async def _aexecute_without_timeout(self, task_prompt: str, task: Task) -> Any:
|
||||
"""Execute a task without a timeout asynchronously.
|
||||
|
||||
Args:
|
||||
task_prompt: The prompt to send to the agent.
|
||||
task: The task being executed.
|
||||
|
||||
Returns:
|
||||
The output of the agent.
|
||||
"""
|
||||
if not self.agent_executor:
|
||||
raise RuntimeError("Agent executor is not initialized.")
|
||||
|
||||
result = await self.agent_executor.ainvoke(
|
||||
{
|
||||
"input": task_prompt,
|
||||
"tool_names": self.agent_executor.tools_names,
|
||||
"tools": self.agent_executor.tools_description,
|
||||
"ask_for_human_input": task.human_input,
|
||||
}
|
||||
)
|
||||
return result["output"]
|
||||
|
||||
def create_agent_executor(
|
||||
self, tools: list[BaseTool] | None = None, task: Task | None = None
|
||||
) -> None:
|
||||
@@ -633,7 +722,7 @@ class Agent(BaseAgent):
|
||||
)
|
||||
|
||||
self.agent_executor = CrewAgentExecutor(
|
||||
llm=self.llm,
|
||||
llm=self.llm, # type: ignore[arg-type]
|
||||
task=task, # type: ignore[arg-type]
|
||||
agent=self,
|
||||
crew=self.crew,
|
||||
@@ -810,6 +899,7 @@ class Agent(BaseAgent):
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
from crewai.tools.mcp_native_tool import MCPNativeTool
|
||||
|
||||
transport: StdioTransport | HTTPTransport | SSETransport
|
||||
if isinstance(mcp_config, MCPServerStdio):
|
||||
transport = StdioTransport(
|
||||
command=mcp_config.command,
|
||||
@@ -903,10 +993,10 @@ class Agent(BaseAgent):
|
||||
server_name=server_name,
|
||||
run_context=None,
|
||||
)
|
||||
if mcp_config.tool_filter(context, tool):
|
||||
if mcp_config.tool_filter(context, tool): # type: ignore[call-arg, arg-type]
|
||||
filtered_tools.append(tool)
|
||||
except (TypeError, AttributeError):
|
||||
if mcp_config.tool_filter(tool):
|
||||
if mcp_config.tool_filter(tool): # type: ignore[call-arg, arg-type]
|
||||
filtered_tools.append(tool)
|
||||
else:
|
||||
# Not callable - include tool
|
||||
@@ -981,7 +1071,9 @@ class Agent(BaseAgent):
|
||||
path = parsed.path.replace("/", "_").strip("_")
|
||||
return f"{domain}_{path}" if path else domain
|
||||
|
||||
def _get_mcp_tool_schemas(self, server_params: dict) -> dict[str, dict]:
|
||||
def _get_mcp_tool_schemas(
|
||||
self, server_params: dict[str, Any]
|
||||
) -> dict[str, dict[str, Any]]:
|
||||
"""Get tool schemas from MCP server for wrapper creation with caching."""
|
||||
server_url = server_params["url"]
|
||||
|
||||
@@ -995,7 +1087,7 @@ class Agent(BaseAgent):
|
||||
self._logger.log(
|
||||
"debug", f"Using cached MCP tool schemas for {server_url}"
|
||||
)
|
||||
return cached_data
|
||||
return cached_data # type: ignore[no-any-return]
|
||||
|
||||
try:
|
||||
schemas = asyncio.run(self._get_mcp_tool_schemas_async(server_params))
|
||||
@@ -1013,7 +1105,7 @@ class Agent(BaseAgent):
|
||||
|
||||
async def _get_mcp_tool_schemas_async(
|
||||
self, server_params: dict[str, Any]
|
||||
) -> dict[str, dict]:
|
||||
) -> dict[str, dict[str, Any]]:
|
||||
"""Async implementation of MCP tool schema retrieval with timeouts and retries."""
|
||||
server_url = server_params["url"]
|
||||
return await self._retry_mcp_discovery(
|
||||
@@ -1021,7 +1113,7 @@ class Agent(BaseAgent):
|
||||
)
|
||||
|
||||
async def _retry_mcp_discovery(
|
||||
self, operation_func, server_url: str
|
||||
self, operation_func: Any, server_url: str
|
||||
) -> dict[str, dict[str, Any]]:
|
||||
"""Retry MCP discovery operation with exponential backoff, avoiding try-except in loop."""
|
||||
last_error = None
|
||||
@@ -1052,7 +1144,7 @@ class Agent(BaseAgent):
|
||||
|
||||
@staticmethod
|
||||
async def _attempt_mcp_discovery(
|
||||
operation_func, server_url: str
|
||||
operation_func: Any, server_url: str
|
||||
) -> tuple[dict[str, dict[str, Any]] | None, str, bool]:
|
||||
"""Attempt single MCP discovery operation and return (result, error_message, should_retry)."""
|
||||
try:
|
||||
@@ -1142,7 +1234,7 @@ class Agent(BaseAgent):
|
||||
properties = json_schema.get("properties", {})
|
||||
required_fields = json_schema.get("required", [])
|
||||
|
||||
field_definitions = {}
|
||||
field_definitions: dict[str, Any] = {}
|
||||
|
||||
for field_name, field_schema in properties.items():
|
||||
field_type = self._json_type_to_python(field_schema)
|
||||
@@ -1162,7 +1254,7 @@ class Agent(BaseAgent):
|
||||
)
|
||||
|
||||
model_name = f"{tool_name.replace('-', '_').replace(' ', '_')}Schema"
|
||||
return create_model(model_name, **field_definitions)
|
||||
return create_model(model_name, **field_definitions) # type: ignore[no-any-return]
|
||||
|
||||
def _json_type_to_python(self, field_schema: dict[str, Any]) -> type:
|
||||
"""Convert JSON Schema type to Python type.
|
||||
@@ -1177,7 +1269,7 @@ class Agent(BaseAgent):
|
||||
json_type = field_schema.get("type")
|
||||
|
||||
if "anyOf" in field_schema:
|
||||
types = []
|
||||
types: list[type] = []
|
||||
for option in field_schema["anyOf"]:
|
||||
if "const" in option:
|
||||
types.append(str)
|
||||
@@ -1185,13 +1277,13 @@ class Agent(BaseAgent):
|
||||
types.append(self._json_type_to_python(option))
|
||||
unique_types = list(set(types))
|
||||
if len(unique_types) > 1:
|
||||
result = unique_types[0]
|
||||
result: Any = unique_types[0]
|
||||
for t in unique_types[1:]:
|
||||
result = result | t
|
||||
return result
|
||||
return result # type: ignore[no-any-return]
|
||||
return unique_types[0]
|
||||
|
||||
type_mapping = {
|
||||
type_mapping: dict[str | None, type] = {
|
||||
"string": str,
|
||||
"number": float,
|
||||
"integer": int,
|
||||
@@ -1203,7 +1295,7 @@ class Agent(BaseAgent):
|
||||
return type_mapping.get(json_type, Any)
|
||||
|
||||
@staticmethod
|
||||
def _fetch_amp_mcp_servers(mcp_name: str) -> list[dict]:
|
||||
def _fetch_amp_mcp_servers(mcp_name: str) -> list[dict[str, Any]]:
|
||||
"""Fetch MCP server configurations from CrewAI AOP API."""
|
||||
# TODO: Implement AMP API call to "integrations/mcps" endpoint
|
||||
# Should return list of server configs with URLs
|
||||
@@ -1438,11 +1530,11 @@ class Agent(BaseAgent):
|
||||
"""
|
||||
if self.apps:
|
||||
platform_tools = self.get_platform_tools(self.apps)
|
||||
if platform_tools:
|
||||
if platform_tools and self.tools is not None:
|
||||
self.tools.extend(platform_tools)
|
||||
if self.mcps:
|
||||
mcps = self.get_mcp_tools(self.mcps)
|
||||
if mcps:
|
||||
if mcps and self.tools is not None:
|
||||
self.tools.extend(mcps)
|
||||
|
||||
lite_agent = LiteAgent(
|
||||
|
||||
@@ -4,9 +4,8 @@ This metaclass enables extension capabilities for agents by detecting
|
||||
extension fields in class annotations and applying appropriate wrappers.
|
||||
"""
|
||||
|
||||
import warnings
|
||||
from functools import wraps
|
||||
from typing import Any
|
||||
import warnings
|
||||
|
||||
from pydantic import model_validator
|
||||
from pydantic._internal._model_construction import ModelMetaclass
|
||||
@@ -59,9 +58,15 @@ class AgentMeta(ModelMetaclass):
|
||||
|
||||
a2a_value = getattr(self, "a2a", None)
|
||||
if a2a_value is not None:
|
||||
from crewai.a2a.extensions.registry import (
|
||||
create_extension_registry_from_config,
|
||||
)
|
||||
from crewai.a2a.wrapper import wrap_agent_with_a2a_instance
|
||||
|
||||
wrap_agent_with_a2a_instance(self)
|
||||
extension_registry = create_extension_registry_from_config(
|
||||
a2a_value
|
||||
)
|
||||
wrap_agent_with_a2a_instance(self, extension_registry)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
355
lib/crewai/src/crewai/agent/utils.py
Normal file
355
lib/crewai/src/crewai/agent/utils.py
Normal file
@@ -0,0 +1,355 @@
|
||||
"""Utility functions for agent task execution.
|
||||
|
||||
This module contains shared logic extracted from the Agent's execute_task
|
||||
and aexecute_task methods to reduce code duplication.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.knowledge_events import (
|
||||
KnowledgeRetrievalCompletedEvent,
|
||||
KnowledgeRetrievalStartedEvent,
|
||||
KnowledgeSearchQueryFailedEvent,
|
||||
)
|
||||
from crewai.knowledge.utils.knowledge_utils import extract_knowledge_context
|
||||
from crewai.utilities.converter import generate_model_description
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.agent.core import Agent
|
||||
from crewai.task import Task
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
from crewai.utilities.i18n import I18N
|
||||
|
||||
|
||||
def handle_reasoning(agent: Agent, task: Task) -> None:
|
||||
"""Handle the reasoning process for an agent before task execution.
|
||||
|
||||
Args:
|
||||
agent: The agent performing the task.
|
||||
task: The task to execute.
|
||||
"""
|
||||
if not agent.reasoning:
|
||||
return
|
||||
|
||||
try:
|
||||
from crewai.utilities.reasoning_handler import (
|
||||
AgentReasoning,
|
||||
AgentReasoningOutput,
|
||||
)
|
||||
|
||||
reasoning_handler = AgentReasoning(task=task, agent=agent)
|
||||
reasoning_output: AgentReasoningOutput = (
|
||||
reasoning_handler.handle_agent_reasoning()
|
||||
)
|
||||
task.description += f"\n\nReasoning Plan:\n{reasoning_output.plan.plan}"
|
||||
except Exception as e:
|
||||
agent._logger.log("error", f"Error during reasoning process: {e!s}")
|
||||
|
||||
|
||||
def build_task_prompt_with_schema(task: Task, task_prompt: str, i18n: I18N) -> str:
|
||||
"""Build task prompt with JSON/Pydantic schema instructions if applicable.
|
||||
|
||||
Args:
|
||||
task: The task being executed.
|
||||
task_prompt: The initial task prompt.
|
||||
i18n: Internationalization instance.
|
||||
|
||||
Returns:
|
||||
The task prompt potentially augmented with schema instructions.
|
||||
"""
|
||||
if (task.output_json or task.output_pydantic) and not task.response_model:
|
||||
if task.output_json:
|
||||
schema_dict = generate_model_description(task.output_json)
|
||||
schema = json.dumps(schema_dict["json_schema"]["schema"], indent=2)
|
||||
task_prompt += "\n" + i18n.slice("formatted_task_instructions").format(
|
||||
output_format=schema
|
||||
)
|
||||
elif task.output_pydantic:
|
||||
schema_dict = generate_model_description(task.output_pydantic)
|
||||
schema = json.dumps(schema_dict["json_schema"]["schema"], indent=2)
|
||||
task_prompt += "\n" + i18n.slice("formatted_task_instructions").format(
|
||||
output_format=schema
|
||||
)
|
||||
return task_prompt
|
||||
|
||||
|
||||
def format_task_with_context(task_prompt: str, context: str | None, i18n: I18N) -> str:
|
||||
"""Format task prompt with context if provided.
|
||||
|
||||
Args:
|
||||
task_prompt: The task prompt.
|
||||
context: Optional context string.
|
||||
i18n: Internationalization instance.
|
||||
|
||||
Returns:
|
||||
The task prompt formatted with context if provided.
|
||||
"""
|
||||
if context:
|
||||
return i18n.slice("task_with_context").format(task=task_prompt, context=context)
|
||||
return task_prompt
|
||||
|
||||
|
||||
def get_knowledge_config(agent: Agent) -> dict[str, Any]:
|
||||
"""Get knowledge configuration from agent.
|
||||
|
||||
Args:
|
||||
agent: The agent instance.
|
||||
|
||||
Returns:
|
||||
Dictionary of knowledge configuration.
|
||||
"""
|
||||
return agent.knowledge_config.model_dump() if agent.knowledge_config else {}
|
||||
|
||||
|
||||
def handle_knowledge_retrieval(
|
||||
agent: Agent,
|
||||
task: Task,
|
||||
task_prompt: str,
|
||||
knowledge_config: dict[str, Any],
|
||||
query_func: Any,
|
||||
crew_query_func: Any,
|
||||
) -> str:
|
||||
"""Handle knowledge retrieval for task execution.
|
||||
|
||||
This function handles both agent-specific and crew-specific knowledge queries.
|
||||
|
||||
Args:
|
||||
agent: The agent performing the task.
|
||||
task: The task being executed.
|
||||
task_prompt: The current task prompt.
|
||||
knowledge_config: Knowledge configuration dictionary.
|
||||
query_func: Function to query agent knowledge (sync or async).
|
||||
crew_query_func: Function to query crew knowledge (sync or async).
|
||||
|
||||
Returns:
|
||||
The task prompt potentially augmented with knowledge context.
|
||||
"""
|
||||
if not (agent.knowledge or (agent.crew and agent.crew.knowledge)):
|
||||
return task_prompt
|
||||
|
||||
crewai_event_bus.emit(
|
||||
agent,
|
||||
event=KnowledgeRetrievalStartedEvent(
|
||||
from_task=task,
|
||||
from_agent=agent,
|
||||
),
|
||||
)
|
||||
try:
|
||||
agent.knowledge_search_query = agent._get_knowledge_search_query(
|
||||
task_prompt, task
|
||||
)
|
||||
if agent.knowledge_search_query:
|
||||
if agent.knowledge:
|
||||
agent_knowledge_snippets = query_func(
|
||||
[agent.knowledge_search_query], **knowledge_config
|
||||
)
|
||||
if agent_knowledge_snippets:
|
||||
agent.agent_knowledge_context = extract_knowledge_context(
|
||||
agent_knowledge_snippets
|
||||
)
|
||||
if agent.agent_knowledge_context:
|
||||
task_prompt += agent.agent_knowledge_context
|
||||
|
||||
knowledge_snippets = crew_query_func(
|
||||
[agent.knowledge_search_query], **knowledge_config
|
||||
)
|
||||
if knowledge_snippets:
|
||||
agent.crew_knowledge_context = extract_knowledge_context(
|
||||
knowledge_snippets
|
||||
)
|
||||
if agent.crew_knowledge_context:
|
||||
task_prompt += agent.crew_knowledge_context
|
||||
|
||||
crewai_event_bus.emit(
|
||||
agent,
|
||||
event=KnowledgeRetrievalCompletedEvent(
|
||||
query=agent.knowledge_search_query,
|
||||
from_task=task,
|
||||
from_agent=agent,
|
||||
retrieved_knowledge=_combine_knowledge_context(agent),
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
crewai_event_bus.emit(
|
||||
agent,
|
||||
event=KnowledgeSearchQueryFailedEvent(
|
||||
query=agent.knowledge_search_query or "",
|
||||
error=str(e),
|
||||
from_task=task,
|
||||
from_agent=agent,
|
||||
),
|
||||
)
|
||||
return task_prompt
|
||||
|
||||
|
||||
def _combine_knowledge_context(agent: Agent) -> str:
|
||||
"""Combine agent and crew knowledge contexts into a single string.
|
||||
|
||||
Args:
|
||||
agent: The agent with knowledge contexts.
|
||||
|
||||
Returns:
|
||||
Combined knowledge context string.
|
||||
"""
|
||||
agent_ctx = agent.agent_knowledge_context or ""
|
||||
crew_ctx = agent.crew_knowledge_context or ""
|
||||
separator = "\n" if agent_ctx and crew_ctx else ""
|
||||
return agent_ctx + separator + crew_ctx
|
||||
|
||||
|
||||
def apply_training_data(agent: Agent, task_prompt: str) -> str:
|
||||
"""Apply training data to the task prompt.
|
||||
|
||||
Args:
|
||||
agent: The agent performing the task.
|
||||
task_prompt: The task prompt.
|
||||
|
||||
Returns:
|
||||
The task prompt with training data applied.
|
||||
"""
|
||||
if agent.crew and agent.crew._train:
|
||||
return agent._training_handler(task_prompt=task_prompt)
|
||||
return agent._use_trained_data(task_prompt=task_prompt)
|
||||
|
||||
|
||||
def process_tool_results(agent: Agent, result: Any) -> Any:
|
||||
"""Process tool results, returning result_as_answer if applicable.
|
||||
|
||||
Args:
|
||||
agent: The agent with tool results.
|
||||
result: The current result.
|
||||
|
||||
Returns:
|
||||
The final result, potentially overridden by tool result_as_answer.
|
||||
"""
|
||||
for tool_result in agent.tools_results:
|
||||
if tool_result.get("result_as_answer", False):
|
||||
result = tool_result["result"]
|
||||
return result
|
||||
|
||||
|
||||
def save_last_messages(agent: Agent) -> None:
|
||||
"""Save the last messages from agent executor.
|
||||
|
||||
Args:
|
||||
agent: The agent instance.
|
||||
"""
|
||||
agent._last_messages = (
|
||||
agent.agent_executor.messages.copy()
|
||||
if agent.agent_executor and hasattr(agent.agent_executor, "messages")
|
||||
else []
|
||||
)
|
||||
|
||||
|
||||
def prepare_tools(
|
||||
agent: Agent, tools: list[BaseTool] | None, task: Task
|
||||
) -> list[BaseTool]:
|
||||
"""Prepare tools for task execution and create agent executor.
|
||||
|
||||
Args:
|
||||
agent: The agent instance.
|
||||
tools: Optional list of tools.
|
||||
task: The task being executed.
|
||||
|
||||
Returns:
|
||||
The list of tools to use.
|
||||
"""
|
||||
final_tools = tools or agent.tools or []
|
||||
agent.create_agent_executor(tools=final_tools, task=task)
|
||||
return final_tools
|
||||
|
||||
|
||||
def validate_max_execution_time(max_execution_time: int | None) -> None:
|
||||
"""Validate max_execution_time parameter.
|
||||
|
||||
Args:
|
||||
max_execution_time: The maximum execution time to validate.
|
||||
|
||||
Raises:
|
||||
ValueError: If max_execution_time is not a positive integer.
|
||||
"""
|
||||
if max_execution_time is not None:
|
||||
if not isinstance(max_execution_time, int) or max_execution_time <= 0:
|
||||
raise ValueError(
|
||||
"Max Execution time must be a positive integer greater than zero"
|
||||
)
|
||||
|
||||
|
||||
async def ahandle_knowledge_retrieval(
|
||||
agent: Agent,
|
||||
task: Task,
|
||||
task_prompt: str,
|
||||
knowledge_config: dict[str, Any],
|
||||
) -> str:
|
||||
"""Handle async knowledge retrieval for task execution.
|
||||
|
||||
Args:
|
||||
agent: The agent performing the task.
|
||||
task: The task being executed.
|
||||
task_prompt: The current task prompt.
|
||||
knowledge_config: Knowledge configuration dictionary.
|
||||
|
||||
Returns:
|
||||
The task prompt potentially augmented with knowledge context.
|
||||
"""
|
||||
if not (agent.knowledge or (agent.crew and agent.crew.knowledge)):
|
||||
return task_prompt
|
||||
|
||||
crewai_event_bus.emit(
|
||||
agent,
|
||||
event=KnowledgeRetrievalStartedEvent(
|
||||
from_task=task,
|
||||
from_agent=agent,
|
||||
),
|
||||
)
|
||||
try:
|
||||
agent.knowledge_search_query = agent._get_knowledge_search_query(
|
||||
task_prompt, task
|
||||
)
|
||||
if agent.knowledge_search_query:
|
||||
if agent.knowledge:
|
||||
agent_knowledge_snippets = await agent.knowledge.aquery(
|
||||
[agent.knowledge_search_query], **knowledge_config
|
||||
)
|
||||
if agent_knowledge_snippets:
|
||||
agent.agent_knowledge_context = extract_knowledge_context(
|
||||
agent_knowledge_snippets
|
||||
)
|
||||
if agent.agent_knowledge_context:
|
||||
task_prompt += agent.agent_knowledge_context
|
||||
|
||||
knowledge_snippets = await agent.crew.aquery_knowledge(
|
||||
[agent.knowledge_search_query], **knowledge_config
|
||||
)
|
||||
if knowledge_snippets:
|
||||
agent.crew_knowledge_context = extract_knowledge_context(
|
||||
knowledge_snippets
|
||||
)
|
||||
if agent.crew_knowledge_context:
|
||||
task_prompt += agent.crew_knowledge_context
|
||||
|
||||
crewai_event_bus.emit(
|
||||
agent,
|
||||
event=KnowledgeRetrievalCompletedEvent(
|
||||
query=agent.knowledge_search_query,
|
||||
from_task=task,
|
||||
from_agent=agent,
|
||||
retrieved_knowledge=_combine_knowledge_context(agent),
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
crewai_event_bus.emit(
|
||||
agent,
|
||||
event=KnowledgeSearchQueryFailedEvent(
|
||||
query=agent.knowledge_search_query or "",
|
||||
error=str(e),
|
||||
from_task=task,
|
||||
from_agent=agent,
|
||||
),
|
||||
)
|
||||
return task_prompt
|
||||
@@ -265,7 +265,7 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta):
|
||||
if not mcps:
|
||||
return mcps
|
||||
|
||||
validated_mcps = []
|
||||
validated_mcps: list[str | MCPServerConfig] = []
|
||||
for mcp in mcps:
|
||||
if isinstance(mcp, str):
|
||||
if mcp.startswith(("https://", "crewai-amp:")):
|
||||
@@ -347,6 +347,15 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta):
|
||||
) -> str:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def aexecute_task(
|
||||
self,
|
||||
task: Any,
|
||||
context: str | None = None,
|
||||
tools: list[BaseTool] | None = None,
|
||||
) -> str:
|
||||
"""Execute a task asynchronously."""
|
||||
|
||||
@abstractmethod
|
||||
def create_agent_executor(self, tools: list[BaseTool] | None = None) -> None:
|
||||
pass
|
||||
|
||||
@@ -28,6 +28,7 @@ from crewai.hooks.llm_hooks import (
|
||||
get_before_llm_call_hooks,
|
||||
)
|
||||
from crewai.utilities.agent_utils import (
|
||||
aget_llm_response,
|
||||
enforce_rpm_limit,
|
||||
format_message_for_llm,
|
||||
get_llm_response,
|
||||
@@ -43,7 +44,10 @@ from crewai.utilities.agent_utils import (
|
||||
from crewai.utilities.constants import TRAINING_DATA_FILE
|
||||
from crewai.utilities.i18n import I18N, get_i18n
|
||||
from crewai.utilities.printer import Printer
|
||||
from crewai.utilities.tool_utils import execute_tool_and_check_finality
|
||||
from crewai.utilities.tool_utils import (
|
||||
aexecute_tool_and_check_finality,
|
||||
execute_tool_and_check_finality,
|
||||
)
|
||||
from crewai.utilities.training_handler import CrewTrainingHandler
|
||||
|
||||
|
||||
@@ -134,8 +138,8 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
self.messages: list[LLMMessage] = []
|
||||
self.iterations = 0
|
||||
self.log_error_after = 3
|
||||
self.before_llm_call_hooks: list[Callable] = []
|
||||
self.after_llm_call_hooks: list[Callable] = []
|
||||
self.before_llm_call_hooks: list[Callable[..., Any]] = []
|
||||
self.after_llm_call_hooks: list[Callable[..., Any]] = []
|
||||
self.before_llm_call_hooks.extend(get_before_llm_call_hooks())
|
||||
self.after_llm_call_hooks.extend(get_after_llm_call_hooks())
|
||||
if self.llm:
|
||||
@@ -312,6 +316,154 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
self._show_logs(formatted_answer)
|
||||
return formatted_answer
|
||||
|
||||
async def ainvoke(self, inputs: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Execute the agent asynchronously with given inputs.
|
||||
|
||||
Args:
|
||||
inputs: Input dictionary containing prompt variables.
|
||||
|
||||
Returns:
|
||||
Dictionary with agent output.
|
||||
"""
|
||||
if "system" in self.prompt:
|
||||
system_prompt = self._format_prompt(
|
||||
cast(str, self.prompt.get("system", "")), inputs
|
||||
)
|
||||
user_prompt = self._format_prompt(
|
||||
cast(str, self.prompt.get("user", "")), inputs
|
||||
)
|
||||
self.messages.append(format_message_for_llm(system_prompt, role="system"))
|
||||
self.messages.append(format_message_for_llm(user_prompt))
|
||||
else:
|
||||
user_prompt = self._format_prompt(self.prompt.get("prompt", ""), inputs)
|
||||
self.messages.append(format_message_for_llm(user_prompt))
|
||||
|
||||
self._show_start_logs()
|
||||
|
||||
self.ask_for_human_input = bool(inputs.get("ask_for_human_input", False))
|
||||
|
||||
try:
|
||||
formatted_answer = await self._ainvoke_loop()
|
||||
except AssertionError:
|
||||
self._printer.print(
|
||||
content="Agent failed to reach a final answer. This is likely a bug - please report it.",
|
||||
color="red",
|
||||
)
|
||||
raise
|
||||
except Exception as e:
|
||||
handle_unknown_error(self._printer, e)
|
||||
raise
|
||||
|
||||
if self.ask_for_human_input:
|
||||
formatted_answer = self._handle_human_feedback(formatted_answer)
|
||||
|
||||
self._create_short_term_memory(formatted_answer)
|
||||
self._create_long_term_memory(formatted_answer)
|
||||
self._create_external_memory(formatted_answer)
|
||||
return {"output": formatted_answer.output}
|
||||
|
||||
async def _ainvoke_loop(self) -> AgentFinish:
|
||||
"""Execute agent loop asynchronously until completion.
|
||||
|
||||
Returns:
|
||||
Final answer from the agent.
|
||||
"""
|
||||
formatted_answer = None
|
||||
while not isinstance(formatted_answer, AgentFinish):
|
||||
try:
|
||||
if has_reached_max_iterations(self.iterations, self.max_iter):
|
||||
formatted_answer = handle_max_iterations_exceeded(
|
||||
formatted_answer,
|
||||
printer=self._printer,
|
||||
i18n=self._i18n,
|
||||
messages=self.messages,
|
||||
llm=self.llm,
|
||||
callbacks=self.callbacks,
|
||||
)
|
||||
break
|
||||
|
||||
enforce_rpm_limit(self.request_within_rpm_limit)
|
||||
|
||||
answer = await aget_llm_response(
|
||||
llm=self.llm,
|
||||
messages=self.messages,
|
||||
callbacks=self.callbacks,
|
||||
printer=self._printer,
|
||||
from_task=self.task,
|
||||
from_agent=self.agent,
|
||||
response_model=self.response_model,
|
||||
executor_context=self,
|
||||
)
|
||||
formatted_answer = process_llm_response(answer, self.use_stop_words) # type: ignore[assignment]
|
||||
|
||||
if isinstance(formatted_answer, AgentAction):
|
||||
fingerprint_context = {}
|
||||
if (
|
||||
self.agent
|
||||
and hasattr(self.agent, "security_config")
|
||||
and hasattr(self.agent.security_config, "fingerprint")
|
||||
):
|
||||
fingerprint_context = {
|
||||
"agent_fingerprint": str(
|
||||
self.agent.security_config.fingerprint
|
||||
)
|
||||
}
|
||||
|
||||
tool_result = await aexecute_tool_and_check_finality(
|
||||
agent_action=formatted_answer,
|
||||
fingerprint_context=fingerprint_context,
|
||||
tools=self.tools,
|
||||
i18n=self._i18n,
|
||||
agent_key=self.agent.key if self.agent else None,
|
||||
agent_role=self.agent.role if self.agent else None,
|
||||
tools_handler=self.tools_handler,
|
||||
task=self.task,
|
||||
agent=self.agent,
|
||||
function_calling_llm=self.function_calling_llm,
|
||||
crew=self.crew,
|
||||
)
|
||||
formatted_answer = self._handle_agent_action(
|
||||
formatted_answer, tool_result
|
||||
)
|
||||
|
||||
self._invoke_step_callback(formatted_answer) # type: ignore[arg-type]
|
||||
self._append_message(formatted_answer.text) # type: ignore[union-attr,attr-defined]
|
||||
|
||||
except OutputParserError as e:
|
||||
formatted_answer = handle_output_parser_exception( # type: ignore[assignment]
|
||||
e=e,
|
||||
messages=self.messages,
|
||||
iterations=self.iterations,
|
||||
log_error_after=self.log_error_after,
|
||||
printer=self._printer,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
if e.__class__.__module__.startswith("litellm"):
|
||||
raise e
|
||||
if is_context_length_exceeded(e):
|
||||
handle_context_length(
|
||||
respect_context_window=self.respect_context_window,
|
||||
printer=self._printer,
|
||||
messages=self.messages,
|
||||
llm=self.llm,
|
||||
callbacks=self.callbacks,
|
||||
i18n=self._i18n,
|
||||
)
|
||||
continue
|
||||
handle_unknown_error(self._printer, e)
|
||||
raise e
|
||||
finally:
|
||||
self.iterations += 1
|
||||
|
||||
if not isinstance(formatted_answer, AgentFinish):
|
||||
raise RuntimeError(
|
||||
"Agent execution ended without reaching a final answer. "
|
||||
f"Got {type(formatted_answer).__name__} instead of AgentFinish."
|
||||
)
|
||||
self._show_logs(formatted_answer)
|
||||
return formatted_answer
|
||||
|
||||
def _handle_agent_action(
|
||||
self, formatted_answer: AgentAction, tool_result: ToolResult
|
||||
) -> AgentAction | AgentFinish:
|
||||
|
||||
@@ -73,6 +73,7 @@ CLI_SETTINGS_KEYS = [
|
||||
"oauth2_audience",
|
||||
"oauth2_client_id",
|
||||
"oauth2_domain",
|
||||
"oauth2_extra",
|
||||
]
|
||||
|
||||
# Default values for CLI settings
|
||||
@@ -82,6 +83,7 @@ DEFAULT_CLI_SETTINGS = {
|
||||
"oauth2_audience": CREWAI_ENTERPRISE_DEFAULT_OAUTH2_AUDIENCE,
|
||||
"oauth2_client_id": CREWAI_ENTERPRISE_DEFAULT_OAUTH2_CLIENT_ID,
|
||||
"oauth2_domain": CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN,
|
||||
"oauth2_extra": {},
|
||||
}
|
||||
|
||||
# Readonly settings - cannot be set by the user
|
||||
|
||||
@@ -14,7 +14,8 @@ import tomli
|
||||
from crewai.cli.utils import read_toml
|
||||
from crewai.cli.version import get_crewai_version
|
||||
from crewai.crew import Crew
|
||||
from crewai.llm import LLM, BaseLLM
|
||||
from crewai.llm import LLM
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.types.crew_chat import ChatInputField, ChatInputs
|
||||
from crewai.utilities.llm_utils import create_llm
|
||||
from crewai.utilities.printer import Printer
|
||||
@@ -27,7 +28,7 @@ MIN_REQUIRED_VERSION: Final[Literal["0.98.0"]] = "0.98.0"
|
||||
|
||||
|
||||
def check_conversational_crews_version(
|
||||
crewai_version: str, pyproject_data: dict
|
||||
crewai_version: str, pyproject_data: dict[str, Any]
|
||||
) -> bool:
|
||||
"""
|
||||
Check if the installed crewAI version supports conversational crews.
|
||||
@@ -53,7 +54,7 @@ def check_conversational_crews_version(
|
||||
return True
|
||||
|
||||
|
||||
def run_chat():
|
||||
def run_chat() -> None:
|
||||
"""
|
||||
Runs an interactive chat loop using the Crew's chat LLM with function calling.
|
||||
Incorporates crew_name, crew_description, and input fields to build a tool schema.
|
||||
@@ -101,7 +102,7 @@ def run_chat():
|
||||
|
||||
click.secho(f"Assistant: {introductory_message}\n", fg="green")
|
||||
|
||||
messages = [
|
||||
messages: list[LLMMessage] = [
|
||||
{"role": "system", "content": system_message},
|
||||
{"role": "assistant", "content": introductory_message},
|
||||
]
|
||||
@@ -113,7 +114,7 @@ def run_chat():
|
||||
chat_loop(chat_llm, messages, crew_tool_schema, available_functions)
|
||||
|
||||
|
||||
def show_loading(event: threading.Event):
|
||||
def show_loading(event: threading.Event) -> None:
|
||||
"""Display animated loading dots while processing."""
|
||||
while not event.is_set():
|
||||
_printer.print(".", end="")
|
||||
@@ -162,23 +163,23 @@ def build_system_message(crew_chat_inputs: ChatInputs) -> str:
|
||||
)
|
||||
|
||||
|
||||
def create_tool_function(crew: Crew, messages: list[dict[str, str]]) -> Any:
|
||||
def create_tool_function(crew: Crew, messages: list[LLMMessage]) -> Any:
|
||||
"""Creates a wrapper function for running the crew tool with messages."""
|
||||
|
||||
def run_crew_tool_with_messages(**kwargs):
|
||||
def run_crew_tool_with_messages(**kwargs: Any) -> str:
|
||||
return run_crew_tool(crew, messages, **kwargs)
|
||||
|
||||
return run_crew_tool_with_messages
|
||||
|
||||
|
||||
def flush_input():
|
||||
def flush_input() -> None:
|
||||
"""Flush any pending input from the user."""
|
||||
if platform.system() == "Windows":
|
||||
# Windows platform
|
||||
import msvcrt
|
||||
|
||||
while msvcrt.kbhit():
|
||||
msvcrt.getch()
|
||||
while msvcrt.kbhit(): # type: ignore[attr-defined]
|
||||
msvcrt.getch() # type: ignore[attr-defined]
|
||||
else:
|
||||
# Unix-like platforms (Linux, macOS)
|
||||
import termios
|
||||
@@ -186,7 +187,12 @@ def flush_input():
|
||||
termios.tcflush(sys.stdin, termios.TCIFLUSH)
|
||||
|
||||
|
||||
def chat_loop(chat_llm, messages, crew_tool_schema, available_functions):
|
||||
def chat_loop(
|
||||
chat_llm: LLM | BaseLLM,
|
||||
messages: list[LLMMessage],
|
||||
crew_tool_schema: dict[str, Any],
|
||||
available_functions: dict[str, Any],
|
||||
) -> None:
|
||||
"""Main chat loop for interacting with the user."""
|
||||
while True:
|
||||
try:
|
||||
@@ -225,7 +231,7 @@ def get_user_input() -> str:
|
||||
|
||||
def handle_user_input(
|
||||
user_input: str,
|
||||
chat_llm: LLM,
|
||||
chat_llm: LLM | BaseLLM,
|
||||
messages: list[LLMMessage],
|
||||
crew_tool_schema: dict[str, Any],
|
||||
available_functions: dict[str, Any],
|
||||
@@ -255,7 +261,7 @@ def handle_user_input(
|
||||
click.secho(f"\nAssistant: {final_response}\n", fg="green")
|
||||
|
||||
|
||||
def generate_crew_tool_schema(crew_inputs: ChatInputs) -> dict:
|
||||
def generate_crew_tool_schema(crew_inputs: ChatInputs) -> dict[str, Any]:
|
||||
"""
|
||||
Dynamically build a Littellm 'function' schema for the given crew.
|
||||
|
||||
@@ -286,7 +292,7 @@ def generate_crew_tool_schema(crew_inputs: ChatInputs) -> dict:
|
||||
}
|
||||
|
||||
|
||||
def run_crew_tool(crew: Crew, messages: list[dict[str, str]], **kwargs):
|
||||
def run_crew_tool(crew: Crew, messages: list[LLMMessage], **kwargs: Any) -> str:
|
||||
"""
|
||||
Runs the crew using crew.kickoff(inputs=kwargs) and returns the output.
|
||||
|
||||
@@ -372,7 +378,9 @@ def load_crew_and_name() -> tuple[Crew, str]:
|
||||
return crew_instance, crew_class_name
|
||||
|
||||
|
||||
def generate_crew_chat_inputs(crew: Crew, crew_name: str, chat_llm) -> ChatInputs:
|
||||
def generate_crew_chat_inputs(
|
||||
crew: Crew, crew_name: str, chat_llm: LLM | BaseLLM
|
||||
) -> ChatInputs:
|
||||
"""
|
||||
Generates the ChatInputs required for the crew by analyzing the tasks and agents.
|
||||
|
||||
@@ -410,23 +418,12 @@ def fetch_required_inputs(crew: Crew) -> set[str]:
|
||||
Returns:
|
||||
Set[str]: A set of placeholder names.
|
||||
"""
|
||||
placeholder_pattern = re.compile(r"\{(.+?)}")
|
||||
required_inputs: set[str] = set()
|
||||
|
||||
# Scan tasks
|
||||
for task in crew.tasks:
|
||||
text = f"{task.description or ''} {task.expected_output or ''}"
|
||||
required_inputs.update(placeholder_pattern.findall(text))
|
||||
|
||||
# Scan agents
|
||||
for agent in crew.agents:
|
||||
text = f"{agent.role or ''} {agent.goal or ''} {agent.backstory or ''}"
|
||||
required_inputs.update(placeholder_pattern.findall(text))
|
||||
|
||||
return required_inputs
|
||||
return crew.fetch_inputs()
|
||||
|
||||
|
||||
def generate_input_description_with_ai(input_name: str, crew: Crew, chat_llm) -> str:
|
||||
def generate_input_description_with_ai(
|
||||
input_name: str, crew: Crew, chat_llm: LLM | BaseLLM
|
||||
) -> str:
|
||||
"""
|
||||
Generates an input description using AI based on the context of the crew.
|
||||
|
||||
@@ -484,10 +481,10 @@ def generate_input_description_with_ai(input_name: str, crew: Crew, chat_llm) ->
|
||||
f"{context}"
|
||||
)
|
||||
response = chat_llm.call(messages=[{"role": "user", "content": prompt}])
|
||||
return response.strip()
|
||||
return str(response).strip()
|
||||
|
||||
|
||||
def generate_crew_description_with_ai(crew: Crew, chat_llm) -> str:
|
||||
def generate_crew_description_with_ai(crew: Crew, chat_llm: LLM | BaseLLM) -> str:
|
||||
"""
|
||||
Generates a brief description of the crew using AI.
|
||||
|
||||
@@ -534,4 +531,4 @@ def generate_crew_description_with_ai(crew: Crew, chat_llm) -> str:
|
||||
f"{context}"
|
||||
)
|
||||
response = chat_llm.call(messages=[{"role": "user", "content": prompt}])
|
||||
return response.strip()
|
||||
return str(response).strip()
|
||||
|
||||
@@ -3,103 +3,56 @@ import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
import sys
|
||||
from typing import BinaryIO, cast
|
||||
import tempfile
|
||||
from typing import Final, Literal, cast
|
||||
|
||||
from cryptography.fernet import Fernet
|
||||
|
||||
|
||||
if sys.platform == "win32":
|
||||
import msvcrt
|
||||
else:
|
||||
import fcntl
|
||||
_FERNET_KEY_LENGTH: Final[Literal[44]] = 44
|
||||
|
||||
|
||||
class TokenManager:
|
||||
def __init__(self, file_path: str = "tokens.enc") -> None:
|
||||
"""
|
||||
Initialize the TokenManager class.
|
||||
"""Manages encrypted token storage."""
|
||||
|
||||
:param file_path: The file path to store the encrypted tokens. Default is "tokens.enc".
|
||||
def __init__(self, file_path: str = "tokens.enc") -> None:
|
||||
"""Initialize the TokenManager.
|
||||
|
||||
Args:
|
||||
file_path: The file path to store encrypted tokens.
|
||||
"""
|
||||
self.file_path = file_path
|
||||
self.key = self._get_or_create_key()
|
||||
self.fernet = Fernet(self.key)
|
||||
|
||||
@staticmethod
|
||||
def _acquire_lock(file_handle: BinaryIO) -> None:
|
||||
"""
|
||||
Acquire an exclusive lock on a file handle.
|
||||
|
||||
Args:
|
||||
file_handle: Open file handle to lock.
|
||||
"""
|
||||
if sys.platform == "win32":
|
||||
msvcrt.locking(file_handle.fileno(), msvcrt.LK_LOCK, 1)
|
||||
else:
|
||||
fcntl.flock(file_handle.fileno(), fcntl.LOCK_EX)
|
||||
|
||||
@staticmethod
|
||||
def _release_lock(file_handle: BinaryIO) -> None:
|
||||
"""
|
||||
Release the lock on a file handle.
|
||||
|
||||
Args:
|
||||
file_handle: Open file handle to unlock.
|
||||
"""
|
||||
if sys.platform == "win32":
|
||||
msvcrt.locking(file_handle.fileno(), msvcrt.LK_UNLCK, 1)
|
||||
else:
|
||||
fcntl.flock(file_handle.fileno(), fcntl.LOCK_UN)
|
||||
|
||||
def _get_or_create_key(self) -> bytes:
|
||||
"""
|
||||
Get or create the encryption key with file locking to prevent race conditions.
|
||||
"""Get or create the encryption key.
|
||||
|
||||
Returns:
|
||||
The encryption key.
|
||||
The encryption key as bytes.
|
||||
"""
|
||||
key_filename = "secret.key"
|
||||
storage_path = self.get_secure_storage_path()
|
||||
key_filename: str = "secret.key"
|
||||
|
||||
key = self.read_secure_file(key_filename)
|
||||
if key is not None and len(key) == 44:
|
||||
key = self._read_secure_file(key_filename)
|
||||
if key is not None and len(key) == _FERNET_KEY_LENGTH:
|
||||
return key
|
||||
|
||||
lock_file_path = storage_path / f"{key_filename}.lock"
|
||||
|
||||
try:
|
||||
lock_file_path.touch()
|
||||
|
||||
with open(lock_file_path, "r+b") as lock_file:
|
||||
self._acquire_lock(lock_file)
|
||||
try:
|
||||
key = self.read_secure_file(key_filename)
|
||||
if key is not None and len(key) == 44:
|
||||
return key
|
||||
|
||||
new_key = Fernet.generate_key()
|
||||
self.save_secure_file(key_filename, new_key)
|
||||
return new_key
|
||||
finally:
|
||||
try:
|
||||
self._release_lock(lock_file)
|
||||
except OSError:
|
||||
pass
|
||||
except OSError:
|
||||
key = self.read_secure_file(key_filename)
|
||||
if key is not None and len(key) == 44:
|
||||
return key
|
||||
|
||||
new_key = Fernet.generate_key()
|
||||
self.save_secure_file(key_filename, new_key)
|
||||
new_key = Fernet.generate_key()
|
||||
if self._atomic_create_secure_file(key_filename, new_key):
|
||||
return new_key
|
||||
|
||||
def save_tokens(self, access_token: str, expires_at: int) -> None:
|
||||
"""
|
||||
Save the access token and its expiration time.
|
||||
key = self._read_secure_file(key_filename)
|
||||
if key is not None and len(key) == _FERNET_KEY_LENGTH:
|
||||
return key
|
||||
|
||||
:param access_token: The access token to save.
|
||||
:param expires_at: The UNIX timestamp of the expiration time.
|
||||
raise RuntimeError("Failed to create or read encryption key")
|
||||
|
||||
def save_tokens(self, access_token: str, expires_at: int) -> None:
|
||||
"""Save the access token and its expiration time.
|
||||
|
||||
Args:
|
||||
access_token: The access token to save.
|
||||
expires_at: The UNIX timestamp of the expiration time.
|
||||
"""
|
||||
expiration_time = datetime.fromtimestamp(expires_at)
|
||||
data = {
|
||||
@@ -107,15 +60,15 @@ class TokenManager:
|
||||
"expiration": expiration_time.isoformat(),
|
||||
}
|
||||
encrypted_data = self.fernet.encrypt(json.dumps(data).encode())
|
||||
self.save_secure_file(self.file_path, encrypted_data)
|
||||
self._atomic_write_secure_file(self.file_path, encrypted_data)
|
||||
|
||||
def get_token(self) -> str | None:
|
||||
"""
|
||||
Get the access token if it is valid and not expired.
|
||||
"""Get the access token if it is valid and not expired.
|
||||
|
||||
:return: The access token if valid and not expired, otherwise None.
|
||||
Returns:
|
||||
The access token if valid and not expired, otherwise None.
|
||||
"""
|
||||
encrypted_data = self.read_secure_file(self.file_path)
|
||||
encrypted_data = self._read_secure_file(self.file_path)
|
||||
if encrypted_data is None:
|
||||
return None
|
||||
|
||||
@@ -126,20 +79,18 @@ class TokenManager:
|
||||
if expiration <= datetime.now():
|
||||
return None
|
||||
|
||||
return cast(str | None, data["access_token"])
|
||||
return cast(str | None, data.get("access_token"))
|
||||
|
||||
def clear_tokens(self) -> None:
|
||||
"""
|
||||
Clear the tokens.
|
||||
"""
|
||||
self.delete_secure_file(self.file_path)
|
||||
"""Clear the stored tokens."""
|
||||
self._delete_secure_file(self.file_path)
|
||||
|
||||
@staticmethod
|
||||
def get_secure_storage_path() -> Path:
|
||||
"""
|
||||
Get the secure storage path based on the operating system.
|
||||
def _get_secure_storage_path() -> Path:
|
||||
"""Get the secure storage path based on the operating system.
|
||||
|
||||
:return: The secure storage path.
|
||||
Returns:
|
||||
The secure storage path.
|
||||
"""
|
||||
if sys.platform == "win32":
|
||||
base_path = os.environ.get("LOCALAPPDATA")
|
||||
@@ -155,44 +106,81 @@ class TokenManager:
|
||||
|
||||
return storage_path
|
||||
|
||||
def save_secure_file(self, filename: str, content: bytes) -> None:
|
||||
"""
|
||||
Save the content to a secure file.
|
||||
def _atomic_create_secure_file(self, filename: str, content: bytes) -> bool:
|
||||
"""Create a file only if it doesn't exist.
|
||||
|
||||
:param filename: The name of the file.
|
||||
:param content: The content to save.
|
||||
Args:
|
||||
filename: The name of the file.
|
||||
content: The content to write.
|
||||
|
||||
Returns:
|
||||
True if file was created, False if it already exists.
|
||||
"""
|
||||
storage_path = self.get_secure_storage_path()
|
||||
storage_path = self._get_secure_storage_path()
|
||||
file_path = storage_path / filename
|
||||
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(content)
|
||||
try:
|
||||
fd = os.open(file_path, os.O_CREAT | os.O_EXCL | os.O_WRONLY, 0o600)
|
||||
try:
|
||||
os.write(fd, content)
|
||||
finally:
|
||||
os.close(fd)
|
||||
return True
|
||||
except FileExistsError:
|
||||
return False
|
||||
|
||||
os.chmod(file_path, 0o600)
|
||||
def _atomic_write_secure_file(self, filename: str, content: bytes) -> None:
|
||||
"""Write content to a secure file.
|
||||
|
||||
def read_secure_file(self, filename: str) -> bytes | None:
|
||||
Args:
|
||||
filename: The name of the file.
|
||||
content: The content to write.
|
||||
"""
|
||||
Read the content of a secure file.
|
||||
|
||||
:param filename: The name of the file.
|
||||
:return: The content of the file if it exists, otherwise None.
|
||||
"""
|
||||
storage_path = self.get_secure_storage_path()
|
||||
storage_path = self._get_secure_storage_path()
|
||||
file_path = storage_path / filename
|
||||
|
||||
if not file_path.exists():
|
||||
fd, temp_path = tempfile.mkstemp(dir=storage_path, prefix=f".{filename}.")
|
||||
fd_closed = False
|
||||
try:
|
||||
os.write(fd, content)
|
||||
os.close(fd)
|
||||
fd_closed = True
|
||||
os.chmod(temp_path, 0o600)
|
||||
os.replace(temp_path, file_path)
|
||||
except Exception:
|
||||
if not fd_closed:
|
||||
os.close(fd)
|
||||
if os.path.exists(temp_path):
|
||||
os.unlink(temp_path)
|
||||
raise
|
||||
|
||||
def _read_secure_file(self, filename: str) -> bytes | None:
|
||||
"""Read the content of a secure file.
|
||||
|
||||
Args:
|
||||
filename: The name of the file.
|
||||
|
||||
Returns:
|
||||
The content of the file if it exists, otherwise None.
|
||||
"""
|
||||
storage_path = self._get_secure_storage_path()
|
||||
file_path = storage_path / filename
|
||||
|
||||
try:
|
||||
with open(file_path, "rb") as f:
|
||||
return f.read()
|
||||
except FileNotFoundError:
|
||||
return None
|
||||
|
||||
with open(file_path, "rb") as f:
|
||||
return f.read()
|
||||
def _delete_secure_file(self, filename: str) -> None:
|
||||
"""Delete a secure file.
|
||||
|
||||
def delete_secure_file(self, filename: str) -> None:
|
||||
Args:
|
||||
filename: The name of the file.
|
||||
"""
|
||||
Delete the secure file.
|
||||
|
||||
:param filename: The name of the file.
|
||||
"""
|
||||
storage_path = self.get_secure_storage_path()
|
||||
storage_path = self._get_secure_storage_path()
|
||||
file_path = storage_path / filename
|
||||
if file_path.exists():
|
||||
file_path.unlink(missing_ok=True)
|
||||
try:
|
||||
file_path.unlink()
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
@@ -5,7 +5,7 @@ description = "{{name}} using crewAI"
|
||||
authors = [{ name = "Your Name", email = "you@example.com" }]
|
||||
requires-python = ">=3.10,<3.14"
|
||||
dependencies = [
|
||||
"crewai[tools]==1.6.0"
|
||||
"crewai[tools]==1.7.0"
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
||||
@@ -5,7 +5,7 @@ description = "{{name}} using crewAI"
|
||||
authors = [{ name = "Your Name", email = "you@example.com" }]
|
||||
requires-python = ">=3.10,<3.14"
|
||||
dependencies = [
|
||||
"crewai[tools]==1.6.0"
|
||||
"crewai[tools]==1.7.0"
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
||||
@@ -35,6 +35,14 @@ from crewai.agent import Agent
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
from crewai.agents.cache.cache_handler import CacheHandler
|
||||
from crewai.crews.crew_output import CrewOutput
|
||||
from crewai.crews.utils import (
|
||||
StreamingContext,
|
||||
check_conditional_skip,
|
||||
enable_agent_streaming,
|
||||
prepare_kickoff,
|
||||
prepare_task_execution,
|
||||
run_for_each_async,
|
||||
)
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.event_listener import EventListener
|
||||
from crewai.events.listeners.tracing.trace_listener import (
|
||||
@@ -47,7 +55,6 @@ from crewai.events.listeners.tracing.utils import (
|
||||
from crewai.events.types.crew_events import (
|
||||
CrewKickoffCompletedEvent,
|
||||
CrewKickoffFailedEvent,
|
||||
CrewKickoffStartedEvent,
|
||||
CrewTestCompletedEvent,
|
||||
CrewTestFailedEvent,
|
||||
CrewTestStartedEvent,
|
||||
@@ -74,7 +81,7 @@ from crewai.tasks.conditional_task import ConditionalTask
|
||||
from crewai.tasks.task_output import TaskOutput
|
||||
from crewai.tools.agent_tools.agent_tools import AgentTools
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
from crewai.types.streaming import CrewStreamingOutput, FlowStreamingOutput
|
||||
from crewai.types.streaming import CrewStreamingOutput
|
||||
from crewai.types.usage_metrics import UsageMetrics
|
||||
from crewai.utilities.constants import NOT_SPECIFIED, TRAINING_DATA_FILE
|
||||
from crewai.utilities.crew.models import CrewContext
|
||||
@@ -92,10 +99,8 @@ from crewai.utilities.planning_handler import CrewPlanner
|
||||
from crewai.utilities.printer import PrinterColor
|
||||
from crewai.utilities.rpm_controller import RPMController
|
||||
from crewai.utilities.streaming import (
|
||||
TaskInfo,
|
||||
create_async_chunk_generator,
|
||||
create_chunk_generator,
|
||||
create_streaming_state,
|
||||
signal_end,
|
||||
signal_error,
|
||||
)
|
||||
@@ -268,7 +273,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
description="list of file paths for task execution JSON files.",
|
||||
)
|
||||
execution_logs: list[dict[str, Any]] = Field(
|
||||
default=[],
|
||||
default_factory=list,
|
||||
description="list of execution logs for tasks",
|
||||
)
|
||||
knowledge_sources: list[BaseKnowledgeSource] | None = Field(
|
||||
@@ -327,7 +332,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
def set_private_attrs(self) -> Crew:
|
||||
"""set private attributes."""
|
||||
self._cache_handler = CacheHandler()
|
||||
event_listener = EventListener() # type: ignore[no-untyped-call]
|
||||
event_listener = EventListener()
|
||||
|
||||
# Determine and set tracing state once for this execution
|
||||
tracing_enabled = should_enable_tracing(override=self.tracing)
|
||||
@@ -348,12 +353,12 @@ class Crew(FlowTrackable, BaseModel):
|
||||
return self
|
||||
|
||||
def _initialize_default_memories(self) -> None:
|
||||
self._long_term_memory = self._long_term_memory or LongTermMemory() # type: ignore[no-untyped-call]
|
||||
self._short_term_memory = self._short_term_memory or ShortTermMemory( # type: ignore[no-untyped-call]
|
||||
self._long_term_memory = self._long_term_memory or LongTermMemory()
|
||||
self._short_term_memory = self._short_term_memory or ShortTermMemory(
|
||||
crew=self,
|
||||
embedder_config=self.embedder,
|
||||
)
|
||||
self._entity_memory = self.entity_memory or EntityMemory( # type: ignore[no-untyped-call]
|
||||
self._entity_memory = self.entity_memory or EntityMemory(
|
||||
crew=self, embedder_config=self.embedder
|
||||
)
|
||||
|
||||
@@ -404,8 +409,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
raise PydanticCustomError(
|
||||
"missing_manager_llm_or_manager_agent",
|
||||
(
|
||||
"Attribute `manager_llm` or `manager_agent` is required "
|
||||
"when using hierarchical process."
|
||||
"Attribute `manager_llm` or `manager_agent` is required when using hierarchical process."
|
||||
),
|
||||
{},
|
||||
)
|
||||
@@ -511,10 +515,9 @@ class Crew(FlowTrackable, BaseModel):
|
||||
raise PydanticCustomError(
|
||||
"invalid_async_conditional_task",
|
||||
(
|
||||
f"Conditional Task: {task.description}, "
|
||||
f"cannot be executed asynchronously."
|
||||
"Conditional Task: {description}, cannot be executed asynchronously."
|
||||
),
|
||||
{},
|
||||
{"description": task.description},
|
||||
)
|
||||
return self
|
||||
|
||||
@@ -675,21 +678,8 @@ class Crew(FlowTrackable, BaseModel):
|
||||
inputs: dict[str, Any] | None = None,
|
||||
) -> CrewOutput | CrewStreamingOutput:
|
||||
if self.stream:
|
||||
for agent in self.agents:
|
||||
if agent.llm is not None:
|
||||
agent.llm.stream = True
|
||||
|
||||
result_holder: list[CrewOutput] = []
|
||||
current_task_info: TaskInfo = {
|
||||
"index": 0,
|
||||
"name": "",
|
||||
"id": "",
|
||||
"agent_role": "",
|
||||
"agent_id": "",
|
||||
}
|
||||
|
||||
state = create_streaming_state(current_task_info, result_holder)
|
||||
output_holder: list[CrewStreamingOutput | FlowStreamingOutput] = []
|
||||
enable_agent_streaming(self.agents)
|
||||
ctx = StreamingContext()
|
||||
|
||||
def run_crew() -> None:
|
||||
"""Execute the crew and capture the result."""
|
||||
@@ -697,59 +687,28 @@ class Crew(FlowTrackable, BaseModel):
|
||||
self.stream = False
|
||||
crew_result = self.kickoff(inputs=inputs)
|
||||
if isinstance(crew_result, CrewOutput):
|
||||
result_holder.append(crew_result)
|
||||
ctx.result_holder.append(crew_result)
|
||||
except Exception as exc:
|
||||
signal_error(state, exc)
|
||||
signal_error(ctx.state, exc)
|
||||
finally:
|
||||
self.stream = True
|
||||
signal_end(state)
|
||||
signal_end(ctx.state)
|
||||
|
||||
streaming_output = CrewStreamingOutput(
|
||||
sync_iterator=create_chunk_generator(state, run_crew, output_holder)
|
||||
sync_iterator=create_chunk_generator(
|
||||
ctx.state, run_crew, ctx.output_holder
|
||||
)
|
||||
)
|
||||
output_holder.append(streaming_output)
|
||||
ctx.output_holder.append(streaming_output)
|
||||
return streaming_output
|
||||
|
||||
ctx = baggage.set_baggage(
|
||||
baggage_ctx = baggage.set_baggage(
|
||||
"crew_context", CrewContext(id=str(self.id), key=self.key)
|
||||
)
|
||||
token = attach(ctx)
|
||||
token = attach(baggage_ctx)
|
||||
|
||||
try:
|
||||
for before_callback in self.before_kickoff_callbacks:
|
||||
if inputs is None:
|
||||
inputs = {}
|
||||
inputs = before_callback(inputs)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
CrewKickoffStartedEvent(crew_name=self.name, inputs=inputs),
|
||||
)
|
||||
|
||||
# Starts the crew to work on its assigned tasks.
|
||||
self._task_output_handler.reset()
|
||||
self._logging_color = "bold_purple"
|
||||
|
||||
if inputs is not None:
|
||||
self._inputs = inputs
|
||||
self._interpolate_inputs(inputs)
|
||||
self._set_tasks_callbacks()
|
||||
self._set_allow_crewai_trigger_context_for_first_task()
|
||||
|
||||
for agent in self.agents:
|
||||
agent.crew = self
|
||||
agent.set_knowledge(crew_embedder=self.embedder)
|
||||
# TODO: Create an AgentFunctionCalling protocol for future refactoring
|
||||
if not agent.function_calling_llm: # type: ignore # "BaseAgent" has no attribute "function_calling_llm"
|
||||
agent.function_calling_llm = self.function_calling_llm # type: ignore # "BaseAgent" has no attribute "function_calling_llm"
|
||||
|
||||
if not agent.step_callback: # type: ignore # "BaseAgent" has no attribute "step_callback"
|
||||
agent.step_callback = self.step_callback # type: ignore # "BaseAgent" has no attribute "step_callback"
|
||||
|
||||
agent.create_agent_executor()
|
||||
|
||||
if self.planning:
|
||||
self._handle_crew_planning()
|
||||
inputs = prepare_kickoff(self, inputs)
|
||||
|
||||
if self.process == Process.sequential:
|
||||
result = self._run_sequential_process()
|
||||
@@ -814,42 +773,27 @@ class Crew(FlowTrackable, BaseModel):
|
||||
inputs = inputs or {}
|
||||
|
||||
if self.stream:
|
||||
for agent in self.agents:
|
||||
if agent.llm is not None:
|
||||
agent.llm.stream = True
|
||||
|
||||
result_holder: list[CrewOutput] = []
|
||||
current_task_info: TaskInfo = {
|
||||
"index": 0,
|
||||
"name": "",
|
||||
"id": "",
|
||||
"agent_role": "",
|
||||
"agent_id": "",
|
||||
}
|
||||
|
||||
state = create_streaming_state(
|
||||
current_task_info, result_holder, use_async=True
|
||||
)
|
||||
output_holder: list[CrewStreamingOutput | FlowStreamingOutput] = []
|
||||
enable_agent_streaming(self.agents)
|
||||
ctx = StreamingContext(use_async=True)
|
||||
|
||||
async def run_crew() -> None:
|
||||
try:
|
||||
self.stream = False
|
||||
result = await asyncio.to_thread(self.kickoff, inputs)
|
||||
if isinstance(result, CrewOutput):
|
||||
result_holder.append(result)
|
||||
ctx.result_holder.append(result)
|
||||
except Exception as e:
|
||||
signal_error(state, e, is_async=True)
|
||||
signal_error(ctx.state, e, is_async=True)
|
||||
finally:
|
||||
self.stream = True
|
||||
signal_end(state, is_async=True)
|
||||
signal_end(ctx.state, is_async=True)
|
||||
|
||||
streaming_output = CrewStreamingOutput(
|
||||
async_iterator=create_async_chunk_generator(
|
||||
state, run_crew, output_holder
|
||||
ctx.state, run_crew, ctx.output_holder
|
||||
)
|
||||
)
|
||||
output_holder.append(streaming_output)
|
||||
ctx.output_holder.append(streaming_output)
|
||||
|
||||
return streaming_output
|
||||
|
||||
@@ -864,89 +808,207 @@ class Crew(FlowTrackable, BaseModel):
|
||||
from all crews as they arrive. After iteration, access results via .results
|
||||
(list of CrewOutput).
|
||||
"""
|
||||
crew_copies = [self.copy() for _ in inputs]
|
||||
|
||||
async def kickoff_fn(
|
||||
crew: Crew, input_data: dict[str, Any]
|
||||
) -> CrewOutput | CrewStreamingOutput:
|
||||
return await crew.kickoff_async(inputs=input_data)
|
||||
|
||||
return await run_for_each_async(self, inputs, kickoff_fn)
|
||||
|
||||
async def akickoff(
|
||||
self, inputs: dict[str, Any] | None = None
|
||||
) -> CrewOutput | CrewStreamingOutput:
|
||||
"""Native async kickoff method using async task execution throughout.
|
||||
|
||||
Unlike kickoff_async which wraps sync kickoff in a thread, this method
|
||||
uses native async/await for all operations including task execution,
|
||||
memory operations, and knowledge queries.
|
||||
"""
|
||||
if self.stream:
|
||||
result_holder: list[list[CrewOutput]] = [[]]
|
||||
current_task_info: TaskInfo = {
|
||||
"index": 0,
|
||||
"name": "",
|
||||
"id": "",
|
||||
"agent_role": "",
|
||||
"agent_id": "",
|
||||
}
|
||||
enable_agent_streaming(self.agents)
|
||||
ctx = StreamingContext(use_async=True)
|
||||
|
||||
state = create_streaming_state(
|
||||
current_task_info, result_holder, use_async=True
|
||||
)
|
||||
output_holder: list[CrewStreamingOutput | FlowStreamingOutput] = []
|
||||
|
||||
async def run_all_crews() -> None:
|
||||
"""Run all crew copies and aggregate their streaming outputs."""
|
||||
async def run_crew() -> None:
|
||||
try:
|
||||
streaming_outputs: list[CrewStreamingOutput] = []
|
||||
for i, crew in enumerate(crew_copies):
|
||||
streaming = await crew.kickoff_async(inputs=inputs[i])
|
||||
if isinstance(streaming, CrewStreamingOutput):
|
||||
streaming_outputs.append(streaming)
|
||||
|
||||
async def consume_stream(
|
||||
stream_output: CrewStreamingOutput,
|
||||
) -> CrewOutput:
|
||||
"""Consume stream chunks and forward to parent queue.
|
||||
|
||||
Args:
|
||||
stream_output: The streaming output to consume.
|
||||
|
||||
Returns:
|
||||
The final CrewOutput result.
|
||||
"""
|
||||
async for chunk in stream_output:
|
||||
if state.async_queue is not None and state.loop is not None:
|
||||
state.loop.call_soon_threadsafe(
|
||||
state.async_queue.put_nowait, chunk
|
||||
)
|
||||
return stream_output.result
|
||||
|
||||
crew_results = await asyncio.gather(
|
||||
*[consume_stream(s) for s in streaming_outputs]
|
||||
)
|
||||
result_holder[0] = list(crew_results)
|
||||
except Exception as e:
|
||||
signal_error(state, e, is_async=True)
|
||||
self.stream = False
|
||||
inner_result = await self.akickoff(inputs)
|
||||
if isinstance(inner_result, CrewOutput):
|
||||
ctx.result_holder.append(inner_result)
|
||||
except Exception as exc:
|
||||
signal_error(ctx.state, exc, is_async=True)
|
||||
finally:
|
||||
signal_end(state, is_async=True)
|
||||
self.stream = True
|
||||
signal_end(ctx.state, is_async=True)
|
||||
|
||||
streaming_output = CrewStreamingOutput(
|
||||
async_iterator=create_async_chunk_generator(
|
||||
state, run_all_crews, output_holder
|
||||
ctx.state, run_crew, ctx.output_holder
|
||||
)
|
||||
)
|
||||
|
||||
def set_results_wrapper(result: Any) -> None:
|
||||
"""Wrap _set_results to match _set_result signature."""
|
||||
streaming_output._set_results(result)
|
||||
|
||||
streaming_output._set_result = set_results_wrapper # type: ignore[method-assign]
|
||||
output_holder.append(streaming_output)
|
||||
ctx.output_holder.append(streaming_output)
|
||||
|
||||
return streaming_output
|
||||
|
||||
tasks = [
|
||||
asyncio.create_task(crew_copy.kickoff_async(inputs=input_data))
|
||||
for crew_copy, input_data in zip(crew_copies, inputs, strict=True)
|
||||
]
|
||||
baggage_ctx = baggage.set_baggage(
|
||||
"crew_context", CrewContext(id=str(self.id), key=self.key)
|
||||
)
|
||||
token = attach(baggage_ctx)
|
||||
|
||||
results = await asyncio.gather(*tasks)
|
||||
try:
|
||||
inputs = prepare_kickoff(self, inputs)
|
||||
|
||||
total_usage_metrics = UsageMetrics()
|
||||
for crew_copy in crew_copies:
|
||||
if crew_copy.usage_metrics:
|
||||
total_usage_metrics.add_usage_metrics(crew_copy.usage_metrics)
|
||||
self.usage_metrics = total_usage_metrics
|
||||
if self.process == Process.sequential:
|
||||
result = await self._arun_sequential_process()
|
||||
elif self.process == Process.hierarchical:
|
||||
result = await self._arun_hierarchical_process()
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"The process '{self.process}' is not implemented yet."
|
||||
)
|
||||
|
||||
self._task_output_handler.reset()
|
||||
return list(results)
|
||||
for after_callback in self.after_kickoff_callbacks:
|
||||
result = after_callback(result)
|
||||
|
||||
self.usage_metrics = self.calculate_usage_metrics()
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
CrewKickoffFailedEvent(error=str(e), crew_name=self.name),
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
detach(token)
|
||||
|
||||
async def akickoff_for_each(
|
||||
self, inputs: list[dict[str, Any]]
|
||||
) -> list[CrewOutput | CrewStreamingOutput] | CrewStreamingOutput:
|
||||
"""Native async execution of the Crew's workflow for each input.
|
||||
|
||||
Uses native async throughout rather than thread-based async.
|
||||
If stream=True, returns a single CrewStreamingOutput that yields chunks
|
||||
from all crews as they arrive.
|
||||
"""
|
||||
|
||||
async def kickoff_fn(
|
||||
crew: Crew, input_data: dict[str, Any]
|
||||
) -> CrewOutput | CrewStreamingOutput:
|
||||
return await crew.akickoff(inputs=input_data)
|
||||
|
||||
return await run_for_each_async(self, inputs, kickoff_fn)
|
||||
|
||||
async def _arun_sequential_process(self) -> CrewOutput:
|
||||
"""Executes tasks sequentially using native async and returns the final output."""
|
||||
return await self._aexecute_tasks(self.tasks)
|
||||
|
||||
async def _arun_hierarchical_process(self) -> CrewOutput:
|
||||
"""Creates and assigns a manager agent to complete the tasks using native async."""
|
||||
self._create_manager_agent()
|
||||
return await self._aexecute_tasks(self.tasks)
|
||||
|
||||
async def _aexecute_tasks(
|
||||
self,
|
||||
tasks: list[Task],
|
||||
start_index: int | None = 0,
|
||||
was_replayed: bool = False,
|
||||
) -> CrewOutput:
|
||||
"""Executes tasks using native async and returns the final output.
|
||||
|
||||
Args:
|
||||
tasks: List of tasks to execute
|
||||
start_index: Index to start execution from (for replay)
|
||||
was_replayed: Whether this is a replayed execution
|
||||
|
||||
Returns:
|
||||
CrewOutput: Final output of the crew
|
||||
"""
|
||||
task_outputs: list[TaskOutput] = []
|
||||
pending_tasks: list[tuple[Task, asyncio.Task[TaskOutput], int]] = []
|
||||
last_sync_output: TaskOutput | None = None
|
||||
|
||||
for task_index, task in enumerate(tasks):
|
||||
exec_data, task_outputs, last_sync_output = prepare_task_execution(
|
||||
self, task, task_index, start_index, task_outputs, last_sync_output
|
||||
)
|
||||
if exec_data.should_skip:
|
||||
continue
|
||||
|
||||
if isinstance(task, ConditionalTask):
|
||||
skipped_task_output = await self._ahandle_conditional_task(
|
||||
task, task_outputs, pending_tasks, task_index, was_replayed
|
||||
)
|
||||
if skipped_task_output:
|
||||
task_outputs.append(skipped_task_output)
|
||||
continue
|
||||
|
||||
if task.async_execution:
|
||||
context = self._get_context(
|
||||
task, [last_sync_output] if last_sync_output else []
|
||||
)
|
||||
async_task = asyncio.create_task(
|
||||
task.aexecute_sync(
|
||||
agent=exec_data.agent,
|
||||
context=context,
|
||||
tools=exec_data.tools,
|
||||
)
|
||||
)
|
||||
pending_tasks.append((task, async_task, task_index))
|
||||
else:
|
||||
if pending_tasks:
|
||||
task_outputs = await self._aprocess_async_tasks(
|
||||
pending_tasks, was_replayed
|
||||
)
|
||||
pending_tasks.clear()
|
||||
|
||||
context = self._get_context(task, task_outputs)
|
||||
task_output = await task.aexecute_sync(
|
||||
agent=exec_data.agent,
|
||||
context=context,
|
||||
tools=exec_data.tools,
|
||||
)
|
||||
task_outputs.append(task_output)
|
||||
self._process_task_result(task, task_output)
|
||||
self._store_execution_log(task, task_output, task_index, was_replayed)
|
||||
|
||||
if pending_tasks:
|
||||
task_outputs = await self._aprocess_async_tasks(pending_tasks, was_replayed)
|
||||
|
||||
return self._create_crew_output(task_outputs)
|
||||
|
||||
async def _ahandle_conditional_task(
|
||||
self,
|
||||
task: ConditionalTask,
|
||||
task_outputs: list[TaskOutput],
|
||||
pending_tasks: list[tuple[Task, asyncio.Task[TaskOutput], int]],
|
||||
task_index: int,
|
||||
was_replayed: bool,
|
||||
) -> TaskOutput | None:
|
||||
"""Handle conditional task evaluation using native async."""
|
||||
if pending_tasks:
|
||||
task_outputs = await self._aprocess_async_tasks(pending_tasks, was_replayed)
|
||||
pending_tasks.clear()
|
||||
|
||||
return check_conditional_skip(
|
||||
self, task, task_outputs, task_index, was_replayed
|
||||
)
|
||||
|
||||
async def _aprocess_async_tasks(
|
||||
self,
|
||||
pending_tasks: list[tuple[Task, asyncio.Task[TaskOutput], int]],
|
||||
was_replayed: bool = False,
|
||||
) -> list[TaskOutput]:
|
||||
"""Process pending async tasks and return their outputs."""
|
||||
task_outputs: list[TaskOutput] = []
|
||||
for future_task, async_task, task_index in pending_tasks:
|
||||
task_output = await async_task
|
||||
task_outputs.append(task_output)
|
||||
self._process_task_result(future_task, task_output)
|
||||
self._store_execution_log(
|
||||
future_task, task_output, task_index, was_replayed
|
||||
)
|
||||
return task_outputs
|
||||
|
||||
def _handle_crew_planning(self) -> None:
|
||||
"""Handles the Crew planning."""
|
||||
@@ -1048,33 +1110,11 @@ class Crew(FlowTrackable, BaseModel):
|
||||
last_sync_output: TaskOutput | None = None
|
||||
|
||||
for task_index, task in enumerate(tasks):
|
||||
if start_index is not None and task_index < start_index:
|
||||
if task.output:
|
||||
if task.async_execution:
|
||||
task_outputs.append(task.output)
|
||||
else:
|
||||
task_outputs = [task.output]
|
||||
last_sync_output = task.output
|
||||
continue
|
||||
|
||||
agent_to_use = self._get_agent_to_use(task)
|
||||
if agent_to_use is None:
|
||||
raise ValueError(
|
||||
f"No agent available for task: {task.description}. "
|
||||
f"Ensure that either the task has an assigned agent "
|
||||
f"or a manager agent is provided."
|
||||
)
|
||||
|
||||
# Determine which tools to use - task tools take precedence over agent tools
|
||||
tools_for_task = task.tools or agent_to_use.tools or []
|
||||
# Prepare tools and ensure they're compatible with task execution
|
||||
tools_for_task = self._prepare_tools(
|
||||
agent_to_use,
|
||||
task,
|
||||
tools_for_task,
|
||||
exec_data, task_outputs, last_sync_output = prepare_task_execution(
|
||||
self, task, task_index, start_index, task_outputs, last_sync_output
|
||||
)
|
||||
|
||||
self._log_task_start(task, agent_to_use.role)
|
||||
if exec_data.should_skip:
|
||||
continue
|
||||
|
||||
if isinstance(task, ConditionalTask):
|
||||
skipped_task_output = self._handle_conditional_task(
|
||||
@@ -1089,9 +1129,9 @@ class Crew(FlowTrackable, BaseModel):
|
||||
task, [last_sync_output] if last_sync_output else []
|
||||
)
|
||||
future = task.execute_async(
|
||||
agent=agent_to_use,
|
||||
agent=exec_data.agent,
|
||||
context=context,
|
||||
tools=tools_for_task,
|
||||
tools=exec_data.tools,
|
||||
)
|
||||
futures.append((task, future, task_index))
|
||||
else:
|
||||
@@ -1101,9 +1141,9 @@ class Crew(FlowTrackable, BaseModel):
|
||||
|
||||
context = self._get_context(task, task_outputs)
|
||||
task_output = task.execute_sync(
|
||||
agent=agent_to_use,
|
||||
agent=exec_data.agent,
|
||||
context=context,
|
||||
tools=tools_for_task,
|
||||
tools=exec_data.tools,
|
||||
)
|
||||
task_outputs.append(task_output)
|
||||
self._process_task_result(task, task_output)
|
||||
@@ -1126,19 +1166,9 @@ class Crew(FlowTrackable, BaseModel):
|
||||
task_outputs = self._process_async_tasks(futures, was_replayed)
|
||||
futures.clear()
|
||||
|
||||
previous_output = task_outputs[-1] if task_outputs else None
|
||||
if previous_output is not None and not task.should_execute(previous_output):
|
||||
self._logger.log(
|
||||
"debug",
|
||||
f"Skipping conditional task: {task.description}",
|
||||
color="yellow",
|
||||
)
|
||||
skipped_task_output = task.get_skipped_task_output()
|
||||
|
||||
if not was_replayed:
|
||||
self._store_execution_log(task, skipped_task_output, task_index)
|
||||
return skipped_task_output
|
||||
return None
|
||||
return check_conditional_skip(
|
||||
self, task, task_outputs, task_index, was_replayed
|
||||
)
|
||||
|
||||
def _prepare_tools(
|
||||
self, agent: BaseAgent, task: Task, tools: list[BaseTool]
|
||||
@@ -1302,7 +1332,8 @@ class Crew(FlowTrackable, BaseModel):
|
||||
)
|
||||
return tools
|
||||
|
||||
def _get_context(self, task: Task, task_outputs: list[TaskOutput]) -> str:
|
||||
@staticmethod
|
||||
def _get_context(task: Task, task_outputs: list[TaskOutput]) -> str:
|
||||
if not task.context:
|
||||
return ""
|
||||
|
||||
@@ -1371,7 +1402,8 @@ class Crew(FlowTrackable, BaseModel):
|
||||
)
|
||||
return task_outputs
|
||||
|
||||
def _find_task_index(self, task_id: str, stored_outputs: list[Any]) -> int | None:
|
||||
@staticmethod
|
||||
def _find_task_index(task_id: str, stored_outputs: list[Any]) -> int | None:
|
||||
return next(
|
||||
(
|
||||
index
|
||||
@@ -1431,6 +1463,16 @@ class Crew(FlowTrackable, BaseModel):
|
||||
)
|
||||
return None
|
||||
|
||||
async def aquery_knowledge(
|
||||
self, query: list[str], results_limit: int = 3, score_threshold: float = 0.35
|
||||
) -> list[SearchResult] | None:
|
||||
"""Query the crew's knowledge base for relevant information asynchronously."""
|
||||
if self.knowledge:
|
||||
return await self.knowledge.aquery(
|
||||
query, results_limit=results_limit, score_threshold=score_threshold
|
||||
)
|
||||
return None
|
||||
|
||||
def fetch_inputs(self) -> set[str]:
|
||||
"""
|
||||
Gathers placeholders (e.g., {something}) referenced in tasks or agents.
|
||||
@@ -1439,7 +1481,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
|
||||
Returns a set of all discovered placeholder names.
|
||||
"""
|
||||
placeholder_pattern = re.compile(r"\{(.+?)\}")
|
||||
placeholder_pattern = re.compile(r"\{(.+?)}")
|
||||
required_inputs: set[str] = set()
|
||||
|
||||
# Scan tasks for inputs
|
||||
@@ -1687,6 +1729,32 @@ class Crew(FlowTrackable, BaseModel):
|
||||
self._logger.log("error", error_msg)
|
||||
raise RuntimeError(error_msg) from e
|
||||
|
||||
def _reset_memory_system(
|
||||
self, system: Any, name: str, reset_fn: Callable[[Any], Any]
|
||||
) -> None:
|
||||
"""Reset a single memory system.
|
||||
|
||||
Args:
|
||||
system: The memory system instance to reset.
|
||||
name: Display name of the memory system for logging.
|
||||
reset_fn: Function to call to reset the system.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the reset operation fails.
|
||||
"""
|
||||
try:
|
||||
reset_fn(system)
|
||||
self._logger.log(
|
||||
"info",
|
||||
f"[Crew ({self.name if self.name else self.id})] "
|
||||
f"{name} memory has been reset",
|
||||
)
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"[Crew ({self.name if self.name else self.id})] "
|
||||
f"Failed to reset {name} memory: {e!s}"
|
||||
) from e
|
||||
|
||||
def _reset_all_memories(self) -> None:
|
||||
"""Reset all available memory systems."""
|
||||
memory_systems = self._get_memory_systems()
|
||||
@@ -1694,21 +1762,10 @@ class Crew(FlowTrackable, BaseModel):
|
||||
for config in memory_systems.values():
|
||||
if (system := config.get("system")) is not None:
|
||||
name = config.get("name")
|
||||
try:
|
||||
reset_fn: Callable[[Any], Any] = cast(
|
||||
Callable[[Any], Any], config.get("reset")
|
||||
)
|
||||
reset_fn(system)
|
||||
self._logger.log(
|
||||
"info",
|
||||
f"[Crew ({self.name if self.name else self.id})] "
|
||||
f"{name} memory has been reset",
|
||||
)
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"[Crew ({self.name if self.name else self.id})] "
|
||||
f"Failed to reset {name} memory: {e!s}"
|
||||
) from e
|
||||
reset_fn: Callable[[Any], Any] = cast(
|
||||
Callable[[Any], Any], config.get("reset")
|
||||
)
|
||||
self._reset_memory_system(system, name, reset_fn)
|
||||
|
||||
def _reset_specific_memory(self, memory_type: str) -> None:
|
||||
"""Reset a specific memory system.
|
||||
@@ -1727,21 +1784,8 @@ class Crew(FlowTrackable, BaseModel):
|
||||
if system is None:
|
||||
raise RuntimeError(f"{name} memory system is not initialized")
|
||||
|
||||
try:
|
||||
reset_fn: Callable[[Any], Any] = cast(
|
||||
Callable[[Any], Any], config.get("reset")
|
||||
)
|
||||
reset_fn(system)
|
||||
self._logger.log(
|
||||
"info",
|
||||
f"[Crew ({self.name if self.name else self.id})] "
|
||||
f"{name} memory has been reset",
|
||||
)
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"[Crew ({self.name if self.name else self.id})] "
|
||||
f"Failed to reset {name} memory: {e!s}"
|
||||
) from e
|
||||
reset_fn: Callable[[Any], Any] = cast(Callable[[Any], Any], config.get("reset"))
|
||||
self._reset_memory_system(system, name, reset_fn)
|
||||
|
||||
def _get_memory_systems(self) -> dict[str, Any]:
|
||||
"""Get all available memory systems with their configuration.
|
||||
@@ -1829,7 +1873,8 @@ class Crew(FlowTrackable, BaseModel):
|
||||
):
|
||||
self.tasks[0].allow_crewai_trigger_context = True
|
||||
|
||||
def _show_tracing_disabled_message(self) -> None:
|
||||
@staticmethod
|
||||
def _show_tracing_disabled_message() -> None:
|
||||
"""Show a message when tracing is disabled."""
|
||||
from crewai.events.listeners.tracing.utils import has_user_declined_tracing
|
||||
|
||||
|
||||
363
lib/crewai/src/crewai/crews/utils.py
Normal file
363
lib/crewai/src/crewai/crews/utils.py
Normal file
@@ -0,0 +1,363 @@
|
||||
"""Utility functions for crew operations."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Callable, Coroutine, Iterable
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
from crewai.crews.crew_output import CrewOutput
|
||||
from crewai.rag.embeddings.types import EmbedderConfig
|
||||
from crewai.types.streaming import CrewStreamingOutput, FlowStreamingOutput
|
||||
from crewai.utilities.streaming import (
|
||||
StreamingState,
|
||||
TaskInfo,
|
||||
create_streaming_state,
|
||||
)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.crew import Crew
|
||||
|
||||
|
||||
def enable_agent_streaming(agents: Iterable[BaseAgent]) -> None:
|
||||
"""Enable streaming on all agents that have an LLM configured.
|
||||
|
||||
Args:
|
||||
agents: Iterable of agents to enable streaming on.
|
||||
"""
|
||||
for agent in agents:
|
||||
if agent.llm is not None:
|
||||
agent.llm.stream = True
|
||||
|
||||
|
||||
def setup_agents(
|
||||
crew: Crew,
|
||||
agents: Iterable[BaseAgent],
|
||||
embedder: EmbedderConfig | None,
|
||||
function_calling_llm: Any,
|
||||
step_callback: Callable[..., Any] | None,
|
||||
) -> None:
|
||||
"""Set up agents for crew execution.
|
||||
|
||||
Args:
|
||||
crew: The crew instance agents belong to.
|
||||
agents: Iterable of agents to set up.
|
||||
embedder: Embedder configuration for knowledge.
|
||||
function_calling_llm: Default function calling LLM for agents.
|
||||
step_callback: Default step callback for agents.
|
||||
"""
|
||||
for agent in agents:
|
||||
agent.crew = crew
|
||||
agent.set_knowledge(crew_embedder=embedder)
|
||||
if not agent.function_calling_llm: # type: ignore[attr-defined]
|
||||
agent.function_calling_llm = function_calling_llm # type: ignore[attr-defined]
|
||||
if not agent.step_callback: # type: ignore[attr-defined]
|
||||
agent.step_callback = step_callback # type: ignore[attr-defined]
|
||||
agent.create_agent_executor()
|
||||
|
||||
|
||||
class TaskExecutionData:
|
||||
"""Data container for prepared task execution information."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
agent: BaseAgent | None,
|
||||
tools: list[Any],
|
||||
should_skip: bool = False,
|
||||
) -> None:
|
||||
"""Initialize task execution data.
|
||||
|
||||
Args:
|
||||
agent: The agent to use for task execution (None if skipped).
|
||||
tools: Prepared tools for the task.
|
||||
should_skip: Whether the task should be skipped (replay).
|
||||
"""
|
||||
self.agent = agent
|
||||
self.tools = tools
|
||||
self.should_skip = should_skip
|
||||
|
||||
|
||||
def prepare_task_execution(
|
||||
crew: Crew,
|
||||
task: Any,
|
||||
task_index: int,
|
||||
start_index: int | None,
|
||||
task_outputs: list[Any],
|
||||
last_sync_output: Any | None,
|
||||
) -> tuple[TaskExecutionData, list[Any], Any | None]:
|
||||
"""Prepare a task for execution, handling replay skip logic and agent/tool setup.
|
||||
|
||||
Args:
|
||||
crew: The crew instance.
|
||||
task: The task to prepare.
|
||||
task_index: Index of the current task.
|
||||
start_index: Index to start execution from (for replay).
|
||||
task_outputs: Current list of task outputs.
|
||||
last_sync_output: Last synchronous task output.
|
||||
|
||||
Returns:
|
||||
A tuple of (TaskExecutionData or None if skipped, updated task_outputs, updated last_sync_output).
|
||||
If the task should be skipped, TaskExecutionData will have should_skip=True.
|
||||
|
||||
Raises:
|
||||
ValueError: If no agent is available for the task.
|
||||
"""
|
||||
# Handle replay skip
|
||||
if start_index is not None and task_index < start_index:
|
||||
if task.output:
|
||||
if task.async_execution:
|
||||
task_outputs.append(task.output)
|
||||
else:
|
||||
task_outputs = [task.output]
|
||||
last_sync_output = task.output
|
||||
return (
|
||||
TaskExecutionData(agent=None, tools=[], should_skip=True),
|
||||
task_outputs,
|
||||
last_sync_output,
|
||||
)
|
||||
|
||||
agent_to_use = crew._get_agent_to_use(task)
|
||||
if agent_to_use is None:
|
||||
raise ValueError(
|
||||
f"No agent available for task: {task.description}. "
|
||||
f"Ensure that either the task has an assigned agent "
|
||||
f"or a manager agent is provided."
|
||||
)
|
||||
|
||||
tools_for_task = task.tools or agent_to_use.tools or []
|
||||
tools_for_task = crew._prepare_tools(
|
||||
agent_to_use,
|
||||
task,
|
||||
tools_for_task,
|
||||
)
|
||||
|
||||
crew._log_task_start(task, agent_to_use.role)
|
||||
|
||||
return (
|
||||
TaskExecutionData(agent=agent_to_use, tools=tools_for_task),
|
||||
task_outputs,
|
||||
last_sync_output,
|
||||
)
|
||||
|
||||
|
||||
def check_conditional_skip(
|
||||
crew: Crew,
|
||||
task: Any,
|
||||
task_outputs: list[Any],
|
||||
task_index: int,
|
||||
was_replayed: bool,
|
||||
) -> Any | None:
|
||||
"""Check if a conditional task should be skipped.
|
||||
|
||||
Args:
|
||||
crew: The crew instance.
|
||||
task: The conditional task to check.
|
||||
task_outputs: List of previous task outputs.
|
||||
task_index: Index of the current task.
|
||||
was_replayed: Whether this is a replayed execution.
|
||||
|
||||
Returns:
|
||||
The skipped task output if the task should be skipped, None otherwise.
|
||||
"""
|
||||
previous_output = task_outputs[-1] if task_outputs else None
|
||||
if previous_output is not None and not task.should_execute(previous_output):
|
||||
crew._logger.log(
|
||||
"debug",
|
||||
f"Skipping conditional task: {task.description}",
|
||||
color="yellow",
|
||||
)
|
||||
skipped_task_output = task.get_skipped_task_output()
|
||||
|
||||
if not was_replayed:
|
||||
crew._store_execution_log(task, skipped_task_output, task_index)
|
||||
return skipped_task_output
|
||||
return None
|
||||
|
||||
|
||||
def prepare_kickoff(crew: Crew, inputs: dict[str, Any] | None) -> dict[str, Any] | None:
|
||||
"""Prepare crew for kickoff execution.
|
||||
|
||||
Handles before callbacks, event emission, task handler reset, input
|
||||
interpolation, task callbacks, agent setup, and planning.
|
||||
|
||||
Args:
|
||||
crew: The crew instance to prepare.
|
||||
inputs: Optional input dictionary to pass to the crew.
|
||||
|
||||
Returns:
|
||||
The potentially modified inputs dictionary after before callbacks.
|
||||
"""
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.crew_events import CrewKickoffStartedEvent
|
||||
|
||||
for before_callback in crew.before_kickoff_callbacks:
|
||||
if inputs is None:
|
||||
inputs = {}
|
||||
inputs = before_callback(inputs)
|
||||
|
||||
future = crewai_event_bus.emit(
|
||||
crew,
|
||||
CrewKickoffStartedEvent(crew_name=crew.name, inputs=inputs),
|
||||
)
|
||||
if future is not None:
|
||||
try:
|
||||
future.result()
|
||||
except Exception: # noqa: S110
|
||||
pass
|
||||
|
||||
crew._task_output_handler.reset()
|
||||
crew._logging_color = "bold_purple"
|
||||
|
||||
if inputs is not None:
|
||||
crew._inputs = inputs
|
||||
crew._interpolate_inputs(inputs)
|
||||
crew._set_tasks_callbacks()
|
||||
crew._set_allow_crewai_trigger_context_for_first_task()
|
||||
|
||||
setup_agents(
|
||||
crew,
|
||||
crew.agents,
|
||||
crew.embedder,
|
||||
crew.function_calling_llm,
|
||||
crew.step_callback,
|
||||
)
|
||||
|
||||
if crew.planning:
|
||||
crew._handle_crew_planning()
|
||||
|
||||
return inputs
|
||||
|
||||
|
||||
class StreamingContext:
|
||||
"""Container for streaming state and holders used during crew execution."""
|
||||
|
||||
def __init__(self, use_async: bool = False) -> None:
|
||||
"""Initialize streaming context.
|
||||
|
||||
Args:
|
||||
use_async: Whether to use async streaming mode.
|
||||
"""
|
||||
self.result_holder: list[CrewOutput] = []
|
||||
self.current_task_info: TaskInfo = {
|
||||
"index": 0,
|
||||
"name": "",
|
||||
"id": "",
|
||||
"agent_role": "",
|
||||
"agent_id": "",
|
||||
}
|
||||
self.state: StreamingState = create_streaming_state(
|
||||
self.current_task_info, self.result_holder, use_async=use_async
|
||||
)
|
||||
self.output_holder: list[CrewStreamingOutput | FlowStreamingOutput] = []
|
||||
|
||||
|
||||
class ForEachStreamingContext:
|
||||
"""Container for streaming state used in for_each crew execution methods."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize for_each streaming context."""
|
||||
self.result_holder: list[list[CrewOutput]] = [[]]
|
||||
self.current_task_info: TaskInfo = {
|
||||
"index": 0,
|
||||
"name": "",
|
||||
"id": "",
|
||||
"agent_role": "",
|
||||
"agent_id": "",
|
||||
}
|
||||
self.state: StreamingState = create_streaming_state(
|
||||
self.current_task_info, self.result_holder, use_async=True
|
||||
)
|
||||
self.output_holder: list[CrewStreamingOutput | FlowStreamingOutput] = []
|
||||
|
||||
|
||||
async def run_for_each_async(
|
||||
crew: Crew,
|
||||
inputs: list[dict[str, Any]],
|
||||
kickoff_fn: Callable[
|
||||
[Crew, dict[str, Any]], Coroutine[Any, Any, CrewOutput | CrewStreamingOutput]
|
||||
],
|
||||
) -> list[CrewOutput | CrewStreamingOutput] | CrewStreamingOutput:
|
||||
"""Execute crew workflow for each input asynchronously.
|
||||
|
||||
Args:
|
||||
crew: The crew instance to execute.
|
||||
inputs: List of input dictionaries for each execution.
|
||||
kickoff_fn: Async function to call for each crew copy (kickoff_async or akickoff).
|
||||
|
||||
Returns:
|
||||
If streaming, a single CrewStreamingOutput that yields chunks from all crews.
|
||||
Otherwise, a list of CrewOutput results.
|
||||
"""
|
||||
from crewai.types.usage_metrics import UsageMetrics
|
||||
from crewai.utilities.streaming import (
|
||||
create_async_chunk_generator,
|
||||
signal_end,
|
||||
signal_error,
|
||||
)
|
||||
|
||||
crew_copies = [crew.copy() for _ in inputs]
|
||||
|
||||
if crew.stream:
|
||||
ctx = ForEachStreamingContext()
|
||||
|
||||
async def run_all_crews() -> None:
|
||||
try:
|
||||
streaming_outputs: list[CrewStreamingOutput] = []
|
||||
for i, crew_copy in enumerate(crew_copies):
|
||||
streaming = await kickoff_fn(crew_copy, inputs[i])
|
||||
if isinstance(streaming, CrewStreamingOutput):
|
||||
streaming_outputs.append(streaming)
|
||||
|
||||
async def consume_stream(
|
||||
stream_output: CrewStreamingOutput,
|
||||
) -> CrewOutput:
|
||||
async for chunk in stream_output:
|
||||
if (
|
||||
ctx.state.async_queue is not None
|
||||
and ctx.state.loop is not None
|
||||
):
|
||||
ctx.state.loop.call_soon_threadsafe(
|
||||
ctx.state.async_queue.put_nowait, chunk
|
||||
)
|
||||
return stream_output.result
|
||||
|
||||
crew_results = await asyncio.gather(
|
||||
*[consume_stream(s) for s in streaming_outputs]
|
||||
)
|
||||
ctx.result_holder[0] = list(crew_results)
|
||||
except Exception as e:
|
||||
signal_error(ctx.state, e, is_async=True)
|
||||
finally:
|
||||
signal_end(ctx.state, is_async=True)
|
||||
|
||||
streaming_output = CrewStreamingOutput(
|
||||
async_iterator=create_async_chunk_generator(
|
||||
ctx.state, run_all_crews, ctx.output_holder
|
||||
)
|
||||
)
|
||||
|
||||
def set_results_wrapper(result: Any) -> None:
|
||||
streaming_output._set_results(result)
|
||||
|
||||
streaming_output._set_result = set_results_wrapper # type: ignore[method-assign]
|
||||
ctx.output_holder.append(streaming_output)
|
||||
|
||||
return streaming_output
|
||||
|
||||
async_tasks: list[asyncio.Task[CrewOutput | CrewStreamingOutput]] = [
|
||||
asyncio.create_task(kickoff_fn(crew_copy, input_data))
|
||||
for crew_copy, input_data in zip(crew_copies, inputs, strict=True)
|
||||
]
|
||||
|
||||
results = await asyncio.gather(*async_tasks)
|
||||
|
||||
total_usage_metrics = UsageMetrics()
|
||||
for crew_copy in crew_copies:
|
||||
if crew_copy.usage_metrics:
|
||||
total_usage_metrics.add_usage_metrics(crew_copy.usage_metrics)
|
||||
crew.usage_metrics = total_usage_metrics
|
||||
|
||||
crew._task_output_handler.reset()
|
||||
return list(results)
|
||||
@@ -140,7 +140,9 @@ class EventListener(BaseEventListener):
|
||||
def on_crew_started(source: Any, event: CrewKickoffStartedEvent) -> None:
|
||||
with self._crew_tree_lock:
|
||||
self.formatter.create_crew_tree(event.crew_name or "Crew", source.id)
|
||||
self._telemetry.crew_execution_span(source, event.inputs)
|
||||
source._execution_span = self._telemetry.crew_execution_span(
|
||||
source, event.inputs
|
||||
)
|
||||
self._crew_tree_lock.notify_all()
|
||||
|
||||
@crewai_event_bus.on(CrewKickoffCompletedEvent)
|
||||
|
||||
@@ -76,7 +76,7 @@ class TraceBatchManager:
|
||||
use_ephemeral: bool = False,
|
||||
) -> TraceBatch:
|
||||
"""Initialize a new trace batch (thread-safe)"""
|
||||
with self._init_lock:
|
||||
with self._batch_ready_cv:
|
||||
if self.current_batch is not None:
|
||||
logger.debug(
|
||||
"Batch already initialized, skipping duplicate initialization"
|
||||
@@ -99,7 +99,6 @@ class TraceBatchManager:
|
||||
self.backend_initialized = True
|
||||
|
||||
self._batch_ready_cv.notify_all()
|
||||
|
||||
return self.current_batch
|
||||
|
||||
def _initialize_backend_batch(
|
||||
@@ -107,7 +106,7 @@ class TraceBatchManager:
|
||||
user_context: dict[str, str],
|
||||
execution_metadata: dict[str, Any],
|
||||
use_ephemeral: bool = False,
|
||||
):
|
||||
) -> None:
|
||||
"""Send batch initialization to backend"""
|
||||
|
||||
if not is_tracing_enabled_in_context():
|
||||
@@ -204,7 +203,7 @@ class TraceBatchManager:
|
||||
return False
|
||||
return True
|
||||
|
||||
def add_event(self, trace_event: TraceEvent):
|
||||
def add_event(self, trace_event: TraceEvent) -> None:
|
||||
"""Add event to buffer"""
|
||||
self.event_buffer.append(trace_event)
|
||||
|
||||
@@ -300,7 +299,7 @@ class TraceBatchManager:
|
||||
|
||||
return finalized_batch
|
||||
|
||||
def _finalize_backend_batch(self, events_count: int = 0):
|
||||
def _finalize_backend_batch(self, events_count: int = 0) -> None:
|
||||
"""Send batch finalization to backend
|
||||
|
||||
Args:
|
||||
@@ -366,7 +365,7 @@ class TraceBatchManager:
|
||||
logger.error(f"❌ Error finalizing trace batch: {e}")
|
||||
self.plus_api.mark_trace_batch_as_failed(self.trace_batch_id, str(e))
|
||||
|
||||
def _cleanup_batch_data(self):
|
||||
def _cleanup_batch_data(self) -> None:
|
||||
"""Clean up batch data after successful finalization to free memory"""
|
||||
try:
|
||||
if hasattr(self, "event_buffer") and self.event_buffer:
|
||||
@@ -411,7 +410,7 @@ class TraceBatchManager:
|
||||
lambda: self.current_batch is not None, timeout=timeout
|
||||
)
|
||||
|
||||
def record_start_time(self, key: str):
|
||||
def record_start_time(self, key: str) -> None:
|
||||
"""Record start time for duration calculation"""
|
||||
self.execution_start_times[key] = datetime.now(timezone.utc)
|
||||
|
||||
|
||||
@@ -71,6 +71,7 @@ from crewai.events.types.reasoning_events import (
|
||||
AgentReasoningFailedEvent,
|
||||
AgentReasoningStartedEvent,
|
||||
)
|
||||
from crewai.events.types.system_events import SignalEvent, on_signal
|
||||
from crewai.events.types.task_events import (
|
||||
TaskCompletedEvent,
|
||||
TaskFailedEvent,
|
||||
@@ -159,6 +160,7 @@ class TraceCollectionListener(BaseEventListener):
|
||||
self._register_flow_event_handlers(crewai_event_bus)
|
||||
self._register_context_event_handlers(crewai_event_bus)
|
||||
self._register_action_event_handlers(crewai_event_bus)
|
||||
self._register_system_event_handlers(crewai_event_bus)
|
||||
|
||||
self._listeners_setup = True
|
||||
|
||||
@@ -458,6 +460,15 @@ class TraceCollectionListener(BaseEventListener):
|
||||
) -> None:
|
||||
self._handle_action_event("knowledge_query_failed", source, event)
|
||||
|
||||
def _register_system_event_handlers(self, event_bus: CrewAIEventsBus) -> None:
|
||||
"""Register handlers for system signal events (SIGTERM, SIGINT, etc.)."""
|
||||
|
||||
@on_signal
|
||||
def handle_signal(source: Any, event: SignalEvent) -> None:
|
||||
"""Flush trace batch on system signals to prevent data loss."""
|
||||
if self.batch_manager.is_batch_initialized():
|
||||
self.batch_manager.finalize_batch()
|
||||
|
||||
def _initialize_crew_batch(self, source: Any, event: Any) -> None:
|
||||
"""Initialize trace batch.
|
||||
|
||||
|
||||
102
lib/crewai/src/crewai/events/types/system_events.py
Normal file
102
lib/crewai/src/crewai/events/types/system_events.py
Normal file
@@ -0,0 +1,102 @@
|
||||
"""System signal event types for CrewAI.
|
||||
|
||||
This module contains event types for system-level signals like SIGTERM,
|
||||
allowing listeners to perform cleanup operations before process termination.
|
||||
"""
|
||||
|
||||
from collections.abc import Callable
|
||||
from enum import IntEnum
|
||||
import signal
|
||||
from typing import Annotated, Literal, TypeVar
|
||||
|
||||
from pydantic import Field, TypeAdapter
|
||||
|
||||
from crewai.events.base_events import BaseEvent
|
||||
|
||||
|
||||
class SignalType(IntEnum):
|
||||
"""Enumeration of supported system signals."""
|
||||
|
||||
SIGTERM = signal.SIGTERM
|
||||
SIGINT = signal.SIGINT
|
||||
SIGHUP = signal.SIGHUP
|
||||
SIGTSTP = signal.SIGTSTP
|
||||
SIGCONT = signal.SIGCONT
|
||||
|
||||
|
||||
class SigTermEvent(BaseEvent):
|
||||
"""Event emitted when SIGTERM is received."""
|
||||
|
||||
type: Literal["SIGTERM"] = "SIGTERM"
|
||||
signal_number: SignalType = SignalType.SIGTERM
|
||||
reason: str | None = None
|
||||
|
||||
|
||||
class SigIntEvent(BaseEvent):
|
||||
"""Event emitted when SIGINT is received."""
|
||||
|
||||
type: Literal["SIGINT"] = "SIGINT"
|
||||
signal_number: SignalType = SignalType.SIGINT
|
||||
reason: str | None = None
|
||||
|
||||
|
||||
class SigHupEvent(BaseEvent):
|
||||
"""Event emitted when SIGHUP is received."""
|
||||
|
||||
type: Literal["SIGHUP"] = "SIGHUP"
|
||||
signal_number: SignalType = SignalType.SIGHUP
|
||||
reason: str | None = None
|
||||
|
||||
|
||||
class SigTStpEvent(BaseEvent):
|
||||
"""Event emitted when SIGTSTP is received.
|
||||
|
||||
Note: SIGSTOP cannot be caught - it immediately suspends the process.
|
||||
"""
|
||||
|
||||
type: Literal["SIGTSTP"] = "SIGTSTP"
|
||||
signal_number: SignalType = SignalType.SIGTSTP
|
||||
reason: str | None = None
|
||||
|
||||
|
||||
class SigContEvent(BaseEvent):
|
||||
"""Event emitted when SIGCONT is received."""
|
||||
|
||||
type: Literal["SIGCONT"] = "SIGCONT"
|
||||
signal_number: SignalType = SignalType.SIGCONT
|
||||
reason: str | None = None
|
||||
|
||||
|
||||
SignalEvent = Annotated[
|
||||
SigTermEvent | SigIntEvent | SigHupEvent | SigTStpEvent | SigContEvent,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
||||
signal_event_adapter: TypeAdapter[SignalEvent] = TypeAdapter(SignalEvent)
|
||||
|
||||
SIGNAL_EVENT_TYPES: tuple[type[BaseEvent], ...] = (
|
||||
SigTermEvent,
|
||||
SigIntEvent,
|
||||
SigHupEvent,
|
||||
SigTStpEvent,
|
||||
SigContEvent,
|
||||
)
|
||||
|
||||
|
||||
T = TypeVar("T", bound=Callable[[object, SignalEvent], None])
|
||||
|
||||
|
||||
def on_signal(func: T) -> T:
|
||||
"""Decorator to register a handler for all signal events.
|
||||
|
||||
Args:
|
||||
func: Handler function that receives (source, event) arguments.
|
||||
|
||||
Returns:
|
||||
The original function, registered for all signal event types.
|
||||
"""
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
|
||||
for event_type in SIGNAL_EVENT_TYPES:
|
||||
crewai_event_bus.on(event_type)(func)
|
||||
return func
|
||||
@@ -1032,6 +1032,20 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
finally:
|
||||
detach(flow_token)
|
||||
|
||||
async def akickoff(
|
||||
self, inputs: dict[str, Any] | None = None
|
||||
) -> Any | FlowStreamingOutput:
|
||||
"""Native async method to start the flow execution. Alias for kickoff_async.
|
||||
|
||||
|
||||
Args:
|
||||
inputs: Optional dictionary containing input values and/or a state ID for restoration.
|
||||
|
||||
Returns:
|
||||
The final output from the flow, which is the result of the last executed method.
|
||||
"""
|
||||
return await self.kickoff_async(inputs)
|
||||
|
||||
async def _execute_start_method(self, start_method_name: FlowMethodName) -> None:
|
||||
"""Executes a flow's start method and its triggered listeners.
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
from crewai.events.event_listener import event_listener
|
||||
from crewai.hooks.types import AfterLLMCallHookType, BeforeLLMCallHookType
|
||||
@@ -9,17 +9,22 @@ from crewai.utilities.printer import Printer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.agents.crew_agent_executor import CrewAgentExecutor
|
||||
from crewai.lite_agent import LiteAgent
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.utilities.types import LLMMessage
|
||||
|
||||
|
||||
class LLMCallHookContext:
|
||||
"""Context object passed to LLM call hooks with full executor access.
|
||||
"""Context object passed to LLM call hooks.
|
||||
|
||||
Provides hooks with complete access to the executor state, allowing
|
||||
Provides hooks with complete access to the execution state, allowing
|
||||
modification of messages, responses, and executor attributes.
|
||||
|
||||
Supports both executor-based calls (agents in crews/flows) and direct LLM calls.
|
||||
|
||||
Attributes:
|
||||
executor: Full reference to the CrewAgentExecutor instance
|
||||
messages: Direct reference to executor.messages (mutable list).
|
||||
executor: Reference to the executor (CrewAgentExecutor/LiteAgent) or None for direct calls
|
||||
messages: Direct reference to messages (mutable list).
|
||||
Can be modified in both before_llm_call and after_llm_call hooks.
|
||||
Modifications in after_llm_call hooks persist to the next iteration,
|
||||
allowing hooks to modify conversation history for subsequent LLM calls.
|
||||
@@ -27,33 +32,75 @@ class LLMCallHookContext:
|
||||
Do NOT replace the list (e.g., context.messages = []), as this will break
|
||||
the executor. Use context.messages.append() or context.messages.extend()
|
||||
instead of assignment.
|
||||
agent: Reference to the agent executing the task
|
||||
task: Reference to the task being executed
|
||||
crew: Reference to the crew instance
|
||||
agent: Reference to the agent executing the task (None for direct LLM calls)
|
||||
task: Reference to the task being executed (None for direct LLM calls or LiteAgent)
|
||||
crew: Reference to the crew instance (None for direct LLM calls or LiteAgent)
|
||||
llm: Reference to the LLM instance
|
||||
iterations: Current iteration count
|
||||
iterations: Current iteration count (0 for direct LLM calls)
|
||||
response: LLM response string (only set for after_llm_call hooks).
|
||||
Can be modified by returning a new string from after_llm_call hook.
|
||||
"""
|
||||
|
||||
executor: CrewAgentExecutor | LiteAgent | None
|
||||
messages: list[LLMMessage]
|
||||
agent: Any
|
||||
task: Any
|
||||
crew: Any
|
||||
llm: BaseLLM | None | str | Any
|
||||
iterations: int
|
||||
response: str | None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
executor: CrewAgentExecutor,
|
||||
executor: CrewAgentExecutor | LiteAgent | None = None,
|
||||
response: str | None = None,
|
||||
messages: list[LLMMessage] | None = None,
|
||||
llm: BaseLLM | str | Any | None = None, # TODO: look into
|
||||
agent: Any | None = None,
|
||||
task: Any | None = None,
|
||||
crew: Any | None = None,
|
||||
) -> None:
|
||||
"""Initialize hook context with executor reference.
|
||||
"""Initialize hook context with executor reference or direct parameters.
|
||||
|
||||
Args:
|
||||
executor: The CrewAgentExecutor instance
|
||||
executor: The CrewAgentExecutor or LiteAgent instance (None for direct LLM calls)
|
||||
response: Optional response string (for after_llm_call hooks)
|
||||
messages: Optional messages list (for direct LLM calls when executor is None)
|
||||
llm: Optional LLM instance (for direct LLM calls when executor is None)
|
||||
agent: Optional agent reference (for direct LLM calls when executor is None)
|
||||
task: Optional task reference (for direct LLM calls when executor is None)
|
||||
crew: Optional crew reference (for direct LLM calls when executor is None)
|
||||
"""
|
||||
self.executor = executor
|
||||
self.messages = executor.messages
|
||||
self.agent = executor.agent
|
||||
self.task = executor.task
|
||||
self.crew = executor.crew
|
||||
self.llm = executor.llm
|
||||
self.iterations = executor.iterations
|
||||
if executor is not None:
|
||||
# Existing path: extract from executor
|
||||
self.executor = executor
|
||||
self.messages = executor.messages
|
||||
self.llm = executor.llm
|
||||
self.iterations = executor.iterations
|
||||
# Handle CrewAgentExecutor vs LiteAgent differences
|
||||
if hasattr(executor, "agent"):
|
||||
self.agent = executor.agent
|
||||
self.task = cast("CrewAgentExecutor", executor).task
|
||||
self.crew = cast("CrewAgentExecutor", executor).crew
|
||||
else:
|
||||
# LiteAgent case - is the agent itself, doesn't have task/crew
|
||||
self.agent = (
|
||||
executor.original_agent
|
||||
if hasattr(executor, "original_agent")
|
||||
else executor
|
||||
)
|
||||
self.task = None
|
||||
self.crew = None
|
||||
else:
|
||||
# New path: direct LLM call with explicit parameters
|
||||
self.executor = None
|
||||
self.messages = messages or []
|
||||
self.llm = llm
|
||||
self.agent = agent
|
||||
self.task = task
|
||||
self.crew = crew
|
||||
self.iterations = 0
|
||||
|
||||
self.response = response
|
||||
|
||||
def request_human_input(
|
||||
|
||||
@@ -32,8 +32,8 @@ class Knowledge(BaseModel):
|
||||
sources: list[BaseKnowledgeSource],
|
||||
embedder: EmbedderConfig | None = None,
|
||||
storage: KnowledgeStorage | None = None,
|
||||
**data,
|
||||
):
|
||||
**data: object,
|
||||
) -> None:
|
||||
super().__init__(**data)
|
||||
if storage:
|
||||
self.storage = storage
|
||||
@@ -75,3 +75,44 @@ class Knowledge(BaseModel):
|
||||
self.storage.reset()
|
||||
else:
|
||||
raise ValueError("Storage is not initialized.")
|
||||
|
||||
async def aquery(
|
||||
self, query: list[str], results_limit: int = 5, score_threshold: float = 0.6
|
||||
) -> list[SearchResult]:
|
||||
"""Query across all knowledge sources asynchronously.
|
||||
|
||||
Args:
|
||||
query: List of query strings.
|
||||
results_limit: Maximum number of results to return.
|
||||
score_threshold: Minimum similarity score for results.
|
||||
|
||||
Returns:
|
||||
The top results matching the query.
|
||||
|
||||
Raises:
|
||||
ValueError: If storage is not initialized.
|
||||
"""
|
||||
if self.storage is None:
|
||||
raise ValueError("Storage is not initialized.")
|
||||
|
||||
return await self.storage.asearch(
|
||||
query,
|
||||
limit=results_limit,
|
||||
score_threshold=score_threshold,
|
||||
)
|
||||
|
||||
async def aadd_sources(self) -> None:
|
||||
"""Add all knowledge sources to storage asynchronously."""
|
||||
try:
|
||||
for source in self.sources:
|
||||
source.storage = self.storage
|
||||
await source.aadd()
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
async def areset(self) -> None:
|
||||
"""Reset the knowledge base asynchronously."""
|
||||
if self.storage:
|
||||
await self.storage.areset()
|
||||
else:
|
||||
raise ValueError("Storage is not initialized.")
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field, field_validator
|
||||
|
||||
@@ -25,7 +26,10 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
|
||||
safe_file_paths: list[Path] = Field(default_factory=list)
|
||||
|
||||
@field_validator("file_path", "file_paths", mode="before")
|
||||
def validate_file_path(cls, v, info): # noqa: N805
|
||||
@classmethod
|
||||
def validate_file_path(
|
||||
cls, v: Path | list[Path] | str | list[str] | None, info: Any
|
||||
) -> Path | list[Path] | str | list[str] | None:
|
||||
"""Validate that at least one of file_path or file_paths is provided."""
|
||||
# Single check if both are None, O(1) instead of nested conditions
|
||||
if (
|
||||
@@ -38,7 +42,7 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
|
||||
raise ValueError("Either file_path or file_paths must be provided")
|
||||
return v
|
||||
|
||||
def model_post_init(self, _):
|
||||
def model_post_init(self, _: Any) -> None:
|
||||
"""Post-initialization method to load content."""
|
||||
self.safe_file_paths = self._process_file_paths()
|
||||
self.validate_content()
|
||||
@@ -48,7 +52,7 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
|
||||
def load_content(self) -> dict[Path, str]:
|
||||
"""Load and preprocess file content. Should be overridden by subclasses. Assume that the file path is relative to the project root in the knowledge directory."""
|
||||
|
||||
def validate_content(self):
|
||||
def validate_content(self) -> None:
|
||||
"""Validate the paths."""
|
||||
for path in self.safe_file_paths:
|
||||
if not path.exists():
|
||||
@@ -65,13 +69,20 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
|
||||
color="red",
|
||||
)
|
||||
|
||||
def _save_documents(self):
|
||||
def _save_documents(self) -> None:
|
||||
"""Save the documents to the storage."""
|
||||
if self.storage:
|
||||
self.storage.save(self.chunks)
|
||||
else:
|
||||
raise ValueError("No storage found to save documents.")
|
||||
|
||||
async def _asave_documents(self) -> None:
|
||||
"""Save the documents to the storage asynchronously."""
|
||||
if self.storage:
|
||||
await self.storage.asave(self.chunks)
|
||||
else:
|
||||
raise ValueError("No storage found to save documents.")
|
||||
|
||||
def convert_to_path(self, path: Path | str) -> Path:
|
||||
"""Convert a path to a Path object."""
|
||||
return Path(KNOWLEDGE_DIRECTORY + "/" + path) if isinstance(path, str) else path
|
||||
|
||||
@@ -39,12 +39,32 @@ class BaseKnowledgeSource(BaseModel, ABC):
|
||||
for i in range(0, len(text), self.chunk_size - self.chunk_overlap)
|
||||
]
|
||||
|
||||
def _save_documents(self):
|
||||
"""
|
||||
Save the documents to the storage.
|
||||
def _save_documents(self) -> None:
|
||||
"""Save the documents to the storage.
|
||||
|
||||
This method should be called after the chunks and embeddings are generated.
|
||||
|
||||
Raises:
|
||||
ValueError: If no storage is configured.
|
||||
"""
|
||||
if self.storage:
|
||||
self.storage.save(self.chunks)
|
||||
else:
|
||||
raise ValueError("No storage found to save documents.")
|
||||
|
||||
@abstractmethod
|
||||
async def aadd(self) -> None:
|
||||
"""Process content, chunk it, compute embeddings, and save them asynchronously."""
|
||||
|
||||
async def _asave_documents(self) -> None:
|
||||
"""Save the documents to the storage asynchronously.
|
||||
|
||||
This method should be called after the chunks and embeddings are generated.
|
||||
|
||||
Raises:
|
||||
ValueError: If no storage is configured.
|
||||
"""
|
||||
if self.storage:
|
||||
await self.storage.asave(self.chunks)
|
||||
else:
|
||||
raise ValueError("No storage found to save documents.")
|
||||
|
||||
@@ -2,27 +2,24 @@ from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterator
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
|
||||
try:
|
||||
from docling.datamodel.base_models import ( # type: ignore[import-not-found]
|
||||
InputFormat,
|
||||
)
|
||||
from docling.document_converter import ( # type: ignore[import-not-found]
|
||||
DocumentConverter,
|
||||
)
|
||||
from docling.exceptions import ConversionError # type: ignore[import-not-found]
|
||||
from docling_core.transforms.chunker.hierarchical_chunker import ( # type: ignore[import-not-found]
|
||||
HierarchicalChunker,
|
||||
)
|
||||
from docling_core.types.doc.document import ( # type: ignore[import-not-found]
|
||||
DoclingDocument,
|
||||
)
|
||||
from docling.datamodel.base_models import InputFormat
|
||||
from docling.document_converter import DocumentConverter
|
||||
from docling.exceptions import ConversionError
|
||||
from docling_core.transforms.chunker.hierarchical_chunker import HierarchicalChunker
|
||||
from docling_core.types.doc.document import DoclingDocument
|
||||
|
||||
DOCLING_AVAILABLE = True
|
||||
except ImportError:
|
||||
DOCLING_AVAILABLE = False
|
||||
# Provide type stubs for when docling is not available
|
||||
if TYPE_CHECKING:
|
||||
from docling.document_converter import DocumentConverter
|
||||
from docling_core.types.doc.document import DoclingDocument
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
@@ -32,11 +29,13 @@ from crewai.utilities.logger import Logger
|
||||
|
||||
|
||||
class CrewDoclingSource(BaseKnowledgeSource):
|
||||
"""Default Source class for converting documents to markdown or json
|
||||
This will auto support PDF, DOCX, and TXT, XLSX, Images, and HTML files without any additional dependencies and follows the docling package as the source of truth.
|
||||
"""Default Source class for converting documents to markdown or json.
|
||||
|
||||
This will auto support PDF, DOCX, and TXT, XLSX, Images, and HTML files without
|
||||
any additional dependencies and follows the docling package as the source of truth.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
if not DOCLING_AVAILABLE:
|
||||
raise ImportError(
|
||||
"The docling package is required to use CrewDoclingSource. "
|
||||
@@ -66,7 +65,7 @@ class CrewDoclingSource(BaseKnowledgeSource):
|
||||
)
|
||||
)
|
||||
|
||||
def model_post_init(self, _) -> None:
|
||||
def model_post_init(self, _: Any) -> None:
|
||||
if self.file_path:
|
||||
self._logger.log(
|
||||
"warning",
|
||||
@@ -99,6 +98,15 @@ class CrewDoclingSource(BaseKnowledgeSource):
|
||||
self.chunks.extend(list(new_chunks_iterable))
|
||||
self._save_documents()
|
||||
|
||||
async def aadd(self) -> None:
|
||||
"""Add docling content asynchronously."""
|
||||
if self.content is None:
|
||||
return
|
||||
for doc in self.content:
|
||||
new_chunks_iterable = self._chunk_doc(doc)
|
||||
self.chunks.extend(list(new_chunks_iterable))
|
||||
await self._asave_documents()
|
||||
|
||||
def _convert_source_to_docling_documents(self) -> list[DoclingDocument]:
|
||||
conv_results_iter = self.document_converter.convert_all(self.safe_file_paths)
|
||||
return [result.document for result in conv_results_iter]
|
||||
|
||||
@@ -31,6 +31,15 @@ class CSVKnowledgeSource(BaseFileKnowledgeSource):
|
||||
self.chunks.extend(new_chunks)
|
||||
self._save_documents()
|
||||
|
||||
async def aadd(self) -> None:
|
||||
"""Add CSV file content asynchronously."""
|
||||
content_str = (
|
||||
str(self.content) if isinstance(self.content, dict) else self.content
|
||||
)
|
||||
new_chunks = self._chunk_text(content_str)
|
||||
self.chunks.extend(new_chunks)
|
||||
await self._asave_documents()
|
||||
|
||||
def _chunk_text(self, text: str) -> list[str]:
|
||||
"""Utility method to split text into chunks."""
|
||||
return [
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
from pathlib import Path
|
||||
from types import ModuleType
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field, field_validator
|
||||
|
||||
@@ -26,7 +28,10 @@ class ExcelKnowledgeSource(BaseKnowledgeSource):
|
||||
safe_file_paths: list[Path] = Field(default_factory=list)
|
||||
|
||||
@field_validator("file_path", "file_paths", mode="before")
|
||||
def validate_file_path(cls, v, info): # noqa: N805
|
||||
@classmethod
|
||||
def validate_file_path(
|
||||
cls, v: Path | list[Path] | str | list[str] | None, info: Any
|
||||
) -> Path | list[Path] | str | list[str] | None:
|
||||
"""Validate that at least one of file_path or file_paths is provided."""
|
||||
# Single check if both are None, O(1) instead of nested conditions
|
||||
if (
|
||||
@@ -69,7 +74,7 @@ class ExcelKnowledgeSource(BaseKnowledgeSource):
|
||||
|
||||
return [self.convert_to_path(path) for path in path_list]
|
||||
|
||||
def validate_content(self):
|
||||
def validate_content(self) -> None:
|
||||
"""Validate the paths."""
|
||||
for path in self.safe_file_paths:
|
||||
if not path.exists():
|
||||
@@ -86,7 +91,7 @@ class ExcelKnowledgeSource(BaseKnowledgeSource):
|
||||
color="red",
|
||||
)
|
||||
|
||||
def model_post_init(self, _) -> None:
|
||||
def model_post_init(self, _: Any) -> None:
|
||||
if self.file_path:
|
||||
self._logger.log(
|
||||
"warning",
|
||||
@@ -128,12 +133,12 @@ class ExcelKnowledgeSource(BaseKnowledgeSource):
|
||||
"""Convert a path to a Path object."""
|
||||
return Path(KNOWLEDGE_DIRECTORY + "/" + path) if isinstance(path, str) else path
|
||||
|
||||
def _import_dependencies(self):
|
||||
def _import_dependencies(self) -> ModuleType:
|
||||
"""Dynamically import dependencies."""
|
||||
try:
|
||||
import pandas as pd # type: ignore[import-untyped,import-not-found]
|
||||
import pandas as pd # type: ignore[import-untyped]
|
||||
|
||||
return pd
|
||||
return pd # type: ignore[no-any-return]
|
||||
except ImportError as e:
|
||||
missing_package = str(e).split()[-1]
|
||||
raise ImportError(
|
||||
@@ -159,6 +164,20 @@ class ExcelKnowledgeSource(BaseKnowledgeSource):
|
||||
self.chunks.extend(new_chunks)
|
||||
self._save_documents()
|
||||
|
||||
async def aadd(self) -> None:
|
||||
"""Add Excel file content asynchronously."""
|
||||
content_str = ""
|
||||
for value in self.content.values():
|
||||
if isinstance(value, dict):
|
||||
for sheet_value in value.values():
|
||||
content_str += str(sheet_value) + "\n"
|
||||
else:
|
||||
content_str += str(value) + "\n"
|
||||
|
||||
new_chunks = self._chunk_text(content_str)
|
||||
self.chunks.extend(new_chunks)
|
||||
await self._asave_documents()
|
||||
|
||||
def _chunk_text(self, text: str) -> list[str]:
|
||||
"""Utility method to split text into chunks."""
|
||||
return [
|
||||
|
||||
@@ -44,6 +44,15 @@ class JSONKnowledgeSource(BaseFileKnowledgeSource):
|
||||
self.chunks.extend(new_chunks)
|
||||
self._save_documents()
|
||||
|
||||
async def aadd(self) -> None:
|
||||
"""Add JSON file content asynchronously."""
|
||||
content_str = (
|
||||
str(self.content) if isinstance(self.content, dict) else self.content
|
||||
)
|
||||
new_chunks = self._chunk_text(content_str)
|
||||
self.chunks.extend(new_chunks)
|
||||
await self._asave_documents()
|
||||
|
||||
def _chunk_text(self, text: str) -> list[str]:
|
||||
"""Utility method to split text into chunks."""
|
||||
return [
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from pathlib import Path
|
||||
from types import ModuleType
|
||||
|
||||
from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledgeSource
|
||||
|
||||
@@ -23,7 +24,7 @@ class PDFKnowledgeSource(BaseFileKnowledgeSource):
|
||||
content[path] = text
|
||||
return content
|
||||
|
||||
def _import_pdfplumber(self):
|
||||
def _import_pdfplumber(self) -> ModuleType:
|
||||
"""Dynamically import pdfplumber."""
|
||||
try:
|
||||
import pdfplumber
|
||||
@@ -44,6 +45,13 @@ class PDFKnowledgeSource(BaseFileKnowledgeSource):
|
||||
self.chunks.extend(new_chunks)
|
||||
self._save_documents()
|
||||
|
||||
async def aadd(self) -> None:
|
||||
"""Add PDF file content asynchronously."""
|
||||
for text in self.content.values():
|
||||
new_chunks = self._chunk_text(text)
|
||||
self.chunks.extend(new_chunks)
|
||||
await self._asave_documents()
|
||||
|
||||
def _chunk_text(self, text: str) -> list[str]:
|
||||
"""Utility method to split text into chunks."""
|
||||
return [
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
|
||||
@@ -9,11 +11,11 @@ class StringKnowledgeSource(BaseKnowledgeSource):
|
||||
content: str = Field(...)
|
||||
collection_name: str | None = Field(default=None)
|
||||
|
||||
def model_post_init(self, _):
|
||||
def model_post_init(self, _: Any) -> None:
|
||||
"""Post-initialization method to validate content."""
|
||||
self.validate_content()
|
||||
|
||||
def validate_content(self):
|
||||
def validate_content(self) -> None:
|
||||
"""Validate string content."""
|
||||
if not isinstance(self.content, str):
|
||||
raise ValueError("StringKnowledgeSource only accepts string content")
|
||||
@@ -24,6 +26,12 @@ class StringKnowledgeSource(BaseKnowledgeSource):
|
||||
self.chunks.extend(new_chunks)
|
||||
self._save_documents()
|
||||
|
||||
async def aadd(self) -> None:
|
||||
"""Add string content asynchronously."""
|
||||
new_chunks = self._chunk_text(self.content)
|
||||
self.chunks.extend(new_chunks)
|
||||
await self._asave_documents()
|
||||
|
||||
def _chunk_text(self, text: str) -> list[str]:
|
||||
"""Utility method to split text into chunks."""
|
||||
return [
|
||||
|
||||
@@ -25,6 +25,13 @@ class TextFileKnowledgeSource(BaseFileKnowledgeSource):
|
||||
self.chunks.extend(new_chunks)
|
||||
self._save_documents()
|
||||
|
||||
async def aadd(self) -> None:
|
||||
"""Add text file content asynchronously."""
|
||||
for text in self.content.values():
|
||||
new_chunks = self._chunk_text(text)
|
||||
self.chunks.extend(new_chunks)
|
||||
await self._asave_documents()
|
||||
|
||||
def _chunk_text(self, text: str) -> list[str]:
|
||||
"""Utility method to split text into chunks."""
|
||||
return [
|
||||
|
||||
@@ -21,10 +21,28 @@ class BaseKnowledgeStorage(ABC):
|
||||
) -> list[SearchResult]:
|
||||
"""Search for documents in the knowledge base."""
|
||||
|
||||
@abstractmethod
|
||||
async def asearch(
|
||||
self,
|
||||
query: list[str],
|
||||
limit: int = 5,
|
||||
metadata_filter: dict[str, Any] | None = None,
|
||||
score_threshold: float = 0.6,
|
||||
) -> list[SearchResult]:
|
||||
"""Search for documents in the knowledge base asynchronously."""
|
||||
|
||||
@abstractmethod
|
||||
def save(self, documents: list[str]) -> None:
|
||||
"""Save documents to the knowledge base."""
|
||||
|
||||
@abstractmethod
|
||||
async def asave(self, documents: list[str]) -> None:
|
||||
"""Save documents to the knowledge base asynchronously."""
|
||||
|
||||
@abstractmethod
|
||||
def reset(self) -> None:
|
||||
"""Reset the knowledge base."""
|
||||
|
||||
@abstractmethod
|
||||
async def areset(self) -> None:
|
||||
"""Reset the knowledge base asynchronously."""
|
||||
|
||||
@@ -25,8 +25,8 @@ class KnowledgeStorage(BaseKnowledgeStorage):
|
||||
def __init__(
|
||||
self,
|
||||
embedder: ProviderSpec
|
||||
| BaseEmbeddingsProvider
|
||||
| type[BaseEmbeddingsProvider]
|
||||
| BaseEmbeddingsProvider[Any]
|
||||
| type[BaseEmbeddingsProvider[Any]]
|
||||
| None = None,
|
||||
collection_name: str | None = None,
|
||||
) -> None:
|
||||
@@ -127,3 +127,96 @@ class KnowledgeStorage(BaseKnowledgeStorage):
|
||||
) from e
|
||||
Logger(verbose=True).log("error", f"Failed to upsert documents: {e}", "red")
|
||||
raise
|
||||
|
||||
async def asearch(
|
||||
self,
|
||||
query: list[str],
|
||||
limit: int = 5,
|
||||
metadata_filter: dict[str, Any] | None = None,
|
||||
score_threshold: float = 0.6,
|
||||
) -> list[SearchResult]:
|
||||
"""Search for documents in the knowledge base asynchronously.
|
||||
|
||||
Args:
|
||||
query: List of query strings.
|
||||
limit: Maximum number of results to return.
|
||||
metadata_filter: Optional metadata filter for the search.
|
||||
score_threshold: Minimum similarity score for results.
|
||||
|
||||
Returns:
|
||||
List of search results.
|
||||
"""
|
||||
try:
|
||||
if not query:
|
||||
raise ValueError("Query cannot be empty")
|
||||
|
||||
client = self._get_client()
|
||||
collection_name = (
|
||||
f"knowledge_{self.collection_name}"
|
||||
if self.collection_name
|
||||
else "knowledge"
|
||||
)
|
||||
query_text = " ".join(query) if len(query) > 1 else query[0]
|
||||
|
||||
return await client.asearch(
|
||||
collection_name=collection_name,
|
||||
query=query_text,
|
||||
limit=limit,
|
||||
metadata_filter=metadata_filter,
|
||||
score_threshold=score_threshold,
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error(
|
||||
f"Error during knowledge search: {e!s}\n{traceback.format_exc()}"
|
||||
)
|
||||
return []
|
||||
|
||||
async def asave(self, documents: list[str]) -> None:
|
||||
"""Save documents to the knowledge base asynchronously.
|
||||
|
||||
Args:
|
||||
documents: List of document strings to save.
|
||||
"""
|
||||
try:
|
||||
client = self._get_client()
|
||||
collection_name = (
|
||||
f"knowledge_{self.collection_name}"
|
||||
if self.collection_name
|
||||
else "knowledge"
|
||||
)
|
||||
await client.aget_or_create_collection(collection_name=collection_name)
|
||||
|
||||
rag_documents: list[BaseRecord] = [{"content": doc} for doc in documents]
|
||||
|
||||
await client.aadd_documents(
|
||||
collection_name=collection_name, documents=rag_documents
|
||||
)
|
||||
except Exception as e:
|
||||
if "dimension mismatch" in str(e).lower():
|
||||
Logger(verbose=True).log(
|
||||
"error",
|
||||
"Embedding dimension mismatch. This usually happens when mixing different embedding models. Try resetting the collection using `crewai reset-memories -a`",
|
||||
"red",
|
||||
)
|
||||
raise ValueError(
|
||||
"Embedding dimension mismatch. Make sure you're using the same embedding model "
|
||||
"across all operations with this collection."
|
||||
"Try resetting the collection using `crewai reset-memories -a`"
|
||||
) from e
|
||||
Logger(verbose=True).log("error", f"Failed to upsert documents: {e}", "red")
|
||||
raise
|
||||
|
||||
async def areset(self) -> None:
|
||||
"""Reset the knowledge base asynchronously."""
|
||||
try:
|
||||
client = self._get_client()
|
||||
collection_name = (
|
||||
f"knowledge_{self.collection_name}"
|
||||
if self.collection_name
|
||||
else "knowledge"
|
||||
)
|
||||
await client.adelete_collection(collection_name=collection_name)
|
||||
except Exception as e:
|
||||
logging.error(
|
||||
f"Error during knowledge reset: {e!s}\n{traceback.format_exc()}"
|
||||
)
|
||||
|
||||
@@ -38,6 +38,8 @@ from crewai.events.types.agent_events import (
|
||||
)
|
||||
from crewai.events.types.logging_events import AgentLogsExecutionEvent
|
||||
from crewai.flow.flow_trackable import FlowTrackable
|
||||
from crewai.hooks.llm_hooks import get_after_llm_call_hooks, get_before_llm_call_hooks
|
||||
from crewai.hooks.types import AfterLLMCallHookType, BeforeLLMCallHookType
|
||||
from crewai.lite_agent_output import LiteAgentOutput
|
||||
from crewai.llm import LLM
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
@@ -155,6 +157,12 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
_guardrail: GuardrailCallable | None = PrivateAttr(default=None)
|
||||
_guardrail_retry_count: int = PrivateAttr(default=0)
|
||||
_callbacks: list[TokenCalcHandler] = PrivateAttr(default_factory=list)
|
||||
_before_llm_call_hooks: list[BeforeLLMCallHookType] = PrivateAttr(
|
||||
default_factory=get_before_llm_call_hooks
|
||||
)
|
||||
_after_llm_call_hooks: list[AfterLLMCallHookType] = PrivateAttr(
|
||||
default_factory=get_after_llm_call_hooks
|
||||
)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def setup_llm(self) -> Self:
|
||||
@@ -246,6 +254,26 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
"""Return the original role for compatibility with tool interfaces."""
|
||||
return self.role
|
||||
|
||||
@property
|
||||
def before_llm_call_hooks(self) -> list[BeforeLLMCallHookType]:
|
||||
"""Get the before_llm_call hooks for this agent."""
|
||||
return self._before_llm_call_hooks
|
||||
|
||||
@property
|
||||
def after_llm_call_hooks(self) -> list[AfterLLMCallHookType]:
|
||||
"""Get the after_llm_call hooks for this agent."""
|
||||
return self._after_llm_call_hooks
|
||||
|
||||
@property
|
||||
def messages(self) -> list[LLMMessage]:
|
||||
"""Get the messages list for hook context compatibility."""
|
||||
return self._messages
|
||||
|
||||
@property
|
||||
def iterations(self) -> int:
|
||||
"""Get the current iteration count for hook context compatibility."""
|
||||
return self._iterations
|
||||
|
||||
def kickoff(
|
||||
self,
|
||||
messages: str | list[LLMMessage],
|
||||
@@ -504,7 +532,7 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
AgentFinish: The final result of the agent execution.
|
||||
"""
|
||||
# Execute the agent loop
|
||||
formatted_answer = None
|
||||
formatted_answer: AgentAction | AgentFinish | None = None
|
||||
while not isinstance(formatted_answer, AgentFinish):
|
||||
try:
|
||||
if has_reached_max_iterations(self._iterations, self.max_iterations):
|
||||
@@ -526,6 +554,7 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
callbacks=self._callbacks,
|
||||
printer=self._printer,
|
||||
from_agent=self,
|
||||
executor_context=self,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
@@ -57,11 +57,17 @@ if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.get_supported_openai_params import (
|
||||
get_supported_openai_params,
|
||||
)
|
||||
from litellm.types.utils import ChatCompletionDeltaToolCall, Choices, ModelResponse
|
||||
from litellm.types.utils import (
|
||||
ChatCompletionDeltaToolCall,
|
||||
Choices,
|
||||
Function,
|
||||
ModelResponse,
|
||||
)
|
||||
from litellm.utils import supports_response_schema
|
||||
|
||||
from crewai.agent.core import Agent
|
||||
from crewai.llms.hooks.base import BaseInterceptor
|
||||
from crewai.llms.providers.anthropic.completion import AnthropicThinkingConfig
|
||||
from crewai.task import Task
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
from crewai.utilities.types import LLMMessage
|
||||
@@ -73,7 +79,12 @@ try:
|
||||
from litellm.litellm_core_utils.get_supported_openai_params import (
|
||||
get_supported_openai_params,
|
||||
)
|
||||
from litellm.types.utils import ChatCompletionDeltaToolCall, Choices, ModelResponse
|
||||
from litellm.types.utils import (
|
||||
ChatCompletionDeltaToolCall,
|
||||
Choices,
|
||||
Function,
|
||||
ModelResponse,
|
||||
)
|
||||
from litellm.utils import supports_response_schema
|
||||
|
||||
LITELLM_AVAILABLE = True
|
||||
@@ -84,6 +95,7 @@ except ImportError:
|
||||
ContextWindowExceededError = Exception # type: ignore
|
||||
get_supported_openai_params = None # type: ignore
|
||||
ChatCompletionDeltaToolCall = None # type: ignore
|
||||
Function = None # type: ignore
|
||||
ModelResponse = None # type: ignore
|
||||
supports_response_schema = None # type: ignore
|
||||
CustomLogger = None # type: ignore
|
||||
@@ -406,46 +418,100 @@ class LLM(BaseLLM):
|
||||
instance.is_litellm = True
|
||||
return instance
|
||||
|
||||
@classmethod
|
||||
def _matches_provider_pattern(cls, model: str, provider: str) -> bool:
|
||||
"""Check if a model name matches provider-specific patterns.
|
||||
|
||||
This allows supporting models that aren't in the hardcoded constants list,
|
||||
including "latest" versions and new models that follow provider naming conventions.
|
||||
|
||||
Args:
|
||||
model: The model name to check
|
||||
provider: The provider to check against (canonical name)
|
||||
|
||||
Returns:
|
||||
True if the model matches the provider's naming pattern, False otherwise
|
||||
"""
|
||||
model_lower = model.lower()
|
||||
|
||||
if provider == "openai":
|
||||
return any(
|
||||
model_lower.startswith(prefix)
|
||||
for prefix in ["gpt-", "o1", "o3", "o4", "whisper-"]
|
||||
)
|
||||
|
||||
if provider == "anthropic" or provider == "claude":
|
||||
return any(
|
||||
model_lower.startswith(prefix) for prefix in ["claude-", "anthropic."]
|
||||
)
|
||||
|
||||
if provider == "gemini" or provider == "google":
|
||||
return any(
|
||||
model_lower.startswith(prefix)
|
||||
for prefix in ["gemini-", "gemma-", "learnlm-"]
|
||||
)
|
||||
|
||||
if provider == "bedrock":
|
||||
return "." in model_lower
|
||||
|
||||
if provider == "azure":
|
||||
return any(
|
||||
model_lower.startswith(prefix)
|
||||
for prefix in ["gpt-", "gpt-35-", "o1", "o3", "o4", "azure-"]
|
||||
)
|
||||
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def _validate_model_in_constants(cls, model: str, provider: str) -> bool:
|
||||
"""Validate if a model name exists in the provider's constants.
|
||||
"""Validate if a model name exists in the provider's constants or matches provider patterns.
|
||||
|
||||
This method first checks the hardcoded constants list for known models.
|
||||
If not found, it falls back to pattern matching to support new models,
|
||||
"latest" versions, and models that follow provider naming conventions.
|
||||
|
||||
Args:
|
||||
model: The model name to validate
|
||||
provider: The provider to check against (canonical name)
|
||||
|
||||
Returns:
|
||||
True if the model exists in the provider's constants, False otherwise
|
||||
True if the model exists in constants or matches provider patterns, False otherwise
|
||||
"""
|
||||
if provider == "openai":
|
||||
return model in OPENAI_MODELS
|
||||
if provider == "openai" and model in OPENAI_MODELS:
|
||||
return True
|
||||
|
||||
if provider == "anthropic" or provider == "claude":
|
||||
return model in ANTHROPIC_MODELS
|
||||
if (
|
||||
provider == "anthropic" or provider == "claude"
|
||||
) and model in ANTHROPIC_MODELS:
|
||||
return True
|
||||
|
||||
if provider == "gemini":
|
||||
return model in GEMINI_MODELS
|
||||
if (provider == "gemini" or provider == "google") and model in GEMINI_MODELS:
|
||||
return True
|
||||
|
||||
if provider == "bedrock":
|
||||
return model in BEDROCK_MODELS
|
||||
if provider == "bedrock" and model in BEDROCK_MODELS:
|
||||
return True
|
||||
|
||||
if provider == "azure":
|
||||
# azure does not provide a list of available models, determine a better way to handle this
|
||||
return True
|
||||
|
||||
return False
|
||||
# Fallback to pattern matching for models not in constants
|
||||
return cls._matches_provider_pattern(model, provider)
|
||||
|
||||
@classmethod
|
||||
def _infer_provider_from_model(cls, model: str) -> str:
|
||||
"""Infer the provider from the model name.
|
||||
|
||||
This method first checks the hardcoded constants list for known models.
|
||||
If not found, it uses pattern matching to infer the provider from model name patterns.
|
||||
This allows supporting new models and "latest" versions without hardcoding.
|
||||
|
||||
Args:
|
||||
model: The model name without provider prefix
|
||||
|
||||
Returns:
|
||||
The inferred provider name, defaults to "openai"
|
||||
"""
|
||||
|
||||
if model in OPENAI_MODELS:
|
||||
return "openai"
|
||||
|
||||
@@ -520,6 +586,7 @@ class LLM(BaseLLM):
|
||||
reasoning_effort: Literal["none", "low", "medium", "high"] | None = None,
|
||||
stream: bool = False,
|
||||
interceptor: BaseInterceptor[httpx.Request, httpx.Response] | None = None,
|
||||
thinking: AnthropicThinkingConfig | dict[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize LLM instance.
|
||||
@@ -556,7 +623,9 @@ class LLM(BaseLLM):
|
||||
self.callbacks = callbacks
|
||||
self.context_window_size = 0
|
||||
self.reasoning_effort = reasoning_effort
|
||||
self.additional_params = kwargs
|
||||
self.additional_params = {
|
||||
k: v for k, v in kwargs.items() if k not in ("is_litellm", "provider")
|
||||
}
|
||||
self.is_anthropic = self._is_anthropic_model(model)
|
||||
self.stream = stream
|
||||
self.interceptor = interceptor
|
||||
@@ -1150,6 +1219,281 @@ class LLM(BaseLLM):
|
||||
)
|
||||
return text_response
|
||||
|
||||
async def _ahandle_non_streaming_response(
|
||||
self,
|
||||
params: dict[str, Any],
|
||||
callbacks: list[Any] | None = None,
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Task | None = None,
|
||||
from_agent: Agent | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str | Any:
|
||||
"""Handle an async non-streaming response from the LLM.
|
||||
|
||||
Args:
|
||||
params: Parameters for the completion call
|
||||
callbacks: Optional list of callback functions
|
||||
available_functions: Dict of available functions
|
||||
from_task: Optional Task that invoked the LLM
|
||||
from_agent: Optional Agent that invoked the LLM
|
||||
response_model: Optional Response model
|
||||
|
||||
Returns:
|
||||
str: The response text
|
||||
"""
|
||||
if response_model and self.is_litellm:
|
||||
from crewai.utilities.internal_instructor import InternalInstructor
|
||||
|
||||
messages = params.get("messages", [])
|
||||
if not messages:
|
||||
raise ValueError("Messages are required when using response_model")
|
||||
|
||||
combined_content = "\n\n".join(
|
||||
f"{msg['role'].upper()}: {msg['content']}" for msg in messages
|
||||
)
|
||||
|
||||
instructor_instance = InternalInstructor(
|
||||
content=combined_content,
|
||||
model=response_model,
|
||||
llm=self,
|
||||
)
|
||||
result = instructor_instance.to_pydantic()
|
||||
structured_response = result.model_dump_json()
|
||||
self._handle_emit_call_events(
|
||||
response=structured_response,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
)
|
||||
return structured_response
|
||||
|
||||
try:
|
||||
if response_model:
|
||||
params["response_model"] = response_model
|
||||
response = await litellm.acompletion(**params)
|
||||
|
||||
except ContextWindowExceededError as e:
|
||||
raise LLMContextLengthExceededError(str(e)) from e
|
||||
|
||||
if response_model is not None:
|
||||
if isinstance(response, BaseModel):
|
||||
structured_response = response.model_dump_json()
|
||||
self._handle_emit_call_events(
|
||||
response=structured_response,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
)
|
||||
return structured_response
|
||||
|
||||
response_message = cast(Choices, cast(ModelResponse, response).choices)[
|
||||
0
|
||||
].message
|
||||
text_response = response_message.content or ""
|
||||
|
||||
if callbacks and len(callbacks) > 0:
|
||||
for callback in callbacks:
|
||||
if hasattr(callback, "log_success_event"):
|
||||
usage_info = getattr(response, "usage", None)
|
||||
if usage_info:
|
||||
callback.log_success_event(
|
||||
kwargs=params,
|
||||
response_obj={"usage": usage_info},
|
||||
start_time=0,
|
||||
end_time=0,
|
||||
)
|
||||
|
||||
tool_calls = getattr(response_message, "tool_calls", [])
|
||||
|
||||
if (not tool_calls or not available_functions) and text_response:
|
||||
self._handle_emit_call_events(
|
||||
response=text_response,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
)
|
||||
return text_response
|
||||
|
||||
if tool_calls and not available_functions and not text_response:
|
||||
return tool_calls
|
||||
|
||||
tool_result = self._handle_tool_call(
|
||||
tool_calls, available_functions, from_task, from_agent
|
||||
)
|
||||
if tool_result is not None:
|
||||
return tool_result
|
||||
|
||||
self._handle_emit_call_events(
|
||||
response=text_response,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
)
|
||||
return text_response
|
||||
|
||||
async def _ahandle_streaming_response(
|
||||
self,
|
||||
params: dict[str, Any],
|
||||
callbacks: list[Any] | None = None,
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Task | None = None,
|
||||
from_agent: Agent | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> Any:
|
||||
"""Handle an async streaming response from the LLM.
|
||||
|
||||
Args:
|
||||
params: Parameters for the completion call
|
||||
callbacks: Optional list of callback functions
|
||||
available_functions: Dict of available functions
|
||||
from_task: Optional task object
|
||||
from_agent: Optional agent object
|
||||
response_model: Optional response model
|
||||
|
||||
Returns:
|
||||
str: The complete response text
|
||||
"""
|
||||
full_response = ""
|
||||
chunk_count = 0
|
||||
usage_info = None
|
||||
|
||||
accumulated_tool_args: defaultdict[int, AccumulatedToolArgs] = defaultdict(
|
||||
AccumulatedToolArgs
|
||||
)
|
||||
|
||||
params["stream"] = True
|
||||
params["stream_options"] = {"include_usage": True}
|
||||
|
||||
try:
|
||||
async for chunk in await litellm.acompletion(**params):
|
||||
chunk_count += 1
|
||||
chunk_content = None
|
||||
|
||||
try:
|
||||
choices = None
|
||||
if isinstance(chunk, dict) and "choices" in chunk:
|
||||
choices = chunk["choices"]
|
||||
elif hasattr(chunk, "choices"):
|
||||
if not isinstance(chunk.choices, type):
|
||||
choices = chunk.choices
|
||||
|
||||
if hasattr(chunk, "usage") and chunk.usage is not None:
|
||||
usage_info = chunk.usage
|
||||
|
||||
if choices and len(choices) > 0:
|
||||
first_choice = choices[0]
|
||||
delta = None
|
||||
|
||||
if isinstance(first_choice, dict):
|
||||
delta = first_choice.get("delta", {})
|
||||
elif hasattr(first_choice, "delta"):
|
||||
delta = first_choice.delta
|
||||
|
||||
if delta:
|
||||
if isinstance(delta, dict):
|
||||
chunk_content = delta.get("content")
|
||||
elif hasattr(delta, "content"):
|
||||
chunk_content = delta.content
|
||||
|
||||
tool_calls: list[ChatCompletionDeltaToolCall] | None = None
|
||||
if isinstance(delta, dict):
|
||||
tool_calls = delta.get("tool_calls")
|
||||
elif hasattr(delta, "tool_calls"):
|
||||
tool_calls = delta.tool_calls
|
||||
|
||||
if tool_calls:
|
||||
for tool_call in tool_calls:
|
||||
idx = tool_call.index
|
||||
if tool_call.function:
|
||||
if tool_call.function.name:
|
||||
accumulated_tool_args[
|
||||
idx
|
||||
].function.name = tool_call.function.name
|
||||
if tool_call.function.arguments:
|
||||
accumulated_tool_args[
|
||||
idx
|
||||
].function.arguments += (
|
||||
tool_call.function.arguments
|
||||
)
|
||||
|
||||
except (AttributeError, KeyError, IndexError, TypeError):
|
||||
pass
|
||||
|
||||
if chunk_content:
|
||||
full_response += chunk_content
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LLMStreamChunkEvent(
|
||||
chunk=chunk_content,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
|
||||
if callbacks and len(callbacks) > 0 and usage_info:
|
||||
for callback in callbacks:
|
||||
if hasattr(callback, "log_success_event"):
|
||||
callback.log_success_event(
|
||||
kwargs=params,
|
||||
response_obj={"usage": usage_info},
|
||||
start_time=0,
|
||||
end_time=0,
|
||||
)
|
||||
|
||||
if accumulated_tool_args and available_functions:
|
||||
# Convert accumulated tool args to ChatCompletionDeltaToolCall objects
|
||||
tool_calls_list: list[ChatCompletionDeltaToolCall] = [
|
||||
ChatCompletionDeltaToolCall(
|
||||
index=idx,
|
||||
function=Function(
|
||||
name=tool_arg.function.name,
|
||||
arguments=tool_arg.function.arguments,
|
||||
),
|
||||
)
|
||||
for idx, tool_arg in accumulated_tool_args.items()
|
||||
if tool_arg.function.name
|
||||
]
|
||||
|
||||
if tool_calls_list:
|
||||
result = self._handle_streaming_tool_calls(
|
||||
tool_calls=tool_calls_list,
|
||||
accumulated_tool_args=accumulated_tool_args,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
self._handle_emit_call_events(
|
||||
response=full_response,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params.get("messages"),
|
||||
)
|
||||
return full_response
|
||||
|
||||
except ContextWindowExceededError as e:
|
||||
raise LLMContextLengthExceededError(str(e)) from e
|
||||
except Exception:
|
||||
if chunk_count == 0:
|
||||
raise
|
||||
if full_response:
|
||||
self._handle_emit_call_events(
|
||||
response=full_response,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params.get("messages"),
|
||||
)
|
||||
return full_response
|
||||
raise
|
||||
|
||||
def _handle_tool_call(
|
||||
self,
|
||||
tool_calls: list[Any],
|
||||
@@ -1300,6 +1644,10 @@ class LLM(BaseLLM):
|
||||
if message.get("role") == "system":
|
||||
msg_role: Literal["assistant"] = "assistant"
|
||||
message["role"] = msg_role
|
||||
|
||||
if not self._invoke_before_llm_call_hooks(messages, from_agent):
|
||||
raise ValueError("LLM call blocked by before_llm_call hook")
|
||||
|
||||
# --- 5) Set up callbacks if provided
|
||||
with suppress_warnings():
|
||||
if callbacks and len(callbacks) > 0:
|
||||
@@ -1309,7 +1657,16 @@ class LLM(BaseLLM):
|
||||
params = self._prepare_completion_params(messages, tools)
|
||||
# --- 7) Make the completion call and handle response
|
||||
if self.stream:
|
||||
return self._handle_streaming_response(
|
||||
result = self._handle_streaming_response(
|
||||
params=params,
|
||||
callbacks=callbacks,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
response_model=response_model,
|
||||
)
|
||||
else:
|
||||
result = self._handle_non_streaming_response(
|
||||
params=params,
|
||||
callbacks=callbacks,
|
||||
available_functions=available_functions,
|
||||
@@ -1318,14 +1675,12 @@ class LLM(BaseLLM):
|
||||
response_model=response_model,
|
||||
)
|
||||
|
||||
return self._handle_non_streaming_response(
|
||||
params=params,
|
||||
callbacks=callbacks,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
response_model=response_model,
|
||||
)
|
||||
if isinstance(result, str):
|
||||
result = self._invoke_after_llm_call_hooks(
|
||||
messages, result, from_agent
|
||||
)
|
||||
|
||||
return result
|
||||
except LLMContextLengthExceededError:
|
||||
# Re-raise LLMContextLengthExceededError as it should be handled
|
||||
# by the CrewAgentExecutor._invoke_loop method, which can then decide
|
||||
@@ -1367,6 +1722,128 @@ class LLM(BaseLLM):
|
||||
)
|
||||
raise
|
||||
|
||||
async def acall(
|
||||
self,
|
||||
messages: str | list[LLMMessage],
|
||||
tools: list[dict[str, BaseTool]] | None = None,
|
||||
callbacks: list[Any] | None = None,
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Task | None = None,
|
||||
from_agent: Agent | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str | Any:
|
||||
"""Async high-level LLM call method.
|
||||
|
||||
Args:
|
||||
messages: Input messages for the LLM.
|
||||
Can be a string or list of message dictionaries.
|
||||
If string, it will be converted to a single user message.
|
||||
If list, each dict must have 'role' and 'content' keys.
|
||||
tools: Optional list of tool schemas for function calling.
|
||||
Each tool should define its name, description, and parameters.
|
||||
callbacks: Optional list of callback functions to be executed
|
||||
during and after the LLM call.
|
||||
available_functions: Optional dict mapping function names to callables
|
||||
that can be invoked by the LLM.
|
||||
from_task: Optional Task that invoked the LLM
|
||||
from_agent: Optional Agent that invoked the LLM
|
||||
response_model: Optional Model that contains a pydantic response model.
|
||||
|
||||
Returns:
|
||||
Union[str, Any]: Either a text response from the LLM (str) or
|
||||
the result of a tool function call (Any).
|
||||
|
||||
Raises:
|
||||
TypeError: If messages format is invalid
|
||||
ValueError: If response format is not supported
|
||||
LLMContextLengthExceededError: If input exceeds model's context limit
|
||||
"""
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LLMCallStartedEvent(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
callbacks=callbacks,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
model=self.model,
|
||||
),
|
||||
)
|
||||
|
||||
self._validate_call_params()
|
||||
|
||||
if isinstance(messages, str):
|
||||
messages = [{"role": "user", "content": messages}]
|
||||
|
||||
if "o1" in self.model.lower():
|
||||
for message in messages:
|
||||
if message.get("role") == "system":
|
||||
msg_role: Literal["assistant"] = "assistant"
|
||||
message["role"] = msg_role
|
||||
|
||||
with suppress_warnings():
|
||||
if callbacks and len(callbacks) > 0:
|
||||
self.set_callbacks(callbacks)
|
||||
try:
|
||||
params = self._prepare_completion_params(messages, tools)
|
||||
|
||||
if self.stream:
|
||||
return await self._ahandle_streaming_response(
|
||||
params=params,
|
||||
callbacks=callbacks,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
response_model=response_model,
|
||||
)
|
||||
|
||||
return await self._ahandle_non_streaming_response(
|
||||
params=params,
|
||||
callbacks=callbacks,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
response_model=response_model,
|
||||
)
|
||||
except LLMContextLengthExceededError:
|
||||
raise
|
||||
except Exception as e:
|
||||
unsupported_stop = "Unsupported parameter" in str(
|
||||
e
|
||||
) and "'stop'" in str(e)
|
||||
|
||||
if unsupported_stop:
|
||||
if (
|
||||
"additional_drop_params" in self.additional_params
|
||||
and isinstance(
|
||||
self.additional_params["additional_drop_params"], list
|
||||
)
|
||||
):
|
||||
self.additional_params["additional_drop_params"].append("stop")
|
||||
else:
|
||||
self.additional_params = {"additional_drop_params": ["stop"]}
|
||||
|
||||
logging.info("Retrying LLM call without the unsupported 'stop'")
|
||||
|
||||
return await self.acall(
|
||||
messages,
|
||||
tools=tools,
|
||||
callbacks=callbacks,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
response_model=response_model,
|
||||
)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LLMCallFailedEvent(
|
||||
error=str(e), from_task=from_task, from_agent=from_agent
|
||||
),
|
||||
)
|
||||
raise
|
||||
|
||||
def _handle_emit_call_events(
|
||||
self,
|
||||
response: Any,
|
||||
@@ -1699,12 +2176,14 @@ class LLM(BaseLLM):
|
||||
max_tokens=self.max_tokens,
|
||||
presence_penalty=self.presence_penalty,
|
||||
frequency_penalty=self.frequency_penalty,
|
||||
logit_bias=copy.deepcopy(self.logit_bias, memo)
|
||||
if self.logit_bias
|
||||
else None,
|
||||
response_format=copy.deepcopy(self.response_format, memo)
|
||||
if self.response_format
|
||||
else None,
|
||||
logit_bias=(
|
||||
copy.deepcopy(self.logit_bias, memo) if self.logit_bias else None
|
||||
),
|
||||
response_format=(
|
||||
copy.deepcopy(self.response_format, memo)
|
||||
if self.response_format
|
||||
else None
|
||||
),
|
||||
seed=self.seed,
|
||||
logprobs=self.logprobs,
|
||||
top_logprobs=self.top_logprobs,
|
||||
|
||||
@@ -158,6 +158,44 @@ class BaseLLM(ABC):
|
||||
RuntimeError: If the LLM request fails for other reasons.
|
||||
"""
|
||||
|
||||
async def acall(
|
||||
self,
|
||||
messages: str | list[LLMMessage],
|
||||
tools: list[dict[str, BaseTool]] | None = None,
|
||||
callbacks: list[Any] | None = None,
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Task | None = None,
|
||||
from_agent: Agent | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str | Any:
|
||||
"""Call the LLM with the given messages.
|
||||
|
||||
Args:
|
||||
messages: Input messages for the LLM.
|
||||
Can be a string or list of message dictionaries.
|
||||
If string, it will be converted to a single user message.
|
||||
If list, each dict must have 'role' and 'content' keys.
|
||||
tools: Optional list of tool schemas for function calling.
|
||||
Each tool should define its name, description, and parameters.
|
||||
callbacks: Optional list of callback functions to be executed
|
||||
during and after the LLM call.
|
||||
available_functions: Optional dict mapping function names to callables
|
||||
that can be invoked by the LLM.
|
||||
from_task: Optional task caller to be used for the LLM call.
|
||||
from_agent: Optional agent caller to be used for the LLM call.
|
||||
response_model: Optional response model to be used for the LLM call.
|
||||
|
||||
Returns:
|
||||
Either a text response from the LLM (str) or
|
||||
the result of a tool function call (Any).
|
||||
|
||||
Raises:
|
||||
ValueError: If the messages format is invalid.
|
||||
TimeoutError: If the LLM request times out.
|
||||
RuntimeError: If the LLM request fails for other reasons.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def _convert_tools_for_interference(
|
||||
self, tools: list[dict[str, BaseTool]]
|
||||
) -> list[dict[str, BaseTool]]:
|
||||
@@ -276,7 +314,7 @@ class BaseLLM(ABC):
|
||||
call_type: LLMCallType,
|
||||
from_task: Task | None = None,
|
||||
from_agent: Agent | None = None,
|
||||
messages: str | list[dict[str, Any]] | None = None,
|
||||
messages: str | list[LLMMessage] | None = None,
|
||||
) -> None:
|
||||
"""Emit LLM call completed event."""
|
||||
crewai_event_bus.emit(
|
||||
@@ -548,3 +586,134 @@ class BaseLLM(ABC):
|
||||
Dictionary with token usage totals
|
||||
"""
|
||||
return UsageMetrics(**self._token_usage)
|
||||
|
||||
def _invoke_before_llm_call_hooks(
|
||||
self,
|
||||
messages: list[LLMMessage],
|
||||
from_agent: Agent | None = None,
|
||||
) -> bool:
|
||||
"""Invoke before_llm_call hooks for direct LLM calls (no agent context).
|
||||
|
||||
This method should be called by native provider implementations before
|
||||
making the actual LLM call when from_agent is None (direct calls).
|
||||
|
||||
Args:
|
||||
messages: The messages being sent to the LLM
|
||||
from_agent: The agent making the call (None for direct calls)
|
||||
|
||||
Returns:
|
||||
True if LLM call should proceed, False if blocked by hook
|
||||
|
||||
Example:
|
||||
>>> # In a native provider's call() method:
|
||||
>>> if from_agent is None and not self._invoke_before_llm_call_hooks(
|
||||
... messages, from_agent
|
||||
... ):
|
||||
... raise ValueError("LLM call blocked by hook")
|
||||
"""
|
||||
# Only invoke hooks for direct calls (no agent context)
|
||||
if from_agent is not None:
|
||||
return True
|
||||
|
||||
from crewai.hooks.llm_hooks import (
|
||||
LLMCallHookContext,
|
||||
get_before_llm_call_hooks,
|
||||
)
|
||||
from crewai.utilities.printer import Printer
|
||||
|
||||
before_hooks = get_before_llm_call_hooks()
|
||||
if not before_hooks:
|
||||
return True
|
||||
|
||||
hook_context = LLMCallHookContext(
|
||||
executor=None,
|
||||
messages=messages,
|
||||
llm=self,
|
||||
agent=None,
|
||||
task=None,
|
||||
crew=None,
|
||||
)
|
||||
printer = Printer()
|
||||
|
||||
try:
|
||||
for hook in before_hooks:
|
||||
result = hook(hook_context)
|
||||
if result is False:
|
||||
printer.print(
|
||||
content="LLM call blocked by before_llm_call hook",
|
||||
color="yellow",
|
||||
)
|
||||
return False
|
||||
except Exception as e:
|
||||
printer.print(
|
||||
content=f"Error in before_llm_call hook: {e}",
|
||||
color="yellow",
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
def _invoke_after_llm_call_hooks(
|
||||
self,
|
||||
messages: list[LLMMessage],
|
||||
response: str,
|
||||
from_agent: Agent | None = None,
|
||||
) -> str:
|
||||
"""Invoke after_llm_call hooks for direct LLM calls (no agent context).
|
||||
|
||||
This method should be called by native provider implementations after
|
||||
receiving the LLM response when from_agent is None (direct calls).
|
||||
|
||||
Args:
|
||||
messages: The messages that were sent to the LLM
|
||||
response: The response from the LLM
|
||||
from_agent: The agent that made the call (None for direct calls)
|
||||
|
||||
Returns:
|
||||
The potentially modified response string
|
||||
|
||||
Example:
|
||||
>>> # In a native provider's call() method:
|
||||
>>> if from_agent is None and isinstance(result, str):
|
||||
... result = self._invoke_after_llm_call_hooks(
|
||||
... messages, result, from_agent
|
||||
... )
|
||||
"""
|
||||
# Only invoke hooks for direct calls (no agent context)
|
||||
if from_agent is not None or not isinstance(response, str):
|
||||
return response
|
||||
|
||||
from crewai.hooks.llm_hooks import (
|
||||
LLMCallHookContext,
|
||||
get_after_llm_call_hooks,
|
||||
)
|
||||
from crewai.utilities.printer import Printer
|
||||
|
||||
after_hooks = get_after_llm_call_hooks()
|
||||
if not after_hooks:
|
||||
return response
|
||||
|
||||
hook_context = LLMCallHookContext(
|
||||
executor=None,
|
||||
messages=messages,
|
||||
llm=self,
|
||||
agent=None,
|
||||
task=None,
|
||||
crew=None,
|
||||
response=response,
|
||||
)
|
||||
printer = Printer()
|
||||
modified_response = response
|
||||
|
||||
try:
|
||||
for hook in after_hooks:
|
||||
result = hook(hook_context)
|
||||
if result is not None and isinstance(result, str):
|
||||
modified_response = result
|
||||
hook_context.response = modified_response
|
||||
except Exception as e:
|
||||
printer.print(
|
||||
content=f"Error in after_llm_call hook: {e}",
|
||||
color="yellow",
|
||||
)
|
||||
|
||||
return modified_response
|
||||
|
||||
@@ -182,6 +182,8 @@ OPENAI_MODELS: list[OpenAIModels] = [
|
||||
|
||||
|
||||
AnthropicModels: TypeAlias = Literal[
|
||||
"claude-opus-4-5-20251101",
|
||||
"claude-opus-4-5",
|
||||
"claude-3-7-sonnet-latest",
|
||||
"claude-3-7-sonnet-20250219",
|
||||
"claude-3-5-haiku-latest",
|
||||
@@ -208,6 +210,8 @@ AnthropicModels: TypeAlias = Literal[
|
||||
"claude-3-haiku-20240307",
|
||||
]
|
||||
ANTHROPIC_MODELS: list[AnthropicModels] = [
|
||||
"claude-opus-4-5-20251101",
|
||||
"claude-opus-4-5",
|
||||
"claude-3-7-sonnet-latest",
|
||||
"claude-3-7-sonnet-20250219",
|
||||
"claude-3-5-haiku-latest",
|
||||
@@ -252,6 +256,7 @@ GeminiModels: TypeAlias = Literal[
|
||||
"gemini-2.5-flash-preview-tts",
|
||||
"gemini-2.5-pro-preview-tts",
|
||||
"gemini-2.5-computer-use-preview-10-2025",
|
||||
"gemini-2.5-pro-exp-03-25",
|
||||
"gemini-2.0-flash",
|
||||
"gemini-2.0-flash-001",
|
||||
"gemini-2.0-flash-exp",
|
||||
@@ -305,6 +310,7 @@ GEMINI_MODELS: list[GeminiModels] = [
|
||||
"gemini-2.5-flash-preview-tts",
|
||||
"gemini-2.5-pro-preview-tts",
|
||||
"gemini-2.5-computer-use-preview-10-2025",
|
||||
"gemini-2.5-pro-exp-03-25",
|
||||
"gemini-2.0-flash",
|
||||
"gemini-2.0-flash-001",
|
||||
"gemini-2.0-flash-exp",
|
||||
@@ -452,6 +458,7 @@ BedrockModels: TypeAlias = Literal[
|
||||
"anthropic.claude-3-sonnet-20240229-v1:0:28k",
|
||||
"anthropic.claude-haiku-4-5-20251001-v1:0",
|
||||
"anthropic.claude-instant-v1:2:100k",
|
||||
"anthropic.claude-opus-4-5-20251101-v1:0",
|
||||
"anthropic.claude-opus-4-1-20250805-v1:0",
|
||||
"anthropic.claude-opus-4-20250514-v1:0",
|
||||
"anthropic.claude-sonnet-4-20250514-v1:0",
|
||||
@@ -524,6 +531,7 @@ BEDROCK_MODELS: list[BedrockModels] = [
|
||||
"anthropic.claude-3-sonnet-20240229-v1:0:28k",
|
||||
"anthropic.claude-haiku-4-5-20251001-v1:0",
|
||||
"anthropic.claude-instant-v1:2:100k",
|
||||
"anthropic.claude-opus-4-5-20251101-v1:0",
|
||||
"anthropic.claude-opus-4-1-20250805-v1:0",
|
||||
"anthropic.claude-opus-4-20250514-v1:0",
|
||||
"anthropic.claude-sonnet-4-20250514-v1:0",
|
||||
|
||||
@@ -3,13 +3,14 @@ from __future__ import annotations
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
from typing import TYPE_CHECKING, Any, Literal, cast
|
||||
|
||||
from anthropic.types import ThinkingBlock
|
||||
from pydantic import BaseModel
|
||||
|
||||
from crewai.events.types.llm_events import LLMCallType
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.llms.hooks.transport import HTTPTransport
|
||||
from crewai.llms.hooks.transport import AsyncHTTPTransport, HTTPTransport
|
||||
from crewai.utilities.agent_utils import is_context_length_exceeded
|
||||
from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
||||
LLMContextLengthExceededError,
|
||||
@@ -21,9 +22,8 @@ if TYPE_CHECKING:
|
||||
from crewai.llms.hooks.base import BaseInterceptor
|
||||
|
||||
try:
|
||||
from anthropic import Anthropic
|
||||
from anthropic.types import Message
|
||||
from anthropic.types.tool_use_block import ToolUseBlock
|
||||
from anthropic import Anthropic, AsyncAnthropic
|
||||
from anthropic.types import Message, TextBlock, ThinkingBlock, ToolUseBlock
|
||||
import httpx
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
@@ -31,6 +31,11 @@ except ImportError:
|
||||
) from None
|
||||
|
||||
|
||||
class AnthropicThinkingConfig(BaseModel):
|
||||
type: Literal["enabled", "disabled"]
|
||||
budget_tokens: int | None = None
|
||||
|
||||
|
||||
class AnthropicCompletion(BaseLLM):
|
||||
"""Anthropic native completion implementation.
|
||||
|
||||
@@ -52,6 +57,7 @@ class AnthropicCompletion(BaseLLM):
|
||||
stream: bool = False,
|
||||
client_params: dict[str, Any] | None = None,
|
||||
interceptor: BaseInterceptor[httpx.Request, httpx.Response] | None = None,
|
||||
thinking: AnthropicThinkingConfig | None = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
"""Initialize Anthropic chat completion client.
|
||||
@@ -84,15 +90,24 @@ class AnthropicCompletion(BaseLLM):
|
||||
|
||||
self.client = Anthropic(**self._get_client_params())
|
||||
|
||||
async_client_params = self._get_client_params()
|
||||
if self.interceptor:
|
||||
async_transport = AsyncHTTPTransport(interceptor=self.interceptor)
|
||||
async_http_client = httpx.AsyncClient(transport=async_transport)
|
||||
async_client_params["http_client"] = async_http_client
|
||||
|
||||
self.async_client = AsyncAnthropic(**async_client_params)
|
||||
|
||||
# Store completion parameters
|
||||
self.max_tokens = max_tokens
|
||||
self.top_p = top_p
|
||||
self.stream = stream
|
||||
self.stop_sequences = stop_sequences or []
|
||||
|
||||
self.thinking = thinking
|
||||
self.previous_thinking_blocks: list[ThinkingBlock] = []
|
||||
# Model-specific settings
|
||||
self.is_claude_3 = "claude-3" in model.lower()
|
||||
self.supports_tools = self.is_claude_3 # Claude 3+ supports tool use
|
||||
self.supports_tools = True
|
||||
|
||||
@property
|
||||
def stop(self) -> list[str]:
|
||||
@@ -182,6 +197,9 @@ class AnthropicCompletion(BaseLLM):
|
||||
messages
|
||||
)
|
||||
|
||||
if not self._invoke_before_llm_call_hooks(formatted_messages, from_agent):
|
||||
raise ValueError("LLM call blocked by before_llm_call hook")
|
||||
|
||||
# Prepare completion parameters
|
||||
completion_params = self._prepare_completion_params(
|
||||
formatted_messages, system_message, tools
|
||||
@@ -213,6 +231,72 @@ class AnthropicCompletion(BaseLLM):
|
||||
)
|
||||
raise
|
||||
|
||||
async def acall(
|
||||
self,
|
||||
messages: str | list[LLMMessage],
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
callbacks: list[Any] | None = None,
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str | Any:
|
||||
"""Async call to Anthropic messages API.
|
||||
|
||||
Args:
|
||||
messages: Input messages for the chat completion
|
||||
tools: List of tool/function definitions
|
||||
callbacks: Callback functions (not used in native implementation)
|
||||
available_functions: Available functions for tool calling
|
||||
from_task: Task that initiated the call
|
||||
from_agent: Agent that initiated the call
|
||||
|
||||
Returns:
|
||||
Chat completion response or tool call result
|
||||
"""
|
||||
try:
|
||||
self._emit_call_started_event(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
callbacks=callbacks,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
formatted_messages, system_message = self._format_messages_for_anthropic(
|
||||
messages
|
||||
)
|
||||
|
||||
completion_params = self._prepare_completion_params(
|
||||
formatted_messages, system_message, tools
|
||||
)
|
||||
|
||||
if self.stream:
|
||||
return await self._ahandle_streaming_completion(
|
||||
completion_params,
|
||||
available_functions,
|
||||
from_task,
|
||||
from_agent,
|
||||
response_model,
|
||||
)
|
||||
|
||||
return await self._ahandle_completion(
|
||||
completion_params,
|
||||
available_functions,
|
||||
from_task,
|
||||
from_agent,
|
||||
response_model,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Anthropic API call failed: {e!s}"
|
||||
logging.error(error_msg)
|
||||
self._emit_call_failed_event(
|
||||
error=error_msg, from_task=from_task, from_agent=from_agent
|
||||
)
|
||||
raise
|
||||
|
||||
def _prepare_completion_params(
|
||||
self,
|
||||
messages: list[LLMMessage],
|
||||
@@ -252,6 +336,12 @@ class AnthropicCompletion(BaseLLM):
|
||||
if tools and self.supports_tools:
|
||||
params["tools"] = self._convert_tools_for_interference(tools)
|
||||
|
||||
if self.thinking:
|
||||
if isinstance(self.thinking, AnthropicThinkingConfig):
|
||||
params["thinking"] = self.thinking.model_dump()
|
||||
else:
|
||||
params["thinking"] = self.thinking
|
||||
|
||||
return params
|
||||
|
||||
def _convert_tools_for_interference(
|
||||
@@ -291,6 +381,34 @@ class AnthropicCompletion(BaseLLM):
|
||||
|
||||
return anthropic_tools
|
||||
|
||||
def _extract_thinking_block(
|
||||
self, content_block: Any
|
||||
) -> ThinkingBlock | dict[str, Any] | None:
|
||||
"""Extract and format thinking block from content block.
|
||||
|
||||
Args:
|
||||
content_block: Content block from Anthropic response
|
||||
|
||||
Returns:
|
||||
Dictionary with thinking block data including signature, or None if not a thinking block
|
||||
"""
|
||||
if content_block.type == "thinking":
|
||||
thinking_block = {
|
||||
"type": "thinking",
|
||||
"thinking": content_block.thinking,
|
||||
}
|
||||
if hasattr(content_block, "signature"):
|
||||
thinking_block["signature"] = content_block.signature
|
||||
return thinking_block
|
||||
if content_block.type == "redacted_thinking":
|
||||
redacted_block = {"type": "redacted_thinking"}
|
||||
if hasattr(content_block, "thinking"):
|
||||
redacted_block["thinking"] = content_block.thinking
|
||||
if hasattr(content_block, "signature"):
|
||||
redacted_block["signature"] = content_block.signature
|
||||
return redacted_block
|
||||
return None
|
||||
|
||||
def _format_messages_for_anthropic(
|
||||
self, messages: str | list[LLMMessage]
|
||||
) -> tuple[list[LLMMessage], str | None]:
|
||||
@@ -300,6 +418,7 @@ class AnthropicCompletion(BaseLLM):
|
||||
- System messages are separate from conversation messages
|
||||
- Messages must alternate between user and assistant
|
||||
- First message must be from user
|
||||
- When thinking is enabled, assistant messages must start with thinking blocks
|
||||
|
||||
Args:
|
||||
messages: Input messages
|
||||
@@ -324,8 +443,29 @@ class AnthropicCompletion(BaseLLM):
|
||||
system_message = cast(str, content)
|
||||
else:
|
||||
role_str = role if role is not None else "user"
|
||||
content_str = content if content is not None else ""
|
||||
formatted_messages.append({"role": role_str, "content": content_str})
|
||||
|
||||
if isinstance(content, list):
|
||||
formatted_messages.append({"role": role_str, "content": content})
|
||||
elif (
|
||||
role_str == "assistant"
|
||||
and self.thinking
|
||||
and self.previous_thinking_blocks
|
||||
):
|
||||
structured_content = cast(
|
||||
list[dict[str, Any]],
|
||||
[
|
||||
*self.previous_thinking_blocks,
|
||||
{"type": "text", "text": content if content else ""},
|
||||
],
|
||||
)
|
||||
formatted_messages.append(
|
||||
LLMMessage(role=role_str, content=structured_content)
|
||||
)
|
||||
else:
|
||||
content_str = content if content is not None else ""
|
||||
formatted_messages.append(
|
||||
LLMMessage(role=role_str, content=content_str)
|
||||
)
|
||||
|
||||
# Ensure first message is from user (Anthropic requirement)
|
||||
if not formatted_messages:
|
||||
@@ -375,7 +515,6 @@ class AnthropicCompletion(BaseLLM):
|
||||
if tool_uses and tool_uses[0].name == "structured_output":
|
||||
structured_data = tool_uses[0].input
|
||||
structured_json = json.dumps(structured_data)
|
||||
|
||||
self._emit_call_completed_event(
|
||||
response=structured_json,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
@@ -403,15 +542,22 @@ class AnthropicCompletion(BaseLLM):
|
||||
from_agent,
|
||||
)
|
||||
|
||||
# Extract text content
|
||||
content = ""
|
||||
thinking_blocks: list[ThinkingBlock] = []
|
||||
|
||||
if response.content:
|
||||
for content_block in response.content:
|
||||
if hasattr(content_block, "text"):
|
||||
content += content_block.text
|
||||
else:
|
||||
thinking_block = self._extract_thinking_block(content_block)
|
||||
if thinking_block:
|
||||
thinking_blocks.append(cast(ThinkingBlock, thinking_block))
|
||||
|
||||
if thinking_blocks:
|
||||
self.previous_thinking_blocks = thinking_blocks
|
||||
|
||||
content = self._apply_stop_words(content)
|
||||
|
||||
self._emit_call_completed_event(
|
||||
response=content,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
@@ -423,7 +569,9 @@ class AnthropicCompletion(BaseLLM):
|
||||
if usage.get("total_tokens", 0) > 0:
|
||||
logging.info(f"Anthropic API usage: {usage}")
|
||||
|
||||
return content
|
||||
return self._invoke_after_llm_call_hooks(
|
||||
params["messages"], content, from_agent
|
||||
)
|
||||
|
||||
def _handle_streaming_completion(
|
||||
self,
|
||||
@@ -464,6 +612,16 @@ class AnthropicCompletion(BaseLLM):
|
||||
|
||||
final_message: Message = stream.get_final_message()
|
||||
|
||||
thinking_blocks: list[ThinkingBlock] = []
|
||||
if final_message.content:
|
||||
for content_block in final_message.content:
|
||||
thinking_block = self._extract_thinking_block(content_block)
|
||||
if thinking_block:
|
||||
thinking_blocks.append(cast(ThinkingBlock, thinking_block))
|
||||
|
||||
if thinking_blocks:
|
||||
self.previous_thinking_blocks = thinking_blocks
|
||||
|
||||
usage = self._extract_anthropic_token_usage(final_message)
|
||||
self._track_token_usage_internal(usage)
|
||||
|
||||
@@ -517,7 +675,9 @@ class AnthropicCompletion(BaseLLM):
|
||||
messages=params["messages"],
|
||||
)
|
||||
|
||||
return full_response
|
||||
return self._invoke_after_llm_call_hooks(
|
||||
params["messages"], full_response, from_agent
|
||||
)
|
||||
|
||||
def _handle_tool_use_conversation(
|
||||
self,
|
||||
@@ -546,7 +706,7 @@ class AnthropicCompletion(BaseLLM):
|
||||
# Execute the tool
|
||||
result = self._handle_tool_execution(
|
||||
function_name=function_name,
|
||||
function_args=function_args, # type: ignore
|
||||
function_args=function_args,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
@@ -566,7 +726,26 @@ class AnthropicCompletion(BaseLLM):
|
||||
follow_up_params = params.copy()
|
||||
|
||||
# Add Claude's tool use response to conversation
|
||||
assistant_message = {"role": "assistant", "content": initial_response.content}
|
||||
assistant_content: list[
|
||||
ThinkingBlock | ToolUseBlock | TextBlock | dict[str, Any]
|
||||
] = []
|
||||
for block in initial_response.content:
|
||||
thinking_block = self._extract_thinking_block(block)
|
||||
if thinking_block:
|
||||
assistant_content.append(thinking_block)
|
||||
elif block.type == "tool_use":
|
||||
assistant_content.append(
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": block.id,
|
||||
"name": block.name,
|
||||
"input": block.input,
|
||||
}
|
||||
)
|
||||
elif hasattr(block, "text"):
|
||||
assistant_content.append({"type": "text", "text": block.text})
|
||||
|
||||
assistant_message = {"role": "assistant", "content": assistant_content}
|
||||
|
||||
# Add user message with tool results
|
||||
user_message = {"role": "user", "content": tool_results}
|
||||
@@ -585,12 +764,20 @@ class AnthropicCompletion(BaseLLM):
|
||||
follow_up_usage = self._extract_anthropic_token_usage(final_response)
|
||||
self._track_token_usage_internal(follow_up_usage)
|
||||
|
||||
# Extract final text content
|
||||
final_content = ""
|
||||
thinking_blocks: list[ThinkingBlock] = []
|
||||
|
||||
if final_response.content:
|
||||
for content_block in final_response.content:
|
||||
if hasattr(content_block, "text"):
|
||||
final_content += content_block.text
|
||||
else:
|
||||
thinking_block = self._extract_thinking_block(content_block)
|
||||
if thinking_block:
|
||||
thinking_blocks.append(cast(ThinkingBlock, thinking_block))
|
||||
|
||||
if thinking_blocks:
|
||||
self.previous_thinking_blocks = thinking_blocks
|
||||
|
||||
final_content = self._apply_stop_words(final_content)
|
||||
|
||||
@@ -626,6 +813,275 @@ class AnthropicCompletion(BaseLLM):
|
||||
return tool_results[0]["content"]
|
||||
raise e
|
||||
|
||||
async def _ahandle_completion(
|
||||
self,
|
||||
params: dict[str, Any],
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str | Any:
|
||||
"""Handle non-streaming async message completion."""
|
||||
if response_model:
|
||||
structured_tool = {
|
||||
"name": "structured_output",
|
||||
"description": "Returns structured data according to the schema",
|
||||
"input_schema": response_model.model_json_schema(),
|
||||
}
|
||||
|
||||
params["tools"] = [structured_tool]
|
||||
params["tool_choice"] = {"type": "tool", "name": "structured_output"}
|
||||
|
||||
try:
|
||||
response: Message = await self.async_client.messages.create(**params)
|
||||
|
||||
except Exception as e:
|
||||
if is_context_length_exceeded(e):
|
||||
logging.error(f"Context window exceeded: {e}")
|
||||
raise LLMContextLengthExceededError(str(e)) from e
|
||||
raise e from e
|
||||
|
||||
usage = self._extract_anthropic_token_usage(response)
|
||||
self._track_token_usage_internal(usage)
|
||||
|
||||
if response_model and response.content:
|
||||
tool_uses = [
|
||||
block for block in response.content if isinstance(block, ToolUseBlock)
|
||||
]
|
||||
if tool_uses and tool_uses[0].name == "structured_output":
|
||||
structured_data = tool_uses[0].input
|
||||
structured_json = json.dumps(structured_data)
|
||||
|
||||
self._emit_call_completed_event(
|
||||
response=structured_json,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
)
|
||||
|
||||
return structured_json
|
||||
|
||||
if response.content and available_functions:
|
||||
tool_uses = [
|
||||
block for block in response.content if isinstance(block, ToolUseBlock)
|
||||
]
|
||||
|
||||
if tool_uses:
|
||||
return await self._ahandle_tool_use_conversation(
|
||||
response,
|
||||
tool_uses,
|
||||
params,
|
||||
available_functions,
|
||||
from_task,
|
||||
from_agent,
|
||||
)
|
||||
|
||||
content = ""
|
||||
if response.content:
|
||||
for content_block in response.content:
|
||||
if hasattr(content_block, "text"):
|
||||
content += content_block.text
|
||||
|
||||
content = self._apply_stop_words(content)
|
||||
|
||||
self._emit_call_completed_event(
|
||||
response=content,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
)
|
||||
|
||||
if usage.get("total_tokens", 0) > 0:
|
||||
logging.info(f"Anthropic API usage: {usage}")
|
||||
|
||||
return content
|
||||
|
||||
async def _ahandle_streaming_completion(
|
||||
self,
|
||||
params: dict[str, Any],
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str:
|
||||
"""Handle async streaming message completion."""
|
||||
if response_model:
|
||||
structured_tool = {
|
||||
"name": "structured_output",
|
||||
"description": "Returns structured data according to the schema",
|
||||
"input_schema": response_model.model_json_schema(),
|
||||
}
|
||||
|
||||
params["tools"] = [structured_tool]
|
||||
params["tool_choice"] = {"type": "tool", "name": "structured_output"}
|
||||
|
||||
full_response = ""
|
||||
|
||||
stream_params = {k: v for k, v in params.items() if k != "stream"}
|
||||
|
||||
async with self.async_client.messages.stream(**stream_params) as stream:
|
||||
async for event in stream:
|
||||
if hasattr(event, "delta") and hasattr(event.delta, "text"):
|
||||
text_delta = event.delta.text
|
||||
full_response += text_delta
|
||||
self._emit_stream_chunk_event(
|
||||
chunk=text_delta,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
final_message: Message = await stream.get_final_message()
|
||||
|
||||
usage = self._extract_anthropic_token_usage(final_message)
|
||||
self._track_token_usage_internal(usage)
|
||||
|
||||
if response_model and final_message.content:
|
||||
tool_uses = [
|
||||
block
|
||||
for block in final_message.content
|
||||
if isinstance(block, ToolUseBlock)
|
||||
]
|
||||
if tool_uses and tool_uses[0].name == "structured_output":
|
||||
structured_data = tool_uses[0].input
|
||||
structured_json = json.dumps(structured_data)
|
||||
|
||||
self._emit_call_completed_event(
|
||||
response=structured_json,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
)
|
||||
|
||||
return structured_json
|
||||
|
||||
if final_message.content and available_functions:
|
||||
tool_uses = [
|
||||
block
|
||||
for block in final_message.content
|
||||
if isinstance(block, ToolUseBlock)
|
||||
]
|
||||
|
||||
if tool_uses:
|
||||
return await self._ahandle_tool_use_conversation(
|
||||
final_message,
|
||||
tool_uses,
|
||||
params,
|
||||
available_functions,
|
||||
from_task,
|
||||
from_agent,
|
||||
)
|
||||
|
||||
full_response = self._apply_stop_words(full_response)
|
||||
|
||||
self._emit_call_completed_event(
|
||||
response=full_response,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
)
|
||||
|
||||
return full_response
|
||||
|
||||
async def _ahandle_tool_use_conversation(
|
||||
self,
|
||||
initial_response: Message,
|
||||
tool_uses: list[ToolUseBlock],
|
||||
params: dict[str, Any],
|
||||
available_functions: dict[str, Any],
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
) -> str:
|
||||
"""Handle the complete async tool use conversation flow.
|
||||
|
||||
This implements the proper Anthropic tool use pattern:
|
||||
1. Claude requests tool use
|
||||
2. We execute the tools
|
||||
3. We send tool results back to Claude
|
||||
4. Claude processes results and generates final response
|
||||
"""
|
||||
tool_results = []
|
||||
|
||||
for tool_use in tool_uses:
|
||||
function_name = tool_use.name
|
||||
function_args = tool_use.input
|
||||
|
||||
result = self._handle_tool_execution(
|
||||
function_name=function_name,
|
||||
function_args=function_args,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
tool_result = {
|
||||
"type": "tool_result",
|
||||
"tool_use_id": tool_use.id,
|
||||
"content": str(result)
|
||||
if result is not None
|
||||
else "Tool execution completed",
|
||||
}
|
||||
tool_results.append(tool_result)
|
||||
|
||||
follow_up_params = params.copy()
|
||||
|
||||
assistant_message = {"role": "assistant", "content": initial_response.content}
|
||||
|
||||
user_message = {"role": "user", "content": tool_results}
|
||||
|
||||
follow_up_params["messages"] = params["messages"] + [
|
||||
assistant_message,
|
||||
user_message,
|
||||
]
|
||||
|
||||
try:
|
||||
final_response: Message = await self.async_client.messages.create(
|
||||
**follow_up_params
|
||||
)
|
||||
|
||||
follow_up_usage = self._extract_anthropic_token_usage(final_response)
|
||||
self._track_token_usage_internal(follow_up_usage)
|
||||
|
||||
final_content = ""
|
||||
if final_response.content:
|
||||
for content_block in final_response.content:
|
||||
if hasattr(content_block, "text"):
|
||||
final_content += content_block.text
|
||||
|
||||
final_content = self._apply_stop_words(final_content)
|
||||
|
||||
self._emit_call_completed_event(
|
||||
response=final_content,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=follow_up_params["messages"],
|
||||
)
|
||||
|
||||
total_usage = {
|
||||
"input_tokens": follow_up_usage.get("input_tokens", 0),
|
||||
"output_tokens": follow_up_usage.get("output_tokens", 0),
|
||||
"total_tokens": follow_up_usage.get("total_tokens", 0),
|
||||
}
|
||||
|
||||
if total_usage.get("total_tokens", 0) > 0:
|
||||
logging.info(f"Anthropic API tool conversation usage: {total_usage}")
|
||||
|
||||
return final_content
|
||||
|
||||
except Exception as e:
|
||||
if is_context_length_exceeded(e):
|
||||
logging.error(f"Context window exceeded in tool follow-up: {e}")
|
||||
raise LLMContextLengthExceededError(str(e)) from e
|
||||
|
||||
logging.error(f"Tool follow-up conversation failed: {e}")
|
||||
if tool_results:
|
||||
return tool_results[0]["content"]
|
||||
raise e
|
||||
|
||||
def supports_function_calling(self) -> bool:
|
||||
"""Check if the model supports function calling."""
|
||||
return self.supports_tools
|
||||
|
||||
@@ -6,8 +6,10 @@ import os
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import Self
|
||||
|
||||
from crewai.utilities.agent_utils import is_context_length_exceeded
|
||||
from crewai.utilities.converter import generate_model_description
|
||||
from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
||||
LLMContextLengthExceededError,
|
||||
)
|
||||
@@ -23,9 +25,13 @@ try:
|
||||
from azure.ai.inference import (
|
||||
ChatCompletionsClient,
|
||||
)
|
||||
from azure.ai.inference.aio import (
|
||||
ChatCompletionsClient as AsyncChatCompletionsClient,
|
||||
)
|
||||
from azure.ai.inference.models import (
|
||||
ChatCompletions,
|
||||
ChatCompletionsToolCall,
|
||||
JsonSchemaFormat,
|
||||
StreamingChatCompletionsUpdate,
|
||||
)
|
||||
from azure.core.credentials import (
|
||||
@@ -133,6 +139,8 @@ class AzureCompletion(BaseLLM):
|
||||
|
||||
self.client = ChatCompletionsClient(**client_kwargs) # type: ignore[arg-type]
|
||||
|
||||
self.async_client = AsyncChatCompletionsClient(**client_kwargs) # type: ignore[arg-type]
|
||||
|
||||
self.top_p = top_p
|
||||
self.frequency_penalty = frequency_penalty
|
||||
self.presence_penalty = presence_penalty
|
||||
@@ -208,6 +216,9 @@ class AzureCompletion(BaseLLM):
|
||||
# Format messages for Azure
|
||||
formatted_messages = self._format_messages_for_azure(messages)
|
||||
|
||||
if not self._invoke_before_llm_call_hooks(formatted_messages, from_agent):
|
||||
raise ValueError("LLM call blocked by before_llm_call hook")
|
||||
|
||||
# Prepare completion parameters
|
||||
completion_params = self._prepare_completion_params(
|
||||
formatted_messages, tools, response_model
|
||||
@@ -256,6 +267,88 @@ class AzureCompletion(BaseLLM):
|
||||
)
|
||||
raise
|
||||
|
||||
async def acall(
|
||||
self,
|
||||
messages: str | list[LLMMessage],
|
||||
tools: list[dict[str, BaseTool]] | None = None,
|
||||
callbacks: list[Any] | None = None,
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str | Any:
|
||||
"""Call Azure AI Inference chat completions API asynchronously.
|
||||
|
||||
Args:
|
||||
messages: Input messages for the chat completion
|
||||
tools: List of tool/function definitions
|
||||
callbacks: Callback functions (not used in native implementation)
|
||||
available_functions: Available functions for tool calling
|
||||
from_task: Task that initiated the call
|
||||
from_agent: Agent that initiated the call
|
||||
response_model: Pydantic model for structured output
|
||||
|
||||
Returns:
|
||||
Chat completion response or tool call result
|
||||
"""
|
||||
try:
|
||||
self._emit_call_started_event(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
callbacks=callbacks,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
formatted_messages = self._format_messages_for_azure(messages)
|
||||
|
||||
completion_params = self._prepare_completion_params(
|
||||
formatted_messages, tools, response_model
|
||||
)
|
||||
|
||||
if self.stream:
|
||||
return await self._ahandle_streaming_completion(
|
||||
completion_params,
|
||||
available_functions,
|
||||
from_task,
|
||||
from_agent,
|
||||
response_model,
|
||||
)
|
||||
|
||||
return await self._ahandle_completion(
|
||||
completion_params,
|
||||
available_functions,
|
||||
from_task,
|
||||
from_agent,
|
||||
response_model,
|
||||
)
|
||||
|
||||
except HttpResponseError as e:
|
||||
if e.status_code == 401:
|
||||
error_msg = "Azure authentication failed. Check your API key."
|
||||
elif e.status_code == 404:
|
||||
error_msg = (
|
||||
f"Azure endpoint not found. Check endpoint URL: {self.endpoint}"
|
||||
)
|
||||
elif e.status_code == 429:
|
||||
error_msg = "Azure API rate limit exceeded. Please retry later."
|
||||
else:
|
||||
error_msg = f"Azure API HTTP error: {e.status_code} - {e.message}"
|
||||
|
||||
logging.error(error_msg)
|
||||
self._emit_call_failed_event(
|
||||
error=error_msg, from_task=from_task, from_agent=from_agent
|
||||
)
|
||||
raise
|
||||
except Exception as e:
|
||||
error_msg = f"Azure API call failed: {e!s}"
|
||||
logging.error(error_msg)
|
||||
self._emit_call_failed_event(
|
||||
error=error_msg, from_task=from_task, from_agent=from_agent
|
||||
)
|
||||
raise
|
||||
|
||||
def _prepare_completion_params(
|
||||
self,
|
||||
messages: list[LLMMessage],
|
||||
@@ -278,13 +371,16 @@ class AzureCompletion(BaseLLM):
|
||||
}
|
||||
|
||||
if response_model and self.is_openai_model:
|
||||
params["response_format"] = {
|
||||
"type": "json_schema",
|
||||
"json_schema": {
|
||||
"name": response_model.__name__,
|
||||
"schema": response_model.model_json_schema(),
|
||||
},
|
||||
}
|
||||
model_description = generate_model_description(response_model)
|
||||
json_schema_info = model_description["json_schema"]
|
||||
json_schema_name = json_schema_info["name"]
|
||||
|
||||
params["response_format"] = JsonSchemaFormat(
|
||||
name=json_schema_name,
|
||||
schema=json_schema_info["schema"],
|
||||
description=f"Schema for {json_schema_name}",
|
||||
strict=json_schema_info["strict"],
|
||||
)
|
||||
|
||||
# Only include model parameter for non-Azure OpenAI endpoints
|
||||
# Azure OpenAI endpoints have the deployment name in the URL
|
||||
@@ -311,8 +407,8 @@ class AzureCompletion(BaseLLM):
|
||||
params["tool_choice"] = "auto"
|
||||
|
||||
additional_params = self.additional_params
|
||||
additional_drop_params = additional_params.get('additional_drop_params')
|
||||
drop_params = additional_params.get('drop_params')
|
||||
additional_drop_params = additional_params.get("additional_drop_params")
|
||||
drop_params = additional_params.get("drop_params")
|
||||
|
||||
if drop_params and isinstance(additional_drop_params, list):
|
||||
for drop_param in additional_drop_params:
|
||||
@@ -457,6 +553,10 @@ class AzureCompletion(BaseLLM):
|
||||
messages=params["messages"],
|
||||
)
|
||||
|
||||
content = self._invoke_after_llm_call_hooks(
|
||||
params["messages"], content, from_agent
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
if is_context_length_exceeded(e):
|
||||
logging.error(f"Context window exceeded: {e}")
|
||||
@@ -549,6 +649,172 @@ class AzureCompletion(BaseLLM):
|
||||
messages=params["messages"],
|
||||
)
|
||||
|
||||
return self._invoke_after_llm_call_hooks(
|
||||
params["messages"], full_response, from_agent
|
||||
)
|
||||
|
||||
async def _ahandle_completion(
|
||||
self,
|
||||
params: dict[str, Any],
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str | Any:
|
||||
"""Handle non-streaming chat completion asynchronously."""
|
||||
try:
|
||||
response: ChatCompletions = await self.async_client.complete(**params)
|
||||
|
||||
if not response.choices:
|
||||
raise ValueError("No choices returned from Azure API")
|
||||
|
||||
choice = response.choices[0]
|
||||
message = choice.message
|
||||
|
||||
usage = self._extract_azure_token_usage(response)
|
||||
self._track_token_usage_internal(usage)
|
||||
|
||||
if response_model and self.is_openai_model:
|
||||
content = message.content or ""
|
||||
try:
|
||||
structured_data = response_model.model_validate_json(content)
|
||||
structured_json = structured_data.model_dump_json()
|
||||
|
||||
self._emit_call_completed_event(
|
||||
response=structured_json,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
)
|
||||
|
||||
return structured_json
|
||||
except Exception as e:
|
||||
error_msg = f"Failed to validate structured output with model {response_model.__name__}: {e}"
|
||||
logging.error(error_msg)
|
||||
raise ValueError(error_msg) from e
|
||||
|
||||
if message.tool_calls and available_functions:
|
||||
tool_call = message.tool_calls[0] # Handle first tool call
|
||||
if isinstance(tool_call, ChatCompletionsToolCall):
|
||||
function_name = tool_call.function.name
|
||||
|
||||
try:
|
||||
function_args = json.loads(tool_call.function.arguments)
|
||||
except json.JSONDecodeError as e:
|
||||
logging.error(f"Failed to parse tool arguments: {e}")
|
||||
function_args = {}
|
||||
|
||||
result = self._handle_tool_execution(
|
||||
function_name=function_name,
|
||||
function_args=function_args,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
content = message.content or ""
|
||||
|
||||
content = self._apply_stop_words(content)
|
||||
|
||||
self._emit_call_completed_event(
|
||||
response=content,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
if is_context_length_exceeded(e):
|
||||
logging.error(f"Context window exceeded: {e}")
|
||||
raise LLMContextLengthExceededError(str(e)) from e
|
||||
|
||||
error_msg = f"Azure API call failed: {e!s}"
|
||||
logging.error(error_msg)
|
||||
self._emit_call_failed_event(
|
||||
error=error_msg, from_task=from_task, from_agent=from_agent
|
||||
)
|
||||
raise e
|
||||
|
||||
return content
|
||||
|
||||
async def _ahandle_streaming_completion(
|
||||
self,
|
||||
params: dict[str, Any],
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str:
|
||||
"""Handle streaming chat completion asynchronously."""
|
||||
full_response = ""
|
||||
tool_calls = {}
|
||||
|
||||
stream = await self.async_client.complete(**params)
|
||||
async for update in stream:
|
||||
if isinstance(update, StreamingChatCompletionsUpdate):
|
||||
if update.choices:
|
||||
choice = update.choices[0]
|
||||
if choice.delta and choice.delta.content:
|
||||
content_delta = choice.delta.content
|
||||
full_response += content_delta
|
||||
self._emit_stream_chunk_event(
|
||||
chunk=content_delta,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
if choice.delta and choice.delta.tool_calls:
|
||||
for tool_call in choice.delta.tool_calls:
|
||||
call_id = tool_call.id or "default"
|
||||
if call_id not in tool_calls:
|
||||
tool_calls[call_id] = {
|
||||
"name": "",
|
||||
"arguments": "",
|
||||
}
|
||||
|
||||
if tool_call.function and tool_call.function.name:
|
||||
tool_calls[call_id]["name"] = tool_call.function.name
|
||||
if tool_call.function and tool_call.function.arguments:
|
||||
tool_calls[call_id]["arguments"] += (
|
||||
tool_call.function.arguments
|
||||
)
|
||||
|
||||
if tool_calls and available_functions:
|
||||
for call_data in tool_calls.values():
|
||||
function_name = call_data["name"]
|
||||
|
||||
try:
|
||||
function_args = json.loads(call_data["arguments"])
|
||||
except json.JSONDecodeError as e:
|
||||
logging.error(f"Failed to parse streamed tool arguments: {e}")
|
||||
continue
|
||||
|
||||
result = self._handle_tool_execution(
|
||||
function_name=function_name,
|
||||
function_args=function_args,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
full_response = self._apply_stop_words(full_response)
|
||||
|
||||
self._emit_call_completed_event(
|
||||
response=full_response,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
)
|
||||
|
||||
return full_response
|
||||
|
||||
def supports_function_calling(self) -> bool:
|
||||
@@ -604,3 +870,20 @@ class AzureCompletion(BaseLLM):
|
||||
"total_tokens": getattr(usage, "total_tokens", 0),
|
||||
}
|
||||
return {"total_tokens": 0}
|
||||
|
||||
async def aclose(self) -> None:
|
||||
"""Close the async client and clean up resources.
|
||||
|
||||
This ensures proper cleanup of the underlying aiohttp session
|
||||
to avoid unclosed connector warnings.
|
||||
"""
|
||||
if hasattr(self.async_client, "close"):
|
||||
await self.async_client.close()
|
||||
|
||||
async def __aenter__(self) -> Self:
|
||||
"""Async context manager entry."""
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
||||
"""Async context manager exit."""
|
||||
await self.aclose()
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping, Sequence
|
||||
from contextlib import AsyncExitStack
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, TypedDict, cast
|
||||
@@ -42,6 +44,16 @@ except ImportError:
|
||||
'AWS Bedrock native provider not available, to install: uv add "crewai[bedrock]"'
|
||||
) from None
|
||||
|
||||
try:
|
||||
from aiobotocore.session import ( # type: ignore[import-untyped]
|
||||
get_session as get_aiobotocore_session,
|
||||
)
|
||||
|
||||
AIOBOTOCORE_AVAILABLE = True
|
||||
except ImportError:
|
||||
AIOBOTOCORE_AVAILABLE = False
|
||||
get_aiobotocore_session = None
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
@@ -221,6 +233,15 @@ class BedrockCompletion(BaseLLM):
|
||||
self.client = session.client("bedrock-runtime", config=config)
|
||||
self.region_name = region_name
|
||||
|
||||
self.aws_access_key_id = aws_access_key_id or os.getenv("AWS_ACCESS_KEY_ID")
|
||||
self.aws_secret_access_key = aws_secret_access_key or os.getenv(
|
||||
"AWS_SECRET_ACCESS_KEY"
|
||||
)
|
||||
self.aws_session_token = aws_session_token or os.getenv("AWS_SESSION_TOKEN")
|
||||
|
||||
self._async_exit_stack = AsyncExitStack() if AIOBOTOCORE_AVAILABLE else None
|
||||
self._async_client_initialized = False
|
||||
|
||||
# Store completion parameters
|
||||
self.max_tokens = max_tokens
|
||||
self.top_p = top_p
|
||||
@@ -291,9 +312,14 @@ class BedrockCompletion(BaseLLM):
|
||||
|
||||
# Format messages for Converse API
|
||||
formatted_messages, system_message = self._format_messages_for_converse(
|
||||
messages # type: ignore[arg-type]
|
||||
messages
|
||||
)
|
||||
|
||||
if not self._invoke_before_llm_call_hooks(
|
||||
cast(list[LLMMessage], formatted_messages), from_agent
|
||||
):
|
||||
raise ValueError("LLM call blocked by before_llm_call hook")
|
||||
|
||||
# Prepare request body
|
||||
body: BedrockConverseRequestBody = {
|
||||
"inferenceConfig": self._get_inference_config(),
|
||||
@@ -335,10 +361,122 @@ class BedrockCompletion(BaseLLM):
|
||||
|
||||
if self.stream:
|
||||
return self._handle_streaming_converse(
|
||||
formatted_messages, body, available_functions, from_task, from_agent
|
||||
cast(list[LLMMessage], formatted_messages),
|
||||
body,
|
||||
available_functions,
|
||||
from_task,
|
||||
from_agent,
|
||||
)
|
||||
|
||||
return self._handle_converse(
|
||||
cast(list[LLMMessage], formatted_messages),
|
||||
body,
|
||||
available_functions,
|
||||
from_task,
|
||||
from_agent,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
if is_context_length_exceeded(e):
|
||||
logging.error(f"Context window exceeded: {e}")
|
||||
raise LLMContextLengthExceededError(str(e)) from e
|
||||
|
||||
error_msg = f"AWS Bedrock API call failed: {e!s}"
|
||||
logging.error(error_msg)
|
||||
self._emit_call_failed_event(
|
||||
error=error_msg, from_task=from_task, from_agent=from_agent
|
||||
)
|
||||
raise
|
||||
|
||||
async def acall(
|
||||
self,
|
||||
messages: str | list[LLMMessage],
|
||||
tools: list[dict[Any, Any]] | None = None,
|
||||
callbacks: list[Any] | None = None,
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str | Any:
|
||||
"""Async call to AWS Bedrock Converse API.
|
||||
|
||||
Args:
|
||||
messages: Input messages as string or list of message dicts.
|
||||
tools: Optional list of tool definitions.
|
||||
callbacks: Optional list of callback handlers.
|
||||
available_functions: Optional dict mapping function names to callables.
|
||||
from_task: Optional task context for events.
|
||||
from_agent: Optional agent context for events.
|
||||
response_model: Optional Pydantic model for structured output.
|
||||
|
||||
Returns:
|
||||
Generated text response or structured output.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: If aiobotocore is not installed.
|
||||
LLMContextLengthExceededError: If context window is exceeded.
|
||||
"""
|
||||
if not AIOBOTOCORE_AVAILABLE:
|
||||
raise NotImplementedError(
|
||||
"Async support for AWS Bedrock requires aiobotocore. "
|
||||
'Install with: uv add "crewai[bedrock-async]"'
|
||||
)
|
||||
|
||||
try:
|
||||
self._emit_call_started_event(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
callbacks=callbacks,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
formatted_messages, system_message = self._format_messages_for_converse(
|
||||
messages # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
body: BedrockConverseRequestBody = {
|
||||
"inferenceConfig": self._get_inference_config(),
|
||||
}
|
||||
|
||||
if system_message:
|
||||
body["system"] = cast(
|
||||
"list[SystemContentBlockTypeDef]",
|
||||
cast(object, [{"text": system_message}]),
|
||||
)
|
||||
|
||||
if tools:
|
||||
tool_config: ToolConfigurationTypeDef = {
|
||||
"tools": cast(
|
||||
"Sequence[ToolTypeDef]",
|
||||
cast(object, self._format_tools_for_converse(tools)),
|
||||
)
|
||||
}
|
||||
body["toolConfig"] = tool_config
|
||||
|
||||
if self.guardrail_config:
|
||||
guardrail_config: GuardrailConfigurationTypeDef = cast(
|
||||
"GuardrailConfigurationTypeDef", cast(object, self.guardrail_config)
|
||||
)
|
||||
body["guardrailConfig"] = guardrail_config
|
||||
|
||||
if self.additional_model_request_fields:
|
||||
body["additionalModelRequestFields"] = (
|
||||
self.additional_model_request_fields
|
||||
)
|
||||
|
||||
if self.additional_model_response_field_paths:
|
||||
body["additionalModelResponseFieldPaths"] = (
|
||||
self.additional_model_response_field_paths
|
||||
)
|
||||
|
||||
if self.stream:
|
||||
return await self._ahandle_streaming_converse(
|
||||
formatted_messages, body, available_functions, from_task, from_agent
|
||||
)
|
||||
|
||||
return await self._ahandle_converse(
|
||||
formatted_messages, body, available_functions, from_task, from_agent
|
||||
)
|
||||
|
||||
@@ -356,7 +494,7 @@ class BedrockCompletion(BaseLLM):
|
||||
|
||||
def _handle_converse(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
messages: list[LLMMessage],
|
||||
body: BedrockConverseRequestBody,
|
||||
available_functions: Mapping[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
@@ -480,7 +618,11 @@ class BedrockCompletion(BaseLLM):
|
||||
messages=messages,
|
||||
)
|
||||
|
||||
return text_content
|
||||
return self._invoke_after_llm_call_hooks(
|
||||
messages,
|
||||
text_content,
|
||||
from_agent,
|
||||
)
|
||||
|
||||
except ClientError as e:
|
||||
# Handle all AWS ClientError exceptions as per documentation
|
||||
@@ -537,7 +679,7 @@ class BedrockCompletion(BaseLLM):
|
||||
|
||||
def _handle_streaming_converse(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
messages: list[LLMMessage],
|
||||
body: BedrockConverseRequestBody,
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
@@ -565,6 +707,341 @@ class BedrockCompletion(BaseLLM):
|
||||
role = event["messageStart"].get("role")
|
||||
logging.debug(f"Streaming message started with role: {role}")
|
||||
|
||||
elif "contentBlockStart" in event:
|
||||
start = event["contentBlockStart"].get("start", {})
|
||||
if "toolUse" in start:
|
||||
current_tool_use = start["toolUse"]
|
||||
tool_use_id = current_tool_use.get("toolUseId")
|
||||
logging.debug(
|
||||
f"Tool use started in stream: {json.dumps(current_tool_use)} (ID: {tool_use_id})"
|
||||
)
|
||||
|
||||
elif "contentBlockDelta" in event:
|
||||
delta = event["contentBlockDelta"]["delta"]
|
||||
if "text" in delta:
|
||||
text_chunk = delta["text"]
|
||||
logging.debug(f"Streaming text chunk: {text_chunk[:50]}...")
|
||||
full_response += text_chunk
|
||||
self._emit_stream_chunk_event(
|
||||
chunk=text_chunk,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
elif "toolUse" in delta and current_tool_use:
|
||||
tool_input = delta["toolUse"].get("input", "")
|
||||
if tool_input:
|
||||
logging.debug(f"Tool input delta: {tool_input}")
|
||||
elif "contentBlockStop" in event:
|
||||
logging.debug("Content block stopped in stream")
|
||||
if current_tool_use and available_functions:
|
||||
function_name = current_tool_use["name"]
|
||||
function_args = cast(
|
||||
dict[str, Any], current_tool_use.get("input", {})
|
||||
)
|
||||
tool_result = self._handle_tool_execution(
|
||||
function_name=function_name,
|
||||
function_args=function_args,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
if tool_result is not None and tool_use_id:
|
||||
messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [{"toolUse": current_tool_use}],
|
||||
}
|
||||
)
|
||||
messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"toolResult": {
|
||||
"toolUseId": tool_use_id,
|
||||
"content": [
|
||||
{"text": str(tool_result)}
|
||||
],
|
||||
}
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
return self._handle_converse(
|
||||
messages,
|
||||
body,
|
||||
available_functions,
|
||||
from_task,
|
||||
from_agent,
|
||||
)
|
||||
current_tool_use = None
|
||||
tool_use_id = None
|
||||
elif "messageStop" in event:
|
||||
stop_reason = event["messageStop"].get("stopReason")
|
||||
logging.debug(f"Streaming message stopped: {stop_reason}")
|
||||
if stop_reason == "max_tokens":
|
||||
logging.warning(
|
||||
"Streaming response truncated due to max_tokens"
|
||||
)
|
||||
elif stop_reason == "content_filtered":
|
||||
logging.warning(
|
||||
"Streaming response filtered due to content policy"
|
||||
)
|
||||
break
|
||||
elif "metadata" in event:
|
||||
metadata = event["metadata"]
|
||||
if "usage" in metadata:
|
||||
usage_metrics = metadata["usage"]
|
||||
self._track_token_usage_internal(usage_metrics)
|
||||
logging.debug(f"Token usage: {usage_metrics}")
|
||||
if "trace" in metadata:
|
||||
logging.debug(
|
||||
f"Trace information available: {metadata['trace']}"
|
||||
)
|
||||
|
||||
except ClientError as e:
|
||||
error_msg = self._handle_client_error(e)
|
||||
raise RuntimeError(error_msg) from e
|
||||
except BotoCoreError as e:
|
||||
error_msg = f"Bedrock streaming connection error: {e}"
|
||||
logging.error(error_msg)
|
||||
raise ConnectionError(error_msg) from e
|
||||
|
||||
full_response = self._apply_stop_words(full_response)
|
||||
|
||||
if not full_response or full_response.strip() == "":
|
||||
logging.warning("Bedrock streaming returned empty content, using fallback")
|
||||
full_response = (
|
||||
"I apologize, but I couldn't generate a response. Please try again."
|
||||
)
|
||||
|
||||
self._emit_call_completed_event(
|
||||
response=full_response,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=messages,
|
||||
)
|
||||
|
||||
return full_response
|
||||
|
||||
async def _ensure_async_client(self) -> Any:
|
||||
"""Ensure async client is initialized and return it."""
|
||||
if not self._async_client_initialized and get_aiobotocore_session:
|
||||
if self._async_exit_stack is None:
|
||||
raise RuntimeError(
|
||||
"Async exit stack not initialized - aiobotocore not available"
|
||||
)
|
||||
session = get_aiobotocore_session()
|
||||
client = await self._async_exit_stack.enter_async_context(
|
||||
session.create_client(
|
||||
"bedrock-runtime",
|
||||
region_name=self.region_name,
|
||||
aws_access_key_id=self.aws_access_key_id,
|
||||
aws_secret_access_key=self.aws_secret_access_key,
|
||||
aws_session_token=self.aws_session_token,
|
||||
)
|
||||
)
|
||||
self._async_client = client
|
||||
self._async_client_initialized = True
|
||||
return self._async_client
|
||||
|
||||
async def _ahandle_converse(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
body: BedrockConverseRequestBody,
|
||||
available_functions: Mapping[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
) -> str:
|
||||
"""Handle async non-streaming converse API call."""
|
||||
try:
|
||||
if not messages:
|
||||
raise ValueError("Messages cannot be empty")
|
||||
|
||||
for i, msg in enumerate(messages):
|
||||
if (
|
||||
not isinstance(msg, dict)
|
||||
or "role" not in msg
|
||||
or "content" not in msg
|
||||
):
|
||||
raise ValueError(f"Invalid message format at index {i}")
|
||||
|
||||
async_client = await self._ensure_async_client()
|
||||
response = await async_client.converse(
|
||||
modelId=self.model_id,
|
||||
messages=cast(
|
||||
"Sequence[MessageTypeDef | MessageOutputTypeDef]",
|
||||
cast(object, messages),
|
||||
),
|
||||
**body,
|
||||
)
|
||||
|
||||
if "usage" in response:
|
||||
self._track_token_usage_internal(response["usage"])
|
||||
|
||||
stop_reason = response.get("stopReason")
|
||||
if stop_reason:
|
||||
logging.debug(f"Response stop reason: {stop_reason}")
|
||||
if stop_reason == "max_tokens":
|
||||
logging.warning("Response truncated due to max_tokens limit")
|
||||
elif stop_reason == "content_filtered":
|
||||
logging.warning("Response was filtered due to content policy")
|
||||
|
||||
output = response.get("output", {})
|
||||
message = output.get("message", {})
|
||||
content = message.get("content", [])
|
||||
|
||||
if not content:
|
||||
logging.warning("No content in Bedrock response")
|
||||
return (
|
||||
"I apologize, but I received an empty response. Please try again."
|
||||
)
|
||||
|
||||
text_content = ""
|
||||
|
||||
for content_block in content:
|
||||
if "text" in content_block:
|
||||
text_content += content_block["text"]
|
||||
|
||||
elif "toolUse" in content_block and available_functions:
|
||||
tool_use_block = content_block["toolUse"]
|
||||
tool_use_id = tool_use_block.get("toolUseId")
|
||||
function_name = tool_use_block["name"]
|
||||
function_args = tool_use_block.get("input", {})
|
||||
|
||||
logging.debug(
|
||||
f"Tool use requested: {function_name} with ID {tool_use_id}"
|
||||
)
|
||||
|
||||
tool_result = self._handle_tool_execution(
|
||||
function_name=function_name,
|
||||
function_args=function_args,
|
||||
available_functions=dict(available_functions),
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
if tool_result is not None:
|
||||
messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [{"toolUse": tool_use_block}],
|
||||
}
|
||||
)
|
||||
|
||||
messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"toolResult": {
|
||||
"toolUseId": tool_use_id,
|
||||
"content": [{"text": str(tool_result)}],
|
||||
}
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
return await self._ahandle_converse(
|
||||
messages, body, available_functions, from_task, from_agent
|
||||
)
|
||||
|
||||
text_content = self._apply_stop_words(text_content)
|
||||
|
||||
if not text_content or text_content.strip() == "":
|
||||
logging.warning("Extracted empty text content from Bedrock response")
|
||||
text_content = "I apologize, but I couldn't generate a proper response. Please try again."
|
||||
|
||||
self._emit_call_completed_event(
|
||||
response=text_content,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=messages,
|
||||
)
|
||||
|
||||
return text_content
|
||||
|
||||
except ClientError as e:
|
||||
error_code = e.response.get("Error", {}).get("Code", "Unknown")
|
||||
error_msg = e.response.get("Error", {}).get("Message", str(e))
|
||||
logging.error(f"AWS Bedrock ClientError ({error_code}): {error_msg}")
|
||||
|
||||
if error_code == "ValidationException":
|
||||
if "last turn" in error_msg and "user message" in error_msg:
|
||||
raise ValueError(
|
||||
f"Conversation format error: {error_msg}. Check message alternation."
|
||||
) from e
|
||||
raise ValueError(f"Request validation failed: {error_msg}") from e
|
||||
if error_code == "AccessDeniedException":
|
||||
raise PermissionError(
|
||||
f"Access denied to model {self.model_id}: {error_msg}"
|
||||
) from e
|
||||
if error_code == "ResourceNotFoundException":
|
||||
raise ValueError(f"Model {self.model_id} not found: {error_msg}") from e
|
||||
if error_code == "ThrottlingException":
|
||||
raise RuntimeError(
|
||||
f"API throttled, please retry later: {error_msg}"
|
||||
) from e
|
||||
if error_code == "ModelTimeoutException":
|
||||
raise TimeoutError(f"Model request timed out: {error_msg}") from e
|
||||
if error_code == "ServiceQuotaExceededException":
|
||||
raise RuntimeError(f"Service quota exceeded: {error_msg}") from e
|
||||
if error_code == "ModelNotReadyException":
|
||||
raise RuntimeError(
|
||||
f"Model {self.model_id} not ready: {error_msg}"
|
||||
) from e
|
||||
if error_code == "ModelErrorException":
|
||||
raise RuntimeError(f"Model error: {error_msg}") from e
|
||||
if error_code == "InternalServerException":
|
||||
raise RuntimeError(f"Internal server error: {error_msg}") from e
|
||||
if error_code == "ServiceUnavailableException":
|
||||
raise RuntimeError(f"Service unavailable: {error_msg}") from e
|
||||
|
||||
raise RuntimeError(f"Bedrock API error ({error_code}): {error_msg}") from e
|
||||
|
||||
except BotoCoreError as e:
|
||||
error_msg = f"Bedrock connection error: {e}"
|
||||
logging.error(error_msg)
|
||||
raise ConnectionError(error_msg) from e
|
||||
except Exception as e:
|
||||
error_msg = f"Unexpected error in Bedrock converse call: {e}"
|
||||
logging.error(error_msg)
|
||||
raise RuntimeError(error_msg) from e
|
||||
|
||||
async def _ahandle_streaming_converse(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
body: BedrockConverseRequestBody,
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
) -> str:
|
||||
"""Handle async streaming converse API call."""
|
||||
full_response = ""
|
||||
current_tool_use = None
|
||||
tool_use_id = None
|
||||
|
||||
try:
|
||||
async_client = await self._ensure_async_client()
|
||||
response = await async_client.converse_stream(
|
||||
modelId=self.model_id,
|
||||
messages=cast(
|
||||
"Sequence[MessageTypeDef | MessageOutputTypeDef]",
|
||||
cast(object, messages),
|
||||
),
|
||||
**body,
|
||||
)
|
||||
|
||||
stream = response.get("stream")
|
||||
if stream:
|
||||
async for event in stream:
|
||||
if "messageStart" in event:
|
||||
role = event["messageStart"].get("role")
|
||||
logging.debug(f"Streaming message started with role: {role}")
|
||||
|
||||
elif "contentBlockStart" in event:
|
||||
start = event["contentBlockStart"].get("start", {})
|
||||
if "toolUse" in start:
|
||||
@@ -590,17 +1067,14 @@ class BedrockCompletion(BaseLLM):
|
||||
if tool_input:
|
||||
logging.debug(f"Tool input delta: {tool_input}")
|
||||
|
||||
# Content block stop - end of a content block
|
||||
elif "contentBlockStop" in event:
|
||||
logging.debug("Content block stopped in stream")
|
||||
# If we were accumulating a tool use, it's now complete
|
||||
if current_tool_use and available_functions:
|
||||
function_name = current_tool_use["name"]
|
||||
function_args = cast(
|
||||
dict[str, Any], current_tool_use.get("input", {})
|
||||
)
|
||||
|
||||
# Execute tool
|
||||
tool_result = self._handle_tool_execution(
|
||||
function_name=function_name,
|
||||
function_args=function_args,
|
||||
@@ -610,7 +1084,6 @@ class BedrockCompletion(BaseLLM):
|
||||
)
|
||||
|
||||
if tool_result is not None and tool_use_id:
|
||||
# Continue conversation with tool result
|
||||
messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
@@ -634,8 +1107,7 @@ class BedrockCompletion(BaseLLM):
|
||||
}
|
||||
)
|
||||
|
||||
# Recursive call - note this switches to non-streaming
|
||||
return self._handle_converse(
|
||||
return await self._ahandle_converse(
|
||||
messages,
|
||||
body,
|
||||
available_functions,
|
||||
@@ -643,10 +1115,9 @@ class BedrockCompletion(BaseLLM):
|
||||
from_agent,
|
||||
)
|
||||
|
||||
current_tool_use = None
|
||||
tool_use_id = None
|
||||
current_tool_use = None
|
||||
tool_use_id = None
|
||||
|
||||
# Message stop - end of entire message
|
||||
elif "messageStop" in event:
|
||||
stop_reason = event["messageStop"].get("stopReason")
|
||||
logging.debug(f"Streaming message stopped: {stop_reason}")
|
||||
@@ -660,7 +1131,6 @@ class BedrockCompletion(BaseLLM):
|
||||
)
|
||||
break
|
||||
|
||||
# Metadata - contains usage information and trace details
|
||||
elif "metadata" in event:
|
||||
metadata = event["metadata"]
|
||||
if "usage" in metadata:
|
||||
@@ -680,17 +1150,14 @@ class BedrockCompletion(BaseLLM):
|
||||
logging.error(error_msg)
|
||||
raise ConnectionError(error_msg) from e
|
||||
|
||||
# Apply stop words to full response
|
||||
full_response = self._apply_stop_words(full_response)
|
||||
|
||||
# Ensure we don't return empty content
|
||||
if not full_response or full_response.strip() == "":
|
||||
logging.warning("Bedrock streaming returned empty content, using fallback")
|
||||
full_response = (
|
||||
"I apologize, but I couldn't generate a response. Please try again."
|
||||
)
|
||||
|
||||
# Emit completion event
|
||||
self._emit_call_completed_event(
|
||||
response=full_response,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
@@ -699,16 +1166,25 @@ class BedrockCompletion(BaseLLM):
|
||||
messages=messages,
|
||||
)
|
||||
|
||||
return full_response
|
||||
return self._invoke_after_llm_call_hooks(
|
||||
messages,
|
||||
full_response,
|
||||
from_agent,
|
||||
)
|
||||
|
||||
def _format_messages_for_converse(
|
||||
self, messages: str | list[dict[str, str]]
|
||||
self, messages: str | list[LLMMessage]
|
||||
) -> tuple[list[dict[str, Any]], str | None]:
|
||||
"""Format messages for Converse API following AWS documentation."""
|
||||
# Use base class formatting first
|
||||
formatted_messages = self._format_messages(messages) # type: ignore[arg-type]
|
||||
"""Format messages for Converse API following AWS documentation.
|
||||
|
||||
converse_messages = []
|
||||
Note: Returns dict[str, Any] instead of LLMMessage because Bedrock uses
|
||||
a different content structure: {"role": str, "content": [{"text": str}]}
|
||||
rather than the standard {"role": str, "content": str}.
|
||||
"""
|
||||
# Use base class formatting first
|
||||
formatted_messages = self._format_messages(messages)
|
||||
|
||||
converse_messages: list[dict[str, Any]] = []
|
||||
system_message: str | None = None
|
||||
|
||||
for message in formatted_messages:
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from typing import Any, cast
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from crewai.events.types.llm_events import LLMCallType
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.llms.hooks.base import BaseInterceptor
|
||||
from crewai.utilities.agent_utils import is_context_length_exceeded
|
||||
from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
||||
LLMContextLengthExceededError,
|
||||
@@ -15,10 +16,15 @@ from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
||||
from crewai.utilities.types import LLMMessage
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.llms.hooks.base import BaseInterceptor
|
||||
|
||||
|
||||
try:
|
||||
from google import genai # type: ignore[import-untyped]
|
||||
from google.genai import types # type: ignore[import-untyped]
|
||||
from google.genai.errors import APIError # type: ignore[import-untyped]
|
||||
from google import genai
|
||||
from google.genai import types
|
||||
from google.genai.errors import APIError
|
||||
from google.genai.types import GenerateContentResponse, Schema
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
'Google Gen AI native provider not available, to install: uv add "crewai[google-genai]"'
|
||||
@@ -102,7 +108,9 @@ class GeminiCompletion(BaseLLM):
|
||||
|
||||
# Model-specific settings
|
||||
version_match = re.search(r"gemini-(\d+(?:\.\d+)?)", model.lower())
|
||||
self.supports_tools = bool(version_match and float(version_match.group(1)) >= 1.5)
|
||||
self.supports_tools = bool(
|
||||
version_match and float(version_match.group(1)) >= 1.5
|
||||
)
|
||||
|
||||
@property
|
||||
def stop(self) -> list[str]:
|
||||
@@ -128,7 +136,7 @@ class GeminiCompletion(BaseLLM):
|
||||
else:
|
||||
self.stop_sequences = []
|
||||
|
||||
def _initialize_client(self, use_vertexai: bool = False) -> genai.Client: # type: ignore[no-any-unimported]
|
||||
def _initialize_client(self, use_vertexai: bool = False) -> genai.Client:
|
||||
"""Initialize the Google Gen AI client with proper parameter handling.
|
||||
|
||||
Args:
|
||||
@@ -238,6 +246,11 @@ class GeminiCompletion(BaseLLM):
|
||||
messages
|
||||
)
|
||||
|
||||
messages_for_hooks = self._convert_contents_to_dict(formatted_content)
|
||||
|
||||
if not self._invoke_before_llm_call_hooks(messages_for_hooks, from_agent):
|
||||
raise ValueError("LLM call blocked by before_llm_call hook")
|
||||
|
||||
config = self._prepare_generation_config(
|
||||
system_instruction, tools, response_model
|
||||
)
|
||||
@@ -277,7 +290,84 @@ class GeminiCompletion(BaseLLM):
|
||||
)
|
||||
raise
|
||||
|
||||
def _prepare_generation_config( # type: ignore[no-any-unimported]
|
||||
async def acall(
|
||||
self,
|
||||
messages: str | list[LLMMessage],
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
callbacks: list[Any] | None = None,
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str | Any:
|
||||
"""Async call to Google Gemini generate content API.
|
||||
|
||||
Args:
|
||||
messages: Input messages for the chat completion
|
||||
tools: List of tool/function definitions
|
||||
callbacks: Callback functions (not used as token counts are handled by the response)
|
||||
available_functions: Available functions for tool calling
|
||||
from_task: Task that initiated the call
|
||||
from_agent: Agent that initiated the call
|
||||
|
||||
Returns:
|
||||
Chat completion response or tool call result
|
||||
"""
|
||||
try:
|
||||
self._emit_call_started_event(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
callbacks=callbacks,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
self.tools = tools
|
||||
|
||||
formatted_content, system_instruction = self._format_messages_for_gemini(
|
||||
messages
|
||||
)
|
||||
|
||||
config = self._prepare_generation_config(
|
||||
system_instruction, tools, response_model
|
||||
)
|
||||
|
||||
if self.stream:
|
||||
return await self._ahandle_streaming_completion(
|
||||
formatted_content,
|
||||
config,
|
||||
available_functions,
|
||||
from_task,
|
||||
from_agent,
|
||||
response_model,
|
||||
)
|
||||
|
||||
return await self._ahandle_completion(
|
||||
formatted_content,
|
||||
system_instruction,
|
||||
config,
|
||||
available_functions,
|
||||
from_task,
|
||||
from_agent,
|
||||
response_model,
|
||||
)
|
||||
|
||||
except APIError as e:
|
||||
error_msg = f"Google Gemini API error: {e.code} - {e.message}"
|
||||
logging.error(error_msg)
|
||||
self._emit_call_failed_event(
|
||||
error=error_msg, from_task=from_task, from_agent=from_agent
|
||||
)
|
||||
raise
|
||||
except Exception as e:
|
||||
error_msg = f"Google Gemini API call failed: {e!s}"
|
||||
logging.error(error_msg)
|
||||
self._emit_call_failed_event(
|
||||
error=error_msg, from_task=from_task, from_agent=from_agent
|
||||
)
|
||||
raise
|
||||
|
||||
def _prepare_generation_config(
|
||||
self,
|
||||
system_instruction: str | None = None,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
@@ -294,7 +384,7 @@ class GeminiCompletion(BaseLLM):
|
||||
GenerateContentConfig object for Gemini API
|
||||
"""
|
||||
self.tools = tools
|
||||
config_params = {}
|
||||
config_params: dict[str, Any] = {}
|
||||
|
||||
# Add system instruction if present
|
||||
if system_instruction:
|
||||
@@ -329,7 +419,7 @@ class GeminiCompletion(BaseLLM):
|
||||
|
||||
return types.GenerateContentConfig(**config_params)
|
||||
|
||||
def _convert_tools_for_interference( # type: ignore[no-any-unimported]
|
||||
def _convert_tools_for_interference( # type: ignore[override]
|
||||
self, tools: list[dict[str, Any]]
|
||||
) -> list[types.Tool]:
|
||||
"""Convert CrewAI tool format to Gemini function declaration format."""
|
||||
@@ -346,7 +436,7 @@ class GeminiCompletion(BaseLLM):
|
||||
)
|
||||
|
||||
# Add parameters if present - ensure parameters is a dict
|
||||
if parameters and isinstance(parameters, dict):
|
||||
if parameters and isinstance(parameters, Schema):
|
||||
function_declaration.parameters = parameters
|
||||
|
||||
gemini_tool = types.Tool(function_declarations=[function_declaration])
|
||||
@@ -354,7 +444,7 @@ class GeminiCompletion(BaseLLM):
|
||||
|
||||
return gemini_tools
|
||||
|
||||
def _format_messages_for_gemini( # type: ignore[no-any-unimported]
|
||||
def _format_messages_for_gemini(
|
||||
self, messages: str | list[LLMMessage]
|
||||
) -> tuple[list[types.Content], str | None]:
|
||||
"""Format messages for Gemini API.
|
||||
@@ -373,32 +463,41 @@ class GeminiCompletion(BaseLLM):
|
||||
# Use base class formatting first
|
||||
base_formatted = super()._format_messages(messages)
|
||||
|
||||
contents = []
|
||||
contents: list[types.Content] = []
|
||||
system_instruction: str | None = None
|
||||
|
||||
for message in base_formatted:
|
||||
role = message.get("role")
|
||||
content = message.get("content", "")
|
||||
role = message["role"]
|
||||
content = message["content"]
|
||||
|
||||
# Convert content to string if it's a list
|
||||
if isinstance(content, list):
|
||||
text_content = " ".join(
|
||||
str(item.get("text", "")) if isinstance(item, dict) else str(item)
|
||||
for item in content
|
||||
)
|
||||
else:
|
||||
text_content = str(content) if content else ""
|
||||
|
||||
if role == "system":
|
||||
# Extract system instruction - Gemini handles it separately
|
||||
if system_instruction:
|
||||
system_instruction += f"\n\n{content}"
|
||||
system_instruction += f"\n\n{text_content}"
|
||||
else:
|
||||
system_instruction = cast(str, content)
|
||||
system_instruction = text_content
|
||||
else:
|
||||
# Convert role for Gemini (assistant -> model)
|
||||
gemini_role = "model" if role == "assistant" else "user"
|
||||
|
||||
# Create Content object
|
||||
gemini_content = types.Content(
|
||||
role=gemini_role, parts=[types.Part.from_text(text=content)]
|
||||
role=gemini_role, parts=[types.Part.from_text(text=text_content)]
|
||||
)
|
||||
contents.append(gemini_content)
|
||||
|
||||
return contents, system_instruction
|
||||
|
||||
def _handle_completion( # type: ignore[no-any-unimported]
|
||||
def _handle_completion(
|
||||
self,
|
||||
contents: list[types.Content],
|
||||
system_instruction: str | None,
|
||||
@@ -409,14 +508,14 @@ class GeminiCompletion(BaseLLM):
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str | Any:
|
||||
"""Handle non-streaming content generation."""
|
||||
api_params = {
|
||||
"model": self.model,
|
||||
"contents": contents,
|
||||
"config": config,
|
||||
}
|
||||
|
||||
try:
|
||||
response = self.client.models.generate_content(**api_params)
|
||||
# The API accepts list[Content] but mypy is overly strict about variance
|
||||
contents_for_api: Any = contents
|
||||
response = self.client.models.generate_content(
|
||||
model=self.model,
|
||||
contents=contents_for_api,
|
||||
config=config,
|
||||
)
|
||||
|
||||
usage = self._extract_token_usage(response)
|
||||
except Exception as e:
|
||||
@@ -433,6 +532,8 @@ class GeminiCompletion(BaseLLM):
|
||||
for part in candidate.content.parts:
|
||||
if hasattr(part, "function_call") and part.function_call:
|
||||
function_name = part.function_call.name
|
||||
if function_name is None:
|
||||
continue
|
||||
function_args = (
|
||||
dict(part.function_call.args)
|
||||
if part.function_call.args
|
||||
@@ -442,7 +543,7 @@ class GeminiCompletion(BaseLLM):
|
||||
result = self._handle_tool_execution(
|
||||
function_name=function_name,
|
||||
function_args=function_args,
|
||||
available_functions=available_functions, # type: ignore
|
||||
available_functions=available_functions or {},
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
@@ -450,7 +551,7 @@ class GeminiCompletion(BaseLLM):
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
content = response.text if hasattr(response, "text") else ""
|
||||
content = response.text or ""
|
||||
content = self._apply_stop_words(content)
|
||||
|
||||
messages_for_event = self._convert_contents_to_dict(contents)
|
||||
@@ -463,9 +564,11 @@ class GeminiCompletion(BaseLLM):
|
||||
messages=messages_for_event,
|
||||
)
|
||||
|
||||
return content
|
||||
return self._invoke_after_llm_call_hooks(
|
||||
messages_for_event, content, from_agent
|
||||
)
|
||||
|
||||
def _handle_streaming_completion( # type: ignore[no-any-unimported]
|
||||
def _handle_streaming_completion(
|
||||
self,
|
||||
contents: list[types.Content],
|
||||
config: types.GenerateContentConfig,
|
||||
@@ -476,16 +579,16 @@ class GeminiCompletion(BaseLLM):
|
||||
) -> str:
|
||||
"""Handle streaming content generation."""
|
||||
full_response = ""
|
||||
function_calls = {}
|
||||
function_calls: dict[str, dict[str, Any]] = {}
|
||||
|
||||
api_params = {
|
||||
"model": self.model,
|
||||
"contents": contents,
|
||||
"config": config,
|
||||
}
|
||||
|
||||
for chunk in self.client.models.generate_content_stream(**api_params):
|
||||
if hasattr(chunk, "text") and chunk.text:
|
||||
# The API accepts list[Content] but mypy is overly strict about variance
|
||||
contents_for_api: Any = contents
|
||||
for chunk in self.client.models.generate_content_stream(
|
||||
model=self.model,
|
||||
contents=contents_for_api,
|
||||
config=config,
|
||||
):
|
||||
if chunk.text:
|
||||
full_response += chunk.text
|
||||
self._emit_stream_chunk_event(
|
||||
chunk=chunk.text,
|
||||
@@ -493,7 +596,7 @@ class GeminiCompletion(BaseLLM):
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
if hasattr(chunk, "candidates") and chunk.candidates:
|
||||
if chunk.candidates:
|
||||
candidate = chunk.candidates[0]
|
||||
if candidate.content and candidate.content.parts:
|
||||
for part in candidate.content.parts:
|
||||
@@ -513,6 +616,14 @@ class GeminiCompletion(BaseLLM):
|
||||
function_name = call_data["name"]
|
||||
function_args = call_data["args"]
|
||||
|
||||
# Skip if function_name is None
|
||||
if not isinstance(function_name, str):
|
||||
continue
|
||||
|
||||
# Ensure function_args is a dict
|
||||
if not isinstance(function_args, dict):
|
||||
function_args = {}
|
||||
|
||||
# Execute tool
|
||||
result = self._handle_tool_execution(
|
||||
function_name=function_name,
|
||||
@@ -535,7 +646,309 @@ class GeminiCompletion(BaseLLM):
|
||||
messages=messages_for_event,
|
||||
)
|
||||
|
||||
return full_response
|
||||
return self._invoke_after_llm_call_hooks(
|
||||
messages_for_event, full_response, from_agent
|
||||
)
|
||||
|
||||
async def _ahandle_completion(
|
||||
self,
|
||||
contents: list[types.Content],
|
||||
system_instruction: str | None,
|
||||
config: types.GenerateContentConfig,
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str | Any:
|
||||
"""Handle async non-streaming content generation."""
|
||||
try:
|
||||
# The API accepts list[Content] but mypy is overly strict about variance
|
||||
contents_for_api: Any = contents
|
||||
response = await self.client.aio.models.generate_content(
|
||||
model=self.model,
|
||||
contents=contents_for_api,
|
||||
config=config,
|
||||
)
|
||||
|
||||
usage = self._extract_token_usage(response)
|
||||
except Exception as e:
|
||||
if is_context_length_exceeded(e):
|
||||
logging.error(f"Context window exceeded: {e}")
|
||||
raise LLMContextLengthExceededError(str(e)) from e
|
||||
raise e from e
|
||||
|
||||
self._track_token_usage_internal(usage)
|
||||
|
||||
if response.candidates and (self.tools or available_functions):
|
||||
candidate = response.candidates[0]
|
||||
if candidate.content and candidate.content.parts:
|
||||
for part in candidate.content.parts:
|
||||
if hasattr(part, "function_call") and part.function_call:
|
||||
function_name = part.function_call.name
|
||||
if function_name is None:
|
||||
continue
|
||||
function_args = (
|
||||
dict(part.function_call.args)
|
||||
if part.function_call.args
|
||||
else {}
|
||||
)
|
||||
|
||||
result = self._handle_tool_execution(
|
||||
function_name=function_name,
|
||||
function_args=function_args,
|
||||
available_functions=available_functions or {},
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
content = response.text or ""
|
||||
content = self._apply_stop_words(content)
|
||||
|
||||
messages_for_event = self._convert_contents_to_dict(contents)
|
||||
|
||||
self._emit_call_completed_event(
|
||||
response=content,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=messages_for_event,
|
||||
)
|
||||
|
||||
return content
|
||||
|
||||
async def _ahandle_streaming_completion(
|
||||
self,
|
||||
contents: list[types.Content],
|
||||
config: types.GenerateContentConfig,
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str:
|
||||
"""Handle async streaming content generation."""
|
||||
full_response = ""
|
||||
function_calls: dict[str, dict[str, Any]] = {}
|
||||
|
||||
# The API accepts list[Content] but mypy is overly strict about variance
|
||||
contents_for_api: Any = contents
|
||||
stream = await self.client.aio.models.generate_content_stream(
|
||||
model=self.model,
|
||||
contents=contents_for_api,
|
||||
config=config,
|
||||
)
|
||||
async for chunk in stream:
|
||||
if chunk.text:
|
||||
full_response += chunk.text
|
||||
self._emit_stream_chunk_event(
|
||||
chunk=chunk.text,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
if chunk.candidates:
|
||||
candidate = chunk.candidates[0]
|
||||
if candidate.content and candidate.content.parts:
|
||||
for part in candidate.content.parts:
|
||||
if hasattr(part, "function_call") and part.function_call:
|
||||
call_id = part.function_call.name or "default"
|
||||
if call_id not in function_calls:
|
||||
function_calls[call_id] = {
|
||||
"name": part.function_call.name,
|
||||
"args": dict(part.function_call.args)
|
||||
if part.function_call.args
|
||||
else {},
|
||||
}
|
||||
|
||||
if function_calls and available_functions:
|
||||
for call_data in function_calls.values():
|
||||
function_name = call_data["name"]
|
||||
function_args = call_data["args"]
|
||||
|
||||
# Skip if function_name is None
|
||||
if not isinstance(function_name, str):
|
||||
continue
|
||||
|
||||
# Ensure function_args is a dict
|
||||
if not isinstance(function_args, dict):
|
||||
function_args = {}
|
||||
|
||||
result = self._handle_tool_execution(
|
||||
function_name=function_name,
|
||||
function_args=function_args,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
messages_for_event = self._convert_contents_to_dict(contents)
|
||||
|
||||
self._emit_call_completed_event(
|
||||
response=full_response,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=messages_for_event,
|
||||
)
|
||||
|
||||
return self._invoke_after_llm_call_hooks(
|
||||
messages_for_event, full_response, from_agent
|
||||
)
|
||||
|
||||
async def _ahandle_completion(
|
||||
self,
|
||||
contents: list[types.Content],
|
||||
system_instruction: str | None,
|
||||
config: types.GenerateContentConfig,
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str | Any:
|
||||
"""Handle async non-streaming content generation."""
|
||||
try:
|
||||
# The API accepts list[Content] but mypy is overly strict about variance
|
||||
contents_for_api: Any = contents
|
||||
response = await self.client.aio.models.generate_content(
|
||||
model=self.model,
|
||||
contents=contents_for_api,
|
||||
config=config,
|
||||
)
|
||||
|
||||
usage = self._extract_token_usage(response)
|
||||
except Exception as e:
|
||||
if is_context_length_exceeded(e):
|
||||
logging.error(f"Context window exceeded: {e}")
|
||||
raise LLMContextLengthExceededError(str(e)) from e
|
||||
raise e from e
|
||||
|
||||
self._track_token_usage_internal(usage)
|
||||
|
||||
if response.candidates and (self.tools or available_functions):
|
||||
candidate = response.candidates[0]
|
||||
if candidate.content and candidate.content.parts:
|
||||
for part in candidate.content.parts:
|
||||
if hasattr(part, "function_call") and part.function_call:
|
||||
function_name = part.function_call.name
|
||||
if function_name is None:
|
||||
continue
|
||||
function_args = (
|
||||
dict(part.function_call.args)
|
||||
if part.function_call.args
|
||||
else {}
|
||||
)
|
||||
|
||||
result = self._handle_tool_execution(
|
||||
function_name=function_name,
|
||||
function_args=function_args,
|
||||
available_functions=available_functions or {},
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
content = response.text or ""
|
||||
content = self._apply_stop_words(content)
|
||||
|
||||
messages_for_event = self._convert_contents_to_dict(contents)
|
||||
|
||||
self._emit_call_completed_event(
|
||||
response=content,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=messages_for_event,
|
||||
)
|
||||
|
||||
return content
|
||||
|
||||
async def _ahandle_streaming_completion(
|
||||
self,
|
||||
contents: list[types.Content],
|
||||
config: types.GenerateContentConfig,
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str:
|
||||
"""Handle async streaming content generation."""
|
||||
full_response = ""
|
||||
function_calls: dict[str, dict[str, Any]] = {}
|
||||
|
||||
# The API accepts list[Content] but mypy is overly strict about variance
|
||||
contents_for_api: Any = contents
|
||||
stream = await self.client.aio.models.generate_content_stream(
|
||||
model=self.model,
|
||||
contents=contents_for_api,
|
||||
config=config,
|
||||
)
|
||||
async for chunk in stream:
|
||||
if chunk.text:
|
||||
full_response += chunk.text
|
||||
self._emit_stream_chunk_event(
|
||||
chunk=chunk.text,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
if chunk.candidates:
|
||||
candidate = chunk.candidates[0]
|
||||
if candidate.content and candidate.content.parts:
|
||||
for part in candidate.content.parts:
|
||||
if hasattr(part, "function_call") and part.function_call:
|
||||
call_id = part.function_call.name or "default"
|
||||
if call_id not in function_calls:
|
||||
function_calls[call_id] = {
|
||||
"name": part.function_call.name,
|
||||
"args": dict(part.function_call.args)
|
||||
if part.function_call.args
|
||||
else {},
|
||||
}
|
||||
|
||||
if function_calls and available_functions:
|
||||
for call_data in function_calls.values():
|
||||
function_name = call_data["name"]
|
||||
function_args = call_data["args"]
|
||||
|
||||
# Skip if function_name is None
|
||||
if not isinstance(function_name, str):
|
||||
continue
|
||||
|
||||
# Ensure function_args is a dict
|
||||
if not isinstance(function_args, dict):
|
||||
function_args = {}
|
||||
|
||||
result = self._handle_tool_execution(
|
||||
function_name=function_name,
|
||||
function_args=function_args,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
messages_for_event = self._convert_contents_to_dict(contents)
|
||||
|
||||
self._emit_call_completed_event(
|
||||
response=full_response,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=messages_for_event,
|
||||
)
|
||||
|
||||
return self._invoke_after_llm_call_hooks(
|
||||
messages_for_event, full_response, from_agent
|
||||
)
|
||||
|
||||
def supports_function_calling(self) -> bool:
|
||||
"""Check if the model supports function calling."""
|
||||
@@ -583,9 +996,10 @@ class GeminiCompletion(BaseLLM):
|
||||
# Default context window size for Gemini models
|
||||
return int(1048576 * CONTEXT_WINDOW_USAGE_RATIO) # 1M tokens
|
||||
|
||||
def _extract_token_usage(self, response: dict[str, Any]) -> dict[str, Any]:
|
||||
@staticmethod
|
||||
def _extract_token_usage(response: GenerateContentResponse) -> dict[str, Any]:
|
||||
"""Extract token usage from Gemini response."""
|
||||
if hasattr(response, "usage_metadata"):
|
||||
if response.usage_metadata:
|
||||
usage = response.usage_metadata
|
||||
return {
|
||||
"prompt_token_count": getattr(usage, "prompt_token_count", 0),
|
||||
@@ -595,21 +1009,23 @@ class GeminiCompletion(BaseLLM):
|
||||
}
|
||||
return {"total_tokens": 0}
|
||||
|
||||
def _convert_contents_to_dict( # type: ignore[no-any-unimported]
|
||||
def _convert_contents_to_dict(
|
||||
self,
|
||||
contents: list[types.Content],
|
||||
) -> list[dict[str, str]]:
|
||||
) -> list[LLMMessage]:
|
||||
"""Convert contents to dict format."""
|
||||
return [
|
||||
{
|
||||
"role": "assistant"
|
||||
if content_obj.role == "model"
|
||||
else content_obj.role,
|
||||
"content": " ".join(
|
||||
part.text
|
||||
for part in content_obj.parts
|
||||
if hasattr(part, "text") and part.text
|
||||
),
|
||||
}
|
||||
for content_obj in contents
|
||||
]
|
||||
result: list[dict[str, str]] = []
|
||||
for content_obj in contents:
|
||||
role = content_obj.role
|
||||
if role == "model":
|
||||
role = "assistant"
|
||||
elif role is None:
|
||||
role = "user"
|
||||
|
||||
parts = content_obj.parts or []
|
||||
content = " ".join(
|
||||
part.text for part in parts if hasattr(part, "text") and part.text
|
||||
)
|
||||
|
||||
result.append({"role": role, "content": content})
|
||||
return result
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterator
|
||||
from collections.abc import AsyncIterator
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import httpx
|
||||
from openai import APIConnectionError, NotFoundError, OpenAI
|
||||
from openai import APIConnectionError, AsyncOpenAI, NotFoundError, OpenAI, Stream
|
||||
from openai.lib.streaming.chat import ChatCompletionStream
|
||||
from openai.types.chat import ChatCompletion, ChatCompletionChunk
|
||||
from openai.types.chat.chat_completion import Choice
|
||||
from openai.types.chat.chat_completion_chunk import ChoiceDelta
|
||||
@@ -15,7 +16,7 @@ from pydantic import BaseModel
|
||||
|
||||
from crewai.events.types.llm_events import LLMCallType
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.llms.hooks.transport import HTTPTransport
|
||||
from crewai.llms.hooks.transport import AsyncHTTPTransport, HTTPTransport
|
||||
from crewai.utilities.agent_utils import is_context_length_exceeded
|
||||
from crewai.utilities.converter import generate_model_description
|
||||
from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
||||
@@ -101,6 +102,14 @@ class OpenAICompletion(BaseLLM):
|
||||
|
||||
self.client = OpenAI(**client_config)
|
||||
|
||||
async_client_config = self._get_client_params()
|
||||
if self.interceptor:
|
||||
async_transport = AsyncHTTPTransport(interceptor=self.interceptor)
|
||||
async_http_client = httpx.AsyncClient(transport=async_transport)
|
||||
async_client_config["http_client"] = async_http_client
|
||||
|
||||
self.async_client = AsyncOpenAI(**async_client_config)
|
||||
|
||||
# Completion parameters
|
||||
self.top_p = top_p
|
||||
self.frequency_penalty = frequency_penalty
|
||||
@@ -181,6 +190,9 @@ class OpenAICompletion(BaseLLM):
|
||||
|
||||
formatted_messages = self._format_messages(messages)
|
||||
|
||||
if not self._invoke_before_llm_call_hooks(formatted_messages, from_agent):
|
||||
raise ValueError("LLM call blocked by before_llm_call hook")
|
||||
|
||||
completion_params = self._prepare_completion_params(
|
||||
messages=formatted_messages, tools=tools
|
||||
)
|
||||
@@ -210,6 +222,71 @@ class OpenAICompletion(BaseLLM):
|
||||
)
|
||||
raise
|
||||
|
||||
async def acall(
|
||||
self,
|
||||
messages: str | list[LLMMessage],
|
||||
tools: list[dict[str, BaseTool]] | None = None,
|
||||
callbacks: list[Any] | None = None,
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Task | None = None,
|
||||
from_agent: Agent | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str | Any:
|
||||
"""Async call to OpenAI chat completion API.
|
||||
|
||||
Args:
|
||||
messages: Input messages for the chat completion
|
||||
tools: list of tool/function definitions
|
||||
callbacks: Callback functions (not used in native implementation)
|
||||
available_functions: Available functions for tool calling
|
||||
from_task: Task that initiated the call
|
||||
from_agent: Agent that initiated the call
|
||||
response_model: Response model for structured output.
|
||||
|
||||
Returns:
|
||||
Chat completion response or tool call result
|
||||
"""
|
||||
try:
|
||||
self._emit_call_started_event(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
callbacks=callbacks,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
formatted_messages = self._format_messages(messages)
|
||||
|
||||
completion_params = self._prepare_completion_params(
|
||||
messages=formatted_messages, tools=tools
|
||||
)
|
||||
|
||||
if self.stream:
|
||||
return await self._ahandle_streaming_completion(
|
||||
params=completion_params,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
response_model=response_model,
|
||||
)
|
||||
|
||||
return await self._ahandle_completion(
|
||||
params=completion_params,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
response_model=response_model,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"OpenAI API call failed: {e!s}"
|
||||
logging.error(error_msg)
|
||||
self._emit_call_failed_event(
|
||||
error=error_msg, from_task=from_task, from_agent=from_agent
|
||||
)
|
||||
raise
|
||||
|
||||
def _prepare_completion_params(
|
||||
self, messages: list[LLMMessage], tools: list[dict[str, BaseTool]] | None = None
|
||||
) -> dict[str, Any]:
|
||||
@@ -352,10 +429,272 @@ class OpenAICompletion(BaseLLM):
|
||||
|
||||
if message.tool_calls and available_functions:
|
||||
tool_call = message.tool_calls[0]
|
||||
function_name = tool_call.function.name # type: ignore[union-attr]
|
||||
function_name = tool_call.function.name
|
||||
|
||||
try:
|
||||
function_args = json.loads(tool_call.function.arguments) # type: ignore[union-attr]
|
||||
function_args = json.loads(tool_call.function.arguments)
|
||||
except json.JSONDecodeError as e:
|
||||
logging.error(f"Failed to parse tool arguments: {e}")
|
||||
function_args = {}
|
||||
|
||||
result = self._handle_tool_execution(
|
||||
function_name=function_name,
|
||||
function_args=function_args,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
content = message.content or ""
|
||||
content = self._apply_stop_words(content)
|
||||
|
||||
if self.response_format and isinstance(self.response_format, type):
|
||||
try:
|
||||
structured_result = self._validate_structured_output(
|
||||
content, self.response_format
|
||||
)
|
||||
self._emit_call_completed_event(
|
||||
response=structured_result,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
)
|
||||
return structured_result
|
||||
except ValueError as e:
|
||||
logging.warning(f"Structured output validation failed: {e}")
|
||||
|
||||
self._emit_call_completed_event(
|
||||
response=content,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
)
|
||||
|
||||
if usage.get("total_tokens", 0) > 0:
|
||||
logging.info(f"OpenAI API usage: {usage}")
|
||||
|
||||
content = self._invoke_after_llm_call_hooks(
|
||||
params["messages"], content, from_agent
|
||||
)
|
||||
except NotFoundError as e:
|
||||
error_msg = f"Model {self.model} not found: {e}"
|
||||
logging.error(error_msg)
|
||||
self._emit_call_failed_event(
|
||||
error=error_msg, from_task=from_task, from_agent=from_agent
|
||||
)
|
||||
raise ValueError(error_msg) from e
|
||||
except APIConnectionError as e:
|
||||
error_msg = f"Failed to connect to OpenAI API: {e}"
|
||||
logging.error(error_msg)
|
||||
self._emit_call_failed_event(
|
||||
error=error_msg, from_task=from_task, from_agent=from_agent
|
||||
)
|
||||
raise ConnectionError(error_msg) from e
|
||||
except Exception as e:
|
||||
# Handle context length exceeded and other errors
|
||||
if is_context_length_exceeded(e):
|
||||
logging.error(f"Context window exceeded: {e}")
|
||||
raise LLMContextLengthExceededError(str(e)) from e
|
||||
|
||||
error_msg = f"OpenAI API call failed: {e!s}"
|
||||
logging.error(error_msg)
|
||||
self._emit_call_failed_event(
|
||||
error=error_msg, from_task=from_task, from_agent=from_agent
|
||||
)
|
||||
raise e from e
|
||||
|
||||
return content
|
||||
|
||||
def _handle_streaming_completion(
|
||||
self,
|
||||
params: dict[str, Any],
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str:
|
||||
"""Handle streaming chat completion."""
|
||||
full_response = ""
|
||||
tool_calls = {}
|
||||
|
||||
if response_model:
|
||||
parse_params = {
|
||||
k: v
|
||||
for k, v in params.items()
|
||||
if k not in ("response_format", "stream")
|
||||
}
|
||||
|
||||
stream: ChatCompletionStream[BaseModel]
|
||||
with self.client.beta.chat.completions.stream(
|
||||
**parse_params, response_format=response_model
|
||||
) as stream:
|
||||
for chunk in stream:
|
||||
if chunk.type == "content.delta":
|
||||
delta_content = chunk.delta
|
||||
if delta_content:
|
||||
self._emit_stream_chunk_event(
|
||||
chunk=delta_content,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
final_completion = stream.get_final_completion()
|
||||
if final_completion and final_completion.choices:
|
||||
parsed_result = final_completion.choices[0].message.parsed
|
||||
if parsed_result:
|
||||
structured_json = parsed_result.model_dump_json()
|
||||
self._emit_call_completed_event(
|
||||
response=structured_json,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
)
|
||||
return structured_json
|
||||
|
||||
logging.error("Failed to get parsed result from stream")
|
||||
return ""
|
||||
|
||||
completion_stream: Stream[ChatCompletionChunk] = (
|
||||
self.client.chat.completions.create(**params)
|
||||
)
|
||||
|
||||
for completion_chunk in completion_stream:
|
||||
if not completion_chunk.choices:
|
||||
continue
|
||||
|
||||
choice = completion_chunk.choices[0]
|
||||
chunk_delta: ChoiceDelta = choice.delta
|
||||
|
||||
if chunk_delta.content:
|
||||
full_response += chunk_delta.content
|
||||
self._emit_stream_chunk_event(
|
||||
chunk=chunk_delta.content,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
if chunk_delta.tool_calls:
|
||||
for tool_call in chunk_delta.tool_calls:
|
||||
call_id = tool_call.id or "default"
|
||||
if call_id not in tool_calls:
|
||||
tool_calls[call_id] = {
|
||||
"name": "",
|
||||
"arguments": "",
|
||||
}
|
||||
|
||||
if tool_call.function and tool_call.function.name:
|
||||
tool_calls[call_id]["name"] = tool_call.function.name
|
||||
if tool_call.function and tool_call.function.arguments:
|
||||
tool_calls[call_id]["arguments"] += tool_call.function.arguments
|
||||
|
||||
if tool_calls and available_functions:
|
||||
for call_data in tool_calls.values():
|
||||
function_name = call_data["name"]
|
||||
arguments = call_data["arguments"]
|
||||
|
||||
# Skip if function name is empty or arguments are empty
|
||||
if not function_name or not arguments:
|
||||
continue
|
||||
|
||||
# Check if function exists in available functions
|
||||
if function_name not in available_functions:
|
||||
logging.warning(
|
||||
f"Function '{function_name}' not found in available functions"
|
||||
)
|
||||
continue
|
||||
|
||||
try:
|
||||
function_args = json.loads(arguments)
|
||||
except json.JSONDecodeError as e:
|
||||
logging.error(f"Failed to parse streamed tool arguments: {e}")
|
||||
continue
|
||||
|
||||
result = self._handle_tool_execution(
|
||||
function_name=function_name,
|
||||
function_args=function_args,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
full_response = self._apply_stop_words(full_response)
|
||||
|
||||
self._emit_call_completed_event(
|
||||
response=full_response,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
)
|
||||
|
||||
return self._invoke_after_llm_call_hooks(
|
||||
params["messages"], full_response, from_agent
|
||||
)
|
||||
|
||||
async def _ahandle_completion(
|
||||
self,
|
||||
params: dict[str, Any],
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str | Any:
|
||||
"""Handle non-streaming async chat completion."""
|
||||
try:
|
||||
if response_model:
|
||||
parse_params = {
|
||||
k: v for k, v in params.items() if k != "response_format"
|
||||
}
|
||||
parsed_response = await self.async_client.beta.chat.completions.parse(
|
||||
**parse_params,
|
||||
response_format=response_model,
|
||||
)
|
||||
math_reasoning = parsed_response.choices[0].message
|
||||
|
||||
if math_reasoning.refusal:
|
||||
pass
|
||||
|
||||
usage = self._extract_openai_token_usage(parsed_response)
|
||||
self._track_token_usage_internal(usage)
|
||||
|
||||
parsed_object = parsed_response.choices[0].message.parsed
|
||||
if parsed_object:
|
||||
structured_json = parsed_object.model_dump_json()
|
||||
self._emit_call_completed_event(
|
||||
response=structured_json,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
)
|
||||
return structured_json
|
||||
|
||||
response: ChatCompletion = await self.async_client.chat.completions.create(
|
||||
**params
|
||||
)
|
||||
|
||||
usage = self._extract_openai_token_usage(response)
|
||||
|
||||
self._track_token_usage_internal(usage)
|
||||
|
||||
choice: Choice = response.choices[0]
|
||||
message = choice.message
|
||||
|
||||
if message.tool_calls and available_functions:
|
||||
tool_call = message.tool_calls[0]
|
||||
function_name = tool_call.function.name
|
||||
|
||||
try:
|
||||
function_args = json.loads(tool_call.function.arguments)
|
||||
except json.JSONDecodeError as e:
|
||||
logging.error(f"Failed to parse tool arguments: {e}")
|
||||
function_args = {}
|
||||
@@ -415,7 +754,6 @@ class OpenAICompletion(BaseLLM):
|
||||
)
|
||||
raise ConnectionError(error_msg) from e
|
||||
except Exception as e:
|
||||
# Handle context length exceeded and other errors
|
||||
if is_context_length_exceeded(e):
|
||||
logging.error(f"Context window exceeded: {e}")
|
||||
raise LLMContextLengthExceededError(str(e)) from e
|
||||
@@ -429,7 +767,7 @@ class OpenAICompletion(BaseLLM):
|
||||
|
||||
return content
|
||||
|
||||
def _handle_streaming_completion(
|
||||
async def _ahandle_streaming_completion(
|
||||
self,
|
||||
params: dict[str, Any],
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
@@ -437,17 +775,17 @@ class OpenAICompletion(BaseLLM):
|
||||
from_agent: Any | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str:
|
||||
"""Handle streaming chat completion."""
|
||||
"""Handle async streaming chat completion."""
|
||||
full_response = ""
|
||||
tool_calls = {}
|
||||
|
||||
if response_model:
|
||||
completion_stream: Iterator[ChatCompletionChunk] = (
|
||||
self.client.chat.completions.create(**params)
|
||||
)
|
||||
completion_stream: AsyncIterator[
|
||||
ChatCompletionChunk
|
||||
] = await self.async_client.chat.completions.create(**params)
|
||||
|
||||
accumulated_content = ""
|
||||
for chunk in completion_stream:
|
||||
async for chunk in completion_stream:
|
||||
if not chunk.choices:
|
||||
continue
|
||||
|
||||
@@ -486,11 +824,11 @@ class OpenAICompletion(BaseLLM):
|
||||
)
|
||||
return accumulated_content
|
||||
|
||||
stream: Iterator[ChatCompletionChunk] = self.client.chat.completions.create(
|
||||
**params
|
||||
)
|
||||
stream: AsyncIterator[
|
||||
ChatCompletionChunk
|
||||
] = await self.async_client.chat.completions.create(**params)
|
||||
|
||||
for chunk in stream:
|
||||
async for chunk in stream:
|
||||
if not chunk.choices:
|
||||
continue
|
||||
|
||||
@@ -524,11 +862,9 @@ class OpenAICompletion(BaseLLM):
|
||||
function_name = call_data["name"]
|
||||
arguments = call_data["arguments"]
|
||||
|
||||
# Skip if function name is empty or arguments are empty
|
||||
if not function_name or not arguments:
|
||||
continue
|
||||
|
||||
# Check if function exists in available functions
|
||||
if function_name not in available_functions:
|
||||
logging.warning(
|
||||
f"Function '{function_name}' not found in available functions"
|
||||
|
||||
@@ -66,7 +66,6 @@ class SSETransport(BaseTransport):
|
||||
self._transport_context = sse_client(
|
||||
self.url,
|
||||
headers=self.headers if self.headers else None,
|
||||
terminate_on_close=True,
|
||||
)
|
||||
|
||||
read, write = await self._transport_context.__aenter__()
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from crewai.memory import (
|
||||
@@ -16,6 +17,8 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
class ContextualMemory:
|
||||
"""Aggregates and retrieves context from multiple memory sources."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
stm: ShortTermMemory,
|
||||
@@ -46,9 +49,14 @@ class ContextualMemory:
|
||||
self.exm.task = self.task
|
||||
|
||||
def build_context_for_task(self, task: Task, context: str) -> str:
|
||||
"""
|
||||
Automatically builds a minimal, highly relevant set of contextual information
|
||||
for a given task.
|
||||
"""Build contextual information for a task synchronously.
|
||||
|
||||
Args:
|
||||
task: The task to build context for.
|
||||
context: Additional context string.
|
||||
|
||||
Returns:
|
||||
Formatted context string from all memory sources.
|
||||
"""
|
||||
query = f"{task.description} {context}".strip()
|
||||
|
||||
@@ -63,6 +71,31 @@ class ContextualMemory:
|
||||
]
|
||||
return "\n".join(filter(None, context_parts))
|
||||
|
||||
async def abuild_context_for_task(self, task: Task, context: str) -> str:
|
||||
"""Build contextual information for a task asynchronously.
|
||||
|
||||
Args:
|
||||
task: The task to build context for.
|
||||
context: Additional context string.
|
||||
|
||||
Returns:
|
||||
Formatted context string from all memory sources.
|
||||
"""
|
||||
query = f"{task.description} {context}".strip()
|
||||
|
||||
if query == "":
|
||||
return ""
|
||||
|
||||
# Fetch all contexts concurrently
|
||||
results = await asyncio.gather(
|
||||
self._afetch_ltm_context(task.description),
|
||||
self._afetch_stm_context(query),
|
||||
self._afetch_entity_context(query),
|
||||
self._afetch_external_context(query),
|
||||
)
|
||||
|
||||
return "\n".join(filter(None, results))
|
||||
|
||||
def _fetch_stm_context(self, query: str) -> str:
|
||||
"""
|
||||
Fetches recent relevant insights from STM related to the task's description and expected_output,
|
||||
@@ -135,3 +168,87 @@ class ContextualMemory:
|
||||
f"- {result['content']}" for result in external_memories
|
||||
)
|
||||
return f"External memories:\n{formatted_memories}"
|
||||
|
||||
async def _afetch_stm_context(self, query: str) -> str:
|
||||
"""Fetch recent relevant insights from STM asynchronously.
|
||||
|
||||
Args:
|
||||
query: The search query.
|
||||
|
||||
Returns:
|
||||
Formatted insights as bullet points, or empty string if none found.
|
||||
"""
|
||||
if self.stm is None:
|
||||
return ""
|
||||
|
||||
stm_results = await self.stm.asearch(query)
|
||||
formatted_results = "\n".join(
|
||||
[f"- {result['content']}" for result in stm_results]
|
||||
)
|
||||
return f"Recent Insights:\n{formatted_results}" if stm_results else ""
|
||||
|
||||
async def _afetch_ltm_context(self, task: str) -> str | None:
|
||||
"""Fetch historical data from LTM asynchronously.
|
||||
|
||||
Args:
|
||||
task: The task description to search for.
|
||||
|
||||
Returns:
|
||||
Formatted historical data as bullet points, or None if none found.
|
||||
"""
|
||||
if self.ltm is None:
|
||||
return ""
|
||||
|
||||
ltm_results = await self.ltm.asearch(task, latest_n=2)
|
||||
if not ltm_results:
|
||||
return None
|
||||
|
||||
formatted_results = [
|
||||
suggestion
|
||||
for result in ltm_results
|
||||
for suggestion in result["metadata"]["suggestions"]
|
||||
]
|
||||
formatted_results = list(dict.fromkeys(formatted_results))
|
||||
formatted_results = "\n".join([f"- {result}" for result in formatted_results]) # type: ignore # Incompatible types in assignment (expression has type "str", variable has type "list[str]")
|
||||
|
||||
return f"Historical Data:\n{formatted_results}" if ltm_results else ""
|
||||
|
||||
async def _afetch_entity_context(self, query: str) -> str:
|
||||
"""Fetch relevant entity information asynchronously.
|
||||
|
||||
Args:
|
||||
query: The search query.
|
||||
|
||||
Returns:
|
||||
Formatted entity information as bullet points, or empty string if none found.
|
||||
"""
|
||||
if self.em is None:
|
||||
return ""
|
||||
|
||||
em_results = await self.em.asearch(query)
|
||||
formatted_results = "\n".join(
|
||||
[f"- {result['content']}" for result in em_results]
|
||||
)
|
||||
return f"Entities:\n{formatted_results}" if em_results else ""
|
||||
|
||||
async def _afetch_external_context(self, query: str) -> str:
|
||||
"""Fetch relevant information from External Memory asynchronously.
|
||||
|
||||
Args:
|
||||
query: The search query.
|
||||
|
||||
Returns:
|
||||
Formatted information as bullet points, or empty string if none found.
|
||||
"""
|
||||
if self.exm is None:
|
||||
return ""
|
||||
|
||||
external_memories = await self.exm.asearch(query)
|
||||
|
||||
if not external_memories:
|
||||
return ""
|
||||
|
||||
formatted_memories = "\n".join(
|
||||
f"- {result['content']}" for result in external_memories
|
||||
)
|
||||
return f"External memories:\n{formatted_memories}"
|
||||
|
||||
@@ -26,7 +26,13 @@ class EntityMemory(Memory):
|
||||
|
||||
_memory_provider: str | None = PrivateAttr()
|
||||
|
||||
def __init__(self, crew=None, embedder_config=None, storage=None, path=None):
|
||||
def __init__(
|
||||
self,
|
||||
crew: Any = None,
|
||||
embedder_config: Any = None,
|
||||
storage: Any = None,
|
||||
path: str | None = None,
|
||||
) -> None:
|
||||
memory_provider = None
|
||||
if embedder_config and isinstance(embedder_config, dict):
|
||||
memory_provider = embedder_config.get("provider")
|
||||
@@ -43,7 +49,7 @@ class EntityMemory(Memory):
|
||||
if embedder_config and isinstance(embedder_config, dict)
|
||||
else None
|
||||
)
|
||||
storage = Mem0Storage(type="short_term", crew=crew, config=config)
|
||||
storage = Mem0Storage(type="short_term", crew=crew, config=config) # type: ignore[no-untyped-call]
|
||||
else:
|
||||
storage = (
|
||||
storage
|
||||
@@ -170,7 +176,17 @@ class EntityMemory(Memory):
|
||||
query: str,
|
||||
limit: int = 5,
|
||||
score_threshold: float = 0.6,
|
||||
):
|
||||
) -> list[Any]:
|
||||
"""Search entity memory for relevant entries.
|
||||
|
||||
Args:
|
||||
query: The search query.
|
||||
limit: Maximum number of results to return.
|
||||
score_threshold: Minimum similarity score for results.
|
||||
|
||||
Returns:
|
||||
List of matching memory entries.
|
||||
"""
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryQueryStartedEvent(
|
||||
@@ -217,6 +233,168 @@ class EntityMemory(Memory):
|
||||
)
|
||||
raise
|
||||
|
||||
async def asave(
|
||||
self,
|
||||
value: EntityMemoryItem | list[EntityMemoryItem],
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""Save entity items asynchronously.
|
||||
|
||||
Args:
|
||||
value: Single EntityMemoryItem or list of EntityMemoryItems to save.
|
||||
metadata: Optional metadata dict (not used, for signature compatibility).
|
||||
"""
|
||||
if not value:
|
||||
return
|
||||
|
||||
items = value if isinstance(value, list) else [value]
|
||||
is_batch = len(items) > 1
|
||||
|
||||
metadata = {"entity_count": len(items)} if is_batch else items[0].metadata
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemorySaveStartedEvent(
|
||||
metadata=metadata,
|
||||
source_type="entity_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
saved_count = 0
|
||||
errors: list[str | None] = []
|
||||
|
||||
async def save_single_item(item: EntityMemoryItem) -> tuple[bool, str | None]:
|
||||
"""Save a single item asynchronously."""
|
||||
try:
|
||||
if self._memory_provider == "mem0":
|
||||
data = f"""
|
||||
Remember details about the following entity:
|
||||
Name: {item.name}
|
||||
Type: {item.type}
|
||||
Entity Description: {item.description}
|
||||
"""
|
||||
else:
|
||||
data = f"{item.name}({item.type}): {item.description}"
|
||||
|
||||
await super(EntityMemory, self).asave(data, item.metadata)
|
||||
return True, None
|
||||
except Exception as e:
|
||||
return False, f"{item.name}: {e!s}"
|
||||
|
||||
try:
|
||||
for item in items:
|
||||
success, error = await save_single_item(item)
|
||||
if success:
|
||||
saved_count += 1
|
||||
else:
|
||||
errors.append(error)
|
||||
|
||||
if is_batch:
|
||||
emit_value = f"Saved {saved_count} entities"
|
||||
metadata = {"entity_count": saved_count, "errors": errors}
|
||||
else:
|
||||
emit_value = f"{items[0].name}({items[0].type}): {items[0].description}"
|
||||
metadata = items[0].metadata
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemorySaveCompletedEvent(
|
||||
value=emit_value,
|
||||
metadata=metadata,
|
||||
save_time_ms=(time.time() - start_time) * 1000,
|
||||
source_type="entity_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
|
||||
if errors:
|
||||
raise Exception(
|
||||
f"Partial save: {len(errors)} failed out of {len(items)}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
fail_metadata = (
|
||||
{"entity_count": len(items), "saved": saved_count}
|
||||
if is_batch
|
||||
else items[0].metadata
|
||||
)
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemorySaveFailedEvent(
|
||||
metadata=fail_metadata,
|
||||
error=str(e),
|
||||
source_type="entity_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
raise
|
||||
|
||||
async def asearch(
|
||||
self,
|
||||
query: str,
|
||||
limit: int = 5,
|
||||
score_threshold: float = 0.6,
|
||||
) -> list[Any]:
|
||||
"""Search entity memory asynchronously.
|
||||
|
||||
Args:
|
||||
query: The search query.
|
||||
limit: Maximum number of results to return.
|
||||
score_threshold: Minimum similarity score for results.
|
||||
|
||||
Returns:
|
||||
List of matching memory entries.
|
||||
"""
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryQueryStartedEvent(
|
||||
query=query,
|
||||
limit=limit,
|
||||
score_threshold=score_threshold,
|
||||
source_type="entity_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
try:
|
||||
results = await super().asearch(
|
||||
query=query, limit=limit, score_threshold=score_threshold
|
||||
)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryQueryCompletedEvent(
|
||||
query=query,
|
||||
results=results,
|
||||
limit=limit,
|
||||
score_threshold=score_threshold,
|
||||
query_time_ms=(time.time() - start_time) * 1000,
|
||||
source_type="entity_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
|
||||
return results
|
||||
except Exception as e:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryQueryFailedEvent(
|
||||
query=query,
|
||||
limit=limit,
|
||||
score_threshold=score_threshold,
|
||||
error=str(e),
|
||||
source_type="entity_memory",
|
||||
),
|
||||
)
|
||||
raise
|
||||
|
||||
def reset(self) -> None:
|
||||
try:
|
||||
self.storage.reset()
|
||||
|
||||
@@ -30,7 +30,7 @@ class ExternalMemory(Memory):
|
||||
def _configure_mem0(crew: Any, config: dict[str, Any]) -> Mem0Storage:
|
||||
from crewai.memory.storage.mem0_storage import Mem0Storage
|
||||
|
||||
return Mem0Storage(type="external", crew=crew, config=config)
|
||||
return Mem0Storage(type="external", crew=crew, config=config) # type: ignore[no-untyped-call]
|
||||
|
||||
@staticmethod
|
||||
def external_supported_storages() -> dict[str, Any]:
|
||||
@@ -53,7 +53,10 @@ class ExternalMemory(Memory):
|
||||
if provider not in supported_storages:
|
||||
raise ValueError(f"Provider {provider} not supported")
|
||||
|
||||
return supported_storages[provider](crew, embedder_config.get("config", {}))
|
||||
storage: Storage = supported_storages[provider](
|
||||
crew, embedder_config.get("config", {})
|
||||
)
|
||||
return storage
|
||||
|
||||
def save(
|
||||
self,
|
||||
@@ -111,7 +114,17 @@ class ExternalMemory(Memory):
|
||||
query: str,
|
||||
limit: int = 5,
|
||||
score_threshold: float = 0.6,
|
||||
):
|
||||
) -> list[Any]:
|
||||
"""Search external memory for relevant entries.
|
||||
|
||||
Args:
|
||||
query: The search query.
|
||||
limit: Maximum number of results to return.
|
||||
score_threshold: Minimum similarity score for results.
|
||||
|
||||
Returns:
|
||||
List of matching memory entries.
|
||||
"""
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryQueryStartedEvent(
|
||||
@@ -158,6 +171,124 @@ class ExternalMemory(Memory):
|
||||
)
|
||||
raise
|
||||
|
||||
async def asave(
|
||||
self,
|
||||
value: Any,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""Save a value to external memory asynchronously.
|
||||
|
||||
Args:
|
||||
value: The value to save.
|
||||
metadata: Optional metadata to associate with the value.
|
||||
"""
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemorySaveStartedEvent(
|
||||
value=value,
|
||||
metadata=metadata,
|
||||
source_type="external_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
try:
|
||||
item = ExternalMemoryItem(
|
||||
value=value,
|
||||
metadata=metadata,
|
||||
agent=self.agent.role if self.agent else None,
|
||||
)
|
||||
await super().asave(value=item.value, metadata=item.metadata)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemorySaveCompletedEvent(
|
||||
value=value,
|
||||
metadata=metadata,
|
||||
save_time_ms=(time.time() - start_time) * 1000,
|
||||
source_type="external_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemorySaveFailedEvent(
|
||||
value=value,
|
||||
metadata=metadata,
|
||||
error=str(e),
|
||||
source_type="external_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
raise
|
||||
|
||||
async def asearch(
|
||||
self,
|
||||
query: str,
|
||||
limit: int = 5,
|
||||
score_threshold: float = 0.6,
|
||||
) -> list[Any]:
|
||||
"""Search external memory asynchronously.
|
||||
|
||||
Args:
|
||||
query: The search query.
|
||||
limit: Maximum number of results to return.
|
||||
score_threshold: Minimum similarity score for results.
|
||||
|
||||
Returns:
|
||||
List of matching memory entries.
|
||||
"""
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryQueryStartedEvent(
|
||||
query=query,
|
||||
limit=limit,
|
||||
score_threshold=score_threshold,
|
||||
source_type="external_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
try:
|
||||
results = await super().asearch(
|
||||
query=query, limit=limit, score_threshold=score_threshold
|
||||
)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryQueryCompletedEvent(
|
||||
query=query,
|
||||
results=results,
|
||||
limit=limit,
|
||||
score_threshold=score_threshold,
|
||||
query_time_ms=(time.time() - start_time) * 1000,
|
||||
source_type="external_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
|
||||
return results
|
||||
except Exception as e:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryQueryFailedEvent(
|
||||
query=query,
|
||||
limit=limit,
|
||||
score_threshold=score_threshold,
|
||||
error=str(e),
|
||||
source_type="external_memory",
|
||||
),
|
||||
)
|
||||
raise
|
||||
|
||||
def reset(self) -> None:
|
||||
self.storage.reset()
|
||||
|
||||
|
||||
@@ -24,7 +24,11 @@ class LongTermMemory(Memory):
|
||||
LongTermMemoryItem instances.
|
||||
"""
|
||||
|
||||
def __init__(self, storage=None, path=None):
|
||||
def __init__(
|
||||
self,
|
||||
storage: LTMSQLiteStorage | None = None,
|
||||
path: str | None = None,
|
||||
) -> None:
|
||||
if not storage:
|
||||
storage = LTMSQLiteStorage(db_path=path) if path else LTMSQLiteStorage()
|
||||
super().__init__(storage=storage)
|
||||
@@ -48,7 +52,7 @@ class LongTermMemory(Memory):
|
||||
metadata.update(
|
||||
{"agent": item.agent, "expected_output": item.expected_output}
|
||||
)
|
||||
self.storage.save( # type: ignore # BUG?: Unexpected keyword argument "task_description","score","datetime" for "save" of "Storage"
|
||||
self.storage.save(
|
||||
task_description=item.task,
|
||||
score=metadata["quality"],
|
||||
metadata=metadata,
|
||||
@@ -80,11 +84,20 @@ class LongTermMemory(Memory):
|
||||
)
|
||||
raise
|
||||
|
||||
def search( # type: ignore # signature of "search" incompatible with supertype "Memory"
|
||||
def search( # type: ignore[override]
|
||||
self,
|
||||
task: str,
|
||||
latest_n: int = 3,
|
||||
) -> list[dict[str, Any]]: # type: ignore # signature of "search" incompatible with supertype "Memory"
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Search long-term memory for relevant entries.
|
||||
|
||||
Args:
|
||||
task: The task description to search for.
|
||||
latest_n: Maximum number of results to return.
|
||||
|
||||
Returns:
|
||||
List of matching memory entries.
|
||||
"""
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryQueryStartedEvent(
|
||||
@@ -98,7 +111,7 @@ class LongTermMemory(Memory):
|
||||
|
||||
start_time = time.time()
|
||||
try:
|
||||
results = self.storage.load(task, latest_n) # type: ignore # BUG?: "Storage" has no attribute "load"
|
||||
results = self.storage.load(task, latest_n)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
@@ -113,7 +126,118 @@ class LongTermMemory(Memory):
|
||||
),
|
||||
)
|
||||
|
||||
return results
|
||||
return results or []
|
||||
except Exception as e:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryQueryFailedEvent(
|
||||
query=task,
|
||||
limit=latest_n,
|
||||
error=str(e),
|
||||
source_type="long_term_memory",
|
||||
),
|
||||
)
|
||||
raise
|
||||
|
||||
async def asave(self, item: LongTermMemoryItem) -> None: # type: ignore[override]
|
||||
"""Save an item to long-term memory asynchronously.
|
||||
|
||||
Args:
|
||||
item: The LongTermMemoryItem to save.
|
||||
"""
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemorySaveStartedEvent(
|
||||
value=item.task,
|
||||
metadata=item.metadata,
|
||||
agent_role=item.agent,
|
||||
source_type="long_term_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
try:
|
||||
metadata = item.metadata
|
||||
metadata.update(
|
||||
{"agent": item.agent, "expected_output": item.expected_output}
|
||||
)
|
||||
await self.storage.asave(
|
||||
task_description=item.task,
|
||||
score=metadata["quality"],
|
||||
metadata=metadata,
|
||||
datetime=item.datetime,
|
||||
)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemorySaveCompletedEvent(
|
||||
value=item.task,
|
||||
metadata=item.metadata,
|
||||
agent_role=item.agent,
|
||||
save_time_ms=(time.time() - start_time) * 1000,
|
||||
source_type="long_term_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemorySaveFailedEvent(
|
||||
value=item.task,
|
||||
metadata=item.metadata,
|
||||
agent_role=item.agent,
|
||||
error=str(e),
|
||||
source_type="long_term_memory",
|
||||
),
|
||||
)
|
||||
raise
|
||||
|
||||
async def asearch( # type: ignore[override]
|
||||
self,
|
||||
task: str,
|
||||
latest_n: int = 3,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Search long-term memory asynchronously.
|
||||
|
||||
Args:
|
||||
task: The task description to search for.
|
||||
latest_n: Maximum number of results to return.
|
||||
|
||||
Returns:
|
||||
List of matching memory entries.
|
||||
"""
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryQueryStartedEvent(
|
||||
query=task,
|
||||
limit=latest_n,
|
||||
source_type="long_term_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
try:
|
||||
results = await self.storage.aload(task, latest_n)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryQueryCompletedEvent(
|
||||
query=task,
|
||||
results=results,
|
||||
limit=latest_n,
|
||||
query_time_ms=(time.time() - start_time) * 1000,
|
||||
source_type="long_term_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
|
||||
return results or []
|
||||
except Exception as e:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
@@ -127,4 +251,5 @@ class LongTermMemory(Memory):
|
||||
raise
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset long-term memory."""
|
||||
self.storage.reset()
|
||||
|
||||
@@ -13,9 +13,7 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
class Memory(BaseModel):
|
||||
"""
|
||||
Base class for memory, now supporting agent tags and generic metadata.
|
||||
"""
|
||||
"""Base class for memory, supporting agent tags and generic metadata."""
|
||||
|
||||
embedder_config: EmbedderConfig | dict[str, Any] | None = None
|
||||
crew: Any | None = None
|
||||
@@ -52,20 +50,72 @@ class Memory(BaseModel):
|
||||
value: Any,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
metadata = metadata or {}
|
||||
"""Save a value to memory.
|
||||
|
||||
Args:
|
||||
value: The value to save.
|
||||
metadata: Optional metadata to associate with the value.
|
||||
"""
|
||||
metadata = metadata or {}
|
||||
self.storage.save(value, metadata)
|
||||
|
||||
async def asave(
|
||||
self,
|
||||
value: Any,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""Save a value to memory asynchronously.
|
||||
|
||||
Args:
|
||||
value: The value to save.
|
||||
metadata: Optional metadata to associate with the value.
|
||||
"""
|
||||
metadata = metadata or {}
|
||||
await self.storage.asave(value, metadata)
|
||||
|
||||
def search(
|
||||
self,
|
||||
query: str,
|
||||
limit: int = 5,
|
||||
score_threshold: float = 0.6,
|
||||
) -> list[Any]:
|
||||
return self.storage.search(
|
||||
"""Search memory for relevant entries.
|
||||
|
||||
Args:
|
||||
query: The search query.
|
||||
limit: Maximum number of results to return.
|
||||
score_threshold: Minimum similarity score for results.
|
||||
|
||||
Returns:
|
||||
List of matching memory entries.
|
||||
"""
|
||||
results: list[Any] = self.storage.search(
|
||||
query=query, limit=limit, score_threshold=score_threshold
|
||||
)
|
||||
return results
|
||||
|
||||
async def asearch(
|
||||
self,
|
||||
query: str,
|
||||
limit: int = 5,
|
||||
score_threshold: float = 0.6,
|
||||
) -> list[Any]:
|
||||
"""Search memory for relevant entries asynchronously.
|
||||
|
||||
Args:
|
||||
query: The search query.
|
||||
limit: Maximum number of results to return.
|
||||
score_threshold: Minimum similarity score for results.
|
||||
|
||||
Returns:
|
||||
List of matching memory entries.
|
||||
"""
|
||||
results: list[Any] = await self.storage.asearch(
|
||||
query=query, limit=limit, score_threshold=score_threshold
|
||||
)
|
||||
return results
|
||||
|
||||
def set_crew(self, crew: Any) -> Memory:
|
||||
"""Set the crew for this memory instance."""
|
||||
self.crew = crew
|
||||
return self
|
||||
|
||||
@@ -30,7 +30,13 @@ class ShortTermMemory(Memory):
|
||||
|
||||
_memory_provider: str | None = PrivateAttr()
|
||||
|
||||
def __init__(self, crew=None, embedder_config=None, storage=None, path=None):
|
||||
def __init__(
|
||||
self,
|
||||
crew: Any = None,
|
||||
embedder_config: Any = None,
|
||||
storage: Any = None,
|
||||
path: str | None = None,
|
||||
) -> None:
|
||||
memory_provider = None
|
||||
if embedder_config and isinstance(embedder_config, dict):
|
||||
memory_provider = embedder_config.get("provider")
|
||||
@@ -47,7 +53,7 @@ class ShortTermMemory(Memory):
|
||||
if embedder_config and isinstance(embedder_config, dict)
|
||||
else None
|
||||
)
|
||||
storage = Mem0Storage(type="short_term", crew=crew, config=config)
|
||||
storage = Mem0Storage(type="short_term", crew=crew, config=config) # type: ignore[no-untyped-call]
|
||||
else:
|
||||
storage = (
|
||||
storage
|
||||
@@ -123,7 +129,17 @@ class ShortTermMemory(Memory):
|
||||
query: str,
|
||||
limit: int = 5,
|
||||
score_threshold: float = 0.6,
|
||||
):
|
||||
) -> list[Any]:
|
||||
"""Search short-term memory for relevant entries.
|
||||
|
||||
Args:
|
||||
query: The search query.
|
||||
limit: Maximum number of results to return.
|
||||
score_threshold: Minimum similarity score for results.
|
||||
|
||||
Returns:
|
||||
List of matching memory entries.
|
||||
"""
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryQueryStartedEvent(
|
||||
@@ -140,7 +156,7 @@ class ShortTermMemory(Memory):
|
||||
try:
|
||||
results = self.storage.search(
|
||||
query=query, limit=limit, score_threshold=score_threshold
|
||||
) # type: ignore # BUG? The reference is to the parent class, but the parent class does not have this parameters
|
||||
)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
@@ -156,7 +172,130 @@ class ShortTermMemory(Memory):
|
||||
),
|
||||
)
|
||||
|
||||
return results
|
||||
return list(results)
|
||||
except Exception as e:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryQueryFailedEvent(
|
||||
query=query,
|
||||
limit=limit,
|
||||
score_threshold=score_threshold,
|
||||
error=str(e),
|
||||
source_type="short_term_memory",
|
||||
),
|
||||
)
|
||||
raise
|
||||
|
||||
async def asave(
|
||||
self,
|
||||
value: Any,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""Save a value to short-term memory asynchronously.
|
||||
|
||||
Args:
|
||||
value: The value to save.
|
||||
metadata: Optional metadata to associate with the value.
|
||||
"""
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemorySaveStartedEvent(
|
||||
value=value,
|
||||
metadata=metadata,
|
||||
source_type="short_term_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
try:
|
||||
item = ShortTermMemoryItem(
|
||||
data=value,
|
||||
metadata=metadata,
|
||||
agent=self.agent.role if self.agent else None,
|
||||
)
|
||||
if self._memory_provider == "mem0":
|
||||
item.data = (
|
||||
f"Remember the following insights from Agent run: {item.data}"
|
||||
)
|
||||
|
||||
await super().asave(value=item.data, metadata=item.metadata)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemorySaveCompletedEvent(
|
||||
value=value,
|
||||
metadata=metadata,
|
||||
save_time_ms=(time.time() - start_time) * 1000,
|
||||
source_type="short_term_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemorySaveFailedEvent(
|
||||
value=value,
|
||||
metadata=metadata,
|
||||
error=str(e),
|
||||
source_type="short_term_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
raise
|
||||
|
||||
async def asearch(
|
||||
self,
|
||||
query: str,
|
||||
limit: int = 5,
|
||||
score_threshold: float = 0.6,
|
||||
) -> list[Any]:
|
||||
"""Search short-term memory asynchronously.
|
||||
|
||||
Args:
|
||||
query: The search query.
|
||||
limit: Maximum number of results to return.
|
||||
score_threshold: Minimum similarity score for results.
|
||||
|
||||
Returns:
|
||||
List of matching memory entries.
|
||||
"""
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryQueryStartedEvent(
|
||||
query=query,
|
||||
limit=limit,
|
||||
score_threshold=score_threshold,
|
||||
source_type="short_term_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
try:
|
||||
results = await self.storage.asearch(
|
||||
query=query, limit=limit, score_threshold=score_threshold
|
||||
)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryQueryCompletedEvent(
|
||||
query=query,
|
||||
results=results,
|
||||
limit=limit,
|
||||
score_threshold=score_threshold,
|
||||
query_time_ms=(time.time() - start_time) * 1000,
|
||||
source_type="short_term_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
|
||||
return list(results)
|
||||
except Exception as e:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
|
||||
@@ -3,29 +3,30 @@ from pathlib import Path
|
||||
import sqlite3
|
||||
from typing import Any
|
||||
|
||||
import aiosqlite
|
||||
|
||||
from crewai.utilities import Printer
|
||||
from crewai.utilities.paths import db_storage_path
|
||||
|
||||
|
||||
class LTMSQLiteStorage:
|
||||
"""
|
||||
An updated SQLite storage class for LTM data storage.
|
||||
"""
|
||||
"""SQLite storage class for long-term memory data."""
|
||||
|
||||
def __init__(self, db_path: str | None = None) -> None:
|
||||
"""Initialize the SQLite storage.
|
||||
|
||||
Args:
|
||||
db_path: Optional path to the database file.
|
||||
"""
|
||||
if db_path is None:
|
||||
# Get the parent directory of the default db path and create our db file there
|
||||
db_path = str(Path(db_storage_path()) / "long_term_memory_storage.db")
|
||||
self.db_path = db_path
|
||||
self._printer: Printer = Printer()
|
||||
# Ensure parent directory exists
|
||||
Path(self.db_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
self._initialize_db()
|
||||
|
||||
def _initialize_db(self):
|
||||
"""
|
||||
Initializes the SQLite database and creates LTM table
|
||||
"""
|
||||
def _initialize_db(self) -> None:
|
||||
"""Initialize the SQLite database and create LTM table."""
|
||||
try:
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
@@ -106,9 +107,7 @@ class LTMSQLiteStorage:
|
||||
)
|
||||
return None
|
||||
|
||||
def reset(
|
||||
self,
|
||||
) -> None:
|
||||
def reset(self) -> None:
|
||||
"""Resets the LTM table with error handling."""
|
||||
try:
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
@@ -121,4 +120,87 @@ class LTMSQLiteStorage:
|
||||
content=f"MEMORY ERROR: An error occurred while deleting all rows in LTM: {e}",
|
||||
color="red",
|
||||
)
|
||||
return
|
||||
|
||||
async def asave(
|
||||
self,
|
||||
task_description: str,
|
||||
metadata: dict[str, Any],
|
||||
datetime: str,
|
||||
score: int | float,
|
||||
) -> None:
|
||||
"""Save data to the LTM table asynchronously.
|
||||
|
||||
Args:
|
||||
task_description: Description of the task.
|
||||
metadata: Metadata associated with the memory.
|
||||
datetime: Timestamp of the memory.
|
||||
score: Quality score of the memory.
|
||||
"""
|
||||
try:
|
||||
async with aiosqlite.connect(self.db_path) as conn:
|
||||
await conn.execute(
|
||||
"""
|
||||
INSERT INTO long_term_memories (task_description, metadata, datetime, score)
|
||||
VALUES (?, ?, ?, ?)
|
||||
""",
|
||||
(task_description, json.dumps(metadata), datetime, score),
|
||||
)
|
||||
await conn.commit()
|
||||
except aiosqlite.Error as e:
|
||||
self._printer.print(
|
||||
content=f"MEMORY ERROR: An error occurred while saving to LTM: {e}",
|
||||
color="red",
|
||||
)
|
||||
|
||||
async def aload(
|
||||
self, task_description: str, latest_n: int
|
||||
) -> list[dict[str, Any]] | None:
|
||||
"""Query the LTM table by task description asynchronously.
|
||||
|
||||
Args:
|
||||
task_description: Description of the task to search for.
|
||||
latest_n: Maximum number of results to return.
|
||||
|
||||
Returns:
|
||||
List of matching memory entries or None if error occurs.
|
||||
"""
|
||||
try:
|
||||
async with aiosqlite.connect(self.db_path) as conn:
|
||||
cursor = await conn.execute(
|
||||
f"""
|
||||
SELECT metadata, datetime, score
|
||||
FROM long_term_memories
|
||||
WHERE task_description = ?
|
||||
ORDER BY datetime DESC, score ASC
|
||||
LIMIT {latest_n}
|
||||
""", # nosec # noqa: S608
|
||||
(task_description,),
|
||||
)
|
||||
rows = await cursor.fetchall()
|
||||
if rows:
|
||||
return [
|
||||
{
|
||||
"metadata": json.loads(row[0]),
|
||||
"datetime": row[1],
|
||||
"score": row[2],
|
||||
}
|
||||
for row in rows
|
||||
]
|
||||
except aiosqlite.Error as e:
|
||||
self._printer.print(
|
||||
content=f"MEMORY ERROR: An error occurred while querying LTM: {e}",
|
||||
color="red",
|
||||
)
|
||||
return None
|
||||
|
||||
async def areset(self) -> None:
|
||||
"""Reset the LTM table asynchronously."""
|
||||
try:
|
||||
async with aiosqlite.connect(self.db_path) as conn:
|
||||
await conn.execute("DELETE FROM long_term_memories")
|
||||
await conn.commit()
|
||||
except aiosqlite.Error as e:
|
||||
self._printer.print(
|
||||
content=f"MEMORY ERROR: An error occurred while deleting all rows in LTM: {e}",
|
||||
color="red",
|
||||
)
|
||||
|
||||
@@ -129,6 +129,12 @@ class RAGStorage(BaseRAGStorage):
|
||||
return f"{base_path}/{file_name}"
|
||||
|
||||
def save(self, value: Any, metadata: dict[str, Any]) -> None:
|
||||
"""Save a value to storage.
|
||||
|
||||
Args:
|
||||
value: The value to save.
|
||||
metadata: Metadata to associate with the value.
|
||||
"""
|
||||
try:
|
||||
client = self._get_client()
|
||||
collection_name = (
|
||||
@@ -167,6 +173,51 @@ class RAGStorage(BaseRAGStorage):
|
||||
f"Error during {self.type} save: {e!s}\n{traceback.format_exc()}"
|
||||
)
|
||||
|
||||
async def asave(self, value: Any, metadata: dict[str, Any]) -> None:
|
||||
"""Save a value to storage asynchronously.
|
||||
|
||||
Args:
|
||||
value: The value to save.
|
||||
metadata: Metadata to associate with the value.
|
||||
"""
|
||||
try:
|
||||
client = self._get_client()
|
||||
collection_name = (
|
||||
f"memory_{self.type}_{self.agents}"
|
||||
if self.agents
|
||||
else f"memory_{self.type}"
|
||||
)
|
||||
await client.aget_or_create_collection(collection_name=collection_name)
|
||||
|
||||
document: BaseRecord = {"content": value}
|
||||
if metadata:
|
||||
document["metadata"] = metadata
|
||||
|
||||
batch_size = None
|
||||
if (
|
||||
self.embedder_config
|
||||
and isinstance(self.embedder_config, dict)
|
||||
and "config" in self.embedder_config
|
||||
):
|
||||
nested_config = self.embedder_config["config"]
|
||||
if isinstance(nested_config, dict):
|
||||
batch_size = nested_config.get("batch_size")
|
||||
|
||||
if batch_size is not None:
|
||||
await client.aadd_documents(
|
||||
collection_name=collection_name,
|
||||
documents=[document],
|
||||
batch_size=cast(int, batch_size),
|
||||
)
|
||||
else:
|
||||
await client.aadd_documents(
|
||||
collection_name=collection_name, documents=[document]
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error(
|
||||
f"Error during {self.type} async save: {e!s}\n{traceback.format_exc()}"
|
||||
)
|
||||
|
||||
def search(
|
||||
self,
|
||||
query: str,
|
||||
@@ -174,6 +225,17 @@ class RAGStorage(BaseRAGStorage):
|
||||
filter: dict[str, Any] | None = None,
|
||||
score_threshold: float = 0.6,
|
||||
) -> list[Any]:
|
||||
"""Search for matching entries in storage.
|
||||
|
||||
Args:
|
||||
query: The search query.
|
||||
limit: Maximum number of results to return.
|
||||
filter: Optional metadata filter.
|
||||
score_threshold: Minimum similarity score for results.
|
||||
|
||||
Returns:
|
||||
List of matching entries.
|
||||
"""
|
||||
try:
|
||||
client = self._get_client()
|
||||
collection_name = (
|
||||
@@ -194,6 +256,44 @@ class RAGStorage(BaseRAGStorage):
|
||||
)
|
||||
return []
|
||||
|
||||
async def asearch(
|
||||
self,
|
||||
query: str,
|
||||
limit: int = 5,
|
||||
filter: dict[str, Any] | None = None,
|
||||
score_threshold: float = 0.6,
|
||||
) -> list[Any]:
|
||||
"""Search for matching entries in storage asynchronously.
|
||||
|
||||
Args:
|
||||
query: The search query.
|
||||
limit: Maximum number of results to return.
|
||||
filter: Optional metadata filter.
|
||||
score_threshold: Minimum similarity score for results.
|
||||
|
||||
Returns:
|
||||
List of matching entries.
|
||||
"""
|
||||
try:
|
||||
client = self._get_client()
|
||||
collection_name = (
|
||||
f"memory_{self.type}_{self.agents}"
|
||||
if self.agents
|
||||
else f"memory_{self.type}"
|
||||
)
|
||||
return await client.asearch(
|
||||
collection_name=collection_name,
|
||||
query=query,
|
||||
limit=limit,
|
||||
metadata_filter=filter,
|
||||
score_threshold=score_threshold,
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error(
|
||||
f"Error during {self.type} async search: {e!s}\n{traceback.format_exc()}"
|
||||
)
|
||||
return []
|
||||
|
||||
def reset(self) -> None:
|
||||
try:
|
||||
client = self._get_client()
|
||||
|
||||
@@ -2,8 +2,10 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
import inspect
|
||||
from typing import TYPE_CHECKING, Any, Concatenate, ParamSpec, TypeVar, overload
|
||||
|
||||
from crewai.project.utils import memoize
|
||||
@@ -156,6 +158,23 @@ def cache_handler(meth: Callable[P, R]) -> CacheHandlerMethod[P, R]:
|
||||
return CacheHandlerMethod(memoize(meth))
|
||||
|
||||
|
||||
def _call_method(method: Callable[..., Any], *args: Any, **kwargs: Any) -> Any:
|
||||
"""Call a method, awaiting it if async and running in an event loop."""
|
||||
result = method(*args, **kwargs)
|
||||
if inspect.iscoroutine(result):
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
loop = None
|
||||
if loop and loop.is_running():
|
||||
import concurrent.futures
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as pool:
|
||||
return pool.submit(asyncio.run, result).result()
|
||||
return asyncio.run(result)
|
||||
return result
|
||||
|
||||
|
||||
@overload
|
||||
def crew(
|
||||
meth: Callable[Concatenate[SelfT, P], Crew],
|
||||
@@ -198,7 +217,7 @@ def crew(
|
||||
|
||||
# Instantiate tasks in order
|
||||
for _, task_method in tasks:
|
||||
task_instance = task_method(self)
|
||||
task_instance = _call_method(task_method, self)
|
||||
instantiated_tasks.append(task_instance)
|
||||
agent_instance = getattr(task_instance, "agent", None)
|
||||
if agent_instance and agent_instance.role not in agent_roles:
|
||||
@@ -207,7 +226,7 @@ def crew(
|
||||
|
||||
# Instantiate agents not included by tasks
|
||||
for _, agent_method in agents:
|
||||
agent_instance = agent_method(self)
|
||||
agent_instance = _call_method(agent_method, self)
|
||||
if agent_instance.role not in agent_roles:
|
||||
instantiated_agents.append(agent_instance)
|
||||
agent_roles.add(agent_instance.role)
|
||||
@@ -215,7 +234,7 @@ def crew(
|
||||
self.agents = instantiated_agents
|
||||
self.tasks = instantiated_tasks
|
||||
|
||||
crew_instance = meth(self, *args, **kwargs)
|
||||
crew_instance: Crew = _call_method(meth, self, *args, **kwargs)
|
||||
|
||||
def callback_wrapper(
|
||||
hook: Callable[Concatenate[CrewInstance, P2], R2], instance: CrewInstance
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
"""Utility functions for the crewai project module."""
|
||||
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Callable, Coroutine
|
||||
from functools import wraps
|
||||
import inspect
|
||||
from typing import Any, ParamSpec, TypeVar, cast
|
||||
|
||||
from pydantic import BaseModel
|
||||
@@ -37,8 +38,8 @@ def _make_hashable(arg: Any) -> Any:
|
||||
def memoize(meth: Callable[P, R]) -> Callable[P, R]:
|
||||
"""Memoize a method by caching its results based on arguments.
|
||||
|
||||
Handles Pydantic BaseModel instances by converting them to JSON strings
|
||||
before hashing for cache lookup.
|
||||
Handles both sync and async methods. Pydantic BaseModel instances are
|
||||
converted to JSON strings before hashing for cache lookup.
|
||||
|
||||
Args:
|
||||
meth: The method to memoize.
|
||||
@@ -46,18 +47,16 @@ def memoize(meth: Callable[P, R]) -> Callable[P, R]:
|
||||
Returns:
|
||||
A memoized version of the method that caches results.
|
||||
"""
|
||||
if inspect.iscoroutinefunction(meth):
|
||||
return cast(Callable[P, R], _memoize_async(meth))
|
||||
return _memoize_sync(meth)
|
||||
|
||||
|
||||
def _memoize_sync(meth: Callable[P, R]) -> Callable[P, R]:
|
||||
"""Memoize a synchronous method."""
|
||||
|
||||
@wraps(meth)
|
||||
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
"""Wrapper that converts arguments to hashable form before caching.
|
||||
|
||||
Args:
|
||||
*args: Positional arguments to the memoized method.
|
||||
**kwargs: Keyword arguments to the memoized method.
|
||||
|
||||
Returns:
|
||||
The result of the memoized method call.
|
||||
"""
|
||||
hashable_args = tuple(_make_hashable(arg) for arg in args)
|
||||
hashable_kwargs = tuple(
|
||||
sorted((k, _make_hashable(v)) for k, v in kwargs.items())
|
||||
@@ -73,3 +72,27 @@ def memoize(meth: Callable[P, R]) -> Callable[P, R]:
|
||||
return result
|
||||
|
||||
return cast(Callable[P, R], wrapper)
|
||||
|
||||
|
||||
def _memoize_async(
|
||||
meth: Callable[P, Coroutine[Any, Any, R]],
|
||||
) -> Callable[P, Coroutine[Any, Any, R]]:
|
||||
"""Memoize an async method."""
|
||||
|
||||
@wraps(meth)
|
||||
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
hashable_args = tuple(_make_hashable(arg) for arg in args)
|
||||
hashable_kwargs = tuple(
|
||||
sorted((k, _make_hashable(v)) for k, v in kwargs.items())
|
||||
)
|
||||
cache_key = str((hashable_args, hashable_kwargs))
|
||||
|
||||
cached_result: R | None = cache.read(tool=meth.__name__, input=cache_key)
|
||||
if cached_result is not None:
|
||||
return cached_result
|
||||
|
||||
result = await meth(*args, **kwargs)
|
||||
cache.add(tool=meth.__name__, input=cache_key, output=result)
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
|
||||
@@ -2,8 +2,10 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Callable
|
||||
from functools import partial
|
||||
import inspect
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
@@ -132,6 +134,22 @@ class CrewClass(Protocol):
|
||||
crew: Callable[..., Crew]
|
||||
|
||||
|
||||
def _resolve_result(result: Any) -> Any:
|
||||
"""Resolve a potentially async result to its value."""
|
||||
if inspect.iscoroutine(result):
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
loop = None
|
||||
if loop and loop.is_running():
|
||||
import concurrent.futures
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as pool:
|
||||
return pool.submit(asyncio.run, result).result()
|
||||
return asyncio.run(result)
|
||||
return result
|
||||
|
||||
|
||||
class DecoratedMethod(Generic[P, R]):
|
||||
"""Base wrapper for methods with decorator metadata.
|
||||
|
||||
@@ -162,7 +180,12 @@ class DecoratedMethod(Generic[P, R]):
|
||||
"""
|
||||
if obj is None:
|
||||
return self
|
||||
bound = partial(self._meth, obj)
|
||||
inner = partial(self._meth, obj)
|
||||
|
||||
def _bound(*args: Any, **kwargs: Any) -> R:
|
||||
result: R = _resolve_result(inner(*args, **kwargs)) # type: ignore[call-arg]
|
||||
return result
|
||||
|
||||
for attr in (
|
||||
"is_agent",
|
||||
"is_llm",
|
||||
@@ -174,8 +197,8 @@ class DecoratedMethod(Generic[P, R]):
|
||||
"is_crew",
|
||||
):
|
||||
if hasattr(self, attr):
|
||||
setattr(bound, attr, getattr(self, attr))
|
||||
return bound
|
||||
setattr(_bound, attr, getattr(self, attr))
|
||||
return _bound
|
||||
|
||||
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R:
|
||||
"""Call the wrapped method.
|
||||
@@ -236,6 +259,7 @@ class BoundTaskMethod(Generic[TaskResultT]):
|
||||
The task result with name ensured.
|
||||
"""
|
||||
result = self._task_method.unwrap()(self._obj, *args, **kwargs)
|
||||
result = _resolve_result(result)
|
||||
return self._task_method.ensure_task_name(result)
|
||||
|
||||
|
||||
@@ -292,7 +316,9 @@ class TaskMethod(Generic[P, TaskResultT]):
|
||||
Returns:
|
||||
The task instance with name set if not already provided.
|
||||
"""
|
||||
return self.ensure_task_name(self._meth(*args, **kwargs))
|
||||
result = self._meth(*args, **kwargs)
|
||||
result = _resolve_result(result)
|
||||
return self.ensure_task_name(result)
|
||||
|
||||
def unwrap(self) -> Callable[P, TaskResultT]:
|
||||
"""Get the original unwrapped method.
|
||||
|
||||
@@ -1,21 +1,35 @@
|
||||
"""HuggingFace embeddings provider."""
|
||||
|
||||
from chromadb.utils.embedding_functions.huggingface_embedding_function import (
|
||||
HuggingFaceEmbeddingServer,
|
||||
HuggingFaceEmbeddingFunction,
|
||||
)
|
||||
from pydantic import AliasChoices, Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
|
||||
class HuggingFaceProvider(BaseEmbeddingsProvider[HuggingFaceEmbeddingServer]):
|
||||
"""HuggingFace embeddings provider."""
|
||||
class HuggingFaceProvider(BaseEmbeddingsProvider[HuggingFaceEmbeddingFunction]):
|
||||
"""HuggingFace embeddings provider for the HuggingFace Inference API."""
|
||||
|
||||
embedding_callable: type[HuggingFaceEmbeddingServer] = Field(
|
||||
default=HuggingFaceEmbeddingServer,
|
||||
embedding_callable: type[HuggingFaceEmbeddingFunction] = Field(
|
||||
default=HuggingFaceEmbeddingFunction,
|
||||
description="HuggingFace embedding function class",
|
||||
)
|
||||
url: str = Field(
|
||||
description="HuggingFace API URL",
|
||||
validation_alias=AliasChoices("EMBEDDINGS_HUGGINGFACE_URL", "HUGGINGFACE_URL"),
|
||||
api_key: str | None = Field(
|
||||
default=None,
|
||||
description="HuggingFace API key",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_HUGGINGFACE_API_KEY",
|
||||
"HUGGINGFACE_API_KEY",
|
||||
"HF_TOKEN",
|
||||
),
|
||||
)
|
||||
model_name: str = Field(
|
||||
default="sentence-transformers/all-MiniLM-L6-v2",
|
||||
description="Model name to use for embeddings",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_HUGGINGFACE_MODEL_NAME",
|
||||
"HUGGINGFACE_MODEL_NAME",
|
||||
"model",
|
||||
),
|
||||
)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""Type definitions for HuggingFace embedding providers."""
|
||||
|
||||
from typing import Literal
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from typing_extensions import Required, TypedDict
|
||||
|
||||
@@ -8,7 +8,11 @@ from typing_extensions import Required, TypedDict
|
||||
class HuggingFaceProviderConfig(TypedDict, total=False):
|
||||
"""Configuration for HuggingFace provider."""
|
||||
|
||||
url: str
|
||||
api_key: str
|
||||
model: Annotated[
|
||||
str, "sentence-transformers/all-MiniLM-L6-v2"
|
||||
] # alias for model_name for backward compat
|
||||
model_name: Annotated[str, "sentence-transformers/all-MiniLM-L6-v2"]
|
||||
|
||||
|
||||
class HuggingFaceProviderSpec(TypedDict, total=False):
|
||||
|
||||
@@ -497,6 +497,107 @@ class Task(BaseModel):
|
||||
result = self._execute_core(agent, context, tools)
|
||||
future.set_result(result)
|
||||
|
||||
async def aexecute_sync(
|
||||
self,
|
||||
agent: BaseAgent | None = None,
|
||||
context: str | None = None,
|
||||
tools: list[BaseTool] | None = None,
|
||||
) -> TaskOutput:
|
||||
"""Execute the task asynchronously using native async/await."""
|
||||
return await self._aexecute_core(agent, context, tools)
|
||||
|
||||
async def _aexecute_core(
|
||||
self,
|
||||
agent: BaseAgent | None,
|
||||
context: str | None,
|
||||
tools: list[Any] | None,
|
||||
) -> TaskOutput:
|
||||
"""Run the core execution logic of the task asynchronously."""
|
||||
try:
|
||||
agent = agent or self.agent
|
||||
self.agent = agent
|
||||
if not agent:
|
||||
raise Exception(
|
||||
f"The task '{self.description}' has no agent assigned, therefore it can't be executed directly and should be executed in a Crew using a specific process that support that, like hierarchical."
|
||||
)
|
||||
|
||||
self.start_time = datetime.datetime.now()
|
||||
|
||||
self.prompt_context = context
|
||||
tools = tools or self.tools or []
|
||||
|
||||
self.processed_by_agents.add(agent.role)
|
||||
crewai_event_bus.emit(self, TaskStartedEvent(context=context, task=self)) # type: ignore[no-untyped-call]
|
||||
result = await agent.aexecute_task(
|
||||
task=self,
|
||||
context=context,
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
if not self._guardrails and not self._guardrail:
|
||||
pydantic_output, json_output = self._export_output(result)
|
||||
else:
|
||||
pydantic_output, json_output = None, None
|
||||
|
||||
task_output = TaskOutput(
|
||||
name=self.name or self.description,
|
||||
description=self.description,
|
||||
expected_output=self.expected_output,
|
||||
raw=result,
|
||||
pydantic=pydantic_output,
|
||||
json_dict=json_output,
|
||||
agent=agent.role,
|
||||
output_format=self._get_output_format(),
|
||||
messages=agent.last_messages, # type: ignore[attr-defined]
|
||||
)
|
||||
|
||||
if self._guardrails:
|
||||
for idx, guardrail in enumerate(self._guardrails):
|
||||
task_output = await self._ainvoke_guardrail_function(
|
||||
task_output=task_output,
|
||||
agent=agent,
|
||||
tools=tools,
|
||||
guardrail=guardrail,
|
||||
guardrail_index=idx,
|
||||
)
|
||||
|
||||
if self._guardrail:
|
||||
task_output = await self._ainvoke_guardrail_function(
|
||||
task_output=task_output,
|
||||
agent=agent,
|
||||
tools=tools,
|
||||
guardrail=self._guardrail,
|
||||
)
|
||||
|
||||
self.output = task_output
|
||||
self.end_time = datetime.datetime.now()
|
||||
|
||||
if self.callback:
|
||||
self.callback(self.output)
|
||||
|
||||
crew = self.agent.crew # type: ignore[union-attr]
|
||||
if crew and crew.task_callback and crew.task_callback != self.callback:
|
||||
crew.task_callback(self.output)
|
||||
|
||||
if self.output_file:
|
||||
content = (
|
||||
json_output
|
||||
if json_output
|
||||
else (
|
||||
pydantic_output.model_dump_json() if pydantic_output else result
|
||||
)
|
||||
)
|
||||
self._save_file(content)
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
TaskCompletedEvent(output=task_output, task=self), # type: ignore[no-untyped-call]
|
||||
)
|
||||
return task_output
|
||||
except Exception as e:
|
||||
self.end_time = datetime.datetime.now()
|
||||
crewai_event_bus.emit(self, TaskFailedEvent(error=str(e), task=self)) # type: ignore[no-untyped-call]
|
||||
raise e # Re-raise the exception after emitting the event
|
||||
|
||||
def _execute_core(
|
||||
self,
|
||||
agent: BaseAgent | None,
|
||||
@@ -539,7 +640,7 @@ class Task(BaseModel):
|
||||
json_dict=json_output,
|
||||
agent=agent.role,
|
||||
output_format=self._get_output_format(),
|
||||
messages=agent.last_messages,
|
||||
messages=agent.last_messages, # type: ignore[attr-defined]
|
||||
)
|
||||
|
||||
if self._guardrails:
|
||||
@@ -950,7 +1051,103 @@ Follow these guidelines:
|
||||
json_dict=json_output,
|
||||
agent=agent.role,
|
||||
output_format=self._get_output_format(),
|
||||
messages=agent.last_messages,
|
||||
messages=agent.last_messages, # type: ignore[attr-defined]
|
||||
)
|
||||
|
||||
return task_output
|
||||
|
||||
async def _ainvoke_guardrail_function(
|
||||
self,
|
||||
task_output: TaskOutput,
|
||||
agent: BaseAgent,
|
||||
tools: list[BaseTool],
|
||||
guardrail: GuardrailCallable | None,
|
||||
guardrail_index: int | None = None,
|
||||
) -> TaskOutput:
|
||||
"""Invoke the guardrail function asynchronously."""
|
||||
if not guardrail:
|
||||
return task_output
|
||||
|
||||
if guardrail_index is not None:
|
||||
current_retry_count = self._guardrail_retry_counts.get(guardrail_index, 0)
|
||||
else:
|
||||
current_retry_count = self.retry_count
|
||||
|
||||
max_attempts = self.guardrail_max_retries + 1
|
||||
|
||||
for attempt in range(max_attempts):
|
||||
guardrail_result = process_guardrail(
|
||||
output=task_output,
|
||||
guardrail=guardrail,
|
||||
retry_count=current_retry_count,
|
||||
event_source=self,
|
||||
from_task=self,
|
||||
from_agent=agent,
|
||||
)
|
||||
|
||||
if guardrail_result.success:
|
||||
if guardrail_result.result is None:
|
||||
raise Exception(
|
||||
"Task guardrail returned None as result. This is not allowed."
|
||||
)
|
||||
|
||||
if isinstance(guardrail_result.result, str):
|
||||
task_output.raw = guardrail_result.result
|
||||
pydantic_output, json_output = self._export_output(
|
||||
guardrail_result.result
|
||||
)
|
||||
task_output.pydantic = pydantic_output
|
||||
task_output.json_dict = json_output
|
||||
elif isinstance(guardrail_result.result, TaskOutput):
|
||||
task_output = guardrail_result.result
|
||||
|
||||
return task_output
|
||||
|
||||
if attempt >= self.guardrail_max_retries:
|
||||
guardrail_name = (
|
||||
f"guardrail {guardrail_index}"
|
||||
if guardrail_index is not None
|
||||
else "guardrail"
|
||||
)
|
||||
raise Exception(
|
||||
f"Task failed {guardrail_name} validation after {self.guardrail_max_retries} retries. "
|
||||
f"Last error: {guardrail_result.error}"
|
||||
)
|
||||
|
||||
if guardrail_index is not None:
|
||||
current_retry_count += 1
|
||||
self._guardrail_retry_counts[guardrail_index] = current_retry_count
|
||||
else:
|
||||
self.retry_count += 1
|
||||
current_retry_count = self.retry_count
|
||||
|
||||
context = self.i18n.errors("validation_error").format(
|
||||
guardrail_result_error=guardrail_result.error,
|
||||
task_output=task_output.raw,
|
||||
)
|
||||
printer = Printer()
|
||||
printer.print(
|
||||
content=f"Guardrail {guardrail_index if guardrail_index is not None else ''} blocked (attempt {attempt + 1}/{max_attempts}), retrying due to: {guardrail_result.error}\n",
|
||||
color="yellow",
|
||||
)
|
||||
|
||||
result = await agent.aexecute_task(
|
||||
task=self,
|
||||
context=context,
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
pydantic_output, json_output = self._export_output(result)
|
||||
task_output = TaskOutput(
|
||||
name=self.name or self.description,
|
||||
description=self.description,
|
||||
expected_output=self.expected_output,
|
||||
raw=result,
|
||||
pydantic=pydantic_output,
|
||||
json_dict=json_output,
|
||||
agent=agent.role,
|
||||
output_format=self._get_output_format(),
|
||||
messages=agent.last_messages, # type: ignore[attr-defined]
|
||||
)
|
||||
|
||||
return task_output
|
||||
|
||||
@@ -9,12 +9,14 @@ data is collected. Users can opt-in to share more complete data using the
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import atexit
|
||||
from collections.abc import Callable
|
||||
from importlib.metadata import version
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import platform
|
||||
import signal
|
||||
import threading
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
@@ -31,6 +33,14 @@ from opentelemetry.sdk.trace.export import (
|
||||
from opentelemetry.trace import Span
|
||||
from typing_extensions import Self
|
||||
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.system_events import (
|
||||
SigContEvent,
|
||||
SigHupEvent,
|
||||
SigIntEvent,
|
||||
SigTStpEvent,
|
||||
SigTermEvent,
|
||||
)
|
||||
from crewai.telemetry.constants import (
|
||||
CREWAI_TELEMETRY_BASE_URL,
|
||||
CREWAI_TELEMETRY_SERVICE_NAME,
|
||||
@@ -121,6 +131,7 @@ class Telemetry:
|
||||
)
|
||||
|
||||
self.provider.add_span_processor(processor)
|
||||
self._register_shutdown_handlers()
|
||||
self.ready = True
|
||||
except Exception as e:
|
||||
if isinstance(
|
||||
@@ -155,6 +166,71 @@ class Telemetry:
|
||||
self.ready = False
|
||||
self.trace_set = False
|
||||
|
||||
def _register_shutdown_handlers(self) -> None:
|
||||
"""Register handlers for graceful shutdown on process exit and signals."""
|
||||
atexit.register(self._shutdown)
|
||||
|
||||
self._original_handlers: dict[int, Any] = {}
|
||||
|
||||
self._register_signal_handler(signal.SIGTERM, SigTermEvent, shutdown=True)
|
||||
self._register_signal_handler(signal.SIGINT, SigIntEvent, shutdown=True)
|
||||
self._register_signal_handler(signal.SIGHUP, SigHupEvent, shutdown=False)
|
||||
self._register_signal_handler(signal.SIGTSTP, SigTStpEvent, shutdown=False)
|
||||
self._register_signal_handler(signal.SIGCONT, SigContEvent, shutdown=False)
|
||||
|
||||
def _register_signal_handler(
|
||||
self,
|
||||
sig: signal.Signals,
|
||||
event_class: type,
|
||||
shutdown: bool = False,
|
||||
) -> None:
|
||||
"""Register a signal handler that emits an event.
|
||||
|
||||
Args:
|
||||
sig: The signal to handle.
|
||||
event_class: The event class to instantiate and emit.
|
||||
shutdown: Whether to trigger shutdown on this signal.
|
||||
"""
|
||||
try:
|
||||
original_handler = signal.getsignal(sig)
|
||||
self._original_handlers[sig] = original_handler
|
||||
|
||||
def handler(signum: int, frame: Any) -> None:
|
||||
crewai_event_bus.emit(self, event_class())
|
||||
|
||||
if shutdown:
|
||||
self._shutdown()
|
||||
|
||||
if original_handler not in (signal.SIG_DFL, signal.SIG_IGN, None):
|
||||
if callable(original_handler):
|
||||
original_handler(signum, frame)
|
||||
elif shutdown:
|
||||
raise SystemExit(0)
|
||||
|
||||
signal.signal(sig, handler)
|
||||
except ValueError as e:
|
||||
logger.warning(
|
||||
f"Cannot register {sig.name} handler: not running in main thread",
|
||||
exc_info=e,
|
||||
)
|
||||
except OSError as e:
|
||||
logger.warning(f"Cannot register {sig.name} handler: {e}", exc_info=e)
|
||||
|
||||
def _shutdown(self) -> None:
|
||||
"""Flush and shutdown the telemetry provider on process exit.
|
||||
|
||||
Uses a short timeout to avoid blocking process shutdown.
|
||||
"""
|
||||
if not self.ready:
|
||||
return
|
||||
|
||||
try:
|
||||
self.provider.force_flush(timeout_millis=5000)
|
||||
self.provider.shutdown()
|
||||
self.ready = False
|
||||
except Exception as e:
|
||||
logger.debug(f"Telemetry shutdown failed: {e}")
|
||||
|
||||
def _safe_telemetry_operation(
|
||||
self, operation: Callable[[], Span | None]
|
||||
) -> Span | None:
|
||||
@@ -316,9 +392,7 @@ class Telemetry:
|
||||
self._add_attribute(span, "platform_system", platform.system())
|
||||
self._add_attribute(span, "platform_version", platform.version())
|
||||
self._add_attribute(span, "cpus", os.cpu_count())
|
||||
self._add_attribute(
|
||||
span, "crew_inputs", json.dumps(inputs) if inputs else None
|
||||
)
|
||||
self._add_attribute(span, "crew_inputs", json.dumps(inputs or {}))
|
||||
else:
|
||||
self._add_attribute(
|
||||
span,
|
||||
@@ -631,9 +705,7 @@ class Telemetry:
|
||||
self._add_attribute(span, "model_name", model_name)
|
||||
|
||||
if crew.share_crew:
|
||||
self._add_attribute(
|
||||
span, "inputs", json.dumps(inputs) if inputs else None
|
||||
)
|
||||
self._add_attribute(span, "inputs", json.dumps(inputs or {}))
|
||||
|
||||
close_span(span)
|
||||
|
||||
@@ -738,9 +810,7 @@ class Telemetry:
|
||||
add_crew_attributes(
|
||||
span, crew, self._add_attribute, include_fingerprint=False
|
||||
)
|
||||
self._add_attribute(
|
||||
span, "crew_inputs", json.dumps(inputs) if inputs else None
|
||||
)
|
||||
self._add_attribute(span, "crew_inputs", json.dumps(inputs or {}))
|
||||
self._add_attribute(
|
||||
span,
|
||||
"crew_agents",
|
||||
|
||||
@@ -2,9 +2,18 @@ from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
import asyncio
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Awaitable, Callable
|
||||
from inspect import signature
|
||||
from typing import Any, cast, get_args, get_origin
|
||||
from typing import (
|
||||
Any,
|
||||
Generic,
|
||||
ParamSpec,
|
||||
TypeVar,
|
||||
cast,
|
||||
get_args,
|
||||
get_origin,
|
||||
overload,
|
||||
)
|
||||
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
@@ -14,6 +23,7 @@ from pydantic import (
|
||||
create_model,
|
||||
field_validator,
|
||||
)
|
||||
from typing_extensions import TypeIs
|
||||
|
||||
from crewai.tools.structured_tool import CrewStructuredTool
|
||||
from crewai.utilities.printer import Printer
|
||||
@@ -21,6 +31,19 @@ from crewai.utilities.printer import Printer
|
||||
|
||||
_printer = Printer()
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R", covariant=True)
|
||||
|
||||
|
||||
def _is_async_callable(func: Callable[..., Any]) -> bool:
|
||||
"""Check if a callable is async."""
|
||||
return asyncio.iscoroutinefunction(func)
|
||||
|
||||
|
||||
def _is_awaitable(value: R | Awaitable[R]) -> TypeIs[Awaitable[R]]:
|
||||
"""Type narrowing check for awaitable values."""
|
||||
return asyncio.iscoroutine(value) or asyncio.isfuture(value)
|
||||
|
||||
|
||||
class EnvVar(BaseModel):
|
||||
name: str
|
||||
@@ -55,7 +78,7 @@ class BaseTool(BaseModel, ABC):
|
||||
default=False, description="Flag to check if the description has been updated."
|
||||
)
|
||||
|
||||
cache_function: Callable = Field(
|
||||
cache_function: Callable[..., bool] = Field(
|
||||
default=lambda _args=None, _result=None: True,
|
||||
description="Function that will be used to determine if the tool should be cached, should return a boolean. If None, the tool will be cached.",
|
||||
)
|
||||
@@ -123,6 +146,35 @@ class BaseTool(BaseModel, ABC):
|
||||
|
||||
return result
|
||||
|
||||
async def arun(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Execute the tool asynchronously.
|
||||
|
||||
Args:
|
||||
*args: Positional arguments to pass to the tool.
|
||||
**kwargs: Keyword arguments to pass to the tool.
|
||||
|
||||
Returns:
|
||||
The result of the tool execution.
|
||||
"""
|
||||
result = await self._arun(*args, **kwargs)
|
||||
self.current_usage_count += 1
|
||||
return result
|
||||
|
||||
async def _arun(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Async implementation of the tool. Override for async support."""
|
||||
raise NotImplementedError(
|
||||
f"{self.__class__.__name__} does not implement _arun. "
|
||||
"Override _arun for async support or use run() for sync execution."
|
||||
)
|
||||
|
||||
def reset_usage_count(self) -> None:
|
||||
"""Reset the current usage count to zero."""
|
||||
self.current_usage_count = 0
|
||||
@@ -133,7 +185,17 @@ class BaseTool(BaseModel, ABC):
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Here goes the actual implementation of the tool."""
|
||||
"""Sync implementation of the tool.
|
||||
|
||||
Subclasses must implement this method for synchronous execution.
|
||||
|
||||
Args:
|
||||
*args: Positional arguments for the tool.
|
||||
**kwargs: Keyword arguments for the tool.
|
||||
|
||||
Returns:
|
||||
The result of the tool execution.
|
||||
"""
|
||||
|
||||
def to_structured_tool(self) -> CrewStructuredTool:
|
||||
"""Convert this tool to a CrewStructuredTool instance."""
|
||||
@@ -239,21 +301,90 @@ class BaseTool(BaseModel, ABC):
|
||||
|
||||
if args:
|
||||
args_str = ", ".join(BaseTool._get_arg_annotations(arg) for arg in args)
|
||||
return f"{origin.__name__}[{args_str}]"
|
||||
return str(f"{origin.__name__}[{args_str}]")
|
||||
|
||||
return origin.__name__
|
||||
return str(origin.__name__)
|
||||
|
||||
|
||||
class Tool(BaseTool):
|
||||
"""The function that will be executed when the tool is called."""
|
||||
class Tool(BaseTool, Generic[P, R]):
|
||||
"""Tool that wraps a callable function.
|
||||
|
||||
func: Callable
|
||||
|
||||
def _run(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return self.func(*args, **kwargs)
|
||||
Type Parameters:
|
||||
P: ParamSpec capturing the function's parameters.
|
||||
R: The return type of the function.
|
||||
"""
|
||||
|
||||
func: Callable[P, R | Awaitable[R]]
|
||||
|
||||
def run(self, *args: P.args, **kwargs: P.kwargs) -> R:
|
||||
"""Executes the tool synchronously.
|
||||
|
||||
Args:
|
||||
*args: Positional arguments for the tool.
|
||||
**kwargs: Keyword arguments for the tool.
|
||||
|
||||
Returns:
|
||||
The result of the tool execution.
|
||||
"""
|
||||
_printer.print(f"Using Tool: {self.name}", color="cyan")
|
||||
result = self.func(*args, **kwargs)
|
||||
|
||||
if asyncio.iscoroutine(result):
|
||||
result = asyncio.run(result)
|
||||
|
||||
self.current_usage_count += 1
|
||||
return result # type: ignore[return-value]
|
||||
|
||||
def _run(self, *args: P.args, **kwargs: P.kwargs) -> R:
|
||||
"""Executes the wrapped function.
|
||||
|
||||
Args:
|
||||
*args: Positional arguments for the function.
|
||||
**kwargs: Keyword arguments for the function.
|
||||
|
||||
Returns:
|
||||
The result of the function execution.
|
||||
"""
|
||||
return self.func(*args, **kwargs) # type: ignore[return-value]
|
||||
|
||||
async def arun(self, *args: P.args, **kwargs: P.kwargs) -> R:
|
||||
"""Executes the tool asynchronously.
|
||||
|
||||
Args:
|
||||
*args: Positional arguments for the tool.
|
||||
**kwargs: Keyword arguments for the tool.
|
||||
|
||||
Returns:
|
||||
The result of the tool execution.
|
||||
"""
|
||||
result = await self._arun(*args, **kwargs)
|
||||
self.current_usage_count += 1
|
||||
return result
|
||||
|
||||
async def _arun(self, *args: P.args, **kwargs: P.kwargs) -> R:
|
||||
"""Executes the wrapped function asynchronously.
|
||||
|
||||
Args:
|
||||
*args: Positional arguments for the function.
|
||||
**kwargs: Keyword arguments for the function.
|
||||
|
||||
Returns:
|
||||
The result of the async function execution.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: If the wrapped function is not async.
|
||||
"""
|
||||
result = self.func(*args, **kwargs)
|
||||
if _is_awaitable(result):
|
||||
return await result
|
||||
raise NotImplementedError(
|
||||
f"{self.name} does not have an async function. "
|
||||
"Use run() for sync execution or provide an async function."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_langchain(cls, tool: Any) -> Tool:
|
||||
def from_langchain(cls, tool: Any) -> Tool[..., Any]:
|
||||
"""Create a Tool instance from a CrewStructuredTool.
|
||||
|
||||
This method takes a CrewStructuredTool object and converts it into a
|
||||
@@ -261,10 +392,10 @@ class Tool(BaseTool):
|
||||
attribute and infers the argument schema if not explicitly provided.
|
||||
|
||||
Args:
|
||||
tool (Any): The CrewStructuredTool object to be converted.
|
||||
tool: The CrewStructuredTool object to be converted.
|
||||
|
||||
Returns:
|
||||
Tool: A new Tool instance created from the provided CrewStructuredTool.
|
||||
A new Tool instance created from the provided CrewStructuredTool.
|
||||
|
||||
Raises:
|
||||
ValueError: If the provided tool does not have a callable 'func' attribute.
|
||||
@@ -308,37 +439,83 @@ class Tool(BaseTool):
|
||||
def to_langchain(
|
||||
tools: list[BaseTool | CrewStructuredTool],
|
||||
) -> list[CrewStructuredTool]:
|
||||
"""Convert a list of tools to CrewStructuredTool instances."""
|
||||
return [t.to_structured_tool() if isinstance(t, BaseTool) else t for t in tools]
|
||||
|
||||
|
||||
P2 = ParamSpec("P2")
|
||||
R2 = TypeVar("R2")
|
||||
|
||||
|
||||
@overload
|
||||
def tool(func: Callable[P2, R2], /) -> Tool[P2, R2]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def tool(
|
||||
*args, result_as_answer: bool = False, max_usage_count: int | None = None
|
||||
) -> Callable:
|
||||
"""
|
||||
Decorator to create a tool from a function.
|
||||
name: str,
|
||||
/,
|
||||
*,
|
||||
result_as_answer: bool = ...,
|
||||
max_usage_count: int | None = ...,
|
||||
) -> Callable[[Callable[P2, R2]], Tool[P2, R2]]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def tool(
|
||||
*,
|
||||
result_as_answer: bool = ...,
|
||||
max_usage_count: int | None = ...,
|
||||
) -> Callable[[Callable[P2, R2]], Tool[P2, R2]]: ...
|
||||
|
||||
|
||||
def tool(
|
||||
*args: Callable[P2, R2] | str,
|
||||
result_as_answer: bool = False,
|
||||
max_usage_count: int | None = None,
|
||||
) -> Tool[P2, R2] | Callable[[Callable[P2, R2]], Tool[P2, R2]]:
|
||||
"""Decorator to create a Tool from a function.
|
||||
|
||||
Can be used in three ways:
|
||||
1. @tool - decorator without arguments, uses function name
|
||||
2. @tool("name") - decorator with custom name
|
||||
3. @tool(result_as_answer=True) - decorator with options
|
||||
|
||||
Args:
|
||||
*args: Positional arguments, either the function to decorate or the tool name.
|
||||
result_as_answer: Flag to indicate if the tool result should be used as the final agent answer.
|
||||
max_usage_count: Maximum number of times this tool can be used. None means unlimited usage.
|
||||
*args: Either the function to decorate or a custom tool name.
|
||||
result_as_answer: If True, the tool result becomes the final agent answer.
|
||||
max_usage_count: Maximum times this tool can be used. None means unlimited.
|
||||
|
||||
Returns:
|
||||
A Tool instance.
|
||||
|
||||
Example:
|
||||
@tool
|
||||
def greet(name: str) -> str:
|
||||
'''Greet someone.'''
|
||||
return f"Hello, {name}!"
|
||||
|
||||
result = greet.run("World")
|
||||
"""
|
||||
|
||||
def _make_with_name(tool_name: str) -> Callable:
|
||||
def _make_tool(f: Callable) -> BaseTool:
|
||||
def _make_with_name(tool_name: str) -> Callable[[Callable[P2, R2]], Tool[P2, R2]]:
|
||||
def _make_tool(f: Callable[P2, R2]) -> Tool[P2, R2]:
|
||||
if f.__doc__ is None:
|
||||
raise ValueError("Function must have a docstring")
|
||||
if f.__annotations__ is None:
|
||||
|
||||
func_annotations = getattr(f, "__annotations__", None)
|
||||
if func_annotations is None:
|
||||
raise ValueError("Function must have type annotations")
|
||||
|
||||
class_name = "".join(tool_name.split()).title()
|
||||
args_schema = cast(
|
||||
tool_args_schema = cast(
|
||||
type[PydanticBaseModel],
|
||||
type(
|
||||
class_name,
|
||||
(PydanticBaseModel,),
|
||||
{
|
||||
"__annotations__": {
|
||||
k: v for k, v in f.__annotations__.items() if k != "return"
|
||||
k: v for k, v in func_annotations.items() if k != "return"
|
||||
},
|
||||
},
|
||||
),
|
||||
@@ -348,10 +525,9 @@ def tool(
|
||||
name=tool_name,
|
||||
description=f.__doc__,
|
||||
func=f,
|
||||
args_schema=args_schema,
|
||||
args_schema=tool_args_schema,
|
||||
result_as_answer=result_as_answer,
|
||||
max_usage_count=max_usage_count,
|
||||
current_usage_count=0,
|
||||
)
|
||||
|
||||
return _make_tool
|
||||
@@ -360,4 +536,10 @@ def tool(
|
||||
return _make_with_name(args[0].__name__)(args[0])
|
||||
if len(args) == 1 and isinstance(args[0], str):
|
||||
return _make_with_name(args[0])
|
||||
if len(args) == 0:
|
||||
|
||||
def decorator(f: Callable[P2, R2]) -> Tool[P2, R2]:
|
||||
return _make_with_name(f.__name__)(f)
|
||||
|
||||
return decorator
|
||||
raise ValueError("Invalid arguments")
|
||||
|
||||
@@ -160,6 +160,251 @@ class ToolUsage:
|
||||
|
||||
return f"{self._use(tool_string=tool_string, tool=tool, calling=calling)}"
|
||||
|
||||
async def ause(
|
||||
self, calling: ToolCalling | InstructorToolCalling, tool_string: str
|
||||
) -> str:
|
||||
"""Execute a tool asynchronously.
|
||||
|
||||
Args:
|
||||
calling: The tool calling information.
|
||||
tool_string: The raw tool string from the agent.
|
||||
|
||||
Returns:
|
||||
The result of the tool execution as a string.
|
||||
"""
|
||||
if isinstance(calling, ToolUsageError):
|
||||
error = calling.message
|
||||
if self.agent and self.agent.verbose:
|
||||
self._printer.print(content=f"\n\n{error}\n", color="red")
|
||||
if self.task:
|
||||
self.task.increment_tools_errors()
|
||||
return error
|
||||
|
||||
try:
|
||||
tool = self._select_tool(calling.tool_name)
|
||||
except Exception as e:
|
||||
error = getattr(e, "message", str(e))
|
||||
if self.task:
|
||||
self.task.increment_tools_errors()
|
||||
if self.agent and self.agent.verbose:
|
||||
self._printer.print(content=f"\n\n{error}\n", color="red")
|
||||
return error
|
||||
|
||||
if (
|
||||
isinstance(tool, CrewStructuredTool)
|
||||
and tool.name == self._i18n.tools("add_image")["name"] # type: ignore
|
||||
):
|
||||
try:
|
||||
return await self._ause(
|
||||
tool_string=tool_string, tool=tool, calling=calling
|
||||
)
|
||||
except Exception as e:
|
||||
error = getattr(e, "message", str(e))
|
||||
if self.task:
|
||||
self.task.increment_tools_errors()
|
||||
if self.agent and self.agent.verbose:
|
||||
self._printer.print(content=f"\n\n{error}\n", color="red")
|
||||
return error
|
||||
|
||||
return (
|
||||
f"{await self._ause(tool_string=tool_string, tool=tool, calling=calling)}"
|
||||
)
|
||||
|
||||
async def _ause(
|
||||
self,
|
||||
tool_string: str,
|
||||
tool: CrewStructuredTool,
|
||||
calling: ToolCalling | InstructorToolCalling,
|
||||
) -> str:
|
||||
"""Internal async tool execution implementation.
|
||||
|
||||
Args:
|
||||
tool_string: The raw tool string from the agent.
|
||||
tool: The tool to execute.
|
||||
calling: The tool calling information.
|
||||
|
||||
Returns:
|
||||
The result of the tool execution as a string.
|
||||
"""
|
||||
if self._check_tool_repeated_usage(calling=calling):
|
||||
try:
|
||||
result = self._i18n.errors("task_repeated_usage").format(
|
||||
tool_names=self.tools_names
|
||||
)
|
||||
self._telemetry.tool_repeated_usage(
|
||||
llm=self.function_calling_llm,
|
||||
tool_name=tool.name,
|
||||
attempts=self._run_attempts,
|
||||
)
|
||||
return self._format_result(result=result)
|
||||
except Exception:
|
||||
if self.task:
|
||||
self.task.increment_tools_errors()
|
||||
|
||||
if self.agent:
|
||||
event_data = {
|
||||
"agent_key": self.agent.key,
|
||||
"agent_role": self.agent.role,
|
||||
"tool_name": self.action.tool,
|
||||
"tool_args": self.action.tool_input,
|
||||
"tool_class": self.action.tool,
|
||||
"agent": self.agent,
|
||||
}
|
||||
|
||||
if self.agent.fingerprint: # type: ignore
|
||||
event_data.update(self.agent.fingerprint) # type: ignore
|
||||
if self.task:
|
||||
event_data["task_name"] = self.task.name or self.task.description
|
||||
event_data["task_id"] = str(self.task.id)
|
||||
crewai_event_bus.emit(self, ToolUsageStartedEvent(**event_data))
|
||||
|
||||
started_at = time.time()
|
||||
from_cache = False
|
||||
result = None # type: ignore
|
||||
|
||||
if self.tools_handler and self.tools_handler.cache:
|
||||
input_str = ""
|
||||
if calling.arguments:
|
||||
if isinstance(calling.arguments, dict):
|
||||
input_str = json.dumps(calling.arguments)
|
||||
else:
|
||||
input_str = str(calling.arguments)
|
||||
|
||||
result = self.tools_handler.cache.read(
|
||||
tool=calling.tool_name, input=input_str
|
||||
) # type: ignore
|
||||
from_cache = result is not None
|
||||
|
||||
available_tool = next(
|
||||
(
|
||||
available_tool
|
||||
for available_tool in self.tools
|
||||
if available_tool.name == tool.name
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
usage_limit_error = self._check_usage_limit(available_tool, tool.name)
|
||||
if usage_limit_error:
|
||||
try:
|
||||
result = usage_limit_error
|
||||
self._telemetry.tool_usage_error(llm=self.function_calling_llm)
|
||||
return self._format_result(result=result)
|
||||
except Exception:
|
||||
if self.task:
|
||||
self.task.increment_tools_errors()
|
||||
|
||||
if result is None:
|
||||
try:
|
||||
if calling.tool_name in [
|
||||
"Delegate work to coworker",
|
||||
"Ask question to coworker",
|
||||
]:
|
||||
coworker = (
|
||||
calling.arguments.get("coworker") if calling.arguments else None
|
||||
)
|
||||
if self.task:
|
||||
self.task.increment_delegations(coworker)
|
||||
|
||||
if calling.arguments:
|
||||
try:
|
||||
acceptable_args = tool.args_schema.model_json_schema()[
|
||||
"properties"
|
||||
].keys()
|
||||
arguments = {
|
||||
k: v
|
||||
for k, v in calling.arguments.items()
|
||||
if k in acceptable_args
|
||||
}
|
||||
arguments = self._add_fingerprint_metadata(arguments)
|
||||
result = await tool.ainvoke(input=arguments)
|
||||
except Exception:
|
||||
arguments = calling.arguments
|
||||
arguments = self._add_fingerprint_metadata(arguments)
|
||||
result = await tool.ainvoke(input=arguments)
|
||||
else:
|
||||
arguments = self._add_fingerprint_metadata({})
|
||||
result = await tool.ainvoke(input=arguments)
|
||||
except Exception as e:
|
||||
self.on_tool_error(tool=tool, tool_calling=calling, e=e)
|
||||
self._run_attempts += 1
|
||||
if self._run_attempts > self._max_parsing_attempts:
|
||||
self._telemetry.tool_usage_error(llm=self.function_calling_llm)
|
||||
error_message = self._i18n.errors("tool_usage_exception").format(
|
||||
error=e, tool=tool.name, tool_inputs=tool.description
|
||||
)
|
||||
error = ToolUsageError(
|
||||
f"\n{error_message}.\nMoving on then. {self._i18n.slice('format').format(tool_names=self.tools_names)}"
|
||||
).message
|
||||
if self.task:
|
||||
self.task.increment_tools_errors()
|
||||
if self.agent and self.agent.verbose:
|
||||
self._printer.print(
|
||||
content=f"\n\n{error_message}\n", color="red"
|
||||
)
|
||||
return error
|
||||
|
||||
if self.task:
|
||||
self.task.increment_tools_errors()
|
||||
return await self.ause(calling=calling, tool_string=tool_string)
|
||||
|
||||
if self.tools_handler:
|
||||
should_cache = True
|
||||
if (
|
||||
hasattr(available_tool, "cache_function")
|
||||
and available_tool.cache_function
|
||||
):
|
||||
should_cache = available_tool.cache_function(
|
||||
calling.arguments, result
|
||||
)
|
||||
|
||||
self.tools_handler.on_tool_use(
|
||||
calling=calling, output=result, should_cache=should_cache
|
||||
)
|
||||
|
||||
self._telemetry.tool_usage(
|
||||
llm=self.function_calling_llm,
|
||||
tool_name=tool.name,
|
||||
attempts=self._run_attempts,
|
||||
)
|
||||
result = self._format_result(result=result)
|
||||
data = {
|
||||
"result": result,
|
||||
"tool_name": tool.name,
|
||||
"tool_args": calling.arguments,
|
||||
}
|
||||
|
||||
self.on_tool_use_finished(
|
||||
tool=tool,
|
||||
tool_calling=calling,
|
||||
from_cache=from_cache,
|
||||
started_at=started_at,
|
||||
result=result,
|
||||
)
|
||||
|
||||
if (
|
||||
hasattr(available_tool, "result_as_answer")
|
||||
and available_tool.result_as_answer # type: ignore
|
||||
):
|
||||
result_as_answer = available_tool.result_as_answer # type: ignore
|
||||
data["result_as_answer"] = result_as_answer # type: ignore
|
||||
|
||||
if self.agent and hasattr(self.agent, "tools_results"):
|
||||
self.agent.tools_results.append(data)
|
||||
|
||||
if available_tool and hasattr(available_tool, "current_usage_count"):
|
||||
available_tool.current_usage_count += 1
|
||||
if (
|
||||
hasattr(available_tool, "max_usage_count")
|
||||
and available_tool.max_usage_count is not None
|
||||
):
|
||||
self._printer.print(
|
||||
content=f"Tool '{available_tool.name}' usage: {available_tool.current_usage_count}/{available_tool.max_usage_count}",
|
||||
color="blue",
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def _use(
|
||||
self,
|
||||
tool_string: str,
|
||||
|
||||
@@ -237,22 +237,22 @@ def get_llm_response(
|
||||
from_task: Task | None = None,
|
||||
from_agent: Agent | LiteAgent | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
executor_context: CrewAgentExecutor | None = None,
|
||||
executor_context: CrewAgentExecutor | LiteAgent | None = None,
|
||||
) -> str:
|
||||
"""Call the LLM and return the response, handling any invalid responses.
|
||||
|
||||
Args:
|
||||
llm: The LLM instance to call
|
||||
messages: The messages to send to the LLM
|
||||
callbacks: List of callbacks for the LLM call
|
||||
printer: Printer instance for output
|
||||
from_task: Optional task context for the LLM call
|
||||
from_agent: Optional agent context for the LLM call
|
||||
response_model: Optional Pydantic model for structured outputs
|
||||
executor_context: Optional executor context for hook invocation
|
||||
llm: The LLM instance to call.
|
||||
messages: The messages to send to the LLM.
|
||||
callbacks: List of callbacks for the LLM call.
|
||||
printer: Printer instance for output.
|
||||
from_task: Optional task context for the LLM call.
|
||||
from_agent: Optional agent context for the LLM call.
|
||||
response_model: Optional Pydantic model for structured outputs.
|
||||
executor_context: Optional executor context for hook invocation.
|
||||
|
||||
Returns:
|
||||
The response from the LLM as a string
|
||||
The response from the LLM as a string.
|
||||
|
||||
Raises:
|
||||
Exception: If an error occurs.
|
||||
@@ -284,6 +284,60 @@ def get_llm_response(
|
||||
return _setup_after_llm_call_hooks(executor_context, answer, printer)
|
||||
|
||||
|
||||
async def aget_llm_response(
|
||||
llm: LLM | BaseLLM,
|
||||
messages: list[LLMMessage],
|
||||
callbacks: list[TokenCalcHandler],
|
||||
printer: Printer,
|
||||
from_task: Task | None = None,
|
||||
from_agent: Agent | LiteAgent | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
executor_context: CrewAgentExecutor | None = None,
|
||||
) -> str:
|
||||
"""Call the LLM asynchronously and return the response.
|
||||
|
||||
Args:
|
||||
llm: The LLM instance to call.
|
||||
messages: The messages to send to the LLM.
|
||||
callbacks: List of callbacks for the LLM call.
|
||||
printer: Printer instance for output.
|
||||
from_task: Optional task context for the LLM call.
|
||||
from_agent: Optional agent context for the LLM call.
|
||||
response_model: Optional Pydantic model for structured outputs.
|
||||
executor_context: Optional executor context for hook invocation.
|
||||
|
||||
Returns:
|
||||
The response from the LLM as a string.
|
||||
|
||||
Raises:
|
||||
Exception: If an error occurs.
|
||||
ValueError: If the response is None or empty.
|
||||
"""
|
||||
if executor_context is not None:
|
||||
if not _setup_before_llm_call_hooks(executor_context, printer):
|
||||
raise ValueError("LLM call blocked by before_llm_call hook")
|
||||
messages = executor_context.messages
|
||||
|
||||
try:
|
||||
answer = await llm.acall(
|
||||
messages,
|
||||
callbacks=callbacks,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent, # type: ignore[arg-type]
|
||||
response_model=response_model,
|
||||
)
|
||||
except Exception as e:
|
||||
raise e
|
||||
if not answer:
|
||||
printer.print(
|
||||
content="Received None or empty response from LLM call.",
|
||||
color="red",
|
||||
)
|
||||
raise ValueError("Invalid response from LLM call - None or empty.")
|
||||
|
||||
return _setup_after_llm_call_hooks(executor_context, answer, printer)
|
||||
|
||||
|
||||
def process_llm_response(
|
||||
answer: str, use_stop_words: bool
|
||||
) -> AgentAction | AgentFinish:
|
||||
@@ -673,7 +727,7 @@ def load_agent_from_repository(from_repository: str) -> dict[str, Any]:
|
||||
|
||||
|
||||
def _setup_before_llm_call_hooks(
|
||||
executor_context: CrewAgentExecutor | None, printer: Printer
|
||||
executor_context: CrewAgentExecutor | LiteAgent | None, printer: Printer
|
||||
) -> bool:
|
||||
"""Setup and invoke before_llm_call hooks for the executor context.
|
||||
|
||||
@@ -723,7 +777,7 @@ def _setup_before_llm_call_hooks(
|
||||
|
||||
|
||||
def _setup_after_llm_call_hooks(
|
||||
executor_context: CrewAgentExecutor | None,
|
||||
executor_context: CrewAgentExecutor | LiteAgent | None,
|
||||
answer: str,
|
||||
printer: Printer,
|
||||
) -> str:
|
||||
|
||||
@@ -26,6 +26,138 @@ if TYPE_CHECKING:
|
||||
from crewai.task import Task
|
||||
|
||||
|
||||
async def aexecute_tool_and_check_finality(
|
||||
agent_action: AgentAction,
|
||||
tools: list[CrewStructuredTool],
|
||||
i18n: I18N,
|
||||
agent_key: str | None = None,
|
||||
agent_role: str | None = None,
|
||||
tools_handler: ToolsHandler | None = None,
|
||||
task: Task | None = None,
|
||||
agent: Agent | BaseAgent | None = None,
|
||||
function_calling_llm: BaseLLM | LLM | None = None,
|
||||
fingerprint_context: dict[str, str] | None = None,
|
||||
crew: Crew | None = None,
|
||||
) -> ToolResult:
|
||||
"""Execute a tool asynchronously and check if the result should be a final answer.
|
||||
|
||||
This is the async version of execute_tool_and_check_finality. It integrates tool
|
||||
hooks for before and after tool execution, allowing programmatic interception
|
||||
and modification of tool calls.
|
||||
|
||||
Args:
|
||||
agent_action: The action containing the tool to execute.
|
||||
tools: List of available tools.
|
||||
i18n: Internationalization settings.
|
||||
agent_key: Optional key for event emission.
|
||||
agent_role: Optional role for event emission.
|
||||
tools_handler: Optional tools handler for tool execution.
|
||||
task: Optional task for tool execution.
|
||||
agent: Optional agent instance for tool execution.
|
||||
function_calling_llm: Optional LLM for function calling.
|
||||
fingerprint_context: Optional context for fingerprinting.
|
||||
crew: Optional crew instance for hook context.
|
||||
|
||||
Returns:
|
||||
ToolResult containing the execution result and whether it should be
|
||||
treated as a final answer.
|
||||
"""
|
||||
logger = Logger(verbose=crew.verbose if crew else False)
|
||||
tool_name_to_tool_map = {tool.name: tool for tool in tools}
|
||||
|
||||
if agent_key and agent_role and agent:
|
||||
fingerprint_context = fingerprint_context or {}
|
||||
if agent:
|
||||
if hasattr(agent, "set_fingerprint") and callable(agent.set_fingerprint):
|
||||
if isinstance(fingerprint_context, dict):
|
||||
try:
|
||||
fingerprint_obj = Fingerprint.from_dict(fingerprint_context)
|
||||
agent.set_fingerprint(fingerprint=fingerprint_obj)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to set fingerprint: {e}") from e
|
||||
|
||||
tool_usage = ToolUsage(
|
||||
tools_handler=tools_handler,
|
||||
tools=tools,
|
||||
function_calling_llm=function_calling_llm, # type: ignore[arg-type]
|
||||
task=task,
|
||||
agent=agent,
|
||||
action=agent_action,
|
||||
)
|
||||
|
||||
tool_calling = tool_usage.parse_tool_calling(agent_action.text)
|
||||
|
||||
if isinstance(tool_calling, ToolUsageError):
|
||||
return ToolResult(tool_calling.message, False)
|
||||
|
||||
if tool_calling.tool_name.casefold().strip() in [
|
||||
name.casefold().strip() for name in tool_name_to_tool_map
|
||||
] or tool_calling.tool_name.casefold().replace("_", " ") in [
|
||||
name.casefold().strip() for name in tool_name_to_tool_map
|
||||
]:
|
||||
tool = tool_name_to_tool_map.get(tool_calling.tool_name)
|
||||
if not tool:
|
||||
tool_result = i18n.errors("wrong_tool_name").format(
|
||||
tool=tool_calling.tool_name,
|
||||
tools=", ".join([t.name.casefold() for t in tools]),
|
||||
)
|
||||
return ToolResult(result=tool_result, result_as_answer=False)
|
||||
|
||||
tool_input = tool_calling.arguments if tool_calling.arguments else {}
|
||||
hook_context = ToolCallHookContext(
|
||||
tool_name=tool_calling.tool_name,
|
||||
tool_input=tool_input,
|
||||
tool=tool,
|
||||
agent=agent,
|
||||
task=task,
|
||||
crew=crew,
|
||||
)
|
||||
|
||||
before_hooks = get_before_tool_call_hooks()
|
||||
try:
|
||||
for hook in before_hooks:
|
||||
result = hook(hook_context)
|
||||
if result is False:
|
||||
blocked_message = (
|
||||
f"Tool execution blocked by hook. "
|
||||
f"Tool: {tool_calling.tool_name}"
|
||||
)
|
||||
return ToolResult(blocked_message, False)
|
||||
except Exception as e:
|
||||
logger.log("error", f"Error in before_tool_call hook: {e}")
|
||||
|
||||
tool_result = await tool_usage.ause(tool_calling, agent_action.text)
|
||||
|
||||
after_hook_context = ToolCallHookContext(
|
||||
tool_name=tool_calling.tool_name,
|
||||
tool_input=tool_input,
|
||||
tool=tool,
|
||||
agent=agent,
|
||||
task=task,
|
||||
crew=crew,
|
||||
tool_result=tool_result,
|
||||
)
|
||||
|
||||
after_hooks = get_after_tool_call_hooks()
|
||||
modified_result: str = tool_result
|
||||
try:
|
||||
for after_hook in after_hooks:
|
||||
hook_result = after_hook(after_hook_context)
|
||||
if hook_result is not None:
|
||||
modified_result = hook_result
|
||||
after_hook_context.tool_result = modified_result
|
||||
except Exception as e:
|
||||
logger.log("error", f"Error in after_tool_call hook: {e}")
|
||||
|
||||
return ToolResult(modified_result, tool.result_as_answer)
|
||||
|
||||
tool_result = i18n.errors("wrong_tool_name").format(
|
||||
tool=tool_calling.tool_name,
|
||||
tools=", ".join([tool.name.casefold() for tool in tools]),
|
||||
)
|
||||
return ToolResult(result=tool_result, result_as_answer=False)
|
||||
|
||||
|
||||
def execute_tool_and_check_finality(
|
||||
agent_action: AgentAction,
|
||||
tools: list[CrewStructuredTool],
|
||||
@@ -141,10 +273,10 @@ def execute_tool_and_check_finality(
|
||||
|
||||
# Execute after_tool_call hooks
|
||||
after_hooks = get_after_tool_call_hooks()
|
||||
modified_result = tool_result
|
||||
modified_result: str = tool_result
|
||||
try:
|
||||
for hook in after_hooks:
|
||||
hook_result = hook(after_hook_context)
|
||||
for after_hook in after_hooks:
|
||||
hook_result = after_hook(after_hook_context)
|
||||
if hook_result is not None:
|
||||
modified_result = hook_result
|
||||
after_hook_context.tool_result = modified_result
|
||||
|
||||
@@ -51,6 +51,15 @@ class ConcreteAgentAdapter(BaseAgentAdapter):
|
||||
# Dummy implementation for MCP tools
|
||||
return []
|
||||
|
||||
async def aexecute_task(
|
||||
self,
|
||||
task: Any,
|
||||
context: str | None = None,
|
||||
tools: list[Any] | None = None,
|
||||
) -> str:
|
||||
# Dummy async implementation
|
||||
return "Task executed"
|
||||
|
||||
|
||||
def test_base_agent_adapter_initialization():
|
||||
"""Test initialization of the concrete agent adapter."""
|
||||
|
||||
@@ -25,6 +25,14 @@ class MockAgent(BaseAgent):
|
||||
def get_mcp_tools(self, mcps: list[str]) -> list[BaseTool]:
|
||||
return []
|
||||
|
||||
async def aexecute_task(
|
||||
self,
|
||||
task: Any,
|
||||
context: str | None = None,
|
||||
tools: list[BaseTool] | None = None,
|
||||
) -> str:
|
||||
return ""
|
||||
|
||||
def get_output_converter(
|
||||
self, llm: Any, text: str, model: type[BaseModel] | None, instructions: str
|
||||
): ...
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user