feat: add transport negotiation and content type handling

- add transport negotiation logic with fallback support
- add content type parser and encoder utilities
- add transport configuration models (client and server)
- add transport types and enums
- enhance config with transport settings
- add negotiation events for transport and content type
This commit is contained in:
Greyson LaLonde
2026-01-29 05:13:42 -05:00
parent e4be1329a0
commit 4543c66697
24 changed files with 3994 additions and 641 deletions

View File

@@ -1,4 +1,4 @@
"""Authentication schemes for A2A protocol agents.
"""Authentication schemes for A2A protocol clients.
Supported authentication methods:
- Bearer tokens
@@ -6,24 +6,135 @@ Supported authentication methods:
- API Keys (header, query, cookie)
- HTTP Basic authentication
- HTTP Digest authentication
- mTLS (mutual TLS) client certificate authentication
"""
from __future__ import annotations
from abc import ABC, abstractmethod
import asyncio
import base64
from collections.abc import Awaitable, Callable, MutableMapping
from pathlib import Path
import ssl
import time
from typing import Literal
from typing import TYPE_CHECKING, ClassVar, Literal
import urllib.parse
import httpx
from httpx import DigestAuth
from pydantic import BaseModel, Field, PrivateAttr
from pydantic import BaseModel, ConfigDict, Field, FilePath, PrivateAttr
from typing_extensions import deprecated
class AuthScheme(ABC, BaseModel):
"""Base class for authentication schemes."""
if TYPE_CHECKING:
import grpc # type: ignore[import-untyped]
class TLSConfig(BaseModel):
"""TLS/mTLS configuration for secure client connections.
Supports mutual TLS (mTLS) where the client presents a certificate to the server,
and standard TLS with custom CA verification.
Attributes:
client_cert_path: Path to client certificate file (PEM format) for mTLS.
client_key_path: Path to client private key file (PEM format) for mTLS.
ca_cert_path: Path to CA certificate bundle for server verification.
verify: Whether to verify server certificates. Set False only for development.
"""
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
client_cert_path: FilePath | None = Field(
default=None,
description="Path to client certificate file (PEM format) for mTLS",
)
client_key_path: FilePath | None = Field(
default=None,
description="Path to client private key file (PEM format) for mTLS",
)
ca_cert_path: FilePath | None = Field(
default=None,
description="Path to CA certificate bundle for server verification",
)
verify: bool = Field(
default=True,
description="Whether to verify server certificates. Set False only for development.",
)
def get_httpx_ssl_context(self) -> ssl.SSLContext | bool | str:
"""Build SSL context for httpx client.
Returns:
SSL context if certificates configured, True for default verification,
False if verification disabled, or path to CA bundle.
"""
if not self.verify:
return False
if self.client_cert_path and self.client_key_path:
context = ssl.create_default_context()
if self.ca_cert_path:
context.load_verify_locations(cafile=str(self.ca_cert_path))
context.load_cert_chain(
certfile=str(self.client_cert_path),
keyfile=str(self.client_key_path),
)
return context
if self.ca_cert_path:
return str(self.ca_cert_path)
return True
def get_grpc_credentials(self) -> grpc.ChannelCredentials | None: # type: ignore[no-any-unimported]
"""Build gRPC channel credentials for secure connections.
Returns:
gRPC SSL credentials if certificates configured, None otherwise.
"""
try:
import grpc
except ImportError:
return None
if not self.verify and not self.client_cert_path:
return None
root_certs: bytes | None = None
private_key: bytes | None = None
certificate_chain: bytes | None = None
if self.ca_cert_path:
root_certs = Path(self.ca_cert_path).read_bytes()
if self.client_cert_path and self.client_key_path:
private_key = Path(self.client_key_path).read_bytes()
certificate_chain = Path(self.client_cert_path).read_bytes()
return grpc.ssl_channel_credentials(
root_certificates=root_certs,
private_key=private_key,
certificate_chain=certificate_chain,
)
class ClientAuthScheme(ABC, BaseModel):
"""Base class for client-side authentication schemes.
Client auth schemes apply credentials to outgoing requests.
Attributes:
tls: Optional TLS/mTLS configuration for secure connections.
"""
tls: TLSConfig | None = Field(
default=None,
description="TLS/mTLS configuration for secure connections",
)
@abstractmethod
async def apply_auth(
@@ -41,7 +152,12 @@ class AuthScheme(ABC, BaseModel):
...
class BearerTokenAuth(AuthScheme):
@deprecated("Use ClientAuthScheme instead", category=FutureWarning)
class AuthScheme(ClientAuthScheme):
"""Deprecated: Use ClientAuthScheme instead."""
class BearerTokenAuth(ClientAuthScheme):
"""Bearer token authentication (Authorization: Bearer <token>).
Attributes:
@@ -66,7 +182,7 @@ class BearerTokenAuth(AuthScheme):
return headers
class HTTPBasicAuth(AuthScheme):
class HTTPBasicAuth(ClientAuthScheme):
"""HTTP Basic authentication.
Attributes:
@@ -95,7 +211,7 @@ class HTTPBasicAuth(AuthScheme):
return headers
class HTTPDigestAuth(AuthScheme):
class HTTPDigestAuth(ClientAuthScheme):
"""HTTP Digest authentication.
Note: Uses httpx-auth library for digest implementation.
@@ -108,6 +224,8 @@ class HTTPDigestAuth(AuthScheme):
username: str = Field(description="Username")
password: str = Field(description="Password")
_configured_client_id: int | None = PrivateAttr(default=None)
async def apply_auth(
self, client: httpx.AsyncClient, headers: MutableMapping[str, str]
) -> MutableMapping[str, str]:
@@ -125,13 +243,21 @@ class HTTPDigestAuth(AuthScheme):
def configure_client(self, client: httpx.AsyncClient) -> None:
"""Configure client with Digest auth.
Idempotent: Only configures the client once. Subsequent calls on the same
client instance are no-ops to prevent overwriting auth configuration.
Args:
client: HTTP client to configure with Digest authentication.
"""
client_id = id(client)
if self._configured_client_id == client_id:
return
client.auth = DigestAuth(self.username, self.password)
self._configured_client_id = client_id
class APIKeyAuth(AuthScheme):
class APIKeyAuth(ClientAuthScheme):
"""API Key authentication (header, query, or cookie).
Attributes:
@@ -146,6 +272,8 @@ class APIKeyAuth(AuthScheme):
)
name: str = Field(default="X-API-Key", description="Parameter name for the API key")
_configured_client_ids: set[int] = PrivateAttr(default_factory=set)
async def apply_auth(
self, client: httpx.AsyncClient, headers: MutableMapping[str, str]
) -> MutableMapping[str, str]:
@@ -167,21 +295,31 @@ class APIKeyAuth(AuthScheme):
def configure_client(self, client: httpx.AsyncClient) -> None:
"""Configure client for query param API keys.
Idempotent: Only adds the request hook once per client instance.
Subsequent calls on the same client are no-ops to prevent hook accumulation.
Args:
client: HTTP client to configure with query param API key hook.
"""
if self.location == "query":
client_id = id(client)
if client_id in self._configured_client_ids:
return
async def _add_api_key_param(request: httpx.Request) -> None:
url = httpx.URL(request.url)
request.url = url.copy_add_param(self.name, self.api_key)
client.event_hooks["request"].append(_add_api_key_param)
self._configured_client_ids.add(client_id)
class OAuth2ClientCredentials(AuthScheme):
class OAuth2ClientCredentials(ClientAuthScheme):
"""OAuth2 Client Credentials flow authentication.
Thread-safe implementation with asyncio.Lock to prevent concurrent token fetches
when multiple requests share the same auth instance.
Attributes:
token_url: OAuth2 token endpoint URL.
client_id: OAuth2 client identifier.
@@ -198,12 +336,17 @@ class OAuth2ClientCredentials(AuthScheme):
_access_token: str | None = PrivateAttr(default=None)
_token_expires_at: float | None = PrivateAttr(default=None)
_lock: asyncio.Lock = PrivateAttr(default_factory=asyncio.Lock)
async def apply_auth(
self, client: httpx.AsyncClient, headers: MutableMapping[str, str]
) -> MutableMapping[str, str]:
"""Apply OAuth2 access token to Authorization header.
Uses asyncio.Lock to ensure only one coroutine fetches tokens at a time,
preventing race conditions when multiple concurrent requests use the same
auth instance.
Args:
client: HTTP client for making token requests.
headers: Current request headers.
@@ -216,7 +359,13 @@ class OAuth2ClientCredentials(AuthScheme):
or self._token_expires_at is None
or time.time() >= self._token_expires_at
):
await self._fetch_token(client)
async with self._lock:
if (
self._access_token is None
or self._token_expires_at is None
or time.time() >= self._token_expires_at
):
await self._fetch_token(client)
if self._access_token:
headers["Authorization"] = f"Bearer {self._access_token}"
@@ -250,9 +399,11 @@ class OAuth2ClientCredentials(AuthScheme):
self._token_expires_at = time.time() + expires_in - 60
class OAuth2AuthorizationCode(AuthScheme):
class OAuth2AuthorizationCode(ClientAuthScheme):
"""OAuth2 Authorization Code flow authentication.
Thread-safe implementation with asyncio.Lock to prevent concurrent token operations.
Note: Requires interactive authorization.
Attributes:
@@ -279,6 +430,7 @@ class OAuth2AuthorizationCode(AuthScheme):
_authorization_callback: Callable[[str], Awaitable[str]] | None = PrivateAttr(
default=None
)
_lock: asyncio.Lock = PrivateAttr(default_factory=asyncio.Lock)
def set_authorization_callback(
self, callback: Callable[[str], Awaitable[str]] | None
@@ -295,6 +447,9 @@ class OAuth2AuthorizationCode(AuthScheme):
) -> MutableMapping[str, str]:
"""Apply OAuth2 access token to Authorization header.
Uses asyncio.Lock to ensure only one coroutine handles token operations
(initial fetch or refresh) at a time.
Args:
client: HTTP client for making token requests.
headers: Current request headers.
@@ -305,14 +460,17 @@ class OAuth2AuthorizationCode(AuthScheme):
Raises:
ValueError: If authorization callback is not set.
"""
if self._access_token is None:
if self._authorization_callback is None:
msg = "Authorization callback not set. Use set_authorization_callback()"
raise ValueError(msg)
await self._fetch_initial_token(client)
async with self._lock:
if self._access_token is None:
await self._fetch_initial_token(client)
elif self._token_expires_at and time.time() >= self._token_expires_at:
await self._refresh_access_token(client)
async with self._lock:
if self._token_expires_at and time.time() >= self._token_expires_at:
await self._refresh_access_token(client)
if self._access_token:
headers["Authorization"] = f"Bearer {self._access_token}"

View File

