mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-05-01 07:13:00 +00:00
feat: add server-side auth schemes and protocol extensions
- add server auth scheme base class and implementations (api key, bearer token, basic/digest auth, mtls) - add server-side extension system for a2a protocol extensions - add extensions middleware for x-a2a-extensions header management - add extension validation and registry utilities - enhance auth utilities with server-side support - add async intercept method to match client call interceptor protocol - fix type_checking import to resolve mypy errors with a2aconfig
This commit is contained in:
@@ -1,20 +1,36 @@
|
|||||||
"""A2A authentication schemas."""
|
"""A2A authentication schemas."""
|
||||||
|
|
||||||
from crewai.a2a.auth.schemas import (
|
from crewai.a2a.auth.client_schemes import (
|
||||||
APIKeyAuth,
|
APIKeyAuth,
|
||||||
|
AuthScheme,
|
||||||
BearerTokenAuth,
|
BearerTokenAuth,
|
||||||
|
ClientAuthScheme,
|
||||||
HTTPBasicAuth,
|
HTTPBasicAuth,
|
||||||
HTTPDigestAuth,
|
HTTPDigestAuth,
|
||||||
OAuth2AuthorizationCode,
|
OAuth2AuthorizationCode,
|
||||||
OAuth2ClientCredentials,
|
OAuth2ClientCredentials,
|
||||||
|
TLSConfig,
|
||||||
|
)
|
||||||
|
from crewai.a2a.auth.server_schemes import (
|
||||||
|
AuthenticatedUser,
|
||||||
|
OIDCAuth,
|
||||||
|
ServerAuthScheme,
|
||||||
|
SimpleTokenAuth,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"APIKeyAuth",
|
"APIKeyAuth",
|
||||||
|
"AuthScheme",
|
||||||
|
"AuthenticatedUser",
|
||||||
"BearerTokenAuth",
|
"BearerTokenAuth",
|
||||||
|
"ClientAuthScheme",
|
||||||
"HTTPBasicAuth",
|
"HTTPBasicAuth",
|
||||||
"HTTPDigestAuth",
|
"HTTPDigestAuth",
|
||||||
"OAuth2AuthorizationCode",
|
"OAuth2AuthorizationCode",
|
||||||
"OAuth2ClientCredentials",
|
"OAuth2ClientCredentials",
|
||||||
|
"OIDCAuth",
|
||||||
|
"ServerAuthScheme",
|
||||||
|
"SimpleTokenAuth",
|
||||||
|
"TLSConfig",
|
||||||
]
|
]
|
||||||
|
|||||||
739
lib/crewai/src/crewai/a2a/auth/server_schemes.py
Normal file
739
lib/crewai/src/crewai/a2a/auth/server_schemes.py
Normal file
@@ -0,0 +1,739 @@
|
|||||||
|
"""Server-side authentication schemes for A2A protocol.
|
||||||
|
|
||||||
|
These schemes validate incoming requests to A2A server endpoints.
|
||||||
|
|
||||||
|
Supported authentication methods:
|
||||||
|
- Simple token validation with static bearer tokens
|
||||||
|
- OpenID Connect with JWT validation using JWKS
|
||||||
|
- OAuth2 with JWT validation or token introspection
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from dataclasses import dataclass
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from typing import TYPE_CHECKING, Annotated, Any, ClassVar, Literal
|
||||||
|
|
||||||
|
import jwt
|
||||||
|
from jwt import PyJWKClient
|
||||||
|
from pydantic import (
|
||||||
|
BaseModel,
|
||||||
|
BeforeValidator,
|
||||||
|
ConfigDict,
|
||||||
|
Field,
|
||||||
|
HttpUrl,
|
||||||
|
PrivateAttr,
|
||||||
|
SecretStr,
|
||||||
|
model_validator,
|
||||||
|
)
|
||||||
|
from typing_extensions import Self
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from a2a.types import OAuth2SecurityScheme
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
from fastapi import HTTPException, status as http_status
|
||||||
|
|
||||||
|
HTTP_401_UNAUTHORIZED = http_status.HTTP_401_UNAUTHORIZED
|
||||||
|
HTTP_500_INTERNAL_SERVER_ERROR = http_status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||||
|
HTTP_503_SERVICE_UNAVAILABLE = http_status.HTTP_503_SERVICE_UNAVAILABLE
|
||||||
|
except ImportError:
|
||||||
|
|
||||||
|
class HTTPException(Exception): # type: ignore[no-redef] # noqa: N818
|
||||||
|
"""Fallback HTTPException when FastAPI is not installed."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
status_code: int,
|
||||||
|
detail: str | None = None,
|
||||||
|
headers: dict[str, str] | None = None,
|
||||||
|
) -> None:
|
||||||
|
self.status_code = status_code
|
||||||
|
self.detail = detail
|
||||||
|
self.headers = headers
|
||||||
|
super().__init__(detail)
|
||||||
|
|
||||||
|
HTTP_401_UNAUTHORIZED = 401
|
||||||
|
HTTP_500_INTERNAL_SERVER_ERROR = 500
|
||||||
|
HTTP_503_SERVICE_UNAVAILABLE = 503
|
||||||
|
|
||||||
|
|
||||||
|
def _coerce_secret_str(v: str | SecretStr | None) -> SecretStr | None:
|
||||||
|
"""Coerce string to SecretStr."""
|
||||||
|
if v is None or isinstance(v, SecretStr):
|
||||||
|
return v
|
||||||
|
return SecretStr(v)
|
||||||
|
|
||||||
|
|
||||||
|
CoercedSecretStr = Annotated[SecretStr, BeforeValidator(_coerce_secret_str)]
|
||||||
|
|
||||||
|
JWTAlgorithm = Literal[
|
||||||
|
"RS256",
|
||||||
|
"RS384",
|
||||||
|
"RS512",
|
||||||
|
"ES256",
|
||||||
|
"ES384",
|
||||||
|
"ES512",
|
||||||
|
"PS256",
|
||||||
|
"PS384",
|
||||||
|
"PS512",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AuthenticatedUser:
|
||||||
|
"""Result of successful authentication.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
token: The original token that was validated.
|
||||||
|
scheme: Name of the authentication scheme used.
|
||||||
|
claims: JWT claims from OIDC or OAuth2 authentication.
|
||||||
|
"""
|
||||||
|
|
||||||
|
token: str
|
||||||
|
scheme: str
|
||||||
|
claims: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class ServerAuthScheme(ABC, BaseModel):
|
||||||
|
"""Base class for server-side authentication schemes.
|
||||||
|
|
||||||
|
Each scheme validates incoming requests and returns an AuthenticatedUser
|
||||||
|
on success, or raises HTTPException on failure.
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def authenticate(self, token: str) -> AuthenticatedUser:
|
||||||
|
"""Authenticate the provided token.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token: The bearer token to authenticate.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AuthenticatedUser on successful authentication.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException: If authentication fails.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class SimpleTokenAuth(ServerAuthScheme):
|
||||||
|
"""Simple bearer token authentication.
|
||||||
|
|
||||||
|
Validates tokens against a configured static token or AUTH_TOKEN env var.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
token: Expected token value. Falls back to AUTH_TOKEN env var if not set.
|
||||||
|
"""
|
||||||
|
|
||||||
|
token: CoercedSecretStr | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="Expected token. Falls back to AUTH_TOKEN env var.",
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_expected_token(self) -> str | None:
|
||||||
|
"""Get the expected token value."""
|
||||||
|
if self.token:
|
||||||
|
return self.token.get_secret_value()
|
||||||
|
return os.environ.get("AUTH_TOKEN")
|
||||||
|
|
||||||
|
async def authenticate(self, token: str) -> AuthenticatedUser:
|
||||||
|
"""Authenticate using simple token comparison.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token: The bearer token to authenticate.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AuthenticatedUser on successful authentication.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException: If authentication fails.
|
||||||
|
"""
|
||||||
|
expected = self._get_expected_token()
|
||||||
|
|
||||||
|
if expected is None:
|
||||||
|
logger.warning(
|
||||||
|
"Simple token authentication failed",
|
||||||
|
extra={"reason": "no_token_configured"},
|
||||||
|
)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Authentication not configured",
|
||||||
|
)
|
||||||
|
|
||||||
|
if token != expected:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Invalid or missing authentication credentials",
|
||||||
|
)
|
||||||
|
|
||||||
|
return AuthenticatedUser(
|
||||||
|
token=token,
|
||||||
|
scheme="simple_token",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class OIDCAuth(ServerAuthScheme):
|
||||||
|
"""OpenID Connect authentication.
|
||||||
|
|
||||||
|
Validates JWTs using JWKS with caching support via PyJWT.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
issuer: The OpenID Connect issuer URL.
|
||||||
|
audience: The expected audience claim.
|
||||||
|
jwks_url: Optional explicit JWKS URL. Derived from issuer if not set.
|
||||||
|
algorithms: List of allowed signing algorithms.
|
||||||
|
required_claims: List of claims that must be present in the token.
|
||||||
|
jwks_cache_ttl: TTL for JWKS cache in seconds.
|
||||||
|
clock_skew_seconds: Allowed clock skew for token validation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
issuer: HttpUrl = Field(
|
||||||
|
description="OpenID Connect issuer URL (e.g., https://auth.example.com)"
|
||||||
|
)
|
||||||
|
audience: str = Field(description="Expected audience claim (e.g., api://my-agent)")
|
||||||
|
jwks_url: HttpUrl | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="Explicit JWKS URL. Derived from issuer if not set.",
|
||||||
|
)
|
||||||
|
algorithms: list[str] = Field(
|
||||||
|
default_factory=lambda: ["RS256"],
|
||||||
|
description="List of allowed signing algorithms (RS256, ES256, etc.)",
|
||||||
|
)
|
||||||
|
required_claims: list[str] = Field(
|
||||||
|
default_factory=lambda: ["exp", "iat", "iss", "aud", "sub"],
|
||||||
|
description="List of claims that must be present in the token",
|
||||||
|
)
|
||||||
|
jwks_cache_ttl: int = Field(
|
||||||
|
default=3600,
|
||||||
|
description="TTL for JWKS cache in seconds",
|
||||||
|
ge=60,
|
||||||
|
)
|
||||||
|
clock_skew_seconds: float = Field(
|
||||||
|
default=30.0,
|
||||||
|
description="Allowed clock skew for token validation",
|
||||||
|
ge=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
_jwk_client: PyJWKClient | None = PrivateAttr(default=None)
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def _init_jwk_client(self) -> Self:
|
||||||
|
"""Initialize the JWK client after model creation."""
|
||||||
|
jwks_url = (
|
||||||
|
str(self.jwks_url)
|
||||||
|
if self.jwks_url
|
||||||
|
else f"{str(self.issuer).rstrip('/')}/.well-known/jwks.json"
|
||||||
|
)
|
||||||
|
self._jwk_client = PyJWKClient(jwks_url, lifespan=self.jwks_cache_ttl)
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def authenticate(self, token: str) -> AuthenticatedUser:
|
||||||
|
"""Authenticate using OIDC JWT validation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token: The JWT to authenticate.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AuthenticatedUser on successful authentication.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException: If authentication fails.
|
||||||
|
"""
|
||||||
|
if self._jwk_client is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="OIDC not initialized",
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
signing_key = self._jwk_client.get_signing_key_from_jwt(token)
|
||||||
|
|
||||||
|
claims = jwt.decode(
|
||||||
|
token,
|
||||||
|
signing_key.key,
|
||||||
|
algorithms=self.algorithms,
|
||||||
|
audience=self.audience,
|
||||||
|
issuer=str(self.issuer).rstrip("/"),
|
||||||
|
leeway=self.clock_skew_seconds,
|
||||||
|
options={
|
||||||
|
"require": self.required_claims,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
return AuthenticatedUser(
|
||||||
|
token=token,
|
||||||
|
scheme="oidc",
|
||||||
|
claims=claims,
|
||||||
|
)
|
||||||
|
|
||||||
|
except jwt.ExpiredSignatureError:
|
||||||
|
logger.debug(
|
||||||
|
"OIDC authentication failed",
|
||||||
|
extra={"reason": "token_expired", "scheme": "oidc"},
|
||||||
|
)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Token has expired",
|
||||||
|
) from None
|
||||||
|
except jwt.InvalidAudienceError:
|
||||||
|
logger.debug(
|
||||||
|
"OIDC authentication failed",
|
||||||
|
extra={"reason": "invalid_audience", "scheme": "oidc"},
|
||||||
|
)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Invalid token audience",
|
||||||
|
) from None
|
||||||
|
except jwt.InvalidIssuerError:
|
||||||
|
logger.debug(
|
||||||
|
"OIDC authentication failed",
|
||||||
|
extra={"reason": "invalid_issuer", "scheme": "oidc"},
|
||||||
|
)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Invalid token issuer",
|
||||||
|
) from None
|
||||||
|
except jwt.MissingRequiredClaimError as e:
|
||||||
|
logger.debug(
|
||||||
|
"OIDC authentication failed",
|
||||||
|
extra={"reason": "missing_claim", "claim": e.claim, "scheme": "oidc"},
|
||||||
|
)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=HTTP_401_UNAUTHORIZED,
|
||||||
|
detail=f"Missing required claim: {e.claim}",
|
||||||
|
) from None
|
||||||
|
except jwt.PyJWKClientError as e:
|
||||||
|
logger.error(
|
||||||
|
"OIDC authentication failed",
|
||||||
|
extra={
|
||||||
|
"reason": "jwks_client_error",
|
||||||
|
"error": str(e),
|
||||||
|
"scheme": "oidc",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=HTTP_503_SERVICE_UNAVAILABLE,
|
||||||
|
detail="Unable to fetch signing keys",
|
||||||
|
) from None
|
||||||
|
except jwt.InvalidTokenError as e:
|
||||||
|
logger.debug(
|
||||||
|
"OIDC authentication failed",
|
||||||
|
extra={"reason": "invalid_token", "error": str(e), "scheme": "oidc"},
|
||||||
|
)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Invalid or missing authentication credentials",
|
||||||
|
) from None
|
||||||
|
|
||||||
|
|
||||||
|
class OAuth2ServerAuth(ServerAuthScheme):
|
||||||
|
"""OAuth2 authentication for A2A server.
|
||||||
|
|
||||||
|
Declares OAuth2 security scheme in AgentCard and validates tokens using
|
||||||
|
either JWKS for JWT tokens or token introspection for opaque tokens.
|
||||||
|
|
||||||
|
This is distinct from OIDCAuth in that it declares an explicit OAuth2SecurityScheme
|
||||||
|
with flows, rather than an OpenIdConnectSecurityScheme with discovery URL.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
token_url: OAuth2 token endpoint URL for client_credentials flow.
|
||||||
|
authorization_url: OAuth2 authorization endpoint for authorization_code flow.
|
||||||
|
refresh_url: Optional refresh token endpoint URL.
|
||||||
|
scopes: Available OAuth2 scopes with descriptions.
|
||||||
|
jwks_url: JWKS URL for JWT validation. Required if not using introspection.
|
||||||
|
introspection_url: Token introspection endpoint (RFC 7662). Alternative to JWKS.
|
||||||
|
introspection_client_id: Client ID for introspection endpoint authentication.
|
||||||
|
introspection_client_secret: Client secret for introspection endpoint.
|
||||||
|
audience: Expected audience claim for JWT validation.
|
||||||
|
issuer: Expected issuer claim for JWT validation.
|
||||||
|
algorithms: Allowed JWT signing algorithms.
|
||||||
|
required_claims: Claims that must be present in the token.
|
||||||
|
jwks_cache_ttl: TTL for JWKS cache in seconds.
|
||||||
|
clock_skew_seconds: Allowed clock skew for token validation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
token_url: HttpUrl = Field(
|
||||||
|
description="OAuth2 token endpoint URL",
|
||||||
|
)
|
||||||
|
authorization_url: HttpUrl | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="OAuth2 authorization endpoint URL for authorization_code flow",
|
||||||
|
)
|
||||||
|
refresh_url: HttpUrl | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="OAuth2 refresh token endpoint URL",
|
||||||
|
)
|
||||||
|
scopes: dict[str, str] = Field(
|
||||||
|
default_factory=dict,
|
||||||
|
description="Available OAuth2 scopes with descriptions",
|
||||||
|
)
|
||||||
|
jwks_url: HttpUrl | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="JWKS URL for JWT validation. Required if not using introspection.",
|
||||||
|
)
|
||||||
|
introspection_url: HttpUrl | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="Token introspection endpoint (RFC 7662). Alternative to JWKS.",
|
||||||
|
)
|
||||||
|
introspection_client_id: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="Client ID for introspection endpoint authentication",
|
||||||
|
)
|
||||||
|
introspection_client_secret: CoercedSecretStr | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="Client secret for introspection endpoint authentication",
|
||||||
|
)
|
||||||
|
audience: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="Expected audience claim for JWT validation",
|
||||||
|
)
|
||||||
|
issuer: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="Expected issuer claim for JWT validation",
|
||||||
|
)
|
||||||
|
algorithms: list[str] = Field(
|
||||||
|
default_factory=lambda: ["RS256"],
|
||||||
|
description="Allowed JWT signing algorithms",
|
||||||
|
)
|
||||||
|
required_claims: list[str] = Field(
|
||||||
|
default_factory=lambda: ["exp", "iat"],
|
||||||
|
description="Claims that must be present in the token",
|
||||||
|
)
|
||||||
|
jwks_cache_ttl: int = Field(
|
||||||
|
default=3600,
|
||||||
|
description="TTL for JWKS cache in seconds",
|
||||||
|
ge=60,
|
||||||
|
)
|
||||||
|
clock_skew_seconds: float = Field(
|
||||||
|
default=30.0,
|
||||||
|
description="Allowed clock skew for token validation",
|
||||||
|
ge=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
_jwk_client: PyJWKClient | None = PrivateAttr(default=None)
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def _validate_and_init(self) -> Self:
|
||||||
|
"""Validate configuration and initialize JWKS client if needed."""
|
||||||
|
if not self.jwks_url and not self.introspection_url:
|
||||||
|
raise ValueError(
|
||||||
|
"Either jwks_url or introspection_url must be provided for token validation"
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.introspection_url:
|
||||||
|
if not self.introspection_client_id or not self.introspection_client_secret:
|
||||||
|
raise ValueError(
|
||||||
|
"introspection_client_id and introspection_client_secret are required "
|
||||||
|
"when using token introspection"
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.jwks_url:
|
||||||
|
self._jwk_client = PyJWKClient(
|
||||||
|
str(self.jwks_url), lifespan=self.jwks_cache_ttl
|
||||||
|
)
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def authenticate(self, token: str) -> AuthenticatedUser:
|
||||||
|
"""Authenticate using OAuth2 token validation.
|
||||||
|
|
||||||
|
Uses JWKS validation if jwks_url is configured, otherwise falls back
|
||||||
|
to token introspection.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token: The OAuth2 access token to authenticate.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AuthenticatedUser on successful authentication.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException: If authentication fails.
|
||||||
|
"""
|
||||||
|
if self._jwk_client:
|
||||||
|
return await self._authenticate_jwt(token)
|
||||||
|
return await self._authenticate_introspection(token)
|
||||||
|
|
||||||
|
async def _authenticate_jwt(self, token: str) -> AuthenticatedUser:
|
||||||
|
"""Authenticate using JWKS JWT validation."""
|
||||||
|
if self._jwk_client is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="OAuth2 JWKS not initialized",
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
signing_key = self._jwk_client.get_signing_key_from_jwt(token)
|
||||||
|
|
||||||
|
decode_options: dict[str, Any] = {
|
||||||
|
"require": self.required_claims,
|
||||||
|
}
|
||||||
|
|
||||||
|
claims = jwt.decode(
|
||||||
|
token,
|
||||||
|
signing_key.key,
|
||||||
|
algorithms=self.algorithms,
|
||||||
|
audience=self.audience,
|
||||||
|
issuer=self.issuer,
|
||||||
|
leeway=self.clock_skew_seconds,
|
||||||
|
options=decode_options,
|
||||||
|
)
|
||||||
|
|
||||||
|
return AuthenticatedUser(
|
||||||
|
token=token,
|
||||||
|
scheme="oauth2",
|
||||||
|
claims=claims,
|
||||||
|
)
|
||||||
|
|
||||||
|
except jwt.ExpiredSignatureError:
|
||||||
|
logger.debug(
|
||||||
|
"OAuth2 authentication failed",
|
||||||
|
extra={"reason": "token_expired", "scheme": "oauth2"},
|
||||||
|
)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Token has expired",
|
||||||
|
) from None
|
||||||
|
except jwt.InvalidAudienceError:
|
||||||
|
logger.debug(
|
||||||
|
"OAuth2 authentication failed",
|
||||||
|
extra={"reason": "invalid_audience", "scheme": "oauth2"},
|
||||||
|
)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Invalid token audience",
|
||||||
|
) from None
|
||||||
|
except jwt.InvalidIssuerError:
|
||||||
|
logger.debug(
|
||||||
|
"OAuth2 authentication failed",
|
||||||
|
extra={"reason": "invalid_issuer", "scheme": "oauth2"},
|
||||||
|
)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Invalid token issuer",
|
||||||
|
) from None
|
||||||
|
except jwt.MissingRequiredClaimError as e:
|
||||||
|
logger.debug(
|
||||||
|
"OAuth2 authentication failed",
|
||||||
|
extra={"reason": "missing_claim", "claim": e.claim, "scheme": "oauth2"},
|
||||||
|
)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=HTTP_401_UNAUTHORIZED,
|
||||||
|
detail=f"Missing required claim: {e.claim}",
|
||||||
|
) from None
|
||||||
|
except jwt.PyJWKClientError as e:
|
||||||
|
logger.error(
|
||||||
|
"OAuth2 authentication failed",
|
||||||
|
extra={
|
||||||
|
"reason": "jwks_client_error",
|
||||||
|
"error": str(e),
|
||||||
|
"scheme": "oauth2",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=HTTP_503_SERVICE_UNAVAILABLE,
|
||||||
|
detail="Unable to fetch signing keys",
|
||||||
|
) from None
|
||||||
|
except jwt.InvalidTokenError as e:
|
||||||
|
logger.debug(
|
||||||
|
"OAuth2 authentication failed",
|
||||||
|
extra={"reason": "invalid_token", "error": str(e), "scheme": "oauth2"},
|
||||||
|
)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Invalid or missing authentication credentials",
|
||||||
|
) from None
|
||||||
|
|
||||||
|
async def _authenticate_introspection(self, token: str) -> AuthenticatedUser:
|
||||||
|
"""Authenticate using OAuth2 token introspection (RFC 7662)."""
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
if not self.introspection_url:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="OAuth2 introspection not configured",
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
response = await client.post(
|
||||||
|
str(self.introspection_url),
|
||||||
|
data={"token": token},
|
||||||
|
auth=(
|
||||||
|
self.introspection_client_id or "",
|
||||||
|
self.introspection_client_secret.get_secret_value()
|
||||||
|
if self.introspection_client_secret
|
||||||
|
else "",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
introspection_result = response.json()
|
||||||
|
|
||||||
|
except httpx.HTTPStatusError as e:
|
||||||
|
logger.error(
|
||||||
|
"OAuth2 introspection failed",
|
||||||
|
extra={"reason": "http_error", "status_code": e.response.status_code},
|
||||||
|
)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=HTTP_503_SERVICE_UNAVAILABLE,
|
||||||
|
detail="Token introspection service unavailable",
|
||||||
|
) from None
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
"OAuth2 introspection failed",
|
||||||
|
extra={"reason": "unexpected_error", "error": str(e)},
|
||||||
|
)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=HTTP_503_SERVICE_UNAVAILABLE,
|
||||||
|
detail="Token introspection failed",
|
||||||
|
) from None
|
||||||
|
|
||||||
|
if not introspection_result.get("active", False):
|
||||||
|
logger.debug(
|
||||||
|
"OAuth2 authentication failed",
|
||||||
|
extra={"reason": "token_not_active", "scheme": "oauth2"},
|
||||||
|
)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Token is not active",
|
||||||
|
)
|
||||||
|
|
||||||
|
return AuthenticatedUser(
|
||||||
|
token=token,
|
||||||
|
scheme="oauth2",
|
||||||
|
claims=introspection_result,
|
||||||
|
)
|
||||||
|
|
||||||
|
def to_security_scheme(self) -> OAuth2SecurityScheme:
|
||||||
|
"""Generate OAuth2SecurityScheme for AgentCard declaration.
|
||||||
|
|
||||||
|
Creates an OAuth2SecurityScheme with appropriate flows based on
|
||||||
|
the configured URLs. Includes client_credentials flow if token_url
|
||||||
|
is set, and authorization_code flow if authorization_url is set.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
OAuth2SecurityScheme suitable for use in AgentCard security_schemes.
|
||||||
|
"""
|
||||||
|
from a2a.types import (
|
||||||
|
AuthorizationCodeOAuthFlow,
|
||||||
|
ClientCredentialsOAuthFlow,
|
||||||
|
OAuth2SecurityScheme,
|
||||||
|
OAuthFlows,
|
||||||
|
)
|
||||||
|
|
||||||
|
client_credentials = None
|
||||||
|
authorization_code = None
|
||||||
|
|
||||||
|
if self.token_url:
|
||||||
|
client_credentials = ClientCredentialsOAuthFlow(
|
||||||
|
token_url=str(self.token_url),
|
||||||
|
refresh_url=str(self.refresh_url) if self.refresh_url else None,
|
||||||
|
scopes=self.scopes,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.authorization_url:
|
||||||
|
authorization_code = AuthorizationCodeOAuthFlow(
|
||||||
|
authorization_url=str(self.authorization_url),
|
||||||
|
token_url=str(self.token_url),
|
||||||
|
refresh_url=str(self.refresh_url) if self.refresh_url else None,
|
||||||
|
scopes=self.scopes,
|
||||||
|
)
|
||||||
|
|
||||||
|
return OAuth2SecurityScheme(
|
||||||
|
flows=OAuthFlows(
|
||||||
|
client_credentials=client_credentials,
|
||||||
|
authorization_code=authorization_code,
|
||||||
|
),
|
||||||
|
description="OAuth2 authentication",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class APIKeyServerAuth(ServerAuthScheme):
|
||||||
|
"""API Key authentication for A2A server.
|
||||||
|
|
||||||
|
Validates requests using an API key in a header, query parameter, or cookie.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
name: The name of the API key parameter (default: X-API-Key).
|
||||||
|
location: Where to look for the API key (header, query, or cookie).
|
||||||
|
api_key: The expected API key value.
|
||||||
|
"""
|
||||||
|
|
||||||
|
name: str = Field(
|
||||||
|
default="X-API-Key",
|
||||||
|
description="Name of the API key parameter",
|
||||||
|
)
|
||||||
|
location: Literal["header", "query", "cookie"] = Field(
|
||||||
|
default="header",
|
||||||
|
description="Where to look for the API key",
|
||||||
|
)
|
||||||
|
api_key: CoercedSecretStr = Field(
|
||||||
|
description="Expected API key value",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def authenticate(self, token: str) -> AuthenticatedUser:
|
||||||
|
"""Authenticate using API key comparison.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token: The API key to authenticate.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AuthenticatedUser on successful authentication.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException: If authentication fails.
|
||||||
|
"""
|
||||||
|
if token != self.api_key.get_secret_value():
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Invalid API key",
|
||||||
|
)
|
||||||
|
|
||||||
|
return AuthenticatedUser(
|
||||||
|
token=token,
|
||||||
|
scheme="api_key",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MTLSServerAuth(ServerAuthScheme):
|
||||||
|
"""Mutual TLS authentication marker for AgentCard declaration.
|
||||||
|
|
||||||
|
This scheme is primarily for AgentCard security_schemes declaration.
|
||||||
|
Actual mTLS verification happens at the TLS/transport layer, not
|
||||||
|
at the application layer via token validation.
|
||||||
|
|
||||||
|
When configured, this signals to clients that the server requires
|
||||||
|
client certificates for authentication.
|
||||||
|
"""
|
||||||
|
|
||||||
|
description: str = Field(
|
||||||
|
default="Mutual TLS certificate authentication",
|
||||||
|
description="Description for the security scheme",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def authenticate(self, token: str) -> AuthenticatedUser:
|
||||||
|
"""Return authenticated user for mTLS.
|
||||||
|
|
||||||
|
mTLS verification happens at the transport layer before this is called.
|
||||||
|
If we reach this point, the TLS handshake with client cert succeeded.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token: Certificate subject or identifier (from TLS layer).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AuthenticatedUser indicating mTLS authentication.
|
||||||
|
"""
|
||||||
|
return AuthenticatedUser(
|
||||||
|
token=token or "mtls-verified",
|
||||||
|
scheme="mtls",
|
||||||
|
)
|
||||||
@@ -6,8 +6,10 @@ OAuth2, API keys, and HTTP authentication methods.
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from collections.abc import Awaitable, Callable, MutableMapping
|
from collections.abc import Awaitable, Callable, MutableMapping
|
||||||
|
import hashlib
|
||||||
import re
|
import re
|
||||||
from typing import Final
|
import threading
|
||||||
|
from typing import Final, Literal, cast
|
||||||
|
|
||||||
from a2a.client.errors import A2AClientHTTPError
|
from a2a.client.errors import A2AClientHTTPError
|
||||||
from a2a.types import (
|
from a2a.types import (
|
||||||
@@ -18,10 +20,10 @@ from a2a.types import (
|
|||||||
)
|
)
|
||||||
from httpx import AsyncClient, Response
|
from httpx import AsyncClient, Response
|
||||||
|
|
||||||
from crewai.a2a.auth.schemas import (
|
from crewai.a2a.auth.client_schemes import (
|
||||||
APIKeyAuth,
|
APIKeyAuth,
|
||||||
AuthScheme,
|
|
||||||
BearerTokenAuth,
|
BearerTokenAuth,
|
||||||
|
ClientAuthScheme,
|
||||||
HTTPBasicAuth,
|
HTTPBasicAuth,
|
||||||
HTTPDigestAuth,
|
HTTPDigestAuth,
|
||||||
OAuth2AuthorizationCode,
|
OAuth2AuthorizationCode,
|
||||||
@@ -29,12 +31,44 @@ from crewai.a2a.auth.schemas import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
_auth_store: dict[int, AuthScheme | None] = {}
|
class _AuthStore:
|
||||||
|
"""Store for authentication schemes with safe concurrent access."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._store: dict[str, ClientAuthScheme | None] = {}
|
||||||
|
self._lock = threading.RLock()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def compute_key(auth_type: str, auth_data: str) -> str:
|
||||||
|
"""Compute a collision-resistant key using SHA-256."""
|
||||||
|
content = f"{auth_type}:{auth_data}"
|
||||||
|
return hashlib.sha256(content.encode()).hexdigest()
|
||||||
|
|
||||||
|
def set(self, key: str, auth: ClientAuthScheme | None) -> None:
|
||||||
|
"""Store an auth scheme."""
|
||||||
|
with self._lock:
|
||||||
|
self._store[key] = auth
|
||||||
|
|
||||||
|
def get(self, key: str) -> ClientAuthScheme | None:
|
||||||
|
"""Retrieve an auth scheme by key."""
|
||||||
|
with self._lock:
|
||||||
|
return self._store.get(key)
|
||||||
|
|
||||||
|
def __setitem__(self, key: str, value: ClientAuthScheme | None) -> None:
|
||||||
|
with self._lock:
|
||||||
|
self._store[key] = value
|
||||||
|
|
||||||
|
def __getitem__(self, key: str) -> ClientAuthScheme | None:
|
||||||
|
with self._lock:
|
||||||
|
return self._store[key]
|
||||||
|
|
||||||
|
|
||||||
|
_auth_store = _AuthStore()
|
||||||
|
|
||||||
_SCHEME_PATTERN: Final[re.Pattern[str]] = re.compile(r"(\w+)\s+(.+?)(?=,\s*\w+\s+|$)")
|
_SCHEME_PATTERN: Final[re.Pattern[str]] = re.compile(r"(\w+)\s+(.+?)(?=,\s*\w+\s+|$)")
|
||||||
_PARAM_PATTERN: Final[re.Pattern[str]] = re.compile(r'(\w+)=(?:"([^"]*)"|([^\s,]+))')
|
_PARAM_PATTERN: Final[re.Pattern[str]] = re.compile(r'(\w+)=(?:"([^"]*)"|([^\s,]+))')
|
||||||
|
|
||||||
_SCHEME_AUTH_MAPPING: Final[dict[type, tuple[type[AuthScheme], ...]]] = {
|
_SCHEME_AUTH_MAPPING: Final[dict[type, tuple[type[ClientAuthScheme], ...]]] = {
|
||||||
OAuth2SecurityScheme: (
|
OAuth2SecurityScheme: (
|
||||||
OAuth2ClientCredentials,
|
OAuth2ClientCredentials,
|
||||||
OAuth2AuthorizationCode,
|
OAuth2AuthorizationCode,
|
||||||
@@ -43,7 +77,9 @@ _SCHEME_AUTH_MAPPING: Final[dict[type, tuple[type[AuthScheme], ...]]] = {
|
|||||||
APIKeySecurityScheme: (APIKeyAuth,),
|
APIKeySecurityScheme: (APIKeyAuth,),
|
||||||
}
|
}
|
||||||
|
|
||||||
_HTTP_SCHEME_MAPPING: Final[dict[str, type[AuthScheme]]] = {
|
_HTTPSchemeType = Literal["basic", "digest", "bearer"]
|
||||||
|
|
||||||
|
_HTTP_SCHEME_MAPPING: Final[dict[_HTTPSchemeType, type[ClientAuthScheme]]] = {
|
||||||
"basic": HTTPBasicAuth,
|
"basic": HTTPBasicAuth,
|
||||||
"digest": HTTPDigestAuth,
|
"digest": HTTPDigestAuth,
|
||||||
"bearer": BearerTokenAuth,
|
"bearer": BearerTokenAuth,
|
||||||
@@ -51,8 +87,8 @@ _HTTP_SCHEME_MAPPING: Final[dict[str, type[AuthScheme]]] = {
|
|||||||
|
|
||||||
|
|
||||||
def _raise_auth_mismatch(
|
def _raise_auth_mismatch(
|
||||||
expected_classes: type[AuthScheme] | tuple[type[AuthScheme], ...],
|
expected_classes: type[ClientAuthScheme] | tuple[type[ClientAuthScheme], ...],
|
||||||
provided_auth: AuthScheme,
|
provided_auth: ClientAuthScheme,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Raise authentication mismatch error.
|
"""Raise authentication mismatch error.
|
||||||
|
|
||||||
@@ -111,7 +147,7 @@ def parse_www_authenticate(header_value: str) -> dict[str, dict[str, str]]:
|
|||||||
|
|
||||||
|
|
||||||
def validate_auth_against_agent_card(
|
def validate_auth_against_agent_card(
|
||||||
agent_card: AgentCard, auth: AuthScheme | None
|
agent_card: AgentCard, auth: ClientAuthScheme | None
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Validate that provided auth matches AgentCard security requirements.
|
"""Validate that provided auth matches AgentCard security requirements.
|
||||||
|
|
||||||
@@ -145,7 +181,8 @@ def validate_auth_against_agent_card(
|
|||||||
return
|
return
|
||||||
|
|
||||||
if isinstance(scheme, HTTPAuthSecurityScheme):
|
if isinstance(scheme, HTTPAuthSecurityScheme):
|
||||||
if required_class := _HTTP_SCHEME_MAPPING.get(scheme.scheme.lower()):
|
scheme_key = cast(_HTTPSchemeType, scheme.scheme.lower())
|
||||||
|
if required_class := _HTTP_SCHEME_MAPPING.get(scheme_key):
|
||||||
if not isinstance(auth, required_class):
|
if not isinstance(auth, required_class):
|
||||||
_raise_auth_mismatch(required_class, auth)
|
_raise_auth_mismatch(required_class, auth)
|
||||||
return
|
return
|
||||||
@@ -156,7 +193,7 @@ def validate_auth_against_agent_card(
|
|||||||
|
|
||||||
async def retry_on_401(
|
async def retry_on_401(
|
||||||
request_func: Callable[[], Awaitable[Response]],
|
request_func: Callable[[], Awaitable[Response]],
|
||||||
auth_scheme: AuthScheme | None,
|
auth_scheme: ClientAuthScheme | None,
|
||||||
client: AsyncClient,
|
client: AsyncClient,
|
||||||
headers: MutableMapping[str, str],
|
headers: MutableMapping[str, str],
|
||||||
max_retries: int = 3,
|
max_retries: int = 3,
|
||||||
|
|||||||
@@ -1,4 +1,37 @@
|
|||||||
"""A2A Protocol Extensions for CrewAI.
|
"""A2A Protocol Extensions for CrewAI.
|
||||||
|
|
||||||
This module contains extensions to the A2A (Agent-to-Agent) protocol.
|
This module contains extensions to the A2A (Agent-to-Agent) protocol.
|
||||||
|
|
||||||
|
**Client-side extensions** (A2AExtension) allow customizing how the A2A wrapper
|
||||||
|
processes requests and responses during delegation to remote agents. These provide
|
||||||
|
hooks for tool injection, prompt augmentation, and response processing.
|
||||||
|
|
||||||
|
**Server-side extensions** (ServerExtension) allow agents to offer additional
|
||||||
|
functionality beyond the core A2A specification. Clients activate extensions
|
||||||
|
via the X-A2A-Extensions header.
|
||||||
|
|
||||||
|
See: https://a2a-protocol.org/latest/topics/extensions/
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from crewai.a2a.extensions.base import (
|
||||||
|
A2AExtension,
|
||||||
|
ConversationState,
|
||||||
|
ExtensionRegistry,
|
||||||
|
ValidatedA2AExtension,
|
||||||
|
)
|
||||||
|
from crewai.a2a.extensions.server import (
|
||||||
|
ExtensionContext,
|
||||||
|
ServerExtension,
|
||||||
|
ServerExtensionRegistry,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"A2AExtension",
|
||||||
|
"ConversationState",
|
||||||
|
"ExtensionContext",
|
||||||
|
"ExtensionRegistry",
|
||||||
|
"ServerExtension",
|
||||||
|
"ServerExtensionRegistry",
|
||||||
|
"ValidatedA2AExtension",
|
||||||
|
]
|
||||||
|
|||||||
@@ -1,14 +1,20 @@
|
|||||||
"""Base extension interface for A2A wrapper integrations.
|
"""Base extension interface for CrewAI A2A wrapper processing hooks.
|
||||||
|
|
||||||
This module defines the protocol for extending A2A wrapper functionality
|
This module defines the protocol for extending CrewAI's A2A wrapper functionality
|
||||||
with custom logic for conversation processing, prompt augmentation, and
|
with custom logic for tool injection, prompt augmentation, and response processing.
|
||||||
agent response handling.
|
|
||||||
|
Note: These are CrewAI-specific processing hooks, NOT A2A protocol extensions.
|
||||||
|
A2A protocol extensions are capability declarations using AgentExtension objects
|
||||||
|
in AgentCard.capabilities.extensions, activated via the A2A-Extensions HTTP header.
|
||||||
|
See: https://a2a-protocol.org/latest/topics/extensions/
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from typing import TYPE_CHECKING, Any, Protocol
|
from typing import TYPE_CHECKING, Annotated, Any, Protocol, runtime_checkable
|
||||||
|
|
||||||
|
from pydantic import BeforeValidator
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -17,6 +23,20 @@ if TYPE_CHECKING:
|
|||||||
from crewai.agent.core import Agent
|
from crewai.agent.core import Agent
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_a2a_extension(v: Any) -> Any:
|
||||||
|
"""Validate that value implements A2AExtension protocol."""
|
||||||
|
if not isinstance(v, A2AExtension):
|
||||||
|
raise ValueError(
|
||||||
|
f"Value must implement A2AExtension protocol. "
|
||||||
|
f"Got {type(v).__name__} which is missing required methods."
|
||||||
|
)
|
||||||
|
return v
|
||||||
|
|
||||||
|
|
||||||
|
ValidatedA2AExtension = Annotated[Any, BeforeValidator(_validate_a2a_extension)]
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
class ConversationState(Protocol):
|
class ConversationState(Protocol):
|
||||||
"""Protocol for extension-specific conversation state.
|
"""Protocol for extension-specific conversation state.
|
||||||
|
|
||||||
@@ -33,11 +53,36 @@ class ConversationState(Protocol):
|
|||||||
...
|
...
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
class A2AExtension(Protocol):
|
class A2AExtension(Protocol):
|
||||||
"""Protocol for A2A wrapper extensions.
|
"""Protocol for A2A wrapper extensions.
|
||||||
|
|
||||||
Extensions can implement this protocol to inject custom logic into
|
Extensions can implement this protocol to inject custom logic into
|
||||||
the A2A conversation flow at various integration points.
|
the A2A conversation flow at various integration points.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
class MyExtension:
|
||||||
|
def inject_tools(self, agent: Agent) -> None:
|
||||||
|
# Add custom tools to the agent
|
||||||
|
pass
|
||||||
|
|
||||||
|
def extract_state_from_history(
|
||||||
|
self, conversation_history: Sequence[Message]
|
||||||
|
) -> ConversationState | None:
|
||||||
|
# Extract state from conversation
|
||||||
|
return None
|
||||||
|
|
||||||
|
def augment_prompt(
|
||||||
|
self, base_prompt: str, conversation_state: ConversationState | None
|
||||||
|
) -> str:
|
||||||
|
# Add custom instructions
|
||||||
|
return base_prompt
|
||||||
|
|
||||||
|
def process_response(
|
||||||
|
self, agent_response: Any, conversation_state: ConversationState | None
|
||||||
|
) -> Any:
|
||||||
|
# Modify response if needed
|
||||||
|
return agent_response
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def inject_tools(self, agent: Agent) -> None:
|
def inject_tools(self, agent: Agent) -> None:
|
||||||
|
|||||||
@@ -1,34 +1,170 @@
|
|||||||
"""Extension registry factory for A2A configurations.
|
"""A2A Protocol extension utilities.
|
||||||
|
|
||||||
This module provides utilities for creating extension registries from A2A configurations.
|
This module provides utilities for working with A2A protocol extensions as
|
||||||
|
defined in the A2A specification. Extensions are capability declarations in
|
||||||
|
AgentCard.capabilities.extensions using AgentExtension objects, activated
|
||||||
|
via the X-A2A-Extensions HTTP header.
|
||||||
|
|
||||||
|
See: https://a2a-protocol.org/latest/topics/extensions/
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import TYPE_CHECKING
|
from typing import Any
|
||||||
|
|
||||||
|
from a2a.client.middleware import ClientCallContext, ClientCallInterceptor
|
||||||
|
from a2a.extensions.common import (
|
||||||
|
HTTP_EXTENSION_HEADER,
|
||||||
|
)
|
||||||
|
from a2a.types import AgentCard, AgentExtension
|
||||||
|
|
||||||
|
from crewai.a2a.config import A2AClientConfig, A2AConfig
|
||||||
from crewai.a2a.extensions.base import ExtensionRegistry
|
from crewai.a2a.extensions.base import ExtensionRegistry
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
def get_extensions_from_config(
|
||||||
from crewai.a2a.config import A2AConfig
|
a2a_config: list[A2AConfig | A2AClientConfig] | A2AConfig | A2AClientConfig,
|
||||||
|
) -> list[str]:
|
||||||
|
"""Extract extension URIs from A2A configuration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
a2a_config: A2A configuration (single or list).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Deduplicated list of extension URIs from all configs.
|
||||||
|
"""
|
||||||
|
configs = a2a_config if isinstance(a2a_config, list) else [a2a_config]
|
||||||
|
seen: set[str] = set()
|
||||||
|
result: list[str] = []
|
||||||
|
|
||||||
|
for config in configs:
|
||||||
|
if not isinstance(config, A2AClientConfig):
|
||||||
|
continue
|
||||||
|
for uri in config.extensions:
|
||||||
|
if uri not in seen:
|
||||||
|
seen.add(uri)
|
||||||
|
result.append(uri)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class ExtensionsMiddleware(ClientCallInterceptor):
|
||||||
|
"""Middleware to add X-A2A-Extensions header to requests.
|
||||||
|
|
||||||
|
This middleware adds the extensions header to all outgoing requests,
|
||||||
|
declaring which A2A protocol extensions the client supports.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, extensions: list[str]) -> None:
|
||||||
|
"""Initialize with extension URIs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
extensions: List of extension URIs the client supports.
|
||||||
|
"""
|
||||||
|
self._extensions = extensions
|
||||||
|
|
||||||
|
async def intercept(
|
||||||
|
self,
|
||||||
|
method_name: str,
|
||||||
|
request_payload: dict[str, Any],
|
||||||
|
http_kwargs: dict[str, Any],
|
||||||
|
agent_card: AgentCard | None,
|
||||||
|
context: ClientCallContext | None,
|
||||||
|
) -> tuple[dict[str, Any], dict[str, Any]]:
|
||||||
|
"""Add extensions header to the request.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
method_name: The A2A method being called.
|
||||||
|
request_payload: The JSON-RPC request payload.
|
||||||
|
http_kwargs: HTTP request kwargs (headers, etc).
|
||||||
|
agent_card: The target agent's card.
|
||||||
|
context: Optional call context.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (request_payload, modified_http_kwargs).
|
||||||
|
"""
|
||||||
|
if self._extensions:
|
||||||
|
headers = http_kwargs.setdefault("headers", {})
|
||||||
|
headers[HTTP_EXTENSION_HEADER] = ",".join(self._extensions)
|
||||||
|
return request_payload, http_kwargs
|
||||||
|
|
||||||
|
|
||||||
|
def validate_required_extensions(
|
||||||
|
agent_card: AgentCard,
|
||||||
|
client_extensions: list[str] | None,
|
||||||
|
) -> list[AgentExtension]:
|
||||||
|
"""Validate that client supports all required extensions from agent.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent_card: The agent's card with declared extensions.
|
||||||
|
client_extensions: Extension URIs the client supports.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of unsupported required extensions.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
None - returns list of unsupported extensions for caller to handle.
|
||||||
|
"""
|
||||||
|
unsupported: list[AgentExtension] = []
|
||||||
|
client_set = set(client_extensions or [])
|
||||||
|
|
||||||
|
if not agent_card.capabilities or not agent_card.capabilities.extensions:
|
||||||
|
return unsupported
|
||||||
|
|
||||||
|
unsupported.extend(
|
||||||
|
ext
|
||||||
|
for ext in agent_card.capabilities.extensions
|
||||||
|
if ext.required and ext.uri not in client_set
|
||||||
|
)
|
||||||
|
|
||||||
|
return unsupported
|
||||||
|
|
||||||
|
|
||||||
def create_extension_registry_from_config(
|
def create_extension_registry_from_config(
|
||||||
a2a_config: list[A2AConfig] | A2AConfig,
|
a2a_config: list[A2AConfig | A2AClientConfig] | A2AConfig | A2AClientConfig,
|
||||||
) -> ExtensionRegistry:
|
) -> ExtensionRegistry:
|
||||||
"""Create an extension registry from A2A configuration.
|
"""Create an extension registry from A2A client configuration.
|
||||||
|
|
||||||
|
Extracts client_extensions from each A2AClientConfig and registers them
|
||||||
|
with the ExtensionRegistry. These extensions provide CrewAI-specific
|
||||||
|
processing hooks (tool injection, prompt augmentation, response processing).
|
||||||
|
|
||||||
|
Note: A2A protocol extensions (URI strings sent via X-A2A-Extensions header)
|
||||||
|
are handled separately via get_extensions_from_config() and ExtensionsMiddleware.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
a2a_config: A2A configuration (single or list)
|
a2a_config: A2A configuration (single or list).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Configured extension registry with all applicable extensions
|
Extension registry with all client_extensions registered.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
class LoggingExtension:
|
||||||
|
def inject_tools(self, agent): pass
|
||||||
|
def extract_state_from_history(self, history): return None
|
||||||
|
def augment_prompt(self, prompt, state): return prompt
|
||||||
|
def process_response(self, response, state):
|
||||||
|
print(f"Response: {response}")
|
||||||
|
return response
|
||||||
|
|
||||||
|
config = A2AClientConfig(
|
||||||
|
endpoint="https://agent.example.com",
|
||||||
|
client_extensions=[LoggingExtension()],
|
||||||
|
)
|
||||||
|
registry = create_extension_registry_from_config(config)
|
||||||
"""
|
"""
|
||||||
registry = ExtensionRegistry()
|
registry = ExtensionRegistry()
|
||||||
configs = a2a_config if isinstance(a2a_config, list) else [a2a_config]
|
configs = a2a_config if isinstance(a2a_config, list) else [a2a_config]
|
||||||
|
|
||||||
for _ in configs:
|
seen: set[int] = set()
|
||||||
pass
|
|
||||||
|
for config in configs:
|
||||||
|
if isinstance(config, (A2AConfig, A2AClientConfig)):
|
||||||
|
client_exts = getattr(config, "client_extensions", [])
|
||||||
|
for extension in client_exts:
|
||||||
|
ext_id = id(extension)
|
||||||
|
if ext_id not in seen:
|
||||||
|
seen.add(ext_id)
|
||||||
|
registry.register(extension)
|
||||||
|
|
||||||
return registry
|
return registry
|
||||||
|
|||||||
305
lib/crewai/src/crewai/a2a/extensions/server.py
Normal file
305
lib/crewai/src/crewai/a2a/extensions/server.py
Normal file
@@ -0,0 +1,305 @@
|
|||||||
|
"""A2A protocol server extensions for CrewAI agents.
|
||||||
|
|
||||||
|
This module provides the base class and context for implementing A2A protocol
|
||||||
|
extensions on the server side. Extensions allow agents to offer additional
|
||||||
|
functionality beyond the core A2A specification.
|
||||||
|
|
||||||
|
See: https://a2a-protocol.org/latest/topics/extensions/
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
import logging
|
||||||
|
from typing import TYPE_CHECKING, Annotated, Any
|
||||||
|
|
||||||
|
from a2a.types import AgentExtension
|
||||||
|
from pydantic_core import CoreSchema, core_schema
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from a2a.server.context import ServerCallContext
|
||||||
|
from pydantic import GetCoreSchemaHandler
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ExtensionContext:
|
||||||
|
"""Context passed to extension hooks during request processing.
|
||||||
|
|
||||||
|
Provides access to request metadata, client extensions, and shared state
|
||||||
|
that extensions can read from and write to.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
metadata: Request metadata dict, includes extension-namespaced keys.
|
||||||
|
client_extensions: Set of extension URIs the client declared support for.
|
||||||
|
state: Mutable dict for extensions to share data during request lifecycle.
|
||||||
|
server_context: The underlying A2A server call context.
|
||||||
|
"""
|
||||||
|
|
||||||
|
metadata: dict[str, Any]
|
||||||
|
client_extensions: set[str]
|
||||||
|
state: dict[str, Any] = field(default_factory=dict)
|
||||||
|
server_context: ServerCallContext | None = None
|
||||||
|
|
||||||
|
def get_extension_metadata(self, uri: str, key: str) -> Any | None:
|
||||||
|
"""Get extension-specific metadata value.
|
||||||
|
|
||||||
|
Extension metadata uses namespaced keys in the format:
|
||||||
|
"{extension_uri}/{key}"
|
||||||
|
|
||||||
|
Args:
|
||||||
|
uri: The extension URI.
|
||||||
|
key: The metadata key within the extension namespace.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The metadata value, or None if not present.
|
||||||
|
"""
|
||||||
|
full_key = f"{uri}/{key}"
|
||||||
|
return self.metadata.get(full_key)
|
||||||
|
|
||||||
|
def set_extension_metadata(self, uri: str, key: str, value: Any) -> None:
|
||||||
|
"""Set extension-specific metadata value.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
uri: The extension URI.
|
||||||
|
key: The metadata key within the extension namespace.
|
||||||
|
value: The value to set.
|
||||||
|
"""
|
||||||
|
full_key = f"{uri}/{key}"
|
||||||
|
self.metadata[full_key] = value
|
||||||
|
|
||||||
|
|
||||||
|
class ServerExtension(ABC):
|
||||||
|
"""Base class for A2A protocol server extensions.
|
||||||
|
|
||||||
|
Subclass this to create custom extensions that modify agent behavior
|
||||||
|
when clients activate them. Extensions are identified by URI and can
|
||||||
|
be marked as required.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
class SamplingExtension(ServerExtension):
|
||||||
|
uri = "urn:crewai:ext:sampling/v1"
|
||||||
|
required = True
|
||||||
|
|
||||||
|
def __init__(self, max_tokens: int = 4096):
|
||||||
|
self.max_tokens = max_tokens
|
||||||
|
|
||||||
|
@property
|
||||||
|
def params(self) -> dict[str, Any]:
|
||||||
|
return {"max_tokens": self.max_tokens}
|
||||||
|
|
||||||
|
async def on_request(self, context: ExtensionContext) -> None:
|
||||||
|
limit = context.get_extension_metadata(self.uri, "limit")
|
||||||
|
if limit:
|
||||||
|
context.state["token_limit"] = int(limit)
|
||||||
|
|
||||||
|
async def on_response(self, context: ExtensionContext, result: Any) -> Any:
|
||||||
|
return result
|
||||||
|
"""
|
||||||
|
|
||||||
|
uri: Annotated[str, "Extension URI identifier. Must be unique."]
|
||||||
|
required: Annotated[bool, "Whether clients must support this extension."] = False
|
||||||
|
description: Annotated[
|
||||||
|
str | None, "Human-readable description of the extension."
|
||||||
|
] = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def __get_pydantic_core_schema__(
|
||||||
|
cls,
|
||||||
|
_source_type: Any,
|
||||||
|
_handler: GetCoreSchemaHandler,
|
||||||
|
) -> CoreSchema:
|
||||||
|
"""Tell Pydantic how to validate ServerExtension instances."""
|
||||||
|
return core_schema.is_instance_schema(cls)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def params(self) -> dict[str, Any] | None:
|
||||||
|
"""Extension parameters to advertise in AgentCard.
|
||||||
|
|
||||||
|
Override this property to expose configuration that clients can read.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict of parameter names to values, or None.
|
||||||
|
"""
|
||||||
|
return None
|
||||||
|
|
||||||
|
def agent_extension(self) -> AgentExtension:
|
||||||
|
"""Generate the AgentExtension object for the AgentCard.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AgentExtension with this extension's URI, required flag, and params.
|
||||||
|
"""
|
||||||
|
return AgentExtension(
|
||||||
|
uri=self.uri,
|
||||||
|
required=self.required if self.required else None,
|
||||||
|
description=self.description,
|
||||||
|
params=self.params,
|
||||||
|
)
|
||||||
|
|
||||||
|
def is_active(self, context: ExtensionContext) -> bool:
|
||||||
|
"""Check if this extension is active for the current request.
|
||||||
|
|
||||||
|
An extension is active if the client declared support for it.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context: The extension context for the current request.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the client supports this extension.
|
||||||
|
"""
|
||||||
|
return self.uri in context.client_extensions
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def on_request(self, context: ExtensionContext) -> None:
|
||||||
|
"""Called before agent execution if extension is active.
|
||||||
|
|
||||||
|
Use this hook to:
|
||||||
|
- Read extension-specific metadata from the request
|
||||||
|
- Set up state for the execution
|
||||||
|
- Modify execution parameters via context.state
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context: The extension context with request metadata and state.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def on_response(self, context: ExtensionContext, result: Any) -> Any:
|
||||||
|
"""Called after agent execution if extension is active.
|
||||||
|
|
||||||
|
Use this hook to:
|
||||||
|
- Modify or enhance the result
|
||||||
|
- Add extension-specific metadata to the response
|
||||||
|
- Clean up any resources
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context: The extension context with request metadata and state.
|
||||||
|
result: The agent execution result.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The result, potentially modified.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class ServerExtensionRegistry:
|
||||||
|
"""Registry for managing server-side A2A protocol extensions.
|
||||||
|
|
||||||
|
Collects extensions and provides methods to generate AgentCapabilities
|
||||||
|
and invoke extension hooks during request processing.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, extensions: list[ServerExtension] | None = None) -> None:
|
||||||
|
"""Initialize the registry with optional extensions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
extensions: Initial list of extensions to register.
|
||||||
|
"""
|
||||||
|
self._extensions: list[ServerExtension] = list(extensions) if extensions else []
|
||||||
|
self._by_uri: dict[str, ServerExtension] = {
|
||||||
|
ext.uri: ext for ext in self._extensions
|
||||||
|
}
|
||||||
|
|
||||||
|
def register(self, extension: ServerExtension) -> None:
|
||||||
|
"""Register an extension.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
extension: The extension to register.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If an extension with the same URI is already registered.
|
||||||
|
"""
|
||||||
|
if extension.uri in self._by_uri:
|
||||||
|
raise ValueError(f"Extension already registered: {extension.uri}")
|
||||||
|
self._extensions.append(extension)
|
||||||
|
self._by_uri[extension.uri] = extension
|
||||||
|
|
||||||
|
def get_agent_extensions(self) -> list[AgentExtension]:
|
||||||
|
"""Get AgentExtension objects for all registered extensions.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of AgentExtension objects for the AgentCard.
|
||||||
|
"""
|
||||||
|
return [ext.agent_extension() for ext in self._extensions]
|
||||||
|
|
||||||
|
def get_extension(self, uri: str) -> ServerExtension | None:
|
||||||
|
"""Get an extension by URI.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
uri: The extension URI.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The extension, or None if not found.
|
||||||
|
"""
|
||||||
|
return self._by_uri.get(uri)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create_context(
|
||||||
|
metadata: dict[str, Any],
|
||||||
|
client_extensions: set[str],
|
||||||
|
server_context: ServerCallContext | None = None,
|
||||||
|
) -> ExtensionContext:
|
||||||
|
"""Create an ExtensionContext for a request.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
metadata: Request metadata dict.
|
||||||
|
client_extensions: Set of extension URIs from client.
|
||||||
|
server_context: Optional server call context.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ExtensionContext for use in hooks.
|
||||||
|
"""
|
||||||
|
return ExtensionContext(
|
||||||
|
metadata=metadata,
|
||||||
|
client_extensions=client_extensions,
|
||||||
|
server_context=server_context,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def invoke_on_request(self, context: ExtensionContext) -> None:
|
||||||
|
"""Invoke on_request hooks for all active extensions.
|
||||||
|
|
||||||
|
Tracks activated extensions and isolates errors from individual hooks.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context: The extension context for the request.
|
||||||
|
"""
|
||||||
|
for extension in self._extensions:
|
||||||
|
if extension.is_active(context):
|
||||||
|
try:
|
||||||
|
await extension.on_request(context)
|
||||||
|
if context.server_context is not None:
|
||||||
|
context.server_context.activated_extensions.add(extension.uri)
|
||||||
|
except Exception:
|
||||||
|
logger.exception(
|
||||||
|
"Extension on_request hook failed",
|
||||||
|
extra={"extension": extension.uri},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def invoke_on_response(self, context: ExtensionContext, result: Any) -> Any:
|
||||||
|
"""Invoke on_response hooks for all active extensions.
|
||||||
|
|
||||||
|
Isolates errors from individual hooks to prevent one failing extension
|
||||||
|
from breaking the entire response.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context: The extension context for the request.
|
||||||
|
result: The agent execution result.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The result after all extensions have processed it.
|
||||||
|
"""
|
||||||
|
processed = result
|
||||||
|
for extension in self._extensions:
|
||||||
|
if extension.is_active(context):
|
||||||
|
try:
|
||||||
|
processed = await extension.on_response(context, processed)
|
||||||
|
except Exception:
|
||||||
|
logger.exception(
|
||||||
|
"Extension on_response hook failed",
|
||||||
|
extra={"extension": extension.uri},
|
||||||
|
)
|
||||||
|
return processed
|
||||||
Reference in New Issue
Block a user