Gl/feat/a2a refactor (#3793)

* feat: agent metaclass, refactor a2a to wrappers

* feat: a2a schemas and utils

* chore: move agent class, update imports

* refactor: organize imports to avoid circularity, add a2a to console

* feat: pass response_model through call chain

* feat: add standard openapi spec serialization to tools and structured output

* feat: a2a events

* chore: add a2a to pyproject

* docs: minimal base for learn docs

* fix: adjust a2a conversation flow, allow llm to decide exit until max_retries

* fix: inject agent skills into initial prompt

* fix: format agent card as json in prompt

* refactor: simplify A2A agent prompt formatting and improve skill display

* chore: wide cleanup

* chore: cleanup logic, add auth cache, use json for messages in prompt

* chore: update docs

* fix: doc snippets formatting

* feat: optimize A2A agent card fetching and improve error reporting

* chore: move imports to top of file

* chore: refactor hasattr check

* chore: add httpx-auth, update lockfile

* feat: create base public api

* chore: cleanup modules, add docstrings, types

* fix: exclude extra fields in prompt

* chore: update docs

* tests: update to correct import

* chore: lint for ruff, add missing import

* fix: tweak openai streaming logic for response model

* tests: add reimport for test

* tests: add reimport for test

* fix: don't set a2a attr if not set

* fix: don't set a2a attr if not set

* chore: update cassettes

* tests: fix tests

* fix: use instructor and dont pass response_format for litellm

* chore: consolidate event listeners, add typing

* fix: address race condition in test, update cassettes

* tests: add correct mocks, rerun cassette for json

* tests: update cassette

* chore: regenerate cassette after new run

* fix: make token manager access-safe

* fix: make token manager access-safe

* merge

* chore: update test and cassete for output pydantic

* fix: tweak to disallow deadlock

* chore: linter

* fix: adjust event ordering for threading

* fix: use conditional for batch check

* tests: tweak for emission

* tests: simplify api + event check

* fix: ensure non-function calling llms see json formatted string

* tests: tweak message comparison

* fix: use internal instructor for litellm structure responses

---------

Co-authored-by: Mike Plachta <mike@crewai.com>
This commit is contained in:
Greyson LaLonde
2025-11-01 02:42:03 +01:00
committed by GitHub
parent e229ef4e19
commit e134e5305b
71 changed files with 9790 additions and 4592 deletions

View File

@@ -3,7 +3,7 @@ from typing import Any
import urllib.request
import warnings
from crewai.agent import Agent
from crewai.agent.core import Agent
from crewai.crew import Crew
from crewai.crews.crew_output import CrewOutput
from crewai.flow.flow import Flow

View File

@@ -0,0 +1,6 @@
"""Agent-to-Agent (A2A) protocol communication module for CrewAI."""
from crewai.a2a.config import A2AConfig
__all__ = ["A2AConfig"]

View File

@@ -0,0 +1,20 @@
"""A2A authentication schemas."""
from crewai.a2a.auth.schemas import (
APIKeyAuth,
BearerTokenAuth,
HTTPBasicAuth,
HTTPDigestAuth,
OAuth2AuthorizationCode,
OAuth2ClientCredentials,
)
__all__ = [
"APIKeyAuth",
"BearerTokenAuth",
"HTTPBasicAuth",
"HTTPDigestAuth",
"OAuth2AuthorizationCode",
"OAuth2ClientCredentials",
]

View File

@@ -0,0 +1,392 @@
"""Authentication schemes for A2A protocol agents.
Supported authentication methods:
- Bearer tokens
- OAuth2 (Client Credentials, Authorization Code)
- API Keys (header, query, cookie)
- HTTP Basic authentication
- HTTP Digest authentication
"""
from __future__ import annotations
from abc import ABC, abstractmethod
import base64
from collections.abc import Awaitable, Callable, MutableMapping
import time
from typing import Literal
import urllib.parse
import httpx
from httpx import DigestAuth
from pydantic import BaseModel, Field, PrivateAttr
class AuthScheme(ABC, BaseModel):
"""Base class for authentication schemes."""
@abstractmethod
async def apply_auth(
self, client: httpx.AsyncClient, headers: MutableMapping[str, str]
) -> MutableMapping[str, str]:
"""Apply authentication to request headers.
Args:
client: HTTP client for making auth requests.
headers: Current request headers.
Returns:
Updated headers with authentication applied.
"""
...
class BearerTokenAuth(AuthScheme):
"""Bearer token authentication (Authorization: Bearer <token>).
Attributes:
token: Bearer token for authentication.
"""
token: str = Field(description="Bearer token")
async def apply_auth(
self, client: httpx.AsyncClient, headers: MutableMapping[str, str]
) -> MutableMapping[str, str]:
"""Apply Bearer token to Authorization header.
Args:
client: HTTP client for making auth requests.
headers: Current request headers.
Returns:
Updated headers with Bearer token in Authorization header.
"""
headers["Authorization"] = f"Bearer {self.token}"
return headers
class HTTPBasicAuth(AuthScheme):
"""HTTP Basic authentication.
Attributes:
username: Username for Basic authentication.
password: Password for Basic authentication.
"""
username: str = Field(description="Username")
password: str = Field(description="Password")
async def apply_auth(
self, client: httpx.AsyncClient, headers: MutableMapping[str, str]
) -> MutableMapping[str, str]:
"""Apply HTTP Basic authentication.
Args:
client: HTTP client for making auth requests.
headers: Current request headers.
Returns:
Updated headers with Basic auth in Authorization header.
"""
credentials = f"{self.username}:{self.password}"
encoded = base64.b64encode(credentials.encode()).decode()
headers["Authorization"] = f"Basic {encoded}"
return headers
class HTTPDigestAuth(AuthScheme):
"""HTTP Digest authentication.
Note: Uses httpx-auth library for digest implementation.
Attributes:
username: Username for Digest authentication.
password: Password for Digest authentication.
"""
username: str = Field(description="Username")
password: str = Field(description="Password")
async def apply_auth(
self, client: httpx.AsyncClient, headers: MutableMapping[str, str]
) -> MutableMapping[str, str]:
"""Digest auth is handled by httpx auth flow, not headers.
Args:
client: HTTP client for making auth requests.
headers: Current request headers.
Returns:
Unchanged headers (Digest auth handled by httpx auth flow).
"""
return headers
def configure_client(self, client: httpx.AsyncClient) -> None:
"""Configure client with Digest auth.
Args:
client: HTTP client to configure with Digest authentication.
"""
client.auth = DigestAuth(self.username, self.password)
class APIKeyAuth(AuthScheme):
"""API Key authentication (header, query, or cookie).
Attributes:
api_key: API key value for authentication.
location: Where to send the API key (header, query, or cookie).
name: Parameter name for the API key (default: X-API-Key).
"""
api_key: str = Field(description="API key value")
location: Literal["header", "query", "cookie"] = Field(
default="header", description="Where to send the API key"
)
name: str = Field(default="X-API-Key", description="Parameter name for the API key")
async def apply_auth(
self, client: httpx.AsyncClient, headers: MutableMapping[str, str]
) -> MutableMapping[str, str]:
"""Apply API key authentication.
Args:
client: HTTP client for making auth requests.
headers: Current request headers.
Returns:
Updated headers with API key (for header/cookie locations).
"""
if self.location == "header":
headers[self.name] = self.api_key
elif self.location == "cookie":
headers["Cookie"] = f"{self.name}={self.api_key}"
return headers
def configure_client(self, client: httpx.AsyncClient) -> None:
"""Configure client for query param API keys.
Args:
client: HTTP client to configure with query param API key hook.
"""
if self.location == "query":
async def _add_api_key_param(request: httpx.Request) -> None:
url = httpx.URL(request.url)
request.url = url.copy_add_param(self.name, self.api_key)
client.event_hooks["request"].append(_add_api_key_param)
class OAuth2ClientCredentials(AuthScheme):
"""OAuth2 Client Credentials flow authentication.
Attributes:
token_url: OAuth2 token endpoint URL.
client_id: OAuth2 client identifier.
client_secret: OAuth2 client secret.
scopes: List of required OAuth2 scopes.
"""
token_url: str = Field(description="OAuth2 token endpoint")
client_id: str = Field(description="OAuth2 client ID")
client_secret: str = Field(description="OAuth2 client secret")
scopes: list[str] = Field(
default_factory=list, description="Required OAuth2 scopes"
)
_access_token: str | None = PrivateAttr(default=None)
_token_expires_at: float | None = PrivateAttr(default=None)
async def apply_auth(
self, client: httpx.AsyncClient, headers: MutableMapping[str, str]
) -> MutableMapping[str, str]:
"""Apply OAuth2 access token to Authorization header.
Args:
client: HTTP client for making token requests.
headers: Current request headers.
Returns:
Updated headers with OAuth2 access token in Authorization header.
"""
if (
self._access_token is None
or self._token_expires_at is None
or time.time() >= self._token_expires_at
):
await self._fetch_token(client)
if self._access_token:
headers["Authorization"] = f"Bearer {self._access_token}"
return headers
async def _fetch_token(self, client: httpx.AsyncClient) -> None:
"""Fetch OAuth2 access token using client credentials flow.
Args:
client: HTTP client for making token request.
Raises:
httpx.HTTPStatusError: If token request fails.
"""
data = {
"grant_type": "client_credentials",
"client_id": self.client_id,
"client_secret": self.client_secret,
}
if self.scopes:
data["scope"] = " ".join(self.scopes)
response = await client.post(self.token_url, data=data)
response.raise_for_status()
token_data = response.json()
self._access_token = token_data["access_token"]
expires_in = token_data.get("expires_in", 3600)
self._token_expires_at = time.time() + expires_in - 60
class OAuth2AuthorizationCode(AuthScheme):
"""OAuth2 Authorization Code flow authentication.
Note: Requires interactive authorization.
Attributes:
authorization_url: OAuth2 authorization endpoint URL.
token_url: OAuth2 token endpoint URL.
client_id: OAuth2 client identifier.
client_secret: OAuth2 client secret.
redirect_uri: OAuth2 redirect URI for callback.
scopes: List of required OAuth2 scopes.
"""
authorization_url: str = Field(description="OAuth2 authorization endpoint")
token_url: str = Field(description="OAuth2 token endpoint")
client_id: str = Field(description="OAuth2 client ID")
client_secret: str = Field(description="OAuth2 client secret")
redirect_uri: str = Field(description="OAuth2 redirect URI")
scopes: list[str] = Field(
default_factory=list, description="Required OAuth2 scopes"
)
_access_token: str | None = PrivateAttr(default=None)
_refresh_token: str | None = PrivateAttr(default=None)
_token_expires_at: float | None = PrivateAttr(default=None)
_authorization_callback: Callable[[str], Awaitable[str]] | None = PrivateAttr(
default=None
)
def set_authorization_callback(
self, callback: Callable[[str], Awaitable[str]] | None
) -> None:
"""Set callback to handle authorization URL.
Args:
callback: Async function that receives authorization URL and returns auth code.
"""
self._authorization_callback = callback
async def apply_auth(
self, client: httpx.AsyncClient, headers: MutableMapping[str, str]
) -> MutableMapping[str, str]:
"""Apply OAuth2 access token to Authorization header.
Args:
client: HTTP client for making token requests.
headers: Current request headers.
Returns:
Updated headers with OAuth2 access token in Authorization header.
Raises:
ValueError: If authorization callback is not set.
"""
if self._access_token is None:
if self._authorization_callback is None:
msg = "Authorization callback not set. Use set_authorization_callback()"
raise ValueError(msg)
await self._fetch_initial_token(client)
elif self._token_expires_at and time.time() >= self._token_expires_at:
await self._refresh_access_token(client)
if self._access_token:
headers["Authorization"] = f"Bearer {self._access_token}"
return headers
async def _fetch_initial_token(self, client: httpx.AsyncClient) -> None:
"""Fetch initial access token using authorization code flow.
Args:
client: HTTP client for making token request.
Raises:
ValueError: If authorization callback is not set.
httpx.HTTPStatusError: If token request fails.
"""
params = {
"response_type": "code",
"client_id": self.client_id,
"redirect_uri": self.redirect_uri,
"scope": " ".join(self.scopes),
}
auth_url = f"{self.authorization_url}?{urllib.parse.urlencode(params)}"
if self._authorization_callback is None:
msg = "Authorization callback not set"
raise ValueError(msg)
auth_code = await self._authorization_callback(auth_url)
data = {
"grant_type": "authorization_code",
"code": auth_code,
"client_id": self.client_id,
"client_secret": self.client_secret,
"redirect_uri": self.redirect_uri,
}
response = await client.post(self.token_url, data=data)
response.raise_for_status()
token_data = response.json()
self._access_token = token_data["access_token"]
self._refresh_token = token_data.get("refresh_token")
expires_in = token_data.get("expires_in", 3600)
self._token_expires_at = time.time() + expires_in - 60
async def _refresh_access_token(self, client: httpx.AsyncClient) -> None:
"""Refresh the access token using refresh token.
Args:
client: HTTP client for making token request.
Raises:
httpx.HTTPStatusError: If token refresh request fails.
"""
if not self._refresh_token:
await self._fetch_initial_token(client)
return
data = {
"grant_type": "refresh_token",
"refresh_token": self._refresh_token,
"client_id": self.client_id,
"client_secret": self.client_secret,
}
response = await client.post(self.token_url, data=data)
response.raise_for_status()
token_data = response.json()
self._access_token = token_data["access_token"]
if "refresh_token" in token_data:
self._refresh_token = token_data["refresh_token"]
expires_in = token_data.get("expires_in", 3600)
self._token_expires_at = time.time() + expires_in - 60

View File

@@ -0,0 +1,236 @@
"""Authentication utilities for A2A protocol agent communication.
Provides validation and retry logic for various authentication schemes including
OAuth2, API keys, and HTTP authentication methods.
"""
import asyncio
from collections.abc import Awaitable, Callable, MutableMapping
import re
from typing import Final
from a2a.client.errors import A2AClientHTTPError
from a2a.types import (
APIKeySecurityScheme,
AgentCard,
HTTPAuthSecurityScheme,
OAuth2SecurityScheme,
)
from httpx import AsyncClient, Response
from crewai.a2a.auth.schemas import (
APIKeyAuth,
AuthScheme,
BearerTokenAuth,
HTTPBasicAuth,
HTTPDigestAuth,
OAuth2AuthorizationCode,
OAuth2ClientCredentials,
)
_auth_store: dict[int, AuthScheme | None] = {}
_SCHEME_PATTERN: Final[re.Pattern[str]] = re.compile(r"(\w+)\s+(.+?)(?=,\s*\w+\s+|$)")
_PARAM_PATTERN: Final[re.Pattern[str]] = re.compile(r'(\w+)=(?:"([^"]*)"|([^\s,]+))')
_SCHEME_AUTH_MAPPING: Final[dict[type, tuple[type[AuthScheme], ...]]] = {
OAuth2SecurityScheme: (
OAuth2ClientCredentials,
OAuth2AuthorizationCode,
BearerTokenAuth,
),
APIKeySecurityScheme: (APIKeyAuth,),
}
_HTTP_SCHEME_MAPPING: Final[dict[str, type[AuthScheme]]] = {
"basic": HTTPBasicAuth,
"digest": HTTPDigestAuth,
"bearer": BearerTokenAuth,
}
def _raise_auth_mismatch(
expected_classes: type[AuthScheme] | tuple[type[AuthScheme], ...],
provided_auth: AuthScheme,
) -> None:
"""Raise authentication mismatch error.
Args:
expected_classes: Expected authentication class or tuple of classes.
provided_auth: Actually provided authentication instance.
Raises:
A2AClientHTTPError: Always raises with 401 status code.
"""
if isinstance(expected_classes, tuple):
if len(expected_classes) == 1:
required = expected_classes[0].__name__
else:
names = [cls.__name__ for cls in expected_classes]
required = f"one of ({', '.join(names)})"
else:
required = expected_classes.__name__
msg = (
f"AgentCard requires {required} authentication, "
f"but {type(provided_auth).__name__} was provided"
)
raise A2AClientHTTPError(401, msg)
def parse_www_authenticate(header_value: str) -> dict[str, dict[str, str]]:
"""Parse WWW-Authenticate header into auth challenges.
Args:
header_value: The WWW-Authenticate header value.
Returns:
Dictionary mapping auth scheme to its parameters.
Example: {"Bearer": {"realm": "api", "scope": "read write"}}
"""
if not header_value:
return {}
challenges: dict[str, dict[str, str]] = {}
for match in _SCHEME_PATTERN.finditer(header_value):
scheme = match.group(1)
params_str = match.group(2)
params: dict[str, str] = {}
for param_match in _PARAM_PATTERN.finditer(params_str):
key = param_match.group(1)
value = param_match.group(2) or param_match.group(3)
params[key] = value
challenges[scheme] = params
return challenges
def validate_auth_against_agent_card(
agent_card: AgentCard, auth: AuthScheme | None
) -> None:
"""Validate that provided auth matches AgentCard security requirements.
Args:
agent_card: The A2A AgentCard containing security requirements.
auth: User-provided authentication scheme (or None).
Raises:
A2AClientHTTPError: If auth doesn't match AgentCard requirements (status_code=401).
"""
if not agent_card.security or not agent_card.security_schemes:
return
if not auth:
msg = "AgentCard requires authentication but no auth scheme provided"
raise A2AClientHTTPError(401, msg)
first_security_req = agent_card.security[0] if agent_card.security else {}
for scheme_name in first_security_req.keys():
security_scheme_wrapper = agent_card.security_schemes.get(scheme_name)
if not security_scheme_wrapper:
continue
scheme = security_scheme_wrapper.root
if allowed_classes := _SCHEME_AUTH_MAPPING.get(type(scheme)):
if not isinstance(auth, allowed_classes):
_raise_auth_mismatch(allowed_classes, auth)
return
if isinstance(scheme, HTTPAuthSecurityScheme):
if required_class := _HTTP_SCHEME_MAPPING.get(scheme.scheme.lower()):
if not isinstance(auth, required_class):
_raise_auth_mismatch(required_class, auth)
return
msg = "Could not validate auth against AgentCard security requirements"
raise A2AClientHTTPError(401, msg)
async def retry_on_401(
request_func: Callable[[], Awaitable[Response]],
auth_scheme: AuthScheme | None,
client: AsyncClient,
headers: MutableMapping[str, str],
max_retries: int = 3,
) -> Response:
"""Retry a request on 401 authentication error.
Handles 401 errors by:
1. Parsing WWW-Authenticate header
2. Re-acquiring credentials
3. Retrying the request
Args:
request_func: Async function that makes the HTTP request.
auth_scheme: Authentication scheme to refresh credentials with.
client: HTTP client for making requests.
headers: Request headers to update with new auth.
max_retries: Maximum number of retry attempts (default: 3).
Returns:
HTTP response from the request.
Raises:
httpx.HTTPStatusError: If retries are exhausted or auth scheme is None.
"""
last_response: Response | None = None
last_challenges: dict[str, dict[str, str]] = {}
for attempt in range(max_retries):
response = await request_func()
if response.status_code != 401:
return response
last_response = response
if auth_scheme is None:
response.raise_for_status()
return response
www_authenticate = response.headers.get("WWW-Authenticate", "")
challenges = parse_www_authenticate(www_authenticate)
last_challenges = challenges
if attempt >= max_retries - 1:
break
backoff_time = 2**attempt
await asyncio.sleep(backoff_time)
await auth_scheme.apply_auth(client, headers)
if last_response:
last_response.raise_for_status()
return last_response
msg = "retry_on_401 failed without making any requests"
if last_challenges:
challenge_info = ", ".join(
f"{scheme} (realm={params.get('realm', 'N/A')})"
for scheme, params in last_challenges.items()
)
msg = f"{msg}. Server challenges: {challenge_info}"
raise RuntimeError(msg)
def configure_auth_client(
auth: HTTPDigestAuth | APIKeyAuth, client: AsyncClient
) -> None:
"""Configure HTTP client with auth-specific settings.
Only HTTPDigestAuth and APIKeyAuth need client configuration.
Args:
auth: Authentication scheme that requires client configuration.
client: HTTP client to configure.
"""
auth.configure_client(client)

View File

@@ -0,0 +1,59 @@
"""A2A configuration types.
This module is separate from experimental.a2a to avoid circular imports.
"""
from __future__ import annotations
from typing import Annotated
from pydantic import (
BaseModel,
BeforeValidator,
Field,
HttpUrl,
TypeAdapter,
)
from crewai.a2a.auth.schemas import AuthScheme
http_url_adapter = TypeAdapter(HttpUrl)
Url = Annotated[
str,
BeforeValidator(
lambda value: str(http_url_adapter.validate_python(value, strict=True))
),
]
class A2AConfig(BaseModel):
"""Configuration for A2A protocol integration.
Attributes:
endpoint: A2A agent endpoint URL.
auth: Authentication scheme (Bearer, OAuth2, API Key, HTTP Basic/Digest).
timeout: Request timeout in seconds (default: 120).
max_turns: Maximum conversation turns with A2A agent (default: 10).
response_model: Optional Pydantic model for structured A2A agent responses.
fail_fast: If True, raise error when agent unreachable; if False, skip and continue (default: True).
"""
endpoint: Url = Field(description="A2A agent endpoint URL")
auth: AuthScheme | None = Field(
default=None,
description="Authentication scheme (Bearer, OAuth2, API Key, HTTP Basic/Digest)",
)
timeout: int = Field(default=120, description="Request timeout in seconds")
max_turns: int = Field(
default=10, description="Maximum conversation turns with A2A agent"
)
response_model: type[BaseModel] | None = Field(
default=None,
description="Optional Pydantic model for structured A2A agent responses. When specified, the A2A agent is expected to return JSON matching this schema.",
)
fail_fast: bool = Field(
default=True,
description="If True, raise an error immediately when the A2A agent is unreachable. If False, skip the A2A agent and continue execution.",
)

View File

@@ -0,0 +1,29 @@
"""String templates for A2A (Agent-to-Agent) protocol messaging and status."""
from string import Template
from typing import Final
AVAILABLE_AGENTS_TEMPLATE: Final[Template] = Template(
"\n<AVAILABLE_A2A_AGENTS>\n $available_a2a_agents\n</AVAILABLE_A2A_AGENTS>\n"
)
PREVIOUS_A2A_CONVERSATION_TEMPLATE: Final[Template] = Template(
"\n<PREVIOUS_A2A_CONVERSATION>\n"
" $previous_a2a_conversation"
"\n</PREVIOUS_A2A_CONVERSATION>\n"
)
CONVERSATION_TURN_INFO_TEMPLATE: Final[Template] = Template(
"\n<CONVERSATION_PROGRESS>\n"
' turn="$turn_count"\n'
' max_turns="$max_turns"\n'
" $warning"
"\n</CONVERSATION_PROGRESS>\n"
)
UNAVAILABLE_AGENTS_NOTICE_TEMPLATE: Final[Template] = Template(
"\n<A2A_AGENTS_STATUS>\n"
" NOTE: A2A agents were configured but are currently unavailable.\n"
" You cannot delegate to remote agents for this task.\n\n"
" Unavailable Agents:\n"
" $unavailable_agents"
"\n</A2A_AGENTS_STATUS>\n"
)

View File

@@ -0,0 +1,38 @@
"""Type definitions for A2A protocol message parts."""
from typing import Any, Literal, Protocol, TypedDict, runtime_checkable
from typing_extensions import NotRequired
@runtime_checkable
class AgentResponseProtocol(Protocol):
"""Protocol for the dynamically created AgentResponse model."""
a2a_ids: tuple[str, ...]
message: str
is_a2a: bool
class PartsMetadataDict(TypedDict, total=False):
"""Metadata for A2A message parts.
Attributes:
mimeType: MIME type for the part content.
schema: JSON schema for the part content.
"""
mimeType: Literal["application/json"]
schema: dict[str, Any]
class PartsDict(TypedDict):
"""A2A message part containing text and optional metadata.
Attributes:
text: The text content of the message part.
metadata: Optional metadata describing the part content.
"""
text: str
metadata: NotRequired[PartsMetadataDict]

View File

@@ -0,0 +1,755 @@
"""Utility functions for A2A (Agent-to-Agent) protocol delegation."""
from __future__ import annotations
import asyncio
from collections.abc import AsyncIterator, MutableMapping
from contextlib import asynccontextmanager
from functools import lru_cache
import time
from typing import TYPE_CHECKING, Any
import uuid
from a2a.client import Client, ClientConfig, ClientFactory
from a2a.client.errors import A2AClientHTTPError
from a2a.types import (
AgentCard,
Message,
Part,
Role,
TaskArtifactUpdateEvent,
TaskState,
TaskStatusUpdateEvent,
TextPart,
TransportProtocol,
)
import httpx
from pydantic import BaseModel, Field, create_model
from crewai.a2a.auth.schemas import APIKeyAuth, HTTPDigestAuth
from crewai.a2a.auth.utils import (
_auth_store,
configure_auth_client,
retry_on_401,
validate_auth_against_agent_card,
)
from crewai.a2a.config import A2AConfig
from crewai.a2a.types import PartsDict, PartsMetadataDict
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.a2a_events import (
A2AConversationStartedEvent,
A2ADelegationCompletedEvent,
A2ADelegationStartedEvent,
A2AMessageSentEvent,
A2AResponseReceivedEvent,
)
from crewai.types.utils import create_literals_from_strings
if TYPE_CHECKING:
from a2a.types import Message, Task as A2ATask
from crewai.a2a.auth.schemas import AuthScheme
@lru_cache()
def _fetch_agent_card_cached(
endpoint: str,
auth_hash: int,
timeout: int,
_ttl_hash: int,
) -> AgentCard:
"""Cached version of fetch_agent_card with auth support.
Args:
endpoint: A2A agent endpoint URL
auth_hash: Hash of the auth object
timeout: Request timeout
_ttl_hash: Time-based hash for cache invalidation (unused in body)
Returns:
Cached AgentCard
"""
auth = _auth_store.get(auth_hash)
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
return loop.run_until_complete(
_fetch_agent_card_async(endpoint=endpoint, auth=auth, timeout=timeout)
)
finally:
loop.close()
def fetch_agent_card(
endpoint: str,
auth: AuthScheme | None = None,
timeout: int = 30,
use_cache: bool = True,
cache_ttl: int = 300,
) -> AgentCard:
"""Fetch AgentCard from an A2A endpoint with optional caching.
Args:
endpoint: A2A agent endpoint URL (AgentCard URL)
auth: Optional AuthScheme for authentication
timeout: Request timeout in seconds
use_cache: Whether to use caching (default True)
cache_ttl: Cache TTL in seconds (default 300 = 5 minutes)
Returns:
AgentCard object with agent capabilities and skills
Raises:
httpx.HTTPStatusError: If the request fails
A2AClientHTTPError: If authentication fails
"""
if use_cache:
auth_hash = hash((type(auth).__name__, id(auth))) if auth else 0
_auth_store[auth_hash] = auth
ttl_hash = int(time.time() // cache_ttl)
return _fetch_agent_card_cached(endpoint, auth_hash, timeout, ttl_hash)
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
return loop.run_until_complete(
_fetch_agent_card_async(endpoint=endpoint, auth=auth, timeout=timeout)
)
finally:
loop.close()
async def _fetch_agent_card_async(
endpoint: str,
auth: AuthScheme | None,
timeout: int,
) -> AgentCard:
"""Async implementation of AgentCard fetching.
Args:
endpoint: A2A agent endpoint URL
auth: Optional AuthScheme for authentication
timeout: Request timeout in seconds
Returns:
AgentCard object
"""
if "/.well-known/agent-card.json" in endpoint:
base_url = endpoint.replace("/.well-known/agent-card.json", "")
agent_card_path = "/.well-known/agent-card.json"
else:
url_parts = endpoint.split("/", 3)
base_url = f"{url_parts[0]}//{url_parts[2]}"
agent_card_path = f"/{url_parts[3]}" if len(url_parts) > 3 else "/"
headers: MutableMapping[str, str] = {}
if auth:
async with httpx.AsyncClient(timeout=timeout) as temp_auth_client:
if isinstance(auth, (HTTPDigestAuth, APIKeyAuth)):
configure_auth_client(auth, temp_auth_client)
headers = await auth.apply_auth(temp_auth_client, {})
async with httpx.AsyncClient(timeout=timeout, headers=headers) as temp_client:
if auth and isinstance(auth, (HTTPDigestAuth, APIKeyAuth)):
configure_auth_client(auth, temp_client)
agent_card_url = f"{base_url}{agent_card_path}"
async def _fetch_agent_card_request() -> httpx.Response:
return await temp_client.get(agent_card_url)
try:
response = await retry_on_401(
request_func=_fetch_agent_card_request,
auth_scheme=auth,
client=temp_client,
headers=temp_client.headers,
max_retries=2,
)
response.raise_for_status()
return AgentCard.model_validate(response.json())
except httpx.HTTPStatusError as e:
if e.response.status_code == 401:
error_details = ["Authentication failed"]
www_auth = e.response.headers.get("WWW-Authenticate")
if www_auth:
error_details.append(f"WWW-Authenticate: {www_auth}")
if not auth:
error_details.append("No auth scheme provided")
msg = " | ".join(error_details)
raise A2AClientHTTPError(401, msg) from e
raise
def execute_a2a_delegation(
endpoint: str,
auth: AuthScheme | None,
timeout: int,
task_description: str,
context: str | None = None,
context_id: str | None = None,
task_id: str | None = None,
reference_task_ids: list[str] | None = None,
metadata: dict[str, Any] | None = None,
extensions: dict[str, Any] | None = None,
conversation_history: list[Message] | None = None,
agent_id: str | None = None,
agent_role: Role | None = None,
agent_branch: Any | None = None,
response_model: type[BaseModel] | None = None,
turn_number: int | None = None,
) -> dict[str, Any]:
"""Execute a task delegation to a remote A2A agent with multi-turn support.
Handles:
- AgentCard discovery
- Authentication setup
- Message creation and sending
- Response parsing
- Multi-turn conversations
Args:
endpoint: A2A agent endpoint URL (AgentCard URL)
auth: Optional AuthScheme for authentication (Bearer, OAuth2, API Key, HTTP Basic/Digest)
timeout: Request timeout in seconds
task_description: The task to delegate
context: Optional context information
context_id: Context ID for correlating messages/tasks
task_id: Specific task identifier
reference_task_ids: List of related task IDs
metadata: Additional metadata (external_id, request_id, etc.)
extensions: Protocol extensions for custom fields
conversation_history: Previous Message objects from conversation
agent_id: Agent identifier for logging
agent_role: Role of the CrewAI agent delegating the task
agent_branch: Optional agent tree branch for logging
response_model: Optional Pydantic model for structured outputs
turn_number: Optional turn number for multi-turn conversations
Returns:
Dictionary with:
- status: "completed", "input_required", "failed", etc.
- result: Result string (if completed)
- error: Error message (if failed)
- history: List of new Message objects from this exchange
Raises:
ImportError: If a2a-sdk is not installed
"""
is_multiturn = bool(conversation_history and len(conversation_history) > 0)
if turn_number is None:
turn_number = (
len([m for m in (conversation_history or []) if m.role == Role.user]) + 1
)
crewai_event_bus.emit(
agent_branch,
A2ADelegationStartedEvent(
endpoint=endpoint,
task_description=task_description,
agent_id=agent_id,
is_multiturn=is_multiturn,
turn_number=turn_number,
),
)
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
result = loop.run_until_complete(
_execute_a2a_delegation_async(
endpoint=endpoint,
auth=auth,
timeout=timeout,
task_description=task_description,
context=context,
context_id=context_id,
task_id=task_id,
reference_task_ids=reference_task_ids,
metadata=metadata,
extensions=extensions,
conversation_history=conversation_history or [],
is_multiturn=is_multiturn,
turn_number=turn_number,
agent_branch=agent_branch,
agent_id=agent_id,
agent_role=agent_role,
response_model=response_model,
)
)
crewai_event_bus.emit(
agent_branch,
A2ADelegationCompletedEvent(
status=result["status"],
result=result.get("result"),
error=result.get("error"),
is_multiturn=is_multiturn,
),
)
return result
finally:
loop.close()
async def _execute_a2a_delegation_async(
endpoint: str,
auth: AuthScheme | None,
timeout: int,
task_description: str,
context: str | None,
context_id: str | None,
task_id: str | None,
reference_task_ids: list[str] | None,
metadata: dict[str, Any] | None,
extensions: dict[str, Any] | None,
conversation_history: list[Message],
is_multiturn: bool = False,
turn_number: int = 1,
agent_branch: Any | None = None,
agent_id: str | None = None,
agent_role: str | None = None,
response_model: type[BaseModel] | None = None,
) -> dict[str, Any]:
"""Async implementation of A2A delegation with multi-turn support.
Args:
endpoint: A2A agent endpoint URL
auth: Optional AuthScheme for authentication
timeout: Request timeout in seconds
task_description: Task to delegate
context: Optional context
context_id: Context ID for correlation
task_id: Specific task identifier
reference_task_ids: Related task IDs
metadata: Additional metadata
extensions: Protocol extensions
conversation_history: Previous Message objects
is_multiturn: Whether this is a multi-turn conversation
turn_number: Current turn number
agent_branch: Agent tree branch for logging
agent_id: Agent identifier for logging
agent_role: Agent role for logging
response_model: Optional Pydantic model for structured outputs
Returns:
Dictionary with status, result/error, and new history
"""
agent_card = await _fetch_agent_card_async(endpoint, auth, timeout)
validate_auth_against_agent_card(agent_card, auth)
headers: MutableMapping[str, str] = {}
if auth:
async with httpx.AsyncClient(timeout=timeout) as temp_auth_client:
if isinstance(auth, (HTTPDigestAuth, APIKeyAuth)):
configure_auth_client(auth, temp_auth_client)
headers = await auth.apply_auth(temp_auth_client, {})
a2a_agent_name = None
if agent_card.name:
a2a_agent_name = agent_card.name
if turn_number == 1:
agent_id_for_event = agent_id or endpoint
crewai_event_bus.emit(
agent_branch,
A2AConversationStartedEvent(
agent_id=agent_id_for_event,
endpoint=endpoint,
a2a_agent_name=a2a_agent_name,
),
)
message_parts = []
if context:
message_parts.append(f"Context:\n{context}\n\n")
message_parts.append(f"{task_description}")
message_text = "".join(message_parts)
if is_multiturn and conversation_history and not task_id:
if first_task_id := conversation_history[0].task_id:
task_id = first_task_id
parts: PartsDict = {"text": message_text}
if response_model:
parts.update(
{
"metadata": PartsMetadataDict(
mimeType="application/json",
schema=response_model.model_json_schema(),
)
}
)
message = Message(
role=Role.user,
message_id=str(uuid.uuid4()),
parts=[Part(root=TextPart(**parts))],
context_id=context_id,
task_id=task_id,
reference_task_ids=reference_task_ids,
metadata=metadata,
extensions=extensions,
)
transport_protocol = TransportProtocol("JSONRPC")
new_messages: list[Message] = [*conversation_history, message]
crewai_event_bus.emit(
None,
A2AMessageSentEvent(
message=message_text,
turn_number=turn_number,
is_multiturn=is_multiturn,
agent_role=agent_role,
),
)
async with _create_a2a_client(
agent_card=agent_card,
transport_protocol=transport_protocol,
timeout=timeout,
headers=headers,
streaming=True,
auth=auth,
) as client:
result_parts: list[str] = []
final_result: dict[str, Any] | None = None
event_stream = client.send_message(message)
try:
async for event in event_stream:
if isinstance(event, Message):
new_messages.append(event)
for part in event.parts:
if part.root.kind == "text":
text = part.root.text
result_parts.append(text)
elif isinstance(event, tuple):
a2a_task, update = event
if isinstance(update, TaskArtifactUpdateEvent):
artifact = update.artifact
result_parts.extend(
part.root.text
for part in artifact.parts
if part.root.kind == "text"
)
is_final_update = False
if isinstance(update, TaskStatusUpdateEvent):
is_final_update = update.final
if not is_final_update and a2a_task.status.state not in [
TaskState.completed,
TaskState.input_required,
TaskState.failed,
TaskState.rejected,
TaskState.auth_required,
TaskState.canceled,
]:
continue
if a2a_task.status.state == TaskState.completed:
extracted_parts = _extract_task_result_parts(a2a_task)
result_parts.extend(extracted_parts)
if a2a_task.history:
new_messages.extend(a2a_task.history)
response_text = " ".join(result_parts) if result_parts else ""
crewai_event_bus.emit(
None,
A2AResponseReceivedEvent(
response=response_text,
turn_number=turn_number,
is_multiturn=is_multiturn,
status="completed",
agent_role=agent_role,
),
)
final_result = {
"status": "completed",
"result": response_text,
"history": new_messages,
"agent_card": agent_card,
}
break
if a2a_task.status.state == TaskState.input_required:
if a2a_task.history:
new_messages.extend(a2a_task.history)
response_text = _extract_error_message(
a2a_task, "Additional input required"
)
if response_text and not a2a_task.history:
agent_message = Message(
role=Role.agent,
message_id=str(uuid.uuid4()),
parts=[Part(root=TextPart(text=response_text))],
context_id=a2a_task.context_id
if hasattr(a2a_task, "context_id")
else None,
task_id=a2a_task.task_id
if hasattr(a2a_task, "task_id")
else None,
)
new_messages.append(agent_message)
crewai_event_bus.emit(
None,
A2AResponseReceivedEvent(
response=response_text,
turn_number=turn_number,
is_multiturn=is_multiturn,
status="input_required",
agent_role=agent_role,
),
)
final_result = {
"status": "input_required",
"error": response_text,
"history": new_messages,
"agent_card": agent_card,
}
break
if a2a_task.status.state in [TaskState.failed, TaskState.rejected]:
error_msg = _extract_error_message(
a2a_task, "Task failed without error message"
)
if a2a_task.history:
new_messages.extend(a2a_task.history)
final_result = {
"status": "failed",
"error": error_msg,
"history": new_messages,
}
break
if a2a_task.status.state == TaskState.auth_required:
error_msg = _extract_error_message(
a2a_task, "Authentication required"
)
final_result = {
"status": "auth_required",
"error": error_msg,
"history": new_messages,
}
break
if a2a_task.status.state == TaskState.canceled:
error_msg = _extract_error_message(
a2a_task, "Task was canceled"
)
final_result = {
"status": "canceled",
"error": error_msg,
"history": new_messages,
}
break
except Exception as e:
current_exception: Exception | BaseException | None = e
while current_exception:
if hasattr(current_exception, "response"):
response = current_exception.response
if hasattr(response, "text"):
break
if current_exception and hasattr(current_exception, "__cause__"):
current_exception = current_exception.__cause__
raise
finally:
if hasattr(event_stream, "aclose"):
await event_stream.aclose()
if final_result:
return final_result
return {
"status": "completed",
"result": " ".join(result_parts) if result_parts else "",
"history": new_messages,
}
@asynccontextmanager
async def _create_a2a_client(
agent_card: AgentCard,
transport_protocol: TransportProtocol,
timeout: int,
headers: MutableMapping[str, str],
streaming: bool,
auth: AuthScheme | None = None,
) -> AsyncIterator[Client]:
"""Create and configure an A2A client.
Args:
agent_card: The A2A agent card
transport_protocol: Transport protocol to use
timeout: Request timeout in seconds
headers: HTTP headers (already with auth applied)
streaming: Enable streaming responses
auth: Optional AuthScheme for client configuration
Yields:
Configured A2A client instance
"""
async with httpx.AsyncClient(
timeout=timeout,
headers=headers,
) as httpx_client:
if auth and isinstance(auth, (HTTPDigestAuth, APIKeyAuth)):
configure_auth_client(auth, httpx_client)
config = ClientConfig(
httpx_client=httpx_client,
supported_transports=[str(transport_protocol.value)],
streaming=streaming,
accepted_output_modes=["application/json"],
)
factory = ClientFactory(config)
client = factory.create(agent_card)
yield client
def _extract_task_result_parts(a2a_task: A2ATask) -> list[str]:
"""Extract result parts from A2A task history and artifacts.
Args:
a2a_task: A2A Task object with history and artifacts
Returns:
List of result text parts
"""
result_parts: list[str] = []
if a2a_task.history:
for history_msg in reversed(a2a_task.history):
if history_msg.role == Role.agent:
result_parts.extend(
part.root.text
for part in history_msg.parts
if part.root.kind == "text"
)
break
if a2a_task.artifacts:
result_parts.extend(
part.root.text
for artifact in a2a_task.artifacts
for part in artifact.parts
if part.root.kind == "text"
)
return result_parts
def _extract_error_message(a2a_task: A2ATask, default: str) -> str:
"""Extract error message from A2A task.
Args:
a2a_task: A2A Task object
default: Default message if no error found
Returns:
Error message string
"""
if a2a_task.status and a2a_task.status.message:
msg = a2a_task.status.message
if msg:
for part in msg.parts:
if part.root.kind == "text":
return str(part.root.text)
return str(msg)
if a2a_task.history:
for history_msg in reversed(a2a_task.history):
for part in history_msg.parts:
if part.root.kind == "text":
return str(part.root.text)
return default
def create_agent_response_model(agent_ids: tuple[str, ...]) -> type[BaseModel]:
"""Create a dynamic AgentResponse model with Literal types for agent IDs.
Args:
agent_ids: List of available A2A agent IDs
Returns:
Dynamically created Pydantic model with Literal-constrained a2a_ids field
"""
DynamicLiteral = create_literals_from_strings(agent_ids) # noqa: N806
return create_model(
"AgentResponse",
a2a_ids=(
tuple[DynamicLiteral, ...], # type: ignore[valid-type]
Field(
default_factory=tuple,
max_length=len(agent_ids),
description="A2A agent IDs to delegate to.",
),
),
message=(
str,
Field(
description="The message content. If is_a2a=true, this is sent to the A2A agent. If is_a2a=false, this is your final answer ending the conversation."
),
),
is_a2a=(
bool,
Field(
description="Set to true to continue the conversation by sending this message to the A2A agent and awaiting their response. Set to false ONLY when you are completely done and providing your final answer (not when asking questions)."
),
),
__base__=BaseModel,
)
def extract_a2a_agent_ids_from_config(
a2a_config: list[A2AConfig] | A2AConfig | None,
) -> tuple[list[A2AConfig], tuple[str, ...]]:
"""Extract A2A agent IDs from A2A configuration.
Args:
a2a_config: A2A configuration
Returns:
List of A2A agent IDs
"""
if a2a_config is None:
return [], ()
if isinstance(a2a_config, A2AConfig):
a2a_agents = [a2a_config]
else:
a2a_agents = a2a_config
return a2a_agents, tuple(config.endpoint for config in a2a_agents)
def get_a2a_agents_and_response_model(
a2a_config: list[A2AConfig] | A2AConfig | None,
) -> tuple[list[A2AConfig], type[BaseModel]]:
"""Get A2A agent IDs and response model.
Args:
a2a_config: A2A configuration
Returns:
Tuple of A2A agent IDs and response model
"""
a2a_agents, agent_ids = extract_a2a_agent_ids_from_config(a2a_config=a2a_config)
return a2a_agents, create_agent_response_model(agent_ids)

View File

@@ -0,0 +1,570 @@
"""A2A agent wrapping logic for metaclass integration.
Wraps agent classes with A2A delegation capabilities.
"""
from __future__ import annotations
from collections.abc import Callable
from concurrent.futures import ThreadPoolExecutor, as_completed
from functools import wraps
from types import MethodType
from typing import TYPE_CHECKING, Any, cast
from a2a.types import Role
from pydantic import BaseModel, ValidationError
from crewai.a2a.config import A2AConfig
from crewai.a2a.templates import (
AVAILABLE_AGENTS_TEMPLATE,
CONVERSATION_TURN_INFO_TEMPLATE,
PREVIOUS_A2A_CONVERSATION_TEMPLATE,
UNAVAILABLE_AGENTS_NOTICE_TEMPLATE,
)
from crewai.a2a.types import AgentResponseProtocol
from crewai.a2a.utils import (
execute_a2a_delegation,
fetch_agent_card,
get_a2a_agents_and_response_model,
)
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.a2a_events import (
A2AConversationCompletedEvent,
A2AMessageSentEvent,
)
if TYPE_CHECKING:
from a2a.types import AgentCard, Message
from crewai.agent.core import Agent
from crewai.task import Task
from crewai.tools.base_tool import BaseTool
def wrap_agent_with_a2a_instance(agent: Agent) -> None:
"""Wrap an agent instance's execute_task method with A2A support.
This function modifies the agent instance by wrapping its execute_task
method to add A2A delegation capabilities. Should only be called when
the agent has a2a configuration set.
Args:
agent: The agent instance to wrap
"""
original_execute_task = agent.execute_task.__func__
@wraps(original_execute_task)
def execute_task_with_a2a(
self: Agent,
task: Task,
context: str | None = None,
tools: list[BaseTool] | None = None,
) -> str:
"""Execute task with A2A delegation support.
Args:
self: The agent instance
task: The task to execute
context: Optional context for task execution
tools: Optional tools available to the agent
Returns:
Task execution result
"""
if not self.a2a:
return original_execute_task(self, task, context, tools)
a2a_agents, agent_response_model = get_a2a_agents_and_response_model(self.a2a)
return _execute_task_with_a2a(
self=self,
a2a_agents=a2a_agents,
original_fn=original_execute_task,
task=task,
agent_response_model=agent_response_model,
context=context,
tools=tools,
)
object.__setattr__(agent, "execute_task", MethodType(execute_task_with_a2a, agent))
def _fetch_card_from_config(
config: A2AConfig,
) -> tuple[A2AConfig, AgentCard | Exception]:
"""Fetch agent card from A2A config.
Args:
config: A2A configuration
Returns:
Tuple of (config, card or exception)
"""
try:
card = fetch_agent_card(
endpoint=config.endpoint,
auth=config.auth,
timeout=config.timeout,
)
return config, card
except Exception as e:
return config, e
def _fetch_agent_cards_concurrently(
a2a_agents: list[A2AConfig],
) -> tuple[dict[str, AgentCard], dict[str, str]]:
"""Fetch agent cards concurrently for multiple A2A agents.
Args:
a2a_agents: List of A2A agent configurations
Returns:
Tuple of (agent_cards dict, failed_agents dict mapping endpoint to error message)
"""
agent_cards: dict[str, AgentCard] = {}
failed_agents: dict[str, str] = {}
with ThreadPoolExecutor(max_workers=len(a2a_agents)) as executor:
futures = {
executor.submit(_fetch_card_from_config, config): config
for config in a2a_agents
}
for future in as_completed(futures):
config, result = future.result()
if isinstance(result, Exception):
if config.fail_fast:
raise RuntimeError(
f"Failed to fetch agent card from {config.endpoint}. "
f"Ensure the A2A agent is running and accessible. Error: {result}"
) from result
failed_agents[config.endpoint] = str(result)
else:
agent_cards[config.endpoint] = result
return agent_cards, failed_agents
def _execute_task_with_a2a(
self: Agent,
a2a_agents: list[A2AConfig],
original_fn: Callable[..., str],
task: Task,
agent_response_model: type[BaseModel],
context: str | None,
tools: list[BaseTool] | None,
) -> str:
"""Wrap execute_task with A2A delegation logic.
Args:
self: The agent instance
a2a_agents: Dictionary of A2A agent configurations
original_fn: The original execute_task method
task: The task to execute
context: Optional context for task execution
tools: Optional tools available to the agent
agent_response_model: Optional agent response model
Returns:
Task execution result (either from LLM or A2A agent)
"""
original_description: str = task.description
original_output_pydantic = task.output_pydantic
original_response_model = task.response_model
agent_cards, failed_agents = _fetch_agent_cards_concurrently(a2a_agents)
if not agent_cards and a2a_agents and failed_agents:
unavailable_agents_text = ""
for endpoint, error in failed_agents.items():
unavailable_agents_text += f" - {endpoint}: {error}\n"
notice = UNAVAILABLE_AGENTS_NOTICE_TEMPLATE.substitute(
unavailable_agents=unavailable_agents_text
)
task.description = f"{original_description}{notice}"
try:
return original_fn(self, task, context, tools)
finally:
task.description = original_description
task.description = _augment_prompt_with_a2a(
a2a_agents=a2a_agents,
task_description=original_description,
agent_cards=agent_cards,
failed_agents=failed_agents,
)
task.response_model = agent_response_model
try:
raw_result = original_fn(self, task, context, tools)
agent_response = _parse_agent_response(
raw_result=raw_result, agent_response_model=agent_response_model
)
if isinstance(agent_response, BaseModel) and isinstance(
agent_response, AgentResponseProtocol
):
if agent_response.is_a2a:
return _delegate_to_a2a(
self,
agent_response=agent_response,
task=task,
original_fn=original_fn,
context=context,
tools=tools,
agent_cards=agent_cards,
original_task_description=original_description,
)
return str(agent_response.message)
return raw_result
finally:
task.description = original_description
task.output_pydantic = original_output_pydantic
task.response_model = original_response_model
def _augment_prompt_with_a2a(
a2a_agents: list[A2AConfig],
task_description: str,
agent_cards: dict[str, AgentCard],
conversation_history: list[Message] | None = None,
turn_num: int = 0,
max_turns: int | None = None,
failed_agents: dict[str, str] | None = None,
) -> str:
"""Add A2A delegation instructions to prompt.
Args:
a2a_agents: Dictionary of A2A agent configurations
task_description: Original task description
agent_cards: dictionary mapping agent IDs to AgentCards
conversation_history: Previous A2A Messages from conversation
turn_num: Current turn number (0-indexed)
max_turns: Maximum allowed turns (from config)
failed_agents: Dictionary mapping failed agent endpoints to error messages
Returns:
Augmented task description with A2A instructions
"""
if not agent_cards:
return task_description
agents_text = ""
for config in a2a_agents:
if config.endpoint in agent_cards:
card = agent_cards[config.endpoint]
agents_text += f"\n{card.model_dump_json(indent=2, exclude_none=True, include={'description', 'url', 'skills'})}\n"
failed_agents = failed_agents or {}
if failed_agents:
agents_text += "\n<!-- Unavailable Agents -->\n"
for endpoint, error in failed_agents.items():
agents_text += f"\n<!-- Agent: {endpoint}\n Status: Unavailable\n Error: {error} -->\n"
agents_text = AVAILABLE_AGENTS_TEMPLATE.substitute(available_a2a_agents=agents_text)
history_text = ""
if conversation_history:
for msg in conversation_history:
history_text += f"\n{msg.model_dump_json(indent=2, exclude_none=True, exclude={'message_id'})}\n"
history_text = PREVIOUS_A2A_CONVERSATION_TEMPLATE.substitute(
previous_a2a_conversation=history_text
)
turn_info = ""
if max_turns is not None and conversation_history:
turn_count = turn_num + 1
warning = ""
if turn_count >= max_turns:
warning = (
"CRITICAL: This is the FINAL turn. You MUST conclude the conversation now.\n"
"Set is_a2a=false and provide your final response to complete the task."
)
elif turn_count == max_turns - 1:
warning = "WARNING: Next turn will be the last. Consider wrapping up the conversation."
turn_info = CONVERSATION_TURN_INFO_TEMPLATE.substitute(
turn_count=turn_count,
max_turns=max_turns,
warning=warning,
)
return f"""{task_description}
IMPORTANT: You have the ability to delegate this task to remote A2A agents.
{agents_text}
{history_text}{turn_info}
"""
def _parse_agent_response(
raw_result: str | dict[str, Any], agent_response_model: type[BaseModel]
) -> BaseModel | str:
"""Parse LLM output as AgentResponse or return raw agent response.
Args:
raw_result: Raw output from LLM
agent_response_model: The agent response model
Returns:
Parsed AgentResponse or string
"""
if agent_response_model:
try:
if isinstance(raw_result, str):
return agent_response_model.model_validate_json(raw_result)
if isinstance(raw_result, dict):
return agent_response_model.model_validate(raw_result)
except ValidationError:
return cast(str, raw_result)
return cast(str, raw_result)
def _handle_agent_response_and_continue(
self: Agent,
a2a_result: dict[str, Any],
agent_id: str,
agent_cards: dict[str, AgentCard] | None,
a2a_agents: list[A2AConfig],
original_task_description: str,
conversation_history: list[Message],
turn_num: int,
max_turns: int,
task: Task,
original_fn: Callable[..., str],
context: str | None,
tools: list[BaseTool] | None,
agent_response_model: type[BaseModel],
) -> tuple[str | None, str | None]:
"""Handle A2A result and get CrewAI agent's response.
Args:
self: The agent instance
a2a_result: Result from A2A delegation
agent_id: ID of the A2A agent
agent_cards: Pre-fetched agent cards
a2a_agents: List of A2A configurations
original_task_description: Original task description
conversation_history: Conversation history
turn_num: Current turn number
max_turns: Maximum turns allowed
task: The task being executed
original_fn: Original execute_task method
context: Optional context
tools: Optional tools
agent_response_model: Response model for parsing
Returns:
Tuple of (final_result, current_request) where:
- final_result is not None if conversation should end
- current_request is the next message to send if continuing
"""
agent_cards_dict = agent_cards or {}
if "agent_card" in a2a_result and agent_id not in agent_cards_dict:
agent_cards_dict[agent_id] = a2a_result["agent_card"]
task.description = _augment_prompt_with_a2a(
a2a_agents=a2a_agents,
task_description=original_task_description,
conversation_history=conversation_history,
turn_num=turn_num,
max_turns=max_turns,
agent_cards=agent_cards_dict,
)
raw_result = original_fn(self, task, context, tools)
llm_response = _parse_agent_response(
raw_result=raw_result, agent_response_model=agent_response_model
)
if isinstance(llm_response, BaseModel) and isinstance(
llm_response, AgentResponseProtocol
):
if not llm_response.is_a2a:
final_turn_number = turn_num + 1
crewai_event_bus.emit(
None,
A2AMessageSentEvent(
message=str(llm_response.message),
turn_number=final_turn_number,
is_multiturn=True,
agent_role=self.role,
),
)
crewai_event_bus.emit(
None,
A2AConversationCompletedEvent(
status="completed",
final_result=str(llm_response.message),
error=None,
total_turns=final_turn_number,
),
)
return str(llm_response.message), None
return None, str(llm_response.message)
return str(raw_result), None
def _delegate_to_a2a(
self: Agent,
agent_response: AgentResponseProtocol,
task: Task,
original_fn: Callable[..., str],
context: str | None,
tools: list[BaseTool] | None,
agent_cards: dict[str, AgentCard] | None = None,
original_task_description: str | None = None,
) -> str:
"""Delegate to A2A agent with multi-turn conversation support.
Args:
self: The agent instance
agent_response: The AgentResponse indicating delegation
task: The task being executed (for extracting A2A fields)
original_fn: The original execute_task method for follow-ups
context: Optional context for task execution
tools: Optional tools available to the agent
agent_cards: Pre-fetched agent cards from _execute_task_with_a2a
original_task_description: The original task description before A2A augmentation
Returns:
Result from A2A agent
Raises:
ImportError: If a2a-sdk is not installed
"""
a2a_agents, agent_response_model = get_a2a_agents_and_response_model(self.a2a)
agent_ids = tuple(config.endpoint for config in a2a_agents)
current_request = str(agent_response.message)
agent_id = agent_response.a2a_ids[0]
if agent_id not in agent_ids:
raise ValueError(
f"Unknown A2A agent ID(s): {agent_response.a2a_ids} not in {agent_ids}"
)
agent_config = next(filter(lambda x: x.endpoint == agent_id, a2a_agents))
task_config = task.config or {}
context_id = task_config.get("context_id")
task_id_config = task_config.get("task_id")
reference_task_ids = task_config.get("reference_task_ids")
metadata = task_config.get("metadata")
extensions = task_config.get("extensions")
if original_task_description is None:
original_task_description = task.description
conversation_history: list[Message] = []
max_turns = agent_config.max_turns
try:
for turn_num in range(max_turns):
console_formatter = getattr(crewai_event_bus, "_console", None)
agent_branch = None
if console_formatter:
agent_branch = getattr(
console_formatter, "current_agent_branch", None
) or getattr(console_formatter, "current_task_branch", None)
a2a_result = execute_a2a_delegation(
endpoint=agent_config.endpoint,
auth=agent_config.auth,
timeout=agent_config.timeout,
task_description=current_request,
context_id=context_id,
task_id=task_id_config,
reference_task_ids=reference_task_ids,
metadata=metadata,
extensions=extensions,
conversation_history=conversation_history,
agent_id=agent_id,
agent_role=Role.user,
agent_branch=agent_branch,
response_model=agent_config.response_model,
turn_number=turn_num + 1,
)
conversation_history = a2a_result.get("history", [])
if a2a_result["status"] in ["completed", "input_required"]:
final_result, next_request = _handle_agent_response_and_continue(
self=self,
a2a_result=a2a_result,
agent_id=agent_id,
agent_cards=agent_cards,
a2a_agents=a2a_agents,
original_task_description=original_task_description,
conversation_history=conversation_history,
turn_num=turn_num,
max_turns=max_turns,
task=task,
original_fn=original_fn,
context=context,
tools=tools,
agent_response_model=agent_response_model,
)
if final_result is not None:
return final_result
if next_request is not None:
current_request = next_request
continue
error_msg = a2a_result.get("error", "Unknown error")
crewai_event_bus.emit(
None,
A2AConversationCompletedEvent(
status="failed",
final_result=None,
error=error_msg,
total_turns=turn_num + 1,
),
)
raise Exception(f"A2A delegation failed: {error_msg}")
if conversation_history:
for msg in reversed(conversation_history):
if msg.role == Role.agent:
text_parts = [
part.root.text for part in msg.parts if part.root.kind == "text"
]
final_message = (
" ".join(text_parts) if text_parts else "Conversation completed"
)
crewai_event_bus.emit(
None,
A2AConversationCompletedEvent(
status="completed",
final_result=final_message,
error=None,
total_turns=max_turns,
),
)
return final_message
crewai_event_bus.emit(
None,
A2AConversationCompletedEvent(
status="failed",
final_result=None,
error=f"Conversation exceeded maximum turns ({max_turns})",
total_turns=max_turns,
),
)
raise Exception(f"A2A conversation exceeded maximum turns ({max_turns})")
finally:
task.description = original_task_description

View File

@@ -0,0 +1,5 @@
from crewai.agent.core import Agent
from crewai.utilities.training_handler import CrewTrainingHandler
__all__ = ["Agent", "CrewTrainingHandler"]

View File

@@ -2,27 +2,27 @@ from __future__ import annotations
import asyncio
from collections.abc import Sequence
import json
import shutil
import subprocess
import time
from typing import (
TYPE_CHECKING,
Any,
Final,
Literal,
cast,
)
from urllib.parse import urlparse
from pydantic import BaseModel, Field, InstanceOf, PrivateAttr, model_validator
from typing_extensions import Self
from crewai.a2a.config import A2AConfig
from crewai.agents.agent_builder.base_agent import BaseAgent
from crewai.agents.cache.cache_handler import CacheHandler
from crewai.agents.crew_agent_executor import CrewAgentExecutor
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.agent_events import (
AgentExecutionCompletedEvent,
AgentExecutionErrorEvent,
AgentExecutionStartedEvent,
)
from crewai.events.types.knowledge_events import (
KnowledgeQueryCompletedEvent,
KnowledgeQueryFailedEvent,
@@ -70,14 +70,14 @@ if TYPE_CHECKING:
# MCP Connection timeout constants (in seconds)
MCP_CONNECTION_TIMEOUT = 10
MCP_TOOL_EXECUTION_TIMEOUT = 30
MCP_DISCOVERY_TIMEOUT = 15
MCP_MAX_RETRIES = 3
MCP_CONNECTION_TIMEOUT: Final[int] = 10
MCP_TOOL_EXECUTION_TIMEOUT: Final[int] = 30
MCP_DISCOVERY_TIMEOUT: Final[int] = 15
MCP_MAX_RETRIES: Final[int] = 3
# Simple in-memory cache for MCP tool schemas (duration: 5 minutes)
_mcp_schema_cache = {}
_cache_ttl = 300 # 5 minutes
_mcp_schema_cache: dict[str, Any] = {}
_cache_ttl: Final[int] = 300 # 5 minutes
class Agent(BaseAgent):
@@ -197,6 +197,10 @@ class Agent(BaseAgent):
guardrail_max_retries: int = Field(
default=3, description="Maximum number of retries when guardrail fails"
)
a2a: list[A2AConfig] | A2AConfig | None = Field(
default=None,
description="A2A (Agent-to-Agent) configuration for delegating tasks to remote agents. Can be a single A2AConfig or a dict mapping agent IDs to configs.",
)
@model_validator(mode="before")
def validate_from_repository(cls, v: Any) -> dict[str, Any] | None | Any: # noqa: N805
@@ -305,17 +309,19 @@ class Agent(BaseAgent):
# If the task requires output in JSON or Pydantic format,
# append specific instructions to the task prompt to ensure
# that the final answer does not include any code block markers
if task.output_json or task.output_pydantic:
# Skip this if task.response_model is set, as native structured outputs handle schema automatically
if (task.output_json or task.output_pydantic) and not task.response_model:
# Generate the schema based on the output format
if task.output_json:
# schema = json.dumps(task.output_json, indent=2)
schema = generate_model_description(task.output_json)
schema_dict = generate_model_description(task.output_json)
schema = json.dumps(schema_dict["json_schema"]["schema"], indent=2)
task_prompt += "\n" + self.i18n.slice(
"formatted_task_instructions"
).format(output_format=schema)
elif task.output_pydantic:
schema = generate_model_description(task.output_pydantic)
schema_dict = generate_model_description(task.output_pydantic)
schema = json.dumps(schema_dict["json_schema"]["schema"], indent=2)
task_prompt += "\n" + self.i18n.slice(
"formatted_task_instructions"
).format(output_format=schema)
@@ -438,6 +444,13 @@ class Agent(BaseAgent):
else:
task_prompt = self._use_trained_data(task_prompt=task_prompt)
# Import agent events locally to avoid circular imports
from crewai.events.types.agent_events import (
AgentExecutionCompletedEvent,
AgentExecutionErrorEvent,
AgentExecutionStartedEvent,
)
try:
crewai_event_bus.emit(
self,
@@ -618,6 +631,7 @@ class Agent(BaseAgent):
self._rpm_controller.check_or_wait if self._rpm_controller else None
),
callbacks=[TokenCalcHandler(self._token_process)],
response_model=task.response_model if task else None,
)
def get_delegation_tools(self, agents: list[BaseAgent]) -> list[BaseTool]:
@@ -709,7 +723,7 @@ class Agent(BaseAgent):
f"Specific tool '{specific_tool}' not found on MCP server: {server_url}",
)
return tools
return cast(list[BaseTool], tools)
except Exception as e:
self._logger.log(
@@ -739,9 +753,9 @@ class Agent(BaseAgent):
return tools
def _extract_server_name(self, server_url: str) -> str:
@staticmethod
def _extract_server_name(server_url: str) -> str:
"""Extract clean server name from URL for tool prefixing."""
from urllib.parse import urlparse
parsed = urlparse(server_url)
domain = parsed.netloc.replace(".", "_")
@@ -778,7 +792,9 @@ class Agent(BaseAgent):
)
return {}
async def _get_mcp_tool_schemas_async(self, server_params: dict) -> dict[str, dict]:
async def _get_mcp_tool_schemas_async(
self, server_params: dict[str, Any]
) -> dict[str, dict]:
"""Async implementation of MCP tool schema retrieval with timeouts and retries."""
server_url = server_params["url"]
return await self._retry_mcp_discovery(
@@ -787,7 +803,7 @@ class Agent(BaseAgent):
async def _retry_mcp_discovery(
self, operation_func, server_url: str
) -> dict[str, dict]:
) -> dict[str, dict[str, Any]]:
"""Retry MCP discovery operation with exponential backoff, avoiding try-except in loop."""
last_error = None
@@ -815,9 +831,10 @@ class Agent(BaseAgent):
f"Failed to discover MCP tools after {MCP_MAX_RETRIES} attempts: {last_error}"
)
@staticmethod
async def _attempt_mcp_discovery(
self, operation_func, server_url: str
) -> tuple[dict[str, dict] | None, str, bool]:
operation_func, server_url: str
) -> tuple[dict[str, dict[str, Any]] | None, str, bool]:
"""Attempt single MCP discovery operation and return (result, error_message, should_retry)."""
try:
result = await operation_func(server_url)
@@ -851,13 +868,13 @@ class Agent(BaseAgent):
async def _discover_mcp_tools_with_timeout(
self, server_url: str
) -> dict[str, dict]:
) -> dict[str, dict[str, Any]]:
"""Discover MCP tools with timeout wrapper."""
return await asyncio.wait_for(
self._discover_mcp_tools(server_url), timeout=MCP_DISCOVERY_TIMEOUT
)
async def _discover_mcp_tools(self, server_url: str) -> dict[str, dict]:
async def _discover_mcp_tools(self, server_url: str) -> dict[str, dict[str, Any]]:
"""Discover tools from MCP server with proper timeout handling."""
from mcp import ClientSession
from mcp.client.streamable_http import streamablehttp_client
@@ -889,7 +906,9 @@ class Agent(BaseAgent):
}
return schemas
def _json_schema_to_pydantic(self, tool_name: str, json_schema: dict) -> type:
def _json_schema_to_pydantic(
self, tool_name: str, json_schema: dict[str, Any]
) -> type:
"""Convert JSON Schema to Pydantic model for tool arguments.
Args:
@@ -926,7 +945,7 @@ class Agent(BaseAgent):
model_name = f"{tool_name.replace('-', '_').replace(' ', '_')}Schema"
return create_model(model_name, **field_definitions)
def _json_type_to_python(self, field_schema: dict) -> type:
def _json_type_to_python(self, field_schema: dict[str, Any]) -> type:
"""Convert JSON Schema type to Python type.
Args:
@@ -935,7 +954,6 @@ class Agent(BaseAgent):
Returns:
Python type
"""
from typing import Any
json_type = field_schema.get("type")
@@ -965,13 +983,15 @@ class Agent(BaseAgent):
return type_mapping.get(json_type, Any)
def _fetch_amp_mcp_servers(self, mcp_name: str) -> list[dict]:
@staticmethod
def _fetch_amp_mcp_servers(mcp_name: str) -> list[dict]:
"""Fetch MCP server configurations from CrewAI AMP API."""
# TODO: Implement AMP API call to "integrations/mcps" endpoint
# Should return list of server configs with URLs
return []
def get_multimodal_tools(self) -> Sequence[BaseTool]:
@staticmethod
def get_multimodal_tools() -> Sequence[BaseTool]:
from crewai.tools.agent_tools.add_image_tool import AddImageTool
return [AddImageTool()]
@@ -991,8 +1011,9 @@ class Agent(BaseAgent):
)
return []
@staticmethod
def get_output_converter(
self, llm: BaseLLM, text: str, model: type[BaseModel], instructions: str
llm: BaseLLM, text: str, model: type[BaseModel], instructions: str
) -> Converter:
return Converter(llm=llm, text=text, model=model, instructions=instructions)
@@ -1022,7 +1043,8 @@ class Agent(BaseAgent):
)
return task_prompt
def _render_text_description(self, tools: list[Any]) -> str:
@staticmethod
def _render_text_description(tools: list[Any]) -> str:
"""Render the tool name and description in plain text.
Output will be in the format of:

