mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-04-15 23:42:37 +00:00
feat: add server-side auth schemes and protocol extensions
- add server auth scheme base class and implementations (api key, bearer token, basic/digest auth, mtls) - add server-side extension system for a2a protocol extensions - add extensions middleware for x-a2a-extensions header management - add extension validation and registry utilities - enhance auth utilities with server-side support - add async intercept method to match client call interceptor protocol - fix type_checking import to resolve mypy errors with a2aconfig
This commit is contained in:
@@ -1,20 +1,36 @@
|
||||
"""A2A authentication schemas."""
|
||||
|
||||
from crewai.a2a.auth.schemas import (
|
||||
from crewai.a2a.auth.client_schemes import (
|
||||
APIKeyAuth,
|
||||
AuthScheme,
|
||||
BearerTokenAuth,
|
||||
ClientAuthScheme,
|
||||
HTTPBasicAuth,
|
||||
HTTPDigestAuth,
|
||||
OAuth2AuthorizationCode,
|
||||
OAuth2ClientCredentials,
|
||||
TLSConfig,
|
||||
)
|
||||
from crewai.a2a.auth.server_schemes import (
|
||||
AuthenticatedUser,
|
||||
OIDCAuth,
|
||||
ServerAuthScheme,
|
||||
SimpleTokenAuth,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"APIKeyAuth",
|
||||
"AuthScheme",
|
||||
"AuthenticatedUser",
|
||||
"BearerTokenAuth",
|
||||
"ClientAuthScheme",
|
||||
"HTTPBasicAuth",
|
||||
"HTTPDigestAuth",
|
||||
"OAuth2AuthorizationCode",
|
||||
"OAuth2ClientCredentials",
|
||||
"OIDCAuth",
|
||||
"ServerAuthScheme",
|
||||
"SimpleTokenAuth",
|
||||
"TLSConfig",
|
||||
]
|
||||
|
||||
739
lib/crewai/src/crewai/a2a/auth/server_schemes.py
Normal file
739
lib/crewai/src/crewai/a2a/auth/server_schemes.py
Normal file
@@ -0,0 +1,739 @@
|
||||
"""Server-side authentication schemes for A2A protocol.
|
||||
|
||||
These schemes validate incoming requests to A2A server endpoints.
|
||||
|
||||
Supported authentication methods:
|
||||
- Simple token validation with static bearer tokens
|
||||
- OpenID Connect with JWT validation using JWKS
|
||||
- OAuth2 with JWT validation or token introspection
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
import logging
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Annotated, Any, ClassVar, Literal
|
||||
|
||||
import jwt
|
||||
from jwt import PyJWKClient
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
BeforeValidator,
|
||||
ConfigDict,
|
||||
Field,
|
||||
HttpUrl,
|
||||
PrivateAttr,
|
||||
SecretStr,
|
||||
model_validator,
|
||||
)
|
||||
from typing_extensions import Self
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from a2a.types import OAuth2SecurityScheme
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
try:
|
||||
from fastapi import HTTPException, status as http_status
|
||||
|
||||
HTTP_401_UNAUTHORIZED = http_status.HTTP_401_UNAUTHORIZED
|
||||
HTTP_500_INTERNAL_SERVER_ERROR = http_status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
HTTP_503_SERVICE_UNAVAILABLE = http_status.HTTP_503_SERVICE_UNAVAILABLE
|
||||
except ImportError:
|
||||
|
||||
class HTTPException(Exception): # type: ignore[no-redef] # noqa: N818
|
||||
"""Fallback HTTPException when FastAPI is not installed."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
status_code: int,
|
||||
detail: str | None = None,
|
||||
headers: dict[str, str] | None = None,
|
||||
) -> None:
|
||||
self.status_code = status_code
|
||||
self.detail = detail
|
||||
self.headers = headers
|
||||
super().__init__(detail)
|
||||
|
||||
HTTP_401_UNAUTHORIZED = 401
|
||||
HTTP_500_INTERNAL_SERVER_ERROR = 500
|
||||
HTTP_503_SERVICE_UNAVAILABLE = 503
|
||||
|
||||
|
||||
def _coerce_secret_str(v: str | SecretStr | None) -> SecretStr | None:
|
||||
"""Coerce string to SecretStr."""
|
||||
if v is None or isinstance(v, SecretStr):
|
||||
return v
|
||||
return SecretStr(v)
|
||||
|
||||
|
||||
CoercedSecretStr = Annotated[SecretStr, BeforeValidator(_coerce_secret_str)]
|
||||
|
||||
JWTAlgorithm = Literal[
|
||||
"RS256",
|
||||
"RS384",
|
||||
"RS512",
|
||||
"ES256",
|
||||
"ES384",
|
||||
"ES512",
|
||||
"PS256",
|
||||
"PS384",
|
||||
"PS512",
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class AuthenticatedUser:
|
||||
"""Result of successful authentication.
|
||||
|
||||
Attributes:
|
||||
token: The original token that was validated.
|
||||
scheme: Name of the authentication scheme used.
|
||||
claims: JWT claims from OIDC or OAuth2 authentication.
|
||||
"""
|
||||
|
||||
token: str
|
||||
scheme: str
|
||||
claims: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class ServerAuthScheme(ABC, BaseModel):
|
||||
"""Base class for server-side authentication schemes.
|
||||
|
||||
Each scheme validates incoming requests and returns an AuthenticatedUser
|
||||
on success, or raises HTTPException on failure.
|
||||
"""
|
||||
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
|
||||
|
||||
@abstractmethod
|
||||
async def authenticate(self, token: str) -> AuthenticatedUser:
|
||||
"""Authenticate the provided token.
|
||||
|
||||
Args:
|
||||
token: The bearer token to authenticate.
|
||||
|
||||
Returns:
|
||||
AuthenticatedUser on successful authentication.
|
||||
|
||||
Raises:
|
||||
HTTPException: If authentication fails.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class SimpleTokenAuth(ServerAuthScheme):
|
||||
"""Simple bearer token authentication.
|
||||
|
||||
Validates tokens against a configured static token or AUTH_TOKEN env var.
|
||||
|
||||
Attributes:
|
||||
token: Expected token value. Falls back to AUTH_TOKEN env var if not set.
|
||||
"""
|
||||
|
||||
token: CoercedSecretStr | None = Field(
|
||||
default=None,
|
||||
description="Expected token. Falls back to AUTH_TOKEN env var.",
|
||||
)
|
||||
|
||||
def _get_expected_token(self) -> str | None:
|
||||
"""Get the expected token value."""
|
||||
if self.token:
|
||||
return self.token.get_secret_value()
|
||||
return os.environ.get("AUTH_TOKEN")
|
||||
|
||||
async def authenticate(self, token: str) -> AuthenticatedUser:
|
||||
"""Authenticate using simple token comparison.
|
||||
|
||||
Args:
|
||||
token: The bearer token to authenticate.
|
||||
|
||||
Returns:
|
||||
AuthenticatedUser on successful authentication.
|
||||
|
||||
Raises:
|
||||
HTTPException: If authentication fails.
|
||||
"""
|
||||
expected = self._get_expected_token()
|
||||
|
||||
if expected is None:
|
||||
logger.warning(
|
||||
"Simple token authentication failed",
|
||||
extra={"reason": "no_token_configured"},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=HTTP_401_UNAUTHORIZED,
|
||||
detail="Authentication not configured",
|
||||
)
|
||||
|
||||
if token != expected:
|
||||
raise HTTPException(
|
||||
status_code=HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid or missing authentication credentials",
|
||||
)
|
||||
|
||||
return AuthenticatedUser(
|
||||
token=token,
|
||||
scheme="simple_token",
|
||||
)
|
||||
|
||||
|
||||
class OIDCAuth(ServerAuthScheme):
|
||||
"""OpenID Connect authentication.
|
||||
|
||||
Validates JWTs using JWKS with caching support via PyJWT.
|
||||
|
||||
Attributes:
|
||||
issuer: The OpenID Connect issuer URL.
|
||||
audience: The expected audience claim.
|
||||
jwks_url: Optional explicit JWKS URL. Derived from issuer if not set.
|
||||
algorithms: List of allowed signing algorithms.
|
||||
required_claims: List of claims that must be present in the token.
|
||||
jwks_cache_ttl: TTL for JWKS cache in seconds.
|
||||
clock_skew_seconds: Allowed clock skew for token validation.
|
||||
"""
|
||||
|
||||
issuer: HttpUrl = Field(
|
||||
description="OpenID Connect issuer URL (e.g., https://auth.example.com)"
|
||||
)
|
||||
audience: str = Field(description="Expected audience claim (e.g., api://my-agent)")
|
||||
jwks_url: HttpUrl | None = Field(
|
||||
default=None,
|
||||
description="Explicit JWKS URL. Derived from issuer if not set.",
|
||||
)
|
||||
algorithms: list[str] = Field(
|
||||
default_factory=lambda: ["RS256"],
|
||||
description="List of allowed signing algorithms (RS256, ES256, etc.)",
|
||||
)
|
||||
required_claims: list[str] = Field(
|
||||
default_factory=lambda: ["exp", "iat", "iss", "aud", "sub"],
|
||||
description="List of claims that must be present in the token",
|
||||
)
|
||||
jwks_cache_ttl: int = Field(
|
||||
default=3600,
|
||||
description="TTL for JWKS cache in seconds",
|
||||
ge=60,
|
||||
)
|
||||
clock_skew_seconds: float = Field(
|
||||
default=30.0,
|
||||
description="Allowed clock skew for token validation",
|
||||
ge=0.0,
|
||||
)
|
||||
|
||||
_jwk_client: PyJWKClient | None = PrivateAttr(default=None)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _init_jwk_client(self) -> Self:
|
||||
"""Initialize the JWK client after model creation."""
|
||||
jwks_url = (
|
||||
str(self.jwks_url)
|
||||
if self.jwks_url
|
||||
else f"{str(self.issuer).rstrip('/')}/.well-known/jwks.json"
|
||||
)
|
||||
self._jwk_client = PyJWKClient(jwks_url, lifespan=self.jwks_cache_ttl)
|
||||
return self
|
||||
|
||||
async def authenticate(self, token: str) -> AuthenticatedUser:
|
||||
"""Authenticate using OIDC JWT validation.
|
||||
|
||||
Args:
|
||||
token: The JWT to authenticate.
|
||||
|
||||
Returns:
|
||||
AuthenticatedUser on successful authentication.
|
||||
|
||||
Raises:
|
||||
HTTPException: If authentication fails.
|
||||
"""
|
||||
if self._jwk_client is None:
|
||||
raise HTTPException(
|
||||
status_code=HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="OIDC not initialized",
|
||||
)
|
||||
|
||||
try:
|
||||
signing_key = self._jwk_client.get_signing_key_from_jwt(token)
|
||||
|
||||
claims = jwt.decode(
|
||||
token,
|
||||
signing_key.key,
|
||||
algorithms=self.algorithms,
|
||||
audience=self.audience,
|
||||
issuer=str(self.issuer).rstrip("/"),
|
||||
leeway=self.clock_skew_seconds,
|
||||
options={
|
||||
"require": self.required_claims,
|
||||
},
|
||||
)
|
||||
|
||||
return AuthenticatedUser(
|
||||
token=token,
|
||||
scheme="oidc",
|
||||
claims=claims,
|
||||
)
|
||||
|
||||
except jwt.ExpiredSignatureError:
|
||||
logger.debug(
|
||||
"OIDC authentication failed",
|
||||
extra={"reason": "token_expired", "scheme": "oidc"},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=HTTP_401_UNAUTHORIZED,
|
||||
detail="Token has expired",
|
||||
) from None
|
||||
except jwt.InvalidAudienceError:
|
||||
logger.debug(
|
||||
"OIDC authentication failed",
|
||||
extra={"reason": "invalid_audience", "scheme": "oidc"},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid token audience",
|
||||
) from None
|
||||
except jwt.InvalidIssuerError:
|
||||
logger.debug(
|
||||
"OIDC authentication failed",
|
||||
extra={"reason": "invalid_issuer", "scheme": "oidc"},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid token issuer",
|
||||
) from None
|
||||
except jwt.MissingRequiredClaimError as e:
|
||||
logger.debug(
|
||||
"OIDC authentication failed",
|
||||
extra={"reason": "missing_claim", "claim": e.claim, "scheme": "oidc"},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=HTTP_401_UNAUTHORIZED,
|
||||
detail=f"Missing required claim: {e.claim}",
|
||||
) from None
|
||||
except jwt.PyJWKClientError as e:
|
||||
logger.error(
|
||||
"OIDC authentication failed",
|
||||
extra={
|
||||
"reason": "jwks_client_error",
|
||||
"error": str(e),
|
||||
"scheme": "oidc",
|
||||
},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="Unable to fetch signing keys",
|
||||
) from None
|
||||
except jwt.InvalidTokenError as e:
|
||||
logger.debug(
|
||||
"OIDC authentication failed",
|
||||
extra={"reason": "invalid_token", "error": str(e), "scheme": "oidc"},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid or missing authentication credentials",
|
||||
) from None
|
||||
|
||||
|
||||
class OAuth2ServerAuth(ServerAuthScheme):
|
||||
"""OAuth2 authentication for A2A server.
|
||||
|
||||
Declares OAuth2 security scheme in AgentCard and validates tokens using
|
||||
either JWKS for JWT tokens or token introspection for opaque tokens.
|
||||
|
||||
This is distinct from OIDCAuth in that it declares an explicit OAuth2SecurityScheme
|
||||
with flows, rather than an OpenIdConnectSecurityScheme with discovery URL.
|
||||
|
||||
Attributes:
|
||||
token_url: OAuth2 token endpoint URL for client_credentials flow.
|
||||
authorization_url: OAuth2 authorization endpoint for authorization_code flow.
|
||||
refresh_url: Optional refresh token endpoint URL.
|
||||
scopes: Available OAuth2 scopes with descriptions.
|
||||
jwks_url: JWKS URL for JWT validation. Required if not using introspection.
|
||||
introspection_url: Token introspection endpoint (RFC 7662). Alternative to JWKS.
|
||||
introspection_client_id: Client ID for introspection endpoint authentication.
|
||||
introspection_client_secret: Client secret for introspection endpoint.
|
||||
audience: Expected audience claim for JWT validation.
|
||||
issuer: Expected issuer claim for JWT validation.
|
||||
algorithms: Allowed JWT signing algorithms.
|
||||
required_claims: Claims that must be present in the token.
|
||||
jwks_cache_ttl: TTL for JWKS cache in seconds.
|
||||
clock_skew_seconds: Allowed clock skew for token validation.
|
||||
"""
|
||||
|
||||
token_url: HttpUrl = Field(
|
||||
description="OAuth2 token endpoint URL",
|
||||
)
|
||||
authorization_url: HttpUrl | None = Field(
|
||||
default=None,
|
||||
description="OAuth2 authorization endpoint URL for authorization_code flow",
|
||||
)
|
||||
refresh_url: HttpUrl | None = Field(
|
||||
default=None,
|
||||
description="OAuth2 refresh token endpoint URL",
|
||||
)
|
||||
scopes: dict[str, str] = Field(
|
||||
default_factory=dict,
|
||||
description="Available OAuth2 scopes with descriptions",
|
||||
)
|
||||
jwks_url: HttpUrl | None = Field(
|
||||
default=None,
|
||||
description="JWKS URL for JWT validation. Required if not using introspection.",
|
||||
)
|
||||
introspection_url: HttpUrl | None = Field(
|
||||
default=None,
|
||||
description="Token introspection endpoint (RFC 7662). Alternative to JWKS.",
|
||||
)
|
||||
introspection_client_id: str | None = Field(
|
||||
default=None,
|
||||
description="Client ID for introspection endpoint authentication",
|
||||
)
|
||||
introspection_client_secret: CoercedSecretStr | None = Field(
|
||||
default=None,
|
||||
description="Client secret for introspection endpoint authentication",
|
||||
)
|
||||
audience: str | None = Field(
|
||||
default=None,
|
||||
description="Expected audience claim for JWT validation",
|
||||
)
|
||||
issuer: str | None = Field(
|
||||
default=None,
|
||||
description="Expected issuer claim for JWT validation",
|
||||
)
|
||||
algorithms: list[str] = Field(
|
||||
default_factory=lambda: ["RS256"],
|
||||
description="Allowed JWT signing algorithms",
|
||||
)
|
||||
required_claims: list[str] = Field(
|
||||
default_factory=lambda: ["exp", "iat"],
|
||||
description="Claims that must be present in the token",
|
||||
)
|
||||
jwks_cache_ttl: int = Field(
|
||||
default=3600,
|
||||
description="TTL for JWKS cache in seconds",
|
||||
ge=60,
|
||||
)
|
||||
clock_skew_seconds: float = Field(
|
||||
default=30.0,
|
||||
description="Allowed clock skew for token validation",
|
||||
ge=0.0,
|
||||
)
|
||||
|
||||
_jwk_client: PyJWKClient | None = PrivateAttr(default=None)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _validate_and_init(self) -> Self:
|
||||
"""Validate configuration and initialize JWKS client if needed."""
|
||||
if not self.jwks_url and not self.introspection_url:
|
||||
raise ValueError(
|
||||
"Either jwks_url or introspection_url must be provided for token validation"
|
||||
)
|
||||
|
||||
if self.introspection_url:
|
||||
if not self.introspection_client_id or not self.introspection_client_secret:
|
||||
raise ValueError(
|
||||
"introspection_client_id and introspection_client_secret are required "
|
||||
"when using token introspection"
|
||||
)
|
||||
|
||||
if self.jwks_url:
|
||||
self._jwk_client = PyJWKClient(
|
||||
str(self.jwks_url), lifespan=self.jwks_cache_ttl
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
async def authenticate(self, token: str) -> AuthenticatedUser:
|
||||
"""Authenticate using OAuth2 token validation.
|
||||
|
||||
Uses JWKS validation if jwks_url is configured, otherwise falls back
|
||||
to token introspection.
|
||||
|
||||
Args:
|
||||
token: The OAuth2 access token to authenticate.
|
||||
|
||||
Returns:
|
||||
AuthenticatedUser on successful authentication.
|
||||
|
||||
Raises:
|
||||
HTTPException: If authentication fails.
|
||||
"""
|
||||
if self._jwk_client:
|
||||
return await self._authenticate_jwt(token)
|
||||
return await self._authenticate_introspection(token)
|
||||
|
||||
async def _authenticate_jwt(self, token: str) -> AuthenticatedUser:
|
||||
"""Authenticate using JWKS JWT validation."""
|
||||
if self._jwk_client is None:
|
||||
raise HTTPException(
|
||||
status_code=HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="OAuth2 JWKS not initialized",
|
||||
)
|
||||
|
||||
try:
|
||||
signing_key = self._jwk_client.get_signing_key_from_jwt(token)
|
||||
|
||||
decode_options: dict[str, Any] = {
|
||||
"require": self.required_claims,
|
||||
}
|
||||
|
||||
claims = jwt.decode(
|
||||
token,
|
||||
signing_key.key,
|
||||
algorithms=self.algorithms,
|
||||
audience=self.audience,
|
||||
issuer=self.issuer,
|
||||
leeway=self.clock_skew_seconds,
|
||||
options=decode_options,
|
||||
)
|
||||
|
||||
return AuthenticatedUser(
|
||||
token=token,
|
||||
scheme="oauth2",
|
||||
claims=claims,
|
||||
)
|
||||
|
||||
except jwt.ExpiredSignatureError:
|
||||
logger.debug(
|
||||
"OAuth2 authentication failed",
|
||||
extra={"reason": "token_expired", "scheme": "oauth2"},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=HTTP_401_UNAUTHORIZED,
|
||||
detail="Token has expired",
|
||||
) from None
|
||||
except jwt.InvalidAudienceError:
|
||||
logger.debug(
|
||||
"OAuth2 authentication failed",
|
||||
extra={"reason": "invalid_audience", "scheme": "oauth2"},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid token audience",
|
||||
) from None
|
||||
except jwt.InvalidIssuerError:
|
||||
logger.debug(
|
||||
"OAuth2 authentication failed",
|
||||
extra={"reason": "invalid_issuer", "scheme": "oauth2"},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid token issuer",
|
||||
) from None
|
||||
except jwt.MissingRequiredClaimError as e:
|
||||
logger.debug(
|
||||
"OAuth2 authentication failed",
|
||||
extra={"reason": "missing_claim", "claim": e.claim, "scheme": "oauth2"},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=HTTP_401_UNAUTHORIZED,
|
||||
detail=f"Missing required claim: {e.claim}",
|
||||
) from None
|
||||
except jwt.PyJWKClientError as e:
|
||||
logger.error(
|
||||
"OAuth2 authentication failed",
|
||||
extra={
|
||||
"reason": "jwks_client_error",
|
||||
"error": str(e),
|
||||
"scheme": "oauth2",
|
||||
},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="Unable to fetch signing keys",
|
||||
) from None
|
||||
except jwt.InvalidTokenError as e:
|
||||
logger.debug(
|
||||
"OAuth2 authentication failed",
|
||||
extra={"reason": "invalid_token", "error": str(e), "scheme": "oauth2"},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid or missing authentication credentials",
|
||||
) from None
|
||||
|
||||
async def _authenticate_introspection(self, token: str) -> AuthenticatedUser:
|
||||
"""Authenticate using OAuth2 token introspection (RFC 7662)."""
|
||||
import httpx
|
||||
|
||||
if not self.introspection_url:
|
||||
raise HTTPException(
|
||||
status_code=HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="OAuth2 introspection not configured",
|
||||
)
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
str(self.introspection_url),
|
||||
data={"token": token},
|
||||
auth=(
|
||||
self.introspection_client_id or "",
|
||||
self.introspection_client_secret.get_secret_value()
|
||||
if self.introspection_client_secret
|
||||
else "",
|
||||
),
|
||||
)
|
||||
response.raise_for_status()
|
||||
introspection_result = response.json()
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(
|
||||
"OAuth2 introspection failed",
|
||||
extra={"reason": "http_error", "status_code": e.response.status_code},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="Token introspection service unavailable",
|
||||
) from None
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"OAuth2 introspection failed",
|
||||
extra={"reason": "unexpected_error", "error": str(e)},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="Token introspection failed",
|
||||
) from None
|
||||
|
||||
if not introspection_result.get("active", False):
|
||||
logger.debug(
|
||||
"OAuth2 authentication failed",
|
||||
extra={"reason": "token_not_active", "scheme": "oauth2"},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=HTTP_401_UNAUTHORIZED,
|
||||
detail="Token is not active",
|
||||
)
|
||||
|
||||
return AuthenticatedUser(
|
||||
token=token,
|
||||
scheme="oauth2",
|
||||
claims=introspection_result,
|
||||
)
|
||||
|
||||
def to_security_scheme(self) -> OAuth2SecurityScheme:
|
||||
"""Generate OAuth2SecurityScheme for AgentCard declaration.
|
||||
|
||||
Creates an OAuth2SecurityScheme with appropriate flows based on
|
||||
the configured URLs. Includes client_credentials flow if token_url
|
||||
is set, and authorization_code flow if authorization_url is set.
|
||||
|
||||
Returns:
|
||||
OAuth2SecurityScheme suitable for use in AgentCard security_schemes.
|
||||
"""
|
||||
from a2a.types import (
|
||||
AuthorizationCodeOAuthFlow,
|
||||
ClientCredentialsOAuthFlow,
|
||||
OAuth2SecurityScheme,
|
||||
OAuthFlows,
|
||||
)
|
||||
|
||||
client_credentials = None
|
||||
authorization_code = None
|
||||
|
||||
if self.token_url:
|
||||
client_credentials = ClientCredentialsOAuthFlow(
|
||||
token_url=str(self.token_url),
|
||||
refresh_url=str(self.refresh_url) if self.refresh_url else None,
|
||||
scopes=self.scopes,
|
||||
)
|
||||
|
||||
if self.authorization_url:
|
||||
authorization_code = AuthorizationCodeOAuthFlow(
|
||||
authorization_url=str(self.authorization_url),
|
||||
token_url=str(self.token_url),
|
||||
refresh_url=str(self.refresh_url) if self.refresh_url else None,
|
||||
scopes=self.scopes,
|
||||
)
|
||||
|
||||
return OAuth2SecurityScheme(
|
||||
flows=OAuthFlows(
|
||||
client_credentials=client_credentials,
|
||||
authorization_code=authorization_code,
|
||||
),
|
||||
description="OAuth2 authentication",
|
||||
)
|
||||
|
||||
|
||||
class APIKeyServerAuth(ServerAuthScheme):
|
||||
"""API Key authentication for A2A server.
|
||||
|
||||
Validates requests using an API key in a header, query parameter, or cookie.
|
||||
|
||||
Attributes:
|
||||
name: The name of the API key parameter (default: X-API-Key).
|
||||
location: Where to look for the API key (header, query, or cookie).
|
||||
api_key: The expected API key value.
|
||||
"""
|
||||
|
||||
name: str = Field(
|
||||
default="X-API-Key",
|
||||
description="Name of the API key parameter",
|
||||
)
|
||||
location: Literal["header", "query", "cookie"] = Field(
|
||||
default="header",
|
||||
description="Where to look for the API key",
|
||||
)
|
||||
api_key: CoercedSecretStr = Field(
|
||||
description="Expected API key value",
|
||||
)
|
||||
|
||||
async def authenticate(self, token: str) -> AuthenticatedUser:
|
||||
"""Authenticate using API key comparison.
|
||||
|
||||
Args:
|
||||
token: The API key to authenticate.
|
||||
|
||||
Returns:
|
||||
AuthenticatedUser on successful authentication.
|
||||
|
||||
Raises:
|
||||
HTTPException: If authentication fails.
|
||||
"""
|
||||
if token != self.api_key.get_secret_value():
|
||||
raise HTTPException(
|
||||
status_code=HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid API key",
|
||||
)
|
||||
|
||||
return AuthenticatedUser(
|
||||
token=token,
|
||||
scheme="api_key",
|
||||
)
|
||||
|
||||
|
||||
class MTLSServerAuth(ServerAuthScheme):
|
||||
"""Mutual TLS authentication marker for AgentCard declaration.
|
||||
|
||||
This scheme is primarily for AgentCard security_schemes declaration.
|
||||
Actual mTLS verification happens at the TLS/transport layer, not
|
||||
at the application layer via token validation.
|
||||
|
||||
When configured, this signals to clients that the server requires
|
||||
client certificates for authentication.
|
||||
"""
|
||||
|
||||
description: str = Field(
|
||||
default="Mutual TLS certificate authentication",
|
||||
description="Description for the security scheme",
|
||||
)
|
||||
|
||||
async def authenticate(self, token: str) -> AuthenticatedUser:
|
||||
"""Return authenticated user for mTLS.
|
||||
|
||||
mTLS verification happens at the transport layer before this is called.
|
||||
If we reach this point, the TLS handshake with client cert succeeded.
|
||||
|
||||
Args:
|
||||
token: Certificate subject or identifier (from TLS layer).
|
||||
|
||||
Returns:
|
||||
AuthenticatedUser indicating mTLS authentication.
|
||||
"""
|
||||
return AuthenticatedUser(
|
||||
token=token or "mtls-verified",
|
||||
scheme="mtls",
|
||||
)
|
||||
@@ -6,8 +6,10 @@ OAuth2, API keys, and HTTP authentication methods.
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Awaitable, Callable, MutableMapping
|
||||
import hashlib
|
||||
import re
|
||||
from typing import Final
|
||||
import threading
|
||||
from typing import Final, Literal, cast
|
||||
|
||||
from a2a.client.errors import A2AClientHTTPError
|
||||
from a2a.types import (
|
||||
@@ -18,10 +20,10 @@ from a2a.types import (
|
||||
)
|
||||
from httpx import AsyncClient, Response
|
||||
|
||||
from crewai.a2a.auth.schemas import (
|
||||
from crewai.a2a.auth.client_schemes import (
|
||||
APIKeyAuth,
|
||||
AuthScheme,
|
||||
BearerTokenAuth,
|
||||
ClientAuthScheme,
|
||||
HTTPBasicAuth,
|
||||
HTTPDigestAuth,
|
||||
OAuth2AuthorizationCode,
|
||||
@@ -29,12 +31,44 @@ from crewai.a2a.auth.schemas import (
|
||||
)
|
||||
|
||||
|
||||
_auth_store: dict[int, AuthScheme | None] = {}
|
||||
class _AuthStore:
|
||||
"""Store for authentication schemes with safe concurrent access."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._store: dict[str, ClientAuthScheme | None] = {}
|
||||
self._lock = threading.RLock()
|
||||
|
||||
@staticmethod
|
||||
def compute_key(auth_type: str, auth_data: str) -> str:
|
||||
"""Compute a collision-resistant key using SHA-256."""
|
||||
content = f"{auth_type}:{auth_data}"
|
||||
return hashlib.sha256(content.encode()).hexdigest()
|
||||
|
||||
def set(self, key: str, auth: ClientAuthScheme | None) -> None:
|
||||
"""Store an auth scheme."""
|
||||
with self._lock:
|
||||
self._store[key] = auth
|
||||
|
||||
def get(self, key: str) -> ClientAuthScheme | None:
|
||||
"""Retrieve an auth scheme by key."""
|
||||
with self._lock:
|
||||
return self._store.get(key)
|
||||
|
||||
def __setitem__(self, key: str, value: ClientAuthScheme | None) -> None:
|
||||
with self._lock:
|
||||
self._store[key] = value
|
||||
|
||||
def __getitem__(self, key: str) -> ClientAuthScheme | None:
|
||||
with self._lock:
|
||||
return self._store[key]
|
||||
|
||||
|
||||
_auth_store = _AuthStore()
|
||||
|
||||
_SCHEME_PATTERN: Final[re.Pattern[str]] = re.compile(r"(\w+)\s+(.+?)(?=,\s*\w+\s+|$)")
|
||||
_PARAM_PATTERN: Final[re.Pattern[str]] = re.compile(r'(\w+)=(?:"([^"]*)"|([^\s,]+))')
|
||||
|
||||
_SCHEME_AUTH_MAPPING: Final[dict[type, tuple[type[AuthScheme], ...]]] = {
|
||||
_SCHEME_AUTH_MAPPING: Final[dict[type, tuple[type[ClientAuthScheme], ...]]] = {
|
||||
OAuth2SecurityScheme: (
|
||||
OAuth2ClientCredentials,
|
||||
OAuth2AuthorizationCode,
|
||||
@@ -43,7 +77,9 @@ _SCHEME_AUTH_MAPPING: Final[dict[type, tuple[type[AuthScheme], ...]]] = {
|
||||
APIKeySecurityScheme: (APIKeyAuth,),
|
||||
}
|
||||
|
||||
_HTTP_SCHEME_MAPPING: Final[dict[str, type[AuthScheme]]] = {
|
||||
_HTTPSchemeType = Literal["basic", "digest", "bearer"]
|
||||
|
||||
_HTTP_SCHEME_MAPPING: Final[dict[_HTTPSchemeType, type[ClientAuthScheme]]] = {
|
||||
"basic": HTTPBasicAuth,
|
||||
"digest": HTTPDigestAuth,
|
||||
"bearer": BearerTokenAuth,
|
||||
@@ -51,8 +87,8 @@ _HTTP_SCHEME_MAPPING: Final[dict[str, type[AuthScheme]]] = {
|
||||
|
||||
|
||||
def _raise_auth_mismatch(
|
||||
expected_classes: type[AuthScheme] | tuple[type[AuthScheme], ...],
|
||||
provided_auth: AuthScheme,
|
||||
expected_classes: type[ClientAuthScheme] | tuple[type[ClientAuthScheme], ...],
|
||||
provided_auth: ClientAuthScheme,
|
||||
) -> None:
|
||||
"""Raise authentication mismatch error.
|
||||
|
||||
@@ -111,7 +147,7 @@ def parse_www_authenticate(header_value: str) -> dict[str, dict[str, str]]:
|
||||
|
||||
|
||||
def validate_auth_against_agent_card(
|
||||
agent_card: AgentCard, auth: AuthScheme | None
|
||||
agent_card: AgentCard, auth: ClientAuthScheme | None
|
||||
) -> None:
|
||||
"""Validate that provided auth matches AgentCard security requirements.
|
||||
|
||||
@@ -145,7 +181,8 @@ def validate_auth_against_agent_card(
|
||||
return
|
||||
|
||||
if isinstance(scheme, HTTPAuthSecurityScheme):
|
||||
if required_class := _HTTP_SCHEME_MAPPING.get(scheme.scheme.lower()):
|
||||
scheme_key = cast(_HTTPSchemeType, scheme.scheme.lower())
|
||||
if required_class := _HTTP_SCHEME_MAPPING.get(scheme_key):
|
||||
if not isinstance(auth, required_class):
|
||||
_raise_auth_mismatch(required_class, auth)
|
||||
return
|
||||
@@ -156,7 +193,7 @@ def validate_auth_against_agent_card(
|
||||
|
||||
async def retry_on_401(
|
||||
request_func: Callable[[], Awaitable[Response]],
|
||||
auth_scheme: AuthScheme | None,
|
||||
auth_scheme: ClientAuthScheme | None,
|
||||
client: AsyncClient,
|
||||
headers: MutableMapping[str, str],
|
||||
max_retries: int = 3,
|
||||
|
||||
@@ -1,4 +1,37 @@
|
||||
"""A2A Protocol Extensions for CrewAI.
|
||||
|
||||
This module contains extensions to the A2A (Agent-to-Agent) protocol.
|
||||
|
||||
**Client-side extensions** (A2AExtension) allow customizing how the A2A wrapper
|
||||
processes requests and responses during delegation to remote agents. These provide
|
||||
hooks for tool injection, prompt augmentation, and response processing.
|
||||
|
||||
**Server-side extensions** (ServerExtension) allow agents to offer additional
|
||||
functionality beyond the core A2A specification. Clients activate extensions
|
||||
via the X-A2A-Extensions header.
|
||||
|
||||
See: https://a2a-protocol.org/latest/topics/extensions/
|
||||
"""
|
||||
|
||||
from crewai.a2a.extensions.base import (
|
||||
A2AExtension,
|
||||
ConversationState,
|
||||
ExtensionRegistry,
|
||||
ValidatedA2AExtension,
|
||||
)
|
||||
from crewai.a2a.extensions.server import (
|
||||
ExtensionContext,
|
||||
ServerExtension,
|
||||
ServerExtensionRegistry,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"A2AExtension",
|
||||
"ConversationState",
|
||||
"ExtensionContext",
|
||||
"ExtensionRegistry",
|
||||
"ServerExtension",
|
||||
"ServerExtensionRegistry",
|
||||
"ValidatedA2AExtension",
|
||||
]
|
||||
|
||||
@@ -1,14 +1,20 @@
|
||||
"""Base extension interface for A2A wrapper integrations.
|
||||
"""Base extension interface for CrewAI A2A wrapper processing hooks.
|
||||
|
||||
This module defines the protocol for extending A2A wrapper functionality
|
||||
with custom logic for conversation processing, prompt augmentation, and
|
||||
agent response handling.
|
||||
This module defines the protocol for extending CrewAI's A2A wrapper functionality
|
||||
with custom logic for tool injection, prompt augmentation, and response processing.
|
||||
|
||||
Note: These are CrewAI-specific processing hooks, NOT A2A protocol extensions.
|
||||
A2A protocol extensions are capability declarations using AgentExtension objects
|
||||
in AgentCard.capabilities.extensions, activated via the A2A-Extensions HTTP header.
|
||||
See: https://a2a-protocol.org/latest/topics/extensions/
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import TYPE_CHECKING, Any, Protocol
|
||||
from typing import TYPE_CHECKING, Annotated, Any, Protocol, runtime_checkable
|
||||
|
||||
from pydantic import BeforeValidator
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -17,6 +23,20 @@ if TYPE_CHECKING:
|
||||
from crewai.agent.core import Agent
|
||||
|
||||
|
||||
def _validate_a2a_extension(v: Any) -> Any:
|
||||
"""Validate that value implements A2AExtension protocol."""
|
||||
if not isinstance(v, A2AExtension):
|
||||
raise ValueError(
|
||||
f"Value must implement A2AExtension protocol. "
|
||||
f"Got {type(v).__name__} which is missing required methods."
|
||||
)
|
||||
return v
|
||||
|
||||
|
||||
ValidatedA2AExtension = Annotated[Any, BeforeValidator(_validate_a2a_extension)]
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class ConversationState(Protocol):
|
||||
"""Protocol for extension-specific conversation state.
|
||||
|
||||
@@ -33,11 +53,36 @@ class ConversationState(Protocol):
|
||||
...
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class A2AExtension(Protocol):
|
||||
"""Protocol for A2A wrapper extensions.
|
||||
|
||||
Extensions can implement this protocol to inject custom logic into
|
||||
the A2A conversation flow at various integration points.
|
||||
|
||||
Example:
|
||||
class MyExtension:
|
||||
def inject_tools(self, agent: Agent) -> None:
|
||||
# Add custom tools to the agent
|
||||
pass
|
||||
|
||||
def extract_state_from_history(
|
||||
self, conversation_history: Sequence[Message]
|
||||
) -> ConversationState | None:
|
||||
# Extract state from conversation
|
||||
return None
|
||||
|
||||
def augment_prompt(
|
||||
self, base_prompt: str, conversation_state: ConversationState | None
|
||||
) -> str:
|
||||
# Add custom instructions
|
||||
return base_prompt
|
||||
|
||||
def process_response(
|
||||
self, agent_response: Any, conversation_state: ConversationState | None
|
||||
) -> Any:
|
||||
# Modify response if needed
|
||||
return agent_response
|
||||
"""
|
||||
|
||||
def inject_tools(self, agent: Agent) -> None:
|
||||
|
||||
@@ -1,34 +1,170 @@
|
||||
"""Extension registry factory for A2A configurations.
|
||||
"""A2A Protocol extension utilities.
|
||||
|
||||
This module provides utilities for creating extension registries from A2A configurations.
|
||||
This module provides utilities for working with A2A protocol extensions as
|
||||
defined in the A2A specification. Extensions are capability declarations in
|
||||
AgentCard.capabilities.extensions using AgentExtension objects, activated
|
||||
via the X-A2A-Extensions HTTP header.
|
||||
|
||||
See: https://a2a-protocol.org/latest/topics/extensions/
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import Any
|
||||
|
||||
from a2a.client.middleware import ClientCallContext, ClientCallInterceptor
|
||||
from a2a.extensions.common import (
|
||||
HTTP_EXTENSION_HEADER,
|
||||
)
|
||||
from a2a.types import AgentCard, AgentExtension
|
||||
|
||||
from crewai.a2a.config import A2AClientConfig, A2AConfig
|
||||
from crewai.a2a.extensions.base import ExtensionRegistry
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.a2a.config import A2AConfig
|
||||
def get_extensions_from_config(
|
||||
a2a_config: list[A2AConfig | A2AClientConfig] | A2AConfig | A2AClientConfig,
|
||||
) -> list[str]:
|
||||
"""Extract extension URIs from A2A configuration.
|
||||
|
||||
Args:
|
||||
a2a_config: A2A configuration (single or list).
|
||||
|
||||
Returns:
|
||||
Deduplicated list of extension URIs from all configs.
|
||||
"""
|
||||
configs = a2a_config if isinstance(a2a_config, list) else [a2a_config]
|
||||
seen: set[str] = set()
|
||||
result: list[str] = []
|
||||
|
||||
for config in configs:
|
||||
if not isinstance(config, A2AClientConfig):
|
||||
continue
|
||||
for uri in config.extensions:
|
||||
if uri not in seen:
|
||||
seen.add(uri)
|
||||
result.append(uri)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class ExtensionsMiddleware(ClientCallInterceptor):
|
||||
"""Middleware to add X-A2A-Extensions header to requests.
|
||||
|
||||
This middleware adds the extensions header to all outgoing requests,
|
||||
declaring which A2A protocol extensions the client supports.
|
||||
"""
|
||||
|
||||
def __init__(self, extensions: list[str]) -> None:
|
||||
"""Initialize with extension URIs.
|
||||
|
||||
Args:
|
||||
extensions: List of extension URIs the client supports.
|
||||
"""
|
||||
self._extensions = extensions
|
||||
|
||||
async def intercept(
|
||||
self,
|
||||
method_name: str,
|
||||
request_payload: dict[str, Any],
|
||||
http_kwargs: dict[str, Any],
|
||||
agent_card: AgentCard | None,
|
||||
context: ClientCallContext | None,
|
||||
) -> tuple[dict[str, Any], dict[str, Any]]:
|
||||
"""Add extensions header to the request.
|
||||
|
||||
Args:
|
||||
method_name: The A2A method being called.
|
||||
request_payload: The JSON-RPC request payload.
|
||||
http_kwargs: HTTP request kwargs (headers, etc).
|
||||
agent_card: The target agent's card.
|
||||
context: Optional call context.
|
||||
|
||||
Returns:
|
||||
Tuple of (request_payload, modified_http_kwargs).
|
||||
"""
|
||||
if self._extensions:
|
||||
headers = http_kwargs.setdefault("headers", {})
|
||||
headers[HTTP_EXTENSION_HEADER] = ",".join(self._extensions)
|
||||
return request_payload, http_kwargs
|
||||
|
||||
|
||||
def validate_required_extensions(
|
||||
agent_card: AgentCard,
|
||||
client_extensions: list[str] | None,
|
||||
) -> list[AgentExtension]:
|
||||
"""Validate that client supports all required extensions from agent.
|
||||
|
||||
Args:
|
||||
agent_card: The agent's card with declared extensions.
|
||||
client_extensions: Extension URIs the client supports.
|
||||
|
||||
Returns:
|
||||
List of unsupported required extensions.
|
||||
|
||||
Raises:
|
||||
None - returns list of unsupported extensions for caller to handle.
|
||||
"""
|
||||
unsupported: list[AgentExtension] = []
|
||||
client_set = set(client_extensions or [])
|
||||
|
||||
if not agent_card.capabilities or not agent_card.capabilities.extensions:
|
||||
return unsupported
|
||||
|
||||
unsupported.extend(
|
||||
ext
|
||||
for ext in agent_card.capabilities.extensions
|
||||
if ext.required and ext.uri not in client_set
|
||||
)
|
||||
|
||||
return unsupported
|
||||
|
||||
|
||||
def create_extension_registry_from_config(
|
||||
a2a_config: list[A2AConfig] | A2AConfig,
|
||||
a2a_config: list[A2AConfig | A2AClientConfig] | A2AConfig | A2AClientConfig,
|
||||
) -> ExtensionRegistry:
|
||||
"""Create an extension registry from A2A configuration.
|
||||
"""Create an extension registry from A2A client configuration.
|
||||
|
||||
Extracts client_extensions from each A2AClientConfig and registers them
|
||||
with the ExtensionRegistry. These extensions provide CrewAI-specific
|
||||
processing hooks (tool injection, prompt augmentation, response processing).
|
||||
|
||||
Note: A2A protocol extensions (URI strings sent via X-A2A-Extensions header)
|
||||
are handled separately via get_extensions_from_config() and ExtensionsMiddleware.
|
||||
|
||||
Args:
|
||||
a2a_config: A2A configuration (single or list)
|
||||
a2a_config: A2A configuration (single or list).
|
||||
|
||||
Returns:
|
||||
Configured extension registry with all applicable extensions
|
||||
Extension registry with all client_extensions registered.
|
||||
|
||||
Example:
|
||||
class LoggingExtension:
|
||||
def inject_tools(self, agent): pass
|
||||
def extract_state_from_history(self, history): return None
|
||||
def augment_prompt(self, prompt, state): return prompt
|
||||
def process_response(self, response, state):
|
||||
print(f"Response: {response}")
|
||||
return response
|
||||
|
||||
config = A2AClientConfig(
|
||||
endpoint="https://agent.example.com",
|
||||
client_extensions=[LoggingExtension()],
|
||||
)
|
||||
registry = create_extension_registry_from_config(config)
|
||||
"""
|
||||
registry = ExtensionRegistry()
|
||||
configs = a2a_config if isinstance(a2a_config, list) else [a2a_config]
|
||||
|
||||
for _ in configs:
|
||||
pass
|
||||
seen: set[int] = set()
|
||||
|
||||
for config in configs:
|
||||
if isinstance(config, (A2AConfig, A2AClientConfig)):
|
||||
client_exts = getattr(config, "client_extensions", [])
|
||||
for extension in client_exts:
|
||||
ext_id = id(extension)
|
||||
if ext_id not in seen:
|
||||
seen.add(ext_id)
|
||||
registry.register(extension)
|
||||
|
||||
return registry
|
||||
|
||||
305
lib/crewai/src/crewai/a2a/extensions/server.py
Normal file
305
lib/crewai/src/crewai/a2a/extensions/server.py
Normal file
@@ -0,0 +1,305 @@
|
||||
"""A2A protocol server extensions for CrewAI agents.
|
||||
|
||||
This module provides the base class and context for implementing A2A protocol
|
||||
extensions on the server side. Extensions allow agents to offer additional
|
||||
functionality beyond the core A2A specification.
|
||||
|
||||
See: https://a2a-protocol.org/latest/topics/extensions/
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Annotated, Any
|
||||
|
||||
from a2a.types import AgentExtension
|
||||
from pydantic_core import CoreSchema, core_schema
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from a2a.server.context import ServerCallContext
|
||||
from pydantic import GetCoreSchemaHandler
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExtensionContext:
|
||||
"""Context passed to extension hooks during request processing.
|
||||
|
||||
Provides access to request metadata, client extensions, and shared state
|
||||
that extensions can read from and write to.
|
||||
|
||||
Attributes:
|
||||
metadata: Request metadata dict, includes extension-namespaced keys.
|
||||
client_extensions: Set of extension URIs the client declared support for.
|
||||
state: Mutable dict for extensions to share data during request lifecycle.
|
||||
server_context: The underlying A2A server call context.
|
||||
"""
|
||||
|
||||
metadata: dict[str, Any]
|
||||
client_extensions: set[str]
|
||||
state: dict[str, Any] = field(default_factory=dict)
|
||||
server_context: ServerCallContext | None = None
|
||||
|
||||
def get_extension_metadata(self, uri: str, key: str) -> Any | None:
|
||||
"""Get extension-specific metadata value.
|
||||
|
||||
Extension metadata uses namespaced keys in the format:
|
||||
"{extension_uri}/{key}"
|
||||
|
||||
Args:
|
||||
uri: The extension URI.
|
||||
key: The metadata key within the extension namespace.
|
||||
|
||||
Returns:
|
||||
The metadata value, or None if not present.
|
||||
"""
|
||||
full_key = f"{uri}/{key}"
|
||||
return self.metadata.get(full_key)
|
||||
|
||||
def set_extension_metadata(self, uri: str, key: str, value: Any) -> None:
|
||||
"""Set extension-specific metadata value.
|
||||
|
||||
Args:
|
||||
uri: The extension URI.
|
||||
key: The metadata key within the extension namespace.
|
||||
value: The value to set.
|
||||
"""
|
||||
full_key = f"{uri}/{key}"
|
||||
self.metadata[full_key] = value
|
||||
|
||||
|
||||
class ServerExtension(ABC):
|
||||
"""Base class for A2A protocol server extensions.
|
||||
|
||||
Subclass this to create custom extensions that modify agent behavior
|
||||
when clients activate them. Extensions are identified by URI and can
|
||||
be marked as required.
|
||||
|
||||
Example:
|
||||
class SamplingExtension(ServerExtension):
|
||||
uri = "urn:crewai:ext:sampling/v1"
|
||||
required = True
|
||||
|
||||
def __init__(self, max_tokens: int = 4096):
|
||||
self.max_tokens = max_tokens
|
||||
|
||||
@property
|
||||
def params(self) -> dict[str, Any]:
|
||||
return {"max_tokens": self.max_tokens}
|
||||
|
||||
async def on_request(self, context: ExtensionContext) -> None:
|
||||
limit = context.get_extension_metadata(self.uri, "limit")
|
||||
if limit:
|
||||
context.state["token_limit"] = int(limit)
|
||||
|
||||
async def on_response(self, context: ExtensionContext, result: Any) -> Any:
|
||||
return result
|
||||
"""
|
||||
|
||||
uri: Annotated[str, "Extension URI identifier. Must be unique."]
|
||||
required: Annotated[bool, "Whether clients must support this extension."] = False
|
||||
description: Annotated[
|
||||
str | None, "Human-readable description of the extension."
|
||||
] = None
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(
|
||||
cls,
|
||||
_source_type: Any,
|
||||
_handler: GetCoreSchemaHandler,
|
||||
) -> CoreSchema:
|
||||
"""Tell Pydantic how to validate ServerExtension instances."""
|
||||
return core_schema.is_instance_schema(cls)
|
||||
|
||||
@property
|
||||
def params(self) -> dict[str, Any] | None:
|
||||
"""Extension parameters to advertise in AgentCard.
|
||||
|
||||
Override this property to expose configuration that clients can read.
|
||||
|
||||
Returns:
|
||||
Dict of parameter names to values, or None.
|
||||
"""
|
||||
return None
|
||||
|
||||
def agent_extension(self) -> AgentExtension:
|
||||
"""Generate the AgentExtension object for the AgentCard.
|
||||
|
||||
Returns:
|
||||
AgentExtension with this extension's URI, required flag, and params.
|
||||
"""
|
||||
return AgentExtension(
|
||||
uri=self.uri,
|
||||
required=self.required if self.required else None,
|
||||
description=self.description,
|
||||
params=self.params,
|
||||
)
|
||||
|
||||
def is_active(self, context: ExtensionContext) -> bool:
|
||||
"""Check if this extension is active for the current request.
|
||||
|
||||
An extension is active if the client declared support for it.
|
||||
|
||||
Args:
|
||||
context: The extension context for the current request.
|
||||
|
||||
Returns:
|
||||
True if the client supports this extension.
|
||||
"""
|
||||
return self.uri in context.client_extensions
|
||||
|
||||
@abstractmethod
|
||||
async def on_request(self, context: ExtensionContext) -> None:
|
||||
"""Called before agent execution if extension is active.
|
||||
|
||||
Use this hook to:
|
||||
- Read extension-specific metadata from the request
|
||||
- Set up state for the execution
|
||||
- Modify execution parameters via context.state
|
||||
|
||||
Args:
|
||||
context: The extension context with request metadata and state.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def on_response(self, context: ExtensionContext, result: Any) -> Any:
|
||||
"""Called after agent execution if extension is active.
|
||||
|
||||
Use this hook to:
|
||||
- Modify or enhance the result
|
||||
- Add extension-specific metadata to the response
|
||||
- Clean up any resources
|
||||
|
||||
Args:
|
||||
context: The extension context with request metadata and state.
|
||||
result: The agent execution result.
|
||||
|
||||
Returns:
|
||||
The result, potentially modified.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class ServerExtensionRegistry:
|
||||
"""Registry for managing server-side A2A protocol extensions.
|
||||
|
||||
Collects extensions and provides methods to generate AgentCapabilities
|
||||
and invoke extension hooks during request processing.
|
||||
"""
|
||||
|
||||
def __init__(self, extensions: list[ServerExtension] | None = None) -> None:
|
||||
"""Initialize the registry with optional extensions.
|
||||
|
||||
Args:
|
||||
extensions: Initial list of extensions to register.
|
||||
"""
|
||||
self._extensions: list[ServerExtension] = list(extensions) if extensions else []
|
||||
self._by_uri: dict[str, ServerExtension] = {
|
||||
ext.uri: ext for ext in self._extensions
|
||||
}
|
||||
|
||||
def register(self, extension: ServerExtension) -> None:
|
||||
"""Register an extension.
|
||||
|
||||
Args:
|
||||
extension: The extension to register.
|
||||
|
||||
Raises:
|
||||
ValueError: If an extension with the same URI is already registered.
|
||||
"""
|
||||
if extension.uri in self._by_uri:
|
||||
raise ValueError(f"Extension already registered: {extension.uri}")
|
||||
self._extensions.append(extension)
|
||||
self._by_uri[extension.uri] = extension
|
||||
|
||||
def get_agent_extensions(self) -> list[AgentExtension]:
|
||||
"""Get AgentExtension objects for all registered extensions.
|
||||
|
||||
Returns:
|
||||
List of AgentExtension objects for the AgentCard.
|
||||
"""
|
||||
return [ext.agent_extension() for ext in self._extensions]
|
||||
|
||||
def get_extension(self, uri: str) -> ServerExtension | None:
|
||||
"""Get an extension by URI.
|
||||
|
||||
Args:
|
||||
uri: The extension URI.
|
||||
|
||||
Returns:
|
||||
The extension, or None if not found.
|
||||
"""
|
||||
return self._by_uri.get(uri)
|
||||
|
||||
@staticmethod
|
||||
def create_context(
|
||||
metadata: dict[str, Any],
|
||||
client_extensions: set[str],
|
||||
server_context: ServerCallContext | None = None,
|
||||
) -> ExtensionContext:
|
||||
"""Create an ExtensionContext for a request.
|
||||
|
||||
Args:
|
||||
metadata: Request metadata dict.
|
||||
client_extensions: Set of extension URIs from client.
|
||||
server_context: Optional server call context.
|
||||
|
||||
Returns:
|
||||
ExtensionContext for use in hooks.
|
||||
"""
|
||||
return ExtensionContext(
|
||||
metadata=metadata,
|
||||
client_extensions=client_extensions,
|
||||
server_context=server_context,
|
||||
)
|
||||
|
||||
async def invoke_on_request(self, context: ExtensionContext) -> None:
|
||||
"""Invoke on_request hooks for all active extensions.
|
||||
|
||||
Tracks activated extensions and isolates errors from individual hooks.
|
||||
|
||||
Args:
|
||||
context: The extension context for the request.
|
||||
"""
|
||||
for extension in self._extensions:
|
||||
if extension.is_active(context):
|
||||
try:
|
||||
await extension.on_request(context)
|
||||
if context.server_context is not None:
|
||||
context.server_context.activated_extensions.add(extension.uri)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Extension on_request hook failed",
|
||||
extra={"extension": extension.uri},
|
||||
)
|
||||
|
||||
async def invoke_on_response(self, context: ExtensionContext, result: Any) -> Any:
|
||||
"""Invoke on_response hooks for all active extensions.
|
||||
|
||||
Isolates errors from individual hooks to prevent one failing extension
|
||||
from breaking the entire response.
|
||||
|
||||
Args:
|
||||
context: The extension context for the request.
|
||||
result: The agent execution result.
|
||||
|
||||
Returns:
|
||||
The result after all extensions have processed it.
|
||||
"""
|
||||
processed = result
|
||||
for extension in self._extensions:
|
||||
if extension.is_active(context):
|
||||
try:
|
||||
processed = await extension.on_response(context, processed)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Extension on_response hook failed",
|
||||
extra={"extension": extension.uri},
|
||||
)
|
||||
return processed
|
||||
Reference in New Issue
Block a user