Compare commits

..

10 Commits

Author SHA1 Message Date
Greyson LaLonde
6da4b13702 feat: add a2a delegation support to LiteAgent 2026-01-29 13:12:05 -05:00
Greyson LaLonde
75c68204ae Revert "chore: remove unused _AuthStore class and related imports"
This reverts commit 3f066f2b0f.
2026-01-29 12:15:18 -05:00
Greyson LaLonde
3f066f2b0f chore: remove unused _AuthStore class and related imports 2026-01-29 12:10:53 -05:00
Greyson LaLonde
4543c66697 feat: add transport negotiation and content type handling
- add transport negotiation logic with fallback support
- add content type parser and encoder utilities
- add transport configuration models (client and server)
- add transport types and enums
- enhance config with transport settings
- add negotiation events for transport and content type
2026-01-29 05:13:42 -05:00
Greyson LaLonde
e4be1329a0 feat: add server-side auth schemes and protocol extensions
- add server auth scheme base class and implementations (api key, bearer token, basic/digest auth, mtls)
- add server-side extension system for a2a protocol extensions
- add extensions middleware for x-a2a-extensions header management
- add extension validation and registry utilities
- enhance auth utilities with server-side support
- add async intercept method to match client call interceptor protocol
- fix type_checking import to resolve mypy errors with a2aconfig
2026-01-29 04:52:36 -05:00
Lorenze Jay
e291a97bdd chore: update version to 1.9.2 across all relevant files (#4299)
Some checks are pending
CodeQL Advanced / Analyze (actions) (push) Waiting to run
CodeQL Advanced / Analyze (python) (push) Waiting to run
Notify Downstream / notify-downstream (push) Waiting to run
2026-01-28 17:11:44 -08:00
Lorenze Jay
2d05e59223 Lorenze/improve tool response pt2 (#4297)
* no need post tool reflection on native tools

* refactor: update prompt generation to prevent thought leakage

- Modified the prompt structure to ensure agents without tools use a simplified format, avoiding ReAct instructions.
- Introduced a new 'task_no_tools' slice for agents lacking tools, ensuring clean output without Thought: prefixes.
- Enhanced test coverage to verify that prompts do not encourage thought leakage, ensuring outputs remain focused and direct.
- Added integration tests to validate that real LLM calls produce clean outputs without internal reasoning artifacts.

* dont forget the cassettes
2026-01-28 16:53:19 -08:00
Greyson LaLonde
a731efac8d fix: improve structured output handling across providers and agents
- add gemini 2.0 schema support using response_json_schema with propertyordering while retaining backward compatibility for earlier models
- refactor llm completions to return validated pydantic models when a response_model is provided, updating hooks, types, and tests for consistent structured outputs
- extend agentfinish and executors to support basemodel outputs, improve anthropic structured parsing, and clean up schema utilities, tests, and original_json handling
2026-01-28 16:59:55 -05:00
Greyson LaLonde
1e27cf3f0f fix: ensure verbosity flag is applied
Some checks failed
CodeQL Advanced / Analyze (actions) (push) Has been cancelled
CodeQL Advanced / Analyze (python) (push) Has been cancelled
Notify Downstream / notify-downstream (push) Has been cancelled
2026-01-28 11:52:47 -05:00
Lorenze Jay
381ad3a9a8 chore: update version to 1.9.1
Some checks failed
CodeQL Advanced / Analyze (actions) (push) Has been cancelled
CodeQL Advanced / Analyze (python) (push) Has been cancelled
Notify Downstream / notify-downstream (push) Has been cancelled
Mark stale issues and pull requests / stale (push) Has been cancelled
2026-01-27 20:08:53 -05:00
88 changed files with 6932 additions and 1202 deletions

View File

@@ -152,4 +152,4 @@ __all__ = [
"wrap_file_source",
]
__version__ = "1.9.0"
__version__ = "1.9.2"

View File

@@ -12,7 +12,7 @@ dependencies = [
"pytube~=15.0.0",
"requests~=2.32.5",
"docker~=7.1.0",
"crewai==1.9.0",
"crewai==1.9.2",
"lancedb~=0.5.4",
"tiktoken~=0.8.0",
"beautifulsoup4~=4.13.4",

View File

@@ -291,4 +291,4 @@ __all__ = [
"ZapierActionTools",
]
__version__ = "1.9.0"
__version__ = "1.9.2"

View File

@@ -49,7 +49,7 @@ Repository = "https://github.com/crewAIInc/crewAI"
[project.optional-dependencies]
tools = [
"crewai-tools==1.9.0",
"crewai-tools==1.9.2",
]
embeddings = [
"tiktoken~=0.8.0"

View File

@@ -40,7 +40,7 @@ def _suppress_pydantic_deprecation_warnings() -> None:
_suppress_pydantic_deprecation_warnings()
__version__ = "1.9.0"
__version__ = "1.9.2"
_telemetry_submitted = False

View File

@@ -1,20 +1,36 @@
"""A2A authentication schemas."""
from crewai.a2a.auth.schemas import (
from crewai.a2a.auth.client_schemes import (
APIKeyAuth,
AuthScheme,
BearerTokenAuth,
ClientAuthScheme,
HTTPBasicAuth,
HTTPDigestAuth,
OAuth2AuthorizationCode,
OAuth2ClientCredentials,
TLSConfig,
)
from crewai.a2a.auth.server_schemes import (
AuthenticatedUser,
OIDCAuth,
ServerAuthScheme,
SimpleTokenAuth,
)
__all__ = [
"APIKeyAuth",
"AuthScheme",
"AuthenticatedUser",
"BearerTokenAuth",
"ClientAuthScheme",
"HTTPBasicAuth",
"HTTPDigestAuth",
"OAuth2AuthorizationCode",
"OAuth2ClientCredentials",
"OIDCAuth",
"ServerAuthScheme",
"SimpleTokenAuth",
"TLSConfig",
]

View File

@@ -1,4 +1,4 @@
"""Authentication schemes for A2A protocol agents.
"""Authentication schemes for A2A protocol clients.
Supported authentication methods:
- Bearer tokens
@@ -6,24 +6,135 @@ Supported authentication methods:
- API Keys (header, query, cookie)
- HTTP Basic authentication
- HTTP Digest authentication
- mTLS (mutual TLS) client certificate authentication
"""
from __future__ import annotations
from abc import ABC, abstractmethod
import asyncio
import base64
from collections.abc import Awaitable, Callable, MutableMapping
from pathlib import Path
import ssl
import time
from typing import Literal
from typing import TYPE_CHECKING, ClassVar, Literal
import urllib.parse
import httpx
from httpx import DigestAuth
from pydantic import BaseModel, Field, PrivateAttr
from pydantic import BaseModel, ConfigDict, Field, FilePath, PrivateAttr
from typing_extensions import deprecated
class AuthScheme(ABC, BaseModel):
"""Base class for authentication schemes."""
if TYPE_CHECKING:
import grpc # type: ignore[import-untyped]
class TLSConfig(BaseModel):
"""TLS/mTLS configuration for secure client connections.
Supports mutual TLS (mTLS) where the client presents a certificate to the server,
and standard TLS with custom CA verification.
Attributes:
client_cert_path: Path to client certificate file (PEM format) for mTLS.
client_key_path: Path to client private key file (PEM format) for mTLS.
ca_cert_path: Path to CA certificate bundle for server verification.
verify: Whether to verify server certificates. Set False only for development.
"""
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
client_cert_path: FilePath | None = Field(
default=None,
description="Path to client certificate file (PEM format) for mTLS",
)
client_key_path: FilePath | None = Field(
default=None,
description="Path to client private key file (PEM format) for mTLS",
)
ca_cert_path: FilePath | None = Field(
default=None,
description="Path to CA certificate bundle for server verification",
)
verify: bool = Field(
default=True,
description="Whether to verify server certificates. Set False only for development.",
)
def get_httpx_ssl_context(self) -> ssl.SSLContext | bool | str:
"""Build SSL context for httpx client.
Returns:
SSL context if certificates configured, True for default verification,
False if verification disabled, or path to CA bundle.
"""
if not self.verify:
return False
if self.client_cert_path and self.client_key_path:
context = ssl.create_default_context()
if self.ca_cert_path:
context.load_verify_locations(cafile=str(self.ca_cert_path))
context.load_cert_chain(
certfile=str(self.client_cert_path),
keyfile=str(self.client_key_path),
)
return context
if self.ca_cert_path:
return str(self.ca_cert_path)
return True
def get_grpc_credentials(self) -> grpc.ChannelCredentials | None: # type: ignore[no-any-unimported]
"""Build gRPC channel credentials for secure connections.
Returns:
gRPC SSL credentials if certificates configured, None otherwise.
"""
try:
import grpc
except ImportError:
return None
if not self.verify and not self.client_cert_path:
return None
root_certs: bytes | None = None
private_key: bytes | None = None
certificate_chain: bytes | None = None
if self.ca_cert_path:
root_certs = Path(self.ca_cert_path).read_bytes()
if self.client_cert_path and self.client_key_path:
private_key = Path(self.client_key_path).read_bytes()
certificate_chain = Path(self.client_cert_path).read_bytes()
return grpc.ssl_channel_credentials(
root_certificates=root_certs,
private_key=private_key,
certificate_chain=certificate_chain,
)
class ClientAuthScheme(ABC, BaseModel):
"""Base class for client-side authentication schemes.
Client auth schemes apply credentials to outgoing requests.
Attributes:
tls: Optional TLS/mTLS configuration for secure connections.
"""
tls: TLSConfig | None = Field(
default=None,
description="TLS/mTLS configuration for secure connections",
)
@abstractmethod
async def apply_auth(
@@ -41,7 +152,12 @@ class AuthScheme(ABC, BaseModel):
...
class BearerTokenAuth(AuthScheme):
@deprecated("Use ClientAuthScheme instead", category=FutureWarning)
class AuthScheme(ClientAuthScheme):
"""Deprecated: Use ClientAuthScheme instead."""
class BearerTokenAuth(ClientAuthScheme):
"""Bearer token authentication (Authorization: Bearer <token>).
Attributes:
@@ -66,7 +182,7 @@ class BearerTokenAuth(AuthScheme):
return headers
class HTTPBasicAuth(AuthScheme):
class HTTPBasicAuth(ClientAuthScheme):
"""HTTP Basic authentication.
Attributes:
@@ -95,7 +211,7 @@ class HTTPBasicAuth(AuthScheme):
return headers
class HTTPDigestAuth(AuthScheme):
class HTTPDigestAuth(ClientAuthScheme):
"""HTTP Digest authentication.
Note: Uses httpx-auth library for digest implementation.
@@ -108,6 +224,8 @@ class HTTPDigestAuth(AuthScheme):
username: str = Field(description="Username")
password: str = Field(description="Password")
_configured_client_id: int | None = PrivateAttr(default=None)
async def apply_auth(
self, client: httpx.AsyncClient, headers: MutableMapping[str, str]
) -> MutableMapping[str, str]:
@@ -125,13 +243,21 @@ class HTTPDigestAuth(AuthScheme):
def configure_client(self, client: httpx.AsyncClient) -> None:
"""Configure client with Digest auth.
Idempotent: Only configures the client once. Subsequent calls on the same
client instance are no-ops to prevent overwriting auth configuration.
Args:
client: HTTP client to configure with Digest authentication.
"""
client_id = id(client)
if self._configured_client_id == client_id:
return
client.auth = DigestAuth(self.username, self.password)
self._configured_client_id = client_id
class APIKeyAuth(AuthScheme):
class APIKeyAuth(ClientAuthScheme):
"""API Key authentication (header, query, or cookie).
Attributes:
@@ -146,6 +272,8 @@ class APIKeyAuth(AuthScheme):
)
name: str = Field(default="X-API-Key", description="Parameter name for the API key")
_configured_client_ids: set[int] = PrivateAttr(default_factory=set)
async def apply_auth(
self, client: httpx.AsyncClient, headers: MutableMapping[str, str]
) -> MutableMapping[str, str]:
@@ -167,21 +295,31 @@ class APIKeyAuth(AuthScheme):
def configure_client(self, client: httpx.AsyncClient) -> None:
"""Configure client for query param API keys.
Idempotent: Only adds the request hook once per client instance.
Subsequent calls on the same client are no-ops to prevent hook accumulation.
Args:
client: HTTP client to configure with query param API key hook.
"""
if self.location == "query":
client_id = id(client)
if client_id in self._configured_client_ids:
return
async def _add_api_key_param(request: httpx.Request) -> None:
url = httpx.URL(request.url)
request.url = url.copy_add_param(self.name, self.api_key)
client.event_hooks["request"].append(_add_api_key_param)
self._configured_client_ids.add(client_id)
class OAuth2ClientCredentials(AuthScheme):
class OAuth2ClientCredentials(ClientAuthScheme):
"""OAuth2 Client Credentials flow authentication.
Thread-safe implementation with asyncio.Lock to prevent concurrent token fetches
when multiple requests share the same auth instance.
Attributes:
token_url: OAuth2 token endpoint URL.
client_id: OAuth2 client identifier.
@@ -198,12 +336,17 @@ class OAuth2ClientCredentials(AuthScheme):
_access_token: str | None = PrivateAttr(default=None)
_token_expires_at: float | None = PrivateAttr(default=None)
_lock: asyncio.Lock = PrivateAttr(default_factory=asyncio.Lock)
async def apply_auth(
self, client: httpx.AsyncClient, headers: MutableMapping[str, str]
) -> MutableMapping[str, str]:
"""Apply OAuth2 access token to Authorization header.
Uses asyncio.Lock to ensure only one coroutine fetches tokens at a time,
preventing race conditions when multiple concurrent requests use the same
auth instance.
Args:
client: HTTP client for making token requests.
headers: Current request headers.
@@ -216,7 +359,13 @@ class OAuth2ClientCredentials(AuthScheme):
or self._token_expires_at is None
or time.time() >= self._token_expires_at
):
await self._fetch_token(client)
async with self._lock:
if (
self._access_token is None
or self._token_expires_at is None
or time.time() >= self._token_expires_at
):
await self._fetch_token(client)
if self._access_token:
headers["Authorization"] = f"Bearer {self._access_token}"
@@ -250,9 +399,11 @@ class OAuth2ClientCredentials(AuthScheme):
self._token_expires_at = time.time() + expires_in - 60
class OAuth2AuthorizationCode(AuthScheme):
class OAuth2AuthorizationCode(ClientAuthScheme):
"""OAuth2 Authorization Code flow authentication.
Thread-safe implementation with asyncio.Lock to prevent concurrent token operations.
Note: Requires interactive authorization.
Attributes:
@@ -279,6 +430,7 @@ class OAuth2AuthorizationCode(AuthScheme):
_authorization_callback: Callable[[str], Awaitable[str]] | None = PrivateAttr(
default=None
)
_lock: asyncio.Lock = PrivateAttr(default_factory=asyncio.Lock)
def set_authorization_callback(
self, callback: Callable[[str], Awaitable[str]] | None
@@ -295,6 +447,9 @@ class OAuth2AuthorizationCode(AuthScheme):
) -> MutableMapping[str, str]:
"""Apply OAuth2 access token to Authorization header.
Uses asyncio.Lock to ensure only one coroutine handles token operations
(initial fetch or refresh) at a time.
Args:
client: HTTP client for making token requests.
headers: Current request headers.
@@ -305,14 +460,17 @@ class OAuth2AuthorizationCode(AuthScheme):
Raises:
ValueError: If authorization callback is not set.
"""
if self._access_token is None:
if self._authorization_callback is None:
msg = "Authorization callback not set. Use set_authorization_callback()"
raise ValueError(msg)
await self._fetch_initial_token(client)
async with self._lock:
if self._access_token is None:
await self._fetch_initial_token(client)
elif self._token_expires_at and time.time() >= self._token_expires_at:
await self._refresh_access_token(client)
async with self._lock:
if self._token_expires_at and time.time() >= self._token_expires_at:
await self._refresh_access_token(client)
if self._access_token:
headers["Authorization"] = f"Bearer {self._access_token}"

View File

@@ -0,0 +1,739 @@
"""Server-side authentication schemes for A2A protocol.
These schemes validate incoming requests to A2A server endpoints.
Supported authentication methods:
- Simple token validation with static bearer tokens
- OpenID Connect with JWT validation using JWKS
- OAuth2 with JWT validation or token introspection
"""
from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass
import logging
import os
from typing import TYPE_CHECKING, Annotated, Any, ClassVar, Literal
import jwt
from jwt import PyJWKClient
from pydantic import (
BaseModel,
BeforeValidator,
ConfigDict,
Field,
HttpUrl,
PrivateAttr,
SecretStr,
model_validator,
)
from typing_extensions import Self
if TYPE_CHECKING:
from a2a.types import OAuth2SecurityScheme
logger = logging.getLogger(__name__)
try:
from fastapi import HTTPException, status as http_status
HTTP_401_UNAUTHORIZED = http_status.HTTP_401_UNAUTHORIZED
HTTP_500_INTERNAL_SERVER_ERROR = http_status.HTTP_500_INTERNAL_SERVER_ERROR
HTTP_503_SERVICE_UNAVAILABLE = http_status.HTTP_503_SERVICE_UNAVAILABLE
except ImportError:
class HTTPException(Exception): # type: ignore[no-redef] # noqa: N818
"""Fallback HTTPException when FastAPI is not installed."""
def __init__(
self,
status_code: int,
detail: str | None = None,
headers: dict[str, str] | None = None,
) -> None:
self.status_code = status_code
self.detail = detail
self.headers = headers
super().__init__(detail)
HTTP_401_UNAUTHORIZED = 401
HTTP_500_INTERNAL_SERVER_ERROR = 500
HTTP_503_SERVICE_UNAVAILABLE = 503
def _coerce_secret_str(v: str | SecretStr | None) -> SecretStr | None:
"""Coerce string to SecretStr."""
if v is None or isinstance(v, SecretStr):
return v
return SecretStr(v)
CoercedSecretStr = Annotated[SecretStr, BeforeValidator(_coerce_secret_str)]
JWTAlgorithm = Literal[
"RS256",
"RS384",
"RS512",
"ES256",
"ES384",
"ES512",
"PS256",
"PS384",
"PS512",
]
@dataclass
class AuthenticatedUser:
"""Result of successful authentication.
Attributes:
token: The original token that was validated.
scheme: Name of the authentication scheme used.
claims: JWT claims from OIDC or OAuth2 authentication.
"""
token: str
scheme: str
claims: dict[str, Any] | None = None
class ServerAuthScheme(ABC, BaseModel):
"""Base class for server-side authentication schemes.
Each scheme validates incoming requests and returns an AuthenticatedUser
on success, or raises HTTPException on failure.
"""
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
@abstractmethod
async def authenticate(self, token: str) -> AuthenticatedUser:
"""Authenticate the provided token.
Args:
token: The bearer token to authenticate.
Returns:
AuthenticatedUser on successful authentication.
Raises:
HTTPException: If authentication fails.
"""
...
class SimpleTokenAuth(ServerAuthScheme):
"""Simple bearer token authentication.
Validates tokens against a configured static token or AUTH_TOKEN env var.
Attributes:
token: Expected token value. Falls back to AUTH_TOKEN env var if not set.
"""
token: CoercedSecretStr | None = Field(
default=None,
description="Expected token. Falls back to AUTH_TOKEN env var.",
)
def _get_expected_token(self) -> str | None:
"""Get the expected token value."""
if self.token:
return self.token.get_secret_value()
return os.environ.get("AUTH_TOKEN")
async def authenticate(self, token: str) -> AuthenticatedUser:
"""Authenticate using simple token comparison.
Args:
token: The bearer token to authenticate.
Returns:
AuthenticatedUser on successful authentication.
Raises:
HTTPException: If authentication fails.
"""
expected = self._get_expected_token()
if expected is None:
logger.warning(
"Simple token authentication failed",
extra={"reason": "no_token_configured"},
)
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Authentication not configured",
)
if token != expected:
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Invalid or missing authentication credentials",
)
return AuthenticatedUser(
token=token,
scheme="simple_token",
)
class OIDCAuth(ServerAuthScheme):
"""OpenID Connect authentication.
Validates JWTs using JWKS with caching support via PyJWT.
Attributes:
issuer: The OpenID Connect issuer URL.
audience: The expected audience claim.
jwks_url: Optional explicit JWKS URL. Derived from issuer if not set.
algorithms: List of allowed signing algorithms.
required_claims: List of claims that must be present in the token.
jwks_cache_ttl: TTL for JWKS cache in seconds.
clock_skew_seconds: Allowed clock skew for token validation.
"""
issuer: HttpUrl = Field(
description="OpenID Connect issuer URL (e.g., https://auth.example.com)"
)
audience: str = Field(description="Expected audience claim (e.g., api://my-agent)")
jwks_url: HttpUrl | None = Field(
default=None,
description="Explicit JWKS URL. Derived from issuer if not set.",
)
algorithms: list[str] = Field(
default_factory=lambda: ["RS256"],
description="List of allowed signing algorithms (RS256, ES256, etc.)",
)
required_claims: list[str] = Field(
default_factory=lambda: ["exp", "iat", "iss", "aud", "sub"],
description="List of claims that must be present in the token",
)
jwks_cache_ttl: int = Field(
default=3600,
description="TTL for JWKS cache in seconds",
ge=60,
)
clock_skew_seconds: float = Field(
default=30.0,
description="Allowed clock skew for token validation",
ge=0.0,
)
_jwk_client: PyJWKClient | None = PrivateAttr(default=None)
@model_validator(mode="after")
def _init_jwk_client(self) -> Self:
"""Initialize the JWK client after model creation."""
jwks_url = (
str(self.jwks_url)
if self.jwks_url
else f"{str(self.issuer).rstrip('/')}/.well-known/jwks.json"
)
self._jwk_client = PyJWKClient(jwks_url, lifespan=self.jwks_cache_ttl)
return self
async def authenticate(self, token: str) -> AuthenticatedUser:
"""Authenticate using OIDC JWT validation.
Args:
token: The JWT to authenticate.
Returns:
AuthenticatedUser on successful authentication.
Raises:
HTTPException: If authentication fails.
"""
if self._jwk_client is None:
raise HTTPException(
status_code=HTTP_500_INTERNAL_SERVER_ERROR,
detail="OIDC not initialized",
)
try:
signing_key = self._jwk_client.get_signing_key_from_jwt(token)
claims = jwt.decode(
token,
signing_key.key,
algorithms=self.algorithms,
audience=self.audience,
issuer=str(self.issuer).rstrip("/"),
leeway=self.clock_skew_seconds,
options={
"require": self.required_claims,
},
)
return AuthenticatedUser(
token=token,
scheme="oidc",
claims=claims,
)
except jwt.ExpiredSignatureError:
logger.debug(
"OIDC authentication failed",
extra={"reason": "token_expired", "scheme": "oidc"},
)
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Token has expired",
) from None
except jwt.InvalidAudienceError:
logger.debug(
"OIDC authentication failed",
extra={"reason": "invalid_audience", "scheme": "oidc"},
)
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Invalid token audience",
) from None
except jwt.InvalidIssuerError:
logger.debug(
"OIDC authentication failed",
extra={"reason": "invalid_issuer", "scheme": "oidc"},
)
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Invalid token issuer",
) from None
except jwt.MissingRequiredClaimError as e:
logger.debug(
"OIDC authentication failed",
extra={"reason": "missing_claim", "claim": e.claim, "scheme": "oidc"},
)
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail=f"Missing required claim: {e.claim}",
) from None
except jwt.PyJWKClientError as e:
logger.error(
"OIDC authentication failed",
extra={
"reason": "jwks_client_error",
"error": str(e),
"scheme": "oidc",
},
)
raise HTTPException(
status_code=HTTP_503_SERVICE_UNAVAILABLE,
detail="Unable to fetch signing keys",
) from None
except jwt.InvalidTokenError as e:
logger.debug(
"OIDC authentication failed",
extra={"reason": "invalid_token", "error": str(e), "scheme": "oidc"},
)
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Invalid or missing authentication credentials",
) from None
class OAuth2ServerAuth(ServerAuthScheme):
"""OAuth2 authentication for A2A server.
Declares OAuth2 security scheme in AgentCard and validates tokens using
either JWKS for JWT tokens or token introspection for opaque tokens.
This is distinct from OIDCAuth in that it declares an explicit OAuth2SecurityScheme
with flows, rather than an OpenIdConnectSecurityScheme with discovery URL.
Attributes:
token_url: OAuth2 token endpoint URL for client_credentials flow.
authorization_url: OAuth2 authorization endpoint for authorization_code flow.
refresh_url: Optional refresh token endpoint URL.
scopes: Available OAuth2 scopes with descriptions.
jwks_url: JWKS URL for JWT validation. Required if not using introspection.
introspection_url: Token introspection endpoint (RFC 7662). Alternative to JWKS.
introspection_client_id: Client ID for introspection endpoint authentication.
introspection_client_secret: Client secret for introspection endpoint.
audience: Expected audience claim for JWT validation.
issuer: Expected issuer claim for JWT validation.
algorithms: Allowed JWT signing algorithms.
required_claims: Claims that must be present in the token.
jwks_cache_ttl: TTL for JWKS cache in seconds.
clock_skew_seconds: Allowed clock skew for token validation.
"""
token_url: HttpUrl = Field(
description="OAuth2 token endpoint URL",
)
authorization_url: HttpUrl | None = Field(
default=None,
description="OAuth2 authorization endpoint URL for authorization_code flow",
)
refresh_url: HttpUrl | None = Field(
default=None,
description="OAuth2 refresh token endpoint URL",
)
scopes: dict[str, str] = Field(
default_factory=dict,
description="Available OAuth2 scopes with descriptions",
)
jwks_url: HttpUrl | None = Field(
default=None,
description="JWKS URL for JWT validation. Required if not using introspection.",
)
introspection_url: HttpUrl | None = Field(
default=None,
description="Token introspection endpoint (RFC 7662). Alternative to JWKS.",
)
introspection_client_id: str | None = Field(
default=None,
description="Client ID for introspection endpoint authentication",
)
introspection_client_secret: CoercedSecretStr | None = Field(
default=None,
description="Client secret for introspection endpoint authentication",
)
audience: str | None = Field(
default=None,
description="Expected audience claim for JWT validation",
)
issuer: str | None = Field(
default=None,
description="Expected issuer claim for JWT validation",
)
algorithms: list[str] = Field(
default_factory=lambda: ["RS256"],
description="Allowed JWT signing algorithms",
)
required_claims: list[str] = Field(
default_factory=lambda: ["exp", "iat"],
description="Claims that must be present in the token",
)
jwks_cache_ttl: int = Field(
default=3600,
description="TTL for JWKS cache in seconds",
ge=60,
)
clock_skew_seconds: float = Field(
default=30.0,
description="Allowed clock skew for token validation",
ge=0.0,
)
_jwk_client: PyJWKClient | None = PrivateAttr(default=None)
@model_validator(mode="after")
def _validate_and_init(self) -> Self:
"""Validate configuration and initialize JWKS client if needed."""
if not self.jwks_url and not self.introspection_url:
raise ValueError(
"Either jwks_url or introspection_url must be provided for token validation"
)
if self.introspection_url:
if not self.introspection_client_id or not self.introspection_client_secret:
raise ValueError(
"introspection_client_id and introspection_client_secret are required "
"when using token introspection"
)
if self.jwks_url:
self._jwk_client = PyJWKClient(
str(self.jwks_url), lifespan=self.jwks_cache_ttl
)
return self
async def authenticate(self, token: str) -> AuthenticatedUser:
"""Authenticate using OAuth2 token validation.
Uses JWKS validation if jwks_url is configured, otherwise falls back
to token introspection.
Args:
token: The OAuth2 access token to authenticate.
Returns:
AuthenticatedUser on successful authentication.
Raises:
HTTPException: If authentication fails.
"""
if self._jwk_client:
return await self._authenticate_jwt(token)
return await self._authenticate_introspection(token)
async def _authenticate_jwt(self, token: str) -> AuthenticatedUser:
"""Authenticate using JWKS JWT validation."""
if self._jwk_client is None:
raise HTTPException(
status_code=HTTP_500_INTERNAL_SERVER_ERROR,
detail="OAuth2 JWKS not initialized",
)
try:
signing_key = self._jwk_client.get_signing_key_from_jwt(token)
decode_options: dict[str, Any] = {
"require": self.required_claims,
}
claims = jwt.decode(
token,
signing_key.key,
algorithms=self.algorithms,
audience=self.audience,
issuer=self.issuer,
leeway=self.clock_skew_seconds,
options=decode_options,
)
return AuthenticatedUser(
token=token,
scheme="oauth2",
claims=claims,
)
except jwt.ExpiredSignatureError:
logger.debug(
"OAuth2 authentication failed",
extra={"reason": "token_expired", "scheme": "oauth2"},
)
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Token has expired",
) from None
except jwt.InvalidAudienceError:
logger.debug(
"OAuth2 authentication failed",
extra={"reason": "invalid_audience", "scheme": "oauth2"},
)
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Invalid token audience",
) from None
except jwt.InvalidIssuerError:
logger.debug(
"OAuth2 authentication failed",
extra={"reason": "invalid_issuer", "scheme": "oauth2"},
)
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Invalid token issuer",
) from None
except jwt.MissingRequiredClaimError as e:
logger.debug(
"OAuth2 authentication failed",
extra={"reason": "missing_claim", "claim": e.claim, "scheme": "oauth2"},
)
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail=f"Missing required claim: {e.claim}",
) from None
except jwt.PyJWKClientError as e:
logger.error(
"OAuth2 authentication failed",
extra={
"reason": "jwks_client_error",
"error": str(e),
"scheme": "oauth2",
},
)
raise HTTPException(
status_code=HTTP_503_SERVICE_UNAVAILABLE,
detail="Unable to fetch signing keys",
) from None
except jwt.InvalidTokenError as e:
logger.debug(
"OAuth2 authentication failed",
extra={"reason": "invalid_token", "error": str(e), "scheme": "oauth2"},
)
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Invalid or missing authentication credentials",
) from None
async def _authenticate_introspection(self, token: str) -> AuthenticatedUser:
"""Authenticate using OAuth2 token introspection (RFC 7662)."""
import httpx
if not self.introspection_url:
raise HTTPException(
status_code=HTTP_500_INTERNAL_SERVER_ERROR,
detail="OAuth2 introspection not configured",
)
try:
async with httpx.AsyncClient() as client:
response = await client.post(
str(self.introspection_url),
data={"token": token},
auth=(
self.introspection_client_id or "",
self.introspection_client_secret.get_secret_value()
if self.introspection_client_secret
else "",
),
)
response.raise_for_status()
introspection_result = response.json()
except httpx.HTTPStatusError as e:
logger.error(
"OAuth2 introspection failed",
extra={"reason": "http_error", "status_code": e.response.status_code},
)
raise HTTPException(
status_code=HTTP_503_SERVICE_UNAVAILABLE,
detail="Token introspection service unavailable",
) from None
except Exception as e:
logger.error(
"OAuth2 introspection failed",
extra={"reason": "unexpected_error", "error": str(e)},
)
raise HTTPException(
status_code=HTTP_503_SERVICE_UNAVAILABLE,
detail="Token introspection failed",
) from None
if not introspection_result.get("active", False):
logger.debug(
"OAuth2 authentication failed",
extra={"reason": "token_not_active", "scheme": "oauth2"},
)
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Token is not active",
)
return AuthenticatedUser(
token=token,
scheme="oauth2",
claims=introspection_result,
)
def to_security_scheme(self) -> OAuth2SecurityScheme:
"""Generate OAuth2SecurityScheme for AgentCard declaration.
Creates an OAuth2SecurityScheme with appropriate flows based on
the configured URLs. Includes client_credentials flow if token_url
is set, and authorization_code flow if authorization_url is set.
Returns:
OAuth2SecurityScheme suitable for use in AgentCard security_schemes.
"""
from a2a.types import (
AuthorizationCodeOAuthFlow,
ClientCredentialsOAuthFlow,
OAuth2SecurityScheme,
OAuthFlows,
)
client_credentials = None
authorization_code = None
if self.token_url:
client_credentials = ClientCredentialsOAuthFlow(
token_url=str(self.token_url),
refresh_url=str(self.refresh_url) if self.refresh_url else None,
scopes=self.scopes,
)
if self.authorization_url:
authorization_code = AuthorizationCodeOAuthFlow(
authorization_url=str(self.authorization_url),
token_url=str(self.token_url),
refresh_url=str(self.refresh_url) if self.refresh_url else None,
scopes=self.scopes,
)
return OAuth2SecurityScheme(
flows=OAuthFlows(
client_credentials=client_credentials,
authorization_code=authorization_code,
),
description="OAuth2 authentication",
)
class APIKeyServerAuth(ServerAuthScheme):
"""API Key authentication for A2A server.
Validates requests using an API key in a header, query parameter, or cookie.
Attributes:
name: The name of the API key parameter (default: X-API-Key).
location: Where to look for the API key (header, query, or cookie).
api_key: The expected API key value.
"""
name: str = Field(
default="X-API-Key",
description="Name of the API key parameter",
)
location: Literal["header", "query", "cookie"] = Field(
default="header",
description="Where to look for the API key",
)
api_key: CoercedSecretStr = Field(
description="Expected API key value",
)
async def authenticate(self, token: str) -> AuthenticatedUser:
"""Authenticate using API key comparison.
Args:
token: The API key to authenticate.
Returns:
AuthenticatedUser on successful authentication.
Raises:
HTTPException: If authentication fails.
"""
if token != self.api_key.get_secret_value():
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Invalid API key",
)
return AuthenticatedUser(
token=token,
scheme="api_key",
)
class MTLSServerAuth(ServerAuthScheme):
"""Mutual TLS authentication marker for AgentCard declaration.
This scheme is primarily for AgentCard security_schemes declaration.
Actual mTLS verification happens at the TLS/transport layer, not
at the application layer via token validation.
When configured, this signals to clients that the server requires
client certificates for authentication.
"""
description: str = Field(
default="Mutual TLS certificate authentication",
description="Description for the security scheme",
)
async def authenticate(self, token: str) -> AuthenticatedUser:
"""Return authenticated user for mTLS.
mTLS verification happens at the transport layer before this is called.
If we reach this point, the TLS handshake with client cert succeeded.
Args:
token: Certificate subject or identifier (from TLS layer).
Returns:
AuthenticatedUser indicating mTLS authentication.
"""
return AuthenticatedUser(
token=token or "mtls-verified",
scheme="mtls",
)

View File

@@ -6,8 +6,10 @@ OAuth2, API keys, and HTTP authentication methods.
import asyncio
from collections.abc import Awaitable, Callable, MutableMapping
import hashlib
import re
from typing import Final
import threading
from typing import Final, Literal, cast
from a2a.client.errors import A2AClientHTTPError
from a2a.types import (
@@ -18,10 +20,10 @@ from a2a.types import (
)
from httpx import AsyncClient, Response
from crewai.a2a.auth.schemas import (
from crewai.a2a.auth.client_schemes import (
APIKeyAuth,
AuthScheme,
BearerTokenAuth,
ClientAuthScheme,
HTTPBasicAuth,
HTTPDigestAuth,
OAuth2AuthorizationCode,
@@ -29,12 +31,44 @@ from crewai.a2a.auth.schemas import (
)
_auth_store: dict[int, AuthScheme | None] = {}
class _AuthStore:
"""Store for authentication schemes with safe concurrent access."""
def __init__(self) -> None:
self._store: dict[str, ClientAuthScheme | None] = {}
self._lock = threading.RLock()
@staticmethod
def compute_key(auth_type: str, auth_data: str) -> str:
"""Compute a collision-resistant key using SHA-256."""
content = f"{auth_type}:{auth_data}"
return hashlib.sha256(content.encode()).hexdigest()
def set(self, key: str, auth: ClientAuthScheme | None) -> None:
"""Store an auth scheme."""
with self._lock:
self._store[key] = auth
def get(self, key: str) -> ClientAuthScheme | None:
"""Retrieve an auth scheme by key."""
with self._lock:
return self._store.get(key)
def __setitem__(self, key: str, value: ClientAuthScheme | None) -> None:
with self._lock:
self._store[key] = value
def __getitem__(self, key: str) -> ClientAuthScheme | None:
with self._lock:
return self._store[key]
_auth_store = _AuthStore()
_SCHEME_PATTERN: Final[re.Pattern[str]] = re.compile(r"(\w+)\s+(.+?)(?=,\s*\w+\s+|$)")
_PARAM_PATTERN: Final[re.Pattern[str]] = re.compile(r'(\w+)=(?:"([^"]*)"|([^\s,]+))')
_SCHEME_AUTH_MAPPING: Final[dict[type, tuple[type[AuthScheme], ...]]] = {
_SCHEME_AUTH_MAPPING: Final[dict[type, tuple[type[ClientAuthScheme], ...]]] = {
OAuth2SecurityScheme: (
OAuth2ClientCredentials,
OAuth2AuthorizationCode,
@@ -43,7 +77,9 @@ _SCHEME_AUTH_MAPPING: Final[dict[type, tuple[type[AuthScheme], ...]]] = {
APIKeySecurityScheme: (APIKeyAuth,),
}
_HTTP_SCHEME_MAPPING: Final[dict[str, type[AuthScheme]]] = {
_HTTPSchemeType = Literal["basic", "digest", "bearer"]
_HTTP_SCHEME_MAPPING: Final[dict[_HTTPSchemeType, type[ClientAuthScheme]]] = {
"basic": HTTPBasicAuth,
"digest": HTTPDigestAuth,
"bearer": BearerTokenAuth,
@@ -51,8 +87,8 @@ _HTTP_SCHEME_MAPPING: Final[dict[str, type[AuthScheme]]] = {
def _raise_auth_mismatch(
expected_classes: type[AuthScheme] | tuple[type[AuthScheme], ...],
provided_auth: AuthScheme,
expected_classes: type[ClientAuthScheme] | tuple[type[ClientAuthScheme], ...],
provided_auth: ClientAuthScheme,
) -> None:
"""Raise authentication mismatch error.
@@ -111,7 +147,7 @@ def parse_www_authenticate(header_value: str) -> dict[str, dict[str, str]]:
def validate_auth_against_agent_card(
agent_card: AgentCard, auth: AuthScheme | None
agent_card: AgentCard, auth: ClientAuthScheme | None
) -> None:
"""Validate that provided auth matches AgentCard security requirements.
@@ -145,7 +181,8 @@ def validate_auth_against_agent_card(
return
if isinstance(scheme, HTTPAuthSecurityScheme):
if required_class := _HTTP_SCHEME_MAPPING.get(scheme.scheme.lower()):
scheme_key = cast(_HTTPSchemeType, scheme.scheme.lower())
if required_class := _HTTP_SCHEME_MAPPING.get(scheme_key):
if not isinstance(auth, required_class):
_raise_auth_mismatch(required_class, auth)
return
@@ -156,7 +193,7 @@ def validate_auth_against_agent_card(
async def retry_on_401(
request_func: Callable[[], Awaitable[Response]],
auth_scheme: AuthScheme | None,
auth_scheme: ClientAuthScheme | None,
client: AsyncClient,
headers: MutableMapping[str, str],
max_retries: int = 3,

View File

@@ -5,14 +5,25 @@ This module is separate from experimental.a2a to avoid circular imports.
from __future__ import annotations
from importlib.metadata import version
from typing import Any, ClassVar, Literal
from pathlib import Path
from typing import Any, ClassVar, Literal, cast
import warnings
from pydantic import BaseModel, ConfigDict, Field
from typing_extensions import deprecated
from pydantic import (
BaseModel,
ConfigDict,
Field,
FilePath,
PrivateAttr,
SecretStr,
model_validator,
)
from typing_extensions import Self, deprecated
from crewai.a2a.auth.schemas import AuthScheme
from crewai.a2a.types import TransportType, Url
from crewai.a2a.auth.client_schemes import ClientAuthScheme
from crewai.a2a.auth.server_schemes import ServerAuthScheme
from crewai.a2a.extensions.base import ValidatedA2AExtension
from crewai.a2a.types import ProtocolVersion, TransportType, Url
try:
@@ -25,16 +36,17 @@ try:
SecurityScheme,
)
from crewai.a2a.extensions.server import ServerExtension
from crewai.a2a.updates import UpdateConfig
except ImportError:
UpdateConfig = Any
AgentCapabilities = Any
AgentCardSignature = Any
AgentInterface = Any
AgentProvider = Any
SecurityScheme = Any
AgentSkill = Any
UpdateConfig = Any # type: ignore[misc,assignment]
UpdateConfig: Any = Any # type: ignore[no-redef]
AgentCapabilities: Any = Any # type: ignore[no-redef]
AgentCardSignature: Any = Any # type: ignore[no-redef]
AgentInterface: Any = Any # type: ignore[no-redef]
AgentProvider: Any = Any # type: ignore[no-redef]
SecurityScheme: Any = Any # type: ignore[no-redef]
AgentSkill: Any = Any # type: ignore[no-redef]
ServerExtension: Any = Any # type: ignore[no-redef]
def _get_default_update_config() -> UpdateConfig:
@@ -43,6 +55,309 @@ def _get_default_update_config() -> UpdateConfig:
return StreamingConfig()
SigningAlgorithm = Literal[
"RS256",
"RS384",
"RS512",
"ES256",
"ES384",
"ES512",
"PS256",
"PS384",
"PS512",
]
class AgentCardSigningConfig(BaseModel):
"""Configuration for AgentCard JWS signing.
Provides the private key and algorithm settings for signing AgentCards.
Either private_key_path or private_key_pem must be provided, but not both.
Attributes:
private_key_path: Path to a PEM-encoded private key file.
private_key_pem: PEM-encoded private key as a secret string.
key_id: Optional key identifier for the JWS header (kid claim).
algorithm: Signing algorithm (RS256, ES256, PS256, etc.).
"""
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
private_key_path: FilePath | None = Field(
default=None,
description="Path to PEM-encoded private key file",
)
private_key_pem: SecretStr | None = Field(
default=None,
description="PEM-encoded private key",
)
key_id: str | None = Field(
default=None,
description="Key identifier for JWS header (kid claim)",
)
algorithm: SigningAlgorithm = Field(
default="RS256",
description="Signing algorithm (RS256, ES256, PS256, etc.)",
)
@model_validator(mode="after")
def _validate_key_source(self) -> Self:
"""Ensure exactly one key source is provided."""
has_path = self.private_key_path is not None
has_pem = self.private_key_pem is not None
if not has_path and not has_pem:
raise ValueError(
"Either private_key_path or private_key_pem must be provided"
)
if has_path and has_pem:
raise ValueError(
"Only one of private_key_path or private_key_pem should be provided"
)
return self
def get_private_key(self) -> str:
"""Get the private key content.
Returns:
The PEM-encoded private key as a string.
"""
if self.private_key_pem:
return self.private_key_pem.get_secret_value()
if self.private_key_path:
return Path(self.private_key_path).read_text()
raise ValueError("No private key configured")
class GRPCServerConfig(BaseModel):
"""gRPC server transport configuration.
Presence of this config in ServerTransportConfig.grpc enables gRPC transport.
Attributes:
host: Hostname to advertise in agent cards (default: localhost).
Use docker service name (e.g., 'web') for docker-compose setups.
port: Port for the gRPC server.
tls_cert_path: Path to TLS certificate file for gRPC.
tls_key_path: Path to TLS private key file for gRPC.
max_workers: Maximum number of workers for the gRPC thread pool.
reflection_enabled: Whether to enable gRPC reflection for debugging.
"""
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
host: str = Field(
default="localhost",
description="Hostname to advertise in agent cards for gRPC connections",
)
port: int = Field(
default=50051,
description="Port for the gRPC server",
)
tls_cert_path: str | None = Field(
default=None,
description="Path to TLS certificate file for gRPC",
)
tls_key_path: str | None = Field(
default=None,
description="Path to TLS private key file for gRPC",
)
max_workers: int = Field(
default=10,
description="Maximum number of workers for the gRPC thread pool",
)
reflection_enabled: bool = Field(
default=False,
description="Whether to enable gRPC reflection for debugging",
)
class GRPCClientConfig(BaseModel):
"""gRPC client transport configuration.
Attributes:
max_send_message_length: Maximum size for outgoing messages in bytes.
max_receive_message_length: Maximum size for incoming messages in bytes.
keepalive_time_ms: Time between keepalive pings in milliseconds.
keepalive_timeout_ms: Timeout for keepalive ping response in milliseconds.
"""
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
max_send_message_length: int | None = Field(
default=None,
description="Maximum size for outgoing messages in bytes",
)
max_receive_message_length: int | None = Field(
default=None,
description="Maximum size for incoming messages in bytes",
)
keepalive_time_ms: int | None = Field(
default=None,
description="Time between keepalive pings in milliseconds",
)
keepalive_timeout_ms: int | None = Field(
default=None,
description="Timeout for keepalive ping response in milliseconds",
)
class JSONRPCServerConfig(BaseModel):
"""JSON-RPC server transport configuration.
Presence of this config in ServerTransportConfig.jsonrpc enables JSON-RPC transport.
Attributes:
rpc_path: URL path for the JSON-RPC endpoint.
agent_card_path: URL path for the agent card endpoint.
"""
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
rpc_path: str = Field(
default="/a2a",
description="URL path for the JSON-RPC endpoint",
)
agent_card_path: str = Field(
default="/.well-known/agent-card.json",
description="URL path for the agent card endpoint",
)
class JSONRPCClientConfig(BaseModel):
"""JSON-RPC client transport configuration.
Attributes:
max_request_size: Maximum request body size in bytes.
"""
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
max_request_size: int | None = Field(
default=None,
description="Maximum request body size in bytes",
)
class HTTPJSONConfig(BaseModel):
"""HTTP+JSON transport configuration.
Presence of this config in ServerTransportConfig.http_json enables HTTP+JSON transport.
"""
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
class ServerPushNotificationConfig(BaseModel):
"""Configuration for outgoing webhook push notifications.
Controls how the server signs and delivers push notifications to clients.
Attributes:
signature_secret: Shared secret for HMAC-SHA256 signing of outgoing webhooks.
"""
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
signature_secret: SecretStr | None = Field(
default=None,
description="Shared secret for HMAC-SHA256 signing of outgoing push notifications",
)
class ServerTransportConfig(BaseModel):
"""Transport configuration for A2A server.
Groups all transport-related settings including preferred transport
and protocol-specific configurations.
Attributes:
preferred: Transport protocol for the preferred endpoint.
jsonrpc: JSON-RPC server transport configuration.
grpc: gRPC server transport configuration.
http_json: HTTP+JSON transport configuration.
"""
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
preferred: TransportType = Field(
default="JSONRPC",
description="Transport protocol for the preferred endpoint",
)
jsonrpc: JSONRPCServerConfig = Field(
default_factory=JSONRPCServerConfig,
description="JSON-RPC server transport configuration",
)
grpc: GRPCServerConfig | None = Field(
default=None,
description="gRPC server transport configuration",
)
http_json: HTTPJSONConfig | None = Field(
default=None,
description="HTTP+JSON transport configuration",
)
def _migrate_client_transport_fields(
transport: ClientTransportConfig,
transport_protocol: TransportType | None,
supported_transports: list[TransportType] | None,
) -> None:
"""Migrate deprecated transport fields to new config."""
if transport_protocol is not None:
warnings.warn(
"transport_protocol is deprecated, use transport=ClientTransportConfig(preferred=...) instead",
FutureWarning,
stacklevel=5,
)
object.__setattr__(transport, "preferred", transport_protocol)
if supported_transports is not None:
warnings.warn(
"supported_transports is deprecated, use transport=ClientTransportConfig(supported=...) instead",
FutureWarning,
stacklevel=5,
)
object.__setattr__(transport, "supported", supported_transports)
class ClientTransportConfig(BaseModel):
"""Transport configuration for A2A client.
Groups all client transport-related settings including preferred transport,
supported transports for negotiation, and protocol-specific configurations.
Transport negotiation logic:
1. If `preferred` is set and server supports it → use client's preferred
2. Otherwise, if server's preferred is in client's `supported` → use server's preferred
3. Otherwise, find first match from client's `supported` in server's interfaces
Attributes:
preferred: Client's preferred transport. If set, client preference takes priority.
supported: Transports the client can use, in order of preference.
jsonrpc: JSON-RPC client transport configuration.
grpc: gRPC client transport configuration.
"""
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
preferred: TransportType | None = Field(
default=None,
description="Client's preferred transport. If set, takes priority over server preference.",
)
supported: list[TransportType] = Field(
default_factory=lambda: cast(list[TransportType], ["JSONRPC"]),
description="Transports the client can use, in order of preference",
)
jsonrpc: JSONRPCClientConfig = Field(
default_factory=JSONRPCClientConfig,
description="JSON-RPC client transport configuration",
)
grpc: GRPCClientConfig = Field(
default_factory=GRPCClientConfig,
description="gRPC client transport configuration",
)
@deprecated(
"""
`crewai.a2a.config.A2AConfig` is deprecated and will be removed in v2.0.0,
@@ -65,13 +380,14 @@ class A2AConfig(BaseModel):
fail_fast: If True, raise error when agent unreachable; if False, skip and continue.
trust_remote_completion_status: If True, return A2A agent's result directly when completed.
updates: Update mechanism config.
transport_protocol: A2A transport protocol (grpc, jsonrpc, http+json).
client_extensions: Client-side processing hooks for tool injection and prompt augmentation.
transport: Transport configuration (preferred, supported transports, gRPC settings).
"""
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
endpoint: Url = Field(description="A2A agent endpoint URL")
auth: AuthScheme | None = Field(
auth: ClientAuthScheme | None = Field(
default=None,
description="Authentication scheme",
)
@@ -95,10 +411,48 @@ class A2AConfig(BaseModel):
default_factory=_get_default_update_config,
description="Update mechanism config",
)
transport_protocol: Literal["JSONRPC", "GRPC", "HTTP+JSON"] = Field(
default="JSONRPC",
description="Specified mode of A2A transport protocol",
client_extensions: list[ValidatedA2AExtension] = Field(
default_factory=list,
description="Client-side processing hooks for tool injection and prompt augmentation",
)
transport: ClientTransportConfig = Field(
default_factory=ClientTransportConfig,
description="Transport configuration (preferred, supported transports, gRPC settings)",
)
transport_protocol: TransportType | None = Field(
default=None,
description="Deprecated: Use transport.preferred instead",
exclude=True,
)
supported_transports: list[TransportType] | None = Field(
default=None,
description="Deprecated: Use transport.supported instead",
exclude=True,
)
use_client_preference: bool | None = Field(
default=None,
description="Deprecated: Set transport.preferred to enable client preference",
exclude=True,
)
_parallel_delegation: bool = PrivateAttr(default=False)
@model_validator(mode="after")
def _migrate_deprecated_transport_fields(self) -> Self:
"""Migrate deprecated transport fields to new config."""
_migrate_client_transport_fields(
self.transport, self.transport_protocol, self.supported_transports
)
if self.use_client_preference is not None:
warnings.warn(
"use_client_preference is deprecated, set transport.preferred to enable client preference",
FutureWarning,
stacklevel=4,
)
if self.use_client_preference and self.transport.supported:
object.__setattr__(
self.transport, "preferred", self.transport.supported[0]
)
return self
class A2AClientConfig(BaseModel):
@@ -114,15 +468,15 @@ class A2AClientConfig(BaseModel):
trust_remote_completion_status: If True, return A2A agent's result directly when completed.
updates: Update mechanism config.
accepted_output_modes: Media types the client can accept in responses.
supported_transports: Ordered list of transport protocols the client supports.
use_client_preference: Whether to prioritize client transport preferences over server.
extensions: Extension URIs the client supports.
extensions: Extension URIs the client supports (A2A protocol extensions).
client_extensions: Client-side processing hooks for tool injection and prompt augmentation.
transport: Transport configuration (preferred, supported transports, gRPC settings).
"""
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
endpoint: Url = Field(description="A2A agent endpoint URL")
auth: AuthScheme | None = Field(
auth: ClientAuthScheme | None = Field(
default=None,
description="Authentication scheme",
)
@@ -150,22 +504,37 @@ class A2AClientConfig(BaseModel):
default_factory=lambda: ["application/json"],
description="Media types the client can accept in responses",
)
supported_transports: list[str] = Field(
default_factory=lambda: ["JSONRPC"],
description="Ordered list of transport protocols the client supports",
)
use_client_preference: bool = Field(
default=False,
description="Whether to prioritize client transport preferences over server",
)
extensions: list[str] = Field(
default_factory=list,
description="Extension URIs the client supports",
)
transport_protocol: Literal["JSONRPC", "GRPC", "HTTP+JSON"] = Field(
default="JSONRPC",
description="Specified mode of A2A transport protocol",
client_extensions: list[ValidatedA2AExtension] = Field(
default_factory=list,
description="Client-side processing hooks for tool injection and prompt augmentation",
)
transport: ClientTransportConfig = Field(
default_factory=ClientTransportConfig,
description="Transport configuration (preferred, supported transports, gRPC settings)",
)
transport_protocol: TransportType | None = Field(
default=None,
description="Deprecated: Use transport.preferred instead",
exclude=True,
)
supported_transports: list[TransportType] | None = Field(
default=None,
description="Deprecated: Use transport.supported instead",
exclude=True,
)
_parallel_delegation: bool = PrivateAttr(default=False)
@model_validator(mode="after")
def _migrate_deprecated_transport_fields(self) -> Self:
"""Migrate deprecated transport fields to new config."""
_migrate_client_transport_fields(
self.transport, self.transport_protocol, self.supported_transports
)
return self
class A2AServerConfig(BaseModel):
@@ -182,7 +551,6 @@ class A2AServerConfig(BaseModel):
default_input_modes: Default supported input MIME types.
default_output_modes: Default supported output MIME types.
capabilities: Declaration of optional capabilities.
preferred_transport: Transport protocol for the preferred endpoint.
protocol_version: A2A protocol version this agent supports.
provider: Information about the agent's service provider.
documentation_url: URL to the agent's documentation.
@@ -192,7 +560,12 @@ class A2AServerConfig(BaseModel):
security_schemes: Security schemes available to authorize requests.
supports_authenticated_extended_card: Whether agent provides extended card to authenticated users.
url: Preferred endpoint URL for the agent.
signatures: JSON Web Signatures for the AgentCard.
signing_config: Configuration for signing the AgentCard with JWS.
signatures: Deprecated. Pre-computed JWS signatures. Use signing_config instead.
server_extensions: Server-side A2A protocol extensions with on_request/on_response hooks.
push_notifications: Configuration for outgoing push notifications.
transport: Transport configuration (preferred transport, gRPC, REST settings).
auth: Authentication scheme for A2A endpoints.
"""
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
@@ -228,12 +601,8 @@ class A2AServerConfig(BaseModel):
),
description="Declaration of optional capabilities supported by the agent",
)
preferred_transport: TransportType = Field(
default="JSONRPC",
description="Transport protocol for the preferred endpoint",
)
protocol_version: str = Field(
default_factory=lambda: version("a2a-sdk"),
protocol_version: ProtocolVersion = Field(
default="0.3.0",
description="A2A protocol version this agent supports",
)
provider: AgentProvider | None = Field(
@@ -250,7 +619,7 @@ class A2AServerConfig(BaseModel):
)
additional_interfaces: list[AgentInterface] = Field(
default_factory=list,
description="Additional supported interfaces (transport and URL combinations)",
description="Additional supported interfaces.",
)
security: list[dict[str, list[str]]] = Field(
default_factory=list,
@@ -268,7 +637,54 @@ class A2AServerConfig(BaseModel):
default=None,
description="Preferred endpoint URL for the agent. Set at runtime if not provided.",
)
signatures: list[AgentCardSignature] = Field(
default_factory=list,
description="JSON Web Signatures for the AgentCard",
signing_config: AgentCardSigningConfig | None = Field(
default=None,
description="Configuration for signing the AgentCard with JWS",
)
signatures: list[AgentCardSignature] | None = Field(
default=None,
description="Deprecated: Use signing_config instead. Pre-computed JWS signatures for the AgentCard.",
exclude=True,
deprecated=True,
)
server_extensions: list[ServerExtension] = Field(
default_factory=list,
description="Server-side A2A protocol extensions that modify agent behavior",
)
push_notifications: ServerPushNotificationConfig | None = Field(
default=None,
description="Configuration for outgoing push notifications",
)
transport: ServerTransportConfig = Field(
default_factory=ServerTransportConfig,
description="Transport configuration (preferred transport, gRPC, REST settings)",
)
preferred_transport: TransportType | None = Field(
default=None,
description="Deprecated: Use transport.preferred instead",
exclude=True,
deprecated=True,
)
auth: ServerAuthScheme | None = Field(
default=None,
description="Authentication scheme for A2A endpoints. Defaults to SimpleTokenAuth using AUTH_TOKEN env var.",
)
@model_validator(mode="after")
def _migrate_deprecated_fields(self) -> Self:
"""Migrate deprecated fields to new config."""
if self.preferred_transport is not None:
warnings.warn(
"preferred_transport is deprecated, use transport=ServerTransportConfig(preferred=...) instead",
FutureWarning,
stacklevel=4,
)
object.__setattr__(self.transport, "preferred", self.preferred_transport)
if self.signatures is not None:
warnings.warn(
"signatures is deprecated, use signing_config=AgentCardSigningConfig(...) instead. "
"The signatures field will be removed in v2.0.0.",
FutureWarning,
stacklevel=4,
)
return self

View File

@@ -1,7 +1,491 @@
"""A2A protocol error types."""
"""A2A error codes and error response utilities.
This module provides a centralized mapping of all A2A protocol error codes
as defined in the A2A specification, plus custom CrewAI extensions.
Error codes follow JSON-RPC 2.0 conventions:
- -32700 to -32600: Standard JSON-RPC errors
- -32099 to -32000: Server errors (A2A-specific)
- -32768 to -32100: Reserved for implementation-defined errors
"""
from __future__ import annotations
from dataclasses import dataclass, field
from enum import IntEnum
from typing import Any
from a2a.client.errors import A2AClientTimeoutError
class A2APollingTimeoutError(A2AClientTimeoutError):
"""Raised when polling exceeds the configured timeout."""
class A2AErrorCode(IntEnum):
"""A2A protocol error codes.
Codes follow JSON-RPC 2.0 specification with A2A-specific extensions.
"""
# JSON-RPC 2.0 Standard Errors (-32700 to -32600)
JSON_PARSE_ERROR = -32700
"""Invalid JSON was received by the server."""
INVALID_REQUEST = -32600
"""The JSON sent is not a valid Request object."""
METHOD_NOT_FOUND = -32601
"""The method does not exist / is not available."""
INVALID_PARAMS = -32602
"""Invalid method parameter(s)."""
INTERNAL_ERROR = -32603
"""Internal JSON-RPC error."""
# A2A-Specific Errors (-32099 to -32000)
TASK_NOT_FOUND = -32001
"""The specified task was not found."""
TASK_NOT_CANCELABLE = -32002
"""The task cannot be canceled (already completed/failed)."""
PUSH_NOTIFICATION_NOT_SUPPORTED = -32003
"""Push notifications are not supported by this agent."""
UNSUPPORTED_OPERATION = -32004
"""The requested operation is not supported."""
CONTENT_TYPE_NOT_SUPPORTED = -32005
"""Incompatible content types between client and server."""
INVALID_AGENT_RESPONSE = -32006
"""The agent produced an invalid response."""
# CrewAI Custom Extensions (-32768 to -32100)
UNSUPPORTED_VERSION = -32009
"""The requested A2A protocol version is not supported."""
UNSUPPORTED_EXTENSION = -32010
"""Client does not support required protocol extensions."""
AUTHENTICATION_REQUIRED = -32011
"""Authentication is required for this operation."""
AUTHORIZATION_FAILED = -32012
"""Authorization check failed (insufficient permissions)."""
RATE_LIMIT_EXCEEDED = -32013
"""Rate limit exceeded for this client/operation."""
TASK_TIMEOUT = -32014
"""Task execution timed out."""
TRANSPORT_NEGOTIATION_FAILED = -32015
"""Failed to negotiate a compatible transport protocol."""
CONTEXT_NOT_FOUND = -32016
"""The specified context was not found."""
SKILL_NOT_FOUND = -32017
"""The specified skill was not found."""
ARTIFACT_NOT_FOUND = -32018
"""The specified artifact was not found."""
# Error code to default message mapping
ERROR_MESSAGES: dict[int, str] = {
A2AErrorCode.JSON_PARSE_ERROR: "Parse error",
A2AErrorCode.INVALID_REQUEST: "Invalid Request",
A2AErrorCode.METHOD_NOT_FOUND: "Method not found",
A2AErrorCode.INVALID_PARAMS: "Invalid params",
A2AErrorCode.INTERNAL_ERROR: "Internal error",
A2AErrorCode.TASK_NOT_FOUND: "Task not found",
A2AErrorCode.TASK_NOT_CANCELABLE: "Task not cancelable",
A2AErrorCode.PUSH_NOTIFICATION_NOT_SUPPORTED: "Push Notification is not supported",
A2AErrorCode.UNSUPPORTED_OPERATION: "This operation is not supported",
A2AErrorCode.CONTENT_TYPE_NOT_SUPPORTED: "Incompatible content types",
A2AErrorCode.INVALID_AGENT_RESPONSE: "Invalid agent response",
A2AErrorCode.UNSUPPORTED_VERSION: "Unsupported A2A version",
A2AErrorCode.UNSUPPORTED_EXTENSION: "Client does not support required extensions",
A2AErrorCode.AUTHENTICATION_REQUIRED: "Authentication required",
A2AErrorCode.AUTHORIZATION_FAILED: "Authorization failed",
A2AErrorCode.RATE_LIMIT_EXCEEDED: "Rate limit exceeded",
A2AErrorCode.TASK_TIMEOUT: "Task execution timed out",
A2AErrorCode.TRANSPORT_NEGOTIATION_FAILED: "Transport negotiation failed",
A2AErrorCode.CONTEXT_NOT_FOUND: "Context not found",
A2AErrorCode.SKILL_NOT_FOUND: "Skill not found",
A2AErrorCode.ARTIFACT_NOT_FOUND: "Artifact not found",
}
@dataclass
class A2AError(Exception):
"""Base exception for A2A protocol errors.
Attributes:
code: The A2A/JSON-RPC error code.
message: Human-readable error message.
data: Optional additional error data.
"""
code: int
message: str | None = None
data: Any = None
def __post_init__(self) -> None:
if self.message is None:
self.message = ERROR_MESSAGES.get(self.code, "Unknown error")
super().__init__(self.message)
def to_dict(self) -> dict[str, Any]:
"""Convert to JSON-RPC error object format."""
error: dict[str, Any] = {
"code": self.code,
"message": self.message,
}
if self.data is not None:
error["data"] = self.data
return error
def to_response(self, request_id: str | int | None = None) -> dict[str, Any]:
"""Convert to full JSON-RPC error response."""
return {
"jsonrpc": "2.0",
"error": self.to_dict(),
"id": request_id,
}
@dataclass
class JSONParseError(A2AError):
"""Invalid JSON was received."""
code: int = field(default=A2AErrorCode.JSON_PARSE_ERROR, init=False)
@dataclass
class InvalidRequestError(A2AError):
"""The JSON sent is not a valid Request object."""
code: int = field(default=A2AErrorCode.INVALID_REQUEST, init=False)
@dataclass
class MethodNotFoundError(A2AError):
"""The method does not exist / is not available."""
code: int = field(default=A2AErrorCode.METHOD_NOT_FOUND, init=False)
method: str | None = None
def __post_init__(self) -> None:
if self.message is None and self.method:
self.message = f"Method not found: {self.method}"
super().__post_init__()
@dataclass
class InvalidParamsError(A2AError):
"""Invalid method parameter(s)."""
code: int = field(default=A2AErrorCode.INVALID_PARAMS, init=False)
param: str | None = None
reason: str | None = None
def __post_init__(self) -> None:
if self.message is None:
if self.param and self.reason:
self.message = f"Invalid parameter '{self.param}': {self.reason}"
elif self.param:
self.message = f"Invalid parameter: {self.param}"
super().__post_init__()
@dataclass
class InternalError(A2AError):
"""Internal JSON-RPC error."""
code: int = field(default=A2AErrorCode.INTERNAL_ERROR, init=False)
@dataclass
class TaskNotFoundError(A2AError):
"""The specified task was not found."""
code: int = field(default=A2AErrorCode.TASK_NOT_FOUND, init=False)
task_id: str | None = None
def __post_init__(self) -> None:
if self.message is None and self.task_id:
self.message = f"Task not found: {self.task_id}"
super().__post_init__()
@dataclass
class TaskNotCancelableError(A2AError):
"""The task cannot be canceled."""
code: int = field(default=A2AErrorCode.TASK_NOT_CANCELABLE, init=False)
task_id: str | None = None
reason: str | None = None
def __post_init__(self) -> None:
if self.message is None:
if self.task_id and self.reason:
self.message = f"Task {self.task_id} cannot be canceled: {self.reason}"
elif self.task_id:
self.message = f"Task {self.task_id} cannot be canceled"
super().__post_init__()
@dataclass
class PushNotificationNotSupportedError(A2AError):
"""Push notifications are not supported."""
code: int = field(default=A2AErrorCode.PUSH_NOTIFICATION_NOT_SUPPORTED, init=False)
@dataclass
class UnsupportedOperationError(A2AError):
"""The requested operation is not supported."""
code: int = field(default=A2AErrorCode.UNSUPPORTED_OPERATION, init=False)
operation: str | None = None
def __post_init__(self) -> None:
if self.message is None and self.operation:
self.message = f"Operation not supported: {self.operation}"
super().__post_init__()
@dataclass
class ContentTypeNotSupportedError(A2AError):
"""Incompatible content types."""
code: int = field(default=A2AErrorCode.CONTENT_TYPE_NOT_SUPPORTED, init=False)
requested_types: list[str] | None = None
supported_types: list[str] | None = None
def __post_init__(self) -> None:
if self.message is None and self.requested_types and self.supported_types:
self.message = (
f"Content type not supported. Requested: {self.requested_types}, "
f"Supported: {self.supported_types}"
)
super().__post_init__()
@dataclass
class InvalidAgentResponseError(A2AError):
"""The agent produced an invalid response."""
code: int = field(default=A2AErrorCode.INVALID_AGENT_RESPONSE, init=False)
@dataclass
class UnsupportedVersionError(A2AError):
"""The requested A2A version is not supported."""
code: int = field(default=A2AErrorCode.UNSUPPORTED_VERSION, init=False)
requested_version: str | None = None
supported_versions: list[str] | None = None
def __post_init__(self) -> None:
if self.message is None and self.requested_version:
msg = f"Unsupported A2A version: {self.requested_version}"
if self.supported_versions:
msg += f". Supported versions: {', '.join(self.supported_versions)}"
self.message = msg
super().__post_init__()
@dataclass
class UnsupportedExtensionError(A2AError):
"""Client does not support required extensions."""
code: int = field(default=A2AErrorCode.UNSUPPORTED_EXTENSION, init=False)
required_extensions: list[str] | None = None
def __post_init__(self) -> None:
if self.message is None and self.required_extensions:
self.message = f"Client does not support required extensions: {', '.join(self.required_extensions)}"
super().__post_init__()
@dataclass
class AuthenticationRequiredError(A2AError):
"""Authentication is required."""
code: int = field(default=A2AErrorCode.AUTHENTICATION_REQUIRED, init=False)
@dataclass
class AuthorizationFailedError(A2AError):
"""Authorization check failed."""
code: int = field(default=A2AErrorCode.AUTHORIZATION_FAILED, init=False)
required_scope: str | None = None
def __post_init__(self) -> None:
if self.message is None and self.required_scope:
self.message = (
f"Authorization failed. Required scope: {self.required_scope}"
)
super().__post_init__()
@dataclass
class RateLimitExceededError(A2AError):
"""Rate limit exceeded."""
code: int = field(default=A2AErrorCode.RATE_LIMIT_EXCEEDED, init=False)
retry_after: int | None = None
def __post_init__(self) -> None:
if self.message is None and self.retry_after:
self.message = (
f"Rate limit exceeded. Retry after {self.retry_after} seconds"
)
if self.retry_after:
self.data = {"retry_after": self.retry_after}
super().__post_init__()
@dataclass
class TaskTimeoutError(A2AError):
"""Task execution timed out."""
code: int = field(default=A2AErrorCode.TASK_TIMEOUT, init=False)
task_id: str | None = None
timeout_seconds: float | None = None
def __post_init__(self) -> None:
if self.message is None:
if self.task_id and self.timeout_seconds:
self.message = (
f"Task {self.task_id} timed out after {self.timeout_seconds}s"
)
elif self.task_id:
self.message = f"Task {self.task_id} timed out"
super().__post_init__()
@dataclass
class TransportNegotiationFailedError(A2AError):
"""Failed to negotiate a compatible transport protocol."""
code: int = field(default=A2AErrorCode.TRANSPORT_NEGOTIATION_FAILED, init=False)
client_transports: list[str] | None = None
server_transports: list[str] | None = None
def __post_init__(self) -> None:
if self.message is None and self.client_transports and self.server_transports:
self.message = (
f"Transport negotiation failed. Client: {self.client_transports}, "
f"Server: {self.server_transports}"
)
super().__post_init__()
@dataclass
class ContextNotFoundError(A2AError):
"""The specified context was not found."""
code: int = field(default=A2AErrorCode.CONTEXT_NOT_FOUND, init=False)
context_id: str | None = None
def __post_init__(self) -> None:
if self.message is None and self.context_id:
self.message = f"Context not found: {self.context_id}"
super().__post_init__()
@dataclass
class SkillNotFoundError(A2AError):
"""The specified skill was not found."""
code: int = field(default=A2AErrorCode.SKILL_NOT_FOUND, init=False)
skill_id: str | None = None
def __post_init__(self) -> None:
if self.message is None and self.skill_id:
self.message = f"Skill not found: {self.skill_id}"
super().__post_init__()
@dataclass
class ArtifactNotFoundError(A2AError):
"""The specified artifact was not found."""
code: int = field(default=A2AErrorCode.ARTIFACT_NOT_FOUND, init=False)
artifact_id: str | None = None
def __post_init__(self) -> None:
if self.message is None and self.artifact_id:
self.message = f"Artifact not found: {self.artifact_id}"
super().__post_init__()
def create_error_response(
code: int | A2AErrorCode,
message: str | None = None,
data: Any = None,
request_id: str | int | None = None,
) -> dict[str, Any]:
"""Create a JSON-RPC error response.
Args:
code: Error code (A2AErrorCode or int).
message: Optional error message (uses default if not provided).
data: Optional additional error data.
request_id: Request ID for correlation.
Returns:
Dict in JSON-RPC error response format.
"""
error = A2AError(code=int(code), message=message, data=data)
return error.to_response(request_id)
def is_retryable_error(code: int) -> bool:
"""Check if an error is potentially retryable.
Args:
code: Error code to check.
Returns:
True if the error might be resolved by retrying.
"""
retryable_codes = {
A2AErrorCode.INTERNAL_ERROR,
A2AErrorCode.RATE_LIMIT_EXCEEDED,
A2AErrorCode.TASK_TIMEOUT,
}
return code in retryable_codes
def is_client_error(code: int) -> bool:
"""Check if an error is a client-side error.
Args:
code: Error code to check.
Returns:
True if the error is due to client request issues.
"""
client_error_codes = {
A2AErrorCode.JSON_PARSE_ERROR,
A2AErrorCode.INVALID_REQUEST,
A2AErrorCode.METHOD_NOT_FOUND,
A2AErrorCode.INVALID_PARAMS,
A2AErrorCode.TASK_NOT_FOUND,
A2AErrorCode.CONTENT_TYPE_NOT_SUPPORTED,
A2AErrorCode.UNSUPPORTED_VERSION,
A2AErrorCode.UNSUPPORTED_EXTENSION,
A2AErrorCode.CONTEXT_NOT_FOUND,
A2AErrorCode.SKILL_NOT_FOUND,
A2AErrorCode.ARTIFACT_NOT_FOUND,
}
return code in client_error_codes

View File

@@ -1,4 +1,37 @@
"""A2A Protocol Extensions for CrewAI.
This module contains extensions to the A2A (Agent-to-Agent) protocol.
**Client-side extensions** (A2AExtension) allow customizing how the A2A wrapper
processes requests and responses during delegation to remote agents. These provide
hooks for tool injection, prompt augmentation, and response processing.
**Server-side extensions** (ServerExtension) allow agents to offer additional
functionality beyond the core A2A specification. Clients activate extensions
via the X-A2A-Extensions header.
See: https://a2a-protocol.org/latest/topics/extensions/
"""
from crewai.a2a.extensions.base import (
A2AExtension,
ConversationState,
ExtensionRegistry,
ValidatedA2AExtension,
)
from crewai.a2a.extensions.server import (
ExtensionContext,
ServerExtension,
ServerExtensionRegistry,
)
__all__ = [
"A2AExtension",
"ConversationState",
"ExtensionContext",
"ExtensionRegistry",
"ServerExtension",
"ServerExtensionRegistry",
"ValidatedA2AExtension",
]

View File

@@ -1,14 +1,20 @@
"""Base extension interface for A2A wrapper integrations.
"""Base extension interface for CrewAI A2A wrapper processing hooks.
This module defines the protocol for extending A2A wrapper functionality
with custom logic for conversation processing, prompt augmentation, and
agent response handling.
This module defines the protocol for extending CrewAI's A2A wrapper functionality
with custom logic for tool injection, prompt augmentation, and response processing.
Note: These are CrewAI-specific processing hooks, NOT A2A protocol extensions.
A2A protocol extensions are capability declarations using AgentExtension objects
in AgentCard.capabilities.extensions, activated via the A2A-Extensions HTTP header.
See: https://a2a-protocol.org/latest/topics/extensions/
"""
from __future__ import annotations
from collections.abc import Sequence
from typing import TYPE_CHECKING, Any, Protocol
from typing import TYPE_CHECKING, Annotated, Any, Protocol, runtime_checkable
from pydantic import BeforeValidator
if TYPE_CHECKING:
@@ -17,6 +23,20 @@ if TYPE_CHECKING:
from crewai.agent.core import Agent
def _validate_a2a_extension(v: Any) -> Any:
"""Validate that value implements A2AExtension protocol."""
if not isinstance(v, A2AExtension):
raise ValueError(
f"Value must implement A2AExtension protocol. "
f"Got {type(v).__name__} which is missing required methods."
)
return v
ValidatedA2AExtension = Annotated[Any, BeforeValidator(_validate_a2a_extension)]
@runtime_checkable
class ConversationState(Protocol):
"""Protocol for extension-specific conversation state.
@@ -33,11 +53,36 @@ class ConversationState(Protocol):
...
@runtime_checkable
class A2AExtension(Protocol):
"""Protocol for A2A wrapper extensions.
Extensions can implement this protocol to inject custom logic into
the A2A conversation flow at various integration points.
Example:
class MyExtension:
def inject_tools(self, agent: Agent) -> None:
# Add custom tools to the agent
pass
def extract_state_from_history(
self, conversation_history: Sequence[Message]
) -> ConversationState | None:
# Extract state from conversation
return None
def augment_prompt(
self, base_prompt: str, conversation_state: ConversationState | None
) -> str:
# Add custom instructions
return base_prompt
def process_response(
self, agent_response: Any, conversation_state: ConversationState | None
) -> Any:
# Modify response if needed
return agent_response
"""
def inject_tools(self, agent: Agent) -> None:

View File

@@ -1,34 +1,170 @@
"""Extension registry factory for A2A configurations.
"""A2A Protocol extension utilities.
This module provides utilities for creating extension registries from A2A configurations.
This module provides utilities for working with A2A protocol extensions as
defined in the A2A specification. Extensions are capability declarations in
AgentCard.capabilities.extensions using AgentExtension objects, activated
via the X-A2A-Extensions HTTP header.
See: https://a2a-protocol.org/latest/topics/extensions/
"""
from __future__ import annotations
from typing import TYPE_CHECKING
from typing import Any
from a2a.client.middleware import ClientCallContext, ClientCallInterceptor
from a2a.extensions.common import (
HTTP_EXTENSION_HEADER,
)
from a2a.types import AgentCard, AgentExtension
from crewai.a2a.config import A2AClientConfig, A2AConfig
from crewai.a2a.extensions.base import ExtensionRegistry
if TYPE_CHECKING:
from crewai.a2a.config import A2AConfig
def get_extensions_from_config(
a2a_config: list[A2AConfig | A2AClientConfig] | A2AConfig | A2AClientConfig,
) -> list[str]:
"""Extract extension URIs from A2A configuration.
Args:
a2a_config: A2A configuration (single or list).
Returns:
Deduplicated list of extension URIs from all configs.
"""
configs = a2a_config if isinstance(a2a_config, list) else [a2a_config]
seen: set[str] = set()
result: list[str] = []
for config in configs:
if not isinstance(config, A2AClientConfig):
continue
for uri in config.extensions:
if uri not in seen:
seen.add(uri)
result.append(uri)
return result
class ExtensionsMiddleware(ClientCallInterceptor):
"""Middleware to add X-A2A-Extensions header to requests.
This middleware adds the extensions header to all outgoing requests,
declaring which A2A protocol extensions the client supports.
"""
def __init__(self, extensions: list[str]) -> None:
"""Initialize with extension URIs.
Args:
extensions: List of extension URIs the client supports.
"""
self._extensions = extensions
async def intercept(
self,
method_name: str,
request_payload: dict[str, Any],
http_kwargs: dict[str, Any],
agent_card: AgentCard | None,
context: ClientCallContext | None,
) -> tuple[dict[str, Any], dict[str, Any]]:
"""Add extensions header to the request.
Args:
method_name: The A2A method being called.
request_payload: The JSON-RPC request payload.
http_kwargs: HTTP request kwargs (headers, etc).
agent_card: The target agent's card.
context: Optional call context.
Returns:
Tuple of (request_payload, modified_http_kwargs).
"""
if self._extensions:
headers = http_kwargs.setdefault("headers", {})
headers[HTTP_EXTENSION_HEADER] = ",".join(self._extensions)
return request_payload, http_kwargs
def validate_required_extensions(
agent_card: AgentCard,
client_extensions: list[str] | None,
) -> list[AgentExtension]:
"""Validate that client supports all required extensions from agent.
Args:
agent_card: The agent's card with declared extensions.
client_extensions: Extension URIs the client supports.
Returns:
List of unsupported required extensions.
Raises:
None - returns list of unsupported extensions for caller to handle.
"""
unsupported: list[AgentExtension] = []
client_set = set(client_extensions or [])
if not agent_card.capabilities or not agent_card.capabilities.extensions:
return unsupported
unsupported.extend(
ext
for ext in agent_card.capabilities.extensions
if ext.required and ext.uri not in client_set
)
return unsupported
def create_extension_registry_from_config(
a2a_config: list[A2AConfig] | A2AConfig,
a2a_config: list[A2AConfig | A2AClientConfig] | A2AConfig | A2AClientConfig,
) -> ExtensionRegistry:
"""Create an extension registry from A2A configuration.
"""Create an extension registry from A2A client configuration.
Extracts client_extensions from each A2AClientConfig and registers them
with the ExtensionRegistry. These extensions provide CrewAI-specific
processing hooks (tool injection, prompt augmentation, response processing).
Note: A2A protocol extensions (URI strings sent via X-A2A-Extensions header)
are handled separately via get_extensions_from_config() and ExtensionsMiddleware.
Args:
a2a_config: A2A configuration (single or list)
a2a_config: A2A configuration (single or list).
Returns:
Configured extension registry with all applicable extensions
Extension registry with all client_extensions registered.
Example:
class LoggingExtension:
def inject_tools(self, agent): pass
def extract_state_from_history(self, history): return None
def augment_prompt(self, prompt, state): return prompt
def process_response(self, response, state):
print(f"Response: {response}")
return response
config = A2AClientConfig(
endpoint="https://agent.example.com",
client_extensions=[LoggingExtension()],
)
registry = create_extension_registry_from_config(config)
"""
registry = ExtensionRegistry()
configs = a2a_config if isinstance(a2a_config, list) else [a2a_config]
for _ in configs:
pass
seen: set[int] = set()
for config in configs:
if isinstance(config, (A2AConfig, A2AClientConfig)):
client_exts = getattr(config, "client_extensions", [])
for extension in client_exts:
ext_id = id(extension)
if ext_id not in seen:
seen.add(ext_id)
registry.register(extension)
return registry

View File

@@ -0,0 +1,305 @@
"""A2A protocol server extensions for CrewAI agents.
This module provides the base class and context for implementing A2A protocol
extensions on the server side. Extensions allow agents to offer additional
functionality beyond the core A2A specification.
See: https://a2a-protocol.org/latest/topics/extensions/
"""
from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
import logging
from typing import TYPE_CHECKING, Annotated, Any
from a2a.types import AgentExtension
from pydantic_core import CoreSchema, core_schema
if TYPE_CHECKING:
from a2a.server.context import ServerCallContext
from pydantic import GetCoreSchemaHandler
logger = logging.getLogger(__name__)
@dataclass
class ExtensionContext:
"""Context passed to extension hooks during request processing.
Provides access to request metadata, client extensions, and shared state
that extensions can read from and write to.
Attributes:
metadata: Request metadata dict, includes extension-namespaced keys.
client_extensions: Set of extension URIs the client declared support for.
state: Mutable dict for extensions to share data during request lifecycle.
server_context: The underlying A2A server call context.
"""
metadata: dict[str, Any]
client_extensions: set[str]
state: dict[str, Any] = field(default_factory=dict)
server_context: ServerCallContext | None = None
def get_extension_metadata(self, uri: str, key: str) -> Any | None:
"""Get extension-specific metadata value.
Extension metadata uses namespaced keys in the format:
"{extension_uri}/{key}"
Args:
uri: The extension URI.
key: The metadata key within the extension namespace.
Returns:
The metadata value, or None if not present.
"""
full_key = f"{uri}/{key}"
return self.metadata.get(full_key)
def set_extension_metadata(self, uri: str, key: str, value: Any) -> None:
"""Set extension-specific metadata value.
Args:
uri: The extension URI.
key: The metadata key within the extension namespace.
value: The value to set.
"""
full_key = f"{uri}/{key}"
self.metadata[full_key] = value
class ServerExtension(ABC):
"""Base class for A2A protocol server extensions.
Subclass this to create custom extensions that modify agent behavior
when clients activate them. Extensions are identified by URI and can
be marked as required.
Example:
class SamplingExtension(ServerExtension):
uri = "urn:crewai:ext:sampling/v1"
required = True
def __init__(self, max_tokens: int = 4096):
self.max_tokens = max_tokens
@property
def params(self) -> dict[str, Any]:
return {"max_tokens": self.max_tokens}
async def on_request(self, context: ExtensionContext) -> None:
limit = context.get_extension_metadata(self.uri, "limit")
if limit:
context.state["token_limit"] = int(limit)
async def on_response(self, context: ExtensionContext, result: Any) -> Any:
return result
"""
uri: Annotated[str, "Extension URI identifier. Must be unique."]
required: Annotated[bool, "Whether clients must support this extension."] = False
description: Annotated[
str | None, "Human-readable description of the extension."
] = None
@classmethod
def __get_pydantic_core_schema__(
cls,
_source_type: Any,
_handler: GetCoreSchemaHandler,
) -> CoreSchema:
"""Tell Pydantic how to validate ServerExtension instances."""
return core_schema.is_instance_schema(cls)
@property
def params(self) -> dict[str, Any] | None:
"""Extension parameters to advertise in AgentCard.
Override this property to expose configuration that clients can read.
Returns:
Dict of parameter names to values, or None.
"""
return None
def agent_extension(self) -> AgentExtension:
"""Generate the AgentExtension object for the AgentCard.
Returns:
AgentExtension with this extension's URI, required flag, and params.
"""
return AgentExtension(
uri=self.uri,
required=self.required if self.required else None,
description=self.description,
params=self.params,
)
def is_active(self, context: ExtensionContext) -> bool:
"""Check if this extension is active for the current request.
An extension is active if the client declared support for it.
Args:
context: The extension context for the current request.
Returns:
True if the client supports this extension.
"""
return self.uri in context.client_extensions
@abstractmethod
async def on_request(self, context: ExtensionContext) -> None:
"""Called before agent execution if extension is active.
Use this hook to:
- Read extension-specific metadata from the request
- Set up state for the execution
- Modify execution parameters via context.state
Args:
context: The extension context with request metadata and state.
"""
...
@abstractmethod
async def on_response(self, context: ExtensionContext, result: Any) -> Any:
"""Called after agent execution if extension is active.
Use this hook to:
- Modify or enhance the result
- Add extension-specific metadata to the response
- Clean up any resources
Args:
context: The extension context with request metadata and state.
result: The agent execution result.
Returns:
The result, potentially modified.
"""
...
class ServerExtensionRegistry:
"""Registry for managing server-side A2A protocol extensions.
Collects extensions and provides methods to generate AgentCapabilities
and invoke extension hooks during request processing.
"""
def __init__(self, extensions: list[ServerExtension] | None = None) -> None:
"""Initialize the registry with optional extensions.
Args:
extensions: Initial list of extensions to register.
"""
self._extensions: list[ServerExtension] = list(extensions) if extensions else []
self._by_uri: dict[str, ServerExtension] = {
ext.uri: ext for ext in self._extensions
}
def register(self, extension: ServerExtension) -> None:
"""Register an extension.
Args:
extension: The extension to register.
Raises:
ValueError: If an extension with the same URI is already registered.
"""
if extension.uri in self._by_uri:
raise ValueError(f"Extension already registered: {extension.uri}")
self._extensions.append(extension)
self._by_uri[extension.uri] = extension
def get_agent_extensions(self) -> list[AgentExtension]:
"""Get AgentExtension objects for all registered extensions.
Returns:
List of AgentExtension objects for the AgentCard.
"""
return [ext.agent_extension() for ext in self._extensions]
def get_extension(self, uri: str) -> ServerExtension | None:
"""Get an extension by URI.
Args:
uri: The extension URI.
Returns:
The extension, or None if not found.
"""
return self._by_uri.get(uri)
@staticmethod
def create_context(
metadata: dict[str, Any],
client_extensions: set[str],
server_context: ServerCallContext | None = None,
) -> ExtensionContext:
"""Create an ExtensionContext for a request.
Args:
metadata: Request metadata dict.
client_extensions: Set of extension URIs from client.
server_context: Optional server call context.
Returns:
ExtensionContext for use in hooks.
"""
return ExtensionContext(
metadata=metadata,
client_extensions=client_extensions,
server_context=server_context,
)
async def invoke_on_request(self, context: ExtensionContext) -> None:
"""Invoke on_request hooks for all active extensions.
Tracks activated extensions and isolates errors from individual hooks.
Args:
context: The extension context for the request.
"""
for extension in self._extensions:
if extension.is_active(context):
try:
await extension.on_request(context)
if context.server_context is not None:
context.server_context.activated_extensions.add(extension.uri)
except Exception:
logger.exception(
"Extension on_request hook failed",
extra={"extension": extension.uri},
)
async def invoke_on_response(self, context: ExtensionContext, result: Any) -> Any:
"""Invoke on_response hooks for all active extensions.
Isolates errors from individual hooks to prevent one failing extension
from breaking the entire response.
Args:
context: The extension context for the request.
result: The agent execution result.
Returns:
The result after all extensions have processed it.
"""
processed = result
for extension in self._extensions:
if extension.is_active(context):
try:
processed = await extension.on_response(context, processed)
except Exception:
logger.exception(
"Extension on_response hook failed",
extra={"extension": extension.uri},
)
return processed

View File

@@ -51,6 +51,13 @@ ACTIONABLE_STATES: frozenset[TaskState] = frozenset(
}
)
PENDING_STATES: frozenset[TaskState] = frozenset(
{
TaskState.submitted,
TaskState.working,
}
)
class TaskStateResult(TypedDict):
"""Result dictionary from processing A2A task state."""
@@ -272,6 +279,9 @@ def process_task_state(
history=new_messages,
)
if a2a_task.status.state in PENDING_STATES:
return None
return None

View File

@@ -38,3 +38,18 @@ You MUST now:
DO NOT send another request - the task is already done.
</REMOTE_AGENT_STATUS>
"""
REMOTE_AGENT_RESPONSE_NOTICE: Final[str] = """
<REMOTE_AGENT_STATUS>
STATUS: RESPONSE_RECEIVED
The remote agent has responded. Their response is in the conversation history above.
You MUST now:
1. Set is_a2a=false (the remote task is complete and cannot receive more messages)
2. Provide YOUR OWN response to the original task based on the information received
IMPORTANT: Your response should be addressed to the USER who gave you the original task.
Report what the remote agent told you in THIRD PERSON (e.g., "The remote agent said..." or "I learned that...").
Do NOT address the remote agent directly or use "you" to refer to them.
</REMOTE_AGENT_STATUS>
"""

View File

@@ -36,6 +36,17 @@ except ImportError:
TransportType = Literal["JSONRPC", "GRPC", "HTTP+JSON"]
ProtocolVersion = Literal[
"0.2.0",
"0.2.1",
"0.2.2",
"0.2.3",
"0.2.4",
"0.2.5",
"0.2.6",
"0.3.0",
"0.4.0",
]
http_url_adapter: TypeAdapter[HttpUrl] = TypeAdapter(HttpUrl)

View File

@@ -2,12 +2,28 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Protocol, TypedDict
from typing import TYPE_CHECKING, Any, NamedTuple, Protocol, TypedDict
from pydantic import GetCoreSchemaHandler
from pydantic_core import CoreSchema, core_schema
class CommonParams(NamedTuple):
"""Common parameters shared across all update handlers.
Groups the frequently-passed parameters to reduce duplication.
"""
turn_number: int
is_multiturn: bool
agent_role: str | None
endpoint: str
a2a_agent_name: str | None
context_id: str | None
from_task: Any
from_agent: Any
if TYPE_CHECKING:
from a2a.client import Client
from a2a.types import AgentCard, Message, Task
@@ -63,8 +79,8 @@ class PushNotificationResultStore(Protocol):
@classmethod
def __get_pydantic_core_schema__(
cls,
source_type: Any,
handler: GetCoreSchemaHandler,
_source_type: Any,
_handler: GetCoreSchemaHandler,
) -> CoreSchema:
return core_schema.any_schema()
@@ -130,3 +146,31 @@ class UpdateHandler(Protocol):
Result dictionary with status, result/error, and history.
"""
...
def extract_common_params(kwargs: BaseHandlerKwargs) -> CommonParams:
"""Extract common parameters from handler kwargs.
Args:
kwargs: Handler kwargs dict.
Returns:
CommonParams with extracted values.
Raises:
ValueError: If endpoint is not provided.
"""
endpoint = kwargs.get("endpoint")
if endpoint is None:
raise ValueError("endpoint is required for update handlers")
return CommonParams(
turn_number=kwargs.get("turn_number", 0),
is_multiturn=kwargs.get("is_multiturn", False),
agent_role=kwargs.get("agent_role"),
endpoint=endpoint,
a2a_agent_name=kwargs.get("a2a_agent_name"),
context_id=kwargs.get("context_id"),
from_task=kwargs.get("from_task"),
from_agent=kwargs.get("from_agent"),
)

View File

@@ -94,7 +94,7 @@ async def _poll_task_until_complete(
A2APollingStatusEvent(
task_id=task_id,
context_id=effective_context_id,
state=str(task.status.state.value) if task.status.state else "unknown",
state=str(task.status.state.value),
elapsed_seconds=elapsed,
poll_count=poll_count,
endpoint=endpoint,
@@ -325,7 +325,7 @@ class PollingHandler:
crewai_event_bus.emit(
agent_branch,
A2AConnectionErrorEvent(
endpoint=endpoint or "",
endpoint=endpoint,
error=str(e),
error_type="unexpected_error",
a2a_agent_name=a2a_agent_name,

View File

@@ -2,10 +2,30 @@
from __future__ import annotations
from typing import Annotated
from a2a.types import PushNotificationAuthenticationInfo
from pydantic import AnyHttpUrl, BaseModel, Field
from pydantic import AnyHttpUrl, BaseModel, BeforeValidator, Field
from crewai.a2a.updates.base import PushNotificationResultStore
from crewai.a2a.updates.push_notifications.signature import WebhookSignatureConfig
def _coerce_signature(
value: str | WebhookSignatureConfig | None,
) -> WebhookSignatureConfig | None:
"""Convert string secret to WebhookSignatureConfig."""
if value is None:
return None
if isinstance(value, str):
return WebhookSignatureConfig.hmac_sha256(secret=value)
return value
SignatureInput = Annotated[
WebhookSignatureConfig | None,
BeforeValidator(_coerce_signature),
]
class PushNotificationConfig(BaseModel):
@@ -19,6 +39,8 @@ class PushNotificationConfig(BaseModel):
timeout: Max seconds to wait for task completion.
interval: Seconds between result polling attempts.
result_store: Store for receiving push notification results.
signature: HMAC signature config. Pass a string (secret) for defaults,
or WebhookSignatureConfig for custom settings.
"""
url: AnyHttpUrl = Field(description="Callback URL for push notifications")
@@ -36,3 +58,8 @@ class PushNotificationConfig(BaseModel):
result_store: PushNotificationResultStore | None = Field(
default=None, description="Result store for push notification handling"
)
signature: SignatureInput = Field(
default=None,
description="HMAC signature config. Pass a string (secret) for simple usage, "
"or WebhookSignatureConfig for custom headers/tolerance.",
)

View File

@@ -24,8 +24,10 @@ from crewai.a2a.task_helpers import (
send_message_and_get_task_id,
)
from crewai.a2a.updates.base import (
CommonParams,
PushNotificationHandlerKwargs,
PushNotificationResultStore,
extract_common_params,
)
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.a2a_events import (
@@ -39,10 +41,81 @@ from crewai.events.types.a2a_events import (
if TYPE_CHECKING:
from a2a.types import Task as A2ATask
logger = logging.getLogger(__name__)
def _handle_push_error(
error: Exception,
error_msg: str,
error_type: str,
new_messages: list[Message],
agent_branch: Any | None,
params: CommonParams,
task_id: str | None,
status_code: int | None = None,
) -> TaskStateResult:
"""Handle push notification errors with consistent event emission.
Args:
error: The exception that occurred.
error_msg: Formatted error message for the result.
error_type: Type of error for the event.
new_messages: List to append error message to.
agent_branch: Agent tree branch for events.
params: Common handler parameters.
task_id: A2A task ID.
status_code: HTTP status code if applicable.
Returns:
TaskStateResult with failed status.
"""
error_message = Message(
role=Role.agent,
message_id=str(uuid.uuid4()),
parts=[Part(root=TextPart(text=error_msg))],
context_id=params.context_id,
task_id=task_id,
)
new_messages.append(error_message)
crewai_event_bus.emit(
agent_branch,
A2AConnectionErrorEvent(
endpoint=params.endpoint,
error=str(error),
error_type=error_type,
status_code=status_code,
a2a_agent_name=params.a2a_agent_name,
operation="push_notification",
context_id=params.context_id,
task_id=task_id,
from_task=params.from_task,
from_agent=params.from_agent,
),
)
crewai_event_bus.emit(
agent_branch,
A2AResponseReceivedEvent(
response=error_msg,
turn_number=params.turn_number,
context_id=params.context_id,
is_multiturn=params.is_multiturn,
status="failed",
final=True,
agent_role=params.agent_role,
endpoint=params.endpoint,
a2a_agent_name=params.a2a_agent_name,
from_task=params.from_task,
from_agent=params.from_agent,
),
)
return TaskStateResult(
status=TaskState.failed,
error=error_msg,
history=new_messages,
)
async def _wait_for_push_result(
task_id: str,
result_store: PushNotificationResultStore,
@@ -126,15 +199,8 @@ class PushNotificationHandler:
polling_timeout = kwargs.get("polling_timeout", 300.0)
polling_interval = kwargs.get("polling_interval", 2.0)
agent_branch = kwargs.get("agent_branch")
turn_number = kwargs.get("turn_number", 0)
is_multiturn = kwargs.get("is_multiturn", False)
agent_role = kwargs.get("agent_role")
context_id = kwargs.get("context_id")
task_id = kwargs.get("task_id")
endpoint = kwargs.get("endpoint")
a2a_agent_name = kwargs.get("a2a_agent_name")
from_task = kwargs.get("from_task")
from_agent = kwargs.get("from_agent")
params = extract_common_params(kwargs)
if config is None:
error_msg = (
@@ -143,15 +209,15 @@ class PushNotificationHandler:
crewai_event_bus.emit(
agent_branch,
A2AConnectionErrorEvent(
endpoint=endpoint or "",
endpoint=params.endpoint,
error=error_msg,
error_type="configuration_error",
a2a_agent_name=a2a_agent_name,
a2a_agent_name=params.a2a_agent_name,
operation="push_notification",
context_id=context_id,
context_id=params.context_id,
task_id=task_id,
from_task=from_task,
from_agent=from_agent,
from_task=params.from_task,
from_agent=params.from_agent,
),
)
return TaskStateResult(
@@ -167,15 +233,15 @@ class PushNotificationHandler:
crewai_event_bus.emit(
agent_branch,
A2AConnectionErrorEvent(
endpoint=endpoint or "",
endpoint=params.endpoint,
error=error_msg,
error_type="configuration_error",
a2a_agent_name=a2a_agent_name,
a2a_agent_name=params.a2a_agent_name,
operation="push_notification",
context_id=context_id,
context_id=params.context_id,
task_id=task_id,
from_task=from_task,
from_agent=from_agent,
from_task=params.from_task,
from_agent=params.from_agent,
),
)
return TaskStateResult(
@@ -189,14 +255,14 @@ class PushNotificationHandler:
event_stream=client.send_message(message),
new_messages=new_messages,
agent_card=agent_card,
turn_number=turn_number,
is_multiturn=is_multiturn,
agent_role=agent_role,
from_task=from_task,
from_agent=from_agent,
endpoint=endpoint,
a2a_agent_name=a2a_agent_name,
context_id=context_id,
turn_number=params.turn_number,
is_multiturn=params.is_multiturn,
agent_role=params.agent_role,
from_task=params.from_task,
from_agent=params.from_agent,
endpoint=params.endpoint,
a2a_agent_name=params.a2a_agent_name,
context_id=params.context_id,
)
if not isinstance(result_or_task_id, str):
@@ -208,12 +274,12 @@ class PushNotificationHandler:
agent_branch,
A2APushNotificationRegisteredEvent(
task_id=task_id,
context_id=context_id,
context_id=params.context_id,
callback_url=str(config.url),
endpoint=endpoint,
a2a_agent_name=a2a_agent_name,
from_task=from_task,
from_agent=from_agent,
endpoint=params.endpoint,
a2a_agent_name=params.a2a_agent_name,
from_task=params.from_task,
from_agent=params.from_agent,
),
)
@@ -229,11 +295,11 @@ class PushNotificationHandler:
timeout=polling_timeout,
poll_interval=polling_interval,
agent_branch=agent_branch,
from_task=from_task,
from_agent=from_agent,
context_id=context_id,
endpoint=endpoint,
a2a_agent_name=a2a_agent_name,
from_task=params.from_task,
from_agent=params.from_agent,
context_id=params.context_id,
endpoint=params.endpoint,
a2a_agent_name=params.a2a_agent_name,
)
if final_task is None:
@@ -247,13 +313,13 @@ class PushNotificationHandler:
a2a_task=final_task,
new_messages=new_messages,
agent_card=agent_card,
turn_number=turn_number,
is_multiturn=is_multiturn,
agent_role=agent_role,
endpoint=endpoint,
a2a_agent_name=a2a_agent_name,
from_task=from_task,
from_agent=from_agent,
turn_number=params.turn_number,
is_multiturn=params.is_multiturn,
agent_role=params.agent_role,
endpoint=params.endpoint,
a2a_agent_name=params.a2a_agent_name,
from_task=params.from_task,
from_agent=params.from_agent,
)
if result:
return result
@@ -265,98 +331,24 @@ class PushNotificationHandler:
)
except A2AClientHTTPError as e:
error_msg = f"HTTP Error {e.status_code}: {e!s}"
error_message = Message(
role=Role.agent,
message_id=str(uuid.uuid4()),
parts=[Part(root=TextPart(text=error_msg))],
context_id=context_id,
return _handle_push_error(
error=e,
error_msg=f"HTTP Error {e.status_code}: {e!s}",
error_type="http_error",
new_messages=new_messages,
agent_branch=agent_branch,
params=params,
task_id=task_id,
)
new_messages.append(error_message)
crewai_event_bus.emit(
agent_branch,
A2AConnectionErrorEvent(
endpoint=endpoint or "",
error=str(e),
error_type="http_error",
status_code=e.status_code,
a2a_agent_name=a2a_agent_name,
operation="push_notification",
context_id=context_id,
task_id=task_id,
from_task=from_task,
from_agent=from_agent,
),
)
crewai_event_bus.emit(
agent_branch,
A2AResponseReceivedEvent(
response=error_msg,
turn_number=turn_number,
context_id=context_id,
is_multiturn=is_multiturn,
status="failed",
final=True,
agent_role=agent_role,
endpoint=endpoint,
a2a_agent_name=a2a_agent_name,
from_task=from_task,
from_agent=from_agent,
),
)
return TaskStateResult(
status=TaskState.failed,
error=error_msg,
history=new_messages,
status_code=e.status_code,
)
except Exception as e:
error_msg = f"Unexpected error during push notification: {e!s}"
error_message = Message(
role=Role.agent,
message_id=str(uuid.uuid4()),
parts=[Part(root=TextPart(text=error_msg))],
context_id=context_id,
return _handle_push_error(
error=e,
error_msg=f"Unexpected error during push notification: {e!s}",
error_type="unexpected_error",
new_messages=new_messages,
agent_branch=agent_branch,
params=params,
task_id=task_id,
)
new_messages.append(error_message)
crewai_event_bus.emit(
agent_branch,
A2AConnectionErrorEvent(
endpoint=endpoint or "",
error=str(e),
error_type="unexpected_error",
a2a_agent_name=a2a_agent_name,
operation="push_notification",
context_id=context_id,
task_id=task_id,
from_task=from_task,
from_agent=from_agent,
),
)
crewai_event_bus.emit(
agent_branch,
A2AResponseReceivedEvent(
response=error_msg,
turn_number=turn_number,
context_id=context_id,
is_multiturn=is_multiturn,
status="failed",
final=True,
agent_role=agent_role,
endpoint=endpoint,
a2a_agent_name=a2a_agent_name,
from_task=from_task,
from_agent=from_agent,
),
)
return TaskStateResult(
status=TaskState.failed,
error=error_msg,
history=new_messages,
)

View File

@@ -0,0 +1,87 @@
"""Webhook signature configuration for push notifications."""
from __future__ import annotations
from enum import Enum
import secrets
from pydantic import BaseModel, Field, SecretStr
class WebhookSignatureMode(str, Enum):
"""Signature mode for webhook push notifications."""
NONE = "none"
HMAC_SHA256 = "hmac_sha256"
class WebhookSignatureConfig(BaseModel):
"""Configuration for webhook signature verification.
Provides cryptographic integrity verification and replay attack protection
for A2A push notifications.
Attributes:
mode: Signature mode (none or hmac_sha256).
secret: Shared secret for HMAC computation (required for hmac_sha256 mode).
timestamp_tolerance_seconds: Max allowed age of timestamps for replay protection.
header_name: HTTP header name for the signature.
timestamp_header_name: HTTP header name for the timestamp.
"""
mode: WebhookSignatureMode = Field(
default=WebhookSignatureMode.NONE,
description="Signature verification mode",
)
secret: SecretStr | None = Field(
default=None,
description="Shared secret for HMAC computation",
)
timestamp_tolerance_seconds: int = Field(
default=300,
ge=0,
description="Max allowed timestamp age in seconds (5 min default)",
)
header_name: str = Field(
default="X-A2A-Signature",
description="HTTP header name for the signature",
)
timestamp_header_name: str = Field(
default="X-A2A-Signature-Timestamp",
description="HTTP header name for the timestamp",
)
@classmethod
def generate_secret(cls, length: int = 32) -> str:
"""Generate a cryptographically secure random secret.
Args:
length: Number of random bytes to generate (default 32).
Returns:
URL-safe base64-encoded secret string.
"""
return secrets.token_urlsafe(length)
@classmethod
def hmac_sha256(
cls,
secret: str | SecretStr,
timestamp_tolerance_seconds: int = 300,
) -> WebhookSignatureConfig:
"""Create an HMAC-SHA256 signature configuration.
Args:
secret: Shared secret for HMAC computation.
timestamp_tolerance_seconds: Max allowed timestamp age in seconds.
Returns:
Configured WebhookSignatureConfig for HMAC-SHA256.
"""
if isinstance(secret, str):
secret = SecretStr(secret)
return cls(
mode=WebhookSignatureMode.HMAC_SHA256,
secret=secret,
timestamp_tolerance_seconds=timestamp_tolerance_seconds,
)

View File

@@ -2,6 +2,9 @@
from __future__ import annotations
import asyncio
import logging
from typing import Final
import uuid
from a2a.client import Client
@@ -11,7 +14,10 @@ from a2a.types import (
Message,
Part,
Role,
Task,
TaskArtifactUpdateEvent,
TaskIdParams,
TaskQueryParams,
TaskState,
TaskStatusUpdateEvent,
TextPart,
@@ -24,7 +30,10 @@ from crewai.a2a.task_helpers import (
TaskStateResult,
process_task_state,
)
from crewai.a2a.updates.base import StreamingHandlerKwargs
from crewai.a2a.updates.base import StreamingHandlerKwargs, extract_common_params
from crewai.a2a.updates.streaming.params import (
process_status_update,
)
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.a2a_events import (
A2AArtifactReceivedEvent,
@@ -35,9 +44,194 @@ from crewai.events.types.a2a_events import (
)
logger = logging.getLogger(__name__)
MAX_RESUBSCRIBE_ATTEMPTS: Final[int] = 3
RESUBSCRIBE_BACKOFF_BASE: Final[float] = 1.0
class StreamingHandler:
"""SSE streaming-based update handler."""
@staticmethod
async def _try_recover_from_interruption( # type: ignore[misc]
client: Client,
task_id: str,
new_messages: list[Message],
agent_card: AgentCard,
result_parts: list[str],
**kwargs: Unpack[StreamingHandlerKwargs],
) -> TaskStateResult | None:
"""Attempt to recover from a stream interruption by checking task state.
If the task completed while we were disconnected, returns the result.
If the task is still running, attempts to resubscribe and continue.
Args:
client: A2A client instance.
task_id: The task ID to recover.
new_messages: List of collected messages.
agent_card: The agent card.
result_parts: Accumulated result text parts.
**kwargs: Handler parameters.
Returns:
TaskStateResult if recovery succeeded (task finished or resubscribe worked).
None if recovery not possible (caller should handle failure).
Note:
When None is returned, recovery failed and the original exception should
be handled by the caller. All recovery attempts are logged.
"""
params = extract_common_params(kwargs) # type: ignore[arg-type]
try:
a2a_task: Task = await client.get_task(TaskQueryParams(id=task_id))
if a2a_task.status.state in TERMINAL_STATES:
logger.info(
"Task completed during stream interruption",
extra={"task_id": task_id, "state": str(a2a_task.status.state)},
)
return process_task_state(
a2a_task=a2a_task,
new_messages=new_messages,
agent_card=agent_card,
turn_number=params.turn_number,
is_multiturn=params.is_multiturn,
agent_role=params.agent_role,
result_parts=result_parts,
endpoint=params.endpoint,
a2a_agent_name=params.a2a_agent_name,
from_task=params.from_task,
from_agent=params.from_agent,
)
if a2a_task.status.state in ACTIONABLE_STATES:
logger.info(
"Task in actionable state during stream interruption",
extra={"task_id": task_id, "state": str(a2a_task.status.state)},
)
return process_task_state(
a2a_task=a2a_task,
new_messages=new_messages,
agent_card=agent_card,
turn_number=params.turn_number,
is_multiturn=params.is_multiturn,
agent_role=params.agent_role,
result_parts=result_parts,
endpoint=params.endpoint,
a2a_agent_name=params.a2a_agent_name,
from_task=params.from_task,
from_agent=params.from_agent,
is_final=False,
)
logger.info(
"Task still running, attempting resubscribe",
extra={"task_id": task_id, "state": str(a2a_task.status.state)},
)
for attempt in range(MAX_RESUBSCRIBE_ATTEMPTS):
try:
backoff = RESUBSCRIBE_BACKOFF_BASE * (2**attempt)
if attempt > 0:
await asyncio.sleep(backoff)
event_stream = client.resubscribe(TaskIdParams(id=task_id))
async for event in event_stream:
if isinstance(event, tuple):
resubscribed_task, update = event
is_final_update = (
process_status_update(update, result_parts)
if isinstance(update, TaskStatusUpdateEvent)
else False
)
if isinstance(update, TaskArtifactUpdateEvent):
artifact = update.artifact
result_parts.extend(
part.root.text
for part in artifact.parts
if part.root.kind == "text"
)
if (
is_final_update
or resubscribed_task.status.state
in TERMINAL_STATES | ACTIONABLE_STATES
):
return process_task_state(
a2a_task=resubscribed_task,
new_messages=new_messages,
agent_card=agent_card,
turn_number=params.turn_number,
is_multiturn=params.is_multiturn,
agent_role=params.agent_role,
result_parts=result_parts,
endpoint=params.endpoint,
a2a_agent_name=params.a2a_agent_name,
from_task=params.from_task,
from_agent=params.from_agent,
is_final=is_final_update,
)
elif isinstance(event, Message):
new_messages.append(event)
result_parts.extend(
part.root.text
for part in event.parts
if part.root.kind == "text"
)
final_task = await client.get_task(TaskQueryParams(id=task_id))
return process_task_state(
a2a_task=final_task,
new_messages=new_messages,
agent_card=agent_card,
turn_number=params.turn_number,
is_multiturn=params.is_multiturn,
agent_role=params.agent_role,
result_parts=result_parts,
endpoint=params.endpoint,
a2a_agent_name=params.a2a_agent_name,
from_task=params.from_task,
from_agent=params.from_agent,
)
except Exception as resubscribe_error: # noqa: PERF203
logger.warning(
"Resubscribe attempt failed",
extra={
"task_id": task_id,
"attempt": attempt + 1,
"max_attempts": MAX_RESUBSCRIBE_ATTEMPTS,
"error": str(resubscribe_error),
},
)
if attempt == MAX_RESUBSCRIBE_ATTEMPTS - 1:
return None
except Exception as e:
logger.warning(
"Failed to recover from stream interruption due to unexpected error",
extra={
"task_id": task_id,
"error": str(e),
"error_type": type(e).__name__,
},
exc_info=True,
)
return None
logger.warning(
"Recovery exhausted all resubscribe attempts without success",
extra={"task_id": task_id, "max_attempts": MAX_RESUBSCRIBE_ATTEMPTS},
)
return None
@staticmethod
async def execute(
client: Client,
@@ -58,42 +252,40 @@ class StreamingHandler:
Returns:
Dictionary with status, result/error, and history.
"""
context_id = kwargs.get("context_id")
task_id = kwargs.get("task_id")
turn_number = kwargs.get("turn_number", 0)
is_multiturn = kwargs.get("is_multiturn", False)
agent_role = kwargs.get("agent_role")
endpoint = kwargs.get("endpoint")
a2a_agent_name = kwargs.get("a2a_agent_name")
from_task = kwargs.get("from_task")
from_agent = kwargs.get("from_agent")
agent_branch = kwargs.get("agent_branch")
params = extract_common_params(kwargs)
result_parts: list[str] = []
final_result: TaskStateResult | None = None
event_stream = client.send_message(message)
chunk_index = 0
current_task_id: str | None = task_id
crewai_event_bus.emit(
agent_branch,
A2AStreamingStartedEvent(
task_id=task_id,
context_id=context_id,
endpoint=endpoint or "",
a2a_agent_name=a2a_agent_name,
turn_number=turn_number,
is_multiturn=is_multiturn,
agent_role=agent_role,
from_task=from_task,
from_agent=from_agent,
context_id=params.context_id,
endpoint=params.endpoint,
a2a_agent_name=params.a2a_agent_name,
turn_number=params.turn_number,
is_multiturn=params.is_multiturn,
agent_role=params.agent_role,
from_task=params.from_task,
from_agent=params.from_agent,
),
)
try:
async for event in event_stream:
if isinstance(event, tuple):
a2a_task, _ = event
current_task_id = a2a_task.id
if isinstance(event, Message):
new_messages.append(event)
message_context_id = event.context_id or context_id
message_context_id = event.context_id or params.context_id
for part in event.parts:
if part.root.kind == "text":
text = part.root.text
@@ -105,12 +297,12 @@ class StreamingHandler:
context_id=message_context_id,
chunk=text,
chunk_index=chunk_index,
endpoint=endpoint,
a2a_agent_name=a2a_agent_name,
turn_number=turn_number,
is_multiturn=is_multiturn,
from_task=from_task,
from_agent=from_agent,
endpoint=params.endpoint,
a2a_agent_name=params.a2a_agent_name,
turn_number=params.turn_number,
is_multiturn=params.is_multiturn,
from_task=params.from_task,
from_agent=params.from_agent,
),
)
chunk_index += 1
@@ -128,12 +320,12 @@ class StreamingHandler:
artifact_size = None
if artifact.parts:
artifact_size = sum(
len(p.root.text.encode("utf-8"))
len(p.root.text.encode())
if p.root.kind == "text"
else len(getattr(p.root, "data", b""))
for p in artifact.parts
)
effective_context_id = a2a_task.context_id or context_id
effective_context_id = a2a_task.context_id or params.context_id
crewai_event_bus.emit(
agent_branch,
A2AArtifactReceivedEvent(
@@ -147,29 +339,21 @@ class StreamingHandler:
size_bytes=artifact_size,
append=update.append or False,
last_chunk=update.last_chunk or False,
endpoint=endpoint,
a2a_agent_name=a2a_agent_name,
endpoint=params.endpoint,
a2a_agent_name=params.a2a_agent_name,
context_id=effective_context_id,
turn_number=turn_number,
is_multiturn=is_multiturn,
from_task=from_task,
from_agent=from_agent,
turn_number=params.turn_number,
is_multiturn=params.is_multiturn,
from_task=params.from_task,
from_agent=params.from_agent,
),
)
is_final_update = False
if isinstance(update, TaskStatusUpdateEvent):
is_final_update = update.final
if (
update.status
and update.status.message
and update.status.message.parts
):
result_parts.extend(
part.root.text
for part in update.status.message.parts
if part.root.kind == "text" and part.root.text
)
is_final_update = (
process_status_update(update, result_parts)
if isinstance(update, TaskStatusUpdateEvent)
else False
)
if (
not is_final_update
@@ -182,27 +366,68 @@ class StreamingHandler:
a2a_task=a2a_task,
new_messages=new_messages,
agent_card=agent_card,
turn_number=turn_number,
is_multiturn=is_multiturn,
agent_role=agent_role,
turn_number=params.turn_number,
is_multiturn=params.is_multiturn,
agent_role=params.agent_role,
result_parts=result_parts,
endpoint=endpoint,
a2a_agent_name=a2a_agent_name,
from_task=from_task,
from_agent=from_agent,
endpoint=params.endpoint,
a2a_agent_name=params.a2a_agent_name,
from_task=params.from_task,
from_agent=params.from_agent,
is_final=is_final_update,
)
if final_result:
break
except A2AClientHTTPError as e:
if current_task_id:
logger.info(
"Stream interrupted with HTTP error, attempting recovery",
extra={
"task_id": current_task_id,
"error": str(e),
"status_code": e.status_code,
},
)
recovery_kwargs = {k: v for k, v in kwargs.items() if k != "task_id"}
recovered_result = (
await StreamingHandler._try_recover_from_interruption(
client=client,
task_id=current_task_id,
new_messages=new_messages,
agent_card=agent_card,
result_parts=result_parts,
**recovery_kwargs,
)
)
if recovered_result:
logger.info(
"Successfully recovered task after HTTP error",
extra={
"task_id": current_task_id,
"status": str(recovered_result.get("status")),
},
)
return recovered_result
logger.warning(
"Failed to recover from HTTP error, returning failure",
extra={
"task_id": current_task_id,
"status_code": e.status_code,
"original_error": str(e),
},
)
error_msg = f"HTTP Error {e.status_code}: {e!s}"
error_type = "http_error"
status_code = e.status_code
error_message = Message(
role=Role.agent,
message_id=str(uuid.uuid4()),
parts=[Part(root=TextPart(text=error_msg))],
context_id=context_id,
context_id=params.context_id,
task_id=task_id,
)
new_messages.append(error_message)
@@ -210,32 +435,118 @@ class StreamingHandler:
crewai_event_bus.emit(
agent_branch,
A2AConnectionErrorEvent(
endpoint=endpoint or "",
endpoint=params.endpoint,
error=str(e),
error_type="http_error",
status_code=e.status_code,
a2a_agent_name=a2a_agent_name,
error_type=error_type,
status_code=status_code,
a2a_agent_name=params.a2a_agent_name,
operation="streaming",
context_id=context_id,
context_id=params.context_id,
task_id=task_id,
from_task=from_task,
from_agent=from_agent,
from_task=params.from_task,
from_agent=params.from_agent,
),
)
crewai_event_bus.emit(
agent_branch,
A2AResponseReceivedEvent(
response=error_msg,
turn_number=turn_number,
context_id=context_id,
is_multiturn=is_multiturn,
turn_number=params.turn_number,
context_id=params.context_id,
is_multiturn=params.is_multiturn,
status="failed",
final=True,
agent_role=agent_role,
endpoint=endpoint,
a2a_agent_name=a2a_agent_name,
from_task=from_task,
from_agent=from_agent,
agent_role=params.agent_role,
endpoint=params.endpoint,
a2a_agent_name=params.a2a_agent_name,
from_task=params.from_task,
from_agent=params.from_agent,
),
)
return TaskStateResult(
status=TaskState.failed,
error=error_msg,
history=new_messages,
)
except (asyncio.TimeoutError, asyncio.CancelledError, ConnectionError) as e:
error_type = type(e).__name__.lower()
if current_task_id:
logger.info(
f"Stream interrupted with {error_type}, attempting recovery",
extra={"task_id": current_task_id, "error": str(e)},
)
recovery_kwargs = {k: v for k, v in kwargs.items() if k != "task_id"}
recovered_result = (
await StreamingHandler._try_recover_from_interruption(
client=client,
task_id=current_task_id,
new_messages=new_messages,
agent_card=agent_card,
result_parts=result_parts,
**recovery_kwargs,
)
)
if recovered_result:
logger.info(
f"Successfully recovered task after {error_type}",
extra={
"task_id": current_task_id,
"status": str(recovered_result.get("status")),
},
)
return recovered_result
logger.warning(
f"Failed to recover from {error_type}, returning failure",
extra={
"task_id": current_task_id,
"error_type": error_type,
"original_error": str(e),
},
)
error_msg = f"Connection error during streaming: {e!s}"
status_code = None
error_message = Message(
role=Role.agent,
message_id=str(uuid.uuid4()),
parts=[Part(root=TextPart(text=error_msg))],
context_id=params.context_id,
task_id=task_id,
)
new_messages.append(error_message)
crewai_event_bus.emit(
agent_branch,
A2AConnectionErrorEvent(
endpoint=params.endpoint,
error=str(e),
error_type=error_type,
status_code=status_code,
a2a_agent_name=params.a2a_agent_name,
operation="streaming",
context_id=params.context_id,
task_id=task_id,
from_task=params.from_task,
from_agent=params.from_agent,
),
)
crewai_event_bus.emit(
agent_branch,
A2AResponseReceivedEvent(
response=error_msg,
turn_number=params.turn_number,
context_id=params.context_id,
is_multiturn=params.is_multiturn,
status="failed",
final=True,
agent_role=params.agent_role,
endpoint=params.endpoint,
a2a_agent_name=params.a2a_agent_name,
from_task=params.from_task,
from_agent=params.from_agent,
),
)
return TaskStateResult(
@@ -245,13 +556,23 @@ class StreamingHandler:
)
except Exception as e:
error_msg = f"Unexpected error during streaming: {e!s}"
logger.exception(
"Unexpected error during streaming",
extra={
"task_id": current_task_id,
"error_type": type(e).__name__,
"endpoint": params.endpoint,
},
)
error_msg = f"Unexpected error during streaming: {type(e).__name__}: {e!s}"
error_type = "unexpected_error"
status_code = None
error_message = Message(
role=Role.agent,
message_id=str(uuid.uuid4()),
parts=[Part(root=TextPart(text=error_msg))],
context_id=context_id,
context_id=params.context_id,
task_id=task_id,
)
new_messages.append(error_message)
@@ -259,31 +580,32 @@ class StreamingHandler:
crewai_event_bus.emit(
agent_branch,
A2AConnectionErrorEvent(
endpoint=endpoint or "",
endpoint=params.endpoint,
error=str(e),
error_type="unexpected_error",
a2a_agent_name=a2a_agent_name,
error_type=error_type,
status_code=status_code,
a2a_agent_name=params.a2a_agent_name,
operation="streaming",
context_id=context_id,
context_id=params.context_id,
task_id=task_id,
from_task=from_task,
from_agent=from_agent,
from_task=params.from_task,
from_agent=params.from_agent,
),
)
crewai_event_bus.emit(
agent_branch,
A2AResponseReceivedEvent(
response=error_msg,
turn_number=turn_number,
context_id=context_id,
is_multiturn=is_multiturn,
turn_number=params.turn_number,
context_id=params.context_id,
is_multiturn=params.is_multiturn,
status="failed",
final=True,
agent_role=agent_role,
endpoint=endpoint,
a2a_agent_name=a2a_agent_name,
from_task=from_task,
from_agent=from_agent,
agent_role=params.agent_role,
endpoint=params.endpoint,
a2a_agent_name=params.a2a_agent_name,
from_task=params.from_task,
from_agent=params.from_agent,
),
)
return TaskStateResult(
@@ -301,15 +623,15 @@ class StreamingHandler:
crewai_event_bus.emit(
agent_branch,
A2AConnectionErrorEvent(
endpoint=endpoint or "",
endpoint=params.endpoint,
error=str(close_error),
error_type="stream_close_error",
a2a_agent_name=a2a_agent_name,
a2a_agent_name=params.a2a_agent_name,
operation="stream_close",
context_id=context_id,
context_id=params.context_id,
task_id=task_id,
from_task=from_task,
from_agent=from_agent,
from_task=params.from_task,
from_agent=params.from_agent,
),
)

View File

@@ -0,0 +1,28 @@
"""Common parameter extraction for streaming handlers."""
from __future__ import annotations
from a2a.types import TaskStatusUpdateEvent
def process_status_update(
update: TaskStatusUpdateEvent,
result_parts: list[str],
) -> bool:
"""Process a status update event and extract text parts.
Args:
update: The status update event.
result_parts: List to append text parts to (modified in place).
Returns:
True if this is a final update, False otherwise.
"""
is_final = update.final
if update.status and update.status.message and update.status.message.parts:
result_parts.extend(
part.root.text
for part in update.status.message.parts
if part.root.kind == "text" and part.root.text
)
return is_final

View File

@@ -5,6 +5,7 @@ from __future__ import annotations
import asyncio
from collections.abc import MutableMapping
from functools import lru_cache
import ssl
import time
from types import MethodType
from typing import TYPE_CHECKING
@@ -15,7 +16,7 @@ from aiocache import cached # type: ignore[import-untyped]
from aiocache.serializers import PickleSerializer # type: ignore[import-untyped]
import httpx
from crewai.a2a.auth.schemas import APIKeyAuth, HTTPDigestAuth
from crewai.a2a.auth.client_schemes import APIKeyAuth, HTTPDigestAuth
from crewai.a2a.auth.utils import (
_auth_store,
configure_auth_client,
@@ -32,11 +33,51 @@ from crewai.events.types.a2a_events import (
if TYPE_CHECKING:
from crewai.a2a.auth.schemas import AuthScheme
from crewai.a2a.auth.client_schemes import ClientAuthScheme
from crewai.agent import Agent
from crewai.task import Task
def _get_tls_verify(auth: ClientAuthScheme | None) -> ssl.SSLContext | bool | str:
"""Get TLS verify parameter from auth scheme.
Args:
auth: Optional authentication scheme with TLS config.
Returns:
SSL context, CA cert path, True for default verification,
or False if verification disabled.
"""
if auth and auth.tls:
return auth.tls.get_httpx_ssl_context()
return True
async def _prepare_auth_headers(
auth: ClientAuthScheme | None,
timeout: int,
) -> tuple[MutableMapping[str, str], ssl.SSLContext | bool | str]:
"""Prepare authentication headers and TLS verification settings.
Args:
auth: Optional authentication scheme.
timeout: Request timeout in seconds.
Returns:
Tuple of (headers dict, TLS verify setting).
"""
headers: MutableMapping[str, str] = {}
verify = _get_tls_verify(auth)
if auth:
async with httpx.AsyncClient(
timeout=timeout, verify=verify
) as temp_auth_client:
if isinstance(auth, (HTTPDigestAuth, APIKeyAuth)):
configure_auth_client(auth, temp_auth_client)
headers = await auth.apply_auth(temp_auth_client, {})
return headers, verify
def _get_server_config(agent: Agent) -> A2AServerConfig | None:
"""Get A2AServerConfig from an agent's a2a configuration.
@@ -59,7 +100,7 @@ def _get_server_config(agent: Agent) -> A2AServerConfig | None:
def fetch_agent_card(
endpoint: str,
auth: AuthScheme | None = None,
auth: ClientAuthScheme | None = None,
timeout: int = 30,
use_cache: bool = True,
cache_ttl: int = 300,
@@ -68,7 +109,7 @@ def fetch_agent_card(
Args:
endpoint: A2A agent endpoint URL (AgentCard URL).
auth: Optional AuthScheme for authentication.
auth: Optional ClientAuthScheme for authentication.
timeout: Request timeout in seconds.
use_cache: Whether to use caching (default True).
cache_ttl: Cache TTL in seconds (default 300 = 5 minutes).
@@ -90,10 +131,10 @@ def fetch_agent_card(
"_authorization_callback",
}
)
auth_hash = hash((type(auth).__name__, auth_data))
auth_hash = _auth_store.compute_key(type(auth).__name__, auth_data)
else:
auth_hash = 0
_auth_store[auth_hash] = auth
auth_hash = _auth_store.compute_key("none", "")
_auth_store.set(auth_hash, auth)
ttl_hash = int(time.time() // cache_ttl)
return _fetch_agent_card_cached(endpoint, auth_hash, timeout, ttl_hash)
@@ -109,7 +150,7 @@ def fetch_agent_card(
async def afetch_agent_card(
endpoint: str,
auth: AuthScheme | None = None,
auth: ClientAuthScheme | None = None,
timeout: int = 30,
use_cache: bool = True,
) -> AgentCard:
@@ -119,7 +160,7 @@ async def afetch_agent_card(
Args:
endpoint: A2A agent endpoint URL (AgentCard URL).
auth: Optional AuthScheme for authentication.
auth: Optional ClientAuthScheme for authentication.
timeout: Request timeout in seconds.
use_cache: Whether to use caching (default True).
@@ -140,10 +181,10 @@ async def afetch_agent_card(
"_authorization_callback",
}
)
auth_hash = hash((type(auth).__name__, auth_data))
auth_hash = _auth_store.compute_key(type(auth).__name__, auth_data)
else:
auth_hash = 0
_auth_store[auth_hash] = auth
auth_hash = _auth_store.compute_key("none", "")
_auth_store.set(auth_hash, auth)
agent_card: AgentCard = await _afetch_agent_card_cached(
endpoint, auth_hash, timeout
)
@@ -155,7 +196,7 @@ async def afetch_agent_card(
@lru_cache()
def _fetch_agent_card_cached(
endpoint: str,
auth_hash: int,
auth_hash: str,
timeout: int,
_ttl_hash: int,
) -> AgentCard:
@@ -175,7 +216,7 @@ def _fetch_agent_card_cached(
@cached(ttl=300, serializer=PickleSerializer()) # type: ignore[untyped-decorator]
async def _afetch_agent_card_cached(
endpoint: str,
auth_hash: int,
auth_hash: str,
timeout: int,
) -> AgentCard:
"""Cached async implementation of AgentCard fetching."""
@@ -185,7 +226,7 @@ async def _afetch_agent_card_cached(
async def _afetch_agent_card_impl(
endpoint: str,
auth: AuthScheme | None,
auth: ClientAuthScheme | None,
timeout: int,
) -> AgentCard:
"""Internal async implementation of AgentCard fetching."""
@@ -197,16 +238,17 @@ async def _afetch_agent_card_impl(
else:
url_parts = endpoint.split("/", 3)
base_url = f"{url_parts[0]}//{url_parts[2]}"
agent_card_path = f"/{url_parts[3]}" if len(url_parts) > 3 else "/"
agent_card_path = (
f"/{url_parts[3]}"
if len(url_parts) > 3 and url_parts[3]
else "/.well-known/agent-card.json"
)
headers: MutableMapping[str, str] = {}
if auth:
async with httpx.AsyncClient(timeout=timeout) as temp_auth_client:
if isinstance(auth, (HTTPDigestAuth, APIKeyAuth)):
configure_auth_client(auth, temp_auth_client)
headers = await auth.apply_auth(temp_auth_client, {})
headers, verify = await _prepare_auth_headers(auth, timeout)
async with httpx.AsyncClient(timeout=timeout, headers=headers) as temp_client:
async with httpx.AsyncClient(
timeout=timeout, headers=headers, verify=verify
) as temp_client:
if auth and isinstance(auth, (HTTPDigestAuth, APIKeyAuth)):
configure_auth_client(auth, temp_client)
@@ -434,6 +476,7 @@ def _agent_to_agent_card(agent: Agent, url: str) -> AgentCard:
"""Generate an A2A AgentCard from an Agent instance.
Uses A2AServerConfig values when available, falling back to agent properties.
If signing_config is provided, the card will be signed with JWS.
Args:
agent: The Agent instance to generate a card for.
@@ -442,6 +485,8 @@ def _agent_to_agent_card(agent: Agent, url: str) -> AgentCard:
Returns:
AgentCard describing the agent's capabilities.
"""
from crewai.a2a.utils.agent_card_signing import sign_agent_card
server_config = _get_server_config(agent) or A2AServerConfig()
name = server_config.name or agent.role
@@ -472,15 +517,31 @@ def _agent_to_agent_card(agent: Agent, url: str) -> AgentCard:
)
)
return AgentCard(
capabilities = server_config.capabilities
if server_config.server_extensions:
from crewai.a2a.extensions.server import ServerExtensionRegistry
registry = ServerExtensionRegistry(server_config.server_extensions)
ext_list = registry.get_agent_extensions()
existing_exts = list(capabilities.extensions) if capabilities.extensions else []
existing_uris = {e.uri for e in existing_exts}
for ext in ext_list:
if ext.uri not in existing_uris:
existing_exts.append(ext)
capabilities = capabilities.model_copy(update={"extensions": existing_exts})
card = AgentCard(
name=name,
description=description,
url=server_config.url or url,
version=server_config.version,
capabilities=server_config.capabilities,
capabilities=capabilities,
default_input_modes=server_config.default_input_modes,
default_output_modes=server_config.default_output_modes,
skills=skills,
preferred_transport=server_config.transport.preferred,
protocol_version=server_config.protocol_version,
provider=server_config.provider,
documentation_url=server_config.documentation_url,
@@ -489,9 +550,21 @@ def _agent_to_agent_card(agent: Agent, url: str) -> AgentCard:
security=server_config.security,
security_schemes=server_config.security_schemes,
supports_authenticated_extended_card=server_config.supports_authenticated_extended_card,
signatures=server_config.signatures,
)
if server_config.signing_config:
signature = sign_agent_card(
card,
private_key=server_config.signing_config.get_private_key(),
key_id=server_config.signing_config.key_id,
algorithm=server_config.signing_config.algorithm,
)
card = card.model_copy(update={"signatures": [signature]})
elif server_config.signatures:
card = card.model_copy(update={"signatures": server_config.signatures})
return card
def inject_a2a_server_methods(agent: Agent) -> None:
"""Inject A2A server methods onto an Agent instance.

View File

@@ -0,0 +1,236 @@
"""AgentCard JWS signing utilities.
This module provides functions for signing and verifying AgentCards using
JSON Web Signatures (JWS) as per RFC 7515. Signed agent cards allow clients
to verify the authenticity and integrity of agent card information.
Example:
>>> from crewai.a2a.utils.agent_card_signing import sign_agent_card
>>> signature = sign_agent_card(agent_card, private_key_pem, key_id="key-1")
>>> card_with_sig = card.model_copy(update={"signatures": [signature]})
"""
from __future__ import annotations
import base64
import json
import logging
from typing import Any, Literal
from a2a.types import AgentCard, AgentCardSignature
import jwt
from pydantic import SecretStr
logger = logging.getLogger(__name__)
SigningAlgorithm = Literal[
"RS256", "RS384", "RS512", "ES256", "ES384", "ES512", "PS256", "PS384", "PS512"
]
def _normalize_private_key(private_key: str | bytes | SecretStr) -> bytes:
"""Normalize private key to bytes format.
Args:
private_key: PEM-encoded private key as string, bytes, or SecretStr.
Returns:
Private key as bytes.
"""
if isinstance(private_key, SecretStr):
private_key = private_key.get_secret_value()
if isinstance(private_key, str):
private_key = private_key.encode()
return private_key
def _serialize_agent_card(agent_card: AgentCard) -> str:
"""Serialize AgentCard to canonical JSON for signing.
Excludes the signatures field to avoid circular reference during signing.
Uses sorted keys and compact separators for deterministic output.
Args:
agent_card: The AgentCard to serialize.
Returns:
Canonical JSON string representation.
"""
card_dict = agent_card.model_dump(exclude={"signatures"}, exclude_none=True)
return json.dumps(card_dict, sort_keys=True, separators=(",", ":"))
def _base64url_encode(data: bytes | str) -> str:
"""Encode data to URL-safe base64 without padding.
Args:
data: Data to encode.
Returns:
URL-safe base64 encoded string without padding.
"""
if isinstance(data, str):
data = data.encode()
return base64.urlsafe_b64encode(data).rstrip(b"=").decode("ascii")
def sign_agent_card(
agent_card: AgentCard,
private_key: str | bytes | SecretStr,
key_id: str | None = None,
algorithm: SigningAlgorithm = "RS256",
) -> AgentCardSignature:
"""Sign an AgentCard using JWS (RFC 7515).
Creates a detached JWS signature for the AgentCard. The signature covers
all fields except the signatures field itself.
Args:
agent_card: The AgentCard to sign.
private_key: PEM-encoded private key (RSA, EC, or RSA-PSS).
key_id: Optional key identifier for the JWS header (kid claim).
algorithm: Signing algorithm (RS256, ES256, PS256, etc.).
Returns:
AgentCardSignature with protected header and signature.
Raises:
jwt.exceptions.InvalidKeyError: If the private key is invalid.
ValueError: If the algorithm is not supported for the key type.
Example:
>>> signature = sign_agent_card(
... agent_card,
... private_key_pem="-----BEGIN PRIVATE KEY-----...",
... key_id="my-key-id",
... )
"""
key_bytes = _normalize_private_key(private_key)
payload = _serialize_agent_card(agent_card)
protected_header: dict[str, Any] = {"typ": "JWS"}
if key_id:
protected_header["kid"] = key_id
jws_token = jwt.api_jws.encode(
payload.encode(),
key_bytes,
algorithm=algorithm,
headers=protected_header,
)
parts = jws_token.split(".")
protected_b64 = parts[0]
signature_b64 = parts[2]
header: dict[str, Any] | None = None
if key_id:
header = {"kid": key_id}
return AgentCardSignature(
protected=protected_b64,
signature=signature_b64,
header=header,
)
def verify_agent_card_signature(
agent_card: AgentCard,
signature: AgentCardSignature,
public_key: str | bytes,
algorithms: list[str] | None = None,
) -> bool:
"""Verify an AgentCard JWS signature.
Validates that the signature was created with the corresponding private key
and that the AgentCard content has not been modified.
Args:
agent_card: The AgentCard to verify.
signature: The AgentCardSignature to validate.
public_key: PEM-encoded public key (RSA, EC, or RSA-PSS).
algorithms: List of allowed algorithms. Defaults to common asymmetric algorithms.
Returns:
True if signature is valid, False otherwise.
Example:
>>> is_valid = verify_agent_card_signature(
... agent_card, signature, public_key_pem="-----BEGIN PUBLIC KEY-----..."
... )
"""
if algorithms is None:
algorithms = [
"RS256",
"RS384",
"RS512",
"ES256",
"ES384",
"ES512",
"PS256",
"PS384",
"PS512",
]
if isinstance(public_key, str):
public_key = public_key.encode()
payload = _serialize_agent_card(agent_card)
payload_b64 = _base64url_encode(payload)
jws_token = f"{signature.protected}.{payload_b64}.{signature.signature}"
try:
jwt.api_jws.decode(
jws_token,
public_key,
algorithms=algorithms,
)
return True
except jwt.InvalidSignatureError:
logger.debug(
"AgentCard signature verification failed",
extra={"reason": "invalid_signature"},
)
return False
except jwt.DecodeError as e:
logger.debug(
"AgentCard signature verification failed",
extra={"reason": "decode_error", "error": str(e)},
)
return False
except jwt.InvalidAlgorithmError as e:
logger.debug(
"AgentCard signature verification failed",
extra={"reason": "algorithm_error", "error": str(e)},
)
return False
def get_key_id_from_signature(signature: AgentCardSignature) -> str | None:
"""Extract the key ID (kid) from an AgentCardSignature.
Checks both the unprotected header and the protected header for the kid claim.
Args:
signature: The AgentCardSignature to extract from.
Returns:
The key ID if present, None otherwise.
"""
if signature.header and "kid" in signature.header:
kid: str = signature.header["kid"]
return kid
try:
protected = signature.protected
padding_needed = 4 - (len(protected) % 4)
if padding_needed != 4:
protected += "=" * padding_needed
protected_json = base64.urlsafe_b64decode(protected).decode()
protected_header: dict[str, Any] = json.loads(protected_json)
return protected_header.get("kid")
except (ValueError, json.JSONDecodeError):
return None

View File

@@ -0,0 +1,339 @@
"""Content type negotiation for A2A protocol.
This module handles negotiation of input/output MIME types between A2A clients
and servers based on AgentCard capabilities.
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING, Annotated, Final, Literal, cast
from a2a.types import Part
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.a2a_events import A2AContentTypeNegotiatedEvent
if TYPE_CHECKING:
from a2a.types import AgentCard, AgentSkill
TEXT_PLAIN: Literal["text/plain"] = "text/plain"
APPLICATION_JSON: Literal["application/json"] = "application/json"
IMAGE_PNG: Literal["image/png"] = "image/png"
IMAGE_JPEG: Literal["image/jpeg"] = "image/jpeg"
IMAGE_WILDCARD: Literal["image/*"] = "image/*"
APPLICATION_PDF: Literal["application/pdf"] = "application/pdf"
APPLICATION_OCTET_STREAM: Literal["application/octet-stream"] = (
"application/octet-stream"
)
DEFAULT_CLIENT_INPUT_MODES: Final[list[Literal["text/plain", "application/json"]]] = [
TEXT_PLAIN,
APPLICATION_JSON,
]
DEFAULT_CLIENT_OUTPUT_MODES: Final[list[Literal["text/plain", "application/json"]]] = [
TEXT_PLAIN,
APPLICATION_JSON,
]
@dataclass
class NegotiatedContentTypes:
"""Result of content type negotiation."""
input_modes: Annotated[list[str], "Negotiated input MIME types the client can send"]
output_modes: Annotated[
list[str], "Negotiated output MIME types the server will produce"
]
effective_input_modes: Annotated[list[str], "Server's effective input modes"]
effective_output_modes: Annotated[list[str], "Server's effective output modes"]
skill_name: Annotated[
str | None, "Skill name if negotiation was skill-specific"
] = None
class ContentTypeNegotiationError(Exception):
"""Raised when no compatible content types can be negotiated."""
def __init__(
self,
client_input_modes: list[str],
client_output_modes: list[str],
server_input_modes: list[str],
server_output_modes: list[str],
direction: str = "both",
message: str | None = None,
) -> None:
self.client_input_modes = client_input_modes
self.client_output_modes = client_output_modes
self.server_input_modes = server_input_modes
self.server_output_modes = server_output_modes
self.direction = direction
if message is None:
if direction == "input":
message = (
f"No compatible input content types. "
f"Client supports: {client_input_modes}, "
f"Server accepts: {server_input_modes}"
)
elif direction == "output":
message = (
f"No compatible output content types. "
f"Client accepts: {client_output_modes}, "
f"Server produces: {server_output_modes}"
)
else:
message = (
f"No compatible content types. "
f"Input - Client: {client_input_modes}, Server: {server_input_modes}. "
f"Output - Client: {client_output_modes}, Server: {server_output_modes}"
)
super().__init__(message)
def _normalize_mime_type(mime_type: str) -> str:
"""Normalize MIME type for comparison (lowercase, strip whitespace)."""
return mime_type.lower().strip()
def _mime_types_compatible(client_type: str, server_type: str) -> bool:
"""Check if two MIME types are compatible.
Handles wildcards like image/* matching image/png.
"""
client_normalized = _normalize_mime_type(client_type)
server_normalized = _normalize_mime_type(server_type)
if client_normalized == server_normalized:
return True
if "*" in client_normalized or "*" in server_normalized:
client_parts = client_normalized.split("/")
server_parts = server_normalized.split("/")
if len(client_parts) == 2 and len(server_parts) == 2:
type_match = (
client_parts[0] == server_parts[0]
or client_parts[0] == "*"
or server_parts[0] == "*"
)
subtype_match = (
client_parts[1] == server_parts[1]
or client_parts[1] == "*"
or server_parts[1] == "*"
)
return type_match and subtype_match
return False
def _find_compatible_modes(
client_modes: list[str], server_modes: list[str]
) -> list[str]:
"""Find compatible MIME types between client and server.
Returns modes in client preference order.
"""
compatible = []
for client_mode in client_modes:
for server_mode in server_modes:
if _mime_types_compatible(client_mode, server_mode):
if "*" in client_mode and "*" not in server_mode:
if server_mode not in compatible:
compatible.append(server_mode)
else:
if client_mode not in compatible:
compatible.append(client_mode)
break
return compatible
def _get_effective_modes(
agent_card: AgentCard,
skill_name: str | None = None,
) -> tuple[list[str], list[str], AgentSkill | None]:
"""Get effective input/output modes from agent card.
If skill_name is provided and the skill has custom modes, those are used.
Otherwise, falls back to agent card defaults.
"""
skill: AgentSkill | None = None
if skill_name and agent_card.skills:
for s in agent_card.skills:
if s.name == skill_name or s.id == skill_name:
skill = s
break
if skill:
input_modes = (
skill.input_modes if skill.input_modes else agent_card.default_input_modes
)
output_modes = (
skill.output_modes
if skill.output_modes
else agent_card.default_output_modes
)
else:
input_modes = agent_card.default_input_modes
output_modes = agent_card.default_output_modes
return input_modes, output_modes, skill
def negotiate_content_types(
agent_card: AgentCard,
client_input_modes: list[str] | None = None,
client_output_modes: list[str] | None = None,
skill_name: str | None = None,
emit_event: bool = True,
endpoint: str | None = None,
a2a_agent_name: str | None = None,
strict: bool = False,
) -> NegotiatedContentTypes:
"""Negotiate content types between client and server.
Args:
agent_card: The remote agent's card with capability info.
client_input_modes: MIME types the client can send. Defaults to text/plain and application/json.
client_output_modes: MIME types the client can accept. Defaults to text/plain and application/json.
skill_name: Optional skill to use for mode lookup.
emit_event: Whether to emit a content type negotiation event.
endpoint: Agent endpoint (for event metadata).
a2a_agent_name: Agent name (for event metadata).
strict: If True, raises error when no compatible types found.
If False, returns empty lists for incompatible directions.
Returns:
NegotiatedContentTypes with compatible input and output modes.
Raises:
ContentTypeNegotiationError: If strict=True and no compatible types found.
"""
if client_input_modes is None:
client_input_modes = cast(list[str], DEFAULT_CLIENT_INPUT_MODES.copy())
if client_output_modes is None:
client_output_modes = cast(list[str], DEFAULT_CLIENT_OUTPUT_MODES.copy())
server_input_modes, server_output_modes, skill = _get_effective_modes(
agent_card, skill_name
)
compatible_input = _find_compatible_modes(client_input_modes, server_input_modes)
compatible_output = _find_compatible_modes(client_output_modes, server_output_modes)
if strict:
if not compatible_input and not compatible_output:
raise ContentTypeNegotiationError(
client_input_modes=client_input_modes,
client_output_modes=client_output_modes,
server_input_modes=server_input_modes,
server_output_modes=server_output_modes,
)
if not compatible_input:
raise ContentTypeNegotiationError(
client_input_modes=client_input_modes,
client_output_modes=client_output_modes,
server_input_modes=server_input_modes,
server_output_modes=server_output_modes,
direction="input",
)
if not compatible_output:
raise ContentTypeNegotiationError(
client_input_modes=client_input_modes,
client_output_modes=client_output_modes,
server_input_modes=server_input_modes,
server_output_modes=server_output_modes,
direction="output",
)
result = NegotiatedContentTypes(
input_modes=compatible_input,
output_modes=compatible_output,
effective_input_modes=server_input_modes,
effective_output_modes=server_output_modes,
skill_name=skill.name if skill else None,
)
if emit_event:
crewai_event_bus.emit(
None,
A2AContentTypeNegotiatedEvent(
endpoint=endpoint or agent_card.url,
a2a_agent_name=a2a_agent_name or agent_card.name,
skill_name=skill_name,
client_input_modes=client_input_modes,
client_output_modes=client_output_modes,
server_input_modes=server_input_modes,
server_output_modes=server_output_modes,
negotiated_input_modes=compatible_input,
negotiated_output_modes=compatible_output,
negotiation_success=bool(compatible_input and compatible_output),
),
)
return result
def validate_content_type(
content_type: str,
allowed_modes: list[str],
) -> bool:
"""Validate that a content type is allowed by a list of modes.
Args:
content_type: The MIME type to validate.
allowed_modes: List of allowed MIME types (may include wildcards).
Returns:
True if content_type is compatible with any allowed mode.
"""
for mode in allowed_modes:
if _mime_types_compatible(content_type, mode):
return True
return False
def get_part_content_type(part: Part) -> str:
"""Extract MIME type from an A2A Part.
Args:
part: A Part object containing TextPart, DataPart, or FilePart.
Returns:
The MIME type string for this part.
"""
root = part.root
if root.kind == "text":
return TEXT_PLAIN
if root.kind == "data":
return APPLICATION_JSON
if root.kind == "file":
return root.file.mime_type or APPLICATION_OCTET_STREAM
return APPLICATION_OCTET_STREAM
def validate_message_parts(
parts: list[Part],
allowed_modes: list[str],
) -> list[str]:
"""Validate that all message parts have allowed content types.
Args:
parts: List of Parts from the incoming message.
allowed_modes: List of allowed MIME types (from default_input_modes).
Returns:
List of invalid content types found (empty if all valid).
"""
invalid_types: list[str] = []
for part in parts:
content_type = get_part_content_type(part)
if not validate_content_type(content_type, allowed_modes):
if content_type not in invalid_types:
invalid_types.append(content_type)
return invalid_types

View File

@@ -3,9 +3,10 @@
from __future__ import annotations
import asyncio
from collections.abc import AsyncIterator, MutableMapping
from collections.abc import AsyncIterator, Callable, MutableMapping
from contextlib import asynccontextmanager
from typing import TYPE_CHECKING, Any, Literal
import logging
from typing import TYPE_CHECKING, Any, Final, Literal
import uuid
from a2a.client import Client, ClientConfig, ClientFactory
@@ -20,18 +21,24 @@ from a2a.types import (
import httpx
from pydantic import BaseModel
from crewai.a2a.auth.schemas import APIKeyAuth, HTTPDigestAuth
from crewai.a2a.auth.client_schemes import APIKeyAuth, HTTPDigestAuth
from crewai.a2a.auth.utils import (
_auth_store,
configure_auth_client,
validate_auth_against_agent_card,
)
from crewai.a2a.config import ClientTransportConfig, GRPCClientConfig
from crewai.a2a.extensions.registry import (
ExtensionsMiddleware,
validate_required_extensions,
)
from crewai.a2a.task_helpers import TaskStateResult
from crewai.a2a.types import (
HANDLER_REGISTRY,
HandlerType,
PartsDict,
PartsMetadataDict,
TransportType,
)
from crewai.a2a.updates import (
PollingConfig,
@@ -39,7 +46,20 @@ from crewai.a2a.updates import (
StreamingHandler,
UpdateConfig,
)
from crewai.a2a.utils.agent_card import _afetch_agent_card_cached
from crewai.a2a.utils.agent_card import (
_afetch_agent_card_cached,
_get_tls_verify,
_prepare_auth_headers,
)
from crewai.a2a.utils.content_type import (
DEFAULT_CLIENT_OUTPUT_MODES,
negotiate_content_types,
)
from crewai.a2a.utils.transport import (
NegotiatedTransport,
TransportNegotiationError,
negotiate_transport,
)
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.a2a_events import (
A2AConversationStartedEvent,
@@ -49,10 +69,16 @@ from crewai.events.types.a2a_events import (
)
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from a2a.types import Message
from crewai.a2a.auth.schemas import AuthScheme
from crewai.a2a.auth.client_schemes import ClientAuthScheme
_DEFAULT_TRANSPORT: Final[TransportType] = "JSONRPC"
def get_handler(config: UpdateConfig | None) -> HandlerType:
@@ -71,8 +97,7 @@ def get_handler(config: UpdateConfig | None) -> HandlerType:
def execute_a2a_delegation(
endpoint: str,
transport_protocol: Literal["JSONRPC", "GRPC", "HTTP+JSON"],
auth: AuthScheme | None,
auth: ClientAuthScheme | None,
timeout: int,
task_description: str,
context: str | None = None,
@@ -91,32 +116,23 @@ def execute_a2a_delegation(
from_task: Any | None = None,
from_agent: Any | None = None,
skill_id: str | None = None,
client_extensions: list[str] | None = None,
transport: ClientTransportConfig | None = None,
accepted_output_modes: list[str] | None = None,
) -> TaskStateResult:
"""Execute a task delegation to a remote A2A agent synchronously.
This is the sync wrapper around aexecute_a2a_delegation. For async contexts,
use aexecute_a2a_delegation directly.
WARNING: This function blocks the entire thread by creating and running a new
event loop. Prefer using 'await aexecute_a2a_delegation()' in async contexts
for better performance and resource efficiency.
This is a synchronous wrapper around aexecute_a2a_delegation that creates a
new event loop to run the async implementation. It is provided for compatibility
with synchronous code paths only.
Args:
endpoint: A2A agent endpoint URL (AgentCard URL)
transport_protocol: Optional A2A transport protocol (grpc, jsonrpc, http+json)
auth: Optional AuthScheme for authentication (Bearer, OAuth2, API Key, HTTP Basic/Digest)
timeout: Request timeout in seconds
task_description: The task to delegate
context: Optional context information
context_id: Context ID for correlating messages/tasks
task_id: Specific task identifier
reference_task_ids: List of related task IDs
metadata: Additional metadata (external_id, request_id, etc.)
extensions: Protocol extensions for custom fields
conversation_history: Previous Message objects from conversation
agent_id: Agent identifier for logging
agent_role: Role of the CrewAI agent delegating the task
agent_branch: Optional agent tree branch for logging
response_model: Optional Pydantic model for structured outputs
turn_number: Optional turn number for multi-turn conversations
endpoint: A2A agent endpoint URL.
auth: Optional AuthScheme for authentication.
endpoint: A2A agent endpoint URL (AgentCard URL).
auth: Optional ClientAuthScheme for authentication.
timeout: Request timeout in seconds.
task_description: The task to delegate.
context: Optional context information.
@@ -135,10 +151,26 @@ def execute_a2a_delegation(
from_task: Optional CrewAI Task object for event metadata.
from_agent: Optional CrewAI Agent object for event metadata.
skill_id: Optional skill ID to target a specific agent capability.
client_extensions: A2A protocol extension URIs the client supports.
transport: Transport configuration (preferred, supported transports, gRPC settings).
accepted_output_modes: MIME types the client can accept in responses.
Returns:
TaskStateResult with status, result/error, history, and agent_card.
Raises:
RuntimeError: If called from an async context with a running event loop.
"""
try:
asyncio.get_running_loop()
raise RuntimeError(
"execute_a2a_delegation() cannot be called from an async context. "
"Use 'await aexecute_a2a_delegation()' instead."
)
except RuntimeError as e:
if "no running event loop" not in str(e).lower():
raise
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
@@ -159,12 +191,14 @@ def execute_a2a_delegation(
agent_role=agent_role,
agent_branch=agent_branch,
response_model=response_model,
transport_protocol=transport_protocol,
turn_number=turn_number,
updates=updates,
from_task=from_task,
from_agent=from_agent,
skill_id=skill_id,
client_extensions=client_extensions,
transport=transport,
accepted_output_modes=accepted_output_modes,
)
)
finally:
@@ -176,8 +210,7 @@ def execute_a2a_delegation(
async def aexecute_a2a_delegation(
endpoint: str,
transport_protocol: Literal["JSONRPC", "GRPC", "HTTP+JSON"],
auth: AuthScheme | None,
auth: ClientAuthScheme | None,
timeout: int,
task_description: str,
context: str | None = None,
@@ -196,6 +229,9 @@ async def aexecute_a2a_delegation(
from_task: Any | None = None,
from_agent: Any | None = None,
skill_id: str | None = None,
client_extensions: list[str] | None = None,
transport: ClientTransportConfig | None = None,
accepted_output_modes: list[str] | None = None,
) -> TaskStateResult:
"""Execute a task delegation to a remote A2A agent asynchronously.
@@ -203,25 +239,8 @@ async def aexecute_a2a_delegation(
in an async context (e.g., with Crew.akickoff() or agent.aexecute_task()).
Args:
endpoint: A2A agent endpoint URL
transport_protocol: Optional A2A transport protocol (grpc, jsonrpc, http+json)
auth: Optional AuthScheme for authentication
timeout: Request timeout in seconds
task_description: Task to delegate
context: Optional context
context_id: Context ID for correlation
task_id: Specific task identifier
reference_task_ids: Related task IDs
metadata: Additional metadata
extensions: Protocol extensions
conversation_history: Previous Message objects
turn_number: Current turn number
agent_branch: Agent tree branch for logging
agent_id: Agent identifier for logging
agent_role: Agent role for logging
response_model: Optional Pydantic model for structured outputs
endpoint: A2A agent endpoint URL.
auth: Optional AuthScheme for authentication.
auth: Optional ClientAuthScheme for authentication.
timeout: Request timeout in seconds.
task_description: The task to delegate.
context: Optional context information.
@@ -240,6 +259,9 @@ async def aexecute_a2a_delegation(
from_task: Optional CrewAI Task object for event metadata.
from_agent: Optional CrewAI Agent object for event metadata.
skill_id: Optional skill ID to target a specific agent capability.
client_extensions: A2A protocol extension URIs the client supports.
transport: Transport configuration (preferred, supported transports, gRPC settings).
accepted_output_modes: MIME types the client can accept in responses.
Returns:
TaskStateResult with status, result/error, history, and agent_card.
@@ -271,10 +293,12 @@ async def aexecute_a2a_delegation(
agent_role=agent_role,
response_model=response_model,
updates=updates,
transport_protocol=transport_protocol,
from_task=from_task,
from_agent=from_agent,
skill_id=skill_id,
client_extensions=client_extensions,
transport=transport,
accepted_output_modes=accepted_output_modes,
)
except Exception as e:
crewai_event_bus.emit(
@@ -294,7 +318,7 @@ async def aexecute_a2a_delegation(
)
raise
agent_card_data: dict[str, Any] = result.get("agent_card") or {}
agent_card_data = result.get("agent_card")
crewai_event_bus.emit(
agent_branch,
A2ADelegationCompletedEvent(
@@ -306,7 +330,7 @@ async def aexecute_a2a_delegation(
endpoint=endpoint,
a2a_agent_name=result.get("a2a_agent_name"),
agent_card=agent_card_data,
provider=agent_card_data.get("provider"),
provider=agent_card_data.get("provider") if agent_card_data else None,
metadata=metadata,
extensions=list(extensions.keys()) if extensions else None,
from_task=from_task,
@@ -319,8 +343,7 @@ async def aexecute_a2a_delegation(
async def _aexecute_a2a_delegation_impl(
endpoint: str,
transport_protocol: Literal["JSONRPC", "GRPC", "HTTP+JSON"],
auth: AuthScheme | None,
auth: ClientAuthScheme | None,
timeout: int,
task_description: str,
context: str | None,
@@ -340,8 +363,13 @@ async def _aexecute_a2a_delegation_impl(
from_task: Any | None = None,
from_agent: Any | None = None,
skill_id: str | None = None,
client_extensions: list[str] | None = None,
transport: ClientTransportConfig | None = None,
accepted_output_modes: list[str] | None = None,
) -> TaskStateResult:
"""Internal async implementation of A2A delegation."""
if transport is None:
transport = ClientTransportConfig()
if auth:
auth_data = auth.model_dump_json(
exclude={
@@ -351,22 +379,70 @@ async def _aexecute_a2a_delegation_impl(
"_authorization_callback",
}
)
auth_hash = hash((type(auth).__name__, auth_data))
auth_hash = _auth_store.compute_key(type(auth).__name__, auth_data)
else:
auth_hash = 0
_auth_store[auth_hash] = auth
auth_hash = _auth_store.compute_key("none", endpoint)
_auth_store.set(auth_hash, auth)
agent_card = await _afetch_agent_card_cached(
endpoint=endpoint, auth_hash=auth_hash, timeout=timeout
)
validate_auth_against_agent_card(agent_card, auth)
headers: MutableMapping[str, str] = {}
if auth:
async with httpx.AsyncClient(timeout=timeout) as temp_auth_client:
if isinstance(auth, (HTTPDigestAuth, APIKeyAuth)):
configure_auth_client(auth, temp_auth_client)
headers = await auth.apply_auth(temp_auth_client, {})
unsupported_exts = validate_required_extensions(agent_card, client_extensions)
if unsupported_exts:
ext_uris = [ext.uri for ext in unsupported_exts]
raise ValueError(
f"Agent requires extensions not supported by client: {ext_uris}"
)
negotiated: NegotiatedTransport | None = None
effective_transport: TransportType = transport.preferred or _DEFAULT_TRANSPORT
effective_url = endpoint
client_transports: list[str] = (
list(transport.supported) if transport.supported else [_DEFAULT_TRANSPORT]
)
try:
negotiated = negotiate_transport(
agent_card=agent_card,
client_supported_transports=client_transports,
client_preferred_transport=transport.preferred,
endpoint=endpoint,
a2a_agent_name=agent_card.name,
)
effective_transport = negotiated.transport # type: ignore[assignment]
effective_url = negotiated.url
except TransportNegotiationError as e:
logger.warning(
"Transport negotiation failed, using fallback",
extra={
"error": str(e),
"fallback_transport": effective_transport,
"fallback_url": effective_url,
"endpoint": endpoint,
"client_transports": client_transports,
"server_transports": [
iface.transport for iface in agent_card.additional_interfaces or []
]
+ [agent_card.preferred_transport or "JSONRPC"],
},
)
effective_output_modes = accepted_output_modes or DEFAULT_CLIENT_OUTPUT_MODES.copy()
content_negotiated = negotiate_content_types(
agent_card=agent_card,
client_output_modes=accepted_output_modes,
skill_name=skill_id,
endpoint=endpoint,
a2a_agent_name=agent_card.name,
)
if content_negotiated.output_modes:
effective_output_modes = content_negotiated.output_modes
headers, _ = await _prepare_auth_headers(auth, timeout)
a2a_agent_name = None
if agent_card.name:
@@ -513,15 +589,22 @@ async def _aexecute_a2a_delegation_impl(
use_streaming = not use_polling and push_config_for_client is None
client_agent_card = agent_card
if effective_url != agent_card.url:
client_agent_card = agent_card.model_copy(update={"url": effective_url})
async with _create_a2a_client(
agent_card=agent_card,
transport_protocol=transport_protocol,
agent_card=client_agent_card,
transport_protocol=effective_transport,
timeout=timeout,
headers=headers,
streaming=use_streaming,
auth=auth,
use_polling=use_polling,
push_notification_config=push_config_for_client,
client_extensions=client_extensions,
accepted_output_modes=effective_output_modes, # type: ignore[arg-type]
grpc_config=transport.grpc,
) as client:
result = await handler.execute(
client=client,
@@ -535,6 +618,245 @@ async def _aexecute_a2a_delegation_impl(
return result
def _normalize_grpc_metadata(
metadata: tuple[tuple[str, str], ...] | None,
) -> tuple[tuple[str, str], ...] | None:
"""Lowercase all gRPC metadata keys.
gRPC requires lowercase metadata keys, but some libraries (like the A2A SDK)
use mixed-case headers like 'X-A2A-Extensions'. This normalizes them.
"""
if metadata is None:
return None
return tuple((key.lower(), value) for key, value in metadata)
def _create_grpc_interceptors(
auth_metadata: list[tuple[str, str]] | None = None,
) -> list[Any]:
"""Create gRPC interceptors for metadata normalization and auth injection.
Args:
auth_metadata: Optional auth metadata to inject into all calls.
Used for insecure channels that need auth (non-localhost without TLS).
Returns a list of interceptors that lowercase metadata keys for gRPC
compatibility. Must be called after grpc is imported.
"""
import grpc.aio # type: ignore[import-untyped]
def _merge_metadata(
existing: tuple[tuple[str, str], ...] | None,
auth: list[tuple[str, str]] | None,
) -> tuple[tuple[str, str], ...] | None:
"""Merge existing metadata with auth metadata and normalize keys."""
merged: list[tuple[str, str]] = []
if existing:
merged.extend(existing)
if auth:
merged.extend(auth)
if not merged:
return None
return tuple((key.lower(), value) for key, value in merged)
def _inject_metadata(client_call_details: Any) -> Any:
"""Inject merged metadata into call details."""
return client_call_details._replace(
metadata=_merge_metadata(client_call_details.metadata, auth_metadata)
)
class MetadataUnaryUnary(grpc.aio.UnaryUnaryClientInterceptor): # type: ignore[misc,no-any-unimported]
"""Interceptor for unary-unary calls that injects auth metadata."""
async def intercept_unary_unary( # type: ignore[no-untyped-def]
self, continuation, client_call_details, request
):
"""Intercept unary-unary call and inject metadata."""
return await continuation(_inject_metadata(client_call_details), request)
class MetadataUnaryStream(grpc.aio.UnaryStreamClientInterceptor): # type: ignore[misc,no-any-unimported]
"""Interceptor for unary-stream calls that injects auth metadata."""
async def intercept_unary_stream( # type: ignore[no-untyped-def]
self, continuation, client_call_details, request
):
"""Intercept unary-stream call and inject metadata."""
return await continuation(_inject_metadata(client_call_details), request)
class MetadataStreamUnary(grpc.aio.StreamUnaryClientInterceptor): # type: ignore[misc,no-any-unimported]
"""Interceptor for stream-unary calls that injects auth metadata."""
async def intercept_stream_unary( # type: ignore[no-untyped-def]
self, continuation, client_call_details, request_iterator
):
"""Intercept stream-unary call and inject metadata."""
return await continuation(
_inject_metadata(client_call_details), request_iterator
)
class MetadataStreamStream(grpc.aio.StreamStreamClientInterceptor): # type: ignore[misc,no-any-unimported]
"""Interceptor for stream-stream calls that injects auth metadata."""
async def intercept_stream_stream( # type: ignore[no-untyped-def]
self, continuation, client_call_details, request_iterator
):
"""Intercept stream-stream call and inject metadata."""
return await continuation(
_inject_metadata(client_call_details), request_iterator
)
return [
MetadataUnaryUnary(),
MetadataUnaryStream(),
MetadataStreamUnary(),
MetadataStreamStream(),
]
def _create_grpc_channel_factory(
grpc_config: GRPCClientConfig,
auth: ClientAuthScheme | None = None,
) -> Callable[[str], Any]:
"""Create a gRPC channel factory with the given configuration.
Args:
grpc_config: gRPC client configuration with channel options.
auth: Optional ClientAuthScheme for TLS and auth configuration.
Returns:
A callable that creates gRPC channels from URLs.
"""
try:
import grpc
except ImportError as e:
raise ImportError(
"gRPC transport requires grpcio. Install with: pip install a2a-sdk[grpc]"
) from e
auth_metadata: list[tuple[str, str]] = []
if auth is not None:
from crewai.a2a.auth.client_schemes import (
APIKeyAuth,
BearerTokenAuth,
HTTPBasicAuth,
HTTPDigestAuth,
OAuth2AuthorizationCode,
OAuth2ClientCredentials,
)
if isinstance(auth, HTTPDigestAuth):
raise ValueError(
"HTTPDigestAuth is not supported with gRPC transport. "
"Digest authentication requires HTTP challenge-response flow. "
"Use BearerTokenAuth, HTTPBasicAuth, APIKeyAuth (header), or OAuth2 instead."
)
if isinstance(auth, APIKeyAuth) and auth.location in ("query", "cookie"):
raise ValueError(
f"APIKeyAuth with location='{auth.location}' is not supported with gRPC transport. "
"gRPC only supports header-based authentication. "
"Use APIKeyAuth with location='header' instead."
)
if isinstance(auth, BearerTokenAuth):
auth_metadata.append(("authorization", f"Bearer {auth.token}"))
elif isinstance(auth, HTTPBasicAuth):
import base64
basic_credentials = f"{auth.username}:{auth.password}"
encoded = base64.b64encode(basic_credentials.encode()).decode()
auth_metadata.append(("authorization", f"Basic {encoded}"))
elif isinstance(auth, APIKeyAuth) and auth.location == "header":
header_name = auth.name.lower()
auth_metadata.append((header_name, auth.api_key))
elif isinstance(auth, (OAuth2ClientCredentials, OAuth2AuthorizationCode)):
if auth._access_token:
auth_metadata.append(("authorization", f"Bearer {auth._access_token}"))
def factory(url: str) -> Any:
"""Create a gRPC channel for the given URL."""
target = url
use_tls = False
if url.startswith("grpcs://"):
target = url[8:]
use_tls = True
elif url.startswith("grpc://"):
target = url[7:]
elif url.startswith("https://"):
target = url[8:]
use_tls = True
elif url.startswith("http://"):
target = url[7:]
options: list[tuple[str, Any]] = []
if grpc_config.max_send_message_length is not None:
options.append(
("grpc.max_send_message_length", grpc_config.max_send_message_length)
)
if grpc_config.max_receive_message_length is not None:
options.append(
(
"grpc.max_receive_message_length",
grpc_config.max_receive_message_length,
)
)
if grpc_config.keepalive_time_ms is not None:
options.append(("grpc.keepalive_time_ms", grpc_config.keepalive_time_ms))
if grpc_config.keepalive_timeout_ms is not None:
options.append(
("grpc.keepalive_timeout_ms", grpc_config.keepalive_timeout_ms)
)
channel_credentials = None
if auth and hasattr(auth, "tls") and auth.tls:
channel_credentials = auth.tls.get_grpc_credentials()
elif use_tls:
channel_credentials = grpc.ssl_channel_credentials()
if channel_credentials and auth_metadata:
class AuthMetadataPlugin(grpc.AuthMetadataPlugin): # type: ignore[misc,no-any-unimported]
"""gRPC auth metadata plugin that adds auth headers as metadata."""
def __init__(self, metadata: list[tuple[str, str]]) -> None:
self._metadata = tuple(metadata)
def __call__( # type: ignore[no-any-unimported]
self,
context: grpc.AuthMetadataContext,
callback: grpc.AuthMetadataPluginCallback,
) -> None:
callback(self._metadata, None)
call_creds = grpc.metadata_call_credentials(
AuthMetadataPlugin(auth_metadata)
)
credentials = grpc.composite_channel_credentials(
channel_credentials, call_creds
)
interceptors = _create_grpc_interceptors()
return grpc.aio.secure_channel(
target, credentials, options=options or None, interceptors=interceptors
)
if channel_credentials:
interceptors = _create_grpc_interceptors()
return grpc.aio.secure_channel(
target,
channel_credentials,
options=options or None,
interceptors=interceptors,
)
interceptors = _create_grpc_interceptors(
auth_metadata=auth_metadata if auth_metadata else None
)
return grpc.aio.insecure_channel(
target, options=options or None, interceptors=interceptors
)
return factory
@asynccontextmanager
async def _create_a2a_client(
agent_card: AgentCard,
@@ -542,9 +864,12 @@ async def _create_a2a_client(
timeout: int,
headers: MutableMapping[str, str],
streaming: bool,
auth: AuthScheme | None = None,
auth: ClientAuthScheme | None = None,
use_polling: bool = False,
push_notification_config: PushNotificationConfig | None = None,
client_extensions: list[str] | None = None,
accepted_output_modes: list[str] | None = None,
grpc_config: GRPCClientConfig | None = None,
) -> AsyncIterator[Client]:
"""Create and configure an A2A client.
@@ -554,16 +879,21 @@ async def _create_a2a_client(
timeout: Request timeout in seconds.
headers: HTTP headers (already with auth applied).
streaming: Enable streaming responses.
auth: Optional AuthScheme for client configuration.
auth: Optional ClientAuthScheme for client configuration.
use_polling: Enable polling mode.
push_notification_config: Optional push notification config.
client_extensions: A2A protocol extension URIs to declare support for.
accepted_output_modes: MIME types the client can accept in responses.
grpc_config: Optional gRPC client configuration.
Yields:
Configured A2A client instance.
"""
verify = _get_tls_verify(auth)
async with httpx.AsyncClient(
timeout=timeout,
headers=headers,
verify=verify,
) as httpx_client:
if auth and isinstance(auth, (HTTPDigestAuth, APIKeyAuth)):
configure_auth_client(auth, httpx_client)
@@ -579,15 +909,27 @@ async def _create_a2a_client(
)
)
grpc_channel_factory = None
if transport_protocol == "GRPC":
grpc_channel_factory = _create_grpc_channel_factory(
grpc_config or GRPCClientConfig(),
auth=auth,
)
config = ClientConfig(
httpx_client=httpx_client,
supported_transports=[transport_protocol],
streaming=streaming and not use_polling,
polling=use_polling,
accepted_output_modes=["application/json"],
accepted_output_modes=accepted_output_modes or DEFAULT_CLIENT_OUTPUT_MODES, # type: ignore[arg-type]
push_notification_configs=push_configs,
grpc_channel_factory=grpc_channel_factory,
)
factory = ClientFactory(config)
client = factory.create(agent_card)
if client_extensions:
await client.add_request_middleware(ExtensionsMiddleware(client_extensions))
yield client

View File

@@ -0,0 +1,131 @@
"""Structured JSON logging utilities for A2A module."""
from __future__ import annotations
from contextvars import ContextVar
from datetime import datetime, timezone
import json
import logging
from typing import Any
_log_context: ContextVar[dict[str, Any] | None] = ContextVar(
"log_context", default=None
)
class JSONFormatter(logging.Formatter):
"""JSON formatter for structured logging.
Outputs logs as JSON with consistent fields for log aggregators.
"""
def format(self, record: logging.LogRecord) -> str:
"""Format log record as JSON string."""
log_data: dict[str, Any] = {
"timestamp": datetime.now(timezone.utc).isoformat(),
"level": record.levelname,
"logger": record.name,
"message": record.getMessage(),
}
if record.exc_info:
log_data["exception"] = self.formatException(record.exc_info)
context = _log_context.get()
if context is not None:
log_data.update(context)
if hasattr(record, "task_id"):
log_data["task_id"] = record.task_id
if hasattr(record, "context_id"):
log_data["context_id"] = record.context_id
if hasattr(record, "agent"):
log_data["agent"] = record.agent
if hasattr(record, "endpoint"):
log_data["endpoint"] = record.endpoint
if hasattr(record, "extension"):
log_data["extension"] = record.extension
if hasattr(record, "error"):
log_data["error"] = record.error
for key, value in record.__dict__.items():
if key.startswith("_") or key in (
"name",
"msg",
"args",
"created",
"filename",
"funcName",
"levelname",
"levelno",
"lineno",
"module",
"msecs",
"pathname",
"process",
"processName",
"relativeCreated",
"stack_info",
"exc_info",
"exc_text",
"thread",
"threadName",
"taskName",
"message",
):
continue
if key not in log_data:
log_data[key] = value
return json.dumps(log_data, default=str)
class LogContext:
"""Context manager for adding fields to all logs within a scope.
Example:
with LogContext(task_id="abc", context_id="xyz"):
logger.info("Processing task") # Includes task_id and context_id
"""
def __init__(self, **fields: Any) -> None:
self._fields = fields
self._token: Any = None
def __enter__(self) -> LogContext:
current = _log_context.get() or {}
new_context = {**current, **self._fields}
self._token = _log_context.set(new_context)
return self
def __exit__(self, *args: Any) -> None:
_log_context.reset(self._token)
def configure_json_logging(logger_name: str = "crewai.a2a") -> None:
"""Configure JSON logging for the A2A module.
Args:
logger_name: Logger name to configure.
"""
logger = logging.getLogger(logger_name)
for handler in logger.handlers[:]:
logger.removeHandler(handler)
handler = logging.StreamHandler()
handler.setFormatter(JSONFormatter())
logger.addHandler(handler)
def get_logger(name: str) -> logging.Logger:
"""Get a logger configured for structured JSON output.
Args:
name: Logger name.
Returns:
Configured logger instance.
"""
return logging.getLogger(name)

View File

@@ -7,26 +7,37 @@ import base64
from collections.abc import Callable, Coroutine
from datetime import datetime
from functools import wraps
import json
import logging
import os
from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar, cast
from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar, TypedDict, cast
from urllib.parse import urlparse
from a2a.server.agent_execution import RequestContext
from a2a.server.events import EventQueue
from a2a.types import (
Artifact,
InternalError,
InvalidParamsError,
Message,
Part,
Task as A2ATask,
TaskState,
TaskStatus,
TaskStatusUpdateEvent,
)
from a2a.utils import new_agent_text_message, new_text_artifact
from a2a.utils import (
get_data_parts,
new_agent_text_message,
new_data_artifact,
new_text_artifact,
)
from a2a.utils.errors import ServerError
from aiocache import SimpleMemoryCache, caches # type: ignore[import-untyped]
from pydantic import BaseModel
from crewai.a2a.utils.agent_card import _get_server_config
from crewai.a2a.utils.content_type import validate_message_parts
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.a2a_events import (
A2AServerTaskCanceledEvent,
@@ -35,9 +46,11 @@ from crewai.events.types.a2a_events import (
A2AServerTaskStartedEvent,
)
from crewai.task import Task
from crewai.utilities.pydantic_schema_utils import create_model_from_schema
if TYPE_CHECKING:
from crewai.a2a.extensions.server import ExtensionContext, ServerExtensionRegistry
from crewai.agent import Agent
@@ -47,7 +60,17 @@ P = ParamSpec("P")
T = TypeVar("T")
def _parse_redis_url(url: str) -> dict[str, Any]:
class RedisCacheConfig(TypedDict, total=False):
"""Configuration for aiocache Redis backend."""
cache: str
endpoint: str
port: int
db: int
password: str
def _parse_redis_url(url: str) -> RedisCacheConfig:
"""Parse a Redis URL into aiocache configuration.
Args:
@@ -56,9 +79,8 @@ def _parse_redis_url(url: str) -> dict[str, Any]:
Returns:
Configuration dict for aiocache.RedisCache.
"""
parsed = urlparse(url)
config: dict[str, Any] = {
config: RedisCacheConfig = {
"cache": "aiocache.RedisCache",
"endpoint": parsed.hostname or "localhost",
"port": parsed.port or 6379,
@@ -138,7 +160,10 @@ def cancellable(
if message["type"] == "message":
return True
except (OSError, ConnectionError) as e:
logger.warning("Cancel watcher error for task_id=%s: %s", task_id, e)
logger.warning(
"Cancel watcher Redis error, falling back to polling",
extra={"task_id": task_id, "error": str(e)},
)
return await poll_for_cancel()
return False
@@ -166,7 +191,67 @@ def cancellable(
return wrapper
@cancellable
def _extract_response_schema(parts: list[Part]) -> dict[str, Any] | None:
"""Extract response schema from message parts metadata.
The client may include a JSON schema in TextPart metadata to specify
the expected response format (see delegation.py line 463).
Args:
parts: List of message parts.
Returns:
JSON schema dict if found, None otherwise.
"""
for part in parts:
if part.root.kind == "text" and part.root.metadata:
schema = part.root.metadata.get("schema")
if schema and isinstance(schema, dict):
return schema # type: ignore[no-any-return]
return None
def _create_result_artifact(
result: Any,
task_id: str,
) -> Artifact:
"""Create artifact from task result, using DataPart for structured data.
Args:
result: The task execution result.
task_id: The task ID for naming the artifact.
Returns:
Artifact with appropriate part type (DataPart for dict/Pydantic, TextPart for strings).
"""
artifact_name = f"result_{task_id}"
if isinstance(result, dict):
return new_data_artifact(artifact_name, result)
if isinstance(result, BaseModel):
return new_data_artifact(artifact_name, result.model_dump())
return new_text_artifact(artifact_name, str(result))
def _build_task_description(
user_message: str,
structured_inputs: list[dict[str, Any]],
) -> str:
"""Build task description including structured data if present.
Args:
user_message: The original user message text.
structured_inputs: List of structured data from DataParts.
Returns:
Task description with structured data appended if present.
"""
if not structured_inputs:
return user_message
structured_json = json.dumps(structured_inputs, indent=2)
return f"{user_message}\n\nStructured Data:\n{structured_json}"
async def execute(
agent: Agent,
context: RequestContext,
@@ -178,15 +263,52 @@ async def execute(
agent: The CrewAI agent to execute the task.
context: The A2A request context containing the user's message.
event_queue: The event queue for sending responses back.
TODOs:
* need to impl both of structured output and file inputs, depends on `file_inputs` for
`crewai.task.Task`, pass the below two to Task. both utils in `a2a.utils.parts`
* structured outputs ingestion, `structured_inputs = get_data_parts(parts=context.message.parts)`
* file inputs ingestion, `file_inputs = get_file_parts(parts=context.message.parts)`
"""
await _execute_impl(agent, context, event_queue, None, None)
@cancellable
async def _execute_impl(
agent: Agent,
context: RequestContext,
event_queue: EventQueue,
extension_registry: ServerExtensionRegistry | None,
extension_context: ExtensionContext | None,
) -> None:
"""Internal implementation for task execution with optional extensions."""
server_config = _get_server_config(agent)
if context.message and context.message.parts and server_config:
allowed_modes = server_config.default_input_modes
invalid_types = validate_message_parts(context.message.parts, allowed_modes)
if invalid_types:
raise ServerError(
InvalidParamsError(
message=f"Unsupported content type(s): {', '.join(invalid_types)}. "
f"Supported: {', '.join(allowed_modes)}"
)
)
if extension_registry and extension_context:
await extension_registry.invoke_on_request(extension_context)
user_message = context.get_user_input()
response_model: type[BaseModel] | None = None
structured_inputs: list[dict[str, Any]] = []
if context.message and context.message.parts:
schema = _extract_response_schema(context.message.parts)
if schema:
try:
response_model = create_model_from_schema(schema)
except Exception as e:
logger.debug(
"Failed to create response model from schema",
extra={"error": str(e), "schema_title": schema.get("title")},
)
structured_inputs = get_data_parts(context.message.parts)
task_id = context.task_id
context_id = context.context_id
if task_id is None or context_id is None:
@@ -203,9 +325,10 @@ async def execute(
raise ServerError(InvalidParamsError(message=msg)) from None
task = Task(
description=user_message,
description=_build_task_description(user_message, structured_inputs),
expected_output="Response to the user's request",
agent=agent,
response_model=response_model,
)
crewai_event_bus.emit(
@@ -220,6 +343,10 @@ async def execute(
try:
result = await agent.aexecute_task(task=task, tools=agent.tools)
if extension_registry and extension_context:
result = await extension_registry.invoke_on_response(
extension_context, result
)
result_str = str(result)
history: list[Message] = [context.message] if context.message else []
history.append(new_agent_text_message(result_str, context_id, task_id))
@@ -227,8 +354,8 @@ async def execute(
A2ATask(
id=task_id,
context_id=context_id,
status=TaskStatus(state=TaskState.input_required),
artifacts=[new_text_artifact(result_str, f"result_{task_id}")],
status=TaskStatus(state=TaskState.completed),
artifacts=[_create_result_artifact(result, task_id)],
history=history,
)
)
@@ -269,6 +396,27 @@ async def execute(
) from e
async def execute_with_extensions(
agent: Agent,
context: RequestContext,
event_queue: EventQueue,
extension_registry: ServerExtensionRegistry,
extension_context: ExtensionContext,
) -> None:
"""Execute an A2A task with extension hooks.
Args:
agent: The CrewAI agent to execute the task.
context: The A2A request context containing the user's message.
event_queue: The event queue for sending responses back.
extension_registry: Registry of server extensions.
extension_context: Context for extension hooks.
"""
await _execute_impl(
agent, context, event_queue, extension_registry, extension_context
)
async def cancel(
context: RequestContext,
event_queue: EventQueue,

View File

@@ -0,0 +1,215 @@
"""Transport negotiation utilities for A2A protocol.
This module provides functionality for negotiating the transport protocol
between an A2A client and server based on their respective capabilities
and preferences.
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import Final, Literal
from a2a.types import AgentCard, AgentInterface
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.a2a_events import A2ATransportNegotiatedEvent
TransportProtocol = Literal["JSONRPC", "GRPC", "HTTP+JSON"]
NegotiationSource = Literal["client_preferred", "server_preferred", "fallback"]
JSONRPC_TRANSPORT: Literal["JSONRPC"] = "JSONRPC"
GRPC_TRANSPORT: Literal["GRPC"] = "GRPC"
HTTP_JSON_TRANSPORT: Literal["HTTP+JSON"] = "HTTP+JSON"
DEFAULT_TRANSPORT_PREFERENCE: Final[list[TransportProtocol]] = [
JSONRPC_TRANSPORT,
GRPC_TRANSPORT,
HTTP_JSON_TRANSPORT,
]
@dataclass
class NegotiatedTransport:
"""Result of transport negotiation.
Attributes:
transport: The negotiated transport protocol.
url: The URL to use for this transport.
source: How the transport was selected ('preferred', 'additional', 'fallback').
"""
transport: str
url: str
source: NegotiationSource
class TransportNegotiationError(Exception):
"""Raised when no compatible transport can be negotiated."""
def __init__(
self,
client_transports: list[str],
server_transports: list[str],
message: str | None = None,
) -> None:
"""Initialize the error with negotiation details.
Args:
client_transports: Transports supported by the client.
server_transports: Transports supported by the server.
message: Optional custom error message.
"""
self.client_transports = client_transports
self.server_transports = server_transports
if message is None:
message = (
f"No compatible transport found. "
f"Client supports: {client_transports}. "
f"Server supports: {server_transports}."
)
super().__init__(message)
def _get_server_interfaces(agent_card: AgentCard) -> list[AgentInterface]:
"""Extract all available interfaces from an AgentCard.
Creates a unified list of interfaces including the primary URL and
any additional interfaces declared by the agent.
Args:
agent_card: The agent's card containing transport information.
Returns:
List of AgentInterface objects representing all available endpoints.
"""
interfaces: list[AgentInterface] = []
primary_transport = agent_card.preferred_transport or JSONRPC_TRANSPORT
interfaces.append(
AgentInterface(
transport=primary_transport,
url=agent_card.url,
)
)
if agent_card.additional_interfaces:
for interface in agent_card.additional_interfaces:
is_duplicate = any(
i.url == interface.url and i.transport == interface.transport
for i in interfaces
)
if not is_duplicate:
interfaces.append(interface)
return interfaces
def negotiate_transport(
agent_card: AgentCard,
client_supported_transports: list[str] | None = None,
client_preferred_transport: str | None = None,
emit_event: bool = True,
endpoint: str | None = None,
a2a_agent_name: str | None = None,
) -> NegotiatedTransport:
"""Negotiate the transport protocol between client and server.
Compares the client's supported transports with the server's available
interfaces to find a compatible transport and URL.
Negotiation logic:
1. If client_preferred_transport is set and server supports it → use it
2. Otherwise, if server's preferred is in client's supported → use server's
3. Otherwise, find first match from client's supported in server's interfaces
Args:
agent_card: The server's AgentCard with transport information.
client_supported_transports: Transports the client can use.
Defaults to ["JSONRPC"] if not specified.
client_preferred_transport: Client's preferred transport. If set and
server supports it, takes priority over server preference.
emit_event: Whether to emit a transport negotiation event.
endpoint: Original endpoint URL for event metadata.
a2a_agent_name: Agent name for event metadata.
Returns:
NegotiatedTransport with the selected transport, URL, and source.
Raises:
TransportNegotiationError: If no compatible transport is found.
"""
if client_supported_transports is None:
client_supported_transports = [JSONRPC_TRANSPORT]
client_transports = [t.upper() for t in client_supported_transports]
client_preferred = (
client_preferred_transport.upper() if client_preferred_transport else None
)
server_interfaces = _get_server_interfaces(agent_card)
server_transports = [i.transport.upper() for i in server_interfaces]
transport_to_interface: dict[str, AgentInterface] = {}
for interface in server_interfaces:
transport_upper = interface.transport.upper()
if transport_upper not in transport_to_interface:
transport_to_interface[transport_upper] = interface
result: NegotiatedTransport | None = None
if client_preferred and client_preferred in transport_to_interface:
interface = transport_to_interface[client_preferred]
result = NegotiatedTransport(
transport=interface.transport,
url=interface.url,
source="client_preferred",
)
else:
server_preferred = (agent_card.preferred_transport or JSONRPC_TRANSPORT).upper()
if (
server_preferred in client_transports
and server_preferred in transport_to_interface
):
interface = transport_to_interface[server_preferred]
result = NegotiatedTransport(
transport=interface.transport,
url=interface.url,
source="server_preferred",
)
else:
for transport in client_transports:
if transport in transport_to_interface:
interface = transport_to_interface[transport]
result = NegotiatedTransport(
transport=interface.transport,
url=interface.url,
source="fallback",
)
break
if result is None:
raise TransportNegotiationError(
client_transports=client_transports,
server_transports=server_transports,
)
if emit_event:
crewai_event_bus.emit(
None,
A2ATransportNegotiatedEvent(
endpoint=endpoint or agent_card.url,
a2a_agent_name=a2a_agent_name or agent_card.name,
negotiated_transport=result.transport,
negotiated_url=result.url,
source=result.source,
client_supported_transports=client_transports,
server_supported_transports=server_transports,
server_preferred_transport=agent_card.preferred_transport
or JSONRPC_TRANSPORT,
client_preferred_transport=client_preferred,
),
)
return result

View File

@@ -11,19 +11,23 @@ from concurrent.futures import ThreadPoolExecutor, as_completed
from functools import wraps
import json
from types import MethodType
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, NamedTuple
from a2a.types import Role, TaskState
from pydantic import BaseModel, ValidationError
from crewai.a2a.config import A2AClientConfig, A2AConfig
from crewai.a2a.extensions.base import ExtensionRegistry
from crewai.a2a.extensions.base import (
A2AExtension,
ConversationState,
ExtensionRegistry,
)
from crewai.a2a.task_helpers import TaskStateResult
from crewai.a2a.templates import (
AVAILABLE_AGENTS_TEMPLATE,
CONVERSATION_TURN_INFO_TEMPLATE,
PREVIOUS_A2A_CONVERSATION_TEMPLATE,
REMOTE_AGENT_COMPLETED_NOTICE,
REMOTE_AGENT_RESPONSE_NOTICE,
UNAVAILABLE_AGENTS_NOTICE_TEMPLATE,
)
from crewai.a2a.types import AgentResponseProtocol
@@ -52,6 +56,42 @@ if TYPE_CHECKING:
from crewai.tools.base_tool import BaseTool
class DelegationContext(NamedTuple):
"""Context prepared for A2A delegation.
Groups all the values needed to execute a delegation to a remote A2A agent.
"""
a2a_agents: list[A2AConfig | A2AClientConfig]
agent_response_model: type[BaseModel] | None
current_request: str
agent_id: str
agent_config: A2AConfig | A2AClientConfig
context_id: str | None
task_id: str | None
metadata: dict[str, Any] | None
extensions: dict[str, Any] | None
reference_task_ids: list[str]
original_task_description: str
max_turns: int
class DelegationState(NamedTuple):
"""Mutable state for A2A delegation loop.
Groups values that may change during delegation turns.
"""
current_request: str
context_id: str | None
task_id: str | None
reference_task_ids: list[str]
conversation_history: list[Message]
agent_card: AgentCard | None
agent_card_dict: dict[str, Any] | None
agent_name: str | None
def wrap_agent_with_a2a_instance(
agent: Agent, extension_registry: ExtensionRegistry | None = None
) -> None:
@@ -165,7 +205,11 @@ def _fetch_agent_cards_concurrently(
agent_cards: dict[str, AgentCard] = {}
failed_agents: dict[str, str] = {}
with ThreadPoolExecutor(max_workers=len(a2a_agents)) as executor:
if not a2a_agents:
return agent_cards, failed_agents
max_workers = min(len(a2a_agents), 10)
with ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = {
executor.submit(_fetch_card_from_config, config): config
for config in a2a_agents
@@ -231,7 +275,7 @@ def _execute_task_with_a2a(
finally:
task.description = original_description
task.description, _ = _augment_prompt_with_a2a(
task.description, _, extension_states = _augment_prompt_with_a2a(
a2a_agents=a2a_agents,
task_description=original_description,
agent_cards=agent_cards,
@@ -248,7 +292,7 @@ def _execute_task_with_a2a(
if extension_registry and isinstance(agent_response, BaseModel):
agent_response = extension_registry.process_response_with_all(
agent_response, {}
agent_response, extension_states
)
if isinstance(agent_response, BaseModel) and isinstance(
@@ -264,7 +308,7 @@ def _execute_task_with_a2a(
tools=tools,
agent_cards=agent_cards,
original_task_description=original_description,
extension_registry=extension_registry,
_extension_registry=extension_registry,
)
return str(agent_response.message)
@@ -284,8 +328,8 @@ def _augment_prompt_with_a2a(
max_turns: int | None = None,
failed_agents: dict[str, str] | None = None,
extension_registry: ExtensionRegistry | None = None,
remote_task_completed: bool = False,
) -> tuple[str, bool]:
remote_status_notice: str = "",
) -> tuple[str, bool, dict[type[A2AExtension], ConversationState]]:
"""Add A2A delegation instructions to prompt.
Args:
@@ -297,13 +341,14 @@ def _augment_prompt_with_a2a(
max_turns: Maximum allowed turns (from config)
failed_agents: Dictionary mapping failed agent endpoints to error messages
extension_registry: Optional registry of A2A extensions
remote_status_notice: Optional notice about remote agent status to append
Returns:
Tuple of (augmented prompt, disable_structured_output flag)
Tuple of (augmented prompt, disable_structured_output flag, extension_states dict)
"""
if not agent_cards:
return task_description, False
return task_description, False, {}
agents_text = ""
@@ -365,15 +410,11 @@ def _augment_prompt_with_a2a(
warning=warning,
)
completion_notice = ""
if remote_task_completed and conversation_history:
completion_notice = REMOTE_AGENT_COMPLETED_NOTICE
augmented_prompt = f"""{task_description}
IMPORTANT: You have the ability to delegate this task to remote A2A agents.
{agents_text}
{history_text}{turn_info}{completion_notice}
{history_text}{turn_info}{remote_status_notice}
"""
@@ -382,7 +423,7 @@ IMPORTANT: You have the ability to delegate this task to remote A2A agents.
augmented_prompt, extension_states
)
return augmented_prompt, disable_structured_output
return augmented_prompt, disable_structured_output, extension_states
def _parse_agent_response(
@@ -461,12 +502,41 @@ def _handle_max_turns_exceeded(
raise Exception(f"A2A conversation exceeded maximum turns ({max_turns})")
def _emit_delegation_failed(
error_msg: str,
turn_num: int,
from_task: Any | None,
from_agent: Any | None,
endpoint: str | None,
a2a_agent_name: str | None,
agent_card: dict[str, Any] | None,
) -> str:
"""Emit failure event and return formatted error message."""
crewai_event_bus.emit(
None,
A2AConversationCompletedEvent(
status="failed",
final_result=None,
error=error_msg,
total_turns=turn_num + 1,
from_task=from_task,
from_agent=from_agent,
endpoint=endpoint,
a2a_agent_name=a2a_agent_name,
agent_card=agent_card,
),
)
return f"A2A delegation failed: {error_msg}"
def _process_response_result(
raw_result: str,
disable_structured_output: bool,
turn_num: int,
agent_role: str,
agent_response_model: type[BaseModel] | None,
extension_registry: ExtensionRegistry | None = None,
extension_states: dict[type[A2AExtension], ConversationState] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
endpoint: str | None = None,
@@ -516,6 +586,11 @@ def _process_response_result(
raw_result=raw_result, agent_response_model=agent_response_model
)
if extension_registry and isinstance(llm_response, BaseModel):
llm_response = extension_registry.process_response_with_all(
llm_response, extension_states or {}
)
if isinstance(llm_response, BaseModel) and isinstance(
llm_response, AgentResponseProtocol
):
@@ -571,72 +646,103 @@ def _prepare_agent_cards_dict(
return agent_cards_dict
def _init_delegation_state(
ctx: DelegationContext,
agent_cards: dict[str, AgentCard] | None,
) -> DelegationState:
"""Initialize delegation state from context and agent cards.
Args:
ctx: Delegation context with config and settings.
agent_cards: Pre-fetched agent cards.
Returns:
Initial delegation state for the conversation loop.
"""
current_agent_card = agent_cards.get(ctx.agent_id) if agent_cards else None
return DelegationState(
current_request=ctx.current_request,
context_id=ctx.context_id,
task_id=ctx.task_id,
reference_task_ids=list(ctx.reference_task_ids),
conversation_history=[],
agent_card=current_agent_card,
agent_card_dict=current_agent_card.model_dump() if current_agent_card else None,
agent_name=current_agent_card.name if current_agent_card else None,
)
def _get_turn_context(
agent_config: A2AConfig | A2AClientConfig,
) -> tuple[Any | None, list[str] | None]:
"""Get context for a delegation turn.
Returns:
Tuple of (agent_branch, accepted_output_modes).
"""
console_formatter = getattr(crewai_event_bus, "_console", None)
agent_branch = None
if console_formatter:
agent_branch = getattr(
console_formatter, "current_agent_branch", None
) or getattr(console_formatter, "current_task_branch", None)
accepted_output_modes = None
if isinstance(agent_config, A2AClientConfig):
accepted_output_modes = agent_config.accepted_output_modes
return agent_branch, accepted_output_modes
def _prepare_delegation_context(
self: Agent,
agent_response: AgentResponseProtocol,
task: Task,
original_task_description: str | None,
) -> tuple[
list[A2AConfig | A2AClientConfig],
type[BaseModel] | None,
str,
str,
A2AConfig | A2AClientConfig,
str | None,
str | None,
dict[str, Any] | None,
dict[str, Any] | None,
list[str],
str,
int,
]:
) -> DelegationContext:
"""Prepare delegation context from agent response and task.
Shared logic for both sync and async delegation.
Returns:
Tuple containing all the context values needed for delegation.
DelegationContext with all values needed for delegation.
"""
a2a_agents, agent_response_model = get_a2a_agents_and_response_model(self.a2a)
agent_ids = tuple(config.endpoint for config in a2a_agents)
current_request = str(agent_response.message)
if hasattr(agent_response, "a2a_ids") and agent_response.a2a_ids:
if not a2a_agents:
raise ValueError("No A2A agents configured for delegation")
if isinstance(agent_response, AgentResponseProtocol) and agent_response.a2a_ids:
agent_id = agent_response.a2a_ids[0]
else:
agent_id = agent_ids[0] if agent_ids else ""
agent_id = agent_ids[0]
if agent_id and agent_id not in agent_ids:
raise ValueError(
f"Unknown A2A agent ID(s): {agent_response.a2a_ids} not in {agent_ids}"
)
if agent_id not in agent_ids:
raise ValueError(f"Unknown A2A agent ID: {agent_id} not in {agent_ids}")
agent_config = next(filter(lambda x: x.endpoint == agent_id, a2a_agents))
agent_config = next(filter(lambda x: x.endpoint == agent_id, a2a_agents), None)
if agent_config is None:
raise ValueError(f"Agent configuration not found for endpoint: {agent_id}")
task_config = task.config or {}
context_id = task_config.get("context_id")
task_id_config = task_config.get("task_id")
metadata = task_config.get("metadata")
extensions = task_config.get("extensions")
reference_task_ids = task_config.get("reference_task_ids", [])
if original_task_description is None:
original_task_description = task.description
max_turns = agent_config.max_turns
return (
a2a_agents,
agent_response_model,
current_request,
agent_id,
agent_config,
context_id,
task_id_config,
metadata,
extensions,
reference_task_ids,
original_task_description,
max_turns,
return DelegationContext(
a2a_agents=a2a_agents,
agent_response_model=agent_response_model,
current_request=current_request,
agent_id=agent_id,
agent_config=agent_config,
context_id=task_config.get("context_id"),
task_id=task_config.get("task_id"),
metadata=task_config.get("metadata"),
extensions=task_config.get("extensions"),
reference_task_ids=task_config.get("reference_task_ids", []),
original_task_description=original_task_description,
max_turns=agent_config.max_turns,
)
@@ -652,20 +758,50 @@ def _handle_task_completion(
endpoint: str | None = None,
a2a_agent_name: str | None = None,
agent_card: dict[str, Any] | None = None,
) -> tuple[str | None, str | None, list[str]]:
) -> tuple[str | None, str | None, list[str], str]:
"""Handle task completion state including reference task updates.
When a remote task completes, this function:
1. Adds the completed task_id to reference_task_ids (if not already present)
2. Clears task_id_config to signal that a new task ID should be generated for next turn
3. Updates task.config with the reference list for subsequent A2A calls
The reference_task_ids list tracks all completed tasks in this conversation chain,
allowing the remote agent to maintain context across multi-turn interactions.
Shared logic for both sync and async delegation.
Args:
a2a_result: Result from A2A delegation containing task status.
task: CrewAI Task object to update with reference IDs.
task_id_config: Current task ID (will be added to references if task completed).
reference_task_ids: Mutable list of completed task IDs (updated in place).
agent_config: A2A configuration with trust settings.
turn_num: Current turn number.
from_task: Optional CrewAI Task for event metadata.
from_agent: Optional CrewAI Agent for event metadata.
endpoint: A2A endpoint URL.
a2a_agent_name: Name of remote A2A agent.
agent_card: Agent card dict for event metadata.
Returns:
Tuple of (result_if_trusted, updated_task_id, updated_reference_task_ids).
Tuple of (result_if_trusted, updated_task_id, updated_reference_task_ids, remote_notice).
- result_if_trusted: Final result if trust_remote_completion_status=True, else None
- updated_task_id: None (cleared to generate new ID for next turn)
- updated_reference_task_ids: The mutated list with completed task added
- remote_notice: Template notice about remote agent response
"""
remote_notice = ""
if a2a_result["status"] == TaskState.completed:
remote_notice = REMOTE_AGENT_RESPONSE_NOTICE
if task_id_config is not None and task_id_config not in reference_task_ids:
reference_task_ids.append(task_id_config)
if task.config is None:
task.config = {}
task.config["reference_task_ids"] = reference_task_ids
task.config["reference_task_ids"] = list(reference_task_ids)
task_id_config = None
if agent_config.trust_remote_completion_status:
@@ -685,9 +821,9 @@ def _handle_task_completion(
agent_card=agent_card,
),
)
return str(result_text), task_id_config, reference_task_ids
return str(result_text), task_id_config, reference_task_ids, remote_notice
return None, task_id_config, reference_task_ids
return None, task_id_config, reference_task_ids, remote_notice
def _handle_agent_response_and_continue(
@@ -705,7 +841,8 @@ def _handle_agent_response_and_continue(
context: str | None,
tools: list[BaseTool] | None,
agent_response_model: type[BaseModel] | None,
remote_task_completed: bool = False,
extension_registry: ExtensionRegistry | None = None,
remote_status_notice: str = "",
endpoint: str | None = None,
a2a_agent_name: str | None = None,
agent_card: dict[str, Any] | None = None,
@@ -735,14 +872,18 @@ def _handle_agent_response_and_continue(
"""
agent_cards_dict = _prepare_agent_cards_dict(a2a_result, agent_id, agent_cards)
task.description, disable_structured_output = _augment_prompt_with_a2a(
(
task.description,
disable_structured_output,
extension_states,
) = _augment_prompt_with_a2a(
a2a_agents=a2a_agents,
task_description=original_task_description,
conversation_history=conversation_history,
turn_num=turn_num,
max_turns=max_turns,
agent_cards=agent_cards_dict,
remote_task_completed=remote_task_completed,
remote_status_notice=remote_status_notice,
)
original_response_model = task.response_model
@@ -760,6 +901,8 @@ def _handle_agent_response_and_continue(
turn_num=turn_num,
agent_role=self.role,
agent_response_model=agent_response_model,
extension_registry=extension_registry,
extension_states=extension_states,
from_task=task,
from_agent=self,
endpoint=endpoint,
@@ -777,7 +920,7 @@ def _delegate_to_a2a(
tools: list[BaseTool] | None,
agent_cards: dict[str, AgentCard] | None = None,
original_task_description: str | None = None,
extension_registry: ExtensionRegistry | None = None,
_extension_registry: ExtensionRegistry | None = None,
) -> str:
"""Delegate to A2A agent with multi-turn conversation support.
@@ -790,7 +933,7 @@ def _delegate_to_a2a(
tools: Optional tools available to the agent
agent_cards: Pre-fetched agent cards from _execute_task_with_a2a
original_task_description: The original task description before A2A augmentation
extension_registry: Optional registry of A2A extensions
_extension_registry: Optional registry of A2A extensions (unused, reserved for future use)
Returns:
Result from A2A agent
@@ -798,60 +941,42 @@ def _delegate_to_a2a(
Raises:
ImportError: If a2a-sdk is not installed
"""
(
a2a_agents,
agent_response_model,
current_request,
agent_id,
agent_config,
context_id,
task_id_config,
metadata,
extensions,
reference_task_ids,
original_task_description,
max_turns,
) = _prepare_delegation_context(
ctx = _prepare_delegation_context(
self, agent_response, task, original_task_description
)
conversation_history: list[Message] = []
current_agent_card = agent_cards.get(agent_id) if agent_cards else None
current_agent_card_dict = (
current_agent_card.model_dump() if current_agent_card else None
)
current_a2a_agent_name = current_agent_card.name if current_agent_card else None
state = _init_delegation_state(ctx, agent_cards)
current_request = state.current_request
context_id = state.context_id
task_id = state.task_id
reference_task_ids = state.reference_task_ids
conversation_history = state.conversation_history
try:
for turn_num in range(max_turns):
console_formatter = getattr(crewai_event_bus, "_console", None)
agent_branch = None
if console_formatter:
agent_branch = getattr(
console_formatter, "current_agent_branch", None
) or getattr(console_formatter, "current_task_branch", None)
for turn_num in range(ctx.max_turns):
agent_branch, accepted_output_modes = _get_turn_context(ctx.agent_config)
a2a_result = execute_a2a_delegation(
endpoint=agent_config.endpoint,
auth=agent_config.auth,
timeout=agent_config.timeout,
endpoint=ctx.agent_config.endpoint,
auth=ctx.agent_config.auth,
timeout=ctx.agent_config.timeout,
task_description=current_request,
context_id=context_id,
task_id=task_id_config,
task_id=task_id,
reference_task_ids=reference_task_ids,
metadata=metadata,
extensions=extensions,
metadata=ctx.metadata,
extensions=ctx.extensions,
conversation_history=conversation_history,
agent_id=agent_id,
agent_id=ctx.agent_id,
agent_role=Role.user,
agent_branch=agent_branch,
response_model=agent_config.response_model,
response_model=ctx.agent_config.response_model,
turn_number=turn_num + 1,
updates=agent_config.updates,
transport_protocol=agent_config.transport_protocol,
updates=ctx.agent_config.updates,
transport=ctx.agent_config.transport,
from_task=task,
from_agent=self,
client_extensions=getattr(ctx.agent_config, "extensions", None),
accepted_output_modes=accepted_output_modes,
)
conversation_history = a2a_result.get("history", [])
@@ -859,24 +984,24 @@ def _delegate_to_a2a(
if conversation_history:
latest_message = conversation_history[-1]
if latest_message.task_id is not None:
task_id_config = latest_message.task_id
task_id = latest_message.task_id
if latest_message.context_id is not None:
context_id = latest_message.context_id
if a2a_result["status"] in [TaskState.completed, TaskState.input_required]:
trusted_result, task_id_config, reference_task_ids = (
trusted_result, task_id, reference_task_ids, remote_notice = (
_handle_task_completion(
a2a_result,
task,
task_id_config,
task_id,
reference_task_ids,
agent_config,
ctx.agent_config,
turn_num,
from_task=task,
from_agent=self,
endpoint=agent_config.endpoint,
a2a_agent_name=current_a2a_agent_name,
agent_card=current_agent_card_dict,
endpoint=ctx.agent_config.endpoint,
a2a_agent_name=state.agent_name,
agent_card=state.agent_card_dict,
)
)
if trusted_result is not None:
@@ -885,22 +1010,23 @@ def _delegate_to_a2a(
final_result, next_request = _handle_agent_response_and_continue(
self=self,
a2a_result=a2a_result,
agent_id=agent_id,
agent_id=ctx.agent_id,
agent_cards=agent_cards,
a2a_agents=a2a_agents,
original_task_description=original_task_description,
a2a_agents=ctx.a2a_agents,
original_task_description=ctx.original_task_description,
conversation_history=conversation_history,
turn_num=turn_num,
max_turns=max_turns,
max_turns=ctx.max_turns,
task=task,
original_fn=original_fn,
context=context,
tools=tools,
agent_response_model=agent_response_model,
remote_task_completed=(a2a_result["status"] == TaskState.completed),
endpoint=agent_config.endpoint,
a2a_agent_name=current_a2a_agent_name,
agent_card=current_agent_card_dict,
agent_response_model=ctx.agent_response_model,
extension_registry=_extension_registry,
remote_status_notice=remote_notice,
endpoint=ctx.agent_config.endpoint,
a2a_agent_name=state.agent_name,
agent_card=state.agent_card_dict,
)
if final_result is not None:
@@ -916,22 +1042,22 @@ def _delegate_to_a2a(
final_result, next_request = _handle_agent_response_and_continue(
self=self,
a2a_result=a2a_result,
agent_id=agent_id,
agent_id=ctx.agent_id,
agent_cards=agent_cards,
a2a_agents=a2a_agents,
original_task_description=original_task_description,
a2a_agents=ctx.a2a_agents,
original_task_description=ctx.original_task_description,
conversation_history=conversation_history,
turn_num=turn_num,
max_turns=max_turns,
max_turns=ctx.max_turns,
task=task,
original_fn=original_fn,
context=context,
tools=tools,
agent_response_model=agent_response_model,
remote_task_completed=False,
endpoint=agent_config.endpoint,
a2a_agent_name=current_a2a_agent_name,
agent_card=current_agent_card_dict,
agent_response_model=ctx.agent_response_model,
extension_registry=_extension_registry,
endpoint=ctx.agent_config.endpoint,
a2a_agent_name=state.agent_name,
agent_card=state.agent_card_dict,
)
if final_result is not None:
@@ -941,34 +1067,28 @@ def _delegate_to_a2a(
current_request = next_request
continue
crewai_event_bus.emit(
None,
A2AConversationCompletedEvent(
status="failed",
final_result=None,
error=error_msg,
total_turns=turn_num + 1,
from_task=task,
from_agent=self,
endpoint=agent_config.endpoint,
a2a_agent_name=current_a2a_agent_name,
agent_card=current_agent_card_dict,
),
return _emit_delegation_failed(
error_msg,
turn_num,
task,
self,
ctx.agent_config.endpoint,
state.agent_name,
state.agent_card_dict,
)
return f"A2A delegation failed: {error_msg}"
return _handle_max_turns_exceeded(
conversation_history,
max_turns,
ctx.max_turns,
from_task=task,
from_agent=self,
endpoint=agent_config.endpoint,
a2a_agent_name=current_a2a_agent_name,
agent_card=current_agent_card_dict,
endpoint=ctx.agent_config.endpoint,
a2a_agent_name=state.agent_name,
agent_card=state.agent_card_dict,
)
finally:
task.description = original_task_description
task.description = ctx.original_task_description
async def _afetch_card_from_config(
@@ -993,6 +1113,9 @@ async def _afetch_agent_cards_concurrently(
agent_cards: dict[str, AgentCard] = {}
failed_agents: dict[str, str] = {}
if not a2a_agents:
return agent_cards, failed_agents
tasks = [_afetch_card_from_config(config) for config in a2a_agents]
results = await asyncio.gather(*tasks)
@@ -1042,7 +1165,7 @@ async def _aexecute_task_with_a2a(
finally:
task.description = original_description
task.description, _ = _augment_prompt_with_a2a(
task.description, _, extension_states = _augment_prompt_with_a2a(
a2a_agents=a2a_agents,
task_description=original_description,
agent_cards=agent_cards,
@@ -1059,7 +1182,7 @@ async def _aexecute_task_with_a2a(
if extension_registry and isinstance(agent_response, BaseModel):
agent_response = extension_registry.process_response_with_all(
agent_response, {}
agent_response, extension_states
)
if isinstance(agent_response, BaseModel) and isinstance(
@@ -1075,7 +1198,7 @@ async def _aexecute_task_with_a2a(
tools=tools,
agent_cards=agent_cards,
original_task_description=original_description,
extension_registry=extension_registry,
_extension_registry=extension_registry,
)
return str(agent_response.message)
@@ -1101,7 +1224,8 @@ async def _ahandle_agent_response_and_continue(
context: str | None,
tools: list[BaseTool] | None,
agent_response_model: type[BaseModel] | None,
remote_task_completed: bool = False,
extension_registry: ExtensionRegistry | None = None,
remote_status_notice: str = "",
endpoint: str | None = None,
a2a_agent_name: str | None = None,
agent_card: dict[str, Any] | None = None,
@@ -1109,14 +1233,18 @@ async def _ahandle_agent_response_and_continue(
"""Async version of _handle_agent_response_and_continue."""
agent_cards_dict = _prepare_agent_cards_dict(a2a_result, agent_id, agent_cards)
task.description, disable_structured_output = _augment_prompt_with_a2a(
(
task.description,
disable_structured_output,
extension_states,
) = _augment_prompt_with_a2a(
a2a_agents=a2a_agents,
task_description=original_task_description,
conversation_history=conversation_history,
turn_num=turn_num,
max_turns=max_turns,
agent_cards=agent_cards_dict,
remote_task_completed=remote_task_completed,
remote_status_notice=remote_status_notice,
)
original_response_model = task.response_model
@@ -1134,6 +1262,8 @@ async def _ahandle_agent_response_and_continue(
turn_num=turn_num,
agent_role=self.role,
agent_response_model=agent_response_model,
extension_registry=extension_registry,
extension_states=extension_states,
from_task=task,
from_agent=self,
endpoint=endpoint,
@@ -1151,63 +1281,45 @@ async def _adelegate_to_a2a(
tools: list[BaseTool] | None,
agent_cards: dict[str, AgentCard] | None = None,
original_task_description: str | None = None,
extension_registry: ExtensionRegistry | None = None,
_extension_registry: ExtensionRegistry | None = None,
) -> str:
"""Async version of _delegate_to_a2a."""
(
a2a_agents,
agent_response_model,
current_request,
agent_id,
agent_config,
context_id,
task_id_config,
metadata,
extensions,
reference_task_ids,
original_task_description,
max_turns,
) = _prepare_delegation_context(
ctx = _prepare_delegation_context(
self, agent_response, task, original_task_description
)
conversation_history: list[Message] = []
current_agent_card = agent_cards.get(agent_id) if agent_cards else None
current_agent_card_dict = (
current_agent_card.model_dump() if current_agent_card else None
)
current_a2a_agent_name = current_agent_card.name if current_agent_card else None
state = _init_delegation_state(ctx, agent_cards)
current_request = state.current_request
context_id = state.context_id
task_id = state.task_id
reference_task_ids = state.reference_task_ids
conversation_history = state.conversation_history
try:
for turn_num in range(max_turns):
console_formatter = getattr(crewai_event_bus, "_console", None)
agent_branch = None
if console_formatter:
agent_branch = getattr(
console_formatter, "current_agent_branch", None
) or getattr(console_formatter, "current_task_branch", None)
for turn_num in range(ctx.max_turns):
agent_branch, accepted_output_modes = _get_turn_context(ctx.agent_config)
a2a_result = await aexecute_a2a_delegation(
endpoint=agent_config.endpoint,
auth=agent_config.auth,
timeout=agent_config.timeout,
endpoint=ctx.agent_config.endpoint,
auth=ctx.agent_config.auth,
timeout=ctx.agent_config.timeout,
task_description=current_request,
context_id=context_id,
task_id=task_id_config,
task_id=task_id,
reference_task_ids=reference_task_ids,
metadata=metadata,
extensions=extensions,
metadata=ctx.metadata,
extensions=ctx.extensions,
conversation_history=conversation_history,
agent_id=agent_id,
agent_id=ctx.agent_id,
agent_role=Role.user,
agent_branch=agent_branch,
response_model=agent_config.response_model,
response_model=ctx.agent_config.response_model,
turn_number=turn_num + 1,
transport_protocol=agent_config.transport_protocol,
updates=agent_config.updates,
transport=ctx.agent_config.transport,
updates=ctx.agent_config.updates,
from_task=task,
from_agent=self,
client_extensions=getattr(ctx.agent_config, "extensions", None),
accepted_output_modes=accepted_output_modes,
)
conversation_history = a2a_result.get("history", [])
@@ -1215,24 +1327,24 @@ async def _adelegate_to_a2a(
if conversation_history:
latest_message = conversation_history[-1]
if latest_message.task_id is not None:
task_id_config = latest_message.task_id
task_id = latest_message.task_id
if latest_message.context_id is not None:
context_id = latest_message.context_id
if a2a_result["status"] in [TaskState.completed, TaskState.input_required]:
trusted_result, task_id_config, reference_task_ids = (
trusted_result, task_id, reference_task_ids, remote_notice = (
_handle_task_completion(
a2a_result,
task,
task_id_config,
task_id,
reference_task_ids,
agent_config,
ctx.agent_config,
turn_num,
from_task=task,
from_agent=self,
endpoint=agent_config.endpoint,
a2a_agent_name=current_a2a_agent_name,
agent_card=current_agent_card_dict,
endpoint=ctx.agent_config.endpoint,
a2a_agent_name=state.agent_name,
agent_card=state.agent_card_dict,
)
)
if trusted_result is not None:
@@ -1241,22 +1353,23 @@ async def _adelegate_to_a2a(
final_result, next_request = await _ahandle_agent_response_and_continue(
self=self,
a2a_result=a2a_result,
agent_id=agent_id,
agent_id=ctx.agent_id,
agent_cards=agent_cards,
a2a_agents=a2a_agents,
original_task_description=original_task_description,
a2a_agents=ctx.a2a_agents,
original_task_description=ctx.original_task_description,
conversation_history=conversation_history,
turn_num=turn_num,
max_turns=max_turns,
max_turns=ctx.max_turns,
task=task,
original_fn=original_fn,
context=context,
tools=tools,
agent_response_model=agent_response_model,
remote_task_completed=(a2a_result["status"] == TaskState.completed),
endpoint=agent_config.endpoint,
a2a_agent_name=current_a2a_agent_name,
agent_card=current_agent_card_dict,
agent_response_model=ctx.agent_response_model,
extension_registry=_extension_registry,
remote_status_notice=remote_notice,
endpoint=ctx.agent_config.endpoint,
a2a_agent_name=state.agent_name,
agent_card=state.agent_card_dict,
)
if final_result is not None:
@@ -1272,21 +1385,22 @@ async def _adelegate_to_a2a(
final_result, next_request = await _ahandle_agent_response_and_continue(
self=self,
a2a_result=a2a_result,
agent_id=agent_id,
agent_id=ctx.agent_id,
agent_cards=agent_cards,
a2a_agents=a2a_agents,
original_task_description=original_task_description,
a2a_agents=ctx.a2a_agents,
original_task_description=ctx.original_task_description,
conversation_history=conversation_history,
turn_num=turn_num,
max_turns=max_turns,
max_turns=ctx.max_turns,
task=task,
original_fn=original_fn,
context=context,
tools=tools,
agent_response_model=agent_response_model,
endpoint=agent_config.endpoint,
a2a_agent_name=current_a2a_agent_name,
agent_card=current_agent_card_dict,
agent_response_model=ctx.agent_response_model,
extension_registry=_extension_registry,
endpoint=ctx.agent_config.endpoint,
a2a_agent_name=state.agent_name,
agent_card=state.agent_card_dict,
)
if final_result is not None:
@@ -1296,31 +1410,25 @@ async def _adelegate_to_a2a(
current_request = next_request
continue
crewai_event_bus.emit(
None,
A2AConversationCompletedEvent(
status="failed",
final_result=None,
error=error_msg,
total_turns=turn_num + 1,
from_task=task,
from_agent=self,
endpoint=agent_config.endpoint,
a2a_agent_name=current_a2a_agent_name,
agent_card=current_agent_card_dict,
),
return _emit_delegation_failed(
error_msg,
turn_num,
task,
self,
ctx.agent_config.endpoint,
state.agent_name,
state.agent_card_dict,
)
return f"A2A delegation failed: {error_msg}"
return _handle_max_turns_exceeded(
conversation_history,
max_turns,
ctx.max_turns,
from_task=task,
from_agent=self,
endpoint=agent_config.endpoint,
a2a_agent_name=current_a2a_agent_name,
agent_card=current_agent_card_dict,
endpoint=ctx.agent_config.endpoint,
a2a_agent_name=state.agent_name,
agent_card=state.agent_card_dict,
)
finally:
task.description = original_task_description
task.description = ctx.original_task_description

View File

@@ -37,7 +37,8 @@ class CrewAgentExecutorMixin:
self.crew
and self.agent
and self.task
and f"Action: {sanitize_tool_name('Delegate work to coworker')}" not in output.text
and f"Action: {sanitize_tool_name('Delegate work to coworker')}"
not in output.text
):
try:
if (
@@ -132,10 +133,11 @@ class CrewAgentExecutorMixin:
and self.crew._long_term_memory
and self.crew._entity_memory is None
):
self._printer.print(
content="Long term memory is enabled, but entity memory is not enabled. Please configure entity memory or set memory=True to automatically enable it.",
color="bold_yellow",
)
if self.agent and self.agent.verbose:
self._printer.print(
content="Long term memory is enabled, but entity memory is not enabled. Please configure entity memory or set memory=True to automatically enable it.",
color="bold_yellow",
)
def _ask_human_input(self, final_answer: str) -> str:
"""Prompt human input with mode-appropriate messaging.

View File

@@ -206,13 +206,14 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
try:
formatted_answer = self._invoke_loop()
except AssertionError:
self._printer.print(
content="Agent failed to reach a final answer. This is likely a bug - please report it.",
color="red",
)
if self.agent.verbose:
self._printer.print(
content="Agent failed to reach a final answer. This is likely a bug - please report it.",
color="red",
)
raise
except Exception as e:
handle_unknown_error(self._printer, e)
handle_unknown_error(self._printer, e, verbose=self.agent.verbose)
raise
if self.ask_for_human_input:
@@ -327,6 +328,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
messages=self.messages,
llm=self.llm,
callbacks=self.callbacks,
verbose=self.agent.verbose,
)
break
@@ -341,22 +343,41 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
from_agent=self.agent,
response_model=self.response_model,
executor_context=self,
verbose=self.agent.verbose,
)
# breakpoint()
if self.response_model is not None:
try:
self.response_model.model_validate_json(answer)
formatted_answer = AgentFinish(
thought="",
output=answer,
text=answer,
)
if isinstance(answer, BaseModel):
output_json = answer.model_dump_json()
formatted_answer = AgentFinish(
thought="",
output=answer,
text=output_json,
)
else:
self.response_model.model_validate_json(answer)
formatted_answer = AgentFinish(
thought="",
output=answer,
text=answer,
)
except ValidationError:
# If validation fails, convert BaseModel to JSON string for parsing
answer_str = (
answer.model_dump_json()
if isinstance(answer, BaseModel)
else str(answer)
)
formatted_answer = process_llm_response(
answer, self.use_stop_words
answer_str, self.use_stop_words
) # type: ignore[assignment]
else:
formatted_answer = process_llm_response(answer, self.use_stop_words) # type: ignore[assignment]
# When no response_model, answer should be a string
answer_str = str(answer) if not isinstance(answer, str) else answer
formatted_answer = process_llm_response(
answer_str, self.use_stop_words
) # type: ignore[assignment]
if isinstance(formatted_answer, AgentAction):
# Extract agent fingerprint if available
@@ -399,6 +420,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
iterations=self.iterations,
log_error_after=self.log_error_after,
printer=self._printer,
verbose=self.agent.verbose,
)
except Exception as e:
@@ -413,9 +435,10 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
llm=self.llm,
callbacks=self.callbacks,
i18n=self._i18n,
verbose=self.agent.verbose,
)
continue
handle_unknown_error(self._printer, e)
handle_unknown_error(self._printer, e, verbose=self.agent.verbose)
raise e
finally:
self.iterations += 1
@@ -461,6 +484,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
messages=self.messages,
llm=self.llm,
callbacks=self.callbacks,
verbose=self.agent.verbose,
)
self._show_logs(formatted_answer)
return formatted_answer
@@ -482,6 +506,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
from_agent=self.agent,
response_model=self.response_model,
executor_context=self,
verbose=self.agent.verbose,
)
# Check if the response is a list of tool calls
@@ -513,6 +538,18 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
self._show_logs(formatted_answer)
return formatted_answer
if isinstance(answer, BaseModel):
output_json = answer.model_dump_json()
formatted_answer = AgentFinish(
thought="",
output=answer,
text=output_json,
)
self._invoke_step_callback(formatted_answer)
self._append_message(output_json)
self._show_logs(formatted_answer)
return formatted_answer
# Unexpected response type, treat as final answer
formatted_answer = AgentFinish(
thought="",
@@ -535,9 +572,10 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
llm=self.llm,
callbacks=self.callbacks,
i18n=self._i18n,
verbose=self.agent.verbose,
)
continue
handle_unknown_error(self._printer, e)
handle_unknown_error(self._printer, e, verbose=self.agent.verbose)
raise e
finally:
self.iterations += 1
@@ -559,13 +597,23 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
from_agent=self.agent,
response_model=self.response_model,
executor_context=self,
verbose=self.agent.verbose,
)
formatted_answer = AgentFinish(
thought="",
output=str(answer),
text=str(answer),
)
if isinstance(answer, BaseModel):
output_json = answer.model_dump_json()
formatted_answer = AgentFinish(
thought="",
output=answer,
text=output_json,
)
else:
answer_str = answer if isinstance(answer, str) else str(answer)
formatted_answer = AgentFinish(
thought="",
output=answer_str,
text=answer_str,
)
self._show_logs(formatted_answer)
return formatted_answer
@@ -755,10 +803,10 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
track_delegation_if_needed(func_name, args_dict, self.task)
# Find the structured tool for hook context
structured_tool = None
for tool in self.tools or []:
if sanitize_tool_name(tool.name) == func_name:
structured_tool = tool
structured_tool: CrewStructuredTool | None = None
for structured in self.tools or []:
if sanitize_tool_name(structured.name) == func_name:
structured_tool = structured
break
# Execute before_tool_call hooks
@@ -779,10 +827,11 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
hook_blocked = True
break
except Exception as hook_error:
self._printer.print(
content=f"Error in before_tool_call hook: {hook_error}",
color="red",
)
if self.agent.verbose:
self._printer.print(
content=f"Error in before_tool_call hook: {hook_error}",
color="red",
)
# If hook blocked execution, set result and skip tool execution
if hook_blocked:
@@ -848,15 +897,16 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
after_hooks = get_after_tool_call_hooks()
try:
for after_hook in after_hooks:
hook_result = after_hook(after_hook_context)
if hook_result is not None:
result = hook_result
after_hook_result = after_hook(after_hook_context)
if after_hook_result is not None:
result = after_hook_result
after_hook_context.tool_result = result
except Exception as hook_error:
self._printer.print(
content=f"Error in after_tool_call hook: {hook_error}",
color="red",
)
if self.agent.verbose:
self._printer.print(
content=f"Error in after_tool_call hook: {hook_error}",
color="red",
)
# Emit tool usage finished event
crewai_event_bus.emit(
@@ -942,13 +992,14 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
try:
formatted_answer = await self._ainvoke_loop()
except AssertionError:
self._printer.print(
content="Agent failed to reach a final answer. This is likely a bug - please report it.",
color="red",
)
if self.agent.verbose:
self._printer.print(
content="Agent failed to reach a final answer. This is likely a bug - please report it.",
color="red",
)
raise
except Exception as e:
handle_unknown_error(self._printer, e)
handle_unknown_error(self._printer, e, verbose=self.agent.verbose)
raise
if self.ask_for_human_input:
@@ -999,6 +1050,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
messages=self.messages,
llm=self.llm,
callbacks=self.callbacks,
verbose=self.agent.verbose,
)
break
@@ -1013,22 +1065,41 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
from_agent=self.agent,
response_model=self.response_model,
executor_context=self,
verbose=self.agent.verbose,
)
if self.response_model is not None:
try:
self.response_model.model_validate_json(answer)
formatted_answer = AgentFinish(
thought="",
output=answer,
text=answer,
)
if isinstance(answer, BaseModel):
output_json = answer.model_dump_json()
formatted_answer = AgentFinish(
thought="",
output=answer,
text=output_json,
)
else:
self.response_model.model_validate_json(answer)
formatted_answer = AgentFinish(
thought="",
output=answer,
text=answer,
)
except ValidationError:
# If validation fails, convert BaseModel to JSON string for parsing
answer_str = (
answer.model_dump_json()
if isinstance(answer, BaseModel)
else str(answer)
)
formatted_answer = process_llm_response(
answer, self.use_stop_words
answer_str, self.use_stop_words
) # type: ignore[assignment]
else:
formatted_answer = process_llm_response(answer, self.use_stop_words) # type: ignore[assignment]
# When no response_model, answer should be a string
answer_str = str(answer) if not isinstance(answer, str) else answer
formatted_answer = process_llm_response(
answer_str, self.use_stop_words
) # type: ignore[assignment]
if isinstance(formatted_answer, AgentAction):
fingerprint_context = {}
@@ -1070,6 +1141,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
iterations=self.iterations,
log_error_after=self.log_error_after,
printer=self._printer,
verbose=self.agent.verbose,
)
except Exception as e:
@@ -1083,9 +1155,10 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
llm=self.llm,
callbacks=self.callbacks,
i18n=self._i18n,
verbose=self.agent.verbose,
)
continue
handle_unknown_error(self._printer, e)
handle_unknown_error(self._printer, e, verbose=self.agent.verbose)
raise e
finally:
self.iterations += 1
@@ -1125,6 +1198,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
messages=self.messages,
llm=self.llm,
callbacks=self.callbacks,
verbose=self.agent.verbose,
)
self._show_logs(formatted_answer)
return formatted_answer
@@ -1146,6 +1220,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
from_agent=self.agent,
response_model=self.response_model,
executor_context=self,
verbose=self.agent.verbose,
)
# Check if the response is a list of tool calls
if (
@@ -1176,6 +1251,18 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
self._show_logs(formatted_answer)
return formatted_answer
if isinstance(answer, BaseModel):
output_json = answer.model_dump_json()
formatted_answer = AgentFinish(
thought="",
output=answer,
text=output_json,
)
self._invoke_step_callback(formatted_answer)
self._append_message(output_json)
self._show_logs(formatted_answer)
return formatted_answer
# Unexpected response type, treat as final answer
formatted_answer = AgentFinish(
thought="",
@@ -1198,9 +1285,10 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
llm=self.llm,
callbacks=self.callbacks,
i18n=self._i18n,
verbose=self.agent.verbose,
)
continue
handle_unknown_error(self._printer, e)
handle_unknown_error(self._printer, e, verbose=self.agent.verbose)
raise e
finally:
self.iterations += 1
@@ -1222,13 +1310,23 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
from_agent=self.agent,
response_model=self.response_model,
executor_context=self,
verbose=self.agent.verbose,
)
formatted_answer = AgentFinish(
thought="",
output=str(answer),
text=str(answer),
)
if isinstance(answer, BaseModel):
output_json = answer.model_dump_json()
formatted_answer = AgentFinish(
thought="",
output=answer,
text=output_json,
)
else:
answer_str = answer if isinstance(answer, str) else str(answer)
formatted_answer = AgentFinish(
thought="",
output=answer_str,
text=answer_str,
)
self._show_logs(formatted_answer)
return formatted_answer
@@ -1339,10 +1437,11 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
)
if train_iteration is None or not isinstance(train_iteration, int):
self._printer.print(
content="Invalid or missing train iteration. Cannot save training data.",
color="red",
)
if self.agent.verbose:
self._printer.print(
content="Invalid or missing train iteration. Cannot save training data.",
color="red",
)
return
training_handler = CrewTrainingHandler(TRAINING_DATA_FILE)
@@ -1362,13 +1461,14 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
if train_iteration in agent_training_data:
agent_training_data[train_iteration]["improved_output"] = result.output
else:
self._printer.print(
content=(
f"No existing training data for agent {agent_id} and iteration "
f"{train_iteration}. Cannot save improved output."
),
color="red",
)
if self.agent.verbose:
self._printer.print(
content=(
f"No existing training data for agent {agent_id} and iteration "
f"{train_iteration}. Cannot save improved output."
),
color="red",
)
return
# Update the training data and save
@@ -1399,7 +1499,12 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
Returns:
Final answer after feedback.
"""
human_feedback = self._ask_human_input(formatted_answer.output)
output_str = (
formatted_answer.output
if isinstance(formatted_answer.output, str)
else formatted_answer.output.model_dump_json()
)
human_feedback = self._ask_human_input(output_str)
if self._is_training_mode():
return self._handle_training_feedback(formatted_answer, human_feedback)
@@ -1458,7 +1563,12 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
self.ask_for_human_input = False
else:
answer = self._process_feedback_iteration(feedback)
feedback = self._ask_human_input(answer.output)
output_str = (
answer.output
if isinstance(answer.output, str)
else answer.output.model_dump_json()
)
feedback = self._ask_human_input(output_str)
return answer

View File

@@ -8,6 +8,7 @@ AgentAction or AgentFinish objects.
from dataclasses import dataclass
from json_repair import repair_json # type: ignore[import-untyped]
from pydantic import BaseModel
from crewai.agents.constants import (
ACTION_INPUT_ONLY_REGEX,
@@ -40,7 +41,7 @@ class AgentFinish:
"""Represents the final answer from an agent."""
thought: str
output: str
output: str | BaseModel
text: str

View File

@@ -5,7 +5,7 @@ description = "{{name}} using crewAI"
authors = [{ name = "Your Name", email = "you@example.com" }]
requires-python = ">=3.10,<3.14"
dependencies = [
"crewai[tools]==1.9.0"
"crewai[tools]==1.9.2"
]
[project.scripts]

View File

@@ -5,7 +5,7 @@ description = "{{name}} using crewAI"
authors = [{ name = "Your Name", email = "you@example.com" }]
requires-python = ">=3.10,<3.14"
dependencies = [
"crewai[tools]==1.9.0"
"crewai[tools]==1.9.2"
]
[project.scripts]

View File

@@ -1,7 +1,3 @@
from typing import Annotated
from pydantic import Field
from crewai.events.types.a2a_events import (
A2AAgentCardFetchedEvent,
A2AArtifactReceivedEvent,
@@ -106,7 +102,7 @@ from crewai.events.types.tool_usage_events import (
)
EventTypes = Annotated[
EventTypes = (
A2AAgentCardFetchedEvent
| A2AArtifactReceivedEvent
| A2AAuthenticationFailedEvent
@@ -184,6 +180,5 @@ EventTypes = Annotated[
| MCPConnectionFailedEvent
| MCPToolExecutionStartedEvent
| MCPToolExecutionCompletedEvent
| MCPToolExecutionFailedEvent,
Field(discriminator="type"),
]
| MCPToolExecutionFailedEvent
)

View File

@@ -73,7 +73,7 @@ class A2ADelegationStartedEvent(A2AEventBase):
extensions: List of A2A extension URIs in use.
"""
type: Literal["a2a_delegation_started"] = "a2a_delegation_started"
type: str = "a2a_delegation_started"
endpoint: str
task_description: str
agent_id: str
@@ -106,7 +106,7 @@ class A2ADelegationCompletedEvent(A2AEventBase):
extensions: List of A2A extension URIs in use.
"""
type: Literal["a2a_delegation_completed"] = "a2a_delegation_completed"
type: str = "a2a_delegation_completed"
status: str
result: str | None = None
error: str | None = None
@@ -140,7 +140,7 @@ class A2AConversationStartedEvent(A2AEventBase):
extensions: List of A2A extension URIs in use.
"""
type: Literal["a2a_conversation_started"] = "a2a_conversation_started"
type: str = "a2a_conversation_started"
agent_id: str
endpoint: str
context_id: str | None = None
@@ -171,7 +171,7 @@ class A2AMessageSentEvent(A2AEventBase):
extensions: List of A2A extension URIs in use.
"""
type: Literal["a2a_message_sent"] = "a2a_message_sent"
type: str = "a2a_message_sent"
message: str
turn_number: int
context_id: str | None = None
@@ -203,7 +203,7 @@ class A2AResponseReceivedEvent(A2AEventBase):
extensions: List of A2A extension URIs in use.
"""
type: Literal["a2a_response_received"] = "a2a_response_received"
type: str = "a2a_response_received"
response: str
turn_number: int
context_id: str | None = None
@@ -237,7 +237,7 @@ class A2AConversationCompletedEvent(A2AEventBase):
extensions: List of A2A extension URIs in use.
"""
type: Literal["a2a_conversation_completed"] = "a2a_conversation_completed"
type: str = "a2a_conversation_completed"
status: Literal["completed", "failed"]
final_result: str | None = None
error: str | None = None
@@ -263,7 +263,7 @@ class A2APollingStartedEvent(A2AEventBase):
metadata: Custom A2A metadata key-value pairs.
"""
type: Literal["a2a_polling_started"] = "a2a_polling_started"
type: str = "a2a_polling_started"
task_id: str
context_id: str | None = None
polling_interval: float
@@ -286,7 +286,7 @@ class A2APollingStatusEvent(A2AEventBase):
metadata: Custom A2A metadata key-value pairs.
"""
type: Literal["a2a_polling_status"] = "a2a_polling_status"
type: str = "a2a_polling_status"
task_id: str
context_id: str | None = None
state: str
@@ -309,7 +309,7 @@ class A2APushNotificationRegisteredEvent(A2AEventBase):
metadata: Custom A2A metadata key-value pairs.
"""
type: Literal["a2a_push_notification_registered"] = "a2a_push_notification_registered"
type: str = "a2a_push_notification_registered"
task_id: str
context_id: str | None = None
callback_url: str
@@ -334,7 +334,7 @@ class A2APushNotificationReceivedEvent(A2AEventBase):
metadata: Custom A2A metadata key-value pairs.
"""
type: Literal["a2a_push_notification_received"] = "a2a_push_notification_received"
type: str = "a2a_push_notification_received"
task_id: str
context_id: str | None = None
state: str
@@ -359,7 +359,7 @@ class A2APushNotificationSentEvent(A2AEventBase):
metadata: Custom A2A metadata key-value pairs.
"""
type: Literal["a2a_push_notification_sent"] = "a2a_push_notification_sent"
type: str = "a2a_push_notification_sent"
task_id: str
context_id: str | None = None
callback_url: str
@@ -381,7 +381,7 @@ class A2APushNotificationTimeoutEvent(A2AEventBase):
metadata: Custom A2A metadata key-value pairs.
"""
type: Literal["a2a_push_notification_timeout"] = "a2a_push_notification_timeout"
type: str = "a2a_push_notification_timeout"
task_id: str
context_id: str | None = None
timeout_seconds: float
@@ -405,7 +405,7 @@ class A2AStreamingStartedEvent(A2AEventBase):
extensions: List of A2A extension URIs in use.
"""
type: Literal["a2a_streaming_started"] = "a2a_streaming_started"
type: str = "a2a_streaming_started"
task_id: str | None = None
context_id: str | None = None
endpoint: str
@@ -434,7 +434,7 @@ class A2AStreamingChunkEvent(A2AEventBase):
extensions: List of A2A extension URIs in use.
"""
type: Literal["a2a_streaming_chunk"] = "a2a_streaming_chunk"
type: str = "a2a_streaming_chunk"
task_id: str | None = None
context_id: str | None = None
chunk: str
@@ -462,7 +462,7 @@ class A2AAgentCardFetchedEvent(A2AEventBase):
metadata: Custom A2A metadata key-value pairs.
"""
type: Literal["a2a_agent_card_fetched"] = "a2a_agent_card_fetched"
type: str = "a2a_agent_card_fetched"
endpoint: str
a2a_agent_name: str | None = None
agent_card: dict[str, Any] | None = None
@@ -486,7 +486,7 @@ class A2AAuthenticationFailedEvent(A2AEventBase):
metadata: Custom A2A metadata key-value pairs.
"""
type: Literal["a2a_authentication_failed"] = "a2a_authentication_failed"
type: str = "a2a_authentication_failed"
endpoint: str
auth_type: str | None = None
error: str
@@ -517,7 +517,7 @@ class A2AArtifactReceivedEvent(A2AEventBase):
extensions: List of A2A extension URIs in use.
"""
type: Literal["a2a_artifact_received"] = "a2a_artifact_received"
type: str = "a2a_artifact_received"
task_id: str
artifact_id: str
artifact_name: str | None = None
@@ -550,7 +550,7 @@ class A2AConnectionErrorEvent(A2AEventBase):
metadata: Custom A2A metadata key-value pairs.
"""
type: Literal["a2a_connection_error"] = "a2a_connection_error"
type: str = "a2a_connection_error"
endpoint: str
error: str
error_type: str | None = None
@@ -571,7 +571,7 @@ class A2AServerTaskStartedEvent(A2AEventBase):
metadata: Custom A2A metadata key-value pairs.
"""
type: Literal["a2a_server_task_started"] = "a2a_server_task_started"
type: str = "a2a_server_task_started"
task_id: str
context_id: str
metadata: dict[str, Any] | None = None
@@ -587,7 +587,7 @@ class A2AServerTaskCompletedEvent(A2AEventBase):
metadata: Custom A2A metadata key-value pairs.
"""
type: Literal["a2a_server_task_completed"] = "a2a_server_task_completed"
type: str = "a2a_server_task_completed"
task_id: str
context_id: str
result: str
@@ -603,7 +603,7 @@ class A2AServerTaskCanceledEvent(A2AEventBase):
metadata: Custom A2A metadata key-value pairs.
"""
type: Literal["a2a_server_task_canceled"] = "a2a_server_task_canceled"
type: str = "a2a_server_task_canceled"
task_id: str
context_id: str
metadata: dict[str, Any] | None = None
@@ -619,7 +619,7 @@ class A2AServerTaskFailedEvent(A2AEventBase):
metadata: Custom A2A metadata key-value pairs.
"""
type: Literal["a2a_server_task_failed"] = "a2a_server_task_failed"
type: str = "a2a_server_task_failed"
task_id: str
context_id: str
error: str
@@ -634,7 +634,7 @@ class A2AParallelDelegationStartedEvent(A2AEventBase):
task_description: Description of the task being delegated.
"""
type: Literal["a2a_parallel_delegation_started"] = "a2a_parallel_delegation_started"
type: str = "a2a_parallel_delegation_started"
endpoints: list[str]
task_description: str
@@ -649,8 +649,170 @@ class A2AParallelDelegationCompletedEvent(A2AEventBase):
results: Summary of results from each agent.
"""
type: Literal["a2a_parallel_delegation_completed"] = "a2a_parallel_delegation_completed"
type: str = "a2a_parallel_delegation_completed"
endpoints: list[str]
success_count: int
failure_count: int
results: dict[str, str] | None = None
class A2ATransportNegotiatedEvent(A2AEventBase):
"""Event emitted when transport protocol is negotiated with an A2A agent.
This event is emitted after comparing client and server transport capabilities
to select the optimal transport protocol and endpoint URL.
Attributes:
endpoint: Original A2A agent endpoint URL.
a2a_agent_name: Name of the A2A agent from agent card.
negotiated_transport: The transport protocol selected (JSONRPC, GRPC, HTTP+JSON).
negotiated_url: The URL to use for the selected transport.
source: How the transport was selected ('client_preferred', 'server_preferred', 'fallback').
client_supported_transports: Transports the client can use.
server_supported_transports: Transports the server supports.
server_preferred_transport: The server's preferred transport from AgentCard.
client_preferred_transport: The client's preferred transport if set.
metadata: Custom A2A metadata key-value pairs.
"""
type: str = "a2a_transport_negotiated"
endpoint: str
a2a_agent_name: str | None = None
negotiated_transport: str
negotiated_url: str
source: str
client_supported_transports: list[str]
server_supported_transports: list[str]
server_preferred_transport: str
client_preferred_transport: str | None = None
metadata: dict[str, Any] | None = None
class A2AContentTypeNegotiatedEvent(A2AEventBase):
"""Event emitted when content types are negotiated with an A2A agent.
This event is emitted after comparing client and server input/output mode
capabilities to determine compatible MIME types for communication.
Attributes:
endpoint: A2A agent endpoint URL.
a2a_agent_name: Name of the A2A agent from agent card.
skill_name: Skill name if negotiation was skill-specific.
client_input_modes: MIME types the client can send.
client_output_modes: MIME types the client can accept.
server_input_modes: MIME types the server accepts.
server_output_modes: MIME types the server produces.
negotiated_input_modes: Compatible input MIME types selected.
negotiated_output_modes: Compatible output MIME types selected.
negotiation_success: Whether compatible types were found for both directions.
metadata: Custom A2A metadata key-value pairs.
"""
type: str = "a2a_content_type_negotiated"
endpoint: str
a2a_agent_name: str | None = None
skill_name: str | None = None
client_input_modes: list[str]
client_output_modes: list[str]
server_input_modes: list[str]
server_output_modes: list[str]
negotiated_input_modes: list[str]
negotiated_output_modes: list[str]
negotiation_success: bool = True
metadata: dict[str, Any] | None = None
# -----------------------------------------------------------------------------
# Context Lifecycle Events
# -----------------------------------------------------------------------------
class A2AContextCreatedEvent(A2AEventBase):
"""Event emitted when an A2A context is created.
Contexts group related tasks in a conversation or workflow.
Attributes:
context_id: Unique identifier for the context.
created_at: Unix timestamp when context was created.
metadata: Custom A2A metadata key-value pairs.
"""
type: str = "a2a_context_created"
context_id: str
created_at: float
metadata: dict[str, Any] | None = None
class A2AContextExpiredEvent(A2AEventBase):
"""Event emitted when an A2A context expires due to TTL.
Attributes:
context_id: The expired context identifier.
created_at: Unix timestamp when context was created.
age_seconds: How long the context existed before expiring.
task_count: Number of tasks in the context when expired.
metadata: Custom A2A metadata key-value pairs.
"""
type: str = "a2a_context_expired"
context_id: str
created_at: float
age_seconds: float
task_count: int
metadata: dict[str, Any] | None = None
class A2AContextIdleEvent(A2AEventBase):
"""Event emitted when an A2A context becomes idle.
Idle contexts have had no activity for the configured threshold.
Attributes:
context_id: The idle context identifier.
idle_seconds: Seconds since last activity.
task_count: Number of tasks in the context.
metadata: Custom A2A metadata key-value pairs.
"""
type: str = "a2a_context_idle"
context_id: str
idle_seconds: float
task_count: int
metadata: dict[str, Any] | None = None
class A2AContextCompletedEvent(A2AEventBase):
"""Event emitted when all tasks in an A2A context complete.
Attributes:
context_id: The completed context identifier.
total_tasks: Total number of tasks that were in the context.
duration_seconds: Total context lifetime in seconds.
metadata: Custom A2A metadata key-value pairs.
"""
type: str = "a2a_context_completed"
context_id: str
total_tasks: int
duration_seconds: float
metadata: dict[str, Any] | None = None
class A2AContextPrunedEvent(A2AEventBase):
"""Event emitted when an A2A context is pruned (deleted).
Pruning removes the context metadata and optionally associated tasks.
Attributes:
context_id: The pruned context identifier.
task_count: Number of tasks that were in the context.
age_seconds: How long the context existed before pruning.
metadata: Custom A2A metadata key-value pairs.
"""
type: str = "a2a_context_pruned"
context_id: str
task_count: int
age_seconds: float
metadata: dict[str, Any] | None = None

View File

@@ -2,7 +2,8 @@
from __future__ import annotations
from typing import Any, Literal
from collections.abc import Sequence
from typing import Any
from pydantic import ConfigDict, model_validator
@@ -17,9 +18,9 @@ class AgentExecutionStartedEvent(BaseEvent):
agent: BaseAgent
task: Any
tools: list[BaseTool | CrewStructuredTool] | None
tools: Sequence[BaseTool | CrewStructuredTool] | None
task_prompt: str
type: Literal["agent_execution_started"] = "agent_execution_started"
type: str = "agent_execution_started"
model_config = ConfigDict(arbitrary_types_allowed=True)
@@ -43,7 +44,7 @@ class AgentExecutionCompletedEvent(BaseEvent):
agent: BaseAgent
task: Any
output: str
type: Literal["agent_execution_completed"] = "agent_execution_completed"
type: str = "agent_execution_completed"
model_config = ConfigDict(arbitrary_types_allowed=True)
@@ -67,7 +68,7 @@ class AgentExecutionErrorEvent(BaseEvent):
agent: BaseAgent
task: Any
error: str
type: Literal["agent_execution_error"] = "agent_execution_error"
type: str = "agent_execution_error"
model_config = ConfigDict(arbitrary_types_allowed=True)
@@ -90,9 +91,9 @@ class LiteAgentExecutionStartedEvent(BaseEvent):
"""Event emitted when a LiteAgent starts executing"""
agent_info: dict[str, Any]
tools: list[BaseTool | CrewStructuredTool] | None
tools: Sequence[BaseTool | CrewStructuredTool] | None
messages: str | list[dict[str, str]]
type: Literal["lite_agent_execution_started"] = "lite_agent_execution_started"
type: str = "lite_agent_execution_started"
model_config = ConfigDict(arbitrary_types_allowed=True)
@@ -102,7 +103,7 @@ class LiteAgentExecutionCompletedEvent(BaseEvent):
agent_info: dict[str, Any]
output: str
type: Literal["lite_agent_execution_completed"] = "lite_agent_execution_completed"
type: str = "lite_agent_execution_completed"
class LiteAgentExecutionErrorEvent(BaseEvent):
@@ -110,7 +111,7 @@ class LiteAgentExecutionErrorEvent(BaseEvent):
agent_info: dict[str, Any]
error: str
type: Literal["lite_agent_execution_error"] = "lite_agent_execution_error"
type: str = "lite_agent_execution_error"
# Agent Eval events
@@ -119,7 +120,7 @@ class AgentEvaluationStartedEvent(BaseEvent):
agent_role: str
task_id: str | None = None
iteration: int
type: Literal["agent_evaluation_started"] = "agent_evaluation_started"
type: str = "agent_evaluation_started"
class AgentEvaluationCompletedEvent(BaseEvent):
@@ -129,7 +130,7 @@ class AgentEvaluationCompletedEvent(BaseEvent):
iteration: int
metric_category: Any
score: Any
type: Literal["agent_evaluation_completed"] = "agent_evaluation_completed"
type: str = "agent_evaluation_completed"
class AgentEvaluationFailedEvent(BaseEvent):
@@ -138,4 +139,4 @@ class AgentEvaluationFailedEvent(BaseEvent):
task_id: str | None = None
iteration: int
error: str
type: Literal["agent_evaluation_failed"] = "agent_evaluation_failed"
type: str = "agent_evaluation_failed"

View File

@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Any, Literal
from typing import TYPE_CHECKING, Any
from crewai.events.base_events import BaseEvent
@@ -40,14 +40,14 @@ class CrewKickoffStartedEvent(CrewBaseEvent):
"""Event emitted when a crew starts execution"""
inputs: dict[str, Any] | None
type: Literal["crew_kickoff_started"] = "crew_kickoff_started"
type: str = "crew_kickoff_started"
class CrewKickoffCompletedEvent(CrewBaseEvent):
"""Event emitted when a crew completes execution"""
output: Any
type: Literal["crew_kickoff_completed"] = "crew_kickoff_completed"
type: str = "crew_kickoff_completed"
total_tokens: int = 0
@@ -55,7 +55,7 @@ class CrewKickoffFailedEvent(CrewBaseEvent):
"""Event emitted when a crew fails to complete execution"""
error: str
type: Literal["crew_kickoff_failed"] = "crew_kickoff_failed"
type: str = "crew_kickoff_failed"
class CrewTrainStartedEvent(CrewBaseEvent):
@@ -64,7 +64,7 @@ class CrewTrainStartedEvent(CrewBaseEvent):
n_iterations: int
filename: str
inputs: dict[str, Any] | None
type: Literal["crew_train_started"] = "crew_train_started"
type: str = "crew_train_started"
class CrewTrainCompletedEvent(CrewBaseEvent):
@@ -72,14 +72,14 @@ class CrewTrainCompletedEvent(CrewBaseEvent):
n_iterations: int
filename: str
type: Literal["crew_train_completed"] = "crew_train_completed"
type: str = "crew_train_completed"
class CrewTrainFailedEvent(CrewBaseEvent):
"""Event emitted when a crew fails to complete training"""
error: str
type: Literal["crew_train_failed"] = "crew_train_failed"
type: str = "crew_train_failed"
class CrewTestStartedEvent(CrewBaseEvent):
@@ -88,20 +88,20 @@ class CrewTestStartedEvent(CrewBaseEvent):
n_iterations: int
eval_llm: str | Any | None
inputs: dict[str, Any] | None
type: Literal["crew_test_started"] = "crew_test_started"
type: str = "crew_test_started"
class CrewTestCompletedEvent(CrewBaseEvent):
"""Event emitted when a crew completes testing"""
type: Literal["crew_test_completed"] = "crew_test_completed"
type: str = "crew_test_completed"
class CrewTestFailedEvent(CrewBaseEvent):
"""Event emitted when a crew fails to complete testing"""
error: str
type: Literal["crew_test_failed"] = "crew_test_failed"
type: str = "crew_test_failed"
class CrewTestResultEvent(CrewBaseEvent):
@@ -110,4 +110,4 @@ class CrewTestResultEvent(CrewBaseEvent):
quality: float
execution_duration: float
model: str
type: Literal["crew_test_result"] = "crew_test_result"
type: str = "crew_test_result"

View File

@@ -1,4 +1,4 @@
from typing import Any, Literal
from typing import Any
from pydantic import BaseModel, ConfigDict
@@ -17,14 +17,14 @@ class FlowStartedEvent(FlowEvent):
flow_name: str
inputs: dict[str, Any] | None = None
type: Literal["flow_started"] = "flow_started"
type: str = "flow_started"
class FlowCreatedEvent(FlowEvent):
"""Event emitted when a flow is created"""
flow_name: str
type: Literal["flow_created"] = "flow_created"
type: str = "flow_created"
class MethodExecutionStartedEvent(FlowEvent):
@@ -34,7 +34,7 @@ class MethodExecutionStartedEvent(FlowEvent):
method_name: str
state: dict[str, Any] | BaseModel
params: dict[str, Any] | None = None
type: Literal["method_execution_started"] = "method_execution_started"
type: str = "method_execution_started"
class MethodExecutionFinishedEvent(FlowEvent):
@@ -44,7 +44,7 @@ class MethodExecutionFinishedEvent(FlowEvent):
method_name: str
result: Any = None
state: dict[str, Any] | BaseModel
type: Literal["method_execution_finished"] = "method_execution_finished"
type: str = "method_execution_finished"
class MethodExecutionFailedEvent(FlowEvent):
@@ -53,7 +53,7 @@ class MethodExecutionFailedEvent(FlowEvent):
flow_name: str
method_name: str
error: Exception
type: Literal["method_execution_failed"] = "method_execution_failed"
type: str = "method_execution_failed"
model_config = ConfigDict(arbitrary_types_allowed=True)
@@ -78,7 +78,7 @@ class MethodExecutionPausedEvent(FlowEvent):
flow_id: str
message: str
emit: list[str] | None = None
type: Literal["method_execution_paused"] = "method_execution_paused"
type: str = "method_execution_paused"
class FlowFinishedEvent(FlowEvent):
@@ -86,7 +86,7 @@ class FlowFinishedEvent(FlowEvent):
flow_name: str
result: Any | None = None
type: Literal["flow_finished"] = "flow_finished"
type: str = "flow_finished"
state: dict[str, Any] | BaseModel
@@ -110,14 +110,14 @@ class FlowPausedEvent(FlowEvent):
state: dict[str, Any] | BaseModel
message: str
emit: list[str] | None = None
type: Literal["flow_paused"] = "flow_paused"
type: str = "flow_paused"
class FlowPlotEvent(FlowEvent):
"""Event emitted when a flow plot is created"""
flow_name: str
type: Literal["flow_plot"] = "flow_plot"
type: str = "flow_plot"
class HumanFeedbackRequestedEvent(FlowEvent):
@@ -138,7 +138,7 @@ class HumanFeedbackRequestedEvent(FlowEvent):
output: Any
message: str
emit: list[str] | None = None
type: Literal["human_feedback_requested"] = "human_feedback_requested"
type: str = "human_feedback_requested"
class HumanFeedbackReceivedEvent(FlowEvent):
@@ -157,4 +157,4 @@ class HumanFeedbackReceivedEvent(FlowEvent):
method_name: str
feedback: str
outcome: str | None = None
type: Literal["human_feedback_received"] = "human_feedback_received"
type: str = "human_feedback_received"

View File

@@ -1,4 +1,4 @@
from typing import Any, Literal
from typing import Any
from crewai.events.base_events import BaseEvent
@@ -20,14 +20,14 @@ class KnowledgeEventBase(BaseEvent):
class KnowledgeRetrievalStartedEvent(KnowledgeEventBase):
"""Event emitted when a knowledge retrieval is started."""
type: Literal["knowledge_search_query_started"] = "knowledge_search_query_started"
type: str = "knowledge_search_query_started"
class KnowledgeRetrievalCompletedEvent(KnowledgeEventBase):
"""Event emitted when a knowledge retrieval is completed."""
query: str
type: Literal["knowledge_search_query_completed"] = "knowledge_search_query_completed"
type: str = "knowledge_search_query_completed"
retrieved_knowledge: str
@@ -35,13 +35,13 @@ class KnowledgeQueryStartedEvent(KnowledgeEventBase):
"""Event emitted when a knowledge query is started."""
task_prompt: str
type: Literal["knowledge_query_started"] = "knowledge_query_started"
type: str = "knowledge_query_started"
class KnowledgeQueryFailedEvent(KnowledgeEventBase):
"""Event emitted when a knowledge query fails."""
type: Literal["knowledge_query_failed"] = "knowledge_query_failed"
type: str = "knowledge_query_failed"
error: str
@@ -49,12 +49,12 @@ class KnowledgeQueryCompletedEvent(KnowledgeEventBase):
"""Event emitted when a knowledge query is completed."""
query: str
type: Literal["knowledge_query_completed"] = "knowledge_query_completed"
type: str = "knowledge_query_completed"
class KnowledgeSearchQueryFailedEvent(KnowledgeEventBase):
"""Event emitted when a knowledge search query fails."""
query: str
type: Literal["knowledge_search_query_failed"] = "knowledge_search_query_failed"
type: str = "knowledge_search_query_failed"
error: str

View File

@@ -1,5 +1,5 @@
from enum import Enum
from typing import Any, Literal
from typing import Any
from pydantic import BaseModel
@@ -42,7 +42,7 @@ class LLMCallStartedEvent(LLMEventBase):
multimodal content (text, images, etc.)
"""
type: Literal["llm_call_started"] = "llm_call_started"
type: str = "llm_call_started"
messages: str | list[dict[str, Any]] | None = None
tools: list[dict[str, Any]] | None = None
callbacks: list[Any] | None = None
@@ -52,7 +52,7 @@ class LLMCallStartedEvent(LLMEventBase):
class LLMCallCompletedEvent(LLMEventBase):
"""Event emitted when a LLM call completes"""
type: Literal["llm_call_completed"] = "llm_call_completed"
type: str = "llm_call_completed"
messages: str | list[dict[str, Any]] | None = None
response: Any
call_type: LLMCallType
@@ -62,7 +62,7 @@ class LLMCallFailedEvent(LLMEventBase):
"""Event emitted when a LLM call fails"""
error: str
type: Literal["llm_call_failed"] = "llm_call_failed"
type: str = "llm_call_failed"
class FunctionCall(BaseModel):
@@ -80,7 +80,7 @@ class ToolCall(BaseModel):
class LLMStreamChunkEvent(LLMEventBase):
"""Event emitted when a streaming chunk is received"""
type: Literal["llm_stream_chunk"] = "llm_stream_chunk"
type: str = "llm_stream_chunk"
chunk: str
tool_call: ToolCall | None = None
call_type: LLMCallType | None = None

View File

@@ -1,6 +1,6 @@
from collections.abc import Callable
from inspect import getsource
from typing import Any, Literal
from typing import Any
from crewai.events.base_events import BaseEvent
@@ -27,7 +27,7 @@ class LLMGuardrailStartedEvent(LLMGuardrailBaseEvent):
retry_count: The number of times the guardrail has been retried
"""
type: Literal["llm_guardrail_started"] = "llm_guardrail_started"
type: str = "llm_guardrail_started"
guardrail: str | Callable
retry_count: int
@@ -53,7 +53,7 @@ class LLMGuardrailCompletedEvent(LLMGuardrailBaseEvent):
retry_count: The number of times the guardrail has been retried
"""
type: Literal["llm_guardrail_completed"] = "llm_guardrail_completed"
type: str = "llm_guardrail_completed"
success: bool
result: Any
error: str | None = None
@@ -68,6 +68,6 @@ class LLMGuardrailFailedEvent(LLMGuardrailBaseEvent):
retry_count: The number of times the guardrail has been retried
"""
type: Literal["llm_guardrail_failed"] = "llm_guardrail_failed"
type: str = "llm_guardrail_failed"
error: str
retry_count: int

View File

@@ -1,6 +1,6 @@
"""Agent logging events that don't reference BaseAgent to avoid circular imports."""
from typing import Any, Literal
from typing import Any
from pydantic import ConfigDict
@@ -13,7 +13,7 @@ class AgentLogsStartedEvent(BaseEvent):
agent_role: str
task_description: str | None = None
verbose: bool = False
type: Literal["agent_logs_started"] = "agent_logs_started"
type: str = "agent_logs_started"
class AgentLogsExecutionEvent(BaseEvent):
@@ -22,6 +22,6 @@ class AgentLogsExecutionEvent(BaseEvent):
agent_role: str
formatted_answer: Any
verbose: bool = False
type: Literal["agent_logs_execution"] = "agent_logs_execution"
type: str = "agent_logs_execution"
model_config = ConfigDict(arbitrary_types_allowed=True)

View File

@@ -1,5 +1,5 @@
from datetime import datetime
from typing import Any, Literal
from typing import Any
from crewai.events.base_events import BaseEvent
@@ -24,7 +24,7 @@ class MCPEvent(BaseEvent):
class MCPConnectionStartedEvent(MCPEvent):
"""Event emitted when starting to connect to an MCP server."""
type: Literal["mcp_connection_started"] = "mcp_connection_started"
type: str = "mcp_connection_started"
connect_timeout: int | None = None
is_reconnect: bool = (
False # True if this is a reconnection, False for first connection
@@ -34,7 +34,7 @@ class MCPConnectionStartedEvent(MCPEvent):
class MCPConnectionCompletedEvent(MCPEvent):
"""Event emitted when successfully connected to an MCP server."""
type: Literal["mcp_connection_completed"] = "mcp_connection_completed"
type: str = "mcp_connection_completed"
started_at: datetime | None = None
completed_at: datetime | None = None
connection_duration_ms: float | None = None
@@ -46,7 +46,7 @@ class MCPConnectionCompletedEvent(MCPEvent):
class MCPConnectionFailedEvent(MCPEvent):
"""Event emitted when connection to an MCP server fails."""
type: Literal["mcp_connection_failed"] = "mcp_connection_failed"
type: str = "mcp_connection_failed"
error: str
error_type: str | None = None # "timeout", "authentication", "network", etc.
started_at: datetime | None = None
@@ -56,7 +56,7 @@ class MCPConnectionFailedEvent(MCPEvent):
class MCPToolExecutionStartedEvent(MCPEvent):
"""Event emitted when starting to execute an MCP tool."""
type: Literal["mcp_tool_execution_started"] = "mcp_tool_execution_started"
type: str = "mcp_tool_execution_started"
tool_name: str
tool_args: dict[str, Any] | None = None
@@ -64,7 +64,7 @@ class MCPToolExecutionStartedEvent(MCPEvent):
class MCPToolExecutionCompletedEvent(MCPEvent):
"""Event emitted when MCP tool execution completes."""
type: Literal["mcp_tool_execution_completed"] = "mcp_tool_execution_completed"
type: str = "mcp_tool_execution_completed"
tool_name: str
tool_args: dict[str, Any] | None = None
result: Any | None = None
@@ -76,7 +76,7 @@ class MCPToolExecutionCompletedEvent(MCPEvent):
class MCPToolExecutionFailedEvent(MCPEvent):
"""Event emitted when MCP tool execution fails."""
type: Literal["mcp_tool_execution_failed"] = "mcp_tool_execution_failed"
type: str = "mcp_tool_execution_failed"
tool_name: str
tool_args: dict[str, Any] | None = None
error: str

View File

@@ -1,4 +1,4 @@
from typing import Any, Literal
from typing import Any
from crewai.events.base_events import BaseEvent
@@ -23,7 +23,7 @@ class MemoryBaseEvent(BaseEvent):
class MemoryQueryStartedEvent(MemoryBaseEvent):
"""Event emitted when a memory query is started"""
type: Literal["memory_query_started"] = "memory_query_started"
type: str = "memory_query_started"
query: str
limit: int
score_threshold: float | None = None
@@ -32,7 +32,7 @@ class MemoryQueryStartedEvent(MemoryBaseEvent):
class MemoryQueryCompletedEvent(MemoryBaseEvent):
"""Event emitted when a memory query is completed successfully"""
type: Literal["memory_query_completed"] = "memory_query_completed"
type: str = "memory_query_completed"
query: str
results: Any
limit: int
@@ -43,7 +43,7 @@ class MemoryQueryCompletedEvent(MemoryBaseEvent):
class MemoryQueryFailedEvent(MemoryBaseEvent):
"""Event emitted when a memory query fails"""
type: Literal["memory_query_failed"] = "memory_query_failed"
type: str = "memory_query_failed"
query: str
limit: int
score_threshold: float | None = None
@@ -53,7 +53,7 @@ class MemoryQueryFailedEvent(MemoryBaseEvent):
class MemorySaveStartedEvent(MemoryBaseEvent):
"""Event emitted when a memory save operation is started"""
type: Literal["memory_save_started"] = "memory_save_started"
type: str = "memory_save_started"
value: str | None = None
metadata: dict[str, Any] | None = None
agent_role: str | None = None
@@ -62,7 +62,7 @@ class MemorySaveStartedEvent(MemoryBaseEvent):
class MemorySaveCompletedEvent(MemoryBaseEvent):
"""Event emitted when a memory save operation is completed successfully"""
type: Literal["memory_save_completed"] = "memory_save_completed"
type: str = "memory_save_completed"
value: str
metadata: dict[str, Any] | None = None
agent_role: str | None = None
@@ -72,7 +72,7 @@ class MemorySaveCompletedEvent(MemoryBaseEvent):
class MemorySaveFailedEvent(MemoryBaseEvent):
"""Event emitted when a memory save operation fails"""
type: Literal["memory_save_failed"] = "memory_save_failed"
type: str = "memory_save_failed"
value: str | None = None
metadata: dict[str, Any] | None = None
agent_role: str | None = None
@@ -82,14 +82,14 @@ class MemorySaveFailedEvent(MemoryBaseEvent):
class MemoryRetrievalStartedEvent(MemoryBaseEvent):
"""Event emitted when memory retrieval for a task prompt starts"""
type: Literal["memory_retrieval_started"] = "memory_retrieval_started"
type: str = "memory_retrieval_started"
task_id: str | None = None
class MemoryRetrievalCompletedEvent(MemoryBaseEvent):
"""Event emitted when memory retrieval for a task prompt completes successfully"""
type: Literal["memory_retrieval_completed"] = "memory_retrieval_completed"
type: str = "memory_retrieval_completed"
task_id: str | None = None
memory_content: str
retrieval_time_ms: float
@@ -98,6 +98,6 @@ class MemoryRetrievalCompletedEvent(MemoryBaseEvent):
class MemoryRetrievalFailedEvent(MemoryBaseEvent):
"""Event emitted when memory retrieval for a task prompt fails."""
type: Literal["memory_retrieval_failed"] = "memory_retrieval_failed"
type: str = "memory_retrieval_failed"
task_id: str | None = None
error: str

View File

@@ -1,4 +1,4 @@
from typing import Any, Literal
from typing import Any
from crewai.events.base_events import BaseEvent
@@ -24,7 +24,7 @@ class ReasoningEvent(BaseEvent):
class AgentReasoningStartedEvent(ReasoningEvent):
"""Event emitted when an agent starts reasoning about a task."""
type: Literal["agent_reasoning_started"] = "agent_reasoning_started"
type: str = "agent_reasoning_started"
agent_role: str
task_id: str
@@ -32,7 +32,7 @@ class AgentReasoningStartedEvent(ReasoningEvent):
class AgentReasoningCompletedEvent(ReasoningEvent):
"""Event emitted when an agent finishes its reasoning process."""
type: Literal["agent_reasoning_completed"] = "agent_reasoning_completed"
type: str = "agent_reasoning_completed"
agent_role: str
task_id: str
plan: str
@@ -42,7 +42,7 @@ class AgentReasoningCompletedEvent(ReasoningEvent):
class AgentReasoningFailedEvent(ReasoningEvent):
"""Event emitted when the reasoning process fails."""
type: Literal["agent_reasoning_failed"] = "agent_reasoning_failed"
type: str = "agent_reasoning_failed"
agent_role: str
task_id: str
error: str

View File

@@ -1,4 +1,4 @@
from typing import Any, Literal
from typing import Any
from crewai.events.base_events import BaseEvent
from crewai.tasks.task_output import TaskOutput
@@ -7,7 +7,7 @@ from crewai.tasks.task_output import TaskOutput
class TaskStartedEvent(BaseEvent):
"""Event emitted when a task starts"""
type: Literal["task_started"] = "task_started"
type: str = "task_started"
context: str | None
task: Any | None = None
@@ -28,7 +28,7 @@ class TaskCompletedEvent(BaseEvent):
"""Event emitted when a task completes"""
output: TaskOutput
type: Literal["task_completed"] = "task_completed"
type: str = "task_completed"
task: Any | None = None
def __init__(self, **data):
@@ -48,7 +48,7 @@ class TaskFailedEvent(BaseEvent):
"""Event emitted when a task fails"""
error: str
type: Literal["task_failed"] = "task_failed"
type: str = "task_failed"
task: Any | None = None
def __init__(self, **data):
@@ -67,7 +67,7 @@ class TaskFailedEvent(BaseEvent):
class TaskEvaluationEvent(BaseEvent):
"""Event emitted when a task evaluation is completed"""
type: Literal["task_evaluation"] = "task_evaluation"
type: str = "task_evaluation"
evaluation_type: str
task: Any | None = None

View File

@@ -1,6 +1,6 @@
from collections.abc import Callable
from datetime import datetime
from typing import Any, Literal
from typing import Any
from pydantic import ConfigDict
@@ -55,7 +55,7 @@ class ToolUsageEvent(BaseEvent):
class ToolUsageStartedEvent(ToolUsageEvent):
"""Event emitted when a tool execution is started"""
type: Literal["tool_usage_started"] = "tool_usage_started"
type: str = "tool_usage_started"
class ToolUsageFinishedEvent(ToolUsageEvent):
@@ -65,35 +65,35 @@ class ToolUsageFinishedEvent(ToolUsageEvent):
finished_at: datetime
from_cache: bool = False
output: Any
type: Literal["tool_usage_finished"] = "tool_usage_finished"
type: str = "tool_usage_finished"
class ToolUsageErrorEvent(ToolUsageEvent):
"""Event emitted when a tool execution encounters an error"""
error: Any
type: Literal["tool_usage_error"] = "tool_usage_error"
type: str = "tool_usage_error"
class ToolValidateInputErrorEvent(ToolUsageEvent):
"""Event emitted when a tool input validation encounters an error"""
error: Any
type: Literal["tool_validate_input_error"] = "tool_validate_input_error"
type: str = "tool_validate_input_error"
class ToolSelectionErrorEvent(ToolUsageEvent):
"""Event emitted when a tool selection encounters an error"""
error: Any
type: Literal["tool_selection_error"] = "tool_selection_error"
type: str = "tool_selection_error"
class ToolExecutionErrorEvent(BaseEvent):
"""Event emitted when a tool execution encounters an error"""
error: Any
type: Literal["tool_execution_error"] = "tool_execution_error"
type: str = "tool_execution_error"
tool_name: str
tool_args: dict[str, Any]
tool_class: Callable

View File

@@ -341,6 +341,7 @@ class AgentExecutor(Flow[AgentReActState], CrewAgentExecutorMixin):
messages=list(self.state.messages),
llm=self.llm,
callbacks=self.callbacks,
verbose=self.agent.verbose,
)
self.state.current_answer = formatted_answer
@@ -366,6 +367,7 @@ class AgentExecutor(Flow[AgentReActState], CrewAgentExecutorMixin):
from_agent=self.agent,
response_model=None,
executor_context=self,
verbose=self.agent.verbose,
)
# Parse the LLM response
@@ -401,7 +403,7 @@ class AgentExecutor(Flow[AgentReActState], CrewAgentExecutorMixin):
return "context_error"
if e.__class__.__module__.startswith("litellm"):
raise e
handle_unknown_error(self._printer, e)
handle_unknown_error(self._printer, e, verbose=self.agent.verbose)
raise
@listen("continue_reasoning_native")
@@ -436,6 +438,7 @@ class AgentExecutor(Flow[AgentReActState], CrewAgentExecutorMixin):
from_agent=self.agent,
response_model=None,
executor_context=self,
verbose=self.agent.verbose,
)
# Check if the response is a list of tool calls
@@ -474,7 +477,7 @@ class AgentExecutor(Flow[AgentReActState], CrewAgentExecutorMixin):
return "context_error"
if e.__class__.__module__.startswith("litellm"):
raise e
handle_unknown_error(self._printer, e)
handle_unknown_error(self._printer, e, verbose=self.agent.verbose)
raise
@router(call_llm_and_parse)
@@ -670,10 +673,10 @@ class AgentExecutor(Flow[AgentReActState], CrewAgentExecutorMixin):
track_delegation_if_needed(func_name, args_dict, self.task)
structured_tool = None
for tool in self.tools or []:
if sanitize_tool_name(tool.name) == func_name:
structured_tool = tool
structured_tool: CrewStructuredTool | None = None
for structured in self.tools or []:
if sanitize_tool_name(structured.name) == func_name:
structured_tool = structured
break
hook_blocked = False
@@ -693,10 +696,11 @@ class AgentExecutor(Flow[AgentReActState], CrewAgentExecutorMixin):
hook_blocked = True
break
except Exception as hook_error:
self._printer.print(
content=f"Error in before_tool_call hook: {hook_error}",
color="red",
)
if self.agent.verbose:
self._printer.print(
content=f"Error in before_tool_call hook: {hook_error}",
color="red",
)
if hook_blocked:
result = f"Tool execution blocked by hook. Tool: {func_name}"
@@ -758,15 +762,16 @@ class AgentExecutor(Flow[AgentReActState], CrewAgentExecutorMixin):
after_hooks = get_after_tool_call_hooks()
try:
for after_hook in after_hooks:
hook_result = after_hook(after_hook_context)
if hook_result is not None:
result = hook_result
after_hook_result = after_hook(after_hook_context)
if after_hook_result is not None:
result = after_hook_result
after_hook_context.tool_result = result
except Exception as hook_error:
self._printer.print(
content=f"Error in after_tool_call hook: {hook_error}",
color="red",
)
if self.agent.verbose:
self._printer.print(
content=f"Error in after_tool_call hook: {hook_error}",
color="red",
)
# Emit tool usage finished event
crewai_event_bus.emit(
@@ -814,15 +819,6 @@ class AgentExecutor(Flow[AgentReActState], CrewAgentExecutorMixin):
self.state.is_finished = True
return "tool_result_is_final"
# Add reflection prompt once after all tools in the batch
reasoning_prompt = self._i18n.slice("post_tool_reasoning")
reasoning_message: LLMMessage = {
"role": "user",
"content": reasoning_prompt,
}
self.state.messages.append(reasoning_message)
return "native_tool_completed"
def _extract_tool_name(self, tool_call: Any) -> str:
@@ -911,6 +907,7 @@ class AgentExecutor(Flow[AgentReActState], CrewAgentExecutorMixin):
iterations=self.state.iterations,
log_error_after=self.log_error_after,
printer=self._printer,
verbose=self.agent.verbose,
)
if formatted_answer:
@@ -930,6 +927,7 @@ class AgentExecutor(Flow[AgentReActState], CrewAgentExecutorMixin):
llm=self.llm,
callbacks=self.callbacks,
i18n=self._i18n,
verbose=self.agent.verbose,
)
self.state.iterations += 1
@@ -1021,7 +1019,7 @@ class AgentExecutor(Flow[AgentReActState], CrewAgentExecutorMixin):
self._console.print(fail_text)
raise
except Exception as e:
handle_unknown_error(self._printer, e)
handle_unknown_error(self._printer, e, verbose=self.agent.verbose)
raise
finally:
self._is_executing = False
@@ -1106,7 +1104,7 @@ class AgentExecutor(Flow[AgentReActState], CrewAgentExecutorMixin):
self._console.print(fail_text)
raise
except Exception as e:
handle_unknown_error(self._printer, e)
handle_unknown_error(self._printer, e, verbose=self.agent.verbose)
raise
finally:
self._is_executing = False

View File

@@ -7,7 +7,7 @@ for building event-driven workflows with conditional execution and routing.
from __future__ import annotations
import asyncio
from collections.abc import Callable
from collections.abc import Callable, Sequence
from concurrent.futures import Future
import copy
import inspect
@@ -2382,7 +2382,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
message: str,
output: Any,
metadata: dict[str, Any] | None = None,
emit: list[str] | None = None,
emit: Sequence[str] | None = None,
) -> str:
"""Request feedback from a human.
Args:
@@ -2453,7 +2453,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
def _collapse_to_outcome(
self,
feedback: str,
outcomes: list[str],
outcomes: Sequence[str],
llm: str | BaseLLM,
) -> str:
"""Collapse free-form feedback to a predefined outcome using LLM.

View File

@@ -53,7 +53,7 @@ Example (asynchronous with custom provider):
from __future__ import annotations
import asyncio
from collections.abc import Callable
from collections.abc import Callable, Sequence
from dataclasses import dataclass, field
from datetime import datetime
from functools import wraps
@@ -128,7 +128,7 @@ class HumanFeedbackConfig:
"""
message: str
emit: list[str] | None = None
emit: Sequence[str] | None = None
llm: str | BaseLLM | None = None
default_outcome: str | None = None
metadata: dict[str, Any] | None = None
@@ -154,7 +154,7 @@ class HumanFeedbackMethod(FlowMethod[Any, Any]):
def human_feedback(
message: str,
emit: list[str] | None = None,
emit: Sequence[str] | None = None,
llm: str | BaseLLM | None = None,
default_outcome: str | None = None,
metadata: dict[str, Any] | None = None,

View File

@@ -118,17 +118,20 @@ class PersistenceDecorator:
)
except Exception as e:
error_msg = LOG_MESSAGES["save_error"].format(method_name, str(e))
cls._printer.print(error_msg, color="red")
if verbose:
cls._printer.print(error_msg, color="red")
logger.error(error_msg)
raise RuntimeError(f"State persistence failed: {e!s}") from e
except AttributeError as e:
error_msg = LOG_MESSAGES["state_missing"]
cls._printer.print(error_msg, color="red")
if verbose:
cls._printer.print(error_msg, color="red")
logger.error(error_msg)
raise ValueError(error_msg) from e
except (TypeError, ValueError) as e:
error_msg = LOG_MESSAGES["id_missing"]
cls._printer.print(error_msg, color="red")
if verbose:
cls._printer.print(error_msg, color="red")
logger.error(error_msg)
raise ValueError(error_msg) from e

View File

@@ -151,7 +151,9 @@ def _unwrap_function(function: Any) -> Any:
return function
def get_possible_return_constants(function: Any) -> list[str] | None:
def get_possible_return_constants(
function: Any, verbose: bool = True
) -> list[str] | None:
"""Extract possible string return values from a function using AST parsing.
This function analyzes the source code of a router method to identify
@@ -178,10 +180,11 @@ def get_possible_return_constants(function: Any) -> list[str] | None:
# Can't get source code
return None
except Exception as e:
_printer.print(
f"Error retrieving source code for function {function.__name__}: {e}",
color="red",
)
if verbose:
_printer.print(
f"Error retrieving source code for function {function.__name__}: {e}",
color="red",
)
return None
try:
@@ -190,25 +193,28 @@ def get_possible_return_constants(function: Any) -> list[str] | None:
# Parse the source code into an AST
code_ast = ast.parse(source)
except IndentationError as e:
_printer.print(
f"IndentationError while parsing source code of {function.__name__}: {e}",
color="red",
)
_printer.print(f"Source code:\n{source}", color="yellow")
if verbose:
_printer.print(
f"IndentationError while parsing source code of {function.__name__}: {e}",
color="red",
)
_printer.print(f"Source code:\n{source}", color="yellow")
return None
except SyntaxError as e:
_printer.print(
f"SyntaxError while parsing source code of {function.__name__}: {e}",
color="red",
)
_printer.print(f"Source code:\n{source}", color="yellow")
if verbose:
_printer.print(
f"SyntaxError while parsing source code of {function.__name__}: {e}",
color="red",
)
_printer.print(f"Source code:\n{source}", color="yellow")
return None
except Exception as e:
_printer.print(
f"Unexpected error while parsing source code of {function.__name__}: {e}",
color="red",
)
_printer.print(f"Source code:\n{source}", color="yellow")
if verbose:
_printer.print(
f"Unexpected error while parsing source code of {function.__name__}: {e}",
color="red",
)
_printer.print(f"Source code:\n{source}", color="yellow")
return None
return_values: set[str] = set()
@@ -388,15 +394,17 @@ def get_possible_return_constants(function: Any) -> list[str] | None:
StateAttributeVisitor().visit(class_ast)
except Exception as e:
_printer.print(
f"Could not analyze class context for {function.__name__}: {e}",
color="yellow",
)
if verbose:
_printer.print(
f"Could not analyze class context for {function.__name__}: {e}",
color="yellow",
)
except Exception as e:
_printer.print(
f"Could not introspect class for {function.__name__}: {e}",
color="yellow",
)
if verbose:
_printer.print(
f"Could not introspect class for {function.__name__}: {e}",
color="yellow",
)
VariableAssignmentVisitor().visit(code_ast)
ReturnVisitor().visit(code_ast)

View File

@@ -2,8 +2,10 @@ from __future__ import annotations
import asyncio
from collections.abc import Callable
from functools import wraps
import inspect
import json
from types import MethodType
from typing import (
TYPE_CHECKING,
Any,
@@ -30,6 +32,8 @@ from typing_extensions import Self
if TYPE_CHECKING:
from crewai_files import FileInput
from crewai.a2a.config import A2AClientConfig, A2AConfig, A2AServerConfig
from crewai.agents.agent_builder.base_agent import BaseAgent
from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess
from crewai.agents.cache.cache_handler import CacheHandler
@@ -72,18 +76,92 @@ from crewai.utilities.agent_utils import (
from crewai.utilities.converter import (
Converter,
ConverterError,
generate_model_description,
)
from crewai.utilities.guardrail import process_guardrail
from crewai.utilities.guardrail_types import GuardrailCallable, GuardrailType
from crewai.utilities.i18n import I18N, get_i18n
from crewai.utilities.llm_utils import create_llm
from crewai.utilities.printer import Printer
from crewai.utilities.pydantic_schema_utils import generate_model_description
from crewai.utilities.token_counter_callback import TokenCalcHandler
from crewai.utilities.tool_utils import execute_tool_and_check_finality
from crewai.utilities.types import LLMMessage
def _kickoff_with_a2a_support(
agent: LiteAgent,
original_kickoff: Callable[..., LiteAgentOutput],
messages: str | list[LLMMessage],
response_format: type[BaseModel] | None,
input_files: dict[str, FileInput] | None,
extension_registry: Any,
) -> LiteAgentOutput:
"""Wrap kickoff with A2A delegation using Task adapter.
Args:
agent: The LiteAgent instance.
original_kickoff: The original kickoff method.
messages: Input messages.
response_format: Optional response format.
input_files: Optional input files.
extension_registry: A2A extension registry.
Returns:
LiteAgentOutput from either local execution or A2A delegation.
"""
from crewai.a2a.utils.response_model import get_a2a_agents_and_response_model
from crewai.a2a.wrapper import _execute_task_with_a2a
from crewai.task import Task
a2a_agents, agent_response_model = get_a2a_agents_and_response_model(agent.a2a)
if not a2a_agents:
return original_kickoff(messages, response_format, input_files)
if isinstance(messages, str):
description = messages
else:
content = next(
(m["content"] for m in reversed(messages) if m["role"] == "user"),
None,
)
description = content if isinstance(content, str) else ""
if not description:
return original_kickoff(messages, response_format, input_files)
fake_task = Task(
description=description,
agent=agent,
expected_output="Result from A2A delegation",
)
def task_to_kickoff_adapter(
self: Any, task: Task, context: str | None, tools: list[Any] | None
) -> str:
result = original_kickoff(messages, response_format, input_files)
return result.raw
result_str = _execute_task_with_a2a(
self=agent, # type: ignore[arg-type]
a2a_agents=a2a_agents,
original_fn=task_to_kickoff_adapter,
task=fake_task,
agent_response_model=agent_response_model,
context=None,
tools=None,
extension_registry=extension_registry,
)
return LiteAgentOutput(
raw=result_str,
pydantic=None,
agent_role=agent.role,
usage_metrics=None,
messages=[],
)
class LiteAgent(FlowTrackable, BaseModel):
"""
A lightweight agent that can process messages and use tools.
@@ -154,6 +232,17 @@ class LiteAgent(FlowTrackable, BaseModel):
guardrail_max_retries: int = Field(
default=3, description="Maximum number of retries when guardrail fails"
)
a2a: (
list[A2AConfig | A2AServerConfig | A2AClientConfig]
| A2AConfig
| A2AServerConfig
| A2AClientConfig
| None
) = Field(
default=None,
description="A2A (Agent-to-Agent) configuration for delegating tasks to remote agents. "
"Can be a single A2AConfig/A2AClientConfig/A2AServerConfig, or a list of configurations.",
)
tools_results: list[dict[str, Any]] = Field(
default_factory=list, description="Results of the tools used by the agent."
)
@@ -209,6 +298,52 @@ class LiteAgent(FlowTrackable, BaseModel):
return self
@model_validator(mode="after")
def setup_a2a_support(self) -> Self:
"""Setup A2A extensions and server methods if a2a config exists."""
if self.a2a:
from crewai.a2a.config import A2AClientConfig, A2AConfig
from crewai.a2a.extensions.registry import (
create_extension_registry_from_config,
)
from crewai.a2a.utils.agent_card import inject_a2a_server_methods
configs = self.a2a if isinstance(self.a2a, list) else [self.a2a]
client_configs = [
config
for config in configs
if isinstance(config, (A2AConfig, A2AClientConfig))
]
extension_registry = (
create_extension_registry_from_config(client_configs)
if client_configs
else create_extension_registry_from_config([])
)
extension_registry.inject_all_tools(self) # type: ignore[arg-type]
inject_a2a_server_methods(self) # type: ignore[arg-type]
original_kickoff = self.kickoff
@wraps(original_kickoff)
def kickoff_with_a2a(
messages: str | list[LLMMessage],
response_format: type[BaseModel] | None = None,
input_files: dict[str, FileInput] | None = None,
) -> LiteAgentOutput:
return _kickoff_with_a2a_support(
self,
original_kickoff,
messages,
response_format,
input_files,
extension_registry,
)
object.__setattr__(self, "kickoff", MethodType(kickoff_with_a2a, self))
return self
@model_validator(mode="after")
def ensure_guardrail_is_callable(self) -> Self:
if callable(self.guardrail):
@@ -344,11 +479,12 @@ class LiteAgent(FlowTrackable, BaseModel):
)
except Exception as e:
self._printer.print(
content="Agent failed to reach a final answer. This is likely a bug - please report it.",
color="red",
)
handle_unknown_error(self._printer, e)
if self.verbose:
self._printer.print(
content="Agent failed to reach a final answer. This is likely a bug - please report it.",
color="red",
)
handle_unknown_error(self._printer, e, verbose=self.verbose)
# Emit error event
crewai_event_bus.emit(
self,
@@ -396,10 +532,11 @@ class LiteAgent(FlowTrackable, BaseModel):
if isinstance(result, BaseModel):
formatted_result = result
except ConverterError as e:
self._printer.print(
content=f"Failed to parse output into response format after retries: {e.message}",
color="yellow",
)
if self.verbose:
self._printer.print(
content=f"Failed to parse output into response format after retries: {e.message}",
color="yellow",
)
# Calculate token usage metrics
if isinstance(self.llm, BaseLLM):
@@ -605,6 +742,7 @@ class LiteAgent(FlowTrackable, BaseModel):
messages=self._messages,
llm=cast(LLM, self.llm),
callbacks=self._callbacks,
verbose=self.verbose,
)
enforce_rpm_limit(self.request_within_rpm_limit)
@@ -617,12 +755,15 @@ class LiteAgent(FlowTrackable, BaseModel):
printer=self._printer,
from_agent=self,
executor_context=self,
verbose=self.verbose,
)
except Exception as e:
raise e
formatted_answer = process_llm_response(answer, self.use_stop_words)
formatted_answer = process_llm_response(
cast(str, answer), self.use_stop_words
)
if isinstance(formatted_answer, AgentAction):
try:
@@ -646,16 +787,18 @@ class LiteAgent(FlowTrackable, BaseModel):
self._append_message(formatted_answer.text, role="assistant")
except OutputParserError as e: # noqa: PERF203
self._printer.print(
content="Failed to parse LLM output. Retrying...",
color="yellow",
)
if self.verbose:
self._printer.print(
content="Failed to parse LLM output. Retrying...",
color="yellow",
)
formatted_answer = handle_output_parser_exception(
e=e,
messages=self._messages,
iterations=self._iterations,
log_error_after=3,
printer=self._printer,
verbose=self.verbose,
)
except Exception as e:
@@ -670,9 +813,10 @@ class LiteAgent(FlowTrackable, BaseModel):
llm=cast(LLM, self.llm),
callbacks=self._callbacks,
i18n=self.i18n,
verbose=self.verbose,
)
continue
handle_unknown_error(self._printer, e)
handle_unknown_error(self._printer, e, verbose=self.verbose)
raise e
finally:
@@ -702,3 +846,21 @@ class LiteAgent(FlowTrackable, BaseModel):
) -> None:
"""Append a message to the message list with the given role."""
self._messages.append(format_message_for_llm(text, role=role))
try:
from crewai.a2a.config import (
A2AClientConfig as _A2AClientConfig,
A2AConfig as _A2AConfig,
A2AServerConfig as _A2AServerConfig,
)
LiteAgent.model_rebuild(
_types_namespace={
"A2AConfig": _A2AConfig,
"A2AClientConfig": _A2AClientConfig,
"A2AServerConfig": _A2AServerConfig,
}
)
except ImportError:
pass

View File

@@ -497,7 +497,7 @@ class BaseLLM(ABC):
from_agent=from_agent,
)
return result
return str(result) if not isinstance(result, str) else result
except Exception as e:
error_msg = f"Error executing function '{function_name}': {e!s}"
@@ -620,11 +620,13 @@ class BaseLLM(ABC):
try:
# Try to parse as JSON first
if response.strip().startswith("{") or response.strip().startswith("["):
return response_format.model_validate_json(response)
data = json.loads(response)
return response_format.model_validate(data)
json_match = _JSON_EXTRACTION_PATTERN.search(response)
if json_match:
return response_format.model_validate_json(json_match.group())
data = json.loads(json_match.group())
return response_format.model_validate(data)
raise ValueError("No JSON found in response")
@@ -735,22 +737,25 @@ class BaseLLM(ABC):
task=None,
crew=None,
)
verbose = getattr(from_agent, "verbose", True) if from_agent else True
printer = Printer()
try:
for hook in before_hooks:
result = hook(hook_context)
if result is False:
printer.print(
content="LLM call blocked by before_llm_call hook",
color="yellow",
)
if verbose:
printer.print(
content="LLM call blocked by before_llm_call hook",
color="yellow",
)
return False
except Exception as e:
printer.print(
content=f"Error in before_llm_call hook: {e}",
color="yellow",
)
if verbose:
printer.print(
content=f"Error in before_llm_call hook: {e}",
color="yellow",
)
return True
@@ -803,6 +808,7 @@ class BaseLLM(ABC):
crew=None,
response=response,
)
verbose = getattr(from_agent, "verbose", True) if from_agent else True
printer = Printer()
modified_response = response
@@ -813,9 +819,10 @@ class BaseLLM(ABC):
modified_response = result
hook_context.response = modified_response
except Exception as e:
printer.print(
content=f"Error in after_llm_call hook: {e}",
color="yellow",
)
if verbose:
printer.print(
content=f"Error in after_llm_call hook: {e}",
color="yellow",
)
return modified_response

View File

@@ -23,7 +23,7 @@ if TYPE_CHECKING:
try:
from anthropic import Anthropic, AsyncAnthropic, transform_schema
from anthropic.types import Message, TextBlock, ThinkingBlock, ToolUseBlock
from anthropic.types.beta import BetaMessage
from anthropic.types.beta import BetaMessage, BetaTextBlock
import httpx
except ImportError:
raise ImportError(
@@ -337,6 +337,7 @@ class AnthropicCompletion(BaseLLM):
available_functions: Available functions for tool calling
from_task: Task that initiated the call
from_agent: Agent that initiated the call
response_model: Optional response model.
Returns:
Chat completion response or tool call result
@@ -677,31 +678,31 @@ class AnthropicCompletion(BaseLLM):
if _is_pydantic_model_class(response_model) and response.content:
if use_native_structured_output:
for block in response.content:
if isinstance(block, TextBlock):
structured_json = block.text
if isinstance(block, (TextBlock, BetaTextBlock)):
structured_data = response_model.model_validate_json(block.text)
self._emit_call_completed_event(
response=structured_json,
response=structured_data.model_dump_json(),
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
)
return structured_json
return structured_data
else:
for block in response.content:
if (
isinstance(block, ToolUseBlock)
and block.name == "structured_output"
):
structured_json = json.dumps(block.input)
structured_data = response_model.model_validate(block.input)
self._emit_call_completed_event(
response=structured_json,
response=structured_data.model_dump_json(),
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
)
return structured_json
return structured_data
# Check if Claude wants to use tools
if response.content:
@@ -897,28 +898,29 @@ class AnthropicCompletion(BaseLLM):
if _is_pydantic_model_class(response_model):
if use_native_structured_output:
structured_data = response_model.model_validate_json(full_response)
self._emit_call_completed_event(
response=full_response,
response=structured_data.model_dump_json(),
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
)
return full_response
return structured_data
for block in final_message.content:
if (
isinstance(block, ToolUseBlock)
and block.name == "structured_output"
):
structured_json = json.dumps(block.input)
structured_data = response_model.model_validate(block.input)
self._emit_call_completed_event(
response=structured_json,
response=structured_data.model_dump_json(),
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
)
return structured_json
return structured_data
if final_message.content:
tool_uses = [
@@ -1166,31 +1168,31 @@ class AnthropicCompletion(BaseLLM):
if _is_pydantic_model_class(response_model) and response.content:
if use_native_structured_output:
for block in response.content:
if isinstance(block, TextBlock):
structured_json = block.text
if isinstance(block, (TextBlock, BetaTextBlock)):
structured_data = response_model.model_validate_json(block.text)
self._emit_call_completed_event(
response=structured_json,
response=structured_data.model_dump_json(),
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
)
return structured_json
return structured_data
else:
for block in response.content:
if (
isinstance(block, ToolUseBlock)
and block.name == "structured_output"
):
structured_json = json.dumps(block.input)
structured_data = response_model.model_validate(block.input)
self._emit_call_completed_event(
response=structured_json,
response=structured_data.model_dump_json(),
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
)
return structured_json
return structured_data
if response.content:
tool_uses = [
@@ -1362,28 +1364,29 @@ class AnthropicCompletion(BaseLLM):
if _is_pydantic_model_class(response_model):
if use_native_structured_output:
structured_data = response_model.model_validate_json(full_response)
self._emit_call_completed_event(
response=full_response,
response=structured_data.model_dump_json(),
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
)
return full_response
return structured_data
for block in final_message.content:
if (
isinstance(block, ToolUseBlock)
and block.name == "structured_output"
):
structured_json = json.dumps(block.input)
structured_data = response_model.model_validate(block.input)
self._emit_call_completed_event(
response=structured_json,
response=structured_data.model_dump_json(),
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
)
return structured_json
return structured_data
if final_message.content:
tool_uses = [

View File

@@ -557,7 +557,7 @@ class AzureCompletion(BaseLLM):
params: AzureCompletionParams,
from_task: Any | None = None,
from_agent: Any | None = None,
) -> str:
) -> BaseModel:
"""Validate content against response model and emit completion event.
Args:
@@ -568,24 +568,23 @@ class AzureCompletion(BaseLLM):
from_agent: Agent that initiated the call
Returns:
Validated and serialized JSON string
Validated Pydantic model instance
Raises:
ValueError: If validation fails
"""
try:
structured_data = response_model.model_validate_json(content)
structured_json = structured_data.model_dump_json()
self._emit_call_completed_event(
response=structured_json,
response=structured_data.model_dump_json(),
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
)
return structured_json
return structured_data
except Exception as e:
error_msg = f"Failed to validate structured output with model {response_model.__name__}: {e}"
logging.error(error_msg)

View File

@@ -1,6 +1,6 @@
from __future__ import annotations
from collections.abc import Sequence
from collections.abc import Mapping, Sequence
from contextlib import AsyncExitStack
import json
import logging
@@ -538,7 +538,7 @@ class BedrockCompletion(BaseLLM):
self,
messages: list[LLMMessage],
body: BedrockConverseRequestBody,
available_functions: dict[str, Any] | None = None,
available_functions: Mapping[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
response_model: type[BaseModel] | None = None,
@@ -1009,7 +1009,7 @@ class BedrockCompletion(BaseLLM):
self,
messages: list[LLMMessage],
body: BedrockConverseRequestBody,
available_functions: dict[str, Any] | None = None,
available_functions: Mapping[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
response_model: type[BaseModel] | None = None,

View File

@@ -132,6 +132,9 @@ class GeminiCompletion(BaseLLM):
self.supports_tools = bool(
version_match and float(version_match.group(1)) >= 1.5
)
self.is_gemini_2_0 = bool(
version_match and float(version_match.group(1)) >= 2.0
)
@property
def stop(self) -> list[str]:
@@ -439,6 +442,11 @@ class GeminiCompletion(BaseLLM):
Returns:
GenerateContentConfig object for Gemini API
Note:
Structured output support varies by model version:
- Gemini 1.5 and earlier: Uses response_schema (Pydantic model)
- Gemini 2.0+: Uses response_json_schema (JSON Schema) with propertyOrdering
"""
self.tools = tools
config_params: dict[str, Any] = {}
@@ -466,9 +474,13 @@ class GeminiCompletion(BaseLLM):
if response_model:
config_params["response_mime_type"] = "application/json"
schema_output = generate_model_description(response_model)
config_params["response_schema"] = schema_output.get("json_schema", {}).get(
"schema", {}
)
schema = schema_output.get("json_schema", {}).get("schema", {})
if self.is_gemini_2_0:
schema = self._add_property_ordering(schema)
config_params["response_json_schema"] = schema
else:
config_params["response_schema"] = response_model
# Handle tools for supported models
if tools and self.supports_tools:
@@ -632,7 +644,7 @@ class GeminiCompletion(BaseLLM):
messages_for_event: list[LLMMessage],
from_task: Any | None = None,
from_agent: Any | None = None,
) -> str:
) -> BaseModel:
"""Validate content against response model and emit completion event.
Args:
@@ -643,24 +655,23 @@ class GeminiCompletion(BaseLLM):
from_agent: Agent that initiated the call
Returns:
Validated and serialized JSON string
Validated Pydantic model instance
Raises:
ValueError: If validation fails
"""
try:
structured_data = response_model.model_validate_json(content)
structured_json = structured_data.model_dump_json()
self._emit_call_completed_event(
response=structured_json,
response=structured_data.model_dump_json(),
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=messages_for_event,
)
return structured_json
return structured_data
except Exception as e:
error_msg = f"Failed to validate structured output with model {response_model.__name__}: {e}"
logging.error(error_msg)
@@ -673,7 +684,7 @@ class GeminiCompletion(BaseLLM):
response_model: type[BaseModel] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
) -> str:
) -> str | BaseModel:
"""Finalize completion response with validation and event emission.
Args:
@@ -684,7 +695,7 @@ class GeminiCompletion(BaseLLM):
from_agent: Agent that initiated the call
Returns:
Final response content after processing
Final response content after processing (str or Pydantic model if response_model provided)
"""
messages_for_event = self._convert_contents_to_dict(contents)
@@ -870,7 +881,7 @@ class GeminiCompletion(BaseLLM):
from_task: Any | None = None,
from_agent: Any | None = None,
response_model: type[BaseModel] | None = None,
) -> str | list[dict[str, Any]]:
) -> str | BaseModel | list[dict[str, Any]]:
"""Finalize streaming response with usage tracking, function execution, and events.
Args:
@@ -990,7 +1001,7 @@ class GeminiCompletion(BaseLLM):
from_task: Any | None = None,
from_agent: Any | None = None,
response_model: type[BaseModel] | None = None,
) -> str | Any:
) -> str | BaseModel | list[dict[str, Any]] | Any:
"""Handle streaming content generation."""
full_response = ""
function_calls: dict[int, dict[str, Any]] = {}
@@ -1190,6 +1201,36 @@ class GeminiCompletion(BaseLLM):
return "".join(text_parts)
@staticmethod
def _add_property_ordering(schema: dict[str, Any]) -> dict[str, Any]:
"""Add propertyOrdering to JSON schema for Gemini 2.0 compatibility.
Gemini 2.0 models require an explicit propertyOrdering list to define
the preferred structure of JSON objects. This recursively adds
propertyOrdering to all objects in the schema.
Args:
schema: JSON schema dictionary.
Returns:
Modified schema with propertyOrdering added to all objects.
"""
if isinstance(schema, dict):
if schema.get("type") == "object" and "properties" in schema:
properties = schema["properties"]
if properties and "propertyOrdering" not in schema:
schema["propertyOrdering"] = list(properties.keys())
for value in schema.values():
if isinstance(value, dict):
GeminiCompletion._add_property_ordering(value)
elif isinstance(value, list):
for item in value:
if isinstance(item, dict):
GeminiCompletion._add_property_ordering(item)
return schema
@staticmethod
def _convert_contents_to_dict(
contents: list[types.Content],

View File

@@ -1570,15 +1570,14 @@ class OpenAICompletion(BaseLLM):
parsed_object = parsed_response.choices[0].message.parsed
if parsed_object:
structured_json = parsed_object.model_dump_json()
self._emit_call_completed_event(
response=structured_json,
response=parsed_object.model_dump_json(),
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
)
return structured_json
return parsed_object
response: ChatCompletion = self.client.chat.completions.create(**params)
@@ -1692,7 +1691,7 @@ class OpenAICompletion(BaseLLM):
from_task: Any | None = None,
from_agent: Any | None = None,
response_model: type[BaseModel] | None = None,
) -> str:
) -> str | BaseModel:
"""Handle streaming chat completion."""
full_response = ""
tool_calls: dict[int, dict[str, Any]] = {}
@@ -1728,15 +1727,14 @@ class OpenAICompletion(BaseLLM):
if final_completion.choices:
parsed_result = final_completion.choices[0].message.parsed
if parsed_result:
structured_json = parsed_result.model_dump_json()
self._emit_call_completed_event(
response=structured_json,
response=parsed_result.model_dump_json(),
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
)
return structured_json
return parsed_result
logging.error("Failed to get parsed result from stream")
return ""
@@ -1887,15 +1885,14 @@ class OpenAICompletion(BaseLLM):
parsed_object = parsed_response.choices[0].message.parsed
if parsed_object:
structured_json = parsed_object.model_dump_json()
self._emit_call_completed_event(
response=structured_json,
response=parsed_object.model_dump_json(),
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
)
return structured_json
return parsed_object
response: ChatCompletion = await self.async_client.chat.completions.create(
**params
@@ -2006,7 +2003,7 @@ class OpenAICompletion(BaseLLM):
from_task: Any | None = None,
from_agent: Any | None = None,
response_model: type[BaseModel] | None = None,
) -> str:
) -> str | BaseModel:
"""Handle async streaming chat completion."""
full_response = ""
tool_calls: dict[int, dict[str, Any]] = {}
@@ -2044,17 +2041,16 @@ class OpenAICompletion(BaseLLM):
try:
parsed_object = response_model.model_validate_json(accumulated_content)
structured_json = parsed_object.model_dump_json()
self._emit_call_completed_event(
response=structured_json,
response=parsed_object.model_dump_json(),
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
)
return structured_json
return parsed_object
except Exception as e:
logging.error(f"Failed to parse structured output from stream: {e}")
self._emit_call_completed_event(

View File

@@ -12,15 +12,17 @@ from crewai.utilities.paths import db_storage_path
class LTMSQLiteStorage:
"""SQLite storage class for long-term memory data."""
def __init__(self, db_path: str | None = None) -> None:
def __init__(self, db_path: str | None = None, verbose: bool = True) -> None:
"""Initialize the SQLite storage.
Args:
db_path: Optional path to the database file.
verbose: Whether to print error messages.
"""
if db_path is None:
db_path = str(Path(db_storage_path()) / "long_term_memory_storage.db")
self.db_path = db_path
self._verbose = verbose
self._printer: Printer = Printer()
Path(self.db_path).parent.mkdir(parents=True, exist_ok=True)
self._initialize_db()
@@ -44,10 +46,11 @@ class LTMSQLiteStorage:
conn.commit()
except sqlite3.Error as e:
self._printer.print(
content=f"MEMORY ERROR: An error occurred during database initialization: {e}",
color="red",
)
if self._verbose:
self._printer.print(
content=f"MEMORY ERROR: An error occurred during database initialization: {e}",
color="red",
)
def save(
self,
@@ -69,10 +72,11 @@ class LTMSQLiteStorage:
)
conn.commit()
except sqlite3.Error as e:
self._printer.print(
content=f"MEMORY ERROR: An error occurred while saving to LTM: {e}",
color="red",
)
if self._verbose:
self._printer.print(
content=f"MEMORY ERROR: An error occurred while saving to LTM: {e}",
color="red",
)
def load(self, task_description: str, latest_n: int) -> list[dict[str, Any]] | None:
"""Queries the LTM table by task description with error handling."""
@@ -101,10 +105,11 @@ class LTMSQLiteStorage:
]
except sqlite3.Error as e:
self._printer.print(
content=f"MEMORY ERROR: An error occurred while querying LTM: {e}",
color="red",
)
if self._verbose:
self._printer.print(
content=f"MEMORY ERROR: An error occurred while querying LTM: {e}",
color="red",
)
return None
def reset(self) -> None:
@@ -116,10 +121,11 @@ class LTMSQLiteStorage:
conn.commit()
except sqlite3.Error as e:
self._printer.print(
content=f"MEMORY ERROR: An error occurred while deleting all rows in LTM: {e}",
color="red",
)
if self._verbose:
self._printer.print(
content=f"MEMORY ERROR: An error occurred while deleting all rows in LTM: {e}",
color="red",
)
async def asave(
self,
@@ -147,10 +153,11 @@ class LTMSQLiteStorage:
)
await conn.commit()
except aiosqlite.Error as e:
self._printer.print(
content=f"MEMORY ERROR: An error occurred while saving to LTM: {e}",
color="red",
)
if self._verbose:
self._printer.print(
content=f"MEMORY ERROR: An error occurred while saving to LTM: {e}",
color="red",
)
async def aload(
self, task_description: str, latest_n: int
@@ -187,10 +194,11 @@ class LTMSQLiteStorage:
for row in rows
]
except aiosqlite.Error as e:
self._printer.print(
content=f"MEMORY ERROR: An error occurred while querying LTM: {e}",
color="red",
)
if self._verbose:
self._printer.print(
content=f"MEMORY ERROR: An error occurred while querying LTM: {e}",
color="red",
)
return None
async def areset(self) -> None:
@@ -200,7 +208,8 @@ class LTMSQLiteStorage:
await conn.execute("DELETE FROM long_term_memories")
await conn.commit()
except aiosqlite.Error as e:
self._printer.print(
content=f"MEMORY ERROR: An error occurred while deleting all rows in LTM: {e}",
color="red",
)
if self._verbose:
self._printer.print(
content=f"MEMORY ERROR: An error occurred while deleting all rows in LTM: {e}",
color="red",
)

View File

@@ -1,5 +1,6 @@
"""Type definitions specific to ChromaDB implementation."""
from collections.abc import Mapping
from typing import Any, NamedTuple
from chromadb.api import AsyncClientAPI, ClientAPI
@@ -48,7 +49,7 @@ class PreparedDocuments(NamedTuple):
ids: list[str]
texts: list[str]
metadatas: list[dict[str, str | int | float | bool]]
metadatas: list[Mapping[str, str | int | float | bool]]
class ExtractedSearchParams(NamedTuple):

View File

@@ -1,5 +1,6 @@
"""Utility functions for ChromaDB client implementation."""
from collections.abc import Mapping
import hashlib
import json
from typing import Literal, TypeGuard, cast
@@ -65,7 +66,7 @@ def _prepare_documents_for_chromadb(
"""
ids: list[str] = []
texts: list[str] = []
metadatas: list[dict[str, str | int | float | bool]] = []
metadatas: list[Mapping[str, str | int | float | bool]] = []
seen_ids: dict[str, int] = {}
try:
@@ -110,7 +111,7 @@ def _prepare_documents_for_chromadb(
def _create_batch_slice(
prepared: PreparedDocuments, start_index: int, batch_size: int
) -> tuple[list[str], list[str], list[dict[str, str | int | float | bool]] | None]:
) -> tuple[list[str], list[str], list[Mapping[str, str | int | float | bool]] | None]:
"""Create a batch slice from prepared documents.
Args:

View File

@@ -1,6 +1,6 @@
"""IBM WatsonX embedding function implementation."""
from typing import cast
from typing import Any, cast
from chromadb.api.types import Documents, EmbeddingFunction, Embeddings
from typing_extensions import Unpack
@@ -15,14 +15,18 @@ _printer = Printer()
class WatsonXEmbeddingFunction(EmbeddingFunction[Documents]):
"""Embedding function for IBM WatsonX models."""
def __init__(self, **kwargs: Unpack[WatsonXProviderConfig]) -> None:
def __init__(
self, *, verbose: bool = True, **kwargs: Unpack[WatsonXProviderConfig]
) -> None:
"""Initialize WatsonX embedding function.
Args:
verbose: Whether to print error messages.
**kwargs: Configuration parameters for WatsonX Embeddings and Credentials.
"""
super().__init__(**kwargs)
self._config = kwargs
self._verbose = verbose
@staticmethod
def name() -> str:
@@ -56,7 +60,7 @@ class WatsonXEmbeddingFunction(EmbeddingFunction[Documents]):
if isinstance(input, str):
input = [input]
embeddings_config: dict = {
embeddings_config: dict[str, Any] = {
"model_id": self._config["model_id"],
}
if "params" in self._config and self._config["params"] is not None:
@@ -90,7 +94,7 @@ class WatsonXEmbeddingFunction(EmbeddingFunction[Documents]):
if "credentials" in self._config and self._config["credentials"] is not None:
embeddings_config["credentials"] = self._config["credentials"]
else:
cred_config: dict = {}
cred_config: dict[str, Any] = {}
if "url" in self._config and self._config["url"] is not None:
cred_config["url"] = self._config["url"]
if "api_key" in self._config and self._config["api_key"] is not None:
@@ -159,5 +163,6 @@ class WatsonXEmbeddingFunction(EmbeddingFunction[Documents]):
embeddings = embedding.embed_documents(input)
return cast(Embeddings, embeddings)
except Exception as e:
_printer.print(f"Error during WatsonX embedding: {e}", color="red")
if self._verbose:
_printer.print(f"Error during WatsonX embedding: {e}", color="red")
raise

View File

@@ -1,8 +1,6 @@
"""Type definitions for the embeddings module."""
from typing import Annotated, Any, Literal, TypeAlias
from pydantic import Field
from typing import Any, Literal, TypeAlias
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
from crewai.rag.embeddings.providers.aws.types import BedrockProviderSpec
@@ -31,7 +29,7 @@ from crewai.rag.embeddings.providers.text2vec.types import Text2VecProviderSpec
from crewai.rag.embeddings.providers.voyageai.types import VoyageAIProviderSpec
ProviderSpec: TypeAlias = Annotated[
ProviderSpec: TypeAlias = (
AzureProviderSpec
| BedrockProviderSpec
| CohereProviderSpec
@@ -49,9 +47,8 @@ ProviderSpec: TypeAlias = Annotated[
| Text2VecProviderSpec
| VertexAIProviderSpec
| VoyageAIProviderSpec
| WatsonXProviderSpec,
Field(discriminator="provider"),
]
| WatsonXProviderSpec
)
AllowedEmbeddingProviders = Literal[
"azure",

View File

@@ -1,6 +1,6 @@
"""Type definitions for RAG (Retrieval-Augmented Generation) systems."""
from collections.abc import Callable
from collections.abc import Callable, Mapping
from typing import Any, TypeAlias
from typing_extensions import Required, TypedDict
@@ -19,8 +19,8 @@ class BaseRecord(TypedDict, total=False):
doc_id: str
content: Required[str]
metadata: (
dict[str, str | int | float | bool]
| list[dict[str, str | int | float | bool]]
Mapping[str, str | int | float | bool]
| list[Mapping[str, str | int | float | bool]]
)

View File

@@ -767,10 +767,11 @@ class Task(BaseModel):
if files:
supported_types: list[str] = []
if self.agent.llm and self.agent.llm.supports_multimodal():
provider = getattr(self.agent.llm, "provider", None) or getattr(
self.agent.llm, "model", "openai"
provider: str = str(
getattr(self.agent.llm, "provider", None)
or getattr(self.agent.llm, "model", "openai")
)
api = getattr(self.agent.llm, "api", None)
api: str | None = getattr(self.agent.llm, "api", None)
supported_types = get_supported_content_types(provider, api)
def is_auto_injected(content_type: str) -> bool:
@@ -887,10 +888,11 @@ Follow these guidelines:
try:
crew_chat_messages = json.loads(crew_chat_messages_json)
except json.JSONDecodeError as e:
_printer.print(
f"An error occurred while parsing crew chat messages: {e}",
color="red",
)
if self.agent and self.agent.verbose:
_printer.print(
f"An error occurred while parsing crew chat messages: {e}",
color="red",
)
raise
conversation_history = "\n".join(
@@ -1132,11 +1134,12 @@ Follow these guidelines:
guardrail_result_error=guardrail_result.error,
task_output=task_output.raw,
)
printer = Printer()
printer.print(
content=f"Guardrail {guardrail_index if guardrail_index is not None else ''} blocked (attempt {attempt + 1}/{max_attempts}), retrying due to: {guardrail_result.error}\n",
color="yellow",
)
if agent and agent.verbose:
printer = Printer()
printer.print(
content=f"Guardrail {guardrail_index if guardrail_index is not None else ''} blocked (attempt {attempt + 1}/{max_attempts}), retrying due to: {guardrail_result.error}\n",
color="yellow",
)
# Regenerate output from agent
result = agent.execute_task(
@@ -1229,11 +1232,12 @@ Follow these guidelines:
guardrail_result_error=guardrail_result.error,
task_output=task_output.raw,
)
printer = Printer()
printer.print(
content=f"Guardrail {guardrail_index if guardrail_index is not None else ''} blocked (attempt {attempt + 1}/{max_attempts}), retrying due to: {guardrail_result.error}\n",
color="yellow",
)
if agent and agent.verbose:
printer = Printer()
printer.print(
content=f"Guardrail {guardrail_index if guardrail_index is not None else ''} blocked (attempt {attempt + 1}/{max_attempts}), retrying due to: {guardrail_result.error}\n",
color="yellow",
)
result = await agent.aexecute_task(
task=self,

View File

@@ -200,12 +200,9 @@ class CrewStructuredTool:
"""
if isinstance(raw_args, str):
try:
validated_args = self.args_schema.model_validate_json(raw_args)
return validated_args.model_dump()
raw_args = json.loads(raw_args)
except json.JSONDecodeError as e:
raise ValueError(f"Failed to parse arguments as JSON: {e}") from e
except Exception as e:
raise ValueError(f"Arguments validation failed: {e}") from e
try:
validated_args = self.args_schema.model_validate(raw_args)

View File

@@ -384,6 +384,8 @@ class ToolUsage:
if (
hasattr(available_tool, "max_usage_count")
and available_tool.max_usage_count is not None
and self.agent
and self.agent.verbose
):
self._printer.print(
content=f"Tool '{sanitize_tool_name(available_tool.name)}' usage: {available_tool.current_usage_count}/{available_tool.max_usage_count}",
@@ -396,6 +398,8 @@ class ToolUsage:
if (
hasattr(available_tool, "max_usage_count")
and available_tool.max_usage_count is not None
and self.agent
and self.agent.verbose
):
self._printer.print(
content=f"Tool '{sanitize_tool_name(available_tool.name)}' usage: {available_tool.current_usage_count}/{available_tool.max_usage_count}",
@@ -610,6 +614,8 @@ class ToolUsage:
if (
hasattr(available_tool, "max_usage_count")
and available_tool.max_usage_count is not None
and self.agent
and self.agent.verbose
):
self._printer.print(
content=f"Tool '{sanitize_tool_name(available_tool.name)}' usage: {available_tool.current_usage_count}/{available_tool.max_usage_count}",
@@ -622,6 +628,8 @@ class ToolUsage:
if (
hasattr(available_tool, "max_usage_count")
and available_tool.max_usage_count is not None
and self.agent
and self.agent.verbose
):
self._printer.print(
content=f"Tool '{sanitize_tool_name(available_tool.name)}' usage: {available_tool.current_usage_count}/{available_tool.max_usage_count}",
@@ -884,15 +892,17 @@ class ToolUsage:
# Attempt 4: Repair JSON
try:
repaired_input = str(repair_json(tool_input, skip_json_loads=True))
self._printer.print(
content=f"Repaired JSON: {repaired_input}", color="blue"
)
if self.agent and self.agent.verbose:
self._printer.print(
content=f"Repaired JSON: {repaired_input}", color="blue"
)
arguments = json.loads(repaired_input)
if isinstance(arguments, dict):
return arguments
except Exception as e:
error = f"Failed to repair JSON: {e}"
self._printer.print(content=error, color="red")
if self.agent and self.agent.verbose:
self._printer.print(content=error, color="red")
error_message = (
"Tool input must be a valid dictionary in JSON or Python literal format"

View File

@@ -10,9 +10,10 @@
"memory": "\n\n# Useful context: \n{memory}",
"role_playing": "You are {role}. {backstory}\nYour personal goal is: {goal}",
"tools": "\nYou ONLY have access to the following tools, and should NEVER make up tools that are not listed here:\n\n{tools}\n\nIMPORTANT: Use the following format in your response:\n\n```\nThought: you should always think about what to do\nAction: the action to take, only one name of [{tool_names}], just the name, exactly as it's written.\nAction Input: the input to the action, just a simple JSON object, enclosed in curly braces, using \" to wrap keys and values.\nObservation: the result of the action\n```\n\nOnce all necessary information is gathered, return the following format:\n\n```\nThought: I now know the final answer\nFinal Answer: the final answer to the original input question\n```",
"no_tools": "\nTo give my best complete final answer to the task respond using the exact following format:\n\nThought: I now can give a great answer\nFinal Answer: Your final answer must be the great and the most complete as possible, it must be outcome described.\n\nI MUST use these formats, my job depends on it!",
"native_tools": "\nUse available tools to gather information and complete your task.",
"native_task": "\nCurrent Task: {input}\n\nThis is VERY important to you, your job depends on it!",
"no_tools": "",
"task_no_tools": "\nCurrent Task: {input}\n\nProvide your complete response:",
"native_tools": "",
"native_task": "\nCurrent Task: {input}",
"post_tool_reasoning": "Analyze the tool result. If requirements are met, provide the Final Answer. Otherwise, call the next tool. Deliver only the answer without meta-commentary.",
"format": "Decide if you need a tool or can provide the final answer. Use one at a time.\nTo use a tool, use:\nThought: [reasoning]\nAction: [name from {tool_names}]\nAction Input: [JSON object]\n\nTo provide the final answer, use:\nThought: [reasoning]\nFinal Answer: [complete response]",
"final_answer_format": "If you don't need to use any more tools, you must give your best complete final answer, make sure it satisfies the expected criteria, use the EXACT format below:\n\n```\nThought: I now can give a great answer\nFinal Answer: my best complete final answer to the task.\n\n```",

View File

@@ -1,7 +1,7 @@
from __future__ import annotations
import asyncio
from collections.abc import Callable
from collections.abc import Callable, Sequence
import json
import re
from typing import TYPE_CHECKING, Any, Final, Literal, TypedDict
@@ -98,7 +98,7 @@ def parse_tools(tools: list[BaseTool]) -> list[CrewStructuredTool]:
return tools_list
def get_tool_names(tools: list[CrewStructuredTool | BaseTool]) -> str:
def get_tool_names(tools: Sequence[CrewStructuredTool | BaseTool]) -> str:
"""Get the sanitized names of the tools.
Args:
@@ -111,7 +111,7 @@ def get_tool_names(tools: list[CrewStructuredTool | BaseTool]) -> str:
def render_text_description_and_args(
tools: list[CrewStructuredTool | BaseTool],
tools: Sequence[CrewStructuredTool | BaseTool],
) -> str:
"""Render the tool name, description, and args in plain text.
@@ -130,7 +130,7 @@ def render_text_description_and_args(
def convert_tools_to_openai_schema(
tools: list[BaseTool | CrewStructuredTool],
tools: Sequence[BaseTool | CrewStructuredTool],
) -> tuple[list[dict[str, Any]], dict[str, Callable[..., Any]]]:
"""Convert CrewAI tools to OpenAI function calling format.
@@ -210,6 +210,7 @@ def handle_max_iterations_exceeded(
messages: list[LLMMessage],
llm: LLM | BaseLLM,
callbacks: list[TokenCalcHandler],
verbose: bool = True,
) -> AgentFinish:
"""Handles the case when the maximum number of iterations is exceeded. Performs one more LLM call to get the final answer.
@@ -220,14 +221,16 @@ def handle_max_iterations_exceeded(
messages: List of messages to send to the LLM.
llm: The LLM instance to call.
callbacks: List of callbacks for the LLM call.
verbose: Whether to print output.
Returns:
AgentFinish with the final answer after exceeding max iterations.
"""
printer.print(
content="Maximum iterations reached. Requesting final answer.",
color="yellow",
)
if verbose:
printer.print(
content="Maximum iterations reached. Requesting final answer.",
color="yellow",
)
if formatted_answer and hasattr(formatted_answer, "text"):
assistant_message = (
@@ -245,10 +248,11 @@ def handle_max_iterations_exceeded(
)
if answer is None or answer == "":
printer.print(
content="Received None or empty response from LLM call.",
color="red",
)
if verbose:
printer.print(
content="Received None or empty response from LLM call.",
color="red",
)
raise ValueError("Invalid response from LLM call - None or empty.")
formatted = format_answer(answer=answer)
@@ -322,7 +326,8 @@ def get_llm_response(
from_agent: Agent | LiteAgent | None = None,
response_model: type[BaseModel] | None = None,
executor_context: CrewAgentExecutor | AgentExecutor | LiteAgent | None = None,
) -> str | Any:
verbose: bool = True,
) -> str | BaseModel | Any:
"""Call the LLM and return the response, handling any invalid responses.
Args:
@@ -336,10 +341,11 @@ def get_llm_response(
from_agent: Optional agent context for the LLM call.
response_model: Optional Pydantic model for structured outputs.
executor_context: Optional executor context for hook invocation.
verbose: Whether to print output.
Returns:
The response from the LLM as a string, or tool call results if
native function calling is used.
The response from the LLM as a string, Pydantic model (when response_model is provided),
or tool call results if native function calling is used.
Raises:
Exception: If an error occurs.
@@ -347,7 +353,7 @@ def get_llm_response(
"""
if executor_context is not None:
if not _setup_before_llm_call_hooks(executor_context, printer):
if not _setup_before_llm_call_hooks(executor_context, printer, verbose=verbose):
raise ValueError("LLM call blocked by before_llm_call hook")
messages = executor_context.messages
@@ -364,13 +370,16 @@ def get_llm_response(
except Exception as e:
raise e
if not answer:
printer.print(
content="Received None or empty response from LLM call.",
color="red",
)
if verbose:
printer.print(
content="Received None or empty response from LLM call.",
color="red",
)
raise ValueError("Invalid response from LLM call - None or empty.")
return _setup_after_llm_call_hooks(executor_context, answer, printer)
return _setup_after_llm_call_hooks(
executor_context, answer, printer, verbose=verbose
)
async def aget_llm_response(
@@ -384,7 +393,8 @@ async def aget_llm_response(
from_agent: Agent | LiteAgent | None = None,
response_model: type[BaseModel] | None = None,
executor_context: CrewAgentExecutor | AgentExecutor | None = None,
) -> str | Any:
verbose: bool = True,
) -> str | BaseModel | Any:
"""Call the LLM asynchronously and return the response.
Args:
@@ -400,15 +410,15 @@ async def aget_llm_response(
executor_context: Optional executor context for hook invocation.
Returns:
The response from the LLM as a string, or tool call results if
native function calling is used.
The response from the LLM as a string, Pydantic model (when response_model is provided),
or tool call results if native function calling is used.
Raises:
Exception: If an error occurs.
ValueError: If the response is None or empty.
"""
if executor_context is not None:
if not _setup_before_llm_call_hooks(executor_context, printer):
if not _setup_before_llm_call_hooks(executor_context, printer, verbose=verbose):
raise ValueError("LLM call blocked by before_llm_call hook")
messages = executor_context.messages
@@ -425,13 +435,16 @@ async def aget_llm_response(
except Exception as e:
raise e
if not answer:
printer.print(
content="Received None or empty response from LLM call.",
color="red",
)
if verbose:
printer.print(
content="Received None or empty response from LLM call.",
color="red",
)
raise ValueError("Invalid response from LLM call - None or empty.")
return _setup_after_llm_call_hooks(executor_context, answer, printer)
return _setup_after_llm_call_hooks(
executor_context, answer, printer, verbose=verbose
)
def process_llm_response(
@@ -498,13 +511,19 @@ def handle_agent_action_core(
return formatted_answer
def handle_unknown_error(printer: Printer, exception: Exception) -> None:
def handle_unknown_error(
printer: Printer, exception: Exception, verbose: bool = True
) -> None:
"""Handle unknown errors by informing the user.
Args:
printer: Printer instance for output
exception: The exception that occurred
verbose: Whether to print output.
"""
if not verbose:
return
error_message = str(exception)
if "litellm" in error_message:
@@ -526,6 +545,7 @@ def handle_output_parser_exception(
iterations: int,
log_error_after: int = 3,
printer: Printer | None = None,
verbose: bool = True,
) -> AgentAction:
"""Handle OutputParserError by updating messages and formatted_answer.
@@ -548,7 +568,7 @@ def handle_output_parser_exception(
thought="",
)
if iterations > log_error_after and printer:
if verbose and iterations > log_error_after and printer:
printer.print(
content=f"Error parsing LLM output, agent will retry: {e.error}",
color="red",
@@ -578,6 +598,7 @@ def handle_context_length(
llm: LLM | BaseLLM,
callbacks: list[TokenCalcHandler],
i18n: I18N,
verbose: bool = True,
) -> None:
"""Handle context length exceeded by either summarizing or raising an error.
@@ -593,16 +614,20 @@ def handle_context_length(
SystemExit: If context length is exceeded and user opts not to summarize
"""
if respect_context_window:
printer.print(
content="Context length exceeded. Summarizing content to fit the model context window. Might take a while...",
color="yellow",
if verbose:
printer.print(
content="Context length exceeded. Summarizing content to fit the model context window. Might take a while...",
color="yellow",
)
summarize_messages(
messages=messages, llm=llm, callbacks=callbacks, i18n=i18n, verbose=verbose
)
summarize_messages(messages=messages, llm=llm, callbacks=callbacks, i18n=i18n)
else:
printer.print(
content="Context length exceeded. Consider using smaller text or RAG tools from crewai_tools.",
color="red",
)
if verbose:
printer.print(
content="Context length exceeded. Consider using smaller text or RAG tools from crewai_tools.",
color="red",
)
raise SystemExit(
"Context length exceeded and user opted not to summarize. Consider using smaller text or RAG tools from crewai_tools."
)
@@ -613,6 +638,7 @@ def summarize_messages(
llm: LLM | BaseLLM,
callbacks: list[TokenCalcHandler],
i18n: I18N,
verbose: bool = True,
) -> None:
"""Summarize messages to fit within context window.
@@ -644,10 +670,11 @@ def summarize_messages(
total_groups = len(messages_groups)
for idx, group in enumerate(messages_groups, 1):
Printer().print(
content=f"Summarizing {idx}/{total_groups}...",
color="yellow",
)
if verbose:
Printer().print(
content=f"Summarizing {idx}/{total_groups}...",
color="yellow",
)
summarization_messages = [
format_message_for_llm(
@@ -905,12 +932,14 @@ def extract_tool_call_info(
def _setup_before_llm_call_hooks(
executor_context: CrewAgentExecutor | AgentExecutor | LiteAgent | None,
printer: Printer,
verbose: bool = True,
) -> bool:
"""Setup and invoke before_llm_call hooks for the executor context.
Args:
executor_context: The executor context to setup the hooks for.
printer: Printer instance for error logging.
verbose: Whether to print output.
Returns:
True if LLM execution should proceed, False if blocked by a hook.
@@ -925,26 +954,29 @@ def _setup_before_llm_call_hooks(
for hook in executor_context.before_llm_call_hooks:
result = hook(hook_context)
if result is False:
printer.print(
content="LLM call blocked by before_llm_call hook",
color="yellow",
)
if verbose:
printer.print(
content="LLM call blocked by before_llm_call hook",
color="yellow",
)
return False
except Exception as e:
printer.print(
content=f"Error in before_llm_call hook: {e}",
color="yellow",
)
if verbose:
printer.print(
content=f"Error in before_llm_call hook: {e}",
color="yellow",
)
if not isinstance(executor_context.messages, list):
printer.print(
content=(
"Warning: before_llm_call hook replaced messages with non-list. "
"Restoring original messages list. Hooks should modify messages in-place, "
"not replace the list (e.g., use context.messages.append() not context.messages = [])."
),
color="yellow",
)
if verbose:
printer.print(
content=(
"Warning: before_llm_call hook replaced messages with non-list. "
"Restoring original messages list. Hooks should modify messages in-place, "
"not replace the list (e.g., use context.messages.append() not context.messages = [])."
),
color="yellow",
)
if isinstance(original_messages, list):
executor_context.messages = original_messages
else:
@@ -955,49 +987,79 @@ def _setup_before_llm_call_hooks(
def _setup_after_llm_call_hooks(
executor_context: CrewAgentExecutor | AgentExecutor | LiteAgent | None,
answer: str,
answer: str | BaseModel,
printer: Printer,
) -> str:
verbose: bool = True,
) -> str | BaseModel:
"""Setup and invoke after_llm_call hooks for the executor context.
Args:
executor_context: The executor context to setup the hooks for.
answer: The LLM response string.
answer: The LLM response (string or Pydantic model).
printer: Printer instance for error logging.
verbose: Whether to print output.
Returns:
The potentially modified response string.
The potentially modified response (string or Pydantic model).
"""
if executor_context and executor_context.after_llm_call_hooks:
from crewai.hooks.llm_hooks import LLMCallHookContext
original_messages = executor_context.messages
hook_context = LLMCallHookContext(executor_context, response=answer)
# For Pydantic models, serialize to JSON for hooks
if isinstance(answer, BaseModel):
pydantic_answer = answer
hook_response: str = pydantic_answer.model_dump_json()
original_json: str = hook_response
else:
pydantic_answer = None
hook_response = str(answer)
hook_context = LLMCallHookContext(executor_context, response=hook_response)
try:
for hook in executor_context.after_llm_call_hooks:
modified_response = hook(hook_context)
if modified_response is not None and isinstance(modified_response, str):
answer = modified_response
hook_response = modified_response
except Exception as e:
printer.print(
content=f"Error in after_llm_call hook: {e}",
color="yellow",
)
if verbose:
printer.print(
content=f"Error in after_llm_call hook: {e}",
color="yellow",
)
if not isinstance(executor_context.messages, list):
printer.print(
content=(
"Warning: after_llm_call hook replaced messages with non-list. "
"Restoring original messages list. Hooks should modify messages in-place, "
"not replace the list (e.g., use context.messages.append() not context.messages = [])."
),
color="yellow",
)
if verbose:
printer.print(
content=(
"Warning: after_llm_call hook replaced messages with non-list. "
"Restoring original messages list. Hooks should modify messages in-place, "
"not replace the list (e.g., use context.messages.append() not context.messages = [])."
),
color="yellow",
)
if isinstance(original_messages, list):
executor_context.messages = original_messages
else:
executor_context.messages = []
# If hooks modified the response, update answer accordingly
if pydantic_answer is not None:
# For Pydantic models, reparse the JSON if it was modified
if hook_response != original_json:
try:
model_class: type[BaseModel] = type(pydantic_answer)
answer = model_class.model_validate_json(hook_response)
except Exception as e:
if verbose:
printer.print(
content=f"Warning: Hook modified response but failed to reparse as {type(pydantic_answer).__name__}: {e}. Using original model.",
color="yellow",
)
else:
# For string responses, use the hook-modified response
answer = hook_response
return answer

View File

@@ -62,7 +62,10 @@ class Converter(OutputConverter):
],
response_model=self.model,
)
result = self.model.model_validate_json(response)
if isinstance(response, BaseModel):
result = response
else:
result = self.model.model_validate_json(response)
else:
response = self.llm.call(
[
@@ -205,10 +208,11 @@ def convert_to_model(
)
except Exception as e:
Printer().print(
content=f"Unexpected error during model conversion: {type(e).__name__}: {e}. Returning original result.",
color="red",
)
if agent and getattr(agent, "verbose", True):
Printer().print(
content=f"Unexpected error during model conversion: {type(e).__name__}: {e}. Returning original result.",
color="red",
)
return result
@@ -262,10 +266,11 @@ def handle_partial_json(
except ValidationError:
raise
except Exception as e:
Printer().print(
content=f"Unexpected error during partial JSON handling: {type(e).__name__}: {e}. Attempting alternative conversion method.",
color="red",
)
if agent and getattr(agent, "verbose", True):
Printer().print(
content=f"Unexpected error during partial JSON handling: {type(e).__name__}: {e}. Attempting alternative conversion method.",
color="red",
)
return convert_with_instructions(
result=result,
@@ -323,10 +328,11 @@ def convert_with_instructions(
)
if isinstance(exported_result, ConverterError):
Printer().print(
content=f"Failed to convert result to model: {exported_result}",
color="red",
)
if agent and getattr(agent, "verbose", True):
Printer().print(
content=f"Failed to convert result to model: {exported_result}",
color="red",
)
return result
return exported_result

View File

@@ -23,7 +23,13 @@ class SystemPromptResult(StandardPromptResult):
COMPONENTS = Literal[
"role_playing", "tools", "no_tools", "native_tools", "task", "native_task"
"role_playing",
"tools",
"no_tools",
"native_tools",
"task",
"native_task",
"task_no_tools",
]
@@ -74,11 +80,14 @@ class Prompts(BaseModel):
slices.append("no_tools")
system: str = self._build_prompt(slices)
# Use native_task for native tool calling (no "Thought:" prompt)
# Use task for ReAct pattern (includes "Thought:" prompt)
task_slice: COMPONENTS = (
"native_task" if self.use_native_tool_calling else "task"
)
# Determine which task slice to use:
task_slice: COMPONENTS
if self.use_native_tool_calling:
task_slice = "native_task"
elif self.has_tools:
task_slice = "task"
else:
task_slice = "task_no_tools"
slices.append(task_slice)
if (

View File

@@ -104,6 +104,7 @@ class TestA2AStreamingIntegration:
message=test_message,
new_messages=new_messages,
agent_card=agent_card,
endpoint=agent_card.url,
)
assert isinstance(result, dict)
@@ -225,6 +226,7 @@ class TestA2APushNotificationHandler:
result_store=mock_store,
polling_timeout=30.0,
polling_interval=1.0,
endpoint=mock_agent_card.url,
)
mock_store.wait_for_result.assert_called_once_with(
@@ -287,6 +289,7 @@ class TestA2APushNotificationHandler:
result_store=mock_store,
polling_timeout=5.0,
polling_interval=0.5,
endpoint=mock_agent_card.url,
)
assert result["status"] == TaskState.failed
@@ -317,6 +320,7 @@ class TestA2APushNotificationHandler:
message=test_msg,
new_messages=new_messages,
agent_card=mock_agent_card,
endpoint=mock_agent_card.url,
)
assert result["status"] == TaskState.failed

View File

@@ -43,6 +43,7 @@ def mock_context() -> MagicMock:
context.context_id = "test-context-456"
context.get_user_input.return_value = "Test user message"
context.message = MagicMock(spec=Message)
context.message.parts = []
context.current_task = None
return context

View File

@@ -1004,3 +1004,53 @@ def test_prepare_kickoff_param_files_override_message_files():
assert "files" in inputs
assert inputs["files"]["same.png"] is param_file # param takes precedence
def test_lite_agent_verbose_false_suppresses_printer_output():
"""Test that setting verbose=False suppresses all printer output."""
from crewai.agents.parser import AgentFinish
from crewai.types.usage_metrics import UsageMetrics
mock_llm = Mock(spec=LLM)
mock_llm.call.return_value = "Final Answer: Hello!"
mock_llm.stop = []
mock_llm.supports_stop_words.return_value = False
mock_llm.get_token_usage_summary.return_value = UsageMetrics(
total_tokens=100,
prompt_tokens=50,
completion_tokens=50,
cached_prompt_tokens=0,
successful_requests=1,
)
with pytest.warns(DeprecationWarning):
agent = LiteAgent(
role="Test Agent",
goal="Test goal",
backstory="Test backstory",
llm=mock_llm,
verbose=False,
)
result = agent.kickoff("Say hello")
assert result is not None
assert isinstance(result, LiteAgentOutput)
# Verify the printer was never called
agent._printer.print = Mock()
# For a clean verification, patch printer before execution
with pytest.warns(DeprecationWarning):
agent2 = LiteAgent(
role="Test Agent",
goal="Test goal",
backstory="Test backstory",
llm=mock_llm,
verbose=False,
)
mock_printer = Mock()
agent2._printer = mock_printer
agent2.kickoff("Say hello")
mock_printer.print.assert_not_called()

View File

@@ -0,0 +1,112 @@
interactions:
- request:
body: '{"messages":[{"role":"system","content":"You are Language Detector. You
are an expert linguist who can identify languages.\nYour personal goal is: Detect
the language of text"},{"role":"user","content":"\nCurrent Task: What language
is this text written in: ''Hello, how are you?''\n\nThis is the expected criteria
for your final answer: The detected language (e.g., English, Spanish, etc.)\nyou
MUST return the actual complete content as the final answer, not a summary.\n\nProvide
your complete response:"}],"model":"gpt-4o-mini"}'
headers:
User-Agent:
- X-USER-AGENT-XXX
accept:
- application/json
accept-encoding:
- ACCEPT-ENCODING-XXX
authorization:
- AUTHORIZATION-XXX
connection:
- keep-alive
content-length:
- '530'
content-type:
- application/json
host:
- api.openai.com
x-stainless-arch:
- X-STAINLESS-ARCH-XXX
x-stainless-async:
- 'false'
x-stainless-lang:
- python
x-stainless-os:
- X-STAINLESS-OS-XXX
x-stainless-package-version:
- 1.83.0
x-stainless-read-timeout:
- X-STAINLESS-READ-TIMEOUT-XXX
x-stainless-retry-count:
- '0'
x-stainless-runtime:
- CPython
x-stainless-runtime-version:
- 3.13.3
method: POST
uri: https://api.openai.com/v1/chat/completions
response:
body:
string: "{\n \"id\": \"chatcmpl-D39bkotgEapBcz1sSIXvhPhK9G7FD\",\n \"object\":
\"chat.completion\",\n \"created\": 1769644288,\n \"model\": \"gpt-4o-mini-2024-07-18\",\n
\ \"choices\": [\n {\n \"index\": 0,\n \"message\": {\n \"role\":
\"assistant\",\n \"content\": \"English\",\n \"refusal\": null,\n
\ \"annotations\": []\n },\n \"logprobs\": null,\n \"finish_reason\":
\"stop\"\n }\n ],\n \"usage\": {\n \"prompt_tokens\": 101,\n \"completion_tokens\":
1,\n \"total_tokens\": 102,\n \"prompt_tokens_details\": {\n \"cached_tokens\":
0,\n \"audio_tokens\": 0\n },\n \"completion_tokens_details\":
{\n \"reasoning_tokens\": 0,\n \"audio_tokens\": 0,\n \"accepted_prediction_tokens\":
0,\n \"rejected_prediction_tokens\": 0\n }\n },\n \"service_tier\":
\"default\",\n \"system_fingerprint\": \"fp_3683ee3deb\"\n}\n"
headers:
CF-RAY:
- CF-RAY-XXX
Connection:
- keep-alive
Content-Type:
- application/json
Date:
- Wed, 28 Jan 2026 23:51:28 GMT
Server:
- cloudflare
Set-Cookie:
- SET-COOKIE-XXX
Strict-Transport-Security:
- STS-XXX
Transfer-Encoding:
- chunked
X-Content-Type-Options:
- X-CONTENT-TYPE-XXX
access-control-expose-headers:
- ACCESS-CONTROL-XXX
alt-svc:
- h3=":443"; ma=86400
cf-cache-status:
- DYNAMIC
openai-organization:
- OPENAI-ORG-XXX
openai-processing-ms:
- '279'
openai-project:
- OPENAI-PROJECT-XXX
openai-version:
- '2020-10-01'
x-openai-proxy-wasm:
- v0.1
x-ratelimit-limit-requests:
- X-RATELIMIT-LIMIT-REQUESTS-XXX
x-ratelimit-limit-tokens:
- X-RATELIMIT-LIMIT-TOKENS-XXX
x-ratelimit-remaining-requests:
- X-RATELIMIT-REMAINING-REQUESTS-XXX
x-ratelimit-remaining-tokens:
- X-RATELIMIT-REMAINING-TOKENS-XXX
x-ratelimit-reset-requests:
- X-RATELIMIT-RESET-REQUESTS-XXX
x-ratelimit-reset-tokens:
- X-RATELIMIT-RESET-TOKENS-XXX
x-request-id:
- X-REQUEST-ID-XXX
status:
code: 200
message: OK
version: 1

View File

@@ -0,0 +1,111 @@
interactions:
- request:
body: '{"messages":[{"role":"system","content":"You are Classifier. You classify
text sentiment accurately.\nYour personal goal is: Classify text sentiment"},{"role":"user","content":"\nCurrent
Task: Classify the sentiment of: ''I love this product!''\n\nThis is the expected
criteria for your final answer: One word: positive, negative, or neutral\nyou
MUST return the actual complete content as the final answer, not a summary.\n\nProvide
your complete response:"}],"model":"gpt-4o-mini"}'
headers:
User-Agent:
- X-USER-AGENT-XXX
accept:
- application/json
accept-encoding:
- ACCEPT-ENCODING-XXX
authorization:
- AUTHORIZATION-XXX
connection:
- keep-alive
content-length:
- '481'
content-type:
- application/json
host:
- api.openai.com
x-stainless-arch:
- X-STAINLESS-ARCH-XXX
x-stainless-async:
- 'false'
x-stainless-lang:
- python
x-stainless-os:
- X-STAINLESS-OS-XXX
x-stainless-package-version:
- 1.83.0
x-stainless-read-timeout:
- X-STAINLESS-READ-TIMEOUT-XXX
x-stainless-retry-count:
- '0'
x-stainless-runtime:
- CPython
x-stainless-runtime-version:
- 3.13.3
method: POST
uri: https://api.openai.com/v1/chat/completions
response:
body:
string: "{\n \"id\": \"chatcmpl-D39bkVPelOZanWIMBoIyzsuj072sM\",\n \"object\":
\"chat.completion\",\n \"created\": 1769644288,\n \"model\": \"gpt-4o-mini-2024-07-18\",\n
\ \"choices\": [\n {\n \"index\": 0,\n \"message\": {\n \"role\":
\"assistant\",\n \"content\": \"positive\",\n \"refusal\": null,\n
\ \"annotations\": []\n },\n \"logprobs\": null,\n \"finish_reason\":
\"stop\"\n }\n ],\n \"usage\": {\n \"prompt_tokens\": 89,\n \"completion_tokens\":
1,\n \"total_tokens\": 90,\n \"prompt_tokens_details\": {\n \"cached_tokens\":
0,\n \"audio_tokens\": 0\n },\n \"completion_tokens_details\":
{\n \"reasoning_tokens\": 0,\n \"audio_tokens\": 0,\n \"accepted_prediction_tokens\":
0,\n \"rejected_prediction_tokens\": 0\n }\n },\n \"service_tier\":
\"default\",\n \"system_fingerprint\": \"fp_3683ee3deb\"\n}\n"
headers:
CF-RAY:
- CF-RAY-XXX
Connection:
- keep-alive
Content-Type:
- application/json
Date:
- Wed, 28 Jan 2026 23:51:29 GMT
Server:
- cloudflare
Set-Cookie:
- SET-COOKIE-XXX
Strict-Transport-Security:
- STS-XXX
Transfer-Encoding:
- chunked
X-Content-Type-Options:
- X-CONTENT-TYPE-XXX
access-control-expose-headers:
- ACCESS-CONTROL-XXX
alt-svc:
- h3=":443"; ma=86400
cf-cache-status:
- DYNAMIC
openai-organization:
- OPENAI-ORG-XXX
openai-processing-ms:
- '323'
openai-project:
- OPENAI-PROJECT-XXX
openai-version:
- '2020-10-01'
x-openai-proxy-wasm:
- v0.1
x-ratelimit-limit-requests:
- X-RATELIMIT-LIMIT-REQUESTS-XXX
x-ratelimit-limit-tokens:
- X-RATELIMIT-LIMIT-TOKENS-XXX
x-ratelimit-remaining-requests:
- X-RATELIMIT-REMAINING-REQUESTS-XXX
x-ratelimit-remaining-tokens:
- X-RATELIMIT-REMAINING-TOKENS-XXX
x-ratelimit-reset-requests:
- X-RATELIMIT-RESET-REQUESTS-XXX
x-ratelimit-reset-tokens:
- X-RATELIMIT-RESET-TOKENS-XXX
x-request-id:
- X-REQUEST-ID-XXX
status:
code: 200
message: OK
version: 1

View File

@@ -157,10 +157,10 @@ async def test_anthropic_async_with_response_model():
"Say hello in French",
response_model=GreetingResponse
)
model = GreetingResponse.model_validate_json(result)
assert isinstance(model, GreetingResponse)
assert isinstance(model.greeting, str)
assert isinstance(model.language, str)
# When response_model is provided, the result is already a parsed Pydantic model instance
assert isinstance(result, GreetingResponse)
assert isinstance(result.greeting, str)
assert isinstance(result.language, str)
@pytest.mark.vcr()

View File

@@ -799,3 +799,131 @@ def test_google_express_mode_works() -> None:
assert result.token_usage.prompt_tokens > 0
assert result.token_usage.completion_tokens > 0
assert result.token_usage.successful_requests >= 1
def test_gemini_2_0_model_detection():
"""Test that Gemini 2.0 models are properly detected."""
# Test Gemini 2.0 models
llm_2_0 = LLM(model="google/gemini-2.0-flash-001")
from crewai.llms.providers.gemini.completion import GeminiCompletion
assert isinstance(llm_2_0, GeminiCompletion)
assert llm_2_0.is_gemini_2_0 is True
llm_2_5 = LLM(model="google/gemini-2.5-flash")
assert isinstance(llm_2_5, GeminiCompletion)
assert llm_2_5.is_gemini_2_0 is True
# Test non-2.0 models
llm_1_5 = LLM(model="google/gemini-1.5-pro")
assert isinstance(llm_1_5, GeminiCompletion)
assert llm_1_5.is_gemini_2_0 is False
def test_add_property_ordering_to_schema():
"""Test that _add_property_ordering correctly adds propertyOrdering to schemas."""
from crewai.llms.providers.gemini.completion import GeminiCompletion
# Test simple object schema
simple_schema = {
"type": "object",
"properties": {
"name": {"type": "string"},
"age": {"type": "integer"},
"email": {"type": "string"}
}
}
result = GeminiCompletion._add_property_ordering(simple_schema)
assert "propertyOrdering" in result
assert result["propertyOrdering"] == ["name", "age", "email"]
# Test nested object schema
nested_schema = {
"type": "object",
"properties": {
"user": {
"type": "object",
"properties": {
"name": {"type": "string"},
"contact": {
"type": "object",
"properties": {
"email": {"type": "string"},
"phone": {"type": "string"}
}
}
}
},
"id": {"type": "integer"}
}
}
result = GeminiCompletion._add_property_ordering(nested_schema)
assert "propertyOrdering" in result
assert result["propertyOrdering"] == ["user", "id"]
assert "propertyOrdering" in result["properties"]["user"]
assert result["properties"]["user"]["propertyOrdering"] == ["name", "contact"]
assert "propertyOrdering" in result["properties"]["user"]["properties"]["contact"]
assert result["properties"]["user"]["properties"]["contact"]["propertyOrdering"] == ["email", "phone"]
def test_gemini_2_0_response_model_with_property_ordering():
"""Test that Gemini 2.0 models include propertyOrdering in response schemas."""
from pydantic import BaseModel, Field
class TestResponse(BaseModel):
"""Test response model."""
name: str = Field(..., description="The name")
age: int = Field(..., description="The age")
email: str = Field(..., description="The email")
llm = LLM(model="google/gemini-2.0-flash-001")
# Prepare generation config with response model
config = llm._prepare_generation_config(response_model=TestResponse)
# Verify that the config has response_json_schema
assert hasattr(config, 'response_json_schema') or 'response_json_schema' in config.__dict__
# Get the schema
if hasattr(config, 'response_json_schema'):
schema = config.response_json_schema
else:
schema = config.__dict__.get('response_json_schema', {})
# Verify propertyOrdering is present for Gemini 2.0
assert "propertyOrdering" in schema
assert "name" in schema["propertyOrdering"]
assert "age" in schema["propertyOrdering"]
assert "email" in schema["propertyOrdering"]
def test_gemini_1_5_response_model_uses_response_schema():
"""Test that Gemini 1.5 models use response_schema parameter (not response_json_schema)."""
from pydantic import BaseModel, Field
class TestResponse(BaseModel):
"""Test response model."""
name: str = Field(..., description="The name")
age: int = Field(..., description="The age")
llm = LLM(model="google/gemini-1.5-pro")
# Prepare generation config with response model
config = llm._prepare_generation_config(response_model=TestResponse)
# Verify that the config uses response_schema (not response_json_schema)
assert hasattr(config, 'response_schema') or 'response_schema' in config.__dict__
assert not (hasattr(config, 'response_json_schema') and config.response_json_schema is not None)
# Get the schema
if hasattr(config, 'response_schema'):
schema = config.response_schema
else:
schema = config.__dict__.get('response_schema')
# For Gemini 1.5, response_schema should be the Pydantic model itself
# The SDK handles conversion internally
assert schema is TestResponse or isinstance(schema, type)

View File

@@ -540,7 +540,9 @@ def test_openai_streaming_with_response_model():
result = llm.call("Test question", response_model=TestResponse)
assert result is not None
assert isinstance(result, str)
assert isinstance(result, TestResponse)
assert result.answer == "test"
assert result.confidence == 0.95
assert mock_stream.called
call_kwargs = mock_stream.call_args[1]

View File

@@ -2585,6 +2585,7 @@ def test_warning_long_term_memory_without_entity_memory():
goal="You research about math.",
backstory="You're an expert in research and you love to learn new things.",
allow_delegation=False,
verbose=True,
)
task1 = Task(

View File

@@ -0,0 +1,234 @@
"""Tests for prompt generation to prevent thought leakage.
These tests verify that:
1. Agents without tools don't get ReAct format instructions
2. The generated prompts don't encourage "Thought:" prefixes that leak into output
3. Real LLM calls produce clean output without internal reasoning
"""
from __future__ import annotations
from unittest.mock import MagicMock
import pytest
from crewai import Agent, Crew, Task
from crewai.llm import LLM
from crewai.utilities.prompts import Prompts
class TestNoToolsPromptGeneration:
"""Tests for prompt generation when agent has no tools."""
def test_no_tools_uses_task_no_tools_slice(self) -> None:
"""Test that agents without tools use task_no_tools slice instead of task."""
mock_agent = MagicMock()
mock_agent.role = "Test Agent"
mock_agent.goal = "Test goal"
mock_agent.backstory = "Test backstory"
prompts = Prompts(
has_tools=False,
use_native_tool_calling=False,
use_system_prompt=True,
agent=mock_agent,
)
result = prompts.task_execution()
# Verify it's a SystemPromptResult with system and user keys
assert "system" in result
assert "user" in result
assert "prompt" in result
# The user prompt should NOT contain "Thought:" (ReAct format)
assert "Thought:" not in result["user"]
# The user prompt should NOT mention tools
assert "use the tools available" not in result["user"]
assert "tools available" not in result["user"].lower()
# The system prompt should NOT contain ReAct format instructions
assert "Thought:" not in result["system"]
assert "Final Answer:" not in result["system"]
def test_no_tools_prompt_is_simple(self) -> None:
"""Test that no-tools prompt is simple and direct."""
mock_agent = MagicMock()
mock_agent.role = "Language Detector"
mock_agent.goal = "Detect language"
mock_agent.backstory = "Expert linguist"
prompts = Prompts(
has_tools=False,
use_native_tool_calling=False,
use_system_prompt=True,
agent=mock_agent,
)
result = prompts.task_execution()
# Should contain the role playing info
assert "Language Detector" in result["system"]
# User prompt should be simple with just the task
assert "Current Task:" in result["user"]
assert "Provide your complete response:" in result["user"]
def test_with_tools_uses_task_slice_with_react(self) -> None:
"""Test that agents WITH tools use the task slice (ReAct format)."""
mock_agent = MagicMock()
mock_agent.role = "Test Agent"
mock_agent.goal = "Test goal"
mock_agent.backstory = "Test backstory"
prompts = Prompts(
has_tools=True,
use_native_tool_calling=False,
use_system_prompt=True,
agent=mock_agent,
)
result = prompts.task_execution()
# With tools and ReAct, the prompt SHOULD contain Thought:
assert "Thought:" in result["user"]
def test_native_tools_uses_native_task_slice(self) -> None:
"""Test that native tool calling uses native_task slice."""
mock_agent = MagicMock()
mock_agent.role = "Test Agent"
mock_agent.goal = "Test goal"
mock_agent.backstory = "Test backstory"
prompts = Prompts(
has_tools=True,
use_native_tool_calling=True,
use_system_prompt=True,
agent=mock_agent,
)
result = prompts.task_execution()
# Native tool calling should NOT have Thought: in user prompt
assert "Thought:" not in result["user"]
# Should NOT have emotional manipulation
assert "your job depends on it" not in result["user"]
class TestNoThoughtLeakagePatterns:
"""Tests to verify prompts don't encourage thought leakage."""
def test_no_job_depends_on_it_in_no_tools(self) -> None:
"""Test that 'your job depends on it' is not in no-tools prompts."""
mock_agent = MagicMock()
mock_agent.role = "Test"
mock_agent.goal = "Test"
mock_agent.backstory = "Test"
prompts = Prompts(
has_tools=False,
use_native_tool_calling=False,
use_system_prompt=True,
agent=mock_agent,
)
result = prompts.task_execution()
full_prompt = result["prompt"]
assert "your job depends on it" not in full_prompt.lower()
assert "i must use these formats" not in full_prompt.lower()
def test_no_job_depends_on_it_in_native_task(self) -> None:
"""Test that 'your job depends on it' is not in native task prompts."""
mock_agent = MagicMock()
mock_agent.role = "Test"
mock_agent.goal = "Test"
mock_agent.backstory = "Test"
prompts = Prompts(
has_tools=True,
use_native_tool_calling=True,
use_system_prompt=True,
agent=mock_agent,
)
result = prompts.task_execution()
full_prompt = result["prompt"]
assert "your job depends on it" not in full_prompt.lower()
class TestRealLLMNoThoughtLeakage:
"""Integration tests with real LLM calls to verify no thought leakage."""
@pytest.mark.vcr()
def test_agent_without_tools_no_thought_in_output(self) -> None:
"""Test that agent without tools produces clean output without 'Thought:' prefix."""
agent = Agent(
role="Language Detector",
goal="Detect the language of text",
backstory="You are an expert linguist who can identify languages.",
tools=[], # No tools
llm=LLM(model="gpt-4o-mini"),
verbose=False,
)
task = Task(
description="What language is this text written in: 'Hello, how are you?'",
expected_output="The detected language (e.g., English, Spanish, etc.)",
agent=agent,
)
crew = Crew(agents=[agent], tasks=[task])
result = crew.kickoff()
assert result is not None
assert result.raw is not None
# The output should NOT start with "Thought:" or contain ReAct artifacts
output = str(result.raw)
assert not output.strip().startswith("Thought:")
assert "Final Answer:" not in output
assert "I now can give a great answer" not in output
# Should contain an actual answer about the language
assert any(
lang in output.lower()
for lang in ["english", "en", "language"]
)
@pytest.mark.vcr()
def test_simple_task_clean_output(self) -> None:
"""Test that a simple task produces clean output without internal reasoning."""
agent = Agent(
role="Classifier",
goal="Classify text sentiment",
backstory="You classify text sentiment accurately.",
tools=[],
llm=LLM(model="gpt-4o-mini"),
verbose=False,
)
task = Task(
description="Classify the sentiment of: 'I love this product!'",
expected_output="One word: positive, negative, or neutral",
agent=agent,
)
crew = Crew(agents=[agent], tasks=[task])
result = crew.kickoff()
assert result is not None
output = str(result.raw).strip().lower()
# Output should be clean - just the classification
assert not output.startswith("thought:")
assert "final answer:" not in output
# Should contain the actual classification
assert any(
sentiment in output
for sentiment in ["positive", "negative", "neutral"]
)

View File

@@ -1,3 +1,3 @@
"""CrewAI development tools."""
__version__ = "1.9.0"
__version__ = "1.9.2"