View File

@@ -0,0 +1,76 @@
"""Generic metaclass for agent extensions.
This metaclass enables extension capabilities for agents by detecting
extension fields in class annotations and applying appropriate wrappers.
"""
import warnings
from functools import wraps
from typing import Any
from pydantic import model_validator
from pydantic._internal._model_construction import ModelMetaclass
class AgentMeta(ModelMetaclass):
"""Generic metaclass for agent extensions.
Detects extension fields (like 'a2a') in class annotations and applies
the appropriate wrapper logic to enable extension functionality.
"""
def __new__(
mcs,
name: str,
bases: tuple[type, ...],
namespace: dict[str, Any],
**kwargs: Any,
) -> type:
"""Create a new class with extension support.
Args:
name: The name of the class being created
bases: Base classes
namespace: Class namespace dictionary
**kwargs: Additional keyword arguments
Returns:
The newly created class with extension support if applicable
"""
orig_post_init_setup = namespace.get("post_init_setup")
if orig_post_init_setup is not None:
original_func = (
orig_post_init_setup.wrapped
if hasattr(orig_post_init_setup, "wrapped")
else orig_post_init_setup
)
def post_init_setup_with_extensions(self: Any) -> Any:
"""Wrap post_init_setup to apply extensions after initialization.
Args:
self: The agent instance
Returns:
The agent instance
"""
result = original_func(self)
a2a_value = getattr(self, "a2a", None)
if a2a_value is not None:
from crewai.a2a.wrapper import wrap_agent_with_a2a_instance
wrap_agent_with_a2a_instance(self)
return result
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore", message=".*overrides an existing Pydantic.*"
)
namespace["post_init_setup"] = model_validator(mode="after")(
post_init_setup_with_extensions
)
return super().__new__(mcs, name, bases, namespace, **kwargs)

