From e4be1329a085bc5c34900e78f76e8c548ef676fc Mon Sep 17 00:00:00 2001 From: Greyson LaLonde Date: Thu, 29 Jan 2026 04:52:36 -0500 Subject: [PATCH] feat: add server-side auth schemes and protocol extensions - add server auth scheme base class and implementations (api key, bearer token, basic/digest auth, mtls) - add server-side extension system for a2a protocol extensions - add extensions middleware for x-a2a-extensions header management - add extension validation and registry utilities - enhance auth utilities with server-side support - add async intercept method to match client call interceptor protocol - fix type_checking import to resolve mypy errors with a2aconfig --- lib/crewai/src/crewai/a2a/auth/__init__.py | 18 +- .../src/crewai/a2a/auth/server_schemes.py | 739 ++++++++++++++++++ lib/crewai/src/crewai/a2a/auth/utils.py | 59 +- .../src/crewai/a2a/extensions/__init__.py | 33 + lib/crewai/src/crewai/a2a/extensions/base.py | 55 +- .../src/crewai/a2a/extensions/registry.py | 158 +++- .../src/crewai/a2a/extensions/server.py | 305 ++++++++ 7 files changed, 1339 insertions(+), 28 deletions(-) create mode 100644 lib/crewai/src/crewai/a2a/auth/server_schemes.py create mode 100644 lib/crewai/src/crewai/a2a/extensions/server.py diff --git a/lib/crewai/src/crewai/a2a/auth/__init__.py b/lib/crewai/src/crewai/a2a/auth/__init__.py index 3cc2f446f..093193a8e 100644 --- a/lib/crewai/src/crewai/a2a/auth/__init__.py +++ b/lib/crewai/src/crewai/a2a/auth/__init__.py @@ -1,20 +1,36 @@ """A2A authentication schemas.""" -from crewai.a2a.auth.schemas import ( +from crewai.a2a.auth.client_schemes import ( APIKeyAuth, + AuthScheme, BearerTokenAuth, + ClientAuthScheme, HTTPBasicAuth, HTTPDigestAuth, OAuth2AuthorizationCode, OAuth2ClientCredentials, + TLSConfig, +) +from crewai.a2a.auth.server_schemes import ( + AuthenticatedUser, + OIDCAuth, + ServerAuthScheme, + SimpleTokenAuth, ) __all__ = [ "APIKeyAuth", + "AuthScheme", + "AuthenticatedUser", "BearerTokenAuth", + "ClientAuthScheme", "HTTPBasicAuth", "HTTPDigestAuth", "OAuth2AuthorizationCode", "OAuth2ClientCredentials", + "OIDCAuth", + "ServerAuthScheme", + "SimpleTokenAuth", + "TLSConfig", ] diff --git a/lib/crewai/src/crewai/a2a/auth/server_schemes.py b/lib/crewai/src/crewai/a2a/auth/server_schemes.py new file mode 100644 index 000000000..25ad597be --- /dev/null +++ b/lib/crewai/src/crewai/a2a/auth/server_schemes.py @@ -0,0 +1,739 @@ +"""Server-side authentication schemes for A2A protocol. + +These schemes validate incoming requests to A2A server endpoints. + +Supported authentication methods: +- Simple token validation with static bearer tokens +- OpenID Connect with JWT validation using JWKS +- OAuth2 with JWT validation or token introspection +""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from dataclasses import dataclass +import logging +import os +from typing import TYPE_CHECKING, Annotated, Any, ClassVar, Literal + +import jwt +from jwt import PyJWKClient +from pydantic import ( + BaseModel, + BeforeValidator, + ConfigDict, + Field, + HttpUrl, + PrivateAttr, + SecretStr, + model_validator, +) +from typing_extensions import Self + + +if TYPE_CHECKING: + from a2a.types import OAuth2SecurityScheme + + +logger = logging.getLogger(__name__) + + +try: + from fastapi import HTTPException, status as http_status + + HTTP_401_UNAUTHORIZED = http_status.HTTP_401_UNAUTHORIZED + HTTP_500_INTERNAL_SERVER_ERROR = http_status.HTTP_500_INTERNAL_SERVER_ERROR + HTTP_503_SERVICE_UNAVAILABLE = http_status.HTTP_503_SERVICE_UNAVAILABLE +except ImportError: + + class HTTPException(Exception): # type: ignore[no-redef] # noqa: N818 + """Fallback HTTPException when FastAPI is not installed.""" + + def __init__( + self, + status_code: int, + detail: str | None = None, + headers: dict[str, str] | None = None, + ) -> None: + self.status_code = status_code + self.detail = detail + self.headers = headers + super().__init__(detail) + + HTTP_401_UNAUTHORIZED = 401 + HTTP_500_INTERNAL_SERVER_ERROR = 500 + HTTP_503_SERVICE_UNAVAILABLE = 503 + + +def _coerce_secret_str(v: str | SecretStr | None) -> SecretStr | None: + """Coerce string to SecretStr.""" + if v is None or isinstance(v, SecretStr): + return v + return SecretStr(v) + + +CoercedSecretStr = Annotated[SecretStr, BeforeValidator(_coerce_secret_str)] + +JWTAlgorithm = Literal[ + "RS256", + "RS384", + "RS512", + "ES256", + "ES384", + "ES512", + "PS256", + "PS384", + "PS512", +] + + +@dataclass +class AuthenticatedUser: + """Result of successful authentication. + + Attributes: + token: The original token that was validated. + scheme: Name of the authentication scheme used. + claims: JWT claims from OIDC or OAuth2 authentication. + """ + + token: str + scheme: str + claims: dict[str, Any] | None = None + + +class ServerAuthScheme(ABC, BaseModel): + """Base class for server-side authentication schemes. + + Each scheme validates incoming requests and returns an AuthenticatedUser + on success, or raises HTTPException on failure. + """ + + model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid") + + @abstractmethod + async def authenticate(self, token: str) -> AuthenticatedUser: + """Authenticate the provided token. + + Args: + token: The bearer token to authenticate. + + Returns: + AuthenticatedUser on successful authentication. + + Raises: + HTTPException: If authentication fails. + """ + ... + + +class SimpleTokenAuth(ServerAuthScheme): + """Simple bearer token authentication. + + Validates tokens against a configured static token or AUTH_TOKEN env var. + + Attributes: + token: Expected token value. Falls back to AUTH_TOKEN env var if not set. + """ + + token: CoercedSecretStr | None = Field( + default=None, + description="Expected token. Falls back to AUTH_TOKEN env var.", + ) + + def _get_expected_token(self) -> str | None: + """Get the expected token value.""" + if self.token: + return self.token.get_secret_value() + return os.environ.get("AUTH_TOKEN") + + async def authenticate(self, token: str) -> AuthenticatedUser: + """Authenticate using simple token comparison. + + Args: + token: The bearer token to authenticate. + + Returns: + AuthenticatedUser on successful authentication. + + Raises: + HTTPException: If authentication fails. + """ + expected = self._get_expected_token() + + if expected is None: + logger.warning( + "Simple token authentication failed", + extra={"reason": "no_token_configured"}, + ) + raise HTTPException( + status_code=HTTP_401_UNAUTHORIZED, + detail="Authentication not configured", + ) + + if token != expected: + raise HTTPException( + status_code=HTTP_401_UNAUTHORIZED, + detail="Invalid or missing authentication credentials", + ) + + return AuthenticatedUser( + token=token, + scheme="simple_token", + ) + + +class OIDCAuth(ServerAuthScheme): + """OpenID Connect authentication. + + Validates JWTs using JWKS with caching support via PyJWT. + + Attributes: + issuer: The OpenID Connect issuer URL. + audience: The expected audience claim. + jwks_url: Optional explicit JWKS URL. Derived from issuer if not set. + algorithms: List of allowed signing algorithms. + required_claims: List of claims that must be present in the token. + jwks_cache_ttl: TTL for JWKS cache in seconds. + clock_skew_seconds: Allowed clock skew for token validation. + """ + + issuer: HttpUrl = Field( + description="OpenID Connect issuer URL (e.g., https://auth.example.com)" + ) + audience: str = Field(description="Expected audience claim (e.g., api://my-agent)") + jwks_url: HttpUrl | None = Field( + default=None, + description="Explicit JWKS URL. Derived from issuer if not set.", + ) + algorithms: list[str] = Field( + default_factory=lambda: ["RS256"], + description="List of allowed signing algorithms (RS256, ES256, etc.)", + ) + required_claims: list[str] = Field( + default_factory=lambda: ["exp", "iat", "iss", "aud", "sub"], + description="List of claims that must be present in the token", + ) + jwks_cache_ttl: int = Field( + default=3600, + description="TTL for JWKS cache in seconds", + ge=60, + ) + clock_skew_seconds: float = Field( + default=30.0, + description="Allowed clock skew for token validation", + ge=0.0, + ) + + _jwk_client: PyJWKClient | None = PrivateAttr(default=None) + + @model_validator(mode="after") + def _init_jwk_client(self) -> Self: + """Initialize the JWK client after model creation.""" + jwks_url = ( + str(self.jwks_url) + if self.jwks_url + else f"{str(self.issuer).rstrip('/')}/.well-known/jwks.json" + ) + self._jwk_client = PyJWKClient(jwks_url, lifespan=self.jwks_cache_ttl) + return self + + async def authenticate(self, token: str) -> AuthenticatedUser: + """Authenticate using OIDC JWT validation. + + Args: + token: The JWT to authenticate. + + Returns: + AuthenticatedUser on successful authentication. + + Raises: + HTTPException: If authentication fails. + """ + if self._jwk_client is None: + raise HTTPException( + status_code=HTTP_500_INTERNAL_SERVER_ERROR, + detail="OIDC not initialized", + ) + + try: + signing_key = self._jwk_client.get_signing_key_from_jwt(token) + + claims = jwt.decode( + token, + signing_key.key, + algorithms=self.algorithms, + audience=self.audience, + issuer=str(self.issuer).rstrip("/"), + leeway=self.clock_skew_seconds, + options={ + "require": self.required_claims, + }, + ) + + return AuthenticatedUser( + token=token, + scheme="oidc", + claims=claims, + ) + + except jwt.ExpiredSignatureError: + logger.debug( + "OIDC authentication failed", + extra={"reason": "token_expired", "scheme": "oidc"}, + ) + raise HTTPException( + status_code=HTTP_401_UNAUTHORIZED, + detail="Token has expired", + ) from None + except jwt.InvalidAudienceError: + logger.debug( + "OIDC authentication failed", + extra={"reason": "invalid_audience", "scheme": "oidc"}, + ) + raise HTTPException( + status_code=HTTP_401_UNAUTHORIZED, + detail="Invalid token audience", + ) from None + except jwt.InvalidIssuerError: + logger.debug( + "OIDC authentication failed", + extra={"reason": "invalid_issuer", "scheme": "oidc"}, + ) + raise HTTPException( + status_code=HTTP_401_UNAUTHORIZED, + detail="Invalid token issuer", + ) from None + except jwt.MissingRequiredClaimError as e: + logger.debug( + "OIDC authentication failed", + extra={"reason": "missing_claim", "claim": e.claim, "scheme": "oidc"}, + ) + raise HTTPException( + status_code=HTTP_401_UNAUTHORIZED, + detail=f"Missing required claim: {e.claim}", + ) from None + except jwt.PyJWKClientError as e: + logger.error( + "OIDC authentication failed", + extra={ + "reason": "jwks_client_error", + "error": str(e), + "scheme": "oidc", + }, + ) + raise HTTPException( + status_code=HTTP_503_SERVICE_UNAVAILABLE, + detail="Unable to fetch signing keys", + ) from None + except jwt.InvalidTokenError as e: + logger.debug( + "OIDC authentication failed", + extra={"reason": "invalid_token", "error": str(e), "scheme": "oidc"}, + ) + raise HTTPException( + status_code=HTTP_401_UNAUTHORIZED, + detail="Invalid or missing authentication credentials", + ) from None + + +class OAuth2ServerAuth(ServerAuthScheme): + """OAuth2 authentication for A2A server. + + Declares OAuth2 security scheme in AgentCard and validates tokens using + either JWKS for JWT tokens or token introspection for opaque tokens. + + This is distinct from OIDCAuth in that it declares an explicit OAuth2SecurityScheme + with flows, rather than an OpenIdConnectSecurityScheme with discovery URL. + + Attributes: + token_url: OAuth2 token endpoint URL for client_credentials flow. + authorization_url: OAuth2 authorization endpoint for authorization_code flow. + refresh_url: Optional refresh token endpoint URL. + scopes: Available OAuth2 scopes with descriptions. + jwks_url: JWKS URL for JWT validation. Required if not using introspection. + introspection_url: Token introspection endpoint (RFC 7662). Alternative to JWKS. + introspection_client_id: Client ID for introspection endpoint authentication. + introspection_client_secret: Client secret for introspection endpoint. + audience: Expected audience claim for JWT validation. + issuer: Expected issuer claim for JWT validation. + algorithms: Allowed JWT signing algorithms. + required_claims: Claims that must be present in the token. + jwks_cache_ttl: TTL for JWKS cache in seconds. + clock_skew_seconds: Allowed clock skew for token validation. + """ + + token_url: HttpUrl = Field( + description="OAuth2 token endpoint URL", + ) + authorization_url: HttpUrl | None = Field( + default=None, + description="OAuth2 authorization endpoint URL for authorization_code flow", + ) + refresh_url: HttpUrl | None = Field( + default=None, + description="OAuth2 refresh token endpoint URL", + ) + scopes: dict[str, str] = Field( + default_factory=dict, + description="Available OAuth2 scopes with descriptions", + ) + jwks_url: HttpUrl | None = Field( + default=None, + description="JWKS URL for JWT validation. Required if not using introspection.", + ) + introspection_url: HttpUrl | None = Field( + default=None, + description="Token introspection endpoint (RFC 7662). Alternative to JWKS.", + ) + introspection_client_id: str | None = Field( + default=None, + description="Client ID for introspection endpoint authentication", + ) + introspection_client_secret: CoercedSecretStr | None = Field( + default=None, + description="Client secret for introspection endpoint authentication", + ) + audience: str | None = Field( + default=None, + description="Expected audience claim for JWT validation", + ) + issuer: str | None = Field( + default=None, + description="Expected issuer claim for JWT validation", + ) + algorithms: list[str] = Field( + default_factory=lambda: ["RS256"], + description="Allowed JWT signing algorithms", + ) + required_claims: list[str] = Field( + default_factory=lambda: ["exp", "iat"], + description="Claims that must be present in the token", + ) + jwks_cache_ttl: int = Field( + default=3600, + description="TTL for JWKS cache in seconds", + ge=60, + ) + clock_skew_seconds: float = Field( + default=30.0, + description="Allowed clock skew for token validation", + ge=0.0, + ) + + _jwk_client: PyJWKClient | None = PrivateAttr(default=None) + + @model_validator(mode="after") + def _validate_and_init(self) -> Self: + """Validate configuration and initialize JWKS client if needed.""" + if not self.jwks_url and not self.introspection_url: + raise ValueError( + "Either jwks_url or introspection_url must be provided for token validation" + ) + + if self.introspection_url: + if not self.introspection_client_id or not self.introspection_client_secret: + raise ValueError( + "introspection_client_id and introspection_client_secret are required " + "when using token introspection" + ) + + if self.jwks_url: + self._jwk_client = PyJWKClient( + str(self.jwks_url), lifespan=self.jwks_cache_ttl + ) + + return self + + async def authenticate(self, token: str) -> AuthenticatedUser: + """Authenticate using OAuth2 token validation. + + Uses JWKS validation if jwks_url is configured, otherwise falls back + to token introspection. + + Args: + token: The OAuth2 access token to authenticate. + + Returns: + AuthenticatedUser on successful authentication. + + Raises: + HTTPException: If authentication fails. + """ + if self._jwk_client: + return await self._authenticate_jwt(token) + return await self._authenticate_introspection(token) + + async def _authenticate_jwt(self, token: str) -> AuthenticatedUser: + """Authenticate using JWKS JWT validation.""" + if self._jwk_client is None: + raise HTTPException( + status_code=HTTP_500_INTERNAL_SERVER_ERROR, + detail="OAuth2 JWKS not initialized", + ) + + try: + signing_key = self._jwk_client.get_signing_key_from_jwt(token) + + decode_options: dict[str, Any] = { + "require": self.required_claims, + } + + claims = jwt.decode( + token, + signing_key.key, + algorithms=self.algorithms, + audience=self.audience, + issuer=self.issuer, + leeway=self.clock_skew_seconds, + options=decode_options, + ) + + return AuthenticatedUser( + token=token, + scheme="oauth2", + claims=claims, + ) + + except jwt.ExpiredSignatureError: + logger.debug( + "OAuth2 authentication failed", + extra={"reason": "token_expired", "scheme": "oauth2"}, + ) + raise HTTPException( + status_code=HTTP_401_UNAUTHORIZED, + detail="Token has expired", + ) from None + except jwt.InvalidAudienceError: + logger.debug( + "OAuth2 authentication failed", + extra={"reason": "invalid_audience", "scheme": "oauth2"}, + ) + raise HTTPException( + status_code=HTTP_401_UNAUTHORIZED, + detail="Invalid token audience", + ) from None + except jwt.InvalidIssuerError: + logger.debug( + "OAuth2 authentication failed", + extra={"reason": "invalid_issuer", "scheme": "oauth2"}, + ) + raise HTTPException( + status_code=HTTP_401_UNAUTHORIZED, + detail="Invalid token issuer", + ) from None + except jwt.MissingRequiredClaimError as e: + logger.debug( + "OAuth2 authentication failed", + extra={"reason": "missing_claim", "claim": e.claim, "scheme": "oauth2"}, + ) + raise HTTPException( + status_code=HTTP_401_UNAUTHORIZED, + detail=f"Missing required claim: {e.claim}", + ) from None + except jwt.PyJWKClientError as e: + logger.error( + "OAuth2 authentication failed", + extra={ + "reason": "jwks_client_error", + "error": str(e), + "scheme": "oauth2", + }, + ) + raise HTTPException( + status_code=HTTP_503_SERVICE_UNAVAILABLE, + detail="Unable to fetch signing keys", + ) from None + except jwt.InvalidTokenError as e: + logger.debug( + "OAuth2 authentication failed", + extra={"reason": "invalid_token", "error": str(e), "scheme": "oauth2"}, + ) + raise HTTPException( + status_code=HTTP_401_UNAUTHORIZED, + detail="Invalid or missing authentication credentials", + ) from None + + async def _authenticate_introspection(self, token: str) -> AuthenticatedUser: + """Authenticate using OAuth2 token introspection (RFC 7662).""" + import httpx + + if not self.introspection_url: + raise HTTPException( + status_code=HTTP_500_INTERNAL_SERVER_ERROR, + detail="OAuth2 introspection not configured", + ) + + try: + async with httpx.AsyncClient() as client: + response = await client.post( + str(self.introspection_url), + data={"token": token}, + auth=( + self.introspection_client_id or "", + self.introspection_client_secret.get_secret_value() + if self.introspection_client_secret + else "", + ), + ) + response.raise_for_status() + introspection_result = response.json() + + except httpx.HTTPStatusError as e: + logger.error( + "OAuth2 introspection failed", + extra={"reason": "http_error", "status_code": e.response.status_code}, + ) + raise HTTPException( + status_code=HTTP_503_SERVICE_UNAVAILABLE, + detail="Token introspection service unavailable", + ) from None + except Exception as e: + logger.error( + "OAuth2 introspection failed", + extra={"reason": "unexpected_error", "error": str(e)}, + ) + raise HTTPException( + status_code=HTTP_503_SERVICE_UNAVAILABLE, + detail="Token introspection failed", + ) from None + + if not introspection_result.get("active", False): + logger.debug( + "OAuth2 authentication failed", + extra={"reason": "token_not_active", "scheme": "oauth2"}, + ) + raise HTTPException( + status_code=HTTP_401_UNAUTHORIZED, + detail="Token is not active", + ) + + return AuthenticatedUser( + token=token, + scheme="oauth2", + claims=introspection_result, + ) + + def to_security_scheme(self) -> OAuth2SecurityScheme: + """Generate OAuth2SecurityScheme for AgentCard declaration. + + Creates an OAuth2SecurityScheme with appropriate flows based on + the configured URLs. Includes client_credentials flow if token_url + is set, and authorization_code flow if authorization_url is set. + + Returns: + OAuth2SecurityScheme suitable for use in AgentCard security_schemes. + """ + from a2a.types import ( + AuthorizationCodeOAuthFlow, + ClientCredentialsOAuthFlow, + OAuth2SecurityScheme, + OAuthFlows, + ) + + client_credentials = None + authorization_code = None + + if self.token_url: + client_credentials = ClientCredentialsOAuthFlow( + token_url=str(self.token_url), + refresh_url=str(self.refresh_url) if self.refresh_url else None, + scopes=self.scopes, + ) + + if self.authorization_url: + authorization_code = AuthorizationCodeOAuthFlow( + authorization_url=str(self.authorization_url), + token_url=str(self.token_url), + refresh_url=str(self.refresh_url) if self.refresh_url else None, + scopes=self.scopes, + ) + + return OAuth2SecurityScheme( + flows=OAuthFlows( + client_credentials=client_credentials, + authorization_code=authorization_code, + ), + description="OAuth2 authentication", + ) + + +class APIKeyServerAuth(ServerAuthScheme): + """API Key authentication for A2A server. + + Validates requests using an API key in a header, query parameter, or cookie. + + Attributes: + name: The name of the API key parameter (default: X-API-Key). + location: Where to look for the API key (header, query, or cookie). + api_key: The expected API key value. + """ + + name: str = Field( + default="X-API-Key", + description="Name of the API key parameter", + ) + location: Literal["header", "query", "cookie"] = Field( + default="header", + description="Where to look for the API key", + ) + api_key: CoercedSecretStr = Field( + description="Expected API key value", + ) + + async def authenticate(self, token: str) -> AuthenticatedUser: + """Authenticate using API key comparison. + + Args: + token: The API key to authenticate. + + Returns: + AuthenticatedUser on successful authentication. + + Raises: + HTTPException: If authentication fails. + """ + if token != self.api_key.get_secret_value(): + raise HTTPException( + status_code=HTTP_401_UNAUTHORIZED, + detail="Invalid API key", + ) + + return AuthenticatedUser( + token=token, + scheme="api_key", + ) + + +class MTLSServerAuth(ServerAuthScheme): + """Mutual TLS authentication marker for AgentCard declaration. + + This scheme is primarily for AgentCard security_schemes declaration. + Actual mTLS verification happens at the TLS/transport layer, not + at the application layer via token validation. + + When configured, this signals to clients that the server requires + client certificates for authentication. + """ + + description: str = Field( + default="Mutual TLS certificate authentication", + description="Description for the security scheme", + ) + + async def authenticate(self, token: str) -> AuthenticatedUser: + """Return authenticated user for mTLS. + + mTLS verification happens at the transport layer before this is called. + If we reach this point, the TLS handshake with client cert succeeded. + + Args: + token: Certificate subject or identifier (from TLS layer). + + Returns: + AuthenticatedUser indicating mTLS authentication. + """ + return AuthenticatedUser( + token=token or "mtls-verified", + scheme="mtls", + ) diff --git a/lib/crewai/src/crewai/a2a/auth/utils.py b/lib/crewai/src/crewai/a2a/auth/utils.py index 2dddaf00a..3e8de3e0d 100644 --- a/lib/crewai/src/crewai/a2a/auth/utils.py +++ b/lib/crewai/src/crewai/a2a/auth/utils.py @@ -6,8 +6,10 @@ OAuth2, API keys, and HTTP authentication methods. import asyncio from collections.abc import Awaitable, Callable, MutableMapping +import hashlib import re -from typing import Final +import threading +from typing import Final, Literal, cast from a2a.client.errors import A2AClientHTTPError from a2a.types import ( @@ -18,10 +20,10 @@ from a2a.types import ( ) from httpx import AsyncClient, Response -from crewai.a2a.auth.schemas import ( +from crewai.a2a.auth.client_schemes import ( APIKeyAuth, - AuthScheme, BearerTokenAuth, + ClientAuthScheme, HTTPBasicAuth, HTTPDigestAuth, OAuth2AuthorizationCode, @@ -29,12 +31,44 @@ from crewai.a2a.auth.schemas import ( ) -_auth_store: dict[int, AuthScheme | None] = {} +class _AuthStore: + """Store for authentication schemes with safe concurrent access.""" + + def __init__(self) -> None: + self._store: dict[str, ClientAuthScheme | None] = {} + self._lock = threading.RLock() + + @staticmethod + def compute_key(auth_type: str, auth_data: str) -> str: + """Compute a collision-resistant key using SHA-256.""" + content = f"{auth_type}:{auth_data}" + return hashlib.sha256(content.encode()).hexdigest() + + def set(self, key: str, auth: ClientAuthScheme | None) -> None: + """Store an auth scheme.""" + with self._lock: + self._store[key] = auth + + def get(self, key: str) -> ClientAuthScheme | None: + """Retrieve an auth scheme by key.""" + with self._lock: + return self._store.get(key) + + def __setitem__(self, key: str, value: ClientAuthScheme | None) -> None: + with self._lock: + self._store[key] = value + + def __getitem__(self, key: str) -> ClientAuthScheme | None: + with self._lock: + return self._store[key] + + +_auth_store = _AuthStore() _SCHEME_PATTERN: Final[re.Pattern[str]] = re.compile(r"(\w+)\s+(.+?)(?=,\s*\w+\s+|$)") _PARAM_PATTERN: Final[re.Pattern[str]] = re.compile(r'(\w+)=(?:"([^"]*)"|([^\s,]+))') -_SCHEME_AUTH_MAPPING: Final[dict[type, tuple[type[AuthScheme], ...]]] = { +_SCHEME_AUTH_MAPPING: Final[dict[type, tuple[type[ClientAuthScheme], ...]]] = { OAuth2SecurityScheme: ( OAuth2ClientCredentials, OAuth2AuthorizationCode, @@ -43,7 +77,9 @@ _SCHEME_AUTH_MAPPING: Final[dict[type, tuple[type[AuthScheme], ...]]] = { APIKeySecurityScheme: (APIKeyAuth,), } -_HTTP_SCHEME_MAPPING: Final[dict[str, type[AuthScheme]]] = { +_HTTPSchemeType = Literal["basic", "digest", "bearer"] + +_HTTP_SCHEME_MAPPING: Final[dict[_HTTPSchemeType, type[ClientAuthScheme]]] = { "basic": HTTPBasicAuth, "digest": HTTPDigestAuth, "bearer": BearerTokenAuth, @@ -51,8 +87,8 @@ _HTTP_SCHEME_MAPPING: Final[dict[str, type[AuthScheme]]] = { def _raise_auth_mismatch( - expected_classes: type[AuthScheme] | tuple[type[AuthScheme], ...], - provided_auth: AuthScheme, + expected_classes: type[ClientAuthScheme] | tuple[type[ClientAuthScheme], ...], + provided_auth: ClientAuthScheme, ) -> None: """Raise authentication mismatch error. @@ -111,7 +147,7 @@ def parse_www_authenticate(header_value: str) -> dict[str, dict[str, str]]: def validate_auth_against_agent_card( - agent_card: AgentCard, auth: AuthScheme | None + agent_card: AgentCard, auth: ClientAuthScheme | None ) -> None: """Validate that provided auth matches AgentCard security requirements. @@ -145,7 +181,8 @@ def validate_auth_against_agent_card( return if isinstance(scheme, HTTPAuthSecurityScheme): - if required_class := _HTTP_SCHEME_MAPPING.get(scheme.scheme.lower()): + scheme_key = cast(_HTTPSchemeType, scheme.scheme.lower()) + if required_class := _HTTP_SCHEME_MAPPING.get(scheme_key): if not isinstance(auth, required_class): _raise_auth_mismatch(required_class, auth) return @@ -156,7 +193,7 @@ def validate_auth_against_agent_card( async def retry_on_401( request_func: Callable[[], Awaitable[Response]], - auth_scheme: AuthScheme | None, + auth_scheme: ClientAuthScheme | None, client: AsyncClient, headers: MutableMapping[str, str], max_retries: int = 3, diff --git a/lib/crewai/src/crewai/a2a/extensions/__init__.py b/lib/crewai/src/crewai/a2a/extensions/__init__.py index 1d0e81e91..b21ae10ad 100644 --- a/lib/crewai/src/crewai/a2a/extensions/__init__.py +++ b/lib/crewai/src/crewai/a2a/extensions/__init__.py @@ -1,4 +1,37 @@ """A2A Protocol Extensions for CrewAI. This module contains extensions to the A2A (Agent-to-Agent) protocol. + +**Client-side extensions** (A2AExtension) allow customizing how the A2A wrapper +processes requests and responses during delegation to remote agents. These provide +hooks for tool injection, prompt augmentation, and response processing. + +**Server-side extensions** (ServerExtension) allow agents to offer additional +functionality beyond the core A2A specification. Clients activate extensions +via the X-A2A-Extensions header. + +See: https://a2a-protocol.org/latest/topics/extensions/ """ + +from crewai.a2a.extensions.base import ( + A2AExtension, + ConversationState, + ExtensionRegistry, + ValidatedA2AExtension, +) +from crewai.a2a.extensions.server import ( + ExtensionContext, + ServerExtension, + ServerExtensionRegistry, +) + + +__all__ = [ + "A2AExtension", + "ConversationState", + "ExtensionContext", + "ExtensionRegistry", + "ServerExtension", + "ServerExtensionRegistry", + "ValidatedA2AExtension", +] diff --git a/lib/crewai/src/crewai/a2a/extensions/base.py b/lib/crewai/src/crewai/a2a/extensions/base.py index 23b09305e..2d7a81a22 100644 --- a/lib/crewai/src/crewai/a2a/extensions/base.py +++ b/lib/crewai/src/crewai/a2a/extensions/base.py @@ -1,14 +1,20 @@ -"""Base extension interface for A2A wrapper integrations. +"""Base extension interface for CrewAI A2A wrapper processing hooks. -This module defines the protocol for extending A2A wrapper functionality -with custom logic for conversation processing, prompt augmentation, and -agent response handling. +This module defines the protocol for extending CrewAI's A2A wrapper functionality +with custom logic for tool injection, prompt augmentation, and response processing. + +Note: These are CrewAI-specific processing hooks, NOT A2A protocol extensions. +A2A protocol extensions are capability declarations using AgentExtension objects +in AgentCard.capabilities.extensions, activated via the A2A-Extensions HTTP header. +See: https://a2a-protocol.org/latest/topics/extensions/ """ from __future__ import annotations from collections.abc import Sequence -from typing import TYPE_CHECKING, Any, Protocol +from typing import TYPE_CHECKING, Annotated, Any, Protocol, runtime_checkable + +from pydantic import BeforeValidator if TYPE_CHECKING: @@ -17,6 +23,20 @@ if TYPE_CHECKING: from crewai.agent.core import Agent +def _validate_a2a_extension(v: Any) -> Any: + """Validate that value implements A2AExtension protocol.""" + if not isinstance(v, A2AExtension): + raise ValueError( + f"Value must implement A2AExtension protocol. " + f"Got {type(v).__name__} which is missing required methods." + ) + return v + + +ValidatedA2AExtension = Annotated[Any, BeforeValidator(_validate_a2a_extension)] + + +@runtime_checkable class ConversationState(Protocol): """Protocol for extension-specific conversation state. @@ -33,11 +53,36 @@ class ConversationState(Protocol): ... +@runtime_checkable class A2AExtension(Protocol): """Protocol for A2A wrapper extensions. Extensions can implement this protocol to inject custom logic into the A2A conversation flow at various integration points. + + Example: + class MyExtension: + def inject_tools(self, agent: Agent) -> None: + # Add custom tools to the agent + pass + + def extract_state_from_history( + self, conversation_history: Sequence[Message] + ) -> ConversationState | None: + # Extract state from conversation + return None + + def augment_prompt( + self, base_prompt: str, conversation_state: ConversationState | None + ) -> str: + # Add custom instructions + return base_prompt + + def process_response( + self, agent_response: Any, conversation_state: ConversationState | None + ) -> Any: + # Modify response if needed + return agent_response """ def inject_tools(self, agent: Agent) -> None: diff --git a/lib/crewai/src/crewai/a2a/extensions/registry.py b/lib/crewai/src/crewai/a2a/extensions/registry.py index ca4824911..4d195961b 100644 --- a/lib/crewai/src/crewai/a2a/extensions/registry.py +++ b/lib/crewai/src/crewai/a2a/extensions/registry.py @@ -1,34 +1,170 @@ -"""Extension registry factory for A2A configurations. +"""A2A Protocol extension utilities. -This module provides utilities for creating extension registries from A2A configurations. +This module provides utilities for working with A2A protocol extensions as +defined in the A2A specification. Extensions are capability declarations in +AgentCard.capabilities.extensions using AgentExtension objects, activated +via the X-A2A-Extensions HTTP header. + +See: https://a2a-protocol.org/latest/topics/extensions/ """ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import Any +from a2a.client.middleware import ClientCallContext, ClientCallInterceptor +from a2a.extensions.common import ( + HTTP_EXTENSION_HEADER, +) +from a2a.types import AgentCard, AgentExtension + +from crewai.a2a.config import A2AClientConfig, A2AConfig from crewai.a2a.extensions.base import ExtensionRegistry -if TYPE_CHECKING: - from crewai.a2a.config import A2AConfig +def get_extensions_from_config( + a2a_config: list[A2AConfig | A2AClientConfig] | A2AConfig | A2AClientConfig, +) -> list[str]: + """Extract extension URIs from A2A configuration. + + Args: + a2a_config: A2A configuration (single or list). + + Returns: + Deduplicated list of extension URIs from all configs. + """ + configs = a2a_config if isinstance(a2a_config, list) else [a2a_config] + seen: set[str] = set() + result: list[str] = [] + + for config in configs: + if not isinstance(config, A2AClientConfig): + continue + for uri in config.extensions: + if uri not in seen: + seen.add(uri) + result.append(uri) + + return result + + +class ExtensionsMiddleware(ClientCallInterceptor): + """Middleware to add X-A2A-Extensions header to requests. + + This middleware adds the extensions header to all outgoing requests, + declaring which A2A protocol extensions the client supports. + """ + + def __init__(self, extensions: list[str]) -> None: + """Initialize with extension URIs. + + Args: + extensions: List of extension URIs the client supports. + """ + self._extensions = extensions + + async def intercept( + self, + method_name: str, + request_payload: dict[str, Any], + http_kwargs: dict[str, Any], + agent_card: AgentCard | None, + context: ClientCallContext | None, + ) -> tuple[dict[str, Any], dict[str, Any]]: + """Add extensions header to the request. + + Args: + method_name: The A2A method being called. + request_payload: The JSON-RPC request payload. + http_kwargs: HTTP request kwargs (headers, etc). + agent_card: The target agent's card. + context: Optional call context. + + Returns: + Tuple of (request_payload, modified_http_kwargs). + """ + if self._extensions: + headers = http_kwargs.setdefault("headers", {}) + headers[HTTP_EXTENSION_HEADER] = ",".join(self._extensions) + return request_payload, http_kwargs + + +def validate_required_extensions( + agent_card: AgentCard, + client_extensions: list[str] | None, +) -> list[AgentExtension]: + """Validate that client supports all required extensions from agent. + + Args: + agent_card: The agent's card with declared extensions. + client_extensions: Extension URIs the client supports. + + Returns: + List of unsupported required extensions. + + Raises: + None - returns list of unsupported extensions for caller to handle. + """ + unsupported: list[AgentExtension] = [] + client_set = set(client_extensions or []) + + if not agent_card.capabilities or not agent_card.capabilities.extensions: + return unsupported + + unsupported.extend( + ext + for ext in agent_card.capabilities.extensions + if ext.required and ext.uri not in client_set + ) + + return unsupported def create_extension_registry_from_config( - a2a_config: list[A2AConfig] | A2AConfig, + a2a_config: list[A2AConfig | A2AClientConfig] | A2AConfig | A2AClientConfig, ) -> ExtensionRegistry: - """Create an extension registry from A2A configuration. + """Create an extension registry from A2A client configuration. + + Extracts client_extensions from each A2AClientConfig and registers them + with the ExtensionRegistry. These extensions provide CrewAI-specific + processing hooks (tool injection, prompt augmentation, response processing). + + Note: A2A protocol extensions (URI strings sent via X-A2A-Extensions header) + are handled separately via get_extensions_from_config() and ExtensionsMiddleware. Args: - a2a_config: A2A configuration (single or list) + a2a_config: A2A configuration (single or list). Returns: - Configured extension registry with all applicable extensions + Extension registry with all client_extensions registered. + + Example: + class LoggingExtension: + def inject_tools(self, agent): pass + def extract_state_from_history(self, history): return None + def augment_prompt(self, prompt, state): return prompt + def process_response(self, response, state): + print(f"Response: {response}") + return response + + config = A2AClientConfig( + endpoint="https://agent.example.com", + client_extensions=[LoggingExtension()], + ) + registry = create_extension_registry_from_config(config) """ registry = ExtensionRegistry() configs = a2a_config if isinstance(a2a_config, list) else [a2a_config] - for _ in configs: - pass + seen: set[int] = set() + + for config in configs: + if isinstance(config, (A2AConfig, A2AClientConfig)): + client_exts = getattr(config, "client_extensions", []) + for extension in client_exts: + ext_id = id(extension) + if ext_id not in seen: + seen.add(ext_id) + registry.register(extension) return registry diff --git a/lib/crewai/src/crewai/a2a/extensions/server.py b/lib/crewai/src/crewai/a2a/extensions/server.py new file mode 100644 index 000000000..9bbc9c08b --- /dev/null +++ b/lib/crewai/src/crewai/a2a/extensions/server.py @@ -0,0 +1,305 @@ +"""A2A protocol server extensions for CrewAI agents. + +This module provides the base class and context for implementing A2A protocol +extensions on the server side. Extensions allow agents to offer additional +functionality beyond the core A2A specification. + +See: https://a2a-protocol.org/latest/topics/extensions/ +""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +import logging +from typing import TYPE_CHECKING, Annotated, Any + +from a2a.types import AgentExtension +from pydantic_core import CoreSchema, core_schema + + +if TYPE_CHECKING: + from a2a.server.context import ServerCallContext + from pydantic import GetCoreSchemaHandler + + +logger = logging.getLogger(__name__) + + +@dataclass +class ExtensionContext: + """Context passed to extension hooks during request processing. + + Provides access to request metadata, client extensions, and shared state + that extensions can read from and write to. + + Attributes: + metadata: Request metadata dict, includes extension-namespaced keys. + client_extensions: Set of extension URIs the client declared support for. + state: Mutable dict for extensions to share data during request lifecycle. + server_context: The underlying A2A server call context. + """ + + metadata: dict[str, Any] + client_extensions: set[str] + state: dict[str, Any] = field(default_factory=dict) + server_context: ServerCallContext | None = None + + def get_extension_metadata(self, uri: str, key: str) -> Any | None: + """Get extension-specific metadata value. + + Extension metadata uses namespaced keys in the format: + "{extension_uri}/{key}" + + Args: + uri: The extension URI. + key: The metadata key within the extension namespace. + + Returns: + The metadata value, or None if not present. + """ + full_key = f"{uri}/{key}" + return self.metadata.get(full_key) + + def set_extension_metadata(self, uri: str, key: str, value: Any) -> None: + """Set extension-specific metadata value. + + Args: + uri: The extension URI. + key: The metadata key within the extension namespace. + value: The value to set. + """ + full_key = f"{uri}/{key}" + self.metadata[full_key] = value + + +class ServerExtension(ABC): + """Base class for A2A protocol server extensions. + + Subclass this to create custom extensions that modify agent behavior + when clients activate them. Extensions are identified by URI and can + be marked as required. + + Example: + class SamplingExtension(ServerExtension): + uri = "urn:crewai:ext:sampling/v1" + required = True + + def __init__(self, max_tokens: int = 4096): + self.max_tokens = max_tokens + + @property + def params(self) -> dict[str, Any]: + return {"max_tokens": self.max_tokens} + + async def on_request(self, context: ExtensionContext) -> None: + limit = context.get_extension_metadata(self.uri, "limit") + if limit: + context.state["token_limit"] = int(limit) + + async def on_response(self, context: ExtensionContext, result: Any) -> Any: + return result + """ + + uri: Annotated[str, "Extension URI identifier. Must be unique."] + required: Annotated[bool, "Whether clients must support this extension."] = False + description: Annotated[ + str | None, "Human-readable description of the extension." + ] = None + + @classmethod + def __get_pydantic_core_schema__( + cls, + _source_type: Any, + _handler: GetCoreSchemaHandler, + ) -> CoreSchema: + """Tell Pydantic how to validate ServerExtension instances.""" + return core_schema.is_instance_schema(cls) + + @property + def params(self) -> dict[str, Any] | None: + """Extension parameters to advertise in AgentCard. + + Override this property to expose configuration that clients can read. + + Returns: + Dict of parameter names to values, or None. + """ + return None + + def agent_extension(self) -> AgentExtension: + """Generate the AgentExtension object for the AgentCard. + + Returns: + AgentExtension with this extension's URI, required flag, and params. + """ + return AgentExtension( + uri=self.uri, + required=self.required if self.required else None, + description=self.description, + params=self.params, + ) + + def is_active(self, context: ExtensionContext) -> bool: + """Check if this extension is active for the current request. + + An extension is active if the client declared support for it. + + Args: + context: The extension context for the current request. + + Returns: + True if the client supports this extension. + """ + return self.uri in context.client_extensions + + @abstractmethod + async def on_request(self, context: ExtensionContext) -> None: + """Called before agent execution if extension is active. + + Use this hook to: + - Read extension-specific metadata from the request + - Set up state for the execution + - Modify execution parameters via context.state + + Args: + context: The extension context with request metadata and state. + """ + ... + + @abstractmethod + async def on_response(self, context: ExtensionContext, result: Any) -> Any: + """Called after agent execution if extension is active. + + Use this hook to: + - Modify or enhance the result + - Add extension-specific metadata to the response + - Clean up any resources + + Args: + context: The extension context with request metadata and state. + result: The agent execution result. + + Returns: + The result, potentially modified. + """ + ... + + +class ServerExtensionRegistry: + """Registry for managing server-side A2A protocol extensions. + + Collects extensions and provides methods to generate AgentCapabilities + and invoke extension hooks during request processing. + """ + + def __init__(self, extensions: list[ServerExtension] | None = None) -> None: + """Initialize the registry with optional extensions. + + Args: + extensions: Initial list of extensions to register. + """ + self._extensions: list[ServerExtension] = list(extensions) if extensions else [] + self._by_uri: dict[str, ServerExtension] = { + ext.uri: ext for ext in self._extensions + } + + def register(self, extension: ServerExtension) -> None: + """Register an extension. + + Args: + extension: The extension to register. + + Raises: + ValueError: If an extension with the same URI is already registered. + """ + if extension.uri in self._by_uri: + raise ValueError(f"Extension already registered: {extension.uri}") + self._extensions.append(extension) + self._by_uri[extension.uri] = extension + + def get_agent_extensions(self) -> list[AgentExtension]: + """Get AgentExtension objects for all registered extensions. + + Returns: + List of AgentExtension objects for the AgentCard. + """ + return [ext.agent_extension() for ext in self._extensions] + + def get_extension(self, uri: str) -> ServerExtension | None: + """Get an extension by URI. + + Args: + uri: The extension URI. + + Returns: + The extension, or None if not found. + """ + return self._by_uri.get(uri) + + @staticmethod + def create_context( + metadata: dict[str, Any], + client_extensions: set[str], + server_context: ServerCallContext | None = None, + ) -> ExtensionContext: + """Create an ExtensionContext for a request. + + Args: + metadata: Request metadata dict. + client_extensions: Set of extension URIs from client. + server_context: Optional server call context. + + Returns: + ExtensionContext for use in hooks. + """ + return ExtensionContext( + metadata=metadata, + client_extensions=client_extensions, + server_context=server_context, + ) + + async def invoke_on_request(self, context: ExtensionContext) -> None: + """Invoke on_request hooks for all active extensions. + + Tracks activated extensions and isolates errors from individual hooks. + + Args: + context: The extension context for the request. + """ + for extension in self._extensions: + if extension.is_active(context): + try: + await extension.on_request(context) + if context.server_context is not None: + context.server_context.activated_extensions.add(extension.uri) + except Exception: + logger.exception( + "Extension on_request hook failed", + extra={"extension": extension.uri}, + ) + + async def invoke_on_response(self, context: ExtensionContext, result: Any) -> Any: + """Invoke on_response hooks for all active extensions. + + Isolates errors from individual hooks to prevent one failing extension + from breaking the entire response. + + Args: + context: The extension context for the request. + result: The agent execution result. + + Returns: + The result after all extensions have processed it. + """ + processed = result + for extension in self._extensions: + if extension.is_active(context): + try: + processed = await extension.on_response(context, processed) + except Exception: + logger.exception( + "Extension on_response hook failed", + extra={"extension": extension.uri}, + ) + return processed