mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-19 12:58:14 +00:00
Compare commits
6 Commits
devin/1763
...
devin/1764
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
347381be57 | ||
|
|
20704742e2 | ||
|
|
59180e9c9f | ||
|
|
3ce019b07b | ||
|
|
2355ec0733 | ||
|
|
c925d2d519 |
161
.env.test
Normal file
161
.env.test
Normal file
@@ -0,0 +1,161 @@
|
||||
# =============================================================================
|
||||
# Test Environment Variables
|
||||
# =============================================================================
|
||||
# This file contains all environment variables needed to run tests locally
|
||||
# in a way that mimics the GitHub Actions CI environment.
|
||||
|
||||
# =============================================================================
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# LLM Provider API Keys
|
||||
# -----------------------------------------------------------------------------
|
||||
OPENAI_API_KEY=fake-api-key
|
||||
ANTHROPIC_API_KEY=fake-anthropic-key
|
||||
GEMINI_API_KEY=fake-gemini-key
|
||||
AZURE_API_KEY=fake-azure-key
|
||||
OPENROUTER_API_KEY=fake-openrouter-key
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# AWS Credentials
|
||||
# -----------------------------------------------------------------------------
|
||||
AWS_ACCESS_KEY_ID=fake-aws-access-key
|
||||
AWS_SECRET_ACCESS_KEY=fake-aws-secret-key
|
||||
AWS_DEFAULT_REGION=us-east-1
|
||||
AWS_REGION_NAME=us-east-1
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Azure OpenAI Configuration
|
||||
# -----------------------------------------------------------------------------
|
||||
AZURE_ENDPOINT=https://fake-azure-endpoint.openai.azure.com
|
||||
AZURE_OPENAI_ENDPOINT=https://fake-azure-endpoint.openai.azure.com
|
||||
AZURE_OPENAI_API_KEY=fake-azure-openai-key
|
||||
AZURE_API_VERSION=2024-02-15-preview
|
||||
OPENAI_API_VERSION=2024-02-15-preview
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Google Cloud Configuration
|
||||
# -----------------------------------------------------------------------------
|
||||
#GOOGLE_CLOUD_PROJECT=fake-gcp-project
|
||||
#GOOGLE_CLOUD_LOCATION=us-central1
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# OpenAI Configuration
|
||||
# -----------------------------------------------------------------------------
|
||||
OPENAI_BASE_URL=https://api.openai.com/v1
|
||||
OPENAI_API_BASE=https://api.openai.com/v1
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Search & Scraping Tool API Keys
|
||||
# -----------------------------------------------------------------------------
|
||||
SERPER_API_KEY=fake-serper-key
|
||||
EXA_API_KEY=fake-exa-key
|
||||
BRAVE_API_KEY=fake-brave-key
|
||||
FIRECRAWL_API_KEY=fake-firecrawl-key
|
||||
TAVILY_API_KEY=fake-tavily-key
|
||||
SERPAPI_API_KEY=fake-serpapi-key
|
||||
SERPLY_API_KEY=fake-serply-key
|
||||
LINKUP_API_KEY=fake-linkup-key
|
||||
PARALLEL_API_KEY=fake-parallel-key
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Exa Configuration
|
||||
# -----------------------------------------------------------------------------
|
||||
EXA_BASE_URL=https://api.exa.ai
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Web Scraping & Automation
|
||||
# -----------------------------------------------------------------------------
|
||||
BRIGHT_DATA_API_KEY=fake-brightdata-key
|
||||
BRIGHT_DATA_ZONE=fake-zone
|
||||
BRIGHTDATA_API_URL=https://api.brightdata.com
|
||||
BRIGHTDATA_DEFAULT_TIMEOUT=600
|
||||
BRIGHTDATA_DEFAULT_POLLING_INTERVAL=1
|
||||
|
||||
OXYLABS_USERNAME=fake-oxylabs-user
|
||||
OXYLABS_PASSWORD=fake-oxylabs-pass
|
||||
|
||||
SCRAPFLY_API_KEY=fake-scrapfly-key
|
||||
SCRAPEGRAPH_API_KEY=fake-scrapegraph-key
|
||||
|
||||
BROWSERBASE_API_KEY=fake-browserbase-key
|
||||
BROWSERBASE_PROJECT_ID=fake-browserbase-project
|
||||
|
||||
HYPERBROWSER_API_KEY=fake-hyperbrowser-key
|
||||
MULTION_API_KEY=fake-multion-key
|
||||
APIFY_API_TOKEN=fake-apify-token
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Database & Vector Store Credentials
|
||||
# -----------------------------------------------------------------------------
|
||||
SINGLESTOREDB_URL=mysql://fake:fake@localhost:3306/fake
|
||||
SINGLESTOREDB_HOST=localhost
|
||||
SINGLESTOREDB_PORT=3306
|
||||
SINGLESTOREDB_USER=fake-user
|
||||
SINGLESTOREDB_PASSWORD=fake-password
|
||||
SINGLESTOREDB_DATABASE=fake-database
|
||||
SINGLESTOREDB_CONNECT_TIMEOUT=30
|
||||
|
||||
SNOWFLAKE_USER=fake-snowflake-user
|
||||
SNOWFLAKE_PASSWORD=fake-snowflake-password
|
||||
SNOWFLAKE_ACCOUNT=fake-snowflake-account
|
||||
SNOWFLAKE_WAREHOUSE=fake-snowflake-warehouse
|
||||
SNOWFLAKE_DATABASE=fake-snowflake-database
|
||||
SNOWFLAKE_SCHEMA=fake-snowflake-schema
|
||||
|
||||
WEAVIATE_URL=http://localhost:8080
|
||||
WEAVIATE_API_KEY=fake-weaviate-key
|
||||
|
||||
EMBEDCHAIN_DB_URI=sqlite:///test.db
|
||||
|
||||
# Databricks Credentials
|
||||
DATABRICKS_HOST=https://fake-databricks.cloud.databricks.com
|
||||
DATABRICKS_TOKEN=fake-databricks-token
|
||||
DATABRICKS_CONFIG_PROFILE=fake-profile
|
||||
|
||||
# MongoDB Credentials
|
||||
MONGODB_URI=mongodb://fake:fake@localhost:27017/fake
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# CrewAI Platform & Enterprise
|
||||
# -----------------------------------------------------------------------------
|
||||
# setting CREWAI_PLATFORM_INTEGRATION_TOKEN causes these test to fail:
|
||||
#=========================== short test summary info ============================
|
||||
#FAILED tests/test_context.py::TestPlatformIntegrationToken::test_platform_context_manager_basic_usage - AssertionError: assert 'fake-platform-token' is None
|
||||
# + where 'fake-platform-token' = get_platform_integration_token()
|
||||
#FAILED tests/test_context.py::TestPlatformIntegrationToken::test_context_var_isolation_between_tests - AssertionError: assert 'fake-platform-token' is None
|
||||
# + where 'fake-platform-token' = get_platform_integration_token()
|
||||
#FAILED tests/test_context.py::TestPlatformIntegrationToken::test_multiple_sequential_context_managers - AssertionError: assert 'fake-platform-token' is None
|
||||
# + where 'fake-platform-token' = get_platform_integration_token()
|
||||
#CREWAI_PLATFORM_INTEGRATION_TOKEN=fake-platform-token
|
||||
CREWAI_PERSONAL_ACCESS_TOKEN=fake-personal-token
|
||||
CREWAI_PLUS_URL=https://fake.crewai.com
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Other Service API Keys
|
||||
# -----------------------------------------------------------------------------
|
||||
ZAPIER_API_KEY=fake-zapier-key
|
||||
PATRONUS_API_KEY=fake-patronus-key
|
||||
MINDS_API_KEY=fake-minds-key
|
||||
HF_TOKEN=fake-hf-token
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Feature Flags/Testing Modes
|
||||
# -----------------------------------------------------------------------------
|
||||
CREWAI_DISABLE_TELEMETRY=true
|
||||
OTEL_SDK_DISABLED=true
|
||||
CREWAI_TESTING=true
|
||||
CREWAI_TRACING_ENABLED=false
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Testing/CI Configuration
|
||||
# -----------------------------------------------------------------------------
|
||||
# VCR recording mode: "none" (default), "new_episodes", "all", "once"
|
||||
PYTEST_VCR_RECORD_MODE=none
|
||||
|
||||
# Set to "true" by GitHub when running in GitHub Actions
|
||||
# GITHUB_ACTIONS=false
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Python Configuration
|
||||
# -----------------------------------------------------------------------------
|
||||
PYTHONUNBUFFERED=1
|
||||
18
.github/workflows/tests.yml
vendored
18
.github/workflows/tests.yml
vendored
@@ -5,18 +5,6 @@ on: [pull_request]
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
env:
|
||||
OPENAI_API_KEY: fake-api-key
|
||||
PYTHONUNBUFFERED: 1
|
||||
BRAVE_API_KEY: fake-brave-key
|
||||
SNOWFLAKE_USER: fake-snowflake-user
|
||||
SNOWFLAKE_PASSWORD: fake-snowflake-password
|
||||
SNOWFLAKE_ACCOUNT: fake-snowflake-account
|
||||
SNOWFLAKE_WAREHOUSE: fake-snowflake-warehouse
|
||||
SNOWFLAKE_DATABASE: fake-snowflake-database
|
||||
SNOWFLAKE_SCHEMA: fake-snowflake-schema
|
||||
EMBEDCHAIN_DB_URI: sqlite:///test.db
|
||||
|
||||
jobs:
|
||||
tests:
|
||||
name: tests (${{ matrix.python-version }})
|
||||
@@ -84,26 +72,20 @@ jobs:
|
||||
# fi
|
||||
|
||||
cd lib/crewai && uv run pytest \
|
||||
--block-network \
|
||||
--timeout=30 \
|
||||
-vv \
|
||||
--splits 8 \
|
||||
--group ${{ matrix.group }} \
|
||||
$DURATIONS_ARG \
|
||||
--durations=10 \
|
||||
-n auto \
|
||||
--maxfail=3
|
||||
|
||||
- name: Run tool tests (group ${{ matrix.group }} of 8)
|
||||
run: |
|
||||
cd lib/crewai-tools && uv run pytest \
|
||||
--block-network \
|
||||
--timeout=30 \
|
||||
-vv \
|
||||
--splits 8 \
|
||||
--group ${{ matrix.group }} \
|
||||
--durations=10 \
|
||||
-n auto \
|
||||
--maxfail=3
|
||||
|
||||
|
||||
|
||||
193
conftest.py
Normal file
193
conftest.py
Normal file
@@ -0,0 +1,193 @@
|
||||
"""Pytest configuration for crewAI workspace."""
|
||||
|
||||
from collections.abc import Generator
|
||||
import os
|
||||
from pathlib import Path
|
||||
import tempfile
|
||||
from typing import Any
|
||||
|
||||
from dotenv import load_dotenv
|
||||
import pytest
|
||||
from vcr.request import Request # type: ignore[import-untyped]
|
||||
|
||||
|
||||
env_test_path = Path(__file__).parent / ".env.test"
|
||||
load_dotenv(env_test_path, override=True)
|
||||
load_dotenv(override=True)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True, scope="function")
|
||||
def cleanup_event_handlers() -> Generator[None, Any, None]:
|
||||
"""Clean up event bus handlers after each test to prevent test pollution."""
|
||||
yield
|
||||
|
||||
try:
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
|
||||
with crewai_event_bus._rwlock.w_locked():
|
||||
crewai_event_bus._sync_handlers.clear()
|
||||
crewai_event_bus._async_handlers.clear()
|
||||
except Exception: # noqa: S110
|
||||
pass
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True, scope="function")
|
||||
def setup_test_environment() -> Generator[None, Any, None]:
|
||||
"""Setup test environment for crewAI workspace."""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
storage_dir = Path(temp_dir) / "crewai_test_storage"
|
||||
storage_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if not storage_dir.exists() or not storage_dir.is_dir():
|
||||
raise RuntimeError(
|
||||
f"Failed to create test storage directory: {storage_dir}"
|
||||
)
|
||||
|
||||
try:
|
||||
test_file = storage_dir / ".permissions_test"
|
||||
test_file.touch()
|
||||
test_file.unlink()
|
||||
except (OSError, IOError) as e:
|
||||
raise RuntimeError(
|
||||
f"Test storage directory {storage_dir} is not writable: {e}"
|
||||
) from e
|
||||
|
||||
os.environ["CREWAI_STORAGE_DIR"] = str(storage_dir)
|
||||
os.environ["CREWAI_TESTING"] = "true"
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
os.environ.pop("CREWAI_TESTING", "true")
|
||||
os.environ.pop("CREWAI_STORAGE_DIR", None)
|
||||
os.environ.pop("CREWAI_DISABLE_TELEMETRY", "true")
|
||||
os.environ.pop("OTEL_SDK_DISABLED", "true")
|
||||
os.environ.pop("OPENAI_BASE_URL", "https://api.openai.com/v1")
|
||||
os.environ.pop("OPENAI_API_BASE", "https://api.openai.com/v1")
|
||||
|
||||
|
||||
HEADERS_TO_FILTER = {
|
||||
"authorization": "AUTHORIZATION-XXX",
|
||||
"content-security-policy": "CSP-FILTERED",
|
||||
"cookie": "COOKIE-XXX",
|
||||
"set-cookie": "SET-COOKIE-XXX",
|
||||
"permissions-policy": "PERMISSIONS-POLICY-XXX",
|
||||
"referrer-policy": "REFERRER-POLICY-XXX",
|
||||
"strict-transport-security": "STS-XXX",
|
||||
"x-content-type-options": "X-CONTENT-TYPE-XXX",
|
||||
"x-frame-options": "X-FRAME-OPTIONS-XXX",
|
||||
"x-permitted-cross-domain-policies": "X-PERMITTED-XXX",
|
||||
"x-request-id": "X-REQUEST-ID-XXX",
|
||||
"x-runtime": "X-RUNTIME-XXX",
|
||||
"x-xss-protection": "X-XSS-PROTECTION-XXX",
|
||||
"x-stainless-arch": "X-STAINLESS-ARCH-XXX",
|
||||
"x-stainless-os": "X-STAINLESS-OS-XXX",
|
||||
"x-stainless-read-timeout": "X-STAINLESS-READ-TIMEOUT-XXX",
|
||||
"cf-ray": "CF-RAY-XXX",
|
||||
"etag": "ETAG-XXX",
|
||||
"Strict-Transport-Security": "STS-XXX",
|
||||
"access-control-expose-headers": "ACCESS-CONTROL-XXX",
|
||||
"openai-organization": "OPENAI-ORG-XXX",
|
||||
"openai-project": "OPENAI-PROJECT-XXX",
|
||||
"x-ratelimit-limit-requests": "X-RATELIMIT-LIMIT-REQUESTS-XXX",
|
||||
"x-ratelimit-limit-tokens": "X-RATELIMIT-LIMIT-TOKENS-XXX",
|
||||
"x-ratelimit-remaining-requests": "X-RATELIMIT-REMAINING-REQUESTS-XXX",
|
||||
"x-ratelimit-remaining-tokens": "X-RATELIMIT-REMAINING-TOKENS-XXX",
|
||||
"x-ratelimit-reset-requests": "X-RATELIMIT-RESET-REQUESTS-XXX",
|
||||
"x-ratelimit-reset-tokens": "X-RATELIMIT-RESET-TOKENS-XXX",
|
||||
"x-goog-api-key": "X-GOOG-API-KEY-XXX",
|
||||
"api-key": "X-API-KEY-XXX",
|
||||
"User-Agent": "X-USER-AGENT-XXX",
|
||||
"apim-request-id:": "X-API-CLIENT-REQUEST-ID-XXX",
|
||||
"azureml-model-session": "AZUREML-MODEL-SESSION-XXX",
|
||||
"x-ms-client-request-id": "X-MS-CLIENT-REQUEST-ID-XXX",
|
||||
"x-ms-region": "X-MS-REGION-XXX",
|
||||
"apim-request-id": "APIM-REQUEST-ID-XXX",
|
||||
"x-api-key": "X-API-KEY-XXX",
|
||||
"anthropic-organization-id": "ANTHROPIC-ORGANIZATION-ID-XXX",
|
||||
"request-id": "REQUEST-ID-XXX",
|
||||
"anthropic-ratelimit-input-tokens-limit": "ANTHROPIC-RATELIMIT-INPUT-TOKENS-LIMIT-XXX",
|
||||
"anthropic-ratelimit-input-tokens-remaining": "ANTHROPIC-RATELIMIT-INPUT-TOKENS-REMAINING-XXX",
|
||||
"anthropic-ratelimit-input-tokens-reset": "ANTHROPIC-RATELIMIT-INPUT-TOKENS-RESET-XXX",
|
||||
"anthropic-ratelimit-output-tokens-limit": "ANTHROPIC-RATELIMIT-OUTPUT-TOKENS-LIMIT-XXX",
|
||||
"anthropic-ratelimit-output-tokens-remaining": "ANTHROPIC-RATELIMIT-OUTPUT-TOKENS-REMAINING-XXX",
|
||||
"anthropic-ratelimit-output-tokens-reset": "ANTHROPIC-RATELIMIT-OUTPUT-TOKENS-RESET-XXX",
|
||||
"anthropic-ratelimit-tokens-limit": "ANTHROPIC-RATELIMIT-TOKENS-LIMIT-XXX",
|
||||
"anthropic-ratelimit-tokens-remaining": "ANTHROPIC-RATELIMIT-TOKENS-REMAINING-XXX",
|
||||
"anthropic-ratelimit-tokens-reset": "ANTHROPIC-RATELIMIT-TOKENS-RESET-XXX",
|
||||
"x-amz-date": "X-AMZ-DATE-XXX",
|
||||
"amz-sdk-invocation-id": "AMZ-SDK-INVOCATION-ID-XXX",
|
||||
"accept-encoding": "ACCEPT-ENCODING-XXX",
|
||||
"x-amzn-requestid": "X-AMZN-REQUESTID-XXX",
|
||||
"x-amzn-RequestId": "X-AMZN-REQUESTID-XXX",
|
||||
}
|
||||
|
||||
|
||||
def _filter_request_headers(request: Request) -> Request: # type: ignore[no-any-unimported]
|
||||
"""Filter sensitive headers from request before recording."""
|
||||
for header_name, replacement in HEADERS_TO_FILTER.items():
|
||||
for variant in [header_name, header_name.upper(), header_name.title()]:
|
||||
if variant in request.headers:
|
||||
request.headers[variant] = [replacement]
|
||||
|
||||
request.method = request.method.upper()
|
||||
return request
|
||||
|
||||
|
||||
def _filter_response_headers(response: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Filter sensitive headers from response before recording."""
|
||||
for header_name, replacement in HEADERS_TO_FILTER.items():
|
||||
for variant in [header_name, header_name.upper(), header_name.title()]:
|
||||
if variant in response["headers"]:
|
||||
response["headers"][variant] = [replacement]
|
||||
return response
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def vcr_cassette_dir(request: Any) -> str:
|
||||
"""Generate cassette directory path based on test module location.
|
||||
|
||||
Organizes cassettes to mirror test directory structure within each package:
|
||||
lib/crewai/tests/llms/google/test_google.py -> lib/crewai/tests/cassettes/llms/google/
|
||||
lib/crewai-tools/tests/tools/test_search.py -> lib/crewai-tools/tests/cassettes/tools/
|
||||
"""
|
||||
test_file = Path(request.fspath)
|
||||
|
||||
for parent in test_file.parents:
|
||||
if parent.name in ("crewai", "crewai-tools") and parent.parent.name == "lib":
|
||||
package_root = parent
|
||||
break
|
||||
else:
|
||||
package_root = test_file.parent
|
||||
|
||||
tests_root = package_root / "tests"
|
||||
test_dir = test_file.parent
|
||||
|
||||
if test_dir != tests_root:
|
||||
relative_path = test_dir.relative_to(tests_root)
|
||||
cassette_dir = tests_root / "cassettes" / relative_path
|
||||
else:
|
||||
cassette_dir = tests_root / "cassettes"
|
||||
|
||||
cassette_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
return str(cassette_dir)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def vcr_config(vcr_cassette_dir: str) -> dict[str, Any]:
|
||||
"""Configure VCR with organized cassette storage."""
|
||||
config = {
|
||||
"cassette_library_dir": vcr_cassette_dir,
|
||||
"record_mode": os.getenv("PYTEST_VCR_RECORD_MODE", "once"),
|
||||
"filter_headers": [(k, v) for k, v in HEADERS_TO_FILTER.items()],
|
||||
"before_record_request": _filter_request_headers,
|
||||
"before_record_response": _filter_response_headers,
|
||||
"filter_query_parameters": ["key"],
|
||||
"match_on": ["method", "scheme", "host", "port", "path"],
|
||||
}
|
||||
|
||||
if os.getenv("GITHUB_ACTIONS") == "true":
|
||||
config["record_mode"] = "none"
|
||||
|
||||
return config
|
||||
@@ -1089,6 +1089,50 @@ CrewAI supports streaming responses from LLMs, allowing your application to rece
|
||||
</Tab>
|
||||
</Tabs>
|
||||
|
||||
## Async LLM Calls
|
||||
|
||||
CrewAI supports asynchronous LLM calls for improved performance and concurrency in your AI workflows. Async calls allow you to run multiple LLM requests concurrently without blocking, making them ideal for high-throughput applications and parallel agent operations.
|
||||
|
||||
<Tabs>
|
||||
<Tab title="Basic Usage">
|
||||
Use the `acall` method for asynchronous LLM requests:
|
||||
|
||||
```python
|
||||
import asyncio
|
||||
from crewai import LLM
|
||||
|
||||
async def main():
|
||||
llm = LLM(model="openai/gpt-4o")
|
||||
|
||||
# Single async call
|
||||
response = await llm.acall("What is the capital of France?")
|
||||
print(response)
|
||||
|
||||
asyncio.run(main())
|
||||
```
|
||||
|
||||
The `acall` method supports all the same parameters as the synchronous `call` method, including messages, tools, and callbacks.
|
||||
</Tab>
|
||||
|
||||
<Tab title="With Streaming">
|
||||
Combine async calls with streaming for real-time concurrent responses:
|
||||
|
||||
```python
|
||||
import asyncio
|
||||
from crewai import LLM
|
||||
|
||||
async def stream_async():
|
||||
llm = LLM(model="openai/gpt-4o", stream=True)
|
||||
|
||||
response = await llm.acall("Write a short story about AI")
|
||||
|
||||
print(response)
|
||||
|
||||
asyncio.run(stream_async())
|
||||
```
|
||||
</Tab>
|
||||
</Tabs>
|
||||
|
||||
## Structured LLM Calls
|
||||
|
||||
CrewAI supports structured responses from LLM calls by allowing you to define a `response_format` using a Pydantic model. This enables the framework to automatically parse and validate the output, making it easier to integrate the response into your application without manual post-processing.
|
||||
|
||||
@@ -218,7 +218,7 @@ Update the root `README.md` only if the tool introduces a new category or notabl
|
||||
|
||||
## Discovery and specs
|
||||
|
||||
Our internal tooling discovers classes whose names end with `Tool`. Keep your class exported from the module path under `crewai_tools/tools/...` to be picked up by scripts like `generate_tool_specs.py`.
|
||||
Our internal tooling discovers classes whose names end with `Tool`. Keep your class exported from the module path under `crewai_tools/tools/...` to be picked up by scripts like `crewai_tools.generate_tool_specs.py`.
|
||||
|
||||
---
|
||||
|
||||
|
||||
@@ -8,17 +8,17 @@ authors = [
|
||||
]
|
||||
requires-python = ">=3.10, <3.14"
|
||||
dependencies = [
|
||||
"lancedb>=0.5.4",
|
||||
"pytube>=15.0.0",
|
||||
"requests>=2.32.5",
|
||||
"docker>=7.1.0",
|
||||
"lancedb~=0.5.4",
|
||||
"pytube~=15.0.0",
|
||||
"requests~=2.32.5",
|
||||
"docker~=7.1.0",
|
||||
"crewai==1.6.1",
|
||||
"lancedb>=0.5.4",
|
||||
"tiktoken>=0.8.0",
|
||||
"beautifulsoup4>=4.13.4",
|
||||
"python-docx>=1.2.0",
|
||||
"youtube-transcript-api>=1.2.2",
|
||||
"pymupdf>=1.26.6",
|
||||
"lancedb~=0.5.4",
|
||||
"tiktoken~=0.8.0",
|
||||
"beautifulsoup4~=4.13.4",
|
||||
"python-docx~=1.2.0",
|
||||
"youtube-transcript-api~=1.2.2",
|
||||
"pymupdf~=1.26.6",
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -4,17 +4,20 @@ from collections.abc import Mapping
|
||||
import inspect
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any, cast
|
||||
from typing import Any
|
||||
|
||||
from crewai.tools.base_tool import BaseTool, EnvVar
|
||||
from crewai_tools import tools
|
||||
from pydantic import BaseModel
|
||||
from pydantic.json_schema import GenerateJsonSchema
|
||||
from pydantic_core import PydanticOmit
|
||||
|
||||
from crewai_tools import tools
|
||||
|
||||
|
||||
class SchemaGenerator(GenerateJsonSchema):
|
||||
def handle_invalid_for_json_schema(self, schema, error_info):
|
||||
def handle_invalid_for_json_schema(
|
||||
self, schema: Any, error_info: Any
|
||||
) -> dict[str, Any]:
|
||||
raise PydanticOmit
|
||||
|
||||
|
||||
@@ -73,7 +76,7 @@ class ToolSpecExtractor:
|
||||
|
||||
@staticmethod
|
||||
def _extract_field_default(
|
||||
field: dict | None, fallback: str | list[Any] = ""
|
||||
field: dict[str, Any] | None, fallback: str | list[Any] = ""
|
||||
) -> str | list[Any] | int:
|
||||
if not field:
|
||||
return fallback
|
||||
@@ -83,7 +86,7 @@ class ToolSpecExtractor:
|
||||
return default if isinstance(default, (list, str, int)) else fallback
|
||||
|
||||
@staticmethod
|
||||
def _extract_params(args_schema_field: dict | None) -> dict[str, Any]:
|
||||
def _extract_params(args_schema_field: dict[str, Any] | None) -> dict[str, Any]:
|
||||
if not args_schema_field:
|
||||
return {}
|
||||
|
||||
@@ -94,15 +97,15 @@ class ToolSpecExtractor:
|
||||
):
|
||||
return {}
|
||||
|
||||
# Cast to type[BaseModel] after runtime check
|
||||
schema_class = cast(type[BaseModel], args_schema_class)
|
||||
try:
|
||||
return schema_class.model_json_schema(schema_generator=SchemaGenerator)
|
||||
return args_schema_class.model_json_schema(schema_generator=SchemaGenerator)
|
||||
except Exception:
|
||||
return {}
|
||||
|
||||
@staticmethod
|
||||
def _extract_env_vars(env_vars_field: dict | None) -> list[dict[str, Any]]:
|
||||
def _extract_env_vars(
|
||||
env_vars_field: dict[str, Any] | None,
|
||||
) -> list[dict[str, Any]]:
|
||||
if not env_vars_field:
|
||||
return []
|
||||
|
||||
@@ -1,21 +0,0 @@
|
||||
import pytest
|
||||
|
||||
|
||||
def pytest_configure(config):
|
||||
"""Register custom markers."""
|
||||
config.addinivalue_line("markers", "integration: mark test as an integration test")
|
||||
config.addinivalue_line("markers", "asyncio: mark test as an async test")
|
||||
|
||||
# Set the asyncio loop scope through ini configuration
|
||||
config.inicfg["asyncio_mode"] = "auto"
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def event_loop():
|
||||
"""Create an instance of the default event loop for each test case."""
|
||||
import asyncio
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
yield loop
|
||||
loop.close()
|
||||
@@ -2,7 +2,7 @@ import json
|
||||
from unittest import mock
|
||||
|
||||
from crewai.tools.base_tool import BaseTool, EnvVar
|
||||
from generate_tool_specs import ToolSpecExtractor
|
||||
from crewai_tools.generate_tool_specs import ToolSpecExtractor
|
||||
from pydantic import BaseModel, Field
|
||||
import pytest
|
||||
|
||||
@@ -61,8 +61,8 @@ def test_unwrap_schema(extractor):
|
||||
@pytest.fixture
|
||||
def mock_tool_extractor(extractor):
|
||||
with (
|
||||
mock.patch("generate_tool_specs.dir", return_value=["MockTool"]),
|
||||
mock.patch("generate_tool_specs.getattr", return_value=MockTool),
|
||||
mock.patch("crewai_tools.generate_tool_specs.dir", return_value=["MockTool"]),
|
||||
mock.patch("crewai_tools.generate_tool_specs.getattr", return_value=MockTool),
|
||||
):
|
||||
extractor.extract_all_tools()
|
||||
assert len(extractor.tools_spec) == 1
|
||||
|
||||
@@ -4,7 +4,7 @@ from crewai_tools.tools.firecrawl_crawl_website_tool.firecrawl_crawl_website_too
|
||||
FirecrawlCrawlWebsiteTool,
|
||||
)
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_firecrawl_crawl_tool_integration():
|
||||
tool = FirecrawlCrawlWebsiteTool(config={
|
||||
"limit": 2,
|
||||
|
||||
@@ -4,7 +4,7 @@ from crewai_tools.tools.firecrawl_scrape_website_tool.firecrawl_scrape_website_t
|
||||
FirecrawlScrapeWebsiteTool,
|
||||
)
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_firecrawl_scrape_tool_integration():
|
||||
tool = FirecrawlScrapeWebsiteTool()
|
||||
result = tool.run(url="https://firecrawl.dev")
|
||||
|
||||
@@ -3,7 +3,7 @@ import pytest
|
||||
from crewai_tools.tools.firecrawl_search_tool.firecrawl_search_tool import FirecrawlSearchTool
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_firecrawl_search_tool_integration():
|
||||
tool = FirecrawlSearchTool()
|
||||
result = tool.run(query="firecrawl")
|
||||
|
||||
@@ -23,15 +23,13 @@ from crewai_tools.tools.rag.rag_tool import Adapter
|
||||
import pytest
|
||||
|
||||
|
||||
pytestmark = [pytest.mark.vcr(filter_headers=["authorization"])]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_adapter():
|
||||
mock_adapter = MagicMock(spec=Adapter)
|
||||
return mock_adapter
|
||||
|
||||
|
||||
@pytest.mark.vcr()
|
||||
def test_directory_search_tool():
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
test_file = Path(temp_dir) / "test.txt"
|
||||
@@ -65,6 +63,7 @@ def test_pdf_search_tool(mock_adapter):
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.vcr()
|
||||
def test_txt_search_tool():
|
||||
with tempfile.NamedTemporaryFile(suffix=".txt", delete=False) as temp_file:
|
||||
temp_file.write(b"This is a test file for txt search")
|
||||
@@ -102,6 +101,7 @@ def test_docx_search_tool(mock_adapter):
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.vcr()
|
||||
def test_json_search_tool():
|
||||
with tempfile.NamedTemporaryFile(suffix=".json", delete=False) as temp_file:
|
||||
temp_file.write(b'{"test": "This is a test JSON file"}')
|
||||
@@ -127,6 +127,7 @@ def test_xml_search_tool(mock_adapter):
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.vcr()
|
||||
def test_csv_search_tool():
|
||||
with tempfile.NamedTemporaryFile(suffix=".csv", delete=False) as temp_file:
|
||||
temp_file.write(b"name,description\ntest,This is a test CSV file")
|
||||
@@ -141,6 +142,7 @@ def test_csv_search_tool():
|
||||
os.unlink(temp_file_path)
|
||||
|
||||
|
||||
@pytest.mark.vcr()
|
||||
def test_mdx_search_tool():
|
||||
with tempfile.NamedTemporaryFile(suffix=".mdx", delete=False) as temp_file:
|
||||
temp_file.write(b"# Test MDX\nThis is a test MDX file")
|
||||
|
||||
@@ -9,35 +9,35 @@ authors = [
|
||||
requires-python = ">=3.10, <3.14"
|
||||
dependencies = [
|
||||
# Core Dependencies
|
||||
"pydantic>=2.11.9",
|
||||
"openai>=1.13.3",
|
||||
"pydantic~=2.11.9",
|
||||
"openai~=1.83.0",
|
||||
"instructor>=1.3.3",
|
||||
# Text Processing
|
||||
"pdfplumber>=0.11.4",
|
||||
"regex>=2024.9.11",
|
||||
"pdfplumber~=0.11.4",
|
||||
"regex~=2024.9.11",
|
||||
# Telemetry and Monitoring
|
||||
"opentelemetry-api>=1.30.0",
|
||||
"opentelemetry-sdk>=1.30.0",
|
||||
"opentelemetry-exporter-otlp-proto-http>=1.30.0",
|
||||
"opentelemetry-api~=1.34.0",
|
||||
"opentelemetry-sdk~=1.34.0",
|
||||
"opentelemetry-exporter-otlp-proto-http~=1.34.0",
|
||||
# Data Handling
|
||||
"chromadb~=1.1.0",
|
||||
"tokenizers>=0.20.3",
|
||||
"openpyxl>=3.1.5",
|
||||
"tokenizers~=0.20.3",
|
||||
"openpyxl~=3.1.5",
|
||||
# Authentication and Security
|
||||
"python-dotenv>=1.1.1",
|
||||
"pyjwt>=2.9.0",
|
||||
"python-dotenv~=1.1.1",
|
||||
"pyjwt~=2.9.0",
|
||||
# Configuration and Utils
|
||||
"click>=8.1.7",
|
||||
"appdirs>=1.4.4",
|
||||
"jsonref>=1.1.0",
|
||||
"json-repair==0.25.2",
|
||||
"uv>=0.4.25",
|
||||
"tomli-w>=1.1.0",
|
||||
"tomli>=2.0.2",
|
||||
"json5>=0.10.0",
|
||||
"portalocker==2.7.0",
|
||||
"pydantic-settings>=2.10.1",
|
||||
"mcp>=1.16.0",
|
||||
"click~=8.1.7",
|
||||
"appdirs~=1.4.4",
|
||||
"jsonref~=1.1.0",
|
||||
"json-repair~=0.25.2",
|
||||
"tomli-w~=1.1.0",
|
||||
"tomli~=2.0.2",
|
||||
"json5~=0.10.0",
|
||||
"portalocker~=2.7.0",
|
||||
"pydantic-settings~=2.10.1",
|
||||
"mcp~=1.16.0",
|
||||
"uv~=0.9.13",
|
||||
]
|
||||
|
||||
[project.urls]
|
||||
@@ -53,50 +53,48 @@ tools = [
|
||||
embeddings = [
|
||||
"tiktoken~=0.8.0"
|
||||
]
|
||||
pdfplumber = [
|
||||
"pdfplumber>=0.11.4",
|
||||
]
|
||||
pandas = [
|
||||
"pandas>=2.2.3",
|
||||
"pandas~=2.2.3",
|
||||
]
|
||||
openpyxl = [
|
||||
"openpyxl>=3.1.5",
|
||||
"openpyxl~=3.1.5",
|
||||
]
|
||||
mem0 = ["mem0ai>=0.1.94"]
|
||||
mem0 = ["mem0ai~=0.1.94"]
|
||||
docling = [
|
||||
"docling>=2.12.0",
|
||||
"docling~=2.63.0",
|
||||
]
|
||||
qdrant = [
|
||||
"qdrant-client[fastembed]>=1.14.3",
|
||||
"qdrant-client[fastembed]~=1.14.3",
|
||||
]
|
||||
aws = [
|
||||
"boto3>=1.40.38",
|
||||
"boto3~=1.40.38",
|
||||
"aiobotocore~=2.25.2",
|
||||
]
|
||||
watson = [
|
||||
"ibm-watsonx-ai>=1.3.39",
|
||||
"ibm-watsonx-ai~=1.3.39",
|
||||
]
|
||||
voyageai = [
|
||||
"voyageai>=0.3.5",
|
||||
"voyageai~=0.3.5",
|
||||
]
|
||||
litellm = [
|
||||
"litellm>=1.74.9",
|
||||
"litellm~=1.74.9",
|
||||
]
|
||||
bedrock = [
|
||||
"boto3>=1.40.45",
|
||||
"boto3~=1.40.45",
|
||||
]
|
||||
google-genai = [
|
||||
"google-genai>=1.2.0",
|
||||
"google-genai~=1.2.0",
|
||||
]
|
||||
azure-ai-inference = [
|
||||
"azure-ai-inference>=1.0.0b9",
|
||||
"azure-ai-inference~=1.0.0b9",
|
||||
]
|
||||
anthropic = [
|
||||
"anthropic>=0.69.0",
|
||||
"anthropic~=0.71.0",
|
||||
]
|
||||
a2a = [
|
||||
a2a = [
|
||||
"a2a-sdk~=0.3.10",
|
||||
"httpx-auth>=0.23.1",
|
||||
"httpx-sse>=0.4.0",
|
||||
"httpx-auth~=0.23.1",
|
||||
"httpx-sse~=0.4.0",
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -950,34 +950,15 @@ class Crew(FlowTrackable, BaseModel):
|
||||
|
||||
def _handle_crew_planning(self) -> None:
|
||||
"""Handles the Crew planning."""
|
||||
import re
|
||||
|
||||
self._logger.log("info", "Planning the crew execution")
|
||||
result = CrewPlanner(
|
||||
tasks=self.tasks, planning_agent_llm=self.planning_llm
|
||||
)._handle_crew_planning()
|
||||
|
||||
plan_map: dict[int, str] = {}
|
||||
for step_plan in result.list_of_plans_per_task:
|
||||
match = re.search(r"Task Number (\d+)", step_plan.task, re.IGNORECASE)
|
||||
if match:
|
||||
task_number = int(match.group(1))
|
||||
plan_map[task_number] = step_plan.plan
|
||||
else:
|
||||
self._logger.log(
|
||||
"warning",
|
||||
f"Could not extract task number from plan task field: {step_plan.task}",
|
||||
)
|
||||
|
||||
for idx, task in enumerate(self.tasks):
|
||||
task_number = idx + 1 # Task numbers are 1-indexed
|
||||
if task_number in plan_map:
|
||||
task.description += plan_map[task_number]
|
||||
else:
|
||||
self._logger.log(
|
||||
"warning",
|
||||
f"No plan found for task {task_number}. Task description: {task.description}",
|
||||
)
|
||||
for task, step_plan in zip(
|
||||
self.tasks, result.list_of_plans_per_task, strict=False
|
||||
):
|
||||
task.description += step_plan.plan
|
||||
|
||||
def _store_execution_log(
|
||||
self,
|
||||
|
||||
@@ -76,7 +76,7 @@ class TraceBatchManager:
|
||||
use_ephemeral: bool = False,
|
||||
) -> TraceBatch:
|
||||
"""Initialize a new trace batch (thread-safe)"""
|
||||
with self._init_lock:
|
||||
with self._batch_ready_cv:
|
||||
if self.current_batch is not None:
|
||||
logger.debug(
|
||||
"Batch already initialized, skipping duplicate initialization"
|
||||
@@ -99,7 +99,6 @@ class TraceBatchManager:
|
||||
self.backend_initialized = True
|
||||
|
||||
self._batch_ready_cv.notify_all()
|
||||
|
||||
return self.current_batch
|
||||
|
||||
def _initialize_backend_batch(
|
||||
@@ -107,7 +106,7 @@ class TraceBatchManager:
|
||||
user_context: dict[str, str],
|
||||
execution_metadata: dict[str, Any],
|
||||
use_ephemeral: bool = False,
|
||||
):
|
||||
) -> None:
|
||||
"""Send batch initialization to backend"""
|
||||
|
||||
if not is_tracing_enabled_in_context():
|
||||
@@ -204,7 +203,7 @@ class TraceBatchManager:
|
||||
return False
|
||||
return True
|
||||
|
||||
def add_event(self, trace_event: TraceEvent):
|
||||
def add_event(self, trace_event: TraceEvent) -> None:
|
||||
"""Add event to buffer"""
|
||||
self.event_buffer.append(trace_event)
|
||||
|
||||
@@ -300,7 +299,7 @@ class TraceBatchManager:
|
||||
|
||||
return finalized_batch
|
||||
|
||||
def _finalize_backend_batch(self, events_count: int = 0):
|
||||
def _finalize_backend_batch(self, events_count: int = 0) -> None:
|
||||
"""Send batch finalization to backend
|
||||
|
||||
Args:
|
||||
@@ -366,7 +365,7 @@ class TraceBatchManager:
|
||||
logger.error(f"❌ Error finalizing trace batch: {e}")
|
||||
self.plus_api.mark_trace_batch_as_failed(self.trace_batch_id, str(e))
|
||||
|
||||
def _cleanup_batch_data(self):
|
||||
def _cleanup_batch_data(self) -> None:
|
||||
"""Clean up batch data after successful finalization to free memory"""
|
||||
try:
|
||||
if hasattr(self, "event_buffer") and self.event_buffer:
|
||||
@@ -411,7 +410,7 @@ class TraceBatchManager:
|
||||
lambda: self.current_batch is not None, timeout=timeout
|
||||
)
|
||||
|
||||
def record_start_time(self, key: str):
|
||||
def record_start_time(self, key: str) -> None:
|
||||
"""Record start time for duration calculation"""
|
||||
self.execution_start_times[key] = datetime.now(timezone.utc)
|
||||
|
||||
|
||||
@@ -71,6 +71,7 @@ from crewai.events.types.reasoning_events import (
|
||||
AgentReasoningFailedEvent,
|
||||
AgentReasoningStartedEvent,
|
||||
)
|
||||
from crewai.events.types.system_events import SignalEvent, on_signal
|
||||
from crewai.events.types.task_events import (
|
||||
TaskCompletedEvent,
|
||||
TaskFailedEvent,
|
||||
@@ -159,6 +160,7 @@ class TraceCollectionListener(BaseEventListener):
|
||||
self._register_flow_event_handlers(crewai_event_bus)
|
||||
self._register_context_event_handlers(crewai_event_bus)
|
||||
self._register_action_event_handlers(crewai_event_bus)
|
||||
self._register_system_event_handlers(crewai_event_bus)
|
||||
|
||||
self._listeners_setup = True
|
||||
|
||||
@@ -458,6 +460,15 @@ class TraceCollectionListener(BaseEventListener):
|
||||
) -> None:
|
||||
self._handle_action_event("knowledge_query_failed", source, event)
|
||||
|
||||
def _register_system_event_handlers(self, event_bus: CrewAIEventsBus) -> None:
|
||||
"""Register handlers for system signal events (SIGTERM, SIGINT, etc.)."""
|
||||
|
||||
@on_signal
|
||||
def handle_signal(source: Any, event: SignalEvent) -> None:
|
||||
"""Flush trace batch on system signals to prevent data loss."""
|
||||
if self.batch_manager.is_batch_initialized():
|
||||
self.batch_manager.finalize_batch()
|
||||
|
||||
def _initialize_crew_batch(self, source: Any, event: Any) -> None:
|
||||
"""Initialize trace batch.
|
||||
|
||||
|
||||
102
lib/crewai/src/crewai/events/types/system_events.py
Normal file
102
lib/crewai/src/crewai/events/types/system_events.py
Normal file
@@ -0,0 +1,102 @@
|
||||
"""System signal event types for CrewAI.
|
||||
|
||||
This module contains event types for system-level signals like SIGTERM,
|
||||
allowing listeners to perform cleanup operations before process termination.
|
||||
"""
|
||||
|
||||
from collections.abc import Callable
|
||||
from enum import IntEnum
|
||||
import signal
|
||||
from typing import Annotated, Literal, TypeVar
|
||||
|
||||
from pydantic import Field, TypeAdapter
|
||||
|
||||
from crewai.events.base_events import BaseEvent
|
||||
|
||||
|
||||
class SignalType(IntEnum):
|
||||
"""Enumeration of supported system signals."""
|
||||
|
||||
SIGTERM = signal.SIGTERM
|
||||
SIGINT = signal.SIGINT
|
||||
SIGHUP = signal.SIGHUP
|
||||
SIGTSTP = signal.SIGTSTP
|
||||
SIGCONT = signal.SIGCONT
|
||||
|
||||
|
||||
class SigTermEvent(BaseEvent):
|
||||
"""Event emitted when SIGTERM is received."""
|
||||
|
||||
type: Literal["SIGTERM"] = "SIGTERM"
|
||||
signal_number: SignalType = SignalType.SIGTERM
|
||||
reason: str | None = None
|
||||
|
||||
|
||||
class SigIntEvent(BaseEvent):
|
||||
"""Event emitted when SIGINT is received."""
|
||||
|
||||
type: Literal["SIGINT"] = "SIGINT"
|
||||
signal_number: SignalType = SignalType.SIGINT
|
||||
reason: str | None = None
|
||||
|
||||
|
||||
class SigHupEvent(BaseEvent):
|
||||
"""Event emitted when SIGHUP is received."""
|
||||
|
||||
type: Literal["SIGHUP"] = "SIGHUP"
|
||||
signal_number: SignalType = SignalType.SIGHUP
|
||||
reason: str | None = None
|
||||
|
||||
|
||||
class SigTStpEvent(BaseEvent):
|
||||
"""Event emitted when SIGTSTP is received.
|
||||
|
||||
Note: SIGSTOP cannot be caught - it immediately suspends the process.
|
||||
"""
|
||||
|
||||
type: Literal["SIGTSTP"] = "SIGTSTP"
|
||||
signal_number: SignalType = SignalType.SIGTSTP
|
||||
reason: str | None = None
|
||||
|
||||
|
||||
class SigContEvent(BaseEvent):
|
||||
"""Event emitted when SIGCONT is received."""
|
||||
|
||||
type: Literal["SIGCONT"] = "SIGCONT"
|
||||
signal_number: SignalType = SignalType.SIGCONT
|
||||
reason: str | None = None
|
||||
|
||||
|
||||
SignalEvent = Annotated[
|
||||
SigTermEvent | SigIntEvent | SigHupEvent | SigTStpEvent | SigContEvent,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
||||
signal_event_adapter: TypeAdapter[SignalEvent] = TypeAdapter(SignalEvent)
|
||||
|
||||
SIGNAL_EVENT_TYPES: tuple[type[BaseEvent], ...] = (
|
||||
SigTermEvent,
|
||||
SigIntEvent,
|
||||
SigHupEvent,
|
||||
SigTStpEvent,
|
||||
SigContEvent,
|
||||
)
|
||||
|
||||
|
||||
T = TypeVar("T", bound=Callable[[object, SignalEvent], None])
|
||||
|
||||
|
||||
def on_signal(func: T) -> T:
|
||||
"""Decorator to register a handler for all signal events.
|
||||
|
||||
Args:
|
||||
func: Handler function that receives (source, event) arguments.
|
||||
|
||||
Returns:
|
||||
The original function, registered for all signal event types.
|
||||
"""
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
|
||||
for event_type in SIGNAL_EVENT_TYPES:
|
||||
crewai_event_bus.on(event_type)(func)
|
||||
return func
|
||||
@@ -57,7 +57,12 @@ if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.get_supported_openai_params import (
|
||||
get_supported_openai_params,
|
||||
)
|
||||
from litellm.types.utils import ChatCompletionDeltaToolCall, Choices, ModelResponse
|
||||
from litellm.types.utils import (
|
||||
ChatCompletionDeltaToolCall,
|
||||
Choices,
|
||||
Function,
|
||||
ModelResponse,
|
||||
)
|
||||
from litellm.utils import supports_response_schema
|
||||
|
||||
from crewai.agent.core import Agent
|
||||
@@ -73,7 +78,12 @@ try:
|
||||
from litellm.litellm_core_utils.get_supported_openai_params import (
|
||||
get_supported_openai_params,
|
||||
)
|
||||
from litellm.types.utils import ChatCompletionDeltaToolCall, Choices, ModelResponse
|
||||
from litellm.types.utils import (
|
||||
ChatCompletionDeltaToolCall,
|
||||
Choices,
|
||||
Function,
|
||||
ModelResponse,
|
||||
)
|
||||
from litellm.utils import supports_response_schema
|
||||
|
||||
LITELLM_AVAILABLE = True
|
||||
@@ -84,6 +94,7 @@ except ImportError:
|
||||
ContextWindowExceededError = Exception # type: ignore
|
||||
get_supported_openai_params = None # type: ignore
|
||||
ChatCompletionDeltaToolCall = None # type: ignore
|
||||
Function = None # type: ignore
|
||||
ModelResponse = None # type: ignore
|
||||
supports_response_schema = None # type: ignore
|
||||
CustomLogger = None # type: ignore
|
||||
@@ -610,7 +621,9 @@ class LLM(BaseLLM):
|
||||
self.callbacks = callbacks
|
||||
self.context_window_size = 0
|
||||
self.reasoning_effort = reasoning_effort
|
||||
self.additional_params = kwargs
|
||||
self.additional_params = {
|
||||
k: v for k, v in kwargs.items() if k not in ("is_litellm", "provider")
|
||||
}
|
||||
self.is_anthropic = self._is_anthropic_model(model)
|
||||
self.stream = stream
|
||||
self.interceptor = interceptor
|
||||
@@ -1204,6 +1217,281 @@ class LLM(BaseLLM):
|
||||
)
|
||||
return text_response
|
||||
|
||||
async def _ahandle_non_streaming_response(
|
||||
self,
|
||||
params: dict[str, Any],
|
||||
callbacks: list[Any] | None = None,
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Task | None = None,
|
||||
from_agent: Agent | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str | Any:
|
||||
"""Handle an async non-streaming response from the LLM.
|
||||
|
||||
Args:
|
||||
params: Parameters for the completion call
|
||||
callbacks: Optional list of callback functions
|
||||
available_functions: Dict of available functions
|
||||
from_task: Optional Task that invoked the LLM
|
||||
from_agent: Optional Agent that invoked the LLM
|
||||
response_model: Optional Response model
|
||||
|
||||
Returns:
|
||||
str: The response text
|
||||
"""
|
||||
if response_model and self.is_litellm:
|
||||
from crewai.utilities.internal_instructor import InternalInstructor
|
||||
|
||||
messages = params.get("messages", [])
|
||||
if not messages:
|
||||
raise ValueError("Messages are required when using response_model")
|
||||
|
||||
combined_content = "\n\n".join(
|
||||
f"{msg['role'].upper()}: {msg['content']}" for msg in messages
|
||||
)
|
||||
|
||||
instructor_instance = InternalInstructor(
|
||||
content=combined_content,
|
||||
model=response_model,
|
||||
llm=self,
|
||||
)
|
||||
result = instructor_instance.to_pydantic()
|
||||
structured_response = result.model_dump_json()
|
||||
self._handle_emit_call_events(
|
||||
response=structured_response,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
)
|
||||
return structured_response
|
||||
|
||||
try:
|
||||
if response_model:
|
||||
params["response_model"] = response_model
|
||||
response = await litellm.acompletion(**params)
|
||||
|
||||
except ContextWindowExceededError as e:
|
||||
raise LLMContextLengthExceededError(str(e)) from e
|
||||
|
||||
if response_model is not None:
|
||||
if isinstance(response, BaseModel):
|
||||
structured_response = response.model_dump_json()
|
||||
self._handle_emit_call_events(
|
||||
response=structured_response,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
)
|
||||
return structured_response
|
||||
|
||||
response_message = cast(Choices, cast(ModelResponse, response).choices)[
|
||||
0
|
||||
].message
|
||||
text_response = response_message.content or ""
|
||||
|
||||
if callbacks and len(callbacks) > 0:
|
||||
for callback in callbacks:
|
||||
if hasattr(callback, "log_success_event"):
|
||||
usage_info = getattr(response, "usage", None)
|
||||
if usage_info:
|
||||
callback.log_success_event(
|
||||
kwargs=params,
|
||||
response_obj={"usage": usage_info},
|
||||
start_time=0,
|
||||
end_time=0,
|
||||
)
|
||||
|
||||
tool_calls = getattr(response_message, "tool_calls", [])
|
||||
|
||||
if (not tool_calls or not available_functions) and text_response:
|
||||
self._handle_emit_call_events(
|
||||
response=text_response,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
)
|
||||
return text_response
|
||||
|
||||
if tool_calls and not available_functions and not text_response:
|
||||
return tool_calls
|
||||
|
||||
tool_result = self._handle_tool_call(
|
||||
tool_calls, available_functions, from_task, from_agent
|
||||
)
|
||||
if tool_result is not None:
|
||||
return tool_result
|
||||
|
||||
self._handle_emit_call_events(
|
||||
response=text_response,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
)
|
||||
return text_response
|
||||
|
||||
async def _ahandle_streaming_response(
|
||||
self,
|
||||
params: dict[str, Any],
|
||||
callbacks: list[Any] | None = None,
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Task | None = None,
|
||||
from_agent: Agent | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> Any:
|
||||
"""Handle an async streaming response from the LLM.
|
||||
|
||||
Args:
|
||||
params: Parameters for the completion call
|
||||
callbacks: Optional list of callback functions
|
||||
available_functions: Dict of available functions
|
||||
from_task: Optional task object
|
||||
from_agent: Optional agent object
|
||||
response_model: Optional response model
|
||||
|
||||
Returns:
|
||||
str: The complete response text
|
||||
"""
|
||||
full_response = ""
|
||||
chunk_count = 0
|
||||
usage_info = None
|
||||
|
||||
accumulated_tool_args: defaultdict[int, AccumulatedToolArgs] = defaultdict(
|
||||
AccumulatedToolArgs
|
||||
)
|
||||
|
||||
params["stream"] = True
|
||||
params["stream_options"] = {"include_usage": True}
|
||||
|
||||
try:
|
||||
async for chunk in await litellm.acompletion(**params):
|
||||
chunk_count += 1
|
||||
chunk_content = None
|
||||
|
||||
try:
|
||||
choices = None
|
||||
if isinstance(chunk, dict) and "choices" in chunk:
|
||||
choices = chunk["choices"]
|
||||
elif hasattr(chunk, "choices"):
|
||||
if not isinstance(chunk.choices, type):
|
||||
choices = chunk.choices
|
||||
|
||||
if hasattr(chunk, "usage") and chunk.usage is not None:
|
||||
usage_info = chunk.usage
|
||||
|
||||
if choices and len(choices) > 0:
|
||||
first_choice = choices[0]
|
||||
delta = None
|
||||
|
||||
if isinstance(first_choice, dict):
|
||||
delta = first_choice.get("delta", {})
|
||||
elif hasattr(first_choice, "delta"):
|
||||
delta = first_choice.delta
|
||||
|
||||
if delta:
|
||||
if isinstance(delta, dict):
|
||||
chunk_content = delta.get("content")
|
||||
elif hasattr(delta, "content"):
|
||||
chunk_content = delta.content
|
||||
|
||||
tool_calls: list[ChatCompletionDeltaToolCall] | None = None
|
||||
if isinstance(delta, dict):
|
||||
tool_calls = delta.get("tool_calls")
|
||||
elif hasattr(delta, "tool_calls"):
|
||||
tool_calls = delta.tool_calls
|
||||
|
||||
if tool_calls:
|
||||
for tool_call in tool_calls:
|
||||
idx = tool_call.index
|
||||
if tool_call.function:
|
||||
if tool_call.function.name:
|
||||
accumulated_tool_args[
|
||||
idx
|
||||
].function.name = tool_call.function.name
|
||||
if tool_call.function.arguments:
|
||||
accumulated_tool_args[
|
||||
idx
|
||||
].function.arguments += (
|
||||
tool_call.function.arguments
|
||||
)
|
||||
|
||||
except (AttributeError, KeyError, IndexError, TypeError):
|
||||
pass
|
||||
|
||||
if chunk_content:
|
||||
full_response += chunk_content
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LLMStreamChunkEvent(
|
||||
chunk=chunk_content,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
|
||||
if callbacks and len(callbacks) > 0 and usage_info:
|
||||
for callback in callbacks:
|
||||
if hasattr(callback, "log_success_event"):
|
||||
callback.log_success_event(
|
||||
kwargs=params,
|
||||
response_obj={"usage": usage_info},
|
||||
start_time=0,
|
||||
end_time=0,
|
||||
)
|
||||
|
||||
if accumulated_tool_args and available_functions:
|
||||
# Convert accumulated tool args to ChatCompletionDeltaToolCall objects
|
||||
tool_calls_list: list[ChatCompletionDeltaToolCall] = [
|
||||
ChatCompletionDeltaToolCall(
|
||||
index=idx,
|
||||
function=Function(
|
||||
name=tool_arg.function.name,
|
||||
arguments=tool_arg.function.arguments,
|
||||
),
|
||||
)
|
||||
for idx, tool_arg in accumulated_tool_args.items()
|
||||
if tool_arg.function.name
|
||||
]
|
||||
|
||||
if tool_calls_list:
|
||||
result = self._handle_streaming_tool_calls(
|
||||
tool_calls=tool_calls_list,
|
||||
accumulated_tool_args=accumulated_tool_args,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
self._handle_emit_call_events(
|
||||
response=full_response,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params.get("messages"),
|
||||
)
|
||||
return full_response
|
||||
|
||||
except ContextWindowExceededError as e:
|
||||
raise LLMContextLengthExceededError(str(e)) from e
|
||||
except Exception:
|
||||
if chunk_count == 0:
|
||||
raise
|
||||
if full_response:
|
||||
self._handle_emit_call_events(
|
||||
response=full_response,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params.get("messages"),
|
||||
)
|
||||
return full_response
|
||||
raise
|
||||
|
||||
def _handle_tool_call(
|
||||
self,
|
||||
tool_calls: list[Any],
|
||||
@@ -1421,6 +1709,128 @@ class LLM(BaseLLM):
|
||||
)
|
||||
raise
|
||||
|
||||
async def acall(
|
||||
self,
|
||||
messages: str | list[LLMMessage],
|
||||
tools: list[dict[str, BaseTool]] | None = None,
|
||||
callbacks: list[Any] | None = None,
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Task | None = None,
|
||||
from_agent: Agent | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str | Any:
|
||||
"""Async high-level LLM call method.
|
||||
|
||||
Args:
|
||||
messages: Input messages for the LLM.
|
||||
Can be a string or list of message dictionaries.
|
||||
If string, it will be converted to a single user message.
|
||||
If list, each dict must have 'role' and 'content' keys.
|
||||
tools: Optional list of tool schemas for function calling.
|
||||
Each tool should define its name, description, and parameters.
|
||||
callbacks: Optional list of callback functions to be executed
|
||||
during and after the LLM call.
|
||||
available_functions: Optional dict mapping function names to callables
|
||||
that can be invoked by the LLM.
|
||||
from_task: Optional Task that invoked the LLM
|
||||
from_agent: Optional Agent that invoked the LLM
|
||||
response_model: Optional Model that contains a pydantic response model.
|
||||
|
||||
Returns:
|
||||
Union[str, Any]: Either a text response from the LLM (str) or
|
||||
the result of a tool function call (Any).
|
||||
|
||||
Raises:
|
||||
TypeError: If messages format is invalid
|
||||
ValueError: If response format is not supported
|
||||
LLMContextLengthExceededError: If input exceeds model's context limit
|
||||
"""
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LLMCallStartedEvent(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
callbacks=callbacks,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
model=self.model,
|
||||
),
|
||||
)
|
||||
|
||||
self._validate_call_params()
|
||||
|
||||
if isinstance(messages, str):
|
||||
messages = [{"role": "user", "content": messages}]
|
||||
|
||||
if "o1" in self.model.lower():
|
||||
for message in messages:
|
||||
if message.get("role") == "system":
|
||||
msg_role: Literal["assistant"] = "assistant"
|
||||
message["role"] = msg_role
|
||||
|
||||
with suppress_warnings():
|
||||
if callbacks and len(callbacks) > 0:
|
||||
self.set_callbacks(callbacks)
|
||||
try:
|
||||
params = self._prepare_completion_params(messages, tools)
|
||||
|
||||
if self.stream:
|
||||
return await self._ahandle_streaming_response(
|
||||
params=params,
|
||||
callbacks=callbacks,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
response_model=response_model,
|
||||
)
|
||||
|
||||
return await self._ahandle_non_streaming_response(
|
||||
params=params,
|
||||
callbacks=callbacks,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
response_model=response_model,
|
||||
)
|
||||
except LLMContextLengthExceededError:
|
||||
raise
|
||||
except Exception as e:
|
||||
unsupported_stop = "Unsupported parameter" in str(
|
||||
e
|
||||
) and "'stop'" in str(e)
|
||||
|
||||
if unsupported_stop:
|
||||
if (
|
||||
"additional_drop_params" in self.additional_params
|
||||
and isinstance(
|
||||
self.additional_params["additional_drop_params"], list
|
||||
)
|
||||
):
|
||||
self.additional_params["additional_drop_params"].append("stop")
|
||||
else:
|
||||
self.additional_params = {"additional_drop_params": ["stop"]}
|
||||
|
||||
logging.info("Retrying LLM call without the unsupported 'stop'")
|
||||
|
||||
return await self.acall(
|
||||
messages,
|
||||
tools=tools,
|
||||
callbacks=callbacks,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
response_model=response_model,
|
||||
)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LLMCallFailedEvent(
|
||||
error=str(e), from_task=from_task, from_agent=from_agent
|
||||
),
|
||||
)
|
||||
raise
|
||||
|
||||
def _handle_emit_call_events(
|
||||
self,
|
||||
response: Any,
|
||||
|
||||
@@ -158,6 +158,44 @@ class BaseLLM(ABC):
|
||||
RuntimeError: If the LLM request fails for other reasons.
|
||||
"""
|
||||
|
||||
async def acall(
|
||||
self,
|
||||
messages: str | list[LLMMessage],
|
||||
tools: list[dict[str, BaseTool]] | None = None,
|
||||
callbacks: list[Any] | None = None,
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Task | None = None,
|
||||
from_agent: Agent | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str | Any:
|
||||
"""Call the LLM with the given messages.
|
||||
|
||||
Args:
|
||||
messages: Input messages for the LLM.
|
||||
Can be a string or list of message dictionaries.
|
||||
If string, it will be converted to a single user message.
|
||||
If list, each dict must have 'role' and 'content' keys.
|
||||
tools: Optional list of tool schemas for function calling.
|
||||
Each tool should define its name, description, and parameters.
|
||||
callbacks: Optional list of callback functions to be executed
|
||||
during and after the LLM call.
|
||||
available_functions: Optional dict mapping function names to callables
|
||||
that can be invoked by the LLM.
|
||||
from_task: Optional task caller to be used for the LLM call.
|
||||
from_agent: Optional agent caller to be used for the LLM call.
|
||||
response_model: Optional response model to be used for the LLM call.
|
||||
|
||||
Returns:
|
||||
Either a text response from the LLM (str) or
|
||||
the result of a tool function call (Any).
|
||||
|
||||
Raises:
|
||||
ValueError: If the messages format is invalid.
|
||||
TimeoutError: If the LLM request times out.
|
||||
RuntimeError: If the LLM request fails for other reasons.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def _convert_tools_for_interference(
|
||||
self, tools: list[dict[str, BaseTool]]
|
||||
) -> list[dict[str, BaseTool]]:
|
||||
|
||||
@@ -256,6 +256,7 @@ GeminiModels: TypeAlias = Literal[
|
||||
"gemini-2.5-flash-preview-tts",
|
||||
"gemini-2.5-pro-preview-tts",
|
||||
"gemini-2.5-computer-use-preview-10-2025",
|
||||
"gemini-2.5-pro-exp-03-25",
|
||||
"gemini-2.0-flash",
|
||||
"gemini-2.0-flash-001",
|
||||
"gemini-2.0-flash-exp",
|
||||
@@ -309,6 +310,7 @@ GEMINI_MODELS: list[GeminiModels] = [
|
||||
"gemini-2.5-flash-preview-tts",
|
||||
"gemini-2.5-pro-preview-tts",
|
||||
"gemini-2.5-computer-use-preview-10-2025",
|
||||
"gemini-2.5-pro-exp-03-25",
|
||||
"gemini-2.0-flash",
|
||||
"gemini-2.0-flash-001",
|
||||
"gemini-2.0-flash-exp",
|
||||
|
||||
@@ -9,7 +9,7 @@ from pydantic import BaseModel
|
||||
|
||||
from crewai.events.types.llm_events import LLMCallType
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.llms.hooks.transport import HTTPTransport
|
||||
from crewai.llms.hooks.transport import AsyncHTTPTransport, HTTPTransport
|
||||
from crewai.utilities.agent_utils import is_context_length_exceeded
|
||||
from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
||||
LLMContextLengthExceededError,
|
||||
@@ -21,7 +21,7 @@ if TYPE_CHECKING:
|
||||
from crewai.llms.hooks.base import BaseInterceptor
|
||||
|
||||
try:
|
||||
from anthropic import Anthropic
|
||||
from anthropic import Anthropic, AsyncAnthropic
|
||||
from anthropic.types import Message
|
||||
from anthropic.types.tool_use_block import ToolUseBlock
|
||||
import httpx
|
||||
@@ -84,15 +84,20 @@ class AnthropicCompletion(BaseLLM):
|
||||
|
||||
self.client = Anthropic(**self._get_client_params())
|
||||
|
||||
async_client_params = self._get_client_params()
|
||||
if self.interceptor:
|
||||
async_transport = AsyncHTTPTransport(interceptor=self.interceptor)
|
||||
async_http_client = httpx.AsyncClient(transport=async_transport)
|
||||
async_client_params["http_client"] = async_http_client
|
||||
|
||||
self.async_client = AsyncAnthropic(**async_client_params)
|
||||
|
||||
# Store completion parameters
|
||||
self.max_tokens = max_tokens
|
||||
self.top_p = top_p
|
||||
self.stream = stream
|
||||
self.stop_sequences = stop_sequences or []
|
||||
|
||||
# Model-specific settings
|
||||
self.is_claude_3 = "claude-3" in model.lower()
|
||||
self.supports_tools = self.is_claude_3 # Claude 3+ supports tool use
|
||||
self.supports_tools = True
|
||||
|
||||
@property
|
||||
def stop(self) -> list[str]:
|
||||
@@ -213,6 +218,72 @@ class AnthropicCompletion(BaseLLM):
|
||||
)
|
||||
raise
|
||||
|
||||
async def acall(
|
||||
self,
|
||||
messages: str | list[LLMMessage],
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
callbacks: list[Any] | None = None,
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str | Any:
|
||||
"""Async call to Anthropic messages API.
|
||||
|
||||
Args:
|
||||
messages: Input messages for the chat completion
|
||||
tools: List of tool/function definitions
|
||||
callbacks: Callback functions (not used in native implementation)
|
||||
available_functions: Available functions for tool calling
|
||||
from_task: Task that initiated the call
|
||||
from_agent: Agent that initiated the call
|
||||
|
||||
Returns:
|
||||
Chat completion response or tool call result
|
||||
"""
|
||||
try:
|
||||
self._emit_call_started_event(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
callbacks=callbacks,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
formatted_messages, system_message = self._format_messages_for_anthropic(
|
||||
messages
|
||||
)
|
||||
|
||||
completion_params = self._prepare_completion_params(
|
||||
formatted_messages, system_message, tools
|
||||
)
|
||||
|
||||
if self.stream:
|
||||
return await self._ahandle_streaming_completion(
|
||||
completion_params,
|
||||
available_functions,
|
||||
from_task,
|
||||
from_agent,
|
||||
response_model,
|
||||
)
|
||||
|
||||
return await self._ahandle_completion(
|
||||
completion_params,
|
||||
available_functions,
|
||||
from_task,
|
||||
from_agent,
|
||||
response_model,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Anthropic API call failed: {e!s}"
|
||||
logging.error(error_msg)
|
||||
self._emit_call_failed_event(
|
||||
error=error_msg, from_task=from_task, from_agent=from_agent
|
||||
)
|
||||
raise
|
||||
|
||||
def _prepare_completion_params(
|
||||
self,
|
||||
messages: list[LLMMessage],
|
||||
@@ -546,7 +617,7 @@ class AnthropicCompletion(BaseLLM):
|
||||
# Execute the tool
|
||||
result = self._handle_tool_execution(
|
||||
function_name=function_name,
|
||||
function_args=function_args, # type: ignore
|
||||
function_args=function_args,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
@@ -626,6 +697,275 @@ class AnthropicCompletion(BaseLLM):
|
||||
return tool_results[0]["content"]
|
||||
raise e
|
||||
|
||||
async def _ahandle_completion(
|
||||
self,
|
||||
params: dict[str, Any],
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str | Any:
|
||||
"""Handle non-streaming async message completion."""
|
||||
if response_model:
|
||||
structured_tool = {
|
||||
"name": "structured_output",
|
||||
"description": "Returns structured data according to the schema",
|
||||
"input_schema": response_model.model_json_schema(),
|
||||
}
|
||||
|
||||
params["tools"] = [structured_tool]
|
||||
params["tool_choice"] = {"type": "tool", "name": "structured_output"}
|
||||
|
||||
try:
|
||||
response: Message = await self.async_client.messages.create(**params)
|
||||
|
||||
except Exception as e:
|
||||
if is_context_length_exceeded(e):
|
||||
logging.error(f"Context window exceeded: {e}")
|
||||
raise LLMContextLengthExceededError(str(e)) from e
|
||||
raise e from e
|
||||
|
||||
usage = self._extract_anthropic_token_usage(response)
|
||||
self._track_token_usage_internal(usage)
|
||||
|
||||
if response_model and response.content:
|
||||
tool_uses = [
|
||||
block for block in response.content if isinstance(block, ToolUseBlock)
|
||||
]
|
||||
if tool_uses and tool_uses[0].name == "structured_output":
|
||||
structured_data = tool_uses[0].input
|
||||
structured_json = json.dumps(structured_data)
|
||||
|
||||
self._emit_call_completed_event(
|
||||
response=structured_json,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
)
|
||||
|
||||
return structured_json
|
||||
|
||||
if response.content and available_functions:
|
||||
tool_uses = [
|
||||
block for block in response.content if isinstance(block, ToolUseBlock)
|
||||
]
|
||||
|
||||
if tool_uses:
|
||||
return await self._ahandle_tool_use_conversation(
|
||||
response,
|
||||
tool_uses,
|
||||
params,
|
||||
available_functions,
|
||||
from_task,
|
||||
from_agent,
|
||||
)
|
||||
|
||||
content = ""
|
||||
if response.content:
|
||||
for content_block in response.content:
|
||||
if hasattr(content_block, "text"):
|
||||
content += content_block.text
|
||||
|
||||
content = self._apply_stop_words(content)
|
||||
|
||||
self._emit_call_completed_event(
|
||||
response=content,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
)
|
||||
|
||||
if usage.get("total_tokens", 0) > 0:
|
||||
logging.info(f"Anthropic API usage: {usage}")
|
||||
|
||||
return content
|
||||
|
||||
async def _ahandle_streaming_completion(
|
||||
self,
|
||||
params: dict[str, Any],
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str:
|
||||
"""Handle async streaming message completion."""
|
||||
if response_model:
|
||||
structured_tool = {
|
||||
"name": "structured_output",
|
||||
"description": "Returns structured data according to the schema",
|
||||
"input_schema": response_model.model_json_schema(),
|
||||
}
|
||||
|
||||
params["tools"] = [structured_tool]
|
||||
params["tool_choice"] = {"type": "tool", "name": "structured_output"}
|
||||
|
||||
full_response = ""
|
||||
|
||||
stream_params = {k: v for k, v in params.items() if k != "stream"}
|
||||
|
||||
async with self.async_client.messages.stream(**stream_params) as stream:
|
||||
async for event in stream:
|
||||
if hasattr(event, "delta") and hasattr(event.delta, "text"):
|
||||
text_delta = event.delta.text
|
||||
full_response += text_delta
|
||||
self._emit_stream_chunk_event(
|
||||
chunk=text_delta,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
final_message: Message = await stream.get_final_message()
|
||||
|
||||
usage = self._extract_anthropic_token_usage(final_message)
|
||||
self._track_token_usage_internal(usage)
|
||||
|
||||
if response_model and final_message.content:
|
||||
tool_uses = [
|
||||
block
|
||||
for block in final_message.content
|
||||
if isinstance(block, ToolUseBlock)
|
||||
]
|
||||
if tool_uses and tool_uses[0].name == "structured_output":
|
||||
structured_data = tool_uses[0].input
|
||||
structured_json = json.dumps(structured_data)
|
||||
|
||||
self._emit_call_completed_event(
|
||||
response=structured_json,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
)
|
||||
|
||||
return structured_json
|
||||
|
||||
if final_message.content and available_functions:
|
||||
tool_uses = [
|
||||
block
|
||||
for block in final_message.content
|
||||
if isinstance(block, ToolUseBlock)
|
||||
]
|
||||
|
||||
if tool_uses:
|
||||
return await self._ahandle_tool_use_conversation(
|
||||
final_message,
|
||||
tool_uses,
|
||||
params,
|
||||
available_functions,
|
||||
from_task,
|
||||
from_agent,
|
||||
)
|
||||
|
||||
full_response = self._apply_stop_words(full_response)
|
||||
|
||||
self._emit_call_completed_event(
|
||||
response=full_response,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
)
|
||||
|
||||
return full_response
|
||||
|
||||
async def _ahandle_tool_use_conversation(
|
||||
self,
|
||||
initial_response: Message,
|
||||
tool_uses: list[ToolUseBlock],
|
||||
params: dict[str, Any],
|
||||
available_functions: dict[str, Any],
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
) -> str:
|
||||
"""Handle the complete async tool use conversation flow.
|
||||
|
||||
This implements the proper Anthropic tool use pattern:
|
||||
1. Claude requests tool use
|
||||
2. We execute the tools
|
||||
3. We send tool results back to Claude
|
||||
4. Claude processes results and generates final response
|
||||
"""
|
||||
tool_results = []
|
||||
|
||||
for tool_use in tool_uses:
|
||||
function_name = tool_use.name
|
||||
function_args = tool_use.input
|
||||
|
||||
result = self._handle_tool_execution(
|
||||
function_name=function_name,
|
||||
function_args=function_args,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
tool_result = {
|
||||
"type": "tool_result",
|
||||
"tool_use_id": tool_use.id,
|
||||
"content": str(result)
|
||||
if result is not None
|
||||
else "Tool execution completed",
|
||||
}
|
||||
tool_results.append(tool_result)
|
||||
|
||||
follow_up_params = params.copy()
|
||||
|
||||
assistant_message = {"role": "assistant", "content": initial_response.content}
|
||||
|
||||
user_message = {"role": "user", "content": tool_results}
|
||||
|
||||
follow_up_params["messages"] = params["messages"] + [
|
||||
assistant_message,
|
||||
user_message,
|
||||
]
|
||||
|
||||
try:
|
||||
final_response: Message = await self.async_client.messages.create(
|
||||
**follow_up_params
|
||||
)
|
||||
|
||||
follow_up_usage = self._extract_anthropic_token_usage(final_response)
|
||||
self._track_token_usage_internal(follow_up_usage)
|
||||
|
||||
final_content = ""
|
||||
if final_response.content:
|
||||
for content_block in final_response.content:
|
||||
if hasattr(content_block, "text"):
|
||||
final_content += content_block.text
|
||||
|
||||
final_content = self._apply_stop_words(final_content)
|
||||
|
||||
self._emit_call_completed_event(
|
||||
response=final_content,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=follow_up_params["messages"],
|
||||
)
|
||||
|
||||
total_usage = {
|
||||
"input_tokens": follow_up_usage.get("input_tokens", 0),
|
||||
"output_tokens": follow_up_usage.get("output_tokens", 0),
|
||||
"total_tokens": follow_up_usage.get("total_tokens", 0),
|
||||
}
|
||||
|
||||
if total_usage.get("total_tokens", 0) > 0:
|
||||
logging.info(f"Anthropic API tool conversation usage: {total_usage}")
|
||||
|
||||
return final_content
|
||||
|
||||
except Exception as e:
|
||||
if is_context_length_exceeded(e):
|
||||
logging.error(f"Context window exceeded in tool follow-up: {e}")
|
||||
raise LLMContextLengthExceededError(str(e)) from e
|
||||
|
||||
logging.error(f"Tool follow-up conversation failed: {e}")
|
||||
if tool_results:
|
||||
return tool_results[0]["content"]
|
||||
raise e
|
||||
|
||||
def supports_function_calling(self) -> bool:
|
||||
"""Check if the model supports function calling."""
|
||||
return self.supports_tools
|
||||
|
||||
@@ -1,11 +1,14 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import Self
|
||||
|
||||
from crewai.utilities.agent_utils import is_context_length_exceeded
|
||||
from crewai.utilities.converter import generate_model_description
|
||||
@@ -24,6 +27,9 @@ try:
|
||||
from azure.ai.inference import (
|
||||
ChatCompletionsClient,
|
||||
)
|
||||
from azure.ai.inference.aio import (
|
||||
ChatCompletionsClient as AsyncChatCompletionsClient,
|
||||
)
|
||||
from azure.ai.inference.models import (
|
||||
ChatCompletions,
|
||||
ChatCompletionsToolCall,
|
||||
@@ -31,7 +37,9 @@ try:
|
||||
StreamingChatCompletionsUpdate,
|
||||
)
|
||||
from azure.core.credentials import (
|
||||
AccessToken,
|
||||
AzureKeyCredential,
|
||||
TokenCredential,
|
||||
)
|
||||
from azure.core.exceptions import (
|
||||
HttpResponseError,
|
||||
@@ -46,6 +54,41 @@ except ImportError:
|
||||
) from None
|
||||
|
||||
|
||||
class _TokenProviderCredential(TokenCredential):
|
||||
"""Wrapper class to convert an azure_ad_token_provider callable into a TokenCredential.
|
||||
|
||||
This allows users to pass a token provider function (like the one returned by
|
||||
azure.identity.get_bearer_token_provider) to the Azure AI Inference client.
|
||||
"""
|
||||
|
||||
def __init__(self, provider: Callable[..., Any]):
|
||||
"""Initialize with a token provider callable.
|
||||
|
||||
Args:
|
||||
provider: A callable that returns an access token. This is typically
|
||||
the result of azure.identity.get_bearer_token_provider().
|
||||
"""
|
||||
self._provider = provider
|
||||
|
||||
def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken:
|
||||
"""Get an access token from the provider.
|
||||
|
||||
Args:
|
||||
*scopes: The scopes for the token (ignored, as the provider handles this).
|
||||
**kwargs: Additional keyword arguments (ignored).
|
||||
|
||||
Returns:
|
||||
An AccessToken instance.
|
||||
"""
|
||||
raw = self._provider()
|
||||
|
||||
if isinstance(raw, AccessToken):
|
||||
return raw
|
||||
|
||||
# If it's a bare string, wrap it with a default expiry of 1 hour
|
||||
return AccessToken(str(raw), int(time.time()) + 3600)
|
||||
|
||||
|
||||
class AzureCompletion(BaseLLM):
|
||||
"""Azure AI Inference native completion implementation.
|
||||
|
||||
@@ -69,6 +112,8 @@ class AzureCompletion(BaseLLM):
|
||||
stop: list[str] | None = None,
|
||||
stream: bool = False,
|
||||
interceptor: BaseInterceptor[Any, Any] | None = None,
|
||||
azure_ad_token_provider: Callable[..., Any] | None = None,
|
||||
credential: TokenCredential | None = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
"""Initialize Azure AI Inference chat completion client.
|
||||
@@ -88,6 +133,13 @@ class AzureCompletion(BaseLLM):
|
||||
stop: Stop sequences
|
||||
stream: Enable streaming responses
|
||||
interceptor: HTTP interceptor (not yet supported for Azure).
|
||||
azure_ad_token_provider: A callable that returns an Azure AD token.
|
||||
This is typically the result of azure.identity.get_bearer_token_provider().
|
||||
Use this for Azure AD token-based authentication instead of API keys.
|
||||
credential: An Azure TokenCredential instance for authentication.
|
||||
This can be any credential from azure.identity (e.g., DefaultAzureCredential,
|
||||
ManagedIdentityCredential). Takes precedence over azure_ad_token_provider
|
||||
and api_key.
|
||||
**kwargs: Additional parameters
|
||||
"""
|
||||
if interceptor is not None:
|
||||
@@ -103,6 +155,7 @@ class AzureCompletion(BaseLLM):
|
||||
self.api_key = api_key or os.getenv("AZURE_API_KEY")
|
||||
self.endpoint = (
|
||||
endpoint
|
||||
or kwargs.get("base_url")
|
||||
or os.getenv("AZURE_ENDPOINT")
|
||||
or os.getenv("AZURE_OPENAI_ENDPOINT")
|
||||
or os.getenv("AZURE_API_BASE")
|
||||
@@ -111,29 +164,45 @@ class AzureCompletion(BaseLLM):
|
||||
self.timeout = timeout
|
||||
self.max_retries = max_retries
|
||||
|
||||
if not self.api_key:
|
||||
raise ValueError(
|
||||
"Azure API key is required. Set AZURE_API_KEY environment variable or pass api_key parameter."
|
||||
)
|
||||
if not self.endpoint:
|
||||
raise ValueError(
|
||||
"Azure endpoint is required. Set AZURE_ENDPOINT environment variable or pass endpoint parameter."
|
||||
)
|
||||
|
||||
# Determine the credential to use (priority: credential > azure_ad_token_provider > api_key)
|
||||
chosen_credential: TokenCredential | AzureKeyCredential | None = None
|
||||
|
||||
if credential is not None:
|
||||
chosen_credential = credential
|
||||
elif azure_ad_token_provider is not None:
|
||||
chosen_credential = _TokenProviderCredential(azure_ad_token_provider)
|
||||
elif self.api_key:
|
||||
chosen_credential = AzureKeyCredential(self.api_key)
|
||||
|
||||
if chosen_credential is None:
|
||||
raise ValueError(
|
||||
"Azure authentication is required. Provide one of: "
|
||||
"api_key (or set AZURE_API_KEY environment variable), "
|
||||
"azure_ad_token_provider (callable from azure.identity.get_bearer_token_provider), "
|
||||
"or credential (TokenCredential instance from azure.identity)."
|
||||
)
|
||||
|
||||
# Validate and potentially fix Azure OpenAI endpoint URL
|
||||
self.endpoint = self._validate_and_fix_endpoint(self.endpoint, model)
|
||||
|
||||
# Build client kwargs
|
||||
client_kwargs = {
|
||||
client_kwargs: dict[str, Any] = {
|
||||
"endpoint": self.endpoint,
|
||||
"credential": AzureKeyCredential(self.api_key),
|
||||
"credential": chosen_credential,
|
||||
}
|
||||
|
||||
# Add api_version if specified (primarily for Azure OpenAI endpoints)
|
||||
if self.api_version:
|
||||
client_kwargs["api_version"] = self.api_version
|
||||
|
||||
self.client = ChatCompletionsClient(**client_kwargs) # type: ignore[arg-type]
|
||||
self.client = ChatCompletionsClient(**client_kwargs)
|
||||
|
||||
self.async_client = AsyncChatCompletionsClient(**client_kwargs)
|
||||
|
||||
self.top_p = top_p
|
||||
self.frequency_penalty = frequency_penalty
|
||||
@@ -258,6 +327,88 @@ class AzureCompletion(BaseLLM):
|
||||
)
|
||||
raise
|
||||
|
||||
async def acall(
|
||||
self,
|
||||
messages: str | list[LLMMessage],
|
||||
tools: list[dict[str, BaseTool]] | None = None,
|
||||
callbacks: list[Any] | None = None,
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str | Any:
|
||||
"""Call Azure AI Inference chat completions API asynchronously.
|
||||
|
||||
Args:
|
||||
messages: Input messages for the chat completion
|
||||
tools: List of tool/function definitions
|
||||
callbacks: Callback functions (not used in native implementation)
|
||||
available_functions: Available functions for tool calling
|
||||
from_task: Task that initiated the call
|
||||
from_agent: Agent that initiated the call
|
||||
response_model: Pydantic model for structured output
|
||||
|
||||
Returns:
|
||||
Chat completion response or tool call result
|
||||
"""
|
||||
try:
|
||||
self._emit_call_started_event(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
callbacks=callbacks,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
formatted_messages = self._format_messages_for_azure(messages)
|
||||
|
||||
completion_params = self._prepare_completion_params(
|
||||
formatted_messages, tools, response_model
|
||||
)
|
||||
|
||||
if self.stream:
|
||||
return await self._ahandle_streaming_completion(
|
||||
completion_params,
|
||||
available_functions,
|
||||
from_task,
|
||||
from_agent,
|
||||
response_model,
|
||||
)
|
||||
|
||||
return await self._ahandle_completion(
|
||||
completion_params,
|
||||
available_functions,
|
||||
from_task,
|
||||
from_agent,
|
||||
response_model,
|
||||
)
|
||||
|
||||
except HttpResponseError as e:
|
||||
if e.status_code == 401:
|
||||
error_msg = "Azure authentication failed. Check your API key."
|
||||
elif e.status_code == 404:
|
||||
error_msg = (
|
||||
f"Azure endpoint not found. Check endpoint URL: {self.endpoint}"
|
||||
)
|
||||
elif e.status_code == 429:
|
||||
error_msg = "Azure API rate limit exceeded. Please retry later."
|
||||
else:
|
||||
error_msg = f"Azure API HTTP error: {e.status_code} - {e.message}"
|
||||
|
||||
logging.error(error_msg)
|
||||
self._emit_call_failed_event(
|
||||
error=error_msg, from_task=from_task, from_agent=from_agent
|
||||
)
|
||||
raise
|
||||
except Exception as e:
|
||||
error_msg = f"Azure API call failed: {e!s}"
|
||||
logging.error(error_msg)
|
||||
self._emit_call_failed_event(
|
||||
error=error_msg, from_task=from_task, from_agent=from_agent
|
||||
)
|
||||
raise
|
||||
|
||||
def _prepare_completion_params(
|
||||
self,
|
||||
messages: list[LLMMessage],
|
||||
@@ -556,6 +707,170 @@ class AzureCompletion(BaseLLM):
|
||||
|
||||
return full_response
|
||||
|
||||
async def _ahandle_completion(
|
||||
self,
|
||||
params: dict[str, Any],
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str | Any:
|
||||
"""Handle non-streaming chat completion asynchronously."""
|
||||
try:
|
||||
response: ChatCompletions = await self.async_client.complete(**params)
|
||||
|
||||
if not response.choices:
|
||||
raise ValueError("No choices returned from Azure API")
|
||||
|
||||
choice = response.choices[0]
|
||||
message = choice.message
|
||||
|
||||
usage = self._extract_azure_token_usage(response)
|
||||
self._track_token_usage_internal(usage)
|
||||
|
||||
if response_model and self.is_openai_model:
|
||||
content = message.content or ""
|
||||
try:
|
||||
structured_data = response_model.model_validate_json(content)
|
||||
structured_json = structured_data.model_dump_json()
|
||||
|
||||
self._emit_call_completed_event(
|
||||
response=structured_json,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
)
|
||||
|
||||
return structured_json
|
||||
except Exception as e:
|
||||
error_msg = f"Failed to validate structured output with model {response_model.__name__}: {e}"
|
||||
logging.error(error_msg)
|
||||
raise ValueError(error_msg) from e
|
||||
|
||||
if message.tool_calls and available_functions:
|
||||
tool_call = message.tool_calls[0] # Handle first tool call
|
||||
if isinstance(tool_call, ChatCompletionsToolCall):
|
||||
function_name = tool_call.function.name
|
||||
|
||||
try:
|
||||
function_args = json.loads(tool_call.function.arguments)
|
||||
except json.JSONDecodeError as e:
|
||||
logging.error(f"Failed to parse tool arguments: {e}")
|
||||
function_args = {}
|
||||
|
||||
result = self._handle_tool_execution(
|
||||
function_name=function_name,
|
||||
function_args=function_args,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
content = message.content or ""
|
||||
|
||||
content = self._apply_stop_words(content)
|
||||
|
||||
self._emit_call_completed_event(
|
||||
response=content,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
if is_context_length_exceeded(e):
|
||||
logging.error(f"Context window exceeded: {e}")
|
||||
raise LLMContextLengthExceededError(str(e)) from e
|
||||
|
||||
error_msg = f"Azure API call failed: {e!s}"
|
||||
logging.error(error_msg)
|
||||
self._emit_call_failed_event(
|
||||
error=error_msg, from_task=from_task, from_agent=from_agent
|
||||
)
|
||||
raise e
|
||||
|
||||
return content
|
||||
|
||||
async def _ahandle_streaming_completion(
|
||||
self,
|
||||
params: dict[str, Any],
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str:
|
||||
"""Handle streaming chat completion asynchronously."""
|
||||
full_response = ""
|
||||
tool_calls = {}
|
||||
|
||||
stream = await self.async_client.complete(**params)
|
||||
async for update in stream:
|
||||
if isinstance(update, StreamingChatCompletionsUpdate):
|
||||
if update.choices:
|
||||
choice = update.choices[0]
|
||||
if choice.delta and choice.delta.content:
|
||||
content_delta = choice.delta.content
|
||||
full_response += content_delta
|
||||
self._emit_stream_chunk_event(
|
||||
chunk=content_delta,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
if choice.delta and choice.delta.tool_calls:
|
||||
for tool_call in choice.delta.tool_calls:
|
||||
call_id = tool_call.id or "default"
|
||||
if call_id not in tool_calls:
|
||||
tool_calls[call_id] = {
|
||||
"name": "",
|
||||
"arguments": "",
|
||||
}
|
||||
|
||||
if tool_call.function and tool_call.function.name:
|
||||
tool_calls[call_id]["name"] = tool_call.function.name
|
||||
if tool_call.function and tool_call.function.arguments:
|
||||
tool_calls[call_id]["arguments"] += (
|
||||
tool_call.function.arguments
|
||||
)
|
||||
|
||||
if tool_calls and available_functions:
|
||||
for call_data in tool_calls.values():
|
||||
function_name = call_data["name"]
|
||||
|
||||
try:
|
||||
function_args = json.loads(call_data["arguments"])
|
||||
except json.JSONDecodeError as e:
|
||||
logging.error(f"Failed to parse streamed tool arguments: {e}")
|
||||
continue
|
||||
|
||||
result = self._handle_tool_execution(
|
||||
function_name=function_name,
|
||||
function_args=function_args,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
full_response = self._apply_stop_words(full_response)
|
||||
|
||||
self._emit_call_completed_event(
|
||||
response=full_response,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
)
|
||||
|
||||
return full_response
|
||||
|
||||
def supports_function_calling(self) -> bool:
|
||||
"""Check if the model supports function calling."""
|
||||
# Azure OpenAI models support function calling
|
||||
@@ -609,3 +924,20 @@ class AzureCompletion(BaseLLM):
|
||||
"total_tokens": getattr(usage, "total_tokens", 0),
|
||||
}
|
||||
return {"total_tokens": 0}
|
||||
|
||||
async def aclose(self) -> None:
|
||||
"""Close the async client and clean up resources.
|
||||
|
||||
This ensures proper cleanup of the underlying aiohttp session
|
||||
to avoid unclosed connector warnings.
|
||||
"""
|
||||
if hasattr(self.async_client, "close"):
|
||||
await self.async_client.close()
|
||||
|
||||
async def __aenter__(self) -> Self:
|
||||
"""Async context manager entry."""
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
||||
"""Async context manager exit."""
|
||||
await self.aclose()
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping, Sequence
|
||||
from contextlib import AsyncExitStack
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, TypedDict, cast
|
||||
@@ -42,6 +44,16 @@ except ImportError:
|
||||
'AWS Bedrock native provider not available, to install: uv add "crewai[bedrock]"'
|
||||
) from None
|
||||
|
||||
try:
|
||||
from aiobotocore.session import ( # type: ignore[import-untyped]
|
||||
get_session as get_aiobotocore_session,
|
||||
)
|
||||
|
||||
AIOBOTOCORE_AVAILABLE = True
|
||||
except ImportError:
|
||||
AIOBOTOCORE_AVAILABLE = False
|
||||
get_aiobotocore_session = None
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
@@ -221,6 +233,15 @@ class BedrockCompletion(BaseLLM):
|
||||
self.client = session.client("bedrock-runtime", config=config)
|
||||
self.region_name = region_name
|
||||
|
||||
self.aws_access_key_id = aws_access_key_id or os.getenv("AWS_ACCESS_KEY_ID")
|
||||
self.aws_secret_access_key = aws_secret_access_key or os.getenv(
|
||||
"AWS_SECRET_ACCESS_KEY"
|
||||
)
|
||||
self.aws_session_token = aws_session_token or os.getenv("AWS_SESSION_TOKEN")
|
||||
|
||||
self._async_exit_stack = AsyncExitStack() if AIOBOTOCORE_AVAILABLE else None
|
||||
self._async_client_initialized = False
|
||||
|
||||
# Store completion parameters
|
||||
self.max_tokens = max_tokens
|
||||
self.top_p = top_p
|
||||
@@ -354,6 +375,110 @@ class BedrockCompletion(BaseLLM):
|
||||
)
|
||||
raise
|
||||
|
||||
async def acall(
|
||||
self,
|
||||
messages: str | list[LLMMessage],
|
||||
tools: list[dict[Any, Any]] | None = None,
|
||||
callbacks: list[Any] | None = None,
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str | Any:
|
||||
"""Async call to AWS Bedrock Converse API.
|
||||
|
||||
Args:
|
||||
messages: Input messages as string or list of message dicts.
|
||||
tools: Optional list of tool definitions.
|
||||
callbacks: Optional list of callback handlers.
|
||||
available_functions: Optional dict mapping function names to callables.
|
||||
from_task: Optional task context for events.
|
||||
from_agent: Optional agent context for events.
|
||||
response_model: Optional Pydantic model for structured output.
|
||||
|
||||
Returns:
|
||||
Generated text response or structured output.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: If aiobotocore is not installed.
|
||||
LLMContextLengthExceededError: If context window is exceeded.
|
||||
"""
|
||||
if not AIOBOTOCORE_AVAILABLE:
|
||||
raise NotImplementedError(
|
||||
"Async support for AWS Bedrock requires aiobotocore. "
|
||||
'Install with: uv add "crewai[bedrock-async]"'
|
||||
)
|
||||
|
||||
try:
|
||||
self._emit_call_started_event(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
callbacks=callbacks,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
formatted_messages, system_message = self._format_messages_for_converse(
|
||||
messages # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
body: BedrockConverseRequestBody = {
|
||||
"inferenceConfig": self._get_inference_config(),
|
||||
}
|
||||
|
||||
if system_message:
|
||||
body["system"] = cast(
|
||||
"list[SystemContentBlockTypeDef]",
|
||||
cast(object, [{"text": system_message}]),
|
||||
)
|
||||
|
||||
if tools:
|
||||
tool_config: ToolConfigurationTypeDef = {
|
||||
"tools": cast(
|
||||
"Sequence[ToolTypeDef]",
|
||||
cast(object, self._format_tools_for_converse(tools)),
|
||||
)
|
||||
}
|
||||
body["toolConfig"] = tool_config
|
||||
|
||||
if self.guardrail_config:
|
||||
guardrail_config: GuardrailConfigurationTypeDef = cast(
|
||||
"GuardrailConfigurationTypeDef", cast(object, self.guardrail_config)
|
||||
)
|
||||
body["guardrailConfig"] = guardrail_config
|
||||
|
||||
if self.additional_model_request_fields:
|
||||
body["additionalModelRequestFields"] = (
|
||||
self.additional_model_request_fields
|
||||
)
|
||||
|
||||
if self.additional_model_response_field_paths:
|
||||
body["additionalModelResponseFieldPaths"] = (
|
||||
self.additional_model_response_field_paths
|
||||
)
|
||||
|
||||
if self.stream:
|
||||
return await self._ahandle_streaming_converse(
|
||||
formatted_messages, body, available_functions, from_task, from_agent
|
||||
)
|
||||
|
||||
return await self._ahandle_converse(
|
||||
formatted_messages, body, available_functions, from_task, from_agent
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
if is_context_length_exceeded(e):
|
||||
logging.error(f"Context window exceeded: {e}")
|
||||
raise LLMContextLengthExceededError(str(e)) from e
|
||||
|
||||
error_msg = f"AWS Bedrock API call failed: {e!s}"
|
||||
logging.error(error_msg)
|
||||
self._emit_call_failed_event(
|
||||
error=error_msg, from_task=from_task, from_agent=from_agent
|
||||
)
|
||||
raise
|
||||
|
||||
def _handle_converse(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
@@ -565,6 +690,341 @@ class BedrockCompletion(BaseLLM):
|
||||
role = event["messageStart"].get("role")
|
||||
logging.debug(f"Streaming message started with role: {role}")
|
||||
|
||||
elif "contentBlockStart" in event:
|
||||
start = event["contentBlockStart"].get("start", {})
|
||||
if "toolUse" in start:
|
||||
current_tool_use = start["toolUse"]
|
||||
tool_use_id = current_tool_use.get("toolUseId")
|
||||
logging.debug(
|
||||
f"Tool use started in stream: {json.dumps(current_tool_use)} (ID: {tool_use_id})"
|
||||
)
|
||||
|
||||
elif "contentBlockDelta" in event:
|
||||
delta = event["contentBlockDelta"]["delta"]
|
||||
if "text" in delta:
|
||||
text_chunk = delta["text"]
|
||||
logging.debug(f"Streaming text chunk: {text_chunk[:50]}...")
|
||||
full_response += text_chunk
|
||||
self._emit_stream_chunk_event(
|
||||
chunk=text_chunk,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
elif "toolUse" in delta and current_tool_use:
|
||||
tool_input = delta["toolUse"].get("input", "")
|
||||
if tool_input:
|
||||
logging.debug(f"Tool input delta: {tool_input}")
|
||||
elif "contentBlockStop" in event:
|
||||
logging.debug("Content block stopped in stream")
|
||||
if current_tool_use and available_functions:
|
||||
function_name = current_tool_use["name"]
|
||||
function_args = cast(
|
||||
dict[str, Any], current_tool_use.get("input", {})
|
||||
)
|
||||
tool_result = self._handle_tool_execution(
|
||||
function_name=function_name,
|
||||
function_args=function_args,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
if tool_result is not None and tool_use_id:
|
||||
messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [{"toolUse": current_tool_use}],
|
||||
}
|
||||
)
|
||||
messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"toolResult": {
|
||||
"toolUseId": tool_use_id,
|
||||
"content": [
|
||||
{"text": str(tool_result)}
|
||||
],
|
||||
}
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
return self._handle_converse(
|
||||
messages,
|
||||
body,
|
||||
available_functions,
|
||||
from_task,
|
||||
from_agent,
|
||||
)
|
||||
current_tool_use = None
|
||||
tool_use_id = None
|
||||
elif "messageStop" in event:
|
||||
stop_reason = event["messageStop"].get("stopReason")
|
||||
logging.debug(f"Streaming message stopped: {stop_reason}")
|
||||
if stop_reason == "max_tokens":
|
||||
logging.warning(
|
||||
"Streaming response truncated due to max_tokens"
|
||||
)
|
||||
elif stop_reason == "content_filtered":
|
||||
logging.warning(
|
||||
"Streaming response filtered due to content policy"
|
||||
)
|
||||
break
|
||||
elif "metadata" in event:
|
||||
metadata = event["metadata"]
|
||||
if "usage" in metadata:
|
||||
usage_metrics = metadata["usage"]
|
||||
self._track_token_usage_internal(usage_metrics)
|
||||
logging.debug(f"Token usage: {usage_metrics}")
|
||||
if "trace" in metadata:
|
||||
logging.debug(
|
||||
f"Trace information available: {metadata['trace']}"
|
||||
)
|
||||
|
||||
except ClientError as e:
|
||||
error_msg = self._handle_client_error(e)
|
||||
raise RuntimeError(error_msg) from e
|
||||
except BotoCoreError as e:
|
||||
error_msg = f"Bedrock streaming connection error: {e}"
|
||||
logging.error(error_msg)
|
||||
raise ConnectionError(error_msg) from e
|
||||
|
||||
full_response = self._apply_stop_words(full_response)
|
||||
|
||||
if not full_response or full_response.strip() == "":
|
||||
logging.warning("Bedrock streaming returned empty content, using fallback")
|
||||
full_response = (
|
||||
"I apologize, but I couldn't generate a response. Please try again."
|
||||
)
|
||||
|
||||
self._emit_call_completed_event(
|
||||
response=full_response,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=messages,
|
||||
)
|
||||
|
||||
return full_response
|
||||
|
||||
async def _ensure_async_client(self) -> Any:
|
||||
"""Ensure async client is initialized and return it."""
|
||||
if not self._async_client_initialized and get_aiobotocore_session:
|
||||
if self._async_exit_stack is None:
|
||||
raise RuntimeError(
|
||||
"Async exit stack not initialized - aiobotocore not available"
|
||||
)
|
||||
session = get_aiobotocore_session()
|
||||
client = await self._async_exit_stack.enter_async_context(
|
||||
session.create_client(
|
||||
"bedrock-runtime",
|
||||
region_name=self.region_name,
|
||||
aws_access_key_id=self.aws_access_key_id,
|
||||
aws_secret_access_key=self.aws_secret_access_key,
|
||||
aws_session_token=self.aws_session_token,
|
||||
)
|
||||
)
|
||||
self._async_client = client
|
||||
self._async_client_initialized = True
|
||||
return self._async_client
|
||||
|
||||
async def _ahandle_converse(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
body: BedrockConverseRequestBody,
|
||||
available_functions: Mapping[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
) -> str:
|
||||
"""Handle async non-streaming converse API call."""
|
||||
try:
|
||||
if not messages:
|
||||
raise ValueError("Messages cannot be empty")
|
||||
|
||||
for i, msg in enumerate(messages):
|
||||
if (
|
||||
not isinstance(msg, dict)
|
||||
or "role" not in msg
|
||||
or "content" not in msg
|
||||
):
|
||||
raise ValueError(f"Invalid message format at index {i}")
|
||||
|
||||
async_client = await self._ensure_async_client()
|
||||
response = await async_client.converse(
|
||||
modelId=self.model_id,
|
||||
messages=cast(
|
||||
"Sequence[MessageTypeDef | MessageOutputTypeDef]",
|
||||
cast(object, messages),
|
||||
),
|
||||
**body,
|
||||
)
|
||||
|
||||
if "usage" in response:
|
||||
self._track_token_usage_internal(response["usage"])
|
||||
|
||||
stop_reason = response.get("stopReason")
|
||||
if stop_reason:
|
||||
logging.debug(f"Response stop reason: {stop_reason}")
|
||||
if stop_reason == "max_tokens":
|
||||
logging.warning("Response truncated due to max_tokens limit")
|
||||
elif stop_reason == "content_filtered":
|
||||
logging.warning("Response was filtered due to content policy")
|
||||
|
||||
output = response.get("output", {})
|
||||
message = output.get("message", {})
|
||||
content = message.get("content", [])
|
||||
|
||||
if not content:
|
||||
logging.warning("No content in Bedrock response")
|
||||
return (
|
||||
"I apologize, but I received an empty response. Please try again."
|
||||
)
|
||||
|
||||
text_content = ""
|
||||
|
||||
for content_block in content:
|
||||
if "text" in content_block:
|
||||
text_content += content_block["text"]
|
||||
|
||||
elif "toolUse" in content_block and available_functions:
|
||||
tool_use_block = content_block["toolUse"]
|
||||
tool_use_id = tool_use_block.get("toolUseId")
|
||||
function_name = tool_use_block["name"]
|
||||
function_args = tool_use_block.get("input", {})
|
||||
|
||||
logging.debug(
|
||||
f"Tool use requested: {function_name} with ID {tool_use_id}"
|
||||
)
|
||||
|
||||
tool_result = self._handle_tool_execution(
|
||||
function_name=function_name,
|
||||
function_args=function_args,
|
||||
available_functions=dict(available_functions),
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
if tool_result is not None:
|
||||
messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [{"toolUse": tool_use_block}],
|
||||
}
|
||||
)
|
||||
|
||||
messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"toolResult": {
|
||||
"toolUseId": tool_use_id,
|
||||
"content": [{"text": str(tool_result)}],
|
||||
}
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
return await self._ahandle_converse(
|
||||
messages, body, available_functions, from_task, from_agent
|
||||
)
|
||||
|
||||
text_content = self._apply_stop_words(text_content)
|
||||
|
||||
if not text_content or text_content.strip() == "":
|
||||
logging.warning("Extracted empty text content from Bedrock response")
|
||||
text_content = "I apologize, but I couldn't generate a proper response. Please try again."
|
||||
|
||||
self._emit_call_completed_event(
|
||||
response=text_content,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=messages,
|
||||
)
|
||||
|
||||
return text_content
|
||||
|
||||
except ClientError as e:
|
||||
error_code = e.response.get("Error", {}).get("Code", "Unknown")
|
||||
error_msg = e.response.get("Error", {}).get("Message", str(e))
|
||||
logging.error(f"AWS Bedrock ClientError ({error_code}): {error_msg}")
|
||||
|
||||
if error_code == "ValidationException":
|
||||
if "last turn" in error_msg and "user message" in error_msg:
|
||||
raise ValueError(
|
||||
f"Conversation format error: {error_msg}. Check message alternation."
|
||||
) from e
|
||||
raise ValueError(f"Request validation failed: {error_msg}") from e
|
||||
if error_code == "AccessDeniedException":
|
||||
raise PermissionError(
|
||||
f"Access denied to model {self.model_id}: {error_msg}"
|
||||
) from e
|
||||
if error_code == "ResourceNotFoundException":
|
||||
raise ValueError(f"Model {self.model_id} not found: {error_msg}") from e
|
||||
if error_code == "ThrottlingException":
|
||||
raise RuntimeError(
|
||||
f"API throttled, please retry later: {error_msg}"
|
||||
) from e
|
||||
if error_code == "ModelTimeoutException":
|
||||
raise TimeoutError(f"Model request timed out: {error_msg}") from e
|
||||
if error_code == "ServiceQuotaExceededException":
|
||||
raise RuntimeError(f"Service quota exceeded: {error_msg}") from e
|
||||
if error_code == "ModelNotReadyException":
|
||||
raise RuntimeError(
|
||||
f"Model {self.model_id} not ready: {error_msg}"
|
||||
) from e
|
||||
if error_code == "ModelErrorException":
|
||||
raise RuntimeError(f"Model error: {error_msg}") from e
|
||||
if error_code == "InternalServerException":
|
||||
raise RuntimeError(f"Internal server error: {error_msg}") from e
|
||||
if error_code == "ServiceUnavailableException":
|
||||
raise RuntimeError(f"Service unavailable: {error_msg}") from e
|
||||
|
||||
raise RuntimeError(f"Bedrock API error ({error_code}): {error_msg}") from e
|
||||
|
||||
except BotoCoreError as e:
|
||||
error_msg = f"Bedrock connection error: {e}"
|
||||
logging.error(error_msg)
|
||||
raise ConnectionError(error_msg) from e
|
||||
except Exception as e:
|
||||
error_msg = f"Unexpected error in Bedrock converse call: {e}"
|
||||
logging.error(error_msg)
|
||||
raise RuntimeError(error_msg) from e
|
||||
|
||||
async def _ahandle_streaming_converse(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
body: BedrockConverseRequestBody,
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
) -> str:
|
||||
"""Handle async streaming converse API call."""
|
||||
full_response = ""
|
||||
current_tool_use = None
|
||||
tool_use_id = None
|
||||
|
||||
try:
|
||||
async_client = await self._ensure_async_client()
|
||||
response = await async_client.converse_stream(
|
||||
modelId=self.model_id,
|
||||
messages=cast(
|
||||
"Sequence[MessageTypeDef | MessageOutputTypeDef]",
|
||||
cast(object, messages),
|
||||
),
|
||||
**body,
|
||||
)
|
||||
|
||||
stream = response.get("stream")
|
||||
if stream:
|
||||
async for event in stream:
|
||||
if "messageStart" in event:
|
||||
role = event["messageStart"].get("role")
|
||||
logging.debug(f"Streaming message started with role: {role}")
|
||||
|
||||
elif "contentBlockStart" in event:
|
||||
start = event["contentBlockStart"].get("start", {})
|
||||
if "toolUse" in start:
|
||||
@@ -590,17 +1050,14 @@ class BedrockCompletion(BaseLLM):
|
||||
if tool_input:
|
||||
logging.debug(f"Tool input delta: {tool_input}")
|
||||
|
||||
# Content block stop - end of a content block
|
||||
elif "contentBlockStop" in event:
|
||||
logging.debug("Content block stopped in stream")
|
||||
# If we were accumulating a tool use, it's now complete
|
||||
if current_tool_use and available_functions:
|
||||
function_name = current_tool_use["name"]
|
||||
function_args = cast(
|
||||
dict[str, Any], current_tool_use.get("input", {})
|
||||
)
|
||||
|
||||
# Execute tool
|
||||
tool_result = self._handle_tool_execution(
|
||||
function_name=function_name,
|
||||
function_args=function_args,
|
||||
@@ -610,7 +1067,6 @@ class BedrockCompletion(BaseLLM):
|
||||
)
|
||||
|
||||
if tool_result is not None and tool_use_id:
|
||||
# Continue conversation with tool result
|
||||
messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
@@ -634,8 +1090,7 @@ class BedrockCompletion(BaseLLM):
|
||||
}
|
||||
)
|
||||
|
||||
# Recursive call - note this switches to non-streaming
|
||||
return self._handle_converse(
|
||||
return await self._ahandle_converse(
|
||||
messages,
|
||||
body,
|
||||
available_functions,
|
||||
@@ -643,10 +1098,9 @@ class BedrockCompletion(BaseLLM):
|
||||
from_agent,
|
||||
)
|
||||
|
||||
current_tool_use = None
|
||||
tool_use_id = None
|
||||
current_tool_use = None
|
||||
tool_use_id = None
|
||||
|
||||
# Message stop - end of entire message
|
||||
elif "messageStop" in event:
|
||||
stop_reason = event["messageStop"].get("stopReason")
|
||||
logging.debug(f"Streaming message stopped: {stop_reason}")
|
||||
@@ -660,7 +1114,6 @@ class BedrockCompletion(BaseLLM):
|
||||
)
|
||||
break
|
||||
|
||||
# Metadata - contains usage information and trace details
|
||||
elif "metadata" in event:
|
||||
metadata = event["metadata"]
|
||||
if "usage" in metadata:
|
||||
@@ -680,17 +1133,14 @@ class BedrockCompletion(BaseLLM):
|
||||
logging.error(error_msg)
|
||||
raise ConnectionError(error_msg) from e
|
||||
|
||||
# Apply stop words to full response
|
||||
full_response = self._apply_stop_words(full_response)
|
||||
|
||||
# Ensure we don't return empty content
|
||||
if not full_response or full_response.strip() == "":
|
||||
logging.warning("Bedrock streaming returned empty content, using fallback")
|
||||
full_response = (
|
||||
"I apologize, but I couldn't generate a response. Please try again."
|
||||
)
|
||||
|
||||
# Emit completion event
|
||||
self._emit_call_completed_event(
|
||||
response=full_response,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from typing import Any, cast
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from crewai.events.types.llm_events import LLMCallType
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.llms.hooks.base import BaseInterceptor
|
||||
from crewai.utilities.agent_utils import is_context_length_exceeded
|
||||
from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
||||
LLMContextLengthExceededError,
|
||||
@@ -15,10 +16,15 @@ from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
||||
from crewai.utilities.types import LLMMessage
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.llms.hooks.base import BaseInterceptor
|
||||
|
||||
|
||||
try:
|
||||
from google import genai # type: ignore[import-untyped]
|
||||
from google.genai import types # type: ignore[import-untyped]
|
||||
from google.genai.errors import APIError # type: ignore[import-untyped]
|
||||
from google import genai
|
||||
from google.genai import types
|
||||
from google.genai.errors import APIError
|
||||
from google.genai.types import GenerateContentResponse, Schema
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
'Google Gen AI native provider not available, to install: uv add "crewai[google-genai]"'
|
||||
@@ -102,7 +108,9 @@ class GeminiCompletion(BaseLLM):
|
||||
|
||||
# Model-specific settings
|
||||
version_match = re.search(r"gemini-(\d+(?:\.\d+)?)", model.lower())
|
||||
self.supports_tools = bool(version_match and float(version_match.group(1)) >= 1.5)
|
||||
self.supports_tools = bool(
|
||||
version_match and float(version_match.group(1)) >= 1.5
|
||||
)
|
||||
|
||||
@property
|
||||
def stop(self) -> list[str]:
|
||||
@@ -128,7 +136,7 @@ class GeminiCompletion(BaseLLM):
|
||||
else:
|
||||
self.stop_sequences = []
|
||||
|
||||
def _initialize_client(self, use_vertexai: bool = False) -> genai.Client: # type: ignore[no-any-unimported]
|
||||
def _initialize_client(self, use_vertexai: bool = False) -> genai.Client:
|
||||
"""Initialize the Google Gen AI client with proper parameter handling.
|
||||
|
||||
Args:
|
||||
@@ -277,7 +285,84 @@ class GeminiCompletion(BaseLLM):
|
||||
)
|
||||
raise
|
||||
|
||||
def _prepare_generation_config( # type: ignore[no-any-unimported]
|
||||
async def acall(
|
||||
self,
|
||||
messages: str | list[LLMMessage],
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
callbacks: list[Any] | None = None,
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str | Any:
|
||||
"""Async call to Google Gemini generate content API.
|
||||
|
||||
Args:
|
||||
messages: Input messages for the chat completion
|
||||
tools: List of tool/function definitions
|
||||
callbacks: Callback functions (not used as token counts are handled by the response)
|
||||
available_functions: Available functions for tool calling
|
||||
from_task: Task that initiated the call
|
||||
from_agent: Agent that initiated the call
|
||||
|
||||
Returns:
|
||||
Chat completion response or tool call result
|
||||
"""
|
||||
try:
|
||||
self._emit_call_started_event(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
callbacks=callbacks,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
self.tools = tools
|
||||
|
||||
formatted_content, system_instruction = self._format_messages_for_gemini(
|
||||
messages
|
||||
)
|
||||
|
||||
config = self._prepare_generation_config(
|
||||
system_instruction, tools, response_model
|
||||
)
|
||||
|
||||
if self.stream:
|
||||
return await self._ahandle_streaming_completion(
|
||||
formatted_content,
|
||||
config,
|
||||
available_functions,
|
||||
from_task,
|
||||
from_agent,
|
||||
response_model,
|
||||
)
|
||||
|
||||
return await self._ahandle_completion(
|
||||
formatted_content,
|
||||
system_instruction,
|
||||
config,
|
||||
available_functions,
|
||||
from_task,
|
||||
from_agent,
|
||||
response_model,
|
||||
)
|
||||
|
||||
except APIError as e:
|
||||
error_msg = f"Google Gemini API error: {e.code} - {e.message}"
|
||||
logging.error(error_msg)
|
||||
self._emit_call_failed_event(
|
||||
error=error_msg, from_task=from_task, from_agent=from_agent
|
||||
)
|
||||
raise
|
||||
except Exception as e:
|
||||
error_msg = f"Google Gemini API call failed: {e!s}"
|
||||
logging.error(error_msg)
|
||||
self._emit_call_failed_event(
|
||||
error=error_msg, from_task=from_task, from_agent=from_agent
|
||||
)
|
||||
raise
|
||||
|
||||
def _prepare_generation_config(
|
||||
self,
|
||||
system_instruction: str | None = None,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
@@ -294,7 +379,7 @@ class GeminiCompletion(BaseLLM):
|
||||
GenerateContentConfig object for Gemini API
|
||||
"""
|
||||
self.tools = tools
|
||||
config_params = {}
|
||||
config_params: dict[str, Any] = {}
|
||||
|
||||
# Add system instruction if present
|
||||
if system_instruction:
|
||||
@@ -329,7 +414,7 @@ class GeminiCompletion(BaseLLM):
|
||||
|
||||
return types.GenerateContentConfig(**config_params)
|
||||
|
||||
def _convert_tools_for_interference( # type: ignore[no-any-unimported]
|
||||
def _convert_tools_for_interference( # type: ignore[override]
|
||||
self, tools: list[dict[str, Any]]
|
||||
) -> list[types.Tool]:
|
||||
"""Convert CrewAI tool format to Gemini function declaration format."""
|
||||
@@ -346,7 +431,7 @@ class GeminiCompletion(BaseLLM):
|
||||
)
|
||||
|
||||
# Add parameters if present - ensure parameters is a dict
|
||||
if parameters and isinstance(parameters, dict):
|
||||
if parameters and isinstance(parameters, Schema):
|
||||
function_declaration.parameters = parameters
|
||||
|
||||
gemini_tool = types.Tool(function_declarations=[function_declaration])
|
||||
@@ -354,7 +439,7 @@ class GeminiCompletion(BaseLLM):
|
||||
|
||||
return gemini_tools
|
||||
|
||||
def _format_messages_for_gemini( # type: ignore[no-any-unimported]
|
||||
def _format_messages_for_gemini(
|
||||
self, messages: str | list[LLMMessage]
|
||||
) -> tuple[list[types.Content], str | None]:
|
||||
"""Format messages for Gemini API.
|
||||
@@ -373,32 +458,41 @@ class GeminiCompletion(BaseLLM):
|
||||
# Use base class formatting first
|
||||
base_formatted = super()._format_messages(messages)
|
||||
|
||||
contents = []
|
||||
contents: list[types.Content] = []
|
||||
system_instruction: str | None = None
|
||||
|
||||
for message in base_formatted:
|
||||
role = message.get("role")
|
||||
content = message.get("content", "")
|
||||
role = message["role"]
|
||||
content = message["content"]
|
||||
|
||||
# Convert content to string if it's a list
|
||||
if isinstance(content, list):
|
||||
text_content = " ".join(
|
||||
str(item.get("text", "")) if isinstance(item, dict) else str(item)
|
||||
for item in content
|
||||
)
|
||||
else:
|
||||
text_content = str(content) if content else ""
|
||||
|
||||
if role == "system":
|
||||
# Extract system instruction - Gemini handles it separately
|
||||
if system_instruction:
|
||||
system_instruction += f"\n\n{content}"
|
||||
system_instruction += f"\n\n{text_content}"
|
||||
else:
|
||||
system_instruction = cast(str, content)
|
||||
system_instruction = text_content
|
||||
else:
|
||||
# Convert role for Gemini (assistant -> model)
|
||||
gemini_role = "model" if role == "assistant" else "user"
|
||||
|
||||
# Create Content object
|
||||
gemini_content = types.Content(
|
||||
role=gemini_role, parts=[types.Part.from_text(text=content)]
|
||||
role=gemini_role, parts=[types.Part.from_text(text=text_content)]
|
||||
)
|
||||
contents.append(gemini_content)
|
||||
|
||||
return contents, system_instruction
|
||||
|
||||
def _handle_completion( # type: ignore[no-any-unimported]
|
||||
def _handle_completion(
|
||||
self,
|
||||
contents: list[types.Content],
|
||||
system_instruction: str | None,
|
||||
@@ -409,14 +503,14 @@ class GeminiCompletion(BaseLLM):
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str | Any:
|
||||
"""Handle non-streaming content generation."""
|
||||
api_params = {
|
||||
"model": self.model,
|
||||
"contents": contents,
|
||||
"config": config,
|
||||
}
|
||||
|
||||
try:
|
||||
response = self.client.models.generate_content(**api_params)
|
||||
# The API accepts list[Content] but mypy is overly strict about variance
|
||||
contents_for_api: Any = contents
|
||||
response = self.client.models.generate_content(
|
||||
model=self.model,
|
||||
contents=contents_for_api,
|
||||
config=config,
|
||||
)
|
||||
|
||||
usage = self._extract_token_usage(response)
|
||||
except Exception as e:
|
||||
@@ -433,6 +527,8 @@ class GeminiCompletion(BaseLLM):
|
||||
for part in candidate.content.parts:
|
||||
if hasattr(part, "function_call") and part.function_call:
|
||||
function_name = part.function_call.name
|
||||
if function_name is None:
|
||||
continue
|
||||
function_args = (
|
||||
dict(part.function_call.args)
|
||||
if part.function_call.args
|
||||
@@ -442,7 +538,7 @@ class GeminiCompletion(BaseLLM):
|
||||
result = self._handle_tool_execution(
|
||||
function_name=function_name,
|
||||
function_args=function_args,
|
||||
available_functions=available_functions, # type: ignore
|
||||
available_functions=available_functions or {},
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
@@ -450,7 +546,7 @@ class GeminiCompletion(BaseLLM):
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
content = response.text if hasattr(response, "text") else ""
|
||||
content = response.text or ""
|
||||
content = self._apply_stop_words(content)
|
||||
|
||||
messages_for_event = self._convert_contents_to_dict(contents)
|
||||
@@ -465,7 +561,7 @@ class GeminiCompletion(BaseLLM):
|
||||
|
||||
return content
|
||||
|
||||
def _handle_streaming_completion( # type: ignore[no-any-unimported]
|
||||
def _handle_streaming_completion(
|
||||
self,
|
||||
contents: list[types.Content],
|
||||
config: types.GenerateContentConfig,
|
||||
@@ -476,16 +572,16 @@ class GeminiCompletion(BaseLLM):
|
||||
) -> str:
|
||||
"""Handle streaming content generation."""
|
||||
full_response = ""
|
||||
function_calls = {}
|
||||
function_calls: dict[str, dict[str, Any]] = {}
|
||||
|
||||
api_params = {
|
||||
"model": self.model,
|
||||
"contents": contents,
|
||||
"config": config,
|
||||
}
|
||||
|
||||
for chunk in self.client.models.generate_content_stream(**api_params):
|
||||
if hasattr(chunk, "text") and chunk.text:
|
||||
# The API accepts list[Content] but mypy is overly strict about variance
|
||||
contents_for_api: Any = contents
|
||||
for chunk in self.client.models.generate_content_stream(
|
||||
model=self.model,
|
||||
contents=contents_for_api,
|
||||
config=config,
|
||||
):
|
||||
if chunk.text:
|
||||
full_response += chunk.text
|
||||
self._emit_stream_chunk_event(
|
||||
chunk=chunk.text,
|
||||
@@ -493,7 +589,7 @@ class GeminiCompletion(BaseLLM):
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
if hasattr(chunk, "candidates") and chunk.candidates:
|
||||
if chunk.candidates:
|
||||
candidate = chunk.candidates[0]
|
||||
if candidate.content and candidate.content.parts:
|
||||
for part in candidate.content.parts:
|
||||
@@ -513,6 +609,14 @@ class GeminiCompletion(BaseLLM):
|
||||
function_name = call_data["name"]
|
||||
function_args = call_data["args"]
|
||||
|
||||
# Skip if function_name is None
|
||||
if not isinstance(function_name, str):
|
||||
continue
|
||||
|
||||
# Ensure function_args is a dict
|
||||
if not isinstance(function_args, dict):
|
||||
function_args = {}
|
||||
|
||||
# Execute tool
|
||||
result = self._handle_tool_execution(
|
||||
function_name=function_name,
|
||||
@@ -537,6 +641,154 @@ class GeminiCompletion(BaseLLM):
|
||||
|
||||
return full_response
|
||||
|
||||
async def _ahandle_completion(
|
||||
self,
|
||||
contents: list[types.Content],
|
||||
system_instruction: str | None,
|
||||
config: types.GenerateContentConfig,
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str | Any:
|
||||
"""Handle async non-streaming content generation."""
|
||||
try:
|
||||
# The API accepts list[Content] but mypy is overly strict about variance
|
||||
contents_for_api: Any = contents
|
||||
response = await self.client.aio.models.generate_content(
|
||||
model=self.model,
|
||||
contents=contents_for_api,
|
||||
config=config,
|
||||
)
|
||||
|
||||
usage = self._extract_token_usage(response)
|
||||
except Exception as e:
|
||||
if is_context_length_exceeded(e):
|
||||
logging.error(f"Context window exceeded: {e}")
|
||||
raise LLMContextLengthExceededError(str(e)) from e
|
||||
raise e from e
|
||||
|
||||
self._track_token_usage_internal(usage)
|
||||
|
||||
if response.candidates and (self.tools or available_functions):
|
||||
candidate = response.candidates[0]
|
||||
if candidate.content and candidate.content.parts:
|
||||
for part in candidate.content.parts:
|
||||
if hasattr(part, "function_call") and part.function_call:
|
||||
function_name = part.function_call.name
|
||||
if function_name is None:
|
||||
continue
|
||||
function_args = (
|
||||
dict(part.function_call.args)
|
||||
if part.function_call.args
|
||||
else {}
|
||||
)
|
||||
|
||||
result = self._handle_tool_execution(
|
||||
function_name=function_name,
|
||||
function_args=function_args,
|
||||
available_functions=available_functions or {},
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
content = response.text or ""
|
||||
content = self._apply_stop_words(content)
|
||||
|
||||
messages_for_event = self._convert_contents_to_dict(contents)
|
||||
|
||||
self._emit_call_completed_event(
|
||||
response=content,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=messages_for_event,
|
||||
)
|
||||
|
||||
return content
|
||||
|
||||
async def _ahandle_streaming_completion(
|
||||
self,
|
||||
contents: list[types.Content],
|
||||
config: types.GenerateContentConfig,
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str:
|
||||
"""Handle async streaming content generation."""
|
||||
full_response = ""
|
||||
function_calls: dict[str, dict[str, Any]] = {}
|
||||
|
||||
# The API accepts list[Content] but mypy is overly strict about variance
|
||||
contents_for_api: Any = contents
|
||||
stream = await self.client.aio.models.generate_content_stream(
|
||||
model=self.model,
|
||||
contents=contents_for_api,
|
||||
config=config,
|
||||
)
|
||||
async for chunk in stream:
|
||||
if chunk.text:
|
||||
full_response += chunk.text
|
||||
self._emit_stream_chunk_event(
|
||||
chunk=chunk.text,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
if chunk.candidates:
|
||||
candidate = chunk.candidates[0]
|
||||
if candidate.content and candidate.content.parts:
|
||||
for part in candidate.content.parts:
|
||||
if hasattr(part, "function_call") and part.function_call:
|
||||
call_id = part.function_call.name or "default"
|
||||
if call_id not in function_calls:
|
||||
function_calls[call_id] = {
|
||||
"name": part.function_call.name,
|
||||
"args": dict(part.function_call.args)
|
||||
if part.function_call.args
|
||||
else {},
|
||||
}
|
||||
|
||||
if function_calls and available_functions:
|
||||
for call_data in function_calls.values():
|
||||
function_name = call_data["name"]
|
||||
function_args = call_data["args"]
|
||||
|
||||
# Skip if function_name is None
|
||||
if not isinstance(function_name, str):
|
||||
continue
|
||||
|
||||
# Ensure function_args is a dict
|
||||
if not isinstance(function_args, dict):
|
||||
function_args = {}
|
||||
|
||||
result = self._handle_tool_execution(
|
||||
function_name=function_name,
|
||||
function_args=function_args,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
messages_for_event = self._convert_contents_to_dict(contents)
|
||||
|
||||
self._emit_call_completed_event(
|
||||
response=full_response,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=messages_for_event,
|
||||
)
|
||||
|
||||
return full_response
|
||||
|
||||
def supports_function_calling(self) -> bool:
|
||||
"""Check if the model supports function calling."""
|
||||
return self.supports_tools
|
||||
@@ -583,9 +835,10 @@ class GeminiCompletion(BaseLLM):
|
||||
# Default context window size for Gemini models
|
||||
return int(1048576 * CONTEXT_WINDOW_USAGE_RATIO) # 1M tokens
|
||||
|
||||
def _extract_token_usage(self, response: dict[str, Any]) -> dict[str, Any]:
|
||||
@staticmethod
|
||||
def _extract_token_usage(response: GenerateContentResponse) -> dict[str, Any]:
|
||||
"""Extract token usage from Gemini response."""
|
||||
if hasattr(response, "usage_metadata"):
|
||||
if response.usage_metadata:
|
||||
usage = response.usage_metadata
|
||||
return {
|
||||
"prompt_token_count": getattr(usage, "prompt_token_count", 0),
|
||||
@@ -595,21 +848,23 @@ class GeminiCompletion(BaseLLM):
|
||||
}
|
||||
return {"total_tokens": 0}
|
||||
|
||||
def _convert_contents_to_dict( # type: ignore[no-any-unimported]
|
||||
def _convert_contents_to_dict(
|
||||
self,
|
||||
contents: list[types.Content],
|
||||
) -> list[dict[str, str]]:
|
||||
"""Convert contents to dict format."""
|
||||
return [
|
||||
{
|
||||
"role": "assistant"
|
||||
if content_obj.role == "model"
|
||||
else content_obj.role,
|
||||
"content": " ".join(
|
||||
part.text
|
||||
for part in content_obj.parts
|
||||
if hasattr(part, "text") and part.text
|
||||
),
|
||||
}
|
||||
for content_obj in contents
|
||||
]
|
||||
result: list[dict[str, str]] = []
|
||||
for content_obj in contents:
|
||||
role = content_obj.role
|
||||
if role == "model":
|
||||
role = "assistant"
|
||||
elif role is None:
|
||||
role = "user"
|
||||
|
||||
parts = content_obj.parts or []
|
||||
content = " ".join(
|
||||
part.text for part in parts if hasattr(part, "text") and part.text
|
||||
)
|
||||
|
||||
result.append({"role": role, "content": content})
|
||||
return result
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterator
|
||||
from collections.abc import AsyncIterator, Iterator
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import httpx
|
||||
from openai import APIConnectionError, NotFoundError, OpenAI
|
||||
from openai import APIConnectionError, AsyncOpenAI, NotFoundError, OpenAI
|
||||
from openai.types.chat import ChatCompletion, ChatCompletionChunk
|
||||
from openai.types.chat.chat_completion import Choice
|
||||
from openai.types.chat.chat_completion_chunk import ChoiceDelta
|
||||
@@ -15,7 +15,7 @@ from pydantic import BaseModel
|
||||
|
||||
from crewai.events.types.llm_events import LLMCallType
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.llms.hooks.transport import HTTPTransport
|
||||
from crewai.llms.hooks.transport import AsyncHTTPTransport, HTTPTransport
|
||||
from crewai.utilities.agent_utils import is_context_length_exceeded
|
||||
from crewai.utilities.converter import generate_model_description
|
||||
from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
||||
@@ -101,6 +101,14 @@ class OpenAICompletion(BaseLLM):
|
||||
|
||||
self.client = OpenAI(**client_config)
|
||||
|
||||
async_client_config = self._get_client_params()
|
||||
if self.interceptor:
|
||||
async_transport = AsyncHTTPTransport(interceptor=self.interceptor)
|
||||
async_http_client = httpx.AsyncClient(transport=async_transport)
|
||||
async_client_config["http_client"] = async_http_client
|
||||
|
||||
self.async_client = AsyncOpenAI(**async_client_config)
|
||||
|
||||
# Completion parameters
|
||||
self.top_p = top_p
|
||||
self.frequency_penalty = frequency_penalty
|
||||
@@ -210,6 +218,71 @@ class OpenAICompletion(BaseLLM):
|
||||
)
|
||||
raise
|
||||
|
||||
async def acall(
|
||||
self,
|
||||
messages: str | list[LLMMessage],
|
||||
tools: list[dict[str, BaseTool]] | None = None,
|
||||
callbacks: list[Any] | None = None,
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Task | None = None,
|
||||
from_agent: Agent | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str | Any:
|
||||
"""Async call to OpenAI chat completion API.
|
||||
|
||||
Args:
|
||||
messages: Input messages for the chat completion
|
||||
tools: list of tool/function definitions
|
||||
callbacks: Callback functions (not used in native implementation)
|
||||
available_functions: Available functions for tool calling
|
||||
from_task: Task that initiated the call
|
||||
from_agent: Agent that initiated the call
|
||||
response_model: Response model for structured output.
|
||||
|
||||
Returns:
|
||||
Chat completion response or tool call result
|
||||
"""
|
||||
try:
|
||||
self._emit_call_started_event(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
callbacks=callbacks,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
formatted_messages = self._format_messages(messages)
|
||||
|
||||
completion_params = self._prepare_completion_params(
|
||||
messages=formatted_messages, tools=tools
|
||||
)
|
||||
|
||||
if self.stream:
|
||||
return await self._ahandle_streaming_completion(
|
||||
params=completion_params,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
response_model=response_model,
|
||||
)
|
||||
|
||||
return await self._ahandle_completion(
|
||||
params=completion_params,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
response_model=response_model,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"OpenAI API call failed: {e!s}"
|
||||
logging.error(error_msg)
|
||||
self._emit_call_failed_event(
|
||||
error=error_msg, from_task=from_task, from_agent=from_agent
|
||||
)
|
||||
raise
|
||||
|
||||
def _prepare_completion_params(
|
||||
self, messages: list[LLMMessage], tools: list[dict[str, BaseTool]] | None = None
|
||||
) -> dict[str, Any]:
|
||||
@@ -352,10 +425,10 @@ class OpenAICompletion(BaseLLM):
|
||||
|
||||
if message.tool_calls and available_functions:
|
||||
tool_call = message.tool_calls[0]
|
||||
function_name = tool_call.function.name # type: ignore[union-attr]
|
||||
function_name = tool_call.function.name
|
||||
|
||||
try:
|
||||
function_args = json.loads(tool_call.function.arguments) # type: ignore[union-attr]
|
||||
function_args = json.loads(tool_call.function.arguments)
|
||||
except json.JSONDecodeError as e:
|
||||
logging.error(f"Failed to parse tool arguments: {e}")
|
||||
function_args = {}
|
||||
@@ -564,6 +637,266 @@ class OpenAICompletion(BaseLLM):
|
||||
|
||||
return full_response
|
||||
|
||||
async def _ahandle_completion(
|
||||
self,
|
||||
params: dict[str, Any],
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str | Any:
|
||||
"""Handle non-streaming async chat completion."""
|
||||
try:
|
||||
if response_model:
|
||||
parse_params = {
|
||||
k: v for k, v in params.items() if k != "response_format"
|
||||
}
|
||||
parsed_response = await self.async_client.beta.chat.completions.parse(
|
||||
**parse_params,
|
||||
response_format=response_model,
|
||||
)
|
||||
math_reasoning = parsed_response.choices[0].message
|
||||
|
||||
if math_reasoning.refusal:
|
||||
pass
|
||||
|
||||
usage = self._extract_openai_token_usage(parsed_response)
|
||||
self._track_token_usage_internal(usage)
|
||||
|
||||
parsed_object = parsed_response.choices[0].message.parsed
|
||||
if parsed_object:
|
||||
structured_json = parsed_object.model_dump_json()
|
||||
self._emit_call_completed_event(
|
||||
response=structured_json,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
)
|
||||
return structured_json
|
||||
|
||||
response: ChatCompletion = await self.async_client.chat.completions.create(
|
||||
**params
|
||||
)
|
||||
|
||||
usage = self._extract_openai_token_usage(response)
|
||||
|
||||
self._track_token_usage_internal(usage)
|
||||
|
||||
choice: Choice = response.choices[0]
|
||||
message = choice.message
|
||||
|
||||
if message.tool_calls and available_functions:
|
||||
tool_call = message.tool_calls[0]
|
||||
function_name = tool_call.function.name
|
||||
|
||||
try:
|
||||
function_args = json.loads(tool_call.function.arguments)
|
||||
except json.JSONDecodeError as e:
|
||||
logging.error(f"Failed to parse tool arguments: {e}")
|
||||
function_args = {}
|
||||
|
||||
result = self._handle_tool_execution(
|
||||
function_name=function_name,
|
||||
function_args=function_args,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
content = message.content or ""
|
||||
content = self._apply_stop_words(content)
|
||||
|
||||
if self.response_format and isinstance(self.response_format, type):
|
||||
try:
|
||||
structured_result = self._validate_structured_output(
|
||||
content, self.response_format
|
||||
)
|
||||
self._emit_call_completed_event(
|
||||
response=structured_result,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
)
|
||||
return structured_result
|
||||
except ValueError as e:
|
||||
logging.warning(f"Structured output validation failed: {e}")
|
||||
|
||||
self._emit_call_completed_event(
|
||||
response=content,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
)
|
||||
|
||||
if usage.get("total_tokens", 0) > 0:
|
||||
logging.info(f"OpenAI API usage: {usage}")
|
||||
except NotFoundError as e:
|
||||
error_msg = f"Model {self.model} not found: {e}"
|
||||
logging.error(error_msg)
|
||||
self._emit_call_failed_event(
|
||||
error=error_msg, from_task=from_task, from_agent=from_agent
|
||||
)
|
||||
raise ValueError(error_msg) from e
|
||||
except APIConnectionError as e:
|
||||
error_msg = f"Failed to connect to OpenAI API: {e}"
|
||||
logging.error(error_msg)
|
||||
self._emit_call_failed_event(
|
||||
error=error_msg, from_task=from_task, from_agent=from_agent
|
||||
)
|
||||
raise ConnectionError(error_msg) from e
|
||||
except Exception as e:
|
||||
if is_context_length_exceeded(e):
|
||||
logging.error(f"Context window exceeded: {e}")
|
||||
raise LLMContextLengthExceededError(str(e)) from e
|
||||
|
||||
error_msg = f"OpenAI API call failed: {e!s}"
|
||||
logging.error(error_msg)
|
||||
self._emit_call_failed_event(
|
||||
error=error_msg, from_task=from_task, from_agent=from_agent
|
||||
)
|
||||
raise e from e
|
||||
|
||||
return content
|
||||
|
||||
async def _ahandle_streaming_completion(
|
||||
self,
|
||||
params: dict[str, Any],
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str:
|
||||
"""Handle async streaming chat completion."""
|
||||
full_response = ""
|
||||
tool_calls = {}
|
||||
|
||||
if response_model:
|
||||
completion_stream: AsyncIterator[
|
||||
ChatCompletionChunk
|
||||
] = await self.async_client.chat.completions.create(**params)
|
||||
|
||||
accumulated_content = ""
|
||||
async for chunk in completion_stream:
|
||||
if not chunk.choices:
|
||||
continue
|
||||
|
||||
choice = chunk.choices[0]
|
||||
delta: ChoiceDelta = choice.delta
|
||||
|
||||
if delta.content:
|
||||
accumulated_content += delta.content
|
||||
self._emit_stream_chunk_event(
|
||||
chunk=delta.content,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
try:
|
||||
parsed_object = response_model.model_validate_json(accumulated_content)
|
||||
structured_json = parsed_object.model_dump_json()
|
||||
|
||||
self._emit_call_completed_event(
|
||||
response=structured_json,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
)
|
||||
|
||||
return structured_json
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to parse structured output from stream: {e}")
|
||||
self._emit_call_completed_event(
|
||||
response=accumulated_content,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
)
|
||||
return accumulated_content
|
||||
|
||||
stream: AsyncIterator[
|
||||
ChatCompletionChunk
|
||||
] = await self.async_client.chat.completions.create(**params)
|
||||
|
||||
async for chunk in stream:
|
||||
if not chunk.choices:
|
||||
continue
|
||||
|
||||
choice = chunk.choices[0]
|
||||
chunk_delta: ChoiceDelta = choice.delta
|
||||
|
||||
if chunk_delta.content:
|
||||
full_response += chunk_delta.content
|
||||
self._emit_stream_chunk_event(
|
||||
chunk=chunk_delta.content,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
if chunk_delta.tool_calls:
|
||||
for tool_call in chunk_delta.tool_calls:
|
||||
call_id = tool_call.id or "default"
|
||||
if call_id not in tool_calls:
|
||||
tool_calls[call_id] = {
|
||||
"name": "",
|
||||
"arguments": "",
|
||||
}
|
||||
|
||||
if tool_call.function and tool_call.function.name:
|
||||
tool_calls[call_id]["name"] = tool_call.function.name
|
||||
if tool_call.function and tool_call.function.arguments:
|
||||
tool_calls[call_id]["arguments"] += tool_call.function.arguments
|
||||
|
||||
if tool_calls and available_functions:
|
||||
for call_data in tool_calls.values():
|
||||
function_name = call_data["name"]
|
||||
arguments = call_data["arguments"]
|
||||
|
||||
if not function_name or not arguments:
|
||||
continue
|
||||
|
||||
if function_name not in available_functions:
|
||||
logging.warning(
|
||||
f"Function '{function_name}' not found in available functions"
|
||||
)
|
||||
continue
|
||||
|
||||
try:
|
||||
function_args = json.loads(arguments)
|
||||
except json.JSONDecodeError as e:
|
||||
logging.error(f"Failed to parse streamed tool arguments: {e}")
|
||||
continue
|
||||
|
||||
result = self._handle_tool_execution(
|
||||
function_name=function_name,
|
||||
function_args=function_args,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
full_response = self._apply_stop_words(full_response)
|
||||
|
||||
self._emit_call_completed_event(
|
||||
response=full_response,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
)
|
||||
|
||||
return full_response
|
||||
|
||||
def supports_function_calling(self) -> bool:
|
||||
"""Check if the model supports function calling."""
|
||||
return not self.is_o1_model
|
||||
|
||||
@@ -9,12 +9,14 @@ data is collected. Users can opt-in to share more complete data using the
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import atexit
|
||||
from collections.abc import Callable
|
||||
from importlib.metadata import version
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import platform
|
||||
import signal
|
||||
import threading
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
@@ -31,6 +33,14 @@ from opentelemetry.sdk.trace.export import (
|
||||
from opentelemetry.trace import Span
|
||||
from typing_extensions import Self
|
||||
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.system_events import (
|
||||
SigContEvent,
|
||||
SigHupEvent,
|
||||
SigIntEvent,
|
||||
SigTStpEvent,
|
||||
SigTermEvent,
|
||||
)
|
||||
from crewai.telemetry.constants import (
|
||||
CREWAI_TELEMETRY_BASE_URL,
|
||||
CREWAI_TELEMETRY_SERVICE_NAME,
|
||||
@@ -121,6 +131,7 @@ class Telemetry:
|
||||
)
|
||||
|
||||
self.provider.add_span_processor(processor)
|
||||
self._register_shutdown_handlers()
|
||||
self.ready = True
|
||||
except Exception as e:
|
||||
if isinstance(
|
||||
@@ -155,6 +166,71 @@ class Telemetry:
|
||||
self.ready = False
|
||||
self.trace_set = False
|
||||
|
||||
def _register_shutdown_handlers(self) -> None:
|
||||
"""Register handlers for graceful shutdown on process exit and signals."""
|
||||
atexit.register(self._shutdown)
|
||||
|
||||
self._original_handlers: dict[int, Any] = {}
|
||||
|
||||
self._register_signal_handler(signal.SIGTERM, SigTermEvent, shutdown=True)
|
||||
self._register_signal_handler(signal.SIGINT, SigIntEvent, shutdown=True)
|
||||
self._register_signal_handler(signal.SIGHUP, SigHupEvent, shutdown=False)
|
||||
self._register_signal_handler(signal.SIGTSTP, SigTStpEvent, shutdown=False)
|
||||
self._register_signal_handler(signal.SIGCONT, SigContEvent, shutdown=False)
|
||||
|
||||
def _register_signal_handler(
|
||||
self,
|
||||
sig: signal.Signals,
|
||||
event_class: type,
|
||||
shutdown: bool = False,
|
||||
) -> None:
|
||||
"""Register a signal handler that emits an event.
|
||||
|
||||
Args:
|
||||
sig: The signal to handle.
|
||||
event_class: The event class to instantiate and emit.
|
||||
shutdown: Whether to trigger shutdown on this signal.
|
||||
"""
|
||||
try:
|
||||
original_handler = signal.getsignal(sig)
|
||||
self._original_handlers[sig] = original_handler
|
||||
|
||||
def handler(signum: int, frame: Any) -> None:
|
||||
crewai_event_bus.emit(self, event_class())
|
||||
|
||||
if shutdown:
|
||||
self._shutdown()
|
||||
|
||||
if original_handler not in (signal.SIG_DFL, signal.SIG_IGN, None):
|
||||
if callable(original_handler):
|
||||
original_handler(signum, frame)
|
||||
elif shutdown:
|
||||
raise SystemExit(0)
|
||||
|
||||
signal.signal(sig, handler)
|
||||
except ValueError as e:
|
||||
logger.warning(
|
||||
f"Cannot register {sig.name} handler: not running in main thread",
|
||||
exc_info=e,
|
||||
)
|
||||
except OSError as e:
|
||||
logger.warning(f"Cannot register {sig.name} handler: {e}", exc_info=e)
|
||||
|
||||
def _shutdown(self) -> None:
|
||||
"""Flush and shutdown the telemetry provider on process exit.
|
||||
|
||||
Uses a short timeout to avoid blocking process shutdown.
|
||||
"""
|
||||
if not self.ready:
|
||||
return
|
||||
|
||||
try:
|
||||
self.provider.force_flush(timeout_millis=5000)
|
||||
self.provider.shutdown()
|
||||
self.ready = False
|
||||
except Exception as e:
|
||||
logger.debug(f"Telemetry shutdown failed: {e}")
|
||||
|
||||
def _safe_telemetry_operation(
|
||||
self, operation: Callable[[], Span | None]
|
||||
) -> Span | None:
|
||||
|
||||
@@ -147,7 +147,7 @@ def test_custom_llm():
|
||||
assert agent.llm.model == "gpt-4"
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_agent_execution():
|
||||
agent = Agent(
|
||||
role="test role",
|
||||
@@ -166,7 +166,7 @@ def test_agent_execution():
|
||||
assert output == "1 + 1 is 2"
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_agent_execution_with_tools():
|
||||
@tool
|
||||
def multiplier(first_number: int, second_number: int) -> float:
|
||||
@@ -211,7 +211,7 @@ def test_agent_execution_with_tools():
|
||||
assert received_events[0].tool_args == {"first_number": 3, "second_number": 4}
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_logging_tool_usage():
|
||||
@tool
|
||||
def multiplier(first_number: int, second_number: int) -> float:
|
||||
@@ -245,7 +245,7 @@ def test_logging_tool_usage():
|
||||
assert agent.tools_handler.last_used_tool.arguments == tool_usage.arguments
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_cache_hitting():
|
||||
@tool
|
||||
def multiplier(first_number: int, second_number: int) -> float:
|
||||
@@ -325,7 +325,7 @@ def test_cache_hitting():
|
||||
assert received_events[0].output == "12"
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_disabling_cache_for_agent():
|
||||
@tool
|
||||
def multiplier(first_number: int, second_number: int) -> float:
|
||||
@@ -389,7 +389,7 @@ def test_disabling_cache_for_agent():
|
||||
read.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_agent_execution_with_specific_tools():
|
||||
@tool
|
||||
def multiplier(first_number: int, second_number: int) -> float:
|
||||
@@ -412,7 +412,7 @@ def test_agent_execution_with_specific_tools():
|
||||
assert output == "The result of the multiplication is 12."
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_agent_powered_by_new_o_model_family_that_allows_skipping_tool():
|
||||
@tool
|
||||
def multiplier(first_number: int, second_number: int) -> float:
|
||||
@@ -438,7 +438,7 @@ def test_agent_powered_by_new_o_model_family_that_allows_skipping_tool():
|
||||
assert output == "12"
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_agent_powered_by_new_o_model_family_that_uses_tool():
|
||||
@tool
|
||||
def comapny_customer_data() -> str:
|
||||
@@ -464,7 +464,7 @@ def test_agent_powered_by_new_o_model_family_that_uses_tool():
|
||||
assert output == "42"
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_agent_custom_max_iterations():
|
||||
@tool
|
||||
def get_final_answer() -> float:
|
||||
@@ -509,7 +509,7 @@ def test_agent_custom_max_iterations():
|
||||
assert call_count == 2
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
@pytest.mark.timeout(30)
|
||||
def test_agent_max_iterations_stops_loop():
|
||||
"""Test that agent execution terminates when max_iter is reached."""
|
||||
@@ -546,7 +546,7 @@ def test_agent_max_iterations_stops_loop():
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_agent_repeated_tool_usage(capsys):
|
||||
"""Test that agents handle repeated tool usage appropriately.
|
||||
|
||||
@@ -595,7 +595,7 @@ def test_agent_repeated_tool_usage(capsys):
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_agent_repeated_tool_usage_check_even_with_disabled_cache(capsys):
|
||||
@tool
|
||||
def get_final_answer(anything: str) -> float:
|
||||
@@ -638,7 +638,7 @@ def test_agent_repeated_tool_usage_check_even_with_disabled_cache(capsys):
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_agent_moved_on_after_max_iterations():
|
||||
@tool
|
||||
def get_final_answer() -> float:
|
||||
@@ -665,7 +665,7 @@ def test_agent_moved_on_after_max_iterations():
|
||||
assert output == "42"
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_agent_respect_the_max_rpm_set(capsys):
|
||||
@tool
|
||||
def get_final_answer() -> float:
|
||||
@@ -699,7 +699,7 @@ def test_agent_respect_the_max_rpm_set(capsys):
|
||||
moveon.assert_called()
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_agent_respect_the_max_rpm_set_over_crew_rpm(capsys):
|
||||
from unittest.mock import patch
|
||||
|
||||
@@ -737,7 +737,7 @@ def test_agent_respect_the_max_rpm_set_over_crew_rpm(capsys):
|
||||
moveon.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_agent_without_max_rpm_respects_crew_rpm(capsys):
|
||||
from unittest.mock import patch
|
||||
|
||||
@@ -797,7 +797,7 @@ def test_agent_without_max_rpm_respects_crew_rpm(capsys):
|
||||
moveon.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_agent_error_on_parsing_tool(capsys):
|
||||
from unittest.mock import patch
|
||||
|
||||
@@ -840,7 +840,7 @@ def test_agent_error_on_parsing_tool(capsys):
|
||||
assert "Error on parsing tool." in captured.out
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_agent_remembers_output_format_after_using_tools_too_many_times():
|
||||
from unittest.mock import patch
|
||||
|
||||
@@ -875,7 +875,7 @@ def test_agent_remembers_output_format_after_using_tools_too_many_times():
|
||||
remember_format.assert_called()
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_agent_use_specific_tasks_output_as_context(capsys):
|
||||
agent1 = Agent(role="test role", goal="test goal", backstory="test backstory")
|
||||
agent2 = Agent(role="test role2", goal="test goal2", backstory="test backstory2")
|
||||
@@ -902,7 +902,7 @@ def test_agent_use_specific_tasks_output_as_context(capsys):
|
||||
assert "hi" in result.raw.lower() or "hello" in result.raw.lower()
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_agent_step_callback():
|
||||
class StepCallback:
|
||||
def callback(self, step):
|
||||
@@ -936,7 +936,7 @@ def test_agent_step_callback():
|
||||
callback.assert_called()
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_agent_function_calling_llm():
|
||||
from crewai.llm import LLM
|
||||
llm = LLM(model="gpt-4o", is_litellm=True)
|
||||
@@ -983,7 +983,7 @@ def test_agent_function_calling_llm():
|
||||
mock_original_tool_calling.assert_called()
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_tool_result_as_answer_is_the_final_answer_for_the_agent():
|
||||
from crewai.tools import BaseTool
|
||||
|
||||
@@ -1013,7 +1013,7 @@ def test_tool_result_as_answer_is_the_final_answer_for_the_agent():
|
||||
assert result.raw == "Howdy!"
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_tool_usage_information_is_appended_to_agent():
|
||||
from crewai.tools import BaseTool
|
||||
|
||||
@@ -1068,7 +1068,7 @@ def test_agent_definition_based_on_dict():
|
||||
|
||||
|
||||
# test for human input
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_agent_human_input():
|
||||
# Agent configuration
|
||||
config = {
|
||||
@@ -1216,7 +1216,7 @@ Thought:<|eot_id|>
|
||||
assert mock_format_prompt.return_value == expected_prompt
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_task_allow_crewai_trigger_context():
|
||||
from crewai import Crew
|
||||
|
||||
@@ -1237,7 +1237,7 @@ def test_task_allow_crewai_trigger_context():
|
||||
assert "Trigger Payload: Important context data" in prompt
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_task_without_allow_crewai_trigger_context():
|
||||
from crewai import Crew
|
||||
|
||||
@@ -1260,7 +1260,7 @@ def test_task_without_allow_crewai_trigger_context():
|
||||
assert "Important context data" not in prompt
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_task_allow_crewai_trigger_context_no_payload():
|
||||
from crewai import Crew
|
||||
|
||||
@@ -1282,7 +1282,7 @@ def test_task_allow_crewai_trigger_context_no_payload():
|
||||
assert "Trigger Payload:" not in prompt
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_do_not_allow_crewai_trigger_context_for_first_task_hierarchical():
|
||||
from crewai import Crew
|
||||
|
||||
@@ -1311,7 +1311,7 @@ def test_do_not_allow_crewai_trigger_context_for_first_task_hierarchical():
|
||||
assert "Trigger Payload: Initial context data" not in first_prompt
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_first_task_auto_inject_trigger():
|
||||
from crewai import Crew
|
||||
|
||||
@@ -1344,7 +1344,7 @@ def test_first_task_auto_inject_trigger():
|
||||
assert "Trigger Payload:" not in second_prompt
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_ensure_first_task_allow_crewai_trigger_context_is_false_does_not_inject():
|
||||
from crewai import Crew
|
||||
|
||||
@@ -1549,7 +1549,7 @@ def test_agent_with_additional_kwargs():
|
||||
assert agent.llm.frequency_penalty == 0.1
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_llm_call():
|
||||
llm = LLM(model="gpt-3.5-turbo")
|
||||
messages = [{"role": "user", "content": "Say 'Hello, World!'"}]
|
||||
@@ -1558,7 +1558,7 @@ def test_llm_call():
|
||||
assert "Hello, World!" in response
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_llm_call_with_error():
|
||||
llm = LLM(model="non-existent-model")
|
||||
messages = [{"role": "user", "content": "This should fail"}]
|
||||
@@ -1567,7 +1567,7 @@ def test_llm_call_with_error():
|
||||
llm.call(messages)
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_handle_context_length_exceeds_limit():
|
||||
# Import necessary modules
|
||||
from crewai.utilities.agent_utils import handle_context_length
|
||||
@@ -1620,7 +1620,7 @@ def test_handle_context_length_exceeds_limit():
|
||||
mock_summarize.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_handle_context_length_exceeds_limit_cli_no():
|
||||
agent = Agent(
|
||||
role="test role",
|
||||
@@ -1695,7 +1695,7 @@ def test_agent_with_all_llm_attributes():
|
||||
assert agent.llm.api_key == "sk-your-api-key-here"
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_llm_call_with_all_attributes():
|
||||
llm = LLM(
|
||||
model="gpt-3.5-turbo",
|
||||
@@ -1712,7 +1712,7 @@ def test_llm_call_with_all_attributes():
|
||||
assert "STOP" not in response
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_agent_with_ollama_llama3():
|
||||
agent = Agent(
|
||||
role="test role",
|
||||
@@ -1733,7 +1733,7 @@ def test_agent_with_ollama_llama3():
|
||||
assert "Llama3" in response or "AI" in response or "language model" in response
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_llm_call_with_ollama_llama3():
|
||||
llm = LLM(
|
||||
model="ollama/llama3.2:3b",
|
||||
@@ -1752,7 +1752,7 @@ def test_llm_call_with_ollama_llama3():
|
||||
assert "Llama3" in response or "AI" in response or "language model" in response
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_agent_execute_task_basic():
|
||||
agent = Agent(
|
||||
role="test role",
|
||||
@@ -1771,7 +1771,7 @@ def test_agent_execute_task_basic():
|
||||
assert "4" in result
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_agent_execute_task_with_context():
|
||||
agent = Agent(
|
||||
role="test role",
|
||||
@@ -1793,7 +1793,7 @@ def test_agent_execute_task_with_context():
|
||||
assert "fox" in result.lower() and "dog" in result.lower()
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_agent_execute_task_with_tool():
|
||||
@tool
|
||||
def dummy_tool(query: str) -> str:
|
||||
@@ -1818,7 +1818,7 @@ def test_agent_execute_task_with_tool():
|
||||
assert "Dummy result for: test query" in result
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_agent_execute_task_with_custom_llm():
|
||||
agent = Agent(
|
||||
role="test role",
|
||||
@@ -1839,7 +1839,7 @@ def test_agent_execute_task_with_custom_llm():
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_agent_execute_task_with_ollama():
|
||||
agent = Agent(
|
||||
role="test role",
|
||||
@@ -1859,7 +1859,7 @@ def test_agent_execute_task_with_ollama():
|
||||
assert "AI" in result or "artificial intelligence" in result.lower()
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_agent_with_knowledge_sources():
|
||||
content = "Brandon's favorite color is red and he likes Mexican food."
|
||||
string_source = StringKnowledgeSource(content=content)
|
||||
@@ -1891,7 +1891,7 @@ def test_agent_with_knowledge_sources():
|
||||
assert "red" in result.raw.lower()
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_agent_with_knowledge_sources_with_query_limit_and_score_threshold():
|
||||
content = "Brandon's favorite color is red and he likes Mexican food."
|
||||
string_source = StringKnowledgeSource(content=content)
|
||||
@@ -1939,7 +1939,7 @@ def test_agent_with_knowledge_sources_with_query_limit_and_score_threshold():
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_agent_with_knowledge_sources_with_query_limit_and_score_threshold_default():
|
||||
content = "Brandon's favorite color is red and he likes Mexican food."
|
||||
string_source = StringKnowledgeSource(content=content)
|
||||
@@ -1988,7 +1988,7 @@ def test_agent_with_knowledge_sources_with_query_limit_and_score_threshold_defau
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_agent_with_knowledge_sources_extensive_role():
|
||||
content = "Brandon's favorite color is red and he likes Mexican food."
|
||||
string_source = StringKnowledgeSource(content=content)
|
||||
@@ -2024,7 +2024,7 @@ def test_agent_with_knowledge_sources_extensive_role():
|
||||
assert "red" in result.raw.lower()
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_agent_with_knowledge_sources_works_with_copy():
|
||||
content = "Brandon's favorite color is red and he likes Mexican food."
|
||||
string_source = StringKnowledgeSource(content=content)
|
||||
@@ -2063,7 +2063,7 @@ def test_agent_with_knowledge_sources_works_with_copy():
|
||||
assert isinstance(agent_copy.llm, BaseLLM)
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_agent_with_knowledge_sources_generate_search_query():
|
||||
content = "Brandon's favorite color is red and he likes Mexican food."
|
||||
string_source = StringKnowledgeSource(content=content)
|
||||
@@ -2116,7 +2116,7 @@ def test_agent_with_knowledge_sources_generate_search_query():
|
||||
assert "red" in result.raw.lower()
|
||||
|
||||
|
||||
@pytest.mark.vcr(record_mode="none", filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_agent_with_knowledge_with_no_crewai_knowledge():
|
||||
mock_knowledge = MagicMock(spec=Knowledge)
|
||||
|
||||
@@ -2143,7 +2143,7 @@ def test_agent_with_knowledge_with_no_crewai_knowledge():
|
||||
mock_knowledge.query.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_agent_with_only_crewai_knowledge():
|
||||
mock_knowledge = MagicMock(spec=Knowledge)
|
||||
|
||||
@@ -2168,7 +2168,7 @@ def test_agent_with_only_crewai_knowledge():
|
||||
mock_knowledge.query.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.vcr(record_mode="none", filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_agent_knowledege_with_crewai_knowledge():
|
||||
crew_knowledge = MagicMock(spec=Knowledge)
|
||||
agent_knowledge = MagicMock(spec=Knowledge)
|
||||
@@ -2197,7 +2197,7 @@ def test_agent_knowledege_with_crewai_knowledge():
|
||||
crew_knowledge.query.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_litellm_auth_error_handling():
|
||||
"""Test that LiteLLM authentication errors are handled correctly and not retried."""
|
||||
from litellm import AuthenticationError as LiteLLMAuthenticationError
|
||||
@@ -2326,7 +2326,7 @@ def test_litellm_anthropic_error_handling():
|
||||
mock_llm_call.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_get_knowledge_search_query():
|
||||
"""Test that _get_knowledge_search_query calls the LLM with the correct prompts."""
|
||||
from crewai.utilities.i18n import I18N
|
||||
|
||||
@@ -70,7 +70,7 @@ class ResearchResult(BaseModel):
|
||||
sources: list[str] = Field(description="List of sources used")
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
@pytest.mark.parametrize("verbose", [True, False])
|
||||
def test_lite_agent_created_with_correct_parameters(monkeypatch, verbose):
|
||||
"""Test that LiteAgent is created with the correct parameters when Agent.kickoff() is called."""
|
||||
@@ -130,7 +130,7 @@ def test_lite_agent_created_with_correct_parameters(monkeypatch, verbose):
|
||||
assert created_lite_agent["response_format"] == TestResponse
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_lite_agent_with_tools():
|
||||
"""Test that Agent can use tools."""
|
||||
# Create a LiteAgent with tools
|
||||
@@ -174,7 +174,7 @@ def test_lite_agent_with_tools():
|
||||
assert event.tool_name == "search_web"
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_lite_agent_structured_output():
|
||||
"""Test that Agent can return a simple structured output."""
|
||||
|
||||
@@ -217,7 +217,7 @@ def test_lite_agent_structured_output():
|
||||
return result
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_lite_agent_returns_usage_metrics():
|
||||
"""Test that LiteAgent returns usage metrics."""
|
||||
llm = LLM(model="gpt-4o-mini")
|
||||
@@ -238,7 +238,7 @@ def test_lite_agent_returns_usage_metrics():
|
||||
assert result.usage_metrics["total_tokens"] > 0
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_lite_agent_output_includes_messages():
|
||||
"""Test that LiteAgentOutput includes messages from agent execution."""
|
||||
llm = LLM(model="gpt-4o-mini")
|
||||
@@ -259,7 +259,7 @@ def test_lite_agent_output_includes_messages():
|
||||
assert len(result.messages) > 0
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
@pytest.mark.asyncio
|
||||
async def test_lite_agent_returns_usage_metrics_async():
|
||||
"""Test that LiteAgent returns usage metrics when run asynchronously."""
|
||||
@@ -354,9 +354,9 @@ def test_sets_parent_flow_when_inside_flow():
|
||||
assert captured_agent.parent_flow is flow
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_guardrail_is_called_using_string():
|
||||
guardrail_events = defaultdict(list)
|
||||
guardrail_events: dict[str, list] = defaultdict(list)
|
||||
from crewai.events.event_types import (
|
||||
LLMGuardrailCompletedEvent,
|
||||
LLMGuardrailStartedEvent,
|
||||
@@ -369,35 +369,33 @@ def test_guardrail_is_called_using_string():
|
||||
guardrail="""Only include Brazilian players, both women and men""",
|
||||
)
|
||||
|
||||
all_events_received = threading.Event()
|
||||
condition = threading.Condition()
|
||||
|
||||
@crewai_event_bus.on(LLMGuardrailStartedEvent)
|
||||
def capture_guardrail_started(source, event):
|
||||
assert isinstance(source, LiteAgent)
|
||||
assert source.original_agent == agent
|
||||
guardrail_events["started"].append(event)
|
||||
if (
|
||||
len(guardrail_events["started"]) == 2
|
||||
and len(guardrail_events["completed"]) == 2
|
||||
):
|
||||
all_events_received.set()
|
||||
with condition:
|
||||
guardrail_events["started"].append(event)
|
||||
condition.notify()
|
||||
|
||||
@crewai_event_bus.on(LLMGuardrailCompletedEvent)
|
||||
def capture_guardrail_completed(source, event):
|
||||
assert isinstance(source, LiteAgent)
|
||||
assert source.original_agent == agent
|
||||
guardrail_events["completed"].append(event)
|
||||
if (
|
||||
len(guardrail_events["started"]) == 2
|
||||
and len(guardrail_events["completed"]) == 2
|
||||
):
|
||||
all_events_received.set()
|
||||
with condition:
|
||||
guardrail_events["completed"].append(event)
|
||||
condition.notify()
|
||||
|
||||
result = agent.kickoff(messages="Top 10 best players in the world?")
|
||||
|
||||
assert all_events_received.wait(timeout=10), (
|
||||
"Timeout waiting for all guardrail events"
|
||||
)
|
||||
with condition:
|
||||
success = condition.wait_for(
|
||||
lambda: len(guardrail_events["started"]) >= 2
|
||||
and len(guardrail_events["completed"]) >= 2,
|
||||
timeout=10,
|
||||
)
|
||||
assert success, "Timeout waiting for all guardrail events"
|
||||
assert len(guardrail_events["started"]) == 2
|
||||
assert len(guardrail_events["completed"]) == 2
|
||||
assert not guardrail_events["completed"][0].success
|
||||
@@ -408,33 +406,27 @@ def test_guardrail_is_called_using_string():
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_guardrail_is_called_using_callable():
|
||||
guardrail_events = defaultdict(list)
|
||||
guardrail_events: dict[str, list] = defaultdict(list)
|
||||
from crewai.events.event_types import (
|
||||
LLMGuardrailCompletedEvent,
|
||||
LLMGuardrailStartedEvent,
|
||||
)
|
||||
|
||||
all_events_received = threading.Event()
|
||||
condition = threading.Condition()
|
||||
|
||||
@crewai_event_bus.on(LLMGuardrailStartedEvent)
|
||||
def capture_guardrail_started(source, event):
|
||||
guardrail_events["started"].append(event)
|
||||
if (
|
||||
len(guardrail_events["started"]) == 1
|
||||
and len(guardrail_events["completed"]) == 1
|
||||
):
|
||||
all_events_received.set()
|
||||
with condition:
|
||||
guardrail_events["started"].append(event)
|
||||
condition.notify()
|
||||
|
||||
@crewai_event_bus.on(LLMGuardrailCompletedEvent)
|
||||
def capture_guardrail_completed(source, event):
|
||||
guardrail_events["completed"].append(event)
|
||||
if (
|
||||
len(guardrail_events["started"]) == 1
|
||||
and len(guardrail_events["completed"]) == 1
|
||||
):
|
||||
all_events_received.set()
|
||||
with condition:
|
||||
guardrail_events["completed"].append(event)
|
||||
condition.notify()
|
||||
|
||||
agent = Agent(
|
||||
role="Sports Analyst",
|
||||
@@ -445,42 +437,40 @@ def test_guardrail_is_called_using_callable():
|
||||
|
||||
result = agent.kickoff(messages="Top 1 best players in the world?")
|
||||
|
||||
assert all_events_received.wait(timeout=10), (
|
||||
"Timeout waiting for all guardrail events"
|
||||
)
|
||||
with condition:
|
||||
success = condition.wait_for(
|
||||
lambda: len(guardrail_events["started"]) >= 1
|
||||
and len(guardrail_events["completed"]) >= 1,
|
||||
timeout=10,
|
||||
)
|
||||
assert success, "Timeout waiting for all guardrail events"
|
||||
assert len(guardrail_events["started"]) == 1
|
||||
assert len(guardrail_events["completed"]) == 1
|
||||
assert guardrail_events["completed"][0].success
|
||||
assert "Pelé - Santos, 1958" in result.raw
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_guardrail_reached_attempt_limit():
|
||||
guardrail_events = defaultdict(list)
|
||||
guardrail_events: dict[str, list] = defaultdict(list)
|
||||
from crewai.events.event_types import (
|
||||
LLMGuardrailCompletedEvent,
|
||||
LLMGuardrailStartedEvent,
|
||||
)
|
||||
|
||||
all_events_received = threading.Event()
|
||||
condition = threading.Condition()
|
||||
|
||||
@crewai_event_bus.on(LLMGuardrailStartedEvent)
|
||||
def capture_guardrail_started(source, event):
|
||||
guardrail_events["started"].append(event)
|
||||
if (
|
||||
len(guardrail_events["started"]) == 3
|
||||
and len(guardrail_events["completed"]) == 3
|
||||
):
|
||||
all_events_received.set()
|
||||
with condition:
|
||||
guardrail_events["started"].append(event)
|
||||
condition.notify()
|
||||
|
||||
@crewai_event_bus.on(LLMGuardrailCompletedEvent)
|
||||
def capture_guardrail_completed(source, event):
|
||||
guardrail_events["completed"].append(event)
|
||||
if (
|
||||
len(guardrail_events["started"]) == 3
|
||||
and len(guardrail_events["completed"]) == 3
|
||||
):
|
||||
all_events_received.set()
|
||||
with condition:
|
||||
guardrail_events["completed"].append(event)
|
||||
condition.notify()
|
||||
|
||||
agent = Agent(
|
||||
role="Sports Analyst",
|
||||
@@ -498,9 +488,13 @@ def test_guardrail_reached_attempt_limit():
|
||||
):
|
||||
agent.kickoff(messages="Top 10 best players in the world?")
|
||||
|
||||
assert all_events_received.wait(timeout=10), (
|
||||
"Timeout waiting for all guardrail events"
|
||||
)
|
||||
with condition:
|
||||
success = condition.wait_for(
|
||||
lambda: len(guardrail_events["started"]) >= 3
|
||||
and len(guardrail_events["completed"]) >= 3,
|
||||
timeout=10,
|
||||
)
|
||||
assert success, "Timeout waiting for all guardrail events"
|
||||
assert len(guardrail_events["started"]) == 3 # 2 retries + 1 initial call
|
||||
assert len(guardrail_events["completed"]) == 3 # 2 retries + 1 initial call
|
||||
assert not guardrail_events["completed"][0].success
|
||||
@@ -508,7 +502,7 @@ def test_guardrail_reached_attempt_limit():
|
||||
assert not guardrail_events["completed"][2].success
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_agent_output_when_guardrail_returns_base_model():
|
||||
class Player(BaseModel):
|
||||
name: str
|
||||
@@ -599,7 +593,7 @@ def test_lite_agent_with_custom_llm_and_guardrails():
|
||||
assert result2.raw == "Modified by guardrail"
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_lite_agent_with_invalid_llm():
|
||||
"""Test that LiteAgent raises proper error when create_llm returns None."""
|
||||
with patch("crewai.lite_agent.create_llm", return_value=None):
|
||||
@@ -615,7 +609,7 @@ def test_lite_agent_with_invalid_llm():
|
||||
|
||||
@patch.dict("os.environ", {"CREWAI_PLATFORM_INTEGRATION_TOKEN": "test_token"})
|
||||
@patch("crewai_tools.tools.crewai_platform_tools.crewai_platform_tool_builder.requests.get")
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_agent_kickoff_with_platform_tools(mock_get):
|
||||
"""Test that Agent.kickoff() properly integrates platform tools with LiteAgent"""
|
||||
mock_response = Mock()
|
||||
@@ -657,7 +651,7 @@ def test_agent_kickoff_with_platform_tools(mock_get):
|
||||
|
||||
@patch.dict("os.environ", {"EXA_API_KEY": "test_exa_key"})
|
||||
@patch("crewai.agent.Agent._get_external_mcp_tools")
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_agent_kickoff_with_mcp_tools(mock_get_mcp_tools):
|
||||
"""Test that Agent.kickoff() properly integrates MCP tools with LiteAgent"""
|
||||
# Setup mock MCP tools - create a proper BaseTool instance
|
||||
|
||||
@@ -1,126 +0,0 @@
|
||||
interactions:
|
||||
- request:
|
||||
body: '{"messages": [{"role": "system", "content": "You are Test Agent. Test backstory\nYour
|
||||
personal goal is: Test goal\nTo give my best complete final answer to the task
|
||||
respond using the exact following format:\n\nThought: I now can give a great
|
||||
answer\nFinal Answer: Your final answer must be the great and the most complete
|
||||
as possible, it must be outcome described.\n\nI MUST use these formats, my job
|
||||
depends on it!"}, {"role": "user", "content": "\nCurrent Task: Say hello to
|
||||
the world\n\nThis is the expected criteria for your final answer: hello world\nyou
|
||||
MUST return the actual complete content as the final answer, not a summary.\n\nBegin!
|
||||
This is VERY important to you, use the tools available and give your best Final
|
||||
Answer, your job depends on it!\n\nThought:"}], "model": "gpt-4o-mini", "stop":
|
||||
["\nObservation:"]}'
|
||||
headers:
|
||||
accept:
|
||||
- application/json
|
||||
accept-encoding:
|
||||
- gzip, deflate, zstd
|
||||
connection:
|
||||
- keep-alive
|
||||
content-length:
|
||||
- '825'
|
||||
content-type:
|
||||
- application/json
|
||||
host:
|
||||
- api.openai.com
|
||||
user-agent:
|
||||
- OpenAI/Python 1.93.0
|
||||
x-stainless-arch:
|
||||
- arm64
|
||||
x-stainless-async:
|
||||
- 'false'
|
||||
x-stainless-lang:
|
||||
- python
|
||||
x-stainless-os:
|
||||
- MacOS
|
||||
x-stainless-package-version:
|
||||
- 1.93.0
|
||||
x-stainless-raw-response:
|
||||
- 'true'
|
||||
x-stainless-read-timeout:
|
||||
- '600.0'
|
||||
x-stainless-retry-count:
|
||||
- '0'
|
||||
x-stainless-runtime:
|
||||
- CPython
|
||||
x-stainless-runtime-version:
|
||||
- 3.12.9
|
||||
method: POST
|
||||
uri: https://api.openai.com/v1/chat/completions
|
||||
response:
|
||||
body:
|
||||
string: !!binary |
|
||||
H4sIAAAAAAAAAwAAAP//jFJdi9swEHz3r1j0HBc7ZydXvx1HC0evhT6UUtrDKNLa1lXWqpJ8aTjy
|
||||
34vsXOz0A/pi8M7OaGZ3nxMApiSrgImOB9Fbnd4Why/v79HeePeD37/Zfdx8evf5+rY4fNjdPbJV
|
||||
ZNDuEUV4Yb0S1FuNQZGZYOGQB4yq+bYsr7J1nhcj0JNEHWmtDWlBaa+MStfZukizbZpfn9gdKYGe
|
||||
VfA1AQB4Hr/Rp5H4k1WQrV4qPXrPW2TVuQmAOdKxwrj3ygduAlvNoCAT0IzW78DQHgQ30KonBA5t
|
||||
tA3c+D06gG/mrTJcw834X0GHWhPsyWm5FHTYDJ7HUGbQegFwYyjwOJQxysMJOZ7Na2qto53/jcoa
|
||||
ZZTvaofck4lGfSDLRvSYADyMQxoucjPrqLehDvQdx+fycjvpsXk3C/TqBAYKXC/q29NoL/VqiYEr
|
||||
7RdjZoKLDuVMnXfCB6loASSL1H+6+Zv2lFyZ9n/kZ0AItAFlbR1KJS4Tz20O4+n+q+085dEw8+ie
|
||||
lMA6KHRxExIbPujpoJg/+IB93SjTorNOTVfV2LrcZLzZYFm+Zskx+QUAAP//AwB1vYZ+YwMAAA==
|
||||
headers:
|
||||
CF-RAY:
|
||||
- 96fc9f29dea3cf1f-SJC
|
||||
Connection:
|
||||
- keep-alive
|
||||
Content-Encoding:
|
||||
- gzip
|
||||
Content-Type:
|
||||
- application/json
|
||||
Date:
|
||||
- Fri, 15 Aug 2025 23:55:15 GMT
|
||||
Server:
|
||||
- cloudflare
|
||||
Set-Cookie:
|
||||
- __cf_bm=oA9oTa3cE0ZaEUDRf0hCpnarSAQKzrVUhl6qDS4j09w-1755302115-1.0.1.1-gUUDl4ZqvBQkg7244DTwOmSiDUT2z_AiQu0P1xUaABjaufSpZuIlI5G0H7OSnW.ldypvpxjj45NGWesJ62M_2U7r20tHz_gMmDFw6D5ZiNc;
|
||||
path=/; expires=Sat, 16-Aug-25 00:25:15 GMT; domain=.api.openai.com; HttpOnly;
|
||||
Secure; SameSite=None
|
||||
- _cfuvid=ICenEGMmOE5jaOjwD30bAOwrF8.XRbSIKTBl1EyWs0o-1755302115700-0.0.1.1-604800000;
|
||||
path=/; domain=.api.openai.com; HttpOnly; Secure; SameSite=None
|
||||
Strict-Transport-Security:
|
||||
- max-age=31536000; includeSubDomains; preload
|
||||
Transfer-Encoding:
|
||||
- chunked
|
||||
X-Content-Type-Options:
|
||||
- nosniff
|
||||
access-control-expose-headers:
|
||||
- X-Request-ID
|
||||
alt-svc:
|
||||
- h3=":443"; ma=86400
|
||||
cf-cache-status:
|
||||
- DYNAMIC
|
||||
openai-organization:
|
||||
- crewai-iuxna1
|
||||
openai-processing-ms:
|
||||
- '735'
|
||||
openai-project:
|
||||
- proj_xitITlrFeen7zjNSzML82h9x
|
||||
openai-version:
|
||||
- '2020-10-01'
|
||||
x-envoy-upstream-service-time:
|
||||
- '753'
|
||||
x-ratelimit-limit-project-tokens:
|
||||
- '150000000'
|
||||
x-ratelimit-limit-requests:
|
||||
- '30000'
|
||||
x-ratelimit-limit-tokens:
|
||||
- '150000000'
|
||||
x-ratelimit-remaining-project-tokens:
|
||||
- '149999830'
|
||||
x-ratelimit-remaining-requests:
|
||||
- '29999'
|
||||
x-ratelimit-remaining-tokens:
|
||||
- '149999827'
|
||||
x-ratelimit-reset-project-tokens:
|
||||
- 0s
|
||||
x-ratelimit-reset-requests:
|
||||
- 2ms
|
||||
x-ratelimit-reset-tokens:
|
||||
- 0s
|
||||
x-request-id:
|
||||
- req_212fde9d945a462ba0d89ea856131dce
|
||||
status:
|
||||
code: 200
|
||||
message: OK
|
||||
version: 1
|
||||
@@ -0,0 +1,211 @@
|
||||
interactions:
|
||||
- request:
|
||||
body: '{"trace_id": "4d0d2b51-d83a-4054-b41e-8c2d17baa88f", "execution_type":
|
||||
"crew", "user_identifier": null, "execution_context": {"crew_fingerprint": null,
|
||||
"crew_name": "crew", "flow_name": null, "crewai_version": "1.6.0", "privacy_level":
|
||||
"standard"}, "execution_metadata": {"expected_duration_estimate": 300, "agent_count":
|
||||
0, "task_count": 0, "flow_method_count": 0, "execution_started_at": "2025-11-29T02:50:39.376314+00:00"},
|
||||
"ephemeral_trace_id": "4d0d2b51-d83a-4054-b41e-8c2d17baa88f"}'
|
||||
headers:
|
||||
Accept:
|
||||
- '*/*'
|
||||
Accept-Encoding:
|
||||
- gzip, deflate
|
||||
Connection:
|
||||
- keep-alive
|
||||
Content-Length:
|
||||
- '488'
|
||||
Content-Type:
|
||||
- application/json
|
||||
User-Agent:
|
||||
- CrewAI-CLI/1.6.0
|
||||
X-Crewai-Version:
|
||||
- 1.6.0
|
||||
authorization:
|
||||
- AUTHORIZATION-XXX
|
||||
method: POST
|
||||
uri: https://app.crewai.com/crewai_plus/api/v1/tracing/ephemeral/batches
|
||||
response:
|
||||
body:
|
||||
string: '{"id":"71726285-2e63-4d2a-b4c4-4bbd0ff6a9f1","ephemeral_trace_id":"4d0d2b51-d83a-4054-b41e-8c2d17baa88f","execution_type":"crew","crew_name":"crew","flow_name":null,"status":"running","duration_ms":null,"crewai_version":"1.6.0","total_events":0,"execution_context":{"crew_fingerprint":null,"crew_name":"crew","flow_name":null,"crewai_version":"1.6.0","privacy_level":"standard"},"created_at":"2025-11-29T02:50:39.931Z","updated_at":"2025-11-29T02:50:39.931Z","access_code":"TRACE-bf7f3f49b3","user_identifier":null}'
|
||||
headers:
|
||||
Connection:
|
||||
- keep-alive
|
||||
Content-Length:
|
||||
- '515'
|
||||
Content-Type:
|
||||
- application/json; charset=utf-8
|
||||
Date:
|
||||
- Sat, 29 Nov 2025 02:50:39 GMT
|
||||
cache-control:
|
||||
- no-store
|
||||
content-security-policy:
|
||||
- CSP-FILTERED
|
||||
etag:
|
||||
- ETAG-XXX
|
||||
expires:
|
||||
- '0'
|
||||
permissions-policy:
|
||||
- PERMISSIONS-POLICY-XXX
|
||||
pragma:
|
||||
- no-cache
|
||||
referrer-policy:
|
||||
- REFERRER-POLICY-XXX
|
||||
strict-transport-security:
|
||||
- STS-XXX
|
||||
vary:
|
||||
- Accept
|
||||
x-content-type-options:
|
||||
- X-CONTENT-TYPE-XXX
|
||||
x-frame-options:
|
||||
- X-FRAME-OPTIONS-XXX
|
||||
x-permitted-cross-domain-policies:
|
||||
- X-PERMITTED-XXX
|
||||
x-request-id:
|
||||
- X-REQUEST-ID-XXX
|
||||
x-runtime:
|
||||
- X-RUNTIME-XXX
|
||||
x-xss-protection:
|
||||
- X-XSS-PROTECTION-XXX
|
||||
status:
|
||||
code: 201
|
||||
message: Created
|
||||
- request:
|
||||
body: '{"messages":[{"role":"system","content":"You are test role. test backstory\nYour
|
||||
personal goal is: test goal\nTo give my best complete final answer to the task
|
||||
respond using the exact following format:\n\nThought: I now can give a great
|
||||
answer\nFinal Answer: Your final answer must be the great and the most complete
|
||||
as possible, it must be outcome described.\n\nI MUST use these formats, my job
|
||||
depends on it!"},{"role":"user","content":"\nCurrent Task: Analyze the data\n\nThis
|
||||
is the expected criteria for your final answer: Analysis report\nyou MUST return
|
||||
the actual complete content as the final answer, not a summary.\n\nBegin! This
|
||||
is VERY important to you, use the tools available and give your best Final Answer,
|
||||
your job depends on it!\n\nThought:"}],"model":"gpt-4.1-mini"}'
|
||||
headers:
|
||||
accept:
|
||||
- application/json
|
||||
accept-encoding:
|
||||
- gzip, deflate
|
||||
authorization:
|
||||
- AUTHORIZATION-XXX
|
||||
connection:
|
||||
- keep-alive
|
||||
content-length:
|
||||
- '785'
|
||||
content-type:
|
||||
- application/json
|
||||
host:
|
||||
- api.openai.com
|
||||
user-agent:
|
||||
- OpenAI/Python 1.109.1
|
||||
x-stainless-arch:
|
||||
- X-STAINLESS-ARCH-XXX
|
||||
x-stainless-async:
|
||||
- 'false'
|
||||
x-stainless-lang:
|
||||
- python
|
||||
x-stainless-os:
|
||||
- X-STAINLESS-OS-XXX
|
||||
x-stainless-package-version:
|
||||
- 1.109.1
|
||||
x-stainless-read-timeout:
|
||||
- X-STAINLESS-READ-TIMEOUT-XXX
|
||||
x-stainless-retry-count:
|
||||
- '0'
|
||||
x-stainless-runtime:
|
||||
- CPython
|
||||
x-stainless-runtime-version:
|
||||
- 3.12.10
|
||||
method: POST
|
||||
uri: https://api.openai.com/v1/chat/completions
|
||||
response:
|
||||
body:
|
||||
string: !!binary |
|
||||
H4sIAAAAAAAAAwAAAP//jFbbbhtHDH33VxD71AKyECeynegtdW9uiiZI3AJpXRj0DHeXzSxnO5yV
|
||||
ohT594Kzq5WcpEBfdBkOycPDy/CfE4CKfbWGyrWYXdeH06v2/EX7li7T5uenL364rH/6NUn+/Ze/
|
||||
3nP99mW1MI14/xe5vNdautj1gTJHGcUuEWYyq2eXF6snT1cXT54VQRc9BVNr+ny6Wp6ddix8+vjR
|
||||
4/PTR6vTs9Wk3kZ2pNUa/jgBAPinfBpQ8fS+WsOjxf6kI1VsqFrPlwCqFIOdVKjKmlFytTgIXZRM
|
||||
UrDftHFo2ryGa5C4BYcCDW8IEBoLAFB0S+lWvmfBAM/LvzXcyq08Fww7ZYXX1MeU7ehsCdeSU/SD
|
||||
MyJu5aYlyKjvINHfAydSQMhtTOYTcG8g1pBbKn4FPGZcwk2EPsUNe0Ni1CZqSdSQpeJuUVTsMrSo
|
||||
cE8koDvN1GFmhyHsINGGaUt+AVvOLWC2mDkK5AieMnIAFA+JAm1QHNl5/gRwR5J1abE9XsK35u3l
|
||||
hpLZvZU3bEoSQXtyXLMb4WxR99g9sBSTfYpdXzCzTgEAqg5dYaQhobRXV8p7SHmPyMCQZqhjmlkz
|
||||
qgEF0OUBA6gjwcRxMVrpI0tW0MG1gAoydOYBA2wwDKQLcJipiYntd+aOQGn8ExPE3FKCDSbG+0AK
|
||||
2zgEb867guYej5I2BlMYejIx9CpFR6osza2cjkdXgVBYmjV8JzokluaQPlaoExHUKXbA4qJYxZK4
|
||||
AqfjYmnGbRmjlGKyrEzWX6YGhT+gJXcNV1NkH0zNrmtOg8uj1+LRaKS6JpdLpe8Jne39xjpgmA3+
|
||||
WgC4FlPWYrBJ2LdqyWFvJVXvICcSP0p7K7QkCl9xDdj3gZ3R+HXhaLWEuW9uyLXCltnimdQl7guk
|
||||
Nxkza2anFk5wQyhQjPOOUBbQkefyHT0twPrbY/LgacO4L/FBPKUiAkeSEwbIJJ7E7QpOz9pTUo5S
|
||||
Ir+xCGZwa7geQ2M3ux76rTmJCXzcSvk9hR03lMYqsnrnjk7Hahqb2axfxZRoiuLg46ol987I3cu0
|
||||
5d6aOW+tnz3XNSWSfKjFQuL5Er5n8SxNYe4F7eDVRPr6wOMIWrmREoXkQ2YsfJTYYTCQU4/OWK9F
|
||||
uWmzcSCZUp8ozxyU+jlK9rbFbNo74K4Pu6L/KpZBgwGucFDSNfy4662nlBSijIkJO4u7RpdjMgh1
|
||||
GKzkjxqjhHqxhKsoLgyWpxLtm6HrMO1KLSAL1BMTI/SulFuhspS5GauZQknb/aAspApl/r/PI+3k
|
||||
92NmZuA1udh1JP7IUj2kMhbQjVwkYNmQZm7KpYL2cgk/c8cjXQXtc9mZN80Jy0AicXEwVsnPwymY
|
||||
Cvlp/LnY02hdh7pmx5b/aVxz15v7PUnU59Z4OOrgW3m6hOd9T+Lt+Sw9+NlU/rpUZOnnBeRSV+Ng
|
||||
Gd0YtH0DYoA45H6YHoFvyFlKzX0im1yfTf+Hk5/14eifn7zpDYhDDpaSEk9HuY0+htjsHswtOsz9
|
||||
D/MMm+dX2C3hut4/A/uJihvkYJEtygjaTZwpgWbqFbYcAuxKYeAh7NIXJb+m+inawsB34o3y6eR4
|
||||
qUhUD4q22cgQwpEAReJUErbO/DlJPs4LTIhNn+K9fqJa1Sys7V0i1Ci2rGiOfVWkH08A/iyL0vBg
|
||||
96lGuu9yfEfF3dn5+WivOixoB+mTp88maY4Zw0FwvlotvmDwbqRKj3atyqFryR9UD4sZDp7jkeDk
|
||||
KOzP4XzJ9hg6S/N/zB8EzlGfyd/1iTy7hyEfriWyBfa/rs00F8CV2tbj6C4zJUuFpxqHMG6V1bh4
|
||||
3dUsjc1LHlfLur97dnlxQeerZ/ePq5OPJ/8CAAD//wMAwDPj9GkLAAA=
|
||||
headers:
|
||||
CF-RAY:
|
||||
- CF-RAY-XXX
|
||||
Connection:
|
||||
- keep-alive
|
||||
Content-Encoding:
|
||||
- gzip
|
||||
Content-Type:
|
||||
- application/json
|
||||
Date:
|
||||
- Sat, 29 Nov 2025 02:50:45 GMT
|
||||
Server:
|
||||
- cloudflare
|
||||
Set-Cookie:
|
||||
- SET-COOKIE-XXX
|
||||
Strict-Transport-Security:
|
||||
- STS-XXX
|
||||
Transfer-Encoding:
|
||||
- chunked
|
||||
X-Content-Type-Options:
|
||||
- X-CONTENT-TYPE-XXX
|
||||
access-control-expose-headers:
|
||||
- ACCESS-CONTROL-XXX
|
||||
alt-svc:
|
||||
- h3=":443"; ma=86400
|
||||
cf-cache-status:
|
||||
- DYNAMIC
|
||||
openai-organization:
|
||||
- OPENAI-ORG-XXX
|
||||
openai-processing-ms:
|
||||
- '5125'
|
||||
openai-project:
|
||||
- OPENAI-PROJECT-XXX
|
||||
openai-version:
|
||||
- '2020-10-01'
|
||||
x-envoy-upstream-service-time:
|
||||
- '5227'
|
||||
x-openai-proxy-wasm:
|
||||
- v0.1
|
||||
x-ratelimit-limit-project-tokens:
|
||||
- '150000000'
|
||||
x-ratelimit-limit-requests:
|
||||
- X-RATELIMIT-LIMIT-REQUESTS-XXX
|
||||
x-ratelimit-limit-tokens:
|
||||
- X-RATELIMIT-LIMIT-TOKENS-XXX
|
||||
x-ratelimit-remaining-project-tokens:
|
||||
- '149999830'
|
||||
x-ratelimit-remaining-requests:
|
||||
- X-RATELIMIT-REMAINING-REQUESTS-XXX
|
||||
x-ratelimit-remaining-tokens:
|
||||
- X-RATELIMIT-REMAINING-TOKENS-XXX
|
||||
x-ratelimit-reset-project-tokens:
|
||||
- 0s
|
||||
x-ratelimit-reset-requests:
|
||||
- X-RATELIMIT-RESET-REQUESTS-XXX
|
||||
x-ratelimit-reset-tokens:
|
||||
- X-RATELIMIT-RESET-TOKENS-XXX
|
||||
x-request-id:
|
||||
- X-REQUEST-ID-XXX
|
||||
status:
|
||||
code: 200
|
||||
message: OK
|
||||
version: 1
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user