View File

@@ -18,6 +18,7 @@ from pydantic import (
from pydantic_core import PydanticCustomError
from typing_extensions import Self
from crewai.agent.internal.meta import AgentMeta
from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess
from crewai.agents.cache.cache_handler import CacheHandler
from crewai.agents.tools_handler import ToolsHandler
@@ -56,7 +57,7 @@ PlatformApp = Literal[
PlatformAppOrAction = PlatformApp | str
class BaseAgent(BaseModel, ABC):
class BaseAgent(BaseModel, ABC, metaclass=AgentMeta):
"""Abstract Base Class for all third party agents compatible with CrewAI.
Attributes:

View File

@@ -9,7 +9,7 @@ from __future__ import annotations
from collections.abc import Callable
from typing import TYPE_CHECKING, Any, Literal, cast
from pydantic import GetCoreSchemaHandler
from pydantic import BaseModel, GetCoreSchemaHandler
from pydantic_core import CoreSchema, core_schema
from crewai.agents.agent_builder.base_agent_executor_mixin import CrewAgentExecutorMixin
@@ -65,7 +65,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
def __init__(
self,
llm: BaseLLM | Any,
llm: BaseLLM | Any | None,
task: Task,
crew: Crew,
agent: Agent,
@@ -82,6 +82,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
respect_context_window: bool = False,
request_within_rpm_limit: Callable[[], bool] | None = None,
callbacks: list[Any] | None = None,
response_model: type[BaseModel] | None = None,
) -> None:
"""Initialize executor.
@@ -103,6 +104,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
respect_context_window: Respect context limits.
request_within_rpm_limit: RPM limit check function.
callbacks: Optional callbacks list.
response_model: Optional Pydantic model for structured outputs.
"""
self._i18n: I18N = I18N()
self.llm = llm
@@ -123,6 +125,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
self.function_calling_llm = function_calling_llm
self.respect_context_window = respect_context_window
self.request_within_rpm_limit = request_within_rpm_limit
self.response_model = response_model
self.ask_for_human_input = False
self.messages: list[LLMMessage] = []
self.iterations = 0
@@ -221,6 +224,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
printer=self._printer,
from_task=self.task,
from_agent=self.agent,
response_model=self.response_model,
)
formatted_answer = process_llm_response(answer, self.use_stop_words)

View File

@@ -3,10 +3,17 @@ import json
import os
from pathlib import Path
import sys
from typing import BinaryIO, cast
from cryptography.fernet import Fernet
if sys.platform == "win32":
import msvcrt
else:
import fcntl
class TokenManager:
def __init__(self, file_path: str = "tokens.enc") -> None:
"""
@@ -18,21 +25,74 @@ class TokenManager:
self.key = self._get_or_create_key()
self.fernet = Fernet(self.key)
@staticmethod
def _acquire_lock(file_handle: BinaryIO) -> None:
"""
Acquire an exclusive lock on a file handle.
Args:
file_handle: Open file handle to lock.
"""
if sys.platform == "win32":
msvcrt.locking(file_handle.fileno(), msvcrt.LK_LOCK, 1)
else:
fcntl.flock(file_handle.fileno(), fcntl.LOCK_EX)
@staticmethod
def _release_lock(file_handle: BinaryIO) -> None:
"""
Release the lock on a file handle.
Args:
file_handle: Open file handle to unlock.
"""
if sys.platform == "win32":
msvcrt.locking(file_handle.fileno(), msvcrt.LK_UNLCK, 1)
else:
fcntl.flock(file_handle.fileno(), fcntl.LOCK_UN)
def _get_or_create_key(self) -> bytes:
"""
Get or create the encryption key.
Get or create the encryption key with file locking to prevent race conditions.
:return: The encryption key.
Returns:
The encryption key.
"""
key_filename = "secret.key"
key = self.read_secure_file(key_filename)
storage_path = self.get_secure_storage_path()
if key is not None:
key = self.read_secure_file(key_filename)
if key is not None and len(key) == 44:
return key
new_key = Fernet.generate_key()
self.save_secure_file(key_filename, new_key)
return new_key
lock_file_path = storage_path / f"{key_filename}.lock"
try:
lock_file_path.touch()
with open(lock_file_path, "r+b") as lock_file:
self._acquire_lock(lock_file)
try:
key = self.read_secure_file(key_filename)
if key is not None and len(key) == 44:
return key
new_key = Fernet.generate_key()
self.save_secure_file(key_filename, new_key)
return new_key
finally:
try:
self._release_lock(lock_file)
except OSError:
pass
except OSError:
key = self.read_secure_file(key_filename)
if key is not None and len(key) == 44:
return key
new_key = Fernet.generate_key()
self.save_secure_file(key_filename, new_key)
return new_key
def save_tokens(self, access_token: str, expires_at: int) -> None:
"""
@@ -59,14 +119,14 @@ class TokenManager:
if encrypted_data is None:
return None
decrypted_data = self.fernet.decrypt(encrypted_data) # type: ignore
decrypted_data = self.fernet.decrypt(encrypted_data)
data = json.loads(decrypted_data)
expiration = datetime.fromisoformat(data["expiration"])
if expiration <= datetime.now():
return None
return data["access_token"]
return cast(str | None, data["access_token"])
def clear_tokens(self) -> None:
"""
@@ -74,20 +134,18 @@ class TokenManager:
"""
self.delete_secure_file(self.file_path)
def get_secure_storage_path(self) -> Path:
@staticmethod
def get_secure_storage_path() -> Path:
"""
Get the secure storage path based on the operating system.
:return: The secure storage path.
"""
if sys.platform == "win32":
# Windows: Use %LOCALAPPDATA%
base_path = os.environ.get("LOCALAPPDATA")
elif sys.platform == "darwin":
# macOS: Use ~/Library/Application Support
base_path = os.path.expanduser("~/Library/Application Support")
else:
# Linux and other Unix-like: Use ~/.local/share
base_path = os.path.expanduser("~/.local/share")
app_name = "crewai/credentials"
@@ -110,7 +168,6 @@ class TokenManager:
with open(file_path, "wb") as f:
f.write(content)
# Set appropriate permissions (read/write for owner only)
os.chmod(file_path, 0o600)
def read_secure_file(self, filename: str) -> bytes | None:

View File

@@ -8,21 +8,15 @@ This module provides the event infrastructure that allows users to:
- Declare handler dependencies for ordered execution
"""
from __future__ import annotations
from typing import TYPE_CHECKING
from crewai.events.base_event_listener import BaseEventListener
from crewai.events.depends import Depends
from crewai.events.event_bus import crewai_event_bus
from crewai.events.handler_graph import CircularDependencyError
from crewai.events.types.agent_events import (
AgentEvaluationCompletedEvent,
AgentEvaluationFailedEvent,
AgentEvaluationStartedEvent,
AgentExecutionCompletedEvent,
AgentExecutionErrorEvent,
AgentExecutionStartedEvent,
LiteAgentExecutionCompletedEvent,
LiteAgentExecutionErrorEvent,
LiteAgentExecutionStartedEvent,
)
from crewai.events.types.crew_events import (
CrewKickoffCompletedEvent,
CrewKickoffFailedEvent,
@@ -100,6 +94,20 @@ from crewai.events.types.tool_usage_events import (
)
if TYPE_CHECKING:
from crewai.events.types.agent_events import (
AgentEvaluationCompletedEvent,
AgentEvaluationFailedEvent,
AgentEvaluationStartedEvent,
AgentExecutionCompletedEvent,
AgentExecutionErrorEvent,
AgentExecutionStartedEvent,
LiteAgentExecutionCompletedEvent,
LiteAgentExecutionErrorEvent,
LiteAgentExecutionStartedEvent,
)
__all__ = [
"AgentEvaluationCompletedEvent",
"AgentEvaluationFailedEvent",
@@ -170,3 +178,27 @@ __all__ = [
"ToolValidateInputErrorEvent",
"crewai_event_bus",
]
_AGENT_EVENT_MAPPING = {
"AgentEvaluationCompletedEvent": "crewai.events.types.agent_events",
"AgentEvaluationFailedEvent": "crewai.events.types.agent_events",
"AgentEvaluationStartedEvent": "crewai.events.types.agent_events",
"AgentExecutionCompletedEvent": "crewai.events.types.agent_events",
"AgentExecutionErrorEvent": "crewai.events.types.agent_events",
"AgentExecutionStartedEvent": "crewai.events.types.agent_events",
"LiteAgentExecutionCompletedEvent": "crewai.events.types.agent_events",
"LiteAgentExecutionErrorEvent": "crewai.events.types.agent_events",
"LiteAgentExecutionStartedEvent": "crewai.events.types.agent_events",
}
def __getattr__(name: str):
"""Lazy import for agent events to avoid circular imports."""
if name in _AGENT_EVENT_MAPPING:
import importlib
module_path = _AGENT_EVENT_MAPPING[name]
module = importlib.import_module(module_path)
return getattr(module, name)
msg = f"module {__name__!r} has no attribute {name!r}"
raise AttributeError(msg)

View File

@@ -1,16 +1,26 @@
"""Base event listener for CrewAI event system."""
from abc import ABC, abstractmethod
from crewai.events.event_bus import CrewAIEventsBus, crewai_event_bus
class BaseEventListener(ABC):
"""Abstract base class for event listeners."""
verbose: bool = False
def __init__(self):
def __init__(self) -> None:
"""Initialize the event listener and register handlers."""
super().__init__()
self.setup_listeners(crewai_event_bus)
crewai_event_bus.validate_dependencies()
@abstractmethod
def setup_listeners(self, crewai_event_bus: CrewAIEventsBus):
def setup_listeners(self, crewai_event_bus: CrewAIEventsBus) -> None:
"""Setup event listeners on the event bus.
Args:
crewai_event_bus: The event bus to register listeners on.
"""
pass

View File

@@ -1,12 +1,21 @@
from __future__ import annotations
from io import StringIO
from typing import Any
import threading
from typing import TYPE_CHECKING, Any
from pydantic import Field, PrivateAttr
from crewai.events.base_event_listener import BaseEventListener
from crewai.events.listeners.memory_listener import MemoryListener
from crewai.events.listeners.tracing.trace_listener import TraceCollectionListener
from crewai.events.types.a2a_events import (
A2AConversationCompletedEvent,
A2AConversationStartedEvent,
A2ADelegationCompletedEvent,
A2ADelegationStartedEvent,
A2AMessageSentEvent,
A2AResponseReceivedEvent,
)
from crewai.events.types.agent_events import (
AgentExecutionCompletedEvent,
AgentExecutionStartedEvent,
@@ -79,6 +88,10 @@ from crewai.utilities import Logger
from crewai.utilities.constants import EMITTER_COLOR
if TYPE_CHECKING:
from crewai.events.event_bus import CrewAIEventsBus
class EventListener(BaseEventListener):
_instance = None
_telemetry: Telemetry = PrivateAttr(default_factory=lambda: Telemetry())
@@ -105,19 +118,24 @@ class EventListener(BaseEventListener):
self.method_branches = {}
self._initialized = True
self.formatter = ConsoleFormatter(verbose=True)
self._crew_tree_lock = threading.Condition()
MemoryListener(formatter=self.formatter)
# Initialize trace listener with formatter for memory event handling
trace_listener = TraceCollectionListener()
trace_listener.formatter = self.formatter
# ----------- CREW EVENTS -----------
def setup_listeners(self, crewai_event_bus):
def setup_listeners(self, crewai_event_bus: CrewAIEventsBus) -> None:
@crewai_event_bus.on(CrewKickoffStartedEvent)
def on_crew_started(source, event: CrewKickoffStartedEvent):
self.formatter.create_crew_tree(event.crew_name or "Crew", source.id)
self._telemetry.crew_execution_span(source, event.inputs)
def on_crew_started(source, event: CrewKickoffStartedEvent) -> None:
with self._crew_tree_lock:
self.formatter.create_crew_tree(event.crew_name or "Crew", source.id)
self._telemetry.crew_execution_span(source, event.inputs)
self._crew_tree_lock.notify_all()
@crewai_event_bus.on(CrewKickoffCompletedEvent)
def on_crew_completed(source, event: CrewKickoffCompletedEvent):
def on_crew_completed(source, event: CrewKickoffCompletedEvent) -> None:
# Handle telemetry
final_string_output = event.output.raw
self._telemetry.end_crew(source, final_string_output)
@@ -131,7 +149,7 @@ class EventListener(BaseEventListener):
)
@crewai_event_bus.on(CrewKickoffFailedEvent)
def on_crew_failed(source, event: CrewKickoffFailedEvent):
def on_crew_failed(source, event: CrewKickoffFailedEvent) -> None:
self.formatter.update_crew_tree(
self.formatter.current_crew_tree,
event.crew_name or "Crew",
@@ -140,23 +158,23 @@ class EventListener(BaseEventListener):
)
@crewai_event_bus.on(CrewTrainStartedEvent)
def on_crew_train_started(source, event: CrewTrainStartedEvent):
def on_crew_train_started(source, event: CrewTrainStartedEvent) -> None:
self.formatter.handle_crew_train_started(
event.crew_name or "Crew", str(event.timestamp)
)
@crewai_event_bus.on(CrewTrainCompletedEvent)
def on_crew_train_completed(source, event: CrewTrainCompletedEvent):
def on_crew_train_completed(source, event: CrewTrainCompletedEvent) -> None:
self.formatter.handle_crew_train_completed(
event.crew_name or "Crew", str(event.timestamp)
)
@crewai_event_bus.on(CrewTrainFailedEvent)
def on_crew_train_failed(source, event: CrewTrainFailedEvent):
def on_crew_train_failed(source, event: CrewTrainFailedEvent) -> None:
self.formatter.handle_crew_train_failed(event.crew_name or "Crew")
@crewai_event_bus.on(CrewTestResultEvent)
def on_crew_test_result(source, event: CrewTestResultEvent):
def on_crew_test_result(source, event: CrewTestResultEvent) -> None:
self._telemetry.individual_test_result_span(
source.crew,
event.quality,
@@ -167,14 +185,22 @@ class EventListener(BaseEventListener):
# ----------- TASK EVENTS -----------
@crewai_event_bus.on(TaskStartedEvent)
def on_task_started(source, event: TaskStartedEvent):
def on_task_started(source, event: TaskStartedEvent) -> None:
span = self._telemetry.task_started(crew=source.agent.crew, task=source)
self.execution_spans[source] = span
# Pass both task ID and task name (if set)
task_name = source.name if hasattr(source, "name") and source.name else None
self.formatter.create_task_branch(
self.formatter.current_crew_tree, source.id, task_name
)
with self._crew_tree_lock:
self._crew_tree_lock.wait_for(
lambda: self.formatter.current_crew_tree is not None, timeout=5.0
)
if self.formatter.current_crew_tree is not None:
task_name = (
source.name if hasattr(source, "name") and source.name else None
)
self.formatter.create_task_branch(
self.formatter.current_crew_tree, source.id, task_name
)
@crewai_event_bus.on(TaskCompletedEvent)
def on_task_completed(source, event: TaskCompletedEvent):
@@ -533,5 +559,61 @@ class EventListener(BaseEventListener):
event.verbose,
)
@crewai_event_bus.on(A2ADelegationStartedEvent)
def on_a2a_delegation_started(source, event: A2ADelegationStartedEvent):
self.formatter.handle_a2a_delegation_started(
event.endpoint,
event.task_description,
event.agent_id,
event.is_multiturn,
event.turn_number,
)
@crewai_event_bus.on(A2ADelegationCompletedEvent)
def on_a2a_delegation_completed(source, event: A2ADelegationCompletedEvent):
self.formatter.handle_a2a_delegation_completed(
event.status,
event.result,
event.error,
event.is_multiturn,
)
@crewai_event_bus.on(A2AConversationStartedEvent)
def on_a2a_conversation_started(source, event: A2AConversationStartedEvent):
# Store A2A agent name for display in conversation tree
if event.a2a_agent_name:
self.formatter._current_a2a_agent_name = event.a2a_agent_name
self.formatter.handle_a2a_conversation_started(
event.agent_id,
event.endpoint,
)
@crewai_event_bus.on(A2AMessageSentEvent)
def on_a2a_message_sent(source, event: A2AMessageSentEvent):
self.formatter.handle_a2a_message_sent(
event.message,
event.turn_number,
event.agent_role,
)
@crewai_event_bus.on(A2AResponseReceivedEvent)
def on_a2a_response_received(source, event: A2AResponseReceivedEvent):
self.formatter.handle_a2a_response_received(
event.response,
event.turn_number,
event.status,
event.agent_role,
)
@crewai_event_bus.on(A2AConversationCompletedEvent)
def on_a2a_conversation_completed(source, event: A2AConversationCompletedEvent):
self.formatter.handle_a2a_conversation_completed(
event.status,
event.final_result,
event.error,
event.total_turns,
)
event_listener = EventListener()

View File

@@ -1,106 +0,0 @@
from crewai.events.base_event_listener import BaseEventListener
from crewai.events.types.memory_events import (
MemoryQueryCompletedEvent,
MemoryQueryFailedEvent,
MemoryRetrievalCompletedEvent,
MemoryRetrievalStartedEvent,
MemorySaveCompletedEvent,
MemorySaveFailedEvent,
MemorySaveStartedEvent,
)
class MemoryListener(BaseEventListener):
def __init__(self, formatter):
super().__init__()
self.formatter = formatter
self.memory_retrieval_in_progress = False
self.memory_save_in_progress = False
def setup_listeners(self, crewai_event_bus):
@crewai_event_bus.on(MemoryRetrievalStartedEvent)
def on_memory_retrieval_started(source, event: MemoryRetrievalStartedEvent):
if self.memory_retrieval_in_progress:
return
self.memory_retrieval_in_progress = True
self.formatter.handle_memory_retrieval_started(
self.formatter.current_agent_branch,
self.formatter.current_crew_tree,
)
@crewai_event_bus.on(MemoryRetrievalCompletedEvent)
def on_memory_retrieval_completed(source, event: MemoryRetrievalCompletedEvent):
if not self.memory_retrieval_in_progress:
return
self.memory_retrieval_in_progress = False
self.formatter.handle_memory_retrieval_completed(
self.formatter.current_agent_branch,
self.formatter.current_crew_tree,
event.memory_content,
event.retrieval_time_ms,
)
@crewai_event_bus.on(MemoryQueryCompletedEvent)
def on_memory_query_completed(source, event: MemoryQueryCompletedEvent):
if not self.memory_retrieval_in_progress:
return
self.formatter.handle_memory_query_completed(
self.formatter.current_agent_branch,
event.source_type,
event.query_time_ms,
self.formatter.current_crew_tree,
)
@crewai_event_bus.on(MemoryQueryFailedEvent)
def on_memory_query_failed(source, event: MemoryQueryFailedEvent):
if not self.memory_retrieval_in_progress:
return
self.formatter.handle_memory_query_failed(
self.formatter.current_agent_branch,
self.formatter.current_crew_tree,
event.error,
event.source_type,
)
@crewai_event_bus.on(MemorySaveStartedEvent)
def on_memory_save_started(source, event: MemorySaveStartedEvent):
if self.memory_save_in_progress:
return
self.memory_save_in_progress = True
self.formatter.handle_memory_save_started(
self.formatter.current_agent_branch,
self.formatter.current_crew_tree,
)
@crewai_event_bus.on(MemorySaveCompletedEvent)
def on_memory_save_completed(source, event: MemorySaveCompletedEvent):
if not self.memory_save_in_progress:
return
self.memory_save_in_progress = False
self.formatter.handle_memory_save_completed(
self.formatter.current_agent_branch,
self.formatter.current_crew_tree,
event.save_time_ms,
event.source_type,
)
@crewai_event_bus.on(MemorySaveFailedEvent)
def on_memory_save_failed(source, event: MemorySaveFailedEvent):
if not self.memory_save_in_progress:
return
self.formatter.handle_memory_save_failed(
self.formatter.current_agent_branch,
event.error,
event.source_type,
self.formatter.current_crew_tree,
)

View File

@@ -73,15 +73,19 @@ class FirstTimeTraceHandler:
self.is_first_time = should_auto_collect_first_time_traces()
return self.is_first_time
def set_batch_manager(self, batch_manager: TraceBatchManager):
"""Set reference to batch manager for sending events."""
def set_batch_manager(self, batch_manager: TraceBatchManager) -> None:
"""Set reference to batch manager for sending events.
Args:
batch_manager: The trace batch manager instance.
"""
self.batch_manager = batch_manager
def mark_events_collected(self):
def mark_events_collected(self) -> None:
"""Mark that events have been collected during execution."""
self.collected_events = True
def handle_execution_completion(self):
def handle_execution_completion(self) -> None:
"""Handle the completion flow as shown in your diagram."""
if not self.is_first_time or not self.collected_events:
return

View File

@@ -44,6 +44,7 @@ class TraceBatchManager:
def __init__(self) -> None:
self._init_lock = Lock()
self._batch_ready_cv = Condition(self._init_lock)
self._pending_events_lock = Lock()
self._pending_events_cv = Condition(self._pending_events_lock)
self._pending_events_count = 0
@@ -94,6 +95,8 @@ class TraceBatchManager:
)
self.backend_initialized = True
self._batch_ready_cv.notify_all()
return self.current_batch
def _initialize_backend_batch(
@@ -161,13 +164,13 @@ class TraceBatchManager:
f"Error initializing trace batch: {e}. Continuing without tracing."
)
def begin_event_processing(self):
"""Mark that an event handler started processing (for synchronization)"""
def begin_event_processing(self) -> None:
"""Mark that an event handler started processing (for synchronization)."""
with self._pending_events_lock:
self._pending_events_count += 1
def end_event_processing(self):
"""Mark that an event handler finished processing (for synchronization)"""
def end_event_processing(self) -> None:
"""Mark that an event handler finished processing (for synchronization)."""
with self._pending_events_cv:
self._pending_events_count -= 1
if self._pending_events_count == 0:
@@ -385,6 +388,22 @@ class TraceBatchManager:
"""Check if batch is initialized"""
return self.current_batch is not None
def wait_for_batch_initialization(self, timeout: float = 2.0) -> bool:
"""Wait for batch to be initialized.
Args:
timeout: Maximum time to wait in seconds (default: 2.0)
Returns:
True if batch was initialized, False if timeout occurred
"""
with self._batch_ready_cv:
if self.current_batch is not None:
return True
return self._batch_ready_cv.wait_for(
lambda: self.current_batch is not None, timeout=timeout
)
def record_start_time(self, key: str):
"""Record start time for duration calculation"""
self.execution_start_times[key] = datetime.now(timezone.utc)

View File

@@ -1,10 +1,16 @@
"""Trace collection listener for orchestrating trace collection."""
import os
from typing import Any, ClassVar
import uuid
from typing_extensions import Self
from crewai.cli.authentication.token import AuthError, get_auth_token
from crewai.cli.version import get_crewai_version
from crewai.events.base_event_listener import BaseEventListener
from crewai.events.event_bus import CrewAIEventsBus
from crewai.events.utils.console_formatter import ConsoleFormatter
from crewai.events.listeners.tracing.first_time_trace_handler import (
FirstTimeTraceHandler,
)
@@ -53,6 +59,8 @@ from crewai.events.types.memory_events import (
MemoryQueryCompletedEvent,
MemoryQueryFailedEvent,
MemoryQueryStartedEvent,
MemoryRetrievalCompletedEvent,
MemoryRetrievalStartedEvent,
MemorySaveCompletedEvent,
MemorySaveFailedEvent,
MemorySaveStartedEvent,
@@ -75,9 +83,7 @@ from crewai.events.types.tool_usage_events import (
class TraceCollectionListener(BaseEventListener):
"""
Trace collection listener that orchestrates trace collection
"""
"""Trace collection listener that orchestrates trace collection."""
complex_events: ClassVar[list[str]] = [
"task_started",
@@ -88,11 +94,12 @@ class TraceCollectionListener(BaseEventListener):
"agent_execution_completed",
]
_instance = None
_initialized = False
_listeners_setup = False
_instance: Self | None = None
_initialized: bool = False
_listeners_setup: bool = False
def __new__(cls, batch_manager: TraceBatchManager | None = None):
def __new__(cls, batch_manager: TraceBatchManager | None = None) -> Self:
"""Create or return singleton instance."""
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
@@ -100,7 +107,14 @@ class TraceCollectionListener(BaseEventListener):
def __init__(
self,
batch_manager: TraceBatchManager | None = None,
):
formatter: ConsoleFormatter | None = None,
) -> None:
"""Initialize trace collection listener.
Args:
batch_manager: Optional trace batch manager instance.
formatter: Optional console formatter for output.
"""
if self._initialized:
return
@@ -108,19 +122,22 @@ class TraceCollectionListener(BaseEventListener):
self.batch_manager = batch_manager or TraceBatchManager()
self._initialized = True
self.first_time_handler = FirstTimeTraceHandler()
self.formatter = formatter
self.memory_retrieval_in_progress = False
self.memory_save_in_progress = False
if self.first_time_handler.initialize_for_first_time_user():
self.first_time_handler.set_batch_manager(self.batch_manager)
def _check_authenticated(self) -> bool:
"""Check if tracing should be enabled"""
"""Check if tracing should be enabled."""
try:
return bool(get_auth_token())
except AuthError:
return False
def _get_user_context(self) -> dict[str, str]:
"""Extract user context for tracing"""
"""Extract user context for tracing."""
return {
"user_id": os.getenv("CREWAI_USER_ID", "anonymous"),
"organization_id": os.getenv("CREWAI_ORG_ID", ""),
@@ -128,9 +145,12 @@ class TraceCollectionListener(BaseEventListener):
"trace_id": str(uuid.uuid4()),
}
def setup_listeners(self, crewai_event_bus):
"""Setup event listeners - delegates to specific handlers"""
def setup_listeners(self, crewai_event_bus: CrewAIEventsBus) -> None:
"""Setup event listeners - delegates to specific handlers.
Args:
crewai_event_bus: The event bus to register listeners on.
"""
if self._listeners_setup:
return
@@ -140,50 +160,52 @@ class TraceCollectionListener(BaseEventListener):
self._listeners_setup = True
def _register_flow_event_handlers(self, event_bus):
"""Register handlers for flow events"""
def _register_flow_event_handlers(self, event_bus: CrewAIEventsBus) -> None:
"""Register handlers for flow events."""
@event_bus.on(FlowCreatedEvent)
def on_flow_created(source, event):
def on_flow_created(source: Any, event: FlowCreatedEvent) -> None:
pass
@event_bus.on(FlowStartedEvent)
def on_flow_started(source, event):
def on_flow_started(source: Any, event: FlowStartedEvent) -> None:
if not self.batch_manager.is_batch_initialized():
self._initialize_flow_batch(source, event)
self._handle_trace_event("flow_started", source, event)
@event_bus.on(MethodExecutionStartedEvent)
def on_method_started(source, event):
def on_method_started(source: Any, event: MethodExecutionStartedEvent) -> None:
self._handle_trace_event("method_execution_started", source, event)
@event_bus.on(MethodExecutionFinishedEvent)
def on_method_finished(source, event):
def on_method_finished(
source: Any, event: MethodExecutionFinishedEvent
) -> None:
self._handle_trace_event("method_execution_finished", source, event)
@event_bus.on(MethodExecutionFailedEvent)
def on_method_failed(source, event):
def on_method_failed(source: Any, event: MethodExecutionFailedEvent) -> None:
self._handle_trace_event("method_execution_failed", source, event)
@event_bus.on(FlowFinishedEvent)
def on_flow_finished(source, event):
def on_flow_finished(source: Any, event: FlowFinishedEvent) -> None:
self._handle_trace_event("flow_finished", source, event)
@event_bus.on(FlowPlotEvent)
def on_flow_plot(source, event):
def on_flow_plot(source: Any, event: FlowPlotEvent) -> None:
self._handle_action_event("flow_plot", source, event)
def _register_context_event_handlers(self, event_bus):
"""Register handlers for context events (start/end)"""
def _register_context_event_handlers(self, event_bus: CrewAIEventsBus) -> None:
"""Register handlers for context events (start/end)."""
@event_bus.on(CrewKickoffStartedEvent)
def on_crew_started(source, event):
def on_crew_started(source: Any, event: CrewKickoffStartedEvent) -> None:
if not self.batch_manager.is_batch_initialized():
self._initialize_crew_batch(source, event)
self._handle_trace_event("crew_kickoff_started", source, event)
@event_bus.on(CrewKickoffCompletedEvent)
def on_crew_completed(source, event):
def on_crew_completed(source: Any, event: CrewKickoffCompletedEvent) -> None:
self._handle_trace_event("crew_kickoff_completed", source, event)
if self.batch_manager.batch_owner_type == "crew":
if self.first_time_handler.is_first_time:
@@ -193,7 +215,7 @@ class TraceCollectionListener(BaseEventListener):
self.batch_manager.finalize_batch()
@event_bus.on(CrewKickoffFailedEvent)
def on_crew_failed(source, event):
def on_crew_failed(source: Any, event: CrewKickoffFailedEvent) -> None:
self._handle_trace_event("crew_kickoff_failed", source, event)
if self.first_time_handler.is_first_time:
self.first_time_handler.mark_events_collected()
@@ -202,134 +224,245 @@ class TraceCollectionListener(BaseEventListener):
self.batch_manager.finalize_batch()
@event_bus.on(TaskStartedEvent)
def on_task_started(source, event):
def on_task_started(source: Any, event: TaskStartedEvent) -> None:
self._handle_trace_event("task_started", source, event)
@event_bus.on(TaskCompletedEvent)
def on_task_completed(source, event):
def on_task_completed(source: Any, event: TaskCompletedEvent) -> None:
self._handle_trace_event("task_completed", source, event)
@event_bus.on(TaskFailedEvent)
def on_task_failed(source, event):
def on_task_failed(source: Any, event: TaskFailedEvent) -> None:
self._handle_trace_event("task_failed", source, event)
@event_bus.on(AgentExecutionStartedEvent)
def on_agent_started(source, event):
def on_agent_started(source: Any, event: AgentExecutionStartedEvent) -> None:
self._handle_trace_event("agent_execution_started", source, event)
@event_bus.on(AgentExecutionCompletedEvent)
def on_agent_completed(source, event):
def on_agent_completed(
source: Any, event: AgentExecutionCompletedEvent
) -> None:
self._handle_trace_event("agent_execution_completed", source, event)
@event_bus.on(LiteAgentExecutionStartedEvent)
def on_lite_agent_started(source, event):
def on_lite_agent_started(
source: Any, event: LiteAgentExecutionStartedEvent
) -> None:
self._handle_trace_event("lite_agent_execution_started", source, event)
@event_bus.on(LiteAgentExecutionCompletedEvent)
def on_lite_agent_completed(source, event):
def on_lite_agent_completed(
source: Any, event: LiteAgentExecutionCompletedEvent
) -> None:
self._handle_trace_event("lite_agent_execution_completed", source, event)
@event_bus.on(LiteAgentExecutionErrorEvent)
def on_lite_agent_error(source, event):
def on_lite_agent_error(
source: Any, event: LiteAgentExecutionErrorEvent
) -> None:
self._handle_trace_event("lite_agent_execution_error", source, event)
@event_bus.on(AgentExecutionErrorEvent)
def on_agent_error(source, event):
def on_agent_error(source: Any, event: AgentExecutionErrorEvent) -> None:
self._handle_trace_event("agent_execution_error", source, event)
@event_bus.on(LLMGuardrailStartedEvent)
def on_guardrail_started(source, event):
def on_guardrail_started(source: Any, event: LLMGuardrailStartedEvent) -> None:
self._handle_trace_event("llm_guardrail_started", source, event)
@event_bus.on(LLMGuardrailCompletedEvent)
def on_guardrail_completed(source, event):
def on_guardrail_completed(
source: Any, event: LLMGuardrailCompletedEvent
) -> None:
self._handle_trace_event("llm_guardrail_completed", source, event)
def _register_action_event_handlers(self, event_bus):
"""Register handlers for action events (LLM calls, tool usage)"""
def _register_action_event_handlers(self, event_bus: CrewAIEventsBus) -> None:
"""Register handlers for action events (LLM calls, tool usage)."""
@event_bus.on(LLMCallStartedEvent)
def on_llm_call_started(source, event):
def on_llm_call_started(source: Any, event: LLMCallStartedEvent) -> None:
self._handle_action_event("llm_call_started", source, event)
@event_bus.on(LLMCallCompletedEvent)
def on_llm_call_completed(source, event):
def on_llm_call_completed(source: Any, event: LLMCallCompletedEvent) -> None:
self._handle_action_event("llm_call_completed", source, event)
@event_bus.on(LLMCallFailedEvent)
def on_llm_call_failed(source, event):
def on_llm_call_failed(source: Any, event: LLMCallFailedEvent) -> None:
self._handle_action_event("llm_call_failed", source, event)
@event_bus.on(ToolUsageStartedEvent)
def on_tool_started(source, event):
def on_tool_started(source: Any, event: ToolUsageStartedEvent) -> None:
self._handle_action_event("tool_usage_started", source, event)
@event_bus.on(ToolUsageFinishedEvent)
def on_tool_finished(source, event):
def on_tool_finished(source: Any, event: ToolUsageFinishedEvent) -> None:
self._handle_action_event("tool_usage_finished", source, event)
@event_bus.on(ToolUsageErrorEvent)
def on_tool_error(source, event):
def on_tool_error(source: Any, event: ToolUsageErrorEvent) -> None:
self._handle_action_event("tool_usage_error", source, event)
@event_bus.on(MemoryQueryStartedEvent)
def on_memory_query_started(source, event):
def on_memory_query_started(
source: Any, event: MemoryQueryStartedEvent
) -> None:
self._handle_action_event("memory_query_started", source, event)
@event_bus.on(MemoryQueryCompletedEvent)
def on_memory_query_completed(source, event):
def on_memory_query_completed(
source: Any, event: MemoryQueryCompletedEvent
) -> None:
self._handle_action_event("memory_query_completed", source, event)
if self.formatter and self.memory_retrieval_in_progress:
self.formatter.handle_memory_query_completed(
self.formatter.current_agent_branch,
event.source_type or "memory",
event.query_time_ms,
self.formatter.current_crew_tree,
)
@event_bus.on(MemoryQueryFailedEvent)
def on_memory_query_failed(source, event):
def on_memory_query_failed(source: Any, event: MemoryQueryFailedEvent) -> None:
self._handle_action_event("memory_query_failed", source, event)
if self.formatter and self.memory_retrieval_in_progress:
self.formatter.handle_memory_query_failed(
self.formatter.current_agent_branch,
self.formatter.current_crew_tree,
event.error,
event.source_type or "memory",
)
@event_bus.on(MemorySaveStartedEvent)
def on_memory_save_started(source, event):
def on_memory_save_started(source: Any, event: MemorySaveStartedEvent) -> None:
self._handle_action_event("memory_save_started", source, event)
if self.formatter:
if self.memory_save_in_progress:
return
self.memory_save_in_progress = True
self.formatter.handle_memory_save_started(
self.formatter.current_agent_branch,
self.formatter.current_crew_tree,
)
@event_bus.on(MemorySaveCompletedEvent)
def on_memory_save_completed(source, event):
def on_memory_save_completed(
source: Any, event: MemorySaveCompletedEvent
) -> None:
self._handle_action_event("memory_save_completed", source, event)
if self.formatter:
if not self.memory_save_in_progress:
return
self.memory_save_in_progress = False
self.formatter.handle_memory_save_completed(
self.formatter.current_agent_branch,
self.formatter.current_crew_tree,
event.save_time_ms,
event.source_type or "memory",
)
@event_bus.on(MemorySaveFailedEvent)
def on_memory_save_failed(source, event):
def on_memory_save_failed(source: Any, event: MemorySaveFailedEvent) -> None:
self._handle_action_event("memory_save_failed", source, event)
if self.formatter and self.memory_save_in_progress:
self.formatter.handle_memory_save_failed(
self.formatter.current_agent_branch,
event.error,
event.source_type or "memory",
self.formatter.current_crew_tree,
)
@event_bus.on(MemoryRetrievalStartedEvent)
def on_memory_retrieval_started(
source: Any, event: MemoryRetrievalStartedEvent
) -> None:
if self.formatter:
if self.memory_retrieval_in_progress:
return
self.memory_retrieval_in_progress = True
self.formatter.handle_memory_retrieval_started(
self.formatter.current_agent_branch,
self.formatter.current_crew_tree,
)
@event_bus.on(MemoryRetrievalCompletedEvent)
def on_memory_retrieval_completed(
source: Any, event: MemoryRetrievalCompletedEvent
) -> None:
if self.formatter:
if not self.memory_retrieval_in_progress:
return
self.memory_retrieval_in_progress = False
self.formatter.handle_memory_retrieval_completed(
self.formatter.current_agent_branch,
self.formatter.current_crew_tree,
event.memory_content,
event.retrieval_time_ms,
)
@event_bus.on(AgentReasoningStartedEvent)
def on_agent_reasoning_started(source, event):
def on_agent_reasoning_started(
source: Any, event: AgentReasoningStartedEvent
) -> None:
self._handle_action_event("agent_reasoning_started", source, event)
@event_bus.on(AgentReasoningCompletedEvent)
def on_agent_reasoning_completed(source, event):
def on_agent_reasoning_completed(
source: Any, event: AgentReasoningCompletedEvent
) -> None:
self._handle_action_event("agent_reasoning_completed", source, event)
@event_bus.on(AgentReasoningFailedEvent)
def on_agent_reasoning_failed(source, event):
def on_agent_reasoning_failed(
source: Any, event: AgentReasoningFailedEvent
) -> None:
self._handle_action_event("agent_reasoning_failed", source, event)
@event_bus.on(KnowledgeRetrievalStartedEvent)
def on_knowledge_retrieval_started(source, event):
def on_knowledge_retrieval_started(
source: Any, event: KnowledgeRetrievalStartedEvent
) -> None:
self._handle_action_event("knowledge_retrieval_started", source, event)
@event_bus.on(KnowledgeRetrievalCompletedEvent)
def on_knowledge_retrieval_completed(source, event):
def on_knowledge_retrieval_completed(
source: Any, event: KnowledgeRetrievalCompletedEvent
) -> None:
self._handle_action_event("knowledge_retrieval_completed", source, event)
@event_bus.on(KnowledgeQueryStartedEvent)
def on_knowledge_query_started(source, event):
def on_knowledge_query_started(
source: Any, event: KnowledgeQueryStartedEvent
) -> None:
self._handle_action_event("knowledge_query_started", source, event)
@event_bus.on(KnowledgeQueryCompletedEvent)
def on_knowledge_query_completed(source, event):
def on_knowledge_query_completed(
source: Any, event: KnowledgeQueryCompletedEvent
) -> None:
self._handle_action_event("knowledge_query_completed", source, event)
@event_bus.on(KnowledgeQueryFailedEvent)
def on_knowledge_query_failed(source, event):
def on_knowledge_query_failed(
source: Any, event: KnowledgeQueryFailedEvent
) -> None:
self._handle_action_event("knowledge_query_failed", source, event)
def _initialize_crew_batch(self, source: Any, event: Any):
"""Initialize trace batch"""
def _initialize_crew_batch(self, source: Any, event: Any) -> None:
"""Initialize trace batch.
Args:
source: Source object that triggered the event.
event: Event object containing crew information.
"""
user_context = self._get_user_context()
execution_metadata = {
"crew_name": getattr(event, "crew_name", "Unknown Crew"),
@@ -342,8 +475,13 @@ class TraceCollectionListener(BaseEventListener):
self._initialize_batch(user_context, execution_metadata)
def _initialize_flow_batch(self, source: Any, event: Any):
"""Initialize trace batch for Flow execution"""
def _initialize_flow_batch(self, source: Any, event: Any) -> None:
"""Initialize trace batch for Flow execution.
Args:
source: Source object that triggered the event.
event: Event object containing flow information.
"""
user_context = self._get_user_context()
execution_metadata = {
"flow_name": getattr(event, "flow_name", "Unknown Flow"),
@@ -359,21 +497,32 @@ class TraceCollectionListener(BaseEventListener):
def _initialize_batch(
self, user_context: dict[str, str], execution_metadata: dict[str, Any]
):
"""Initialize trace batch - auto-enable ephemeral for first-time users."""
) -> None:
"""Initialize trace batch - auto-enable ephemeral for first-time users.
Args:
user_context: User context information.
execution_metadata: Metadata about the execution.
"""
if self.first_time_handler.is_first_time:
return self.batch_manager.initialize_batch(
self.batch_manager.initialize_batch(
user_context, execution_metadata, use_ephemeral=True
)
return
use_ephemeral = not self._check_authenticated()
return self.batch_manager.initialize_batch(
self.batch_manager.initialize_batch(
user_context, execution_metadata, use_ephemeral=use_ephemeral
)
def _handle_trace_event(self, event_type: str, source: Any, event: Any):
"""Generic handler for context end events"""
def _handle_trace_event(self, event_type: str, source: Any, event: Any) -> None:
"""Generic handler for context end events.
Args:
event_type: Type of the event.
source: Source object that triggered the event.
event: Event object.
"""
self.batch_manager.begin_event_processing()
try:
trace_event = self._create_trace_event(event_type, source, event)
@@ -381,9 +530,14 @@ class TraceCollectionListener(BaseEventListener):
finally:
self.batch_manager.end_event_processing()
def _handle_action_event(self, event_type: str, source: Any, event: Any):
"""Generic handler for action events (LLM calls, tool usage)"""
def _handle_action_event(self, event_type: str, source: Any, event: Any) -> None:
"""Generic handler for action events (LLM calls, tool usage).
Args:
event_type: Type of the event.
source: Source object that triggered the event.
event: Event object.
"""
if not self.batch_manager.is_batch_initialized():
user_context = self._get_user_context()
execution_metadata = {

View File

@@ -0,0 +1,141 @@
"""Events for A2A (Agent-to-Agent) delegation.
This module defines events emitted during A2A protocol delegation,
including both single-turn and multiturn conversation flows.
"""
from typing import Any, Literal
from crewai.events.base_events import BaseEvent
class A2AEventBase(BaseEvent):
"""Base class for A2A events with task/agent context."""
from_task: Any | None = None
from_agent: Any | None = None
def __init__(self, **data):
"""Initialize A2A event, extracting task and agent metadata."""
if data.get("from_task"):
task = data["from_task"]
data["task_id"] = str(task.id)
data["task_name"] = task.name or task.description
data["from_task"] = None
if data.get("from_agent"):
agent = data["from_agent"]
data["agent_id"] = str(agent.id)
data["agent_role"] = agent.role
data["from_agent"] = None
super().__init__(**data)
class A2ADelegationStartedEvent(A2AEventBase):
"""Event emitted when A2A delegation starts.
Attributes:
endpoint: A2A agent endpoint URL (AgentCard URL)
task_description: Task being delegated to the A2A agent
agent_id: A2A agent identifier
is_multiturn: Whether this is part of a multiturn conversation
turn_number: Current turn number (1-indexed, 1 for single-turn)
"""
type: str = "a2a_delegation_started"
endpoint: str
task_description: str
agent_id: str
is_multiturn: bool = False
turn_number: int = 1
class A2ADelegationCompletedEvent(A2AEventBase):
"""Event emitted when A2A delegation completes.
Attributes:
status: Completion status (completed, input_required, failed, etc.)
result: Result message if status is completed
error: Error/response message (error for failed, response for input_required)
is_multiturn: Whether this is part of a multiturn conversation
"""
type: str = "a2a_delegation_completed"
status: str
result: str | None = None
error: str | None = None
is_multiturn: bool = False
class A2AConversationStartedEvent(A2AEventBase):
"""Event emitted when a multiturn A2A conversation starts.
This is emitted once at the beginning of a multiturn conversation,
before the first message exchange.
Attributes:
agent_id: A2A agent identifier
endpoint: A2A agent endpoint URL
a2a_agent_name: Name of the A2A agent from agent card
"""
type: str = "a2a_conversation_started"
agent_id: str
endpoint: str
a2a_agent_name: str | None = None
class A2AMessageSentEvent(A2AEventBase):
"""Event emitted when a message is sent to the A2A agent.
Attributes:
message: Message content sent to the A2A agent
turn_number: Current turn number (1-indexed)
is_multiturn: Whether this is part of a multiturn conversation
agent_role: Role of the CrewAI agent sending the message
"""
type: str = "a2a_message_sent"
message: str
turn_number: int
is_multiturn: bool = False
agent_role: str | None = None
class A2AResponseReceivedEvent(A2AEventBase):
"""Event emitted when a response is received from the A2A agent.
Attributes:
response: Response content from the A2A agent
turn_number: Current turn number (1-indexed)
is_multiturn: Whether this is part of a multiturn conversation
status: Response status (input_required, completed, etc.)
agent_role: Role of the CrewAI agent (for display)
"""
type: str = "a2a_response_received"
response: str
turn_number: int
is_multiturn: bool = False
status: str
agent_role: str | None = None
class A2AConversationCompletedEvent(A2AEventBase):
"""Event emitted when a multiturn A2A conversation completes.
This is emitted once at the end of a multiturn conversation.
Attributes:
status: Final status (completed, failed, etc.)
final_result: Final result if completed successfully
error: Error message if failed
total_turns: Total number of turns in the conversation
"""
type: str = "a2a_conversation_completed"
status: Literal["completed", "failed"]
final_result: str | None = None
error: str | None = None
total_turns: int

View File

@@ -17,9 +17,16 @@ class ConsoleFormatter:
current_method_branch: Tree | None = None
current_lite_agent_branch: Tree | None = None
tool_usage_counts: ClassVar[dict[str, int]] = {}
current_reasoning_branch: Tree | None = None # Track reasoning status
current_reasoning_branch: Tree | None = None
_live_paused: bool = False
current_llm_tool_tree: Tree | None = None
current_a2a_conversation_branch: Tree | None = None
current_a2a_turn_count: int = 0
_pending_a2a_message: str | None = None
_pending_a2a_agent_role: str | None = None
_pending_a2a_turn_number: int | None = None
_a2a_turn_branches: ClassVar[dict[int, Tree]] = {}
_current_a2a_agent_name: str | None = None
def __init__(self, verbose: bool = False):
self.console = Console(width=None)
@@ -192,7 +199,12 @@ class ConsoleFormatter:
style,
ID=source_id,
)
content.append(f"Final Output: {final_string_output}\n", style="white")
if status == "failed" and final_string_output:
content.append("Error:\n", style="white bold")
content.append(f"{final_string_output}\n", style="red")
else:
content.append(f"Final Output: {final_string_output}\n", style="white")
self.print_panel(content, title, style)
@@ -1474,22 +1486,37 @@ class ConsoleFormatter:
self.print()
elif isinstance(formatted_answer, AgentFinish):
# Create content for the finish panel
content = Text()
content.append("Agent: ", style="white")
content.append(f"{agent_role}\n\n", style="bright_green bold")
content.append("Final Answer:\n", style="white")
content.append(f"{formatted_answer.output}", style="bright_green")
is_a2a_delegation = False
try:
output_data = json.loads(formatted_answer.output)
if isinstance(output_data, dict):
if output_data.get("is_a2a") is True:
is_a2a_delegation = True
elif "output" in output_data:
nested_output = output_data["output"]
if (
isinstance(nested_output, dict)
and nested_output.get("is_a2a") is True
):
is_a2a_delegation = True
except (json.JSONDecodeError, TypeError, ValueError):
pass
# Create and display the finish panel
finish_panel = Panel(
content,
title="✅ Agent Final Answer",
border_style="green",
padding=(1, 2),
)
self.print(finish_panel)
self.print()
if not is_a2a_delegation:
content = Text()
content.append("Agent: ", style="white")
content.append(f"{agent_role}\n\n", style="bright_green bold")
content.append("Final Answer:\n", style="white")
content.append(f"{formatted_answer.output}", style="bright_green")
finish_panel = Panel(
content,
title="✅ Agent Final Answer",
border_style="green",
padding=(1, 2),
)
self.print(finish_panel)
self.print()
def handle_memory_retrieval_started(
self,
@@ -1789,3 +1816,435 @@ class ConsoleFormatter:
Attempts=f"{retry_count + 1}",
)
self.print_panel(content, "🛡️ Guardrail Failed", "red")
def handle_a2a_delegation_started(
self,
endpoint: str,
task_description: str,
agent_id: str,
is_multiturn: bool = False,
turn_number: int = 1,
) -> None:
"""Handle A2A delegation started event.
Args:
endpoint: A2A agent endpoint URL
task_description: Task being delegated
agent_id: A2A agent identifier
is_multiturn: Whether this is part of a multiturn conversation
turn_number: Current turn number in conversation (1-indexed)
"""
branch_to_use = self.current_lite_agent_branch or self.current_task_branch
tree_to_use = self.current_crew_tree or branch_to_use
a2a_branch: Tree | None = None
if is_multiturn:
if self.current_a2a_turn_count == 0 and not isinstance(
self.current_a2a_conversation_branch, Tree
):
if branch_to_use is not None and tree_to_use is not None:
self.current_a2a_conversation_branch = branch_to_use.add("")
self.update_tree_label(
self.current_a2a_conversation_branch,
"💬",
f"Multiturn A2A Conversation ({agent_id})",
"cyan",
)
self.print(tree_to_use)
self.print()
else:
self.current_a2a_conversation_branch = "MULTITURN_NO_TREE"
content = Text()
content.append(
"Multiturn A2A Conversation Started\n\n", style="cyan bold"
)
content.append("Agent ID: ", style="white")
content.append(f"{agent_id}\n", style="cyan")
content.append("Note: ", style="white dim")
content.append(
"Conversation will be tracked in tree view", style="cyan dim"
)
panel = self.create_panel(
content, "💬 Multiturn Conversation", "cyan"
)
self.print(panel)
self.print()
self.current_a2a_turn_count = turn_number
return (
self.current_a2a_conversation_branch
if isinstance(self.current_a2a_conversation_branch, Tree)
else None
)
if branch_to_use is not None and tree_to_use is not None:
a2a_branch = branch_to_use.add("")
self.update_tree_label(
a2a_branch,
"🔗",
f"Delegating to A2A Agent ({agent_id})",
"cyan",
)
self.print(tree_to_use)
self.print()
content = Text()
content.append("A2A Delegation Started\n\n", style="cyan bold")
content.append("Agent ID: ", style="white")
content.append(f"{agent_id}\n", style="cyan")
content.append("Endpoint: ", style="white")
content.append(f"{endpoint}\n\n", style="cyan dim")
content.append("Task Description:\n", style="white")
task_preview = (
task_description
if len(task_description) <= 200
else task_description[:197] + "..."
)
content.append(task_preview, style="cyan")
panel = self.create_panel(content, "🔗 A2A Delegation", "cyan")
self.print(panel)
self.print()
return a2a_branch
def handle_a2a_delegation_completed(
self,
status: str,
result: str | None = None,
error: str | None = None,
is_multiturn: bool = False,
) -> None:
"""Handle A2A delegation completed event.
Args:
status: Completion status
result: Optional result message
error: Optional error message (or response for input_required)
is_multiturn: Whether this is part of a multiturn conversation
"""
tree_to_use = self.current_crew_tree or self.current_task_branch
a2a_branch = None
if is_multiturn and self.current_a2a_conversation_branch:
has_tree = isinstance(self.current_a2a_conversation_branch, Tree)
if status == "input_required" and error:
pass
elif status == "completed":
if has_tree:
final_turn = self.current_a2a_conversation_branch.add("")
self.update_tree_label(
final_turn,
"",
"Conversation Completed",
"green",
)
if tree_to_use:
self.print(tree_to_use)
self.print()
self.current_a2a_conversation_branch = None
self.current_a2a_turn_count = 0
elif status == "failed":
if has_tree:
error_turn = self.current_a2a_conversation_branch.add("")
error_msg = (
error[:150] + "..." if error and len(error) > 150 else error
)
self.update_tree_label(
error_turn,
"",
f"Failed: {error_msg}" if error else "Conversation Failed",
"red",
)
if tree_to_use:
self.print(tree_to_use)
self.print()
self.current_a2a_conversation_branch = None
self.current_a2a_turn_count = 0
return
if a2a_branch and tree_to_use:
if status == "completed":
self.update_tree_label(
a2a_branch,
"",
"A2A Delegation Completed",
"green",
)
elif status == "failed":
self.update_tree_label(
a2a_branch,
"",
"A2A Delegation Failed",
"red",
)
else:
self.update_tree_label(
a2a_branch,
"⚠️",
f"A2A Delegation {status.replace('_', ' ').title()}",
"yellow",
)
self.print(tree_to_use)
self.print()
if status == "completed" and result:
content = Text()
content.append("A2A Delegation Completed\n\n", style="green bold")
content.append("Result:\n", style="white")
result_preview = result if len(result) <= 500 else result[:497] + "..."
content.append(result_preview, style="green")
panel = self.create_panel(content, "✅ A2A Success", "green")
self.print(panel)
self.print()
elif status == "input_required" and error:
content = Text()
content.append("A2A Response\n\n", style="cyan bold")
content.append("Message:\n", style="white")
response_preview = error if len(error) <= 500 else error[:497] + "..."
content.append(response_preview, style="cyan")
panel = self.create_panel(content, "💬 A2A Response", "cyan")
self.print(panel)
self.print()
elif error:
content = Text()
content.append(
"A2A Delegation Issue\n\n",
style="red bold" if status == "failed" else "yellow bold",
)
content.append("Status: ", style="white")
content.append(
f"{status}\n\n", style="red" if status == "failed" else "yellow"
)
content.append("Message:\n", style="white")
content.append(error, style="red" if status == "failed" else "yellow")
panel_style = "red" if status == "failed" else "yellow"
panel_title = "❌ A2A Failed" if status == "failed" else "⚠️ A2A Status"
panel = self.create_panel(content, panel_title, panel_style)
self.print(panel)
self.print()
def handle_a2a_conversation_started(
self,
agent_id: str,
endpoint: str,
) -> None:
"""Handle A2A conversation started event.
Args:
agent_id: A2A agent identifier
endpoint: A2A agent endpoint URL
"""
branch_to_use = self.current_lite_agent_branch or self.current_task_branch
tree_to_use = self.current_crew_tree or branch_to_use
if not isinstance(self.current_a2a_conversation_branch, Tree):
if branch_to_use is not None and tree_to_use is not None:
self.current_a2a_conversation_branch = branch_to_use.add("")
self.update_tree_label(
self.current_a2a_conversation_branch,
"💬",
f"Multiturn A2A Conversation ({agent_id})",
"cyan",
)
self.print(tree_to_use)
self.print()
else:
self.current_a2a_conversation_branch = "MULTITURN_NO_TREE"
def handle_a2a_message_sent(
self,
message: str,
turn_number: int,
agent_role: str | None = None,
) -> None:
"""Handle A2A message sent event.
Args:
message: Message content sent to the A2A agent
turn_number: Current turn number
agent_role: Role of the CrewAI agent sending the message
"""
self._pending_a2a_message = message
self._pending_a2a_agent_role = agent_role
self._pending_a2a_turn_number = turn_number
def handle_a2a_response_received(
self,
response: str,
turn_number: int,
status: str,
agent_role: str | None = None,
) -> None:
"""Handle A2A response received event.
Args:
response: Response content from the A2A agent
turn_number: Current turn number
status: Response status (input_required, completed, etc.)
agent_role: Role of the CrewAI agent (for display)
"""
if self.current_a2a_conversation_branch and isinstance(
self.current_a2a_conversation_branch, Tree
):
if turn_number in self._a2a_turn_branches:
turn_branch = self._a2a_turn_branches[turn_number]
else:
turn_branch = self.current_a2a_conversation_branch.add("")
self.update_tree_label(
turn_branch,
"💬",
f"Turn {turn_number}",
"cyan",
)
self._a2a_turn_branches[turn_number] = turn_branch
crewai_agent_role = self._pending_a2a_agent_role or agent_role or "User"
message_content = self._pending_a2a_message or "sent message"
message_preview = (
message_content[:100] + "..."
if len(message_content) > 100
else message_content
)
user_node = turn_branch.add("")
self.update_tree_label(
user_node,
f"{crewai_agent_role} 👤 : ",
f'"{message_preview}"',
"blue",
)
agent_node = turn_branch.add("")
response_preview = (
response[:100] + "..." if len(response) > 100 else response
)
a2a_agent_display = f"{self._current_a2a_agent_name} \U0001f916: "
if status == "completed":
response_color = "green"
status_indicator = ""
elif status == "input_required":
response_color = "yellow"
status_indicator = ""
elif status == "failed":
response_color = "red"
status_indicator = ""
elif status == "auth_required":
response_color = "magenta"
status_indicator = "🔒"
elif status == "canceled":
response_color = "dim"
status_indicator = ""
else:
response_color = "cyan"
status_indicator = ""
label = f'"{response_preview}"'
if status_indicator:
label = f"{status_indicator} {label}"
self.update_tree_label(
agent_node,
a2a_agent_display,
label,
response_color,
)
self._pending_a2a_message = None
self._pending_a2a_agent_role = None
self._pending_a2a_turn_number = None
tree_to_use = self.current_crew_tree or self.current_task_branch
if tree_to_use:
self.print(tree_to_use)
self.print()
def handle_a2a_conversation_completed(
self,
status: str,
final_result: str | None,
error: str | None,
total_turns: int,
) -> None:
"""Handle A2A conversation completed event.
Args:
status: Final status (completed, failed, etc.)
final_result: Final result if completed successfully
error: Error message if failed
total_turns: Total number of turns in the conversation
"""
if self.current_a2a_conversation_branch and isinstance(
self.current_a2a_conversation_branch, Tree
):
if status == "completed":
if self._pending_a2a_message and self._pending_a2a_agent_role:
if total_turns in self._a2a_turn_branches:
turn_branch = self._a2a_turn_branches[total_turns]
else:
turn_branch = self.current_a2a_conversation_branch.add("")
self.update_tree_label(
turn_branch,
"💬",
f"Turn {total_turns}",
"cyan",
)
self._a2a_turn_branches[total_turns] = turn_branch
crewai_agent_role = self._pending_a2a_agent_role
message_content = self._pending_a2a_message
message_preview = (
message_content[:100] + "..."
if len(message_content) > 100
else message_content
)
user_node = turn_branch.add("")
self.update_tree_label(
user_node,
f"{crewai_agent_role} 👤 : ",
f'"{message_preview}"',
"green",
)
self._pending_a2a_message = None
self._pending_a2a_agent_role = None
self._pending_a2a_turn_number = None
elif status == "failed":
error_turn = self.current_a2a_conversation_branch.add("")
error_msg = error[:150] + "..." if error and len(error) > 150 else error
self.update_tree_label(
error_turn,
"",
f"Failed: {error_msg}" if error else "Conversation Failed",
"red",
)
tree_to_use = self.current_crew_tree or self.current_task_branch
if tree_to_use:
self.print(tree_to_use)
self.print()
self.current_a2a_conversation_branch = None
self.current_a2a_turn_count = 0

View File

@@ -1,65 +0,0 @@
"""A2A (Agent-to-Agent) Protocol adapter for CrewAI.
This module provides integration with A2A protocol-compliant agents,
enabling CrewAI to orchestrate external agents like ServiceNow, Bedrock Agents,
Glean, and other A2A-compliant systems.
Example:
```python
from crewai.experimental.a2a import A2AAgentAdapter
# Create A2A agent
servicenow_agent = A2AAgentAdapter(
agent_card_url="https://servicenow.example.com/.well-known/agent-card.json",
auth_token="your-token",
role="ServiceNow Incident Manager",
goal="Create and manage IT incidents",
backstory="Expert at incident management",
)
# Use in crew
crew = Crew(agents=[servicenow_agent], tasks=[task])
```
"""
from crewai.experimental.a2a.a2a_adapter import A2AAgentAdapter
from crewai.experimental.a2a.auth import (
APIKeyAuth,
AuthScheme,
BearerTokenAuth,
HTTPBasicAuth,
HTTPDigestAuth,
OAuth2AuthorizationCode,
OAuth2ClientCredentials,
create_auth_from_agent_card,
)
from crewai.experimental.a2a.exceptions import (
A2AAuthenticationError,
A2AConfigurationError,
A2AConnectionError,
A2AError,
A2AInputRequiredError,
A2ATaskCanceledError,
A2ATaskFailedError,
)
__all__ = [
"A2AAgentAdapter",
"A2AAuthenticationError",
"A2AConfigurationError",
"A2AConnectionError",
"A2AError",
"A2AInputRequiredError",
"A2ATaskCanceledError",
"A2ATaskFailedError",
"APIKeyAuth",
# Authentication
"AuthScheme",
"BearerTokenAuth",
"HTTPBasicAuth",
"HTTPDigestAuth",
"OAuth2AuthorizationCode",
"OAuth2ClientCredentials",
"create_auth_from_agent_card",
]

File diff suppressed because it is too large Load Diff

View File

@@ -1,424 +0,0 @@
"""Authentication schemes for A2A protocol agents.
This module provides support for various authentication methods:
- Bearer tokens (existing)
- OAuth2 (Client Credentials, Authorization Code)
- API Keys (header, query, cookie)
- HTTP Basic authentication
- HTTP Digest authentication
"""
from __future__ import annotations
from abc import ABC, abstractmethod
import base64
from collections.abc import Awaitable, Callable
from typing import TYPE_CHECKING, Any, Literal
import httpx
from pydantic import BaseModel, Field
if TYPE_CHECKING:
from a2a.types import AgentCard
class AuthScheme(ABC, BaseModel):
"""Base class for authentication schemes."""
@abstractmethod
async def apply_auth(
self, client: httpx.AsyncClient, headers: dict[str, str]
) -> dict[str, str]:
"""Apply authentication to request headers.
Args:
client: HTTP client for making auth requests.
headers: Current request headers.
Returns:
Updated headers with authentication applied.
"""
...
@abstractmethod
def configure_client(self, client: httpx.AsyncClient) -> None:
"""Configure the HTTP client for this auth scheme.
Args:
client: HTTP client to configure.
"""
...
class BearerTokenAuth(AuthScheme):
"""Bearer token authentication (Authorization: Bearer <token>)."""
token: str = Field(description="Bearer token")
async def apply_auth(
self, client: httpx.AsyncClient, headers: dict[str, str]
) -> dict[str, str]:
"""Apply Bearer token to Authorization header."""
headers["Authorization"] = f"Bearer {self.token}"
return headers
def configure_client(self, client: httpx.AsyncClient) -> None:
"""No client configuration needed for Bearer tokens."""
class HTTPBasicAuth(AuthScheme):
"""HTTP Basic authentication."""
username: str = Field(description="Username")
password: str = Field(description="Password")
async def apply_auth(
self, client: httpx.AsyncClient, headers: dict[str, str]
) -> dict[str, str]:
"""Apply HTTP Basic authentication."""
credentials = f"{self.username}:{self.password}"
encoded = base64.b64encode(credentials.encode()).decode()
headers["Authorization"] = f"Basic {encoded}"
return headers
def configure_client(self, client: httpx.AsyncClient) -> None:
"""No client configuration needed for Basic auth."""
class HTTPDigestAuth(AuthScheme):
"""HTTP Digest authentication.
Note: Uses httpx-auth library for proper digest implementation.
"""
username: str = Field(description="Username")
password: str = Field(description="Password")
async def apply_auth(
self, client: httpx.AsyncClient, headers: dict[str, str]
) -> dict[str, str]:
"""Digest auth is handled by httpx auth flow, not headers."""
return headers
def configure_client(self, client: httpx.AsyncClient) -> None:
"""Configure client with Digest auth."""
try:
from httpx_auth import DigestAuth # type: ignore[import-not-found]
client.auth = DigestAuth(self.username, self.password) # type: ignore[import-not-found]
except ImportError as e:
msg = "httpx-auth required for Digest authentication. Install with: pip install httpx-auth"
raise ImportError(msg) from e
class APIKeyAuth(AuthScheme):
"""API Key authentication (header, query, or cookie)."""
api_key: str = Field(description="API key value")
location: Literal["header", "query", "cookie"] = Field(
default="header", description="Where to send the API key"
)
name: str = Field(default="X-API-Key", description="Parameter name for the API key")
async def apply_auth(
self, client: httpx.AsyncClient, headers: dict[str, str]
) -> dict[str, str]:
"""Apply API key authentication."""
if self.location == "header":
headers[self.name] = self.api_key
elif self.location == "cookie":
headers["Cookie"] = f"{self.name}={self.api_key}"
# Query params are handled in configure_client via event hooks
return headers
def configure_client(self, client: httpx.AsyncClient) -> None:
"""Configure client for query param API keys."""
if self.location == "query":
# Add API key to all requests via event hook
async def add_api_key_param(request: httpx.Request) -> None:
url = httpx.URL(request.url)
request.url = url.copy_add_param(self.name, self.api_key)
client.event_hooks["request"].append(add_api_key_param)
class OAuth2ClientCredentials(AuthScheme):
"""OAuth2 Client Credentials flow authentication."""
token_url: str = Field(description="OAuth2 token endpoint")
client_id: str = Field(description="OAuth2 client ID")
client_secret: str = Field(description="OAuth2 client secret")
scopes: list[str] = Field(
default_factory=list, description="Required OAuth2 scopes"
)
_access_token: str | None = None
_token_expires_at: float | None = None
async def apply_auth(
self, client: httpx.AsyncClient, headers: dict[str, str]
) -> dict[str, str]:
"""Apply OAuth2 access token to Authorization header."""
# Get or refresh token if needed
import time
if (
self._access_token is None
or self._token_expires_at is None
or time.time() >= self._token_expires_at
):
await self._fetch_token(client)
if self._access_token:
headers["Authorization"] = f"Bearer {self._access_token}"
return headers
async def _fetch_token(self, client: httpx.AsyncClient) -> None:
"""Fetch OAuth2 access token using client credentials flow."""
import time
data = {
"grant_type": "client_credentials",
"client_id": self.client_id,
"client_secret": self.client_secret,
}
if self.scopes:
data["scope"] = " ".join(self.scopes)
response = await client.post(self.token_url, data=data)
response.raise_for_status()
token_data = response.json()
self._access_token = token_data["access_token"]
# Calculate expiration time (default to 3600 seconds if not provided)
expires_in = token_data.get("expires_in", 3600)
self._token_expires_at = time.time() + expires_in - 60 # 60s buffer
def configure_client(self, client: httpx.AsyncClient) -> None:
"""No client configuration needed for OAuth2."""
class OAuth2AuthorizationCode(AuthScheme):
"""OAuth2 Authorization Code flow authentication.
Note: This requires interactive authorization and is typically used
for user-facing applications. For server-to-server, use ClientCredentials.
"""
authorization_url: str = Field(description="OAuth2 authorization endpoint")
token_url: str = Field(description="OAuth2 token endpoint")
client_id: str = Field(description="OAuth2 client ID")
client_secret: str = Field(description="OAuth2 client secret")
redirect_uri: str = Field(description="OAuth2 redirect URI")
scopes: list[str] = Field(
default_factory=list, description="Required OAuth2 scopes"
)
_access_token: str | None = None
_refresh_token: str | None = None
_token_expires_at: float | None = None
_authorization_callback: Callable[[str], Awaitable[str]] | None = None
def set_authorization_callback(
self, callback: Callable[[str], Awaitable[str]] | None
) -> None:
"""Set callback to handle authorization URL.
The callback receives the authorization URL and should return
the authorization code after user completes the flow.
"""
self._authorization_callback = callback
async def apply_auth(
self, client: httpx.AsyncClient, headers: dict[str, str]
) -> dict[str, str]:
"""Apply OAuth2 access token to Authorization header."""
import time
# Get or refresh token if needed
if self._access_token is None:
if self._authorization_callback is None:
msg = "Authorization callback not set. Use set_authorization_callback()"
raise ValueError(msg)
await self._fetch_initial_token(client)
elif self._token_expires_at and time.time() >= self._token_expires_at:
await self._refresh_access_token(client)
if self._access_token:
headers["Authorization"] = f"Bearer {self._access_token}"
return headers
async def _fetch_initial_token(self, client: httpx.AsyncClient) -> None:
"""Fetch initial access token using authorization code flow."""
import time
import urllib.parse
# Build authorization URL
params = {
"response_type": "code",
"client_id": self.client_id,
"redirect_uri": self.redirect_uri,
"scope": " ".join(self.scopes),
}
auth_url = f"{self.authorization_url}?{urllib.parse.urlencode(params)}"
# Get authorization code from callback
if self._authorization_callback is None:
msg = "Authorization callback not set"
raise ValueError(msg)
auth_code = await self._authorization_callback(auth_url)
# Exchange code for token
data = {
"grant_type": "authorization_code",
"code": auth_code,
"client_id": self.client_id,
"client_secret": self.client_secret,
"redirect_uri": self.redirect_uri,
}
response = await client.post(self.token_url, data=data)
response.raise_for_status()
token_data = response.json()
self._access_token = token_data["access_token"]
self._refresh_token = token_data.get("refresh_token")
expires_in = token_data.get("expires_in", 3600)
self._token_expires_at = time.time() + expires_in - 60
async def _refresh_access_token(self, client: httpx.AsyncClient) -> None:
"""Refresh the access token using refresh token."""
import time
if not self._refresh_token:
# Re-authorize if no refresh token
await self._fetch_initial_token(client)
return
data = {
"grant_type": "refresh_token",
"refresh_token": self._refresh_token,
"client_id": self.client_id,
"client_secret": self.client_secret,
}
response = await client.post(self.token_url, data=data)
response.raise_for_status()
token_data = response.json()
self._access_token = token_data["access_token"]
if "refresh_token" in token_data:
self._refresh_token = token_data["refresh_token"]
expires_in = token_data.get("expires_in", 3600)
self._token_expires_at = time.time() + expires_in - 60
def configure_client(self, client: httpx.AsyncClient) -> None:
"""No client configuration needed for OAuth2."""
def create_auth_from_agent_card(
agent_card: AgentCard, credentials: dict[str, Any]
) -> AuthScheme | None:
"""Create an appropriate authentication scheme from AgentCard security config.
Args:
agent_card: The A2A AgentCard containing security requirements.
credentials: User-provided credentials (passwords, tokens, keys, etc.).
Returns:
Configured AuthScheme, or None if no authentication required.
Example:
```python
# For OAuth2
credentials = {
"client_id": "my-app",
"client_secret": "secret123",
}
auth = create_auth_from_agent_card(agent_card, credentials)
# For API Key
credentials = {"api_key": "key-12345"}
auth = create_auth_from_agent_card(agent_card, credentials)
# For HTTP Basic
credentials = {"username": "user", "password": "pass"}
auth = create_auth_from_agent_card(agent_card, credentials)
```
"""
if not agent_card.security or not agent_card.security_schemes:
return None
# Get the first required security scheme
first_security_req = agent_card.security[0] if agent_card.security else {}
for scheme_name, _scopes in first_security_req.items():
security_scheme_obj = agent_card.security_schemes.get(scheme_name)
if not security_scheme_obj:
continue
# SecurityScheme is a dict-like object
security_scheme = dict(security_scheme_obj) # type: ignore[arg-type]
scheme_type = str(security_scheme.get("type", "")).lower()
# OAuth2
if scheme_type == "oauth2":
flows = security_scheme.get("flows", {})
if "clientCredentials" in flows:
flow = flows["clientCredentials"]
return OAuth2ClientCredentials(
token_url=str(flow["tokenUrl"]),
client_id=str(credentials.get("client_id", "")),
client_secret=str(credentials.get("client_secret", "")),
scopes=list(flow.get("scopes", {}).keys()),
)
if "authorizationCode" in flows:
flow = flows["authorizationCode"]
return OAuth2AuthorizationCode(
authorization_url=str(flow["authorizationUrl"]),
token_url=str(flow["tokenUrl"]),
client_id=str(credentials.get("client_id", "")),
client_secret=str(credentials.get("client_secret", "")),
redirect_uri=str(credentials.get("redirect_uri", "")),
scopes=list(flow.get("scopes", {}).keys()),
)
# API Key
elif scheme_type == "apikey":
location = str(security_scheme.get("in", "header"))
name = str(security_scheme.get("name", "X-API-Key"))
return APIKeyAuth(
api_key=str(credentials.get("api_key", "")),
location=location, # type: ignore[arg-type]
name=name,
)
# HTTP Auth
elif scheme_type == "http":
http_scheme = str(security_scheme.get("scheme", "")).lower()
if http_scheme == "basic":
return HTTPBasicAuth(
username=str(credentials.get("username", "")),
password=str(credentials.get("password", "")),
)
if http_scheme == "digest":
return HTTPDigestAuth(
username=str(credentials.get("username", "")),
password=str(credentials.get("password", "")),
)
if http_scheme == "bearer":
return BearerTokenAuth(token=str(credentials.get("token", "")))
return None

View File

@@ -1,56 +0,0 @@
"""Custom exceptions for A2A Agent Adapter."""
class A2AError(Exception):
"""Base exception for A2A adapter errors."""
class A2ATaskFailedError(A2AError):
"""Raised when A2A agent task fails or is rejected.
This exception is raised when the A2A agent reports a task
in the 'failed' or 'rejected' state.
"""
class A2AInputRequiredError(A2AError):
"""Raised when A2A agent requires additional input.
This exception is raised when the A2A agent reports a task
in the 'input_required' state, indicating that it needs more
information to complete the task.
"""
class A2AConfigurationError(A2AError):
"""Raised when A2A adapter configuration is invalid.
This exception is raised during initialization or setup when
the adapter configuration is invalid or incompatible.
"""
class A2AConnectionError(A2AError):
"""Raised when connection to A2A agent fails.
This exception is raised when the adapter cannot establish
a connection to the A2A agent or when network errors occur.
"""
class A2AAuthenticationError(A2AError):
"""Raised when A2A agent requires authentication.
This exception is raised when the A2A agent reports a task
in the 'auth_required' state, indicating that authentication
is needed before the task can continue.
"""
class A2ATaskCanceledError(A2AError):
"""Raised when A2A task is canceled.
This exception is raised when the A2A agent reports a task
in the 'canceled' state, indicating the task was canceled
either by the user or the system.
"""

View File

@@ -1,56 +0,0 @@
"""Type protocols for A2A SDK components.
These protocols define the expected interfaces for A2A SDK types,
allowing for type checking without requiring the SDK to be installed.
"""
from collections.abc import AsyncIterator
from typing import Any, Protocol, runtime_checkable
@runtime_checkable
class AgentCardProtocol(Protocol):
"""Protocol for A2A AgentCard."""
name: str
version: str
description: str
skills: list[Any]
capabilities: Any
@runtime_checkable
class ClientProtocol(Protocol):
"""Protocol for A2A Client."""
async def send_message(self, message: Any) -> AsyncIterator[Any]:
"""Send message to A2A agent."""
...
async def get_card(self) -> AgentCardProtocol:
"""Get agent card."""
...
async def close(self) -> None:
"""Close client connection."""
...
@runtime_checkable
class MessageProtocol(Protocol):
"""Protocol for A2A Message."""
role: Any
message_id: str
parts: list[Any]
@runtime_checkable
class TaskProtocol(Protocol):
"""Protocol for A2A Task."""
id: str
context_id: str
status: Any
history: list[Any] | None
artifacts: list[Any] | None

View File

@@ -2,7 +2,7 @@ from collections.abc import Sequence
import threading
from typing import Any
from crewai.agent import Agent
from crewai.agent.core import Agent
from crewai.agents.agent_builder.base_agent import BaseAgent
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.agent_events import (

View File

@@ -21,6 +21,7 @@ from typing import (
from dotenv import load_dotenv
from pydantic import BaseModel, Field
from typing_extensions import Self
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.llm_events import (
@@ -36,30 +37,34 @@ from crewai.events.types.tool_usage_events import (
ToolUsageStartedEvent,
)
from crewai.llms.base_llm import BaseLLM
from crewai.utilities import InternalInstructor
from crewai.utilities.exceptions.context_window_exceeding_exception import (
LLMContextLengthExceededError,
)
from crewai.utilities.logger_utils import suppress_warnings
from crewai.utilities.types import LLMMessage
if TYPE_CHECKING:
from litellm import Choices
from litellm.exceptions import ContextWindowExceededError
from litellm.litellm_core_utils.get_supported_openai_params import (
get_supported_openai_params,
)
from litellm.types.utils import ChatCompletionDeltaToolCall, ModelResponse
from litellm.types.utils import ChatCompletionDeltaToolCall, Choices, ModelResponse
from litellm.utils import supports_response_schema
from crewai.agent.core import Agent
from crewai.task import Task
from crewai.tools.base_tool import BaseTool
from crewai.utilities.types import LLMMessage
try:
import litellm
from litellm import Choices, CustomLogger
from litellm.exceptions import ContextWindowExceededError
from litellm.integrations.custom_logger import CustomLogger
from litellm.litellm_core_utils.get_supported_openai_params import (
get_supported_openai_params,
)
from litellm.types.utils import ChatCompletionDeltaToolCall, ModelResponse
from litellm.types.utils import ChatCompletionDeltaToolCall, Choices, ModelResponse
from litellm.utils import supports_response_schema
LITELLM_AVAILABLE = True
@@ -72,6 +77,7 @@ except ImportError:
ChatCompletionDeltaToolCall = None # type: ignore
ModelResponse = None # type: ignore
supports_response_schema = None # type: ignore
CustomLogger = None # type: ignore
load_dotenv()
@@ -104,11 +110,13 @@ class FilteredStream(io.TextIOBase):
return self._original_stream.write(s)
def flush(self):
with self._lock:
return self._original_stream.flush()
def flush(self) -> None:
if self._lock:
with self._lock:
return self._original_stream.flush()
return None
def __getattr__(self, name):
def __getattr__(self, name: str) -> Any:
"""Delegate attribute access to the wrapped original stream.
This ensures compatibility with libraries (e.g., Rich) that rely on
@@ -122,16 +130,16 @@ class FilteredStream(io.TextIOBase):
# confuses Rich). These explicit pass-throughs ensure the wrapped Console
# still sees a fully-featured stream.
@property
def encoding(self):
def encoding(self) -> str | Any: # type: ignore[override]
return getattr(self._original_stream, "encoding", "utf-8")
def isatty(self):
def isatty(self) -> bool:
return self._original_stream.isatty()
def fileno(self):
def fileno(self) -> int:
return self._original_stream.fileno()
def writable(self):
def writable(self) -> bool:
return True
@@ -312,7 +320,7 @@ class AccumulatedToolArgs(BaseModel):
class LLM(BaseLLM):
completion_cost: float | None = None
def __new__(cls, model: str, is_litellm: bool = False, **kwargs) -> LLM:
def __new__(cls, model: str, is_litellm: bool = False, **kwargs: Any) -> LLM:
"""Factory method that routes to native SDK or falls back to LiteLLM."""
if not model or not isinstance(model, str):
raise ValueError("Model must be a non-empty string")
@@ -323,7 +331,9 @@ class LLM(BaseLLM):
if native_class and not is_litellm and provider in SUPPORTED_NATIVE_PROVIDERS:
try:
model_string = model.partition("/")[2] if "/" in model else model
return native_class(model=model_string, provider=provider, **kwargs)
return cast(
Self, native_class(model=model_string, provider=provider, **kwargs)
)
except Exception as e:
raise ImportError(f"Error importing native provider: {e}") from e
@@ -393,13 +403,21 @@ class LLM(BaseLLM):
callbacks: list[Any] | None = None,
reasoning_effort: Literal["none", "low", "medium", "high"] | None = None,
stream: bool = False,
**kwargs,
):
**kwargs: Any,
) -> None:
"""Initialize LLM instance.
Note: This __init__ method is only called for fallback instances.
Native provider instances handle their own initialization in their respective classes.
"""
super().__init__(
model=model,
temperature=temperature,
api_key=api_key,
base_url=base_url,
timeout=timeout,
**kwargs,
)
self.model = model
self.timeout = timeout
self.temperature = temperature
@@ -454,7 +472,7 @@ class LLM(BaseLLM):
def _prepare_completion_params(
self,
messages: str | list[LLMMessage],
tools: list[dict] | None = None,
tools: list[dict[str, BaseTool]] | None = None,
) -> dict[str, Any]:
"""Prepare parameters for the completion call.
@@ -505,9 +523,10 @@ class LLM(BaseLLM):
params: dict[str, Any],
callbacks: list[Any] | None = None,
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
) -> str:
from_task: Task | None = None,
from_agent: Agent | None = None,
response_model: type[BaseModel] | None = None,
) -> Any:
"""Handle a streaming response from the LLM.
Args:
@@ -516,6 +535,7 @@ class LLM(BaseLLM):
available_functions: Dict of available functions
from_task: Optional task object
from_agent: Optional agent object
response_model: Optional response model
Returns:
str: The complete response text
@@ -716,14 +736,30 @@ class LLM(BaseLLM):
tool_calls = message.tool_calls
except Exception as e:
logging.debug(f"Error checking for tool calls: {e}")
# --- 8) If no tool calls or no available functions, return the text response directly
if not tool_calls or not available_functions:
# Track token usage and log callbacks if available in streaming mode
if usage_info:
self._track_token_usage_internal(usage_info)
self._handle_streaming_callbacks(callbacks, usage_info, last_chunk)
# Emit completion event and return response
if response_model and self.is_litellm:
instructor_instance = InternalInstructor(
content=full_response,
model=response_model,
llm=self,
)
result = instructor_instance.to_pydantic()
structured_response = result.model_dump_json()
self._handle_emit_call_events(
response=structured_response,
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
)
return structured_response
self._handle_emit_call_events(
response=full_response,
call_type=LLMCallType.LLM_CALL,
@@ -784,9 +820,9 @@ class LLM(BaseLLM):
tool_calls: list[ChatCompletionDeltaToolCall],
accumulated_tool_args: defaultdict[int, AccumulatedToolArgs],
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
) -> None | str:
from_task: Task | None = None,
from_agent: Agent | None = None,
) -> Any:
for tool_call in tool_calls:
current_tool_accumulator = accumulated_tool_args[tool_call.index]
@@ -869,8 +905,9 @@ class LLM(BaseLLM):
params: dict[str, Any],
callbacks: list[Any] | None = None,
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
from_task: Task | None = None,
from_agent: Agent | None = None,
response_model: type[BaseModel] | None = None,
) -> str | Any:
"""Handle a non-streaming response from the LLM.
@@ -880,23 +917,69 @@ class LLM(BaseLLM):
available_functions: Dict of available functions
from_task: Optional Task that invoked the LLM
from_agent: Optional Agent that invoked the LLM
response_model: Optional Response model
Returns:
str: The response text
"""
# --- 1) Make the completion call
# --- 1) Handle response_model with InternalInstructor for LiteLLM
if response_model and self.is_litellm:
from crewai.utilities.internal_instructor import InternalInstructor
messages = params.get("messages", [])
if not messages:
raise ValueError("Messages are required when using response_model")
# Combine all message content for InternalInstructor
combined_content = "\n\n".join(
f"{msg['role'].upper()}: {msg['content']}" for msg in messages
)
instructor_instance = InternalInstructor(
content=combined_content,
model=response_model,
llm=self,
)
result = instructor_instance.to_pydantic()
structured_response = result.model_dump_json()
self._handle_emit_call_events(
response=structured_response,
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
)
return structured_response
try:
# Attempt to make the completion call, but catch context window errors
# and convert them to our own exception type for consistent handling
# across the codebase. This allows CrewAgentExecutor to handle context
# length issues appropriately.
if response_model:
params["response_model"] = response_model
response = litellm.completion(**params)
except ContextWindowExceededError as e:
# Convert litellm's context window error to our own exception type
# for consistent handling in the rest of the codebase
raise LLMContextLengthExceededError(str(e)) from e
# --- 2) Extract response message and content
# --- 2) Handle structured output response (when response_model is provided)
if response_model is not None:
# When using instructor/response_model, litellm returns a Pydantic model instance
if isinstance(response, BaseModel):
structured_response = response.model_dump_json()
self._handle_emit_call_events(
response=structured_response,
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
)
return structured_response
# --- 3) Extract response message and content (standard response)
response_message = cast(Choices, cast(ModelResponse, response).choices)[
0
].message
@@ -951,9 +1034,9 @@ class LLM(BaseLLM):
self,
tool_calls: list[Any],
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
) -> str | None:
from_task: Task | None = None,
from_agent: Agent | None = None,
) -> Any:
"""Handle a tool call from the LLM.
Args:
@@ -1039,11 +1122,12 @@ class LLM(BaseLLM):
def call(
self,
messages: str | list[LLMMessage],
tools: list[dict] | None = None,
tools: list[dict[str, BaseTool]] | None = None,
callbacks: list[Any] | None = None,
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
from_task: Task | None = None,
from_agent: Agent | None = None,
response_model: type[BaseModel] | None = None,
) -> str | Any:
"""High-level LLM call method.
@@ -1060,6 +1144,7 @@ class LLM(BaseLLM):
that can be invoked by the LLM.
from_task: Optional Task that invoked the LLM
from_agent: Optional Agent that invoked the LLM
response_model: Optional Model that contains a pydantic response model.
Returns:
Union[str, Any]: Either a text response from the LLM (str) or
@@ -1105,11 +1190,21 @@ class LLM(BaseLLM):
# --- 7) Make the completion call and handle response
if self.stream:
return self._handle_streaming_response(
params, callbacks, available_functions, from_task, from_agent
params=params,
callbacks=callbacks,
available_functions=available_functions,
from_task=from_task,
from_agent=from_agent,
response_model=response_model,
)
return self._handle_non_streaming_response(
params, callbacks, available_functions, from_task, from_agent
params=params,
callbacks=callbacks,
available_functions=available_functions,
from_task=from_task,
from_agent=from_agent,
response_model=response_model,
)
except LLMContextLengthExceededError:
# Re-raise LLMContextLengthExceededError as it should be handled
@@ -1141,6 +1236,7 @@ class LLM(BaseLLM):
available_functions=available_functions,
from_task=from_task,
from_agent=from_agent,
response_model=response_model,
)
crewai_event_bus.emit(
@@ -1155,10 +1251,10 @@ class LLM(BaseLLM):
self,
response: Any,
call_type: LLMCallType,
from_task: Any | None = None,
from_agent: Any | None = None,
messages: str | list[dict[str, Any]] | None = None,
):
from_task: Task | None = None,
from_agent: Agent | None = None,
messages: str | list[LLMMessage] | None = None,
) -> None:
"""Handle the events for the LLM call.
Args:
@@ -1324,7 +1420,7 @@ class LLM(BaseLLM):
return self.context_window_size
@staticmethod
def set_callbacks(callbacks: list[Any]):
def set_callbacks(callbacks: list[Any]) -> None:
"""
Attempt to keep a single set of callbacks in litellm by removing old
duplicates and adding new ones.
@@ -1377,7 +1473,7 @@ class LLM(BaseLLM):
litellm.success_callback = success_callbacks
litellm.failure_callback = failure_callbacks
def __copy__(self):
def __copy__(self) -> LLM:
"""Create a shallow copy of the LLM instance."""
# Filter out parameters that are already explicitly passed to avoid conflicts
filtered_params = {
@@ -1437,7 +1533,7 @@ class LLM(BaseLLM):
**filtered_params,
)
def __deepcopy__(self, memo):
def __deepcopy__(self, memo: dict[int, Any] | None) -> LLM:
"""Create a deep copy of the LLM instance."""
import copy

View File

@@ -10,6 +10,7 @@ from abc import ABC, abstractmethod
from datetime import datetime
import json
import logging
import re
from typing import TYPE_CHECKING, Any, Final
from pydantic import BaseModel
@@ -31,11 +32,15 @@ from crewai.types.usage_metrics import UsageMetrics
if TYPE_CHECKING:
from crewai.agent.core import Agent
from crewai.task import Task
from crewai.tools.base_tool import BaseTool
from crewai.utilities.types import LLMMessage
DEFAULT_CONTEXT_WINDOW_SIZE: Final[int] = 4096
DEFAULT_SUPPORTS_STOP_WORDS: Final[bool] = True
_JSON_EXTRACTION_PATTERN: Final[re.Pattern[str]] = re.compile(r"\{.*}", re.DOTALL)
class BaseLLM(ABC):
@@ -65,9 +70,8 @@ class BaseLLM(ABC):
temperature: float | None = None,
api_key: str | None = None,
base_url: str | None = None,
timeout: float | None = None,
provider: str | None = None,
**kwargs,
**kwargs: Any,
) -> None:
"""Initialize the BaseLLM with default attributes.
@@ -93,8 +97,10 @@ class BaseLLM(ABC):
self.stop: list[str] = []
elif isinstance(stop, str):
self.stop = [stop]
else:
elif isinstance(stop, list):
self.stop = stop
else:
self.stop = []
self._token_usage = {
"total_tokens": 0,
@@ -118,11 +124,12 @@ class BaseLLM(ABC):
def call(
self,
messages: str | list[LLMMessage],
tools: list[dict] | None = None,
tools: list[dict[str, BaseTool]] | None = None,
callbacks: list[Any] | None = None,
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
from_task: Task | None = None,
from_agent: Agent | None = None,
response_model: type[BaseModel] | None = None,
) -> str | Any:
"""Call the LLM with the given messages.
@@ -139,6 +146,7 @@ class BaseLLM(ABC):
that can be invoked by the LLM.
from_task: Optional task caller to be used for the LLM call.
from_agent: Optional agent caller to be used for the LLM call.
response_model: Optional response model to be used for the LLM call.
Returns:
Either a text response from the LLM (str) or
@@ -150,7 +158,9 @@ class BaseLLM(ABC):
RuntimeError: If the LLM request fails for other reasons.
"""
def _convert_tools_for_interference(self, tools: list[dict]) -> list[dict]:
def _convert_tools_for_interference(
self, tools: list[dict[str, BaseTool]]
) -> list[dict[str, BaseTool]]:
"""Convert tools to a format that can be used for interference.
Args:
@@ -237,11 +247,11 @@ class BaseLLM(ABC):
def _emit_call_started_event(
self,
messages: str | list[LLMMessage],
tools: list[dict] | None = None,
tools: list[dict[str, BaseTool]] | None = None,
callbacks: list[Any] | None = None,
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
from_task: Task | None = None,
from_agent: Agent | None = None,
) -> None:
"""Emit LLM call started event."""
if not hasattr(crewai_event_bus, "emit"):
@@ -264,8 +274,8 @@ class BaseLLM(ABC):
self,
response: Any,
call_type: LLMCallType,
from_task: Any | None = None,
from_agent: Any | None = None,
from_task: Task | None = None,
from_agent: Agent | None = None,
messages: str | list[dict[str, Any]] | None = None,
) -> None:
"""Emit LLM call completed event."""
@@ -284,8 +294,8 @@ class BaseLLM(ABC):
def _emit_call_failed_event(
self,
error: str,
from_task: Any | None = None,
from_agent: Any | None = None,
from_task: Task | None = None,
from_agent: Agent | None = None,
) -> None:
"""Emit LLM call failed event."""
if not hasattr(crewai_event_bus, "emit"):
@@ -303,8 +313,8 @@ class BaseLLM(ABC):
def _emit_stream_chunk_event(
self,
chunk: str,
from_task: Any | None = None,
from_agent: Any | None = None,
from_task: Task | None = None,
from_agent: Agent | None = None,
tool_call: dict[str, Any] | None = None,
) -> None:
"""Emit stream chunk event."""
@@ -326,8 +336,8 @@ class BaseLLM(ABC):
function_name: str,
function_args: dict[str, Any],
available_functions: dict[str, Any],
from_task: Any | None = None,
from_agent: Any | None = None,
from_task: Task | None = None,
from_agent: Agent | None = None,
) -> str | None:
"""Handle tool execution with proper event emission.
@@ -443,10 +453,10 @@ class BaseLLM(ABC):
f"Message at index {i} must have 'role' and 'content' keys"
)
return messages # type: ignore[return-value]
return messages
@staticmethod
def _validate_structured_output(
self,
response: str,
response_format: type[BaseModel] | None,
) -> str | BaseModel:
@@ -471,10 +481,7 @@ class BaseLLM(ABC):
data = json.loads(response)
return response_format.model_validate(data)
# Try to extract JSON from response
import re
json_match = re.search(r"\{.*\}", response, re.DOTALL)
json_match = _JSON_EXTRACTION_PATTERN.search(response)
if json_match:
data = json.loads(json_match.group())
return response_format.model_validate(data)
@@ -487,7 +494,8 @@ class BaseLLM(ABC):
f"Failed to parse response into {response_format.__name__}: {e}"
) from e
def _extract_provider(self, model: str) -> str:
@staticmethod
def _extract_provider(model: str) -> str:
"""Extract provider from model string.
Args:

View File

@@ -1,7 +1,13 @@
from __future__ import annotations
import json
import logging
import os
from typing import Any, cast
from pydantic import BaseModel
from crewai.events.types.llm_events import LLMCallType
from crewai.llms.base_llm import BaseLLM
from crewai.utilities.agent_utils import is_context_length_exceeded
@@ -109,6 +115,7 @@ class AnthropicCompletion(BaseLLM):
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
response_model: type[BaseModel] | None = None,
) -> str | Any:
"""Call Anthropic messages API.
@@ -147,11 +154,19 @@ class AnthropicCompletion(BaseLLM):
# Handle streaming vs non-streaming
if self.stream:
return self._handle_streaming_completion(
completion_params, available_functions, from_task, from_agent
completion_params,
available_functions,
from_task,
from_agent,
response_model,
)
return self._handle_completion(
completion_params, available_functions, from_task, from_agent
completion_params,
available_functions,
from_task,
from_agent,
response_model,
)
except Exception as e:
@@ -290,8 +305,19 @@ class AnthropicCompletion(BaseLLM):
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
response_model: type[BaseModel] | None = None,
) -> str | Any:
"""Handle non-streaming message completion."""
if response_model:
structured_tool = {
"name": "structured_output",
"description": "Returns structured data according to the schema",
"input_schema": response_model.model_json_schema(),
}
params["tools"] = [structured_tool]
params["tool_choice"] = {"type": "tool", "name": "structured_output"}
try:
response: Message = self.client.messages.create(**params)
@@ -304,6 +330,24 @@ class AnthropicCompletion(BaseLLM):
usage = self._extract_anthropic_token_usage(response)
self._track_token_usage_internal(usage)
if response_model and response.content:
tool_uses = [
block for block in response.content if isinstance(block, ToolUseBlock)
]
if tool_uses and tool_uses[0].name == "structured_output":
structured_data = tool_uses[0].input
structured_json = json.dumps(structured_data)
self._emit_call_completed_event(
response=structured_json,
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
)
return structured_json
# Check if Claude wants to use tools
if response.content and available_functions:
tool_uses = [
@@ -349,8 +393,19 @@ class AnthropicCompletion(BaseLLM):
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
response_model: type[BaseModel] | None = None,
) -> str:
"""Handle streaming message completion."""
if response_model:
structured_tool = {
"name": "structured_output",
"description": "Returns structured data according to the schema",
"input_schema": response_model.model_json_schema(),
}
params["tools"] = [structured_tool]
params["tool_choice"] = {"type": "tool", "name": "structured_output"}
full_response = ""
# Remove 'stream' parameter as messages.stream() doesn't accept it
@@ -374,6 +429,26 @@ class AnthropicCompletion(BaseLLM):
usage = self._extract_anthropic_token_usage(final_message)
self._track_token_usage_internal(usage)
if response_model and final_message.content:
tool_uses = [
block
for block in final_message.content
if isinstance(block, ToolUseBlock)
]
if tool_uses and tool_uses[0].name == "structured_output":
structured_data = tool_uses[0].input
structured_json = json.dumps(structured_data)
self._emit_call_completed_event(
response=structured_json,
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
)
return structured_json
if final_message.content and available_functions:
tool_uses = [
block

View File

@@ -1,7 +1,11 @@
from __future__ import annotations
import json
import logging
import os
from typing import Any
from typing import Any, TYPE_CHECKING
from pydantic import BaseModel
from crewai.utilities.agent_utils import is_context_length_exceeded
from crewai.utilities.exceptions.context_window_exceeding_exception import (
@@ -9,6 +13,9 @@ from crewai.utilities.exceptions.context_window_exceeding_exception import (
)
from crewai.utilities.types import LLMMessage
if TYPE_CHECKING:
from crewai.tools.base_tool import BaseTool
try:
from azure.ai.inference import ( # type: ignore[import-not-found]
@@ -157,11 +164,12 @@ class AzureCompletion(BaseLLM):
def call(
self,
messages: str | list[LLMMessage],
tools: list[dict] | None = None,
tools: list[dict[str, BaseTool]] | None = None,
callbacks: list[Any] | None = None,
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
response_model: type[BaseModel] | None = None,
) -> str | Any:
"""Call Azure AI Inference chat completions API.
@@ -192,17 +200,25 @@ class AzureCompletion(BaseLLM):
# Prepare completion parameters
completion_params = self._prepare_completion_params(
formatted_messages, tools
formatted_messages, tools, response_model
)
# Handle streaming vs non-streaming
if self.stream:
return self._handle_streaming_completion(
completion_params, available_functions, from_task, from_agent
completion_params,
available_functions,
from_task,
from_agent,
response_model,
)
return self._handle_completion(
completion_params, available_functions, from_task, from_agent
completion_params,
available_functions,
from_task,
from_agent,
response_model,
)
except HttpResponseError as e:
@@ -234,12 +250,14 @@ class AzureCompletion(BaseLLM):
self,
messages: list[LLMMessage],
tools: list[dict] | None = None,
response_model: type[BaseModel] | None = None,
) -> dict[str, Any]:
"""Prepare parameters for Azure AI Inference chat completion.
Args:
messages: Formatted messages for Azure
tools: Tool definitions
response_model: Pydantic model for structured output
Returns:
Parameters dictionary for Azure API
@@ -249,6 +267,15 @@ class AzureCompletion(BaseLLM):
"stream": self.stream,
}
if response_model and self.is_openai_model:
params["response_format"] = {
"type": "json_schema",
"json_schema": {
"name": response_model.__name__,
"schema": response_model.model_json_schema(),
},
}
# Only include model parameter for non-Azure OpenAI endpoints
# Azure OpenAI endpoints have the deployment name in the URL
if not self.is_azure_openai_endpoint:
@@ -334,6 +361,7 @@ class AzureCompletion(BaseLLM):
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
response_model: type[BaseModel] | None = None,
) -> str | Any:
"""Handle non-streaming chat completion."""
# Make API call
@@ -350,6 +378,26 @@ class AzureCompletion(BaseLLM):
usage = self._extract_azure_token_usage(response)
self._track_token_usage_internal(usage)
if response_model and self.is_openai_model:
content = message.content or ""
try:
structured_data = response_model.model_validate_json(content)
structured_json = structured_data.model_dump_json()
self._emit_call_completed_event(
response=structured_json,
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
)
return structured_json
except Exception as e:
error_msg = f"Failed to validate structured output with model {response_model.__name__}: {e}"
logging.error(error_msg)
raise ValueError(error_msg) from e
# Handle tool calls
if message.tool_calls and available_functions:
tool_call = message.tool_calls[0] # Handle first tool call
@@ -409,6 +457,7 @@ class AzureCompletion(BaseLLM):
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
response_model: type[BaseModel] | None = None,
) -> str:
"""Handle streaming chat completion."""
full_response = ""

View File

@@ -5,6 +5,7 @@ import logging
import os
from typing import TYPE_CHECKING, Any, TypedDict, cast
from pydantic import BaseModel
from typing_extensions import Required
from crewai.events.types.llm_events import LLMCallType
@@ -240,6 +241,7 @@ class BedrockCompletion(BaseLLM):
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
response_model: type[BaseModel] | None = None,
) -> str | Any:
"""Call AWS Bedrock Converse API."""
try:

View File

@@ -1,7 +1,10 @@
import json
import logging
import os
from typing import Any, cast
from pydantic import BaseModel
from crewai.events.types.llm_events import LLMCallType
from crewai.llms.base_llm import BaseLLM
from crewai.utilities.agent_utils import is_context_length_exceeded
@@ -173,6 +176,7 @@ class GeminiCompletion(BaseLLM):
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
response_model: type[BaseModel] | None = None,
) -> str | Any:
"""Call Google Gemini generate content API.
@@ -202,7 +206,9 @@ class GeminiCompletion(BaseLLM):
messages # type: ignore[arg-type]
)
config = self._prepare_generation_config(system_instruction, tools)
config = self._prepare_generation_config(
system_instruction, tools, response_model
)
if self.stream:
return self._handle_streaming_completion(
@@ -211,6 +217,7 @@ class GeminiCompletion(BaseLLM):
available_functions,
from_task,
from_agent,
response_model,
)
return self._handle_completion(
@@ -220,6 +227,7 @@ class GeminiCompletion(BaseLLM):
available_functions,
from_task,
from_agent,
response_model,
)
except APIError as e:
@@ -241,12 +249,14 @@ class GeminiCompletion(BaseLLM):
self,
system_instruction: str | None = None,
tools: list[dict] | None = None,
response_model: type[BaseModel] | None = None,
) -> types.GenerateContentConfig:
"""Prepare generation config for Google Gemini API.
Args:
system_instruction: System instruction for the model
tools: Tool definitions
response_model: Pydantic model for structured output
Returns:
GenerateContentConfig object for Gemini API
@@ -274,6 +284,10 @@ class GeminiCompletion(BaseLLM):
if self.stop_sequences:
config_params["stop_sequences"] = self.stop_sequences
if response_model:
config_params["response_mime_type"] = "application/json"
config_params["response_schema"] = response_model.model_json_schema()
# Handle tools for supported models
if tools and self.supports_tools:
config_params["tools"] = self._convert_tools_for_interference(tools)
@@ -358,6 +372,7 @@ class GeminiCompletion(BaseLLM):
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
response_model: type[BaseModel] | None = None,
) -> str | Any:
"""Handle non-streaming content generation."""
api_params = {
@@ -423,6 +438,7 @@ class GeminiCompletion(BaseLLM):
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
response_model: type[BaseModel] | None = None,
) -> str:
"""Handle streaming content generation."""
full_response = ""

View File

@@ -1,8 +1,10 @@
from __future__ import annotations
from collections.abc import Iterator
import json
import logging
import os
from typing import Any
from typing import TYPE_CHECKING, Any
from openai import APIConnectionError, NotFoundError, OpenAI
from openai.types.chat import ChatCompletion, ChatCompletionChunk
@@ -19,6 +21,12 @@ from crewai.utilities.exceptions.context_window_exceeding_exception import (
from crewai.utilities.types import LLMMessage
if TYPE_CHECKING:
from crewai.agent.core import Agent
from crewai.task import Task
from crewai.tools.base_tool import BaseTool
class OpenAICompletion(BaseLLM):
"""OpenAI native completion implementation.
@@ -51,8 +59,8 @@ class OpenAICompletion(BaseLLM):
top_logprobs: int | None = None,
reasoning_effort: str | None = None,
provider: str | None = None,
**kwargs,
):
**kwargs: Any,
) -> None:
"""Initialize OpenAI chat completion client."""
if provider is None:
@@ -129,11 +137,12 @@ class OpenAICompletion(BaseLLM):
def call(
self,
messages: str | list[LLMMessage],
tools: list[dict] | None = None,
tools: list[dict[str, BaseTool]] | None = None,
callbacks: list[Any] | None = None,
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
from_task: Task | None = None,
from_agent: Agent | None = None,
response_model: type[BaseModel] | None = None,
) -> str | Any:
"""Call OpenAI chat completion API.
@@ -144,13 +153,14 @@ class OpenAICompletion(BaseLLM):
available_functions: Available functions for tool calling
from_task: Task that initiated the call
from_agent: Agent that initiated the call
response_model: Response model for structured output.
Returns:
Chat completion response or tool call result
"""
try:
self._emit_call_started_event(
messages=messages, # type: ignore[arg-type]
messages=messages,
tools=tools,
callbacks=callbacks,
available_functions=available_functions,
@@ -158,19 +168,27 @@ class OpenAICompletion(BaseLLM):
from_agent=from_agent,
)
formatted_messages = self._format_messages(messages) # type: ignore[arg-type]
formatted_messages = self._format_messages(messages)
completion_params = self._prepare_completion_params(
formatted_messages, tools
messages=formatted_messages, tools=tools
)
if self.stream:
return self._handle_streaming_completion(
completion_params, available_functions, from_task, from_agent
params=completion_params,
available_functions=available_functions,
from_task=from_task,
from_agent=from_agent,
response_model=response_model,
)
return self._handle_completion(
completion_params, available_functions, from_task, from_agent
params=completion_params,
available_functions=available_functions,
from_task=from_task,
from_agent=from_agent,
response_model=response_model,
)
except Exception as e:
@@ -182,14 +200,15 @@ class OpenAICompletion(BaseLLM):
raise
def _prepare_completion_params(
self, messages: list[LLMMessage], tools: list[dict] | None = None
self, messages: list[LLMMessage], tools: list[dict[str, BaseTool]] | None = None
) -> dict[str, Any]:
"""Prepare parameters for OpenAI chat completion."""
params = {
params: dict[str, Any] = {
"model": self.model,
"messages": messages,
"stream": self.stream,
}
if self.stream:
params["stream"] = self.stream
params.update(self.additional_params)
@@ -216,22 +235,6 @@ class OpenAICompletion(BaseLLM):
if self.is_o1_model and self.reasoning_effort:
params["reasoning_effort"] = self.reasoning_effort
# Handle response format for structured outputs
if self.response_format:
if isinstance(self.response_format, type) and issubclass(
self.response_format, BaseModel
):
# Convert Pydantic model to OpenAI response format
params["response_format"] = {
"type": "json_schema",
"json_schema": {
"name": self.response_format.__name__,
"schema": self.response_format.model_json_schema(),
},
}
else:
params["response_format"] = self.response_format
if tools:
params["tools"] = self._convert_tools_for_interference(tools)
params["tool_choice"] = "auto"
@@ -251,7 +254,9 @@ class OpenAICompletion(BaseLLM):
return {k: v for k, v in params.items() if k not in crewai_specific_params}
def _convert_tools_for_interference(self, tools: list[dict]) -> list[dict]:
def _convert_tools_for_interference(
self, tools: list[dict[str, BaseTool]]
) -> list[dict[str, Any]]:
"""Convert CrewAI tool format to OpenAI function calling format."""
from crewai.llms.providers.utils.common import safe_tool_conversion
@@ -283,9 +288,35 @@ class OpenAICompletion(BaseLLM):
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
response_model: type[BaseModel] | None = None,
) -> str | Any:
"""Handle non-streaming chat completion."""
try:
if response_model:
parsed_response = self.client.beta.chat.completions.parse(
**params,
response_format=response_model,
)
math_reasoning = parsed_response.choices[0].message
if math_reasoning.refusal:
pass
usage = self._extract_openai_token_usage(parsed_response)
self._track_token_usage_internal(usage)
parsed_object = parsed_response.choices[0].message.parsed
if parsed_object:
structured_json = parsed_object.model_dump_json()
self._emit_call_completed_event(
response=structured_json,
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
)
return structured_json
response: ChatCompletion = self.client.chat.completions.create(**params)
usage = self._extract_openai_token_usage(response)
@@ -380,12 +411,57 @@ class OpenAICompletion(BaseLLM):
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
response_model: type[BaseModel] | None = None,
) -> str:
"""Handle streaming chat completion."""
full_response = ""
tool_calls = {}
# Make streaming API call
if response_model:
completion_stream: Iterator[ChatCompletionChunk] = (
self.client.chat.completions.create(**params)
)
accumulated_content = ""
for chunk in completion_stream:
if not chunk.choices:
continue
choice = chunk.choices[0]
delta: ChoiceDelta = choice.delta
if delta.content:
accumulated_content += delta.content
self._emit_stream_chunk_event(
chunk=delta.content,
from_task=from_task,
from_agent=from_agent,
)
try:
parsed_object = response_model.model_validate_json(accumulated_content)
structured_json = parsed_object.model_dump_json()
self._emit_call_completed_event(
response=structured_json,
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
)
return structured_json
except Exception as e:
logging.error(f"Failed to parse structured output from stream: {e}")
self._emit_call_completed_event(
response=accumulated_content,
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
)
return accumulated_content
stream: Iterator[ChatCompletionChunk] = self.client.chat.completions.create(
**params
)
@@ -395,20 +471,18 @@ class OpenAICompletion(BaseLLM):
continue
choice = chunk.choices[0]
delta: ChoiceDelta = choice.delta
chunk_delta: ChoiceDelta = choice.delta
# Handle content streaming
if delta.content:
full_response += delta.content
if chunk_delta.content:
full_response += chunk_delta.content
self._emit_stream_chunk_event(
chunk=delta.content,
chunk=chunk_delta.content,
from_task=from_task,
from_agent=from_agent,
)
# Handle tool call streaming
if delta.tool_calls:
for tool_call in delta.tool_calls:
if chunk_delta.tool_calls:
for tool_call in chunk_delta.tool_calls:
call_id = tool_call.id or "default"
if call_id not in tool_calls:
tool_calls[call_id] = {
@@ -454,10 +528,8 @@ class OpenAICompletion(BaseLLM):
if result is not None:
return result
# Apply stop words to full response
full_response = self._apply_stop_words(full_response)
# Emit completion event and return full response
self._emit_call_completed_event(
response=full_response,
call_type=LLMCallType.LLM_CALL,
@@ -523,12 +595,9 @@ class OpenAICompletion(BaseLLM):
}
return {"total_tokens": 0}
def _format_messages( # type: ignore[override]
self, messages: str | list[LLMMessage]
) -> list[LLMMessage]:
def _format_messages(self, messages: str | list[LLMMessage]) -> list[LLMMessage]:
"""Format messages for OpenAI API."""
# Use base class formatting first
base_formatted = super()._format_messages(messages) # type: ignore[arg-type]
base_formatted = super()._format_messages(messages)
# Apply OpenAI-specific formatting
formatted_messages: list[LLMMessage] = []

View File

@@ -32,7 +32,7 @@ from pydantic_core import PydanticCustomError
from crewai.agents.agent_builder.base_agent import BaseAgent
from crewai.events.event_bus import crewai_event_bus
from crewai.events.event_types import (
from crewai.events.types.task_events import (
TaskCompletedEvent,
TaskFailedEvent,
TaskStartedEvent,
@@ -123,6 +123,10 @@ class Task(BaseModel):
description="A Pydantic model to be used to create a Pydantic output.",
default=None,
)
response_model: type[BaseModel] | None = Field(
description="A Pydantic model for structured LLM outputs using native provider features.",
default=None,
)
output_file: str | None = Field(
description="A file path to be used to create a file output.",
default=None,

View File

@@ -0,0 +1,25 @@
"""Utilities for creating and manipulating types."""
from typing import Annotated, Final, Literal
from typing_extensions import TypeAliasType
_DYNAMIC_LITERAL_ALIAS: Final[Literal["DynamicLiteral"]] = "DynamicLiteral"
def create_literals_from_strings(
values: Annotated[
tuple[str, ...], "Should contain unique strings; duplicates will be removed"
],
) -> type:
"""Create a Literal type for each A2A agent ID.
Args:
values: a tuple of the A2A agent IDs
Returns:
Literal type for each A2A agent ID
"""
unique_values: tuple[str, ...] = tuple(dict.fromkeys(values))
return Literal.__getitem__(unique_values)

View File

@@ -5,6 +5,7 @@ import json
import re
from typing import TYPE_CHECKING, Any, Final, Literal, TypedDict
from pydantic import BaseModel
from rich.console import Console
from crewai.agents.constants import FINAL_ANSWER_AND_PARSABLE_ACTION_ERROR_MESSAGE
@@ -226,6 +227,7 @@ def get_llm_response(
printer: Printer,
from_task: Task | None = None,
from_agent: Agent | LiteAgent | None = None,
response_model: type[BaseModel] | None = None,
) -> str:
"""Call the LLM and return the response, handling any invalid responses.
@@ -236,6 +238,7 @@ def get_llm_response(
printer: Printer instance for output
from_task: Optional task context for the LLM call
from_agent: Optional agent context for the LLM call
response_model: Optional Pydantic model for structured outputs
Returns:
The response from the LLM as a string
@@ -250,6 +253,7 @@ def get_llm_response(
callbacks=callbacks,
from_task=from_task,
from_agent=from_agent,
response_model=response_model,
)
except Exception as e:
raise e

View File

@@ -1,8 +1,10 @@
from __future__ import annotations
from collections.abc import Callable
from copy import deepcopy
import json
import re
from typing import TYPE_CHECKING, Any, Final, TypedDict, Union, get_args, get_origin
from typing import TYPE_CHECKING, Any, Final, TypedDict
from pydantic import BaseModel, ValidationError
from typing_extensions import Unpack
@@ -53,7 +55,14 @@ class Converter(OutputConverter):
"""
try:
if self.llm.supports_function_calling():
result = self._create_instructor().to_pydantic()
response = self.llm.call(
messages=[
{"role": "system", "content": self.instructions},
{"role": "user", "content": self.text},
],
response_model=self.model,
)
result = self.model.model_validate_json(response)
else:
response = self.llm.call(
[
@@ -66,7 +75,7 @@ class Converter(OutputConverter):
result = self.model.model_validate_json(response)
except ValidationError:
# If direct validation fails, attempt to extract valid JSON
result = handle_partial_json(
result = handle_partial_json( # type: ignore[assignment]
result=response,
model=self.model,
is_json_output=False,
@@ -131,7 +140,7 @@ class Converter(OutputConverter):
return self.to_json(current_attempt + 1)
return ConverterError(f"Failed to convert text into JSON, error: {e}.")
def _create_instructor(self) -> InternalInstructor:
def _create_instructor(self) -> InternalInstructor[Any]:
"""Create an instructor."""
return InternalInstructor(
@@ -264,7 +273,7 @@ def convert_with_instructions(
is_json_output: bool,
agent: Agent | BaseAgent | None,
converter_cls: type[Converter] | None = None,
) -> dict | BaseModel | str:
) -> dict[str, Any] | BaseModel | str:
"""Convert a result string to a Pydantic model or JSON using instructions.
Args:
@@ -336,13 +345,14 @@ def get_conversion_instructions(
model_schema = PydanticSchemaParser(model=model).get_schema()
instructions += (
f"\n\nOutput ONLY the valid JSON and nothing else.\n\n"
f"The JSON must follow this schema exactly:\n```json\n{model_schema}\n```"
f"Use this format exactly:\n```json\n{model_schema}\n```"
)
else:
model_description = generate_model_description(model)
schema_json = json.dumps(model_description["json_schema"]["schema"], indent=2)
instructions += (
f"\n\nOutput ONLY the valid JSON and nothing else.\n\n"
f"The JSON must follow this format exactly:\n{model_description}"
f"Use this format exactly:\n```json\n{schema_json}\n```"
)
return instructions
@@ -399,57 +409,222 @@ def create_converter(
if not converter:
raise Exception("No output converter found or set.")
return converter
return converter # type: ignore[no-any-return]
def generate_model_description(model: type[BaseModel]) -> str:
"""Generate a string description of a Pydantic model's fields and their types.
def resolve_refs(schema: dict[str, Any]) -> dict[str, Any]:
"""Recursively resolve all local $refs in the given JSON Schema using $defs as the source.
This function takes a Pydantic model class and returns a string that describes
the model's fields and their respective types. The description includes handling
of complex types such as `Optional`, `List`, and `Dict`, as well as nested Pydantic
models.
This is needed because Pydantic generates $ref-based schemas that
some consumers (e.g. LLMs, tool frameworks) don't handle well.
Args:
schema: JSON Schema dict that may contain "$refs" and "$defs".
Returns:
A new schema dictionary with all local $refs replaced by their definitions.
"""
defs = schema.get("$defs", {})
schema_copy = deepcopy(schema)
def _resolve(node: Any) -> Any:
if isinstance(node, dict):
ref = node.get("$ref")
if isinstance(ref, str) and ref.startswith("#/$defs/"):
def_name = ref.replace("#/$defs/", "")
if def_name in defs:
return _resolve(deepcopy(defs[def_name]))
raise KeyError(f"Definition '{def_name}' not found in $defs.")
return {k: _resolve(v) for k, v in node.items()}
if isinstance(node, list):
return [_resolve(i) for i in node]
return node
return _resolve(schema_copy) # type: ignore[no-any-return]
def add_key_in_dict_recursively(
d: dict[str, Any], key: str, value: Any, criteria: Callable[[dict[str, Any]], bool]
) -> dict[str, Any]:
"""Recursively adds a key/value pair to all nested dicts matching `criteria`."""
if isinstance(d, dict):
if criteria(d) and key not in d:
d[key] = value
for v in d.values():
add_key_in_dict_recursively(v, key, value, criteria)
elif isinstance(d, list):
for i in d:
add_key_in_dict_recursively(i, key, value, criteria)
return d
def fix_discriminator_mappings(schema: dict[str, Any]) -> dict[str, Any]:
"""Replace '#/$defs/...' references in discriminator.mapping with just the model name."""
output = schema.get("properties", {}).get("output")
if not output:
return schema
disc = output.get("discriminator")
if not disc or "mapping" not in disc:
return schema
disc["mapping"] = {k: v.split("/")[-1] for k, v in disc["mapping"].items()}
return schema
def add_const_to_oneof_variants(schema: dict[str, Any]) -> dict[str, Any]:
"""Add const fields to oneOf variants for discriminated unions.
The json_schema_to_pydantic library requires each oneOf variant to have
a const field for the discriminator property. This function adds those
const fields based on the discriminator mapping.
Args:
schema: JSON Schema dict that may contain discriminated unions
Returns:
Modified schema with const fields added to oneOf variants
"""
def _process_oneof(node: dict[str, Any]) -> dict[str, Any]:
"""Process a single node that might contain a oneOf with discriminator."""
if not isinstance(node, dict):
return node
if "oneOf" in node and "discriminator" in node:
discriminator = node["discriminator"]
property_name = discriminator.get("propertyName")
mapping = discriminator.get("mapping", {})
if property_name and mapping:
one_of_variants = node.get("oneOf", [])
for variant in one_of_variants:
if isinstance(variant, dict) and "properties" in variant:
variant_title = variant.get("title", "")
matched_disc_value = None
for disc_value, schema_name in mapping.items():
if variant_title == schema_name or variant_title.endswith(
schema_name
):
matched_disc_value = disc_value
break
if matched_disc_value is not None:
props = variant["properties"]
if property_name in props:
props[property_name]["const"] = matched_disc_value
for key, value in node.items():
if isinstance(value, dict):
node[key] = _process_oneof(value)
elif isinstance(value, list):
node[key] = [
_process_oneof(item) if isinstance(item, dict) else item
for item in value
]
return node
return _process_oneof(deepcopy(schema))
def convert_oneof_to_anyof(schema: dict[str, Any]) -> dict[str, Any]:
"""Convert oneOf to anyOf for OpenAI compatibility.
OpenAI's Structured Outputs support anyOf better than oneOf.
This recursively converts all oneOf occurrences to anyOf.
Args:
schema: JSON schema dictionary.
Returns:
Modified schema with anyOf instead of oneOf.
"""
if isinstance(schema, dict):
if "oneOf" in schema:
schema["anyOf"] = schema.pop("oneOf")
for value in schema.values():
if isinstance(value, dict):
convert_oneof_to_anyof(value)
elif isinstance(value, list):
for item in value:
if isinstance(item, dict):
convert_oneof_to_anyof(item)
return schema
def ensure_all_properties_required(schema: dict[str, Any]) -> dict[str, Any]:
"""Ensure all properties are in the required array for OpenAI strict mode.
OpenAI's strict structured outputs require all properties to be listed
in the required array. This recursively updates all objects to include
all their properties in required.
Args:
schema: JSON schema dictionary.
Returns:
Modified schema with all properties marked as required.
"""
if isinstance(schema, dict):
if schema.get("type") == "object" and "properties" in schema:
properties = schema["properties"]
if properties:
schema["required"] = list(properties.keys())
for value in schema.values():
if isinstance(value, dict):
ensure_all_properties_required(value)
elif isinstance(value, list):
for item in value:
if isinstance(item, dict):
ensure_all_properties_required(item)
return schema
def generate_model_description(model: type[BaseModel]) -> dict[str, Any]:
"""Generate JSON schema description of a Pydantic model.
This function takes a Pydantic model class and returns its JSON schema,
which includes full type information, discriminators, and all metadata.
The schema is dereferenced to inline all $ref references for better LLM understanding.
Args:
model: A Pydantic model class.
Returns:
A string representation of the model's fields and types.
A JSON schema dictionary representation of the model.
"""
def describe_field(field_type: Any) -> str:
"""Recursively describe a field's type.
json_schema = model.model_json_schema(ref_template="#/$defs/{model}")
Args:
field_type: The type of the field to describe.
json_schema = add_key_in_dict_recursively(
json_schema,
key="additionalProperties",
value=False,
criteria=lambda d: d.get("type") == "object"
and "additionalProperties" not in d,
)
Returns:
A string representation of the field's type.
"""
origin = get_origin(field_type)
args = get_args(field_type)
json_schema = resolve_refs(json_schema)
if origin is Union or (origin is None and len(args) > 0):
# Handle both Union and the new '|' syntax
non_none_args = [arg for arg in args if arg is not type(None)]
if len(non_none_args) == 1:
return f"Optional[{describe_field(non_none_args[0])}]"
return f"Optional[Union[{', '.join(describe_field(arg) for arg in non_none_args)}]]"
if origin is list:
return f"List[{describe_field(args[0])}]"
if origin is dict:
key_type = describe_field(args[0])
value_type = describe_field(args[1])
return f"Dict[{key_type}, {value_type}]"
if isinstance(field_type, type) and issubclass(field_type, BaseModel):
return generate_model_description(field_type)
if hasattr(field_type, "__name__"):
return field_type.__name__
return str(field_type)
json_schema.pop("$defs", None)
json_schema = fix_discriminator_mappings(json_schema)
json_schema = convert_oneof_to_anyof(json_schema)
json_schema = ensure_all_properties_required(json_schema)
fields = model.model_fields
field_descriptions = [
f'"{name}": {describe_field(field.annotation)}'
for name, field in fields.items()
]
return "{\n " + ",\n ".join(field_descriptions) + "\n}"
return {
"type": "json_schema",
"json_schema": {
"name": model.__name__,
"strict": True,
"schema": json_schema,
},
}