diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index defe87b5c..ba68fec38 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -19,7 +19,7 @@ repos: language: system pass_filenames: true types: [python] - exclude: ^(lib/crewai/src/crewai/cli/templates/|lib/crewai/tests/|lib/crewai-tools/tests/|lib/crewai-files/tests/) + exclude: ^(lib/crewai/src/crewai/cli/templates/|lib/crewai/tests/|lib/crewai-tools/tests/|lib/crewai-files/tests/|lib/crewai-a2a/tests/) - repo: https://github.com/astral-sh/uv-pre-commit rev: 0.9.3 hooks: diff --git a/conftest.py b/conftest.py index 1cce71c26..a821f4c31 100644 --- a/conftest.py +++ b/conftest.py @@ -12,6 +12,7 @@ from dotenv import load_dotenv import pytest from vcr.request import Request # type: ignore[import-untyped] + try: import vcr.stubs.httpx_stubs as httpx_stubs # type: ignore[import-untyped] except ModuleNotFoundError: @@ -225,7 +226,7 @@ def vcr_cassette_dir(request: Any) -> str: for parent in test_file.parents: if ( - parent.name in ("crewai", "crewai-tools", "crewai-files") + parent.name in ("crewai", "crewai-tools", "crewai-files", "crewai-a2a") and parent.parent.name == "lib" ): package_root = parent diff --git a/lib/crewai-a2a/README.md b/lib/crewai-a2a/README.md new file mode 100644 index 000000000..e69de29bb diff --git a/lib/crewai-a2a/pyproject.toml b/lib/crewai-a2a/pyproject.toml new file mode 100644 index 000000000..4d8cfbb2e --- /dev/null +++ b/lib/crewai-a2a/pyproject.toml @@ -0,0 +1,25 @@ +[project] +name = "crewai-a2a" +dynamic = ["version"] +description = "A2A (Agent-to-Agent) protocol support for CrewAI" +readme = "README.md" +authors = [{ name = "Greyson LaLonde", email = "greyson@crewai.com" }] +requires-python = ">=3.10, <3.14" +dependencies = [ + "crewai==1.10.1b1", + "a2a-sdk~=0.3.10", + "httpx-auth~=0.23.1", + "httpx-sse~=0.4.0", + "aiocache[redis,memcached]~=0.12.3", +] + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.version] +path = "src/crewai_a2a/__init__.py" + +[tool.uv.sources] +crewai = { workspace = true } +crewai-files = { workspace = true } diff --git a/lib/crewai-a2a/src/crewai_a2a/__init__.py b/lib/crewai-a2a/src/crewai_a2a/__init__.py new file mode 100644 index 000000000..0abef96b7 --- /dev/null +++ b/lib/crewai-a2a/src/crewai_a2a/__init__.py @@ -0,0 +1,12 @@ +"""Agent-to-Agent (A2A) protocol communication module for CrewAI.""" + +__version__ = "1.10.1b1" + +from crewai_a2a.config import A2AClientConfig, A2AConfig, A2AServerConfig + + +__all__ = [ + "A2AClientConfig", + "A2AConfig", + "A2AServerConfig", +] diff --git a/lib/crewai-a2a/src/crewai_a2a/auth/__init__.py b/lib/crewai-a2a/src/crewai_a2a/auth/__init__.py new file mode 100644 index 000000000..214b579ab --- /dev/null +++ b/lib/crewai-a2a/src/crewai_a2a/auth/__init__.py @@ -0,0 +1,36 @@ +"""A2A authentication schemas.""" + +from crewai_a2a.auth.client_schemes import ( + APIKeyAuth, + AuthScheme, + BearerTokenAuth, + ClientAuthScheme, + HTTPBasicAuth, + HTTPDigestAuth, + OAuth2AuthorizationCode, + OAuth2ClientCredentials, + TLSConfig, +) +from crewai_a2a.auth.server_schemes import ( + AuthenticatedUser, + OIDCAuth, + ServerAuthScheme, + SimpleTokenAuth, +) + + +__all__ = [ + "APIKeyAuth", + "AuthScheme", + "AuthenticatedUser", + "BearerTokenAuth", + "ClientAuthScheme", + "HTTPBasicAuth", + "HTTPDigestAuth", + "OAuth2AuthorizationCode", + "OAuth2ClientCredentials", + "OIDCAuth", + "ServerAuthScheme", + "SimpleTokenAuth", + "TLSConfig", +] diff --git a/lib/crewai-a2a/src/crewai_a2a/auth/client_schemes.py b/lib/crewai-a2a/src/crewai_a2a/auth/client_schemes.py new file mode 100644 index 000000000..0356b8aef --- /dev/null +++ b/lib/crewai-a2a/src/crewai_a2a/auth/client_schemes.py @@ -0,0 +1,550 @@ +"""Authentication schemes for A2A protocol clients. + +Supported authentication methods: +- Bearer tokens +- OAuth2 (Client Credentials, Authorization Code) +- API Keys (header, query, cookie) +- HTTP Basic authentication +- HTTP Digest authentication +- mTLS (mutual TLS) client certificate authentication +""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +import asyncio +import base64 +from collections.abc import Awaitable, Callable, MutableMapping +from pathlib import Path +import ssl +import time +from typing import TYPE_CHECKING, ClassVar, Literal +import urllib.parse + +import httpx +from httpx import DigestAuth +from pydantic import BaseModel, ConfigDict, Field, FilePath, PrivateAttr +from typing_extensions import deprecated + + +if TYPE_CHECKING: + import grpc # type: ignore[import-untyped] + + +class TLSConfig(BaseModel): + """TLS/mTLS configuration for secure client connections. + + Supports mutual TLS (mTLS) where the client presents a certificate to the server, + and standard TLS with custom CA verification. + + Attributes: + client_cert_path: Path to client certificate file (PEM format) for mTLS. + client_key_path: Path to client private key file (PEM format) for mTLS. + ca_cert_path: Path to CA certificate bundle for server verification. + verify: Whether to verify server certificates. Set False only for development. + """ + + model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid") + + client_cert_path: FilePath | None = Field( + default=None, + description="Path to client certificate file (PEM format) for mTLS", + ) + client_key_path: FilePath | None = Field( + default=None, + description="Path to client private key file (PEM format) for mTLS", + ) + ca_cert_path: FilePath | None = Field( + default=None, + description="Path to CA certificate bundle for server verification", + ) + verify: bool = Field( + default=True, + description="Whether to verify server certificates. Set False only for development.", + ) + + def get_httpx_ssl_context(self) -> ssl.SSLContext | bool | str: + """Build SSL context for httpx client. + + Returns: + SSL context if certificates configured, True for default verification, + False if verification disabled, or path to CA bundle. + """ + if not self.verify: + return False + + if self.client_cert_path and self.client_key_path: + context = ssl.create_default_context() + + if self.ca_cert_path: + context.load_verify_locations(cafile=str(self.ca_cert_path)) + + context.load_cert_chain( + certfile=str(self.client_cert_path), + keyfile=str(self.client_key_path), + ) + return context + + if self.ca_cert_path: + return str(self.ca_cert_path) + + return True + + def get_grpc_credentials(self) -> grpc.ChannelCredentials | None: # type: ignore[no-any-unimported] + """Build gRPC channel credentials for secure connections. + + Returns: + gRPC SSL credentials if certificates configured, None otherwise. + """ + try: + import grpc + except ImportError: + return None + + if not self.verify and not self.client_cert_path: + return None + + root_certs: bytes | None = None + private_key: bytes | None = None + certificate_chain: bytes | None = None + + if self.ca_cert_path: + root_certs = Path(self.ca_cert_path).read_bytes() + + if self.client_cert_path and self.client_key_path: + private_key = Path(self.client_key_path).read_bytes() + certificate_chain = Path(self.client_cert_path).read_bytes() + + return grpc.ssl_channel_credentials( + root_certificates=root_certs, + private_key=private_key, + certificate_chain=certificate_chain, + ) + + +class ClientAuthScheme(ABC, BaseModel): + """Base class for client-side authentication schemes. + + Client auth schemes apply credentials to outgoing requests. + + Attributes: + tls: Optional TLS/mTLS configuration for secure connections. + """ + + tls: TLSConfig | None = Field( + default=None, + description="TLS/mTLS configuration for secure connections", + ) + + @abstractmethod + async def apply_auth( + self, client: httpx.AsyncClient, headers: MutableMapping[str, str] + ) -> MutableMapping[str, str]: + """Apply authentication to request headers. + + Args: + client: HTTP client for making auth requests. + headers: Current request headers. + + Returns: + Updated headers with authentication applied. + """ + ... + + +@deprecated("Use ClientAuthScheme instead", category=FutureWarning) +class AuthScheme(ClientAuthScheme): + """Deprecated: Use ClientAuthScheme instead.""" + + +class BearerTokenAuth(ClientAuthScheme): + """Bearer token authentication (Authorization: Bearer ). + + Attributes: + token: Bearer token for authentication. + """ + + token: str = Field(description="Bearer token") + + async def apply_auth( + self, client: httpx.AsyncClient, headers: MutableMapping[str, str] + ) -> MutableMapping[str, str]: + """Apply Bearer token to Authorization header. + + Args: + client: HTTP client for making auth requests. + headers: Current request headers. + + Returns: + Updated headers with Bearer token in Authorization header. + """ + headers["Authorization"] = f"Bearer {self.token}" + return headers + + +class HTTPBasicAuth(ClientAuthScheme): + """HTTP Basic authentication. + + Attributes: + username: Username for Basic authentication. + password: Password for Basic authentication. + """ + + username: str = Field(description="Username") + password: str = Field(description="Password") + + async def apply_auth( + self, client: httpx.AsyncClient, headers: MutableMapping[str, str] + ) -> MutableMapping[str, str]: + """Apply HTTP Basic authentication. + + Args: + client: HTTP client for making auth requests. + headers: Current request headers. + + Returns: + Updated headers with Basic auth in Authorization header. + """ + credentials = f"{self.username}:{self.password}" + encoded = base64.b64encode(credentials.encode()).decode() + headers["Authorization"] = f"Basic {encoded}" + return headers + + +class HTTPDigestAuth(ClientAuthScheme): + """HTTP Digest authentication. + + Note: Uses httpx-auth library for digest implementation. + + Attributes: + username: Username for Digest authentication. + password: Password for Digest authentication. + """ + + username: str = Field(description="Username") + password: str = Field(description="Password") + + _configured_client_id: int | None = PrivateAttr(default=None) + + async def apply_auth( + self, client: httpx.AsyncClient, headers: MutableMapping[str, str] + ) -> MutableMapping[str, str]: + """Digest auth is handled by httpx auth flow, not headers. + + Args: + client: HTTP client for making auth requests. + headers: Current request headers. + + Returns: + Unchanged headers (Digest auth handled by httpx auth flow). + """ + return headers + + def configure_client(self, client: httpx.AsyncClient) -> None: + """Configure client with Digest auth. + + Idempotent: Only configures the client once. Subsequent calls on the same + client instance are no-ops to prevent overwriting auth configuration. + + Args: + client: HTTP client to configure with Digest authentication. + """ + client_id = id(client) + if self._configured_client_id == client_id: + return + + client.auth = DigestAuth(self.username, self.password) + self._configured_client_id = client_id + + +class APIKeyAuth(ClientAuthScheme): + """API Key authentication (header, query, or cookie). + + Attributes: + api_key: API key value for authentication. + location: Where to send the API key (header, query, or cookie). + name: Parameter name for the API key (default: X-API-Key). + """ + + api_key: str = Field(description="API key value") + location: Literal["header", "query", "cookie"] = Field( + default="header", description="Where to send the API key" + ) + name: str = Field(default="X-API-Key", description="Parameter name for the API key") + + _configured_client_ids: set[int] = PrivateAttr(default_factory=set) + + async def apply_auth( + self, client: httpx.AsyncClient, headers: MutableMapping[str, str] + ) -> MutableMapping[str, str]: + """Apply API key authentication. + + Args: + client: HTTP client for making auth requests. + headers: Current request headers. + + Returns: + Updated headers with API key (for header/cookie locations). + """ + if self.location == "header": + headers[self.name] = self.api_key + elif self.location == "cookie": + headers["Cookie"] = f"{self.name}={self.api_key}" + return headers + + def configure_client(self, client: httpx.AsyncClient) -> None: + """Configure client for query param API keys. + + Idempotent: Only adds the request hook once per client instance. + Subsequent calls on the same client are no-ops to prevent hook accumulation. + + Args: + client: HTTP client to configure with query param API key hook. + """ + if self.location == "query": + client_id = id(client) + if client_id in self._configured_client_ids: + return + + async def _add_api_key_param(request: httpx.Request) -> None: + url = httpx.URL(request.url) + request.url = url.copy_add_param(self.name, self.api_key) + + client.event_hooks["request"].append(_add_api_key_param) + self._configured_client_ids.add(client_id) + + +class OAuth2ClientCredentials(ClientAuthScheme): + """OAuth2 Client Credentials flow authentication. + + Thread-safe implementation with asyncio.Lock to prevent concurrent token fetches + when multiple requests share the same auth instance. + + Attributes: + token_url: OAuth2 token endpoint URL. + client_id: OAuth2 client identifier. + client_secret: OAuth2 client secret. + scopes: List of required OAuth2 scopes. + """ + + token_url: str = Field(description="OAuth2 token endpoint") + client_id: str = Field(description="OAuth2 client ID") + client_secret: str = Field(description="OAuth2 client secret") + scopes: list[str] = Field( + default_factory=list, description="Required OAuth2 scopes" + ) + + _access_token: str | None = PrivateAttr(default=None) + _token_expires_at: float | None = PrivateAttr(default=None) + _lock: asyncio.Lock = PrivateAttr(default_factory=asyncio.Lock) + + async def apply_auth( + self, client: httpx.AsyncClient, headers: MutableMapping[str, str] + ) -> MutableMapping[str, str]: + """Apply OAuth2 access token to Authorization header. + + Uses asyncio.Lock to ensure only one coroutine fetches tokens at a time, + preventing race conditions when multiple concurrent requests use the same + auth instance. + + Args: + client: HTTP client for making token requests. + headers: Current request headers. + + Returns: + Updated headers with OAuth2 access token in Authorization header. + """ + if ( + self._access_token is None + or self._token_expires_at is None + or time.time() >= self._token_expires_at + ): + async with self._lock: + if ( + self._access_token is None + or self._token_expires_at is None + or time.time() >= self._token_expires_at + ): + await self._fetch_token(client) + + if self._access_token: + headers["Authorization"] = f"Bearer {self._access_token}" + + return headers + + async def _fetch_token(self, client: httpx.AsyncClient) -> None: + """Fetch OAuth2 access token using client credentials flow. + + Args: + client: HTTP client for making token request. + + Raises: + httpx.HTTPStatusError: If token request fails. + """ + data = { + "grant_type": "client_credentials", + "client_id": self.client_id, + "client_secret": self.client_secret, + } + + if self.scopes: + data["scope"] = " ".join(self.scopes) + + response = await client.post(self.token_url, data=data) + response.raise_for_status() + + token_data = response.json() + self._access_token = token_data["access_token"] + expires_in = token_data.get("expires_in", 3600) + self._token_expires_at = time.time() + expires_in - 60 + + +class OAuth2AuthorizationCode(ClientAuthScheme): + """OAuth2 Authorization Code flow authentication. + + Thread-safe implementation with asyncio.Lock to prevent concurrent token operations. + + Note: Requires interactive authorization. + + Attributes: + authorization_url: OAuth2 authorization endpoint URL. + token_url: OAuth2 token endpoint URL. + client_id: OAuth2 client identifier. + client_secret: OAuth2 client secret. + redirect_uri: OAuth2 redirect URI for callback. + scopes: List of required OAuth2 scopes. + """ + + authorization_url: str = Field(description="OAuth2 authorization endpoint") + token_url: str = Field(description="OAuth2 token endpoint") + client_id: str = Field(description="OAuth2 client ID") + client_secret: str = Field(description="OAuth2 client secret") + redirect_uri: str = Field(description="OAuth2 redirect URI") + scopes: list[str] = Field( + default_factory=list, description="Required OAuth2 scopes" + ) + + _access_token: str | None = PrivateAttr(default=None) + _refresh_token: str | None = PrivateAttr(default=None) + _token_expires_at: float | None = PrivateAttr(default=None) + _authorization_callback: Callable[[str], Awaitable[str]] | None = PrivateAttr( + default=None + ) + _lock: asyncio.Lock = PrivateAttr(default_factory=asyncio.Lock) + + def set_authorization_callback( + self, callback: Callable[[str], Awaitable[str]] | None + ) -> None: + """Set callback to handle authorization URL. + + Args: + callback: Async function that receives authorization URL and returns auth code. + """ + self._authorization_callback = callback + + async def apply_auth( + self, client: httpx.AsyncClient, headers: MutableMapping[str, str] + ) -> MutableMapping[str, str]: + """Apply OAuth2 access token to Authorization header. + + Uses asyncio.Lock to ensure only one coroutine handles token operations + (initial fetch or refresh) at a time. + + Args: + client: HTTP client for making token requests. + headers: Current request headers. + + Returns: + Updated headers with OAuth2 access token in Authorization header. + + Raises: + ValueError: If authorization callback is not set. + """ + if self._access_token is None: + if self._authorization_callback is None: + msg = "Authorization callback not set. Use set_authorization_callback()" + raise ValueError(msg) + async with self._lock: + if self._access_token is None: + await self._fetch_initial_token(client) + elif self._token_expires_at and time.time() >= self._token_expires_at: + async with self._lock: + if self._token_expires_at and time.time() >= self._token_expires_at: + await self._refresh_access_token(client) + + if self._access_token: + headers["Authorization"] = f"Bearer {self._access_token}" + + return headers + + async def _fetch_initial_token(self, client: httpx.AsyncClient) -> None: + """Fetch initial access token using authorization code flow. + + Args: + client: HTTP client for making token request. + + Raises: + ValueError: If authorization callback is not set. + httpx.HTTPStatusError: If token request fails. + """ + params = { + "response_type": "code", + "client_id": self.client_id, + "redirect_uri": self.redirect_uri, + "scope": " ".join(self.scopes), + } + auth_url = f"{self.authorization_url}?{urllib.parse.urlencode(params)}" + + if self._authorization_callback is None: + msg = "Authorization callback not set" + raise ValueError(msg) + auth_code = await self._authorization_callback(auth_url) + + data = { + "grant_type": "authorization_code", + "code": auth_code, + "client_id": self.client_id, + "client_secret": self.client_secret, + "redirect_uri": self.redirect_uri, + } + + response = await client.post(self.token_url, data=data) + response.raise_for_status() + + token_data = response.json() + self._access_token = token_data["access_token"] + self._refresh_token = token_data.get("refresh_token") + + expires_in = token_data.get("expires_in", 3600) + self._token_expires_at = time.time() + expires_in - 60 + + async def _refresh_access_token(self, client: httpx.AsyncClient) -> None: + """Refresh the access token using refresh token. + + Args: + client: HTTP client for making token request. + + Raises: + httpx.HTTPStatusError: If token refresh request fails. + """ + if not self._refresh_token: + await self._fetch_initial_token(client) + return + + data = { + "grant_type": "refresh_token", + "refresh_token": self._refresh_token, + "client_id": self.client_id, + "client_secret": self.client_secret, + } + + response = await client.post(self.token_url, data=data) + response.raise_for_status() + + token_data = response.json() + self._access_token = token_data["access_token"] + if "refresh_token" in token_data: + self._refresh_token = token_data["refresh_token"] + + expires_in = token_data.get("expires_in", 3600) + self._token_expires_at = time.time() + expires_in - 60 diff --git a/lib/crewai-a2a/src/crewai_a2a/auth/schemas.py b/lib/crewai-a2a/src/crewai_a2a/auth/schemas.py new file mode 100644 index 000000000..4372ed5fe --- /dev/null +++ b/lib/crewai-a2a/src/crewai_a2a/auth/schemas.py @@ -0,0 +1,71 @@ +"""Deprecated: Authentication schemes for A2A protocol agents. + +This module is deprecated. Import from crewai_a2a.auth instead: +- crewai_a2a.auth.ClientAuthScheme (replaces AuthScheme) +- crewai_a2a.auth.BearerTokenAuth +- crewai_a2a.auth.HTTPBasicAuth +- crewai_a2a.auth.HTTPDigestAuth +- crewai_a2a.auth.APIKeyAuth +- crewai_a2a.auth.OAuth2ClientCredentials +- crewai_a2a.auth.OAuth2AuthorizationCode +""" + +from __future__ import annotations + +from typing_extensions import deprecated + +from crewai_a2a.auth.client_schemes import ( + APIKeyAuth as _APIKeyAuth, + BearerTokenAuth as _BearerTokenAuth, + ClientAuthScheme as _ClientAuthScheme, + HTTPBasicAuth as _HTTPBasicAuth, + HTTPDigestAuth as _HTTPDigestAuth, + OAuth2AuthorizationCode as _OAuth2AuthorizationCode, + OAuth2ClientCredentials as _OAuth2ClientCredentials, +) + + +@deprecated("Use ClientAuthScheme from crewai_a2a.auth instead", category=FutureWarning) +class AuthScheme(_ClientAuthScheme): + """Deprecated: Use ClientAuthScheme from crewai_a2a.auth instead.""" + + +@deprecated("Import from crewai_a2a.auth instead", category=FutureWarning) +class BearerTokenAuth(_BearerTokenAuth): + """Deprecated: Import from crewai_a2a.auth instead.""" + + +@deprecated("Import from crewai_a2a.auth instead", category=FutureWarning) +class HTTPBasicAuth(_HTTPBasicAuth): + """Deprecated: Import from crewai_a2a.auth instead.""" + + +@deprecated("Import from crewai_a2a.auth instead", category=FutureWarning) +class HTTPDigestAuth(_HTTPDigestAuth): + """Deprecated: Import from crewai_a2a.auth instead.""" + + +@deprecated("Import from crewai_a2a.auth instead", category=FutureWarning) +class APIKeyAuth(_APIKeyAuth): + """Deprecated: Import from crewai_a2a.auth instead.""" + + +@deprecated("Import from crewai_a2a.auth instead", category=FutureWarning) +class OAuth2ClientCredentials(_OAuth2ClientCredentials): + """Deprecated: Import from crewai_a2a.auth instead.""" + + +@deprecated("Import from crewai_a2a.auth instead", category=FutureWarning) +class OAuth2AuthorizationCode(_OAuth2AuthorizationCode): + """Deprecated: Import from crewai_a2a.auth instead.""" + + +__all__ = [ + "APIKeyAuth", + "AuthScheme", + "BearerTokenAuth", + "HTTPBasicAuth", + "HTTPDigestAuth", + "OAuth2AuthorizationCode", + "OAuth2ClientCredentials", +] diff --git a/lib/crewai-a2a/src/crewai_a2a/auth/server_schemes.py b/lib/crewai-a2a/src/crewai_a2a/auth/server_schemes.py new file mode 100644 index 000000000..64aa7de37 --- /dev/null +++ b/lib/crewai-a2a/src/crewai_a2a/auth/server_schemes.py @@ -0,0 +1,742 @@ +"""Server-side authentication schemes for A2A protocol. + +These schemes validate incoming requests to A2A server endpoints. + +Supported authentication methods: +- Simple token validation with static bearer tokens +- OpenID Connect with JWT validation using JWKS +- OAuth2 with JWT validation or token introspection +""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from dataclasses import dataclass +import logging +import os +from typing import TYPE_CHECKING, Annotated, Any, ClassVar, Literal + +import jwt +from jwt import PyJWKClient +from pydantic import ( + BaseModel, + BeforeValidator, + ConfigDict, + Field, + HttpUrl, + PrivateAttr, + SecretStr, + model_validator, +) +from typing_extensions import Self + + +if TYPE_CHECKING: + from a2a.types import OAuth2SecurityScheme + + +logger = logging.getLogger(__name__) + + +try: + from fastapi import ( # type: ignore[import-not-found] + HTTPException, + status as http_status, + ) + + HTTP_401_UNAUTHORIZED = http_status.HTTP_401_UNAUTHORIZED + HTTP_500_INTERNAL_SERVER_ERROR = http_status.HTTP_500_INTERNAL_SERVER_ERROR + HTTP_503_SERVICE_UNAVAILABLE = http_status.HTTP_503_SERVICE_UNAVAILABLE +except ImportError: + + class HTTPException(Exception): # type: ignore[no-redef] # noqa: N818 + """Fallback HTTPException when FastAPI is not installed.""" + + def __init__( + self, + status_code: int, + detail: str | None = None, + headers: dict[str, str] | None = None, + ) -> None: + self.status_code = status_code + self.detail = detail + self.headers = headers + super().__init__(detail) + + HTTP_401_UNAUTHORIZED = 401 + HTTP_500_INTERNAL_SERVER_ERROR = 500 + HTTP_503_SERVICE_UNAVAILABLE = 503 + + +def _coerce_secret_str(v: str | SecretStr | None) -> SecretStr | None: + """Coerce string to SecretStr.""" + if v is None or isinstance(v, SecretStr): + return v + return SecretStr(v) + + +CoercedSecretStr = Annotated[SecretStr, BeforeValidator(_coerce_secret_str)] + +JWTAlgorithm = Literal[ + "RS256", + "RS384", + "RS512", + "ES256", + "ES384", + "ES512", + "PS256", + "PS384", + "PS512", +] + + +@dataclass +class AuthenticatedUser: + """Result of successful authentication. + + Attributes: + token: The original token that was validated. + scheme: Name of the authentication scheme used. + claims: JWT claims from OIDC or OAuth2 authentication. + """ + + token: str + scheme: str + claims: dict[str, Any] | None = None + + +class ServerAuthScheme(ABC, BaseModel): + """Base class for server-side authentication schemes. + + Each scheme validates incoming requests and returns an AuthenticatedUser + on success, or raises HTTPException on failure. + """ + + model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid") + + @abstractmethod + async def authenticate(self, token: str) -> AuthenticatedUser: + """Authenticate the provided token. + + Args: + token: The bearer token to authenticate. + + Returns: + AuthenticatedUser on successful authentication. + + Raises: + HTTPException: If authentication fails. + """ + ... + + +class SimpleTokenAuth(ServerAuthScheme): + """Simple bearer token authentication. + + Validates tokens against a configured static token or AUTH_TOKEN env var. + + Attributes: + token: Expected token value. Falls back to AUTH_TOKEN env var if not set. + """ + + token: CoercedSecretStr | None = Field( + default=None, + description="Expected token. Falls back to AUTH_TOKEN env var.", + ) + + def _get_expected_token(self) -> str | None: + """Get the expected token value.""" + if self.token: + return self.token.get_secret_value() + return os.environ.get("AUTH_TOKEN") + + async def authenticate(self, token: str) -> AuthenticatedUser: + """Authenticate using simple token comparison. + + Args: + token: The bearer token to authenticate. + + Returns: + AuthenticatedUser on successful authentication. + + Raises: + HTTPException: If authentication fails. + """ + expected = self._get_expected_token() + + if expected is None: + logger.warning( + "Simple token authentication failed", + extra={"reason": "no_token_configured"}, + ) + raise HTTPException( + status_code=HTTP_401_UNAUTHORIZED, + detail="Authentication not configured", + ) + + if token != expected: + raise HTTPException( + status_code=HTTP_401_UNAUTHORIZED, + detail="Invalid or missing authentication credentials", + ) + + return AuthenticatedUser( + token=token, + scheme="simple_token", + ) + + +class OIDCAuth(ServerAuthScheme): + """OpenID Connect authentication. + + Validates JWTs using JWKS with caching support via PyJWT. + + Attributes: + issuer: The OpenID Connect issuer URL. + audience: The expected audience claim. + jwks_url: Optional explicit JWKS URL. Derived from issuer if not set. + algorithms: List of allowed signing algorithms. + required_claims: List of claims that must be present in the token. + jwks_cache_ttl: TTL for JWKS cache in seconds. + clock_skew_seconds: Allowed clock skew for token validation. + """ + + issuer: HttpUrl = Field( + description="OpenID Connect issuer URL (e.g., https://auth.example.com)" + ) + audience: str = Field(description="Expected audience claim (e.g., api://my-agent)") + jwks_url: HttpUrl | None = Field( + default=None, + description="Explicit JWKS URL. Derived from issuer if not set.", + ) + algorithms: list[str] = Field( + default_factory=lambda: ["RS256"], + description="List of allowed signing algorithms (RS256, ES256, etc.)", + ) + required_claims: list[str] = Field( + default_factory=lambda: ["exp", "iat", "iss", "aud", "sub"], + description="List of claims that must be present in the token", + ) + jwks_cache_ttl: int = Field( + default=3600, + description="TTL for JWKS cache in seconds", + ge=60, + ) + clock_skew_seconds: float = Field( + default=30.0, + description="Allowed clock skew for token validation", + ge=0.0, + ) + + _jwk_client: PyJWKClient | None = PrivateAttr(default=None) + + @model_validator(mode="after") + def _init_jwk_client(self) -> Self: + """Initialize the JWK client after model creation.""" + jwks_url = ( + str(self.jwks_url) + if self.jwks_url + else f"{str(self.issuer).rstrip('/')}/.well-known/jwks.json" + ) + self._jwk_client = PyJWKClient(jwks_url, lifespan=self.jwks_cache_ttl) + return self + + async def authenticate(self, token: str) -> AuthenticatedUser: + """Authenticate using OIDC JWT validation. + + Args: + token: The JWT to authenticate. + + Returns: + AuthenticatedUser on successful authentication. + + Raises: + HTTPException: If authentication fails. + """ + if self._jwk_client is None: + raise HTTPException( + status_code=HTTP_500_INTERNAL_SERVER_ERROR, + detail="OIDC not initialized", + ) + + try: + signing_key = self._jwk_client.get_signing_key_from_jwt(token) + + claims = jwt.decode( + token, + signing_key.key, + algorithms=self.algorithms, + audience=self.audience, + issuer=str(self.issuer).rstrip("/"), + leeway=self.clock_skew_seconds, + options={ + "require": self.required_claims, + }, + ) + + return AuthenticatedUser( + token=token, + scheme="oidc", + claims=claims, + ) + + except jwt.ExpiredSignatureError: + logger.debug( + "OIDC authentication failed", + extra={"reason": "token_expired", "scheme": "oidc"}, + ) + raise HTTPException( + status_code=HTTP_401_UNAUTHORIZED, + detail="Token has expired", + ) from None + except jwt.InvalidAudienceError: + logger.debug( + "OIDC authentication failed", + extra={"reason": "invalid_audience", "scheme": "oidc"}, + ) + raise HTTPException( + status_code=HTTP_401_UNAUTHORIZED, + detail="Invalid token audience", + ) from None + except jwt.InvalidIssuerError: + logger.debug( + "OIDC authentication failed", + extra={"reason": "invalid_issuer", "scheme": "oidc"}, + ) + raise HTTPException( + status_code=HTTP_401_UNAUTHORIZED, + detail="Invalid token issuer", + ) from None + except jwt.MissingRequiredClaimError as e: + logger.debug( + "OIDC authentication failed", + extra={"reason": "missing_claim", "claim": e.claim, "scheme": "oidc"}, + ) + raise HTTPException( + status_code=HTTP_401_UNAUTHORIZED, + detail=f"Missing required claim: {e.claim}", + ) from None + except jwt.PyJWKClientError as e: + logger.error( + "OIDC authentication failed", + extra={ + "reason": "jwks_client_error", + "error": str(e), + "scheme": "oidc", + }, + ) + raise HTTPException( + status_code=HTTP_503_SERVICE_UNAVAILABLE, + detail="Unable to fetch signing keys", + ) from None + except jwt.InvalidTokenError as e: + logger.debug( + "OIDC authentication failed", + extra={"reason": "invalid_token", "error": str(e), "scheme": "oidc"}, + ) + raise HTTPException( + status_code=HTTP_401_UNAUTHORIZED, + detail="Invalid or missing authentication credentials", + ) from None + + +class OAuth2ServerAuth(ServerAuthScheme): + """OAuth2 authentication for A2A server. + + Declares OAuth2 security scheme in AgentCard and validates tokens using + either JWKS for JWT tokens or token introspection for opaque tokens. + + This is distinct from OIDCAuth in that it declares an explicit OAuth2SecurityScheme + with flows, rather than an OpenIdConnectSecurityScheme with discovery URL. + + Attributes: + token_url: OAuth2 token endpoint URL for client_credentials flow. + authorization_url: OAuth2 authorization endpoint for authorization_code flow. + refresh_url: Optional refresh token endpoint URL. + scopes: Available OAuth2 scopes with descriptions. + jwks_url: JWKS URL for JWT validation. Required if not using introspection. + introspection_url: Token introspection endpoint (RFC 7662). Alternative to JWKS. + introspection_client_id: Client ID for introspection endpoint authentication. + introspection_client_secret: Client secret for introspection endpoint. + audience: Expected audience claim for JWT validation. + issuer: Expected issuer claim for JWT validation. + algorithms: Allowed JWT signing algorithms. + required_claims: Claims that must be present in the token. + jwks_cache_ttl: TTL for JWKS cache in seconds. + clock_skew_seconds: Allowed clock skew for token validation. + """ + + token_url: HttpUrl = Field( + description="OAuth2 token endpoint URL", + ) + authorization_url: HttpUrl | None = Field( + default=None, + description="OAuth2 authorization endpoint URL for authorization_code flow", + ) + refresh_url: HttpUrl | None = Field( + default=None, + description="OAuth2 refresh token endpoint URL", + ) + scopes: dict[str, str] = Field( + default_factory=dict, + description="Available OAuth2 scopes with descriptions", + ) + jwks_url: HttpUrl | None = Field( + default=None, + description="JWKS URL for JWT validation. Required if not using introspection.", + ) + introspection_url: HttpUrl | None = Field( + default=None, + description="Token introspection endpoint (RFC 7662). Alternative to JWKS.", + ) + introspection_client_id: str | None = Field( + default=None, + description="Client ID for introspection endpoint authentication", + ) + introspection_client_secret: CoercedSecretStr | None = Field( + default=None, + description="Client secret for introspection endpoint authentication", + ) + audience: str | None = Field( + default=None, + description="Expected audience claim for JWT validation", + ) + issuer: str | None = Field( + default=None, + description="Expected issuer claim for JWT validation", + ) + algorithms: list[str] = Field( + default_factory=lambda: ["RS256"], + description="Allowed JWT signing algorithms", + ) + required_claims: list[str] = Field( + default_factory=lambda: ["exp", "iat"], + description="Claims that must be present in the token", + ) + jwks_cache_ttl: int = Field( + default=3600, + description="TTL for JWKS cache in seconds", + ge=60, + ) + clock_skew_seconds: float = Field( + default=30.0, + description="Allowed clock skew for token validation", + ge=0.0, + ) + + _jwk_client: PyJWKClient | None = PrivateAttr(default=None) + + @model_validator(mode="after") + def _validate_and_init(self) -> Self: + """Validate configuration and initialize JWKS client if needed.""" + if not self.jwks_url and not self.introspection_url: + raise ValueError( + "Either jwks_url or introspection_url must be provided for token validation" + ) + + if self.introspection_url: + if not self.introspection_client_id or not self.introspection_client_secret: + raise ValueError( + "introspection_client_id and introspection_client_secret are required " + "when using token introspection" + ) + + if self.jwks_url: + self._jwk_client = PyJWKClient( + str(self.jwks_url), lifespan=self.jwks_cache_ttl + ) + + return self + + async def authenticate(self, token: str) -> AuthenticatedUser: + """Authenticate using OAuth2 token validation. + + Uses JWKS validation if jwks_url is configured, otherwise falls back + to token introspection. + + Args: + token: The OAuth2 access token to authenticate. + + Returns: + AuthenticatedUser on successful authentication. + + Raises: + HTTPException: If authentication fails. + """ + if self._jwk_client: + return await self._authenticate_jwt(token) + return await self._authenticate_introspection(token) + + async def _authenticate_jwt(self, token: str) -> AuthenticatedUser: + """Authenticate using JWKS JWT validation.""" + if self._jwk_client is None: + raise HTTPException( + status_code=HTTP_500_INTERNAL_SERVER_ERROR, + detail="OAuth2 JWKS not initialized", + ) + + try: + signing_key = self._jwk_client.get_signing_key_from_jwt(token) + + decode_options: dict[str, Any] = { + "require": self.required_claims, + } + + claims = jwt.decode( + token, + signing_key.key, + algorithms=self.algorithms, + audience=self.audience, + issuer=self.issuer, + leeway=self.clock_skew_seconds, + options=decode_options, # type: ignore[arg-type] + ) + + return AuthenticatedUser( + token=token, + scheme="oauth2", + claims=claims, + ) + + except jwt.ExpiredSignatureError: + logger.debug( + "OAuth2 authentication failed", + extra={"reason": "token_expired", "scheme": "oauth2"}, + ) + raise HTTPException( + status_code=HTTP_401_UNAUTHORIZED, + detail="Token has expired", + ) from None + except jwt.InvalidAudienceError: + logger.debug( + "OAuth2 authentication failed", + extra={"reason": "invalid_audience", "scheme": "oauth2"}, + ) + raise HTTPException( + status_code=HTTP_401_UNAUTHORIZED, + detail="Invalid token audience", + ) from None + except jwt.InvalidIssuerError: + logger.debug( + "OAuth2 authentication failed", + extra={"reason": "invalid_issuer", "scheme": "oauth2"}, + ) + raise HTTPException( + status_code=HTTP_401_UNAUTHORIZED, + detail="Invalid token issuer", + ) from None + except jwt.MissingRequiredClaimError as e: + logger.debug( + "OAuth2 authentication failed", + extra={"reason": "missing_claim", "claim": e.claim, "scheme": "oauth2"}, + ) + raise HTTPException( + status_code=HTTP_401_UNAUTHORIZED, + detail=f"Missing required claim: {e.claim}", + ) from None + except jwt.PyJWKClientError as e: + logger.error( + "OAuth2 authentication failed", + extra={ + "reason": "jwks_client_error", + "error": str(e), + "scheme": "oauth2", + }, + ) + raise HTTPException( + status_code=HTTP_503_SERVICE_UNAVAILABLE, + detail="Unable to fetch signing keys", + ) from None + except jwt.InvalidTokenError as e: + logger.debug( + "OAuth2 authentication failed", + extra={"reason": "invalid_token", "error": str(e), "scheme": "oauth2"}, + ) + raise HTTPException( + status_code=HTTP_401_UNAUTHORIZED, + detail="Invalid or missing authentication credentials", + ) from None + + async def _authenticate_introspection(self, token: str) -> AuthenticatedUser: + """Authenticate using OAuth2 token introspection (RFC 7662).""" + import httpx + + if not self.introspection_url: + raise HTTPException( + status_code=HTTP_500_INTERNAL_SERVER_ERROR, + detail="OAuth2 introspection not configured", + ) + + try: + async with httpx.AsyncClient() as client: + response = await client.post( + str(self.introspection_url), + data={"token": token}, + auth=( + self.introspection_client_id or "", + self.introspection_client_secret.get_secret_value() + if self.introspection_client_secret + else "", + ), + ) + response.raise_for_status() + introspection_result = response.json() + + except httpx.HTTPStatusError as e: + logger.error( + "OAuth2 introspection failed", + extra={"reason": "http_error", "status_code": e.response.status_code}, + ) + raise HTTPException( + status_code=HTTP_503_SERVICE_UNAVAILABLE, + detail="Token introspection service unavailable", + ) from None + except Exception as e: + logger.error( + "OAuth2 introspection failed", + extra={"reason": "unexpected_error", "error": str(e)}, + ) + raise HTTPException( + status_code=HTTP_503_SERVICE_UNAVAILABLE, + detail="Token introspection failed", + ) from None + + if not introspection_result.get("active", False): + logger.debug( + "OAuth2 authentication failed", + extra={"reason": "token_not_active", "scheme": "oauth2"}, + ) + raise HTTPException( + status_code=HTTP_401_UNAUTHORIZED, + detail="Token is not active", + ) + + return AuthenticatedUser( + token=token, + scheme="oauth2", + claims=introspection_result, + ) + + def to_security_scheme(self) -> OAuth2SecurityScheme: + """Generate OAuth2SecurityScheme for AgentCard declaration. + + Creates an OAuth2SecurityScheme with appropriate flows based on + the configured URLs. Includes client_credentials flow if token_url + is set, and authorization_code flow if authorization_url is set. + + Returns: + OAuth2SecurityScheme suitable for use in AgentCard security_schemes. + """ + from a2a.types import ( + AuthorizationCodeOAuthFlow, + ClientCredentialsOAuthFlow, + OAuth2SecurityScheme, + OAuthFlows, + ) + + client_credentials = None + authorization_code = None + + if self.token_url: + client_credentials = ClientCredentialsOAuthFlow( + token_url=str(self.token_url), + refresh_url=str(self.refresh_url) if self.refresh_url else None, + scopes=self.scopes, + ) + + if self.authorization_url: + authorization_code = AuthorizationCodeOAuthFlow( + authorization_url=str(self.authorization_url), + token_url=str(self.token_url), + refresh_url=str(self.refresh_url) if self.refresh_url else None, + scopes=self.scopes, + ) + + return OAuth2SecurityScheme( + flows=OAuthFlows( + client_credentials=client_credentials, + authorization_code=authorization_code, + ), + description="OAuth2 authentication", + ) + + +class APIKeyServerAuth(ServerAuthScheme): + """API Key authentication for A2A server. + + Validates requests using an API key in a header, query parameter, or cookie. + + Attributes: + name: The name of the API key parameter (default: X-API-Key). + location: Where to look for the API key (header, query, or cookie). + api_key: The expected API key value. + """ + + name: str = Field( + default="X-API-Key", + description="Name of the API key parameter", + ) + location: Literal["header", "query", "cookie"] = Field( + default="header", + description="Where to look for the API key", + ) + api_key: CoercedSecretStr = Field( + description="Expected API key value", + ) + + async def authenticate(self, token: str) -> AuthenticatedUser: + """Authenticate using API key comparison. + + Args: + token: The API key to authenticate. + + Returns: + AuthenticatedUser on successful authentication. + + Raises: + HTTPException: If authentication fails. + """ + if token != self.api_key.get_secret_value(): + raise HTTPException( + status_code=HTTP_401_UNAUTHORIZED, + detail="Invalid API key", + ) + + return AuthenticatedUser( + token=token, + scheme="api_key", + ) + + +class MTLSServerAuth(ServerAuthScheme): + """Mutual TLS authentication marker for AgentCard declaration. + + This scheme is primarily for AgentCard security_schemes declaration. + Actual mTLS verification happens at the TLS/transport layer, not + at the application layer via token validation. + + When configured, this signals to clients that the server requires + client certificates for authentication. + """ + + description: str = Field( + default="Mutual TLS certificate authentication", + description="Description for the security scheme", + ) + + async def authenticate(self, token: str) -> AuthenticatedUser: + """Return authenticated user for mTLS. + + mTLS verification happens at the transport layer before this is called. + If we reach this point, the TLS handshake with client cert succeeded. + + Args: + token: Certificate subject or identifier (from TLS layer). + + Returns: + AuthenticatedUser indicating mTLS authentication. + """ + return AuthenticatedUser( + token=token or "mtls-verified", + scheme="mtls", + ) diff --git a/lib/crewai-a2a/src/crewai_a2a/auth/utils.py b/lib/crewai-a2a/src/crewai_a2a/auth/utils.py new file mode 100644 index 000000000..d1699143e --- /dev/null +++ b/lib/crewai-a2a/src/crewai_a2a/auth/utils.py @@ -0,0 +1,273 @@ +"""Authentication utilities for A2A protocol agent communication. + +Provides validation and retry logic for various authentication schemes including +OAuth2, API keys, and HTTP authentication methods. +""" + +import asyncio +from collections.abc import Awaitable, Callable, MutableMapping +import hashlib +import re +import threading +from typing import Final, Literal, cast + +from a2a.client.errors import A2AClientHTTPError +from a2a.types import ( + APIKeySecurityScheme, + AgentCard, + HTTPAuthSecurityScheme, + OAuth2SecurityScheme, +) +from httpx import AsyncClient, Response + +from crewai_a2a.auth.client_schemes import ( + APIKeyAuth, + BearerTokenAuth, + ClientAuthScheme, + HTTPBasicAuth, + HTTPDigestAuth, + OAuth2AuthorizationCode, + OAuth2ClientCredentials, +) + + +class _AuthStore: + """Store for authentication schemes with safe concurrent access.""" + + def __init__(self) -> None: + self._store: dict[str, ClientAuthScheme | None] = {} + self._lock = threading.RLock() + + @staticmethod + def compute_key(auth_type: str, auth_data: str) -> str: + """Compute a collision-resistant key using SHA-256.""" + content = f"{auth_type}:{auth_data}" + return hashlib.sha256(content.encode()).hexdigest() + + def set(self, key: str, auth: ClientAuthScheme | None) -> None: + """Store an auth scheme.""" + with self._lock: + self._store[key] = auth + + def get(self, key: str) -> ClientAuthScheme | None: + """Retrieve an auth scheme by key.""" + with self._lock: + return self._store.get(key) + + def __setitem__(self, key: str, value: ClientAuthScheme | None) -> None: + with self._lock: + self._store[key] = value + + def __getitem__(self, key: str) -> ClientAuthScheme | None: + with self._lock: + return self._store[key] + + +_auth_store = _AuthStore() + +_SCHEME_PATTERN: Final[re.Pattern[str]] = re.compile(r"(\w+)\s+(.+?)(?=,\s*\w+\s+|$)") +_PARAM_PATTERN: Final[re.Pattern[str]] = re.compile(r'(\w+)=(?:"([^"]*)"|([^\s,]+))') + +_SCHEME_AUTH_MAPPING: Final[dict[type, tuple[type[ClientAuthScheme], ...]]] = { + OAuth2SecurityScheme: ( + OAuth2ClientCredentials, + OAuth2AuthorizationCode, + BearerTokenAuth, + ), + APIKeySecurityScheme: (APIKeyAuth,), +} + +_HTTPSchemeType = Literal["basic", "digest", "bearer"] + +_HTTP_SCHEME_MAPPING: Final[dict[_HTTPSchemeType, type[ClientAuthScheme]]] = { + "basic": HTTPBasicAuth, + "digest": HTTPDigestAuth, + "bearer": BearerTokenAuth, +} + + +def _raise_auth_mismatch( + expected_classes: type[ClientAuthScheme] | tuple[type[ClientAuthScheme], ...], + provided_auth: ClientAuthScheme, +) -> None: + """Raise authentication mismatch error. + + Args: + expected_classes: Expected authentication class or tuple of classes. + provided_auth: Actually provided authentication instance. + + Raises: + A2AClientHTTPError: Always raises with 401 status code. + """ + if isinstance(expected_classes, tuple): + if len(expected_classes) == 1: + required = expected_classes[0].__name__ + else: + names = [cls.__name__ for cls in expected_classes] + required = f"one of ({', '.join(names)})" + else: + required = expected_classes.__name__ + + msg = ( + f"AgentCard requires {required} authentication, " + f"but {type(provided_auth).__name__} was provided" + ) + raise A2AClientHTTPError(401, msg) + + +def parse_www_authenticate(header_value: str) -> dict[str, dict[str, str]]: + """Parse WWW-Authenticate header into auth challenges. + + Args: + header_value: The WWW-Authenticate header value. + + Returns: + Dictionary mapping auth scheme to its parameters. + Example: {"Bearer": {"realm": "api", "scope": "read write"}} + """ + if not header_value: + return {} + + challenges: dict[str, dict[str, str]] = {} + + for match in _SCHEME_PATTERN.finditer(header_value): + scheme = match.group(1) + params_str = match.group(2) + + params: dict[str, str] = {} + + for param_match in _PARAM_PATTERN.finditer(params_str): + key = param_match.group(1) + value = param_match.group(2) or param_match.group(3) + params[key] = value + + challenges[scheme] = params + + return challenges + + +def validate_auth_against_agent_card( + agent_card: AgentCard, auth: ClientAuthScheme | None +) -> None: + """Validate that provided auth matches AgentCard security requirements. + + Args: + agent_card: The A2A AgentCard containing security requirements. + auth: User-provided authentication scheme (or None). + + Raises: + A2AClientHTTPError: If auth doesn't match AgentCard requirements (status_code=401). + """ + + if not agent_card.security or not agent_card.security_schemes: + return + + if not auth: + msg = "AgentCard requires authentication but no auth scheme provided" + raise A2AClientHTTPError(401, msg) + + first_security_req = agent_card.security[0] if agent_card.security else {} + + for scheme_name in first_security_req.keys(): + security_scheme_wrapper = agent_card.security_schemes.get(scheme_name) + if not security_scheme_wrapper: + continue + + scheme = security_scheme_wrapper.root + + if allowed_classes := _SCHEME_AUTH_MAPPING.get(type(scheme)): + if not isinstance(auth, allowed_classes): + _raise_auth_mismatch(allowed_classes, auth) + return + + if isinstance(scheme, HTTPAuthSecurityScheme): + scheme_key = cast(_HTTPSchemeType, scheme.scheme.lower()) + if required_class := _HTTP_SCHEME_MAPPING.get(scheme_key): + if not isinstance(auth, required_class): + _raise_auth_mismatch(required_class, auth) + return + + msg = "Could not validate auth against AgentCard security requirements" + raise A2AClientHTTPError(401, msg) + + +async def retry_on_401( + request_func: Callable[[], Awaitable[Response]], + auth_scheme: ClientAuthScheme | None, + client: AsyncClient, + headers: MutableMapping[str, str], + max_retries: int = 3, +) -> Response: + """Retry a request on 401 authentication error. + + Handles 401 errors by: + 1. Parsing WWW-Authenticate header + 2. Re-acquiring credentials + 3. Retrying the request + + Args: + request_func: Async function that makes the HTTP request. + auth_scheme: Authentication scheme to refresh credentials with. + client: HTTP client for making requests. + headers: Request headers to update with new auth. + max_retries: Maximum number of retry attempts (default: 3). + + Returns: + HTTP response from the request. + + Raises: + httpx.HTTPStatusError: If retries are exhausted or auth scheme is None. + """ + last_response: Response | None = None + last_challenges: dict[str, dict[str, str]] = {} + + for attempt in range(max_retries): + response = await request_func() + + if response.status_code != 401: + return response + + last_response = response + + if auth_scheme is None: + response.raise_for_status() + return response + + www_authenticate = response.headers.get("WWW-Authenticate", "") + challenges = parse_www_authenticate(www_authenticate) + last_challenges = challenges + + if attempt >= max_retries - 1: + break + + backoff_time = 2**attempt + await asyncio.sleep(backoff_time) + + await auth_scheme.apply_auth(client, headers) + + if last_response: + last_response.raise_for_status() + return last_response + + msg = "retry_on_401 failed without making any requests" + if last_challenges: + challenge_info = ", ".join( + f"{scheme} (realm={params.get('realm', 'N/A')})" + for scheme, params in last_challenges.items() + ) + msg = f"{msg}. Server challenges: {challenge_info}" + raise RuntimeError(msg) + + +def configure_auth_client( + auth: HTTPDigestAuth | APIKeyAuth, client: AsyncClient +) -> None: + """Configure HTTP client with auth-specific settings. + + Only HTTPDigestAuth and APIKeyAuth need client configuration. + + Args: + auth: Authentication scheme that requires client configuration. + client: HTTP client to configure. + """ + auth.configure_client(client) diff --git a/lib/crewai-a2a/src/crewai_a2a/config.py b/lib/crewai-a2a/src/crewai_a2a/config.py new file mode 100644 index 000000000..58f1dd66f --- /dev/null +++ b/lib/crewai-a2a/src/crewai_a2a/config.py @@ -0,0 +1,690 @@ +"""A2A configuration types. + +This module is separate from experimental.a2a to avoid circular imports. +""" + +from __future__ import annotations + +from pathlib import Path +from typing import Any, ClassVar, Literal, cast +import warnings + +from pydantic import ( + BaseModel, + ConfigDict, + Field, + FilePath, + PrivateAttr, + SecretStr, + model_validator, +) +from typing_extensions import Self, deprecated + +from crewai_a2a.auth.client_schemes import ClientAuthScheme +from crewai_a2a.auth.server_schemes import ServerAuthScheme +from crewai_a2a.extensions.base import ValidatedA2AExtension +from crewai_a2a.types import ProtocolVersion, TransportType, Url + + +try: + from a2a.types import ( + AgentCapabilities, + AgentCardSignature, + AgentInterface, + AgentProvider, + AgentSkill, + SecurityScheme, + ) + + from crewai_a2a.extensions.server import ServerExtension + from crewai_a2a.updates import UpdateConfig +except ImportError: + UpdateConfig: Any = Any # type: ignore[no-redef] + AgentCapabilities: Any = Any # type: ignore[no-redef] + AgentCardSignature: Any = Any # type: ignore[no-redef] + AgentInterface: Any = Any # type: ignore[no-redef] + AgentProvider: Any = Any # type: ignore[no-redef] + SecurityScheme: Any = Any # type: ignore[no-redef] + AgentSkill: Any = Any # type: ignore[no-redef] + ServerExtension: Any = Any # type: ignore[no-redef] + + +def _get_default_update_config() -> UpdateConfig: + from crewai_a2a.updates import StreamingConfig + + return StreamingConfig() + + +SigningAlgorithm = Literal[ + "RS256", + "RS384", + "RS512", + "ES256", + "ES384", + "ES512", + "PS256", + "PS384", + "PS512", +] + + +class AgentCardSigningConfig(BaseModel): + """Configuration for AgentCard JWS signing. + + Provides the private key and algorithm settings for signing AgentCards. + Either private_key_path or private_key_pem must be provided, but not both. + + Attributes: + private_key_path: Path to a PEM-encoded private key file. + private_key_pem: PEM-encoded private key as a secret string. + key_id: Optional key identifier for the JWS header (kid claim). + algorithm: Signing algorithm (RS256, ES256, PS256, etc.). + """ + + model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid") + + private_key_path: FilePath | None = Field( + default=None, + description="Path to PEM-encoded private key file", + ) + private_key_pem: SecretStr | None = Field( + default=None, + description="PEM-encoded private key", + ) + key_id: str | None = Field( + default=None, + description="Key identifier for JWS header (kid claim)", + ) + algorithm: SigningAlgorithm = Field( + default="RS256", + description="Signing algorithm (RS256, ES256, PS256, etc.)", + ) + + @model_validator(mode="after") + def _validate_key_source(self) -> Self: + """Ensure exactly one key source is provided.""" + has_path = self.private_key_path is not None + has_pem = self.private_key_pem is not None + + if not has_path and not has_pem: + raise ValueError( + "Either private_key_path or private_key_pem must be provided" + ) + if has_path and has_pem: + raise ValueError( + "Only one of private_key_path or private_key_pem should be provided" + ) + return self + + def get_private_key(self) -> str: + """Get the private key content. + + Returns: + The PEM-encoded private key as a string. + """ + if self.private_key_pem: + return self.private_key_pem.get_secret_value() + if self.private_key_path: + return Path(self.private_key_path).read_text() + raise ValueError("No private key configured") + + +class GRPCServerConfig(BaseModel): + """gRPC server transport configuration. + + Presence of this config in ServerTransportConfig.grpc enables gRPC transport. + + Attributes: + host: Hostname to advertise in agent cards (default: localhost). + Use docker service name (e.g., 'web') for docker-compose setups. + port: Port for the gRPC server. + tls_cert_path: Path to TLS certificate file for gRPC. + tls_key_path: Path to TLS private key file for gRPC. + max_workers: Maximum number of workers for the gRPC thread pool. + reflection_enabled: Whether to enable gRPC reflection for debugging. + """ + + model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid") + + host: str = Field( + default="localhost", + description="Hostname to advertise in agent cards for gRPC connections", + ) + port: int = Field( + default=50051, + description="Port for the gRPC server", + ) + tls_cert_path: str | None = Field( + default=None, + description="Path to TLS certificate file for gRPC", + ) + tls_key_path: str | None = Field( + default=None, + description="Path to TLS private key file for gRPC", + ) + max_workers: int = Field( + default=10, + description="Maximum number of workers for the gRPC thread pool", + ) + reflection_enabled: bool = Field( + default=False, + description="Whether to enable gRPC reflection for debugging", + ) + + +class GRPCClientConfig(BaseModel): + """gRPC client transport configuration. + + Attributes: + max_send_message_length: Maximum size for outgoing messages in bytes. + max_receive_message_length: Maximum size for incoming messages in bytes. + keepalive_time_ms: Time between keepalive pings in milliseconds. + keepalive_timeout_ms: Timeout for keepalive ping response in milliseconds. + """ + + model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid") + + max_send_message_length: int | None = Field( + default=None, + description="Maximum size for outgoing messages in bytes", + ) + max_receive_message_length: int | None = Field( + default=None, + description="Maximum size for incoming messages in bytes", + ) + keepalive_time_ms: int | None = Field( + default=None, + description="Time between keepalive pings in milliseconds", + ) + keepalive_timeout_ms: int | None = Field( + default=None, + description="Timeout for keepalive ping response in milliseconds", + ) + + +class JSONRPCServerConfig(BaseModel): + """JSON-RPC server transport configuration. + + Presence of this config in ServerTransportConfig.jsonrpc enables JSON-RPC transport. + + Attributes: + rpc_path: URL path for the JSON-RPC endpoint. + agent_card_path: URL path for the agent card endpoint. + """ + + model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid") + + rpc_path: str = Field( + default="/a2a", + description="URL path for the JSON-RPC endpoint", + ) + agent_card_path: str = Field( + default="/.well-known/agent-card.json", + description="URL path for the agent card endpoint", + ) + + +class JSONRPCClientConfig(BaseModel): + """JSON-RPC client transport configuration. + + Attributes: + max_request_size: Maximum request body size in bytes. + """ + + model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid") + + max_request_size: int | None = Field( + default=None, + description="Maximum request body size in bytes", + ) + + +class HTTPJSONConfig(BaseModel): + """HTTP+JSON transport configuration. + + Presence of this config in ServerTransportConfig.http_json enables HTTP+JSON transport. + """ + + model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid") + + +class ServerPushNotificationConfig(BaseModel): + """Configuration for outgoing webhook push notifications. + + Controls how the server signs and delivers push notifications to clients. + + Attributes: + signature_secret: Shared secret for HMAC-SHA256 signing of outgoing webhooks. + """ + + model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid") + + signature_secret: SecretStr | None = Field( + default=None, + description="Shared secret for HMAC-SHA256 signing of outgoing push notifications", + ) + + +class ServerTransportConfig(BaseModel): + """Transport configuration for A2A server. + + Groups all transport-related settings including preferred transport + and protocol-specific configurations. + + Attributes: + preferred: Transport protocol for the preferred endpoint. + jsonrpc: JSON-RPC server transport configuration. + grpc: gRPC server transport configuration. + http_json: HTTP+JSON transport configuration. + """ + + model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid") + + preferred: TransportType = Field( + default="JSONRPC", + description="Transport protocol for the preferred endpoint", + ) + jsonrpc: JSONRPCServerConfig = Field( + default_factory=JSONRPCServerConfig, + description="JSON-RPC server transport configuration", + ) + grpc: GRPCServerConfig | None = Field( + default=None, + description="gRPC server transport configuration", + ) + http_json: HTTPJSONConfig | None = Field( + default=None, + description="HTTP+JSON transport configuration", + ) + + +def _migrate_client_transport_fields( + transport: ClientTransportConfig, + transport_protocol: TransportType | None, + supported_transports: list[TransportType] | None, +) -> None: + """Migrate deprecated transport fields to new config.""" + if transport_protocol is not None: + warnings.warn( + "transport_protocol is deprecated, use transport=ClientTransportConfig(preferred=...) instead", + FutureWarning, + stacklevel=5, + ) + object.__setattr__(transport, "preferred", transport_protocol) + if supported_transports is not None: + warnings.warn( + "supported_transports is deprecated, use transport=ClientTransportConfig(supported=...) instead", + FutureWarning, + stacklevel=5, + ) + object.__setattr__(transport, "supported", supported_transports) + + +class ClientTransportConfig(BaseModel): + """Transport configuration for A2A client. + + Groups all client transport-related settings including preferred transport, + supported transports for negotiation, and protocol-specific configurations. + + Transport negotiation logic: + 1. If `preferred` is set and server supports it → use client's preferred + 2. Otherwise, if server's preferred is in client's `supported` → use server's preferred + 3. Otherwise, find first match from client's `supported` in server's interfaces + + Attributes: + preferred: Client's preferred transport. If set, client preference takes priority. + supported: Transports the client can use, in order of preference. + jsonrpc: JSON-RPC client transport configuration. + grpc: gRPC client transport configuration. + """ + + model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid") + + preferred: TransportType | None = Field( + default=None, + description="Client's preferred transport. If set, takes priority over server preference.", + ) + supported: list[TransportType] = Field( + default_factory=lambda: cast(list[TransportType], ["JSONRPC"]), + description="Transports the client can use, in order of preference", + ) + jsonrpc: JSONRPCClientConfig = Field( + default_factory=JSONRPCClientConfig, + description="JSON-RPC client transport configuration", + ) + grpc: GRPCClientConfig = Field( + default_factory=GRPCClientConfig, + description="gRPC client transport configuration", + ) + + +@deprecated( + """ + `crewai_a2a.config.A2AConfig` is deprecated and will be removed in v2.0.0, + use `crewai_a2a.config.A2AClientConfig` or `crewai_a2a.config.A2AServerConfig` instead. + """, + category=FutureWarning, +) +class A2AConfig(BaseModel): + """Configuration for A2A protocol integration. + + Deprecated: + Use A2AClientConfig instead. This class will be removed in a future version. + + Attributes: + endpoint: A2A agent endpoint URL. + auth: Authentication scheme. + timeout: Request timeout in seconds. + max_turns: Maximum conversation turns with A2A agent. + response_model: Optional Pydantic model for structured A2A agent responses. + fail_fast: If True, raise error when agent unreachable; if False, skip and continue. + trust_remote_completion_status: If True, return A2A agent's result directly when completed. + updates: Update mechanism config. + client_extensions: Client-side processing hooks for tool injection and prompt augmentation. + transport: Transport configuration (preferred, supported transports, gRPC settings). + """ + + model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid") + + endpoint: Url = Field(description="A2A agent endpoint URL") + auth: ClientAuthScheme | None = Field( + default=None, + description="Authentication scheme", + ) + timeout: int = Field(default=120, description="Request timeout in seconds") + max_turns: int = Field( + default=10, description="Maximum conversation turns with A2A agent" + ) + response_model: type[BaseModel] | None = Field( + default=None, + description="Optional Pydantic model for structured A2A agent responses", + ) + fail_fast: bool = Field( + default=True, + description="If True, raise error when agent unreachable; if False, skip", + ) + trust_remote_completion_status: bool = Field( + default=False, + description="If True, return A2A result directly when completed", + ) + updates: UpdateConfig = Field( + default_factory=_get_default_update_config, + description="Update mechanism config", + ) + client_extensions: list[ValidatedA2AExtension] = Field( + default_factory=list, + description="Client-side processing hooks for tool injection and prompt augmentation", + ) + transport: ClientTransportConfig = Field( + default_factory=ClientTransportConfig, + description="Transport configuration (preferred, supported transports, gRPC settings)", + ) + transport_protocol: TransportType | None = Field( + default=None, + description="Deprecated: Use transport.preferred instead", + exclude=True, + ) + supported_transports: list[TransportType] | None = Field( + default=None, + description="Deprecated: Use transport.supported instead", + exclude=True, + ) + use_client_preference: bool | None = Field( + default=None, + description="Deprecated: Set transport.preferred to enable client preference", + exclude=True, + ) + _parallel_delegation: bool = PrivateAttr(default=False) + + @model_validator(mode="after") + def _migrate_deprecated_transport_fields(self) -> Self: + """Migrate deprecated transport fields to new config.""" + _migrate_client_transport_fields( + self.transport, self.transport_protocol, self.supported_transports + ) + if self.use_client_preference is not None: + warnings.warn( + "use_client_preference is deprecated, set transport.preferred to enable client preference", + FutureWarning, + stacklevel=4, + ) + if self.use_client_preference and self.transport.supported: + object.__setattr__( + self.transport, "preferred", self.transport.supported[0] + ) + return self + + +class A2AClientConfig(BaseModel): + """Configuration for connecting to remote A2A agents. + + Attributes: + endpoint: A2A agent endpoint URL. + auth: Authentication scheme. + timeout: Request timeout in seconds. + max_turns: Maximum conversation turns with A2A agent. + response_model: Optional Pydantic model for structured A2A agent responses. + fail_fast: If True, raise error when agent unreachable; if False, skip and continue. + trust_remote_completion_status: If True, return A2A agent's result directly when completed. + updates: Update mechanism config. + accepted_output_modes: Media types the client can accept in responses. + extensions: Extension URIs the client supports (A2A protocol extensions). + client_extensions: Client-side processing hooks for tool injection and prompt augmentation. + transport: Transport configuration (preferred, supported transports, gRPC settings). + """ + + model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid") + + endpoint: Url = Field(description="A2A agent endpoint URL") + auth: ClientAuthScheme | None = Field( + default=None, + description="Authentication scheme", + ) + timeout: int = Field(default=120, description="Request timeout in seconds") + max_turns: int = Field( + default=10, description="Maximum conversation turns with A2A agent" + ) + response_model: type[BaseModel] | None = Field( + default=None, + description="Optional Pydantic model for structured A2A agent responses", + ) + fail_fast: bool = Field( + default=True, + description="If True, raise error when agent unreachable; if False, skip", + ) + trust_remote_completion_status: bool = Field( + default=False, + description="If True, return A2A result directly when completed", + ) + updates: UpdateConfig = Field( + default_factory=_get_default_update_config, + description="Update mechanism config", + ) + accepted_output_modes: list[str] = Field( + default_factory=lambda: ["application/json"], + description="Media types the client can accept in responses", + ) + extensions: list[str] = Field( + default_factory=list, + description="Extension URIs the client supports", + ) + client_extensions: list[ValidatedA2AExtension] = Field( + default_factory=list, + description="Client-side processing hooks for tool injection and prompt augmentation", + ) + transport: ClientTransportConfig = Field( + default_factory=ClientTransportConfig, + description="Transport configuration (preferred, supported transports, gRPC settings)", + ) + transport_protocol: TransportType | None = Field( + default=None, + description="Deprecated: Use transport.preferred instead", + exclude=True, + ) + supported_transports: list[TransportType] | None = Field( + default=None, + description="Deprecated: Use transport.supported instead", + exclude=True, + ) + _parallel_delegation: bool = PrivateAttr(default=False) + + @model_validator(mode="after") + def _migrate_deprecated_transport_fields(self) -> Self: + """Migrate deprecated transport fields to new config.""" + _migrate_client_transport_fields( + self.transport, self.transport_protocol, self.supported_transports + ) + return self + + +class A2AServerConfig(BaseModel): + """Configuration for exposing a Crew or Agent as an A2A server. + + All fields correspond to A2A AgentCard fields. Fields like name, description, + and skills can be auto-derived from the Crew/Agent if not provided. + + Attributes: + name: Human-readable name for the agent. + description: Human-readable description of the agent. + version: Version string for the agent card. + skills: List of agent skills/capabilities. + default_input_modes: Default supported input MIME types. + default_output_modes: Default supported output MIME types. + capabilities: Declaration of optional capabilities. + protocol_version: A2A protocol version this agent supports. + provider: Information about the agent's service provider. + documentation_url: URL to the agent's documentation. + icon_url: URL to an icon for the agent. + additional_interfaces: Additional supported interfaces. + security: Security requirement objects for all interactions. + security_schemes: Security schemes available to authorize requests. + supports_authenticated_extended_card: Whether agent provides extended card to authenticated users. + url: Preferred endpoint URL for the agent. + signing_config: Configuration for signing the AgentCard with JWS. + signatures: Deprecated. Pre-computed JWS signatures. Use signing_config instead. + server_extensions: Server-side A2A protocol extensions with on_request/on_response hooks. + push_notifications: Configuration for outgoing push notifications. + transport: Transport configuration (preferred transport, gRPC, REST settings). + auth: Authentication scheme for A2A endpoints. + """ + + model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid") + + name: str | None = Field( + default=None, + description="Human-readable name for the agent. Auto-derived from Crew/Agent if not provided.", + ) + description: str | None = Field( + default=None, + description="Human-readable description of the agent. Auto-derived from Crew/Agent if not provided.", + ) + version: str = Field( + default="1.0.0", + description="Version string for the agent card", + ) + skills: list[AgentSkill] = Field( + default_factory=list, + description="List of agent skills. Auto-derived from tasks/tools if not provided.", + ) + default_input_modes: list[str] = Field( + default_factory=lambda: ["text/plain", "application/json"], + description="Default supported input MIME types", + ) + default_output_modes: list[str] = Field( + default_factory=lambda: ["text/plain", "application/json"], + description="Default supported output MIME types", + ) + capabilities: AgentCapabilities = Field( + default_factory=lambda: AgentCapabilities( + streaming=True, + push_notifications=False, + ), + description="Declaration of optional capabilities supported by the agent", + ) + protocol_version: ProtocolVersion = Field( + default="0.3.0", + description="A2A protocol version this agent supports", + ) + provider: AgentProvider | None = Field( + default=None, + description="Information about the agent's service provider", + ) + documentation_url: Url | None = Field( + default=None, + description="URL to the agent's documentation", + ) + icon_url: Url | None = Field( + default=None, + description="URL to an icon for the agent", + ) + additional_interfaces: list[AgentInterface] = Field( + default_factory=list, + description="Additional supported interfaces.", + ) + security: list[dict[str, list[str]]] = Field( + default_factory=list, + description="Security requirement objects for all agent interactions", + ) + security_schemes: dict[str, SecurityScheme] = Field( + default_factory=dict, + description="Security schemes available to authorize requests", + ) + supports_authenticated_extended_card: bool = Field( + default=False, + description="Whether agent provides extended card to authenticated users", + ) + url: Url | None = Field( + default=None, + description="Preferred endpoint URL for the agent. Set at runtime if not provided.", + ) + signing_config: AgentCardSigningConfig | None = Field( + default=None, + description="Configuration for signing the AgentCard with JWS", + ) + signatures: list[AgentCardSignature] | None = Field( + default=None, + description="Deprecated: Use signing_config instead. Pre-computed JWS signatures for the AgentCard.", + exclude=True, + deprecated=True, + ) + server_extensions: list[ServerExtension] = Field( + default_factory=list, + description="Server-side A2A protocol extensions that modify agent behavior", + ) + push_notifications: ServerPushNotificationConfig | None = Field( + default=None, + description="Configuration for outgoing push notifications", + ) + transport: ServerTransportConfig = Field( + default_factory=ServerTransportConfig, + description="Transport configuration (preferred transport, gRPC, REST settings)", + ) + preferred_transport: TransportType | None = Field( + default=None, + description="Deprecated: Use transport.preferred instead", + exclude=True, + deprecated=True, + ) + auth: ServerAuthScheme | None = Field( + default=None, + description="Authentication scheme for A2A endpoints. Defaults to SimpleTokenAuth using AUTH_TOKEN env var.", + ) + + @model_validator(mode="after") + def _migrate_deprecated_fields(self) -> Self: + """Migrate deprecated fields to new config.""" + if self.preferred_transport is not None: + warnings.warn( + "preferred_transport is deprecated, use transport=ServerTransportConfig(preferred=...) instead", + FutureWarning, + stacklevel=4, + ) + object.__setattr__(self.transport, "preferred", self.preferred_transport) + if self.signatures is not None: + warnings.warn( + "signatures is deprecated, use signing_config=AgentCardSigningConfig(...) instead. " + "The signatures field will be removed in v2.0.0.", + FutureWarning, + stacklevel=4, + ) + return self diff --git a/lib/crewai-a2a/src/crewai_a2a/errors.py b/lib/crewai-a2a/src/crewai_a2a/errors.py new file mode 100644 index 000000000..aabe10288 --- /dev/null +++ b/lib/crewai-a2a/src/crewai_a2a/errors.py @@ -0,0 +1,491 @@ +"""A2A error codes and error response utilities. + +This module provides a centralized mapping of all A2A protocol error codes +as defined in the A2A specification, plus custom CrewAI extensions. + +Error codes follow JSON-RPC 2.0 conventions: +- -32700 to -32600: Standard JSON-RPC errors +- -32099 to -32000: Server errors (A2A-specific) +- -32768 to -32100: Reserved for implementation-defined errors +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from enum import IntEnum +from typing import Any + +from a2a.client.errors import A2AClientTimeoutError + + +class A2APollingTimeoutError(A2AClientTimeoutError): + """Raised when polling exceeds the configured timeout.""" + + +class A2AErrorCode(IntEnum): + """A2A protocol error codes. + + Codes follow JSON-RPC 2.0 specification with A2A-specific extensions. + """ + + # JSON-RPC 2.0 Standard Errors (-32700 to -32600) + JSON_PARSE_ERROR = -32700 + """Invalid JSON was received by the server.""" + + INVALID_REQUEST = -32600 + """The JSON sent is not a valid Request object.""" + + METHOD_NOT_FOUND = -32601 + """The method does not exist / is not available.""" + + INVALID_PARAMS = -32602 + """Invalid method parameter(s).""" + + INTERNAL_ERROR = -32603 + """Internal JSON-RPC error.""" + + # A2A-Specific Errors (-32099 to -32000) + TASK_NOT_FOUND = -32001 + """The specified task was not found.""" + + TASK_NOT_CANCELABLE = -32002 + """The task cannot be canceled (already completed/failed).""" + + PUSH_NOTIFICATION_NOT_SUPPORTED = -32003 + """Push notifications are not supported by this agent.""" + + UNSUPPORTED_OPERATION = -32004 + """The requested operation is not supported.""" + + CONTENT_TYPE_NOT_SUPPORTED = -32005 + """Incompatible content types between client and server.""" + + INVALID_AGENT_RESPONSE = -32006 + """The agent produced an invalid response.""" + + # CrewAI Custom Extensions (-32768 to -32100) + UNSUPPORTED_VERSION = -32009 + """The requested A2A protocol version is not supported.""" + + UNSUPPORTED_EXTENSION = -32010 + """Client does not support required protocol extensions.""" + + AUTHENTICATION_REQUIRED = -32011 + """Authentication is required for this operation.""" + + AUTHORIZATION_FAILED = -32012 + """Authorization check failed (insufficient permissions).""" + + RATE_LIMIT_EXCEEDED = -32013 + """Rate limit exceeded for this client/operation.""" + + TASK_TIMEOUT = -32014 + """Task execution timed out.""" + + TRANSPORT_NEGOTIATION_FAILED = -32015 + """Failed to negotiate a compatible transport protocol.""" + + CONTEXT_NOT_FOUND = -32016 + """The specified context was not found.""" + + SKILL_NOT_FOUND = -32017 + """The specified skill was not found.""" + + ARTIFACT_NOT_FOUND = -32018 + """The specified artifact was not found.""" + + +# Error code to default message mapping +ERROR_MESSAGES: dict[int, str] = { + A2AErrorCode.JSON_PARSE_ERROR: "Parse error", + A2AErrorCode.INVALID_REQUEST: "Invalid Request", + A2AErrorCode.METHOD_NOT_FOUND: "Method not found", + A2AErrorCode.INVALID_PARAMS: "Invalid params", + A2AErrorCode.INTERNAL_ERROR: "Internal error", + A2AErrorCode.TASK_NOT_FOUND: "Task not found", + A2AErrorCode.TASK_NOT_CANCELABLE: "Task not cancelable", + A2AErrorCode.PUSH_NOTIFICATION_NOT_SUPPORTED: "Push Notification is not supported", + A2AErrorCode.UNSUPPORTED_OPERATION: "This operation is not supported", + A2AErrorCode.CONTENT_TYPE_NOT_SUPPORTED: "Incompatible content types", + A2AErrorCode.INVALID_AGENT_RESPONSE: "Invalid agent response", + A2AErrorCode.UNSUPPORTED_VERSION: "Unsupported A2A version", + A2AErrorCode.UNSUPPORTED_EXTENSION: "Client does not support required extensions", + A2AErrorCode.AUTHENTICATION_REQUIRED: "Authentication required", + A2AErrorCode.AUTHORIZATION_FAILED: "Authorization failed", + A2AErrorCode.RATE_LIMIT_EXCEEDED: "Rate limit exceeded", + A2AErrorCode.TASK_TIMEOUT: "Task execution timed out", + A2AErrorCode.TRANSPORT_NEGOTIATION_FAILED: "Transport negotiation failed", + A2AErrorCode.CONTEXT_NOT_FOUND: "Context not found", + A2AErrorCode.SKILL_NOT_FOUND: "Skill not found", + A2AErrorCode.ARTIFACT_NOT_FOUND: "Artifact not found", +} + + +@dataclass +class A2AError(Exception): + """Base exception for A2A protocol errors. + + Attributes: + code: The A2A/JSON-RPC error code. + message: Human-readable error message. + data: Optional additional error data. + """ + + code: int + message: str | None = None + data: Any = None + + def __post_init__(self) -> None: + if self.message is None: + self.message = ERROR_MESSAGES.get(self.code, "Unknown error") + super().__init__(self.message) + + def to_dict(self) -> dict[str, Any]: + """Convert to JSON-RPC error object format.""" + error: dict[str, Any] = { + "code": self.code, + "message": self.message, + } + if self.data is not None: + error["data"] = self.data + return error + + def to_response(self, request_id: str | int | None = None) -> dict[str, Any]: + """Convert to full JSON-RPC error response.""" + return { + "jsonrpc": "2.0", + "error": self.to_dict(), + "id": request_id, + } + + +@dataclass +class JSONParseError(A2AError): + """Invalid JSON was received.""" + + code: int = field(default=A2AErrorCode.JSON_PARSE_ERROR, init=False) + + +@dataclass +class InvalidRequestError(A2AError): + """The JSON sent is not a valid Request object.""" + + code: int = field(default=A2AErrorCode.INVALID_REQUEST, init=False) + + +@dataclass +class MethodNotFoundError(A2AError): + """The method does not exist / is not available.""" + + code: int = field(default=A2AErrorCode.METHOD_NOT_FOUND, init=False) + method: str | None = None + + def __post_init__(self) -> None: + if self.message is None and self.method: + self.message = f"Method not found: {self.method}" + super().__post_init__() + + +@dataclass +class InvalidParamsError(A2AError): + """Invalid method parameter(s).""" + + code: int = field(default=A2AErrorCode.INVALID_PARAMS, init=False) + param: str | None = None + reason: str | None = None + + def __post_init__(self) -> None: + if self.message is None: + if self.param and self.reason: + self.message = f"Invalid parameter '{self.param}': {self.reason}" + elif self.param: + self.message = f"Invalid parameter: {self.param}" + super().__post_init__() + + +@dataclass +class InternalError(A2AError): + """Internal JSON-RPC error.""" + + code: int = field(default=A2AErrorCode.INTERNAL_ERROR, init=False) + + +@dataclass +class TaskNotFoundError(A2AError): + """The specified task was not found.""" + + code: int = field(default=A2AErrorCode.TASK_NOT_FOUND, init=False) + task_id: str | None = None + + def __post_init__(self) -> None: + if self.message is None and self.task_id: + self.message = f"Task not found: {self.task_id}" + super().__post_init__() + + +@dataclass +class TaskNotCancelableError(A2AError): + """The task cannot be canceled.""" + + code: int = field(default=A2AErrorCode.TASK_NOT_CANCELABLE, init=False) + task_id: str | None = None + reason: str | None = None + + def __post_init__(self) -> None: + if self.message is None: + if self.task_id and self.reason: + self.message = f"Task {self.task_id} cannot be canceled: {self.reason}" + elif self.task_id: + self.message = f"Task {self.task_id} cannot be canceled" + super().__post_init__() + + +@dataclass +class PushNotificationNotSupportedError(A2AError): + """Push notifications are not supported.""" + + code: int = field(default=A2AErrorCode.PUSH_NOTIFICATION_NOT_SUPPORTED, init=False) + + +@dataclass +class UnsupportedOperationError(A2AError): + """The requested operation is not supported.""" + + code: int = field(default=A2AErrorCode.UNSUPPORTED_OPERATION, init=False) + operation: str | None = None + + def __post_init__(self) -> None: + if self.message is None and self.operation: + self.message = f"Operation not supported: {self.operation}" + super().__post_init__() + + +@dataclass +class ContentTypeNotSupportedError(A2AError): + """Incompatible content types.""" + + code: int = field(default=A2AErrorCode.CONTENT_TYPE_NOT_SUPPORTED, init=False) + requested_types: list[str] | None = None + supported_types: list[str] | None = None + + def __post_init__(self) -> None: + if self.message is None and self.requested_types and self.supported_types: + self.message = ( + f"Content type not supported. Requested: {self.requested_types}, " + f"Supported: {self.supported_types}" + ) + super().__post_init__() + + +@dataclass +class InvalidAgentResponseError(A2AError): + """The agent produced an invalid response.""" + + code: int = field(default=A2AErrorCode.INVALID_AGENT_RESPONSE, init=False) + + +@dataclass +class UnsupportedVersionError(A2AError): + """The requested A2A version is not supported.""" + + code: int = field(default=A2AErrorCode.UNSUPPORTED_VERSION, init=False) + requested_version: str | None = None + supported_versions: list[str] | None = None + + def __post_init__(self) -> None: + if self.message is None and self.requested_version: + msg = f"Unsupported A2A version: {self.requested_version}" + if self.supported_versions: + msg += f". Supported versions: {', '.join(self.supported_versions)}" + self.message = msg + super().__post_init__() + + +@dataclass +class UnsupportedExtensionError(A2AError): + """Client does not support required extensions.""" + + code: int = field(default=A2AErrorCode.UNSUPPORTED_EXTENSION, init=False) + required_extensions: list[str] | None = None + + def __post_init__(self) -> None: + if self.message is None and self.required_extensions: + self.message = f"Client does not support required extensions: {', '.join(self.required_extensions)}" + super().__post_init__() + + +@dataclass +class AuthenticationRequiredError(A2AError): + """Authentication is required.""" + + code: int = field(default=A2AErrorCode.AUTHENTICATION_REQUIRED, init=False) + + +@dataclass +class AuthorizationFailedError(A2AError): + """Authorization check failed.""" + + code: int = field(default=A2AErrorCode.AUTHORIZATION_FAILED, init=False) + required_scope: str | None = None + + def __post_init__(self) -> None: + if self.message is None and self.required_scope: + self.message = ( + f"Authorization failed. Required scope: {self.required_scope}" + ) + super().__post_init__() + + +@dataclass +class RateLimitExceededError(A2AError): + """Rate limit exceeded.""" + + code: int = field(default=A2AErrorCode.RATE_LIMIT_EXCEEDED, init=False) + retry_after: int | None = None + + def __post_init__(self) -> None: + if self.message is None and self.retry_after: + self.message = ( + f"Rate limit exceeded. Retry after {self.retry_after} seconds" + ) + if self.retry_after: + self.data = {"retry_after": self.retry_after} + super().__post_init__() + + +@dataclass +class TaskTimeoutError(A2AError): + """Task execution timed out.""" + + code: int = field(default=A2AErrorCode.TASK_TIMEOUT, init=False) + task_id: str | None = None + timeout_seconds: float | None = None + + def __post_init__(self) -> None: + if self.message is None: + if self.task_id and self.timeout_seconds: + self.message = ( + f"Task {self.task_id} timed out after {self.timeout_seconds}s" + ) + elif self.task_id: + self.message = f"Task {self.task_id} timed out" + super().__post_init__() + + +@dataclass +class TransportNegotiationFailedError(A2AError): + """Failed to negotiate a compatible transport protocol.""" + + code: int = field(default=A2AErrorCode.TRANSPORT_NEGOTIATION_FAILED, init=False) + client_transports: list[str] | None = None + server_transports: list[str] | None = None + + def __post_init__(self) -> None: + if self.message is None and self.client_transports and self.server_transports: + self.message = ( + f"Transport negotiation failed. Client: {self.client_transports}, " + f"Server: {self.server_transports}" + ) + super().__post_init__() + + +@dataclass +class ContextNotFoundError(A2AError): + """The specified context was not found.""" + + code: int = field(default=A2AErrorCode.CONTEXT_NOT_FOUND, init=False) + context_id: str | None = None + + def __post_init__(self) -> None: + if self.message is None and self.context_id: + self.message = f"Context not found: {self.context_id}" + super().__post_init__() + + +@dataclass +class SkillNotFoundError(A2AError): + """The specified skill was not found.""" + + code: int = field(default=A2AErrorCode.SKILL_NOT_FOUND, init=False) + skill_id: str | None = None + + def __post_init__(self) -> None: + if self.message is None and self.skill_id: + self.message = f"Skill not found: {self.skill_id}" + super().__post_init__() + + +@dataclass +class ArtifactNotFoundError(A2AError): + """The specified artifact was not found.""" + + code: int = field(default=A2AErrorCode.ARTIFACT_NOT_FOUND, init=False) + artifact_id: str | None = None + + def __post_init__(self) -> None: + if self.message is None and self.artifact_id: + self.message = f"Artifact not found: {self.artifact_id}" + super().__post_init__() + + +def create_error_response( + code: int | A2AErrorCode, + message: str | None = None, + data: Any = None, + request_id: str | int | None = None, +) -> dict[str, Any]: + """Create a JSON-RPC error response. + + Args: + code: Error code (A2AErrorCode or int). + message: Optional error message (uses default if not provided). + data: Optional additional error data. + request_id: Request ID for correlation. + + Returns: + Dict in JSON-RPC error response format. + """ + error = A2AError(code=int(code), message=message, data=data) + return error.to_response(request_id) + + +def is_retryable_error(code: int) -> bool: + """Check if an error is potentially retryable. + + Args: + code: Error code to check. + + Returns: + True if the error might be resolved by retrying. + """ + retryable_codes = { + A2AErrorCode.INTERNAL_ERROR, + A2AErrorCode.RATE_LIMIT_EXCEEDED, + A2AErrorCode.TASK_TIMEOUT, + } + return code in retryable_codes + + +def is_client_error(code: int) -> bool: + """Check if an error is a client-side error. + + Args: + code: Error code to check. + + Returns: + True if the error is due to client request issues. + """ + client_error_codes = { + A2AErrorCode.JSON_PARSE_ERROR, + A2AErrorCode.INVALID_REQUEST, + A2AErrorCode.METHOD_NOT_FOUND, + A2AErrorCode.INVALID_PARAMS, + A2AErrorCode.TASK_NOT_FOUND, + A2AErrorCode.CONTENT_TYPE_NOT_SUPPORTED, + A2AErrorCode.UNSUPPORTED_VERSION, + A2AErrorCode.UNSUPPORTED_EXTENSION, + A2AErrorCode.CONTEXT_NOT_FOUND, + A2AErrorCode.SKILL_NOT_FOUND, + A2AErrorCode.ARTIFACT_NOT_FOUND, + } + return code in client_error_codes diff --git a/lib/crewai-a2a/src/crewai_a2a/extensions/__init__.py b/lib/crewai-a2a/src/crewai_a2a/extensions/__init__.py new file mode 100644 index 000000000..ef066ca96 --- /dev/null +++ b/lib/crewai-a2a/src/crewai_a2a/extensions/__init__.py @@ -0,0 +1,37 @@ +"""A2A Protocol Extensions for CrewAI. + +This module contains extensions to the A2A (Agent-to-Agent) protocol. + +**Client-side extensions** (A2AExtension) allow customizing how the A2A wrapper +processes requests and responses during delegation to remote agents. These provide +hooks for tool injection, prompt augmentation, and response processing. + +**Server-side extensions** (ServerExtension) allow agents to offer additional +functionality beyond the core A2A specification. Clients activate extensions +via the X-A2A-Extensions header. + +See: https://a2a-protocol.org/latest/topics/extensions/ +""" + +from crewai_a2a.extensions.base import ( + A2AExtension, + ConversationState, + ExtensionRegistry, + ValidatedA2AExtension, +) +from crewai_a2a.extensions.server import ( + ExtensionContext, + ServerExtension, + ServerExtensionRegistry, +) + + +__all__ = [ + "A2AExtension", + "ConversationState", + "ExtensionContext", + "ExtensionRegistry", + "ServerExtension", + "ServerExtensionRegistry", + "ValidatedA2AExtension", +] diff --git a/lib/crewai-a2a/src/crewai_a2a/extensions/base.py b/lib/crewai-a2a/src/crewai_a2a/extensions/base.py new file mode 100644 index 000000000..f3b3201b1 --- /dev/null +++ b/lib/crewai-a2a/src/crewai_a2a/extensions/base.py @@ -0,0 +1,237 @@ +"""Base extension interface for CrewAI A2A wrapper processing hooks. + +This module defines the protocol for extending CrewAI's A2A wrapper functionality +with custom logic for tool injection, prompt augmentation, and response processing. + +Note: These are CrewAI-specific processing hooks, NOT A2A protocol extensions. +A2A protocol extensions are capability declarations using AgentExtension objects +in AgentCard.capabilities.extensions, activated via the A2A-Extensions HTTP header. +See: https://a2a-protocol.org/latest/topics/extensions/ +""" + +from __future__ import annotations + +from collections.abc import Sequence +from typing import TYPE_CHECKING, Annotated, Any, Protocol, runtime_checkable + +from pydantic import BeforeValidator + + +if TYPE_CHECKING: + from a2a.types import Message + from crewai.agent.core import Agent + + +def _validate_a2a_extension(v: Any) -> Any: + """Validate that value implements A2AExtension protocol.""" + if not isinstance(v, A2AExtension): + raise ValueError( + f"Value must implement A2AExtension protocol. " + f"Got {type(v).__name__} which is missing required methods." + ) + return v + + +ValidatedA2AExtension = Annotated[Any, BeforeValidator(_validate_a2a_extension)] + + +@runtime_checkable +class ConversationState(Protocol): + """Protocol for extension-specific conversation state. + + Extensions can define their own state classes that implement this protocol + to track conversation-specific data extracted from message history. + """ + + def is_ready(self) -> bool: + """Check if the state indicates readiness for some action. + + Returns: + True if the state is ready, False otherwise. + """ + ... + + +@runtime_checkable +class A2AExtension(Protocol): + """Protocol for A2A wrapper extensions. + + Extensions can implement this protocol to inject custom logic into + the A2A conversation flow at various integration points. + + Example: + class MyExtension: + def inject_tools(self, agent: Agent) -> None: + # Add custom tools to the agent + pass + + def extract_state_from_history( + self, conversation_history: Sequence[Message] + ) -> ConversationState | None: + # Extract state from conversation + return None + + def augment_prompt( + self, base_prompt: str, conversation_state: ConversationState | None + ) -> str: + # Add custom instructions + return base_prompt + + def process_response( + self, agent_response: Any, conversation_state: ConversationState | None + ) -> Any: + # Modify response if needed + return agent_response + """ + + def inject_tools(self, agent: Agent) -> None: + """Inject extension-specific tools into the agent. + + Called when an agent is wrapped with A2A capabilities. Extensions + can add tools that enable extension-specific functionality. + + Args: + agent: The agent instance to inject tools into. + """ + ... + + def extract_state_from_history( + self, conversation_history: Sequence[Message] + ) -> ConversationState | None: + """Extract extension-specific state from conversation history. + + Called during prompt augmentation to allow extensions to analyze + the conversation history and extract relevant state information. + + Args: + conversation_history: The sequence of A2A messages exchanged. + + Returns: + Extension-specific conversation state, or None if no relevant state. + """ + ... + + def augment_prompt( + self, + base_prompt: str, + conversation_state: ConversationState | None, + ) -> str: + """Augment the task prompt with extension-specific instructions. + + Called during prompt augmentation to allow extensions to add + custom instructions based on conversation state. + + Args: + base_prompt: The base prompt to augment. + conversation_state: Extension-specific state from extract_state_from_history. + + Returns: + The augmented prompt with extension-specific instructions. + """ + ... + + def process_response( + self, + agent_response: Any, + conversation_state: ConversationState | None, + ) -> Any: + """Process and potentially modify the agent response. + + Called after parsing the agent's response, allowing extensions to + enhance or modify the response based on conversation state. + + Args: + agent_response: The parsed agent response. + conversation_state: Extension-specific state from extract_state_from_history. + + Returns: + The processed agent response (may be modified or original). + """ + ... + + +class ExtensionRegistry: + """Registry for managing A2A extensions. + + Maintains a collection of extensions and provides methods to invoke + their hooks at various integration points. + """ + + def __init__(self) -> None: + """Initialize the extension registry.""" + self._extensions: list[A2AExtension] = [] + + def register(self, extension: A2AExtension) -> None: + """Register an extension. + + Args: + extension: The extension to register. + """ + self._extensions.append(extension) + + def inject_all_tools(self, agent: Agent) -> None: + """Inject tools from all registered extensions. + + Args: + agent: The agent instance to inject tools into. + """ + for extension in self._extensions: + extension.inject_tools(agent) + + def extract_all_states( + self, conversation_history: Sequence[Message] + ) -> dict[type[A2AExtension], ConversationState]: + """Extract conversation states from all registered extensions. + + Args: + conversation_history: The sequence of A2A messages exchanged. + + Returns: + Mapping of extension types to their conversation states. + """ + states: dict[type[A2AExtension], ConversationState] = {} + for extension in self._extensions: + state = extension.extract_state_from_history(conversation_history) + if state is not None: + states[type(extension)] = state + return states + + def augment_prompt_with_all( + self, + base_prompt: str, + extension_states: dict[type[A2AExtension], ConversationState], + ) -> str: + """Augment prompt with instructions from all registered extensions. + + Args: + base_prompt: The base prompt to augment. + extension_states: Mapping of extension types to conversation states. + + Returns: + The fully augmented prompt. + """ + augmented = base_prompt + for extension in self._extensions: + state = extension_states.get(type(extension)) + augmented = extension.augment_prompt(augmented, state) + return augmented + + def process_response_with_all( + self, + agent_response: Any, + extension_states: dict[type[A2AExtension], ConversationState], + ) -> Any: + """Process response through all registered extensions. + + Args: + agent_response: The parsed agent response. + extension_states: Mapping of extension types to conversation states. + + Returns: + The processed agent response. + """ + processed = agent_response + for extension in self._extensions: + state = extension_states.get(type(extension)) + processed = extension.process_response(processed, state) + return processed diff --git a/lib/crewai-a2a/src/crewai_a2a/extensions/registry.py b/lib/crewai-a2a/src/crewai_a2a/extensions/registry.py new file mode 100644 index 000000000..9403de2a0 --- /dev/null +++ b/lib/crewai-a2a/src/crewai_a2a/extensions/registry.py @@ -0,0 +1,170 @@ +"""A2A Protocol extension utilities. + +This module provides utilities for working with A2A protocol extensions as +defined in the A2A specification. Extensions are capability declarations in +AgentCard.capabilities.extensions using AgentExtension objects, activated +via the X-A2A-Extensions HTTP header. + +See: https://a2a-protocol.org/latest/topics/extensions/ +""" + +from __future__ import annotations + +from typing import Any + +from a2a.client.middleware import ClientCallContext, ClientCallInterceptor +from a2a.extensions.common import ( + HTTP_EXTENSION_HEADER, +) +from a2a.types import AgentCard, AgentExtension + +from crewai_a2a.config import A2AClientConfig, A2AConfig +from crewai_a2a.extensions.base import ExtensionRegistry + + +def get_extensions_from_config( + a2a_config: list[A2AConfig | A2AClientConfig] | A2AConfig | A2AClientConfig, +) -> list[str]: + """Extract extension URIs from A2A configuration. + + Args: + a2a_config: A2A configuration (single or list). + + Returns: + Deduplicated list of extension URIs from all configs. + """ + configs = a2a_config if isinstance(a2a_config, list) else [a2a_config] + seen: set[str] = set() + result: list[str] = [] + + for config in configs: + if not isinstance(config, A2AClientConfig): + continue + for uri in config.extensions: + if uri not in seen: + seen.add(uri) + result.append(uri) + + return result + + +class ExtensionsMiddleware(ClientCallInterceptor): + """Middleware to add X-A2A-Extensions header to requests. + + This middleware adds the extensions header to all outgoing requests, + declaring which A2A protocol extensions the client supports. + """ + + def __init__(self, extensions: list[str]) -> None: + """Initialize with extension URIs. + + Args: + extensions: List of extension URIs the client supports. + """ + self._extensions = extensions + + async def intercept( + self, + method_name: str, + request_payload: dict[str, Any], + http_kwargs: dict[str, Any], + agent_card: AgentCard | None, + context: ClientCallContext | None, + ) -> tuple[dict[str, Any], dict[str, Any]]: + """Add extensions header to the request. + + Args: + method_name: The A2A method being called. + request_payload: The JSON-RPC request payload. + http_kwargs: HTTP request kwargs (headers, etc). + agent_card: The target agent's card. + context: Optional call context. + + Returns: + Tuple of (request_payload, modified_http_kwargs). + """ + if self._extensions: + headers = http_kwargs.setdefault("headers", {}) + headers[HTTP_EXTENSION_HEADER] = ",".join(self._extensions) + return request_payload, http_kwargs + + +def validate_required_extensions( + agent_card: AgentCard, + client_extensions: list[str] | None, +) -> list[AgentExtension]: + """Validate that client supports all required extensions from agent. + + Args: + agent_card: The agent's card with declared extensions. + client_extensions: Extension URIs the client supports. + + Returns: + List of unsupported required extensions. + + Raises: + None - returns list of unsupported extensions for caller to handle. + """ + unsupported: list[AgentExtension] = [] + client_set = set(client_extensions or []) + + if not agent_card.capabilities or not agent_card.capabilities.extensions: + return unsupported + + unsupported.extend( + ext + for ext in agent_card.capabilities.extensions + if ext.required and ext.uri not in client_set + ) + + return unsupported + + +def create_extension_registry_from_config( + a2a_config: list[A2AConfig | A2AClientConfig] | A2AConfig | A2AClientConfig, +) -> ExtensionRegistry: + """Create an extension registry from A2A client configuration. + + Extracts client_extensions from each A2AClientConfig and registers them + with the ExtensionRegistry. These extensions provide CrewAI-specific + processing hooks (tool injection, prompt augmentation, response processing). + + Note: A2A protocol extensions (URI strings sent via X-A2A-Extensions header) + are handled separately via get_extensions_from_config() and ExtensionsMiddleware. + + Args: + a2a_config: A2A configuration (single or list). + + Returns: + Extension registry with all client_extensions registered. + + Example: + class LoggingExtension: + def inject_tools(self, agent): pass + def extract_state_from_history(self, history): return None + def augment_prompt(self, prompt, state): return prompt + def process_response(self, response, state): + print(f"Response: {response}") + return response + + config = A2AClientConfig( + endpoint="https://agent.example.com", + client_extensions=[LoggingExtension()], + ) + registry = create_extension_registry_from_config(config) + """ + registry = ExtensionRegistry() + configs = a2a_config if isinstance(a2a_config, list) else [a2a_config] + + seen: set[int] = set() + + for config in configs: + if isinstance(config, (A2AConfig, A2AClientConfig)): + client_exts = getattr(config, "client_extensions", []) + for extension in client_exts: + ext_id = id(extension) + if ext_id not in seen: + seen.add(ext_id) + registry.register(extension) + + return registry diff --git a/lib/crewai-a2a/src/crewai_a2a/extensions/server.py b/lib/crewai-a2a/src/crewai_a2a/extensions/server.py new file mode 100644 index 000000000..9bbc9c08b --- /dev/null +++ b/lib/crewai-a2a/src/crewai_a2a/extensions/server.py @@ -0,0 +1,305 @@ +"""A2A protocol server extensions for CrewAI agents. + +This module provides the base class and context for implementing A2A protocol +extensions on the server side. Extensions allow agents to offer additional +functionality beyond the core A2A specification. + +See: https://a2a-protocol.org/latest/topics/extensions/ +""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +import logging +from typing import TYPE_CHECKING, Annotated, Any + +from a2a.types import AgentExtension +from pydantic_core import CoreSchema, core_schema + + +if TYPE_CHECKING: + from a2a.server.context import ServerCallContext + from pydantic import GetCoreSchemaHandler + + +logger = logging.getLogger(__name__) + + +@dataclass +class ExtensionContext: + """Context passed to extension hooks during request processing. + + Provides access to request metadata, client extensions, and shared state + that extensions can read from and write to. + + Attributes: + metadata: Request metadata dict, includes extension-namespaced keys. + client_extensions: Set of extension URIs the client declared support for. + state: Mutable dict for extensions to share data during request lifecycle. + server_context: The underlying A2A server call context. + """ + + metadata: dict[str, Any] + client_extensions: set[str] + state: dict[str, Any] = field(default_factory=dict) + server_context: ServerCallContext | None = None + + def get_extension_metadata(self, uri: str, key: str) -> Any | None: + """Get extension-specific metadata value. + + Extension metadata uses namespaced keys in the format: + "{extension_uri}/{key}" + + Args: + uri: The extension URI. + key: The metadata key within the extension namespace. + + Returns: + The metadata value, or None if not present. + """ + full_key = f"{uri}/{key}" + return self.metadata.get(full_key) + + def set_extension_metadata(self, uri: str, key: str, value: Any) -> None: + """Set extension-specific metadata value. + + Args: + uri: The extension URI. + key: The metadata key within the extension namespace. + value: The value to set. + """ + full_key = f"{uri}/{key}" + self.metadata[full_key] = value + + +class ServerExtension(ABC): + """Base class for A2A protocol server extensions. + + Subclass this to create custom extensions that modify agent behavior + when clients activate them. Extensions are identified by URI and can + be marked as required. + + Example: + class SamplingExtension(ServerExtension): + uri = "urn:crewai:ext:sampling/v1" + required = True + + def __init__(self, max_tokens: int = 4096): + self.max_tokens = max_tokens + + @property + def params(self) -> dict[str, Any]: + return {"max_tokens": self.max_tokens} + + async def on_request(self, context: ExtensionContext) -> None: + limit = context.get_extension_metadata(self.uri, "limit") + if limit: + context.state["token_limit"] = int(limit) + + async def on_response(self, context: ExtensionContext, result: Any) -> Any: + return result + """ + + uri: Annotated[str, "Extension URI identifier. Must be unique."] + required: Annotated[bool, "Whether clients must support this extension."] = False + description: Annotated[ + str | None, "Human-readable description of the extension." + ] = None + + @classmethod + def __get_pydantic_core_schema__( + cls, + _source_type: Any, + _handler: GetCoreSchemaHandler, + ) -> CoreSchema: + """Tell Pydantic how to validate ServerExtension instances.""" + return core_schema.is_instance_schema(cls) + + @property + def params(self) -> dict[str, Any] | None: + """Extension parameters to advertise in AgentCard. + + Override this property to expose configuration that clients can read. + + Returns: + Dict of parameter names to values, or None. + """ + return None + + def agent_extension(self) -> AgentExtension: + """Generate the AgentExtension object for the AgentCard. + + Returns: + AgentExtension with this extension's URI, required flag, and params. + """ + return AgentExtension( + uri=self.uri, + required=self.required if self.required else None, + description=self.description, + params=self.params, + ) + + def is_active(self, context: ExtensionContext) -> bool: + """Check if this extension is active for the current request. + + An extension is active if the client declared support for it. + + Args: + context: The extension context for the current request. + + Returns: + True if the client supports this extension. + """ + return self.uri in context.client_extensions + + @abstractmethod + async def on_request(self, context: ExtensionContext) -> None: + """Called before agent execution if extension is active. + + Use this hook to: + - Read extension-specific metadata from the request + - Set up state for the execution + - Modify execution parameters via context.state + + Args: + context: The extension context with request metadata and state. + """ + ... + + @abstractmethod + async def on_response(self, context: ExtensionContext, result: Any) -> Any: + """Called after agent execution if extension is active. + + Use this hook to: + - Modify or enhance the result + - Add extension-specific metadata to the response + - Clean up any resources + + Args: + context: The extension context with request metadata and state. + result: The agent execution result. + + Returns: + The result, potentially modified. + """ + ... + + +class ServerExtensionRegistry: + """Registry for managing server-side A2A protocol extensions. + + Collects extensions and provides methods to generate AgentCapabilities + and invoke extension hooks during request processing. + """ + + def __init__(self, extensions: list[ServerExtension] | None = None) -> None: + """Initialize the registry with optional extensions. + + Args: + extensions: Initial list of extensions to register. + """ + self._extensions: list[ServerExtension] = list(extensions) if extensions else [] + self._by_uri: dict[str, ServerExtension] = { + ext.uri: ext for ext in self._extensions + } + + def register(self, extension: ServerExtension) -> None: + """Register an extension. + + Args: + extension: The extension to register. + + Raises: + ValueError: If an extension with the same URI is already registered. + """ + if extension.uri in self._by_uri: + raise ValueError(f"Extension already registered: {extension.uri}") + self._extensions.append(extension) + self._by_uri[extension.uri] = extension + + def get_agent_extensions(self) -> list[AgentExtension]: + """Get AgentExtension objects for all registered extensions. + + Returns: + List of AgentExtension objects for the AgentCard. + """ + return [ext.agent_extension() for ext in self._extensions] + + def get_extension(self, uri: str) -> ServerExtension | None: + """Get an extension by URI. + + Args: + uri: The extension URI. + + Returns: + The extension, or None if not found. + """ + return self._by_uri.get(uri) + + @staticmethod + def create_context( + metadata: dict[str, Any], + client_extensions: set[str], + server_context: ServerCallContext | None = None, + ) -> ExtensionContext: + """Create an ExtensionContext for a request. + + Args: + metadata: Request metadata dict. + client_extensions: Set of extension URIs from client. + server_context: Optional server call context. + + Returns: + ExtensionContext for use in hooks. + """ + return ExtensionContext( + metadata=metadata, + client_extensions=client_extensions, + server_context=server_context, + ) + + async def invoke_on_request(self, context: ExtensionContext) -> None: + """Invoke on_request hooks for all active extensions. + + Tracks activated extensions and isolates errors from individual hooks. + + Args: + context: The extension context for the request. + """ + for extension in self._extensions: + if extension.is_active(context): + try: + await extension.on_request(context) + if context.server_context is not None: + context.server_context.activated_extensions.add(extension.uri) + except Exception: + logger.exception( + "Extension on_request hook failed", + extra={"extension": extension.uri}, + ) + + async def invoke_on_response(self, context: ExtensionContext, result: Any) -> Any: + """Invoke on_response hooks for all active extensions. + + Isolates errors from individual hooks to prevent one failing extension + from breaking the entire response. + + Args: + context: The extension context for the request. + result: The agent execution result. + + Returns: + The result after all extensions have processed it. + """ + processed = result + for extension in self._extensions: + if extension.is_active(context): + try: + processed = await extension.on_response(context, processed) + except Exception: + logger.exception( + "Extension on_response hook failed", + extra={"extension": extension.uri}, + ) + return processed diff --git a/lib/crewai-a2a/src/crewai_a2a/py.typed b/lib/crewai-a2a/src/crewai_a2a/py.typed new file mode 100644 index 000000000..e69de29bb diff --git a/lib/crewai-a2a/src/crewai_a2a/task_helpers.py b/lib/crewai-a2a/src/crewai_a2a/task_helpers.py new file mode 100644 index 000000000..43ea61be2 --- /dev/null +++ b/lib/crewai-a2a/src/crewai_a2a/task_helpers.py @@ -0,0 +1,479 @@ +"""Helper functions for processing A2A task results.""" + +from __future__ import annotations + +from collections.abc import AsyncIterator +from typing import TYPE_CHECKING, Any, TypedDict +import uuid + +from a2a.client.errors import A2AClientHTTPError +from a2a.types import ( + AgentCard, + Message, + Part, + Role, + Task, + TaskArtifactUpdateEvent, + TaskState, + TaskStatusUpdateEvent, + TextPart, +) +from crewai.events.event_bus import crewai_event_bus +from crewai.events.types.a2a_events import ( + A2AConnectionErrorEvent, + A2AResponseReceivedEvent, +) +from typing_extensions import NotRequired + + +if TYPE_CHECKING: + from a2a.types import Task as A2ATask + +SendMessageEvent = ( + tuple[Task, TaskStatusUpdateEvent | TaskArtifactUpdateEvent | None] | Message +) + + +TERMINAL_STATES: frozenset[TaskState] = frozenset( + { + TaskState.completed, + TaskState.failed, + TaskState.rejected, + TaskState.canceled, + } +) + +ACTIONABLE_STATES: frozenset[TaskState] = frozenset( + { + TaskState.input_required, + TaskState.auth_required, + } +) + +PENDING_STATES: frozenset[TaskState] = frozenset( + { + TaskState.submitted, + TaskState.working, + } +) + + +class TaskStateResult(TypedDict): + """Result dictionary from processing A2A task state.""" + + status: TaskState + history: list[Message] + result: NotRequired[str] + error: NotRequired[str] + agent_card: NotRequired[dict[str, Any]] + a2a_agent_name: NotRequired[str | None] + + +def extract_task_result_parts(a2a_task: A2ATask) -> list[str]: + """Extract result parts from A2A task status message, history, and artifacts. + + Args: + a2a_task: A2A Task object with status, history, and artifacts + + Returns: + List of result text parts + """ + result_parts: list[str] = [] + + if a2a_task.status and a2a_task.status.message: + msg = a2a_task.status.message + result_parts.extend( + part.root.text for part in msg.parts if part.root.kind == "text" + ) + + if not result_parts and a2a_task.history: + for history_msg in reversed(a2a_task.history): + if history_msg.role == Role.agent: + result_parts.extend( + part.root.text + for part in history_msg.parts + if part.root.kind == "text" + ) + break + + if a2a_task.artifacts: + result_parts.extend( + part.root.text + for artifact in a2a_task.artifacts + for part in artifact.parts + if part.root.kind == "text" + ) + + return result_parts + + +def extract_error_message(a2a_task: A2ATask, default: str) -> str: + """Extract error message from A2A task. + + Args: + a2a_task: A2A Task object + default: Default message if no error found + + Returns: + Error message string + """ + if a2a_task.status and a2a_task.status.message: + msg = a2a_task.status.message + if msg: + for part in msg.parts: + if part.root.kind == "text": + return str(part.root.text) + return str(msg) + + if a2a_task.history: + for history_msg in reversed(a2a_task.history): + for part in history_msg.parts: + if part.root.kind == "text": + return str(part.root.text) + + return default + + +def process_task_state( + a2a_task: A2ATask, + new_messages: list[Message], + agent_card: AgentCard, + turn_number: int, + is_multiturn: bool, + agent_role: str | None, + result_parts: list[str] | None = None, + endpoint: str | None = None, + a2a_agent_name: str | None = None, + from_task: Any | None = None, + from_agent: Any | None = None, + is_final: bool = True, +) -> TaskStateResult | None: + """Process A2A task state and return result dictionary. + + Shared logic for both polling and streaming handlers. + + Args: + a2a_task: The A2A task to process. + new_messages: List to collect messages (modified in place). + agent_card: The agent card. + turn_number: Current turn number. + is_multiturn: Whether multi-turn conversation. + agent_role: Agent role for logging. + result_parts: Accumulated result parts (streaming passes accumulated, + polling passes None to extract from task). + endpoint: A2A agent endpoint URL. + a2a_agent_name: Name of the A2A agent from agent card. + from_task: Optional CrewAI Task for event metadata. + from_agent: Optional CrewAI Agent for event metadata. + is_final: Whether this is the final response in the stream. + + Returns: + Result dictionary if terminal/actionable state, None otherwise. + """ + if result_parts is None: + result_parts = [] + + if a2a_task.status.state == TaskState.completed: + if not result_parts: + extracted_parts = extract_task_result_parts(a2a_task) + result_parts.extend(extracted_parts) + if a2a_task.history: + new_messages.extend(a2a_task.history) + + response_text = " ".join(result_parts) if result_parts else "" + message_id = None + if a2a_task.status and a2a_task.status.message: + message_id = a2a_task.status.message.message_id + crewai_event_bus.emit( + None, + A2AResponseReceivedEvent( + response=response_text, + turn_number=turn_number, + context_id=a2a_task.context_id, + message_id=message_id, + is_multiturn=is_multiturn, + status="completed", + final=is_final, + agent_role=agent_role, + endpoint=endpoint, + a2a_agent_name=a2a_agent_name, + from_task=from_task, + from_agent=from_agent, + ), + ) + + return TaskStateResult( + status=TaskState.completed, + agent_card=agent_card.model_dump(exclude_none=True), + result=response_text, + history=new_messages, + ) + + if a2a_task.status.state == TaskState.input_required: + if a2a_task.history: + new_messages.extend(a2a_task.history) + + response_text = extract_error_message(a2a_task, "Additional input required") + if response_text and not a2a_task.history: + agent_message = Message( + role=Role.agent, + message_id=str(uuid.uuid4()), + parts=[Part(root=TextPart(text=response_text))], + context_id=a2a_task.context_id, + task_id=a2a_task.id, + ) + new_messages.append(agent_message) + + input_message_id = None + if a2a_task.status and a2a_task.status.message: + input_message_id = a2a_task.status.message.message_id + crewai_event_bus.emit( + None, + A2AResponseReceivedEvent( + response=response_text, + turn_number=turn_number, + context_id=a2a_task.context_id, + message_id=input_message_id, + is_multiturn=is_multiturn, + status="input_required", + final=is_final, + agent_role=agent_role, + endpoint=endpoint, + a2a_agent_name=a2a_agent_name, + from_task=from_task, + from_agent=from_agent, + ), + ) + + return TaskStateResult( + status=TaskState.input_required, + error=response_text, + history=new_messages, + agent_card=agent_card.model_dump(exclude_none=True), + ) + + if a2a_task.status.state in {TaskState.failed, TaskState.rejected}: + error_msg = extract_error_message(a2a_task, "Task failed without error message") + if a2a_task.history: + new_messages.extend(a2a_task.history) + return TaskStateResult( + status=TaskState.failed, + error=error_msg, + history=new_messages, + ) + + if a2a_task.status.state == TaskState.auth_required: + error_msg = extract_error_message(a2a_task, "Authentication required") + return TaskStateResult( + status=TaskState.auth_required, + error=error_msg, + history=new_messages, + ) + + if a2a_task.status.state == TaskState.canceled: + error_msg = extract_error_message(a2a_task, "Task was canceled") + return TaskStateResult( + status=TaskState.canceled, + error=error_msg, + history=new_messages, + ) + + if a2a_task.status.state in PENDING_STATES: + return None + + return None + + +async def send_message_and_get_task_id( + event_stream: AsyncIterator[SendMessageEvent], + new_messages: list[Message], + agent_card: AgentCard, + turn_number: int, + is_multiturn: bool, + agent_role: str | None, + from_task: Any | None = None, + from_agent: Any | None = None, + endpoint: str | None = None, + a2a_agent_name: str | None = None, + context_id: str | None = None, +) -> str | TaskStateResult: + """Send message and process initial response. + + Handles the common pattern of sending a message and either: + - Getting an immediate Message response (task completed synchronously) + - Getting a Task that needs polling/waiting for completion + + Args: + event_stream: Async iterator from client.send_message() + new_messages: List to collect messages (modified in place) + agent_card: The agent card + turn_number: Current turn number + is_multiturn: Whether multi-turn conversation + agent_role: Agent role for logging + from_task: Optional CrewAI Task object for event metadata. + from_agent: Optional CrewAI Agent object for event metadata. + endpoint: Optional A2A endpoint URL. + a2a_agent_name: Optional A2A agent name. + context_id: Optional A2A context ID for correlation. + + Returns: + Task ID string if agent needs polling/waiting, or TaskStateResult if done. + """ + try: + async for event in event_stream: + if isinstance(event, Message): + new_messages.append(event) + result_parts = [ + part.root.text for part in event.parts if part.root.kind == "text" + ] + response_text = " ".join(result_parts) if result_parts else "" + + crewai_event_bus.emit( + None, + A2AResponseReceivedEvent( + response=response_text, + turn_number=turn_number, + context_id=event.context_id, + message_id=event.message_id, + is_multiturn=is_multiturn, + status="completed", + final=True, + agent_role=agent_role, + endpoint=endpoint, + a2a_agent_name=a2a_agent_name, + from_task=from_task, + from_agent=from_agent, + ), + ) + + return TaskStateResult( + status=TaskState.completed, + result=response_text, + history=new_messages, + agent_card=agent_card.model_dump(exclude_none=True), + ) + + if isinstance(event, tuple): + a2a_task, _ = event + + if a2a_task.status.state in TERMINAL_STATES | ACTIONABLE_STATES: + result = process_task_state( + a2a_task=a2a_task, + new_messages=new_messages, + agent_card=agent_card, + turn_number=turn_number, + is_multiturn=is_multiturn, + agent_role=agent_role, + endpoint=endpoint, + a2a_agent_name=a2a_agent_name, + from_task=from_task, + from_agent=from_agent, + ) + if result: + return result + + return a2a_task.id + + return TaskStateResult( + status=TaskState.failed, + error="No task ID received from initial message", + history=new_messages, + ) + + except A2AClientHTTPError as e: + error_msg = f"HTTP Error {e.status_code}: {e!s}" + + error_message = Message( + role=Role.agent, + message_id=str(uuid.uuid4()), + parts=[Part(root=TextPart(text=error_msg))], + context_id=context_id, + ) + new_messages.append(error_message) + + crewai_event_bus.emit( + None, + A2AConnectionErrorEvent( + endpoint=endpoint or "", + error=str(e), + error_type="http_error", + status_code=e.status_code, + a2a_agent_name=a2a_agent_name, + operation="send_message", + context_id=context_id, + from_task=from_task, + from_agent=from_agent, + ), + ) + crewai_event_bus.emit( + None, + A2AResponseReceivedEvent( + response=error_msg, + turn_number=turn_number, + context_id=context_id, + is_multiturn=is_multiturn, + status="failed", + final=True, + agent_role=agent_role, + endpoint=endpoint, + a2a_agent_name=a2a_agent_name, + from_task=from_task, + from_agent=from_agent, + ), + ) + return TaskStateResult( + status=TaskState.failed, + error=error_msg, + history=new_messages, + ) + + except Exception as e: + error_msg = f"Unexpected error during send_message: {e!s}" + + error_message = Message( + role=Role.agent, + message_id=str(uuid.uuid4()), + parts=[Part(root=TextPart(text=error_msg))], + context_id=context_id, + ) + new_messages.append(error_message) + + crewai_event_bus.emit( + None, + A2AConnectionErrorEvent( + endpoint=endpoint or "", + error=str(e), + error_type="unexpected_error", + a2a_agent_name=a2a_agent_name, + operation="send_message", + context_id=context_id, + from_task=from_task, + from_agent=from_agent, + ), + ) + crewai_event_bus.emit( + None, + A2AResponseReceivedEvent( + response=error_msg, + turn_number=turn_number, + context_id=context_id, + is_multiturn=is_multiturn, + status="failed", + final=True, + agent_role=agent_role, + endpoint=endpoint, + a2a_agent_name=a2a_agent_name, + from_task=from_task, + from_agent=from_agent, + ), + ) + return TaskStateResult( + status=TaskState.failed, + error=error_msg, + history=new_messages, + ) + + finally: + aclose = getattr(event_stream, "aclose", None) + if aclose: + await aclose() diff --git a/lib/crewai-a2a/src/crewai_a2a/templates.py b/lib/crewai-a2a/src/crewai_a2a/templates.py new file mode 100644 index 000000000..16f0c479e --- /dev/null +++ b/lib/crewai-a2a/src/crewai_a2a/templates.py @@ -0,0 +1,55 @@ +"""String templates for A2A (Agent-to-Agent) protocol messaging and status.""" + +from string import Template +from typing import Final + + +AVAILABLE_AGENTS_TEMPLATE: Final[Template] = Template( + "\n\n $available_a2a_agents\n\n" +) +PREVIOUS_A2A_CONVERSATION_TEMPLATE: Final[Template] = Template( + "\n\n" + " $previous_a2a_conversation" + "\n\n" +) +CONVERSATION_TURN_INFO_TEMPLATE: Final[Template] = Template( + "\n\n" + ' turn="$turn_count"\n' + ' max_turns="$max_turns"\n' + " $warning" + "\n\n" +) +UNAVAILABLE_AGENTS_NOTICE_TEMPLATE: Final[Template] = Template( + "\n\n" + " NOTE: A2A agents were configured but are currently unavailable.\n" + " You cannot delegate to remote agents for this task.\n\n" + " Unavailable Agents:\n" + " $unavailable_agents" + "\n\n" +) +REMOTE_AGENT_COMPLETED_NOTICE: Final[str] = """ + +STATUS: COMPLETED +The remote agent has finished processing your request. Their response is in the conversation history above. +You MUST now: +1. Extract the answer from the conversation history +2. Set is_a2a=false +3. Return the answer as your final message +DO NOT send another request - the task is already done. + +""" + +REMOTE_AGENT_RESPONSE_NOTICE: Final[str] = """ + +STATUS: RESPONSE_RECEIVED +The remote agent has responded. Their response is in the conversation history above. + +You MUST now: +1. Set is_a2a=false (the remote task is complete and cannot receive more messages) +2. Provide YOUR OWN response to the original task based on the information received + +IMPORTANT: Your response should be addressed to the USER who gave you the original task. +Report what the remote agent told you in THIRD PERSON (e.g., "The remote agent said..." or "I learned that..."). +Do NOT address the remote agent directly or use "you" to refer to them. + +""" diff --git a/lib/crewai-a2a/src/crewai_a2a/types.py b/lib/crewai-a2a/src/crewai_a2a/types.py new file mode 100644 index 000000000..8826f67d6 --- /dev/null +++ b/lib/crewai-a2a/src/crewai_a2a/types.py @@ -0,0 +1,104 @@ +"""Type definitions for A2A protocol message parts.""" + +from __future__ import annotations + +from typing import ( + Annotated, + Any, + Literal, + Protocol, + TypedDict, + runtime_checkable, +) + +from pydantic import BeforeValidator, HttpUrl, TypeAdapter +from typing_extensions import NotRequired + + +try: + from crewai_a2a.updates import ( + PollingConfig, + PollingHandler, + PushNotificationConfig, + PushNotificationHandler, + StreamingConfig, + StreamingHandler, + UpdateConfig, + ) +except ImportError: + PollingConfig = Any # type: ignore[misc,assignment] + PollingHandler = Any # type: ignore[misc,assignment] + PushNotificationConfig = Any # type: ignore[misc,assignment] + PushNotificationHandler = Any # type: ignore[misc,assignment] + StreamingConfig = Any # type: ignore[misc,assignment] + StreamingHandler = Any # type: ignore[misc,assignment] + UpdateConfig = Any # type: ignore[misc,assignment] + + +TransportType = Literal["JSONRPC", "GRPC", "HTTP+JSON"] +ProtocolVersion = Literal[ + "0.2.0", + "0.2.1", + "0.2.2", + "0.2.3", + "0.2.4", + "0.2.5", + "0.2.6", + "0.3.0", + "0.4.0", +] + +http_url_adapter: TypeAdapter[HttpUrl] = TypeAdapter(HttpUrl) + +Url = Annotated[ + str, + BeforeValidator( + lambda value: str(http_url_adapter.validate_python(value, strict=True)) + ), +] + + +@runtime_checkable +class AgentResponseProtocol(Protocol): + """Protocol for the dynamically created AgentResponse model.""" + + a2a_ids: tuple[str, ...] + message: str + is_a2a: bool + + +class PartsMetadataDict(TypedDict, total=False): + """Metadata for A2A message parts. + + Attributes: + mimeType: MIME type for the part content. + schema: JSON schema for the part content. + """ + + mimeType: Literal["application/json"] + schema: dict[str, Any] + + +class PartsDict(TypedDict): + """A2A message part containing text and optional metadata. + + Attributes: + text: The text content of the message part. + metadata: Optional metadata describing the part content. + """ + + text: str + metadata: NotRequired[PartsMetadataDict] + + +PollingHandlerType = type[PollingHandler] +StreamingHandlerType = type[StreamingHandler] +PushNotificationHandlerType = type[PushNotificationHandler] + +HandlerType = PollingHandlerType | StreamingHandlerType | PushNotificationHandlerType + +HANDLER_REGISTRY: dict[type[UpdateConfig], HandlerType] = { + PollingConfig: PollingHandler, + StreamingConfig: StreamingHandler, + PushNotificationConfig: PushNotificationHandler, +} diff --git a/lib/crewai-a2a/src/crewai_a2a/updates/__init__.py b/lib/crewai-a2a/src/crewai_a2a/updates/__init__.py new file mode 100644 index 000000000..5c9f98c05 --- /dev/null +++ b/lib/crewai-a2a/src/crewai_a2a/updates/__init__.py @@ -0,0 +1,35 @@ +"""A2A update mechanism configuration types.""" + +from crewai_a2a.updates.base import ( + BaseHandlerKwargs, + PollingHandlerKwargs, + PushNotificationHandlerKwargs, + PushNotificationResultStore, + StreamingHandlerKwargs, + UpdateHandler, +) +from crewai_a2a.updates.polling.config import PollingConfig +from crewai_a2a.updates.polling.handler import PollingHandler +from crewai_a2a.updates.push_notifications.config import PushNotificationConfig +from crewai_a2a.updates.push_notifications.handler import PushNotificationHandler +from crewai_a2a.updates.streaming.config import StreamingConfig +from crewai_a2a.updates.streaming.handler import StreamingHandler + + +UpdateConfig = PollingConfig | StreamingConfig | PushNotificationConfig + +__all__ = [ + "BaseHandlerKwargs", + "PollingConfig", + "PollingHandler", + "PollingHandlerKwargs", + "PushNotificationConfig", + "PushNotificationHandler", + "PushNotificationHandlerKwargs", + "PushNotificationResultStore", + "StreamingConfig", + "StreamingHandler", + "StreamingHandlerKwargs", + "UpdateConfig", + "UpdateHandler", +] diff --git a/lib/crewai-a2a/src/crewai_a2a/updates/base.py b/lib/crewai-a2a/src/crewai_a2a/updates/base.py new file mode 100644 index 000000000..ea44ebc10 --- /dev/null +++ b/lib/crewai-a2a/src/crewai_a2a/updates/base.py @@ -0,0 +1,176 @@ +"""Base types for A2A update mechanism handlers.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, NamedTuple, Protocol, TypedDict + +from pydantic import GetCoreSchemaHandler +from pydantic_core import CoreSchema, core_schema + + +class CommonParams(NamedTuple): + """Common parameters shared across all update handlers. + + Groups the frequently-passed parameters to reduce duplication. + """ + + turn_number: int + is_multiturn: bool + agent_role: str | None + endpoint: str + a2a_agent_name: str | None + context_id: str | None + from_task: Any + from_agent: Any + + +if TYPE_CHECKING: + from a2a.client import Client + from a2a.types import AgentCard, Message, Task + + from crewai_a2a.task_helpers import TaskStateResult + from crewai_a2a.updates.push_notifications.config import PushNotificationConfig + + +class BaseHandlerKwargs(TypedDict, total=False): + """Base kwargs shared by all handlers.""" + + turn_number: int + is_multiturn: bool + agent_role: str | None + context_id: str | None + task_id: str | None + endpoint: str | None + agent_branch: Any + a2a_agent_name: str | None + from_task: Any + from_agent: Any + + +class PollingHandlerKwargs(BaseHandlerKwargs, total=False): + """Kwargs for polling handler.""" + + polling_interval: float + polling_timeout: float + history_length: int + max_polls: int | None + + +class StreamingHandlerKwargs(BaseHandlerKwargs, total=False): + """Kwargs for streaming handler.""" + + +class PushNotificationHandlerKwargs(BaseHandlerKwargs, total=False): + """Kwargs for push notification handler.""" + + config: PushNotificationConfig + result_store: PushNotificationResultStore + polling_timeout: float + polling_interval: float + + +class PushNotificationResultStore(Protocol): + """Protocol for storing and retrieving push notification results. + + This protocol defines the interface for a result store that the + PushNotificationHandler uses to wait for task completion. + """ + + @classmethod + def __get_pydantic_core_schema__( + cls, + _source_type: Any, + _handler: GetCoreSchemaHandler, + ) -> CoreSchema: + return core_schema.any_schema() + + async def wait_for_result( + self, + task_id: str, + timeout: float, + poll_interval: float = 1.0, + ) -> Task | None: + """Wait for a task result to be available. + + Args: + task_id: The task ID to wait for. + timeout: Max seconds to wait before returning None. + poll_interval: Seconds between polling attempts. + + Returns: + The completed Task object, or None if timeout. + """ + ... + + async def get_result(self, task_id: str) -> Task | None: + """Get a task result if available. + + Args: + task_id: The task ID to retrieve. + + Returns: + The Task object if available, None otherwise. + """ + ... + + async def store_result(self, task: Task) -> None: + """Store a task result. + + Args: + task: The Task object to store. + """ + ... + + +class UpdateHandler(Protocol): + """Protocol for A2A update mechanism handlers.""" + + @staticmethod + async def execute( + client: Client, + message: Message, + new_messages: list[Message], + agent_card: AgentCard, + **kwargs: Any, + ) -> TaskStateResult: + """Execute the update mechanism and return result. + + Args: + client: A2A client instance. + message: Message to send. + new_messages: List to collect messages (modified in place). + agent_card: The agent card. + **kwargs: Additional handler-specific parameters. + + Returns: + Result dictionary with status, result/error, and history. + """ + ... + + +def extract_common_params(kwargs: BaseHandlerKwargs) -> CommonParams: + """Extract common parameters from handler kwargs. + + Args: + kwargs: Handler kwargs dict. + + Returns: + CommonParams with extracted values. + + Raises: + ValueError: If endpoint is not provided. + """ + endpoint = kwargs.get("endpoint") + if endpoint is None: + raise ValueError("endpoint is required for update handlers") + + return CommonParams( + turn_number=kwargs.get("turn_number", 0), + is_multiturn=kwargs.get("is_multiturn", False), + agent_role=kwargs.get("agent_role"), + endpoint=endpoint, + a2a_agent_name=kwargs.get("a2a_agent_name"), + context_id=kwargs.get("context_id"), + from_task=kwargs.get("from_task"), + from_agent=kwargs.get("from_agent"), + ) diff --git a/lib/crewai-a2a/src/crewai_a2a/updates/polling/__init__.py b/lib/crewai-a2a/src/crewai_a2a/updates/polling/__init__.py new file mode 100644 index 000000000..7199db700 --- /dev/null +++ b/lib/crewai-a2a/src/crewai_a2a/updates/polling/__init__.py @@ -0,0 +1 @@ +"""Polling update mechanism module.""" diff --git a/lib/crewai-a2a/src/crewai_a2a/updates/polling/config.py b/lib/crewai-a2a/src/crewai_a2a/updates/polling/config.py new file mode 100644 index 000000000..1dcf970a6 --- /dev/null +++ b/lib/crewai-a2a/src/crewai_a2a/updates/polling/config.py @@ -0,0 +1,25 @@ +"""Polling update mechanism configuration.""" + +from __future__ import annotations + +from pydantic import BaseModel, Field + + +class PollingConfig(BaseModel): + """Configuration for polling-based task updates. + + Attributes: + interval: Seconds between poll attempts. + timeout: Max seconds to poll before raising timeout error. + max_polls: Max number of poll attempts. + history_length: Number of messages to retrieve per poll. + """ + + interval: float = Field( + default=2.0, gt=0, description="Seconds between poll attempts" + ) + timeout: float | None = Field(default=None, gt=0, description="Max seconds to poll") + max_polls: int | None = Field(default=None, gt=0, description="Max poll attempts") + history_length: int = Field( + default=100, gt=0, description="Messages to retrieve per poll" + ) diff --git a/lib/crewai-a2a/src/crewai_a2a/updates/polling/handler.py b/lib/crewai-a2a/src/crewai_a2a/updates/polling/handler.py new file mode 100644 index 000000000..30fbf90e5 --- /dev/null +++ b/lib/crewai-a2a/src/crewai_a2a/updates/polling/handler.py @@ -0,0 +1,359 @@ +"""Polling update mechanism handler.""" + +from __future__ import annotations + +import asyncio +import time +from typing import TYPE_CHECKING, Any +import uuid + +from a2a.client import Client +from a2a.client.errors import A2AClientHTTPError +from a2a.types import ( + AgentCard, + Message, + Part, + Role, + TaskQueryParams, + TaskState, + TextPart, +) +from crewai.events.event_bus import crewai_event_bus +from crewai.events.types.a2a_events import ( + A2AConnectionErrorEvent, + A2APollingStartedEvent, + A2APollingStatusEvent, + A2AResponseReceivedEvent, +) +from typing_extensions import Unpack + +from crewai_a2a.errors import A2APollingTimeoutError +from crewai_a2a.task_helpers import ( + ACTIONABLE_STATES, + TERMINAL_STATES, + TaskStateResult, + process_task_state, + send_message_and_get_task_id, +) +from crewai_a2a.updates.base import PollingHandlerKwargs + + +if TYPE_CHECKING: + from a2a.types import Task as A2ATask + + +async def _poll_task_until_complete( + client: Client, + task_id: str, + polling_interval: float, + polling_timeout: float, + agent_branch: Any | None = None, + history_length: int = 100, + max_polls: int | None = None, + from_task: Any | None = None, + from_agent: Any | None = None, + context_id: str | None = None, + endpoint: str | None = None, + a2a_agent_name: str | None = None, +) -> A2ATask: + """Poll task status until terminal state reached. + + Args: + client: A2A client instance. + task_id: Task ID to poll. + polling_interval: Seconds between poll attempts. + polling_timeout: Max seconds before timeout. + agent_branch: Agent tree branch for logging. + history_length: Number of messages to retrieve per poll. + max_polls: Max number of poll attempts (None = unlimited). + from_task: Optional CrewAI Task object for event metadata. + from_agent: Optional CrewAI Agent object for event metadata. + context_id: A2A context ID for correlation. + endpoint: A2A agent endpoint URL. + a2a_agent_name: Name of the A2A agent from agent card. + + Returns: + Final task object in terminal state. + + Raises: + A2APollingTimeoutError: If polling exceeds timeout or max_polls. + """ + start_time = time.monotonic() + poll_count = 0 + + while True: + poll_count += 1 + task = await client.get_task( + TaskQueryParams(id=task_id, history_length=history_length) + ) + + elapsed = time.monotonic() - start_time + effective_context_id = task.context_id or context_id + crewai_event_bus.emit( + agent_branch, + A2APollingStatusEvent( + task_id=task_id, + context_id=effective_context_id, + state=str(task.status.state.value), + elapsed_seconds=elapsed, + poll_count=poll_count, + endpoint=endpoint, + a2a_agent_name=a2a_agent_name, + from_task=from_task, + from_agent=from_agent, + ), + ) + + if task.status.state in TERMINAL_STATES | ACTIONABLE_STATES: + return task + + if elapsed > polling_timeout: + raise A2APollingTimeoutError( + f"Polling timeout after {polling_timeout}s ({poll_count} polls)" + ) + + if max_polls and poll_count >= max_polls: + raise A2APollingTimeoutError( + f"Max polls ({max_polls}) exceeded after {elapsed:.1f}s" + ) + + await asyncio.sleep(polling_interval) + + +class PollingHandler: + """Polling-based update handler.""" + + @staticmethod + async def execute( + client: Client, + message: Message, + new_messages: list[Message], + agent_card: AgentCard, + **kwargs: Unpack[PollingHandlerKwargs], + ) -> TaskStateResult: + """Execute A2A delegation using polling for updates. + + Args: + client: A2A client instance. + message: Message to send. + new_messages: List to collect messages. + agent_card: The agent card. + **kwargs: Polling-specific parameters. + + Returns: + Dictionary with status, result/error, and history. + """ + polling_interval = kwargs.get("polling_interval", 2.0) + polling_timeout = kwargs.get("polling_timeout", 300.0) + endpoint = kwargs.get("endpoint", "") + agent_branch = kwargs.get("agent_branch") + turn_number = kwargs.get("turn_number", 0) + is_multiturn = kwargs.get("is_multiturn", False) + agent_role = kwargs.get("agent_role") + history_length = kwargs.get("history_length", 100) + max_polls = kwargs.get("max_polls") + context_id = kwargs.get("context_id") + task_id = kwargs.get("task_id") + a2a_agent_name = kwargs.get("a2a_agent_name") + from_task = kwargs.get("from_task") + from_agent = kwargs.get("from_agent") + + try: + result_or_task_id = await send_message_and_get_task_id( + event_stream=client.send_message(message), + new_messages=new_messages, + agent_card=agent_card, + turn_number=turn_number, + is_multiturn=is_multiturn, + agent_role=agent_role, + from_task=from_task, + from_agent=from_agent, + endpoint=endpoint, + a2a_agent_name=a2a_agent_name, + context_id=context_id, + ) + + if not isinstance(result_or_task_id, str): + return result_or_task_id + + task_id = result_or_task_id + + crewai_event_bus.emit( + agent_branch, + A2APollingStartedEvent( + task_id=task_id, + context_id=context_id, + polling_interval=polling_interval, + endpoint=endpoint, + a2a_agent_name=a2a_agent_name, + from_task=from_task, + from_agent=from_agent, + ), + ) + + final_task = await _poll_task_until_complete( + client=client, + task_id=task_id, + polling_interval=polling_interval, + polling_timeout=polling_timeout, + agent_branch=agent_branch, + history_length=history_length, + max_polls=max_polls, + from_task=from_task, + from_agent=from_agent, + context_id=context_id, + endpoint=endpoint, + a2a_agent_name=a2a_agent_name, + ) + + result = process_task_state( + a2a_task=final_task, + new_messages=new_messages, + agent_card=agent_card, + turn_number=turn_number, + is_multiturn=is_multiturn, + agent_role=agent_role, + endpoint=endpoint, + a2a_agent_name=a2a_agent_name, + from_task=from_task, + from_agent=from_agent, + ) + if result: + return result + + return TaskStateResult( + status=TaskState.failed, + error=f"Unexpected task state: {final_task.status.state}", + history=new_messages, + ) + + except A2APollingTimeoutError as e: + error_msg = str(e) + + error_message = Message( + role=Role.agent, + message_id=str(uuid.uuid4()), + parts=[Part(root=TextPart(text=error_msg))], + context_id=context_id, + task_id=task_id, + ) + new_messages.append(error_message) + + crewai_event_bus.emit( + agent_branch, + A2AResponseReceivedEvent( + response=error_msg, + turn_number=turn_number, + context_id=context_id, + is_multiturn=is_multiturn, + status="failed", + final=True, + agent_role=agent_role, + endpoint=endpoint, + a2a_agent_name=a2a_agent_name, + from_task=from_task, + from_agent=from_agent, + ), + ) + return TaskStateResult( + status=TaskState.failed, + error=error_msg, + history=new_messages, + ) + + except A2AClientHTTPError as e: + error_msg = f"HTTP Error {e.status_code}: {e!s}" + + error_message = Message( + role=Role.agent, + message_id=str(uuid.uuid4()), + parts=[Part(root=TextPart(text=error_msg))], + context_id=context_id, + task_id=task_id, + ) + new_messages.append(error_message) + + crewai_event_bus.emit( + agent_branch, + A2AConnectionErrorEvent( + endpoint=endpoint, + error=str(e), + error_type="http_error", + status_code=e.status_code, + a2a_agent_name=a2a_agent_name, + operation="polling", + context_id=context_id, + task_id=task_id, + from_task=from_task, + from_agent=from_agent, + ), + ) + crewai_event_bus.emit( + agent_branch, + A2AResponseReceivedEvent( + response=error_msg, + turn_number=turn_number, + context_id=context_id, + is_multiturn=is_multiturn, + status="failed", + final=True, + agent_role=agent_role, + endpoint=endpoint, + a2a_agent_name=a2a_agent_name, + from_task=from_task, + from_agent=from_agent, + ), + ) + return TaskStateResult( + status=TaskState.failed, + error=error_msg, + history=new_messages, + ) + + except Exception as e: + error_msg = f"Unexpected error during polling: {e!s}" + + error_message = Message( + role=Role.agent, + message_id=str(uuid.uuid4()), + parts=[Part(root=TextPart(text=error_msg))], + context_id=context_id, + task_id=task_id, + ) + new_messages.append(error_message) + + crewai_event_bus.emit( + agent_branch, + A2AConnectionErrorEvent( + endpoint=endpoint, + error=str(e), + error_type="unexpected_error", + a2a_agent_name=a2a_agent_name, + operation="polling", + context_id=context_id, + task_id=task_id, + from_task=from_task, + from_agent=from_agent, + ), + ) + crewai_event_bus.emit( + agent_branch, + A2AResponseReceivedEvent( + response=error_msg, + turn_number=turn_number, + context_id=context_id, + is_multiturn=is_multiturn, + status="failed", + final=True, + agent_role=agent_role, + endpoint=endpoint, + a2a_agent_name=a2a_agent_name, + from_task=from_task, + from_agent=from_agent, + ), + ) + return TaskStateResult( + status=TaskState.failed, + error=error_msg, + history=new_messages, + ) diff --git a/lib/crewai-a2a/src/crewai_a2a/updates/push_notifications/__init__.py b/lib/crewai-a2a/src/crewai_a2a/updates/push_notifications/__init__.py new file mode 100644 index 000000000..abb3c2f23 --- /dev/null +++ b/lib/crewai-a2a/src/crewai_a2a/updates/push_notifications/__init__.py @@ -0,0 +1 @@ +"""Push notification update mechanism module.""" diff --git a/lib/crewai-a2a/src/crewai_a2a/updates/push_notifications/config.py b/lib/crewai-a2a/src/crewai_a2a/updates/push_notifications/config.py new file mode 100644 index 000000000..d0f486cf0 --- /dev/null +++ b/lib/crewai-a2a/src/crewai_a2a/updates/push_notifications/config.py @@ -0,0 +1,65 @@ +"""Push notification update mechanism configuration.""" + +from __future__ import annotations + +from typing import Annotated + +from a2a.types import PushNotificationAuthenticationInfo +from pydantic import AnyHttpUrl, BaseModel, BeforeValidator, Field + +from crewai_a2a.updates.base import PushNotificationResultStore +from crewai_a2a.updates.push_notifications.signature import WebhookSignatureConfig + + +def _coerce_signature( + value: str | WebhookSignatureConfig | None, +) -> WebhookSignatureConfig | None: + """Convert string secret to WebhookSignatureConfig.""" + if value is None: + return None + if isinstance(value, str): + return WebhookSignatureConfig.hmac_sha256(secret=value) + return value + + +SignatureInput = Annotated[ + WebhookSignatureConfig | None, + BeforeValidator(_coerce_signature), +] + + +class PushNotificationConfig(BaseModel): + """Configuration for webhook-based task updates. + + Attributes: + url: Callback URL where agent sends push notifications. + id: Unique identifier for this config. + token: Token to validate incoming notifications. + authentication: Auth info for agent to use when calling webhook. + timeout: Max seconds to wait for task completion. + interval: Seconds between result polling attempts. + result_store: Store for receiving push notification results. + signature: HMAC signature config. Pass a string (secret) for defaults, + or WebhookSignatureConfig for custom settings. + """ + + url: AnyHttpUrl = Field(description="Callback URL for push notifications") + id: str | None = Field(default=None, description="Unique config identifier") + token: str | None = Field(default=None, description="Validation token") + authentication: PushNotificationAuthenticationInfo | None = Field( + default=None, description="Auth info for agent to use when calling webhook" + ) + timeout: float | None = Field( + default=300.0, gt=0, description="Max seconds to wait for task completion" + ) + interval: float = Field( + default=2.0, gt=0, description="Seconds between result polling attempts" + ) + result_store: PushNotificationResultStore | None = Field( + default=None, description="Result store for push notification handling" + ) + signature: SignatureInput = Field( + default=None, + description="HMAC signature config. Pass a string (secret) for simple usage, " + "or WebhookSignatureConfig for custom headers/tolerance.", + ) diff --git a/lib/crewai-a2a/src/crewai_a2a/updates/push_notifications/handler.py b/lib/crewai-a2a/src/crewai_a2a/updates/push_notifications/handler.py new file mode 100644 index 000000000..19add538b --- /dev/null +++ b/lib/crewai-a2a/src/crewai_a2a/updates/push_notifications/handler.py @@ -0,0 +1,354 @@ +"""Push notification (webhook) update mechanism handler.""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING, Any +import uuid + +from a2a.client import Client +from a2a.client.errors import A2AClientHTTPError +from a2a.types import ( + AgentCard, + Message, + Part, + Role, + TaskState, + TextPart, +) +from crewai.events.event_bus import crewai_event_bus +from crewai.events.types.a2a_events import ( + A2AConnectionErrorEvent, + A2APushNotificationRegisteredEvent, + A2APushNotificationTimeoutEvent, + A2AResponseReceivedEvent, +) +from typing_extensions import Unpack + +from crewai_a2a.task_helpers import ( + TaskStateResult, + process_task_state, + send_message_and_get_task_id, +) +from crewai_a2a.updates.base import ( + CommonParams, + PushNotificationHandlerKwargs, + PushNotificationResultStore, + extract_common_params, +) + + +if TYPE_CHECKING: + from a2a.types import Task as A2ATask + +logger = logging.getLogger(__name__) + + +def _handle_push_error( + error: Exception, + error_msg: str, + error_type: str, + new_messages: list[Message], + agent_branch: Any | None, + params: CommonParams, + task_id: str | None, + status_code: int | None = None, +) -> TaskStateResult: + """Handle push notification errors with consistent event emission. + + Args: + error: The exception that occurred. + error_msg: Formatted error message for the result. + error_type: Type of error for the event. + new_messages: List to append error message to. + agent_branch: Agent tree branch for events. + params: Common handler parameters. + task_id: A2A task ID. + status_code: HTTP status code if applicable. + + Returns: + TaskStateResult with failed status. + """ + error_message = Message( + role=Role.agent, + message_id=str(uuid.uuid4()), + parts=[Part(root=TextPart(text=error_msg))], + context_id=params.context_id, + task_id=task_id, + ) + new_messages.append(error_message) + + crewai_event_bus.emit( + agent_branch, + A2AConnectionErrorEvent( + endpoint=params.endpoint, + error=str(error), + error_type=error_type, + status_code=status_code, + a2a_agent_name=params.a2a_agent_name, + operation="push_notification", + context_id=params.context_id, + task_id=task_id, + from_task=params.from_task, + from_agent=params.from_agent, + ), + ) + crewai_event_bus.emit( + agent_branch, + A2AResponseReceivedEvent( + response=error_msg, + turn_number=params.turn_number, + context_id=params.context_id, + is_multiturn=params.is_multiturn, + status="failed", + final=True, + agent_role=params.agent_role, + endpoint=params.endpoint, + a2a_agent_name=params.a2a_agent_name, + from_task=params.from_task, + from_agent=params.from_agent, + ), + ) + return TaskStateResult( + status=TaskState.failed, + error=error_msg, + history=new_messages, + ) + + +async def _wait_for_push_result( + task_id: str, + result_store: PushNotificationResultStore, + timeout: float, + poll_interval: float, + agent_branch: Any | None = None, + from_task: Any | None = None, + from_agent: Any | None = None, + context_id: str | None = None, + endpoint: str | None = None, + a2a_agent_name: str | None = None, +) -> A2ATask | None: + """Wait for push notification result. + + Args: + task_id: Task ID to wait for. + result_store: Store to retrieve results from. + timeout: Max seconds to wait. + poll_interval: Seconds between polling attempts. + agent_branch: Agent tree branch for logging. + from_task: Optional CrewAI Task object for event metadata. + from_agent: Optional CrewAI Agent object for event metadata. + context_id: A2A context ID for correlation. + endpoint: A2A agent endpoint URL. + a2a_agent_name: Name of the A2A agent. + + Returns: + Final task object, or None if timeout. + """ + task = await result_store.wait_for_result( + task_id=task_id, + timeout=timeout, + poll_interval=poll_interval, + ) + + if task is None: + crewai_event_bus.emit( + agent_branch, + A2APushNotificationTimeoutEvent( + task_id=task_id, + context_id=context_id, + timeout_seconds=timeout, + endpoint=endpoint, + a2a_agent_name=a2a_agent_name, + from_task=from_task, + from_agent=from_agent, + ), + ) + + return task + + +class PushNotificationHandler: + """Push notification (webhook) based update handler.""" + + @staticmethod + async def execute( + client: Client, + message: Message, + new_messages: list[Message], + agent_card: AgentCard, + **kwargs: Unpack[PushNotificationHandlerKwargs], + ) -> TaskStateResult: + """Execute A2A delegation using push notifications for updates. + + Args: + client: A2A client instance. + message: Message to send. + new_messages: List to collect messages. + agent_card: The agent card. + **kwargs: Push notification-specific parameters. + + Returns: + Dictionary with status, result/error, and history. + + Raises: + ValueError: If result_store or config not provided. + """ + config = kwargs.get("config") + result_store = kwargs.get("result_store") + polling_timeout = kwargs.get("polling_timeout", 300.0) + polling_interval = kwargs.get("polling_interval", 2.0) + agent_branch = kwargs.get("agent_branch") + task_id = kwargs.get("task_id") + params = extract_common_params(kwargs) + + if config is None: + error_msg = ( + "PushNotificationConfig is required for push notification handler" + ) + crewai_event_bus.emit( + agent_branch, + A2AConnectionErrorEvent( + endpoint=params.endpoint, + error=error_msg, + error_type="configuration_error", + a2a_agent_name=params.a2a_agent_name, + operation="push_notification", + context_id=params.context_id, + task_id=task_id, + from_task=params.from_task, + from_agent=params.from_agent, + ), + ) + return TaskStateResult( + status=TaskState.failed, + error=error_msg, + history=new_messages, + ) + + if result_store is None: + error_msg = ( + "PushNotificationResultStore is required for push notification handler" + ) + crewai_event_bus.emit( + agent_branch, + A2AConnectionErrorEvent( + endpoint=params.endpoint, + error=error_msg, + error_type="configuration_error", + a2a_agent_name=params.a2a_agent_name, + operation="push_notification", + context_id=params.context_id, + task_id=task_id, + from_task=params.from_task, + from_agent=params.from_agent, + ), + ) + return TaskStateResult( + status=TaskState.failed, + error=error_msg, + history=new_messages, + ) + + try: + result_or_task_id = await send_message_and_get_task_id( + event_stream=client.send_message(message), + new_messages=new_messages, + agent_card=agent_card, + turn_number=params.turn_number, + is_multiturn=params.is_multiturn, + agent_role=params.agent_role, + from_task=params.from_task, + from_agent=params.from_agent, + endpoint=params.endpoint, + a2a_agent_name=params.a2a_agent_name, + context_id=params.context_id, + ) + + if not isinstance(result_or_task_id, str): + return result_or_task_id + + task_id = result_or_task_id + + crewai_event_bus.emit( + agent_branch, + A2APushNotificationRegisteredEvent( + task_id=task_id, + context_id=params.context_id, + callback_url=str(config.url), + endpoint=params.endpoint, + a2a_agent_name=params.a2a_agent_name, + from_task=params.from_task, + from_agent=params.from_agent, + ), + ) + + logger.debug( + "Push notification callback for task %s configured at %s (via initial request)", + task_id, + config.url, + ) + + final_task = await _wait_for_push_result( + task_id=task_id, + result_store=result_store, + timeout=polling_timeout, + poll_interval=polling_interval, + agent_branch=agent_branch, + from_task=params.from_task, + from_agent=params.from_agent, + context_id=params.context_id, + endpoint=params.endpoint, + a2a_agent_name=params.a2a_agent_name, + ) + + if final_task is None: + return TaskStateResult( + status=TaskState.failed, + error=f"Push notification timeout after {polling_timeout}s", + history=new_messages, + ) + + result = process_task_state( + a2a_task=final_task, + new_messages=new_messages, + agent_card=agent_card, + turn_number=params.turn_number, + is_multiturn=params.is_multiturn, + agent_role=params.agent_role, + endpoint=params.endpoint, + a2a_agent_name=params.a2a_agent_name, + from_task=params.from_task, + from_agent=params.from_agent, + ) + if result: + return result + + return TaskStateResult( + status=TaskState.failed, + error=f"Unexpected task state: {final_task.status.state}", + history=new_messages, + ) + + except A2AClientHTTPError as e: + return _handle_push_error( + error=e, + error_msg=f"HTTP Error {e.status_code}: {e!s}", + error_type="http_error", + new_messages=new_messages, + agent_branch=agent_branch, + params=params, + task_id=task_id, + status_code=e.status_code, + ) + + except Exception as e: + return _handle_push_error( + error=e, + error_msg=f"Unexpected error during push notification: {e!s}", + error_type="unexpected_error", + new_messages=new_messages, + agent_branch=agent_branch, + params=params, + task_id=task_id, + ) diff --git a/lib/crewai-a2a/src/crewai_a2a/updates/push_notifications/signature.py b/lib/crewai-a2a/src/crewai_a2a/updates/push_notifications/signature.py new file mode 100644 index 000000000..9cac929ec --- /dev/null +++ b/lib/crewai-a2a/src/crewai_a2a/updates/push_notifications/signature.py @@ -0,0 +1,87 @@ +"""Webhook signature configuration for push notifications.""" + +from __future__ import annotations + +from enum import Enum +import secrets + +from pydantic import BaseModel, Field, SecretStr + + +class WebhookSignatureMode(str, Enum): + """Signature mode for webhook push notifications.""" + + NONE = "none" + HMAC_SHA256 = "hmac_sha256" + + +class WebhookSignatureConfig(BaseModel): + """Configuration for webhook signature verification. + + Provides cryptographic integrity verification and replay attack protection + for A2A push notifications. + + Attributes: + mode: Signature mode (none or hmac_sha256). + secret: Shared secret for HMAC computation (required for hmac_sha256 mode). + timestamp_tolerance_seconds: Max allowed age of timestamps for replay protection. + header_name: HTTP header name for the signature. + timestamp_header_name: HTTP header name for the timestamp. + """ + + mode: WebhookSignatureMode = Field( + default=WebhookSignatureMode.NONE, + description="Signature verification mode", + ) + secret: SecretStr | None = Field( + default=None, + description="Shared secret for HMAC computation", + ) + timestamp_tolerance_seconds: int = Field( + default=300, + ge=0, + description="Max allowed timestamp age in seconds (5 min default)", + ) + header_name: str = Field( + default="X-A2A-Signature", + description="HTTP header name for the signature", + ) + timestamp_header_name: str = Field( + default="X-A2A-Signature-Timestamp", + description="HTTP header name for the timestamp", + ) + + @classmethod + def generate_secret(cls, length: int = 32) -> str: + """Generate a cryptographically secure random secret. + + Args: + length: Number of random bytes to generate (default 32). + + Returns: + URL-safe base64-encoded secret string. + """ + return secrets.token_urlsafe(length) + + @classmethod + def hmac_sha256( + cls, + secret: str | SecretStr, + timestamp_tolerance_seconds: int = 300, + ) -> WebhookSignatureConfig: + """Create an HMAC-SHA256 signature configuration. + + Args: + secret: Shared secret for HMAC computation. + timestamp_tolerance_seconds: Max allowed timestamp age in seconds. + + Returns: + Configured WebhookSignatureConfig for HMAC-SHA256. + """ + if isinstance(secret, str): + secret = SecretStr(secret) + return cls( + mode=WebhookSignatureMode.HMAC_SHA256, + secret=secret, + timestamp_tolerance_seconds=timestamp_tolerance_seconds, + ) diff --git a/lib/crewai-a2a/src/crewai_a2a/updates/streaming/__init__.py b/lib/crewai-a2a/src/crewai_a2a/updates/streaming/__init__.py new file mode 100644 index 000000000..7adada8b5 --- /dev/null +++ b/lib/crewai-a2a/src/crewai_a2a/updates/streaming/__init__.py @@ -0,0 +1 @@ +"""Streaming update mechanism module.""" diff --git a/lib/crewai-a2a/src/crewai_a2a/updates/streaming/config.py b/lib/crewai-a2a/src/crewai_a2a/updates/streaming/config.py new file mode 100644 index 000000000..6098bf550 --- /dev/null +++ b/lib/crewai-a2a/src/crewai_a2a/updates/streaming/config.py @@ -0,0 +1,9 @@ +"""Streaming update mechanism configuration.""" + +from __future__ import annotations + +from pydantic import BaseModel + + +class StreamingConfig(BaseModel): + """Configuration for SSE-based task updates.""" diff --git a/lib/crewai-a2a/src/crewai_a2a/updates/streaming/handler.py b/lib/crewai-a2a/src/crewai_a2a/updates/streaming/handler.py new file mode 100644 index 000000000..fc5abcbcd --- /dev/null +++ b/lib/crewai-a2a/src/crewai_a2a/updates/streaming/handler.py @@ -0,0 +1,646 @@ +"""Streaming (SSE) update mechanism handler.""" + +from __future__ import annotations + +import asyncio +import logging +from typing import Final +import uuid + +from a2a.client import Client +from a2a.client.errors import A2AClientHTTPError +from a2a.types import ( + AgentCard, + Message, + Part, + Role, + Task, + TaskArtifactUpdateEvent, + TaskIdParams, + TaskQueryParams, + TaskState, + TaskStatusUpdateEvent, + TextPart, +) +from crewai.events.event_bus import crewai_event_bus +from crewai.events.types.a2a_events import ( + A2AArtifactReceivedEvent, + A2AConnectionErrorEvent, + A2AResponseReceivedEvent, + A2AStreamingChunkEvent, + A2AStreamingStartedEvent, +) +from typing_extensions import Unpack + +from crewai_a2a.task_helpers import ( + ACTIONABLE_STATES, + TERMINAL_STATES, + TaskStateResult, + process_task_state, +) +from crewai_a2a.updates.base import StreamingHandlerKwargs, extract_common_params +from crewai_a2a.updates.streaming.params import ( + process_status_update, +) + + +logger = logging.getLogger(__name__) + +MAX_RESUBSCRIBE_ATTEMPTS: Final[int] = 3 +RESUBSCRIBE_BACKOFF_BASE: Final[float] = 1.0 + + +class StreamingHandler: + """SSE streaming-based update handler.""" + + @staticmethod + async def _try_recover_from_interruption( # type: ignore[misc] + client: Client, + task_id: str, + new_messages: list[Message], + agent_card: AgentCard, + result_parts: list[str], + **kwargs: Unpack[StreamingHandlerKwargs], + ) -> TaskStateResult | None: + """Attempt to recover from a stream interruption by checking task state. + + If the task completed while we were disconnected, returns the result. + If the task is still running, attempts to resubscribe and continue. + + Args: + client: A2A client instance. + task_id: The task ID to recover. + new_messages: List of collected messages. + agent_card: The agent card. + result_parts: Accumulated result text parts. + **kwargs: Handler parameters. + + Returns: + TaskStateResult if recovery succeeded (task finished or resubscribe worked). + None if recovery not possible (caller should handle failure). + + Note: + When None is returned, recovery failed and the original exception should + be handled by the caller. All recovery attempts are logged. + """ + params = extract_common_params(kwargs) # type: ignore[arg-type] + + try: + a2a_task: Task = await client.get_task(TaskQueryParams(id=task_id)) + + if a2a_task.status.state in TERMINAL_STATES: + logger.info( + "Task completed during stream interruption", + extra={"task_id": task_id, "state": str(a2a_task.status.state)}, + ) + return process_task_state( + a2a_task=a2a_task, + new_messages=new_messages, + agent_card=agent_card, + turn_number=params.turn_number, + is_multiturn=params.is_multiturn, + agent_role=params.agent_role, + result_parts=result_parts, + endpoint=params.endpoint, + a2a_agent_name=params.a2a_agent_name, + from_task=params.from_task, + from_agent=params.from_agent, + ) + + if a2a_task.status.state in ACTIONABLE_STATES: + logger.info( + "Task in actionable state during stream interruption", + extra={"task_id": task_id, "state": str(a2a_task.status.state)}, + ) + return process_task_state( + a2a_task=a2a_task, + new_messages=new_messages, + agent_card=agent_card, + turn_number=params.turn_number, + is_multiturn=params.is_multiturn, + agent_role=params.agent_role, + result_parts=result_parts, + endpoint=params.endpoint, + a2a_agent_name=params.a2a_agent_name, + from_task=params.from_task, + from_agent=params.from_agent, + is_final=False, + ) + + logger.info( + "Task still running, attempting resubscribe", + extra={"task_id": task_id, "state": str(a2a_task.status.state)}, + ) + + for attempt in range(MAX_RESUBSCRIBE_ATTEMPTS): + try: + backoff = RESUBSCRIBE_BACKOFF_BASE * (2**attempt) + if attempt > 0: + await asyncio.sleep(backoff) + + event_stream = client.resubscribe(TaskIdParams(id=task_id)) + + async for event in event_stream: + if isinstance(event, tuple): + resubscribed_task, update = event + + is_final_update = ( + process_status_update(update, result_parts) + if isinstance(update, TaskStatusUpdateEvent) + else False + ) + + if isinstance(update, TaskArtifactUpdateEvent): + artifact = update.artifact + result_parts.extend( + part.root.text + for part in artifact.parts + if part.root.kind == "text" + ) + + if ( + is_final_update + or resubscribed_task.status.state + in TERMINAL_STATES | ACTIONABLE_STATES + ): + return process_task_state( + a2a_task=resubscribed_task, + new_messages=new_messages, + agent_card=agent_card, + turn_number=params.turn_number, + is_multiturn=params.is_multiturn, + agent_role=params.agent_role, + result_parts=result_parts, + endpoint=params.endpoint, + a2a_agent_name=params.a2a_agent_name, + from_task=params.from_task, + from_agent=params.from_agent, + is_final=is_final_update, + ) + + elif isinstance(event, Message): + new_messages.append(event) + result_parts.extend( + part.root.text + for part in event.parts + if part.root.kind == "text" + ) + + final_task = await client.get_task(TaskQueryParams(id=task_id)) + return process_task_state( + a2a_task=final_task, + new_messages=new_messages, + agent_card=agent_card, + turn_number=params.turn_number, + is_multiturn=params.is_multiturn, + agent_role=params.agent_role, + result_parts=result_parts, + endpoint=params.endpoint, + a2a_agent_name=params.a2a_agent_name, + from_task=params.from_task, + from_agent=params.from_agent, + ) + + except Exception as resubscribe_error: # noqa: PERF203 + logger.warning( + "Resubscribe attempt failed", + extra={ + "task_id": task_id, + "attempt": attempt + 1, + "max_attempts": MAX_RESUBSCRIBE_ATTEMPTS, + "error": str(resubscribe_error), + }, + ) + if attempt == MAX_RESUBSCRIBE_ATTEMPTS - 1: + return None + + except Exception as e: + logger.warning( + "Failed to recover from stream interruption due to unexpected error", + extra={ + "task_id": task_id, + "error": str(e), + "error_type": type(e).__name__, + }, + exc_info=True, + ) + return None + + logger.warning( + "Recovery exhausted all resubscribe attempts without success", + extra={"task_id": task_id, "max_attempts": MAX_RESUBSCRIBE_ATTEMPTS}, + ) + return None + + @staticmethod + async def execute( + client: Client, + message: Message, + new_messages: list[Message], + agent_card: AgentCard, + **kwargs: Unpack[StreamingHandlerKwargs], + ) -> TaskStateResult: + """Execute A2A delegation using SSE streaming for updates. + + Args: + client: A2A client instance. + message: Message to send. + new_messages: List to collect messages. + agent_card: The agent card. + **kwargs: Streaming-specific parameters. + + Returns: + Dictionary with status, result/error, and history. + """ + task_id = kwargs.get("task_id") + agent_branch = kwargs.get("agent_branch") + params = extract_common_params(kwargs) + + result_parts: list[str] = [] + final_result: TaskStateResult | None = None + event_stream = client.send_message(message) + chunk_index = 0 + current_task_id: str | None = task_id + + crewai_event_bus.emit( + agent_branch, + A2AStreamingStartedEvent( + task_id=task_id, + context_id=params.context_id, + endpoint=params.endpoint, + a2a_agent_name=params.a2a_agent_name, + turn_number=params.turn_number, + is_multiturn=params.is_multiturn, + agent_role=params.agent_role, + from_task=params.from_task, + from_agent=params.from_agent, + ), + ) + + try: + async for event in event_stream: + if isinstance(event, tuple): + a2a_task, _ = event + current_task_id = a2a_task.id + + if isinstance(event, Message): + new_messages.append(event) + message_context_id = event.context_id or params.context_id + for part in event.parts: + if part.root.kind == "text": + text = part.root.text + result_parts.append(text) + crewai_event_bus.emit( + agent_branch, + A2AStreamingChunkEvent( + task_id=event.task_id or task_id, + context_id=message_context_id, + chunk=text, + chunk_index=chunk_index, + endpoint=params.endpoint, + a2a_agent_name=params.a2a_agent_name, + turn_number=params.turn_number, + is_multiturn=params.is_multiturn, + from_task=params.from_task, + from_agent=params.from_agent, + ), + ) + chunk_index += 1 + + elif isinstance(event, tuple): + a2a_task, update = event + + if isinstance(update, TaskArtifactUpdateEvent): + artifact = update.artifact + result_parts.extend( + part.root.text + for part in artifact.parts + if part.root.kind == "text" + ) + artifact_size = None + if artifact.parts: + artifact_size = sum( + len(p.root.text.encode()) + if p.root.kind == "text" + else len(getattr(p.root, "data", b"")) + for p in artifact.parts + ) + effective_context_id = a2a_task.context_id or params.context_id + crewai_event_bus.emit( + agent_branch, + A2AArtifactReceivedEvent( + task_id=a2a_task.id, + artifact_id=artifact.artifact_id, + artifact_name=artifact.name, + artifact_description=artifact.description, + mime_type=artifact.parts[0].root.kind + if artifact.parts + else None, + size_bytes=artifact_size, + append=update.append or False, + last_chunk=update.last_chunk or False, + endpoint=params.endpoint, + a2a_agent_name=params.a2a_agent_name, + context_id=effective_context_id, + turn_number=params.turn_number, + is_multiturn=params.is_multiturn, + from_task=params.from_task, + from_agent=params.from_agent, + ), + ) + + is_final_update = ( + process_status_update(update, result_parts) + if isinstance(update, TaskStatusUpdateEvent) + else False + ) + + if ( + not is_final_update + and a2a_task.status.state + not in TERMINAL_STATES | ACTIONABLE_STATES + ): + continue + + final_result = process_task_state( + a2a_task=a2a_task, + new_messages=new_messages, + agent_card=agent_card, + turn_number=params.turn_number, + is_multiturn=params.is_multiturn, + agent_role=params.agent_role, + result_parts=result_parts, + endpoint=params.endpoint, + a2a_agent_name=params.a2a_agent_name, + from_task=params.from_task, + from_agent=params.from_agent, + is_final=is_final_update, + ) + if final_result: + break + + except A2AClientHTTPError as e: + if current_task_id: + logger.info( + "Stream interrupted with HTTP error, attempting recovery", + extra={ + "task_id": current_task_id, + "error": str(e), + "status_code": e.status_code, + }, + ) + recovery_kwargs = {k: v for k, v in kwargs.items() if k != "task_id"} + recovered_result = ( + await StreamingHandler._try_recover_from_interruption( + client=client, + task_id=current_task_id, + new_messages=new_messages, + agent_card=agent_card, + result_parts=result_parts, + **recovery_kwargs, + ) + ) + if recovered_result: + logger.info( + "Successfully recovered task after HTTP error", + extra={ + "task_id": current_task_id, + "status": str(recovered_result.get("status")), + }, + ) + return recovered_result + + logger.warning( + "Failed to recover from HTTP error, returning failure", + extra={ + "task_id": current_task_id, + "status_code": e.status_code, + "original_error": str(e), + }, + ) + + error_msg = f"HTTP Error {e.status_code}: {e!s}" + error_type = "http_error" + status_code = e.status_code + + error_message = Message( + role=Role.agent, + message_id=str(uuid.uuid4()), + parts=[Part(root=TextPart(text=error_msg))], + context_id=params.context_id, + task_id=task_id, + ) + new_messages.append(error_message) + + crewai_event_bus.emit( + agent_branch, + A2AConnectionErrorEvent( + endpoint=params.endpoint, + error=str(e), + error_type=error_type, + status_code=status_code, + a2a_agent_name=params.a2a_agent_name, + operation="streaming", + context_id=params.context_id, + task_id=task_id, + from_task=params.from_task, + from_agent=params.from_agent, + ), + ) + crewai_event_bus.emit( + agent_branch, + A2AResponseReceivedEvent( + response=error_msg, + turn_number=params.turn_number, + context_id=params.context_id, + is_multiturn=params.is_multiturn, + status="failed", + final=True, + agent_role=params.agent_role, + endpoint=params.endpoint, + a2a_agent_name=params.a2a_agent_name, + from_task=params.from_task, + from_agent=params.from_agent, + ), + ) + return TaskStateResult( + status=TaskState.failed, + error=error_msg, + history=new_messages, + ) + + except (asyncio.TimeoutError, asyncio.CancelledError, ConnectionError) as e: + error_type = type(e).__name__.lower() + if current_task_id: + logger.info( + f"Stream interrupted with {error_type}, attempting recovery", + extra={"task_id": current_task_id, "error": str(e)}, + ) + recovery_kwargs = {k: v for k, v in kwargs.items() if k != "task_id"} + recovered_result = ( + await StreamingHandler._try_recover_from_interruption( + client=client, + task_id=current_task_id, + new_messages=new_messages, + agent_card=agent_card, + result_parts=result_parts, + **recovery_kwargs, + ) + ) + if recovered_result: + logger.info( + f"Successfully recovered task after {error_type}", + extra={ + "task_id": current_task_id, + "status": str(recovered_result.get("status")), + }, + ) + return recovered_result + + logger.warning( + f"Failed to recover from {error_type}, returning failure", + extra={ + "task_id": current_task_id, + "error_type": error_type, + "original_error": str(e), + }, + ) + + error_msg = f"Connection error during streaming: {e!s}" + status_code = None + + error_message = Message( + role=Role.agent, + message_id=str(uuid.uuid4()), + parts=[Part(root=TextPart(text=error_msg))], + context_id=params.context_id, + task_id=task_id, + ) + new_messages.append(error_message) + + crewai_event_bus.emit( + agent_branch, + A2AConnectionErrorEvent( + endpoint=params.endpoint, + error=str(e), + error_type=error_type, + status_code=status_code, + a2a_agent_name=params.a2a_agent_name, + operation="streaming", + context_id=params.context_id, + task_id=task_id, + from_task=params.from_task, + from_agent=params.from_agent, + ), + ) + crewai_event_bus.emit( + agent_branch, + A2AResponseReceivedEvent( + response=error_msg, + turn_number=params.turn_number, + context_id=params.context_id, + is_multiturn=params.is_multiturn, + status="failed", + final=True, + agent_role=params.agent_role, + endpoint=params.endpoint, + a2a_agent_name=params.a2a_agent_name, + from_task=params.from_task, + from_agent=params.from_agent, + ), + ) + return TaskStateResult( + status=TaskState.failed, + error=error_msg, + history=new_messages, + ) + + except Exception as e: + logger.exception( + "Unexpected error during streaming", + extra={ + "task_id": current_task_id, + "error_type": type(e).__name__, + "endpoint": params.endpoint, + }, + ) + error_msg = f"Unexpected error during streaming: {type(e).__name__}: {e!s}" + error_type = "unexpected_error" + status_code = None + + error_message = Message( + role=Role.agent, + message_id=str(uuid.uuid4()), + parts=[Part(root=TextPart(text=error_msg))], + context_id=params.context_id, + task_id=task_id, + ) + new_messages.append(error_message) + + crewai_event_bus.emit( + agent_branch, + A2AConnectionErrorEvent( + endpoint=params.endpoint, + error=str(e), + error_type=error_type, + status_code=status_code, + a2a_agent_name=params.a2a_agent_name, + operation="streaming", + context_id=params.context_id, + task_id=task_id, + from_task=params.from_task, + from_agent=params.from_agent, + ), + ) + crewai_event_bus.emit( + agent_branch, + A2AResponseReceivedEvent( + response=error_msg, + turn_number=params.turn_number, + context_id=params.context_id, + is_multiturn=params.is_multiturn, + status="failed", + final=True, + agent_role=params.agent_role, + endpoint=params.endpoint, + a2a_agent_name=params.a2a_agent_name, + from_task=params.from_task, + from_agent=params.from_agent, + ), + ) + return TaskStateResult( + status=TaskState.failed, + error=error_msg, + history=new_messages, + ) + + finally: + aclose = getattr(event_stream, "aclose", None) + if aclose: + try: + await aclose() + except Exception as close_error: + crewai_event_bus.emit( + agent_branch, + A2AConnectionErrorEvent( + endpoint=params.endpoint, + error=str(close_error), + error_type="stream_close_error", + a2a_agent_name=params.a2a_agent_name, + operation="stream_close", + context_id=params.context_id, + task_id=task_id, + from_task=params.from_task, + from_agent=params.from_agent, + ), + ) + + if final_result: + return final_result + + return TaskStateResult( + status=TaskState.completed, + result=" ".join(result_parts) if result_parts else "", + history=new_messages, + agent_card=agent_card.model_dump(exclude_none=True), + ) diff --git a/lib/crewai-a2a/src/crewai_a2a/updates/streaming/params.py b/lib/crewai-a2a/src/crewai_a2a/updates/streaming/params.py new file mode 100644 index 000000000..a4bf8c0a2 --- /dev/null +++ b/lib/crewai-a2a/src/crewai_a2a/updates/streaming/params.py @@ -0,0 +1,28 @@ +"""Common parameter extraction for streaming handlers.""" + +from __future__ import annotations + +from a2a.types import TaskStatusUpdateEvent + + +def process_status_update( + update: TaskStatusUpdateEvent, + result_parts: list[str], +) -> bool: + """Process a status update event and extract text parts. + + Args: + update: The status update event. + result_parts: List to append text parts to (modified in place). + + Returns: + True if this is a final update, False otherwise. + """ + is_final = update.final + if update.status and update.status.message and update.status.message.parts: + result_parts.extend( + part.root.text + for part in update.status.message.parts + if part.root.kind == "text" and part.root.text + ) + return is_final diff --git a/lib/crewai-a2a/src/crewai_a2a/utils/__init__.py b/lib/crewai-a2a/src/crewai_a2a/utils/__init__.py new file mode 100644 index 000000000..bdb7bed62 --- /dev/null +++ b/lib/crewai-a2a/src/crewai_a2a/utils/__init__.py @@ -0,0 +1 @@ +"""A2A utility modules for client operations.""" diff --git a/lib/crewai-a2a/src/crewai_a2a/utils/agent_card.py b/lib/crewai-a2a/src/crewai_a2a/utils/agent_card.py new file mode 100644 index 000000000..18cff95be --- /dev/null +++ b/lib/crewai-a2a/src/crewai_a2a/utils/agent_card.py @@ -0,0 +1,587 @@ +"""AgentCard utilities for A2A client and server operations.""" + +from __future__ import annotations + +import asyncio +from collections.abc import MutableMapping +from functools import lru_cache +import ssl +import time +from types import MethodType +from typing import TYPE_CHECKING + +from a2a.client.errors import A2AClientHTTPError +from a2a.types import AgentCapabilities, AgentCard, AgentSkill +from aiocache import cached # type: ignore[import-untyped] +from aiocache.serializers import PickleSerializer # type: ignore[import-untyped] +from crewai.crew import Crew +from crewai.events.event_bus import crewai_event_bus +from crewai.events.types.a2a_events import ( + A2AAgentCardFetchedEvent, + A2AAuthenticationFailedEvent, + A2AConnectionErrorEvent, +) +import httpx + +from crewai_a2a.auth.client_schemes import APIKeyAuth, HTTPDigestAuth +from crewai_a2a.auth.utils import ( + _auth_store, + configure_auth_client, + retry_on_401, +) +from crewai_a2a.config import A2AServerConfig + + +if TYPE_CHECKING: + from crewai.agent import Agent + from crewai.task import Task + + from crewai_a2a.auth.client_schemes import ClientAuthScheme + + +def _get_tls_verify(auth: ClientAuthScheme | None) -> ssl.SSLContext | bool | str: + """Get TLS verify parameter from auth scheme. + + Args: + auth: Optional authentication scheme with TLS config. + + Returns: + SSL context, CA cert path, True for default verification, + or False if verification disabled. + """ + if auth and auth.tls: + return auth.tls.get_httpx_ssl_context() + return True + + +async def _prepare_auth_headers( + auth: ClientAuthScheme | None, + timeout: int, +) -> tuple[MutableMapping[str, str], ssl.SSLContext | bool | str]: + """Prepare authentication headers and TLS verification settings. + + Args: + auth: Optional authentication scheme. + timeout: Request timeout in seconds. + + Returns: + Tuple of (headers dict, TLS verify setting). + """ + headers: MutableMapping[str, str] = {} + verify = _get_tls_verify(auth) + if auth: + async with httpx.AsyncClient( + timeout=timeout, verify=verify + ) as temp_auth_client: + if isinstance(auth, (HTTPDigestAuth, APIKeyAuth)): + configure_auth_client(auth, temp_auth_client) + headers = await auth.apply_auth(temp_auth_client, {}) + return headers, verify + + +def _get_server_config(agent: Agent) -> A2AServerConfig | None: + """Get A2AServerConfig from an agent's a2a configuration. + + Args: + agent: The Agent instance to check. + + Returns: + A2AServerConfig if present, None otherwise. + """ + if agent.a2a is None: + return None + if isinstance(agent.a2a, A2AServerConfig): + return agent.a2a + if isinstance(agent.a2a, list): + for config in agent.a2a: + if isinstance(config, A2AServerConfig): + return config + return None + + +def fetch_agent_card( + endpoint: str, + auth: ClientAuthScheme | None = None, + timeout: int = 30, + use_cache: bool = True, + cache_ttl: int = 300, +) -> AgentCard: + """Fetch AgentCard from an A2A endpoint with optional caching. + + Args: + endpoint: A2A agent endpoint URL (AgentCard URL). + auth: Optional ClientAuthScheme for authentication. + timeout: Request timeout in seconds. + use_cache: Whether to use caching (default True). + cache_ttl: Cache TTL in seconds (default 300 = 5 minutes). + + Returns: + AgentCard object with agent capabilities and skills. + + Raises: + httpx.HTTPStatusError: If the request fails. + A2AClientHTTPError: If authentication fails. + """ + if use_cache: + if auth: + auth_data = auth.model_dump_json( + exclude={ + "_access_token", + "_token_expires_at", + "_refresh_token", + "_authorization_callback", + } + ) + auth_hash = _auth_store.compute_key(type(auth).__name__, auth_data) + else: + auth_hash = _auth_store.compute_key("none", "") + _auth_store.set(auth_hash, auth) + ttl_hash = int(time.time() // cache_ttl) + return _fetch_agent_card_cached(endpoint, auth_hash, timeout, ttl_hash) + + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + return loop.run_until_complete( + afetch_agent_card(endpoint=endpoint, auth=auth, timeout=timeout) + ) + finally: + loop.close() + + +async def afetch_agent_card( + endpoint: str, + auth: ClientAuthScheme | None = None, + timeout: int = 30, + use_cache: bool = True, +) -> AgentCard: + """Fetch AgentCard from an A2A endpoint asynchronously. + + Native async implementation. Use this when running in an async context. + + Args: + endpoint: A2A agent endpoint URL (AgentCard URL). + auth: Optional ClientAuthScheme for authentication. + timeout: Request timeout in seconds. + use_cache: Whether to use caching (default True). + + Returns: + AgentCard object with agent capabilities and skills. + + Raises: + httpx.HTTPStatusError: If the request fails. + A2AClientHTTPError: If authentication fails. + """ + if use_cache: + if auth: + auth_data = auth.model_dump_json( + exclude={ + "_access_token", + "_token_expires_at", + "_refresh_token", + "_authorization_callback", + } + ) + auth_hash = _auth_store.compute_key(type(auth).__name__, auth_data) + else: + auth_hash = _auth_store.compute_key("none", "") + _auth_store.set(auth_hash, auth) + agent_card: AgentCard = await _afetch_agent_card_cached( + endpoint, auth_hash, timeout + ) + return agent_card + + return await _afetch_agent_card_impl(endpoint=endpoint, auth=auth, timeout=timeout) + + +@lru_cache() +def _fetch_agent_card_cached( + endpoint: str, + auth_hash: str, + timeout: int, + _ttl_hash: int, +) -> AgentCard: + """Cached sync version of fetch_agent_card.""" + auth = _auth_store.get(auth_hash) + + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + return loop.run_until_complete( + _afetch_agent_card_impl(endpoint=endpoint, auth=auth, timeout=timeout) + ) + finally: + loop.close() + + +@cached(ttl=300, serializer=PickleSerializer()) # type: ignore[untyped-decorator] +async def _afetch_agent_card_cached( + endpoint: str, + auth_hash: str, + timeout: int, +) -> AgentCard: + """Cached async implementation of AgentCard fetching.""" + auth = _auth_store.get(auth_hash) + return await _afetch_agent_card_impl(endpoint=endpoint, auth=auth, timeout=timeout) + + +async def _afetch_agent_card_impl( + endpoint: str, + auth: ClientAuthScheme | None, + timeout: int, +) -> AgentCard: + """Internal async implementation of AgentCard fetching.""" + start_time = time.perf_counter() + + if "/.well-known/agent-card.json" in endpoint: + base_url = endpoint.replace("/.well-known/agent-card.json", "") + agent_card_path = "/.well-known/agent-card.json" + else: + url_parts = endpoint.split("/", 3) + base_url = f"{url_parts[0]}//{url_parts[2]}" + agent_card_path = ( + f"/{url_parts[3]}" + if len(url_parts) > 3 and url_parts[3] + else "/.well-known/agent-card.json" + ) + + headers, verify = await _prepare_auth_headers(auth, timeout) + + async with httpx.AsyncClient( + timeout=timeout, headers=headers, verify=verify + ) as temp_client: + if auth and isinstance(auth, (HTTPDigestAuth, APIKeyAuth)): + configure_auth_client(auth, temp_client) + + agent_card_url = f"{base_url}{agent_card_path}" + + async def _fetch_agent_card_request() -> httpx.Response: + return await temp_client.get(agent_card_url) + + try: + response = await retry_on_401( + request_func=_fetch_agent_card_request, + auth_scheme=auth, + client=temp_client, + headers=temp_client.headers, + max_retries=2, + ) + response.raise_for_status() + + agent_card = AgentCard.model_validate(response.json()) + fetch_time_ms = (time.perf_counter() - start_time) * 1000 + agent_card_dict = agent_card.model_dump(exclude_none=True) + + crewai_event_bus.emit( + None, + A2AAgentCardFetchedEvent( + endpoint=endpoint, + a2a_agent_name=agent_card.name, + agent_card=agent_card_dict, + protocol_version=agent_card.protocol_version, + provider=agent_card_dict.get("provider"), + cached=False, + fetch_time_ms=fetch_time_ms, + ), + ) + + return agent_card + + except httpx.HTTPStatusError as e: + elapsed_ms = (time.perf_counter() - start_time) * 1000 + response_body = e.response.text[:1000] if e.response.text else None + + if e.response.status_code == 401: + error_details = ["Authentication failed"] + www_auth = e.response.headers.get("WWW-Authenticate") + if www_auth: + error_details.append(f"WWW-Authenticate: {www_auth}") + if not auth: + error_details.append("No auth scheme provided") + msg = " | ".join(error_details) + + auth_type = type(auth).__name__ if auth else None + crewai_event_bus.emit( + None, + A2AAuthenticationFailedEvent( + endpoint=endpoint, + auth_type=auth_type, + error=msg, + status_code=401, + metadata={ + "elapsed_ms": elapsed_ms, + "response_body": response_body, + "www_authenticate": www_auth, + "request_url": str(e.request.url), + }, + ), + ) + + raise A2AClientHTTPError(401, msg) from e + + crewai_event_bus.emit( + None, + A2AConnectionErrorEvent( + endpoint=endpoint, + error=str(e), + error_type="http_error", + status_code=e.response.status_code, + operation="fetch_agent_card", + metadata={ + "elapsed_ms": elapsed_ms, + "response_body": response_body, + "request_url": str(e.request.url), + }, + ), + ) + raise + + except httpx.TimeoutException as e: + elapsed_ms = (time.perf_counter() - start_time) * 1000 + crewai_event_bus.emit( + None, + A2AConnectionErrorEvent( + endpoint=endpoint, + error=str(e), + error_type="timeout", + operation="fetch_agent_card", + metadata={ + "elapsed_ms": elapsed_ms, + "timeout_config": timeout, + "request_url": str(e.request.url) if e.request else None, + }, + ), + ) + raise + + except httpx.ConnectError as e: + elapsed_ms = (time.perf_counter() - start_time) * 1000 + crewai_event_bus.emit( + None, + A2AConnectionErrorEvent( + endpoint=endpoint, + error=str(e), + error_type="connection_error", + operation="fetch_agent_card", + metadata={ + "elapsed_ms": elapsed_ms, + "request_url": str(e.request.url) if e.request else None, + }, + ), + ) + raise + + except httpx.RequestError as e: + elapsed_ms = (time.perf_counter() - start_time) * 1000 + crewai_event_bus.emit( + None, + A2AConnectionErrorEvent( + endpoint=endpoint, + error=str(e), + error_type="request_error", + operation="fetch_agent_card", + metadata={ + "elapsed_ms": elapsed_ms, + "request_url": str(e.request.url) if e.request else None, + }, + ), + ) + raise + + +def _task_to_skill(task: Task) -> AgentSkill: + """Convert a CrewAI Task to an A2A AgentSkill. + + Args: + task: The CrewAI Task to convert. + + Returns: + AgentSkill representing the task's capability. + """ + task_name = task.name or task.description[:50] + task_id = task_name.lower().replace(" ", "_") + + tags: list[str] = [] + if task.agent: + tags.append(task.agent.role.lower().replace(" ", "-")) + + return AgentSkill( + id=task_id, + name=task_name, + description=task.description, + tags=tags, + examples=[task.expected_output] if task.expected_output else None, + ) + + +def _tool_to_skill(tool_name: str, tool_description: str) -> AgentSkill: + """Convert an Agent's tool to an A2A AgentSkill. + + Args: + tool_name: Name of the tool. + tool_description: Description of what the tool does. + + Returns: + AgentSkill representing the tool's capability. + """ + tool_id = tool_name.lower().replace(" ", "_") + + return AgentSkill( + id=tool_id, + name=tool_name, + description=tool_description, + tags=[tool_name.lower().replace(" ", "-")], + ) + + +def _crew_to_agent_card(crew: Crew, url: str) -> AgentCard: + """Generate an A2A AgentCard from a Crew instance. + + Args: + crew: The Crew instance to generate a card for. + url: The base URL where this crew will be exposed. + + Returns: + AgentCard describing the crew's capabilities. + """ + crew_name = getattr(crew, "name", None) or crew.__class__.__name__ + + description_parts: list[str] = [] + crew_description = getattr(crew, "description", None) + if crew_description: + description_parts.append(crew_description) + else: + agent_roles = [agent.role for agent in crew.agents] + description_parts.append( + f"A crew of {len(crew.agents)} agents: {', '.join(agent_roles)}" + ) + + skills = [_task_to_skill(task) for task in crew.tasks] + + return AgentCard( + name=crew_name, + description=" ".join(description_parts), + url=url, + version="1.0.0", + capabilities=AgentCapabilities( + streaming=True, + push_notifications=True, + ), + default_input_modes=["text/plain", "application/json"], + default_output_modes=["text/plain", "application/json"], + skills=skills, + ) + + +def _agent_to_agent_card(agent: Agent, url: str) -> AgentCard: + """Generate an A2A AgentCard from an Agent instance. + + Uses A2AServerConfig values when available, falling back to agent properties. + If signing_config is provided, the card will be signed with JWS. + + Args: + agent: The Agent instance to generate a card for. + url: The base URL where this agent will be exposed. + + Returns: + AgentCard describing the agent's capabilities. + """ + from crewai_a2a.utils.agent_card_signing import sign_agent_card + + server_config = _get_server_config(agent) or A2AServerConfig() + + name = server_config.name or agent.role + + description_parts = [agent.goal] + if agent.backstory: + description_parts.append(agent.backstory) + description = server_config.description or " ".join(description_parts) + + skills: list[AgentSkill] = ( + server_config.skills.copy() if server_config.skills else [] + ) + + if not skills: + if agent.tools: + for tool in agent.tools: + tool_name = getattr(tool, "name", None) or tool.__class__.__name__ + tool_desc = getattr(tool, "description", None) or f"Tool: {tool_name}" + skills.append(_tool_to_skill(tool_name, tool_desc)) + + if not skills: + skills.append( + AgentSkill( + id=agent.role.lower().replace(" ", "_"), + name=agent.role, + description=agent.goal, + tags=[agent.role.lower().replace(" ", "-")], + ) + ) + + capabilities = server_config.capabilities + if server_config.server_extensions: + from crewai_a2a.extensions.server import ServerExtensionRegistry + + registry = ServerExtensionRegistry(server_config.server_extensions) + ext_list = registry.get_agent_extensions() + + existing_exts = list(capabilities.extensions) if capabilities.extensions else [] + existing_uris = {e.uri for e in existing_exts} + for ext in ext_list: + if ext.uri not in existing_uris: + existing_exts.append(ext) + + capabilities = capabilities.model_copy(update={"extensions": existing_exts}) + + card = AgentCard( + name=name, + description=description, + url=server_config.url or url, + version=server_config.version, + capabilities=capabilities, + default_input_modes=server_config.default_input_modes, + default_output_modes=server_config.default_output_modes, + skills=skills, + preferred_transport=server_config.transport.preferred, + protocol_version=server_config.protocol_version, + provider=server_config.provider, + documentation_url=server_config.documentation_url, + icon_url=server_config.icon_url, + additional_interfaces=server_config.additional_interfaces, + security=server_config.security, + security_schemes=server_config.security_schemes, + supports_authenticated_extended_card=server_config.supports_authenticated_extended_card, + ) + + if server_config.signing_config: + signature = sign_agent_card( + card, + private_key=server_config.signing_config.get_private_key(), + key_id=server_config.signing_config.key_id, + algorithm=server_config.signing_config.algorithm, + ) + card = card.model_copy(update={"signatures": [signature]}) + elif server_config.signatures: + card = card.model_copy(update={"signatures": server_config.signatures}) + + return card + + +def inject_a2a_server_methods(agent: Agent) -> None: + """Inject A2A server methods onto an Agent instance. + + Adds a `to_agent_card(url: str) -> AgentCard` method to the agent + that generates an A2A-compliant AgentCard. + + Only injects if the agent has an A2AServerConfig. + + Args: + agent: The Agent instance to inject methods onto. + """ + if _get_server_config(agent) is None: + return + + def _to_agent_card(self: Agent, url: str) -> AgentCard: + return _agent_to_agent_card(self, url) + + object.__setattr__(agent, "to_agent_card", MethodType(_to_agent_card, agent)) diff --git a/lib/crewai-a2a/src/crewai_a2a/utils/agent_card_signing.py b/lib/crewai-a2a/src/crewai_a2a/utils/agent_card_signing.py new file mode 100644 index 000000000..9dfe0ab89 --- /dev/null +++ b/lib/crewai-a2a/src/crewai_a2a/utils/agent_card_signing.py @@ -0,0 +1,236 @@ +"""AgentCard JWS signing utilities. + +This module provides functions for signing and verifying AgentCards using +JSON Web Signatures (JWS) as per RFC 7515. Signed agent cards allow clients +to verify the authenticity and integrity of agent card information. + +Example: + >>> from crewai_a2a.utils.agent_card_signing import sign_agent_card + >>> signature = sign_agent_card(agent_card, private_key_pem, key_id="key-1") + >>> card_with_sig = card.model_copy(update={"signatures": [signature]}) +""" + +from __future__ import annotations + +import base64 +import json +import logging +from typing import Any, Literal + +from a2a.types import AgentCard, AgentCardSignature +import jwt +from pydantic import SecretStr + + +logger = logging.getLogger(__name__) + + +SigningAlgorithm = Literal[ + "RS256", "RS384", "RS512", "ES256", "ES384", "ES512", "PS256", "PS384", "PS512" +] + + +def _normalize_private_key(private_key: str | bytes | SecretStr) -> bytes: + """Normalize private key to bytes format. + + Args: + private_key: PEM-encoded private key as string, bytes, or SecretStr. + + Returns: + Private key as bytes. + """ + if isinstance(private_key, SecretStr): + private_key = private_key.get_secret_value() + if isinstance(private_key, str): + private_key = private_key.encode() + return private_key + + +def _serialize_agent_card(agent_card: AgentCard) -> str: + """Serialize AgentCard to canonical JSON for signing. + + Excludes the signatures field to avoid circular reference during signing. + Uses sorted keys and compact separators for deterministic output. + + Args: + agent_card: The AgentCard to serialize. + + Returns: + Canonical JSON string representation. + """ + card_dict = agent_card.model_dump(exclude={"signatures"}, exclude_none=True) + return json.dumps(card_dict, sort_keys=True, separators=(",", ":")) + + +def _base64url_encode(data: bytes | str) -> str: + """Encode data to URL-safe base64 without padding. + + Args: + data: Data to encode. + + Returns: + URL-safe base64 encoded string without padding. + """ + if isinstance(data, str): + data = data.encode() + return base64.urlsafe_b64encode(data).rstrip(b"=").decode("ascii") + + +def sign_agent_card( + agent_card: AgentCard, + private_key: str | bytes | SecretStr, + key_id: str | None = None, + algorithm: SigningAlgorithm = "RS256", +) -> AgentCardSignature: + """Sign an AgentCard using JWS (RFC 7515). + + Creates a detached JWS signature for the AgentCard. The signature covers + all fields except the signatures field itself. + + Args: + agent_card: The AgentCard to sign. + private_key: PEM-encoded private key (RSA, EC, or RSA-PSS). + key_id: Optional key identifier for the JWS header (kid claim). + algorithm: Signing algorithm (RS256, ES256, PS256, etc.). + + Returns: + AgentCardSignature with protected header and signature. + + Raises: + jwt.exceptions.InvalidKeyError: If the private key is invalid. + ValueError: If the algorithm is not supported for the key type. + + Example: + >>> signature = sign_agent_card( + ... agent_card, + ... private_key_pem="-----BEGIN PRIVATE KEY-----...", + ... key_id="my-key-id", + ... ) + """ + key_bytes = _normalize_private_key(private_key) + payload = _serialize_agent_card(agent_card) + + protected_header: dict[str, Any] = {"typ": "JWS"} + if key_id: + protected_header["kid"] = key_id + + jws_token = jwt.api_jws.encode( + payload.encode(), + key_bytes, + algorithm=algorithm, + headers=protected_header, + ) + + parts = jws_token.split(".") + protected_b64 = parts[0] + signature_b64 = parts[2] + + header: dict[str, Any] | None = None + if key_id: + header = {"kid": key_id} + + return AgentCardSignature( + protected=protected_b64, + signature=signature_b64, + header=header, + ) + + +def verify_agent_card_signature( + agent_card: AgentCard, + signature: AgentCardSignature, + public_key: str | bytes, + algorithms: list[str] | None = None, +) -> bool: + """Verify an AgentCard JWS signature. + + Validates that the signature was created with the corresponding private key + and that the AgentCard content has not been modified. + + Args: + agent_card: The AgentCard to verify. + signature: The AgentCardSignature to validate. + public_key: PEM-encoded public key (RSA, EC, or RSA-PSS). + algorithms: List of allowed algorithms. Defaults to common asymmetric algorithms. + + Returns: + True if signature is valid, False otherwise. + + Example: + >>> is_valid = verify_agent_card_signature( + ... agent_card, signature, public_key_pem="-----BEGIN PUBLIC KEY-----..." + ... ) + """ + if algorithms is None: + algorithms = [ + "RS256", + "RS384", + "RS512", + "ES256", + "ES384", + "ES512", + "PS256", + "PS384", + "PS512", + ] + + if isinstance(public_key, str): + public_key = public_key.encode() + + payload = _serialize_agent_card(agent_card) + payload_b64 = _base64url_encode(payload) + jws_token = f"{signature.protected}.{payload_b64}.{signature.signature}" + + try: + jwt.api_jws.decode( + jws_token, + public_key, + algorithms=algorithms, + ) + return True + except jwt.InvalidSignatureError: + logger.debug( + "AgentCard signature verification failed", + extra={"reason": "invalid_signature"}, + ) + return False + except jwt.DecodeError as e: + logger.debug( + "AgentCard signature verification failed", + extra={"reason": "decode_error", "error": str(e)}, + ) + return False + except jwt.InvalidAlgorithmError as e: + logger.debug( + "AgentCard signature verification failed", + extra={"reason": "algorithm_error", "error": str(e)}, + ) + return False + + +def get_key_id_from_signature(signature: AgentCardSignature) -> str | None: + """Extract the key ID (kid) from an AgentCardSignature. + + Checks both the unprotected header and the protected header for the kid claim. + + Args: + signature: The AgentCardSignature to extract from. + + Returns: + The key ID if present, None otherwise. + """ + if signature.header and "kid" in signature.header: + kid: str = signature.header["kid"] + return kid + + try: + protected = signature.protected + padding_needed = 4 - (len(protected) % 4) + if padding_needed != 4: + protected += "=" * padding_needed + + protected_json = base64.urlsafe_b64decode(protected).decode() + protected_header: dict[str, Any] = json.loads(protected_json) + return protected_header.get("kid") + except (ValueError, json.JSONDecodeError): + return None diff --git a/lib/crewai-a2a/src/crewai_a2a/utils/content_type.py b/lib/crewai-a2a/src/crewai_a2a/utils/content_type.py new file mode 100644 index 000000000..b752fb88b --- /dev/null +++ b/lib/crewai-a2a/src/crewai_a2a/utils/content_type.py @@ -0,0 +1,338 @@ +"""Content type negotiation for A2A protocol. + +This module handles negotiation of input/output MIME types between A2A clients +and servers based on AgentCard capabilities. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Annotated, Final, Literal, cast + +from a2a.types import Part +from crewai.events.event_bus import crewai_event_bus +from crewai.events.types.a2a_events import A2AContentTypeNegotiatedEvent + + +if TYPE_CHECKING: + from a2a.types import AgentCard, AgentSkill + + +TEXT_PLAIN: Literal["text/plain"] = "text/plain" +APPLICATION_JSON: Literal["application/json"] = "application/json" +IMAGE_PNG: Literal["image/png"] = "image/png" +IMAGE_JPEG: Literal["image/jpeg"] = "image/jpeg" +IMAGE_WILDCARD: Literal["image/*"] = "image/*" +APPLICATION_PDF: Literal["application/pdf"] = "application/pdf" +APPLICATION_OCTET_STREAM: Literal["application/octet-stream"] = ( + "application/octet-stream" +) + +DEFAULT_CLIENT_INPUT_MODES: Final[list[Literal["text/plain", "application/json"]]] = [ + TEXT_PLAIN, + APPLICATION_JSON, +] +DEFAULT_CLIENT_OUTPUT_MODES: Final[list[Literal["text/plain", "application/json"]]] = [ + TEXT_PLAIN, + APPLICATION_JSON, +] + + +@dataclass +class NegotiatedContentTypes: + """Result of content type negotiation.""" + + input_modes: Annotated[list[str], "Negotiated input MIME types the client can send"] + output_modes: Annotated[ + list[str], "Negotiated output MIME types the server will produce" + ] + effective_input_modes: Annotated[list[str], "Server's effective input modes"] + effective_output_modes: Annotated[list[str], "Server's effective output modes"] + skill_name: Annotated[ + str | None, "Skill name if negotiation was skill-specific" + ] = None + + +class ContentTypeNegotiationError(Exception): + """Raised when no compatible content types can be negotiated.""" + + def __init__( + self, + client_input_modes: list[str], + client_output_modes: list[str], + server_input_modes: list[str], + server_output_modes: list[str], + direction: str = "both", + message: str | None = None, + ) -> None: + self.client_input_modes = client_input_modes + self.client_output_modes = client_output_modes + self.server_input_modes = server_input_modes + self.server_output_modes = server_output_modes + self.direction = direction + + if message is None: + if direction == "input": + message = ( + f"No compatible input content types. " + f"Client supports: {client_input_modes}, " + f"Server accepts: {server_input_modes}" + ) + elif direction == "output": + message = ( + f"No compatible output content types. " + f"Client accepts: {client_output_modes}, " + f"Server produces: {server_output_modes}" + ) + else: + message = ( + f"No compatible content types. " + f"Input - Client: {client_input_modes}, Server: {server_input_modes}. " + f"Output - Client: {client_output_modes}, Server: {server_output_modes}" + ) + + super().__init__(message) + + +def _normalize_mime_type(mime_type: str) -> str: + """Normalize MIME type for comparison (lowercase, strip whitespace).""" + return mime_type.lower().strip() + + +def _mime_types_compatible(client_type: str, server_type: str) -> bool: + """Check if two MIME types are compatible. + + Handles wildcards like image/* matching image/png. + """ + client_normalized = _normalize_mime_type(client_type) + server_normalized = _normalize_mime_type(server_type) + + if client_normalized == server_normalized: + return True + + if "*" in client_normalized or "*" in server_normalized: + client_parts = client_normalized.split("/") + server_parts = server_normalized.split("/") + + if len(client_parts) == 2 and len(server_parts) == 2: + type_match = ( + client_parts[0] == server_parts[0] + or client_parts[0] == "*" + or server_parts[0] == "*" + ) + subtype_match = ( + client_parts[1] == server_parts[1] + or client_parts[1] == "*" + or server_parts[1] == "*" + ) + return type_match and subtype_match + + return False + + +def _find_compatible_modes( + client_modes: list[str], server_modes: list[str] +) -> list[str]: + """Find compatible MIME types between client and server. + + Returns modes in client preference order. + """ + compatible = [] + for client_mode in client_modes: + for server_mode in server_modes: + if _mime_types_compatible(client_mode, server_mode): + if "*" in client_mode and "*" not in server_mode: + if server_mode not in compatible: + compatible.append(server_mode) + else: + if client_mode not in compatible: + compatible.append(client_mode) + break + return compatible + + +def _get_effective_modes( + agent_card: AgentCard, + skill_name: str | None = None, +) -> tuple[list[str], list[str], AgentSkill | None]: + """Get effective input/output modes from agent card. + + If skill_name is provided and the skill has custom modes, those are used. + Otherwise, falls back to agent card defaults. + """ + skill: AgentSkill | None = None + + if skill_name and agent_card.skills: + for s in agent_card.skills: + if s.name == skill_name or s.id == skill_name: + skill = s + break + + if skill: + input_modes = ( + skill.input_modes if skill.input_modes else agent_card.default_input_modes + ) + output_modes = ( + skill.output_modes + if skill.output_modes + else agent_card.default_output_modes + ) + else: + input_modes = agent_card.default_input_modes + output_modes = agent_card.default_output_modes + + return input_modes, output_modes, skill + + +def negotiate_content_types( + agent_card: AgentCard, + client_input_modes: list[str] | None = None, + client_output_modes: list[str] | None = None, + skill_name: str | None = None, + emit_event: bool = True, + endpoint: str | None = None, + a2a_agent_name: str | None = None, + strict: bool = False, +) -> NegotiatedContentTypes: + """Negotiate content types between client and server. + + Args: + agent_card: The remote agent's card with capability info. + client_input_modes: MIME types the client can send. Defaults to text/plain and application/json. + client_output_modes: MIME types the client can accept. Defaults to text/plain and application/json. + skill_name: Optional skill to use for mode lookup. + emit_event: Whether to emit a content type negotiation event. + endpoint: Agent endpoint (for event metadata). + a2a_agent_name: Agent name (for event metadata). + strict: If True, raises error when no compatible types found. + If False, returns empty lists for incompatible directions. + + Returns: + NegotiatedContentTypes with compatible input and output modes. + + Raises: + ContentTypeNegotiationError: If strict=True and no compatible types found. + """ + if client_input_modes is None: + client_input_modes = cast(list[str], DEFAULT_CLIENT_INPUT_MODES.copy()) + if client_output_modes is None: + client_output_modes = cast(list[str], DEFAULT_CLIENT_OUTPUT_MODES.copy()) + + server_input_modes, server_output_modes, skill = _get_effective_modes( + agent_card, skill_name + ) + + compatible_input = _find_compatible_modes(client_input_modes, server_input_modes) + compatible_output = _find_compatible_modes(client_output_modes, server_output_modes) + + if strict: + if not compatible_input and not compatible_output: + raise ContentTypeNegotiationError( + client_input_modes=client_input_modes, + client_output_modes=client_output_modes, + server_input_modes=server_input_modes, + server_output_modes=server_output_modes, + ) + if not compatible_input: + raise ContentTypeNegotiationError( + client_input_modes=client_input_modes, + client_output_modes=client_output_modes, + server_input_modes=server_input_modes, + server_output_modes=server_output_modes, + direction="input", + ) + if not compatible_output: + raise ContentTypeNegotiationError( + client_input_modes=client_input_modes, + client_output_modes=client_output_modes, + server_input_modes=server_input_modes, + server_output_modes=server_output_modes, + direction="output", + ) + + result = NegotiatedContentTypes( + input_modes=compatible_input, + output_modes=compatible_output, + effective_input_modes=server_input_modes, + effective_output_modes=server_output_modes, + skill_name=skill.name if skill else None, + ) + + if emit_event: + crewai_event_bus.emit( + None, + A2AContentTypeNegotiatedEvent( + endpoint=endpoint or agent_card.url, + a2a_agent_name=a2a_agent_name or agent_card.name, + skill_name=skill_name, + client_input_modes=client_input_modes, + client_output_modes=client_output_modes, + server_input_modes=server_input_modes, + server_output_modes=server_output_modes, + negotiated_input_modes=compatible_input, + negotiated_output_modes=compatible_output, + negotiation_success=bool(compatible_input and compatible_output), + ), + ) + + return result + + +def validate_content_type( + content_type: str, + allowed_modes: list[str], +) -> bool: + """Validate that a content type is allowed by a list of modes. + + Args: + content_type: The MIME type to validate. + allowed_modes: List of allowed MIME types (may include wildcards). + + Returns: + True if content_type is compatible with any allowed mode. + """ + for mode in allowed_modes: + if _mime_types_compatible(content_type, mode): + return True + return False + + +def get_part_content_type(part: Part) -> str: + """Extract MIME type from an A2A Part. + + Args: + part: A Part object containing TextPart, DataPart, or FilePart. + + Returns: + The MIME type string for this part. + """ + root = part.root + if root.kind == "text": + return TEXT_PLAIN + if root.kind == "data": + return APPLICATION_JSON + if root.kind == "file": + return root.file.mime_type or APPLICATION_OCTET_STREAM + return APPLICATION_OCTET_STREAM + + +def validate_message_parts( + parts: list[Part], + allowed_modes: list[str], +) -> list[str]: + """Validate that all message parts have allowed content types. + + Args: + parts: List of Parts from the incoming message. + allowed_modes: List of allowed MIME types (from default_input_modes). + + Returns: + List of invalid content types found (empty if all valid). + """ + invalid_types: list[str] = [] + for part in parts: + content_type = get_part_content_type(part) + if not validate_content_type(content_type, allowed_modes): + if content_type not in invalid_types: + invalid_types.append(content_type) + return invalid_types diff --git a/lib/crewai-a2a/src/crewai_a2a/utils/delegation.py b/lib/crewai-a2a/src/crewai_a2a/utils/delegation.py new file mode 100644 index 000000000..acbad00e1 --- /dev/null +++ b/lib/crewai-a2a/src/crewai_a2a/utils/delegation.py @@ -0,0 +1,980 @@ +"""A2A delegation utilities for executing tasks on remote agents.""" + +from __future__ import annotations + +import asyncio +import base64 +from collections.abc import AsyncIterator, Callable, MutableMapping +from contextlib import asynccontextmanager +import logging +from typing import TYPE_CHECKING, Any, Final, Literal +import uuid + +from a2a.client import Client, ClientConfig, ClientFactory +from a2a.types import ( + AgentCard, + FilePart, + FileWithBytes, + Message, + Part, + PushNotificationConfig as A2APushNotificationConfig, + Role, + TextPart, +) +from crewai.events.event_bus import crewai_event_bus +from crewai.events.types.a2a_events import ( + A2AConversationStartedEvent, + A2ADelegationCompletedEvent, + A2ADelegationStartedEvent, + A2AMessageSentEvent, +) +import httpx +from pydantic import BaseModel + +from crewai_a2a.auth.client_schemes import APIKeyAuth, HTTPDigestAuth +from crewai_a2a.auth.utils import ( + _auth_store, + configure_auth_client, + validate_auth_against_agent_card, +) +from crewai_a2a.config import ClientTransportConfig, GRPCClientConfig +from crewai_a2a.extensions.registry import ( + ExtensionsMiddleware, + validate_required_extensions, +) +from crewai_a2a.task_helpers import TaskStateResult +from crewai_a2a.types import ( + HANDLER_REGISTRY, + HandlerType, + PartsDict, + PartsMetadataDict, + TransportType, +) +from crewai_a2a.updates import ( + PollingConfig, + PushNotificationConfig, + StreamingHandler, + UpdateConfig, +) +from crewai_a2a.utils.agent_card import ( + _afetch_agent_card_cached, + _get_tls_verify, + _prepare_auth_headers, +) +from crewai_a2a.utils.content_type import ( + DEFAULT_CLIENT_OUTPUT_MODES, + negotiate_content_types, +) +from crewai_a2a.utils.transport import ( + NegotiatedTransport, + TransportNegotiationError, + negotiate_transport, +) + + +logger = logging.getLogger(__name__) + + +if TYPE_CHECKING: + from a2a.types import Message + + from crewai_a2a.auth.client_schemes import ClientAuthScheme + + +_DEFAULT_TRANSPORT: Final[TransportType] = "JSONRPC" + + +def _create_file_parts(input_files: dict[str, Any] | None) -> list[Part]: + """Convert FileInput dictionary to FilePart objects. + + Args: + input_files: Dictionary mapping names to FileInput objects. + + Returns: + List of Part objects containing FilePart data. + """ + if not input_files: + return [] + + try: + import crewai_files # noqa: F401 + except ImportError: + logger.debug("crewai_files not installed, skipping file parts") + return [] + + parts: list[Part] = [] + for name, file_input in input_files.items(): + content_bytes = file_input.read() + content_base64 = base64.b64encode(content_bytes).decode() + file_with_bytes = FileWithBytes( + bytes=content_base64, + mimeType=file_input.content_type, + name=file_input.filename or name, + ) + parts.append(Part(root=FilePart(file=file_with_bytes))) + + return parts + + +def get_handler(config: UpdateConfig | None) -> HandlerType: + """Get the handler class for a given update config. + + Args: + config: Update mechanism configuration. + + Returns: + Handler class for the config type, defaults to StreamingHandler. + """ + if config is None: + return StreamingHandler + return HANDLER_REGISTRY.get(type(config), StreamingHandler) + + +def execute_a2a_delegation( + endpoint: str, + auth: ClientAuthScheme | None, + timeout: int, + task_description: str, + context: str | None = None, + context_id: str | None = None, + task_id: str | None = None, + reference_task_ids: list[str] | None = None, + metadata: dict[str, Any] | None = None, + extensions: dict[str, Any] | None = None, + conversation_history: list[Message] | None = None, + agent_id: str | None = None, + agent_role: Role | None = None, + agent_branch: Any | None = None, + response_model: type[BaseModel] | None = None, + turn_number: int | None = None, + updates: UpdateConfig | None = None, + from_task: Any | None = None, + from_agent: Any | None = None, + skill_id: str | None = None, + client_extensions: list[str] | None = None, + transport: ClientTransportConfig | None = None, + accepted_output_modes: list[str] | None = None, + input_files: dict[str, Any] | None = None, +) -> TaskStateResult: + """Execute a task delegation to a remote A2A agent synchronously. + + WARNING: This function blocks the entire thread by creating and running a new + event loop. Prefer using 'await aexecute_a2a_delegation()' in async contexts + for better performance and resource efficiency. + + This is a synchronous wrapper around aexecute_a2a_delegation that creates a + new event loop to run the async implementation. It is provided for compatibility + with synchronous code paths only. + + Args: + endpoint: A2A agent endpoint URL (AgentCard URL). + auth: Optional ClientAuthScheme for authentication. + timeout: Request timeout in seconds. + task_description: The task to delegate. + context: Optional context information. + context_id: Context ID for correlating messages/tasks. + task_id: Specific task identifier. + reference_task_ids: List of related task IDs. + metadata: Additional metadata. + extensions: Protocol extensions for custom fields. + conversation_history: Previous Message objects from conversation. + agent_id: Agent identifier for logging. + agent_role: Role of the CrewAI agent delegating the task. + agent_branch: Optional agent tree branch for logging. + response_model: Optional Pydantic model for structured outputs. + turn_number: Optional turn number for multi-turn conversations. + updates: Update mechanism config from A2AConfig.updates. + from_task: Optional CrewAI Task object for event metadata. + from_agent: Optional CrewAI Agent object for event metadata. + skill_id: Optional skill ID to target a specific agent capability. + client_extensions: A2A protocol extension URIs the client supports. + transport: Transport configuration (preferred, supported transports, gRPC settings). + accepted_output_modes: MIME types the client can accept in responses. + input_files: Optional dictionary of files to send to remote agent. + + Returns: + TaskStateResult with status, result/error, history, and agent_card. + + Raises: + RuntimeError: If called from an async context with a running event loop. + """ + try: + asyncio.get_running_loop() + raise RuntimeError( + "execute_a2a_delegation() cannot be called from an async context. " + "Use 'await aexecute_a2a_delegation()' instead." + ) + except RuntimeError as e: + if "no running event loop" not in str(e).lower(): + raise + + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + return loop.run_until_complete( + aexecute_a2a_delegation( + endpoint=endpoint, + auth=auth, + timeout=timeout, + task_description=task_description, + context=context, + context_id=context_id, + task_id=task_id, + reference_task_ids=reference_task_ids, + metadata=metadata, + extensions=extensions, + conversation_history=conversation_history, + agent_id=agent_id, + agent_role=agent_role, + agent_branch=agent_branch, + response_model=response_model, + turn_number=turn_number, + updates=updates, + from_task=from_task, + from_agent=from_agent, + skill_id=skill_id, + client_extensions=client_extensions, + transport=transport, + accepted_output_modes=accepted_output_modes, + input_files=input_files, + ) + ) + finally: + try: + loop.run_until_complete(loop.shutdown_asyncgens()) + finally: + loop.close() + + +async def aexecute_a2a_delegation( + endpoint: str, + auth: ClientAuthScheme | None, + timeout: int, + task_description: str, + context: str | None = None, + context_id: str | None = None, + task_id: str | None = None, + reference_task_ids: list[str] | None = None, + metadata: dict[str, Any] | None = None, + extensions: dict[str, Any] | None = None, + conversation_history: list[Message] | None = None, + agent_id: str | None = None, + agent_role: Role | None = None, + agent_branch: Any | None = None, + response_model: type[BaseModel] | None = None, + turn_number: int | None = None, + updates: UpdateConfig | None = None, + from_task: Any | None = None, + from_agent: Any | None = None, + skill_id: str | None = None, + client_extensions: list[str] | None = None, + transport: ClientTransportConfig | None = None, + accepted_output_modes: list[str] | None = None, + input_files: dict[str, Any] | None = None, +) -> TaskStateResult: + """Execute a task delegation to a remote A2A agent asynchronously. + + Native async implementation with multi-turn support. Use this when running + in an async context (e.g., with Crew.akickoff() or agent.aexecute_task()). + + Args: + endpoint: A2A agent endpoint URL. + auth: Optional ClientAuthScheme for authentication. + timeout: Request timeout in seconds. + task_description: The task to delegate. + context: Optional context information. + context_id: Context ID for correlating messages/tasks. + task_id: Specific task identifier. + reference_task_ids: List of related task IDs. + metadata: Additional metadata. + extensions: Protocol extensions for custom fields. + conversation_history: Previous Message objects from conversation. + agent_id: Agent identifier for logging. + agent_role: Role of the CrewAI agent delegating the task. + agent_branch: Optional agent tree branch for logging. + response_model: Optional Pydantic model for structured outputs. + turn_number: Optional turn number for multi-turn conversations. + updates: Update mechanism config from A2AConfig.updates. + from_task: Optional CrewAI Task object for event metadata. + from_agent: Optional CrewAI Agent object for event metadata. + skill_id: Optional skill ID to target a specific agent capability. + client_extensions: A2A protocol extension URIs the client supports. + transport: Transport configuration (preferred, supported transports, gRPC settings). + accepted_output_modes: MIME types the client can accept in responses. + input_files: Optional dictionary of files to send to remote agent. + + Returns: + TaskStateResult with status, result/error, history, and agent_card. + """ + if conversation_history is None: + conversation_history = [] + + is_multiturn = len(conversation_history) > 0 + if turn_number is None: + turn_number = len([m for m in conversation_history if m.role == Role.user]) + 1 + + try: + result = await _aexecute_a2a_delegation_impl( + endpoint=endpoint, + auth=auth, + timeout=timeout, + task_description=task_description, + context=context, + context_id=context_id, + task_id=task_id, + reference_task_ids=reference_task_ids, + metadata=metadata, + extensions=extensions, + conversation_history=conversation_history, + is_multiturn=is_multiturn, + turn_number=turn_number, + agent_branch=agent_branch, + agent_id=agent_id, + agent_role=agent_role, + response_model=response_model, + updates=updates, + from_task=from_task, + from_agent=from_agent, + skill_id=skill_id, + client_extensions=client_extensions, + transport=transport, + accepted_output_modes=accepted_output_modes, + input_files=input_files, + ) + except Exception as e: + crewai_event_bus.emit( + agent_branch, + A2ADelegationCompletedEvent( + status="failed", + result=None, + error=str(e), + context_id=context_id, + is_multiturn=is_multiturn, + endpoint=endpoint, + metadata=metadata, + extensions=list(extensions.keys()) if extensions else None, + from_task=from_task, + from_agent=from_agent, + ), + ) + raise + + agent_card_data = result.get("agent_card") + crewai_event_bus.emit( + agent_branch, + A2ADelegationCompletedEvent( + status=result["status"], + result=result.get("result"), + error=result.get("error"), + context_id=context_id, + is_multiturn=is_multiturn, + endpoint=endpoint, + a2a_agent_name=result.get("a2a_agent_name"), + agent_card=agent_card_data, + provider=agent_card_data.get("provider") if agent_card_data else None, + metadata=metadata, + extensions=list(extensions.keys()) if extensions else None, + from_task=from_task, + from_agent=from_agent, + ), + ) + + return result + + +async def _aexecute_a2a_delegation_impl( + endpoint: str, + auth: ClientAuthScheme | None, + timeout: int, + task_description: str, + context: str | None, + context_id: str | None, + task_id: str | None, + reference_task_ids: list[str] | None, + metadata: dict[str, Any] | None, + extensions: dict[str, Any] | None, + conversation_history: list[Message], + is_multiturn: bool, + turn_number: int, + agent_branch: Any | None, + agent_id: str | None, + agent_role: str | None, + response_model: type[BaseModel] | None, + updates: UpdateConfig | None, + from_task: Any | None = None, + from_agent: Any | None = None, + skill_id: str | None = None, + client_extensions: list[str] | None = None, + transport: ClientTransportConfig | None = None, + accepted_output_modes: list[str] | None = None, + input_files: dict[str, Any] | None = None, +) -> TaskStateResult: + """Internal async implementation of A2A delegation.""" + if transport is None: + transport = ClientTransportConfig() + if auth: + auth_data = auth.model_dump_json( + exclude={ + "_access_token", + "_token_expires_at", + "_refresh_token", + "_authorization_callback", + } + ) + auth_hash = _auth_store.compute_key(type(auth).__name__, auth_data) + else: + auth_hash = _auth_store.compute_key("none", endpoint) + _auth_store.set(auth_hash, auth) + agent_card = await _afetch_agent_card_cached( + endpoint=endpoint, auth_hash=auth_hash, timeout=timeout + ) + + validate_auth_against_agent_card(agent_card, auth) + + unsupported_exts = validate_required_extensions(agent_card, client_extensions) + if unsupported_exts: + ext_uris = [ext.uri for ext in unsupported_exts] + raise ValueError( + f"Agent requires extensions not supported by client: {ext_uris}" + ) + + negotiated: NegotiatedTransport | None = None + effective_transport: TransportType = transport.preferred or _DEFAULT_TRANSPORT + effective_url = endpoint + + client_transports: list[str] = ( + list(transport.supported) if transport.supported else [_DEFAULT_TRANSPORT] + ) + + try: + negotiated = negotiate_transport( + agent_card=agent_card, + client_supported_transports=client_transports, + client_preferred_transport=transport.preferred, + endpoint=endpoint, + a2a_agent_name=agent_card.name, + ) + effective_transport = negotiated.transport # type: ignore[assignment] + effective_url = negotiated.url + except TransportNegotiationError as e: + logger.warning( + "Transport negotiation failed, using fallback", + extra={ + "error": str(e), + "fallback_transport": effective_transport, + "fallback_url": effective_url, + "endpoint": endpoint, + "client_transports": client_transports, + "server_transports": [ + iface.transport for iface in agent_card.additional_interfaces or [] + ] + + [agent_card.preferred_transport or "JSONRPC"], + }, + ) + + effective_output_modes = accepted_output_modes or DEFAULT_CLIENT_OUTPUT_MODES.copy() + + content_negotiated = negotiate_content_types( + agent_card=agent_card, + client_output_modes=accepted_output_modes, + skill_name=skill_id, + endpoint=endpoint, + a2a_agent_name=agent_card.name, + ) + if content_negotiated.output_modes: + effective_output_modes = content_negotiated.output_modes + + headers, _ = await _prepare_auth_headers(auth, timeout) + + a2a_agent_name = None + if agent_card.name: + a2a_agent_name = agent_card.name + + agent_card_dict = agent_card.model_dump(exclude_none=True) + crewai_event_bus.emit( + agent_branch, + A2ADelegationStartedEvent( + endpoint=endpoint, + task_description=task_description, + agent_id=agent_id or endpoint, + context_id=context_id, + is_multiturn=is_multiturn, + turn_number=turn_number, + a2a_agent_name=a2a_agent_name, + agent_card=agent_card_dict, + protocol_version=agent_card.protocol_version, + provider=agent_card_dict.get("provider"), + skill_id=skill_id, + metadata=metadata, + extensions=list(extensions.keys()) if extensions else None, + from_task=from_task, + from_agent=from_agent, + ), + ) + + if turn_number == 1: + agent_id_for_event = agent_id or endpoint + crewai_event_bus.emit( + agent_branch, + A2AConversationStartedEvent( + agent_id=agent_id_for_event, + endpoint=endpoint, + context_id=context_id, + a2a_agent_name=a2a_agent_name, + agent_card=agent_card_dict, + protocol_version=agent_card.protocol_version, + provider=agent_card_dict.get("provider"), + skill_id=skill_id, + reference_task_ids=reference_task_ids, + metadata=metadata, + extensions=list(extensions.keys()) if extensions else None, + from_task=from_task, + from_agent=from_agent, + ), + ) + + message_parts = [] + + if context: + message_parts.append(f"Context:\n{context}\n\n") + message_parts.append(f"{task_description}") + message_text = "".join(message_parts) + + if is_multiturn and conversation_history and not task_id: + if first_task_id := conversation_history[0].task_id: + task_id = first_task_id + + parts: PartsDict = {"text": message_text} + if response_model: + parts.update( + { + "metadata": PartsMetadataDict( + mimeType="application/json", + schema=response_model.model_json_schema(), + ) + } + ) + + message_metadata = metadata.copy() if metadata else {} + if skill_id: + message_metadata["skill_id"] = skill_id + + parts_list: list[Part] = [Part(root=TextPart(**parts))] + parts_list.extend(_create_file_parts(input_files)) + + message = Message( + role=Role.user, + message_id=str(uuid.uuid4()), + parts=parts_list, + context_id=context_id, + task_id=task_id, + reference_task_ids=reference_task_ids, + metadata=message_metadata if message_metadata else None, + extensions=extensions, + ) + + new_messages: list[Message] = [*conversation_history, message] + crewai_event_bus.emit( + None, + A2AMessageSentEvent( + message=message_text, + turn_number=turn_number, + context_id=context_id, + message_id=message.message_id, + is_multiturn=is_multiturn, + agent_role=agent_role, + endpoint=endpoint, + a2a_agent_name=a2a_agent_name, + skill_id=skill_id, + metadata=message_metadata if message_metadata else None, + extensions=list(extensions.keys()) if extensions else None, + from_task=from_task, + from_agent=from_agent, + ), + ) + + handler = get_handler(updates) + use_polling = isinstance(updates, PollingConfig) + + handler_kwargs: dict[str, Any] = { + "turn_number": turn_number, + "is_multiturn": is_multiturn, + "agent_role": agent_role, + "context_id": context_id, + "task_id": task_id, + "endpoint": endpoint, + "agent_branch": agent_branch, + "a2a_agent_name": a2a_agent_name, + "from_task": from_task, + "from_agent": from_agent, + } + + if isinstance(updates, PollingConfig): + handler_kwargs.update( + { + "polling_interval": updates.interval, + "polling_timeout": updates.timeout or float(timeout), + "history_length": updates.history_length, + "max_polls": updates.max_polls, + } + ) + elif isinstance(updates, PushNotificationConfig): + handler_kwargs.update( + { + "config": updates, + "result_store": updates.result_store, + "polling_timeout": updates.timeout or float(timeout), + "polling_interval": updates.interval, + } + ) + + push_config_for_client = ( + updates if isinstance(updates, PushNotificationConfig) else None + ) + + use_streaming = not use_polling and push_config_for_client is None + + client_agent_card = agent_card + if effective_url != agent_card.url: + client_agent_card = agent_card.model_copy(update={"url": effective_url}) + + async with _create_a2a_client( + agent_card=client_agent_card, + transport_protocol=effective_transport, + timeout=timeout, + headers=headers, + streaming=use_streaming, + auth=auth, + use_polling=use_polling, + push_notification_config=push_config_for_client, + client_extensions=client_extensions, + accepted_output_modes=effective_output_modes, # type: ignore[arg-type] + grpc_config=transport.grpc, + ) as client: + result = await handler.execute( + client=client, + message=message, + new_messages=new_messages, + agent_card=agent_card, + **handler_kwargs, + ) + result["a2a_agent_name"] = a2a_agent_name + result["agent_card"] = agent_card.model_dump(exclude_none=True) + return result + + +def _normalize_grpc_metadata( + metadata: tuple[tuple[str, str], ...] | None, +) -> tuple[tuple[str, str], ...] | None: + """Lowercase all gRPC metadata keys. + + gRPC requires lowercase metadata keys, but some libraries (like the A2A SDK) + use mixed-case headers like 'X-A2A-Extensions'. This normalizes them. + """ + if metadata is None: + return None + return tuple((key.lower(), value) for key, value in metadata) + + +def _create_grpc_interceptors( + auth_metadata: list[tuple[str, str]] | None = None, +) -> list[Any]: + """Create gRPC interceptors for metadata normalization and auth injection. + + Args: + auth_metadata: Optional auth metadata to inject into all calls. + Used for insecure channels that need auth (non-localhost without TLS). + + Returns a list of interceptors that lowercase metadata keys for gRPC + compatibility. Must be called after grpc is imported. + """ + import grpc.aio # type: ignore[import-untyped] + + def _merge_metadata( + existing: tuple[tuple[str, str], ...] | None, + auth: list[tuple[str, str]] | None, + ) -> tuple[tuple[str, str], ...] | None: + """Merge existing metadata with auth metadata and normalize keys.""" + merged: list[tuple[str, str]] = [] + if existing: + merged.extend(existing) + if auth: + merged.extend(auth) + if not merged: + return None + return tuple((key.lower(), value) for key, value in merged) + + def _inject_metadata(client_call_details: Any) -> Any: + """Inject merged metadata into call details.""" + return client_call_details._replace( + metadata=_merge_metadata(client_call_details.metadata, auth_metadata) + ) + + class MetadataUnaryUnary(grpc.aio.UnaryUnaryClientInterceptor): # type: ignore[misc,no-any-unimported] + """Interceptor for unary-unary calls that injects auth metadata.""" + + async def intercept_unary_unary( # type: ignore[no-untyped-def] + self, continuation, client_call_details, request + ): + """Intercept unary-unary call and inject metadata.""" + return await continuation(_inject_metadata(client_call_details), request) + + class MetadataUnaryStream(grpc.aio.UnaryStreamClientInterceptor): # type: ignore[misc,no-any-unimported] + """Interceptor for unary-stream calls that injects auth metadata.""" + + async def intercept_unary_stream( # type: ignore[no-untyped-def] + self, continuation, client_call_details, request + ): + """Intercept unary-stream call and inject metadata.""" + return await continuation(_inject_metadata(client_call_details), request) + + class MetadataStreamUnary(grpc.aio.StreamUnaryClientInterceptor): # type: ignore[misc,no-any-unimported] + """Interceptor for stream-unary calls that injects auth metadata.""" + + async def intercept_stream_unary( # type: ignore[no-untyped-def] + self, continuation, client_call_details, request_iterator + ): + """Intercept stream-unary call and inject metadata.""" + return await continuation( + _inject_metadata(client_call_details), request_iterator + ) + + class MetadataStreamStream(grpc.aio.StreamStreamClientInterceptor): # type: ignore[misc,no-any-unimported] + """Interceptor for stream-stream calls that injects auth metadata.""" + + async def intercept_stream_stream( # type: ignore[no-untyped-def] + self, continuation, client_call_details, request_iterator + ): + """Intercept stream-stream call and inject metadata.""" + return await continuation( + _inject_metadata(client_call_details), request_iterator + ) + + return [ + MetadataUnaryUnary(), + MetadataUnaryStream(), + MetadataStreamUnary(), + MetadataStreamStream(), + ] + + +def _create_grpc_channel_factory( + grpc_config: GRPCClientConfig, + auth: ClientAuthScheme | None = None, +) -> Callable[[str], Any]: + """Create a gRPC channel factory with the given configuration. + + Args: + grpc_config: gRPC client configuration with channel options. + auth: Optional ClientAuthScheme for TLS and auth configuration. + + Returns: + A callable that creates gRPC channels from URLs. + """ + try: + import grpc + except ImportError as e: + raise ImportError( + "gRPC transport requires grpcio. Install with: pip install a2a-sdk[grpc]" + ) from e + + auth_metadata: list[tuple[str, str]] = [] + + if auth is not None: + from crewai_a2a.auth.client_schemes import ( + APIKeyAuth, + BearerTokenAuth, + HTTPBasicAuth, + HTTPDigestAuth, + OAuth2AuthorizationCode, + OAuth2ClientCredentials, + ) + + if isinstance(auth, HTTPDigestAuth): + raise ValueError( + "HTTPDigestAuth is not supported with gRPC transport. " + "Digest authentication requires HTTP challenge-response flow. " + "Use BearerTokenAuth, HTTPBasicAuth, APIKeyAuth (header), or OAuth2 instead." + ) + if isinstance(auth, APIKeyAuth) and auth.location in ("query", "cookie"): + raise ValueError( + f"APIKeyAuth with location='{auth.location}' is not supported with gRPC transport. " + "gRPC only supports header-based authentication. " + "Use APIKeyAuth with location='header' instead." + ) + + if isinstance(auth, BearerTokenAuth): + auth_metadata.append(("authorization", f"Bearer {auth.token}")) + elif isinstance(auth, HTTPBasicAuth): + import base64 + + basic_credentials = f"{auth.username}:{auth.password}" + encoded = base64.b64encode(basic_credentials.encode()).decode() + auth_metadata.append(("authorization", f"Basic {encoded}")) + elif isinstance(auth, APIKeyAuth) and auth.location == "header": + header_name = auth.name.lower() + auth_metadata.append((header_name, auth.api_key)) + elif isinstance(auth, (OAuth2ClientCredentials, OAuth2AuthorizationCode)): + if auth._access_token: + auth_metadata.append(("authorization", f"Bearer {auth._access_token}")) + + def factory(url: str) -> Any: + """Create a gRPC channel for the given URL.""" + target = url + use_tls = False + + if url.startswith("grpcs://"): + target = url[8:] + use_tls = True + elif url.startswith("grpc://"): + target = url[7:] + elif url.startswith("https://"): + target = url[8:] + use_tls = True + elif url.startswith("http://"): + target = url[7:] + + options: list[tuple[str, Any]] = [] + if grpc_config.max_send_message_length is not None: + options.append( + ("grpc.max_send_message_length", grpc_config.max_send_message_length) + ) + if grpc_config.max_receive_message_length is not None: + options.append( + ( + "grpc.max_receive_message_length", + grpc_config.max_receive_message_length, + ) + ) + if grpc_config.keepalive_time_ms is not None: + options.append(("grpc.keepalive_time_ms", grpc_config.keepalive_time_ms)) + if grpc_config.keepalive_timeout_ms is not None: + options.append( + ("grpc.keepalive_timeout_ms", grpc_config.keepalive_timeout_ms) + ) + + channel_credentials = None + if auth and hasattr(auth, "tls") and auth.tls: + channel_credentials = auth.tls.get_grpc_credentials() + elif use_tls: + channel_credentials = grpc.ssl_channel_credentials() + + if channel_credentials and auth_metadata: + + class AuthMetadataPlugin(grpc.AuthMetadataPlugin): # type: ignore[misc,no-any-unimported] + """gRPC auth metadata plugin that adds auth headers as metadata.""" + + def __init__(self, metadata: list[tuple[str, str]]) -> None: + self._metadata = tuple(metadata) + + def __call__( # type: ignore[no-any-unimported] + self, + context: grpc.AuthMetadataContext, + callback: grpc.AuthMetadataPluginCallback, + ) -> None: + callback(self._metadata, None) + + call_creds = grpc.metadata_call_credentials( + AuthMetadataPlugin(auth_metadata) + ) + credentials = grpc.composite_channel_credentials( + channel_credentials, call_creds + ) + interceptors = _create_grpc_interceptors() + return grpc.aio.secure_channel( + target, credentials, options=options or None, interceptors=interceptors + ) + if channel_credentials: + interceptors = _create_grpc_interceptors() + return grpc.aio.secure_channel( + target, + channel_credentials, + options=options or None, + interceptors=interceptors, + ) + interceptors = _create_grpc_interceptors( + auth_metadata=auth_metadata if auth_metadata else None + ) + return grpc.aio.insecure_channel( + target, options=options or None, interceptors=interceptors + ) + + return factory + + +@asynccontextmanager +async def _create_a2a_client( + agent_card: AgentCard, + transport_protocol: Literal["JSONRPC", "GRPC", "HTTP+JSON"], + timeout: int, + headers: MutableMapping[str, str], + streaming: bool, + auth: ClientAuthScheme | None = None, + use_polling: bool = False, + push_notification_config: PushNotificationConfig | None = None, + client_extensions: list[str] | None = None, + accepted_output_modes: list[str] | None = None, + grpc_config: GRPCClientConfig | None = None, +) -> AsyncIterator[Client]: + """Create and configure an A2A client. + + Args: + agent_card: The A2A agent card. + transport_protocol: Transport protocol to use. + timeout: Request timeout in seconds. + headers: HTTP headers (already with auth applied). + streaming: Enable streaming responses. + auth: Optional ClientAuthScheme for client configuration. + use_polling: Enable polling mode. + push_notification_config: Optional push notification config. + client_extensions: A2A protocol extension URIs to declare support for. + accepted_output_modes: MIME types the client can accept in responses. + grpc_config: Optional gRPC client configuration. + + Yields: + Configured A2A client instance. + """ + verify = _get_tls_verify(auth) + async with httpx.AsyncClient( + timeout=timeout, + headers=headers, + verify=verify, + ) as httpx_client: + if auth and isinstance(auth, (HTTPDigestAuth, APIKeyAuth)): + configure_auth_client(auth, httpx_client) + + push_configs: list[A2APushNotificationConfig] = [] + if push_notification_config is not None: + push_configs.append( + A2APushNotificationConfig( + url=str(push_notification_config.url), + id=push_notification_config.id, + token=push_notification_config.token, + authentication=push_notification_config.authentication, + ) + ) + + grpc_channel_factory = None + if transport_protocol == "GRPC": + grpc_channel_factory = _create_grpc_channel_factory( + grpc_config or GRPCClientConfig(), + auth=auth, + ) + + config = ClientConfig( + httpx_client=httpx_client, + supported_transports=[transport_protocol], + streaming=streaming and not use_polling, + polling=use_polling, + accepted_output_modes=accepted_output_modes or DEFAULT_CLIENT_OUTPUT_MODES, # type: ignore[arg-type] + push_notification_configs=push_configs, + grpc_channel_factory=grpc_channel_factory, + ) + + factory = ClientFactory(config) + client = factory.create(agent_card) + + if client_extensions: + await client.add_request_middleware(ExtensionsMiddleware(client_extensions)) + + yield client diff --git a/lib/crewai-a2a/src/crewai_a2a/utils/logging.py b/lib/crewai-a2a/src/crewai_a2a/utils/logging.py new file mode 100644 index 000000000..585d1d8f3 --- /dev/null +++ b/lib/crewai-a2a/src/crewai_a2a/utils/logging.py @@ -0,0 +1,131 @@ +"""Structured JSON logging utilities for A2A module.""" + +from __future__ import annotations + +from contextvars import ContextVar +from datetime import datetime, timezone +import json +import logging +from typing import Any + + +_log_context: ContextVar[dict[str, Any] | None] = ContextVar( + "log_context", default=None +) + + +class JSONFormatter(logging.Formatter): + """JSON formatter for structured logging. + + Outputs logs as JSON with consistent fields for log aggregators. + """ + + def format(self, record: logging.LogRecord) -> str: + """Format log record as JSON string.""" + log_data: dict[str, Any] = { + "timestamp": datetime.now(timezone.utc).isoformat(), + "level": record.levelname, + "logger": record.name, + "message": record.getMessage(), + } + + if record.exc_info: + log_data["exception"] = self.formatException(record.exc_info) + + context = _log_context.get() + if context is not None: + log_data.update(context) + + if hasattr(record, "task_id"): + log_data["task_id"] = record.task_id + if hasattr(record, "context_id"): + log_data["context_id"] = record.context_id + if hasattr(record, "agent"): + log_data["agent"] = record.agent + if hasattr(record, "endpoint"): + log_data["endpoint"] = record.endpoint + if hasattr(record, "extension"): + log_data["extension"] = record.extension + if hasattr(record, "error"): + log_data["error"] = record.error + + for key, value in record.__dict__.items(): + if key.startswith("_") or key in ( + "name", + "msg", + "args", + "created", + "filename", + "funcName", + "levelname", + "levelno", + "lineno", + "module", + "msecs", + "pathname", + "process", + "processName", + "relativeCreated", + "stack_info", + "exc_info", + "exc_text", + "thread", + "threadName", + "taskName", + "message", + ): + continue + if key not in log_data: + log_data[key] = value + + return json.dumps(log_data, default=str) + + +class LogContext: + """Context manager for adding fields to all logs within a scope. + + Example: + with LogContext(task_id="abc", context_id="xyz"): + logger.info("Processing task") # Includes task_id and context_id + """ + + def __init__(self, **fields: Any) -> None: + self._fields = fields + self._token: Any = None + + def __enter__(self) -> LogContext: + current = _log_context.get() or {} + new_context = {**current, **self._fields} + self._token = _log_context.set(new_context) + return self + + def __exit__(self, *args: Any) -> None: + _log_context.reset(self._token) + + +def configure_json_logging(logger_name: str = "crewai.a2a") -> None: + """Configure JSON logging for the A2A module. + + Args: + logger_name: Logger name to configure. + """ + logger = logging.getLogger(logger_name) + + for handler in logger.handlers[:]: + logger.removeHandler(handler) + + handler = logging.StreamHandler() + handler.setFormatter(JSONFormatter()) + logger.addHandler(handler) + + +def get_logger(name: str) -> logging.Logger: + """Get a logger configured for structured JSON output. + + Args: + name: Logger name. + + Returns: + Configured logger instance. + """ + return logging.getLogger(name) diff --git a/lib/crewai-a2a/src/crewai_a2a/utils/response_model.py b/lib/crewai-a2a/src/crewai_a2a/utils/response_model.py new file mode 100644 index 000000000..6f4b5f057 --- /dev/null +++ b/lib/crewai-a2a/src/crewai_a2a/utils/response_model.py @@ -0,0 +1,101 @@ +"""Response model utilities for A2A agent interactions.""" + +from __future__ import annotations + +from typing import TypeAlias + +from crewai.types.utils import create_literals_from_strings +from pydantic import BaseModel, Field, create_model + +from crewai_a2a.config import A2AClientConfig, A2AConfig, A2AServerConfig + + +A2AConfigTypes: TypeAlias = A2AConfig | A2AServerConfig | A2AClientConfig +A2AClientConfigTypes: TypeAlias = A2AConfig | A2AClientConfig + + +def create_agent_response_model(agent_ids: tuple[str, ...]) -> type[BaseModel] | None: + """Create a dynamic AgentResponse model with Literal types for agent IDs. + + Args: + agent_ids: List of available A2A agent IDs. + + Returns: + Dynamically created Pydantic model with Literal-constrained a2a_ids field, + or None if agent_ids is empty. + """ + if not agent_ids: + return None + + DynamicLiteral = create_literals_from_strings(agent_ids) # noqa: N806 + + return create_model( + "AgentResponse", + a2a_ids=( + tuple[DynamicLiteral, ...], # type: ignore[valid-type] + Field( + default_factory=tuple, + max_length=len(agent_ids), + description="A2A agent IDs to delegate to.", + ), + ), + message=( + str, + Field( + description="The message content. If is_a2a=true, this is sent to the A2A agent. If is_a2a=false, this is your final answer ending the conversation." + ), + ), + is_a2a=( + bool, + Field( + description="Set to false when the remote agent has answered your question - extract their answer and return it as your final message. Set to true ONLY if you need to ask a NEW, DIFFERENT question. NEVER repeat the same request - if the conversation history shows the agent already answered, set is_a2a=false immediately." + ), + ), + __base__=BaseModel, + ) + + +def extract_a2a_agent_ids_from_config( + a2a_config: list[A2AConfigTypes] | A2AConfigTypes | None, +) -> tuple[list[A2AClientConfigTypes], tuple[str, ...]]: + """Extract A2A agent IDs from A2A configuration. + + Filters out A2AServerConfig since it doesn't have an endpoint for delegation. + + Args: + a2a_config: A2A configuration (any type). + + Returns: + Tuple of client A2A configs list and agent endpoint IDs. + """ + if a2a_config is None: + return [], () + + configs: list[A2AConfigTypes] + if isinstance(a2a_config, (A2AConfig, A2AClientConfig, A2AServerConfig)): + configs = [a2a_config] + else: + configs = a2a_config + + # Filter to only client configs (those with endpoint) + client_configs: list[A2AClientConfigTypes] = [ + config for config in configs if isinstance(config, (A2AConfig, A2AClientConfig)) + ] + + return client_configs, tuple(config.endpoint for config in client_configs) + + +def get_a2a_agents_and_response_model( + a2a_config: list[A2AConfigTypes] | A2AConfigTypes | None, +) -> tuple[list[A2AClientConfigTypes], type[BaseModel] | None]: + """Get A2A agent configs and response model. + + Args: + a2a_config: A2A configuration (any type). + + Returns: + Tuple of client A2A configs and response model. + """ + a2a_agents, agent_ids = extract_a2a_agent_ids_from_config(a2a_config=a2a_config) + + return a2a_agents, create_agent_response_model(agent_ids) diff --git a/lib/crewai-a2a/src/crewai_a2a/utils/task.py b/lib/crewai-a2a/src/crewai_a2a/utils/task.py new file mode 100644 index 000000000..8aa2484d9 --- /dev/null +++ b/lib/crewai-a2a/src/crewai_a2a/utils/task.py @@ -0,0 +1,585 @@ +"""A2A task utilities for server-side task management.""" + +from __future__ import annotations + +import asyncio +import base64 +from collections.abc import Callable, Coroutine +from datetime import datetime +from functools import wraps +import json +import logging +import os +from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar, TypedDict, cast +from urllib.parse import urlparse + +from a2a.server.agent_execution import RequestContext +from a2a.server.events import EventQueue +from a2a.types import ( + Artifact, + FileWithBytes, + FileWithUri, + InternalError, + InvalidParamsError, + Message, + Part, + Task as A2ATask, + TaskState, + TaskStatus, + TaskStatusUpdateEvent, +) +from a2a.utils import ( + get_data_parts, + get_file_parts, + new_agent_text_message, + new_data_artifact, + new_text_artifact, +) +from a2a.utils.errors import ServerError +from aiocache import SimpleMemoryCache, caches # type: ignore[import-untyped] +from crewai.events.event_bus import crewai_event_bus +from crewai.events.types.a2a_events import ( + A2AServerTaskCanceledEvent, + A2AServerTaskCompletedEvent, + A2AServerTaskFailedEvent, + A2AServerTaskStartedEvent, +) +from crewai.task import Task +from crewai.utilities.pydantic_schema_utils import create_model_from_schema +from pydantic import BaseModel + +from crewai_a2a.utils.agent_card import _get_server_config +from crewai_a2a.utils.content_type import validate_message_parts + + +if TYPE_CHECKING: + from crewai.agent import Agent + + from crewai_a2a.extensions.server import ExtensionContext, ServerExtensionRegistry + + +logger = logging.getLogger(__name__) + +P = ParamSpec("P") +T = TypeVar("T") + + +class RedisCacheConfig(TypedDict, total=False): + """Configuration for aiocache Redis backend.""" + + cache: str + endpoint: str + port: int + db: int + password: str + + +def _parse_redis_url(url: str) -> RedisCacheConfig: + """Parse a Redis URL into aiocache configuration. + + Args: + url: Redis connection URL (e.g., redis://localhost:6379/0). + + Returns: + Configuration dict for aiocache.RedisCache. + """ + parsed = urlparse(url) + config: RedisCacheConfig = { + "cache": "aiocache.RedisCache", + "endpoint": parsed.hostname or "localhost", + "port": parsed.port or 6379, + } + if parsed.path and parsed.path != "/": + try: + config["db"] = int(parsed.path.lstrip("/")) + except ValueError: + pass + if parsed.password: + config["password"] = parsed.password + return config + + +_redis_url = os.environ.get("REDIS_URL") + +caches.set_config( + { + "default": _parse_redis_url(_redis_url) + if _redis_url + else { + "cache": "aiocache.SimpleMemoryCache", + } + } +) + + +def cancellable( + fn: Callable[P, Coroutine[Any, Any, T]], +) -> Callable[P, Coroutine[Any, Any, T]]: + """Decorator that enables cancellation for A2A task execution. + + Runs a cancellation watcher concurrently with the wrapped function. + When a cancel event is published, the execution is cancelled. + + Args: + fn: The async function to wrap. + + Returns: + Wrapped function with cancellation support. + """ + + @wraps(fn) + async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: + """Wrap function with cancellation monitoring.""" + context: RequestContext | None = None + for arg in args: + if isinstance(arg, RequestContext): + context = arg + break + if context is None: + context = cast(RequestContext | None, kwargs.get("context")) + + if context is None: + return await fn(*args, **kwargs) + + task_id = context.task_id + cache = caches.get("default") + + async def poll_for_cancel() -> bool: + """Poll cache for cancellation flag.""" + while True: + if await cache.get(f"cancel:{task_id}"): + return True + await asyncio.sleep(0.1) + + async def watch_for_cancel() -> bool: + """Watch for cancellation events via pub/sub or polling.""" + if isinstance(cache, SimpleMemoryCache): + return await poll_for_cancel() + + try: + client = cache.client + pubsub = client.pubsub() + await pubsub.subscribe(f"cancel:{task_id}") + async for message in pubsub.listen(): + if message["type"] == "message": + return True + except (OSError, ConnectionError) as e: + logger.warning( + "Cancel watcher Redis error, falling back to polling", + extra={"task_id": task_id, "error": str(e)}, + ) + return await poll_for_cancel() + return False + + execute_task = asyncio.create_task(fn(*args, **kwargs)) + cancel_watch = asyncio.create_task(watch_for_cancel()) + + try: + done, _ = await asyncio.wait( + [execute_task, cancel_watch], + return_when=asyncio.FIRST_COMPLETED, + ) + + if cancel_watch in done: + execute_task.cancel() + try: + await execute_task + except asyncio.CancelledError: + pass + raise asyncio.CancelledError(f"Task {task_id} was cancelled") + cancel_watch.cancel() + return execute_task.result() + finally: + await cache.delete(f"cancel:{task_id}") + + return wrapper + + +def _convert_a2a_files_to_file_inputs( + a2a_files: list[FileWithBytes | FileWithUri], +) -> dict[str, Any]: + """Convert a2a file types to crewai FileInput dict. + + Args: + a2a_files: List of FileWithBytes or FileWithUri from a2a SDK. + + Returns: + Dictionary mapping file names to FileInput objects. + """ + try: + from crewai_files import File, FileBytes + except ImportError: + logger.debug("crewai_files not installed, returning empty file dict") + return {} + + file_dict: dict[str, Any] = {} + for idx, a2a_file in enumerate(a2a_files): + if isinstance(a2a_file, FileWithBytes): + file_bytes = base64.b64decode(a2a_file.bytes) + name = a2a_file.name or f"file_{idx}" + file_source = FileBytes(data=file_bytes, filename=a2a_file.name) + file_dict[name] = File(source=file_source) + elif isinstance(a2a_file, FileWithUri): + name = a2a_file.name or f"file_{idx}" + file_dict[name] = File(source=a2a_file.uri) + + return file_dict + + +def _extract_response_schema(parts: list[Part]) -> dict[str, Any] | None: + """Extract response schema from message parts metadata. + + The client may include a JSON schema in TextPart metadata to specify + the expected response format (see delegation.py line 463). + + Args: + parts: List of message parts. + + Returns: + JSON schema dict if found, None otherwise. + """ + for part in parts: + if part.root.kind == "text" and part.root.metadata: + schema = part.root.metadata.get("schema") + if schema and isinstance(schema, dict): + return schema # type: ignore[no-any-return] + return None + + +def _create_result_artifact( + result: Any, + task_id: str, +) -> Artifact: + """Create artifact from task result, using DataPart for structured data. + + Args: + result: The task execution result. + task_id: The task ID for naming the artifact. + + Returns: + Artifact with appropriate part type (DataPart for dict/Pydantic, TextPart for strings). + """ + artifact_name = f"result_{task_id}" + if isinstance(result, dict): + return new_data_artifact(artifact_name, result) + if isinstance(result, BaseModel): + return new_data_artifact(artifact_name, result.model_dump()) + return new_text_artifact(artifact_name, str(result)) + + +def _build_task_description( + user_message: str, + structured_inputs: list[dict[str, Any]], +) -> str: + """Build task description including structured data if present. + + Args: + user_message: The original user message text. + structured_inputs: List of structured data from DataParts. + + Returns: + Task description with structured data appended if present. + """ + if not structured_inputs: + return user_message + + structured_json = json.dumps(structured_inputs, indent=2) + return f"{user_message}\n\nStructured Data:\n{structured_json}" + + +async def execute( + agent: Agent, + context: RequestContext, + event_queue: EventQueue, +) -> None: + """Execute an A2A task using a CrewAI agent. + + Args: + agent: The CrewAI agent to execute the task. + context: The A2A request context containing the user's message. + event_queue: The event queue for sending responses back. + """ + await _execute_impl(agent, context, event_queue, None, None) + + +@cancellable +async def _execute_impl( + agent: Agent, + context: RequestContext, + event_queue: EventQueue, + extension_registry: ServerExtensionRegistry | None, + extension_context: ExtensionContext | None, +) -> None: + """Internal implementation for task execution with optional extensions.""" + server_config = _get_server_config(agent) + if context.message and context.message.parts and server_config: + allowed_modes = server_config.default_input_modes + invalid_types = validate_message_parts(context.message.parts, allowed_modes) + if invalid_types: + raise ServerError( + InvalidParamsError( + message=f"Unsupported content type(s): {', '.join(invalid_types)}. " + f"Supported: {', '.join(allowed_modes)}" + ) + ) + + if extension_registry and extension_context: + await extension_registry.invoke_on_request(extension_context) + + user_message = context.get_user_input() + + response_model: type[BaseModel] | None = None + structured_inputs: list[dict[str, Any]] = [] + a2a_files: list[FileWithBytes | FileWithUri] = [] + + if context.message and context.message.parts: + schema = _extract_response_schema(context.message.parts) + if schema: + try: + response_model = create_model_from_schema(schema) + except Exception as e: + logger.debug( + "Failed to create response model from schema", + extra={"error": str(e), "schema_title": schema.get("title")}, + ) + + structured_inputs = get_data_parts(context.message.parts) + a2a_files = get_file_parts(context.message.parts) + + task_id = context.task_id + context_id = context.context_id + if task_id is None or context_id is None: + msg = "task_id and context_id are required" + crewai_event_bus.emit( + agent, + A2AServerTaskFailedEvent( + task_id="", + context_id="", + error=msg, + from_agent=agent, + ), + ) + raise ServerError(InvalidParamsError(message=msg)) from None + + task = Task( + description=_build_task_description(user_message, structured_inputs), + expected_output="Response to the user's request", + agent=agent, + response_model=response_model, + input_files=_convert_a2a_files_to_file_inputs(a2a_files), + ) + + crewai_event_bus.emit( + agent, + A2AServerTaskStartedEvent( + task_id=task_id, + context_id=context_id, + from_task=task, + from_agent=agent, + ), + ) + + try: + result = await agent.aexecute_task(task=task, tools=agent.tools) + if extension_registry and extension_context: + result = await extension_registry.invoke_on_response( + extension_context, result + ) + result_str = str(result) + history: list[Message] = [context.message] if context.message else [] + history.append(new_agent_text_message(result_str, context_id, task_id)) + await event_queue.enqueue_event( + A2ATask( + id=task_id, + context_id=context_id, + status=TaskStatus(state=TaskState.completed), + artifacts=[_create_result_artifact(result, task_id)], + history=history, + ) + ) + crewai_event_bus.emit( + agent, + A2AServerTaskCompletedEvent( + task_id=task_id, + context_id=context_id, + result=str(result), + from_task=task, + from_agent=agent, + ), + ) + except asyncio.CancelledError: + crewai_event_bus.emit( + agent, + A2AServerTaskCanceledEvent( + task_id=task_id, + context_id=context_id, + from_task=task, + from_agent=agent, + ), + ) + raise + except Exception as e: + crewai_event_bus.emit( + agent, + A2AServerTaskFailedEvent( + task_id=task_id, + context_id=context_id, + error=str(e), + from_task=task, + from_agent=agent, + ), + ) + raise ServerError( + error=InternalError(message=f"Task execution failed: {e}") + ) from e + + +async def execute_with_extensions( + agent: Agent, + context: RequestContext, + event_queue: EventQueue, + extension_registry: ServerExtensionRegistry, + extension_context: ExtensionContext, +) -> None: + """Execute an A2A task with extension hooks. + + Args: + agent: The CrewAI agent to execute the task. + context: The A2A request context containing the user's message. + event_queue: The event queue for sending responses back. + extension_registry: Registry of server extensions. + extension_context: Context for extension hooks. + """ + await _execute_impl( + agent, context, event_queue, extension_registry, extension_context + ) + + +async def cancel( + context: RequestContext, + event_queue: EventQueue, +) -> A2ATask | None: + """Cancel an A2A task. + + Publishes a cancel event that the cancellable decorator listens for. + + Args: + context: The A2A request context containing task information. + event_queue: The event queue for sending the cancellation status. + + Returns: + The canceled task with updated status. + """ + task_id = context.task_id + context_id = context.context_id + if task_id is None or context_id is None: + raise ServerError(InvalidParamsError(message="task_id and context_id required")) + + if context.current_task and context.current_task.status.state in ( + TaskState.completed, + TaskState.failed, + TaskState.canceled, + ): + return context.current_task + + cache = caches.get("default") + + await cache.set(f"cancel:{task_id}", True, ttl=3600) + if not isinstance(cache, SimpleMemoryCache): + await cache.client.publish(f"cancel:{task_id}", "cancel") + + await event_queue.enqueue_event( + TaskStatusUpdateEvent( + task_id=task_id, + context_id=context_id, + status=TaskStatus(state=TaskState.canceled), + final=True, + ) + ) + + if context.current_task: + context.current_task.status = TaskStatus(state=TaskState.canceled) + return context.current_task + return None + + +def list_tasks( + tasks: list[A2ATask], + context_id: str | None = None, + status: TaskState | None = None, + status_timestamp_after: datetime | None = None, + page_size: int = 50, + page_token: str | None = None, + history_length: int | None = None, + include_artifacts: bool = False, +) -> tuple[list[A2ATask], str | None, int]: + """Filter and paginate A2A tasks. + + Provides filtering by context, status, and timestamp, along with + cursor-based pagination. This is a pure utility function that operates + on an in-memory list of tasks - storage retrieval is handled separately. + + Args: + tasks: All tasks to filter. + context_id: Filter by context ID to get tasks in a conversation. + status: Filter by task state (e.g., completed, working). + status_timestamp_after: Filter to tasks updated after this time. + page_size: Maximum tasks per page (default 50). + page_token: Base64-encoded cursor from previous response. + history_length: Limit history messages per task (None = full history). + include_artifacts: Whether to include task artifacts (default False). + + Returns: + Tuple of (filtered_tasks, next_page_token, total_count). + - filtered_tasks: Tasks matching filters, paginated and trimmed. + - next_page_token: Token for next page, or None if no more pages. + - total_count: Total number of tasks matching filters (before pagination). + """ + filtered: list[A2ATask] = [] + for task in tasks: + if context_id and task.context_id != context_id: + continue + if status and task.status.state != status: + continue + if status_timestamp_after and task.status.timestamp: + ts = datetime.fromisoformat(task.status.timestamp.replace("Z", "+00:00")) + if ts <= status_timestamp_after: + continue + filtered.append(task) + + def get_timestamp(t: A2ATask) -> datetime: + """Extract timestamp from task status for sorting.""" + if t.status.timestamp is None: + return datetime.min + return datetime.fromisoformat(t.status.timestamp.replace("Z", "+00:00")) + + filtered.sort(key=get_timestamp, reverse=True) + total = len(filtered) + + start = 0 + if page_token: + try: + cursor_id = base64.b64decode(page_token).decode() + for idx, task in enumerate(filtered): + if task.id == cursor_id: + start = idx + 1 + break + except (ValueError, UnicodeDecodeError): + pass + + page = filtered[start : start + page_size] + + result: list[A2ATask] = [] + for task in page: + task = task.model_copy(deep=True) + if history_length is not None and task.history: + task.history = task.history[-history_length:] + if not include_artifacts: + task.artifacts = None + result.append(task) + + next_token: str | None = None + if result and len(result) == page_size: + next_token = base64.b64encode(result[-1].id.encode()).decode() + + return result, next_token, total diff --git a/lib/crewai-a2a/src/crewai_a2a/utils/transport.py b/lib/crewai-a2a/src/crewai_a2a/utils/transport.py new file mode 100644 index 000000000..d4667e1de --- /dev/null +++ b/lib/crewai-a2a/src/crewai_a2a/utils/transport.py @@ -0,0 +1,214 @@ +"""Transport negotiation utilities for A2A protocol. + +This module provides functionality for negotiating the transport protocol +between an A2A client and server based on their respective capabilities +and preferences. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Final, Literal + +from a2a.types import AgentCard, AgentInterface +from crewai.events.event_bus import crewai_event_bus +from crewai.events.types.a2a_events import A2ATransportNegotiatedEvent + + +TransportProtocol = Literal["JSONRPC", "GRPC", "HTTP+JSON"] +NegotiationSource = Literal["client_preferred", "server_preferred", "fallback"] + +JSONRPC_TRANSPORT: Literal["JSONRPC"] = "JSONRPC" +GRPC_TRANSPORT: Literal["GRPC"] = "GRPC" +HTTP_JSON_TRANSPORT: Literal["HTTP+JSON"] = "HTTP+JSON" + +DEFAULT_TRANSPORT_PREFERENCE: Final[list[TransportProtocol]] = [ + JSONRPC_TRANSPORT, + GRPC_TRANSPORT, + HTTP_JSON_TRANSPORT, +] + + +@dataclass +class NegotiatedTransport: + """Result of transport negotiation. + + Attributes: + transport: The negotiated transport protocol. + url: The URL to use for this transport. + source: How the transport was selected ('preferred', 'additional', 'fallback'). + """ + + transport: str + url: str + source: NegotiationSource + + +class TransportNegotiationError(Exception): + """Raised when no compatible transport can be negotiated.""" + + def __init__( + self, + client_transports: list[str], + server_transports: list[str], + message: str | None = None, + ) -> None: + """Initialize the error with negotiation details. + + Args: + client_transports: Transports supported by the client. + server_transports: Transports supported by the server. + message: Optional custom error message. + """ + self.client_transports = client_transports + self.server_transports = server_transports + if message is None: + message = ( + f"No compatible transport found. " + f"Client supports: {client_transports}. " + f"Server supports: {server_transports}." + ) + super().__init__(message) + + +def _get_server_interfaces(agent_card: AgentCard) -> list[AgentInterface]: + """Extract all available interfaces from an AgentCard. + + Creates a unified list of interfaces including the primary URL and + any additional interfaces declared by the agent. + + Args: + agent_card: The agent's card containing transport information. + + Returns: + List of AgentInterface objects representing all available endpoints. + """ + interfaces: list[AgentInterface] = [] + + primary_transport = agent_card.preferred_transport or JSONRPC_TRANSPORT + interfaces.append( + AgentInterface( + transport=primary_transport, + url=agent_card.url, + ) + ) + + if agent_card.additional_interfaces: + for interface in agent_card.additional_interfaces: + is_duplicate = any( + i.url == interface.url and i.transport == interface.transport + for i in interfaces + ) + if not is_duplicate: + interfaces.append(interface) + + return interfaces + + +def negotiate_transport( + agent_card: AgentCard, + client_supported_transports: list[str] | None = None, + client_preferred_transport: str | None = None, + emit_event: bool = True, + endpoint: str | None = None, + a2a_agent_name: str | None = None, +) -> NegotiatedTransport: + """Negotiate the transport protocol between client and server. + + Compares the client's supported transports with the server's available + interfaces to find a compatible transport and URL. + + Negotiation logic: + 1. If client_preferred_transport is set and server supports it → use it + 2. Otherwise, if server's preferred is in client's supported → use server's + 3. Otherwise, find first match from client's supported in server's interfaces + + Args: + agent_card: The server's AgentCard with transport information. + client_supported_transports: Transports the client can use. + Defaults to ["JSONRPC"] if not specified. + client_preferred_transport: Client's preferred transport. If set and + server supports it, takes priority over server preference. + emit_event: Whether to emit a transport negotiation event. + endpoint: Original endpoint URL for event metadata. + a2a_agent_name: Agent name for event metadata. + + Returns: + NegotiatedTransport with the selected transport, URL, and source. + + Raises: + TransportNegotiationError: If no compatible transport is found. + """ + if client_supported_transports is None: + client_supported_transports = [JSONRPC_TRANSPORT] + + client_transports = [t.upper() for t in client_supported_transports] + client_preferred = ( + client_preferred_transport.upper() if client_preferred_transport else None + ) + + server_interfaces = _get_server_interfaces(agent_card) + server_transports = [i.transport.upper() for i in server_interfaces] + + transport_to_interface: dict[str, AgentInterface] = {} + for interface in server_interfaces: + transport_upper = interface.transport.upper() + if transport_upper not in transport_to_interface: + transport_to_interface[transport_upper] = interface + + result: NegotiatedTransport | None = None + + if client_preferred and client_preferred in transport_to_interface: + interface = transport_to_interface[client_preferred] + result = NegotiatedTransport( + transport=interface.transport, + url=interface.url, + source="client_preferred", + ) + else: + server_preferred = (agent_card.preferred_transport or JSONRPC_TRANSPORT).upper() + if ( + server_preferred in client_transports + and server_preferred in transport_to_interface + ): + interface = transport_to_interface[server_preferred] + result = NegotiatedTransport( + transport=interface.transport, + url=interface.url, + source="server_preferred", + ) + else: + for transport in client_transports: + if transport in transport_to_interface: + interface = transport_to_interface[transport] + result = NegotiatedTransport( + transport=interface.transport, + url=interface.url, + source="fallback", + ) + break + + if result is None: + raise TransportNegotiationError( + client_transports=client_transports, + server_transports=server_transports, + ) + + if emit_event: + crewai_event_bus.emit( + None, + A2ATransportNegotiatedEvent( + endpoint=endpoint or agent_card.url, + a2a_agent_name=a2a_agent_name or agent_card.name, + negotiated_transport=result.transport, + negotiated_url=result.url, + source=result.source, + client_supported_transports=client_transports, + server_supported_transports=server_transports, + server_preferred_transport=agent_card.preferred_transport + or JSONRPC_TRANSPORT, + client_preferred_transport=client_preferred, + ), + ) + + return result diff --git a/lib/crewai-a2a/src/crewai_a2a/wrapper.py b/lib/crewai-a2a/src/crewai_a2a/wrapper.py new file mode 100644 index 000000000..d4acc40f1 --- /dev/null +++ b/lib/crewai-a2a/src/crewai_a2a/wrapper.py @@ -0,0 +1,1752 @@ +"""A2A agent wrapping logic for metaclass integration. + +Wraps agent classes with A2A delegation capabilities. +""" + +from __future__ import annotations + +import asyncio +from collections.abc import Callable, Coroutine, Mapping +from concurrent.futures import ThreadPoolExecutor, as_completed +from functools import wraps +import json +from types import MethodType +from typing import TYPE_CHECKING, Any, NamedTuple + +from a2a.types import Role, TaskState +from crewai.events.event_bus import crewai_event_bus +from crewai.events.types.a2a_events import ( + A2AConversationCompletedEvent, + A2AMessageSentEvent, +) +from crewai.lite_agent_output import LiteAgentOutput +from crewai.task import Task +from pydantic import BaseModel, ValidationError + +from crewai_a2a.config import A2AClientConfig, A2AConfig +from crewai_a2a.extensions.base import ( + A2AExtension, + ConversationState, + ExtensionRegistry, +) +from crewai_a2a.task_helpers import TaskStateResult +from crewai_a2a.templates import ( + AVAILABLE_AGENTS_TEMPLATE, + CONVERSATION_TURN_INFO_TEMPLATE, + PREVIOUS_A2A_CONVERSATION_TEMPLATE, + REMOTE_AGENT_RESPONSE_NOTICE, + UNAVAILABLE_AGENTS_NOTICE_TEMPLATE, +) +from crewai_a2a.types import AgentResponseProtocol +from crewai_a2a.utils.agent_card import ( + afetch_agent_card, + fetch_agent_card, + inject_a2a_server_methods, +) +from crewai_a2a.utils.delegation import ( + aexecute_a2a_delegation, + execute_a2a_delegation, +) +from crewai_a2a.utils.response_model import get_a2a_agents_and_response_model + + +if TYPE_CHECKING: + from a2a.types import AgentCard, Message + from crewai.agent.core import Agent + from crewai.tools.base_tool import BaseTool + + +class DelegationContext(NamedTuple): + """Context prepared for A2A delegation. + + Groups all the values needed to execute a delegation to a remote A2A agent. + """ + + a2a_agents: list[A2AConfig | A2AClientConfig] + agent_response_model: type[BaseModel] | None + current_request: str + agent_id: str + agent_config: A2AConfig | A2AClientConfig + context_id: str | None + task_id: str | None + metadata: dict[str, Any] | None + extensions: dict[str, Any] | None + reference_task_ids: list[str] + original_task_description: str + max_turns: int + + +class DelegationState(NamedTuple): + """Mutable state for A2A delegation loop. + + Groups values that may change during delegation turns. + """ + + current_request: str + context_id: str | None + task_id: str | None + reference_task_ids: list[str] + conversation_history: list[Message] + agent_card: AgentCard | None + agent_card_dict: dict[str, Any] | None + agent_name: str | None + + +def wrap_agent_with_a2a_instance( + agent: Agent, extension_registry: ExtensionRegistry | None = None +) -> None: + """Wrap an agent instance's task execution and kickoff methods with A2A support. + + This function modifies the agent instance by wrapping its execute_task, + aexecute_task, kickoff, and kickoff_async methods to add A2A delegation + capabilities. Should only be called when the agent has a2a configuration set. + + Args: + agent: The agent instance to wrap. + extension_registry: Optional registry of A2A extensions. + """ + if extension_registry is None: + extension_registry = ExtensionRegistry() + + extension_registry.inject_all_tools(agent) + + original_execute_task = agent.execute_task.__func__ # type: ignore[attr-defined] + original_aexecute_task = agent.aexecute_task.__func__ # type: ignore[attr-defined] + + @wraps(original_execute_task) + def execute_task_with_a2a( + self: Agent, + task: Task, + context: str | None = None, + tools: list[BaseTool] | None = None, + ) -> str: + """Execute task with A2A delegation support (sync).""" + if not self.a2a: + return original_execute_task(self, task, context, tools) # type: ignore[no-any-return] + + a2a_agents, agent_response_model = get_a2a_agents_and_response_model(self.a2a) + + return _execute_task_with_a2a( + self=self, + a2a_agents=a2a_agents, + original_fn=original_execute_task, + task=task, + agent_response_model=agent_response_model, + context=context, + tools=tools, + extension_registry=extension_registry, + ) + + @wraps(original_aexecute_task) + async def aexecute_task_with_a2a( + self: Agent, + task: Task, + context: str | None = None, + tools: list[BaseTool] | None = None, + ) -> str: + """Execute task with A2A delegation support (async).""" + if not self.a2a: + return await original_aexecute_task(self, task, context, tools) # type: ignore[no-any-return] + + a2a_agents, agent_response_model = get_a2a_agents_and_response_model(self.a2a) + + return await _aexecute_task_with_a2a( + self=self, + a2a_agents=a2a_agents, + original_fn=original_aexecute_task, + task=task, + agent_response_model=agent_response_model, + context=context, + tools=tools, + extension_registry=extension_registry, + ) + + object.__setattr__(agent, "execute_task", MethodType(execute_task_with_a2a, agent)) + object.__setattr__( + agent, "aexecute_task", MethodType(aexecute_task_with_a2a, agent) + ) + + original_kickoff = agent.kickoff.__func__ # type: ignore[attr-defined] + original_kickoff_async = agent.kickoff_async.__func__ # type: ignore[attr-defined] + + @wraps(original_kickoff) + def kickoff_with_a2a( + self: Agent, + messages: str | list[Any], + response_format: type[Any] | None = None, + input_files: dict[str, Any] | None = None, + ) -> Any: + """Execute agent kickoff with A2A delegation support.""" + if not self.a2a: + return original_kickoff(self, messages, response_format, input_files) + + a2a_agents, agent_response_model = get_a2a_agents_and_response_model(self.a2a) + + if not a2a_agents: + return original_kickoff(self, messages, response_format, input_files) + + return _kickoff_with_a2a( + self=self, + a2a_agents=a2a_agents, + original_kickoff=original_kickoff, + messages=messages, + response_format=response_format, + input_files=input_files, + agent_response_model=agent_response_model, + extension_registry=extension_registry, + ) + + @wraps(original_kickoff_async) + async def kickoff_async_with_a2a( + self: Agent, + messages: str | list[Any], + response_format: type[Any] | None = None, + input_files: dict[str, Any] | None = None, + ) -> Any: + """Execute agent kickoff with A2A delegation support.""" + if not self.a2a: + return await original_kickoff_async( + self, messages, response_format, input_files + ) + + a2a_agents, agent_response_model = get_a2a_agents_and_response_model(self.a2a) + + if not a2a_agents: + return await original_kickoff_async( + self, messages, response_format, input_files + ) + + return await _akickoff_with_a2a( + self=self, + a2a_agents=a2a_agents, + original_kickoff_async=original_kickoff_async, + messages=messages, + response_format=response_format, + input_files=input_files, + agent_response_model=agent_response_model, + extension_registry=extension_registry, + ) + + object.__setattr__(agent, "kickoff", MethodType(kickoff_with_a2a, agent)) + object.__setattr__( + agent, "kickoff_async", MethodType(kickoff_async_with_a2a, agent) + ) + + inject_a2a_server_methods(agent) + + +def _fetch_card_from_config( + config: A2AConfig | A2AClientConfig, +) -> tuple[A2AConfig | A2AClientConfig, AgentCard | Exception]: + """Fetch agent card from A2A config. + + Args: + config: A2A configuration + + Returns: + Tuple of (config, card or exception) + """ + try: + card = fetch_agent_card( + endpoint=config.endpoint, + auth=config.auth, + timeout=config.timeout, + ) + return config, card + except Exception as e: + return config, e + + +def _fetch_agent_cards_concurrently( + a2a_agents: list[A2AConfig | A2AClientConfig], +) -> tuple[dict[str, AgentCard], dict[str, str]]: + """Fetch agent cards concurrently for multiple A2A agents. + + Args: + a2a_agents: List of A2A agent configurations + + Returns: + Tuple of (agent_cards dict, failed_agents dict mapping endpoint to error message) + """ + agent_cards: dict[str, AgentCard] = {} + failed_agents: dict[str, str] = {} + + if not a2a_agents: + return agent_cards, failed_agents + + max_workers = min(len(a2a_agents), 10) + with ThreadPoolExecutor(max_workers=max_workers) as executor: + futures = { + executor.submit(_fetch_card_from_config, config): config + for config in a2a_agents + } + for future in as_completed(futures): + config, result = future.result() + if isinstance(result, Exception): + if config.fail_fast: + raise RuntimeError( + f"Failed to fetch agent card from {config.endpoint}. " + f"Ensure the A2A agent is running and accessible. Error: {result}" + ) from result + failed_agents[config.endpoint] = str(result) + else: + agent_cards[config.endpoint] = result + + return agent_cards, failed_agents + + +def _execute_task_with_a2a( + self: Agent, + a2a_agents: list[A2AConfig | A2AClientConfig], + original_fn: Callable[..., str], + task: Task, + agent_response_model: type[BaseModel] | None, + context: str | None, + tools: list[BaseTool] | None, + extension_registry: ExtensionRegistry, +) -> str: + """Wrap execute_task with A2A delegation logic. + + Args: + self: The agent instance + a2a_agents: Dictionary of A2A agent configurations + original_fn: The original execute_task method + task: The task to execute + context: Optional context for task execution + tools: Optional tools available to the agent + agent_response_model: Optional agent response model + extension_registry: Registry of A2A extensions + + Returns: + Task execution result (either from LLM or A2A agent) + """ + original_description: str = task.description + original_output_pydantic = task.output_pydantic + original_response_model = task.response_model + + agent_cards, failed_agents = _fetch_agent_cards_concurrently(a2a_agents) + + if not agent_cards and a2a_agents and failed_agents: + unavailable_agents_text = "" + for endpoint, error in failed_agents.items(): + unavailable_agents_text += f" - {endpoint}: {error}\n" + + notice = UNAVAILABLE_AGENTS_NOTICE_TEMPLATE.substitute( + unavailable_agents=unavailable_agents_text + ) + task.description = f"{original_description}{notice}" + + try: + return original_fn(self, task, context, tools) + finally: + task.description = original_description + + task.description, _, extension_states = _augment_prompt_with_a2a( + a2a_agents=a2a_agents, + task_description=original_description, + agent_cards=agent_cards, + failed_agents=failed_agents, + extension_registry=extension_registry, + ) + task.response_model = agent_response_model + + try: + raw_result = original_fn(self, task, context, tools) + agent_response = _parse_agent_response( + raw_result=raw_result, agent_response_model=agent_response_model + ) + + if extension_registry and isinstance(agent_response, BaseModel): + agent_response = extension_registry.process_response_with_all( + agent_response, extension_states + ) + + if isinstance(agent_response, BaseModel) and isinstance( + agent_response, AgentResponseProtocol + ): + if agent_response.is_a2a: + return _delegate_to_a2a( + self, + agent_response=agent_response, + task=task, + original_fn=original_fn, + context=context, + tools=tools, + agent_cards=agent_cards, + original_task_description=original_description, + _extension_registry=extension_registry, + ) + task.output_pydantic = None + return agent_response.message + + return raw_result + finally: + task.description = original_description + if task.output_pydantic is not None: + task.output_pydantic = original_output_pydantic + task.response_model = original_response_model + + +def _kickoff_with_a2a( + self: Agent, + a2a_agents: list[A2AConfig | A2AClientConfig], + original_kickoff: Callable[..., LiteAgentOutput], + messages: str | list[Any], + response_format: type[Any] | None, + input_files: dict[str, Any] | None, + agent_response_model: type[BaseModel] | None, + extension_registry: ExtensionRegistry, +) -> LiteAgentOutput: + """Execute kickoff with A2A delegation support (sync). + + Args: + self: The agent instance. + a2a_agents: List of A2A agent configurations. + original_kickoff: The original kickoff method. + messages: Messages to send to the agent. + response_format: Optional response format. + input_files: Optional input files. + agent_response_model: Optional agent response model. + extension_registry: Registry of A2A extensions. + + Returns: + LiteAgentOutput from kickoff or A2A delegation. + """ + if isinstance(messages, str): + description = messages + else: + content = next( + (m["content"] for m in reversed(messages) if m["role"] == "user"), + None, + ) + description = content if isinstance(content, str) else "" + + if not description: + return original_kickoff(self, messages, response_format, input_files) + + fake_task = Task( + description=description, + agent=self, + expected_output="Result from A2A delegation", + input_files=input_files or {}, + ) + + agent_cards, failed_agents = _fetch_agent_cards_concurrently(a2a_agents) + + if not agent_cards and a2a_agents and failed_agents: + return original_kickoff(self, messages, response_format, input_files) + + fake_task.description, _, extension_states = _augment_prompt_with_a2a( + a2a_agents=a2a_agents, + task_description=description, + agent_cards=agent_cards, + failed_agents=failed_agents, + extension_registry=extension_registry, + ) + fake_task.response_model = agent_response_model + + try: + result: LiteAgentOutput = original_kickoff( + self, messages, agent_response_model or response_format, input_files + ) + agent_response = _parse_agent_response( + raw_result=result.raw, agent_response_model=agent_response_model + ) + + if extension_registry and isinstance(agent_response, BaseModel): + agent_response = extension_registry.process_response_with_all( + agent_response, extension_states + ) + + if isinstance(agent_response, BaseModel) and isinstance( + agent_response, AgentResponseProtocol + ): + if agent_response.is_a2a: + + def _kickoff_adapter( + self_: Agent, + _task: Task, + _context: str | None, + _tools: list[Any] | None, + ) -> str: + fmt = ( + _task.response_model or agent_response_model or response_format + ) + output: LiteAgentOutput = original_kickoff( + self_, messages, fmt, input_files + ) + return output.raw + + result_str = _delegate_to_a2a( + self, + agent_response=agent_response, + task=fake_task, + original_fn=_kickoff_adapter, + context=None, + tools=None, + agent_cards=agent_cards, + original_task_description=description, + _extension_registry=extension_registry, + ) + return LiteAgentOutput( + raw=result_str, + pydantic=None, + agent_role=self.role, + usage_metrics=None, + messages=[], + ) + return LiteAgentOutput( + raw=agent_response.message, + pydantic=None, + agent_role=self.role, + usage_metrics=result.usage_metrics, + messages=result.messages, + ) + + return result + finally: + fake_task.description = description + + +async def _akickoff_with_a2a( + self: Agent, + a2a_agents: list[A2AConfig | A2AClientConfig], + original_kickoff_async: Callable[..., Coroutine[Any, Any, LiteAgentOutput]], + messages: str | list[Any], + response_format: type[Any] | None, + input_files: dict[str, Any] | None, + agent_response_model: type[BaseModel] | None, + extension_registry: ExtensionRegistry, +) -> LiteAgentOutput: + """Execute kickoff with A2A delegation support (async). + + Args: + self: The agent instance. + a2a_agents: List of A2A agent configurations. + original_kickoff_async: The original kickoff_async method. + messages: Messages to send to the agent. + response_format: Optional response format. + input_files: Optional input files. + agent_response_model: Optional agent response model. + extension_registry: Registry of A2A extensions. + + Returns: + LiteAgentOutput from kickoff or A2A delegation. + """ + if isinstance(messages, str): + description = messages + else: + content = next( + (m["content"] for m in reversed(messages) if m["role"] == "user"), + None, + ) + description = content if isinstance(content, str) else "" + + if not description: + return await original_kickoff_async( + self, messages, response_format, input_files + ) + + fake_task = Task( + description=description, + agent=self, + expected_output="Result from A2A delegation", + input_files=input_files or {}, + ) + + agent_cards, failed_agents = await _afetch_agent_cards_concurrently(a2a_agents) + + if not agent_cards and a2a_agents and failed_agents: + return await original_kickoff_async( + self, messages, response_format, input_files + ) + + fake_task.description, _, extension_states = _augment_prompt_with_a2a( + a2a_agents=a2a_agents, + task_description=description, + agent_cards=agent_cards, + failed_agents=failed_agents, + extension_registry=extension_registry, + ) + fake_task.response_model = agent_response_model + + try: + result: LiteAgentOutput = await original_kickoff_async( + self, messages, agent_response_model or response_format, input_files + ) + agent_response = _parse_agent_response( + raw_result=result.raw, agent_response_model=agent_response_model + ) + + if extension_registry and isinstance(agent_response, BaseModel): + agent_response = extension_registry.process_response_with_all( + agent_response, extension_states + ) + + if isinstance(agent_response, BaseModel) and isinstance( + agent_response, AgentResponseProtocol + ): + if agent_response.is_a2a: + + async def _kickoff_adapter( + self_: Agent, + _task: Task, + _context: str | None, + _tools: list[Any] | None, + ) -> str: + fmt = ( + _task.response_model or agent_response_model or response_format + ) + output: LiteAgentOutput = await original_kickoff_async( + self_, messages, fmt, input_files + ) + return output.raw + + result_str = await _adelegate_to_a2a( + self, + agent_response=agent_response, + task=fake_task, + original_fn=_kickoff_adapter, + context=None, + tools=None, + agent_cards=agent_cards, + original_task_description=description, + _extension_registry=extension_registry, + ) + return LiteAgentOutput( + raw=result_str, + pydantic=None, + agent_role=self.role, + usage_metrics=None, + messages=[], + ) + return LiteAgentOutput( + raw=agent_response.message, + pydantic=None, + agent_role=self.role, + usage_metrics=result.usage_metrics, + messages=result.messages, + ) + + return result + finally: + fake_task.description = description + + +def _augment_prompt_with_a2a( + a2a_agents: list[A2AConfig | A2AClientConfig], + task_description: str, + agent_cards: Mapping[str, AgentCard | dict[str, Any]], + conversation_history: list[Message] | None = None, + turn_num: int = 0, + max_turns: int | None = None, + failed_agents: dict[str, str] | None = None, + extension_registry: ExtensionRegistry | None = None, + remote_status_notice: str = "", +) -> tuple[str, bool, dict[type[A2AExtension], ConversationState]]: + """Add A2A delegation instructions to prompt. + + Args: + a2a_agents: Dictionary of A2A agent configurations + task_description: Original task description + agent_cards: dictionary mapping agent IDs to AgentCards + conversation_history: Previous A2A Messages from conversation + turn_num: Current turn number (0-indexed) + max_turns: Maximum allowed turns (from config) + failed_agents: Dictionary mapping failed agent endpoints to error messages + extension_registry: Optional registry of A2A extensions + remote_status_notice: Optional notice about remote agent status to append + + Returns: + Tuple of (augmented prompt, disable_structured_output flag, extension_states dict) + """ + + if not agent_cards: + return task_description, False, {} + + agents_text = "" + + for config in a2a_agents: + if config.endpoint in agent_cards: + card = agent_cards[config.endpoint] + if isinstance(card, dict): + filtered = { + k: v + for k, v in card.items() + if k in {"description", "url", "skills"} and v is not None + } + agents_text += f"\n{json.dumps(filtered, indent=2)}\n" + else: + agents_text += f"\n{card.model_dump_json(indent=2, exclude_none=True, include={'description', 'url', 'skills'})}\n" + + failed_agents = failed_agents or {} + if failed_agents: + agents_text += "\n\n" + for endpoint, error in failed_agents.items(): + agents_text += f"\n\n" + + agents_text = AVAILABLE_AGENTS_TEMPLATE.substitute(available_a2a_agents=agents_text) + + history_text = "" + + if conversation_history: + for msg in conversation_history: + history_text += f"\n{msg.model_dump_json(indent=2, exclude_none=True, exclude={'message_id'})}\n" + + history_text = PREVIOUS_A2A_CONVERSATION_TEMPLATE.substitute( + previous_a2a_conversation=history_text + ) + + extension_states = {} + disable_structured_output = False + if extension_registry and conversation_history: + extension_states = extension_registry.extract_all_states(conversation_history) + for state in extension_states.values(): + if state.is_ready(): + disable_structured_output = True + break + turn_info = "" + + if max_turns is not None and conversation_history: + turn_count = turn_num + 1 + warning = "" + if turn_count >= max_turns: + warning = ( + "CRITICAL: This is the FINAL turn. You MUST conclude the conversation now.\n" + "Set is_a2a=false and provide your final response to complete the task." + ) + elif turn_count == max_turns - 1: + warning = "WARNING: Next turn will be the last. Consider wrapping up the conversation." + + turn_info = CONVERSATION_TURN_INFO_TEMPLATE.substitute( + turn_count=turn_count, + max_turns=max_turns, + warning=warning, + ) + + augmented_prompt = f"""{task_description} + +IMPORTANT: You have the ability to delegate this task to remote A2A agents. +{agents_text} +{history_text}{turn_info}{remote_status_notice} + +""" + + if extension_registry: + augmented_prompt = extension_registry.augment_prompt_with_all( + augmented_prompt, extension_states + ) + + return augmented_prompt, disable_structured_output, extension_states + + +def _parse_agent_response( + raw_result: str | dict[str, Any], agent_response_model: type[BaseModel] | None +) -> BaseModel | str | dict[str, Any]: + """Parse LLM output as AgentResponse or return raw agent response.""" + if agent_response_model: + try: + if isinstance(raw_result, str): + return agent_response_model.model_validate_json(raw_result) + if isinstance(raw_result, dict): + return agent_response_model.model_validate(raw_result) + except ValidationError: + return raw_result + return raw_result + + +def _handle_max_turns_exceeded( + conversation_history: list[Message], + max_turns: int, + from_task: Any | None = None, + from_agent: Any | None = None, + endpoint: str | None = None, + a2a_agent_name: str | None = None, + agent_card: dict[str, Any] | None = None, +) -> str: + """Handle the case when max turns is exceeded. + + Shared logic for both sync and async delegation. + + Returns: + Final message if found in history. + + Raises: + Exception: If no final message found and max turns exceeded. + """ + if conversation_history: + for msg in reversed(conversation_history): + if msg.role == Role.agent: + text_parts = [ + part.root.text for part in msg.parts if part.root.kind == "text" + ] + final_message = ( + " ".join(text_parts) if text_parts else "Conversation completed" + ) + crewai_event_bus.emit( + None, + A2AConversationCompletedEvent( + status="completed", + final_result=final_message, + error=None, + total_turns=max_turns, + from_task=from_task, + from_agent=from_agent, + endpoint=endpoint, + a2a_agent_name=a2a_agent_name, + agent_card=agent_card, + ), + ) + return final_message + + crewai_event_bus.emit( + None, + A2AConversationCompletedEvent( + status="failed", + final_result=None, + error=f"Conversation exceeded maximum turns ({max_turns})", + total_turns=max_turns, + from_task=from_task, + from_agent=from_agent, + endpoint=endpoint, + a2a_agent_name=a2a_agent_name, + agent_card=agent_card, + ), + ) + raise Exception(f"A2A conversation exceeded maximum turns ({max_turns})") + + +def _emit_delegation_failed( + error_msg: str, + turn_num: int, + from_task: Any | None, + from_agent: Any | None, + endpoint: str | None, + a2a_agent_name: str | None, + agent_card: dict[str, Any] | None, +) -> str: + """Emit failure event and return formatted error message.""" + crewai_event_bus.emit( + None, + A2AConversationCompletedEvent( + status="failed", + final_result=None, + error=error_msg, + total_turns=turn_num + 1, + from_task=from_task, + from_agent=from_agent, + endpoint=endpoint, + a2a_agent_name=a2a_agent_name, + agent_card=agent_card, + ), + ) + return f"A2A delegation failed: {error_msg}" + + +def _process_response_result( + raw_result: str, + disable_structured_output: bool, + turn_num: int, + agent_role: str, + agent_response_model: type[BaseModel] | None, + extension_registry: ExtensionRegistry | None = None, + extension_states: dict[type[A2AExtension], ConversationState] | None = None, + from_task: Any | None = None, + from_agent: Any | None = None, + endpoint: str | None = None, + a2a_agent_name: str | None = None, + agent_card: dict[str, Any] | None = None, +) -> tuple[str | None, str | None]: + """Process LLM response and determine next action. + + Shared logic for both sync and async handlers. + + Returns: + Tuple of (final_result, next_request). + """ + if disable_structured_output: + final_turn_number = turn_num + 1 + result_text = str(raw_result) + crewai_event_bus.emit( + None, + A2AMessageSentEvent( + message=result_text, + turn_number=final_turn_number, + is_multiturn=True, + agent_role=agent_role, + from_task=from_task, + from_agent=from_agent, + endpoint=endpoint, + a2a_agent_name=a2a_agent_name, + ), + ) + crewai_event_bus.emit( + None, + A2AConversationCompletedEvent( + status="completed", + final_result=result_text, + error=None, + total_turns=final_turn_number, + from_task=from_task, + from_agent=from_agent, + endpoint=endpoint, + a2a_agent_name=a2a_agent_name, + agent_card=agent_card, + ), + ) + return result_text, None + + llm_response = _parse_agent_response( + raw_result=raw_result, agent_response_model=agent_response_model + ) + + if extension_registry and isinstance(llm_response, BaseModel): + llm_response = extension_registry.process_response_with_all( + llm_response, extension_states or {} + ) + + if isinstance(llm_response, BaseModel) and isinstance( + llm_response, AgentResponseProtocol + ): + if not llm_response.is_a2a: + final_turn_number = turn_num + 1 + crewai_event_bus.emit( + None, + A2AMessageSentEvent( + message=str(llm_response.message), + turn_number=final_turn_number, + is_multiturn=True, + agent_role=agent_role, + from_task=from_task, + from_agent=from_agent, + endpoint=endpoint, + a2a_agent_name=a2a_agent_name, + ), + ) + crewai_event_bus.emit( + None, + A2AConversationCompletedEvent( + status="completed", + final_result=str(llm_response.message), + error=None, + total_turns=final_turn_number, + from_task=from_task, + from_agent=from_agent, + endpoint=endpoint, + a2a_agent_name=a2a_agent_name, + agent_card=agent_card, + ), + ) + return llm_response.message, None + return None, llm_response.message + + return str(raw_result), None + + +def _prepare_agent_cards_dict( + a2a_result: TaskStateResult, + agent_id: str, + agent_cards: Mapping[str, AgentCard | dict[str, Any]] | None, +) -> dict[str, AgentCard | dict[str, Any]]: + """Prepare agent cards dictionary from result and existing cards. + + Shared logic for both sync and async response handlers. + """ + agent_cards_dict: dict[str, AgentCard | dict[str, Any]] = ( + dict(agent_cards) if agent_cards else {} + ) + if "agent_card" in a2a_result and agent_id not in agent_cards_dict: + agent_cards_dict[agent_id] = a2a_result["agent_card"] + return agent_cards_dict + + +def _init_delegation_state( + ctx: DelegationContext, + agent_cards: dict[str, AgentCard] | None, +) -> DelegationState: + """Initialize delegation state from context and agent cards. + + Args: + ctx: Delegation context with config and settings. + agent_cards: Pre-fetched agent cards. + + Returns: + Initial delegation state for the conversation loop. + """ + current_agent_card = agent_cards.get(ctx.agent_id) if agent_cards else None + return DelegationState( + current_request=ctx.current_request, + context_id=ctx.context_id, + task_id=ctx.task_id, + reference_task_ids=list(ctx.reference_task_ids), + conversation_history=[], + agent_card=current_agent_card, + agent_card_dict=current_agent_card.model_dump() if current_agent_card else None, + agent_name=current_agent_card.name if current_agent_card else None, + ) + + +def _get_turn_context( + agent_config: A2AConfig | A2AClientConfig, +) -> tuple[Any | None, list[str] | None]: + """Get context for a delegation turn. + + Returns: + Tuple of (agent_branch, accepted_output_modes). + """ + console_formatter = getattr(crewai_event_bus, "_console", None) + agent_branch = None + if console_formatter: + agent_branch = getattr( + console_formatter, "current_agent_branch", None + ) or getattr(console_formatter, "current_task_branch", None) + + accepted_output_modes = None + if isinstance(agent_config, A2AClientConfig): + accepted_output_modes = agent_config.accepted_output_modes + + return agent_branch, accepted_output_modes + + +def _prepare_delegation_context( + self: Agent, + agent_response: AgentResponseProtocol, + task: Task, + original_task_description: str | None, +) -> DelegationContext: + """Prepare delegation context from agent response and task. + + Shared logic for both sync and async delegation. + + Returns: + DelegationContext with all values needed for delegation. + """ + a2a_agents, agent_response_model = get_a2a_agents_and_response_model(self.a2a) + agent_ids = tuple(config.endpoint for config in a2a_agents) + current_request = str(agent_response.message) + + if not a2a_agents: + raise ValueError("No A2A agents configured for delegation") + + if isinstance(agent_response, AgentResponseProtocol) and agent_response.a2a_ids: + agent_id = agent_response.a2a_ids[0] + else: + agent_id = agent_ids[0] + + if agent_id not in agent_ids: + raise ValueError(f"Unknown A2A agent ID: {agent_id} not in {agent_ids}") + + agent_config = next(filter(lambda x: x.endpoint == agent_id, a2a_agents), None) + if agent_config is None: + raise ValueError(f"Agent configuration not found for endpoint: {agent_id}") + task_config = task.config or {} + + if original_task_description is None: + original_task_description = task.description + + return DelegationContext( + a2a_agents=a2a_agents, + agent_response_model=agent_response_model, + current_request=current_request, + agent_id=agent_id, + agent_config=agent_config, + context_id=task_config.get("context_id"), + task_id=task_config.get("task_id"), + metadata=task_config.get("metadata"), + extensions=task_config.get("extensions"), + reference_task_ids=task_config.get("reference_task_ids", []), + original_task_description=original_task_description, + max_turns=agent_config.max_turns, + ) + + +def _handle_task_completion( + a2a_result: TaskStateResult, + task: Task, + task_id_config: str | None, + reference_task_ids: list[str], + agent_config: A2AConfig | A2AClientConfig, + turn_num: int, + from_task: Any | None = None, + from_agent: Any | None = None, + endpoint: str | None = None, + a2a_agent_name: str | None = None, + agent_card: dict[str, Any] | None = None, +) -> tuple[str | None, str | None, list[str], str]: + """Handle task completion state including reference task updates. + + When a remote task completes, this function: + 1. Adds the completed task_id to reference_task_ids (if not already present) + 2. Clears task_id_config to signal that a new task ID should be generated for next turn + 3. Updates task.config with the reference list for subsequent A2A calls + + The reference_task_ids list tracks all completed tasks in this conversation chain, + allowing the remote agent to maintain context across multi-turn interactions. + + Shared logic for both sync and async delegation. + + Args: + a2a_result: Result from A2A delegation containing task status. + task: CrewAI Task object to update with reference IDs. + task_id_config: Current task ID (will be added to references if task completed). + reference_task_ids: Mutable list of completed task IDs (updated in place). + agent_config: A2A configuration with trust settings. + turn_num: Current turn number. + from_task: Optional CrewAI Task for event metadata. + from_agent: Optional CrewAI Agent for event metadata. + endpoint: A2A endpoint URL. + a2a_agent_name: Name of remote A2A agent. + agent_card: Agent card dict for event metadata. + + Returns: + Tuple of (result_if_trusted, updated_task_id, updated_reference_task_ids, remote_notice). + - result_if_trusted: Final result if trust_remote_completion_status=True, else None + - updated_task_id: None (cleared to generate new ID for next turn) + - updated_reference_task_ids: The mutated list with completed task added + - remote_notice: Template notice about remote agent response + """ + remote_notice = "" + if a2a_result["status"] == TaskState.completed: + remote_notice = REMOTE_AGENT_RESPONSE_NOTICE + + if task_id_config is not None and task_id_config not in reference_task_ids: + reference_task_ids.append(task_id_config) + + if task.config is None: + task.config = {} + task.config["reference_task_ids"] = list(reference_task_ids) + + task_id_config = None + + if agent_config.trust_remote_completion_status: + result_text = a2a_result.get("result", "") + final_turn_number = turn_num + 1 + crewai_event_bus.emit( + None, + A2AConversationCompletedEvent( + status="completed", + final_result=result_text, + error=None, + total_turns=final_turn_number, + from_task=from_task, + from_agent=from_agent, + endpoint=endpoint, + a2a_agent_name=a2a_agent_name, + agent_card=agent_card, + ), + ) + return str(result_text), task_id_config, reference_task_ids, remote_notice + + return None, task_id_config, reference_task_ids, remote_notice + + +def _handle_agent_response_and_continue( + self: Agent, + a2a_result: TaskStateResult, + agent_id: str, + agent_cards: dict[str, AgentCard] | None, + a2a_agents: list[A2AConfig | A2AClientConfig], + original_task_description: str, + conversation_history: list[Message], + turn_num: int, + max_turns: int, + task: Task, + original_fn: Callable[..., str], + context: str | None, + tools: list[BaseTool] | None, + agent_response_model: type[BaseModel] | None, + extension_registry: ExtensionRegistry | None = None, + remote_status_notice: str = "", + endpoint: str | None = None, + a2a_agent_name: str | None = None, + agent_card: dict[str, Any] | None = None, +) -> tuple[str | None, str | None]: + """Handle A2A result and get CrewAI agent's response. + + Args: + self: The agent instance + a2a_result: Result from A2A delegation + agent_id: ID of the A2A agent + agent_cards: Pre-fetched agent cards + a2a_agents: List of A2A configurations + original_task_description: Original task description + conversation_history: Conversation history + turn_num: Current turn number + max_turns: Maximum turns allowed + task: The task being executed + original_fn: Original execute_task method + context: Optional context + tools: Optional tools + agent_response_model: Response model for parsing + + Returns: + Tuple of (final_result, current_request) where: + - final_result is not None if conversation should end + - current_request is the next message to send if continuing + """ + agent_cards_dict = _prepare_agent_cards_dict(a2a_result, agent_id, agent_cards) + + ( + task.description, + disable_structured_output, + extension_states, + ) = _augment_prompt_with_a2a( + a2a_agents=a2a_agents, + task_description=original_task_description, + conversation_history=conversation_history, + turn_num=turn_num, + max_turns=max_turns, + agent_cards=agent_cards_dict, + remote_status_notice=remote_status_notice, + ) + + original_response_model = task.response_model + if disable_structured_output: + task.response_model = None + + raw_result = original_fn(self, task, context, tools) + + if disable_structured_output: + task.response_model = original_response_model + + return _process_response_result( + raw_result=raw_result, + disable_structured_output=disable_structured_output, + turn_num=turn_num, + agent_role=self.role, + agent_response_model=agent_response_model, + extension_registry=extension_registry, + extension_states=extension_states, + from_task=task, + from_agent=self, + endpoint=endpoint, + a2a_agent_name=a2a_agent_name, + agent_card=agent_card, + ) + + +def _delegate_to_a2a( + self: Agent, + agent_response: AgentResponseProtocol, + task: Task, + original_fn: Callable[..., str], + context: str | None, + tools: list[BaseTool] | None, + agent_cards: dict[str, AgentCard] | None = None, + original_task_description: str | None = None, + _extension_registry: ExtensionRegistry | None = None, +) -> str: + """Delegate to A2A agent with multi-turn conversation support. + + Args: + self: The agent instance + agent_response: The AgentResponse indicating delegation + task: The task being executed (for extracting A2A fields) + original_fn: The original execute_task method for follow-ups + context: Optional context for task execution + tools: Optional tools available to the agent + agent_cards: Pre-fetched agent cards from _execute_task_with_a2a + original_task_description: The original task description before A2A augmentation + _extension_registry: Optional registry of A2A extensions (unused, reserved for future use) + + Returns: + Result from A2A agent + + Raises: + ImportError: If a2a-sdk is not installed + """ + ctx = _prepare_delegation_context( + self, agent_response, task, original_task_description + ) + state = _init_delegation_state(ctx, agent_cards) + current_request = state.current_request + context_id = state.context_id + task_id = state.task_id + reference_task_ids = state.reference_task_ids + conversation_history = state.conversation_history + + try: + for turn_num in range(ctx.max_turns): + agent_branch, accepted_output_modes = _get_turn_context(ctx.agent_config) + + a2a_result = execute_a2a_delegation( + endpoint=ctx.agent_config.endpoint, + auth=ctx.agent_config.auth, + timeout=ctx.agent_config.timeout, + task_description=current_request, + context_id=context_id, + task_id=task_id, + reference_task_ids=reference_task_ids, + metadata=ctx.metadata, + extensions=ctx.extensions, + conversation_history=conversation_history, + agent_id=ctx.agent_id, + agent_role=Role.user, + agent_branch=agent_branch, + response_model=ctx.agent_config.response_model, + turn_number=turn_num + 1, + updates=ctx.agent_config.updates, + transport=ctx.agent_config.transport, + from_task=task, + from_agent=self, + client_extensions=getattr(ctx.agent_config, "extensions", None), + accepted_output_modes=accepted_output_modes, + input_files=task.input_files, + ) + + conversation_history = a2a_result.get("history", []) + + if conversation_history: + latest_message = conversation_history[-1] + if latest_message.task_id is not None: + task_id = latest_message.task_id + if latest_message.context_id is not None: + context_id = latest_message.context_id + + if a2a_result["status"] in [TaskState.completed, TaskState.input_required]: + trusted_result, task_id, reference_task_ids, remote_notice = ( + _handle_task_completion( + a2a_result, + task, + task_id, + reference_task_ids, + ctx.agent_config, + turn_num, + from_task=task, + from_agent=self, + endpoint=ctx.agent_config.endpoint, + a2a_agent_name=state.agent_name, + agent_card=state.agent_card_dict, + ) + ) + if trusted_result is not None: + return trusted_result + + final_result, next_request = _handle_agent_response_and_continue( + self=self, + a2a_result=a2a_result, + agent_id=ctx.agent_id, + agent_cards=agent_cards, + a2a_agents=ctx.a2a_agents, + original_task_description=ctx.original_task_description, + conversation_history=conversation_history, + turn_num=turn_num, + max_turns=ctx.max_turns, + task=task, + original_fn=original_fn, + context=context, + tools=tools, + agent_response_model=ctx.agent_response_model, + extension_registry=_extension_registry, + remote_status_notice=remote_notice, + endpoint=ctx.agent_config.endpoint, + a2a_agent_name=state.agent_name, + agent_card=state.agent_card_dict, + ) + + if final_result is not None: + return final_result + + if next_request is not None: + current_request = next_request + + continue + + error_msg = a2a_result.get("error", "Unknown error") + + final_result, next_request = _handle_agent_response_and_continue( + self=self, + a2a_result=a2a_result, + agent_id=ctx.agent_id, + agent_cards=agent_cards, + a2a_agents=ctx.a2a_agents, + original_task_description=ctx.original_task_description, + conversation_history=conversation_history, + turn_num=turn_num, + max_turns=ctx.max_turns, + task=task, + original_fn=original_fn, + context=context, + tools=tools, + agent_response_model=ctx.agent_response_model, + extension_registry=_extension_registry, + endpoint=ctx.agent_config.endpoint, + a2a_agent_name=state.agent_name, + agent_card=state.agent_card_dict, + ) + + if final_result is not None: + return final_result + + if next_request is not None: + current_request = next_request + continue + + return _emit_delegation_failed( + error_msg, + turn_num, + task, + self, + ctx.agent_config.endpoint, + state.agent_name, + state.agent_card_dict, + ) + + return _handle_max_turns_exceeded( + conversation_history, + ctx.max_turns, + from_task=task, + from_agent=self, + endpoint=ctx.agent_config.endpoint, + a2a_agent_name=state.agent_name, + agent_card=state.agent_card_dict, + ) + + finally: + task.description = ctx.original_task_description + + +async def _afetch_card_from_config( + config: A2AConfig | A2AClientConfig, +) -> tuple[A2AConfig | A2AClientConfig, AgentCard | Exception]: + """Fetch agent card from A2A config asynchronously.""" + try: + card = await afetch_agent_card( + endpoint=config.endpoint, + auth=config.auth, + timeout=config.timeout, + ) + return config, card + except Exception as e: + return config, e + + +async def _afetch_agent_cards_concurrently( + a2a_agents: list[A2AConfig | A2AClientConfig], +) -> tuple[dict[str, AgentCard], dict[str, str]]: + """Fetch agent cards concurrently for multiple A2A agents using asyncio.""" + agent_cards: dict[str, AgentCard] = {} + failed_agents: dict[str, str] = {} + + if not a2a_agents: + return agent_cards, failed_agents + + tasks = [_afetch_card_from_config(config) for config in a2a_agents] + results = await asyncio.gather(*tasks) + + for config, result in results: + if isinstance(result, Exception): + if config.fail_fast: + raise RuntimeError( + f"Failed to fetch agent card from {config.endpoint}. " + f"Ensure the A2A agent is running and accessible. Error: {result}" + ) from result + failed_agents[config.endpoint] = str(result) + else: + agent_cards[config.endpoint] = result + + return agent_cards, failed_agents + + +async def _aexecute_task_with_a2a( + self: Agent, + a2a_agents: list[A2AConfig | A2AClientConfig], + original_fn: Callable[..., Coroutine[Any, Any, str]], + task: Task, + agent_response_model: type[BaseModel] | None, + context: str | None, + tools: list[BaseTool] | None, + extension_registry: ExtensionRegistry, +) -> str: + """Async version of _execute_task_with_a2a.""" + original_description: str = task.description + original_output_pydantic = task.output_pydantic + original_response_model = task.response_model + + agent_cards, failed_agents = await _afetch_agent_cards_concurrently(a2a_agents) + + if not agent_cards and a2a_agents and failed_agents: + unavailable_agents_text = "" + for endpoint, error in failed_agents.items(): + unavailable_agents_text += f" - {endpoint}: {error}\n" + + notice = UNAVAILABLE_AGENTS_NOTICE_TEMPLATE.substitute( + unavailable_agents=unavailable_agents_text + ) + task.description = f"{original_description}{notice}" + + try: + return await original_fn(self, task, context, tools) + finally: + task.description = original_description + + task.description, _, extension_states = _augment_prompt_with_a2a( + a2a_agents=a2a_agents, + task_description=original_description, + agent_cards=agent_cards, + failed_agents=failed_agents, + extension_registry=extension_registry, + ) + task.response_model = agent_response_model + + try: + raw_result = await original_fn(self, task, context, tools) + agent_response = _parse_agent_response( + raw_result=raw_result, agent_response_model=agent_response_model + ) + + if extension_registry and isinstance(agent_response, BaseModel): + agent_response = extension_registry.process_response_with_all( + agent_response, extension_states + ) + + if isinstance(agent_response, BaseModel) and isinstance( + agent_response, AgentResponseProtocol + ): + if agent_response.is_a2a: + return await _adelegate_to_a2a( + self, + agent_response=agent_response, + task=task, + original_fn=original_fn, + context=context, + tools=tools, + agent_cards=agent_cards, + original_task_description=original_description, + _extension_registry=extension_registry, + ) + task.output_pydantic = None + return agent_response.message + + return raw_result + finally: + task.description = original_description + if task.output_pydantic is not None: + task.output_pydantic = original_output_pydantic + task.response_model = original_response_model + + +async def _ahandle_agent_response_and_continue( + self: Agent, + a2a_result: TaskStateResult, + agent_id: str, + agent_cards: dict[str, AgentCard] | None, + a2a_agents: list[A2AConfig | A2AClientConfig], + original_task_description: str, + conversation_history: list[Message], + turn_num: int, + max_turns: int, + task: Task, + original_fn: Callable[..., Coroutine[Any, Any, str]], + context: str | None, + tools: list[BaseTool] | None, + agent_response_model: type[BaseModel] | None, + extension_registry: ExtensionRegistry | None = None, + remote_status_notice: str = "", + endpoint: str | None = None, + a2a_agent_name: str | None = None, + agent_card: dict[str, Any] | None = None, +) -> tuple[str | None, str | None]: + """Async version of _handle_agent_response_and_continue.""" + agent_cards_dict = _prepare_agent_cards_dict(a2a_result, agent_id, agent_cards) + + ( + task.description, + disable_structured_output, + extension_states, + ) = _augment_prompt_with_a2a( + a2a_agents=a2a_agents, + task_description=original_task_description, + conversation_history=conversation_history, + turn_num=turn_num, + max_turns=max_turns, + agent_cards=agent_cards_dict, + remote_status_notice=remote_status_notice, + ) + + original_response_model = task.response_model + if disable_structured_output: + task.response_model = None + + raw_result = await original_fn(self, task, context, tools) + + if disable_structured_output: + task.response_model = original_response_model + + return _process_response_result( + raw_result=raw_result, + disable_structured_output=disable_structured_output, + turn_num=turn_num, + agent_role=self.role, + agent_response_model=agent_response_model, + extension_registry=extension_registry, + extension_states=extension_states, + from_task=task, + from_agent=self, + endpoint=endpoint, + a2a_agent_name=a2a_agent_name, + agent_card=agent_card, + ) + + +async def _adelegate_to_a2a( + self: Agent, + agent_response: AgentResponseProtocol, + task: Task, + original_fn: Callable[..., Coroutine[Any, Any, str]], + context: str | None, + tools: list[BaseTool] | None, + agent_cards: dict[str, AgentCard] | None = None, + original_task_description: str | None = None, + _extension_registry: ExtensionRegistry | None = None, +) -> str: + """Async version of _delegate_to_a2a.""" + ctx = _prepare_delegation_context( + self, agent_response, task, original_task_description + ) + state = _init_delegation_state(ctx, agent_cards) + current_request = state.current_request + context_id = state.context_id + task_id = state.task_id + reference_task_ids = state.reference_task_ids + conversation_history = state.conversation_history + + try: + for turn_num in range(ctx.max_turns): + agent_branch, accepted_output_modes = _get_turn_context(ctx.agent_config) + + a2a_result = await aexecute_a2a_delegation( + endpoint=ctx.agent_config.endpoint, + auth=ctx.agent_config.auth, + timeout=ctx.agent_config.timeout, + task_description=current_request, + context_id=context_id, + task_id=task_id, + reference_task_ids=reference_task_ids, + metadata=ctx.metadata, + extensions=ctx.extensions, + conversation_history=conversation_history, + agent_id=ctx.agent_id, + agent_role=Role.user, + agent_branch=agent_branch, + response_model=ctx.agent_config.response_model, + turn_number=turn_num + 1, + transport=ctx.agent_config.transport, + updates=ctx.agent_config.updates, + from_task=task, + from_agent=self, + client_extensions=getattr(ctx.agent_config, "extensions", None), + accepted_output_modes=accepted_output_modes, + input_files=task.input_files, + ) + + conversation_history = a2a_result.get("history", []) + + if conversation_history: + latest_message = conversation_history[-1] + if latest_message.task_id is not None: + task_id = latest_message.task_id + if latest_message.context_id is not None: + context_id = latest_message.context_id + + if a2a_result["status"] in [TaskState.completed, TaskState.input_required]: + trusted_result, task_id, reference_task_ids, remote_notice = ( + _handle_task_completion( + a2a_result, + task, + task_id, + reference_task_ids, + ctx.agent_config, + turn_num, + from_task=task, + from_agent=self, + endpoint=ctx.agent_config.endpoint, + a2a_agent_name=state.agent_name, + agent_card=state.agent_card_dict, + ) + ) + if trusted_result is not None: + return trusted_result + + final_result, next_request = await _ahandle_agent_response_and_continue( + self=self, + a2a_result=a2a_result, + agent_id=ctx.agent_id, + agent_cards=agent_cards, + a2a_agents=ctx.a2a_agents, + original_task_description=ctx.original_task_description, + conversation_history=conversation_history, + turn_num=turn_num, + max_turns=ctx.max_turns, + task=task, + original_fn=original_fn, + context=context, + tools=tools, + agent_response_model=ctx.agent_response_model, + extension_registry=_extension_registry, + remote_status_notice=remote_notice, + endpoint=ctx.agent_config.endpoint, + a2a_agent_name=state.agent_name, + agent_card=state.agent_card_dict, + ) + + if final_result is not None: + return final_result + + if next_request is not None: + current_request = next_request + + continue + + error_msg = a2a_result.get("error", "Unknown error") + + final_result, next_request = await _ahandle_agent_response_and_continue( + self=self, + a2a_result=a2a_result, + agent_id=ctx.agent_id, + agent_cards=agent_cards, + a2a_agents=ctx.a2a_agents, + original_task_description=ctx.original_task_description, + conversation_history=conversation_history, + turn_num=turn_num, + max_turns=ctx.max_turns, + task=task, + original_fn=original_fn, + context=context, + tools=tools, + agent_response_model=ctx.agent_response_model, + extension_registry=_extension_registry, + endpoint=ctx.agent_config.endpoint, + a2a_agent_name=state.agent_name, + agent_card=state.agent_card_dict, + ) + + if final_result is not None: + return final_result + + if next_request is not None: + current_request = next_request + continue + + return _emit_delegation_failed( + error_msg, + turn_num, + task, + self, + ctx.agent_config.endpoint, + state.agent_name, + state.agent_card_dict, + ) + + return _handle_max_turns_exceeded( + conversation_history, + ctx.max_turns, + from_task=task, + from_agent=self, + endpoint=ctx.agent_config.endpoint, + a2a_agent_name=state.agent_name, + agent_card=state.agent_card_dict, + ) + + finally: + task.description = ctx.original_task_description diff --git a/lib/crewai-a2a/tests/cassettes/TestA2AAgentCardFetching.test_fetch_agent_card.yaml b/lib/crewai-a2a/tests/cassettes/TestA2AAgentCardFetching.test_fetch_agent_card.yaml new file mode 100644 index 000000000..d60788a55 --- /dev/null +++ b/lib/crewai-a2a/tests/cassettes/TestA2AAgentCardFetching.test_fetch_agent_card.yaml @@ -0,0 +1,44 @@ +interactions: +- request: + body: '' + headers: + User-Agent: + - X-USER-AGENT-XXX + accept: + - '*/*' + accept-encoding: + - ACCEPT-ENCODING-XXX + connection: + - keep-alive + host: + - localhost:9999 + method: GET + uri: http://localhost:9999/.well-known/agent-card.json + response: + body: + string: '{"capabilities":{"streaming":true},"defaultInputModes":["text"],"defaultOutputModes":["text"],"description":"An + AI assistant powered by OpenAI GPT with calculator and time tools. Ask questions, + perform calculations, or get the current time in any timezone.","name":"GPT + Assistant","preferredTransport":"JSONRPC","protocolVersion":"0.3.0","skills":[{"description":"Have + a general conversation with the AI assistant. Ask questions, get explanations, + or just chat.","examples":["Hello, how are you?","Explain quantum computing + in simple terms","What can you help me with?"],"id":"conversation","name":"General + Conversation","tags":["chat","conversation","general"]},{"description":"Perform + mathematical calculations including arithmetic, exponents, and more.","examples":["What + is 25 * 17?","Calculate 2^10","What''s (100 + 50) / 3?"],"id":"calculator","name":"Calculator","tags":["math","calculator","arithmetic"]},{"description":"Get + the current date and time in any timezone.","examples":["What time is it?","What''s + the current time in Tokyo?","What''s today''s date in New York?"],"id":"time","name":"Current + Time","tags":["time","date","timezone"]}],"url":"http://localhost:9999/","version":"1.0.0"}' + headers: + content-length: + - '1198' + content-type: + - application/json + date: + - Tue, 06 Jan 2026 14:17:00 GMT + server: + - uvicorn + status: + code: 200 + message: OK +version: 1 diff --git a/lib/crewai-a2a/tests/cassettes/TestA2APollingIntegration.test_polling_completes_task.yaml b/lib/crewai-a2a/tests/cassettes/TestA2APollingIntegration.test_polling_completes_task.yaml new file mode 100644 index 000000000..3832dc7da --- /dev/null +++ b/lib/crewai-a2a/tests/cassettes/TestA2APollingIntegration.test_polling_completes_task.yaml @@ -0,0 +1,126 @@ +interactions: +- request: + body: '' + headers: + User-Agent: + - X-USER-AGENT-XXX + accept: + - '*/*' + accept-encoding: + - ACCEPT-ENCODING-XXX + connection: + - keep-alive + host: + - localhost:9999 + method: GET + uri: http://localhost:9999/.well-known/agent-card.json + response: + body: + string: '{"capabilities":{"streaming":true},"defaultInputModes":["text"],"defaultOutputModes":["text"],"description":"An + AI assistant powered by OpenAI GPT with calculator and time tools. Ask questions, + perform calculations, or get the current time in any timezone.","name":"GPT + Assistant","preferredTransport":"JSONRPC","protocolVersion":"0.3.0","skills":[{"description":"Have + a general conversation with the AI assistant. Ask questions, get explanations, + or just chat.","examples":["Hello, how are you?","Explain quantum computing + in simple terms","What can you help me with?"],"id":"conversation","name":"General + Conversation","tags":["chat","conversation","general"]},{"description":"Perform + mathematical calculations including arithmetic, exponents, and more.","examples":["What + is 25 * 17?","Calculate 2^10","What''s (100 + 50) / 3?"],"id":"calculator","name":"Calculator","tags":["math","calculator","arithmetic"]},{"description":"Get + the current date and time in any timezone.","examples":["What time is it?","What''s + the current time in Tokyo?","What''s today''s date in New York?"],"id":"time","name":"Current + Time","tags":["time","date","timezone"]}],"url":"http://localhost:9999/","version":"1.0.0"}' + headers: + content-length: + - '1198' + content-type: + - application/json + date: + - Tue, 06 Jan 2026 14:16:58 GMT + server: + - uvicorn + status: + code: 200 + message: OK +- request: + body: '{"id":"e5ac2160-ae9b-4bf9-aad7-14bf0d53d6d9","jsonrpc":"2.0","method":"message/stream","params":{"configuration":{"acceptedOutputModes":[],"blocking":true},"message":{"kind":"message","messageId":"e1e63c75-3ea0-49fb-b512-5128a2476416","parts":[{"kind":"text","text":"What + is 2 + 2?"}],"role":"user"}}}' + headers: + User-Agent: + - X-USER-AGENT-XXX + accept: + - '*/*, text/event-stream' + accept-encoding: + - ACCEPT-ENCODING-XXX + cache-control: + - no-store + connection: + - keep-alive + content-length: + - '301' + content-type: + - application/json + host: + - localhost:9999 + method: POST + uri: http://localhost:9999/ + response: + body: + string: "data: {\"id\":\"e5ac2160-ae9b-4bf9-aad7-14bf0d53d6d9\",\"jsonrpc\":\"2.0\",\"result\":{\"contextId\":\"b9e14c1b-734d-4d1e-864a-e6dda5231d71\",\"final\":false,\"kind\":\"status-update\",\"status\":{\"state\":\"submitted\"},\"taskId\":\"0dd4d3af-f35d-409d-9462-01218e5641f9\"}}\r\n\r\ndata: + {\"id\":\"e5ac2160-ae9b-4bf9-aad7-14bf0d53d6d9\",\"jsonrpc\":\"2.0\",\"result\":{\"contextId\":\"b9e14c1b-734d-4d1e-864a-e6dda5231d71\",\"final\":false,\"kind\":\"status-update\",\"status\":{\"state\":\"working\"},\"taskId\":\"0dd4d3af-f35d-409d-9462-01218e5641f9\"}}\r\n\r\ndata: + {\"id\":\"e5ac2160-ae9b-4bf9-aad7-14bf0d53d6d9\",\"jsonrpc\":\"2.0\",\"result\":{\"contextId\":\"b9e14c1b-734d-4d1e-864a-e6dda5231d71\",\"final\":true,\"kind\":\"status-update\",\"status\":{\"message\":{\"kind\":\"message\",\"messageId\":\"54bb7ff3-f2c0-4eb3-b427-bf1c8cf90832\",\"parts\":[{\"kind\":\"text\",\"text\":\"\\n[Tool: + calculator] 2 + 2 = 4\\n2 + 2 equals 4.\"}],\"role\":\"agent\"},\"state\":\"completed\"},\"taskId\":\"0dd4d3af-f35d-409d-9462-01218e5641f9\"}}\r\n\r\n" + headers: + Transfer-Encoding: + - chunked + cache-control: + - no-store + connection: + - keep-alive + content-type: + - text/event-stream; charset=utf-8 + date: + - Tue, 06 Jan 2026 14:16:58 GMT + server: + - uvicorn + x-accel-buffering: + - 'no' + status: + code: 200 + message: OK +- request: + body: '{"id":"cb1e4af3-d2d0-4848-96b8-7082ee6171d1","jsonrpc":"2.0","method":"tasks/get","params":{"historyLength":100,"id":"0dd4d3af-f35d-409d-9462-01218e5641f9"}}' + headers: + User-Agent: + - X-USER-AGENT-XXX + accept: + - '*/*' + accept-encoding: + - ACCEPT-ENCODING-XXX + connection: + - keep-alive + content-length: + - '157' + content-type: + - application/json + host: + - localhost:9999 + method: POST + uri: http://localhost:9999/ + response: + body: + string: '{"id":"cb1e4af3-d2d0-4848-96b8-7082ee6171d1","jsonrpc":"2.0","result":{"contextId":"b9e14c1b-734d-4d1e-864a-e6dda5231d71","history":[{"contextId":"b9e14c1b-734d-4d1e-864a-e6dda5231d71","kind":"message","messageId":"e1e63c75-3ea0-49fb-b512-5128a2476416","parts":[{"kind":"text","text":"What + is 2 + 2?"}],"role":"user","taskId":"0dd4d3af-f35d-409d-9462-01218e5641f9"}],"id":"0dd4d3af-f35d-409d-9462-01218e5641f9","kind":"task","status":{"message":{"kind":"message","messageId":"54bb7ff3-f2c0-4eb3-b427-bf1c8cf90832","parts":[{"kind":"text","text":"\n[Tool: + calculator] 2 + 2 = 4\n2 + 2 equals 4."}],"role":"agent"},"state":"completed"}}}' + headers: + content-length: + - '635' + content-type: + - application/json + date: + - Tue, 06 Jan 2026 14:17:00 GMT + server: + - uvicorn + status: + code: 200 + message: OK +version: 1 diff --git a/lib/crewai-a2a/tests/cassettes/TestA2AStreamingIntegration.test_streaming_completes_task.yaml b/lib/crewai-a2a/tests/cassettes/TestA2AStreamingIntegration.test_streaming_completes_task.yaml new file mode 100644 index 000000000..e98e61c2b --- /dev/null +++ b/lib/crewai-a2a/tests/cassettes/TestA2AStreamingIntegration.test_streaming_completes_task.yaml @@ -0,0 +1,90 @@ +interactions: +- request: + body: '' + headers: + User-Agent: + - X-USER-AGENT-XXX + accept: + - '*/*' + accept-encoding: + - ACCEPT-ENCODING-XXX + connection: + - keep-alive + host: + - localhost:9999 + method: GET + uri: http://localhost:9999/.well-known/agent-card.json + response: + body: + string: '{"capabilities":{"streaming":true},"defaultInputModes":["text"],"defaultOutputModes":["text"],"description":"An + AI assistant powered by OpenAI GPT with calculator and time tools. Ask questions, + perform calculations, or get the current time in any timezone.","name":"GPT + Assistant","preferredTransport":"JSONRPC","protocolVersion":"0.3.0","skills":[{"description":"Have + a general conversation with the AI assistant. Ask questions, get explanations, + or just chat.","examples":["Hello, how are you?","Explain quantum computing + in simple terms","What can you help me with?"],"id":"conversation","name":"General + Conversation","tags":["chat","conversation","general"]},{"description":"Perform + mathematical calculations including arithmetic, exponents, and more.","examples":["What + is 25 * 17?","Calculate 2^10","What''s (100 + 50) / 3?"],"id":"calculator","name":"Calculator","tags":["math","calculator","arithmetic"]},{"description":"Get + the current date and time in any timezone.","examples":["What time is it?","What''s + the current time in Tokyo?","What''s today''s date in New York?"],"id":"time","name":"Current + Time","tags":["time","date","timezone"]}],"url":"http://localhost:9999/","version":"1.0.0"}' + headers: + content-length: + - '1198' + content-type: + - application/json + date: + - Tue, 06 Jan 2026 14:17:02 GMT + server: + - uvicorn + status: + code: 200 + message: OK +- request: + body: '{"id":"8cf25b61-8884-4246-adce-fccb32e176ab","jsonrpc":"2.0","method":"message/stream","params":{"configuration":{"acceptedOutputModes":[],"blocking":true},"message":{"kind":"message","messageId":"c145297f-7331-4835-adcc-66b51de92a2b","parts":[{"kind":"text","text":"What + is 2 + 2?"}],"role":"user"}}}' + headers: + User-Agent: + - X-USER-AGENT-XXX + accept: + - '*/*, text/event-stream' + accept-encoding: + - ACCEPT-ENCODING-XXX + cache-control: + - no-store + connection: + - keep-alive + content-length: + - '301' + content-type: + - application/json + host: + - localhost:9999 + method: POST + uri: http://localhost:9999/ + response: + body: + string: "data: {\"id\":\"8cf25b61-8884-4246-adce-fccb32e176ab\",\"jsonrpc\":\"2.0\",\"result\":{\"contextId\":\"30601267-ab3b-48ef-afc8-916c37a18651\",\"final\":false,\"kind\":\"status-update\",\"status\":{\"state\":\"submitted\"},\"taskId\":\"3083d3da-4739-4f4f-a4e8-7c048ea819c1\"}}\r\n\r\ndata: + {\"id\":\"8cf25b61-8884-4246-adce-fccb32e176ab\",\"jsonrpc\":\"2.0\",\"result\":{\"contextId\":\"30601267-ab3b-48ef-afc8-916c37a18651\",\"final\":false,\"kind\":\"status-update\",\"status\":{\"state\":\"working\"},\"taskId\":\"3083d3da-4739-4f4f-a4e8-7c048ea819c1\"}}\r\n\r\ndata: + {\"id\":\"8cf25b61-8884-4246-adce-fccb32e176ab\",\"jsonrpc\":\"2.0\",\"result\":{\"contextId\":\"30601267-ab3b-48ef-afc8-916c37a18651\",\"final\":true,\"kind\":\"status-update\",\"status\":{\"message\":{\"kind\":\"message\",\"messageId\":\"25f81e3c-b7e8-48b5-a98a-4066f3637a13\",\"parts\":[{\"kind\":\"text\",\"text\":\"\\n[Tool: + calculator] 2 + 2 = 4\\n2 + 2 equals 4.\"}],\"role\":\"agent\"},\"state\":\"completed\"},\"taskId\":\"3083d3da-4739-4f4f-a4e8-7c048ea819c1\"}}\r\n\r\n" + headers: + Transfer-Encoding: + - chunked + cache-control: + - no-store + connection: + - keep-alive + content-type: + - text/event-stream; charset=utf-8 + date: + - Tue, 06 Jan 2026 14:17:02 GMT + server: + - uvicorn + x-accel-buffering: + - 'no' + status: + code: 200 + message: OK +version: 1 diff --git a/lib/crewai-a2a/tests/cassettes/TestA2ATaskOperations.test_send_message_and_get_response.yaml b/lib/crewai-a2a/tests/cassettes/TestA2ATaskOperations.test_send_message_and_get_response.yaml new file mode 100644 index 000000000..e3623e8da --- /dev/null +++ b/lib/crewai-a2a/tests/cassettes/TestA2ATaskOperations.test_send_message_and_get_response.yaml @@ -0,0 +1,90 @@ +interactions: +- request: + body: '' + headers: + User-Agent: + - X-USER-AGENT-XXX + accept: + - '*/*' + accept-encoding: + - ACCEPT-ENCODING-XXX + connection: + - keep-alive + host: + - localhost:9999 + method: GET + uri: http://localhost:9999/.well-known/agent-card.json + response: + body: + string: '{"capabilities":{"streaming":true},"defaultInputModes":["text"],"defaultOutputModes":["text"],"description":"An + AI assistant powered by OpenAI GPT with calculator and time tools. Ask questions, + perform calculations, or get the current time in any timezone.","name":"GPT + Assistant","preferredTransport":"JSONRPC","protocolVersion":"0.3.0","skills":[{"description":"Have + a general conversation with the AI assistant. Ask questions, get explanations, + or just chat.","examples":["Hello, how are you?","Explain quantum computing + in simple terms","What can you help me with?"],"id":"conversation","name":"General + Conversation","tags":["chat","conversation","general"]},{"description":"Perform + mathematical calculations including arithmetic, exponents, and more.","examples":["What + is 25 * 17?","Calculate 2^10","What''s (100 + 50) / 3?"],"id":"calculator","name":"Calculator","tags":["math","calculator","arithmetic"]},{"description":"Get + the current date and time in any timezone.","examples":["What time is it?","What''s + the current time in Tokyo?","What''s today''s date in New York?"],"id":"time","name":"Current + Time","tags":["time","date","timezone"]}],"url":"http://localhost:9999/","version":"1.0.0"}' + headers: + content-length: + - '1198' + content-type: + - application/json + date: + - Tue, 06 Jan 2026 14:17:00 GMT + server: + - uvicorn + status: + code: 200 + message: OK +- request: + body: '{"id":"3a17c6bf-8db6-45a6-8535-34c45c0c4936","jsonrpc":"2.0","method":"message/stream","params":{"configuration":{"acceptedOutputModes":[],"blocking":true},"message":{"kind":"message","messageId":"712558a3-6d92-4591-be8a-9dd8566dde82","parts":[{"kind":"text","text":"What + is 2 + 2?"}],"role":"user"}}}' + headers: + User-Agent: + - X-USER-AGENT-XXX + accept: + - '*/*, text/event-stream' + accept-encoding: + - ACCEPT-ENCODING-XXX + cache-control: + - no-store + connection: + - keep-alive + content-length: + - '301' + content-type: + - application/json + host: + - localhost:9999 + method: POST + uri: http://localhost:9999/ + response: + body: + string: "data: {\"id\":\"3a17c6bf-8db6-45a6-8535-34c45c0c4936\",\"jsonrpc\":\"2.0\",\"result\":{\"contextId\":\"ca2fbbc9-761e-45d9-a929-0c68b1f8acbf\",\"final\":false,\"kind\":\"status-update\",\"status\":{\"state\":\"submitted\"},\"taskId\":\"c6e88db0-36e9-4269-8b9a-ecb6dfdcf6a1\"}}\r\n\r\ndata: + {\"id\":\"3a17c6bf-8db6-45a6-8535-34c45c0c4936\",\"jsonrpc\":\"2.0\",\"result\":{\"contextId\":\"ca2fbbc9-761e-45d9-a929-0c68b1f8acbf\",\"final\":false,\"kind\":\"status-update\",\"status\":{\"state\":\"working\"},\"taskId\":\"c6e88db0-36e9-4269-8b9a-ecb6dfdcf6a1\"}}\r\n\r\ndata: + {\"id\":\"3a17c6bf-8db6-45a6-8535-34c45c0c4936\",\"jsonrpc\":\"2.0\",\"result\":{\"contextId\":\"ca2fbbc9-761e-45d9-a929-0c68b1f8acbf\",\"final\":true,\"kind\":\"status-update\",\"status\":{\"message\":{\"kind\":\"message\",\"messageId\":\"916324aa-fd25-4849-bceb-c4644e2fcbb0\",\"parts\":[{\"kind\":\"text\",\"text\":\"\\n[Tool: + calculator] 2 + 2 = 4\\n2 + 2 equals 4.\"}],\"role\":\"agent\"},\"state\":\"completed\"},\"taskId\":\"c6e88db0-36e9-4269-8b9a-ecb6dfdcf6a1\"}}\r\n\r\n" + headers: + Transfer-Encoding: + - chunked + cache-control: + - no-store + connection: + - keep-alive + content-type: + - text/event-stream; charset=utf-8 + date: + - Tue, 06 Jan 2026 14:17:00 GMT + server: + - uvicorn + x-accel-buffering: + - 'no' + status: + code: 200 + message: OK +version: 1 diff --git a/lib/crewai-a2a/tests/conftest.py b/lib/crewai-a2a/tests/conftest.py new file mode 100644 index 000000000..1f8f9f949 --- /dev/null +++ b/lib/crewai-a2a/tests/conftest.py @@ -0,0 +1,21 @@ +"""Pytest configuration for crewai-a2a tests. + +Ensures Agent model is properly rebuilt with A2A types, +which can fail silently during circular import resolution. +""" + +from crewai_a2a.config import A2AClientConfig, A2AConfig, A2AServerConfig + + +def pytest_configure() -> None: + """Rebuild Agent/LiteAgent models after crewai_a2a is fully loaded.""" + from crewai.agent.core import Agent + from crewai.lite_agent import LiteAgent + + ns = { + "A2AConfig": A2AConfig, + "A2AClientConfig": A2AClientConfig, + "A2AServerConfig": A2AServerConfig, + } + Agent.model_rebuild(_types_namespace=ns) + LiteAgent.model_rebuild(_types_namespace=ns) diff --git a/lib/crewai/tests/a2a/test_a2a_integration.py b/lib/crewai-a2a/tests/test_a2a_integration.py similarity index 92% rename from lib/crewai/tests/a2a/test_a2a_integration.py rename to lib/crewai-a2a/tests/test_a2a_integration.py index 9950ee0a2..4f5983cc9 100644 --- a/lib/crewai/tests/a2a/test_a2a_integration.py +++ b/lib/crewai-a2a/tests/test_a2a_integration.py @@ -3,15 +3,13 @@ from __future__ import annotations import os import uuid +from a2a.client import ClientFactory +from a2a.types import AgentCard, Message, Part, Role, Task, TaskState, TextPart +from crewai_a2a.updates.polling.handler import PollingHandler +from crewai_a2a.updates.streaming.handler import StreamingHandler import pytest import pytest_asyncio -from a2a.client import ClientFactory -from a2a.types import AgentCard, Message, Part, Role, TaskState, TextPart - -from crewai.a2a.updates.polling.handler import PollingHandler -from crewai.a2a.updates.streaming.handler import StreamingHandler - A2A_TEST_ENDPOINT = os.getenv("A2A_TEST_ENDPOINT", "http://localhost:9999") @@ -162,7 +160,7 @@ class TestA2APushNotificationHandler: ) @pytest.fixture - def mock_task(self) -> "Task": + def mock_task(self) -> Task: """Create a minimal valid task for testing.""" from a2a.types import Task, TaskStatus @@ -182,11 +180,12 @@ class TestA2APushNotificationHandler: from unittest.mock import AsyncMock, MagicMock from a2a.types import Task, TaskStatus + from crewai_a2a.updates.push_notifications.config import PushNotificationConfig + from crewai_a2a.updates.push_notifications.handler import ( + PushNotificationHandler, + ) from pydantic import AnyHttpUrl - from crewai.a2a.updates.push_notifications.config import PushNotificationConfig - from crewai.a2a.updates.push_notifications.handler import PushNotificationHandler - completed_task = Task( id="task-123", context_id="ctx-123", @@ -246,11 +245,12 @@ class TestA2APushNotificationHandler: from unittest.mock import AsyncMock, MagicMock from a2a.types import Task, TaskStatus + from crewai_a2a.updates.push_notifications.config import PushNotificationConfig + from crewai_a2a.updates.push_notifications.handler import ( + PushNotificationHandler, + ) from pydantic import AnyHttpUrl - from crewai.a2a.updates.push_notifications.config import PushNotificationConfig - from crewai.a2a.updates.push_notifications.handler import PushNotificationHandler - mock_store = MagicMock() mock_store.wait_for_result = AsyncMock(return_value=None) @@ -303,7 +303,9 @@ class TestA2APushNotificationHandler: """Test that push handler fails gracefully without config.""" from unittest.mock import MagicMock - from crewai.a2a.updates.push_notifications.handler import PushNotificationHandler + from crewai_a2a.updates.push_notifications.handler import ( + PushNotificationHandler, + ) mock_client = MagicMock() diff --git a/lib/crewai/tests/a2a/utils/test_agent_card.py b/lib/crewai-a2a/tests/utils/test_agent_card.py similarity index 98% rename from lib/crewai/tests/a2a/utils/test_agent_card.py rename to lib/crewai-a2a/tests/utils/test_agent_card.py index fb96710a7..9da1a7b6c 100644 --- a/lib/crewai/tests/a2a/utils/test_agent_card.py +++ b/lib/crewai-a2a/tests/utils/test_agent_card.py @@ -3,10 +3,9 @@ from __future__ import annotations from a2a.types import AgentCard, AgentSkill - from crewai import Agent -from crewai.a2a.config import A2AClientConfig, A2AServerConfig -from crewai.a2a.utils.agent_card import inject_a2a_server_methods +from crewai_a2a.config import A2AClientConfig, A2AServerConfig +from crewai_a2a.utils.agent_card import inject_a2a_server_methods class TestInjectA2AServerMethods: diff --git a/lib/crewai/tests/a2a/utils/test_task.py b/lib/crewai-a2a/tests/utils/test_task.py similarity index 91% rename from lib/crewai/tests/a2a/utils/test_task.py rename to lib/crewai-a2a/tests/utils/test_task.py index 781827ac8..5882fe723 100644 --- a/lib/crewai/tests/a2a/utils/test_task.py +++ b/lib/crewai-a2a/tests/utils/test_task.py @@ -6,13 +6,12 @@ import asyncio from typing import Any from unittest.mock import AsyncMock, MagicMock, patch -import pytest -import pytest_asyncio from a2a.server.agent_execution import RequestContext from a2a.server.events import EventQueue from a2a.types import Message, Task as A2ATask, TaskState, TaskStatus - -from crewai.a2a.utils.task import cancel, cancellable, execute +from crewai_a2a.utils.task import cancel, cancellable, execute +import pytest +import pytest_asyncio @pytest.fixture @@ -85,8 +84,11 @@ class TestCancellableDecorator: assert call_count == 1 @pytest.mark.asyncio - async def test_executes_function_with_context(self, mock_context: MagicMock) -> None: + async def test_executes_function_with_context( + self, mock_context: MagicMock + ) -> None: """Function executes normally with RequestContext when not cancelled.""" + @cancellable async def my_func(context: RequestContext) -> str: await asyncio.sleep(0.01) @@ -134,6 +136,7 @@ class TestCancellableDecorator: @pytest.mark.asyncio async def test_extracts_context_from_kwargs(self, mock_context: MagicMock) -> None: """Context can be passed as keyword argument.""" + @cancellable async def my_func(value: int, context: RequestContext | None = None) -> int: return value + 1 @@ -156,8 +159,8 @@ class TestExecute: ) -> None: """Execute completes successfully and enqueues completed task.""" with ( - patch("crewai.a2a.utils.task.Task", return_value=mock_task), - patch("crewai.a2a.utils.task.crewai_event_bus") as mock_bus, + patch("crewai_a2a.utils.task.Task", return_value=mock_task), + patch("crewai_a2a.utils.task.crewai_event_bus") as mock_bus, ): await execute(mock_agent, mock_context, mock_event_queue) @@ -175,8 +178,8 @@ class TestExecute: ) -> None: """Execute emits A2AServerTaskStartedEvent.""" with ( - patch("crewai.a2a.utils.task.Task", return_value=mock_task), - patch("crewai.a2a.utils.task.crewai_event_bus") as mock_bus, + patch("crewai_a2a.utils.task.Task", return_value=mock_task), + patch("crewai_a2a.utils.task.crewai_event_bus") as mock_bus, ): await execute(mock_agent, mock_context, mock_event_queue) @@ -197,8 +200,8 @@ class TestExecute: ) -> None: """Execute emits A2AServerTaskCompletedEvent on success.""" with ( - patch("crewai.a2a.utils.task.Task", return_value=mock_task), - patch("crewai.a2a.utils.task.crewai_event_bus") as mock_bus, + patch("crewai_a2a.utils.task.Task", return_value=mock_task), + patch("crewai_a2a.utils.task.crewai_event_bus") as mock_bus, ): await execute(mock_agent, mock_context, mock_event_queue) @@ -221,8 +224,8 @@ class TestExecute: mock_agent.aexecute_task = AsyncMock(side_effect=ValueError("Test error")) with ( - patch("crewai.a2a.utils.task.Task", return_value=mock_task), - patch("crewai.a2a.utils.task.crewai_event_bus") as mock_bus, + patch("crewai_a2a.utils.task.Task", return_value=mock_task), + patch("crewai_a2a.utils.task.crewai_event_bus") as mock_bus, ): with pytest.raises(Exception): await execute(mock_agent, mock_context, mock_event_queue) @@ -245,8 +248,8 @@ class TestExecute: mock_agent.aexecute_task = AsyncMock(side_effect=asyncio.CancelledError()) with ( - patch("crewai.a2a.utils.task.Task", return_value=mock_task), - patch("crewai.a2a.utils.task.crewai_event_bus") as mock_bus, + patch("crewai_a2a.utils.task.Task", return_value=mock_task), + patch("crewai_a2a.utils.task.crewai_event_bus") as mock_bus, ): with pytest.raises(asyncio.CancelledError): await execute(mock_agent, mock_context, mock_event_queue) @@ -354,6 +357,7 @@ class TestExecuteAndCancelIntegration: mock_task: MagicMock, ) -> None: """Calling cancel stops a running execute.""" + async def slow_task(**kwargs: Any) -> str: await asyncio.sleep(2.0) return "should not complete" @@ -361,8 +365,8 @@ class TestExecuteAndCancelIntegration: mock_agent.aexecute_task = slow_task with ( - patch("crewai.a2a.utils.task.Task", return_value=mock_task), - patch("crewai.a2a.utils.task.crewai_event_bus"), + patch("crewai_a2a.utils.task.Task", return_value=mock_task), + patch("crewai_a2a.utils.task.crewai_event_bus"), ): execute_task = asyncio.create_task( execute(mock_agent, mock_context, mock_event_queue) @@ -372,4 +376,4 @@ class TestExecuteAndCancelIntegration: await cancel(mock_context, mock_event_queue) with pytest.raises(asyncio.CancelledError): - await execute_task \ No newline at end of file + await execute_task diff --git a/lib/crewai/pyproject.toml b/lib/crewai/pyproject.toml index 0401ef193..12986893c 100644 --- a/lib/crewai/pyproject.toml +++ b/lib/crewai/pyproject.toml @@ -96,12 +96,7 @@ azure-ai-inference = [ anthropic = [ "anthropic~=0.73.0", ] -a2a = [ - "a2a-sdk~=0.3.10", - "httpx-auth~=0.23.1", - "httpx-sse~=0.4.0", - "aiocache[redis,memcached]~=0.12.3", -] +a2a = ["crewai-a2a==1.10.1b1"] file-processing = [ "crewai-files", ] @@ -132,6 +127,7 @@ torchvision = [ { index = "pytorch", marker = "python_version < '3.13'" }, ] crewai-files = { workspace = true } +crewai-a2a = { workspace = true } [build-system] diff --git a/lib/crewai/src/crewai/a2a/__init__.py b/lib/crewai/src/crewai/a2a/__init__.py index 634f77708..510c67044 100644 --- a/lib/crewai/src/crewai/a2a/__init__.py +++ b/lib/crewai/src/crewai/a2a/__init__.py @@ -1,10 +1,13 @@ -"""Agent-to-Agent (A2A) protocol communication module for CrewAI.""" +"""Backward-compatibility shim — use ``crewai_a2a`` instead.""" -from crewai.a2a.config import A2AClientConfig, A2AConfig, A2AServerConfig +import warnings -__all__ = [ - "A2AClientConfig", - "A2AConfig", - "A2AServerConfig", -] +warnings.warn( + "'crewai.a2a' has been moved to 'crewai_a2a'. " + "Please update your imports. The old path will be removed in v2.0.0.", + FutureWarning, + stacklevel=2, +) + +from crewai_a2a import * # noqa: E402, F403 diff --git a/lib/crewai/src/crewai/a2a/auth/__init__.py b/lib/crewai/src/crewai/a2a/auth/__init__.py index 093193a8e..06fb67f73 100644 --- a/lib/crewai/src/crewai/a2a/auth/__init__.py +++ b/lib/crewai/src/crewai/a2a/auth/__init__.py @@ -1,36 +1,13 @@ -"""A2A authentication schemas.""" +"""Backward-compatibility shim — use ``crewai_a2a.auth`` instead.""" -from crewai.a2a.auth.client_schemes import ( - APIKeyAuth, - AuthScheme, - BearerTokenAuth, - ClientAuthScheme, - HTTPBasicAuth, - HTTPDigestAuth, - OAuth2AuthorizationCode, - OAuth2ClientCredentials, - TLSConfig, -) -from crewai.a2a.auth.server_schemes import ( - AuthenticatedUser, - OIDCAuth, - ServerAuthScheme, - SimpleTokenAuth, +import warnings + + +warnings.warn( + "'crewai.a2a.auth' has been moved to 'crewai_a2a.auth'. " + "Please update your imports. The old path will be removed in v2.0.0.", + FutureWarning, + stacklevel=2, ) - -__all__ = [ - "APIKeyAuth", - "AuthScheme", - "AuthenticatedUser", - "BearerTokenAuth", - "ClientAuthScheme", - "HTTPBasicAuth", - "HTTPDigestAuth", - "OAuth2AuthorizationCode", - "OAuth2ClientCredentials", - "OIDCAuth", - "ServerAuthScheme", - "SimpleTokenAuth", - "TLSConfig", -] +from crewai_a2a.auth import * # noqa: E402, F403 diff --git a/lib/crewai/src/crewai/a2a/auth/client_schemes.py b/lib/crewai/src/crewai/a2a/auth/client_schemes.py index 0356b8aef..ceeac7df6 100644 --- a/lib/crewai/src/crewai/a2a/auth/client_schemes.py +++ b/lib/crewai/src/crewai/a2a/auth/client_schemes.py @@ -1,550 +1,13 @@ -"""Authentication schemes for A2A protocol clients. +"""Backward-compatibility shim — use ``crewai_a2a.auth.client_schemes`` instead.""" -Supported authentication methods: -- Bearer tokens -- OAuth2 (Client Credentials, Authorization Code) -- API Keys (header, query, cookie) -- HTTP Basic authentication -- HTTP Digest authentication -- mTLS (mutual TLS) client certificate authentication -""" +import warnings -from __future__ import annotations -from abc import ABC, abstractmethod -import asyncio -import base64 -from collections.abc import Awaitable, Callable, MutableMapping -from pathlib import Path -import ssl -import time -from typing import TYPE_CHECKING, ClassVar, Literal -import urllib.parse +warnings.warn( + "'crewai.a2a.auth.client_schemes' has been moved to 'crewai_a2a.auth.client_schemes'. " + "Please update your imports. The old path will be removed in v2.0.0.", + FutureWarning, + stacklevel=2, +) -import httpx -from httpx import DigestAuth -from pydantic import BaseModel, ConfigDict, Field, FilePath, PrivateAttr -from typing_extensions import deprecated - - -if TYPE_CHECKING: - import grpc # type: ignore[import-untyped] - - -class TLSConfig(BaseModel): - """TLS/mTLS configuration for secure client connections. - - Supports mutual TLS (mTLS) where the client presents a certificate to the server, - and standard TLS with custom CA verification. - - Attributes: - client_cert_path: Path to client certificate file (PEM format) for mTLS. - client_key_path: Path to client private key file (PEM format) for mTLS. - ca_cert_path: Path to CA certificate bundle for server verification. - verify: Whether to verify server certificates. Set False only for development. - """ - - model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid") - - client_cert_path: FilePath | None = Field( - default=None, - description="Path to client certificate file (PEM format) for mTLS", - ) - client_key_path: FilePath | None = Field( - default=None, - description="Path to client private key file (PEM format) for mTLS", - ) - ca_cert_path: FilePath | None = Field( - default=None, - description="Path to CA certificate bundle for server verification", - ) - verify: bool = Field( - default=True, - description="Whether to verify server certificates. Set False only for development.", - ) - - def get_httpx_ssl_context(self) -> ssl.SSLContext | bool | str: - """Build SSL context for httpx client. - - Returns: - SSL context if certificates configured, True for default verification, - False if verification disabled, or path to CA bundle. - """ - if not self.verify: - return False - - if self.client_cert_path and self.client_key_path: - context = ssl.create_default_context() - - if self.ca_cert_path: - context.load_verify_locations(cafile=str(self.ca_cert_path)) - - context.load_cert_chain( - certfile=str(self.client_cert_path), - keyfile=str(self.client_key_path), - ) - return context - - if self.ca_cert_path: - return str(self.ca_cert_path) - - return True - - def get_grpc_credentials(self) -> grpc.ChannelCredentials | None: # type: ignore[no-any-unimported] - """Build gRPC channel credentials for secure connections. - - Returns: - gRPC SSL credentials if certificates configured, None otherwise. - """ - try: - import grpc - except ImportError: - return None - - if not self.verify and not self.client_cert_path: - return None - - root_certs: bytes | None = None - private_key: bytes | None = None - certificate_chain: bytes | None = None - - if self.ca_cert_path: - root_certs = Path(self.ca_cert_path).read_bytes() - - if self.client_cert_path and self.client_key_path: - private_key = Path(self.client_key_path).read_bytes() - certificate_chain = Path(self.client_cert_path).read_bytes() - - return grpc.ssl_channel_credentials( - root_certificates=root_certs, - private_key=private_key, - certificate_chain=certificate_chain, - ) - - -class ClientAuthScheme(ABC, BaseModel): - """Base class for client-side authentication schemes. - - Client auth schemes apply credentials to outgoing requests. - - Attributes: - tls: Optional TLS/mTLS configuration for secure connections. - """ - - tls: TLSConfig | None = Field( - default=None, - description="TLS/mTLS configuration for secure connections", - ) - - @abstractmethod - async def apply_auth( - self, client: httpx.AsyncClient, headers: MutableMapping[str, str] - ) -> MutableMapping[str, str]: - """Apply authentication to request headers. - - Args: - client: HTTP client for making auth requests. - headers: Current request headers. - - Returns: - Updated headers with authentication applied. - """ - ... - - -@deprecated("Use ClientAuthScheme instead", category=FutureWarning) -class AuthScheme(ClientAuthScheme): - """Deprecated: Use ClientAuthScheme instead.""" - - -class BearerTokenAuth(ClientAuthScheme): - """Bearer token authentication (Authorization: Bearer ). - - Attributes: - token: Bearer token for authentication. - """ - - token: str = Field(description="Bearer token") - - async def apply_auth( - self, client: httpx.AsyncClient, headers: MutableMapping[str, str] - ) -> MutableMapping[str, str]: - """Apply Bearer token to Authorization header. - - Args: - client: HTTP client for making auth requests. - headers: Current request headers. - - Returns: - Updated headers with Bearer token in Authorization header. - """ - headers["Authorization"] = f"Bearer {self.token}" - return headers - - -class HTTPBasicAuth(ClientAuthScheme): - """HTTP Basic authentication. - - Attributes: - username: Username for Basic authentication. - password: Password for Basic authentication. - """ - - username: str = Field(description="Username") - password: str = Field(description="Password") - - async def apply_auth( - self, client: httpx.AsyncClient, headers: MutableMapping[str, str] - ) -> MutableMapping[str, str]: - """Apply HTTP Basic authentication. - - Args: - client: HTTP client for making auth requests. - headers: Current request headers. - - Returns: - Updated headers with Basic auth in Authorization header. - """ - credentials = f"{self.username}:{self.password}" - encoded = base64.b64encode(credentials.encode()).decode() - headers["Authorization"] = f"Basic {encoded}" - return headers - - -class HTTPDigestAuth(ClientAuthScheme): - """HTTP Digest authentication. - - Note: Uses httpx-auth library for digest implementation. - - Attributes: - username: Username for Digest authentication. - password: Password for Digest authentication. - """ - - username: str = Field(description="Username") - password: str = Field(description="Password") - - _configured_client_id: int | None = PrivateAttr(default=None) - - async def apply_auth( - self, client: httpx.AsyncClient, headers: MutableMapping[str, str] - ) -> MutableMapping[str, str]: - """Digest auth is handled by httpx auth flow, not headers. - - Args: - client: HTTP client for making auth requests. - headers: Current request headers. - - Returns: - Unchanged headers (Digest auth handled by httpx auth flow). - """ - return headers - - def configure_client(self, client: httpx.AsyncClient) -> None: - """Configure client with Digest auth. - - Idempotent: Only configures the client once. Subsequent calls on the same - client instance are no-ops to prevent overwriting auth configuration. - - Args: - client: HTTP client to configure with Digest authentication. - """ - client_id = id(client) - if self._configured_client_id == client_id: - return - - client.auth = DigestAuth(self.username, self.password) - self._configured_client_id = client_id - - -class APIKeyAuth(ClientAuthScheme): - """API Key authentication (header, query, or cookie). - - Attributes: - api_key: API key value for authentication. - location: Where to send the API key (header, query, or cookie). - name: Parameter name for the API key (default: X-API-Key). - """ - - api_key: str = Field(description="API key value") - location: Literal["header", "query", "cookie"] = Field( - default="header", description="Where to send the API key" - ) - name: str = Field(default="X-API-Key", description="Parameter name for the API key") - - _configured_client_ids: set[int] = PrivateAttr(default_factory=set) - - async def apply_auth( - self, client: httpx.AsyncClient, headers: MutableMapping[str, str] - ) -> MutableMapping[str, str]: - """Apply API key authentication. - - Args: - client: HTTP client for making auth requests. - headers: Current request headers. - - Returns: - Updated headers with API key (for header/cookie locations). - """ - if self.location == "header": - headers[self.name] = self.api_key - elif self.location == "cookie": - headers["Cookie"] = f"{self.name}={self.api_key}" - return headers - - def configure_client(self, client: httpx.AsyncClient) -> None: - """Configure client for query param API keys. - - Idempotent: Only adds the request hook once per client instance. - Subsequent calls on the same client are no-ops to prevent hook accumulation. - - Args: - client: HTTP client to configure with query param API key hook. - """ - if self.location == "query": - client_id = id(client) - if client_id in self._configured_client_ids: - return - - async def _add_api_key_param(request: httpx.Request) -> None: - url = httpx.URL(request.url) - request.url = url.copy_add_param(self.name, self.api_key) - - client.event_hooks["request"].append(_add_api_key_param) - self._configured_client_ids.add(client_id) - - -class OAuth2ClientCredentials(ClientAuthScheme): - """OAuth2 Client Credentials flow authentication. - - Thread-safe implementation with asyncio.Lock to prevent concurrent token fetches - when multiple requests share the same auth instance. - - Attributes: - token_url: OAuth2 token endpoint URL. - client_id: OAuth2 client identifier. - client_secret: OAuth2 client secret. - scopes: List of required OAuth2 scopes. - """ - - token_url: str = Field(description="OAuth2 token endpoint") - client_id: str = Field(description="OAuth2 client ID") - client_secret: str = Field(description="OAuth2 client secret") - scopes: list[str] = Field( - default_factory=list, description="Required OAuth2 scopes" - ) - - _access_token: str | None = PrivateAttr(default=None) - _token_expires_at: float | None = PrivateAttr(default=None) - _lock: asyncio.Lock = PrivateAttr(default_factory=asyncio.Lock) - - async def apply_auth( - self, client: httpx.AsyncClient, headers: MutableMapping[str, str] - ) -> MutableMapping[str, str]: - """Apply OAuth2 access token to Authorization header. - - Uses asyncio.Lock to ensure only one coroutine fetches tokens at a time, - preventing race conditions when multiple concurrent requests use the same - auth instance. - - Args: - client: HTTP client for making token requests. - headers: Current request headers. - - Returns: - Updated headers with OAuth2 access token in Authorization header. - """ - if ( - self._access_token is None - or self._token_expires_at is None - or time.time() >= self._token_expires_at - ): - async with self._lock: - if ( - self._access_token is None - or self._token_expires_at is None - or time.time() >= self._token_expires_at - ): - await self._fetch_token(client) - - if self._access_token: - headers["Authorization"] = f"Bearer {self._access_token}" - - return headers - - async def _fetch_token(self, client: httpx.AsyncClient) -> None: - """Fetch OAuth2 access token using client credentials flow. - - Args: - client: HTTP client for making token request. - - Raises: - httpx.HTTPStatusError: If token request fails. - """ - data = { - "grant_type": "client_credentials", - "client_id": self.client_id, - "client_secret": self.client_secret, - } - - if self.scopes: - data["scope"] = " ".join(self.scopes) - - response = await client.post(self.token_url, data=data) - response.raise_for_status() - - token_data = response.json() - self._access_token = token_data["access_token"] - expires_in = token_data.get("expires_in", 3600) - self._token_expires_at = time.time() + expires_in - 60 - - -class OAuth2AuthorizationCode(ClientAuthScheme): - """OAuth2 Authorization Code flow authentication. - - Thread-safe implementation with asyncio.Lock to prevent concurrent token operations. - - Note: Requires interactive authorization. - - Attributes: - authorization_url: OAuth2 authorization endpoint URL. - token_url: OAuth2 token endpoint URL. - client_id: OAuth2 client identifier. - client_secret: OAuth2 client secret. - redirect_uri: OAuth2 redirect URI for callback. - scopes: List of required OAuth2 scopes. - """ - - authorization_url: str = Field(description="OAuth2 authorization endpoint") - token_url: str = Field(description="OAuth2 token endpoint") - client_id: str = Field(description="OAuth2 client ID") - client_secret: str = Field(description="OAuth2 client secret") - redirect_uri: str = Field(description="OAuth2 redirect URI") - scopes: list[str] = Field( - default_factory=list, description="Required OAuth2 scopes" - ) - - _access_token: str | None = PrivateAttr(default=None) - _refresh_token: str | None = PrivateAttr(default=None) - _token_expires_at: float | None = PrivateAttr(default=None) - _authorization_callback: Callable[[str], Awaitable[str]] | None = PrivateAttr( - default=None - ) - _lock: asyncio.Lock = PrivateAttr(default_factory=asyncio.Lock) - - def set_authorization_callback( - self, callback: Callable[[str], Awaitable[str]] | None - ) -> None: - """Set callback to handle authorization URL. - - Args: - callback: Async function that receives authorization URL and returns auth code. - """ - self._authorization_callback = callback - - async def apply_auth( - self, client: httpx.AsyncClient, headers: MutableMapping[str, str] - ) -> MutableMapping[str, str]: - """Apply OAuth2 access token to Authorization header. - - Uses asyncio.Lock to ensure only one coroutine handles token operations - (initial fetch or refresh) at a time. - - Args: - client: HTTP client for making token requests. - headers: Current request headers. - - Returns: - Updated headers with OAuth2 access token in Authorization header. - - Raises: - ValueError: If authorization callback is not set. - """ - if self._access_token is None: - if self._authorization_callback is None: - msg = "Authorization callback not set. Use set_authorization_callback()" - raise ValueError(msg) - async with self._lock: - if self._access_token is None: - await self._fetch_initial_token(client) - elif self._token_expires_at and time.time() >= self._token_expires_at: - async with self._lock: - if self._token_expires_at and time.time() >= self._token_expires_at: - await self._refresh_access_token(client) - - if self._access_token: - headers["Authorization"] = f"Bearer {self._access_token}" - - return headers - - async def _fetch_initial_token(self, client: httpx.AsyncClient) -> None: - """Fetch initial access token using authorization code flow. - - Args: - client: HTTP client for making token request. - - Raises: - ValueError: If authorization callback is not set. - httpx.HTTPStatusError: If token request fails. - """ - params = { - "response_type": "code", - "client_id": self.client_id, - "redirect_uri": self.redirect_uri, - "scope": " ".join(self.scopes), - } - auth_url = f"{self.authorization_url}?{urllib.parse.urlencode(params)}" - - if self._authorization_callback is None: - msg = "Authorization callback not set" - raise ValueError(msg) - auth_code = await self._authorization_callback(auth_url) - - data = { - "grant_type": "authorization_code", - "code": auth_code, - "client_id": self.client_id, - "client_secret": self.client_secret, - "redirect_uri": self.redirect_uri, - } - - response = await client.post(self.token_url, data=data) - response.raise_for_status() - - token_data = response.json() - self._access_token = token_data["access_token"] - self._refresh_token = token_data.get("refresh_token") - - expires_in = token_data.get("expires_in", 3600) - self._token_expires_at = time.time() + expires_in - 60 - - async def _refresh_access_token(self, client: httpx.AsyncClient) -> None: - """Refresh the access token using refresh token. - - Args: - client: HTTP client for making token request. - - Raises: - httpx.HTTPStatusError: If token refresh request fails. - """ - if not self._refresh_token: - await self._fetch_initial_token(client) - return - - data = { - "grant_type": "refresh_token", - "refresh_token": self._refresh_token, - "client_id": self.client_id, - "client_secret": self.client_secret, - } - - response = await client.post(self.token_url, data=data) - response.raise_for_status() - - token_data = response.json() - self._access_token = token_data["access_token"] - if "refresh_token" in token_data: - self._refresh_token = token_data["refresh_token"] - - expires_in = token_data.get("expires_in", 3600) - self._token_expires_at = time.time() + expires_in - 60 +from crewai_a2a.auth.client_schemes import * # noqa: E402, F403 diff --git a/lib/crewai/src/crewai/a2a/auth/schemas.py b/lib/crewai/src/crewai/a2a/auth/schemas.py index 4c8f3c0d7..53d032429 100644 --- a/lib/crewai/src/crewai/a2a/auth/schemas.py +++ b/lib/crewai/src/crewai/a2a/auth/schemas.py @@ -1,71 +1,13 @@ -"""Deprecated: Authentication schemes for A2A protocol agents. +"""Backward-compatibility shim — use ``crewai_a2a.auth.schemas`` instead.""" -This module is deprecated. Import from crewai.a2a.auth instead: -- crewai.a2a.auth.ClientAuthScheme (replaces AuthScheme) -- crewai.a2a.auth.BearerTokenAuth -- crewai.a2a.auth.HTTPBasicAuth -- crewai.a2a.auth.HTTPDigestAuth -- crewai.a2a.auth.APIKeyAuth -- crewai.a2a.auth.OAuth2ClientCredentials -- crewai.a2a.auth.OAuth2AuthorizationCode -""" +import warnings -from __future__ import annotations -from typing_extensions import deprecated - -from crewai.a2a.auth.client_schemes import ( - APIKeyAuth as _APIKeyAuth, - BearerTokenAuth as _BearerTokenAuth, - ClientAuthScheme as _ClientAuthScheme, - HTTPBasicAuth as _HTTPBasicAuth, - HTTPDigestAuth as _HTTPDigestAuth, - OAuth2AuthorizationCode as _OAuth2AuthorizationCode, - OAuth2ClientCredentials as _OAuth2ClientCredentials, +warnings.warn( + "'crewai.a2a.auth.schemas' has been moved to 'crewai_a2a.auth.schemas'. " + "Please update your imports. The old path will be removed in v2.0.0.", + FutureWarning, + stacklevel=2, ) - -@deprecated("Use ClientAuthScheme from crewai.a2a.auth instead", category=FutureWarning) -class AuthScheme(_ClientAuthScheme): - """Deprecated: Use ClientAuthScheme from crewai.a2a.auth instead.""" - - -@deprecated("Import from crewai.a2a.auth instead", category=FutureWarning) -class BearerTokenAuth(_BearerTokenAuth): - """Deprecated: Import from crewai.a2a.auth instead.""" - - -@deprecated("Import from crewai.a2a.auth instead", category=FutureWarning) -class HTTPBasicAuth(_HTTPBasicAuth): - """Deprecated: Import from crewai.a2a.auth instead.""" - - -@deprecated("Import from crewai.a2a.auth instead", category=FutureWarning) -class HTTPDigestAuth(_HTTPDigestAuth): - """Deprecated: Import from crewai.a2a.auth instead.""" - - -@deprecated("Import from crewai.a2a.auth instead", category=FutureWarning) -class APIKeyAuth(_APIKeyAuth): - """Deprecated: Import from crewai.a2a.auth instead.""" - - -@deprecated("Import from crewai.a2a.auth instead", category=FutureWarning) -class OAuth2ClientCredentials(_OAuth2ClientCredentials): - """Deprecated: Import from crewai.a2a.auth instead.""" - - -@deprecated("Import from crewai.a2a.auth instead", category=FutureWarning) -class OAuth2AuthorizationCode(_OAuth2AuthorizationCode): - """Deprecated: Import from crewai.a2a.auth instead.""" - - -__all__ = [ - "APIKeyAuth", - "AuthScheme", - "BearerTokenAuth", - "HTTPBasicAuth", - "HTTPDigestAuth", - "OAuth2AuthorizationCode", - "OAuth2ClientCredentials", -] +from crewai_a2a.auth.schemas import * # noqa: E402, F403 diff --git a/lib/crewai/src/crewai/a2a/auth/server_schemes.py b/lib/crewai/src/crewai/a2a/auth/server_schemes.py index 25ad597be..2c625a83f 100644 --- a/lib/crewai/src/crewai/a2a/auth/server_schemes.py +++ b/lib/crewai/src/crewai/a2a/auth/server_schemes.py @@ -1,739 +1,13 @@ -"""Server-side authentication schemes for A2A protocol. +"""Backward-compatibility shim — use ``crewai_a2a.auth.server_schemes`` instead.""" -These schemes validate incoming requests to A2A server endpoints. +import warnings -Supported authentication methods: -- Simple token validation with static bearer tokens -- OpenID Connect with JWT validation using JWKS -- OAuth2 with JWT validation or token introspection -""" -from __future__ import annotations - -from abc import ABC, abstractmethod -from dataclasses import dataclass -import logging -import os -from typing import TYPE_CHECKING, Annotated, Any, ClassVar, Literal - -import jwt -from jwt import PyJWKClient -from pydantic import ( - BaseModel, - BeforeValidator, - ConfigDict, - Field, - HttpUrl, - PrivateAttr, - SecretStr, - model_validator, +warnings.warn( + "'crewai.a2a.auth.server_schemes' has been moved to 'crewai_a2a.auth.server_schemes'. " + "Please update your imports. The old path will be removed in v2.0.0.", + FutureWarning, + stacklevel=2, ) -from typing_extensions import Self - -if TYPE_CHECKING: - from a2a.types import OAuth2SecurityScheme - - -logger = logging.getLogger(__name__) - - -try: - from fastapi import HTTPException, status as http_status - - HTTP_401_UNAUTHORIZED = http_status.HTTP_401_UNAUTHORIZED - HTTP_500_INTERNAL_SERVER_ERROR = http_status.HTTP_500_INTERNAL_SERVER_ERROR - HTTP_503_SERVICE_UNAVAILABLE = http_status.HTTP_503_SERVICE_UNAVAILABLE -except ImportError: - - class HTTPException(Exception): # type: ignore[no-redef] # noqa: N818 - """Fallback HTTPException when FastAPI is not installed.""" - - def __init__( - self, - status_code: int, - detail: str | None = None, - headers: dict[str, str] | None = None, - ) -> None: - self.status_code = status_code - self.detail = detail - self.headers = headers - super().__init__(detail) - - HTTP_401_UNAUTHORIZED = 401 - HTTP_500_INTERNAL_SERVER_ERROR = 500 - HTTP_503_SERVICE_UNAVAILABLE = 503 - - -def _coerce_secret_str(v: str | SecretStr | None) -> SecretStr | None: - """Coerce string to SecretStr.""" - if v is None or isinstance(v, SecretStr): - return v - return SecretStr(v) - - -CoercedSecretStr = Annotated[SecretStr, BeforeValidator(_coerce_secret_str)] - -JWTAlgorithm = Literal[ - "RS256", - "RS384", - "RS512", - "ES256", - "ES384", - "ES512", - "PS256", - "PS384", - "PS512", -] - - -@dataclass -class AuthenticatedUser: - """Result of successful authentication. - - Attributes: - token: The original token that was validated. - scheme: Name of the authentication scheme used. - claims: JWT claims from OIDC or OAuth2 authentication. - """ - - token: str - scheme: str - claims: dict[str, Any] | None = None - - -class ServerAuthScheme(ABC, BaseModel): - """Base class for server-side authentication schemes. - - Each scheme validates incoming requests and returns an AuthenticatedUser - on success, or raises HTTPException on failure. - """ - - model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid") - - @abstractmethod - async def authenticate(self, token: str) -> AuthenticatedUser: - """Authenticate the provided token. - - Args: - token: The bearer token to authenticate. - - Returns: - AuthenticatedUser on successful authentication. - - Raises: - HTTPException: If authentication fails. - """ - ... - - -class SimpleTokenAuth(ServerAuthScheme): - """Simple bearer token authentication. - - Validates tokens against a configured static token or AUTH_TOKEN env var. - - Attributes: - token: Expected token value. Falls back to AUTH_TOKEN env var if not set. - """ - - token: CoercedSecretStr | None = Field( - default=None, - description="Expected token. Falls back to AUTH_TOKEN env var.", - ) - - def _get_expected_token(self) -> str | None: - """Get the expected token value.""" - if self.token: - return self.token.get_secret_value() - return os.environ.get("AUTH_TOKEN") - - async def authenticate(self, token: str) -> AuthenticatedUser: - """Authenticate using simple token comparison. - - Args: - token: The bearer token to authenticate. - - Returns: - AuthenticatedUser on successful authentication. - - Raises: - HTTPException: If authentication fails. - """ - expected = self._get_expected_token() - - if expected is None: - logger.warning( - "Simple token authentication failed", - extra={"reason": "no_token_configured"}, - ) - raise HTTPException( - status_code=HTTP_401_UNAUTHORIZED, - detail="Authentication not configured", - ) - - if token != expected: - raise HTTPException( - status_code=HTTP_401_UNAUTHORIZED, - detail="Invalid or missing authentication credentials", - ) - - return AuthenticatedUser( - token=token, - scheme="simple_token", - ) - - -class OIDCAuth(ServerAuthScheme): - """OpenID Connect authentication. - - Validates JWTs using JWKS with caching support via PyJWT. - - Attributes: - issuer: The OpenID Connect issuer URL. - audience: The expected audience claim. - jwks_url: Optional explicit JWKS URL. Derived from issuer if not set. - algorithms: List of allowed signing algorithms. - required_claims: List of claims that must be present in the token. - jwks_cache_ttl: TTL for JWKS cache in seconds. - clock_skew_seconds: Allowed clock skew for token validation. - """ - - issuer: HttpUrl = Field( - description="OpenID Connect issuer URL (e.g., https://auth.example.com)" - ) - audience: str = Field(description="Expected audience claim (e.g., api://my-agent)") - jwks_url: HttpUrl | None = Field( - default=None, - description="Explicit JWKS URL. Derived from issuer if not set.", - ) - algorithms: list[str] = Field( - default_factory=lambda: ["RS256"], - description="List of allowed signing algorithms (RS256, ES256, etc.)", - ) - required_claims: list[str] = Field( - default_factory=lambda: ["exp", "iat", "iss", "aud", "sub"], - description="List of claims that must be present in the token", - ) - jwks_cache_ttl: int = Field( - default=3600, - description="TTL for JWKS cache in seconds", - ge=60, - ) - clock_skew_seconds: float = Field( - default=30.0, - description="Allowed clock skew for token validation", - ge=0.0, - ) - - _jwk_client: PyJWKClient | None = PrivateAttr(default=None) - - @model_validator(mode="after") - def _init_jwk_client(self) -> Self: - """Initialize the JWK client after model creation.""" - jwks_url = ( - str(self.jwks_url) - if self.jwks_url - else f"{str(self.issuer).rstrip('/')}/.well-known/jwks.json" - ) - self._jwk_client = PyJWKClient(jwks_url, lifespan=self.jwks_cache_ttl) - return self - - async def authenticate(self, token: str) -> AuthenticatedUser: - """Authenticate using OIDC JWT validation. - - Args: - token: The JWT to authenticate. - - Returns: - AuthenticatedUser on successful authentication. - - Raises: - HTTPException: If authentication fails. - """ - if self._jwk_client is None: - raise HTTPException( - status_code=HTTP_500_INTERNAL_SERVER_ERROR, - detail="OIDC not initialized", - ) - - try: - signing_key = self._jwk_client.get_signing_key_from_jwt(token) - - claims = jwt.decode( - token, - signing_key.key, - algorithms=self.algorithms, - audience=self.audience, - issuer=str(self.issuer).rstrip("/"), - leeway=self.clock_skew_seconds, - options={ - "require": self.required_claims, - }, - ) - - return AuthenticatedUser( - token=token, - scheme="oidc", - claims=claims, - ) - - except jwt.ExpiredSignatureError: - logger.debug( - "OIDC authentication failed", - extra={"reason": "token_expired", "scheme": "oidc"}, - ) - raise HTTPException( - status_code=HTTP_401_UNAUTHORIZED, - detail="Token has expired", - ) from None - except jwt.InvalidAudienceError: - logger.debug( - "OIDC authentication failed", - extra={"reason": "invalid_audience", "scheme": "oidc"}, - ) - raise HTTPException( - status_code=HTTP_401_UNAUTHORIZED, - detail="Invalid token audience", - ) from None - except jwt.InvalidIssuerError: - logger.debug( - "OIDC authentication failed", - extra={"reason": "invalid_issuer", "scheme": "oidc"}, - ) - raise HTTPException( - status_code=HTTP_401_UNAUTHORIZED, - detail="Invalid token issuer", - ) from None - except jwt.MissingRequiredClaimError as e: - logger.debug( - "OIDC authentication failed", - extra={"reason": "missing_claim", "claim": e.claim, "scheme": "oidc"}, - ) - raise HTTPException( - status_code=HTTP_401_UNAUTHORIZED, - detail=f"Missing required claim: {e.claim}", - ) from None - except jwt.PyJWKClientError as e: - logger.error( - "OIDC authentication failed", - extra={ - "reason": "jwks_client_error", - "error": str(e), - "scheme": "oidc", - }, - ) - raise HTTPException( - status_code=HTTP_503_SERVICE_UNAVAILABLE, - detail="Unable to fetch signing keys", - ) from None - except jwt.InvalidTokenError as e: - logger.debug( - "OIDC authentication failed", - extra={"reason": "invalid_token", "error": str(e), "scheme": "oidc"}, - ) - raise HTTPException( - status_code=HTTP_401_UNAUTHORIZED, - detail="Invalid or missing authentication credentials", - ) from None - - -class OAuth2ServerAuth(ServerAuthScheme): - """OAuth2 authentication for A2A server. - - Declares OAuth2 security scheme in AgentCard and validates tokens using - either JWKS for JWT tokens or token introspection for opaque tokens. - - This is distinct from OIDCAuth in that it declares an explicit OAuth2SecurityScheme - with flows, rather than an OpenIdConnectSecurityScheme with discovery URL. - - Attributes: - token_url: OAuth2 token endpoint URL for client_credentials flow. - authorization_url: OAuth2 authorization endpoint for authorization_code flow. - refresh_url: Optional refresh token endpoint URL. - scopes: Available OAuth2 scopes with descriptions. - jwks_url: JWKS URL for JWT validation. Required if not using introspection. - introspection_url: Token introspection endpoint (RFC 7662). Alternative to JWKS. - introspection_client_id: Client ID for introspection endpoint authentication. - introspection_client_secret: Client secret for introspection endpoint. - audience: Expected audience claim for JWT validation. - issuer: Expected issuer claim for JWT validation. - algorithms: Allowed JWT signing algorithms. - required_claims: Claims that must be present in the token. - jwks_cache_ttl: TTL for JWKS cache in seconds. - clock_skew_seconds: Allowed clock skew for token validation. - """ - - token_url: HttpUrl = Field( - description="OAuth2 token endpoint URL", - ) - authorization_url: HttpUrl | None = Field( - default=None, - description="OAuth2 authorization endpoint URL for authorization_code flow", - ) - refresh_url: HttpUrl | None = Field( - default=None, - description="OAuth2 refresh token endpoint URL", - ) - scopes: dict[str, str] = Field( - default_factory=dict, - description="Available OAuth2 scopes with descriptions", - ) - jwks_url: HttpUrl | None = Field( - default=None, - description="JWKS URL for JWT validation. Required if not using introspection.", - ) - introspection_url: HttpUrl | None = Field( - default=None, - description="Token introspection endpoint (RFC 7662). Alternative to JWKS.", - ) - introspection_client_id: str | None = Field( - default=None, - description="Client ID for introspection endpoint authentication", - ) - introspection_client_secret: CoercedSecretStr | None = Field( - default=None, - description="Client secret for introspection endpoint authentication", - ) - audience: str | None = Field( - default=None, - description="Expected audience claim for JWT validation", - ) - issuer: str | None = Field( - default=None, - description="Expected issuer claim for JWT validation", - ) - algorithms: list[str] = Field( - default_factory=lambda: ["RS256"], - description="Allowed JWT signing algorithms", - ) - required_claims: list[str] = Field( - default_factory=lambda: ["exp", "iat"], - description="Claims that must be present in the token", - ) - jwks_cache_ttl: int = Field( - default=3600, - description="TTL for JWKS cache in seconds", - ge=60, - ) - clock_skew_seconds: float = Field( - default=30.0, - description="Allowed clock skew for token validation", - ge=0.0, - ) - - _jwk_client: PyJWKClient | None = PrivateAttr(default=None) - - @model_validator(mode="after") - def _validate_and_init(self) -> Self: - """Validate configuration and initialize JWKS client if needed.""" - if not self.jwks_url and not self.introspection_url: - raise ValueError( - "Either jwks_url or introspection_url must be provided for token validation" - ) - - if self.introspection_url: - if not self.introspection_client_id or not self.introspection_client_secret: - raise ValueError( - "introspection_client_id and introspection_client_secret are required " - "when using token introspection" - ) - - if self.jwks_url: - self._jwk_client = PyJWKClient( - str(self.jwks_url), lifespan=self.jwks_cache_ttl - ) - - return self - - async def authenticate(self, token: str) -> AuthenticatedUser: - """Authenticate using OAuth2 token validation. - - Uses JWKS validation if jwks_url is configured, otherwise falls back - to token introspection. - - Args: - token: The OAuth2 access token to authenticate. - - Returns: - AuthenticatedUser on successful authentication. - - Raises: - HTTPException: If authentication fails. - """ - if self._jwk_client: - return await self._authenticate_jwt(token) - return await self._authenticate_introspection(token) - - async def _authenticate_jwt(self, token: str) -> AuthenticatedUser: - """Authenticate using JWKS JWT validation.""" - if self._jwk_client is None: - raise HTTPException( - status_code=HTTP_500_INTERNAL_SERVER_ERROR, - detail="OAuth2 JWKS not initialized", - ) - - try: - signing_key = self._jwk_client.get_signing_key_from_jwt(token) - - decode_options: dict[str, Any] = { - "require": self.required_claims, - } - - claims = jwt.decode( - token, - signing_key.key, - algorithms=self.algorithms, - audience=self.audience, - issuer=self.issuer, - leeway=self.clock_skew_seconds, - options=decode_options, - ) - - return AuthenticatedUser( - token=token, - scheme="oauth2", - claims=claims, - ) - - except jwt.ExpiredSignatureError: - logger.debug( - "OAuth2 authentication failed", - extra={"reason": "token_expired", "scheme": "oauth2"}, - ) - raise HTTPException( - status_code=HTTP_401_UNAUTHORIZED, - detail="Token has expired", - ) from None - except jwt.InvalidAudienceError: - logger.debug( - "OAuth2 authentication failed", - extra={"reason": "invalid_audience", "scheme": "oauth2"}, - ) - raise HTTPException( - status_code=HTTP_401_UNAUTHORIZED, - detail="Invalid token audience", - ) from None - except jwt.InvalidIssuerError: - logger.debug( - "OAuth2 authentication failed", - extra={"reason": "invalid_issuer", "scheme": "oauth2"}, - ) - raise HTTPException( - status_code=HTTP_401_UNAUTHORIZED, - detail="Invalid token issuer", - ) from None - except jwt.MissingRequiredClaimError as e: - logger.debug( - "OAuth2 authentication failed", - extra={"reason": "missing_claim", "claim": e.claim, "scheme": "oauth2"}, - ) - raise HTTPException( - status_code=HTTP_401_UNAUTHORIZED, - detail=f"Missing required claim: {e.claim}", - ) from None - except jwt.PyJWKClientError as e: - logger.error( - "OAuth2 authentication failed", - extra={ - "reason": "jwks_client_error", - "error": str(e), - "scheme": "oauth2", - }, - ) - raise HTTPException( - status_code=HTTP_503_SERVICE_UNAVAILABLE, - detail="Unable to fetch signing keys", - ) from None - except jwt.InvalidTokenError as e: - logger.debug( - "OAuth2 authentication failed", - extra={"reason": "invalid_token", "error": str(e), "scheme": "oauth2"}, - ) - raise HTTPException( - status_code=HTTP_401_UNAUTHORIZED, - detail="Invalid or missing authentication credentials", - ) from None - - async def _authenticate_introspection(self, token: str) -> AuthenticatedUser: - """Authenticate using OAuth2 token introspection (RFC 7662).""" - import httpx - - if not self.introspection_url: - raise HTTPException( - status_code=HTTP_500_INTERNAL_SERVER_ERROR, - detail="OAuth2 introspection not configured", - ) - - try: - async with httpx.AsyncClient() as client: - response = await client.post( - str(self.introspection_url), - data={"token": token}, - auth=( - self.introspection_client_id or "", - self.introspection_client_secret.get_secret_value() - if self.introspection_client_secret - else "", - ), - ) - response.raise_for_status() - introspection_result = response.json() - - except httpx.HTTPStatusError as e: - logger.error( - "OAuth2 introspection failed", - extra={"reason": "http_error", "status_code": e.response.status_code}, - ) - raise HTTPException( - status_code=HTTP_503_SERVICE_UNAVAILABLE, - detail="Token introspection service unavailable", - ) from None - except Exception as e: - logger.error( - "OAuth2 introspection failed", - extra={"reason": "unexpected_error", "error": str(e)}, - ) - raise HTTPException( - status_code=HTTP_503_SERVICE_UNAVAILABLE, - detail="Token introspection failed", - ) from None - - if not introspection_result.get("active", False): - logger.debug( - "OAuth2 authentication failed", - extra={"reason": "token_not_active", "scheme": "oauth2"}, - ) - raise HTTPException( - status_code=HTTP_401_UNAUTHORIZED, - detail="Token is not active", - ) - - return AuthenticatedUser( - token=token, - scheme="oauth2", - claims=introspection_result, - ) - - def to_security_scheme(self) -> OAuth2SecurityScheme: - """Generate OAuth2SecurityScheme for AgentCard declaration. - - Creates an OAuth2SecurityScheme with appropriate flows based on - the configured URLs. Includes client_credentials flow if token_url - is set, and authorization_code flow if authorization_url is set. - - Returns: - OAuth2SecurityScheme suitable for use in AgentCard security_schemes. - """ - from a2a.types import ( - AuthorizationCodeOAuthFlow, - ClientCredentialsOAuthFlow, - OAuth2SecurityScheme, - OAuthFlows, - ) - - client_credentials = None - authorization_code = None - - if self.token_url: - client_credentials = ClientCredentialsOAuthFlow( - token_url=str(self.token_url), - refresh_url=str(self.refresh_url) if self.refresh_url else None, - scopes=self.scopes, - ) - - if self.authorization_url: - authorization_code = AuthorizationCodeOAuthFlow( - authorization_url=str(self.authorization_url), - token_url=str(self.token_url), - refresh_url=str(self.refresh_url) if self.refresh_url else None, - scopes=self.scopes, - ) - - return OAuth2SecurityScheme( - flows=OAuthFlows( - client_credentials=client_credentials, - authorization_code=authorization_code, - ), - description="OAuth2 authentication", - ) - - -class APIKeyServerAuth(ServerAuthScheme): - """API Key authentication for A2A server. - - Validates requests using an API key in a header, query parameter, or cookie. - - Attributes: - name: The name of the API key parameter (default: X-API-Key). - location: Where to look for the API key (header, query, or cookie). - api_key: The expected API key value. - """ - - name: str = Field( - default="X-API-Key", - description="Name of the API key parameter", - ) - location: Literal["header", "query", "cookie"] = Field( - default="header", - description="Where to look for the API key", - ) - api_key: CoercedSecretStr = Field( - description="Expected API key value", - ) - - async def authenticate(self, token: str) -> AuthenticatedUser: - """Authenticate using API key comparison. - - Args: - token: The API key to authenticate. - - Returns: - AuthenticatedUser on successful authentication. - - Raises: - HTTPException: If authentication fails. - """ - if token != self.api_key.get_secret_value(): - raise HTTPException( - status_code=HTTP_401_UNAUTHORIZED, - detail="Invalid API key", - ) - - return AuthenticatedUser( - token=token, - scheme="api_key", - ) - - -class MTLSServerAuth(ServerAuthScheme): - """Mutual TLS authentication marker for AgentCard declaration. - - This scheme is primarily for AgentCard security_schemes declaration. - Actual mTLS verification happens at the TLS/transport layer, not - at the application layer via token validation. - - When configured, this signals to clients that the server requires - client certificates for authentication. - """ - - description: str = Field( - default="Mutual TLS certificate authentication", - description="Description for the security scheme", - ) - - async def authenticate(self, token: str) -> AuthenticatedUser: - """Return authenticated user for mTLS. - - mTLS verification happens at the transport layer before this is called. - If we reach this point, the TLS handshake with client cert succeeded. - - Args: - token: Certificate subject or identifier (from TLS layer). - - Returns: - AuthenticatedUser indicating mTLS authentication. - """ - return AuthenticatedUser( - token=token or "mtls-verified", - scheme="mtls", - ) +from crewai_a2a.auth.server_schemes import * # noqa: E402, F403 diff --git a/lib/crewai/src/crewai/a2a/auth/utils.py b/lib/crewai/src/crewai/a2a/auth/utils.py index 3e8de3e0d..cc8fd4f42 100644 --- a/lib/crewai/src/crewai/a2a/auth/utils.py +++ b/lib/crewai/src/crewai/a2a/auth/utils.py @@ -1,273 +1,13 @@ -"""Authentication utilities for A2A protocol agent communication. +"""Backward-compatibility shim — use ``crewai_a2a.auth.utils`` instead.""" -Provides validation and retry logic for various authentication schemes including -OAuth2, API keys, and HTTP authentication methods. -""" +import warnings -import asyncio -from collections.abc import Awaitable, Callable, MutableMapping -import hashlib -import re -import threading -from typing import Final, Literal, cast -from a2a.client.errors import A2AClientHTTPError -from a2a.types import ( - APIKeySecurityScheme, - AgentCard, - HTTPAuthSecurityScheme, - OAuth2SecurityScheme, -) -from httpx import AsyncClient, Response - -from crewai.a2a.auth.client_schemes import ( - APIKeyAuth, - BearerTokenAuth, - ClientAuthScheme, - HTTPBasicAuth, - HTTPDigestAuth, - OAuth2AuthorizationCode, - OAuth2ClientCredentials, +warnings.warn( + "'crewai.a2a.auth.utils' has been moved to 'crewai_a2a.auth.utils'. " + "Please update your imports. The old path will be removed in v2.0.0.", + FutureWarning, + stacklevel=2, ) - -class _AuthStore: - """Store for authentication schemes with safe concurrent access.""" - - def __init__(self) -> None: - self._store: dict[str, ClientAuthScheme | None] = {} - self._lock = threading.RLock() - - @staticmethod - def compute_key(auth_type: str, auth_data: str) -> str: - """Compute a collision-resistant key using SHA-256.""" - content = f"{auth_type}:{auth_data}" - return hashlib.sha256(content.encode()).hexdigest() - - def set(self, key: str, auth: ClientAuthScheme | None) -> None: - """Store an auth scheme.""" - with self._lock: - self._store[key] = auth - - def get(self, key: str) -> ClientAuthScheme | None: - """Retrieve an auth scheme by key.""" - with self._lock: - return self._store.get(key) - - def __setitem__(self, key: str, value: ClientAuthScheme | None) -> None: - with self._lock: - self._store[key] = value - - def __getitem__(self, key: str) -> ClientAuthScheme | None: - with self._lock: - return self._store[key] - - -_auth_store = _AuthStore() - -_SCHEME_PATTERN: Final[re.Pattern[str]] = re.compile(r"(\w+)\s+(.+?)(?=,\s*\w+\s+|$)") -_PARAM_PATTERN: Final[re.Pattern[str]] = re.compile(r'(\w+)=(?:"([^"]*)"|([^\s,]+))') - -_SCHEME_AUTH_MAPPING: Final[dict[type, tuple[type[ClientAuthScheme], ...]]] = { - OAuth2SecurityScheme: ( - OAuth2ClientCredentials, - OAuth2AuthorizationCode, - BearerTokenAuth, - ), - APIKeySecurityScheme: (APIKeyAuth,), -} - -_HTTPSchemeType = Literal["basic", "digest", "bearer"] - -_HTTP_SCHEME_MAPPING: Final[dict[_HTTPSchemeType, type[ClientAuthScheme]]] = { - "basic": HTTPBasicAuth, - "digest": HTTPDigestAuth, - "bearer": BearerTokenAuth, -} - - -def _raise_auth_mismatch( - expected_classes: type[ClientAuthScheme] | tuple[type[ClientAuthScheme], ...], - provided_auth: ClientAuthScheme, -) -> None: - """Raise authentication mismatch error. - - Args: - expected_classes: Expected authentication class or tuple of classes. - provided_auth: Actually provided authentication instance. - - Raises: - A2AClientHTTPError: Always raises with 401 status code. - """ - if isinstance(expected_classes, tuple): - if len(expected_classes) == 1: - required = expected_classes[0].__name__ - else: - names = [cls.__name__ for cls in expected_classes] - required = f"one of ({', '.join(names)})" - else: - required = expected_classes.__name__ - - msg = ( - f"AgentCard requires {required} authentication, " - f"but {type(provided_auth).__name__} was provided" - ) - raise A2AClientHTTPError(401, msg) - - -def parse_www_authenticate(header_value: str) -> dict[str, dict[str, str]]: - """Parse WWW-Authenticate header into auth challenges. - - Args: - header_value: The WWW-Authenticate header value. - - Returns: - Dictionary mapping auth scheme to its parameters. - Example: {"Bearer": {"realm": "api", "scope": "read write"}} - """ - if not header_value: - return {} - - challenges: dict[str, dict[str, str]] = {} - - for match in _SCHEME_PATTERN.finditer(header_value): - scheme = match.group(1) - params_str = match.group(2) - - params: dict[str, str] = {} - - for param_match in _PARAM_PATTERN.finditer(params_str): - key = param_match.group(1) - value = param_match.group(2) or param_match.group(3) - params[key] = value - - challenges[scheme] = params - - return challenges - - -def validate_auth_against_agent_card( - agent_card: AgentCard, auth: ClientAuthScheme | None -) -> None: - """Validate that provided auth matches AgentCard security requirements. - - Args: - agent_card: The A2A AgentCard containing security requirements. - auth: User-provided authentication scheme (or None). - - Raises: - A2AClientHTTPError: If auth doesn't match AgentCard requirements (status_code=401). - """ - - if not agent_card.security or not agent_card.security_schemes: - return - - if not auth: - msg = "AgentCard requires authentication but no auth scheme provided" - raise A2AClientHTTPError(401, msg) - - first_security_req = agent_card.security[0] if agent_card.security else {} - - for scheme_name in first_security_req.keys(): - security_scheme_wrapper = agent_card.security_schemes.get(scheme_name) - if not security_scheme_wrapper: - continue - - scheme = security_scheme_wrapper.root - - if allowed_classes := _SCHEME_AUTH_MAPPING.get(type(scheme)): - if not isinstance(auth, allowed_classes): - _raise_auth_mismatch(allowed_classes, auth) - return - - if isinstance(scheme, HTTPAuthSecurityScheme): - scheme_key = cast(_HTTPSchemeType, scheme.scheme.lower()) - if required_class := _HTTP_SCHEME_MAPPING.get(scheme_key): - if not isinstance(auth, required_class): - _raise_auth_mismatch(required_class, auth) - return - - msg = "Could not validate auth against AgentCard security requirements" - raise A2AClientHTTPError(401, msg) - - -async def retry_on_401( - request_func: Callable[[], Awaitable[Response]], - auth_scheme: ClientAuthScheme | None, - client: AsyncClient, - headers: MutableMapping[str, str], - max_retries: int = 3, -) -> Response: - """Retry a request on 401 authentication error. - - Handles 401 errors by: - 1. Parsing WWW-Authenticate header - 2. Re-acquiring credentials - 3. Retrying the request - - Args: - request_func: Async function that makes the HTTP request. - auth_scheme: Authentication scheme to refresh credentials with. - client: HTTP client for making requests. - headers: Request headers to update with new auth. - max_retries: Maximum number of retry attempts (default: 3). - - Returns: - HTTP response from the request. - - Raises: - httpx.HTTPStatusError: If retries are exhausted or auth scheme is None. - """ - last_response: Response | None = None - last_challenges: dict[str, dict[str, str]] = {} - - for attempt in range(max_retries): - response = await request_func() - - if response.status_code != 401: - return response - - last_response = response - - if auth_scheme is None: - response.raise_for_status() - return response - - www_authenticate = response.headers.get("WWW-Authenticate", "") - challenges = parse_www_authenticate(www_authenticate) - last_challenges = challenges - - if attempt >= max_retries - 1: - break - - backoff_time = 2**attempt - await asyncio.sleep(backoff_time) - - await auth_scheme.apply_auth(client, headers) - - if last_response: - last_response.raise_for_status() - return last_response - - msg = "retry_on_401 failed without making any requests" - if last_challenges: - challenge_info = ", ".join( - f"{scheme} (realm={params.get('realm', 'N/A')})" - for scheme, params in last_challenges.items() - ) - msg = f"{msg}. Server challenges: {challenge_info}" - raise RuntimeError(msg) - - -def configure_auth_client( - auth: HTTPDigestAuth | APIKeyAuth, client: AsyncClient -) -> None: - """Configure HTTP client with auth-specific settings. - - Only HTTPDigestAuth and APIKeyAuth need client configuration. - - Args: - auth: Authentication scheme that requires client configuration. - client: HTTP client to configure. - """ - auth.configure_client(client) +from crewai_a2a.auth.utils import * # noqa: E402, F403 diff --git a/lib/crewai/src/crewai/a2a/config.py b/lib/crewai/src/crewai/a2a/config.py index 1b9d63db4..c34853623 100644 --- a/lib/crewai/src/crewai/a2a/config.py +++ b/lib/crewai/src/crewai/a2a/config.py @@ -1,690 +1,13 @@ -"""A2A configuration types. +"""Backward-compatibility shim — use ``crewai_a2a.config`` instead.""" -This module is separate from experimental.a2a to avoid circular imports. -""" - -from __future__ import annotations - -from pathlib import Path -from typing import Any, ClassVar, Literal, cast import warnings -from pydantic import ( - BaseModel, - ConfigDict, - Field, - FilePath, - PrivateAttr, - SecretStr, - model_validator, + +warnings.warn( + "'crewai.a2a.config' has been moved to 'crewai_a2a.config'. " + "Please update your imports. The old path will be removed in v2.0.0.", + FutureWarning, + stacklevel=2, ) -from typing_extensions import Self, deprecated -from crewai.a2a.auth.client_schemes import ClientAuthScheme -from crewai.a2a.auth.server_schemes import ServerAuthScheme -from crewai.a2a.extensions.base import ValidatedA2AExtension -from crewai.a2a.types import ProtocolVersion, TransportType, Url - - -try: - from a2a.types import ( - AgentCapabilities, - AgentCardSignature, - AgentInterface, - AgentProvider, - AgentSkill, - SecurityScheme, - ) - - from crewai.a2a.extensions.server import ServerExtension - from crewai.a2a.updates import UpdateConfig -except ImportError: - UpdateConfig: Any = Any # type: ignore[no-redef] - AgentCapabilities: Any = Any # type: ignore[no-redef] - AgentCardSignature: Any = Any # type: ignore[no-redef] - AgentInterface: Any = Any # type: ignore[no-redef] - AgentProvider: Any = Any # type: ignore[no-redef] - SecurityScheme: Any = Any # type: ignore[no-redef] - AgentSkill: Any = Any # type: ignore[no-redef] - ServerExtension: Any = Any # type: ignore[no-redef] - - -def _get_default_update_config() -> UpdateConfig: - from crewai.a2a.updates import StreamingConfig - - return StreamingConfig() - - -SigningAlgorithm = Literal[ - "RS256", - "RS384", - "RS512", - "ES256", - "ES384", - "ES512", - "PS256", - "PS384", - "PS512", -] - - -class AgentCardSigningConfig(BaseModel): - """Configuration for AgentCard JWS signing. - - Provides the private key and algorithm settings for signing AgentCards. - Either private_key_path or private_key_pem must be provided, but not both. - - Attributes: - private_key_path: Path to a PEM-encoded private key file. - private_key_pem: PEM-encoded private key as a secret string. - key_id: Optional key identifier for the JWS header (kid claim). - algorithm: Signing algorithm (RS256, ES256, PS256, etc.). - """ - - model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid") - - private_key_path: FilePath | None = Field( - default=None, - description="Path to PEM-encoded private key file", - ) - private_key_pem: SecretStr | None = Field( - default=None, - description="PEM-encoded private key", - ) - key_id: str | None = Field( - default=None, - description="Key identifier for JWS header (kid claim)", - ) - algorithm: SigningAlgorithm = Field( - default="RS256", - description="Signing algorithm (RS256, ES256, PS256, etc.)", - ) - - @model_validator(mode="after") - def _validate_key_source(self) -> Self: - """Ensure exactly one key source is provided.""" - has_path = self.private_key_path is not None - has_pem = self.private_key_pem is not None - - if not has_path and not has_pem: - raise ValueError( - "Either private_key_path or private_key_pem must be provided" - ) - if has_path and has_pem: - raise ValueError( - "Only one of private_key_path or private_key_pem should be provided" - ) - return self - - def get_private_key(self) -> str: - """Get the private key content. - - Returns: - The PEM-encoded private key as a string. - """ - if self.private_key_pem: - return self.private_key_pem.get_secret_value() - if self.private_key_path: - return Path(self.private_key_path).read_text() - raise ValueError("No private key configured") - - -class GRPCServerConfig(BaseModel): - """gRPC server transport configuration. - - Presence of this config in ServerTransportConfig.grpc enables gRPC transport. - - Attributes: - host: Hostname to advertise in agent cards (default: localhost). - Use docker service name (e.g., 'web') for docker-compose setups. - port: Port for the gRPC server. - tls_cert_path: Path to TLS certificate file for gRPC. - tls_key_path: Path to TLS private key file for gRPC. - max_workers: Maximum number of workers for the gRPC thread pool. - reflection_enabled: Whether to enable gRPC reflection for debugging. - """ - - model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid") - - host: str = Field( - default="localhost", - description="Hostname to advertise in agent cards for gRPC connections", - ) - port: int = Field( - default=50051, - description="Port for the gRPC server", - ) - tls_cert_path: str | None = Field( - default=None, - description="Path to TLS certificate file for gRPC", - ) - tls_key_path: str | None = Field( - default=None, - description="Path to TLS private key file for gRPC", - ) - max_workers: int = Field( - default=10, - description="Maximum number of workers for the gRPC thread pool", - ) - reflection_enabled: bool = Field( - default=False, - description="Whether to enable gRPC reflection for debugging", - ) - - -class GRPCClientConfig(BaseModel): - """gRPC client transport configuration. - - Attributes: - max_send_message_length: Maximum size for outgoing messages in bytes. - max_receive_message_length: Maximum size for incoming messages in bytes. - keepalive_time_ms: Time between keepalive pings in milliseconds. - keepalive_timeout_ms: Timeout for keepalive ping response in milliseconds. - """ - - model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid") - - max_send_message_length: int | None = Field( - default=None, - description="Maximum size for outgoing messages in bytes", - ) - max_receive_message_length: int | None = Field( - default=None, - description="Maximum size for incoming messages in bytes", - ) - keepalive_time_ms: int | None = Field( - default=None, - description="Time between keepalive pings in milliseconds", - ) - keepalive_timeout_ms: int | None = Field( - default=None, - description="Timeout for keepalive ping response in milliseconds", - ) - - -class JSONRPCServerConfig(BaseModel): - """JSON-RPC server transport configuration. - - Presence of this config in ServerTransportConfig.jsonrpc enables JSON-RPC transport. - - Attributes: - rpc_path: URL path for the JSON-RPC endpoint. - agent_card_path: URL path for the agent card endpoint. - """ - - model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid") - - rpc_path: str = Field( - default="/a2a", - description="URL path for the JSON-RPC endpoint", - ) - agent_card_path: str = Field( - default="/.well-known/agent-card.json", - description="URL path for the agent card endpoint", - ) - - -class JSONRPCClientConfig(BaseModel): - """JSON-RPC client transport configuration. - - Attributes: - max_request_size: Maximum request body size in bytes. - """ - - model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid") - - max_request_size: int | None = Field( - default=None, - description="Maximum request body size in bytes", - ) - - -class HTTPJSONConfig(BaseModel): - """HTTP+JSON transport configuration. - - Presence of this config in ServerTransportConfig.http_json enables HTTP+JSON transport. - """ - - model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid") - - -class ServerPushNotificationConfig(BaseModel): - """Configuration for outgoing webhook push notifications. - - Controls how the server signs and delivers push notifications to clients. - - Attributes: - signature_secret: Shared secret for HMAC-SHA256 signing of outgoing webhooks. - """ - - model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid") - - signature_secret: SecretStr | None = Field( - default=None, - description="Shared secret for HMAC-SHA256 signing of outgoing push notifications", - ) - - -class ServerTransportConfig(BaseModel): - """Transport configuration for A2A server. - - Groups all transport-related settings including preferred transport - and protocol-specific configurations. - - Attributes: - preferred: Transport protocol for the preferred endpoint. - jsonrpc: JSON-RPC server transport configuration. - grpc: gRPC server transport configuration. - http_json: HTTP+JSON transport configuration. - """ - - model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid") - - preferred: TransportType = Field( - default="JSONRPC", - description="Transport protocol for the preferred endpoint", - ) - jsonrpc: JSONRPCServerConfig = Field( - default_factory=JSONRPCServerConfig, - description="JSON-RPC server transport configuration", - ) - grpc: GRPCServerConfig | None = Field( - default=None, - description="gRPC server transport configuration", - ) - http_json: HTTPJSONConfig | None = Field( - default=None, - description="HTTP+JSON transport configuration", - ) - - -def _migrate_client_transport_fields( - transport: ClientTransportConfig, - transport_protocol: TransportType | None, - supported_transports: list[TransportType] | None, -) -> None: - """Migrate deprecated transport fields to new config.""" - if transport_protocol is not None: - warnings.warn( - "transport_protocol is deprecated, use transport=ClientTransportConfig(preferred=...) instead", - FutureWarning, - stacklevel=5, - ) - object.__setattr__(transport, "preferred", transport_protocol) - if supported_transports is not None: - warnings.warn( - "supported_transports is deprecated, use transport=ClientTransportConfig(supported=...) instead", - FutureWarning, - stacklevel=5, - ) - object.__setattr__(transport, "supported", supported_transports) - - -class ClientTransportConfig(BaseModel): - """Transport configuration for A2A client. - - Groups all client transport-related settings including preferred transport, - supported transports for negotiation, and protocol-specific configurations. - - Transport negotiation logic: - 1. If `preferred` is set and server supports it → use client's preferred - 2. Otherwise, if server's preferred is in client's `supported` → use server's preferred - 3. Otherwise, find first match from client's `supported` in server's interfaces - - Attributes: - preferred: Client's preferred transport. If set, client preference takes priority. - supported: Transports the client can use, in order of preference. - jsonrpc: JSON-RPC client transport configuration. - grpc: gRPC client transport configuration. - """ - - model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid") - - preferred: TransportType | None = Field( - default=None, - description="Client's preferred transport. If set, takes priority over server preference.", - ) - supported: list[TransportType] = Field( - default_factory=lambda: cast(list[TransportType], ["JSONRPC"]), - description="Transports the client can use, in order of preference", - ) - jsonrpc: JSONRPCClientConfig = Field( - default_factory=JSONRPCClientConfig, - description="JSON-RPC client transport configuration", - ) - grpc: GRPCClientConfig = Field( - default_factory=GRPCClientConfig, - description="gRPC client transport configuration", - ) - - -@deprecated( - """ - `crewai.a2a.config.A2AConfig` is deprecated and will be removed in v2.0.0, - use `crewai.a2a.config.A2AClientConfig` or `crewai.a2a.config.A2AServerConfig` instead. - """, - category=FutureWarning, -) -class A2AConfig(BaseModel): - """Configuration for A2A protocol integration. - - Deprecated: - Use A2AClientConfig instead. This class will be removed in a future version. - - Attributes: - endpoint: A2A agent endpoint URL. - auth: Authentication scheme. - timeout: Request timeout in seconds. - max_turns: Maximum conversation turns with A2A agent. - response_model: Optional Pydantic model for structured A2A agent responses. - fail_fast: If True, raise error when agent unreachable; if False, skip and continue. - trust_remote_completion_status: If True, return A2A agent's result directly when completed. - updates: Update mechanism config. - client_extensions: Client-side processing hooks for tool injection and prompt augmentation. - transport: Transport configuration (preferred, supported transports, gRPC settings). - """ - - model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid") - - endpoint: Url = Field(description="A2A agent endpoint URL") - auth: ClientAuthScheme | None = Field( - default=None, - description="Authentication scheme", - ) - timeout: int = Field(default=120, description="Request timeout in seconds") - max_turns: int = Field( - default=10, description="Maximum conversation turns with A2A agent" - ) - response_model: type[BaseModel] | None = Field( - default=None, - description="Optional Pydantic model for structured A2A agent responses", - ) - fail_fast: bool = Field( - default=True, - description="If True, raise error when agent unreachable; if False, skip", - ) - trust_remote_completion_status: bool = Field( - default=False, - description="If True, return A2A result directly when completed", - ) - updates: UpdateConfig = Field( - default_factory=_get_default_update_config, - description="Update mechanism config", - ) - client_extensions: list[ValidatedA2AExtension] = Field( - default_factory=list, - description="Client-side processing hooks for tool injection and prompt augmentation", - ) - transport: ClientTransportConfig = Field( - default_factory=ClientTransportConfig, - description="Transport configuration (preferred, supported transports, gRPC settings)", - ) - transport_protocol: TransportType | None = Field( - default=None, - description="Deprecated: Use transport.preferred instead", - exclude=True, - ) - supported_transports: list[TransportType] | None = Field( - default=None, - description="Deprecated: Use transport.supported instead", - exclude=True, - ) - use_client_preference: bool | None = Field( - default=None, - description="Deprecated: Set transport.preferred to enable client preference", - exclude=True, - ) - _parallel_delegation: bool = PrivateAttr(default=False) - - @model_validator(mode="after") - def _migrate_deprecated_transport_fields(self) -> Self: - """Migrate deprecated transport fields to new config.""" - _migrate_client_transport_fields( - self.transport, self.transport_protocol, self.supported_transports - ) - if self.use_client_preference is not None: - warnings.warn( - "use_client_preference is deprecated, set transport.preferred to enable client preference", - FutureWarning, - stacklevel=4, - ) - if self.use_client_preference and self.transport.supported: - object.__setattr__( - self.transport, "preferred", self.transport.supported[0] - ) - return self - - -class A2AClientConfig(BaseModel): - """Configuration for connecting to remote A2A agents. - - Attributes: - endpoint: A2A agent endpoint URL. - auth: Authentication scheme. - timeout: Request timeout in seconds. - max_turns: Maximum conversation turns with A2A agent. - response_model: Optional Pydantic model for structured A2A agent responses. - fail_fast: If True, raise error when agent unreachable; if False, skip and continue. - trust_remote_completion_status: If True, return A2A agent's result directly when completed. - updates: Update mechanism config. - accepted_output_modes: Media types the client can accept in responses. - extensions: Extension URIs the client supports (A2A protocol extensions). - client_extensions: Client-side processing hooks for tool injection and prompt augmentation. - transport: Transport configuration (preferred, supported transports, gRPC settings). - """ - - model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid") - - endpoint: Url = Field(description="A2A agent endpoint URL") - auth: ClientAuthScheme | None = Field( - default=None, - description="Authentication scheme", - ) - timeout: int = Field(default=120, description="Request timeout in seconds") - max_turns: int = Field( - default=10, description="Maximum conversation turns with A2A agent" - ) - response_model: type[BaseModel] | None = Field( - default=None, - description="Optional Pydantic model for structured A2A agent responses", - ) - fail_fast: bool = Field( - default=True, - description="If True, raise error when agent unreachable; if False, skip", - ) - trust_remote_completion_status: bool = Field( - default=False, - description="If True, return A2A result directly when completed", - ) - updates: UpdateConfig = Field( - default_factory=_get_default_update_config, - description="Update mechanism config", - ) - accepted_output_modes: list[str] = Field( - default_factory=lambda: ["application/json"], - description="Media types the client can accept in responses", - ) - extensions: list[str] = Field( - default_factory=list, - description="Extension URIs the client supports", - ) - client_extensions: list[ValidatedA2AExtension] = Field( - default_factory=list, - description="Client-side processing hooks for tool injection and prompt augmentation", - ) - transport: ClientTransportConfig = Field( - default_factory=ClientTransportConfig, - description="Transport configuration (preferred, supported transports, gRPC settings)", - ) - transport_protocol: TransportType | None = Field( - default=None, - description="Deprecated: Use transport.preferred instead", - exclude=True, - ) - supported_transports: list[TransportType] | None = Field( - default=None, - description="Deprecated: Use transport.supported instead", - exclude=True, - ) - _parallel_delegation: bool = PrivateAttr(default=False) - - @model_validator(mode="after") - def _migrate_deprecated_transport_fields(self) -> Self: - """Migrate deprecated transport fields to new config.""" - _migrate_client_transport_fields( - self.transport, self.transport_protocol, self.supported_transports - ) - return self - - -class A2AServerConfig(BaseModel): - """Configuration for exposing a Crew or Agent as an A2A server. - - All fields correspond to A2A AgentCard fields. Fields like name, description, - and skills can be auto-derived from the Crew/Agent if not provided. - - Attributes: - name: Human-readable name for the agent. - description: Human-readable description of the agent. - version: Version string for the agent card. - skills: List of agent skills/capabilities. - default_input_modes: Default supported input MIME types. - default_output_modes: Default supported output MIME types. - capabilities: Declaration of optional capabilities. - protocol_version: A2A protocol version this agent supports. - provider: Information about the agent's service provider. - documentation_url: URL to the agent's documentation. - icon_url: URL to an icon for the agent. - additional_interfaces: Additional supported interfaces. - security: Security requirement objects for all interactions. - security_schemes: Security schemes available to authorize requests. - supports_authenticated_extended_card: Whether agent provides extended card to authenticated users. - url: Preferred endpoint URL for the agent. - signing_config: Configuration for signing the AgentCard with JWS. - signatures: Deprecated. Pre-computed JWS signatures. Use signing_config instead. - server_extensions: Server-side A2A protocol extensions with on_request/on_response hooks. - push_notifications: Configuration for outgoing push notifications. - transport: Transport configuration (preferred transport, gRPC, REST settings). - auth: Authentication scheme for A2A endpoints. - """ - - model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid") - - name: str | None = Field( - default=None, - description="Human-readable name for the agent. Auto-derived from Crew/Agent if not provided.", - ) - description: str | None = Field( - default=None, - description="Human-readable description of the agent. Auto-derived from Crew/Agent if not provided.", - ) - version: str = Field( - default="1.0.0", - description="Version string for the agent card", - ) - skills: list[AgentSkill] = Field( - default_factory=list, - description="List of agent skills. Auto-derived from tasks/tools if not provided.", - ) - default_input_modes: list[str] = Field( - default_factory=lambda: ["text/plain", "application/json"], - description="Default supported input MIME types", - ) - default_output_modes: list[str] = Field( - default_factory=lambda: ["text/plain", "application/json"], - description="Default supported output MIME types", - ) - capabilities: AgentCapabilities = Field( - default_factory=lambda: AgentCapabilities( - streaming=True, - push_notifications=False, - ), - description="Declaration of optional capabilities supported by the agent", - ) - protocol_version: ProtocolVersion = Field( - default="0.3.0", - description="A2A protocol version this agent supports", - ) - provider: AgentProvider | None = Field( - default=None, - description="Information about the agent's service provider", - ) - documentation_url: Url | None = Field( - default=None, - description="URL to the agent's documentation", - ) - icon_url: Url | None = Field( - default=None, - description="URL to an icon for the agent", - ) - additional_interfaces: list[AgentInterface] = Field( - default_factory=list, - description="Additional supported interfaces.", - ) - security: list[dict[str, list[str]]] = Field( - default_factory=list, - description="Security requirement objects for all agent interactions", - ) - security_schemes: dict[str, SecurityScheme] = Field( - default_factory=dict, - description="Security schemes available to authorize requests", - ) - supports_authenticated_extended_card: bool = Field( - default=False, - description="Whether agent provides extended card to authenticated users", - ) - url: Url | None = Field( - default=None, - description="Preferred endpoint URL for the agent. Set at runtime if not provided.", - ) - signing_config: AgentCardSigningConfig | None = Field( - default=None, - description="Configuration for signing the AgentCard with JWS", - ) - signatures: list[AgentCardSignature] | None = Field( - default=None, - description="Deprecated: Use signing_config instead. Pre-computed JWS signatures for the AgentCard.", - exclude=True, - deprecated=True, - ) - server_extensions: list[ServerExtension] = Field( - default_factory=list, - description="Server-side A2A protocol extensions that modify agent behavior", - ) - push_notifications: ServerPushNotificationConfig | None = Field( - default=None, - description="Configuration for outgoing push notifications", - ) - transport: ServerTransportConfig = Field( - default_factory=ServerTransportConfig, - description="Transport configuration (preferred transport, gRPC, REST settings)", - ) - preferred_transport: TransportType | None = Field( - default=None, - description="Deprecated: Use transport.preferred instead", - exclude=True, - deprecated=True, - ) - auth: ServerAuthScheme | None = Field( - default=None, - description="Authentication scheme for A2A endpoints. Defaults to SimpleTokenAuth using AUTH_TOKEN env var.", - ) - - @model_validator(mode="after") - def _migrate_deprecated_fields(self) -> Self: - """Migrate deprecated fields to new config.""" - if self.preferred_transport is not None: - warnings.warn( - "preferred_transport is deprecated, use transport=ServerTransportConfig(preferred=...) instead", - FutureWarning, - stacklevel=4, - ) - object.__setattr__(self.transport, "preferred", self.preferred_transport) - if self.signatures is not None: - warnings.warn( - "signatures is deprecated, use signing_config=AgentCardSigningConfig(...) instead. " - "The signatures field will be removed in v2.0.0.", - FutureWarning, - stacklevel=4, - ) - return self +from crewai_a2a.config import * # noqa: E402, F403 diff --git a/lib/crewai/src/crewai/a2a/errors.py b/lib/crewai/src/crewai/a2a/errors.py index aabe10288..f8905a531 100644 --- a/lib/crewai/src/crewai/a2a/errors.py +++ b/lib/crewai/src/crewai/a2a/errors.py @@ -1,491 +1,13 @@ -"""A2A error codes and error response utilities. +"""Backward-compatibility shim — use ``crewai_a2a.errors`` instead.""" -This module provides a centralized mapping of all A2A protocol error codes -as defined in the A2A specification, plus custom CrewAI extensions. +import warnings -Error codes follow JSON-RPC 2.0 conventions: -- -32700 to -32600: Standard JSON-RPC errors -- -32099 to -32000: Server errors (A2A-specific) -- -32768 to -32100: Reserved for implementation-defined errors -""" -from __future__ import annotations +warnings.warn( + "'crewai.a2a.errors' has been moved to 'crewai_a2a.errors'. " + "Please update your imports. The old path will be removed in v2.0.0.", + FutureWarning, + stacklevel=2, +) -from dataclasses import dataclass, field -from enum import IntEnum -from typing import Any - -from a2a.client.errors import A2AClientTimeoutError - - -class A2APollingTimeoutError(A2AClientTimeoutError): - """Raised when polling exceeds the configured timeout.""" - - -class A2AErrorCode(IntEnum): - """A2A protocol error codes. - - Codes follow JSON-RPC 2.0 specification with A2A-specific extensions. - """ - - # JSON-RPC 2.0 Standard Errors (-32700 to -32600) - JSON_PARSE_ERROR = -32700 - """Invalid JSON was received by the server.""" - - INVALID_REQUEST = -32600 - """The JSON sent is not a valid Request object.""" - - METHOD_NOT_FOUND = -32601 - """The method does not exist / is not available.""" - - INVALID_PARAMS = -32602 - """Invalid method parameter(s).""" - - INTERNAL_ERROR = -32603 - """Internal JSON-RPC error.""" - - # A2A-Specific Errors (-32099 to -32000) - TASK_NOT_FOUND = -32001 - """The specified task was not found.""" - - TASK_NOT_CANCELABLE = -32002 - """The task cannot be canceled (already completed/failed).""" - - PUSH_NOTIFICATION_NOT_SUPPORTED = -32003 - """Push notifications are not supported by this agent.""" - - UNSUPPORTED_OPERATION = -32004 - """The requested operation is not supported.""" - - CONTENT_TYPE_NOT_SUPPORTED = -32005 - """Incompatible content types between client and server.""" - - INVALID_AGENT_RESPONSE = -32006 - """The agent produced an invalid response.""" - - # CrewAI Custom Extensions (-32768 to -32100) - UNSUPPORTED_VERSION = -32009 - """The requested A2A protocol version is not supported.""" - - UNSUPPORTED_EXTENSION = -32010 - """Client does not support required protocol extensions.""" - - AUTHENTICATION_REQUIRED = -32011 - """Authentication is required for this operation.""" - - AUTHORIZATION_FAILED = -32012 - """Authorization check failed (insufficient permissions).""" - - RATE_LIMIT_EXCEEDED = -32013 - """Rate limit exceeded for this client/operation.""" - - TASK_TIMEOUT = -32014 - """Task execution timed out.""" - - TRANSPORT_NEGOTIATION_FAILED = -32015 - """Failed to negotiate a compatible transport protocol.""" - - CONTEXT_NOT_FOUND = -32016 - """The specified context was not found.""" - - SKILL_NOT_FOUND = -32017 - """The specified skill was not found.""" - - ARTIFACT_NOT_FOUND = -32018 - """The specified artifact was not found.""" - - -# Error code to default message mapping -ERROR_MESSAGES: dict[int, str] = { - A2AErrorCode.JSON_PARSE_ERROR: "Parse error", - A2AErrorCode.INVALID_REQUEST: "Invalid Request", - A2AErrorCode.METHOD_NOT_FOUND: "Method not found", - A2AErrorCode.INVALID_PARAMS: "Invalid params", - A2AErrorCode.INTERNAL_ERROR: "Internal error", - A2AErrorCode.TASK_NOT_FOUND: "Task not found", - A2AErrorCode.TASK_NOT_CANCELABLE: "Task not cancelable", - A2AErrorCode.PUSH_NOTIFICATION_NOT_SUPPORTED: "Push Notification is not supported", - A2AErrorCode.UNSUPPORTED_OPERATION: "This operation is not supported", - A2AErrorCode.CONTENT_TYPE_NOT_SUPPORTED: "Incompatible content types", - A2AErrorCode.INVALID_AGENT_RESPONSE: "Invalid agent response", - A2AErrorCode.UNSUPPORTED_VERSION: "Unsupported A2A version", - A2AErrorCode.UNSUPPORTED_EXTENSION: "Client does not support required extensions", - A2AErrorCode.AUTHENTICATION_REQUIRED: "Authentication required", - A2AErrorCode.AUTHORIZATION_FAILED: "Authorization failed", - A2AErrorCode.RATE_LIMIT_EXCEEDED: "Rate limit exceeded", - A2AErrorCode.TASK_TIMEOUT: "Task execution timed out", - A2AErrorCode.TRANSPORT_NEGOTIATION_FAILED: "Transport negotiation failed", - A2AErrorCode.CONTEXT_NOT_FOUND: "Context not found", - A2AErrorCode.SKILL_NOT_FOUND: "Skill not found", - A2AErrorCode.ARTIFACT_NOT_FOUND: "Artifact not found", -} - - -@dataclass -class A2AError(Exception): - """Base exception for A2A protocol errors. - - Attributes: - code: The A2A/JSON-RPC error code. - message: Human-readable error message. - data: Optional additional error data. - """ - - code: int - message: str | None = None - data: Any = None - - def __post_init__(self) -> None: - if self.message is None: - self.message = ERROR_MESSAGES.get(self.code, "Unknown error") - super().__init__(self.message) - - def to_dict(self) -> dict[str, Any]: - """Convert to JSON-RPC error object format.""" - error: dict[str, Any] = { - "code": self.code, - "message": self.message, - } - if self.data is not None: - error["data"] = self.data - return error - - def to_response(self, request_id: str | int | None = None) -> dict[str, Any]: - """Convert to full JSON-RPC error response.""" - return { - "jsonrpc": "2.0", - "error": self.to_dict(), - "id": request_id, - } - - -@dataclass -class JSONParseError(A2AError): - """Invalid JSON was received.""" - - code: int = field(default=A2AErrorCode.JSON_PARSE_ERROR, init=False) - - -@dataclass -class InvalidRequestError(A2AError): - """The JSON sent is not a valid Request object.""" - - code: int = field(default=A2AErrorCode.INVALID_REQUEST, init=False) - - -@dataclass -class MethodNotFoundError(A2AError): - """The method does not exist / is not available.""" - - code: int = field(default=A2AErrorCode.METHOD_NOT_FOUND, init=False) - method: str | None = None - - def __post_init__(self) -> None: - if self.message is None and self.method: - self.message = f"Method not found: {self.method}" - super().__post_init__() - - -@dataclass -class InvalidParamsError(A2AError): - """Invalid method parameter(s).""" - - code: int = field(default=A2AErrorCode.INVALID_PARAMS, init=False) - param: str | None = None - reason: str | None = None - - def __post_init__(self) -> None: - if self.message is None: - if self.param and self.reason: - self.message = f"Invalid parameter '{self.param}': {self.reason}" - elif self.param: - self.message = f"Invalid parameter: {self.param}" - super().__post_init__() - - -@dataclass -class InternalError(A2AError): - """Internal JSON-RPC error.""" - - code: int = field(default=A2AErrorCode.INTERNAL_ERROR, init=False) - - -@dataclass -class TaskNotFoundError(A2AError): - """The specified task was not found.""" - - code: int = field(default=A2AErrorCode.TASK_NOT_FOUND, init=False) - task_id: str | None = None - - def __post_init__(self) -> None: - if self.message is None and self.task_id: - self.message = f"Task not found: {self.task_id}" - super().__post_init__() - - -@dataclass -class TaskNotCancelableError(A2AError): - """The task cannot be canceled.""" - - code: int = field(default=A2AErrorCode.TASK_NOT_CANCELABLE, init=False) - task_id: str | None = None - reason: str | None = None - - def __post_init__(self) -> None: - if self.message is None: - if self.task_id and self.reason: - self.message = f"Task {self.task_id} cannot be canceled: {self.reason}" - elif self.task_id: - self.message = f"Task {self.task_id} cannot be canceled" - super().__post_init__() - - -@dataclass -class PushNotificationNotSupportedError(A2AError): - """Push notifications are not supported.""" - - code: int = field(default=A2AErrorCode.PUSH_NOTIFICATION_NOT_SUPPORTED, init=False) - - -@dataclass -class UnsupportedOperationError(A2AError): - """The requested operation is not supported.""" - - code: int = field(default=A2AErrorCode.UNSUPPORTED_OPERATION, init=False) - operation: str | None = None - - def __post_init__(self) -> None: - if self.message is None and self.operation: - self.message = f"Operation not supported: {self.operation}" - super().__post_init__() - - -@dataclass -class ContentTypeNotSupportedError(A2AError): - """Incompatible content types.""" - - code: int = field(default=A2AErrorCode.CONTENT_TYPE_NOT_SUPPORTED, init=False) - requested_types: list[str] | None = None - supported_types: list[str] | None = None - - def __post_init__(self) -> None: - if self.message is None and self.requested_types and self.supported_types: - self.message = ( - f"Content type not supported. Requested: {self.requested_types}, " - f"Supported: {self.supported_types}" - ) - super().__post_init__() - - -@dataclass -class InvalidAgentResponseError(A2AError): - """The agent produced an invalid response.""" - - code: int = field(default=A2AErrorCode.INVALID_AGENT_RESPONSE, init=False) - - -@dataclass -class UnsupportedVersionError(A2AError): - """The requested A2A version is not supported.""" - - code: int = field(default=A2AErrorCode.UNSUPPORTED_VERSION, init=False) - requested_version: str | None = None - supported_versions: list[str] | None = None - - def __post_init__(self) -> None: - if self.message is None and self.requested_version: - msg = f"Unsupported A2A version: {self.requested_version}" - if self.supported_versions: - msg += f". Supported versions: {', '.join(self.supported_versions)}" - self.message = msg - super().__post_init__() - - -@dataclass -class UnsupportedExtensionError(A2AError): - """Client does not support required extensions.""" - - code: int = field(default=A2AErrorCode.UNSUPPORTED_EXTENSION, init=False) - required_extensions: list[str] | None = None - - def __post_init__(self) -> None: - if self.message is None and self.required_extensions: - self.message = f"Client does not support required extensions: {', '.join(self.required_extensions)}" - super().__post_init__() - - -@dataclass -class AuthenticationRequiredError(A2AError): - """Authentication is required.""" - - code: int = field(default=A2AErrorCode.AUTHENTICATION_REQUIRED, init=False) - - -@dataclass -class AuthorizationFailedError(A2AError): - """Authorization check failed.""" - - code: int = field(default=A2AErrorCode.AUTHORIZATION_FAILED, init=False) - required_scope: str | None = None - - def __post_init__(self) -> None: - if self.message is None and self.required_scope: - self.message = ( - f"Authorization failed. Required scope: {self.required_scope}" - ) - super().__post_init__() - - -@dataclass -class RateLimitExceededError(A2AError): - """Rate limit exceeded.""" - - code: int = field(default=A2AErrorCode.RATE_LIMIT_EXCEEDED, init=False) - retry_after: int | None = None - - def __post_init__(self) -> None: - if self.message is None and self.retry_after: - self.message = ( - f"Rate limit exceeded. Retry after {self.retry_after} seconds" - ) - if self.retry_after: - self.data = {"retry_after": self.retry_after} - super().__post_init__() - - -@dataclass -class TaskTimeoutError(A2AError): - """Task execution timed out.""" - - code: int = field(default=A2AErrorCode.TASK_TIMEOUT, init=False) - task_id: str | None = None - timeout_seconds: float | None = None - - def __post_init__(self) -> None: - if self.message is None: - if self.task_id and self.timeout_seconds: - self.message = ( - f"Task {self.task_id} timed out after {self.timeout_seconds}s" - ) - elif self.task_id: - self.message = f"Task {self.task_id} timed out" - super().__post_init__() - - -@dataclass -class TransportNegotiationFailedError(A2AError): - """Failed to negotiate a compatible transport protocol.""" - - code: int = field(default=A2AErrorCode.TRANSPORT_NEGOTIATION_FAILED, init=False) - client_transports: list[str] | None = None - server_transports: list[str] | None = None - - def __post_init__(self) -> None: - if self.message is None and self.client_transports and self.server_transports: - self.message = ( - f"Transport negotiation failed. Client: {self.client_transports}, " - f"Server: {self.server_transports}" - ) - super().__post_init__() - - -@dataclass -class ContextNotFoundError(A2AError): - """The specified context was not found.""" - - code: int = field(default=A2AErrorCode.CONTEXT_NOT_FOUND, init=False) - context_id: str | None = None - - def __post_init__(self) -> None: - if self.message is None and self.context_id: - self.message = f"Context not found: {self.context_id}" - super().__post_init__() - - -@dataclass -class SkillNotFoundError(A2AError): - """The specified skill was not found.""" - - code: int = field(default=A2AErrorCode.SKILL_NOT_FOUND, init=False) - skill_id: str | None = None - - def __post_init__(self) -> None: - if self.message is None and self.skill_id: - self.message = f"Skill not found: {self.skill_id}" - super().__post_init__() - - -@dataclass -class ArtifactNotFoundError(A2AError): - """The specified artifact was not found.""" - - code: int = field(default=A2AErrorCode.ARTIFACT_NOT_FOUND, init=False) - artifact_id: str | None = None - - def __post_init__(self) -> None: - if self.message is None and self.artifact_id: - self.message = f"Artifact not found: {self.artifact_id}" - super().__post_init__() - - -def create_error_response( - code: int | A2AErrorCode, - message: str | None = None, - data: Any = None, - request_id: str | int | None = None, -) -> dict[str, Any]: - """Create a JSON-RPC error response. - - Args: - code: Error code (A2AErrorCode or int). - message: Optional error message (uses default if not provided). - data: Optional additional error data. - request_id: Request ID for correlation. - - Returns: - Dict in JSON-RPC error response format. - """ - error = A2AError(code=int(code), message=message, data=data) - return error.to_response(request_id) - - -def is_retryable_error(code: int) -> bool: - """Check if an error is potentially retryable. - - Args: - code: Error code to check. - - Returns: - True if the error might be resolved by retrying. - """ - retryable_codes = { - A2AErrorCode.INTERNAL_ERROR, - A2AErrorCode.RATE_LIMIT_EXCEEDED, - A2AErrorCode.TASK_TIMEOUT, - } - return code in retryable_codes - - -def is_client_error(code: int) -> bool: - """Check if an error is a client-side error. - - Args: - code: Error code to check. - - Returns: - True if the error is due to client request issues. - """ - client_error_codes = { - A2AErrorCode.JSON_PARSE_ERROR, - A2AErrorCode.INVALID_REQUEST, - A2AErrorCode.METHOD_NOT_FOUND, - A2AErrorCode.INVALID_PARAMS, - A2AErrorCode.TASK_NOT_FOUND, - A2AErrorCode.CONTENT_TYPE_NOT_SUPPORTED, - A2AErrorCode.UNSUPPORTED_VERSION, - A2AErrorCode.UNSUPPORTED_EXTENSION, - A2AErrorCode.CONTEXT_NOT_FOUND, - A2AErrorCode.SKILL_NOT_FOUND, - A2AErrorCode.ARTIFACT_NOT_FOUND, - } - return code in client_error_codes +from crewai_a2a.errors import * # noqa: E402, F403 diff --git a/lib/crewai/src/crewai/a2a/extensions/__init__.py b/lib/crewai/src/crewai/a2a/extensions/__init__.py index b21ae10ad..9c902367b 100644 --- a/lib/crewai/src/crewai/a2a/extensions/__init__.py +++ b/lib/crewai/src/crewai/a2a/extensions/__init__.py @@ -1,37 +1,13 @@ -"""A2A Protocol Extensions for CrewAI. +"""Backward-compatibility shim — use ``crewai_a2a.extensions`` instead.""" -This module contains extensions to the A2A (Agent-to-Agent) protocol. +import warnings -**Client-side extensions** (A2AExtension) allow customizing how the A2A wrapper -processes requests and responses during delegation to remote agents. These provide -hooks for tool injection, prompt augmentation, and response processing. -**Server-side extensions** (ServerExtension) allow agents to offer additional -functionality beyond the core A2A specification. Clients activate extensions -via the X-A2A-Extensions header. - -See: https://a2a-protocol.org/latest/topics/extensions/ -""" - -from crewai.a2a.extensions.base import ( - A2AExtension, - ConversationState, - ExtensionRegistry, - ValidatedA2AExtension, -) -from crewai.a2a.extensions.server import ( - ExtensionContext, - ServerExtension, - ServerExtensionRegistry, +warnings.warn( + "'crewai.a2a.extensions' has been moved to 'crewai_a2a.extensions'. " + "Please update your imports. The old path will be removed in v2.0.0.", + FutureWarning, + stacklevel=2, ) - -__all__ = [ - "A2AExtension", - "ConversationState", - "ExtensionContext", - "ExtensionRegistry", - "ServerExtension", - "ServerExtensionRegistry", - "ValidatedA2AExtension", -] +from crewai_a2a.extensions import * # noqa: E402, F403 diff --git a/lib/crewai/src/crewai/a2a/extensions/base.py b/lib/crewai/src/crewai/a2a/extensions/base.py index 2d7a81a22..f94c675c7 100644 --- a/lib/crewai/src/crewai/a2a/extensions/base.py +++ b/lib/crewai/src/crewai/a2a/extensions/base.py @@ -1,238 +1,13 @@ -"""Base extension interface for CrewAI A2A wrapper processing hooks. +"""Backward-compatibility shim — use ``crewai_a2a.extensions.base`` instead.""" -This module defines the protocol for extending CrewAI's A2A wrapper functionality -with custom logic for tool injection, prompt augmentation, and response processing. - -Note: These are CrewAI-specific processing hooks, NOT A2A protocol extensions. -A2A protocol extensions are capability declarations using AgentExtension objects -in AgentCard.capabilities.extensions, activated via the A2A-Extensions HTTP header. -See: https://a2a-protocol.org/latest/topics/extensions/ -""" - -from __future__ import annotations - -from collections.abc import Sequence -from typing import TYPE_CHECKING, Annotated, Any, Protocol, runtime_checkable - -from pydantic import BeforeValidator +import warnings -if TYPE_CHECKING: - from a2a.types import Message +warnings.warn( + "'crewai.a2a.extensions.base' has been moved to 'crewai_a2a.extensions.base'. " + "Please update your imports. The old path will be removed in v2.0.0.", + FutureWarning, + stacklevel=2, +) - from crewai.agent.core import Agent - - -def _validate_a2a_extension(v: Any) -> Any: - """Validate that value implements A2AExtension protocol.""" - if not isinstance(v, A2AExtension): - raise ValueError( - f"Value must implement A2AExtension protocol. " - f"Got {type(v).__name__} which is missing required methods." - ) - return v - - -ValidatedA2AExtension = Annotated[Any, BeforeValidator(_validate_a2a_extension)] - - -@runtime_checkable -class ConversationState(Protocol): - """Protocol for extension-specific conversation state. - - Extensions can define their own state classes that implement this protocol - to track conversation-specific data extracted from message history. - """ - - def is_ready(self) -> bool: - """Check if the state indicates readiness for some action. - - Returns: - True if the state is ready, False otherwise. - """ - ... - - -@runtime_checkable -class A2AExtension(Protocol): - """Protocol for A2A wrapper extensions. - - Extensions can implement this protocol to inject custom logic into - the A2A conversation flow at various integration points. - - Example: - class MyExtension: - def inject_tools(self, agent: Agent) -> None: - # Add custom tools to the agent - pass - - def extract_state_from_history( - self, conversation_history: Sequence[Message] - ) -> ConversationState | None: - # Extract state from conversation - return None - - def augment_prompt( - self, base_prompt: str, conversation_state: ConversationState | None - ) -> str: - # Add custom instructions - return base_prompt - - def process_response( - self, agent_response: Any, conversation_state: ConversationState | None - ) -> Any: - # Modify response if needed - return agent_response - """ - - def inject_tools(self, agent: Agent) -> None: - """Inject extension-specific tools into the agent. - - Called when an agent is wrapped with A2A capabilities. Extensions - can add tools that enable extension-specific functionality. - - Args: - agent: The agent instance to inject tools into. - """ - ... - - def extract_state_from_history( - self, conversation_history: Sequence[Message] - ) -> ConversationState | None: - """Extract extension-specific state from conversation history. - - Called during prompt augmentation to allow extensions to analyze - the conversation history and extract relevant state information. - - Args: - conversation_history: The sequence of A2A messages exchanged. - - Returns: - Extension-specific conversation state, or None if no relevant state. - """ - ... - - def augment_prompt( - self, - base_prompt: str, - conversation_state: ConversationState | None, - ) -> str: - """Augment the task prompt with extension-specific instructions. - - Called during prompt augmentation to allow extensions to add - custom instructions based on conversation state. - - Args: - base_prompt: The base prompt to augment. - conversation_state: Extension-specific state from extract_state_from_history. - - Returns: - The augmented prompt with extension-specific instructions. - """ - ... - - def process_response( - self, - agent_response: Any, - conversation_state: ConversationState | None, - ) -> Any: - """Process and potentially modify the agent response. - - Called after parsing the agent's response, allowing extensions to - enhance or modify the response based on conversation state. - - Args: - agent_response: The parsed agent response. - conversation_state: Extension-specific state from extract_state_from_history. - - Returns: - The processed agent response (may be modified or original). - """ - ... - - -class ExtensionRegistry: - """Registry for managing A2A extensions. - - Maintains a collection of extensions and provides methods to invoke - their hooks at various integration points. - """ - - def __init__(self) -> None: - """Initialize the extension registry.""" - self._extensions: list[A2AExtension] = [] - - def register(self, extension: A2AExtension) -> None: - """Register an extension. - - Args: - extension: The extension to register. - """ - self._extensions.append(extension) - - def inject_all_tools(self, agent: Agent) -> None: - """Inject tools from all registered extensions. - - Args: - agent: The agent instance to inject tools into. - """ - for extension in self._extensions: - extension.inject_tools(agent) - - def extract_all_states( - self, conversation_history: Sequence[Message] - ) -> dict[type[A2AExtension], ConversationState]: - """Extract conversation states from all registered extensions. - - Args: - conversation_history: The sequence of A2A messages exchanged. - - Returns: - Mapping of extension types to their conversation states. - """ - states: dict[type[A2AExtension], ConversationState] = {} - for extension in self._extensions: - state = extension.extract_state_from_history(conversation_history) - if state is not None: - states[type(extension)] = state - return states - - def augment_prompt_with_all( - self, - base_prompt: str, - extension_states: dict[type[A2AExtension], ConversationState], - ) -> str: - """Augment prompt with instructions from all registered extensions. - - Args: - base_prompt: The base prompt to augment. - extension_states: Mapping of extension types to conversation states. - - Returns: - The fully augmented prompt. - """ - augmented = base_prompt - for extension in self._extensions: - state = extension_states.get(type(extension)) - augmented = extension.augment_prompt(augmented, state) - return augmented - - def process_response_with_all( - self, - agent_response: Any, - extension_states: dict[type[A2AExtension], ConversationState], - ) -> Any: - """Process response through all registered extensions. - - Args: - agent_response: The parsed agent response. - extension_states: Mapping of extension types to conversation states. - - Returns: - The processed agent response. - """ - processed = agent_response - for extension in self._extensions: - state = extension_states.get(type(extension)) - processed = extension.process_response(processed, state) - return processed +from crewai_a2a.extensions.base import * # noqa: E402, F403 diff --git a/lib/crewai/src/crewai/a2a/extensions/registry.py b/lib/crewai/src/crewai/a2a/extensions/registry.py index 4d195961b..f81d0db36 100644 --- a/lib/crewai/src/crewai/a2a/extensions/registry.py +++ b/lib/crewai/src/crewai/a2a/extensions/registry.py @@ -1,170 +1,13 @@ -"""A2A Protocol extension utilities. +"""Backward-compatibility shim — use ``crewai_a2a.extensions.registry`` instead.""" -This module provides utilities for working with A2A protocol extensions as -defined in the A2A specification. Extensions are capability declarations in -AgentCard.capabilities.extensions using AgentExtension objects, activated -via the X-A2A-Extensions HTTP header. +import warnings -See: https://a2a-protocol.org/latest/topics/extensions/ -""" -from __future__ import annotations - -from typing import Any - -from a2a.client.middleware import ClientCallContext, ClientCallInterceptor -from a2a.extensions.common import ( - HTTP_EXTENSION_HEADER, +warnings.warn( + "'crewai.a2a.extensions.registry' has been moved to 'crewai_a2a.extensions.registry'. " + "Please update your imports. The old path will be removed in v2.0.0.", + FutureWarning, + stacklevel=2, ) -from a2a.types import AgentCard, AgentExtension -from crewai.a2a.config import A2AClientConfig, A2AConfig -from crewai.a2a.extensions.base import ExtensionRegistry - - -def get_extensions_from_config( - a2a_config: list[A2AConfig | A2AClientConfig] | A2AConfig | A2AClientConfig, -) -> list[str]: - """Extract extension URIs from A2A configuration. - - Args: - a2a_config: A2A configuration (single or list). - - Returns: - Deduplicated list of extension URIs from all configs. - """ - configs = a2a_config if isinstance(a2a_config, list) else [a2a_config] - seen: set[str] = set() - result: list[str] = [] - - for config in configs: - if not isinstance(config, A2AClientConfig): - continue - for uri in config.extensions: - if uri not in seen: - seen.add(uri) - result.append(uri) - - return result - - -class ExtensionsMiddleware(ClientCallInterceptor): - """Middleware to add X-A2A-Extensions header to requests. - - This middleware adds the extensions header to all outgoing requests, - declaring which A2A protocol extensions the client supports. - """ - - def __init__(self, extensions: list[str]) -> None: - """Initialize with extension URIs. - - Args: - extensions: List of extension URIs the client supports. - """ - self._extensions = extensions - - async def intercept( - self, - method_name: str, - request_payload: dict[str, Any], - http_kwargs: dict[str, Any], - agent_card: AgentCard | None, - context: ClientCallContext | None, - ) -> tuple[dict[str, Any], dict[str, Any]]: - """Add extensions header to the request. - - Args: - method_name: The A2A method being called. - request_payload: The JSON-RPC request payload. - http_kwargs: HTTP request kwargs (headers, etc). - agent_card: The target agent's card. - context: Optional call context. - - Returns: - Tuple of (request_payload, modified_http_kwargs). - """ - if self._extensions: - headers = http_kwargs.setdefault("headers", {}) - headers[HTTP_EXTENSION_HEADER] = ",".join(self._extensions) - return request_payload, http_kwargs - - -def validate_required_extensions( - agent_card: AgentCard, - client_extensions: list[str] | None, -) -> list[AgentExtension]: - """Validate that client supports all required extensions from agent. - - Args: - agent_card: The agent's card with declared extensions. - client_extensions: Extension URIs the client supports. - - Returns: - List of unsupported required extensions. - - Raises: - None - returns list of unsupported extensions for caller to handle. - """ - unsupported: list[AgentExtension] = [] - client_set = set(client_extensions or []) - - if not agent_card.capabilities or not agent_card.capabilities.extensions: - return unsupported - - unsupported.extend( - ext - for ext in agent_card.capabilities.extensions - if ext.required and ext.uri not in client_set - ) - - return unsupported - - -def create_extension_registry_from_config( - a2a_config: list[A2AConfig | A2AClientConfig] | A2AConfig | A2AClientConfig, -) -> ExtensionRegistry: - """Create an extension registry from A2A client configuration. - - Extracts client_extensions from each A2AClientConfig and registers them - with the ExtensionRegistry. These extensions provide CrewAI-specific - processing hooks (tool injection, prompt augmentation, response processing). - - Note: A2A protocol extensions (URI strings sent via X-A2A-Extensions header) - are handled separately via get_extensions_from_config() and ExtensionsMiddleware. - - Args: - a2a_config: A2A configuration (single or list). - - Returns: - Extension registry with all client_extensions registered. - - Example: - class LoggingExtension: - def inject_tools(self, agent): pass - def extract_state_from_history(self, history): return None - def augment_prompt(self, prompt, state): return prompt - def process_response(self, response, state): - print(f"Response: {response}") - return response - - config = A2AClientConfig( - endpoint="https://agent.example.com", - client_extensions=[LoggingExtension()], - ) - registry = create_extension_registry_from_config(config) - """ - registry = ExtensionRegistry() - configs = a2a_config if isinstance(a2a_config, list) else [a2a_config] - - seen: set[int] = set() - - for config in configs: - if isinstance(config, (A2AConfig, A2AClientConfig)): - client_exts = getattr(config, "client_extensions", []) - for extension in client_exts: - ext_id = id(extension) - if ext_id not in seen: - seen.add(ext_id) - registry.register(extension) - - return registry +from crewai_a2a.extensions.registry import * # noqa: E402, F403 diff --git a/lib/crewai/src/crewai/a2a/extensions/server.py b/lib/crewai/src/crewai/a2a/extensions/server.py index 9bbc9c08b..67d10fee4 100644 --- a/lib/crewai/src/crewai/a2a/extensions/server.py +++ b/lib/crewai/src/crewai/a2a/extensions/server.py @@ -1,305 +1,13 @@ -"""A2A protocol server extensions for CrewAI agents. +"""Backward-compatibility shim — use ``crewai_a2a.extensions.server`` instead.""" -This module provides the base class and context for implementing A2A protocol -extensions on the server side. Extensions allow agents to offer additional -functionality beyond the core A2A specification. +import warnings -See: https://a2a-protocol.org/latest/topics/extensions/ -""" -from __future__ import annotations +warnings.warn( + "'crewai.a2a.extensions.server' has been moved to 'crewai_a2a.extensions.server'. " + "Please update your imports. The old path will be removed in v2.0.0.", + FutureWarning, + stacklevel=2, +) -from abc import ABC, abstractmethod -from dataclasses import dataclass, field -import logging -from typing import TYPE_CHECKING, Annotated, Any - -from a2a.types import AgentExtension -from pydantic_core import CoreSchema, core_schema - - -if TYPE_CHECKING: - from a2a.server.context import ServerCallContext - from pydantic import GetCoreSchemaHandler - - -logger = logging.getLogger(__name__) - - -@dataclass -class ExtensionContext: - """Context passed to extension hooks during request processing. - - Provides access to request metadata, client extensions, and shared state - that extensions can read from and write to. - - Attributes: - metadata: Request metadata dict, includes extension-namespaced keys. - client_extensions: Set of extension URIs the client declared support for. - state: Mutable dict for extensions to share data during request lifecycle. - server_context: The underlying A2A server call context. - """ - - metadata: dict[str, Any] - client_extensions: set[str] - state: dict[str, Any] = field(default_factory=dict) - server_context: ServerCallContext | None = None - - def get_extension_metadata(self, uri: str, key: str) -> Any | None: - """Get extension-specific metadata value. - - Extension metadata uses namespaced keys in the format: - "{extension_uri}/{key}" - - Args: - uri: The extension URI. - key: The metadata key within the extension namespace. - - Returns: - The metadata value, or None if not present. - """ - full_key = f"{uri}/{key}" - return self.metadata.get(full_key) - - def set_extension_metadata(self, uri: str, key: str, value: Any) -> None: - """Set extension-specific metadata value. - - Args: - uri: The extension URI. - key: The metadata key within the extension namespace. - value: The value to set. - """ - full_key = f"{uri}/{key}" - self.metadata[full_key] = value - - -class ServerExtension(ABC): - """Base class for A2A protocol server extensions. - - Subclass this to create custom extensions that modify agent behavior - when clients activate them. Extensions are identified by URI and can - be marked as required. - - Example: - class SamplingExtension(ServerExtension): - uri = "urn:crewai:ext:sampling/v1" - required = True - - def __init__(self, max_tokens: int = 4096): - self.max_tokens = max_tokens - - @property - def params(self) -> dict[str, Any]: - return {"max_tokens": self.max_tokens} - - async def on_request(self, context: ExtensionContext) -> None: - limit = context.get_extension_metadata(self.uri, "limit") - if limit: - context.state["token_limit"] = int(limit) - - async def on_response(self, context: ExtensionContext, result: Any) -> Any: - return result - """ - - uri: Annotated[str, "Extension URI identifier. Must be unique."] - required: Annotated[bool, "Whether clients must support this extension."] = False - description: Annotated[ - str | None, "Human-readable description of the extension." - ] = None - - @classmethod - def __get_pydantic_core_schema__( - cls, - _source_type: Any, - _handler: GetCoreSchemaHandler, - ) -> CoreSchema: - """Tell Pydantic how to validate ServerExtension instances.""" - return core_schema.is_instance_schema(cls) - - @property - def params(self) -> dict[str, Any] | None: - """Extension parameters to advertise in AgentCard. - - Override this property to expose configuration that clients can read. - - Returns: - Dict of parameter names to values, or None. - """ - return None - - def agent_extension(self) -> AgentExtension: - """Generate the AgentExtension object for the AgentCard. - - Returns: - AgentExtension with this extension's URI, required flag, and params. - """ - return AgentExtension( - uri=self.uri, - required=self.required if self.required else None, - description=self.description, - params=self.params, - ) - - def is_active(self, context: ExtensionContext) -> bool: - """Check if this extension is active for the current request. - - An extension is active if the client declared support for it. - - Args: - context: The extension context for the current request. - - Returns: - True if the client supports this extension. - """ - return self.uri in context.client_extensions - - @abstractmethod - async def on_request(self, context: ExtensionContext) -> None: - """Called before agent execution if extension is active. - - Use this hook to: - - Read extension-specific metadata from the request - - Set up state for the execution - - Modify execution parameters via context.state - - Args: - context: The extension context with request metadata and state. - """ - ... - - @abstractmethod - async def on_response(self, context: ExtensionContext, result: Any) -> Any: - """Called after agent execution if extension is active. - - Use this hook to: - - Modify or enhance the result - - Add extension-specific metadata to the response - - Clean up any resources - - Args: - context: The extension context with request metadata and state. - result: The agent execution result. - - Returns: - The result, potentially modified. - """ - ... - - -class ServerExtensionRegistry: - """Registry for managing server-side A2A protocol extensions. - - Collects extensions and provides methods to generate AgentCapabilities - and invoke extension hooks during request processing. - """ - - def __init__(self, extensions: list[ServerExtension] | None = None) -> None: - """Initialize the registry with optional extensions. - - Args: - extensions: Initial list of extensions to register. - """ - self._extensions: list[ServerExtension] = list(extensions) if extensions else [] - self._by_uri: dict[str, ServerExtension] = { - ext.uri: ext for ext in self._extensions - } - - def register(self, extension: ServerExtension) -> None: - """Register an extension. - - Args: - extension: The extension to register. - - Raises: - ValueError: If an extension with the same URI is already registered. - """ - if extension.uri in self._by_uri: - raise ValueError(f"Extension already registered: {extension.uri}") - self._extensions.append(extension) - self._by_uri[extension.uri] = extension - - def get_agent_extensions(self) -> list[AgentExtension]: - """Get AgentExtension objects for all registered extensions. - - Returns: - List of AgentExtension objects for the AgentCard. - """ - return [ext.agent_extension() for ext in self._extensions] - - def get_extension(self, uri: str) -> ServerExtension | None: - """Get an extension by URI. - - Args: - uri: The extension URI. - - Returns: - The extension, or None if not found. - """ - return self._by_uri.get(uri) - - @staticmethod - def create_context( - metadata: dict[str, Any], - client_extensions: set[str], - server_context: ServerCallContext | None = None, - ) -> ExtensionContext: - """Create an ExtensionContext for a request. - - Args: - metadata: Request metadata dict. - client_extensions: Set of extension URIs from client. - server_context: Optional server call context. - - Returns: - ExtensionContext for use in hooks. - """ - return ExtensionContext( - metadata=metadata, - client_extensions=client_extensions, - server_context=server_context, - ) - - async def invoke_on_request(self, context: ExtensionContext) -> None: - """Invoke on_request hooks for all active extensions. - - Tracks activated extensions and isolates errors from individual hooks. - - Args: - context: The extension context for the request. - """ - for extension in self._extensions: - if extension.is_active(context): - try: - await extension.on_request(context) - if context.server_context is not None: - context.server_context.activated_extensions.add(extension.uri) - except Exception: - logger.exception( - "Extension on_request hook failed", - extra={"extension": extension.uri}, - ) - - async def invoke_on_response(self, context: ExtensionContext, result: Any) -> Any: - """Invoke on_response hooks for all active extensions. - - Isolates errors from individual hooks to prevent one failing extension - from breaking the entire response. - - Args: - context: The extension context for the request. - result: The agent execution result. - - Returns: - The result after all extensions have processed it. - """ - processed = result - for extension in self._extensions: - if extension.is_active(context): - try: - processed = await extension.on_response(context, processed) - except Exception: - logger.exception( - "Extension on_response hook failed", - extra={"extension": extension.uri}, - ) - return processed +from crewai_a2a.extensions.server import * # noqa: E402, F403 diff --git a/lib/crewai/src/crewai/a2a/task_helpers.py b/lib/crewai/src/crewai/a2a/task_helpers.py index b4a758656..cef3a7944 100644 --- a/lib/crewai/src/crewai/a2a/task_helpers.py +++ b/lib/crewai/src/crewai/a2a/task_helpers.py @@ -1,480 +1,13 @@ -"""Helper functions for processing A2A task results.""" +"""Backward-compatibility shim — use ``crewai_a2a.task_helpers`` instead.""" -from __future__ import annotations +import warnings -from collections.abc import AsyncIterator -from typing import TYPE_CHECKING, Any, TypedDict -import uuid -from a2a.client.errors import A2AClientHTTPError -from a2a.types import ( - AgentCard, - Message, - Part, - Role, - Task, - TaskArtifactUpdateEvent, - TaskState, - TaskStatusUpdateEvent, - TextPart, -) -from typing_extensions import NotRequired - -from crewai.events.event_bus import crewai_event_bus -from crewai.events.types.a2a_events import ( - A2AConnectionErrorEvent, - A2AResponseReceivedEvent, +warnings.warn( + "'crewai.a2a.task_helpers' has been moved to 'crewai_a2a.task_helpers'. " + "Please update your imports. The old path will be removed in v2.0.0.", + FutureWarning, + stacklevel=2, ) - -if TYPE_CHECKING: - from a2a.types import Task as A2ATask - -SendMessageEvent = ( - tuple[Task, TaskStatusUpdateEvent | TaskArtifactUpdateEvent | None] | Message -) - - -TERMINAL_STATES: frozenset[TaskState] = frozenset( - { - TaskState.completed, - TaskState.failed, - TaskState.rejected, - TaskState.canceled, - } -) - -ACTIONABLE_STATES: frozenset[TaskState] = frozenset( - { - TaskState.input_required, - TaskState.auth_required, - } -) - -PENDING_STATES: frozenset[TaskState] = frozenset( - { - TaskState.submitted, - TaskState.working, - } -) - - -class TaskStateResult(TypedDict): - """Result dictionary from processing A2A task state.""" - - status: TaskState - history: list[Message] - result: NotRequired[str] - error: NotRequired[str] - agent_card: NotRequired[dict[str, Any]] - a2a_agent_name: NotRequired[str | None] - - -def extract_task_result_parts(a2a_task: A2ATask) -> list[str]: - """Extract result parts from A2A task status message, history, and artifacts. - - Args: - a2a_task: A2A Task object with status, history, and artifacts - - Returns: - List of result text parts - """ - result_parts: list[str] = [] - - if a2a_task.status and a2a_task.status.message: - msg = a2a_task.status.message - result_parts.extend( - part.root.text for part in msg.parts if part.root.kind == "text" - ) - - if not result_parts and a2a_task.history: - for history_msg in reversed(a2a_task.history): - if history_msg.role == Role.agent: - result_parts.extend( - part.root.text - for part in history_msg.parts - if part.root.kind == "text" - ) - break - - if a2a_task.artifacts: - result_parts.extend( - part.root.text - for artifact in a2a_task.artifacts - for part in artifact.parts - if part.root.kind == "text" - ) - - return result_parts - - -def extract_error_message(a2a_task: A2ATask, default: str) -> str: - """Extract error message from A2A task. - - Args: - a2a_task: A2A Task object - default: Default message if no error found - - Returns: - Error message string - """ - if a2a_task.status and a2a_task.status.message: - msg = a2a_task.status.message - if msg: - for part in msg.parts: - if part.root.kind == "text": - return str(part.root.text) - return str(msg) - - if a2a_task.history: - for history_msg in reversed(a2a_task.history): - for part in history_msg.parts: - if part.root.kind == "text": - return str(part.root.text) - - return default - - -def process_task_state( - a2a_task: A2ATask, - new_messages: list[Message], - agent_card: AgentCard, - turn_number: int, - is_multiturn: bool, - agent_role: str | None, - result_parts: list[str] | None = None, - endpoint: str | None = None, - a2a_agent_name: str | None = None, - from_task: Any | None = None, - from_agent: Any | None = None, - is_final: bool = True, -) -> TaskStateResult | None: - """Process A2A task state and return result dictionary. - - Shared logic for both polling and streaming handlers. - - Args: - a2a_task: The A2A task to process. - new_messages: List to collect messages (modified in place). - agent_card: The agent card. - turn_number: Current turn number. - is_multiturn: Whether multi-turn conversation. - agent_role: Agent role for logging. - result_parts: Accumulated result parts (streaming passes accumulated, - polling passes None to extract from task). - endpoint: A2A agent endpoint URL. - a2a_agent_name: Name of the A2A agent from agent card. - from_task: Optional CrewAI Task for event metadata. - from_agent: Optional CrewAI Agent for event metadata. - is_final: Whether this is the final response in the stream. - - Returns: - Result dictionary if terminal/actionable state, None otherwise. - """ - if result_parts is None: - result_parts = [] - - if a2a_task.status.state == TaskState.completed: - if not result_parts: - extracted_parts = extract_task_result_parts(a2a_task) - result_parts.extend(extracted_parts) - if a2a_task.history: - new_messages.extend(a2a_task.history) - - response_text = " ".join(result_parts) if result_parts else "" - message_id = None - if a2a_task.status and a2a_task.status.message: - message_id = a2a_task.status.message.message_id - crewai_event_bus.emit( - None, - A2AResponseReceivedEvent( - response=response_text, - turn_number=turn_number, - context_id=a2a_task.context_id, - message_id=message_id, - is_multiturn=is_multiturn, - status="completed", - final=is_final, - agent_role=agent_role, - endpoint=endpoint, - a2a_agent_name=a2a_agent_name, - from_task=from_task, - from_agent=from_agent, - ), - ) - - return TaskStateResult( - status=TaskState.completed, - agent_card=agent_card.model_dump(exclude_none=True), - result=response_text, - history=new_messages, - ) - - if a2a_task.status.state == TaskState.input_required: - if a2a_task.history: - new_messages.extend(a2a_task.history) - - response_text = extract_error_message(a2a_task, "Additional input required") - if response_text and not a2a_task.history: - agent_message = Message( - role=Role.agent, - message_id=str(uuid.uuid4()), - parts=[Part(root=TextPart(text=response_text))], - context_id=a2a_task.context_id, - task_id=a2a_task.id, - ) - new_messages.append(agent_message) - - input_message_id = None - if a2a_task.status and a2a_task.status.message: - input_message_id = a2a_task.status.message.message_id - crewai_event_bus.emit( - None, - A2AResponseReceivedEvent( - response=response_text, - turn_number=turn_number, - context_id=a2a_task.context_id, - message_id=input_message_id, - is_multiturn=is_multiturn, - status="input_required", - final=is_final, - agent_role=agent_role, - endpoint=endpoint, - a2a_agent_name=a2a_agent_name, - from_task=from_task, - from_agent=from_agent, - ), - ) - - return TaskStateResult( - status=TaskState.input_required, - error=response_text, - history=new_messages, - agent_card=agent_card.model_dump(exclude_none=True), - ) - - if a2a_task.status.state in {TaskState.failed, TaskState.rejected}: - error_msg = extract_error_message(a2a_task, "Task failed without error message") - if a2a_task.history: - new_messages.extend(a2a_task.history) - return TaskStateResult( - status=TaskState.failed, - error=error_msg, - history=new_messages, - ) - - if a2a_task.status.state == TaskState.auth_required: - error_msg = extract_error_message(a2a_task, "Authentication required") - return TaskStateResult( - status=TaskState.auth_required, - error=error_msg, - history=new_messages, - ) - - if a2a_task.status.state == TaskState.canceled: - error_msg = extract_error_message(a2a_task, "Task was canceled") - return TaskStateResult( - status=TaskState.canceled, - error=error_msg, - history=new_messages, - ) - - if a2a_task.status.state in PENDING_STATES: - return None - - return None - - -async def send_message_and_get_task_id( - event_stream: AsyncIterator[SendMessageEvent], - new_messages: list[Message], - agent_card: AgentCard, - turn_number: int, - is_multiturn: bool, - agent_role: str | None, - from_task: Any | None = None, - from_agent: Any | None = None, - endpoint: str | None = None, - a2a_agent_name: str | None = None, - context_id: str | None = None, -) -> str | TaskStateResult: - """Send message and process initial response. - - Handles the common pattern of sending a message and either: - - Getting an immediate Message response (task completed synchronously) - - Getting a Task that needs polling/waiting for completion - - Args: - event_stream: Async iterator from client.send_message() - new_messages: List to collect messages (modified in place) - agent_card: The agent card - turn_number: Current turn number - is_multiturn: Whether multi-turn conversation - agent_role: Agent role for logging - from_task: Optional CrewAI Task object for event metadata. - from_agent: Optional CrewAI Agent object for event metadata. - endpoint: Optional A2A endpoint URL. - a2a_agent_name: Optional A2A agent name. - context_id: Optional A2A context ID for correlation. - - Returns: - Task ID string if agent needs polling/waiting, or TaskStateResult if done. - """ - try: - async for event in event_stream: - if isinstance(event, Message): - new_messages.append(event) - result_parts = [ - part.root.text for part in event.parts if part.root.kind == "text" - ] - response_text = " ".join(result_parts) if result_parts else "" - - crewai_event_bus.emit( - None, - A2AResponseReceivedEvent( - response=response_text, - turn_number=turn_number, - context_id=event.context_id, - message_id=event.message_id, - is_multiturn=is_multiturn, - status="completed", - final=True, - agent_role=agent_role, - endpoint=endpoint, - a2a_agent_name=a2a_agent_name, - from_task=from_task, - from_agent=from_agent, - ), - ) - - return TaskStateResult( - status=TaskState.completed, - result=response_text, - history=new_messages, - agent_card=agent_card.model_dump(exclude_none=True), - ) - - if isinstance(event, tuple): - a2a_task, _ = event - - if a2a_task.status.state in TERMINAL_STATES | ACTIONABLE_STATES: - result = process_task_state( - a2a_task=a2a_task, - new_messages=new_messages, - agent_card=agent_card, - turn_number=turn_number, - is_multiturn=is_multiturn, - agent_role=agent_role, - endpoint=endpoint, - a2a_agent_name=a2a_agent_name, - from_task=from_task, - from_agent=from_agent, - ) - if result: - return result - - return a2a_task.id - - return TaskStateResult( - status=TaskState.failed, - error="No task ID received from initial message", - history=new_messages, - ) - - except A2AClientHTTPError as e: - error_msg = f"HTTP Error {e.status_code}: {e!s}" - - error_message = Message( - role=Role.agent, - message_id=str(uuid.uuid4()), - parts=[Part(root=TextPart(text=error_msg))], - context_id=context_id, - ) - new_messages.append(error_message) - - crewai_event_bus.emit( - None, - A2AConnectionErrorEvent( - endpoint=endpoint or "", - error=str(e), - error_type="http_error", - status_code=e.status_code, - a2a_agent_name=a2a_agent_name, - operation="send_message", - context_id=context_id, - from_task=from_task, - from_agent=from_agent, - ), - ) - crewai_event_bus.emit( - None, - A2AResponseReceivedEvent( - response=error_msg, - turn_number=turn_number, - context_id=context_id, - is_multiturn=is_multiturn, - status="failed", - final=True, - agent_role=agent_role, - endpoint=endpoint, - a2a_agent_name=a2a_agent_name, - from_task=from_task, - from_agent=from_agent, - ), - ) - return TaskStateResult( - status=TaskState.failed, - error=error_msg, - history=new_messages, - ) - - except Exception as e: - error_msg = f"Unexpected error during send_message: {e!s}" - - error_message = Message( - role=Role.agent, - message_id=str(uuid.uuid4()), - parts=[Part(root=TextPart(text=error_msg))], - context_id=context_id, - ) - new_messages.append(error_message) - - crewai_event_bus.emit( - None, - A2AConnectionErrorEvent( - endpoint=endpoint or "", - error=str(e), - error_type="unexpected_error", - a2a_agent_name=a2a_agent_name, - operation="send_message", - context_id=context_id, - from_task=from_task, - from_agent=from_agent, - ), - ) - crewai_event_bus.emit( - None, - A2AResponseReceivedEvent( - response=error_msg, - turn_number=turn_number, - context_id=context_id, - is_multiturn=is_multiturn, - status="failed", - final=True, - agent_role=agent_role, - endpoint=endpoint, - a2a_agent_name=a2a_agent_name, - from_task=from_task, - from_agent=from_agent, - ), - ) - return TaskStateResult( - status=TaskState.failed, - error=error_msg, - history=new_messages, - ) - - finally: - aclose = getattr(event_stream, "aclose", None) - if aclose: - await aclose() +from crewai_a2a.task_helpers import * # noqa: E402, F403 diff --git a/lib/crewai/src/crewai/a2a/templates.py b/lib/crewai/src/crewai/a2a/templates.py index 16f0c479e..c73b6f843 100644 --- a/lib/crewai/src/crewai/a2a/templates.py +++ b/lib/crewai/src/crewai/a2a/templates.py @@ -1,55 +1,13 @@ -"""String templates for A2A (Agent-to-Agent) protocol messaging and status.""" +"""Backward-compatibility shim — use ``crewai_a2a.templates`` instead.""" -from string import Template -from typing import Final +import warnings -AVAILABLE_AGENTS_TEMPLATE: Final[Template] = Template( - "\n\n $available_a2a_agents\n\n" +warnings.warn( + "'crewai.a2a.templates' has been moved to 'crewai_a2a.templates'. " + "Please update your imports. The old path will be removed in v2.0.0.", + FutureWarning, + stacklevel=2, ) -PREVIOUS_A2A_CONVERSATION_TEMPLATE: Final[Template] = Template( - "\n\n" - " $previous_a2a_conversation" - "\n\n" -) -CONVERSATION_TURN_INFO_TEMPLATE: Final[Template] = Template( - "\n\n" - ' turn="$turn_count"\n' - ' max_turns="$max_turns"\n' - " $warning" - "\n\n" -) -UNAVAILABLE_AGENTS_NOTICE_TEMPLATE: Final[Template] = Template( - "\n\n" - " NOTE: A2A agents were configured but are currently unavailable.\n" - " You cannot delegate to remote agents for this task.\n\n" - " Unavailable Agents:\n" - " $unavailable_agents" - "\n\n" -) -REMOTE_AGENT_COMPLETED_NOTICE: Final[str] = """ - -STATUS: COMPLETED -The remote agent has finished processing your request. Their response is in the conversation history above. -You MUST now: -1. Extract the answer from the conversation history -2. Set is_a2a=false -3. Return the answer as your final message -DO NOT send another request - the task is already done. - -""" -REMOTE_AGENT_RESPONSE_NOTICE: Final[str] = """ - -STATUS: RESPONSE_RECEIVED -The remote agent has responded. Their response is in the conversation history above. - -You MUST now: -1. Set is_a2a=false (the remote task is complete and cannot receive more messages) -2. Provide YOUR OWN response to the original task based on the information received - -IMPORTANT: Your response should be addressed to the USER who gave you the original task. -Report what the remote agent told you in THIRD PERSON (e.g., "The remote agent said..." or "I learned that..."). -Do NOT address the remote agent directly or use "you" to refer to them. - -""" +from crewai_a2a.templates import * # noqa: E402, F403 diff --git a/lib/crewai/src/crewai/a2a/types.py b/lib/crewai/src/crewai/a2a/types.py index 5a4a7672a..b3e8293da 100644 --- a/lib/crewai/src/crewai/a2a/types.py +++ b/lib/crewai/src/crewai/a2a/types.py @@ -1,104 +1,13 @@ -"""Type definitions for A2A protocol message parts.""" +"""Backward-compatibility shim — use ``crewai_a2a.types`` instead.""" -from __future__ import annotations +import warnings -from typing import ( - Annotated, - Any, - Literal, - Protocol, - TypedDict, - runtime_checkable, + +warnings.warn( + "'crewai.a2a.types' has been moved to 'crewai_a2a.types'. " + "Please update your imports. The old path will be removed in v2.0.0.", + FutureWarning, + stacklevel=2, ) -from pydantic import BeforeValidator, HttpUrl, TypeAdapter -from typing_extensions import NotRequired - - -try: - from crewai.a2a.updates import ( - PollingConfig, - PollingHandler, - PushNotificationConfig, - PushNotificationHandler, - StreamingConfig, - StreamingHandler, - UpdateConfig, - ) -except ImportError: - PollingConfig = Any # type: ignore[misc,assignment] - PollingHandler = Any # type: ignore[misc,assignment] - PushNotificationConfig = Any # type: ignore[misc,assignment] - PushNotificationHandler = Any # type: ignore[misc,assignment] - StreamingConfig = Any # type: ignore[misc,assignment] - StreamingHandler = Any # type: ignore[misc,assignment] - UpdateConfig = Any # type: ignore[misc,assignment] - - -TransportType = Literal["JSONRPC", "GRPC", "HTTP+JSON"] -ProtocolVersion = Literal[ - "0.2.0", - "0.2.1", - "0.2.2", - "0.2.3", - "0.2.4", - "0.2.5", - "0.2.6", - "0.3.0", - "0.4.0", -] - -http_url_adapter: TypeAdapter[HttpUrl] = TypeAdapter(HttpUrl) - -Url = Annotated[ - str, - BeforeValidator( - lambda value: str(http_url_adapter.validate_python(value, strict=True)) - ), -] - - -@runtime_checkable -class AgentResponseProtocol(Protocol): - """Protocol for the dynamically created AgentResponse model.""" - - a2a_ids: tuple[str, ...] - message: str - is_a2a: bool - - -class PartsMetadataDict(TypedDict, total=False): - """Metadata for A2A message parts. - - Attributes: - mimeType: MIME type for the part content. - schema: JSON schema for the part content. - """ - - mimeType: Literal["application/json"] - schema: dict[str, Any] - - -class PartsDict(TypedDict): - """A2A message part containing text and optional metadata. - - Attributes: - text: The text content of the message part. - metadata: Optional metadata describing the part content. - """ - - text: str - metadata: NotRequired[PartsMetadataDict] - - -PollingHandlerType = type[PollingHandler] -StreamingHandlerType = type[StreamingHandler] -PushNotificationHandlerType = type[PushNotificationHandler] - -HandlerType = PollingHandlerType | StreamingHandlerType | PushNotificationHandlerType - -HANDLER_REGISTRY: dict[type[UpdateConfig], HandlerType] = { - PollingConfig: PollingHandler, - StreamingConfig: StreamingHandler, - PushNotificationConfig: PushNotificationHandler, -} +from crewai_a2a.types import * # noqa: E402, F403 diff --git a/lib/crewai/src/crewai/a2a/updates/__init__.py b/lib/crewai/src/crewai/a2a/updates/__init__.py index 953eb48c3..75a310bac 100644 --- a/lib/crewai/src/crewai/a2a/updates/__init__.py +++ b/lib/crewai/src/crewai/a2a/updates/__init__.py @@ -1,35 +1,13 @@ -"""A2A update mechanism configuration types.""" +"""Backward-compatibility shim — use ``crewai_a2a.updates`` instead.""" -from crewai.a2a.updates.base import ( - BaseHandlerKwargs, - PollingHandlerKwargs, - PushNotificationHandlerKwargs, - PushNotificationResultStore, - StreamingHandlerKwargs, - UpdateHandler, +import warnings + + +warnings.warn( + "'crewai.a2a.updates' has been moved to 'crewai_a2a.updates'. " + "Please update your imports. The old path will be removed in v2.0.0.", + FutureWarning, + stacklevel=2, ) -from crewai.a2a.updates.polling.config import PollingConfig -from crewai.a2a.updates.polling.handler import PollingHandler -from crewai.a2a.updates.push_notifications.config import PushNotificationConfig -from crewai.a2a.updates.push_notifications.handler import PushNotificationHandler -from crewai.a2a.updates.streaming.config import StreamingConfig -from crewai.a2a.updates.streaming.handler import StreamingHandler - -UpdateConfig = PollingConfig | StreamingConfig | PushNotificationConfig - -__all__ = [ - "BaseHandlerKwargs", - "PollingConfig", - "PollingHandler", - "PollingHandlerKwargs", - "PushNotificationConfig", - "PushNotificationHandler", - "PushNotificationHandlerKwargs", - "PushNotificationResultStore", - "StreamingConfig", - "StreamingHandler", - "StreamingHandlerKwargs", - "UpdateConfig", - "UpdateHandler", -] +from crewai_a2a.updates import * # noqa: E402, F403 diff --git a/lib/crewai/src/crewai/a2a/updates/base.py b/lib/crewai/src/crewai/a2a/updates/base.py index 8a6a53aa3..2dea408ca 100644 --- a/lib/crewai/src/crewai/a2a/updates/base.py +++ b/lib/crewai/src/crewai/a2a/updates/base.py @@ -1,176 +1,13 @@ -"""Base types for A2A update mechanism handlers.""" +"""Backward-compatibility shim — use ``crewai_a2a.updates.base`` instead.""" -from __future__ import annotations - -from typing import TYPE_CHECKING, Any, NamedTuple, Protocol, TypedDict - -from pydantic import GetCoreSchemaHandler -from pydantic_core import CoreSchema, core_schema +import warnings -class CommonParams(NamedTuple): - """Common parameters shared across all update handlers. +warnings.warn( + "'crewai.a2a.updates.base' has been moved to 'crewai_a2a.updates.base'. " + "Please update your imports. The old path will be removed in v2.0.0.", + FutureWarning, + stacklevel=2, +) - Groups the frequently-passed parameters to reduce duplication. - """ - - turn_number: int - is_multiturn: bool - agent_role: str | None - endpoint: str - a2a_agent_name: str | None - context_id: str | None - from_task: Any - from_agent: Any - - -if TYPE_CHECKING: - from a2a.client import Client - from a2a.types import AgentCard, Message, Task - - from crewai.a2a.task_helpers import TaskStateResult - from crewai.a2a.updates.push_notifications.config import PushNotificationConfig - - -class BaseHandlerKwargs(TypedDict, total=False): - """Base kwargs shared by all handlers.""" - - turn_number: int - is_multiturn: bool - agent_role: str | None - context_id: str | None - task_id: str | None - endpoint: str | None - agent_branch: Any - a2a_agent_name: str | None - from_task: Any - from_agent: Any - - -class PollingHandlerKwargs(BaseHandlerKwargs, total=False): - """Kwargs for polling handler.""" - - polling_interval: float - polling_timeout: float - history_length: int - max_polls: int | None - - -class StreamingHandlerKwargs(BaseHandlerKwargs, total=False): - """Kwargs for streaming handler.""" - - -class PushNotificationHandlerKwargs(BaseHandlerKwargs, total=False): - """Kwargs for push notification handler.""" - - config: PushNotificationConfig - result_store: PushNotificationResultStore - polling_timeout: float - polling_interval: float - - -class PushNotificationResultStore(Protocol): - """Protocol for storing and retrieving push notification results. - - This protocol defines the interface for a result store that the - PushNotificationHandler uses to wait for task completion. - """ - - @classmethod - def __get_pydantic_core_schema__( - cls, - _source_type: Any, - _handler: GetCoreSchemaHandler, - ) -> CoreSchema: - return core_schema.any_schema() - - async def wait_for_result( - self, - task_id: str, - timeout: float, - poll_interval: float = 1.0, - ) -> Task | None: - """Wait for a task result to be available. - - Args: - task_id: The task ID to wait for. - timeout: Max seconds to wait before returning None. - poll_interval: Seconds between polling attempts. - - Returns: - The completed Task object, or None if timeout. - """ - ... - - async def get_result(self, task_id: str) -> Task | None: - """Get a task result if available. - - Args: - task_id: The task ID to retrieve. - - Returns: - The Task object if available, None otherwise. - """ - ... - - async def store_result(self, task: Task) -> None: - """Store a task result. - - Args: - task: The Task object to store. - """ - ... - - -class UpdateHandler(Protocol): - """Protocol for A2A update mechanism handlers.""" - - @staticmethod - async def execute( - client: Client, - message: Message, - new_messages: list[Message], - agent_card: AgentCard, - **kwargs: Any, - ) -> TaskStateResult: - """Execute the update mechanism and return result. - - Args: - client: A2A client instance. - message: Message to send. - new_messages: List to collect messages (modified in place). - agent_card: The agent card. - **kwargs: Additional handler-specific parameters. - - Returns: - Result dictionary with status, result/error, and history. - """ - ... - - -def extract_common_params(kwargs: BaseHandlerKwargs) -> CommonParams: - """Extract common parameters from handler kwargs. - - Args: - kwargs: Handler kwargs dict. - - Returns: - CommonParams with extracted values. - - Raises: - ValueError: If endpoint is not provided. - """ - endpoint = kwargs.get("endpoint") - if endpoint is None: - raise ValueError("endpoint is required for update handlers") - - return CommonParams( - turn_number=kwargs.get("turn_number", 0), - is_multiturn=kwargs.get("is_multiturn", False), - agent_role=kwargs.get("agent_role"), - endpoint=endpoint, - a2a_agent_name=kwargs.get("a2a_agent_name"), - context_id=kwargs.get("context_id"), - from_task=kwargs.get("from_task"), - from_agent=kwargs.get("from_agent"), - ) +from crewai_a2a.updates.base import * # noqa: E402, F403 diff --git a/lib/crewai/src/crewai/a2a/updates/polling/__init__.py b/lib/crewai/src/crewai/a2a/updates/polling/__init__.py index 7199db700..7a3be1924 100644 --- a/lib/crewai/src/crewai/a2a/updates/polling/__init__.py +++ b/lib/crewai/src/crewai/a2a/updates/polling/__init__.py @@ -1 +1,13 @@ -"""Polling update mechanism module.""" +"""Backward-compatibility shim — use ``crewai_a2a.updates.polling`` instead.""" + +import warnings + + +warnings.warn( + "'crewai.a2a.updates.polling' has been moved to 'crewai_a2a.updates.polling'. " + "Please update your imports. The old path will be removed in v2.0.0.", + FutureWarning, + stacklevel=2, +) + +from crewai_a2a.updates.polling import * # noqa: E402, F403 diff --git a/lib/crewai/src/crewai/a2a/updates/polling/config.py b/lib/crewai/src/crewai/a2a/updates/polling/config.py index 1dcf970a6..233fbe73b 100644 --- a/lib/crewai/src/crewai/a2a/updates/polling/config.py +++ b/lib/crewai/src/crewai/a2a/updates/polling/config.py @@ -1,25 +1,13 @@ -"""Polling update mechanism configuration.""" +"""Backward-compatibility shim — use ``crewai_a2a.updates.polling.config`` instead.""" -from __future__ import annotations - -from pydantic import BaseModel, Field +import warnings -class PollingConfig(BaseModel): - """Configuration for polling-based task updates. +warnings.warn( + "'crewai.a2a.updates.polling.config' has been moved to 'crewai_a2a.updates.polling.config'. " + "Please update your imports. The old path will be removed in v2.0.0.", + FutureWarning, + stacklevel=2, +) - Attributes: - interval: Seconds between poll attempts. - timeout: Max seconds to poll before raising timeout error. - max_polls: Max number of poll attempts. - history_length: Number of messages to retrieve per poll. - """ - - interval: float = Field( - default=2.0, gt=0, description="Seconds between poll attempts" - ) - timeout: float | None = Field(default=None, gt=0, description="Max seconds to poll") - max_polls: int | None = Field(default=None, gt=0, description="Max poll attempts") - history_length: int = Field( - default=100, gt=0, description="Messages to retrieve per poll" - ) +from crewai_a2a.updates.polling.config import * # noqa: E402, F403 diff --git a/lib/crewai/src/crewai/a2a/updates/polling/handler.py b/lib/crewai/src/crewai/a2a/updates/polling/handler.py index dad5bca57..f7298be76 100644 --- a/lib/crewai/src/crewai/a2a/updates/polling/handler.py +++ b/lib/crewai/src/crewai/a2a/updates/polling/handler.py @@ -1,359 +1,13 @@ -"""Polling update mechanism handler.""" +"""Backward-compatibility shim — use ``crewai_a2a.updates.polling.handler`` instead.""" -from __future__ import annotations +import warnings -import asyncio -import time -from typing import TYPE_CHECKING, Any -import uuid -from a2a.client import Client -from a2a.client.errors import A2AClientHTTPError -from a2a.types import ( - AgentCard, - Message, - Part, - Role, - TaskQueryParams, - TaskState, - TextPart, -) -from typing_extensions import Unpack - -from crewai.a2a.errors import A2APollingTimeoutError -from crewai.a2a.task_helpers import ( - ACTIONABLE_STATES, - TERMINAL_STATES, - TaskStateResult, - process_task_state, - send_message_and_get_task_id, -) -from crewai.a2a.updates.base import PollingHandlerKwargs -from crewai.events.event_bus import crewai_event_bus -from crewai.events.types.a2a_events import ( - A2AConnectionErrorEvent, - A2APollingStartedEvent, - A2APollingStatusEvent, - A2AResponseReceivedEvent, +warnings.warn( + "'crewai.a2a.updates.polling.handler' has been moved to 'crewai_a2a.updates.polling.handler'. " + "Please update your imports. The old path will be removed in v2.0.0.", + FutureWarning, + stacklevel=2, ) - -if TYPE_CHECKING: - from a2a.types import Task as A2ATask - - -async def _poll_task_until_complete( - client: Client, - task_id: str, - polling_interval: float, - polling_timeout: float, - agent_branch: Any | None = None, - history_length: int = 100, - max_polls: int | None = None, - from_task: Any | None = None, - from_agent: Any | None = None, - context_id: str | None = None, - endpoint: str | None = None, - a2a_agent_name: str | None = None, -) -> A2ATask: - """Poll task status until terminal state reached. - - Args: - client: A2A client instance. - task_id: Task ID to poll. - polling_interval: Seconds between poll attempts. - polling_timeout: Max seconds before timeout. - agent_branch: Agent tree branch for logging. - history_length: Number of messages to retrieve per poll. - max_polls: Max number of poll attempts (None = unlimited). - from_task: Optional CrewAI Task object for event metadata. - from_agent: Optional CrewAI Agent object for event metadata. - context_id: A2A context ID for correlation. - endpoint: A2A agent endpoint URL. - a2a_agent_name: Name of the A2A agent from agent card. - - Returns: - Final task object in terminal state. - - Raises: - A2APollingTimeoutError: If polling exceeds timeout or max_polls. - """ - start_time = time.monotonic() - poll_count = 0 - - while True: - poll_count += 1 - task = await client.get_task( - TaskQueryParams(id=task_id, history_length=history_length) - ) - - elapsed = time.monotonic() - start_time - effective_context_id = task.context_id or context_id - crewai_event_bus.emit( - agent_branch, - A2APollingStatusEvent( - task_id=task_id, - context_id=effective_context_id, - state=str(task.status.state.value), - elapsed_seconds=elapsed, - poll_count=poll_count, - endpoint=endpoint, - a2a_agent_name=a2a_agent_name, - from_task=from_task, - from_agent=from_agent, - ), - ) - - if task.status.state in TERMINAL_STATES | ACTIONABLE_STATES: - return task - - if elapsed > polling_timeout: - raise A2APollingTimeoutError( - f"Polling timeout after {polling_timeout}s ({poll_count} polls)" - ) - - if max_polls and poll_count >= max_polls: - raise A2APollingTimeoutError( - f"Max polls ({max_polls}) exceeded after {elapsed:.1f}s" - ) - - await asyncio.sleep(polling_interval) - - -class PollingHandler: - """Polling-based update handler.""" - - @staticmethod - async def execute( - client: Client, - message: Message, - new_messages: list[Message], - agent_card: AgentCard, - **kwargs: Unpack[PollingHandlerKwargs], - ) -> TaskStateResult: - """Execute A2A delegation using polling for updates. - - Args: - client: A2A client instance. - message: Message to send. - new_messages: List to collect messages. - agent_card: The agent card. - **kwargs: Polling-specific parameters. - - Returns: - Dictionary with status, result/error, and history. - """ - polling_interval = kwargs.get("polling_interval", 2.0) - polling_timeout = kwargs.get("polling_timeout", 300.0) - endpoint = kwargs.get("endpoint", "") - agent_branch = kwargs.get("agent_branch") - turn_number = kwargs.get("turn_number", 0) - is_multiturn = kwargs.get("is_multiturn", False) - agent_role = kwargs.get("agent_role") - history_length = kwargs.get("history_length", 100) - max_polls = kwargs.get("max_polls") - context_id = kwargs.get("context_id") - task_id = kwargs.get("task_id") - a2a_agent_name = kwargs.get("a2a_agent_name") - from_task = kwargs.get("from_task") - from_agent = kwargs.get("from_agent") - - try: - result_or_task_id = await send_message_and_get_task_id( - event_stream=client.send_message(message), - new_messages=new_messages, - agent_card=agent_card, - turn_number=turn_number, - is_multiturn=is_multiturn, - agent_role=agent_role, - from_task=from_task, - from_agent=from_agent, - endpoint=endpoint, - a2a_agent_name=a2a_agent_name, - context_id=context_id, - ) - - if not isinstance(result_or_task_id, str): - return result_or_task_id - - task_id = result_or_task_id - - crewai_event_bus.emit( - agent_branch, - A2APollingStartedEvent( - task_id=task_id, - context_id=context_id, - polling_interval=polling_interval, - endpoint=endpoint, - a2a_agent_name=a2a_agent_name, - from_task=from_task, - from_agent=from_agent, - ), - ) - - final_task = await _poll_task_until_complete( - client=client, - task_id=task_id, - polling_interval=polling_interval, - polling_timeout=polling_timeout, - agent_branch=agent_branch, - history_length=history_length, - max_polls=max_polls, - from_task=from_task, - from_agent=from_agent, - context_id=context_id, - endpoint=endpoint, - a2a_agent_name=a2a_agent_name, - ) - - result = process_task_state( - a2a_task=final_task, - new_messages=new_messages, - agent_card=agent_card, - turn_number=turn_number, - is_multiturn=is_multiturn, - agent_role=agent_role, - endpoint=endpoint, - a2a_agent_name=a2a_agent_name, - from_task=from_task, - from_agent=from_agent, - ) - if result: - return result - - return TaskStateResult( - status=TaskState.failed, - error=f"Unexpected task state: {final_task.status.state}", - history=new_messages, - ) - - except A2APollingTimeoutError as e: - error_msg = str(e) - - error_message = Message( - role=Role.agent, - message_id=str(uuid.uuid4()), - parts=[Part(root=TextPart(text=error_msg))], - context_id=context_id, - task_id=task_id, - ) - new_messages.append(error_message) - - crewai_event_bus.emit( - agent_branch, - A2AResponseReceivedEvent( - response=error_msg, - turn_number=turn_number, - context_id=context_id, - is_multiturn=is_multiturn, - status="failed", - final=True, - agent_role=agent_role, - endpoint=endpoint, - a2a_agent_name=a2a_agent_name, - from_task=from_task, - from_agent=from_agent, - ), - ) - return TaskStateResult( - status=TaskState.failed, - error=error_msg, - history=new_messages, - ) - - except A2AClientHTTPError as e: - error_msg = f"HTTP Error {e.status_code}: {e!s}" - - error_message = Message( - role=Role.agent, - message_id=str(uuid.uuid4()), - parts=[Part(root=TextPart(text=error_msg))], - context_id=context_id, - task_id=task_id, - ) - new_messages.append(error_message) - - crewai_event_bus.emit( - agent_branch, - A2AConnectionErrorEvent( - endpoint=endpoint, - error=str(e), - error_type="http_error", - status_code=e.status_code, - a2a_agent_name=a2a_agent_name, - operation="polling", - context_id=context_id, - task_id=task_id, - from_task=from_task, - from_agent=from_agent, - ), - ) - crewai_event_bus.emit( - agent_branch, - A2AResponseReceivedEvent( - response=error_msg, - turn_number=turn_number, - context_id=context_id, - is_multiturn=is_multiturn, - status="failed", - final=True, - agent_role=agent_role, - endpoint=endpoint, - a2a_agent_name=a2a_agent_name, - from_task=from_task, - from_agent=from_agent, - ), - ) - return TaskStateResult( - status=TaskState.failed, - error=error_msg, - history=new_messages, - ) - - except Exception as e: - error_msg = f"Unexpected error during polling: {e!s}" - - error_message = Message( - role=Role.agent, - message_id=str(uuid.uuid4()), - parts=[Part(root=TextPart(text=error_msg))], - context_id=context_id, - task_id=task_id, - ) - new_messages.append(error_message) - - crewai_event_bus.emit( - agent_branch, - A2AConnectionErrorEvent( - endpoint=endpoint, - error=str(e), - error_type="unexpected_error", - a2a_agent_name=a2a_agent_name, - operation="polling", - context_id=context_id, - task_id=task_id, - from_task=from_task, - from_agent=from_agent, - ), - ) - crewai_event_bus.emit( - agent_branch, - A2AResponseReceivedEvent( - response=error_msg, - turn_number=turn_number, - context_id=context_id, - is_multiturn=is_multiturn, - status="failed", - final=True, - agent_role=agent_role, - endpoint=endpoint, - a2a_agent_name=a2a_agent_name, - from_task=from_task, - from_agent=from_agent, - ), - ) - return TaskStateResult( - status=TaskState.failed, - error=error_msg, - history=new_messages, - ) +from crewai_a2a.updates.polling.handler import * # noqa: E402, F403 diff --git a/lib/crewai/src/crewai/a2a/updates/push_notifications/__init__.py b/lib/crewai/src/crewai/a2a/updates/push_notifications/__init__.py index abb3c2f23..599e86a12 100644 --- a/lib/crewai/src/crewai/a2a/updates/push_notifications/__init__.py +++ b/lib/crewai/src/crewai/a2a/updates/push_notifications/__init__.py @@ -1 +1,13 @@ -"""Push notification update mechanism module.""" +"""Backward-compatibility shim — use ``crewai_a2a.updates.push_notifications`` instead.""" + +import warnings + + +warnings.warn( + "'crewai.a2a.updates.push_notifications' has been moved to 'crewai_a2a.updates.push_notifications'. " + "Please update your imports. The old path will be removed in v2.0.0.", + FutureWarning, + stacklevel=2, +) + +from crewai_a2a.updates.push_notifications import * # noqa: E402, F403 diff --git a/lib/crewai/src/crewai/a2a/updates/push_notifications/config.py b/lib/crewai/src/crewai/a2a/updates/push_notifications/config.py index de81dbe80..20bd3e80b 100644 --- a/lib/crewai/src/crewai/a2a/updates/push_notifications/config.py +++ b/lib/crewai/src/crewai/a2a/updates/push_notifications/config.py @@ -1,65 +1,13 @@ -"""Push notification update mechanism configuration.""" +"""Backward-compatibility shim — use ``crewai_a2a.updates.push_notifications.config`` instead.""" -from __future__ import annotations - -from typing import Annotated - -from a2a.types import PushNotificationAuthenticationInfo -from pydantic import AnyHttpUrl, BaseModel, BeforeValidator, Field - -from crewai.a2a.updates.base import PushNotificationResultStore -from crewai.a2a.updates.push_notifications.signature import WebhookSignatureConfig +import warnings -def _coerce_signature( - value: str | WebhookSignatureConfig | None, -) -> WebhookSignatureConfig | None: - """Convert string secret to WebhookSignatureConfig.""" - if value is None: - return None - if isinstance(value, str): - return WebhookSignatureConfig.hmac_sha256(secret=value) - return value +warnings.warn( + "'crewai.a2a.updates.push_notifications.config' has been moved to 'crewai_a2a.updates.push_notifications.config'. " + "Please update your imports. The old path will be removed in v2.0.0.", + FutureWarning, + stacklevel=2, +) - -SignatureInput = Annotated[ - WebhookSignatureConfig | None, - BeforeValidator(_coerce_signature), -] - - -class PushNotificationConfig(BaseModel): - """Configuration for webhook-based task updates. - - Attributes: - url: Callback URL where agent sends push notifications. - id: Unique identifier for this config. - token: Token to validate incoming notifications. - authentication: Auth info for agent to use when calling webhook. - timeout: Max seconds to wait for task completion. - interval: Seconds between result polling attempts. - result_store: Store for receiving push notification results. - signature: HMAC signature config. Pass a string (secret) for defaults, - or WebhookSignatureConfig for custom settings. - """ - - url: AnyHttpUrl = Field(description="Callback URL for push notifications") - id: str | None = Field(default=None, description="Unique config identifier") - token: str | None = Field(default=None, description="Validation token") - authentication: PushNotificationAuthenticationInfo | None = Field( - default=None, description="Auth info for agent to use when calling webhook" - ) - timeout: float | None = Field( - default=300.0, gt=0, description="Max seconds to wait for task completion" - ) - interval: float = Field( - default=2.0, gt=0, description="Seconds between result polling attempts" - ) - result_store: PushNotificationResultStore | None = Field( - default=None, description="Result store for push notification handling" - ) - signature: SignatureInput = Field( - default=None, - description="HMAC signature config. Pass a string (secret) for simple usage, " - "or WebhookSignatureConfig for custom headers/tolerance.", - ) +from crewai_a2a.updates.push_notifications.config import * # noqa: E402, F403 diff --git a/lib/crewai/src/crewai/a2a/updates/push_notifications/handler.py b/lib/crewai/src/crewai/a2a/updates/push_notifications/handler.py index 783bf6483..1c2f459ed 100644 --- a/lib/crewai/src/crewai/a2a/updates/push_notifications/handler.py +++ b/lib/crewai/src/crewai/a2a/updates/push_notifications/handler.py @@ -1,354 +1,13 @@ -"""Push notification (webhook) update mechanism handler.""" +"""Backward-compatibility shim — use ``crewai_a2a.updates.push_notifications.handler`` instead.""" -from __future__ import annotations +import warnings -import logging -from typing import TYPE_CHECKING, Any -import uuid -from a2a.client import Client -from a2a.client.errors import A2AClientHTTPError -from a2a.types import ( - AgentCard, - Message, - Part, - Role, - TaskState, - TextPart, -) -from typing_extensions import Unpack - -from crewai.a2a.task_helpers import ( - TaskStateResult, - process_task_state, - send_message_and_get_task_id, -) -from crewai.a2a.updates.base import ( - CommonParams, - PushNotificationHandlerKwargs, - PushNotificationResultStore, - extract_common_params, -) -from crewai.events.event_bus import crewai_event_bus -from crewai.events.types.a2a_events import ( - A2AConnectionErrorEvent, - A2APushNotificationRegisteredEvent, - A2APushNotificationTimeoutEvent, - A2AResponseReceivedEvent, +warnings.warn( + "'crewai.a2a.updates.push_notifications.handler' has been moved to 'crewai_a2a.updates.push_notifications.handler'. " + "Please update your imports. The old path will be removed in v2.0.0.", + FutureWarning, + stacklevel=2, ) - -if TYPE_CHECKING: - from a2a.types import Task as A2ATask - -logger = logging.getLogger(__name__) - - -def _handle_push_error( - error: Exception, - error_msg: str, - error_type: str, - new_messages: list[Message], - agent_branch: Any | None, - params: CommonParams, - task_id: str | None, - status_code: int | None = None, -) -> TaskStateResult: - """Handle push notification errors with consistent event emission. - - Args: - error: The exception that occurred. - error_msg: Formatted error message for the result. - error_type: Type of error for the event. - new_messages: List to append error message to. - agent_branch: Agent tree branch for events. - params: Common handler parameters. - task_id: A2A task ID. - status_code: HTTP status code if applicable. - - Returns: - TaskStateResult with failed status. - """ - error_message = Message( - role=Role.agent, - message_id=str(uuid.uuid4()), - parts=[Part(root=TextPart(text=error_msg))], - context_id=params.context_id, - task_id=task_id, - ) - new_messages.append(error_message) - - crewai_event_bus.emit( - agent_branch, - A2AConnectionErrorEvent( - endpoint=params.endpoint, - error=str(error), - error_type=error_type, - status_code=status_code, - a2a_agent_name=params.a2a_agent_name, - operation="push_notification", - context_id=params.context_id, - task_id=task_id, - from_task=params.from_task, - from_agent=params.from_agent, - ), - ) - crewai_event_bus.emit( - agent_branch, - A2AResponseReceivedEvent( - response=error_msg, - turn_number=params.turn_number, - context_id=params.context_id, - is_multiturn=params.is_multiturn, - status="failed", - final=True, - agent_role=params.agent_role, - endpoint=params.endpoint, - a2a_agent_name=params.a2a_agent_name, - from_task=params.from_task, - from_agent=params.from_agent, - ), - ) - return TaskStateResult( - status=TaskState.failed, - error=error_msg, - history=new_messages, - ) - - -async def _wait_for_push_result( - task_id: str, - result_store: PushNotificationResultStore, - timeout: float, - poll_interval: float, - agent_branch: Any | None = None, - from_task: Any | None = None, - from_agent: Any | None = None, - context_id: str | None = None, - endpoint: str | None = None, - a2a_agent_name: str | None = None, -) -> A2ATask | None: - """Wait for push notification result. - - Args: - task_id: Task ID to wait for. - result_store: Store to retrieve results from. - timeout: Max seconds to wait. - poll_interval: Seconds between polling attempts. - agent_branch: Agent tree branch for logging. - from_task: Optional CrewAI Task object for event metadata. - from_agent: Optional CrewAI Agent object for event metadata. - context_id: A2A context ID for correlation. - endpoint: A2A agent endpoint URL. - a2a_agent_name: Name of the A2A agent. - - Returns: - Final task object, or None if timeout. - """ - task = await result_store.wait_for_result( - task_id=task_id, - timeout=timeout, - poll_interval=poll_interval, - ) - - if task is None: - crewai_event_bus.emit( - agent_branch, - A2APushNotificationTimeoutEvent( - task_id=task_id, - context_id=context_id, - timeout_seconds=timeout, - endpoint=endpoint, - a2a_agent_name=a2a_agent_name, - from_task=from_task, - from_agent=from_agent, - ), - ) - - return task - - -class PushNotificationHandler: - """Push notification (webhook) based update handler.""" - - @staticmethod - async def execute( - client: Client, - message: Message, - new_messages: list[Message], - agent_card: AgentCard, - **kwargs: Unpack[PushNotificationHandlerKwargs], - ) -> TaskStateResult: - """Execute A2A delegation using push notifications for updates. - - Args: - client: A2A client instance. - message: Message to send. - new_messages: List to collect messages. - agent_card: The agent card. - **kwargs: Push notification-specific parameters. - - Returns: - Dictionary with status, result/error, and history. - - Raises: - ValueError: If result_store or config not provided. - """ - config = kwargs.get("config") - result_store = kwargs.get("result_store") - polling_timeout = kwargs.get("polling_timeout", 300.0) - polling_interval = kwargs.get("polling_interval", 2.0) - agent_branch = kwargs.get("agent_branch") - task_id = kwargs.get("task_id") - params = extract_common_params(kwargs) - - if config is None: - error_msg = ( - "PushNotificationConfig is required for push notification handler" - ) - crewai_event_bus.emit( - agent_branch, - A2AConnectionErrorEvent( - endpoint=params.endpoint, - error=error_msg, - error_type="configuration_error", - a2a_agent_name=params.a2a_agent_name, - operation="push_notification", - context_id=params.context_id, - task_id=task_id, - from_task=params.from_task, - from_agent=params.from_agent, - ), - ) - return TaskStateResult( - status=TaskState.failed, - error=error_msg, - history=new_messages, - ) - - if result_store is None: - error_msg = ( - "PushNotificationResultStore is required for push notification handler" - ) - crewai_event_bus.emit( - agent_branch, - A2AConnectionErrorEvent( - endpoint=params.endpoint, - error=error_msg, - error_type="configuration_error", - a2a_agent_name=params.a2a_agent_name, - operation="push_notification", - context_id=params.context_id, - task_id=task_id, - from_task=params.from_task, - from_agent=params.from_agent, - ), - ) - return TaskStateResult( - status=TaskState.failed, - error=error_msg, - history=new_messages, - ) - - try: - result_or_task_id = await send_message_and_get_task_id( - event_stream=client.send_message(message), - new_messages=new_messages, - agent_card=agent_card, - turn_number=params.turn_number, - is_multiturn=params.is_multiturn, - agent_role=params.agent_role, - from_task=params.from_task, - from_agent=params.from_agent, - endpoint=params.endpoint, - a2a_agent_name=params.a2a_agent_name, - context_id=params.context_id, - ) - - if not isinstance(result_or_task_id, str): - return result_or_task_id - - task_id = result_or_task_id - - crewai_event_bus.emit( - agent_branch, - A2APushNotificationRegisteredEvent( - task_id=task_id, - context_id=params.context_id, - callback_url=str(config.url), - endpoint=params.endpoint, - a2a_agent_name=params.a2a_agent_name, - from_task=params.from_task, - from_agent=params.from_agent, - ), - ) - - logger.debug( - "Push notification callback for task %s configured at %s (via initial request)", - task_id, - config.url, - ) - - final_task = await _wait_for_push_result( - task_id=task_id, - result_store=result_store, - timeout=polling_timeout, - poll_interval=polling_interval, - agent_branch=agent_branch, - from_task=params.from_task, - from_agent=params.from_agent, - context_id=params.context_id, - endpoint=params.endpoint, - a2a_agent_name=params.a2a_agent_name, - ) - - if final_task is None: - return TaskStateResult( - status=TaskState.failed, - error=f"Push notification timeout after {polling_timeout}s", - history=new_messages, - ) - - result = process_task_state( - a2a_task=final_task, - new_messages=new_messages, - agent_card=agent_card, - turn_number=params.turn_number, - is_multiturn=params.is_multiturn, - agent_role=params.agent_role, - endpoint=params.endpoint, - a2a_agent_name=params.a2a_agent_name, - from_task=params.from_task, - from_agent=params.from_agent, - ) - if result: - return result - - return TaskStateResult( - status=TaskState.failed, - error=f"Unexpected task state: {final_task.status.state}", - history=new_messages, - ) - - except A2AClientHTTPError as e: - return _handle_push_error( - error=e, - error_msg=f"HTTP Error {e.status_code}: {e!s}", - error_type="http_error", - new_messages=new_messages, - agent_branch=agent_branch, - params=params, - task_id=task_id, - status_code=e.status_code, - ) - - except Exception as e: - return _handle_push_error( - error=e, - error_msg=f"Unexpected error during push notification: {e!s}", - error_type="unexpected_error", - new_messages=new_messages, - agent_branch=agent_branch, - params=params, - task_id=task_id, - ) +from crewai_a2a.updates.push_notifications.handler import * # noqa: E402, F403 diff --git a/lib/crewai/src/crewai/a2a/updates/push_notifications/signature.py b/lib/crewai/src/crewai/a2a/updates/push_notifications/signature.py index 9cac929ec..2ccbdc1e8 100644 --- a/lib/crewai/src/crewai/a2a/updates/push_notifications/signature.py +++ b/lib/crewai/src/crewai/a2a/updates/push_notifications/signature.py @@ -1,87 +1,13 @@ -"""Webhook signature configuration for push notifications.""" +"""Backward-compatibility shim — use ``crewai_a2a.updates.push_notifications.signature`` instead.""" -from __future__ import annotations - -from enum import Enum -import secrets - -from pydantic import BaseModel, Field, SecretStr +import warnings -class WebhookSignatureMode(str, Enum): - """Signature mode for webhook push notifications.""" +warnings.warn( + "'crewai.a2a.updates.push_notifications.signature' has been moved to 'crewai_a2a.updates.push_notifications.signature'. " + "Please update your imports. The old path will be removed in v2.0.0.", + FutureWarning, + stacklevel=2, +) - NONE = "none" - HMAC_SHA256 = "hmac_sha256" - - -class WebhookSignatureConfig(BaseModel): - """Configuration for webhook signature verification. - - Provides cryptographic integrity verification and replay attack protection - for A2A push notifications. - - Attributes: - mode: Signature mode (none or hmac_sha256). - secret: Shared secret for HMAC computation (required for hmac_sha256 mode). - timestamp_tolerance_seconds: Max allowed age of timestamps for replay protection. - header_name: HTTP header name for the signature. - timestamp_header_name: HTTP header name for the timestamp. - """ - - mode: WebhookSignatureMode = Field( - default=WebhookSignatureMode.NONE, - description="Signature verification mode", - ) - secret: SecretStr | None = Field( - default=None, - description="Shared secret for HMAC computation", - ) - timestamp_tolerance_seconds: int = Field( - default=300, - ge=0, - description="Max allowed timestamp age in seconds (5 min default)", - ) - header_name: str = Field( - default="X-A2A-Signature", - description="HTTP header name for the signature", - ) - timestamp_header_name: str = Field( - default="X-A2A-Signature-Timestamp", - description="HTTP header name for the timestamp", - ) - - @classmethod - def generate_secret(cls, length: int = 32) -> str: - """Generate a cryptographically secure random secret. - - Args: - length: Number of random bytes to generate (default 32). - - Returns: - URL-safe base64-encoded secret string. - """ - return secrets.token_urlsafe(length) - - @classmethod - def hmac_sha256( - cls, - secret: str | SecretStr, - timestamp_tolerance_seconds: int = 300, - ) -> WebhookSignatureConfig: - """Create an HMAC-SHA256 signature configuration. - - Args: - secret: Shared secret for HMAC computation. - timestamp_tolerance_seconds: Max allowed timestamp age in seconds. - - Returns: - Configured WebhookSignatureConfig for HMAC-SHA256. - """ - if isinstance(secret, str): - secret = SecretStr(secret) - return cls( - mode=WebhookSignatureMode.HMAC_SHA256, - secret=secret, - timestamp_tolerance_seconds=timestamp_tolerance_seconds, - ) +from crewai_a2a.updates.push_notifications.signature import * # noqa: E402, F403 diff --git a/lib/crewai/src/crewai/a2a/updates/streaming/__init__.py b/lib/crewai/src/crewai/a2a/updates/streaming/__init__.py index 7adada8b5..35b3232a5 100644 --- a/lib/crewai/src/crewai/a2a/updates/streaming/__init__.py +++ b/lib/crewai/src/crewai/a2a/updates/streaming/__init__.py @@ -1 +1,13 @@ -"""Streaming update mechanism module.""" +"""Backward-compatibility shim — use ``crewai_a2a.updates.streaming`` instead.""" + +import warnings + + +warnings.warn( + "'crewai.a2a.updates.streaming' has been moved to 'crewai_a2a.updates.streaming'. " + "Please update your imports. The old path will be removed in v2.0.0.", + FutureWarning, + stacklevel=2, +) + +from crewai_a2a.updates.streaming import * # noqa: E402, F403 diff --git a/lib/crewai/src/crewai/a2a/updates/streaming/config.py b/lib/crewai/src/crewai/a2a/updates/streaming/config.py index 6098bf550..eba202e3f 100644 --- a/lib/crewai/src/crewai/a2a/updates/streaming/config.py +++ b/lib/crewai/src/crewai/a2a/updates/streaming/config.py @@ -1,9 +1,13 @@ -"""Streaming update mechanism configuration.""" +"""Backward-compatibility shim — use ``crewai_a2a.updates.streaming.config`` instead.""" -from __future__ import annotations - -from pydantic import BaseModel +import warnings -class StreamingConfig(BaseModel): - """Configuration for SSE-based task updates.""" +warnings.warn( + "'crewai.a2a.updates.streaming.config' has been moved to 'crewai_a2a.updates.streaming.config'. " + "Please update your imports. The old path will be removed in v2.0.0.", + FutureWarning, + stacklevel=2, +) + +from crewai_a2a.updates.streaming.config import * # noqa: E402, F403 diff --git a/lib/crewai/src/crewai/a2a/updates/streaming/handler.py b/lib/crewai/src/crewai/a2a/updates/streaming/handler.py index 9b0c21d12..099991178 100644 --- a/lib/crewai/src/crewai/a2a/updates/streaming/handler.py +++ b/lib/crewai/src/crewai/a2a/updates/streaming/handler.py @@ -1,646 +1,13 @@ -"""Streaming (SSE) update mechanism handler.""" +"""Backward-compatibility shim — use ``crewai_a2a.updates.streaming.handler`` instead.""" -from __future__ import annotations +import warnings -import asyncio -import logging -from typing import Final -import uuid -from a2a.client import Client -from a2a.client.errors import A2AClientHTTPError -from a2a.types import ( - AgentCard, - Message, - Part, - Role, - Task, - TaskArtifactUpdateEvent, - TaskIdParams, - TaskQueryParams, - TaskState, - TaskStatusUpdateEvent, - TextPart, -) -from typing_extensions import Unpack - -from crewai.a2a.task_helpers import ( - ACTIONABLE_STATES, - TERMINAL_STATES, - TaskStateResult, - process_task_state, -) -from crewai.a2a.updates.base import StreamingHandlerKwargs, extract_common_params -from crewai.a2a.updates.streaming.params import ( - process_status_update, -) -from crewai.events.event_bus import crewai_event_bus -from crewai.events.types.a2a_events import ( - A2AArtifactReceivedEvent, - A2AConnectionErrorEvent, - A2AResponseReceivedEvent, - A2AStreamingChunkEvent, - A2AStreamingStartedEvent, +warnings.warn( + "'crewai.a2a.updates.streaming.handler' has been moved to 'crewai_a2a.updates.streaming.handler'. " + "Please update your imports. The old path will be removed in v2.0.0.", + FutureWarning, + stacklevel=2, ) - -logger = logging.getLogger(__name__) - -MAX_RESUBSCRIBE_ATTEMPTS: Final[int] = 3 -RESUBSCRIBE_BACKOFF_BASE: Final[float] = 1.0 - - -class StreamingHandler: - """SSE streaming-based update handler.""" - - @staticmethod - async def _try_recover_from_interruption( # type: ignore[misc] - client: Client, - task_id: str, - new_messages: list[Message], - agent_card: AgentCard, - result_parts: list[str], - **kwargs: Unpack[StreamingHandlerKwargs], - ) -> TaskStateResult | None: - """Attempt to recover from a stream interruption by checking task state. - - If the task completed while we were disconnected, returns the result. - If the task is still running, attempts to resubscribe and continue. - - Args: - client: A2A client instance. - task_id: The task ID to recover. - new_messages: List of collected messages. - agent_card: The agent card. - result_parts: Accumulated result text parts. - **kwargs: Handler parameters. - - Returns: - TaskStateResult if recovery succeeded (task finished or resubscribe worked). - None if recovery not possible (caller should handle failure). - - Note: - When None is returned, recovery failed and the original exception should - be handled by the caller. All recovery attempts are logged. - """ - params = extract_common_params(kwargs) # type: ignore[arg-type] - - try: - a2a_task: Task = await client.get_task(TaskQueryParams(id=task_id)) - - if a2a_task.status.state in TERMINAL_STATES: - logger.info( - "Task completed during stream interruption", - extra={"task_id": task_id, "state": str(a2a_task.status.state)}, - ) - return process_task_state( - a2a_task=a2a_task, - new_messages=new_messages, - agent_card=agent_card, - turn_number=params.turn_number, - is_multiturn=params.is_multiturn, - agent_role=params.agent_role, - result_parts=result_parts, - endpoint=params.endpoint, - a2a_agent_name=params.a2a_agent_name, - from_task=params.from_task, - from_agent=params.from_agent, - ) - - if a2a_task.status.state in ACTIONABLE_STATES: - logger.info( - "Task in actionable state during stream interruption", - extra={"task_id": task_id, "state": str(a2a_task.status.state)}, - ) - return process_task_state( - a2a_task=a2a_task, - new_messages=new_messages, - agent_card=agent_card, - turn_number=params.turn_number, - is_multiturn=params.is_multiturn, - agent_role=params.agent_role, - result_parts=result_parts, - endpoint=params.endpoint, - a2a_agent_name=params.a2a_agent_name, - from_task=params.from_task, - from_agent=params.from_agent, - is_final=False, - ) - - logger.info( - "Task still running, attempting resubscribe", - extra={"task_id": task_id, "state": str(a2a_task.status.state)}, - ) - - for attempt in range(MAX_RESUBSCRIBE_ATTEMPTS): - try: - backoff = RESUBSCRIBE_BACKOFF_BASE * (2**attempt) - if attempt > 0: - await asyncio.sleep(backoff) - - event_stream = client.resubscribe(TaskIdParams(id=task_id)) - - async for event in event_stream: - if isinstance(event, tuple): - resubscribed_task, update = event - - is_final_update = ( - process_status_update(update, result_parts) - if isinstance(update, TaskStatusUpdateEvent) - else False - ) - - if isinstance(update, TaskArtifactUpdateEvent): - artifact = update.artifact - result_parts.extend( - part.root.text - for part in artifact.parts - if part.root.kind == "text" - ) - - if ( - is_final_update - or resubscribed_task.status.state - in TERMINAL_STATES | ACTIONABLE_STATES - ): - return process_task_state( - a2a_task=resubscribed_task, - new_messages=new_messages, - agent_card=agent_card, - turn_number=params.turn_number, - is_multiturn=params.is_multiturn, - agent_role=params.agent_role, - result_parts=result_parts, - endpoint=params.endpoint, - a2a_agent_name=params.a2a_agent_name, - from_task=params.from_task, - from_agent=params.from_agent, - is_final=is_final_update, - ) - - elif isinstance(event, Message): - new_messages.append(event) - result_parts.extend( - part.root.text - for part in event.parts - if part.root.kind == "text" - ) - - final_task = await client.get_task(TaskQueryParams(id=task_id)) - return process_task_state( - a2a_task=final_task, - new_messages=new_messages, - agent_card=agent_card, - turn_number=params.turn_number, - is_multiturn=params.is_multiturn, - agent_role=params.agent_role, - result_parts=result_parts, - endpoint=params.endpoint, - a2a_agent_name=params.a2a_agent_name, - from_task=params.from_task, - from_agent=params.from_agent, - ) - - except Exception as resubscribe_error: # noqa: PERF203 - logger.warning( - "Resubscribe attempt failed", - extra={ - "task_id": task_id, - "attempt": attempt + 1, - "max_attempts": MAX_RESUBSCRIBE_ATTEMPTS, - "error": str(resubscribe_error), - }, - ) - if attempt == MAX_RESUBSCRIBE_ATTEMPTS - 1: - return None - - except Exception as e: - logger.warning( - "Failed to recover from stream interruption due to unexpected error", - extra={ - "task_id": task_id, - "error": str(e), - "error_type": type(e).__name__, - }, - exc_info=True, - ) - return None - - logger.warning( - "Recovery exhausted all resubscribe attempts without success", - extra={"task_id": task_id, "max_attempts": MAX_RESUBSCRIBE_ATTEMPTS}, - ) - return None - - @staticmethod - async def execute( - client: Client, - message: Message, - new_messages: list[Message], - agent_card: AgentCard, - **kwargs: Unpack[StreamingHandlerKwargs], - ) -> TaskStateResult: - """Execute A2A delegation using SSE streaming for updates. - - Args: - client: A2A client instance. - message: Message to send. - new_messages: List to collect messages. - agent_card: The agent card. - **kwargs: Streaming-specific parameters. - - Returns: - Dictionary with status, result/error, and history. - """ - task_id = kwargs.get("task_id") - agent_branch = kwargs.get("agent_branch") - params = extract_common_params(kwargs) - - result_parts: list[str] = [] - final_result: TaskStateResult | None = None - event_stream = client.send_message(message) - chunk_index = 0 - current_task_id: str | None = task_id - - crewai_event_bus.emit( - agent_branch, - A2AStreamingStartedEvent( - task_id=task_id, - context_id=params.context_id, - endpoint=params.endpoint, - a2a_agent_name=params.a2a_agent_name, - turn_number=params.turn_number, - is_multiturn=params.is_multiturn, - agent_role=params.agent_role, - from_task=params.from_task, - from_agent=params.from_agent, - ), - ) - - try: - async for event in event_stream: - if isinstance(event, tuple): - a2a_task, _ = event - current_task_id = a2a_task.id - - if isinstance(event, Message): - new_messages.append(event) - message_context_id = event.context_id or params.context_id - for part in event.parts: - if part.root.kind == "text": - text = part.root.text - result_parts.append(text) - crewai_event_bus.emit( - agent_branch, - A2AStreamingChunkEvent( - task_id=event.task_id or task_id, - context_id=message_context_id, - chunk=text, - chunk_index=chunk_index, - endpoint=params.endpoint, - a2a_agent_name=params.a2a_agent_name, - turn_number=params.turn_number, - is_multiturn=params.is_multiturn, - from_task=params.from_task, - from_agent=params.from_agent, - ), - ) - chunk_index += 1 - - elif isinstance(event, tuple): - a2a_task, update = event - - if isinstance(update, TaskArtifactUpdateEvent): - artifact = update.artifact - result_parts.extend( - part.root.text - for part in artifact.parts - if part.root.kind == "text" - ) - artifact_size = None - if artifact.parts: - artifact_size = sum( - len(p.root.text.encode()) - if p.root.kind == "text" - else len(getattr(p.root, "data", b"")) - for p in artifact.parts - ) - effective_context_id = a2a_task.context_id or params.context_id - crewai_event_bus.emit( - agent_branch, - A2AArtifactReceivedEvent( - task_id=a2a_task.id, - artifact_id=artifact.artifact_id, - artifact_name=artifact.name, - artifact_description=artifact.description, - mime_type=artifact.parts[0].root.kind - if artifact.parts - else None, - size_bytes=artifact_size, - append=update.append or False, - last_chunk=update.last_chunk or False, - endpoint=params.endpoint, - a2a_agent_name=params.a2a_agent_name, - context_id=effective_context_id, - turn_number=params.turn_number, - is_multiturn=params.is_multiturn, - from_task=params.from_task, - from_agent=params.from_agent, - ), - ) - - is_final_update = ( - process_status_update(update, result_parts) - if isinstance(update, TaskStatusUpdateEvent) - else False - ) - - if ( - not is_final_update - and a2a_task.status.state - not in TERMINAL_STATES | ACTIONABLE_STATES - ): - continue - - final_result = process_task_state( - a2a_task=a2a_task, - new_messages=new_messages, - agent_card=agent_card, - turn_number=params.turn_number, - is_multiturn=params.is_multiturn, - agent_role=params.agent_role, - result_parts=result_parts, - endpoint=params.endpoint, - a2a_agent_name=params.a2a_agent_name, - from_task=params.from_task, - from_agent=params.from_agent, - is_final=is_final_update, - ) - if final_result: - break - - except A2AClientHTTPError as e: - if current_task_id: - logger.info( - "Stream interrupted with HTTP error, attempting recovery", - extra={ - "task_id": current_task_id, - "error": str(e), - "status_code": e.status_code, - }, - ) - recovery_kwargs = {k: v for k, v in kwargs.items() if k != "task_id"} - recovered_result = ( - await StreamingHandler._try_recover_from_interruption( - client=client, - task_id=current_task_id, - new_messages=new_messages, - agent_card=agent_card, - result_parts=result_parts, - **recovery_kwargs, - ) - ) - if recovered_result: - logger.info( - "Successfully recovered task after HTTP error", - extra={ - "task_id": current_task_id, - "status": str(recovered_result.get("status")), - }, - ) - return recovered_result - - logger.warning( - "Failed to recover from HTTP error, returning failure", - extra={ - "task_id": current_task_id, - "status_code": e.status_code, - "original_error": str(e), - }, - ) - - error_msg = f"HTTP Error {e.status_code}: {e!s}" - error_type = "http_error" - status_code = e.status_code - - error_message = Message( - role=Role.agent, - message_id=str(uuid.uuid4()), - parts=[Part(root=TextPart(text=error_msg))], - context_id=params.context_id, - task_id=task_id, - ) - new_messages.append(error_message) - - crewai_event_bus.emit( - agent_branch, - A2AConnectionErrorEvent( - endpoint=params.endpoint, - error=str(e), - error_type=error_type, - status_code=status_code, - a2a_agent_name=params.a2a_agent_name, - operation="streaming", - context_id=params.context_id, - task_id=task_id, - from_task=params.from_task, - from_agent=params.from_agent, - ), - ) - crewai_event_bus.emit( - agent_branch, - A2AResponseReceivedEvent( - response=error_msg, - turn_number=params.turn_number, - context_id=params.context_id, - is_multiturn=params.is_multiturn, - status="failed", - final=True, - agent_role=params.agent_role, - endpoint=params.endpoint, - a2a_agent_name=params.a2a_agent_name, - from_task=params.from_task, - from_agent=params.from_agent, - ), - ) - return TaskStateResult( - status=TaskState.failed, - error=error_msg, - history=new_messages, - ) - - except (asyncio.TimeoutError, asyncio.CancelledError, ConnectionError) as e: - error_type = type(e).__name__.lower() - if current_task_id: - logger.info( - f"Stream interrupted with {error_type}, attempting recovery", - extra={"task_id": current_task_id, "error": str(e)}, - ) - recovery_kwargs = {k: v for k, v in kwargs.items() if k != "task_id"} - recovered_result = ( - await StreamingHandler._try_recover_from_interruption( - client=client, - task_id=current_task_id, - new_messages=new_messages, - agent_card=agent_card, - result_parts=result_parts, - **recovery_kwargs, - ) - ) - if recovered_result: - logger.info( - f"Successfully recovered task after {error_type}", - extra={ - "task_id": current_task_id, - "status": str(recovered_result.get("status")), - }, - ) - return recovered_result - - logger.warning( - f"Failed to recover from {error_type}, returning failure", - extra={ - "task_id": current_task_id, - "error_type": error_type, - "original_error": str(e), - }, - ) - - error_msg = f"Connection error during streaming: {e!s}" - status_code = None - - error_message = Message( - role=Role.agent, - message_id=str(uuid.uuid4()), - parts=[Part(root=TextPart(text=error_msg))], - context_id=params.context_id, - task_id=task_id, - ) - new_messages.append(error_message) - - crewai_event_bus.emit( - agent_branch, - A2AConnectionErrorEvent( - endpoint=params.endpoint, - error=str(e), - error_type=error_type, - status_code=status_code, - a2a_agent_name=params.a2a_agent_name, - operation="streaming", - context_id=params.context_id, - task_id=task_id, - from_task=params.from_task, - from_agent=params.from_agent, - ), - ) - crewai_event_bus.emit( - agent_branch, - A2AResponseReceivedEvent( - response=error_msg, - turn_number=params.turn_number, - context_id=params.context_id, - is_multiturn=params.is_multiturn, - status="failed", - final=True, - agent_role=params.agent_role, - endpoint=params.endpoint, - a2a_agent_name=params.a2a_agent_name, - from_task=params.from_task, - from_agent=params.from_agent, - ), - ) - return TaskStateResult( - status=TaskState.failed, - error=error_msg, - history=new_messages, - ) - - except Exception as e: - logger.exception( - "Unexpected error during streaming", - extra={ - "task_id": current_task_id, - "error_type": type(e).__name__, - "endpoint": params.endpoint, - }, - ) - error_msg = f"Unexpected error during streaming: {type(e).__name__}: {e!s}" - error_type = "unexpected_error" - status_code = None - - error_message = Message( - role=Role.agent, - message_id=str(uuid.uuid4()), - parts=[Part(root=TextPart(text=error_msg))], - context_id=params.context_id, - task_id=task_id, - ) - new_messages.append(error_message) - - crewai_event_bus.emit( - agent_branch, - A2AConnectionErrorEvent( - endpoint=params.endpoint, - error=str(e), - error_type=error_type, - status_code=status_code, - a2a_agent_name=params.a2a_agent_name, - operation="streaming", - context_id=params.context_id, - task_id=task_id, - from_task=params.from_task, - from_agent=params.from_agent, - ), - ) - crewai_event_bus.emit( - agent_branch, - A2AResponseReceivedEvent( - response=error_msg, - turn_number=params.turn_number, - context_id=params.context_id, - is_multiturn=params.is_multiturn, - status="failed", - final=True, - agent_role=params.agent_role, - endpoint=params.endpoint, - a2a_agent_name=params.a2a_agent_name, - from_task=params.from_task, - from_agent=params.from_agent, - ), - ) - return TaskStateResult( - status=TaskState.failed, - error=error_msg, - history=new_messages, - ) - - finally: - aclose = getattr(event_stream, "aclose", None) - if aclose: - try: - await aclose() - except Exception as close_error: - crewai_event_bus.emit( - agent_branch, - A2AConnectionErrorEvent( - endpoint=params.endpoint, - error=str(close_error), - error_type="stream_close_error", - a2a_agent_name=params.a2a_agent_name, - operation="stream_close", - context_id=params.context_id, - task_id=task_id, - from_task=params.from_task, - from_agent=params.from_agent, - ), - ) - - if final_result: - return final_result - - return TaskStateResult( - status=TaskState.completed, - result=" ".join(result_parts) if result_parts else "", - history=new_messages, - agent_card=agent_card.model_dump(exclude_none=True), - ) +from crewai_a2a.updates.streaming.handler import * # noqa: E402, F403 diff --git a/lib/crewai/src/crewai/a2a/updates/streaming/params.py b/lib/crewai/src/crewai/a2a/updates/streaming/params.py index a4bf8c0a2..2e3bbb602 100644 --- a/lib/crewai/src/crewai/a2a/updates/streaming/params.py +++ b/lib/crewai/src/crewai/a2a/updates/streaming/params.py @@ -1,28 +1,13 @@ -"""Common parameter extraction for streaming handlers.""" +"""Backward-compatibility shim — use ``crewai_a2a.updates.streaming.params`` instead.""" -from __future__ import annotations - -from a2a.types import TaskStatusUpdateEvent +import warnings -def process_status_update( - update: TaskStatusUpdateEvent, - result_parts: list[str], -) -> bool: - """Process a status update event and extract text parts. +warnings.warn( + "'crewai.a2a.updates.streaming.params' has been moved to 'crewai_a2a.updates.streaming.params'. " + "Please update your imports. The old path will be removed in v2.0.0.", + FutureWarning, + stacklevel=2, +) - Args: - update: The status update event. - result_parts: List to append text parts to (modified in place). - - Returns: - True if this is a final update, False otherwise. - """ - is_final = update.final - if update.status and update.status.message and update.status.message.parts: - result_parts.extend( - part.root.text - for part in update.status.message.parts - if part.root.kind == "text" and part.root.text - ) - return is_final +from crewai_a2a.updates.streaming.params import * # noqa: E402, F403 diff --git a/lib/crewai/src/crewai/a2a/utils/__init__.py b/lib/crewai/src/crewai/a2a/utils/__init__.py index bdb7bed62..42ad51fbe 100644 --- a/lib/crewai/src/crewai/a2a/utils/__init__.py +++ b/lib/crewai/src/crewai/a2a/utils/__init__.py @@ -1 +1,13 @@ -"""A2A utility modules for client operations.""" +"""Backward-compatibility shim — use ``crewai_a2a.utils`` instead.""" + +import warnings + + +warnings.warn( + "'crewai.a2a.utils' has been moved to 'crewai_a2a.utils'. " + "Please update your imports. The old path will be removed in v2.0.0.", + FutureWarning, + stacklevel=2, +) + +from crewai_a2a.utils import * # noqa: E402, F403 diff --git a/lib/crewai/src/crewai/a2a/utils/agent_card.py b/lib/crewai/src/crewai/a2a/utils/agent_card.py index c548cd1e7..0f5525835 100644 --- a/lib/crewai/src/crewai/a2a/utils/agent_card.py +++ b/lib/crewai/src/crewai/a2a/utils/agent_card.py @@ -1,586 +1,13 @@ -"""AgentCard utilities for A2A client and server operations.""" +"""Backward-compatibility shim — use ``crewai_a2a.utils.agent_card`` instead.""" -from __future__ import annotations +import warnings -import asyncio -from collections.abc import MutableMapping -from functools import lru_cache -import ssl -import time -from types import MethodType -from typing import TYPE_CHECKING -from a2a.client.errors import A2AClientHTTPError -from a2a.types import AgentCapabilities, AgentCard, AgentSkill -from aiocache import cached # type: ignore[import-untyped] -from aiocache.serializers import PickleSerializer # type: ignore[import-untyped] -import httpx - -from crewai.a2a.auth.client_schemes import APIKeyAuth, HTTPDigestAuth -from crewai.a2a.auth.utils import ( - _auth_store, - configure_auth_client, - retry_on_401, -) -from crewai.a2a.config import A2AServerConfig -from crewai.crew import Crew -from crewai.events.event_bus import crewai_event_bus -from crewai.events.types.a2a_events import ( - A2AAgentCardFetchedEvent, - A2AAuthenticationFailedEvent, - A2AConnectionErrorEvent, +warnings.warn( + "'crewai.a2a.utils.agent_card' has been moved to 'crewai_a2a.utils.agent_card'. " + "Please update your imports. The old path will be removed in v2.0.0.", + FutureWarning, + stacklevel=2, ) - -if TYPE_CHECKING: - from crewai.a2a.auth.client_schemes import ClientAuthScheme - from crewai.agent import Agent - from crewai.task import Task - - -def _get_tls_verify(auth: ClientAuthScheme | None) -> ssl.SSLContext | bool | str: - """Get TLS verify parameter from auth scheme. - - Args: - auth: Optional authentication scheme with TLS config. - - Returns: - SSL context, CA cert path, True for default verification, - or False if verification disabled. - """ - if auth and auth.tls: - return auth.tls.get_httpx_ssl_context() - return True - - -async def _prepare_auth_headers( - auth: ClientAuthScheme | None, - timeout: int, -) -> tuple[MutableMapping[str, str], ssl.SSLContext | bool | str]: - """Prepare authentication headers and TLS verification settings. - - Args: - auth: Optional authentication scheme. - timeout: Request timeout in seconds. - - Returns: - Tuple of (headers dict, TLS verify setting). - """ - headers: MutableMapping[str, str] = {} - verify = _get_tls_verify(auth) - if auth: - async with httpx.AsyncClient( - timeout=timeout, verify=verify - ) as temp_auth_client: - if isinstance(auth, (HTTPDigestAuth, APIKeyAuth)): - configure_auth_client(auth, temp_auth_client) - headers = await auth.apply_auth(temp_auth_client, {}) - return headers, verify - - -def _get_server_config(agent: Agent) -> A2AServerConfig | None: - """Get A2AServerConfig from an agent's a2a configuration. - - Args: - agent: The Agent instance to check. - - Returns: - A2AServerConfig if present, None otherwise. - """ - if agent.a2a is None: - return None - if isinstance(agent.a2a, A2AServerConfig): - return agent.a2a - if isinstance(agent.a2a, list): - for config in agent.a2a: - if isinstance(config, A2AServerConfig): - return config - return None - - -def fetch_agent_card( - endpoint: str, - auth: ClientAuthScheme | None = None, - timeout: int = 30, - use_cache: bool = True, - cache_ttl: int = 300, -) -> AgentCard: - """Fetch AgentCard from an A2A endpoint with optional caching. - - Args: - endpoint: A2A agent endpoint URL (AgentCard URL). - auth: Optional ClientAuthScheme for authentication. - timeout: Request timeout in seconds. - use_cache: Whether to use caching (default True). - cache_ttl: Cache TTL in seconds (default 300 = 5 minutes). - - Returns: - AgentCard object with agent capabilities and skills. - - Raises: - httpx.HTTPStatusError: If the request fails. - A2AClientHTTPError: If authentication fails. - """ - if use_cache: - if auth: - auth_data = auth.model_dump_json( - exclude={ - "_access_token", - "_token_expires_at", - "_refresh_token", - "_authorization_callback", - } - ) - auth_hash = _auth_store.compute_key(type(auth).__name__, auth_data) - else: - auth_hash = _auth_store.compute_key("none", "") - _auth_store.set(auth_hash, auth) - ttl_hash = int(time.time() // cache_ttl) - return _fetch_agent_card_cached(endpoint, auth_hash, timeout, ttl_hash) - - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - try: - return loop.run_until_complete( - afetch_agent_card(endpoint=endpoint, auth=auth, timeout=timeout) - ) - finally: - loop.close() - - -async def afetch_agent_card( - endpoint: str, - auth: ClientAuthScheme | None = None, - timeout: int = 30, - use_cache: bool = True, -) -> AgentCard: - """Fetch AgentCard from an A2A endpoint asynchronously. - - Native async implementation. Use this when running in an async context. - - Args: - endpoint: A2A agent endpoint URL (AgentCard URL). - auth: Optional ClientAuthScheme for authentication. - timeout: Request timeout in seconds. - use_cache: Whether to use caching (default True). - - Returns: - AgentCard object with agent capabilities and skills. - - Raises: - httpx.HTTPStatusError: If the request fails. - A2AClientHTTPError: If authentication fails. - """ - if use_cache: - if auth: - auth_data = auth.model_dump_json( - exclude={ - "_access_token", - "_token_expires_at", - "_refresh_token", - "_authorization_callback", - } - ) - auth_hash = _auth_store.compute_key(type(auth).__name__, auth_data) - else: - auth_hash = _auth_store.compute_key("none", "") - _auth_store.set(auth_hash, auth) - agent_card: AgentCard = await _afetch_agent_card_cached( - endpoint, auth_hash, timeout - ) - return agent_card - - return await _afetch_agent_card_impl(endpoint=endpoint, auth=auth, timeout=timeout) - - -@lru_cache() -def _fetch_agent_card_cached( - endpoint: str, - auth_hash: str, - timeout: int, - _ttl_hash: int, -) -> AgentCard: - """Cached sync version of fetch_agent_card.""" - auth = _auth_store.get(auth_hash) - - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - try: - return loop.run_until_complete( - _afetch_agent_card_impl(endpoint=endpoint, auth=auth, timeout=timeout) - ) - finally: - loop.close() - - -@cached(ttl=300, serializer=PickleSerializer()) # type: ignore[untyped-decorator] -async def _afetch_agent_card_cached( - endpoint: str, - auth_hash: str, - timeout: int, -) -> AgentCard: - """Cached async implementation of AgentCard fetching.""" - auth = _auth_store.get(auth_hash) - return await _afetch_agent_card_impl(endpoint=endpoint, auth=auth, timeout=timeout) - - -async def _afetch_agent_card_impl( - endpoint: str, - auth: ClientAuthScheme | None, - timeout: int, -) -> AgentCard: - """Internal async implementation of AgentCard fetching.""" - start_time = time.perf_counter() - - if "/.well-known/agent-card.json" in endpoint: - base_url = endpoint.replace("/.well-known/agent-card.json", "") - agent_card_path = "/.well-known/agent-card.json" - else: - url_parts = endpoint.split("/", 3) - base_url = f"{url_parts[0]}//{url_parts[2]}" - agent_card_path = ( - f"/{url_parts[3]}" - if len(url_parts) > 3 and url_parts[3] - else "/.well-known/agent-card.json" - ) - - headers, verify = await _prepare_auth_headers(auth, timeout) - - async with httpx.AsyncClient( - timeout=timeout, headers=headers, verify=verify - ) as temp_client: - if auth and isinstance(auth, (HTTPDigestAuth, APIKeyAuth)): - configure_auth_client(auth, temp_client) - - agent_card_url = f"{base_url}{agent_card_path}" - - async def _fetch_agent_card_request() -> httpx.Response: - return await temp_client.get(agent_card_url) - - try: - response = await retry_on_401( - request_func=_fetch_agent_card_request, - auth_scheme=auth, - client=temp_client, - headers=temp_client.headers, - max_retries=2, - ) - response.raise_for_status() - - agent_card = AgentCard.model_validate(response.json()) - fetch_time_ms = (time.perf_counter() - start_time) * 1000 - agent_card_dict = agent_card.model_dump(exclude_none=True) - - crewai_event_bus.emit( - None, - A2AAgentCardFetchedEvent( - endpoint=endpoint, - a2a_agent_name=agent_card.name, - agent_card=agent_card_dict, - protocol_version=agent_card.protocol_version, - provider=agent_card_dict.get("provider"), - cached=False, - fetch_time_ms=fetch_time_ms, - ), - ) - - return agent_card - - except httpx.HTTPStatusError as e: - elapsed_ms = (time.perf_counter() - start_time) * 1000 - response_body = e.response.text[:1000] if e.response.text else None - - if e.response.status_code == 401: - error_details = ["Authentication failed"] - www_auth = e.response.headers.get("WWW-Authenticate") - if www_auth: - error_details.append(f"WWW-Authenticate: {www_auth}") - if not auth: - error_details.append("No auth scheme provided") - msg = " | ".join(error_details) - - auth_type = type(auth).__name__ if auth else None - crewai_event_bus.emit( - None, - A2AAuthenticationFailedEvent( - endpoint=endpoint, - auth_type=auth_type, - error=msg, - status_code=401, - metadata={ - "elapsed_ms": elapsed_ms, - "response_body": response_body, - "www_authenticate": www_auth, - "request_url": str(e.request.url), - }, - ), - ) - - raise A2AClientHTTPError(401, msg) from e - - crewai_event_bus.emit( - None, - A2AConnectionErrorEvent( - endpoint=endpoint, - error=str(e), - error_type="http_error", - status_code=e.response.status_code, - operation="fetch_agent_card", - metadata={ - "elapsed_ms": elapsed_ms, - "response_body": response_body, - "request_url": str(e.request.url), - }, - ), - ) - raise - - except httpx.TimeoutException as e: - elapsed_ms = (time.perf_counter() - start_time) * 1000 - crewai_event_bus.emit( - None, - A2AConnectionErrorEvent( - endpoint=endpoint, - error=str(e), - error_type="timeout", - operation="fetch_agent_card", - metadata={ - "elapsed_ms": elapsed_ms, - "timeout_config": timeout, - "request_url": str(e.request.url) if e.request else None, - }, - ), - ) - raise - - except httpx.ConnectError as e: - elapsed_ms = (time.perf_counter() - start_time) * 1000 - crewai_event_bus.emit( - None, - A2AConnectionErrorEvent( - endpoint=endpoint, - error=str(e), - error_type="connection_error", - operation="fetch_agent_card", - metadata={ - "elapsed_ms": elapsed_ms, - "request_url": str(e.request.url) if e.request else None, - }, - ), - ) - raise - - except httpx.RequestError as e: - elapsed_ms = (time.perf_counter() - start_time) * 1000 - crewai_event_bus.emit( - None, - A2AConnectionErrorEvent( - endpoint=endpoint, - error=str(e), - error_type="request_error", - operation="fetch_agent_card", - metadata={ - "elapsed_ms": elapsed_ms, - "request_url": str(e.request.url) if e.request else None, - }, - ), - ) - raise - - -def _task_to_skill(task: Task) -> AgentSkill: - """Convert a CrewAI Task to an A2A AgentSkill. - - Args: - task: The CrewAI Task to convert. - - Returns: - AgentSkill representing the task's capability. - """ - task_name = task.name or task.description[:50] - task_id = task_name.lower().replace(" ", "_") - - tags: list[str] = [] - if task.agent: - tags.append(task.agent.role.lower().replace(" ", "-")) - - return AgentSkill( - id=task_id, - name=task_name, - description=task.description, - tags=tags, - examples=[task.expected_output] if task.expected_output else None, - ) - - -def _tool_to_skill(tool_name: str, tool_description: str) -> AgentSkill: - """Convert an Agent's tool to an A2A AgentSkill. - - Args: - tool_name: Name of the tool. - tool_description: Description of what the tool does. - - Returns: - AgentSkill representing the tool's capability. - """ - tool_id = tool_name.lower().replace(" ", "_") - - return AgentSkill( - id=tool_id, - name=tool_name, - description=tool_description, - tags=[tool_name.lower().replace(" ", "-")], - ) - - -def _crew_to_agent_card(crew: Crew, url: str) -> AgentCard: - """Generate an A2A AgentCard from a Crew instance. - - Args: - crew: The Crew instance to generate a card for. - url: The base URL where this crew will be exposed. - - Returns: - AgentCard describing the crew's capabilities. - """ - crew_name = getattr(crew, "name", None) or crew.__class__.__name__ - - description_parts: list[str] = [] - crew_description = getattr(crew, "description", None) - if crew_description: - description_parts.append(crew_description) - else: - agent_roles = [agent.role for agent in crew.agents] - description_parts.append( - f"A crew of {len(crew.agents)} agents: {', '.join(agent_roles)}" - ) - - skills = [_task_to_skill(task) for task in crew.tasks] - - return AgentCard( - name=crew_name, - description=" ".join(description_parts), - url=url, - version="1.0.0", - capabilities=AgentCapabilities( - streaming=True, - push_notifications=True, - ), - default_input_modes=["text/plain", "application/json"], - default_output_modes=["text/plain", "application/json"], - skills=skills, - ) - - -def _agent_to_agent_card(agent: Agent, url: str) -> AgentCard: - """Generate an A2A AgentCard from an Agent instance. - - Uses A2AServerConfig values when available, falling back to agent properties. - If signing_config is provided, the card will be signed with JWS. - - Args: - agent: The Agent instance to generate a card for. - url: The base URL where this agent will be exposed. - - Returns: - AgentCard describing the agent's capabilities. - """ - from crewai.a2a.utils.agent_card_signing import sign_agent_card - - server_config = _get_server_config(agent) or A2AServerConfig() - - name = server_config.name or agent.role - - description_parts = [agent.goal] - if agent.backstory: - description_parts.append(agent.backstory) - description = server_config.description or " ".join(description_parts) - - skills: list[AgentSkill] = ( - server_config.skills.copy() if server_config.skills else [] - ) - - if not skills: - if agent.tools: - for tool in agent.tools: - tool_name = getattr(tool, "name", None) or tool.__class__.__name__ - tool_desc = getattr(tool, "description", None) or f"Tool: {tool_name}" - skills.append(_tool_to_skill(tool_name, tool_desc)) - - if not skills: - skills.append( - AgentSkill( - id=agent.role.lower().replace(" ", "_"), - name=agent.role, - description=agent.goal, - tags=[agent.role.lower().replace(" ", "-")], - ) - ) - - capabilities = server_config.capabilities - if server_config.server_extensions: - from crewai.a2a.extensions.server import ServerExtensionRegistry - - registry = ServerExtensionRegistry(server_config.server_extensions) - ext_list = registry.get_agent_extensions() - - existing_exts = list(capabilities.extensions) if capabilities.extensions else [] - existing_uris = {e.uri for e in existing_exts} - for ext in ext_list: - if ext.uri not in existing_uris: - existing_exts.append(ext) - - capabilities = capabilities.model_copy(update={"extensions": existing_exts}) - - card = AgentCard( - name=name, - description=description, - url=server_config.url or url, - version=server_config.version, - capabilities=capabilities, - default_input_modes=server_config.default_input_modes, - default_output_modes=server_config.default_output_modes, - skills=skills, - preferred_transport=server_config.transport.preferred, - protocol_version=server_config.protocol_version, - provider=server_config.provider, - documentation_url=server_config.documentation_url, - icon_url=server_config.icon_url, - additional_interfaces=server_config.additional_interfaces, - security=server_config.security, - security_schemes=server_config.security_schemes, - supports_authenticated_extended_card=server_config.supports_authenticated_extended_card, - ) - - if server_config.signing_config: - signature = sign_agent_card( - card, - private_key=server_config.signing_config.get_private_key(), - key_id=server_config.signing_config.key_id, - algorithm=server_config.signing_config.algorithm, - ) - card = card.model_copy(update={"signatures": [signature]}) - elif server_config.signatures: - card = card.model_copy(update={"signatures": server_config.signatures}) - - return card - - -def inject_a2a_server_methods(agent: Agent) -> None: - """Inject A2A server methods onto an Agent instance. - - Adds a `to_agent_card(url: str) -> AgentCard` method to the agent - that generates an A2A-compliant AgentCard. - - Only injects if the agent has an A2AServerConfig. - - Args: - agent: The Agent instance to inject methods onto. - """ - if _get_server_config(agent) is None: - return - - def _to_agent_card(self: Agent, url: str) -> AgentCard: - return _agent_to_agent_card(self, url) - - object.__setattr__(agent, "to_agent_card", MethodType(_to_agent_card, agent)) +from crewai_a2a.utils.agent_card import * # noqa: E402, F403 diff --git a/lib/crewai/src/crewai/a2a/utils/agent_card_signing.py b/lib/crewai/src/crewai/a2a/utils/agent_card_signing.py index d869020af..75efbb9b4 100644 --- a/lib/crewai/src/crewai/a2a/utils/agent_card_signing.py +++ b/lib/crewai/src/crewai/a2a/utils/agent_card_signing.py @@ -1,236 +1,13 @@ -"""AgentCard JWS signing utilities. +"""Backward-compatibility shim — use ``crewai_a2a.utils.agent_card_signing`` instead.""" -This module provides functions for signing and verifying AgentCards using -JSON Web Signatures (JWS) as per RFC 7515. Signed agent cards allow clients -to verify the authenticity and integrity of agent card information. - -Example: - >>> from crewai.a2a.utils.agent_card_signing import sign_agent_card - >>> signature = sign_agent_card(agent_card, private_key_pem, key_id="key-1") - >>> card_with_sig = card.model_copy(update={"signatures": [signature]}) -""" - -from __future__ import annotations - -import base64 -import json -import logging -from typing import Any, Literal - -from a2a.types import AgentCard, AgentCardSignature -import jwt -from pydantic import SecretStr +import warnings -logger = logging.getLogger(__name__) +warnings.warn( + "'crewai.a2a.utils.agent_card_signing' has been moved to 'crewai_a2a.utils.agent_card_signing'. " + "Please update your imports. The old path will be removed in v2.0.0.", + FutureWarning, + stacklevel=2, +) - -SigningAlgorithm = Literal[ - "RS256", "RS384", "RS512", "ES256", "ES384", "ES512", "PS256", "PS384", "PS512" -] - - -def _normalize_private_key(private_key: str | bytes | SecretStr) -> bytes: - """Normalize private key to bytes format. - - Args: - private_key: PEM-encoded private key as string, bytes, or SecretStr. - - Returns: - Private key as bytes. - """ - if isinstance(private_key, SecretStr): - private_key = private_key.get_secret_value() - if isinstance(private_key, str): - private_key = private_key.encode() - return private_key - - -def _serialize_agent_card(agent_card: AgentCard) -> str: - """Serialize AgentCard to canonical JSON for signing. - - Excludes the signatures field to avoid circular reference during signing. - Uses sorted keys and compact separators for deterministic output. - - Args: - agent_card: The AgentCard to serialize. - - Returns: - Canonical JSON string representation. - """ - card_dict = agent_card.model_dump(exclude={"signatures"}, exclude_none=True) - return json.dumps(card_dict, sort_keys=True, separators=(",", ":")) - - -def _base64url_encode(data: bytes | str) -> str: - """Encode data to URL-safe base64 without padding. - - Args: - data: Data to encode. - - Returns: - URL-safe base64 encoded string without padding. - """ - if isinstance(data, str): - data = data.encode() - return base64.urlsafe_b64encode(data).rstrip(b"=").decode("ascii") - - -def sign_agent_card( - agent_card: AgentCard, - private_key: str | bytes | SecretStr, - key_id: str | None = None, - algorithm: SigningAlgorithm = "RS256", -) -> AgentCardSignature: - """Sign an AgentCard using JWS (RFC 7515). - - Creates a detached JWS signature for the AgentCard. The signature covers - all fields except the signatures field itself. - - Args: - agent_card: The AgentCard to sign. - private_key: PEM-encoded private key (RSA, EC, or RSA-PSS). - key_id: Optional key identifier for the JWS header (kid claim). - algorithm: Signing algorithm (RS256, ES256, PS256, etc.). - - Returns: - AgentCardSignature with protected header and signature. - - Raises: - jwt.exceptions.InvalidKeyError: If the private key is invalid. - ValueError: If the algorithm is not supported for the key type. - - Example: - >>> signature = sign_agent_card( - ... agent_card, - ... private_key_pem="-----BEGIN PRIVATE KEY-----...", - ... key_id="my-key-id", - ... ) - """ - key_bytes = _normalize_private_key(private_key) - payload = _serialize_agent_card(agent_card) - - protected_header: dict[str, Any] = {"typ": "JWS"} - if key_id: - protected_header["kid"] = key_id - - jws_token = jwt.api_jws.encode( - payload.encode(), - key_bytes, - algorithm=algorithm, - headers=protected_header, - ) - - parts = jws_token.split(".") - protected_b64 = parts[0] - signature_b64 = parts[2] - - header: dict[str, Any] | None = None - if key_id: - header = {"kid": key_id} - - return AgentCardSignature( - protected=protected_b64, - signature=signature_b64, - header=header, - ) - - -def verify_agent_card_signature( - agent_card: AgentCard, - signature: AgentCardSignature, - public_key: str | bytes, - algorithms: list[str] | None = None, -) -> bool: - """Verify an AgentCard JWS signature. - - Validates that the signature was created with the corresponding private key - and that the AgentCard content has not been modified. - - Args: - agent_card: The AgentCard to verify. - signature: The AgentCardSignature to validate. - public_key: PEM-encoded public key (RSA, EC, or RSA-PSS). - algorithms: List of allowed algorithms. Defaults to common asymmetric algorithms. - - Returns: - True if signature is valid, False otherwise. - - Example: - >>> is_valid = verify_agent_card_signature( - ... agent_card, signature, public_key_pem="-----BEGIN PUBLIC KEY-----..." - ... ) - """ - if algorithms is None: - algorithms = [ - "RS256", - "RS384", - "RS512", - "ES256", - "ES384", - "ES512", - "PS256", - "PS384", - "PS512", - ] - - if isinstance(public_key, str): - public_key = public_key.encode() - - payload = _serialize_agent_card(agent_card) - payload_b64 = _base64url_encode(payload) - jws_token = f"{signature.protected}.{payload_b64}.{signature.signature}" - - try: - jwt.api_jws.decode( - jws_token, - public_key, - algorithms=algorithms, - ) - return True - except jwt.InvalidSignatureError: - logger.debug( - "AgentCard signature verification failed", - extra={"reason": "invalid_signature"}, - ) - return False - except jwt.DecodeError as e: - logger.debug( - "AgentCard signature verification failed", - extra={"reason": "decode_error", "error": str(e)}, - ) - return False - except jwt.InvalidAlgorithmError as e: - logger.debug( - "AgentCard signature verification failed", - extra={"reason": "algorithm_error", "error": str(e)}, - ) - return False - - -def get_key_id_from_signature(signature: AgentCardSignature) -> str | None: - """Extract the key ID (kid) from an AgentCardSignature. - - Checks both the unprotected header and the protected header for the kid claim. - - Args: - signature: The AgentCardSignature to extract from. - - Returns: - The key ID if present, None otherwise. - """ - if signature.header and "kid" in signature.header: - kid: str = signature.header["kid"] - return kid - - try: - protected = signature.protected - padding_needed = 4 - (len(protected) % 4) - if padding_needed != 4: - protected += "=" * padding_needed - - protected_json = base64.urlsafe_b64decode(protected).decode() - protected_header: dict[str, Any] = json.loads(protected_json) - return protected_header.get("kid") - except (ValueError, json.JSONDecodeError): - return None +from crewai_a2a.utils.agent_card_signing import * # noqa: E402, F403 diff --git a/lib/crewai/src/crewai/a2a/utils/content_type.py b/lib/crewai/src/crewai/a2a/utils/content_type.py index f063fef19..521e0c3f8 100644 --- a/lib/crewai/src/crewai/a2a/utils/content_type.py +++ b/lib/crewai/src/crewai/a2a/utils/content_type.py @@ -1,339 +1,13 @@ -"""Content type negotiation for A2A protocol. +"""Backward-compatibility shim — use ``crewai_a2a.utils.content_type`` instead.""" -This module handles negotiation of input/output MIME types between A2A clients -and servers based on AgentCard capabilities. -""" - -from __future__ import annotations - -from dataclasses import dataclass -from typing import TYPE_CHECKING, Annotated, Final, Literal, cast - -from a2a.types import Part - -from crewai.events.event_bus import crewai_event_bus -from crewai.events.types.a2a_events import A2AContentTypeNegotiatedEvent +import warnings -if TYPE_CHECKING: - from a2a.types import AgentCard, AgentSkill - - -TEXT_PLAIN: Literal["text/plain"] = "text/plain" -APPLICATION_JSON: Literal["application/json"] = "application/json" -IMAGE_PNG: Literal["image/png"] = "image/png" -IMAGE_JPEG: Literal["image/jpeg"] = "image/jpeg" -IMAGE_WILDCARD: Literal["image/*"] = "image/*" -APPLICATION_PDF: Literal["application/pdf"] = "application/pdf" -APPLICATION_OCTET_STREAM: Literal["application/octet-stream"] = ( - "application/octet-stream" +warnings.warn( + "'crewai.a2a.utils.content_type' has been moved to 'crewai_a2a.utils.content_type'. " + "Please update your imports. The old path will be removed in v2.0.0.", + FutureWarning, + stacklevel=2, ) -DEFAULT_CLIENT_INPUT_MODES: Final[list[Literal["text/plain", "application/json"]]] = [ - TEXT_PLAIN, - APPLICATION_JSON, -] -DEFAULT_CLIENT_OUTPUT_MODES: Final[list[Literal["text/plain", "application/json"]]] = [ - TEXT_PLAIN, - APPLICATION_JSON, -] - - -@dataclass -class NegotiatedContentTypes: - """Result of content type negotiation.""" - - input_modes: Annotated[list[str], "Negotiated input MIME types the client can send"] - output_modes: Annotated[ - list[str], "Negotiated output MIME types the server will produce" - ] - effective_input_modes: Annotated[list[str], "Server's effective input modes"] - effective_output_modes: Annotated[list[str], "Server's effective output modes"] - skill_name: Annotated[ - str | None, "Skill name if negotiation was skill-specific" - ] = None - - -class ContentTypeNegotiationError(Exception): - """Raised when no compatible content types can be negotiated.""" - - def __init__( - self, - client_input_modes: list[str], - client_output_modes: list[str], - server_input_modes: list[str], - server_output_modes: list[str], - direction: str = "both", - message: str | None = None, - ) -> None: - self.client_input_modes = client_input_modes - self.client_output_modes = client_output_modes - self.server_input_modes = server_input_modes - self.server_output_modes = server_output_modes - self.direction = direction - - if message is None: - if direction == "input": - message = ( - f"No compatible input content types. " - f"Client supports: {client_input_modes}, " - f"Server accepts: {server_input_modes}" - ) - elif direction == "output": - message = ( - f"No compatible output content types. " - f"Client accepts: {client_output_modes}, " - f"Server produces: {server_output_modes}" - ) - else: - message = ( - f"No compatible content types. " - f"Input - Client: {client_input_modes}, Server: {server_input_modes}. " - f"Output - Client: {client_output_modes}, Server: {server_output_modes}" - ) - - super().__init__(message) - - -def _normalize_mime_type(mime_type: str) -> str: - """Normalize MIME type for comparison (lowercase, strip whitespace).""" - return mime_type.lower().strip() - - -def _mime_types_compatible(client_type: str, server_type: str) -> bool: - """Check if two MIME types are compatible. - - Handles wildcards like image/* matching image/png. - """ - client_normalized = _normalize_mime_type(client_type) - server_normalized = _normalize_mime_type(server_type) - - if client_normalized == server_normalized: - return True - - if "*" in client_normalized or "*" in server_normalized: - client_parts = client_normalized.split("/") - server_parts = server_normalized.split("/") - - if len(client_parts) == 2 and len(server_parts) == 2: - type_match = ( - client_parts[0] == server_parts[0] - or client_parts[0] == "*" - or server_parts[0] == "*" - ) - subtype_match = ( - client_parts[1] == server_parts[1] - or client_parts[1] == "*" - or server_parts[1] == "*" - ) - return type_match and subtype_match - - return False - - -def _find_compatible_modes( - client_modes: list[str], server_modes: list[str] -) -> list[str]: - """Find compatible MIME types between client and server. - - Returns modes in client preference order. - """ - compatible = [] - for client_mode in client_modes: - for server_mode in server_modes: - if _mime_types_compatible(client_mode, server_mode): - if "*" in client_mode and "*" not in server_mode: - if server_mode not in compatible: - compatible.append(server_mode) - else: - if client_mode not in compatible: - compatible.append(client_mode) - break - return compatible - - -def _get_effective_modes( - agent_card: AgentCard, - skill_name: str | None = None, -) -> tuple[list[str], list[str], AgentSkill | None]: - """Get effective input/output modes from agent card. - - If skill_name is provided and the skill has custom modes, those are used. - Otherwise, falls back to agent card defaults. - """ - skill: AgentSkill | None = None - - if skill_name and agent_card.skills: - for s in agent_card.skills: - if s.name == skill_name or s.id == skill_name: - skill = s - break - - if skill: - input_modes = ( - skill.input_modes if skill.input_modes else agent_card.default_input_modes - ) - output_modes = ( - skill.output_modes - if skill.output_modes - else agent_card.default_output_modes - ) - else: - input_modes = agent_card.default_input_modes - output_modes = agent_card.default_output_modes - - return input_modes, output_modes, skill - - -def negotiate_content_types( - agent_card: AgentCard, - client_input_modes: list[str] | None = None, - client_output_modes: list[str] | None = None, - skill_name: str | None = None, - emit_event: bool = True, - endpoint: str | None = None, - a2a_agent_name: str | None = None, - strict: bool = False, -) -> NegotiatedContentTypes: - """Negotiate content types between client and server. - - Args: - agent_card: The remote agent's card with capability info. - client_input_modes: MIME types the client can send. Defaults to text/plain and application/json. - client_output_modes: MIME types the client can accept. Defaults to text/plain and application/json. - skill_name: Optional skill to use for mode lookup. - emit_event: Whether to emit a content type negotiation event. - endpoint: Agent endpoint (for event metadata). - a2a_agent_name: Agent name (for event metadata). - strict: If True, raises error when no compatible types found. - If False, returns empty lists for incompatible directions. - - Returns: - NegotiatedContentTypes with compatible input and output modes. - - Raises: - ContentTypeNegotiationError: If strict=True and no compatible types found. - """ - if client_input_modes is None: - client_input_modes = cast(list[str], DEFAULT_CLIENT_INPUT_MODES.copy()) - if client_output_modes is None: - client_output_modes = cast(list[str], DEFAULT_CLIENT_OUTPUT_MODES.copy()) - - server_input_modes, server_output_modes, skill = _get_effective_modes( - agent_card, skill_name - ) - - compatible_input = _find_compatible_modes(client_input_modes, server_input_modes) - compatible_output = _find_compatible_modes(client_output_modes, server_output_modes) - - if strict: - if not compatible_input and not compatible_output: - raise ContentTypeNegotiationError( - client_input_modes=client_input_modes, - client_output_modes=client_output_modes, - server_input_modes=server_input_modes, - server_output_modes=server_output_modes, - ) - if not compatible_input: - raise ContentTypeNegotiationError( - client_input_modes=client_input_modes, - client_output_modes=client_output_modes, - server_input_modes=server_input_modes, - server_output_modes=server_output_modes, - direction="input", - ) - if not compatible_output: - raise ContentTypeNegotiationError( - client_input_modes=client_input_modes, - client_output_modes=client_output_modes, - server_input_modes=server_input_modes, - server_output_modes=server_output_modes, - direction="output", - ) - - result = NegotiatedContentTypes( - input_modes=compatible_input, - output_modes=compatible_output, - effective_input_modes=server_input_modes, - effective_output_modes=server_output_modes, - skill_name=skill.name if skill else None, - ) - - if emit_event: - crewai_event_bus.emit( - None, - A2AContentTypeNegotiatedEvent( - endpoint=endpoint or agent_card.url, - a2a_agent_name=a2a_agent_name or agent_card.name, - skill_name=skill_name, - client_input_modes=client_input_modes, - client_output_modes=client_output_modes, - server_input_modes=server_input_modes, - server_output_modes=server_output_modes, - negotiated_input_modes=compatible_input, - negotiated_output_modes=compatible_output, - negotiation_success=bool(compatible_input and compatible_output), - ), - ) - - return result - - -def validate_content_type( - content_type: str, - allowed_modes: list[str], -) -> bool: - """Validate that a content type is allowed by a list of modes. - - Args: - content_type: The MIME type to validate. - allowed_modes: List of allowed MIME types (may include wildcards). - - Returns: - True if content_type is compatible with any allowed mode. - """ - for mode in allowed_modes: - if _mime_types_compatible(content_type, mode): - return True - return False - - -def get_part_content_type(part: Part) -> str: - """Extract MIME type from an A2A Part. - - Args: - part: A Part object containing TextPart, DataPart, or FilePart. - - Returns: - The MIME type string for this part. - """ - root = part.root - if root.kind == "text": - return TEXT_PLAIN - if root.kind == "data": - return APPLICATION_JSON - if root.kind == "file": - return root.file.mime_type or APPLICATION_OCTET_STREAM - return APPLICATION_OCTET_STREAM - - -def validate_message_parts( - parts: list[Part], - allowed_modes: list[str], -) -> list[str]: - """Validate that all message parts have allowed content types. - - Args: - parts: List of Parts from the incoming message. - allowed_modes: List of allowed MIME types (from default_input_modes). - - Returns: - List of invalid content types found (empty if all valid). - """ - invalid_types: list[str] = [] - for part in parts: - content_type = get_part_content_type(part) - if not validate_content_type(content_type, allowed_modes): - if content_type not in invalid_types: - invalid_types.append(content_type) - return invalid_types +from crewai_a2a.utils.content_type import * # noqa: E402, F403 diff --git a/lib/crewai/src/crewai/a2a/utils/delegation.py b/lib/crewai/src/crewai/a2a/utils/delegation.py index cfcf51f36..e8f45a742 100644 --- a/lib/crewai/src/crewai/a2a/utils/delegation.py +++ b/lib/crewai/src/crewai/a2a/utils/delegation.py @@ -1,980 +1,13 @@ -"""A2A delegation utilities for executing tasks on remote agents.""" +"""Backward-compatibility shim — use ``crewai_a2a.utils.delegation`` instead.""" -from __future__ import annotations +import warnings -import asyncio -import base64 -from collections.abc import AsyncIterator, Callable, MutableMapping -from contextlib import asynccontextmanager -import logging -from typing import TYPE_CHECKING, Any, Final, Literal -import uuid -from a2a.client import Client, ClientConfig, ClientFactory -from a2a.types import ( - AgentCard, - FilePart, - FileWithBytes, - Message, - Part, - PushNotificationConfig as A2APushNotificationConfig, - Role, - TextPart, -) -import httpx -from pydantic import BaseModel - -from crewai.a2a.auth.client_schemes import APIKeyAuth, HTTPDigestAuth -from crewai.a2a.auth.utils import ( - _auth_store, - configure_auth_client, - validate_auth_against_agent_card, -) -from crewai.a2a.config import ClientTransportConfig, GRPCClientConfig -from crewai.a2a.extensions.registry import ( - ExtensionsMiddleware, - validate_required_extensions, -) -from crewai.a2a.task_helpers import TaskStateResult -from crewai.a2a.types import ( - HANDLER_REGISTRY, - HandlerType, - PartsDict, - PartsMetadataDict, - TransportType, -) -from crewai.a2a.updates import ( - PollingConfig, - PushNotificationConfig, - StreamingHandler, - UpdateConfig, -) -from crewai.a2a.utils.agent_card import ( - _afetch_agent_card_cached, - _get_tls_verify, - _prepare_auth_headers, -) -from crewai.a2a.utils.content_type import ( - DEFAULT_CLIENT_OUTPUT_MODES, - negotiate_content_types, -) -from crewai.a2a.utils.transport import ( - NegotiatedTransport, - TransportNegotiationError, - negotiate_transport, -) -from crewai.events.event_bus import crewai_event_bus -from crewai.events.types.a2a_events import ( - A2AConversationStartedEvent, - A2ADelegationCompletedEvent, - A2ADelegationStartedEvent, - A2AMessageSentEvent, +warnings.warn( + "'crewai.a2a.utils.delegation' has been moved to 'crewai_a2a.utils.delegation'. " + "Please update your imports. The old path will be removed in v2.0.0.", + FutureWarning, + stacklevel=2, ) - -logger = logging.getLogger(__name__) - - -if TYPE_CHECKING: - from a2a.types import Message - - from crewai.a2a.auth.client_schemes import ClientAuthScheme - - -_DEFAULT_TRANSPORT: Final[TransportType] = "JSONRPC" - - -def _create_file_parts(input_files: dict[str, Any] | None) -> list[Part]: - """Convert FileInput dictionary to FilePart objects. - - Args: - input_files: Dictionary mapping names to FileInput objects. - - Returns: - List of Part objects containing FilePart data. - """ - if not input_files: - return [] - - try: - import crewai_files # noqa: F401 - except ImportError: - logger.debug("crewai_files not installed, skipping file parts") - return [] - - parts: list[Part] = [] - for name, file_input in input_files.items(): - content_bytes = file_input.read() - content_base64 = base64.b64encode(content_bytes).decode() - file_with_bytes = FileWithBytes( - bytes=content_base64, - mimeType=file_input.content_type, - name=file_input.filename or name, - ) - parts.append(Part(root=FilePart(file=file_with_bytes))) - - return parts - - -def get_handler(config: UpdateConfig | None) -> HandlerType: - """Get the handler class for a given update config. - - Args: - config: Update mechanism configuration. - - Returns: - Handler class for the config type, defaults to StreamingHandler. - """ - if config is None: - return StreamingHandler - return HANDLER_REGISTRY.get(type(config), StreamingHandler) - - -def execute_a2a_delegation( - endpoint: str, - auth: ClientAuthScheme | None, - timeout: int, - task_description: str, - context: str | None = None, - context_id: str | None = None, - task_id: str | None = None, - reference_task_ids: list[str] | None = None, - metadata: dict[str, Any] | None = None, - extensions: dict[str, Any] | None = None, - conversation_history: list[Message] | None = None, - agent_id: str | None = None, - agent_role: Role | None = None, - agent_branch: Any | None = None, - response_model: type[BaseModel] | None = None, - turn_number: int | None = None, - updates: UpdateConfig | None = None, - from_task: Any | None = None, - from_agent: Any | None = None, - skill_id: str | None = None, - client_extensions: list[str] | None = None, - transport: ClientTransportConfig | None = None, - accepted_output_modes: list[str] | None = None, - input_files: dict[str, Any] | None = None, -) -> TaskStateResult: - """Execute a task delegation to a remote A2A agent synchronously. - - WARNING: This function blocks the entire thread by creating and running a new - event loop. Prefer using 'await aexecute_a2a_delegation()' in async contexts - for better performance and resource efficiency. - - This is a synchronous wrapper around aexecute_a2a_delegation that creates a - new event loop to run the async implementation. It is provided for compatibility - with synchronous code paths only. - - Args: - endpoint: A2A agent endpoint URL (AgentCard URL). - auth: Optional ClientAuthScheme for authentication. - timeout: Request timeout in seconds. - task_description: The task to delegate. - context: Optional context information. - context_id: Context ID for correlating messages/tasks. - task_id: Specific task identifier. - reference_task_ids: List of related task IDs. - metadata: Additional metadata. - extensions: Protocol extensions for custom fields. - conversation_history: Previous Message objects from conversation. - agent_id: Agent identifier for logging. - agent_role: Role of the CrewAI agent delegating the task. - agent_branch: Optional agent tree branch for logging. - response_model: Optional Pydantic model for structured outputs. - turn_number: Optional turn number for multi-turn conversations. - updates: Update mechanism config from A2AConfig.updates. - from_task: Optional CrewAI Task object for event metadata. - from_agent: Optional CrewAI Agent object for event metadata. - skill_id: Optional skill ID to target a specific agent capability. - client_extensions: A2A protocol extension URIs the client supports. - transport: Transport configuration (preferred, supported transports, gRPC settings). - accepted_output_modes: MIME types the client can accept in responses. - input_files: Optional dictionary of files to send to remote agent. - - Returns: - TaskStateResult with status, result/error, history, and agent_card. - - Raises: - RuntimeError: If called from an async context with a running event loop. - """ - try: - asyncio.get_running_loop() - raise RuntimeError( - "execute_a2a_delegation() cannot be called from an async context. " - "Use 'await aexecute_a2a_delegation()' instead." - ) - except RuntimeError as e: - if "no running event loop" not in str(e).lower(): - raise - - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - try: - return loop.run_until_complete( - aexecute_a2a_delegation( - endpoint=endpoint, - auth=auth, - timeout=timeout, - task_description=task_description, - context=context, - context_id=context_id, - task_id=task_id, - reference_task_ids=reference_task_ids, - metadata=metadata, - extensions=extensions, - conversation_history=conversation_history, - agent_id=agent_id, - agent_role=agent_role, - agent_branch=agent_branch, - response_model=response_model, - turn_number=turn_number, - updates=updates, - from_task=from_task, - from_agent=from_agent, - skill_id=skill_id, - client_extensions=client_extensions, - transport=transport, - accepted_output_modes=accepted_output_modes, - input_files=input_files, - ) - ) - finally: - try: - loop.run_until_complete(loop.shutdown_asyncgens()) - finally: - loop.close() - - -async def aexecute_a2a_delegation( - endpoint: str, - auth: ClientAuthScheme | None, - timeout: int, - task_description: str, - context: str | None = None, - context_id: str | None = None, - task_id: str | None = None, - reference_task_ids: list[str] | None = None, - metadata: dict[str, Any] | None = None, - extensions: dict[str, Any] | None = None, - conversation_history: list[Message] | None = None, - agent_id: str | None = None, - agent_role: Role | None = None, - agent_branch: Any | None = None, - response_model: type[BaseModel] | None = None, - turn_number: int | None = None, - updates: UpdateConfig | None = None, - from_task: Any | None = None, - from_agent: Any | None = None, - skill_id: str | None = None, - client_extensions: list[str] | None = None, - transport: ClientTransportConfig | None = None, - accepted_output_modes: list[str] | None = None, - input_files: dict[str, Any] | None = None, -) -> TaskStateResult: - """Execute a task delegation to a remote A2A agent asynchronously. - - Native async implementation with multi-turn support. Use this when running - in an async context (e.g., with Crew.akickoff() or agent.aexecute_task()). - - Args: - endpoint: A2A agent endpoint URL. - auth: Optional ClientAuthScheme for authentication. - timeout: Request timeout in seconds. - task_description: The task to delegate. - context: Optional context information. - context_id: Context ID for correlating messages/tasks. - task_id: Specific task identifier. - reference_task_ids: List of related task IDs. - metadata: Additional metadata. - extensions: Protocol extensions for custom fields. - conversation_history: Previous Message objects from conversation. - agent_id: Agent identifier for logging. - agent_role: Role of the CrewAI agent delegating the task. - agent_branch: Optional agent tree branch for logging. - response_model: Optional Pydantic model for structured outputs. - turn_number: Optional turn number for multi-turn conversations. - updates: Update mechanism config from A2AConfig.updates. - from_task: Optional CrewAI Task object for event metadata. - from_agent: Optional CrewAI Agent object for event metadata. - skill_id: Optional skill ID to target a specific agent capability. - client_extensions: A2A protocol extension URIs the client supports. - transport: Transport configuration (preferred, supported transports, gRPC settings). - accepted_output_modes: MIME types the client can accept in responses. - input_files: Optional dictionary of files to send to remote agent. - - Returns: - TaskStateResult with status, result/error, history, and agent_card. - """ - if conversation_history is None: - conversation_history = [] - - is_multiturn = len(conversation_history) > 0 - if turn_number is None: - turn_number = len([m for m in conversation_history if m.role == Role.user]) + 1 - - try: - result = await _aexecute_a2a_delegation_impl( - endpoint=endpoint, - auth=auth, - timeout=timeout, - task_description=task_description, - context=context, - context_id=context_id, - task_id=task_id, - reference_task_ids=reference_task_ids, - metadata=metadata, - extensions=extensions, - conversation_history=conversation_history, - is_multiturn=is_multiturn, - turn_number=turn_number, - agent_branch=agent_branch, - agent_id=agent_id, - agent_role=agent_role, - response_model=response_model, - updates=updates, - from_task=from_task, - from_agent=from_agent, - skill_id=skill_id, - client_extensions=client_extensions, - transport=transport, - accepted_output_modes=accepted_output_modes, - input_files=input_files, - ) - except Exception as e: - crewai_event_bus.emit( - agent_branch, - A2ADelegationCompletedEvent( - status="failed", - result=None, - error=str(e), - context_id=context_id, - is_multiturn=is_multiturn, - endpoint=endpoint, - metadata=metadata, - extensions=list(extensions.keys()) if extensions else None, - from_task=from_task, - from_agent=from_agent, - ), - ) - raise - - agent_card_data = result.get("agent_card") - crewai_event_bus.emit( - agent_branch, - A2ADelegationCompletedEvent( - status=result["status"], - result=result.get("result"), - error=result.get("error"), - context_id=context_id, - is_multiturn=is_multiturn, - endpoint=endpoint, - a2a_agent_name=result.get("a2a_agent_name"), - agent_card=agent_card_data, - provider=agent_card_data.get("provider") if agent_card_data else None, - metadata=metadata, - extensions=list(extensions.keys()) if extensions else None, - from_task=from_task, - from_agent=from_agent, - ), - ) - - return result - - -async def _aexecute_a2a_delegation_impl( - endpoint: str, - auth: ClientAuthScheme | None, - timeout: int, - task_description: str, - context: str | None, - context_id: str | None, - task_id: str | None, - reference_task_ids: list[str] | None, - metadata: dict[str, Any] | None, - extensions: dict[str, Any] | None, - conversation_history: list[Message], - is_multiturn: bool, - turn_number: int, - agent_branch: Any | None, - agent_id: str | None, - agent_role: str | None, - response_model: type[BaseModel] | None, - updates: UpdateConfig | None, - from_task: Any | None = None, - from_agent: Any | None = None, - skill_id: str | None = None, - client_extensions: list[str] | None = None, - transport: ClientTransportConfig | None = None, - accepted_output_modes: list[str] | None = None, - input_files: dict[str, Any] | None = None, -) -> TaskStateResult: - """Internal async implementation of A2A delegation.""" - if transport is None: - transport = ClientTransportConfig() - if auth: - auth_data = auth.model_dump_json( - exclude={ - "_access_token", - "_token_expires_at", - "_refresh_token", - "_authorization_callback", - } - ) - auth_hash = _auth_store.compute_key(type(auth).__name__, auth_data) - else: - auth_hash = _auth_store.compute_key("none", endpoint) - _auth_store.set(auth_hash, auth) - agent_card = await _afetch_agent_card_cached( - endpoint=endpoint, auth_hash=auth_hash, timeout=timeout - ) - - validate_auth_against_agent_card(agent_card, auth) - - unsupported_exts = validate_required_extensions(agent_card, client_extensions) - if unsupported_exts: - ext_uris = [ext.uri for ext in unsupported_exts] - raise ValueError( - f"Agent requires extensions not supported by client: {ext_uris}" - ) - - negotiated: NegotiatedTransport | None = None - effective_transport: TransportType = transport.preferred or _DEFAULT_TRANSPORT - effective_url = endpoint - - client_transports: list[str] = ( - list(transport.supported) if transport.supported else [_DEFAULT_TRANSPORT] - ) - - try: - negotiated = negotiate_transport( - agent_card=agent_card, - client_supported_transports=client_transports, - client_preferred_transport=transport.preferred, - endpoint=endpoint, - a2a_agent_name=agent_card.name, - ) - effective_transport = negotiated.transport # type: ignore[assignment] - effective_url = negotiated.url - except TransportNegotiationError as e: - logger.warning( - "Transport negotiation failed, using fallback", - extra={ - "error": str(e), - "fallback_transport": effective_transport, - "fallback_url": effective_url, - "endpoint": endpoint, - "client_transports": client_transports, - "server_transports": [ - iface.transport for iface in agent_card.additional_interfaces or [] - ] - + [agent_card.preferred_transport or "JSONRPC"], - }, - ) - - effective_output_modes = accepted_output_modes or DEFAULT_CLIENT_OUTPUT_MODES.copy() - - content_negotiated = negotiate_content_types( - agent_card=agent_card, - client_output_modes=accepted_output_modes, - skill_name=skill_id, - endpoint=endpoint, - a2a_agent_name=agent_card.name, - ) - if content_negotiated.output_modes: - effective_output_modes = content_negotiated.output_modes - - headers, _ = await _prepare_auth_headers(auth, timeout) - - a2a_agent_name = None - if agent_card.name: - a2a_agent_name = agent_card.name - - agent_card_dict = agent_card.model_dump(exclude_none=True) - crewai_event_bus.emit( - agent_branch, - A2ADelegationStartedEvent( - endpoint=endpoint, - task_description=task_description, - agent_id=agent_id or endpoint, - context_id=context_id, - is_multiturn=is_multiturn, - turn_number=turn_number, - a2a_agent_name=a2a_agent_name, - agent_card=agent_card_dict, - protocol_version=agent_card.protocol_version, - provider=agent_card_dict.get("provider"), - skill_id=skill_id, - metadata=metadata, - extensions=list(extensions.keys()) if extensions else None, - from_task=from_task, - from_agent=from_agent, - ), - ) - - if turn_number == 1: - agent_id_for_event = agent_id or endpoint - crewai_event_bus.emit( - agent_branch, - A2AConversationStartedEvent( - agent_id=agent_id_for_event, - endpoint=endpoint, - context_id=context_id, - a2a_agent_name=a2a_agent_name, - agent_card=agent_card_dict, - protocol_version=agent_card.protocol_version, - provider=agent_card_dict.get("provider"), - skill_id=skill_id, - reference_task_ids=reference_task_ids, - metadata=metadata, - extensions=list(extensions.keys()) if extensions else None, - from_task=from_task, - from_agent=from_agent, - ), - ) - - message_parts = [] - - if context: - message_parts.append(f"Context:\n{context}\n\n") - message_parts.append(f"{task_description}") - message_text = "".join(message_parts) - - if is_multiturn and conversation_history and not task_id: - if first_task_id := conversation_history[0].task_id: - task_id = first_task_id - - parts: PartsDict = {"text": message_text} - if response_model: - parts.update( - { - "metadata": PartsMetadataDict( - mimeType="application/json", - schema=response_model.model_json_schema(), - ) - } - ) - - message_metadata = metadata.copy() if metadata else {} - if skill_id: - message_metadata["skill_id"] = skill_id - - parts_list: list[Part] = [Part(root=TextPart(**parts))] - parts_list.extend(_create_file_parts(input_files)) - - message = Message( - role=Role.user, - message_id=str(uuid.uuid4()), - parts=parts_list, - context_id=context_id, - task_id=task_id, - reference_task_ids=reference_task_ids, - metadata=message_metadata if message_metadata else None, - extensions=extensions, - ) - - new_messages: list[Message] = [*conversation_history, message] - crewai_event_bus.emit( - None, - A2AMessageSentEvent( - message=message_text, - turn_number=turn_number, - context_id=context_id, - message_id=message.message_id, - is_multiturn=is_multiturn, - agent_role=agent_role, - endpoint=endpoint, - a2a_agent_name=a2a_agent_name, - skill_id=skill_id, - metadata=message_metadata if message_metadata else None, - extensions=list(extensions.keys()) if extensions else None, - from_task=from_task, - from_agent=from_agent, - ), - ) - - handler = get_handler(updates) - use_polling = isinstance(updates, PollingConfig) - - handler_kwargs: dict[str, Any] = { - "turn_number": turn_number, - "is_multiturn": is_multiturn, - "agent_role": agent_role, - "context_id": context_id, - "task_id": task_id, - "endpoint": endpoint, - "agent_branch": agent_branch, - "a2a_agent_name": a2a_agent_name, - "from_task": from_task, - "from_agent": from_agent, - } - - if isinstance(updates, PollingConfig): - handler_kwargs.update( - { - "polling_interval": updates.interval, - "polling_timeout": updates.timeout or float(timeout), - "history_length": updates.history_length, - "max_polls": updates.max_polls, - } - ) - elif isinstance(updates, PushNotificationConfig): - handler_kwargs.update( - { - "config": updates, - "result_store": updates.result_store, - "polling_timeout": updates.timeout or float(timeout), - "polling_interval": updates.interval, - } - ) - - push_config_for_client = ( - updates if isinstance(updates, PushNotificationConfig) else None - ) - - use_streaming = not use_polling and push_config_for_client is None - - client_agent_card = agent_card - if effective_url != agent_card.url: - client_agent_card = agent_card.model_copy(update={"url": effective_url}) - - async with _create_a2a_client( - agent_card=client_agent_card, - transport_protocol=effective_transport, - timeout=timeout, - headers=headers, - streaming=use_streaming, - auth=auth, - use_polling=use_polling, - push_notification_config=push_config_for_client, - client_extensions=client_extensions, - accepted_output_modes=effective_output_modes, # type: ignore[arg-type] - grpc_config=transport.grpc, - ) as client: - result = await handler.execute( - client=client, - message=message, - new_messages=new_messages, - agent_card=agent_card, - **handler_kwargs, - ) - result["a2a_agent_name"] = a2a_agent_name - result["agent_card"] = agent_card.model_dump(exclude_none=True) - return result - - -def _normalize_grpc_metadata( - metadata: tuple[tuple[str, str], ...] | None, -) -> tuple[tuple[str, str], ...] | None: - """Lowercase all gRPC metadata keys. - - gRPC requires lowercase metadata keys, but some libraries (like the A2A SDK) - use mixed-case headers like 'X-A2A-Extensions'. This normalizes them. - """ - if metadata is None: - return None - return tuple((key.lower(), value) for key, value in metadata) - - -def _create_grpc_interceptors( - auth_metadata: list[tuple[str, str]] | None = None, -) -> list[Any]: - """Create gRPC interceptors for metadata normalization and auth injection. - - Args: - auth_metadata: Optional auth metadata to inject into all calls. - Used for insecure channels that need auth (non-localhost without TLS). - - Returns a list of interceptors that lowercase metadata keys for gRPC - compatibility. Must be called after grpc is imported. - """ - import grpc.aio # type: ignore[import-untyped] - - def _merge_metadata( - existing: tuple[tuple[str, str], ...] | None, - auth: list[tuple[str, str]] | None, - ) -> tuple[tuple[str, str], ...] | None: - """Merge existing metadata with auth metadata and normalize keys.""" - merged: list[tuple[str, str]] = [] - if existing: - merged.extend(existing) - if auth: - merged.extend(auth) - if not merged: - return None - return tuple((key.lower(), value) for key, value in merged) - - def _inject_metadata(client_call_details: Any) -> Any: - """Inject merged metadata into call details.""" - return client_call_details._replace( - metadata=_merge_metadata(client_call_details.metadata, auth_metadata) - ) - - class MetadataUnaryUnary(grpc.aio.UnaryUnaryClientInterceptor): # type: ignore[misc,no-any-unimported] - """Interceptor for unary-unary calls that injects auth metadata.""" - - async def intercept_unary_unary( # type: ignore[no-untyped-def] - self, continuation, client_call_details, request - ): - """Intercept unary-unary call and inject metadata.""" - return await continuation(_inject_metadata(client_call_details), request) - - class MetadataUnaryStream(grpc.aio.UnaryStreamClientInterceptor): # type: ignore[misc,no-any-unimported] - """Interceptor for unary-stream calls that injects auth metadata.""" - - async def intercept_unary_stream( # type: ignore[no-untyped-def] - self, continuation, client_call_details, request - ): - """Intercept unary-stream call and inject metadata.""" - return await continuation(_inject_metadata(client_call_details), request) - - class MetadataStreamUnary(grpc.aio.StreamUnaryClientInterceptor): # type: ignore[misc,no-any-unimported] - """Interceptor for stream-unary calls that injects auth metadata.""" - - async def intercept_stream_unary( # type: ignore[no-untyped-def] - self, continuation, client_call_details, request_iterator - ): - """Intercept stream-unary call and inject metadata.""" - return await continuation( - _inject_metadata(client_call_details), request_iterator - ) - - class MetadataStreamStream(grpc.aio.StreamStreamClientInterceptor): # type: ignore[misc,no-any-unimported] - """Interceptor for stream-stream calls that injects auth metadata.""" - - async def intercept_stream_stream( # type: ignore[no-untyped-def] - self, continuation, client_call_details, request_iterator - ): - """Intercept stream-stream call and inject metadata.""" - return await continuation( - _inject_metadata(client_call_details), request_iterator - ) - - return [ - MetadataUnaryUnary(), - MetadataUnaryStream(), - MetadataStreamUnary(), - MetadataStreamStream(), - ] - - -def _create_grpc_channel_factory( - grpc_config: GRPCClientConfig, - auth: ClientAuthScheme | None = None, -) -> Callable[[str], Any]: - """Create a gRPC channel factory with the given configuration. - - Args: - grpc_config: gRPC client configuration with channel options. - auth: Optional ClientAuthScheme for TLS and auth configuration. - - Returns: - A callable that creates gRPC channels from URLs. - """ - try: - import grpc - except ImportError as e: - raise ImportError( - "gRPC transport requires grpcio. Install with: pip install a2a-sdk[grpc]" - ) from e - - auth_metadata: list[tuple[str, str]] = [] - - if auth is not None: - from crewai.a2a.auth.client_schemes import ( - APIKeyAuth, - BearerTokenAuth, - HTTPBasicAuth, - HTTPDigestAuth, - OAuth2AuthorizationCode, - OAuth2ClientCredentials, - ) - - if isinstance(auth, HTTPDigestAuth): - raise ValueError( - "HTTPDigestAuth is not supported with gRPC transport. " - "Digest authentication requires HTTP challenge-response flow. " - "Use BearerTokenAuth, HTTPBasicAuth, APIKeyAuth (header), or OAuth2 instead." - ) - if isinstance(auth, APIKeyAuth) and auth.location in ("query", "cookie"): - raise ValueError( - f"APIKeyAuth with location='{auth.location}' is not supported with gRPC transport. " - "gRPC only supports header-based authentication. " - "Use APIKeyAuth with location='header' instead." - ) - - if isinstance(auth, BearerTokenAuth): - auth_metadata.append(("authorization", f"Bearer {auth.token}")) - elif isinstance(auth, HTTPBasicAuth): - import base64 - - basic_credentials = f"{auth.username}:{auth.password}" - encoded = base64.b64encode(basic_credentials.encode()).decode() - auth_metadata.append(("authorization", f"Basic {encoded}")) - elif isinstance(auth, APIKeyAuth) and auth.location == "header": - header_name = auth.name.lower() - auth_metadata.append((header_name, auth.api_key)) - elif isinstance(auth, (OAuth2ClientCredentials, OAuth2AuthorizationCode)): - if auth._access_token: - auth_metadata.append(("authorization", f"Bearer {auth._access_token}")) - - def factory(url: str) -> Any: - """Create a gRPC channel for the given URL.""" - target = url - use_tls = False - - if url.startswith("grpcs://"): - target = url[8:] - use_tls = True - elif url.startswith("grpc://"): - target = url[7:] - elif url.startswith("https://"): - target = url[8:] - use_tls = True - elif url.startswith("http://"): - target = url[7:] - - options: list[tuple[str, Any]] = [] - if grpc_config.max_send_message_length is not None: - options.append( - ("grpc.max_send_message_length", grpc_config.max_send_message_length) - ) - if grpc_config.max_receive_message_length is not None: - options.append( - ( - "grpc.max_receive_message_length", - grpc_config.max_receive_message_length, - ) - ) - if grpc_config.keepalive_time_ms is not None: - options.append(("grpc.keepalive_time_ms", grpc_config.keepalive_time_ms)) - if grpc_config.keepalive_timeout_ms is not None: - options.append( - ("grpc.keepalive_timeout_ms", grpc_config.keepalive_timeout_ms) - ) - - channel_credentials = None - if auth and hasattr(auth, "tls") and auth.tls: - channel_credentials = auth.tls.get_grpc_credentials() - elif use_tls: - channel_credentials = grpc.ssl_channel_credentials() - - if channel_credentials and auth_metadata: - - class AuthMetadataPlugin(grpc.AuthMetadataPlugin): # type: ignore[misc,no-any-unimported] - """gRPC auth metadata plugin that adds auth headers as metadata.""" - - def __init__(self, metadata: list[tuple[str, str]]) -> None: - self._metadata = tuple(metadata) - - def __call__( # type: ignore[no-any-unimported] - self, - context: grpc.AuthMetadataContext, - callback: grpc.AuthMetadataPluginCallback, - ) -> None: - callback(self._metadata, None) - - call_creds = grpc.metadata_call_credentials( - AuthMetadataPlugin(auth_metadata) - ) - credentials = grpc.composite_channel_credentials( - channel_credentials, call_creds - ) - interceptors = _create_grpc_interceptors() - return grpc.aio.secure_channel( - target, credentials, options=options or None, interceptors=interceptors - ) - if channel_credentials: - interceptors = _create_grpc_interceptors() - return grpc.aio.secure_channel( - target, - channel_credentials, - options=options or None, - interceptors=interceptors, - ) - interceptors = _create_grpc_interceptors( - auth_metadata=auth_metadata if auth_metadata else None - ) - return grpc.aio.insecure_channel( - target, options=options or None, interceptors=interceptors - ) - - return factory - - -@asynccontextmanager -async def _create_a2a_client( - agent_card: AgentCard, - transport_protocol: Literal["JSONRPC", "GRPC", "HTTP+JSON"], - timeout: int, - headers: MutableMapping[str, str], - streaming: bool, - auth: ClientAuthScheme | None = None, - use_polling: bool = False, - push_notification_config: PushNotificationConfig | None = None, - client_extensions: list[str] | None = None, - accepted_output_modes: list[str] | None = None, - grpc_config: GRPCClientConfig | None = None, -) -> AsyncIterator[Client]: - """Create and configure an A2A client. - - Args: - agent_card: The A2A agent card. - transport_protocol: Transport protocol to use. - timeout: Request timeout in seconds. - headers: HTTP headers (already with auth applied). - streaming: Enable streaming responses. - auth: Optional ClientAuthScheme for client configuration. - use_polling: Enable polling mode. - push_notification_config: Optional push notification config. - client_extensions: A2A protocol extension URIs to declare support for. - accepted_output_modes: MIME types the client can accept in responses. - grpc_config: Optional gRPC client configuration. - - Yields: - Configured A2A client instance. - """ - verify = _get_tls_verify(auth) - async with httpx.AsyncClient( - timeout=timeout, - headers=headers, - verify=verify, - ) as httpx_client: - if auth and isinstance(auth, (HTTPDigestAuth, APIKeyAuth)): - configure_auth_client(auth, httpx_client) - - push_configs: list[A2APushNotificationConfig] = [] - if push_notification_config is not None: - push_configs.append( - A2APushNotificationConfig( - url=str(push_notification_config.url), - id=push_notification_config.id, - token=push_notification_config.token, - authentication=push_notification_config.authentication, - ) - ) - - grpc_channel_factory = None - if transport_protocol == "GRPC": - grpc_channel_factory = _create_grpc_channel_factory( - grpc_config or GRPCClientConfig(), - auth=auth, - ) - - config = ClientConfig( - httpx_client=httpx_client, - supported_transports=[transport_protocol], - streaming=streaming and not use_polling, - polling=use_polling, - accepted_output_modes=accepted_output_modes or DEFAULT_CLIENT_OUTPUT_MODES, # type: ignore[arg-type] - push_notification_configs=push_configs, - grpc_channel_factory=grpc_channel_factory, - ) - - factory = ClientFactory(config) - client = factory.create(agent_card) - - if client_extensions: - await client.add_request_middleware(ExtensionsMiddleware(client_extensions)) - - yield client +from crewai_a2a.utils.delegation import * # noqa: E402, F403 diff --git a/lib/crewai/src/crewai/a2a/utils/logging.py b/lib/crewai/src/crewai/a2a/utils/logging.py index 585d1d8f3..a654df7a9 100644 --- a/lib/crewai/src/crewai/a2a/utils/logging.py +++ b/lib/crewai/src/crewai/a2a/utils/logging.py @@ -1,131 +1,13 @@ -"""Structured JSON logging utilities for A2A module.""" +"""Backward-compatibility shim — use ``crewai_a2a.utils.logging`` instead.""" -from __future__ import annotations - -from contextvars import ContextVar -from datetime import datetime, timezone -import json -import logging -from typing import Any +import warnings -_log_context: ContextVar[dict[str, Any] | None] = ContextVar( - "log_context", default=None +warnings.warn( + "'crewai.a2a.utils.logging' has been moved to 'crewai_a2a.utils.logging'. " + "Please update your imports. The old path will be removed in v2.0.0.", + FutureWarning, + stacklevel=2, ) - -class JSONFormatter(logging.Formatter): - """JSON formatter for structured logging. - - Outputs logs as JSON with consistent fields for log aggregators. - """ - - def format(self, record: logging.LogRecord) -> str: - """Format log record as JSON string.""" - log_data: dict[str, Any] = { - "timestamp": datetime.now(timezone.utc).isoformat(), - "level": record.levelname, - "logger": record.name, - "message": record.getMessage(), - } - - if record.exc_info: - log_data["exception"] = self.formatException(record.exc_info) - - context = _log_context.get() - if context is not None: - log_data.update(context) - - if hasattr(record, "task_id"): - log_data["task_id"] = record.task_id - if hasattr(record, "context_id"): - log_data["context_id"] = record.context_id - if hasattr(record, "agent"): - log_data["agent"] = record.agent - if hasattr(record, "endpoint"): - log_data["endpoint"] = record.endpoint - if hasattr(record, "extension"): - log_data["extension"] = record.extension - if hasattr(record, "error"): - log_data["error"] = record.error - - for key, value in record.__dict__.items(): - if key.startswith("_") or key in ( - "name", - "msg", - "args", - "created", - "filename", - "funcName", - "levelname", - "levelno", - "lineno", - "module", - "msecs", - "pathname", - "process", - "processName", - "relativeCreated", - "stack_info", - "exc_info", - "exc_text", - "thread", - "threadName", - "taskName", - "message", - ): - continue - if key not in log_data: - log_data[key] = value - - return json.dumps(log_data, default=str) - - -class LogContext: - """Context manager for adding fields to all logs within a scope. - - Example: - with LogContext(task_id="abc", context_id="xyz"): - logger.info("Processing task") # Includes task_id and context_id - """ - - def __init__(self, **fields: Any) -> None: - self._fields = fields - self._token: Any = None - - def __enter__(self) -> LogContext: - current = _log_context.get() or {} - new_context = {**current, **self._fields} - self._token = _log_context.set(new_context) - return self - - def __exit__(self, *args: Any) -> None: - _log_context.reset(self._token) - - -def configure_json_logging(logger_name: str = "crewai.a2a") -> None: - """Configure JSON logging for the A2A module. - - Args: - logger_name: Logger name to configure. - """ - logger = logging.getLogger(logger_name) - - for handler in logger.handlers[:]: - logger.removeHandler(handler) - - handler = logging.StreamHandler() - handler.setFormatter(JSONFormatter()) - logger.addHandler(handler) - - -def get_logger(name: str) -> logging.Logger: - """Get a logger configured for structured JSON output. - - Args: - name: Logger name. - - Returns: - Configured logger instance. - """ - return logging.getLogger(name) +from crewai_a2a.utils.logging import * # noqa: E402, F403 diff --git a/lib/crewai/src/crewai/a2a/utils/response_model.py b/lib/crewai/src/crewai/a2a/utils/response_model.py index 4e65ef2b7..6cf9bc12f 100644 --- a/lib/crewai/src/crewai/a2a/utils/response_model.py +++ b/lib/crewai/src/crewai/a2a/utils/response_model.py @@ -1,101 +1,13 @@ -"""Response model utilities for A2A agent interactions.""" +"""Backward-compatibility shim — use ``crewai_a2a.utils.response_model`` instead.""" -from __future__ import annotations - -from typing import TypeAlias - -from pydantic import BaseModel, Field, create_model - -from crewai.a2a.config import A2AClientConfig, A2AConfig, A2AServerConfig -from crewai.types.utils import create_literals_from_strings +import warnings -A2AConfigTypes: TypeAlias = A2AConfig | A2AServerConfig | A2AClientConfig -A2AClientConfigTypes: TypeAlias = A2AConfig | A2AClientConfig +warnings.warn( + "'crewai.a2a.utils.response_model' has been moved to 'crewai_a2a.utils.response_model'. " + "Please update your imports. The old path will be removed in v2.0.0.", + FutureWarning, + stacklevel=2, +) - -def create_agent_response_model(agent_ids: tuple[str, ...]) -> type[BaseModel] | None: - """Create a dynamic AgentResponse model with Literal types for agent IDs. - - Args: - agent_ids: List of available A2A agent IDs. - - Returns: - Dynamically created Pydantic model with Literal-constrained a2a_ids field, - or None if agent_ids is empty. - """ - if not agent_ids: - return None - - DynamicLiteral = create_literals_from_strings(agent_ids) # noqa: N806 - - return create_model( - "AgentResponse", - a2a_ids=( - tuple[DynamicLiteral, ...], # type: ignore[valid-type] - Field( - default_factory=tuple, - max_length=len(agent_ids), - description="A2A agent IDs to delegate to.", - ), - ), - message=( - str, - Field( - description="The message content. If is_a2a=true, this is sent to the A2A agent. If is_a2a=false, this is your final answer ending the conversation." - ), - ), - is_a2a=( - bool, - Field( - description="Set to false when the remote agent has answered your question - extract their answer and return it as your final message. Set to true ONLY if you need to ask a NEW, DIFFERENT question. NEVER repeat the same request - if the conversation history shows the agent already answered, set is_a2a=false immediately." - ), - ), - __base__=BaseModel, - ) - - -def extract_a2a_agent_ids_from_config( - a2a_config: list[A2AConfigTypes] | A2AConfigTypes | None, -) -> tuple[list[A2AClientConfigTypes], tuple[str, ...]]: - """Extract A2A agent IDs from A2A configuration. - - Filters out A2AServerConfig since it doesn't have an endpoint for delegation. - - Args: - a2a_config: A2A configuration (any type). - - Returns: - Tuple of client A2A configs list and agent endpoint IDs. - """ - if a2a_config is None: - return [], () - - configs: list[A2AConfigTypes] - if isinstance(a2a_config, (A2AConfig, A2AClientConfig, A2AServerConfig)): - configs = [a2a_config] - else: - configs = a2a_config - - # Filter to only client configs (those with endpoint) - client_configs: list[A2AClientConfigTypes] = [ - config for config in configs if isinstance(config, (A2AConfig, A2AClientConfig)) - ] - - return client_configs, tuple(config.endpoint for config in client_configs) - - -def get_a2a_agents_and_response_model( - a2a_config: list[A2AConfigTypes] | A2AConfigTypes | None, -) -> tuple[list[A2AClientConfigTypes], type[BaseModel] | None]: - """Get A2A agent configs and response model. - - Args: - a2a_config: A2A configuration (any type). - - Returns: - Tuple of client A2A configs and response model. - """ - a2a_agents, agent_ids = extract_a2a_agent_ids_from_config(a2a_config=a2a_config) - - return a2a_agents, create_agent_response_model(agent_ids) +from crewai_a2a.utils.response_model import * # noqa: E402, F403 diff --git a/lib/crewai/src/crewai/a2a/utils/task.py b/lib/crewai/src/crewai/a2a/utils/task.py index d73556875..1a0ff956b 100644 --- a/lib/crewai/src/crewai/a2a/utils/task.py +++ b/lib/crewai/src/crewai/a2a/utils/task.py @@ -1,584 +1,13 @@ -"""A2A task utilities for server-side task management.""" +"""Backward-compatibility shim — use ``crewai_a2a.utils.task`` instead.""" -from __future__ import annotations - -import asyncio -import base64 -from collections.abc import Callable, Coroutine -from datetime import datetime -from functools import wraps -import json -import logging -import os -from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar, TypedDict, cast -from urllib.parse import urlparse - -from a2a.server.agent_execution import RequestContext -from a2a.server.events import EventQueue -from a2a.types import ( - Artifact, - FileWithBytes, - FileWithUri, - InternalError, - InvalidParamsError, - Message, - Part, - Task as A2ATask, - TaskState, - TaskStatus, - TaskStatusUpdateEvent, -) -from a2a.utils import ( - get_data_parts, - get_file_parts, - new_agent_text_message, - new_data_artifact, - new_text_artifact, -) -from a2a.utils.errors import ServerError -from aiocache import SimpleMemoryCache, caches # type: ignore[import-untyped] -from pydantic import BaseModel - -from crewai.a2a.utils.agent_card import _get_server_config -from crewai.a2a.utils.content_type import validate_message_parts -from crewai.events.event_bus import crewai_event_bus -from crewai.events.types.a2a_events import ( - A2AServerTaskCanceledEvent, - A2AServerTaskCompletedEvent, - A2AServerTaskFailedEvent, - A2AServerTaskStartedEvent, -) -from crewai.task import Task -from crewai.utilities.pydantic_schema_utils import create_model_from_schema +import warnings -if TYPE_CHECKING: - from crewai.a2a.extensions.server import ExtensionContext, ServerExtensionRegistry - from crewai.agent import Agent - - -logger = logging.getLogger(__name__) - -P = ParamSpec("P") -T = TypeVar("T") - - -class RedisCacheConfig(TypedDict, total=False): - """Configuration for aiocache Redis backend.""" - - cache: str - endpoint: str - port: int - db: int - password: str - - -def _parse_redis_url(url: str) -> RedisCacheConfig: - """Parse a Redis URL into aiocache configuration. - - Args: - url: Redis connection URL (e.g., redis://localhost:6379/0). - - Returns: - Configuration dict for aiocache.RedisCache. - """ - parsed = urlparse(url) - config: RedisCacheConfig = { - "cache": "aiocache.RedisCache", - "endpoint": parsed.hostname or "localhost", - "port": parsed.port or 6379, - } - if parsed.path and parsed.path != "/": - try: - config["db"] = int(parsed.path.lstrip("/")) - except ValueError: - pass - if parsed.password: - config["password"] = parsed.password - return config - - -_redis_url = os.environ.get("REDIS_URL") - -caches.set_config( - { - "default": _parse_redis_url(_redis_url) - if _redis_url - else { - "cache": "aiocache.SimpleMemoryCache", - } - } +warnings.warn( + "'crewai.a2a.utils.task' has been moved to 'crewai_a2a.utils.task'. " + "Please update your imports. The old path will be removed in v2.0.0.", + FutureWarning, + stacklevel=2, ) - -def cancellable( - fn: Callable[P, Coroutine[Any, Any, T]], -) -> Callable[P, Coroutine[Any, Any, T]]: - """Decorator that enables cancellation for A2A task execution. - - Runs a cancellation watcher concurrently with the wrapped function. - When a cancel event is published, the execution is cancelled. - - Args: - fn: The async function to wrap. - - Returns: - Wrapped function with cancellation support. - """ - - @wraps(fn) - async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: - """Wrap function with cancellation monitoring.""" - context: RequestContext | None = None - for arg in args: - if isinstance(arg, RequestContext): - context = arg - break - if context is None: - context = cast(RequestContext | None, kwargs.get("context")) - - if context is None: - return await fn(*args, **kwargs) - - task_id = context.task_id - cache = caches.get("default") - - async def poll_for_cancel() -> bool: - """Poll cache for cancellation flag.""" - while True: - if await cache.get(f"cancel:{task_id}"): - return True - await asyncio.sleep(0.1) - - async def watch_for_cancel() -> bool: - """Watch for cancellation events via pub/sub or polling.""" - if isinstance(cache, SimpleMemoryCache): - return await poll_for_cancel() - - try: - client = cache.client - pubsub = client.pubsub() - await pubsub.subscribe(f"cancel:{task_id}") - async for message in pubsub.listen(): - if message["type"] == "message": - return True - except (OSError, ConnectionError) as e: - logger.warning( - "Cancel watcher Redis error, falling back to polling", - extra={"task_id": task_id, "error": str(e)}, - ) - return await poll_for_cancel() - return False - - execute_task = asyncio.create_task(fn(*args, **kwargs)) - cancel_watch = asyncio.create_task(watch_for_cancel()) - - try: - done, _ = await asyncio.wait( - [execute_task, cancel_watch], - return_when=asyncio.FIRST_COMPLETED, - ) - - if cancel_watch in done: - execute_task.cancel() - try: - await execute_task - except asyncio.CancelledError: - pass - raise asyncio.CancelledError(f"Task {task_id} was cancelled") - cancel_watch.cancel() - return execute_task.result() - finally: - await cache.delete(f"cancel:{task_id}") - - return wrapper - - -def _convert_a2a_files_to_file_inputs( - a2a_files: list[FileWithBytes | FileWithUri], -) -> dict[str, Any]: - """Convert a2a file types to crewai FileInput dict. - - Args: - a2a_files: List of FileWithBytes or FileWithUri from a2a SDK. - - Returns: - Dictionary mapping file names to FileInput objects. - """ - try: - from crewai_files import File, FileBytes - except ImportError: - logger.debug("crewai_files not installed, returning empty file dict") - return {} - - file_dict: dict[str, Any] = {} - for idx, a2a_file in enumerate(a2a_files): - if isinstance(a2a_file, FileWithBytes): - file_bytes = base64.b64decode(a2a_file.bytes) - name = a2a_file.name or f"file_{idx}" - file_source = FileBytes(data=file_bytes, filename=a2a_file.name) - file_dict[name] = File(source=file_source) - elif isinstance(a2a_file, FileWithUri): - name = a2a_file.name or f"file_{idx}" - file_dict[name] = File(source=a2a_file.uri) - - return file_dict - - -def _extract_response_schema(parts: list[Part]) -> dict[str, Any] | None: - """Extract response schema from message parts metadata. - - The client may include a JSON schema in TextPart metadata to specify - the expected response format (see delegation.py line 463). - - Args: - parts: List of message parts. - - Returns: - JSON schema dict if found, None otherwise. - """ - for part in parts: - if part.root.kind == "text" and part.root.metadata: - schema = part.root.metadata.get("schema") - if schema and isinstance(schema, dict): - return schema # type: ignore[no-any-return] - return None - - -def _create_result_artifact( - result: Any, - task_id: str, -) -> Artifact: - """Create artifact from task result, using DataPart for structured data. - - Args: - result: The task execution result. - task_id: The task ID for naming the artifact. - - Returns: - Artifact with appropriate part type (DataPart for dict/Pydantic, TextPart for strings). - """ - artifact_name = f"result_{task_id}" - if isinstance(result, dict): - return new_data_artifact(artifact_name, result) - if isinstance(result, BaseModel): - return new_data_artifact(artifact_name, result.model_dump()) - return new_text_artifact(artifact_name, str(result)) - - -def _build_task_description( - user_message: str, - structured_inputs: list[dict[str, Any]], -) -> str: - """Build task description including structured data if present. - - Args: - user_message: The original user message text. - structured_inputs: List of structured data from DataParts. - - Returns: - Task description with structured data appended if present. - """ - if not structured_inputs: - return user_message - - structured_json = json.dumps(structured_inputs, indent=2) - return f"{user_message}\n\nStructured Data:\n{structured_json}" - - -async def execute( - agent: Agent, - context: RequestContext, - event_queue: EventQueue, -) -> None: - """Execute an A2A task using a CrewAI agent. - - Args: - agent: The CrewAI agent to execute the task. - context: The A2A request context containing the user's message. - event_queue: The event queue for sending responses back. - """ - await _execute_impl(agent, context, event_queue, None, None) - - -@cancellable -async def _execute_impl( - agent: Agent, - context: RequestContext, - event_queue: EventQueue, - extension_registry: ServerExtensionRegistry | None, - extension_context: ExtensionContext | None, -) -> None: - """Internal implementation for task execution with optional extensions.""" - server_config = _get_server_config(agent) - if context.message and context.message.parts and server_config: - allowed_modes = server_config.default_input_modes - invalid_types = validate_message_parts(context.message.parts, allowed_modes) - if invalid_types: - raise ServerError( - InvalidParamsError( - message=f"Unsupported content type(s): {', '.join(invalid_types)}. " - f"Supported: {', '.join(allowed_modes)}" - ) - ) - - if extension_registry and extension_context: - await extension_registry.invoke_on_request(extension_context) - - user_message = context.get_user_input() - - response_model: type[BaseModel] | None = None - structured_inputs: list[dict[str, Any]] = [] - a2a_files: list[FileWithBytes | FileWithUri] = [] - - if context.message and context.message.parts: - schema = _extract_response_schema(context.message.parts) - if schema: - try: - response_model = create_model_from_schema(schema) - except Exception as e: - logger.debug( - "Failed to create response model from schema", - extra={"error": str(e), "schema_title": schema.get("title")}, - ) - - structured_inputs = get_data_parts(context.message.parts) - a2a_files = get_file_parts(context.message.parts) - - task_id = context.task_id - context_id = context.context_id - if task_id is None or context_id is None: - msg = "task_id and context_id are required" - crewai_event_bus.emit( - agent, - A2AServerTaskFailedEvent( - task_id="", - context_id="", - error=msg, - from_agent=agent, - ), - ) - raise ServerError(InvalidParamsError(message=msg)) from None - - task = Task( - description=_build_task_description(user_message, structured_inputs), - expected_output="Response to the user's request", - agent=agent, - response_model=response_model, - input_files=_convert_a2a_files_to_file_inputs(a2a_files), - ) - - crewai_event_bus.emit( - agent, - A2AServerTaskStartedEvent( - task_id=task_id, - context_id=context_id, - from_task=task, - from_agent=agent, - ), - ) - - try: - result = await agent.aexecute_task(task=task, tools=agent.tools) - if extension_registry and extension_context: - result = await extension_registry.invoke_on_response( - extension_context, result - ) - result_str = str(result) - history: list[Message] = [context.message] if context.message else [] - history.append(new_agent_text_message(result_str, context_id, task_id)) - await event_queue.enqueue_event( - A2ATask( - id=task_id, - context_id=context_id, - status=TaskStatus(state=TaskState.completed), - artifacts=[_create_result_artifact(result, task_id)], - history=history, - ) - ) - crewai_event_bus.emit( - agent, - A2AServerTaskCompletedEvent( - task_id=task_id, - context_id=context_id, - result=str(result), - from_task=task, - from_agent=agent, - ), - ) - except asyncio.CancelledError: - crewai_event_bus.emit( - agent, - A2AServerTaskCanceledEvent( - task_id=task_id, - context_id=context_id, - from_task=task, - from_agent=agent, - ), - ) - raise - except Exception as e: - crewai_event_bus.emit( - agent, - A2AServerTaskFailedEvent( - task_id=task_id, - context_id=context_id, - error=str(e), - from_task=task, - from_agent=agent, - ), - ) - raise ServerError( - error=InternalError(message=f"Task execution failed: {e}") - ) from e - - -async def execute_with_extensions( - agent: Agent, - context: RequestContext, - event_queue: EventQueue, - extension_registry: ServerExtensionRegistry, - extension_context: ExtensionContext, -) -> None: - """Execute an A2A task with extension hooks. - - Args: - agent: The CrewAI agent to execute the task. - context: The A2A request context containing the user's message. - event_queue: The event queue for sending responses back. - extension_registry: Registry of server extensions. - extension_context: Context for extension hooks. - """ - await _execute_impl( - agent, context, event_queue, extension_registry, extension_context - ) - - -async def cancel( - context: RequestContext, - event_queue: EventQueue, -) -> A2ATask | None: - """Cancel an A2A task. - - Publishes a cancel event that the cancellable decorator listens for. - - Args: - context: The A2A request context containing task information. - event_queue: The event queue for sending the cancellation status. - - Returns: - The canceled task with updated status. - """ - task_id = context.task_id - context_id = context.context_id - if task_id is None or context_id is None: - raise ServerError(InvalidParamsError(message="task_id and context_id required")) - - if context.current_task and context.current_task.status.state in ( - TaskState.completed, - TaskState.failed, - TaskState.canceled, - ): - return context.current_task - - cache = caches.get("default") - - await cache.set(f"cancel:{task_id}", True, ttl=3600) - if not isinstance(cache, SimpleMemoryCache): - await cache.client.publish(f"cancel:{task_id}", "cancel") - - await event_queue.enqueue_event( - TaskStatusUpdateEvent( - task_id=task_id, - context_id=context_id, - status=TaskStatus(state=TaskState.canceled), - final=True, - ) - ) - - if context.current_task: - context.current_task.status = TaskStatus(state=TaskState.canceled) - return context.current_task - return None - - -def list_tasks( - tasks: list[A2ATask], - context_id: str | None = None, - status: TaskState | None = None, - status_timestamp_after: datetime | None = None, - page_size: int = 50, - page_token: str | None = None, - history_length: int | None = None, - include_artifacts: bool = False, -) -> tuple[list[A2ATask], str | None, int]: - """Filter and paginate A2A tasks. - - Provides filtering by context, status, and timestamp, along with - cursor-based pagination. This is a pure utility function that operates - on an in-memory list of tasks - storage retrieval is handled separately. - - Args: - tasks: All tasks to filter. - context_id: Filter by context ID to get tasks in a conversation. - status: Filter by task state (e.g., completed, working). - status_timestamp_after: Filter to tasks updated after this time. - page_size: Maximum tasks per page (default 50). - page_token: Base64-encoded cursor from previous response. - history_length: Limit history messages per task (None = full history). - include_artifacts: Whether to include task artifacts (default False). - - Returns: - Tuple of (filtered_tasks, next_page_token, total_count). - - filtered_tasks: Tasks matching filters, paginated and trimmed. - - next_page_token: Token for next page, or None if no more pages. - - total_count: Total number of tasks matching filters (before pagination). - """ - filtered: list[A2ATask] = [] - for task in tasks: - if context_id and task.context_id != context_id: - continue - if status and task.status.state != status: - continue - if status_timestamp_after and task.status.timestamp: - ts = datetime.fromisoformat(task.status.timestamp.replace("Z", "+00:00")) - if ts <= status_timestamp_after: - continue - filtered.append(task) - - def get_timestamp(t: A2ATask) -> datetime: - """Extract timestamp from task status for sorting.""" - if t.status.timestamp is None: - return datetime.min - return datetime.fromisoformat(t.status.timestamp.replace("Z", "+00:00")) - - filtered.sort(key=get_timestamp, reverse=True) - total = len(filtered) - - start = 0 - if page_token: - try: - cursor_id = base64.b64decode(page_token).decode() - for idx, task in enumerate(filtered): - if task.id == cursor_id: - start = idx + 1 - break - except (ValueError, UnicodeDecodeError): - pass - - page = filtered[start : start + page_size] - - result: list[A2ATask] = [] - for task in page: - task = task.model_copy(deep=True) - if history_length is not None and task.history: - task.history = task.history[-history_length:] - if not include_artifacts: - task.artifacts = None - result.append(task) - - next_token: str | None = None - if result and len(result) == page_size: - next_token = base64.b64encode(result[-1].id.encode()).decode() - - return result, next_token, total +from crewai_a2a.utils.task import * # noqa: E402, F403 diff --git a/lib/crewai/src/crewai/a2a/utils/transport.py b/lib/crewai/src/crewai/a2a/utils/transport.py index cc57ba20c..7d3530774 100644 --- a/lib/crewai/src/crewai/a2a/utils/transport.py +++ b/lib/crewai/src/crewai/a2a/utils/transport.py @@ -1,215 +1,13 @@ -"""Transport negotiation utilities for A2A protocol. +"""Backward-compatibility shim — use ``crewai_a2a.utils.transport`` instead.""" -This module provides functionality for negotiating the transport protocol -between an A2A client and server based on their respective capabilities -and preferences. -""" - -from __future__ import annotations - -from dataclasses import dataclass -from typing import Final, Literal - -from a2a.types import AgentCard, AgentInterface - -from crewai.events.event_bus import crewai_event_bus -from crewai.events.types.a2a_events import A2ATransportNegotiatedEvent +import warnings -TransportProtocol = Literal["JSONRPC", "GRPC", "HTTP+JSON"] -NegotiationSource = Literal["client_preferred", "server_preferred", "fallback"] +warnings.warn( + "'crewai.a2a.utils.transport' has been moved to 'crewai_a2a.utils.transport'. " + "Please update your imports. The old path will be removed in v2.0.0.", + FutureWarning, + stacklevel=2, +) -JSONRPC_TRANSPORT: Literal["JSONRPC"] = "JSONRPC" -GRPC_TRANSPORT: Literal["GRPC"] = "GRPC" -HTTP_JSON_TRANSPORT: Literal["HTTP+JSON"] = "HTTP+JSON" - -DEFAULT_TRANSPORT_PREFERENCE: Final[list[TransportProtocol]] = [ - JSONRPC_TRANSPORT, - GRPC_TRANSPORT, - HTTP_JSON_TRANSPORT, -] - - -@dataclass -class NegotiatedTransport: - """Result of transport negotiation. - - Attributes: - transport: The negotiated transport protocol. - url: The URL to use for this transport. - source: How the transport was selected ('preferred', 'additional', 'fallback'). - """ - - transport: str - url: str - source: NegotiationSource - - -class TransportNegotiationError(Exception): - """Raised when no compatible transport can be negotiated.""" - - def __init__( - self, - client_transports: list[str], - server_transports: list[str], - message: str | None = None, - ) -> None: - """Initialize the error with negotiation details. - - Args: - client_transports: Transports supported by the client. - server_transports: Transports supported by the server. - message: Optional custom error message. - """ - self.client_transports = client_transports - self.server_transports = server_transports - if message is None: - message = ( - f"No compatible transport found. " - f"Client supports: {client_transports}. " - f"Server supports: {server_transports}." - ) - super().__init__(message) - - -def _get_server_interfaces(agent_card: AgentCard) -> list[AgentInterface]: - """Extract all available interfaces from an AgentCard. - - Creates a unified list of interfaces including the primary URL and - any additional interfaces declared by the agent. - - Args: - agent_card: The agent's card containing transport information. - - Returns: - List of AgentInterface objects representing all available endpoints. - """ - interfaces: list[AgentInterface] = [] - - primary_transport = agent_card.preferred_transport or JSONRPC_TRANSPORT - interfaces.append( - AgentInterface( - transport=primary_transport, - url=agent_card.url, - ) - ) - - if agent_card.additional_interfaces: - for interface in agent_card.additional_interfaces: - is_duplicate = any( - i.url == interface.url and i.transport == interface.transport - for i in interfaces - ) - if not is_duplicate: - interfaces.append(interface) - - return interfaces - - -def negotiate_transport( - agent_card: AgentCard, - client_supported_transports: list[str] | None = None, - client_preferred_transport: str | None = None, - emit_event: bool = True, - endpoint: str | None = None, - a2a_agent_name: str | None = None, -) -> NegotiatedTransport: - """Negotiate the transport protocol between client and server. - - Compares the client's supported transports with the server's available - interfaces to find a compatible transport and URL. - - Negotiation logic: - 1. If client_preferred_transport is set and server supports it → use it - 2. Otherwise, if server's preferred is in client's supported → use server's - 3. Otherwise, find first match from client's supported in server's interfaces - - Args: - agent_card: The server's AgentCard with transport information. - client_supported_transports: Transports the client can use. - Defaults to ["JSONRPC"] if not specified. - client_preferred_transport: Client's preferred transport. If set and - server supports it, takes priority over server preference. - emit_event: Whether to emit a transport negotiation event. - endpoint: Original endpoint URL for event metadata. - a2a_agent_name: Agent name for event metadata. - - Returns: - NegotiatedTransport with the selected transport, URL, and source. - - Raises: - TransportNegotiationError: If no compatible transport is found. - """ - if client_supported_transports is None: - client_supported_transports = [JSONRPC_TRANSPORT] - - client_transports = [t.upper() for t in client_supported_transports] - client_preferred = ( - client_preferred_transport.upper() if client_preferred_transport else None - ) - - server_interfaces = _get_server_interfaces(agent_card) - server_transports = [i.transport.upper() for i in server_interfaces] - - transport_to_interface: dict[str, AgentInterface] = {} - for interface in server_interfaces: - transport_upper = interface.transport.upper() - if transport_upper not in transport_to_interface: - transport_to_interface[transport_upper] = interface - - result: NegotiatedTransport | None = None - - if client_preferred and client_preferred in transport_to_interface: - interface = transport_to_interface[client_preferred] - result = NegotiatedTransport( - transport=interface.transport, - url=interface.url, - source="client_preferred", - ) - else: - server_preferred = (agent_card.preferred_transport or JSONRPC_TRANSPORT).upper() - if ( - server_preferred in client_transports - and server_preferred in transport_to_interface - ): - interface = transport_to_interface[server_preferred] - result = NegotiatedTransport( - transport=interface.transport, - url=interface.url, - source="server_preferred", - ) - else: - for transport in client_transports: - if transport in transport_to_interface: - interface = transport_to_interface[transport] - result = NegotiatedTransport( - transport=interface.transport, - url=interface.url, - source="fallback", - ) - break - - if result is None: - raise TransportNegotiationError( - client_transports=client_transports, - server_transports=server_transports, - ) - - if emit_event: - crewai_event_bus.emit( - None, - A2ATransportNegotiatedEvent( - endpoint=endpoint or agent_card.url, - a2a_agent_name=a2a_agent_name or agent_card.name, - negotiated_transport=result.transport, - negotiated_url=result.url, - source=result.source, - client_supported_transports=client_transports, - server_supported_transports=server_transports, - server_preferred_transport=agent_card.preferred_transport - or JSONRPC_TRANSPORT, - client_preferred_transport=client_preferred, - ), - ) - - return result +from crewai_a2a.utils.transport import * # noqa: E402, F403 diff --git a/lib/crewai/src/crewai/a2a/wrapper.py b/lib/crewai/src/crewai/a2a/wrapper.py index 307ba0c90..e3c60641f 100644 --- a/lib/crewai/src/crewai/a2a/wrapper.py +++ b/lib/crewai/src/crewai/a2a/wrapper.py @@ -1,1753 +1,13 @@ -"""A2A agent wrapping logic for metaclass integration. +"""Backward-compatibility shim — use ``crewai_a2a.wrapper`` instead.""" -Wraps agent classes with A2A delegation capabilities. -""" +import warnings -from __future__ import annotations -import asyncio -from collections.abc import Callable, Coroutine, Mapping -from concurrent.futures import ThreadPoolExecutor, as_completed -from functools import wraps -import json -from types import MethodType -from typing import TYPE_CHECKING, Any, NamedTuple - -from a2a.types import Role, TaskState -from pydantic import BaseModel, ValidationError - -from crewai.a2a.config import A2AClientConfig, A2AConfig -from crewai.a2a.extensions.base import ( - A2AExtension, - ConversationState, - ExtensionRegistry, +warnings.warn( + "'crewai.a2a.wrapper' has been moved to 'crewai_a2a.wrapper'. " + "Please update your imports. The old path will be removed in v2.0.0.", + FutureWarning, + stacklevel=2, ) -from crewai.a2a.task_helpers import TaskStateResult -from crewai.a2a.templates import ( - AVAILABLE_AGENTS_TEMPLATE, - CONVERSATION_TURN_INFO_TEMPLATE, - PREVIOUS_A2A_CONVERSATION_TEMPLATE, - REMOTE_AGENT_RESPONSE_NOTICE, - UNAVAILABLE_AGENTS_NOTICE_TEMPLATE, -) -from crewai.a2a.types import AgentResponseProtocol -from crewai.a2a.utils.agent_card import ( - afetch_agent_card, - fetch_agent_card, - inject_a2a_server_methods, -) -from crewai.a2a.utils.delegation import ( - aexecute_a2a_delegation, - execute_a2a_delegation, -) -from crewai.a2a.utils.response_model import get_a2a_agents_and_response_model -from crewai.events.event_bus import crewai_event_bus -from crewai.events.types.a2a_events import ( - A2AConversationCompletedEvent, - A2AMessageSentEvent, -) -from crewai.lite_agent_output import LiteAgentOutput -from crewai.task import Task - -if TYPE_CHECKING: - from a2a.types import AgentCard, Message - - from crewai.agent.core import Agent - from crewai.tools.base_tool import BaseTool - - -class DelegationContext(NamedTuple): - """Context prepared for A2A delegation. - - Groups all the values needed to execute a delegation to a remote A2A agent. - """ - - a2a_agents: list[A2AConfig | A2AClientConfig] - agent_response_model: type[BaseModel] | None - current_request: str - agent_id: str - agent_config: A2AConfig | A2AClientConfig - context_id: str | None - task_id: str | None - metadata: dict[str, Any] | None - extensions: dict[str, Any] | None - reference_task_ids: list[str] - original_task_description: str - max_turns: int - - -class DelegationState(NamedTuple): - """Mutable state for A2A delegation loop. - - Groups values that may change during delegation turns. - """ - - current_request: str - context_id: str | None - task_id: str | None - reference_task_ids: list[str] - conversation_history: list[Message] - agent_card: AgentCard | None - agent_card_dict: dict[str, Any] | None - agent_name: str | None - - -def wrap_agent_with_a2a_instance( - agent: Agent, extension_registry: ExtensionRegistry | None = None -) -> None: - """Wrap an agent instance's task execution and kickoff methods with A2A support. - - This function modifies the agent instance by wrapping its execute_task, - aexecute_task, kickoff, and kickoff_async methods to add A2A delegation - capabilities. Should only be called when the agent has a2a configuration set. - - Args: - agent: The agent instance to wrap. - extension_registry: Optional registry of A2A extensions. - """ - if extension_registry is None: - extension_registry = ExtensionRegistry() - - extension_registry.inject_all_tools(agent) - - original_execute_task = agent.execute_task.__func__ # type: ignore[attr-defined] - original_aexecute_task = agent.aexecute_task.__func__ # type: ignore[attr-defined] - - @wraps(original_execute_task) - def execute_task_with_a2a( - self: Agent, - task: Task, - context: str | None = None, - tools: list[BaseTool] | None = None, - ) -> str: - """Execute task with A2A delegation support (sync).""" - if not self.a2a: - return original_execute_task(self, task, context, tools) # type: ignore[no-any-return] - - a2a_agents, agent_response_model = get_a2a_agents_and_response_model(self.a2a) - - return _execute_task_with_a2a( - self=self, - a2a_agents=a2a_agents, - original_fn=original_execute_task, - task=task, - agent_response_model=agent_response_model, - context=context, - tools=tools, - extension_registry=extension_registry, - ) - - @wraps(original_aexecute_task) - async def aexecute_task_with_a2a( - self: Agent, - task: Task, - context: str | None = None, - tools: list[BaseTool] | None = None, - ) -> str: - """Execute task with A2A delegation support (async).""" - if not self.a2a: - return await original_aexecute_task(self, task, context, tools) # type: ignore[no-any-return] - - a2a_agents, agent_response_model = get_a2a_agents_and_response_model(self.a2a) - - return await _aexecute_task_with_a2a( - self=self, - a2a_agents=a2a_agents, - original_fn=original_aexecute_task, - task=task, - agent_response_model=agent_response_model, - context=context, - tools=tools, - extension_registry=extension_registry, - ) - - object.__setattr__(agent, "execute_task", MethodType(execute_task_with_a2a, agent)) - object.__setattr__( - agent, "aexecute_task", MethodType(aexecute_task_with_a2a, agent) - ) - - original_kickoff = agent.kickoff.__func__ # type: ignore[attr-defined] - original_kickoff_async = agent.kickoff_async.__func__ # type: ignore[attr-defined] - - @wraps(original_kickoff) - def kickoff_with_a2a( - self: Agent, - messages: str | list[Any], - response_format: type[Any] | None = None, - input_files: dict[str, Any] | None = None, - ) -> Any: - """Execute agent kickoff with A2A delegation support.""" - if not self.a2a: - return original_kickoff(self, messages, response_format, input_files) - - a2a_agents, agent_response_model = get_a2a_agents_and_response_model(self.a2a) - - if not a2a_agents: - return original_kickoff(self, messages, response_format, input_files) - - return _kickoff_with_a2a( - self=self, - a2a_agents=a2a_agents, - original_kickoff=original_kickoff, - messages=messages, - response_format=response_format, - input_files=input_files, - agent_response_model=agent_response_model, - extension_registry=extension_registry, - ) - - @wraps(original_kickoff_async) - async def kickoff_async_with_a2a( - self: Agent, - messages: str | list[Any], - response_format: type[Any] | None = None, - input_files: dict[str, Any] | None = None, - ) -> Any: - """Execute agent kickoff with A2A delegation support.""" - if not self.a2a: - return await original_kickoff_async( - self, messages, response_format, input_files - ) - - a2a_agents, agent_response_model = get_a2a_agents_and_response_model(self.a2a) - - if not a2a_agents: - return await original_kickoff_async( - self, messages, response_format, input_files - ) - - return await _akickoff_with_a2a( - self=self, - a2a_agents=a2a_agents, - original_kickoff_async=original_kickoff_async, - messages=messages, - response_format=response_format, - input_files=input_files, - agent_response_model=agent_response_model, - extension_registry=extension_registry, - ) - - object.__setattr__(agent, "kickoff", MethodType(kickoff_with_a2a, agent)) - object.__setattr__( - agent, "kickoff_async", MethodType(kickoff_async_with_a2a, agent) - ) - - inject_a2a_server_methods(agent) - - -def _fetch_card_from_config( - config: A2AConfig | A2AClientConfig, -) -> tuple[A2AConfig | A2AClientConfig, AgentCard | Exception]: - """Fetch agent card from A2A config. - - Args: - config: A2A configuration - - Returns: - Tuple of (config, card or exception) - """ - try: - card = fetch_agent_card( - endpoint=config.endpoint, - auth=config.auth, - timeout=config.timeout, - ) - return config, card - except Exception as e: - return config, e - - -def _fetch_agent_cards_concurrently( - a2a_agents: list[A2AConfig | A2AClientConfig], -) -> tuple[dict[str, AgentCard], dict[str, str]]: - """Fetch agent cards concurrently for multiple A2A agents. - - Args: - a2a_agents: List of A2A agent configurations - - Returns: - Tuple of (agent_cards dict, failed_agents dict mapping endpoint to error message) - """ - agent_cards: dict[str, AgentCard] = {} - failed_agents: dict[str, str] = {} - - if not a2a_agents: - return agent_cards, failed_agents - - max_workers = min(len(a2a_agents), 10) - with ThreadPoolExecutor(max_workers=max_workers) as executor: - futures = { - executor.submit(_fetch_card_from_config, config): config - for config in a2a_agents - } - for future in as_completed(futures): - config, result = future.result() - if isinstance(result, Exception): - if config.fail_fast: - raise RuntimeError( - f"Failed to fetch agent card from {config.endpoint}. " - f"Ensure the A2A agent is running and accessible. Error: {result}" - ) from result - failed_agents[config.endpoint] = str(result) - else: - agent_cards[config.endpoint] = result - - return agent_cards, failed_agents - - -def _execute_task_with_a2a( - self: Agent, - a2a_agents: list[A2AConfig | A2AClientConfig], - original_fn: Callable[..., str], - task: Task, - agent_response_model: type[BaseModel] | None, - context: str | None, - tools: list[BaseTool] | None, - extension_registry: ExtensionRegistry, -) -> str: - """Wrap execute_task with A2A delegation logic. - - Args: - self: The agent instance - a2a_agents: Dictionary of A2A agent configurations - original_fn: The original execute_task method - task: The task to execute - context: Optional context for task execution - tools: Optional tools available to the agent - agent_response_model: Optional agent response model - extension_registry: Registry of A2A extensions - - Returns: - Task execution result (either from LLM or A2A agent) - """ - original_description: str = task.description - original_output_pydantic = task.output_pydantic - original_response_model = task.response_model - - agent_cards, failed_agents = _fetch_agent_cards_concurrently(a2a_agents) - - if not agent_cards and a2a_agents and failed_agents: - unavailable_agents_text = "" - for endpoint, error in failed_agents.items(): - unavailable_agents_text += f" - {endpoint}: {error}\n" - - notice = UNAVAILABLE_AGENTS_NOTICE_TEMPLATE.substitute( - unavailable_agents=unavailable_agents_text - ) - task.description = f"{original_description}{notice}" - - try: - return original_fn(self, task, context, tools) - finally: - task.description = original_description - - task.description, _, extension_states = _augment_prompt_with_a2a( - a2a_agents=a2a_agents, - task_description=original_description, - agent_cards=agent_cards, - failed_agents=failed_agents, - extension_registry=extension_registry, - ) - task.response_model = agent_response_model - - try: - raw_result = original_fn(self, task, context, tools) - agent_response = _parse_agent_response( - raw_result=raw_result, agent_response_model=agent_response_model - ) - - if extension_registry and isinstance(agent_response, BaseModel): - agent_response = extension_registry.process_response_with_all( - agent_response, extension_states - ) - - if isinstance(agent_response, BaseModel) and isinstance( - agent_response, AgentResponseProtocol - ): - if agent_response.is_a2a: - return _delegate_to_a2a( - self, - agent_response=agent_response, - task=task, - original_fn=original_fn, - context=context, - tools=tools, - agent_cards=agent_cards, - original_task_description=original_description, - _extension_registry=extension_registry, - ) - task.output_pydantic = None - return agent_response.message - - return raw_result - finally: - task.description = original_description - if task.output_pydantic is not None: - task.output_pydantic = original_output_pydantic - task.response_model = original_response_model - - -def _kickoff_with_a2a( - self: Agent, - a2a_agents: list[A2AConfig | A2AClientConfig], - original_kickoff: Callable[..., LiteAgentOutput], - messages: str | list[Any], - response_format: type[Any] | None, - input_files: dict[str, Any] | None, - agent_response_model: type[BaseModel] | None, - extension_registry: ExtensionRegistry, -) -> LiteAgentOutput: - """Execute kickoff with A2A delegation support (sync). - - Args: - self: The agent instance. - a2a_agents: List of A2A agent configurations. - original_kickoff: The original kickoff method. - messages: Messages to send to the agent. - response_format: Optional response format. - input_files: Optional input files. - agent_response_model: Optional agent response model. - extension_registry: Registry of A2A extensions. - - Returns: - LiteAgentOutput from kickoff or A2A delegation. - """ - if isinstance(messages, str): - description = messages - else: - content = next( - (m["content"] for m in reversed(messages) if m["role"] == "user"), - None, - ) - description = content if isinstance(content, str) else "" - - if not description: - return original_kickoff(self, messages, response_format, input_files) - - fake_task = Task( - description=description, - agent=self, - expected_output="Result from A2A delegation", - input_files=input_files or {}, - ) - - agent_cards, failed_agents = _fetch_agent_cards_concurrently(a2a_agents) - - if not agent_cards and a2a_agents and failed_agents: - return original_kickoff(self, messages, response_format, input_files) - - fake_task.description, _, extension_states = _augment_prompt_with_a2a( - a2a_agents=a2a_agents, - task_description=description, - agent_cards=agent_cards, - failed_agents=failed_agents, - extension_registry=extension_registry, - ) - fake_task.response_model = agent_response_model - - try: - result: LiteAgentOutput = original_kickoff( - self, messages, agent_response_model or response_format, input_files - ) - agent_response = _parse_agent_response( - raw_result=result.raw, agent_response_model=agent_response_model - ) - - if extension_registry and isinstance(agent_response, BaseModel): - agent_response = extension_registry.process_response_with_all( - agent_response, extension_states - ) - - if isinstance(agent_response, BaseModel) and isinstance( - agent_response, AgentResponseProtocol - ): - if agent_response.is_a2a: - - def _kickoff_adapter( - self_: Agent, - _task: Task, - _context: str | None, - _tools: list[Any] | None, - ) -> str: - fmt = ( - _task.response_model or agent_response_model or response_format - ) - output: LiteAgentOutput = original_kickoff( - self_, messages, fmt, input_files - ) - return output.raw - - result_str = _delegate_to_a2a( - self, - agent_response=agent_response, - task=fake_task, - original_fn=_kickoff_adapter, - context=None, - tools=None, - agent_cards=agent_cards, - original_task_description=description, - _extension_registry=extension_registry, - ) - return LiteAgentOutput( - raw=result_str, - pydantic=None, - agent_role=self.role, - usage_metrics=None, - messages=[], - ) - return LiteAgentOutput( - raw=agent_response.message, - pydantic=None, - agent_role=self.role, - usage_metrics=result.usage_metrics, - messages=result.messages, - ) - - return result - finally: - fake_task.description = description - - -async def _akickoff_with_a2a( - self: Agent, - a2a_agents: list[A2AConfig | A2AClientConfig], - original_kickoff_async: Callable[..., Coroutine[Any, Any, LiteAgentOutput]], - messages: str | list[Any], - response_format: type[Any] | None, - input_files: dict[str, Any] | None, - agent_response_model: type[BaseModel] | None, - extension_registry: ExtensionRegistry, -) -> LiteAgentOutput: - """Execute kickoff with A2A delegation support (async). - - Args: - self: The agent instance. - a2a_agents: List of A2A agent configurations. - original_kickoff_async: The original kickoff_async method. - messages: Messages to send to the agent. - response_format: Optional response format. - input_files: Optional input files. - agent_response_model: Optional agent response model. - extension_registry: Registry of A2A extensions. - - Returns: - LiteAgentOutput from kickoff or A2A delegation. - """ - if isinstance(messages, str): - description = messages - else: - content = next( - (m["content"] for m in reversed(messages) if m["role"] == "user"), - None, - ) - description = content if isinstance(content, str) else "" - - if not description: - return await original_kickoff_async( - self, messages, response_format, input_files - ) - - fake_task = Task( - description=description, - agent=self, - expected_output="Result from A2A delegation", - input_files=input_files or {}, - ) - - agent_cards, failed_agents = await _afetch_agent_cards_concurrently(a2a_agents) - - if not agent_cards and a2a_agents and failed_agents: - return await original_kickoff_async( - self, messages, response_format, input_files - ) - - fake_task.description, _, extension_states = _augment_prompt_with_a2a( - a2a_agents=a2a_agents, - task_description=description, - agent_cards=agent_cards, - failed_agents=failed_agents, - extension_registry=extension_registry, - ) - fake_task.response_model = agent_response_model - - try: - result: LiteAgentOutput = await original_kickoff_async( - self, messages, agent_response_model or response_format, input_files - ) - agent_response = _parse_agent_response( - raw_result=result.raw, agent_response_model=agent_response_model - ) - - if extension_registry and isinstance(agent_response, BaseModel): - agent_response = extension_registry.process_response_with_all( - agent_response, extension_states - ) - - if isinstance(agent_response, BaseModel) and isinstance( - agent_response, AgentResponseProtocol - ): - if agent_response.is_a2a: - - async def _kickoff_adapter( - self_: Agent, - _task: Task, - _context: str | None, - _tools: list[Any] | None, - ) -> str: - fmt = ( - _task.response_model or agent_response_model or response_format - ) - output: LiteAgentOutput = await original_kickoff_async( - self_, messages, fmt, input_files - ) - return output.raw - - result_str = await _adelegate_to_a2a( - self, - agent_response=agent_response, - task=fake_task, - original_fn=_kickoff_adapter, - context=None, - tools=None, - agent_cards=agent_cards, - original_task_description=description, - _extension_registry=extension_registry, - ) - return LiteAgentOutput( - raw=result_str, - pydantic=None, - agent_role=self.role, - usage_metrics=None, - messages=[], - ) - return LiteAgentOutput( - raw=agent_response.message, - pydantic=None, - agent_role=self.role, - usage_metrics=result.usage_metrics, - messages=result.messages, - ) - - return result - finally: - fake_task.description = description - - -def _augment_prompt_with_a2a( - a2a_agents: list[A2AConfig | A2AClientConfig], - task_description: str, - agent_cards: Mapping[str, AgentCard | dict[str, Any]], - conversation_history: list[Message] | None = None, - turn_num: int = 0, - max_turns: int | None = None, - failed_agents: dict[str, str] | None = None, - extension_registry: ExtensionRegistry | None = None, - remote_status_notice: str = "", -) -> tuple[str, bool, dict[type[A2AExtension], ConversationState]]: - """Add A2A delegation instructions to prompt. - - Args: - a2a_agents: Dictionary of A2A agent configurations - task_description: Original task description - agent_cards: dictionary mapping agent IDs to AgentCards - conversation_history: Previous A2A Messages from conversation - turn_num: Current turn number (0-indexed) - max_turns: Maximum allowed turns (from config) - failed_agents: Dictionary mapping failed agent endpoints to error messages - extension_registry: Optional registry of A2A extensions - remote_status_notice: Optional notice about remote agent status to append - - Returns: - Tuple of (augmented prompt, disable_structured_output flag, extension_states dict) - """ - - if not agent_cards: - return task_description, False, {} - - agents_text = "" - - for config in a2a_agents: - if config.endpoint in agent_cards: - card = agent_cards[config.endpoint] - if isinstance(card, dict): - filtered = { - k: v - for k, v in card.items() - if k in {"description", "url", "skills"} and v is not None - } - agents_text += f"\n{json.dumps(filtered, indent=2)}\n" - else: - agents_text += f"\n{card.model_dump_json(indent=2, exclude_none=True, include={'description', 'url', 'skills'})}\n" - - failed_agents = failed_agents or {} - if failed_agents: - agents_text += "\n\n" - for endpoint, error in failed_agents.items(): - agents_text += f"\n\n" - - agents_text = AVAILABLE_AGENTS_TEMPLATE.substitute(available_a2a_agents=agents_text) - - history_text = "" - - if conversation_history: - for msg in conversation_history: - history_text += f"\n{msg.model_dump_json(indent=2, exclude_none=True, exclude={'message_id'})}\n" - - history_text = PREVIOUS_A2A_CONVERSATION_TEMPLATE.substitute( - previous_a2a_conversation=history_text - ) - - extension_states = {} - disable_structured_output = False - if extension_registry and conversation_history: - extension_states = extension_registry.extract_all_states(conversation_history) - for state in extension_states.values(): - if state.is_ready(): - disable_structured_output = True - break - turn_info = "" - - if max_turns is not None and conversation_history: - turn_count = turn_num + 1 - warning = "" - if turn_count >= max_turns: - warning = ( - "CRITICAL: This is the FINAL turn. You MUST conclude the conversation now.\n" - "Set is_a2a=false and provide your final response to complete the task." - ) - elif turn_count == max_turns - 1: - warning = "WARNING: Next turn will be the last. Consider wrapping up the conversation." - - turn_info = CONVERSATION_TURN_INFO_TEMPLATE.substitute( - turn_count=turn_count, - max_turns=max_turns, - warning=warning, - ) - - augmented_prompt = f"""{task_description} - -IMPORTANT: You have the ability to delegate this task to remote A2A agents. -{agents_text} -{history_text}{turn_info}{remote_status_notice} - -""" - - if extension_registry: - augmented_prompt = extension_registry.augment_prompt_with_all( - augmented_prompt, extension_states - ) - - return augmented_prompt, disable_structured_output, extension_states - - -def _parse_agent_response( - raw_result: str | dict[str, Any], agent_response_model: type[BaseModel] | None -) -> BaseModel | str | dict[str, Any]: - """Parse LLM output as AgentResponse or return raw agent response.""" - if agent_response_model: - try: - if isinstance(raw_result, str): - return agent_response_model.model_validate_json(raw_result) - if isinstance(raw_result, dict): - return agent_response_model.model_validate(raw_result) - except ValidationError: - return raw_result - return raw_result - - -def _handle_max_turns_exceeded( - conversation_history: list[Message], - max_turns: int, - from_task: Any | None = None, - from_agent: Any | None = None, - endpoint: str | None = None, - a2a_agent_name: str | None = None, - agent_card: dict[str, Any] | None = None, -) -> str: - """Handle the case when max turns is exceeded. - - Shared logic for both sync and async delegation. - - Returns: - Final message if found in history. - - Raises: - Exception: If no final message found and max turns exceeded. - """ - if conversation_history: - for msg in reversed(conversation_history): - if msg.role == Role.agent: - text_parts = [ - part.root.text for part in msg.parts if part.root.kind == "text" - ] - final_message = ( - " ".join(text_parts) if text_parts else "Conversation completed" - ) - crewai_event_bus.emit( - None, - A2AConversationCompletedEvent( - status="completed", - final_result=final_message, - error=None, - total_turns=max_turns, - from_task=from_task, - from_agent=from_agent, - endpoint=endpoint, - a2a_agent_name=a2a_agent_name, - agent_card=agent_card, - ), - ) - return final_message - - crewai_event_bus.emit( - None, - A2AConversationCompletedEvent( - status="failed", - final_result=None, - error=f"Conversation exceeded maximum turns ({max_turns})", - total_turns=max_turns, - from_task=from_task, - from_agent=from_agent, - endpoint=endpoint, - a2a_agent_name=a2a_agent_name, - agent_card=agent_card, - ), - ) - raise Exception(f"A2A conversation exceeded maximum turns ({max_turns})") - - -def _emit_delegation_failed( - error_msg: str, - turn_num: int, - from_task: Any | None, - from_agent: Any | None, - endpoint: str | None, - a2a_agent_name: str | None, - agent_card: dict[str, Any] | None, -) -> str: - """Emit failure event and return formatted error message.""" - crewai_event_bus.emit( - None, - A2AConversationCompletedEvent( - status="failed", - final_result=None, - error=error_msg, - total_turns=turn_num + 1, - from_task=from_task, - from_agent=from_agent, - endpoint=endpoint, - a2a_agent_name=a2a_agent_name, - agent_card=agent_card, - ), - ) - return f"A2A delegation failed: {error_msg}" - - -def _process_response_result( - raw_result: str, - disable_structured_output: bool, - turn_num: int, - agent_role: str, - agent_response_model: type[BaseModel] | None, - extension_registry: ExtensionRegistry | None = None, - extension_states: dict[type[A2AExtension], ConversationState] | None = None, - from_task: Any | None = None, - from_agent: Any | None = None, - endpoint: str | None = None, - a2a_agent_name: str | None = None, - agent_card: dict[str, Any] | None = None, -) -> tuple[str | None, str | None]: - """Process LLM response and determine next action. - - Shared logic for both sync and async handlers. - - Returns: - Tuple of (final_result, next_request). - """ - if disable_structured_output: - final_turn_number = turn_num + 1 - result_text = str(raw_result) - crewai_event_bus.emit( - None, - A2AMessageSentEvent( - message=result_text, - turn_number=final_turn_number, - is_multiturn=True, - agent_role=agent_role, - from_task=from_task, - from_agent=from_agent, - endpoint=endpoint, - a2a_agent_name=a2a_agent_name, - ), - ) - crewai_event_bus.emit( - None, - A2AConversationCompletedEvent( - status="completed", - final_result=result_text, - error=None, - total_turns=final_turn_number, - from_task=from_task, - from_agent=from_agent, - endpoint=endpoint, - a2a_agent_name=a2a_agent_name, - agent_card=agent_card, - ), - ) - return result_text, None - - llm_response = _parse_agent_response( - raw_result=raw_result, agent_response_model=agent_response_model - ) - - if extension_registry and isinstance(llm_response, BaseModel): - llm_response = extension_registry.process_response_with_all( - llm_response, extension_states or {} - ) - - if isinstance(llm_response, BaseModel) and isinstance( - llm_response, AgentResponseProtocol - ): - if not llm_response.is_a2a: - final_turn_number = turn_num + 1 - crewai_event_bus.emit( - None, - A2AMessageSentEvent( - message=str(llm_response.message), - turn_number=final_turn_number, - is_multiturn=True, - agent_role=agent_role, - from_task=from_task, - from_agent=from_agent, - endpoint=endpoint, - a2a_agent_name=a2a_agent_name, - ), - ) - crewai_event_bus.emit( - None, - A2AConversationCompletedEvent( - status="completed", - final_result=str(llm_response.message), - error=None, - total_turns=final_turn_number, - from_task=from_task, - from_agent=from_agent, - endpoint=endpoint, - a2a_agent_name=a2a_agent_name, - agent_card=agent_card, - ), - ) - return llm_response.message, None - return None, llm_response.message - - return str(raw_result), None - - -def _prepare_agent_cards_dict( - a2a_result: TaskStateResult, - agent_id: str, - agent_cards: Mapping[str, AgentCard | dict[str, Any]] | None, -) -> dict[str, AgentCard | dict[str, Any]]: - """Prepare agent cards dictionary from result and existing cards. - - Shared logic for both sync and async response handlers. - """ - agent_cards_dict: dict[str, AgentCard | dict[str, Any]] = ( - dict(agent_cards) if agent_cards else {} - ) - if "agent_card" in a2a_result and agent_id not in agent_cards_dict: - agent_cards_dict[agent_id] = a2a_result["agent_card"] - return agent_cards_dict - - -def _init_delegation_state( - ctx: DelegationContext, - agent_cards: dict[str, AgentCard] | None, -) -> DelegationState: - """Initialize delegation state from context and agent cards. - - Args: - ctx: Delegation context with config and settings. - agent_cards: Pre-fetched agent cards. - - Returns: - Initial delegation state for the conversation loop. - """ - current_agent_card = agent_cards.get(ctx.agent_id) if agent_cards else None - return DelegationState( - current_request=ctx.current_request, - context_id=ctx.context_id, - task_id=ctx.task_id, - reference_task_ids=list(ctx.reference_task_ids), - conversation_history=[], - agent_card=current_agent_card, - agent_card_dict=current_agent_card.model_dump() if current_agent_card else None, - agent_name=current_agent_card.name if current_agent_card else None, - ) - - -def _get_turn_context( - agent_config: A2AConfig | A2AClientConfig, -) -> tuple[Any | None, list[str] | None]: - """Get context for a delegation turn. - - Returns: - Tuple of (agent_branch, accepted_output_modes). - """ - console_formatter = getattr(crewai_event_bus, "_console", None) - agent_branch = None - if console_formatter: - agent_branch = getattr( - console_formatter, "current_agent_branch", None - ) or getattr(console_formatter, "current_task_branch", None) - - accepted_output_modes = None - if isinstance(agent_config, A2AClientConfig): - accepted_output_modes = agent_config.accepted_output_modes - - return agent_branch, accepted_output_modes - - -def _prepare_delegation_context( - self: Agent, - agent_response: AgentResponseProtocol, - task: Task, - original_task_description: str | None, -) -> DelegationContext: - """Prepare delegation context from agent response and task. - - Shared logic for both sync and async delegation. - - Returns: - DelegationContext with all values needed for delegation. - """ - a2a_agents, agent_response_model = get_a2a_agents_and_response_model(self.a2a) - agent_ids = tuple(config.endpoint for config in a2a_agents) - current_request = str(agent_response.message) - - if not a2a_agents: - raise ValueError("No A2A agents configured for delegation") - - if isinstance(agent_response, AgentResponseProtocol) and agent_response.a2a_ids: - agent_id = agent_response.a2a_ids[0] - else: - agent_id = agent_ids[0] - - if agent_id not in agent_ids: - raise ValueError(f"Unknown A2A agent ID: {agent_id} not in {agent_ids}") - - agent_config = next(filter(lambda x: x.endpoint == agent_id, a2a_agents), None) - if agent_config is None: - raise ValueError(f"Agent configuration not found for endpoint: {agent_id}") - task_config = task.config or {} - - if original_task_description is None: - original_task_description = task.description - - return DelegationContext( - a2a_agents=a2a_agents, - agent_response_model=agent_response_model, - current_request=current_request, - agent_id=agent_id, - agent_config=agent_config, - context_id=task_config.get("context_id"), - task_id=task_config.get("task_id"), - metadata=task_config.get("metadata"), - extensions=task_config.get("extensions"), - reference_task_ids=task_config.get("reference_task_ids", []), - original_task_description=original_task_description, - max_turns=agent_config.max_turns, - ) - - -def _handle_task_completion( - a2a_result: TaskStateResult, - task: Task, - task_id_config: str | None, - reference_task_ids: list[str], - agent_config: A2AConfig | A2AClientConfig, - turn_num: int, - from_task: Any | None = None, - from_agent: Any | None = None, - endpoint: str | None = None, - a2a_agent_name: str | None = None, - agent_card: dict[str, Any] | None = None, -) -> tuple[str | None, str | None, list[str], str]: - """Handle task completion state including reference task updates. - - When a remote task completes, this function: - 1. Adds the completed task_id to reference_task_ids (if not already present) - 2. Clears task_id_config to signal that a new task ID should be generated for next turn - 3. Updates task.config with the reference list for subsequent A2A calls - - The reference_task_ids list tracks all completed tasks in this conversation chain, - allowing the remote agent to maintain context across multi-turn interactions. - - Shared logic for both sync and async delegation. - - Args: - a2a_result: Result from A2A delegation containing task status. - task: CrewAI Task object to update with reference IDs. - task_id_config: Current task ID (will be added to references if task completed). - reference_task_ids: Mutable list of completed task IDs (updated in place). - agent_config: A2A configuration with trust settings. - turn_num: Current turn number. - from_task: Optional CrewAI Task for event metadata. - from_agent: Optional CrewAI Agent for event metadata. - endpoint: A2A endpoint URL. - a2a_agent_name: Name of remote A2A agent. - agent_card: Agent card dict for event metadata. - - Returns: - Tuple of (result_if_trusted, updated_task_id, updated_reference_task_ids, remote_notice). - - result_if_trusted: Final result if trust_remote_completion_status=True, else None - - updated_task_id: None (cleared to generate new ID for next turn) - - updated_reference_task_ids: The mutated list with completed task added - - remote_notice: Template notice about remote agent response - """ - remote_notice = "" - if a2a_result["status"] == TaskState.completed: - remote_notice = REMOTE_AGENT_RESPONSE_NOTICE - - if task_id_config is not None and task_id_config not in reference_task_ids: - reference_task_ids.append(task_id_config) - - if task.config is None: - task.config = {} - task.config["reference_task_ids"] = list(reference_task_ids) - - task_id_config = None - - if agent_config.trust_remote_completion_status: - result_text = a2a_result.get("result", "") - final_turn_number = turn_num + 1 - crewai_event_bus.emit( - None, - A2AConversationCompletedEvent( - status="completed", - final_result=result_text, - error=None, - total_turns=final_turn_number, - from_task=from_task, - from_agent=from_agent, - endpoint=endpoint, - a2a_agent_name=a2a_agent_name, - agent_card=agent_card, - ), - ) - return str(result_text), task_id_config, reference_task_ids, remote_notice - - return None, task_id_config, reference_task_ids, remote_notice - - -def _handle_agent_response_and_continue( - self: Agent, - a2a_result: TaskStateResult, - agent_id: str, - agent_cards: dict[str, AgentCard] | None, - a2a_agents: list[A2AConfig | A2AClientConfig], - original_task_description: str, - conversation_history: list[Message], - turn_num: int, - max_turns: int, - task: Task, - original_fn: Callable[..., str], - context: str | None, - tools: list[BaseTool] | None, - agent_response_model: type[BaseModel] | None, - extension_registry: ExtensionRegistry | None = None, - remote_status_notice: str = "", - endpoint: str | None = None, - a2a_agent_name: str | None = None, - agent_card: dict[str, Any] | None = None, -) -> tuple[str | None, str | None]: - """Handle A2A result and get CrewAI agent's response. - - Args: - self: The agent instance - a2a_result: Result from A2A delegation - agent_id: ID of the A2A agent - agent_cards: Pre-fetched agent cards - a2a_agents: List of A2A configurations - original_task_description: Original task description - conversation_history: Conversation history - turn_num: Current turn number - max_turns: Maximum turns allowed - task: The task being executed - original_fn: Original execute_task method - context: Optional context - tools: Optional tools - agent_response_model: Response model for parsing - - Returns: - Tuple of (final_result, current_request) where: - - final_result is not None if conversation should end - - current_request is the next message to send if continuing - """ - agent_cards_dict = _prepare_agent_cards_dict(a2a_result, agent_id, agent_cards) - - ( - task.description, - disable_structured_output, - extension_states, - ) = _augment_prompt_with_a2a( - a2a_agents=a2a_agents, - task_description=original_task_description, - conversation_history=conversation_history, - turn_num=turn_num, - max_turns=max_turns, - agent_cards=agent_cards_dict, - remote_status_notice=remote_status_notice, - ) - - original_response_model = task.response_model - if disable_structured_output: - task.response_model = None - - raw_result = original_fn(self, task, context, tools) - - if disable_structured_output: - task.response_model = original_response_model - - return _process_response_result( - raw_result=raw_result, - disable_structured_output=disable_structured_output, - turn_num=turn_num, - agent_role=self.role, - agent_response_model=agent_response_model, - extension_registry=extension_registry, - extension_states=extension_states, - from_task=task, - from_agent=self, - endpoint=endpoint, - a2a_agent_name=a2a_agent_name, - agent_card=agent_card, - ) - - -def _delegate_to_a2a( - self: Agent, - agent_response: AgentResponseProtocol, - task: Task, - original_fn: Callable[..., str], - context: str | None, - tools: list[BaseTool] | None, - agent_cards: dict[str, AgentCard] | None = None, - original_task_description: str | None = None, - _extension_registry: ExtensionRegistry | None = None, -) -> str: - """Delegate to A2A agent with multi-turn conversation support. - - Args: - self: The agent instance - agent_response: The AgentResponse indicating delegation - task: The task being executed (for extracting A2A fields) - original_fn: The original execute_task method for follow-ups - context: Optional context for task execution - tools: Optional tools available to the agent - agent_cards: Pre-fetched agent cards from _execute_task_with_a2a - original_task_description: The original task description before A2A augmentation - _extension_registry: Optional registry of A2A extensions (unused, reserved for future use) - - Returns: - Result from A2A agent - - Raises: - ImportError: If a2a-sdk is not installed - """ - ctx = _prepare_delegation_context( - self, agent_response, task, original_task_description - ) - state = _init_delegation_state(ctx, agent_cards) - current_request = state.current_request - context_id = state.context_id - task_id = state.task_id - reference_task_ids = state.reference_task_ids - conversation_history = state.conversation_history - - try: - for turn_num in range(ctx.max_turns): - agent_branch, accepted_output_modes = _get_turn_context(ctx.agent_config) - - a2a_result = execute_a2a_delegation( - endpoint=ctx.agent_config.endpoint, - auth=ctx.agent_config.auth, - timeout=ctx.agent_config.timeout, - task_description=current_request, - context_id=context_id, - task_id=task_id, - reference_task_ids=reference_task_ids, - metadata=ctx.metadata, - extensions=ctx.extensions, - conversation_history=conversation_history, - agent_id=ctx.agent_id, - agent_role=Role.user, - agent_branch=agent_branch, - response_model=ctx.agent_config.response_model, - turn_number=turn_num + 1, - updates=ctx.agent_config.updates, - transport=ctx.agent_config.transport, - from_task=task, - from_agent=self, - client_extensions=getattr(ctx.agent_config, "extensions", None), - accepted_output_modes=accepted_output_modes, - input_files=task.input_files, - ) - - conversation_history = a2a_result.get("history", []) - - if conversation_history: - latest_message = conversation_history[-1] - if latest_message.task_id is not None: - task_id = latest_message.task_id - if latest_message.context_id is not None: - context_id = latest_message.context_id - - if a2a_result["status"] in [TaskState.completed, TaskState.input_required]: - trusted_result, task_id, reference_task_ids, remote_notice = ( - _handle_task_completion( - a2a_result, - task, - task_id, - reference_task_ids, - ctx.agent_config, - turn_num, - from_task=task, - from_agent=self, - endpoint=ctx.agent_config.endpoint, - a2a_agent_name=state.agent_name, - agent_card=state.agent_card_dict, - ) - ) - if trusted_result is not None: - return trusted_result - - final_result, next_request = _handle_agent_response_and_continue( - self=self, - a2a_result=a2a_result, - agent_id=ctx.agent_id, - agent_cards=agent_cards, - a2a_agents=ctx.a2a_agents, - original_task_description=ctx.original_task_description, - conversation_history=conversation_history, - turn_num=turn_num, - max_turns=ctx.max_turns, - task=task, - original_fn=original_fn, - context=context, - tools=tools, - agent_response_model=ctx.agent_response_model, - extension_registry=_extension_registry, - remote_status_notice=remote_notice, - endpoint=ctx.agent_config.endpoint, - a2a_agent_name=state.agent_name, - agent_card=state.agent_card_dict, - ) - - if final_result is not None: - return final_result - - if next_request is not None: - current_request = next_request - - continue - - error_msg = a2a_result.get("error", "Unknown error") - - final_result, next_request = _handle_agent_response_and_continue( - self=self, - a2a_result=a2a_result, - agent_id=ctx.agent_id, - agent_cards=agent_cards, - a2a_agents=ctx.a2a_agents, - original_task_description=ctx.original_task_description, - conversation_history=conversation_history, - turn_num=turn_num, - max_turns=ctx.max_turns, - task=task, - original_fn=original_fn, - context=context, - tools=tools, - agent_response_model=ctx.agent_response_model, - extension_registry=_extension_registry, - endpoint=ctx.agent_config.endpoint, - a2a_agent_name=state.agent_name, - agent_card=state.agent_card_dict, - ) - - if final_result is not None: - return final_result - - if next_request is not None: - current_request = next_request - continue - - return _emit_delegation_failed( - error_msg, - turn_num, - task, - self, - ctx.agent_config.endpoint, - state.agent_name, - state.agent_card_dict, - ) - - return _handle_max_turns_exceeded( - conversation_history, - ctx.max_turns, - from_task=task, - from_agent=self, - endpoint=ctx.agent_config.endpoint, - a2a_agent_name=state.agent_name, - agent_card=state.agent_card_dict, - ) - - finally: - task.description = ctx.original_task_description - - -async def _afetch_card_from_config( - config: A2AConfig | A2AClientConfig, -) -> tuple[A2AConfig | A2AClientConfig, AgentCard | Exception]: - """Fetch agent card from A2A config asynchronously.""" - try: - card = await afetch_agent_card( - endpoint=config.endpoint, - auth=config.auth, - timeout=config.timeout, - ) - return config, card - except Exception as e: - return config, e - - -async def _afetch_agent_cards_concurrently( - a2a_agents: list[A2AConfig | A2AClientConfig], -) -> tuple[dict[str, AgentCard], dict[str, str]]: - """Fetch agent cards concurrently for multiple A2A agents using asyncio.""" - agent_cards: dict[str, AgentCard] = {} - failed_agents: dict[str, str] = {} - - if not a2a_agents: - return agent_cards, failed_agents - - tasks = [_afetch_card_from_config(config) for config in a2a_agents] - results = await asyncio.gather(*tasks) - - for config, result in results: - if isinstance(result, Exception): - if config.fail_fast: - raise RuntimeError( - f"Failed to fetch agent card from {config.endpoint}. " - f"Ensure the A2A agent is running and accessible. Error: {result}" - ) from result - failed_agents[config.endpoint] = str(result) - else: - agent_cards[config.endpoint] = result - - return agent_cards, failed_agents - - -async def _aexecute_task_with_a2a( - self: Agent, - a2a_agents: list[A2AConfig | A2AClientConfig], - original_fn: Callable[..., Coroutine[Any, Any, str]], - task: Task, - agent_response_model: type[BaseModel] | None, - context: str | None, - tools: list[BaseTool] | None, - extension_registry: ExtensionRegistry, -) -> str: - """Async version of _execute_task_with_a2a.""" - original_description: str = task.description - original_output_pydantic = task.output_pydantic - original_response_model = task.response_model - - agent_cards, failed_agents = await _afetch_agent_cards_concurrently(a2a_agents) - - if not agent_cards and a2a_agents and failed_agents: - unavailable_agents_text = "" - for endpoint, error in failed_agents.items(): - unavailable_agents_text += f" - {endpoint}: {error}\n" - - notice = UNAVAILABLE_AGENTS_NOTICE_TEMPLATE.substitute( - unavailable_agents=unavailable_agents_text - ) - task.description = f"{original_description}{notice}" - - try: - return await original_fn(self, task, context, tools) - finally: - task.description = original_description - - task.description, _, extension_states = _augment_prompt_with_a2a( - a2a_agents=a2a_agents, - task_description=original_description, - agent_cards=agent_cards, - failed_agents=failed_agents, - extension_registry=extension_registry, - ) - task.response_model = agent_response_model - - try: - raw_result = await original_fn(self, task, context, tools) - agent_response = _parse_agent_response( - raw_result=raw_result, agent_response_model=agent_response_model - ) - - if extension_registry and isinstance(agent_response, BaseModel): - agent_response = extension_registry.process_response_with_all( - agent_response, extension_states - ) - - if isinstance(agent_response, BaseModel) and isinstance( - agent_response, AgentResponseProtocol - ): - if agent_response.is_a2a: - return await _adelegate_to_a2a( - self, - agent_response=agent_response, - task=task, - original_fn=original_fn, - context=context, - tools=tools, - agent_cards=agent_cards, - original_task_description=original_description, - _extension_registry=extension_registry, - ) - task.output_pydantic = None - return agent_response.message - - return raw_result - finally: - task.description = original_description - if task.output_pydantic is not None: - task.output_pydantic = original_output_pydantic - task.response_model = original_response_model - - -async def _ahandle_agent_response_and_continue( - self: Agent, - a2a_result: TaskStateResult, - agent_id: str, - agent_cards: dict[str, AgentCard] | None, - a2a_agents: list[A2AConfig | A2AClientConfig], - original_task_description: str, - conversation_history: list[Message], - turn_num: int, - max_turns: int, - task: Task, - original_fn: Callable[..., Coroutine[Any, Any, str]], - context: str | None, - tools: list[BaseTool] | None, - agent_response_model: type[BaseModel] | None, - extension_registry: ExtensionRegistry | None = None, - remote_status_notice: str = "", - endpoint: str | None = None, - a2a_agent_name: str | None = None, - agent_card: dict[str, Any] | None = None, -) -> tuple[str | None, str | None]: - """Async version of _handle_agent_response_and_continue.""" - agent_cards_dict = _prepare_agent_cards_dict(a2a_result, agent_id, agent_cards) - - ( - task.description, - disable_structured_output, - extension_states, - ) = _augment_prompt_with_a2a( - a2a_agents=a2a_agents, - task_description=original_task_description, - conversation_history=conversation_history, - turn_num=turn_num, - max_turns=max_turns, - agent_cards=agent_cards_dict, - remote_status_notice=remote_status_notice, - ) - - original_response_model = task.response_model - if disable_structured_output: - task.response_model = None - - raw_result = await original_fn(self, task, context, tools) - - if disable_structured_output: - task.response_model = original_response_model - - return _process_response_result( - raw_result=raw_result, - disable_structured_output=disable_structured_output, - turn_num=turn_num, - agent_role=self.role, - agent_response_model=agent_response_model, - extension_registry=extension_registry, - extension_states=extension_states, - from_task=task, - from_agent=self, - endpoint=endpoint, - a2a_agent_name=a2a_agent_name, - agent_card=agent_card, - ) - - -async def _adelegate_to_a2a( - self: Agent, - agent_response: AgentResponseProtocol, - task: Task, - original_fn: Callable[..., Coroutine[Any, Any, str]], - context: str | None, - tools: list[BaseTool] | None, - agent_cards: dict[str, AgentCard] | None = None, - original_task_description: str | None = None, - _extension_registry: ExtensionRegistry | None = None, -) -> str: - """Async version of _delegate_to_a2a.""" - ctx = _prepare_delegation_context( - self, agent_response, task, original_task_description - ) - state = _init_delegation_state(ctx, agent_cards) - current_request = state.current_request - context_id = state.context_id - task_id = state.task_id - reference_task_ids = state.reference_task_ids - conversation_history = state.conversation_history - - try: - for turn_num in range(ctx.max_turns): - agent_branch, accepted_output_modes = _get_turn_context(ctx.agent_config) - - a2a_result = await aexecute_a2a_delegation( - endpoint=ctx.agent_config.endpoint, - auth=ctx.agent_config.auth, - timeout=ctx.agent_config.timeout, - task_description=current_request, - context_id=context_id, - task_id=task_id, - reference_task_ids=reference_task_ids, - metadata=ctx.metadata, - extensions=ctx.extensions, - conversation_history=conversation_history, - agent_id=ctx.agent_id, - agent_role=Role.user, - agent_branch=agent_branch, - response_model=ctx.agent_config.response_model, - turn_number=turn_num + 1, - transport=ctx.agent_config.transport, - updates=ctx.agent_config.updates, - from_task=task, - from_agent=self, - client_extensions=getattr(ctx.agent_config, "extensions", None), - accepted_output_modes=accepted_output_modes, - input_files=task.input_files, - ) - - conversation_history = a2a_result.get("history", []) - - if conversation_history: - latest_message = conversation_history[-1] - if latest_message.task_id is not None: - task_id = latest_message.task_id - if latest_message.context_id is not None: - context_id = latest_message.context_id - - if a2a_result["status"] in [TaskState.completed, TaskState.input_required]: - trusted_result, task_id, reference_task_ids, remote_notice = ( - _handle_task_completion( - a2a_result, - task, - task_id, - reference_task_ids, - ctx.agent_config, - turn_num, - from_task=task, - from_agent=self, - endpoint=ctx.agent_config.endpoint, - a2a_agent_name=state.agent_name, - agent_card=state.agent_card_dict, - ) - ) - if trusted_result is not None: - return trusted_result - - final_result, next_request = await _ahandle_agent_response_and_continue( - self=self, - a2a_result=a2a_result, - agent_id=ctx.agent_id, - agent_cards=agent_cards, - a2a_agents=ctx.a2a_agents, - original_task_description=ctx.original_task_description, - conversation_history=conversation_history, - turn_num=turn_num, - max_turns=ctx.max_turns, - task=task, - original_fn=original_fn, - context=context, - tools=tools, - agent_response_model=ctx.agent_response_model, - extension_registry=_extension_registry, - remote_status_notice=remote_notice, - endpoint=ctx.agent_config.endpoint, - a2a_agent_name=state.agent_name, - agent_card=state.agent_card_dict, - ) - - if final_result is not None: - return final_result - - if next_request is not None: - current_request = next_request - - continue - - error_msg = a2a_result.get("error", "Unknown error") - - final_result, next_request = await _ahandle_agent_response_and_continue( - self=self, - a2a_result=a2a_result, - agent_id=ctx.agent_id, - agent_cards=agent_cards, - a2a_agents=ctx.a2a_agents, - original_task_description=ctx.original_task_description, - conversation_history=conversation_history, - turn_num=turn_num, - max_turns=ctx.max_turns, - task=task, - original_fn=original_fn, - context=context, - tools=tools, - agent_response_model=ctx.agent_response_model, - extension_registry=_extension_registry, - endpoint=ctx.agent_config.endpoint, - a2a_agent_name=state.agent_name, - agent_card=state.agent_card_dict, - ) - - if final_result is not None: - return final_result - - if next_request is not None: - current_request = next_request - continue - - return _emit_delegation_failed( - error_msg, - turn_num, - task, - self, - ctx.agent_config.endpoint, - state.agent_name, - state.agent_card_dict, - ) - - return _handle_max_turns_exceeded( - conversation_history, - ctx.max_turns, - from_task=task, - from_agent=self, - endpoint=ctx.agent_config.endpoint, - a2a_agent_name=state.agent_name, - agent_card=state.agent_card_dict, - ) - - finally: - task.description = ctx.original_task_description +from crewai_a2a.wrapper import * # noqa: E402, F403 diff --git a/lib/crewai/src/crewai/agent/core.py b/lib/crewai/src/crewai/agent/core.py index 21edbd160..e978f49e3 100644 --- a/lib/crewai/src/crewai/agent/core.py +++ b/lib/crewai/src/crewai/agent/core.py @@ -84,16 +84,16 @@ from crewai.utilities.training_handler import CrewTrainingHandler try: - from crewai.a2a.types import AgentResponseProtocol + from crewai_a2a.types import AgentResponseProtocol except ImportError: AgentResponseProtocol = None # type: ignore[assignment, misc] if TYPE_CHECKING: + from crewai_a2a.config import A2AClientConfig, A2AConfig, A2AServerConfig from crewai_files import FileInput from crewai_tools import CodeInterpreterTool - from crewai.a2a.config import A2AClientConfig, A2AConfig, A2AServerConfig from crewai.agents.agent_builder.base_agent import PlatformAppOrAction from crewai.task import Task from crewai.tools.base_tool import BaseTool @@ -1740,7 +1740,7 @@ class Agent(BaseAgent): # Rebuild Agent model to resolve A2A type forward references try: - from crewai.a2a.config import ( + from crewai_a2a.config import ( A2AClientConfig as _A2AClientConfig, A2AConfig as _A2AConfig, A2AServerConfig as _A2AServerConfig, diff --git a/lib/crewai/src/crewai/agent/internal/meta.py b/lib/crewai/src/crewai/agent/internal/meta.py index 7ecea9b35..1a7892f9f 100644 --- a/lib/crewai/src/crewai/agent/internal/meta.py +++ b/lib/crewai/src/crewai/agent/internal/meta.py @@ -58,10 +58,10 @@ class AgentMeta(ModelMetaclass): a2a_value = getattr(self, "a2a", None) if a2a_value is not None: - from crewai.a2a.extensions.registry import ( + from crewai_a2a.extensions.registry import ( create_extension_registry_from_config, ) - from crewai.a2a.wrapper import wrap_agent_with_a2a_instance + from crewai_a2a.wrapper import wrap_agent_with_a2a_instance extension_registry = create_extension_registry_from_config( a2a_value diff --git a/lib/crewai/src/crewai/lite_agent.py b/lib/crewai/src/crewai/lite_agent.py index 66b710890..f865bc5f2 100644 --- a/lib/crewai/src/crewai/lite_agent.py +++ b/lib/crewai/src/crewai/lite_agent.py @@ -31,10 +31,9 @@ from typing_extensions import Self if TYPE_CHECKING: + from crewai_a2a.config import A2AClientConfig, A2AConfig, A2AServerConfig from crewai_files import FileInput - from crewai.a2a.config import A2AClientConfig, A2AConfig, A2AServerConfig - from crewai.agents.agent_builder.base_agent import BaseAgent from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess from crewai.agents.cache.cache_handler import CacheHandler @@ -120,8 +119,9 @@ def _kickoff_with_a2a_support( Returns: LiteAgentOutput from either local execution or A2A delegation. """ - from crewai.a2a.utils.response_model import get_a2a_agents_and_response_model - from crewai.a2a.wrapper import _execute_task_with_a2a + from crewai_a2a.utils.response_model import get_a2a_agents_and_response_model + from crewai_a2a.wrapper import _execute_task_with_a2a + from crewai.task import Task a2a_agents, agent_response_model = get_a2a_agents_and_response_model(agent.a2a) @@ -319,11 +319,11 @@ class LiteAgent(FlowTrackable, BaseModel): def setup_a2a_support(self) -> Self: """Setup A2A extensions and server methods if a2a config exists.""" if self.a2a: - from crewai.a2a.config import A2AClientConfig, A2AConfig - from crewai.a2a.extensions.registry import ( + from crewai_a2a.config import A2AClientConfig, A2AConfig + from crewai_a2a.extensions.registry import ( create_extension_registry_from_config, ) - from crewai.a2a.utils.agent_card import inject_a2a_server_methods + from crewai_a2a.utils.agent_card import inject_a2a_server_methods configs = self.a2a if isinstance(self.a2a, list) else [self.a2a] client_configs = [ @@ -995,7 +995,7 @@ class LiteAgent(FlowTrackable, BaseModel): try: - from crewai.a2a.config import ( + from crewai_a2a.config import ( A2AClientConfig as _A2AClientConfig, A2AConfig as _A2AConfig, A2AServerConfig as _A2AServerConfig, diff --git a/lib/crewai/tests/agents/test_a2a_trust_completion_status.py b/lib/crewai/tests/agents/test_a2a_trust_completion_status.py index 6347f8e1c..f3d6ec0fb 100644 --- a/lib/crewai/tests/agents/test_a2a_trust_completion_status.py +++ b/lib/crewai/tests/agents/test_a2a_trust_completion_status.py @@ -4,7 +4,7 @@ from unittest.mock import MagicMock, patch import pytest -from crewai.a2a.config import A2AConfig +from crewai_a2a.config import A2AConfig try: from a2a.types import Message, Role @@ -27,8 +27,8 @@ def _create_mock_agent_card(name: str = "Test", url: str = "http://test-endpoint @pytest.mark.skipif(not A2A_SDK_INSTALLED, reason="Requires a2a-sdk to be installed") def test_trust_remote_completion_status_true_returns_directly(): """When trust_remote_completion_status=True and A2A returns completed, return result directly.""" - from crewai.a2a.wrapper import _delegate_to_a2a - from crewai.a2a.types import AgentResponseProtocol + from crewai_a2a.wrapper import _delegate_to_a2a + from crewai_a2a.types import AgentResponseProtocol from crewai import Agent, Task a2a_config = A2AConfig( @@ -51,8 +51,8 @@ def test_trust_remote_completion_status_true_returns_directly(): a2a_ids = ["http://test-endpoint.com/"] with ( - patch("crewai.a2a.wrapper.execute_a2a_delegation") as mock_execute, - patch("crewai.a2a.wrapper._fetch_agent_cards_concurrently") as mock_fetch, + patch("crewai_a2a.wrapper.execute_a2a_delegation") as mock_execute, + patch("crewai_a2a.wrapper._fetch_agent_cards_concurrently") as mock_fetch, ): mock_card = _create_mock_agent_card() mock_fetch.return_value = ({"http://test-endpoint.com/": mock_card}, {}) @@ -83,7 +83,7 @@ def test_trust_remote_completion_status_true_returns_directly(): @pytest.mark.skipif(not A2A_SDK_INSTALLED, reason="Requires a2a-sdk to be installed") def test_trust_remote_completion_status_false_continues_conversation(): """When trust_remote_completion_status=False and A2A returns completed, ask server agent.""" - from crewai.a2a.wrapper import _delegate_to_a2a + from crewai_a2a.wrapper import _delegate_to_a2a from crewai import Agent, Task a2a_config = A2AConfig( @@ -116,8 +116,8 @@ def test_trust_remote_completion_status_false_continues_conversation(): return "unexpected" with ( - patch("crewai.a2a.wrapper.execute_a2a_delegation") as mock_execute, - patch("crewai.a2a.wrapper._fetch_agent_cards_concurrently") as mock_fetch, + patch("crewai_a2a.wrapper.execute_a2a_delegation") as mock_execute, + patch("crewai_a2a.wrapper._fetch_agent_cards_concurrently") as mock_fetch, ): mock_card = _create_mock_agent_card() mock_fetch.return_value = ({"http://test-endpoint.com/": mock_card}, {}) diff --git a/lib/crewai/tests/agents/test_agent_a2a_kickoff.py b/lib/crewai/tests/agents/test_agent_a2a_kickoff.py index 00123c4cf..c3b34d21d 100644 --- a/lib/crewai/tests/agents/test_agent_a2a_kickoff.py +++ b/lib/crewai/tests/agents/test_agent_a2a_kickoff.py @@ -7,7 +7,7 @@ import os import pytest from crewai import Agent -from crewai.a2a.config import A2AClientConfig +from crewai_a2a.config import A2AClientConfig A2A_TEST_ENDPOINT = os.getenv( diff --git a/lib/crewai/tests/agents/test_agent_a2a_wrapping.py b/lib/crewai/tests/agents/test_agent_a2a_wrapping.py index 6b7d4be9f..b6e08f0bc 100644 --- a/lib/crewai/tests/agents/test_agent_a2a_wrapping.py +++ b/lib/crewai/tests/agents/test_agent_a2a_wrapping.py @@ -5,7 +5,7 @@ from unittest.mock import patch import pytest from crewai import Agent -from crewai.a2a.config import A2AConfig +from crewai_a2a.config import A2AConfig try: import a2a # noqa: F401 diff --git a/pyproject.toml b/pyproject.toml index 657c15eaa..b5664e49e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -106,6 +106,7 @@ ignore-decorators = ["typing.overload"] "lib/crewai/tests/**/*.py" = ["S101", "RET504", "S105", "S106"] # Allow assert statements, unnecessary assignments, and hardcoded passwords in tests "lib/crewai-tools/tests/**/*.py" = ["S101", "RET504", "S105", "S106", "RUF012", "N818", "E402", "RUF043", "S110", "B017"] # Allow various test-specific patterns "lib/crewai-files/tests/**/*.py" = ["S101", "RET504", "S105", "S106", "B017", "F841"] # Allow assert statements and blind exception assertions in tests +"lib/crewai-a2a/tests/**/*.py" = ["S101", "RET504", "S105", "S106", "B017", "F821"] # Allow assert statements, unnecessary assignments, hardcoded passwords, blind exceptions, and forward refs in tests [tool.mypy] @@ -118,7 +119,7 @@ warn_return_any = true show_error_codes = true warn_unused_ignores = true python_version = "3.12" -exclude = "(?x)(^lib/crewai/src/crewai/cli/templates/|^lib/crewai/tests/|^lib/crewai-tools/tests/|^lib/crewai-files/tests/)" +exclude = "(?x)(^lib/crewai/src/crewai/cli/templates/|^lib/crewai/tests/|^lib/crewai-tools/tests/|^lib/crewai-files/tests/|^lib/crewai-a2a/tests/)" plugins = ["pydantic.mypy"] @@ -134,6 +135,7 @@ testpaths = [ "lib/crewai/tests", "lib/crewai-tools/tests", "lib/crewai-files/tests", + "lib/crewai-a2a/tests", ] asyncio_mode = "strict" asyncio_default_fixture_loop_scope = "function" @@ -157,6 +159,7 @@ members = [ "lib/crewai-tools", "lib/devtools", "lib/crewai-files", + "lib/crewai-a2a", ] @@ -165,3 +168,4 @@ crewai = { workspace = true } crewai-tools = { workspace = true } crewai-devtools = { workspace = true } crewai-files = { workspace = true } +crewai-a2a = { workspace = true } diff --git a/uv.lock b/uv.lock index dba6ab30c..61d7b0a1c 100644 --- a/uv.lock +++ b/uv.lock @@ -15,6 +15,7 @@ resolution-markers = [ [manifest] members = [ "crewai", + "crewai-a2a", "crewai-devtools", "crewai-files", "crewai-tools", @@ -1124,10 +1125,7 @@ dependencies = [ [package.optional-dependencies] a2a = [ - { name = "a2a-sdk" }, - { name = "aiocache", extra = ["memcached", "redis"] }, - { name = "httpx-auth" }, - { name = "httpx-sse" }, + { name = "crewai-a2a" }, ] anthropic = [ { name = "anthropic" }, @@ -1181,9 +1179,7 @@ watson = [ [package.metadata] requires-dist = [ - { name = "a2a-sdk", marker = "extra == 'a2a'", specifier = "~=0.3.10" }, { name = "aiobotocore", marker = "extra == 'aws'", specifier = "~=2.25.2" }, - { name = "aiocache", extras = ["memcached", "redis"], marker = "extra == 'a2a'", specifier = "~=0.12.3" }, { name = "aiosqlite", specifier = "~=0.21.0" }, { name = "anthropic", marker = "extra == 'anthropic'", specifier = "~=0.73.0" }, { name = "appdirs", specifier = "~=1.4.4" }, @@ -1192,13 +1188,12 @@ requires-dist = [ { name = "boto3", marker = "extra == 'bedrock'", specifier = "~=1.40.45" }, { name = "chromadb", specifier = "~=1.1.0" }, { name = "click", specifier = "~=8.1.7" }, + { name = "crewai-a2a", marker = "extra == 'a2a'", editable = "lib/crewai-a2a" }, { name = "crewai-files", marker = "extra == 'file-processing'", editable = "lib/crewai-files" }, { name = "crewai-tools", marker = "extra == 'tools'", editable = "lib/crewai-tools" }, { name = "docling", marker = "extra == 'docling'", specifier = "~=2.63.0" }, { name = "google-genai", marker = "extra == 'google-genai'", specifier = "~=1.49.0" }, { name = "httpx", specifier = "~=0.28.1" }, - { name = "httpx-auth", marker = "extra == 'a2a'", specifier = "~=0.23.1" }, - { name = "httpx-sse", marker = "extra == 'a2a'", specifier = "~=0.4.0" }, { name = "ibm-watsonx-ai", marker = "extra == 'watson'", specifier = "~=1.3.39" }, { name = "instructor", specifier = ">=1.3.3" }, { name = "json-repair", specifier = "~=0.25.2" }, @@ -1233,6 +1228,26 @@ requires-dist = [ ] provides-extras = ["a2a", "anthropic", "aws", "azure-ai-inference", "bedrock", "docling", "embeddings", "file-processing", "google-genai", "litellm", "mem0", "openpyxl", "pandas", "qdrant", "tools", "voyageai", "watson"] +[[package]] +name = "crewai-a2a" +source = { editable = "lib/crewai-a2a" } +dependencies = [ + { name = "a2a-sdk" }, + { name = "aiocache", extra = ["memcached", "redis"] }, + { name = "crewai" }, + { name = "httpx-auth" }, + { name = "httpx-sse" }, +] + +[package.metadata] +requires-dist = [ + { name = "a2a-sdk", specifier = "~=0.3.10" }, + { name = "aiocache", extras = ["memcached", "redis"], specifier = "~=0.12.3" }, + { name = "crewai", editable = "lib/crewai" }, + { name = "httpx-auth", specifier = "~=0.23.1" }, + { name = "httpx-sse", specifier = "~=0.4.0" }, +] + [[package]] name = "crewai-devtools" source = { editable = "lib/devtools" }