From 4543c666970be9951dcd180a4c853ffe4cb064e5 Mon Sep 17 00:00:00 2001 From: Greyson LaLonde Date: Thu, 29 Jan 2026 05:13:42 -0500 Subject: [PATCH] feat: add transport negotiation and content type handling - add transport negotiation logic with fallback support - add content type parser and encoder utilities - add transport configuration models (client and server) - add transport types and enums - enhance config with transport settings - add negotiation events for transport and content type --- .../auth/{schemas.py => client_schemes.py} | 188 +++++- lib/crewai/src/crewai/a2a/config.py | 508 +++++++++++++-- lib/crewai/src/crewai/a2a/errors.py | 486 ++++++++++++++- lib/crewai/src/crewai/a2a/task_helpers.py | 10 + lib/crewai/src/crewai/a2a/templates.py | 15 + lib/crewai/src/crewai/a2a/types.py | 11 + lib/crewai/src/crewai/a2a/updates/base.py | 50 +- .../src/crewai/a2a/updates/polling/handler.py | 4 +- .../a2a/updates/push_notifications/config.py | 29 +- .../a2a/updates/push_notifications/handler.py | 258 ++++---- .../updates/push_notifications/signature.py | 87 +++ .../crewai/a2a/updates/streaming/handler.py | 500 ++++++++++++--- .../crewai/a2a/updates/streaming/params.py | 28 + lib/crewai/src/crewai/a2a/utils/agent_card.py | 125 +++- .../crewai/a2a/utils/agent_card_signing.py | 236 +++++++ .../src/crewai/a2a/utils/content_type.py | 339 ++++++++++ lib/crewai/src/crewai/a2a/utils/delegation.py | 478 ++++++++++++-- lib/crewai/src/crewai/a2a/utils/logging.py | 131 ++++ lib/crewai/src/crewai/a2a/utils/task.py | 180 +++++- lib/crewai/src/crewai/a2a/utils/transport.py | 215 +++++++ lib/crewai/src/crewai/a2a/wrapper.py | 590 +++++++++++------- .../src/crewai/events/types/a2a_events.py | 162 +++++ lib/crewai/tests/a2a/test_a2a_integration.py | 4 + lib/crewai/tests/a2a/utils/test_task.py | 1 + 24 files changed, 3994 insertions(+), 641 deletions(-) rename lib/crewai/src/crewai/a2a/auth/{schemas.py => client_schemes.py} (65%) create mode 100644 lib/crewai/src/crewai/a2a/updates/push_notifications/signature.py create mode 100644 lib/crewai/src/crewai/a2a/updates/streaming/params.py create mode 100644 lib/crewai/src/crewai/a2a/utils/agent_card_signing.py create mode 100644 lib/crewai/src/crewai/a2a/utils/content_type.py create mode 100644 lib/crewai/src/crewai/a2a/utils/logging.py create mode 100644 lib/crewai/src/crewai/a2a/utils/transport.py diff --git a/lib/crewai/src/crewai/a2a/auth/schemas.py b/lib/crewai/src/crewai/a2a/auth/client_schemes.py similarity index 65% rename from lib/crewai/src/crewai/a2a/auth/schemas.py rename to lib/crewai/src/crewai/a2a/auth/client_schemes.py index af9344279..0356b8aef 100644 --- a/lib/crewai/src/crewai/a2a/auth/schemas.py +++ b/lib/crewai/src/crewai/a2a/auth/client_schemes.py @@ -1,4 +1,4 @@ -"""Authentication schemes for A2A protocol agents. +"""Authentication schemes for A2A protocol clients. Supported authentication methods: - Bearer tokens @@ -6,24 +6,135 @@ Supported authentication methods: - API Keys (header, query, cookie) - HTTP Basic authentication - HTTP Digest authentication +- mTLS (mutual TLS) client certificate authentication """ from __future__ import annotations from abc import ABC, abstractmethod +import asyncio import base64 from collections.abc import Awaitable, Callable, MutableMapping +from pathlib import Path +import ssl import time -from typing import Literal +from typing import TYPE_CHECKING, ClassVar, Literal import urllib.parse import httpx from httpx import DigestAuth -from pydantic import BaseModel, Field, PrivateAttr +from pydantic import BaseModel, ConfigDict, Field, FilePath, PrivateAttr +from typing_extensions import deprecated -class AuthScheme(ABC, BaseModel): - """Base class for authentication schemes.""" +if TYPE_CHECKING: + import grpc # type: ignore[import-untyped] + + +class TLSConfig(BaseModel): + """TLS/mTLS configuration for secure client connections. + + Supports mutual TLS (mTLS) where the client presents a certificate to the server, + and standard TLS with custom CA verification. + + Attributes: + client_cert_path: Path to client certificate file (PEM format) for mTLS. + client_key_path: Path to client private key file (PEM format) for mTLS. + ca_cert_path: Path to CA certificate bundle for server verification. + verify: Whether to verify server certificates. Set False only for development. + """ + + model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid") + + client_cert_path: FilePath | None = Field( + default=None, + description="Path to client certificate file (PEM format) for mTLS", + ) + client_key_path: FilePath | None = Field( + default=None, + description="Path to client private key file (PEM format) for mTLS", + ) + ca_cert_path: FilePath | None = Field( + default=None, + description="Path to CA certificate bundle for server verification", + ) + verify: bool = Field( + default=True, + description="Whether to verify server certificates. Set False only for development.", + ) + + def get_httpx_ssl_context(self) -> ssl.SSLContext | bool | str: + """Build SSL context for httpx client. + + Returns: + SSL context if certificates configured, True for default verification, + False if verification disabled, or path to CA bundle. + """ + if not self.verify: + return False + + if self.client_cert_path and self.client_key_path: + context = ssl.create_default_context() + + if self.ca_cert_path: + context.load_verify_locations(cafile=str(self.ca_cert_path)) + + context.load_cert_chain( + certfile=str(self.client_cert_path), + keyfile=str(self.client_key_path), + ) + return context + + if self.ca_cert_path: + return str(self.ca_cert_path) + + return True + + def get_grpc_credentials(self) -> grpc.ChannelCredentials | None: # type: ignore[no-any-unimported] + """Build gRPC channel credentials for secure connections. + + Returns: + gRPC SSL credentials if certificates configured, None otherwise. + """ + try: + import grpc + except ImportError: + return None + + if not self.verify and not self.client_cert_path: + return None + + root_certs: bytes | None = None + private_key: bytes | None = None + certificate_chain: bytes | None = None + + if self.ca_cert_path: + root_certs = Path(self.ca_cert_path).read_bytes() + + if self.client_cert_path and self.client_key_path: + private_key = Path(self.client_key_path).read_bytes() + certificate_chain = Path(self.client_cert_path).read_bytes() + + return grpc.ssl_channel_credentials( + root_certificates=root_certs, + private_key=private_key, + certificate_chain=certificate_chain, + ) + + +class ClientAuthScheme(ABC, BaseModel): + """Base class for client-side authentication schemes. + + Client auth schemes apply credentials to outgoing requests. + + Attributes: + tls: Optional TLS/mTLS configuration for secure connections. + """ + + tls: TLSConfig | None = Field( + default=None, + description="TLS/mTLS configuration for secure connections", + ) @abstractmethod async def apply_auth( @@ -41,7 +152,12 @@ class AuthScheme(ABC, BaseModel): ... -class BearerTokenAuth(AuthScheme): +@deprecated("Use ClientAuthScheme instead", category=FutureWarning) +class AuthScheme(ClientAuthScheme): + """Deprecated: Use ClientAuthScheme instead.""" + + +class BearerTokenAuth(ClientAuthScheme): """Bearer token authentication (Authorization: Bearer ). Attributes: @@ -66,7 +182,7 @@ class BearerTokenAuth(AuthScheme): return headers -class HTTPBasicAuth(AuthScheme): +class HTTPBasicAuth(ClientAuthScheme): """HTTP Basic authentication. Attributes: @@ -95,7 +211,7 @@ class HTTPBasicAuth(AuthScheme): return headers -class HTTPDigestAuth(AuthScheme): +class HTTPDigestAuth(ClientAuthScheme): """HTTP Digest authentication. Note: Uses httpx-auth library for digest implementation. @@ -108,6 +224,8 @@ class HTTPDigestAuth(AuthScheme): username: str = Field(description="Username") password: str = Field(description="Password") + _configured_client_id: int | None = PrivateAttr(default=None) + async def apply_auth( self, client: httpx.AsyncClient, headers: MutableMapping[str, str] ) -> MutableMapping[str, str]: @@ -125,13 +243,21 @@ class HTTPDigestAuth(AuthScheme): def configure_client(self, client: httpx.AsyncClient) -> None: """Configure client with Digest auth. + Idempotent: Only configures the client once. Subsequent calls on the same + client instance are no-ops to prevent overwriting auth configuration. + Args: client: HTTP client to configure with Digest authentication. """ + client_id = id(client) + if self._configured_client_id == client_id: + return + client.auth = DigestAuth(self.username, self.password) + self._configured_client_id = client_id -class APIKeyAuth(AuthScheme): +class APIKeyAuth(ClientAuthScheme): """API Key authentication (header, query, or cookie). Attributes: @@ -146,6 +272,8 @@ class APIKeyAuth(AuthScheme): ) name: str = Field(default="X-API-Key", description="Parameter name for the API key") + _configured_client_ids: set[int] = PrivateAttr(default_factory=set) + async def apply_auth( self, client: httpx.AsyncClient, headers: MutableMapping[str, str] ) -> MutableMapping[str, str]: @@ -167,21 +295,31 @@ class APIKeyAuth(AuthScheme): def configure_client(self, client: httpx.AsyncClient) -> None: """Configure client for query param API keys. + Idempotent: Only adds the request hook once per client instance. + Subsequent calls on the same client are no-ops to prevent hook accumulation. + Args: client: HTTP client to configure with query param API key hook. """ if self.location == "query": + client_id = id(client) + if client_id in self._configured_client_ids: + return async def _add_api_key_param(request: httpx.Request) -> None: url = httpx.URL(request.url) request.url = url.copy_add_param(self.name, self.api_key) client.event_hooks["request"].append(_add_api_key_param) + self._configured_client_ids.add(client_id) -class OAuth2ClientCredentials(AuthScheme): +class OAuth2ClientCredentials(ClientAuthScheme): """OAuth2 Client Credentials flow authentication. + Thread-safe implementation with asyncio.Lock to prevent concurrent token fetches + when multiple requests share the same auth instance. + Attributes: token_url: OAuth2 token endpoint URL. client_id: OAuth2 client identifier. @@ -198,12 +336,17 @@ class OAuth2ClientCredentials(AuthScheme): _access_token: str | None = PrivateAttr(default=None) _token_expires_at: float | None = PrivateAttr(default=None) + _lock: asyncio.Lock = PrivateAttr(default_factory=asyncio.Lock) async def apply_auth( self, client: httpx.AsyncClient, headers: MutableMapping[str, str] ) -> MutableMapping[str, str]: """Apply OAuth2 access token to Authorization header. + Uses asyncio.Lock to ensure only one coroutine fetches tokens at a time, + preventing race conditions when multiple concurrent requests use the same + auth instance. + Args: client: HTTP client for making token requests. headers: Current request headers. @@ -216,7 +359,13 @@ class OAuth2ClientCredentials(AuthScheme): or self._token_expires_at is None or time.time() >= self._token_expires_at ): - await self._fetch_token(client) + async with self._lock: + if ( + self._access_token is None + or self._token_expires_at is None + or time.time() >= self._token_expires_at + ): + await self._fetch_token(client) if self._access_token: headers["Authorization"] = f"Bearer {self._access_token}" @@ -250,9 +399,11 @@ class OAuth2ClientCredentials(AuthScheme): self._token_expires_at = time.time() + expires_in - 60 -class OAuth2AuthorizationCode(AuthScheme): +class OAuth2AuthorizationCode(ClientAuthScheme): """OAuth2 Authorization Code flow authentication. + Thread-safe implementation with asyncio.Lock to prevent concurrent token operations. + Note: Requires interactive authorization. Attributes: @@ -279,6 +430,7 @@ class OAuth2AuthorizationCode(AuthScheme): _authorization_callback: Callable[[str], Awaitable[str]] | None = PrivateAttr( default=None ) + _lock: asyncio.Lock = PrivateAttr(default_factory=asyncio.Lock) def set_authorization_callback( self, callback: Callable[[str], Awaitable[str]] | None @@ -295,6 +447,9 @@ class OAuth2AuthorizationCode(AuthScheme): ) -> MutableMapping[str, str]: """Apply OAuth2 access token to Authorization header. + Uses asyncio.Lock to ensure only one coroutine handles token operations + (initial fetch or refresh) at a time. + Args: client: HTTP client for making token requests. headers: Current request headers. @@ -305,14 +460,17 @@ class OAuth2AuthorizationCode(AuthScheme): Raises: ValueError: If authorization callback is not set. """ - if self._access_token is None: if self._authorization_callback is None: msg = "Authorization callback not set. Use set_authorization_callback()" raise ValueError(msg) - await self._fetch_initial_token(client) + async with self._lock: + if self._access_token is None: + await self._fetch_initial_token(client) elif self._token_expires_at and time.time() >= self._token_expires_at: - await self._refresh_access_token(client) + async with self._lock: + if self._token_expires_at and time.time() >= self._token_expires_at: + await self._refresh_access_token(client) if self._access_token: headers["Authorization"] = f"Bearer {self._access_token}" diff --git a/lib/crewai/src/crewai/a2a/config.py b/lib/crewai/src/crewai/a2a/config.py index 1b8cd7d81..1b9d63db4 100644 --- a/lib/crewai/src/crewai/a2a/config.py +++ b/lib/crewai/src/crewai/a2a/config.py @@ -5,14 +5,25 @@ This module is separate from experimental.a2a to avoid circular imports. from __future__ import annotations -from importlib.metadata import version -from typing import Any, ClassVar, Literal +from pathlib import Path +from typing import Any, ClassVar, Literal, cast +import warnings -from pydantic import BaseModel, ConfigDict, Field -from typing_extensions import deprecated +from pydantic import ( + BaseModel, + ConfigDict, + Field, + FilePath, + PrivateAttr, + SecretStr, + model_validator, +) +from typing_extensions import Self, deprecated -from crewai.a2a.auth.schemas import AuthScheme -from crewai.a2a.types import TransportType, Url +from crewai.a2a.auth.client_schemes import ClientAuthScheme +from crewai.a2a.auth.server_schemes import ServerAuthScheme +from crewai.a2a.extensions.base import ValidatedA2AExtension +from crewai.a2a.types import ProtocolVersion, TransportType, Url try: @@ -25,16 +36,17 @@ try: SecurityScheme, ) + from crewai.a2a.extensions.server import ServerExtension from crewai.a2a.updates import UpdateConfig except ImportError: - UpdateConfig = Any - AgentCapabilities = Any - AgentCardSignature = Any - AgentInterface = Any - AgentProvider = Any - SecurityScheme = Any - AgentSkill = Any - UpdateConfig = Any # type: ignore[misc,assignment] + UpdateConfig: Any = Any # type: ignore[no-redef] + AgentCapabilities: Any = Any # type: ignore[no-redef] + AgentCardSignature: Any = Any # type: ignore[no-redef] + AgentInterface: Any = Any # type: ignore[no-redef] + AgentProvider: Any = Any # type: ignore[no-redef] + SecurityScheme: Any = Any # type: ignore[no-redef] + AgentSkill: Any = Any # type: ignore[no-redef] + ServerExtension: Any = Any # type: ignore[no-redef] def _get_default_update_config() -> UpdateConfig: @@ -43,6 +55,309 @@ def _get_default_update_config() -> UpdateConfig: return StreamingConfig() +SigningAlgorithm = Literal[ + "RS256", + "RS384", + "RS512", + "ES256", + "ES384", + "ES512", + "PS256", + "PS384", + "PS512", +] + + +class AgentCardSigningConfig(BaseModel): + """Configuration for AgentCard JWS signing. + + Provides the private key and algorithm settings for signing AgentCards. + Either private_key_path or private_key_pem must be provided, but not both. + + Attributes: + private_key_path: Path to a PEM-encoded private key file. + private_key_pem: PEM-encoded private key as a secret string. + key_id: Optional key identifier for the JWS header (kid claim). + algorithm: Signing algorithm (RS256, ES256, PS256, etc.). + """ + + model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid") + + private_key_path: FilePath | None = Field( + default=None, + description="Path to PEM-encoded private key file", + ) + private_key_pem: SecretStr | None = Field( + default=None, + description="PEM-encoded private key", + ) + key_id: str | None = Field( + default=None, + description="Key identifier for JWS header (kid claim)", + ) + algorithm: SigningAlgorithm = Field( + default="RS256", + description="Signing algorithm (RS256, ES256, PS256, etc.)", + ) + + @model_validator(mode="after") + def _validate_key_source(self) -> Self: + """Ensure exactly one key source is provided.""" + has_path = self.private_key_path is not None + has_pem = self.private_key_pem is not None + + if not has_path and not has_pem: + raise ValueError( + "Either private_key_path or private_key_pem must be provided" + ) + if has_path and has_pem: + raise ValueError( + "Only one of private_key_path or private_key_pem should be provided" + ) + return self + + def get_private_key(self) -> str: + """Get the private key content. + + Returns: + The PEM-encoded private key as a string. + """ + if self.private_key_pem: + return self.private_key_pem.get_secret_value() + if self.private_key_path: + return Path(self.private_key_path).read_text() + raise ValueError("No private key configured") + + +class GRPCServerConfig(BaseModel): + """gRPC server transport configuration. + + Presence of this config in ServerTransportConfig.grpc enables gRPC transport. + + Attributes: + host: Hostname to advertise in agent cards (default: localhost). + Use docker service name (e.g., 'web') for docker-compose setups. + port: Port for the gRPC server. + tls_cert_path: Path to TLS certificate file for gRPC. + tls_key_path: Path to TLS private key file for gRPC. + max_workers: Maximum number of workers for the gRPC thread pool. + reflection_enabled: Whether to enable gRPC reflection for debugging. + """ + + model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid") + + host: str = Field( + default="localhost", + description="Hostname to advertise in agent cards for gRPC connections", + ) + port: int = Field( + default=50051, + description="Port for the gRPC server", + ) + tls_cert_path: str | None = Field( + default=None, + description="Path to TLS certificate file for gRPC", + ) + tls_key_path: str | None = Field( + default=None, + description="Path to TLS private key file for gRPC", + ) + max_workers: int = Field( + default=10, + description="Maximum number of workers for the gRPC thread pool", + ) + reflection_enabled: bool = Field( + default=False, + description="Whether to enable gRPC reflection for debugging", + ) + + +class GRPCClientConfig(BaseModel): + """gRPC client transport configuration. + + Attributes: + max_send_message_length: Maximum size for outgoing messages in bytes. + max_receive_message_length: Maximum size for incoming messages in bytes. + keepalive_time_ms: Time between keepalive pings in milliseconds. + keepalive_timeout_ms: Timeout for keepalive ping response in milliseconds. + """ + + model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid") + + max_send_message_length: int | None = Field( + default=None, + description="Maximum size for outgoing messages in bytes", + ) + max_receive_message_length: int | None = Field( + default=None, + description="Maximum size for incoming messages in bytes", + ) + keepalive_time_ms: int | None = Field( + default=None, + description="Time between keepalive pings in milliseconds", + ) + keepalive_timeout_ms: int | None = Field( + default=None, + description="Timeout for keepalive ping response in milliseconds", + ) + + +class JSONRPCServerConfig(BaseModel): + """JSON-RPC server transport configuration. + + Presence of this config in ServerTransportConfig.jsonrpc enables JSON-RPC transport. + + Attributes: + rpc_path: URL path for the JSON-RPC endpoint. + agent_card_path: URL path for the agent card endpoint. + """ + + model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid") + + rpc_path: str = Field( + default="/a2a", + description="URL path for the JSON-RPC endpoint", + ) + agent_card_path: str = Field( + default="/.well-known/agent-card.json", + description="URL path for the agent card endpoint", + ) + + +class JSONRPCClientConfig(BaseModel): + """JSON-RPC client transport configuration. + + Attributes: + max_request_size: Maximum request body size in bytes. + """ + + model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid") + + max_request_size: int | None = Field( + default=None, + description="Maximum request body size in bytes", + ) + + +class HTTPJSONConfig(BaseModel): + """HTTP+JSON transport configuration. + + Presence of this config in ServerTransportConfig.http_json enables HTTP+JSON transport. + """ + + model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid") + + +class ServerPushNotificationConfig(BaseModel): + """Configuration for outgoing webhook push notifications. + + Controls how the server signs and delivers push notifications to clients. + + Attributes: + signature_secret: Shared secret for HMAC-SHA256 signing of outgoing webhooks. + """ + + model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid") + + signature_secret: SecretStr | None = Field( + default=None, + description="Shared secret for HMAC-SHA256 signing of outgoing push notifications", + ) + + +class ServerTransportConfig(BaseModel): + """Transport configuration for A2A server. + + Groups all transport-related settings including preferred transport + and protocol-specific configurations. + + Attributes: + preferred: Transport protocol for the preferred endpoint. + jsonrpc: JSON-RPC server transport configuration. + grpc: gRPC server transport configuration. + http_json: HTTP+JSON transport configuration. + """ + + model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid") + + preferred: TransportType = Field( + default="JSONRPC", + description="Transport protocol for the preferred endpoint", + ) + jsonrpc: JSONRPCServerConfig = Field( + default_factory=JSONRPCServerConfig, + description="JSON-RPC server transport configuration", + ) + grpc: GRPCServerConfig | None = Field( + default=None, + description="gRPC server transport configuration", + ) + http_json: HTTPJSONConfig | None = Field( + default=None, + description="HTTP+JSON transport configuration", + ) + + +def _migrate_client_transport_fields( + transport: ClientTransportConfig, + transport_protocol: TransportType | None, + supported_transports: list[TransportType] | None, +) -> None: + """Migrate deprecated transport fields to new config.""" + if transport_protocol is not None: + warnings.warn( + "transport_protocol is deprecated, use transport=ClientTransportConfig(preferred=...) instead", + FutureWarning, + stacklevel=5, + ) + object.__setattr__(transport, "preferred", transport_protocol) + if supported_transports is not None: + warnings.warn( + "supported_transports is deprecated, use transport=ClientTransportConfig(supported=...) instead", + FutureWarning, + stacklevel=5, + ) + object.__setattr__(transport, "supported", supported_transports) + + +class ClientTransportConfig(BaseModel): + """Transport configuration for A2A client. + + Groups all client transport-related settings including preferred transport, + supported transports for negotiation, and protocol-specific configurations. + + Transport negotiation logic: + 1. If `preferred` is set and server supports it → use client's preferred + 2. Otherwise, if server's preferred is in client's `supported` → use server's preferred + 3. Otherwise, find first match from client's `supported` in server's interfaces + + Attributes: + preferred: Client's preferred transport. If set, client preference takes priority. + supported: Transports the client can use, in order of preference. + jsonrpc: JSON-RPC client transport configuration. + grpc: gRPC client transport configuration. + """ + + model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid") + + preferred: TransportType | None = Field( + default=None, + description="Client's preferred transport. If set, takes priority over server preference.", + ) + supported: list[TransportType] = Field( + default_factory=lambda: cast(list[TransportType], ["JSONRPC"]), + description="Transports the client can use, in order of preference", + ) + jsonrpc: JSONRPCClientConfig = Field( + default_factory=JSONRPCClientConfig, + description="JSON-RPC client transport configuration", + ) + grpc: GRPCClientConfig = Field( + default_factory=GRPCClientConfig, + description="gRPC client transport configuration", + ) + + @deprecated( """ `crewai.a2a.config.A2AConfig` is deprecated and will be removed in v2.0.0, @@ -65,13 +380,14 @@ class A2AConfig(BaseModel): fail_fast: If True, raise error when agent unreachable; if False, skip and continue. trust_remote_completion_status: If True, return A2A agent's result directly when completed. updates: Update mechanism config. - transport_protocol: A2A transport protocol (grpc, jsonrpc, http+json). + client_extensions: Client-side processing hooks for tool injection and prompt augmentation. + transport: Transport configuration (preferred, supported transports, gRPC settings). """ model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid") endpoint: Url = Field(description="A2A agent endpoint URL") - auth: AuthScheme | None = Field( + auth: ClientAuthScheme | None = Field( default=None, description="Authentication scheme", ) @@ -95,10 +411,48 @@ class A2AConfig(BaseModel): default_factory=_get_default_update_config, description="Update mechanism config", ) - transport_protocol: Literal["JSONRPC", "GRPC", "HTTP+JSON"] = Field( - default="JSONRPC", - description="Specified mode of A2A transport protocol", + client_extensions: list[ValidatedA2AExtension] = Field( + default_factory=list, + description="Client-side processing hooks for tool injection and prompt augmentation", ) + transport: ClientTransportConfig = Field( + default_factory=ClientTransportConfig, + description="Transport configuration (preferred, supported transports, gRPC settings)", + ) + transport_protocol: TransportType | None = Field( + default=None, + description="Deprecated: Use transport.preferred instead", + exclude=True, + ) + supported_transports: list[TransportType] | None = Field( + default=None, + description="Deprecated: Use transport.supported instead", + exclude=True, + ) + use_client_preference: bool | None = Field( + default=None, + description="Deprecated: Set transport.preferred to enable client preference", + exclude=True, + ) + _parallel_delegation: bool = PrivateAttr(default=False) + + @model_validator(mode="after") + def _migrate_deprecated_transport_fields(self) -> Self: + """Migrate deprecated transport fields to new config.""" + _migrate_client_transport_fields( + self.transport, self.transport_protocol, self.supported_transports + ) + if self.use_client_preference is not None: + warnings.warn( + "use_client_preference is deprecated, set transport.preferred to enable client preference", + FutureWarning, + stacklevel=4, + ) + if self.use_client_preference and self.transport.supported: + object.__setattr__( + self.transport, "preferred", self.transport.supported[0] + ) + return self class A2AClientConfig(BaseModel): @@ -114,15 +468,15 @@ class A2AClientConfig(BaseModel): trust_remote_completion_status: If True, return A2A agent's result directly when completed. updates: Update mechanism config. accepted_output_modes: Media types the client can accept in responses. - supported_transports: Ordered list of transport protocols the client supports. - use_client_preference: Whether to prioritize client transport preferences over server. - extensions: Extension URIs the client supports. + extensions: Extension URIs the client supports (A2A protocol extensions). + client_extensions: Client-side processing hooks for tool injection and prompt augmentation. + transport: Transport configuration (preferred, supported transports, gRPC settings). """ model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid") endpoint: Url = Field(description="A2A agent endpoint URL") - auth: AuthScheme | None = Field( + auth: ClientAuthScheme | None = Field( default=None, description="Authentication scheme", ) @@ -150,22 +504,37 @@ class A2AClientConfig(BaseModel): default_factory=lambda: ["application/json"], description="Media types the client can accept in responses", ) - supported_transports: list[str] = Field( - default_factory=lambda: ["JSONRPC"], - description="Ordered list of transport protocols the client supports", - ) - use_client_preference: bool = Field( - default=False, - description="Whether to prioritize client transport preferences over server", - ) extensions: list[str] = Field( default_factory=list, description="Extension URIs the client supports", ) - transport_protocol: Literal["JSONRPC", "GRPC", "HTTP+JSON"] = Field( - default="JSONRPC", - description="Specified mode of A2A transport protocol", + client_extensions: list[ValidatedA2AExtension] = Field( + default_factory=list, + description="Client-side processing hooks for tool injection and prompt augmentation", ) + transport: ClientTransportConfig = Field( + default_factory=ClientTransportConfig, + description="Transport configuration (preferred, supported transports, gRPC settings)", + ) + transport_protocol: TransportType | None = Field( + default=None, + description="Deprecated: Use transport.preferred instead", + exclude=True, + ) + supported_transports: list[TransportType] | None = Field( + default=None, + description="Deprecated: Use transport.supported instead", + exclude=True, + ) + _parallel_delegation: bool = PrivateAttr(default=False) + + @model_validator(mode="after") + def _migrate_deprecated_transport_fields(self) -> Self: + """Migrate deprecated transport fields to new config.""" + _migrate_client_transport_fields( + self.transport, self.transport_protocol, self.supported_transports + ) + return self class A2AServerConfig(BaseModel): @@ -182,7 +551,6 @@ class A2AServerConfig(BaseModel): default_input_modes: Default supported input MIME types. default_output_modes: Default supported output MIME types. capabilities: Declaration of optional capabilities. - preferred_transport: Transport protocol for the preferred endpoint. protocol_version: A2A protocol version this agent supports. provider: Information about the agent's service provider. documentation_url: URL to the agent's documentation. @@ -192,7 +560,12 @@ class A2AServerConfig(BaseModel): security_schemes: Security schemes available to authorize requests. supports_authenticated_extended_card: Whether agent provides extended card to authenticated users. url: Preferred endpoint URL for the agent. - signatures: JSON Web Signatures for the AgentCard. + signing_config: Configuration for signing the AgentCard with JWS. + signatures: Deprecated. Pre-computed JWS signatures. Use signing_config instead. + server_extensions: Server-side A2A protocol extensions with on_request/on_response hooks. + push_notifications: Configuration for outgoing push notifications. + transport: Transport configuration (preferred transport, gRPC, REST settings). + auth: Authentication scheme for A2A endpoints. """ model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid") @@ -228,12 +601,8 @@ class A2AServerConfig(BaseModel): ), description="Declaration of optional capabilities supported by the agent", ) - preferred_transport: TransportType = Field( - default="JSONRPC", - description="Transport protocol for the preferred endpoint", - ) - protocol_version: str = Field( - default_factory=lambda: version("a2a-sdk"), + protocol_version: ProtocolVersion = Field( + default="0.3.0", description="A2A protocol version this agent supports", ) provider: AgentProvider | None = Field( @@ -250,7 +619,7 @@ class A2AServerConfig(BaseModel): ) additional_interfaces: list[AgentInterface] = Field( default_factory=list, - description="Additional supported interfaces (transport and URL combinations)", + description="Additional supported interfaces.", ) security: list[dict[str, list[str]]] = Field( default_factory=list, @@ -268,7 +637,54 @@ class A2AServerConfig(BaseModel): default=None, description="Preferred endpoint URL for the agent. Set at runtime if not provided.", ) - signatures: list[AgentCardSignature] = Field( - default_factory=list, - description="JSON Web Signatures for the AgentCard", + signing_config: AgentCardSigningConfig | None = Field( + default=None, + description="Configuration for signing the AgentCard with JWS", ) + signatures: list[AgentCardSignature] | None = Field( + default=None, + description="Deprecated: Use signing_config instead. Pre-computed JWS signatures for the AgentCard.", + exclude=True, + deprecated=True, + ) + server_extensions: list[ServerExtension] = Field( + default_factory=list, + description="Server-side A2A protocol extensions that modify agent behavior", + ) + push_notifications: ServerPushNotificationConfig | None = Field( + default=None, + description="Configuration for outgoing push notifications", + ) + transport: ServerTransportConfig = Field( + default_factory=ServerTransportConfig, + description="Transport configuration (preferred transport, gRPC, REST settings)", + ) + preferred_transport: TransportType | None = Field( + default=None, + description="Deprecated: Use transport.preferred instead", + exclude=True, + deprecated=True, + ) + auth: ServerAuthScheme | None = Field( + default=None, + description="Authentication scheme for A2A endpoints. Defaults to SimpleTokenAuth using AUTH_TOKEN env var.", + ) + + @model_validator(mode="after") + def _migrate_deprecated_fields(self) -> Self: + """Migrate deprecated fields to new config.""" + if self.preferred_transport is not None: + warnings.warn( + "preferred_transport is deprecated, use transport=ServerTransportConfig(preferred=...) instead", + FutureWarning, + stacklevel=4, + ) + object.__setattr__(self.transport, "preferred", self.preferred_transport) + if self.signatures is not None: + warnings.warn( + "signatures is deprecated, use signing_config=AgentCardSigningConfig(...) instead. " + "The signatures field will be removed in v2.0.0.", + FutureWarning, + stacklevel=4, + ) + return self diff --git a/lib/crewai/src/crewai/a2a/errors.py b/lib/crewai/src/crewai/a2a/errors.py index e24e9c296..aabe10288 100644 --- a/lib/crewai/src/crewai/a2a/errors.py +++ b/lib/crewai/src/crewai/a2a/errors.py @@ -1,7 +1,491 @@ -"""A2A protocol error types.""" +"""A2A error codes and error response utilities. + +This module provides a centralized mapping of all A2A protocol error codes +as defined in the A2A specification, plus custom CrewAI extensions. + +Error codes follow JSON-RPC 2.0 conventions: +- -32700 to -32600: Standard JSON-RPC errors +- -32099 to -32000: Server errors (A2A-specific) +- -32768 to -32100: Reserved for implementation-defined errors +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from enum import IntEnum +from typing import Any from a2a.client.errors import A2AClientTimeoutError class A2APollingTimeoutError(A2AClientTimeoutError): """Raised when polling exceeds the configured timeout.""" + + +class A2AErrorCode(IntEnum): + """A2A protocol error codes. + + Codes follow JSON-RPC 2.0 specification with A2A-specific extensions. + """ + + # JSON-RPC 2.0 Standard Errors (-32700 to -32600) + JSON_PARSE_ERROR = -32700 + """Invalid JSON was received by the server.""" + + INVALID_REQUEST = -32600 + """The JSON sent is not a valid Request object.""" + + METHOD_NOT_FOUND = -32601 + """The method does not exist / is not available.""" + + INVALID_PARAMS = -32602 + """Invalid method parameter(s).""" + + INTERNAL_ERROR = -32603 + """Internal JSON-RPC error.""" + + # A2A-Specific Errors (-32099 to -32000) + TASK_NOT_FOUND = -32001 + """The specified task was not found.""" + + TASK_NOT_CANCELABLE = -32002 + """The task cannot be canceled (already completed/failed).""" + + PUSH_NOTIFICATION_NOT_SUPPORTED = -32003 + """Push notifications are not supported by this agent.""" + + UNSUPPORTED_OPERATION = -32004 + """The requested operation is not supported.""" + + CONTENT_TYPE_NOT_SUPPORTED = -32005 + """Incompatible content types between client and server.""" + + INVALID_AGENT_RESPONSE = -32006 + """The agent produced an invalid response.""" + + # CrewAI Custom Extensions (-32768 to -32100) + UNSUPPORTED_VERSION = -32009 + """The requested A2A protocol version is not supported.""" + + UNSUPPORTED_EXTENSION = -32010 + """Client does not support required protocol extensions.""" + + AUTHENTICATION_REQUIRED = -32011 + """Authentication is required for this operation.""" + + AUTHORIZATION_FAILED = -32012 + """Authorization check failed (insufficient permissions).""" + + RATE_LIMIT_EXCEEDED = -32013 + """Rate limit exceeded for this client/operation.""" + + TASK_TIMEOUT = -32014 + """Task execution timed out.""" + + TRANSPORT_NEGOTIATION_FAILED = -32015 + """Failed to negotiate a compatible transport protocol.""" + + CONTEXT_NOT_FOUND = -32016 + """The specified context was not found.""" + + SKILL_NOT_FOUND = -32017 + """The specified skill was not found.""" + + ARTIFACT_NOT_FOUND = -32018 + """The specified artifact was not found.""" + + +# Error code to default message mapping +ERROR_MESSAGES: dict[int, str] = { + A2AErrorCode.JSON_PARSE_ERROR: "Parse error", + A2AErrorCode.INVALID_REQUEST: "Invalid Request", + A2AErrorCode.METHOD_NOT_FOUND: "Method not found", + A2AErrorCode.INVALID_PARAMS: "Invalid params", + A2AErrorCode.INTERNAL_ERROR: "Internal error", + A2AErrorCode.TASK_NOT_FOUND: "Task not found", + A2AErrorCode.TASK_NOT_CANCELABLE: "Task not cancelable", + A2AErrorCode.PUSH_NOTIFICATION_NOT_SUPPORTED: "Push Notification is not supported", + A2AErrorCode.UNSUPPORTED_OPERATION: "This operation is not supported", + A2AErrorCode.CONTENT_TYPE_NOT_SUPPORTED: "Incompatible content types", + A2AErrorCode.INVALID_AGENT_RESPONSE: "Invalid agent response", + A2AErrorCode.UNSUPPORTED_VERSION: "Unsupported A2A version", + A2AErrorCode.UNSUPPORTED_EXTENSION: "Client does not support required extensions", + A2AErrorCode.AUTHENTICATION_REQUIRED: "Authentication required", + A2AErrorCode.AUTHORIZATION_FAILED: "Authorization failed", + A2AErrorCode.RATE_LIMIT_EXCEEDED: "Rate limit exceeded", + A2AErrorCode.TASK_TIMEOUT: "Task execution timed out", + A2AErrorCode.TRANSPORT_NEGOTIATION_FAILED: "Transport negotiation failed", + A2AErrorCode.CONTEXT_NOT_FOUND: "Context not found", + A2AErrorCode.SKILL_NOT_FOUND: "Skill not found", + A2AErrorCode.ARTIFACT_NOT_FOUND: "Artifact not found", +} + + +@dataclass +class A2AError(Exception): + """Base exception for A2A protocol errors. + + Attributes: + code: The A2A/JSON-RPC error code. + message: Human-readable error message. + data: Optional additional error data. + """ + + code: int + message: str | None = None + data: Any = None + + def __post_init__(self) -> None: + if self.message is None: + self.message = ERROR_MESSAGES.get(self.code, "Unknown error") + super().__init__(self.message) + + def to_dict(self) -> dict[str, Any]: + """Convert to JSON-RPC error object format.""" + error: dict[str, Any] = { + "code": self.code, + "message": self.message, + } + if self.data is not None: + error["data"] = self.data + return error + + def to_response(self, request_id: str | int | None = None) -> dict[str, Any]: + """Convert to full JSON-RPC error response.""" + return { + "jsonrpc": "2.0", + "error": self.to_dict(), + "id": request_id, + } + + +@dataclass +class JSONParseError(A2AError): + """Invalid JSON was received.""" + + code: int = field(default=A2AErrorCode.JSON_PARSE_ERROR, init=False) + + +@dataclass +class InvalidRequestError(A2AError): + """The JSON sent is not a valid Request object.""" + + code: int = field(default=A2AErrorCode.INVALID_REQUEST, init=False) + + +@dataclass +class MethodNotFoundError(A2AError): + """The method does not exist / is not available.""" + + code: int = field(default=A2AErrorCode.METHOD_NOT_FOUND, init=False) + method: str | None = None + + def __post_init__(self) -> None: + if self.message is None and self.method: + self.message = f"Method not found: {self.method}" + super().__post_init__() + + +@dataclass +class InvalidParamsError(A2AError): + """Invalid method parameter(s).""" + + code: int = field(default=A2AErrorCode.INVALID_PARAMS, init=False) + param: str | None = None + reason: str | None = None + + def __post_init__(self) -> None: + if self.message is None: + if self.param and self.reason: + self.message = f"Invalid parameter '{self.param}': {self.reason}" + elif self.param: + self.message = f"Invalid parameter: {self.param}" + super().__post_init__() + + +@dataclass +class InternalError(A2AError): + """Internal JSON-RPC error.""" + + code: int = field(default=A2AErrorCode.INTERNAL_ERROR, init=False) + + +@dataclass +class TaskNotFoundError(A2AError): + """The specified task was not found.""" + + code: int = field(default=A2AErrorCode.TASK_NOT_FOUND, init=False) + task_id: str | None = None + + def __post_init__(self) -> None: + if self.message is None and self.task_id: + self.message = f"Task not found: {self.task_id}" + super().__post_init__() + + +@dataclass +class TaskNotCancelableError(A2AError): + """The task cannot be canceled.""" + + code: int = field(default=A2AErrorCode.TASK_NOT_CANCELABLE, init=False) + task_id: str | None = None + reason: str | None = None + + def __post_init__(self) -> None: + if self.message is None: + if self.task_id and self.reason: + self.message = f"Task {self.task_id} cannot be canceled: {self.reason}" + elif self.task_id: + self.message = f"Task {self.task_id} cannot be canceled" + super().__post_init__() + + +@dataclass +class PushNotificationNotSupportedError(A2AError): + """Push notifications are not supported.""" + + code: int = field(default=A2AErrorCode.PUSH_NOTIFICATION_NOT_SUPPORTED, init=False) + + +@dataclass +class UnsupportedOperationError(A2AError): + """The requested operation is not supported.""" + + code: int = field(default=A2AErrorCode.UNSUPPORTED_OPERATION, init=False) + operation: str | None = None + + def __post_init__(self) -> None: + if self.message is None and self.operation: + self.message = f"Operation not supported: {self.operation}" + super().__post_init__() + + +@dataclass +class ContentTypeNotSupportedError(A2AError): + """Incompatible content types.""" + + code: int = field(default=A2AErrorCode.CONTENT_TYPE_NOT_SUPPORTED, init=False) + requested_types: list[str] | None = None + supported_types: list[str] | None = None + + def __post_init__(self) -> None: + if self.message is None and self.requested_types and self.supported_types: + self.message = ( + f"Content type not supported. Requested: {self.requested_types}, " + f"Supported: {self.supported_types}" + ) + super().__post_init__() + + +@dataclass +class InvalidAgentResponseError(A2AError): + """The agent produced an invalid response.""" + + code: int = field(default=A2AErrorCode.INVALID_AGENT_RESPONSE, init=False) + + +@dataclass +class UnsupportedVersionError(A2AError): + """The requested A2A version is not supported.""" + + code: int = field(default=A2AErrorCode.UNSUPPORTED_VERSION, init=False) + requested_version: str | None = None + supported_versions: list[str] | None = None + + def __post_init__(self) -> None: + if self.message is None and self.requested_version: + msg = f"Unsupported A2A version: {self.requested_version}" + if self.supported_versions: + msg += f". Supported versions: {', '.join(self.supported_versions)}" + self.message = msg + super().__post_init__() + + +@dataclass +class UnsupportedExtensionError(A2AError): + """Client does not support required extensions.""" + + code: int = field(default=A2AErrorCode.UNSUPPORTED_EXTENSION, init=False) + required_extensions: list[str] | None = None + + def __post_init__(self) -> None: + if self.message is None and self.required_extensions: + self.message = f"Client does not support required extensions: {', '.join(self.required_extensions)}" + super().__post_init__() + + +@dataclass +class AuthenticationRequiredError(A2AError): + """Authentication is required.""" + + code: int = field(default=A2AErrorCode.AUTHENTICATION_REQUIRED, init=False) + + +@dataclass +class AuthorizationFailedError(A2AError): + """Authorization check failed.""" + + code: int = field(default=A2AErrorCode.AUTHORIZATION_FAILED, init=False) + required_scope: str | None = None + + def __post_init__(self) -> None: + if self.message is None and self.required_scope: + self.message = ( + f"Authorization failed. Required scope: {self.required_scope}" + ) + super().__post_init__() + + +@dataclass +class RateLimitExceededError(A2AError): + """Rate limit exceeded.""" + + code: int = field(default=A2AErrorCode.RATE_LIMIT_EXCEEDED, init=False) + retry_after: int | None = None + + def __post_init__(self) -> None: + if self.message is None and self.retry_after: + self.message = ( + f"Rate limit exceeded. Retry after {self.retry_after} seconds" + ) + if self.retry_after: + self.data = {"retry_after": self.retry_after} + super().__post_init__() + + +@dataclass +class TaskTimeoutError(A2AError): + """Task execution timed out.""" + + code: int = field(default=A2AErrorCode.TASK_TIMEOUT, init=False) + task_id: str | None = None + timeout_seconds: float | None = None + + def __post_init__(self) -> None: + if self.message is None: + if self.task_id and self.timeout_seconds: + self.message = ( + f"Task {self.task_id} timed out after {self.timeout_seconds}s" + ) + elif self.task_id: + self.message = f"Task {self.task_id} timed out" + super().__post_init__() + + +@dataclass +class TransportNegotiationFailedError(A2AError): + """Failed to negotiate a compatible transport protocol.""" + + code: int = field(default=A2AErrorCode.TRANSPORT_NEGOTIATION_FAILED, init=False) + client_transports: list[str] | None = None + server_transports: list[str] | None = None + + def __post_init__(self) -> None: + if self.message is None and self.client_transports and self.server_transports: + self.message = ( + f"Transport negotiation failed. Client: {self.client_transports}, " + f"Server: {self.server_transports}" + ) + super().__post_init__() + + +@dataclass +class ContextNotFoundError(A2AError): + """The specified context was not found.""" + + code: int = field(default=A2AErrorCode.CONTEXT_NOT_FOUND, init=False) + context_id: str | None = None + + def __post_init__(self) -> None: + if self.message is None and self.context_id: + self.message = f"Context not found: {self.context_id}" + super().__post_init__() + + +@dataclass +class SkillNotFoundError(A2AError): + """The specified skill was not found.""" + + code: int = field(default=A2AErrorCode.SKILL_NOT_FOUND, init=False) + skill_id: str | None = None + + def __post_init__(self) -> None: + if self.message is None and self.skill_id: + self.message = f"Skill not found: {self.skill_id}" + super().__post_init__() + + +@dataclass +class ArtifactNotFoundError(A2AError): + """The specified artifact was not found.""" + + code: int = field(default=A2AErrorCode.ARTIFACT_NOT_FOUND, init=False) + artifact_id: str | None = None + + def __post_init__(self) -> None: + if self.message is None and self.artifact_id: + self.message = f"Artifact not found: {self.artifact_id}" + super().__post_init__() + + +def create_error_response( + code: int | A2AErrorCode, + message: str | None = None, + data: Any = None, + request_id: str | int | None = None, +) -> dict[str, Any]: + """Create a JSON-RPC error response. + + Args: + code: Error code (A2AErrorCode or int). + message: Optional error message (uses default if not provided). + data: Optional additional error data. + request_id: Request ID for correlation. + + Returns: + Dict in JSON-RPC error response format. + """ + error = A2AError(code=int(code), message=message, data=data) + return error.to_response(request_id) + + +def is_retryable_error(code: int) -> bool: + """Check if an error is potentially retryable. + + Args: + code: Error code to check. + + Returns: + True if the error might be resolved by retrying. + """ + retryable_codes = { + A2AErrorCode.INTERNAL_ERROR, + A2AErrorCode.RATE_LIMIT_EXCEEDED, + A2AErrorCode.TASK_TIMEOUT, + } + return code in retryable_codes + + +def is_client_error(code: int) -> bool: + """Check if an error is a client-side error. + + Args: + code: Error code to check. + + Returns: + True if the error is due to client request issues. + """ + client_error_codes = { + A2AErrorCode.JSON_PARSE_ERROR, + A2AErrorCode.INVALID_REQUEST, + A2AErrorCode.METHOD_NOT_FOUND, + A2AErrorCode.INVALID_PARAMS, + A2AErrorCode.TASK_NOT_FOUND, + A2AErrorCode.CONTENT_TYPE_NOT_SUPPORTED, + A2AErrorCode.UNSUPPORTED_VERSION, + A2AErrorCode.UNSUPPORTED_EXTENSION, + A2AErrorCode.CONTEXT_NOT_FOUND, + A2AErrorCode.SKILL_NOT_FOUND, + A2AErrorCode.ARTIFACT_NOT_FOUND, + } + return code in client_error_codes diff --git a/lib/crewai/src/crewai/a2a/task_helpers.py b/lib/crewai/src/crewai/a2a/task_helpers.py index 1b513612a..b4a758656 100644 --- a/lib/crewai/src/crewai/a2a/task_helpers.py +++ b/lib/crewai/src/crewai/a2a/task_helpers.py @@ -51,6 +51,13 @@ ACTIONABLE_STATES: frozenset[TaskState] = frozenset( } ) +PENDING_STATES: frozenset[TaskState] = frozenset( + { + TaskState.submitted, + TaskState.working, + } +) + class TaskStateResult(TypedDict): """Result dictionary from processing A2A task state.""" @@ -272,6 +279,9 @@ def process_task_state( history=new_messages, ) + if a2a_task.status.state in PENDING_STATES: + return None + return None diff --git a/lib/crewai/src/crewai/a2a/templates.py b/lib/crewai/src/crewai/a2a/templates.py index 83bce22e5..16f0c479e 100644 --- a/lib/crewai/src/crewai/a2a/templates.py +++ b/lib/crewai/src/crewai/a2a/templates.py @@ -38,3 +38,18 @@ You MUST now: DO NOT send another request - the task is already done. """ + +REMOTE_AGENT_RESPONSE_NOTICE: Final[str] = """ + +STATUS: RESPONSE_RECEIVED +The remote agent has responded. Their response is in the conversation history above. + +You MUST now: +1. Set is_a2a=false (the remote task is complete and cannot receive more messages) +2. Provide YOUR OWN response to the original task based on the information received + +IMPORTANT: Your response should be addressed to the USER who gave you the original task. +Report what the remote agent told you in THIRD PERSON (e.g., "The remote agent said..." or "I learned that..."). +Do NOT address the remote agent directly or use "you" to refer to them. + +""" diff --git a/lib/crewai/src/crewai/a2a/types.py b/lib/crewai/src/crewai/a2a/types.py index ea15abd80..5a4a7672a 100644 --- a/lib/crewai/src/crewai/a2a/types.py +++ b/lib/crewai/src/crewai/a2a/types.py @@ -36,6 +36,17 @@ except ImportError: TransportType = Literal["JSONRPC", "GRPC", "HTTP+JSON"] +ProtocolVersion = Literal[ + "0.2.0", + "0.2.1", + "0.2.2", + "0.2.3", + "0.2.4", + "0.2.5", + "0.2.6", + "0.3.0", + "0.4.0", +] http_url_adapter: TypeAdapter[HttpUrl] = TypeAdapter(HttpUrl) diff --git a/lib/crewai/src/crewai/a2a/updates/base.py b/lib/crewai/src/crewai/a2a/updates/base.py index f81edf0bf..8a6a53aa3 100644 --- a/lib/crewai/src/crewai/a2a/updates/base.py +++ b/lib/crewai/src/crewai/a2a/updates/base.py @@ -2,12 +2,28 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Protocol, TypedDict +from typing import TYPE_CHECKING, Any, NamedTuple, Protocol, TypedDict from pydantic import GetCoreSchemaHandler from pydantic_core import CoreSchema, core_schema +class CommonParams(NamedTuple): + """Common parameters shared across all update handlers. + + Groups the frequently-passed parameters to reduce duplication. + """ + + turn_number: int + is_multiturn: bool + agent_role: str | None + endpoint: str + a2a_agent_name: str | None + context_id: str | None + from_task: Any + from_agent: Any + + if TYPE_CHECKING: from a2a.client import Client from a2a.types import AgentCard, Message, Task @@ -63,8 +79,8 @@ class PushNotificationResultStore(Protocol): @classmethod def __get_pydantic_core_schema__( cls, - source_type: Any, - handler: GetCoreSchemaHandler, + _source_type: Any, + _handler: GetCoreSchemaHandler, ) -> CoreSchema: return core_schema.any_schema() @@ -130,3 +146,31 @@ class UpdateHandler(Protocol): Result dictionary with status, result/error, and history. """ ... + + +def extract_common_params(kwargs: BaseHandlerKwargs) -> CommonParams: + """Extract common parameters from handler kwargs. + + Args: + kwargs: Handler kwargs dict. + + Returns: + CommonParams with extracted values. + + Raises: + ValueError: If endpoint is not provided. + """ + endpoint = kwargs.get("endpoint") + if endpoint is None: + raise ValueError("endpoint is required for update handlers") + + return CommonParams( + turn_number=kwargs.get("turn_number", 0), + is_multiturn=kwargs.get("is_multiturn", False), + agent_role=kwargs.get("agent_role"), + endpoint=endpoint, + a2a_agent_name=kwargs.get("a2a_agent_name"), + context_id=kwargs.get("context_id"), + from_task=kwargs.get("from_task"), + from_agent=kwargs.get("from_agent"), + ) diff --git a/lib/crewai/src/crewai/a2a/updates/polling/handler.py b/lib/crewai/src/crewai/a2a/updates/polling/handler.py index 3981e554b..dad5bca57 100644 --- a/lib/crewai/src/crewai/a2a/updates/polling/handler.py +++ b/lib/crewai/src/crewai/a2a/updates/polling/handler.py @@ -94,7 +94,7 @@ async def _poll_task_until_complete( A2APollingStatusEvent( task_id=task_id, context_id=effective_context_id, - state=str(task.status.state.value) if task.status.state else "unknown", + state=str(task.status.state.value), elapsed_seconds=elapsed, poll_count=poll_count, endpoint=endpoint, @@ -325,7 +325,7 @@ class PollingHandler: crewai_event_bus.emit( agent_branch, A2AConnectionErrorEvent( - endpoint=endpoint or "", + endpoint=endpoint, error=str(e), error_type="unexpected_error", a2a_agent_name=a2a_agent_name, diff --git a/lib/crewai/src/crewai/a2a/updates/push_notifications/config.py b/lib/crewai/src/crewai/a2a/updates/push_notifications/config.py index 2cd22bc21..de81dbe80 100644 --- a/lib/crewai/src/crewai/a2a/updates/push_notifications/config.py +++ b/lib/crewai/src/crewai/a2a/updates/push_notifications/config.py @@ -2,10 +2,30 @@ from __future__ import annotations +from typing import Annotated + from a2a.types import PushNotificationAuthenticationInfo -from pydantic import AnyHttpUrl, BaseModel, Field +from pydantic import AnyHttpUrl, BaseModel, BeforeValidator, Field from crewai.a2a.updates.base import PushNotificationResultStore +from crewai.a2a.updates.push_notifications.signature import WebhookSignatureConfig + + +def _coerce_signature( + value: str | WebhookSignatureConfig | None, +) -> WebhookSignatureConfig | None: + """Convert string secret to WebhookSignatureConfig.""" + if value is None: + return None + if isinstance(value, str): + return WebhookSignatureConfig.hmac_sha256(secret=value) + return value + + +SignatureInput = Annotated[ + WebhookSignatureConfig | None, + BeforeValidator(_coerce_signature), +] class PushNotificationConfig(BaseModel): @@ -19,6 +39,8 @@ class PushNotificationConfig(BaseModel): timeout: Max seconds to wait for task completion. interval: Seconds between result polling attempts. result_store: Store for receiving push notification results. + signature: HMAC signature config. Pass a string (secret) for defaults, + or WebhookSignatureConfig for custom settings. """ url: AnyHttpUrl = Field(description="Callback URL for push notifications") @@ -36,3 +58,8 @@ class PushNotificationConfig(BaseModel): result_store: PushNotificationResultStore | None = Field( default=None, description="Result store for push notification handling" ) + signature: SignatureInput = Field( + default=None, + description="HMAC signature config. Pass a string (secret) for simple usage, " + "or WebhookSignatureConfig for custom headers/tolerance.", + ) diff --git a/lib/crewai/src/crewai/a2a/updates/push_notifications/handler.py b/lib/crewai/src/crewai/a2a/updates/push_notifications/handler.py index b2bddf8f1..783bf6483 100644 --- a/lib/crewai/src/crewai/a2a/updates/push_notifications/handler.py +++ b/lib/crewai/src/crewai/a2a/updates/push_notifications/handler.py @@ -24,8 +24,10 @@ from crewai.a2a.task_helpers import ( send_message_and_get_task_id, ) from crewai.a2a.updates.base import ( + CommonParams, PushNotificationHandlerKwargs, PushNotificationResultStore, + extract_common_params, ) from crewai.events.event_bus import crewai_event_bus from crewai.events.types.a2a_events import ( @@ -39,10 +41,81 @@ from crewai.events.types.a2a_events import ( if TYPE_CHECKING: from a2a.types import Task as A2ATask - logger = logging.getLogger(__name__) +def _handle_push_error( + error: Exception, + error_msg: str, + error_type: str, + new_messages: list[Message], + agent_branch: Any | None, + params: CommonParams, + task_id: str | None, + status_code: int | None = None, +) -> TaskStateResult: + """Handle push notification errors with consistent event emission. + + Args: + error: The exception that occurred. + error_msg: Formatted error message for the result. + error_type: Type of error for the event. + new_messages: List to append error message to. + agent_branch: Agent tree branch for events. + params: Common handler parameters. + task_id: A2A task ID. + status_code: HTTP status code if applicable. + + Returns: + TaskStateResult with failed status. + """ + error_message = Message( + role=Role.agent, + message_id=str(uuid.uuid4()), + parts=[Part(root=TextPart(text=error_msg))], + context_id=params.context_id, + task_id=task_id, + ) + new_messages.append(error_message) + + crewai_event_bus.emit( + agent_branch, + A2AConnectionErrorEvent( + endpoint=params.endpoint, + error=str(error), + error_type=error_type, + status_code=status_code, + a2a_agent_name=params.a2a_agent_name, + operation="push_notification", + context_id=params.context_id, + task_id=task_id, + from_task=params.from_task, + from_agent=params.from_agent, + ), + ) + crewai_event_bus.emit( + agent_branch, + A2AResponseReceivedEvent( + response=error_msg, + turn_number=params.turn_number, + context_id=params.context_id, + is_multiturn=params.is_multiturn, + status="failed", + final=True, + agent_role=params.agent_role, + endpoint=params.endpoint, + a2a_agent_name=params.a2a_agent_name, + from_task=params.from_task, + from_agent=params.from_agent, + ), + ) + return TaskStateResult( + status=TaskState.failed, + error=error_msg, + history=new_messages, + ) + + async def _wait_for_push_result( task_id: str, result_store: PushNotificationResultStore, @@ -126,15 +199,8 @@ class PushNotificationHandler: polling_timeout = kwargs.get("polling_timeout", 300.0) polling_interval = kwargs.get("polling_interval", 2.0) agent_branch = kwargs.get("agent_branch") - turn_number = kwargs.get("turn_number", 0) - is_multiturn = kwargs.get("is_multiturn", False) - agent_role = kwargs.get("agent_role") - context_id = kwargs.get("context_id") task_id = kwargs.get("task_id") - endpoint = kwargs.get("endpoint") - a2a_agent_name = kwargs.get("a2a_agent_name") - from_task = kwargs.get("from_task") - from_agent = kwargs.get("from_agent") + params = extract_common_params(kwargs) if config is None: error_msg = ( @@ -143,15 +209,15 @@ class PushNotificationHandler: crewai_event_bus.emit( agent_branch, A2AConnectionErrorEvent( - endpoint=endpoint or "", + endpoint=params.endpoint, error=error_msg, error_type="configuration_error", - a2a_agent_name=a2a_agent_name, + a2a_agent_name=params.a2a_agent_name, operation="push_notification", - context_id=context_id, + context_id=params.context_id, task_id=task_id, - from_task=from_task, - from_agent=from_agent, + from_task=params.from_task, + from_agent=params.from_agent, ), ) return TaskStateResult( @@ -167,15 +233,15 @@ class PushNotificationHandler: crewai_event_bus.emit( agent_branch, A2AConnectionErrorEvent( - endpoint=endpoint or "", + endpoint=params.endpoint, error=error_msg, error_type="configuration_error", - a2a_agent_name=a2a_agent_name, + a2a_agent_name=params.a2a_agent_name, operation="push_notification", - context_id=context_id, + context_id=params.context_id, task_id=task_id, - from_task=from_task, - from_agent=from_agent, + from_task=params.from_task, + from_agent=params.from_agent, ), ) return TaskStateResult( @@ -189,14 +255,14 @@ class PushNotificationHandler: event_stream=client.send_message(message), new_messages=new_messages, agent_card=agent_card, - turn_number=turn_number, - is_multiturn=is_multiturn, - agent_role=agent_role, - from_task=from_task, - from_agent=from_agent, - endpoint=endpoint, - a2a_agent_name=a2a_agent_name, - context_id=context_id, + turn_number=params.turn_number, + is_multiturn=params.is_multiturn, + agent_role=params.agent_role, + from_task=params.from_task, + from_agent=params.from_agent, + endpoint=params.endpoint, + a2a_agent_name=params.a2a_agent_name, + context_id=params.context_id, ) if not isinstance(result_or_task_id, str): @@ -208,12 +274,12 @@ class PushNotificationHandler: agent_branch, A2APushNotificationRegisteredEvent( task_id=task_id, - context_id=context_id, + context_id=params.context_id, callback_url=str(config.url), - endpoint=endpoint, - a2a_agent_name=a2a_agent_name, - from_task=from_task, - from_agent=from_agent, + endpoint=params.endpoint, + a2a_agent_name=params.a2a_agent_name, + from_task=params.from_task, + from_agent=params.from_agent, ), ) @@ -229,11 +295,11 @@ class PushNotificationHandler: timeout=polling_timeout, poll_interval=polling_interval, agent_branch=agent_branch, - from_task=from_task, - from_agent=from_agent, - context_id=context_id, - endpoint=endpoint, - a2a_agent_name=a2a_agent_name, + from_task=params.from_task, + from_agent=params.from_agent, + context_id=params.context_id, + endpoint=params.endpoint, + a2a_agent_name=params.a2a_agent_name, ) if final_task is None: @@ -247,13 +313,13 @@ class PushNotificationHandler: a2a_task=final_task, new_messages=new_messages, agent_card=agent_card, - turn_number=turn_number, - is_multiturn=is_multiturn, - agent_role=agent_role, - endpoint=endpoint, - a2a_agent_name=a2a_agent_name, - from_task=from_task, - from_agent=from_agent, + turn_number=params.turn_number, + is_multiturn=params.is_multiturn, + agent_role=params.agent_role, + endpoint=params.endpoint, + a2a_agent_name=params.a2a_agent_name, + from_task=params.from_task, + from_agent=params.from_agent, ) if result: return result @@ -265,98 +331,24 @@ class PushNotificationHandler: ) except A2AClientHTTPError as e: - error_msg = f"HTTP Error {e.status_code}: {e!s}" - - error_message = Message( - role=Role.agent, - message_id=str(uuid.uuid4()), - parts=[Part(root=TextPart(text=error_msg))], - context_id=context_id, + return _handle_push_error( + error=e, + error_msg=f"HTTP Error {e.status_code}: {e!s}", + error_type="http_error", + new_messages=new_messages, + agent_branch=agent_branch, + params=params, task_id=task_id, - ) - new_messages.append(error_message) - - crewai_event_bus.emit( - agent_branch, - A2AConnectionErrorEvent( - endpoint=endpoint or "", - error=str(e), - error_type="http_error", - status_code=e.status_code, - a2a_agent_name=a2a_agent_name, - operation="push_notification", - context_id=context_id, - task_id=task_id, - from_task=from_task, - from_agent=from_agent, - ), - ) - crewai_event_bus.emit( - agent_branch, - A2AResponseReceivedEvent( - response=error_msg, - turn_number=turn_number, - context_id=context_id, - is_multiturn=is_multiturn, - status="failed", - final=True, - agent_role=agent_role, - endpoint=endpoint, - a2a_agent_name=a2a_agent_name, - from_task=from_task, - from_agent=from_agent, - ), - ) - return TaskStateResult( - status=TaskState.failed, - error=error_msg, - history=new_messages, + status_code=e.status_code, ) except Exception as e: - error_msg = f"Unexpected error during push notification: {e!s}" - - error_message = Message( - role=Role.agent, - message_id=str(uuid.uuid4()), - parts=[Part(root=TextPart(text=error_msg))], - context_id=context_id, + return _handle_push_error( + error=e, + error_msg=f"Unexpected error during push notification: {e!s}", + error_type="unexpected_error", + new_messages=new_messages, + agent_branch=agent_branch, + params=params, task_id=task_id, ) - new_messages.append(error_message) - - crewai_event_bus.emit( - agent_branch, - A2AConnectionErrorEvent( - endpoint=endpoint or "", - error=str(e), - error_type="unexpected_error", - a2a_agent_name=a2a_agent_name, - operation="push_notification", - context_id=context_id, - task_id=task_id, - from_task=from_task, - from_agent=from_agent, - ), - ) - crewai_event_bus.emit( - agent_branch, - A2AResponseReceivedEvent( - response=error_msg, - turn_number=turn_number, - context_id=context_id, - is_multiturn=is_multiturn, - status="failed", - final=True, - agent_role=agent_role, - endpoint=endpoint, - a2a_agent_name=a2a_agent_name, - from_task=from_task, - from_agent=from_agent, - ), - ) - return TaskStateResult( - status=TaskState.failed, - error=error_msg, - history=new_messages, - ) diff --git a/lib/crewai/src/crewai/a2a/updates/push_notifications/signature.py b/lib/crewai/src/crewai/a2a/updates/push_notifications/signature.py new file mode 100644 index 000000000..9cac929ec --- /dev/null +++ b/lib/crewai/src/crewai/a2a/updates/push_notifications/signature.py @@ -0,0 +1,87 @@ +"""Webhook signature configuration for push notifications.""" + +from __future__ import annotations + +from enum import Enum +import secrets + +from pydantic import BaseModel, Field, SecretStr + + +class WebhookSignatureMode(str, Enum): + """Signature mode for webhook push notifications.""" + + NONE = "none" + HMAC_SHA256 = "hmac_sha256" + + +class WebhookSignatureConfig(BaseModel): + """Configuration for webhook signature verification. + + Provides cryptographic integrity verification and replay attack protection + for A2A push notifications. + + Attributes: + mode: Signature mode (none or hmac_sha256). + secret: Shared secret for HMAC computation (required for hmac_sha256 mode). + timestamp_tolerance_seconds: Max allowed age of timestamps for replay protection. + header_name: HTTP header name for the signature. + timestamp_header_name: HTTP header name for the timestamp. + """ + + mode: WebhookSignatureMode = Field( + default=WebhookSignatureMode.NONE, + description="Signature verification mode", + ) + secret: SecretStr | None = Field( + default=None, + description="Shared secret for HMAC computation", + ) + timestamp_tolerance_seconds: int = Field( + default=300, + ge=0, + description="Max allowed timestamp age in seconds (5 min default)", + ) + header_name: str = Field( + default="X-A2A-Signature", + description="HTTP header name for the signature", + ) + timestamp_header_name: str = Field( + default="X-A2A-Signature-Timestamp", + description="HTTP header name for the timestamp", + ) + + @classmethod + def generate_secret(cls, length: int = 32) -> str: + """Generate a cryptographically secure random secret. + + Args: + length: Number of random bytes to generate (default 32). + + Returns: + URL-safe base64-encoded secret string. + """ + return secrets.token_urlsafe(length) + + @classmethod + def hmac_sha256( + cls, + secret: str | SecretStr, + timestamp_tolerance_seconds: int = 300, + ) -> WebhookSignatureConfig: + """Create an HMAC-SHA256 signature configuration. + + Args: + secret: Shared secret for HMAC computation. + timestamp_tolerance_seconds: Max allowed timestamp age in seconds. + + Returns: + Configured WebhookSignatureConfig for HMAC-SHA256. + """ + if isinstance(secret, str): + secret = SecretStr(secret) + return cls( + mode=WebhookSignatureMode.HMAC_SHA256, + secret=secret, + timestamp_tolerance_seconds=timestamp_tolerance_seconds, + ) diff --git a/lib/crewai/src/crewai/a2a/updates/streaming/handler.py b/lib/crewai/src/crewai/a2a/updates/streaming/handler.py index 2bfe4dbed..9b0c21d12 100644 --- a/lib/crewai/src/crewai/a2a/updates/streaming/handler.py +++ b/lib/crewai/src/crewai/a2a/updates/streaming/handler.py @@ -2,6 +2,9 @@ from __future__ import annotations +import asyncio +import logging +from typing import Final import uuid from a2a.client import Client @@ -11,7 +14,10 @@ from a2a.types import ( Message, Part, Role, + Task, TaskArtifactUpdateEvent, + TaskIdParams, + TaskQueryParams, TaskState, TaskStatusUpdateEvent, TextPart, @@ -24,7 +30,10 @@ from crewai.a2a.task_helpers import ( TaskStateResult, process_task_state, ) -from crewai.a2a.updates.base import StreamingHandlerKwargs +from crewai.a2a.updates.base import StreamingHandlerKwargs, extract_common_params +from crewai.a2a.updates.streaming.params import ( + process_status_update, +) from crewai.events.event_bus import crewai_event_bus from crewai.events.types.a2a_events import ( A2AArtifactReceivedEvent, @@ -35,9 +44,194 @@ from crewai.events.types.a2a_events import ( ) +logger = logging.getLogger(__name__) + +MAX_RESUBSCRIBE_ATTEMPTS: Final[int] = 3 +RESUBSCRIBE_BACKOFF_BASE: Final[float] = 1.0 + + class StreamingHandler: """SSE streaming-based update handler.""" + @staticmethod + async def _try_recover_from_interruption( # type: ignore[misc] + client: Client, + task_id: str, + new_messages: list[Message], + agent_card: AgentCard, + result_parts: list[str], + **kwargs: Unpack[StreamingHandlerKwargs], + ) -> TaskStateResult | None: + """Attempt to recover from a stream interruption by checking task state. + + If the task completed while we were disconnected, returns the result. + If the task is still running, attempts to resubscribe and continue. + + Args: + client: A2A client instance. + task_id: The task ID to recover. + new_messages: List of collected messages. + agent_card: The agent card. + result_parts: Accumulated result text parts. + **kwargs: Handler parameters. + + Returns: + TaskStateResult if recovery succeeded (task finished or resubscribe worked). + None if recovery not possible (caller should handle failure). + + Note: + When None is returned, recovery failed and the original exception should + be handled by the caller. All recovery attempts are logged. + """ + params = extract_common_params(kwargs) # type: ignore[arg-type] + + try: + a2a_task: Task = await client.get_task(TaskQueryParams(id=task_id)) + + if a2a_task.status.state in TERMINAL_STATES: + logger.info( + "Task completed during stream interruption", + extra={"task_id": task_id, "state": str(a2a_task.status.state)}, + ) + return process_task_state( + a2a_task=a2a_task, + new_messages=new_messages, + agent_card=agent_card, + turn_number=params.turn_number, + is_multiturn=params.is_multiturn, + agent_role=params.agent_role, + result_parts=result_parts, + endpoint=params.endpoint, + a2a_agent_name=params.a2a_agent_name, + from_task=params.from_task, + from_agent=params.from_agent, + ) + + if a2a_task.status.state in ACTIONABLE_STATES: + logger.info( + "Task in actionable state during stream interruption", + extra={"task_id": task_id, "state": str(a2a_task.status.state)}, + ) + return process_task_state( + a2a_task=a2a_task, + new_messages=new_messages, + agent_card=agent_card, + turn_number=params.turn_number, + is_multiturn=params.is_multiturn, + agent_role=params.agent_role, + result_parts=result_parts, + endpoint=params.endpoint, + a2a_agent_name=params.a2a_agent_name, + from_task=params.from_task, + from_agent=params.from_agent, + is_final=False, + ) + + logger.info( + "Task still running, attempting resubscribe", + extra={"task_id": task_id, "state": str(a2a_task.status.state)}, + ) + + for attempt in range(MAX_RESUBSCRIBE_ATTEMPTS): + try: + backoff = RESUBSCRIBE_BACKOFF_BASE * (2**attempt) + if attempt > 0: + await asyncio.sleep(backoff) + + event_stream = client.resubscribe(TaskIdParams(id=task_id)) + + async for event in event_stream: + if isinstance(event, tuple): + resubscribed_task, update = event + + is_final_update = ( + process_status_update(update, result_parts) + if isinstance(update, TaskStatusUpdateEvent) + else False + ) + + if isinstance(update, TaskArtifactUpdateEvent): + artifact = update.artifact + result_parts.extend( + part.root.text + for part in artifact.parts + if part.root.kind == "text" + ) + + if ( + is_final_update + or resubscribed_task.status.state + in TERMINAL_STATES | ACTIONABLE_STATES + ): + return process_task_state( + a2a_task=resubscribed_task, + new_messages=new_messages, + agent_card=agent_card, + turn_number=params.turn_number, + is_multiturn=params.is_multiturn, + agent_role=params.agent_role, + result_parts=result_parts, + endpoint=params.endpoint, + a2a_agent_name=params.a2a_agent_name, + from_task=params.from_task, + from_agent=params.from_agent, + is_final=is_final_update, + ) + + elif isinstance(event, Message): + new_messages.append(event) + result_parts.extend( + part.root.text + for part in event.parts + if part.root.kind == "text" + ) + + final_task = await client.get_task(TaskQueryParams(id=task_id)) + return process_task_state( + a2a_task=final_task, + new_messages=new_messages, + agent_card=agent_card, + turn_number=params.turn_number, + is_multiturn=params.is_multiturn, + agent_role=params.agent_role, + result_parts=result_parts, + endpoint=params.endpoint, + a2a_agent_name=params.a2a_agent_name, + from_task=params.from_task, + from_agent=params.from_agent, + ) + + except Exception as resubscribe_error: # noqa: PERF203 + logger.warning( + "Resubscribe attempt failed", + extra={ + "task_id": task_id, + "attempt": attempt + 1, + "max_attempts": MAX_RESUBSCRIBE_ATTEMPTS, + "error": str(resubscribe_error), + }, + ) + if attempt == MAX_RESUBSCRIBE_ATTEMPTS - 1: + return None + + except Exception as e: + logger.warning( + "Failed to recover from stream interruption due to unexpected error", + extra={ + "task_id": task_id, + "error": str(e), + "error_type": type(e).__name__, + }, + exc_info=True, + ) + return None + + logger.warning( + "Recovery exhausted all resubscribe attempts without success", + extra={"task_id": task_id, "max_attempts": MAX_RESUBSCRIBE_ATTEMPTS}, + ) + return None + @staticmethod async def execute( client: Client, @@ -58,42 +252,40 @@ class StreamingHandler: Returns: Dictionary with status, result/error, and history. """ - context_id = kwargs.get("context_id") task_id = kwargs.get("task_id") - turn_number = kwargs.get("turn_number", 0) - is_multiturn = kwargs.get("is_multiturn", False) - agent_role = kwargs.get("agent_role") - endpoint = kwargs.get("endpoint") - a2a_agent_name = kwargs.get("a2a_agent_name") - from_task = kwargs.get("from_task") - from_agent = kwargs.get("from_agent") agent_branch = kwargs.get("agent_branch") + params = extract_common_params(kwargs) result_parts: list[str] = [] final_result: TaskStateResult | None = None event_stream = client.send_message(message) chunk_index = 0 + current_task_id: str | None = task_id crewai_event_bus.emit( agent_branch, A2AStreamingStartedEvent( task_id=task_id, - context_id=context_id, - endpoint=endpoint or "", - a2a_agent_name=a2a_agent_name, - turn_number=turn_number, - is_multiturn=is_multiturn, - agent_role=agent_role, - from_task=from_task, - from_agent=from_agent, + context_id=params.context_id, + endpoint=params.endpoint, + a2a_agent_name=params.a2a_agent_name, + turn_number=params.turn_number, + is_multiturn=params.is_multiturn, + agent_role=params.agent_role, + from_task=params.from_task, + from_agent=params.from_agent, ), ) try: async for event in event_stream: + if isinstance(event, tuple): + a2a_task, _ = event + current_task_id = a2a_task.id + if isinstance(event, Message): new_messages.append(event) - message_context_id = event.context_id or context_id + message_context_id = event.context_id or params.context_id for part in event.parts: if part.root.kind == "text": text = part.root.text @@ -105,12 +297,12 @@ class StreamingHandler: context_id=message_context_id, chunk=text, chunk_index=chunk_index, - endpoint=endpoint, - a2a_agent_name=a2a_agent_name, - turn_number=turn_number, - is_multiturn=is_multiturn, - from_task=from_task, - from_agent=from_agent, + endpoint=params.endpoint, + a2a_agent_name=params.a2a_agent_name, + turn_number=params.turn_number, + is_multiturn=params.is_multiturn, + from_task=params.from_task, + from_agent=params.from_agent, ), ) chunk_index += 1 @@ -128,12 +320,12 @@ class StreamingHandler: artifact_size = None if artifact.parts: artifact_size = sum( - len(p.root.text.encode("utf-8")) + len(p.root.text.encode()) if p.root.kind == "text" else len(getattr(p.root, "data", b"")) for p in artifact.parts ) - effective_context_id = a2a_task.context_id or context_id + effective_context_id = a2a_task.context_id or params.context_id crewai_event_bus.emit( agent_branch, A2AArtifactReceivedEvent( @@ -147,29 +339,21 @@ class StreamingHandler: size_bytes=artifact_size, append=update.append or False, last_chunk=update.last_chunk or False, - endpoint=endpoint, - a2a_agent_name=a2a_agent_name, + endpoint=params.endpoint, + a2a_agent_name=params.a2a_agent_name, context_id=effective_context_id, - turn_number=turn_number, - is_multiturn=is_multiturn, - from_task=from_task, - from_agent=from_agent, + turn_number=params.turn_number, + is_multiturn=params.is_multiturn, + from_task=params.from_task, + from_agent=params.from_agent, ), ) - is_final_update = False - if isinstance(update, TaskStatusUpdateEvent): - is_final_update = update.final - if ( - update.status - and update.status.message - and update.status.message.parts - ): - result_parts.extend( - part.root.text - for part in update.status.message.parts - if part.root.kind == "text" and part.root.text - ) + is_final_update = ( + process_status_update(update, result_parts) + if isinstance(update, TaskStatusUpdateEvent) + else False + ) if ( not is_final_update @@ -182,27 +366,68 @@ class StreamingHandler: a2a_task=a2a_task, new_messages=new_messages, agent_card=agent_card, - turn_number=turn_number, - is_multiturn=is_multiturn, - agent_role=agent_role, + turn_number=params.turn_number, + is_multiturn=params.is_multiturn, + agent_role=params.agent_role, result_parts=result_parts, - endpoint=endpoint, - a2a_agent_name=a2a_agent_name, - from_task=from_task, - from_agent=from_agent, + endpoint=params.endpoint, + a2a_agent_name=params.a2a_agent_name, + from_task=params.from_task, + from_agent=params.from_agent, is_final=is_final_update, ) if final_result: break except A2AClientHTTPError as e: + if current_task_id: + logger.info( + "Stream interrupted with HTTP error, attempting recovery", + extra={ + "task_id": current_task_id, + "error": str(e), + "status_code": e.status_code, + }, + ) + recovery_kwargs = {k: v for k, v in kwargs.items() if k != "task_id"} + recovered_result = ( + await StreamingHandler._try_recover_from_interruption( + client=client, + task_id=current_task_id, + new_messages=new_messages, + agent_card=agent_card, + result_parts=result_parts, + **recovery_kwargs, + ) + ) + if recovered_result: + logger.info( + "Successfully recovered task after HTTP error", + extra={ + "task_id": current_task_id, + "status": str(recovered_result.get("status")), + }, + ) + return recovered_result + + logger.warning( + "Failed to recover from HTTP error, returning failure", + extra={ + "task_id": current_task_id, + "status_code": e.status_code, + "original_error": str(e), + }, + ) + error_msg = f"HTTP Error {e.status_code}: {e!s}" + error_type = "http_error" + status_code = e.status_code error_message = Message( role=Role.agent, message_id=str(uuid.uuid4()), parts=[Part(root=TextPart(text=error_msg))], - context_id=context_id, + context_id=params.context_id, task_id=task_id, ) new_messages.append(error_message) @@ -210,32 +435,118 @@ class StreamingHandler: crewai_event_bus.emit( agent_branch, A2AConnectionErrorEvent( - endpoint=endpoint or "", + endpoint=params.endpoint, error=str(e), - error_type="http_error", - status_code=e.status_code, - a2a_agent_name=a2a_agent_name, + error_type=error_type, + status_code=status_code, + a2a_agent_name=params.a2a_agent_name, operation="streaming", - context_id=context_id, + context_id=params.context_id, task_id=task_id, - from_task=from_task, - from_agent=from_agent, + from_task=params.from_task, + from_agent=params.from_agent, ), ) crewai_event_bus.emit( agent_branch, A2AResponseReceivedEvent( response=error_msg, - turn_number=turn_number, - context_id=context_id, - is_multiturn=is_multiturn, + turn_number=params.turn_number, + context_id=params.context_id, + is_multiturn=params.is_multiturn, status="failed", final=True, - agent_role=agent_role, - endpoint=endpoint, - a2a_agent_name=a2a_agent_name, - from_task=from_task, - from_agent=from_agent, + agent_role=params.agent_role, + endpoint=params.endpoint, + a2a_agent_name=params.a2a_agent_name, + from_task=params.from_task, + from_agent=params.from_agent, + ), + ) + return TaskStateResult( + status=TaskState.failed, + error=error_msg, + history=new_messages, + ) + + except (asyncio.TimeoutError, asyncio.CancelledError, ConnectionError) as e: + error_type = type(e).__name__.lower() + if current_task_id: + logger.info( + f"Stream interrupted with {error_type}, attempting recovery", + extra={"task_id": current_task_id, "error": str(e)}, + ) + recovery_kwargs = {k: v for k, v in kwargs.items() if k != "task_id"} + recovered_result = ( + await StreamingHandler._try_recover_from_interruption( + client=client, + task_id=current_task_id, + new_messages=new_messages, + agent_card=agent_card, + result_parts=result_parts, + **recovery_kwargs, + ) + ) + if recovered_result: + logger.info( + f"Successfully recovered task after {error_type}", + extra={ + "task_id": current_task_id, + "status": str(recovered_result.get("status")), + }, + ) + return recovered_result + + logger.warning( + f"Failed to recover from {error_type}, returning failure", + extra={ + "task_id": current_task_id, + "error_type": error_type, + "original_error": str(e), + }, + ) + + error_msg = f"Connection error during streaming: {e!s}" + status_code = None + + error_message = Message( + role=Role.agent, + message_id=str(uuid.uuid4()), + parts=[Part(root=TextPart(text=error_msg))], + context_id=params.context_id, + task_id=task_id, + ) + new_messages.append(error_message) + + crewai_event_bus.emit( + agent_branch, + A2AConnectionErrorEvent( + endpoint=params.endpoint, + error=str(e), + error_type=error_type, + status_code=status_code, + a2a_agent_name=params.a2a_agent_name, + operation="streaming", + context_id=params.context_id, + task_id=task_id, + from_task=params.from_task, + from_agent=params.from_agent, + ), + ) + crewai_event_bus.emit( + agent_branch, + A2AResponseReceivedEvent( + response=error_msg, + turn_number=params.turn_number, + context_id=params.context_id, + is_multiturn=params.is_multiturn, + status="failed", + final=True, + agent_role=params.agent_role, + endpoint=params.endpoint, + a2a_agent_name=params.a2a_agent_name, + from_task=params.from_task, + from_agent=params.from_agent, ), ) return TaskStateResult( @@ -245,13 +556,23 @@ class StreamingHandler: ) except Exception as e: - error_msg = f"Unexpected error during streaming: {e!s}" + logger.exception( + "Unexpected error during streaming", + extra={ + "task_id": current_task_id, + "error_type": type(e).__name__, + "endpoint": params.endpoint, + }, + ) + error_msg = f"Unexpected error during streaming: {type(e).__name__}: {e!s}" + error_type = "unexpected_error" + status_code = None error_message = Message( role=Role.agent, message_id=str(uuid.uuid4()), parts=[Part(root=TextPart(text=error_msg))], - context_id=context_id, + context_id=params.context_id, task_id=task_id, ) new_messages.append(error_message) @@ -259,31 +580,32 @@ class StreamingHandler: crewai_event_bus.emit( agent_branch, A2AConnectionErrorEvent( - endpoint=endpoint or "", + endpoint=params.endpoint, error=str(e), - error_type="unexpected_error", - a2a_agent_name=a2a_agent_name, + error_type=error_type, + status_code=status_code, + a2a_agent_name=params.a2a_agent_name, operation="streaming", - context_id=context_id, + context_id=params.context_id, task_id=task_id, - from_task=from_task, - from_agent=from_agent, + from_task=params.from_task, + from_agent=params.from_agent, ), ) crewai_event_bus.emit( agent_branch, A2AResponseReceivedEvent( response=error_msg, - turn_number=turn_number, - context_id=context_id, - is_multiturn=is_multiturn, + turn_number=params.turn_number, + context_id=params.context_id, + is_multiturn=params.is_multiturn, status="failed", final=True, - agent_role=agent_role, - endpoint=endpoint, - a2a_agent_name=a2a_agent_name, - from_task=from_task, - from_agent=from_agent, + agent_role=params.agent_role, + endpoint=params.endpoint, + a2a_agent_name=params.a2a_agent_name, + from_task=params.from_task, + from_agent=params.from_agent, ), ) return TaskStateResult( @@ -301,15 +623,15 @@ class StreamingHandler: crewai_event_bus.emit( agent_branch, A2AConnectionErrorEvent( - endpoint=endpoint or "", + endpoint=params.endpoint, error=str(close_error), error_type="stream_close_error", - a2a_agent_name=a2a_agent_name, + a2a_agent_name=params.a2a_agent_name, operation="stream_close", - context_id=context_id, + context_id=params.context_id, task_id=task_id, - from_task=from_task, - from_agent=from_agent, + from_task=params.from_task, + from_agent=params.from_agent, ), ) diff --git a/lib/crewai/src/crewai/a2a/updates/streaming/params.py b/lib/crewai/src/crewai/a2a/updates/streaming/params.py new file mode 100644 index 000000000..a4bf8c0a2 --- /dev/null +++ b/lib/crewai/src/crewai/a2a/updates/streaming/params.py @@ -0,0 +1,28 @@ +"""Common parameter extraction for streaming handlers.""" + +from __future__ import annotations + +from a2a.types import TaskStatusUpdateEvent + + +def process_status_update( + update: TaskStatusUpdateEvent, + result_parts: list[str], +) -> bool: + """Process a status update event and extract text parts. + + Args: + update: The status update event. + result_parts: List to append text parts to (modified in place). + + Returns: + True if this is a final update, False otherwise. + """ + is_final = update.final + if update.status and update.status.message and update.status.message.parts: + result_parts.extend( + part.root.text + for part in update.status.message.parts + if part.root.kind == "text" and part.root.text + ) + return is_final diff --git a/lib/crewai/src/crewai/a2a/utils/agent_card.py b/lib/crewai/src/crewai/a2a/utils/agent_card.py index a21bfefdb..c548cd1e7 100644 --- a/lib/crewai/src/crewai/a2a/utils/agent_card.py +++ b/lib/crewai/src/crewai/a2a/utils/agent_card.py @@ -5,6 +5,7 @@ from __future__ import annotations import asyncio from collections.abc import MutableMapping from functools import lru_cache +import ssl import time from types import MethodType from typing import TYPE_CHECKING @@ -15,7 +16,7 @@ from aiocache import cached # type: ignore[import-untyped] from aiocache.serializers import PickleSerializer # type: ignore[import-untyped] import httpx -from crewai.a2a.auth.schemas import APIKeyAuth, HTTPDigestAuth +from crewai.a2a.auth.client_schemes import APIKeyAuth, HTTPDigestAuth from crewai.a2a.auth.utils import ( _auth_store, configure_auth_client, @@ -32,11 +33,51 @@ from crewai.events.types.a2a_events import ( if TYPE_CHECKING: - from crewai.a2a.auth.schemas import AuthScheme + from crewai.a2a.auth.client_schemes import ClientAuthScheme from crewai.agent import Agent from crewai.task import Task +def _get_tls_verify(auth: ClientAuthScheme | None) -> ssl.SSLContext | bool | str: + """Get TLS verify parameter from auth scheme. + + Args: + auth: Optional authentication scheme with TLS config. + + Returns: + SSL context, CA cert path, True for default verification, + or False if verification disabled. + """ + if auth and auth.tls: + return auth.tls.get_httpx_ssl_context() + return True + + +async def _prepare_auth_headers( + auth: ClientAuthScheme | None, + timeout: int, +) -> tuple[MutableMapping[str, str], ssl.SSLContext | bool | str]: + """Prepare authentication headers and TLS verification settings. + + Args: + auth: Optional authentication scheme. + timeout: Request timeout in seconds. + + Returns: + Tuple of (headers dict, TLS verify setting). + """ + headers: MutableMapping[str, str] = {} + verify = _get_tls_verify(auth) + if auth: + async with httpx.AsyncClient( + timeout=timeout, verify=verify + ) as temp_auth_client: + if isinstance(auth, (HTTPDigestAuth, APIKeyAuth)): + configure_auth_client(auth, temp_auth_client) + headers = await auth.apply_auth(temp_auth_client, {}) + return headers, verify + + def _get_server_config(agent: Agent) -> A2AServerConfig | None: """Get A2AServerConfig from an agent's a2a configuration. @@ -59,7 +100,7 @@ def _get_server_config(agent: Agent) -> A2AServerConfig | None: def fetch_agent_card( endpoint: str, - auth: AuthScheme | None = None, + auth: ClientAuthScheme | None = None, timeout: int = 30, use_cache: bool = True, cache_ttl: int = 300, @@ -68,7 +109,7 @@ def fetch_agent_card( Args: endpoint: A2A agent endpoint URL (AgentCard URL). - auth: Optional AuthScheme for authentication. + auth: Optional ClientAuthScheme for authentication. timeout: Request timeout in seconds. use_cache: Whether to use caching (default True). cache_ttl: Cache TTL in seconds (default 300 = 5 minutes). @@ -90,10 +131,10 @@ def fetch_agent_card( "_authorization_callback", } ) - auth_hash = hash((type(auth).__name__, auth_data)) + auth_hash = _auth_store.compute_key(type(auth).__name__, auth_data) else: - auth_hash = 0 - _auth_store[auth_hash] = auth + auth_hash = _auth_store.compute_key("none", "") + _auth_store.set(auth_hash, auth) ttl_hash = int(time.time() // cache_ttl) return _fetch_agent_card_cached(endpoint, auth_hash, timeout, ttl_hash) @@ -109,7 +150,7 @@ def fetch_agent_card( async def afetch_agent_card( endpoint: str, - auth: AuthScheme | None = None, + auth: ClientAuthScheme | None = None, timeout: int = 30, use_cache: bool = True, ) -> AgentCard: @@ -119,7 +160,7 @@ async def afetch_agent_card( Args: endpoint: A2A agent endpoint URL (AgentCard URL). - auth: Optional AuthScheme for authentication. + auth: Optional ClientAuthScheme for authentication. timeout: Request timeout in seconds. use_cache: Whether to use caching (default True). @@ -140,10 +181,10 @@ async def afetch_agent_card( "_authorization_callback", } ) - auth_hash = hash((type(auth).__name__, auth_data)) + auth_hash = _auth_store.compute_key(type(auth).__name__, auth_data) else: - auth_hash = 0 - _auth_store[auth_hash] = auth + auth_hash = _auth_store.compute_key("none", "") + _auth_store.set(auth_hash, auth) agent_card: AgentCard = await _afetch_agent_card_cached( endpoint, auth_hash, timeout ) @@ -155,7 +196,7 @@ async def afetch_agent_card( @lru_cache() def _fetch_agent_card_cached( endpoint: str, - auth_hash: int, + auth_hash: str, timeout: int, _ttl_hash: int, ) -> AgentCard: @@ -175,7 +216,7 @@ def _fetch_agent_card_cached( @cached(ttl=300, serializer=PickleSerializer()) # type: ignore[untyped-decorator] async def _afetch_agent_card_cached( endpoint: str, - auth_hash: int, + auth_hash: str, timeout: int, ) -> AgentCard: """Cached async implementation of AgentCard fetching.""" @@ -185,7 +226,7 @@ async def _afetch_agent_card_cached( async def _afetch_agent_card_impl( endpoint: str, - auth: AuthScheme | None, + auth: ClientAuthScheme | None, timeout: int, ) -> AgentCard: """Internal async implementation of AgentCard fetching.""" @@ -197,16 +238,17 @@ async def _afetch_agent_card_impl( else: url_parts = endpoint.split("/", 3) base_url = f"{url_parts[0]}//{url_parts[2]}" - agent_card_path = f"/{url_parts[3]}" if len(url_parts) > 3 else "/" + agent_card_path = ( + f"/{url_parts[3]}" + if len(url_parts) > 3 and url_parts[3] + else "/.well-known/agent-card.json" + ) - headers: MutableMapping[str, str] = {} - if auth: - async with httpx.AsyncClient(timeout=timeout) as temp_auth_client: - if isinstance(auth, (HTTPDigestAuth, APIKeyAuth)): - configure_auth_client(auth, temp_auth_client) - headers = await auth.apply_auth(temp_auth_client, {}) + headers, verify = await _prepare_auth_headers(auth, timeout) - async with httpx.AsyncClient(timeout=timeout, headers=headers) as temp_client: + async with httpx.AsyncClient( + timeout=timeout, headers=headers, verify=verify + ) as temp_client: if auth and isinstance(auth, (HTTPDigestAuth, APIKeyAuth)): configure_auth_client(auth, temp_client) @@ -434,6 +476,7 @@ def _agent_to_agent_card(agent: Agent, url: str) -> AgentCard: """Generate an A2A AgentCard from an Agent instance. Uses A2AServerConfig values when available, falling back to agent properties. + If signing_config is provided, the card will be signed with JWS. Args: agent: The Agent instance to generate a card for. @@ -442,6 +485,8 @@ def _agent_to_agent_card(agent: Agent, url: str) -> AgentCard: Returns: AgentCard describing the agent's capabilities. """ + from crewai.a2a.utils.agent_card_signing import sign_agent_card + server_config = _get_server_config(agent) or A2AServerConfig() name = server_config.name or agent.role @@ -472,15 +517,31 @@ def _agent_to_agent_card(agent: Agent, url: str) -> AgentCard: ) ) - return AgentCard( + capabilities = server_config.capabilities + if server_config.server_extensions: + from crewai.a2a.extensions.server import ServerExtensionRegistry + + registry = ServerExtensionRegistry(server_config.server_extensions) + ext_list = registry.get_agent_extensions() + + existing_exts = list(capabilities.extensions) if capabilities.extensions else [] + existing_uris = {e.uri for e in existing_exts} + for ext in ext_list: + if ext.uri not in existing_uris: + existing_exts.append(ext) + + capabilities = capabilities.model_copy(update={"extensions": existing_exts}) + + card = AgentCard( name=name, description=description, url=server_config.url or url, version=server_config.version, - capabilities=server_config.capabilities, + capabilities=capabilities, default_input_modes=server_config.default_input_modes, default_output_modes=server_config.default_output_modes, skills=skills, + preferred_transport=server_config.transport.preferred, protocol_version=server_config.protocol_version, provider=server_config.provider, documentation_url=server_config.documentation_url, @@ -489,9 +550,21 @@ def _agent_to_agent_card(agent: Agent, url: str) -> AgentCard: security=server_config.security, security_schemes=server_config.security_schemes, supports_authenticated_extended_card=server_config.supports_authenticated_extended_card, - signatures=server_config.signatures, ) + if server_config.signing_config: + signature = sign_agent_card( + card, + private_key=server_config.signing_config.get_private_key(), + key_id=server_config.signing_config.key_id, + algorithm=server_config.signing_config.algorithm, + ) + card = card.model_copy(update={"signatures": [signature]}) + elif server_config.signatures: + card = card.model_copy(update={"signatures": server_config.signatures}) + + return card + def inject_a2a_server_methods(agent: Agent) -> None: """Inject A2A server methods onto an Agent instance. diff --git a/lib/crewai/src/crewai/a2a/utils/agent_card_signing.py b/lib/crewai/src/crewai/a2a/utils/agent_card_signing.py new file mode 100644 index 000000000..d869020af --- /dev/null +++ b/lib/crewai/src/crewai/a2a/utils/agent_card_signing.py @@ -0,0 +1,236 @@ +"""AgentCard JWS signing utilities. + +This module provides functions for signing and verifying AgentCards using +JSON Web Signatures (JWS) as per RFC 7515. Signed agent cards allow clients +to verify the authenticity and integrity of agent card information. + +Example: + >>> from crewai.a2a.utils.agent_card_signing import sign_agent_card + >>> signature = sign_agent_card(agent_card, private_key_pem, key_id="key-1") + >>> card_with_sig = card.model_copy(update={"signatures": [signature]}) +""" + +from __future__ import annotations + +import base64 +import json +import logging +from typing import Any, Literal + +from a2a.types import AgentCard, AgentCardSignature +import jwt +from pydantic import SecretStr + + +logger = logging.getLogger(__name__) + + +SigningAlgorithm = Literal[ + "RS256", "RS384", "RS512", "ES256", "ES384", "ES512", "PS256", "PS384", "PS512" +] + + +def _normalize_private_key(private_key: str | bytes | SecretStr) -> bytes: + """Normalize private key to bytes format. + + Args: + private_key: PEM-encoded private key as string, bytes, or SecretStr. + + Returns: + Private key as bytes. + """ + if isinstance(private_key, SecretStr): + private_key = private_key.get_secret_value() + if isinstance(private_key, str): + private_key = private_key.encode() + return private_key + + +def _serialize_agent_card(agent_card: AgentCard) -> str: + """Serialize AgentCard to canonical JSON for signing. + + Excludes the signatures field to avoid circular reference during signing. + Uses sorted keys and compact separators for deterministic output. + + Args: + agent_card: The AgentCard to serialize. + + Returns: + Canonical JSON string representation. + """ + card_dict = agent_card.model_dump(exclude={"signatures"}, exclude_none=True) + return json.dumps(card_dict, sort_keys=True, separators=(",", ":")) + + +def _base64url_encode(data: bytes | str) -> str: + """Encode data to URL-safe base64 without padding. + + Args: + data: Data to encode. + + Returns: + URL-safe base64 encoded string without padding. + """ + if isinstance(data, str): + data = data.encode() + return base64.urlsafe_b64encode(data).rstrip(b"=").decode("ascii") + + +def sign_agent_card( + agent_card: AgentCard, + private_key: str | bytes | SecretStr, + key_id: str | None = None, + algorithm: SigningAlgorithm = "RS256", +) -> AgentCardSignature: + """Sign an AgentCard using JWS (RFC 7515). + + Creates a detached JWS signature for the AgentCard. The signature covers + all fields except the signatures field itself. + + Args: + agent_card: The AgentCard to sign. + private_key: PEM-encoded private key (RSA, EC, or RSA-PSS). + key_id: Optional key identifier for the JWS header (kid claim). + algorithm: Signing algorithm (RS256, ES256, PS256, etc.). + + Returns: + AgentCardSignature with protected header and signature. + + Raises: + jwt.exceptions.InvalidKeyError: If the private key is invalid. + ValueError: If the algorithm is not supported for the key type. + + Example: + >>> signature = sign_agent_card( + ... agent_card, + ... private_key_pem="-----BEGIN PRIVATE KEY-----...", + ... key_id="my-key-id", + ... ) + """ + key_bytes = _normalize_private_key(private_key) + payload = _serialize_agent_card(agent_card) + + protected_header: dict[str, Any] = {"typ": "JWS"} + if key_id: + protected_header["kid"] = key_id + + jws_token = jwt.api_jws.encode( + payload.encode(), + key_bytes, + algorithm=algorithm, + headers=protected_header, + ) + + parts = jws_token.split(".") + protected_b64 = parts[0] + signature_b64 = parts[2] + + header: dict[str, Any] | None = None + if key_id: + header = {"kid": key_id} + + return AgentCardSignature( + protected=protected_b64, + signature=signature_b64, + header=header, + ) + + +def verify_agent_card_signature( + agent_card: AgentCard, + signature: AgentCardSignature, + public_key: str | bytes, + algorithms: list[str] | None = None, +) -> bool: + """Verify an AgentCard JWS signature. + + Validates that the signature was created with the corresponding private key + and that the AgentCard content has not been modified. + + Args: + agent_card: The AgentCard to verify. + signature: The AgentCardSignature to validate. + public_key: PEM-encoded public key (RSA, EC, or RSA-PSS). + algorithms: List of allowed algorithms. Defaults to common asymmetric algorithms. + + Returns: + True if signature is valid, False otherwise. + + Example: + >>> is_valid = verify_agent_card_signature( + ... agent_card, signature, public_key_pem="-----BEGIN PUBLIC KEY-----..." + ... ) + """ + if algorithms is None: + algorithms = [ + "RS256", + "RS384", + "RS512", + "ES256", + "ES384", + "ES512", + "PS256", + "PS384", + "PS512", + ] + + if isinstance(public_key, str): + public_key = public_key.encode() + + payload = _serialize_agent_card(agent_card) + payload_b64 = _base64url_encode(payload) + jws_token = f"{signature.protected}.{payload_b64}.{signature.signature}" + + try: + jwt.api_jws.decode( + jws_token, + public_key, + algorithms=algorithms, + ) + return True + except jwt.InvalidSignatureError: + logger.debug( + "AgentCard signature verification failed", + extra={"reason": "invalid_signature"}, + ) + return False + except jwt.DecodeError as e: + logger.debug( + "AgentCard signature verification failed", + extra={"reason": "decode_error", "error": str(e)}, + ) + return False + except jwt.InvalidAlgorithmError as e: + logger.debug( + "AgentCard signature verification failed", + extra={"reason": "algorithm_error", "error": str(e)}, + ) + return False + + +def get_key_id_from_signature(signature: AgentCardSignature) -> str | None: + """Extract the key ID (kid) from an AgentCardSignature. + + Checks both the unprotected header and the protected header for the kid claim. + + Args: + signature: The AgentCardSignature to extract from. + + Returns: + The key ID if present, None otherwise. + """ + if signature.header and "kid" in signature.header: + kid: str = signature.header["kid"] + return kid + + try: + protected = signature.protected + padding_needed = 4 - (len(protected) % 4) + if padding_needed != 4: + protected += "=" * padding_needed + + protected_json = base64.urlsafe_b64decode(protected).decode() + protected_header: dict[str, Any] = json.loads(protected_json) + return protected_header.get("kid") + except (ValueError, json.JSONDecodeError): + return None diff --git a/lib/crewai/src/crewai/a2a/utils/content_type.py b/lib/crewai/src/crewai/a2a/utils/content_type.py new file mode 100644 index 000000000..f063fef19 --- /dev/null +++ b/lib/crewai/src/crewai/a2a/utils/content_type.py @@ -0,0 +1,339 @@ +"""Content type negotiation for A2A protocol. + +This module handles negotiation of input/output MIME types between A2A clients +and servers based on AgentCard capabilities. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Annotated, Final, Literal, cast + +from a2a.types import Part + +from crewai.events.event_bus import crewai_event_bus +from crewai.events.types.a2a_events import A2AContentTypeNegotiatedEvent + + +if TYPE_CHECKING: + from a2a.types import AgentCard, AgentSkill + + +TEXT_PLAIN: Literal["text/plain"] = "text/plain" +APPLICATION_JSON: Literal["application/json"] = "application/json" +IMAGE_PNG: Literal["image/png"] = "image/png" +IMAGE_JPEG: Literal["image/jpeg"] = "image/jpeg" +IMAGE_WILDCARD: Literal["image/*"] = "image/*" +APPLICATION_PDF: Literal["application/pdf"] = "application/pdf" +APPLICATION_OCTET_STREAM: Literal["application/octet-stream"] = ( + "application/octet-stream" +) + +DEFAULT_CLIENT_INPUT_MODES: Final[list[Literal["text/plain", "application/json"]]] = [ + TEXT_PLAIN, + APPLICATION_JSON, +] +DEFAULT_CLIENT_OUTPUT_MODES: Final[list[Literal["text/plain", "application/json"]]] = [ + TEXT_PLAIN, + APPLICATION_JSON, +] + + +@dataclass +class NegotiatedContentTypes: + """Result of content type negotiation.""" + + input_modes: Annotated[list[str], "Negotiated input MIME types the client can send"] + output_modes: Annotated[ + list[str], "Negotiated output MIME types the server will produce" + ] + effective_input_modes: Annotated[list[str], "Server's effective input modes"] + effective_output_modes: Annotated[list[str], "Server's effective output modes"] + skill_name: Annotated[ + str | None, "Skill name if negotiation was skill-specific" + ] = None + + +class ContentTypeNegotiationError(Exception): + """Raised when no compatible content types can be negotiated.""" + + def __init__( + self, + client_input_modes: list[str], + client_output_modes: list[str], + server_input_modes: list[str], + server_output_modes: list[str], + direction: str = "both", + message: str | None = None, + ) -> None: + self.client_input_modes = client_input_modes + self.client_output_modes = client_output_modes + self.server_input_modes = server_input_modes + self.server_output_modes = server_output_modes + self.direction = direction + + if message is None: + if direction == "input": + message = ( + f"No compatible input content types. " + f"Client supports: {client_input_modes}, " + f"Server accepts: {server_input_modes}" + ) + elif direction == "output": + message = ( + f"No compatible output content types. " + f"Client accepts: {client_output_modes}, " + f"Server produces: {server_output_modes}" + ) + else: + message = ( + f"No compatible content types. " + f"Input - Client: {client_input_modes}, Server: {server_input_modes}. " + f"Output - Client: {client_output_modes}, Server: {server_output_modes}" + ) + + super().__init__(message) + + +def _normalize_mime_type(mime_type: str) -> str: + """Normalize MIME type for comparison (lowercase, strip whitespace).""" + return mime_type.lower().strip() + + +def _mime_types_compatible(client_type: str, server_type: str) -> bool: + """Check if two MIME types are compatible. + + Handles wildcards like image/* matching image/png. + """ + client_normalized = _normalize_mime_type(client_type) + server_normalized = _normalize_mime_type(server_type) + + if client_normalized == server_normalized: + return True + + if "*" in client_normalized or "*" in server_normalized: + client_parts = client_normalized.split("/") + server_parts = server_normalized.split("/") + + if len(client_parts) == 2 and len(server_parts) == 2: + type_match = ( + client_parts[0] == server_parts[0] + or client_parts[0] == "*" + or server_parts[0] == "*" + ) + subtype_match = ( + client_parts[1] == server_parts[1] + or client_parts[1] == "*" + or server_parts[1] == "*" + ) + return type_match and subtype_match + + return False + + +def _find_compatible_modes( + client_modes: list[str], server_modes: list[str] +) -> list[str]: + """Find compatible MIME types between client and server. + + Returns modes in client preference order. + """ + compatible = [] + for client_mode in client_modes: + for server_mode in server_modes: + if _mime_types_compatible(client_mode, server_mode): + if "*" in client_mode and "*" not in server_mode: + if server_mode not in compatible: + compatible.append(server_mode) + else: + if client_mode not in compatible: + compatible.append(client_mode) + break + return compatible + + +def _get_effective_modes( + agent_card: AgentCard, + skill_name: str | None = None, +) -> tuple[list[str], list[str], AgentSkill | None]: + """Get effective input/output modes from agent card. + + If skill_name is provided and the skill has custom modes, those are used. + Otherwise, falls back to agent card defaults. + """ + skill: AgentSkill | None = None + + if skill_name and agent_card.skills: + for s in agent_card.skills: + if s.name == skill_name or s.id == skill_name: + skill = s + break + + if skill: + input_modes = ( + skill.input_modes if skill.input_modes else agent_card.default_input_modes + ) + output_modes = ( + skill.output_modes + if skill.output_modes + else agent_card.default_output_modes + ) + else: + input_modes = agent_card.default_input_modes + output_modes = agent_card.default_output_modes + + return input_modes, output_modes, skill + + +def negotiate_content_types( + agent_card: AgentCard, + client_input_modes: list[str] | None = None, + client_output_modes: list[str] | None = None, + skill_name: str | None = None, + emit_event: bool = True, + endpoint: str | None = None, + a2a_agent_name: str | None = None, + strict: bool = False, +) -> NegotiatedContentTypes: + """Negotiate content types between client and server. + + Args: + agent_card: The remote agent's card with capability info. + client_input_modes: MIME types the client can send. Defaults to text/plain and application/json. + client_output_modes: MIME types the client can accept. Defaults to text/plain and application/json. + skill_name: Optional skill to use for mode lookup. + emit_event: Whether to emit a content type negotiation event. + endpoint: Agent endpoint (for event metadata). + a2a_agent_name: Agent name (for event metadata). + strict: If True, raises error when no compatible types found. + If False, returns empty lists for incompatible directions. + + Returns: + NegotiatedContentTypes with compatible input and output modes. + + Raises: + ContentTypeNegotiationError: If strict=True and no compatible types found. + """ + if client_input_modes is None: + client_input_modes = cast(list[str], DEFAULT_CLIENT_INPUT_MODES.copy()) + if client_output_modes is None: + client_output_modes = cast(list[str], DEFAULT_CLIENT_OUTPUT_MODES.copy()) + + server_input_modes, server_output_modes, skill = _get_effective_modes( + agent_card, skill_name + ) + + compatible_input = _find_compatible_modes(client_input_modes, server_input_modes) + compatible_output = _find_compatible_modes(client_output_modes, server_output_modes) + + if strict: + if not compatible_input and not compatible_output: + raise ContentTypeNegotiationError( + client_input_modes=client_input_modes, + client_output_modes=client_output_modes, + server_input_modes=server_input_modes, + server_output_modes=server_output_modes, + ) + if not compatible_input: + raise ContentTypeNegotiationError( + client_input_modes=client_input_modes, + client_output_modes=client_output_modes, + server_input_modes=server_input_modes, + server_output_modes=server_output_modes, + direction="input", + ) + if not compatible_output: + raise ContentTypeNegotiationError( + client_input_modes=client_input_modes, + client_output_modes=client_output_modes, + server_input_modes=server_input_modes, + server_output_modes=server_output_modes, + direction="output", + ) + + result = NegotiatedContentTypes( + input_modes=compatible_input, + output_modes=compatible_output, + effective_input_modes=server_input_modes, + effective_output_modes=server_output_modes, + skill_name=skill.name if skill else None, + ) + + if emit_event: + crewai_event_bus.emit( + None, + A2AContentTypeNegotiatedEvent( + endpoint=endpoint or agent_card.url, + a2a_agent_name=a2a_agent_name or agent_card.name, + skill_name=skill_name, + client_input_modes=client_input_modes, + client_output_modes=client_output_modes, + server_input_modes=server_input_modes, + server_output_modes=server_output_modes, + negotiated_input_modes=compatible_input, + negotiated_output_modes=compatible_output, + negotiation_success=bool(compatible_input and compatible_output), + ), + ) + + return result + + +def validate_content_type( + content_type: str, + allowed_modes: list[str], +) -> bool: + """Validate that a content type is allowed by a list of modes. + + Args: + content_type: The MIME type to validate. + allowed_modes: List of allowed MIME types (may include wildcards). + + Returns: + True if content_type is compatible with any allowed mode. + """ + for mode in allowed_modes: + if _mime_types_compatible(content_type, mode): + return True + return False + + +def get_part_content_type(part: Part) -> str: + """Extract MIME type from an A2A Part. + + Args: + part: A Part object containing TextPart, DataPart, or FilePart. + + Returns: + The MIME type string for this part. + """ + root = part.root + if root.kind == "text": + return TEXT_PLAIN + if root.kind == "data": + return APPLICATION_JSON + if root.kind == "file": + return root.file.mime_type or APPLICATION_OCTET_STREAM + return APPLICATION_OCTET_STREAM + + +def validate_message_parts( + parts: list[Part], + allowed_modes: list[str], +) -> list[str]: + """Validate that all message parts have allowed content types. + + Args: + parts: List of Parts from the incoming message. + allowed_modes: List of allowed MIME types (from default_input_modes). + + Returns: + List of invalid content types found (empty if all valid). + """ + invalid_types: list[str] = [] + for part in parts: + content_type = get_part_content_type(part) + if not validate_content_type(content_type, allowed_modes): + if content_type not in invalid_types: + invalid_types.append(content_type) + return invalid_types diff --git a/lib/crewai/src/crewai/a2a/utils/delegation.py b/lib/crewai/src/crewai/a2a/utils/delegation.py index f322bbf74..b2315c13f 100644 --- a/lib/crewai/src/crewai/a2a/utils/delegation.py +++ b/lib/crewai/src/crewai/a2a/utils/delegation.py @@ -3,9 +3,10 @@ from __future__ import annotations import asyncio -from collections.abc import AsyncIterator, MutableMapping +from collections.abc import AsyncIterator, Callable, MutableMapping from contextlib import asynccontextmanager -from typing import TYPE_CHECKING, Any, Literal +import logging +from typing import TYPE_CHECKING, Any, Final, Literal import uuid from a2a.client import Client, ClientConfig, ClientFactory @@ -20,18 +21,24 @@ from a2a.types import ( import httpx from pydantic import BaseModel -from crewai.a2a.auth.schemas import APIKeyAuth, HTTPDigestAuth +from crewai.a2a.auth.client_schemes import APIKeyAuth, HTTPDigestAuth from crewai.a2a.auth.utils import ( _auth_store, configure_auth_client, validate_auth_against_agent_card, ) +from crewai.a2a.config import ClientTransportConfig, GRPCClientConfig +from crewai.a2a.extensions.registry import ( + ExtensionsMiddleware, + validate_required_extensions, +) from crewai.a2a.task_helpers import TaskStateResult from crewai.a2a.types import ( HANDLER_REGISTRY, HandlerType, PartsDict, PartsMetadataDict, + TransportType, ) from crewai.a2a.updates import ( PollingConfig, @@ -39,7 +46,20 @@ from crewai.a2a.updates import ( StreamingHandler, UpdateConfig, ) -from crewai.a2a.utils.agent_card import _afetch_agent_card_cached +from crewai.a2a.utils.agent_card import ( + _afetch_agent_card_cached, + _get_tls_verify, + _prepare_auth_headers, +) +from crewai.a2a.utils.content_type import ( + DEFAULT_CLIENT_OUTPUT_MODES, + negotiate_content_types, +) +from crewai.a2a.utils.transport import ( + NegotiatedTransport, + TransportNegotiationError, + negotiate_transport, +) from crewai.events.event_bus import crewai_event_bus from crewai.events.types.a2a_events import ( A2AConversationStartedEvent, @@ -49,10 +69,16 @@ from crewai.events.types.a2a_events import ( ) +logger = logging.getLogger(__name__) + + if TYPE_CHECKING: from a2a.types import Message - from crewai.a2a.auth.schemas import AuthScheme + from crewai.a2a.auth.client_schemes import ClientAuthScheme + + +_DEFAULT_TRANSPORT: Final[TransportType] = "JSONRPC" def get_handler(config: UpdateConfig | None) -> HandlerType: @@ -71,8 +97,7 @@ def get_handler(config: UpdateConfig | None) -> HandlerType: def execute_a2a_delegation( endpoint: str, - transport_protocol: Literal["JSONRPC", "GRPC", "HTTP+JSON"], - auth: AuthScheme | None, + auth: ClientAuthScheme | None, timeout: int, task_description: str, context: str | None = None, @@ -91,32 +116,23 @@ def execute_a2a_delegation( from_task: Any | None = None, from_agent: Any | None = None, skill_id: str | None = None, + client_extensions: list[str] | None = None, + transport: ClientTransportConfig | None = None, + accepted_output_modes: list[str] | None = None, ) -> TaskStateResult: """Execute a task delegation to a remote A2A agent synchronously. - This is the sync wrapper around aexecute_a2a_delegation. For async contexts, - use aexecute_a2a_delegation directly. + WARNING: This function blocks the entire thread by creating and running a new + event loop. Prefer using 'await aexecute_a2a_delegation()' in async contexts + for better performance and resource efficiency. + + This is a synchronous wrapper around aexecute_a2a_delegation that creates a + new event loop to run the async implementation. It is provided for compatibility + with synchronous code paths only. Args: - endpoint: A2A agent endpoint URL (AgentCard URL) - transport_protocol: Optional A2A transport protocol (grpc, jsonrpc, http+json) - auth: Optional AuthScheme for authentication (Bearer, OAuth2, API Key, HTTP Basic/Digest) - timeout: Request timeout in seconds - task_description: The task to delegate - context: Optional context information - context_id: Context ID for correlating messages/tasks - task_id: Specific task identifier - reference_task_ids: List of related task IDs - metadata: Additional metadata (external_id, request_id, etc.) - extensions: Protocol extensions for custom fields - conversation_history: Previous Message objects from conversation - agent_id: Agent identifier for logging - agent_role: Role of the CrewAI agent delegating the task - agent_branch: Optional agent tree branch for logging - response_model: Optional Pydantic model for structured outputs - turn_number: Optional turn number for multi-turn conversations - endpoint: A2A agent endpoint URL. - auth: Optional AuthScheme for authentication. + endpoint: A2A agent endpoint URL (AgentCard URL). + auth: Optional ClientAuthScheme for authentication. timeout: Request timeout in seconds. task_description: The task to delegate. context: Optional context information. @@ -135,10 +151,26 @@ def execute_a2a_delegation( from_task: Optional CrewAI Task object for event metadata. from_agent: Optional CrewAI Agent object for event metadata. skill_id: Optional skill ID to target a specific agent capability. + client_extensions: A2A protocol extension URIs the client supports. + transport: Transport configuration (preferred, supported transports, gRPC settings). + accepted_output_modes: MIME types the client can accept in responses. Returns: TaskStateResult with status, result/error, history, and agent_card. + + Raises: + RuntimeError: If called from an async context with a running event loop. """ + try: + asyncio.get_running_loop() + raise RuntimeError( + "execute_a2a_delegation() cannot be called from an async context. " + "Use 'await aexecute_a2a_delegation()' instead." + ) + except RuntimeError as e: + if "no running event loop" not in str(e).lower(): + raise + loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) try: @@ -159,12 +191,14 @@ def execute_a2a_delegation( agent_role=agent_role, agent_branch=agent_branch, response_model=response_model, - transport_protocol=transport_protocol, turn_number=turn_number, updates=updates, from_task=from_task, from_agent=from_agent, skill_id=skill_id, + client_extensions=client_extensions, + transport=transport, + accepted_output_modes=accepted_output_modes, ) ) finally: @@ -176,8 +210,7 @@ def execute_a2a_delegation( async def aexecute_a2a_delegation( endpoint: str, - transport_protocol: Literal["JSONRPC", "GRPC", "HTTP+JSON"], - auth: AuthScheme | None, + auth: ClientAuthScheme | None, timeout: int, task_description: str, context: str | None = None, @@ -196,6 +229,9 @@ async def aexecute_a2a_delegation( from_task: Any | None = None, from_agent: Any | None = None, skill_id: str | None = None, + client_extensions: list[str] | None = None, + transport: ClientTransportConfig | None = None, + accepted_output_modes: list[str] | None = None, ) -> TaskStateResult: """Execute a task delegation to a remote A2A agent asynchronously. @@ -203,25 +239,8 @@ async def aexecute_a2a_delegation( in an async context (e.g., with Crew.akickoff() or agent.aexecute_task()). Args: - endpoint: A2A agent endpoint URL - transport_protocol: Optional A2A transport protocol (grpc, jsonrpc, http+json) - auth: Optional AuthScheme for authentication - timeout: Request timeout in seconds - task_description: Task to delegate - context: Optional context - context_id: Context ID for correlation - task_id: Specific task identifier - reference_task_ids: Related task IDs - metadata: Additional metadata - extensions: Protocol extensions - conversation_history: Previous Message objects - turn_number: Current turn number - agent_branch: Agent tree branch for logging - agent_id: Agent identifier for logging - agent_role: Agent role for logging - response_model: Optional Pydantic model for structured outputs endpoint: A2A agent endpoint URL. - auth: Optional AuthScheme for authentication. + auth: Optional ClientAuthScheme for authentication. timeout: Request timeout in seconds. task_description: The task to delegate. context: Optional context information. @@ -240,6 +259,9 @@ async def aexecute_a2a_delegation( from_task: Optional CrewAI Task object for event metadata. from_agent: Optional CrewAI Agent object for event metadata. skill_id: Optional skill ID to target a specific agent capability. + client_extensions: A2A protocol extension URIs the client supports. + transport: Transport configuration (preferred, supported transports, gRPC settings). + accepted_output_modes: MIME types the client can accept in responses. Returns: TaskStateResult with status, result/error, history, and agent_card. @@ -271,10 +293,12 @@ async def aexecute_a2a_delegation( agent_role=agent_role, response_model=response_model, updates=updates, - transport_protocol=transport_protocol, from_task=from_task, from_agent=from_agent, skill_id=skill_id, + client_extensions=client_extensions, + transport=transport, + accepted_output_modes=accepted_output_modes, ) except Exception as e: crewai_event_bus.emit( @@ -294,7 +318,7 @@ async def aexecute_a2a_delegation( ) raise - agent_card_data: dict[str, Any] = result.get("agent_card") or {} + agent_card_data = result.get("agent_card") crewai_event_bus.emit( agent_branch, A2ADelegationCompletedEvent( @@ -306,7 +330,7 @@ async def aexecute_a2a_delegation( endpoint=endpoint, a2a_agent_name=result.get("a2a_agent_name"), agent_card=agent_card_data, - provider=agent_card_data.get("provider"), + provider=agent_card_data.get("provider") if agent_card_data else None, metadata=metadata, extensions=list(extensions.keys()) if extensions else None, from_task=from_task, @@ -319,8 +343,7 @@ async def aexecute_a2a_delegation( async def _aexecute_a2a_delegation_impl( endpoint: str, - transport_protocol: Literal["JSONRPC", "GRPC", "HTTP+JSON"], - auth: AuthScheme | None, + auth: ClientAuthScheme | None, timeout: int, task_description: str, context: str | None, @@ -340,8 +363,13 @@ async def _aexecute_a2a_delegation_impl( from_task: Any | None = None, from_agent: Any | None = None, skill_id: str | None = None, + client_extensions: list[str] | None = None, + transport: ClientTransportConfig | None = None, + accepted_output_modes: list[str] | None = None, ) -> TaskStateResult: """Internal async implementation of A2A delegation.""" + if transport is None: + transport = ClientTransportConfig() if auth: auth_data = auth.model_dump_json( exclude={ @@ -351,22 +379,70 @@ async def _aexecute_a2a_delegation_impl( "_authorization_callback", } ) - auth_hash = hash((type(auth).__name__, auth_data)) + auth_hash = _auth_store.compute_key(type(auth).__name__, auth_data) else: - auth_hash = 0 - _auth_store[auth_hash] = auth + auth_hash = _auth_store.compute_key("none", endpoint) + _auth_store.set(auth_hash, auth) agent_card = await _afetch_agent_card_cached( endpoint=endpoint, auth_hash=auth_hash, timeout=timeout ) validate_auth_against_agent_card(agent_card, auth) - headers: MutableMapping[str, str] = {} - if auth: - async with httpx.AsyncClient(timeout=timeout) as temp_auth_client: - if isinstance(auth, (HTTPDigestAuth, APIKeyAuth)): - configure_auth_client(auth, temp_auth_client) - headers = await auth.apply_auth(temp_auth_client, {}) + unsupported_exts = validate_required_extensions(agent_card, client_extensions) + if unsupported_exts: + ext_uris = [ext.uri for ext in unsupported_exts] + raise ValueError( + f"Agent requires extensions not supported by client: {ext_uris}" + ) + + negotiated: NegotiatedTransport | None = None + effective_transport: TransportType = transport.preferred or _DEFAULT_TRANSPORT + effective_url = endpoint + + client_transports: list[str] = ( + list(transport.supported) if transport.supported else [_DEFAULT_TRANSPORT] + ) + + try: + negotiated = negotiate_transport( + agent_card=agent_card, + client_supported_transports=client_transports, + client_preferred_transport=transport.preferred, + endpoint=endpoint, + a2a_agent_name=agent_card.name, + ) + effective_transport = negotiated.transport # type: ignore[assignment] + effective_url = negotiated.url + except TransportNegotiationError as e: + logger.warning( + "Transport negotiation failed, using fallback", + extra={ + "error": str(e), + "fallback_transport": effective_transport, + "fallback_url": effective_url, + "endpoint": endpoint, + "client_transports": client_transports, + "server_transports": [ + iface.transport for iface in agent_card.additional_interfaces or [] + ] + + [agent_card.preferred_transport or "JSONRPC"], + }, + ) + + effective_output_modes = accepted_output_modes or DEFAULT_CLIENT_OUTPUT_MODES.copy() + + content_negotiated = negotiate_content_types( + agent_card=agent_card, + client_output_modes=accepted_output_modes, + skill_name=skill_id, + endpoint=endpoint, + a2a_agent_name=agent_card.name, + ) + if content_negotiated.output_modes: + effective_output_modes = content_negotiated.output_modes + + headers, _ = await _prepare_auth_headers(auth, timeout) a2a_agent_name = None if agent_card.name: @@ -513,15 +589,22 @@ async def _aexecute_a2a_delegation_impl( use_streaming = not use_polling and push_config_for_client is None + client_agent_card = agent_card + if effective_url != agent_card.url: + client_agent_card = agent_card.model_copy(update={"url": effective_url}) + async with _create_a2a_client( - agent_card=agent_card, - transport_protocol=transport_protocol, + agent_card=client_agent_card, + transport_protocol=effective_transport, timeout=timeout, headers=headers, streaming=use_streaming, auth=auth, use_polling=use_polling, push_notification_config=push_config_for_client, + client_extensions=client_extensions, + accepted_output_modes=effective_output_modes, # type: ignore[arg-type] + grpc_config=transport.grpc, ) as client: result = await handler.execute( client=client, @@ -535,6 +618,245 @@ async def _aexecute_a2a_delegation_impl( return result +def _normalize_grpc_metadata( + metadata: tuple[tuple[str, str], ...] | None, +) -> tuple[tuple[str, str], ...] | None: + """Lowercase all gRPC metadata keys. + + gRPC requires lowercase metadata keys, but some libraries (like the A2A SDK) + use mixed-case headers like 'X-A2A-Extensions'. This normalizes them. + """ + if metadata is None: + return None + return tuple((key.lower(), value) for key, value in metadata) + + +def _create_grpc_interceptors( + auth_metadata: list[tuple[str, str]] | None = None, +) -> list[Any]: + """Create gRPC interceptors for metadata normalization and auth injection. + + Args: + auth_metadata: Optional auth metadata to inject into all calls. + Used for insecure channels that need auth (non-localhost without TLS). + + Returns a list of interceptors that lowercase metadata keys for gRPC + compatibility. Must be called after grpc is imported. + """ + import grpc.aio # type: ignore[import-untyped] + + def _merge_metadata( + existing: tuple[tuple[str, str], ...] | None, + auth: list[tuple[str, str]] | None, + ) -> tuple[tuple[str, str], ...] | None: + """Merge existing metadata with auth metadata and normalize keys.""" + merged: list[tuple[str, str]] = [] + if existing: + merged.extend(existing) + if auth: + merged.extend(auth) + if not merged: + return None + return tuple((key.lower(), value) for key, value in merged) + + def _inject_metadata(client_call_details: Any) -> Any: + """Inject merged metadata into call details.""" + return client_call_details._replace( + metadata=_merge_metadata(client_call_details.metadata, auth_metadata) + ) + + class MetadataUnaryUnary(grpc.aio.UnaryUnaryClientInterceptor): # type: ignore[misc,no-any-unimported] + """Interceptor for unary-unary calls that injects auth metadata.""" + + async def intercept_unary_unary( # type: ignore[no-untyped-def] + self, continuation, client_call_details, request + ): + """Intercept unary-unary call and inject metadata.""" + return await continuation(_inject_metadata(client_call_details), request) + + class MetadataUnaryStream(grpc.aio.UnaryStreamClientInterceptor): # type: ignore[misc,no-any-unimported] + """Interceptor for unary-stream calls that injects auth metadata.""" + + async def intercept_unary_stream( # type: ignore[no-untyped-def] + self, continuation, client_call_details, request + ): + """Intercept unary-stream call and inject metadata.""" + return await continuation(_inject_metadata(client_call_details), request) + + class MetadataStreamUnary(grpc.aio.StreamUnaryClientInterceptor): # type: ignore[misc,no-any-unimported] + """Interceptor for stream-unary calls that injects auth metadata.""" + + async def intercept_stream_unary( # type: ignore[no-untyped-def] + self, continuation, client_call_details, request_iterator + ): + """Intercept stream-unary call and inject metadata.""" + return await continuation( + _inject_metadata(client_call_details), request_iterator + ) + + class MetadataStreamStream(grpc.aio.StreamStreamClientInterceptor): # type: ignore[misc,no-any-unimported] + """Interceptor for stream-stream calls that injects auth metadata.""" + + async def intercept_stream_stream( # type: ignore[no-untyped-def] + self, continuation, client_call_details, request_iterator + ): + """Intercept stream-stream call and inject metadata.""" + return await continuation( + _inject_metadata(client_call_details), request_iterator + ) + + return [ + MetadataUnaryUnary(), + MetadataUnaryStream(), + MetadataStreamUnary(), + MetadataStreamStream(), + ] + + +def _create_grpc_channel_factory( + grpc_config: GRPCClientConfig, + auth: ClientAuthScheme | None = None, +) -> Callable[[str], Any]: + """Create a gRPC channel factory with the given configuration. + + Args: + grpc_config: gRPC client configuration with channel options. + auth: Optional ClientAuthScheme for TLS and auth configuration. + + Returns: + A callable that creates gRPC channels from URLs. + """ + try: + import grpc + except ImportError as e: + raise ImportError( + "gRPC transport requires grpcio. Install with: pip install a2a-sdk[grpc]" + ) from e + + auth_metadata: list[tuple[str, str]] = [] + + if auth is not None: + from crewai.a2a.auth.client_schemes import ( + APIKeyAuth, + BearerTokenAuth, + HTTPBasicAuth, + HTTPDigestAuth, + OAuth2AuthorizationCode, + OAuth2ClientCredentials, + ) + + if isinstance(auth, HTTPDigestAuth): + raise ValueError( + "HTTPDigestAuth is not supported with gRPC transport. " + "Digest authentication requires HTTP challenge-response flow. " + "Use BearerTokenAuth, HTTPBasicAuth, APIKeyAuth (header), or OAuth2 instead." + ) + if isinstance(auth, APIKeyAuth) and auth.location in ("query", "cookie"): + raise ValueError( + f"APIKeyAuth with location='{auth.location}' is not supported with gRPC transport. " + "gRPC only supports header-based authentication. " + "Use APIKeyAuth with location='header' instead." + ) + + if isinstance(auth, BearerTokenAuth): + auth_metadata.append(("authorization", f"Bearer {auth.token}")) + elif isinstance(auth, HTTPBasicAuth): + import base64 + + basic_credentials = f"{auth.username}:{auth.password}" + encoded = base64.b64encode(basic_credentials.encode()).decode() + auth_metadata.append(("authorization", f"Basic {encoded}")) + elif isinstance(auth, APIKeyAuth) and auth.location == "header": + header_name = auth.name.lower() + auth_metadata.append((header_name, auth.api_key)) + elif isinstance(auth, (OAuth2ClientCredentials, OAuth2AuthorizationCode)): + if auth._access_token: + auth_metadata.append(("authorization", f"Bearer {auth._access_token}")) + + def factory(url: str) -> Any: + """Create a gRPC channel for the given URL.""" + target = url + use_tls = False + + if url.startswith("grpcs://"): + target = url[8:] + use_tls = True + elif url.startswith("grpc://"): + target = url[7:] + elif url.startswith("https://"): + target = url[8:] + use_tls = True + elif url.startswith("http://"): + target = url[7:] + + options: list[tuple[str, Any]] = [] + if grpc_config.max_send_message_length is not None: + options.append( + ("grpc.max_send_message_length", grpc_config.max_send_message_length) + ) + if grpc_config.max_receive_message_length is not None: + options.append( + ( + "grpc.max_receive_message_length", + grpc_config.max_receive_message_length, + ) + ) + if grpc_config.keepalive_time_ms is not None: + options.append(("grpc.keepalive_time_ms", grpc_config.keepalive_time_ms)) + if grpc_config.keepalive_timeout_ms is not None: + options.append( + ("grpc.keepalive_timeout_ms", grpc_config.keepalive_timeout_ms) + ) + + channel_credentials = None + if auth and hasattr(auth, "tls") and auth.tls: + channel_credentials = auth.tls.get_grpc_credentials() + elif use_tls: + channel_credentials = grpc.ssl_channel_credentials() + + if channel_credentials and auth_metadata: + + class AuthMetadataPlugin(grpc.AuthMetadataPlugin): # type: ignore[misc,no-any-unimported] + """gRPC auth metadata plugin that adds auth headers as metadata.""" + + def __init__(self, metadata: list[tuple[str, str]]) -> None: + self._metadata = tuple(metadata) + + def __call__( # type: ignore[no-any-unimported] + self, + context: grpc.AuthMetadataContext, + callback: grpc.AuthMetadataPluginCallback, + ) -> None: + callback(self._metadata, None) + + call_creds = grpc.metadata_call_credentials( + AuthMetadataPlugin(auth_metadata) + ) + credentials = grpc.composite_channel_credentials( + channel_credentials, call_creds + ) + interceptors = _create_grpc_interceptors() + return grpc.aio.secure_channel( + target, credentials, options=options or None, interceptors=interceptors + ) + if channel_credentials: + interceptors = _create_grpc_interceptors() + return grpc.aio.secure_channel( + target, + channel_credentials, + options=options or None, + interceptors=interceptors, + ) + interceptors = _create_grpc_interceptors( + auth_metadata=auth_metadata if auth_metadata else None + ) + return grpc.aio.insecure_channel( + target, options=options or None, interceptors=interceptors + ) + + return factory + + @asynccontextmanager async def _create_a2a_client( agent_card: AgentCard, @@ -542,9 +864,12 @@ async def _create_a2a_client( timeout: int, headers: MutableMapping[str, str], streaming: bool, - auth: AuthScheme | None = None, + auth: ClientAuthScheme | None = None, use_polling: bool = False, push_notification_config: PushNotificationConfig | None = None, + client_extensions: list[str] | None = None, + accepted_output_modes: list[str] | None = None, + grpc_config: GRPCClientConfig | None = None, ) -> AsyncIterator[Client]: """Create and configure an A2A client. @@ -554,16 +879,21 @@ async def _create_a2a_client( timeout: Request timeout in seconds. headers: HTTP headers (already with auth applied). streaming: Enable streaming responses. - auth: Optional AuthScheme for client configuration. + auth: Optional ClientAuthScheme for client configuration. use_polling: Enable polling mode. push_notification_config: Optional push notification config. + client_extensions: A2A protocol extension URIs to declare support for. + accepted_output_modes: MIME types the client can accept in responses. + grpc_config: Optional gRPC client configuration. Yields: Configured A2A client instance. """ + verify = _get_tls_verify(auth) async with httpx.AsyncClient( timeout=timeout, headers=headers, + verify=verify, ) as httpx_client: if auth and isinstance(auth, (HTTPDigestAuth, APIKeyAuth)): configure_auth_client(auth, httpx_client) @@ -579,15 +909,27 @@ async def _create_a2a_client( ) ) + grpc_channel_factory = None + if transport_protocol == "GRPC": + grpc_channel_factory = _create_grpc_channel_factory( + grpc_config or GRPCClientConfig(), + auth=auth, + ) + config = ClientConfig( httpx_client=httpx_client, supported_transports=[transport_protocol], streaming=streaming and not use_polling, polling=use_polling, - accepted_output_modes=["application/json"], + accepted_output_modes=accepted_output_modes or DEFAULT_CLIENT_OUTPUT_MODES, # type: ignore[arg-type] push_notification_configs=push_configs, + grpc_channel_factory=grpc_channel_factory, ) factory = ClientFactory(config) client = factory.create(agent_card) + + if client_extensions: + await client.add_request_middleware(ExtensionsMiddleware(client_extensions)) + yield client diff --git a/lib/crewai/src/crewai/a2a/utils/logging.py b/lib/crewai/src/crewai/a2a/utils/logging.py new file mode 100644 index 000000000..585d1d8f3 --- /dev/null +++ b/lib/crewai/src/crewai/a2a/utils/logging.py @@ -0,0 +1,131 @@ +"""Structured JSON logging utilities for A2A module.""" + +from __future__ import annotations + +from contextvars import ContextVar +from datetime import datetime, timezone +import json +import logging +from typing import Any + + +_log_context: ContextVar[dict[str, Any] | None] = ContextVar( + "log_context", default=None +) + + +class JSONFormatter(logging.Formatter): + """JSON formatter for structured logging. + + Outputs logs as JSON with consistent fields for log aggregators. + """ + + def format(self, record: logging.LogRecord) -> str: + """Format log record as JSON string.""" + log_data: dict[str, Any] = { + "timestamp": datetime.now(timezone.utc).isoformat(), + "level": record.levelname, + "logger": record.name, + "message": record.getMessage(), + } + + if record.exc_info: + log_data["exception"] = self.formatException(record.exc_info) + + context = _log_context.get() + if context is not None: + log_data.update(context) + + if hasattr(record, "task_id"): + log_data["task_id"] = record.task_id + if hasattr(record, "context_id"): + log_data["context_id"] = record.context_id + if hasattr(record, "agent"): + log_data["agent"] = record.agent + if hasattr(record, "endpoint"): + log_data["endpoint"] = record.endpoint + if hasattr(record, "extension"): + log_data["extension"] = record.extension + if hasattr(record, "error"): + log_data["error"] = record.error + + for key, value in record.__dict__.items(): + if key.startswith("_") or key in ( + "name", + "msg", + "args", + "created", + "filename", + "funcName", + "levelname", + "levelno", + "lineno", + "module", + "msecs", + "pathname", + "process", + "processName", + "relativeCreated", + "stack_info", + "exc_info", + "exc_text", + "thread", + "threadName", + "taskName", + "message", + ): + continue + if key not in log_data: + log_data[key] = value + + return json.dumps(log_data, default=str) + + +class LogContext: + """Context manager for adding fields to all logs within a scope. + + Example: + with LogContext(task_id="abc", context_id="xyz"): + logger.info("Processing task") # Includes task_id and context_id + """ + + def __init__(self, **fields: Any) -> None: + self._fields = fields + self._token: Any = None + + def __enter__(self) -> LogContext: + current = _log_context.get() or {} + new_context = {**current, **self._fields} + self._token = _log_context.set(new_context) + return self + + def __exit__(self, *args: Any) -> None: + _log_context.reset(self._token) + + +def configure_json_logging(logger_name: str = "crewai.a2a") -> None: + """Configure JSON logging for the A2A module. + + Args: + logger_name: Logger name to configure. + """ + logger = logging.getLogger(logger_name) + + for handler in logger.handlers[:]: + logger.removeHandler(handler) + + handler = logging.StreamHandler() + handler.setFormatter(JSONFormatter()) + logger.addHandler(handler) + + +def get_logger(name: str) -> logging.Logger: + """Get a logger configured for structured JSON output. + + Args: + name: Logger name. + + Returns: + Configured logger instance. + """ + return logging.getLogger(name) diff --git a/lib/crewai/src/crewai/a2a/utils/task.py b/lib/crewai/src/crewai/a2a/utils/task.py index 479a3e1c9..63868b841 100644 --- a/lib/crewai/src/crewai/a2a/utils/task.py +++ b/lib/crewai/src/crewai/a2a/utils/task.py @@ -7,26 +7,37 @@ import base64 from collections.abc import Callable, Coroutine from datetime import datetime from functools import wraps +import json import logging import os -from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar, cast +from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar, TypedDict, cast from urllib.parse import urlparse from a2a.server.agent_execution import RequestContext from a2a.server.events import EventQueue from a2a.types import ( + Artifact, InternalError, InvalidParamsError, Message, + Part, Task as A2ATask, TaskState, TaskStatus, TaskStatusUpdateEvent, ) -from a2a.utils import new_agent_text_message, new_text_artifact +from a2a.utils import ( + get_data_parts, + new_agent_text_message, + new_data_artifact, + new_text_artifact, +) from a2a.utils.errors import ServerError from aiocache import SimpleMemoryCache, caches # type: ignore[import-untyped] +from pydantic import BaseModel +from crewai.a2a.utils.agent_card import _get_server_config +from crewai.a2a.utils.content_type import validate_message_parts from crewai.events.event_bus import crewai_event_bus from crewai.events.types.a2a_events import ( A2AServerTaskCanceledEvent, @@ -35,9 +46,11 @@ from crewai.events.types.a2a_events import ( A2AServerTaskStartedEvent, ) from crewai.task import Task +from crewai.utilities.pydantic_schema_utils import create_model_from_schema if TYPE_CHECKING: + from crewai.a2a.extensions.server import ExtensionContext, ServerExtensionRegistry from crewai.agent import Agent @@ -47,7 +60,17 @@ P = ParamSpec("P") T = TypeVar("T") -def _parse_redis_url(url: str) -> dict[str, Any]: +class RedisCacheConfig(TypedDict, total=False): + """Configuration for aiocache Redis backend.""" + + cache: str + endpoint: str + port: int + db: int + password: str + + +def _parse_redis_url(url: str) -> RedisCacheConfig: """Parse a Redis URL into aiocache configuration. Args: @@ -56,9 +79,8 @@ def _parse_redis_url(url: str) -> dict[str, Any]: Returns: Configuration dict for aiocache.RedisCache. """ - parsed = urlparse(url) - config: dict[str, Any] = { + config: RedisCacheConfig = { "cache": "aiocache.RedisCache", "endpoint": parsed.hostname or "localhost", "port": parsed.port or 6379, @@ -138,7 +160,10 @@ def cancellable( if message["type"] == "message": return True except (OSError, ConnectionError) as e: - logger.warning("Cancel watcher error for task_id=%s: %s", task_id, e) + logger.warning( + "Cancel watcher Redis error, falling back to polling", + extra={"task_id": task_id, "error": str(e)}, + ) return await poll_for_cancel() return False @@ -166,7 +191,67 @@ def cancellable( return wrapper -@cancellable +def _extract_response_schema(parts: list[Part]) -> dict[str, Any] | None: + """Extract response schema from message parts metadata. + + The client may include a JSON schema in TextPart metadata to specify + the expected response format (see delegation.py line 463). + + Args: + parts: List of message parts. + + Returns: + JSON schema dict if found, None otherwise. + """ + for part in parts: + if part.root.kind == "text" and part.root.metadata: + schema = part.root.metadata.get("schema") + if schema and isinstance(schema, dict): + return schema # type: ignore[no-any-return] + return None + + +def _create_result_artifact( + result: Any, + task_id: str, +) -> Artifact: + """Create artifact from task result, using DataPart for structured data. + + Args: + result: The task execution result. + task_id: The task ID for naming the artifact. + + Returns: + Artifact with appropriate part type (DataPart for dict/Pydantic, TextPart for strings). + """ + artifact_name = f"result_{task_id}" + if isinstance(result, dict): + return new_data_artifact(artifact_name, result) + if isinstance(result, BaseModel): + return new_data_artifact(artifact_name, result.model_dump()) + return new_text_artifact(artifact_name, str(result)) + + +def _build_task_description( + user_message: str, + structured_inputs: list[dict[str, Any]], +) -> str: + """Build task description including structured data if present. + + Args: + user_message: The original user message text. + structured_inputs: List of structured data from DataParts. + + Returns: + Task description with structured data appended if present. + """ + if not structured_inputs: + return user_message + + structured_json = json.dumps(structured_inputs, indent=2) + return f"{user_message}\n\nStructured Data:\n{structured_json}" + + async def execute( agent: Agent, context: RequestContext, @@ -178,15 +263,52 @@ async def execute( agent: The CrewAI agent to execute the task. context: The A2A request context containing the user's message. event_queue: The event queue for sending responses back. - - TODOs: - * need to impl both of structured output and file inputs, depends on `file_inputs` for - `crewai.task.Task`, pass the below two to Task. both utils in `a2a.utils.parts` - * structured outputs ingestion, `structured_inputs = get_data_parts(parts=context.message.parts)` - * file inputs ingestion, `file_inputs = get_file_parts(parts=context.message.parts)` """ + await _execute_impl(agent, context, event_queue, None, None) + + +@cancellable +async def _execute_impl( + agent: Agent, + context: RequestContext, + event_queue: EventQueue, + extension_registry: ServerExtensionRegistry | None, + extension_context: ExtensionContext | None, +) -> None: + """Internal implementation for task execution with optional extensions.""" + server_config = _get_server_config(agent) + if context.message and context.message.parts and server_config: + allowed_modes = server_config.default_input_modes + invalid_types = validate_message_parts(context.message.parts, allowed_modes) + if invalid_types: + raise ServerError( + InvalidParamsError( + message=f"Unsupported content type(s): {', '.join(invalid_types)}. " + f"Supported: {', '.join(allowed_modes)}" + ) + ) + + if extension_registry and extension_context: + await extension_registry.invoke_on_request(extension_context) user_message = context.get_user_input() + + response_model: type[BaseModel] | None = None + structured_inputs: list[dict[str, Any]] = [] + + if context.message and context.message.parts: + schema = _extract_response_schema(context.message.parts) + if schema: + try: + response_model = create_model_from_schema(schema) + except Exception as e: + logger.debug( + "Failed to create response model from schema", + extra={"error": str(e), "schema_title": schema.get("title")}, + ) + + structured_inputs = get_data_parts(context.message.parts) + task_id = context.task_id context_id = context.context_id if task_id is None or context_id is None: @@ -203,9 +325,10 @@ async def execute( raise ServerError(InvalidParamsError(message=msg)) from None task = Task( - description=user_message, + description=_build_task_description(user_message, structured_inputs), expected_output="Response to the user's request", agent=agent, + response_model=response_model, ) crewai_event_bus.emit( @@ -220,6 +343,10 @@ async def execute( try: result = await agent.aexecute_task(task=task, tools=agent.tools) + if extension_registry and extension_context: + result = await extension_registry.invoke_on_response( + extension_context, result + ) result_str = str(result) history: list[Message] = [context.message] if context.message else [] history.append(new_agent_text_message(result_str, context_id, task_id)) @@ -227,8 +354,8 @@ async def execute( A2ATask( id=task_id, context_id=context_id, - status=TaskStatus(state=TaskState.input_required), - artifacts=[new_text_artifact(result_str, f"result_{task_id}")], + status=TaskStatus(state=TaskState.completed), + artifacts=[_create_result_artifact(result, task_id)], history=history, ) ) @@ -269,6 +396,27 @@ async def execute( ) from e +async def execute_with_extensions( + agent: Agent, + context: RequestContext, + event_queue: EventQueue, + extension_registry: ServerExtensionRegistry, + extension_context: ExtensionContext, +) -> None: + """Execute an A2A task with extension hooks. + + Args: + agent: The CrewAI agent to execute the task. + context: The A2A request context containing the user's message. + event_queue: The event queue for sending responses back. + extension_registry: Registry of server extensions. + extension_context: Context for extension hooks. + """ + await _execute_impl( + agent, context, event_queue, extension_registry, extension_context + ) + + async def cancel( context: RequestContext, event_queue: EventQueue, diff --git a/lib/crewai/src/crewai/a2a/utils/transport.py b/lib/crewai/src/crewai/a2a/utils/transport.py new file mode 100644 index 000000000..cc57ba20c --- /dev/null +++ b/lib/crewai/src/crewai/a2a/utils/transport.py @@ -0,0 +1,215 @@ +"""Transport negotiation utilities for A2A protocol. + +This module provides functionality for negotiating the transport protocol +between an A2A client and server based on their respective capabilities +and preferences. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Final, Literal + +from a2a.types import AgentCard, AgentInterface + +from crewai.events.event_bus import crewai_event_bus +from crewai.events.types.a2a_events import A2ATransportNegotiatedEvent + + +TransportProtocol = Literal["JSONRPC", "GRPC", "HTTP+JSON"] +NegotiationSource = Literal["client_preferred", "server_preferred", "fallback"] + +JSONRPC_TRANSPORT: Literal["JSONRPC"] = "JSONRPC" +GRPC_TRANSPORT: Literal["GRPC"] = "GRPC" +HTTP_JSON_TRANSPORT: Literal["HTTP+JSON"] = "HTTP+JSON" + +DEFAULT_TRANSPORT_PREFERENCE: Final[list[TransportProtocol]] = [ + JSONRPC_TRANSPORT, + GRPC_TRANSPORT, + HTTP_JSON_TRANSPORT, +] + + +@dataclass +class NegotiatedTransport: + """Result of transport negotiation. + + Attributes: + transport: The negotiated transport protocol. + url: The URL to use for this transport. + source: How the transport was selected ('preferred', 'additional', 'fallback'). + """ + + transport: str + url: str + source: NegotiationSource + + +class TransportNegotiationError(Exception): + """Raised when no compatible transport can be negotiated.""" + + def __init__( + self, + client_transports: list[str], + server_transports: list[str], + message: str | None = None, + ) -> None: + """Initialize the error with negotiation details. + + Args: + client_transports: Transports supported by the client. + server_transports: Transports supported by the server. + message: Optional custom error message. + """ + self.client_transports = client_transports + self.server_transports = server_transports + if message is None: + message = ( + f"No compatible transport found. " + f"Client supports: {client_transports}. " + f"Server supports: {server_transports}." + ) + super().__init__(message) + + +def _get_server_interfaces(agent_card: AgentCard) -> list[AgentInterface]: + """Extract all available interfaces from an AgentCard. + + Creates a unified list of interfaces including the primary URL and + any additional interfaces declared by the agent. + + Args: + agent_card: The agent's card containing transport information. + + Returns: + List of AgentInterface objects representing all available endpoints. + """ + interfaces: list[AgentInterface] = [] + + primary_transport = agent_card.preferred_transport or JSONRPC_TRANSPORT + interfaces.append( + AgentInterface( + transport=primary_transport, + url=agent_card.url, + ) + ) + + if agent_card.additional_interfaces: + for interface in agent_card.additional_interfaces: + is_duplicate = any( + i.url == interface.url and i.transport == interface.transport + for i in interfaces + ) + if not is_duplicate: + interfaces.append(interface) + + return interfaces + + +def negotiate_transport( + agent_card: AgentCard, + client_supported_transports: list[str] | None = None, + client_preferred_transport: str | None = None, + emit_event: bool = True, + endpoint: str | None = None, + a2a_agent_name: str | None = None, +) -> NegotiatedTransport: + """Negotiate the transport protocol between client and server. + + Compares the client's supported transports with the server's available + interfaces to find a compatible transport and URL. + + Negotiation logic: + 1. If client_preferred_transport is set and server supports it → use it + 2. Otherwise, if server's preferred is in client's supported → use server's + 3. Otherwise, find first match from client's supported in server's interfaces + + Args: + agent_card: The server's AgentCard with transport information. + client_supported_transports: Transports the client can use. + Defaults to ["JSONRPC"] if not specified. + client_preferred_transport: Client's preferred transport. If set and + server supports it, takes priority over server preference. + emit_event: Whether to emit a transport negotiation event. + endpoint: Original endpoint URL for event metadata. + a2a_agent_name: Agent name for event metadata. + + Returns: + NegotiatedTransport with the selected transport, URL, and source. + + Raises: + TransportNegotiationError: If no compatible transport is found. + """ + if client_supported_transports is None: + client_supported_transports = [JSONRPC_TRANSPORT] + + client_transports = [t.upper() for t in client_supported_transports] + client_preferred = ( + client_preferred_transport.upper() if client_preferred_transport else None + ) + + server_interfaces = _get_server_interfaces(agent_card) + server_transports = [i.transport.upper() for i in server_interfaces] + + transport_to_interface: dict[str, AgentInterface] = {} + for interface in server_interfaces: + transport_upper = interface.transport.upper() + if transport_upper not in transport_to_interface: + transport_to_interface[transport_upper] = interface + + result: NegotiatedTransport | None = None + + if client_preferred and client_preferred in transport_to_interface: + interface = transport_to_interface[client_preferred] + result = NegotiatedTransport( + transport=interface.transport, + url=interface.url, + source="client_preferred", + ) + else: + server_preferred = (agent_card.preferred_transport or JSONRPC_TRANSPORT).upper() + if ( + server_preferred in client_transports + and server_preferred in transport_to_interface + ): + interface = transport_to_interface[server_preferred] + result = NegotiatedTransport( + transport=interface.transport, + url=interface.url, + source="server_preferred", + ) + else: + for transport in client_transports: + if transport in transport_to_interface: + interface = transport_to_interface[transport] + result = NegotiatedTransport( + transport=interface.transport, + url=interface.url, + source="fallback", + ) + break + + if result is None: + raise TransportNegotiationError( + client_transports=client_transports, + server_transports=server_transports, + ) + + if emit_event: + crewai_event_bus.emit( + None, + A2ATransportNegotiatedEvent( + endpoint=endpoint or agent_card.url, + a2a_agent_name=a2a_agent_name or agent_card.name, + negotiated_transport=result.transport, + negotiated_url=result.url, + source=result.source, + client_supported_transports=client_transports, + server_supported_transports=server_transports, + server_preferred_transport=agent_card.preferred_transport + or JSONRPC_TRANSPORT, + client_preferred_transport=client_preferred, + ), + ) + + return result diff --git a/lib/crewai/src/crewai/a2a/wrapper.py b/lib/crewai/src/crewai/a2a/wrapper.py index a149c46a0..e77b9fb9e 100644 --- a/lib/crewai/src/crewai/a2a/wrapper.py +++ b/lib/crewai/src/crewai/a2a/wrapper.py @@ -11,19 +11,23 @@ from concurrent.futures import ThreadPoolExecutor, as_completed from functools import wraps import json from types import MethodType -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, NamedTuple from a2a.types import Role, TaskState from pydantic import BaseModel, ValidationError from crewai.a2a.config import A2AClientConfig, A2AConfig -from crewai.a2a.extensions.base import ExtensionRegistry +from crewai.a2a.extensions.base import ( + A2AExtension, + ConversationState, + ExtensionRegistry, +) from crewai.a2a.task_helpers import TaskStateResult from crewai.a2a.templates import ( AVAILABLE_AGENTS_TEMPLATE, CONVERSATION_TURN_INFO_TEMPLATE, PREVIOUS_A2A_CONVERSATION_TEMPLATE, - REMOTE_AGENT_COMPLETED_NOTICE, + REMOTE_AGENT_RESPONSE_NOTICE, UNAVAILABLE_AGENTS_NOTICE_TEMPLATE, ) from crewai.a2a.types import AgentResponseProtocol @@ -52,6 +56,42 @@ if TYPE_CHECKING: from crewai.tools.base_tool import BaseTool +class DelegationContext(NamedTuple): + """Context prepared for A2A delegation. + + Groups all the values needed to execute a delegation to a remote A2A agent. + """ + + a2a_agents: list[A2AConfig | A2AClientConfig] + agent_response_model: type[BaseModel] | None + current_request: str + agent_id: str + agent_config: A2AConfig | A2AClientConfig + context_id: str | None + task_id: str | None + metadata: dict[str, Any] | None + extensions: dict[str, Any] | None + reference_task_ids: list[str] + original_task_description: str + max_turns: int + + +class DelegationState(NamedTuple): + """Mutable state for A2A delegation loop. + + Groups values that may change during delegation turns. + """ + + current_request: str + context_id: str | None + task_id: str | None + reference_task_ids: list[str] + conversation_history: list[Message] + agent_card: AgentCard | None + agent_card_dict: dict[str, Any] | None + agent_name: str | None + + def wrap_agent_with_a2a_instance( agent: Agent, extension_registry: ExtensionRegistry | None = None ) -> None: @@ -165,7 +205,11 @@ def _fetch_agent_cards_concurrently( agent_cards: dict[str, AgentCard] = {} failed_agents: dict[str, str] = {} - with ThreadPoolExecutor(max_workers=len(a2a_agents)) as executor: + if not a2a_agents: + return agent_cards, failed_agents + + max_workers = min(len(a2a_agents), 10) + with ThreadPoolExecutor(max_workers=max_workers) as executor: futures = { executor.submit(_fetch_card_from_config, config): config for config in a2a_agents @@ -231,7 +275,7 @@ def _execute_task_with_a2a( finally: task.description = original_description - task.description, _ = _augment_prompt_with_a2a( + task.description, _, extension_states = _augment_prompt_with_a2a( a2a_agents=a2a_agents, task_description=original_description, agent_cards=agent_cards, @@ -248,7 +292,7 @@ def _execute_task_with_a2a( if extension_registry and isinstance(agent_response, BaseModel): agent_response = extension_registry.process_response_with_all( - agent_response, {} + agent_response, extension_states ) if isinstance(agent_response, BaseModel) and isinstance( @@ -264,7 +308,7 @@ def _execute_task_with_a2a( tools=tools, agent_cards=agent_cards, original_task_description=original_description, - extension_registry=extension_registry, + _extension_registry=extension_registry, ) return str(agent_response.message) @@ -284,8 +328,8 @@ def _augment_prompt_with_a2a( max_turns: int | None = None, failed_agents: dict[str, str] | None = None, extension_registry: ExtensionRegistry | None = None, - remote_task_completed: bool = False, -) -> tuple[str, bool]: + remote_status_notice: str = "", +) -> tuple[str, bool, dict[type[A2AExtension], ConversationState]]: """Add A2A delegation instructions to prompt. Args: @@ -297,13 +341,14 @@ def _augment_prompt_with_a2a( max_turns: Maximum allowed turns (from config) failed_agents: Dictionary mapping failed agent endpoints to error messages extension_registry: Optional registry of A2A extensions + remote_status_notice: Optional notice about remote agent status to append Returns: - Tuple of (augmented prompt, disable_structured_output flag) + Tuple of (augmented prompt, disable_structured_output flag, extension_states dict) """ if not agent_cards: - return task_description, False + return task_description, False, {} agents_text = "" @@ -365,15 +410,11 @@ def _augment_prompt_with_a2a( warning=warning, ) - completion_notice = "" - if remote_task_completed and conversation_history: - completion_notice = REMOTE_AGENT_COMPLETED_NOTICE - augmented_prompt = f"""{task_description} IMPORTANT: You have the ability to delegate this task to remote A2A agents. {agents_text} -{history_text}{turn_info}{completion_notice} +{history_text}{turn_info}{remote_status_notice} """ @@ -382,7 +423,7 @@ IMPORTANT: You have the ability to delegate this task to remote A2A agents. augmented_prompt, extension_states ) - return augmented_prompt, disable_structured_output + return augmented_prompt, disable_structured_output, extension_states def _parse_agent_response( @@ -461,12 +502,41 @@ def _handle_max_turns_exceeded( raise Exception(f"A2A conversation exceeded maximum turns ({max_turns})") +def _emit_delegation_failed( + error_msg: str, + turn_num: int, + from_task: Any | None, + from_agent: Any | None, + endpoint: str | None, + a2a_agent_name: str | None, + agent_card: dict[str, Any] | None, +) -> str: + """Emit failure event and return formatted error message.""" + crewai_event_bus.emit( + None, + A2AConversationCompletedEvent( + status="failed", + final_result=None, + error=error_msg, + total_turns=turn_num + 1, + from_task=from_task, + from_agent=from_agent, + endpoint=endpoint, + a2a_agent_name=a2a_agent_name, + agent_card=agent_card, + ), + ) + return f"A2A delegation failed: {error_msg}" + + def _process_response_result( raw_result: str, disable_structured_output: bool, turn_num: int, agent_role: str, agent_response_model: type[BaseModel] | None, + extension_registry: ExtensionRegistry | None = None, + extension_states: dict[type[A2AExtension], ConversationState] | None = None, from_task: Any | None = None, from_agent: Any | None = None, endpoint: str | None = None, @@ -516,6 +586,11 @@ def _process_response_result( raw_result=raw_result, agent_response_model=agent_response_model ) + if extension_registry and isinstance(llm_response, BaseModel): + llm_response = extension_registry.process_response_with_all( + llm_response, extension_states or {} + ) + if isinstance(llm_response, BaseModel) and isinstance( llm_response, AgentResponseProtocol ): @@ -571,72 +646,103 @@ def _prepare_agent_cards_dict( return agent_cards_dict +def _init_delegation_state( + ctx: DelegationContext, + agent_cards: dict[str, AgentCard] | None, +) -> DelegationState: + """Initialize delegation state from context and agent cards. + + Args: + ctx: Delegation context with config and settings. + agent_cards: Pre-fetched agent cards. + + Returns: + Initial delegation state for the conversation loop. + """ + current_agent_card = agent_cards.get(ctx.agent_id) if agent_cards else None + return DelegationState( + current_request=ctx.current_request, + context_id=ctx.context_id, + task_id=ctx.task_id, + reference_task_ids=list(ctx.reference_task_ids), + conversation_history=[], + agent_card=current_agent_card, + agent_card_dict=current_agent_card.model_dump() if current_agent_card else None, + agent_name=current_agent_card.name if current_agent_card else None, + ) + + +def _get_turn_context( + agent_config: A2AConfig | A2AClientConfig, +) -> tuple[Any | None, list[str] | None]: + """Get context for a delegation turn. + + Returns: + Tuple of (agent_branch, accepted_output_modes). + """ + console_formatter = getattr(crewai_event_bus, "_console", None) + agent_branch = None + if console_formatter: + agent_branch = getattr( + console_formatter, "current_agent_branch", None + ) or getattr(console_formatter, "current_task_branch", None) + + accepted_output_modes = None + if isinstance(agent_config, A2AClientConfig): + accepted_output_modes = agent_config.accepted_output_modes + + return agent_branch, accepted_output_modes + + def _prepare_delegation_context( self: Agent, agent_response: AgentResponseProtocol, task: Task, original_task_description: str | None, -) -> tuple[ - list[A2AConfig | A2AClientConfig], - type[BaseModel] | None, - str, - str, - A2AConfig | A2AClientConfig, - str | None, - str | None, - dict[str, Any] | None, - dict[str, Any] | None, - list[str], - str, - int, -]: +) -> DelegationContext: """Prepare delegation context from agent response and task. Shared logic for both sync and async delegation. Returns: - Tuple containing all the context values needed for delegation. + DelegationContext with all values needed for delegation. """ a2a_agents, agent_response_model = get_a2a_agents_and_response_model(self.a2a) agent_ids = tuple(config.endpoint for config in a2a_agents) current_request = str(agent_response.message) - if hasattr(agent_response, "a2a_ids") and agent_response.a2a_ids: + if not a2a_agents: + raise ValueError("No A2A agents configured for delegation") + + if isinstance(agent_response, AgentResponseProtocol) and agent_response.a2a_ids: agent_id = agent_response.a2a_ids[0] else: - agent_id = agent_ids[0] if agent_ids else "" + agent_id = agent_ids[0] - if agent_id and agent_id not in agent_ids: - raise ValueError( - f"Unknown A2A agent ID(s): {agent_response.a2a_ids} not in {agent_ids}" - ) + if agent_id not in agent_ids: + raise ValueError(f"Unknown A2A agent ID: {agent_id} not in {agent_ids}") - agent_config = next(filter(lambda x: x.endpoint == agent_id, a2a_agents)) + agent_config = next(filter(lambda x: x.endpoint == agent_id, a2a_agents), None) + if agent_config is None: + raise ValueError(f"Agent configuration not found for endpoint: {agent_id}") task_config = task.config or {} - context_id = task_config.get("context_id") - task_id_config = task_config.get("task_id") - metadata = task_config.get("metadata") - extensions = task_config.get("extensions") - reference_task_ids = task_config.get("reference_task_ids", []) if original_task_description is None: original_task_description = task.description - max_turns = agent_config.max_turns - - return ( - a2a_agents, - agent_response_model, - current_request, - agent_id, - agent_config, - context_id, - task_id_config, - metadata, - extensions, - reference_task_ids, - original_task_description, - max_turns, + return DelegationContext( + a2a_agents=a2a_agents, + agent_response_model=agent_response_model, + current_request=current_request, + agent_id=agent_id, + agent_config=agent_config, + context_id=task_config.get("context_id"), + task_id=task_config.get("task_id"), + metadata=task_config.get("metadata"), + extensions=task_config.get("extensions"), + reference_task_ids=task_config.get("reference_task_ids", []), + original_task_description=original_task_description, + max_turns=agent_config.max_turns, ) @@ -652,20 +758,50 @@ def _handle_task_completion( endpoint: str | None = None, a2a_agent_name: str | None = None, agent_card: dict[str, Any] | None = None, -) -> tuple[str | None, str | None, list[str]]: +) -> tuple[str | None, str | None, list[str], str]: """Handle task completion state including reference task updates. + When a remote task completes, this function: + 1. Adds the completed task_id to reference_task_ids (if not already present) + 2. Clears task_id_config to signal that a new task ID should be generated for next turn + 3. Updates task.config with the reference list for subsequent A2A calls + + The reference_task_ids list tracks all completed tasks in this conversation chain, + allowing the remote agent to maintain context across multi-turn interactions. + Shared logic for both sync and async delegation. + Args: + a2a_result: Result from A2A delegation containing task status. + task: CrewAI Task object to update with reference IDs. + task_id_config: Current task ID (will be added to references if task completed). + reference_task_ids: Mutable list of completed task IDs (updated in place). + agent_config: A2A configuration with trust settings. + turn_num: Current turn number. + from_task: Optional CrewAI Task for event metadata. + from_agent: Optional CrewAI Agent for event metadata. + endpoint: A2A endpoint URL. + a2a_agent_name: Name of remote A2A agent. + agent_card: Agent card dict for event metadata. + Returns: - Tuple of (result_if_trusted, updated_task_id, updated_reference_task_ids). + Tuple of (result_if_trusted, updated_task_id, updated_reference_task_ids, remote_notice). + - result_if_trusted: Final result if trust_remote_completion_status=True, else None + - updated_task_id: None (cleared to generate new ID for next turn) + - updated_reference_task_ids: The mutated list with completed task added + - remote_notice: Template notice about remote agent response """ + remote_notice = "" if a2a_result["status"] == TaskState.completed: + remote_notice = REMOTE_AGENT_RESPONSE_NOTICE + if task_id_config is not None and task_id_config not in reference_task_ids: reference_task_ids.append(task_id_config) + if task.config is None: task.config = {} - task.config["reference_task_ids"] = reference_task_ids + task.config["reference_task_ids"] = list(reference_task_ids) + task_id_config = None if agent_config.trust_remote_completion_status: @@ -685,9 +821,9 @@ def _handle_task_completion( agent_card=agent_card, ), ) - return str(result_text), task_id_config, reference_task_ids + return str(result_text), task_id_config, reference_task_ids, remote_notice - return None, task_id_config, reference_task_ids + return None, task_id_config, reference_task_ids, remote_notice def _handle_agent_response_and_continue( @@ -705,7 +841,8 @@ def _handle_agent_response_and_continue( context: str | None, tools: list[BaseTool] | None, agent_response_model: type[BaseModel] | None, - remote_task_completed: bool = False, + extension_registry: ExtensionRegistry | None = None, + remote_status_notice: str = "", endpoint: str | None = None, a2a_agent_name: str | None = None, agent_card: dict[str, Any] | None = None, @@ -735,14 +872,18 @@ def _handle_agent_response_and_continue( """ agent_cards_dict = _prepare_agent_cards_dict(a2a_result, agent_id, agent_cards) - task.description, disable_structured_output = _augment_prompt_with_a2a( + ( + task.description, + disable_structured_output, + extension_states, + ) = _augment_prompt_with_a2a( a2a_agents=a2a_agents, task_description=original_task_description, conversation_history=conversation_history, turn_num=turn_num, max_turns=max_turns, agent_cards=agent_cards_dict, - remote_task_completed=remote_task_completed, + remote_status_notice=remote_status_notice, ) original_response_model = task.response_model @@ -760,6 +901,8 @@ def _handle_agent_response_and_continue( turn_num=turn_num, agent_role=self.role, agent_response_model=agent_response_model, + extension_registry=extension_registry, + extension_states=extension_states, from_task=task, from_agent=self, endpoint=endpoint, @@ -777,7 +920,7 @@ def _delegate_to_a2a( tools: list[BaseTool] | None, agent_cards: dict[str, AgentCard] | None = None, original_task_description: str | None = None, - extension_registry: ExtensionRegistry | None = None, + _extension_registry: ExtensionRegistry | None = None, ) -> str: """Delegate to A2A agent with multi-turn conversation support. @@ -790,7 +933,7 @@ def _delegate_to_a2a( tools: Optional tools available to the agent agent_cards: Pre-fetched agent cards from _execute_task_with_a2a original_task_description: The original task description before A2A augmentation - extension_registry: Optional registry of A2A extensions + _extension_registry: Optional registry of A2A extensions (unused, reserved for future use) Returns: Result from A2A agent @@ -798,60 +941,42 @@ def _delegate_to_a2a( Raises: ImportError: If a2a-sdk is not installed """ - ( - a2a_agents, - agent_response_model, - current_request, - agent_id, - agent_config, - context_id, - task_id_config, - metadata, - extensions, - reference_task_ids, - original_task_description, - max_turns, - ) = _prepare_delegation_context( + ctx = _prepare_delegation_context( self, agent_response, task, original_task_description ) - - conversation_history: list[Message] = [] - - current_agent_card = agent_cards.get(agent_id) if agent_cards else None - current_agent_card_dict = ( - current_agent_card.model_dump() if current_agent_card else None - ) - current_a2a_agent_name = current_agent_card.name if current_agent_card else None + state = _init_delegation_state(ctx, agent_cards) + current_request = state.current_request + context_id = state.context_id + task_id = state.task_id + reference_task_ids = state.reference_task_ids + conversation_history = state.conversation_history try: - for turn_num in range(max_turns): - console_formatter = getattr(crewai_event_bus, "_console", None) - agent_branch = None - if console_formatter: - agent_branch = getattr( - console_formatter, "current_agent_branch", None - ) or getattr(console_formatter, "current_task_branch", None) + for turn_num in range(ctx.max_turns): + agent_branch, accepted_output_modes = _get_turn_context(ctx.agent_config) a2a_result = execute_a2a_delegation( - endpoint=agent_config.endpoint, - auth=agent_config.auth, - timeout=agent_config.timeout, + endpoint=ctx.agent_config.endpoint, + auth=ctx.agent_config.auth, + timeout=ctx.agent_config.timeout, task_description=current_request, context_id=context_id, - task_id=task_id_config, + task_id=task_id, reference_task_ids=reference_task_ids, - metadata=metadata, - extensions=extensions, + metadata=ctx.metadata, + extensions=ctx.extensions, conversation_history=conversation_history, - agent_id=agent_id, + agent_id=ctx.agent_id, agent_role=Role.user, agent_branch=agent_branch, - response_model=agent_config.response_model, + response_model=ctx.agent_config.response_model, turn_number=turn_num + 1, - updates=agent_config.updates, - transport_protocol=agent_config.transport_protocol, + updates=ctx.agent_config.updates, + transport=ctx.agent_config.transport, from_task=task, from_agent=self, + client_extensions=getattr(ctx.agent_config, "extensions", None), + accepted_output_modes=accepted_output_modes, ) conversation_history = a2a_result.get("history", []) @@ -859,24 +984,24 @@ def _delegate_to_a2a( if conversation_history: latest_message = conversation_history[-1] if latest_message.task_id is not None: - task_id_config = latest_message.task_id + task_id = latest_message.task_id if latest_message.context_id is not None: context_id = latest_message.context_id if a2a_result["status"] in [TaskState.completed, TaskState.input_required]: - trusted_result, task_id_config, reference_task_ids = ( + trusted_result, task_id, reference_task_ids, remote_notice = ( _handle_task_completion( a2a_result, task, - task_id_config, + task_id, reference_task_ids, - agent_config, + ctx.agent_config, turn_num, from_task=task, from_agent=self, - endpoint=agent_config.endpoint, - a2a_agent_name=current_a2a_agent_name, - agent_card=current_agent_card_dict, + endpoint=ctx.agent_config.endpoint, + a2a_agent_name=state.agent_name, + agent_card=state.agent_card_dict, ) ) if trusted_result is not None: @@ -885,22 +1010,23 @@ def _delegate_to_a2a( final_result, next_request = _handle_agent_response_and_continue( self=self, a2a_result=a2a_result, - agent_id=agent_id, + agent_id=ctx.agent_id, agent_cards=agent_cards, - a2a_agents=a2a_agents, - original_task_description=original_task_description, + a2a_agents=ctx.a2a_agents, + original_task_description=ctx.original_task_description, conversation_history=conversation_history, turn_num=turn_num, - max_turns=max_turns, + max_turns=ctx.max_turns, task=task, original_fn=original_fn, context=context, tools=tools, - agent_response_model=agent_response_model, - remote_task_completed=(a2a_result["status"] == TaskState.completed), - endpoint=agent_config.endpoint, - a2a_agent_name=current_a2a_agent_name, - agent_card=current_agent_card_dict, + agent_response_model=ctx.agent_response_model, + extension_registry=_extension_registry, + remote_status_notice=remote_notice, + endpoint=ctx.agent_config.endpoint, + a2a_agent_name=state.agent_name, + agent_card=state.agent_card_dict, ) if final_result is not None: @@ -916,22 +1042,22 @@ def _delegate_to_a2a( final_result, next_request = _handle_agent_response_and_continue( self=self, a2a_result=a2a_result, - agent_id=agent_id, + agent_id=ctx.agent_id, agent_cards=agent_cards, - a2a_agents=a2a_agents, - original_task_description=original_task_description, + a2a_agents=ctx.a2a_agents, + original_task_description=ctx.original_task_description, conversation_history=conversation_history, turn_num=turn_num, - max_turns=max_turns, + max_turns=ctx.max_turns, task=task, original_fn=original_fn, context=context, tools=tools, - agent_response_model=agent_response_model, - remote_task_completed=False, - endpoint=agent_config.endpoint, - a2a_agent_name=current_a2a_agent_name, - agent_card=current_agent_card_dict, + agent_response_model=ctx.agent_response_model, + extension_registry=_extension_registry, + endpoint=ctx.agent_config.endpoint, + a2a_agent_name=state.agent_name, + agent_card=state.agent_card_dict, ) if final_result is not None: @@ -941,34 +1067,28 @@ def _delegate_to_a2a( current_request = next_request continue - crewai_event_bus.emit( - None, - A2AConversationCompletedEvent( - status="failed", - final_result=None, - error=error_msg, - total_turns=turn_num + 1, - from_task=task, - from_agent=self, - endpoint=agent_config.endpoint, - a2a_agent_name=current_a2a_agent_name, - agent_card=current_agent_card_dict, - ), + return _emit_delegation_failed( + error_msg, + turn_num, + task, + self, + ctx.agent_config.endpoint, + state.agent_name, + state.agent_card_dict, ) - return f"A2A delegation failed: {error_msg}" return _handle_max_turns_exceeded( conversation_history, - max_turns, + ctx.max_turns, from_task=task, from_agent=self, - endpoint=agent_config.endpoint, - a2a_agent_name=current_a2a_agent_name, - agent_card=current_agent_card_dict, + endpoint=ctx.agent_config.endpoint, + a2a_agent_name=state.agent_name, + agent_card=state.agent_card_dict, ) finally: - task.description = original_task_description + task.description = ctx.original_task_description async def _afetch_card_from_config( @@ -993,6 +1113,9 @@ async def _afetch_agent_cards_concurrently( agent_cards: dict[str, AgentCard] = {} failed_agents: dict[str, str] = {} + if not a2a_agents: + return agent_cards, failed_agents + tasks = [_afetch_card_from_config(config) for config in a2a_agents] results = await asyncio.gather(*tasks) @@ -1042,7 +1165,7 @@ async def _aexecute_task_with_a2a( finally: task.description = original_description - task.description, _ = _augment_prompt_with_a2a( + task.description, _, extension_states = _augment_prompt_with_a2a( a2a_agents=a2a_agents, task_description=original_description, agent_cards=agent_cards, @@ -1059,7 +1182,7 @@ async def _aexecute_task_with_a2a( if extension_registry and isinstance(agent_response, BaseModel): agent_response = extension_registry.process_response_with_all( - agent_response, {} + agent_response, extension_states ) if isinstance(agent_response, BaseModel) and isinstance( @@ -1075,7 +1198,7 @@ async def _aexecute_task_with_a2a( tools=tools, agent_cards=agent_cards, original_task_description=original_description, - extension_registry=extension_registry, + _extension_registry=extension_registry, ) return str(agent_response.message) @@ -1101,7 +1224,8 @@ async def _ahandle_agent_response_and_continue( context: str | None, tools: list[BaseTool] | None, agent_response_model: type[BaseModel] | None, - remote_task_completed: bool = False, + extension_registry: ExtensionRegistry | None = None, + remote_status_notice: str = "", endpoint: str | None = None, a2a_agent_name: str | None = None, agent_card: dict[str, Any] | None = None, @@ -1109,14 +1233,18 @@ async def _ahandle_agent_response_and_continue( """Async version of _handle_agent_response_and_continue.""" agent_cards_dict = _prepare_agent_cards_dict(a2a_result, agent_id, agent_cards) - task.description, disable_structured_output = _augment_prompt_with_a2a( + ( + task.description, + disable_structured_output, + extension_states, + ) = _augment_prompt_with_a2a( a2a_agents=a2a_agents, task_description=original_task_description, conversation_history=conversation_history, turn_num=turn_num, max_turns=max_turns, agent_cards=agent_cards_dict, - remote_task_completed=remote_task_completed, + remote_status_notice=remote_status_notice, ) original_response_model = task.response_model @@ -1134,6 +1262,8 @@ async def _ahandle_agent_response_and_continue( turn_num=turn_num, agent_role=self.role, agent_response_model=agent_response_model, + extension_registry=extension_registry, + extension_states=extension_states, from_task=task, from_agent=self, endpoint=endpoint, @@ -1151,63 +1281,45 @@ async def _adelegate_to_a2a( tools: list[BaseTool] | None, agent_cards: dict[str, AgentCard] | None = None, original_task_description: str | None = None, - extension_registry: ExtensionRegistry | None = None, + _extension_registry: ExtensionRegistry | None = None, ) -> str: """Async version of _delegate_to_a2a.""" - ( - a2a_agents, - agent_response_model, - current_request, - agent_id, - agent_config, - context_id, - task_id_config, - metadata, - extensions, - reference_task_ids, - original_task_description, - max_turns, - ) = _prepare_delegation_context( + ctx = _prepare_delegation_context( self, agent_response, task, original_task_description ) - - conversation_history: list[Message] = [] - - current_agent_card = agent_cards.get(agent_id) if agent_cards else None - current_agent_card_dict = ( - current_agent_card.model_dump() if current_agent_card else None - ) - current_a2a_agent_name = current_agent_card.name if current_agent_card else None + state = _init_delegation_state(ctx, agent_cards) + current_request = state.current_request + context_id = state.context_id + task_id = state.task_id + reference_task_ids = state.reference_task_ids + conversation_history = state.conversation_history try: - for turn_num in range(max_turns): - console_formatter = getattr(crewai_event_bus, "_console", None) - agent_branch = None - if console_formatter: - agent_branch = getattr( - console_formatter, "current_agent_branch", None - ) or getattr(console_formatter, "current_task_branch", None) + for turn_num in range(ctx.max_turns): + agent_branch, accepted_output_modes = _get_turn_context(ctx.agent_config) a2a_result = await aexecute_a2a_delegation( - endpoint=agent_config.endpoint, - auth=agent_config.auth, - timeout=agent_config.timeout, + endpoint=ctx.agent_config.endpoint, + auth=ctx.agent_config.auth, + timeout=ctx.agent_config.timeout, task_description=current_request, context_id=context_id, - task_id=task_id_config, + task_id=task_id, reference_task_ids=reference_task_ids, - metadata=metadata, - extensions=extensions, + metadata=ctx.metadata, + extensions=ctx.extensions, conversation_history=conversation_history, - agent_id=agent_id, + agent_id=ctx.agent_id, agent_role=Role.user, agent_branch=agent_branch, - response_model=agent_config.response_model, + response_model=ctx.agent_config.response_model, turn_number=turn_num + 1, - transport_protocol=agent_config.transport_protocol, - updates=agent_config.updates, + transport=ctx.agent_config.transport, + updates=ctx.agent_config.updates, from_task=task, from_agent=self, + client_extensions=getattr(ctx.agent_config, "extensions", None), + accepted_output_modes=accepted_output_modes, ) conversation_history = a2a_result.get("history", []) @@ -1215,24 +1327,24 @@ async def _adelegate_to_a2a( if conversation_history: latest_message = conversation_history[-1] if latest_message.task_id is not None: - task_id_config = latest_message.task_id + task_id = latest_message.task_id if latest_message.context_id is not None: context_id = latest_message.context_id if a2a_result["status"] in [TaskState.completed, TaskState.input_required]: - trusted_result, task_id_config, reference_task_ids = ( + trusted_result, task_id, reference_task_ids, remote_notice = ( _handle_task_completion( a2a_result, task, - task_id_config, + task_id, reference_task_ids, - agent_config, + ctx.agent_config, turn_num, from_task=task, from_agent=self, - endpoint=agent_config.endpoint, - a2a_agent_name=current_a2a_agent_name, - agent_card=current_agent_card_dict, + endpoint=ctx.agent_config.endpoint, + a2a_agent_name=state.agent_name, + agent_card=state.agent_card_dict, ) ) if trusted_result is not None: @@ -1241,22 +1353,23 @@ async def _adelegate_to_a2a( final_result, next_request = await _ahandle_agent_response_and_continue( self=self, a2a_result=a2a_result, - agent_id=agent_id, + agent_id=ctx.agent_id, agent_cards=agent_cards, - a2a_agents=a2a_agents, - original_task_description=original_task_description, + a2a_agents=ctx.a2a_agents, + original_task_description=ctx.original_task_description, conversation_history=conversation_history, turn_num=turn_num, - max_turns=max_turns, + max_turns=ctx.max_turns, task=task, original_fn=original_fn, context=context, tools=tools, - agent_response_model=agent_response_model, - remote_task_completed=(a2a_result["status"] == TaskState.completed), - endpoint=agent_config.endpoint, - a2a_agent_name=current_a2a_agent_name, - agent_card=current_agent_card_dict, + agent_response_model=ctx.agent_response_model, + extension_registry=_extension_registry, + remote_status_notice=remote_notice, + endpoint=ctx.agent_config.endpoint, + a2a_agent_name=state.agent_name, + agent_card=state.agent_card_dict, ) if final_result is not None: @@ -1272,21 +1385,22 @@ async def _adelegate_to_a2a( final_result, next_request = await _ahandle_agent_response_and_continue( self=self, a2a_result=a2a_result, - agent_id=agent_id, + agent_id=ctx.agent_id, agent_cards=agent_cards, - a2a_agents=a2a_agents, - original_task_description=original_task_description, + a2a_agents=ctx.a2a_agents, + original_task_description=ctx.original_task_description, conversation_history=conversation_history, turn_num=turn_num, - max_turns=max_turns, + max_turns=ctx.max_turns, task=task, original_fn=original_fn, context=context, tools=tools, - agent_response_model=agent_response_model, - endpoint=agent_config.endpoint, - a2a_agent_name=current_a2a_agent_name, - agent_card=current_agent_card_dict, + agent_response_model=ctx.agent_response_model, + extension_registry=_extension_registry, + endpoint=ctx.agent_config.endpoint, + a2a_agent_name=state.agent_name, + agent_card=state.agent_card_dict, ) if final_result is not None: @@ -1296,31 +1410,25 @@ async def _adelegate_to_a2a( current_request = next_request continue - crewai_event_bus.emit( - None, - A2AConversationCompletedEvent( - status="failed", - final_result=None, - error=error_msg, - total_turns=turn_num + 1, - from_task=task, - from_agent=self, - endpoint=agent_config.endpoint, - a2a_agent_name=current_a2a_agent_name, - agent_card=current_agent_card_dict, - ), + return _emit_delegation_failed( + error_msg, + turn_num, + task, + self, + ctx.agent_config.endpoint, + state.agent_name, + state.agent_card_dict, ) - return f"A2A delegation failed: {error_msg}" return _handle_max_turns_exceeded( conversation_history, - max_turns, + ctx.max_turns, from_task=task, from_agent=self, - endpoint=agent_config.endpoint, - a2a_agent_name=current_a2a_agent_name, - agent_card=current_agent_card_dict, + endpoint=ctx.agent_config.endpoint, + a2a_agent_name=state.agent_name, + agent_card=state.agent_card_dict, ) finally: - task.description = original_task_description + task.description = ctx.original_task_description diff --git a/lib/crewai/src/crewai/events/types/a2a_events.py b/lib/crewai/src/crewai/events/types/a2a_events.py index d69878aac..55de064f8 100644 --- a/lib/crewai/src/crewai/events/types/a2a_events.py +++ b/lib/crewai/src/crewai/events/types/a2a_events.py @@ -654,3 +654,165 @@ class A2AParallelDelegationCompletedEvent(A2AEventBase): success_count: int failure_count: int results: dict[str, str] | None = None + + +class A2ATransportNegotiatedEvent(A2AEventBase): + """Event emitted when transport protocol is negotiated with an A2A agent. + + This event is emitted after comparing client and server transport capabilities + to select the optimal transport protocol and endpoint URL. + + Attributes: + endpoint: Original A2A agent endpoint URL. + a2a_agent_name: Name of the A2A agent from agent card. + negotiated_transport: The transport protocol selected (JSONRPC, GRPC, HTTP+JSON). + negotiated_url: The URL to use for the selected transport. + source: How the transport was selected ('client_preferred', 'server_preferred', 'fallback'). + client_supported_transports: Transports the client can use. + server_supported_transports: Transports the server supports. + server_preferred_transport: The server's preferred transport from AgentCard. + client_preferred_transport: The client's preferred transport if set. + metadata: Custom A2A metadata key-value pairs. + """ + + type: str = "a2a_transport_negotiated" + endpoint: str + a2a_agent_name: str | None = None + negotiated_transport: str + negotiated_url: str + source: str + client_supported_transports: list[str] + server_supported_transports: list[str] + server_preferred_transport: str + client_preferred_transport: str | None = None + metadata: dict[str, Any] | None = None + + +class A2AContentTypeNegotiatedEvent(A2AEventBase): + """Event emitted when content types are negotiated with an A2A agent. + + This event is emitted after comparing client and server input/output mode + capabilities to determine compatible MIME types for communication. + + Attributes: + endpoint: A2A agent endpoint URL. + a2a_agent_name: Name of the A2A agent from agent card. + skill_name: Skill name if negotiation was skill-specific. + client_input_modes: MIME types the client can send. + client_output_modes: MIME types the client can accept. + server_input_modes: MIME types the server accepts. + server_output_modes: MIME types the server produces. + negotiated_input_modes: Compatible input MIME types selected. + negotiated_output_modes: Compatible output MIME types selected. + negotiation_success: Whether compatible types were found for both directions. + metadata: Custom A2A metadata key-value pairs. + """ + + type: str = "a2a_content_type_negotiated" + endpoint: str + a2a_agent_name: str | None = None + skill_name: str | None = None + client_input_modes: list[str] + client_output_modes: list[str] + server_input_modes: list[str] + server_output_modes: list[str] + negotiated_input_modes: list[str] + negotiated_output_modes: list[str] + negotiation_success: bool = True + metadata: dict[str, Any] | None = None + + +# ----------------------------------------------------------------------------- +# Context Lifecycle Events +# ----------------------------------------------------------------------------- + + +class A2AContextCreatedEvent(A2AEventBase): + """Event emitted when an A2A context is created. + + Contexts group related tasks in a conversation or workflow. + + Attributes: + context_id: Unique identifier for the context. + created_at: Unix timestamp when context was created. + metadata: Custom A2A metadata key-value pairs. + """ + + type: str = "a2a_context_created" + context_id: str + created_at: float + metadata: dict[str, Any] | None = None + + +class A2AContextExpiredEvent(A2AEventBase): + """Event emitted when an A2A context expires due to TTL. + + Attributes: + context_id: The expired context identifier. + created_at: Unix timestamp when context was created. + age_seconds: How long the context existed before expiring. + task_count: Number of tasks in the context when expired. + metadata: Custom A2A metadata key-value pairs. + """ + + type: str = "a2a_context_expired" + context_id: str + created_at: float + age_seconds: float + task_count: int + metadata: dict[str, Any] | None = None + + +class A2AContextIdleEvent(A2AEventBase): + """Event emitted when an A2A context becomes idle. + + Idle contexts have had no activity for the configured threshold. + + Attributes: + context_id: The idle context identifier. + idle_seconds: Seconds since last activity. + task_count: Number of tasks in the context. + metadata: Custom A2A metadata key-value pairs. + """ + + type: str = "a2a_context_idle" + context_id: str + idle_seconds: float + task_count: int + metadata: dict[str, Any] | None = None + + +class A2AContextCompletedEvent(A2AEventBase): + """Event emitted when all tasks in an A2A context complete. + + Attributes: + context_id: The completed context identifier. + total_tasks: Total number of tasks that were in the context. + duration_seconds: Total context lifetime in seconds. + metadata: Custom A2A metadata key-value pairs. + """ + + type: str = "a2a_context_completed" + context_id: str + total_tasks: int + duration_seconds: float + metadata: dict[str, Any] | None = None + + +class A2AContextPrunedEvent(A2AEventBase): + """Event emitted when an A2A context is pruned (deleted). + + Pruning removes the context metadata and optionally associated tasks. + + Attributes: + context_id: The pruned context identifier. + task_count: Number of tasks that were in the context. + age_seconds: How long the context existed before pruning. + metadata: Custom A2A metadata key-value pairs. + """ + + type: str = "a2a_context_pruned" + context_id: str + task_count: int + age_seconds: float + metadata: dict[str, Any] | None = None diff --git a/lib/crewai/tests/a2a/test_a2a_integration.py b/lib/crewai/tests/a2a/test_a2a_integration.py index f46af4789..9950ee0a2 100644 --- a/lib/crewai/tests/a2a/test_a2a_integration.py +++ b/lib/crewai/tests/a2a/test_a2a_integration.py @@ -104,6 +104,7 @@ class TestA2AStreamingIntegration: message=test_message, new_messages=new_messages, agent_card=agent_card, + endpoint=agent_card.url, ) assert isinstance(result, dict) @@ -225,6 +226,7 @@ class TestA2APushNotificationHandler: result_store=mock_store, polling_timeout=30.0, polling_interval=1.0, + endpoint=mock_agent_card.url, ) mock_store.wait_for_result.assert_called_once_with( @@ -287,6 +289,7 @@ class TestA2APushNotificationHandler: result_store=mock_store, polling_timeout=5.0, polling_interval=0.5, + endpoint=mock_agent_card.url, ) assert result["status"] == TaskState.failed @@ -317,6 +320,7 @@ class TestA2APushNotificationHandler: message=test_msg, new_messages=new_messages, agent_card=mock_agent_card, + endpoint=mock_agent_card.url, ) assert result["status"] == TaskState.failed diff --git a/lib/crewai/tests/a2a/utils/test_task.py b/lib/crewai/tests/a2a/utils/test_task.py index 3c3f8865e..781827ac8 100644 --- a/lib/crewai/tests/a2a/utils/test_task.py +++ b/lib/crewai/tests/a2a/utils/test_task.py @@ -43,6 +43,7 @@ def mock_context() -> MagicMock: context.context_id = "test-context-456" context.get_user_input.return_value = "Test user message" context.message = MagicMock(spec=Message) + context.message.parts = [] context.current_task = None return context