mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-04 05:38:33 +00:00
Compare commits
13 Commits
lorenze/la
...
devin/1764
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
347381be57 | ||
|
|
20704742e2 | ||
|
|
59180e9c9f | ||
|
|
3ce019b07b | ||
|
|
2355ec0733 | ||
|
|
c925d2d519 | ||
|
|
bc4e6a3127 | ||
|
|
37526c693b | ||
|
|
c59173a762 | ||
|
|
4d8eec96e8 | ||
|
|
2025a26fc3 | ||
|
|
bed9a3847a | ||
|
|
5239dc9859 |
161
.env.test
Normal file
161
.env.test
Normal file
@@ -0,0 +1,161 @@
|
||||
# =============================================================================
|
||||
# Test Environment Variables
|
||||
# =============================================================================
|
||||
# This file contains all environment variables needed to run tests locally
|
||||
# in a way that mimics the GitHub Actions CI environment.
|
||||
|
||||
# =============================================================================
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# LLM Provider API Keys
|
||||
# -----------------------------------------------------------------------------
|
||||
OPENAI_API_KEY=fake-api-key
|
||||
ANTHROPIC_API_KEY=fake-anthropic-key
|
||||
GEMINI_API_KEY=fake-gemini-key
|
||||
AZURE_API_KEY=fake-azure-key
|
||||
OPENROUTER_API_KEY=fake-openrouter-key
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# AWS Credentials
|
||||
# -----------------------------------------------------------------------------
|
||||
AWS_ACCESS_KEY_ID=fake-aws-access-key
|
||||
AWS_SECRET_ACCESS_KEY=fake-aws-secret-key
|
||||
AWS_DEFAULT_REGION=us-east-1
|
||||
AWS_REGION_NAME=us-east-1
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Azure OpenAI Configuration
|
||||
# -----------------------------------------------------------------------------
|
||||
AZURE_ENDPOINT=https://fake-azure-endpoint.openai.azure.com
|
||||
AZURE_OPENAI_ENDPOINT=https://fake-azure-endpoint.openai.azure.com
|
||||
AZURE_OPENAI_API_KEY=fake-azure-openai-key
|
||||
AZURE_API_VERSION=2024-02-15-preview
|
||||
OPENAI_API_VERSION=2024-02-15-preview
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Google Cloud Configuration
|
||||
# -----------------------------------------------------------------------------
|
||||
#GOOGLE_CLOUD_PROJECT=fake-gcp-project
|
||||
#GOOGLE_CLOUD_LOCATION=us-central1
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# OpenAI Configuration
|
||||
# -----------------------------------------------------------------------------
|
||||
OPENAI_BASE_URL=https://api.openai.com/v1
|
||||
OPENAI_API_BASE=https://api.openai.com/v1
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Search & Scraping Tool API Keys
|
||||
# -----------------------------------------------------------------------------
|
||||
SERPER_API_KEY=fake-serper-key
|
||||
EXA_API_KEY=fake-exa-key
|
||||
BRAVE_API_KEY=fake-brave-key
|
||||
FIRECRAWL_API_KEY=fake-firecrawl-key
|
||||
TAVILY_API_KEY=fake-tavily-key
|
||||
SERPAPI_API_KEY=fake-serpapi-key
|
||||
SERPLY_API_KEY=fake-serply-key
|
||||
LINKUP_API_KEY=fake-linkup-key
|
||||
PARALLEL_API_KEY=fake-parallel-key
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Exa Configuration
|
||||
# -----------------------------------------------------------------------------
|
||||
EXA_BASE_URL=https://api.exa.ai
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Web Scraping & Automation
|
||||
# -----------------------------------------------------------------------------
|
||||
BRIGHT_DATA_API_KEY=fake-brightdata-key
|
||||
BRIGHT_DATA_ZONE=fake-zone
|
||||
BRIGHTDATA_API_URL=https://api.brightdata.com
|
||||
BRIGHTDATA_DEFAULT_TIMEOUT=600
|
||||
BRIGHTDATA_DEFAULT_POLLING_INTERVAL=1
|
||||
|
||||
OXYLABS_USERNAME=fake-oxylabs-user
|
||||
OXYLABS_PASSWORD=fake-oxylabs-pass
|
||||
|
||||
SCRAPFLY_API_KEY=fake-scrapfly-key
|
||||
SCRAPEGRAPH_API_KEY=fake-scrapegraph-key
|
||||
|
||||
BROWSERBASE_API_KEY=fake-browserbase-key
|
||||
BROWSERBASE_PROJECT_ID=fake-browserbase-project
|
||||
|
||||
HYPERBROWSER_API_KEY=fake-hyperbrowser-key
|
||||
MULTION_API_KEY=fake-multion-key
|
||||
APIFY_API_TOKEN=fake-apify-token
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Database & Vector Store Credentials
|
||||
# -----------------------------------------------------------------------------
|
||||
SINGLESTOREDB_URL=mysql://fake:fake@localhost:3306/fake
|
||||
SINGLESTOREDB_HOST=localhost
|
||||
SINGLESTOREDB_PORT=3306
|
||||
SINGLESTOREDB_USER=fake-user
|
||||
SINGLESTOREDB_PASSWORD=fake-password
|
||||
SINGLESTOREDB_DATABASE=fake-database
|
||||
SINGLESTOREDB_CONNECT_TIMEOUT=30
|
||||
|
||||
SNOWFLAKE_USER=fake-snowflake-user
|
||||
SNOWFLAKE_PASSWORD=fake-snowflake-password
|
||||
SNOWFLAKE_ACCOUNT=fake-snowflake-account
|
||||
SNOWFLAKE_WAREHOUSE=fake-snowflake-warehouse
|
||||
SNOWFLAKE_DATABASE=fake-snowflake-database
|
||||
SNOWFLAKE_SCHEMA=fake-snowflake-schema
|
||||
|
||||
WEAVIATE_URL=http://localhost:8080
|
||||
WEAVIATE_API_KEY=fake-weaviate-key
|
||||
|
||||
EMBEDCHAIN_DB_URI=sqlite:///test.db
|
||||
|
||||
# Databricks Credentials
|
||||
DATABRICKS_HOST=https://fake-databricks.cloud.databricks.com
|
||||
DATABRICKS_TOKEN=fake-databricks-token
|
||||
DATABRICKS_CONFIG_PROFILE=fake-profile
|
||||
|
||||
# MongoDB Credentials
|
||||
MONGODB_URI=mongodb://fake:fake@localhost:27017/fake
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# CrewAI Platform & Enterprise
|
||||
# -----------------------------------------------------------------------------
|
||||
# setting CREWAI_PLATFORM_INTEGRATION_TOKEN causes these test to fail:
|
||||
#=========================== short test summary info ============================
|
||||
#FAILED tests/test_context.py::TestPlatformIntegrationToken::test_platform_context_manager_basic_usage - AssertionError: assert 'fake-platform-token' is None
|
||||
# + where 'fake-platform-token' = get_platform_integration_token()
|
||||
#FAILED tests/test_context.py::TestPlatformIntegrationToken::test_context_var_isolation_between_tests - AssertionError: assert 'fake-platform-token' is None
|
||||
# + where 'fake-platform-token' = get_platform_integration_token()
|
||||
#FAILED tests/test_context.py::TestPlatformIntegrationToken::test_multiple_sequential_context_managers - AssertionError: assert 'fake-platform-token' is None
|
||||
# + where 'fake-platform-token' = get_platform_integration_token()
|
||||
#CREWAI_PLATFORM_INTEGRATION_TOKEN=fake-platform-token
|
||||
CREWAI_PERSONAL_ACCESS_TOKEN=fake-personal-token
|
||||
CREWAI_PLUS_URL=https://fake.crewai.com
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Other Service API Keys
|
||||
# -----------------------------------------------------------------------------
|
||||
ZAPIER_API_KEY=fake-zapier-key
|
||||
PATRONUS_API_KEY=fake-patronus-key
|
||||
MINDS_API_KEY=fake-minds-key
|
||||
HF_TOKEN=fake-hf-token
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Feature Flags/Testing Modes
|
||||
# -----------------------------------------------------------------------------
|
||||
CREWAI_DISABLE_TELEMETRY=true
|
||||
OTEL_SDK_DISABLED=true
|
||||
CREWAI_TESTING=true
|
||||
CREWAI_TRACING_ENABLED=false
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Testing/CI Configuration
|
||||
# -----------------------------------------------------------------------------
|
||||
# VCR recording mode: "none" (default), "new_episodes", "all", "once"
|
||||
PYTEST_VCR_RECORD_MODE=none
|
||||
|
||||
# Set to "true" by GitHub when running in GitHub Actions
|
||||
# GITHUB_ACTIONS=false
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Python Configuration
|
||||
# -----------------------------------------------------------------------------
|
||||
PYTHONUNBUFFERED=1
|
||||
18
.github/workflows/tests.yml
vendored
18
.github/workflows/tests.yml
vendored
@@ -5,18 +5,6 @@ on: [pull_request]
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
env:
|
||||
OPENAI_API_KEY: fake-api-key
|
||||
PYTHONUNBUFFERED: 1
|
||||
BRAVE_API_KEY: fake-brave-key
|
||||
SNOWFLAKE_USER: fake-snowflake-user
|
||||
SNOWFLAKE_PASSWORD: fake-snowflake-password
|
||||
SNOWFLAKE_ACCOUNT: fake-snowflake-account
|
||||
SNOWFLAKE_WAREHOUSE: fake-snowflake-warehouse
|
||||
SNOWFLAKE_DATABASE: fake-snowflake-database
|
||||
SNOWFLAKE_SCHEMA: fake-snowflake-schema
|
||||
EMBEDCHAIN_DB_URI: sqlite:///test.db
|
||||
|
||||
jobs:
|
||||
tests:
|
||||
name: tests (${{ matrix.python-version }})
|
||||
@@ -84,26 +72,20 @@ jobs:
|
||||
# fi
|
||||
|
||||
cd lib/crewai && uv run pytest \
|
||||
--block-network \
|
||||
--timeout=30 \
|
||||
-vv \
|
||||
--splits 8 \
|
||||
--group ${{ matrix.group }} \
|
||||
$DURATIONS_ARG \
|
||||
--durations=10 \
|
||||
-n auto \
|
||||
--maxfail=3
|
||||
|
||||
- name: Run tool tests (group ${{ matrix.group }} of 8)
|
||||
run: |
|
||||
cd lib/crewai-tools && uv run pytest \
|
||||
--block-network \
|
||||
--timeout=30 \
|
||||
-vv \
|
||||
--splits 8 \
|
||||
--group ${{ matrix.group }} \
|
||||
--durations=10 \
|
||||
-n auto \
|
||||
--maxfail=3
|
||||
|
||||
|
||||
|
||||
193
conftest.py
Normal file
193
conftest.py
Normal file
@@ -0,0 +1,193 @@
|
||||
"""Pytest configuration for crewAI workspace."""
|
||||
|
||||
from collections.abc import Generator
|
||||
import os
|
||||
from pathlib import Path
|
||||
import tempfile
|
||||
from typing import Any
|
||||
|
||||
from dotenv import load_dotenv
|
||||
import pytest
|
||||
from vcr.request import Request # type: ignore[import-untyped]
|
||||
|
||||
|
||||
env_test_path = Path(__file__).parent / ".env.test"
|
||||
load_dotenv(env_test_path, override=True)
|
||||
load_dotenv(override=True)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True, scope="function")
|
||||
def cleanup_event_handlers() -> Generator[None, Any, None]:
|
||||
"""Clean up event bus handlers after each test to prevent test pollution."""
|
||||
yield
|
||||
|
||||
try:
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
|
||||
with crewai_event_bus._rwlock.w_locked():
|
||||
crewai_event_bus._sync_handlers.clear()
|
||||
crewai_event_bus._async_handlers.clear()
|
||||
except Exception: # noqa: S110
|
||||
pass
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True, scope="function")
|
||||
def setup_test_environment() -> Generator[None, Any, None]:
|
||||
"""Setup test environment for crewAI workspace."""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
storage_dir = Path(temp_dir) / "crewai_test_storage"
|
||||
storage_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if not storage_dir.exists() or not storage_dir.is_dir():
|
||||
raise RuntimeError(
|
||||
f"Failed to create test storage directory: {storage_dir}"
|
||||
)
|
||||
|
||||
try:
|
||||
test_file = storage_dir / ".permissions_test"
|
||||
test_file.touch()
|
||||
test_file.unlink()
|
||||
except (OSError, IOError) as e:
|
||||
raise RuntimeError(
|
||||
f"Test storage directory {storage_dir} is not writable: {e}"
|
||||
) from e
|
||||
|
||||
os.environ["CREWAI_STORAGE_DIR"] = str(storage_dir)
|
||||
os.environ["CREWAI_TESTING"] = "true"
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
os.environ.pop("CREWAI_TESTING", "true")
|
||||
os.environ.pop("CREWAI_STORAGE_DIR", None)
|
||||
os.environ.pop("CREWAI_DISABLE_TELEMETRY", "true")
|
||||
os.environ.pop("OTEL_SDK_DISABLED", "true")
|
||||
os.environ.pop("OPENAI_BASE_URL", "https://api.openai.com/v1")
|
||||
os.environ.pop("OPENAI_API_BASE", "https://api.openai.com/v1")
|
||||
|
||||
|
||||
HEADERS_TO_FILTER = {
|
||||
"authorization": "AUTHORIZATION-XXX",
|
||||
"content-security-policy": "CSP-FILTERED",
|
||||
"cookie": "COOKIE-XXX",
|
||||
"set-cookie": "SET-COOKIE-XXX",
|
||||
"permissions-policy": "PERMISSIONS-POLICY-XXX",
|
||||
"referrer-policy": "REFERRER-POLICY-XXX",
|
||||
"strict-transport-security": "STS-XXX",
|
||||
"x-content-type-options": "X-CONTENT-TYPE-XXX",
|
||||
"x-frame-options": "X-FRAME-OPTIONS-XXX",
|
||||
"x-permitted-cross-domain-policies": "X-PERMITTED-XXX",
|
||||
"x-request-id": "X-REQUEST-ID-XXX",
|
||||
"x-runtime": "X-RUNTIME-XXX",
|
||||
"x-xss-protection": "X-XSS-PROTECTION-XXX",
|
||||
"x-stainless-arch": "X-STAINLESS-ARCH-XXX",
|
||||
"x-stainless-os": "X-STAINLESS-OS-XXX",
|
||||
"x-stainless-read-timeout": "X-STAINLESS-READ-TIMEOUT-XXX",
|
||||
"cf-ray": "CF-RAY-XXX",
|
||||
"etag": "ETAG-XXX",
|
||||
"Strict-Transport-Security": "STS-XXX",
|
||||
"access-control-expose-headers": "ACCESS-CONTROL-XXX",
|
||||
"openai-organization": "OPENAI-ORG-XXX",
|
||||
"openai-project": "OPENAI-PROJECT-XXX",
|
||||
"x-ratelimit-limit-requests": "X-RATELIMIT-LIMIT-REQUESTS-XXX",
|
||||
"x-ratelimit-limit-tokens": "X-RATELIMIT-LIMIT-TOKENS-XXX",
|
||||
"x-ratelimit-remaining-requests": "X-RATELIMIT-REMAINING-REQUESTS-XXX",
|
||||
"x-ratelimit-remaining-tokens": "X-RATELIMIT-REMAINING-TOKENS-XXX",
|
||||
"x-ratelimit-reset-requests": "X-RATELIMIT-RESET-REQUESTS-XXX",
|
||||
"x-ratelimit-reset-tokens": "X-RATELIMIT-RESET-TOKENS-XXX",
|
||||
"x-goog-api-key": "X-GOOG-API-KEY-XXX",
|
||||
"api-key": "X-API-KEY-XXX",
|
||||
"User-Agent": "X-USER-AGENT-XXX",
|
||||
"apim-request-id:": "X-API-CLIENT-REQUEST-ID-XXX",
|
||||
"azureml-model-session": "AZUREML-MODEL-SESSION-XXX",
|
||||
"x-ms-client-request-id": "X-MS-CLIENT-REQUEST-ID-XXX",
|
||||
"x-ms-region": "X-MS-REGION-XXX",
|
||||
"apim-request-id": "APIM-REQUEST-ID-XXX",
|
||||
"x-api-key": "X-API-KEY-XXX",
|
||||
"anthropic-organization-id": "ANTHROPIC-ORGANIZATION-ID-XXX",
|
||||
"request-id": "REQUEST-ID-XXX",
|
||||
"anthropic-ratelimit-input-tokens-limit": "ANTHROPIC-RATELIMIT-INPUT-TOKENS-LIMIT-XXX",
|
||||
"anthropic-ratelimit-input-tokens-remaining": "ANTHROPIC-RATELIMIT-INPUT-TOKENS-REMAINING-XXX",
|
||||
"anthropic-ratelimit-input-tokens-reset": "ANTHROPIC-RATELIMIT-INPUT-TOKENS-RESET-XXX",
|
||||
"anthropic-ratelimit-output-tokens-limit": "ANTHROPIC-RATELIMIT-OUTPUT-TOKENS-LIMIT-XXX",
|
||||
"anthropic-ratelimit-output-tokens-remaining": "ANTHROPIC-RATELIMIT-OUTPUT-TOKENS-REMAINING-XXX",
|
||||
"anthropic-ratelimit-output-tokens-reset": "ANTHROPIC-RATELIMIT-OUTPUT-TOKENS-RESET-XXX",
|
||||
"anthropic-ratelimit-tokens-limit": "ANTHROPIC-RATELIMIT-TOKENS-LIMIT-XXX",
|
||||
"anthropic-ratelimit-tokens-remaining": "ANTHROPIC-RATELIMIT-TOKENS-REMAINING-XXX",
|
||||
"anthropic-ratelimit-tokens-reset": "ANTHROPIC-RATELIMIT-TOKENS-RESET-XXX",
|
||||
"x-amz-date": "X-AMZ-DATE-XXX",
|
||||
"amz-sdk-invocation-id": "AMZ-SDK-INVOCATION-ID-XXX",
|
||||
"accept-encoding": "ACCEPT-ENCODING-XXX",
|
||||
"x-amzn-requestid": "X-AMZN-REQUESTID-XXX",
|
||||
"x-amzn-RequestId": "X-AMZN-REQUESTID-XXX",
|
||||
}
|
||||
|
||||
|
||||
def _filter_request_headers(request: Request) -> Request: # type: ignore[no-any-unimported]
|
||||
"""Filter sensitive headers from request before recording."""
|
||||
for header_name, replacement in HEADERS_TO_FILTER.items():
|
||||
for variant in [header_name, header_name.upper(), header_name.title()]:
|
||||
if variant in request.headers:
|
||||
request.headers[variant] = [replacement]
|
||||
|
||||
request.method = request.method.upper()
|
||||
return request
|
||||
|
||||
|
||||
def _filter_response_headers(response: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Filter sensitive headers from response before recording."""
|
||||
for header_name, replacement in HEADERS_TO_FILTER.items():
|
||||
for variant in [header_name, header_name.upper(), header_name.title()]:
|
||||
if variant in response["headers"]:
|
||||
response["headers"][variant] = [replacement]
|
||||
return response
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def vcr_cassette_dir(request: Any) -> str:
|
||||
"""Generate cassette directory path based on test module location.
|
||||
|
||||
Organizes cassettes to mirror test directory structure within each package:
|
||||
lib/crewai/tests/llms/google/test_google.py -> lib/crewai/tests/cassettes/llms/google/
|
||||
lib/crewai-tools/tests/tools/test_search.py -> lib/crewai-tools/tests/cassettes/tools/
|
||||
"""
|
||||
test_file = Path(request.fspath)
|
||||
|
||||
for parent in test_file.parents:
|
||||
if parent.name in ("crewai", "crewai-tools") and parent.parent.name == "lib":
|
||||
package_root = parent
|
||||
break
|
||||
else:
|
||||
package_root = test_file.parent
|
||||
|
||||
tests_root = package_root / "tests"
|
||||
test_dir = test_file.parent
|
||||
|
||||
if test_dir != tests_root:
|
||||
relative_path = test_dir.relative_to(tests_root)
|
||||
cassette_dir = tests_root / "cassettes" / relative_path
|
||||
else:
|
||||
cassette_dir = tests_root / "cassettes"
|
||||
|
||||
cassette_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
return str(cassette_dir)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def vcr_config(vcr_cassette_dir: str) -> dict[str, Any]:
|
||||
"""Configure VCR with organized cassette storage."""
|
||||
config = {
|
||||
"cassette_library_dir": vcr_cassette_dir,
|
||||
"record_mode": os.getenv("PYTEST_VCR_RECORD_MODE", "once"),
|
||||
"filter_headers": [(k, v) for k, v in HEADERS_TO_FILTER.items()],
|
||||
"before_record_request": _filter_request_headers,
|
||||
"before_record_response": _filter_response_headers,
|
||||
"filter_query_parameters": ["key"],
|
||||
"match_on": ["method", "scheme", "host", "port", "path"],
|
||||
}
|
||||
|
||||
if os.getenv("GITHUB_ACTIONS") == "true":
|
||||
config["record_mode"] = "none"
|
||||
|
||||
return config
|
||||
@@ -1089,6 +1089,50 @@ CrewAI supports streaming responses from LLMs, allowing your application to rece
|
||||
</Tab>
|
||||
</Tabs>
|
||||
|
||||
## Async LLM Calls
|
||||
|
||||
CrewAI supports asynchronous LLM calls for improved performance and concurrency in your AI workflows. Async calls allow you to run multiple LLM requests concurrently without blocking, making them ideal for high-throughput applications and parallel agent operations.
|
||||
|
||||
<Tabs>
|
||||
<Tab title="Basic Usage">
|
||||
Use the `acall` method for asynchronous LLM requests:
|
||||
|
||||
```python
|
||||
import asyncio
|
||||
from crewai import LLM
|
||||
|
||||
async def main():
|
||||
llm = LLM(model="openai/gpt-4o")
|
||||
|
||||
# Single async call
|
||||
response = await llm.acall("What is the capital of France?")
|
||||
print(response)
|
||||
|
||||
asyncio.run(main())
|
||||
```
|
||||
|
||||
The `acall` method supports all the same parameters as the synchronous `call` method, including messages, tools, and callbacks.
|
||||
</Tab>
|
||||
|
||||
<Tab title="With Streaming">
|
||||
Combine async calls with streaming for real-time concurrent responses:
|
||||
|
||||
```python
|
||||
import asyncio
|
||||
from crewai import LLM
|
||||
|
||||
async def stream_async():
|
||||
llm = LLM(model="openai/gpt-4o", stream=True)
|
||||
|
||||
response = await llm.acall("Write a short story about AI")
|
||||
|
||||
print(response)
|
||||
|
||||
asyncio.run(stream_async())
|
||||
```
|
||||
</Tab>
|
||||
</Tabs>
|
||||
|
||||
## Structured LLM Calls
|
||||
|
||||
CrewAI supports structured responses from LLM calls by allowing you to define a `response_format` using a Pydantic model. This enables the framework to automatically parse and validate the output, making it easier to integrate the response into your application without manual post-processing.
|
||||
|
||||
@@ -218,7 +218,7 @@ Update the root `README.md` only if the tool introduces a new category or notabl
|
||||
|
||||
## Discovery and specs
|
||||
|
||||
Our internal tooling discovers classes whose names end with `Tool`. Keep your class exported from the module path under `crewai_tools/tools/...` to be picked up by scripts like `generate_tool_specs.py`.
|
||||
Our internal tooling discovers classes whose names end with `Tool`. Keep your class exported from the module path under `crewai_tools/tools/...` to be picked up by scripts like `crewai_tools.generate_tool_specs.py`.
|
||||
|
||||
---
|
||||
|
||||
|
||||
@@ -8,17 +8,17 @@ authors = [
|
||||
]
|
||||
requires-python = ">=3.10, <3.14"
|
||||
dependencies = [
|
||||
"lancedb>=0.5.4",
|
||||
"pytube>=15.0.0",
|
||||
"requests>=2.32.5",
|
||||
"docker>=7.1.0",
|
||||
"crewai==1.6.0",
|
||||
"lancedb>=0.5.4",
|
||||
"tiktoken>=0.8.0",
|
||||
"beautifulsoup4>=4.13.4",
|
||||
"pypdf>=5.9.0",
|
||||
"python-docx>=1.2.0",
|
||||
"youtube-transcript-api>=1.2.2",
|
||||
"lancedb~=0.5.4",
|
||||
"pytube~=15.0.0",
|
||||
"requests~=2.32.5",
|
||||
"docker~=7.1.0",
|
||||
"crewai==1.6.1",
|
||||
"lancedb~=0.5.4",
|
||||
"tiktoken~=0.8.0",
|
||||
"beautifulsoup4~=4.13.4",
|
||||
"python-docx~=1.2.0",
|
||||
"youtube-transcript-api~=1.2.2",
|
||||
"pymupdf~=1.26.6",
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -291,4 +291,4 @@ __all__ = [
|
||||
"ZapierActionTools",
|
||||
]
|
||||
|
||||
__version__ = "1.6.0"
|
||||
__version__ = "1.6.1"
|
||||
|
||||
@@ -3,8 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, TypeAlias, TypedDict, cast
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
import uuid
|
||||
|
||||
from crewai.rag.config.types import RagConfigType
|
||||
@@ -19,15 +18,13 @@ from typing_extensions import TypeIs, Unpack
|
||||
from crewai_tools.rag.data_types import DataType
|
||||
from crewai_tools.rag.misc import sanitize_metadata_for_chromadb
|
||||
from crewai_tools.tools.rag.rag_tool import Adapter
|
||||
from crewai_tools.tools.rag.types import AddDocumentParams, ContentItem
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.rag.qdrant.config import QdrantConfig
|
||||
|
||||
|
||||
ContentItem: TypeAlias = str | Path | dict[str, Any]
|
||||
|
||||
|
||||
def _is_qdrant_config(config: Any) -> TypeIs[QdrantConfig]:
|
||||
"""Check if config is a QdrantConfig using safe duck typing.
|
||||
|
||||
@@ -46,19 +43,6 @@ def _is_qdrant_config(config: Any) -> TypeIs[QdrantConfig]:
|
||||
return False
|
||||
|
||||
|
||||
class AddDocumentParams(TypedDict, total=False):
|
||||
"""Parameters for adding documents to the RAG system."""
|
||||
|
||||
data_type: DataType
|
||||
metadata: dict[str, Any]
|
||||
website: str
|
||||
url: str
|
||||
file_path: str | Path
|
||||
github_url: str
|
||||
youtube_url: str
|
||||
directory_path: str | Path
|
||||
|
||||
|
||||
class CrewAIRagAdapter(Adapter):
|
||||
"""Adapter that uses CrewAI's native RAG system.
|
||||
|
||||
@@ -131,13 +115,26 @@ class CrewAIRagAdapter(Adapter):
|
||||
def add(self, *args: ContentItem, **kwargs: Unpack[AddDocumentParams]) -> None:
|
||||
"""Add content to the knowledge base.
|
||||
|
||||
This method handles various input types and converts them to documents
|
||||
for the vector database. It supports the data_type parameter for
|
||||
compatibility with existing tools.
|
||||
|
||||
Args:
|
||||
*args: Content items to add (strings, paths, or document dicts)
|
||||
**kwargs: Additional parameters including data_type, metadata, etc.
|
||||
**kwargs: Additional parameters including:
|
||||
- data_type: DataType enum or string (e.g., "file", "pdf_file", "text")
|
||||
- path: Path to file or directory (alternative to positional arg)
|
||||
- file_path: Alias for path
|
||||
- metadata: Additional metadata to attach to documents
|
||||
- url: URL to fetch content from
|
||||
- website: Website URL to scrape
|
||||
- github_url: GitHub repository URL
|
||||
- youtube_url: YouTube video URL
|
||||
- directory_path: Path to directory
|
||||
|
||||
Examples:
|
||||
rag_tool.add("path/to/document.pdf", data_type=DataType.PDF_FILE)
|
||||
|
||||
rag_tool.add(path="path/to/document.pdf", data_type="file")
|
||||
rag_tool.add(file_path="path/to/document.pdf", data_type="pdf_file")
|
||||
|
||||
rag_tool.add("path/to/document.pdf") # auto-detects PDF
|
||||
"""
|
||||
import os
|
||||
|
||||
@@ -146,10 +143,54 @@ class CrewAIRagAdapter(Adapter):
|
||||
from crewai_tools.rag.source_content import SourceContent
|
||||
|
||||
documents: list[BaseRecord] = []
|
||||
data_type: DataType | None = kwargs.get("data_type")
|
||||
raw_data_type = kwargs.get("data_type")
|
||||
base_metadata: dict[str, Any] = kwargs.get("metadata", {})
|
||||
|
||||
for arg in args:
|
||||
data_type: DataType | None = None
|
||||
if raw_data_type is not None:
|
||||
if isinstance(raw_data_type, DataType):
|
||||
if raw_data_type != DataType.FILE:
|
||||
data_type = raw_data_type
|
||||
elif isinstance(raw_data_type, str):
|
||||
if raw_data_type != "file":
|
||||
try:
|
||||
data_type = DataType(raw_data_type)
|
||||
except ValueError:
|
||||
raise ValueError(
|
||||
f"Invalid data_type: '{raw_data_type}'. "
|
||||
f"Valid values are: 'file' (auto-detect), or one of: "
|
||||
f"{', '.join(dt.value for dt in DataType)}"
|
||||
) from None
|
||||
|
||||
content_items: list[ContentItem] = list(args)
|
||||
|
||||
path_value = kwargs.get("path") or kwargs.get("file_path")
|
||||
if path_value is not None:
|
||||
content_items.append(path_value)
|
||||
|
||||
if url := kwargs.get("url"):
|
||||
content_items.append(url)
|
||||
if website := kwargs.get("website"):
|
||||
content_items.append(website)
|
||||
if github_url := kwargs.get("github_url"):
|
||||
content_items.append(github_url)
|
||||
if youtube_url := kwargs.get("youtube_url"):
|
||||
content_items.append(youtube_url)
|
||||
if directory_path := kwargs.get("directory_path"):
|
||||
content_items.append(directory_path)
|
||||
|
||||
file_extensions = {
|
||||
".pdf",
|
||||
".txt",
|
||||
".csv",
|
||||
".json",
|
||||
".xml",
|
||||
".docx",
|
||||
".mdx",
|
||||
".md",
|
||||
}
|
||||
|
||||
for arg in content_items:
|
||||
source_ref: str
|
||||
if isinstance(arg, dict):
|
||||
source_ref = str(arg.get("source", arg.get("content", "")))
|
||||
@@ -157,6 +198,14 @@ class CrewAIRagAdapter(Adapter):
|
||||
source_ref = str(arg)
|
||||
|
||||
if not data_type:
|
||||
ext = os.path.splitext(source_ref)[1].lower()
|
||||
is_url = source_ref.startswith(("http://", "https://", "file://"))
|
||||
if (
|
||||
ext in file_extensions
|
||||
and not is_url
|
||||
and not os.path.isfile(source_ref)
|
||||
):
|
||||
raise FileNotFoundError(f"File does not exist: {source_ref}")
|
||||
data_type = DataTypes.from_content(source_ref)
|
||||
|
||||
if data_type == DataType.DIRECTORY:
|
||||
|
||||
@@ -4,17 +4,20 @@ from collections.abc import Mapping
|
||||
import inspect
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any, cast
|
||||
from typing import Any
|
||||
|
||||
from crewai.tools.base_tool import BaseTool, EnvVar
|
||||
from crewai_tools import tools
|
||||
from pydantic import BaseModel
|
||||
from pydantic.json_schema import GenerateJsonSchema
|
||||
from pydantic_core import PydanticOmit
|
||||
|
||||
from crewai_tools import tools
|
||||
|
||||
|
||||
class SchemaGenerator(GenerateJsonSchema):
|
||||
def handle_invalid_for_json_schema(self, schema, error_info):
|
||||
def handle_invalid_for_json_schema(
|
||||
self, schema: Any, error_info: Any
|
||||
) -> dict[str, Any]:
|
||||
raise PydanticOmit
|
||||
|
||||
|
||||
@@ -73,7 +76,7 @@ class ToolSpecExtractor:
|
||||
|
||||
@staticmethod
|
||||
def _extract_field_default(
|
||||
field: dict | None, fallback: str | list[Any] = ""
|
||||
field: dict[str, Any] | None, fallback: str | list[Any] = ""
|
||||
) -> str | list[Any] | int:
|
||||
if not field:
|
||||
return fallback
|
||||
@@ -83,7 +86,7 @@ class ToolSpecExtractor:
|
||||
return default if isinstance(default, (list, str, int)) else fallback
|
||||
|
||||
@staticmethod
|
||||
def _extract_params(args_schema_field: dict | None) -> dict[str, Any]:
|
||||
def _extract_params(args_schema_field: dict[str, Any] | None) -> dict[str, Any]:
|
||||
if not args_schema_field:
|
||||
return {}
|
||||
|
||||
@@ -94,15 +97,15 @@ class ToolSpecExtractor:
|
||||
):
|
||||
return {}
|
||||
|
||||
# Cast to type[BaseModel] after runtime check
|
||||
schema_class = cast(type[BaseModel], args_schema_class)
|
||||
try:
|
||||
return schema_class.model_json_schema(schema_generator=SchemaGenerator)
|
||||
return args_schema_class.model_json_schema(schema_generator=SchemaGenerator)
|
||||
except Exception:
|
||||
return {}
|
||||
|
||||
@staticmethod
|
||||
def _extract_env_vars(env_vars_field: dict | None) -> list[dict[str, Any]]:
|
||||
def _extract_env_vars(
|
||||
env_vars_field: dict[str, Any] | None,
|
||||
) -> list[dict[str, Any]]:
|
||||
if not env_vars_field:
|
||||
return []
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
from enum import Enum
|
||||
from importlib import import_module
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import cast
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from crewai_tools.rag.base_loader import BaseLoader
|
||||
@@ -8,6 +10,7 @@ from crewai_tools.rag.chunkers.base_chunker import BaseChunker
|
||||
|
||||
|
||||
class DataType(str, Enum):
|
||||
FILE = "file"
|
||||
PDF_FILE = "pdf_file"
|
||||
TEXT_FILE = "text_file"
|
||||
CSV = "csv"
|
||||
@@ -15,22 +18,14 @@ class DataType(str, Enum):
|
||||
XML = "xml"
|
||||
DOCX = "docx"
|
||||
MDX = "mdx"
|
||||
|
||||
# Database types
|
||||
MYSQL = "mysql"
|
||||
POSTGRES = "postgres"
|
||||
|
||||
# Repository types
|
||||
GITHUB = "github"
|
||||
DIRECTORY = "directory"
|
||||
|
||||
# Web types
|
||||
WEBSITE = "website"
|
||||
DOCS_SITE = "docs_site"
|
||||
YOUTUBE_VIDEO = "youtube_video"
|
||||
YOUTUBE_CHANNEL = "youtube_channel"
|
||||
|
||||
# Raw types
|
||||
TEXT = "text"
|
||||
|
||||
def get_chunker(self) -> BaseChunker:
|
||||
@@ -63,13 +58,11 @@ class DataType(str, Enum):
|
||||
|
||||
try:
|
||||
module = import_module(module_path)
|
||||
return getattr(module, class_name)()
|
||||
return cast(BaseChunker, getattr(module, class_name)())
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error loading chunker for {self}: {e}") from e
|
||||
|
||||
def get_loader(self) -> BaseLoader:
|
||||
from importlib import import_module
|
||||
|
||||
loaders = {
|
||||
DataType.PDF_FILE: ("pdf_loader", "PDFLoader"),
|
||||
DataType.TEXT_FILE: ("text_loader", "TextFileLoader"),
|
||||
@@ -98,7 +91,7 @@ class DataType(str, Enum):
|
||||
module_path = f"crewai_tools.rag.loaders.{module_name}"
|
||||
try:
|
||||
module = import_module(module_path)
|
||||
return getattr(module, class_name)()
|
||||
return cast(BaseLoader, getattr(module, class_name)())
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error loading loader for {self}: {e}") from e
|
||||
|
||||
|
||||
@@ -2,70 +2,112 @@
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import Any, cast
|
||||
from urllib.parse import urlparse
|
||||
import urllib.request
|
||||
|
||||
from crewai_tools.rag.base_loader import BaseLoader, LoaderResult
|
||||
from crewai_tools.rag.source_content import SourceContent
|
||||
|
||||
|
||||
class PDFLoader(BaseLoader):
|
||||
"""Loader for PDF files."""
|
||||
"""Loader for PDF files and URLs."""
|
||||
|
||||
def load(self, source: SourceContent, **kwargs) -> LoaderResult: # type: ignore[override]
|
||||
"""Load and extract text from a PDF file.
|
||||
@staticmethod
|
||||
def _is_url(path: str) -> bool:
|
||||
"""Check if the path is a URL."""
|
||||
try:
|
||||
parsed = urlparse(path)
|
||||
return parsed.scheme in ("http", "https")
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _download_pdf(url: str) -> bytes:
|
||||
"""Download PDF content from a URL.
|
||||
|
||||
Args:
|
||||
source: The source content containing the PDF file path
|
||||
url: The URL to download from.
|
||||
|
||||
Returns:
|
||||
LoaderResult with extracted text content
|
||||
The PDF content as bytes.
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If the PDF file doesn't exist
|
||||
ImportError: If required PDF libraries aren't installed
|
||||
ValueError: If the download fails.
|
||||
"""
|
||||
|
||||
try:
|
||||
with urllib.request.urlopen(url, timeout=30) as response: # noqa: S310
|
||||
return cast(bytes, response.read())
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to download PDF from {url}: {e!s}") from e
|
||||
|
||||
def load(self, source: SourceContent, **kwargs: Any) -> LoaderResult: # type: ignore[override]
|
||||
"""Load and extract text from a PDF file or URL.
|
||||
|
||||
Args:
|
||||
source: The source content containing the PDF file path or URL.
|
||||
|
||||
Returns:
|
||||
LoaderResult with extracted text content.
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If the PDF file doesn't exist.
|
||||
ImportError: If required PDF libraries aren't installed.
|
||||
ValueError: If the PDF cannot be read or downloaded.
|
||||
"""
|
||||
try:
|
||||
import pypdf
|
||||
except ImportError:
|
||||
try:
|
||||
import PyPDF2 as pypdf # type: ignore[import-not-found,no-redef] # noqa: N813
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"PDF support requires pypdf or PyPDF2. Install with: uv add pypdf"
|
||||
) from e
|
||||
import pymupdf # type: ignore[import-untyped]
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"PDF support requires pymupdf. Install with: uv add pymupdf"
|
||||
) from e
|
||||
|
||||
file_path = source.source
|
||||
is_url = self._is_url(file_path)
|
||||
|
||||
if not os.path.isfile(file_path):
|
||||
raise FileNotFoundError(f"PDF file not found: {file_path}")
|
||||
if is_url:
|
||||
source_name = Path(urlparse(file_path).path).name or "downloaded.pdf"
|
||||
else:
|
||||
source_name = Path(file_path).name
|
||||
|
||||
text_content = []
|
||||
text_content: list[str] = []
|
||||
metadata: dict[str, Any] = {
|
||||
"source": str(file_path),
|
||||
"file_name": Path(file_path).name,
|
||||
"source": file_path,
|
||||
"file_name": source_name,
|
||||
"file_type": "pdf",
|
||||
}
|
||||
|
||||
try:
|
||||
with open(file_path, "rb") as file:
|
||||
pdf_reader = pypdf.PdfReader(file)
|
||||
metadata["num_pages"] = len(pdf_reader.pages)
|
||||
if is_url:
|
||||
pdf_bytes = self._download_pdf(file_path)
|
||||
doc = pymupdf.open(stream=pdf_bytes, filetype="pdf")
|
||||
else:
|
||||
if not os.path.isfile(file_path):
|
||||
raise FileNotFoundError(f"PDF file not found: {file_path}")
|
||||
doc = pymupdf.open(file_path)
|
||||
|
||||
for page_num, page in enumerate(pdf_reader.pages, 1):
|
||||
page_text = page.extract_text()
|
||||
if page_text.strip():
|
||||
text_content.append(f"Page {page_num}:\n{page_text}")
|
||||
metadata["num_pages"] = len(doc)
|
||||
|
||||
for page_num, page in enumerate(doc, 1):
|
||||
page_text = page.get_text()
|
||||
if page_text.strip():
|
||||
text_content.append(f"Page {page_num}:\n{page_text}")
|
||||
|
||||
doc.close()
|
||||
except FileNotFoundError:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error reading PDF file {file_path}: {e!s}") from e
|
||||
raise ValueError(f"Error reading PDF from {file_path}: {e!s}") from e
|
||||
|
||||
if not text_content:
|
||||
content = f"[PDF file with no extractable text: {Path(file_path).name}]"
|
||||
content = f"[PDF file with no extractable text: {source_name}]"
|
||||
else:
|
||||
content = "\n\n".join(text_content)
|
||||
|
||||
return LoaderResult(
|
||||
content=content,
|
||||
source=str(file_path),
|
||||
source=file_path,
|
||||
metadata=metadata,
|
||||
doc_id=self.generate_doc_id(source_ref=str(file_path), content=content),
|
||||
doc_id=self.generate_doc_id(source_ref=file_path, content=content),
|
||||
)
|
||||
|
||||
@@ -14,9 +14,14 @@ from pydantic import (
|
||||
field_validator,
|
||||
model_validator,
|
||||
)
|
||||
from typing_extensions import Self
|
||||
from typing_extensions import Self, Unpack
|
||||
|
||||
from crewai_tools.tools.rag.types import RagToolConfig, VectorDbConfig
|
||||
from crewai_tools.tools.rag.types import (
|
||||
AddDocumentParams,
|
||||
ContentItem,
|
||||
RagToolConfig,
|
||||
VectorDbConfig,
|
||||
)
|
||||
|
||||
|
||||
def _validate_embedding_config(
|
||||
@@ -72,6 +77,8 @@ def _validate_embedding_config(
|
||||
|
||||
|
||||
class Adapter(BaseModel, ABC):
|
||||
"""Abstract base class for RAG adapters."""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
@abstractmethod
|
||||
@@ -86,8 +93,8 @@ class Adapter(BaseModel, ABC):
|
||||
@abstractmethod
|
||||
def add(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
*args: ContentItem,
|
||||
**kwargs: Unpack[AddDocumentParams],
|
||||
) -> None:
|
||||
"""Add content to the knowledge base."""
|
||||
|
||||
@@ -102,7 +109,11 @@ class RagTool(BaseTool):
|
||||
) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def add(self, *args: Any, **kwargs: Any) -> None:
|
||||
def add(
|
||||
self,
|
||||
*args: ContentItem,
|
||||
**kwargs: Unpack[AddDocumentParams],
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
name: str = "Knowledge base"
|
||||
@@ -207,9 +218,34 @@ class RagTool(BaseTool):
|
||||
|
||||
def add(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
*args: ContentItem,
|
||||
**kwargs: Unpack[AddDocumentParams],
|
||||
) -> None:
|
||||
"""Add content to the knowledge base.
|
||||
|
||||
|
||||
Args:
|
||||
*args: Content items to add (strings, paths, or document dicts)
|
||||
data_type: DataType enum or string (e.g., "file", "pdf_file", "text")
|
||||
path: Path to file or directory, alias to positional arg
|
||||
file_path: Alias for path
|
||||
metadata: Additional metadata to attach to documents
|
||||
url: URL to fetch content from
|
||||
website: Website URL to scrape
|
||||
github_url: GitHub repository URL
|
||||
youtube_url: YouTube video URL
|
||||
directory_path: Path to directory
|
||||
|
||||
Examples:
|
||||
rag_tool.add("path/to/document.pdf", data_type=DataType.PDF_FILE)
|
||||
|
||||
# Keyword argument (documented API)
|
||||
rag_tool.add(path="path/to/document.pdf", data_type="file")
|
||||
rag_tool.add(file_path="path/to/document.pdf", data_type="pdf_file")
|
||||
|
||||
# Auto-detect type from extension
|
||||
rag_tool.add("path/to/document.pdf") # auto-detects PDF
|
||||
"""
|
||||
self.adapter.add(*args, **kwargs)
|
||||
|
||||
def _run(
|
||||
|
||||
@@ -1,10 +1,50 @@
|
||||
"""Type definitions for RAG tool configuration."""
|
||||
|
||||
from typing import Any, Literal
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal, TypeAlias
|
||||
|
||||
from crewai.rag.embeddings.types import ProviderSpec
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from crewai_tools.rag.data_types import DataType
|
||||
|
||||
|
||||
DataTypeStr: TypeAlias = Literal[
|
||||
"file",
|
||||
"pdf_file",
|
||||
"text_file",
|
||||
"csv",
|
||||
"json",
|
||||
"xml",
|
||||
"docx",
|
||||
"mdx",
|
||||
"mysql",
|
||||
"postgres",
|
||||
"github",
|
||||
"directory",
|
||||
"website",
|
||||
"docs_site",
|
||||
"youtube_video",
|
||||
"youtube_channel",
|
||||
"text",
|
||||
]
|
||||
|
||||
ContentItem: TypeAlias = str | Path | dict[str, Any]
|
||||
|
||||
|
||||
class AddDocumentParams(TypedDict, total=False):
|
||||
"""Parameters for adding documents to the RAG system."""
|
||||
|
||||
data_type: DataType | DataTypeStr
|
||||
metadata: dict[str, Any]
|
||||
path: str | Path
|
||||
file_path: str | Path
|
||||
website: str
|
||||
url: str
|
||||
github_url: str
|
||||
youtube_url: str
|
||||
directory_path: str | Path
|
||||
|
||||
|
||||
class VectorDbConfig(TypedDict):
|
||||
"""Configuration for vector database provider.
|
||||
|
||||
@@ -1,21 +0,0 @@
|
||||
import pytest
|
||||
|
||||
|
||||
def pytest_configure(config):
|
||||
"""Register custom markers."""
|
||||
config.addinivalue_line("markers", "integration: mark test as an integration test")
|
||||
config.addinivalue_line("markers", "asyncio: mark test as an async test")
|
||||
|
||||
# Set the asyncio loop scope through ini configuration
|
||||
config.inicfg["asyncio_mode"] = "auto"
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def event_loop():
|
||||
"""Create an instance of the default event loop for each test case."""
|
||||
import asyncio
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
yield loop
|
||||
loop.close()
|
||||
@@ -2,7 +2,7 @@ import json
|
||||
from unittest import mock
|
||||
|
||||
from crewai.tools.base_tool import BaseTool, EnvVar
|
||||
from generate_tool_specs import ToolSpecExtractor
|
||||
from crewai_tools.generate_tool_specs import ToolSpecExtractor
|
||||
from pydantic import BaseModel, Field
|
||||
import pytest
|
||||
|
||||
@@ -61,8 +61,8 @@ def test_unwrap_schema(extractor):
|
||||
@pytest.fixture
|
||||
def mock_tool_extractor(extractor):
|
||||
with (
|
||||
mock.patch("generate_tool_specs.dir", return_value=["MockTool"]),
|
||||
mock.patch("generate_tool_specs.getattr", return_value=MockTool),
|
||||
mock.patch("crewai_tools.generate_tool_specs.dir", return_value=["MockTool"]),
|
||||
mock.patch("crewai_tools.generate_tool_specs.getattr", return_value=MockTool),
|
||||
):
|
||||
extractor.extract_all_tools()
|
||||
assert len(extractor.tools_spec) == 1
|
||||
|
||||
@@ -4,7 +4,7 @@ from crewai_tools.tools.firecrawl_crawl_website_tool.firecrawl_crawl_website_too
|
||||
FirecrawlCrawlWebsiteTool,
|
||||
)
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_firecrawl_crawl_tool_integration():
|
||||
tool = FirecrawlCrawlWebsiteTool(config={
|
||||
"limit": 2,
|
||||
|
||||
@@ -4,7 +4,7 @@ from crewai_tools.tools.firecrawl_scrape_website_tool.firecrawl_scrape_website_t
|
||||
FirecrawlScrapeWebsiteTool,
|
||||
)
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_firecrawl_scrape_tool_integration():
|
||||
tool = FirecrawlScrapeWebsiteTool()
|
||||
result = tool.run(url="https://firecrawl.dev")
|
||||
|
||||
@@ -3,7 +3,7 @@ import pytest
|
||||
from crewai_tools.tools.firecrawl_search_tool.firecrawl_search_tool import FirecrawlSearchTool
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_firecrawl_search_tool_integration():
|
||||
tool = FirecrawlSearchTool()
|
||||
result = tool.run(query="firecrawl")
|
||||
|
||||
471
lib/crewai-tools/tests/tools/rag/test_rag_tool_add_data_type.py
Normal file
471
lib/crewai-tools/tests/tools/rag/test_rag_tool_add_data_type.py
Normal file
@@ -0,0 +1,471 @@
|
||||
"""Tests for RagTool.add() method with various data_type values."""
|
||||
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai_tools.rag.data_types import DataType
|
||||
from crewai_tools.tools.rag.rag_tool import RagTool
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_rag_client() -> MagicMock:
|
||||
"""Create a mock RAG client for testing."""
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_or_create_collection = MagicMock(return_value=None)
|
||||
mock_client.add_documents = MagicMock(return_value=None)
|
||||
mock_client.search = MagicMock(return_value=[])
|
||||
return mock_client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def rag_tool(mock_rag_client: MagicMock) -> RagTool:
|
||||
"""Create a RagTool instance with mocked client."""
|
||||
with (
|
||||
patch(
|
||||
"crewai_tools.adapters.crewai_rag_adapter.get_rag_client",
|
||||
return_value=mock_rag_client,
|
||||
),
|
||||
patch(
|
||||
"crewai_tools.adapters.crewai_rag_adapter.create_client",
|
||||
return_value=mock_rag_client,
|
||||
),
|
||||
):
|
||||
return RagTool()
|
||||
|
||||
|
||||
class TestDataTypeFileAlias:
|
||||
"""Tests for data_type='file' alias."""
|
||||
|
||||
def test_file_alias_with_existing_file(
|
||||
self, rag_tool: RagTool, mock_rag_client: MagicMock
|
||||
) -> None:
|
||||
"""Test that data_type='file' works with existing files."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
test_file = Path(tmpdir) / "test.txt"
|
||||
test_file.write_text("Test content for file alias.")
|
||||
|
||||
rag_tool.add(path=str(test_file), data_type="file")
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
def test_file_alias_with_nonexistent_file_raises_error(
|
||||
self, rag_tool: RagTool
|
||||
) -> None:
|
||||
"""Test that data_type='file' raises FileNotFoundError for missing files."""
|
||||
with pytest.raises(FileNotFoundError, match="File does not exist"):
|
||||
rag_tool.add(path="nonexistent/path/to/file.pdf", data_type="file")
|
||||
|
||||
def test_file_alias_with_path_keyword(
|
||||
self, rag_tool: RagTool, mock_rag_client: MagicMock
|
||||
) -> None:
|
||||
"""Test that path keyword argument works with data_type='file'."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
test_file = Path(tmpdir) / "document.txt"
|
||||
test_file.write_text("Content via path keyword.")
|
||||
|
||||
rag_tool.add(data_type="file", path=str(test_file))
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
def test_file_alias_with_file_path_keyword(
|
||||
self, rag_tool: RagTool, mock_rag_client: MagicMock
|
||||
) -> None:
|
||||
"""Test that file_path keyword argument works with data_type='file'."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
test_file = Path(tmpdir) / "document.txt"
|
||||
test_file.write_text("Content via file_path keyword.")
|
||||
|
||||
rag_tool.add(data_type="file", file_path=str(test_file))
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
|
||||
class TestDataTypeStringValues:
|
||||
"""Tests for data_type as string values matching DataType enum."""
|
||||
|
||||
def test_pdf_file_string(
|
||||
self, rag_tool: RagTool, mock_rag_client: MagicMock
|
||||
) -> None:
|
||||
"""Test data_type='pdf_file' with existing PDF file."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
# Create a minimal valid PDF file
|
||||
test_file = Path(tmpdir) / "test.pdf"
|
||||
test_file.write_bytes(
|
||||
b"%PDF-1.4\n1 0 obj\n<<\n/Type /Catalog\n>>\nendobj\ntrailer\n"
|
||||
b"<<\n/Root 1 0 R\n>>\n%%EOF"
|
||||
)
|
||||
|
||||
# Mock the PDF loader to avoid actual PDF parsing
|
||||
with patch(
|
||||
"crewai_tools.adapters.crewai_rag_adapter.DataType.get_loader"
|
||||
) as mock_loader:
|
||||
mock_loader_instance = MagicMock()
|
||||
mock_loader_instance.load.return_value = MagicMock(
|
||||
content="PDF content", metadata={}, doc_id="test-id"
|
||||
)
|
||||
mock_loader.return_value = mock_loader_instance
|
||||
|
||||
rag_tool.add(path=str(test_file), data_type="pdf_file")
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
def test_text_file_string(
|
||||
self, rag_tool: RagTool, mock_rag_client: MagicMock
|
||||
) -> None:
|
||||
"""Test data_type='text_file' with existing text file."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
test_file = Path(tmpdir) / "test.txt"
|
||||
test_file.write_text("Plain text content.")
|
||||
|
||||
rag_tool.add(path=str(test_file), data_type="text_file")
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
def test_csv_string(self, rag_tool: RagTool, mock_rag_client: MagicMock) -> None:
|
||||
"""Test data_type='csv' with existing CSV file."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
test_file = Path(tmpdir) / "test.csv"
|
||||
test_file.write_text("name,value\nfoo,1\nbar,2")
|
||||
|
||||
rag_tool.add(path=str(test_file), data_type="csv")
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
def test_json_string(self, rag_tool: RagTool, mock_rag_client: MagicMock) -> None:
|
||||
"""Test data_type='json' with existing JSON file."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
test_file = Path(tmpdir) / "test.json"
|
||||
test_file.write_text('{"key": "value", "items": [1, 2, 3]}')
|
||||
|
||||
rag_tool.add(path=str(test_file), data_type="json")
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
def test_xml_string(self, rag_tool: RagTool, mock_rag_client: MagicMock) -> None:
|
||||
"""Test data_type='xml' with existing XML file."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
test_file = Path(tmpdir) / "test.xml"
|
||||
test_file.write_text('<?xml version="1.0"?><root><item>value</item></root>')
|
||||
|
||||
rag_tool.add(path=str(test_file), data_type="xml")
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
def test_mdx_string(self, rag_tool: RagTool, mock_rag_client: MagicMock) -> None:
|
||||
"""Test data_type='mdx' with existing MDX file."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
test_file = Path(tmpdir) / "test.mdx"
|
||||
test_file.write_text("# Heading\n\nSome markdown content.")
|
||||
|
||||
rag_tool.add(path=str(test_file), data_type="mdx")
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
def test_text_string(self, rag_tool: RagTool, mock_rag_client: MagicMock) -> None:
|
||||
"""Test data_type='text' with raw text content."""
|
||||
rag_tool.add("This is raw text content.", data_type="text")
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
def test_directory_string(
|
||||
self, rag_tool: RagTool, mock_rag_client: MagicMock
|
||||
) -> None:
|
||||
"""Test data_type='directory' with existing directory."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
# Create some files in the directory
|
||||
(Path(tmpdir) / "file1.txt").write_text("Content 1")
|
||||
(Path(tmpdir) / "file2.txt").write_text("Content 2")
|
||||
|
||||
rag_tool.add(path=tmpdir, data_type="directory")
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
|
||||
class TestDataTypeEnumValues:
|
||||
"""Tests for data_type as DataType enum values."""
|
||||
|
||||
def test_datatype_file_enum_with_existing_file(
|
||||
self, rag_tool: RagTool, mock_rag_client: MagicMock
|
||||
) -> None:
|
||||
"""Test data_type=DataType.FILE with existing file (auto-detect)."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
test_file = Path(tmpdir) / "test.txt"
|
||||
test_file.write_text("File enum auto-detect content.")
|
||||
|
||||
rag_tool.add(str(test_file), data_type=DataType.FILE)
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
def test_datatype_file_enum_with_nonexistent_file_raises_error(
|
||||
self, rag_tool: RagTool
|
||||
) -> None:
|
||||
"""Test data_type=DataType.FILE raises FileNotFoundError for missing files."""
|
||||
with pytest.raises(FileNotFoundError, match="File does not exist"):
|
||||
rag_tool.add("nonexistent/file.pdf", data_type=DataType.FILE)
|
||||
|
||||
def test_datatype_pdf_file_enum(
|
||||
self, rag_tool: RagTool, mock_rag_client: MagicMock
|
||||
) -> None:
|
||||
"""Test data_type=DataType.PDF_FILE with existing file."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
test_file = Path(tmpdir) / "test.pdf"
|
||||
test_file.write_bytes(
|
||||
b"%PDF-1.4\n1 0 obj\n<<\n/Type /Catalog\n>>\nendobj\ntrailer\n"
|
||||
b"<<\n/Root 1 0 R\n>>\n%%EOF"
|
||||
)
|
||||
|
||||
with patch(
|
||||
"crewai_tools.adapters.crewai_rag_adapter.DataType.get_loader"
|
||||
) as mock_loader:
|
||||
mock_loader_instance = MagicMock()
|
||||
mock_loader_instance.load.return_value = MagicMock(
|
||||
content="PDF content", metadata={}, doc_id="test-id"
|
||||
)
|
||||
mock_loader.return_value = mock_loader_instance
|
||||
|
||||
rag_tool.add(str(test_file), data_type=DataType.PDF_FILE)
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
def test_datatype_text_file_enum(
|
||||
self, rag_tool: RagTool, mock_rag_client: MagicMock
|
||||
) -> None:
|
||||
"""Test data_type=DataType.TEXT_FILE with existing file."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
test_file = Path(tmpdir) / "test.txt"
|
||||
test_file.write_text("Text file content.")
|
||||
|
||||
rag_tool.add(str(test_file), data_type=DataType.TEXT_FILE)
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
def test_datatype_text_enum(
|
||||
self, rag_tool: RagTool, mock_rag_client: MagicMock
|
||||
) -> None:
|
||||
"""Test data_type=DataType.TEXT with raw text."""
|
||||
rag_tool.add("Raw text using enum.", data_type=DataType.TEXT)
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
def test_datatype_directory_enum(
|
||||
self, rag_tool: RagTool, mock_rag_client: MagicMock
|
||||
) -> None:
|
||||
"""Test data_type=DataType.DIRECTORY with existing directory."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
(Path(tmpdir) / "file.txt").write_text("Directory file content.")
|
||||
|
||||
rag_tool.add(tmpdir, data_type=DataType.DIRECTORY)
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
|
||||
class TestInvalidDataType:
|
||||
"""Tests for invalid data_type values."""
|
||||
|
||||
def test_invalid_string_data_type_raises_error(self, rag_tool: RagTool) -> None:
|
||||
"""Test that invalid string data_type raises ValueError."""
|
||||
with pytest.raises(ValueError, match="Invalid data_type"):
|
||||
rag_tool.add("some content", data_type="invalid_type")
|
||||
|
||||
def test_invalid_data_type_error_message_contains_valid_values(
|
||||
self, rag_tool: RagTool
|
||||
) -> None:
|
||||
"""Test that error message lists valid data_type values."""
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
rag_tool.add("some content", data_type="not_a_type")
|
||||
|
||||
error_message = str(exc_info.value)
|
||||
assert "file" in error_message
|
||||
assert "pdf_file" in error_message
|
||||
assert "text_file" in error_message
|
||||
|
||||
|
||||
class TestFileExistenceValidation:
|
||||
"""Tests for file existence validation."""
|
||||
|
||||
def test_pdf_file_not_found_raises_error(self, rag_tool: RagTool) -> None:
|
||||
"""Test that non-existent PDF file raises FileNotFoundError."""
|
||||
with pytest.raises(FileNotFoundError, match="File does not exist"):
|
||||
rag_tool.add(path="nonexistent.pdf", data_type="pdf_file")
|
||||
|
||||
def test_text_file_not_found_raises_error(self, rag_tool: RagTool) -> None:
|
||||
"""Test that non-existent text file raises FileNotFoundError."""
|
||||
with pytest.raises(FileNotFoundError, match="File does not exist"):
|
||||
rag_tool.add(path="nonexistent.txt", data_type="text_file")
|
||||
|
||||
def test_csv_file_not_found_raises_error(self, rag_tool: RagTool) -> None:
|
||||
"""Test that non-existent CSV file raises FileNotFoundError."""
|
||||
with pytest.raises(FileNotFoundError, match="File does not exist"):
|
||||
rag_tool.add(path="nonexistent.csv", data_type="csv")
|
||||
|
||||
def test_json_file_not_found_raises_error(self, rag_tool: RagTool) -> None:
|
||||
"""Test that non-existent JSON file raises FileNotFoundError."""
|
||||
with pytest.raises(FileNotFoundError, match="File does not exist"):
|
||||
rag_tool.add(path="nonexistent.json", data_type="json")
|
||||
|
||||
def test_xml_file_not_found_raises_error(self, rag_tool: RagTool) -> None:
|
||||
"""Test that non-existent XML file raises FileNotFoundError."""
|
||||
with pytest.raises(FileNotFoundError, match="File does not exist"):
|
||||
rag_tool.add(path="nonexistent.xml", data_type="xml")
|
||||
|
||||
def test_docx_file_not_found_raises_error(self, rag_tool: RagTool) -> None:
|
||||
"""Test that non-existent DOCX file raises FileNotFoundError."""
|
||||
with pytest.raises(FileNotFoundError, match="File does not exist"):
|
||||
rag_tool.add(path="nonexistent.docx", data_type="docx")
|
||||
|
||||
def test_mdx_file_not_found_raises_error(self, rag_tool: RagTool) -> None:
|
||||
"""Test that non-existent MDX file raises FileNotFoundError."""
|
||||
with pytest.raises(FileNotFoundError, match="File does not exist"):
|
||||
rag_tool.add(path="nonexistent.mdx", data_type="mdx")
|
||||
|
||||
def test_directory_not_found_raises_error(self, rag_tool: RagTool) -> None:
|
||||
"""Test that non-existent directory raises ValueError."""
|
||||
with pytest.raises(ValueError, match="Directory does not exist"):
|
||||
rag_tool.add(path="nonexistent/directory", data_type="directory")
|
||||
|
||||
|
||||
class TestKeywordArgumentVariants:
|
||||
"""Tests for different keyword argument combinations."""
|
||||
|
||||
def test_positional_argument_with_data_type(
|
||||
self, rag_tool: RagTool, mock_rag_client: MagicMock
|
||||
) -> None:
|
||||
"""Test positional argument with data_type."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
test_file = Path(tmpdir) / "test.txt"
|
||||
test_file.write_text("Positional arg content.")
|
||||
|
||||
rag_tool.add(str(test_file), data_type="text_file")
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
def test_path_keyword_with_data_type(
|
||||
self, rag_tool: RagTool, mock_rag_client: MagicMock
|
||||
) -> None:
|
||||
"""Test path keyword argument with data_type."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
test_file = Path(tmpdir) / "test.txt"
|
||||
test_file.write_text("Path keyword content.")
|
||||
|
||||
rag_tool.add(path=str(test_file), data_type="text_file")
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
def test_file_path_keyword_with_data_type(
|
||||
self, rag_tool: RagTool, mock_rag_client: MagicMock
|
||||
) -> None:
|
||||
"""Test file_path keyword argument with data_type."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
test_file = Path(tmpdir) / "test.txt"
|
||||
test_file.write_text("File path keyword content.")
|
||||
|
||||
rag_tool.add(file_path=str(test_file), data_type="text_file")
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
def test_directory_path_keyword(
|
||||
self, rag_tool: RagTool, mock_rag_client: MagicMock
|
||||
) -> None:
|
||||
"""Test directory_path keyword argument."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
(Path(tmpdir) / "file.txt").write_text("Directory content.")
|
||||
|
||||
rag_tool.add(directory_path=tmpdir)
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
|
||||
class TestAutoDetection:
|
||||
"""Tests for auto-detection of data type from content."""
|
||||
|
||||
def test_auto_detect_nonexistent_file_raises_error(self, rag_tool: RagTool) -> None:
|
||||
"""Test that auto-detection raises FileNotFoundError for missing files."""
|
||||
with pytest.raises(FileNotFoundError, match="File does not exist"):
|
||||
rag_tool.add("path/to/document.pdf")
|
||||
|
||||
def test_auto_detect_txt_file(
|
||||
self, rag_tool: RagTool, mock_rag_client: MagicMock
|
||||
) -> None:
|
||||
"""Test auto-detection of .txt file type."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
test_file = Path(tmpdir) / "auto.txt"
|
||||
test_file.write_text("Auto-detected text file.")
|
||||
|
||||
# No data_type specified - should auto-detect
|
||||
rag_tool.add(str(test_file))
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
def test_auto_detect_csv_file(
|
||||
self, rag_tool: RagTool, mock_rag_client: MagicMock
|
||||
) -> None:
|
||||
"""Test auto-detection of .csv file type."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
test_file = Path(tmpdir) / "auto.csv"
|
||||
test_file.write_text("col1,col2\nval1,val2")
|
||||
|
||||
rag_tool.add(str(test_file))
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
def test_auto_detect_json_file(
|
||||
self, rag_tool: RagTool, mock_rag_client: MagicMock
|
||||
) -> None:
|
||||
"""Test auto-detection of .json file type."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
test_file = Path(tmpdir) / "auto.json"
|
||||
test_file.write_text('{"auto": "detected"}')
|
||||
|
||||
rag_tool.add(str(test_file))
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
def test_auto_detect_directory(
|
||||
self, rag_tool: RagTool, mock_rag_client: MagicMock
|
||||
) -> None:
|
||||
"""Test auto-detection of directory type."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
(Path(tmpdir) / "file.txt").write_text("Auto-detected directory.")
|
||||
|
||||
rag_tool.add(tmpdir)
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
def test_auto_detect_raw_text(
|
||||
self, rag_tool: RagTool, mock_rag_client: MagicMock
|
||||
) -> None:
|
||||
"""Test auto-detection of raw text (non-file content)."""
|
||||
rag_tool.add("Just some raw text content")
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
|
||||
class TestMetadataHandling:
|
||||
"""Tests for metadata handling with data_type."""
|
||||
|
||||
def test_metadata_passed_to_documents(
|
||||
self, rag_tool: RagTool, mock_rag_client: MagicMock
|
||||
) -> None:
|
||||
"""Test that metadata is properly passed to documents."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
test_file = Path(tmpdir) / "test.txt"
|
||||
test_file.write_text("Content with metadata.")
|
||||
|
||||
rag_tool.add(
|
||||
path=str(test_file),
|
||||
data_type="text_file",
|
||||
metadata={"custom_key": "custom_value"},
|
||||
)
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
call_args = mock_rag_client.add_documents.call_args
|
||||
documents = call_args.kwargs.get("documents", call_args.args[0] if call_args.args else [])
|
||||
|
||||
# Check that at least one document has the custom metadata
|
||||
assert any(
|
||||
doc.get("metadata", {}).get("custom_key") == "custom_value"
|
||||
for doc in documents
|
||||
)
|
||||
@@ -23,15 +23,13 @@ from crewai_tools.tools.rag.rag_tool import Adapter
|
||||
import pytest
|
||||
|
||||
|
||||
pytestmark = [pytest.mark.vcr(filter_headers=["authorization"])]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_adapter():
|
||||
mock_adapter = MagicMock(spec=Adapter)
|
||||
return mock_adapter
|
||||
|
||||
|
||||
@pytest.mark.vcr()
|
||||
def test_directory_search_tool():
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
test_file = Path(temp_dir) / "test.txt"
|
||||
@@ -65,6 +63,7 @@ def test_pdf_search_tool(mock_adapter):
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.vcr()
|
||||
def test_txt_search_tool():
|
||||
with tempfile.NamedTemporaryFile(suffix=".txt", delete=False) as temp_file:
|
||||
temp_file.write(b"This is a test file for txt search")
|
||||
@@ -102,6 +101,7 @@ def test_docx_search_tool(mock_adapter):
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.vcr()
|
||||
def test_json_search_tool():
|
||||
with tempfile.NamedTemporaryFile(suffix=".json", delete=False) as temp_file:
|
||||
temp_file.write(b'{"test": "This is a test JSON file"}')
|
||||
@@ -127,6 +127,7 @@ def test_xml_search_tool(mock_adapter):
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.vcr()
|
||||
def test_csv_search_tool():
|
||||
with tempfile.NamedTemporaryFile(suffix=".csv", delete=False) as temp_file:
|
||||
temp_file.write(b"name,description\ntest,This is a test CSV file")
|
||||
@@ -141,6 +142,7 @@ def test_csv_search_tool():
|
||||
os.unlink(temp_file_path)
|
||||
|
||||
|
||||
@pytest.mark.vcr()
|
||||
def test_mdx_search_tool():
|
||||
with tempfile.NamedTemporaryFile(suffix=".mdx", delete=False) as temp_file:
|
||||
temp_file.write(b"# Test MDX\nThis is a test MDX file")
|
||||
|
||||
@@ -9,35 +9,35 @@ authors = [
|
||||
requires-python = ">=3.10, <3.14"
|
||||
dependencies = [
|
||||
# Core Dependencies
|
||||
"pydantic>=2.11.9",
|
||||
"openai>=1.13.3",
|
||||
"pydantic~=2.11.9",
|
||||
"openai~=1.83.0",
|
||||
"instructor>=1.3.3",
|
||||
# Text Processing
|
||||
"pdfplumber>=0.11.4",
|
||||
"regex>=2024.9.11",
|
||||
"pdfplumber~=0.11.4",
|
||||
"regex~=2024.9.11",
|
||||
# Telemetry and Monitoring
|
||||
"opentelemetry-api>=1.30.0",
|
||||
"opentelemetry-sdk>=1.30.0",
|
||||
"opentelemetry-exporter-otlp-proto-http>=1.30.0",
|
||||
"opentelemetry-api~=1.34.0",
|
||||
"opentelemetry-sdk~=1.34.0",
|
||||
"opentelemetry-exporter-otlp-proto-http~=1.34.0",
|
||||
# Data Handling
|
||||
"chromadb~=1.1.0",
|
||||
"tokenizers>=0.20.3",
|
||||
"openpyxl>=3.1.5",
|
||||
"tokenizers~=0.20.3",
|
||||
"openpyxl~=3.1.5",
|
||||
# Authentication and Security
|
||||
"python-dotenv>=1.1.1",
|
||||
"pyjwt>=2.9.0",
|
||||
"python-dotenv~=1.1.1",
|
||||
"pyjwt~=2.9.0",
|
||||
# Configuration and Utils
|
||||
"click>=8.1.7",
|
||||
"appdirs>=1.4.4",
|
||||
"jsonref>=1.1.0",
|
||||
"json-repair==0.25.2",
|
||||
"uv>=0.4.25",
|
||||
"tomli-w>=1.1.0",
|
||||
"tomli>=2.0.2",
|
||||
"json5>=0.10.0",
|
||||
"portalocker==2.7.0",
|
||||
"pydantic-settings>=2.10.1",
|
||||
"mcp>=1.16.0",
|
||||
"click~=8.1.7",
|
||||
"appdirs~=1.4.4",
|
||||
"jsonref~=1.1.0",
|
||||
"json-repair~=0.25.2",
|
||||
"tomli-w~=1.1.0",
|
||||
"tomli~=2.0.2",
|
||||
"json5~=0.10.0",
|
||||
"portalocker~=2.7.0",
|
||||
"pydantic-settings~=2.10.1",
|
||||
"mcp~=1.16.0",
|
||||
"uv~=0.9.13",
|
||||
]
|
||||
|
||||
[project.urls]
|
||||
@@ -48,55 +48,53 @@ Repository = "https://github.com/crewAIInc/crewAI"
|
||||
|
||||
[project.optional-dependencies]
|
||||
tools = [
|
||||
"crewai-tools==1.6.0",
|
||||
"crewai-tools==1.6.1",
|
||||
]
|
||||
embeddings = [
|
||||
"tiktoken~=0.8.0"
|
||||
]
|
||||
pdfplumber = [
|
||||
"pdfplumber>=0.11.4",
|
||||
]
|
||||
pandas = [
|
||||
"pandas>=2.2.3",
|
||||
"pandas~=2.2.3",
|
||||
]
|
||||
openpyxl = [
|
||||
"openpyxl>=3.1.5",
|
||||
"openpyxl~=3.1.5",
|
||||
]
|
||||
mem0 = ["mem0ai>=0.1.94"]
|
||||
mem0 = ["mem0ai~=0.1.94"]
|
||||
docling = [
|
||||
"docling>=2.12.0",
|
||||
"docling~=2.63.0",
|
||||
]
|
||||
qdrant = [
|
||||
"qdrant-client[fastembed]>=1.14.3",
|
||||
"qdrant-client[fastembed]~=1.14.3",
|
||||
]
|
||||
aws = [
|
||||
"boto3>=1.40.38",
|
||||
"boto3~=1.40.38",
|
||||
"aiobotocore~=2.25.2",
|
||||
]
|
||||
watson = [
|
||||
"ibm-watsonx-ai>=1.3.39",
|
||||
"ibm-watsonx-ai~=1.3.39",
|
||||
]
|
||||
voyageai = [
|
||||
"voyageai>=0.3.5",
|
||||
"voyageai~=0.3.5",
|
||||
]
|
||||
litellm = [
|
||||
"litellm>=1.74.9",
|
||||
"litellm~=1.74.9",
|
||||
]
|
||||
bedrock = [
|
||||
"boto3>=1.40.45",
|
||||
"boto3~=1.40.45",
|
||||
]
|
||||
google-genai = [
|
||||
"google-genai>=1.2.0",
|
||||
"google-genai~=1.2.0",
|
||||
]
|
||||
azure-ai-inference = [
|
||||
"azure-ai-inference>=1.0.0b9",
|
||||
"azure-ai-inference~=1.0.0b9",
|
||||
]
|
||||
anthropic = [
|
||||
"anthropic>=0.69.0",
|
||||
"anthropic~=0.71.0",
|
||||
]
|
||||
a2a = [
|
||||
a2a = [
|
||||
"a2a-sdk~=0.3.10",
|
||||
"httpx-auth>=0.23.1",
|
||||
"httpx-sse>=0.4.0",
|
||||
"httpx-auth~=0.23.1",
|
||||
"httpx-sse~=0.4.0",
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -3,6 +3,19 @@ from typing import Any
|
||||
import urllib.request
|
||||
import warnings
|
||||
|
||||
from crewai.agent.core import Agent
|
||||
from crewai.crew import Crew
|
||||
from crewai.crews.crew_output import CrewOutput
|
||||
from crewai.flow.flow import Flow
|
||||
from crewai.knowledge.knowledge import Knowledge
|
||||
from crewai.llm import LLM
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.process import Process
|
||||
from crewai.task import Task
|
||||
from crewai.tasks.llm_guardrail import LLMGuardrail
|
||||
from crewai.tasks.task_output import TaskOutput
|
||||
from crewai.telemetry.telemetry import Telemetry
|
||||
|
||||
|
||||
def _suppress_pydantic_deprecation_warnings() -> None:
|
||||
"""Suppress Pydantic deprecation warnings using targeted monkey patch."""
|
||||
@@ -27,7 +40,7 @@ def _suppress_pydantic_deprecation_warnings() -> None:
|
||||
|
||||
_suppress_pydantic_deprecation_warnings()
|
||||
|
||||
__version__ = "1.6.0"
|
||||
__version__ = "1.6.1"
|
||||
_telemetry_submitted = False
|
||||
|
||||
|
||||
@@ -35,8 +48,6 @@ def _track_install() -> None:
|
||||
"""Track package installation/first-use via Scarf analytics."""
|
||||
global _telemetry_submitted
|
||||
|
||||
from crewai.telemetry.telemetry import Telemetry
|
||||
|
||||
if _telemetry_submitted or Telemetry._is_telemetry_disabled():
|
||||
return
|
||||
|
||||
@@ -54,15 +65,12 @@ def _track_install() -> None:
|
||||
|
||||
def _track_install_async() -> None:
|
||||
"""Track installation in background thread to avoid blocking imports."""
|
||||
from crewai.telemetry.telemetry import Telemetry
|
||||
|
||||
if not Telemetry._is_telemetry_disabled():
|
||||
thread = threading.Thread(target=_track_install, daemon=True)
|
||||
thread.start()
|
||||
|
||||
|
||||
_track_install_async()
|
||||
|
||||
__all__ = [
|
||||
"LLM",
|
||||
"Agent",
|
||||
@@ -77,51 +85,3 @@ __all__ = [
|
||||
"TaskOutput",
|
||||
"__version__",
|
||||
]
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
if name == "Agent":
|
||||
from crewai.agent.core import Agent
|
||||
|
||||
return Agent
|
||||
if name == "Crew":
|
||||
from crewai.crew import Crew
|
||||
|
||||
return Crew
|
||||
if name == "CrewOutput":
|
||||
from crewai.crews.crew_output import CrewOutput
|
||||
|
||||
return CrewOutput
|
||||
if name == "Flow":
|
||||
from crewai.flow.flow import Flow
|
||||
|
||||
return Flow
|
||||
if name == "Knowledge":
|
||||
from crewai.knowledge.knowledge import Knowledge
|
||||
|
||||
return Knowledge
|
||||
if name == "LLM":
|
||||
from crewai.llm import LLM
|
||||
|
||||
return LLM
|
||||
if name == "BaseLLM":
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
|
||||
return BaseLLM
|
||||
if name == "Process":
|
||||
from crewai.process import Process
|
||||
|
||||
return Process
|
||||
if name == "Task":
|
||||
from crewai.task import Task
|
||||
|
||||
return Task
|
||||
if name == "LLMGuardrail":
|
||||
from crewai.tasks.llm_guardrail import LLMGuardrail
|
||||
|
||||
return LLMGuardrail
|
||||
if name == "TaskOutput":
|
||||
from crewai.tasks.task_output import TaskOutput
|
||||
|
||||
return TaskOutput
|
||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||
|
||||
@@ -4,6 +4,32 @@ import subprocess
|
||||
|
||||
import click
|
||||
|
||||
from crewai.cli.add_crew_to_flow import add_crew_to_flow
|
||||
from crewai.cli.authentication.main import AuthenticationCommand
|
||||
from crewai.cli.config import Settings
|
||||
from crewai.cli.create_crew import create_crew
|
||||
from crewai.cli.create_flow import create_flow
|
||||
from crewai.cli.crew_chat import run_chat
|
||||
from crewai.cli.deploy.main import DeployCommand
|
||||
from crewai.cli.enterprise.main import EnterpriseConfigureCommand
|
||||
from crewai.cli.evaluate_crew import evaluate_crew
|
||||
from crewai.cli.install_crew import install_crew
|
||||
from crewai.cli.kickoff_flow import kickoff_flow
|
||||
from crewai.cli.organization.main import OrganizationCommand
|
||||
from crewai.cli.plot_flow import plot_flow
|
||||
from crewai.cli.replay_from_task import replay_task_command
|
||||
from crewai.cli.reset_memories_command import reset_memories_command
|
||||
from crewai.cli.run_crew import run_crew
|
||||
from crewai.cli.settings.main import SettingsCommand
|
||||
from crewai.cli.tools.main import ToolCommand
|
||||
from crewai.cli.train_crew import train_crew
|
||||
from crewai.cli.triggers.main import TriggersCommand
|
||||
from crewai.cli.update_crew import update_crew
|
||||
from crewai.cli.utils import build_env_with_tool_repository_credentials, read_toml
|
||||
from crewai.memory.storage.kickoff_task_outputs_storage import (
|
||||
KickoffTaskOutputsSQLiteStorage,
|
||||
)
|
||||
|
||||
|
||||
@click.group()
|
||||
@click.version_option(get_version("crewai"))
|
||||
@@ -20,8 +46,6 @@ def crewai():
|
||||
@click.argument("uv_args", nargs=-1, type=click.UNPROCESSED)
|
||||
def uv(uv_args):
|
||||
"""A wrapper around uv commands that adds custom tool authentication through env vars."""
|
||||
from crewai.cli.utils import build_env_with_tool_repository_credentials, read_toml
|
||||
|
||||
env = os.environ.copy()
|
||||
try:
|
||||
pyproject_data = read_toml()
|
||||
@@ -61,12 +85,8 @@ def uv(uv_args):
|
||||
def create(type, name, provider, skip_provider=False):
|
||||
"""Create a new crew, or flow."""
|
||||
if type == "crew":
|
||||
from crewai.cli.create_crew import create_crew
|
||||
|
||||
create_crew(name, provider, skip_provider)
|
||||
elif type == "flow":
|
||||
from crewai.cli.create_flow import create_flow
|
||||
|
||||
create_flow(name)
|
||||
else:
|
||||
click.secho("Error: Invalid type. Must be 'crew' or 'flow'.", fg="red")
|
||||
@@ -109,8 +129,6 @@ def version(tools):
|
||||
)
|
||||
def train(n_iterations: int, filename: str):
|
||||
"""Train the crew."""
|
||||
from crewai.cli.train_crew import train_crew
|
||||
|
||||
click.echo(f"Training the Crew for {n_iterations} iterations")
|
||||
train_crew(n_iterations, filename)
|
||||
|
||||
@@ -130,8 +148,6 @@ def replay(task_id: str) -> None:
|
||||
task_id (str): The ID of the task to replay from.
|
||||
"""
|
||||
try:
|
||||
from crewai.cli.replay_from_task import replay_task_command
|
||||
|
||||
click.echo(f"Replaying the crew from task {task_id}")
|
||||
replay_task_command(task_id)
|
||||
except Exception as e:
|
||||
@@ -144,10 +160,6 @@ def log_tasks_outputs() -> None:
|
||||
Retrieve your latest crew.kickoff() task outputs.
|
||||
"""
|
||||
try:
|
||||
from crewai.memory.storage.kickoff_task_outputs_storage import (
|
||||
KickoffTaskOutputsSQLiteStorage,
|
||||
)
|
||||
|
||||
storage = KickoffTaskOutputsSQLiteStorage()
|
||||
tasks = storage.load()
|
||||
|
||||
@@ -205,8 +217,6 @@ def reset_memories(
|
||||
"Please specify at least one memory type to reset using the appropriate flags."
|
||||
)
|
||||
return
|
||||
from crewai.cli.reset_memories_command import reset_memories_command
|
||||
|
||||
reset_memories_command(
|
||||
long, short, entities, knowledge, agent_knowledge, kickoff_outputs, all
|
||||
)
|
||||
@@ -231,8 +241,6 @@ def reset_memories(
|
||||
)
|
||||
def test(n_iterations: int, model: str):
|
||||
"""Test the crew and evaluate the results."""
|
||||
from crewai.cli.evaluate_crew import evaluate_crew
|
||||
|
||||
click.echo(f"Testing the crew for {n_iterations} iterations with model {model}")
|
||||
evaluate_crew(n_iterations, model)
|
||||
|
||||
@@ -246,33 +254,24 @@ def test(n_iterations: int, model: str):
|
||||
@click.pass_context
|
||||
def install(context):
|
||||
"""Install the Crew."""
|
||||
from crewai.cli.install_crew import install_crew
|
||||
|
||||
install_crew(context.args)
|
||||
|
||||
|
||||
@crewai.command()
|
||||
def run():
|
||||
"""Run the Crew."""
|
||||
from crewai.cli.run_crew import run_crew
|
||||
|
||||
run_crew()
|
||||
|
||||
|
||||
@crewai.command()
|
||||
def update():
|
||||
"""Update the pyproject.toml of the Crew project to use uv."""
|
||||
from crewai.cli.update_crew import update_crew
|
||||
|
||||
update_crew()
|
||||
|
||||
|
||||
@crewai.command()
|
||||
def login():
|
||||
"""Sign Up/Login to CrewAI AOP."""
|
||||
from crewai.cli.authentication.main import AuthenticationCommand
|
||||
from crewai.cli.config import Settings
|
||||
|
||||
Settings().clear_user_settings()
|
||||
AuthenticationCommand().login()
|
||||
|
||||
@@ -287,8 +286,6 @@ def deploy():
|
||||
@click.option("-y", "--yes", is_flag=True, help="Skip the confirmation prompt")
|
||||
def deploy_create(yes: bool):
|
||||
"""Create a Crew deployment."""
|
||||
from crewai.cli.deploy.main import DeployCommand
|
||||
|
||||
deploy_cmd = DeployCommand()
|
||||
deploy_cmd.create_crew(yes)
|
||||
|
||||
@@ -296,8 +293,6 @@ def deploy_create(yes: bool):
|
||||
@deploy.command(name="list")
|
||||
def deploy_list():
|
||||
"""List all deployments."""
|
||||
from crewai.cli.deploy.main import DeployCommand
|
||||
|
||||
deploy_cmd = DeployCommand()
|
||||
deploy_cmd.list_crews()
|
||||
|
||||
@@ -306,8 +301,6 @@ def deploy_list():
|
||||
@click.option("-u", "--uuid", type=str, help="Crew UUID parameter")
|
||||
def deploy_push(uuid: str | None):
|
||||
"""Deploy the Crew."""
|
||||
from crewai.cli.deploy.main import DeployCommand
|
||||
|
||||
deploy_cmd = DeployCommand()
|
||||
deploy_cmd.deploy(uuid=uuid)
|
||||
|
||||
@@ -316,8 +309,6 @@ def deploy_push(uuid: str | None):
|
||||
@click.option("-u", "--uuid", type=str, help="Crew UUID parameter")
|
||||
def deply_status(uuid: str | None):
|
||||
"""Get the status of a deployment."""
|
||||
from crewai.cli.deploy.main import DeployCommand
|
||||
|
||||
deploy_cmd = DeployCommand()
|
||||
deploy_cmd.get_crew_status(uuid=uuid)
|
||||
|
||||
@@ -326,8 +317,6 @@ def deply_status(uuid: str | None):
|
||||
@click.option("-u", "--uuid", type=str, help="Crew UUID parameter")
|
||||
def deploy_logs(uuid: str | None):
|
||||
"""Get the logs of a deployment."""
|
||||
from crewai.cli.deploy.main import DeployCommand
|
||||
|
||||
deploy_cmd = DeployCommand()
|
||||
deploy_cmd.get_crew_logs(uuid=uuid)
|
||||
|
||||
@@ -336,8 +325,6 @@ def deploy_logs(uuid: str | None):
|
||||
@click.option("-u", "--uuid", type=str, help="Crew UUID parameter")
|
||||
def deploy_remove(uuid: str | None):
|
||||
"""Remove a deployment."""
|
||||
from crewai.cli.deploy.main import DeployCommand
|
||||
|
||||
deploy_cmd = DeployCommand()
|
||||
deploy_cmd.remove_crew(uuid=uuid)
|
||||
|
||||
@@ -350,8 +337,6 @@ def tool():
|
||||
@tool.command(name="create")
|
||||
@click.argument("handle")
|
||||
def tool_create(handle: str):
|
||||
from crewai.cli.tools.main import ToolCommand
|
||||
|
||||
tool_cmd = ToolCommand()
|
||||
tool_cmd.create(handle)
|
||||
|
||||
@@ -359,8 +344,6 @@ def tool_create(handle: str):
|
||||
@tool.command(name="install")
|
||||
@click.argument("handle")
|
||||
def tool_install(handle: str):
|
||||
from crewai.cli.tools.main import ToolCommand
|
||||
|
||||
tool_cmd = ToolCommand()
|
||||
tool_cmd.login()
|
||||
tool_cmd.install(handle)
|
||||
@@ -377,8 +360,6 @@ def tool_install(handle: str):
|
||||
@click.option("--public", "is_public", flag_value=True, default=False)
|
||||
@click.option("--private", "is_public", flag_value=False)
|
||||
def tool_publish(is_public: bool, force: bool):
|
||||
from crewai.cli.tools.main import ToolCommand
|
||||
|
||||
tool_cmd = ToolCommand()
|
||||
tool_cmd.login()
|
||||
tool_cmd.publish(is_public, force)
|
||||
@@ -392,8 +373,6 @@ def flow():
|
||||
@flow.command(name="kickoff")
|
||||
def flow_run():
|
||||
"""Kickoff the Flow."""
|
||||
from crewai.cli.kickoff_flow import kickoff_flow
|
||||
|
||||
click.echo("Running the Flow")
|
||||
kickoff_flow()
|
||||
|
||||
@@ -401,8 +380,6 @@ def flow_run():
|
||||
@flow.command(name="plot")
|
||||
def flow_plot():
|
||||
"""Plot the Flow."""
|
||||
from crewai.cli.plot_flow import plot_flow
|
||||
|
||||
click.echo("Plotting the Flow")
|
||||
plot_flow()
|
||||
|
||||
@@ -411,8 +388,6 @@ def flow_plot():
|
||||
@click.argument("crew_name")
|
||||
def flow_add_crew(crew_name):
|
||||
"""Add a crew to an existing flow."""
|
||||
from crewai.cli.add_crew_to_flow import add_crew_to_flow
|
||||
|
||||
click.echo(f"Adding crew {crew_name} to the flow")
|
||||
add_crew_to_flow(crew_name)
|
||||
|
||||
@@ -425,8 +400,6 @@ def triggers():
|
||||
@triggers.command(name="list")
|
||||
def triggers_list():
|
||||
"""List all available triggers from integrations."""
|
||||
from crewai.cli.triggers.main import TriggersCommand
|
||||
|
||||
triggers_cmd = TriggersCommand()
|
||||
triggers_cmd.list_triggers()
|
||||
|
||||
@@ -435,8 +408,6 @@ def triggers_list():
|
||||
@click.argument("trigger_path")
|
||||
def triggers_run(trigger_path: str):
|
||||
"""Execute crew with trigger payload. Format: app_slug/trigger_slug"""
|
||||
from crewai.cli.triggers.main import TriggersCommand
|
||||
|
||||
triggers_cmd = TriggersCommand()
|
||||
triggers_cmd.execute_with_trigger(trigger_path)
|
||||
|
||||
@@ -451,8 +422,6 @@ def chat():
|
||||
"\nStarting a conversation with the Crew\nType 'exit' or Ctrl+C to quit.\n",
|
||||
)
|
||||
|
||||
from crewai.cli.crew_chat import run_chat
|
||||
|
||||
run_chat()
|
||||
|
||||
|
||||
@@ -464,8 +433,6 @@ def org():
|
||||
@org.command("list")
|
||||
def org_list():
|
||||
"""List available organizations."""
|
||||
from crewai.cli.organization.main import OrganizationCommand
|
||||
|
||||
org_command = OrganizationCommand()
|
||||
org_command.list()
|
||||
|
||||
@@ -474,8 +441,6 @@ def org_list():
|
||||
@click.argument("id")
|
||||
def switch(id):
|
||||
"""Switch to a specific organization."""
|
||||
from crewai.cli.organization.main import OrganizationCommand
|
||||
|
||||
org_command = OrganizationCommand()
|
||||
org_command.switch(id)
|
||||
|
||||
@@ -483,8 +448,6 @@ def switch(id):
|
||||
@org.command()
|
||||
def current():
|
||||
"""Show current organization when 'crewai org' is called without subcommands."""
|
||||
from crewai.cli.organization.main import OrganizationCommand
|
||||
|
||||
org_command = OrganizationCommand()
|
||||
org_command.current()
|
||||
|
||||
@@ -498,8 +461,6 @@ def enterprise():
|
||||
@click.argument("enterprise_url")
|
||||
def enterprise_configure(enterprise_url: str):
|
||||
"""Configure CrewAI AOP OAuth2 settings from the provided Enterprise URL."""
|
||||
from crewai.cli.enterprise.main import EnterpriseConfigureCommand
|
||||
|
||||
enterprise_command = EnterpriseConfigureCommand()
|
||||
enterprise_command.configure(enterprise_url)
|
||||
|
||||
@@ -512,8 +473,6 @@ def config():
|
||||
@config.command("list")
|
||||
def config_list():
|
||||
"""List all CLI configuration parameters."""
|
||||
from crewai.cli.settings.main import SettingsCommand
|
||||
|
||||
config_command = SettingsCommand()
|
||||
config_command.list()
|
||||
|
||||
@@ -523,8 +482,6 @@ def config_list():
|
||||
@click.argument("value")
|
||||
def config_set(key: str, value: str):
|
||||
"""Set a CLI configuration parameter."""
|
||||
from crewai.cli.settings.main import SettingsCommand
|
||||
|
||||
config_command = SettingsCommand()
|
||||
config_command.set(key, value)
|
||||
|
||||
@@ -532,8 +489,6 @@ def config_set(key: str, value: str):
|
||||
@config.command("reset")
|
||||
def config_reset():
|
||||
"""Reset all CLI configuration parameters to default values."""
|
||||
from crewai.cli.settings.main import SettingsCommand
|
||||
|
||||
config_command = SettingsCommand()
|
||||
config_command.reset_all_settings()
|
||||
|
||||
|
||||
@@ -73,6 +73,7 @@ CLI_SETTINGS_KEYS = [
|
||||
"oauth2_audience",
|
||||
"oauth2_client_id",
|
||||
"oauth2_domain",
|
||||
"oauth2_extra",
|
||||
]
|
||||
|
||||
# Default values for CLI settings
|
||||
@@ -82,6 +83,7 @@ DEFAULT_CLI_SETTINGS = {
|
||||
"oauth2_audience": CREWAI_ENTERPRISE_DEFAULT_OAUTH2_AUDIENCE,
|
||||
"oauth2_client_id": CREWAI_ENTERPRISE_DEFAULT_OAUTH2_CLIENT_ID,
|
||||
"oauth2_domain": CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN,
|
||||
"oauth2_extra": {},
|
||||
}
|
||||
|
||||
# Readonly settings - cannot be set by the user
|
||||
|
||||
@@ -5,7 +5,7 @@ description = "{{name}} using crewAI"
|
||||
authors = [{ name = "Your Name", email = "you@example.com" }]
|
||||
requires-python = ">=3.10,<3.14"
|
||||
dependencies = [
|
||||
"crewai[tools]==1.6.0"
|
||||
"crewai[tools]==1.6.1"
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
||||
@@ -5,7 +5,7 @@ description = "{{name}} using crewAI"
|
||||
authors = [{ name = "Your Name", email = "you@example.com" }]
|
||||
requires-python = ">=3.10,<3.14"
|
||||
dependencies = [
|
||||
"crewai[tools]==1.6.0"
|
||||
"crewai[tools]==1.6.1"
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
||||
@@ -76,7 +76,7 @@ class TraceBatchManager:
|
||||
use_ephemeral: bool = False,
|
||||
) -> TraceBatch:
|
||||
"""Initialize a new trace batch (thread-safe)"""
|
||||
with self._init_lock:
|
||||
with self._batch_ready_cv:
|
||||
if self.current_batch is not None:
|
||||
logger.debug(
|
||||
"Batch already initialized, skipping duplicate initialization"
|
||||
@@ -99,7 +99,6 @@ class TraceBatchManager:
|
||||
self.backend_initialized = True
|
||||
|
||||
self._batch_ready_cv.notify_all()
|
||||
|
||||
return self.current_batch
|
||||
|
||||
def _initialize_backend_batch(
|
||||
@@ -107,7 +106,7 @@ class TraceBatchManager:
|
||||
user_context: dict[str, str],
|
||||
execution_metadata: dict[str, Any],
|
||||
use_ephemeral: bool = False,
|
||||
):
|
||||
) -> None:
|
||||
"""Send batch initialization to backend"""
|
||||
|
||||
if not is_tracing_enabled_in_context():
|
||||
@@ -204,7 +203,7 @@ class TraceBatchManager:
|
||||
return False
|
||||
return True
|
||||
|
||||
def add_event(self, trace_event: TraceEvent):
|
||||
def add_event(self, trace_event: TraceEvent) -> None:
|
||||
"""Add event to buffer"""
|
||||
self.event_buffer.append(trace_event)
|
||||
|
||||
@@ -300,7 +299,7 @@ class TraceBatchManager:
|
||||
|
||||
return finalized_batch
|
||||
|
||||
def _finalize_backend_batch(self, events_count: int = 0):
|
||||
def _finalize_backend_batch(self, events_count: int = 0) -> None:
|
||||
"""Send batch finalization to backend
|
||||
|
||||
Args:
|
||||
@@ -366,7 +365,7 @@ class TraceBatchManager:
|
||||
logger.error(f"❌ Error finalizing trace batch: {e}")
|
||||
self.plus_api.mark_trace_batch_as_failed(self.trace_batch_id, str(e))
|
||||
|
||||
def _cleanup_batch_data(self):
|
||||
def _cleanup_batch_data(self) -> None:
|
||||
"""Clean up batch data after successful finalization to free memory"""
|
||||
try:
|
||||
if hasattr(self, "event_buffer") and self.event_buffer:
|
||||
@@ -411,7 +410,7 @@ class TraceBatchManager:
|
||||
lambda: self.current_batch is not None, timeout=timeout
|
||||
)
|
||||
|
||||
def record_start_time(self, key: str):
|
||||
def record_start_time(self, key: str) -> None:
|
||||
"""Record start time for duration calculation"""
|
||||
self.execution_start_times[key] = datetime.now(timezone.utc)
|
||||
|
||||
|
||||
@@ -71,6 +71,7 @@ from crewai.events.types.reasoning_events import (
|
||||
AgentReasoningFailedEvent,
|
||||
AgentReasoningStartedEvent,
|
||||
)
|
||||
from crewai.events.types.system_events import SignalEvent, on_signal
|
||||
from crewai.events.types.task_events import (
|
||||
TaskCompletedEvent,
|
||||
TaskFailedEvent,
|
||||
@@ -159,6 +160,7 @@ class TraceCollectionListener(BaseEventListener):
|
||||
self._register_flow_event_handlers(crewai_event_bus)
|
||||
self._register_context_event_handlers(crewai_event_bus)
|
||||
self._register_action_event_handlers(crewai_event_bus)
|
||||
self._register_system_event_handlers(crewai_event_bus)
|
||||
|
||||
self._listeners_setup = True
|
||||
|
||||
@@ -458,6 +460,15 @@ class TraceCollectionListener(BaseEventListener):
|
||||
) -> None:
|
||||
self._handle_action_event("knowledge_query_failed", source, event)
|
||||
|
||||
def _register_system_event_handlers(self, event_bus: CrewAIEventsBus) -> None:
|
||||
"""Register handlers for system signal events (SIGTERM, SIGINT, etc.)."""
|
||||
|
||||
@on_signal
|
||||
def handle_signal(source: Any, event: SignalEvent) -> None:
|
||||
"""Flush trace batch on system signals to prevent data loss."""
|
||||
if self.batch_manager.is_batch_initialized():
|
||||
self.batch_manager.finalize_batch()
|
||||
|
||||
def _initialize_crew_batch(self, source: Any, event: Any) -> None:
|
||||
"""Initialize trace batch.
|
||||
|
||||
|
||||
102
lib/crewai/src/crewai/events/types/system_events.py
Normal file
102
lib/crewai/src/crewai/events/types/system_events.py
Normal file
@@ -0,0 +1,102 @@
|
||||
"""System signal event types for CrewAI.
|
||||
|
||||
This module contains event types for system-level signals like SIGTERM,
|
||||
allowing listeners to perform cleanup operations before process termination.
|
||||
"""
|
||||
|
||||
from collections.abc import Callable
|
||||
from enum import IntEnum
|
||||
import signal
|
||||
from typing import Annotated, Literal, TypeVar
|
||||
|
||||
from pydantic import Field, TypeAdapter
|
||||
|
||||
from crewai.events.base_events import BaseEvent
|
||||
|
||||
|
||||
class SignalType(IntEnum):
|
||||
"""Enumeration of supported system signals."""
|
||||
|
||||
SIGTERM = signal.SIGTERM
|
||||
SIGINT = signal.SIGINT
|
||||
SIGHUP = signal.SIGHUP
|
||||
SIGTSTP = signal.SIGTSTP
|
||||
SIGCONT = signal.SIGCONT
|
||||
|
||||
|
||||
class SigTermEvent(BaseEvent):
|
||||
"""Event emitted when SIGTERM is received."""
|
||||
|
||||
type: Literal["SIGTERM"] = "SIGTERM"
|
||||
signal_number: SignalType = SignalType.SIGTERM
|
||||
reason: str | None = None
|
||||
|
||||
|
||||
class SigIntEvent(BaseEvent):
|
||||
"""Event emitted when SIGINT is received."""
|
||||
|
||||
type: Literal["SIGINT"] = "SIGINT"
|
||||
signal_number: SignalType = SignalType.SIGINT
|
||||
reason: str | None = None
|
||||
|
||||
|
||||
class SigHupEvent(BaseEvent):
|
||||
"""Event emitted when SIGHUP is received."""
|
||||
|
||||
type: Literal["SIGHUP"] = "SIGHUP"
|
||||
signal_number: SignalType = SignalType.SIGHUP
|
||||
reason: str | None = None
|
||||
|
||||
|
||||
class SigTStpEvent(BaseEvent):
|
||||
"""Event emitted when SIGTSTP is received.
|
||||
|
||||
Note: SIGSTOP cannot be caught - it immediately suspends the process.
|
||||
"""
|
||||
|
||||
type: Literal["SIGTSTP"] = "SIGTSTP"
|
||||
signal_number: SignalType = SignalType.SIGTSTP
|
||||
reason: str | None = None
|
||||
|
||||
|
||||
class SigContEvent(BaseEvent):
|
||||
"""Event emitted when SIGCONT is received."""
|
||||
|
||||
type: Literal["SIGCONT"] = "SIGCONT"
|
||||
signal_number: SignalType = SignalType.SIGCONT
|
||||
reason: str | None = None
|
||||
|
||||
|
||||
SignalEvent = Annotated[
|
||||
SigTermEvent | SigIntEvent | SigHupEvent | SigTStpEvent | SigContEvent,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
||||
signal_event_adapter: TypeAdapter[SignalEvent] = TypeAdapter(SignalEvent)
|
||||
|
||||
SIGNAL_EVENT_TYPES: tuple[type[BaseEvent], ...] = (
|
||||
SigTermEvent,
|
||||
SigIntEvent,
|
||||
SigHupEvent,
|
||||
SigTStpEvent,
|
||||
SigContEvent,
|
||||
)
|
||||
|
||||
|
||||
T = TypeVar("T", bound=Callable[[object, SignalEvent], None])
|
||||
|
||||
|
||||
def on_signal(func: T) -> T:
|
||||
"""Decorator to register a handler for all signal events.
|
||||
|
||||
Args:
|
||||
func: Handler function that receives (source, event) arguments.
|
||||
|
||||
Returns:
|
||||
The original function, registered for all signal event types.
|
||||
"""
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
|
||||
for event_type in SIGNAL_EVENT_TYPES:
|
||||
crewai_event_bus.on(event_type)(func)
|
||||
return func
|
||||
@@ -57,7 +57,12 @@ if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.get_supported_openai_params import (
|
||||
get_supported_openai_params,
|
||||
)
|
||||
from litellm.types.utils import ChatCompletionDeltaToolCall, Choices, ModelResponse
|
||||
from litellm.types.utils import (
|
||||
ChatCompletionDeltaToolCall,
|
||||
Choices,
|
||||
Function,
|
||||
ModelResponse,
|
||||
)
|
||||
from litellm.utils import supports_response_schema
|
||||
|
||||
from crewai.agent.core import Agent
|
||||
@@ -73,7 +78,12 @@ try:
|
||||
from litellm.litellm_core_utils.get_supported_openai_params import (
|
||||
get_supported_openai_params,
|
||||
)
|
||||
from litellm.types.utils import ChatCompletionDeltaToolCall, Choices, ModelResponse
|
||||
from litellm.types.utils import (
|
||||
ChatCompletionDeltaToolCall,
|
||||
Choices,
|
||||
Function,
|
||||
ModelResponse,
|
||||
)
|
||||
from litellm.utils import supports_response_schema
|
||||
|
||||
LITELLM_AVAILABLE = True
|
||||
@@ -84,6 +94,7 @@ except ImportError:
|
||||
ContextWindowExceededError = Exception # type: ignore
|
||||
get_supported_openai_params = None # type: ignore
|
||||
ChatCompletionDeltaToolCall = None # type: ignore
|
||||
Function = None # type: ignore
|
||||
ModelResponse = None # type: ignore
|
||||
supports_response_schema = None # type: ignore
|
||||
CustomLogger = None # type: ignore
|
||||
@@ -406,46 +417,100 @@ class LLM(BaseLLM):
|
||||
instance.is_litellm = True
|
||||
return instance
|
||||
|
||||
@classmethod
|
||||
def _matches_provider_pattern(cls, model: str, provider: str) -> bool:
|
||||
"""Check if a model name matches provider-specific patterns.
|
||||
|
||||
This allows supporting models that aren't in the hardcoded constants list,
|
||||
including "latest" versions and new models that follow provider naming conventions.
|
||||
|
||||
Args:
|
||||
model: The model name to check
|
||||
provider: The provider to check against (canonical name)
|
||||
|
||||
Returns:
|
||||
True if the model matches the provider's naming pattern, False otherwise
|
||||
"""
|
||||
model_lower = model.lower()
|
||||
|
||||
if provider == "openai":
|
||||
return any(
|
||||
model_lower.startswith(prefix)
|
||||
for prefix in ["gpt-", "o1", "o3", "o4", "whisper-"]
|
||||
)
|
||||
|
||||
if provider == "anthropic" or provider == "claude":
|
||||
return any(
|
||||
model_lower.startswith(prefix) for prefix in ["claude-", "anthropic."]
|
||||
)
|
||||
|
||||
if provider == "gemini" or provider == "google":
|
||||
return any(
|
||||
model_lower.startswith(prefix)
|
||||
for prefix in ["gemini-", "gemma-", "learnlm-"]
|
||||
)
|
||||
|
||||
if provider == "bedrock":
|
||||
return "." in model_lower
|
||||
|
||||
if provider == "azure":
|
||||
return any(
|
||||
model_lower.startswith(prefix)
|
||||
for prefix in ["gpt-", "gpt-35-", "o1", "o3", "o4", "azure-"]
|
||||
)
|
||||
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def _validate_model_in_constants(cls, model: str, provider: str) -> bool:
|
||||
"""Validate if a model name exists in the provider's constants.
|
||||
"""Validate if a model name exists in the provider's constants or matches provider patterns.
|
||||
|
||||
This method first checks the hardcoded constants list for known models.
|
||||
If not found, it falls back to pattern matching to support new models,
|
||||
"latest" versions, and models that follow provider naming conventions.
|
||||
|
||||
Args:
|
||||
model: The model name to validate
|
||||
provider: The provider to check against (canonical name)
|
||||
|
||||
Returns:
|
||||
True if the model exists in the provider's constants, False otherwise
|
||||
True if the model exists in constants or matches provider patterns, False otherwise
|
||||
"""
|
||||
if provider == "openai":
|
||||
return model in OPENAI_MODELS
|
||||
if provider == "openai" and model in OPENAI_MODELS:
|
||||
return True
|
||||
|
||||
if provider == "anthropic" or provider == "claude":
|
||||
return model in ANTHROPIC_MODELS
|
||||
if (
|
||||
provider == "anthropic" or provider == "claude"
|
||||
) and model in ANTHROPIC_MODELS:
|
||||
return True
|
||||
|
||||
if provider == "gemini":
|
||||
return model in GEMINI_MODELS
|
||||
if (provider == "gemini" or provider == "google") and model in GEMINI_MODELS:
|
||||
return True
|
||||
|
||||
if provider == "bedrock":
|
||||
return model in BEDROCK_MODELS
|
||||
if provider == "bedrock" and model in BEDROCK_MODELS:
|
||||
return True
|
||||
|
||||
if provider == "azure":
|
||||
# azure does not provide a list of available models, determine a better way to handle this
|
||||
return True
|
||||
|
||||
return False
|
||||
# Fallback to pattern matching for models not in constants
|
||||
return cls._matches_provider_pattern(model, provider)
|
||||
|
||||
@classmethod
|
||||
def _infer_provider_from_model(cls, model: str) -> str:
|
||||
"""Infer the provider from the model name.
|
||||
|
||||
This method first checks the hardcoded constants list for known models.
|
||||
If not found, it uses pattern matching to infer the provider from model name patterns.
|
||||
This allows supporting new models and "latest" versions without hardcoding.
|
||||
|
||||
Args:
|
||||
model: The model name without provider prefix
|
||||
|
||||
Returns:
|
||||
The inferred provider name, defaults to "openai"
|
||||
"""
|
||||
|
||||
if model in OPENAI_MODELS:
|
||||
return "openai"
|
||||
|
||||
@@ -556,7 +621,9 @@ class LLM(BaseLLM):
|
||||
self.callbacks = callbacks
|
||||
self.context_window_size = 0
|
||||
self.reasoning_effort = reasoning_effort
|
||||
self.additional_params = kwargs
|
||||
self.additional_params = {
|
||||
k: v for k, v in kwargs.items() if k not in ("is_litellm", "provider")
|
||||
}
|
||||
self.is_anthropic = self._is_anthropic_model(model)
|
||||
self.stream = stream
|
||||
self.interceptor = interceptor
|
||||
@@ -1150,6 +1217,281 @@ class LLM(BaseLLM):
|
||||
)
|
||||
return text_response
|
||||
|
||||
async def _ahandle_non_streaming_response(
|
||||
self,
|
||||
params: dict[str, Any],
|
||||
callbacks: list[Any] | None = None,
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Task | None = None,
|
||||
from_agent: Agent | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str | Any:
|
||||
"""Handle an async non-streaming response from the LLM.
|
||||
|
||||
Args:
|
||||
params: Parameters for the completion call
|
||||
callbacks: Optional list of callback functions
|
||||
available_functions: Dict of available functions
|
||||
from_task: Optional Task that invoked the LLM
|
||||
from_agent: Optional Agent that invoked the LLM
|
||||
response_model: Optional Response model
|
||||
|
||||
Returns:
|
||||
str: The response text
|
||||
"""
|
||||
if response_model and self.is_litellm:
|
||||
from crewai.utilities.internal_instructor import InternalInstructor
|
||||
|
||||
messages = params.get("messages", [])
|
||||
if not messages:
|
||||
raise ValueError("Messages are required when using response_model")
|
||||
|
||||
combined_content = "\n\n".join(
|
||||
f"{msg['role'].upper()}: {msg['content']}" for msg in messages
|
||||
)
|
||||
|
||||
instructor_instance = InternalInstructor(
|
||||
content=combined_content,
|
||||
model=response_model,
|
||||
llm=self,
|
||||
)
|
||||
result = instructor_instance.to_pydantic()
|
||||
structured_response = result.model_dump_json()
|
||||
self._handle_emit_call_events(
|
||||
response=structured_response,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
)
|
||||
return structured_response
|
||||
|
||||
try:
|
||||
if response_model:
|
||||
params["response_model"] = response_model
|
||||
response = await litellm.acompletion(**params)
|
||||
|
||||
except ContextWindowExceededError as e:
|
||||
raise LLMContextLengthExceededError(str(e)) from e
|
||||
|
||||
if response_model is not None:
|
||||
if isinstance(response, BaseModel):
|
||||
structured_response = response.model_dump_json()
|
||||
self._handle_emit_call_events(
|
||||
response=structured_response,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
)
|
||||
return structured_response
|
||||
|
||||
response_message = cast(Choices, cast(ModelResponse, response).choices)[
|
||||
0
|
||||
].message
|
||||
text_response = response_message.content or ""
|
||||
|
||||
if callbacks and len(callbacks) > 0:
|
||||
for callback in callbacks:
|
||||
if hasattr(callback, "log_success_event"):
|
||||
usage_info = getattr(response, "usage", None)
|
||||
if usage_info:
|
||||
callback.log_success_event(
|
||||
kwargs=params,
|
||||
response_obj={"usage": usage_info},
|
||||
start_time=0,
|
||||
end_time=0,
|
||||
)
|
||||
|
||||
tool_calls = getattr(response_message, "tool_calls", [])
|
||||
|
||||
if (not tool_calls or not available_functions) and text_response:
|
||||
self._handle_emit_call_events(
|
||||
response=text_response,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
)
|
||||
return text_response
|
||||
|
||||
if tool_calls and not available_functions and not text_response:
|
||||
return tool_calls
|
||||
|
||||
tool_result = self._handle_tool_call(
|
||||
tool_calls, available_functions, from_task, from_agent
|
||||
)
|
||||
if tool_result is not None:
|
||||
return tool_result
|
||||
|
||||
self._handle_emit_call_events(
|
||||
response=text_response,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
)
|
||||
return text_response
|
||||
|
||||
async def _ahandle_streaming_response(
|
||||
self,
|
||||
params: dict[str, Any],
|
||||
callbacks: list[Any] | None = None,
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Task | None = None,
|
||||
from_agent: Agent | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> Any:
|
||||
"""Handle an async streaming response from the LLM.
|
||||
|
||||
Args:
|
||||
params: Parameters for the completion call
|
||||
callbacks: Optional list of callback functions
|
||||
available_functions: Dict of available functions
|
||||
from_task: Optional task object
|
||||
from_agent: Optional agent object
|
||||
response_model: Optional response model
|
||||
|
||||
Returns:
|
||||
str: The complete response text
|
||||
"""
|
||||
full_response = ""
|
||||
chunk_count = 0
|
||||
usage_info = None
|
||||
|
||||
accumulated_tool_args: defaultdict[int, AccumulatedToolArgs] = defaultdict(
|
||||
AccumulatedToolArgs
|
||||
)
|
||||
|
||||
params["stream"] = True
|
||||
params["stream_options"] = {"include_usage": True}
|
||||
|
||||
try:
|
||||
async for chunk in await litellm.acompletion(**params):
|
||||
chunk_count += 1
|
||||
chunk_content = None
|
||||
|
||||
try:
|
||||
choices = None
|
||||
if isinstance(chunk, dict) and "choices" in chunk:
|
||||
choices = chunk["choices"]
|
||||
elif hasattr(chunk, "choices"):
|
||||
if not isinstance(chunk.choices, type):
|
||||
choices = chunk.choices
|
||||
|
||||
if hasattr(chunk, "usage") and chunk.usage is not None:
|
||||
usage_info = chunk.usage
|
||||
|
||||
if choices and len(choices) > 0:
|
||||
first_choice = choices[0]
|
||||
delta = None
|
||||
|
||||
if isinstance(first_choice, dict):
|
||||
delta = first_choice.get("delta", {})
|
||||
elif hasattr(first_choice, "delta"):
|
||||
delta = first_choice.delta
|
||||
|
||||
if delta:
|
||||
if isinstance(delta, dict):
|
||||
chunk_content = delta.get("content")
|
||||
elif hasattr(delta, "content"):
|
||||
chunk_content = delta.content
|
||||
|
||||
tool_calls: list[ChatCompletionDeltaToolCall] | None = None
|
||||
if isinstance(delta, dict):
|
||||
tool_calls = delta.get("tool_calls")
|
||||
elif hasattr(delta, "tool_calls"):
|
||||
tool_calls = delta.tool_calls
|
||||
|
||||
if tool_calls:
|
||||
for tool_call in tool_calls:
|
||||
idx = tool_call.index
|
||||
if tool_call.function:
|
||||
if tool_call.function.name:
|
||||
accumulated_tool_args[
|
||||
idx
|
||||
].function.name = tool_call.function.name
|
||||
if tool_call.function.arguments:
|
||||
accumulated_tool_args[
|
||||
idx
|
||||
].function.arguments += (
|
||||
tool_call.function.arguments
|
||||
)
|
||||
|
||||
except (AttributeError, KeyError, IndexError, TypeError):
|
||||
pass
|
||||
|
||||
if chunk_content:
|
||||
full_response += chunk_content
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LLMStreamChunkEvent(
|
||||
chunk=chunk_content,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
|
||||
if callbacks and len(callbacks) > 0 and usage_info:
|
||||
for callback in callbacks:
|
||||
if hasattr(callback, "log_success_event"):
|
||||
callback.log_success_event(
|
||||
kwargs=params,
|
||||
response_obj={"usage": usage_info},
|
||||
start_time=0,
|
||||
end_time=0,
|
||||
)
|
||||
|
||||
if accumulated_tool_args and available_functions:
|
||||
# Convert accumulated tool args to ChatCompletionDeltaToolCall objects
|
||||
tool_calls_list: list[ChatCompletionDeltaToolCall] = [
|
||||
ChatCompletionDeltaToolCall(
|
||||
index=idx,
|
||||
function=Function(
|
||||
name=tool_arg.function.name,
|
||||
arguments=tool_arg.function.arguments,
|
||||
),
|
||||
)
|
||||
for idx, tool_arg in accumulated_tool_args.items()
|
||||
if tool_arg.function.name
|
||||
]
|
||||
|
||||
if tool_calls_list:
|
||||
result = self._handle_streaming_tool_calls(
|
||||
tool_calls=tool_calls_list,
|
||||
accumulated_tool_args=accumulated_tool_args,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
self._handle_emit_call_events(
|
||||
response=full_response,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params.get("messages"),
|
||||
)
|
||||
return full_response
|
||||
|
||||
except ContextWindowExceededError as e:
|
||||
raise LLMContextLengthExceededError(str(e)) from e
|
||||
except Exception:
|
||||
if chunk_count == 0:
|
||||
raise
|
||||
if full_response:
|
||||
self._handle_emit_call_events(
|
||||
response=full_response,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params.get("messages"),
|
||||
)
|
||||
return full_response
|
||||
raise
|
||||
|
||||
def _handle_tool_call(
|
||||
self,
|
||||
tool_calls: list[Any],
|
||||
@@ -1367,6 +1709,128 @@ class LLM(BaseLLM):
|
||||
)
|
||||
raise
|
||||
|
||||
async def acall(
|
||||
self,
|
||||
messages: str | list[LLMMessage],
|
||||
tools: list[dict[str, BaseTool]] | None = None,
|
||||
callbacks: list[Any] | None = None,
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Task | None = None,
|
||||
from_agent: Agent | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str | Any:
|
||||
"""Async high-level LLM call method.
|
||||
|
||||
Args:
|
||||
messages: Input messages for the LLM.
|
||||
Can be a string or list of message dictionaries.
|
||||
If string, it will be converted to a single user message.
|
||||
If list, each dict must have 'role' and 'content' keys.
|
||||
tools: Optional list of tool schemas for function calling.
|
||||
Each tool should define its name, description, and parameters.
|
||||
callbacks: Optional list of callback functions to be executed
|
||||
during and after the LLM call.
|
||||
available_functions: Optional dict mapping function names to callables
|
||||
that can be invoked by the LLM.
|
||||
from_task: Optional Task that invoked the LLM
|
||||
from_agent: Optional Agent that invoked the LLM
|
||||
response_model: Optional Model that contains a pydantic response model.
|
||||
|
||||
Returns:
|
||||
Union[str, Any]: Either a text response from the LLM (str) or
|
||||
the result of a tool function call (Any).
|
||||
|
||||
Raises:
|
||||
TypeError: If messages format is invalid
|
||||
ValueError: If response format is not supported
|
||||
LLMContextLengthExceededError: If input exceeds model's context limit
|
||||
"""
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LLMCallStartedEvent(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
callbacks=callbacks,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
model=self.model,
|
||||
),
|
||||
)
|
||||
|
||||
self._validate_call_params()
|
||||
|
||||
if isinstance(messages, str):
|
||||
messages = [{"role": "user", "content": messages}]
|
||||
|
||||
if "o1" in self.model.lower():
|
||||
for message in messages:
|
||||
if message.get("role") == "system":
|
||||
msg_role: Literal["assistant"] = "assistant"
|
||||
message["role"] = msg_role
|
||||
|
||||
with suppress_warnings():
|
||||
if callbacks and len(callbacks) > 0:
|
||||
self.set_callbacks(callbacks)
|
||||
try:
|
||||
params = self._prepare_completion_params(messages, tools)
|
||||
|
||||
if self.stream:
|
||||
return await self._ahandle_streaming_response(
|
||||
params=params,
|
||||
callbacks=callbacks,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
response_model=response_model,
|
||||
)
|
||||
|
||||
return await self._ahandle_non_streaming_response(
|
||||
params=params,
|
||||
callbacks=callbacks,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
response_model=response_model,
|
||||
)
|
||||
except LLMContextLengthExceededError:
|
||||
raise
|
||||
except Exception as e:
|
||||
unsupported_stop = "Unsupported parameter" in str(
|
||||
e
|
||||
) and "'stop'" in str(e)
|
||||
|
||||
if unsupported_stop:
|
||||
if (
|
||||
"additional_drop_params" in self.additional_params
|
||||
and isinstance(
|
||||
self.additional_params["additional_drop_params"], list
|
||||
)
|
||||
):
|
||||
self.additional_params["additional_drop_params"].append("stop")
|
||||
else:
|
||||
self.additional_params = {"additional_drop_params": ["stop"]}
|
||||
|
||||
logging.info("Retrying LLM call without the unsupported 'stop'")
|
||||
|
||||
return await self.acall(
|
||||
messages,
|
||||
tools=tools,
|
||||
callbacks=callbacks,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
response_model=response_model,
|
||||
)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LLMCallFailedEvent(
|
||||
error=str(e), from_task=from_task, from_agent=from_agent
|
||||
),
|
||||
)
|
||||
raise
|
||||
|
||||
def _handle_emit_call_events(
|
||||
self,
|
||||
response: Any,
|
||||
@@ -1699,12 +2163,14 @@ class LLM(BaseLLM):
|
||||
max_tokens=self.max_tokens,
|
||||
presence_penalty=self.presence_penalty,
|
||||
frequency_penalty=self.frequency_penalty,
|
||||
logit_bias=copy.deepcopy(self.logit_bias, memo)
|
||||
if self.logit_bias
|
||||
else None,
|
||||
response_format=copy.deepcopy(self.response_format, memo)
|
||||
if self.response_format
|
||||
else None,
|
||||
logit_bias=(
|
||||
copy.deepcopy(self.logit_bias, memo) if self.logit_bias else None
|
||||
),
|
||||
response_format=(
|
||||
copy.deepcopy(self.response_format, memo)
|
||||
if self.response_format
|
||||
else None
|
||||
),
|
||||
seed=self.seed,
|
||||
logprobs=self.logprobs,
|
||||
top_logprobs=self.top_logprobs,
|
||||
|
||||
@@ -158,6 +158,44 @@ class BaseLLM(ABC):
|
||||
RuntimeError: If the LLM request fails for other reasons.
|
||||
"""
|
||||
|
||||
async def acall(
|
||||
self,
|
||||
messages: str | list[LLMMessage],
|
||||
tools: list[dict[str, BaseTool]] | None = None,
|
||||
callbacks: list[Any] | None = None,
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Task | None = None,
|
||||
from_agent: Agent | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str | Any:
|
||||
"""Call the LLM with the given messages.
|
||||
|
||||
Args:
|
||||
messages: Input messages for the LLM.
|
||||
Can be a string or list of message dictionaries.
|
||||
If string, it will be converted to a single user message.
|
||||
If list, each dict must have 'role' and 'content' keys.
|
||||
tools: Optional list of tool schemas for function calling.
|
||||
Each tool should define its name, description, and parameters.
|
||||
callbacks: Optional list of callback functions to be executed
|
||||
during and after the LLM call.
|
||||
available_functions: Optional dict mapping function names to callables
|
||||
that can be invoked by the LLM.
|
||||
from_task: Optional task caller to be used for the LLM call.
|
||||
from_agent: Optional agent caller to be used for the LLM call.
|
||||
response_model: Optional response model to be used for the LLM call.
|
||||
|
||||
Returns:
|
||||
Either a text response from the LLM (str) or
|
||||
the result of a tool function call (Any).
|
||||
|
||||
Raises:
|
||||
ValueError: If the messages format is invalid.
|
||||
TimeoutError: If the LLM request times out.
|
||||
RuntimeError: If the LLM request fails for other reasons.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def _convert_tools_for_interference(
|
||||
self, tools: list[dict[str, BaseTool]]
|
||||
) -> list[dict[str, BaseTool]]:
|
||||
|
||||
@@ -182,6 +182,8 @@ OPENAI_MODELS: list[OpenAIModels] = [
|
||||
|
||||
|
||||
AnthropicModels: TypeAlias = Literal[
|
||||
"claude-opus-4-5-20251101",
|
||||
"claude-opus-4-5",
|
||||
"claude-3-7-sonnet-latest",
|
||||
"claude-3-7-sonnet-20250219",
|
||||
"claude-3-5-haiku-latest",
|
||||
@@ -208,6 +210,8 @@ AnthropicModels: TypeAlias = Literal[
|
||||
"claude-3-haiku-20240307",
|
||||
]
|
||||
ANTHROPIC_MODELS: list[AnthropicModels] = [
|
||||
"claude-opus-4-5-20251101",
|
||||
"claude-opus-4-5",
|
||||
"claude-3-7-sonnet-latest",
|
||||
"claude-3-7-sonnet-20250219",
|
||||
"claude-3-5-haiku-latest",
|
||||
@@ -252,6 +256,7 @@ GeminiModels: TypeAlias = Literal[
|
||||
"gemini-2.5-flash-preview-tts",
|
||||
"gemini-2.5-pro-preview-tts",
|
||||
"gemini-2.5-computer-use-preview-10-2025",
|
||||
"gemini-2.5-pro-exp-03-25",
|
||||
"gemini-2.0-flash",
|
||||
"gemini-2.0-flash-001",
|
||||
"gemini-2.0-flash-exp",
|
||||
@@ -305,6 +310,7 @@ GEMINI_MODELS: list[GeminiModels] = [
|
||||
"gemini-2.5-flash-preview-tts",
|
||||
"gemini-2.5-pro-preview-tts",
|
||||
"gemini-2.5-computer-use-preview-10-2025",
|
||||
"gemini-2.5-pro-exp-03-25",
|
||||
"gemini-2.0-flash",
|
||||
"gemini-2.0-flash-001",
|
||||
"gemini-2.0-flash-exp",
|
||||
@@ -452,6 +458,7 @@ BedrockModels: TypeAlias = Literal[
|
||||
"anthropic.claude-3-sonnet-20240229-v1:0:28k",
|
||||
"anthropic.claude-haiku-4-5-20251001-v1:0",
|
||||
"anthropic.claude-instant-v1:2:100k",
|
||||
"anthropic.claude-opus-4-5-20251101-v1:0",
|
||||
"anthropic.claude-opus-4-1-20250805-v1:0",
|
||||
"anthropic.claude-opus-4-20250514-v1:0",
|
||||
"anthropic.claude-sonnet-4-20250514-v1:0",
|
||||
@@ -524,6 +531,7 @@ BEDROCK_MODELS: list[BedrockModels] = [
|
||||
"anthropic.claude-3-sonnet-20240229-v1:0:28k",
|
||||
"anthropic.claude-haiku-4-5-20251001-v1:0",
|
||||
"anthropic.claude-instant-v1:2:100k",
|
||||
"anthropic.claude-opus-4-5-20251101-v1:0",
|
||||
"anthropic.claude-opus-4-1-20250805-v1:0",
|
||||
"anthropic.claude-opus-4-20250514-v1:0",
|
||||
"anthropic.claude-sonnet-4-20250514-v1:0",
|
||||
|
||||
@@ -9,7 +9,7 @@ from pydantic import BaseModel
|
||||
|
||||
from crewai.events.types.llm_events import LLMCallType
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.llms.hooks.transport import HTTPTransport
|
||||
from crewai.llms.hooks.transport import AsyncHTTPTransport, HTTPTransport
|
||||
from crewai.utilities.agent_utils import is_context_length_exceeded
|
||||
from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
||||
LLMContextLengthExceededError,
|
||||
@@ -21,7 +21,7 @@ if TYPE_CHECKING:
|
||||
from crewai.llms.hooks.base import BaseInterceptor
|
||||
|
||||
try:
|
||||
from anthropic import Anthropic
|
||||
from anthropic import Anthropic, AsyncAnthropic
|
||||
from anthropic.types import Message
|
||||
from anthropic.types.tool_use_block import ToolUseBlock
|
||||
import httpx
|
||||
@@ -84,15 +84,20 @@ class AnthropicCompletion(BaseLLM):
|
||||
|
||||
self.client = Anthropic(**self._get_client_params())
|
||||
|
||||
async_client_params = self._get_client_params()
|
||||
if self.interceptor:
|
||||
async_transport = AsyncHTTPTransport(interceptor=self.interceptor)
|
||||
async_http_client = httpx.AsyncClient(transport=async_transport)
|
||||
async_client_params["http_client"] = async_http_client
|
||||
|
||||
self.async_client = AsyncAnthropic(**async_client_params)
|
||||
|
||||
# Store completion parameters
|
||||
self.max_tokens = max_tokens
|
||||
self.top_p = top_p
|
||||
self.stream = stream
|
||||
self.stop_sequences = stop_sequences or []
|
||||
|
||||
# Model-specific settings
|
||||
self.is_claude_3 = "claude-3" in model.lower()
|
||||
self.supports_tools = self.is_claude_3 # Claude 3+ supports tool use
|
||||
self.supports_tools = True
|
||||
|
||||
@property
|
||||
def stop(self) -> list[str]:
|
||||
@@ -213,6 +218,72 @@ class AnthropicCompletion(BaseLLM):
|
||||
)
|
||||
raise
|
||||
|
||||
async def acall(
|
||||
self,
|
||||
messages: str | list[LLMMessage],
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
callbacks: list[Any] | None = None,
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str | Any:
|
||||
"""Async call to Anthropic messages API.
|
||||
|
||||
Args:
|
||||
messages: Input messages for the chat completion
|
||||
tools: List of tool/function definitions
|
||||
callbacks: Callback functions (not used in native implementation)
|
||||
available_functions: Available functions for tool calling
|
||||
from_task: Task that initiated the call
|
||||
from_agent: Agent that initiated the call
|
||||
|
||||
Returns:
|
||||
Chat completion response or tool call result
|
||||
"""
|
||||
try:
|
||||
self._emit_call_started_event(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
callbacks=callbacks,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
formatted_messages, system_message = self._format_messages_for_anthropic(
|
||||
messages
|
||||
)
|
||||
|
||||
completion_params = self._prepare_completion_params(
|
||||
formatted_messages, system_message, tools
|
||||
)
|
||||
|
||||
if self.stream:
|
||||
return await self._ahandle_streaming_completion(
|
||||
completion_params,
|
||||
available_functions,
|
||||
from_task,
|
||||
from_agent,
|
||||
response_model,
|
||||
)
|
||||
|
||||
return await self._ahandle_completion(
|
||||
completion_params,
|
||||
available_functions,
|
||||
from_task,
|
||||
from_agent,
|
||||
response_model,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Anthropic API call failed: {e!s}"
|
||||
logging.error(error_msg)
|
||||
self._emit_call_failed_event(
|
||||
error=error_msg, from_task=from_task, from_agent=from_agent
|
||||
)
|
||||
raise
|
||||
|
||||
def _prepare_completion_params(
|
||||
self,
|
||||
messages: list[LLMMessage],
|
||||
@@ -546,7 +617,7 @@ class AnthropicCompletion(BaseLLM):
|
||||
# Execute the tool
|
||||
result = self._handle_tool_execution(
|
||||
function_name=function_name,
|
||||
function_args=function_args, # type: ignore
|
||||
function_args=function_args,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
@@ -626,6 +697,275 @@ class AnthropicCompletion(BaseLLM):
|
||||
return tool_results[0]["content"]
|
||||
raise e
|
||||
|
||||
async def _ahandle_completion(
|
||||
self,
|
||||
params: dict[str, Any],
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str | Any:
|
||||
"""Handle non-streaming async message completion."""
|
||||
if response_model:
|
||||
structured_tool = {
|
||||
"name": "structured_output",
|
||||
"description": "Returns structured data according to the schema",
|
||||
"input_schema": response_model.model_json_schema(),
|
||||
}
|
||||
|
||||
params["tools"] = [structured_tool]
|
||||
params["tool_choice"] = {"type": "tool", "name": "structured_output"}
|
||||
|
||||
try:
|
||||
response: Message = await self.async_client.messages.create(**params)
|
||||
|
||||
except Exception as e:
|
||||
if is_context_length_exceeded(e):
|
||||
logging.error(f"Context window exceeded: {e}")
|
||||
raise LLMContextLengthExceededError(str(e)) from e
|
||||
raise e from e
|
||||
|
||||
usage = self._extract_anthropic_token_usage(response)
|
||||
self._track_token_usage_internal(usage)
|
||||
|
||||
if response_model and response.content:
|
||||
tool_uses = [
|
||||
block for block in response.content if isinstance(block, ToolUseBlock)
|
||||
]
|
||||
if tool_uses and tool_uses[0].name == "structured_output":
|
||||
structured_data = tool_uses[0].input
|
||||
structured_json = json.dumps(structured_data)
|
||||
|
||||
self._emit_call_completed_event(
|
||||
response=structured_json,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
)
|
||||
|
||||
return structured_json
|
||||
|
||||
if response.content and available_functions:
|
||||
tool_uses = [
|
||||
block for block in response.content if isinstance(block, ToolUseBlock)
|
||||
]
|
||||
|
||||
if tool_uses:
|
||||
return await self._ahandle_tool_use_conversation(
|
||||
response,
|
||||
tool_uses,
|
||||
params,
|
||||
available_functions,
|
||||
from_task,
|
||||
from_agent,
|
||||
)
|
||||
|
||||
content = ""
|
||||
if response.content:
|
||||
for content_block in response.content:
|
||||
if hasattr(content_block, "text"):
|
||||
content += content_block.text
|
||||
|
||||
content = self._apply_stop_words(content)
|
||||
|
||||
self._emit_call_completed_event(
|
||||
response=content,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
)
|
||||
|
||||
if usage.get("total_tokens", 0) > 0:
|
||||
logging.info(f"Anthropic API usage: {usage}")
|
||||
|
||||
return content
|
||||
|
||||
async def _ahandle_streaming_completion(
|
||||
self,
|
||||
params: dict[str, Any],
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str:
|
||||
"""Handle async streaming message completion."""
|
||||
if response_model:
|
||||
structured_tool = {
|
||||
"name": "structured_output",
|
||||
"description": "Returns structured data according to the schema",
|
||||
"input_schema": response_model.model_json_schema(),
|
||||
}
|
||||
|
||||
params["tools"] = [structured_tool]
|
||||
params["tool_choice"] = {"type": "tool", "name": "structured_output"}
|
||||
|
||||
full_response = ""
|
||||
|
||||
stream_params = {k: v for k, v in params.items() if k != "stream"}
|
||||
|
||||
async with self.async_client.messages.stream(**stream_params) as stream:
|
||||
async for event in stream:
|
||||
if hasattr(event, "delta") and hasattr(event.delta, "text"):
|
||||
text_delta = event.delta.text
|
||||
full_response += text_delta
|
||||
self._emit_stream_chunk_event(
|
||||
chunk=text_delta,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
final_message: Message = await stream.get_final_message()
|
||||
|
||||
usage = self._extract_anthropic_token_usage(final_message)
|
||||
self._track_token_usage_internal(usage)
|
||||
|
||||
if response_model and final_message.content:
|
||||
tool_uses = [
|
||||
block
|
||||
for block in final_message.content
|
||||
if isinstance(block, ToolUseBlock)
|
||||
]
|
||||
if tool_uses and tool_uses[0].name == "structured_output":
|
||||
structured_data = tool_uses[0].input
|
||||
structured_json = json.dumps(structured_data)
|
||||
|
||||
self._emit_call_completed_event(
|
||||
response=structured_json,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
)
|
||||
|
||||
return structured_json
|
||||
|
||||
if final_message.content and available_functions:
|
||||
tool_uses = [
|
||||
block
|
||||
for block in final_message.content
|
||||
if isinstance(block, ToolUseBlock)
|
||||
]
|
||||
|
||||
if tool_uses:
|
||||
return await self._ahandle_tool_use_conversation(
|
||||
final_message,
|
||||
tool_uses,
|
||||
params,
|
||||
available_functions,
|
||||
from_task,
|
||||
from_agent,
|
||||
)
|
||||
|
||||
full_response = self._apply_stop_words(full_response)
|
||||
|
||||
self._emit_call_completed_event(
|
||||
response=full_response,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
)
|
||||
|
||||
return full_response
|
||||
|
||||
async def _ahandle_tool_use_conversation(
|
||||
self,
|
||||
initial_response: Message,
|
||||
tool_uses: list[ToolUseBlock],
|
||||
params: dict[str, Any],
|
||||
available_functions: dict[str, Any],
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
) -> str:
|
||||
"""Handle the complete async tool use conversation flow.
|
||||
|
||||
This implements the proper Anthropic tool use pattern:
|
||||
1. Claude requests tool use
|
||||
2. We execute the tools
|
||||
3. We send tool results back to Claude
|
||||
4. Claude processes results and generates final response
|
||||
"""
|
||||
tool_results = []
|
||||
|
||||
for tool_use in tool_uses:
|
||||
function_name = tool_use.name
|
||||
function_args = tool_use.input
|
||||
|
||||
result = self._handle_tool_execution(
|
||||
function_name=function_name,
|
||||
function_args=function_args,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
tool_result = {
|
||||
"type": "tool_result",
|
||||
"tool_use_id": tool_use.id,
|
||||
"content": str(result)
|
||||
if result is not None
|
||||
else "Tool execution completed",
|
||||
}
|
||||
tool_results.append(tool_result)
|
||||
|
||||
follow_up_params = params.copy()
|
||||
|
||||
assistant_message = {"role": "assistant", "content": initial_response.content}
|
||||
|
||||
user_message = {"role": "user", "content": tool_results}
|
||||
|
||||
follow_up_params["messages"] = params["messages"] + [
|
||||
assistant_message,
|
||||
user_message,
|
||||
]
|
||||
|
||||
try:
|
||||
final_response: Message = await self.async_client.messages.create(
|
||||
**follow_up_params
|
||||
)
|
||||
|
||||
follow_up_usage = self._extract_anthropic_token_usage(final_response)
|
||||
self._track_token_usage_internal(follow_up_usage)
|
||||
|
||||
final_content = ""
|
||||
if final_response.content:
|
||||
for content_block in final_response.content:
|
||||
if hasattr(content_block, "text"):
|
||||
final_content += content_block.text
|
||||
|
||||
final_content = self._apply_stop_words(final_content)
|
||||
|
||||
self._emit_call_completed_event(
|
||||
response=final_content,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=follow_up_params["messages"],
|
||||
)
|
||||
|
||||
total_usage = {
|
||||
"input_tokens": follow_up_usage.get("input_tokens", 0),
|
||||
"output_tokens": follow_up_usage.get("output_tokens", 0),
|
||||
"total_tokens": follow_up_usage.get("total_tokens", 0),
|
||||
}
|
||||
|
||||
if total_usage.get("total_tokens", 0) > 0:
|
||||
logging.info(f"Anthropic API tool conversation usage: {total_usage}")
|
||||
|
||||
return final_content
|
||||
|
||||
except Exception as e:
|
||||
if is_context_length_exceeded(e):
|
||||
logging.error(f"Context window exceeded in tool follow-up: {e}")
|
||||
raise LLMContextLengthExceededError(str(e)) from e
|
||||
|
||||
logging.error(f"Tool follow-up conversation failed: {e}")
|
||||
if tool_results:
|
||||
return tool_results[0]["content"]
|
||||
raise e
|
||||
|
||||
def supports_function_calling(self) -> bool:
|
||||
"""Check if the model supports function calling."""
|
||||
return self.supports_tools
|
||||
|
||||
@@ -1,13 +1,17 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import Self
|
||||
|
||||
from crewai.utilities.agent_utils import is_context_length_exceeded
|
||||
from crewai.utilities.converter import generate_model_description
|
||||
from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
||||
LLMContextLengthExceededError,
|
||||
)
|
||||
@@ -23,13 +27,19 @@ try:
|
||||
from azure.ai.inference import (
|
||||
ChatCompletionsClient,
|
||||
)
|
||||
from azure.ai.inference.aio import (
|
||||
ChatCompletionsClient as AsyncChatCompletionsClient,
|
||||
)
|
||||
from azure.ai.inference.models import (
|
||||
ChatCompletions,
|
||||
ChatCompletionsToolCall,
|
||||
JsonSchemaFormat,
|
||||
StreamingChatCompletionsUpdate,
|
||||
)
|
||||
from azure.core.credentials import (
|
||||
AccessToken,
|
||||
AzureKeyCredential,
|
||||
TokenCredential,
|
||||
)
|
||||
from azure.core.exceptions import (
|
||||
HttpResponseError,
|
||||
@@ -44,6 +54,41 @@ except ImportError:
|
||||
) from None
|
||||
|
||||
|
||||
class _TokenProviderCredential(TokenCredential):
|
||||
"""Wrapper class to convert an azure_ad_token_provider callable into a TokenCredential.
|
||||
|
||||
This allows users to pass a token provider function (like the one returned by
|
||||
azure.identity.get_bearer_token_provider) to the Azure AI Inference client.
|
||||
"""
|
||||
|
||||
def __init__(self, provider: Callable[..., Any]):
|
||||
"""Initialize with a token provider callable.
|
||||
|
||||
Args:
|
||||
provider: A callable that returns an access token. This is typically
|
||||
the result of azure.identity.get_bearer_token_provider().
|
||||
"""
|
||||
self._provider = provider
|
||||
|
||||
def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken:
|
||||
"""Get an access token from the provider.
|
||||
|
||||
Args:
|
||||
*scopes: The scopes for the token (ignored, as the provider handles this).
|
||||
**kwargs: Additional keyword arguments (ignored).
|
||||
|
||||
Returns:
|
||||
An AccessToken instance.
|
||||
"""
|
||||
raw = self._provider()
|
||||
|
||||
if isinstance(raw, AccessToken):
|
||||
return raw
|
||||
|
||||
# If it's a bare string, wrap it with a default expiry of 1 hour
|
||||
return AccessToken(str(raw), int(time.time()) + 3600)
|
||||
|
||||
|
||||
class AzureCompletion(BaseLLM):
|
||||
"""Azure AI Inference native completion implementation.
|
||||
|
||||
@@ -67,6 +112,8 @@ class AzureCompletion(BaseLLM):
|
||||
stop: list[str] | None = None,
|
||||
stream: bool = False,
|
||||
interceptor: BaseInterceptor[Any, Any] | None = None,
|
||||
azure_ad_token_provider: Callable[..., Any] | None = None,
|
||||
credential: TokenCredential | None = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
"""Initialize Azure AI Inference chat completion client.
|
||||
@@ -86,6 +133,13 @@ class AzureCompletion(BaseLLM):
|
||||
stop: Stop sequences
|
||||
stream: Enable streaming responses
|
||||
interceptor: HTTP interceptor (not yet supported for Azure).
|
||||
azure_ad_token_provider: A callable that returns an Azure AD token.
|
||||
This is typically the result of azure.identity.get_bearer_token_provider().
|
||||
Use this for Azure AD token-based authentication instead of API keys.
|
||||
credential: An Azure TokenCredential instance for authentication.
|
||||
This can be any credential from azure.identity (e.g., DefaultAzureCredential,
|
||||
ManagedIdentityCredential). Takes precedence over azure_ad_token_provider
|
||||
and api_key.
|
||||
**kwargs: Additional parameters
|
||||
"""
|
||||
if interceptor is not None:
|
||||
@@ -101,6 +155,7 @@ class AzureCompletion(BaseLLM):
|
||||
self.api_key = api_key or os.getenv("AZURE_API_KEY")
|
||||
self.endpoint = (
|
||||
endpoint
|
||||
or kwargs.get("base_url")
|
||||
or os.getenv("AZURE_ENDPOINT")
|
||||
or os.getenv("AZURE_OPENAI_ENDPOINT")
|
||||
or os.getenv("AZURE_API_BASE")
|
||||
@@ -109,29 +164,45 @@ class AzureCompletion(BaseLLM):
|
||||
self.timeout = timeout
|
||||
self.max_retries = max_retries
|
||||
|
||||
if not self.api_key:
|
||||
raise ValueError(
|
||||
"Azure API key is required. Set AZURE_API_KEY environment variable or pass api_key parameter."
|
||||
)
|
||||
if not self.endpoint:
|
||||
raise ValueError(
|
||||
"Azure endpoint is required. Set AZURE_ENDPOINT environment variable or pass endpoint parameter."
|
||||
)
|
||||
|
||||
# Determine the credential to use (priority: credential > azure_ad_token_provider > api_key)
|
||||
chosen_credential: TokenCredential | AzureKeyCredential | None = None
|
||||
|
||||
if credential is not None:
|
||||
chosen_credential = credential
|
||||
elif azure_ad_token_provider is not None:
|
||||
chosen_credential = _TokenProviderCredential(azure_ad_token_provider)
|
||||
elif self.api_key:
|
||||
chosen_credential = AzureKeyCredential(self.api_key)
|
||||
|
||||
if chosen_credential is None:
|
||||
raise ValueError(
|
||||
"Azure authentication is required. Provide one of: "
|
||||
"api_key (or set AZURE_API_KEY environment variable), "
|
||||
"azure_ad_token_provider (callable from azure.identity.get_bearer_token_provider), "
|
||||
"or credential (TokenCredential instance from azure.identity)."
|
||||
)
|
||||
|
||||
# Validate and potentially fix Azure OpenAI endpoint URL
|
||||
self.endpoint = self._validate_and_fix_endpoint(self.endpoint, model)
|
||||
|
||||
# Build client kwargs
|
||||
client_kwargs = {
|
||||
client_kwargs: dict[str, Any] = {
|
||||
"endpoint": self.endpoint,
|
||||
"credential": AzureKeyCredential(self.api_key),
|
||||
"credential": chosen_credential,
|
||||
}
|
||||
|
||||
# Add api_version if specified (primarily for Azure OpenAI endpoints)
|
||||
if self.api_version:
|
||||
client_kwargs["api_version"] = self.api_version
|
||||
|
||||
self.client = ChatCompletionsClient(**client_kwargs) # type: ignore[arg-type]
|
||||
self.client = ChatCompletionsClient(**client_kwargs)
|
||||
|
||||
self.async_client = AsyncChatCompletionsClient(**client_kwargs)
|
||||
|
||||
self.top_p = top_p
|
||||
self.frequency_penalty = frequency_penalty
|
||||
@@ -256,6 +327,88 @@ class AzureCompletion(BaseLLM):
|
||||
)
|
||||
raise
|
||||
|
||||
async def acall(
|
||||
self,
|
||||
messages: str | list[LLMMessage],
|
||||
tools: list[dict[str, BaseTool]] | None = None,
|
||||
callbacks: list[Any] | None = None,
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str | Any:
|
||||
"""Call Azure AI Inference chat completions API asynchronously.
|
||||
|
||||
Args:
|
||||
messages: Input messages for the chat completion
|
||||
tools: List of tool/function definitions
|
||||
callbacks: Callback functions (not used in native implementation)
|
||||
available_functions: Available functions for tool calling
|
||||
from_task: Task that initiated the call
|
||||
from_agent: Agent that initiated the call
|
||||
response_model: Pydantic model for structured output
|
||||
|
||||
Returns:
|
||||
Chat completion response or tool call result
|
||||
"""
|
||||
try:
|
||||
self._emit_call_started_event(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
callbacks=callbacks,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
formatted_messages = self._format_messages_for_azure(messages)
|
||||
|
||||
completion_params = self._prepare_completion_params(
|
||||
formatted_messages, tools, response_model
|
||||
)
|
||||
|
||||
if self.stream:
|
||||
return await self._ahandle_streaming_completion(
|
||||
completion_params,
|
||||
available_functions,
|
||||
from_task,
|
||||
from_agent,
|
||||
response_model,
|
||||
)
|
||||
|
||||
return await self._ahandle_completion(
|
||||
completion_params,
|
||||
available_functions,
|
||||
from_task,
|
||||
from_agent,
|
||||
response_model,
|
||||
)
|
||||
|
||||
except HttpResponseError as e:
|
||||
if e.status_code == 401:
|
||||
error_msg = "Azure authentication failed. Check your API key."
|
||||
elif e.status_code == 404:
|
||||
error_msg = (
|
||||
f"Azure endpoint not found. Check endpoint URL: {self.endpoint}"
|
||||
)
|
||||
elif e.status_code == 429:
|
||||
error_msg = "Azure API rate limit exceeded. Please retry later."
|
||||
else:
|
||||
error_msg = f"Azure API HTTP error: {e.status_code} - {e.message}"
|
||||
|
||||
logging.error(error_msg)
|
||||
self._emit_call_failed_event(
|
||||
error=error_msg, from_task=from_task, from_agent=from_agent
|
||||
)
|
||||
raise
|
||||
except Exception as e:
|
||||
error_msg = f"Azure API call failed: {e!s}"
|
||||
logging.error(error_msg)
|
||||
self._emit_call_failed_event(
|
||||
error=error_msg, from_task=from_task, from_agent=from_agent
|
||||
)
|
||||
raise
|
||||
|
||||
def _prepare_completion_params(
|
||||
self,
|
||||
messages: list[LLMMessage],
|
||||
@@ -278,13 +431,16 @@ class AzureCompletion(BaseLLM):
|
||||
}
|
||||
|
||||
if response_model and self.is_openai_model:
|
||||
params["response_format"] = {
|
||||
"type": "json_schema",
|
||||
"json_schema": {
|
||||
"name": response_model.__name__,
|
||||
"schema": response_model.model_json_schema(),
|
||||
},
|
||||
}
|
||||
model_description = generate_model_description(response_model)
|
||||
json_schema_info = model_description["json_schema"]
|
||||
json_schema_name = json_schema_info["name"]
|
||||
|
||||
params["response_format"] = JsonSchemaFormat(
|
||||
name=json_schema_name,
|
||||
schema=json_schema_info["schema"],
|
||||
description=f"Schema for {json_schema_name}",
|
||||
strict=json_schema_info["strict"],
|
||||
)
|
||||
|
||||
# Only include model parameter for non-Azure OpenAI endpoints
|
||||
# Azure OpenAI endpoints have the deployment name in the URL
|
||||
@@ -311,8 +467,8 @@ class AzureCompletion(BaseLLM):
|
||||
params["tool_choice"] = "auto"
|
||||
|
||||
additional_params = self.additional_params
|
||||
additional_drop_params = additional_params.get('additional_drop_params')
|
||||
drop_params = additional_params.get('drop_params')
|
||||
additional_drop_params = additional_params.get("additional_drop_params")
|
||||
drop_params = additional_params.get("drop_params")
|
||||
|
||||
if drop_params and isinstance(additional_drop_params, list):
|
||||
for drop_param in additional_drop_params:
|
||||
@@ -551,6 +707,170 @@ class AzureCompletion(BaseLLM):
|
||||
|
||||
return full_response
|
||||
|
||||
async def _ahandle_completion(
|
||||
self,
|
||||
params: dict[str, Any],
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str | Any:
|
||||
"""Handle non-streaming chat completion asynchronously."""
|
||||
try:
|
||||
response: ChatCompletions = await self.async_client.complete(**params)
|
||||
|
||||
if not response.choices:
|
||||
raise ValueError("No choices returned from Azure API")
|
||||
|
||||
choice = response.choices[0]
|
||||
message = choice.message
|
||||
|
||||
usage = self._extract_azure_token_usage(response)
|
||||
self._track_token_usage_internal(usage)
|
||||
|
||||
if response_model and self.is_openai_model:
|
||||
content = message.content or ""
|
||||
try:
|
||||
structured_data = response_model.model_validate_json(content)
|
||||
structured_json = structured_data.model_dump_json()
|
||||
|
||||
self._emit_call_completed_event(
|
||||
response=structured_json,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
)
|
||||
|
||||
return structured_json
|
||||
except Exception as e:
|
||||
error_msg = f"Failed to validate structured output with model {response_model.__name__}: {e}"
|
||||
logging.error(error_msg)
|
||||
raise ValueError(error_msg) from e
|
||||
|
||||
if message.tool_calls and available_functions:
|
||||
tool_call = message.tool_calls[0] # Handle first tool call
|
||||
if isinstance(tool_call, ChatCompletionsToolCall):
|
||||
function_name = tool_call.function.name
|
||||
|
||||
try:
|
||||
function_args = json.loads(tool_call.function.arguments)
|
||||
except json.JSONDecodeError as e:
|
||||
logging.error(f"Failed to parse tool arguments: {e}")
|
||||
function_args = {}
|
||||
|
||||
result = self._handle_tool_execution(
|
||||
function_name=function_name,
|
||||
function_args=function_args,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
content = message.content or ""
|
||||
|
||||
content = self._apply_stop_words(content)
|
||||
|
||||
self._emit_call_completed_event(
|
||||
response=content,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
if is_context_length_exceeded(e):
|
||||
logging.error(f"Context window exceeded: {e}")
|
||||
raise LLMContextLengthExceededError(str(e)) from e
|
||||
|
||||
error_msg = f"Azure API call failed: {e!s}"
|
||||
logging.error(error_msg)
|
||||
self._emit_call_failed_event(
|
||||
error=error_msg, from_task=from_task, from_agent=from_agent
|
||||
)
|
||||
raise e
|
||||
|
||||
return content
|
||||
|
||||
async def _ahandle_streaming_completion(
|
||||
self,
|
||||
params: dict[str, Any],
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str:
|
||||
"""Handle streaming chat completion asynchronously."""
|
||||
full_response = ""
|
||||
tool_calls = {}
|
||||
|
||||
stream = await self.async_client.complete(**params)
|
||||
async for update in stream:
|
||||
if isinstance(update, StreamingChatCompletionsUpdate):
|
||||
if update.choices:
|
||||
choice = update.choices[0]
|
||||
if choice.delta and choice.delta.content:
|
||||
content_delta = choice.delta.content
|
||||
full_response += content_delta
|
||||
self._emit_stream_chunk_event(
|
||||
chunk=content_delta,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
if choice.delta and choice.delta.tool_calls:
|
||||
for tool_call in choice.delta.tool_calls:
|
||||
call_id = tool_call.id or "default"
|
||||
if call_id not in tool_calls:
|
||||
tool_calls[call_id] = {
|
||||
"name": "",
|
||||
"arguments": "",
|
||||
}
|
||||
|
||||
if tool_call.function and tool_call.function.name:
|
||||
tool_calls[call_id]["name"] = tool_call.function.name
|
||||
if tool_call.function and tool_call.function.arguments:
|
||||
tool_calls[call_id]["arguments"] += (
|
||||
tool_call.function.arguments
|
||||
)
|
||||
|
||||
if tool_calls and available_functions:
|
||||
for call_data in tool_calls.values():
|
||||
function_name = call_data["name"]
|
||||
|
||||
try:
|
||||
function_args = json.loads(call_data["arguments"])
|
||||
except json.JSONDecodeError as e:
|
||||
logging.error(f"Failed to parse streamed tool arguments: {e}")
|
||||
continue
|
||||
|
||||
result = self._handle_tool_execution(
|
||||
function_name=function_name,
|
||||
function_args=function_args,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
full_response = self._apply_stop_words(full_response)
|
||||
|
||||
self._emit_call_completed_event(
|
||||
response=full_response,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
)
|
||||
|
||||
return full_response
|
||||
|
||||
def supports_function_calling(self) -> bool:
|
||||
"""Check if the model supports function calling."""
|
||||
# Azure OpenAI models support function calling
|
||||
@@ -604,3 +924,20 @@ class AzureCompletion(BaseLLM):
|
||||
"total_tokens": getattr(usage, "total_tokens", 0),
|
||||
}
|
||||
return {"total_tokens": 0}
|
||||
|
||||
async def aclose(self) -> None:
|
||||
"""Close the async client and clean up resources.
|
||||
|
||||
This ensures proper cleanup of the underlying aiohttp session
|
||||
to avoid unclosed connector warnings.
|
||||
"""
|
||||
if hasattr(self.async_client, "close"):
|
||||
await self.async_client.close()
|
||||
|
||||
async def __aenter__(self) -> Self:
|
||||
"""Async context manager entry."""
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
||||
"""Async context manager exit."""
|
||||
await self.aclose()
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping, Sequence
|
||||
from contextlib import AsyncExitStack
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, TypedDict, cast
|
||||
@@ -42,6 +44,16 @@ except ImportError:
|
||||
'AWS Bedrock native provider not available, to install: uv add "crewai[bedrock]"'
|
||||
) from None
|
||||
|
||||
try:
|
||||
from aiobotocore.session import ( # type: ignore[import-untyped]
|
||||
get_session as get_aiobotocore_session,
|
||||
)
|
||||
|
||||
AIOBOTOCORE_AVAILABLE = True
|
||||
except ImportError:
|
||||
AIOBOTOCORE_AVAILABLE = False
|
||||
get_aiobotocore_session = None
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
@@ -221,6 +233,15 @@ class BedrockCompletion(BaseLLM):
|
||||
self.client = session.client("bedrock-runtime", config=config)
|
||||
self.region_name = region_name
|
||||
|
||||
self.aws_access_key_id = aws_access_key_id or os.getenv("AWS_ACCESS_KEY_ID")
|
||||
self.aws_secret_access_key = aws_secret_access_key or os.getenv(
|
||||
"AWS_SECRET_ACCESS_KEY"
|
||||
)
|
||||
self.aws_session_token = aws_session_token or os.getenv("AWS_SESSION_TOKEN")
|
||||
|
||||
self._async_exit_stack = AsyncExitStack() if AIOBOTOCORE_AVAILABLE else None
|
||||
self._async_client_initialized = False
|
||||
|
||||
# Store completion parameters
|
||||
self.max_tokens = max_tokens
|
||||
self.top_p = top_p
|
||||
@@ -354,6 +375,110 @@ class BedrockCompletion(BaseLLM):
|
||||
)
|
||||
raise
|
||||
|
||||
async def acall(
|
||||
self,
|
||||
messages: str | list[LLMMessage],
|
||||
tools: list[dict[Any, Any]] | None = None,
|
||||
callbacks: list[Any] | None = None,
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str | Any:
|
||||
"""Async call to AWS Bedrock Converse API.
|
||||
|
||||
Args:
|
||||
messages: Input messages as string or list of message dicts.
|
||||
tools: Optional list of tool definitions.
|
||||
callbacks: Optional list of callback handlers.
|
||||
available_functions: Optional dict mapping function names to callables.
|
||||
from_task: Optional task context for events.
|
||||
from_agent: Optional agent context for events.
|
||||
response_model: Optional Pydantic model for structured output.
|
||||
|
||||
Returns:
|
||||
Generated text response or structured output.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: If aiobotocore is not installed.
|
||||
LLMContextLengthExceededError: If context window is exceeded.
|
||||
"""
|
||||
if not AIOBOTOCORE_AVAILABLE:
|
||||
raise NotImplementedError(
|
||||
"Async support for AWS Bedrock requires aiobotocore. "
|
||||
'Install with: uv add "crewai[bedrock-async]"'
|
||||
)
|
||||
|
||||
try:
|
||||
self._emit_call_started_event(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
callbacks=callbacks,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
formatted_messages, system_message = self._format_messages_for_converse(
|
||||
messages # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
body: BedrockConverseRequestBody = {
|
||||
"inferenceConfig": self._get_inference_config(),
|
||||
}
|
||||
|
||||
if system_message:
|
||||
body["system"] = cast(
|
||||
"list[SystemContentBlockTypeDef]",
|
||||
cast(object, [{"text": system_message}]),
|
||||
)
|
||||
|
||||
if tools:
|
||||
tool_config: ToolConfigurationTypeDef = {
|
||||
"tools": cast(
|
||||
"Sequence[ToolTypeDef]",
|
||||
cast(object, self._format_tools_for_converse(tools)),
|
||||
)
|
||||
}
|
||||
body["toolConfig"] = tool_config
|
||||
|
||||
if self.guardrail_config:
|
||||
guardrail_config: GuardrailConfigurationTypeDef = cast(
|
||||
"GuardrailConfigurationTypeDef", cast(object, self.guardrail_config)
|
||||
)
|
||||
body["guardrailConfig"] = guardrail_config
|
||||
|
||||
if self.additional_model_request_fields:
|
||||
body["additionalModelRequestFields"] = (
|
||||
self.additional_model_request_fields
|
||||
)
|
||||
|
||||
if self.additional_model_response_field_paths:
|
||||
body["additionalModelResponseFieldPaths"] = (
|
||||
self.additional_model_response_field_paths
|
||||
)
|
||||
|
||||
if self.stream:
|
||||
return await self._ahandle_streaming_converse(
|
||||
formatted_messages, body, available_functions, from_task, from_agent
|
||||
)
|
||||
|
||||
return await self._ahandle_converse(
|
||||
formatted_messages, body, available_functions, from_task, from_agent
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
if is_context_length_exceeded(e):
|
||||
logging.error(f"Context window exceeded: {e}")
|
||||
raise LLMContextLengthExceededError(str(e)) from e
|
||||
|
||||
error_msg = f"AWS Bedrock API call failed: {e!s}"
|
||||
logging.error(error_msg)
|
||||
self._emit_call_failed_event(
|
||||
error=error_msg, from_task=from_task, from_agent=from_agent
|
||||
)
|
||||
raise
|
||||
|
||||
def _handle_converse(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
@@ -565,6 +690,341 @@ class BedrockCompletion(BaseLLM):
|
||||
role = event["messageStart"].get("role")
|
||||
logging.debug(f"Streaming message started with role: {role}")
|
||||
|
||||
elif "contentBlockStart" in event:
|
||||
start = event["contentBlockStart"].get("start", {})
|
||||
if "toolUse" in start:
|
||||
current_tool_use = start["toolUse"]
|
||||
tool_use_id = current_tool_use.get("toolUseId")
|
||||
logging.debug(
|
||||
f"Tool use started in stream: {json.dumps(current_tool_use)} (ID: {tool_use_id})"
|
||||
)
|
||||
|
||||
elif "contentBlockDelta" in event:
|
||||
delta = event["contentBlockDelta"]["delta"]
|
||||
if "text" in delta:
|
||||
text_chunk = delta["text"]
|
||||
logging.debug(f"Streaming text chunk: {text_chunk[:50]}...")
|
||||
full_response += text_chunk
|
||||
self._emit_stream_chunk_event(
|
||||
chunk=text_chunk,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
elif "toolUse" in delta and current_tool_use:
|
||||
tool_input = delta["toolUse"].get("input", "")
|
||||
if tool_input:
|
||||
logging.debug(f"Tool input delta: {tool_input}")
|
||||
elif "contentBlockStop" in event:
|
||||
logging.debug("Content block stopped in stream")
|
||||
if current_tool_use and available_functions:
|
||||
function_name = current_tool_use["name"]
|
||||
function_args = cast(
|
||||
dict[str, Any], current_tool_use.get("input", {})
|
||||
)
|
||||
tool_result = self._handle_tool_execution(
|
||||
function_name=function_name,
|
||||
function_args=function_args,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
if tool_result is not None and tool_use_id:
|
||||
messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [{"toolUse": current_tool_use}],
|
||||
}
|
||||
)
|
||||
messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"toolResult": {
|
||||
"toolUseId": tool_use_id,
|
||||
"content": [
|
||||
{"text": str(tool_result)}
|
||||
],
|
||||
}
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
return self._handle_converse(
|
||||
messages,
|
||||
body,
|
||||
available_functions,
|
||||
from_task,
|
||||
from_agent,
|
||||
)
|
||||
current_tool_use = None
|
||||
tool_use_id = None
|
||||
elif "messageStop" in event:
|
||||
stop_reason = event["messageStop"].get("stopReason")
|
||||
logging.debug(f"Streaming message stopped: {stop_reason}")
|
||||
if stop_reason == "max_tokens":
|
||||
logging.warning(
|
||||
"Streaming response truncated due to max_tokens"
|
||||
)
|
||||
elif stop_reason == "content_filtered":
|
||||
logging.warning(
|
||||
"Streaming response filtered due to content policy"
|
||||
)
|
||||
break
|
||||
elif "metadata" in event:
|
||||
metadata = event["metadata"]
|
||||
if "usage" in metadata:
|
||||
usage_metrics = metadata["usage"]
|
||||
self._track_token_usage_internal(usage_metrics)
|
||||
logging.debug(f"Token usage: {usage_metrics}")
|
||||
if "trace" in metadata:
|
||||
logging.debug(
|
||||
f"Trace information available: {metadata['trace']}"
|
||||
)
|
||||
|
||||
except ClientError as e:
|
||||
error_msg = self._handle_client_error(e)
|
||||
raise RuntimeError(error_msg) from e
|
||||
except BotoCoreError as e:
|
||||
error_msg = f"Bedrock streaming connection error: {e}"
|
||||
logging.error(error_msg)
|
||||
raise ConnectionError(error_msg) from e
|
||||
|
||||
full_response = self._apply_stop_words(full_response)
|
||||
|
||||
if not full_response or full_response.strip() == "":
|
||||
logging.warning("Bedrock streaming returned empty content, using fallback")
|
||||
full_response = (
|
||||
"I apologize, but I couldn't generate a response. Please try again."
|
||||
)
|
||||
|
||||
self._emit_call_completed_event(
|
||||
response=full_response,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=messages,
|
||||
)
|
||||
|
||||
return full_response
|
||||
|
||||
async def _ensure_async_client(self) -> Any:
|
||||
"""Ensure async client is initialized and return it."""
|
||||
if not self._async_client_initialized and get_aiobotocore_session:
|
||||
if self._async_exit_stack is None:
|
||||
raise RuntimeError(
|
||||
"Async exit stack not initialized - aiobotocore not available"
|
||||
)
|
||||
session = get_aiobotocore_session()
|
||||
client = await self._async_exit_stack.enter_async_context(
|
||||
session.create_client(
|
||||
"bedrock-runtime",
|
||||
region_name=self.region_name,
|
||||
aws_access_key_id=self.aws_access_key_id,
|
||||
aws_secret_access_key=self.aws_secret_access_key,
|
||||
aws_session_token=self.aws_session_token,
|
||||
)
|
||||
)
|
||||
self._async_client = client
|
||||
self._async_client_initialized = True
|
||||
return self._async_client
|
||||
|
||||
async def _ahandle_converse(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
body: BedrockConverseRequestBody,
|
||||
available_functions: Mapping[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
) -> str:
|
||||
"""Handle async non-streaming converse API call."""
|
||||
try:
|
||||
if not messages:
|
||||
raise ValueError("Messages cannot be empty")
|
||||
|
||||
for i, msg in enumerate(messages):
|
||||
if (
|
||||
not isinstance(msg, dict)
|
||||
or "role" not in msg
|
||||
or "content" not in msg
|
||||
):
|
||||
raise ValueError(f"Invalid message format at index {i}")
|
||||
|
||||
async_client = await self._ensure_async_client()
|
||||
response = await async_client.converse(
|
||||
modelId=self.model_id,
|
||||
messages=cast(
|
||||
"Sequence[MessageTypeDef | MessageOutputTypeDef]",
|
||||
cast(object, messages),
|
||||
),
|
||||
**body,
|
||||
)
|
||||
|
||||
if "usage" in response:
|
||||
self._track_token_usage_internal(response["usage"])
|
||||
|
||||
stop_reason = response.get("stopReason")
|
||||
if stop_reason:
|
||||
logging.debug(f"Response stop reason: {stop_reason}")
|
||||
if stop_reason == "max_tokens":
|
||||
logging.warning("Response truncated due to max_tokens limit")
|
||||
elif stop_reason == "content_filtered":
|
||||
logging.warning("Response was filtered due to content policy")
|
||||
|
||||
output = response.get("output", {})
|
||||
message = output.get("message", {})
|
||||
content = message.get("content", [])
|
||||
|
||||
if not content:
|
||||
logging.warning("No content in Bedrock response")
|
||||
return (
|
||||
"I apologize, but I received an empty response. Please try again."
|
||||
)
|
||||
|
||||
text_content = ""
|
||||
|
||||
for content_block in content:
|
||||
if "text" in content_block:
|
||||
text_content += content_block["text"]
|
||||
|
||||
elif "toolUse" in content_block and available_functions:
|
||||
tool_use_block = content_block["toolUse"]
|
||||
tool_use_id = tool_use_block.get("toolUseId")
|
||||
function_name = tool_use_block["name"]
|
||||
function_args = tool_use_block.get("input", {})
|
||||
|
||||
logging.debug(
|
||||
f"Tool use requested: {function_name} with ID {tool_use_id}"
|
||||
)
|
||||
|
||||
tool_result = self._handle_tool_execution(
|
||||
function_name=function_name,
|
||||
function_args=function_args,
|
||||
available_functions=dict(available_functions),
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
if tool_result is not None:
|
||||
messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [{"toolUse": tool_use_block}],
|
||||
}
|
||||
)
|
||||
|
||||
messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"toolResult": {
|
||||
"toolUseId": tool_use_id,
|
||||
"content": [{"text": str(tool_result)}],
|
||||
}
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
return await self._ahandle_converse(
|
||||
messages, body, available_functions, from_task, from_agent
|
||||
)
|
||||
|
||||
text_content = self._apply_stop_words(text_content)
|
||||
|
||||
if not text_content or text_content.strip() == "":
|
||||
logging.warning("Extracted empty text content from Bedrock response")
|
||||
text_content = "I apologize, but I couldn't generate a proper response. Please try again."
|
||||
|
||||
self._emit_call_completed_event(
|
||||
response=text_content,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=messages,
|
||||
)
|
||||
|
||||
return text_content
|
||||
|
||||
except ClientError as e:
|
||||
error_code = e.response.get("Error", {}).get("Code", "Unknown")
|
||||
error_msg = e.response.get("Error", {}).get("Message", str(e))
|
||||
logging.error(f"AWS Bedrock ClientError ({error_code}): {error_msg}")
|
||||
|
||||
if error_code == "ValidationException":
|
||||
if "last turn" in error_msg and "user message" in error_msg:
|
||||
raise ValueError(
|
||||
f"Conversation format error: {error_msg}. Check message alternation."
|
||||
) from e
|
||||
raise ValueError(f"Request validation failed: {error_msg}") from e
|
||||
if error_code == "AccessDeniedException":
|
||||
raise PermissionError(
|
||||
f"Access denied to model {self.model_id}: {error_msg}"
|
||||
) from e
|
||||
if error_code == "ResourceNotFoundException":
|
||||
raise ValueError(f"Model {self.model_id} not found: {error_msg}") from e
|
||||
if error_code == "ThrottlingException":
|
||||
raise RuntimeError(
|
||||
f"API throttled, please retry later: {error_msg}"
|
||||
) from e
|
||||
if error_code == "ModelTimeoutException":
|
||||
raise TimeoutError(f"Model request timed out: {error_msg}") from e
|
||||
if error_code == "ServiceQuotaExceededException":
|
||||
raise RuntimeError(f"Service quota exceeded: {error_msg}") from e
|
||||
if error_code == "ModelNotReadyException":
|
||||
raise RuntimeError(
|
||||
f"Model {self.model_id} not ready: {error_msg}"
|
||||
) from e
|
||||
if error_code == "ModelErrorException":
|
||||
raise RuntimeError(f"Model error: {error_msg}") from e
|
||||
if error_code == "InternalServerException":
|
||||
raise RuntimeError(f"Internal server error: {error_msg}") from e
|
||||
if error_code == "ServiceUnavailableException":
|
||||
raise RuntimeError(f"Service unavailable: {error_msg}") from e
|
||||
|
||||
raise RuntimeError(f"Bedrock API error ({error_code}): {error_msg}") from e
|
||||
|
||||
except BotoCoreError as e:
|
||||
error_msg = f"Bedrock connection error: {e}"
|
||||
logging.error(error_msg)
|
||||
raise ConnectionError(error_msg) from e
|
||||
except Exception as e:
|
||||
error_msg = f"Unexpected error in Bedrock converse call: {e}"
|
||||
logging.error(error_msg)
|
||||
raise RuntimeError(error_msg) from e
|
||||
|
||||
async def _ahandle_streaming_converse(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
body: BedrockConverseRequestBody,
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
) -> str:
|
||||
"""Handle async streaming converse API call."""
|
||||
full_response = ""
|
||||
current_tool_use = None
|
||||
tool_use_id = None
|
||||
|
||||
try:
|
||||
async_client = await self._ensure_async_client()
|
||||
response = await async_client.converse_stream(
|
||||
modelId=self.model_id,
|
||||
messages=cast(
|
||||
"Sequence[MessageTypeDef | MessageOutputTypeDef]",
|
||||
cast(object, messages),
|
||||
),
|
||||
**body,
|
||||
)
|
||||
|
||||
stream = response.get("stream")
|
||||
if stream:
|
||||
async for event in stream:
|
||||
if "messageStart" in event:
|
||||
role = event["messageStart"].get("role")
|
||||
logging.debug(f"Streaming message started with role: {role}")
|
||||
|
||||
elif "contentBlockStart" in event:
|
||||
start = event["contentBlockStart"].get("start", {})
|
||||
if "toolUse" in start:
|
||||
@@ -590,17 +1050,14 @@ class BedrockCompletion(BaseLLM):
|
||||
if tool_input:
|
||||
logging.debug(f"Tool input delta: {tool_input}")
|
||||
|
||||
# Content block stop - end of a content block
|
||||
elif "contentBlockStop" in event:
|
||||
logging.debug("Content block stopped in stream")
|
||||
# If we were accumulating a tool use, it's now complete
|
||||
if current_tool_use and available_functions:
|
||||
function_name = current_tool_use["name"]
|
||||
function_args = cast(
|
||||
dict[str, Any], current_tool_use.get("input", {})
|
||||
)
|
||||
|
||||
# Execute tool
|
||||
tool_result = self._handle_tool_execution(
|
||||
function_name=function_name,
|
||||
function_args=function_args,
|
||||
@@ -610,7 +1067,6 @@ class BedrockCompletion(BaseLLM):
|
||||
)
|
||||
|
||||
if tool_result is not None and tool_use_id:
|
||||
# Continue conversation with tool result
|
||||
messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
@@ -634,8 +1090,7 @@ class BedrockCompletion(BaseLLM):
|
||||
}
|
||||
)
|
||||
|
||||
# Recursive call - note this switches to non-streaming
|
||||
return self._handle_converse(
|
||||
return await self._ahandle_converse(
|
||||
messages,
|
||||
body,
|
||||
available_functions,
|
||||
@@ -643,10 +1098,9 @@ class BedrockCompletion(BaseLLM):
|
||||
from_agent,
|
||||
)
|
||||
|
||||
current_tool_use = None
|
||||
tool_use_id = None
|
||||
current_tool_use = None
|
||||
tool_use_id = None
|
||||
|
||||
# Message stop - end of entire message
|
||||
elif "messageStop" in event:
|
||||
stop_reason = event["messageStop"].get("stopReason")
|
||||
logging.debug(f"Streaming message stopped: {stop_reason}")
|
||||
@@ -660,7 +1114,6 @@ class BedrockCompletion(BaseLLM):
|
||||
)
|
||||
break
|
||||
|
||||
# Metadata - contains usage information and trace details
|
||||
elif "metadata" in event:
|
||||
metadata = event["metadata"]
|
||||
if "usage" in metadata:
|
||||
@@ -680,17 +1133,14 @@ class BedrockCompletion(BaseLLM):
|
||||
logging.error(error_msg)
|
||||
raise ConnectionError(error_msg) from e
|
||||
|
||||
# Apply stop words to full response
|
||||
full_response = self._apply_stop_words(full_response)
|
||||
|
||||
# Ensure we don't return empty content
|
||||
if not full_response or full_response.strip() == "":
|
||||
logging.warning("Bedrock streaming returned empty content, using fallback")
|
||||
full_response = (
|
||||
"I apologize, but I couldn't generate a response. Please try again."
|
||||
)
|
||||
|
||||
# Emit completion event
|
||||
self._emit_call_completed_event(
|
||||
response=full_response,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from typing import Any, cast
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from crewai.events.types.llm_events import LLMCallType
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.llms.hooks.base import BaseInterceptor
|
||||
from crewai.utilities.agent_utils import is_context_length_exceeded
|
||||
from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
||||
LLMContextLengthExceededError,
|
||||
@@ -15,10 +16,15 @@ from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
||||
from crewai.utilities.types import LLMMessage
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.llms.hooks.base import BaseInterceptor
|
||||
|
||||
|
||||
try:
|
||||
from google import genai # type: ignore[import-untyped]
|
||||
from google.genai import types # type: ignore[import-untyped]
|
||||
from google.genai.errors import APIError # type: ignore[import-untyped]
|
||||
from google import genai
|
||||
from google.genai import types
|
||||
from google.genai.errors import APIError
|
||||
from google.genai.types import GenerateContentResponse, Schema
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
'Google Gen AI native provider not available, to install: uv add "crewai[google-genai]"'
|
||||
@@ -102,7 +108,9 @@ class GeminiCompletion(BaseLLM):
|
||||
|
||||
# Model-specific settings
|
||||
version_match = re.search(r"gemini-(\d+(?:\.\d+)?)", model.lower())
|
||||
self.supports_tools = bool(version_match and float(version_match.group(1)) >= 1.5)
|
||||
self.supports_tools = bool(
|
||||
version_match and float(version_match.group(1)) >= 1.5
|
||||
)
|
||||
|
||||
@property
|
||||
def stop(self) -> list[str]:
|
||||
@@ -128,7 +136,7 @@ class GeminiCompletion(BaseLLM):
|
||||
else:
|
||||
self.stop_sequences = []
|
||||
|
||||
def _initialize_client(self, use_vertexai: bool = False) -> genai.Client: # type: ignore[no-any-unimported]
|
||||
def _initialize_client(self, use_vertexai: bool = False) -> genai.Client:
|
||||
"""Initialize the Google Gen AI client with proper parameter handling.
|
||||
|
||||
Args:
|
||||
@@ -277,7 +285,84 @@ class GeminiCompletion(BaseLLM):
|
||||
)
|
||||
raise
|
||||
|
||||
def _prepare_generation_config( # type: ignore[no-any-unimported]
|
||||
async def acall(
|
||||
self,
|
||||
messages: str | list[LLMMessage],
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
callbacks: list[Any] | None = None,
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str | Any:
|
||||
"""Async call to Google Gemini generate content API.
|
||||
|
||||
Args:
|
||||
messages: Input messages for the chat completion
|
||||
tools: List of tool/function definitions
|
||||
callbacks: Callback functions (not used as token counts are handled by the response)
|
||||
available_functions: Available functions for tool calling
|
||||
from_task: Task that initiated the call
|
||||
from_agent: Agent that initiated the call
|
||||
|
||||
Returns:
|
||||
Chat completion response or tool call result
|
||||
"""
|
||||
try:
|
||||
self._emit_call_started_event(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
callbacks=callbacks,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
self.tools = tools
|
||||
|
||||
formatted_content, system_instruction = self._format_messages_for_gemini(
|
||||
messages
|
||||
)
|
||||
|
||||
config = self._prepare_generation_config(
|
||||
system_instruction, tools, response_model
|
||||
)
|
||||
|
||||
if self.stream:
|
||||
return await self._ahandle_streaming_completion(
|
||||
formatted_content,
|
||||
config,
|
||||
available_functions,
|
||||
from_task,
|
||||
from_agent,
|
||||
response_model,
|
||||
)
|
||||
|
||||
return await self._ahandle_completion(
|
||||
formatted_content,
|
||||
system_instruction,
|
||||
config,
|
||||
available_functions,
|
||||
from_task,
|
||||
from_agent,
|
||||
response_model,
|
||||
)
|
||||
|
||||
except APIError as e:
|
||||
error_msg = f"Google Gemini API error: {e.code} - {e.message}"
|
||||
logging.error(error_msg)
|
||||
self._emit_call_failed_event(
|
||||
error=error_msg, from_task=from_task, from_agent=from_agent
|
||||
)
|
||||
raise
|
||||
except Exception as e:
|
||||
error_msg = f"Google Gemini API call failed: {e!s}"
|
||||
logging.error(error_msg)
|
||||
self._emit_call_failed_event(
|
||||
error=error_msg, from_task=from_task, from_agent=from_agent
|
||||
)
|
||||
raise
|
||||
|
||||
def _prepare_generation_config(
|
||||
self,
|
||||
system_instruction: str | None = None,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
@@ -294,7 +379,7 @@ class GeminiCompletion(BaseLLM):
|
||||
GenerateContentConfig object for Gemini API
|
||||
"""
|
||||
self.tools = tools
|
||||
config_params = {}
|
||||
config_params: dict[str, Any] = {}
|
||||
|
||||
# Add system instruction if present
|
||||
if system_instruction:
|
||||
@@ -329,7 +414,7 @@ class GeminiCompletion(BaseLLM):
|
||||
|
||||
return types.GenerateContentConfig(**config_params)
|
||||
|
||||
def _convert_tools_for_interference( # type: ignore[no-any-unimported]
|
||||
def _convert_tools_for_interference( # type: ignore[override]
|
||||
self, tools: list[dict[str, Any]]
|
||||
) -> list[types.Tool]:
|
||||
"""Convert CrewAI tool format to Gemini function declaration format."""
|
||||
@@ -346,7 +431,7 @@ class GeminiCompletion(BaseLLM):
|
||||
)
|
||||
|
||||
# Add parameters if present - ensure parameters is a dict
|
||||
if parameters and isinstance(parameters, dict):
|
||||
if parameters and isinstance(parameters, Schema):
|
||||
function_declaration.parameters = parameters
|
||||
|
||||
gemini_tool = types.Tool(function_declarations=[function_declaration])
|
||||
@@ -354,7 +439,7 @@ class GeminiCompletion(BaseLLM):
|
||||
|
||||
return gemini_tools
|
||||
|
||||
def _format_messages_for_gemini( # type: ignore[no-any-unimported]
|
||||
def _format_messages_for_gemini(
|
||||
self, messages: str | list[LLMMessage]
|
||||
) -> tuple[list[types.Content], str | None]:
|
||||
"""Format messages for Gemini API.
|
||||
@@ -373,32 +458,41 @@ class GeminiCompletion(BaseLLM):
|
||||
# Use base class formatting first
|
||||
base_formatted = super()._format_messages(messages)
|
||||
|
||||
contents = []
|
||||
contents: list[types.Content] = []
|
||||
system_instruction: str | None = None
|
||||
|
||||
for message in base_formatted:
|
||||
role = message.get("role")
|
||||
content = message.get("content", "")
|
||||
role = message["role"]
|
||||
content = message["content"]
|
||||
|
||||
# Convert content to string if it's a list
|
||||
if isinstance(content, list):
|
||||
text_content = " ".join(
|
||||
str(item.get("text", "")) if isinstance(item, dict) else str(item)
|
||||
for item in content
|
||||
)
|
||||
else:
|
||||
text_content = str(content) if content else ""
|
||||
|
||||
if role == "system":
|
||||
# Extract system instruction - Gemini handles it separately
|
||||
if system_instruction:
|
||||
system_instruction += f"\n\n{content}"
|
||||
system_instruction += f"\n\n{text_content}"
|
||||
else:
|
||||
system_instruction = cast(str, content)
|
||||
system_instruction = text_content
|
||||
else:
|
||||
# Convert role for Gemini (assistant -> model)
|
||||
gemini_role = "model" if role == "assistant" else "user"
|
||||
|
||||
# Create Content object
|
||||
gemini_content = types.Content(
|
||||
role=gemini_role, parts=[types.Part.from_text(text=content)]
|
||||
role=gemini_role, parts=[types.Part.from_text(text=text_content)]
|
||||
)
|
||||
contents.append(gemini_content)
|
||||
|
||||
return contents, system_instruction
|
||||
|
||||
def _handle_completion( # type: ignore[no-any-unimported]
|
||||
def _handle_completion(
|
||||
self,
|
||||
contents: list[types.Content],
|
||||
system_instruction: str | None,
|
||||
@@ -409,14 +503,14 @@ class GeminiCompletion(BaseLLM):
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str | Any:
|
||||
"""Handle non-streaming content generation."""
|
||||
api_params = {
|
||||
"model": self.model,
|
||||
"contents": contents,
|
||||
"config": config,
|
||||
}
|
||||
|
||||
try:
|
||||
response = self.client.models.generate_content(**api_params)
|
||||
# The API accepts list[Content] but mypy is overly strict about variance
|
||||
contents_for_api: Any = contents
|
||||
response = self.client.models.generate_content(
|
||||
model=self.model,
|
||||
contents=contents_for_api,
|
||||
config=config,
|
||||
)
|
||||
|
||||
usage = self._extract_token_usage(response)
|
||||
except Exception as e:
|
||||
@@ -433,6 +527,8 @@ class GeminiCompletion(BaseLLM):
|
||||
for part in candidate.content.parts:
|
||||
if hasattr(part, "function_call") and part.function_call:
|
||||
function_name = part.function_call.name
|
||||
if function_name is None:
|
||||
continue
|
||||
function_args = (
|
||||
dict(part.function_call.args)
|
||||
if part.function_call.args
|
||||
@@ -442,7 +538,7 @@ class GeminiCompletion(BaseLLM):
|
||||
result = self._handle_tool_execution(
|
||||
function_name=function_name,
|
||||
function_args=function_args,
|
||||
available_functions=available_functions, # type: ignore
|
||||
available_functions=available_functions or {},
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
@@ -450,7 +546,7 @@ class GeminiCompletion(BaseLLM):
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
content = response.text if hasattr(response, "text") else ""
|
||||
content = response.text or ""
|
||||
content = self._apply_stop_words(content)
|
||||
|
||||
messages_for_event = self._convert_contents_to_dict(contents)
|
||||
@@ -465,7 +561,7 @@ class GeminiCompletion(BaseLLM):
|
||||
|
||||
return content
|
||||
|
||||
def _handle_streaming_completion( # type: ignore[no-any-unimported]
|
||||
def _handle_streaming_completion(
|
||||
self,
|
||||
contents: list[types.Content],
|
||||
config: types.GenerateContentConfig,
|
||||
@@ -476,16 +572,16 @@ class GeminiCompletion(BaseLLM):
|
||||
) -> str:
|
||||
"""Handle streaming content generation."""
|
||||
full_response = ""
|
||||
function_calls = {}
|
||||
function_calls: dict[str, dict[str, Any]] = {}
|
||||
|
||||
api_params = {
|
||||
"model": self.model,
|
||||
"contents": contents,
|
||||
"config": config,
|
||||
}
|
||||
|
||||
for chunk in self.client.models.generate_content_stream(**api_params):
|
||||
if hasattr(chunk, "text") and chunk.text:
|
||||
# The API accepts list[Content] but mypy is overly strict about variance
|
||||
contents_for_api: Any = contents
|
||||
for chunk in self.client.models.generate_content_stream(
|
||||
model=self.model,
|
||||
contents=contents_for_api,
|
||||
config=config,
|
||||
):
|
||||
if chunk.text:
|
||||
full_response += chunk.text
|
||||
self._emit_stream_chunk_event(
|
||||
chunk=chunk.text,
|
||||
@@ -493,7 +589,7 @@ class GeminiCompletion(BaseLLM):
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
if hasattr(chunk, "candidates") and chunk.candidates:
|
||||
if chunk.candidates:
|
||||
candidate = chunk.candidates[0]
|
||||
if candidate.content and candidate.content.parts:
|
||||
for part in candidate.content.parts:
|
||||
@@ -513,6 +609,14 @@ class GeminiCompletion(BaseLLM):
|
||||
function_name = call_data["name"]
|
||||
function_args = call_data["args"]
|
||||
|
||||
# Skip if function_name is None
|
||||
if not isinstance(function_name, str):
|
||||
continue
|
||||
|
||||
# Ensure function_args is a dict
|
||||
if not isinstance(function_args, dict):
|
||||
function_args = {}
|
||||
|
||||
# Execute tool
|
||||
result = self._handle_tool_execution(
|
||||
function_name=function_name,
|
||||
@@ -537,6 +641,154 @@ class GeminiCompletion(BaseLLM):
|
||||
|
||||
return full_response
|
||||
|
||||
async def _ahandle_completion(
|
||||
self,
|
||||
contents: list[types.Content],
|
||||
system_instruction: str | None,
|
||||
config: types.GenerateContentConfig,
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str | Any:
|
||||
"""Handle async non-streaming content generation."""
|
||||
try:
|
||||
# The API accepts list[Content] but mypy is overly strict about variance
|
||||
contents_for_api: Any = contents
|
||||
response = await self.client.aio.models.generate_content(
|
||||
model=self.model,
|
||||
contents=contents_for_api,
|
||||
config=config,
|
||||
)
|
||||
|
||||
usage = self._extract_token_usage(response)
|
||||
except Exception as e:
|
||||
if is_context_length_exceeded(e):
|
||||
logging.error(f"Context window exceeded: {e}")
|
||||
raise LLMContextLengthExceededError(str(e)) from e
|
||||
raise e from e
|
||||
|
||||
self._track_token_usage_internal(usage)
|
||||
|
||||
if response.candidates and (self.tools or available_functions):
|
||||
candidate = response.candidates[0]
|
||||
if candidate.content and candidate.content.parts:
|
||||
for part in candidate.content.parts:
|
||||
if hasattr(part, "function_call") and part.function_call:
|
||||
function_name = part.function_call.name
|
||||
if function_name is None:
|
||||
continue
|
||||
function_args = (
|
||||
dict(part.function_call.args)
|
||||
if part.function_call.args
|
||||
else {}
|
||||
)
|
||||
|
||||
result = self._handle_tool_execution(
|
||||
function_name=function_name,
|
||||
function_args=function_args,
|
||||
available_functions=available_functions or {},
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
content = response.text or ""
|
||||
content = self._apply_stop_words(content)
|
||||
|
||||
messages_for_event = self._convert_contents_to_dict(contents)
|
||||
|
||||
self._emit_call_completed_event(
|
||||
response=content,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=messages_for_event,
|
||||
)
|
||||
|
||||
return content
|
||||
|
||||
async def _ahandle_streaming_completion(
|
||||
self,
|
||||
contents: list[types.Content],
|
||||
config: types.GenerateContentConfig,
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str:
|
||||
"""Handle async streaming content generation."""
|
||||
full_response = ""
|
||||
function_calls: dict[str, dict[str, Any]] = {}
|
||||
|
||||
# The API accepts list[Content] but mypy is overly strict about variance
|
||||
contents_for_api: Any = contents
|
||||
stream = await self.client.aio.models.generate_content_stream(
|
||||
model=self.model,
|
||||
contents=contents_for_api,
|
||||
config=config,
|
||||
)
|
||||
async for chunk in stream:
|
||||
if chunk.text:
|
||||
full_response += chunk.text
|
||||
self._emit_stream_chunk_event(
|
||||
chunk=chunk.text,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
if chunk.candidates:
|
||||
candidate = chunk.candidates[0]
|
||||
if candidate.content and candidate.content.parts:
|
||||
for part in candidate.content.parts:
|
||||
if hasattr(part, "function_call") and part.function_call:
|
||||
call_id = part.function_call.name or "default"
|
||||
if call_id not in function_calls:
|
||||
function_calls[call_id] = {
|
||||
"name": part.function_call.name,
|
||||
"args": dict(part.function_call.args)
|
||||
if part.function_call.args
|
||||
else {},
|
||||
}
|
||||
|
||||
if function_calls and available_functions:
|
||||
for call_data in function_calls.values():
|
||||
function_name = call_data["name"]
|
||||
function_args = call_data["args"]
|
||||
|
||||
# Skip if function_name is None
|
||||
if not isinstance(function_name, str):
|
||||
continue
|
||||
|
||||
# Ensure function_args is a dict
|
||||
if not isinstance(function_args, dict):
|
||||
function_args = {}
|
||||
|
||||
result = self._handle_tool_execution(
|
||||
function_name=function_name,
|
||||
function_args=function_args,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
messages_for_event = self._convert_contents_to_dict(contents)
|
||||
|
||||
self._emit_call_completed_event(
|
||||
response=full_response,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=messages_for_event,
|
||||
)
|
||||
|
||||
return full_response
|
||||
|
||||
def supports_function_calling(self) -> bool:
|
||||
"""Check if the model supports function calling."""
|
||||
return self.supports_tools
|
||||
@@ -583,9 +835,10 @@ class GeminiCompletion(BaseLLM):
|
||||
# Default context window size for Gemini models
|
||||
return int(1048576 * CONTEXT_WINDOW_USAGE_RATIO) # 1M tokens
|
||||
|
||||
def _extract_token_usage(self, response: dict[str, Any]) -> dict[str, Any]:
|
||||
@staticmethod
|
||||
def _extract_token_usage(response: GenerateContentResponse) -> dict[str, Any]:
|
||||
"""Extract token usage from Gemini response."""
|
||||
if hasattr(response, "usage_metadata"):
|
||||
if response.usage_metadata:
|
||||
usage = response.usage_metadata
|
||||
return {
|
||||
"prompt_token_count": getattr(usage, "prompt_token_count", 0),
|
||||
@@ -595,21 +848,23 @@ class GeminiCompletion(BaseLLM):
|
||||
}
|
||||
return {"total_tokens": 0}
|
||||
|
||||
def _convert_contents_to_dict( # type: ignore[no-any-unimported]
|
||||
def _convert_contents_to_dict(
|
||||
self,
|
||||
contents: list[types.Content],
|
||||
) -> list[dict[str, str]]:
|
||||
"""Convert contents to dict format."""
|
||||
return [
|
||||
{
|
||||
"role": "assistant"
|
||||
if content_obj.role == "model"
|
||||
else content_obj.role,
|
||||
"content": " ".join(
|
||||
part.text
|
||||
for part in content_obj.parts
|
||||
if hasattr(part, "text") and part.text
|
||||
),
|
||||
}
|
||||
for content_obj in contents
|
||||
]
|
||||
result: list[dict[str, str]] = []
|
||||
for content_obj in contents:
|
||||
role = content_obj.role
|
||||
if role == "model":
|
||||
role = "assistant"
|
||||
elif role is None:
|
||||
role = "user"
|
||||
|
||||
parts = content_obj.parts or []
|
||||
content = " ".join(
|
||||
part.text for part in parts if hasattr(part, "text") and part.text
|
||||
)
|
||||
|
||||
result.append({"role": role, "content": content})
|
||||
return result
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterator
|
||||
from collections.abc import AsyncIterator, Iterator
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import httpx
|
||||
from openai import APIConnectionError, NotFoundError, OpenAI
|
||||
from openai import APIConnectionError, AsyncOpenAI, NotFoundError, OpenAI
|
||||
from openai.types.chat import ChatCompletion, ChatCompletionChunk
|
||||
from openai.types.chat.chat_completion import Choice
|
||||
from openai.types.chat.chat_completion_chunk import ChoiceDelta
|
||||
@@ -15,7 +15,7 @@ from pydantic import BaseModel
|
||||
|
||||
from crewai.events.types.llm_events import LLMCallType
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.llms.hooks.transport import HTTPTransport
|
||||
from crewai.llms.hooks.transport import AsyncHTTPTransport, HTTPTransport
|
||||
from crewai.utilities.agent_utils import is_context_length_exceeded
|
||||
from crewai.utilities.converter import generate_model_description
|
||||
from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
||||
@@ -101,6 +101,14 @@ class OpenAICompletion(BaseLLM):
|
||||
|
||||
self.client = OpenAI(**client_config)
|
||||
|
||||
async_client_config = self._get_client_params()
|
||||
if self.interceptor:
|
||||
async_transport = AsyncHTTPTransport(interceptor=self.interceptor)
|
||||
async_http_client = httpx.AsyncClient(transport=async_transport)
|
||||
async_client_config["http_client"] = async_http_client
|
||||
|
||||
self.async_client = AsyncOpenAI(**async_client_config)
|
||||
|
||||
# Completion parameters
|
||||
self.top_p = top_p
|
||||
self.frequency_penalty = frequency_penalty
|
||||
@@ -210,6 +218,71 @@ class OpenAICompletion(BaseLLM):
|
||||
)
|
||||
raise
|
||||
|
||||
async def acall(
|
||||
self,
|
||||
messages: str | list[LLMMessage],
|
||||
tools: list[dict[str, BaseTool]] | None = None,
|
||||
callbacks: list[Any] | None = None,
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Task | None = None,
|
||||
from_agent: Agent | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str | Any:
|
||||
"""Async call to OpenAI chat completion API.
|
||||
|
||||
Args:
|
||||
messages: Input messages for the chat completion
|
||||
tools: list of tool/function definitions
|
||||
callbacks: Callback functions (not used in native implementation)
|
||||
available_functions: Available functions for tool calling
|
||||
from_task: Task that initiated the call
|
||||
from_agent: Agent that initiated the call
|
||||
response_model: Response model for structured output.
|
||||
|
||||
Returns:
|
||||
Chat completion response or tool call result
|
||||
"""
|
||||
try:
|
||||
self._emit_call_started_event(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
callbacks=callbacks,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
formatted_messages = self._format_messages(messages)
|
||||
|
||||
completion_params = self._prepare_completion_params(
|
||||
messages=formatted_messages, tools=tools
|
||||
)
|
||||
|
||||
if self.stream:
|
||||
return await self._ahandle_streaming_completion(
|
||||
params=completion_params,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
response_model=response_model,
|
||||
)
|
||||
|
||||
return await self._ahandle_completion(
|
||||
params=completion_params,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
response_model=response_model,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"OpenAI API call failed: {e!s}"
|
||||
logging.error(error_msg)
|
||||
self._emit_call_failed_event(
|
||||
error=error_msg, from_task=from_task, from_agent=from_agent
|
||||
)
|
||||
raise
|
||||
|
||||
def _prepare_completion_params(
|
||||
self, messages: list[LLMMessage], tools: list[dict[str, BaseTool]] | None = None
|
||||
) -> dict[str, Any]:
|
||||
@@ -352,10 +425,10 @@ class OpenAICompletion(BaseLLM):
|
||||
|
||||
if message.tool_calls and available_functions:
|
||||
tool_call = message.tool_calls[0]
|
||||
function_name = tool_call.function.name # type: ignore[union-attr]
|
||||
function_name = tool_call.function.name
|
||||
|
||||
try:
|
||||
function_args = json.loads(tool_call.function.arguments) # type: ignore[union-attr]
|
||||
function_args = json.loads(tool_call.function.arguments)
|
||||
except json.JSONDecodeError as e:
|
||||
logging.error(f"Failed to parse tool arguments: {e}")
|
||||
function_args = {}
|
||||
@@ -564,6 +637,266 @@ class OpenAICompletion(BaseLLM):
|
||||
|
||||
return full_response
|
||||
|
||||
async def _ahandle_completion(
|
||||
self,
|
||||
params: dict[str, Any],
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str | Any:
|
||||
"""Handle non-streaming async chat completion."""
|
||||
try:
|
||||
if response_model:
|
||||
parse_params = {
|
||||
k: v for k, v in params.items() if k != "response_format"
|
||||
}
|
||||
parsed_response = await self.async_client.beta.chat.completions.parse(
|
||||
**parse_params,
|
||||
response_format=response_model,
|
||||
)
|
||||
math_reasoning = parsed_response.choices[0].message
|
||||
|
||||
if math_reasoning.refusal:
|
||||
pass
|
||||
|
||||
usage = self._extract_openai_token_usage(parsed_response)
|
||||
self._track_token_usage_internal(usage)
|
||||
|
||||
parsed_object = parsed_response.choices[0].message.parsed
|
||||
if parsed_object:
|
||||
structured_json = parsed_object.model_dump_json()
|
||||
self._emit_call_completed_event(
|
||||
response=structured_json,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
)
|
||||
return structured_json
|
||||
|
||||
response: ChatCompletion = await self.async_client.chat.completions.create(
|
||||
**params
|
||||
)
|
||||
|
||||
usage = self._extract_openai_token_usage(response)
|
||||
|
||||
self._track_token_usage_internal(usage)
|
||||
|
||||
choice: Choice = response.choices[0]
|
||||
message = choice.message
|
||||
|
||||
if message.tool_calls and available_functions:
|
||||
tool_call = message.tool_calls[0]
|
||||
function_name = tool_call.function.name
|
||||
|
||||
try:
|
||||
function_args = json.loads(tool_call.function.arguments)
|
||||
except json.JSONDecodeError as e:
|
||||
logging.error(f"Failed to parse tool arguments: {e}")
|
||||
function_args = {}
|
||||
|
||||
result = self._handle_tool_execution(
|
||||
function_name=function_name,
|
||||
function_args=function_args,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
content = message.content or ""
|
||||
content = self._apply_stop_words(content)
|
||||
|
||||
if self.response_format and isinstance(self.response_format, type):
|
||||
try:
|
||||
structured_result = self._validate_structured_output(
|
||||
content, self.response_format
|
||||
)
|
||||
self._emit_call_completed_event(
|
||||
response=structured_result,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
)
|
||||
return structured_result
|
||||
except ValueError as e:
|
||||
logging.warning(f"Structured output validation failed: {e}")
|
||||
|
||||
self._emit_call_completed_event(
|
||||
response=content,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
)
|
||||
|
||||
if usage.get("total_tokens", 0) > 0:
|
||||
logging.info(f"OpenAI API usage: {usage}")
|
||||
except NotFoundError as e:
|
||||
error_msg = f"Model {self.model} not found: {e}"
|
||||
logging.error(error_msg)
|
||||
self._emit_call_failed_event(
|
||||
error=error_msg, from_task=from_task, from_agent=from_agent
|
||||
)
|
||||
raise ValueError(error_msg) from e
|
||||
except APIConnectionError as e:
|
||||
error_msg = f"Failed to connect to OpenAI API: {e}"
|
||||
logging.error(error_msg)
|
||||
self._emit_call_failed_event(
|
||||
error=error_msg, from_task=from_task, from_agent=from_agent
|
||||
)
|
||||
raise ConnectionError(error_msg) from e
|
||||
except Exception as e:
|
||||
if is_context_length_exceeded(e):
|
||||
logging.error(f"Context window exceeded: {e}")
|
||||
raise LLMContextLengthExceededError(str(e)) from e
|
||||
|
||||
error_msg = f"OpenAI API call failed: {e!s}"
|
||||
logging.error(error_msg)
|
||||
self._emit_call_failed_event(
|
||||
error=error_msg, from_task=from_task, from_agent=from_agent
|
||||
)
|
||||
raise e from e
|
||||
|
||||
return content
|
||||
|
||||
async def _ahandle_streaming_completion(
|
||||
self,
|
||||
params: dict[str, Any],
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str:
|
||||
"""Handle async streaming chat completion."""
|
||||
full_response = ""
|
||||
tool_calls = {}
|
||||
|
||||
if response_model:
|
||||
completion_stream: AsyncIterator[
|
||||
ChatCompletionChunk
|
||||
] = await self.async_client.chat.completions.create(**params)
|
||||
|
||||
accumulated_content = ""
|
||||
async for chunk in completion_stream:
|
||||
if not chunk.choices:
|
||||
continue
|
||||
|
||||
choice = chunk.choices[0]
|
||||
delta: ChoiceDelta = choice.delta
|
||||
|
||||
if delta.content:
|
||||
accumulated_content += delta.content
|
||||
self._emit_stream_chunk_event(
|
||||
chunk=delta.content,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
try:
|
||||
parsed_object = response_model.model_validate_json(accumulated_content)
|
||||
structured_json = parsed_object.model_dump_json()
|
||||
|
||||
self._emit_call_completed_event(
|
||||
response=structured_json,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
)
|
||||
|
||||
return structured_json
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to parse structured output from stream: {e}")
|
||||
self._emit_call_completed_event(
|
||||
response=accumulated_content,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
)
|
||||
return accumulated_content
|
||||
|
||||
stream: AsyncIterator[
|
||||
ChatCompletionChunk
|
||||
] = await self.async_client.chat.completions.create(**params)
|
||||
|
||||
async for chunk in stream:
|
||||
if not chunk.choices:
|
||||
continue
|
||||
|
||||
choice = chunk.choices[0]
|
||||
chunk_delta: ChoiceDelta = choice.delta
|
||||
|
||||
if chunk_delta.content:
|
||||
full_response += chunk_delta.content
|
||||
self._emit_stream_chunk_event(
|
||||
chunk=chunk_delta.content,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
if chunk_delta.tool_calls:
|
||||
for tool_call in chunk_delta.tool_calls:
|
||||
call_id = tool_call.id or "default"
|
||||
if call_id not in tool_calls:
|
||||
tool_calls[call_id] = {
|
||||
"name": "",
|
||||
"arguments": "",
|
||||
}
|
||||
|
||||
if tool_call.function and tool_call.function.name:
|
||||
tool_calls[call_id]["name"] = tool_call.function.name
|
||||
if tool_call.function and tool_call.function.arguments:
|
||||
tool_calls[call_id]["arguments"] += tool_call.function.arguments
|
||||
|
||||
if tool_calls and available_functions:
|
||||
for call_data in tool_calls.values():
|
||||
function_name = call_data["name"]
|
||||
arguments = call_data["arguments"]
|
||||
|
||||
if not function_name or not arguments:
|
||||
continue
|
||||
|
||||
if function_name not in available_functions:
|
||||
logging.warning(
|
||||
f"Function '{function_name}' not found in available functions"
|
||||
)
|
||||
continue
|
||||
|
||||
try:
|
||||
function_args = json.loads(arguments)
|
||||
except json.JSONDecodeError as e:
|
||||
logging.error(f"Failed to parse streamed tool arguments: {e}")
|
||||
continue
|
||||
|
||||
result = self._handle_tool_execution(
|
||||
function_name=function_name,
|
||||
function_args=function_args,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
full_response = self._apply_stop_words(full_response)
|
||||
|
||||
self._emit_call_completed_event(
|
||||
response=full_response,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
)
|
||||
|
||||
return full_response
|
||||
|
||||
def supports_function_calling(self) -> bool:
|
||||
"""Check if the model supports function calling."""
|
||||
return not self.is_o1_model
|
||||
|
||||
@@ -66,7 +66,6 @@ class SSETransport(BaseTransport):
|
||||
self._transport_context = sse_client(
|
||||
self.url,
|
||||
headers=self.headers if self.headers else None,
|
||||
terminate_on_close=True,
|
||||
)
|
||||
|
||||
read, write = await self._transport_context.__aenter__()
|
||||
|
||||
@@ -2,8 +2,10 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
import inspect
|
||||
from typing import TYPE_CHECKING, Any, Concatenate, ParamSpec, TypeVar, overload
|
||||
|
||||
from crewai.project.utils import memoize
|
||||
@@ -156,6 +158,23 @@ def cache_handler(meth: Callable[P, R]) -> CacheHandlerMethod[P, R]:
|
||||
return CacheHandlerMethod(memoize(meth))
|
||||
|
||||
|
||||
def _call_method(method: Callable[..., Any], *args: Any, **kwargs: Any) -> Any:
|
||||
"""Call a method, awaiting it if async and running in an event loop."""
|
||||
result = method(*args, **kwargs)
|
||||
if inspect.iscoroutine(result):
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
loop = None
|
||||
if loop and loop.is_running():
|
||||
import concurrent.futures
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as pool:
|
||||
return pool.submit(asyncio.run, result).result()
|
||||
return asyncio.run(result)
|
||||
return result
|
||||
|
||||
|
||||
@overload
|
||||
def crew(
|
||||
meth: Callable[Concatenate[SelfT, P], Crew],
|
||||
@@ -198,7 +217,7 @@ def crew(
|
||||
|
||||
# Instantiate tasks in order
|
||||
for _, task_method in tasks:
|
||||
task_instance = task_method(self)
|
||||
task_instance = _call_method(task_method, self)
|
||||
instantiated_tasks.append(task_instance)
|
||||
agent_instance = getattr(task_instance, "agent", None)
|
||||
if agent_instance and agent_instance.role not in agent_roles:
|
||||
@@ -207,7 +226,7 @@ def crew(
|
||||
|
||||
# Instantiate agents not included by tasks
|
||||
for _, agent_method in agents:
|
||||
agent_instance = agent_method(self)
|
||||
agent_instance = _call_method(agent_method, self)
|
||||
if agent_instance.role not in agent_roles:
|
||||
instantiated_agents.append(agent_instance)
|
||||
agent_roles.add(agent_instance.role)
|
||||
@@ -215,7 +234,7 @@ def crew(
|
||||
self.agents = instantiated_agents
|
||||
self.tasks = instantiated_tasks
|
||||
|
||||
crew_instance = meth(self, *args, **kwargs)
|
||||
crew_instance: Crew = _call_method(meth, self, *args, **kwargs)
|
||||
|
||||
def callback_wrapper(
|
||||
hook: Callable[Concatenate[CrewInstance, P2], R2], instance: CrewInstance
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
"""Utility functions for the crewai project module."""
|
||||
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Callable, Coroutine
|
||||
from functools import wraps
|
||||
import inspect
|
||||
from typing import Any, ParamSpec, TypeVar, cast
|
||||
|
||||
from pydantic import BaseModel
|
||||
@@ -37,8 +38,8 @@ def _make_hashable(arg: Any) -> Any:
|
||||
def memoize(meth: Callable[P, R]) -> Callable[P, R]:
|
||||
"""Memoize a method by caching its results based on arguments.
|
||||
|
||||
Handles Pydantic BaseModel instances by converting them to JSON strings
|
||||
before hashing for cache lookup.
|
||||
Handles both sync and async methods. Pydantic BaseModel instances are
|
||||
converted to JSON strings before hashing for cache lookup.
|
||||
|
||||
Args:
|
||||
meth: The method to memoize.
|
||||
@@ -46,18 +47,16 @@ def memoize(meth: Callable[P, R]) -> Callable[P, R]:
|
||||
Returns:
|
||||
A memoized version of the method that caches results.
|
||||
"""
|
||||
if inspect.iscoroutinefunction(meth):
|
||||
return cast(Callable[P, R], _memoize_async(meth))
|
||||
return _memoize_sync(meth)
|
||||
|
||||
|
||||
def _memoize_sync(meth: Callable[P, R]) -> Callable[P, R]:
|
||||
"""Memoize a synchronous method."""
|
||||
|
||||
@wraps(meth)
|
||||
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
"""Wrapper that converts arguments to hashable form before caching.
|
||||
|
||||
Args:
|
||||
*args: Positional arguments to the memoized method.
|
||||
**kwargs: Keyword arguments to the memoized method.
|
||||
|
||||
Returns:
|
||||
The result of the memoized method call.
|
||||
"""
|
||||
hashable_args = tuple(_make_hashable(arg) for arg in args)
|
||||
hashable_kwargs = tuple(
|
||||
sorted((k, _make_hashable(v)) for k, v in kwargs.items())
|
||||
@@ -73,3 +72,27 @@ def memoize(meth: Callable[P, R]) -> Callable[P, R]:
|
||||
return result
|
||||
|
||||
return cast(Callable[P, R], wrapper)
|
||||
|
||||
|
||||
def _memoize_async(
|
||||
meth: Callable[P, Coroutine[Any, Any, R]],
|
||||
) -> Callable[P, Coroutine[Any, Any, R]]:
|
||||
"""Memoize an async method."""
|
||||
|
||||
@wraps(meth)
|
||||
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
hashable_args = tuple(_make_hashable(arg) for arg in args)
|
||||
hashable_kwargs = tuple(
|
||||
sorted((k, _make_hashable(v)) for k, v in kwargs.items())
|
||||
)
|
||||
cache_key = str((hashable_args, hashable_kwargs))
|
||||
|
||||
cached_result: R | None = cache.read(tool=meth.__name__, input=cache_key)
|
||||
if cached_result is not None:
|
||||
return cached_result
|
||||
|
||||
result = await meth(*args, **kwargs)
|
||||
cache.add(tool=meth.__name__, input=cache_key, output=result)
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
|
||||
@@ -2,8 +2,10 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Callable
|
||||
from functools import partial
|
||||
import inspect
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
@@ -132,6 +134,22 @@ class CrewClass(Protocol):
|
||||
crew: Callable[..., Crew]
|
||||
|
||||
|
||||
def _resolve_result(result: Any) -> Any:
|
||||
"""Resolve a potentially async result to its value."""
|
||||
if inspect.iscoroutine(result):
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
loop = None
|
||||
if loop and loop.is_running():
|
||||
import concurrent.futures
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as pool:
|
||||
return pool.submit(asyncio.run, result).result()
|
||||
return asyncio.run(result)
|
||||
return result
|
||||
|
||||
|
||||
class DecoratedMethod(Generic[P, R]):
|
||||
"""Base wrapper for methods with decorator metadata.
|
||||
|
||||
@@ -162,7 +180,12 @@ class DecoratedMethod(Generic[P, R]):
|
||||
"""
|
||||
if obj is None:
|
||||
return self
|
||||
bound = partial(self._meth, obj)
|
||||
inner = partial(self._meth, obj)
|
||||
|
||||
def _bound(*args: Any, **kwargs: Any) -> R:
|
||||
result: R = _resolve_result(inner(*args, **kwargs)) # type: ignore[call-arg]
|
||||
return result
|
||||
|
||||
for attr in (
|
||||
"is_agent",
|
||||
"is_llm",
|
||||
@@ -174,8 +197,8 @@ class DecoratedMethod(Generic[P, R]):
|
||||
"is_crew",
|
||||
):
|
||||
if hasattr(self, attr):
|
||||
setattr(bound, attr, getattr(self, attr))
|
||||
return bound
|
||||
setattr(_bound, attr, getattr(self, attr))
|
||||
return _bound
|
||||
|
||||
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R:
|
||||
"""Call the wrapped method.
|
||||
@@ -236,6 +259,7 @@ class BoundTaskMethod(Generic[TaskResultT]):
|
||||
The task result with name ensured.
|
||||
"""
|
||||
result = self._task_method.unwrap()(self._obj, *args, **kwargs)
|
||||
result = _resolve_result(result)
|
||||
return self._task_method.ensure_task_name(result)
|
||||
|
||||
|
||||
@@ -292,7 +316,9 @@ class TaskMethod(Generic[P, TaskResultT]):
|
||||
Returns:
|
||||
The task instance with name set if not already provided.
|
||||
"""
|
||||
return self.ensure_task_name(self._meth(*args, **kwargs))
|
||||
result = self._meth(*args, **kwargs)
|
||||
result = _resolve_result(result)
|
||||
return self.ensure_task_name(result)
|
||||
|
||||
def unwrap(self) -> Callable[P, TaskResultT]:
|
||||
"""Get the original unwrapped method.
|
||||
|
||||
@@ -9,12 +9,14 @@ data is collected. Users can opt-in to share more complete data using the
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import atexit
|
||||
from collections.abc import Callable
|
||||
from importlib.metadata import version
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import platform
|
||||
import signal
|
||||
import threading
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
@@ -31,6 +33,14 @@ from opentelemetry.sdk.trace.export import (
|
||||
from opentelemetry.trace import Span
|
||||
from typing_extensions import Self
|
||||
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.system_events import (
|
||||
SigContEvent,
|
||||
SigHupEvent,
|
||||
SigIntEvent,
|
||||
SigTStpEvent,
|
||||
SigTermEvent,
|
||||
)
|
||||
from crewai.telemetry.constants import (
|
||||
CREWAI_TELEMETRY_BASE_URL,
|
||||
CREWAI_TELEMETRY_SERVICE_NAME,
|
||||
@@ -121,6 +131,7 @@ class Telemetry:
|
||||
)
|
||||
|
||||
self.provider.add_span_processor(processor)
|
||||
self._register_shutdown_handlers()
|
||||
self.ready = True
|
||||
except Exception as e:
|
||||
if isinstance(
|
||||
@@ -155,6 +166,71 @@ class Telemetry:
|
||||
self.ready = False
|
||||
self.trace_set = False
|
||||
|
||||
def _register_shutdown_handlers(self) -> None:
|
||||
"""Register handlers for graceful shutdown on process exit and signals."""
|
||||
atexit.register(self._shutdown)
|
||||
|
||||
self._original_handlers: dict[int, Any] = {}
|
||||
|
||||
self._register_signal_handler(signal.SIGTERM, SigTermEvent, shutdown=True)
|
||||
self._register_signal_handler(signal.SIGINT, SigIntEvent, shutdown=True)
|
||||
self._register_signal_handler(signal.SIGHUP, SigHupEvent, shutdown=False)
|
||||
self._register_signal_handler(signal.SIGTSTP, SigTStpEvent, shutdown=False)
|
||||
self._register_signal_handler(signal.SIGCONT, SigContEvent, shutdown=False)
|
||||
|
||||
def _register_signal_handler(
|
||||
self,
|
||||
sig: signal.Signals,
|
||||
event_class: type,
|
||||
shutdown: bool = False,
|
||||
) -> None:
|
||||
"""Register a signal handler that emits an event.
|
||||
|
||||
Args:
|
||||
sig: The signal to handle.
|
||||
event_class: The event class to instantiate and emit.
|
||||
shutdown: Whether to trigger shutdown on this signal.
|
||||
"""
|
||||
try:
|
||||
original_handler = signal.getsignal(sig)
|
||||
self._original_handlers[sig] = original_handler
|
||||
|
||||
def handler(signum: int, frame: Any) -> None:
|
||||
crewai_event_bus.emit(self, event_class())
|
||||
|
||||
if shutdown:
|
||||
self._shutdown()
|
||||
|
||||
if original_handler not in (signal.SIG_DFL, signal.SIG_IGN, None):
|
||||
if callable(original_handler):
|
||||
original_handler(signum, frame)
|
||||
elif shutdown:
|
||||
raise SystemExit(0)
|
||||
|
||||
signal.signal(sig, handler)
|
||||
except ValueError as e:
|
||||
logger.warning(
|
||||
f"Cannot register {sig.name} handler: not running in main thread",
|
||||
exc_info=e,
|
||||
)
|
||||
except OSError as e:
|
||||
logger.warning(f"Cannot register {sig.name} handler: {e}", exc_info=e)
|
||||
|
||||
def _shutdown(self) -> None:
|
||||
"""Flush and shutdown the telemetry provider on process exit.
|
||||
|
||||
Uses a short timeout to avoid blocking process shutdown.
|
||||
"""
|
||||
if not self.ready:
|
||||
return
|
||||
|
||||
try:
|
||||
self.provider.force_flush(timeout_millis=5000)
|
||||
self.provider.shutdown()
|
||||
self.ready = False
|
||||
except Exception as e:
|
||||
logger.debug(f"Telemetry shutdown failed: {e}")
|
||||
|
||||
def _safe_telemetry_operation(
|
||||
self, operation: Callable[[], Span | None]
|
||||
) -> Span | None:
|
||||
|
||||
@@ -147,7 +147,7 @@ def test_custom_llm():
|
||||
assert agent.llm.model == "gpt-4"
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_agent_execution():
|
||||
agent = Agent(
|
||||
role="test role",
|
||||
@@ -166,7 +166,7 @@ def test_agent_execution():
|
||||
assert output == "1 + 1 is 2"
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_agent_execution_with_tools():
|
||||
@tool
|
||||
def multiplier(first_number: int, second_number: int) -> float:
|
||||
@@ -211,7 +211,7 @@ def test_agent_execution_with_tools():
|
||||
assert received_events[0].tool_args == {"first_number": 3, "second_number": 4}
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_logging_tool_usage():
|
||||
@tool
|
||||
def multiplier(first_number: int, second_number: int) -> float:
|
||||
@@ -245,7 +245,7 @@ def test_logging_tool_usage():
|
||||
assert agent.tools_handler.last_used_tool.arguments == tool_usage.arguments
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_cache_hitting():
|
||||
@tool
|
||||
def multiplier(first_number: int, second_number: int) -> float:
|
||||
@@ -325,7 +325,7 @@ def test_cache_hitting():
|
||||
assert received_events[0].output == "12"
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_disabling_cache_for_agent():
|
||||
@tool
|
||||
def multiplier(first_number: int, second_number: int) -> float:
|
||||
@@ -389,7 +389,7 @@ def test_disabling_cache_for_agent():
|
||||
read.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_agent_execution_with_specific_tools():
|
||||
@tool
|
||||
def multiplier(first_number: int, second_number: int) -> float:
|
||||
@@ -412,7 +412,7 @@ def test_agent_execution_with_specific_tools():
|
||||
assert output == "The result of the multiplication is 12."
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_agent_powered_by_new_o_model_family_that_allows_skipping_tool():
|
||||
@tool
|
||||
def multiplier(first_number: int, second_number: int) -> float:
|
||||
@@ -438,7 +438,7 @@ def test_agent_powered_by_new_o_model_family_that_allows_skipping_tool():
|
||||
assert output == "12"
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_agent_powered_by_new_o_model_family_that_uses_tool():
|
||||
@tool
|
||||
def comapny_customer_data() -> str:
|
||||
@@ -464,7 +464,7 @@ def test_agent_powered_by_new_o_model_family_that_uses_tool():
|
||||
assert output == "42"
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_agent_custom_max_iterations():
|
||||
@tool
|
||||
def get_final_answer() -> float:
|
||||
@@ -509,7 +509,7 @@ def test_agent_custom_max_iterations():
|
||||
assert call_count == 2
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
@pytest.mark.timeout(30)
|
||||
def test_agent_max_iterations_stops_loop():
|
||||
"""Test that agent execution terminates when max_iter is reached."""
|
||||
@@ -546,7 +546,7 @@ def test_agent_max_iterations_stops_loop():
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_agent_repeated_tool_usage(capsys):
|
||||
"""Test that agents handle repeated tool usage appropriately.
|
||||
|
||||
@@ -595,7 +595,7 @@ def test_agent_repeated_tool_usage(capsys):
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_agent_repeated_tool_usage_check_even_with_disabled_cache(capsys):
|
||||
@tool
|
||||
def get_final_answer(anything: str) -> float:
|
||||
@@ -638,7 +638,7 @@ def test_agent_repeated_tool_usage_check_even_with_disabled_cache(capsys):
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_agent_moved_on_after_max_iterations():
|
||||
@tool
|
||||
def get_final_answer() -> float:
|
||||
@@ -665,7 +665,7 @@ def test_agent_moved_on_after_max_iterations():
|
||||
assert output == "42"
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_agent_respect_the_max_rpm_set(capsys):
|
||||
@tool
|
||||
def get_final_answer() -> float:
|
||||
@@ -699,7 +699,7 @@ def test_agent_respect_the_max_rpm_set(capsys):
|
||||
moveon.assert_called()
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_agent_respect_the_max_rpm_set_over_crew_rpm(capsys):
|
||||
from unittest.mock import patch
|
||||
|
||||
@@ -737,7 +737,7 @@ def test_agent_respect_the_max_rpm_set_over_crew_rpm(capsys):
|
||||
moveon.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_agent_without_max_rpm_respects_crew_rpm(capsys):
|
||||
from unittest.mock import patch
|
||||
|
||||
@@ -797,7 +797,7 @@ def test_agent_without_max_rpm_respects_crew_rpm(capsys):
|
||||
moveon.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_agent_error_on_parsing_tool(capsys):
|
||||
from unittest.mock import patch
|
||||
|
||||
@@ -840,7 +840,7 @@ def test_agent_error_on_parsing_tool(capsys):
|
||||
assert "Error on parsing tool." in captured.out
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_agent_remembers_output_format_after_using_tools_too_many_times():
|
||||
from unittest.mock import patch
|
||||
|
||||
@@ -875,7 +875,7 @@ def test_agent_remembers_output_format_after_using_tools_too_many_times():
|
||||
remember_format.assert_called()
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_agent_use_specific_tasks_output_as_context(capsys):
|
||||
agent1 = Agent(role="test role", goal="test goal", backstory="test backstory")
|
||||
agent2 = Agent(role="test role2", goal="test goal2", backstory="test backstory2")
|
||||
@@ -902,7 +902,7 @@ def test_agent_use_specific_tasks_output_as_context(capsys):
|
||||
assert "hi" in result.raw.lower() or "hello" in result.raw.lower()
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_agent_step_callback():
|
||||
class StepCallback:
|
||||
def callback(self, step):
|
||||
@@ -936,7 +936,7 @@ def test_agent_step_callback():
|
||||
callback.assert_called()
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_agent_function_calling_llm():
|
||||
from crewai.llm import LLM
|
||||
llm = LLM(model="gpt-4o", is_litellm=True)
|
||||
@@ -983,7 +983,7 @@ def test_agent_function_calling_llm():
|
||||
mock_original_tool_calling.assert_called()
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_tool_result_as_answer_is_the_final_answer_for_the_agent():
|
||||
from crewai.tools import BaseTool
|
||||
|
||||
@@ -1013,7 +1013,7 @@ def test_tool_result_as_answer_is_the_final_answer_for_the_agent():
|
||||
assert result.raw == "Howdy!"
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_tool_usage_information_is_appended_to_agent():
|
||||
from crewai.tools import BaseTool
|
||||
|
||||
@@ -1068,7 +1068,7 @@ def test_agent_definition_based_on_dict():
|
||||
|
||||
|
||||
# test for human input
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_agent_human_input():
|
||||
# Agent configuration
|
||||
config = {
|
||||
@@ -1216,7 +1216,7 @@ Thought:<|eot_id|>
|
||||
assert mock_format_prompt.return_value == expected_prompt
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_task_allow_crewai_trigger_context():
|
||||
from crewai import Crew
|
||||
|
||||
@@ -1237,7 +1237,7 @@ def test_task_allow_crewai_trigger_context():
|
||||
assert "Trigger Payload: Important context data" in prompt
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_task_without_allow_crewai_trigger_context():
|
||||
from crewai import Crew
|
||||
|
||||
@@ -1260,7 +1260,7 @@ def test_task_without_allow_crewai_trigger_context():
|
||||
assert "Important context data" not in prompt
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_task_allow_crewai_trigger_context_no_payload():
|
||||
from crewai import Crew
|
||||
|
||||
@@ -1282,7 +1282,7 @@ def test_task_allow_crewai_trigger_context_no_payload():
|
||||
assert "Trigger Payload:" not in prompt
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_do_not_allow_crewai_trigger_context_for_first_task_hierarchical():
|
||||
from crewai import Crew
|
||||
|
||||
@@ -1311,7 +1311,7 @@ def test_do_not_allow_crewai_trigger_context_for_first_task_hierarchical():
|
||||
assert "Trigger Payload: Initial context data" not in first_prompt
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_first_task_auto_inject_trigger():
|
||||
from crewai import Crew
|
||||
|
||||
@@ -1344,7 +1344,7 @@ def test_first_task_auto_inject_trigger():
|
||||
assert "Trigger Payload:" not in second_prompt
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_ensure_first_task_allow_crewai_trigger_context_is_false_does_not_inject():
|
||||
from crewai import Crew
|
||||
|
||||
@@ -1549,7 +1549,7 @@ def test_agent_with_additional_kwargs():
|
||||
assert agent.llm.frequency_penalty == 0.1
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_llm_call():
|
||||
llm = LLM(model="gpt-3.5-turbo")
|
||||
messages = [{"role": "user", "content": "Say 'Hello, World!'"}]
|
||||
@@ -1558,7 +1558,7 @@ def test_llm_call():
|
||||
assert "Hello, World!" in response
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_llm_call_with_error():
|
||||
llm = LLM(model="non-existent-model")
|
||||
messages = [{"role": "user", "content": "This should fail"}]
|
||||
@@ -1567,7 +1567,7 @@ def test_llm_call_with_error():
|
||||
llm.call(messages)
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_handle_context_length_exceeds_limit():
|
||||
# Import necessary modules
|
||||
from crewai.utilities.agent_utils import handle_context_length
|
||||
@@ -1620,7 +1620,7 @@ def test_handle_context_length_exceeds_limit():
|
||||
mock_summarize.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_handle_context_length_exceeds_limit_cli_no():
|
||||
agent = Agent(
|
||||
role="test role",
|
||||
@@ -1695,7 +1695,7 @@ def test_agent_with_all_llm_attributes():
|
||||
assert agent.llm.api_key == "sk-your-api-key-here"
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_llm_call_with_all_attributes():
|
||||
llm = LLM(
|
||||
model="gpt-3.5-turbo",
|
||||
@@ -1712,7 +1712,7 @@ def test_llm_call_with_all_attributes():
|
||||
assert "STOP" not in response
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_agent_with_ollama_llama3():
|
||||
agent = Agent(
|
||||
role="test role",
|
||||
@@ -1733,7 +1733,7 @@ def test_agent_with_ollama_llama3():
|
||||
assert "Llama3" in response or "AI" in response or "language model" in response
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_llm_call_with_ollama_llama3():
|
||||
llm = LLM(
|
||||
model="ollama/llama3.2:3b",
|
||||
@@ -1752,7 +1752,7 @@ def test_llm_call_with_ollama_llama3():
|
||||
assert "Llama3" in response or "AI" in response or "language model" in response
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_agent_execute_task_basic():
|
||||
agent = Agent(
|
||||
role="test role",
|
||||
@@ -1771,7 +1771,7 @@ def test_agent_execute_task_basic():
|
||||
assert "4" in result
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_agent_execute_task_with_context():
|
||||
agent = Agent(
|
||||
role="test role",
|
||||
@@ -1793,7 +1793,7 @@ def test_agent_execute_task_with_context():
|
||||
assert "fox" in result.lower() and "dog" in result.lower()
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_agent_execute_task_with_tool():
|
||||
@tool
|
||||
def dummy_tool(query: str) -> str:
|
||||
@@ -1818,7 +1818,7 @@ def test_agent_execute_task_with_tool():
|
||||
assert "Dummy result for: test query" in result
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_agent_execute_task_with_custom_llm():
|
||||
agent = Agent(
|
||||
role="test role",
|
||||
@@ -1839,7 +1839,7 @@ def test_agent_execute_task_with_custom_llm():
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_agent_execute_task_with_ollama():
|
||||
agent = Agent(
|
||||
role="test role",
|
||||
@@ -1859,7 +1859,7 @@ def test_agent_execute_task_with_ollama():
|
||||
assert "AI" in result or "artificial intelligence" in result.lower()
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_agent_with_knowledge_sources():
|
||||
content = "Brandon's favorite color is red and he likes Mexican food."
|
||||
string_source = StringKnowledgeSource(content=content)
|
||||
@@ -1891,7 +1891,7 @@ def test_agent_with_knowledge_sources():
|
||||
assert "red" in result.raw.lower()
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_agent_with_knowledge_sources_with_query_limit_and_score_threshold():
|
||||
content = "Brandon's favorite color is red and he likes Mexican food."
|
||||
string_source = StringKnowledgeSource(content=content)
|
||||
@@ -1939,7 +1939,7 @@ def test_agent_with_knowledge_sources_with_query_limit_and_score_threshold():
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_agent_with_knowledge_sources_with_query_limit_and_score_threshold_default():
|
||||
content = "Brandon's favorite color is red and he likes Mexican food."
|
||||
string_source = StringKnowledgeSource(content=content)
|
||||
@@ -1988,7 +1988,7 @@ def test_agent_with_knowledge_sources_with_query_limit_and_score_threshold_defau
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_agent_with_knowledge_sources_extensive_role():
|
||||
content = "Brandon's favorite color is red and he likes Mexican food."
|
||||
string_source = StringKnowledgeSource(content=content)
|
||||
@@ -2024,7 +2024,7 @@ def test_agent_with_knowledge_sources_extensive_role():
|
||||
assert "red" in result.raw.lower()
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_agent_with_knowledge_sources_works_with_copy():
|
||||
content = "Brandon's favorite color is red and he likes Mexican food."
|
||||
string_source = StringKnowledgeSource(content=content)
|
||||
@@ -2063,7 +2063,7 @@ def test_agent_with_knowledge_sources_works_with_copy():
|
||||
assert isinstance(agent_copy.llm, BaseLLM)
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_agent_with_knowledge_sources_generate_search_query():
|
||||
content = "Brandon's favorite color is red and he likes Mexican food."
|
||||
string_source = StringKnowledgeSource(content=content)
|
||||
@@ -2116,7 +2116,7 @@ def test_agent_with_knowledge_sources_generate_search_query():
|
||||
assert "red" in result.raw.lower()
|
||||
|
||||
|
||||
@pytest.mark.vcr(record_mode="none", filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_agent_with_knowledge_with_no_crewai_knowledge():
|
||||
mock_knowledge = MagicMock(spec=Knowledge)
|
||||
|
||||
@@ -2143,7 +2143,7 @@ def test_agent_with_knowledge_with_no_crewai_knowledge():
|
||||
mock_knowledge.query.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_agent_with_only_crewai_knowledge():
|
||||
mock_knowledge = MagicMock(spec=Knowledge)
|
||||
|
||||
@@ -2168,7 +2168,7 @@ def test_agent_with_only_crewai_knowledge():
|
||||
mock_knowledge.query.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.vcr(record_mode="none", filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_agent_knowledege_with_crewai_knowledge():
|
||||
crew_knowledge = MagicMock(spec=Knowledge)
|
||||
agent_knowledge = MagicMock(spec=Knowledge)
|
||||
@@ -2197,7 +2197,7 @@ def test_agent_knowledege_with_crewai_knowledge():
|
||||
crew_knowledge.query.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_litellm_auth_error_handling():
|
||||
"""Test that LiteLLM authentication errors are handled correctly and not retried."""
|
||||
from litellm import AuthenticationError as LiteLLMAuthenticationError
|
||||
@@ -2326,7 +2326,7 @@ def test_litellm_anthropic_error_handling():
|
||||
mock_llm_call.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_get_knowledge_search_query():
|
||||
"""Test that _get_knowledge_search_query calls the LLM with the correct prompts."""
|
||||
from crewai.utilities.i18n import I18N
|
||||
|
||||
@@ -70,7 +70,7 @@ class ResearchResult(BaseModel):
|
||||
sources: list[str] = Field(description="List of sources used")
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
@pytest.mark.parametrize("verbose", [True, False])
|
||||
def test_lite_agent_created_with_correct_parameters(monkeypatch, verbose):
|
||||
"""Test that LiteAgent is created with the correct parameters when Agent.kickoff() is called."""
|
||||
@@ -130,7 +130,7 @@ def test_lite_agent_created_with_correct_parameters(monkeypatch, verbose):
|
||||
assert created_lite_agent["response_format"] == TestResponse
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_lite_agent_with_tools():
|
||||
"""Test that Agent can use tools."""
|
||||
# Create a LiteAgent with tools
|
||||
@@ -174,7 +174,7 @@ def test_lite_agent_with_tools():
|
||||
assert event.tool_name == "search_web"
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_lite_agent_structured_output():
|
||||
"""Test that Agent can return a simple structured output."""
|
||||
|
||||
@@ -217,7 +217,7 @@ def test_lite_agent_structured_output():
|
||||
return result
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_lite_agent_returns_usage_metrics():
|
||||
"""Test that LiteAgent returns usage metrics."""
|
||||
llm = LLM(model="gpt-4o-mini")
|
||||
@@ -238,7 +238,7 @@ def test_lite_agent_returns_usage_metrics():
|
||||
assert result.usage_metrics["total_tokens"] > 0
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_lite_agent_output_includes_messages():
|
||||
"""Test that LiteAgentOutput includes messages from agent execution."""
|
||||
llm = LLM(model="gpt-4o-mini")
|
||||
@@ -259,7 +259,7 @@ def test_lite_agent_output_includes_messages():
|
||||
assert len(result.messages) > 0
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
@pytest.mark.asyncio
|
||||
async def test_lite_agent_returns_usage_metrics_async():
|
||||
"""Test that LiteAgent returns usage metrics when run asynchronously."""
|
||||
@@ -354,9 +354,9 @@ def test_sets_parent_flow_when_inside_flow():
|
||||
assert captured_agent.parent_flow is flow
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_guardrail_is_called_using_string():
|
||||
guardrail_events = defaultdict(list)
|
||||
guardrail_events: dict[str, list] = defaultdict(list)
|
||||
from crewai.events.event_types import (
|
||||
LLMGuardrailCompletedEvent,
|
||||
LLMGuardrailStartedEvent,
|
||||
@@ -369,35 +369,33 @@ def test_guardrail_is_called_using_string():
|
||||
guardrail="""Only include Brazilian players, both women and men""",
|
||||
)
|
||||
|
||||
all_events_received = threading.Event()
|
||||
condition = threading.Condition()
|
||||
|
||||
@crewai_event_bus.on(LLMGuardrailStartedEvent)
|
||||
def capture_guardrail_started(source, event):
|
||||
assert isinstance(source, LiteAgent)
|
||||
assert source.original_agent == agent
|
||||
guardrail_events["started"].append(event)
|
||||
if (
|
||||
len(guardrail_events["started"]) == 2
|
||||
and len(guardrail_events["completed"]) == 2
|
||||
):
|
||||
all_events_received.set()
|
||||
with condition:
|
||||
guardrail_events["started"].append(event)
|
||||
condition.notify()
|
||||
|
||||
@crewai_event_bus.on(LLMGuardrailCompletedEvent)
|
||||
def capture_guardrail_completed(source, event):
|
||||
assert isinstance(source, LiteAgent)
|
||||
assert source.original_agent == agent
|
||||
guardrail_events["completed"].append(event)
|
||||
if (
|
||||
len(guardrail_events["started"]) == 2
|
||||
and len(guardrail_events["completed"]) == 2
|
||||
):
|
||||
all_events_received.set()
|
||||
with condition:
|
||||
guardrail_events["completed"].append(event)
|
||||
condition.notify()
|
||||
|
||||
result = agent.kickoff(messages="Top 10 best players in the world?")
|
||||
|
||||
assert all_events_received.wait(timeout=10), (
|
||||
"Timeout waiting for all guardrail events"
|
||||
)
|
||||
with condition:
|
||||
success = condition.wait_for(
|
||||
lambda: len(guardrail_events["started"]) >= 2
|
||||
and len(guardrail_events["completed"]) >= 2,
|
||||
timeout=10,
|
||||
)
|
||||
assert success, "Timeout waiting for all guardrail events"
|
||||
assert len(guardrail_events["started"]) == 2
|
||||
assert len(guardrail_events["completed"]) == 2
|
||||
assert not guardrail_events["completed"][0].success
|
||||
@@ -408,33 +406,27 @@ def test_guardrail_is_called_using_string():
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_guardrail_is_called_using_callable():
|
||||
guardrail_events = defaultdict(list)
|
||||
guardrail_events: dict[str, list] = defaultdict(list)
|
||||
from crewai.events.event_types import (
|
||||
LLMGuardrailCompletedEvent,
|
||||
LLMGuardrailStartedEvent,
|
||||
)
|
||||
|
||||
all_events_received = threading.Event()
|
||||
condition = threading.Condition()
|
||||
|
||||
@crewai_event_bus.on(LLMGuardrailStartedEvent)
|
||||
def capture_guardrail_started(source, event):
|
||||
guardrail_events["started"].append(event)
|
||||
if (
|
||||
len(guardrail_events["started"]) == 1
|
||||
and len(guardrail_events["completed"]) == 1
|
||||
):
|
||||
all_events_received.set()
|
||||
with condition:
|
||||
guardrail_events["started"].append(event)
|
||||
condition.notify()
|
||||
|
||||
@crewai_event_bus.on(LLMGuardrailCompletedEvent)
|
||||
def capture_guardrail_completed(source, event):
|
||||
guardrail_events["completed"].append(event)
|
||||
if (
|
||||
len(guardrail_events["started"]) == 1
|
||||
and len(guardrail_events["completed"]) == 1
|
||||
):
|
||||
all_events_received.set()
|
||||
with condition:
|
||||
guardrail_events["completed"].append(event)
|
||||
condition.notify()
|
||||
|
||||
agent = Agent(
|
||||
role="Sports Analyst",
|
||||
@@ -445,42 +437,40 @@ def test_guardrail_is_called_using_callable():
|
||||
|
||||
result = agent.kickoff(messages="Top 1 best players in the world?")
|
||||
|
||||
assert all_events_received.wait(timeout=10), (
|
||||
"Timeout waiting for all guardrail events"
|
||||
)
|
||||
with condition:
|
||||
success = condition.wait_for(
|
||||
lambda: len(guardrail_events["started"]) >= 1
|
||||
and len(guardrail_events["completed"]) >= 1,
|
||||
timeout=10,
|
||||
)
|
||||
assert success, "Timeout waiting for all guardrail events"
|
||||
assert len(guardrail_events["started"]) == 1
|
||||
assert len(guardrail_events["completed"]) == 1
|
||||
assert guardrail_events["completed"][0].success
|
||||
assert "Pelé - Santos, 1958" in result.raw
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_guardrail_reached_attempt_limit():
|
||||
guardrail_events = defaultdict(list)
|
||||
guardrail_events: dict[str, list] = defaultdict(list)
|
||||
from crewai.events.event_types import (
|
||||
LLMGuardrailCompletedEvent,
|
||||
LLMGuardrailStartedEvent,
|
||||
)
|
||||
|
||||
all_events_received = threading.Event()
|
||||
condition = threading.Condition()
|
||||
|
||||
@crewai_event_bus.on(LLMGuardrailStartedEvent)
|
||||
def capture_guardrail_started(source, event):
|
||||
guardrail_events["started"].append(event)
|
||||
if (
|
||||
len(guardrail_events["started"]) == 3
|
||||
and len(guardrail_events["completed"]) == 3
|
||||
):
|
||||
all_events_received.set()
|
||||
with condition:
|
||||
guardrail_events["started"].append(event)
|
||||
condition.notify()
|
||||
|
||||
@crewai_event_bus.on(LLMGuardrailCompletedEvent)
|
||||
def capture_guardrail_completed(source, event):
|
||||
guardrail_events["completed"].append(event)
|
||||
if (
|
||||
len(guardrail_events["started"]) == 3
|
||||
and len(guardrail_events["completed"]) == 3
|
||||
):
|
||||
all_events_received.set()
|
||||
with condition:
|
||||
guardrail_events["completed"].append(event)
|
||||
condition.notify()
|
||||
|
||||
agent = Agent(
|
||||
role="Sports Analyst",
|
||||
@@ -498,9 +488,13 @@ def test_guardrail_reached_attempt_limit():
|
||||
):
|
||||
agent.kickoff(messages="Top 10 best players in the world?")
|
||||
|
||||
assert all_events_received.wait(timeout=10), (
|
||||
"Timeout waiting for all guardrail events"
|
||||
)
|
||||
with condition:
|
||||
success = condition.wait_for(
|
||||
lambda: len(guardrail_events["started"]) >= 3
|
||||
and len(guardrail_events["completed"]) >= 3,
|
||||
timeout=10,
|
||||
)
|
||||
assert success, "Timeout waiting for all guardrail events"
|
||||
assert len(guardrail_events["started"]) == 3 # 2 retries + 1 initial call
|
||||
assert len(guardrail_events["completed"]) == 3 # 2 retries + 1 initial call
|
||||
assert not guardrail_events["completed"][0].success
|
||||
@@ -508,7 +502,7 @@ def test_guardrail_reached_attempt_limit():
|
||||
assert not guardrail_events["completed"][2].success
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_agent_output_when_guardrail_returns_base_model():
|
||||
class Player(BaseModel):
|
||||
name: str
|
||||
@@ -599,7 +593,7 @@ def test_lite_agent_with_custom_llm_and_guardrails():
|
||||
assert result2.raw == "Modified by guardrail"
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_lite_agent_with_invalid_llm():
|
||||
"""Test that LiteAgent raises proper error when create_llm returns None."""
|
||||
with patch("crewai.lite_agent.create_llm", return_value=None):
|
||||
@@ -615,7 +609,7 @@ def test_lite_agent_with_invalid_llm():
|
||||
|
||||
@patch.dict("os.environ", {"CREWAI_PLATFORM_INTEGRATION_TOKEN": "test_token"})
|
||||
@patch("crewai_tools.tools.crewai_platform_tools.crewai_platform_tool_builder.requests.get")
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_agent_kickoff_with_platform_tools(mock_get):
|
||||
"""Test that Agent.kickoff() properly integrates platform tools with LiteAgent"""
|
||||
mock_response = Mock()
|
||||
@@ -657,7 +651,7 @@ def test_agent_kickoff_with_platform_tools(mock_get):
|
||||
|
||||
@patch.dict("os.environ", {"EXA_API_KEY": "test_exa_key"})
|
||||
@patch("crewai.agent.Agent._get_external_mcp_tools")
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr()
|
||||
def test_agent_kickoff_with_mcp_tools(mock_get_mcp_tools):
|
||||
"""Test that Agent.kickoff() properly integrates MCP tools with LiteAgent"""
|
||||
# Setup mock MCP tools - create a proper BaseTool instance
|
||||
|
||||
@@ -1,126 +0,0 @@
|
||||
interactions:
|
||||
- request:
|
||||
body: '{"messages": [{"role": "system", "content": "You are Test Agent. Test backstory\nYour
|
||||
personal goal is: Test goal\nTo give my best complete final answer to the task
|
||||
respond using the exact following format:\n\nThought: I now can give a great
|
||||
answer\nFinal Answer: Your final answer must be the great and the most complete
|
||||
as possible, it must be outcome described.\n\nI MUST use these formats, my job
|
||||
depends on it!"}, {"role": "user", "content": "\nCurrent Task: Say hello to
|
||||
the world\n\nThis is the expected criteria for your final answer: hello world\nyou
|
||||
MUST return the actual complete content as the final answer, not a summary.\n\nBegin!
|
||||
This is VERY important to you, use the tools available and give your best Final
|
||||
Answer, your job depends on it!\n\nThought:"}], "model": "gpt-4o-mini", "stop":
|
||||
["\nObservation:"]}'
|
||||
headers:
|
||||
accept:
|
||||
- application/json
|
||||
accept-encoding:
|
||||
- gzip, deflate, zstd
|
||||
connection:
|
||||
- keep-alive
|
||||
content-length:
|
||||
- '825'
|
||||
content-type:
|
||||
- application/json
|
||||
host:
|
||||
- api.openai.com
|
||||
user-agent:
|
||||
- OpenAI/Python 1.93.0
|
||||
x-stainless-arch:
|
||||
- arm64
|
||||
x-stainless-async:
|
||||
- 'false'
|
||||
x-stainless-lang:
|
||||
- python
|
||||
x-stainless-os:
|
||||
- MacOS
|
||||
x-stainless-package-version:
|
||||
- 1.93.0
|
||||
x-stainless-raw-response:
|
||||
- 'true'
|
||||
x-stainless-read-timeout:
|
||||
- '600.0'
|
||||
x-stainless-retry-count:
|
||||
- '0'
|
||||
x-stainless-runtime:
|
||||
- CPython
|
||||
x-stainless-runtime-version:
|
||||
- 3.12.9
|
||||
method: POST
|
||||
uri: https://api.openai.com/v1/chat/completions
|
||||
response:
|
||||
body:
|
||||
string: !!binary |
|
||||
H4sIAAAAAAAAAwAAAP//jFJdi9swEHz3r1j0HBc7ZydXvx1HC0evhT6UUtrDKNLa1lXWqpJ8aTjy
|
||||
34vsXOz0A/pi8M7OaGZ3nxMApiSrgImOB9Fbnd4Why/v79HeePeD37/Zfdx8evf5+rY4fNjdPbJV
|
||||
ZNDuEUV4Yb0S1FuNQZGZYOGQB4yq+bYsr7J1nhcj0JNEHWmtDWlBaa+MStfZukizbZpfn9gdKYGe
|
||||
VfA1AQB4Hr/Rp5H4k1WQrV4qPXrPW2TVuQmAOdKxwrj3ygduAlvNoCAT0IzW78DQHgQ30KonBA5t
|
||||
tA3c+D06gG/mrTJcw834X0GHWhPsyWm5FHTYDJ7HUGbQegFwYyjwOJQxysMJOZ7Na2qto53/jcoa
|
||||
ZZTvaofck4lGfSDLRvSYADyMQxoucjPrqLehDvQdx+fycjvpsXk3C/TqBAYKXC/q29NoL/VqiYEr
|
||||
7RdjZoKLDuVMnXfCB6loASSL1H+6+Zv2lFyZ9n/kZ0AItAFlbR1KJS4Tz20O4+n+q+085dEw8+ie
|
||||
lMA6KHRxExIbPujpoJg/+IB93SjTorNOTVfV2LrcZLzZYFm+Zskx+QUAAP//AwB1vYZ+YwMAAA==
|
||||
headers:
|
||||
CF-RAY:
|
||||
- 96fc9f29dea3cf1f-SJC
|
||||
Connection:
|
||||
- keep-alive
|
||||
Content-Encoding:
|
||||
- gzip
|
||||
Content-Type:
|
||||
- application/json
|
||||
Date:
|
||||
- Fri, 15 Aug 2025 23:55:15 GMT
|
||||
Server:
|
||||
- cloudflare
|
||||
Set-Cookie:
|
||||
- __cf_bm=oA9oTa3cE0ZaEUDRf0hCpnarSAQKzrVUhl6qDS4j09w-1755302115-1.0.1.1-gUUDl4ZqvBQkg7244DTwOmSiDUT2z_AiQu0P1xUaABjaufSpZuIlI5G0H7OSnW.ldypvpxjj45NGWesJ62M_2U7r20tHz_gMmDFw6D5ZiNc;
|
||||
path=/; expires=Sat, 16-Aug-25 00:25:15 GMT; domain=.api.openai.com; HttpOnly;
|
||||
Secure; SameSite=None
|
||||
- _cfuvid=ICenEGMmOE5jaOjwD30bAOwrF8.XRbSIKTBl1EyWs0o-1755302115700-0.0.1.1-604800000;
|
||||
path=/; domain=.api.openai.com; HttpOnly; Secure; SameSite=None
|
||||
Strict-Transport-Security:
|
||||
- max-age=31536000; includeSubDomains; preload
|
||||
Transfer-Encoding:
|
||||
- chunked
|
||||
X-Content-Type-Options:
|
||||
- nosniff
|
||||
access-control-expose-headers:
|
||||
- X-Request-ID
|
||||
alt-svc:
|
||||
- h3=":443"; ma=86400
|
||||
cf-cache-status:
|
||||
- DYNAMIC
|
||||
openai-organization:
|
||||
- crewai-iuxna1
|
||||
openai-processing-ms:
|
||||
- '735'
|
||||
openai-project:
|
||||
- proj_xitITlrFeen7zjNSzML82h9x
|
||||
openai-version:
|
||||
- '2020-10-01'
|
||||
x-envoy-upstream-service-time:
|
||||
- '753'
|
||||
x-ratelimit-limit-project-tokens:
|
||||
- '150000000'
|
||||
x-ratelimit-limit-requests:
|
||||
- '30000'
|
||||
x-ratelimit-limit-tokens:
|
||||
- '150000000'
|
||||
x-ratelimit-remaining-project-tokens:
|
||||
- '149999830'
|
||||
x-ratelimit-remaining-requests:
|
||||
- '29999'
|
||||
x-ratelimit-remaining-tokens:
|
||||
- '149999827'
|
||||
x-ratelimit-reset-project-tokens:
|
||||
- 0s
|
||||
x-ratelimit-reset-requests:
|
||||
- 2ms
|
||||
x-ratelimit-reset-tokens:
|
||||
- 0s
|
||||
x-request-id:
|
||||
- req_212fde9d945a462ba0d89ea856131dce
|
||||
status:
|
||||
code: 200
|
||||
message: OK
|
||||
version: 1
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user