mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-30 02:28:13 +00:00
Compare commits
1 Commits
gl/feat/a2
...
devin/1769
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
78de2038d7 |
@@ -10,7 +10,7 @@ requires-python = ">=3.10, <3.14"
|
||||
dependencies = [
|
||||
# Core Dependencies
|
||||
"pydantic~=2.11.9",
|
||||
"openai~=1.83.0",
|
||||
"openai>=1.83.0,<2",
|
||||
"instructor>=1.3.3",
|
||||
# Text Processing
|
||||
"pdfplumber~=0.11.4",
|
||||
|
||||
@@ -1,36 +1,20 @@
|
||||
"""A2A authentication schemas."""
|
||||
|
||||
from crewai.a2a.auth.client_schemes import (
|
||||
from crewai.a2a.auth.schemas import (
|
||||
APIKeyAuth,
|
||||
AuthScheme,
|
||||
BearerTokenAuth,
|
||||
ClientAuthScheme,
|
||||
HTTPBasicAuth,
|
||||
HTTPDigestAuth,
|
||||
OAuth2AuthorizationCode,
|
||||
OAuth2ClientCredentials,
|
||||
TLSConfig,
|
||||
)
|
||||
from crewai.a2a.auth.server_schemes import (
|
||||
AuthenticatedUser,
|
||||
OIDCAuth,
|
||||
ServerAuthScheme,
|
||||
SimpleTokenAuth,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"APIKeyAuth",
|
||||
"AuthScheme",
|
||||
"AuthenticatedUser",
|
||||
"BearerTokenAuth",
|
||||
"ClientAuthScheme",
|
||||
"HTTPBasicAuth",
|
||||
"HTTPDigestAuth",
|
||||
"OAuth2AuthorizationCode",
|
||||
"OAuth2ClientCredentials",
|
||||
"OIDCAuth",
|
||||
"ServerAuthScheme",
|
||||
"SimpleTokenAuth",
|
||||
"TLSConfig",
|
||||
]
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""Authentication schemes for A2A protocol clients.
|
||||
"""Authentication schemes for A2A protocol agents.
|
||||
|
||||
Supported authentication methods:
|
||||
- Bearer tokens
|
||||
@@ -6,135 +6,24 @@ Supported authentication methods:
|
||||
- API Keys (header, query, cookie)
|
||||
- HTTP Basic authentication
|
||||
- HTTP Digest authentication
|
||||
- mTLS (mutual TLS) client certificate authentication
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
import asyncio
|
||||
import base64
|
||||
from collections.abc import Awaitable, Callable, MutableMapping
|
||||
from pathlib import Path
|
||||
import ssl
|
||||
import time
|
||||
from typing import TYPE_CHECKING, ClassVar, Literal
|
||||
from typing import Literal
|
||||
import urllib.parse
|
||||
|
||||
import httpx
|
||||
from httpx import DigestAuth
|
||||
from pydantic import BaseModel, ConfigDict, Field, FilePath, PrivateAttr
|
||||
from typing_extensions import deprecated
|
||||
from pydantic import BaseModel, Field, PrivateAttr
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import grpc # type: ignore[import-untyped]
|
||||
|
||||
|
||||
class TLSConfig(BaseModel):
|
||||
"""TLS/mTLS configuration for secure client connections.
|
||||
|
||||
Supports mutual TLS (mTLS) where the client presents a certificate to the server,
|
||||
and standard TLS with custom CA verification.
|
||||
|
||||
Attributes:
|
||||
client_cert_path: Path to client certificate file (PEM format) for mTLS.
|
||||
client_key_path: Path to client private key file (PEM format) for mTLS.
|
||||
ca_cert_path: Path to CA certificate bundle for server verification.
|
||||
verify: Whether to verify server certificates. Set False only for development.
|
||||
"""
|
||||
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
|
||||
|
||||
client_cert_path: FilePath | None = Field(
|
||||
default=None,
|
||||
description="Path to client certificate file (PEM format) for mTLS",
|
||||
)
|
||||
client_key_path: FilePath | None = Field(
|
||||
default=None,
|
||||
description="Path to client private key file (PEM format) for mTLS",
|
||||
)
|
||||
ca_cert_path: FilePath | None = Field(
|
||||
default=None,
|
||||
description="Path to CA certificate bundle for server verification",
|
||||
)
|
||||
verify: bool = Field(
|
||||
default=True,
|
||||
description="Whether to verify server certificates. Set False only for development.",
|
||||
)
|
||||
|
||||
def get_httpx_ssl_context(self) -> ssl.SSLContext | bool | str:
|
||||
"""Build SSL context for httpx client.
|
||||
|
||||
Returns:
|
||||
SSL context if certificates configured, True for default verification,
|
||||
False if verification disabled, or path to CA bundle.
|
||||
"""
|
||||
if not self.verify:
|
||||
return False
|
||||
|
||||
if self.client_cert_path and self.client_key_path:
|
||||
context = ssl.create_default_context()
|
||||
|
||||
if self.ca_cert_path:
|
||||
context.load_verify_locations(cafile=str(self.ca_cert_path))
|
||||
|
||||
context.load_cert_chain(
|
||||
certfile=str(self.client_cert_path),
|
||||
keyfile=str(self.client_key_path),
|
||||
)
|
||||
return context
|
||||
|
||||
if self.ca_cert_path:
|
||||
return str(self.ca_cert_path)
|
||||
|
||||
return True
|
||||
|
||||
def get_grpc_credentials(self) -> grpc.ChannelCredentials | None: # type: ignore[no-any-unimported]
|
||||
"""Build gRPC channel credentials for secure connections.
|
||||
|
||||
Returns:
|
||||
gRPC SSL credentials if certificates configured, None otherwise.
|
||||
"""
|
||||
try:
|
||||
import grpc
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
if not self.verify and not self.client_cert_path:
|
||||
return None
|
||||
|
||||
root_certs: bytes | None = None
|
||||
private_key: bytes | None = None
|
||||
certificate_chain: bytes | None = None
|
||||
|
||||
if self.ca_cert_path:
|
||||
root_certs = Path(self.ca_cert_path).read_bytes()
|
||||
|
||||
if self.client_cert_path and self.client_key_path:
|
||||
private_key = Path(self.client_key_path).read_bytes()
|
||||
certificate_chain = Path(self.client_cert_path).read_bytes()
|
||||
|
||||
return grpc.ssl_channel_credentials(
|
||||
root_certificates=root_certs,
|
||||
private_key=private_key,
|
||||
certificate_chain=certificate_chain,
|
||||
)
|
||||
|
||||
|
||||
class ClientAuthScheme(ABC, BaseModel):
|
||||
"""Base class for client-side authentication schemes.
|
||||
|
||||
Client auth schemes apply credentials to outgoing requests.
|
||||
|
||||
Attributes:
|
||||
tls: Optional TLS/mTLS configuration for secure connections.
|
||||
"""
|
||||
|
||||
tls: TLSConfig | None = Field(
|
||||
default=None,
|
||||
description="TLS/mTLS configuration for secure connections",
|
||||
)
|
||||
class AuthScheme(ABC, BaseModel):
|
||||
"""Base class for authentication schemes."""
|
||||
|
||||
@abstractmethod
|
||||
async def apply_auth(
|
||||
@@ -152,12 +41,7 @@ class ClientAuthScheme(ABC, BaseModel):
|
||||
...
|
||||
|
||||
|
||||
@deprecated("Use ClientAuthScheme instead", category=FutureWarning)
|
||||
class AuthScheme(ClientAuthScheme):
|
||||
"""Deprecated: Use ClientAuthScheme instead."""
|
||||
|
||||
|
||||
class BearerTokenAuth(ClientAuthScheme):
|
||||
class BearerTokenAuth(AuthScheme):
|
||||
"""Bearer token authentication (Authorization: Bearer <token>).
|
||||
|
||||
Attributes:
|
||||
@@ -182,7 +66,7 @@ class BearerTokenAuth(ClientAuthScheme):
|
||||
return headers
|
||||
|
||||
|
||||
class HTTPBasicAuth(ClientAuthScheme):
|
||||
class HTTPBasicAuth(AuthScheme):
|
||||
"""HTTP Basic authentication.
|
||||
|
||||
Attributes:
|
||||
@@ -211,7 +95,7 @@ class HTTPBasicAuth(ClientAuthScheme):
|
||||
return headers
|
||||
|
||||
|
||||
class HTTPDigestAuth(ClientAuthScheme):
|
||||
class HTTPDigestAuth(AuthScheme):
|
||||
"""HTTP Digest authentication.
|
||||
|
||||
Note: Uses httpx-auth library for digest implementation.
|
||||
@@ -224,8 +108,6 @@ class HTTPDigestAuth(ClientAuthScheme):
|
||||
username: str = Field(description="Username")
|
||||
password: str = Field(description="Password")
|
||||
|
||||
_configured_client_id: int | None = PrivateAttr(default=None)
|
||||
|
||||
async def apply_auth(
|
||||
self, client: httpx.AsyncClient, headers: MutableMapping[str, str]
|
||||
) -> MutableMapping[str, str]:
|
||||
@@ -243,21 +125,13 @@ class HTTPDigestAuth(ClientAuthScheme):
|
||||
def configure_client(self, client: httpx.AsyncClient) -> None:
|
||||
"""Configure client with Digest auth.
|
||||
|
||||
Idempotent: Only configures the client once. Subsequent calls on the same
|
||||
client instance are no-ops to prevent overwriting auth configuration.
|
||||
|
||||
Args:
|
||||
client: HTTP client to configure with Digest authentication.
|
||||
"""
|
||||
client_id = id(client)
|
||||
if self._configured_client_id == client_id:
|
||||
return
|
||||
|
||||
client.auth = DigestAuth(self.username, self.password)
|
||||
self._configured_client_id = client_id
|
||||
|
||||
|
||||
class APIKeyAuth(ClientAuthScheme):
|
||||
class APIKeyAuth(AuthScheme):
|
||||
"""API Key authentication (header, query, or cookie).
|
||||
|
||||
Attributes:
|
||||
@@ -272,8 +146,6 @@ class APIKeyAuth(ClientAuthScheme):
|
||||
)
|
||||
name: str = Field(default="X-API-Key", description="Parameter name for the API key")
|
||||
|
||||
_configured_client_ids: set[int] = PrivateAttr(default_factory=set)
|
||||
|
||||
async def apply_auth(
|
||||
self, client: httpx.AsyncClient, headers: MutableMapping[str, str]
|
||||
) -> MutableMapping[str, str]:
|
||||
@@ -295,31 +167,21 @@ class APIKeyAuth(ClientAuthScheme):
|
||||
def configure_client(self, client: httpx.AsyncClient) -> None:
|
||||
"""Configure client for query param API keys.
|
||||
|
||||
Idempotent: Only adds the request hook once per client instance.
|
||||
Subsequent calls on the same client are no-ops to prevent hook accumulation.
|
||||
|
||||
Args:
|
||||
client: HTTP client to configure with query param API key hook.
|
||||
"""
|
||||
if self.location == "query":
|
||||
client_id = id(client)
|
||||
if client_id in self._configured_client_ids:
|
||||
return
|
||||
|
||||
async def _add_api_key_param(request: httpx.Request) -> None:
|
||||
url = httpx.URL(request.url)
|
||||
request.url = url.copy_add_param(self.name, self.api_key)
|
||||
|
||||
client.event_hooks["request"].append(_add_api_key_param)
|
||||
self._configured_client_ids.add(client_id)
|
||||
|
||||
|
||||
class OAuth2ClientCredentials(ClientAuthScheme):
|
||||
class OAuth2ClientCredentials(AuthScheme):
|
||||
"""OAuth2 Client Credentials flow authentication.
|
||||
|
||||
Thread-safe implementation with asyncio.Lock to prevent concurrent token fetches
|
||||
when multiple requests share the same auth instance.
|
||||
|
||||
Attributes:
|
||||
token_url: OAuth2 token endpoint URL.
|
||||
client_id: OAuth2 client identifier.
|
||||
@@ -336,17 +198,12 @@ class OAuth2ClientCredentials(ClientAuthScheme):
|
||||
|
||||
_access_token: str | None = PrivateAttr(default=None)
|
||||
_token_expires_at: float | None = PrivateAttr(default=None)
|
||||
_lock: asyncio.Lock = PrivateAttr(default_factory=asyncio.Lock)
|
||||
|
||||
async def apply_auth(
|
||||
self, client: httpx.AsyncClient, headers: MutableMapping[str, str]
|
||||
) -> MutableMapping[str, str]:
|
||||
"""Apply OAuth2 access token to Authorization header.
|
||||
|
||||
Uses asyncio.Lock to ensure only one coroutine fetches tokens at a time,
|
||||
preventing race conditions when multiple concurrent requests use the same
|
||||
auth instance.
|
||||
|
||||
Args:
|
||||
client: HTTP client for making token requests.
|
||||
headers: Current request headers.
|
||||
@@ -359,13 +216,7 @@ class OAuth2ClientCredentials(ClientAuthScheme):
|
||||
or self._token_expires_at is None
|
||||
or time.time() >= self._token_expires_at
|
||||
):
|
||||
async with self._lock:
|
||||
if (
|
||||
self._access_token is None
|
||||
or self._token_expires_at is None
|
||||
or time.time() >= self._token_expires_at
|
||||
):
|
||||
await self._fetch_token(client)
|
||||
await self._fetch_token(client)
|
||||
|
||||
if self._access_token:
|
||||
headers["Authorization"] = f"Bearer {self._access_token}"
|
||||
@@ -399,11 +250,9 @@ class OAuth2ClientCredentials(ClientAuthScheme):
|
||||
self._token_expires_at = time.time() + expires_in - 60
|
||||
|
||||
|
||||
class OAuth2AuthorizationCode(ClientAuthScheme):
|
||||
class OAuth2AuthorizationCode(AuthScheme):
|
||||
"""OAuth2 Authorization Code flow authentication.
|
||||
|
||||
Thread-safe implementation with asyncio.Lock to prevent concurrent token operations.
|
||||
|
||||
Note: Requires interactive authorization.
|
||||
|
||||
Attributes:
|
||||
@@ -430,7 +279,6 @@ class OAuth2AuthorizationCode(ClientAuthScheme):
|
||||
_authorization_callback: Callable[[str], Awaitable[str]] | None = PrivateAttr(
|
||||
default=None
|
||||
)
|
||||
_lock: asyncio.Lock = PrivateAttr(default_factory=asyncio.Lock)
|
||||
|
||||
def set_authorization_callback(
|
||||
self, callback: Callable[[str], Awaitable[str]] | None
|
||||
@@ -447,9 +295,6 @@ class OAuth2AuthorizationCode(ClientAuthScheme):
|
||||
) -> MutableMapping[str, str]:
|
||||
"""Apply OAuth2 access token to Authorization header.
|
||||
|
||||
Uses asyncio.Lock to ensure only one coroutine handles token operations
|
||||
(initial fetch or refresh) at a time.
|
||||
|
||||
Args:
|
||||
client: HTTP client for making token requests.
|
||||
headers: Current request headers.
|
||||
@@ -460,17 +305,14 @@ class OAuth2AuthorizationCode(ClientAuthScheme):
|
||||
Raises:
|
||||
ValueError: If authorization callback is not set.
|
||||
"""
|
||||
|
||||
if self._access_token is None:
|
||||
if self._authorization_callback is None:
|
||||
msg = "Authorization callback not set. Use set_authorization_callback()"
|
||||
raise ValueError(msg)
|
||||
async with self._lock:
|
||||
if self._access_token is None:
|
||||
await self._fetch_initial_token(client)
|
||||
await self._fetch_initial_token(client)
|
||||
elif self._token_expires_at and time.time() >= self._token_expires_at:
|
||||
async with self._lock:
|
||||
if self._token_expires_at and time.time() >= self._token_expires_at:
|
||||
await self._refresh_access_token(client)
|
||||
await self._refresh_access_token(client)
|
||||
|
||||
if self._access_token:
|
||||
headers["Authorization"] = f"Bearer {self._access_token}"
|
||||
@@ -1,739 +0,0 @@
|
||||
"""Server-side authentication schemes for A2A protocol.
|
||||
|
||||
These schemes validate incoming requests to A2A server endpoints.
|
||||
|
||||
Supported authentication methods:
|
||||
- Simple token validation with static bearer tokens
|
||||
- OpenID Connect with JWT validation using JWKS
|
||||
- OAuth2 with JWT validation or token introspection
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
import logging
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Annotated, Any, ClassVar, Literal
|
||||
|
||||
import jwt
|
||||
from jwt import PyJWKClient
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
BeforeValidator,
|
||||
ConfigDict,
|
||||
Field,
|
||||
HttpUrl,
|
||||
PrivateAttr,
|
||||
SecretStr,
|
||||
model_validator,
|
||||
)
|
||||
from typing_extensions import Self
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from a2a.types import OAuth2SecurityScheme
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
try:
|
||||
from fastapi import HTTPException, status as http_status
|
||||
|
||||
HTTP_401_UNAUTHORIZED = http_status.HTTP_401_UNAUTHORIZED
|
||||
HTTP_500_INTERNAL_SERVER_ERROR = http_status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
HTTP_503_SERVICE_UNAVAILABLE = http_status.HTTP_503_SERVICE_UNAVAILABLE
|
||||
except ImportError:
|
||||
|
||||
class HTTPException(Exception): # type: ignore[no-redef] # noqa: N818
|
||||
"""Fallback HTTPException when FastAPI is not installed."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
status_code: int,
|
||||
detail: str | None = None,
|
||||
headers: dict[str, str] | None = None,
|
||||
) -> None:
|
||||
self.status_code = status_code
|
||||
self.detail = detail
|
||||
self.headers = headers
|
||||
super().__init__(detail)
|
||||
|
||||
HTTP_401_UNAUTHORIZED = 401
|
||||
HTTP_500_INTERNAL_SERVER_ERROR = 500
|
||||
HTTP_503_SERVICE_UNAVAILABLE = 503
|
||||
|
||||
|
||||
def _coerce_secret_str(v: str | SecretStr | None) -> SecretStr | None:
|
||||
"""Coerce string to SecretStr."""
|
||||
if v is None or isinstance(v, SecretStr):
|
||||
return v
|
||||
return SecretStr(v)
|
||||
|
||||
|
||||
CoercedSecretStr = Annotated[SecretStr, BeforeValidator(_coerce_secret_str)]
|
||||
|
||||
JWTAlgorithm = Literal[
|
||||
"RS256",
|
||||
"RS384",
|
||||
"RS512",
|
||||
"ES256",
|
||||
"ES384",
|
||||
"ES512",
|
||||
"PS256",
|
||||
"PS384",
|
||||
"PS512",
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class AuthenticatedUser:
|
||||
"""Result of successful authentication.
|
||||
|
||||
Attributes:
|
||||
token: The original token that was validated.
|
||||
scheme: Name of the authentication scheme used.
|
||||
claims: JWT claims from OIDC or OAuth2 authentication.
|
||||
"""
|
||||
|
||||
token: str
|
||||
scheme: str
|
||||
claims: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class ServerAuthScheme(ABC, BaseModel):
|
||||
"""Base class for server-side authentication schemes.
|
||||
|
||||
Each scheme validates incoming requests and returns an AuthenticatedUser
|
||||
on success, or raises HTTPException on failure.
|
||||
"""
|
||||
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
|
||||
|
||||
@abstractmethod
|
||||
async def authenticate(self, token: str) -> AuthenticatedUser:
|
||||
"""Authenticate the provided token.
|
||||
|
||||
Args:
|
||||
token: The bearer token to authenticate.
|
||||
|
||||
Returns:
|
||||
AuthenticatedUser on successful authentication.
|
||||
|
||||
Raises:
|
||||
HTTPException: If authentication fails.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class SimpleTokenAuth(ServerAuthScheme):
|
||||
"""Simple bearer token authentication.
|
||||
|
||||
Validates tokens against a configured static token or AUTH_TOKEN env var.
|
||||
|
||||
Attributes:
|
||||
token: Expected token value. Falls back to AUTH_TOKEN env var if not set.
|
||||
"""
|
||||
|
||||
token: CoercedSecretStr | None = Field(
|
||||
default=None,
|
||||
description="Expected token. Falls back to AUTH_TOKEN env var.",
|
||||
)
|
||||
|
||||
def _get_expected_token(self) -> str | None:
|
||||
"""Get the expected token value."""
|
||||
if self.token:
|
||||
return self.token.get_secret_value()
|
||||
return os.environ.get("AUTH_TOKEN")
|
||||
|
||||
async def authenticate(self, token: str) -> AuthenticatedUser:
|
||||
"""Authenticate using simple token comparison.
|
||||
|
||||
Args:
|
||||
token: The bearer token to authenticate.
|
||||
|
||||
Returns:
|
||||
AuthenticatedUser on successful authentication.
|
||||
|
||||
Raises:
|
||||
HTTPException: If authentication fails.
|
||||
"""
|
||||
expected = self._get_expected_token()
|
||||
|
||||
if expected is None:
|
||||
logger.warning(
|
||||
"Simple token authentication failed",
|
||||
extra={"reason": "no_token_configured"},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=HTTP_401_UNAUTHORIZED,
|
||||
detail="Authentication not configured",
|
||||
)
|
||||
|
||||
if token != expected:
|
||||
raise HTTPException(
|
||||
status_code=HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid or missing authentication credentials",
|
||||
)
|
||||
|
||||
return AuthenticatedUser(
|
||||
token=token,
|
||||
scheme="simple_token",
|
||||
)
|
||||
|
||||
|
||||
class OIDCAuth(ServerAuthScheme):
|
||||
"""OpenID Connect authentication.
|
||||
|
||||
Validates JWTs using JWKS with caching support via PyJWT.
|
||||
|
||||
Attributes:
|
||||
issuer: The OpenID Connect issuer URL.
|
||||
audience: The expected audience claim.
|
||||
jwks_url: Optional explicit JWKS URL. Derived from issuer if not set.
|
||||
algorithms: List of allowed signing algorithms.
|
||||
required_claims: List of claims that must be present in the token.
|
||||
jwks_cache_ttl: TTL for JWKS cache in seconds.
|
||||
clock_skew_seconds: Allowed clock skew for token validation.
|
||||
"""
|
||||
|
||||
issuer: HttpUrl = Field(
|
||||
description="OpenID Connect issuer URL (e.g., https://auth.example.com)"
|
||||
)
|
||||
audience: str = Field(description="Expected audience claim (e.g., api://my-agent)")
|
||||
jwks_url: HttpUrl | None = Field(
|
||||
default=None,
|
||||
description="Explicit JWKS URL. Derived from issuer if not set.",
|
||||
)
|
||||
algorithms: list[str] = Field(
|
||||
default_factory=lambda: ["RS256"],
|
||||
description="List of allowed signing algorithms (RS256, ES256, etc.)",
|
||||
)
|
||||
required_claims: list[str] = Field(
|
||||
default_factory=lambda: ["exp", "iat", "iss", "aud", "sub"],
|
||||
description="List of claims that must be present in the token",
|
||||
)
|
||||
jwks_cache_ttl: int = Field(
|
||||
default=3600,
|
||||
description="TTL for JWKS cache in seconds",
|
||||
ge=60,
|
||||
)
|
||||
clock_skew_seconds: float = Field(
|
||||
default=30.0,
|
||||
description="Allowed clock skew for token validation",
|
||||
ge=0.0,
|
||||
)
|
||||
|
||||
_jwk_client: PyJWKClient | None = PrivateAttr(default=None)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _init_jwk_client(self) -> Self:
|
||||
"""Initialize the JWK client after model creation."""
|
||||
jwks_url = (
|
||||
str(self.jwks_url)
|
||||
if self.jwks_url
|
||||
else f"{str(self.issuer).rstrip('/')}/.well-known/jwks.json"
|
||||
)
|
||||
self._jwk_client = PyJWKClient(jwks_url, lifespan=self.jwks_cache_ttl)
|
||||
return self
|
||||
|
||||
async def authenticate(self, token: str) -> AuthenticatedUser:
|
||||
"""Authenticate using OIDC JWT validation.
|
||||
|
||||
Args:
|
||||
token: The JWT to authenticate.
|
||||
|
||||
Returns:
|
||||
AuthenticatedUser on successful authentication.
|
||||
|
||||
Raises:
|
||||
HTTPException: If authentication fails.
|
||||
"""
|
||||
if self._jwk_client is None:
|
||||
raise HTTPException(
|
||||
status_code=HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="OIDC not initialized",
|
||||
)
|
||||
|
||||
try:
|
||||
signing_key = self._jwk_client.get_signing_key_from_jwt(token)
|
||||
|
||||
claims = jwt.decode(
|
||||
token,
|
||||
signing_key.key,
|
||||
algorithms=self.algorithms,
|
||||
audience=self.audience,
|
||||
issuer=str(self.issuer).rstrip("/"),
|
||||
leeway=self.clock_skew_seconds,
|
||||
options={
|
||||
"require": self.required_claims,
|
||||
},
|
||||
)
|
||||
|
||||
return AuthenticatedUser(
|
||||
token=token,
|
||||
scheme="oidc",
|
||||
claims=claims,
|
||||
)
|
||||
|
||||
except jwt.ExpiredSignatureError:
|
||||
logger.debug(
|
||||
"OIDC authentication failed",
|
||||
extra={"reason": "token_expired", "scheme": "oidc"},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=HTTP_401_UNAUTHORIZED,
|
||||
detail="Token has expired",
|
||||
) from None
|
||||
except jwt.InvalidAudienceError:
|
||||
logger.debug(
|
||||
"OIDC authentication failed",
|
||||
extra={"reason": "invalid_audience", "scheme": "oidc"},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid token audience",
|
||||
) from None
|
||||
except jwt.InvalidIssuerError:
|
||||
logger.debug(
|
||||
"OIDC authentication failed",
|
||||
extra={"reason": "invalid_issuer", "scheme": "oidc"},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid token issuer",
|
||||
) from None
|
||||
except jwt.MissingRequiredClaimError as e:
|
||||
logger.debug(
|
||||
"OIDC authentication failed",
|
||||
extra={"reason": "missing_claim", "claim": e.claim, "scheme": "oidc"},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=HTTP_401_UNAUTHORIZED,
|
||||
detail=f"Missing required claim: {e.claim}",
|
||||
) from None
|
||||
except jwt.PyJWKClientError as e:
|
||||
logger.error(
|
||||
"OIDC authentication failed",
|
||||
extra={
|
||||
"reason": "jwks_client_error",
|
||||
"error": str(e),
|
||||
"scheme": "oidc",
|
||||
},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="Unable to fetch signing keys",
|
||||
) from None
|
||||
except jwt.InvalidTokenError as e:
|
||||
logger.debug(
|
||||
"OIDC authentication failed",
|
||||
extra={"reason": "invalid_token", "error": str(e), "scheme": "oidc"},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid or missing authentication credentials",
|
||||
) from None
|
||||
|
||||
|
||||
class OAuth2ServerAuth(ServerAuthScheme):
|
||||
"""OAuth2 authentication for A2A server.
|
||||
|
||||
Declares OAuth2 security scheme in AgentCard and validates tokens using
|
||||
either JWKS for JWT tokens or token introspection for opaque tokens.
|
||||
|
||||
This is distinct from OIDCAuth in that it declares an explicit OAuth2SecurityScheme
|
||||
with flows, rather than an OpenIdConnectSecurityScheme with discovery URL.
|
||||
|
||||
Attributes:
|
||||
token_url: OAuth2 token endpoint URL for client_credentials flow.
|
||||
authorization_url: OAuth2 authorization endpoint for authorization_code flow.
|
||||
refresh_url: Optional refresh token endpoint URL.
|
||||
scopes: Available OAuth2 scopes with descriptions.
|
||||
jwks_url: JWKS URL for JWT validation. Required if not using introspection.
|
||||
introspection_url: Token introspection endpoint (RFC 7662). Alternative to JWKS.
|
||||
introspection_client_id: Client ID for introspection endpoint authentication.
|
||||
introspection_client_secret: Client secret for introspection endpoint.
|
||||
audience: Expected audience claim for JWT validation.
|
||||
issuer: Expected issuer claim for JWT validation.
|
||||
algorithms: Allowed JWT signing algorithms.
|
||||
required_claims: Claims that must be present in the token.
|
||||
jwks_cache_ttl: TTL for JWKS cache in seconds.
|
||||
clock_skew_seconds: Allowed clock skew for token validation.
|
||||
"""
|
||||
|
||||
token_url: HttpUrl = Field(
|
||||
description="OAuth2 token endpoint URL",
|
||||
)
|
||||
authorization_url: HttpUrl | None = Field(
|
||||
default=None,
|
||||
description="OAuth2 authorization endpoint URL for authorization_code flow",
|
||||
)
|
||||
refresh_url: HttpUrl | None = Field(
|
||||
default=None,
|
||||
description="OAuth2 refresh token endpoint URL",
|
||||
)
|
||||
scopes: dict[str, str] = Field(
|
||||
default_factory=dict,
|
||||
description="Available OAuth2 scopes with descriptions",
|
||||
)
|
||||
jwks_url: HttpUrl | None = Field(
|
||||
default=None,
|
||||
description="JWKS URL for JWT validation. Required if not using introspection.",
|
||||
)
|
||||
introspection_url: HttpUrl | None = Field(
|
||||
default=None,
|
||||
description="Token introspection endpoint (RFC 7662). Alternative to JWKS.",
|
||||
)
|
||||
introspection_client_id: str | None = Field(
|
||||
default=None,
|
||||
description="Client ID for introspection endpoint authentication",
|
||||
)
|
||||
introspection_client_secret: CoercedSecretStr | None = Field(
|
||||
default=None,
|
||||
description="Client secret for introspection endpoint authentication",
|
||||
)
|
||||
audience: str | None = Field(
|
||||
default=None,
|
||||
description="Expected audience claim for JWT validation",
|
||||
)
|
||||
issuer: str | None = Field(
|
||||
default=None,
|
||||
description="Expected issuer claim for JWT validation",
|
||||
)
|
||||
algorithms: list[str] = Field(
|
||||
default_factory=lambda: ["RS256"],
|
||||
description="Allowed JWT signing algorithms",
|
||||
)
|
||||
required_claims: list[str] = Field(
|
||||
default_factory=lambda: ["exp", "iat"],
|
||||
description="Claims that must be present in the token",
|
||||
)
|
||||
jwks_cache_ttl: int = Field(
|
||||
default=3600,
|
||||
description="TTL for JWKS cache in seconds",
|
||||
ge=60,
|
||||
)
|
||||
clock_skew_seconds: float = Field(
|
||||
default=30.0,
|
||||
description="Allowed clock skew for token validation",
|
||||
ge=0.0,
|
||||
)
|
||||
|
||||
_jwk_client: PyJWKClient | None = PrivateAttr(default=None)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _validate_and_init(self) -> Self:
|
||||
"""Validate configuration and initialize JWKS client if needed."""
|
||||
if not self.jwks_url and not self.introspection_url:
|
||||
raise ValueError(
|
||||
"Either jwks_url or introspection_url must be provided for token validation"
|
||||
)
|
||||
|
||||
if self.introspection_url:
|
||||
if not self.introspection_client_id or not self.introspection_client_secret:
|
||||
raise ValueError(
|
||||
"introspection_client_id and introspection_client_secret are required "
|
||||
"when using token introspection"
|
||||
)
|
||||
|
||||
if self.jwks_url:
|
||||
self._jwk_client = PyJWKClient(
|
||||
str(self.jwks_url), lifespan=self.jwks_cache_ttl
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
async def authenticate(self, token: str) -> AuthenticatedUser:
|
||||
"""Authenticate using OAuth2 token validation.
|
||||
|
||||
Uses JWKS validation if jwks_url is configured, otherwise falls back
|
||||
to token introspection.
|
||||
|
||||
Args:
|
||||
token: The OAuth2 access token to authenticate.
|
||||
|
||||
Returns:
|
||||
AuthenticatedUser on successful authentication.
|
||||
|
||||
Raises:
|
||||
HTTPException: If authentication fails.
|
||||
"""
|
||||
if self._jwk_client:
|
||||
return await self._authenticate_jwt(token)
|
||||
return await self._authenticate_introspection(token)
|
||||
|
||||
async def _authenticate_jwt(self, token: str) -> AuthenticatedUser:
|
||||
"""Authenticate using JWKS JWT validation."""
|
||||
if self._jwk_client is None:
|
||||
raise HTTPException(
|
||||
status_code=HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="OAuth2 JWKS not initialized",
|
||||
)
|
||||
|
||||
try:
|
||||
signing_key = self._jwk_client.get_signing_key_from_jwt(token)
|
||||
|
||||
decode_options: dict[str, Any] = {
|
||||
"require": self.required_claims,
|
||||
}
|
||||
|
||||
claims = jwt.decode(
|
||||
token,
|
||||
signing_key.key,
|
||||
algorithms=self.algorithms,
|
||||
audience=self.audience,
|
||||
issuer=self.issuer,
|
||||
leeway=self.clock_skew_seconds,
|
||||
options=decode_options,
|
||||
)
|
||||
|
||||
return AuthenticatedUser(
|
||||
token=token,
|
||||
scheme="oauth2",
|
||||
claims=claims,
|
||||
)
|
||||
|
||||
except jwt.ExpiredSignatureError:
|
||||
logger.debug(
|
||||
"OAuth2 authentication failed",
|
||||
extra={"reason": "token_expired", "scheme": "oauth2"},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=HTTP_401_UNAUTHORIZED,
|
||||
detail="Token has expired",
|
||||
) from None
|
||||
except jwt.InvalidAudienceError:
|
||||
logger.debug(
|
||||
"OAuth2 authentication failed",
|
||||
extra={"reason": "invalid_audience", "scheme": "oauth2"},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid token audience",
|
||||
) from None
|
||||
except jwt.InvalidIssuerError:
|
||||
logger.debug(
|
||||
"OAuth2 authentication failed",
|
||||
extra={"reason": "invalid_issuer", "scheme": "oauth2"},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid token issuer",
|
||||
) from None
|
||||
except jwt.MissingRequiredClaimError as e:
|
||||
logger.debug(
|
||||
"OAuth2 authentication failed",
|
||||
extra={"reason": "missing_claim", "claim": e.claim, "scheme": "oauth2"},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=HTTP_401_UNAUTHORIZED,
|
||||
detail=f"Missing required claim: {e.claim}",
|
||||
) from None
|
||||
except jwt.PyJWKClientError as e:
|
||||
logger.error(
|
||||
"OAuth2 authentication failed",
|
||||
extra={
|
||||
"reason": "jwks_client_error",
|
||||
"error": str(e),
|
||||
"scheme": "oauth2",
|
||||
},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="Unable to fetch signing keys",
|
||||
) from None
|
||||
except jwt.InvalidTokenError as e:
|
||||
logger.debug(
|
||||
"OAuth2 authentication failed",
|
||||
extra={"reason": "invalid_token", "error": str(e), "scheme": "oauth2"},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid or missing authentication credentials",
|
||||
) from None
|
||||
|
||||
async def _authenticate_introspection(self, token: str) -> AuthenticatedUser:
|
||||
"""Authenticate using OAuth2 token introspection (RFC 7662)."""
|
||||
import httpx
|
||||
|
||||
if not self.introspection_url:
|
||||
raise HTTPException(
|
||||
status_code=HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="OAuth2 introspection not configured",
|
||||
)
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
str(self.introspection_url),
|
||||
data={"token": token},
|
||||
auth=(
|
||||
self.introspection_client_id or "",
|
||||
self.introspection_client_secret.get_secret_value()
|
||||
if self.introspection_client_secret
|
||||
else "",
|
||||
),
|
||||
)
|
||||
response.raise_for_status()
|
||||
introspection_result = response.json()
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(
|
||||
"OAuth2 introspection failed",
|
||||
extra={"reason": "http_error", "status_code": e.response.status_code},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="Token introspection service unavailable",
|
||||
) from None
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"OAuth2 introspection failed",
|
||||
extra={"reason": "unexpected_error", "error": str(e)},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="Token introspection failed",
|
||||
) from None
|
||||
|
||||
if not introspection_result.get("active", False):
|
||||
logger.debug(
|
||||
"OAuth2 authentication failed",
|
||||
extra={"reason": "token_not_active", "scheme": "oauth2"},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=HTTP_401_UNAUTHORIZED,
|
||||
detail="Token is not active",
|
||||
)
|
||||
|
||||
return AuthenticatedUser(
|
||||
token=token,
|
||||
scheme="oauth2",
|
||||
claims=introspection_result,
|
||||
)
|
||||
|
||||
def to_security_scheme(self) -> OAuth2SecurityScheme:
|
||||
"""Generate OAuth2SecurityScheme for AgentCard declaration.
|
||||
|
||||
Creates an OAuth2SecurityScheme with appropriate flows based on
|
||||
the configured URLs. Includes client_credentials flow if token_url
|
||||
is set, and authorization_code flow if authorization_url is set.
|
||||
|
||||
Returns:
|
||||
OAuth2SecurityScheme suitable for use in AgentCard security_schemes.
|
||||
"""
|
||||
from a2a.types import (
|
||||
AuthorizationCodeOAuthFlow,
|
||||
ClientCredentialsOAuthFlow,
|
||||
OAuth2SecurityScheme,
|
||||
OAuthFlows,
|
||||
)
|
||||
|
||||
client_credentials = None
|
||||
authorization_code = None
|
||||
|
||||
if self.token_url:
|
||||
client_credentials = ClientCredentialsOAuthFlow(
|
||||
token_url=str(self.token_url),
|
||||
refresh_url=str(self.refresh_url) if self.refresh_url else None,
|
||||
scopes=self.scopes,
|
||||
)
|
||||
|
||||
if self.authorization_url:
|
||||
authorization_code = AuthorizationCodeOAuthFlow(
|
||||
authorization_url=str(self.authorization_url),
|
||||
token_url=str(self.token_url),
|
||||
refresh_url=str(self.refresh_url) if self.refresh_url else None,
|
||||
scopes=self.scopes,
|
||||
)
|
||||
|
||||
return OAuth2SecurityScheme(
|
||||
flows=OAuthFlows(
|
||||
client_credentials=client_credentials,
|
||||
authorization_code=authorization_code,
|
||||
),
|
||||
description="OAuth2 authentication",
|
||||
)
|
||||
|
||||
|
||||
class APIKeyServerAuth(ServerAuthScheme):
|
||||
"""API Key authentication for A2A server.
|
||||
|
||||
Validates requests using an API key in a header, query parameter, or cookie.
|
||||
|
||||
Attributes:
|
||||
name: The name of the API key parameter (default: X-API-Key).
|
||||
location: Where to look for the API key (header, query, or cookie).
|
||||
api_key: The expected API key value.
|
||||
"""
|
||||
|
||||
name: str = Field(
|
||||
default="X-API-Key",
|
||||
description="Name of the API key parameter",
|
||||
)
|
||||
location: Literal["header", "query", "cookie"] = Field(
|
||||
default="header",
|
||||
description="Where to look for the API key",
|
||||
)
|
||||
api_key: CoercedSecretStr = Field(
|
||||
description="Expected API key value",
|
||||
)
|
||||
|
||||
async def authenticate(self, token: str) -> AuthenticatedUser:
|
||||
"""Authenticate using API key comparison.
|
||||
|
||||
Args:
|
||||
token: The API key to authenticate.
|
||||
|
||||
Returns:
|
||||
AuthenticatedUser on successful authentication.
|
||||
|
||||
Raises:
|
||||
HTTPException: If authentication fails.
|
||||
"""
|
||||
if token != self.api_key.get_secret_value():
|
||||
raise HTTPException(
|
||||
status_code=HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid API key",
|
||||
)
|
||||
|
||||
return AuthenticatedUser(
|
||||
token=token,
|
||||
scheme="api_key",
|
||||
)
|
||||
|
||||
|
||||
class MTLSServerAuth(ServerAuthScheme):
|
||||
"""Mutual TLS authentication marker for AgentCard declaration.
|
||||
|
||||
This scheme is primarily for AgentCard security_schemes declaration.
|
||||
Actual mTLS verification happens at the TLS/transport layer, not
|
||||
at the application layer via token validation.
|
||||
|
||||
When configured, this signals to clients that the server requires
|
||||
client certificates for authentication.
|
||||
"""
|
||||
|
||||
description: str = Field(
|
||||
default="Mutual TLS certificate authentication",
|
||||
description="Description for the security scheme",
|
||||
)
|
||||
|
||||
async def authenticate(self, token: str) -> AuthenticatedUser:
|
||||
"""Return authenticated user for mTLS.
|
||||
|
||||
mTLS verification happens at the transport layer before this is called.
|
||||
If we reach this point, the TLS handshake with client cert succeeded.
|
||||
|
||||
Args:
|
||||
token: Certificate subject or identifier (from TLS layer).
|
||||
|
||||
Returns:
|
||||
AuthenticatedUser indicating mTLS authentication.
|
||||
"""
|
||||
return AuthenticatedUser(
|
||||
token=token or "mtls-verified",
|
||||
scheme="mtls",
|
||||
)
|
||||
@@ -6,10 +6,8 @@ OAuth2, API keys, and HTTP authentication methods.
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Awaitable, Callable, MutableMapping
|
||||
import hashlib
|
||||
import re
|
||||
import threading
|
||||
from typing import Final, Literal, cast
|
||||
from typing import Final
|
||||
|
||||
from a2a.client.errors import A2AClientHTTPError
|
||||
from a2a.types import (
|
||||
@@ -20,10 +18,10 @@ from a2a.types import (
|
||||
)
|
||||
from httpx import AsyncClient, Response
|
||||
|
||||
from crewai.a2a.auth.client_schemes import (
|
||||
from crewai.a2a.auth.schemas import (
|
||||
APIKeyAuth,
|
||||
AuthScheme,
|
||||
BearerTokenAuth,
|
||||
ClientAuthScheme,
|
||||
HTTPBasicAuth,
|
||||
HTTPDigestAuth,
|
||||
OAuth2AuthorizationCode,
|
||||
@@ -31,44 +29,12 @@ from crewai.a2a.auth.client_schemes import (
|
||||
)
|
||||
|
||||
|
||||
class _AuthStore:
|
||||
"""Store for authentication schemes with safe concurrent access."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._store: dict[str, ClientAuthScheme | None] = {}
|
||||
self._lock = threading.RLock()
|
||||
|
||||
@staticmethod
|
||||
def compute_key(auth_type: str, auth_data: str) -> str:
|
||||
"""Compute a collision-resistant key using SHA-256."""
|
||||
content = f"{auth_type}:{auth_data}"
|
||||
return hashlib.sha256(content.encode()).hexdigest()
|
||||
|
||||
def set(self, key: str, auth: ClientAuthScheme | None) -> None:
|
||||
"""Store an auth scheme."""
|
||||
with self._lock:
|
||||
self._store[key] = auth
|
||||
|
||||
def get(self, key: str) -> ClientAuthScheme | None:
|
||||
"""Retrieve an auth scheme by key."""
|
||||
with self._lock:
|
||||
return self._store.get(key)
|
||||
|
||||
def __setitem__(self, key: str, value: ClientAuthScheme | None) -> None:
|
||||
with self._lock:
|
||||
self._store[key] = value
|
||||
|
||||
def __getitem__(self, key: str) -> ClientAuthScheme | None:
|
||||
with self._lock:
|
||||
return self._store[key]
|
||||
|
||||
|
||||
_auth_store = _AuthStore()
|
||||
_auth_store: dict[int, AuthScheme | None] = {}
|
||||
|
||||
_SCHEME_PATTERN: Final[re.Pattern[str]] = re.compile(r"(\w+)\s+(.+?)(?=,\s*\w+\s+|$)")
|
||||
_PARAM_PATTERN: Final[re.Pattern[str]] = re.compile(r'(\w+)=(?:"([^"]*)"|([^\s,]+))')
|
||||
|
||||
_SCHEME_AUTH_MAPPING: Final[dict[type, tuple[type[ClientAuthScheme], ...]]] = {
|
||||
_SCHEME_AUTH_MAPPING: Final[dict[type, tuple[type[AuthScheme], ...]]] = {
|
||||
OAuth2SecurityScheme: (
|
||||
OAuth2ClientCredentials,
|
||||
OAuth2AuthorizationCode,
|
||||
@@ -77,9 +43,7 @@ _SCHEME_AUTH_MAPPING: Final[dict[type, tuple[type[ClientAuthScheme], ...]]] = {
|
||||
APIKeySecurityScheme: (APIKeyAuth,),
|
||||
}
|
||||
|
||||
_HTTPSchemeType = Literal["basic", "digest", "bearer"]
|
||||
|
||||
_HTTP_SCHEME_MAPPING: Final[dict[_HTTPSchemeType, type[ClientAuthScheme]]] = {
|
||||
_HTTP_SCHEME_MAPPING: Final[dict[str, type[AuthScheme]]] = {
|
||||
"basic": HTTPBasicAuth,
|
||||
"digest": HTTPDigestAuth,
|
||||
"bearer": BearerTokenAuth,
|
||||
@@ -87,8 +51,8 @@ _HTTP_SCHEME_MAPPING: Final[dict[_HTTPSchemeType, type[ClientAuthScheme]]] = {
|
||||
|
||||
|
||||
def _raise_auth_mismatch(
|
||||
expected_classes: type[ClientAuthScheme] | tuple[type[ClientAuthScheme], ...],
|
||||
provided_auth: ClientAuthScheme,
|
||||
expected_classes: type[AuthScheme] | tuple[type[AuthScheme], ...],
|
||||
provided_auth: AuthScheme,
|
||||
) -> None:
|
||||
"""Raise authentication mismatch error.
|
||||
|
||||
@@ -147,7 +111,7 @@ def parse_www_authenticate(header_value: str) -> dict[str, dict[str, str]]:
|
||||
|
||||
|
||||
def validate_auth_against_agent_card(
|
||||
agent_card: AgentCard, auth: ClientAuthScheme | None
|
||||
agent_card: AgentCard, auth: AuthScheme | None
|
||||
) -> None:
|
||||
"""Validate that provided auth matches AgentCard security requirements.
|
||||
|
||||
@@ -181,8 +145,7 @@ def validate_auth_against_agent_card(
|
||||
return
|
||||
|
||||
if isinstance(scheme, HTTPAuthSecurityScheme):
|
||||
scheme_key = cast(_HTTPSchemeType, scheme.scheme.lower())
|
||||
if required_class := _HTTP_SCHEME_MAPPING.get(scheme_key):
|
||||
if required_class := _HTTP_SCHEME_MAPPING.get(scheme.scheme.lower()):
|
||||
if not isinstance(auth, required_class):
|
||||
_raise_auth_mismatch(required_class, auth)
|
||||
return
|
||||
@@ -193,7 +156,7 @@ def validate_auth_against_agent_card(
|
||||
|
||||
async def retry_on_401(
|
||||
request_func: Callable[[], Awaitable[Response]],
|
||||
auth_scheme: ClientAuthScheme | None,
|
||||
auth_scheme: AuthScheme | None,
|
||||
client: AsyncClient,
|
||||
headers: MutableMapping[str, str],
|
||||
max_retries: int = 3,
|
||||
|
||||
@@ -5,25 +5,14 @@ This module is separate from experimental.a2a to avoid circular imports.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any, ClassVar, Literal, cast
|
||||
import warnings
|
||||
from importlib.metadata import version
|
||||
from typing import Any, ClassVar, Literal
|
||||
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
Field,
|
||||
FilePath,
|
||||
PrivateAttr,
|
||||
SecretStr,
|
||||
model_validator,
|
||||
)
|
||||
from typing_extensions import Self, deprecated
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from typing_extensions import deprecated
|
||||
|
||||
from crewai.a2a.auth.client_schemes import ClientAuthScheme
|
||||
from crewai.a2a.auth.server_schemes import ServerAuthScheme
|
||||
from crewai.a2a.extensions.base import ValidatedA2AExtension
|
||||
from crewai.a2a.types import ProtocolVersion, TransportType, Url
|
||||
from crewai.a2a.auth.schemas import AuthScheme
|
||||
from crewai.a2a.types import TransportType, Url
|
||||
|
||||
|
||||
try:
|
||||
@@ -36,17 +25,16 @@ try:
|
||||
SecurityScheme,
|
||||
)
|
||||
|
||||
from crewai.a2a.extensions.server import ServerExtension
|
||||
from crewai.a2a.updates import UpdateConfig
|
||||
except ImportError:
|
||||
UpdateConfig: Any = Any # type: ignore[no-redef]
|
||||
AgentCapabilities: Any = Any # type: ignore[no-redef]
|
||||
AgentCardSignature: Any = Any # type: ignore[no-redef]
|
||||
AgentInterface: Any = Any # type: ignore[no-redef]
|
||||
AgentProvider: Any = Any # type: ignore[no-redef]
|
||||
SecurityScheme: Any = Any # type: ignore[no-redef]
|
||||
AgentSkill: Any = Any # type: ignore[no-redef]
|
||||
ServerExtension: Any = Any # type: ignore[no-redef]
|
||||
UpdateConfig = Any
|
||||
AgentCapabilities = Any
|
||||
AgentCardSignature = Any
|
||||
AgentInterface = Any
|
||||
AgentProvider = Any
|
||||
SecurityScheme = Any
|
||||
AgentSkill = Any
|
||||
UpdateConfig = Any # type: ignore[misc,assignment]
|
||||
|
||||
|
||||
def _get_default_update_config() -> UpdateConfig:
|
||||
@@ -55,309 +43,6 @@ def _get_default_update_config() -> UpdateConfig:
|
||||
return StreamingConfig()
|
||||
|
||||
|
||||
SigningAlgorithm = Literal[
|
||||
"RS256",
|
||||
"RS384",
|
||||
"RS512",
|
||||
"ES256",
|
||||
"ES384",
|
||||
"ES512",
|
||||
"PS256",
|
||||
"PS384",
|
||||
"PS512",
|
||||
]
|
||||
|
||||
|
||||
class AgentCardSigningConfig(BaseModel):
|
||||
"""Configuration for AgentCard JWS signing.
|
||||
|
||||
Provides the private key and algorithm settings for signing AgentCards.
|
||||
Either private_key_path or private_key_pem must be provided, but not both.
|
||||
|
||||
Attributes:
|
||||
private_key_path: Path to a PEM-encoded private key file.
|
||||
private_key_pem: PEM-encoded private key as a secret string.
|
||||
key_id: Optional key identifier for the JWS header (kid claim).
|
||||
algorithm: Signing algorithm (RS256, ES256, PS256, etc.).
|
||||
"""
|
||||
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
|
||||
|
||||
private_key_path: FilePath | None = Field(
|
||||
default=None,
|
||||
description="Path to PEM-encoded private key file",
|
||||
)
|
||||
private_key_pem: SecretStr | None = Field(
|
||||
default=None,
|
||||
description="PEM-encoded private key",
|
||||
)
|
||||
key_id: str | None = Field(
|
||||
default=None,
|
||||
description="Key identifier for JWS header (kid claim)",
|
||||
)
|
||||
algorithm: SigningAlgorithm = Field(
|
||||
default="RS256",
|
||||
description="Signing algorithm (RS256, ES256, PS256, etc.)",
|
||||
)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _validate_key_source(self) -> Self:
|
||||
"""Ensure exactly one key source is provided."""
|
||||
has_path = self.private_key_path is not None
|
||||
has_pem = self.private_key_pem is not None
|
||||
|
||||
if not has_path and not has_pem:
|
||||
raise ValueError(
|
||||
"Either private_key_path or private_key_pem must be provided"
|
||||
)
|
||||
if has_path and has_pem:
|
||||
raise ValueError(
|
||||
"Only one of private_key_path or private_key_pem should be provided"
|
||||
)
|
||||
return self
|
||||
|
||||
def get_private_key(self) -> str:
|
||||
"""Get the private key content.
|
||||
|
||||
Returns:
|
||||
The PEM-encoded private key as a string.
|
||||
"""
|
||||
if self.private_key_pem:
|
||||
return self.private_key_pem.get_secret_value()
|
||||
if self.private_key_path:
|
||||
return Path(self.private_key_path).read_text()
|
||||
raise ValueError("No private key configured")
|
||||
|
||||
|
||||
class GRPCServerConfig(BaseModel):
|
||||
"""gRPC server transport configuration.
|
||||
|
||||
Presence of this config in ServerTransportConfig.grpc enables gRPC transport.
|
||||
|
||||
Attributes:
|
||||
host: Hostname to advertise in agent cards (default: localhost).
|
||||
Use docker service name (e.g., 'web') for docker-compose setups.
|
||||
port: Port for the gRPC server.
|
||||
tls_cert_path: Path to TLS certificate file for gRPC.
|
||||
tls_key_path: Path to TLS private key file for gRPC.
|
||||
max_workers: Maximum number of workers for the gRPC thread pool.
|
||||
reflection_enabled: Whether to enable gRPC reflection for debugging.
|
||||
"""
|
||||
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
|
||||
|
||||
host: str = Field(
|
||||
default="localhost",
|
||||
description="Hostname to advertise in agent cards for gRPC connections",
|
||||
)
|
||||
port: int = Field(
|
||||
default=50051,
|
||||
description="Port for the gRPC server",
|
||||
)
|
||||
tls_cert_path: str | None = Field(
|
||||
default=None,
|
||||
description="Path to TLS certificate file for gRPC",
|
||||
)
|
||||
tls_key_path: str | None = Field(
|
||||
default=None,
|
||||
description="Path to TLS private key file for gRPC",
|
||||
)
|
||||
max_workers: int = Field(
|
||||
default=10,
|
||||
description="Maximum number of workers for the gRPC thread pool",
|
||||
)
|
||||
reflection_enabled: bool = Field(
|
||||
default=False,
|
||||
description="Whether to enable gRPC reflection for debugging",
|
||||
)
|
||||
|
||||
|
||||
class GRPCClientConfig(BaseModel):
|
||||
"""gRPC client transport configuration.
|
||||
|
||||
Attributes:
|
||||
max_send_message_length: Maximum size for outgoing messages in bytes.
|
||||
max_receive_message_length: Maximum size for incoming messages in bytes.
|
||||
keepalive_time_ms: Time between keepalive pings in milliseconds.
|
||||
keepalive_timeout_ms: Timeout for keepalive ping response in milliseconds.
|
||||
"""
|
||||
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
|
||||
|
||||
max_send_message_length: int | None = Field(
|
||||
default=None,
|
||||
description="Maximum size for outgoing messages in bytes",
|
||||
)
|
||||
max_receive_message_length: int | None = Field(
|
||||
default=None,
|
||||
description="Maximum size for incoming messages in bytes",
|
||||
)
|
||||
keepalive_time_ms: int | None = Field(
|
||||
default=None,
|
||||
description="Time between keepalive pings in milliseconds",
|
||||
)
|
||||
keepalive_timeout_ms: int | None = Field(
|
||||
default=None,
|
||||
description="Timeout for keepalive ping response in milliseconds",
|
||||
)
|
||||
|
||||
|
||||
class JSONRPCServerConfig(BaseModel):
|
||||
"""JSON-RPC server transport configuration.
|
||||
|
||||
Presence of this config in ServerTransportConfig.jsonrpc enables JSON-RPC transport.
|
||||
|
||||
Attributes:
|
||||
rpc_path: URL path for the JSON-RPC endpoint.
|
||||
agent_card_path: URL path for the agent card endpoint.
|
||||
"""
|
||||
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
|
||||
|
||||
rpc_path: str = Field(
|
||||
default="/a2a",
|
||||
description="URL path for the JSON-RPC endpoint",
|
||||
)
|
||||
agent_card_path: str = Field(
|
||||
default="/.well-known/agent-card.json",
|
||||
description="URL path for the agent card endpoint",
|
||||
)
|
||||
|
||||
|
||||
class JSONRPCClientConfig(BaseModel):
|
||||
"""JSON-RPC client transport configuration.
|
||||
|
||||
Attributes:
|
||||
max_request_size: Maximum request body size in bytes.
|
||||
"""
|
||||
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
|
||||
|
||||
max_request_size: int | None = Field(
|
||||
default=None,
|
||||
description="Maximum request body size in bytes",
|
||||
)
|
||||
|
||||
|
||||
class HTTPJSONConfig(BaseModel):
|
||||
"""HTTP+JSON transport configuration.
|
||||
|
||||
Presence of this config in ServerTransportConfig.http_json enables HTTP+JSON transport.
|
||||
"""
|
||||
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
class ServerPushNotificationConfig(BaseModel):
|
||||
"""Configuration for outgoing webhook push notifications.
|
||||
|
||||
Controls how the server signs and delivers push notifications to clients.
|
||||
|
||||
Attributes:
|
||||
signature_secret: Shared secret for HMAC-SHA256 signing of outgoing webhooks.
|
||||
"""
|
||||
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
|
||||
|
||||
signature_secret: SecretStr | None = Field(
|
||||
default=None,
|
||||
description="Shared secret for HMAC-SHA256 signing of outgoing push notifications",
|
||||
)
|
||||
|
||||
|
||||
class ServerTransportConfig(BaseModel):
|
||||
"""Transport configuration for A2A server.
|
||||
|
||||
Groups all transport-related settings including preferred transport
|
||||
and protocol-specific configurations.
|
||||
|
||||
Attributes:
|
||||
preferred: Transport protocol for the preferred endpoint.
|
||||
jsonrpc: JSON-RPC server transport configuration.
|
||||
grpc: gRPC server transport configuration.
|
||||
http_json: HTTP+JSON transport configuration.
|
||||
"""
|
||||
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
|
||||
|
||||
preferred: TransportType = Field(
|
||||
default="JSONRPC",
|
||||
description="Transport protocol for the preferred endpoint",
|
||||
)
|
||||
jsonrpc: JSONRPCServerConfig = Field(
|
||||
default_factory=JSONRPCServerConfig,
|
||||
description="JSON-RPC server transport configuration",
|
||||
)
|
||||
grpc: GRPCServerConfig | None = Field(
|
||||
default=None,
|
||||
description="gRPC server transport configuration",
|
||||
)
|
||||
http_json: HTTPJSONConfig | None = Field(
|
||||
default=None,
|
||||
description="HTTP+JSON transport configuration",
|
||||
)
|
||||
|
||||
|
||||
def _migrate_client_transport_fields(
|
||||
transport: ClientTransportConfig,
|
||||
transport_protocol: TransportType | None,
|
||||
supported_transports: list[TransportType] | None,
|
||||
) -> None:
|
||||
"""Migrate deprecated transport fields to new config."""
|
||||
if transport_protocol is not None:
|
||||
warnings.warn(
|
||||
"transport_protocol is deprecated, use transport=ClientTransportConfig(preferred=...) instead",
|
||||
FutureWarning,
|
||||
stacklevel=5,
|
||||
)
|
||||
object.__setattr__(transport, "preferred", transport_protocol)
|
||||
if supported_transports is not None:
|
||||
warnings.warn(
|
||||
"supported_transports is deprecated, use transport=ClientTransportConfig(supported=...) instead",
|
||||
FutureWarning,
|
||||
stacklevel=5,
|
||||
)
|
||||
object.__setattr__(transport, "supported", supported_transports)
|
||||
|
||||
|
||||
class ClientTransportConfig(BaseModel):
|
||||
"""Transport configuration for A2A client.
|
||||
|
||||
Groups all client transport-related settings including preferred transport,
|
||||
supported transports for negotiation, and protocol-specific configurations.
|
||||
|
||||
Transport negotiation logic:
|
||||
1. If `preferred` is set and server supports it → use client's preferred
|
||||
2. Otherwise, if server's preferred is in client's `supported` → use server's preferred
|
||||
3. Otherwise, find first match from client's `supported` in server's interfaces
|
||||
|
||||
Attributes:
|
||||
preferred: Client's preferred transport. If set, client preference takes priority.
|
||||
supported: Transports the client can use, in order of preference.
|
||||
jsonrpc: JSON-RPC client transport configuration.
|
||||
grpc: gRPC client transport configuration.
|
||||
"""
|
||||
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
|
||||
|
||||
preferred: TransportType | None = Field(
|
||||
default=None,
|
||||
description="Client's preferred transport. If set, takes priority over server preference.",
|
||||
)
|
||||
supported: list[TransportType] = Field(
|
||||
default_factory=lambda: cast(list[TransportType], ["JSONRPC"]),
|
||||
description="Transports the client can use, in order of preference",
|
||||
)
|
||||
jsonrpc: JSONRPCClientConfig = Field(
|
||||
default_factory=JSONRPCClientConfig,
|
||||
description="JSON-RPC client transport configuration",
|
||||
)
|
||||
grpc: GRPCClientConfig = Field(
|
||||
default_factory=GRPCClientConfig,
|
||||
description="gRPC client transport configuration",
|
||||
)
|
||||
|
||||
|
||||
@deprecated(
|
||||
"""
|
||||
`crewai.a2a.config.A2AConfig` is deprecated and will be removed in v2.0.0,
|
||||
@@ -380,14 +65,13 @@ class A2AConfig(BaseModel):
|
||||
fail_fast: If True, raise error when agent unreachable; if False, skip and continue.
|
||||
trust_remote_completion_status: If True, return A2A agent's result directly when completed.
|
||||
updates: Update mechanism config.
|
||||
client_extensions: Client-side processing hooks for tool injection and prompt augmentation.
|
||||
transport: Transport configuration (preferred, supported transports, gRPC settings).
|
||||
transport_protocol: A2A transport protocol (grpc, jsonrpc, http+json).
|
||||
"""
|
||||
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
|
||||
|
||||
endpoint: Url = Field(description="A2A agent endpoint URL")
|
||||
auth: ClientAuthScheme | None = Field(
|
||||
auth: AuthScheme | None = Field(
|
||||
default=None,
|
||||
description="Authentication scheme",
|
||||
)
|
||||
@@ -411,48 +95,10 @@ class A2AConfig(BaseModel):
|
||||
default_factory=_get_default_update_config,
|
||||
description="Update mechanism config",
|
||||
)
|
||||
client_extensions: list[ValidatedA2AExtension] = Field(
|
||||
default_factory=list,
|
||||
description="Client-side processing hooks for tool injection and prompt augmentation",
|
||||
transport_protocol: Literal["JSONRPC", "GRPC", "HTTP+JSON"] = Field(
|
||||
default="JSONRPC",
|
||||
description="Specified mode of A2A transport protocol",
|
||||
)
|
||||
transport: ClientTransportConfig = Field(
|
||||
default_factory=ClientTransportConfig,
|
||||
description="Transport configuration (preferred, supported transports, gRPC settings)",
|
||||
)
|
||||
transport_protocol: TransportType | None = Field(
|
||||
default=None,
|
||||
description="Deprecated: Use transport.preferred instead",
|
||||
exclude=True,
|
||||
)
|
||||
supported_transports: list[TransportType] | None = Field(
|
||||
default=None,
|
||||
description="Deprecated: Use transport.supported instead",
|
||||
exclude=True,
|
||||
)
|
||||
use_client_preference: bool | None = Field(
|
||||
default=None,
|
||||
description="Deprecated: Set transport.preferred to enable client preference",
|
||||
exclude=True,
|
||||
)
|
||||
_parallel_delegation: bool = PrivateAttr(default=False)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _migrate_deprecated_transport_fields(self) -> Self:
|
||||
"""Migrate deprecated transport fields to new config."""
|
||||
_migrate_client_transport_fields(
|
||||
self.transport, self.transport_protocol, self.supported_transports
|
||||
)
|
||||
if self.use_client_preference is not None:
|
||||
warnings.warn(
|
||||
"use_client_preference is deprecated, set transport.preferred to enable client preference",
|
||||
FutureWarning,
|
||||
stacklevel=4,
|
||||
)
|
||||
if self.use_client_preference and self.transport.supported:
|
||||
object.__setattr__(
|
||||
self.transport, "preferred", self.transport.supported[0]
|
||||
)
|
||||
return self
|
||||
|
||||
|
||||
class A2AClientConfig(BaseModel):
|
||||
@@ -468,15 +114,15 @@ class A2AClientConfig(BaseModel):
|
||||
trust_remote_completion_status: If True, return A2A agent's result directly when completed.
|
||||
updates: Update mechanism config.
|
||||
accepted_output_modes: Media types the client can accept in responses.
|
||||
extensions: Extension URIs the client supports (A2A protocol extensions).
|
||||
client_extensions: Client-side processing hooks for tool injection and prompt augmentation.
|
||||
transport: Transport configuration (preferred, supported transports, gRPC settings).
|
||||
supported_transports: Ordered list of transport protocols the client supports.
|
||||
use_client_preference: Whether to prioritize client transport preferences over server.
|
||||
extensions: Extension URIs the client supports.
|
||||
"""
|
||||
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
|
||||
|
||||
endpoint: Url = Field(description="A2A agent endpoint URL")
|
||||
auth: ClientAuthScheme | None = Field(
|
||||
auth: AuthScheme | None = Field(
|
||||
default=None,
|
||||
description="Authentication scheme",
|
||||
)
|
||||
@@ -504,37 +150,22 @@ class A2AClientConfig(BaseModel):
|
||||
default_factory=lambda: ["application/json"],
|
||||
description="Media types the client can accept in responses",
|
||||
)
|
||||
supported_transports: list[str] = Field(
|
||||
default_factory=lambda: ["JSONRPC"],
|
||||
description="Ordered list of transport protocols the client supports",
|
||||
)
|
||||
use_client_preference: bool = Field(
|
||||
default=False,
|
||||
description="Whether to prioritize client transport preferences over server",
|
||||
)
|
||||
extensions: list[str] = Field(
|
||||
default_factory=list,
|
||||
description="Extension URIs the client supports",
|
||||
)
|
||||
client_extensions: list[ValidatedA2AExtension] = Field(
|
||||
default_factory=list,
|
||||
description="Client-side processing hooks for tool injection and prompt augmentation",
|
||||
transport_protocol: Literal["JSONRPC", "GRPC", "HTTP+JSON"] = Field(
|
||||
default="JSONRPC",
|
||||
description="Specified mode of A2A transport protocol",
|
||||
)
|
||||
transport: ClientTransportConfig = Field(
|
||||
default_factory=ClientTransportConfig,
|
||||
description="Transport configuration (preferred, supported transports, gRPC settings)",
|
||||
)
|
||||
transport_protocol: TransportType | None = Field(
|
||||
default=None,
|
||||
description="Deprecated: Use transport.preferred instead",
|
||||
exclude=True,
|
||||
)
|
||||
supported_transports: list[TransportType] | None = Field(
|
||||
default=None,
|
||||
description="Deprecated: Use transport.supported instead",
|
||||
exclude=True,
|
||||
)
|
||||
_parallel_delegation: bool = PrivateAttr(default=False)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _migrate_deprecated_transport_fields(self) -> Self:
|
||||
"""Migrate deprecated transport fields to new config."""
|
||||
_migrate_client_transport_fields(
|
||||
self.transport, self.transport_protocol, self.supported_transports
|
||||
)
|
||||
return self
|
||||
|
||||
|
||||
class A2AServerConfig(BaseModel):
|
||||
@@ -551,6 +182,7 @@ class A2AServerConfig(BaseModel):
|
||||
default_input_modes: Default supported input MIME types.
|
||||
default_output_modes: Default supported output MIME types.
|
||||
capabilities: Declaration of optional capabilities.
|
||||
preferred_transport: Transport protocol for the preferred endpoint.
|
||||
protocol_version: A2A protocol version this agent supports.
|
||||
provider: Information about the agent's service provider.
|
||||
documentation_url: URL to the agent's documentation.
|
||||
@@ -560,12 +192,7 @@ class A2AServerConfig(BaseModel):
|
||||
security_schemes: Security schemes available to authorize requests.
|
||||
supports_authenticated_extended_card: Whether agent provides extended card to authenticated users.
|
||||
url: Preferred endpoint URL for the agent.
|
||||
signing_config: Configuration for signing the AgentCard with JWS.
|
||||
signatures: Deprecated. Pre-computed JWS signatures. Use signing_config instead.
|
||||
server_extensions: Server-side A2A protocol extensions with on_request/on_response hooks.
|
||||
push_notifications: Configuration for outgoing push notifications.
|
||||
transport: Transport configuration (preferred transport, gRPC, REST settings).
|
||||
auth: Authentication scheme for A2A endpoints.
|
||||
signatures: JSON Web Signatures for the AgentCard.
|
||||
"""
|
||||
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
|
||||
@@ -601,8 +228,12 @@ class A2AServerConfig(BaseModel):
|
||||
),
|
||||
description="Declaration of optional capabilities supported by the agent",
|
||||
)
|
||||
protocol_version: ProtocolVersion = Field(
|
||||
default="0.3.0",
|
||||
preferred_transport: TransportType = Field(
|
||||
default="JSONRPC",
|
||||
description="Transport protocol for the preferred endpoint",
|
||||
)
|
||||
protocol_version: str = Field(
|
||||
default_factory=lambda: version("a2a-sdk"),
|
||||
description="A2A protocol version this agent supports",
|
||||
)
|
||||
provider: AgentProvider | None = Field(
|
||||
@@ -619,7 +250,7 @@ class A2AServerConfig(BaseModel):
|
||||
)
|
||||
additional_interfaces: list[AgentInterface] = Field(
|
||||
default_factory=list,
|
||||
description="Additional supported interfaces.",
|
||||
description="Additional supported interfaces (transport and URL combinations)",
|
||||
)
|
||||
security: list[dict[str, list[str]]] = Field(
|
||||
default_factory=list,
|
||||
@@ -637,54 +268,7 @@ class A2AServerConfig(BaseModel):
|
||||
default=None,
|
||||
description="Preferred endpoint URL for the agent. Set at runtime if not provided.",
|
||||
)
|
||||
signing_config: AgentCardSigningConfig | None = Field(
|
||||
default=None,
|
||||
description="Configuration for signing the AgentCard with JWS",
|
||||
)
|
||||
signatures: list[AgentCardSignature] | None = Field(
|
||||
default=None,
|
||||
description="Deprecated: Use signing_config instead. Pre-computed JWS signatures for the AgentCard.",
|
||||
exclude=True,
|
||||
deprecated=True,
|
||||
)
|
||||
server_extensions: list[ServerExtension] = Field(
|
||||
signatures: list[AgentCardSignature] = Field(
|
||||
default_factory=list,
|
||||
description="Server-side A2A protocol extensions that modify agent behavior",
|
||||
description="JSON Web Signatures for the AgentCard",
|
||||
)
|
||||
push_notifications: ServerPushNotificationConfig | None = Field(
|
||||
default=None,
|
||||
description="Configuration for outgoing push notifications",
|
||||
)
|
||||
transport: ServerTransportConfig = Field(
|
||||
default_factory=ServerTransportConfig,
|
||||
description="Transport configuration (preferred transport, gRPC, REST settings)",
|
||||
)
|
||||
preferred_transport: TransportType | None = Field(
|
||||
default=None,
|
||||
description="Deprecated: Use transport.preferred instead",
|
||||
exclude=True,
|
||||
deprecated=True,
|
||||
)
|
||||
auth: ServerAuthScheme | None = Field(
|
||||
default=None,
|
||||
description="Authentication scheme for A2A endpoints. Defaults to SimpleTokenAuth using AUTH_TOKEN env var.",
|
||||
)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _migrate_deprecated_fields(self) -> Self:
|
||||
"""Migrate deprecated fields to new config."""
|
||||
if self.preferred_transport is not None:
|
||||
warnings.warn(
|
||||
"preferred_transport is deprecated, use transport=ServerTransportConfig(preferred=...) instead",
|
||||
FutureWarning,
|
||||
stacklevel=4,
|
||||
)
|
||||
object.__setattr__(self.transport, "preferred", self.preferred_transport)
|
||||
if self.signatures is not None:
|
||||
warnings.warn(
|
||||
"signatures is deprecated, use signing_config=AgentCardSigningConfig(...) instead. "
|
||||
"The signatures field will be removed in v2.0.0.",
|
||||
FutureWarning,
|
||||
stacklevel=4,
|
||||
)
|
||||
return self
|
||||
|
||||
@@ -1,491 +1,7 @@
|
||||
"""A2A error codes and error response utilities.
|
||||
|
||||
This module provides a centralized mapping of all A2A protocol error codes
|
||||
as defined in the A2A specification, plus custom CrewAI extensions.
|
||||
|
||||
Error codes follow JSON-RPC 2.0 conventions:
|
||||
- -32700 to -32600: Standard JSON-RPC errors
|
||||
- -32099 to -32000: Server errors (A2A-specific)
|
||||
- -32768 to -32100: Reserved for implementation-defined errors
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from enum import IntEnum
|
||||
from typing import Any
|
||||
"""A2A protocol error types."""
|
||||
|
||||
from a2a.client.errors import A2AClientTimeoutError
|
||||
|
||||
|
||||
class A2APollingTimeoutError(A2AClientTimeoutError):
|
||||
"""Raised when polling exceeds the configured timeout."""
|
||||
|
||||
|
||||
class A2AErrorCode(IntEnum):
|
||||
"""A2A protocol error codes.
|
||||
|
||||
Codes follow JSON-RPC 2.0 specification with A2A-specific extensions.
|
||||
"""
|
||||
|
||||
# JSON-RPC 2.0 Standard Errors (-32700 to -32600)
|
||||
JSON_PARSE_ERROR = -32700
|
||||
"""Invalid JSON was received by the server."""
|
||||
|
||||
INVALID_REQUEST = -32600
|
||||
"""The JSON sent is not a valid Request object."""
|
||||
|
||||
METHOD_NOT_FOUND = -32601
|
||||
"""The method does not exist / is not available."""
|
||||
|
||||
INVALID_PARAMS = -32602
|
||||
"""Invalid method parameter(s)."""
|
||||
|
||||
INTERNAL_ERROR = -32603
|
||||
"""Internal JSON-RPC error."""
|
||||
|
||||
# A2A-Specific Errors (-32099 to -32000)
|
||||
TASK_NOT_FOUND = -32001
|
||||
"""The specified task was not found."""
|
||||
|
||||
TASK_NOT_CANCELABLE = -32002
|
||||
"""The task cannot be canceled (already completed/failed)."""
|
||||
|
||||
PUSH_NOTIFICATION_NOT_SUPPORTED = -32003
|
||||
"""Push notifications are not supported by this agent."""
|
||||
|
||||
UNSUPPORTED_OPERATION = -32004
|
||||
"""The requested operation is not supported."""
|
||||
|
||||
CONTENT_TYPE_NOT_SUPPORTED = -32005
|
||||
"""Incompatible content types between client and server."""
|
||||
|
||||
INVALID_AGENT_RESPONSE = -32006
|
||||
"""The agent produced an invalid response."""
|
||||
|
||||
# CrewAI Custom Extensions (-32768 to -32100)
|
||||
UNSUPPORTED_VERSION = -32009
|
||||
"""The requested A2A protocol version is not supported."""
|
||||
|
||||
UNSUPPORTED_EXTENSION = -32010
|
||||
"""Client does not support required protocol extensions."""
|
||||
|
||||
AUTHENTICATION_REQUIRED = -32011
|
||||
"""Authentication is required for this operation."""
|
||||
|
||||
AUTHORIZATION_FAILED = -32012
|
||||
"""Authorization check failed (insufficient permissions)."""
|
||||
|
||||
RATE_LIMIT_EXCEEDED = -32013
|
||||
"""Rate limit exceeded for this client/operation."""
|
||||
|
||||
TASK_TIMEOUT = -32014
|
||||
"""Task execution timed out."""
|
||||
|
||||
TRANSPORT_NEGOTIATION_FAILED = -32015
|
||||
"""Failed to negotiate a compatible transport protocol."""
|
||||
|
||||
CONTEXT_NOT_FOUND = -32016
|
||||
"""The specified context was not found."""
|
||||
|
||||
SKILL_NOT_FOUND = -32017
|
||||
"""The specified skill was not found."""
|
||||
|
||||
ARTIFACT_NOT_FOUND = -32018
|
||||
"""The specified artifact was not found."""
|
||||
|
||||
|
||||
# Error code to default message mapping
|
||||
ERROR_MESSAGES: dict[int, str] = {
|
||||
A2AErrorCode.JSON_PARSE_ERROR: "Parse error",
|
||||
A2AErrorCode.INVALID_REQUEST: "Invalid Request",
|
||||
A2AErrorCode.METHOD_NOT_FOUND: "Method not found",
|
||||
A2AErrorCode.INVALID_PARAMS: "Invalid params",
|
||||
A2AErrorCode.INTERNAL_ERROR: "Internal error",
|
||||
A2AErrorCode.TASK_NOT_FOUND: "Task not found",
|
||||
A2AErrorCode.TASK_NOT_CANCELABLE: "Task not cancelable",
|
||||
A2AErrorCode.PUSH_NOTIFICATION_NOT_SUPPORTED: "Push Notification is not supported",
|
||||
A2AErrorCode.UNSUPPORTED_OPERATION: "This operation is not supported",
|
||||
A2AErrorCode.CONTENT_TYPE_NOT_SUPPORTED: "Incompatible content types",
|
||||
A2AErrorCode.INVALID_AGENT_RESPONSE: "Invalid agent response",
|
||||
A2AErrorCode.UNSUPPORTED_VERSION: "Unsupported A2A version",
|
||||
A2AErrorCode.UNSUPPORTED_EXTENSION: "Client does not support required extensions",
|
||||
A2AErrorCode.AUTHENTICATION_REQUIRED: "Authentication required",
|
||||
A2AErrorCode.AUTHORIZATION_FAILED: "Authorization failed",
|
||||
A2AErrorCode.RATE_LIMIT_EXCEEDED: "Rate limit exceeded",
|
||||
A2AErrorCode.TASK_TIMEOUT: "Task execution timed out",
|
||||
A2AErrorCode.TRANSPORT_NEGOTIATION_FAILED: "Transport negotiation failed",
|
||||
A2AErrorCode.CONTEXT_NOT_FOUND: "Context not found",
|
||||
A2AErrorCode.SKILL_NOT_FOUND: "Skill not found",
|
||||
A2AErrorCode.ARTIFACT_NOT_FOUND: "Artifact not found",
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class A2AError(Exception):
|
||||
"""Base exception for A2A protocol errors.
|
||||
|
||||
Attributes:
|
||||
code: The A2A/JSON-RPC error code.
|
||||
message: Human-readable error message.
|
||||
data: Optional additional error data.
|
||||
"""
|
||||
|
||||
code: int
|
||||
message: str | None = None
|
||||
data: Any = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.message is None:
|
||||
self.message = ERROR_MESSAGES.get(self.code, "Unknown error")
|
||||
super().__init__(self.message)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to JSON-RPC error object format."""
|
||||
error: dict[str, Any] = {
|
||||
"code": self.code,
|
||||
"message": self.message,
|
||||
}
|
||||
if self.data is not None:
|
||||
error["data"] = self.data
|
||||
return error
|
||||
|
||||
def to_response(self, request_id: str | int | None = None) -> dict[str, Any]:
|
||||
"""Convert to full JSON-RPC error response."""
|
||||
return {
|
||||
"jsonrpc": "2.0",
|
||||
"error": self.to_dict(),
|
||||
"id": request_id,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class JSONParseError(A2AError):
|
||||
"""Invalid JSON was received."""
|
||||
|
||||
code: int = field(default=A2AErrorCode.JSON_PARSE_ERROR, init=False)
|
||||
|
||||
|
||||
@dataclass
|
||||
class InvalidRequestError(A2AError):
|
||||
"""The JSON sent is not a valid Request object."""
|
||||
|
||||
code: int = field(default=A2AErrorCode.INVALID_REQUEST, init=False)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MethodNotFoundError(A2AError):
|
||||
"""The method does not exist / is not available."""
|
||||
|
||||
code: int = field(default=A2AErrorCode.METHOD_NOT_FOUND, init=False)
|
||||
method: str | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.message is None and self.method:
|
||||
self.message = f"Method not found: {self.method}"
|
||||
super().__post_init__()
|
||||
|
||||
|
||||
@dataclass
|
||||
class InvalidParamsError(A2AError):
|
||||
"""Invalid method parameter(s)."""
|
||||
|
||||
code: int = field(default=A2AErrorCode.INVALID_PARAMS, init=False)
|
||||
param: str | None = None
|
||||
reason: str | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.message is None:
|
||||
if self.param and self.reason:
|
||||
self.message = f"Invalid parameter '{self.param}': {self.reason}"
|
||||
elif self.param:
|
||||
self.message = f"Invalid parameter: {self.param}"
|
||||
super().__post_init__()
|
||||
|
||||
|
||||
@dataclass
|
||||
class InternalError(A2AError):
|
||||
"""Internal JSON-RPC error."""
|
||||
|
||||
code: int = field(default=A2AErrorCode.INTERNAL_ERROR, init=False)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TaskNotFoundError(A2AError):
|
||||
"""The specified task was not found."""
|
||||
|
||||
code: int = field(default=A2AErrorCode.TASK_NOT_FOUND, init=False)
|
||||
task_id: str | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.message is None and self.task_id:
|
||||
self.message = f"Task not found: {self.task_id}"
|
||||
super().__post_init__()
|
||||
|
||||
|
||||
@dataclass
|
||||
class TaskNotCancelableError(A2AError):
|
||||
"""The task cannot be canceled."""
|
||||
|
||||
code: int = field(default=A2AErrorCode.TASK_NOT_CANCELABLE, init=False)
|
||||
task_id: str | None = None
|
||||
reason: str | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.message is None:
|
||||
if self.task_id and self.reason:
|
||||
self.message = f"Task {self.task_id} cannot be canceled: {self.reason}"
|
||||
elif self.task_id:
|
||||
self.message = f"Task {self.task_id} cannot be canceled"
|
||||
super().__post_init__()
|
||||
|
||||
|
||||
@dataclass
|
||||
class PushNotificationNotSupportedError(A2AError):
|
||||
"""Push notifications are not supported."""
|
||||
|
||||
code: int = field(default=A2AErrorCode.PUSH_NOTIFICATION_NOT_SUPPORTED, init=False)
|
||||
|
||||
|
||||
@dataclass
|
||||
class UnsupportedOperationError(A2AError):
|
||||
"""The requested operation is not supported."""
|
||||
|
||||
code: int = field(default=A2AErrorCode.UNSUPPORTED_OPERATION, init=False)
|
||||
operation: str | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.message is None and self.operation:
|
||||
self.message = f"Operation not supported: {self.operation}"
|
||||
super().__post_init__()
|
||||
|
||||
|
||||
@dataclass
|
||||
class ContentTypeNotSupportedError(A2AError):
|
||||
"""Incompatible content types."""
|
||||
|
||||
code: int = field(default=A2AErrorCode.CONTENT_TYPE_NOT_SUPPORTED, init=False)
|
||||
requested_types: list[str] | None = None
|
||||
supported_types: list[str] | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.message is None and self.requested_types and self.supported_types:
|
||||
self.message = (
|
||||
f"Content type not supported. Requested: {self.requested_types}, "
|
||||
f"Supported: {self.supported_types}"
|
||||
)
|
||||
super().__post_init__()
|
||||
|
||||
|
||||
@dataclass
|
||||
class InvalidAgentResponseError(A2AError):
|
||||
"""The agent produced an invalid response."""
|
||||
|
||||
code: int = field(default=A2AErrorCode.INVALID_AGENT_RESPONSE, init=False)
|
||||
|
||||
|
||||
@dataclass
|
||||
class UnsupportedVersionError(A2AError):
|
||||
"""The requested A2A version is not supported."""
|
||||
|
||||
code: int = field(default=A2AErrorCode.UNSUPPORTED_VERSION, init=False)
|
||||
requested_version: str | None = None
|
||||
supported_versions: list[str] | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.message is None and self.requested_version:
|
||||
msg = f"Unsupported A2A version: {self.requested_version}"
|
||||
if self.supported_versions:
|
||||
msg += f". Supported versions: {', '.join(self.supported_versions)}"
|
||||
self.message = msg
|
||||
super().__post_init__()
|
||||
|
||||
|
||||
@dataclass
|
||||
class UnsupportedExtensionError(A2AError):
|
||||
"""Client does not support required extensions."""
|
||||
|
||||
code: int = field(default=A2AErrorCode.UNSUPPORTED_EXTENSION, init=False)
|
||||
required_extensions: list[str] | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.message is None and self.required_extensions:
|
||||
self.message = f"Client does not support required extensions: {', '.join(self.required_extensions)}"
|
||||
super().__post_init__()
|
||||
|
||||
|
||||
@dataclass
|
||||
class AuthenticationRequiredError(A2AError):
|
||||
"""Authentication is required."""
|
||||
|
||||
code: int = field(default=A2AErrorCode.AUTHENTICATION_REQUIRED, init=False)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AuthorizationFailedError(A2AError):
|
||||
"""Authorization check failed."""
|
||||
|
||||
code: int = field(default=A2AErrorCode.AUTHORIZATION_FAILED, init=False)
|
||||
required_scope: str | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.message is None and self.required_scope:
|
||||
self.message = (
|
||||
f"Authorization failed. Required scope: {self.required_scope}"
|
||||
)
|
||||
super().__post_init__()
|
||||
|
||||
|
||||
@dataclass
|
||||
class RateLimitExceededError(A2AError):
|
||||
"""Rate limit exceeded."""
|
||||
|
||||
code: int = field(default=A2AErrorCode.RATE_LIMIT_EXCEEDED, init=False)
|
||||
retry_after: int | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.message is None and self.retry_after:
|
||||
self.message = (
|
||||
f"Rate limit exceeded. Retry after {self.retry_after} seconds"
|
||||
)
|
||||
if self.retry_after:
|
||||
self.data = {"retry_after": self.retry_after}
|
||||
super().__post_init__()
|
||||
|
||||
|
||||
@dataclass
|
||||
class TaskTimeoutError(A2AError):
|
||||
"""Task execution timed out."""
|
||||
|
||||
code: int = field(default=A2AErrorCode.TASK_TIMEOUT, init=False)
|
||||
task_id: str | None = None
|
||||
timeout_seconds: float | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.message is None:
|
||||
if self.task_id and self.timeout_seconds:
|
||||
self.message = (
|
||||
f"Task {self.task_id} timed out after {self.timeout_seconds}s"
|
||||
)
|
||||
elif self.task_id:
|
||||
self.message = f"Task {self.task_id} timed out"
|
||||
super().__post_init__()
|
||||
|
||||
|
||||
@dataclass
|
||||
class TransportNegotiationFailedError(A2AError):
|
||||
"""Failed to negotiate a compatible transport protocol."""
|
||||
|
||||
code: int = field(default=A2AErrorCode.TRANSPORT_NEGOTIATION_FAILED, init=False)
|
||||
client_transports: list[str] | None = None
|
||||
server_transports: list[str] | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.message is None and self.client_transports and self.server_transports:
|
||||
self.message = (
|
||||
f"Transport negotiation failed. Client: {self.client_transports}, "
|
||||
f"Server: {self.server_transports}"
|
||||
)
|
||||
super().__post_init__()
|
||||
|
||||
|
||||
@dataclass
|
||||
class ContextNotFoundError(A2AError):
|
||||
"""The specified context was not found."""
|
||||
|
||||
code: int = field(default=A2AErrorCode.CONTEXT_NOT_FOUND, init=False)
|
||||
context_id: str | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.message is None and self.context_id:
|
||||
self.message = f"Context not found: {self.context_id}"
|
||||
super().__post_init__()
|
||||
|
||||
|
||||
@dataclass
|
||||
class SkillNotFoundError(A2AError):
|
||||
"""The specified skill was not found."""
|
||||
|
||||
code: int = field(default=A2AErrorCode.SKILL_NOT_FOUND, init=False)
|
||||
skill_id: str | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.message is None and self.skill_id:
|
||||
self.message = f"Skill not found: {self.skill_id}"
|
||||
super().__post_init__()
|
||||
|
||||
|
||||
@dataclass
|
||||
class ArtifactNotFoundError(A2AError):
|
||||
"""The specified artifact was not found."""
|
||||
|
||||
code: int = field(default=A2AErrorCode.ARTIFACT_NOT_FOUND, init=False)
|
||||
artifact_id: str | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.message is None and self.artifact_id:
|
||||
self.message = f"Artifact not found: {self.artifact_id}"
|
||||
super().__post_init__()
|
||||
|
||||
|
||||
def create_error_response(
|
||||
code: int | A2AErrorCode,
|
||||
message: str | None = None,
|
||||
data: Any = None,
|
||||
request_id: str | int | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Create a JSON-RPC error response.
|
||||
|
||||
Args:
|
||||
code: Error code (A2AErrorCode or int).
|
||||
message: Optional error message (uses default if not provided).
|
||||
data: Optional additional error data.
|
||||
request_id: Request ID for correlation.
|
||||
|
||||
Returns:
|
||||
Dict in JSON-RPC error response format.
|
||||
"""
|
||||
error = A2AError(code=int(code), message=message, data=data)
|
||||
return error.to_response(request_id)
|
||||
|
||||
|
||||
def is_retryable_error(code: int) -> bool:
|
||||
"""Check if an error is potentially retryable.
|
||||
|
||||
Args:
|
||||
code: Error code to check.
|
||||
|
||||
Returns:
|
||||
True if the error might be resolved by retrying.
|
||||
"""
|
||||
retryable_codes = {
|
||||
A2AErrorCode.INTERNAL_ERROR,
|
||||
A2AErrorCode.RATE_LIMIT_EXCEEDED,
|
||||
A2AErrorCode.TASK_TIMEOUT,
|
||||
}
|
||||
return code in retryable_codes
|
||||
|
||||
|
||||
def is_client_error(code: int) -> bool:
|
||||
"""Check if an error is a client-side error.
|
||||
|
||||
Args:
|
||||
code: Error code to check.
|
||||
|
||||
Returns:
|
||||
True if the error is due to client request issues.
|
||||
"""
|
||||
client_error_codes = {
|
||||
A2AErrorCode.JSON_PARSE_ERROR,
|
||||
A2AErrorCode.INVALID_REQUEST,
|
||||
A2AErrorCode.METHOD_NOT_FOUND,
|
||||
A2AErrorCode.INVALID_PARAMS,
|
||||
A2AErrorCode.TASK_NOT_FOUND,
|
||||
A2AErrorCode.CONTENT_TYPE_NOT_SUPPORTED,
|
||||
A2AErrorCode.UNSUPPORTED_VERSION,
|
||||
A2AErrorCode.UNSUPPORTED_EXTENSION,
|
||||
A2AErrorCode.CONTEXT_NOT_FOUND,
|
||||
A2AErrorCode.SKILL_NOT_FOUND,
|
||||
A2AErrorCode.ARTIFACT_NOT_FOUND,
|
||||
}
|
||||
return code in client_error_codes
|
||||
|
||||
@@ -1,37 +1,4 @@
|
||||
"""A2A Protocol Extensions for CrewAI.
|
||||
|
||||
This module contains extensions to the A2A (Agent-to-Agent) protocol.
|
||||
|
||||
**Client-side extensions** (A2AExtension) allow customizing how the A2A wrapper
|
||||
processes requests and responses during delegation to remote agents. These provide
|
||||
hooks for tool injection, prompt augmentation, and response processing.
|
||||
|
||||
**Server-side extensions** (ServerExtension) allow agents to offer additional
|
||||
functionality beyond the core A2A specification. Clients activate extensions
|
||||
via the X-A2A-Extensions header.
|
||||
|
||||
See: https://a2a-protocol.org/latest/topics/extensions/
|
||||
"""
|
||||
|
||||
from crewai.a2a.extensions.base import (
|
||||
A2AExtension,
|
||||
ConversationState,
|
||||
ExtensionRegistry,
|
||||
ValidatedA2AExtension,
|
||||
)
|
||||
from crewai.a2a.extensions.server import (
|
||||
ExtensionContext,
|
||||
ServerExtension,
|
||||
ServerExtensionRegistry,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"A2AExtension",
|
||||
"ConversationState",
|
||||
"ExtensionContext",
|
||||
"ExtensionRegistry",
|
||||
"ServerExtension",
|
||||
"ServerExtensionRegistry",
|
||||
"ValidatedA2AExtension",
|
||||
]
|
||||
|
||||
@@ -1,20 +1,14 @@
|
||||
"""Base extension interface for CrewAI A2A wrapper processing hooks.
|
||||
"""Base extension interface for A2A wrapper integrations.
|
||||
|
||||
This module defines the protocol for extending CrewAI's A2A wrapper functionality
|
||||
with custom logic for tool injection, prompt augmentation, and response processing.
|
||||
|
||||
Note: These are CrewAI-specific processing hooks, NOT A2A protocol extensions.
|
||||
A2A protocol extensions are capability declarations using AgentExtension objects
|
||||
in AgentCard.capabilities.extensions, activated via the A2A-Extensions HTTP header.
|
||||
See: https://a2a-protocol.org/latest/topics/extensions/
|
||||
This module defines the protocol for extending A2A wrapper functionality
|
||||
with custom logic for conversation processing, prompt augmentation, and
|
||||
agent response handling.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import TYPE_CHECKING, Annotated, Any, Protocol, runtime_checkable
|
||||
|
||||
from pydantic import BeforeValidator
|
||||
from typing import TYPE_CHECKING, Any, Protocol
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -23,20 +17,6 @@ if TYPE_CHECKING:
|
||||
from crewai.agent.core import Agent
|
||||
|
||||
|
||||
def _validate_a2a_extension(v: Any) -> Any:
|
||||
"""Validate that value implements A2AExtension protocol."""
|
||||
if not isinstance(v, A2AExtension):
|
||||
raise ValueError(
|
||||
f"Value must implement A2AExtension protocol. "
|
||||
f"Got {type(v).__name__} which is missing required methods."
|
||||
)
|
||||
return v
|
||||
|
||||
|
||||
ValidatedA2AExtension = Annotated[Any, BeforeValidator(_validate_a2a_extension)]
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class ConversationState(Protocol):
|
||||
"""Protocol for extension-specific conversation state.
|
||||
|
||||
@@ -53,36 +33,11 @@ class ConversationState(Protocol):
|
||||
...
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class A2AExtension(Protocol):
|
||||
"""Protocol for A2A wrapper extensions.
|
||||
|
||||
Extensions can implement this protocol to inject custom logic into
|
||||
the A2A conversation flow at various integration points.
|
||||
|
||||
Example:
|
||||
class MyExtension:
|
||||
def inject_tools(self, agent: Agent) -> None:
|
||||
# Add custom tools to the agent
|
||||
pass
|
||||
|
||||
def extract_state_from_history(
|
||||
self, conversation_history: Sequence[Message]
|
||||
) -> ConversationState | None:
|
||||
# Extract state from conversation
|
||||
return None
|
||||
|
||||
def augment_prompt(
|
||||
self, base_prompt: str, conversation_state: ConversationState | None
|
||||
) -> str:
|
||||
# Add custom instructions
|
||||
return base_prompt
|
||||
|
||||
def process_response(
|
||||
self, agent_response: Any, conversation_state: ConversationState | None
|
||||
) -> Any:
|
||||
# Modify response if needed
|
||||
return agent_response
|
||||
"""
|
||||
|
||||
def inject_tools(self, agent: Agent) -> None:
|
||||
|
||||
@@ -1,170 +1,34 @@
|
||||
"""A2A Protocol extension utilities.
|
||||
"""Extension registry factory for A2A configurations.
|
||||
|
||||
This module provides utilities for working with A2A protocol extensions as
|
||||
defined in the A2A specification. Extensions are capability declarations in
|
||||
AgentCard.capabilities.extensions using AgentExtension objects, activated
|
||||
via the X-A2A-Extensions HTTP header.
|
||||
|
||||
See: https://a2a-protocol.org/latest/topics/extensions/
|
||||
This module provides utilities for creating extension registries from A2A configurations.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from a2a.client.middleware import ClientCallContext, ClientCallInterceptor
|
||||
from a2a.extensions.common import (
|
||||
HTTP_EXTENSION_HEADER,
|
||||
)
|
||||
from a2a.types import AgentCard, AgentExtension
|
||||
|
||||
from crewai.a2a.config import A2AClientConfig, A2AConfig
|
||||
from crewai.a2a.extensions.base import ExtensionRegistry
|
||||
|
||||
|
||||
def get_extensions_from_config(
|
||||
a2a_config: list[A2AConfig | A2AClientConfig] | A2AConfig | A2AClientConfig,
|
||||
) -> list[str]:
|
||||
"""Extract extension URIs from A2A configuration.
|
||||
|
||||
Args:
|
||||
a2a_config: A2A configuration (single or list).
|
||||
|
||||
Returns:
|
||||
Deduplicated list of extension URIs from all configs.
|
||||
"""
|
||||
configs = a2a_config if isinstance(a2a_config, list) else [a2a_config]
|
||||
seen: set[str] = set()
|
||||
result: list[str] = []
|
||||
|
||||
for config in configs:
|
||||
if not isinstance(config, A2AClientConfig):
|
||||
continue
|
||||
for uri in config.extensions:
|
||||
if uri not in seen:
|
||||
seen.add(uri)
|
||||
result.append(uri)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class ExtensionsMiddleware(ClientCallInterceptor):
|
||||
"""Middleware to add X-A2A-Extensions header to requests.
|
||||
|
||||
This middleware adds the extensions header to all outgoing requests,
|
||||
declaring which A2A protocol extensions the client supports.
|
||||
"""
|
||||
|
||||
def __init__(self, extensions: list[str]) -> None:
|
||||
"""Initialize with extension URIs.
|
||||
|
||||
Args:
|
||||
extensions: List of extension URIs the client supports.
|
||||
"""
|
||||
self._extensions = extensions
|
||||
|
||||
async def intercept(
|
||||
self,
|
||||
method_name: str,
|
||||
request_payload: dict[str, Any],
|
||||
http_kwargs: dict[str, Any],
|
||||
agent_card: AgentCard | None,
|
||||
context: ClientCallContext | None,
|
||||
) -> tuple[dict[str, Any], dict[str, Any]]:
|
||||
"""Add extensions header to the request.
|
||||
|
||||
Args:
|
||||
method_name: The A2A method being called.
|
||||
request_payload: The JSON-RPC request payload.
|
||||
http_kwargs: HTTP request kwargs (headers, etc).
|
||||
agent_card: The target agent's card.
|
||||
context: Optional call context.
|
||||
|
||||
Returns:
|
||||
Tuple of (request_payload, modified_http_kwargs).
|
||||
"""
|
||||
if self._extensions:
|
||||
headers = http_kwargs.setdefault("headers", {})
|
||||
headers[HTTP_EXTENSION_HEADER] = ",".join(self._extensions)
|
||||
return request_payload, http_kwargs
|
||||
|
||||
|
||||
def validate_required_extensions(
|
||||
agent_card: AgentCard,
|
||||
client_extensions: list[str] | None,
|
||||
) -> list[AgentExtension]:
|
||||
"""Validate that client supports all required extensions from agent.
|
||||
|
||||
Args:
|
||||
agent_card: The agent's card with declared extensions.
|
||||
client_extensions: Extension URIs the client supports.
|
||||
|
||||
Returns:
|
||||
List of unsupported required extensions.
|
||||
|
||||
Raises:
|
||||
None - returns list of unsupported extensions for caller to handle.
|
||||
"""
|
||||
unsupported: list[AgentExtension] = []
|
||||
client_set = set(client_extensions or [])
|
||||
|
||||
if not agent_card.capabilities or not agent_card.capabilities.extensions:
|
||||
return unsupported
|
||||
|
||||
unsupported.extend(
|
||||
ext
|
||||
for ext in agent_card.capabilities.extensions
|
||||
if ext.required and ext.uri not in client_set
|
||||
)
|
||||
|
||||
return unsupported
|
||||
if TYPE_CHECKING:
|
||||
from crewai.a2a.config import A2AConfig
|
||||
|
||||
|
||||
def create_extension_registry_from_config(
|
||||
a2a_config: list[A2AConfig | A2AClientConfig] | A2AConfig | A2AClientConfig,
|
||||
a2a_config: list[A2AConfig] | A2AConfig,
|
||||
) -> ExtensionRegistry:
|
||||
"""Create an extension registry from A2A client configuration.
|
||||
|
||||
Extracts client_extensions from each A2AClientConfig and registers them
|
||||
with the ExtensionRegistry. These extensions provide CrewAI-specific
|
||||
processing hooks (tool injection, prompt augmentation, response processing).
|
||||
|
||||
Note: A2A protocol extensions (URI strings sent via X-A2A-Extensions header)
|
||||
are handled separately via get_extensions_from_config() and ExtensionsMiddleware.
|
||||
"""Create an extension registry from A2A configuration.
|
||||
|
||||
Args:
|
||||
a2a_config: A2A configuration (single or list).
|
||||
a2a_config: A2A configuration (single or list)
|
||||
|
||||
Returns:
|
||||
Extension registry with all client_extensions registered.
|
||||
|
||||
Example:
|
||||
class LoggingExtension:
|
||||
def inject_tools(self, agent): pass
|
||||
def extract_state_from_history(self, history): return None
|
||||
def augment_prompt(self, prompt, state): return prompt
|
||||
def process_response(self, response, state):
|
||||
print(f"Response: {response}")
|
||||
return response
|
||||
|
||||
config = A2AClientConfig(
|
||||
endpoint="https://agent.example.com",
|
||||
client_extensions=[LoggingExtension()],
|
||||
)
|
||||
registry = create_extension_registry_from_config(config)
|
||||
Configured extension registry with all applicable extensions
|
||||
"""
|
||||
registry = ExtensionRegistry()
|
||||
configs = a2a_config if isinstance(a2a_config, list) else [a2a_config]
|
||||
|
||||
seen: set[int] = set()
|
||||
|
||||
for config in configs:
|
||||
if isinstance(config, (A2AConfig, A2AClientConfig)):
|
||||
client_exts = getattr(config, "client_extensions", [])
|
||||
for extension in client_exts:
|
||||
ext_id = id(extension)
|
||||
if ext_id not in seen:
|
||||
seen.add(ext_id)
|
||||
registry.register(extension)
|
||||
for _ in configs:
|
||||
pass
|
||||
|
||||
return registry
|
||||
|
||||
@@ -1,305 +0,0 @@
|
||||
"""A2A protocol server extensions for CrewAI agents.
|
||||
|
||||
This module provides the base class and context for implementing A2A protocol
|
||||
extensions on the server side. Extensions allow agents to offer additional
|
||||
functionality beyond the core A2A specification.
|
||||
|
||||
See: https://a2a-protocol.org/latest/topics/extensions/
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Annotated, Any
|
||||
|
||||
from a2a.types import AgentExtension
|
||||
from pydantic_core import CoreSchema, core_schema
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from a2a.server.context import ServerCallContext
|
||||
from pydantic import GetCoreSchemaHandler
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExtensionContext:
|
||||
"""Context passed to extension hooks during request processing.
|
||||
|
||||
Provides access to request metadata, client extensions, and shared state
|
||||
that extensions can read from and write to.
|
||||
|
||||
Attributes:
|
||||
metadata: Request metadata dict, includes extension-namespaced keys.
|
||||
client_extensions: Set of extension URIs the client declared support for.
|
||||
state: Mutable dict for extensions to share data during request lifecycle.
|
||||
server_context: The underlying A2A server call context.
|
||||
"""
|
||||
|
||||
metadata: dict[str, Any]
|
||||
client_extensions: set[str]
|
||||
state: dict[str, Any] = field(default_factory=dict)
|
||||
server_context: ServerCallContext | None = None
|
||||
|
||||
def get_extension_metadata(self, uri: str, key: str) -> Any | None:
|
||||
"""Get extension-specific metadata value.
|
||||
|
||||
Extension metadata uses namespaced keys in the format:
|
||||
"{extension_uri}/{key}"
|
||||
|
||||
Args:
|
||||
uri: The extension URI.
|
||||
key: The metadata key within the extension namespace.
|
||||
|
||||
Returns:
|
||||
The metadata value, or None if not present.
|
||||
"""
|
||||
full_key = f"{uri}/{key}"
|
||||
return self.metadata.get(full_key)
|
||||
|
||||
def set_extension_metadata(self, uri: str, key: str, value: Any) -> None:
|
||||
"""Set extension-specific metadata value.
|
||||
|
||||
Args:
|
||||
uri: The extension URI.
|
||||
key: The metadata key within the extension namespace.
|
||||
value: The value to set.
|
||||
"""
|
||||
full_key = f"{uri}/{key}"
|
||||
self.metadata[full_key] = value
|
||||
|
||||
|
||||
class ServerExtension(ABC):
|
||||
"""Base class for A2A protocol server extensions.
|
||||
|
||||
Subclass this to create custom extensions that modify agent behavior
|
||||
when clients activate them. Extensions are identified by URI and can
|
||||
be marked as required.
|
||||
|
||||
Example:
|
||||
class SamplingExtension(ServerExtension):
|
||||
uri = "urn:crewai:ext:sampling/v1"
|
||||
required = True
|
||||
|
||||
def __init__(self, max_tokens: int = 4096):
|
||||
self.max_tokens = max_tokens
|
||||
|
||||
@property
|
||||
def params(self) -> dict[str, Any]:
|
||||
return {"max_tokens": self.max_tokens}
|
||||
|
||||
async def on_request(self, context: ExtensionContext) -> None:
|
||||
limit = context.get_extension_metadata(self.uri, "limit")
|
||||
if limit:
|
||||
context.state["token_limit"] = int(limit)
|
||||
|
||||
async def on_response(self, context: ExtensionContext, result: Any) -> Any:
|
||||
return result
|
||||
"""
|
||||
|
||||
uri: Annotated[str, "Extension URI identifier. Must be unique."]
|
||||
required: Annotated[bool, "Whether clients must support this extension."] = False
|
||||
description: Annotated[
|
||||
str | None, "Human-readable description of the extension."
|
||||
] = None
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(
|
||||
cls,
|
||||
_source_type: Any,
|
||||
_handler: GetCoreSchemaHandler,
|
||||
) -> CoreSchema:
|
||||
"""Tell Pydantic how to validate ServerExtension instances."""
|
||||
return core_schema.is_instance_schema(cls)
|
||||
|
||||
@property
|
||||
def params(self) -> dict[str, Any] | None:
|
||||
"""Extension parameters to advertise in AgentCard.
|
||||
|
||||
Override this property to expose configuration that clients can read.
|
||||
|
||||
Returns:
|
||||
Dict of parameter names to values, or None.
|
||||
"""
|
||||
return None
|
||||
|
||||
def agent_extension(self) -> AgentExtension:
|
||||
"""Generate the AgentExtension object for the AgentCard.
|
||||
|
||||
Returns:
|
||||
AgentExtension with this extension's URI, required flag, and params.
|
||||
"""
|
||||
return AgentExtension(
|
||||
uri=self.uri,
|
||||
required=self.required if self.required else None,
|
||||
description=self.description,
|
||||
params=self.params,
|
||||
)
|
||||
|
||||
def is_active(self, context: ExtensionContext) -> bool:
|
||||
"""Check if this extension is active for the current request.
|
||||
|
||||
An extension is active if the client declared support for it.
|
||||
|
||||
Args:
|
||||
context: The extension context for the current request.
|
||||
|
||||
Returns:
|
||||
True if the client supports this extension.
|
||||
"""
|
||||
return self.uri in context.client_extensions
|
||||
|
||||
@abstractmethod
|
||||
async def on_request(self, context: ExtensionContext) -> None:
|
||||
"""Called before agent execution if extension is active.
|
||||
|
||||
Use this hook to:
|
||||
- Read extension-specific metadata from the request
|
||||
- Set up state for the execution
|
||||
- Modify execution parameters via context.state
|
||||
|
||||
Args:
|
||||
context: The extension context with request metadata and state.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def on_response(self, context: ExtensionContext, result: Any) -> Any:
|
||||
"""Called after agent execution if extension is active.
|
||||
|
||||
Use this hook to:
|
||||
- Modify or enhance the result
|
||||
- Add extension-specific metadata to the response
|
||||
- Clean up any resources
|
||||
|
||||
Args:
|
||||
context: The extension context with request metadata and state.
|
||||
result: The agent execution result.
|
||||
|
||||
Returns:
|
||||
The result, potentially modified.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class ServerExtensionRegistry:
|
||||
"""Registry for managing server-side A2A protocol extensions.
|
||||
|
||||
Collects extensions and provides methods to generate AgentCapabilities
|
||||
and invoke extension hooks during request processing.
|
||||
"""
|
||||
|
||||
def __init__(self, extensions: list[ServerExtension] | None = None) -> None:
|
||||
"""Initialize the registry with optional extensions.
|
||||
|
||||
Args:
|
||||
extensions: Initial list of extensions to register.
|
||||
"""
|
||||
self._extensions: list[ServerExtension] = list(extensions) if extensions else []
|
||||
self._by_uri: dict[str, ServerExtension] = {
|
||||
ext.uri: ext for ext in self._extensions
|
||||
}
|
||||
|
||||
def register(self, extension: ServerExtension) -> None:
|
||||
"""Register an extension.
|
||||
|
||||
Args:
|
||||
extension: The extension to register.
|
||||
|
||||
Raises:
|
||||
ValueError: If an extension with the same URI is already registered.
|
||||
"""
|
||||
if extension.uri in self._by_uri:
|
||||
raise ValueError(f"Extension already registered: {extension.uri}")
|
||||
self._extensions.append(extension)
|
||||
self._by_uri[extension.uri] = extension
|
||||
|
||||
def get_agent_extensions(self) -> list[AgentExtension]:
|
||||
"""Get AgentExtension objects for all registered extensions.
|
||||
|
||||
Returns:
|
||||
List of AgentExtension objects for the AgentCard.
|
||||
"""
|
||||
return [ext.agent_extension() for ext in self._extensions]
|
||||
|
||||
def get_extension(self, uri: str) -> ServerExtension | None:
|
||||
"""Get an extension by URI.
|
||||
|
||||
Args:
|
||||
uri: The extension URI.
|
||||
|
||||
Returns:
|
||||
The extension, or None if not found.
|
||||
"""
|
||||
return self._by_uri.get(uri)
|
||||
|
||||
@staticmethod
|
||||
def create_context(
|
||||
metadata: dict[str, Any],
|
||||
client_extensions: set[str],
|
||||
server_context: ServerCallContext | None = None,
|
||||
) -> ExtensionContext:
|
||||
"""Create an ExtensionContext for a request.
|
||||
|
||||
Args:
|
||||
metadata: Request metadata dict.
|
||||
client_extensions: Set of extension URIs from client.
|
||||
server_context: Optional server call context.
|
||||
|
||||
Returns:
|
||||
ExtensionContext for use in hooks.
|
||||
"""
|
||||
return ExtensionContext(
|
||||
metadata=metadata,
|
||||
client_extensions=client_extensions,
|
||||
server_context=server_context,
|
||||
)
|
||||
|
||||
async def invoke_on_request(self, context: ExtensionContext) -> None:
|
||||
"""Invoke on_request hooks for all active extensions.
|
||||
|
||||
Tracks activated extensions and isolates errors from individual hooks.
|
||||
|
||||
Args:
|
||||
context: The extension context for the request.
|
||||
"""
|
||||
for extension in self._extensions:
|
||||
if extension.is_active(context):
|
||||
try:
|
||||
await extension.on_request(context)
|
||||
if context.server_context is not None:
|
||||
context.server_context.activated_extensions.add(extension.uri)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Extension on_request hook failed",
|
||||
extra={"extension": extension.uri},
|
||||
)
|
||||
|
||||
async def invoke_on_response(self, context: ExtensionContext, result: Any) -> Any:
|
||||
"""Invoke on_response hooks for all active extensions.
|
||||
|
||||
Isolates errors from individual hooks to prevent one failing extension
|
||||
from breaking the entire response.
|
||||
|
||||
Args:
|
||||
context: The extension context for the request.
|
||||
result: The agent execution result.
|
||||
|
||||
Returns:
|
||||
The result after all extensions have processed it.
|
||||
"""
|
||||
processed = result
|
||||
for extension in self._extensions:
|
||||
if extension.is_active(context):
|
||||
try:
|
||||
processed = await extension.on_response(context, processed)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Extension on_response hook failed",
|
||||
extra={"extension": extension.uri},
|
||||
)
|
||||
return processed
|
||||
@@ -51,13 +51,6 @@ ACTIONABLE_STATES: frozenset[TaskState] = frozenset(
|
||||
}
|
||||
)
|
||||
|
||||
PENDING_STATES: frozenset[TaskState] = frozenset(
|
||||
{
|
||||
TaskState.submitted,
|
||||
TaskState.working,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class TaskStateResult(TypedDict):
|
||||
"""Result dictionary from processing A2A task state."""
|
||||
@@ -279,9 +272,6 @@ def process_task_state(
|
||||
history=new_messages,
|
||||
)
|
||||
|
||||
if a2a_task.status.state in PENDING_STATES:
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
|
||||
|
||||
@@ -38,18 +38,3 @@ You MUST now:
|
||||
DO NOT send another request - the task is already done.
|
||||
</REMOTE_AGENT_STATUS>
|
||||
"""
|
||||
|
||||
REMOTE_AGENT_RESPONSE_NOTICE: Final[str] = """
|
||||
<REMOTE_AGENT_STATUS>
|
||||
STATUS: RESPONSE_RECEIVED
|
||||
The remote agent has responded. Their response is in the conversation history above.
|
||||
|
||||
You MUST now:
|
||||
1. Set is_a2a=false (the remote task is complete and cannot receive more messages)
|
||||
2. Provide YOUR OWN response to the original task based on the information received
|
||||
|
||||
IMPORTANT: Your response should be addressed to the USER who gave you the original task.
|
||||
Report what the remote agent told you in THIRD PERSON (e.g., "The remote agent said..." or "I learned that...").
|
||||
Do NOT address the remote agent directly or use "you" to refer to them.
|
||||
</REMOTE_AGENT_STATUS>
|
||||
"""
|
||||
|
||||
@@ -36,17 +36,6 @@ except ImportError:
|
||||
|
||||
|
||||
TransportType = Literal["JSONRPC", "GRPC", "HTTP+JSON"]
|
||||
ProtocolVersion = Literal[
|
||||
"0.2.0",
|
||||
"0.2.1",
|
||||
"0.2.2",
|
||||
"0.2.3",
|
||||
"0.2.4",
|
||||
"0.2.5",
|
||||
"0.2.6",
|
||||
"0.3.0",
|
||||
"0.4.0",
|
||||
]
|
||||
|
||||
http_url_adapter: TypeAdapter[HttpUrl] = TypeAdapter(HttpUrl)
|
||||
|
||||
|
||||
@@ -2,28 +2,12 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, NamedTuple, Protocol, TypedDict
|
||||
from typing import TYPE_CHECKING, Any, Protocol, TypedDict
|
||||
|
||||
from pydantic import GetCoreSchemaHandler
|
||||
from pydantic_core import CoreSchema, core_schema
|
||||
|
||||
|
||||
class CommonParams(NamedTuple):
|
||||
"""Common parameters shared across all update handlers.
|
||||
|
||||
Groups the frequently-passed parameters to reduce duplication.
|
||||
"""
|
||||
|
||||
turn_number: int
|
||||
is_multiturn: bool
|
||||
agent_role: str | None
|
||||
endpoint: str
|
||||
a2a_agent_name: str | None
|
||||
context_id: str | None
|
||||
from_task: Any
|
||||
from_agent: Any
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from a2a.client import Client
|
||||
from a2a.types import AgentCard, Message, Task
|
||||
@@ -79,8 +63,8 @@ class PushNotificationResultStore(Protocol):
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(
|
||||
cls,
|
||||
_source_type: Any,
|
||||
_handler: GetCoreSchemaHandler,
|
||||
source_type: Any,
|
||||
handler: GetCoreSchemaHandler,
|
||||
) -> CoreSchema:
|
||||
return core_schema.any_schema()
|
||||
|
||||
@@ -146,31 +130,3 @@ class UpdateHandler(Protocol):
|
||||
Result dictionary with status, result/error, and history.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
def extract_common_params(kwargs: BaseHandlerKwargs) -> CommonParams:
|
||||
"""Extract common parameters from handler kwargs.
|
||||
|
||||
Args:
|
||||
kwargs: Handler kwargs dict.
|
||||
|
||||
Returns:
|
||||
CommonParams with extracted values.
|
||||
|
||||
Raises:
|
||||
ValueError: If endpoint is not provided.
|
||||
"""
|
||||
endpoint = kwargs.get("endpoint")
|
||||
if endpoint is None:
|
||||
raise ValueError("endpoint is required for update handlers")
|
||||
|
||||
return CommonParams(
|
||||
turn_number=kwargs.get("turn_number", 0),
|
||||
is_multiturn=kwargs.get("is_multiturn", False),
|
||||
agent_role=kwargs.get("agent_role"),
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=kwargs.get("a2a_agent_name"),
|
||||
context_id=kwargs.get("context_id"),
|
||||
from_task=kwargs.get("from_task"),
|
||||
from_agent=kwargs.get("from_agent"),
|
||||
)
|
||||
|
||||
@@ -94,7 +94,7 @@ async def _poll_task_until_complete(
|
||||
A2APollingStatusEvent(
|
||||
task_id=task_id,
|
||||
context_id=effective_context_id,
|
||||
state=str(task.status.state.value),
|
||||
state=str(task.status.state.value) if task.status.state else "unknown",
|
||||
elapsed_seconds=elapsed,
|
||||
poll_count=poll_count,
|
||||
endpoint=endpoint,
|
||||
@@ -325,7 +325,7 @@ class PollingHandler:
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
A2AConnectionErrorEvent(
|
||||
endpoint=endpoint,
|
||||
endpoint=endpoint or "",
|
||||
error=str(e),
|
||||
error_type="unexpected_error",
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
|
||||
@@ -2,30 +2,10 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Annotated
|
||||
|
||||
from a2a.types import PushNotificationAuthenticationInfo
|
||||
from pydantic import AnyHttpUrl, BaseModel, BeforeValidator, Field
|
||||
from pydantic import AnyHttpUrl, BaseModel, Field
|
||||
|
||||
from crewai.a2a.updates.base import PushNotificationResultStore
|
||||
from crewai.a2a.updates.push_notifications.signature import WebhookSignatureConfig
|
||||
|
||||
|
||||
def _coerce_signature(
|
||||
value: str | WebhookSignatureConfig | None,
|
||||
) -> WebhookSignatureConfig | None:
|
||||
"""Convert string secret to WebhookSignatureConfig."""
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, str):
|
||||
return WebhookSignatureConfig.hmac_sha256(secret=value)
|
||||
return value
|
||||
|
||||
|
||||
SignatureInput = Annotated[
|
||||
WebhookSignatureConfig | None,
|
||||
BeforeValidator(_coerce_signature),
|
||||
]
|
||||
|
||||
|
||||
class PushNotificationConfig(BaseModel):
|
||||
@@ -39,8 +19,6 @@ class PushNotificationConfig(BaseModel):
|
||||
timeout: Max seconds to wait for task completion.
|
||||
interval: Seconds between result polling attempts.
|
||||
result_store: Store for receiving push notification results.
|
||||
signature: HMAC signature config. Pass a string (secret) for defaults,
|
||||
or WebhookSignatureConfig for custom settings.
|
||||
"""
|
||||
|
||||
url: AnyHttpUrl = Field(description="Callback URL for push notifications")
|
||||
@@ -58,8 +36,3 @@ class PushNotificationConfig(BaseModel):
|
||||
result_store: PushNotificationResultStore | None = Field(
|
||||
default=None, description="Result store for push notification handling"
|
||||
)
|
||||
signature: SignatureInput = Field(
|
||||
default=None,
|
||||
description="HMAC signature config. Pass a string (secret) for simple usage, "
|
||||
"or WebhookSignatureConfig for custom headers/tolerance.",
|
||||
)
|
||||
|
||||
@@ -24,10 +24,8 @@ from crewai.a2a.task_helpers import (
|
||||
send_message_and_get_task_id,
|
||||
)
|
||||
from crewai.a2a.updates.base import (
|
||||
CommonParams,
|
||||
PushNotificationHandlerKwargs,
|
||||
PushNotificationResultStore,
|
||||
extract_common_params,
|
||||
)
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.a2a_events import (
|
||||
@@ -41,81 +39,10 @@ from crewai.events.types.a2a_events import (
|
||||
if TYPE_CHECKING:
|
||||
from a2a.types import Task as A2ATask
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _handle_push_error(
|
||||
error: Exception,
|
||||
error_msg: str,
|
||||
error_type: str,
|
||||
new_messages: list[Message],
|
||||
agent_branch: Any | None,
|
||||
params: CommonParams,
|
||||
task_id: str | None,
|
||||
status_code: int | None = None,
|
||||
) -> TaskStateResult:
|
||||
"""Handle push notification errors with consistent event emission.
|
||||
|
||||
Args:
|
||||
error: The exception that occurred.
|
||||
error_msg: Formatted error message for the result.
|
||||
error_type: Type of error for the event.
|
||||
new_messages: List to append error message to.
|
||||
agent_branch: Agent tree branch for events.
|
||||
params: Common handler parameters.
|
||||
task_id: A2A task ID.
|
||||
status_code: HTTP status code if applicable.
|
||||
|
||||
Returns:
|
||||
TaskStateResult with failed status.
|
||||
"""
|
||||
error_message = Message(
|
||||
role=Role.agent,
|
||||
message_id=str(uuid.uuid4()),
|
||||
parts=[Part(root=TextPart(text=error_msg))],
|
||||
context_id=params.context_id,
|
||||
task_id=task_id,
|
||||
)
|
||||
new_messages.append(error_message)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
A2AConnectionErrorEvent(
|
||||
endpoint=params.endpoint,
|
||||
error=str(error),
|
||||
error_type=error_type,
|
||||
status_code=status_code,
|
||||
a2a_agent_name=params.a2a_agent_name,
|
||||
operation="push_notification",
|
||||
context_id=params.context_id,
|
||||
task_id=task_id,
|
||||
from_task=params.from_task,
|
||||
from_agent=params.from_agent,
|
||||
),
|
||||
)
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
A2AResponseReceivedEvent(
|
||||
response=error_msg,
|
||||
turn_number=params.turn_number,
|
||||
context_id=params.context_id,
|
||||
is_multiturn=params.is_multiturn,
|
||||
status="failed",
|
||||
final=True,
|
||||
agent_role=params.agent_role,
|
||||
endpoint=params.endpoint,
|
||||
a2a_agent_name=params.a2a_agent_name,
|
||||
from_task=params.from_task,
|
||||
from_agent=params.from_agent,
|
||||
),
|
||||
)
|
||||
return TaskStateResult(
|
||||
status=TaskState.failed,
|
||||
error=error_msg,
|
||||
history=new_messages,
|
||||
)
|
||||
|
||||
|
||||
async def _wait_for_push_result(
|
||||
task_id: str,
|
||||
result_store: PushNotificationResultStore,
|
||||
@@ -199,8 +126,15 @@ class PushNotificationHandler:
|
||||
polling_timeout = kwargs.get("polling_timeout", 300.0)
|
||||
polling_interval = kwargs.get("polling_interval", 2.0)
|
||||
agent_branch = kwargs.get("agent_branch")
|
||||
turn_number = kwargs.get("turn_number", 0)
|
||||
is_multiturn = kwargs.get("is_multiturn", False)
|
||||
agent_role = kwargs.get("agent_role")
|
||||
context_id = kwargs.get("context_id")
|
||||
task_id = kwargs.get("task_id")
|
||||
params = extract_common_params(kwargs)
|
||||
endpoint = kwargs.get("endpoint")
|
||||
a2a_agent_name = kwargs.get("a2a_agent_name")
|
||||
from_task = kwargs.get("from_task")
|
||||
from_agent = kwargs.get("from_agent")
|
||||
|
||||
if config is None:
|
||||
error_msg = (
|
||||
@@ -209,15 +143,15 @@ class PushNotificationHandler:
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
A2AConnectionErrorEvent(
|
||||
endpoint=params.endpoint,
|
||||
endpoint=endpoint or "",
|
||||
error=error_msg,
|
||||
error_type="configuration_error",
|
||||
a2a_agent_name=params.a2a_agent_name,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
operation="push_notification",
|
||||
context_id=params.context_id,
|
||||
context_id=context_id,
|
||||
task_id=task_id,
|
||||
from_task=params.from_task,
|
||||
from_agent=params.from_agent,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
return TaskStateResult(
|
||||
@@ -233,15 +167,15 @@ class PushNotificationHandler:
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
A2AConnectionErrorEvent(
|
||||
endpoint=params.endpoint,
|
||||
endpoint=endpoint or "",
|
||||
error=error_msg,
|
||||
error_type="configuration_error",
|
||||
a2a_agent_name=params.a2a_agent_name,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
operation="push_notification",
|
||||
context_id=params.context_id,
|
||||
context_id=context_id,
|
||||
task_id=task_id,
|
||||
from_task=params.from_task,
|
||||
from_agent=params.from_agent,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
return TaskStateResult(
|
||||
@@ -255,14 +189,14 @@ class PushNotificationHandler:
|
||||
event_stream=client.send_message(message),
|
||||
new_messages=new_messages,
|
||||
agent_card=agent_card,
|
||||
turn_number=params.turn_number,
|
||||
is_multiturn=params.is_multiturn,
|
||||
agent_role=params.agent_role,
|
||||
from_task=params.from_task,
|
||||
from_agent=params.from_agent,
|
||||
endpoint=params.endpoint,
|
||||
a2a_agent_name=params.a2a_agent_name,
|
||||
context_id=params.context_id,
|
||||
turn_number=turn_number,
|
||||
is_multiturn=is_multiturn,
|
||||
agent_role=agent_role,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
context_id=context_id,
|
||||
)
|
||||
|
||||
if not isinstance(result_or_task_id, str):
|
||||
@@ -274,12 +208,12 @@ class PushNotificationHandler:
|
||||
agent_branch,
|
||||
A2APushNotificationRegisteredEvent(
|
||||
task_id=task_id,
|
||||
context_id=params.context_id,
|
||||
context_id=context_id,
|
||||
callback_url=str(config.url),
|
||||
endpoint=params.endpoint,
|
||||
a2a_agent_name=params.a2a_agent_name,
|
||||
from_task=params.from_task,
|
||||
from_agent=params.from_agent,
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -295,11 +229,11 @@ class PushNotificationHandler:
|
||||
timeout=polling_timeout,
|
||||
poll_interval=polling_interval,
|
||||
agent_branch=agent_branch,
|
||||
from_task=params.from_task,
|
||||
from_agent=params.from_agent,
|
||||
context_id=params.context_id,
|
||||
endpoint=params.endpoint,
|
||||
a2a_agent_name=params.a2a_agent_name,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
context_id=context_id,
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
)
|
||||
|
||||
if final_task is None:
|
||||
@@ -313,13 +247,13 @@ class PushNotificationHandler:
|
||||
a2a_task=final_task,
|
||||
new_messages=new_messages,
|
||||
agent_card=agent_card,
|
||||
turn_number=params.turn_number,
|
||||
is_multiturn=params.is_multiturn,
|
||||
agent_role=params.agent_role,
|
||||
endpoint=params.endpoint,
|
||||
a2a_agent_name=params.a2a_agent_name,
|
||||
from_task=params.from_task,
|
||||
from_agent=params.from_agent,
|
||||
turn_number=turn_number,
|
||||
is_multiturn=is_multiturn,
|
||||
agent_role=agent_role,
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
if result:
|
||||
return result
|
||||
@@ -331,24 +265,98 @@ class PushNotificationHandler:
|
||||
)
|
||||
|
||||
except A2AClientHTTPError as e:
|
||||
return _handle_push_error(
|
||||
error=e,
|
||||
error_msg=f"HTTP Error {e.status_code}: {e!s}",
|
||||
error_type="http_error",
|
||||
new_messages=new_messages,
|
||||
agent_branch=agent_branch,
|
||||
params=params,
|
||||
error_msg = f"HTTP Error {e.status_code}: {e!s}"
|
||||
|
||||
error_message = Message(
|
||||
role=Role.agent,
|
||||
message_id=str(uuid.uuid4()),
|
||||
parts=[Part(root=TextPart(text=error_msg))],
|
||||
context_id=context_id,
|
||||
task_id=task_id,
|
||||
status_code=e.status_code,
|
||||
)
|
||||
new_messages.append(error_message)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
A2AConnectionErrorEvent(
|
||||
endpoint=endpoint or "",
|
||||
error=str(e),
|
||||
error_type="http_error",
|
||||
status_code=e.status_code,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
operation="push_notification",
|
||||
context_id=context_id,
|
||||
task_id=task_id,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
A2AResponseReceivedEvent(
|
||||
response=error_msg,
|
||||
turn_number=turn_number,
|
||||
context_id=context_id,
|
||||
is_multiturn=is_multiturn,
|
||||
status="failed",
|
||||
final=True,
|
||||
agent_role=agent_role,
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
return TaskStateResult(
|
||||
status=TaskState.failed,
|
||||
error=error_msg,
|
||||
history=new_messages,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
return _handle_push_error(
|
||||
error=e,
|
||||
error_msg=f"Unexpected error during push notification: {e!s}",
|
||||
error_type="unexpected_error",
|
||||
new_messages=new_messages,
|
||||
agent_branch=agent_branch,
|
||||
params=params,
|
||||
error_msg = f"Unexpected error during push notification: {e!s}"
|
||||
|
||||
error_message = Message(
|
||||
role=Role.agent,
|
||||
message_id=str(uuid.uuid4()),
|
||||
parts=[Part(root=TextPart(text=error_msg))],
|
||||
context_id=context_id,
|
||||
task_id=task_id,
|
||||
)
|
||||
new_messages.append(error_message)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
A2AConnectionErrorEvent(
|
||||
endpoint=endpoint or "",
|
||||
error=str(e),
|
||||
error_type="unexpected_error",
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
operation="push_notification",
|
||||
context_id=context_id,
|
||||
task_id=task_id,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
A2AResponseReceivedEvent(
|
||||
response=error_msg,
|
||||
turn_number=turn_number,
|
||||
context_id=context_id,
|
||||
is_multiturn=is_multiturn,
|
||||
status="failed",
|
||||
final=True,
|
||||
agent_role=agent_role,
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
return TaskStateResult(
|
||||
status=TaskState.failed,
|
||||
error=error_msg,
|
||||
history=new_messages,
|
||||
)
|
||||
|
||||
@@ -1,87 +0,0 @@
|
||||
"""Webhook signature configuration for push notifications."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import Enum
|
||||
import secrets
|
||||
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
|
||||
|
||||
class WebhookSignatureMode(str, Enum):
|
||||
"""Signature mode for webhook push notifications."""
|
||||
|
||||
NONE = "none"
|
||||
HMAC_SHA256 = "hmac_sha256"
|
||||
|
||||
|
||||
class WebhookSignatureConfig(BaseModel):
|
||||
"""Configuration for webhook signature verification.
|
||||
|
||||
Provides cryptographic integrity verification and replay attack protection
|
||||
for A2A push notifications.
|
||||
|
||||
Attributes:
|
||||
mode: Signature mode (none or hmac_sha256).
|
||||
secret: Shared secret for HMAC computation (required for hmac_sha256 mode).
|
||||
timestamp_tolerance_seconds: Max allowed age of timestamps for replay protection.
|
||||
header_name: HTTP header name for the signature.
|
||||
timestamp_header_name: HTTP header name for the timestamp.
|
||||
"""
|
||||
|
||||
mode: WebhookSignatureMode = Field(
|
||||
default=WebhookSignatureMode.NONE,
|
||||
description="Signature verification mode",
|
||||
)
|
||||
secret: SecretStr | None = Field(
|
||||
default=None,
|
||||
description="Shared secret for HMAC computation",
|
||||
)
|
||||
timestamp_tolerance_seconds: int = Field(
|
||||
default=300,
|
||||
ge=0,
|
||||
description="Max allowed timestamp age in seconds (5 min default)",
|
||||
)
|
||||
header_name: str = Field(
|
||||
default="X-A2A-Signature",
|
||||
description="HTTP header name for the signature",
|
||||
)
|
||||
timestamp_header_name: str = Field(
|
||||
default="X-A2A-Signature-Timestamp",
|
||||
description="HTTP header name for the timestamp",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def generate_secret(cls, length: int = 32) -> str:
|
||||
"""Generate a cryptographically secure random secret.
|
||||
|
||||
Args:
|
||||
length: Number of random bytes to generate (default 32).
|
||||
|
||||
Returns:
|
||||
URL-safe base64-encoded secret string.
|
||||
"""
|
||||
return secrets.token_urlsafe(length)
|
||||
|
||||
@classmethod
|
||||
def hmac_sha256(
|
||||
cls,
|
||||
secret: str | SecretStr,
|
||||
timestamp_tolerance_seconds: int = 300,
|
||||
) -> WebhookSignatureConfig:
|
||||
"""Create an HMAC-SHA256 signature configuration.
|
||||
|
||||
Args:
|
||||
secret: Shared secret for HMAC computation.
|
||||
timestamp_tolerance_seconds: Max allowed timestamp age in seconds.
|
||||
|
||||
Returns:
|
||||
Configured WebhookSignatureConfig for HMAC-SHA256.
|
||||
"""
|
||||
if isinstance(secret, str):
|
||||
secret = SecretStr(secret)
|
||||
return cls(
|
||||
mode=WebhookSignatureMode.HMAC_SHA256,
|
||||
secret=secret,
|
||||
timestamp_tolerance_seconds=timestamp_tolerance_seconds,
|
||||
)
|
||||
@@ -2,9 +2,6 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Final
|
||||
import uuid
|
||||
|
||||
from a2a.client import Client
|
||||
@@ -14,10 +11,7 @@ from a2a.types import (
|
||||
Message,
|
||||
Part,
|
||||
Role,
|
||||
Task,
|
||||
TaskArtifactUpdateEvent,
|
||||
TaskIdParams,
|
||||
TaskQueryParams,
|
||||
TaskState,
|
||||
TaskStatusUpdateEvent,
|
||||
TextPart,
|
||||
@@ -30,10 +24,7 @@ from crewai.a2a.task_helpers import (
|
||||
TaskStateResult,
|
||||
process_task_state,
|
||||
)
|
||||
from crewai.a2a.updates.base import StreamingHandlerKwargs, extract_common_params
|
||||
from crewai.a2a.updates.streaming.params import (
|
||||
process_status_update,
|
||||
)
|
||||
from crewai.a2a.updates.base import StreamingHandlerKwargs
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.a2a_events import (
|
||||
A2AArtifactReceivedEvent,
|
||||
@@ -44,194 +35,9 @@ from crewai.events.types.a2a_events import (
|
||||
)
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MAX_RESUBSCRIBE_ATTEMPTS: Final[int] = 3
|
||||
RESUBSCRIBE_BACKOFF_BASE: Final[float] = 1.0
|
||||
|
||||
|
||||
class StreamingHandler:
|
||||
"""SSE streaming-based update handler."""
|
||||
|
||||
@staticmethod
|
||||
async def _try_recover_from_interruption( # type: ignore[misc]
|
||||
client: Client,
|
||||
task_id: str,
|
||||
new_messages: list[Message],
|
||||
agent_card: AgentCard,
|
||||
result_parts: list[str],
|
||||
**kwargs: Unpack[StreamingHandlerKwargs],
|
||||
) -> TaskStateResult | None:
|
||||
"""Attempt to recover from a stream interruption by checking task state.
|
||||
|
||||
If the task completed while we were disconnected, returns the result.
|
||||
If the task is still running, attempts to resubscribe and continue.
|
||||
|
||||
Args:
|
||||
client: A2A client instance.
|
||||
task_id: The task ID to recover.
|
||||
new_messages: List of collected messages.
|
||||
agent_card: The agent card.
|
||||
result_parts: Accumulated result text parts.
|
||||
**kwargs: Handler parameters.
|
||||
|
||||
Returns:
|
||||
TaskStateResult if recovery succeeded (task finished or resubscribe worked).
|
||||
None if recovery not possible (caller should handle failure).
|
||||
|
||||
Note:
|
||||
When None is returned, recovery failed and the original exception should
|
||||
be handled by the caller. All recovery attempts are logged.
|
||||
"""
|
||||
params = extract_common_params(kwargs) # type: ignore[arg-type]
|
||||
|
||||
try:
|
||||
a2a_task: Task = await client.get_task(TaskQueryParams(id=task_id))
|
||||
|
||||
if a2a_task.status.state in TERMINAL_STATES:
|
||||
logger.info(
|
||||
"Task completed during stream interruption",
|
||||
extra={"task_id": task_id, "state": str(a2a_task.status.state)},
|
||||
)
|
||||
return process_task_state(
|
||||
a2a_task=a2a_task,
|
||||
new_messages=new_messages,
|
||||
agent_card=agent_card,
|
||||
turn_number=params.turn_number,
|
||||
is_multiturn=params.is_multiturn,
|
||||
agent_role=params.agent_role,
|
||||
result_parts=result_parts,
|
||||
endpoint=params.endpoint,
|
||||
a2a_agent_name=params.a2a_agent_name,
|
||||
from_task=params.from_task,
|
||||
from_agent=params.from_agent,
|
||||
)
|
||||
|
||||
if a2a_task.status.state in ACTIONABLE_STATES:
|
||||
logger.info(
|
||||
"Task in actionable state during stream interruption",
|
||||
extra={"task_id": task_id, "state": str(a2a_task.status.state)},
|
||||
)
|
||||
return process_task_state(
|
||||
a2a_task=a2a_task,
|
||||
new_messages=new_messages,
|
||||
agent_card=agent_card,
|
||||
turn_number=params.turn_number,
|
||||
is_multiturn=params.is_multiturn,
|
||||
agent_role=params.agent_role,
|
||||
result_parts=result_parts,
|
||||
endpoint=params.endpoint,
|
||||
a2a_agent_name=params.a2a_agent_name,
|
||||
from_task=params.from_task,
|
||||
from_agent=params.from_agent,
|
||||
is_final=False,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Task still running, attempting resubscribe",
|
||||
extra={"task_id": task_id, "state": str(a2a_task.status.state)},
|
||||
)
|
||||
|
||||
for attempt in range(MAX_RESUBSCRIBE_ATTEMPTS):
|
||||
try:
|
||||
backoff = RESUBSCRIBE_BACKOFF_BASE * (2**attempt)
|
||||
if attempt > 0:
|
||||
await asyncio.sleep(backoff)
|
||||
|
||||
event_stream = client.resubscribe(TaskIdParams(id=task_id))
|
||||
|
||||
async for event in event_stream:
|
||||
if isinstance(event, tuple):
|
||||
resubscribed_task, update = event
|
||||
|
||||
is_final_update = (
|
||||
process_status_update(update, result_parts)
|
||||
if isinstance(update, TaskStatusUpdateEvent)
|
||||
else False
|
||||
)
|
||||
|
||||
if isinstance(update, TaskArtifactUpdateEvent):
|
||||
artifact = update.artifact
|
||||
result_parts.extend(
|
||||
part.root.text
|
||||
for part in artifact.parts
|
||||
if part.root.kind == "text"
|
||||
)
|
||||
|
||||
if (
|
||||
is_final_update
|
||||
or resubscribed_task.status.state
|
||||
in TERMINAL_STATES | ACTIONABLE_STATES
|
||||
):
|
||||
return process_task_state(
|
||||
a2a_task=resubscribed_task,
|
||||
new_messages=new_messages,
|
||||
agent_card=agent_card,
|
||||
turn_number=params.turn_number,
|
||||
is_multiturn=params.is_multiturn,
|
||||
agent_role=params.agent_role,
|
||||
result_parts=result_parts,
|
||||
endpoint=params.endpoint,
|
||||
a2a_agent_name=params.a2a_agent_name,
|
||||
from_task=params.from_task,
|
||||
from_agent=params.from_agent,
|
||||
is_final=is_final_update,
|
||||
)
|
||||
|
||||
elif isinstance(event, Message):
|
||||
new_messages.append(event)
|
||||
result_parts.extend(
|
||||
part.root.text
|
||||
for part in event.parts
|
||||
if part.root.kind == "text"
|
||||
)
|
||||
|
||||
final_task = await client.get_task(TaskQueryParams(id=task_id))
|
||||
return process_task_state(
|
||||
a2a_task=final_task,
|
||||
new_messages=new_messages,
|
||||
agent_card=agent_card,
|
||||
turn_number=params.turn_number,
|
||||
is_multiturn=params.is_multiturn,
|
||||
agent_role=params.agent_role,
|
||||
result_parts=result_parts,
|
||||
endpoint=params.endpoint,
|
||||
a2a_agent_name=params.a2a_agent_name,
|
||||
from_task=params.from_task,
|
||||
from_agent=params.from_agent,
|
||||
)
|
||||
|
||||
except Exception as resubscribe_error: # noqa: PERF203
|
||||
logger.warning(
|
||||
"Resubscribe attempt failed",
|
||||
extra={
|
||||
"task_id": task_id,
|
||||
"attempt": attempt + 1,
|
||||
"max_attempts": MAX_RESUBSCRIBE_ATTEMPTS,
|
||||
"error": str(resubscribe_error),
|
||||
},
|
||||
)
|
||||
if attempt == MAX_RESUBSCRIBE_ATTEMPTS - 1:
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Failed to recover from stream interruption due to unexpected error",
|
||||
extra={
|
||||
"task_id": task_id,
|
||||
"error": str(e),
|
||||
"error_type": type(e).__name__,
|
||||
},
|
||||
exc_info=True,
|
||||
)
|
||||
return None
|
||||
|
||||
logger.warning(
|
||||
"Recovery exhausted all resubscribe attempts without success",
|
||||
extra={"task_id": task_id, "max_attempts": MAX_RESUBSCRIBE_ATTEMPTS},
|
||||
)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
async def execute(
|
||||
client: Client,
|
||||
@@ -252,40 +58,42 @@ class StreamingHandler:
|
||||
Returns:
|
||||
Dictionary with status, result/error, and history.
|
||||
"""
|
||||
context_id = kwargs.get("context_id")
|
||||
task_id = kwargs.get("task_id")
|
||||
turn_number = kwargs.get("turn_number", 0)
|
||||
is_multiturn = kwargs.get("is_multiturn", False)
|
||||
agent_role = kwargs.get("agent_role")
|
||||
endpoint = kwargs.get("endpoint")
|
||||
a2a_agent_name = kwargs.get("a2a_agent_name")
|
||||
from_task = kwargs.get("from_task")
|
||||
from_agent = kwargs.get("from_agent")
|
||||
agent_branch = kwargs.get("agent_branch")
|
||||
params = extract_common_params(kwargs)
|
||||
|
||||
result_parts: list[str] = []
|
||||
final_result: TaskStateResult | None = None
|
||||
event_stream = client.send_message(message)
|
||||
chunk_index = 0
|
||||
current_task_id: str | None = task_id
|
||||
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
A2AStreamingStartedEvent(
|
||||
task_id=task_id,
|
||||
context_id=params.context_id,
|
||||
endpoint=params.endpoint,
|
||||
a2a_agent_name=params.a2a_agent_name,
|
||||
turn_number=params.turn_number,
|
||||
is_multiturn=params.is_multiturn,
|
||||
agent_role=params.agent_role,
|
||||
from_task=params.from_task,
|
||||
from_agent=params.from_agent,
|
||||
context_id=context_id,
|
||||
endpoint=endpoint or "",
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
turn_number=turn_number,
|
||||
is_multiturn=is_multiturn,
|
||||
agent_role=agent_role,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
|
||||
try:
|
||||
async for event in event_stream:
|
||||
if isinstance(event, tuple):
|
||||
a2a_task, _ = event
|
||||
current_task_id = a2a_task.id
|
||||
|
||||
if isinstance(event, Message):
|
||||
new_messages.append(event)
|
||||
message_context_id = event.context_id or params.context_id
|
||||
message_context_id = event.context_id or context_id
|
||||
for part in event.parts:
|
||||
if part.root.kind == "text":
|
||||
text = part.root.text
|
||||
@@ -297,12 +105,12 @@ class StreamingHandler:
|
||||
context_id=message_context_id,
|
||||
chunk=text,
|
||||
chunk_index=chunk_index,
|
||||
endpoint=params.endpoint,
|
||||
a2a_agent_name=params.a2a_agent_name,
|
||||
turn_number=params.turn_number,
|
||||
is_multiturn=params.is_multiturn,
|
||||
from_task=params.from_task,
|
||||
from_agent=params.from_agent,
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
turn_number=turn_number,
|
||||
is_multiturn=is_multiturn,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
chunk_index += 1
|
||||
@@ -320,12 +128,12 @@ class StreamingHandler:
|
||||
artifact_size = None
|
||||
if artifact.parts:
|
||||
artifact_size = sum(
|
||||
len(p.root.text.encode())
|
||||
len(p.root.text.encode("utf-8"))
|
||||
if p.root.kind == "text"
|
||||
else len(getattr(p.root, "data", b""))
|
||||
for p in artifact.parts
|
||||
)
|
||||
effective_context_id = a2a_task.context_id or params.context_id
|
||||
effective_context_id = a2a_task.context_id or context_id
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
A2AArtifactReceivedEvent(
|
||||
@@ -339,21 +147,29 @@ class StreamingHandler:
|
||||
size_bytes=artifact_size,
|
||||
append=update.append or False,
|
||||
last_chunk=update.last_chunk or False,
|
||||
endpoint=params.endpoint,
|
||||
a2a_agent_name=params.a2a_agent_name,
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
context_id=effective_context_id,
|
||||
turn_number=params.turn_number,
|
||||
is_multiturn=params.is_multiturn,
|
||||
from_task=params.from_task,
|
||||
from_agent=params.from_agent,
|
||||
turn_number=turn_number,
|
||||
is_multiturn=is_multiturn,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
|
||||
is_final_update = (
|
||||
process_status_update(update, result_parts)
|
||||
if isinstance(update, TaskStatusUpdateEvent)
|
||||
else False
|
||||
)
|
||||
is_final_update = False
|
||||
if isinstance(update, TaskStatusUpdateEvent):
|
||||
is_final_update = update.final
|
||||
if (
|
||||
update.status
|
||||
and update.status.message
|
||||
and update.status.message.parts
|
||||
):
|
||||
result_parts.extend(
|
||||
part.root.text
|
||||
for part in update.status.message.parts
|
||||
if part.root.kind == "text" and part.root.text
|
||||
)
|
||||
|
||||
if (
|
||||
not is_final_update
|
||||
@@ -366,68 +182,27 @@ class StreamingHandler:
|
||||
a2a_task=a2a_task,
|
||||
new_messages=new_messages,
|
||||
agent_card=agent_card,
|
||||
turn_number=params.turn_number,
|
||||
is_multiturn=params.is_multiturn,
|
||||
agent_role=params.agent_role,
|
||||
turn_number=turn_number,
|
||||
is_multiturn=is_multiturn,
|
||||
agent_role=agent_role,
|
||||
result_parts=result_parts,
|
||||
endpoint=params.endpoint,
|
||||
a2a_agent_name=params.a2a_agent_name,
|
||||
from_task=params.from_task,
|
||||
from_agent=params.from_agent,
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
is_final=is_final_update,
|
||||
)
|
||||
if final_result:
|
||||
break
|
||||
|
||||
except A2AClientHTTPError as e:
|
||||
if current_task_id:
|
||||
logger.info(
|
||||
"Stream interrupted with HTTP error, attempting recovery",
|
||||
extra={
|
||||
"task_id": current_task_id,
|
||||
"error": str(e),
|
||||
"status_code": e.status_code,
|
||||
},
|
||||
)
|
||||
recovery_kwargs = {k: v for k, v in kwargs.items() if k != "task_id"}
|
||||
recovered_result = (
|
||||
await StreamingHandler._try_recover_from_interruption(
|
||||
client=client,
|
||||
task_id=current_task_id,
|
||||
new_messages=new_messages,
|
||||
agent_card=agent_card,
|
||||
result_parts=result_parts,
|
||||
**recovery_kwargs,
|
||||
)
|
||||
)
|
||||
if recovered_result:
|
||||
logger.info(
|
||||
"Successfully recovered task after HTTP error",
|
||||
extra={
|
||||
"task_id": current_task_id,
|
||||
"status": str(recovered_result.get("status")),
|
||||
},
|
||||
)
|
||||
return recovered_result
|
||||
|
||||
logger.warning(
|
||||
"Failed to recover from HTTP error, returning failure",
|
||||
extra={
|
||||
"task_id": current_task_id,
|
||||
"status_code": e.status_code,
|
||||
"original_error": str(e),
|
||||
},
|
||||
)
|
||||
|
||||
error_msg = f"HTTP Error {e.status_code}: {e!s}"
|
||||
error_type = "http_error"
|
||||
status_code = e.status_code
|
||||
|
||||
error_message = Message(
|
||||
role=Role.agent,
|
||||
message_id=str(uuid.uuid4()),
|
||||
parts=[Part(root=TextPart(text=error_msg))],
|
||||
context_id=params.context_id,
|
||||
context_id=context_id,
|
||||
task_id=task_id,
|
||||
)
|
||||
new_messages.append(error_message)
|
||||
@@ -435,118 +210,32 @@ class StreamingHandler:
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
A2AConnectionErrorEvent(
|
||||
endpoint=params.endpoint,
|
||||
endpoint=endpoint or "",
|
||||
error=str(e),
|
||||
error_type=error_type,
|
||||
status_code=status_code,
|
||||
a2a_agent_name=params.a2a_agent_name,
|
||||
error_type="http_error",
|
||||
status_code=e.status_code,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
operation="streaming",
|
||||
context_id=params.context_id,
|
||||
context_id=context_id,
|
||||
task_id=task_id,
|
||||
from_task=params.from_task,
|
||||
from_agent=params.from_agent,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
A2AResponseReceivedEvent(
|
||||
response=error_msg,
|
||||
turn_number=params.turn_number,
|
||||
context_id=params.context_id,
|
||||
is_multiturn=params.is_multiturn,
|
||||
turn_number=turn_number,
|
||||
context_id=context_id,
|
||||
is_multiturn=is_multiturn,
|
||||
status="failed",
|
||||
final=True,
|
||||
agent_role=params.agent_role,
|
||||
endpoint=params.endpoint,
|
||||
a2a_agent_name=params.a2a_agent_name,
|
||||
from_task=params.from_task,
|
||||
from_agent=params.from_agent,
|
||||
),
|
||||
)
|
||||
return TaskStateResult(
|
||||
status=TaskState.failed,
|
||||
error=error_msg,
|
||||
history=new_messages,
|
||||
)
|
||||
|
||||
except (asyncio.TimeoutError, asyncio.CancelledError, ConnectionError) as e:
|
||||
error_type = type(e).__name__.lower()
|
||||
if current_task_id:
|
||||
logger.info(
|
||||
f"Stream interrupted with {error_type}, attempting recovery",
|
||||
extra={"task_id": current_task_id, "error": str(e)},
|
||||
)
|
||||
recovery_kwargs = {k: v for k, v in kwargs.items() if k != "task_id"}
|
||||
recovered_result = (
|
||||
await StreamingHandler._try_recover_from_interruption(
|
||||
client=client,
|
||||
task_id=current_task_id,
|
||||
new_messages=new_messages,
|
||||
agent_card=agent_card,
|
||||
result_parts=result_parts,
|
||||
**recovery_kwargs,
|
||||
)
|
||||
)
|
||||
if recovered_result:
|
||||
logger.info(
|
||||
f"Successfully recovered task after {error_type}",
|
||||
extra={
|
||||
"task_id": current_task_id,
|
||||
"status": str(recovered_result.get("status")),
|
||||
},
|
||||
)
|
||||
return recovered_result
|
||||
|
||||
logger.warning(
|
||||
f"Failed to recover from {error_type}, returning failure",
|
||||
extra={
|
||||
"task_id": current_task_id,
|
||||
"error_type": error_type,
|
||||
"original_error": str(e),
|
||||
},
|
||||
)
|
||||
|
||||
error_msg = f"Connection error during streaming: {e!s}"
|
||||
status_code = None
|
||||
|
||||
error_message = Message(
|
||||
role=Role.agent,
|
||||
message_id=str(uuid.uuid4()),
|
||||
parts=[Part(root=TextPart(text=error_msg))],
|
||||
context_id=params.context_id,
|
||||
task_id=task_id,
|
||||
)
|
||||
new_messages.append(error_message)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
A2AConnectionErrorEvent(
|
||||
endpoint=params.endpoint,
|
||||
error=str(e),
|
||||
error_type=error_type,
|
||||
status_code=status_code,
|
||||
a2a_agent_name=params.a2a_agent_name,
|
||||
operation="streaming",
|
||||
context_id=params.context_id,
|
||||
task_id=task_id,
|
||||
from_task=params.from_task,
|
||||
from_agent=params.from_agent,
|
||||
),
|
||||
)
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
A2AResponseReceivedEvent(
|
||||
response=error_msg,
|
||||
turn_number=params.turn_number,
|
||||
context_id=params.context_id,
|
||||
is_multiturn=params.is_multiturn,
|
||||
status="failed",
|
||||
final=True,
|
||||
agent_role=params.agent_role,
|
||||
endpoint=params.endpoint,
|
||||
a2a_agent_name=params.a2a_agent_name,
|
||||
from_task=params.from_task,
|
||||
from_agent=params.from_agent,
|
||||
agent_role=agent_role,
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
return TaskStateResult(
|
||||
@@ -556,23 +245,13 @@ class StreamingHandler:
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"Unexpected error during streaming",
|
||||
extra={
|
||||
"task_id": current_task_id,
|
||||
"error_type": type(e).__name__,
|
||||
"endpoint": params.endpoint,
|
||||
},
|
||||
)
|
||||
error_msg = f"Unexpected error during streaming: {type(e).__name__}: {e!s}"
|
||||
error_type = "unexpected_error"
|
||||
status_code = None
|
||||
error_msg = f"Unexpected error during streaming: {e!s}"
|
||||
|
||||
error_message = Message(
|
||||
role=Role.agent,
|
||||
message_id=str(uuid.uuid4()),
|
||||
parts=[Part(root=TextPart(text=error_msg))],
|
||||
context_id=params.context_id,
|
||||
context_id=context_id,
|
||||
task_id=task_id,
|
||||
)
|
||||
new_messages.append(error_message)
|
||||
@@ -580,32 +259,31 @@ class StreamingHandler:
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
A2AConnectionErrorEvent(
|
||||
endpoint=params.endpoint,
|
||||
endpoint=endpoint or "",
|
||||
error=str(e),
|
||||
error_type=error_type,
|
||||
status_code=status_code,
|
||||
a2a_agent_name=params.a2a_agent_name,
|
||||
error_type="unexpected_error",
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
operation="streaming",
|
||||
context_id=params.context_id,
|
||||
context_id=context_id,
|
||||
task_id=task_id,
|
||||
from_task=params.from_task,
|
||||
from_agent=params.from_agent,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
A2AResponseReceivedEvent(
|
||||
response=error_msg,
|
||||
turn_number=params.turn_number,
|
||||
context_id=params.context_id,
|
||||
is_multiturn=params.is_multiturn,
|
||||
turn_number=turn_number,
|
||||
context_id=context_id,
|
||||
is_multiturn=is_multiturn,
|
||||
status="failed",
|
||||
final=True,
|
||||
agent_role=params.agent_role,
|
||||
endpoint=params.endpoint,
|
||||
a2a_agent_name=params.a2a_agent_name,
|
||||
from_task=params.from_task,
|
||||
from_agent=params.from_agent,
|
||||
agent_role=agent_role,
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
return TaskStateResult(
|
||||
@@ -623,15 +301,15 @@ class StreamingHandler:
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
A2AConnectionErrorEvent(
|
||||
endpoint=params.endpoint,
|
||||
endpoint=endpoint or "",
|
||||
error=str(close_error),
|
||||
error_type="stream_close_error",
|
||||
a2a_agent_name=params.a2a_agent_name,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
operation="stream_close",
|
||||
context_id=params.context_id,
|
||||
context_id=context_id,
|
||||
task_id=task_id,
|
||||
from_task=params.from_task,
|
||||
from_agent=params.from_agent,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -1,28 +0,0 @@
|
||||
"""Common parameter extraction for streaming handlers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from a2a.types import TaskStatusUpdateEvent
|
||||
|
||||
|
||||
def process_status_update(
|
||||
update: TaskStatusUpdateEvent,
|
||||
result_parts: list[str],
|
||||
) -> bool:
|
||||
"""Process a status update event and extract text parts.
|
||||
|
||||
Args:
|
||||
update: The status update event.
|
||||
result_parts: List to append text parts to (modified in place).
|
||||
|
||||
Returns:
|
||||
True if this is a final update, False otherwise.
|
||||
"""
|
||||
is_final = update.final
|
||||
if update.status and update.status.message and update.status.message.parts:
|
||||
result_parts.extend(
|
||||
part.root.text
|
||||
for part in update.status.message.parts
|
||||
if part.root.kind == "text" and part.root.text
|
||||
)
|
||||
return is_final
|
||||
@@ -5,7 +5,6 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
from collections.abc import MutableMapping
|
||||
from functools import lru_cache
|
||||
import ssl
|
||||
import time
|
||||
from types import MethodType
|
||||
from typing import TYPE_CHECKING
|
||||
@@ -16,7 +15,7 @@ from aiocache import cached # type: ignore[import-untyped]
|
||||
from aiocache.serializers import PickleSerializer # type: ignore[import-untyped]
|
||||
import httpx
|
||||
|
||||
from crewai.a2a.auth.client_schemes import APIKeyAuth, HTTPDigestAuth
|
||||
from crewai.a2a.auth.schemas import APIKeyAuth, HTTPDigestAuth
|
||||
from crewai.a2a.auth.utils import (
|
||||
_auth_store,
|
||||
configure_auth_client,
|
||||
@@ -33,51 +32,11 @@ from crewai.events.types.a2a_events import (
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.a2a.auth.client_schemes import ClientAuthScheme
|
||||
from crewai.a2a.auth.schemas import AuthScheme
|
||||
from crewai.agent import Agent
|
||||
from crewai.task import Task
|
||||
|
||||
|
||||
def _get_tls_verify(auth: ClientAuthScheme | None) -> ssl.SSLContext | bool | str:
|
||||
"""Get TLS verify parameter from auth scheme.
|
||||
|
||||
Args:
|
||||
auth: Optional authentication scheme with TLS config.
|
||||
|
||||
Returns:
|
||||
SSL context, CA cert path, True for default verification,
|
||||
or False if verification disabled.
|
||||
"""
|
||||
if auth and auth.tls:
|
||||
return auth.tls.get_httpx_ssl_context()
|
||||
return True
|
||||
|
||||
|
||||
async def _prepare_auth_headers(
|
||||
auth: ClientAuthScheme | None,
|
||||
timeout: int,
|
||||
) -> tuple[MutableMapping[str, str], ssl.SSLContext | bool | str]:
|
||||
"""Prepare authentication headers and TLS verification settings.
|
||||
|
||||
Args:
|
||||
auth: Optional authentication scheme.
|
||||
timeout: Request timeout in seconds.
|
||||
|
||||
Returns:
|
||||
Tuple of (headers dict, TLS verify setting).
|
||||
"""
|
||||
headers: MutableMapping[str, str] = {}
|
||||
verify = _get_tls_verify(auth)
|
||||
if auth:
|
||||
async with httpx.AsyncClient(
|
||||
timeout=timeout, verify=verify
|
||||
) as temp_auth_client:
|
||||
if isinstance(auth, (HTTPDigestAuth, APIKeyAuth)):
|
||||
configure_auth_client(auth, temp_auth_client)
|
||||
headers = await auth.apply_auth(temp_auth_client, {})
|
||||
return headers, verify
|
||||
|
||||
|
||||
def _get_server_config(agent: Agent) -> A2AServerConfig | None:
|
||||
"""Get A2AServerConfig from an agent's a2a configuration.
|
||||
|
||||
@@ -100,7 +59,7 @@ def _get_server_config(agent: Agent) -> A2AServerConfig | None:
|
||||
|
||||
def fetch_agent_card(
|
||||
endpoint: str,
|
||||
auth: ClientAuthScheme | None = None,
|
||||
auth: AuthScheme | None = None,
|
||||
timeout: int = 30,
|
||||
use_cache: bool = True,
|
||||
cache_ttl: int = 300,
|
||||
@@ -109,7 +68,7 @@ def fetch_agent_card(
|
||||
|
||||
Args:
|
||||
endpoint: A2A agent endpoint URL (AgentCard URL).
|
||||
auth: Optional ClientAuthScheme for authentication.
|
||||
auth: Optional AuthScheme for authentication.
|
||||
timeout: Request timeout in seconds.
|
||||
use_cache: Whether to use caching (default True).
|
||||
cache_ttl: Cache TTL in seconds (default 300 = 5 minutes).
|
||||
@@ -131,10 +90,10 @@ def fetch_agent_card(
|
||||
"_authorization_callback",
|
||||
}
|
||||
)
|
||||
auth_hash = _auth_store.compute_key(type(auth).__name__, auth_data)
|
||||
auth_hash = hash((type(auth).__name__, auth_data))
|
||||
else:
|
||||
auth_hash = _auth_store.compute_key("none", "")
|
||||
_auth_store.set(auth_hash, auth)
|
||||
auth_hash = 0
|
||||
_auth_store[auth_hash] = auth
|
||||
ttl_hash = int(time.time() // cache_ttl)
|
||||
return _fetch_agent_card_cached(endpoint, auth_hash, timeout, ttl_hash)
|
||||
|
||||
@@ -150,7 +109,7 @@ def fetch_agent_card(
|
||||
|
||||
async def afetch_agent_card(
|
||||
endpoint: str,
|
||||
auth: ClientAuthScheme | None = None,
|
||||
auth: AuthScheme | None = None,
|
||||
timeout: int = 30,
|
||||
use_cache: bool = True,
|
||||
) -> AgentCard:
|
||||
@@ -160,7 +119,7 @@ async def afetch_agent_card(
|
||||
|
||||
Args:
|
||||
endpoint: A2A agent endpoint URL (AgentCard URL).
|
||||
auth: Optional ClientAuthScheme for authentication.
|
||||
auth: Optional AuthScheme for authentication.
|
||||
timeout: Request timeout in seconds.
|
||||
use_cache: Whether to use caching (default True).
|
||||
|
||||
@@ -181,10 +140,10 @@ async def afetch_agent_card(
|
||||
"_authorization_callback",
|
||||
}
|
||||
)
|
||||
auth_hash = _auth_store.compute_key(type(auth).__name__, auth_data)
|
||||
auth_hash = hash((type(auth).__name__, auth_data))
|
||||
else:
|
||||
auth_hash = _auth_store.compute_key("none", "")
|
||||
_auth_store.set(auth_hash, auth)
|
||||
auth_hash = 0
|
||||
_auth_store[auth_hash] = auth
|
||||
agent_card: AgentCard = await _afetch_agent_card_cached(
|
||||
endpoint, auth_hash, timeout
|
||||
)
|
||||
@@ -196,7 +155,7 @@ async def afetch_agent_card(
|
||||
@lru_cache()
|
||||
def _fetch_agent_card_cached(
|
||||
endpoint: str,
|
||||
auth_hash: str,
|
||||
auth_hash: int,
|
||||
timeout: int,
|
||||
_ttl_hash: int,
|
||||
) -> AgentCard:
|
||||
@@ -216,7 +175,7 @@ def _fetch_agent_card_cached(
|
||||
@cached(ttl=300, serializer=PickleSerializer()) # type: ignore[untyped-decorator]
|
||||
async def _afetch_agent_card_cached(
|
||||
endpoint: str,
|
||||
auth_hash: str,
|
||||
auth_hash: int,
|
||||
timeout: int,
|
||||
) -> AgentCard:
|
||||
"""Cached async implementation of AgentCard fetching."""
|
||||
@@ -226,7 +185,7 @@ async def _afetch_agent_card_cached(
|
||||
|
||||
async def _afetch_agent_card_impl(
|
||||
endpoint: str,
|
||||
auth: ClientAuthScheme | None,
|
||||
auth: AuthScheme | None,
|
||||
timeout: int,
|
||||
) -> AgentCard:
|
||||
"""Internal async implementation of AgentCard fetching."""
|
||||
@@ -238,17 +197,16 @@ async def _afetch_agent_card_impl(
|
||||
else:
|
||||
url_parts = endpoint.split("/", 3)
|
||||
base_url = f"{url_parts[0]}//{url_parts[2]}"
|
||||
agent_card_path = (
|
||||
f"/{url_parts[3]}"
|
||||
if len(url_parts) > 3 and url_parts[3]
|
||||
else "/.well-known/agent-card.json"
|
||||
)
|
||||
agent_card_path = f"/{url_parts[3]}" if len(url_parts) > 3 else "/"
|
||||
|
||||
headers, verify = await _prepare_auth_headers(auth, timeout)
|
||||
headers: MutableMapping[str, str] = {}
|
||||
if auth:
|
||||
async with httpx.AsyncClient(timeout=timeout) as temp_auth_client:
|
||||
if isinstance(auth, (HTTPDigestAuth, APIKeyAuth)):
|
||||
configure_auth_client(auth, temp_auth_client)
|
||||
headers = await auth.apply_auth(temp_auth_client, {})
|
||||
|
||||
async with httpx.AsyncClient(
|
||||
timeout=timeout, headers=headers, verify=verify
|
||||
) as temp_client:
|
||||
async with httpx.AsyncClient(timeout=timeout, headers=headers) as temp_client:
|
||||
if auth and isinstance(auth, (HTTPDigestAuth, APIKeyAuth)):
|
||||
configure_auth_client(auth, temp_client)
|
||||
|
||||
@@ -476,7 +434,6 @@ def _agent_to_agent_card(agent: Agent, url: str) -> AgentCard:
|
||||
"""Generate an A2A AgentCard from an Agent instance.
|
||||
|
||||
Uses A2AServerConfig values when available, falling back to agent properties.
|
||||
If signing_config is provided, the card will be signed with JWS.
|
||||
|
||||
Args:
|
||||
agent: The Agent instance to generate a card for.
|
||||
@@ -485,8 +442,6 @@ def _agent_to_agent_card(agent: Agent, url: str) -> AgentCard:
|
||||
Returns:
|
||||
AgentCard describing the agent's capabilities.
|
||||
"""
|
||||
from crewai.a2a.utils.agent_card_signing import sign_agent_card
|
||||
|
||||
server_config = _get_server_config(agent) or A2AServerConfig()
|
||||
|
||||
name = server_config.name or agent.role
|
||||
@@ -517,31 +472,15 @@ def _agent_to_agent_card(agent: Agent, url: str) -> AgentCard:
|
||||
)
|
||||
)
|
||||
|
||||
capabilities = server_config.capabilities
|
||||
if server_config.server_extensions:
|
||||
from crewai.a2a.extensions.server import ServerExtensionRegistry
|
||||
|
||||
registry = ServerExtensionRegistry(server_config.server_extensions)
|
||||
ext_list = registry.get_agent_extensions()
|
||||
|
||||
existing_exts = list(capabilities.extensions) if capabilities.extensions else []
|
||||
existing_uris = {e.uri for e in existing_exts}
|
||||
for ext in ext_list:
|
||||
if ext.uri not in existing_uris:
|
||||
existing_exts.append(ext)
|
||||
|
||||
capabilities = capabilities.model_copy(update={"extensions": existing_exts})
|
||||
|
||||
card = AgentCard(
|
||||
return AgentCard(
|
||||
name=name,
|
||||
description=description,
|
||||
url=server_config.url or url,
|
||||
version=server_config.version,
|
||||
capabilities=capabilities,
|
||||
capabilities=server_config.capabilities,
|
||||
default_input_modes=server_config.default_input_modes,
|
||||
default_output_modes=server_config.default_output_modes,
|
||||
skills=skills,
|
||||
preferred_transport=server_config.transport.preferred,
|
||||
protocol_version=server_config.protocol_version,
|
||||
provider=server_config.provider,
|
||||
documentation_url=server_config.documentation_url,
|
||||
@@ -550,21 +489,9 @@ def _agent_to_agent_card(agent: Agent, url: str) -> AgentCard:
|
||||
security=server_config.security,
|
||||
security_schemes=server_config.security_schemes,
|
||||
supports_authenticated_extended_card=server_config.supports_authenticated_extended_card,
|
||||
signatures=server_config.signatures,
|
||||
)
|
||||
|
||||
if server_config.signing_config:
|
||||
signature = sign_agent_card(
|
||||
card,
|
||||
private_key=server_config.signing_config.get_private_key(),
|
||||
key_id=server_config.signing_config.key_id,
|
||||
algorithm=server_config.signing_config.algorithm,
|
||||
)
|
||||
card = card.model_copy(update={"signatures": [signature]})
|
||||
elif server_config.signatures:
|
||||
card = card.model_copy(update={"signatures": server_config.signatures})
|
||||
|
||||
return card
|
||||
|
||||
|
||||
def inject_a2a_server_methods(agent: Agent) -> None:
|
||||
"""Inject A2A server methods onto an Agent instance.
|
||||
|
||||
@@ -1,236 +0,0 @@
|
||||
"""AgentCard JWS signing utilities.
|
||||
|
||||
This module provides functions for signing and verifying AgentCards using
|
||||
JSON Web Signatures (JWS) as per RFC 7515. Signed agent cards allow clients
|
||||
to verify the authenticity and integrity of agent card information.
|
||||
|
||||
Example:
|
||||
>>> from crewai.a2a.utils.agent_card_signing import sign_agent_card
|
||||
>>> signature = sign_agent_card(agent_card, private_key_pem, key_id="key-1")
|
||||
>>> card_with_sig = card.model_copy(update={"signatures": [signature]})
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Literal
|
||||
|
||||
from a2a.types import AgentCard, AgentCardSignature
|
||||
import jwt
|
||||
from pydantic import SecretStr
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
SigningAlgorithm = Literal[
|
||||
"RS256", "RS384", "RS512", "ES256", "ES384", "ES512", "PS256", "PS384", "PS512"
|
||||
]
|
||||
|
||||
|
||||
def _normalize_private_key(private_key: str | bytes | SecretStr) -> bytes:
|
||||
"""Normalize private key to bytes format.
|
||||
|
||||
Args:
|
||||
private_key: PEM-encoded private key as string, bytes, or SecretStr.
|
||||
|
||||
Returns:
|
||||
Private key as bytes.
|
||||
"""
|
||||
if isinstance(private_key, SecretStr):
|
||||
private_key = private_key.get_secret_value()
|
||||
if isinstance(private_key, str):
|
||||
private_key = private_key.encode()
|
||||
return private_key
|
||||
|
||||
|
||||
def _serialize_agent_card(agent_card: AgentCard) -> str:
|
||||
"""Serialize AgentCard to canonical JSON for signing.
|
||||
|
||||
Excludes the signatures field to avoid circular reference during signing.
|
||||
Uses sorted keys and compact separators for deterministic output.
|
||||
|
||||
Args:
|
||||
agent_card: The AgentCard to serialize.
|
||||
|
||||
Returns:
|
||||
Canonical JSON string representation.
|
||||
"""
|
||||
card_dict = agent_card.model_dump(exclude={"signatures"}, exclude_none=True)
|
||||
return json.dumps(card_dict, sort_keys=True, separators=(",", ":"))
|
||||
|
||||
|
||||
def _base64url_encode(data: bytes | str) -> str:
|
||||
"""Encode data to URL-safe base64 without padding.
|
||||
|
||||
Args:
|
||||
data: Data to encode.
|
||||
|
||||
Returns:
|
||||
URL-safe base64 encoded string without padding.
|
||||
"""
|
||||
if isinstance(data, str):
|
||||
data = data.encode()
|
||||
return base64.urlsafe_b64encode(data).rstrip(b"=").decode("ascii")
|
||||
|
||||
|
||||
def sign_agent_card(
|
||||
agent_card: AgentCard,
|
||||
private_key: str | bytes | SecretStr,
|
||||
key_id: str | None = None,
|
||||
algorithm: SigningAlgorithm = "RS256",
|
||||
) -> AgentCardSignature:
|
||||
"""Sign an AgentCard using JWS (RFC 7515).
|
||||
|
||||
Creates a detached JWS signature for the AgentCard. The signature covers
|
||||
all fields except the signatures field itself.
|
||||
|
||||
Args:
|
||||
agent_card: The AgentCard to sign.
|
||||
private_key: PEM-encoded private key (RSA, EC, or RSA-PSS).
|
||||
key_id: Optional key identifier for the JWS header (kid claim).
|
||||
algorithm: Signing algorithm (RS256, ES256, PS256, etc.).
|
||||
|
||||
Returns:
|
||||
AgentCardSignature with protected header and signature.
|
||||
|
||||
Raises:
|
||||
jwt.exceptions.InvalidKeyError: If the private key is invalid.
|
||||
ValueError: If the algorithm is not supported for the key type.
|
||||
|
||||
Example:
|
||||
>>> signature = sign_agent_card(
|
||||
... agent_card,
|
||||
... private_key_pem="-----BEGIN PRIVATE KEY-----...",
|
||||
... key_id="my-key-id",
|
||||
... )
|
||||
"""
|
||||
key_bytes = _normalize_private_key(private_key)
|
||||
payload = _serialize_agent_card(agent_card)
|
||||
|
||||
protected_header: dict[str, Any] = {"typ": "JWS"}
|
||||
if key_id:
|
||||
protected_header["kid"] = key_id
|
||||
|
||||
jws_token = jwt.api_jws.encode(
|
||||
payload.encode(),
|
||||
key_bytes,
|
||||
algorithm=algorithm,
|
||||
headers=protected_header,
|
||||
)
|
||||
|
||||
parts = jws_token.split(".")
|
||||
protected_b64 = parts[0]
|
||||
signature_b64 = parts[2]
|
||||
|
||||
header: dict[str, Any] | None = None
|
||||
if key_id:
|
||||
header = {"kid": key_id}
|
||||
|
||||
return AgentCardSignature(
|
||||
protected=protected_b64,
|
||||
signature=signature_b64,
|
||||
header=header,
|
||||
)
|
||||
|
||||
|
||||
def verify_agent_card_signature(
|
||||
agent_card: AgentCard,
|
||||
signature: AgentCardSignature,
|
||||
public_key: str | bytes,
|
||||
algorithms: list[str] | None = None,
|
||||
) -> bool:
|
||||
"""Verify an AgentCard JWS signature.
|
||||
|
||||
Validates that the signature was created with the corresponding private key
|
||||
and that the AgentCard content has not been modified.
|
||||
|
||||
Args:
|
||||
agent_card: The AgentCard to verify.
|
||||
signature: The AgentCardSignature to validate.
|
||||
public_key: PEM-encoded public key (RSA, EC, or RSA-PSS).
|
||||
algorithms: List of allowed algorithms. Defaults to common asymmetric algorithms.
|
||||
|
||||
Returns:
|
||||
True if signature is valid, False otherwise.
|
||||
|
||||
Example:
|
||||
>>> is_valid = verify_agent_card_signature(
|
||||
... agent_card, signature, public_key_pem="-----BEGIN PUBLIC KEY-----..."
|
||||
... )
|
||||
"""
|
||||
if algorithms is None:
|
||||
algorithms = [
|
||||
"RS256",
|
||||
"RS384",
|
||||
"RS512",
|
||||
"ES256",
|
||||
"ES384",
|
||||
"ES512",
|
||||
"PS256",
|
||||
"PS384",
|
||||
"PS512",
|
||||
]
|
||||
|
||||
if isinstance(public_key, str):
|
||||
public_key = public_key.encode()
|
||||
|
||||
payload = _serialize_agent_card(agent_card)
|
||||
payload_b64 = _base64url_encode(payload)
|
||||
jws_token = f"{signature.protected}.{payload_b64}.{signature.signature}"
|
||||
|
||||
try:
|
||||
jwt.api_jws.decode(
|
||||
jws_token,
|
||||
public_key,
|
||||
algorithms=algorithms,
|
||||
)
|
||||
return True
|
||||
except jwt.InvalidSignatureError:
|
||||
logger.debug(
|
||||
"AgentCard signature verification failed",
|
||||
extra={"reason": "invalid_signature"},
|
||||
)
|
||||
return False
|
||||
except jwt.DecodeError as e:
|
||||
logger.debug(
|
||||
"AgentCard signature verification failed",
|
||||
extra={"reason": "decode_error", "error": str(e)},
|
||||
)
|
||||
return False
|
||||
except jwt.InvalidAlgorithmError as e:
|
||||
logger.debug(
|
||||
"AgentCard signature verification failed",
|
||||
extra={"reason": "algorithm_error", "error": str(e)},
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
def get_key_id_from_signature(signature: AgentCardSignature) -> str | None:
|
||||
"""Extract the key ID (kid) from an AgentCardSignature.
|
||||
|
||||
Checks both the unprotected header and the protected header for the kid claim.
|
||||
|
||||
Args:
|
||||
signature: The AgentCardSignature to extract from.
|
||||
|
||||
Returns:
|
||||
The key ID if present, None otherwise.
|
||||
"""
|
||||
if signature.header and "kid" in signature.header:
|
||||
kid: str = signature.header["kid"]
|
||||
return kid
|
||||
|
||||
try:
|
||||
protected = signature.protected
|
||||
padding_needed = 4 - (len(protected) % 4)
|
||||
if padding_needed != 4:
|
||||
protected += "=" * padding_needed
|
||||
|
||||
protected_json = base64.urlsafe_b64decode(protected).decode()
|
||||
protected_header: dict[str, Any] = json.loads(protected_json)
|
||||
return protected_header.get("kid")
|
||||
except (ValueError, json.JSONDecodeError):
|
||||
return None
|
||||
@@ -1,339 +0,0 @@
|
||||
"""Content type negotiation for A2A protocol.
|
||||
|
||||
This module handles negotiation of input/output MIME types between A2A clients
|
||||
and servers based on AgentCard capabilities.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Annotated, Final, Literal, cast
|
||||
|
||||
from a2a.types import Part
|
||||
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.a2a_events import A2AContentTypeNegotiatedEvent
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from a2a.types import AgentCard, AgentSkill
|
||||
|
||||
|
||||
TEXT_PLAIN: Literal["text/plain"] = "text/plain"
|
||||
APPLICATION_JSON: Literal["application/json"] = "application/json"
|
||||
IMAGE_PNG: Literal["image/png"] = "image/png"
|
||||
IMAGE_JPEG: Literal["image/jpeg"] = "image/jpeg"
|
||||
IMAGE_WILDCARD: Literal["image/*"] = "image/*"
|
||||
APPLICATION_PDF: Literal["application/pdf"] = "application/pdf"
|
||||
APPLICATION_OCTET_STREAM: Literal["application/octet-stream"] = (
|
||||
"application/octet-stream"
|
||||
)
|
||||
|
||||
DEFAULT_CLIENT_INPUT_MODES: Final[list[Literal["text/plain", "application/json"]]] = [
|
||||
TEXT_PLAIN,
|
||||
APPLICATION_JSON,
|
||||
]
|
||||
DEFAULT_CLIENT_OUTPUT_MODES: Final[list[Literal["text/plain", "application/json"]]] = [
|
||||
TEXT_PLAIN,
|
||||
APPLICATION_JSON,
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class NegotiatedContentTypes:
|
||||
"""Result of content type negotiation."""
|
||||
|
||||
input_modes: Annotated[list[str], "Negotiated input MIME types the client can send"]
|
||||
output_modes: Annotated[
|
||||
list[str], "Negotiated output MIME types the server will produce"
|
||||
]
|
||||
effective_input_modes: Annotated[list[str], "Server's effective input modes"]
|
||||
effective_output_modes: Annotated[list[str], "Server's effective output modes"]
|
||||
skill_name: Annotated[
|
||||
str | None, "Skill name if negotiation was skill-specific"
|
||||
] = None
|
||||
|
||||
|
||||
class ContentTypeNegotiationError(Exception):
|
||||
"""Raised when no compatible content types can be negotiated."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client_input_modes: list[str],
|
||||
client_output_modes: list[str],
|
||||
server_input_modes: list[str],
|
||||
server_output_modes: list[str],
|
||||
direction: str = "both",
|
||||
message: str | None = None,
|
||||
) -> None:
|
||||
self.client_input_modes = client_input_modes
|
||||
self.client_output_modes = client_output_modes
|
||||
self.server_input_modes = server_input_modes
|
||||
self.server_output_modes = server_output_modes
|
||||
self.direction = direction
|
||||
|
||||
if message is None:
|
||||
if direction == "input":
|
||||
message = (
|
||||
f"No compatible input content types. "
|
||||
f"Client supports: {client_input_modes}, "
|
||||
f"Server accepts: {server_input_modes}"
|
||||
)
|
||||
elif direction == "output":
|
||||
message = (
|
||||
f"No compatible output content types. "
|
||||
f"Client accepts: {client_output_modes}, "
|
||||
f"Server produces: {server_output_modes}"
|
||||
)
|
||||
else:
|
||||
message = (
|
||||
f"No compatible content types. "
|
||||
f"Input - Client: {client_input_modes}, Server: {server_input_modes}. "
|
||||
f"Output - Client: {client_output_modes}, Server: {server_output_modes}"
|
||||
)
|
||||
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
def _normalize_mime_type(mime_type: str) -> str:
|
||||
"""Normalize MIME type for comparison (lowercase, strip whitespace)."""
|
||||
return mime_type.lower().strip()
|
||||
|
||||
|
||||
def _mime_types_compatible(client_type: str, server_type: str) -> bool:
|
||||
"""Check if two MIME types are compatible.
|
||||
|
||||
Handles wildcards like image/* matching image/png.
|
||||
"""
|
||||
client_normalized = _normalize_mime_type(client_type)
|
||||
server_normalized = _normalize_mime_type(server_type)
|
||||
|
||||
if client_normalized == server_normalized:
|
||||
return True
|
||||
|
||||
if "*" in client_normalized or "*" in server_normalized:
|
||||
client_parts = client_normalized.split("/")
|
||||
server_parts = server_normalized.split("/")
|
||||
|
||||
if len(client_parts) == 2 and len(server_parts) == 2:
|
||||
type_match = (
|
||||
client_parts[0] == server_parts[0]
|
||||
or client_parts[0] == "*"
|
||||
or server_parts[0] == "*"
|
||||
)
|
||||
subtype_match = (
|
||||
client_parts[1] == server_parts[1]
|
||||
or client_parts[1] == "*"
|
||||
or server_parts[1] == "*"
|
||||
)
|
||||
return type_match and subtype_match
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def _find_compatible_modes(
|
||||
client_modes: list[str], server_modes: list[str]
|
||||
) -> list[str]:
|
||||
"""Find compatible MIME types between client and server.
|
||||
|
||||
Returns modes in client preference order.
|
||||
"""
|
||||
compatible = []
|
||||
for client_mode in client_modes:
|
||||
for server_mode in server_modes:
|
||||
if _mime_types_compatible(client_mode, server_mode):
|
||||
if "*" in client_mode and "*" not in server_mode:
|
||||
if server_mode not in compatible:
|
||||
compatible.append(server_mode)
|
||||
else:
|
||||
if client_mode not in compatible:
|
||||
compatible.append(client_mode)
|
||||
break
|
||||
return compatible
|
||||
|
||||
|
||||
def _get_effective_modes(
|
||||
agent_card: AgentCard,
|
||||
skill_name: str | None = None,
|
||||
) -> tuple[list[str], list[str], AgentSkill | None]:
|
||||
"""Get effective input/output modes from agent card.
|
||||
|
||||
If skill_name is provided and the skill has custom modes, those are used.
|
||||
Otherwise, falls back to agent card defaults.
|
||||
"""
|
||||
skill: AgentSkill | None = None
|
||||
|
||||
if skill_name and agent_card.skills:
|
||||
for s in agent_card.skills:
|
||||
if s.name == skill_name or s.id == skill_name:
|
||||
skill = s
|
||||
break
|
||||
|
||||
if skill:
|
||||
input_modes = (
|
||||
skill.input_modes if skill.input_modes else agent_card.default_input_modes
|
||||
)
|
||||
output_modes = (
|
||||
skill.output_modes
|
||||
if skill.output_modes
|
||||
else agent_card.default_output_modes
|
||||
)
|
||||
else:
|
||||
input_modes = agent_card.default_input_modes
|
||||
output_modes = agent_card.default_output_modes
|
||||
|
||||
return input_modes, output_modes, skill
|
||||
|
||||
|
||||
def negotiate_content_types(
|
||||
agent_card: AgentCard,
|
||||
client_input_modes: list[str] | None = None,
|
||||
client_output_modes: list[str] | None = None,
|
||||
skill_name: str | None = None,
|
||||
emit_event: bool = True,
|
||||
endpoint: str | None = None,
|
||||
a2a_agent_name: str | None = None,
|
||||
strict: bool = False,
|
||||
) -> NegotiatedContentTypes:
|
||||
"""Negotiate content types between client and server.
|
||||
|
||||
Args:
|
||||
agent_card: The remote agent's card with capability info.
|
||||
client_input_modes: MIME types the client can send. Defaults to text/plain and application/json.
|
||||
client_output_modes: MIME types the client can accept. Defaults to text/plain and application/json.
|
||||
skill_name: Optional skill to use for mode lookup.
|
||||
emit_event: Whether to emit a content type negotiation event.
|
||||
endpoint: Agent endpoint (for event metadata).
|
||||
a2a_agent_name: Agent name (for event metadata).
|
||||
strict: If True, raises error when no compatible types found.
|
||||
If False, returns empty lists for incompatible directions.
|
||||
|
||||
Returns:
|
||||
NegotiatedContentTypes with compatible input and output modes.
|
||||
|
||||
Raises:
|
||||
ContentTypeNegotiationError: If strict=True and no compatible types found.
|
||||
"""
|
||||
if client_input_modes is None:
|
||||
client_input_modes = cast(list[str], DEFAULT_CLIENT_INPUT_MODES.copy())
|
||||
if client_output_modes is None:
|
||||
client_output_modes = cast(list[str], DEFAULT_CLIENT_OUTPUT_MODES.copy())
|
||||
|
||||
server_input_modes, server_output_modes, skill = _get_effective_modes(
|
||||
agent_card, skill_name
|
||||
)
|
||||
|
||||
compatible_input = _find_compatible_modes(client_input_modes, server_input_modes)
|
||||
compatible_output = _find_compatible_modes(client_output_modes, server_output_modes)
|
||||
|
||||
if strict:
|
||||
if not compatible_input and not compatible_output:
|
||||
raise ContentTypeNegotiationError(
|
||||
client_input_modes=client_input_modes,
|
||||
client_output_modes=client_output_modes,
|
||||
server_input_modes=server_input_modes,
|
||||
server_output_modes=server_output_modes,
|
||||
)
|
||||
if not compatible_input:
|
||||
raise ContentTypeNegotiationError(
|
||||
client_input_modes=client_input_modes,
|
||||
client_output_modes=client_output_modes,
|
||||
server_input_modes=server_input_modes,
|
||||
server_output_modes=server_output_modes,
|
||||
direction="input",
|
||||
)
|
||||
if not compatible_output:
|
||||
raise ContentTypeNegotiationError(
|
||||
client_input_modes=client_input_modes,
|
||||
client_output_modes=client_output_modes,
|
||||
server_input_modes=server_input_modes,
|
||||
server_output_modes=server_output_modes,
|
||||
direction="output",
|
||||
)
|
||||
|
||||
result = NegotiatedContentTypes(
|
||||
input_modes=compatible_input,
|
||||
output_modes=compatible_output,
|
||||
effective_input_modes=server_input_modes,
|
||||
effective_output_modes=server_output_modes,
|
||||
skill_name=skill.name if skill else None,
|
||||
)
|
||||
|
||||
if emit_event:
|
||||
crewai_event_bus.emit(
|
||||
None,
|
||||
A2AContentTypeNegotiatedEvent(
|
||||
endpoint=endpoint or agent_card.url,
|
||||
a2a_agent_name=a2a_agent_name or agent_card.name,
|
||||
skill_name=skill_name,
|
||||
client_input_modes=client_input_modes,
|
||||
client_output_modes=client_output_modes,
|
||||
server_input_modes=server_input_modes,
|
||||
server_output_modes=server_output_modes,
|
||||
negotiated_input_modes=compatible_input,
|
||||
negotiated_output_modes=compatible_output,
|
||||
negotiation_success=bool(compatible_input and compatible_output),
|
||||
),
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def validate_content_type(
|
||||
content_type: str,
|
||||
allowed_modes: list[str],
|
||||
) -> bool:
|
||||
"""Validate that a content type is allowed by a list of modes.
|
||||
|
||||
Args:
|
||||
content_type: The MIME type to validate.
|
||||
allowed_modes: List of allowed MIME types (may include wildcards).
|
||||
|
||||
Returns:
|
||||
True if content_type is compatible with any allowed mode.
|
||||
"""
|
||||
for mode in allowed_modes:
|
||||
if _mime_types_compatible(content_type, mode):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def get_part_content_type(part: Part) -> str:
|
||||
"""Extract MIME type from an A2A Part.
|
||||
|
||||
Args:
|
||||
part: A Part object containing TextPart, DataPart, or FilePart.
|
||||
|
||||
Returns:
|
||||
The MIME type string for this part.
|
||||
"""
|
||||
root = part.root
|
||||
if root.kind == "text":
|
||||
return TEXT_PLAIN
|
||||
if root.kind == "data":
|
||||
return APPLICATION_JSON
|
||||
if root.kind == "file":
|
||||
return root.file.mime_type or APPLICATION_OCTET_STREAM
|
||||
return APPLICATION_OCTET_STREAM
|
||||
|
||||
|
||||
def validate_message_parts(
|
||||
parts: list[Part],
|
||||
allowed_modes: list[str],
|
||||
) -> list[str]:
|
||||
"""Validate that all message parts have allowed content types.
|
||||
|
||||
Args:
|
||||
parts: List of Parts from the incoming message.
|
||||
allowed_modes: List of allowed MIME types (from default_input_modes).
|
||||
|
||||
Returns:
|
||||
List of invalid content types found (empty if all valid).
|
||||
"""
|
||||
invalid_types: list[str] = []
|
||||
for part in parts:
|
||||
content_type = get_part_content_type(part)
|
||||
if not validate_content_type(content_type, allowed_modes):
|
||||
if content_type not in invalid_types:
|
||||
invalid_types.append(content_type)
|
||||
return invalid_types
|
||||
@@ -3,18 +3,14 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
from collections.abc import AsyncIterator, Callable, MutableMapping
|
||||
from collections.abc import AsyncIterator, MutableMapping
|
||||
from contextlib import asynccontextmanager
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any, Final, Literal
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
import uuid
|
||||
|
||||
from a2a.client import Client, ClientConfig, ClientFactory
|
||||
from a2a.types import (
|
||||
AgentCard,
|
||||
FilePart,
|
||||
FileWithBytes,
|
||||
Message,
|
||||
Part,
|
||||
PushNotificationConfig as A2APushNotificationConfig,
|
||||
@@ -24,24 +20,18 @@ from a2a.types import (
|
||||
import httpx
|
||||
from pydantic import BaseModel
|
||||
|
||||
from crewai.a2a.auth.client_schemes import APIKeyAuth, HTTPDigestAuth
|
||||
from crewai.a2a.auth.schemas import APIKeyAuth, HTTPDigestAuth
|
||||
from crewai.a2a.auth.utils import (
|
||||
_auth_store,
|
||||
configure_auth_client,
|
||||
validate_auth_against_agent_card,
|
||||
)
|
||||
from crewai.a2a.config import ClientTransportConfig, GRPCClientConfig
|
||||
from crewai.a2a.extensions.registry import (
|
||||
ExtensionsMiddleware,
|
||||
validate_required_extensions,
|
||||
)
|
||||
from crewai.a2a.task_helpers import TaskStateResult
|
||||
from crewai.a2a.types import (
|
||||
HANDLER_REGISTRY,
|
||||
HandlerType,
|
||||
PartsDict,
|
||||
PartsMetadataDict,
|
||||
TransportType,
|
||||
)
|
||||
from crewai.a2a.updates import (
|
||||
PollingConfig,
|
||||
@@ -49,20 +39,7 @@ from crewai.a2a.updates import (
|
||||
StreamingHandler,
|
||||
UpdateConfig,
|
||||
)
|
||||
from crewai.a2a.utils.agent_card import (
|
||||
_afetch_agent_card_cached,
|
||||
_get_tls_verify,
|
||||
_prepare_auth_headers,
|
||||
)
|
||||
from crewai.a2a.utils.content_type import (
|
||||
DEFAULT_CLIENT_OUTPUT_MODES,
|
||||
negotiate_content_types,
|
||||
)
|
||||
from crewai.a2a.utils.transport import (
|
||||
NegotiatedTransport,
|
||||
TransportNegotiationError,
|
||||
negotiate_transport,
|
||||
)
|
||||
from crewai.a2a.utils.agent_card import _afetch_agent_card_cached
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.a2a_events import (
|
||||
A2AConversationStartedEvent,
|
||||
@@ -72,48 +49,10 @@ from crewai.events.types.a2a_events import (
|
||||
)
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from a2a.types import Message
|
||||
|
||||
from crewai.a2a.auth.client_schemes import ClientAuthScheme
|
||||
|
||||
|
||||
_DEFAULT_TRANSPORT: Final[TransportType] = "JSONRPC"
|
||||
|
||||
|
||||
def _create_file_parts(input_files: dict[str, Any] | None) -> list[Part]:
|
||||
"""Convert FileInput dictionary to FilePart objects.
|
||||
|
||||
Args:
|
||||
input_files: Dictionary mapping names to FileInput objects.
|
||||
|
||||
Returns:
|
||||
List of Part objects containing FilePart data.
|
||||
"""
|
||||
if not input_files:
|
||||
return []
|
||||
|
||||
try:
|
||||
import crewai_files # noqa: F401
|
||||
except ImportError:
|
||||
logger.debug("crewai_files not installed, skipping file parts")
|
||||
return []
|
||||
|
||||
parts: list[Part] = []
|
||||
for name, file_input in input_files.items():
|
||||
content_bytes = file_input.read()
|
||||
content_base64 = base64.b64encode(content_bytes).decode()
|
||||
file_with_bytes = FileWithBytes(
|
||||
bytes=content_base64,
|
||||
mimeType=file_input.content_type,
|
||||
name=file_input.filename or name,
|
||||
)
|
||||
parts.append(Part(root=FilePart(file=file_with_bytes)))
|
||||
|
||||
return parts
|
||||
from crewai.a2a.auth.schemas import AuthScheme
|
||||
|
||||
|
||||
def get_handler(config: UpdateConfig | None) -> HandlerType:
|
||||
@@ -132,7 +71,8 @@ def get_handler(config: UpdateConfig | None) -> HandlerType:
|
||||
|
||||
def execute_a2a_delegation(
|
||||
endpoint: str,
|
||||
auth: ClientAuthScheme | None,
|
||||
transport_protocol: Literal["JSONRPC", "GRPC", "HTTP+JSON"],
|
||||
auth: AuthScheme | None,
|
||||
timeout: int,
|
||||
task_description: str,
|
||||
context: str | None = None,
|
||||
@@ -151,24 +91,32 @@ def execute_a2a_delegation(
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
skill_id: str | None = None,
|
||||
client_extensions: list[str] | None = None,
|
||||
transport: ClientTransportConfig | None = None,
|
||||
accepted_output_modes: list[str] | None = None,
|
||||
input_files: dict[str, Any] | None = None,
|
||||
) -> TaskStateResult:
|
||||
"""Execute a task delegation to a remote A2A agent synchronously.
|
||||
|
||||
WARNING: This function blocks the entire thread by creating and running a new
|
||||
event loop. Prefer using 'await aexecute_a2a_delegation()' in async contexts
|
||||
for better performance and resource efficiency.
|
||||
|
||||
This is a synchronous wrapper around aexecute_a2a_delegation that creates a
|
||||
new event loop to run the async implementation. It is provided for compatibility
|
||||
with synchronous code paths only.
|
||||
This is the sync wrapper around aexecute_a2a_delegation. For async contexts,
|
||||
use aexecute_a2a_delegation directly.
|
||||
|
||||
Args:
|
||||
endpoint: A2A agent endpoint URL (AgentCard URL).
|
||||
auth: Optional ClientAuthScheme for authentication.
|
||||
endpoint: A2A agent endpoint URL (AgentCard URL)
|
||||
transport_protocol: Optional A2A transport protocol (grpc, jsonrpc, http+json)
|
||||
auth: Optional AuthScheme for authentication (Bearer, OAuth2, API Key, HTTP Basic/Digest)
|
||||
timeout: Request timeout in seconds
|
||||
task_description: The task to delegate
|
||||
context: Optional context information
|
||||
context_id: Context ID for correlating messages/tasks
|
||||
task_id: Specific task identifier
|
||||
reference_task_ids: List of related task IDs
|
||||
metadata: Additional metadata (external_id, request_id, etc.)
|
||||
extensions: Protocol extensions for custom fields
|
||||
conversation_history: Previous Message objects from conversation
|
||||
agent_id: Agent identifier for logging
|
||||
agent_role: Role of the CrewAI agent delegating the task
|
||||
agent_branch: Optional agent tree branch for logging
|
||||
response_model: Optional Pydantic model for structured outputs
|
||||
turn_number: Optional turn number for multi-turn conversations
|
||||
endpoint: A2A agent endpoint URL.
|
||||
auth: Optional AuthScheme for authentication.
|
||||
timeout: Request timeout in seconds.
|
||||
task_description: The task to delegate.
|
||||
context: Optional context information.
|
||||
@@ -187,27 +135,10 @@ def execute_a2a_delegation(
|
||||
from_task: Optional CrewAI Task object for event metadata.
|
||||
from_agent: Optional CrewAI Agent object for event metadata.
|
||||
skill_id: Optional skill ID to target a specific agent capability.
|
||||
client_extensions: A2A protocol extension URIs the client supports.
|
||||
transport: Transport configuration (preferred, supported transports, gRPC settings).
|
||||
accepted_output_modes: MIME types the client can accept in responses.
|
||||
input_files: Optional dictionary of files to send to remote agent.
|
||||
|
||||
Returns:
|
||||
TaskStateResult with status, result/error, history, and agent_card.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If called from an async context with a running event loop.
|
||||
"""
|
||||
try:
|
||||
asyncio.get_running_loop()
|
||||
raise RuntimeError(
|
||||
"execute_a2a_delegation() cannot be called from an async context. "
|
||||
"Use 'await aexecute_a2a_delegation()' instead."
|
||||
)
|
||||
except RuntimeError as e:
|
||||
if "no running event loop" not in str(e).lower():
|
||||
raise
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
@@ -228,15 +159,12 @@ def execute_a2a_delegation(
|
||||
agent_role=agent_role,
|
||||
agent_branch=agent_branch,
|
||||
response_model=response_model,
|
||||
transport_protocol=transport_protocol,
|
||||
turn_number=turn_number,
|
||||
updates=updates,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
skill_id=skill_id,
|
||||
client_extensions=client_extensions,
|
||||
transport=transport,
|
||||
accepted_output_modes=accepted_output_modes,
|
||||
input_files=input_files,
|
||||
)
|
||||
)
|
||||
finally:
|
||||
@@ -248,7 +176,8 @@ def execute_a2a_delegation(
|
||||
|
||||
async def aexecute_a2a_delegation(
|
||||
endpoint: str,
|
||||
auth: ClientAuthScheme | None,
|
||||
transport_protocol: Literal["JSONRPC", "GRPC", "HTTP+JSON"],
|
||||
auth: AuthScheme | None,
|
||||
timeout: int,
|
||||
task_description: str,
|
||||
context: str | None = None,
|
||||
@@ -267,10 +196,6 @@ async def aexecute_a2a_delegation(
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
skill_id: str | None = None,
|
||||
client_extensions: list[str] | None = None,
|
||||
transport: ClientTransportConfig | None = None,
|
||||
accepted_output_modes: list[str] | None = None,
|
||||
input_files: dict[str, Any] | None = None,
|
||||
) -> TaskStateResult:
|
||||
"""Execute a task delegation to a remote A2A agent asynchronously.
|
||||
|
||||
@@ -278,8 +203,25 @@ async def aexecute_a2a_delegation(
|
||||
in an async context (e.g., with Crew.akickoff() or agent.aexecute_task()).
|
||||
|
||||
Args:
|
||||
endpoint: A2A agent endpoint URL
|
||||
transport_protocol: Optional A2A transport protocol (grpc, jsonrpc, http+json)
|
||||
auth: Optional AuthScheme for authentication
|
||||
timeout: Request timeout in seconds
|
||||
task_description: Task to delegate
|
||||
context: Optional context
|
||||
context_id: Context ID for correlation
|
||||
task_id: Specific task identifier
|
||||
reference_task_ids: Related task IDs
|
||||
metadata: Additional metadata
|
||||
extensions: Protocol extensions
|
||||
conversation_history: Previous Message objects
|
||||
turn_number: Current turn number
|
||||
agent_branch: Agent tree branch for logging
|
||||
agent_id: Agent identifier for logging
|
||||
agent_role: Agent role for logging
|
||||
response_model: Optional Pydantic model for structured outputs
|
||||
endpoint: A2A agent endpoint URL.
|
||||
auth: Optional ClientAuthScheme for authentication.
|
||||
auth: Optional AuthScheme for authentication.
|
||||
timeout: Request timeout in seconds.
|
||||
task_description: The task to delegate.
|
||||
context: Optional context information.
|
||||
@@ -298,10 +240,6 @@ async def aexecute_a2a_delegation(
|
||||
from_task: Optional CrewAI Task object for event metadata.
|
||||
from_agent: Optional CrewAI Agent object for event metadata.
|
||||
skill_id: Optional skill ID to target a specific agent capability.
|
||||
client_extensions: A2A protocol extension URIs the client supports.
|
||||
transport: Transport configuration (preferred, supported transports, gRPC settings).
|
||||
accepted_output_modes: MIME types the client can accept in responses.
|
||||
input_files: Optional dictionary of files to send to remote agent.
|
||||
|
||||
Returns:
|
||||
TaskStateResult with status, result/error, history, and agent_card.
|
||||
@@ -333,13 +271,10 @@ async def aexecute_a2a_delegation(
|
||||
agent_role=agent_role,
|
||||
response_model=response_model,
|
||||
updates=updates,
|
||||
transport_protocol=transport_protocol,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
skill_id=skill_id,
|
||||
client_extensions=client_extensions,
|
||||
transport=transport,
|
||||
accepted_output_modes=accepted_output_modes,
|
||||
input_files=input_files,
|
||||
)
|
||||
except Exception as e:
|
||||
crewai_event_bus.emit(
|
||||
@@ -359,7 +294,7 @@ async def aexecute_a2a_delegation(
|
||||
)
|
||||
raise
|
||||
|
||||
agent_card_data = result.get("agent_card")
|
||||
agent_card_data: dict[str, Any] = result.get("agent_card") or {}
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
A2ADelegationCompletedEvent(
|
||||
@@ -371,7 +306,7 @@ async def aexecute_a2a_delegation(
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=result.get("a2a_agent_name"),
|
||||
agent_card=agent_card_data,
|
||||
provider=agent_card_data.get("provider") if agent_card_data else None,
|
||||
provider=agent_card_data.get("provider"),
|
||||
metadata=metadata,
|
||||
extensions=list(extensions.keys()) if extensions else None,
|
||||
from_task=from_task,
|
||||
@@ -384,7 +319,8 @@ async def aexecute_a2a_delegation(
|
||||
|
||||
async def _aexecute_a2a_delegation_impl(
|
||||
endpoint: str,
|
||||
auth: ClientAuthScheme | None,
|
||||
transport_protocol: Literal["JSONRPC", "GRPC", "HTTP+JSON"],
|
||||
auth: AuthScheme | None,
|
||||
timeout: int,
|
||||
task_description: str,
|
||||
context: str | None,
|
||||
@@ -404,14 +340,8 @@ async def _aexecute_a2a_delegation_impl(
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
skill_id: str | None = None,
|
||||
client_extensions: list[str] | None = None,
|
||||
transport: ClientTransportConfig | None = None,
|
||||
accepted_output_modes: list[str] | None = None,
|
||||
input_files: dict[str, Any] | None = None,
|
||||
) -> TaskStateResult:
|
||||
"""Internal async implementation of A2A delegation."""
|
||||
if transport is None:
|
||||
transport = ClientTransportConfig()
|
||||
if auth:
|
||||
auth_data = auth.model_dump_json(
|
||||
exclude={
|
||||
@@ -421,70 +351,22 @@ async def _aexecute_a2a_delegation_impl(
|
||||
"_authorization_callback",
|
||||
}
|
||||
)
|
||||
auth_hash = _auth_store.compute_key(type(auth).__name__, auth_data)
|
||||
auth_hash = hash((type(auth).__name__, auth_data))
|
||||
else:
|
||||
auth_hash = _auth_store.compute_key("none", endpoint)
|
||||
_auth_store.set(auth_hash, auth)
|
||||
auth_hash = 0
|
||||
_auth_store[auth_hash] = auth
|
||||
agent_card = await _afetch_agent_card_cached(
|
||||
endpoint=endpoint, auth_hash=auth_hash, timeout=timeout
|
||||
)
|
||||
|
||||
validate_auth_against_agent_card(agent_card, auth)
|
||||
|
||||
unsupported_exts = validate_required_extensions(agent_card, client_extensions)
|
||||
if unsupported_exts:
|
||||
ext_uris = [ext.uri for ext in unsupported_exts]
|
||||
raise ValueError(
|
||||
f"Agent requires extensions not supported by client: {ext_uris}"
|
||||
)
|
||||
|
||||
negotiated: NegotiatedTransport | None = None
|
||||
effective_transport: TransportType = transport.preferred or _DEFAULT_TRANSPORT
|
||||
effective_url = endpoint
|
||||
|
||||
client_transports: list[str] = (
|
||||
list(transport.supported) if transport.supported else [_DEFAULT_TRANSPORT]
|
||||
)
|
||||
|
||||
try:
|
||||
negotiated = negotiate_transport(
|
||||
agent_card=agent_card,
|
||||
client_supported_transports=client_transports,
|
||||
client_preferred_transport=transport.preferred,
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=agent_card.name,
|
||||
)
|
||||
effective_transport = negotiated.transport # type: ignore[assignment]
|
||||
effective_url = negotiated.url
|
||||
except TransportNegotiationError as e:
|
||||
logger.warning(
|
||||
"Transport negotiation failed, using fallback",
|
||||
extra={
|
||||
"error": str(e),
|
||||
"fallback_transport": effective_transport,
|
||||
"fallback_url": effective_url,
|
||||
"endpoint": endpoint,
|
||||
"client_transports": client_transports,
|
||||
"server_transports": [
|
||||
iface.transport for iface in agent_card.additional_interfaces or []
|
||||
]
|
||||
+ [agent_card.preferred_transport or "JSONRPC"],
|
||||
},
|
||||
)
|
||||
|
||||
effective_output_modes = accepted_output_modes or DEFAULT_CLIENT_OUTPUT_MODES.copy()
|
||||
|
||||
content_negotiated = negotiate_content_types(
|
||||
agent_card=agent_card,
|
||||
client_output_modes=accepted_output_modes,
|
||||
skill_name=skill_id,
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=agent_card.name,
|
||||
)
|
||||
if content_negotiated.output_modes:
|
||||
effective_output_modes = content_negotiated.output_modes
|
||||
|
||||
headers, _ = await _prepare_auth_headers(auth, timeout)
|
||||
headers: MutableMapping[str, str] = {}
|
||||
if auth:
|
||||
async with httpx.AsyncClient(timeout=timeout) as temp_auth_client:
|
||||
if isinstance(auth, (HTTPDigestAuth, APIKeyAuth)):
|
||||
configure_auth_client(auth, temp_auth_client)
|
||||
headers = await auth.apply_auth(temp_auth_client, {})
|
||||
|
||||
a2a_agent_name = None
|
||||
if agent_card.name:
|
||||
@@ -559,13 +441,10 @@ async def _aexecute_a2a_delegation_impl(
|
||||
if skill_id:
|
||||
message_metadata["skill_id"] = skill_id
|
||||
|
||||
parts_list: list[Part] = [Part(root=TextPart(**parts))]
|
||||
parts_list.extend(_create_file_parts(input_files))
|
||||
|
||||
message = Message(
|
||||
role=Role.user,
|
||||
message_id=str(uuid.uuid4()),
|
||||
parts=parts_list,
|
||||
parts=[Part(root=TextPart(**parts))],
|
||||
context_id=context_id,
|
||||
task_id=task_id,
|
||||
reference_task_ids=reference_task_ids,
|
||||
@@ -634,22 +513,15 @@ async def _aexecute_a2a_delegation_impl(
|
||||
|
||||
use_streaming = not use_polling and push_config_for_client is None
|
||||
|
||||
client_agent_card = agent_card
|
||||
if effective_url != agent_card.url:
|
||||
client_agent_card = agent_card.model_copy(update={"url": effective_url})
|
||||
|
||||
async with _create_a2a_client(
|
||||
agent_card=client_agent_card,
|
||||
transport_protocol=effective_transport,
|
||||
agent_card=agent_card,
|
||||
transport_protocol=transport_protocol,
|
||||
timeout=timeout,
|
||||
headers=headers,
|
||||
streaming=use_streaming,
|
||||
auth=auth,
|
||||
use_polling=use_polling,
|
||||
push_notification_config=push_config_for_client,
|
||||
client_extensions=client_extensions,
|
||||
accepted_output_modes=effective_output_modes, # type: ignore[arg-type]
|
||||
grpc_config=transport.grpc,
|
||||
) as client:
|
||||
result = await handler.execute(
|
||||
client=client,
|
||||
@@ -663,245 +535,6 @@ async def _aexecute_a2a_delegation_impl(
|
||||
return result
|
||||
|
||||
|
||||
def _normalize_grpc_metadata(
|
||||
metadata: tuple[tuple[str, str], ...] | None,
|
||||
) -> tuple[tuple[str, str], ...] | None:
|
||||
"""Lowercase all gRPC metadata keys.
|
||||
|
||||
gRPC requires lowercase metadata keys, but some libraries (like the A2A SDK)
|
||||
use mixed-case headers like 'X-A2A-Extensions'. This normalizes them.
|
||||
"""
|
||||
if metadata is None:
|
||||
return None
|
||||
return tuple((key.lower(), value) for key, value in metadata)
|
||||
|
||||
|
||||
def _create_grpc_interceptors(
|
||||
auth_metadata: list[tuple[str, str]] | None = None,
|
||||
) -> list[Any]:
|
||||
"""Create gRPC interceptors for metadata normalization and auth injection.
|
||||
|
||||
Args:
|
||||
auth_metadata: Optional auth metadata to inject into all calls.
|
||||
Used for insecure channels that need auth (non-localhost without TLS).
|
||||
|
||||
Returns a list of interceptors that lowercase metadata keys for gRPC
|
||||
compatibility. Must be called after grpc is imported.
|
||||
"""
|
||||
import grpc.aio # type: ignore[import-untyped]
|
||||
|
||||
def _merge_metadata(
|
||||
existing: tuple[tuple[str, str], ...] | None,
|
||||
auth: list[tuple[str, str]] | None,
|
||||
) -> tuple[tuple[str, str], ...] | None:
|
||||
"""Merge existing metadata with auth metadata and normalize keys."""
|
||||
merged: list[tuple[str, str]] = []
|
||||
if existing:
|
||||
merged.extend(existing)
|
||||
if auth:
|
||||
merged.extend(auth)
|
||||
if not merged:
|
||||
return None
|
||||
return tuple((key.lower(), value) for key, value in merged)
|
||||
|
||||
def _inject_metadata(client_call_details: Any) -> Any:
|
||||
"""Inject merged metadata into call details."""
|
||||
return client_call_details._replace(
|
||||
metadata=_merge_metadata(client_call_details.metadata, auth_metadata)
|
||||
)
|
||||
|
||||
class MetadataUnaryUnary(grpc.aio.UnaryUnaryClientInterceptor): # type: ignore[misc,no-any-unimported]
|
||||
"""Interceptor for unary-unary calls that injects auth metadata."""
|
||||
|
||||
async def intercept_unary_unary( # type: ignore[no-untyped-def]
|
||||
self, continuation, client_call_details, request
|
||||
):
|
||||
"""Intercept unary-unary call and inject metadata."""
|
||||
return await continuation(_inject_metadata(client_call_details), request)
|
||||
|
||||
class MetadataUnaryStream(grpc.aio.UnaryStreamClientInterceptor): # type: ignore[misc,no-any-unimported]
|
||||
"""Interceptor for unary-stream calls that injects auth metadata."""
|
||||
|
||||
async def intercept_unary_stream( # type: ignore[no-untyped-def]
|
||||
self, continuation, client_call_details, request
|
||||
):
|
||||
"""Intercept unary-stream call and inject metadata."""
|
||||
return await continuation(_inject_metadata(client_call_details), request)
|
||||
|
||||
class MetadataStreamUnary(grpc.aio.StreamUnaryClientInterceptor): # type: ignore[misc,no-any-unimported]
|
||||
"""Interceptor for stream-unary calls that injects auth metadata."""
|
||||
|
||||
async def intercept_stream_unary( # type: ignore[no-untyped-def]
|
||||
self, continuation, client_call_details, request_iterator
|
||||
):
|
||||
"""Intercept stream-unary call and inject metadata."""
|
||||
return await continuation(
|
||||
_inject_metadata(client_call_details), request_iterator
|
||||
)
|
||||
|
||||
class MetadataStreamStream(grpc.aio.StreamStreamClientInterceptor): # type: ignore[misc,no-any-unimported]
|
||||
"""Interceptor for stream-stream calls that injects auth metadata."""
|
||||
|
||||
async def intercept_stream_stream( # type: ignore[no-untyped-def]
|
||||
self, continuation, client_call_details, request_iterator
|
||||
):
|
||||
"""Intercept stream-stream call and inject metadata."""
|
||||
return await continuation(
|
||||
_inject_metadata(client_call_details), request_iterator
|
||||
)
|
||||
|
||||
return [
|
||||
MetadataUnaryUnary(),
|
||||
MetadataUnaryStream(),
|
||||
MetadataStreamUnary(),
|
||||
MetadataStreamStream(),
|
||||
]
|
||||
|
||||
|
||||
def _create_grpc_channel_factory(
|
||||
grpc_config: GRPCClientConfig,
|
||||
auth: ClientAuthScheme | None = None,
|
||||
) -> Callable[[str], Any]:
|
||||
"""Create a gRPC channel factory with the given configuration.
|
||||
|
||||
Args:
|
||||
grpc_config: gRPC client configuration with channel options.
|
||||
auth: Optional ClientAuthScheme for TLS and auth configuration.
|
||||
|
||||
Returns:
|
||||
A callable that creates gRPC channels from URLs.
|
||||
"""
|
||||
try:
|
||||
import grpc
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"gRPC transport requires grpcio. Install with: pip install a2a-sdk[grpc]"
|
||||
) from e
|
||||
|
||||
auth_metadata: list[tuple[str, str]] = []
|
||||
|
||||
if auth is not None:
|
||||
from crewai.a2a.auth.client_schemes import (
|
||||
APIKeyAuth,
|
||||
BearerTokenAuth,
|
||||
HTTPBasicAuth,
|
||||
HTTPDigestAuth,
|
||||
OAuth2AuthorizationCode,
|
||||
OAuth2ClientCredentials,
|
||||
)
|
||||
|
||||
if isinstance(auth, HTTPDigestAuth):
|
||||
raise ValueError(
|
||||
"HTTPDigestAuth is not supported with gRPC transport. "
|
||||
"Digest authentication requires HTTP challenge-response flow. "
|
||||
"Use BearerTokenAuth, HTTPBasicAuth, APIKeyAuth (header), or OAuth2 instead."
|
||||
)
|
||||
if isinstance(auth, APIKeyAuth) and auth.location in ("query", "cookie"):
|
||||
raise ValueError(
|
||||
f"APIKeyAuth with location='{auth.location}' is not supported with gRPC transport. "
|
||||
"gRPC only supports header-based authentication. "
|
||||
"Use APIKeyAuth with location='header' instead."
|
||||
)
|
||||
|
||||
if isinstance(auth, BearerTokenAuth):
|
||||
auth_metadata.append(("authorization", f"Bearer {auth.token}"))
|
||||
elif isinstance(auth, HTTPBasicAuth):
|
||||
import base64
|
||||
|
||||
basic_credentials = f"{auth.username}:{auth.password}"
|
||||
encoded = base64.b64encode(basic_credentials.encode()).decode()
|
||||
auth_metadata.append(("authorization", f"Basic {encoded}"))
|
||||
elif isinstance(auth, APIKeyAuth) and auth.location == "header":
|
||||
header_name = auth.name.lower()
|
||||
auth_metadata.append((header_name, auth.api_key))
|
||||
elif isinstance(auth, (OAuth2ClientCredentials, OAuth2AuthorizationCode)):
|
||||
if auth._access_token:
|
||||
auth_metadata.append(("authorization", f"Bearer {auth._access_token}"))
|
||||
|
||||
def factory(url: str) -> Any:
|
||||
"""Create a gRPC channel for the given URL."""
|
||||
target = url
|
||||
use_tls = False
|
||||
|
||||
if url.startswith("grpcs://"):
|
||||
target = url[8:]
|
||||
use_tls = True
|
||||
elif url.startswith("grpc://"):
|
||||
target = url[7:]
|
||||
elif url.startswith("https://"):
|
||||
target = url[8:]
|
||||
use_tls = True
|
||||
elif url.startswith("http://"):
|
||||
target = url[7:]
|
||||
|
||||
options: list[tuple[str, Any]] = []
|
||||
if grpc_config.max_send_message_length is not None:
|
||||
options.append(
|
||||
("grpc.max_send_message_length", grpc_config.max_send_message_length)
|
||||
)
|
||||
if grpc_config.max_receive_message_length is not None:
|
||||
options.append(
|
||||
(
|
||||
"grpc.max_receive_message_length",
|
||||
grpc_config.max_receive_message_length,
|
||||
)
|
||||
)
|
||||
if grpc_config.keepalive_time_ms is not None:
|
||||
options.append(("grpc.keepalive_time_ms", grpc_config.keepalive_time_ms))
|
||||
if grpc_config.keepalive_timeout_ms is not None:
|
||||
options.append(
|
||||
("grpc.keepalive_timeout_ms", grpc_config.keepalive_timeout_ms)
|
||||
)
|
||||
|
||||
channel_credentials = None
|
||||
if auth and hasattr(auth, "tls") and auth.tls:
|
||||
channel_credentials = auth.tls.get_grpc_credentials()
|
||||
elif use_tls:
|
||||
channel_credentials = grpc.ssl_channel_credentials()
|
||||
|
||||
if channel_credentials and auth_metadata:
|
||||
|
||||
class AuthMetadataPlugin(grpc.AuthMetadataPlugin): # type: ignore[misc,no-any-unimported]
|
||||
"""gRPC auth metadata plugin that adds auth headers as metadata."""
|
||||
|
||||
def __init__(self, metadata: list[tuple[str, str]]) -> None:
|
||||
self._metadata = tuple(metadata)
|
||||
|
||||
def __call__( # type: ignore[no-any-unimported]
|
||||
self,
|
||||
context: grpc.AuthMetadataContext,
|
||||
callback: grpc.AuthMetadataPluginCallback,
|
||||
) -> None:
|
||||
callback(self._metadata, None)
|
||||
|
||||
call_creds = grpc.metadata_call_credentials(
|
||||
AuthMetadataPlugin(auth_metadata)
|
||||
)
|
||||
credentials = grpc.composite_channel_credentials(
|
||||
channel_credentials, call_creds
|
||||
)
|
||||
interceptors = _create_grpc_interceptors()
|
||||
return grpc.aio.secure_channel(
|
||||
target, credentials, options=options or None, interceptors=interceptors
|
||||
)
|
||||
if channel_credentials:
|
||||
interceptors = _create_grpc_interceptors()
|
||||
return grpc.aio.secure_channel(
|
||||
target,
|
||||
channel_credentials,
|
||||
options=options or None,
|
||||
interceptors=interceptors,
|
||||
)
|
||||
interceptors = _create_grpc_interceptors(
|
||||
auth_metadata=auth_metadata if auth_metadata else None
|
||||
)
|
||||
return grpc.aio.insecure_channel(
|
||||
target, options=options or None, interceptors=interceptors
|
||||
)
|
||||
|
||||
return factory
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def _create_a2a_client(
|
||||
agent_card: AgentCard,
|
||||
@@ -909,12 +542,9 @@ async def _create_a2a_client(
|
||||
timeout: int,
|
||||
headers: MutableMapping[str, str],
|
||||
streaming: bool,
|
||||
auth: ClientAuthScheme | None = None,
|
||||
auth: AuthScheme | None = None,
|
||||
use_polling: bool = False,
|
||||
push_notification_config: PushNotificationConfig | None = None,
|
||||
client_extensions: list[str] | None = None,
|
||||
accepted_output_modes: list[str] | None = None,
|
||||
grpc_config: GRPCClientConfig | None = None,
|
||||
) -> AsyncIterator[Client]:
|
||||
"""Create and configure an A2A client.
|
||||
|
||||
@@ -924,21 +554,16 @@ async def _create_a2a_client(
|
||||
timeout: Request timeout in seconds.
|
||||
headers: HTTP headers (already with auth applied).
|
||||
streaming: Enable streaming responses.
|
||||
auth: Optional ClientAuthScheme for client configuration.
|
||||
auth: Optional AuthScheme for client configuration.
|
||||
use_polling: Enable polling mode.
|
||||
push_notification_config: Optional push notification config.
|
||||
client_extensions: A2A protocol extension URIs to declare support for.
|
||||
accepted_output_modes: MIME types the client can accept in responses.
|
||||
grpc_config: Optional gRPC client configuration.
|
||||
|
||||
Yields:
|
||||
Configured A2A client instance.
|
||||
"""
|
||||
verify = _get_tls_verify(auth)
|
||||
async with httpx.AsyncClient(
|
||||
timeout=timeout,
|
||||
headers=headers,
|
||||
verify=verify,
|
||||
) as httpx_client:
|
||||
if auth and isinstance(auth, (HTTPDigestAuth, APIKeyAuth)):
|
||||
configure_auth_client(auth, httpx_client)
|
||||
@@ -954,27 +579,15 @@ async def _create_a2a_client(
|
||||
)
|
||||
)
|
||||
|
||||
grpc_channel_factory = None
|
||||
if transport_protocol == "GRPC":
|
||||
grpc_channel_factory = _create_grpc_channel_factory(
|
||||
grpc_config or GRPCClientConfig(),
|
||||
auth=auth,
|
||||
)
|
||||
|
||||
config = ClientConfig(
|
||||
httpx_client=httpx_client,
|
||||
supported_transports=[transport_protocol],
|
||||
streaming=streaming and not use_polling,
|
||||
polling=use_polling,
|
||||
accepted_output_modes=accepted_output_modes or DEFAULT_CLIENT_OUTPUT_MODES, # type: ignore[arg-type]
|
||||
accepted_output_modes=["application/json"],
|
||||
push_notification_configs=push_configs,
|
||||
grpc_channel_factory=grpc_channel_factory,
|
||||
)
|
||||
|
||||
factory = ClientFactory(config)
|
||||
client = factory.create(agent_card)
|
||||
|
||||
if client_extensions:
|
||||
await client.add_request_middleware(ExtensionsMiddleware(client_extensions))
|
||||
|
||||
yield client
|
||||
|
||||
@@ -1,131 +0,0 @@
|
||||
"""Structured JSON logging utilities for A2A module."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from contextvars import ContextVar
|
||||
from datetime import datetime, timezone
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
|
||||
_log_context: ContextVar[dict[str, Any] | None] = ContextVar(
|
||||
"log_context", default=None
|
||||
)
|
||||
|
||||
|
||||
class JSONFormatter(logging.Formatter):
|
||||
"""JSON formatter for structured logging.
|
||||
|
||||
Outputs logs as JSON with consistent fields for log aggregators.
|
||||
"""
|
||||
|
||||
def format(self, record: logging.LogRecord) -> str:
|
||||
"""Format log record as JSON string."""
|
||||
log_data: dict[str, Any] = {
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"level": record.levelname,
|
||||
"logger": record.name,
|
||||
"message": record.getMessage(),
|
||||
}
|
||||
|
||||
if record.exc_info:
|
||||
log_data["exception"] = self.formatException(record.exc_info)
|
||||
|
||||
context = _log_context.get()
|
||||
if context is not None:
|
||||
log_data.update(context)
|
||||
|
||||
if hasattr(record, "task_id"):
|
||||
log_data["task_id"] = record.task_id
|
||||
if hasattr(record, "context_id"):
|
||||
log_data["context_id"] = record.context_id
|
||||
if hasattr(record, "agent"):
|
||||
log_data["agent"] = record.agent
|
||||
if hasattr(record, "endpoint"):
|
||||
log_data["endpoint"] = record.endpoint
|
||||
if hasattr(record, "extension"):
|
||||
log_data["extension"] = record.extension
|
||||
if hasattr(record, "error"):
|
||||
log_data["error"] = record.error
|
||||
|
||||
for key, value in record.__dict__.items():
|
||||
if key.startswith("_") or key in (
|
||||
"name",
|
||||
"msg",
|
||||
"args",
|
||||
"created",
|
||||
"filename",
|
||||
"funcName",
|
||||
"levelname",
|
||||
"levelno",
|
||||
"lineno",
|
||||
"module",
|
||||
"msecs",
|
||||
"pathname",
|
||||
"process",
|
||||
"processName",
|
||||
"relativeCreated",
|
||||
"stack_info",
|
||||
"exc_info",
|
||||
"exc_text",
|
||||
"thread",
|
||||
"threadName",
|
||||
"taskName",
|
||||
"message",
|
||||
):
|
||||
continue
|
||||
if key not in log_data:
|
||||
log_data[key] = value
|
||||
|
||||
return json.dumps(log_data, default=str)
|
||||
|
||||
|
||||
class LogContext:
|
||||
"""Context manager for adding fields to all logs within a scope.
|
||||
|
||||
Example:
|
||||
with LogContext(task_id="abc", context_id="xyz"):
|
||||
logger.info("Processing task") # Includes task_id and context_id
|
||||
"""
|
||||
|
||||
def __init__(self, **fields: Any) -> None:
|
||||
self._fields = fields
|
||||
self._token: Any = None
|
||||
|
||||
def __enter__(self) -> LogContext:
|
||||
current = _log_context.get() or {}
|
||||
new_context = {**current, **self._fields}
|
||||
self._token = _log_context.set(new_context)
|
||||
return self
|
||||
|
||||
def __exit__(self, *args: Any) -> None:
|
||||
_log_context.reset(self._token)
|
||||
|
||||
|
||||
def configure_json_logging(logger_name: str = "crewai.a2a") -> None:
|
||||
"""Configure JSON logging for the A2A module.
|
||||
|
||||
Args:
|
||||
logger_name: Logger name to configure.
|
||||
"""
|
||||
logger = logging.getLogger(logger_name)
|
||||
|
||||
for handler in logger.handlers[:]:
|
||||
logger.removeHandler(handler)
|
||||
|
||||
handler = logging.StreamHandler()
|
||||
handler.setFormatter(JSONFormatter())
|
||||
logger.addHandler(handler)
|
||||
|
||||
|
||||
def get_logger(name: str) -> logging.Logger:
|
||||
"""Get a logger configured for structured JSON output.
|
||||
|
||||
Args:
|
||||
name: Logger name.
|
||||
|
||||
Returns:
|
||||
Configured logger instance.
|
||||
"""
|
||||
return logging.getLogger(name)
|
||||
@@ -7,40 +7,26 @@ import base64
|
||||
from collections.abc import Callable, Coroutine
|
||||
from datetime import datetime
|
||||
from functools import wraps
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar, TypedDict, cast
|
||||
from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar, cast
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from a2a.server.agent_execution import RequestContext
|
||||
from a2a.server.events import EventQueue
|
||||
from a2a.types import (
|
||||
Artifact,
|
||||
FileWithBytes,
|
||||
FileWithUri,
|
||||
InternalError,
|
||||
InvalidParamsError,
|
||||
Message,
|
||||
Part,
|
||||
Task as A2ATask,
|
||||
TaskState,
|
||||
TaskStatus,
|
||||
TaskStatusUpdateEvent,
|
||||
)
|
||||
from a2a.utils import (
|
||||
get_data_parts,
|
||||
get_file_parts,
|
||||
new_agent_text_message,
|
||||
new_data_artifact,
|
||||
new_text_artifact,
|
||||
)
|
||||
from a2a.utils import new_agent_text_message, new_text_artifact
|
||||
from a2a.utils.errors import ServerError
|
||||
from aiocache import SimpleMemoryCache, caches # type: ignore[import-untyped]
|
||||
from pydantic import BaseModel
|
||||
|
||||
from crewai.a2a.utils.agent_card import _get_server_config
|
||||
from crewai.a2a.utils.content_type import validate_message_parts
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.a2a_events import (
|
||||
A2AServerTaskCanceledEvent,
|
||||
@@ -49,11 +35,9 @@ from crewai.events.types.a2a_events import (
|
||||
A2AServerTaskStartedEvent,
|
||||
)
|
||||
from crewai.task import Task
|
||||
from crewai.utilities.pydantic_schema_utils import create_model_from_schema
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.a2a.extensions.server import ExtensionContext, ServerExtensionRegistry
|
||||
from crewai.agent import Agent
|
||||
|
||||
|
||||
@@ -63,17 +47,7 @@ P = ParamSpec("P")
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class RedisCacheConfig(TypedDict, total=False):
|
||||
"""Configuration for aiocache Redis backend."""
|
||||
|
||||
cache: str
|
||||
endpoint: str
|
||||
port: int
|
||||
db: int
|
||||
password: str
|
||||
|
||||
|
||||
def _parse_redis_url(url: str) -> RedisCacheConfig:
|
||||
def _parse_redis_url(url: str) -> dict[str, Any]:
|
||||
"""Parse a Redis URL into aiocache configuration.
|
||||
|
||||
Args:
|
||||
@@ -82,8 +56,9 @@ def _parse_redis_url(url: str) -> RedisCacheConfig:
|
||||
Returns:
|
||||
Configuration dict for aiocache.RedisCache.
|
||||
"""
|
||||
|
||||
parsed = urlparse(url)
|
||||
config: RedisCacheConfig = {
|
||||
config: dict[str, Any] = {
|
||||
"cache": "aiocache.RedisCache",
|
||||
"endpoint": parsed.hostname or "localhost",
|
||||
"port": parsed.port or 6379,
|
||||
@@ -163,10 +138,7 @@ def cancellable(
|
||||
if message["type"] == "message":
|
||||
return True
|
||||
except (OSError, ConnectionError) as e:
|
||||
logger.warning(
|
||||
"Cancel watcher Redis error, falling back to polling",
|
||||
extra={"task_id": task_id, "error": str(e)},
|
||||
)
|
||||
logger.warning("Cancel watcher error for task_id=%s: %s", task_id, e)
|
||||
return await poll_for_cancel()
|
||||
return False
|
||||
|
||||
@@ -194,98 +166,7 @@ def cancellable(
|
||||
return wrapper
|
||||
|
||||
|
||||
def _convert_a2a_files_to_file_inputs(
|
||||
a2a_files: list[FileWithBytes | FileWithUri],
|
||||
) -> dict[str, Any]:
|
||||
"""Convert a2a file types to crewai FileInput dict.
|
||||
|
||||
Args:
|
||||
a2a_files: List of FileWithBytes or FileWithUri from a2a SDK.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping file names to FileInput objects.
|
||||
"""
|
||||
try:
|
||||
from crewai_files import File, FileBytes
|
||||
except ImportError:
|
||||
logger.debug("crewai_files not installed, returning empty file dict")
|
||||
return {}
|
||||
|
||||
file_dict: dict[str, Any] = {}
|
||||
for idx, a2a_file in enumerate(a2a_files):
|
||||
if isinstance(a2a_file, FileWithBytes):
|
||||
file_bytes = base64.b64decode(a2a_file.bytes)
|
||||
name = a2a_file.name or f"file_{idx}"
|
||||
file_source = FileBytes(data=file_bytes, filename=a2a_file.name)
|
||||
file_dict[name] = File(source=file_source)
|
||||
elif isinstance(a2a_file, FileWithUri):
|
||||
name = a2a_file.name or f"file_{idx}"
|
||||
file_dict[name] = File(source=a2a_file.uri)
|
||||
|
||||
return file_dict
|
||||
|
||||
|
||||
def _extract_response_schema(parts: list[Part]) -> dict[str, Any] | None:
|
||||
"""Extract response schema from message parts metadata.
|
||||
|
||||
The client may include a JSON schema in TextPart metadata to specify
|
||||
the expected response format (see delegation.py line 463).
|
||||
|
||||
Args:
|
||||
parts: List of message parts.
|
||||
|
||||
Returns:
|
||||
JSON schema dict if found, None otherwise.
|
||||
"""
|
||||
for part in parts:
|
||||
if part.root.kind == "text" and part.root.metadata:
|
||||
schema = part.root.metadata.get("schema")
|
||||
if schema and isinstance(schema, dict):
|
||||
return schema # type: ignore[no-any-return]
|
||||
return None
|
||||
|
||||
|
||||
def _create_result_artifact(
|
||||
result: Any,
|
||||
task_id: str,
|
||||
) -> Artifact:
|
||||
"""Create artifact from task result, using DataPart for structured data.
|
||||
|
||||
Args:
|
||||
result: The task execution result.
|
||||
task_id: The task ID for naming the artifact.
|
||||
|
||||
Returns:
|
||||
Artifact with appropriate part type (DataPart for dict/Pydantic, TextPart for strings).
|
||||
"""
|
||||
artifact_name = f"result_{task_id}"
|
||||
if isinstance(result, dict):
|
||||
return new_data_artifact(artifact_name, result)
|
||||
if isinstance(result, BaseModel):
|
||||
return new_data_artifact(artifact_name, result.model_dump())
|
||||
return new_text_artifact(artifact_name, str(result))
|
||||
|
||||
|
||||
def _build_task_description(
|
||||
user_message: str,
|
||||
structured_inputs: list[dict[str, Any]],
|
||||
) -> str:
|
||||
"""Build task description including structured data if present.
|
||||
|
||||
Args:
|
||||
user_message: The original user message text.
|
||||
structured_inputs: List of structured data from DataParts.
|
||||
|
||||
Returns:
|
||||
Task description with structured data appended if present.
|
||||
"""
|
||||
if not structured_inputs:
|
||||
return user_message
|
||||
|
||||
structured_json = json.dumps(structured_inputs, indent=2)
|
||||
return f"{user_message}\n\nStructured Data:\n{structured_json}"
|
||||
|
||||
|
||||
@cancellable
|
||||
async def execute(
|
||||
agent: Agent,
|
||||
context: RequestContext,
|
||||
@@ -297,54 +178,15 @@ async def execute(
|
||||
agent: The CrewAI agent to execute the task.
|
||||
context: The A2A request context containing the user's message.
|
||||
event_queue: The event queue for sending responses back.
|
||||
|
||||
TODOs:
|
||||
* need to impl both of structured output and file inputs, depends on `file_inputs` for
|
||||
`crewai.task.Task`, pass the below two to Task. both utils in `a2a.utils.parts`
|
||||
* structured outputs ingestion, `structured_inputs = get_data_parts(parts=context.message.parts)`
|
||||
* file inputs ingestion, `file_inputs = get_file_parts(parts=context.message.parts)`
|
||||
"""
|
||||
await _execute_impl(agent, context, event_queue, None, None)
|
||||
|
||||
|
||||
@cancellable
|
||||
async def _execute_impl(
|
||||
agent: Agent,
|
||||
context: RequestContext,
|
||||
event_queue: EventQueue,
|
||||
extension_registry: ServerExtensionRegistry | None,
|
||||
extension_context: ExtensionContext | None,
|
||||
) -> None:
|
||||
"""Internal implementation for task execution with optional extensions."""
|
||||
server_config = _get_server_config(agent)
|
||||
if context.message and context.message.parts and server_config:
|
||||
allowed_modes = server_config.default_input_modes
|
||||
invalid_types = validate_message_parts(context.message.parts, allowed_modes)
|
||||
if invalid_types:
|
||||
raise ServerError(
|
||||
InvalidParamsError(
|
||||
message=f"Unsupported content type(s): {', '.join(invalid_types)}. "
|
||||
f"Supported: {', '.join(allowed_modes)}"
|
||||
)
|
||||
)
|
||||
|
||||
if extension_registry and extension_context:
|
||||
await extension_registry.invoke_on_request(extension_context)
|
||||
|
||||
user_message = context.get_user_input()
|
||||
|
||||
response_model: type[BaseModel] | None = None
|
||||
structured_inputs: list[dict[str, Any]] = []
|
||||
a2a_files: list[FileWithBytes | FileWithUri] = []
|
||||
|
||||
if context.message and context.message.parts:
|
||||
schema = _extract_response_schema(context.message.parts)
|
||||
if schema:
|
||||
try:
|
||||
response_model = create_model_from_schema(schema)
|
||||
except Exception as e:
|
||||
logger.debug(
|
||||
"Failed to create response model from schema",
|
||||
extra={"error": str(e), "schema_title": schema.get("title")},
|
||||
)
|
||||
|
||||
structured_inputs = get_data_parts(context.message.parts)
|
||||
a2a_files = get_file_parts(context.message.parts)
|
||||
|
||||
task_id = context.task_id
|
||||
context_id = context.context_id
|
||||
if task_id is None or context_id is None:
|
||||
@@ -361,11 +203,9 @@ async def _execute_impl(
|
||||
raise ServerError(InvalidParamsError(message=msg)) from None
|
||||
|
||||
task = Task(
|
||||
description=_build_task_description(user_message, structured_inputs),
|
||||
description=user_message,
|
||||
expected_output="Response to the user's request",
|
||||
agent=agent,
|
||||
response_model=response_model,
|
||||
input_files=_convert_a2a_files_to_file_inputs(a2a_files),
|
||||
)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
@@ -380,10 +220,6 @@ async def _execute_impl(
|
||||
|
||||
try:
|
||||
result = await agent.aexecute_task(task=task, tools=agent.tools)
|
||||
if extension_registry and extension_context:
|
||||
result = await extension_registry.invoke_on_response(
|
||||
extension_context, result
|
||||
)
|
||||
result_str = str(result)
|
||||
history: list[Message] = [context.message] if context.message else []
|
||||
history.append(new_agent_text_message(result_str, context_id, task_id))
|
||||
@@ -391,8 +227,8 @@ async def _execute_impl(
|
||||
A2ATask(
|
||||
id=task_id,
|
||||
context_id=context_id,
|
||||
status=TaskStatus(state=TaskState.completed),
|
||||
artifacts=[_create_result_artifact(result, task_id)],
|
||||
status=TaskStatus(state=TaskState.input_required),
|
||||
artifacts=[new_text_artifact(result_str, f"result_{task_id}")],
|
||||
history=history,
|
||||
)
|
||||
)
|
||||
@@ -433,27 +269,6 @@ async def _execute_impl(
|
||||
) from e
|
||||
|
||||
|
||||
async def execute_with_extensions(
|
||||
agent: Agent,
|
||||
context: RequestContext,
|
||||
event_queue: EventQueue,
|
||||
extension_registry: ServerExtensionRegistry,
|
||||
extension_context: ExtensionContext,
|
||||
) -> None:
|
||||
"""Execute an A2A task with extension hooks.
|
||||
|
||||
Args:
|
||||
agent: The CrewAI agent to execute the task.
|
||||
context: The A2A request context containing the user's message.
|
||||
event_queue: The event queue for sending responses back.
|
||||
extension_registry: Registry of server extensions.
|
||||
extension_context: Context for extension hooks.
|
||||
"""
|
||||
await _execute_impl(
|
||||
agent, context, event_queue, extension_registry, extension_context
|
||||
)
|
||||
|
||||
|
||||
async def cancel(
|
||||
context: RequestContext,
|
||||
event_queue: EventQueue,
|
||||
|
||||
@@ -1,215 +0,0 @@
|
||||
"""Transport negotiation utilities for A2A protocol.
|
||||
|
||||
This module provides functionality for negotiating the transport protocol
|
||||
between an A2A client and server based on their respective capabilities
|
||||
and preferences.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Final, Literal
|
||||
|
||||
from a2a.types import AgentCard, AgentInterface
|
||||
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.a2a_events import A2ATransportNegotiatedEvent
|
||||
|
||||
|
||||
TransportProtocol = Literal["JSONRPC", "GRPC", "HTTP+JSON"]
|
||||
NegotiationSource = Literal["client_preferred", "server_preferred", "fallback"]
|
||||
|
||||
JSONRPC_TRANSPORT: Literal["JSONRPC"] = "JSONRPC"
|
||||
GRPC_TRANSPORT: Literal["GRPC"] = "GRPC"
|
||||
HTTP_JSON_TRANSPORT: Literal["HTTP+JSON"] = "HTTP+JSON"
|
||||
|
||||
DEFAULT_TRANSPORT_PREFERENCE: Final[list[TransportProtocol]] = [
|
||||
JSONRPC_TRANSPORT,
|
||||
GRPC_TRANSPORT,
|
||||
HTTP_JSON_TRANSPORT,
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class NegotiatedTransport:
|
||||
"""Result of transport negotiation.
|
||||
|
||||
Attributes:
|
||||
transport: The negotiated transport protocol.
|
||||
url: The URL to use for this transport.
|
||||
source: How the transport was selected ('preferred', 'additional', 'fallback').
|
||||
"""
|
||||
|
||||
transport: str
|
||||
url: str
|
||||
source: NegotiationSource
|
||||
|
||||
|
||||
class TransportNegotiationError(Exception):
|
||||
"""Raised when no compatible transport can be negotiated."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client_transports: list[str],
|
||||
server_transports: list[str],
|
||||
message: str | None = None,
|
||||
) -> None:
|
||||
"""Initialize the error with negotiation details.
|
||||
|
||||
Args:
|
||||
client_transports: Transports supported by the client.
|
||||
server_transports: Transports supported by the server.
|
||||
message: Optional custom error message.
|
||||
"""
|
||||
self.client_transports = client_transports
|
||||
self.server_transports = server_transports
|
||||
if message is None:
|
||||
message = (
|
||||
f"No compatible transport found. "
|
||||
f"Client supports: {client_transports}. "
|
||||
f"Server supports: {server_transports}."
|
||||
)
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
def _get_server_interfaces(agent_card: AgentCard) -> list[AgentInterface]:
|
||||
"""Extract all available interfaces from an AgentCard.
|
||||
|
||||
Creates a unified list of interfaces including the primary URL and
|
||||
any additional interfaces declared by the agent.
|
||||
|
||||
Args:
|
||||
agent_card: The agent's card containing transport information.
|
||||
|
||||
Returns:
|
||||
List of AgentInterface objects representing all available endpoints.
|
||||
"""
|
||||
interfaces: list[AgentInterface] = []
|
||||
|
||||
primary_transport = agent_card.preferred_transport or JSONRPC_TRANSPORT
|
||||
interfaces.append(
|
||||
AgentInterface(
|
||||
transport=primary_transport,
|
||||
url=agent_card.url,
|
||||
)
|
||||
)
|
||||
|
||||
if agent_card.additional_interfaces:
|
||||
for interface in agent_card.additional_interfaces:
|
||||
is_duplicate = any(
|
||||
i.url == interface.url and i.transport == interface.transport
|
||||
for i in interfaces
|
||||
)
|
||||
if not is_duplicate:
|
||||
interfaces.append(interface)
|
||||
|
||||
return interfaces
|
||||
|
||||
|
||||
def negotiate_transport(
|
||||
agent_card: AgentCard,
|
||||
client_supported_transports: list[str] | None = None,
|
||||
client_preferred_transport: str | None = None,
|
||||
emit_event: bool = True,
|
||||
endpoint: str | None = None,
|
||||
a2a_agent_name: str | None = None,
|
||||
) -> NegotiatedTransport:
|
||||
"""Negotiate the transport protocol between client and server.
|
||||
|
||||
Compares the client's supported transports with the server's available
|
||||
interfaces to find a compatible transport and URL.
|
||||
|
||||
Negotiation logic:
|
||||
1. If client_preferred_transport is set and server supports it → use it
|
||||
2. Otherwise, if server's preferred is in client's supported → use server's
|
||||
3. Otherwise, find first match from client's supported in server's interfaces
|
||||
|
||||
Args:
|
||||
agent_card: The server's AgentCard with transport information.
|
||||
client_supported_transports: Transports the client can use.
|
||||
Defaults to ["JSONRPC"] if not specified.
|
||||
client_preferred_transport: Client's preferred transport. If set and
|
||||
server supports it, takes priority over server preference.
|
||||
emit_event: Whether to emit a transport negotiation event.
|
||||
endpoint: Original endpoint URL for event metadata.
|
||||
a2a_agent_name: Agent name for event metadata.
|
||||
|
||||
Returns:
|
||||
NegotiatedTransport with the selected transport, URL, and source.
|
||||
|
||||
Raises:
|
||||
TransportNegotiationError: If no compatible transport is found.
|
||||
"""
|
||||
if client_supported_transports is None:
|
||||
client_supported_transports = [JSONRPC_TRANSPORT]
|
||||
|
||||
client_transports = [t.upper() for t in client_supported_transports]
|
||||
client_preferred = (
|
||||
client_preferred_transport.upper() if client_preferred_transport else None
|
||||
)
|
||||
|
||||
server_interfaces = _get_server_interfaces(agent_card)
|
||||
server_transports = [i.transport.upper() for i in server_interfaces]
|
||||
|
||||
transport_to_interface: dict[str, AgentInterface] = {}
|
||||
for interface in server_interfaces:
|
||||
transport_upper = interface.transport.upper()
|
||||
if transport_upper not in transport_to_interface:
|
||||
transport_to_interface[transport_upper] = interface
|
||||
|
||||
result: NegotiatedTransport | None = None
|
||||
|
||||
if client_preferred and client_preferred in transport_to_interface:
|
||||
interface = transport_to_interface[client_preferred]
|
||||
result = NegotiatedTransport(
|
||||
transport=interface.transport,
|
||||
url=interface.url,
|
||||
source="client_preferred",
|
||||
)
|
||||
else:
|
||||
server_preferred = (agent_card.preferred_transport or JSONRPC_TRANSPORT).upper()
|
||||
if (
|
||||
server_preferred in client_transports
|
||||
and server_preferred in transport_to_interface
|
||||
):
|
||||
interface = transport_to_interface[server_preferred]
|
||||
result = NegotiatedTransport(
|
||||
transport=interface.transport,
|
||||
url=interface.url,
|
||||
source="server_preferred",
|
||||
)
|
||||
else:
|
||||
for transport in client_transports:
|
||||
if transport in transport_to_interface:
|
||||
interface = transport_to_interface[transport]
|
||||
result = NegotiatedTransport(
|
||||
transport=interface.transport,
|
||||
url=interface.url,
|
||||
source="fallback",
|
||||
)
|
||||
break
|
||||
|
||||
if result is None:
|
||||
raise TransportNegotiationError(
|
||||
client_transports=client_transports,
|
||||
server_transports=server_transports,
|
||||
)
|
||||
|
||||
if emit_event:
|
||||
crewai_event_bus.emit(
|
||||
None,
|
||||
A2ATransportNegotiatedEvent(
|
||||
endpoint=endpoint or agent_card.url,
|
||||
a2a_agent_name=a2a_agent_name or agent_card.name,
|
||||
negotiated_transport=result.transport,
|
||||
negotiated_url=result.url,
|
||||
source=result.source,
|
||||
client_supported_transports=client_transports,
|
||||
server_supported_transports=server_transports,
|
||||
server_preferred_transport=agent_card.preferred_transport
|
||||
or JSONRPC_TRANSPORT,
|
||||
client_preferred_transport=client_preferred,
|
||||
),
|
||||
)
|
||||
|
||||
return result
|
||||
File diff suppressed because it is too large
Load Diff
@@ -654,165 +654,3 @@ class A2AParallelDelegationCompletedEvent(A2AEventBase):
|
||||
success_count: int
|
||||
failure_count: int
|
||||
results: dict[str, str] | None = None
|
||||
|
||||
|
||||
class A2ATransportNegotiatedEvent(A2AEventBase):
|
||||
"""Event emitted when transport protocol is negotiated with an A2A agent.
|
||||
|
||||
This event is emitted after comparing client and server transport capabilities
|
||||
to select the optimal transport protocol and endpoint URL.
|
||||
|
||||
Attributes:
|
||||
endpoint: Original A2A agent endpoint URL.
|
||||
a2a_agent_name: Name of the A2A agent from agent card.
|
||||
negotiated_transport: The transport protocol selected (JSONRPC, GRPC, HTTP+JSON).
|
||||
negotiated_url: The URL to use for the selected transport.
|
||||
source: How the transport was selected ('client_preferred', 'server_preferred', 'fallback').
|
||||
client_supported_transports: Transports the client can use.
|
||||
server_supported_transports: Transports the server supports.
|
||||
server_preferred_transport: The server's preferred transport from AgentCard.
|
||||
client_preferred_transport: The client's preferred transport if set.
|
||||
metadata: Custom A2A metadata key-value pairs.
|
||||
"""
|
||||
|
||||
type: str = "a2a_transport_negotiated"
|
||||
endpoint: str
|
||||
a2a_agent_name: str | None = None
|
||||
negotiated_transport: str
|
||||
negotiated_url: str
|
||||
source: str
|
||||
client_supported_transports: list[str]
|
||||
server_supported_transports: list[str]
|
||||
server_preferred_transport: str
|
||||
client_preferred_transport: str | None = None
|
||||
metadata: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class A2AContentTypeNegotiatedEvent(A2AEventBase):
|
||||
"""Event emitted when content types are negotiated with an A2A agent.
|
||||
|
||||
This event is emitted after comparing client and server input/output mode
|
||||
capabilities to determine compatible MIME types for communication.
|
||||
|
||||
Attributes:
|
||||
endpoint: A2A agent endpoint URL.
|
||||
a2a_agent_name: Name of the A2A agent from agent card.
|
||||
skill_name: Skill name if negotiation was skill-specific.
|
||||
client_input_modes: MIME types the client can send.
|
||||
client_output_modes: MIME types the client can accept.
|
||||
server_input_modes: MIME types the server accepts.
|
||||
server_output_modes: MIME types the server produces.
|
||||
negotiated_input_modes: Compatible input MIME types selected.
|
||||
negotiated_output_modes: Compatible output MIME types selected.
|
||||
negotiation_success: Whether compatible types were found for both directions.
|
||||
metadata: Custom A2A metadata key-value pairs.
|
||||
"""
|
||||
|
||||
type: str = "a2a_content_type_negotiated"
|
||||
endpoint: str
|
||||
a2a_agent_name: str | None = None
|
||||
skill_name: str | None = None
|
||||
client_input_modes: list[str]
|
||||
client_output_modes: list[str]
|
||||
server_input_modes: list[str]
|
||||
server_output_modes: list[str]
|
||||
negotiated_input_modes: list[str]
|
||||
negotiated_output_modes: list[str]
|
||||
negotiation_success: bool = True
|
||||
metadata: dict[str, Any] | None = None
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Context Lifecycle Events
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class A2AContextCreatedEvent(A2AEventBase):
|
||||
"""Event emitted when an A2A context is created.
|
||||
|
||||
Contexts group related tasks in a conversation or workflow.
|
||||
|
||||
Attributes:
|
||||
context_id: Unique identifier for the context.
|
||||
created_at: Unix timestamp when context was created.
|
||||
metadata: Custom A2A metadata key-value pairs.
|
||||
"""
|
||||
|
||||
type: str = "a2a_context_created"
|
||||
context_id: str
|
||||
created_at: float
|
||||
metadata: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class A2AContextExpiredEvent(A2AEventBase):
|
||||
"""Event emitted when an A2A context expires due to TTL.
|
||||
|
||||
Attributes:
|
||||
context_id: The expired context identifier.
|
||||
created_at: Unix timestamp when context was created.
|
||||
age_seconds: How long the context existed before expiring.
|
||||
task_count: Number of tasks in the context when expired.
|
||||
metadata: Custom A2A metadata key-value pairs.
|
||||
"""
|
||||
|
||||
type: str = "a2a_context_expired"
|
||||
context_id: str
|
||||
created_at: float
|
||||
age_seconds: float
|
||||
task_count: int
|
||||
metadata: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class A2AContextIdleEvent(A2AEventBase):
|
||||
"""Event emitted when an A2A context becomes idle.
|
||||
|
||||
Idle contexts have had no activity for the configured threshold.
|
||||
|
||||
Attributes:
|
||||
context_id: The idle context identifier.
|
||||
idle_seconds: Seconds since last activity.
|
||||
task_count: Number of tasks in the context.
|
||||
metadata: Custom A2A metadata key-value pairs.
|
||||
"""
|
||||
|
||||
type: str = "a2a_context_idle"
|
||||
context_id: str
|
||||
idle_seconds: float
|
||||
task_count: int
|
||||
metadata: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class A2AContextCompletedEvent(A2AEventBase):
|
||||
"""Event emitted when all tasks in an A2A context complete.
|
||||
|
||||
Attributes:
|
||||
context_id: The completed context identifier.
|
||||
total_tasks: Total number of tasks that were in the context.
|
||||
duration_seconds: Total context lifetime in seconds.
|
||||
metadata: Custom A2A metadata key-value pairs.
|
||||
"""
|
||||
|
||||
type: str = "a2a_context_completed"
|
||||
context_id: str
|
||||
total_tasks: int
|
||||
duration_seconds: float
|
||||
metadata: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class A2AContextPrunedEvent(A2AEventBase):
|
||||
"""Event emitted when an A2A context is pruned (deleted).
|
||||
|
||||
Pruning removes the context metadata and optionally associated tasks.
|
||||
|
||||
Attributes:
|
||||
context_id: The pruned context identifier.
|
||||
task_count: Number of tasks that were in the context.
|
||||
age_seconds: How long the context existed before pruning.
|
||||
metadata: Custom A2A metadata key-value pairs.
|
||||
"""
|
||||
|
||||
type: str = "a2a_context_pruned"
|
||||
context_id: str
|
||||
task_count: int
|
||||
age_seconds: float
|
||||
metadata: dict[str, Any] | None = None
|
||||
|
||||
@@ -2,10 +2,8 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
import inspect
|
||||
import json
|
||||
from types import MethodType
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
@@ -32,8 +30,6 @@ from typing_extensions import Self
|
||||
if TYPE_CHECKING:
|
||||
from crewai_files import FileInput
|
||||
|
||||
from crewai.a2a.config import A2AClientConfig, A2AConfig, A2AServerConfig
|
||||
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess
|
||||
from crewai.agents.cache.cache_handler import CacheHandler
|
||||
@@ -88,81 +84,6 @@ from crewai.utilities.tool_utils import execute_tool_and_check_finality
|
||||
from crewai.utilities.types import LLMMessage
|
||||
|
||||
|
||||
def _kickoff_with_a2a_support(
|
||||
agent: LiteAgent,
|
||||
original_kickoff: Callable[..., LiteAgentOutput],
|
||||
messages: str | list[LLMMessage],
|
||||
response_format: type[BaseModel] | None,
|
||||
input_files: dict[str, FileInput] | None,
|
||||
extension_registry: Any,
|
||||
) -> LiteAgentOutput:
|
||||
"""Wrap kickoff with A2A delegation using Task adapter.
|
||||
|
||||
Args:
|
||||
agent: The LiteAgent instance.
|
||||
original_kickoff: The original kickoff method.
|
||||
messages: Input messages.
|
||||
response_format: Optional response format.
|
||||
input_files: Optional input files.
|
||||
extension_registry: A2A extension registry.
|
||||
|
||||
Returns:
|
||||
LiteAgentOutput from either local execution or A2A delegation.
|
||||
"""
|
||||
from crewai.a2a.utils.response_model import get_a2a_agents_and_response_model
|
||||
from crewai.a2a.wrapper import _execute_task_with_a2a
|
||||
from crewai.task import Task
|
||||
|
||||
a2a_agents, agent_response_model = get_a2a_agents_and_response_model(agent.a2a)
|
||||
|
||||
if not a2a_agents:
|
||||
return original_kickoff(messages, response_format, input_files)
|
||||
|
||||
if isinstance(messages, str):
|
||||
description = messages
|
||||
else:
|
||||
content = next(
|
||||
(m["content"] for m in reversed(messages) if m["role"] == "user"),
|
||||
None,
|
||||
)
|
||||
description = content if isinstance(content, str) else ""
|
||||
|
||||
if not description:
|
||||
return original_kickoff(messages, response_format, input_files)
|
||||
|
||||
fake_task = Task(
|
||||
description=description,
|
||||
agent=agent,
|
||||
expected_output="Result from A2A delegation",
|
||||
input_files=input_files or {},
|
||||
)
|
||||
|
||||
def task_to_kickoff_adapter(
|
||||
self: Any, task: Task, context: str | None, tools: list[Any] | None
|
||||
) -> str:
|
||||
result = original_kickoff(messages, response_format, input_files)
|
||||
return result.raw
|
||||
|
||||
result_str = _execute_task_with_a2a(
|
||||
self=agent, # type: ignore[arg-type]
|
||||
a2a_agents=a2a_agents,
|
||||
original_fn=task_to_kickoff_adapter,
|
||||
task=fake_task,
|
||||
agent_response_model=agent_response_model,
|
||||
context=None,
|
||||
tools=None,
|
||||
extension_registry=extension_registry,
|
||||
)
|
||||
|
||||
return LiteAgentOutput(
|
||||
raw=result_str,
|
||||
pydantic=None,
|
||||
agent_role=agent.role,
|
||||
usage_metrics=None,
|
||||
messages=[],
|
||||
)
|
||||
|
||||
|
||||
class LiteAgent(FlowTrackable, BaseModel):
|
||||
"""
|
||||
A lightweight agent that can process messages and use tools.
|
||||
@@ -233,17 +154,6 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
guardrail_max_retries: int = Field(
|
||||
default=3, description="Maximum number of retries when guardrail fails"
|
||||
)
|
||||
a2a: (
|
||||
list[A2AConfig | A2AServerConfig | A2AClientConfig]
|
||||
| A2AConfig
|
||||
| A2AServerConfig
|
||||
| A2AClientConfig
|
||||
| None
|
||||
) = Field(
|
||||
default=None,
|
||||
description="A2A (Agent-to-Agent) configuration for delegating tasks to remote agents. "
|
||||
"Can be a single A2AConfig/A2AClientConfig/A2AServerConfig, or a list of configurations.",
|
||||
)
|
||||
tools_results: list[dict[str, Any]] = Field(
|
||||
default_factory=list, description="Results of the tools used by the agent."
|
||||
)
|
||||
@@ -299,52 +209,6 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def setup_a2a_support(self) -> Self:
|
||||
"""Setup A2A extensions and server methods if a2a config exists."""
|
||||
if self.a2a:
|
||||
from crewai.a2a.config import A2AClientConfig, A2AConfig
|
||||
from crewai.a2a.extensions.registry import (
|
||||
create_extension_registry_from_config,
|
||||
)
|
||||
from crewai.a2a.utils.agent_card import inject_a2a_server_methods
|
||||
|
||||
configs = self.a2a if isinstance(self.a2a, list) else [self.a2a]
|
||||
client_configs = [
|
||||
config
|
||||
for config in configs
|
||||
if isinstance(config, (A2AConfig, A2AClientConfig))
|
||||
]
|
||||
|
||||
extension_registry = (
|
||||
create_extension_registry_from_config(client_configs)
|
||||
if client_configs
|
||||
else create_extension_registry_from_config([])
|
||||
)
|
||||
extension_registry.inject_all_tools(self) # type: ignore[arg-type]
|
||||
inject_a2a_server_methods(self) # type: ignore[arg-type]
|
||||
|
||||
original_kickoff = self.kickoff
|
||||
|
||||
@wraps(original_kickoff)
|
||||
def kickoff_with_a2a(
|
||||
messages: str | list[LLMMessage],
|
||||
response_format: type[BaseModel] | None = None,
|
||||
input_files: dict[str, FileInput] | None = None,
|
||||
) -> LiteAgentOutput:
|
||||
return _kickoff_with_a2a_support(
|
||||
self,
|
||||
original_kickoff,
|
||||
messages,
|
||||
response_format,
|
||||
input_files,
|
||||
extension_registry,
|
||||
)
|
||||
|
||||
object.__setattr__(self, "kickoff", MethodType(kickoff_with_a2a, self))
|
||||
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def ensure_guardrail_is_callable(self) -> Self:
|
||||
if callable(self.guardrail):
|
||||
@@ -762,9 +626,7 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
formatted_answer = process_llm_response(
|
||||
cast(str, answer), self.use_stop_words
|
||||
)
|
||||
formatted_answer = process_llm_response(answer, self.use_stop_words)
|
||||
|
||||
if isinstance(formatted_answer, AgentAction):
|
||||
try:
|
||||
@@ -847,21 +709,3 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
) -> None:
|
||||
"""Append a message to the message list with the given role."""
|
||||
self._messages.append(format_message_for_llm(text, role=role))
|
||||
|
||||
|
||||
try:
|
||||
from crewai.a2a.config import (
|
||||
A2AClientConfig as _A2AClientConfig,
|
||||
A2AConfig as _A2AConfig,
|
||||
A2AServerConfig as _A2AServerConfig,
|
||||
)
|
||||
|
||||
LiteAgent.model_rebuild(
|
||||
_types_namespace={
|
||||
"A2AConfig": _A2AConfig,
|
||||
"A2AClientConfig": _A2AClientConfig,
|
||||
"A2AServerConfig": _A2AServerConfig,
|
||||
}
|
||||
)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
@@ -104,7 +104,6 @@ class TestA2AStreamingIntegration:
|
||||
message=test_message,
|
||||
new_messages=new_messages,
|
||||
agent_card=agent_card,
|
||||
endpoint=agent_card.url,
|
||||
)
|
||||
|
||||
assert isinstance(result, dict)
|
||||
@@ -226,7 +225,6 @@ class TestA2APushNotificationHandler:
|
||||
result_store=mock_store,
|
||||
polling_timeout=30.0,
|
||||
polling_interval=1.0,
|
||||
endpoint=mock_agent_card.url,
|
||||
)
|
||||
|
||||
mock_store.wait_for_result.assert_called_once_with(
|
||||
@@ -289,7 +287,6 @@ class TestA2APushNotificationHandler:
|
||||
result_store=mock_store,
|
||||
polling_timeout=5.0,
|
||||
polling_interval=0.5,
|
||||
endpoint=mock_agent_card.url,
|
||||
)
|
||||
|
||||
assert result["status"] == TaskState.failed
|
||||
@@ -320,7 +317,6 @@ class TestA2APushNotificationHandler:
|
||||
message=test_msg,
|
||||
new_messages=new_messages,
|
||||
agent_card=mock_agent_card,
|
||||
endpoint=mock_agent_card.url,
|
||||
)
|
||||
|
||||
assert result["status"] == TaskState.failed
|
||||
|
||||
@@ -43,7 +43,6 @@ def mock_context() -> MagicMock:
|
||||
context.context_id = "test-context-456"
|
||||
context.get_user_input.return_value = "Test user message"
|
||||
context.message = MagicMock(spec=Message)
|
||||
context.message.parts = []
|
||||
context.current_task = None
|
||||
return context
|
||||
|
||||
|
||||
@@ -1397,3 +1397,56 @@ def test_openai_responses_api_both_auto_chains_work_together():
|
||||
assert params.get("previous_response_id") == "resp_123"
|
||||
assert "reasoning.encrypted_content" in params["include"]
|
||||
assert len(params["input"]) == 2 # Reasoning item + message
|
||||
|
||||
|
||||
def test_openai_sdk_imports_compatibility():
|
||||
"""
|
||||
Test that all OpenAI SDK imports used by CrewAI are available.
|
||||
|
||||
This test verifies that the OpenAI SDK version installed provides all the
|
||||
types and classes that CrewAI depends on. If this test fails after updating
|
||||
the OpenAI SDK, it indicates a breaking change in the SDK that needs to be
|
||||
addressed.
|
||||
|
||||
Related to issue #4300: Dependency constraints in pyproject.toml are overly strict
|
||||
"""
|
||||
from openai import APIConnectionError, AsyncOpenAI, NotFoundError, OpenAI, Stream
|
||||
from openai.lib.streaming.chat import ChatCompletionStream
|
||||
from openai.types.chat import ChatCompletion, ChatCompletionChunk
|
||||
from openai.types.chat.chat_completion import Choice
|
||||
from openai.types.chat.chat_completion_chunk import ChoiceDelta
|
||||
from openai.types.responses import Response
|
||||
|
||||
assert OpenAI is not None
|
||||
assert AsyncOpenAI is not None
|
||||
assert Stream is not None
|
||||
assert APIConnectionError is not None
|
||||
assert NotFoundError is not None
|
||||
assert ChatCompletionStream is not None
|
||||
assert ChatCompletion is not None
|
||||
assert ChatCompletionChunk is not None
|
||||
assert Choice is not None
|
||||
assert ChoiceDelta is not None
|
||||
assert Response is not None
|
||||
|
||||
|
||||
def test_openai_sdk_client_instantiation():
|
||||
"""
|
||||
Test that OpenAI client can be instantiated with the current SDK version.
|
||||
|
||||
This test verifies that the OpenAI client initialization works correctly
|
||||
with the installed SDK version, ensuring compatibility with newer versions.
|
||||
|
||||
Related to issue #4300: Dependency constraints in pyproject.toml are overly strict
|
||||
"""
|
||||
from openai import AsyncOpenAI, OpenAI
|
||||
|
||||
client = OpenAI(api_key="test-key")
|
||||
async_client = AsyncOpenAI(api_key="test-key")
|
||||
|
||||
assert client is not None
|
||||
assert async_client is not None
|
||||
assert hasattr(client, "chat")
|
||||
assert hasattr(client.chat, "completions")
|
||||
assert hasattr(async_client, "chat")
|
||||
assert hasattr(async_client.chat, "completions")
|
||||
|
||||
Reference in New Issue
Block a user