feat: add server-side auth schemes and protocol extensions

- add server auth scheme base class and implementations (api key, bearer token, basic/digest auth, mtls)
- add server-side extension system for a2a protocol extensions
- add extensions middleware for x-a2a-extensions header management
- add extension validation and registry utilities
- enhance auth utilities with server-side support
- add async intercept method to match client call interceptor protocol
- fix type_checking import to resolve mypy errors with a2aconfig
This commit is contained in:
Greyson LaLonde
2026-01-29 04:52:36 -05:00
parent e291a97bdd
commit e4be1329a0
7 changed files with 1339 additions and 28 deletions

View File

@@ -1,20 +1,36 @@
"""A2A authentication schemas."""
from crewai.a2a.auth.schemas import (
from crewai.a2a.auth.client_schemes import (
APIKeyAuth,
AuthScheme,
BearerTokenAuth,
ClientAuthScheme,
HTTPBasicAuth,
HTTPDigestAuth,
OAuth2AuthorizationCode,
OAuth2ClientCredentials,
TLSConfig,
)
from crewai.a2a.auth.server_schemes import (
AuthenticatedUser,
OIDCAuth,
ServerAuthScheme,
SimpleTokenAuth,
)
__all__ = [
"APIKeyAuth",
"AuthScheme",
"AuthenticatedUser",
"BearerTokenAuth",
"ClientAuthScheme",
"HTTPBasicAuth",
"HTTPDigestAuth",
"OAuth2AuthorizationCode",
"OAuth2ClientCredentials",
"OIDCAuth",
"ServerAuthScheme",
"SimpleTokenAuth",
"TLSConfig",
]

View File

@@ -0,0 +1,739 @@
"""Server-side authentication schemes for A2A protocol.
These schemes validate incoming requests to A2A server endpoints.
Supported authentication methods:
- Simple token validation with static bearer tokens
- OpenID Connect with JWT validation using JWKS
- OAuth2 with JWT validation or token introspection
"""
from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass
import logging
import os
from typing import TYPE_CHECKING, Annotated, Any, ClassVar, Literal
import jwt
from jwt import PyJWKClient
from pydantic import (
BaseModel,
BeforeValidator,
ConfigDict,
Field,
HttpUrl,
PrivateAttr,
SecretStr,
model_validator,
)
from typing_extensions import Self
if TYPE_CHECKING:
from a2a.types import OAuth2SecurityScheme
logger = logging.getLogger(__name__)
try:
from fastapi import HTTPException, status as http_status
HTTP_401_UNAUTHORIZED = http_status.HTTP_401_UNAUTHORIZED
HTTP_500_INTERNAL_SERVER_ERROR = http_status.HTTP_500_INTERNAL_SERVER_ERROR
HTTP_503_SERVICE_UNAVAILABLE = http_status.HTTP_503_SERVICE_UNAVAILABLE
except ImportError:
class HTTPException(Exception): # type: ignore[no-redef] # noqa: N818
"""Fallback HTTPException when FastAPI is not installed."""
def __init__(
self,
status_code: int,
detail: str | None = None,
headers: dict[str, str] | None = None,
) -> None:
self.status_code = status_code
self.detail = detail
self.headers = headers
super().__init__(detail)
HTTP_401_UNAUTHORIZED = 401
HTTP_500_INTERNAL_SERVER_ERROR = 500
HTTP_503_SERVICE_UNAVAILABLE = 503
def _coerce_secret_str(v: str | SecretStr | None) -> SecretStr | None:
"""Coerce string to SecretStr."""
if v is None or isinstance(v, SecretStr):
return v
return SecretStr(v)
CoercedSecretStr = Annotated[SecretStr, BeforeValidator(_coerce_secret_str)]
JWTAlgorithm = Literal[
"RS256",
"RS384",
"RS512",
"ES256",
"ES384",
"ES512",
"PS256",
"PS384",
"PS512",
]
@dataclass
class AuthenticatedUser:
"""Result of successful authentication.
Attributes:
token: The original token that was validated.
scheme: Name of the authentication scheme used.
claims: JWT claims from OIDC or OAuth2 authentication.
"""
token: str
scheme: str
claims: dict[str, Any] | None = None
class ServerAuthScheme(ABC, BaseModel):
"""Base class for server-side authentication schemes.
Each scheme validates incoming requests and returns an AuthenticatedUser
on success, or raises HTTPException on failure.
"""
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
@abstractmethod
async def authenticate(self, token: str) -> AuthenticatedUser:
"""Authenticate the provided token.
Args:
token: The bearer token to authenticate.
Returns:
AuthenticatedUser on successful authentication.
Raises:
HTTPException: If authentication fails.
"""
...
class SimpleTokenAuth(ServerAuthScheme):
"""Simple bearer token authentication.
Validates tokens against a configured static token or AUTH_TOKEN env var.
Attributes:
token: Expected token value. Falls back to AUTH_TOKEN env var if not set.
"""
token: CoercedSecretStr | None = Field(
default=None,
description="Expected token. Falls back to AUTH_TOKEN env var.",
)
def _get_expected_token(self) -> str | None:
"""Get the expected token value."""
if self.token:
return self.token.get_secret_value()
return os.environ.get("AUTH_TOKEN")
async def authenticate(self, token: str) -> AuthenticatedUser:
"""Authenticate using simple token comparison.
Args:
token: The bearer token to authenticate.
Returns:
AuthenticatedUser on successful authentication.
Raises:
HTTPException: If authentication fails.
"""
expected = self._get_expected_token()
if expected is None:
logger.warning(
"Simple token authentication failed",
extra={"reason": "no_token_configured"},
)
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Authentication not configured",
)
if token != expected:
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Invalid or missing authentication credentials",
)
return AuthenticatedUser(
token=token,
scheme="simple_token",
)
class OIDCAuth(ServerAuthScheme):
"""OpenID Connect authentication.
Validates JWTs using JWKS with caching support via PyJWT.
Attributes:
issuer: The OpenID Connect issuer URL.
audience: The expected audience claim.
jwks_url: Optional explicit JWKS URL. Derived from issuer if not set.
algorithms: List of allowed signing algorithms.
required_claims: List of claims that must be present in the token.
jwks_cache_ttl: TTL for JWKS cache in seconds.
clock_skew_seconds: Allowed clock skew for token validation.
"""
issuer: HttpUrl = Field(
description="OpenID Connect issuer URL (e.g., https://auth.example.com)"
)
audience: str = Field(description="Expected audience claim (e.g., api://my-agent)")
jwks_url: HttpUrl | None = Field(
default=None,
description="Explicit JWKS URL. Derived from issuer if not set.",
)
algorithms: list[str] = Field(
default_factory=lambda: ["RS256"],
description="List of allowed signing algorithms (RS256, ES256, etc.)",
)
required_claims: list[str] = Field(
default_factory=lambda: ["exp", "iat", "iss", "aud", "sub"],
description="List of claims that must be present in the token",
)
jwks_cache_ttl: int = Field(
default=3600,
description="TTL for JWKS cache in seconds",
ge=60,
)
clock_skew_seconds: float = Field(
default=30.0,
description="Allowed clock skew for token validation",
ge=0.0,
)
_jwk_client: PyJWKClient | None = PrivateAttr(default=None)
@model_validator(mode="after")
def _init_jwk_client(self) -> Self:
"""Initialize the JWK client after model creation."""
jwks_url = (
str(self.jwks_url)
if self.jwks_url
else f"{str(self.issuer).rstrip('/')}/.well-known/jwks.json"
)
self._jwk_client = PyJWKClient(jwks_url, lifespan=self.jwks_cache_ttl)
return self
async def authenticate(self, token: str) -> AuthenticatedUser:
"""Authenticate using OIDC JWT validation.
Args:
token: The JWT to authenticate.
Returns:
AuthenticatedUser on successful authentication.
Raises:
HTTPException: If authentication fails.
"""
if self._jwk_client is None:
raise HTTPException(
status_code=HTTP_500_INTERNAL_SERVER_ERROR,
detail="OIDC not initialized",
)
try:
signing_key = self._jwk_client.get_signing_key_from_jwt(token)
claims = jwt.decode(
token,
signing_key.key,
algorithms=self.algorithms,
audience=self.audience,
issuer=str(self.issuer).rstrip("/"),
leeway=self.clock_skew_seconds,
options={
"require": self.required_claims,
},
)
return AuthenticatedUser(
token=token,
scheme="oidc",
claims=claims,
)
except jwt.ExpiredSignatureError:
logger.debug(
"OIDC authentication failed",
extra={"reason": "token_expired", "scheme": "oidc"},
)
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Token has expired",
) from None
except jwt.InvalidAudienceError:
logger.debug(
"OIDC authentication failed",
extra={"reason": "invalid_audience", "scheme": "oidc"},
)
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Invalid token audience",
) from None
except jwt.InvalidIssuerError:
logger.debug(
"OIDC authentication failed",
extra={"reason": "invalid_issuer", "scheme": "oidc"},
)
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Invalid token issuer",
) from None
except jwt.MissingRequiredClaimError as e:
logger.debug(
"OIDC authentication failed",
extra={"reason": "missing_claim", "claim": e.claim, "scheme": "oidc"},
)
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail=f"Missing required claim: {e.claim}",
) from None
except jwt.PyJWKClientError as e:
logger.error(
"OIDC authentication failed",
extra={
"reason": "jwks_client_error",
"error": str(e),
"scheme": "oidc",
},
)
raise HTTPException(
status_code=HTTP_503_SERVICE_UNAVAILABLE,
detail="Unable to fetch signing keys",
) from None
except jwt.InvalidTokenError as e:
logger.debug(
"OIDC authentication failed",
extra={"reason": "invalid_token", "error": str(e), "scheme": "oidc"},
)
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Invalid or missing authentication credentials",
) from None
class OAuth2ServerAuth(ServerAuthScheme):
"""OAuth2 authentication for A2A server.
Declares OAuth2 security scheme in AgentCard and validates tokens using
either JWKS for JWT tokens or token introspection for opaque tokens.
This is distinct from OIDCAuth in that it declares an explicit OAuth2SecurityScheme
with flows, rather than an OpenIdConnectSecurityScheme with discovery URL.
Attributes:
token_url: OAuth2 token endpoint URL for client_credentials flow.
authorization_url: OAuth2 authorization endpoint for authorization_code flow.
refresh_url: Optional refresh token endpoint URL.
scopes: Available OAuth2 scopes with descriptions.
jwks_url: JWKS URL for JWT validation. Required if not using introspection.
introspection_url: Token introspection endpoint (RFC 7662). Alternative to JWKS.
introspection_client_id: Client ID for introspection endpoint authentication.
introspection_client_secret: Client secret for introspection endpoint.
audience: Expected audience claim for JWT validation.
issuer: Expected issuer claim for JWT validation.
algorithms: Allowed JWT signing algorithms.
required_claims: Claims that must be present in the token.
jwks_cache_ttl: TTL for JWKS cache in seconds.
clock_skew_seconds: Allowed clock skew for token validation.
"""
token_url: HttpUrl = Field(
description="OAuth2 token endpoint URL",
)
authorization_url: HttpUrl | None = Field(
default=None,
description="OAuth2 authorization endpoint URL for authorization_code flow",
)
refresh_url: HttpUrl | None = Field(
default=None,
description="OAuth2 refresh token endpoint URL",
)
scopes: dict[str, str] = Field(
default_factory=dict,
description="Available OAuth2 scopes with descriptions",
)
jwks_url: HttpUrl | None = Field(
default=None,
description="JWKS URL for JWT validation. Required if not using introspection.",
)
introspection_url: HttpUrl | None = Field(
default=None,
description="Token introspection endpoint (RFC 7662). Alternative to JWKS.",
)
introspection_client_id: str | None = Field(
default=None,
description="Client ID for introspection endpoint authentication",
)
introspection_client_secret: CoercedSecretStr | None = Field(
default=None,
description="Client secret for introspection endpoint authentication",
)
audience: str | None = Field(
default=None,
description="Expected audience claim for JWT validation",
)
issuer: str | None = Field(
default=None,
description="Expected issuer claim for JWT validation",
)
algorithms: list[str] = Field(
default_factory=lambda: ["RS256"],
description="Allowed JWT signing algorithms",
)
required_claims: list[str] = Field(
default_factory=lambda: ["exp", "iat"],
description="Claims that must be present in the token",
)
jwks_cache_ttl: int = Field(
default=3600,
description="TTL for JWKS cache in seconds",
ge=60,
)
clock_skew_seconds: float = Field(
default=30.0,
description="Allowed clock skew for token validation",
ge=0.0,
)
_jwk_client: PyJWKClient | None = PrivateAttr(default=None)
@model_validator(mode="after")
def _validate_and_init(self) -> Self:
"""Validate configuration and initialize JWKS client if needed."""
if not self.jwks_url and not self.introspection_url:
raise ValueError(
"Either jwks_url or introspection_url must be provided for token validation"
)
if self.introspection_url:
if not self.introspection_client_id or not self.introspection_client_secret:
raise ValueError(
"introspection_client_id and introspection_client_secret are required "
"when using token introspection"
)
if self.jwks_url:
self._jwk_client = PyJWKClient(
str(self.jwks_url), lifespan=self.jwks_cache_ttl
)
return self
async def authenticate(self, token: str) -> AuthenticatedUser:
"""Authenticate using OAuth2 token validation.
Uses JWKS validation if jwks_url is configured, otherwise falls back
to token introspection.
Args:
token: The OAuth2 access token to authenticate.
Returns:
AuthenticatedUser on successful authentication.
Raises:
HTTPException: If authentication fails.
"""
if self._jwk_client:
return await self._authenticate_jwt(token)
return await self._authenticate_introspection(token)
async def _authenticate_jwt(self, token: str) -> AuthenticatedUser:
"""Authenticate using JWKS JWT validation."""
if self._jwk_client is None:
raise HTTPException(
status_code=HTTP_500_INTERNAL_SERVER_ERROR,
detail="OAuth2 JWKS not initialized",
)
try:
signing_key = self._jwk_client.get_signing_key_from_jwt(token)
decode_options: dict[str, Any] = {
"require": self.required_claims,
}
claims = jwt.decode(
token,
signing_key.key,
algorithms=self.algorithms,
audience=self.audience,
issuer=self.issuer,
leeway=self.clock_skew_seconds,
options=decode_options,
)
return AuthenticatedUser(
token=token,
scheme="oauth2",
claims=claims,
)
except jwt.ExpiredSignatureError:
logger.debug(
"OAuth2 authentication failed",
extra={"reason": "token_expired", "scheme": "oauth2"},
)
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Token has expired",
) from None
except jwt.InvalidAudienceError:
logger.debug(
"OAuth2 authentication failed",
extra={"reason": "invalid_audience", "scheme": "oauth2"},
)
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Invalid token audience",
) from None
except jwt.InvalidIssuerError:
logger.debug(
"OAuth2 authentication failed",
extra={"reason": "invalid_issuer", "scheme": "oauth2"},
)
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Invalid token issuer",
) from None
except jwt.MissingRequiredClaimError as e:
logger.debug(
"OAuth2 authentication failed",
extra={"reason": "missing_claim", "claim": e.claim, "scheme": "oauth2"},
)
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail=f"Missing required claim: {e.claim}",
) from None
except jwt.PyJWKClientError as e:
logger.error(
"OAuth2 authentication failed",
extra={
"reason": "jwks_client_error",
"error": str(e),
"scheme": "oauth2",
},
)
raise HTTPException(
status_code=HTTP_503_SERVICE_UNAVAILABLE,
detail="Unable to fetch signing keys",
) from None
except jwt.InvalidTokenError as e:
logger.debug(
"OAuth2 authentication failed",
extra={"reason": "invalid_token", "error": str(e), "scheme": "oauth2"},
)
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Invalid or missing authentication credentials",
) from None
async def _authenticate_introspection(self, token: str) -> AuthenticatedUser:
"""Authenticate using OAuth2 token introspection (RFC 7662)."""
import httpx
if not self.introspection_url:
raise HTTPException(
status_code=HTTP_500_INTERNAL_SERVER_ERROR,
detail="OAuth2 introspection not configured",
)
try:
async with httpx.AsyncClient() as client:
response = await client.post(
str(self.introspection_url),
data={"token": token},
auth=(
self.introspection_client_id or "",
self.introspection_client_secret.get_secret_value()
if self.introspection_client_secret
else "",
),
)
response.raise_for_status()
introspection_result = response.json()
except httpx.HTTPStatusError as e:
logger.error(
"OAuth2 introspection failed",
extra={"reason": "http_error", "status_code": e.response.status_code},
)
raise HTTPException(
status_code=HTTP_503_SERVICE_UNAVAILABLE,
detail="Token introspection service unavailable",
) from None
except Exception as e:
logger.error(
"OAuth2 introspection failed",
extra={"reason": "unexpected_error", "error": str(e)},
)
raise HTTPException(
status_code=HTTP_503_SERVICE_UNAVAILABLE,
detail="Token introspection failed",
) from None
if not introspection_result.get("active", False):
logger.debug(
"OAuth2 authentication failed",
extra={"reason": "token_not_active", "scheme": "oauth2"},
)
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Token is not active",
)
return AuthenticatedUser(
token=token,
scheme="oauth2",
claims=introspection_result,
)
def to_security_scheme(self) -> OAuth2SecurityScheme:
"""Generate OAuth2SecurityScheme for AgentCard declaration.
Creates an OAuth2SecurityScheme with appropriate flows based on
the configured URLs. Includes client_credentials flow if token_url
is set, and authorization_code flow if authorization_url is set.
Returns:
OAuth2SecurityScheme suitable for use in AgentCard security_schemes.
"""
from a2a.types import (
AuthorizationCodeOAuthFlow,
ClientCredentialsOAuthFlow,
OAuth2SecurityScheme,
OAuthFlows,
)
client_credentials = None
authorization_code = None
if self.token_url:
client_credentials = ClientCredentialsOAuthFlow(
token_url=str(self.token_url),
refresh_url=str(self.refresh_url) if self.refresh_url else None,
scopes=self.scopes,
)
if self.authorization_url:
authorization_code = AuthorizationCodeOAuthFlow(
authorization_url=str(self.authorization_url),
token_url=str(self.token_url),
refresh_url=str(self.refresh_url) if self.refresh_url else None,
scopes=self.scopes,
)
return OAuth2SecurityScheme(
flows=OAuthFlows(
client_credentials=client_credentials,
authorization_code=authorization_code,
),
description="OAuth2 authentication",
)
class APIKeyServerAuth(ServerAuthScheme):
"""API Key authentication for A2A server.
Validates requests using an API key in a header, query parameter, or cookie.
Attributes:
name: The name of the API key parameter (default: X-API-Key).
location: Where to look for the API key (header, query, or cookie).
api_key: The expected API key value.
"""
name: str = Field(
default="X-API-Key",
description="Name of the API key parameter",
)
location: Literal["header", "query", "cookie"] = Field(
default="header",
description="Where to look for the API key",
)
api_key: CoercedSecretStr = Field(
description="Expected API key value",
)
async def authenticate(self, token: str) -> AuthenticatedUser:
"""Authenticate using API key comparison.
Args:
token: The API key to authenticate.
Returns:
AuthenticatedUser on successful authentication.
Raises:
HTTPException: If authentication fails.
"""
if token != self.api_key.get_secret_value():
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Invalid API key",
)
return AuthenticatedUser(
token=token,
scheme="api_key",
)
class MTLSServerAuth(ServerAuthScheme):
"""Mutual TLS authentication marker for AgentCard declaration.
This scheme is primarily for AgentCard security_schemes declaration.
Actual mTLS verification happens at the TLS/transport layer, not
at the application layer via token validation.
When configured, this signals to clients that the server requires
client certificates for authentication.
"""
description: str = Field(
default="Mutual TLS certificate authentication",
description="Description for the security scheme",
)
async def authenticate(self, token: str) -> AuthenticatedUser:
"""Return authenticated user for mTLS.
mTLS verification happens at the transport layer before this is called.
If we reach this point, the TLS handshake with client cert succeeded.
Args:
token: Certificate subject or identifier (from TLS layer).
Returns:
AuthenticatedUser indicating mTLS authentication.
"""
return AuthenticatedUser(
token=token or "mtls-verified",
scheme="mtls",
)

