mirror of
https://github.com/crewAIInc/crewAI.git
synced 2025-12-16 04:18:35 +00:00
Compare commits
32 Commits
1.2.1
...
gl/chore/p
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9e75055c3e | ||
|
|
b9bc5714c3 | ||
|
|
996c53228b | ||
|
|
d0d22038f5 | ||
|
|
2ba0434eb8 | ||
|
|
cad3251721 | ||
|
|
19f4ef0982 | ||
|
|
54f8cc685f | ||
|
|
0f1c173d02 | ||
|
|
19c5b9a35e | ||
|
|
1ed307b58c | ||
|
|
d29867bbb6 | ||
|
|
b2c278ed22 | ||
|
|
f6aed9798b | ||
|
|
40a2d387a1 | ||
|
|
6f36d7003b | ||
|
|
9e5906c52f | ||
|
|
fc521839e4 | ||
|
|
e4cc9a664c | ||
|
|
7e6171d5bc | ||
|
|
61ad1fb112 | ||
|
|
54710a8711 | ||
|
|
5abf976373 | ||
|
|
329567153b | ||
|
|
60332e0b19 | ||
|
|
40932af3fa | ||
|
|
e134e5305b | ||
|
|
e229ef4e19 | ||
|
|
2e9eb8c32d | ||
|
|
4ebb5114ed | ||
|
|
70b083945f | ||
|
|
410db1ff39 |
18
.github/workflows/tests.yml
vendored
18
.github/workflows/tests.yml
vendored
@@ -5,18 +5,6 @@ on: [pull_request]
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
env:
|
||||
OPENAI_API_KEY: fake-api-key
|
||||
PYTHONUNBUFFERED: 1
|
||||
BRAVE_API_KEY: fake-brave-key
|
||||
SNOWFLAKE_USER: fake-snowflake-user
|
||||
SNOWFLAKE_PASSWORD: fake-snowflake-password
|
||||
SNOWFLAKE_ACCOUNT: fake-snowflake-account
|
||||
SNOWFLAKE_WAREHOUSE: fake-snowflake-warehouse
|
||||
SNOWFLAKE_DATABASE: fake-snowflake-database
|
||||
SNOWFLAKE_SCHEMA: fake-snowflake-schema
|
||||
EMBEDCHAIN_DB_URI: sqlite:///test.db
|
||||
|
||||
jobs:
|
||||
tests:
|
||||
name: tests (${{ matrix.python-version }})
|
||||
@@ -84,26 +72,20 @@ jobs:
|
||||
# fi
|
||||
|
||||
cd lib/crewai && uv run pytest \
|
||||
--block-network \
|
||||
--timeout=30 \
|
||||
-vv \
|
||||
--splits 8 \
|
||||
--group ${{ matrix.group }} \
|
||||
$DURATIONS_ARG \
|
||||
--durations=10 \
|
||||
-n auto \
|
||||
--maxfail=3
|
||||
|
||||
- name: Run tool tests (group ${{ matrix.group }} of 8)
|
||||
run: |
|
||||
cd lib/crewai-tools && uv run pytest \
|
||||
--block-network \
|
||||
--timeout=30 \
|
||||
-vv \
|
||||
--splits 8 \
|
||||
--group ${{ matrix.group }} \
|
||||
--durations=10 \
|
||||
-n auto \
|
||||
--maxfail=3
|
||||
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@ repos:
|
||||
language: system
|
||||
pass_filenames: true
|
||||
types: [python]
|
||||
exclude: ^(lib/crewai/src/crewai/cli/templates/|lib/crewai/tests/|lib/crewai-tools/tests/)
|
||||
- repo: https://github.com/astral-sh/uv-pre-commit
|
||||
rev: 0.9.3
|
||||
hooks:
|
||||
|
||||
138
conftest.py
Normal file
138
conftest.py
Normal file
@@ -0,0 +1,138 @@
|
||||
"""Pytest configuration for crewAI workspace."""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
import tempfile
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
# Load .env.test for consistent test environment
|
||||
# env_test_path = Path(__file__).parent / ".env.test"
|
||||
# load_dotenv(env_test_path, override=True)
|
||||
# load_dotenv(override=True)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True, scope="function")
|
||||
def setup_test_environment() -> None: # type: ignore
|
||||
"""Set up test environment with a temporary directory for SQLite storage."""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
# Create the directory with proper permissions
|
||||
storage_dir = Path(temp_dir) / "crewai_test_storage"
|
||||
storage_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Validate that the directory was created successfully
|
||||
if not storage_dir.exists() or not storage_dir.is_dir():
|
||||
raise RuntimeError(
|
||||
f"Failed to create test storage directory: {storage_dir}"
|
||||
)
|
||||
|
||||
# Verify directory permissions
|
||||
try:
|
||||
# Try to create a test file to verify write permissions
|
||||
test_file = storage_dir / ".permissions_test"
|
||||
test_file.touch()
|
||||
test_file.unlink()
|
||||
except (OSError, IOError) as e:
|
||||
raise RuntimeError(
|
||||
f"Test storage directory {storage_dir} is not writable: {e}"
|
||||
) from e
|
||||
|
||||
os.environ["CREWAI_STORAGE_DIR"] = str(storage_dir)
|
||||
os.environ["CREWAI_TESTING"] = "true"
|
||||
yield
|
||||
|
||||
os.environ.pop("CREWAI_TESTING", None)
|
||||
# Cleanup is handled automatically when tempfile context exits
|
||||
|
||||
|
||||
HEADERS_TO_FILTER = {
|
||||
"authorization": "AUTHORIZATION-XXX",
|
||||
"content-security-policy": "CSP-FILTERED",
|
||||
"cookie": "COOKIE-XXX",
|
||||
"set-cookie": "SET-COOKIE-XXX",
|
||||
"permissions-policy": "PERMISSIONS-POLICY-XXX",
|
||||
"referrer-policy": "REFERRER-POLICY-XXX",
|
||||
"strict-transport-security": "STS-XXX",
|
||||
"x-content-type-options": "X-CONTENT-TYPE-XXX",
|
||||
"x-frame-options": "X-FRAME-OPTIONS-XXX",
|
||||
"x-permitted-cross-domain-policies": "X-PERMITTED-XXX",
|
||||
"x-request-id": "X-REQUEST-ID-XXX",
|
||||
"x-runtime": "X-RUNTIME-XXX",
|
||||
"x-xss-protection": "X-XSS-PROTECTION-XXX",
|
||||
"x-stainless-arch": "X-STAINLESS-ARCH-XXX",
|
||||
"x-stainless-os": "X-STAINLESS-OS-XXX",
|
||||
"x-stainless-read-timeout": "X-STAINLESS-READ-TIMEOUT-XXX",
|
||||
"cf-ray": "CF-RAY-XXX",
|
||||
"etag": "ETAG-XXX",
|
||||
"Strict-Transport-Security": "STS-XXX",
|
||||
"access-control-expose-headers": "ACCESS-CONTROL-XXX",
|
||||
"openai-organization": "OPENAI-ORG-XXX",
|
||||
"openai-project": "OPENAI-PROJECT-XXX",
|
||||
"x-ratelimit-limit-requests": "X-RATELIMIT-LIMIT-REQUESTS-XXX",
|
||||
"x-ratelimit-limit-tokens": "X-RATELIMIT-LIMIT-TOKENS-XXX",
|
||||
"x-ratelimit-remaining-requests": "X-RATELIMIT-REMAINING-REQUESTS-XXX",
|
||||
"x-ratelimit-remaining-tokens": "X-RATELIMIT-REMAINING-TOKENS-XXX",
|
||||
"x-ratelimit-reset-requests": "X-RATELIMIT-RESET-REQUESTS-XXX",
|
||||
"x-ratelimit-reset-tokens": "X-RATELIMIT-RESET-TOKENS-XXX",
|
||||
}
|
||||
|
||||
|
||||
def _filter_response_headers(response) -> dict[str, Any]: # type: ignore
|
||||
"""Filter sensitive headers from response before recording."""
|
||||
for header_name, replacement in HEADERS_TO_FILTER.items():
|
||||
for variant in [header_name, header_name.upper(), header_name.title()]:
|
||||
if variant in response["headers"]:
|
||||
response["headers"][variant] = [replacement]
|
||||
return response # type: ignore
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def vcr_cassette_dir(request: Any) -> str:
|
||||
"""Generate cassette directory path based on test module location.
|
||||
|
||||
Organizes cassettes to mirror test directory structure within each package:
|
||||
lib/crewai/tests/llms/google/test_google.py -> lib/crewai/tests/cassettes/llms/google/
|
||||
lib/crewai-tools/tests/tools/test_search.py -> lib/crewai-tools/tests/cassettes/tools/
|
||||
"""
|
||||
test_file = Path(request.fspath)
|
||||
|
||||
# Find the package root (lib/crewai or lib/crewai-tools)
|
||||
for parent in test_file.parents:
|
||||
if parent.name in ("crewai", "crewai-tools") and parent.parent.name == "lib":
|
||||
package_root = parent
|
||||
break
|
||||
else:
|
||||
# Fallback to old behavior if we can't find package root
|
||||
package_root = test_file.parent
|
||||
|
||||
tests_root = package_root / "tests"
|
||||
test_dir = test_file.parent
|
||||
|
||||
# Get relative path from tests root
|
||||
if test_dir != tests_root:
|
||||
relative_path = test_dir.relative_to(tests_root)
|
||||
cassette_dir = tests_root / "cassettes" / relative_path
|
||||
else:
|
||||
cassette_dir = tests_root / "cassettes"
|
||||
|
||||
cassette_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
return str(cassette_dir)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def vcr_config(vcr_cassette_dir: str) -> dict[str, Any]:
|
||||
"""Configure VCR with organized cassette storage."""
|
||||
config = {
|
||||
"cassette_library_dir": vcr_cassette_dir,
|
||||
"record_mode": os.getenv("PYTEST_VCR_RECORD_MODE") or "once",
|
||||
"filter_headers": [(k, v) for k, v in HEADERS_TO_FILTER.items()],
|
||||
"before_record_response": _filter_response_headers,
|
||||
}
|
||||
|
||||
if os.getenv("GITHUB_ACTIONS") == "true":
|
||||
config["record_mode"] = "none"
|
||||
|
||||
return config
|
||||
@@ -1200,6 +1200,52 @@ Learn how to get the most out of your LLM configuration:
|
||||
)
|
||||
```
|
||||
</Accordion>
|
||||
|
||||
<Accordion title="Transport Interceptors">
|
||||
CrewAI provides message interceptors for several providers, allowing you to hook into request/response cycles at the transport layer.
|
||||
|
||||
**Supported Providers:**
|
||||
- ✅ OpenAI
|
||||
- ✅ Anthropic
|
||||
|
||||
**Basic Usage:**
|
||||
```python
|
||||
import httpx
|
||||
from crewai import LLM
|
||||
from crewai.llms.hooks import BaseInterceptor
|
||||
|
||||
class CustomInterceptor(BaseInterceptor[httpx.Request, httpx.Response]):
|
||||
"""Custom interceptor to modify requests and responses."""
|
||||
|
||||
def on_outbound(self, request: httpx.Request) -> httpx.Request:
|
||||
"""Print request before sending to the LLM provider."""
|
||||
print(request)
|
||||
return request
|
||||
|
||||
def on_inbound(self, response: httpx.Response) -> httpx.Response:
|
||||
"""Process response after receiving from the LLM provider."""
|
||||
print(f"Status: {response.status_code}")
|
||||
print(f"Response time: {response.elapsed}")
|
||||
return response
|
||||
|
||||
# Use the interceptor with an LLM
|
||||
llm = LLM(
|
||||
model="openai/gpt-4o",
|
||||
interceptor=CustomInterceptor()
|
||||
)
|
||||
```
|
||||
|
||||
**Important Notes:**
|
||||
- Both methods must return the received object or type of object.
|
||||
- Modifying received objects may result in unexpected behavior or application crashes.
|
||||
- Not all providers support interceptors - check the supported providers list above
|
||||
|
||||
<Info>
|
||||
Interceptors operate at the transport layer. This is particularly useful for:
|
||||
- Message transformation and filtering
|
||||
- Debugging API interactions
|
||||
</Info>
|
||||
</Accordion>
|
||||
</AccordionGroup>
|
||||
|
||||
## Common Issues and Solutions
|
||||
|
||||
291
docs/en/learn/a2a-agent-delegation.mdx
Normal file
291
docs/en/learn/a2a-agent-delegation.mdx
Normal file
@@ -0,0 +1,291 @@
|
||||
---
|
||||
title: Agent-to-Agent (A2A) Protocol
|
||||
description: Enable CrewAI agents to delegate tasks to remote A2A-compliant agents for specialized handling
|
||||
icon: network-wired
|
||||
mode: "wide"
|
||||
---
|
||||
|
||||
## A2A Agent Delegation
|
||||
|
||||
CrewAI supports the Agent-to-Agent (A2A) protocol, allowing agents to delegate tasks to remote specialized agents. The agent's LLM automatically decides whether to handle a task directly or delegate to an A2A agent based on the task requirements.
|
||||
|
||||
<Note>
|
||||
A2A delegation requires the `a2a-sdk` package. Install with: `uv add 'crewai[a2a]'` or `pip install 'crewai[a2a]'`
|
||||
</Note>
|
||||
|
||||
## How It Works
|
||||
|
||||
When an agent is configured with A2A capabilities:
|
||||
|
||||
1. The LLM analyzes each task
|
||||
2. It decides to either:
|
||||
- Handle the task directly using its own capabilities
|
||||
- Delegate to a remote A2A agent for specialized handling
|
||||
3. If delegating, the agent communicates with the remote A2A agent through the protocol
|
||||
4. Results are returned to the CrewAI workflow
|
||||
|
||||
## Basic Configuration
|
||||
|
||||
Configure an agent for A2A delegation by setting the `a2a` parameter:
|
||||
|
||||
```python Code
|
||||
from crewai import Agent, Crew, Task
|
||||
from crewai.a2a import A2AConfig
|
||||
|
||||
agent = Agent(
|
||||
role="Research Coordinator",
|
||||
goal="Coordinate research tasks efficiently",
|
||||
backstory="Expert at delegating to specialized research agents",
|
||||
llm="gpt-4o",
|
||||
a2a=A2AConfig(
|
||||
endpoint="https://example.com/.well-known/agent-card.json",
|
||||
timeout=120,
|
||||
max_turns=10
|
||||
)
|
||||
)
|
||||
|
||||
task = Task(
|
||||
description="Research the latest developments in quantum computing",
|
||||
expected_output="A comprehensive research report",
|
||||
agent=agent
|
||||
)
|
||||
|
||||
crew = Crew(agents=[agent], tasks=[task], verbose=True)
|
||||
result = crew.kickoff()
|
||||
```
|
||||
|
||||
## Configuration Options
|
||||
|
||||
The `A2AConfig` class accepts the following parameters:
|
||||
|
||||
<ParamField path="endpoint" type="str" required>
|
||||
The A2A agent endpoint URL (typically points to `.well-known/agent-card.json`)
|
||||
</ParamField>
|
||||
|
||||
<ParamField path="auth" type="AuthScheme" default="None">
|
||||
Authentication scheme for the A2A agent. Supports Bearer tokens, OAuth2, API keys, and HTTP authentication.
|
||||
</ParamField>
|
||||
|
||||
<ParamField path="timeout" type="int" default="120">
|
||||
Request timeout in seconds
|
||||
</ParamField>
|
||||
|
||||
<ParamField path="max_turns" type="int" default="10">
|
||||
Maximum number of conversation turns with the A2A agent
|
||||
</ParamField>
|
||||
|
||||
<ParamField path="response_model" type="type[BaseModel]" default="None">
|
||||
Optional Pydantic model for requesting structured output from an A2A agent. A2A protocol does not
|
||||
enforce this, so an A2A agent does not need to honor this request.
|
||||
</ParamField>
|
||||
|
||||
<ParamField path="fail_fast" type="bool" default="True">
|
||||
Whether to raise an error immediately if agent connection fails. When `False`, the agent continues with available agents and informs the LLM about unavailable ones.
|
||||
</ParamField>
|
||||
|
||||
## Authentication
|
||||
|
||||
For A2A agents that require authentication, use one of the provided auth schemes:
|
||||
|
||||
<Tabs>
|
||||
<Tab title="Bearer Token">
|
||||
```python Code
|
||||
from crewai.a2a import A2AConfig
|
||||
from crewai.a2a.auth import BearerTokenAuth
|
||||
|
||||
agent = Agent(
|
||||
role="Secure Coordinator",
|
||||
goal="Coordinate tasks with secured agents",
|
||||
backstory="Manages secure agent communications",
|
||||
llm="gpt-4o",
|
||||
a2a=A2AConfig(
|
||||
endpoint="https://secure-agent.example.com/.well-known/agent-card.json",
|
||||
auth=BearerTokenAuth(token="your-bearer-token"),
|
||||
timeout=120
|
||||
)
|
||||
)
|
||||
```
|
||||
</Tab>
|
||||
|
||||
<Tab title="API Key">
|
||||
```python Code
|
||||
from crewai.a2a import A2AConfig
|
||||
from crewai.a2a.auth import APIKeyAuth
|
||||
|
||||
agent = Agent(
|
||||
role="API Coordinator",
|
||||
goal="Coordinate with API-based agents",
|
||||
backstory="Manages API-authenticated communications",
|
||||
llm="gpt-4o",
|
||||
a2a=A2AConfig(
|
||||
endpoint="https://api-agent.example.com/.well-known/agent-card.json",
|
||||
auth=APIKeyAuth(
|
||||
api_key="your-api-key",
|
||||
location="header", # or "query" or "cookie"
|
||||
name="X-API-Key"
|
||||
),
|
||||
timeout=120
|
||||
)
|
||||
)
|
||||
```
|
||||
</Tab>
|
||||
|
||||
<Tab title="OAuth2">
|
||||
```python Code
|
||||
from crewai.a2a import A2AConfig
|
||||
from crewai.a2a.auth import OAuth2ClientCredentials
|
||||
|
||||
agent = Agent(
|
||||
role="OAuth Coordinator",
|
||||
goal="Coordinate with OAuth-secured agents",
|
||||
backstory="Manages OAuth-authenticated communications",
|
||||
llm="gpt-4o",
|
||||
a2a=A2AConfig(
|
||||
endpoint="https://oauth-agent.example.com/.well-known/agent-card.json",
|
||||
auth=OAuth2ClientCredentials(
|
||||
token_url="https://auth.example.com/oauth/token",
|
||||
client_id="your-client-id",
|
||||
client_secret="your-client-secret",
|
||||
scopes=["read", "write"]
|
||||
),
|
||||
timeout=120
|
||||
)
|
||||
)
|
||||
```
|
||||
</Tab>
|
||||
|
||||
<Tab title="HTTP Basic">
|
||||
```python Code
|
||||
from crewai.a2a import A2AConfig
|
||||
from crewai.a2a.auth import HTTPBasicAuth
|
||||
|
||||
agent = Agent(
|
||||
role="Basic Auth Coordinator",
|
||||
goal="Coordinate with basic auth agents",
|
||||
backstory="Manages basic authentication communications",
|
||||
llm="gpt-4o",
|
||||
a2a=A2AConfig(
|
||||
endpoint="https://basic-agent.example.com/.well-known/agent-card.json",
|
||||
auth=HTTPBasicAuth(
|
||||
username="your-username",
|
||||
password="your-password"
|
||||
),
|
||||
timeout=120
|
||||
)
|
||||
)
|
||||
```
|
||||
</Tab>
|
||||
</Tabs>
|
||||
|
||||
## Multiple A2A Agents
|
||||
|
||||
Configure multiple A2A agents for delegation by passing a list:
|
||||
|
||||
```python Code
|
||||
from crewai.a2a import A2AConfig
|
||||
from crewai.a2a.auth import BearerTokenAuth
|
||||
|
||||
agent = Agent(
|
||||
role="Multi-Agent Coordinator",
|
||||
goal="Coordinate with multiple specialized agents",
|
||||
backstory="Expert at delegating to the right specialist",
|
||||
llm="gpt-4o",
|
||||
a2a=[
|
||||
A2AConfig(
|
||||
endpoint="https://research.example.com/.well-known/agent-card.json",
|
||||
timeout=120
|
||||
),
|
||||
A2AConfig(
|
||||
endpoint="https://data.example.com/.well-known/agent-card.json",
|
||||
auth=BearerTokenAuth(token="data-token"),
|
||||
timeout=90
|
||||
)
|
||||
]
|
||||
)
|
||||
```
|
||||
|
||||
The LLM will automatically choose which A2A agent to delegate to based on the task requirements.
|
||||
|
||||
## Error Handling
|
||||
|
||||
Control how agent connection failures are handled using the `fail_fast` parameter:
|
||||
|
||||
```python Code
|
||||
from crewai.a2a import A2AConfig
|
||||
|
||||
# Fail immediately on connection errors (default)
|
||||
agent = Agent(
|
||||
role="Research Coordinator",
|
||||
goal="Coordinate research tasks",
|
||||
backstory="Expert at delegation",
|
||||
llm="gpt-4o",
|
||||
a2a=A2AConfig(
|
||||
endpoint="https://research.example.com/.well-known/agent-card.json",
|
||||
fail_fast=True
|
||||
)
|
||||
)
|
||||
|
||||
# Continue with available agents
|
||||
agent = Agent(
|
||||
role="Multi-Agent Coordinator",
|
||||
goal="Coordinate with multiple agents",
|
||||
backstory="Expert at working with available resources",
|
||||
llm="gpt-4o",
|
||||
a2a=[
|
||||
A2AConfig(
|
||||
endpoint="https://primary.example.com/.well-known/agent-card.json",
|
||||
fail_fast=False
|
||||
),
|
||||
A2AConfig(
|
||||
endpoint="https://backup.example.com/.well-known/agent-card.json",
|
||||
fail_fast=False
|
||||
)
|
||||
]
|
||||
)
|
||||
```
|
||||
|
||||
When `fail_fast=False`:
|
||||
- If some agents fail, the LLM is informed which agents are unavailable and can delegate to working agents
|
||||
- If all agents fail, the LLM receives a notice about unavailable agents and handles the task directly
|
||||
- Connection errors are captured and included in the context for better decision-making
|
||||
|
||||
## Best Practices
|
||||
|
||||
<CardGroup cols={2}>
|
||||
<Card title="Set Appropriate Timeouts" icon="clock">
|
||||
Configure timeouts based on expected A2A agent response times. Longer-running tasks may need higher timeout values.
|
||||
</Card>
|
||||
|
||||
<Card title="Limit Conversation Turns" icon="comments">
|
||||
Use `max_turns` to prevent excessive back-and-forth. The agent will automatically conclude conversations before hitting the limit.
|
||||
</Card>
|
||||
|
||||
<Card title="Use Resilient Error Handling" icon="shield-check">
|
||||
Set `fail_fast=False` for production environments with multiple agents to gracefully handle connection failures and maintain workflow continuity.
|
||||
</Card>
|
||||
|
||||
<Card title="Secure Your Credentials" icon="lock">
|
||||
Store authentication tokens and credentials as environment variables, not in code.
|
||||
</Card>
|
||||
|
||||
<Card title="Monitor Delegation Decisions" icon="eye">
|
||||
Use verbose mode to observe when the LLM chooses to delegate versus handle tasks directly.
|
||||
</Card>
|
||||
</CardGroup>
|
||||
|
||||
## Supported Authentication Methods
|
||||
|
||||
- **Bearer Token** - Simple token-based authentication
|
||||
- **OAuth2 Client Credentials** - OAuth2 flow for machine-to-machine communication
|
||||
- **OAuth2 Authorization Code** - OAuth2 flow requiring user authorization
|
||||
- **API Key** - Key-based authentication (header, query param, or cookie)
|
||||
- **HTTP Basic** - Username/password authentication
|
||||
- **HTTP Digest** - Digest authentication (requires `httpx-auth` package)
|
||||
|
||||
## Learn More
|
||||
|
||||
For more information about the A2A protocol and reference implementations:
|
||||
|
||||
- [A2A Protocol Documentation](https://a2a-protocol.org)
|
||||
- [A2A Sample Implementations](https://github.com/a2aproject/a2a-samples)
|
||||
- [A2A Python SDK](https://github.com/a2aproject/a2a-python)
|
||||
@@ -11,9 +11,13 @@ The [Model Context Protocol](https://modelcontextprotocol.io/introduction) (MCP)
|
||||
|
||||
CrewAI offers **two approaches** for MCP integration:
|
||||
|
||||
### Simple DSL Integration** (Recommended)
|
||||
### 🚀 **Simple DSL Integration** (Recommended)
|
||||
|
||||
Use the `mcps` field directly on agents for seamless MCP tool integration:
|
||||
Use the `mcps` field directly on agents for seamless MCP tool integration. The DSL supports both **string references** (for quick setup) and **structured configurations** (for full control).
|
||||
|
||||
#### String-Based References (Quick Setup)
|
||||
|
||||
Perfect for remote HTTPS servers and CrewAI AMP marketplace:
|
||||
|
||||
```python
|
||||
from crewai import Agent
|
||||
@@ -32,6 +36,46 @@ agent = Agent(
|
||||
# MCP tools are now automatically available to your agent!
|
||||
```
|
||||
|
||||
#### Structured Configurations (Full Control)
|
||||
|
||||
For complete control over connection settings, tool filtering, and all transport types:
|
||||
|
||||
```python
|
||||
from crewai import Agent
|
||||
from crewai.mcp import MCPServerStdio, MCPServerHTTP, MCPServerSSE
|
||||
from crewai.mcp.filters import create_static_tool_filter
|
||||
|
||||
agent = Agent(
|
||||
role="Advanced Research Analyst",
|
||||
goal="Research with full control over MCP connections",
|
||||
backstory="Expert researcher with advanced tool access",
|
||||
mcps=[
|
||||
# Stdio transport for local servers
|
||||
MCPServerStdio(
|
||||
command="npx",
|
||||
args=["-y", "@modelcontextprotocol/server-filesystem"],
|
||||
env={"API_KEY": "your_key"},
|
||||
tool_filter=create_static_tool_filter(
|
||||
allowed_tool_names=["read_file", "list_directory"]
|
||||
),
|
||||
cache_tools_list=True,
|
||||
),
|
||||
# HTTP/Streamable HTTP transport for remote servers
|
||||
MCPServerHTTP(
|
||||
url="https://api.example.com/mcp",
|
||||
headers={"Authorization": "Bearer your_token"},
|
||||
streamable=True,
|
||||
cache_tools_list=True,
|
||||
),
|
||||
# SSE transport for real-time streaming
|
||||
MCPServerSSE(
|
||||
url="https://stream.example.com/mcp/sse",
|
||||
headers={"Authorization": "Bearer your_token"},
|
||||
),
|
||||
]
|
||||
)
|
||||
```
|
||||
|
||||
### 🔧 **Advanced: MCPServerAdapter** (For Complex Scenarios)
|
||||
|
||||
For advanced use cases requiring manual connection management, the `crewai-tools` library provides the `MCPServerAdapter` class.
|
||||
@@ -68,12 +112,14 @@ uv pip install 'crewai-tools[mcp]'
|
||||
|
||||
## Quick Start: Simple DSL Integration
|
||||
|
||||
The easiest way to integrate MCP servers is using the `mcps` field on your agents:
|
||||
The easiest way to integrate MCP servers is using the `mcps` field on your agents. You can use either string references or structured configurations.
|
||||
|
||||
### Quick Start with String References
|
||||
|
||||
```python
|
||||
from crewai import Agent, Task, Crew
|
||||
|
||||
# Create agent with MCP tools
|
||||
# Create agent with MCP tools using string references
|
||||
research_agent = Agent(
|
||||
role="Research Analyst",
|
||||
goal="Find and analyze information using advanced search tools",
|
||||
@@ -96,13 +142,53 @@ crew = Crew(agents=[research_agent], tasks=[research_task])
|
||||
result = crew.kickoff()
|
||||
```
|
||||
|
||||
### Quick Start with Structured Configurations
|
||||
|
||||
```python
|
||||
from crewai import Agent, Task, Crew
|
||||
from crewai.mcp import MCPServerStdio, MCPServerHTTP, MCPServerSSE
|
||||
|
||||
# Create agent with structured MCP configurations
|
||||
research_agent = Agent(
|
||||
role="Research Analyst",
|
||||
goal="Find and analyze information using advanced search tools",
|
||||
backstory="Expert researcher with access to multiple data sources",
|
||||
mcps=[
|
||||
# Local stdio server
|
||||
MCPServerStdio(
|
||||
command="python",
|
||||
args=["local_server.py"],
|
||||
env={"API_KEY": "your_key"},
|
||||
),
|
||||
# Remote HTTP server
|
||||
MCPServerHTTP(
|
||||
url="https://api.research.com/mcp",
|
||||
headers={"Authorization": "Bearer your_token"},
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
# Create task
|
||||
research_task = Task(
|
||||
description="Research the latest developments in AI agent frameworks",
|
||||
expected_output="Comprehensive research report with citations",
|
||||
agent=research_agent
|
||||
)
|
||||
|
||||
# Create and run crew
|
||||
crew = Crew(agents=[research_agent], tasks=[research_task])
|
||||
result = crew.kickoff()
|
||||
```
|
||||
|
||||
That's it! The MCP tools are automatically discovered and available to your agent.
|
||||
|
||||
## MCP Reference Formats
|
||||
|
||||
The `mcps` field supports various reference formats for maximum flexibility:
|
||||
The `mcps` field supports both **string references** (for quick setup) and **structured configurations** (for full control). You can mix both formats in the same list.
|
||||
|
||||
### External MCP Servers
|
||||
### String-Based References
|
||||
|
||||
#### External MCP Servers
|
||||
|
||||
```python
|
||||
mcps=[
|
||||
@@ -117,7 +203,7 @@ mcps=[
|
||||
]
|
||||
```
|
||||
|
||||
### CrewAI AMP Marketplace
|
||||
#### CrewAI AMP Marketplace
|
||||
|
||||
```python
|
||||
mcps=[
|
||||
@@ -133,17 +219,166 @@ mcps=[
|
||||
]
|
||||
```
|
||||
|
||||
### Mixed References
|
||||
### Structured Configurations
|
||||
|
||||
#### Stdio Transport (Local Servers)
|
||||
|
||||
Perfect for local MCP servers that run as processes:
|
||||
|
||||
```python
|
||||
from crewai.mcp import MCPServerStdio
|
||||
from crewai.mcp.filters import create_static_tool_filter
|
||||
|
||||
mcps=[
|
||||
"https://external-api.com/mcp", # External server
|
||||
"https://weather.service.com/mcp#forecast", # Specific external tool
|
||||
"crewai-amp:financial-insights", # AMP service
|
||||
"crewai-amp:data-analysis#sentiment_tool" # Specific AMP tool
|
||||
MCPServerStdio(
|
||||
command="npx",
|
||||
args=["-y", "@modelcontextprotocol/server-filesystem"],
|
||||
env={"API_KEY": "your_key"},
|
||||
tool_filter=create_static_tool_filter(
|
||||
allowed_tool_names=["read_file", "write_file"]
|
||||
),
|
||||
cache_tools_list=True,
|
||||
),
|
||||
# Python-based server
|
||||
MCPServerStdio(
|
||||
command="python",
|
||||
args=["path/to/server.py"],
|
||||
env={"UV_PYTHON": "3.12", "API_KEY": "your_key"},
|
||||
),
|
||||
]
|
||||
```
|
||||
|
||||
#### HTTP/Streamable HTTP Transport (Remote Servers)
|
||||
|
||||
For remote MCP servers over HTTP/HTTPS:
|
||||
|
||||
```python
|
||||
from crewai.mcp import MCPServerHTTP
|
||||
|
||||
mcps=[
|
||||
# Streamable HTTP (default)
|
||||
MCPServerHTTP(
|
||||
url="https://api.example.com/mcp",
|
||||
headers={"Authorization": "Bearer your_token"},
|
||||
streamable=True,
|
||||
cache_tools_list=True,
|
||||
),
|
||||
# Standard HTTP
|
||||
MCPServerHTTP(
|
||||
url="https://api.example.com/mcp",
|
||||
headers={"Authorization": "Bearer your_token"},
|
||||
streamable=False,
|
||||
),
|
||||
]
|
||||
```
|
||||
|
||||
#### SSE Transport (Real-Time Streaming)
|
||||
|
||||
For remote servers using Server-Sent Events:
|
||||
|
||||
```python
|
||||
from crewai.mcp import MCPServerSSE
|
||||
|
||||
mcps=[
|
||||
MCPServerSSE(
|
||||
url="https://stream.example.com/mcp/sse",
|
||||
headers={"Authorization": "Bearer your_token"},
|
||||
cache_tools_list=True,
|
||||
),
|
||||
]
|
||||
```
|
||||
|
||||
### Mixed References
|
||||
|
||||
You can combine string references and structured configurations:
|
||||
|
||||
```python
|
||||
from crewai.mcp import MCPServerStdio, MCPServerHTTP
|
||||
|
||||
mcps=[
|
||||
# String references
|
||||
"https://external-api.com/mcp", # External server
|
||||
"crewai-amp:financial-insights", # AMP service
|
||||
|
||||
# Structured configurations
|
||||
MCPServerStdio(
|
||||
command="npx",
|
||||
args=["-y", "@modelcontextprotocol/server-filesystem"],
|
||||
),
|
||||
MCPServerHTTP(
|
||||
url="https://api.example.com/mcp",
|
||||
headers={"Authorization": "Bearer token"},
|
||||
),
|
||||
]
|
||||
```
|
||||
|
||||
### Tool Filtering
|
||||
|
||||
Structured configurations support advanced tool filtering:
|
||||
|
||||
```python
|
||||
from crewai.mcp import MCPServerStdio
|
||||
from crewai.mcp.filters import create_static_tool_filter, create_dynamic_tool_filter, ToolFilterContext
|
||||
|
||||
# Static filtering (allow/block lists)
|
||||
static_filter = create_static_tool_filter(
|
||||
allowed_tool_names=["read_file", "write_file"],
|
||||
blocked_tool_names=["delete_file"],
|
||||
)
|
||||
|
||||
# Dynamic filtering (context-aware)
|
||||
def dynamic_filter(context: ToolFilterContext, tool: dict) -> bool:
|
||||
# Block dangerous tools for certain agent roles
|
||||
if context.agent.role == "Code Reviewer":
|
||||
if "delete" in tool.get("name", "").lower():
|
||||
return False
|
||||
return True
|
||||
|
||||
mcps=[
|
||||
MCPServerStdio(
|
||||
command="npx",
|
||||
args=["-y", "@modelcontextprotocol/server-filesystem"],
|
||||
tool_filter=static_filter, # or dynamic_filter
|
||||
),
|
||||
]
|
||||
```
|
||||
|
||||
## Configuration Parameters
|
||||
|
||||
Each transport type supports specific configuration options:
|
||||
|
||||
### MCPServerStdio Parameters
|
||||
|
||||
- **`command`** (required): Command to execute (e.g., `"python"`, `"node"`, `"npx"`, `"uvx"`)
|
||||
- **`args`** (optional): List of command arguments (e.g., `["server.py"]` or `["-y", "@mcp/server"]`)
|
||||
- **`env`** (optional): Dictionary of environment variables to pass to the process
|
||||
- **`tool_filter`** (optional): Tool filter function for filtering available tools
|
||||
- **`cache_tools_list`** (optional): Whether to cache the tool list for faster subsequent access (default: `False`)
|
||||
|
||||
### MCPServerHTTP Parameters
|
||||
|
||||
- **`url`** (required): Server URL (e.g., `"https://api.example.com/mcp"`)
|
||||
- **`headers`** (optional): Dictionary of HTTP headers for authentication or other purposes
|
||||
- **`streamable`** (optional): Whether to use streamable HTTP transport (default: `True`)
|
||||
- **`tool_filter`** (optional): Tool filter function for filtering available tools
|
||||
- **`cache_tools_list`** (optional): Whether to cache the tool list for faster subsequent access (default: `False`)
|
||||
|
||||
### MCPServerSSE Parameters
|
||||
|
||||
- **`url`** (required): Server URL (e.g., `"https://api.example.com/mcp/sse"`)
|
||||
- **`headers`** (optional): Dictionary of HTTP headers for authentication or other purposes
|
||||
- **`tool_filter`** (optional): Tool filter function for filtering available tools
|
||||
- **`cache_tools_list`** (optional): Whether to cache the tool list for faster subsequent access (default: `False`)
|
||||
|
||||
### Common Parameters
|
||||
|
||||
All transport types support:
|
||||
- **`tool_filter`**: Filter function to control which tools are available. Can be:
|
||||
- `None` (default): All tools are available
|
||||
- Static filter: Created with `create_static_tool_filter()` for allow/block lists
|
||||
- Dynamic filter: Created with `create_dynamic_tool_filter()` for context-aware filtering
|
||||
- **`cache_tools_list`**: When `True`, caches the tool list after first discovery to improve performance on subsequent connections
|
||||
|
||||
## Key Features
|
||||
|
||||
- 🔄 **Automatic Tool Discovery**: Tools are automatically discovered and integrated
|
||||
@@ -152,26 +387,47 @@ mcps=[
|
||||
- 🛡️ **Error Resilience**: Graceful handling of unavailable servers
|
||||
- ⏱️ **Timeout Protection**: Built-in timeouts prevent hanging connections
|
||||
- 📊 **Transparent Integration**: Works seamlessly with existing CrewAI features
|
||||
- 🔧 **Full Transport Support**: Stdio, HTTP/Streamable HTTP, and SSE transports
|
||||
- 🎯 **Advanced Filtering**: Static and dynamic tool filtering capabilities
|
||||
- 🔐 **Flexible Authentication**: Support for headers, environment variables, and query parameters
|
||||
|
||||
## Error Handling
|
||||
|
||||
The MCP DSL integration is designed to be resilient:
|
||||
The MCP DSL integration is designed to be resilient and handles failures gracefully:
|
||||
|
||||
```python
|
||||
from crewai import Agent
|
||||
from crewai.mcp import MCPServerStdio, MCPServerHTTP
|
||||
|
||||
agent = Agent(
|
||||
role="Resilient Agent",
|
||||
goal="Continue working despite server issues",
|
||||
backstory="Agent that handles failures gracefully",
|
||||
mcps=[
|
||||
# String references
|
||||
"https://reliable-server.com/mcp", # Will work
|
||||
"https://unreachable-server.com/mcp", # Will be skipped gracefully
|
||||
"https://slow-server.com/mcp", # Will timeout gracefully
|
||||
"crewai-amp:working-service" # Will work
|
||||
"crewai-amp:working-service", # Will work
|
||||
|
||||
# Structured configs
|
||||
MCPServerStdio(
|
||||
command="python",
|
||||
args=["reliable_server.py"], # Will work
|
||||
),
|
||||
MCPServerHTTP(
|
||||
url="https://slow-server.com/mcp", # Will timeout gracefully
|
||||
),
|
||||
]
|
||||
)
|
||||
# Agent will use tools from working servers and log warnings for failing ones
|
||||
```
|
||||
|
||||
All connection errors are handled gracefully:
|
||||
- **Connection failures**: Logged as warnings, agent continues with available tools
|
||||
- **Timeout errors**: Connections timeout after 30 seconds (configurable)
|
||||
- **Authentication errors**: Logged clearly for debugging
|
||||
- **Invalid configurations**: Validation errors are raised at agent creation time
|
||||
|
||||
## Advanced: MCPServerAdapter
|
||||
|
||||
For complex scenarios requiring manual connection management, use the `MCPServerAdapter` class from `crewai-tools`. Using a Python context manager (`with` statement) is the recommended approach as it automatically handles starting and stopping the connection to the MCP server.
|
||||
|
||||
@@ -93,11 +93,15 @@ After running the application, you can view the traces in [Datadog LLM Observabi
|
||||
|
||||
Clicking on a trace will show you the details of the trace, including total tokens used, number of LLM calls, models used, and estimated cost. Clicking into a specific span will narrow down these details, and show related input, output, and metadata.
|
||||
|
||||

|
||||
<Frame>
|
||||
<img src="/images/datadog-llm-observability-1.png" alt="Datadog LLM Observability Trace View" />
|
||||
</Frame>
|
||||
|
||||
Additionally, you can view the execution graph view of the trace, which shows the control and data flow of the trace, which will scale with larger agents to show handoffs and relationships between LLM calls, tool calls, and agent interactions.
|
||||
|
||||

|
||||
<Frame>
|
||||
<img src="/images/datadog-llm-observability-2.png" alt="Datadog LLM Observability Agent Execution Flow View" />
|
||||
</Frame>
|
||||
|
||||
## References
|
||||
|
||||
|
||||
@@ -23,13 +23,15 @@ Here's a minimal example of how to use the tool:
|
||||
|
||||
```python
|
||||
from crewai import Agent
|
||||
from crewai_tools import QdrantVectorSearchTool
|
||||
from crewai_tools import QdrantVectorSearchTool, QdrantConfig
|
||||
|
||||
# Initialize the tool
|
||||
# Initialize the tool with QdrantConfig
|
||||
qdrant_tool = QdrantVectorSearchTool(
|
||||
qdrant_url="your_qdrant_url",
|
||||
qdrant_api_key="your_qdrant_api_key",
|
||||
collection_name="your_collection"
|
||||
qdrant_config=QdrantConfig(
|
||||
qdrant_url="your_qdrant_url",
|
||||
qdrant_api_key="your_qdrant_api_key",
|
||||
collection_name="your_collection"
|
||||
)
|
||||
)
|
||||
|
||||
# Create an agent that uses the tool
|
||||
@@ -82,7 +84,7 @@ def extract_text_from_pdf(pdf_path):
|
||||
def get_openai_embedding(text):
|
||||
response = client.embeddings.create(
|
||||
input=text,
|
||||
model="text-embedding-3-small"
|
||||
model="text-embedding-3-large"
|
||||
)
|
||||
return response.data[0].embedding
|
||||
|
||||
@@ -90,13 +92,13 @@ def get_openai_embedding(text):
|
||||
def load_pdf_to_qdrant(pdf_path, qdrant, collection_name):
|
||||
# Extract text from PDF
|
||||
text_chunks = extract_text_from_pdf(pdf_path)
|
||||
|
||||
|
||||
# Create Qdrant collection
|
||||
if qdrant.collection_exists(collection_name):
|
||||
qdrant.delete_collection(collection_name)
|
||||
qdrant.create_collection(
|
||||
collection_name=collection_name,
|
||||
vectors_config=VectorParams(size=1536, distance=Distance.COSINE)
|
||||
vectors_config=VectorParams(size=3072, distance=Distance.COSINE)
|
||||
)
|
||||
|
||||
# Store embeddings
|
||||
@@ -120,19 +122,23 @@ pdf_path = "path/to/your/document.pdf"
|
||||
load_pdf_to_qdrant(pdf_path, qdrant, collection_name)
|
||||
|
||||
# Initialize Qdrant search tool
|
||||
from crewai_tools import QdrantConfig
|
||||
|
||||
qdrant_tool = QdrantVectorSearchTool(
|
||||
qdrant_url=os.getenv("QDRANT_URL"),
|
||||
qdrant_api_key=os.getenv("QDRANT_API_KEY"),
|
||||
collection_name=collection_name,
|
||||
limit=3,
|
||||
score_threshold=0.35
|
||||
qdrant_config=QdrantConfig(
|
||||
qdrant_url=os.getenv("QDRANT_URL"),
|
||||
qdrant_api_key=os.getenv("QDRANT_API_KEY"),
|
||||
collection_name=collection_name,
|
||||
limit=3,
|
||||
score_threshold=0.35
|
||||
)
|
||||
)
|
||||
|
||||
# Create CrewAI agents
|
||||
search_agent = Agent(
|
||||
role="Senior Semantic Search Agent",
|
||||
goal="Find and analyze documents based on semantic search",
|
||||
backstory="""You are an expert research assistant who can find relevant
|
||||
backstory="""You are an expert research assistant who can find relevant
|
||||
information using semantic search in a Qdrant database.""",
|
||||
tools=[qdrant_tool],
|
||||
verbose=True
|
||||
@@ -141,7 +147,7 @@ search_agent = Agent(
|
||||
answer_agent = Agent(
|
||||
role="Senior Answer Assistant",
|
||||
goal="Generate answers to questions based on the context provided",
|
||||
backstory="""You are an expert answer assistant who can generate
|
||||
backstory="""You are an expert answer assistant who can generate
|
||||
answers to questions based on the context provided.""",
|
||||
tools=[qdrant_tool],
|
||||
verbose=True
|
||||
@@ -180,21 +186,82 @@ print(result)
|
||||
## Tool Parameters
|
||||
|
||||
### Required Parameters
|
||||
- `qdrant_url` (str): The URL of your Qdrant server
|
||||
- `qdrant_api_key` (str): API key for authentication with Qdrant
|
||||
- `collection_name` (str): Name of the Qdrant collection to search
|
||||
- `qdrant_config` (QdrantConfig): Configuration object containing all Qdrant settings
|
||||
|
||||
### Optional Parameters
|
||||
### QdrantConfig Parameters
|
||||
- `qdrant_url` (str): The URL of your Qdrant server
|
||||
- `qdrant_api_key` (str, optional): API key for authentication with Qdrant
|
||||
- `collection_name` (str): Name of the Qdrant collection to search
|
||||
- `limit` (int): Maximum number of results to return (default: 3)
|
||||
- `score_threshold` (float): Minimum similarity score threshold (default: 0.35)
|
||||
- `filter` (Any, optional): Qdrant Filter instance for advanced filtering (default: None)
|
||||
|
||||
### Optional Tool Parameters
|
||||
- `custom_embedding_fn` (Callable[[str], list[float]]): Custom function for text vectorization
|
||||
- `qdrant_package` (str): Base package path for Qdrant (default: "qdrant_client")
|
||||
- `client` (Any): Pre-initialized Qdrant client (optional)
|
||||
|
||||
## Advanced Filtering
|
||||
|
||||
The QdrantVectorSearchTool supports powerful filtering capabilities to refine your search results:
|
||||
|
||||
### Dynamic Filtering
|
||||
Use `filter_by` and `filter_value` parameters in your search to filter results on-the-fly:
|
||||
|
||||
```python
|
||||
# Agent will use these parameters when calling the tool
|
||||
# The tool schema accepts filter_by and filter_value
|
||||
# Example: search with category filter
|
||||
# Results will be filtered where category == "technology"
|
||||
```
|
||||
|
||||
### Preset Filters with QdrantConfig
|
||||
For complex filtering, use Qdrant Filter instances in your configuration:
|
||||
|
||||
```python
|
||||
from qdrant_client.http import models as qmodels
|
||||
from crewai_tools import QdrantVectorSearchTool, QdrantConfig
|
||||
|
||||
# Create a filter for specific conditions
|
||||
preset_filter = qmodels.Filter(
|
||||
must=[
|
||||
qmodels.FieldCondition(
|
||||
key="category",
|
||||
match=qmodels.MatchValue(value="research")
|
||||
),
|
||||
qmodels.FieldCondition(
|
||||
key="year",
|
||||
match=qmodels.MatchValue(value=2024)
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
# Initialize tool with preset filter
|
||||
qdrant_tool = QdrantVectorSearchTool(
|
||||
qdrant_config=QdrantConfig(
|
||||
qdrant_url="your_url",
|
||||
qdrant_api_key="your_key",
|
||||
collection_name="your_collection",
|
||||
filter=preset_filter # Preset filter applied to all searches
|
||||
)
|
||||
)
|
||||
```
|
||||
|
||||
### Combining Filters
|
||||
The tool automatically combines preset filters from `QdrantConfig` with dynamic filters from `filter_by` and `filter_value`:
|
||||
|
||||
```python
|
||||
# If QdrantConfig has a preset filter for category="research"
|
||||
# And the search uses filter_by="year", filter_value=2024
|
||||
# Both filters will be combined (AND logic)
|
||||
```
|
||||
|
||||
## Search Parameters
|
||||
|
||||
The tool accepts these parameters in its schema:
|
||||
- `query` (str): The search query to find similar documents
|
||||
- `filter_by` (str, optional): Metadata field to filter on
|
||||
- `filter_value` (str, optional): Value to filter by
|
||||
- `filter_value` (Any, optional): Value to filter by
|
||||
|
||||
## Return Format
|
||||
|
||||
@@ -214,7 +281,7 @@ The tool returns results in JSON format:
|
||||
|
||||
## Default Embedding
|
||||
|
||||
By default, the tool uses OpenAI's `text-embedding-3-small` model for vectorization. This requires:
|
||||
By default, the tool uses OpenAI's `text-embedding-3-large` model for vectorization. This requires:
|
||||
- OpenAI API key set in environment: `OPENAI_API_KEY`
|
||||
|
||||
## Custom Embeddings
|
||||
@@ -240,18 +307,22 @@ def custom_embeddings(text: str) -> list[float]:
|
||||
# Tokenize and get model outputs
|
||||
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
|
||||
outputs = model(**inputs)
|
||||
|
||||
|
||||
# Use mean pooling to get text embedding
|
||||
embeddings = outputs.last_hidden_state.mean(dim=1)
|
||||
|
||||
|
||||
# Convert to list of floats and return
|
||||
return embeddings[0].tolist()
|
||||
|
||||
# Use custom embeddings with the tool
|
||||
from crewai_tools import QdrantConfig
|
||||
|
||||
tool = QdrantVectorSearchTool(
|
||||
qdrant_url="your_url",
|
||||
qdrant_api_key="your_key",
|
||||
collection_name="your_collection",
|
||||
qdrant_config=QdrantConfig(
|
||||
qdrant_url="your_url",
|
||||
qdrant_api_key="your_key",
|
||||
collection_name="your_collection"
|
||||
),
|
||||
custom_embedding_fn=custom_embeddings # Pass your custom function
|
||||
)
|
||||
```
|
||||
@@ -269,4 +340,4 @@ Required environment variables:
|
||||
```bash
|
||||
export QDRANT_URL="your_qdrant_url" # If not provided in constructor
|
||||
export QDRANT_API_KEY="your_api_key" # If not provided in constructor
|
||||
export OPENAI_API_KEY="your_openai_key" # If using default embeddings
|
||||
export OPENAI_API_KEY="your_openai_key" # If using default embeddings
|
||||
|
||||
@@ -54,25 +54,25 @@ The following parameters can be used to customize the `CSVSearchTool`'s behavior
|
||||
By default, the tool uses OpenAI for both embeddings and summarization. To customize the model, you can use a config dictionary as follows:
|
||||
|
||||
```python Code
|
||||
from chromadb.config import Settings
|
||||
|
||||
tool = CSVSearchTool(
|
||||
config=dict(
|
||||
llm=dict(
|
||||
provider="ollama", # or google, openai, anthropic, llama2, ...
|
||||
config=dict(
|
||||
model="llama2",
|
||||
# temperature=0.5,
|
||||
# top_p=1,
|
||||
# stream=true,
|
||||
),
|
||||
),
|
||||
embedder=dict(
|
||||
provider="google", # or openai, ollama, ...
|
||||
config=dict(
|
||||
model="models/embedding-001",
|
||||
task_type="retrieval_document",
|
||||
# title="Embeddings",
|
||||
),
|
||||
),
|
||||
)
|
||||
config={
|
||||
"embedding_model": {
|
||||
"provider": "openai",
|
||||
"config": {
|
||||
"model": "text-embedding-3-small",
|
||||
# "api_key": "sk-...",
|
||||
},
|
||||
},
|
||||
"vectordb": {
|
||||
"provider": "chromadb", # or "qdrant"
|
||||
"config": {
|
||||
# "settings": Settings(persist_directory="/content/chroma", allow_reset=True, is_persistent=True),
|
||||
# from qdrant_client.models import VectorParams, Distance
|
||||
# "vectors_config": VectorParams(size=384, distance=Distance.COSINE),
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
||||
```
|
||||
@@ -46,23 +46,25 @@ tool = DirectorySearchTool(directory='/path/to/directory')
|
||||
The DirectorySearchTool uses OpenAI for embeddings and summarization by default. Customization options for these settings include changing the model provider and configuration, enhancing flexibility for advanced users.
|
||||
|
||||
```python Code
|
||||
from chromadb.config import Settings
|
||||
|
||||
tool = DirectorySearchTool(
|
||||
config=dict(
|
||||
llm=dict(
|
||||
provider="ollama", # Options include ollama, google, anthropic, llama2, and more
|
||||
config=dict(
|
||||
model="llama2",
|
||||
# Additional configurations here
|
||||
),
|
||||
),
|
||||
embedder=dict(
|
||||
provider="google", # or openai, ollama, ...
|
||||
config=dict(
|
||||
model="models/embedding-001",
|
||||
task_type="retrieval_document",
|
||||
# title="Embeddings",
|
||||
),
|
||||
),
|
||||
)
|
||||
config={
|
||||
"embedding_model": {
|
||||
"provider": "openai",
|
||||
"config": {
|
||||
"model": "text-embedding-3-small",
|
||||
# "api_key": "sk-...",
|
||||
},
|
||||
},
|
||||
"vectordb": {
|
||||
"provider": "chromadb", # or "qdrant"
|
||||
"config": {
|
||||
# "settings": Settings(persist_directory="/content/chroma", allow_reset=True, is_persistent=True),
|
||||
# from qdrant_client.models import VectorParams, Distance
|
||||
# "vectors_config": VectorParams(size=384, distance=Distance.COSINE),
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
||||
```
|
||||
@@ -56,25 +56,25 @@ The following parameters can be used to customize the `DOCXSearchTool`'s behavio
|
||||
By default, the tool uses OpenAI for both embeddings and summarization. To customize the model, you can use a config dictionary as follows:
|
||||
|
||||
```python Code
|
||||
from chromadb.config import Settings
|
||||
|
||||
tool = DOCXSearchTool(
|
||||
config=dict(
|
||||
llm=dict(
|
||||
provider="ollama", # or google, openai, anthropic, llama2, ...
|
||||
config=dict(
|
||||
model="llama2",
|
||||
# temperature=0.5,
|
||||
# top_p=1,
|
||||
# stream=true,
|
||||
),
|
||||
),
|
||||
embedder=dict(
|
||||
provider="google", # or openai, ollama, ...
|
||||
config=dict(
|
||||
model="models/embedding-001",
|
||||
task_type="retrieval_document",
|
||||
# title="Embeddings",
|
||||
),
|
||||
),
|
||||
)
|
||||
config={
|
||||
"embedding_model": {
|
||||
"provider": "openai",
|
||||
"config": {
|
||||
"model": "text-embedding-3-small",
|
||||
# "api_key": "sk-...",
|
||||
},
|
||||
},
|
||||
"vectordb": {
|
||||
"provider": "chromadb", # or "qdrant"
|
||||
"config": {
|
||||
# "settings": Settings(persist_directory="/content/chroma", allow_reset=True, is_persistent=True),
|
||||
# from qdrant_client.models import VectorParams, Distance
|
||||
# "vectors_config": VectorParams(size=384, distance=Distance.COSINE),
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
@@ -48,27 +48,25 @@ tool = MDXSearchTool(mdx='path/to/your/document.mdx')
|
||||
The tool defaults to using OpenAI for embeddings and summarization. For customization, utilize a configuration dictionary as shown below:
|
||||
|
||||
```python Code
|
||||
from chromadb.config import Settings
|
||||
|
||||
tool = MDXSearchTool(
|
||||
config=dict(
|
||||
llm=dict(
|
||||
provider="ollama", # Options include google, openai, anthropic, llama2, etc.
|
||||
config=dict(
|
||||
model="llama2",
|
||||
# Optional parameters can be included here.
|
||||
# temperature=0.5,
|
||||
# top_p=1,
|
||||
# stream=true,
|
||||
),
|
||||
),
|
||||
embedder=dict(
|
||||
provider="google", # or openai, ollama, ...
|
||||
config=dict(
|
||||
model="models/embedding-001",
|
||||
task_type="retrieval_document",
|
||||
# Optional title for the embeddings can be added here.
|
||||
# title="Embeddings",
|
||||
),
|
||||
),
|
||||
)
|
||||
config={
|
||||
"embedding_model": {
|
||||
"provider": "openai",
|
||||
"config": {
|
||||
"model": "text-embedding-3-small",
|
||||
# "api_key": "sk-...",
|
||||
},
|
||||
},
|
||||
"vectordb": {
|
||||
"provider": "chromadb", # or "qdrant"
|
||||
"config": {
|
||||
# "settings": Settings(persist_directory="/content/chroma", allow_reset=True, is_persistent=True),
|
||||
# from qdrant_client.models import VectorParams, Distance
|
||||
# "vectors_config": VectorParams(size=384, distance=Distance.COSINE),
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
||||
```
|
||||
@@ -45,28 +45,64 @@ tool = PDFSearchTool(pdf='path/to/your/document.pdf')
|
||||
|
||||
## Custom model and embeddings
|
||||
|
||||
By default, the tool uses OpenAI for both embeddings and summarization. To customize the model, you can use a config dictionary as follows:
|
||||
By default, the tool uses OpenAI for both embeddings and summarization. To customize the model, you can use a config dictionary as follows. Note: a vector database is required because generated embeddings must be stored and queried from a vectordb.
|
||||
|
||||
```python Code
|
||||
from crewai_tools import PDFSearchTool
|
||||
|
||||
# - embedding_model (required): choose provider + provider-specific config
|
||||
# - vectordb (required): choose vector DB and pass its config
|
||||
|
||||
tool = PDFSearchTool(
|
||||
config=dict(
|
||||
llm=dict(
|
||||
provider="ollama", # or google, openai, anthropic, llama2, ...
|
||||
config=dict(
|
||||
model="llama2",
|
||||
# temperature=0.5,
|
||||
# top_p=1,
|
||||
# stream=true,
|
||||
),
|
||||
),
|
||||
embedder=dict(
|
||||
provider="google", # or openai, ollama, ...
|
||||
config=dict(
|
||||
model="models/embedding-001",
|
||||
task_type="retrieval_document",
|
||||
# title="Embeddings",
|
||||
),
|
||||
),
|
||||
)
|
||||
config={
|
||||
"embedding_model": {
|
||||
# Supported providers: "openai", "azure", "google-generativeai", "google-vertex",
|
||||
# "voyageai", "cohere", "huggingface", "jina", "sentence-transformer",
|
||||
# "text2vec", "ollama", "openclip", "instructor", "onnx", "roboflow", "watsonx", "custom"
|
||||
"provider": "openai", # or: "google-generativeai", "cohere", "ollama", ...
|
||||
"config": {
|
||||
# Model identifier for the chosen provider. "model" will be auto-mapped to "model_name" internally.
|
||||
"model": "text-embedding-3-small",
|
||||
# Optional: API key. If omitted, the tool will use provider-specific env vars when available
|
||||
# (e.g., OPENAI_API_KEY for provider="openai").
|
||||
# "api_key": "sk-...",
|
||||
|
||||
# Provider-specific examples:
|
||||
# --- Google Generative AI ---
|
||||
# (Set provider="google-generativeai" above)
|
||||
# "model": "models/embedding-001",
|
||||
# "task_type": "retrieval_document",
|
||||
# "title": "Embeddings",
|
||||
|
||||
# --- Cohere ---
|
||||
# (Set provider="cohere" above)
|
||||
# "model": "embed-english-v3.0",
|
||||
|
||||
# --- Ollama (local) ---
|
||||
# (Set provider="ollama" above)
|
||||
# "model": "nomic-embed-text",
|
||||
},
|
||||
},
|
||||
"vectordb": {
|
||||
"provider": "chromadb", # or "qdrant"
|
||||
"config": {
|
||||
# For ChromaDB: pass "settings" (chromadb.config.Settings) or rely on defaults.
|
||||
# Example (uncomment and import):
|
||||
# from chromadb.config import Settings
|
||||
# "settings": Settings(
|
||||
# persist_directory="/content/chroma",
|
||||
# allow_reset=True,
|
||||
# is_persistent=True,
|
||||
# ),
|
||||
|
||||
# For Qdrant: pass "vectors_config" (qdrant_client.models.VectorParams).
|
||||
# Example (uncomment and import):
|
||||
# from qdrant_client.models import VectorParams, Distance
|
||||
# "vectors_config": VectorParams(size=384, distance=Distance.COSINE),
|
||||
|
||||
# Note: collection name is controlled by the tool (default: "rag_tool_collection"), not set here.
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
||||
```
|
||||
@@ -57,25 +57,41 @@ By default, the tool uses OpenAI for both embeddings and summarization.
|
||||
To customize the model, you can use a config dictionary as follows:
|
||||
|
||||
```python Code
|
||||
from chromadb.config import Settings
|
||||
|
||||
tool = TXTSearchTool(
|
||||
config=dict(
|
||||
llm=dict(
|
||||
provider="ollama", # or google, openai, anthropic, llama2, ...
|
||||
config=dict(
|
||||
model="llama2",
|
||||
# temperature=0.5,
|
||||
# top_p=1,
|
||||
# stream=true,
|
||||
),
|
||||
),
|
||||
embedder=dict(
|
||||
provider="google", # or openai, ollama, ...
|
||||
config=dict(
|
||||
model="models/embedding-001",
|
||||
task_type="retrieval_document",
|
||||
# title="Embeddings",
|
||||
),
|
||||
),
|
||||
)
|
||||
config={
|
||||
# Required: embeddings provider + config
|
||||
"embedding_model": {
|
||||
"provider": "openai", # or google-generativeai, cohere, ollama, ...
|
||||
"config": {
|
||||
"model": "text-embedding-3-small",
|
||||
# "api_key": "sk-...", # optional if env var is set
|
||||
# Provider examples:
|
||||
# Google → model: "models/embedding-001", task_type: "retrieval_document"
|
||||
# Cohere → model: "embed-english-v3.0"
|
||||
# Ollama → model: "nomic-embed-text"
|
||||
},
|
||||
},
|
||||
|
||||
# Required: vector database config
|
||||
"vectordb": {
|
||||
"provider": "chromadb", # or "qdrant"
|
||||
"config": {
|
||||
# Chroma settings (optional persistence)
|
||||
# "settings": Settings(
|
||||
# persist_directory="/content/chroma",
|
||||
# allow_reset=True,
|
||||
# is_persistent=True,
|
||||
# ),
|
||||
|
||||
# Qdrant vector params example:
|
||||
# from qdrant_client.models import VectorParams, Distance
|
||||
# "vectors_config": VectorParams(size=384, distance=Distance.COSINE),
|
||||
|
||||
# Note: collection name is controlled by the tool (default: "rag_tool_collection").
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
||||
```
|
||||
@@ -54,25 +54,25 @@ It is an optional parameter during the tool's initialization but must be provide
|
||||
By default, the tool uses OpenAI for both embeddings and summarization. To customize the model, you can use a config dictionary as follows:
|
||||
|
||||
```python Code
|
||||
from chromadb.config import Settings
|
||||
|
||||
tool = XMLSearchTool(
|
||||
config=dict(
|
||||
llm=dict(
|
||||
provider="ollama", # or google, openai, anthropic, llama2, ...
|
||||
config=dict(
|
||||
model="llama2",
|
||||
# temperature=0.5,
|
||||
# top_p=1,
|
||||
# stream=true,
|
||||
),
|
||||
),
|
||||
embedder=dict(
|
||||
provider="google", # or openai, ollama, ...
|
||||
config=dict(
|
||||
model="models/embedding-001",
|
||||
task_type="retrieval_document",
|
||||
# title="Embeddings",
|
||||
),
|
||||
),
|
||||
)
|
||||
config={
|
||||
"embedding_model": {
|
||||
"provider": "openai",
|
||||
"config": {
|
||||
"model": "text-embedding-3-small",
|
||||
# "api_key": "sk-...",
|
||||
},
|
||||
},
|
||||
"vectordb": {
|
||||
"provider": "chromadb", # or "qdrant"
|
||||
"config": {
|
||||
# "settings": Settings(persist_directory="/content/chroma", allow_reset=True, is_persistent=True),
|
||||
# from qdrant_client.models import VectorParams, Distance
|
||||
# "vectors_config": VectorParams(size=384, distance=Distance.COSINE),
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
||||
```
|
||||
@@ -93,11 +93,15 @@ ddtrace-run python crewai_agent.py
|
||||
|
||||
트레이스를 클릭하면 사용된 총 토큰, LLM 호출 수, 사용된 모델, 예상 비용 등 트레이스에 대한 세부 정보가 표시됩니다. 특정 스팬(span)을 클릭하면 이러한 세부 정보의 범위가 좁혀지고 관련 입력, 출력 및 메타데이터가 표시됩니다.
|
||||
|
||||

|
||||
<Frame>
|
||||
<img src="/images/datadog-llm-observability-1.png" alt="Datadog LLM 옵저버빌리티 추적 보기" />
|
||||
</Frame>
|
||||
|
||||
또한, 트레이스의 제어 및 데이터 흐름을 보여주는 트레이스의 실행 그래프 보기를 볼 수 있으며, 이는 더 큰 에이전트로 확장하여 LLM 호출, 도구 호출 및 에이전트 상호 작용 간의 핸드오프와 관계를 보여줍니다.
|
||||
|
||||

|
||||
<Frame>
|
||||
<img src="/images/datadog-llm-observability-2.png" alt="Datadog LLM Observability 에이전트 실행 흐름 보기" />
|
||||
</Frame>
|
||||
|
||||
## 참조
|
||||
|
||||
|
||||
@@ -23,13 +23,15 @@ uv add qdrant-client
|
||||
|
||||
```python
|
||||
from crewai import Agent
|
||||
from crewai_tools import QdrantVectorSearchTool
|
||||
from crewai_tools import QdrantVectorSearchTool, QdrantConfig
|
||||
|
||||
# Initialize the tool
|
||||
# QdrantConfig로 도구 초기화
|
||||
qdrant_tool = QdrantVectorSearchTool(
|
||||
qdrant_url="your_qdrant_url",
|
||||
qdrant_api_key="your_qdrant_api_key",
|
||||
collection_name="your_collection"
|
||||
qdrant_config=QdrantConfig(
|
||||
qdrant_url="your_qdrant_url",
|
||||
qdrant_api_key="your_qdrant_api_key",
|
||||
collection_name="your_collection"
|
||||
)
|
||||
)
|
||||
|
||||
# Create an agent that uses the tool
|
||||
@@ -82,7 +84,7 @@ def extract_text_from_pdf(pdf_path):
|
||||
def get_openai_embedding(text):
|
||||
response = client.embeddings.create(
|
||||
input=text,
|
||||
model="text-embedding-3-small"
|
||||
model="text-embedding-3-large"
|
||||
)
|
||||
return response.data[0].embedding
|
||||
|
||||
@@ -90,13 +92,13 @@ def get_openai_embedding(text):
|
||||
def load_pdf_to_qdrant(pdf_path, qdrant, collection_name):
|
||||
# Extract text from PDF
|
||||
text_chunks = extract_text_from_pdf(pdf_path)
|
||||
|
||||
|
||||
# Create Qdrant collection
|
||||
if qdrant.collection_exists(collection_name):
|
||||
qdrant.delete_collection(collection_name)
|
||||
qdrant.create_collection(
|
||||
collection_name=collection_name,
|
||||
vectors_config=VectorParams(size=1536, distance=Distance.COSINE)
|
||||
vectors_config=VectorParams(size=3072, distance=Distance.COSINE)
|
||||
)
|
||||
|
||||
# Store embeddings
|
||||
@@ -120,19 +122,23 @@ pdf_path = "path/to/your/document.pdf"
|
||||
load_pdf_to_qdrant(pdf_path, qdrant, collection_name)
|
||||
|
||||
# Initialize Qdrant search tool
|
||||
from crewai_tools import QdrantConfig
|
||||
|
||||
qdrant_tool = QdrantVectorSearchTool(
|
||||
qdrant_url=os.getenv("QDRANT_URL"),
|
||||
qdrant_api_key=os.getenv("QDRANT_API_KEY"),
|
||||
collection_name=collection_name,
|
||||
limit=3,
|
||||
score_threshold=0.35
|
||||
qdrant_config=QdrantConfig(
|
||||
qdrant_url=os.getenv("QDRANT_URL"),
|
||||
qdrant_api_key=os.getenv("QDRANT_API_KEY"),
|
||||
collection_name=collection_name,
|
||||
limit=3,
|
||||
score_threshold=0.35
|
||||
)
|
||||
)
|
||||
|
||||
# Create CrewAI agents
|
||||
search_agent = Agent(
|
||||
role="Senior Semantic Search Agent",
|
||||
goal="Find and analyze documents based on semantic search",
|
||||
backstory="""You are an expert research assistant who can find relevant
|
||||
backstory="""You are an expert research assistant who can find relevant
|
||||
information using semantic search in a Qdrant database.""",
|
||||
tools=[qdrant_tool],
|
||||
verbose=True
|
||||
@@ -141,7 +147,7 @@ search_agent = Agent(
|
||||
answer_agent = Agent(
|
||||
role="Senior Answer Assistant",
|
||||
goal="Generate answers to questions based on the context provided",
|
||||
backstory="""You are an expert answer assistant who can generate
|
||||
backstory="""You are an expert answer assistant who can generate
|
||||
answers to questions based on the context provided.""",
|
||||
tools=[qdrant_tool],
|
||||
verbose=True
|
||||
@@ -180,21 +186,82 @@ print(result)
|
||||
## 도구 매개변수
|
||||
|
||||
### 필수 파라미터
|
||||
- `qdrant_url` (str): Qdrant 서버의 URL
|
||||
- `qdrant_api_key` (str): Qdrant 인증을 위한 API 키
|
||||
- `collection_name` (str): 검색할 Qdrant 컬렉션의 이름
|
||||
- `qdrant_config` (QdrantConfig): 모든 Qdrant 설정을 포함하는 구성 객체
|
||||
|
||||
### 선택적 매개변수
|
||||
### QdrantConfig 매개변수
|
||||
- `qdrant_url` (str): Qdrant 서버의 URL
|
||||
- `qdrant_api_key` (str, 선택 사항): Qdrant 인증을 위한 API 키
|
||||
- `collection_name` (str): 검색할 Qdrant 컬렉션의 이름
|
||||
- `limit` (int): 반환할 최대 결과 수 (기본값: 3)
|
||||
- `score_threshold` (float): 최소 유사도 점수 임계값 (기본값: 0.35)
|
||||
- `filter` (Any, 선택 사항): 고급 필터링을 위한 Qdrant Filter 인스턴스 (기본값: None)
|
||||
|
||||
### 선택적 도구 매개변수
|
||||
- `custom_embedding_fn` (Callable[[str], list[float]]): 텍스트 벡터화를 위한 사용자 지정 함수
|
||||
- `qdrant_package` (str): Qdrant의 기본 패키지 경로 (기본값: "qdrant_client")
|
||||
- `client` (Any): 사전 초기화된 Qdrant 클라이언트 (선택 사항)
|
||||
|
||||
## 고급 필터링
|
||||
|
||||
QdrantVectorSearchTool은 검색 결과를 세밀하게 조정할 수 있는 강력한 필터링 기능을 지원합니다:
|
||||
|
||||
### 동적 필터링
|
||||
검색 시 `filter_by` 및 `filter_value` 매개변수를 사용하여 즉석에서 결과를 필터링할 수 있습니다:
|
||||
|
||||
```python
|
||||
# 에이전트는 도구를 호출할 때 이러한 매개변수를 사용합니다
|
||||
# 도구 스키마는 filter_by 및 filter_value를 허용합니다
|
||||
# 예시: 카테고리 필터를 사용한 검색
|
||||
# 결과는 category == "기술"인 항목으로 필터링됩니다
|
||||
```
|
||||
|
||||
### QdrantConfig를 사용한 사전 설정 필터
|
||||
복잡한 필터링의 경우 구성에서 Qdrant Filter 인스턴스를 사용하세요:
|
||||
|
||||
```python
|
||||
from qdrant_client.http import models as qmodels
|
||||
from crewai_tools import QdrantVectorSearchTool, QdrantConfig
|
||||
|
||||
# 특정 조건에 대한 필터 생성
|
||||
preset_filter = qmodels.Filter(
|
||||
must=[
|
||||
qmodels.FieldCondition(
|
||||
key="category",
|
||||
match=qmodels.MatchValue(value="research")
|
||||
),
|
||||
qmodels.FieldCondition(
|
||||
key="year",
|
||||
match=qmodels.MatchValue(value=2024)
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
# 사전 설정 필터로 도구 초기화
|
||||
qdrant_tool = QdrantVectorSearchTool(
|
||||
qdrant_config=QdrantConfig(
|
||||
qdrant_url="your_url",
|
||||
qdrant_api_key="your_key",
|
||||
collection_name="your_collection",
|
||||
filter=preset_filter # 모든 검색에 적용되는 사전 설정 필터
|
||||
)
|
||||
)
|
||||
```
|
||||
|
||||
### 필터 결합
|
||||
도구는 `QdrantConfig`의 사전 설정 필터와 `filter_by` 및 `filter_value`의 동적 필터를 자동으로 결합합니다:
|
||||
|
||||
```python
|
||||
# QdrantConfig에 category="research"에 대한 사전 설정 필터가 있고
|
||||
# 검색에서 filter_by="year", filter_value=2024를 사용하는 경우
|
||||
# 두 필터가 모두 결합됩니다 (AND 논리)
|
||||
```
|
||||
|
||||
## 검색 매개변수
|
||||
|
||||
이 도구는 스키마에서 다음과 같은 매개변수를 허용합니다:
|
||||
- `query` (str): 유사한 문서를 찾기 위한 검색 쿼리
|
||||
- `filter_by` (str, 선택 사항): 필터링할 메타데이터 필드
|
||||
- `filter_value` (str, 선택 사항): 필터 기준 값
|
||||
- `filter_value` (Any, 선택 사항): 필터 기준 값
|
||||
|
||||
## 반환 형식
|
||||
|
||||
@@ -214,7 +281,7 @@ print(result)
|
||||
|
||||
## 기본 임베딩
|
||||
|
||||
기본적으로, 이 도구는 벡터화를 위해 OpenAI의 `text-embedding-3-small` 모델을 사용합니다. 이를 위해서는 다음이 필요합니다:
|
||||
기본적으로, 이 도구는 벡터화를 위해 OpenAI의 `text-embedding-3-large` 모델을 사용합니다. 이를 위해서는 다음이 필요합니다:
|
||||
- 환경변수에 설정된 OpenAI API 키: `OPENAI_API_KEY`
|
||||
|
||||
## 커스텀 임베딩
|
||||
@@ -240,18 +307,22 @@ def custom_embeddings(text: str) -> list[float]:
|
||||
# Tokenize and get model outputs
|
||||
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
|
||||
outputs = model(**inputs)
|
||||
|
||||
|
||||
# Use mean pooling to get text embedding
|
||||
embeddings = outputs.last_hidden_state.mean(dim=1)
|
||||
|
||||
|
||||
# Convert to list of floats and return
|
||||
return embeddings[0].tolist()
|
||||
|
||||
# Use custom embeddings with the tool
|
||||
from crewai_tools import QdrantConfig
|
||||
|
||||
tool = QdrantVectorSearchTool(
|
||||
qdrant_url="your_url",
|
||||
qdrant_api_key="your_key",
|
||||
collection_name="your_collection",
|
||||
qdrant_config=QdrantConfig(
|
||||
qdrant_url="your_url",
|
||||
qdrant_api_key="your_key",
|
||||
collection_name="your_collection"
|
||||
),
|
||||
custom_embedding_fn=custom_embeddings # Pass your custom function
|
||||
)
|
||||
```
|
||||
@@ -270,4 +341,4 @@ tool = QdrantVectorSearchTool(
|
||||
export QDRANT_URL="your_qdrant_url" # If not provided in constructor
|
||||
export QDRANT_API_KEY="your_api_key" # If not provided in constructor
|
||||
export OPENAI_API_KEY="your_openai_key" # If using default embeddings
|
||||
```
|
||||
```
|
||||
|
||||
@@ -54,25 +54,25 @@ tool = CSVSearchTool()
|
||||
기본적으로 이 도구는 임베딩과 요약 모두에 OpenAI를 사용합니다. 모델을 사용자 지정하려면 다음과 같이 config 딕셔너리를 사용할 수 있습니다:
|
||||
|
||||
```python Code
|
||||
from chromadb.config import Settings
|
||||
|
||||
tool = CSVSearchTool(
|
||||
config=dict(
|
||||
llm=dict(
|
||||
provider="ollama", # or google, openai, anthropic, llama2, ...
|
||||
config=dict(
|
||||
model="llama2",
|
||||
# temperature=0.5,
|
||||
# top_p=1,
|
||||
# stream=true,
|
||||
),
|
||||
),
|
||||
embedder=dict(
|
||||
provider="google", # or openai, ollama, ...
|
||||
config=dict(
|
||||
model="models/embedding-001",
|
||||
task_type="retrieval_document",
|
||||
# title="Embeddings",
|
||||
),
|
||||
),
|
||||
)
|
||||
config={
|
||||
"embedding_model": {
|
||||
"provider": "openai",
|
||||
"config": {
|
||||
"model": "text-embedding-3-small",
|
||||
# "api_key": "sk-...",
|
||||
},
|
||||
},
|
||||
"vectordb": {
|
||||
"provider": "chromadb", # 또는 "qdrant"
|
||||
"config": {
|
||||
# "settings": Settings(persist_directory="/content/chroma", allow_reset=True, is_persistent=True),
|
||||
# from qdrant_client.models import VectorParams, Distance
|
||||
# "vectors_config": VectorParams(size=384, distance=Distance.COSINE),
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
@@ -46,23 +46,25 @@ tool = DirectorySearchTool(directory='/path/to/directory')
|
||||
DirectorySearchTool은 기본적으로 OpenAI를 사용하여 임베딩 및 요약을 수행합니다. 이 설정의 커스터마이즈 옵션에는 모델 공급자 및 구성을 변경하는 것이 포함되어 있어, 고급 사용자를 위한 유연성을 향상시킵니다.
|
||||
|
||||
```python Code
|
||||
from chromadb.config import Settings
|
||||
|
||||
tool = DirectorySearchTool(
|
||||
config=dict(
|
||||
llm=dict(
|
||||
provider="ollama", # Options include ollama, google, anthropic, llama2, and more
|
||||
config=dict(
|
||||
model="llama2",
|
||||
# Additional configurations here
|
||||
),
|
||||
),
|
||||
embedder=dict(
|
||||
provider="google", # or openai, ollama, ...
|
||||
config=dict(
|
||||
model="models/embedding-001",
|
||||
task_type="retrieval_document",
|
||||
# title="Embeddings",
|
||||
),
|
||||
),
|
||||
)
|
||||
config={
|
||||
"embedding_model": {
|
||||
"provider": "openai",
|
||||
"config": {
|
||||
"model": "text-embedding-3-small",
|
||||
# "api_key": "sk-...",
|
||||
},
|
||||
},
|
||||
"vectordb": {
|
||||
"provider": "chromadb", # 또는 "qdrant"
|
||||
"config": {
|
||||
# "settings": Settings(persist_directory="/content/chroma", allow_reset=True, is_persistent=True),
|
||||
# from qdrant_client.models import VectorParams, Distance
|
||||
# "vectors_config": VectorParams(size=384, distance=Distance.COSINE),
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
@@ -56,25 +56,25 @@ tool = DOCXSearchTool(docx='path/to/your/document.docx')
|
||||
기본적으로 이 도구는 임베딩과 요약 모두에 OpenAI를 사용합니다. 모델을 커스터마이즈하려면 다음과 같이 config 딕셔너리를 사용할 수 있습니다:
|
||||
|
||||
```python Code
|
||||
from chromadb.config import Settings
|
||||
|
||||
tool = DOCXSearchTool(
|
||||
config=dict(
|
||||
llm=dict(
|
||||
provider="ollama", # or google, openai, anthropic, llama2, ...
|
||||
config=dict(
|
||||
model="llama2",
|
||||
# temperature=0.5,
|
||||
# top_p=1,
|
||||
# stream=true,
|
||||
),
|
||||
),
|
||||
embedder=dict(
|
||||
provider="google", # or openai, ollama, ...
|
||||
config=dict(
|
||||
model="models/embedding-001",
|
||||
task_type="retrieval_document",
|
||||
# title="Embeddings",
|
||||
),
|
||||
),
|
||||
)
|
||||
config={
|
||||
"embedding_model": {
|
||||
"provider": "openai",
|
||||
"config": {
|
||||
"model": "text-embedding-3-small",
|
||||
# "api_key": "sk-...",
|
||||
},
|
||||
},
|
||||
"vectordb": {
|
||||
"provider": "chromadb", # 또는 "qdrant"
|
||||
"config": {
|
||||
# "settings": Settings(persist_directory="/content/chroma", allow_reset=True, is_persistent=True),
|
||||
# from qdrant_client.models import VectorParams, Distance
|
||||
# "vectors_config": VectorParams(size=384, distance=Distance.COSINE),
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
@@ -48,27 +48,25 @@ tool = MDXSearchTool(mdx='path/to/your/document.mdx')
|
||||
이 도구는 기본적으로 임베딩과 요약을 위해 OpenAI를 사용합니다. 커스터마이징을 위해 아래와 같이 설정 딕셔너리를 사용할 수 있습니다.
|
||||
|
||||
```python Code
|
||||
from chromadb.config import Settings
|
||||
|
||||
tool = MDXSearchTool(
|
||||
config=dict(
|
||||
llm=dict(
|
||||
provider="ollama", # 옵션에는 google, openai, anthropic, llama2 등이 있습니다.
|
||||
config=dict(
|
||||
model="llama2",
|
||||
# 선택적 파라미터를 여기에 포함할 수 있습니다.
|
||||
# temperature=0.5,
|
||||
# top_p=1,
|
||||
# stream=true,
|
||||
),
|
||||
),
|
||||
embedder=dict(
|
||||
provider="google", # 또는 openai, ollama, ...
|
||||
config=dict(
|
||||
model="models/embedding-001",
|
||||
task_type="retrieval_document",
|
||||
# 임베딩에 대한 선택적 제목을 여기에 추가할 수 있습니다.
|
||||
# title="Embeddings",
|
||||
),
|
||||
),
|
||||
)
|
||||
config={
|
||||
"embedding_model": {
|
||||
"provider": "openai",
|
||||
"config": {
|
||||
"model": "text-embedding-3-small",
|
||||
# "api_key": "sk-...",
|
||||
},
|
||||
},
|
||||
"vectordb": {
|
||||
"provider": "chromadb", # 또는 "qdrant"
|
||||
"config": {
|
||||
# "settings": Settings(persist_directory="/content/chroma", allow_reset=True, is_persistent=True),
|
||||
# from qdrant_client.models import VectorParams, Distance
|
||||
# "vectors_config": VectorParams(size=384, distance=Distance.COSINE),
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
@@ -45,28 +45,60 @@ tool = PDFSearchTool(pdf='path/to/your/document.pdf')
|
||||
|
||||
## 커스텀 모델 및 임베딩
|
||||
|
||||
기본적으로 이 도구는 임베딩과 요약 모두에 OpenAI를 사용합니다. 모델을 커스터마이즈하려면 다음과 같이 config 딕셔너리를 사용할 수 있습니다:
|
||||
기본적으로 이 도구는 임베딩과 요약 모두에 OpenAI를 사용합니다. 모델을 커스터마이즈하려면 다음과 같이 config 딕셔너리를 사용할 수 있습니다. 참고: 임베딩은 벡터DB에 저장되어야 하므로 vectordb 설정이 필요합니다.
|
||||
|
||||
```python Code
|
||||
from crewai_tools import PDFSearchTool
|
||||
from chromadb.config import Settings # Chroma 영속성 설정
|
||||
|
||||
tool = PDFSearchTool(
|
||||
config=dict(
|
||||
llm=dict(
|
||||
provider="ollama", # or google, openai, anthropic, llama2, ...
|
||||
config=dict(
|
||||
model="llama2",
|
||||
# temperature=0.5,
|
||||
# top_p=1,
|
||||
# stream=true,
|
||||
),
|
||||
),
|
||||
embedder=dict(
|
||||
provider="google", # or openai, ollama, ...
|
||||
config=dict(
|
||||
model="models/embedding-001",
|
||||
task_type="retrieval_document",
|
||||
# title="Embeddings",
|
||||
),
|
||||
),
|
||||
)
|
||||
config={
|
||||
# 필수: 임베딩 제공자와 설정
|
||||
"embedding_model": {
|
||||
# 사용 가능 공급자: "openai", "azure", "google-generativeai", "google-vertex",
|
||||
# "voyageai", "cohere", "huggingface", "jina", "sentence-transformer",
|
||||
# "text2vec", "ollama", "openclip", "instructor", "onnx", "roboflow", "watsonx", "custom"
|
||||
"provider": "openai",
|
||||
"config": {
|
||||
# "model" 키는 내부적으로 "model_name"으로 매핑됩니다.
|
||||
"model": "text-embedding-3-small",
|
||||
# 선택: API 키 (미설정 시 환경변수 사용)
|
||||
# "api_key": "sk-...",
|
||||
|
||||
# 공급자별 예시
|
||||
# --- Google ---
|
||||
# (provider를 "google-generativeai"로 설정)
|
||||
# "model": "models/embedding-001",
|
||||
# "task_type": "retrieval_document",
|
||||
|
||||
# --- Cohere ---
|
||||
# (provider를 "cohere"로 설정)
|
||||
# "model": "embed-english-v3.0",
|
||||
|
||||
# --- Ollama(로컬) ---
|
||||
# (provider를 "ollama"로 설정)
|
||||
# "model": "nomic-embed-text",
|
||||
},
|
||||
},
|
||||
|
||||
# 필수: 벡터DB 설정
|
||||
"vectordb": {
|
||||
"provider": "chromadb", # 또는 "qdrant"
|
||||
"config": {
|
||||
# Chroma 설정 예시
|
||||
# "settings": Settings(
|
||||
# persist_directory="/content/chroma",
|
||||
# allow_reset=True,
|
||||
# is_persistent=True,
|
||||
# ),
|
||||
|
||||
# Qdrant 설정 예시
|
||||
# from qdrant_client.models import VectorParams, Distance
|
||||
# "vectors_config": VectorParams(size=384, distance=Distance.COSINE),
|
||||
|
||||
# 참고: 컬렉션 이름은 도구에서 관리합니다(기본값: "rag_tool_collection").
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
@@ -57,25 +57,34 @@ tool = TXTSearchTool(txt='path/to/text/file.txt')
|
||||
모델을 커스터마이징하려면 다음과 같이 config 딕셔너리를 사용할 수 있습니다:
|
||||
|
||||
```python Code
|
||||
from chromadb.config import Settings
|
||||
|
||||
tool = TXTSearchTool(
|
||||
config=dict(
|
||||
llm=dict(
|
||||
provider="ollama", # or google, openai, anthropic, llama2, ...
|
||||
config=dict(
|
||||
model="llama2",
|
||||
# temperature=0.5,
|
||||
# top_p=1,
|
||||
# stream=true,
|
||||
),
|
||||
),
|
||||
embedder=dict(
|
||||
provider="google", # or openai, ollama, ...
|
||||
config=dict(
|
||||
model="models/embedding-001",
|
||||
task_type="retrieval_document",
|
||||
# title="Embeddings",
|
||||
),
|
||||
),
|
||||
)
|
||||
config={
|
||||
# 필수: 임베딩 제공자 + 설정
|
||||
"embedding_model": {
|
||||
"provider": "openai", # 또는 google-generativeai, cohere, ollama 등
|
||||
"config": {
|
||||
"model": "text-embedding-3-small",
|
||||
# "api_key": "sk-...", # 환경변수 사용 시 생략 가능
|
||||
# 공급자별 예시: Google → model: "models/embedding-001", task_type: "retrieval_document"
|
||||
},
|
||||
},
|
||||
|
||||
# 필수: 벡터DB 설정
|
||||
"vectordb": {
|
||||
"provider": "chromadb", # 또는 "qdrant"
|
||||
"config": {
|
||||
# Chroma 설정(영속성 예시)
|
||||
# "settings": Settings(persist_directory="/content/chroma", allow_reset=True, is_persistent=True),
|
||||
|
||||
# Qdrant 벡터 파라미터 예시:
|
||||
# from qdrant_client.models import VectorParams, Distance
|
||||
# "vectors_config": VectorParams(size=384, distance=Distance.COSINE),
|
||||
|
||||
# 참고: 컬렉션 이름은 도구에서 관리합니다(기본값: "rag_tool_collection").
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
@@ -54,25 +54,25 @@ tool = XMLSearchTool(xml='path/to/your/xmlfile.xml')
|
||||
기본적으로 이 도구는 임베딩과 요약 모두에 OpenAI를 사용합니다. 모델을 커스터마이징하려면 다음과 같이 config 딕셔너리를 사용할 수 있습니다.
|
||||
|
||||
```python Code
|
||||
from chromadb.config import Settings
|
||||
|
||||
tool = XMLSearchTool(
|
||||
config=dict(
|
||||
llm=dict(
|
||||
provider="ollama", # or google, openai, anthropic, llama2, ...
|
||||
config=dict(
|
||||
model="llama2",
|
||||
# temperature=0.5,
|
||||
# top_p=1,
|
||||
# stream=true,
|
||||
),
|
||||
),
|
||||
embedder=dict(
|
||||
provider="google", # or openai, ollama, ...
|
||||
config=dict(
|
||||
model="models/embedding-001",
|
||||
task_type="retrieval_document",
|
||||
# title="Embeddings",
|
||||
),
|
||||
),
|
||||
)
|
||||
config={
|
||||
"embedding_model": {
|
||||
"provider": "openai",
|
||||
"config": {
|
||||
"model": "text-embedding-3-small",
|
||||
# "api_key": "sk-...",
|
||||
},
|
||||
},
|
||||
"vectordb": {
|
||||
"provider": "chromadb", # 또는 "qdrant"
|
||||
"config": {
|
||||
# "settings": Settings(persist_directory="/content/chroma", allow_reset=True, is_persistent=True),
|
||||
# from qdrant_client.models import VectorParams, Distance
|
||||
# "vectors_config": VectorParams(size=384, distance=Distance.COSINE),
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
@@ -93,11 +93,14 @@ Depois de executar o aplicativo, você pode visualizar os traços na [Datadog LL
|
||||
|
||||
Ao clicar em um rastreamento, você verá os detalhes do rastreamento, incluindo o total de tokens usados, o número de chamadas LLM, os modelos usados e o custo estimado. Clicar em um intervalo específico reduzirá esses detalhes e mostrará a entrada, a saída e os metadados relacionados.
|
||||
|
||||

|
||||
|
||||
<Frame>
|
||||
<img src="/images/datadog-llm-observability-1.png" alt="Visualização do rastreamento de observabilidade do Datadog LLM" />
|
||||
</Frame>
|
||||
Além disso, você pode visualizar a visualização do gráfico de execução do rastreamento, que mostra o controle e o fluxo de dados do rastreamento, que será dimensionado com agentes maiores para mostrar transferências e relacionamentos entre chamadas LLM, chamadas de ferramentas e interações de agentes.
|
||||
|
||||

|
||||
<Frame>
|
||||
<img src="/images/datadog-llm-observability-2.png" alt="Visualização do fluxo de execução do agente de observabilidade do Datadog LLM" />
|
||||
</Frame>
|
||||
|
||||
## Referências
|
||||
|
||||
|
||||
@@ -23,13 +23,15 @@ Veja um exemplo mínimo de como utilizar a ferramenta:
|
||||
|
||||
```python
|
||||
from crewai import Agent
|
||||
from crewai_tools import QdrantVectorSearchTool
|
||||
from crewai_tools import QdrantVectorSearchTool, QdrantConfig
|
||||
|
||||
# Inicialize a ferramenta
|
||||
# Inicialize a ferramenta com QdrantConfig
|
||||
qdrant_tool = QdrantVectorSearchTool(
|
||||
qdrant_url="your_qdrant_url",
|
||||
qdrant_api_key="your_qdrant_api_key",
|
||||
collection_name="your_collection"
|
||||
qdrant_config=QdrantConfig(
|
||||
qdrant_url="your_qdrant_url",
|
||||
qdrant_api_key="your_qdrant_api_key",
|
||||
collection_name="your_collection"
|
||||
)
|
||||
)
|
||||
|
||||
# Crie um agente que utiliza a ferramenta
|
||||
@@ -82,7 +84,7 @@ def extract_text_from_pdf(pdf_path):
|
||||
def get_openai_embedding(text):
|
||||
response = client.embeddings.create(
|
||||
input=text,
|
||||
model="text-embedding-3-small"
|
||||
model="text-embedding-3-large"
|
||||
)
|
||||
return response.data[0].embedding
|
||||
|
||||
@@ -90,13 +92,13 @@ def get_openai_embedding(text):
|
||||
def load_pdf_to_qdrant(pdf_path, qdrant, collection_name):
|
||||
# Extrair texto do PDF
|
||||
text_chunks = extract_text_from_pdf(pdf_path)
|
||||
|
||||
|
||||
# Criar coleção no Qdrant
|
||||
if qdrant.collection_exists(collection_name):
|
||||
qdrant.delete_collection(collection_name)
|
||||
qdrant.create_collection(
|
||||
collection_name=collection_name,
|
||||
vectors_config=VectorParams(size=1536, distance=Distance.COSINE)
|
||||
vectors_config=VectorParams(size=3072, distance=Distance.COSINE)
|
||||
)
|
||||
|
||||
# Armazenar embeddings
|
||||
@@ -120,19 +122,23 @@ pdf_path = "path/to/your/document.pdf"
|
||||
load_pdf_to_qdrant(pdf_path, qdrant, collection_name)
|
||||
|
||||
# Inicializar ferramenta de busca Qdrant
|
||||
from crewai_tools import QdrantConfig
|
||||
|
||||
qdrant_tool = QdrantVectorSearchTool(
|
||||
qdrant_url=os.getenv("QDRANT_URL"),
|
||||
qdrant_api_key=os.getenv("QDRANT_API_KEY"),
|
||||
collection_name=collection_name,
|
||||
limit=3,
|
||||
score_threshold=0.35
|
||||
qdrant_config=QdrantConfig(
|
||||
qdrant_url=os.getenv("QDRANT_URL"),
|
||||
qdrant_api_key=os.getenv("QDRANT_API_KEY"),
|
||||
collection_name=collection_name,
|
||||
limit=3,
|
||||
score_threshold=0.35
|
||||
)
|
||||
)
|
||||
|
||||
# Criar agentes CrewAI
|
||||
search_agent = Agent(
|
||||
role="Senior Semantic Search Agent",
|
||||
goal="Find and analyze documents based on semantic search",
|
||||
backstory="""You are an expert research assistant who can find relevant
|
||||
backstory="""You are an expert research assistant who can find relevant
|
||||
information using semantic search in a Qdrant database.""",
|
||||
tools=[qdrant_tool],
|
||||
verbose=True
|
||||
@@ -141,7 +147,7 @@ search_agent = Agent(
|
||||
answer_agent = Agent(
|
||||
role="Senior Answer Assistant",
|
||||
goal="Generate answers to questions based on the context provided",
|
||||
backstory="""You are an expert answer assistant who can generate
|
||||
backstory="""You are an expert answer assistant who can generate
|
||||
answers to questions based on the context provided.""",
|
||||
tools=[qdrant_tool],
|
||||
verbose=True
|
||||
@@ -180,21 +186,82 @@ print(result)
|
||||
## Parâmetros da Ferramenta
|
||||
|
||||
### Parâmetros Obrigatórios
|
||||
- `qdrant_url` (str): URL do seu servidor Qdrant
|
||||
- `qdrant_api_key` (str): Chave de API para autenticação com o Qdrant
|
||||
- `collection_name` (str): Nome da coleção Qdrant a ser pesquisada
|
||||
- `qdrant_config` (QdrantConfig): Objeto de configuração contendo todas as configurações do Qdrant
|
||||
|
||||
### Parâmetros Opcionais
|
||||
### Parâmetros do QdrantConfig
|
||||
- `qdrant_url` (str): URL do seu servidor Qdrant
|
||||
- `qdrant_api_key` (str, opcional): Chave de API para autenticação com o Qdrant
|
||||
- `collection_name` (str): Nome da coleção Qdrant a ser pesquisada
|
||||
- `limit` (int): Número máximo de resultados a serem retornados (padrão: 3)
|
||||
- `score_threshold` (float): Limite mínimo de similaridade (padrão: 0.35)
|
||||
- `filter` (Any, opcional): Instância de Filter do Qdrant para filtragem avançada (padrão: None)
|
||||
|
||||
### Parâmetros Opcionais da Ferramenta
|
||||
- `custom_embedding_fn` (Callable[[str], list[float]]): Função personalizada para vetorização de textos
|
||||
- `qdrant_package` (str): Caminho base do pacote Qdrant (padrão: "qdrant_client")
|
||||
- `client` (Any): Cliente Qdrant pré-inicializado (opcional)
|
||||
|
||||
## Filtragem Avançada
|
||||
|
||||
A ferramenta QdrantVectorSearchTool oferece recursos poderosos de filtragem para refinar os resultados da busca:
|
||||
|
||||
### Filtragem Dinâmica
|
||||
Use os parâmetros `filter_by` e `filter_value` na sua busca para filtrar resultados dinamicamente:
|
||||
|
||||
```python
|
||||
# O agente usará esses parâmetros ao chamar a ferramenta
|
||||
# O schema da ferramenta aceita filter_by e filter_value
|
||||
# Exemplo: busca com filtro de categoria
|
||||
# Os resultados serão filtrados onde categoria == "tecnologia"
|
||||
```
|
||||
|
||||
### Filtros Pré-definidos com QdrantConfig
|
||||
Para filtragens complexas, use instâncias de Filter do Qdrant na sua configuração:
|
||||
|
||||
```python
|
||||
from qdrant_client.http import models as qmodels
|
||||
from crewai_tools import QdrantVectorSearchTool, QdrantConfig
|
||||
|
||||
# Criar um filtro para condições específicas
|
||||
preset_filter = qmodels.Filter(
|
||||
must=[
|
||||
qmodels.FieldCondition(
|
||||
key="categoria",
|
||||
match=qmodels.MatchValue(value="pesquisa")
|
||||
),
|
||||
qmodels.FieldCondition(
|
||||
key="ano",
|
||||
match=qmodels.MatchValue(value=2024)
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
# Inicializar ferramenta com filtro pré-definido
|
||||
qdrant_tool = QdrantVectorSearchTool(
|
||||
qdrant_config=QdrantConfig(
|
||||
qdrant_url="your_url",
|
||||
qdrant_api_key="your_key",
|
||||
collection_name="your_collection",
|
||||
filter=preset_filter # Filtro pré-definido aplicado a todas as buscas
|
||||
)
|
||||
)
|
||||
```
|
||||
|
||||
### Combinando Filtros
|
||||
A ferramenta combina automaticamente os filtros pré-definidos do `QdrantConfig` com os filtros dinâmicos de `filter_by` e `filter_value`:
|
||||
|
||||
```python
|
||||
# Se QdrantConfig tem um filtro pré-definido para categoria="pesquisa"
|
||||
# E a busca usa filter_by="ano", filter_value=2024
|
||||
# Ambos os filtros serão combinados (lógica AND)
|
||||
```
|
||||
|
||||
## Parâmetros de Busca
|
||||
|
||||
A ferramenta aceita estes parâmetros em seu schema:
|
||||
- `query` (str): Consulta de busca para encontrar documentos similares
|
||||
- `filter_by` (str, opcional): Campo de metadado para filtrar
|
||||
- `filter_value` (str, opcional): Valor para filtrar
|
||||
- `filter_value` (Any, opcional): Valor para filtrar
|
||||
|
||||
## Formato de Retorno
|
||||
|
||||
@@ -214,7 +281,7 @@ A ferramenta retorna resultados no formato JSON:
|
||||
|
||||
## Embedding Padrão
|
||||
|
||||
Por padrão, a ferramenta utiliza o modelo `text-embedding-3-small` da OpenAI para vetorização. Isso requer:
|
||||
Por padrão, a ferramenta utiliza o modelo `text-embedding-3-large` da OpenAI para vetorização. Isso requer:
|
||||
- Chave de API da OpenAI definida na variável de ambiente: `OPENAI_API_KEY`
|
||||
|
||||
## Embeddings Personalizados
|
||||
@@ -240,18 +307,22 @@ def custom_embeddings(text: str) -> list[float]:
|
||||
# Tokenizar e obter saídas do modelo
|
||||
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
|
||||
outputs = model(**inputs)
|
||||
|
||||
|
||||
# Usar mean pooling para obter o embedding do texto
|
||||
embeddings = outputs.last_hidden_state.mean(dim=1)
|
||||
|
||||
|
||||
# Converter para lista de floats e retornar
|
||||
return embeddings[0].tolist()
|
||||
|
||||
# Usar embeddings personalizados com a ferramenta
|
||||
from crewai_tools import QdrantConfig
|
||||
|
||||
tool = QdrantVectorSearchTool(
|
||||
qdrant_url="your_url",
|
||||
qdrant_api_key="your_key",
|
||||
collection_name="your_collection",
|
||||
qdrant_config=QdrantConfig(
|
||||
qdrant_url="your_url",
|
||||
qdrant_api_key="your_key",
|
||||
collection_name="your_collection"
|
||||
),
|
||||
custom_embedding_fn=custom_embeddings # Passe sua função personalizada
|
||||
)
|
||||
```
|
||||
@@ -270,4 +341,4 @@ Variáveis de ambiente obrigatórias:
|
||||
export QDRANT_URL="your_qdrant_url" # Se não for informado no construtor
|
||||
export QDRANT_API_KEY="your_api_key" # Se não for informado no construtor
|
||||
export OPENAI_API_KEY="your_openai_key" # Se estiver usando embeddings padrão
|
||||
```
|
||||
```
|
||||
|
||||
@@ -46,23 +46,25 @@ tool = DirectorySearchTool(directory='/path/to/directory')
|
||||
O DirectorySearchTool utiliza OpenAI para embeddings e sumarização por padrão. As opções de personalização dessas configurações incluem a alteração do provedor de modelo e configurações, ampliando a flexibilidade para usuários avançados.
|
||||
|
||||
```python Code
|
||||
from chromadb.config import Settings
|
||||
|
||||
tool = DirectorySearchTool(
|
||||
config=dict(
|
||||
llm=dict(
|
||||
provider="ollama", # As opções incluem ollama, google, anthropic, llama2 e mais
|
||||
config=dict(
|
||||
model="llama2",
|
||||
# Configurações adicionais aqui
|
||||
),
|
||||
),
|
||||
embedder=dict(
|
||||
provider="google", # ou openai, ollama, ...
|
||||
config=dict(
|
||||
model="models/embedding-001",
|
||||
task_type="retrieval_document",
|
||||
# title="Embeddings",
|
||||
),
|
||||
),
|
||||
)
|
||||
config={
|
||||
"embedding_model": {
|
||||
"provider": "openai",
|
||||
"config": {
|
||||
"model": "text-embedding-3-small",
|
||||
# "api_key": "sk-...",
|
||||
},
|
||||
},
|
||||
"vectordb": {
|
||||
"provider": "chromadb", # ou "qdrant"
|
||||
"config": {
|
||||
# "settings": Settings(persist_directory="/content/chroma", allow_reset=True, is_persistent=True),
|
||||
# from qdrant_client.models import VectorParams, Distance
|
||||
# "vectors_config": VectorParams(size=384, distance=Distance.COSINE),
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
||||
```
|
||||
@@ -56,25 +56,25 @@ Os seguintes parâmetros podem ser usados para customizar o comportamento da `DO
|
||||
Por padrão, a ferramenta utiliza o OpenAI tanto para embeddings quanto para sumarização. Para customizar o modelo, você pode usar um dicionário de configuração como no exemplo:
|
||||
|
||||
```python Code
|
||||
from chromadb.config import Settings
|
||||
|
||||
tool = DOCXSearchTool(
|
||||
config=dict(
|
||||
llm=dict(
|
||||
provider="ollama", # ou google, openai, anthropic, llama2, ...
|
||||
config=dict(
|
||||
model="llama2",
|
||||
# temperature=0.5,
|
||||
# top_p=1,
|
||||
# stream=true,
|
||||
),
|
||||
),
|
||||
embedder=dict(
|
||||
provider="google", # ou openai, ollama, ...
|
||||
config=dict(
|
||||
model="models/embedding-001",
|
||||
task_type="retrieval_document",
|
||||
# title="Embeddings",
|
||||
),
|
||||
),
|
||||
)
|
||||
config={
|
||||
"embedding_model": {
|
||||
"provider": "openai",
|
||||
"config": {
|
||||
"model": "text-embedding-3-small",
|
||||
# "api_key": "sk-...",
|
||||
},
|
||||
},
|
||||
"vectordb": {
|
||||
"provider": "chromadb", # ou "qdrant"
|
||||
"config": {
|
||||
# "settings": Settings(persist_directory="/content/chroma", allow_reset=True, is_persistent=True),
|
||||
# from qdrant_client.models import VectorParams, Distance
|
||||
# "vectors_config": VectorParams(size=384, distance=Distance.COSINE),
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
||||
```
|
||||
@@ -48,27 +48,25 @@ tool = MDXSearchTool(mdx='path/to/your/document.mdx')
|
||||
A ferramenta utiliza, por padrão, o OpenAI para embeddings e sumarização. Para personalizar, utilize um dicionário de configuração conforme exemplo abaixo:
|
||||
|
||||
```python Code
|
||||
from chromadb.config import Settings
|
||||
|
||||
tool = MDXSearchTool(
|
||||
config=dict(
|
||||
llm=dict(
|
||||
provider="ollama", # As opções incluem google, openai, anthropic, llama2, etc.
|
||||
config=dict(
|
||||
model="llama2",
|
||||
# Parâmetros opcionais podem ser incluídos aqui.
|
||||
# temperature=0.5,
|
||||
# top_p=1,
|
||||
# stream=true,
|
||||
),
|
||||
),
|
||||
embedder=dict(
|
||||
provider="google", # ou openai, ollama, ...
|
||||
config=dict(
|
||||
model="models/embedding-001",
|
||||
task_type="retrieval_document",
|
||||
# Um título opcional para os embeddings pode ser adicionado aqui.
|
||||
# title="Embeddings",
|
||||
),
|
||||
),
|
||||
)
|
||||
config={
|
||||
"embedding_model": {
|
||||
"provider": "openai",
|
||||
"config": {
|
||||
"model": "text-embedding-3-small",
|
||||
# "api_key": "sk-...",
|
||||
},
|
||||
},
|
||||
"vectordb": {
|
||||
"provider": "chromadb", # ou "qdrant"
|
||||
"config": {
|
||||
# "settings": Settings(persist_directory="/content/chroma", allow_reset=True, is_persistent=True),
|
||||
# from qdrant_client.models import VectorParams, Distance
|
||||
# "vectors_config": VectorParams(size=384, distance=Distance.COSINE),
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
||||
```
|
||||
@@ -45,28 +45,60 @@ tool = PDFSearchTool(pdf='path/to/your/document.pdf')
|
||||
|
||||
## Modelo e embeddings personalizados
|
||||
|
||||
Por padrão, a ferramenta utiliza OpenAI tanto para embeddings quanto para sumarização. Para personalizar o modelo, você pode usar um dicionário de configuração como no exemplo abaixo:
|
||||
Por padrão, a ferramenta utiliza OpenAI para embeddings e sumarização. Para personalizar, use um dicionário de configuração conforme abaixo. Observação: um banco vetorial (vectordb) é necessário, pois os embeddings gerados precisam ser armazenados e consultados.
|
||||
|
||||
```python Code
|
||||
from crewai_tools import PDFSearchTool
|
||||
from chromadb.config import Settings # Persistência no Chroma
|
||||
|
||||
tool = PDFSearchTool(
|
||||
config=dict(
|
||||
llm=dict(
|
||||
provider="ollama", # ou google, openai, anthropic, llama2, ...
|
||||
config=dict(
|
||||
model="llama2",
|
||||
# temperature=0.5,
|
||||
# top_p=1,
|
||||
# stream=true,
|
||||
),
|
||||
),
|
||||
embedder=dict(
|
||||
provider="google", # ou openai, ollama, ...
|
||||
config=dict(
|
||||
model="models/embedding-001",
|
||||
task_type="retrieval_document",
|
||||
# title="Embeddings",
|
||||
),
|
||||
),
|
||||
)
|
||||
config={
|
||||
# Obrigatório: provedor de embeddings + configuração
|
||||
"embedding_model": {
|
||||
# Provedores suportados: "openai", "azure", "google-generativeai", "google-vertex",
|
||||
# "voyageai", "cohere", "huggingface", "jina", "sentence-transformer",
|
||||
# "text2vec", "ollama", "openclip", "instructor", "onnx", "roboflow", "watsonx", "custom"
|
||||
"provider": "openai",
|
||||
"config": {
|
||||
# "model" é mapeado internamente para "model_name".
|
||||
"model": "text-embedding-3-small",
|
||||
# Opcional: chave da API (se ausente, usa variáveis de ambiente do provedor)
|
||||
# "api_key": "sk-...",
|
||||
|
||||
# Exemplos específicos por provedor
|
||||
# --- Google ---
|
||||
# (defina provider="google-generativeai")
|
||||
# "model": "models/embedding-001",
|
||||
# "task_type": "retrieval_document",
|
||||
|
||||
# --- Cohere ---
|
||||
# (defina provider="cohere")
|
||||
# "model": "embed-english-v3.0",
|
||||
|
||||
# --- Ollama (local) ---
|
||||
# (defina provider="ollama")
|
||||
# "model": "nomic-embed-text",
|
||||
},
|
||||
},
|
||||
|
||||
# Obrigatório: configuração do banco vetorial
|
||||
"vectordb": {
|
||||
"provider": "chromadb", # ou "qdrant"
|
||||
"config": {
|
||||
# Exemplo Chroma:
|
||||
# "settings": Settings(
|
||||
# persist_directory="/content/chroma",
|
||||
# allow_reset=True,
|
||||
# is_persistent=True,
|
||||
# ),
|
||||
|
||||
# Exemplo Qdrant:
|
||||
# from qdrant_client.models import VectorParams, Distance
|
||||
# "vectors_config": VectorParams(size=384, distance=Distance.COSINE),
|
||||
|
||||
# Observação: o nome da coleção é controlado pela ferramenta (padrão: "rag_tool_collection").
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
||||
```
|
||||
@@ -57,25 +57,39 @@ Por padrão, a ferramenta utiliza o OpenAI tanto para embeddings quanto para sum
|
||||
Para personalizar o modelo, você pode usar um dicionário de configuração como o exemplo a seguir:
|
||||
|
||||
```python Code
|
||||
from chromadb.config import Settings
|
||||
|
||||
tool = TXTSearchTool(
|
||||
config=dict(
|
||||
llm=dict(
|
||||
provider="ollama", # ou google, openai, anthropic, llama2, ...
|
||||
config=dict(
|
||||
model="llama2",
|
||||
# temperature=0.5,
|
||||
# top_p=1,
|
||||
# stream=true,
|
||||
),
|
||||
),
|
||||
embedder=dict(
|
||||
provider="google", # ou openai, ollama, ...
|
||||
config=dict(
|
||||
model="models/embedding-001",
|
||||
task_type="retrieval_document",
|
||||
# title="Embeddings",
|
||||
),
|
||||
),
|
||||
)
|
||||
config={
|
||||
# Obrigatório: provedor de embeddings + configuração
|
||||
"embedding_model": {
|
||||
"provider": "openai", # ou google-generativeai, cohere, ollama, ...
|
||||
"config": {
|
||||
"model": "text-embedding-3-small",
|
||||
# "api_key": "sk-...", # opcional se variável de ambiente estiver definida
|
||||
# Exemplos por provedor:
|
||||
# Google → model: "models/embedding-001", task_type: "retrieval_document"
|
||||
},
|
||||
},
|
||||
|
||||
# Obrigatório: configuração do banco vetorial
|
||||
"vectordb": {
|
||||
"provider": "chromadb", # ou "qdrant"
|
||||
"config": {
|
||||
# Configurações do Chroma (persistência opcional)
|
||||
# "settings": Settings(
|
||||
# persist_directory="/content/chroma",
|
||||
# allow_reset=True,
|
||||
# is_persistent=True,
|
||||
# ),
|
||||
|
||||
# Exemplo de parâmetros de vetor do Qdrant:
|
||||
# from qdrant_client.models import VectorParams, Distance
|
||||
# "vectors_config": VectorParams(size=384, distance=Distance.COSINE),
|
||||
|
||||
# Observação: o nome da coleção é controlado pela ferramenta (padrão: "rag_tool_collection").
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
||||
```
|
||||
@@ -54,25 +54,25 @@ Este parâmetro é opcional durante a inicialização da ferramenta, mas deve se
|
||||
Por padrão, a ferramenta utiliza a OpenAI tanto para embeddings quanto para sumarização. Para personalizar o modelo, você pode usar um dicionário de configuração conforme o exemplo a seguir:
|
||||
|
||||
```python Code
|
||||
from chromadb.config import Settings
|
||||
|
||||
tool = XMLSearchTool(
|
||||
config=dict(
|
||||
llm=dict(
|
||||
provider="ollama", # ou google, openai, anthropic, llama2, ...
|
||||
config=dict(
|
||||
model="llama2",
|
||||
# temperature=0.5,
|
||||
# top_p=1,
|
||||
# stream=true,
|
||||
),
|
||||
),
|
||||
embedder=dict(
|
||||
provider="google", # ou openai, ollama, ...
|
||||
config=dict(
|
||||
model="models/embedding-001",
|
||||
task_type="retrieval_document",
|
||||
# title="Embeddings",
|
||||
),
|
||||
),
|
||||
)
|
||||
config={
|
||||
"embedding_model": {
|
||||
"provider": "openai",
|
||||
"config": {
|
||||
"model": "text-embedding-3-small",
|
||||
# "api_key": "sk-...",
|
||||
},
|
||||
},
|
||||
"vectordb": {
|
||||
"provider": "chromadb", # ou "qdrant"
|
||||
"config": {
|
||||
# "settings": Settings(persist_directory="/content/chroma", allow_reset=True, is_persistent=True),
|
||||
# from qdrant_client.models import VectorParams, Distance
|
||||
# "vectors_config": VectorParams(size=384, distance=Distance.COSINE),
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
||||
```
|
||||
@@ -12,7 +12,7 @@ dependencies = [
|
||||
"pytube>=15.0.0",
|
||||
"requests>=2.32.5",
|
||||
"docker>=7.1.0",
|
||||
"crewai==1.2.1",
|
||||
"crewai==1.4.1",
|
||||
"lancedb>=0.5.4",
|
||||
"tiktoken>=0.8.0",
|
||||
"beautifulsoup4>=4.13.4",
|
||||
|
||||
@@ -287,4 +287,4 @@ __all__ = [
|
||||
"ZapierActionTools",
|
||||
]
|
||||
|
||||
__version__ = "1.2.1"
|
||||
__version__ = "1.4.1"
|
||||
|
||||
@@ -229,6 +229,7 @@ class CrewAIRagAdapter(Adapter):
|
||||
continue
|
||||
else:
|
||||
metadata: dict[str, Any] = base_metadata.copy()
|
||||
source_content = SourceContent(source_ref)
|
||||
|
||||
if data_type in [
|
||||
DataType.PDF_FILE,
|
||||
@@ -239,13 +240,12 @@ class CrewAIRagAdapter(Adapter):
|
||||
DataType.XML,
|
||||
DataType.MDX,
|
||||
]:
|
||||
if not os.path.isfile(source_ref):
|
||||
if not source_content.is_url() and not source_content.path_exists():
|
||||
raise FileNotFoundError(f"File does not exist: {source_ref}")
|
||||
|
||||
loader = data_type.get_loader()
|
||||
chunker = data_type.get_chunker()
|
||||
|
||||
source_content = SourceContent(source_ref)
|
||||
loader_result: LoaderResult = loader.load(source_content)
|
||||
|
||||
chunks = chunker.chunk(loader_result.content)
|
||||
|
||||
@@ -7,11 +7,12 @@ from pathlib import Path
|
||||
from typing import Any, cast
|
||||
|
||||
from crewai.tools.base_tool import BaseTool, EnvVar
|
||||
from crewai_tools import tools
|
||||
from pydantic import BaseModel
|
||||
from pydantic.json_schema import GenerateJsonSchema
|
||||
from pydantic_core import PydanticOmit
|
||||
|
||||
from crewai_tools import tools
|
||||
|
||||
|
||||
class SchemaGenerator(GenerateJsonSchema):
|
||||
def handle_invalid_for_json_schema(self, schema, error_info):
|
||||
@@ -91,7 +91,7 @@ class EmbeddingService:
|
||||
"azure": "AZURE_OPENAI_API_KEY",
|
||||
"amazon-bedrock": "AWS_ACCESS_KEY_ID", # or AWS_PROFILE
|
||||
"cohere": "COHERE_API_KEY",
|
||||
"google-generativeai": "GOOGLE_API_KEY",
|
||||
"google-generativeai": "GEMINI_API_KEY",
|
||||
"google-vertex": "GOOGLE_APPLICATION_CREDENTIALS",
|
||||
"huggingface": "HUGGINGFACE_API_KEY",
|
||||
"jina": "JINA_API_KEY",
|
||||
|
||||
@@ -22,22 +22,23 @@ class FirecrawlCrawlWebsiteToolSchema(BaseModel):
|
||||
|
||||
|
||||
class FirecrawlCrawlWebsiteTool(BaseTool):
|
||||
"""Tool for crawling websites using Firecrawl. To run this tool, you need to have a Firecrawl API key.
|
||||
"""Tool for crawling websites using Firecrawl v2 API. To run this tool, you need to have a Firecrawl API key.
|
||||
|
||||
Args:
|
||||
api_key (str): Your Firecrawl API key.
|
||||
config (dict): Optional. It contains Firecrawl API parameters.
|
||||
config (dict): Optional. It contains Firecrawl v2 API parameters.
|
||||
|
||||
Default configuration options:
|
||||
max_depth (int): Maximum depth to crawl. Default: 2
|
||||
Default configuration options (Firecrawl v2 API):
|
||||
max_discovery_depth (int): Maximum depth for discovering pages. Default: 2
|
||||
ignore_sitemap (bool): Whether to ignore sitemap. Default: True
|
||||
limit (int): Maximum number of pages to crawl. Default: 100
|
||||
allow_backward_links (bool): Allow crawling backward links. Default: False
|
||||
limit (int): Maximum number of pages to crawl. Default: 10
|
||||
allow_external_links (bool): Allow crawling external links. Default: False
|
||||
scrape_options (ScrapeOptions): Options for scraping content
|
||||
- formats (list[str]): Content formats to return. Default: ["markdown", "screenshot", "links"]
|
||||
allow_subdomains (bool): Allow crawling subdomains. Default: False
|
||||
delay (int): Delay between requests in milliseconds. Default: None
|
||||
scrape_options (dict): Options for scraping content
|
||||
- formats (list[str]): Content formats to return. Default: ["markdown"]
|
||||
- only_main_content (bool): Only return main content. Default: True
|
||||
- timeout (int): Timeout in milliseconds. Default: 30000
|
||||
- timeout (int): Timeout in milliseconds. Default: 10000
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(
|
||||
@@ -49,14 +50,15 @@ class FirecrawlCrawlWebsiteTool(BaseTool):
|
||||
api_key: str | None = None
|
||||
config: dict[str, Any] | None = Field(
|
||||
default_factory=lambda: {
|
||||
"maxDepth": 2,
|
||||
"ignoreSitemap": True,
|
||||
"max_discovery_depth": 2,
|
||||
"ignore_sitemap": True,
|
||||
"limit": 10,
|
||||
"allowBackwardLinks": False,
|
||||
"allowExternalLinks": False,
|
||||
"scrapeOptions": {
|
||||
"formats": ["markdown", "screenshot", "links"],
|
||||
"onlyMainContent": True,
|
||||
"allow_external_links": False,
|
||||
"allow_subdomains": False,
|
||||
"delay": None,
|
||||
"scrape_options": {
|
||||
"formats": ["markdown"],
|
||||
"only_main_content": True,
|
||||
"timeout": 10000,
|
||||
},
|
||||
}
|
||||
@@ -107,7 +109,7 @@ class FirecrawlCrawlWebsiteTool(BaseTool):
|
||||
if not self._firecrawl:
|
||||
raise RuntimeError("FirecrawlApp not properly initialized")
|
||||
|
||||
return self._firecrawl.crawl_url(url, poll_interval=2, params=self.config)
|
||||
return self._firecrawl.crawl(url=url, poll_interval=2, **self.config)
|
||||
|
||||
|
||||
try:
|
||||
|
||||
@@ -22,20 +22,27 @@ class FirecrawlScrapeWebsiteToolSchema(BaseModel):
|
||||
|
||||
|
||||
class FirecrawlScrapeWebsiteTool(BaseTool):
|
||||
"""Tool for scraping webpages using Firecrawl. To run this tool, you need to have a Firecrawl API key.
|
||||
"""Tool for scraping webpages using Firecrawl v2 API. To run this tool, you need to have a Firecrawl API key.
|
||||
|
||||
Args:
|
||||
api_key (str): Your Firecrawl API key.
|
||||
config (dict): Optional. It contains Firecrawl API parameters.
|
||||
config (dict): Optional. It contains Firecrawl v2 API parameters.
|
||||
|
||||
Default configuration options:
|
||||
Default configuration options (Firecrawl v2 API):
|
||||
formats (list[str]): Content formats to return. Default: ["markdown"]
|
||||
onlyMainContent (bool): Only return main content. Default: True
|
||||
includeTags (list[str]): Tags to include. Default: []
|
||||
excludeTags (list[str]): Tags to exclude. Default: []
|
||||
headers (dict): Headers to include. Default: {}
|
||||
waitFor (int): Time to wait for page to load in ms. Default: 0
|
||||
json_options (dict): Options for JSON extraction. Default: None
|
||||
only_main_content (bool): Only return main content excluding headers, navs, footers, etc. Default: True
|
||||
include_tags (list[str]): Tags to include in the output. Default: []
|
||||
exclude_tags (list[str]): Tags to exclude from the output. Default: []
|
||||
max_age (int): Returns cached version if younger than this age in milliseconds. Default: 172800000 (2 days)
|
||||
headers (dict): Headers to send with the request (e.g., cookies, user-agent). Default: {}
|
||||
wait_for (int): Delay in milliseconds before fetching content. Default: 0
|
||||
mobile (bool): Emulate scraping from a mobile device. Default: False
|
||||
skip_tls_verification (bool): Skip TLS certificate verification. Default: True
|
||||
timeout (int): Request timeout in milliseconds. Default: None
|
||||
remove_base64_images (bool): Remove base64 images from output. Default: True
|
||||
block_ads (bool): Enable ad-blocking and cookie popup blocking. Default: True
|
||||
proxy (str): Proxy type ("basic", "stealth", "auto"). Default: "auto"
|
||||
store_in_cache (bool): Store page in Firecrawl index and cache. Default: True
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(
|
||||
@@ -48,11 +55,18 @@ class FirecrawlScrapeWebsiteTool(BaseTool):
|
||||
config: dict[str, Any] = Field(
|
||||
default_factory=lambda: {
|
||||
"formats": ["markdown"],
|
||||
"onlyMainContent": True,
|
||||
"includeTags": [],
|
||||
"excludeTags": [],
|
||||
"only_main_content": True,
|
||||
"include_tags": [],
|
||||
"exclude_tags": [],
|
||||
"max_age": 172800000, # 2 days cache
|
||||
"headers": {},
|
||||
"waitFor": 0,
|
||||
"wait_for": 0,
|
||||
"mobile": False,
|
||||
"skip_tls_verification": True,
|
||||
"remove_base64_images": True,
|
||||
"block_ads": True,
|
||||
"proxy": "auto",
|
||||
"store_in_cache": True,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -95,7 +109,7 @@ class FirecrawlScrapeWebsiteTool(BaseTool):
|
||||
if not self._firecrawl:
|
||||
raise RuntimeError("FirecrawlApp not properly initialized")
|
||||
|
||||
return self._firecrawl.scrape_url(url, params=self.config)
|
||||
return self._firecrawl.scrape(url=url, **self.config)
|
||||
|
||||
|
||||
try:
|
||||
|
||||
@@ -23,19 +23,24 @@ class FirecrawlSearchToolSchema(BaseModel):
|
||||
|
||||
|
||||
class FirecrawlSearchTool(BaseTool):
|
||||
"""Tool for searching webpages using Firecrawl. To run this tool, you need to have a Firecrawl API key.
|
||||
"""Tool for searching webpages using Firecrawl v2 API. To run this tool, you need to have a Firecrawl API key.
|
||||
|
||||
Args:
|
||||
api_key (str): Your Firecrawl API key.
|
||||
config (dict): Optional. It contains Firecrawl API parameters.
|
||||
config (dict): Optional. It contains Firecrawl v2 API parameters.
|
||||
|
||||
Default configuration options:
|
||||
limit (int): Maximum number of pages to crawl. Default: 5
|
||||
tbs (str): Time before search. Default: None
|
||||
lang (str): Language. Default: "en"
|
||||
country (str): Country. Default: "us"
|
||||
location (str): Location. Default: None
|
||||
timeout (int): Timeout in milliseconds. Default: 60000
|
||||
Default configuration options (Firecrawl v2 API):
|
||||
limit (int): Maximum number of search results to return. Default: 5
|
||||
tbs (str): Time-based search filter (e.g., "qdr:d" for past day). Default: None
|
||||
location (str): Location for search results. Default: None
|
||||
timeout (int): Request timeout in milliseconds. Default: None
|
||||
scrape_options (dict): Options for scraping the search results. Default: {"formats": ["markdown"]}
|
||||
- formats (list[str]): Content formats to return. Default: ["markdown"]
|
||||
- only_main_content (bool): Only return main content. Default: True
|
||||
- include_tags (list[str]): Tags to include. Default: []
|
||||
- exclude_tags (list[str]): Tags to exclude. Default: []
|
||||
- wait_for (int): Delay before fetching content in ms. Default: 0
|
||||
- timeout (int): Request timeout in milliseconds. Default: None
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(
|
||||
@@ -49,10 +54,15 @@ class FirecrawlSearchTool(BaseTool):
|
||||
default_factory=lambda: {
|
||||
"limit": 5,
|
||||
"tbs": None,
|
||||
"lang": "en",
|
||||
"country": "us",
|
||||
"location": None,
|
||||
"timeout": 60000,
|
||||
"timeout": None,
|
||||
"scrape_options": {
|
||||
"formats": ["markdown"],
|
||||
"only_main_content": True,
|
||||
"include_tags": [],
|
||||
"exclude_tags": [],
|
||||
"wait_for": 0,
|
||||
},
|
||||
}
|
||||
)
|
||||
_firecrawl: FirecrawlApp | None = PrivateAttr(None)
|
||||
@@ -106,7 +116,7 @@ class FirecrawlSearchTool(BaseTool):
|
||||
|
||||
return self._firecrawl.search(
|
||||
query=query,
|
||||
params=self.config,
|
||||
**self.config,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
import importlib
|
||||
import json
|
||||
import os
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from crewai.tools import BaseTool, EnvVar
|
||||
@@ -12,9 +12,13 @@ from pydantic.types import ImportString
|
||||
|
||||
|
||||
class QdrantToolSchema(BaseModel):
|
||||
query: str = Field(..., description="Query to search in Qdrant DB.")
|
||||
filter_by: str | None = None
|
||||
filter_value: str | None = None
|
||||
query: str = Field(..., description="Query to search in Qdrant DB")
|
||||
filter_by: str | None = Field(
|
||||
default=None, description="Parameter to filter the search by."
|
||||
)
|
||||
filter_value: Any | None = Field(
|
||||
default=None, description="Value to filter the search by."
|
||||
)
|
||||
|
||||
|
||||
class QdrantConfig(BaseModel):
|
||||
@@ -25,7 +29,9 @@ class QdrantConfig(BaseModel):
|
||||
collection_name: str
|
||||
limit: int = 3
|
||||
score_threshold: float = 0.35
|
||||
filter_conditions: list[tuple[str, Any]] = Field(default_factory=list)
|
||||
filter: Any | None = Field(
|
||||
default=None, description="Qdrant Filter instance for advanced filtering."
|
||||
)
|
||||
|
||||
|
||||
class QdrantVectorSearchTool(BaseTool):
|
||||
@@ -76,23 +82,26 @@ class QdrantVectorSearchTool(BaseTool):
|
||||
filter_value: Any | None = None,
|
||||
) -> str:
|
||||
"""Perform vector similarity search."""
|
||||
filter_ = self.qdrant_package.http.models.Filter
|
||||
field_condition = self.qdrant_package.http.models.FieldCondition
|
||||
match_value = self.qdrant_package.http.models.MatchValue
|
||||
conditions = self.qdrant_config.filter_conditions.copy()
|
||||
if filter_by and filter_value is not None:
|
||||
conditions.append((filter_by, filter_value))
|
||||
|
||||
search_filter = (
|
||||
filter_(
|
||||
must=[
|
||||
field_condition(key=k, match=match_value(value=v))
|
||||
for k, v in conditions
|
||||
]
|
||||
)
|
||||
if conditions
|
||||
else None
|
||||
self.qdrant_config.filter.model_copy()
|
||||
if self.qdrant_config.filter is not None
|
||||
else self.qdrant_package.http.models.Filter(must=[])
|
||||
)
|
||||
if filter_by and filter_value is not None:
|
||||
if not hasattr(search_filter, "must") or not isinstance(
|
||||
search_filter.must, list
|
||||
):
|
||||
search_filter.must = []
|
||||
search_filter.must.append(
|
||||
self.qdrant_package.http.models.FieldCondition(
|
||||
key=filter_by,
|
||||
match=self.qdrant_package.http.models.MatchValue(
|
||||
value=filter_value
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
query_vector = (
|
||||
self.custom_embedding_fn(query)
|
||||
if self.custom_embedding_fn
|
||||
|
||||
@@ -247,7 +247,7 @@ class StagehandTool(BaseTool):
|
||||
if "claude" in model_str.lower() or "anthropic" in model_str.lower():
|
||||
return self.model_api_key or os.getenv("ANTHROPIC_API_KEY")
|
||||
if "gemini" in model_str.lower():
|
||||
return self.model_api_key or os.getenv("GOOGLE_API_KEY")
|
||||
return self.model_api_key or os.getenv("GEMINI_API_KEY")
|
||||
# Default to trying OpenAI, then Anthropic
|
||||
return (
|
||||
self.model_api_key
|
||||
@@ -313,7 +313,7 @@ class StagehandTool(BaseTool):
|
||||
|
||||
if not model_api_key:
|
||||
raise ValueError(
|
||||
"No appropriate API key found for model. Please set OPENAI_API_KEY, ANTHROPIC_API_KEY, or GOOGLE_API_KEY"
|
||||
"No appropriate API key found for model. Please set OPENAI_API_KEY, ANTHROPIC_API_KEY, or GEMINI_API_KEY"
|
||||
)
|
||||
|
||||
# Build the StagehandConfig with proper parameter names
|
||||
|
||||
@@ -1,21 +0,0 @@
|
||||
import pytest
|
||||
|
||||
|
||||
def pytest_configure(config):
|
||||
"""Register custom markers."""
|
||||
config.addinivalue_line("markers", "integration: mark test as an integration test")
|
||||
config.addinivalue_line("markers", "asyncio: mark test as an async test")
|
||||
|
||||
# Set the asyncio loop scope through ini configuration
|
||||
config.inicfg["asyncio_mode"] = "auto"
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def event_loop():
|
||||
"""Create an instance of the default event loop for each test case."""
|
||||
import asyncio
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
yield loop
|
||||
loop.close()
|
||||
@@ -2,7 +2,7 @@ import json
|
||||
from unittest import mock
|
||||
|
||||
from crewai.tools.base_tool import BaseTool, EnvVar
|
||||
from generate_tool_specs import ToolSpecExtractor
|
||||
from crewai_tools.generate_tool_specs import ToolSpecExtractor
|
||||
from pydantic import BaseModel, Field
|
||||
import pytest
|
||||
|
||||
@@ -61,8 +61,8 @@ def test_unwrap_schema(extractor):
|
||||
@pytest.fixture
|
||||
def mock_tool_extractor(extractor):
|
||||
with (
|
||||
mock.patch("generate_tool_specs.dir", return_value=["MockTool"]),
|
||||
mock.patch("generate_tool_specs.getattr", return_value=MockTool),
|
||||
mock.patch("crewai_tools.generate_tool_specs.dir", return_value=["MockTool"]),
|
||||
mock.patch("crewai_tools.generate_tool_specs.getattr", return_value=MockTool),
|
||||
):
|
||||
extractor.extract_all_tools()
|
||||
assert len(extractor.tools_spec) == 1
|
||||
|
||||
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@@ -0,0 +1,18 @@
|
||||
import pytest
|
||||
|
||||
from crewai_tools.tools.firecrawl_crawl_website_tool.firecrawl_crawl_website_tool import (
|
||||
FirecrawlCrawlWebsiteTool,
|
||||
)
|
||||
|
||||
@pytest.mark.vcr()
|
||||
def test_firecrawl_crawl_tool_integration():
|
||||
tool = FirecrawlCrawlWebsiteTool(config={
|
||||
"limit": 2,
|
||||
"max_discovery_depth": 1,
|
||||
"scrape_options": {"formats": ["markdown"]}
|
||||
})
|
||||
result = tool.run(url="https://firecrawl.dev")
|
||||
|
||||
assert result is not None
|
||||
assert hasattr(result, 'status')
|
||||
assert result.status in ["completed", "scraping"]
|
||||
@@ -0,0 +1,15 @@
|
||||
import pytest
|
||||
|
||||
from crewai_tools.tools.firecrawl_scrape_website_tool.firecrawl_scrape_website_tool import (
|
||||
FirecrawlScrapeWebsiteTool,
|
||||
)
|
||||
|
||||
@pytest.mark.vcr()
|
||||
def test_firecrawl_scrape_tool_integration():
|
||||
tool = FirecrawlScrapeWebsiteTool()
|
||||
result = tool.run(url="https://firecrawl.dev")
|
||||
|
||||
assert result is not None
|
||||
assert hasattr(result, 'markdown')
|
||||
assert len(result.markdown) > 0
|
||||
assert "Firecrawl" in result.markdown or "firecrawl" in result.markdown.lower()
|
||||
12
lib/crewai-tools/tests/tools/firecrawl_search_tool_test.py
Normal file
12
lib/crewai-tools/tests/tools/firecrawl_search_tool_test.py
Normal file
@@ -0,0 +1,12 @@
|
||||
import pytest
|
||||
|
||||
from crewai_tools.tools.firecrawl_search_tool.firecrawl_search_tool import FirecrawlSearchTool
|
||||
|
||||
|
||||
@pytest.mark.vcr()
|
||||
def test_firecrawl_search_tool_integration():
|
||||
tool = FirecrawlSearchTool()
|
||||
result = tool.run(query="firecrawl")
|
||||
|
||||
assert result is not None
|
||||
assert hasattr(result, 'web') or hasattr(result, 'news') or hasattr(result, 'images')
|
||||
0
lib/crewai-tools/tests/tools/rag/__init__.py
Normal file
0
lib/crewai-tools/tests/tools/rag/__init__.py
Normal file
@@ -23,7 +23,6 @@ dependencies = [
|
||||
"chromadb~=1.1.0",
|
||||
"tokenizers>=0.20.3",
|
||||
"openpyxl>=3.1.5",
|
||||
"pyvis>=0.3.2",
|
||||
# Authentication and Security
|
||||
"python-dotenv>=1.1.1",
|
||||
"pyjwt>=2.9.0",
|
||||
@@ -49,7 +48,7 @@ Repository = "https://github.com/crewAIInc/crewAI"
|
||||
|
||||
[project.optional-dependencies]
|
||||
tools = [
|
||||
"crewai-tools==1.2.1",
|
||||
"crewai-tools==1.4.1",
|
||||
]
|
||||
embeddings = [
|
||||
"tiktoken~=0.8.0"
|
||||
@@ -94,10 +93,11 @@ azure-ai-inference = [
|
||||
anthropic = [
|
||||
"anthropic>=0.69.0",
|
||||
]
|
||||
# a2a = [
|
||||
# "a2a-sdk~=0.3.9",
|
||||
# "httpx-sse>=0.4.0",
|
||||
# ]
|
||||
a2a = [
|
||||
"a2a-sdk~=0.3.10",
|
||||
"httpx-auth>=0.23.1",
|
||||
"httpx-sse>=0.4.0",
|
||||
]
|
||||
|
||||
|
||||
[project.scripts]
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import Any
|
||||
import urllib.request
|
||||
import warnings
|
||||
|
||||
from crewai.agent import Agent
|
||||
from crewai.agent.core import Agent
|
||||
from crewai.crew import Crew
|
||||
from crewai.crews.crew_output import CrewOutput
|
||||
from crewai.flow.flow import Flow
|
||||
@@ -40,7 +40,7 @@ def _suppress_pydantic_deprecation_warnings() -> None:
|
||||
|
||||
_suppress_pydantic_deprecation_warnings()
|
||||
|
||||
__version__ = "1.2.1"
|
||||
__version__ = "1.4.1"
|
||||
_telemetry_submitted = False
|
||||
|
||||
|
||||
|
||||
6
lib/crewai/src/crewai/a2a/__init__.py
Normal file
6
lib/crewai/src/crewai/a2a/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
"""Agent-to-Agent (A2A) protocol communication module for CrewAI."""
|
||||
|
||||
from crewai.a2a.config import A2AConfig
|
||||
|
||||
|
||||
__all__ = ["A2AConfig"]
|
||||
20
lib/crewai/src/crewai/a2a/auth/__init__.py
Normal file
20
lib/crewai/src/crewai/a2a/auth/__init__.py
Normal file
@@ -0,0 +1,20 @@
|
||||
"""A2A authentication schemas."""
|
||||
|
||||
from crewai.a2a.auth.schemas import (
|
||||
APIKeyAuth,
|
||||
BearerTokenAuth,
|
||||
HTTPBasicAuth,
|
||||
HTTPDigestAuth,
|
||||
OAuth2AuthorizationCode,
|
||||
OAuth2ClientCredentials,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"APIKeyAuth",
|
||||
"BearerTokenAuth",
|
||||
"HTTPBasicAuth",
|
||||
"HTTPDigestAuth",
|
||||
"OAuth2AuthorizationCode",
|
||||
"OAuth2ClientCredentials",
|
||||
]
|
||||
392
lib/crewai/src/crewai/a2a/auth/schemas.py
Normal file
392
lib/crewai/src/crewai/a2a/auth/schemas.py
Normal file
@@ -0,0 +1,392 @@
|
||||
"""Authentication schemes for A2A protocol agents.
|
||||
|
||||
Supported authentication methods:
|
||||
- Bearer tokens
|
||||
- OAuth2 (Client Credentials, Authorization Code)
|
||||
- API Keys (header, query, cookie)
|
||||
- HTTP Basic authentication
|
||||
- HTTP Digest authentication
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
import base64
|
||||
from collections.abc import Awaitable, Callable, MutableMapping
|
||||
import time
|
||||
from typing import Literal
|
||||
import urllib.parse
|
||||
|
||||
import httpx
|
||||
from httpx import DigestAuth
|
||||
from pydantic import BaseModel, Field, PrivateAttr
|
||||
|
||||
|
||||
class AuthScheme(ABC, BaseModel):
|
||||
"""Base class for authentication schemes."""
|
||||
|
||||
@abstractmethod
|
||||
async def apply_auth(
|
||||
self, client: httpx.AsyncClient, headers: MutableMapping[str, str]
|
||||
) -> MutableMapping[str, str]:
|
||||
"""Apply authentication to request headers.
|
||||
|
||||
Args:
|
||||
client: HTTP client for making auth requests.
|
||||
headers: Current request headers.
|
||||
|
||||
Returns:
|
||||
Updated headers with authentication applied.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class BearerTokenAuth(AuthScheme):
|
||||
"""Bearer token authentication (Authorization: Bearer <token>).
|
||||
|
||||
Attributes:
|
||||
token: Bearer token for authentication.
|
||||
"""
|
||||
|
||||
token: str = Field(description="Bearer token")
|
||||
|
||||
async def apply_auth(
|
||||
self, client: httpx.AsyncClient, headers: MutableMapping[str, str]
|
||||
) -> MutableMapping[str, str]:
|
||||
"""Apply Bearer token to Authorization header.
|
||||
|
||||
Args:
|
||||
client: HTTP client for making auth requests.
|
||||
headers: Current request headers.
|
||||
|
||||
Returns:
|
||||
Updated headers with Bearer token in Authorization header.
|
||||
"""
|
||||
headers["Authorization"] = f"Bearer {self.token}"
|
||||
return headers
|
||||
|
||||
|
||||
class HTTPBasicAuth(AuthScheme):
|
||||
"""HTTP Basic authentication.
|
||||
|
||||
Attributes:
|
||||
username: Username for Basic authentication.
|
||||
password: Password for Basic authentication.
|
||||
"""
|
||||
|
||||
username: str = Field(description="Username")
|
||||
password: str = Field(description="Password")
|
||||
|
||||
async def apply_auth(
|
||||
self, client: httpx.AsyncClient, headers: MutableMapping[str, str]
|
||||
) -> MutableMapping[str, str]:
|
||||
"""Apply HTTP Basic authentication.
|
||||
|
||||
Args:
|
||||
client: HTTP client for making auth requests.
|
||||
headers: Current request headers.
|
||||
|
||||
Returns:
|
||||
Updated headers with Basic auth in Authorization header.
|
||||
"""
|
||||
credentials = f"{self.username}:{self.password}"
|
||||
encoded = base64.b64encode(credentials.encode()).decode()
|
||||
headers["Authorization"] = f"Basic {encoded}"
|
||||
return headers
|
||||
|
||||
|
||||
class HTTPDigestAuth(AuthScheme):
|
||||
"""HTTP Digest authentication.
|
||||
|
||||
Note: Uses httpx-auth library for digest implementation.
|
||||
|
||||
Attributes:
|
||||
username: Username for Digest authentication.
|
||||
password: Password for Digest authentication.
|
||||
"""
|
||||
|
||||
username: str = Field(description="Username")
|
||||
password: str = Field(description="Password")
|
||||
|
||||
async def apply_auth(
|
||||
self, client: httpx.AsyncClient, headers: MutableMapping[str, str]
|
||||
) -> MutableMapping[str, str]:
|
||||
"""Digest auth is handled by httpx auth flow, not headers.
|
||||
|
||||
Args:
|
||||
client: HTTP client for making auth requests.
|
||||
headers: Current request headers.
|
||||
|
||||
Returns:
|
||||
Unchanged headers (Digest auth handled by httpx auth flow).
|
||||
"""
|
||||
return headers
|
||||
|
||||
def configure_client(self, client: httpx.AsyncClient) -> None:
|
||||
"""Configure client with Digest auth.
|
||||
|
||||
Args:
|
||||
client: HTTP client to configure with Digest authentication.
|
||||
"""
|
||||
client.auth = DigestAuth(self.username, self.password)
|
||||
|
||||
|
||||
class APIKeyAuth(AuthScheme):
|
||||
"""API Key authentication (header, query, or cookie).
|
||||
|
||||
Attributes:
|
||||
api_key: API key value for authentication.
|
||||
location: Where to send the API key (header, query, or cookie).
|
||||
name: Parameter name for the API key (default: X-API-Key).
|
||||
"""
|
||||
|
||||
api_key: str = Field(description="API key value")
|
||||
location: Literal["header", "query", "cookie"] = Field(
|
||||
default="header", description="Where to send the API key"
|
||||
)
|
||||
name: str = Field(default="X-API-Key", description="Parameter name for the API key")
|
||||
|
||||
async def apply_auth(
|
||||
self, client: httpx.AsyncClient, headers: MutableMapping[str, str]
|
||||
) -> MutableMapping[str, str]:
|
||||
"""Apply API key authentication.
|
||||
|
||||
Args:
|
||||
client: HTTP client for making auth requests.
|
||||
headers: Current request headers.
|
||||
|
||||
Returns:
|
||||
Updated headers with API key (for header/cookie locations).
|
||||
"""
|
||||
if self.location == "header":
|
||||
headers[self.name] = self.api_key
|
||||
elif self.location == "cookie":
|
||||
headers["Cookie"] = f"{self.name}={self.api_key}"
|
||||
return headers
|
||||
|
||||
def configure_client(self, client: httpx.AsyncClient) -> None:
|
||||
"""Configure client for query param API keys.
|
||||
|
||||
Args:
|
||||
client: HTTP client to configure with query param API key hook.
|
||||
"""
|
||||
if self.location == "query":
|
||||
|
||||
async def _add_api_key_param(request: httpx.Request) -> None:
|
||||
url = httpx.URL(request.url)
|
||||
request.url = url.copy_add_param(self.name, self.api_key)
|
||||
|
||||
client.event_hooks["request"].append(_add_api_key_param)
|
||||
|
||||
|
||||
class OAuth2ClientCredentials(AuthScheme):
|
||||
"""OAuth2 Client Credentials flow authentication.
|
||||
|
||||
Attributes:
|
||||
token_url: OAuth2 token endpoint URL.
|
||||
client_id: OAuth2 client identifier.
|
||||
client_secret: OAuth2 client secret.
|
||||
scopes: List of required OAuth2 scopes.
|
||||
"""
|
||||
|
||||
token_url: str = Field(description="OAuth2 token endpoint")
|
||||
client_id: str = Field(description="OAuth2 client ID")
|
||||
client_secret: str = Field(description="OAuth2 client secret")
|
||||
scopes: list[str] = Field(
|
||||
default_factory=list, description="Required OAuth2 scopes"
|
||||
)
|
||||
|
||||
_access_token: str | None = PrivateAttr(default=None)
|
||||
_token_expires_at: float | None = PrivateAttr(default=None)
|
||||
|
||||
async def apply_auth(
|
||||
self, client: httpx.AsyncClient, headers: MutableMapping[str, str]
|
||||
) -> MutableMapping[str, str]:
|
||||
"""Apply OAuth2 access token to Authorization header.
|
||||
|
||||
Args:
|
||||
client: HTTP client for making token requests.
|
||||
headers: Current request headers.
|
||||
|
||||
Returns:
|
||||
Updated headers with OAuth2 access token in Authorization header.
|
||||
"""
|
||||
if (
|
||||
self._access_token is None
|
||||
or self._token_expires_at is None
|
||||
or time.time() >= self._token_expires_at
|
||||
):
|
||||
await self._fetch_token(client)
|
||||
|
||||
if self._access_token:
|
||||
headers["Authorization"] = f"Bearer {self._access_token}"
|
||||
|
||||
return headers
|
||||
|
||||
async def _fetch_token(self, client: httpx.AsyncClient) -> None:
|
||||
"""Fetch OAuth2 access token using client credentials flow.
|
||||
|
||||
Args:
|
||||
client: HTTP client for making token request.
|
||||
|
||||
Raises:
|
||||
httpx.HTTPStatusError: If token request fails.
|
||||
"""
|
||||
data = {
|
||||
"grant_type": "client_credentials",
|
||||
"client_id": self.client_id,
|
||||
"client_secret": self.client_secret,
|
||||
}
|
||||
|
||||
if self.scopes:
|
||||
data["scope"] = " ".join(self.scopes)
|
||||
|
||||
response = await client.post(self.token_url, data=data)
|
||||
response.raise_for_status()
|
||||
|
||||
token_data = response.json()
|
||||
self._access_token = token_data["access_token"]
|
||||
expires_in = token_data.get("expires_in", 3600)
|
||||
self._token_expires_at = time.time() + expires_in - 60
|
||||
|
||||
|
||||
class OAuth2AuthorizationCode(AuthScheme):
|
||||
"""OAuth2 Authorization Code flow authentication.
|
||||
|
||||
Note: Requires interactive authorization.
|
||||
|
||||
Attributes:
|
||||
authorization_url: OAuth2 authorization endpoint URL.
|
||||
token_url: OAuth2 token endpoint URL.
|
||||
client_id: OAuth2 client identifier.
|
||||
client_secret: OAuth2 client secret.
|
||||
redirect_uri: OAuth2 redirect URI for callback.
|
||||
scopes: List of required OAuth2 scopes.
|
||||
"""
|
||||
|
||||
authorization_url: str = Field(description="OAuth2 authorization endpoint")
|
||||
token_url: str = Field(description="OAuth2 token endpoint")
|
||||
client_id: str = Field(description="OAuth2 client ID")
|
||||
client_secret: str = Field(description="OAuth2 client secret")
|
||||
redirect_uri: str = Field(description="OAuth2 redirect URI")
|
||||
scopes: list[str] = Field(
|
||||
default_factory=list, description="Required OAuth2 scopes"
|
||||
)
|
||||
|
||||
_access_token: str | None = PrivateAttr(default=None)
|
||||
_refresh_token: str | None = PrivateAttr(default=None)
|
||||
_token_expires_at: float | None = PrivateAttr(default=None)
|
||||
_authorization_callback: Callable[[str], Awaitable[str]] | None = PrivateAttr(
|
||||
default=None
|
||||
)
|
||||
|
||||
def set_authorization_callback(
|
||||
self, callback: Callable[[str], Awaitable[str]] | None
|
||||
) -> None:
|
||||
"""Set callback to handle authorization URL.
|
||||
|
||||
Args:
|
||||
callback: Async function that receives authorization URL and returns auth code.
|
||||
"""
|
||||
self._authorization_callback = callback
|
||||
|
||||
async def apply_auth(
|
||||
self, client: httpx.AsyncClient, headers: MutableMapping[str, str]
|
||||
) -> MutableMapping[str, str]:
|
||||
"""Apply OAuth2 access token to Authorization header.
|
||||
|
||||
Args:
|
||||
client: HTTP client for making token requests.
|
||||
headers: Current request headers.
|
||||
|
||||
Returns:
|
||||
Updated headers with OAuth2 access token in Authorization header.
|
||||
|
||||
Raises:
|
||||
ValueError: If authorization callback is not set.
|
||||
"""
|
||||
|
||||
if self._access_token is None:
|
||||
if self._authorization_callback is None:
|
||||
msg = "Authorization callback not set. Use set_authorization_callback()"
|
||||
raise ValueError(msg)
|
||||
await self._fetch_initial_token(client)
|
||||
elif self._token_expires_at and time.time() >= self._token_expires_at:
|
||||
await self._refresh_access_token(client)
|
||||
|
||||
if self._access_token:
|
||||
headers["Authorization"] = f"Bearer {self._access_token}"
|
||||
|
||||
return headers
|
||||
|
||||
async def _fetch_initial_token(self, client: httpx.AsyncClient) -> None:
|
||||
"""Fetch initial access token using authorization code flow.
|
||||
|
||||
Args:
|
||||
client: HTTP client for making token request.
|
||||
|
||||
Raises:
|
||||
ValueError: If authorization callback is not set.
|
||||
httpx.HTTPStatusError: If token request fails.
|
||||
"""
|
||||
params = {
|
||||
"response_type": "code",
|
||||
"client_id": self.client_id,
|
||||
"redirect_uri": self.redirect_uri,
|
||||
"scope": " ".join(self.scopes),
|
||||
}
|
||||
auth_url = f"{self.authorization_url}?{urllib.parse.urlencode(params)}"
|
||||
|
||||
if self._authorization_callback is None:
|
||||
msg = "Authorization callback not set"
|
||||
raise ValueError(msg)
|
||||
auth_code = await self._authorization_callback(auth_url)
|
||||
|
||||
data = {
|
||||
"grant_type": "authorization_code",
|
||||
"code": auth_code,
|
||||
"client_id": self.client_id,
|
||||
"client_secret": self.client_secret,
|
||||
"redirect_uri": self.redirect_uri,
|
||||
}
|
||||
|
||||
response = await client.post(self.token_url, data=data)
|
||||
response.raise_for_status()
|
||||
|
||||
token_data = response.json()
|
||||
self._access_token = token_data["access_token"]
|
||||
self._refresh_token = token_data.get("refresh_token")
|
||||
|
||||
expires_in = token_data.get("expires_in", 3600)
|
||||
self._token_expires_at = time.time() + expires_in - 60
|
||||
|
||||
async def _refresh_access_token(self, client: httpx.AsyncClient) -> None:
|
||||
"""Refresh the access token using refresh token.
|
||||
|
||||
Args:
|
||||
client: HTTP client for making token request.
|
||||
|
||||
Raises:
|
||||
httpx.HTTPStatusError: If token refresh request fails.
|
||||
"""
|
||||
if not self._refresh_token:
|
||||
await self._fetch_initial_token(client)
|
||||
return
|
||||
|
||||
data = {
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": self._refresh_token,
|
||||
"client_id": self.client_id,
|
||||
"client_secret": self.client_secret,
|
||||
}
|
||||
|
||||
response = await client.post(self.token_url, data=data)
|
||||
response.raise_for_status()
|
||||
|
||||
token_data = response.json()
|
||||
self._access_token = token_data["access_token"]
|
||||
if "refresh_token" in token_data:
|
||||
self._refresh_token = token_data["refresh_token"]
|
||||
|
||||
expires_in = token_data.get("expires_in", 3600)
|
||||
self._token_expires_at = time.time() + expires_in - 60
|
||||
236
lib/crewai/src/crewai/a2a/auth/utils.py
Normal file
236
lib/crewai/src/crewai/a2a/auth/utils.py
Normal file
@@ -0,0 +1,236 @@
|
||||
"""Authentication utilities for A2A protocol agent communication.
|
||||
|
||||
Provides validation and retry logic for various authentication schemes including
|
||||
OAuth2, API keys, and HTTP authentication methods.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Awaitable, Callable, MutableMapping
|
||||
import re
|
||||
from typing import Final
|
||||
|
||||
from a2a.client.errors import A2AClientHTTPError
|
||||
from a2a.types import (
|
||||
APIKeySecurityScheme,
|
||||
AgentCard,
|
||||
HTTPAuthSecurityScheme,
|
||||
OAuth2SecurityScheme,
|
||||
)
|
||||
from httpx import AsyncClient, Response
|
||||
|
||||
from crewai.a2a.auth.schemas import (
|
||||
APIKeyAuth,
|
||||
AuthScheme,
|
||||
BearerTokenAuth,
|
||||
HTTPBasicAuth,
|
||||
HTTPDigestAuth,
|
||||
OAuth2AuthorizationCode,
|
||||
OAuth2ClientCredentials,
|
||||
)
|
||||
|
||||
|
||||
_auth_store: dict[int, AuthScheme | None] = {}
|
||||
|
||||
_SCHEME_PATTERN: Final[re.Pattern[str]] = re.compile(r"(\w+)\s+(.+?)(?=,\s*\w+\s+|$)")
|
||||
_PARAM_PATTERN: Final[re.Pattern[str]] = re.compile(r'(\w+)=(?:"([^"]*)"|([^\s,]+))')
|
||||
|
||||
_SCHEME_AUTH_MAPPING: Final[dict[type, tuple[type[AuthScheme], ...]]] = {
|
||||
OAuth2SecurityScheme: (
|
||||
OAuth2ClientCredentials,
|
||||
OAuth2AuthorizationCode,
|
||||
BearerTokenAuth,
|
||||
),
|
||||
APIKeySecurityScheme: (APIKeyAuth,),
|
||||
}
|
||||
|
||||
_HTTP_SCHEME_MAPPING: Final[dict[str, type[AuthScheme]]] = {
|
||||
"basic": HTTPBasicAuth,
|
||||
"digest": HTTPDigestAuth,
|
||||
"bearer": BearerTokenAuth,
|
||||
}
|
||||
|
||||
|
||||
def _raise_auth_mismatch(
|
||||
expected_classes: type[AuthScheme] | tuple[type[AuthScheme], ...],
|
||||
provided_auth: AuthScheme,
|
||||
) -> None:
|
||||
"""Raise authentication mismatch error.
|
||||
|
||||
Args:
|
||||
expected_classes: Expected authentication class or tuple of classes.
|
||||
provided_auth: Actually provided authentication instance.
|
||||
|
||||
Raises:
|
||||
A2AClientHTTPError: Always raises with 401 status code.
|
||||
"""
|
||||
if isinstance(expected_classes, tuple):
|
||||
if len(expected_classes) == 1:
|
||||
required = expected_classes[0].__name__
|
||||
else:
|
||||
names = [cls.__name__ for cls in expected_classes]
|
||||
required = f"one of ({', '.join(names)})"
|
||||
else:
|
||||
required = expected_classes.__name__
|
||||
|
||||
msg = (
|
||||
f"AgentCard requires {required} authentication, "
|
||||
f"but {type(provided_auth).__name__} was provided"
|
||||
)
|
||||
raise A2AClientHTTPError(401, msg)
|
||||
|
||||
|
||||
def parse_www_authenticate(header_value: str) -> dict[str, dict[str, str]]:
|
||||
"""Parse WWW-Authenticate header into auth challenges.
|
||||
|
||||
Args:
|
||||
header_value: The WWW-Authenticate header value.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping auth scheme to its parameters.
|
||||
Example: {"Bearer": {"realm": "api", "scope": "read write"}}
|
||||
"""
|
||||
if not header_value:
|
||||
return {}
|
||||
|
||||
challenges: dict[str, dict[str, str]] = {}
|
||||
|
||||
for match in _SCHEME_PATTERN.finditer(header_value):
|
||||
scheme = match.group(1)
|
||||
params_str = match.group(2)
|
||||
|
||||
params: dict[str, str] = {}
|
||||
|
||||
for param_match in _PARAM_PATTERN.finditer(params_str):
|
||||
key = param_match.group(1)
|
||||
value = param_match.group(2) or param_match.group(3)
|
||||
params[key] = value
|
||||
|
||||
challenges[scheme] = params
|
||||
|
||||
return challenges
|
||||
|
||||
|
||||
def validate_auth_against_agent_card(
|
||||
agent_card: AgentCard, auth: AuthScheme | None
|
||||
) -> None:
|
||||
"""Validate that provided auth matches AgentCard security requirements.
|
||||
|
||||
Args:
|
||||
agent_card: The A2A AgentCard containing security requirements.
|
||||
auth: User-provided authentication scheme (or None).
|
||||
|
||||
Raises:
|
||||
A2AClientHTTPError: If auth doesn't match AgentCard requirements (status_code=401).
|
||||
"""
|
||||
|
||||
if not agent_card.security or not agent_card.security_schemes:
|
||||
return
|
||||
|
||||
if not auth:
|
||||
msg = "AgentCard requires authentication but no auth scheme provided"
|
||||
raise A2AClientHTTPError(401, msg)
|
||||
|
||||
first_security_req = agent_card.security[0] if agent_card.security else {}
|
||||
|
||||
for scheme_name in first_security_req.keys():
|
||||
security_scheme_wrapper = agent_card.security_schemes.get(scheme_name)
|
||||
if not security_scheme_wrapper:
|
||||
continue
|
||||
|
||||
scheme = security_scheme_wrapper.root
|
||||
|
||||
if allowed_classes := _SCHEME_AUTH_MAPPING.get(type(scheme)):
|
||||
if not isinstance(auth, allowed_classes):
|
||||
_raise_auth_mismatch(allowed_classes, auth)
|
||||
return
|
||||
|
||||
if isinstance(scheme, HTTPAuthSecurityScheme):
|
||||
if required_class := _HTTP_SCHEME_MAPPING.get(scheme.scheme.lower()):
|
||||
if not isinstance(auth, required_class):
|
||||
_raise_auth_mismatch(required_class, auth)
|
||||
return
|
||||
|
||||
msg = "Could not validate auth against AgentCard security requirements"
|
||||
raise A2AClientHTTPError(401, msg)
|
||||
|
||||
|
||||
async def retry_on_401(
|
||||
request_func: Callable[[], Awaitable[Response]],
|
||||
auth_scheme: AuthScheme | None,
|
||||
client: AsyncClient,
|
||||
headers: MutableMapping[str, str],
|
||||
max_retries: int = 3,
|
||||
) -> Response:
|
||||
"""Retry a request on 401 authentication error.
|
||||
|
||||
Handles 401 errors by:
|
||||
1. Parsing WWW-Authenticate header
|
||||
2. Re-acquiring credentials
|
||||
3. Retrying the request
|
||||
|
||||
Args:
|
||||
request_func: Async function that makes the HTTP request.
|
||||
auth_scheme: Authentication scheme to refresh credentials with.
|
||||
client: HTTP client for making requests.
|
||||
headers: Request headers to update with new auth.
|
||||
max_retries: Maximum number of retry attempts (default: 3).
|
||||
|
||||
Returns:
|
||||
HTTP response from the request.
|
||||
|
||||
Raises:
|
||||
httpx.HTTPStatusError: If retries are exhausted or auth scheme is None.
|
||||
"""
|
||||
last_response: Response | None = None
|
||||
last_challenges: dict[str, dict[str, str]] = {}
|
||||
|
||||
for attempt in range(max_retries):
|
||||
response = await request_func()
|
||||
|
||||
if response.status_code != 401:
|
||||
return response
|
||||
|
||||
last_response = response
|
||||
|
||||
if auth_scheme is None:
|
||||
response.raise_for_status()
|
||||
return response
|
||||
|
||||
www_authenticate = response.headers.get("WWW-Authenticate", "")
|
||||
challenges = parse_www_authenticate(www_authenticate)
|
||||
last_challenges = challenges
|
||||
|
||||
if attempt >= max_retries - 1:
|
||||
break
|
||||
|
||||
backoff_time = 2**attempt
|
||||
await asyncio.sleep(backoff_time)
|
||||
|
||||
await auth_scheme.apply_auth(client, headers)
|
||||
|
||||
if last_response:
|
||||
last_response.raise_for_status()
|
||||
return last_response
|
||||
|
||||
msg = "retry_on_401 failed without making any requests"
|
||||
if last_challenges:
|
||||
challenge_info = ", ".join(
|
||||
f"{scheme} (realm={params.get('realm', 'N/A')})"
|
||||
for scheme, params in last_challenges.items()
|
||||
)
|
||||
msg = f"{msg}. Server challenges: {challenge_info}"
|
||||
raise RuntimeError(msg)
|
||||
|
||||
|
||||
def configure_auth_client(
|
||||
auth: HTTPDigestAuth | APIKeyAuth, client: AsyncClient
|
||||
) -> None:
|
||||
"""Configure HTTP client with auth-specific settings.
|
||||
|
||||
Only HTTPDigestAuth and APIKeyAuth need client configuration.
|
||||
|
||||
Args:
|
||||
auth: Authentication scheme that requires client configuration.
|
||||
client: HTTP client to configure.
|
||||
"""
|
||||
auth.configure_client(client)
|
||||
59
lib/crewai/src/crewai/a2a/config.py
Normal file
59
lib/crewai/src/crewai/a2a/config.py
Normal file
@@ -0,0 +1,59 @@
|
||||
"""A2A configuration types.
|
||||
|
||||
This module is separate from experimental.a2a to avoid circular imports.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Annotated
|
||||
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
BeforeValidator,
|
||||
Field,
|
||||
HttpUrl,
|
||||
TypeAdapter,
|
||||
)
|
||||
|
||||
from crewai.a2a.auth.schemas import AuthScheme
|
||||
|
||||
|
||||
http_url_adapter = TypeAdapter(HttpUrl)
|
||||
|
||||
Url = Annotated[
|
||||
str,
|
||||
BeforeValidator(
|
||||
lambda value: str(http_url_adapter.validate_python(value, strict=True))
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
class A2AConfig(BaseModel):
|
||||
"""Configuration for A2A protocol integration.
|
||||
|
||||
Attributes:
|
||||
endpoint: A2A agent endpoint URL.
|
||||
auth: Authentication scheme (Bearer, OAuth2, API Key, HTTP Basic/Digest).
|
||||
timeout: Request timeout in seconds (default: 120).
|
||||
max_turns: Maximum conversation turns with A2A agent (default: 10).
|
||||
response_model: Optional Pydantic model for structured A2A agent responses.
|
||||
fail_fast: If True, raise error when agent unreachable; if False, skip and continue (default: True).
|
||||
"""
|
||||
|
||||
endpoint: Url = Field(description="A2A agent endpoint URL")
|
||||
auth: AuthScheme | None = Field(
|
||||
default=None,
|
||||
description="Authentication scheme (Bearer, OAuth2, API Key, HTTP Basic/Digest)",
|
||||
)
|
||||
timeout: int = Field(default=120, description="Request timeout in seconds")
|
||||
max_turns: int = Field(
|
||||
default=10, description="Maximum conversation turns with A2A agent"
|
||||
)
|
||||
response_model: type[BaseModel] | None = Field(
|
||||
default=None,
|
||||
description="Optional Pydantic model for structured A2A agent responses. When specified, the A2A agent is expected to return JSON matching this schema.",
|
||||
)
|
||||
fail_fast: bool = Field(
|
||||
default=True,
|
||||
description="If True, raise an error immediately when the A2A agent is unreachable. If False, skip the A2A agent and continue execution.",
|
||||
)
|
||||
29
lib/crewai/src/crewai/a2a/templates.py
Normal file
29
lib/crewai/src/crewai/a2a/templates.py
Normal file
@@ -0,0 +1,29 @@
|
||||
"""String templates for A2A (Agent-to-Agent) protocol messaging and status."""
|
||||
|
||||
from string import Template
|
||||
from typing import Final
|
||||
|
||||
|
||||
AVAILABLE_AGENTS_TEMPLATE: Final[Template] = Template(
|
||||
"\n<AVAILABLE_A2A_AGENTS>\n $available_a2a_agents\n</AVAILABLE_A2A_AGENTS>\n"
|
||||
)
|
||||
PREVIOUS_A2A_CONVERSATION_TEMPLATE: Final[Template] = Template(
|
||||
"\n<PREVIOUS_A2A_CONVERSATION>\n"
|
||||
" $previous_a2a_conversation"
|
||||
"\n</PREVIOUS_A2A_CONVERSATION>\n"
|
||||
)
|
||||
CONVERSATION_TURN_INFO_TEMPLATE: Final[Template] = Template(
|
||||
"\n<CONVERSATION_PROGRESS>\n"
|
||||
' turn="$turn_count"\n'
|
||||
' max_turns="$max_turns"\n'
|
||||
" $warning"
|
||||
"\n</CONVERSATION_PROGRESS>\n"
|
||||
)
|
||||
UNAVAILABLE_AGENTS_NOTICE_TEMPLATE: Final[Template] = Template(
|
||||
"\n<A2A_AGENTS_STATUS>\n"
|
||||
" NOTE: A2A agents were configured but are currently unavailable.\n"
|
||||
" You cannot delegate to remote agents for this task.\n\n"
|
||||
" Unavailable Agents:\n"
|
||||
" $unavailable_agents"
|
||||
"\n</A2A_AGENTS_STATUS>\n"
|
||||
)
|
||||
38
lib/crewai/src/crewai/a2a/types.py
Normal file
38
lib/crewai/src/crewai/a2a/types.py
Normal file
@@ -0,0 +1,38 @@
|
||||
"""Type definitions for A2A protocol message parts."""
|
||||
|
||||
from typing import Any, Literal, Protocol, TypedDict, runtime_checkable
|
||||
|
||||
from typing_extensions import NotRequired
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class AgentResponseProtocol(Protocol):
|
||||
"""Protocol for the dynamically created AgentResponse model."""
|
||||
|
||||
a2a_ids: tuple[str, ...]
|
||||
message: str
|
||||
is_a2a: bool
|
||||
|
||||
|
||||
class PartsMetadataDict(TypedDict, total=False):
|
||||
"""Metadata for A2A message parts.
|
||||
|
||||
Attributes:
|
||||
mimeType: MIME type for the part content.
|
||||
schema: JSON schema for the part content.
|
||||
"""
|
||||
|
||||
mimeType: Literal["application/json"]
|
||||
schema: dict[str, Any]
|
||||
|
||||
|
||||
class PartsDict(TypedDict):
|
||||
"""A2A message part containing text and optional metadata.
|
||||
|
||||
Attributes:
|
||||
text: The text content of the message part.
|
||||
metadata: Optional metadata describing the part content.
|
||||
"""
|
||||
|
||||
text: str
|
||||
metadata: NotRequired[PartsMetadataDict]
|
||||
755
lib/crewai/src/crewai/a2a/utils.py
Normal file
755
lib/crewai/src/crewai/a2a/utils.py
Normal file
@@ -0,0 +1,755 @@
|
||||
"""Utility functions for A2A (Agent-to-Agent) protocol delegation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import AsyncIterator, MutableMapping
|
||||
from contextlib import asynccontextmanager
|
||||
from functools import lru_cache
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Any
|
||||
import uuid
|
||||
|
||||
from a2a.client import Client, ClientConfig, ClientFactory
|
||||
from a2a.client.errors import A2AClientHTTPError
|
||||
from a2a.types import (
|
||||
AgentCard,
|
||||
Message,
|
||||
Part,
|
||||
Role,
|
||||
TaskArtifactUpdateEvent,
|
||||
TaskState,
|
||||
TaskStatusUpdateEvent,
|
||||
TextPart,
|
||||
TransportProtocol,
|
||||
)
|
||||
import httpx
|
||||
from pydantic import BaseModel, Field, create_model
|
||||
|
||||
from crewai.a2a.auth.schemas import APIKeyAuth, HTTPDigestAuth
|
||||
from crewai.a2a.auth.utils import (
|
||||
_auth_store,
|
||||
configure_auth_client,
|
||||
retry_on_401,
|
||||
validate_auth_against_agent_card,
|
||||
)
|
||||
from crewai.a2a.config import A2AConfig
|
||||
from crewai.a2a.types import PartsDict, PartsMetadataDict
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.a2a_events import (
|
||||
A2AConversationStartedEvent,
|
||||
A2ADelegationCompletedEvent,
|
||||
A2ADelegationStartedEvent,
|
||||
A2AMessageSentEvent,
|
||||
A2AResponseReceivedEvent,
|
||||
)
|
||||
from crewai.types.utils import create_literals_from_strings
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from a2a.types import Message, Task as A2ATask
|
||||
|
||||
from crewai.a2a.auth.schemas import AuthScheme
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def _fetch_agent_card_cached(
|
||||
endpoint: str,
|
||||
auth_hash: int,
|
||||
timeout: int,
|
||||
_ttl_hash: int,
|
||||
) -> AgentCard:
|
||||
"""Cached version of fetch_agent_card with auth support.
|
||||
|
||||
Args:
|
||||
endpoint: A2A agent endpoint URL
|
||||
auth_hash: Hash of the auth object
|
||||
timeout: Request timeout
|
||||
_ttl_hash: Time-based hash for cache invalidation (unused in body)
|
||||
|
||||
Returns:
|
||||
Cached AgentCard
|
||||
"""
|
||||
auth = _auth_store.get(auth_hash)
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
return loop.run_until_complete(
|
||||
_fetch_agent_card_async(endpoint=endpoint, auth=auth, timeout=timeout)
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
|
||||
def fetch_agent_card(
|
||||
endpoint: str,
|
||||
auth: AuthScheme | None = None,
|
||||
timeout: int = 30,
|
||||
use_cache: bool = True,
|
||||
cache_ttl: int = 300,
|
||||
) -> AgentCard:
|
||||
"""Fetch AgentCard from an A2A endpoint with optional caching.
|
||||
|
||||
Args:
|
||||
endpoint: A2A agent endpoint URL (AgentCard URL)
|
||||
auth: Optional AuthScheme for authentication
|
||||
timeout: Request timeout in seconds
|
||||
use_cache: Whether to use caching (default True)
|
||||
cache_ttl: Cache TTL in seconds (default 300 = 5 minutes)
|
||||
|
||||
Returns:
|
||||
AgentCard object with agent capabilities and skills
|
||||
|
||||
Raises:
|
||||
httpx.HTTPStatusError: If the request fails
|
||||
A2AClientHTTPError: If authentication fails
|
||||
"""
|
||||
if use_cache:
|
||||
auth_hash = hash((type(auth).__name__, id(auth))) if auth else 0
|
||||
_auth_store[auth_hash] = auth
|
||||
ttl_hash = int(time.time() // cache_ttl)
|
||||
return _fetch_agent_card_cached(endpoint, auth_hash, timeout, ttl_hash)
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
return loop.run_until_complete(
|
||||
_fetch_agent_card_async(endpoint=endpoint, auth=auth, timeout=timeout)
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
|
||||
async def _fetch_agent_card_async(
|
||||
endpoint: str,
|
||||
auth: AuthScheme | None,
|
||||
timeout: int,
|
||||
) -> AgentCard:
|
||||
"""Async implementation of AgentCard fetching.
|
||||
|
||||
Args:
|
||||
endpoint: A2A agent endpoint URL
|
||||
auth: Optional AuthScheme for authentication
|
||||
timeout: Request timeout in seconds
|
||||
|
||||
Returns:
|
||||
AgentCard object
|
||||
"""
|
||||
if "/.well-known/agent-card.json" in endpoint:
|
||||
base_url = endpoint.replace("/.well-known/agent-card.json", "")
|
||||
agent_card_path = "/.well-known/agent-card.json"
|
||||
else:
|
||||
url_parts = endpoint.split("/", 3)
|
||||
base_url = f"{url_parts[0]}//{url_parts[2]}"
|
||||
agent_card_path = f"/{url_parts[3]}" if len(url_parts) > 3 else "/"
|
||||
|
||||
headers: MutableMapping[str, str] = {}
|
||||
if auth:
|
||||
async with httpx.AsyncClient(timeout=timeout) as temp_auth_client:
|
||||
if isinstance(auth, (HTTPDigestAuth, APIKeyAuth)):
|
||||
configure_auth_client(auth, temp_auth_client)
|
||||
headers = await auth.apply_auth(temp_auth_client, {})
|
||||
|
||||
async with httpx.AsyncClient(timeout=timeout, headers=headers) as temp_client:
|
||||
if auth and isinstance(auth, (HTTPDigestAuth, APIKeyAuth)):
|
||||
configure_auth_client(auth, temp_client)
|
||||
|
||||
agent_card_url = f"{base_url}{agent_card_path}"
|
||||
|
||||
async def _fetch_agent_card_request() -> httpx.Response:
|
||||
return await temp_client.get(agent_card_url)
|
||||
|
||||
try:
|
||||
response = await retry_on_401(
|
||||
request_func=_fetch_agent_card_request,
|
||||
auth_scheme=auth,
|
||||
client=temp_client,
|
||||
headers=temp_client.headers,
|
||||
max_retries=2,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
return AgentCard.model_validate(response.json())
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
if e.response.status_code == 401:
|
||||
error_details = ["Authentication failed"]
|
||||
www_auth = e.response.headers.get("WWW-Authenticate")
|
||||
if www_auth:
|
||||
error_details.append(f"WWW-Authenticate: {www_auth}")
|
||||
if not auth:
|
||||
error_details.append("No auth scheme provided")
|
||||
msg = " | ".join(error_details)
|
||||
raise A2AClientHTTPError(401, msg) from e
|
||||
raise
|
||||
|
||||
|
||||
def execute_a2a_delegation(
|
||||
endpoint: str,
|
||||
auth: AuthScheme | None,
|
||||
timeout: int,
|
||||
task_description: str,
|
||||
context: str | None = None,
|
||||
context_id: str | None = None,
|
||||
task_id: str | None = None,
|
||||
reference_task_ids: list[str] | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
extensions: dict[str, Any] | None = None,
|
||||
conversation_history: list[Message] | None = None,
|
||||
agent_id: str | None = None,
|
||||
agent_role: Role | None = None,
|
||||
agent_branch: Any | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
turn_number: int | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Execute a task delegation to a remote A2A agent with multi-turn support.
|
||||
|
||||
Handles:
|
||||
- AgentCard discovery
|
||||
- Authentication setup
|
||||
- Message creation and sending
|
||||
- Response parsing
|
||||
- Multi-turn conversations
|
||||
|
||||
Args:
|
||||
endpoint: A2A agent endpoint URL (AgentCard URL)
|
||||
auth: Optional AuthScheme for authentication (Bearer, OAuth2, API Key, HTTP Basic/Digest)
|
||||
timeout: Request timeout in seconds
|
||||
task_description: The task to delegate
|
||||
context: Optional context information
|
||||
context_id: Context ID for correlating messages/tasks
|
||||
task_id: Specific task identifier
|
||||
reference_task_ids: List of related task IDs
|
||||
metadata: Additional metadata (external_id, request_id, etc.)
|
||||
extensions: Protocol extensions for custom fields
|
||||
conversation_history: Previous Message objects from conversation
|
||||
agent_id: Agent identifier for logging
|
||||
agent_role: Role of the CrewAI agent delegating the task
|
||||
agent_branch: Optional agent tree branch for logging
|
||||
response_model: Optional Pydantic model for structured outputs
|
||||
turn_number: Optional turn number for multi-turn conversations
|
||||
|
||||
Returns:
|
||||
Dictionary with:
|
||||
- status: "completed", "input_required", "failed", etc.
|
||||
- result: Result string (if completed)
|
||||
- error: Error message (if failed)
|
||||
- history: List of new Message objects from this exchange
|
||||
|
||||
Raises:
|
||||
ImportError: If a2a-sdk is not installed
|
||||
"""
|
||||
is_multiturn = bool(conversation_history and len(conversation_history) > 0)
|
||||
if turn_number is None:
|
||||
turn_number = (
|
||||
len([m for m in (conversation_history or []) if m.role == Role.user]) + 1
|
||||
)
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
A2ADelegationStartedEvent(
|
||||
endpoint=endpoint,
|
||||
task_description=task_description,
|
||||
agent_id=agent_id,
|
||||
is_multiturn=is_multiturn,
|
||||
turn_number=turn_number,
|
||||
),
|
||||
)
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
result = loop.run_until_complete(
|
||||
_execute_a2a_delegation_async(
|
||||
endpoint=endpoint,
|
||||
auth=auth,
|
||||
timeout=timeout,
|
||||
task_description=task_description,
|
||||
context=context,
|
||||
context_id=context_id,
|
||||
task_id=task_id,
|
||||
reference_task_ids=reference_task_ids,
|
||||
metadata=metadata,
|
||||
extensions=extensions,
|
||||
conversation_history=conversation_history or [],
|
||||
is_multiturn=is_multiturn,
|
||||
turn_number=turn_number,
|
||||
agent_branch=agent_branch,
|
||||
agent_id=agent_id,
|
||||
agent_role=agent_role,
|
||||
response_model=response_model,
|
||||
)
|
||||
)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
A2ADelegationCompletedEvent(
|
||||
status=result["status"],
|
||||
result=result.get("result"),
|
||||
error=result.get("error"),
|
||||
is_multiturn=is_multiturn,
|
||||
),
|
||||
)
|
||||
|
||||
return result
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
|
||||
async def _execute_a2a_delegation_async(
|
||||
endpoint: str,
|
||||
auth: AuthScheme | None,
|
||||
timeout: int,
|
||||
task_description: str,
|
||||
context: str | None,
|
||||
context_id: str | None,
|
||||
task_id: str | None,
|
||||
reference_task_ids: list[str] | None,
|
||||
metadata: dict[str, Any] | None,
|
||||
extensions: dict[str, Any] | None,
|
||||
conversation_history: list[Message],
|
||||
is_multiturn: bool = False,
|
||||
turn_number: int = 1,
|
||||
agent_branch: Any | None = None,
|
||||
agent_id: str | None = None,
|
||||
agent_role: str | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Async implementation of A2A delegation with multi-turn support.
|
||||
|
||||
Args:
|
||||
endpoint: A2A agent endpoint URL
|
||||
auth: Optional AuthScheme for authentication
|
||||
timeout: Request timeout in seconds
|
||||
task_description: Task to delegate
|
||||
context: Optional context
|
||||
context_id: Context ID for correlation
|
||||
task_id: Specific task identifier
|
||||
reference_task_ids: Related task IDs
|
||||
metadata: Additional metadata
|
||||
extensions: Protocol extensions
|
||||
conversation_history: Previous Message objects
|
||||
is_multiturn: Whether this is a multi-turn conversation
|
||||
turn_number: Current turn number
|
||||
agent_branch: Agent tree branch for logging
|
||||
agent_id: Agent identifier for logging
|
||||
agent_role: Agent role for logging
|
||||
response_model: Optional Pydantic model for structured outputs
|
||||
|
||||
Returns:
|
||||
Dictionary with status, result/error, and new history
|
||||
"""
|
||||
agent_card = await _fetch_agent_card_async(endpoint, auth, timeout)
|
||||
|
||||
validate_auth_against_agent_card(agent_card, auth)
|
||||
|
||||
headers: MutableMapping[str, str] = {}
|
||||
if auth:
|
||||
async with httpx.AsyncClient(timeout=timeout) as temp_auth_client:
|
||||
if isinstance(auth, (HTTPDigestAuth, APIKeyAuth)):
|
||||
configure_auth_client(auth, temp_auth_client)
|
||||
headers = await auth.apply_auth(temp_auth_client, {})
|
||||
|
||||
a2a_agent_name = None
|
||||
if agent_card.name:
|
||||
a2a_agent_name = agent_card.name
|
||||
|
||||
if turn_number == 1:
|
||||
agent_id_for_event = agent_id or endpoint
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
A2AConversationStartedEvent(
|
||||
agent_id=agent_id_for_event,
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
),
|
||||
)
|
||||
|
||||
message_parts = []
|
||||
|
||||
if context:
|
||||
message_parts.append(f"Context:\n{context}\n\n")
|
||||
message_parts.append(f"{task_description}")
|
||||
message_text = "".join(message_parts)
|
||||
|
||||
if is_multiturn and conversation_history and not task_id:
|
||||
if first_task_id := conversation_history[0].task_id:
|
||||
task_id = first_task_id
|
||||
|
||||
parts: PartsDict = {"text": message_text}
|
||||
if response_model:
|
||||
parts.update(
|
||||
{
|
||||
"metadata": PartsMetadataDict(
|
||||
mimeType="application/json",
|
||||
schema=response_model.model_json_schema(),
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
message = Message(
|
||||
role=Role.user,
|
||||
message_id=str(uuid.uuid4()),
|
||||
parts=[Part(root=TextPart(**parts))],
|
||||
context_id=context_id,
|
||||
task_id=task_id,
|
||||
reference_task_ids=reference_task_ids,
|
||||
metadata=metadata,
|
||||
extensions=extensions,
|
||||
)
|
||||
|
||||
transport_protocol = TransportProtocol("JSONRPC")
|
||||
new_messages: list[Message] = [*conversation_history, message]
|
||||
crewai_event_bus.emit(
|
||||
None,
|
||||
A2AMessageSentEvent(
|
||||
message=message_text,
|
||||
turn_number=turn_number,
|
||||
is_multiturn=is_multiturn,
|
||||
agent_role=agent_role,
|
||||
),
|
||||
)
|
||||
|
||||
async with _create_a2a_client(
|
||||
agent_card=agent_card,
|
||||
transport_protocol=transport_protocol,
|
||||
timeout=timeout,
|
||||
headers=headers,
|
||||
streaming=True,
|
||||
auth=auth,
|
||||
) as client:
|
||||
result_parts: list[str] = []
|
||||
final_result: dict[str, Any] | None = None
|
||||
event_stream = client.send_message(message)
|
||||
|
||||
try:
|
||||
async for event in event_stream:
|
||||
if isinstance(event, Message):
|
||||
new_messages.append(event)
|
||||
for part in event.parts:
|
||||
if part.root.kind == "text":
|
||||
text = part.root.text
|
||||
result_parts.append(text)
|
||||
|
||||
elif isinstance(event, tuple):
|
||||
a2a_task, update = event
|
||||
|
||||
if isinstance(update, TaskArtifactUpdateEvent):
|
||||
artifact = update.artifact
|
||||
result_parts.extend(
|
||||
part.root.text
|
||||
for part in artifact.parts
|
||||
if part.root.kind == "text"
|
||||
)
|
||||
|
||||
is_final_update = False
|
||||
if isinstance(update, TaskStatusUpdateEvent):
|
||||
is_final_update = update.final
|
||||
|
||||
if not is_final_update and a2a_task.status.state not in [
|
||||
TaskState.completed,
|
||||
TaskState.input_required,
|
||||
TaskState.failed,
|
||||
TaskState.rejected,
|
||||
TaskState.auth_required,
|
||||
TaskState.canceled,
|
||||
]:
|
||||
continue
|
||||
|
||||
if a2a_task.status.state == TaskState.completed:
|
||||
extracted_parts = _extract_task_result_parts(a2a_task)
|
||||
result_parts.extend(extracted_parts)
|
||||
if a2a_task.history:
|
||||
new_messages.extend(a2a_task.history)
|
||||
|
||||
response_text = " ".join(result_parts) if result_parts else ""
|
||||
crewai_event_bus.emit(
|
||||
None,
|
||||
A2AResponseReceivedEvent(
|
||||
response=response_text,
|
||||
turn_number=turn_number,
|
||||
is_multiturn=is_multiturn,
|
||||
status="completed",
|
||||
agent_role=agent_role,
|
||||
),
|
||||
)
|
||||
|
||||
final_result = {
|
||||
"status": "completed",
|
||||
"result": response_text,
|
||||
"history": new_messages,
|
||||
"agent_card": agent_card,
|
||||
}
|
||||
break
|
||||
|
||||
if a2a_task.status.state == TaskState.input_required:
|
||||
if a2a_task.history:
|
||||
new_messages.extend(a2a_task.history)
|
||||
|
||||
response_text = _extract_error_message(
|
||||
a2a_task, "Additional input required"
|
||||
)
|
||||
if response_text and not a2a_task.history:
|
||||
agent_message = Message(
|
||||
role=Role.agent,
|
||||
message_id=str(uuid.uuid4()),
|
||||
parts=[Part(root=TextPart(text=response_text))],
|
||||
context_id=a2a_task.context_id
|
||||
if hasattr(a2a_task, "context_id")
|
||||
else None,
|
||||
task_id=a2a_task.task_id
|
||||
if hasattr(a2a_task, "task_id")
|
||||
else None,
|
||||
)
|
||||
new_messages.append(agent_message)
|
||||
crewai_event_bus.emit(
|
||||
None,
|
||||
A2AResponseReceivedEvent(
|
||||
response=response_text,
|
||||
turn_number=turn_number,
|
||||
is_multiturn=is_multiturn,
|
||||
status="input_required",
|
||||
agent_role=agent_role,
|
||||
),
|
||||
)
|
||||
|
||||
final_result = {
|
||||
"status": "input_required",
|
||||
"error": response_text,
|
||||
"history": new_messages,
|
||||
"agent_card": agent_card,
|
||||
}
|
||||
break
|
||||
|
||||
if a2a_task.status.state in [TaskState.failed, TaskState.rejected]:
|
||||
error_msg = _extract_error_message(
|
||||
a2a_task, "Task failed without error message"
|
||||
)
|
||||
if a2a_task.history:
|
||||
new_messages.extend(a2a_task.history)
|
||||
final_result = {
|
||||
"status": "failed",
|
||||
"error": error_msg,
|
||||
"history": new_messages,
|
||||
}
|
||||
break
|
||||
|
||||
if a2a_task.status.state == TaskState.auth_required:
|
||||
error_msg = _extract_error_message(
|
||||
a2a_task, "Authentication required"
|
||||
)
|
||||
final_result = {
|
||||
"status": "auth_required",
|
||||
"error": error_msg,
|
||||
"history": new_messages,
|
||||
}
|
||||
break
|
||||
|
||||
if a2a_task.status.state == TaskState.canceled:
|
||||
error_msg = _extract_error_message(
|
||||
a2a_task, "Task was canceled"
|
||||
)
|
||||
final_result = {
|
||||
"status": "canceled",
|
||||
"error": error_msg,
|
||||
"history": new_messages,
|
||||
}
|
||||
break
|
||||
except Exception as e:
|
||||
current_exception: Exception | BaseException | None = e
|
||||
while current_exception:
|
||||
if hasattr(current_exception, "response"):
|
||||
response = current_exception.response
|
||||
if hasattr(response, "text"):
|
||||
break
|
||||
if current_exception and hasattr(current_exception, "__cause__"):
|
||||
current_exception = current_exception.__cause__
|
||||
raise
|
||||
finally:
|
||||
if hasattr(event_stream, "aclose"):
|
||||
await event_stream.aclose()
|
||||
|
||||
if final_result:
|
||||
return final_result
|
||||
|
||||
return {
|
||||
"status": "completed",
|
||||
"result": " ".join(result_parts) if result_parts else "",
|
||||
"history": new_messages,
|
||||
}
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def _create_a2a_client(
|
||||
agent_card: AgentCard,
|
||||
transport_protocol: TransportProtocol,
|
||||
timeout: int,
|
||||
headers: MutableMapping[str, str],
|
||||
streaming: bool,
|
||||
auth: AuthScheme | None = None,
|
||||
) -> AsyncIterator[Client]:
|
||||
"""Create and configure an A2A client.
|
||||
|
||||
Args:
|
||||
agent_card: The A2A agent card
|
||||
transport_protocol: Transport protocol to use
|
||||
timeout: Request timeout in seconds
|
||||
headers: HTTP headers (already with auth applied)
|
||||
streaming: Enable streaming responses
|
||||
auth: Optional AuthScheme for client configuration
|
||||
|
||||
Yields:
|
||||
Configured A2A client instance
|
||||
"""
|
||||
|
||||
async with httpx.AsyncClient(
|
||||
timeout=timeout,
|
||||
headers=headers,
|
||||
) as httpx_client:
|
||||
if auth and isinstance(auth, (HTTPDigestAuth, APIKeyAuth)):
|
||||
configure_auth_client(auth, httpx_client)
|
||||
|
||||
config = ClientConfig(
|
||||
httpx_client=httpx_client,
|
||||
supported_transports=[str(transport_protocol.value)],
|
||||
streaming=streaming,
|
||||
accepted_output_modes=["application/json"],
|
||||
)
|
||||
|
||||
factory = ClientFactory(config)
|
||||
client = factory.create(agent_card)
|
||||
yield client
|
||||
|
||||
|
||||
def _extract_task_result_parts(a2a_task: A2ATask) -> list[str]:
|
||||
"""Extract result parts from A2A task history and artifacts.
|
||||
|
||||
Args:
|
||||
a2a_task: A2A Task object with history and artifacts
|
||||
|
||||
Returns:
|
||||
List of result text parts
|
||||
"""
|
||||
|
||||
result_parts: list[str] = []
|
||||
|
||||
if a2a_task.history:
|
||||
for history_msg in reversed(a2a_task.history):
|
||||
if history_msg.role == Role.agent:
|
||||
result_parts.extend(
|
||||
part.root.text
|
||||
for part in history_msg.parts
|
||||
if part.root.kind == "text"
|
||||
)
|
||||
break
|
||||
|
||||
if a2a_task.artifacts:
|
||||
result_parts.extend(
|
||||
part.root.text
|
||||
for artifact in a2a_task.artifacts
|
||||
for part in artifact.parts
|
||||
if part.root.kind == "text"
|
||||
)
|
||||
|
||||
return result_parts
|
||||
|
||||
|
||||
def _extract_error_message(a2a_task: A2ATask, default: str) -> str:
|
||||
"""Extract error message from A2A task.
|
||||
|
||||
Args:
|
||||
a2a_task: A2A Task object
|
||||
default: Default message if no error found
|
||||
|
||||
Returns:
|
||||
Error message string
|
||||
"""
|
||||
if a2a_task.status and a2a_task.status.message:
|
||||
msg = a2a_task.status.message
|
||||
if msg:
|
||||
for part in msg.parts:
|
||||
if part.root.kind == "text":
|
||||
return str(part.root.text)
|
||||
return str(msg)
|
||||
|
||||
if a2a_task.history:
|
||||
for history_msg in reversed(a2a_task.history):
|
||||
for part in history_msg.parts:
|
||||
if part.root.kind == "text":
|
||||
return str(part.root.text)
|
||||
|
||||
return default
|
||||
|
||||
|
||||
def create_agent_response_model(agent_ids: tuple[str, ...]) -> type[BaseModel]:
|
||||
"""Create a dynamic AgentResponse model with Literal types for agent IDs.
|
||||
|
||||
Args:
|
||||
agent_ids: List of available A2A agent IDs
|
||||
|
||||
Returns:
|
||||
Dynamically created Pydantic model with Literal-constrained a2a_ids field
|
||||
"""
|
||||
|
||||
DynamicLiteral = create_literals_from_strings(agent_ids) # noqa: N806
|
||||
|
||||
return create_model(
|
||||
"AgentResponse",
|
||||
a2a_ids=(
|
||||
tuple[DynamicLiteral, ...], # type: ignore[valid-type]
|
||||
Field(
|
||||
default_factory=tuple,
|
||||
max_length=len(agent_ids),
|
||||
description="A2A agent IDs to delegate to.",
|
||||
),
|
||||
),
|
||||
message=(
|
||||
str,
|
||||
Field(
|
||||
description="The message content. If is_a2a=true, this is sent to the A2A agent. If is_a2a=false, this is your final answer ending the conversation."
|
||||
),
|
||||
),
|
||||
is_a2a=(
|
||||
bool,
|
||||
Field(
|
||||
description="Set to true to continue the conversation by sending this message to the A2A agent and awaiting their response. Set to false ONLY when you are completely done and providing your final answer (not when asking questions)."
|
||||
),
|
||||
),
|
||||
__base__=BaseModel,
|
||||
)
|
||||
|
||||
|
||||
def extract_a2a_agent_ids_from_config(
|
||||
a2a_config: list[A2AConfig] | A2AConfig | None,
|
||||
) -> tuple[list[A2AConfig], tuple[str, ...]]:
|
||||
"""Extract A2A agent IDs from A2A configuration.
|
||||
|
||||
Args:
|
||||
a2a_config: A2A configuration
|
||||
|
||||
Returns:
|
||||
List of A2A agent IDs
|
||||
"""
|
||||
if a2a_config is None:
|
||||
return [], ()
|
||||
|
||||
if isinstance(a2a_config, A2AConfig):
|
||||
a2a_agents = [a2a_config]
|
||||
else:
|
||||
a2a_agents = a2a_config
|
||||
return a2a_agents, tuple(config.endpoint for config in a2a_agents)
|
||||
|
||||
|
||||
def get_a2a_agents_and_response_model(
|
||||
a2a_config: list[A2AConfig] | A2AConfig | None,
|
||||
) -> tuple[list[A2AConfig], type[BaseModel]]:
|
||||
"""Get A2A agent IDs and response model.
|
||||
|
||||
Args:
|
||||
a2a_config: A2A configuration
|
||||
|
||||
Returns:
|
||||
Tuple of A2A agent IDs and response model
|
||||
"""
|
||||
a2a_agents, agent_ids = extract_a2a_agent_ids_from_config(a2a_config=a2a_config)
|
||||
return a2a_agents, create_agent_response_model(agent_ids)
|
||||
570
lib/crewai/src/crewai/a2a/wrapper.py
Normal file
570
lib/crewai/src/crewai/a2a/wrapper.py
Normal file
@@ -0,0 +1,570 @@
|
||||
"""A2A agent wrapping logic for metaclass integration.
|
||||
|
||||
Wraps agent classes with A2A delegation capabilities.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from functools import wraps
|
||||
from types import MethodType
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
from a2a.types import Role
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from crewai.a2a.config import A2AConfig
|
||||
from crewai.a2a.templates import (
|
||||
AVAILABLE_AGENTS_TEMPLATE,
|
||||
CONVERSATION_TURN_INFO_TEMPLATE,
|
||||
PREVIOUS_A2A_CONVERSATION_TEMPLATE,
|
||||
UNAVAILABLE_AGENTS_NOTICE_TEMPLATE,
|
||||
)
|
||||
from crewai.a2a.types import AgentResponseProtocol
|
||||
from crewai.a2a.utils import (
|
||||
execute_a2a_delegation,
|
||||
fetch_agent_card,
|
||||
get_a2a_agents_and_response_model,
|
||||
)
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.a2a_events import (
|
||||
A2AConversationCompletedEvent,
|
||||
A2AMessageSentEvent,
|
||||
)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from a2a.types import AgentCard, Message
|
||||
|
||||
from crewai.agent.core import Agent
|
||||
from crewai.task import Task
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
|
||||
|
||||
def wrap_agent_with_a2a_instance(agent: Agent) -> None:
|
||||
"""Wrap an agent instance's execute_task method with A2A support.
|
||||
|
||||
This function modifies the agent instance by wrapping its execute_task
|
||||
method to add A2A delegation capabilities. Should only be called when
|
||||
the agent has a2a configuration set.
|
||||
|
||||
Args:
|
||||
agent: The agent instance to wrap
|
||||
"""
|
||||
original_execute_task = agent.execute_task.__func__
|
||||
|
||||
@wraps(original_execute_task)
|
||||
def execute_task_with_a2a(
|
||||
self: Agent,
|
||||
task: Task,
|
||||
context: str | None = None,
|
||||
tools: list[BaseTool] | None = None,
|
||||
) -> str:
|
||||
"""Execute task with A2A delegation support.
|
||||
|
||||
Args:
|
||||
self: The agent instance
|
||||
task: The task to execute
|
||||
context: Optional context for task execution
|
||||
tools: Optional tools available to the agent
|
||||
|
||||
Returns:
|
||||
Task execution result
|
||||
"""
|
||||
if not self.a2a:
|
||||
return original_execute_task(self, task, context, tools)
|
||||
|
||||
a2a_agents, agent_response_model = get_a2a_agents_and_response_model(self.a2a)
|
||||
|
||||
return _execute_task_with_a2a(
|
||||
self=self,
|
||||
a2a_agents=a2a_agents,
|
||||
original_fn=original_execute_task,
|
||||
task=task,
|
||||
agent_response_model=agent_response_model,
|
||||
context=context,
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
object.__setattr__(agent, "execute_task", MethodType(execute_task_with_a2a, agent))
|
||||
|
||||
|
||||
def _fetch_card_from_config(
|
||||
config: A2AConfig,
|
||||
) -> tuple[A2AConfig, AgentCard | Exception]:
|
||||
"""Fetch agent card from A2A config.
|
||||
|
||||
Args:
|
||||
config: A2A configuration
|
||||
|
||||
Returns:
|
||||
Tuple of (config, card or exception)
|
||||
"""
|
||||
try:
|
||||
card = fetch_agent_card(
|
||||
endpoint=config.endpoint,
|
||||
auth=config.auth,
|
||||
timeout=config.timeout,
|
||||
)
|
||||
return config, card
|
||||
except Exception as e:
|
||||
return config, e
|
||||
|
||||
|
||||
def _fetch_agent_cards_concurrently(
|
||||
a2a_agents: list[A2AConfig],
|
||||
) -> tuple[dict[str, AgentCard], dict[str, str]]:
|
||||
"""Fetch agent cards concurrently for multiple A2A agents.
|
||||
|
||||
Args:
|
||||
a2a_agents: List of A2A agent configurations
|
||||
|
||||
Returns:
|
||||
Tuple of (agent_cards dict, failed_agents dict mapping endpoint to error message)
|
||||
"""
|
||||
agent_cards: dict[str, AgentCard] = {}
|
||||
failed_agents: dict[str, str] = {}
|
||||
|
||||
with ThreadPoolExecutor(max_workers=len(a2a_agents)) as executor:
|
||||
futures = {
|
||||
executor.submit(_fetch_card_from_config, config): config
|
||||
for config in a2a_agents
|
||||
}
|
||||
for future in as_completed(futures):
|
||||
config, result = future.result()
|
||||
if isinstance(result, Exception):
|
||||
if config.fail_fast:
|
||||
raise RuntimeError(
|
||||
f"Failed to fetch agent card from {config.endpoint}. "
|
||||
f"Ensure the A2A agent is running and accessible. Error: {result}"
|
||||
) from result
|
||||
failed_agents[config.endpoint] = str(result)
|
||||
else:
|
||||
agent_cards[config.endpoint] = result
|
||||
|
||||
return agent_cards, failed_agents
|
||||
|
||||
|
||||
def _execute_task_with_a2a(
|
||||
self: Agent,
|
||||
a2a_agents: list[A2AConfig],
|
||||
original_fn: Callable[..., str],
|
||||
task: Task,
|
||||
agent_response_model: type[BaseModel],
|
||||
context: str | None,
|
||||
tools: list[BaseTool] | None,
|
||||
) -> str:
|
||||
"""Wrap execute_task with A2A delegation logic.
|
||||
|
||||
Args:
|
||||
self: The agent instance
|
||||
a2a_agents: Dictionary of A2A agent configurations
|
||||
original_fn: The original execute_task method
|
||||
task: The task to execute
|
||||
context: Optional context for task execution
|
||||
tools: Optional tools available to the agent
|
||||
agent_response_model: Optional agent response model
|
||||
|
||||
Returns:
|
||||
Task execution result (either from LLM or A2A agent)
|
||||
"""
|
||||
original_description: str = task.description
|
||||
original_output_pydantic = task.output_pydantic
|
||||
original_response_model = task.response_model
|
||||
|
||||
agent_cards, failed_agents = _fetch_agent_cards_concurrently(a2a_agents)
|
||||
|
||||
if not agent_cards and a2a_agents and failed_agents:
|
||||
unavailable_agents_text = ""
|
||||
for endpoint, error in failed_agents.items():
|
||||
unavailable_agents_text += f" - {endpoint}: {error}\n"
|
||||
|
||||
notice = UNAVAILABLE_AGENTS_NOTICE_TEMPLATE.substitute(
|
||||
unavailable_agents=unavailable_agents_text
|
||||
)
|
||||
task.description = f"{original_description}{notice}"
|
||||
|
||||
try:
|
||||
return original_fn(self, task, context, tools)
|
||||
finally:
|
||||
task.description = original_description
|
||||
|
||||
task.description = _augment_prompt_with_a2a(
|
||||
a2a_agents=a2a_agents,
|
||||
task_description=original_description,
|
||||
agent_cards=agent_cards,
|
||||
failed_agents=failed_agents,
|
||||
)
|
||||
task.response_model = agent_response_model
|
||||
|
||||
try:
|
||||
raw_result = original_fn(self, task, context, tools)
|
||||
agent_response = _parse_agent_response(
|
||||
raw_result=raw_result, agent_response_model=agent_response_model
|
||||
)
|
||||
|
||||
if isinstance(agent_response, BaseModel) and isinstance(
|
||||
agent_response, AgentResponseProtocol
|
||||
):
|
||||
if agent_response.is_a2a:
|
||||
return _delegate_to_a2a(
|
||||
self,
|
||||
agent_response=agent_response,
|
||||
task=task,
|
||||
original_fn=original_fn,
|
||||
context=context,
|
||||
tools=tools,
|
||||
agent_cards=agent_cards,
|
||||
original_task_description=original_description,
|
||||
)
|
||||
return str(agent_response.message)
|
||||
|
||||
return raw_result
|
||||
finally:
|
||||
task.description = original_description
|
||||
task.output_pydantic = original_output_pydantic
|
||||
task.response_model = original_response_model
|
||||
|
||||
|
||||
def _augment_prompt_with_a2a(
|
||||
a2a_agents: list[A2AConfig],
|
||||
task_description: str,
|
||||
agent_cards: dict[str, AgentCard],
|
||||
conversation_history: list[Message] | None = None,
|
||||
turn_num: int = 0,
|
||||
max_turns: int | None = None,
|
||||
failed_agents: dict[str, str] | None = None,
|
||||
) -> str:
|
||||
"""Add A2A delegation instructions to prompt.
|
||||
|
||||
Args:
|
||||
a2a_agents: Dictionary of A2A agent configurations
|
||||
task_description: Original task description
|
||||
agent_cards: dictionary mapping agent IDs to AgentCards
|
||||
conversation_history: Previous A2A Messages from conversation
|
||||
turn_num: Current turn number (0-indexed)
|
||||
max_turns: Maximum allowed turns (from config)
|
||||
failed_agents: Dictionary mapping failed agent endpoints to error messages
|
||||
|
||||
Returns:
|
||||
Augmented task description with A2A instructions
|
||||
"""
|
||||
|
||||
if not agent_cards:
|
||||
return task_description
|
||||
|
||||
agents_text = ""
|
||||
|
||||
for config in a2a_agents:
|
||||
if config.endpoint in agent_cards:
|
||||
card = agent_cards[config.endpoint]
|
||||
agents_text += f"\n{card.model_dump_json(indent=2, exclude_none=True, include={'description', 'url', 'skills'})}\n"
|
||||
|
||||
failed_agents = failed_agents or {}
|
||||
if failed_agents:
|
||||
agents_text += "\n<!-- Unavailable Agents -->\n"
|
||||
for endpoint, error in failed_agents.items():
|
||||
agents_text += f"\n<!-- Agent: {endpoint}\n Status: Unavailable\n Error: {error} -->\n"
|
||||
|
||||
agents_text = AVAILABLE_AGENTS_TEMPLATE.substitute(available_a2a_agents=agents_text)
|
||||
|
||||
history_text = ""
|
||||
if conversation_history:
|
||||
for msg in conversation_history:
|
||||
history_text += f"\n{msg.model_dump_json(indent=2, exclude_none=True, exclude={'message_id'})}\n"
|
||||
|
||||
history_text = PREVIOUS_A2A_CONVERSATION_TEMPLATE.substitute(
|
||||
previous_a2a_conversation=history_text
|
||||
)
|
||||
turn_info = ""
|
||||
|
||||
if max_turns is not None and conversation_history:
|
||||
turn_count = turn_num + 1
|
||||
warning = ""
|
||||
if turn_count >= max_turns:
|
||||
warning = (
|
||||
"CRITICAL: This is the FINAL turn. You MUST conclude the conversation now.\n"
|
||||
"Set is_a2a=false and provide your final response to complete the task."
|
||||
)
|
||||
elif turn_count == max_turns - 1:
|
||||
warning = "WARNING: Next turn will be the last. Consider wrapping up the conversation."
|
||||
|
||||
turn_info = CONVERSATION_TURN_INFO_TEMPLATE.substitute(
|
||||
turn_count=turn_count,
|
||||
max_turns=max_turns,
|
||||
warning=warning,
|
||||
)
|
||||
|
||||
return f"""{task_description}
|
||||
|
||||
IMPORTANT: You have the ability to delegate this task to remote A2A agents.
|
||||
|
||||
{agents_text}
|
||||
{history_text}{turn_info}
|
||||
|
||||
|
||||
"""
|
||||
|
||||
|
||||
def _parse_agent_response(
|
||||
raw_result: str | dict[str, Any], agent_response_model: type[BaseModel]
|
||||
) -> BaseModel | str:
|
||||
"""Parse LLM output as AgentResponse or return raw agent response.
|
||||
|
||||
Args:
|
||||
raw_result: Raw output from LLM
|
||||
agent_response_model: The agent response model
|
||||
|
||||
Returns:
|
||||
Parsed AgentResponse or string
|
||||
"""
|
||||
if agent_response_model:
|
||||
try:
|
||||
if isinstance(raw_result, str):
|
||||
return agent_response_model.model_validate_json(raw_result)
|
||||
if isinstance(raw_result, dict):
|
||||
return agent_response_model.model_validate(raw_result)
|
||||
except ValidationError:
|
||||
return cast(str, raw_result)
|
||||
return cast(str, raw_result)
|
||||
|
||||
|
||||
def _handle_agent_response_and_continue(
|
||||
self: Agent,
|
||||
a2a_result: dict[str, Any],
|
||||
agent_id: str,
|
||||
agent_cards: dict[str, AgentCard] | None,
|
||||
a2a_agents: list[A2AConfig],
|
||||
original_task_description: str,
|
||||
conversation_history: list[Message],
|
||||
turn_num: int,
|
||||
max_turns: int,
|
||||
task: Task,
|
||||
original_fn: Callable[..., str],
|
||||
context: str | None,
|
||||
tools: list[BaseTool] | None,
|
||||
agent_response_model: type[BaseModel],
|
||||
) -> tuple[str | None, str | None]:
|
||||
"""Handle A2A result and get CrewAI agent's response.
|
||||
|
||||
Args:
|
||||
self: The agent instance
|
||||
a2a_result: Result from A2A delegation
|
||||
agent_id: ID of the A2A agent
|
||||
agent_cards: Pre-fetched agent cards
|
||||
a2a_agents: List of A2A configurations
|
||||
original_task_description: Original task description
|
||||
conversation_history: Conversation history
|
||||
turn_num: Current turn number
|
||||
max_turns: Maximum turns allowed
|
||||
task: The task being executed
|
||||
original_fn: Original execute_task method
|
||||
context: Optional context
|
||||
tools: Optional tools
|
||||
agent_response_model: Response model for parsing
|
||||
|
||||
Returns:
|
||||
Tuple of (final_result, current_request) where:
|
||||
- final_result is not None if conversation should end
|
||||
- current_request is the next message to send if continuing
|
||||
"""
|
||||
agent_cards_dict = agent_cards or {}
|
||||
if "agent_card" in a2a_result and agent_id not in agent_cards_dict:
|
||||
agent_cards_dict[agent_id] = a2a_result["agent_card"]
|
||||
|
||||
task.description = _augment_prompt_with_a2a(
|
||||
a2a_agents=a2a_agents,
|
||||
task_description=original_task_description,
|
||||
conversation_history=conversation_history,
|
||||
turn_num=turn_num,
|
||||
max_turns=max_turns,
|
||||
agent_cards=agent_cards_dict,
|
||||
)
|
||||
|
||||
raw_result = original_fn(self, task, context, tools)
|
||||
llm_response = _parse_agent_response(
|
||||
raw_result=raw_result, agent_response_model=agent_response_model
|
||||
)
|
||||
|
||||
if isinstance(llm_response, BaseModel) and isinstance(
|
||||
llm_response, AgentResponseProtocol
|
||||
):
|
||||
if not llm_response.is_a2a:
|
||||
final_turn_number = turn_num + 1
|
||||
crewai_event_bus.emit(
|
||||
None,
|
||||
A2AMessageSentEvent(
|
||||
message=str(llm_response.message),
|
||||
turn_number=final_turn_number,
|
||||
is_multiturn=True,
|
||||
agent_role=self.role,
|
||||
),
|
||||
)
|
||||
crewai_event_bus.emit(
|
||||
None,
|
||||
A2AConversationCompletedEvent(
|
||||
status="completed",
|
||||
final_result=str(llm_response.message),
|
||||
error=None,
|
||||
total_turns=final_turn_number,
|
||||
),
|
||||
)
|
||||
return str(llm_response.message), None
|
||||
return None, str(llm_response.message)
|
||||
|
||||
return str(raw_result), None
|
||||
|
||||
|
||||
def _delegate_to_a2a(
|
||||
self: Agent,
|
||||
agent_response: AgentResponseProtocol,
|
||||
task: Task,
|
||||
original_fn: Callable[..., str],
|
||||
context: str | None,
|
||||
tools: list[BaseTool] | None,
|
||||
agent_cards: dict[str, AgentCard] | None = None,
|
||||
original_task_description: str | None = None,
|
||||
) -> str:
|
||||
"""Delegate to A2A agent with multi-turn conversation support.
|
||||
|
||||
Args:
|
||||
self: The agent instance
|
||||
agent_response: The AgentResponse indicating delegation
|
||||
task: The task being executed (for extracting A2A fields)
|
||||
original_fn: The original execute_task method for follow-ups
|
||||
context: Optional context for task execution
|
||||
tools: Optional tools available to the agent
|
||||
agent_cards: Pre-fetched agent cards from _execute_task_with_a2a
|
||||
original_task_description: The original task description before A2A augmentation
|
||||
|
||||
Returns:
|
||||
Result from A2A agent
|
||||
|
||||
Raises:
|
||||
ImportError: If a2a-sdk is not installed
|
||||
"""
|
||||
a2a_agents, agent_response_model = get_a2a_agents_and_response_model(self.a2a)
|
||||
agent_ids = tuple(config.endpoint for config in a2a_agents)
|
||||
current_request = str(agent_response.message)
|
||||
agent_id = agent_response.a2a_ids[0]
|
||||
|
||||
if agent_id not in agent_ids:
|
||||
raise ValueError(
|
||||
f"Unknown A2A agent ID(s): {agent_response.a2a_ids} not in {agent_ids}"
|
||||
)
|
||||
|
||||
agent_config = next(filter(lambda x: x.endpoint == agent_id, a2a_agents))
|
||||
task_config = task.config or {}
|
||||
context_id = task_config.get("context_id")
|
||||
task_id_config = task_config.get("task_id")
|
||||
reference_task_ids = task_config.get("reference_task_ids")
|
||||
metadata = task_config.get("metadata")
|
||||
extensions = task_config.get("extensions")
|
||||
|
||||
if original_task_description is None:
|
||||
original_task_description = task.description
|
||||
|
||||
conversation_history: list[Message] = []
|
||||
max_turns = agent_config.max_turns
|
||||
|
||||
try:
|
||||
for turn_num in range(max_turns):
|
||||
console_formatter = getattr(crewai_event_bus, "_console", None)
|
||||
agent_branch = None
|
||||
if console_formatter:
|
||||
agent_branch = getattr(
|
||||
console_formatter, "current_agent_branch", None
|
||||
) or getattr(console_formatter, "current_task_branch", None)
|
||||
|
||||
a2a_result = execute_a2a_delegation(
|
||||
endpoint=agent_config.endpoint,
|
||||
auth=agent_config.auth,
|
||||
timeout=agent_config.timeout,
|
||||
task_description=current_request,
|
||||
context_id=context_id,
|
||||
task_id=task_id_config,
|
||||
reference_task_ids=reference_task_ids,
|
||||
metadata=metadata,
|
||||
extensions=extensions,
|
||||
conversation_history=conversation_history,
|
||||
agent_id=agent_id,
|
||||
agent_role=Role.user,
|
||||
agent_branch=agent_branch,
|
||||
response_model=agent_config.response_model,
|
||||
turn_number=turn_num + 1,
|
||||
)
|
||||
|
||||
conversation_history = a2a_result.get("history", [])
|
||||
|
||||
if a2a_result["status"] in ["completed", "input_required"]:
|
||||
final_result, next_request = _handle_agent_response_and_continue(
|
||||
self=self,
|
||||
a2a_result=a2a_result,
|
||||
agent_id=agent_id,
|
||||
agent_cards=agent_cards,
|
||||
a2a_agents=a2a_agents,
|
||||
original_task_description=original_task_description,
|
||||
conversation_history=conversation_history,
|
||||
turn_num=turn_num,
|
||||
max_turns=max_turns,
|
||||
task=task,
|
||||
original_fn=original_fn,
|
||||
context=context,
|
||||
tools=tools,
|
||||
agent_response_model=agent_response_model,
|
||||
)
|
||||
|
||||
if final_result is not None:
|
||||
return final_result
|
||||
|
||||
if next_request is not None:
|
||||
current_request = next_request
|
||||
|
||||
continue
|
||||
|
||||
error_msg = a2a_result.get("error", "Unknown error")
|
||||
crewai_event_bus.emit(
|
||||
None,
|
||||
A2AConversationCompletedEvent(
|
||||
status="failed",
|
||||
final_result=None,
|
||||
error=error_msg,
|
||||
total_turns=turn_num + 1,
|
||||
),
|
||||
)
|
||||
raise Exception(f"A2A delegation failed: {error_msg}")
|
||||
|
||||
if conversation_history:
|
||||
for msg in reversed(conversation_history):
|
||||
if msg.role == Role.agent:
|
||||
text_parts = [
|
||||
part.root.text for part in msg.parts if part.root.kind == "text"
|
||||
]
|
||||
final_message = (
|
||||
" ".join(text_parts) if text_parts else "Conversation completed"
|
||||
)
|
||||
crewai_event_bus.emit(
|
||||
None,
|
||||
A2AConversationCompletedEvent(
|
||||
status="completed",
|
||||
final_result=final_message,
|
||||
error=None,
|
||||
total_turns=max_turns,
|
||||
),
|
||||
)
|
||||
return final_message
|
||||
|
||||
crewai_event_bus.emit(
|
||||
None,
|
||||
A2AConversationCompletedEvent(
|
||||
status="failed",
|
||||
final_result=None,
|
||||
error=f"Conversation exceeded maximum turns ({max_turns})",
|
||||
total_turns=max_turns,
|
||||
),
|
||||
)
|
||||
raise Exception(f"A2A conversation exceeded maximum turns ({max_turns})")
|
||||
|
||||
finally:
|
||||
task.description = original_task_description
|
||||
5
lib/crewai/src/crewai/agent/__init__.py
Normal file
5
lib/crewai/src/crewai/agent/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from crewai.agent.core import Agent
|
||||
from crewai.utilities.training_handler import CrewTrainingHandler
|
||||
|
||||
|
||||
__all__ = ["Agent", "CrewTrainingHandler"]
|
||||
@@ -2,27 +2,27 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Sequence
|
||||
import json
|
||||
import shutil
|
||||
import subprocess
|
||||
import time
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Final,
|
||||
Literal,
|
||||
cast,
|
||||
)
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from pydantic import BaseModel, Field, InstanceOf, PrivateAttr, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
from crewai.a2a.config import A2AConfig
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
from crewai.agents.cache.cache_handler import CacheHandler
|
||||
from crewai.agents.crew_agent_executor import CrewAgentExecutor
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.agent_events import (
|
||||
AgentExecutionCompletedEvent,
|
||||
AgentExecutionErrorEvent,
|
||||
AgentExecutionStartedEvent,
|
||||
)
|
||||
from crewai.events.types.knowledge_events import (
|
||||
KnowledgeQueryCompletedEvent,
|
||||
KnowledgeQueryFailedEvent,
|
||||
@@ -40,6 +40,16 @@ from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
|
||||
from crewai.knowledge.utils.knowledge_utils import extract_knowledge_context
|
||||
from crewai.lite_agent import LiteAgent
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.mcp import (
|
||||
MCPClient,
|
||||
MCPServerConfig,
|
||||
MCPServerHTTP,
|
||||
MCPServerSSE,
|
||||
MCPServerStdio,
|
||||
)
|
||||
from crewai.mcp.transports.http import HTTPTransport
|
||||
from crewai.mcp.transports.sse import SSETransport
|
||||
from crewai.mcp.transports.stdio import StdioTransport
|
||||
from crewai.memory.contextual.contextual_memory import ContextualMemory
|
||||
from crewai.rag.embeddings.types import EmbedderConfig
|
||||
from crewai.security.fingerprint import Fingerprint
|
||||
@@ -70,14 +80,14 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
# MCP Connection timeout constants (in seconds)
|
||||
MCP_CONNECTION_TIMEOUT = 10
|
||||
MCP_TOOL_EXECUTION_TIMEOUT = 30
|
||||
MCP_DISCOVERY_TIMEOUT = 15
|
||||
MCP_MAX_RETRIES = 3
|
||||
MCP_CONNECTION_TIMEOUT: Final[int] = 10
|
||||
MCP_TOOL_EXECUTION_TIMEOUT: Final[int] = 30
|
||||
MCP_DISCOVERY_TIMEOUT: Final[int] = 15
|
||||
MCP_MAX_RETRIES: Final[int] = 3
|
||||
|
||||
# Simple in-memory cache for MCP tool schemas (duration: 5 minutes)
|
||||
_mcp_schema_cache = {}
|
||||
_cache_ttl = 300 # 5 minutes
|
||||
_mcp_schema_cache: dict[str, Any] = {}
|
||||
_cache_ttl: Final[int] = 300 # 5 minutes
|
||||
|
||||
|
||||
class Agent(BaseAgent):
|
||||
@@ -108,6 +118,7 @@ class Agent(BaseAgent):
|
||||
"""
|
||||
|
||||
_times_executed: int = PrivateAttr(default=0)
|
||||
_mcp_clients: list[Any] = PrivateAttr(default_factory=list)
|
||||
max_execution_time: int | None = Field(
|
||||
default=None,
|
||||
description="Maximum execution time for an agent to execute a task",
|
||||
@@ -197,6 +208,10 @@ class Agent(BaseAgent):
|
||||
guardrail_max_retries: int = Field(
|
||||
default=3, description="Maximum number of retries when guardrail fails"
|
||||
)
|
||||
a2a: list[A2AConfig] | A2AConfig | None = Field(
|
||||
default=None,
|
||||
description="A2A (Agent-to-Agent) configuration for delegating tasks to remote agents. Can be a single A2AConfig or a dict mapping agent IDs to configs.",
|
||||
)
|
||||
|
||||
@model_validator(mode="before")
|
||||
def validate_from_repository(cls, v: Any) -> dict[str, Any] | None | Any: # noqa: N805
|
||||
@@ -305,17 +320,19 @@ class Agent(BaseAgent):
|
||||
# If the task requires output in JSON or Pydantic format,
|
||||
# append specific instructions to the task prompt to ensure
|
||||
# that the final answer does not include any code block markers
|
||||
if task.output_json or task.output_pydantic:
|
||||
# Skip this if task.response_model is set, as native structured outputs handle schema automatically
|
||||
if (task.output_json or task.output_pydantic) and not task.response_model:
|
||||
# Generate the schema based on the output format
|
||||
if task.output_json:
|
||||
# schema = json.dumps(task.output_json, indent=2)
|
||||
schema = generate_model_description(task.output_json)
|
||||
schema_dict = generate_model_description(task.output_json)
|
||||
schema = json.dumps(schema_dict["json_schema"]["schema"], indent=2)
|
||||
task_prompt += "\n" + self.i18n.slice(
|
||||
"formatted_task_instructions"
|
||||
).format(output_format=schema)
|
||||
|
||||
elif task.output_pydantic:
|
||||
schema = generate_model_description(task.output_pydantic)
|
||||
schema_dict = generate_model_description(task.output_pydantic)
|
||||
schema = json.dumps(schema_dict["json_schema"]["schema"], indent=2)
|
||||
task_prompt += "\n" + self.i18n.slice(
|
||||
"formatted_task_instructions"
|
||||
).format(output_format=schema)
|
||||
@@ -438,6 +455,13 @@ class Agent(BaseAgent):
|
||||
else:
|
||||
task_prompt = self._use_trained_data(task_prompt=task_prompt)
|
||||
|
||||
# Import agent events locally to avoid circular imports
|
||||
from crewai.events.types.agent_events import (
|
||||
AgentExecutionCompletedEvent,
|
||||
AgentExecutionErrorEvent,
|
||||
AgentExecutionStartedEvent,
|
||||
)
|
||||
|
||||
try:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
@@ -513,6 +537,9 @@ class Agent(BaseAgent):
|
||||
self,
|
||||
event=AgentExecutionCompletedEvent(agent=self, task=task, output=result),
|
||||
)
|
||||
|
||||
self._cleanup_mcp_clients()
|
||||
|
||||
return result
|
||||
|
||||
def _execute_with_timeout(self, task_prompt: str, task: Task, timeout: int) -> Any:
|
||||
@@ -618,6 +645,7 @@ class Agent(BaseAgent):
|
||||
self._rpm_controller.check_or_wait if self._rpm_controller else None
|
||||
),
|
||||
callbacks=[TokenCalcHandler(self._token_process)],
|
||||
response_model=task.response_model if task else None,
|
||||
)
|
||||
|
||||
def get_delegation_tools(self, agents: list[BaseAgent]) -> list[BaseTool]:
|
||||
@@ -635,30 +663,70 @@ class Agent(BaseAgent):
|
||||
self._logger.log("error", f"Error getting platform tools: {e!s}")
|
||||
return []
|
||||
|
||||
def get_mcp_tools(self, mcps: list[str]) -> list[BaseTool]:
|
||||
"""Convert MCP server references to CrewAI tools."""
|
||||
def get_mcp_tools(self, mcps: list[str | MCPServerConfig]) -> list[BaseTool]:
|
||||
"""Convert MCP server references/configs to CrewAI tools.
|
||||
|
||||
Supports both string references (backwards compatible) and structured
|
||||
configuration objects (MCPServerStdio, MCPServerHTTP, MCPServerSSE).
|
||||
|
||||
Args:
|
||||
mcps: List of MCP server references (strings) or configurations.
|
||||
|
||||
Returns:
|
||||
List of BaseTool instances from MCP servers.
|
||||
"""
|
||||
all_tools = []
|
||||
clients = []
|
||||
|
||||
for mcp_ref in mcps:
|
||||
try:
|
||||
if mcp_ref.startswith("crewai-amp:"):
|
||||
tools = self._get_amp_mcp_tools(mcp_ref)
|
||||
elif mcp_ref.startswith("https://"):
|
||||
tools = self._get_external_mcp_tools(mcp_ref)
|
||||
else:
|
||||
continue
|
||||
for mcp_config in mcps:
|
||||
if isinstance(mcp_config, str):
|
||||
tools = self._get_mcp_tools_from_string(mcp_config)
|
||||
else:
|
||||
tools, client = self._get_native_mcp_tools(mcp_config)
|
||||
if client:
|
||||
clients.append(client)
|
||||
|
||||
all_tools.extend(tools)
|
||||
self._logger.log(
|
||||
"info", f"Successfully loaded {len(tools)} tools from {mcp_ref}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
self._logger.log("warning", f"Skipping MCP {mcp_ref} due to error: {e}")
|
||||
continue
|
||||
all_tools.extend(tools)
|
||||
|
||||
# Store clients for cleanup
|
||||
self._mcp_clients.extend(clients)
|
||||
return all_tools
|
||||
|
||||
def _cleanup_mcp_clients(self) -> None:
|
||||
"""Cleanup MCP client connections after task execution."""
|
||||
if not self._mcp_clients:
|
||||
return
|
||||
|
||||
async def _disconnect_all() -> None:
|
||||
for client in self._mcp_clients:
|
||||
if client and hasattr(client, "connected") and client.connected:
|
||||
await client.disconnect()
|
||||
|
||||
try:
|
||||
asyncio.run(_disconnect_all())
|
||||
except Exception as e:
|
||||
self._logger.log("error", f"Error during MCP client cleanup: {e}")
|
||||
finally:
|
||||
self._mcp_clients.clear()
|
||||
|
||||
def _get_mcp_tools_from_string(self, mcp_ref: str) -> list[BaseTool]:
|
||||
"""Get tools from legacy string-based MCP references.
|
||||
|
||||
This method maintains backwards compatibility with string-based
|
||||
MCP references (https://... and crewai-amp:...).
|
||||
|
||||
Args:
|
||||
mcp_ref: String reference to MCP server.
|
||||
|
||||
Returns:
|
||||
List of BaseTool instances.
|
||||
"""
|
||||
if mcp_ref.startswith("crewai-amp:"):
|
||||
return self._get_amp_mcp_tools(mcp_ref)
|
||||
if mcp_ref.startswith("https://"):
|
||||
return self._get_external_mcp_tools(mcp_ref)
|
||||
return []
|
||||
|
||||
def _get_external_mcp_tools(self, mcp_ref: str) -> list[BaseTool]:
|
||||
"""Get tools from external HTTPS MCP server with graceful error handling."""
|
||||
from crewai.tools.mcp_tool_wrapper import MCPToolWrapper
|
||||
@@ -709,7 +777,7 @@ class Agent(BaseAgent):
|
||||
f"Specific tool '{specific_tool}' not found on MCP server: {server_url}",
|
||||
)
|
||||
|
||||
return tools
|
||||
return cast(list[BaseTool], tools)
|
||||
|
||||
except Exception as e:
|
||||
self._logger.log(
|
||||
@@ -717,6 +785,164 @@ class Agent(BaseAgent):
|
||||
)
|
||||
return []
|
||||
|
||||
def _get_native_mcp_tools(
|
||||
self, mcp_config: MCPServerConfig
|
||||
) -> tuple[list[BaseTool], Any | None]:
|
||||
"""Get tools from MCP server using structured configuration.
|
||||
|
||||
This method creates an MCP client based on the configuration type,
|
||||
connects to the server, discovers tools, applies filtering, and
|
||||
returns wrapped tools along with the client instance for cleanup.
|
||||
|
||||
Args:
|
||||
mcp_config: MCP server configuration (MCPServerStdio, MCPServerHTTP, or MCPServerSSE).
|
||||
|
||||
Returns:
|
||||
Tuple of (list of BaseTool instances, MCPClient instance for cleanup).
|
||||
"""
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
from crewai.tools.mcp_native_tool import MCPNativeTool
|
||||
|
||||
if isinstance(mcp_config, MCPServerStdio):
|
||||
transport = StdioTransport(
|
||||
command=mcp_config.command,
|
||||
args=mcp_config.args,
|
||||
env=mcp_config.env,
|
||||
)
|
||||
server_name = f"{mcp_config.command}_{'_'.join(mcp_config.args)}"
|
||||
elif isinstance(mcp_config, MCPServerHTTP):
|
||||
transport = HTTPTransport(
|
||||
url=mcp_config.url,
|
||||
headers=mcp_config.headers,
|
||||
streamable=mcp_config.streamable,
|
||||
)
|
||||
server_name = self._extract_server_name(mcp_config.url)
|
||||
elif isinstance(mcp_config, MCPServerSSE):
|
||||
transport = SSETransport(
|
||||
url=mcp_config.url,
|
||||
headers=mcp_config.headers,
|
||||
)
|
||||
server_name = self._extract_server_name(mcp_config.url)
|
||||
else:
|
||||
raise ValueError(f"Unsupported MCP server config type: {type(mcp_config)}")
|
||||
|
||||
client = MCPClient(
|
||||
transport=transport,
|
||||
cache_tools_list=mcp_config.cache_tools_list,
|
||||
)
|
||||
|
||||
async def _setup_client_and_list_tools() -> list[dict[str, Any]]:
|
||||
"""Async helper to connect and list tools in same event loop."""
|
||||
|
||||
try:
|
||||
if not client.connected:
|
||||
await client.connect()
|
||||
|
||||
tools_list = await client.list_tools()
|
||||
|
||||
try:
|
||||
await client.disconnect()
|
||||
# Small delay to allow background tasks to finish cleanup
|
||||
# This helps prevent "cancel scope in different task" errors
|
||||
# when asyncio.run() closes the event loop
|
||||
await asyncio.sleep(0.1)
|
||||
except Exception as e:
|
||||
self._logger.log("error", f"Error during disconnect: {e}")
|
||||
|
||||
return tools_list
|
||||
except Exception as e:
|
||||
if client.connected:
|
||||
await client.disconnect()
|
||||
await asyncio.sleep(0.1)
|
||||
raise RuntimeError(
|
||||
f"Error during setup client and list tools: {e}"
|
||||
) from e
|
||||
|
||||
try:
|
||||
try:
|
||||
asyncio.get_running_loop()
|
||||
import concurrent.futures
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
future = executor.submit(
|
||||
asyncio.run, _setup_client_and_list_tools()
|
||||
)
|
||||
tools_list = future.result()
|
||||
except RuntimeError:
|
||||
try:
|
||||
tools_list = asyncio.run(_setup_client_and_list_tools())
|
||||
except RuntimeError as e:
|
||||
error_msg = str(e).lower()
|
||||
if "cancel scope" in error_msg or "task" in error_msg:
|
||||
raise ConnectionError(
|
||||
"MCP connection failed due to event loop cleanup issues. "
|
||||
"This may be due to authentication errors or server unavailability."
|
||||
) from e
|
||||
except asyncio.CancelledError as e:
|
||||
raise ConnectionError(
|
||||
"MCP connection was cancelled. This may indicate an authentication "
|
||||
"error or server unavailability."
|
||||
) from e
|
||||
|
||||
if mcp_config.tool_filter:
|
||||
filtered_tools = []
|
||||
for tool in tools_list:
|
||||
if callable(mcp_config.tool_filter):
|
||||
try:
|
||||
from crewai.mcp.filters import ToolFilterContext
|
||||
|
||||
context = ToolFilterContext(
|
||||
agent=self,
|
||||
server_name=server_name,
|
||||
run_context=None,
|
||||
)
|
||||
if mcp_config.tool_filter(context, tool):
|
||||
filtered_tools.append(tool)
|
||||
except (TypeError, AttributeError):
|
||||
if mcp_config.tool_filter(tool):
|
||||
filtered_tools.append(tool)
|
||||
else:
|
||||
# Not callable - include tool
|
||||
filtered_tools.append(tool)
|
||||
tools_list = filtered_tools
|
||||
|
||||
tools = []
|
||||
for tool_def in tools_list:
|
||||
tool_name = tool_def.get("name", "")
|
||||
if not tool_name:
|
||||
continue
|
||||
|
||||
# Convert inputSchema to Pydantic model if present
|
||||
args_schema = None
|
||||
if tool_def.get("inputSchema"):
|
||||
args_schema = self._json_schema_to_pydantic(
|
||||
tool_name, tool_def["inputSchema"]
|
||||
)
|
||||
|
||||
tool_schema = {
|
||||
"description": tool_def.get("description", ""),
|
||||
"args_schema": args_schema,
|
||||
}
|
||||
|
||||
try:
|
||||
native_tool = MCPNativeTool(
|
||||
mcp_client=client,
|
||||
tool_name=tool_name,
|
||||
tool_schema=tool_schema,
|
||||
server_name=server_name,
|
||||
)
|
||||
tools.append(native_tool)
|
||||
except Exception as e:
|
||||
self._logger.log("error", f"Failed to create native MCP tool: {e}")
|
||||
continue
|
||||
|
||||
return cast(list[BaseTool], tools), client
|
||||
except Exception as e:
|
||||
if client.connected:
|
||||
asyncio.run(client.disconnect())
|
||||
|
||||
raise RuntimeError(f"Failed to get native MCP tools: {e}") from e
|
||||
|
||||
def _get_amp_mcp_tools(self, amp_ref: str) -> list[BaseTool]:
|
||||
"""Get tools from CrewAI AMP MCP marketplace."""
|
||||
# Parse: "crewai-amp:mcp-name" or "crewai-amp:mcp-name#tool_name"
|
||||
@@ -739,9 +965,9 @@ class Agent(BaseAgent):
|
||||
|
||||
return tools
|
||||
|
||||
def _extract_server_name(self, server_url: str) -> str:
|
||||
@staticmethod
|
||||
def _extract_server_name(server_url: str) -> str:
|
||||
"""Extract clean server name from URL for tool prefixing."""
|
||||
from urllib.parse import urlparse
|
||||
|
||||
parsed = urlparse(server_url)
|
||||
domain = parsed.netloc.replace(".", "_")
|
||||
@@ -778,7 +1004,9 @@ class Agent(BaseAgent):
|
||||
)
|
||||
return {}
|
||||
|
||||
async def _get_mcp_tool_schemas_async(self, server_params: dict) -> dict[str, dict]:
|
||||
async def _get_mcp_tool_schemas_async(
|
||||
self, server_params: dict[str, Any]
|
||||
) -> dict[str, dict]:
|
||||
"""Async implementation of MCP tool schema retrieval with timeouts and retries."""
|
||||
server_url = server_params["url"]
|
||||
return await self._retry_mcp_discovery(
|
||||
@@ -787,7 +1015,7 @@ class Agent(BaseAgent):
|
||||
|
||||
async def _retry_mcp_discovery(
|
||||
self, operation_func, server_url: str
|
||||
) -> dict[str, dict]:
|
||||
) -> dict[str, dict[str, Any]]:
|
||||
"""Retry MCP discovery operation with exponential backoff, avoiding try-except in loop."""
|
||||
last_error = None
|
||||
|
||||
@@ -815,9 +1043,10 @@ class Agent(BaseAgent):
|
||||
f"Failed to discover MCP tools after {MCP_MAX_RETRIES} attempts: {last_error}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def _attempt_mcp_discovery(
|
||||
self, operation_func, server_url: str
|
||||
) -> tuple[dict[str, dict] | None, str, bool]:
|
||||
operation_func, server_url: str
|
||||
) -> tuple[dict[str, dict[str, Any]] | None, str, bool]:
|
||||
"""Attempt single MCP discovery operation and return (result, error_message, should_retry)."""
|
||||
try:
|
||||
result = await operation_func(server_url)
|
||||
@@ -851,13 +1080,13 @@ class Agent(BaseAgent):
|
||||
|
||||
async def _discover_mcp_tools_with_timeout(
|
||||
self, server_url: str
|
||||
) -> dict[str, dict]:
|
||||
) -> dict[str, dict[str, Any]]:
|
||||
"""Discover MCP tools with timeout wrapper."""
|
||||
return await asyncio.wait_for(
|
||||
self._discover_mcp_tools(server_url), timeout=MCP_DISCOVERY_TIMEOUT
|
||||
)
|
||||
|
||||
async def _discover_mcp_tools(self, server_url: str) -> dict[str, dict]:
|
||||
async def _discover_mcp_tools(self, server_url: str) -> dict[str, dict[str, Any]]:
|
||||
"""Discover tools from MCP server with proper timeout handling."""
|
||||
from mcp import ClientSession
|
||||
from mcp.client.streamable_http import streamablehttp_client
|
||||
@@ -889,7 +1118,9 @@ class Agent(BaseAgent):
|
||||
}
|
||||
return schemas
|
||||
|
||||
def _json_schema_to_pydantic(self, tool_name: str, json_schema: dict) -> type:
|
||||
def _json_schema_to_pydantic(
|
||||
self, tool_name: str, json_schema: dict[str, Any]
|
||||
) -> type:
|
||||
"""Convert JSON Schema to Pydantic model for tool arguments.
|
||||
|
||||
Args:
|
||||
@@ -926,7 +1157,7 @@ class Agent(BaseAgent):
|
||||
model_name = f"{tool_name.replace('-', '_').replace(' ', '_')}Schema"
|
||||
return create_model(model_name, **field_definitions)
|
||||
|
||||
def _json_type_to_python(self, field_schema: dict) -> type:
|
||||
def _json_type_to_python(self, field_schema: dict[str, Any]) -> type:
|
||||
"""Convert JSON Schema type to Python type.
|
||||
|
||||
Args:
|
||||
@@ -935,7 +1166,6 @@ class Agent(BaseAgent):
|
||||
Returns:
|
||||
Python type
|
||||
"""
|
||||
from typing import Any
|
||||
|
||||
json_type = field_schema.get("type")
|
||||
|
||||
@@ -965,13 +1195,15 @@ class Agent(BaseAgent):
|
||||
|
||||
return type_mapping.get(json_type, Any)
|
||||
|
||||
def _fetch_amp_mcp_servers(self, mcp_name: str) -> list[dict]:
|
||||
@staticmethod
|
||||
def _fetch_amp_mcp_servers(mcp_name: str) -> list[dict]:
|
||||
"""Fetch MCP server configurations from CrewAI AMP API."""
|
||||
# TODO: Implement AMP API call to "integrations/mcps" endpoint
|
||||
# Should return list of server configs with URLs
|
||||
return []
|
||||
|
||||
def get_multimodal_tools(self) -> Sequence[BaseTool]:
|
||||
@staticmethod
|
||||
def get_multimodal_tools() -> Sequence[BaseTool]:
|
||||
from crewai.tools.agent_tools.add_image_tool import AddImageTool
|
||||
|
||||
return [AddImageTool()]
|
||||
@@ -991,8 +1223,9 @@ class Agent(BaseAgent):
|
||||
)
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def get_output_converter(
|
||||
self, llm: BaseLLM, text: str, model: type[BaseModel], instructions: str
|
||||
llm: BaseLLM, text: str, model: type[BaseModel], instructions: str
|
||||
) -> Converter:
|
||||
return Converter(llm=llm, text=text, model=model, instructions=instructions)
|
||||
|
||||
@@ -1022,7 +1255,8 @@ class Agent(BaseAgent):
|
||||
)
|
||||
return task_prompt
|
||||
|
||||
def _render_text_description(self, tools: list[Any]) -> str:
|
||||
@staticmethod
|
||||
def _render_text_description(tools: list[Any]) -> str:
|
||||
"""Render the tool name and description in plain text.
|
||||
|
||||
Output will be in the format of:
|
||||
0
lib/crewai/src/crewai/agent/internal/__init__.py
Normal file
0
lib/crewai/src/crewai/agent/internal/__init__.py
Normal file
76
lib/crewai/src/crewai/agent/internal/meta.py
Normal file
76
lib/crewai/src/crewai/agent/internal/meta.py
Normal file
@@ -0,0 +1,76 @@
|
||||
"""Generic metaclass for agent extensions.
|
||||
|
||||
This metaclass enables extension capabilities for agents by detecting
|
||||
extension fields in class annotations and applying appropriate wrappers.
|
||||
"""
|
||||
|
||||
import warnings
|
||||
from functools import wraps
|
||||
from typing import Any
|
||||
|
||||
from pydantic import model_validator
|
||||
from pydantic._internal._model_construction import ModelMetaclass
|
||||
|
||||
|
||||
class AgentMeta(ModelMetaclass):
|
||||
"""Generic metaclass for agent extensions.
|
||||
|
||||
Detects extension fields (like 'a2a') in class annotations and applies
|
||||
the appropriate wrapper logic to enable extension functionality.
|
||||
"""
|
||||
|
||||
def __new__(
|
||||
mcs,
|
||||
name: str,
|
||||
bases: tuple[type, ...],
|
||||
namespace: dict[str, Any],
|
||||
**kwargs: Any,
|
||||
) -> type:
|
||||
"""Create a new class with extension support.
|
||||
|
||||
Args:
|
||||
name: The name of the class being created
|
||||
bases: Base classes
|
||||
namespace: Class namespace dictionary
|
||||
**kwargs: Additional keyword arguments
|
||||
|
||||
Returns:
|
||||
The newly created class with extension support if applicable
|
||||
"""
|
||||
orig_post_init_setup = namespace.get("post_init_setup")
|
||||
|
||||
if orig_post_init_setup is not None:
|
||||
original_func = (
|
||||
orig_post_init_setup.wrapped
|
||||
if hasattr(orig_post_init_setup, "wrapped")
|
||||
else orig_post_init_setup
|
||||
)
|
||||
|
||||
def post_init_setup_with_extensions(self: Any) -> Any:
|
||||
"""Wrap post_init_setup to apply extensions after initialization.
|
||||
|
||||
Args:
|
||||
self: The agent instance
|
||||
|
||||
Returns:
|
||||
The agent instance
|
||||
"""
|
||||
result = original_func(self)
|
||||
|
||||
a2a_value = getattr(self, "a2a", None)
|
||||
if a2a_value is not None:
|
||||
from crewai.a2a.wrapper import wrap_agent_with_a2a_instance
|
||||
|
||||
wrap_agent_with_a2a_instance(self)
|
||||
|
||||
return result
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings(
|
||||
"ignore", message=".*overrides an existing Pydantic.*"
|
||||
)
|
||||
namespace["post_init_setup"] = model_validator(mode="after")(
|
||||
post_init_setup_with_extensions
|
||||
)
|
||||
|
||||
return super().__new__(mcs, name, bases, namespace, **kwargs)
|
||||
@@ -7,7 +7,7 @@ output conversion for OpenAI agents, supporting JSON and Pydantic model formats.
|
||||
from typing import Any
|
||||
|
||||
from crewai.agents.agent_adapters.base_converter_adapter import BaseConverterAdapter
|
||||
from crewai.utilities.i18n import I18N
|
||||
from crewai.utilities.i18n import get_i18n
|
||||
|
||||
|
||||
class OpenAIConverterAdapter(BaseConverterAdapter):
|
||||
@@ -59,7 +59,7 @@ class OpenAIConverterAdapter(BaseConverterAdapter):
|
||||
return base_prompt
|
||||
|
||||
output_schema: str = (
|
||||
I18N()
|
||||
get_i18n()
|
||||
.slice("formatted_task_instructions")
|
||||
.format(output_format=self._schema)
|
||||
)
|
||||
|
||||
@@ -18,17 +18,19 @@ from pydantic import (
|
||||
from pydantic_core import PydanticCustomError
|
||||
from typing_extensions import Self
|
||||
|
||||
from crewai.agent.internal.meta import AgentMeta
|
||||
from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess
|
||||
from crewai.agents.cache.cache_handler import CacheHandler
|
||||
from crewai.agents.tools_handler import ToolsHandler
|
||||
from crewai.knowledge.knowledge import Knowledge
|
||||
from crewai.knowledge.knowledge_config import KnowledgeConfig
|
||||
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
|
||||
from crewai.mcp.config import MCPServerConfig
|
||||
from crewai.rag.embeddings.types import EmbedderConfig
|
||||
from crewai.security.security_config import SecurityConfig
|
||||
from crewai.tools.base_tool import BaseTool, Tool
|
||||
from crewai.utilities.config import process_config
|
||||
from crewai.utilities.i18n import I18N
|
||||
from crewai.utilities.i18n import I18N, get_i18n
|
||||
from crewai.utilities.logger import Logger
|
||||
from crewai.utilities.rpm_controller import RPMController
|
||||
from crewai.utilities.string_utils import interpolate_only
|
||||
@@ -56,7 +58,7 @@ PlatformApp = Literal[
|
||||
PlatformAppOrAction = PlatformApp | str
|
||||
|
||||
|
||||
class BaseAgent(BaseModel, ABC):
|
||||
class BaseAgent(BaseModel, ABC, metaclass=AgentMeta):
|
||||
"""Abstract Base Class for all third party agents compatible with CrewAI.
|
||||
|
||||
Attributes:
|
||||
@@ -106,7 +108,7 @@ class BaseAgent(BaseModel, ABC):
|
||||
Set private attributes.
|
||||
"""
|
||||
|
||||
__hash__ = object.__hash__ # type: ignore
|
||||
__hash__ = object.__hash__
|
||||
_logger: Logger = PrivateAttr(default_factory=lambda: Logger(verbose=False))
|
||||
_rpm_controller: RPMController | None = PrivateAttr(default=None)
|
||||
_request_within_rpm_limit: Any = PrivateAttr(default=None)
|
||||
@@ -149,7 +151,7 @@ class BaseAgent(BaseModel, ABC):
|
||||
)
|
||||
crew: Any = Field(default=None, description="Crew to which the agent belongs.")
|
||||
i18n: I18N = Field(
|
||||
default_factory=I18N, description="Internationalization settings."
|
||||
default_factory=get_i18n, description="Internationalization settings."
|
||||
)
|
||||
cache_handler: CacheHandler | None = Field(
|
||||
default=None, description="An instance of the CacheHandler class."
|
||||
@@ -179,8 +181,8 @@ class BaseAgent(BaseModel, ABC):
|
||||
default_factory=SecurityConfig,
|
||||
description="Security configuration for the agent, including fingerprinting.",
|
||||
)
|
||||
callbacks: list[Callable] = Field(
|
||||
default=[], description="Callbacks to be used for the agent"
|
||||
callbacks: list[Callable[[Any], Any]] = Field(
|
||||
default_factory=list, description="Callbacks to be used for the agent"
|
||||
)
|
||||
adapted_agent: bool = Field(
|
||||
default=False, description="Whether the agent is adapted"
|
||||
@@ -193,14 +195,14 @@ class BaseAgent(BaseModel, ABC):
|
||||
default=None,
|
||||
description="List of applications or application/action combinations that the agent can access through CrewAI Platform. Can contain app names (e.g., 'gmail') or specific actions (e.g., 'gmail/send_email')",
|
||||
)
|
||||
mcps: list[str] | None = Field(
|
||||
mcps: list[str | MCPServerConfig] | None = Field(
|
||||
default=None,
|
||||
description="List of MCP server references. Supports 'https://server.com/path' for external servers and 'crewai-amp:mcp-name' for AMP marketplace. Use '#tool_name' suffix for specific tools.",
|
||||
)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def process_model_config(cls, values):
|
||||
def process_model_config(cls, values: Any) -> dict[str, Any]:
|
||||
return process_config(values, cls)
|
||||
|
||||
@field_validator("tools")
|
||||
@@ -252,23 +254,39 @@ class BaseAgent(BaseModel, ABC):
|
||||
|
||||
@field_validator("mcps")
|
||||
@classmethod
|
||||
def validate_mcps(cls, mcps: list[str] | None) -> list[str] | None:
|
||||
def validate_mcps(
|
||||
cls, mcps: list[str | MCPServerConfig] | None
|
||||
) -> list[str | MCPServerConfig] | None:
|
||||
"""Validate MCP server references and configurations.
|
||||
|
||||
Supports both string references (for backwards compatibility) and
|
||||
structured configuration objects (MCPServerStdio, MCPServerHTTP, MCPServerSSE).
|
||||
"""
|
||||
if not mcps:
|
||||
return mcps
|
||||
|
||||
validated_mcps = []
|
||||
for mcp in mcps:
|
||||
if mcp.startswith(("https://", "crewai-amp:")):
|
||||
if isinstance(mcp, str):
|
||||
if mcp.startswith(("https://", "crewai-amp:")):
|
||||
validated_mcps.append(mcp)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid MCP reference: {mcp}. "
|
||||
"String references must start with 'https://' or 'crewai-amp:'"
|
||||
)
|
||||
|
||||
elif isinstance(mcp, (MCPServerConfig)):
|
||||
validated_mcps.append(mcp)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid MCP reference: {mcp}. Must start with 'https://' or 'crewai-amp:'"
|
||||
f"Invalid MCP configuration: {type(mcp)}. "
|
||||
"Must be a string reference or MCPServerConfig instance."
|
||||
)
|
||||
|
||||
return list(set(validated_mcps))
|
||||
return validated_mcps
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_and_set_attributes(self):
|
||||
def validate_and_set_attributes(self) -> Self:
|
||||
# Validate required fields
|
||||
for field in ["role", "goal", "backstory"]:
|
||||
if getattr(self, field) is None:
|
||||
@@ -300,7 +318,7 @@ class BaseAgent(BaseModel, ABC):
|
||||
)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def set_private_attrs(self):
|
||||
def set_private_attrs(self) -> Self:
|
||||
"""Set private attributes."""
|
||||
self._logger = Logger(verbose=self.verbose)
|
||||
if self.max_rpm and not self._rpm_controller:
|
||||
@@ -312,7 +330,7 @@ class BaseAgent(BaseModel, ABC):
|
||||
return self
|
||||
|
||||
@property
|
||||
def key(self):
|
||||
def key(self) -> str:
|
||||
source = [
|
||||
self._original_role or self.role,
|
||||
self._original_goal or self.goal,
|
||||
@@ -330,7 +348,7 @@ class BaseAgent(BaseModel, ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def create_agent_executor(self, tools=None) -> None:
|
||||
def create_agent_executor(self, tools: list[BaseTool] | None = None) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@@ -342,7 +360,7 @@ class BaseAgent(BaseModel, ABC):
|
||||
"""Get platform tools for the specified list of applications and/or application/action combinations."""
|
||||
|
||||
@abstractmethod
|
||||
def get_mcp_tools(self, mcps: list[str]) -> list[BaseTool]:
|
||||
def get_mcp_tools(self, mcps: list[str | MCPServerConfig]) -> list[BaseTool]:
|
||||
"""Get MCP tools for the specified list of MCP server references."""
|
||||
|
||||
def copy(self) -> Self: # type: ignore # Signature of "copy" incompatible with supertype "BaseModel"
|
||||
@@ -442,5 +460,5 @@ class BaseAgent(BaseModel, ABC):
|
||||
self._rpm_controller = rpm_controller
|
||||
self.create_agent_executor()
|
||||
|
||||
def set_knowledge(self, crew_embedder: EmbedderConfig | None = None):
|
||||
def set_knowledge(self, crew_embedder: EmbedderConfig | None = None) -> None:
|
||||
pass
|
||||
|
||||
@@ -9,7 +9,7 @@ from __future__ import annotations
|
||||
from collections.abc import Callable
|
||||
from typing import TYPE_CHECKING, Any, Literal, cast
|
||||
|
||||
from pydantic import GetCoreSchemaHandler
|
||||
from pydantic import BaseModel, GetCoreSchemaHandler
|
||||
from pydantic_core import CoreSchema, core_schema
|
||||
|
||||
from crewai.agents.agent_builder.base_agent_executor_mixin import CrewAgentExecutorMixin
|
||||
@@ -37,7 +37,7 @@ from crewai.utilities.agent_utils import (
|
||||
process_llm_response,
|
||||
)
|
||||
from crewai.utilities.constants import TRAINING_DATA_FILE
|
||||
from crewai.utilities.i18n import I18N
|
||||
from crewai.utilities.i18n import I18N, get_i18n
|
||||
from crewai.utilities.printer import Printer
|
||||
from crewai.utilities.tool_utils import execute_tool_and_check_finality
|
||||
from crewai.utilities.training_handler import CrewTrainingHandler
|
||||
@@ -65,7 +65,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm: BaseLLM | Any,
|
||||
llm: BaseLLM,
|
||||
task: Task,
|
||||
crew: Crew,
|
||||
agent: Agent,
|
||||
@@ -82,6 +82,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
respect_context_window: bool = False,
|
||||
request_within_rpm_limit: Callable[[], bool] | None = None,
|
||||
callbacks: list[Any] | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> None:
|
||||
"""Initialize executor.
|
||||
|
||||
@@ -103,8 +104,9 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
respect_context_window: Respect context limits.
|
||||
request_within_rpm_limit: RPM limit check function.
|
||||
callbacks: Optional callbacks list.
|
||||
response_model: Optional Pydantic model for structured outputs.
|
||||
"""
|
||||
self._i18n: I18N = I18N()
|
||||
self._i18n: I18N = get_i18n()
|
||||
self.llm = llm
|
||||
self.task = task
|
||||
self.agent = agent
|
||||
@@ -119,23 +121,34 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
self.tools_handler = tools_handler
|
||||
self.original_tools = original_tools or []
|
||||
self.step_callback = step_callback
|
||||
self.use_stop_words = self.llm.supports_stop_words()
|
||||
self.tools_description = tools_description
|
||||
self.function_calling_llm = function_calling_llm
|
||||
self.respect_context_window = respect_context_window
|
||||
self.request_within_rpm_limit = request_within_rpm_limit
|
||||
self.response_model = response_model
|
||||
self.ask_for_human_input = False
|
||||
self.messages: list[LLMMessage] = []
|
||||
self.iterations = 0
|
||||
self.log_error_after = 3
|
||||
existing_stop = getattr(self.llm, "stop", [])
|
||||
self.llm.stop = list(
|
||||
set(
|
||||
existing_stop + self.stop
|
||||
if isinstance(existing_stop, list)
|
||||
else self.stop
|
||||
if self.llm:
|
||||
# This may be mutating the shared llm object and needs further evaluation
|
||||
existing_stop = getattr(self.llm, "stop", [])
|
||||
self.llm.stop = list(
|
||||
set(
|
||||
existing_stop + self.stop
|
||||
if isinstance(existing_stop, list)
|
||||
else self.stop
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
@property
|
||||
def use_stop_words(self) -> bool:
|
||||
"""Check to determine if stop words are being used.
|
||||
|
||||
Returns:
|
||||
bool: True if tool should be used or not.
|
||||
"""
|
||||
return self.llm.supports_stop_words() if self.llm else False
|
||||
|
||||
def invoke(self, inputs: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Execute the agent with given inputs.
|
||||
@@ -201,6 +214,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
llm=self.llm,
|
||||
callbacks=self.callbacks,
|
||||
)
|
||||
break
|
||||
|
||||
enforce_rpm_limit(self.request_within_rpm_limit)
|
||||
|
||||
@@ -211,8 +225,9 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
printer=self._printer,
|
||||
from_task=self.task,
|
||||
from_agent=self.agent,
|
||||
response_model=self.response_model,
|
||||
)
|
||||
formatted_answer = process_llm_response(answer, self.use_stop_words)
|
||||
formatted_answer = process_llm_response(answer, self.use_stop_words) # type: ignore[assignment]
|
||||
|
||||
if isinstance(formatted_answer, AgentAction):
|
||||
# Extract agent fingerprint if available
|
||||
@@ -244,11 +259,11 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
formatted_answer, tool_result
|
||||
)
|
||||
|
||||
self._invoke_step_callback(formatted_answer)
|
||||
self._append_message(formatted_answer.text)
|
||||
self._invoke_step_callback(formatted_answer) # type: ignore[arg-type]
|
||||
self._append_message(formatted_answer.text) # type: ignore[union-attr,attr-defined]
|
||||
|
||||
except OutputParserError as e: # noqa: PERF203
|
||||
formatted_answer = handle_output_parser_exception(
|
||||
except OutputParserError as e:
|
||||
formatted_answer = handle_output_parser_exception( # type: ignore[assignment]
|
||||
e=e,
|
||||
messages=self.messages,
|
||||
iterations=self.iterations,
|
||||
|
||||
@@ -18,10 +18,10 @@ from crewai.agents.constants import (
|
||||
MISSING_ACTION_INPUT_AFTER_ACTION_ERROR_MESSAGE,
|
||||
UNABLE_TO_REPAIR_JSON_RESULTS,
|
||||
)
|
||||
from crewai.utilities.i18n import I18N
|
||||
from crewai.utilities.i18n import get_i18n
|
||||
|
||||
|
||||
_I18N = I18N()
|
||||
_I18N = get_i18n()
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -3,10 +3,17 @@ import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
import sys
|
||||
from typing import BinaryIO, cast
|
||||
|
||||
from cryptography.fernet import Fernet
|
||||
|
||||
|
||||
if sys.platform == "win32":
|
||||
import msvcrt
|
||||
else:
|
||||
import fcntl
|
||||
|
||||
|
||||
class TokenManager:
|
||||
def __init__(self, file_path: str = "tokens.enc") -> None:
|
||||
"""
|
||||
@@ -18,21 +25,74 @@ class TokenManager:
|
||||
self.key = self._get_or_create_key()
|
||||
self.fernet = Fernet(self.key)
|
||||
|
||||
@staticmethod
|
||||
def _acquire_lock(file_handle: BinaryIO) -> None:
|
||||
"""
|
||||
Acquire an exclusive lock on a file handle.
|
||||
|
||||
Args:
|
||||
file_handle: Open file handle to lock.
|
||||
"""
|
||||
if sys.platform == "win32":
|
||||
msvcrt.locking(file_handle.fileno(), msvcrt.LK_LOCK, 1)
|
||||
else:
|
||||
fcntl.flock(file_handle.fileno(), fcntl.LOCK_EX)
|
||||
|
||||
@staticmethod
|
||||
def _release_lock(file_handle: BinaryIO) -> None:
|
||||
"""
|
||||
Release the lock on a file handle.
|
||||
|
||||
Args:
|
||||
file_handle: Open file handle to unlock.
|
||||
"""
|
||||
if sys.platform == "win32":
|
||||
msvcrt.locking(file_handle.fileno(), msvcrt.LK_UNLCK, 1)
|
||||
else:
|
||||
fcntl.flock(file_handle.fileno(), fcntl.LOCK_UN)
|
||||
|
||||
def _get_or_create_key(self) -> bytes:
|
||||
"""
|
||||
Get or create the encryption key.
|
||||
Get or create the encryption key with file locking to prevent race conditions.
|
||||
|
||||
:return: The encryption key.
|
||||
Returns:
|
||||
The encryption key.
|
||||
"""
|
||||
key_filename = "secret.key"
|
||||
key = self.read_secure_file(key_filename)
|
||||
storage_path = self.get_secure_storage_path()
|
||||
|
||||
if key is not None:
|
||||
key = self.read_secure_file(key_filename)
|
||||
if key is not None and len(key) == 44:
|
||||
return key
|
||||
|
||||
new_key = Fernet.generate_key()
|
||||
self.save_secure_file(key_filename, new_key)
|
||||
return new_key
|
||||
lock_file_path = storage_path / f"{key_filename}.lock"
|
||||
|
||||
try:
|
||||
lock_file_path.touch()
|
||||
|
||||
with open(lock_file_path, "r+b") as lock_file:
|
||||
self._acquire_lock(lock_file)
|
||||
try:
|
||||
key = self.read_secure_file(key_filename)
|
||||
if key is not None and len(key) == 44:
|
||||
return key
|
||||
|
||||
new_key = Fernet.generate_key()
|
||||
self.save_secure_file(key_filename, new_key)
|
||||
return new_key
|
||||
finally:
|
||||
try:
|
||||
self._release_lock(lock_file)
|
||||
except OSError:
|
||||
pass
|
||||
except OSError:
|
||||
key = self.read_secure_file(key_filename)
|
||||
if key is not None and len(key) == 44:
|
||||
return key
|
||||
|
||||
new_key = Fernet.generate_key()
|
||||
self.save_secure_file(key_filename, new_key)
|
||||
return new_key
|
||||
|
||||
def save_tokens(self, access_token: str, expires_at: int) -> None:
|
||||
"""
|
||||
@@ -59,14 +119,14 @@ class TokenManager:
|
||||
if encrypted_data is None:
|
||||
return None
|
||||
|
||||
decrypted_data = self.fernet.decrypt(encrypted_data) # type: ignore
|
||||
decrypted_data = self.fernet.decrypt(encrypted_data)
|
||||
data = json.loads(decrypted_data)
|
||||
|
||||
expiration = datetime.fromisoformat(data["expiration"])
|
||||
if expiration <= datetime.now():
|
||||
return None
|
||||
|
||||
return data["access_token"]
|
||||
return cast(str | None, data["access_token"])
|
||||
|
||||
def clear_tokens(self) -> None:
|
||||
"""
|
||||
@@ -74,20 +134,18 @@ class TokenManager:
|
||||
"""
|
||||
self.delete_secure_file(self.file_path)
|
||||
|
||||
def get_secure_storage_path(self) -> Path:
|
||||
@staticmethod
|
||||
def get_secure_storage_path() -> Path:
|
||||
"""
|
||||
Get the secure storage path based on the operating system.
|
||||
|
||||
:return: The secure storage path.
|
||||
"""
|
||||
if sys.platform == "win32":
|
||||
# Windows: Use %LOCALAPPDATA%
|
||||
base_path = os.environ.get("LOCALAPPDATA")
|
||||
elif sys.platform == "darwin":
|
||||
# macOS: Use ~/Library/Application Support
|
||||
base_path = os.path.expanduser("~/Library/Application Support")
|
||||
else:
|
||||
# Linux and other Unix-like: Use ~/.local/share
|
||||
base_path = os.path.expanduser("~/.local/share")
|
||||
|
||||
app_name = "crewai/credentials"
|
||||
@@ -110,7 +168,6 @@ class TokenManager:
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(content)
|
||||
|
||||
# Set appropriate permissions (read/write for owner only)
|
||||
os.chmod(file_path, 0o600)
|
||||
|
||||
def read_secure_file(self, filename: str) -> bytes | None:
|
||||
|
||||
@@ -5,7 +5,7 @@ description = "{{name}} using crewAI"
|
||||
authors = [{ name = "Your Name", email = "you@example.com" }]
|
||||
requires-python = ">=3.10,<3.14"
|
||||
dependencies = [
|
||||
"crewai[tools]==1.2.1"
|
||||
"crewai[tools]==1.4.1"
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
||||
@@ -5,7 +5,7 @@ description = "{{name}} using crewAI"
|
||||
authors = [{ name = "Your Name", email = "you@example.com" }]
|
||||
requires-python = ">=3.10,<3.14"
|
||||
dependencies = [
|
||||
"crewai[tools]==1.2.1"
|
||||
"crewai[tools]==1.4.1"
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
||||
@@ -27,6 +27,7 @@ from pydantic import (
|
||||
model_validator,
|
||||
)
|
||||
from pydantic_core import PydanticCustomError
|
||||
from typing_extensions import Self
|
||||
|
||||
from crewai.agent import Agent
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
@@ -70,7 +71,7 @@ from crewai.task import Task
|
||||
from crewai.tasks.conditional_task import ConditionalTask
|
||||
from crewai.tasks.task_output import TaskOutput
|
||||
from crewai.tools.agent_tools.agent_tools import AgentTools
|
||||
from crewai.tools.base_tool import BaseTool, Tool
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
from crewai.types.usage_metrics import UsageMetrics
|
||||
from crewai.utilities.constants import NOT_SPECIFIED, TRAINING_DATA_FILE
|
||||
from crewai.utilities.crew.models import CrewContext
|
||||
@@ -81,7 +82,7 @@ from crewai.utilities.formatter import (
|
||||
aggregate_raw_outputs_from_task_outputs,
|
||||
aggregate_raw_outputs_from_tasks,
|
||||
)
|
||||
from crewai.utilities.i18n import I18N
|
||||
from crewai.utilities.i18n import get_i18n
|
||||
from crewai.utilities.llm_utils import create_llm
|
||||
from crewai.utilities.logger import Logger
|
||||
from crewai.utilities.planning_handler import CrewPlanner
|
||||
@@ -195,7 +196,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
function_calling_llm: str | InstanceOf[LLM] | Any | None = Field(
|
||||
description="Language model that will run the agent.", default=None
|
||||
)
|
||||
config: Json | dict[str, Any] | None = Field(default=None)
|
||||
config: Json[dict[str, Any]] | dict[str, Any] | None = Field(default=None)
|
||||
id: UUID4 = Field(default_factory=uuid.uuid4, frozen=True)
|
||||
share_crew: bool | None = Field(default=False)
|
||||
step_callback: Any | None = Field(
|
||||
@@ -294,7 +295,9 @@ class Crew(FlowTrackable, BaseModel):
|
||||
|
||||
@field_validator("config", mode="before")
|
||||
@classmethod
|
||||
def check_config_type(cls, v: Json | dict[str, Any]) -> Json | dict[str, Any]:
|
||||
def check_config_type(
|
||||
cls, v: Json[dict[str, Any]] | dict[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
"""Validates that the config is a valid type.
|
||||
Args:
|
||||
v: The config to be validated.
|
||||
@@ -310,7 +313,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
"""set private attributes."""
|
||||
|
||||
self._cache_handler = CacheHandler()
|
||||
event_listener = EventListener()
|
||||
event_listener = EventListener() # type: ignore[no-untyped-call]
|
||||
|
||||
if (
|
||||
is_tracing_enabled()
|
||||
@@ -330,13 +333,13 @@ class Crew(FlowTrackable, BaseModel):
|
||||
|
||||
return self
|
||||
|
||||
def _initialize_default_memories(self):
|
||||
self._long_term_memory = self._long_term_memory or LongTermMemory()
|
||||
self._short_term_memory = self._short_term_memory or ShortTermMemory(
|
||||
def _initialize_default_memories(self) -> None:
|
||||
self._long_term_memory = self._long_term_memory or LongTermMemory() # type: ignore[no-untyped-call]
|
||||
self._short_term_memory = self._short_term_memory or ShortTermMemory( # type: ignore[no-untyped-call]
|
||||
crew=self,
|
||||
embedder_config=self.embedder,
|
||||
)
|
||||
self._entity_memory = self.entity_memory or EntityMemory(
|
||||
self._entity_memory = self.entity_memory or EntityMemory( # type: ignore[no-untyped-call]
|
||||
crew=self, embedder_config=self.embedder
|
||||
)
|
||||
|
||||
@@ -380,7 +383,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_manager_llm(self):
|
||||
def check_manager_llm(self) -> Self:
|
||||
"""Validates that the language model is set when using hierarchical process."""
|
||||
if self.process == Process.hierarchical:
|
||||
if not self.manager_llm and not self.manager_agent:
|
||||
@@ -405,7 +408,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_config(self):
|
||||
def check_config(self) -> Self:
|
||||
"""Validates that the crew is properly configured with agents and tasks."""
|
||||
if not self.config and not self.tasks and not self.agents:
|
||||
raise PydanticCustomError(
|
||||
@@ -426,23 +429,20 @@ class Crew(FlowTrackable, BaseModel):
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_tasks(self):
|
||||
def validate_tasks(self) -> Self:
|
||||
if self.process == Process.sequential:
|
||||
for task in self.tasks:
|
||||
if task.agent is None:
|
||||
raise PydanticCustomError(
|
||||
"missing_agent_in_task",
|
||||
(
|
||||
f"Sequential process error: Agent is missing in the task "
|
||||
f"with the following description: {task.description}"
|
||||
), # type: ignore # Dynamic string in error message
|
||||
{},
|
||||
"Sequential process error: Agent is missing in the task with the following description: {description}",
|
||||
{"description": task.description},
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_end_with_at_most_one_async_task(self):
|
||||
def validate_end_with_at_most_one_async_task(self) -> Self:
|
||||
"""Validates that the crew ends with at most one asynchronous task."""
|
||||
final_async_task_count = 0
|
||||
|
||||
@@ -505,7 +505,9 @@ class Crew(FlowTrackable, BaseModel):
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_async_task_cannot_include_sequential_async_tasks_in_context(self):
|
||||
def validate_async_task_cannot_include_sequential_async_tasks_in_context(
|
||||
self,
|
||||
) -> Self:
|
||||
"""
|
||||
Validates that if a task is set to be executed asynchronously,
|
||||
it cannot include other asynchronous tasks in its context unless
|
||||
@@ -527,7 +529,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_context_no_future_tasks(self):
|
||||
def validate_context_no_future_tasks(self) -> Self:
|
||||
"""Validates that a task's context does not include future tasks."""
|
||||
task_indices = {id(task): i for i, task in enumerate(self.tasks)}
|
||||
|
||||
@@ -561,7 +563,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
"""
|
||||
return self.security_config.fingerprint
|
||||
|
||||
def _setup_from_config(self):
|
||||
def _setup_from_config(self) -> None:
|
||||
"""Initializes agents and tasks from the provided config."""
|
||||
if self.config is None:
|
||||
raise ValueError("Config should not be None.")
|
||||
@@ -628,12 +630,12 @@ class Crew(FlowTrackable, BaseModel):
|
||||
|
||||
for agent in train_crew.agents:
|
||||
if training_data.get(str(agent.id)):
|
||||
result = TaskEvaluator(agent).evaluate_training_data(
|
||||
result = TaskEvaluator(agent).evaluate_training_data( # type: ignore[arg-type]
|
||||
training_data=training_data, agent_id=str(agent.id)
|
||||
)
|
||||
CrewTrainingHandler(filename).save_trained_data(
|
||||
agent_id=str(agent.role),
|
||||
trained_data=result.model_dump(), # type: ignore[arg-type]
|
||||
trained_data=result.model_dump(),
|
||||
)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
@@ -684,12 +686,8 @@ class Crew(FlowTrackable, BaseModel):
|
||||
self._set_tasks_callbacks()
|
||||
self._set_allow_crewai_trigger_context_for_first_task()
|
||||
|
||||
i18n = I18N(prompt_file=self.prompt_file)
|
||||
|
||||
for agent in self.agents:
|
||||
agent.i18n = i18n
|
||||
# type: ignore[attr-defined] # Argument 1 to "_interpolate_inputs" of "Crew" has incompatible type "dict[str, Any] | None"; expected "dict[str, Any]"
|
||||
agent.crew = self # type: ignore[attr-defined]
|
||||
agent.crew = self
|
||||
agent.set_knowledge(crew_embedder=self.embedder)
|
||||
# TODO: Create an AgentFunctionCalling protocol for future refactoring
|
||||
if not agent.function_calling_llm: # type: ignore # "BaseAgent" has no attribute "function_calling_llm"
|
||||
@@ -753,10 +751,12 @@ class Crew(FlowTrackable, BaseModel):
|
||||
inputs = inputs or {}
|
||||
return await asyncio.to_thread(self.kickoff, inputs)
|
||||
|
||||
async def kickoff_for_each_async(self, inputs: list[dict]) -> list[CrewOutput]:
|
||||
async def kickoff_for_each_async(
|
||||
self, inputs: list[dict[str, Any]]
|
||||
) -> list[CrewOutput]:
|
||||
crew_copies = [self.copy() for _ in inputs]
|
||||
|
||||
async def run_crew(crew, input_data):
|
||||
async def run_crew(crew: Self, input_data: Any) -> CrewOutput:
|
||||
return await crew.kickoff_async(inputs=input_data)
|
||||
|
||||
tasks = [
|
||||
@@ -775,7 +775,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
self._task_output_handler.reset()
|
||||
return results
|
||||
|
||||
def _handle_crew_planning(self):
|
||||
def _handle_crew_planning(self) -> None:
|
||||
"""Handles the Crew planning."""
|
||||
self._logger.log("info", "Planning the crew execution")
|
||||
result = CrewPlanner(
|
||||
@@ -793,7 +793,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
output: TaskOutput,
|
||||
task_index: int,
|
||||
was_replayed: bool = False,
|
||||
):
|
||||
) -> None:
|
||||
if self._inputs:
|
||||
inputs = self._inputs
|
||||
else:
|
||||
@@ -825,19 +825,21 @@ class Crew(FlowTrackable, BaseModel):
|
||||
self._create_manager_agent()
|
||||
return self._execute_tasks(self.tasks)
|
||||
|
||||
def _create_manager_agent(self):
|
||||
i18n = I18N(prompt_file=self.prompt_file)
|
||||
def _create_manager_agent(self) -> None:
|
||||
if self.manager_agent is not None:
|
||||
self.manager_agent.allow_delegation = True
|
||||
manager = self.manager_agent
|
||||
if manager.tools is not None and len(manager.tools) > 0:
|
||||
self._logger.log(
|
||||
"warning", "Manager agent should not have tools", color="orange"
|
||||
"warning",
|
||||
"Manager agent should not have tools",
|
||||
color="bold_yellow",
|
||||
)
|
||||
manager.tools = []
|
||||
raise Exception("Manager agent should not have tools")
|
||||
else:
|
||||
self.manager_llm = create_llm(self.manager_llm)
|
||||
i18n = get_i18n(prompt_file=self.prompt_file)
|
||||
manager = Agent(
|
||||
role=i18n.retrieve("hierarchical_manager_agent", "role"),
|
||||
goal=i18n.retrieve("hierarchical_manager_agent", "goal"),
|
||||
@@ -895,7 +897,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
tools_for_task = self._prepare_tools(
|
||||
agent_to_use,
|
||||
task,
|
||||
cast(list[Tool] | list[BaseTool], tools_for_task),
|
||||
tools_for_task,
|
||||
)
|
||||
|
||||
self._log_task_start(task, agent_to_use.role)
|
||||
@@ -915,7 +917,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
future = task.execute_async(
|
||||
agent=agent_to_use,
|
||||
context=context,
|
||||
tools=cast(list[BaseTool], tools_for_task),
|
||||
tools=tools_for_task,
|
||||
)
|
||||
futures.append((task, future, task_index))
|
||||
else:
|
||||
@@ -927,7 +929,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
task_output = task.execute_sync(
|
||||
agent=agent_to_use,
|
||||
context=context,
|
||||
tools=cast(list[BaseTool], tools_for_task),
|
||||
tools=tools_for_task,
|
||||
)
|
||||
task_outputs.append(task_output)
|
||||
self._process_task_result(task, task_output)
|
||||
@@ -965,7 +967,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
return None
|
||||
|
||||
def _prepare_tools(
|
||||
self, agent: BaseAgent, task: Task, tools: list[Tool] | list[BaseTool]
|
||||
self, agent: BaseAgent, task: Task, tools: list[BaseTool]
|
||||
) -> list[BaseTool]:
|
||||
# Add delegation tools if agent allows delegation
|
||||
if hasattr(agent, "allow_delegation") and getattr(
|
||||
@@ -1002,21 +1004,21 @@ class Crew(FlowTrackable, BaseModel):
|
||||
tools = self._add_mcp_tools(task, tools)
|
||||
|
||||
# Return a list[BaseTool] compatible with Task.execute_sync and execute_async
|
||||
return cast(list[BaseTool], tools)
|
||||
return tools
|
||||
|
||||
def _get_agent_to_use(self, task: Task) -> BaseAgent | None:
|
||||
if self.process == Process.hierarchical:
|
||||
return self.manager_agent
|
||||
return task.agent
|
||||
|
||||
@staticmethod
|
||||
def _merge_tools(
|
||||
self,
|
||||
existing_tools: list[Tool] | list[BaseTool],
|
||||
new_tools: list[Tool] | list[BaseTool],
|
||||
existing_tools: list[BaseTool],
|
||||
new_tools: list[BaseTool],
|
||||
) -> list[BaseTool]:
|
||||
"""Merge new tools into existing tools list, avoiding duplicates."""
|
||||
if not new_tools:
|
||||
return cast(list[BaseTool], existing_tools)
|
||||
return existing_tools
|
||||
|
||||
# Create mapping of tool names to new tools
|
||||
new_tool_map = {tool.name: tool for tool in new_tools}
|
||||
@@ -1027,63 +1029,62 @@ class Crew(FlowTrackable, BaseModel):
|
||||
# Add all new tools
|
||||
tools.extend(new_tools)
|
||||
|
||||
return cast(list[BaseTool], tools)
|
||||
return tools
|
||||
|
||||
def _inject_delegation_tools(
|
||||
self,
|
||||
tools: list[Tool] | list[BaseTool],
|
||||
tools: list[BaseTool],
|
||||
task_agent: BaseAgent,
|
||||
agents: list[BaseAgent],
|
||||
) -> list[BaseTool]:
|
||||
if hasattr(task_agent, "get_delegation_tools"):
|
||||
delegation_tools = task_agent.get_delegation_tools(agents)
|
||||
# Cast delegation_tools to the expected type for _merge_tools
|
||||
return self._merge_tools(tools, cast(list[BaseTool], delegation_tools))
|
||||
return cast(list[BaseTool], tools)
|
||||
return self._merge_tools(tools, delegation_tools)
|
||||
return tools
|
||||
|
||||
def _inject_platform_tools(
|
||||
self,
|
||||
tools: list[Tool] | list[BaseTool],
|
||||
tools: list[BaseTool],
|
||||
task_agent: BaseAgent,
|
||||
) -> list[BaseTool]:
|
||||
apps = getattr(task_agent, "apps", None) or []
|
||||
|
||||
if hasattr(task_agent, "get_platform_tools") and apps:
|
||||
platform_tools = task_agent.get_platform_tools(apps=apps)
|
||||
return self._merge_tools(tools, cast(list[BaseTool], platform_tools))
|
||||
return cast(list[BaseTool], tools)
|
||||
return self._merge_tools(tools, platform_tools)
|
||||
return tools
|
||||
|
||||
def _inject_mcp_tools(
|
||||
self,
|
||||
tools: list[Tool] | list[BaseTool],
|
||||
tools: list[BaseTool],
|
||||
task_agent: BaseAgent,
|
||||
) -> list[BaseTool]:
|
||||
mcps = getattr(task_agent, "mcps", None) or []
|
||||
if hasattr(task_agent, "get_mcp_tools") and mcps:
|
||||
mcp_tools = task_agent.get_mcp_tools(mcps=mcps)
|
||||
return self._merge_tools(tools, cast(list[BaseTool], mcp_tools))
|
||||
return cast(list[BaseTool], tools)
|
||||
return self._merge_tools(tools, mcp_tools)
|
||||
return tools
|
||||
|
||||
def _add_multimodal_tools(
|
||||
self, agent: BaseAgent, tools: list[Tool] | list[BaseTool]
|
||||
self, agent: BaseAgent, tools: list[BaseTool]
|
||||
) -> list[BaseTool]:
|
||||
if hasattr(agent, "get_multimodal_tools"):
|
||||
multimodal_tools = agent.get_multimodal_tools()
|
||||
# Cast multimodal_tools to the expected type for _merge_tools
|
||||
return self._merge_tools(tools, cast(list[BaseTool], multimodal_tools))
|
||||
return cast(list[BaseTool], tools)
|
||||
return tools
|
||||
|
||||
def _add_code_execution_tools(
|
||||
self, agent: BaseAgent, tools: list[Tool] | list[BaseTool]
|
||||
self, agent: BaseAgent, tools: list[BaseTool]
|
||||
) -> list[BaseTool]:
|
||||
if hasattr(agent, "get_code_execution_tools"):
|
||||
code_tools = agent.get_code_execution_tools()
|
||||
# Cast code_tools to the expected type for _merge_tools
|
||||
return self._merge_tools(tools, cast(list[BaseTool], code_tools))
|
||||
return cast(list[BaseTool], tools)
|
||||
return tools
|
||||
|
||||
def _add_delegation_tools(
|
||||
self, task: Task, tools: list[Tool] | list[BaseTool]
|
||||
self, task: Task, tools: list[BaseTool]
|
||||
) -> list[BaseTool]:
|
||||
agents_for_delegation = [agent for agent in self.agents if agent != task.agent]
|
||||
if len(self.agents) > 1 and len(agents_for_delegation) > 0 and task.agent:
|
||||
@@ -1092,25 +1093,21 @@ class Crew(FlowTrackable, BaseModel):
|
||||
tools = self._inject_delegation_tools(
|
||||
tools, task.agent, agents_for_delegation
|
||||
)
|
||||
return cast(list[BaseTool], tools)
|
||||
return tools
|
||||
|
||||
def _add_platform_tools(
|
||||
self, task: Task, tools: list[Tool] | list[BaseTool]
|
||||
) -> list[BaseTool]:
|
||||
def _add_platform_tools(self, task: Task, tools: list[BaseTool]) -> list[BaseTool]:
|
||||
if task.agent:
|
||||
tools = self._inject_platform_tools(tools, task.agent)
|
||||
|
||||
return cast(list[BaseTool], tools or [])
|
||||
return tools or []
|
||||
|
||||
def _add_mcp_tools(
|
||||
self, task: Task, tools: list[Tool] | list[BaseTool]
|
||||
) -> list[BaseTool]:
|
||||
def _add_mcp_tools(self, task: Task, tools: list[BaseTool]) -> list[BaseTool]:
|
||||
if task.agent:
|
||||
tools = self._inject_mcp_tools(tools, task.agent)
|
||||
|
||||
return cast(list[BaseTool], tools or [])
|
||||
return tools or []
|
||||
|
||||
def _log_task_start(self, task: Task, role: str = "None"):
|
||||
def _log_task_start(self, task: Task, role: str = "None") -> None:
|
||||
if self.output_log_file:
|
||||
self._file_handler.log(
|
||||
task_name=task.name, # type: ignore[arg-type]
|
||||
@@ -1120,7 +1117,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
)
|
||||
|
||||
def _update_manager_tools(
|
||||
self, task: Task, tools: list[Tool] | list[BaseTool]
|
||||
self, task: Task, tools: list[BaseTool]
|
||||
) -> list[BaseTool]:
|
||||
if self.manager_agent:
|
||||
if task.agent:
|
||||
@@ -1129,7 +1126,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
tools = self._inject_delegation_tools(
|
||||
tools, self.manager_agent, self.agents
|
||||
)
|
||||
return cast(list[BaseTool], tools)
|
||||
return tools
|
||||
|
||||
def _get_context(self, task: Task, task_outputs: list[TaskOutput]) -> str:
|
||||
if not task.context:
|
||||
@@ -1280,7 +1277,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
|
||||
return required_inputs
|
||||
|
||||
def copy(self):
|
||||
def copy(self) -> Crew: # type: ignore[override]
|
||||
"""
|
||||
Creates a deep copy of the Crew instance.
|
||||
|
||||
@@ -1311,7 +1308,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
manager_agent = self.manager_agent.copy() if self.manager_agent else None
|
||||
manager_llm = shallow_copy(self.manager_llm) if self.manager_llm else None
|
||||
|
||||
task_mapping = {}
|
||||
task_mapping: dict[str, Any] = {}
|
||||
|
||||
cloned_tasks = []
|
||||
existing_knowledge_sources = shallow_copy(self.knowledge_sources)
|
||||
@@ -1373,7 +1370,6 @@ class Crew(FlowTrackable, BaseModel):
|
||||
)
|
||||
for task in self.tasks
|
||||
]
|
||||
# type: ignore # "interpolate_inputs" of "Agent" does not return a value (it only ever returns None)
|
||||
for agent in self.agents:
|
||||
agent.interpolate_inputs(inputs)
|
||||
|
||||
@@ -1463,7 +1459,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
)
|
||||
raise
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"Crew(id={self.id}, process={self.process}, "
|
||||
f"number_of_agents={len(self.agents)}, "
|
||||
@@ -1520,7 +1516,9 @@ class Crew(FlowTrackable, BaseModel):
|
||||
if (system := config.get("system")) is not None:
|
||||
name = config.get("name")
|
||||
try:
|
||||
reset_fn: Callable = cast(Callable, config.get("reset"))
|
||||
reset_fn: Callable[[Any], Any] = cast(
|
||||
Callable[[Any], Any], config.get("reset")
|
||||
)
|
||||
reset_fn(system)
|
||||
self._logger.log(
|
||||
"info",
|
||||
@@ -1551,7 +1549,9 @@ class Crew(FlowTrackable, BaseModel):
|
||||
raise RuntimeError(f"{name} memory system is not initialized")
|
||||
|
||||
try:
|
||||
reset_fn: Callable = cast(Callable, config.get("reset"))
|
||||
reset_fn: Callable[[Any], Any] = cast(
|
||||
Callable[[Any], Any], config.get("reset")
|
||||
)
|
||||
reset_fn(system)
|
||||
self._logger.log(
|
||||
"info",
|
||||
@@ -1564,7 +1564,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
f"Failed to reset {name} memory: {e!s}"
|
||||
) from e
|
||||
|
||||
def _get_memory_systems(self):
|
||||
def _get_memory_systems(self) -> dict[str, Any]:
|
||||
"""Get all available memory systems with their configuration.
|
||||
|
||||
Returns:
|
||||
@@ -1572,10 +1572,10 @@ class Crew(FlowTrackable, BaseModel):
|
||||
display names.
|
||||
"""
|
||||
|
||||
def default_reset(memory):
|
||||
def default_reset(memory: Any) -> Any:
|
||||
return memory.reset()
|
||||
|
||||
def knowledge_reset(memory):
|
||||
def knowledge_reset(memory: Any) -> Any:
|
||||
return self.reset_knowledge(memory)
|
||||
|
||||
# Get knowledge for agents
|
||||
@@ -1635,7 +1635,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
for ks in knowledges:
|
||||
ks.reset()
|
||||
|
||||
def _set_allow_crewai_trigger_context_for_first_task(self):
|
||||
def _set_allow_crewai_trigger_context_for_first_task(self) -> None:
|
||||
crewai_trigger_payload = self._inputs and self._inputs.get(
|
||||
"crewai_trigger_payload"
|
||||
)
|
||||
|
||||
@@ -8,21 +8,14 @@ This module provides the event infrastructure that allows users to:
|
||||
- Declare handler dependencies for ordered execution
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from crewai.events.base_event_listener import BaseEventListener
|
||||
from crewai.events.depends import Depends
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.handler_graph import CircularDependencyError
|
||||
from crewai.events.types.agent_events import (
|
||||
AgentEvaluationCompletedEvent,
|
||||
AgentEvaluationFailedEvent,
|
||||
AgentEvaluationStartedEvent,
|
||||
AgentExecutionCompletedEvent,
|
||||
AgentExecutionErrorEvent,
|
||||
AgentExecutionStartedEvent,
|
||||
LiteAgentExecutionCompletedEvent,
|
||||
LiteAgentExecutionErrorEvent,
|
||||
LiteAgentExecutionStartedEvent,
|
||||
)
|
||||
from crewai.events.types.crew_events import (
|
||||
CrewKickoffCompletedEvent,
|
||||
CrewKickoffFailedEvent,
|
||||
@@ -67,6 +60,14 @@ from crewai.events.types.logging_events import (
|
||||
AgentLogsExecutionEvent,
|
||||
AgentLogsStartedEvent,
|
||||
)
|
||||
from crewai.events.types.mcp_events import (
|
||||
MCPConnectionCompletedEvent,
|
||||
MCPConnectionFailedEvent,
|
||||
MCPConnectionStartedEvent,
|
||||
MCPToolExecutionCompletedEvent,
|
||||
MCPToolExecutionFailedEvent,
|
||||
MCPToolExecutionStartedEvent,
|
||||
)
|
||||
from crewai.events.types.memory_events import (
|
||||
MemoryQueryCompletedEvent,
|
||||
MemoryQueryFailedEvent,
|
||||
@@ -100,6 +101,20 @@ from crewai.events.types.tool_usage_events import (
|
||||
)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.events.types.agent_events import (
|
||||
AgentEvaluationCompletedEvent,
|
||||
AgentEvaluationFailedEvent,
|
||||
AgentEvaluationStartedEvent,
|
||||
AgentExecutionCompletedEvent,
|
||||
AgentExecutionErrorEvent,
|
||||
AgentExecutionStartedEvent,
|
||||
LiteAgentExecutionCompletedEvent,
|
||||
LiteAgentExecutionErrorEvent,
|
||||
LiteAgentExecutionStartedEvent,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"AgentEvaluationCompletedEvent",
|
||||
"AgentEvaluationFailedEvent",
|
||||
@@ -145,6 +160,12 @@ __all__ = [
|
||||
"LiteAgentExecutionCompletedEvent",
|
||||
"LiteAgentExecutionErrorEvent",
|
||||
"LiteAgentExecutionStartedEvent",
|
||||
"MCPConnectionCompletedEvent",
|
||||
"MCPConnectionFailedEvent",
|
||||
"MCPConnectionStartedEvent",
|
||||
"MCPToolExecutionCompletedEvent",
|
||||
"MCPToolExecutionFailedEvent",
|
||||
"MCPToolExecutionStartedEvent",
|
||||
"MemoryQueryCompletedEvent",
|
||||
"MemoryQueryFailedEvent",
|
||||
"MemoryQueryStartedEvent",
|
||||
@@ -170,3 +191,27 @@ __all__ = [
|
||||
"ToolValidateInputErrorEvent",
|
||||
"crewai_event_bus",
|
||||
]
|
||||
|
||||
_AGENT_EVENT_MAPPING = {
|
||||
"AgentEvaluationCompletedEvent": "crewai.events.types.agent_events",
|
||||
"AgentEvaluationFailedEvent": "crewai.events.types.agent_events",
|
||||
"AgentEvaluationStartedEvent": "crewai.events.types.agent_events",
|
||||
"AgentExecutionCompletedEvent": "crewai.events.types.agent_events",
|
||||
"AgentExecutionErrorEvent": "crewai.events.types.agent_events",
|
||||
"AgentExecutionStartedEvent": "crewai.events.types.agent_events",
|
||||
"LiteAgentExecutionCompletedEvent": "crewai.events.types.agent_events",
|
||||
"LiteAgentExecutionErrorEvent": "crewai.events.types.agent_events",
|
||||
"LiteAgentExecutionStartedEvent": "crewai.events.types.agent_events",
|
||||
}
|
||||
|
||||
|
||||
def __getattr__(name: str):
|
||||
"""Lazy import for agent events to avoid circular imports."""
|
||||
if name in _AGENT_EVENT_MAPPING:
|
||||
import importlib
|
||||
|
||||
module_path = _AGENT_EVENT_MAPPING[name]
|
||||
module = importlib.import_module(module_path)
|
||||
return getattr(module, name)
|
||||
msg = f"module {__name__!r} has no attribute {name!r}"
|
||||
raise AttributeError(msg)
|
||||
|
||||
@@ -1,16 +1,26 @@
|
||||
"""Base event listener for CrewAI event system."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from crewai.events.event_bus import CrewAIEventsBus, crewai_event_bus
|
||||
|
||||
|
||||
class BaseEventListener(ABC):
|
||||
"""Abstract base class for event listeners."""
|
||||
|
||||
verbose: bool = False
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the event listener and register handlers."""
|
||||
super().__init__()
|
||||
self.setup_listeners(crewai_event_bus)
|
||||
crewai_event_bus.validate_dependencies()
|
||||
|
||||
@abstractmethod
|
||||
def setup_listeners(self, crewai_event_bus: CrewAIEventsBus):
|
||||
def setup_listeners(self, crewai_event_bus: CrewAIEventsBus) -> None:
|
||||
"""Setup event listeners on the event bus.
|
||||
|
||||
Args:
|
||||
crewai_event_bus: The event bus to register listeners on.
|
||||
"""
|
||||
pass
|
||||
|
||||
@@ -1,12 +1,21 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from io import StringIO
|
||||
from typing import Any
|
||||
import threading
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from pydantic import Field, PrivateAttr
|
||||
|
||||
from crewai.events.base_event_listener import BaseEventListener
|
||||
from crewai.events.listeners.memory_listener import MemoryListener
|
||||
from crewai.events.listeners.tracing.trace_listener import TraceCollectionListener
|
||||
from crewai.events.types.a2a_events import (
|
||||
A2AConversationCompletedEvent,
|
||||
A2AConversationStartedEvent,
|
||||
A2ADelegationCompletedEvent,
|
||||
A2ADelegationStartedEvent,
|
||||
A2AMessageSentEvent,
|
||||
A2AResponseReceivedEvent,
|
||||
)
|
||||
from crewai.events.types.agent_events import (
|
||||
AgentExecutionCompletedEvent,
|
||||
AgentExecutionStartedEvent,
|
||||
@@ -56,6 +65,14 @@ from crewai.events.types.logging_events import (
|
||||
AgentLogsExecutionEvent,
|
||||
AgentLogsStartedEvent,
|
||||
)
|
||||
from crewai.events.types.mcp_events import (
|
||||
MCPConnectionCompletedEvent,
|
||||
MCPConnectionFailedEvent,
|
||||
MCPConnectionStartedEvent,
|
||||
MCPToolExecutionCompletedEvent,
|
||||
MCPToolExecutionFailedEvent,
|
||||
MCPToolExecutionStartedEvent,
|
||||
)
|
||||
from crewai.events.types.reasoning_events import (
|
||||
AgentReasoningCompletedEvent,
|
||||
AgentReasoningFailedEvent,
|
||||
@@ -79,6 +96,10 @@ from crewai.utilities import Logger
|
||||
from crewai.utilities.constants import EMITTER_COLOR
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.events.event_bus import CrewAIEventsBus
|
||||
|
||||
|
||||
class EventListener(BaseEventListener):
|
||||
_instance = None
|
||||
_telemetry: Telemetry = PrivateAttr(default_factory=lambda: Telemetry())
|
||||
@@ -88,6 +109,7 @@ class EventListener(BaseEventListener):
|
||||
text_stream = StringIO()
|
||||
knowledge_retrieval_in_progress = False
|
||||
knowledge_query_in_progress = False
|
||||
method_branches: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
@@ -101,21 +123,27 @@ class EventListener(BaseEventListener):
|
||||
self._telemetry = Telemetry()
|
||||
self._telemetry.set_tracer()
|
||||
self.execution_spans = {}
|
||||
self.method_branches = {}
|
||||
self._initialized = True
|
||||
self.formatter = ConsoleFormatter(verbose=True)
|
||||
self._crew_tree_lock = threading.Condition()
|
||||
|
||||
MemoryListener(formatter=self.formatter)
|
||||
# Initialize trace listener with formatter for memory event handling
|
||||
trace_listener = TraceCollectionListener()
|
||||
trace_listener.formatter = self.formatter
|
||||
|
||||
# ----------- CREW EVENTS -----------
|
||||
|
||||
def setup_listeners(self, crewai_event_bus):
|
||||
def setup_listeners(self, crewai_event_bus: CrewAIEventsBus) -> None:
|
||||
@crewai_event_bus.on(CrewKickoffStartedEvent)
|
||||
def on_crew_started(source, event: CrewKickoffStartedEvent):
|
||||
self.formatter.create_crew_tree(event.crew_name or "Crew", source.id)
|
||||
self._telemetry.crew_execution_span(source, event.inputs)
|
||||
def on_crew_started(source, event: CrewKickoffStartedEvent) -> None:
|
||||
with self._crew_tree_lock:
|
||||
self.formatter.create_crew_tree(event.crew_name or "Crew", source.id)
|
||||
self._telemetry.crew_execution_span(source, event.inputs)
|
||||
self._crew_tree_lock.notify_all()
|
||||
|
||||
@crewai_event_bus.on(CrewKickoffCompletedEvent)
|
||||
def on_crew_completed(source, event: CrewKickoffCompletedEvent):
|
||||
def on_crew_completed(source, event: CrewKickoffCompletedEvent) -> None:
|
||||
# Handle telemetry
|
||||
final_string_output = event.output.raw
|
||||
self._telemetry.end_crew(source, final_string_output)
|
||||
@@ -129,7 +157,7 @@ class EventListener(BaseEventListener):
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(CrewKickoffFailedEvent)
|
||||
def on_crew_failed(source, event: CrewKickoffFailedEvent):
|
||||
def on_crew_failed(source, event: CrewKickoffFailedEvent) -> None:
|
||||
self.formatter.update_crew_tree(
|
||||
self.formatter.current_crew_tree,
|
||||
event.crew_name or "Crew",
|
||||
@@ -138,23 +166,23 @@ class EventListener(BaseEventListener):
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(CrewTrainStartedEvent)
|
||||
def on_crew_train_started(source, event: CrewTrainStartedEvent):
|
||||
def on_crew_train_started(source, event: CrewTrainStartedEvent) -> None:
|
||||
self.formatter.handle_crew_train_started(
|
||||
event.crew_name or "Crew", str(event.timestamp)
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(CrewTrainCompletedEvent)
|
||||
def on_crew_train_completed(source, event: CrewTrainCompletedEvent):
|
||||
def on_crew_train_completed(source, event: CrewTrainCompletedEvent) -> None:
|
||||
self.formatter.handle_crew_train_completed(
|
||||
event.crew_name or "Crew", str(event.timestamp)
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(CrewTrainFailedEvent)
|
||||
def on_crew_train_failed(source, event: CrewTrainFailedEvent):
|
||||
def on_crew_train_failed(source, event: CrewTrainFailedEvent) -> None:
|
||||
self.formatter.handle_crew_train_failed(event.crew_name or "Crew")
|
||||
|
||||
@crewai_event_bus.on(CrewTestResultEvent)
|
||||
def on_crew_test_result(source, event: CrewTestResultEvent):
|
||||
def on_crew_test_result(source, event: CrewTestResultEvent) -> None:
|
||||
self._telemetry.individual_test_result_span(
|
||||
source.crew,
|
||||
event.quality,
|
||||
@@ -165,14 +193,22 @@ class EventListener(BaseEventListener):
|
||||
# ----------- TASK EVENTS -----------
|
||||
|
||||
@crewai_event_bus.on(TaskStartedEvent)
|
||||
def on_task_started(source, event: TaskStartedEvent):
|
||||
def on_task_started(source, event: TaskStartedEvent) -> None:
|
||||
span = self._telemetry.task_started(crew=source.agent.crew, task=source)
|
||||
self.execution_spans[source] = span
|
||||
# Pass both task ID and task name (if set)
|
||||
task_name = source.name if hasattr(source, "name") and source.name else None
|
||||
self.formatter.create_task_branch(
|
||||
self.formatter.current_crew_tree, source.id, task_name
|
||||
)
|
||||
|
||||
with self._crew_tree_lock:
|
||||
self._crew_tree_lock.wait_for(
|
||||
lambda: self.formatter.current_crew_tree is not None, timeout=5.0
|
||||
)
|
||||
|
||||
if self.formatter.current_crew_tree is not None:
|
||||
task_name = (
|
||||
source.name if hasattr(source, "name") and source.name else None
|
||||
)
|
||||
self.formatter.create_task_branch(
|
||||
self.formatter.current_crew_tree, source.id, task_name
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(TaskCompletedEvent)
|
||||
def on_task_completed(source, event: TaskCompletedEvent):
|
||||
@@ -263,7 +299,8 @@ class EventListener(BaseEventListener):
|
||||
@crewai_event_bus.on(FlowCreatedEvent)
|
||||
def on_flow_created(source, event: FlowCreatedEvent):
|
||||
self._telemetry.flow_creation_span(event.flow_name)
|
||||
self.formatter.create_flow_tree(event.flow_name, str(source.flow_id))
|
||||
tree = self.formatter.create_flow_tree(event.flow_name, str(source.flow_id))
|
||||
self.formatter.current_flow_tree = tree
|
||||
|
||||
@crewai_event_bus.on(FlowStartedEvent)
|
||||
def on_flow_started(source, event: FlowStartedEvent):
|
||||
@@ -280,30 +317,36 @@ class EventListener(BaseEventListener):
|
||||
|
||||
@crewai_event_bus.on(MethodExecutionStartedEvent)
|
||||
def on_method_execution_started(source, event: MethodExecutionStartedEvent):
|
||||
self.formatter.update_method_status(
|
||||
self.formatter.current_method_branch,
|
||||
method_branch = self.method_branches.get(event.method_name)
|
||||
updated_branch = self.formatter.update_method_status(
|
||||
method_branch,
|
||||
self.formatter.current_flow_tree,
|
||||
event.method_name,
|
||||
"running",
|
||||
)
|
||||
self.method_branches[event.method_name] = updated_branch
|
||||
|
||||
@crewai_event_bus.on(MethodExecutionFinishedEvent)
|
||||
def on_method_execution_finished(source, event: MethodExecutionFinishedEvent):
|
||||
self.formatter.update_method_status(
|
||||
self.formatter.current_method_branch,
|
||||
method_branch = self.method_branches.get(event.method_name)
|
||||
updated_branch = self.formatter.update_method_status(
|
||||
method_branch,
|
||||
self.formatter.current_flow_tree,
|
||||
event.method_name,
|
||||
"completed",
|
||||
)
|
||||
self.method_branches[event.method_name] = updated_branch
|
||||
|
||||
@crewai_event_bus.on(MethodExecutionFailedEvent)
|
||||
def on_method_execution_failed(source, event: MethodExecutionFailedEvent):
|
||||
self.formatter.update_method_status(
|
||||
self.formatter.current_method_branch,
|
||||
method_branch = self.method_branches.get(event.method_name)
|
||||
updated_branch = self.formatter.update_method_status(
|
||||
method_branch,
|
||||
self.formatter.current_flow_tree,
|
||||
event.method_name,
|
||||
"failed",
|
||||
)
|
||||
self.method_branches[event.method_name] = updated_branch
|
||||
|
||||
# ----------- TOOL USAGE EVENTS -----------
|
||||
|
||||
@@ -524,5 +567,123 @@ class EventListener(BaseEventListener):
|
||||
event.verbose,
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(A2ADelegationStartedEvent)
|
||||
def on_a2a_delegation_started(source, event: A2ADelegationStartedEvent):
|
||||
self.formatter.handle_a2a_delegation_started(
|
||||
event.endpoint,
|
||||
event.task_description,
|
||||
event.agent_id,
|
||||
event.is_multiturn,
|
||||
event.turn_number,
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(A2ADelegationCompletedEvent)
|
||||
def on_a2a_delegation_completed(source, event: A2ADelegationCompletedEvent):
|
||||
self.formatter.handle_a2a_delegation_completed(
|
||||
event.status,
|
||||
event.result,
|
||||
event.error,
|
||||
event.is_multiturn,
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(A2AConversationStartedEvent)
|
||||
def on_a2a_conversation_started(source, event: A2AConversationStartedEvent):
|
||||
# Store A2A agent name for display in conversation tree
|
||||
if event.a2a_agent_name:
|
||||
self.formatter._current_a2a_agent_name = event.a2a_agent_name
|
||||
|
||||
self.formatter.handle_a2a_conversation_started(
|
||||
event.agent_id,
|
||||
event.endpoint,
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(A2AMessageSentEvent)
|
||||
def on_a2a_message_sent(source, event: A2AMessageSentEvent):
|
||||
self.formatter.handle_a2a_message_sent(
|
||||
event.message,
|
||||
event.turn_number,
|
||||
event.agent_role,
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(A2AResponseReceivedEvent)
|
||||
def on_a2a_response_received(source, event: A2AResponseReceivedEvent):
|
||||
self.formatter.handle_a2a_response_received(
|
||||
event.response,
|
||||
event.turn_number,
|
||||
event.status,
|
||||
event.agent_role,
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(A2AConversationCompletedEvent)
|
||||
def on_a2a_conversation_completed(source, event: A2AConversationCompletedEvent):
|
||||
self.formatter.handle_a2a_conversation_completed(
|
||||
event.status,
|
||||
event.final_result,
|
||||
event.error,
|
||||
event.total_turns,
|
||||
)
|
||||
|
||||
# ----------- MCP EVENTS -----------
|
||||
|
||||
@crewai_event_bus.on(MCPConnectionStartedEvent)
|
||||
def on_mcp_connection_started(source, event: MCPConnectionStartedEvent):
|
||||
self.formatter.handle_mcp_connection_started(
|
||||
event.server_name,
|
||||
event.server_url,
|
||||
event.transport_type,
|
||||
event.is_reconnect,
|
||||
event.connect_timeout,
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(MCPConnectionCompletedEvent)
|
||||
def on_mcp_connection_completed(source, event: MCPConnectionCompletedEvent):
|
||||
self.formatter.handle_mcp_connection_completed(
|
||||
event.server_name,
|
||||
event.server_url,
|
||||
event.transport_type,
|
||||
event.connection_duration_ms,
|
||||
event.is_reconnect,
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(MCPConnectionFailedEvent)
|
||||
def on_mcp_connection_failed(source, event: MCPConnectionFailedEvent):
|
||||
self.formatter.handle_mcp_connection_failed(
|
||||
event.server_name,
|
||||
event.server_url,
|
||||
event.transport_type,
|
||||
event.error,
|
||||
event.error_type,
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(MCPToolExecutionStartedEvent)
|
||||
def on_mcp_tool_execution_started(source, event: MCPToolExecutionStartedEvent):
|
||||
self.formatter.handle_mcp_tool_execution_started(
|
||||
event.server_name,
|
||||
event.tool_name,
|
||||
event.tool_args,
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(MCPToolExecutionCompletedEvent)
|
||||
def on_mcp_tool_execution_completed(
|
||||
source, event: MCPToolExecutionCompletedEvent
|
||||
):
|
||||
self.formatter.handle_mcp_tool_execution_completed(
|
||||
event.server_name,
|
||||
event.tool_name,
|
||||
event.tool_args,
|
||||
event.result,
|
||||
event.execution_duration_ms,
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(MCPToolExecutionFailedEvent)
|
||||
def on_mcp_tool_execution_failed(source, event: MCPToolExecutionFailedEvent):
|
||||
self.formatter.handle_mcp_tool_execution_failed(
|
||||
event.server_name,
|
||||
event.tool_name,
|
||||
event.tool_args,
|
||||
event.error,
|
||||
event.error_type,
|
||||
)
|
||||
|
||||
|
||||
event_listener = EventListener()
|
||||
|
||||
@@ -40,6 +40,14 @@ from crewai.events.types.llm_guardrail_events import (
|
||||
LLMGuardrailCompletedEvent,
|
||||
LLMGuardrailStartedEvent,
|
||||
)
|
||||
from crewai.events.types.mcp_events import (
|
||||
MCPConnectionCompletedEvent,
|
||||
MCPConnectionFailedEvent,
|
||||
MCPConnectionStartedEvent,
|
||||
MCPToolExecutionCompletedEvent,
|
||||
MCPToolExecutionFailedEvent,
|
||||
MCPToolExecutionStartedEvent,
|
||||
)
|
||||
from crewai.events.types.memory_events import (
|
||||
MemoryQueryCompletedEvent,
|
||||
MemoryQueryFailedEvent,
|
||||
@@ -115,4 +123,10 @@ EventTypes = (
|
||||
| MemoryQueryFailedEvent
|
||||
| MemoryRetrievalStartedEvent
|
||||
| MemoryRetrievalCompletedEvent
|
||||
| MCPConnectionStartedEvent
|
||||
| MCPConnectionCompletedEvent
|
||||
| MCPConnectionFailedEvent
|
||||
| MCPToolExecutionStartedEvent
|
||||
| MCPToolExecutionCompletedEvent
|
||||
| MCPToolExecutionFailedEvent
|
||||
)
|
||||
|
||||
@@ -1,106 +0,0 @@
|
||||
from crewai.events.base_event_listener import BaseEventListener
|
||||
from crewai.events.types.memory_events import (
|
||||
MemoryQueryCompletedEvent,
|
||||
MemoryQueryFailedEvent,
|
||||
MemoryRetrievalCompletedEvent,
|
||||
MemoryRetrievalStartedEvent,
|
||||
MemorySaveCompletedEvent,
|
||||
MemorySaveFailedEvent,
|
||||
MemorySaveStartedEvent,
|
||||
)
|
||||
|
||||
|
||||
class MemoryListener(BaseEventListener):
|
||||
def __init__(self, formatter):
|
||||
super().__init__()
|
||||
self.formatter = formatter
|
||||
self.memory_retrieval_in_progress = False
|
||||
self.memory_save_in_progress = False
|
||||
|
||||
def setup_listeners(self, crewai_event_bus):
|
||||
@crewai_event_bus.on(MemoryRetrievalStartedEvent)
|
||||
def on_memory_retrieval_started(source, event: MemoryRetrievalStartedEvent):
|
||||
if self.memory_retrieval_in_progress:
|
||||
return
|
||||
|
||||
self.memory_retrieval_in_progress = True
|
||||
|
||||
self.formatter.handle_memory_retrieval_started(
|
||||
self.formatter.current_agent_branch,
|
||||
self.formatter.current_crew_tree,
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(MemoryRetrievalCompletedEvent)
|
||||
def on_memory_retrieval_completed(source, event: MemoryRetrievalCompletedEvent):
|
||||
if not self.memory_retrieval_in_progress:
|
||||
return
|
||||
|
||||
self.memory_retrieval_in_progress = False
|
||||
self.formatter.handle_memory_retrieval_completed(
|
||||
self.formatter.current_agent_branch,
|
||||
self.formatter.current_crew_tree,
|
||||
event.memory_content,
|
||||
event.retrieval_time_ms,
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(MemoryQueryCompletedEvent)
|
||||
def on_memory_query_completed(source, event: MemoryQueryCompletedEvent):
|
||||
if not self.memory_retrieval_in_progress:
|
||||
return
|
||||
|
||||
self.formatter.handle_memory_query_completed(
|
||||
self.formatter.current_agent_branch,
|
||||
event.source_type,
|
||||
event.query_time_ms,
|
||||
self.formatter.current_crew_tree,
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(MemoryQueryFailedEvent)
|
||||
def on_memory_query_failed(source, event: MemoryQueryFailedEvent):
|
||||
if not self.memory_retrieval_in_progress:
|
||||
return
|
||||
|
||||
self.formatter.handle_memory_query_failed(
|
||||
self.formatter.current_agent_branch,
|
||||
self.formatter.current_crew_tree,
|
||||
event.error,
|
||||
event.source_type,
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(MemorySaveStartedEvent)
|
||||
def on_memory_save_started(source, event: MemorySaveStartedEvent):
|
||||
if self.memory_save_in_progress:
|
||||
return
|
||||
|
||||
self.memory_save_in_progress = True
|
||||
|
||||
self.formatter.handle_memory_save_started(
|
||||
self.formatter.current_agent_branch,
|
||||
self.formatter.current_crew_tree,
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(MemorySaveCompletedEvent)
|
||||
def on_memory_save_completed(source, event: MemorySaveCompletedEvent):
|
||||
if not self.memory_save_in_progress:
|
||||
return
|
||||
|
||||
self.memory_save_in_progress = False
|
||||
|
||||
self.formatter.handle_memory_save_completed(
|
||||
self.formatter.current_agent_branch,
|
||||
self.formatter.current_crew_tree,
|
||||
event.save_time_ms,
|
||||
event.source_type,
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(MemorySaveFailedEvent)
|
||||
def on_memory_save_failed(source, event: MemorySaveFailedEvent):
|
||||
if not self.memory_save_in_progress:
|
||||
return
|
||||
|
||||
self.formatter.handle_memory_save_failed(
|
||||
self.formatter.current_agent_branch,
|
||||
event.error,
|
||||
event.source_type,
|
||||
self.formatter.current_crew_tree,
|
||||
)
|
||||
@@ -73,15 +73,19 @@ class FirstTimeTraceHandler:
|
||||
self.is_first_time = should_auto_collect_first_time_traces()
|
||||
return self.is_first_time
|
||||
|
||||
def set_batch_manager(self, batch_manager: TraceBatchManager):
|
||||
"""Set reference to batch manager for sending events."""
|
||||
def set_batch_manager(self, batch_manager: TraceBatchManager) -> None:
|
||||
"""Set reference to batch manager for sending events.
|
||||
|
||||
Args:
|
||||
batch_manager: The trace batch manager instance.
|
||||
"""
|
||||
self.batch_manager = batch_manager
|
||||
|
||||
def mark_events_collected(self):
|
||||
def mark_events_collected(self) -> None:
|
||||
"""Mark that events have been collected during execution."""
|
||||
self.collected_events = True
|
||||
|
||||
def handle_execution_completion(self):
|
||||
def handle_execution_completion(self) -> None:
|
||||
"""Handle the completion flow as shown in your diagram."""
|
||||
if not self.is_first_time or not self.collected_events:
|
||||
return
|
||||
|
||||
@@ -44,6 +44,7 @@ class TraceBatchManager:
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._init_lock = Lock()
|
||||
self._batch_ready_cv = Condition(self._init_lock)
|
||||
self._pending_events_lock = Lock()
|
||||
self._pending_events_cv = Condition(self._pending_events_lock)
|
||||
self._pending_events_count = 0
|
||||
@@ -94,6 +95,8 @@ class TraceBatchManager:
|
||||
)
|
||||
self.backend_initialized = True
|
||||
|
||||
self._batch_ready_cv.notify_all()
|
||||
|
||||
return self.current_batch
|
||||
|
||||
def _initialize_backend_batch(
|
||||
@@ -161,13 +164,13 @@ class TraceBatchManager:
|
||||
f"Error initializing trace batch: {e}. Continuing without tracing."
|
||||
)
|
||||
|
||||
def begin_event_processing(self):
|
||||
"""Mark that an event handler started processing (for synchronization)"""
|
||||
def begin_event_processing(self) -> None:
|
||||
"""Mark that an event handler started processing (for synchronization)."""
|
||||
with self._pending_events_lock:
|
||||
self._pending_events_count += 1
|
||||
|
||||
def end_event_processing(self):
|
||||
"""Mark that an event handler finished processing (for synchronization)"""
|
||||
def end_event_processing(self) -> None:
|
||||
"""Mark that an event handler finished processing (for synchronization)."""
|
||||
with self._pending_events_cv:
|
||||
self._pending_events_count -= 1
|
||||
if self._pending_events_count == 0:
|
||||
@@ -385,6 +388,22 @@ class TraceBatchManager:
|
||||
"""Check if batch is initialized"""
|
||||
return self.current_batch is not None
|
||||
|
||||
def wait_for_batch_initialization(self, timeout: float = 2.0) -> bool:
|
||||
"""Wait for batch to be initialized.
|
||||
|
||||
Args:
|
||||
timeout: Maximum time to wait in seconds (default: 2.0)
|
||||
|
||||
Returns:
|
||||
True if batch was initialized, False if timeout occurred
|
||||
"""
|
||||
with self._batch_ready_cv:
|
||||
if self.current_batch is not None:
|
||||
return True
|
||||
return self._batch_ready_cv.wait_for(
|
||||
lambda: self.current_batch is not None, timeout=timeout
|
||||
)
|
||||
|
||||
def record_start_time(self, key: str):
|
||||
"""Record start time for duration calculation"""
|
||||
self.execution_start_times[key] = datetime.now(timezone.utc)
|
||||
|
||||
@@ -1,10 +1,16 @@
|
||||
"""Trace collection listener for orchestrating trace collection."""
|
||||
|
||||
import os
|
||||
from typing import Any, ClassVar
|
||||
import uuid
|
||||
|
||||
from typing_extensions import Self
|
||||
|
||||
from crewai.cli.authentication.token import AuthError, get_auth_token
|
||||
from crewai.cli.version import get_crewai_version
|
||||
from crewai.events.base_event_listener import BaseEventListener
|
||||
from crewai.events.event_bus import CrewAIEventsBus
|
||||
from crewai.events.utils.console_formatter import ConsoleFormatter
|
||||
from crewai.events.listeners.tracing.first_time_trace_handler import (
|
||||
FirstTimeTraceHandler,
|
||||
)
|
||||
@@ -53,6 +59,8 @@ from crewai.events.types.memory_events import (
|
||||
MemoryQueryCompletedEvent,
|
||||
MemoryQueryFailedEvent,
|
||||
MemoryQueryStartedEvent,
|
||||
MemoryRetrievalCompletedEvent,
|
||||
MemoryRetrievalStartedEvent,
|
||||
MemorySaveCompletedEvent,
|
||||
MemorySaveFailedEvent,
|
||||
MemorySaveStartedEvent,
|
||||
@@ -75,9 +83,7 @@ from crewai.events.types.tool_usage_events import (
|
||||
|
||||
|
||||
class TraceCollectionListener(BaseEventListener):
|
||||
"""
|
||||
Trace collection listener that orchestrates trace collection
|
||||
"""
|
||||
"""Trace collection listener that orchestrates trace collection."""
|
||||
|
||||
complex_events: ClassVar[list[str]] = [
|
||||
"task_started",
|
||||
@@ -88,11 +94,12 @@ class TraceCollectionListener(BaseEventListener):
|
||||
"agent_execution_completed",
|
||||
]
|
||||
|
||||
_instance = None
|
||||
_initialized = False
|
||||
_listeners_setup = False
|
||||
_instance: Self | None = None
|
||||
_initialized: bool = False
|
||||
_listeners_setup: bool = False
|
||||
|
||||
def __new__(cls, batch_manager: TraceBatchManager | None = None):
|
||||
def __new__(cls, batch_manager: TraceBatchManager | None = None) -> Self:
|
||||
"""Create or return singleton instance."""
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
@@ -100,7 +107,14 @@ class TraceCollectionListener(BaseEventListener):
|
||||
def __init__(
|
||||
self,
|
||||
batch_manager: TraceBatchManager | None = None,
|
||||
):
|
||||
formatter: ConsoleFormatter | None = None,
|
||||
) -> None:
|
||||
"""Initialize trace collection listener.
|
||||
|
||||
Args:
|
||||
batch_manager: Optional trace batch manager instance.
|
||||
formatter: Optional console formatter for output.
|
||||
"""
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
@@ -108,19 +122,22 @@ class TraceCollectionListener(BaseEventListener):
|
||||
self.batch_manager = batch_manager or TraceBatchManager()
|
||||
self._initialized = True
|
||||
self.first_time_handler = FirstTimeTraceHandler()
|
||||
self.formatter = formatter
|
||||
self.memory_retrieval_in_progress = False
|
||||
self.memory_save_in_progress = False
|
||||
|
||||
if self.first_time_handler.initialize_for_first_time_user():
|
||||
self.first_time_handler.set_batch_manager(self.batch_manager)
|
||||
|
||||
def _check_authenticated(self) -> bool:
|
||||
"""Check if tracing should be enabled"""
|
||||
"""Check if tracing should be enabled."""
|
||||
try:
|
||||
return bool(get_auth_token())
|
||||
except AuthError:
|
||||
return False
|
||||
|
||||
def _get_user_context(self) -> dict[str, str]:
|
||||
"""Extract user context for tracing"""
|
||||
"""Extract user context for tracing."""
|
||||
return {
|
||||
"user_id": os.getenv("CREWAI_USER_ID", "anonymous"),
|
||||
"organization_id": os.getenv("CREWAI_ORG_ID", ""),
|
||||
@@ -128,9 +145,12 @@ class TraceCollectionListener(BaseEventListener):
|
||||
"trace_id": str(uuid.uuid4()),
|
||||
}
|
||||
|
||||
def setup_listeners(self, crewai_event_bus):
|
||||
"""Setup event listeners - delegates to specific handlers"""
|
||||
def setup_listeners(self, crewai_event_bus: CrewAIEventsBus) -> None:
|
||||
"""Setup event listeners - delegates to specific handlers.
|
||||
|
||||
Args:
|
||||
crewai_event_bus: The event bus to register listeners on.
|
||||
"""
|
||||
if self._listeners_setup:
|
||||
return
|
||||
|
||||
@@ -140,50 +160,52 @@ class TraceCollectionListener(BaseEventListener):
|
||||
|
||||
self._listeners_setup = True
|
||||
|
||||
def _register_flow_event_handlers(self, event_bus):
|
||||
"""Register handlers for flow events"""
|
||||
def _register_flow_event_handlers(self, event_bus: CrewAIEventsBus) -> None:
|
||||
"""Register handlers for flow events."""
|
||||
|
||||
@event_bus.on(FlowCreatedEvent)
|
||||
def on_flow_created(source, event):
|
||||
def on_flow_created(source: Any, event: FlowCreatedEvent) -> None:
|
||||
pass
|
||||
|
||||
@event_bus.on(FlowStartedEvent)
|
||||
def on_flow_started(source, event):
|
||||
def on_flow_started(source: Any, event: FlowStartedEvent) -> None:
|
||||
if not self.batch_manager.is_batch_initialized():
|
||||
self._initialize_flow_batch(source, event)
|
||||
self._handle_trace_event("flow_started", source, event)
|
||||
|
||||
@event_bus.on(MethodExecutionStartedEvent)
|
||||
def on_method_started(source, event):
|
||||
def on_method_started(source: Any, event: MethodExecutionStartedEvent) -> None:
|
||||
self._handle_trace_event("method_execution_started", source, event)
|
||||
|
||||
@event_bus.on(MethodExecutionFinishedEvent)
|
||||
def on_method_finished(source, event):
|
||||
def on_method_finished(
|
||||
source: Any, event: MethodExecutionFinishedEvent
|
||||
) -> None:
|
||||
self._handle_trace_event("method_execution_finished", source, event)
|
||||
|
||||
@event_bus.on(MethodExecutionFailedEvent)
|
||||
def on_method_failed(source, event):
|
||||
def on_method_failed(source: Any, event: MethodExecutionFailedEvent) -> None:
|
||||
self._handle_trace_event("method_execution_failed", source, event)
|
||||
|
||||
@event_bus.on(FlowFinishedEvent)
|
||||
def on_flow_finished(source, event):
|
||||
def on_flow_finished(source: Any, event: FlowFinishedEvent) -> None:
|
||||
self._handle_trace_event("flow_finished", source, event)
|
||||
|
||||
@event_bus.on(FlowPlotEvent)
|
||||
def on_flow_plot(source, event):
|
||||
def on_flow_plot(source: Any, event: FlowPlotEvent) -> None:
|
||||
self._handle_action_event("flow_plot", source, event)
|
||||
|
||||
def _register_context_event_handlers(self, event_bus):
|
||||
"""Register handlers for context events (start/end)"""
|
||||
def _register_context_event_handlers(self, event_bus: CrewAIEventsBus) -> None:
|
||||
"""Register handlers for context events (start/end)."""
|
||||
|
||||
@event_bus.on(CrewKickoffStartedEvent)
|
||||
def on_crew_started(source, event):
|
||||
def on_crew_started(source: Any, event: CrewKickoffStartedEvent) -> None:
|
||||
if not self.batch_manager.is_batch_initialized():
|
||||
self._initialize_crew_batch(source, event)
|
||||
self._handle_trace_event("crew_kickoff_started", source, event)
|
||||
|
||||
@event_bus.on(CrewKickoffCompletedEvent)
|
||||
def on_crew_completed(source, event):
|
||||
def on_crew_completed(source: Any, event: CrewKickoffCompletedEvent) -> None:
|
||||
self._handle_trace_event("crew_kickoff_completed", source, event)
|
||||
if self.batch_manager.batch_owner_type == "crew":
|
||||
if self.first_time_handler.is_first_time:
|
||||
@@ -193,7 +215,7 @@ class TraceCollectionListener(BaseEventListener):
|
||||
self.batch_manager.finalize_batch()
|
||||
|
||||
@event_bus.on(CrewKickoffFailedEvent)
|
||||
def on_crew_failed(source, event):
|
||||
def on_crew_failed(source: Any, event: CrewKickoffFailedEvent) -> None:
|
||||
self._handle_trace_event("crew_kickoff_failed", source, event)
|
||||
if self.first_time_handler.is_first_time:
|
||||
self.first_time_handler.mark_events_collected()
|
||||
@@ -202,134 +224,245 @@ class TraceCollectionListener(BaseEventListener):
|
||||
self.batch_manager.finalize_batch()
|
||||
|
||||
@event_bus.on(TaskStartedEvent)
|
||||
def on_task_started(source, event):
|
||||
def on_task_started(source: Any, event: TaskStartedEvent) -> None:
|
||||
self._handle_trace_event("task_started", source, event)
|
||||
|
||||
@event_bus.on(TaskCompletedEvent)
|
||||
def on_task_completed(source, event):
|
||||
def on_task_completed(source: Any, event: TaskCompletedEvent) -> None:
|
||||
self._handle_trace_event("task_completed", source, event)
|
||||
|
||||
@event_bus.on(TaskFailedEvent)
|
||||
def on_task_failed(source, event):
|
||||
def on_task_failed(source: Any, event: TaskFailedEvent) -> None:
|
||||
self._handle_trace_event("task_failed", source, event)
|
||||
|
||||
@event_bus.on(AgentExecutionStartedEvent)
|
||||
def on_agent_started(source, event):
|
||||
def on_agent_started(source: Any, event: AgentExecutionStartedEvent) -> None:
|
||||
self._handle_trace_event("agent_execution_started", source, event)
|
||||
|
||||
@event_bus.on(AgentExecutionCompletedEvent)
|
||||
def on_agent_completed(source, event):
|
||||
def on_agent_completed(
|
||||
source: Any, event: AgentExecutionCompletedEvent
|
||||
) -> None:
|
||||
self._handle_trace_event("agent_execution_completed", source, event)
|
||||
|
||||
@event_bus.on(LiteAgentExecutionStartedEvent)
|
||||
def on_lite_agent_started(source, event):
|
||||
def on_lite_agent_started(
|
||||
source: Any, event: LiteAgentExecutionStartedEvent
|
||||
) -> None:
|
||||
self._handle_trace_event("lite_agent_execution_started", source, event)
|
||||
|
||||
@event_bus.on(LiteAgentExecutionCompletedEvent)
|
||||
def on_lite_agent_completed(source, event):
|
||||
def on_lite_agent_completed(
|
||||
source: Any, event: LiteAgentExecutionCompletedEvent
|
||||
) -> None:
|
||||
self._handle_trace_event("lite_agent_execution_completed", source, event)
|
||||
|
||||
@event_bus.on(LiteAgentExecutionErrorEvent)
|
||||
def on_lite_agent_error(source, event):
|
||||
def on_lite_agent_error(
|
||||
source: Any, event: LiteAgentExecutionErrorEvent
|
||||
) -> None:
|
||||
self._handle_trace_event("lite_agent_execution_error", source, event)
|
||||
|
||||
@event_bus.on(AgentExecutionErrorEvent)
|
||||
def on_agent_error(source, event):
|
||||
def on_agent_error(source: Any, event: AgentExecutionErrorEvent) -> None:
|
||||
self._handle_trace_event("agent_execution_error", source, event)
|
||||
|
||||
@event_bus.on(LLMGuardrailStartedEvent)
|
||||
def on_guardrail_started(source, event):
|
||||
def on_guardrail_started(source: Any, event: LLMGuardrailStartedEvent) -> None:
|
||||
self._handle_trace_event("llm_guardrail_started", source, event)
|
||||
|
||||
@event_bus.on(LLMGuardrailCompletedEvent)
|
||||
def on_guardrail_completed(source, event):
|
||||
def on_guardrail_completed(
|
||||
source: Any, event: LLMGuardrailCompletedEvent
|
||||
) -> None:
|
||||
self._handle_trace_event("llm_guardrail_completed", source, event)
|
||||
|
||||
def _register_action_event_handlers(self, event_bus):
|
||||
"""Register handlers for action events (LLM calls, tool usage)"""
|
||||
def _register_action_event_handlers(self, event_bus: CrewAIEventsBus) -> None:
|
||||
"""Register handlers for action events (LLM calls, tool usage)."""
|
||||
|
||||
@event_bus.on(LLMCallStartedEvent)
|
||||
def on_llm_call_started(source, event):
|
||||
def on_llm_call_started(source: Any, event: LLMCallStartedEvent) -> None:
|
||||
self._handle_action_event("llm_call_started", source, event)
|
||||
|
||||
@event_bus.on(LLMCallCompletedEvent)
|
||||
def on_llm_call_completed(source, event):
|
||||
def on_llm_call_completed(source: Any, event: LLMCallCompletedEvent) -> None:
|
||||
self._handle_action_event("llm_call_completed", source, event)
|
||||
|
||||
@event_bus.on(LLMCallFailedEvent)
|
||||
def on_llm_call_failed(source, event):
|
||||
def on_llm_call_failed(source: Any, event: LLMCallFailedEvent) -> None:
|
||||
self._handle_action_event("llm_call_failed", source, event)
|
||||
|
||||
@event_bus.on(ToolUsageStartedEvent)
|
||||
def on_tool_started(source, event):
|
||||
def on_tool_started(source: Any, event: ToolUsageStartedEvent) -> None:
|
||||
self._handle_action_event("tool_usage_started", source, event)
|
||||
|
||||
@event_bus.on(ToolUsageFinishedEvent)
|
||||
def on_tool_finished(source, event):
|
||||
def on_tool_finished(source: Any, event: ToolUsageFinishedEvent) -> None:
|
||||
self._handle_action_event("tool_usage_finished", source, event)
|
||||
|
||||
@event_bus.on(ToolUsageErrorEvent)
|
||||
def on_tool_error(source, event):
|
||||
def on_tool_error(source: Any, event: ToolUsageErrorEvent) -> None:
|
||||
self._handle_action_event("tool_usage_error", source, event)
|
||||
|
||||
@event_bus.on(MemoryQueryStartedEvent)
|
||||
def on_memory_query_started(source, event):
|
||||
def on_memory_query_started(
|
||||
source: Any, event: MemoryQueryStartedEvent
|
||||
) -> None:
|
||||
self._handle_action_event("memory_query_started", source, event)
|
||||
|
||||
@event_bus.on(MemoryQueryCompletedEvent)
|
||||
def on_memory_query_completed(source, event):
|
||||
def on_memory_query_completed(
|
||||
source: Any, event: MemoryQueryCompletedEvent
|
||||
) -> None:
|
||||
self._handle_action_event("memory_query_completed", source, event)
|
||||
if self.formatter and self.memory_retrieval_in_progress:
|
||||
self.formatter.handle_memory_query_completed(
|
||||
self.formatter.current_agent_branch,
|
||||
event.source_type or "memory",
|
||||
event.query_time_ms,
|
||||
self.formatter.current_crew_tree,
|
||||
)
|
||||
|
||||
@event_bus.on(MemoryQueryFailedEvent)
|
||||
def on_memory_query_failed(source, event):
|
||||
def on_memory_query_failed(source: Any, event: MemoryQueryFailedEvent) -> None:
|
||||
self._handle_action_event("memory_query_failed", source, event)
|
||||
if self.formatter and self.memory_retrieval_in_progress:
|
||||
self.formatter.handle_memory_query_failed(
|
||||
self.formatter.current_agent_branch,
|
||||
self.formatter.current_crew_tree,
|
||||
event.error,
|
||||
event.source_type or "memory",
|
||||
)
|
||||
|
||||
@event_bus.on(MemorySaveStartedEvent)
|
||||
def on_memory_save_started(source, event):
|
||||
def on_memory_save_started(source: Any, event: MemorySaveStartedEvent) -> None:
|
||||
self._handle_action_event("memory_save_started", source, event)
|
||||
if self.formatter:
|
||||
if self.memory_save_in_progress:
|
||||
return
|
||||
|
||||
self.memory_save_in_progress = True
|
||||
|
||||
self.formatter.handle_memory_save_started(
|
||||
self.formatter.current_agent_branch,
|
||||
self.formatter.current_crew_tree,
|
||||
)
|
||||
|
||||
@event_bus.on(MemorySaveCompletedEvent)
|
||||
def on_memory_save_completed(source, event):
|
||||
def on_memory_save_completed(
|
||||
source: Any, event: MemorySaveCompletedEvent
|
||||
) -> None:
|
||||
self._handle_action_event("memory_save_completed", source, event)
|
||||
if self.formatter:
|
||||
if not self.memory_save_in_progress:
|
||||
return
|
||||
|
||||
self.memory_save_in_progress = False
|
||||
|
||||
self.formatter.handle_memory_save_completed(
|
||||
self.formatter.current_agent_branch,
|
||||
self.formatter.current_crew_tree,
|
||||
event.save_time_ms,
|
||||
event.source_type or "memory",
|
||||
)
|
||||
|
||||
@event_bus.on(MemorySaveFailedEvent)
|
||||
def on_memory_save_failed(source, event):
|
||||
def on_memory_save_failed(source: Any, event: MemorySaveFailedEvent) -> None:
|
||||
self._handle_action_event("memory_save_failed", source, event)
|
||||
if self.formatter and self.memory_save_in_progress:
|
||||
self.formatter.handle_memory_save_failed(
|
||||
self.formatter.current_agent_branch,
|
||||
event.error,
|
||||
event.source_type or "memory",
|
||||
self.formatter.current_crew_tree,
|
||||
)
|
||||
|
||||
@event_bus.on(MemoryRetrievalStartedEvent)
|
||||
def on_memory_retrieval_started(
|
||||
source: Any, event: MemoryRetrievalStartedEvent
|
||||
) -> None:
|
||||
if self.formatter:
|
||||
if self.memory_retrieval_in_progress:
|
||||
return
|
||||
|
||||
self.memory_retrieval_in_progress = True
|
||||
|
||||
self.formatter.handle_memory_retrieval_started(
|
||||
self.formatter.current_agent_branch,
|
||||
self.formatter.current_crew_tree,
|
||||
)
|
||||
|
||||
@event_bus.on(MemoryRetrievalCompletedEvent)
|
||||
def on_memory_retrieval_completed(
|
||||
source: Any, event: MemoryRetrievalCompletedEvent
|
||||
) -> None:
|
||||
if self.formatter:
|
||||
if not self.memory_retrieval_in_progress:
|
||||
return
|
||||
|
||||
self.memory_retrieval_in_progress = False
|
||||
self.formatter.handle_memory_retrieval_completed(
|
||||
self.formatter.current_agent_branch,
|
||||
self.formatter.current_crew_tree,
|
||||
event.memory_content,
|
||||
event.retrieval_time_ms,
|
||||
)
|
||||
|
||||
@event_bus.on(AgentReasoningStartedEvent)
|
||||
def on_agent_reasoning_started(source, event):
|
||||
def on_agent_reasoning_started(
|
||||
source: Any, event: AgentReasoningStartedEvent
|
||||
) -> None:
|
||||
self._handle_action_event("agent_reasoning_started", source, event)
|
||||
|
||||
@event_bus.on(AgentReasoningCompletedEvent)
|
||||
def on_agent_reasoning_completed(source, event):
|
||||
def on_agent_reasoning_completed(
|
||||
source: Any, event: AgentReasoningCompletedEvent
|
||||
) -> None:
|
||||
self._handle_action_event("agent_reasoning_completed", source, event)
|
||||
|
||||
@event_bus.on(AgentReasoningFailedEvent)
|
||||
def on_agent_reasoning_failed(source, event):
|
||||
def on_agent_reasoning_failed(
|
||||
source: Any, event: AgentReasoningFailedEvent
|
||||
) -> None:
|
||||
self._handle_action_event("agent_reasoning_failed", source, event)
|
||||
|
||||
@event_bus.on(KnowledgeRetrievalStartedEvent)
|
||||
def on_knowledge_retrieval_started(source, event):
|
||||
def on_knowledge_retrieval_started(
|
||||
source: Any, event: KnowledgeRetrievalStartedEvent
|
||||
) -> None:
|
||||
self._handle_action_event("knowledge_retrieval_started", source, event)
|
||||
|
||||
@event_bus.on(KnowledgeRetrievalCompletedEvent)
|
||||
def on_knowledge_retrieval_completed(source, event):
|
||||
def on_knowledge_retrieval_completed(
|
||||
source: Any, event: KnowledgeRetrievalCompletedEvent
|
||||
) -> None:
|
||||
self._handle_action_event("knowledge_retrieval_completed", source, event)
|
||||
|
||||
@event_bus.on(KnowledgeQueryStartedEvent)
|
||||
def on_knowledge_query_started(source, event):
|
||||
def on_knowledge_query_started(
|
||||
source: Any, event: KnowledgeQueryStartedEvent
|
||||
) -> None:
|
||||
self._handle_action_event("knowledge_query_started", source, event)
|
||||
|
||||
@event_bus.on(KnowledgeQueryCompletedEvent)
|
||||
def on_knowledge_query_completed(source, event):
|
||||
def on_knowledge_query_completed(
|
||||
source: Any, event: KnowledgeQueryCompletedEvent
|
||||
) -> None:
|
||||
self._handle_action_event("knowledge_query_completed", source, event)
|
||||
|
||||
@event_bus.on(KnowledgeQueryFailedEvent)
|
||||
def on_knowledge_query_failed(source, event):
|
||||
def on_knowledge_query_failed(
|
||||
source: Any, event: KnowledgeQueryFailedEvent
|
||||
) -> None:
|
||||
self._handle_action_event("knowledge_query_failed", source, event)
|
||||
|
||||
def _initialize_crew_batch(self, source: Any, event: Any):
|
||||
"""Initialize trace batch"""
|
||||
def _initialize_crew_batch(self, source: Any, event: Any) -> None:
|
||||
"""Initialize trace batch.
|
||||
|
||||
Args:
|
||||
source: Source object that triggered the event.
|
||||
event: Event object containing crew information.
|
||||
"""
|
||||
user_context = self._get_user_context()
|
||||
execution_metadata = {
|
||||
"crew_name": getattr(event, "crew_name", "Unknown Crew"),
|
||||
@@ -342,8 +475,13 @@ class TraceCollectionListener(BaseEventListener):
|
||||
|
||||
self._initialize_batch(user_context, execution_metadata)
|
||||
|
||||
def _initialize_flow_batch(self, source: Any, event: Any):
|
||||
"""Initialize trace batch for Flow execution"""
|
||||
def _initialize_flow_batch(self, source: Any, event: Any) -> None:
|
||||
"""Initialize trace batch for Flow execution.
|
||||
|
||||
Args:
|
||||
source: Source object that triggered the event.
|
||||
event: Event object containing flow information.
|
||||
"""
|
||||
user_context = self._get_user_context()
|
||||
execution_metadata = {
|
||||
"flow_name": getattr(event, "flow_name", "Unknown Flow"),
|
||||
@@ -359,21 +497,32 @@ class TraceCollectionListener(BaseEventListener):
|
||||
|
||||
def _initialize_batch(
|
||||
self, user_context: dict[str, str], execution_metadata: dict[str, Any]
|
||||
):
|
||||
"""Initialize trace batch - auto-enable ephemeral for first-time users."""
|
||||
) -> None:
|
||||
"""Initialize trace batch - auto-enable ephemeral for first-time users.
|
||||
|
||||
Args:
|
||||
user_context: User context information.
|
||||
execution_metadata: Metadata about the execution.
|
||||
"""
|
||||
if self.first_time_handler.is_first_time:
|
||||
return self.batch_manager.initialize_batch(
|
||||
self.batch_manager.initialize_batch(
|
||||
user_context, execution_metadata, use_ephemeral=True
|
||||
)
|
||||
return
|
||||
|
||||
use_ephemeral = not self._check_authenticated()
|
||||
return self.batch_manager.initialize_batch(
|
||||
self.batch_manager.initialize_batch(
|
||||
user_context, execution_metadata, use_ephemeral=use_ephemeral
|
||||
)
|
||||
|
||||
def _handle_trace_event(self, event_type: str, source: Any, event: Any):
|
||||
"""Generic handler for context end events"""
|
||||
def _handle_trace_event(self, event_type: str, source: Any, event: Any) -> None:
|
||||
"""Generic handler for context end events.
|
||||
|
||||
Args:
|
||||
event_type: Type of the event.
|
||||
source: Source object that triggered the event.
|
||||
event: Event object.
|
||||
"""
|
||||
self.batch_manager.begin_event_processing()
|
||||
try:
|
||||
trace_event = self._create_trace_event(event_type, source, event)
|
||||
@@ -381,9 +530,14 @@ class TraceCollectionListener(BaseEventListener):
|
||||
finally:
|
||||
self.batch_manager.end_event_processing()
|
||||
|
||||
def _handle_action_event(self, event_type: str, source: Any, event: Any):
|
||||
"""Generic handler for action events (LLM calls, tool usage)"""
|
||||
def _handle_action_event(self, event_type: str, source: Any, event: Any) -> None:
|
||||
"""Generic handler for action events (LLM calls, tool usage).
|
||||
|
||||
Args:
|
||||
event_type: Type of the event.
|
||||
source: Source object that triggered the event.
|
||||
event: Event object.
|
||||
"""
|
||||
if not self.batch_manager.is_batch_initialized():
|
||||
user_context = self._get_user_context()
|
||||
execution_metadata = {
|
||||
|
||||
141
lib/crewai/src/crewai/events/types/a2a_events.py
Normal file
141
lib/crewai/src/crewai/events/types/a2a_events.py
Normal file
@@ -0,0 +1,141 @@
|
||||
"""Events for A2A (Agent-to-Agent) delegation.
|
||||
|
||||
This module defines events emitted during A2A protocol delegation,
|
||||
including both single-turn and multiturn conversation flows.
|
||||
"""
|
||||
|
||||
from typing import Any, Literal
|
||||
|
||||
from crewai.events.base_events import BaseEvent
|
||||
|
||||
|
||||
class A2AEventBase(BaseEvent):
|
||||
"""Base class for A2A events with task/agent context."""
|
||||
|
||||
from_task: Any | None = None
|
||||
from_agent: Any | None = None
|
||||
|
||||
def __init__(self, **data):
|
||||
"""Initialize A2A event, extracting task and agent metadata."""
|
||||
if data.get("from_task"):
|
||||
task = data["from_task"]
|
||||
data["task_id"] = str(task.id)
|
||||
data["task_name"] = task.name or task.description
|
||||
data["from_task"] = None
|
||||
|
||||
if data.get("from_agent"):
|
||||
agent = data["from_agent"]
|
||||
data["agent_id"] = str(agent.id)
|
||||
data["agent_role"] = agent.role
|
||||
data["from_agent"] = None
|
||||
|
||||
super().__init__(**data)
|
||||
|
||||
|
||||
class A2ADelegationStartedEvent(A2AEventBase):
|
||||
"""Event emitted when A2A delegation starts.
|
||||
|
||||
Attributes:
|
||||
endpoint: A2A agent endpoint URL (AgentCard URL)
|
||||
task_description: Task being delegated to the A2A agent
|
||||
agent_id: A2A agent identifier
|
||||
is_multiturn: Whether this is part of a multiturn conversation
|
||||
turn_number: Current turn number (1-indexed, 1 for single-turn)
|
||||
"""
|
||||
|
||||
type: str = "a2a_delegation_started"
|
||||
endpoint: str
|
||||
task_description: str
|
||||
agent_id: str
|
||||
is_multiturn: bool = False
|
||||
turn_number: int = 1
|
||||
|
||||
|
||||
class A2ADelegationCompletedEvent(A2AEventBase):
|
||||
"""Event emitted when A2A delegation completes.
|
||||
|
||||
Attributes:
|
||||
status: Completion status (completed, input_required, failed, etc.)
|
||||
result: Result message if status is completed
|
||||
error: Error/response message (error for failed, response for input_required)
|
||||
is_multiturn: Whether this is part of a multiturn conversation
|
||||
"""
|
||||
|
||||
type: str = "a2a_delegation_completed"
|
||||
status: str
|
||||
result: str | None = None
|
||||
error: str | None = None
|
||||
is_multiturn: bool = False
|
||||
|
||||
|
||||
class A2AConversationStartedEvent(A2AEventBase):
|
||||
"""Event emitted when a multiturn A2A conversation starts.
|
||||
|
||||
This is emitted once at the beginning of a multiturn conversation,
|
||||
before the first message exchange.
|
||||
|
||||
Attributes:
|
||||
agent_id: A2A agent identifier
|
||||
endpoint: A2A agent endpoint URL
|
||||
a2a_agent_name: Name of the A2A agent from agent card
|
||||
"""
|
||||
|
||||
type: str = "a2a_conversation_started"
|
||||
agent_id: str
|
||||
endpoint: str
|
||||
a2a_agent_name: str | None = None
|
||||
|
||||
|
||||
class A2AMessageSentEvent(A2AEventBase):
|
||||
"""Event emitted when a message is sent to the A2A agent.
|
||||
|
||||
Attributes:
|
||||
message: Message content sent to the A2A agent
|
||||
turn_number: Current turn number (1-indexed)
|
||||
is_multiturn: Whether this is part of a multiturn conversation
|
||||
agent_role: Role of the CrewAI agent sending the message
|
||||
"""
|
||||
|
||||
type: str = "a2a_message_sent"
|
||||
message: str
|
||||
turn_number: int
|
||||
is_multiturn: bool = False
|
||||
agent_role: str | None = None
|
||||
|
||||
|
||||
class A2AResponseReceivedEvent(A2AEventBase):
|
||||
"""Event emitted when a response is received from the A2A agent.
|
||||
|
||||
Attributes:
|
||||
response: Response content from the A2A agent
|
||||
turn_number: Current turn number (1-indexed)
|
||||
is_multiturn: Whether this is part of a multiturn conversation
|
||||
status: Response status (input_required, completed, etc.)
|
||||
agent_role: Role of the CrewAI agent (for display)
|
||||
"""
|
||||
|
||||
type: str = "a2a_response_received"
|
||||
response: str
|
||||
turn_number: int
|
||||
is_multiturn: bool = False
|
||||
status: str
|
||||
agent_role: str | None = None
|
||||
|
||||
|
||||
class A2AConversationCompletedEvent(A2AEventBase):
|
||||
"""Event emitted when a multiturn A2A conversation completes.
|
||||
|
||||
This is emitted once at the end of a multiturn conversation.
|
||||
|
||||
Attributes:
|
||||
status: Final status (completed, failed, etc.)
|
||||
final_result: Final result if completed successfully
|
||||
error: Error message if failed
|
||||
total_turns: Total number of turns in the conversation
|
||||
"""
|
||||
|
||||
type: str = "a2a_conversation_completed"
|
||||
status: Literal["completed", "failed"]
|
||||
final_result: str | None = None
|
||||
error: str | None = None
|
||||
total_turns: int
|
||||
85
lib/crewai/src/crewai/events/types/mcp_events.py
Normal file
85
lib/crewai/src/crewai/events/types/mcp_events.py
Normal file
@@ -0,0 +1,85 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from crewai.events.base_events import BaseEvent
|
||||
|
||||
|
||||
class MCPEvent(BaseEvent):
|
||||
"""Base event for MCP operations."""
|
||||
|
||||
server_name: str
|
||||
server_url: str | None = None
|
||||
transport_type: str | None = None # "stdio", "http", "sse"
|
||||
agent_id: str | None = None
|
||||
agent_role: str | None = None
|
||||
from_agent: Any | None = None
|
||||
from_task: Any | None = None
|
||||
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
self._set_agent_params(data)
|
||||
self._set_task_params(data)
|
||||
|
||||
|
||||
class MCPConnectionStartedEvent(MCPEvent):
|
||||
"""Event emitted when starting to connect to an MCP server."""
|
||||
|
||||
type: str = "mcp_connection_started"
|
||||
connect_timeout: int | None = None
|
||||
is_reconnect: bool = (
|
||||
False # True if this is a reconnection, False for first connection
|
||||
)
|
||||
|
||||
|
||||
class MCPConnectionCompletedEvent(MCPEvent):
|
||||
"""Event emitted when successfully connected to an MCP server."""
|
||||
|
||||
type: str = "mcp_connection_completed"
|
||||
started_at: datetime | None = None
|
||||
completed_at: datetime | None = None
|
||||
connection_duration_ms: float | None = None
|
||||
is_reconnect: bool = (
|
||||
False # True if this was a reconnection, False for first connection
|
||||
)
|
||||
|
||||
|
||||
class MCPConnectionFailedEvent(MCPEvent):
|
||||
"""Event emitted when connection to an MCP server fails."""
|
||||
|
||||
type: str = "mcp_connection_failed"
|
||||
error: str
|
||||
error_type: str | None = None # "timeout", "authentication", "network", etc.
|
||||
started_at: datetime | None = None
|
||||
failed_at: datetime | None = None
|
||||
|
||||
|
||||
class MCPToolExecutionStartedEvent(MCPEvent):
|
||||
"""Event emitted when starting to execute an MCP tool."""
|
||||
|
||||
type: str = "mcp_tool_execution_started"
|
||||
tool_name: str
|
||||
tool_args: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class MCPToolExecutionCompletedEvent(MCPEvent):
|
||||
"""Event emitted when MCP tool execution completes."""
|
||||
|
||||
type: str = "mcp_tool_execution_completed"
|
||||
tool_name: str
|
||||
tool_args: dict[str, Any] | None = None
|
||||
result: Any | None = None
|
||||
started_at: datetime | None = None
|
||||
completed_at: datetime | None = None
|
||||
execution_duration_ms: float | None = None
|
||||
|
||||
|
||||
class MCPToolExecutionFailedEvent(MCPEvent):
|
||||
"""Event emitted when MCP tool execution fails."""
|
||||
|
||||
type: str = "mcp_tool_execution_failed"
|
||||
tool_name: str
|
||||
tool_args: dict[str, Any] | None = None
|
||||
error: str
|
||||
error_type: str | None = None # "timeout", "validation", "server_error", etc.
|
||||
started_at: datetime | None = None
|
||||
failed_at: datetime | None = None
|
||||
@@ -17,9 +17,16 @@ class ConsoleFormatter:
|
||||
current_method_branch: Tree | None = None
|
||||
current_lite_agent_branch: Tree | None = None
|
||||
tool_usage_counts: ClassVar[dict[str, int]] = {}
|
||||
current_reasoning_branch: Tree | None = None # Track reasoning status
|
||||
current_reasoning_branch: Tree | None = None
|
||||
_live_paused: bool = False
|
||||
current_llm_tool_tree: Tree | None = None
|
||||
current_a2a_conversation_branch: Tree | None = None
|
||||
current_a2a_turn_count: int = 0
|
||||
_pending_a2a_message: str | None = None
|
||||
_pending_a2a_agent_role: str | None = None
|
||||
_pending_a2a_turn_number: int | None = None
|
||||
_a2a_turn_branches: ClassVar[dict[int, Tree]] = {}
|
||||
_current_a2a_agent_name: str | None = None
|
||||
|
||||
def __init__(self, verbose: bool = False):
|
||||
self.console = Console(width=None)
|
||||
@@ -192,7 +199,12 @@ class ConsoleFormatter:
|
||||
style,
|
||||
ID=source_id,
|
||||
)
|
||||
content.append(f"Final Output: {final_string_output}\n", style="white")
|
||||
|
||||
if status == "failed" and final_string_output:
|
||||
content.append("Error:\n", style="white bold")
|
||||
content.append(f"{final_string_output}\n", style="red")
|
||||
else:
|
||||
content.append(f"Final Output: {final_string_output}\n", style="white")
|
||||
|
||||
self.print_panel(content, title, style)
|
||||
|
||||
@@ -357,7 +369,14 @@ class ConsoleFormatter:
|
||||
return flow_tree
|
||||
|
||||
def start_flow(self, flow_name: str, flow_id: str) -> Tree | None:
|
||||
"""Initialize a flow execution tree."""
|
||||
"""Initialize or update a flow execution tree."""
|
||||
if self.current_flow_tree is not None:
|
||||
for child in self.current_flow_tree.children:
|
||||
if "Starting Flow" in str(child.label):
|
||||
child.label = Text("🚀 Flow Started", style="green")
|
||||
break
|
||||
return self.current_flow_tree
|
||||
|
||||
flow_tree = Tree("")
|
||||
flow_label = Text()
|
||||
flow_label.append("🌊 Flow: ", style="blue bold")
|
||||
@@ -436,27 +455,38 @@ class ConsoleFormatter:
|
||||
prefix, style = "🔄 Running:", "yellow"
|
||||
elif status == "completed":
|
||||
prefix, style = "✅ Completed:", "green"
|
||||
# Update initialization node when a method completes successfully
|
||||
for child in flow_tree.children:
|
||||
if "Starting Flow" in str(child.label):
|
||||
child.label = Text("Flow Method Step", style="white")
|
||||
break
|
||||
else:
|
||||
prefix, style = "❌ Failed:", "red"
|
||||
# Update initialization node on failure
|
||||
for child in flow_tree.children:
|
||||
if "Starting Flow" in str(child.label):
|
||||
child.label = Text("❌ Flow Step Failed", style="red")
|
||||
break
|
||||
|
||||
if not method_branch:
|
||||
# Find or create method branch
|
||||
for branch in flow_tree.children:
|
||||
if method_name in str(branch.label):
|
||||
method_branch = branch
|
||||
break
|
||||
if not method_branch:
|
||||
method_branch = flow_tree.add("")
|
||||
if method_branch is not None:
|
||||
if method_branch in flow_tree.children:
|
||||
method_branch.label = Text(prefix, style=f"{style} bold") + Text(
|
||||
f" {method_name}", style=style
|
||||
)
|
||||
self.print(flow_tree)
|
||||
self.print()
|
||||
return method_branch
|
||||
|
||||
for branch in flow_tree.children:
|
||||
label_str = str(branch.label)
|
||||
if f" {method_name}" in label_str and (
|
||||
"Running:" in label_str
|
||||
or "Completed:" in label_str
|
||||
or "Failed:" in label_str
|
||||
):
|
||||
method_branch = branch
|
||||
break
|
||||
|
||||
if method_branch is None:
|
||||
method_branch = flow_tree.add("")
|
||||
|
||||
method_branch.label = Text(prefix, style=f"{style} bold") + Text(
|
||||
f" {method_name}", style=style
|
||||
@@ -464,6 +494,7 @@ class ConsoleFormatter:
|
||||
|
||||
self.print(flow_tree)
|
||||
self.print()
|
||||
|
||||
return method_branch
|
||||
|
||||
def get_llm_tree(self, tool_name: str):
|
||||
@@ -1455,22 +1486,37 @@ class ConsoleFormatter:
|
||||
self.print()
|
||||
|
||||
elif isinstance(formatted_answer, AgentFinish):
|
||||
# Create content for the finish panel
|
||||
content = Text()
|
||||
content.append("Agent: ", style="white")
|
||||
content.append(f"{agent_role}\n\n", style="bright_green bold")
|
||||
content.append("Final Answer:\n", style="white")
|
||||
content.append(f"{formatted_answer.output}", style="bright_green")
|
||||
is_a2a_delegation = False
|
||||
try:
|
||||
output_data = json.loads(formatted_answer.output)
|
||||
if isinstance(output_data, dict):
|
||||
if output_data.get("is_a2a") is True:
|
||||
is_a2a_delegation = True
|
||||
elif "output" in output_data:
|
||||
nested_output = output_data["output"]
|
||||
if (
|
||||
isinstance(nested_output, dict)
|
||||
and nested_output.get("is_a2a") is True
|
||||
):
|
||||
is_a2a_delegation = True
|
||||
except (json.JSONDecodeError, TypeError, ValueError):
|
||||
pass
|
||||
|
||||
# Create and display the finish panel
|
||||
finish_panel = Panel(
|
||||
content,
|
||||
title="✅ Agent Final Answer",
|
||||
border_style="green",
|
||||
padding=(1, 2),
|
||||
)
|
||||
self.print(finish_panel)
|
||||
self.print()
|
||||
if not is_a2a_delegation:
|
||||
content = Text()
|
||||
content.append("Agent: ", style="white")
|
||||
content.append(f"{agent_role}\n\n", style="bright_green bold")
|
||||
content.append("Final Answer:\n", style="white")
|
||||
content.append(f"{formatted_answer.output}", style="bright_green")
|
||||
|
||||
finish_panel = Panel(
|
||||
content,
|
||||
title="✅ Agent Final Answer",
|
||||
border_style="green",
|
||||
padding=(1, 2),
|
||||
)
|
||||
self.print(finish_panel)
|
||||
self.print()
|
||||
|
||||
def handle_memory_retrieval_started(
|
||||
self,
|
||||
@@ -1770,3 +1816,635 @@ class ConsoleFormatter:
|
||||
Attempts=f"{retry_count + 1}",
|
||||
)
|
||||
self.print_panel(content, "🛡️ Guardrail Failed", "red")
|
||||
|
||||
def handle_a2a_delegation_started(
|
||||
self,
|
||||
endpoint: str,
|
||||
task_description: str,
|
||||
agent_id: str,
|
||||
is_multiturn: bool = False,
|
||||
turn_number: int = 1,
|
||||
) -> None:
|
||||
"""Handle A2A delegation started event.
|
||||
|
||||
Args:
|
||||
endpoint: A2A agent endpoint URL
|
||||
task_description: Task being delegated
|
||||
agent_id: A2A agent identifier
|
||||
is_multiturn: Whether this is part of a multiturn conversation
|
||||
turn_number: Current turn number in conversation (1-indexed)
|
||||
"""
|
||||
branch_to_use = self.current_lite_agent_branch or self.current_task_branch
|
||||
tree_to_use = self.current_crew_tree or branch_to_use
|
||||
a2a_branch: Tree | None = None
|
||||
|
||||
if is_multiturn:
|
||||
if self.current_a2a_turn_count == 0 and not isinstance(
|
||||
self.current_a2a_conversation_branch, Tree
|
||||
):
|
||||
if branch_to_use is not None and tree_to_use is not None:
|
||||
self.current_a2a_conversation_branch = branch_to_use.add("")
|
||||
self.update_tree_label(
|
||||
self.current_a2a_conversation_branch,
|
||||
"💬",
|
||||
f"Multiturn A2A Conversation ({agent_id})",
|
||||
"cyan",
|
||||
)
|
||||
self.print(tree_to_use)
|
||||
self.print()
|
||||
else:
|
||||
self.current_a2a_conversation_branch = "MULTITURN_NO_TREE"
|
||||
|
||||
content = Text()
|
||||
content.append(
|
||||
"Multiturn A2A Conversation Started\n\n", style="cyan bold"
|
||||
)
|
||||
content.append("Agent ID: ", style="white")
|
||||
content.append(f"{agent_id}\n", style="cyan")
|
||||
content.append("Note: ", style="white dim")
|
||||
content.append(
|
||||
"Conversation will be tracked in tree view", style="cyan dim"
|
||||
)
|
||||
|
||||
panel = self.create_panel(
|
||||
content, "💬 Multiturn Conversation", "cyan"
|
||||
)
|
||||
self.print(panel)
|
||||
self.print()
|
||||
|
||||
self.current_a2a_turn_count = turn_number
|
||||
|
||||
return (
|
||||
self.current_a2a_conversation_branch
|
||||
if isinstance(self.current_a2a_conversation_branch, Tree)
|
||||
else None
|
||||
)
|
||||
|
||||
if branch_to_use is not None and tree_to_use is not None:
|
||||
a2a_branch = branch_to_use.add("")
|
||||
self.update_tree_label(
|
||||
a2a_branch,
|
||||
"🔗",
|
||||
f"Delegating to A2A Agent ({agent_id})",
|
||||
"cyan",
|
||||
)
|
||||
|
||||
self.print(tree_to_use)
|
||||
self.print()
|
||||
|
||||
content = Text()
|
||||
content.append("A2A Delegation Started\n\n", style="cyan bold")
|
||||
content.append("Agent ID: ", style="white")
|
||||
content.append(f"{agent_id}\n", style="cyan")
|
||||
content.append("Endpoint: ", style="white")
|
||||
content.append(f"{endpoint}\n\n", style="cyan dim")
|
||||
content.append("Task Description:\n", style="white")
|
||||
|
||||
task_preview = (
|
||||
task_description
|
||||
if len(task_description) <= 200
|
||||
else task_description[:197] + "..."
|
||||
)
|
||||
content.append(task_preview, style="cyan")
|
||||
|
||||
panel = self.create_panel(content, "🔗 A2A Delegation", "cyan")
|
||||
self.print(panel)
|
||||
self.print()
|
||||
|
||||
return a2a_branch
|
||||
|
||||
def handle_a2a_delegation_completed(
|
||||
self,
|
||||
status: str,
|
||||
result: str | None = None,
|
||||
error: str | None = None,
|
||||
is_multiturn: bool = False,
|
||||
) -> None:
|
||||
"""Handle A2A delegation completed event.
|
||||
|
||||
Args:
|
||||
status: Completion status
|
||||
result: Optional result message
|
||||
error: Optional error message (or response for input_required)
|
||||
is_multiturn: Whether this is part of a multiturn conversation
|
||||
"""
|
||||
tree_to_use = self.current_crew_tree or self.current_task_branch
|
||||
a2a_branch = None
|
||||
|
||||
if is_multiturn and self.current_a2a_conversation_branch:
|
||||
has_tree = isinstance(self.current_a2a_conversation_branch, Tree)
|
||||
|
||||
if status == "input_required" and error:
|
||||
pass
|
||||
elif status == "completed":
|
||||
if has_tree:
|
||||
final_turn = self.current_a2a_conversation_branch.add("")
|
||||
self.update_tree_label(
|
||||
final_turn,
|
||||
"✅",
|
||||
"Conversation Completed",
|
||||
"green",
|
||||
)
|
||||
|
||||
if tree_to_use:
|
||||
self.print(tree_to_use)
|
||||
self.print()
|
||||
|
||||
self.current_a2a_conversation_branch = None
|
||||
self.current_a2a_turn_count = 0
|
||||
elif status == "failed":
|
||||
if has_tree:
|
||||
error_turn = self.current_a2a_conversation_branch.add("")
|
||||
error_msg = (
|
||||
error[:150] + "..." if error and len(error) > 150 else error
|
||||
)
|
||||
self.update_tree_label(
|
||||
error_turn,
|
||||
"❌",
|
||||
f"Failed: {error_msg}" if error else "Conversation Failed",
|
||||
"red",
|
||||
)
|
||||
|
||||
if tree_to_use:
|
||||
self.print(tree_to_use)
|
||||
self.print()
|
||||
|
||||
self.current_a2a_conversation_branch = None
|
||||
self.current_a2a_turn_count = 0
|
||||
|
||||
return
|
||||
|
||||
if a2a_branch and tree_to_use:
|
||||
if status == "completed":
|
||||
self.update_tree_label(
|
||||
a2a_branch,
|
||||
"✅",
|
||||
"A2A Delegation Completed",
|
||||
"green",
|
||||
)
|
||||
elif status == "failed":
|
||||
self.update_tree_label(
|
||||
a2a_branch,
|
||||
"❌",
|
||||
"A2A Delegation Failed",
|
||||
"red",
|
||||
)
|
||||
else:
|
||||
self.update_tree_label(
|
||||
a2a_branch,
|
||||
"⚠️",
|
||||
f"A2A Delegation {status.replace('_', ' ').title()}",
|
||||
"yellow",
|
||||
)
|
||||
|
||||
self.print(tree_to_use)
|
||||
self.print()
|
||||
|
||||
if status == "completed" and result:
|
||||
content = Text()
|
||||
content.append("A2A Delegation Completed\n\n", style="green bold")
|
||||
content.append("Result:\n", style="white")
|
||||
|
||||
result_preview = result if len(result) <= 500 else result[:497] + "..."
|
||||
content.append(result_preview, style="green")
|
||||
|
||||
panel = self.create_panel(content, "✅ A2A Success", "green")
|
||||
self.print(panel)
|
||||
self.print()
|
||||
elif status == "input_required" and error:
|
||||
content = Text()
|
||||
content.append("A2A Response\n\n", style="cyan bold")
|
||||
content.append("Message:\n", style="white")
|
||||
|
||||
response_preview = error if len(error) <= 500 else error[:497] + "..."
|
||||
content.append(response_preview, style="cyan")
|
||||
|
||||
panel = self.create_panel(content, "💬 A2A Response", "cyan")
|
||||
self.print(panel)
|
||||
self.print()
|
||||
elif error:
|
||||
content = Text()
|
||||
content.append(
|
||||
"A2A Delegation Issue\n\n",
|
||||
style="red bold" if status == "failed" else "yellow bold",
|
||||
)
|
||||
content.append("Status: ", style="white")
|
||||
content.append(
|
||||
f"{status}\n\n", style="red" if status == "failed" else "yellow"
|
||||
)
|
||||
content.append("Message:\n", style="white")
|
||||
content.append(error, style="red" if status == "failed" else "yellow")
|
||||
|
||||
panel_style = "red" if status == "failed" else "yellow"
|
||||
panel_title = "❌ A2A Failed" if status == "failed" else "⚠️ A2A Status"
|
||||
panel = self.create_panel(content, panel_title, panel_style)
|
||||
self.print(panel)
|
||||
self.print()
|
||||
|
||||
def handle_a2a_conversation_started(
|
||||
self,
|
||||
agent_id: str,
|
||||
endpoint: str,
|
||||
) -> None:
|
||||
"""Handle A2A conversation started event.
|
||||
|
||||
Args:
|
||||
agent_id: A2A agent identifier
|
||||
endpoint: A2A agent endpoint URL
|
||||
"""
|
||||
branch_to_use = self.current_lite_agent_branch or self.current_task_branch
|
||||
tree_to_use = self.current_crew_tree or branch_to_use
|
||||
|
||||
if not isinstance(self.current_a2a_conversation_branch, Tree):
|
||||
if branch_to_use is not None and tree_to_use is not None:
|
||||
self.current_a2a_conversation_branch = branch_to_use.add("")
|
||||
self.update_tree_label(
|
||||
self.current_a2a_conversation_branch,
|
||||
"💬",
|
||||
f"Multiturn A2A Conversation ({agent_id})",
|
||||
"cyan",
|
||||
)
|
||||
self.print(tree_to_use)
|
||||
self.print()
|
||||
else:
|
||||
self.current_a2a_conversation_branch = "MULTITURN_NO_TREE"
|
||||
|
||||
def handle_a2a_message_sent(
|
||||
self,
|
||||
message: str,
|
||||
turn_number: int,
|
||||
agent_role: str | None = None,
|
||||
) -> None:
|
||||
"""Handle A2A message sent event.
|
||||
|
||||
Args:
|
||||
message: Message content sent to the A2A agent
|
||||
turn_number: Current turn number
|
||||
agent_role: Role of the CrewAI agent sending the message
|
||||
"""
|
||||
self._pending_a2a_message = message
|
||||
self._pending_a2a_agent_role = agent_role
|
||||
self._pending_a2a_turn_number = turn_number
|
||||
|
||||
def handle_a2a_response_received(
|
||||
self,
|
||||
response: str,
|
||||
turn_number: int,
|
||||
status: str,
|
||||
agent_role: str | None = None,
|
||||
) -> None:
|
||||
"""Handle A2A response received event.
|
||||
|
||||
Args:
|
||||
response: Response content from the A2A agent
|
||||
turn_number: Current turn number
|
||||
status: Response status (input_required, completed, etc.)
|
||||
agent_role: Role of the CrewAI agent (for display)
|
||||
"""
|
||||
if self.current_a2a_conversation_branch and isinstance(
|
||||
self.current_a2a_conversation_branch, Tree
|
||||
):
|
||||
if turn_number in self._a2a_turn_branches:
|
||||
turn_branch = self._a2a_turn_branches[turn_number]
|
||||
else:
|
||||
turn_branch = self.current_a2a_conversation_branch.add("")
|
||||
self.update_tree_label(
|
||||
turn_branch,
|
||||
"💬",
|
||||
f"Turn {turn_number}",
|
||||
"cyan",
|
||||
)
|
||||
self._a2a_turn_branches[turn_number] = turn_branch
|
||||
|
||||
crewai_agent_role = self._pending_a2a_agent_role or agent_role or "User"
|
||||
message_content = self._pending_a2a_message or "sent message"
|
||||
|
||||
message_preview = (
|
||||
message_content[:100] + "..."
|
||||
if len(message_content) > 100
|
||||
else message_content
|
||||
)
|
||||
|
||||
user_node = turn_branch.add("")
|
||||
self.update_tree_label(
|
||||
user_node,
|
||||
f"{crewai_agent_role} 👤 : ",
|
||||
f'"{message_preview}"',
|
||||
"blue",
|
||||
)
|
||||
|
||||
agent_node = turn_branch.add("")
|
||||
response_preview = (
|
||||
response[:100] + "..." if len(response) > 100 else response
|
||||
)
|
||||
|
||||
a2a_agent_display = f"{self._current_a2a_agent_name} \U0001f916: "
|
||||
|
||||
if status == "completed":
|
||||
response_color = "green"
|
||||
status_indicator = "✓"
|
||||
elif status == "input_required":
|
||||
response_color = "yellow"
|
||||
status_indicator = "❓"
|
||||
elif status == "failed":
|
||||
response_color = "red"
|
||||
status_indicator = "✗"
|
||||
elif status == "auth_required":
|
||||
response_color = "magenta"
|
||||
status_indicator = "🔒"
|
||||
elif status == "canceled":
|
||||
response_color = "dim"
|
||||
status_indicator = "⊘"
|
||||
else:
|
||||
response_color = "cyan"
|
||||
status_indicator = ""
|
||||
|
||||
label = f'"{response_preview}"'
|
||||
if status_indicator:
|
||||
label = f"{status_indicator} {label}"
|
||||
|
||||
self.update_tree_label(
|
||||
agent_node,
|
||||
a2a_agent_display,
|
||||
label,
|
||||
response_color,
|
||||
)
|
||||
|
||||
self._pending_a2a_message = None
|
||||
self._pending_a2a_agent_role = None
|
||||
self._pending_a2a_turn_number = None
|
||||
|
||||
tree_to_use = self.current_crew_tree or self.current_task_branch
|
||||
if tree_to_use:
|
||||
self.print(tree_to_use)
|
||||
self.print()
|
||||
|
||||
def handle_a2a_conversation_completed(
|
||||
self,
|
||||
status: str,
|
||||
final_result: str | None,
|
||||
error: str | None,
|
||||
total_turns: int,
|
||||
) -> None:
|
||||
"""Handle A2A conversation completed event.
|
||||
|
||||
Args:
|
||||
status: Final status (completed, failed, etc.)
|
||||
final_result: Final result if completed successfully
|
||||
error: Error message if failed
|
||||
total_turns: Total number of turns in the conversation
|
||||
"""
|
||||
if self.current_a2a_conversation_branch and isinstance(
|
||||
self.current_a2a_conversation_branch, Tree
|
||||
):
|
||||
if status == "completed":
|
||||
if self._pending_a2a_message and self._pending_a2a_agent_role:
|
||||
if total_turns in self._a2a_turn_branches:
|
||||
turn_branch = self._a2a_turn_branches[total_turns]
|
||||
else:
|
||||
turn_branch = self.current_a2a_conversation_branch.add("")
|
||||
self.update_tree_label(
|
||||
turn_branch,
|
||||
"💬",
|
||||
f"Turn {total_turns}",
|
||||
"cyan",
|
||||
)
|
||||
self._a2a_turn_branches[total_turns] = turn_branch
|
||||
|
||||
crewai_agent_role = self._pending_a2a_agent_role
|
||||
message_content = self._pending_a2a_message
|
||||
|
||||
message_preview = (
|
||||
message_content[:100] + "..."
|
||||
if len(message_content) > 100
|
||||
else message_content
|
||||
)
|
||||
|
||||
user_node = turn_branch.add("")
|
||||
self.update_tree_label(
|
||||
user_node,
|
||||
f"{crewai_agent_role} 👤 : ",
|
||||
f'"{message_preview}"',
|
||||
"green",
|
||||
)
|
||||
|
||||
self._pending_a2a_message = None
|
||||
self._pending_a2a_agent_role = None
|
||||
self._pending_a2a_turn_number = None
|
||||
elif status == "failed":
|
||||
error_turn = self.current_a2a_conversation_branch.add("")
|
||||
error_msg = error[:150] + "..." if error and len(error) > 150 else error
|
||||
self.update_tree_label(
|
||||
error_turn,
|
||||
"❌",
|
||||
f"Failed: {error_msg}" if error else "Conversation Failed",
|
||||
"red",
|
||||
)
|
||||
|
||||
tree_to_use = self.current_crew_tree or self.current_task_branch
|
||||
if tree_to_use:
|
||||
self.print(tree_to_use)
|
||||
self.print()
|
||||
|
||||
self.current_a2a_conversation_branch = None
|
||||
self.current_a2a_turn_count = 0
|
||||
|
||||
# ----------- MCP EVENTS -----------
|
||||
|
||||
def handle_mcp_connection_started(
|
||||
self,
|
||||
server_name: str,
|
||||
server_url: str | None = None,
|
||||
transport_type: str | None = None,
|
||||
is_reconnect: bool = False,
|
||||
connect_timeout: int | None = None,
|
||||
) -> None:
|
||||
"""Handle MCP connection started event."""
|
||||
if not self.verbose:
|
||||
return
|
||||
|
||||
content = Text()
|
||||
reconnect_text = " (Reconnecting)" if is_reconnect else ""
|
||||
content.append(f"MCP Connection Started{reconnect_text}\n\n", style="cyan bold")
|
||||
content.append("Server: ", style="white")
|
||||
content.append(f"{server_name}\n", style="cyan")
|
||||
|
||||
if server_url:
|
||||
content.append("URL: ", style="white")
|
||||
content.append(f"{server_url}\n", style="cyan dim")
|
||||
|
||||
if transport_type:
|
||||
content.append("Transport: ", style="white")
|
||||
content.append(f"{transport_type}\n", style="cyan")
|
||||
|
||||
if connect_timeout:
|
||||
content.append("Timeout: ", style="white")
|
||||
content.append(f"{connect_timeout}s\n", style="cyan")
|
||||
|
||||
panel = self.create_panel(content, "🔌 MCP Connection", "cyan")
|
||||
self.print(panel)
|
||||
self.print()
|
||||
|
||||
def handle_mcp_connection_completed(
|
||||
self,
|
||||
server_name: str,
|
||||
server_url: str | None = None,
|
||||
transport_type: str | None = None,
|
||||
connection_duration_ms: float | None = None,
|
||||
is_reconnect: bool = False,
|
||||
) -> None:
|
||||
"""Handle MCP connection completed event."""
|
||||
if not self.verbose:
|
||||
return
|
||||
|
||||
content = Text()
|
||||
reconnect_text = " (Reconnected)" if is_reconnect else ""
|
||||
content.append(
|
||||
f"MCP Connection Completed{reconnect_text}\n\n", style="green bold"
|
||||
)
|
||||
content.append("Server: ", style="white")
|
||||
content.append(f"{server_name}\n", style="green")
|
||||
|
||||
if server_url:
|
||||
content.append("URL: ", style="white")
|
||||
content.append(f"{server_url}\n", style="green dim")
|
||||
|
||||
if transport_type:
|
||||
content.append("Transport: ", style="white")
|
||||
content.append(f"{transport_type}\n", style="green")
|
||||
|
||||
if connection_duration_ms is not None:
|
||||
content.append("Duration: ", style="white")
|
||||
content.append(f"{connection_duration_ms:.2f}ms\n", style="green")
|
||||
|
||||
panel = self.create_panel(content, "✅ MCP Connected", "green")
|
||||
self.print(panel)
|
||||
self.print()
|
||||
|
||||
def handle_mcp_connection_failed(
|
||||
self,
|
||||
server_name: str,
|
||||
server_url: str | None = None,
|
||||
transport_type: str | None = None,
|
||||
error: str = "",
|
||||
error_type: str | None = None,
|
||||
) -> None:
|
||||
"""Handle MCP connection failed event."""
|
||||
if not self.verbose:
|
||||
return
|
||||
|
||||
content = Text()
|
||||
content.append("MCP Connection Failed\n\n", style="red bold")
|
||||
content.append("Server: ", style="white")
|
||||
content.append(f"{server_name}\n", style="red")
|
||||
|
||||
if server_url:
|
||||
content.append("URL: ", style="white")
|
||||
content.append(f"{server_url}\n", style="red dim")
|
||||
|
||||
if transport_type:
|
||||
content.append("Transport: ", style="white")
|
||||
content.append(f"{transport_type}\n", style="red")
|
||||
|
||||
if error_type:
|
||||
content.append("Error Type: ", style="white")
|
||||
content.append(f"{error_type}\n", style="red")
|
||||
|
||||
if error:
|
||||
content.append("\nError: ", style="white bold")
|
||||
error_preview = error[:500] + "..." if len(error) > 500 else error
|
||||
content.append(f"{error_preview}\n", style="red")
|
||||
|
||||
panel = self.create_panel(content, "❌ MCP Connection Failed", "red")
|
||||
self.print(panel)
|
||||
self.print()
|
||||
|
||||
def handle_mcp_tool_execution_started(
|
||||
self,
|
||||
server_name: str,
|
||||
tool_name: str,
|
||||
tool_args: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""Handle MCP tool execution started event."""
|
||||
if not self.verbose:
|
||||
return
|
||||
|
||||
content = self.create_status_content(
|
||||
"MCP Tool Execution Started",
|
||||
tool_name,
|
||||
"yellow",
|
||||
tool_args=tool_args or {},
|
||||
Server=server_name,
|
||||
)
|
||||
|
||||
panel = self.create_panel(content, "🔧 MCP Tool", "yellow")
|
||||
self.print(panel)
|
||||
self.print()
|
||||
|
||||
def handle_mcp_tool_execution_completed(
|
||||
self,
|
||||
server_name: str,
|
||||
tool_name: str,
|
||||
tool_args: dict[str, Any] | None = None,
|
||||
result: Any | None = None,
|
||||
execution_duration_ms: float | None = None,
|
||||
) -> None:
|
||||
"""Handle MCP tool execution completed event."""
|
||||
if not self.verbose:
|
||||
return
|
||||
|
||||
content = self.create_status_content(
|
||||
"MCP Tool Execution Completed",
|
||||
tool_name,
|
||||
"green",
|
||||
tool_args=tool_args or {},
|
||||
Server=server_name,
|
||||
)
|
||||
|
||||
if execution_duration_ms is not None:
|
||||
content.append("Duration: ", style="white")
|
||||
content.append(f"{execution_duration_ms:.2f}ms\n", style="green")
|
||||
|
||||
if result is not None:
|
||||
result_str = str(result)
|
||||
if len(result_str) > 500:
|
||||
result_str = result_str[:497] + "..."
|
||||
content.append("\nResult: ", style="white bold")
|
||||
content.append(f"{result_str}\n", style="green")
|
||||
|
||||
panel = self.create_panel(content, "✅ MCP Tool Completed", "green")
|
||||
self.print(panel)
|
||||
self.print()
|
||||
|
||||
def handle_mcp_tool_execution_failed(
|
||||
self,
|
||||
server_name: str,
|
||||
tool_name: str,
|
||||
tool_args: dict[str, Any] | None = None,
|
||||
error: str = "",
|
||||
error_type: str | None = None,
|
||||
) -> None:
|
||||
"""Handle MCP tool execution failed event."""
|
||||
if not self.verbose:
|
||||
return
|
||||
|
||||
content = self.create_status_content(
|
||||
"MCP Tool Execution Failed",
|
||||
tool_name,
|
||||
"red",
|
||||
tool_args=tool_args or {},
|
||||
Server=server_name,
|
||||
)
|
||||
|
||||
if error_type:
|
||||
content.append("Error Type: ", style="white")
|
||||
content.append(f"{error_type}\n", style="red")
|
||||
|
||||
if error:
|
||||
content.append("\nError: ", style="white bold")
|
||||
error_preview = error[:500] + "..." if len(error) > 500 else error
|
||||
content.append(f"{error_preview}\n", style="red")
|
||||
|
||||
panel = self.create_panel(content, "❌ MCP Tool Failed", "red")
|
||||
self.print(panel)
|
||||
self.print()
|
||||
|
||||
@@ -1,65 +0,0 @@
|
||||
"""A2A (Agent-to-Agent) Protocol adapter for CrewAI.
|
||||
|
||||
This module provides integration with A2A protocol-compliant agents,
|
||||
enabling CrewAI to orchestrate external agents like ServiceNow, Bedrock Agents,
|
||||
Glean, and other A2A-compliant systems.
|
||||
|
||||
Example:
|
||||
```python
|
||||
from crewai.experimental.a2a import A2AAgentAdapter
|
||||
|
||||
# Create A2A agent
|
||||
servicenow_agent = A2AAgentAdapter(
|
||||
agent_card_url="https://servicenow.example.com/.well-known/agent-card.json",
|
||||
auth_token="your-token",
|
||||
role="ServiceNow Incident Manager",
|
||||
goal="Create and manage IT incidents",
|
||||
backstory="Expert at incident management",
|
||||
)
|
||||
|
||||
# Use in crew
|
||||
crew = Crew(agents=[servicenow_agent], tasks=[task])
|
||||
```
|
||||
"""
|
||||
|
||||
from crewai.experimental.a2a.a2a_adapter import A2AAgentAdapter
|
||||
from crewai.experimental.a2a.auth import (
|
||||
APIKeyAuth,
|
||||
AuthScheme,
|
||||
BearerTokenAuth,
|
||||
HTTPBasicAuth,
|
||||
HTTPDigestAuth,
|
||||
OAuth2AuthorizationCode,
|
||||
OAuth2ClientCredentials,
|
||||
create_auth_from_agent_card,
|
||||
)
|
||||
from crewai.experimental.a2a.exceptions import (
|
||||
A2AAuthenticationError,
|
||||
A2AConfigurationError,
|
||||
A2AConnectionError,
|
||||
A2AError,
|
||||
A2AInputRequiredError,
|
||||
A2ATaskCanceledError,
|
||||
A2ATaskFailedError,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"A2AAgentAdapter",
|
||||
"A2AAuthenticationError",
|
||||
"A2AConfigurationError",
|
||||
"A2AConnectionError",
|
||||
"A2AError",
|
||||
"A2AInputRequiredError",
|
||||
"A2ATaskCanceledError",
|
||||
"A2ATaskFailedError",
|
||||
"APIKeyAuth",
|
||||
# Authentication
|
||||
"AuthScheme",
|
||||
"BearerTokenAuth",
|
||||
"HTTPBasicAuth",
|
||||
"HTTPDigestAuth",
|
||||
"OAuth2AuthorizationCode",
|
||||
"OAuth2ClientCredentials",
|
||||
"create_auth_from_agent_card",
|
||||
]
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,424 +0,0 @@
|
||||
"""Authentication schemes for A2A protocol agents.
|
||||
|
||||
This module provides support for various authentication methods:
|
||||
- Bearer tokens (existing)
|
||||
- OAuth2 (Client Credentials, Authorization Code)
|
||||
- API Keys (header, query, cookie)
|
||||
- HTTP Basic authentication
|
||||
- HTTP Digest authentication
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
import base64
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from a2a.types import AgentCard
|
||||
|
||||
|
||||
class AuthScheme(ABC, BaseModel):
|
||||
"""Base class for authentication schemes."""
|
||||
|
||||
@abstractmethod
|
||||
async def apply_auth(
|
||||
self, client: httpx.AsyncClient, headers: dict[str, str]
|
||||
) -> dict[str, str]:
|
||||
"""Apply authentication to request headers.
|
||||
|
||||
Args:
|
||||
client: HTTP client for making auth requests.
|
||||
headers: Current request headers.
|
||||
|
||||
Returns:
|
||||
Updated headers with authentication applied.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def configure_client(self, client: httpx.AsyncClient) -> None:
|
||||
"""Configure the HTTP client for this auth scheme.
|
||||
|
||||
Args:
|
||||
client: HTTP client to configure.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class BearerTokenAuth(AuthScheme):
|
||||
"""Bearer token authentication (Authorization: Bearer <token>)."""
|
||||
|
||||
token: str = Field(description="Bearer token")
|
||||
|
||||
async def apply_auth(
|
||||
self, client: httpx.AsyncClient, headers: dict[str, str]
|
||||
) -> dict[str, str]:
|
||||
"""Apply Bearer token to Authorization header."""
|
||||
headers["Authorization"] = f"Bearer {self.token}"
|
||||
return headers
|
||||
|
||||
def configure_client(self, client: httpx.AsyncClient) -> None:
|
||||
"""No client configuration needed for Bearer tokens."""
|
||||
|
||||
|
||||
class HTTPBasicAuth(AuthScheme):
|
||||
"""HTTP Basic authentication."""
|
||||
|
||||
username: str = Field(description="Username")
|
||||
password: str = Field(description="Password")
|
||||
|
||||
async def apply_auth(
|
||||
self, client: httpx.AsyncClient, headers: dict[str, str]
|
||||
) -> dict[str, str]:
|
||||
"""Apply HTTP Basic authentication."""
|
||||
credentials = f"{self.username}:{self.password}"
|
||||
encoded = base64.b64encode(credentials.encode()).decode()
|
||||
headers["Authorization"] = f"Basic {encoded}"
|
||||
return headers
|
||||
|
||||
def configure_client(self, client: httpx.AsyncClient) -> None:
|
||||
"""No client configuration needed for Basic auth."""
|
||||
|
||||
|
||||
class HTTPDigestAuth(AuthScheme):
|
||||
"""HTTP Digest authentication.
|
||||
|
||||
Note: Uses httpx-auth library for proper digest implementation.
|
||||
"""
|
||||
|
||||
username: str = Field(description="Username")
|
||||
password: str = Field(description="Password")
|
||||
|
||||
async def apply_auth(
|
||||
self, client: httpx.AsyncClient, headers: dict[str, str]
|
||||
) -> dict[str, str]:
|
||||
"""Digest auth is handled by httpx auth flow, not headers."""
|
||||
return headers
|
||||
|
||||
def configure_client(self, client: httpx.AsyncClient) -> None:
|
||||
"""Configure client with Digest auth."""
|
||||
try:
|
||||
from httpx_auth import DigestAuth # type: ignore[import-not-found]
|
||||
|
||||
client.auth = DigestAuth(self.username, self.password) # type: ignore[import-not-found]
|
||||
except ImportError as e:
|
||||
msg = "httpx-auth required for Digest authentication. Install with: pip install httpx-auth"
|
||||
raise ImportError(msg) from e
|
||||
|
||||
|
||||
class APIKeyAuth(AuthScheme):
|
||||
"""API Key authentication (header, query, or cookie)."""
|
||||
|
||||
api_key: str = Field(description="API key value")
|
||||
location: Literal["header", "query", "cookie"] = Field(
|
||||
default="header", description="Where to send the API key"
|
||||
)
|
||||
name: str = Field(default="X-API-Key", description="Parameter name for the API key")
|
||||
|
||||
async def apply_auth(
|
||||
self, client: httpx.AsyncClient, headers: dict[str, str]
|
||||
) -> dict[str, str]:
|
||||
"""Apply API key authentication."""
|
||||
if self.location == "header":
|
||||
headers[self.name] = self.api_key
|
||||
elif self.location == "cookie":
|
||||
headers["Cookie"] = f"{self.name}={self.api_key}"
|
||||
# Query params are handled in configure_client via event hooks
|
||||
return headers
|
||||
|
||||
def configure_client(self, client: httpx.AsyncClient) -> None:
|
||||
"""Configure client for query param API keys."""
|
||||
if self.location == "query":
|
||||
# Add API key to all requests via event hook
|
||||
async def add_api_key_param(request: httpx.Request) -> None:
|
||||
url = httpx.URL(request.url)
|
||||
request.url = url.copy_add_param(self.name, self.api_key)
|
||||
|
||||
client.event_hooks["request"].append(add_api_key_param)
|
||||
|
||||
|
||||
class OAuth2ClientCredentials(AuthScheme):
|
||||
"""OAuth2 Client Credentials flow authentication."""
|
||||
|
||||
token_url: str = Field(description="OAuth2 token endpoint")
|
||||
client_id: str = Field(description="OAuth2 client ID")
|
||||
client_secret: str = Field(description="OAuth2 client secret")
|
||||
scopes: list[str] = Field(
|
||||
default_factory=list, description="Required OAuth2 scopes"
|
||||
)
|
||||
|
||||
_access_token: str | None = None
|
||||
_token_expires_at: float | None = None
|
||||
|
||||
async def apply_auth(
|
||||
self, client: httpx.AsyncClient, headers: dict[str, str]
|
||||
) -> dict[str, str]:
|
||||
"""Apply OAuth2 access token to Authorization header."""
|
||||
# Get or refresh token if needed
|
||||
import time
|
||||
|
||||
if (
|
||||
self._access_token is None
|
||||
or self._token_expires_at is None
|
||||
or time.time() >= self._token_expires_at
|
||||
):
|
||||
await self._fetch_token(client)
|
||||
|
||||
if self._access_token:
|
||||
headers["Authorization"] = f"Bearer {self._access_token}"
|
||||
|
||||
return headers
|
||||
|
||||
async def _fetch_token(self, client: httpx.AsyncClient) -> None:
|
||||
"""Fetch OAuth2 access token using client credentials flow."""
|
||||
import time
|
||||
|
||||
data = {
|
||||
"grant_type": "client_credentials",
|
||||
"client_id": self.client_id,
|
||||
"client_secret": self.client_secret,
|
||||
}
|
||||
|
||||
if self.scopes:
|
||||
data["scope"] = " ".join(self.scopes)
|
||||
|
||||
response = await client.post(self.token_url, data=data)
|
||||
response.raise_for_status()
|
||||
|
||||
token_data = response.json()
|
||||
self._access_token = token_data["access_token"]
|
||||
|
||||
# Calculate expiration time (default to 3600 seconds if not provided)
|
||||
expires_in = token_data.get("expires_in", 3600)
|
||||
self._token_expires_at = time.time() + expires_in - 60 # 60s buffer
|
||||
|
||||
def configure_client(self, client: httpx.AsyncClient) -> None:
|
||||
"""No client configuration needed for OAuth2."""
|
||||
|
||||
|
||||
class OAuth2AuthorizationCode(AuthScheme):
|
||||
"""OAuth2 Authorization Code flow authentication.
|
||||
|
||||
Note: This requires interactive authorization and is typically used
|
||||
for user-facing applications. For server-to-server, use ClientCredentials.
|
||||
"""
|
||||
|
||||
authorization_url: str = Field(description="OAuth2 authorization endpoint")
|
||||
token_url: str = Field(description="OAuth2 token endpoint")
|
||||
client_id: str = Field(description="OAuth2 client ID")
|
||||
client_secret: str = Field(description="OAuth2 client secret")
|
||||
redirect_uri: str = Field(description="OAuth2 redirect URI")
|
||||
scopes: list[str] = Field(
|
||||
default_factory=list, description="Required OAuth2 scopes"
|
||||
)
|
||||
|
||||
_access_token: str | None = None
|
||||
_refresh_token: str | None = None
|
||||
_token_expires_at: float | None = None
|
||||
_authorization_callback: Callable[[str], Awaitable[str]] | None = None
|
||||
|
||||
def set_authorization_callback(
|
||||
self, callback: Callable[[str], Awaitable[str]] | None
|
||||
) -> None:
|
||||
"""Set callback to handle authorization URL.
|
||||
|
||||
The callback receives the authorization URL and should return
|
||||
the authorization code after user completes the flow.
|
||||
"""
|
||||
self._authorization_callback = callback
|
||||
|
||||
async def apply_auth(
|
||||
self, client: httpx.AsyncClient, headers: dict[str, str]
|
||||
) -> dict[str, str]:
|
||||
"""Apply OAuth2 access token to Authorization header."""
|
||||
import time
|
||||
|
||||
# Get or refresh token if needed
|
||||
if self._access_token is None:
|
||||
if self._authorization_callback is None:
|
||||
msg = "Authorization callback not set. Use set_authorization_callback()"
|
||||
raise ValueError(msg)
|
||||
await self._fetch_initial_token(client)
|
||||
elif self._token_expires_at and time.time() >= self._token_expires_at:
|
||||
await self._refresh_access_token(client)
|
||||
|
||||
if self._access_token:
|
||||
headers["Authorization"] = f"Bearer {self._access_token}"
|
||||
|
||||
return headers
|
||||
|
||||
async def _fetch_initial_token(self, client: httpx.AsyncClient) -> None:
|
||||
"""Fetch initial access token using authorization code flow."""
|
||||
import time
|
||||
import urllib.parse
|
||||
|
||||
# Build authorization URL
|
||||
params = {
|
||||
"response_type": "code",
|
||||
"client_id": self.client_id,
|
||||
"redirect_uri": self.redirect_uri,
|
||||
"scope": " ".join(self.scopes),
|
||||
}
|
||||
auth_url = f"{self.authorization_url}?{urllib.parse.urlencode(params)}"
|
||||
|
||||
# Get authorization code from callback
|
||||
if self._authorization_callback is None:
|
||||
msg = "Authorization callback not set"
|
||||
raise ValueError(msg)
|
||||
auth_code = await self._authorization_callback(auth_url)
|
||||
|
||||
# Exchange code for token
|
||||
data = {
|
||||
"grant_type": "authorization_code",
|
||||
"code": auth_code,
|
||||
"client_id": self.client_id,
|
||||
"client_secret": self.client_secret,
|
||||
"redirect_uri": self.redirect_uri,
|
||||
}
|
||||
|
||||
response = await client.post(self.token_url, data=data)
|
||||
response.raise_for_status()
|
||||
|
||||
token_data = response.json()
|
||||
self._access_token = token_data["access_token"]
|
||||
self._refresh_token = token_data.get("refresh_token")
|
||||
|
||||
expires_in = token_data.get("expires_in", 3600)
|
||||
self._token_expires_at = time.time() + expires_in - 60
|
||||
|
||||
async def _refresh_access_token(self, client: httpx.AsyncClient) -> None:
|
||||
"""Refresh the access token using refresh token."""
|
||||
import time
|
||||
|
||||
if not self._refresh_token:
|
||||
# Re-authorize if no refresh token
|
||||
await self._fetch_initial_token(client)
|
||||
return
|
||||
|
||||
data = {
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": self._refresh_token,
|
||||
"client_id": self.client_id,
|
||||
"client_secret": self.client_secret,
|
||||
}
|
||||
|
||||
response = await client.post(self.token_url, data=data)
|
||||
response.raise_for_status()
|
||||
|
||||
token_data = response.json()
|
||||
self._access_token = token_data["access_token"]
|
||||
if "refresh_token" in token_data:
|
||||
self._refresh_token = token_data["refresh_token"]
|
||||
|
||||
expires_in = token_data.get("expires_in", 3600)
|
||||
self._token_expires_at = time.time() + expires_in - 60
|
||||
|
||||
def configure_client(self, client: httpx.AsyncClient) -> None:
|
||||
"""No client configuration needed for OAuth2."""
|
||||
|
||||
|
||||
def create_auth_from_agent_card(
|
||||
agent_card: AgentCard, credentials: dict[str, Any]
|
||||
) -> AuthScheme | None:
|
||||
"""Create an appropriate authentication scheme from AgentCard security config.
|
||||
|
||||
Args:
|
||||
agent_card: The A2A AgentCard containing security requirements.
|
||||
credentials: User-provided credentials (passwords, tokens, keys, etc.).
|
||||
|
||||
Returns:
|
||||
Configured AuthScheme, or None if no authentication required.
|
||||
|
||||
Example:
|
||||
```python
|
||||
# For OAuth2
|
||||
credentials = {
|
||||
"client_id": "my-app",
|
||||
"client_secret": "secret123",
|
||||
}
|
||||
auth = create_auth_from_agent_card(agent_card, credentials)
|
||||
|
||||
# For API Key
|
||||
credentials = {"api_key": "key-12345"}
|
||||
auth = create_auth_from_agent_card(agent_card, credentials)
|
||||
|
||||
# For HTTP Basic
|
||||
credentials = {"username": "user", "password": "pass"}
|
||||
auth = create_auth_from_agent_card(agent_card, credentials)
|
||||
```
|
||||
"""
|
||||
if not agent_card.security or not agent_card.security_schemes:
|
||||
return None
|
||||
|
||||
# Get the first required security scheme
|
||||
first_security_req = agent_card.security[0] if agent_card.security else {}
|
||||
|
||||
for scheme_name, _scopes in first_security_req.items():
|
||||
security_scheme_obj = agent_card.security_schemes.get(scheme_name)
|
||||
if not security_scheme_obj:
|
||||
continue
|
||||
|
||||
# SecurityScheme is a dict-like object
|
||||
security_scheme = dict(security_scheme_obj) # type: ignore[arg-type]
|
||||
scheme_type = str(security_scheme.get("type", "")).lower()
|
||||
|
||||
# OAuth2
|
||||
if scheme_type == "oauth2":
|
||||
flows = security_scheme.get("flows", {})
|
||||
|
||||
if "clientCredentials" in flows:
|
||||
flow = flows["clientCredentials"]
|
||||
return OAuth2ClientCredentials(
|
||||
token_url=str(flow["tokenUrl"]),
|
||||
client_id=str(credentials.get("client_id", "")),
|
||||
client_secret=str(credentials.get("client_secret", "")),
|
||||
scopes=list(flow.get("scopes", {}).keys()),
|
||||
)
|
||||
|
||||
if "authorizationCode" in flows:
|
||||
flow = flows["authorizationCode"]
|
||||
return OAuth2AuthorizationCode(
|
||||
authorization_url=str(flow["authorizationUrl"]),
|
||||
token_url=str(flow["tokenUrl"]),
|
||||
client_id=str(credentials.get("client_id", "")),
|
||||
client_secret=str(credentials.get("client_secret", "")),
|
||||
redirect_uri=str(credentials.get("redirect_uri", "")),
|
||||
scopes=list(flow.get("scopes", {}).keys()),
|
||||
)
|
||||
|
||||
# API Key
|
||||
elif scheme_type == "apikey":
|
||||
location = str(security_scheme.get("in", "header"))
|
||||
name = str(security_scheme.get("name", "X-API-Key"))
|
||||
return APIKeyAuth(
|
||||
api_key=str(credentials.get("api_key", "")),
|
||||
location=location, # type: ignore[arg-type]
|
||||
name=name,
|
||||
)
|
||||
|
||||
# HTTP Auth
|
||||
elif scheme_type == "http":
|
||||
http_scheme = str(security_scheme.get("scheme", "")).lower()
|
||||
|
||||
if http_scheme == "basic":
|
||||
return HTTPBasicAuth(
|
||||
username=str(credentials.get("username", "")),
|
||||
password=str(credentials.get("password", "")),
|
||||
)
|
||||
|
||||
if http_scheme == "digest":
|
||||
return HTTPDigestAuth(
|
||||
username=str(credentials.get("username", "")),
|
||||
password=str(credentials.get("password", "")),
|
||||
)
|
||||
|
||||
if http_scheme == "bearer":
|
||||
return BearerTokenAuth(token=str(credentials.get("token", "")))
|
||||
|
||||
return None
|
||||
@@ -1,56 +0,0 @@
|
||||
"""Custom exceptions for A2A Agent Adapter."""
|
||||
|
||||
|
||||
class A2AError(Exception):
|
||||
"""Base exception for A2A adapter errors."""
|
||||
|
||||
|
||||
class A2ATaskFailedError(A2AError):
|
||||
"""Raised when A2A agent task fails or is rejected.
|
||||
|
||||
This exception is raised when the A2A agent reports a task
|
||||
in the 'failed' or 'rejected' state.
|
||||
"""
|
||||
|
||||
|
||||
class A2AInputRequiredError(A2AError):
|
||||
"""Raised when A2A agent requires additional input.
|
||||
|
||||
This exception is raised when the A2A agent reports a task
|
||||
in the 'input_required' state, indicating that it needs more
|
||||
information to complete the task.
|
||||
"""
|
||||
|
||||
|
||||
class A2AConfigurationError(A2AError):
|
||||
"""Raised when A2A adapter configuration is invalid.
|
||||
|
||||
This exception is raised during initialization or setup when
|
||||
the adapter configuration is invalid or incompatible.
|
||||
"""
|
||||
|
||||
|
||||
class A2AConnectionError(A2AError):
|
||||
"""Raised when connection to A2A agent fails.
|
||||
|
||||
This exception is raised when the adapter cannot establish
|
||||
a connection to the A2A agent or when network errors occur.
|
||||
"""
|
||||
|
||||
|
||||
class A2AAuthenticationError(A2AError):
|
||||
"""Raised when A2A agent requires authentication.
|
||||
|
||||
This exception is raised when the A2A agent reports a task
|
||||
in the 'auth_required' state, indicating that authentication
|
||||
is needed before the task can continue.
|
||||
"""
|
||||
|
||||
|
||||
class A2ATaskCanceledError(A2AError):
|
||||
"""Raised when A2A task is canceled.
|
||||
|
||||
This exception is raised when the A2A agent reports a task
|
||||
in the 'canceled' state, indicating the task was canceled
|
||||
either by the user or the system.
|
||||
"""
|
||||
@@ -1,56 +0,0 @@
|
||||
"""Type protocols for A2A SDK components.
|
||||
|
||||
These protocols define the expected interfaces for A2A SDK types,
|
||||
allowing for type checking without requiring the SDK to be installed.
|
||||
"""
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Any, Protocol, runtime_checkable
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class AgentCardProtocol(Protocol):
|
||||
"""Protocol for A2A AgentCard."""
|
||||
|
||||
name: str
|
||||
version: str
|
||||
description: str
|
||||
skills: list[Any]
|
||||
capabilities: Any
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class ClientProtocol(Protocol):
|
||||
"""Protocol for A2A Client."""
|
||||
|
||||
async def send_message(self, message: Any) -> AsyncIterator[Any]:
|
||||
"""Send message to A2A agent."""
|
||||
...
|
||||
|
||||
async def get_card(self) -> AgentCardProtocol:
|
||||
"""Get agent card."""
|
||||
...
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close client connection."""
|
||||
...
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class MessageProtocol(Protocol):
|
||||
"""Protocol for A2A Message."""
|
||||
|
||||
role: Any
|
||||
message_id: str
|
||||
parts: list[Any]
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class TaskProtocol(Protocol):
|
||||
"""Protocol for A2A Task."""
|
||||
|
||||
id: str
|
||||
context_id: str
|
||||
status: Any
|
||||
history: list[Any] | None
|
||||
artifacts: list[Any] | None
|
||||
@@ -2,7 +2,7 @@ from collections.abc import Sequence
|
||||
import threading
|
||||
from typing import Any
|
||||
|
||||
from crewai.agent import Agent
|
||||
from crewai.agent.core import Agent
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.agent_events import (
|
||||
|
||||
@@ -1,5 +1,21 @@
|
||||
from crewai.flow.flow import Flow, and_, listen, or_, router, start
|
||||
from crewai.flow.persistence import persist
|
||||
from crewai.flow.visualization import (
|
||||
FlowStructure,
|
||||
build_flow_structure,
|
||||
visualize_flow_structure,
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["Flow", "and_", "listen", "or_", "persist", "router", "start"]
|
||||
__all__ = [
|
||||
"Flow",
|
||||
"FlowStructure",
|
||||
"and_",
|
||||
"build_flow_structure",
|
||||
"listen",
|
||||
"or_",
|
||||
"persist",
|
||||
"router",
|
||||
"start",
|
||||
"visualize_flow_structure",
|
||||
]
|
||||
|
||||
@@ -1,93 +0,0 @@
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="utf-8" />
|
||||
<title>{{ title }}</title>
|
||||
<script
|
||||
src="https://cdnjs.cloudflare.com/ajax/libs/vis-network/9.1.2/dist/vis-network.min.js"
|
||||
integrity="sha512-LnvoEWDFrqGHlHmDD2101OrLcbsfkrzoSpvtSQtxK3RMnRV0eOkhhBN2dXHKRrUU8p2DGRTk35n4O8nWSVe1mQ=="
|
||||
crossorigin="anonymous"
|
||||
referrerpolicy="no-referrer"
|
||||
></script>
|
||||
<link
|
||||
rel="stylesheet"
|
||||
href="https://cdnjs.cloudflare.com/ajax/libs/vis-network/9.1.2/dist/dist/vis-network.min.css"
|
||||
integrity="sha512-WgxfT5LWjfszlPHXRmBWHkV2eceiWTOBvrKCNbdgDYTHrT2AeLCGbF4sZlZw3UMN3WtL0tGUoIAKsu8mllg/XA=="
|
||||
crossorigin="anonymous"
|
||||
referrerpolicy="no-referrer"
|
||||
/>
|
||||
<style type="text/css">
|
||||
body {
|
||||
font-family: verdana;
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
}
|
||||
.container {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
height: 100vh;
|
||||
}
|
||||
#mynetwork {
|
||||
flex-grow: 1;
|
||||
width: 100%;
|
||||
height: 750px;
|
||||
background-color: #ffffff;
|
||||
}
|
||||
.card {
|
||||
border: none;
|
||||
}
|
||||
.legend-container {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
padding: 10px;
|
||||
background-color: #f8f9fa;
|
||||
position: fixed; /* Make the legend fixed */
|
||||
bottom: 0; /* Position it at the bottom */
|
||||
width: 100%; /* Make it span the full width */
|
||||
}
|
||||
.legend-item {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
margin-right: 20px;
|
||||
}
|
||||
.legend-color-box {
|
||||
width: 20px;
|
||||
height: 20px;
|
||||
margin-right: 5px;
|
||||
}
|
||||
.logo {
|
||||
height: 50px;
|
||||
margin-right: 20px;
|
||||
}
|
||||
.legend-dashed {
|
||||
border-bottom: 2px dashed #666666;
|
||||
width: 20px;
|
||||
height: 0;
|
||||
margin-right: 5px;
|
||||
}
|
||||
.legend-solid {
|
||||
border-bottom: 2px solid #666666;
|
||||
width: 20px;
|
||||
height: 0;
|
||||
margin-right: 5px;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<div class="card" style="width: 100%">
|
||||
<div id="mynetwork" class="card-body"></div>
|
||||
</div>
|
||||
<div class="legend-container">
|
||||
<img
|
||||
src="data:image/svg+xml;base64,{{ logo_svg_base64 }}"
|
||||
alt="CrewAI logo"
|
||||
class="logo"
|
||||
/>
|
||||
<!-- LEGEND_ITEMS_PLACEHOLDER -->
|
||||
</div>
|
||||
</div>
|
||||
{{ network_content }}
|
||||
</body>
|
||||
</html>
|
||||
File diff suppressed because one or more lines are too long
|
Before Width: | Height: | Size: 27 KiB |
4
lib/crewai/src/crewai/flow/constants.py
Normal file
4
lib/crewai/src/crewai/flow/constants.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from typing import Final, Literal
|
||||
|
||||
AND_CONDITION: Final[Literal["AND"]] = "AND"
|
||||
OR_CONDITION: Final[Literal["OR"]] = "OR"
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user