@@ -5,14 +5,25 @@ This module is separate from experimental.a2a to avoid circular imports.
from __future__ import annotations
from importlib.metadata import version
from typing import Any, ClassVar, Literal
from pathlib import Path
from typing import Any, ClassVar, Literal, cast
import warnings
from pydantic import BaseModel, ConfigDict, Field
from typing_extensions import deprecated
from pydantic import (
BaseModel,
ConfigDict,
Field,
FilePath,
PrivateAttr,
SecretStr,
model_validator,
)
from typing_extensions import Self, deprecated
from crewai.a2a.auth.schemas import AuthScheme
from crewai.a2a.types import TransportType, Url
from crewai.a2a.auth.client_schemes import ClientAuthScheme
from crewai.a2a.auth.server_schemes import ServerAuthScheme
from crewai.a2a.extensions.base import ValidatedA2AExtension
from crewai.a2a.types import ProtocolVersion, TransportType, Url
try:
@@ -25,16 +36,17 @@ try:
SecurityScheme,
)
from crewai.a2a.extensions.server import ServerExtension
from crewai.a2a.updates import UpdateConfig
except ImportError:
UpdateConfig = Any
AgentCapabilities = Any
AgentCardSignature = Any
AgentInterface = Any
AgentProvider = Any
SecurityScheme = Any
AgentSkill = Any
UpdateConfig = Any # type: ignore[misc,assignment]
UpdateConfig: Any = Any # type: ignore[no-redef]
AgentCapabilities: Any = Any # type: ignore[no-redef]
AgentCardSignature: Any = Any # type: ignore[no-redef]
AgentInterface: Any = Any # type: ignore[no-redef]
AgentProvider: Any = Any # type: ignore[no-redef]
SecurityScheme: Any = Any # type: ignore[no-redef]
AgentSkill: Any = Any # type: ignore[no-redef]
ServerExtension: Any = Any # type: ignore[no-redef]
def _get_default_update_config() -> UpdateConfig:
@@ -43,6 +55,309 @@ def _get_default_update_config() -> UpdateConfig:
return StreamingConfig()
SigningAlgorithm = Literal[
"RS256",
"RS384",
"RS512",
"ES256",
"ES384",
"ES512",
"PS256",
"PS384",
"PS512",
]
class AgentCardSigningConfig(BaseModel):
"""Configuration for AgentCard JWS signing.
Provides the private key and algorithm settings for signing AgentCards.
Either private_key_path or private_key_pem must be provided, but not both.
Attributes:
private_key_path: Path to a PEM-encoded private key file.
private_key_pem: PEM-encoded private key as a secret string.
key_id: Optional key identifier for the JWS header (kid claim).
algorithm: Signing algorithm (RS256, ES256, PS256, etc.).
"""
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
private_key_path: FilePath | None = Field(
default=None,
description="Path to PEM-encoded private key file",
)
private_key_pem: SecretStr | None = Field(
default=None,
description="PEM-encoded private key",
)
key_id: str | None = Field(
default=None,
description="Key identifier for JWS header (kid claim)",
)
algorithm: SigningAlgorithm = Field(
default="RS256",
description="Signing algorithm (RS256, ES256, PS256, etc.)",
)
@model_validator(mode="after")
def _validate_key_source(self) -> Self:
"""Ensure exactly one key source is provided."""
has_path = self.private_key_path is not None
has_pem = self.private_key_pem is not None
if not has_path and not has_pem:
raise ValueError(
"Either private_key_path or private_key_pem must be provided"
)
if has_path and has_pem:
raise ValueError(
"Only one of private_key_path or private_key_pem should be provided"
)
return self
def get_private_key(self) -> str:
"""Get the private key content.
Returns:
The PEM-encoded private key as a string.
"""
if self.private_key_pem:
return self.private_key_pem.get_secret_value()
if self.private_key_path:
return Path(self.private_key_path).read_text()
raise ValueError("No private key configured")
class GRPCServerConfig(BaseModel):
"""gRPC server transport configuration.
Presence of this config in ServerTransportConfig.grpc enables gRPC transport.
Attributes:
host: Hostname to advertise in agent cards (default: localhost).
Use docker service name (e.g., 'web') for docker-compose setups.
port: Port for the gRPC server.
tls_cert_path: Path to TLS certificate file for gRPC.
tls_key_path: Path to TLS private key file for gRPC.
max_workers: Maximum number of workers for the gRPC thread pool.
reflection_enabled: Whether to enable gRPC reflection for debugging.
"""
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
host: str = Field(
default="localhost",
description="Hostname to advertise in agent cards for gRPC connections",
)
port: int = Field(
default=50051,
description="Port for the gRPC server",
)
tls_cert_path: str | None = Field(
default=None,
description="Path to TLS certificate file for gRPC",
)
tls_key_path: str | None = Field(
default=None,
description="Path to TLS private key file for gRPC",
)
max_workers: int = Field(
default=10,
description="Maximum number of workers for the gRPC thread pool",
)
reflection_enabled: bool = Field(
default=False,
description="Whether to enable gRPC reflection for debugging",
)
class GRPCClientConfig(BaseModel):
"""gRPC client transport configuration.
Attributes:
max_send_message_length: Maximum size for outgoing messages in bytes.
max_receive_message_length: Maximum size for incoming messages in bytes.
keepalive_time_ms: Time between keepalive pings in milliseconds.
keepalive_timeout_ms: Timeout for keepalive ping response in milliseconds.
"""
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
max_send_message_length: int | None = Field(
default=None,
description="Maximum size for outgoing messages in bytes",
)
max_receive_message_length: int | None = Field(
default=None,
description="Maximum size for incoming messages in bytes",
)
keepalive_time_ms: int | None = Field(
default=None,
description="Time between keepalive pings in milliseconds",
)
keepalive_timeout_ms: int | None = Field(
default=None,
description="Timeout for keepalive ping response in milliseconds",
)
class JSONRPCServerConfig(BaseModel):
"""JSON-RPC server transport configuration.
Presence of this config in ServerTransportConfig.jsonrpc enables JSON-RPC transport.
Attributes:
rpc_path: URL path for the JSON-RPC endpoint.
agent_card_path: URL path for the agent card endpoint.
"""
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
rpc_path: str = Field(
default="/a2a",
description="URL path for the JSON-RPC endpoint",
)
agent_card_path: str = Field(
default="/.well-known/agent-card.json",
description="URL path for the agent card endpoint",
)
class JSONRPCClientConfig(BaseModel):
"""JSON-RPC client transport configuration.
Attributes:
max_request_size: Maximum request body size in bytes.
"""
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
max_request_size: int | None = Field(
default=None,
description="Maximum request body size in bytes",
)
class HTTPJSONConfig(BaseModel):
"""HTTP+JSON transport configuration.
Presence of this config in ServerTransportConfig.http_json enables HTTP+JSON transport.
"""
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
class ServerPushNotificationConfig(BaseModel):
"""Configuration for outgoing webhook push notifications.
Controls how the server signs and delivers push notifications to clients.
Attributes:
signature_secret: Shared secret for HMAC-SHA256 signing of outgoing webhooks.
"""
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
signature_secret: SecretStr | None = Field(
default=None,
description="Shared secret for HMAC-SHA256 signing of outgoing push notifications",
)
class ServerTransportConfig(BaseModel):
"""Transport configuration for A2A server.
Groups all transport-related settings including preferred transport
and protocol-specific configurations.
Attributes:
preferred: Transport protocol for the preferred endpoint.
jsonrpc: JSON-RPC server transport configuration.
grpc: gRPC server transport configuration.
http_json: HTTP+JSON transport configuration.
"""
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
preferred: TransportType = Field(
default="JSONRPC",
description="Transport protocol for the preferred endpoint",
)
jsonrpc: JSONRPCServerConfig = Field(
default_factory=JSONRPCServerConfig,
description="JSON-RPC server transport configuration",
)
grpc: GRPCServerConfig | None = Field(
default=None,
description="gRPC server transport configuration",
)
http_json: HTTPJSONConfig | None = Field(
default=None,
description="HTTP+JSON transport configuration",
)
def _migrate_client_transport_fields(
transport: ClientTransportConfig,
transport_protocol: TransportType | None,
supported_transports: list[TransportType] | None,
) -> None:
"""Migrate deprecated transport fields to new config."""
if transport_protocol is not None:
warnings.warn(
"transport_protocol is deprecated, use transport=ClientTransportConfig(preferred=...) instead",
FutureWarning,
stacklevel=5,
)
object.__setattr__(transport, "preferred", transport_protocol)
if supported_transports is not None:
warnings.warn(
"supported_transports is deprecated, use transport=ClientTransportConfig(supported=...) instead",
FutureWarning,
stacklevel=5,
)
object.__setattr__(transport, "supported", supported_transports)
class ClientTransportConfig(BaseModel):
"""Transport configuration for A2A client.
Groups all client transport-related settings including preferred transport,
supported transports for negotiation, and protocol-specific configurations.
Transport negotiation logic:
1. If `preferred` is set and server supports it → use client's preferred
2. Otherwise, if server's preferred is in client's `supported` → use server's preferred
3. Otherwise, find first match from client's `supported` in server's interfaces
Attributes:
preferred: Client's preferred transport. If set, client preference takes priority.
supported: Transports the client can use, in order of preference.
jsonrpc: JSON-RPC client transport configuration.
grpc: gRPC client transport configuration.
"""
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
preferred: TransportType | None = Field(
default=None,
description="Client's preferred transport. If set, takes priority over server preference.",
)
supported: list[TransportType] = Field(
default_factory=lambda: cast(list[TransportType], ["JSONRPC"]),
description="Transports the client can use, in order of preference",
)
jsonrpc: JSONRPCClientConfig = Field(
default_factory=JSONRPCClientConfig,
description="JSON-RPC client transport configuration",
)
grpc: GRPCClientConfig = Field(
default_factory=GRPCClientConfig,
description="gRPC client transport configuration",
)
@deprecated(
"""
`crewai.a2a.config.A2AConfig` is deprecated and will be removed in v2.0.0,
@@ -65,13 +380,14 @@ class A2AConfig(BaseModel):
fail_fast: If True, raise error when agent unreachable; if False, skip and continue.
trust_remote_completion_status: If True, return A2A agent's result directly when completed.
updates: Update mechanism config.
transport_protocol: A2A transport protocol (grpc, jsonrpc, http+json).
client_extensions: Client-side processing hooks for tool injection and prompt augmentation.
transport: Transport configuration (preferred, supported transports, gRPC settings).
"""
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
endpoint: Url = Field(description="A2A agent endpoint URL")
auth: AuthScheme | None = Field(
auth: ClientAuthScheme | None = Field(
default=None,
description="Authentication scheme",
)
@@ -95,10 +411,48 @@ class A2AConfig(BaseModel):
default_factory=_get_default_update_config,
description="Update mechanism config",
)
transport_protocol: Literal["JSONRPC", "GRPC", "HTTP+JSON"] = Field(
default="JSONRPC",
description="Specified mode of A2A transport protocol",
client_extensions: list[ValidatedA2AExtension] = Field(
default_factory=list,
description="Client-side processing hooks for tool injection and prompt augmentation",
)
transport: ClientTransportConfig = Field(
default_factory=ClientTransportConfig,
description="Transport configuration (preferred, supported transports, gRPC settings)",
)
transport_protocol: TransportType | None = Field(
default=None,
description="Deprecated: Use transport.preferred instead",
exclude=True,
)
supported_transports: list[TransportType] | None = Field(
default=None,
description="Deprecated: Use transport.supported instead",
exclude=True,
)
use_client_preference: bool | None = Field(
default=None,
description="Deprecated: Set transport.preferred to enable client preference",
exclude=True,
)
_parallel_delegation: bool = PrivateAttr(default=False)
@model_validator(mode="after")
def _migrate_deprecated_transport_fields(self) -> Self:
"""Migrate deprecated transport fields to new config."""
_migrate_client_transport_fields(
self.transport, self.transport_protocol, self.supported_transports
)
if self.use_client_preference is not None:
warnings.warn(
"use_client_preference is deprecated, set transport.preferred to enable client preference",
FutureWarning,
stacklevel=4,
)
if self.use_client_preference and self.transport.supported:
object.__setattr__(
self.transport, "preferred", self.transport.supported[0]
)
return self
class A2AClientConfig(BaseModel):
@@ -114,15 +468,15 @@ class A2AClientConfig(BaseModel):
trust_remote_completion_status: If True, return A2A agent's result directly when completed.
updates: Update mechanism config.
accepted_output_modes: Media types the client can accept in responses.
supported_transports: Ordered list of transport protocols the client supports.
use_client_preference: Whether to prioritize client transport preferences over server.
extensions: Extension URIs the client supports.
extensions: Extension URIs the client supports (A2A protocol extensions).
client_extensions: Client-side processing hooks for tool injection and prompt augmentation.
transport: Transport configuration (preferred, supported transports, gRPC settings).
"""
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
endpoint: Url = Field(description="A2A agent endpoint URL")
auth: AuthScheme | None = Field(
auth: ClientAuthScheme | None = Field(
default=None,
description="Authentication scheme",
)
@@ -150,22 +504,37 @@ class A2AClientConfig(BaseModel):
default_factory=lambda: ["application/json"],
description="Media types the client can accept in responses",
)
supported_transports: list[str] = Field(
default_factory=lambda: ["JSONRPC"],
description="Ordered list of transport protocols the client supports",
)
use_client_preference: bool = Field(
default=False,
description="Whether to prioritize client transport preferences over server",
)
extensions: list[str] = Field(
default_factory=list,
description="Extension URIs the client supports",
)
transport_protocol: Literal["JSONRPC", "GRPC", "HTTP+JSON"] = Field(
default="JSONRPC",
description="Specified mode of A2A transport protocol",
client_extensions: list[ValidatedA2AExtension] = Field(
default_factory=list,
description="Client-side processing hooks for tool injection and prompt augmentation",
)
transport: ClientTransportConfig = Field(
default_factory=ClientTransportConfig,
description="Transport configuration (preferred, supported transports, gRPC settings)",
)
transport_protocol: TransportType | None = Field(
default=None,
description="Deprecated: Use transport.preferred instead",
exclude=True,
)
supported_transports: list[TransportType] | None = Field(
default=None,
description="Deprecated: Use transport.supported instead",
exclude=True,
)
_parallel_delegation: bool = PrivateAttr(default=False)
@model_validator(mode="after")
def _migrate_deprecated_transport_fields(self) -> Self:
"""Migrate deprecated transport fields to new config."""
_migrate_client_transport_fields(
self.transport, self.transport_protocol, self.supported_transports
)
return self
class A2AServerConfig(BaseModel):
@@ -182,7 +551,6 @@ class A2AServerConfig(BaseModel):
default_input_modes: Default supported input MIME types.
default_output_modes: Default supported output MIME types.
capabilities: Declaration of optional capabilities.
preferred_transport: Transport protocol for the preferred endpoint.
protocol_version: A2A protocol version this agent supports.
provider: Information about the agent's service provider.
documentation_url: URL to the agent's documentation.
@@ -192,7 +560,12 @@ class A2AServerConfig(BaseModel):
security_schemes: Security schemes available to authorize requests.
supports_authenticated_extended_card: Whether agent provides extended card to authenticated users.
url: Preferred endpoint URL for the agent.
signatures: JSON Web Signatures for the AgentCard.
signing_config: Configuration for signing the AgentCard with JWS.
signatures: Deprecated. Pre-computed JWS signatures. Use signing_config instead.
server_extensions: Server-side A2A protocol extensions with on_request/on_response hooks.
push_notifications: Configuration for outgoing push notifications.
transport: Transport configuration (preferred transport, gRPC, REST settings).
auth: Authentication scheme for A2A endpoints.
"""
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
@@ -228,12 +601,8 @@ class A2AServerConfig(BaseModel):
),
description="Declaration of optional capabilities supported by the agent",
)
preferred_transport: TransportType = Field(
default="JSONRPC",
description="Transport protocol for the preferred endpoint",
)
protocol_version: str = Field(
default_factory=lambda: version("a2a-sdk"),
protocol_version: ProtocolVersion = Field(
default="0.3.0",
description="A2A protocol version this agent supports",
)
provider: AgentProvider | None = Field(
@@ -250,7 +619,7 @@ class A2AServerConfig(BaseModel):
)
additional_interfaces: list[AgentInterface] = Field(
default_factory=list,
description="Additional supported interfaces (transport and URL combinations)",
description="Additional supported interfaces.",
)
security: list[dict[str, list[str]]] = Field(
default_factory=list,
@@ -268,7 +637,54 @@ class A2AServerConfig(BaseModel):
default=None,
description="Preferred endpoint URL for the agent. Set at runtime if not provided.",
)
signatures: list[AgentCardSignature] = Field(
default_factory=list,
description="JSON Web Signatures for the AgentCard",
signing_config: AgentCardSigningConfig | None = Field(
default=None,
description="Configuration for signing the AgentCard with JWS",
)
signatures: list[AgentCardSignature] | None = Field(
default=None,
description="Deprecated: Use signing_config instead. Pre-computed JWS signatures for the AgentCard.",
exclude=True,
deprecated=True,
)
server_extensions: list[ServerExtension] = Field(
default_factory=list,
description="Server-side A2A protocol extensions that modify agent behavior",
)
push_notifications: ServerPushNotificationConfig | None = Field(
default=None,
description="Configuration for outgoing push notifications",
)
transport: ServerTransportConfig = Field(
default_factory=ServerTransportConfig,
description="Transport configuration (preferred transport, gRPC, REST settings)",
)
preferred_transport: TransportType | None = Field(
default=None,
description="Deprecated: Use transport.preferred instead",
exclude=True,
deprecated=True,
)
auth: ServerAuthScheme | None = Field(
default=None,
description="Authentication scheme for A2A endpoints. Defaults to SimpleTokenAuth using AUTH_TOKEN env var.",
)
@model_validator(mode="after")
def _migrate_deprecated_fields(self) -> Self:
"""Migrate deprecated fields to new config."""
if self.preferred_transport is not None:
warnings.warn(
"preferred_transport is deprecated, use transport=ServerTransportConfig(preferred=...) instead",
FutureWarning,
stacklevel=4,
)
object.__setattr__(self.transport, "preferred", self.preferred_transport)
if self.signatures is not None:
warnings.warn(
"signatures is deprecated, use signing_config=AgentCardSigningConfig(...) instead. "
"The signatures field will be removed in v2.0.0.",
FutureWarning,
stacklevel=4,
)
return self

View File

@@ -1,7 +1,491 @@
"""A2A protocol error types."""
"""A2A error codes and error response utilities.
This module provides a centralized mapping of all A2A protocol error codes
as defined in the A2A specification, plus custom CrewAI extensions.
Error codes follow JSON-RPC 2.0 conventions:
- -32700 to -32600: Standard JSON-RPC errors
- -32099 to -32000: Server errors (A2A-specific)
- -32768 to -32100: Reserved for implementation-defined errors
"""
from __future__ import annotations
from dataclasses import dataclass, field
from enum import IntEnum
from typing import Any
from a2a.client.errors import A2AClientTimeoutError
class A2APollingTimeoutError(A2AClientTimeoutError):
"""Raised when polling exceeds the configured timeout."""
class A2AErrorCode(IntEnum):
"""A2A protocol error codes.
Codes follow JSON-RPC 2.0 specification with A2A-specific extensions.
"""
# JSON-RPC 2.0 Standard Errors (-32700 to -32600)
JSON_PARSE_ERROR = -32700
"""Invalid JSON was received by the server."""
INVALID_REQUEST = -32600
"""The JSON sent is not a valid Request object."""
METHOD_NOT_FOUND = -32601
"""The method does not exist / is not available."""
INVALID_PARAMS = -32602
"""Invalid method parameter(s)."""
INTERNAL_ERROR = -32603
"""Internal JSON-RPC error."""
# A2A-Specific Errors (-32099 to -32000)
TASK_NOT_FOUND = -32001
"""The specified task was not found."""
TASK_NOT_CANCELABLE = -32002
"""The task cannot be canceled (already completed/failed)."""
PUSH_NOTIFICATION_NOT_SUPPORTED = -32003
"""Push notifications are not supported by this agent."""
UNSUPPORTED_OPERATION = -32004
"""The requested operation is not supported."""
CONTENT_TYPE_NOT_SUPPORTED = -32005
"""Incompatible content types between client and server."""
INVALID_AGENT_RESPONSE = -32006
"""The agent produced an invalid response."""
# CrewAI Custom Extensions (-32768 to -32100)
UNSUPPORTED_VERSION = -32009
"""The requested A2A protocol version is not supported."""
UNSUPPORTED_EXTENSION = -32010
"""Client does not support required protocol extensions."""
AUTHENTICATION_REQUIRED = -32011
"""Authentication is required for this operation."""
AUTHORIZATION_FAILED = -32012
"""Authorization check failed (insufficient permissions)."""
RATE_LIMIT_EXCEEDED = -32013
"""Rate limit exceeded for this client/operation."""
TASK_TIMEOUT = -32014
"""Task execution timed out."""
TRANSPORT_NEGOTIATION_FAILED = -32015
"""Failed to negotiate a compatible transport protocol."""
CONTEXT_NOT_FOUND = -32016
"""The specified context was not found."""
SKILL_NOT_FOUND = -32017
"""The specified skill was not found."""
ARTIFACT_NOT_FOUND = -32018
"""The specified artifact was not found."""
# Error code to default message mapping
ERROR_MESSAGES: dict[int, str] = {
A2AErrorCode.JSON_PARSE_ERROR: "Parse error",
A2AErrorCode.INVALID_REQUEST: "Invalid Request",
A2AErrorCode.METHOD_NOT_FOUND: "Method not found",
A2AErrorCode.INVALID_PARAMS: "Invalid params",
A2AErrorCode.INTERNAL_ERROR: "Internal error",
A2AErrorCode.TASK_NOT_FOUND: "Task not found",
A2AErrorCode.TASK_NOT_CANCELABLE: "Task not cancelable",
A2AErrorCode.PUSH_NOTIFICATION_NOT_SUPPORTED: "Push Notification is not supported",
A2AErrorCode.UNSUPPORTED_OPERATION: "This operation is not supported",
A2AErrorCode.CONTENT_TYPE_NOT_SUPPORTED: "Incompatible content types",
A2AErrorCode.INVALID_AGENT_RESPONSE: "Invalid agent response",
A2AErrorCode.UNSUPPORTED_VERSION: "Unsupported A2A version",
A2AErrorCode.UNSUPPORTED_EXTENSION: "Client does not support required extensions",
A2AErrorCode.AUTHENTICATION_REQUIRED: "Authentication required",
A2AErrorCode.AUTHORIZATION_FAILED: "Authorization failed",
A2AErrorCode.RATE_LIMIT_EXCEEDED: "Rate limit exceeded",
A2AErrorCode.TASK_TIMEOUT: "Task execution timed out",
A2AErrorCode.TRANSPORT_NEGOTIATION_FAILED: "Transport negotiation failed",
A2AErrorCode.CONTEXT_NOT_FOUND: "Context not found",
A2AErrorCode.SKILL_NOT_FOUND: "Skill not found",
A2AErrorCode.ARTIFACT_NOT_FOUND: "Artifact not found",
}
@dataclass
class A2AError(Exception):
"""Base exception for A2A protocol errors.
Attributes:
code: The A2A/JSON-RPC error code.
message: Human-readable error message.
data: Optional additional error data.
"""
code: int
message: str | None = None
data: Any = None
def __post_init__(self) -> None:
if self.message is None:
self.message = ERROR_MESSAGES.get(self.code, "Unknown error")
super().__init__(self.message)
def to_dict(self) -> dict[str, Any]:
"""Convert to JSON-RPC error object format."""
error: dict[str, Any] = {
"code": self.code,
"message": self.message,
}
if self.data is not None:
error["data"] = self.data
return error
def to_response(self, request_id: str | int | None = None) -> dict[str, Any]:
"""Convert to full JSON-RPC error response."""
return {
"jsonrpc": "2.0",
"error": self.to_dict(),
"id": request_id,
}
@dataclass
class JSONParseError(A2AError):
"""Invalid JSON was received."""
code: int = field(default=A2AErrorCode.JSON_PARSE_ERROR, init=False)
@dataclass
class InvalidRequestError(A2AError):
"""The JSON sent is not a valid Request object."""
code: int = field(default=A2AErrorCode.INVALID_REQUEST, init=False)
@dataclass
class MethodNotFoundError(A2AError):
"""The method does not exist / is not available."""
code: int = field(default=A2AErrorCode.METHOD_NOT_FOUND, init=False)
method: str | None = None
def __post_init__(self) -> None:
if self.message is None and self.method:
self.message = f"Method not found: {self.method}"
super().__post_init__()
@dataclass
class InvalidParamsError(A2AError):
"""Invalid method parameter(s)."""
code: int = field(default=A2AErrorCode.INVALID_PARAMS, init=False)
param: str | None = None
reason: str | None = None
def __post_init__(self) -> None:
if self.message is None:
if self.param and self.reason:
self.message = f"Invalid parameter '{self.param}': {self.reason}"
elif self.param:
self.message = f"Invalid parameter: {self.param}"
super().__post_init__()
@dataclass
class InternalError(A2AError):
"""Internal JSON-RPC error."""
code: int = field(default=A2AErrorCode.INTERNAL_ERROR, init=False)
@dataclass
class TaskNotFoundError(A2AError):
"""The specified task was not found."""
code: int = field(default=A2AErrorCode.TASK_NOT_FOUND, init=False)
task_id: str | None = None
def __post_init__(self) -> None:
if self.message is None and self.task_id:
self.message = f"Task not found: {self.task_id}"
super().__post_init__()
@dataclass
class TaskNotCancelableError(A2AError):
"""The task cannot be canceled."""
code: int = field(default=A2AErrorCode.TASK_NOT_CANCELABLE, init=False)
task_id: str | None = None
reason: str | None = None
def __post_init__(self) -> None:
if self.message is None:
if self.task_id and self.reason:
self.message = f"Task {self.task_id} cannot be canceled: {self.reason}"
elif self.task_id:
self.message = f"Task {self.task_id} cannot be canceled"
super().__post_init__()
@dataclass
class PushNotificationNotSupportedError(A2AError):
"""Push notifications are not supported."""
code: int = field(default=A2AErrorCode.PUSH_NOTIFICATION_NOT_SUPPORTED, init=False)
@dataclass
class UnsupportedOperationError(A2AError):
"""The requested operation is not supported."""
code: int = field(default=A2AErrorCode.UNSUPPORTED_OPERATION, init=False)
operation: str | None = None
def __post_init__(self) -> None:
if self.message is None and self.operation:
self.message = f"Operation not supported: {self.operation}"
super().__post_init__()
@dataclass
class ContentTypeNotSupportedError(A2AError):
"""Incompatible content types."""
code: int = field(default=A2AErrorCode.CONTENT_TYPE_NOT_SUPPORTED, init=False)
requested_types: list[str] | None = None
supported_types: list[str] | None = None
def __post_init__(self) -> None:
if self.message is None and self.requested_types and self.supported_types:
self.message = (
f"Content type not supported. Requested: {self.requested_types}, "
f"Supported: {self.supported_types}"
)
super().__post_init__()
@dataclass
class InvalidAgentResponseError(A2AError):
"""The agent produced an invalid response."""
code: int = field(default=A2AErrorCode.INVALID_AGENT_RESPONSE, init=False)
@dataclass
class UnsupportedVersionError(A2AError):
"""The requested A2A version is not supported."""
code: int = field(default=A2AErrorCode.UNSUPPORTED_VERSION, init=False)
requested_version: str | None = None
supported_versions: list[str] | None = None
def __post_init__(self) -> None:
if self.message is None and self.requested_version:
msg = f"Unsupported A2A version: {self.requested_version}"
if self.supported_versions:
msg += f". Supported versions: {', '.join(self.supported_versions)}"
self.message = msg
super().__post_init__()
@dataclass
class UnsupportedExtensionError(A2AError):
"""Client does not support required extensions."""
code: int = field(default=A2AErrorCode.UNSUPPORTED_EXTENSION, init=False)
required_extensions: list[str] | None = None
def __post_init__(self) -> None:
if self.message is None and self.required_extensions:
self.message = f"Client does not support required extensions: {', '.join(self.required_extensions)}"
super().__post_init__()
@dataclass
class AuthenticationRequiredError(A2AError):
"""Authentication is required."""
code: int = field(default=A2AErrorCode.AUTHENTICATION_REQUIRED, init=False)
@dataclass
class AuthorizationFailedError(A2AError):
"""Authorization check failed."""
code: int = field(default=A2AErrorCode.AUTHORIZATION_FAILED, init=False)
required_scope: str | None = None
def __post_init__(self) -> None:
if self.message is None and self.required_scope:
self.message = (
f"Authorization failed. Required scope: {self.required_scope}"
)
super().__post_init__()
@dataclass
class RateLimitExceededError(A2AError):
"""Rate limit exceeded."""
code: int = field(default=A2AErrorCode.RATE_LIMIT_EXCEEDED, init=False)
retry_after: int | None = None
def __post_init__(self) -> None:
if self.message is None and self.retry_after:
self.message = (
f"Rate limit exceeded. Retry after {self.retry_after} seconds"
)
if self.retry_after:
self.data = {"retry_after": self.retry_after}
super().__post_init__()
@dataclass
class TaskTimeoutError(A2AError):
"""Task execution timed out."""
code: int = field(default=A2AErrorCode.TASK_TIMEOUT, init=False)
task_id: str | None = None
timeout_seconds: float | None = None
def __post_init__(self) -> None:
if self.message is None:
if self.task_id and self.timeout_seconds:
self.message = (
f"Task {self.task_id} timed out after {self.timeout_seconds}s"
)
elif self.task_id:
self.message = f"Task {self.task_id} timed out"
super().__post_init__()
@dataclass
class TransportNegotiationFailedError(A2AError):
"""Failed to negotiate a compatible transport protocol."""
code: int = field(default=A2AErrorCode.TRANSPORT_NEGOTIATION_FAILED, init=False)
client_transports: list[str] | None = None
server_transports: list[str] | None = None
def __post_init__(self) -> None:
if self.message is None and self.client_transports and self.server_transports:
self.message = (
f"Transport negotiation failed. Client: {self.client_transports}, "
f"Server: {self.server_transports}"
)
super().__post_init__()
@dataclass
class ContextNotFoundError(A2AError):
"""The specified context was not found."""
code: int = field(default=A2AErrorCode.CONTEXT_NOT_FOUND, init=False)
context_id: str | None = None
def __post_init__(self) -> None:
if self.message is None and self.context_id:
self.message = f"Context not found: {self.context_id}"
super().__post_init__()
@dataclass
class SkillNotFoundError(A2AError):
"""The specified skill was not found."""
code: int = field(default=A2AErrorCode.SKILL_NOT_FOUND, init=False)
skill_id: str | None = None
def __post_init__(self) -> None:
if self.message is None and self.skill_id:
self.message = f"Skill not found: {self.skill_id}"
super().__post_init__()
@dataclass
class ArtifactNotFoundError(A2AError):
"""The specified artifact was not found."""
code: int = field(default=A2AErrorCode.ARTIFACT_NOT_FOUND, init=False)
artifact_id: str | None = None
def __post_init__(self) -> None:
if self.message is None and self.artifact_id:
self.message = f"Artifact not found: {self.artifact_id}"
super().__post_init__()
def create_error_response(
code: int | A2AErrorCode,
message: str | None = None,
data: Any = None,
request_id: str | int | None = None,
) -> dict[str, Any]:
"""Create a JSON-RPC error response.
Args:
code: Error code (A2AErrorCode or int).
message: Optional error message (uses default if not provided).
data: Optional additional error data.
request_id: Request ID for correlation.
Returns:
Dict in JSON-RPC error response format.
"""
error = A2AError(code=int(code), message=message, data=data)
return error.to_response(request_id)
def is_retryable_error(code: int) -> bool:
"""Check if an error is potentially retryable.
Args:
code: Error code to check.
Returns:
True if the error might be resolved by retrying.
"""
retryable_codes = {
A2AErrorCode.INTERNAL_ERROR,
A2AErrorCode.RATE_LIMIT_EXCEEDED,
A2AErrorCode.TASK_TIMEOUT,
}
return code in retryable_codes
def is_client_error(code: int) -> bool:
"""Check if an error is a client-side error.
Args:
code: Error code to check.
Returns:
True if the error is due to client request issues.
"""
client_error_codes = {
A2AErrorCode.JSON_PARSE_ERROR,
A2AErrorCode.INVALID_REQUEST,
A2AErrorCode.METHOD_NOT_FOUND,
A2AErrorCode.INVALID_PARAMS,
A2AErrorCode.TASK_NOT_FOUND,
A2AErrorCode.CONTENT_TYPE_NOT_SUPPORTED,
A2AErrorCode.UNSUPPORTED_VERSION,
A2AErrorCode.UNSUPPORTED_EXTENSION,
A2AErrorCode.CONTEXT_NOT_FOUND,
A2AErrorCode.SKILL_NOT_FOUND,
A2AErrorCode.ARTIFACT_NOT_FOUND,
}
return code in client_error_codes

View File

@@ -51,6 +51,13 @@ ACTIONABLE_STATES: frozenset[TaskState] = frozenset(
}
)
PENDING_STATES: frozenset[TaskState] = frozenset(
{
TaskState.submitted,
TaskState.working,
}
)
class TaskStateResult(TypedDict):
"""Result dictionary from processing A2A task state."""
@@ -272,6 +279,9 @@ def process_task_state(
history=new_messages,
)
if a2a_task.status.state in PENDING_STATES:
return None
return None

View File

@@ -38,3 +38,18 @@ You MUST now:
DO NOT send another request - the task is already done.
</REMOTE_AGENT_STATUS>
"""
REMOTE_AGENT_RESPONSE_NOTICE: Final[str] = """
<REMOTE_AGENT_STATUS>
STATUS: RESPONSE_RECEIVED
The remote agent has responded. Their response is in the conversation history above.
You MUST now:
1. Set is_a2a=false (the remote task is complete and cannot receive more messages)
2. Provide YOUR OWN response to the original task based on the information received
IMPORTANT: Your response should be addressed to the USER who gave you the original task.
Report what the remote agent told you in THIRD PERSON (e.g., "The remote agent said..." or "I learned that...").
Do NOT address the remote agent directly or use "you" to refer to them.
</REMOTE_AGENT_STATUS>
"""

View File

@@ -36,6 +36,17 @@ except ImportError:
TransportType = Literal["JSONRPC", "GRPC", "HTTP+JSON"]
ProtocolVersion = Literal[
"0.2.0",
"0.2.1",
"0.2.2",
"0.2.3",
"0.2.4",
"0.2.5",
"0.2.6",
"0.3.0",
"0.4.0",
]
http_url_adapter: TypeAdapter[HttpUrl] = TypeAdapter(HttpUrl)

View File

@@ -2,12 +2,28 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Protocol, TypedDict
from typing import TYPE_CHECKING, Any, NamedTuple, Protocol, TypedDict
from pydantic import GetCoreSchemaHandler
from pydantic_core import CoreSchema, core_schema
class CommonParams(NamedTuple):
"""Common parameters shared across all update handlers.
Groups the frequently-passed parameters to reduce duplication.
"""
turn_number: int
is_multiturn: bool
agent_role: str | None
endpoint: str
a2a_agent_name: str | None
context_id: str | None
from_task: Any
from_agent: Any
if TYPE_CHECKING:
from a2a.client import Client
from a2a.types import AgentCard, Message, Task
@@ -63,8 +79,8 @@ class PushNotificationResultStore(Protocol):
@classmethod
def __get_pydantic_core_schema__(
cls,
source_type: Any,
handler: GetCoreSchemaHandler,
_source_type: Any,
_handler: GetCoreSchemaHandler,
) -> CoreSchema:
return core_schema.any_schema()
@@ -130,3 +146,31 @@ class UpdateHandler(Protocol):
Result dictionary with status, result/error, and history.
"""
...
def extract_common_params(kwargs: BaseHandlerKwargs) -> CommonParams:
"""Extract common parameters from handler kwargs.
Args:
kwargs: Handler kwargs dict.
Returns:
CommonParams with extracted values.
Raises:
ValueError: If endpoint is not provided.
"""
endpoint = kwargs.get("endpoint")
if endpoint is None:
raise ValueError("endpoint is required for update handlers")
return CommonParams(
turn_number=kwargs.get("turn_number", 0),
is_multiturn=kwargs.get("is_multiturn", False),
agent_role=kwargs.get("agent_role"),
endpoint=endpoint,
a2a_agent_name=kwargs.get("a2a_agent_name"),
context_id=kwargs.get("context_id"),
from_task=kwargs.get("from_task"),
from_agent=kwargs.get("from_agent"),
)

View File

@@ -94,7 +94,7 @@ async def _poll_task_until_complete(
A2APollingStatusEvent(
task_id=task_id,
context_id=effective_context_id,
state=str(task.status.state.value) if task.status.state else "unknown",
state=str(task.status.state.value),
elapsed_seconds=elapsed,
poll_count=poll_count,
endpoint=endpoint,
@@ -325,7 +325,7 @@ class PollingHandler:
crewai_event_bus.emit(
agent_branch,
A2AConnectionErrorEvent(
endpoint=endpoint or "",
endpoint=endpoint,
error=str(e),
error_type="unexpected_error",
a2a_agent_name=a2a_agent_name,

View File

@@ -2,10 +2,30 @@
from __future__ import annotations
from typing import Annotated
from a2a.types import PushNotificationAuthenticationInfo
from pydantic import AnyHttpUrl, BaseModel, Field
from pydantic import AnyHttpUrl, BaseModel, BeforeValidator, Field
from crewai.a2a.updates.base import PushNotificationResultStore
from crewai.a2a.updates.push_notifications.signature import WebhookSignatureConfig
def _coerce_signature(
value: str | WebhookSignatureConfig | None,
) -> WebhookSignatureConfig | None:
"""Convert string secret to WebhookSignatureConfig."""
if value is None:
return None
if isinstance(value, str):
return WebhookSignatureConfig.hmac_sha256(secret=value)
return value
SignatureInput = Annotated[
WebhookSignatureConfig | None,
BeforeValidator(_coerce_signature),
]
class PushNotificationConfig(BaseModel):
@@ -19,6 +39,8 @@ class PushNotificationConfig(BaseModel):
timeout: Max seconds to wait for task completion.
interval: Seconds between result polling attempts.
result_store: Store for receiving push notification results.
signature: HMAC signature config. Pass a string (secret) for defaults,
or WebhookSignatureConfig for custom settings.
"""
url: AnyHttpUrl = Field(description="Callback URL for push notifications")
@@ -36,3 +58,8 @@ class PushNotificationConfig(BaseModel):
result_store: PushNotificationResultStore | None = Field(
default=None, description="Result store for push notification handling"
)
signature: SignatureInput = Field(
default=None,
description="HMAC signature config. Pass a string (secret) for simple usage, "
"or WebhookSignatureConfig for custom headers/tolerance.",
)

View File

@@ -24,8 +24,10 @@ from crewai.a2a.task_helpers import (
send_message_and_get_task_id,
)
from crewai.a2a.updates.base import (
CommonParams,
PushNotificationHandlerKwargs,
PushNotificationResultStore,
extract_common_params,
)
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.a2a_events import (
@@ -39,10 +41,81 @@ from crewai.events.types.a2a_events import (
if TYPE_CHECKING:
from a2a.types import Task as A2ATask
logger = logging.getLogger(__name__)
def _handle_push_error(
error: Exception,
error_msg: str,
error_type: str,
new_messages: list[Message],
agent_branch: Any | None,
params: CommonParams,
task_id: str | None,
status_code: int | None = None,
) -> TaskStateResult:
"""Handle push notification errors with consistent event emission.
Args:
error: The exception that occurred.
error_msg: Formatted error message for the result.
error_type: Type of error for the event.
new_messages: List to append error message to.
agent_branch: Agent tree branch for events.
params: Common handler parameters.
task_id: A2A task ID.
status_code: HTTP status code if applicable.
Returns:
TaskStateResult with failed status.
"""
error_message = Message(
role=Role.agent,
message_id=str(uuid.uuid4()),
parts=[Part(root=TextPart(text=error_msg))],
context_id=params.context_id,
task_id=task_id,
)
new_messages.append(error_message)
crewai_event_bus.emit(
agent_branch,
A2AConnectionErrorEvent(
endpoint=params.endpoint,
error=str(error),
error_type=error_type,
status_code=status_code,
a2a_agent_name=params.a2a_agent_name,
operation="push_notification",
context_id=params.context_id,
task_id=task_id,
from_task=params.from_task,
from_agent=params.from_agent,
),
)
crewai_event_bus.emit(
agent_branch,
A2AResponseReceivedEvent(
response=error_msg,
turn_number=params.turn_number,
context_id=params.context_id,
is_multiturn=params.is_multiturn,
status="failed",
final=True,
agent_role=params.agent_role,
endpoint=params.endpoint,
a2a_agent_name=params.a2a_agent_name,
from_task=params.from_task,
from_agent=params.from_agent,
),
)
return TaskStateResult(
status=TaskState.failed,
error=error_msg,
history=new_messages,
)
async def _wait_for_push_result(
task_id: str,
result_store: PushNotificationResultStore,
@@ -126,15 +199,8 @@ class PushNotificationHandler:
polling_timeout = kwargs.get("polling_timeout", 300.0)
polling_interval = kwargs.get("polling_interval", 2.0)
agent_branch = kwargs.get("agent_branch")
turn_number = kwargs.get("turn_number", 0)
is_multiturn = kwargs.get("is_multiturn", False)
agent_role = kwargs.get("agent_role")
context_id = kwargs.get("context_id")
task_id = kwargs.get("task_id")
endpoint = kwargs.get("endpoint")
a2a_agent_name = kwargs.get("a2a_agent_name")
from_task = kwargs.get("from_task")
from_agent = kwargs.get("from_agent")
params = extract_common_params(kwargs)
if config is None:
error_msg = (
@@ -143,15 +209,15 @@ class PushNotificationHandler:
crewai_event_bus.emit(
agent_branch,
A2AConnectionErrorEvent(
endpoint=endpoint or "",
endpoint=params.endpoint,
error=error_msg,
error_type="configuration_error",
a2a_agent_name=a2a_agent_name,
a2a_agent_name=params.a2a_agent_name,
operation="push_notification",
context_id=context_id,
context_id=params.context_id,
task_id=task_id,
from_task=from_task,
from_agent=from_agent,
from_task=params.from_task,
from_agent=params.from_agent,
),
)
return TaskStateResult(
@@ -167,15 +233,15 @@ class PushNotificationHandler:
crewai_event_bus.emit(
agent_branch,
A2AConnectionErrorEvent(
endpoint=endpoint or "",
endpoint=params.endpoint,
error=error_msg,
error_type="configuration_error",
a2a_agent_name=a2a_agent_name,
a2a_agent_name=params.a2a_agent_name,
operation="push_notification",
context_id=context_id,
context_id=params.context_id,
task_id=task_id,
from_task=from_task,
from_agent=from_agent,
from_task=params.from_task,
from_agent=params.from_agent,
),
)
return TaskStateResult(
@@ -189,14 +255,14 @@ class PushNotificationHandler:
event_stream=client.send_message(message),
new_messages=new_messages,
agent_card=agent_card,
turn_number=turn_number,
is_multiturn=is_multiturn,
agent_role=agent_role,
from_task=from_task,
from_agent=from_agent,
endpoint=endpoint,
a2a_agent_name=a2a_agent_name,
context_id=context_id,
turn_number=params.turn_number,
is_multiturn=params.is_multiturn,
agent_role=params.agent_role,
from_task=params.from_task,
from_agent=params.from_agent,
endpoint=params.endpoint,
a2a_agent_name=params.a2a_agent_name,
context_id=params.context_id,
)
if not isinstance(result_or_task_id, str):
@@ -208,12 +274,12 @@ class PushNotificationHandler:
agent_branch,
A2APushNotificationRegisteredEvent(
task_id=task_id,
context_id=context_id,
context_id=params.context_id,
callback_url=str(config.url),
endpoint=endpoint,
a2a_agent_name=a2a_agent_name,
from_task=from_task,
from_agent=from_agent,
endpoint=params.endpoint,
a2a_agent_name=params.a2a_agent_name,
from_task=params.from_task,
from_agent=params.from_agent,
),
)
@@ -229,11 +295,11 @@ class PushNotificationHandler:
timeout=polling_timeout,
poll_interval=polling_interval,
agent_branch=agent_branch,
from_task=from_task,
from_agent=from_agent,
context_id=context_id,
endpoint=endpoint,
a2a_agent_name=a2a_agent_name,
from_task=params.from_task,
from_agent=params.from_agent,
context_id=params.context_id,
endpoint=params.endpoint,
a2a_agent_name=params.a2a_agent_name,
)
if final_task is None:
@@ -247,13 +313,13 @@ class PushNotificationHandler:
a2a_task=final_task,
new_messages=new_messages,
agent_card=agent_card,
turn_number=turn_number,
is_multiturn=is_multiturn,
agent_role=agent_role,
endpoint=endpoint,
a2a_agent_name=a2a_agent_name,
from_task=from_task,
from_agent=from_agent,
turn_number=params.turn_number,
is_multiturn=params.is_multiturn,
agent_role=params.agent_role,
endpoint=params.endpoint,
a2a_agent_name=params.a2a_agent_name,
from_task=params.from_task,
from_agent=params.from_agent,
)
if result:
return result
@@ -265,98 +331,24 @@ class PushNotificationHandler:
)
except A2AClientHTTPError as e:
error_msg = f"HTTP Error {e.status_code}: {e!s}"
error_message = Message(
role=Role.agent,
message_id=str(uuid.uuid4()),
parts=[Part(root=TextPart(text=error_msg))],
context_id=context_id,
return _handle_push_error(
error=e,
error_msg=f"HTTP Error {e.status_code}: {e!s}",
error_type="http_error",
new_messages=new_messages,
agent_branch=agent_branch,
params=params,
task_id=task_id,
)
new_messages.append(error_message)
crewai_event_bus.emit(
agent_branch,
A2AConnectionErrorEvent(
endpoint=endpoint or "",
error=str(e),
error_type="http_error",
status_code=e.status_code,
a2a_agent_name=a2a_agent_name,
operation="push_notification",
context_id=context_id,
task_id=task_id,
from_task=from_task,
from_agent=from_agent,
),
)
crewai_event_bus.emit(
agent_branch,
A2AResponseReceivedEvent(
response=error_msg,
turn_number=turn_number,
context_id=context_id,
is_multiturn=is_multiturn,
status="failed",
final=True,
agent_role=agent_role,
endpoint=endpoint,
a2a_agent_name=a2a_agent_name,
from_task=from_task,
from_agent=from_agent,
),
)
return TaskStateResult(
status=TaskState.failed,
error=error_msg,
history=new_messages,
status_code=e.status_code,
)
except Exception as e:
error_msg = f"Unexpected error during push notification: {e!s}"
error_message = Message(
role=Role.agent,
message_id=str(uuid.uuid4()),
parts=[Part(root=TextPart(text=error_msg))],
context_id=context_id,
return _handle_push_error(
error=e,
error_msg=f"Unexpected error during push notification: {e!s}",
error_type="unexpected_error",
new_messages=new_messages,
agent_branch=agent_branch,
params=params,
task_id=task_id,
)
new_messages.append(error_message)
crewai_event_bus.emit(
agent_branch,
A2AConnectionErrorEvent(
endpoint=endpoint or "",
error=str(e),
error_type="unexpected_error",
a2a_agent_name=a2a_agent_name,
operation="push_notification",
context_id=context_id,
task_id=task_id,
from_task=from_task,
from_agent=from_agent,
),
)
crewai_event_bus.emit(
agent_branch,
A2AResponseReceivedEvent(
response=error_msg,
turn_number=turn_number,
context_id=context_id,
is_multiturn=is_multiturn,
status="failed",
final=True,
agent_role=agent_role,
endpoint=endpoint,
a2a_agent_name=a2a_agent_name,
from_task=from_task,
from_agent=from_agent,
),
)
return TaskStateResult(
status=TaskState.failed,
error=error_msg,
history=new_messages,
)

View File

@@ -0,0 +1,87 @@
"""Webhook signature configuration for push notifications."""
from __future__ import annotations
from enum import Enum
import secrets
from pydantic import BaseModel, Field, SecretStr
class WebhookSignatureMode(str, Enum):
"""Signature mode for webhook push notifications."""
NONE = "none"
HMAC_SHA256 = "hmac_sha256"
class WebhookSignatureConfig(BaseModel):
"""Configuration for webhook signature verification.
Provides cryptographic integrity verification and replay attack protection
for A2A push notifications.
Attributes:
mode: Signature mode (none or hmac_sha256).
secret: Shared secret for HMAC computation (required for hmac_sha256 mode).
timestamp_tolerance_seconds: Max allowed age of timestamps for replay protection.
header_name: HTTP header name for the signature.
timestamp_header_name: HTTP header name for the timestamp.
"""
mode: WebhookSignatureMode = Field(
default=WebhookSignatureMode.NONE,
description="Signature verification mode",
)
secret: SecretStr | None = Field(
default=None,
description="Shared secret for HMAC computation",
)
timestamp_tolerance_seconds: int = Field(
default=300,
ge=0,
description="Max allowed timestamp age in seconds (5 min default)",
)
header_name: str = Field(
default="X-A2A-Signature",
description="HTTP header name for the signature",
)
timestamp_header_name: str = Field(
default="X-A2A-Signature-Timestamp",
description="HTTP header name for the timestamp",
)
@classmethod
def generate_secret(cls, length: int = 32) -> str:
"""Generate a cryptographically secure random secret.
Args:
length: Number of random bytes to generate (default 32).
Returns:
URL-safe base64-encoded secret string.
"""
return secrets.token_urlsafe(length)
@classmethod
def hmac_sha256(
cls,
secret: str | SecretStr,
timestamp_tolerance_seconds: int = 300,
) -> WebhookSignatureConfig:
"""Create an HMAC-SHA256 signature configuration.
Args:
secret: Shared secret for HMAC computation.
timestamp_tolerance_seconds: Max allowed timestamp age in seconds.
Returns:
Configured WebhookSignatureConfig for HMAC-SHA256.
"""
if isinstance(secret, str):
secret = SecretStr(secret)
return cls(
mode=WebhookSignatureMode.HMAC_SHA256,
secret=secret,
timestamp_tolerance_seconds=timestamp_tolerance_seconds,
)

View File

@@ -2,6 +2,9 @@
from __future__ import annotations
import asyncio
import logging
from typing import Final
import uuid
from a2a.client import Client
@@ -11,7 +14,10 @@ from a2a.types import (
Message,
Part,
Role,
Task,
TaskArtifactUpdateEvent,
TaskIdParams,
TaskQueryParams,
TaskState,
TaskStatusUpdateEvent,
TextPart,
@@ -24,7 +30,10 @@ from crewai.a2a.task_helpers import (
TaskStateResult,
process_task_state,
)
from crewai.a2a.updates.base import StreamingHandlerKwargs
from crewai.a2a.updates.base import StreamingHandlerKwargs, extract_common_params
from crewai.a2a.updates.streaming.params import (
process_status_update,
)
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.a2a_events import (
A2AArtifactReceivedEvent,
@@ -35,9 +44,194 @@ from crewai.events.types.a2a_events import (
)
logger = logging.getLogger(__name__)
MAX_RESUBSCRIBE_ATTEMPTS: Final[int] = 3
RESUBSCRIBE_BACKOFF_BASE: Final[float] = 1.0
class StreamingHandler:
"""SSE streaming-based update handler."""
@staticmethod
async def _try_recover_from_interruption( # type: ignore[misc]
client: Client,
task_id: str,
new_messages: list[Message],
agent_card: AgentCard,
result_parts: list[str],
**kwargs: Unpack[StreamingHandlerKwargs],
) -> TaskStateResult | None:
"""Attempt to recover from a stream interruption by checking task state.
If the task completed while we were disconnected, returns the result.
If the task is still running, attempts to resubscribe and continue.
Args:
client: A2A client instance.
task_id: The task ID to recover.
new_messages: List of collected messages.
agent_card: The agent card.
result_parts: Accumulated result text parts.
**kwargs: Handler parameters.
Returns:
TaskStateResult if recovery succeeded (task finished or resubscribe worked).
None if recovery not possible (caller should handle failure).
Note:
When None is returned, recovery failed and the original exception should
be handled by the caller. All recovery attempts are logged.
"""
params = extract_common_params(kwargs) # type: ignore[arg-type]
try:
a2a_task: Task = await client.get_task(TaskQueryParams(id=task_id))
if a2a_task.status.state in TERMINAL_STATES:
logger.info(
"Task completed during stream interruption",
extra={"task_id": task_id, "state": str(a2a_task.status.state)},
)
return process_task_state(
a2a_task=a2a_task,
new_messages=new_messages,
agent_card=agent_card,
turn_number=params.turn_number,
is_multiturn=params.is_multiturn,
agent_role=params.agent_role,
result_parts=result_parts,
endpoint=params.endpoint,
a2a_agent_name=params.a2a_agent_name,
from_task=params.from_task,
from_agent=params.from_agent,
)
if a2a_task.status.state in ACTIONABLE_STATES:
logger.info(
"Task in actionable state during stream interruption",
extra={"task_id": task_id, "state": str(a2a_task.status.state)},
)
return process_task_state(
a2a_task=a2a_task,
new_messages=new_messages,
agent_card=agent_card,
turn_number=params.turn_number,
is_multiturn=params.is_multiturn,
agent_role=params.agent_role,
result_parts=result_parts,
endpoint=params.endpoint,
a2a_agent_name=params.a2a_agent_name,
from_task=params.from_task,
from_agent=params.from_agent,
is_final=False,
)
logger.info(
"Task still running, attempting resubscribe",
extra={"task_id": task_id, "state": str(a2a_task.status.state)},
)
for attempt in range(MAX_RESUBSCRIBE_ATTEMPTS):
try:
backoff = RESUBSCRIBE_BACKOFF_BASE * (2**attempt)
if attempt > 0:
await asyncio.sleep(backoff)
event_stream = client.resubscribe(TaskIdParams(id=task_id))
async for event in event_stream:
if isinstance(event, tuple):
resubscribed_task, update = event
is_final_update = (
process_status_update(update, result_parts)
if isinstance(update, TaskStatusUpdateEvent)
else False
)
if isinstance(update, TaskArtifactUpdateEvent):
artifact = update.artifact
result_parts.extend(
part.root.text
for part in artifact.parts
if part.root.kind == "text"
)
if (
is_final_update
or resubscribed_task.status.state
in TERMINAL_STATES | ACTIONABLE_STATES
):
return process_task_state(
a2a_task=resubscribed_task,
new_messages=new_messages,
agent_card=agent_card,
turn_number=params.turn_number,
is_multiturn=params.is_multiturn,
agent_role=params.agent_role,
result_parts=result_parts,
endpoint=params.endpoint,
a2a_agent_name=params.a2a_agent_name,
from_task=params.from_task,
from_agent=params.from_agent,
is_final=is_final_update,
)
elif isinstance(event, Message):
new_messages.append(event)
result_parts.extend(
part.root.text
for part in event.parts
if part.root.kind == "text"
)
final_task = await client.get_task(TaskQueryParams(id=task_id))
return process_task_state(
a2a_task=final_task,
new_messages=new_messages,
agent_card=agent_card,
turn_number=params.turn_number,
is_multiturn=params.is_multiturn,
agent_role=params.agent_role,
result_parts=result_parts,
endpoint=params.endpoint,
a2a_agent_name=params.a2a_agent_name,
from_task=params.from_task,
from_agent=params.from_agent,
)
except Exception as resubscribe_error: # noqa: PERF203
logger.warning(
"Resubscribe attempt failed",
extra={
"task_id": task_id,
"attempt": attempt + 1,
"max_attempts": MAX_RESUBSCRIBE_ATTEMPTS,
"error": str(resubscribe_error),
},
)
if attempt == MAX_RESUBSCRIBE_ATTEMPTS - 1:
return None
except Exception as e:
logger.warning(
"Failed to recover from stream interruption due to unexpected error",
extra={
"task_id": task_id,
"error": str(e),
"error_type": type(e).__name__,
},
exc_info=True,
)
return None
logger.warning(
"Recovery exhausted all resubscribe attempts without success",
extra={"task_id": task_id, "max_attempts": MAX_RESUBSCRIBE_ATTEMPTS},
)
return None
@staticmethod
async def execute(
client: Client,
@@ -58,42 +252,40 @@ class StreamingHandler:
Returns:
Dictionary with status, result/error, and history.
"""
context_id = kwargs.get("context_id")
task_id = kwargs.get("task_id")
turn_number = kwargs.get("turn_number", 0)
is_multiturn = kwargs.get("is_multiturn", False)
agent_role = kwargs.get("agent_role")
endpoint = kwargs.get("endpoint")
a2a_agent_name = kwargs.get("a2a_agent_name")
from_task = kwargs.get("from_task")
from_agent = kwargs.get("from_agent")
agent_branch = kwargs.get("agent_branch")
params = extract_common_params(kwargs)
result_parts: list[str] = []
final_result: TaskStateResult | None = None
event_stream = client.send_message(message)
chunk_index = 0
current_task_id: str | None = task_id
crewai_event_bus.emit(
agent_branch,
A2AStreamingStartedEvent(
task_id=task_id,
context_id=context_id,
endpoint=endpoint or "",
a2a_agent_name=a2a_agent_name,
turn_number=turn_number,
is_multiturn=is_multiturn,
agent_role=agent_role,
from_task=from_task,
from_agent=from_agent,
context_id=params.context_id,
endpoint=params.endpoint,
a2a_agent_name=params.a2a_agent_name,
turn_number=params.turn_number,
is_multiturn=params.is_multiturn,
agent_role=params.agent_role,
from_task=params.from_task,
from_agent=params.from_agent,
),
)
try:
async for event in event_stream:
if isinstance(event, tuple):
a2a_task, _ = event
current_task_id = a2a_task.id
if isinstance(event, Message):
new_messages.append(event)
message_context_id = event.context_id or context_id
message_context_id = event.context_id or params.context_id
for part in event.parts:
if part.root.kind == "text":
text = part.root.text
@@ -105,12 +297,12 @@ class StreamingHandler:
context_id=message_context_id,
chunk=text,
chunk_index=chunk_index,
endpoint=endpoint,
a2a_agent_name=a2a_agent_name,
turn_number=turn_number,
is_multiturn=is_multiturn,
from_task=from_task,
from_agent=from_agent,
endpoint=params.endpoint,
a2a_agent_name=params.a2a_agent_name,
turn_number=params.turn_number,
is_multiturn=params.is_multiturn,
from_task=params.from_task,
from_agent=params.from_agent,
),
)
chunk_index += 1
@@ -128,12 +320,12 @@ class StreamingHandler:
artifact_size = None
if artifact.parts:
artifact_size = sum(
len(p.root.text.encode("utf-8"))
len(p.root.text.encode())
if p.root.kind == "text"
else len(getattr(p.root, "data", b""))
for p in artifact.parts
)
effective_context_id = a2a_task.context_id or context_id
effective_context_id = a2a_task.context_id or params.context_id
crewai_event_bus.emit(
agent_branch,
A2AArtifactReceivedEvent(
@@ -147,29 +339,21 @@ class StreamingHandler:
size_bytes=artifact_size,
append=update.append or False,
last_chunk=update.last_chunk or False,
endpoint=endpoint,
a2a_agent_name=a2a_agent_name,
endpoint=params.endpoint,
a2a_agent_name=params.a2a_agent_name,
context_id=effective_context_id,
turn_number=turn_number,
is_multiturn=is_multiturn,
from_task=from_task,
from_agent=from_agent,
turn_number=params.turn_number,
is_multiturn=params.is_multiturn,
from_task=params.from_task,
from_agent=params.from_agent,
),
)
is_final_update = False
if isinstance(update, TaskStatusUpdateEvent):
is_final_update = update.final
if (
update.status
and update.status.message
and update.status.message.parts
):
result_parts.extend(
part.root.text
for part in update.status.message.parts
if part.root.kind == "text" and part.root.text
)
is_final_update = (
process_status_update(update, result_parts)
if isinstance(update, TaskStatusUpdateEvent)
else False
)
if (
not is_final_update
@@ -182,27 +366,68 @@ class StreamingHandler:
a2a_task=a2a_task,
new_messages=new_messages,
agent_card=agent_card,
turn_number=turn_number,
is_multiturn=is_multiturn,
agent_role=agent_role,
turn_number=params.turn_number,
is_multiturn=params.is_multiturn,
agent_role=params.agent_role,
result_parts=result_parts,
endpoint=endpoint,
a2a_agent_name=a2a_agent_name,
from_task=from_task,
from_agent=from_agent,
endpoint=params.endpoint,
a2a_agent_name=params.a2a_agent_name,
from_task=params.from_task,
from_agent=params.from_agent,
is_final=is_final_update,
)
if final_result:
break
except A2AClientHTTPError as e:
if current_task_id:
logger.info(
"Stream interrupted with HTTP error, attempting recovery",
extra={
"task_id": current_task_id,
"error": str(e),
"status_code": e.status_code,
},
)
recovery_kwargs = {k: v for k, v in kwargs.items() if k != "task_id"}
recovered_result = (
await StreamingHandler._try_recover_from_interruption(
client=client,
task_id=current_task_id,
new_messages=new_messages,
agent_card=agent_card,
result_parts=result_parts,
**recovery_kwargs,
)
)
if recovered_result:
logger.info(
"Successfully recovered task after HTTP error",
extra={
"task_id": current_task_id,
"status": str(recovered_result.get("status")),
},
)
return recovered_result
logger.warning(
"Failed to recover from HTTP error, returning failure",
extra={
"task_id": current_task_id,
"status_code": e.status_code,
"original_error": str(e),
},
)
error_msg = f"HTTP Error {e.status_code}: {e!s}"
error_type = "http_error"
status_code = e.status_code
error_message = Message(
role=Role.agent,
message_id=str(uuid.uuid4()),
parts=[Part(root=TextPart(text=error_msg))],
context_id=context_id,
context_id=params.context_id,
task_id=task_id,
)
new_messages.append(error_message)
@@ -210,32 +435,118 @@ class StreamingHandler:
crewai_event_bus.emit(
agent_branch,
A2AConnectionErrorEvent(
endpoint=endpoint or "",
endpoint=params.endpoint,
error=str(e),
error_type="http_error",
status_code=e.status_code,
a2a_agent_name=a2a_agent_name,
error_type=error_type,
status_code=status_code,
a2a_agent_name=params.a2a_agent_name,
operation="streaming",
context_id=context_id,
context_id=params.context_id,
task_id=task_id,
from_task=from_task,
from_agent=from_agent,
from_task=params.from_task,
from_agent=params.from_agent,
),
)
crewai_event_bus.emit(
agent_branch,
A2AResponseReceivedEvent(
response=error_msg,
turn_number=turn_number,
context_id=context_id,
is_multiturn=is_multiturn,
turn_number=params.turn_number,
context_id=params.context_id,
is_multiturn=params.is_multiturn,
status="failed",
final=True,
agent_role=agent_role,
endpoint=endpoint,
a2a_agent_name=a2a_agent_name,
from_task=from_task,
from_agent=from_agent,
agent_role=params.agent_role,
endpoint=params.endpoint,
a2a_agent_name=params.a2a_agent_name,
from_task=params.from_task,
from_agent=params.from_agent,
),
)
return TaskStateResult(
status=TaskState.failed,
error=error_msg,
history=new_messages,
)
except (asyncio.TimeoutError, asyncio.CancelledError, ConnectionError) as e:
error_type = type(e).__name__.lower()
if current_task_id:
logger.info(
f"Stream interrupted with {error_type}, attempting recovery",
extra={"task_id": current_task_id, "error": str(e)},
)
recovery_kwargs = {k: v for k, v in kwargs.items() if k != "task_id"}
recovered_result = (
await StreamingHandler._try_recover_from_interruption(
client=client,
task_id=current_task_id,
new_messages=new_messages,
agent_card=agent_card,
result_parts=result_parts,
**recovery_kwargs,
)
)
if recovered_result:
logger.info(
f"Successfully recovered task after {error_type}",
extra={
"task_id": current_task_id,
"status": str(recovered_result.get("status")),
},
)
return recovered_result
logger.warning(
f"Failed to recover from {error_type}, returning failure",
extra={
"task_id": current_task_id,
"error_type": error_type,
"original_error": str(e),
},
)
error_msg = f"Connection error during streaming: {e!s}"
status_code = None
error_message = Message(
role=Role.agent,
message_id=str(uuid.uuid4()),
parts=[Part(root=TextPart(text=error_msg))],
context_id=params.context_id,
task_id=task_id,
)
new_messages.append(error_message)
crewai_event_bus.emit(
agent_branch,
A2AConnectionErrorEvent(
endpoint=params.endpoint,
error=str(e),
error_type=error_type,
status_code=status_code,
a2a_agent_name=params.a2a_agent_name,
operation="streaming",
context_id=params.context_id,
task_id=task_id,
from_task=params.from_task,
from_agent=params.from_agent,
),
)
crewai_event_bus.emit(
agent_branch,
A2AResponseReceivedEvent(
response=error_msg,
turn_number=params.turn_number,
context_id=params.context_id,
is_multiturn=params.is_multiturn,
status="failed",
final=True,
agent_role=params.agent_role,
endpoint=params.endpoint,
a2a_agent_name=params.a2a_agent_name,
from_task=params.from_task,
from_agent=params.from_agent,
),
)
return TaskStateResult(
@@ -245,13 +556,23 @@ class StreamingHandler:
)
except Exception as e:
error_msg = f"Unexpected error during streaming: {e!s}"
logger.exception(
"Unexpected error during streaming",
extra={
"task_id": current_task_id,
"error_type": type(e).__name__,
"endpoint": params.endpoint,
},
)
error_msg = f"Unexpected error during streaming: {type(e).__name__}: {e!s}"
error_type = "unexpected_error"
status_code = None
error_message = Message(
role=Role.agent,
message_id=str(uuid.uuid4()),
parts=[Part(root=TextPart(text=error_msg))],
context_id=context_id,
context_id=params.context_id,
task_id=task_id,
)
new_messages.append(error_message)
@@ -259,31 +580,32 @@ class StreamingHandler:
crewai_event_bus.emit(
agent_branch,
A2AConnectionErrorEvent(
endpoint=endpoint or "",
endpoint=params.endpoint,
error=str(e),
error_type="unexpected_error",
a2a_agent_name=a2a_agent_name,
error_type=error_type,
status_code=status_code,
a2a_agent_name=params.a2a_agent_name,
operation="streaming",
context_id=context_id,
context_id=params.context_id,
task_id=task_id,
from_task=from_task,
from_agent=from_agent,
from_task=params.from_task,
from_agent=params.from_agent,
),
)
crewai_event_bus.emit(
agent_branch,
A2AResponseReceivedEvent(
response=error_msg,
turn_number=turn_number,
context_id=context_id,
is_multiturn=is_multiturn,
turn_number=params.turn_number,
context_id=params.context_id,
is_multiturn=params.is_multiturn,
status="failed",
final=True,
agent_role=agent_role,
endpoint=endpoint,
a2a_agent_name=a2a_agent_name,
from_task=from_task,
from_agent=from_agent,
agent_role=params.agent_role,
endpoint=params.endpoint,
a2a_agent_name=params.a2a_agent_name,
from_task=params.from_task,
from_agent=params.from_agent,
),
)
return TaskStateResult(
@@ -301,15 +623,15 @@ class StreamingHandler:
crewai_event_bus.emit(
agent_branch,
A2AConnectionErrorEvent(
endpoint=endpoint or "",
endpoint=params.endpoint,
error=str(close_error),
error_type="stream_close_error",
a2a_agent_name=a2a_agent_name,
a2a_agent_name=params.a2a_agent_name,
operation="stream_close",
context_id=context_id,
context_id=params.context_id,
task_id=task_id,
from_task=from_task,
from_agent=from_agent,
from_task=params.from_task,
from_agent=params.from_agent,
),
)

View File

@@ -0,0 +1,28 @@
"""Common parameter extraction for streaming handlers."""
from __future__ import annotations
from a2a.types import TaskStatusUpdateEvent
def process_status_update(
update: TaskStatusUpdateEvent,
result_parts: list[str],
) -> bool:
"""Process a status update event and extract text parts.
Args:
update: The status update event.
result_parts: List to append text parts to (modified in place).
Returns:
True if this is a final update, False otherwise.
"""
is_final = update.final
if update.status and update.status.message and update.status.message.parts:
result_parts.extend(
part.root.text
for part in update.status.message.parts
if part.root.kind == "text" and part.root.text
)
return is_final

View File

@@ -5,6 +5,7 @@ from __future__ import annotations
import asyncio
from collections.abc import MutableMapping
from functools import lru_cache
import ssl
import time
from types import MethodType
from typing import TYPE_CHECKING
@@ -15,7 +16,7 @@ from aiocache import cached # type: ignore[import-untyped]
from aiocache.serializers import PickleSerializer # type: ignore[import-untyped]
import httpx
from crewai.a2a.auth.schemas import APIKeyAuth, HTTPDigestAuth
from crewai.a2a.auth.client_schemes import APIKeyAuth, HTTPDigestAuth
from crewai.a2a.auth.utils import (
_auth_store,
configure_auth_client,
@@ -32,11 +33,51 @@ from crewai.events.types.a2a_events import (
if TYPE_CHECKING:
from crewai.a2a.auth.schemas import AuthScheme
from crewai.a2a.auth.client_schemes import ClientAuthScheme
from crewai.agent import Agent
from crewai.task import Task
def _get_tls_verify(auth: ClientAuthScheme | None) -> ssl.SSLContext | bool | str:
"""Get TLS verify parameter from auth scheme.
Args:
auth: Optional authentication scheme with TLS config.
Returns:
SSL context, CA cert path, True for default verification,
or False if verification disabled.
"""
if auth and auth.tls:
return auth.tls.get_httpx_ssl_context()
return True
async def _prepare_auth_headers(
auth: ClientAuthScheme | None,
timeout: int,
) -> tuple[MutableMapping[str, str], ssl.SSLContext | bool | str]:
"""Prepare authentication headers and TLS verification settings.
Args:
auth: Optional authentication scheme.
timeout: Request timeout in seconds.
Returns:
Tuple of (headers dict, TLS verify setting).
"""
headers: MutableMapping[str, str] = {}
verify = _get_tls_verify(auth)
if auth:
async with httpx.AsyncClient(
timeout=timeout, verify=verify
) as temp_auth_client:
if isinstance(auth, (HTTPDigestAuth, APIKeyAuth)):
configure_auth_client(auth, temp_auth_client)
headers = await auth.apply_auth(temp_auth_client, {})
return headers, verify
def _get_server_config(agent: Agent) -> A2AServerConfig | None:
"""Get A2AServerConfig from an agent's a2a configuration.
@@ -59,7 +100,7 @@ def _get_server_config(agent: Agent) -> A2AServerConfig | None:
def fetch_agent_card(
endpoint: str,
auth: AuthScheme | None = None,
auth: ClientAuthScheme | None = None,
timeout: int = 30,
use_cache: bool = True,
cache_ttl: int = 300,
@@ -68,7 +109,7 @@ def fetch_agent_card(
Args:
endpoint: A2A agent endpoint URL (AgentCard URL).
auth: Optional AuthScheme for authentication.
auth: Optional ClientAuthScheme for authentication.
timeout: Request timeout in seconds.
use_cache: Whether to use caching (default True).
cache_ttl: Cache TTL in seconds (default 300 = 5 minutes).
@@ -90,10 +131,10 @@ def fetch_agent_card(
"_authorization_callback",
}
)
auth_hash = hash((type(auth).__name__, auth_data))
auth_hash = _auth_store.compute_key(type(auth).__name__, auth_data)
else:
auth_hash = 0
_auth_store[auth_hash] = auth
auth_hash = _auth_store.compute_key("none", "")
_auth_store.set(auth_hash, auth)
ttl_hash = int(time.time() // cache_ttl)
return _fetch_agent_card_cached(endpoint, auth_hash, timeout, ttl_hash)
@@ -109,7 +150,7 @@ def fetch_agent_card(
async def afetch_agent_card(
endpoint: str,
auth: AuthScheme | None = None,
auth: ClientAuthScheme | None = None,
timeout: int = 30,
use_cache: bool = True,
) -> AgentCard:
@@ -119,7 +160,7 @@ async def afetch_agent_card(
Args:
endpoint: A2A agent endpoint URL (AgentCard URL).
auth: Optional AuthScheme for authentication.
auth: Optional ClientAuthScheme for authentication.
timeout: Request timeout in seconds.
use_cache: Whether to use caching (default True).
@@ -140,10 +181,10 @@ async def afetch_agent_card(
"_authorization_callback",
}
)
auth_hash = hash((type(auth).__name__, auth_data))
auth_hash = _auth_store.compute_key(type(auth).__name__, auth_data)
else:
auth_hash = 0
_auth_store[auth_hash] = auth
auth_hash = _auth_store.compute_key("none", "")
_auth_store.set(auth_hash, auth)
agent_card: AgentCard = await _afetch_agent_card_cached(
endpoint, auth_hash, timeout
)
@@ -155,7 +196,7 @@ async def afetch_agent_card(
@lru_cache()
def _fetch_agent_card_cached(
endpoint: str,
auth_hash: int,
auth_hash: str,
timeout: int,
_ttl_hash: int,
) -> AgentCard:
@@ -175,7 +216,7 @@ def _fetch_agent_card_cached(
@cached(ttl=300, serializer=PickleSerializer()) # type: ignore[untyped-decorator]
async def _afetch_agent_card_cached(
endpoint: str,
auth_hash: int,
auth_hash: str,
timeout: int,
) -> AgentCard:
"""Cached async implementation of AgentCard fetching."""
@@ -185,7 +226,7 @@ async def _afetch_agent_card_cached(
async def _afetch_agent_card_impl(
endpoint: str,
auth: AuthScheme | None,
auth: ClientAuthScheme | None,
timeout: int,
) -> AgentCard:
"""Internal async implementation of AgentCard fetching."""
@@ -197,16 +238,17 @@ async def _afetch_agent_card_impl(
else:
url_parts = endpoint.split("/", 3)
base_url = f"{url_parts[0]}//{url_parts[2]}"
agent_card_path = f"/{url_parts[3]}" if len(url_parts) > 3 else "/"
agent_card_path = (
f"/{url_parts[3]}"
if len(url_parts) > 3 and url_parts[3]
else "/.well-known/agent-card.json"
)
headers: MutableMapping[str, str] = {}
if auth:
async with httpx.AsyncClient(timeout=timeout) as temp_auth_client:
if isinstance(auth, (HTTPDigestAuth, APIKeyAuth)):
configure_auth_client(auth, temp_auth_client)
headers = await auth.apply_auth(temp_auth_client, {})
headers, verify = await _prepare_auth_headers(auth, timeout)
async with httpx.AsyncClient(timeout=timeout, headers=headers) as temp_client:
async with httpx.AsyncClient(
timeout=timeout, headers=headers, verify=verify
) as temp_client:
if auth and isinstance(auth, (HTTPDigestAuth, APIKeyAuth)):
configure_auth_client(auth, temp_client)
@@ -434,6 +476,7 @@ def _agent_to_agent_card(agent: Agent, url: str) -> AgentCard:
"""Generate an A2A AgentCard from an Agent instance.
Uses A2AServerConfig values when available, falling back to agent properties.
If signing_config is provided, the card will be signed with JWS.
Args:
agent: The Agent instance to generate a card for.
@@ -442,6 +485,8 @@ def _agent_to_agent_card(agent: Agent, url: str) -> AgentCard:
Returns:
AgentCard describing the agent's capabilities.
"""
from crewai.a2a.utils.agent_card_signing import sign_agent_card
server_config = _get_server_config(agent) or A2AServerConfig()
name = server_config.name or agent.role
@@ -472,15 +517,31 @@ def _agent_to_agent_card(agent: Agent, url: str) -> AgentCard:
)
)
return AgentCard(
capabilities = server_config.capabilities
if server_config.server_extensions:
from crewai.a2a.extensions.server import ServerExtensionRegistry
registry = ServerExtensionRegistry(server_config.server_extensions)
ext_list = registry.get_agent_extensions()
existing_exts = list(capabilities.extensions) if capabilities.extensions else []
existing_uris = {e.uri for e in existing_exts}
for ext in ext_list:
if ext.uri not in existing_uris:
existing_exts.append(ext)
capabilities = capabilities.model_copy(update={"extensions": existing_exts})
card = AgentCard(
name=name,
description=description,
url=server_config.url or url,
version=server_config.version,
capabilities=server_config.capabilities,
capabilities=capabilities,
default_input_modes=server_config.default_input_modes,
default_output_modes=server_config.default_output_modes,
skills=skills,
preferred_transport=server_config.transport.preferred,
protocol_version=server_config.protocol_version,
provider=server_config.provider,
documentation_url=server_config.documentation_url,
@@ -489,9 +550,21 @@ def _agent_to_agent_card(agent: Agent, url: str) -> AgentCard:
security=server_config.security,
security_schemes=server_config.security_schemes,
supports_authenticated_extended_card=server_config.supports_authenticated_extended_card,
signatures=server_config.signatures,
)
if server_config.signing_config:
signature = sign_agent_card(
card,
private_key=server_config.signing_config.get_private_key(),
key_id=server_config.signing_config.key_id,
algorithm=server_config.signing_config.algorithm,
)
card = card.model_copy(update={"signatures": [signature]})
elif server_config.signatures:
card = card.model_copy(update={"signatures": server_config.signatures})
return card
def inject_a2a_server_methods(agent: Agent) -> None:
"""Inject A2A server methods onto an Agent instance.

View File

@@ -0,0 +1,236 @@
"""AgentCard JWS signing utilities.
This module provides functions for signing and verifying AgentCards using
JSON Web Signatures (JWS) as per RFC 7515. Signed agent cards allow clients
to verify the authenticity and integrity of agent card information.
Example:
>>> from crewai.a2a.utils.agent_card_signing import sign_agent_card
>>> signature = sign_agent_card(agent_card, private_key_pem, key_id="key-1")
>>> card_with_sig = card.model_copy(update={"signatures": [signature]})
"""
from __future__ import annotations
import base64
import json
import logging
from typing import Any, Literal
from a2a.types import AgentCard, AgentCardSignature
import jwt
from pydantic import SecretStr
logger = logging.getLogger(__name__)
SigningAlgorithm = Literal[
"RS256", "RS384", "RS512", "ES256", "ES384", "ES512", "PS256", "PS384", "PS512"
]
def _normalize_private_key(private_key: str | bytes | SecretStr) -> bytes:
"""Normalize private key to bytes format.
Args:
private_key: PEM-encoded private key as string, bytes, or SecretStr.
Returns:
Private key as bytes.
"""
if isinstance(private_key, SecretStr):
private_key = private_key.get_secret_value()
if isinstance(private_key, str):
private_key = private_key.encode()
return private_key
def _serialize_agent_card(agent_card: AgentCard) -> str:
"""Serialize AgentCard to canonical JSON for signing.
Excludes the signatures field to avoid circular reference during signing.
Uses sorted keys and compact separators for deterministic output.
Args:
agent_card: The AgentCard to serialize.
Returns:
Canonical JSON string representation.
"""
card_dict = agent_card.model_dump(exclude={"signatures"}, exclude_none=True)
return json.dumps(card_dict, sort_keys=True, separators=(",", ":"))
def _base64url_encode(data: bytes | str) -> str:
"""Encode data to URL-safe base64 without padding.
Args:
data: Data to encode.
Returns:
URL-safe base64 encoded string without padding.
"""
if isinstance(data, str):
data = data.encode()
return base64.urlsafe_b64encode(data).rstrip(b"=").decode("ascii")
def sign_agent_card(
agent_card: AgentCard,
private_key: str | bytes | SecretStr,
key_id: str | None = None,
algorithm: SigningAlgorithm = "RS256",
) -> AgentCardSignature:
"""Sign an AgentCard using JWS (RFC 7515).
Creates a detached JWS signature for the AgentCard. The signature covers
all fields except the signatures field itself.
Args:
agent_card: The AgentCard to sign.
private_key: PEM-encoded private key (RSA, EC, or RSA-PSS).
key_id: Optional key identifier for the JWS header (kid claim).
algorithm: Signing algorithm (RS256, ES256, PS256, etc.).
Returns:
AgentCardSignature with protected header and signature.
Raises:
jwt.exceptions.InvalidKeyError: If the private key is invalid.
ValueError: If the algorithm is not supported for the key type.
Example:
>>> signature = sign_agent_card(
... agent_card,
... private_key_pem="-----BEGIN PRIVATE KEY-----...",
... key_id="my-key-id",
... )
"""
key_bytes = _normalize_private_key(private_key)
payload = _serialize_agent_card(agent_card)
protected_header: dict[str, Any] = {"typ": "JWS"}
if key_id:
protected_header["kid"] = key_id
jws_token = jwt.api_jws.encode(
payload.encode(),
key_bytes,
algorithm=algorithm,
headers=protected_header,
)
parts = jws_token.split(".")
protected_b64 = parts[0]
signature_b64 = parts[2]
header: dict[str, Any] | None = None
if key_id:
header = {"kid": key_id}
return AgentCardSignature(
protected=protected_b64,
signature=signature_b64,
header=header,
)
def verify_agent_card_signature(
agent_card: AgentCard,
signature: AgentCardSignature,
public_key: str | bytes,
algorithms: list[str] | None = None,
) -> bool:
"""Verify an AgentCard JWS signature.
Validates that the signature was created with the corresponding private key
and that the AgentCard content has not been modified.
Args:
agent_card: The AgentCard to verify.
signature: The AgentCardSignature to validate.
public_key: PEM-encoded public key (RSA, EC, or RSA-PSS).
algorithms: List of allowed algorithms. Defaults to common asymmetric algorithms.
Returns:
True if signature is valid, False otherwise.
Example:
>>> is_valid = verify_agent_card_signature(
... agent_card, signature, public_key_pem="-----BEGIN PUBLIC KEY-----..."
... )
"""
if algorithms is None:
algorithms = [
"RS256",
"RS384",
"RS512",
"ES256",
"ES384",
"ES512",
"PS256",
"PS384",
"PS512",
]
if isinstance(public_key, str):
public_key = public_key.encode()
payload = _serialize_agent_card(agent_card)
payload_b64 = _base64url_encode(payload)
jws_token = f"{signature.protected}.{payload_b64}.{signature.signature}"
try:
jwt.api_jws.decode(
jws_token,
public_key,
algorithms=algorithms,
)
return True
except jwt.InvalidSignatureError:
logger.debug(
"AgentCard signature verification failed",
extra={"reason": "invalid_signature"},
)
return False
except jwt.DecodeError as e:
logger.debug(
"AgentCard signature verification failed",
extra={"reason": "decode_error", "error": str(e)},
)
return False
except jwt.InvalidAlgorithmError as e:
logger.debug(
"AgentCard signature verification failed",
extra={"reason": "algorithm_error", "error": str(e)},
)
return False
def get_key_id_from_signature(signature: AgentCardSignature) -> str | None:
"""Extract the key ID (kid) from an AgentCardSignature.
Checks both the unprotected header and the protected header for the kid claim.
Args:
signature: The AgentCardSignature to extract from.
Returns:
The key ID if present, None otherwise.
"""
if signature.header and "kid" in signature.header:
kid: str = signature.header["kid"]
return kid
try:
protected = signature.protected
padding_needed = 4 - (len(protected) % 4)
if padding_needed != 4:
protected += "=" * padding_needed
protected_json = base64.urlsafe_b64decode(protected).decode()
protected_header: dict[str, Any] = json.loads(protected_json)
return protected_header.get("kid")
except (ValueError, json.JSONDecodeError):
return None

View File

@@ -0,0 +1,339 @@
"""Content type negotiation for A2A protocol.
This module handles negotiation of input/output MIME types between A2A clients
and servers based on AgentCard capabilities.
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING, Annotated, Final, Literal, cast
from a2a.types import Part
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.a2a_events import A2AContentTypeNegotiatedEvent
if TYPE_CHECKING:
from a2a.types import AgentCard, AgentSkill
TEXT_PLAIN: Literal["text/plain"] = "text/plain"
APPLICATION_JSON: Literal["application/json"] = "application/json"
IMAGE_PNG: Literal["image/png"] = "image/png"
IMAGE_JPEG: Literal["image/jpeg"] = "image/jpeg"
IMAGE_WILDCARD: Literal["image/*"] = "image/*"
APPLICATION_PDF: Literal["application/pdf"] = "application/pdf"
APPLICATION_OCTET_STREAM: Literal["application/octet-stream"] = (
"application/octet-stream"
)
DEFAULT_CLIENT_INPUT_MODES: Final[list[Literal["text/plain", "application/json"]]] = [
TEXT_PLAIN,
APPLICATION_JSON,
]
DEFAULT_CLIENT_OUTPUT_MODES: Final[list[Literal["text/plain", "application/json"]]] = [
TEXT_PLAIN,
APPLICATION_JSON,
]
@dataclass
class NegotiatedContentTypes:
"""Result of content type negotiation."""
input_modes: Annotated[list[str], "Negotiated input MIME types the client can send"]
output_modes: Annotated[
list[str], "Negotiated output MIME types the server will produce"
]
effective_input_modes: Annotated[list[str], "Server's effective input modes"]
effective_output_modes: Annotated[list[str], "Server's effective output modes"]
skill_name: Annotated[
str | None, "Skill name if negotiation was skill-specific"
] = None
class ContentTypeNegotiationError(Exception):
"""Raised when no compatible content types can be negotiated."""
def __init__(
self,
client_input_modes: list[str],
client_output_modes: list[str],
server_input_modes: list[str],
server_output_modes: list[str],
direction: str = "both",
message: str | None = None,
) -> None:
self.client_input_modes = client_input_modes
self.client_output_modes = client_output_modes
self.server_input_modes = server_input_modes
self.server_output_modes = server_output_modes
self.direction = direction
if message is None:
if direction == "input":
message = (
f"No compatible input content types. "
f"Client supports: {client_input_modes}, "
f"Server accepts: {server_input_modes}"
)
elif direction == "output":
message = (
f"No compatible output content types. "
f"Client accepts: {client_output_modes}, "
f"Server produces: {server_output_modes}"
)
else:
message = (
f"No compatible content types. "
f"Input - Client: {client_input_modes}, Server: {server_input_modes}. "
f"Output - Client: {client_output_modes}, Server: {server_output_modes}"
)
super().__init__(message)
def _normalize_mime_type(mime_type: str) -> str:
"""Normalize MIME type for comparison (lowercase, strip whitespace)."""
return mime_type.lower().strip()
def _mime_types_compatible(client_type: str, server_type: str) -> bool:
"""Check if two MIME types are compatible.
Handles wildcards like image/* matching image/png.
"""
client_normalized = _normalize_mime_type(client_type)
server_normalized = _normalize_mime_type(server_type)
if client_normalized == server_normalized:
return True
if "*" in client_normalized or "*" in server_normalized:
client_parts = client_normalized.split("/")
server_parts = server_normalized.split("/")
if len(client_parts) == 2 and len(server_parts) == 2:
type_match = (
client_parts[0] == server_parts[0]
or client_parts[0] == "*"
or server_parts[0] == "*"
)
subtype_match = (
client_parts[1] == server_parts[1]
or client_parts[1] == "*"
or server_parts[1] == "*"
)
return type_match and subtype_match
return False
def _find_compatible_modes(
client_modes: list[str], server_modes: list[str]
) -> list[str]:
"""Find compatible MIME types between client and server.
Returns modes in client preference order.
"""
compatible = []
for client_mode in client_modes:
for server_mode in server_modes:
if _mime_types_compatible(client_mode, server_mode):
if "*" in client_mode and "*" not in server_mode:
if server_mode not in compatible:
compatible.append(server_mode)
else:
if client_mode not in compatible:
compatible.append(client_mode)
break
return compatible
def _get_effective_modes(
agent_card: AgentCard,
skill_name: str | None = None,
) -> tuple[list[str], list[str], AgentSkill | None]:
"""Get effective input/output modes from agent card.
If skill_name is provided and the skill has custom modes, those are used.
Otherwise, falls back to agent card defaults.
"""
skill: AgentSkill | None = None
if skill_name and agent_card.skills:
for s in agent_card.skills:
if s.name == skill_name or s.id == skill_name:
skill = s
break
if skill:
input_modes = (
skill.input_modes if skill.input_modes else agent_card.default_input_modes
)
output_modes = (
skill.output_modes
if skill.output_modes
else agent_card.default_output_modes
)
else:
input_modes = agent_card.default_input_modes
output_modes = agent_card.default_output_modes
return input_modes, output_modes, skill
def negotiate_content_types(
agent_card: AgentCard,
client_input_modes: list[str] | None = None,
client_output_modes: list[str] | None = None,
skill_name: str | None = None,
emit_event: bool = True,
endpoint: str | None = None,
a2a_agent_name: str | None = None,
strict: bool = False,
) -> NegotiatedContentTypes:
"""Negotiate content types between client and server.
Args:
agent_card: The remote agent's card with capability info.
client_input_modes: MIME types the client can send. Defaults to text/plain and application/json.
client_output_modes: MIME types the client can accept. Defaults to text/plain and application/json.
skill_name: Optional skill to use for mode lookup.
emit_event: Whether to emit a content type negotiation event.
endpoint: Agent endpoint (for event metadata).
a2a_agent_name: Agent name (for event metadata).
strict: If True, raises error when no compatible types found.
If False, returns empty lists for incompatible directions.
Returns:
NegotiatedContentTypes with compatible input and output modes.
Raises:
ContentTypeNegotiationError: If strict=True and no compatible types found.
"""
if client_input_modes is None:
client_input_modes = cast(list[str], DEFAULT_CLIENT_INPUT_MODES.copy())
if client_output_modes is None:
client_output_modes = cast(list[str], DEFAULT_CLIENT_OUTPUT_MODES.copy())
server_input_modes, server_output_modes, skill = _get_effective_modes(
agent_card, skill_name
)
compatible_input = _find_compatible_modes(client_input_modes, server_input_modes)
compatible_output = _find_compatible_modes(client_output_modes, server_output_modes)
if strict:
if not compatible_input and not compatible_output:
raise ContentTypeNegotiationError(
client_input_modes=client_input_modes,
client_output_modes=client_output_modes,
server_input_modes=server_input_modes,
server_output_modes=server_output_modes,
)
if not compatible_input:
raise ContentTypeNegotiationError(
client_input_modes=client_input_modes,
client_output_modes=client_output_modes,
server_input_modes=server_input_modes,
server_output_modes=server_output_modes,
direction="input",
)
if not compatible_output:
raise ContentTypeNegotiationError(
client_input_modes=client_input_modes,
client_output_modes=client_output_modes,
server_input_modes=server_input_modes,
server_output_modes=server_output_modes,
direction="output",
)
result = NegotiatedContentTypes(
input_modes=compatible_input,
output_modes=compatible_output,
effective_input_modes=server_input_modes,
effective_output_modes=server_output_modes,
skill_name=skill.name if skill else None,
)
if emit_event:
crewai_event_bus.emit(
None,
A2AContentTypeNegotiatedEvent(
endpoint=endpoint or agent_card.url,
a2a_agent_name=a2a_agent_name or agent_card.name,
skill_name=skill_name,
client_input_modes=client_input_modes,
client_output_modes=client_output_modes,
server_input_modes=server_input_modes,
server_output_modes=server_output_modes,
negotiated_input_modes=compatible_input,
negotiated_output_modes=compatible_output,
negotiation_success=bool(compatible_input and compatible_output),
),
)
return result
def validate_content_type(
content_type: str,
allowed_modes: list[str],
) -> bool:
"""Validate that a content type is allowed by a list of modes.
Args:
content_type: The MIME type to validate.
allowed_modes: List of allowed MIME types (may include wildcards).
Returns:
True if content_type is compatible with any allowed mode.
"""
for mode in allowed_modes:
if _mime_types_compatible(content_type, mode):
return True
return False
def get_part_content_type(part: Part) -> str:
"""Extract MIME type from an A2A Part.
Args:
part: A Part object containing TextPart, DataPart, or FilePart.
Returns:
The MIME type string for this part.
"""
root = part.root
if root.kind == "text":
return TEXT_PLAIN
if root.kind == "data":
return APPLICATION_JSON
if root.kind == "file":
return root.file.mime_type or APPLICATION_OCTET_STREAM
return APPLICATION_OCTET_STREAM
def validate_message_parts(
parts: list[Part],
allowed_modes: list[str],
) -> list[str]:
"""Validate that all message parts have allowed content types.
Args:
parts: List of Parts from the incoming message.
allowed_modes: List of allowed MIME types (from default_input_modes).
Returns:
List of invalid content types found (empty if all valid).
"""
invalid_types: list[str] = []
for part in parts:
content_type = get_part_content_type(part)
if not validate_content_type(content_type, allowed_modes):
if content_type not in invalid_types:
invalid_types.append(content_type)
return invalid_types

View File

@@ -3,9 +3,10 @@
from __future__ import annotations
import asyncio
from collections.abc import AsyncIterator, MutableMapping
from collections.abc import AsyncIterator, Callable, MutableMapping
from contextlib import asynccontextmanager
from typing import TYPE_CHECKING, Any, Literal
import logging
from typing import TYPE_CHECKING, Any, Final, Literal
import uuid
from a2a.client import Client, ClientConfig, ClientFactory
@@ -20,18 +21,24 @@ from a2a.types import (
import httpx
from pydantic import BaseModel
from crewai.a2a.auth.schemas import APIKeyAuth, HTTPDigestAuth
from crewai.a2a.auth.client_schemes import APIKeyAuth, HTTPDigestAuth
from crewai.a2a.auth.utils import (
_auth_store,
configure_auth_client,
validate_auth_against_agent_card,
)
from crewai.a2a.config import ClientTransportConfig, GRPCClientConfig
from crewai.a2a.extensions.registry import (
ExtensionsMiddleware,
validate_required_extensions,
)
from crewai.a2a.task_helpers import TaskStateResult
from crewai.a2a.types import (
HANDLER_REGISTRY,
HandlerType,
PartsDict,
PartsMetadataDict,
TransportType,
)
from crewai.a2a.updates import (
PollingConfig,
@@ -39,7 +46,20 @@ from crewai.a2a.updates import (
StreamingHandler,
UpdateConfig,
)
from crewai.a2a.utils.agent_card import _afetch_agent_card_cached
from crewai.a2a.utils.agent_card import (
_afetch_agent_card_cached,
_get_tls_verify,
_prepare_auth_headers,
)
from crewai.a2a.utils.content_type import (
DEFAULT_CLIENT_OUTPUT_MODES,
negotiate_content_types,
)
from crewai.a2a.utils.transport import (
NegotiatedTransport,
TransportNegotiationError,
negotiate_transport,
)
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.a2a_events import (
A2AConversationStartedEvent,
@@ -49,10 +69,16 @@ from crewai.events.types.a2a_events import (
)
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from a2a.types import Message
from crewai.a2a.auth.schemas import AuthScheme
from crewai.a2a.auth.client_schemes import ClientAuthScheme
_DEFAULT_TRANSPORT: Final[TransportType] = "JSONRPC"
def get_handler(config: UpdateConfig | None) -> HandlerType:
@@ -71,8 +97,7 @@ def get_handler(config: UpdateConfig | None) -> HandlerType:
def execute_a2a_delegation(
endpoint: str,
transport_protocol: Literal["JSONRPC", "GRPC", "HTTP+JSON"],
auth: AuthScheme | None,
auth: ClientAuthScheme | None,
timeout: int,
task_description: str,
context: str | None = None,
@@ -91,32 +116,23 @@ def execute_a2a_delegation(
from_task: Any | None = None,
from_agent: Any | None = None,
skill_id: str | None = None,
client_extensions: list[str] | None = None,
transport: ClientTransportConfig | None = None,
accepted_output_modes: list[str] | None = None,
) -> TaskStateResult:
"""Execute a task delegation to a remote A2A agent synchronously.
This is the sync wrapper around aexecute_a2a_delegation. For async contexts,
use aexecute_a2a_delegation directly.
WARNING: This function blocks the entire thread by creating and running a new
event loop. Prefer using 'await aexecute_a2a_delegation()' in async contexts
for better performance and resource efficiency.
This is a synchronous wrapper around aexecute_a2a_delegation that creates a
new event loop to run the async implementation. It is provided for compatibility
with synchronous code paths only.
Args:
endpoint: A2A agent endpoint URL (AgentCard URL)
transport_protocol: Optional A2A transport protocol (grpc, jsonrpc, http+json)
auth: Optional AuthScheme for authentication (Bearer, OAuth2, API Key, HTTP Basic/Digest)
timeout: Request timeout in seconds
task_description: The task to delegate
context: Optional context information
context_id: Context ID for correlating messages/tasks
task_id: Specific task identifier
reference_task_ids: List of related task IDs
metadata: Additional metadata (external_id, request_id, etc.)
extensions: Protocol extensions for custom fields
conversation_history: Previous Message objects from conversation
agent_id: Agent identifier for logging
agent_role: Role of the CrewAI agent delegating the task
agent_branch: Optional agent tree branch for logging
response_model: Optional Pydantic model for structured outputs
turn_number: Optional turn number for multi-turn conversations
endpoint: A2A agent endpoint URL.
auth: Optional AuthScheme for authentication.
endpoint: A2A agent endpoint URL (AgentCard URL).
auth: Optional ClientAuthScheme for authentication.
timeout: Request timeout in seconds.
task_description: The task to delegate.
context: Optional context information.
@@ -135,10 +151,26 @@ def execute_a2a_delegation(
from_task: Optional CrewAI Task object for event metadata.
from_agent: Optional CrewAI Agent object for event metadata.
skill_id: Optional skill ID to target a specific agent capability.
client_extensions: A2A protocol extension URIs the client supports.
transport: Transport configuration (preferred, supported transports, gRPC settings).
accepted_output_modes: MIME types the client can accept in responses.
Returns:
TaskStateResult with status, result/error, history, and agent_card.
Raises:
RuntimeError: If called from an async context with a running event loop.
"""
try:
asyncio.get_running_loop()
raise RuntimeError(
"execute_a2a_delegation() cannot be called from an async context. "
"Use 'await aexecute_a2a_delegation()' instead."
)
except RuntimeError as e:
if "no running event loop" not in str(e).lower():
raise
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
@@ -159,12 +191,14 @@ def execute_a2a_delegation(
agent_role=agent_role,
agent_branch=agent_branch,
response_model=response_model,
transport_protocol=transport_protocol,
turn_number=turn_number,
updates=updates,
from_task=from_task,
from_agent=from_agent,
skill_id=skill_id,
client_extensions=client_extensions,
transport=transport,
accepted_output_modes=accepted_output_modes,
)
)
finally:
@@ -176,8 +210,7 @@ def execute_a2a_delegation(
async def aexecute_a2a_delegation(
endpoint: str,
transport_protocol: Literal["JSONRPC", "GRPC", "HTTP+JSON"],
auth: AuthScheme | None,
auth: ClientAuthScheme | None,
timeout: int,
task_description: str,
context: str | None = None,
@@ -196,6 +229,9 @@ async def aexecute_a2a_delegation(
from_task: Any | None = None,
from_agent: Any | None = None,
skill_id: str | None = None,
client_extensions: list[str] | None = None,
transport: ClientTransportConfig | None = None,
accepted_output_modes: list[str] | None = None,
) -> TaskStateResult:
"""Execute a task delegation to a remote A2A agent asynchronously.
@@ -203,25 +239,8 @@ async def aexecute_a2a_delegation(
in an async context (e.g., with Crew.akickoff() or agent.aexecute_task()).
Args:
endpoint: A2A agent endpoint URL
transport_protocol: Optional A2A transport protocol (grpc, jsonrpc, http+json)
auth: Optional AuthScheme for authentication
timeout: Request timeout in seconds
task_description: Task to delegate
context: Optional context
context_id: Context ID for correlation
task_id: Specific task identifier
reference_task_ids: Related task IDs
metadata: Additional metadata
extensions: Protocol extensions
conversation_history: Previous Message objects
turn_number: Current turn number
agent_branch: Agent tree branch for logging
agent_id: Agent identifier for logging
agent_role: Agent role for logging
response_model: Optional Pydantic model for structured outputs
endpoint: A2A agent endpoint URL.
auth: Optional AuthScheme for authentication.
auth: Optional ClientAuthScheme for authentication.
timeout: Request timeout in seconds.
task_description: The task to delegate.
context: Optional context information.
@@ -240,6 +259,9 @@ async def aexecute_a2a_delegation(
from_task: Optional CrewAI Task object for event metadata.
from_agent: Optional CrewAI Agent object for event metadata.
skill_id: Optional skill ID to target a specific agent capability.
client_extensions: A2A protocol extension URIs the client supports.
transport: Transport configuration (preferred, supported transports, gRPC settings).
accepted_output_modes: MIME types the client can accept in responses.
Returns:
TaskStateResult with status, result/error, history, and agent_card.
@@ -271,10 +293,12 @@ async def aexecute_a2a_delegation(
agent_role=agent_role,
response_model=response_model,
updates=updates,
transport_protocol=transport_protocol,
from_task=from_task,
from_agent=from_agent,
skill_id=skill_id,
client_extensions=client_extensions,
transport=transport,
accepted_output_modes=accepted_output_modes,
)
except Exception as e:
crewai_event_bus.emit(
@@ -294,7 +318,7 @@ async def aexecute_a2a_delegation(
)
raise
agent_card_data: dict[str, Any] = result.get("agent_card") or {}
agent_card_data = result.get("agent_card")
crewai_event_bus.emit(
agent_branch,
A2ADelegationCompletedEvent(
@@ -306,7 +330,7 @@ async def aexecute_a2a_delegation(
endpoint=endpoint,
a2a_agent_name=result.get("a2a_agent_name"),
agent_card=agent_card_data,
provider=agent_card_data.get("provider"),
provider=agent_card_data.get("provider") if agent_card_data else None,
metadata=metadata,
extensions=list(extensions.keys()) if extensions else None,
from_task=from_task,
@@ -319,8 +343,7 @@ async def aexecute_a2a_delegation(
async def _aexecute_a2a_delegation_impl(
endpoint: str,
transport_protocol: Literal["JSONRPC", "GRPC", "HTTP+JSON"],
auth: AuthScheme | None,
auth: ClientAuthScheme | None,
timeout: int,
task_description: str,
context: str | None,
@@ -340,8 +363,13 @@ async def _aexecute_a2a_delegation_impl(
from_task: Any | None = None,
from_agent: Any | None = None,
skill_id: str | None = None,
client_extensions: list[str] | None = None,
transport: ClientTransportConfig | None = None,
accepted_output_modes: list[str] | None = None,
) -> TaskStateResult:
"""Internal async implementation of A2A delegation."""
if transport is None:
transport = ClientTransportConfig()
if auth:
auth_data = auth.model_dump_json(
exclude={
@@ -351,22 +379,70 @@ async def _aexecute_a2a_delegation_impl(
"_authorization_callback",
}
)
auth_hash = hash((type(auth).__name__, auth_data))
auth_hash = _auth_store.compute_key(type(auth).__name__, auth_data)
else:
auth_hash = 0
_auth_store[auth_hash] = auth
auth_hash = _auth_store.compute_key("none", endpoint)
_auth_store.set(auth_hash, auth)
agent_card = await _afetch_agent_card_cached(
endpoint=endpoint, auth_hash=auth_hash, timeout=timeout
)
validate_auth_against_agent_card(agent_card, auth)
headers: MutableMapping[str, str] = {}
if auth:
async with httpx.AsyncClient(timeout=timeout) as temp_auth_client:
if isinstance(auth, (HTTPDigestAuth, APIKeyAuth)):
configure_auth_client(auth, temp_auth_client)
headers = await auth.apply_auth(temp_auth_client, {})
unsupported_exts = validate_required_extensions(agent_card, client_extensions)
if unsupported_exts:
ext_uris = [ext.uri for ext in unsupported_exts]
raise ValueError(
f"Agent requires extensions not supported by client: {ext_uris}"
)
negotiated: NegotiatedTransport | None = None
effective_transport: TransportType = transport.preferred or _DEFAULT_TRANSPORT
effective_url = endpoint
client_transports: list[str] = (
list(transport.supported) if transport.supported else [_DEFAULT_TRANSPORT]
)
try:
negotiated = negotiate_transport(
agent_card=agent_card,
client_supported_transports=client_transports,
client_preferred_transport=transport.preferred,
endpoint=endpoint,
a2a_agent_name=agent_card.name,
)
effective_transport = negotiated.transport # type: ignore[assignment]
effective_url = negotiated.url
except TransportNegotiationError as e:
logger.warning(
"Transport negotiation failed, using fallback",
extra={
"error": str(e),
"fallback_transport": effective_transport,
"fallback_url": effective_url,
"endpoint": endpoint,
"client_transports": client_transports,
"server_transports": [
iface.transport for iface in agent_card.additional_interfaces or []
]
+ [agent_card.preferred_transport or "JSONRPC"],
},
)
effective_output_modes = accepted_output_modes or DEFAULT_CLIENT_OUTPUT_MODES.copy()
content_negotiated = negotiate_content_types(
agent_card=agent_card,
client_output_modes=accepted_output_modes,
skill_name=skill_id,
endpoint=endpoint,
a2a_agent_name=agent_card.name,
)
if content_negotiated.output_modes:
effective_output_modes = content_negotiated.output_modes
headers, _ = await _prepare_auth_headers(auth, timeout)
a2a_agent_name = None
if agent_card.name:
@@ -513,15 +589,22 @@ async def _aexecute_a2a_delegation_impl(
use_streaming = not use_polling and push_config_for_client is None
client_agent_card = agent_card
if effective_url != agent_card.url:
client_agent_card = agent_card.model_copy(update={"url": effective_url})
async with _create_a2a_client(
agent_card=agent_card,
transport_protocol=transport_protocol,
agent_card=client_agent_card,
transport_protocol=effective_transport,
timeout=timeout,
headers=headers,
streaming=use_streaming,
auth=auth,
use_polling=use_polling,
push_notification_config=push_config_for_client,
client_extensions=client_extensions,
accepted_output_modes=effective_output_modes, # type: ignore[arg-type]
grpc_config=transport.grpc,
) as client:
result = await handler.execute(
client=client,
@@ -535,6 +618,245 @@ async def _aexecute_a2a_delegation_impl(
return result
def _normalize_grpc_metadata(
metadata: tuple[tuple[str, str], ...] | None,
) -> tuple[tuple[str, str], ...] | None:
"""Lowercase all gRPC metadata keys.
gRPC requires lowercase metadata keys, but some libraries (like the A2A SDK)
use mixed-case headers like 'X-A2A-Extensions'. This normalizes them.
"""
if metadata is None:
return None
return tuple((key.lower(), value) for key, value in metadata)
def _create_grpc_interceptors(
auth_metadata: list[tuple[str, str]] | None = None,
) -> list[Any]:
"""Create gRPC interceptors for metadata normalization and auth injection.
Args:
auth_metadata: Optional auth metadata to inject into all calls.
Used for insecure channels that need auth (non-localhost without TLS).
Returns a list of interceptors that lowercase metadata keys for gRPC
compatibility. Must be called after grpc is imported.
"""
import grpc.aio # type: ignore[import-untyped]
def _merge_metadata(
existing: tuple[tuple[str, str], ...] | None,
auth: list[tuple[str, str]] | None,
) -> tuple[tuple[str, str], ...] | None:
"""Merge existing metadata with auth metadata and normalize keys."""
merged: list[tuple[str, str]] = []
if existing:
merged.extend(existing)
if auth:
merged.extend(auth)
if not merged:
return None
return tuple((key.lower(), value) for key, value in merged)
def _inject_metadata(client_call_details: Any) -> Any:
"""Inject merged metadata into call details."""
return client_call_details._replace(
metadata=_merge_metadata(client_call_details.metadata, auth_metadata)
)
class MetadataUnaryUnary(grpc.aio.UnaryUnaryClientInterceptor): # type: ignore[misc,no-any-unimported]
"""Interceptor for unary-unary calls that injects auth metadata."""
async def intercept_unary_unary( # type: ignore[no-untyped-def]
self, continuation, client_call_details, request
):
"""Intercept unary-unary call and inject metadata."""
return await continuation(_inject_metadata(client_call_details), request)
class MetadataUnaryStream(grpc.aio.UnaryStreamClientInterceptor): # type: ignore[misc,no-any-unimported]
"""Interceptor for unary-stream calls that injects auth metadata."""
async def intercept_unary_stream( # type: ignore[no-untyped-def]
self, continuation, client_call_details, request
):
"""Intercept unary-stream call and inject metadata."""
return await continuation(_inject_metadata(client_call_details), request)
class MetadataStreamUnary(grpc.aio.StreamUnaryClientInterceptor): # type: ignore[misc,no-any-unimported]
"""Interceptor for stream-unary calls that injects auth metadata."""
async def intercept_stream_unary( # type: ignore[no-untyped-def]
self, continuation, client_call_details, request_iterator
):
"""Intercept stream-unary call and inject metadata."""
return await continuation(
_inject_metadata(client_call_details), request_iterator
)
class MetadataStreamStream(grpc.aio.StreamStreamClientInterceptor): # type: ignore[misc,no-any-unimported]
"""Interceptor for stream-stream calls that injects auth metadata."""
async def intercept_stream_stream( # type: ignore[no-untyped-def]
self, continuation, client_call_details, request_iterator
):
"""Intercept stream-stream call and inject metadata."""
return await continuation(
_inject_metadata(client_call_details), request_iterator
)
return [
MetadataUnaryUnary(),
MetadataUnaryStream(),
MetadataStreamUnary(),
MetadataStreamStream(),
]
def _create_grpc_channel_factory(
grpc_config: GRPCClientConfig,
auth: ClientAuthScheme | None = None,
) -> Callable[[str], Any]:
"""Create a gRPC channel factory with the given configuration.
Args:
grpc_config: gRPC client configuration with channel options.
auth: Optional ClientAuthScheme for TLS and auth configuration.
Returns:
A callable that creates gRPC channels from URLs.
"""
try:
import grpc
except ImportError as e:
raise ImportError(
"gRPC transport requires grpcio. Install with: pip install a2a-sdk[grpc]"
) from e
auth_metadata: list[tuple[str, str]] = []
if auth is not None:
from crewai.a2a.auth.client_schemes import (
APIKeyAuth,
BearerTokenAuth,
HTTPBasicAuth,
HTTPDigestAuth,
OAuth2AuthorizationCode,
OAuth2ClientCredentials,
)
if isinstance(auth, HTTPDigestAuth):
raise ValueError(
"HTTPDigestAuth is not supported with gRPC transport. "
"Digest authentication requires HTTP challenge-response flow. "
"Use BearerTokenAuth, HTTPBasicAuth, APIKeyAuth (header), or OAuth2 instead."
)
if isinstance(auth, APIKeyAuth) and auth.location in ("query", "cookie"):
raise ValueError(
f"APIKeyAuth with location='{auth.location}' is not supported with gRPC transport. "
"gRPC only supports header-based authentication. "
"Use APIKeyAuth with location='header' instead."
)
if isinstance(auth, BearerTokenAuth):
auth_metadata.append(("authorization", f"Bearer {auth.token}"))
elif isinstance(auth, HTTPBasicAuth):
import base64
basic_credentials = f"{auth.username}:{auth.password}"
encoded = base64.b64encode(basic_credentials.encode()).decode()
auth_metadata.append(("authorization", f"Basic {encoded}"))
elif isinstance(auth, APIKeyAuth) and auth.location == "header":
header_name = auth.name.lower()
auth_metadata.append((header_name, auth.api_key))
elif isinstance(auth, (OAuth2ClientCredentials, OAuth2AuthorizationCode)):
if auth._access_token:
auth_metadata.append(("authorization", f"Bearer {auth._access_token}"))
def factory(url: str) -> Any:
"""Create a gRPC channel for the given URL."""
target = url
use_tls = False
if url.startswith("grpcs://"):
target = url[8:]
use_tls = True
elif url.startswith("grpc://"):
target = url[7:]
elif url.startswith("https://"):
target = url[8:]
use_tls = True
elif url.startswith("http://"):
target = url[7:]
options: list[tuple[str, Any]] = []
if grpc_config.max_send_message_length is not None:
options.append(
("grpc.max_send_message_length", grpc_config.max_send_message_length)
)
if grpc_config.max_receive_message_length is not None:
options.append(
(
"grpc.max_receive_message_length",
grpc_config.max_receive_message_length,
)
)
if grpc_config.keepalive_time_ms is not None:
options.append(("grpc.keepalive_time_ms", grpc_config.keepalive_time_ms))
if grpc_config.keepalive_timeout_ms is not None:
options.append(
("grpc.keepalive_timeout_ms", grpc_config.keepalive_timeout_ms)
)
channel_credentials = None
if auth and hasattr(auth, "tls") and auth.tls:
channel_credentials = auth.tls.get_grpc_credentials()
elif use_tls:
channel_credentials = grpc.ssl_channel_credentials()
if channel_credentials and auth_metadata:
class AuthMetadataPlugin(grpc.AuthMetadataPlugin): # type: ignore[misc,no-any-unimported]
"""gRPC auth metadata plugin that adds auth headers as metadata."""
def __init__(self, metadata: list[tuple[str, str]]) -> None:
self._metadata = tuple(metadata)
def __call__( # type: ignore[no-any-unimported]
self,
context: grpc.AuthMetadataContext,
callback: grpc.AuthMetadataPluginCallback,
) -> None:
callback(self._metadata, None)
call_creds = grpc.metadata_call_credentials(
AuthMetadataPlugin(auth_metadata)
)
credentials = grpc.composite_channel_credentials(
channel_credentials, call_creds
)
interceptors = _create_grpc_interceptors()
return grpc.aio.secure_channel(
target, credentials, options=options or None, interceptors=interceptors
)
if channel_credentials:
interceptors = _create_grpc_interceptors()
return grpc.aio.secure_channel(
target,
channel_credentials,
options=options or None,
interceptors=interceptors,
)
interceptors = _create_grpc_interceptors(
auth_metadata=auth_metadata if auth_metadata else None
)
return grpc.aio.insecure_channel(
target, options=options or None, interceptors=interceptors
)
return factory
@asynccontextmanager
async def _create_a2a_client(
agent_card: AgentCard,
@@ -542,9 +864,12 @@ async def _create_a2a_client(
timeout: int,
headers: MutableMapping[str, str],
streaming: bool,
auth: AuthScheme | None = None,
auth: ClientAuthScheme | None = None,
use_polling: bool = False,
push_notification_config: PushNotificationConfig | None = None,
client_extensions: list[str] | None = None,
accepted_output_modes: list[str] | None = None,
grpc_config: GRPCClientConfig | None = None,
) -> AsyncIterator[Client]:
"""Create and configure an A2A client.
@@ -554,16 +879,21 @@ async def _create_a2a_client(
timeout: Request timeout in seconds.
headers: HTTP headers (already with auth applied).
streaming: Enable streaming responses.
auth: Optional AuthScheme for client configuration.
auth: Optional ClientAuthScheme for client configuration.
use_polling: Enable polling mode.
push_notification_config: Optional push notification config.
client_extensions: A2A protocol extension URIs to declare support for.
accepted_output_modes: MIME types the client can accept in responses.
grpc_config: Optional gRPC client configuration.
Yields:
Configured A2A client instance.
"""
verify = _get_tls_verify(auth)
async with httpx.AsyncClient(
timeout=timeout,
headers=headers,
verify=verify,
) as httpx_client:
if auth and isinstance(auth, (HTTPDigestAuth, APIKeyAuth)):
configure_auth_client(auth, httpx_client)
@@ -579,15 +909,27 @@ async def _create_a2a_client(
)
)
grpc_channel_factory = None
if transport_protocol == "GRPC":
grpc_channel_factory = _create_grpc_channel_factory(
grpc_config or GRPCClientConfig(),
auth=auth,
)
config = ClientConfig(
httpx_client=httpx_client,
supported_transports=[transport_protocol],
streaming=streaming and not use_polling,
polling=use_polling,
accepted_output_modes=["application/json"],
accepted_output_modes=accepted_output_modes or DEFAULT_CLIENT_OUTPUT_MODES, # type: ignore[arg-type]
push_notification_configs=push_configs,
grpc_channel_factory=grpc_channel_factory,
)
factory = ClientFactory(config)
client = factory.create(agent_card)
if client_extensions:
await client.add_request_middleware(ExtensionsMiddleware(client_extensions))
yield client

View File

@@ -0,0 +1,131 @@
"""Structured JSON logging utilities for A2A module."""
from __future__ import annotations
from contextvars import ContextVar
from datetime import datetime, timezone
import json
import logging
from typing import Any
_log_context: ContextVar[dict[str, Any] | None] = ContextVar(
"log_context", default=None
)
class JSONFormatter(logging.Formatter):
"""JSON formatter for structured logging.
Outputs logs as JSON with consistent fields for log aggregators.
"""
def format(self, record: logging.LogRecord) -> str:
"""Format log record as JSON string."""
log_data: dict[str, Any] = {
"timestamp": datetime.now(timezone.utc).isoformat(),
"level": record.levelname,
"logger": record.name,
"message": record.getMessage(),
}
if record.exc_info:
log_data["exception"] = self.formatException(record.exc_info)
context = _log_context.get()
if context is not None:
log_data.update(context)
if hasattr(record, "task_id"):
log_data["task_id"] = record.task_id
if hasattr(record, "context_id"):
log_data["context_id"] = record.context_id
if hasattr(record, "agent"):
log_data["agent"] = record.agent
if hasattr(record, "endpoint"):
log_data["endpoint"] = record.endpoint
if hasattr(record, "extension"):
log_data["extension"] = record.extension
if hasattr(record, "error"):
log_data["error"] = record.error
for key, value in record.__dict__.items():
if key.startswith("_") or key in (
"name",
"msg",
"args",
"created",
"filename",
"funcName",
"levelname",
"levelno",
"lineno",
"module",
"msecs",
"pathname",
"process",
"processName",
"relativeCreated",
"stack_info",
"exc_info",
"exc_text",
"thread",
"threadName",
"taskName",
"message",
):
continue
if key not in log_data:
log_data[key] = value
return json.dumps(log_data, default=str)
class LogContext:
"""Context manager for adding fields to all logs within a scope.
Example:
with LogContext(task_id="abc", context_id="xyz"):
logger.info("Processing task") # Includes task_id and context_id
"""
def __init__(self, **fields: Any) -> None:
self._fields = fields
self._token: Any = None
def __enter__(self) -> LogContext:
current = _log_context.get() or {}
new_context = {**current, **self._fields}
self._token = _log_context.set(new_context)
return self
def __exit__(self, *args: Any) -> None:
_log_context.reset(self._token)
def configure_json_logging(logger_name: str = "crewai.a2a") -> None:
"""Configure JSON logging for the A2A module.
Args:
logger_name: Logger name to configure.
"""
logger = logging.getLogger(logger_name)
for handler in logger.handlers[:]:
logger.removeHandler(handler)
handler = logging.StreamHandler()
handler.setFormatter(JSONFormatter())
logger.addHandler(handler)
def get_logger(name: str) -> logging.Logger:
"""Get a logger configured for structured JSON output.
Args:
name: Logger name.
Returns:
Configured logger instance.
"""
return logging.getLogger(name)

View File

@@ -7,26 +7,37 @@ import base64
from collections.abc import Callable, Coroutine
from datetime import datetime
from functools import wraps
import json
import logging
import os
from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar, cast
from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar, TypedDict, cast
from urllib.parse import urlparse
from a2a.server.agent_execution import RequestContext
from a2a.server.events import EventQueue
from a2a.types import (
Artifact,
InternalError,
InvalidParamsError,
Message,
Part,
Task as A2ATask,
TaskState,
TaskStatus,
TaskStatusUpdateEvent,
)
from a2a.utils import new_agent_text_message, new_text_artifact
from a2a.utils import (
get_data_parts,
new_agent_text_message,
new_data_artifact,
new_text_artifact,
)
from a2a.utils.errors import ServerError
from aiocache import SimpleMemoryCache, caches # type: ignore[import-untyped]
from pydantic import BaseModel
from crewai.a2a.utils.agent_card import _get_server_config
from crewai.a2a.utils.content_type import validate_message_parts
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.a2a_events import (
A2AServerTaskCanceledEvent,
@@ -35,9 +46,11 @@ from crewai.events.types.a2a_events import (
A2AServerTaskStartedEvent,
)
from crewai.task import Task
from crewai.utilities.pydantic_schema_utils import create_model_from_schema
if TYPE_CHECKING:
from crewai.a2a.extensions.server import ExtensionContext, ServerExtensionRegistry
from crewai.agent import Agent
@@ -47,7 +60,17 @@ P = ParamSpec("P")
T = TypeVar("T")
def _parse_redis_url(url: str) -> dict[str, Any]:
class RedisCacheConfig(TypedDict, total=False):
"""Configuration for aiocache Redis backend."""
cache: str
endpoint: str
port: int
db: int
password: str
def _parse_redis_url(url: str) -> RedisCacheConfig:
"""Parse a Redis URL into aiocache configuration.
Args:
@@ -56,9 +79,8 @@ def _parse_redis_url(url: str) -> dict[str, Any]:
Returns:
Configuration dict for aiocache.RedisCache.
"""
parsed = urlparse(url)
config: dict[str, Any] = {
config: RedisCacheConfig = {
"cache": "aiocache.RedisCache",
"endpoint": parsed.hostname or "localhost",
"port": parsed.port or 6379,
@@ -138,7 +160,10 @@ def cancellable(
if message["type"] == "message":
return True
except (OSError, ConnectionError) as e:
logger.warning("Cancel watcher error for task_id=%s: %s", task_id, e)
logger.warning(
"Cancel watcher Redis error, falling back to polling",
extra={"task_id": task_id, "error": str(e)},
)
return await poll_for_cancel()
return False
@@ -166,7 +191,67 @@ def cancellable(
return wrapper
@cancellable
def _extract_response_schema(parts: list[Part]) -> dict[str, Any] | None:
"""Extract response schema from message parts metadata.
The client may include a JSON schema in TextPart metadata to specify
the expected response format (see delegation.py line 463).
Args:
parts: List of message parts.
Returns:
JSON schema dict if found, None otherwise.
"""
for part in parts:
if part.root.kind == "text" and part.root.metadata:
schema = part.root.metadata.get("schema")
if schema and isinstance(schema, dict):
return schema # type: ignore[no-any-return]
return None
def _create_result_artifact(
result: Any,
task_id: str,
) -> Artifact:
"""Create artifact from task result, using DataPart for structured data.
Args:
result: The task execution result.
task_id: The task ID for naming the artifact.
Returns:
Artifact with appropriate part type (DataPart for dict/Pydantic, TextPart for strings).
"""
artifact_name = f"result_{task_id}"
if isinstance(result, dict):
return new_data_artifact(artifact_name, result)
if isinstance(result, BaseModel):
return new_data_artifact(artifact_name, result.model_dump())
return new_text_artifact(artifact_name, str(result))
def _build_task_description(
user_message: str,
structured_inputs: list[dict[str, Any]],
) -> str:
"""Build task description including structured data if present.
Args:
user_message: The original user message text.
structured_inputs: List of structured data from DataParts.
Returns:
Task description with structured data appended if present.
"""
if not structured_inputs:
return user_message
structured_json = json.dumps(structured_inputs, indent=2)
return f"{user_message}\n\nStructured Data:\n{structured_json}"
async def execute(
agent: Agent,
context: RequestContext,
@@ -178,15 +263,52 @@ async def execute(
agent: The CrewAI agent to execute the task.
context: The A2A request context containing the user's message.
event_queue: The event queue for sending responses back.
TODOs:
* need to impl both of structured output and file inputs, depends on `file_inputs` for
`crewai.task.Task`, pass the below two to Task. both utils in `a2a.utils.parts`
* structured outputs ingestion, `structured_inputs = get_data_parts(parts=context.message.parts)`
* file inputs ingestion, `file_inputs = get_file_parts(parts=context.message.parts)`
"""
await _execute_impl(agent, context, event_queue, None, None)
@cancellable
async def _execute_impl(
agent: Agent,
context: RequestContext,
event_queue: EventQueue,
extension_registry: ServerExtensionRegistry | None,
extension_context: ExtensionContext | None,
) -> None:
"""Internal implementation for task execution with optional extensions."""
server_config = _get_server_config(agent)
if context.message and context.message.parts and server_config:
allowed_modes = server_config.default_input_modes
invalid_types = validate_message_parts(context.message.parts, allowed_modes)
if invalid_types:
raise ServerError(
InvalidParamsError(
message=f"Unsupported content type(s): {', '.join(invalid_types)}. "
f"Supported: {', '.join(allowed_modes)}"
)
)
if extension_registry and extension_context:
await extension_registry.invoke_on_request(extension_context)
user_message = context.get_user_input()
response_model: type[BaseModel] | None = None
structured_inputs: list[dict[str, Any]] = []
if context.message and context.message.parts:
schema = _extract_response_schema(context.message.parts)
if schema:
try:
response_model = create_model_from_schema(schema)
except Exception as e:
logger.debug(
"Failed to create response model from schema",
extra={"error": str(e), "schema_title": schema.get("title")},
)
structured_inputs = get_data_parts(context.message.parts)
task_id = context.task_id
context_id = context.context_id
if task_id is None or context_id is None:
@@ -203,9 +325,10 @@ async def execute(
raise ServerError(InvalidParamsError(message=msg)) from None
task = Task(
description=user_message,
description=_build_task_description(user_message, structured_inputs),
expected_output="Response to the user's request",
agent=agent,
response_model=response_model,
)
crewai_event_bus.emit(
@@ -220,6 +343,10 @@ async def execute(
try:
result = await agent.aexecute_task(task=task, tools=agent.tools)
if extension_registry and extension_context:
result = await extension_registry.invoke_on_response(
extension_context, result
)
result_str = str(result)
history: list[Message] = [context.message] if context.message else []
history.append(new_agent_text_message(result_str, context_id, task_id))
@@ -227,8 +354,8 @@ async def execute(
A2ATask(
id=task_id,
context_id=context_id,
status=TaskStatus(state=TaskState.input_required),
artifacts=[new_text_artifact(result_str, f"result_{task_id}")],
status=TaskStatus(state=TaskState.completed),
artifacts=[_create_result_artifact(result, task_id)],
history=history,
)
)
@@ -269,6 +396,27 @@ async def execute(
) from e
async def execute_with_extensions(
agent: Agent,
context: RequestContext,
event_queue: EventQueue,
extension_registry: ServerExtensionRegistry,
extension_context: ExtensionContext,
) -> None:
"""Execute an A2A task with extension hooks.
Args:
agent: The CrewAI agent to execute the task.
context: The A2A request context containing the user's message.
event_queue: The event queue for sending responses back.
extension_registry: Registry of server extensions.
extension_context: Context for extension hooks.
"""
await _execute_impl(
agent, context, event_queue, extension_registry, extension_context
)
async def cancel(
context: RequestContext,
event_queue: EventQueue,

View File

@@ -0,0 +1,215 @@
"""Transport negotiation utilities for A2A protocol.
This module provides functionality for negotiating the transport protocol
between an A2A client and server based on their respective capabilities
and preferences.
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import Final, Literal
from a2a.types import AgentCard, AgentInterface
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.a2a_events import A2ATransportNegotiatedEvent
TransportProtocol = Literal["JSONRPC", "GRPC", "HTTP+JSON"]
NegotiationSource = Literal["client_preferred", "server_preferred", "fallback"]
JSONRPC_TRANSPORT: Literal["JSONRPC"] = "JSONRPC"
GRPC_TRANSPORT: Literal["GRPC"] = "GRPC"
HTTP_JSON_TRANSPORT: Literal["HTTP+JSON"] = "HTTP+JSON"
DEFAULT_TRANSPORT_PREFERENCE: Final[list[TransportProtocol]] = [
JSONRPC_TRANSPORT,
GRPC_TRANSPORT,
HTTP_JSON_TRANSPORT,
]
@dataclass
class NegotiatedTransport:
"""Result of transport negotiation.
Attributes:
transport: The negotiated transport protocol.
url: The URL to use for this transport.
source: How the transport was selected ('preferred', 'additional', 'fallback').
"""
transport: str
url: str
source: NegotiationSource
class TransportNegotiationError(Exception):
"""Raised when no compatible transport can be negotiated."""
def __init__(
self,
client_transports: list[str],
server_transports: list[str],
message: str | None = None,
) -> None:
"""Initialize the error with negotiation details.
Args:
client_transports: Transports supported by the client.
server_transports: Transports supported by the server.
message: Optional custom error message.
"""
self.client_transports = client_transports
self.server_transports = server_transports
if message is None:
message = (
f"No compatible transport found. "
f"Client supports: {client_transports}. "
f"Server supports: {server_transports}."
)
super().__init__(message)
def _get_server_interfaces(agent_card: AgentCard) -> list[AgentInterface]:
"""Extract all available interfaces from an AgentCard.
Creates a unified list of interfaces including the primary URL and
any additional interfaces declared by the agent.
Args:
agent_card: The agent's card containing transport information.
Returns:
List of AgentInterface objects representing all available endpoints.
"""
interfaces: list[AgentInterface] = []
primary_transport = agent_card.preferred_transport or JSONRPC_TRANSPORT
interfaces.append(
AgentInterface(
transport=primary_transport,
url=agent_card.url,
)
)
if agent_card.additional_interfaces:
for interface in agent_card.additional_interfaces:
is_duplicate = any(
i.url == interface.url and i.transport == interface.transport
for i in interfaces
)
if not is_duplicate:
interfaces.append(interface)
return interfaces
def negotiate_transport(
agent_card: AgentCard,
client_supported_transports: list[str] | None = None,
client_preferred_transport: str | None = None,
emit_event: bool = True,
endpoint: str | None = None,
a2a_agent_name: str | None = None,
) -> NegotiatedTransport:
"""Negotiate the transport protocol between client and server.
Compares the client's supported transports with the server's available
interfaces to find a compatible transport and URL.
Negotiation logic:
1. If client_preferred_transport is set and server supports it → use it
2. Otherwise, if server's preferred is in client's supported → use server's
3. Otherwise, find first match from client's supported in server's interfaces
Args:
agent_card: The server's AgentCard with transport information.
client_supported_transports: Transports the client can use.
Defaults to ["JSONRPC"] if not specified.
client_preferred_transport: Client's preferred transport. If set and
server supports it, takes priority over server preference.
emit_event: Whether to emit a transport negotiation event.
endpoint: Original endpoint URL for event metadata.
a2a_agent_name: Agent name for event metadata.
Returns:
NegotiatedTransport with the selected transport, URL, and source.
Raises:
TransportNegotiationError: If no compatible transport is found.
"""
if client_supported_transports is None:
client_supported_transports = [JSONRPC_TRANSPORT]
client_transports = [t.upper() for t in client_supported_transports]
client_preferred = (
client_preferred_transport.upper() if client_preferred_transport else None
)
server_interfaces = _get_server_interfaces(agent_card)
server_transports = [i.transport.upper() for i in server_interfaces]
transport_to_interface: dict[str, AgentInterface] = {}
for interface in server_interfaces:
transport_upper = interface.transport.upper()
if transport_upper not in transport_to_interface:
transport_to_interface[transport_upper] = interface
result: NegotiatedTransport | None = None
if client_preferred and client_preferred in transport_to_interface:
interface = transport_to_interface[client_preferred]
result = NegotiatedTransport(
transport=interface.transport,
url=interface.url,
source="client_preferred",
)
else:
server_preferred = (agent_card.preferred_transport or JSONRPC_TRANSPORT).upper()
if (
server_preferred in client_transports
and server_preferred in transport_to_interface
):
interface = transport_to_interface[server_preferred]
result = NegotiatedTransport(
transport=interface.transport,
url=interface.url,
source="server_preferred",
)
else:
for transport in client_transports:
if transport in transport_to_interface:
interface = transport_to_interface[transport]
result = NegotiatedTransport(
transport=interface.transport,
url=interface.url,
source="fallback",
)
break
if result is None:
raise TransportNegotiationError(
client_transports=client_transports,
server_transports=server_transports,
)
if emit_event:
crewai_event_bus.emit(
None,
A2ATransportNegotiatedEvent(
endpoint=endpoint or agent_card.url,
a2a_agent_name=a2a_agent_name or agent_card.name,
negotiated_transport=result.transport,
negotiated_url=result.url,
source=result.source,
client_supported_transports=client_transports,
server_supported_transports=server_transports,
server_preferred_transport=agent_card.preferred_transport
or JSONRPC_TRANSPORT,
client_preferred_transport=client_preferred,
),
)
return result

View File

@@ -11,19 +11,23 @@ from concurrent.futures import ThreadPoolExecutor, as_completed
from functools import wraps
import json
from types import MethodType
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, NamedTuple
from a2a.types import Role, TaskState
from pydantic import BaseModel, ValidationError
from crewai.a2a.config import A2AClientConfig, A2AConfig
from crewai.a2a.extensions.base import ExtensionRegistry
from crewai.a2a.extensions.base import (
A2AExtension,
ConversationState,
ExtensionRegistry,
)
from crewai.a2a.task_helpers import TaskStateResult
from crewai.a2a.templates import (
AVAILABLE_AGENTS_TEMPLATE,
CONVERSATION_TURN_INFO_TEMPLATE,
PREVIOUS_A2A_CONVERSATION_TEMPLATE,
REMOTE_AGENT_COMPLETED_NOTICE,
REMOTE_AGENT_RESPONSE_NOTICE,
UNAVAILABLE_AGENTS_NOTICE_TEMPLATE,
)
from crewai.a2a.types import AgentResponseProtocol
@@ -52,6 +56,42 @@ if TYPE_CHECKING:
from crewai.tools.base_tool import BaseTool
class DelegationContext(NamedTuple):
"""Context prepared for A2A delegation.
Groups all the values needed to execute a delegation to a remote A2A agent.
"""
a2a_agents: list[A2AConfig | A2AClientConfig]
agent_response_model: type[BaseModel] | None
current_request: str
agent_id: str
agent_config: A2AConfig | A2AClientConfig
context_id: str | None
task_id: str | None
metadata: dict[str, Any] | None
extensions: dict[str, Any] | None
reference_task_ids: list[str]
original_task_description: str
max_turns: int
class DelegationState(NamedTuple):
"""Mutable state for A2A delegation loop.
Groups values that may change during delegation turns.
"""
current_request: str
context_id: str | None
task_id: str | None
reference_task_ids: list[str]
conversation_history: list[Message]
agent_card: AgentCard | None
agent_card_dict: dict[str, Any] | None
agent_name: str | None
def wrap_agent_with_a2a_instance(
agent: Agent, extension_registry: ExtensionRegistry | None = None
) -> None:
@@ -165,7 +205,11 @@ def _fetch_agent_cards_concurrently(
agent_cards: dict[str, AgentCard] = {}
failed_agents: dict[str, str] = {}
with ThreadPoolExecutor(max_workers=len(a2a_agents)) as executor:
if not a2a_agents:
return agent_cards, failed_agents
max_workers = min(len(a2a_agents), 10)
with ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = {
executor.submit(_fetch_card_from_config, config): config
for config in a2a_agents
@@ -231,7 +275,7 @@ def _execute_task_with_a2a(
finally:
task.description = original_description
task.description, _ = _augment_prompt_with_a2a(
task.description, _, extension_states = _augment_prompt_with_a2a(
a2a_agents=a2a_agents,
task_description=original_description,
agent_cards=agent_cards,
@@ -248,7 +292,7 @@ def _execute_task_with_a2a(
if extension_registry and isinstance(agent_response, BaseModel):
agent_response = extension_registry.process_response_with_all(
agent_response, {}
agent_response, extension_states
)
if isinstance(agent_response, BaseModel) and isinstance(
@@ -264,7 +308,7 @@ def _execute_task_with_a2a(
tools=tools,
agent_cards=agent_cards,
original_task_description=original_description,
extension_registry=extension_registry,
_extension_registry=extension_registry,
)
return str(agent_response.message)
@@ -284,8 +328,8 @@ def _augment_prompt_with_a2a(
max_turns: int | None = None,
failed_agents: dict[str, str] | None = None,
extension_registry: ExtensionRegistry | None = None,
remote_task_completed: bool = False,
) -> tuple[str, bool]:
remote_status_notice: str = "",
) -> tuple[str, bool, dict[type[A2AExtension], ConversationState]]:
"""Add A2A delegation instructions to prompt.
Args:
@@ -297,13 +341,14 @@ def _augment_prompt_with_a2a(
max_turns: Maximum allowed turns (from config)
failed_agents: Dictionary mapping failed agent endpoints to error messages
extension_registry: Optional registry of A2A extensions
remote_status_notice: Optional notice about remote agent status to append
Returns:
Tuple of (augmented prompt, disable_structured_output flag)
Tuple of (augmented prompt, disable_structured_output flag, extension_states dict)
"""
if not agent_cards:
return task_description, False
return task_description, False, {}
agents_text = ""
@@ -365,15 +410,11 @@ def _augment_prompt_with_a2a(
warning=warning,
)
completion_notice = ""
if remote_task_completed and conversation_history:
completion_notice = REMOTE_AGENT_COMPLETED_NOTICE
augmented_prompt = f"""{task_description}
IMPORTANT: You have the ability to delegate this task to remote A2A agents.
{agents_text}
{history_text}{turn_info}{completion_notice}
{history_text}{turn_info}{remote_status_notice}
"""
@@ -382,7 +423,7 @@ IMPORTANT: You have the ability to delegate this task to remote A2A agents.
augmented_prompt, extension_states
)
return augmented_prompt, disable_structured_output
return augmented_prompt, disable_structured_output, extension_states
def _parse_agent_response(
@@ -461,12 +502,41 @@ def _handle_max_turns_exceeded(
raise Exception(f"A2A conversation exceeded maximum turns ({max_turns})")
def _emit_delegation_failed(
error_msg: str,
turn_num: int,
from_task: Any | None,
from_agent: Any | None,
endpoint: str | None,
a2a_agent_name: str | None,
agent_card: dict[str, Any] | None,
) -> str:
"""Emit failure event and return formatted error message."""
crewai_event_bus.emit(
None,
A2AConversationCompletedEvent(
status="failed",
final_result=None,
error=error_msg,
total_turns=turn_num + 1,
from_task=from_task,
from_agent=from_agent,
endpoint=endpoint,
a2a_agent_name=a2a_agent_name,
agent_card=agent_card,
),
)
return f"A2A delegation failed: {error_msg}"
def _process_response_result(
raw_result: str,
disable_structured_output: bool,
turn_num: int,
agent_role: str,
agent_response_model: type[BaseModel] | None,
extension_registry: ExtensionRegistry | None = None,
extension_states: dict[type[A2AExtension], ConversationState] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
endpoint: str | None = None,
@@ -516,6 +586,11 @@ def _process_response_result(
raw_result=raw_result, agent_response_model=agent_response_model
)
if extension_registry and isinstance(llm_response, BaseModel):
llm_response = extension_registry.process_response_with_all(
llm_response, extension_states or {}
)
if isinstance(llm_response, BaseModel) and isinstance(
llm_response, AgentResponseProtocol
):
@@ -571,72 +646,103 @@ def _prepare_agent_cards_dict(
return agent_cards_dict
def _init_delegation_state(
ctx: DelegationContext,
agent_cards: dict[str, AgentCard] | None,
) -> DelegationState:
"""Initialize delegation state from context and agent cards.
Args:
ctx: Delegation context with config and settings.
agent_cards: Pre-fetched agent cards.
Returns:
Initial delegation state for the conversation loop.
"""
current_agent_card = agent_cards.get(ctx.agent_id) if agent_cards else None
return DelegationState(
current_request=ctx.current_request,
context_id=ctx.context_id,
task_id=ctx.task_id,
reference_task_ids=list(ctx.reference_task_ids),
conversation_history=[],
agent_card=current_agent_card,
agent_card_dict=current_agent_card.model_dump() if current_agent_card else None,
agent_name=current_agent_card.name if current_agent_card else None,
)
def _get_turn_context(
agent_config: A2AConfig | A2AClientConfig,
) -> tuple[Any | None, list[str] | None]:
"""Get context for a delegation turn.
Returns:
Tuple of (agent_branch, accepted_output_modes).
"""
console_formatter = getattr(crewai_event_bus, "_console", None)
agent_branch = None
if console_formatter:
agent_branch = getattr(
console_formatter, "current_agent_branch", None
) or getattr(console_formatter, "current_task_branch", None)
accepted_output_modes = None
if isinstance(agent_config, A2AClientConfig):
accepted_output_modes = agent_config.accepted_output_modes
return agent_branch, accepted_output_modes
def _prepare_delegation_context(
self: Agent,
agent_response: AgentResponseProtocol,
task: Task,
original_task_description: str | None,
) -> tuple[
list[A2AConfig | A2AClientConfig],
type[BaseModel] | None,
str,
str,
A2AConfig | A2AClientConfig,
str | None,
str | None,
dict[str, Any] | None,
dict[str, Any] | None,
list[str],
str,
int,
]:
) -> DelegationContext:
"""Prepare delegation context from agent response and task.
Shared logic for both sync and async delegation.
Returns:
Tuple containing all the context values needed for delegation.
DelegationContext with all values needed for delegation.
"""
a2a_agents, agent_response_model = get_a2a_agents_and_response_model(self.a2a)
agent_ids = tuple(config.endpoint for config in a2a_agents)
current_request = str(agent_response.message)
if hasattr(agent_response, "a2a_ids") and agent_response.a2a_ids:
if not a2a_agents:
raise ValueError("No A2A agents configured for delegation")
if isinstance(agent_response, AgentResponseProtocol) and agent_response.a2a_ids:
agent_id = agent_response.a2a_ids[0]
else:
agent_id = agent_ids[0] if agent_ids else ""
agent_id = agent_ids[0]
if agent_id and agent_id not in agent_ids:
raise ValueError(
f"Unknown A2A agent ID(s): {agent_response.a2a_ids} not in {agent_ids}"
)
if agent_id not in agent_ids:
raise ValueError(f"Unknown A2A agent ID: {agent_id} not in {agent_ids}")
agent_config = next(filter(lambda x: x.endpoint == agent_id, a2a_agents))
agent_config = next(filter(lambda x: x.endpoint == agent_id, a2a_agents), None)
if agent_config is None:
raise ValueError(f"Agent configuration not found for endpoint: {agent_id}")
task_config = task.config or {}
context_id = task_config.get("context_id")
task_id_config = task_config.get("task_id")
metadata = task_config.get("metadata")
extensions = task_config.get("extensions")
reference_task_ids = task_config.get("reference_task_ids", [])
if original_task_description is None:
original_task_description = task.description
max_turns = agent_config.max_turns
return (
a2a_agents,
agent_response_model,
current_request,
agent_id,
agent_config,
context_id,
task_id_config,
metadata,
extensions,
reference_task_ids,
original_task_description,
max_turns,
return DelegationContext(
a2a_agents=a2a_agents,
agent_response_model=agent_response_model,
current_request=current_request,
agent_id=agent_id,
agent_config=agent_config,
context_id=task_config.get("context_id"),
task_id=task_config.get("task_id"),
metadata=task_config.get("metadata"),
extensions=task_config.get("extensions"),
reference_task_ids=task_config.get("reference_task_ids", []),
original_task_description=original_task_description,
max_turns=agent_config.max_turns,
)
@@ -652,20 +758,50 @@ def _handle_task_completion(
endpoint: str | None = None,
a2a_agent_name: str | None = None,
agent_card: dict[str, Any] | None = None,
) -> tuple[str | None, str | None, list[str]]:
) -> tuple[str | None, str | None, list[str], str]:
"""Handle task completion state including reference task updates.
When a remote task completes, this function:
1. Adds the completed task_id to reference_task_ids (if not already present)
2. Clears task_id_config to signal that a new task ID should be generated for next turn
3. Updates task.config with the reference list for subsequent A2A calls
The reference_task_ids list tracks all completed tasks in this conversation chain,
allowing the remote agent to maintain context across multi-turn interactions.
Shared logic for both sync and async delegation.
Args:
a2a_result: Result from A2A delegation containing task status.
task: CrewAI Task object to update with reference IDs.
task_id_config: Current task ID (will be added to references if task completed).
reference_task_ids: Mutable list of completed task IDs (updated in place).
agent_config: A2A configuration with trust settings.
turn_num: Current turn number.
from_task: Optional CrewAI Task for event metadata.
from_agent: Optional CrewAI Agent for event metadata.
endpoint: A2A endpoint URL.
a2a_agent_name: Name of remote A2A agent.
agent_card: Agent card dict for event metadata.
Returns:
Tuple of (result_if_trusted, updated_task_id, updated_reference_task_ids).
Tuple of (result_if_trusted, updated_task_id, updated_reference_task_ids, remote_notice).
- result_if_trusted: Final result if trust_remote_completion_status=True, else None
- updated_task_id: None (cleared to generate new ID for next turn)
- updated_reference_task_ids: The mutated list with completed task added
- remote_notice: Template notice about remote agent response
"""
remote_notice = ""
if a2a_result["status"] == TaskState.completed:
remote_notice = REMOTE_AGENT_RESPONSE_NOTICE
if task_id_config is not None and task_id_config not in reference_task_ids:
reference_task_ids.append(task_id_config)
if task.config is None:
task.config = {}
task.config["reference_task_ids"] = reference_task_ids
task.config["reference_task_ids"] = list(reference_task_ids)
task_id_config = None
if agent_config.trust_remote_completion_status:
@@ -685,9 +821,9 @@ def _handle_task_completion(
agent_card=agent_card,
),
)
return str(result_text), task_id_config, reference_task_ids
return str(result_text), task_id_config, reference_task_ids, remote_notice
return None, task_id_config, reference_task_ids
return None, task_id_config, reference_task_ids, remote_notice
def _handle_agent_response_and_continue(
@@ -705,7 +841,8 @@ def _handle_agent_response_and_continue(
context: str | None,
tools: list[BaseTool] | None,
agent_response_model: type[BaseModel] | None,
remote_task_completed: bool = False,
extension_registry: ExtensionRegistry | None = None,
remote_status_notice: str = "",
endpoint: str | None = None,
a2a_agent_name: str | None = None,
agent_card: dict[str, Any] | None = None,
@@ -735,14 +872,18 @@ def _handle_agent_response_and_continue(
"""
agent_cards_dict = _prepare_agent_cards_dict(a2a_result, agent_id, agent_cards)
task.description, disable_structured_output = _augment_prompt_with_a2a(
(
task.description,
disable_structured_output,
extension_states,
) = _augment_prompt_with_a2a(
a2a_agents=a2a_agents,
task_description=original_task_description,
conversation_history=conversation_history,
turn_num=turn_num,
max_turns=max_turns,
agent_cards=agent_cards_dict,
remote_task_completed=remote_task_completed,
remote_status_notice=remote_status_notice,
)
original_response_model = task.response_model
@@ -760,6 +901,8 @@ def _handle_agent_response_and_continue(
turn_num=turn_num,
agent_role=self.role,
agent_response_model=agent_response_model,
extension_registry=extension_registry,
extension_states=extension_states,
from_task=task,
from_agent=self,
endpoint=endpoint,
@@ -777,7 +920,7 @@ def _delegate_to_a2a(
tools: list[BaseTool] | None,
agent_cards: dict[str, AgentCard] | None = None,
original_task_description: str | None = None,
extension_registry: ExtensionRegistry | None = None,
_extension_registry: ExtensionRegistry | None = None,
) -> str:
"""Delegate to A2A agent with multi-turn conversation support.
@@ -790,7 +933,7 @@ def _delegate_to_a2a(
tools: Optional tools available to the agent
agent_cards: Pre-fetched agent cards from _execute_task_with_a2a
original_task_description: The original task description before A2A augmentation
extension_registry: Optional registry of A2A extensions
_extension_registry: Optional registry of A2A extensions (unused, reserved for future use)
Returns:
Result from A2A agent
@@ -798,60 +941,42 @@ def _delegate_to_a2a(
Raises:
ImportError: If a2a-sdk is not installed
"""
(
a2a_agents,
agent_response_model,
current_request,
agent_id,
agent_config,
context_id,
task_id_config,
metadata,
extensions,
reference_task_ids,
original_task_description,
max_turns,
) = _prepare_delegation_context(
ctx = _prepare_delegation_context(
self, agent_response, task, original_task_description
)
conversation_history: list[Message] = []
current_agent_card = agent_cards.get(agent_id) if agent_cards else None
current_agent_card_dict = (
current_agent_card.model_dump() if current_agent_card else None
)
current_a2a_agent_name = current_agent_card.name if current_agent_card else None
state = _init_delegation_state(ctx, agent_cards)
current_request = state.current_request
context_id = state.context_id
task_id = state.task_id
reference_task_ids = state.reference_task_ids
conversation_history = state.conversation_history
try:
for turn_num in range(max_turns):
console_formatter = getattr(crewai_event_bus, "_console", None)
agent_branch = None
if console_formatter:
agent_branch = getattr(
console_formatter, "current_agent_branch", None
) or getattr(console_formatter, "current_task_branch", None)
for turn_num in range(ctx.max_turns):
agent_branch, accepted_output_modes = _get_turn_context(ctx.agent_config)
a2a_result = execute_a2a_delegation(
endpoint=agent_config.endpoint,
auth=agent_config.auth,
timeout=agent_config.timeout,
endpoint=ctx.agent_config.endpoint,
auth=ctx.agent_config.auth,
timeout=ctx.agent_config.timeout,
task_description=current_request,
context_id=context_id,
task_id=task_id_config,
task_id=task_id,
reference_task_ids=reference_task_ids,
metadata=metadata,
extensions=extensions,
metadata=ctx.metadata,
extensions=ctx.extensions,
conversation_history=conversation_history,
agent_id=agent_id,
agent_id=ctx.agent_id,
agent_role=Role.user,
agent_branch=agent_branch,
response_model=agent_config.response_model,
response_model=ctx.agent_config.response_model,
turn_number=turn_num + 1,
updates=agent_config.updates,
transport_protocol=agent_config.transport_protocol,
updates=ctx.agent_config.updates,
transport=ctx.agent_config.transport,
from_task=task,
from_agent=self,
client_extensions=getattr(ctx.agent_config, "extensions", None),
accepted_output_modes=accepted_output_modes,
)
conversation_history = a2a_result.get("history", [])
@@ -859,24 +984,24 @@ def _delegate_to_a2a(
if conversation_history:
latest_message = conversation_history[-1]
if latest_message.task_id is not None:
task_id_config = latest_message.task_id
task_id = latest_message.task_id
if latest_message.context_id is not None:
context_id = latest_message.context_id
if a2a_result["status"] in [TaskState.completed, TaskState.input_required]:
trusted_result, task_id_config, reference_task_ids = (
trusted_result, task_id, reference_task_ids, remote_notice = (
_handle_task_completion(
a2a_result,
task,
task_id_config,
task_id,
reference_task_ids,
agent_config,
ctx.agent_config,
turn_num,
from_task=task,
from_agent=self,
endpoint=agent_config.endpoint,
a2a_agent_name=current_a2a_agent_name,
agent_card=current_agent_card_dict,
endpoint=ctx.agent_config.endpoint,
a2a_agent_name=state.agent_name,
agent_card=state.agent_card_dict,
)
)
if trusted_result is not None:
@@ -885,22 +1010,23 @@ def _delegate_to_a2a(
final_result, next_request = _handle_agent_response_and_continue(
self=self,
a2a_result=a2a_result,
agent_id=agent_id,
agent_id=ctx.agent_id,
agent_cards=agent_cards,
a2a_agents=a2a_agents,
original_task_description=original_task_description,
a2a_agents=ctx.a2a_agents,
original_task_description=ctx.original_task_description,
conversation_history=conversation_history,
turn_num=turn_num,
max_turns=max_turns,
max_turns=ctx.max_turns,
task=task,
original_fn=original_fn,
context=context,
tools=tools,
agent_response_model=agent_response_model,
remote_task_completed=(a2a_result["status"] == TaskState.completed),
endpoint=agent_config.endpoint,
a2a_agent_name=current_a2a_agent_name,
agent_card=current_agent_card_dict,
agent_response_model=ctx.agent_response_model,
extension_registry=_extension_registry,
remote_status_notice=remote_notice,
endpoint=ctx.agent_config.endpoint,
a2a_agent_name=state.agent_name,
agent_card=state.agent_card_dict,
)
if final_result is not None:
@@ -916,22 +1042,22 @@ def _delegate_to_a2a(
final_result, next_request = _handle_agent_response_and_continue(
self=self,
a2a_result=a2a_result,
agent_id=agent_id,
agent_id=ctx.agent_id,
agent_cards=agent_cards,
a2a_agents=a2a_agents,
original_task_description=original_task_description,
a2a_agents=ctx.a2a_agents,
original_task_description=ctx.original_task_description,
conversation_history=conversation_history,
turn_num=turn_num,
max_turns=max_turns,
max_turns=ctx.max_turns,
task=task,
original_fn=original_fn,
context=context,
tools=tools,
agent_response_model=agent_response_model,
remote_task_completed=False,
endpoint=agent_config.endpoint,
a2a_agent_name=current_a2a_agent_name,
agent_card=current_agent_card_dict,
agent_response_model=ctx.agent_response_model,
extension_registry=_extension_registry,
endpoint=ctx.agent_config.endpoint,
a2a_agent_name=state.agent_name,
agent_card=state.agent_card_dict,
)
if final_result is not None:
@@ -941,34 +1067,28 @@ def _delegate_to_a2a(
current_request = next_request
continue
crewai_event_bus.emit(
None,
A2AConversationCompletedEvent(
status="failed",
final_result=None,
error=error_msg,
total_turns=turn_num + 1,
from_task=task,
from_agent=self,
endpoint=agent_config.endpoint,
a2a_agent_name=current_a2a_agent_name,
agent_card=current_agent_card_dict,
),
return _emit_delegation_failed(
error_msg,
turn_num,
task,
self,
ctx.agent_config.endpoint,
state.agent_name,
state.agent_card_dict,
)
return f"A2A delegation failed: {error_msg}"
return _handle_max_turns_exceeded(
conversation_history,
max_turns,
ctx.max_turns,
from_task=task,
from_agent=self,
endpoint=agent_config.endpoint,
a2a_agent_name=current_a2a_agent_name,
agent_card=current_agent_card_dict,
endpoint=ctx.agent_config.endpoint,
a2a_agent_name=state.agent_name,
agent_card=state.agent_card_dict,
)
finally:
task.description = original_task_description
task.description = ctx.original_task_description
async def _afetch_card_from_config(
@@ -993,6 +1113,9 @@ async def _afetch_agent_cards_concurrently(
agent_cards: dict[str, AgentCard] = {}
failed_agents: dict[str, str] = {}
if not a2a_agents:
return agent_cards, failed_agents
tasks = [_afetch_card_from_config(config) for config in a2a_agents]
results = await asyncio.gather(*tasks)
@@ -1042,7 +1165,7 @@ async def _aexecute_task_with_a2a(
finally:
task.description = original_description
task.description, _ = _augment_prompt_with_a2a(
task.description, _, extension_states = _augment_prompt_with_a2a(
a2a_agents=a2a_agents,
task_description=original_description,
agent_cards=agent_cards,
@@ -1059,7 +1182,7 @@ async def _aexecute_task_with_a2a(
if extension_registry and isinstance(agent_response, BaseModel):
agent_response = extension_registry.process_response_with_all(
agent_response, {}
agent_response, extension_states
)
if isinstance(agent_response, BaseModel) and isinstance(
@@ -1075,7 +1198,7 @@ async def _aexecute_task_with_a2a(
tools=tools,
agent_cards=agent_cards,
original_task_description=original_description,
extension_registry=extension_registry,
_extension_registry=extension_registry,
)
return str(agent_response.message)
@@ -1101,7 +1224,8 @@ async def _ahandle_agent_response_and_continue(
context: str | None,
tools: list[BaseTool] | None,
agent_response_model: type[BaseModel] | None,
remote_task_completed: bool = False,
extension_registry: ExtensionRegistry | None = None,
remote_status_notice: str = "",
endpoint: str | None = None,
a2a_agent_name: str | None = None,
agent_card: dict[str, Any] | None = None,
@@ -1109,14 +1233,18 @@ async def _ahandle_agent_response_and_continue(
"""Async version of _handle_agent_response_and_continue."""
agent_cards_dict = _prepare_agent_cards_dict(a2a_result, agent_id, agent_cards)
task.description, disable_structured_output = _augment_prompt_with_a2a(
(
task.description,
disable_structured_output,
extension_states,
) = _augment_prompt_with_a2a(
a2a_agents=a2a_agents,
task_description=original_task_description,
conversation_history=conversation_history,
turn_num=turn_num,
max_turns=max_turns,
agent_cards=agent_cards_dict,
remote_task_completed=remote_task_completed,
remote_status_notice=remote_status_notice,
)
original_response_model = task.response_model
@@ -1134,6 +1262,8 @@ async def _ahandle_agent_response_and_continue(
turn_num=turn_num,
agent_role=self.role,
agent_response_model=agent_response_model,
extension_registry=extension_registry,
extension_states=extension_states,
from_task=task,
from_agent=self,
endpoint=endpoint,
@@ -1151,63 +1281,45 @@ async def _adelegate_to_a2a(
tools: list[BaseTool] | None,
agent_cards: dict[str, AgentCard] | None = None,
original_task_description: str | None = None,
extension_registry: ExtensionRegistry | None = None,
_extension_registry: ExtensionRegistry | None = None,
) -> str:
"""Async version of _delegate_to_a2a."""
(
a2a_agents,
agent_response_model,
current_request,
agent_id,
agent_config,
context_id,
task_id_config,
metadata,
extensions,
reference_task_ids,
original_task_description,
max_turns,
) = _prepare_delegation_context(
ctx = _prepare_delegation_context(
self, agent_response, task, original_task_description
)
conversation_history: list[Message] = []
current_agent_card = agent_cards.get(agent_id) if agent_cards else None
current_agent_card_dict = (
current_agent_card.model_dump() if current_agent_card else None
)
current_a2a_agent_name = current_agent_card.name if current_agent_card else None
state = _init_delegation_state(ctx, agent_cards)
current_request = state.current_request
context_id = state.context_id
task_id = state.task_id
reference_task_ids = state.reference_task_ids
conversation_history = state.conversation_history
try:
for turn_num in range(max_turns):
console_formatter = getattr(crewai_event_bus, "_console", None)
agent_branch = None
if console_formatter:
agent_branch = getattr(
console_formatter, "current_agent_branch", None
) or getattr(console_formatter, "current_task_branch", None)
for turn_num in range(ctx.max_turns):
agent_branch, accepted_output_modes = _get_turn_context(ctx.agent_config)
a2a_result = await aexecute_a2a_delegation(
endpoint=agent_config.endpoint,
auth=agent_config.auth,
timeout=agent_config.timeout,
endpoint=ctx.agent_config.endpoint,
auth=ctx.agent_config.auth,
timeout=ctx.agent_config.timeout,
task_description=current_request,
context_id=context_id,
task_id=task_id_config,
task_id=task_id,
reference_task_ids=reference_task_ids,
metadata=metadata,
extensions=extensions,
metadata=ctx.metadata,
extensions=ctx.extensions,
conversation_history=conversation_history,
agent_id=agent_id,
agent_id=ctx.agent_id,
agent_role=Role.user,
agent_branch=agent_branch,
response_model=agent_config.response_model,
response_model=ctx.agent_config.response_model,
turn_number=turn_num + 1,
transport_protocol=agent_config.transport_protocol,
updates=agent_config.updates,
transport=ctx.agent_config.transport,
updates=ctx.agent_config.updates,
from_task=task,
from_agent=self,
client_extensions=getattr(ctx.agent_config, "extensions", None),
accepted_output_modes=accepted_output_modes,
)
conversation_history = a2a_result.get("history", [])
@@ -1215,24 +1327,24 @@ async def _adelegate_to_a2a(
if conversation_history:
latest_message = conversation_history[-1]
if latest_message.task_id is not None:
task_id_config = latest_message.task_id
task_id = latest_message.task_id
if latest_message.context_id is not None:
context_id = latest_message.context_id
if a2a_result["status"] in [TaskState.completed, TaskState.input_required]:
trusted_result, task_id_config, reference_task_ids = (
trusted_result, task_id, reference_task_ids, remote_notice = (
_handle_task_completion(
a2a_result,
task,
task_id_config,
task_id,
reference_task_ids,
agent_config,
ctx.agent_config,
turn_num,
from_task=task,
from_agent=self,
endpoint=agent_config.endpoint,
a2a_agent_name=current_a2a_agent_name,
agent_card=current_agent_card_dict,
endpoint=ctx.agent_config.endpoint,
a2a_agent_name=state.agent_name,
agent_card=state.agent_card_dict,
)
)
if trusted_result is not None:
@@ -1241,22 +1353,23 @@ async def _adelegate_to_a2a(
final_result, next_request = await _ahandle_agent_response_and_continue(
self=self,
a2a_result=a2a_result,
agent_id=agent_id,
agent_id=ctx.agent_id,
agent_cards=agent_cards,
a2a_agents=a2a_agents,
original_task_description=original_task_description,
a2a_agents=ctx.a2a_agents,
original_task_description=ctx.original_task_description,
conversation_history=conversation_history,
turn_num=turn_num,
max_turns=max_turns,
max_turns=ctx.max_turns,
task=task,
original_fn=original_fn,
context=context,
tools=tools,
agent_response_model=agent_response_model,
remote_task_completed=(a2a_result["status"] == TaskState.completed),
endpoint=agent_config.endpoint,
a2a_agent_name=current_a2a_agent_name,
agent_card=current_agent_card_dict,
agent_response_model=ctx.agent_response_model,
extension_registry=_extension_registry,
remote_status_notice=remote_notice,
endpoint=ctx.agent_config.endpoint,
a2a_agent_name=state.agent_name,
agent_card=state.agent_card_dict,
)
if final_result is not None:
@@ -1272,21 +1385,22 @@ async def _adelegate_to_a2a(
final_result, next_request = await _ahandle_agent_response_and_continue(
self=self,
a2a_result=a2a_result,
agent_id=agent_id,
agent_id=ctx.agent_id,
agent_cards=agent_cards,
a2a_agents=a2a_agents,
original_task_description=original_task_description,
a2a_agents=ctx.a2a_agents,
original_task_description=ctx.original_task_description,
conversation_history=conversation_history,
turn_num=turn_num,
max_turns=max_turns,
max_turns=ctx.max_turns,
task=task,
original_fn=original_fn,
context=context,
tools=tools,
agent_response_model=agent_response_model,
endpoint=agent_config.endpoint,
a2a_agent_name=current_a2a_agent_name,
agent_card=current_agent_card_dict,
agent_response_model=ctx.agent_response_model,
extension_registry=_extension_registry,
endpoint=ctx.agent_config.endpoint,
a2a_agent_name=state.agent_name,
agent_card=state.agent_card_dict,
)
if final_result is not None:
@@ -1296,31 +1410,25 @@ async def _adelegate_to_a2a(
current_request = next_request
continue
crewai_event_bus.emit(
None,
A2AConversationCompletedEvent(
status="failed",
final_result=None,
error=error_msg,
total_turns=turn_num + 1,
from_task=task,
from_agent=self,
endpoint=agent_config.endpoint,
a2a_agent_name=current_a2a_agent_name,
agent_card=current_agent_card_dict,
),
return _emit_delegation_failed(
error_msg,
turn_num,
task,
self,
ctx.agent_config.endpoint,
state.agent_name,
state.agent_card_dict,
)
return f"A2A delegation failed: {error_msg}"
return _handle_max_turns_exceeded(
conversation_history,
max_turns,
ctx.max_turns,
from_task=task,
from_agent=self,
endpoint=agent_config.endpoint,
a2a_agent_name=current_a2a_agent_name,
agent_card=current_agent_card_dict,
endpoint=ctx.agent_config.endpoint,
a2a_agent_name=state.agent_name,
agent_card=state.agent_card_dict,
)
finally:
task.description = original_task_description
task.description = ctx.original_task_description

View File

@@ -654,3 +654,165 @@ class A2AParallelDelegationCompletedEvent(A2AEventBase):
success_count: int
failure_count: int
results: dict[str, str] | None = None
class A2ATransportNegotiatedEvent(A2AEventBase):
"""Event emitted when transport protocol is negotiated with an A2A agent.
This event is emitted after comparing client and server transport capabilities
to select the optimal transport protocol and endpoint URL.
Attributes:
endpoint: Original A2A agent endpoint URL.
a2a_agent_name: Name of the A2A agent from agent card.
negotiated_transport: The transport protocol selected (JSONRPC, GRPC, HTTP+JSON).
negotiated_url: The URL to use for the selected transport.
source: How the transport was selected ('client_preferred', 'server_preferred', 'fallback').
client_supported_transports: Transports the client can use.
server_supported_transports: Transports the server supports.
server_preferred_transport: The server's preferred transport from AgentCard.
client_preferred_transport: The client's preferred transport if set.
metadata: Custom A2A metadata key-value pairs.
"""
type: str = "a2a_transport_negotiated"
endpoint: str
a2a_agent_name: str | None = None
negotiated_transport: str
negotiated_url: str
source: str
client_supported_transports: list[str]
server_supported_transports: list[str]
server_preferred_transport: str
client_preferred_transport: str | None = None
metadata: dict[str, Any] | None = None
class A2AContentTypeNegotiatedEvent(A2AEventBase):
"""Event emitted when content types are negotiated with an A2A agent.
This event is emitted after comparing client and server input/output mode
capabilities to determine compatible MIME types for communication.
Attributes:
endpoint: A2A agent endpoint URL.
a2a_agent_name: Name of the A2A agent from agent card.
skill_name: Skill name if negotiation was skill-specific.
client_input_modes: MIME types the client can send.
client_output_modes: MIME types the client can accept.
server_input_modes: MIME types the server accepts.
server_output_modes: MIME types the server produces.
negotiated_input_modes: Compatible input MIME types selected.
negotiated_output_modes: Compatible output MIME types selected.
negotiation_success: Whether compatible types were found for both directions.
metadata: Custom A2A metadata key-value pairs.
"""
type: str = "a2a_content_type_negotiated"
endpoint: str
a2a_agent_name: str | None = None
skill_name: str | None = None
client_input_modes: list[str]
client_output_modes: list[str]
server_input_modes: list[str]
server_output_modes: list[str]
negotiated_input_modes: list[str]
negotiated_output_modes: list[str]
negotiation_success: bool = True
metadata: dict[str, Any] | None = None
# -----------------------------------------------------------------------------
# Context Lifecycle Events
# -----------------------------------------------------------------------------
class A2AContextCreatedEvent(A2AEventBase):
"""Event emitted when an A2A context is created.
Contexts group related tasks in a conversation or workflow.
Attributes:
context_id: Unique identifier for the context.
created_at: Unix timestamp when context was created.
metadata: Custom A2A metadata key-value pairs.
"""
type: str = "a2a_context_created"
context_id: str
created_at: float
metadata: dict[str, Any] | None = None
class A2AContextExpiredEvent(A2AEventBase):
"""Event emitted when an A2A context expires due to TTL.
Attributes:
context_id: The expired context identifier.
created_at: Unix timestamp when context was created.
age_seconds: How long the context existed before expiring.
task_count: Number of tasks in the context when expired.
metadata: Custom A2A metadata key-value pairs.
"""
type: str = "a2a_context_expired"
context_id: str
created_at: float
age_seconds: float
task_count: int
metadata: dict[str, Any] | None = None
class A2AContextIdleEvent(A2AEventBase):
"""Event emitted when an A2A context becomes idle.
Idle contexts have had no activity for the configured threshold.
Attributes:
context_id: The idle context identifier.
idle_seconds: Seconds since last activity.
task_count: Number of tasks in the context.
metadata: Custom A2A metadata key-value pairs.
"""
type: str = "a2a_context_idle"
context_id: str
idle_seconds: float
task_count: int
metadata: dict[str, Any] | None = None
class A2AContextCompletedEvent(A2AEventBase):
"""Event emitted when all tasks in an A2A context complete.
Attributes:
context_id: The completed context identifier.
total_tasks: Total number of tasks that were in the context.
duration_seconds: Total context lifetime in seconds.
metadata: Custom A2A metadata key-value pairs.
"""
type: str = "a2a_context_completed"
context_id: str
total_tasks: int
duration_seconds: float
metadata: dict[str, Any] | None = None
class A2AContextPrunedEvent(A2AEventBase):
"""Event emitted when an A2A context is pruned (deleted).
Pruning removes the context metadata and optionally associated tasks.
Attributes:
context_id: The pruned context identifier.
task_count: Number of tasks that were in the context.
age_seconds: How long the context existed before pruning.
metadata: Custom A2A metadata key-value pairs.
"""
type: str = "a2a_context_pruned"
context_id: str
task_count: int
age_seconds: float
metadata: dict[str, Any] | None = None

View File

@@ -104,6 +104,7 @@ class TestA2AStreamingIntegration:
message=test_message,
new_messages=new_messages,
agent_card=agent_card,
endpoint=agent_card.url,
)
assert isinstance(result, dict)
@@ -225,6 +226,7 @@ class TestA2APushNotificationHandler:
result_store=mock_store,
polling_timeout=30.0,
polling_interval=1.0,
endpoint=mock_agent_card.url,
)
mock_store.wait_for_result.assert_called_once_with(
@@ -287,6 +289,7 @@ class TestA2APushNotificationHandler:
result_store=mock_store,
polling_timeout=5.0,
polling_interval=0.5,
endpoint=mock_agent_card.url,
)
assert result["status"] == TaskState.failed
@@ -317,6 +320,7 @@ class TestA2APushNotificationHandler:
message=test_msg,
new_messages=new_messages,
agent_card=mock_agent_card,
endpoint=mock_agent_card.url,
)
assert result["status"] == TaskState.failed

View File

@@ -43,6 +43,7 @@ def mock_context() -> MagicMock:
context.context_id = "test-context-456"
context.get_user_input.return_value = "Test user message"
context.message = MagicMock(spec=Message)
context.message.parts = []
context.current_task = None
return context