View File

@@ -6,8 +6,10 @@ OAuth2, API keys, and HTTP authentication methods.
import asyncio
from collections.abc import Awaitable, Callable, MutableMapping
import hashlib
import re
from typing import Final
import threading
from typing import Final, Literal, cast
from a2a.client.errors import A2AClientHTTPError
from a2a.types import (
@@ -18,10 +20,10 @@ from a2a.types import (
)
from httpx import AsyncClient, Response
from crewai.a2a.auth.schemas import (
from crewai.a2a.auth.client_schemes import (
APIKeyAuth,
AuthScheme,
BearerTokenAuth,
ClientAuthScheme,
HTTPBasicAuth,
HTTPDigestAuth,
OAuth2AuthorizationCode,
@@ -29,12 +31,44 @@ from crewai.a2a.auth.schemas import (
)
_auth_store: dict[int, AuthScheme | None] = {}
class _AuthStore:
"""Store for authentication schemes with safe concurrent access."""
def __init__(self) -> None:
self._store: dict[str, ClientAuthScheme | None] = {}
self._lock = threading.RLock()
@staticmethod
def compute_key(auth_type: str, auth_data: str) -> str:
"""Compute a collision-resistant key using SHA-256."""
content = f"{auth_type}:{auth_data}"
return hashlib.sha256(content.encode()).hexdigest()
def set(self, key: str, auth: ClientAuthScheme | None) -> None:
"""Store an auth scheme."""
with self._lock:
self._store[key] = auth
def get(self, key: str) -> ClientAuthScheme | None:
"""Retrieve an auth scheme by key."""
with self._lock:
return self._store.get(key)
def __setitem__(self, key: str, value: ClientAuthScheme | None) -> None:
with self._lock:
self._store[key] = value
def __getitem__(self, key: str) -> ClientAuthScheme | None:
with self._lock:
return self._store[key]
_auth_store = _AuthStore()
_SCHEME_PATTERN: Final[re.Pattern[str]] = re.compile(r"(\w+)\s+(.+?)(?=,\s*\w+\s+|$)")
_PARAM_PATTERN: Final[re.Pattern[str]] = re.compile(r'(\w+)=(?:"([^"]*)"|([^\s,]+))')
_SCHEME_AUTH_MAPPING: Final[dict[type, tuple[type[AuthScheme], ...]]] = {
_SCHEME_AUTH_MAPPING: Final[dict[type, tuple[type[ClientAuthScheme], ...]]] = {
OAuth2SecurityScheme: (
OAuth2ClientCredentials,
OAuth2AuthorizationCode,
@@ -43,7 +77,9 @@ _SCHEME_AUTH_MAPPING: Final[dict[type, tuple[type[AuthScheme], ...]]] = {
APIKeySecurityScheme: (APIKeyAuth,),
}
_HTTP_SCHEME_MAPPING: Final[dict[str, type[AuthScheme]]] = {
_HTTPSchemeType = Literal["basic", "digest", "bearer"]
_HTTP_SCHEME_MAPPING: Final[dict[_HTTPSchemeType, type[ClientAuthScheme]]] = {
"basic": HTTPBasicAuth,
"digest": HTTPDigestAuth,
"bearer": BearerTokenAuth,
@@ -51,8 +87,8 @@ _HTTP_SCHEME_MAPPING: Final[dict[str, type[AuthScheme]]] = {
def _raise_auth_mismatch(
expected_classes: type[AuthScheme] | tuple[type[AuthScheme], ...],
provided_auth: AuthScheme,
expected_classes: type[ClientAuthScheme] | tuple[type[ClientAuthScheme], ...],
provided_auth: ClientAuthScheme,
) -> None:
"""Raise authentication mismatch error.
@@ -111,7 +147,7 @@ def parse_www_authenticate(header_value: str) -> dict[str, dict[str, str]]:
def validate_auth_against_agent_card(
agent_card: AgentCard, auth: AuthScheme | None
agent_card: AgentCard, auth: ClientAuthScheme | None
) -> None:
"""Validate that provided auth matches AgentCard security requirements.
@@ -145,7 +181,8 @@ def validate_auth_against_agent_card(
return
if isinstance(scheme, HTTPAuthSecurityScheme):
if required_class := _HTTP_SCHEME_MAPPING.get(scheme.scheme.lower()):
scheme_key = cast(_HTTPSchemeType, scheme.scheme.lower())
if required_class := _HTTP_SCHEME_MAPPING.get(scheme_key):
if not isinstance(auth, required_class):
_raise_auth_mismatch(required_class, auth)
return
@@ -156,7 +193,7 @@ def validate_auth_against_agent_card(
async def retry_on_401(
request_func: Callable[[], Awaitable[Response]],
auth_scheme: AuthScheme | None,
auth_scheme: ClientAuthScheme | None,
client: AsyncClient,
headers: MutableMapping[str, str],
max_retries: int = 3,

View File

@@ -1,4 +1,37 @@
"""A2A Protocol Extensions for CrewAI.
This module contains extensions to the A2A (Agent-to-Agent) protocol.
**Client-side extensions** (A2AExtension) allow customizing how the A2A wrapper
processes requests and responses during delegation to remote agents. These provide
hooks for tool injection, prompt augmentation, and response processing.
**Server-side extensions** (ServerExtension) allow agents to offer additional
functionality beyond the core A2A specification. Clients activate extensions
via the X-A2A-Extensions header.
See: https://a2a-protocol.org/latest/topics/extensions/
"""
from crewai.a2a.extensions.base import (
A2AExtension,
ConversationState,
ExtensionRegistry,
ValidatedA2AExtension,
)
from crewai.a2a.extensions.server import (
ExtensionContext,
ServerExtension,
ServerExtensionRegistry,
)
__all__ = [
"A2AExtension",
"ConversationState",
"ExtensionContext",
"ExtensionRegistry",
"ServerExtension",
"ServerExtensionRegistry",
"ValidatedA2AExtension",
]

View File

@@ -1,14 +1,20 @@
"""Base extension interface for A2A wrapper integrations.
"""Base extension interface for CrewAI A2A wrapper processing hooks.
This module defines the protocol for extending A2A wrapper functionality
with custom logic for conversation processing, prompt augmentation, and
agent response handling.
This module defines the protocol for extending CrewAI's A2A wrapper functionality
with custom logic for tool injection, prompt augmentation, and response processing.
Note: These are CrewAI-specific processing hooks, NOT A2A protocol extensions.
A2A protocol extensions are capability declarations using AgentExtension objects
in AgentCard.capabilities.extensions, activated via the A2A-Extensions HTTP header.
See: https://a2a-protocol.org/latest/topics/extensions/
"""
from __future__ import annotations
from collections.abc import Sequence
from typing import TYPE_CHECKING, Any, Protocol
from typing import TYPE_CHECKING, Annotated, Any, Protocol, runtime_checkable
from pydantic import BeforeValidator
if TYPE_CHECKING:
@@ -17,6 +23,20 @@ if TYPE_CHECKING:
from crewai.agent.core import Agent
def _validate_a2a_extension(v: Any) -> Any:
"""Validate that value implements A2AExtension protocol."""
if not isinstance(v, A2AExtension):
raise ValueError(
f"Value must implement A2AExtension protocol. "
f"Got {type(v).__name__} which is missing required methods."
)
return v
ValidatedA2AExtension = Annotated[Any, BeforeValidator(_validate_a2a_extension)]
@runtime_checkable
class ConversationState(Protocol):
"""Protocol for extension-specific conversation state.
@@ -33,11 +53,36 @@ class ConversationState(Protocol):
...
@runtime_checkable
class A2AExtension(Protocol):
"""Protocol for A2A wrapper extensions.
Extensions can implement this protocol to inject custom logic into
the A2A conversation flow at various integration points.
Example:
class MyExtension:
def inject_tools(self, agent: Agent) -> None:
# Add custom tools to the agent
pass
def extract_state_from_history(
self, conversation_history: Sequence[Message]
) -> ConversationState | None:
# Extract state from conversation
return None
def augment_prompt(
self, base_prompt: str, conversation_state: ConversationState | None
) -> str:
# Add custom instructions
return base_prompt
def process_response(
self, agent_response: Any, conversation_state: ConversationState | None
) -> Any:
# Modify response if needed
return agent_response
"""
def inject_tools(self, agent: Agent) -> None:

View File

@@ -1,34 +1,170 @@
"""Extension registry factory for A2A configurations.
"""A2A Protocol extension utilities.
This module provides utilities for creating extension registries from A2A configurations.
This module provides utilities for working with A2A protocol extensions as
defined in the A2A specification. Extensions are capability declarations in
AgentCard.capabilities.extensions using AgentExtension objects, activated
via the X-A2A-Extensions HTTP header.
See: https://a2a-protocol.org/latest/topics/extensions/
"""
from __future__ import annotations
from typing import TYPE_CHECKING
from typing import Any
from a2a.client.middleware import ClientCallContext, ClientCallInterceptor
from a2a.extensions.common import (
HTTP_EXTENSION_HEADER,
)
from a2a.types import AgentCard, AgentExtension
from crewai.a2a.config import A2AClientConfig, A2AConfig
from crewai.a2a.extensions.base import ExtensionRegistry
if TYPE_CHECKING:
from crewai.a2a.config import A2AConfig
def get_extensions_from_config(
a2a_config: list[A2AConfig | A2AClientConfig] | A2AConfig | A2AClientConfig,
) -> list[str]:
"""Extract extension URIs from A2A configuration.
Args:
a2a_config: A2A configuration (single or list).
Returns:
Deduplicated list of extension URIs from all configs.
"""
configs = a2a_config if isinstance(a2a_config, list) else [a2a_config]
seen: set[str] = set()
result: list[str] = []
for config in configs:
if not isinstance(config, A2AClientConfig):
continue
for uri in config.extensions:
if uri not in seen:
seen.add(uri)
result.append(uri)
return result
class ExtensionsMiddleware(ClientCallInterceptor):
"""Middleware to add X-A2A-Extensions header to requests.
This middleware adds the extensions header to all outgoing requests,
declaring which A2A protocol extensions the client supports.
"""
def __init__(self, extensions: list[str]) -> None:
"""Initialize with extension URIs.
Args:
extensions: List of extension URIs the client supports.
"""
self._extensions = extensions
async def intercept(
self,
method_name: str,
request_payload: dict[str, Any],
http_kwargs: dict[str, Any],
agent_card: AgentCard | None,
context: ClientCallContext | None,
) -> tuple[dict[str, Any], dict[str, Any]]:
"""Add extensions header to the request.
Args:
method_name: The A2A method being called.
request_payload: The JSON-RPC request payload.
http_kwargs: HTTP request kwargs (headers, etc).
agent_card: The target agent's card.
context: Optional call context.
Returns:
Tuple of (request_payload, modified_http_kwargs).
"""
if self._extensions:
headers = http_kwargs.setdefault("headers", {})
headers[HTTP_EXTENSION_HEADER] = ",".join(self._extensions)
return request_payload, http_kwargs
def validate_required_extensions(
agent_card: AgentCard,
client_extensions: list[str] | None,
) -> list[AgentExtension]:
"""Validate that client supports all required extensions from agent.
Args:
agent_card: The agent's card with declared extensions.
client_extensions: Extension URIs the client supports.
Returns:
List of unsupported required extensions.
Raises:
None - returns list of unsupported extensions for caller to handle.
"""
unsupported: list[AgentExtension] = []
client_set = set(client_extensions or [])
if not agent_card.capabilities or not agent_card.capabilities.extensions:
return unsupported
unsupported.extend(
ext
for ext in agent_card.capabilities.extensions
if ext.required and ext.uri not in client_set
)
return unsupported
def create_extension_registry_from_config(
a2a_config: list[A2AConfig] | A2AConfig,
a2a_config: list[A2AConfig | A2AClientConfig] | A2AConfig | A2AClientConfig,
) -> ExtensionRegistry:
"""Create an extension registry from A2A configuration.
"""Create an extension registry from A2A client configuration.
Extracts client_extensions from each A2AClientConfig and registers them
with the ExtensionRegistry. These extensions provide CrewAI-specific
processing hooks (tool injection, prompt augmentation, response processing).
Note: A2A protocol extensions (URI strings sent via X-A2A-Extensions header)
are handled separately via get_extensions_from_config() and ExtensionsMiddleware.
Args:
a2a_config: A2A configuration (single or list)
a2a_config: A2A configuration (single or list).
Returns:
Configured extension registry with all applicable extensions
Extension registry with all client_extensions registered.
Example:
class LoggingExtension:
def inject_tools(self, agent): pass
def extract_state_from_history(self, history): return None
def augment_prompt(self, prompt, state): return prompt
def process_response(self, response, state):
print(f"Response: {response}")
return response
config = A2AClientConfig(
endpoint="https://agent.example.com",
client_extensions=[LoggingExtension()],
)
registry = create_extension_registry_from_config(config)
"""
registry = ExtensionRegistry()
configs = a2a_config if isinstance(a2a_config, list) else [a2a_config]
for _ in configs:
pass
seen: set[int] = set()
for config in configs:
if isinstance(config, (A2AConfig, A2AClientConfig)):
client_exts = getattr(config, "client_extensions", [])
for extension in client_exts:
ext_id = id(extension)
if ext_id not in seen:
seen.add(ext_id)
registry.register(extension)
return registry

View File

@@ -0,0 +1,305 @@
"""A2A protocol server extensions for CrewAI agents.
This module provides the base class and context for implementing A2A protocol
extensions on the server side. Extensions allow agents to offer additional
functionality beyond the core A2A specification.
See: https://a2a-protocol.org/latest/topics/extensions/
"""
from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
import logging
from typing import TYPE_CHECKING, Annotated, Any
from a2a.types import AgentExtension
from pydantic_core import CoreSchema, core_schema
if TYPE_CHECKING:
from a2a.server.context import ServerCallContext
from pydantic import GetCoreSchemaHandler
logger = logging.getLogger(__name__)
@dataclass
class ExtensionContext:
"""Context passed to extension hooks during request processing.
Provides access to request metadata, client extensions, and shared state
that extensions can read from and write to.
Attributes:
metadata: Request metadata dict, includes extension-namespaced keys.
client_extensions: Set of extension URIs the client declared support for.
state: Mutable dict for extensions to share data during request lifecycle.
server_context: The underlying A2A server call context.
"""
metadata: dict[str, Any]
client_extensions: set[str]
state: dict[str, Any] = field(default_factory=dict)
server_context: ServerCallContext | None = None
def get_extension_metadata(self, uri: str, key: str) -> Any | None:
"""Get extension-specific metadata value.
Extension metadata uses namespaced keys in the format:
"{extension_uri}/{key}"
Args:
uri: The extension URI.
key: The metadata key within the extension namespace.
Returns:
The metadata value, or None if not present.
"""
full_key = f"{uri}/{key}"
return self.metadata.get(full_key)
def set_extension_metadata(self, uri: str, key: str, value: Any) -> None:
"""Set extension-specific metadata value.
Args:
uri: The extension URI.
key: The metadata key within the extension namespace.
value: The value to set.
"""
full_key = f"{uri}/{key}"
self.metadata[full_key] = value
class ServerExtension(ABC):
"""Base class for A2A protocol server extensions.
Subclass this to create custom extensions that modify agent behavior
when clients activate them. Extensions are identified by URI and can
be marked as required.
Example:
class SamplingExtension(ServerExtension):
uri = "urn:crewai:ext:sampling/v1"
required = True
def __init__(self, max_tokens: int = 4096):
self.max_tokens = max_tokens
@property
def params(self) -> dict[str, Any]:
return {"max_tokens": self.max_tokens}
async def on_request(self, context: ExtensionContext) -> None:
limit = context.get_extension_metadata(self.uri, "limit")
if limit:
context.state["token_limit"] = int(limit)
async def on_response(self, context: ExtensionContext, result: Any) -> Any:
return result
"""
uri: Annotated[str, "Extension URI identifier. Must be unique."]
required: Annotated[bool, "Whether clients must support this extension."] = False
description: Annotated[
str | None, "Human-readable description of the extension."
] = None
@classmethod
def __get_pydantic_core_schema__(
cls,
_source_type: Any,
_handler: GetCoreSchemaHandler,
) -> CoreSchema:
"""Tell Pydantic how to validate ServerExtension instances."""
return core_schema.is_instance_schema(cls)
@property
def params(self) -> dict[str, Any] | None:
"""Extension parameters to advertise in AgentCard.
Override this property to expose configuration that clients can read.
Returns:
Dict of parameter names to values, or None.
"""
return None
def agent_extension(self) -> AgentExtension:
"""Generate the AgentExtension object for the AgentCard.
Returns:
AgentExtension with this extension's URI, required flag, and params.
"""
return AgentExtension(
uri=self.uri,
required=self.required if self.required else None,
description=self.description,
params=self.params,
)
def is_active(self, context: ExtensionContext) -> bool:
"""Check if this extension is active for the current request.
An extension is active if the client declared support for it.
Args:
context: The extension context for the current request.
Returns:
True if the client supports this extension.
"""
return self.uri in context.client_extensions
@abstractmethod
async def on_request(self, context: ExtensionContext) -> None:
"""Called before agent execution if extension is active.
Use this hook to:
- Read extension-specific metadata from the request
- Set up state for the execution
- Modify execution parameters via context.state
Args:
context: The extension context with request metadata and state.
"""
...
@abstractmethod
async def on_response(self, context: ExtensionContext, result: Any) -> Any:
"""Called after agent execution if extension is active.
Use this hook to:
- Modify or enhance the result
- Add extension-specific metadata to the response
- Clean up any resources
Args:
context: The extension context with request metadata and state.
result: The agent execution result.
Returns:
The result, potentially modified.
"""
...
class ServerExtensionRegistry:
"""Registry for managing server-side A2A protocol extensions.
Collects extensions and provides methods to generate AgentCapabilities
and invoke extension hooks during request processing.
"""
def __init__(self, extensions: list[ServerExtension] | None = None) -> None:
"""Initialize the registry with optional extensions.
Args:
extensions: Initial list of extensions to register.
"""
self._extensions: list[ServerExtension] = list(extensions) if extensions else []
self._by_uri: dict[str, ServerExtension] = {
ext.uri: ext for ext in self._extensions
}
def register(self, extension: ServerExtension) -> None:
"""Register an extension.
Args:
extension: The extension to register.
Raises:
ValueError: If an extension with the same URI is already registered.
"""
if extension.uri in self._by_uri:
raise ValueError(f"Extension already registered: {extension.uri}")
self._extensions.append(extension)
self._by_uri[extension.uri] = extension
def get_agent_extensions(self) -> list[AgentExtension]:
"""Get AgentExtension objects for all registered extensions.
Returns:
List of AgentExtension objects for the AgentCard.
"""
return [ext.agent_extension() for ext in self._extensions]
def get_extension(self, uri: str) -> ServerExtension | None:
"""Get an extension by URI.
Args:
uri: The extension URI.
Returns:
The extension, or None if not found.
"""
return self._by_uri.get(uri)
@staticmethod
def create_context(
metadata: dict[str, Any],
client_extensions: set[str],
server_context: ServerCallContext | None = None,
) -> ExtensionContext:
"""Create an ExtensionContext for a request.
Args:
metadata: Request metadata dict.
client_extensions: Set of extension URIs from client.
server_context: Optional server call context.
Returns:
ExtensionContext for use in hooks.
"""
return ExtensionContext(
metadata=metadata,
client_extensions=client_extensions,
server_context=server_context,
)
async def invoke_on_request(self, context: ExtensionContext) -> None:
"""Invoke on_request hooks for all active extensions.
Tracks activated extensions and isolates errors from individual hooks.
Args:
context: The extension context for the request.
"""
for extension in self._extensions:
if extension.is_active(context):
try:
await extension.on_request(context)
if context.server_context is not None:
context.server_context.activated_extensions.add(extension.uri)
except Exception:
logger.exception(
"Extension on_request hook failed",
extra={"extension": extension.uri},
)
async def invoke_on_response(self, context: ExtensionContext, result: Any) -> Any:
"""Invoke on_response hooks for all active extensions.
Isolates errors from individual hooks to prevent one failing extension
from breaking the entire response.
Args:
context: The extension context for the request.
result: The agent execution result.
Returns:
The result after all extensions have processed it.
"""
processed = result
for extension in self._extensions:
if extension.is_active(context):
try:
processed = await extension.on_response(context, processed)
except Exception:
logger.exception(
"Extension on_response hook failed",
extra={"extension": extension.uri},
)
return processed