Compare commits

...

2 Commits

Author SHA1 Message Date
Greyson LaLonde
8b3acb58a4 chore: add crewai-a2a package README 2026-02-27 09:23:55 -05:00
Greyson LaLonde
8a3c2d5ca6 refactor: extract crewai.a2a to crewai-a2a workspace package 2026-02-27 09:12:57 -05:00
98 changed files with 11990 additions and 10888 deletions

View File

@@ -19,7 +19,7 @@ repos:
language: system
pass_filenames: true
types: [python]
exclude: ^(lib/crewai/src/crewai/cli/templates/|lib/crewai/tests/|lib/crewai-tools/tests/|lib/crewai-files/tests/)
exclude: ^(lib/crewai/src/crewai/cli/templates/|lib/crewai/tests/|lib/crewai-tools/tests/|lib/crewai-files/tests/|lib/crewai-a2a/tests/)
- repo: https://github.com/astral-sh/uv-pre-commit
rev: 0.9.3
hooks:

View File

@@ -12,6 +12,7 @@ from dotenv import load_dotenv
import pytest
from vcr.request import Request # type: ignore[import-untyped]
try:
import vcr.stubs.httpx_stubs as httpx_stubs # type: ignore[import-untyped]
except ModuleNotFoundError:
@@ -225,7 +226,7 @@ def vcr_cassette_dir(request: Any) -> str:
for parent in test_file.parents:
if (
parent.name in ("crewai", "crewai-tools", "crewai-files")
parent.name in ("crewai", "crewai-tools", "crewai-files", "crewai-a2a")
and parent.parent.name == "lib"
):
package_root = parent

189
lib/crewai-a2a/README.md Normal file
View File

@@ -0,0 +1,189 @@
# crewai-a2a
Agent-to-Agent (A2A) protocol support for CrewAI. Enables agents to discover, authenticate, and communicate with remote A2A-compatible agents.
## Quick Links
[Homepage](https://www.crewai.com/) | [Documentation](https://docs.crewai.com/) | [Community](https://community.crewai.com/)
## Installation
```bash
uv pip install crewai[a2a]
# or
uv add 'crewai[a2a]'
```
## Usage
### Connecting to a Remote A2A Agent
```python
from crewai import Agent
from crewai_a2a import A2AClientConfig
agent = Agent(
role="Coordinator",
goal="Delegate research tasks",
a2a=[
A2AClientConfig(endpoint="https://research-agent.example.com"),
],
)
```
### Exposing an Agent as an A2A Server
```python
from crewai import Agent
from crewai_a2a import A2AServerConfig
agent = Agent(
role="Researcher",
goal="Answer research questions",
a2a_server=A2AServerConfig(
name="Research Agent",
description="Answers research questions using web search",
),
)
```
## Authentication
### Client Schemes
```python
from crewai_a2a.auth import (
BearerTokenAuth,
HTTPBasicAuth,
APIKeyAuth,
OAuth2ClientCredentials,
)
from crewai_a2a.config import A2AClientConfig
# Bearer token
A2AClientConfig(
endpoint="https://agent.example.com",
auth=BearerTokenAuth(token="my-token"),
)
# API key
A2AClientConfig(
endpoint="https://agent.example.com",
auth=APIKeyAuth(api_key="key", location="header", name="X-API-Key"),
)
# OAuth2 client credentials
A2AClientConfig(
endpoint="https://agent.example.com",
auth=OAuth2ClientCredentials(
token_url="https://auth.example.com/token",
client_id="id",
client_secret="secret",
),
)
```
### Server Schemes
```python
from crewai_a2a.auth import SimpleTokenAuth, OIDCAuth
from crewai_a2a.config import A2AServerConfig
# Simple token validation
A2AServerConfig(auth=SimpleTokenAuth(token="expected-token"))
# OpenID Connect
A2AServerConfig(
auth=OIDCAuth(
issuer="https://auth.example.com",
audience="my-agent",
),
)
```
## Update Mechanisms
Control how the client receives task updates from remote agents.
```python
from crewai_a2a.updates import PollingConfig, StreamingConfig, PushNotificationConfig
from crewai_a2a.config import A2AClientConfig
# Polling
A2AClientConfig(
endpoint="https://agent.example.com",
updates=PollingConfig(interval=2.0, timeout=60),
)
# Server-Sent Events streaming
A2AClientConfig(
endpoint="https://agent.example.com",
updates=StreamingConfig(),
)
# Webhook push notifications
A2AClientConfig(
endpoint="https://agent.example.com",
updates=PushNotificationConfig(
url="https://my-server.example.com/webhook",
timeout=300,
),
)
```
## Extensions
### Client Extensions
Client extensions inject tools, augment prompts, and process responses.
```python
from crewai_a2a.extensions import A2AExtension
class MyExtension(A2AExtension):
def inject_tools(self, agent):
...
def augment_prompt(self, base_prompt, conversation_state):
return f"{base_prompt}\n\nAdditional context from extension."
```
### Server Extensions
Server extensions add protocol-level capabilities to your A2A server.
```python
from crewai_a2a.extensions import ServerExtension
class MyServerExtension(ServerExtension):
uri = "urn:my-org:my-extension"
description = "Custom protocol extension"
async def on_request(self, context):
...
async def on_response(self, context, result):
...
```
## Transport
Three transport protocols are supported: JSON-RPC (default), gRPC, and HTTP+JSON.
```python
from crewai_a2a.config import ClientTransportConfig, GRPCClientConfig
from crewai_a2a.config import A2AClientConfig
A2AClientConfig(
endpoint="https://agent.example.com",
transport=ClientTransportConfig(
preferred="GRPC",
grpc=GRPCClientConfig(
max_send_message_length=4 * 1024 * 1024,
),
),
)
```

View File

@@ -0,0 +1,25 @@
[project]
name = "crewai-a2a"
dynamic = ["version"]
description = "A2A (Agent-to-Agent) protocol support for CrewAI"
readme = "README.md"
authors = [{ name = "Greyson LaLonde", email = "greyson@crewai.com" }]
requires-python = ">=3.10, <3.14"
dependencies = [
"crewai==1.10.1b1",
"a2a-sdk~=0.3.10",
"httpx-auth~=0.23.1",
"httpx-sse~=0.4.0",
"aiocache[redis,memcached]~=0.12.3",
]
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
[tool.hatch.version]
path = "src/crewai_a2a/__init__.py"
[tool.uv.sources]
crewai = { workspace = true }
crewai-files = { workspace = true }

View File

@@ -0,0 +1,12 @@
"""Agent-to-Agent (A2A) protocol communication module for CrewAI."""
__version__ = "1.10.1b1"
from crewai_a2a.config import A2AClientConfig, A2AConfig, A2AServerConfig
__all__ = [
"A2AClientConfig",
"A2AConfig",
"A2AServerConfig",
]

View File

@@ -0,0 +1,36 @@
"""A2A authentication schemas."""
from crewai_a2a.auth.client_schemes import (
APIKeyAuth,
AuthScheme,
BearerTokenAuth,
ClientAuthScheme,
HTTPBasicAuth,
HTTPDigestAuth,
OAuth2AuthorizationCode,
OAuth2ClientCredentials,
TLSConfig,
)
from crewai_a2a.auth.server_schemes import (
AuthenticatedUser,
OIDCAuth,
ServerAuthScheme,
SimpleTokenAuth,
)
__all__ = [
"APIKeyAuth",
"AuthScheme",
"AuthenticatedUser",
"BearerTokenAuth",
"ClientAuthScheme",
"HTTPBasicAuth",
"HTTPDigestAuth",
"OAuth2AuthorizationCode",
"OAuth2ClientCredentials",
"OIDCAuth",
"ServerAuthScheme",
"SimpleTokenAuth",
"TLSConfig",
]

View File

@@ -0,0 +1,550 @@
"""Authentication schemes for A2A protocol clients.
Supported authentication methods:
- Bearer tokens
- OAuth2 (Client Credentials, Authorization Code)
- API Keys (header, query, cookie)
- HTTP Basic authentication
- HTTP Digest authentication
- mTLS (mutual TLS) client certificate authentication
"""
from __future__ import annotations
from abc import ABC, abstractmethod
import asyncio
import base64
from collections.abc import Awaitable, Callable, MutableMapping
from pathlib import Path
import ssl
import time
from typing import TYPE_CHECKING, ClassVar, Literal
import urllib.parse
import httpx
from httpx import DigestAuth
from pydantic import BaseModel, ConfigDict, Field, FilePath, PrivateAttr
from typing_extensions import deprecated
if TYPE_CHECKING:
import grpc # type: ignore[import-untyped]
class TLSConfig(BaseModel):
"""TLS/mTLS configuration for secure client connections.
Supports mutual TLS (mTLS) where the client presents a certificate to the server,
and standard TLS with custom CA verification.
Attributes:
client_cert_path: Path to client certificate file (PEM format) for mTLS.
client_key_path: Path to client private key file (PEM format) for mTLS.
ca_cert_path: Path to CA certificate bundle for server verification.
verify: Whether to verify server certificates. Set False only for development.
"""
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
client_cert_path: FilePath | None = Field(
default=None,
description="Path to client certificate file (PEM format) for mTLS",
)
client_key_path: FilePath | None = Field(
default=None,
description="Path to client private key file (PEM format) for mTLS",
)
ca_cert_path: FilePath | None = Field(
default=None,
description="Path to CA certificate bundle for server verification",
)
verify: bool = Field(
default=True,
description="Whether to verify server certificates. Set False only for development.",
)
def get_httpx_ssl_context(self) -> ssl.SSLContext | bool | str:
"""Build SSL context for httpx client.
Returns:
SSL context if certificates configured, True for default verification,
False if verification disabled, or path to CA bundle.
"""
if not self.verify:
return False
if self.client_cert_path and self.client_key_path:
context = ssl.create_default_context()
if self.ca_cert_path:
context.load_verify_locations(cafile=str(self.ca_cert_path))
context.load_cert_chain(
certfile=str(self.client_cert_path),
keyfile=str(self.client_key_path),
)
return context
if self.ca_cert_path:
return str(self.ca_cert_path)
return True
def get_grpc_credentials(self) -> grpc.ChannelCredentials | None: # type: ignore[no-any-unimported]
"""Build gRPC channel credentials for secure connections.
Returns:
gRPC SSL credentials if certificates configured, None otherwise.
"""
try:
import grpc
except ImportError:
return None
if not self.verify and not self.client_cert_path:
return None
root_certs: bytes | None = None
private_key: bytes | None = None
certificate_chain: bytes | None = None
if self.ca_cert_path:
root_certs = Path(self.ca_cert_path).read_bytes()
if self.client_cert_path and self.client_key_path:
private_key = Path(self.client_key_path).read_bytes()
certificate_chain = Path(self.client_cert_path).read_bytes()
return grpc.ssl_channel_credentials(
root_certificates=root_certs,
private_key=private_key,
certificate_chain=certificate_chain,
)
class ClientAuthScheme(ABC, BaseModel):
"""Base class for client-side authentication schemes.
Client auth schemes apply credentials to outgoing requests.
Attributes:
tls: Optional TLS/mTLS configuration for secure connections.
"""
tls: TLSConfig | None = Field(
default=None,
description="TLS/mTLS configuration for secure connections",
)
@abstractmethod
async def apply_auth(
self, client: httpx.AsyncClient, headers: MutableMapping[str, str]
) -> MutableMapping[str, str]:
"""Apply authentication to request headers.
Args:
client: HTTP client for making auth requests.
headers: Current request headers.
Returns:
Updated headers with authentication applied.
"""
...
@deprecated("Use ClientAuthScheme instead", category=FutureWarning)
class AuthScheme(ClientAuthScheme):
"""Deprecated: Use ClientAuthScheme instead."""
class BearerTokenAuth(ClientAuthScheme):
"""Bearer token authentication (Authorization: Bearer <token>).
Attributes:
token: Bearer token for authentication.
"""
token: str = Field(description="Bearer token")
async def apply_auth(
self, client: httpx.AsyncClient, headers: MutableMapping[str, str]
) -> MutableMapping[str, str]:
"""Apply Bearer token to Authorization header.
Args:
client: HTTP client for making auth requests.
headers: Current request headers.
Returns:
Updated headers with Bearer token in Authorization header.
"""
headers["Authorization"] = f"Bearer {self.token}"
return headers
class HTTPBasicAuth(ClientAuthScheme):
"""HTTP Basic authentication.
Attributes:
username: Username for Basic authentication.
password: Password for Basic authentication.
"""
username: str = Field(description="Username")
password: str = Field(description="Password")
async def apply_auth(
self, client: httpx.AsyncClient, headers: MutableMapping[str, str]
) -> MutableMapping[str, str]:
"""Apply HTTP Basic authentication.
Args:
client: HTTP client for making auth requests.
headers: Current request headers.
Returns:
Updated headers with Basic auth in Authorization header.
"""
credentials = f"{self.username}:{self.password}"
encoded = base64.b64encode(credentials.encode()).decode()
headers["Authorization"] = f"Basic {encoded}"
return headers
class HTTPDigestAuth(ClientAuthScheme):
"""HTTP Digest authentication.
Note: Uses httpx-auth library for digest implementation.
Attributes:
username: Username for Digest authentication.
password: Password for Digest authentication.
"""
username: str = Field(description="Username")
password: str = Field(description="Password")
_configured_client_id: int | None = PrivateAttr(default=None)
async def apply_auth(
self, client: httpx.AsyncClient, headers: MutableMapping[str, str]
) -> MutableMapping[str, str]:
"""Digest auth is handled by httpx auth flow, not headers.
Args:
client: HTTP client for making auth requests.
headers: Current request headers.
Returns:
Unchanged headers (Digest auth handled by httpx auth flow).
"""
return headers
def configure_client(self, client: httpx.AsyncClient) -> None:
"""Configure client with Digest auth.
Idempotent: Only configures the client once. Subsequent calls on the same
client instance are no-ops to prevent overwriting auth configuration.
Args:
client: HTTP client to configure with Digest authentication.
"""
client_id = id(client)
if self._configured_client_id == client_id:
return
client.auth = DigestAuth(self.username, self.password)
self._configured_client_id = client_id
class APIKeyAuth(ClientAuthScheme):
"""API Key authentication (header, query, or cookie).
Attributes:
api_key: API key value for authentication.
location: Where to send the API key (header, query, or cookie).
name: Parameter name for the API key (default: X-API-Key).
"""
api_key: str = Field(description="API key value")
location: Literal["header", "query", "cookie"] = Field(
default="header", description="Where to send the API key"
)
name: str = Field(default="X-API-Key", description="Parameter name for the API key")
_configured_client_ids: set[int] = PrivateAttr(default_factory=set)
async def apply_auth(
self, client: httpx.AsyncClient, headers: MutableMapping[str, str]
) -> MutableMapping[str, str]:
"""Apply API key authentication.
Args:
client: HTTP client for making auth requests.
headers: Current request headers.
Returns:
Updated headers with API key (for header/cookie locations).
"""
if self.location == "header":
headers[self.name] = self.api_key
elif self.location == "cookie":
headers["Cookie"] = f"{self.name}={self.api_key}"
return headers
def configure_client(self, client: httpx.AsyncClient) -> None:
"""Configure client for query param API keys.
Idempotent: Only adds the request hook once per client instance.
Subsequent calls on the same client are no-ops to prevent hook accumulation.
Args:
client: HTTP client to configure with query param API key hook.
"""
if self.location == "query":
client_id = id(client)
if client_id in self._configured_client_ids:
return
async def _add_api_key_param(request: httpx.Request) -> None:
url = httpx.URL(request.url)
request.url = url.copy_add_param(self.name, self.api_key)
client.event_hooks["request"].append(_add_api_key_param)
self._configured_client_ids.add(client_id)
class OAuth2ClientCredentials(ClientAuthScheme):
"""OAuth2 Client Credentials flow authentication.
Thread-safe implementation with asyncio.Lock to prevent concurrent token fetches
when multiple requests share the same auth instance.
Attributes:
token_url: OAuth2 token endpoint URL.
client_id: OAuth2 client identifier.
client_secret: OAuth2 client secret.
scopes: List of required OAuth2 scopes.
"""
token_url: str = Field(description="OAuth2 token endpoint")
client_id: str = Field(description="OAuth2 client ID")
client_secret: str = Field(description="OAuth2 client secret")
scopes: list[str] = Field(
default_factory=list, description="Required OAuth2 scopes"
)
_access_token: str | None = PrivateAttr(default=None)
_token_expires_at: float | None = PrivateAttr(default=None)
_lock: asyncio.Lock = PrivateAttr(default_factory=asyncio.Lock)
async def apply_auth(
self, client: httpx.AsyncClient, headers: MutableMapping[str, str]
) -> MutableMapping[str, str]:
"""Apply OAuth2 access token to Authorization header.
Uses asyncio.Lock to ensure only one coroutine fetches tokens at a time,
preventing race conditions when multiple concurrent requests use the same
auth instance.
Args:
client: HTTP client for making token requests.
headers: Current request headers.
Returns:
Updated headers with OAuth2 access token in Authorization header.
"""
if (
self._access_token is None
or self._token_expires_at is None
or time.time() >= self._token_expires_at
):
async with self._lock:
if (
self._access_token is None
or self._token_expires_at is None
or time.time() >= self._token_expires_at
):
await self._fetch_token(client)
if self._access_token:
headers["Authorization"] = f"Bearer {self._access_token}"
return headers
async def _fetch_token(self, client: httpx.AsyncClient) -> None:
"""Fetch OAuth2 access token using client credentials flow.
Args:
client: HTTP client for making token request.
Raises:
httpx.HTTPStatusError: If token request fails.
"""
data = {
"grant_type": "client_credentials",
"client_id": self.client_id,
"client_secret": self.client_secret,
}
if self.scopes:
data["scope"] = " ".join(self.scopes)
response = await client.post(self.token_url, data=data)
response.raise_for_status()
token_data = response.json()
self._access_token = token_data["access_token"]
expires_in = token_data.get("expires_in", 3600)
self._token_expires_at = time.time() + expires_in - 60
class OAuth2AuthorizationCode(ClientAuthScheme):
"""OAuth2 Authorization Code flow authentication.
Thread-safe implementation with asyncio.Lock to prevent concurrent token operations.
Note: Requires interactive authorization.
Attributes:
authorization_url: OAuth2 authorization endpoint URL.
token_url: OAuth2 token endpoint URL.
client_id: OAuth2 client identifier.
client_secret: OAuth2 client secret.
redirect_uri: OAuth2 redirect URI for callback.
scopes: List of required OAuth2 scopes.
"""
authorization_url: str = Field(description="OAuth2 authorization endpoint")
token_url: str = Field(description="OAuth2 token endpoint")
client_id: str = Field(description="OAuth2 client ID")
client_secret: str = Field(description="OAuth2 client secret")
redirect_uri: str = Field(description="OAuth2 redirect URI")
scopes: list[str] = Field(
default_factory=list, description="Required OAuth2 scopes"
)
_access_token: str | None = PrivateAttr(default=None)
_refresh_token: str | None = PrivateAttr(default=None)
_token_expires_at: float | None = PrivateAttr(default=None)
_authorization_callback: Callable[[str], Awaitable[str]] | None = PrivateAttr(
default=None
)
_lock: asyncio.Lock = PrivateAttr(default_factory=asyncio.Lock)
def set_authorization_callback(
self, callback: Callable[[str], Awaitable[str]] | None
) -> None:
"""Set callback to handle authorization URL.
Args:
callback: Async function that receives authorization URL and returns auth code.
"""
self._authorization_callback = callback
async def apply_auth(
self, client: httpx.AsyncClient, headers: MutableMapping[str, str]
) -> MutableMapping[str, str]:
"""Apply OAuth2 access token to Authorization header.
Uses asyncio.Lock to ensure only one coroutine handles token operations
(initial fetch or refresh) at a time.
Args:
client: HTTP client for making token requests.
headers: Current request headers.
Returns:
Updated headers with OAuth2 access token in Authorization header.
Raises:
ValueError: If authorization callback is not set.
"""
if self._access_token is None:
if self._authorization_callback is None:
msg = "Authorization callback not set. Use set_authorization_callback()"
raise ValueError(msg)
async with self._lock:
if self._access_token is None:
await self._fetch_initial_token(client)
elif self._token_expires_at and time.time() >= self._token_expires_at:
async with self._lock:
if self._token_expires_at and time.time() >= self._token_expires_at:
await self._refresh_access_token(client)
if self._access_token:
headers["Authorization"] = f"Bearer {self._access_token}"
return headers
async def _fetch_initial_token(self, client: httpx.AsyncClient) -> None:
"""Fetch initial access token using authorization code flow.
Args:
client: HTTP client for making token request.
Raises:
ValueError: If authorization callback is not set.
httpx.HTTPStatusError: If token request fails.
"""
params = {
"response_type": "code",
"client_id": self.client_id,
"redirect_uri": self.redirect_uri,
"scope": " ".join(self.scopes),
}
auth_url = f"{self.authorization_url}?{urllib.parse.urlencode(params)}"
if self._authorization_callback is None:
msg = "Authorization callback not set"
raise ValueError(msg)
auth_code = await self._authorization_callback(auth_url)
data = {
"grant_type": "authorization_code",
"code": auth_code,
"client_id": self.client_id,
"client_secret": self.client_secret,
"redirect_uri": self.redirect_uri,
}
response = await client.post(self.token_url, data=data)
response.raise_for_status()
token_data = response.json()
self._access_token = token_data["access_token"]
self._refresh_token = token_data.get("refresh_token")
expires_in = token_data.get("expires_in", 3600)
self._token_expires_at = time.time() + expires_in - 60
async def _refresh_access_token(self, client: httpx.AsyncClient) -> None:
"""Refresh the access token using refresh token.
Args:
client: HTTP client for making token request.
Raises:
httpx.HTTPStatusError: If token refresh request fails.
"""
if not self._refresh_token:
await self._fetch_initial_token(client)
return
data = {
"grant_type": "refresh_token",
"refresh_token": self._refresh_token,
"client_id": self.client_id,
"client_secret": self.client_secret,
}
response = await client.post(self.token_url, data=data)
response.raise_for_status()
token_data = response.json()
self._access_token = token_data["access_token"]
if "refresh_token" in token_data:
self._refresh_token = token_data["refresh_token"]
expires_in = token_data.get("expires_in", 3600)
self._token_expires_at = time.time() + expires_in - 60

View File

@@ -0,0 +1,71 @@
"""Deprecated: Authentication schemes for A2A protocol agents.
This module is deprecated. Import from crewai_a2a.auth instead:
- crewai_a2a.auth.ClientAuthScheme (replaces AuthScheme)
- crewai_a2a.auth.BearerTokenAuth
- crewai_a2a.auth.HTTPBasicAuth
- crewai_a2a.auth.HTTPDigestAuth
- crewai_a2a.auth.APIKeyAuth
- crewai_a2a.auth.OAuth2ClientCredentials
- crewai_a2a.auth.OAuth2AuthorizationCode
"""
from __future__ import annotations
from typing_extensions import deprecated
from crewai_a2a.auth.client_schemes import (
APIKeyAuth as _APIKeyAuth,
BearerTokenAuth as _BearerTokenAuth,
ClientAuthScheme as _ClientAuthScheme,
HTTPBasicAuth as _HTTPBasicAuth,
HTTPDigestAuth as _HTTPDigestAuth,
OAuth2AuthorizationCode as _OAuth2AuthorizationCode,
OAuth2ClientCredentials as _OAuth2ClientCredentials,
)
@deprecated("Use ClientAuthScheme from crewai_a2a.auth instead", category=FutureWarning)
class AuthScheme(_ClientAuthScheme):
"""Deprecated: Use ClientAuthScheme from crewai_a2a.auth instead."""
@deprecated("Import from crewai_a2a.auth instead", category=FutureWarning)
class BearerTokenAuth(_BearerTokenAuth):
"""Deprecated: Import from crewai_a2a.auth instead."""
@deprecated("Import from crewai_a2a.auth instead", category=FutureWarning)
class HTTPBasicAuth(_HTTPBasicAuth):
"""Deprecated: Import from crewai_a2a.auth instead."""
@deprecated("Import from crewai_a2a.auth instead", category=FutureWarning)
class HTTPDigestAuth(_HTTPDigestAuth):
"""Deprecated: Import from crewai_a2a.auth instead."""
@deprecated("Import from crewai_a2a.auth instead", category=FutureWarning)
class APIKeyAuth(_APIKeyAuth):
"""Deprecated: Import from crewai_a2a.auth instead."""
@deprecated("Import from crewai_a2a.auth instead", category=FutureWarning)
class OAuth2ClientCredentials(_OAuth2ClientCredentials):
"""Deprecated: Import from crewai_a2a.auth instead."""
@deprecated("Import from crewai_a2a.auth instead", category=FutureWarning)
class OAuth2AuthorizationCode(_OAuth2AuthorizationCode):
"""Deprecated: Import from crewai_a2a.auth instead."""
__all__ = [
"APIKeyAuth",
"AuthScheme",
"BearerTokenAuth",
"HTTPBasicAuth",
"HTTPDigestAuth",
"OAuth2AuthorizationCode",
"OAuth2ClientCredentials",
]

View File

@@ -0,0 +1,742 @@
"""Server-side authentication schemes for A2A protocol.
These schemes validate incoming requests to A2A server endpoints.
Supported authentication methods:
- Simple token validation with static bearer tokens
- OpenID Connect with JWT validation using JWKS
- OAuth2 with JWT validation or token introspection
"""
from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass
import logging
import os
from typing import TYPE_CHECKING, Annotated, Any, ClassVar, Literal
import jwt
from jwt import PyJWKClient
from pydantic import (
BaseModel,
BeforeValidator,
ConfigDict,
Field,
HttpUrl,
PrivateAttr,
SecretStr,
model_validator,
)
from typing_extensions import Self
if TYPE_CHECKING:
from a2a.types import OAuth2SecurityScheme
logger = logging.getLogger(__name__)
try:
from fastapi import ( # type: ignore[import-not-found]
HTTPException,
status as http_status,
)
HTTP_401_UNAUTHORIZED = http_status.HTTP_401_UNAUTHORIZED
HTTP_500_INTERNAL_SERVER_ERROR = http_status.HTTP_500_INTERNAL_SERVER_ERROR
HTTP_503_SERVICE_UNAVAILABLE = http_status.HTTP_503_SERVICE_UNAVAILABLE
except ImportError:
class HTTPException(Exception): # type: ignore[no-redef] # noqa: N818
"""Fallback HTTPException when FastAPI is not installed."""
def __init__(
self,
status_code: int,
detail: str | None = None,
headers: dict[str, str] | None = None,
) -> None:
self.status_code = status_code
self.detail = detail
self.headers = headers
super().__init__(detail)
HTTP_401_UNAUTHORIZED = 401
HTTP_500_INTERNAL_SERVER_ERROR = 500
HTTP_503_SERVICE_UNAVAILABLE = 503
def _coerce_secret_str(v: str | SecretStr | None) -> SecretStr | None:
"""Coerce string to SecretStr."""
if v is None or isinstance(v, SecretStr):
return v
return SecretStr(v)
CoercedSecretStr = Annotated[SecretStr, BeforeValidator(_coerce_secret_str)]
JWTAlgorithm = Literal[
"RS256",
"RS384",
"RS512",
"ES256",
"ES384",
"ES512",
"PS256",
"PS384",
"PS512",
]
@dataclass
class AuthenticatedUser:
"""Result of successful authentication.
Attributes:
token: The original token that was validated.
scheme: Name of the authentication scheme used.
claims: JWT claims from OIDC or OAuth2 authentication.
"""
token: str
scheme: str
claims: dict[str, Any] | None = None
class ServerAuthScheme(ABC, BaseModel):
"""Base class for server-side authentication schemes.
Each scheme validates incoming requests and returns an AuthenticatedUser
on success, or raises HTTPException on failure.
"""
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
@abstractmethod
async def authenticate(self, token: str) -> AuthenticatedUser:
"""Authenticate the provided token.
Args:
token: The bearer token to authenticate.
Returns:
AuthenticatedUser on successful authentication.
Raises:
HTTPException: If authentication fails.
"""
...
class SimpleTokenAuth(ServerAuthScheme):
"""Simple bearer token authentication.
Validates tokens against a configured static token or AUTH_TOKEN env var.
Attributes:
token: Expected token value. Falls back to AUTH_TOKEN env var if not set.
"""
token: CoercedSecretStr | None = Field(
default=None,
description="Expected token. Falls back to AUTH_TOKEN env var.",
)
def _get_expected_token(self) -> str | None:
"""Get the expected token value."""
if self.token:
return self.token.get_secret_value()
return os.environ.get("AUTH_TOKEN")
async def authenticate(self, token: str) -> AuthenticatedUser:
"""Authenticate using simple token comparison.
Args:
token: The bearer token to authenticate.
Returns:
AuthenticatedUser on successful authentication.
Raises:
HTTPException: If authentication fails.
"""
expected = self._get_expected_token()
if expected is None:
logger.warning(
"Simple token authentication failed",
extra={"reason": "no_token_configured"},
)
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Authentication not configured",
)
if token != expected:
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Invalid or missing authentication credentials",
)
return AuthenticatedUser(
token=token,
scheme="simple_token",
)
class OIDCAuth(ServerAuthScheme):
"""OpenID Connect authentication.
Validates JWTs using JWKS with caching support via PyJWT.
Attributes:
issuer: The OpenID Connect issuer URL.
audience: The expected audience claim.
jwks_url: Optional explicit JWKS URL. Derived from issuer if not set.
algorithms: List of allowed signing algorithms.
required_claims: List of claims that must be present in the token.
jwks_cache_ttl: TTL for JWKS cache in seconds.
clock_skew_seconds: Allowed clock skew for token validation.
"""
issuer: HttpUrl = Field(
description="OpenID Connect issuer URL (e.g., https://auth.example.com)"
)
audience: str = Field(description="Expected audience claim (e.g., api://my-agent)")
jwks_url: HttpUrl | None = Field(
default=None,
description="Explicit JWKS URL. Derived from issuer if not set.",
)
algorithms: list[str] = Field(
default_factory=lambda: ["RS256"],
description="List of allowed signing algorithms (RS256, ES256, etc.)",
)
required_claims: list[str] = Field(
default_factory=lambda: ["exp", "iat", "iss", "aud", "sub"],
description="List of claims that must be present in the token",
)
jwks_cache_ttl: int = Field(
default=3600,
description="TTL for JWKS cache in seconds",
ge=60,
)
clock_skew_seconds: float = Field(
default=30.0,
description="Allowed clock skew for token validation",
ge=0.0,
)
_jwk_client: PyJWKClient | None = PrivateAttr(default=None)
@model_validator(mode="after")
def _init_jwk_client(self) -> Self:
"""Initialize the JWK client after model creation."""
jwks_url = (
str(self.jwks_url)
if self.jwks_url
else f"{str(self.issuer).rstrip('/')}/.well-known/jwks.json"
)
self._jwk_client = PyJWKClient(jwks_url, lifespan=self.jwks_cache_ttl)
return self
async def authenticate(self, token: str) -> AuthenticatedUser:
"""Authenticate using OIDC JWT validation.
Args:
token: The JWT to authenticate.
Returns:
AuthenticatedUser on successful authentication.
Raises:
HTTPException: If authentication fails.
"""
if self._jwk_client is None:
raise HTTPException(
status_code=HTTP_500_INTERNAL_SERVER_ERROR,
detail="OIDC not initialized",
)
try:
signing_key = self._jwk_client.get_signing_key_from_jwt(token)
claims = jwt.decode(
token,
signing_key.key,
algorithms=self.algorithms,
audience=self.audience,
issuer=str(self.issuer).rstrip("/"),
leeway=self.clock_skew_seconds,
options={
"require": self.required_claims,
},
)
return AuthenticatedUser(
token=token,
scheme="oidc",
claims=claims,
)
except jwt.ExpiredSignatureError:
logger.debug(
"OIDC authentication failed",
extra={"reason": "token_expired", "scheme": "oidc"},
)
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Token has expired",
) from None
except jwt.InvalidAudienceError:
logger.debug(
"OIDC authentication failed",
extra={"reason": "invalid_audience", "scheme": "oidc"},
)
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Invalid token audience",
) from None
except jwt.InvalidIssuerError:
logger.debug(
"OIDC authentication failed",
extra={"reason": "invalid_issuer", "scheme": "oidc"},
)
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Invalid token issuer",
) from None
except jwt.MissingRequiredClaimError as e:
logger.debug(
"OIDC authentication failed",
extra={"reason": "missing_claim", "claim": e.claim, "scheme": "oidc"},
)
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail=f"Missing required claim: {e.claim}",
) from None
except jwt.PyJWKClientError as e:
logger.error(
"OIDC authentication failed",
extra={
"reason": "jwks_client_error",
"error": str(e),
"scheme": "oidc",
},
)
raise HTTPException(
status_code=HTTP_503_SERVICE_UNAVAILABLE,
detail="Unable to fetch signing keys",
) from None
except jwt.InvalidTokenError as e:
logger.debug(
"OIDC authentication failed",
extra={"reason": "invalid_token", "error": str(e), "scheme": "oidc"},
)
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Invalid or missing authentication credentials",
) from None
class OAuth2ServerAuth(ServerAuthScheme):
"""OAuth2 authentication for A2A server.
Declares OAuth2 security scheme in AgentCard and validates tokens using
either JWKS for JWT tokens or token introspection for opaque tokens.
This is distinct from OIDCAuth in that it declares an explicit OAuth2SecurityScheme
with flows, rather than an OpenIdConnectSecurityScheme with discovery URL.
Attributes:
token_url: OAuth2 token endpoint URL for client_credentials flow.
authorization_url: OAuth2 authorization endpoint for authorization_code flow.
refresh_url: Optional refresh token endpoint URL.
scopes: Available OAuth2 scopes with descriptions.
jwks_url: JWKS URL for JWT validation. Required if not using introspection.
introspection_url: Token introspection endpoint (RFC 7662). Alternative to JWKS.
introspection_client_id: Client ID for introspection endpoint authentication.
introspection_client_secret: Client secret for introspection endpoint.
audience: Expected audience claim for JWT validation.
issuer: Expected issuer claim for JWT validation.
algorithms: Allowed JWT signing algorithms.
required_claims: Claims that must be present in the token.
jwks_cache_ttl: TTL for JWKS cache in seconds.
clock_skew_seconds: Allowed clock skew for token validation.
"""
token_url: HttpUrl = Field(
description="OAuth2 token endpoint URL",
)
authorization_url: HttpUrl | None = Field(
default=None,
description="OAuth2 authorization endpoint URL for authorization_code flow",
)
refresh_url: HttpUrl | None = Field(
default=None,
description="OAuth2 refresh token endpoint URL",
)
scopes: dict[str, str] = Field(
default_factory=dict,
description="Available OAuth2 scopes with descriptions",
)
jwks_url: HttpUrl | None = Field(
default=None,
description="JWKS URL for JWT validation. Required if not using introspection.",
)
introspection_url: HttpUrl | None = Field(
default=None,
description="Token introspection endpoint (RFC 7662). Alternative to JWKS.",
)
introspection_client_id: str | None = Field(
default=None,
description="Client ID for introspection endpoint authentication",
)
introspection_client_secret: CoercedSecretStr | None = Field(
default=None,
description="Client secret for introspection endpoint authentication",
)
audience: str | None = Field(
default=None,
description="Expected audience claim for JWT validation",
)
issuer: str | None = Field(
default=None,
description="Expected issuer claim for JWT validation",
)
algorithms: list[str] = Field(
default_factory=lambda: ["RS256"],
description="Allowed JWT signing algorithms",
)
required_claims: list[str] = Field(
default_factory=lambda: ["exp", "iat"],
description="Claims that must be present in the token",
)
jwks_cache_ttl: int = Field(
default=3600,
description="TTL for JWKS cache in seconds",
ge=60,
)
clock_skew_seconds: float = Field(
default=30.0,
description="Allowed clock skew for token validation",
ge=0.0,
)
_jwk_client: PyJWKClient | None = PrivateAttr(default=None)
@model_validator(mode="after")
def _validate_and_init(self) -> Self:
"""Validate configuration and initialize JWKS client if needed."""
if not self.jwks_url and not self.introspection_url:
raise ValueError(
"Either jwks_url or introspection_url must be provided for token validation"
)
if self.introspection_url:
if not self.introspection_client_id or not self.introspection_client_secret:
raise ValueError(
"introspection_client_id and introspection_client_secret are required "
"when using token introspection"
)
if self.jwks_url:
self._jwk_client = PyJWKClient(
str(self.jwks_url), lifespan=self.jwks_cache_ttl
)
return self
async def authenticate(self, token: str) -> AuthenticatedUser:
"""Authenticate using OAuth2 token validation.
Uses JWKS validation if jwks_url is configured, otherwise falls back
to token introspection.
Args:
token: The OAuth2 access token to authenticate.
Returns:
AuthenticatedUser on successful authentication.
Raises:
HTTPException: If authentication fails.
"""
if self._jwk_client:
return await self._authenticate_jwt(token)
return await self._authenticate_introspection(token)
async def _authenticate_jwt(self, token: str) -> AuthenticatedUser:
"""Authenticate using JWKS JWT validation."""
if self._jwk_client is None:
raise HTTPException(
status_code=HTTP_500_INTERNAL_SERVER_ERROR,
detail="OAuth2 JWKS not initialized",
)
try:
signing_key = self._jwk_client.get_signing_key_from_jwt(token)
decode_options: dict[str, Any] = {
"require": self.required_claims,
}
claims = jwt.decode(
token,
signing_key.key,
algorithms=self.algorithms,
audience=self.audience,
issuer=self.issuer,
leeway=self.clock_skew_seconds,
options=decode_options, # type: ignore[arg-type]
)
return AuthenticatedUser(
token=token,
scheme="oauth2",
claims=claims,
)
except jwt.ExpiredSignatureError:
logger.debug(
"OAuth2 authentication failed",
extra={"reason": "token_expired", "scheme": "oauth2"},
)
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Token has expired",
) from None
except jwt.InvalidAudienceError:
logger.debug(
"OAuth2 authentication failed",
extra={"reason": "invalid_audience", "scheme": "oauth2"},
)
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Invalid token audience",
) from None
except jwt.InvalidIssuerError:
logger.debug(
"OAuth2 authentication failed",
extra={"reason": "invalid_issuer", "scheme": "oauth2"},
)
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Invalid token issuer",
) from None
except jwt.MissingRequiredClaimError as e:
logger.debug(
"OAuth2 authentication failed",
extra={"reason": "missing_claim", "claim": e.claim, "scheme": "oauth2"},
)
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail=f"Missing required claim: {e.claim}",
) from None
except jwt.PyJWKClientError as e:
logger.error(
"OAuth2 authentication failed",
extra={
"reason": "jwks_client_error",
"error": str(e),
"scheme": "oauth2",
},
)
raise HTTPException(
status_code=HTTP_503_SERVICE_UNAVAILABLE,
detail="Unable to fetch signing keys",
) from None
except jwt.InvalidTokenError as e:
logger.debug(
"OAuth2 authentication failed",
extra={"reason": "invalid_token", "error": str(e), "scheme": "oauth2"},
)
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Invalid or missing authentication credentials",
) from None
async def _authenticate_introspection(self, token: str) -> AuthenticatedUser:
"""Authenticate using OAuth2 token introspection (RFC 7662)."""
import httpx
if not self.introspection_url:
raise HTTPException(
status_code=HTTP_500_INTERNAL_SERVER_ERROR,
detail="OAuth2 introspection not configured",
)
try:
async with httpx.AsyncClient() as client:
response = await client.post(
str(self.introspection_url),
data={"token": token},
auth=(
self.introspection_client_id or "",
self.introspection_client_secret.get_secret_value()
if self.introspection_client_secret
else "",
),
)
response.raise_for_status()
introspection_result = response.json()
except httpx.HTTPStatusError as e:
logger.error(
"OAuth2 introspection failed",
extra={"reason": "http_error", "status_code": e.response.status_code},
)
raise HTTPException(
status_code=HTTP_503_SERVICE_UNAVAILABLE,
detail="Token introspection service unavailable",
) from None
except Exception as e:
logger.error(
"OAuth2 introspection failed",
extra={"reason": "unexpected_error", "error": str(e)},
)
raise HTTPException(
status_code=HTTP_503_SERVICE_UNAVAILABLE,
detail="Token introspection failed",
) from None
if not introspection_result.get("active", False):
logger.debug(
"OAuth2 authentication failed",
extra={"reason": "token_not_active", "scheme": "oauth2"},
)
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Token is not active",
)
return AuthenticatedUser(
token=token,
scheme="oauth2",
claims=introspection_result,
)
def to_security_scheme(self) -> OAuth2SecurityScheme:
"""Generate OAuth2SecurityScheme for AgentCard declaration.
Creates an OAuth2SecurityScheme with appropriate flows based on
the configured URLs. Includes client_credentials flow if token_url
is set, and authorization_code flow if authorization_url is set.
Returns:
OAuth2SecurityScheme suitable for use in AgentCard security_schemes.
"""
from a2a.types import (
AuthorizationCodeOAuthFlow,
ClientCredentialsOAuthFlow,
OAuth2SecurityScheme,
OAuthFlows,
)
client_credentials = None
authorization_code = None
if self.token_url:
client_credentials = ClientCredentialsOAuthFlow(
token_url=str(self.token_url),
refresh_url=str(self.refresh_url) if self.refresh_url else None,
scopes=self.scopes,
)
if self.authorization_url:
authorization_code = AuthorizationCodeOAuthFlow(
authorization_url=str(self.authorization_url),
token_url=str(self.token_url),
refresh_url=str(self.refresh_url) if self.refresh_url else None,
scopes=self.scopes,
)
return OAuth2SecurityScheme(
flows=OAuthFlows(
client_credentials=client_credentials,
authorization_code=authorization_code,
),
description="OAuth2 authentication",
)
class APIKeyServerAuth(ServerAuthScheme):
"""API Key authentication for A2A server.
Validates requests using an API key in a header, query parameter, or cookie.
Attributes:
name: The name of the API key parameter (default: X-API-Key).
location: Where to look for the API key (header, query, or cookie).
api_key: The expected API key value.
"""
name: str = Field(
default="X-API-Key",
description="Name of the API key parameter",
)
location: Literal["header", "query", "cookie"] = Field(
default="header",
description="Where to look for the API key",
)
api_key: CoercedSecretStr = Field(
description="Expected API key value",
)
async def authenticate(self, token: str) -> AuthenticatedUser:
"""Authenticate using API key comparison.
Args:
token: The API key to authenticate.
Returns:
AuthenticatedUser on successful authentication.
Raises:
HTTPException: If authentication fails.
"""
if token != self.api_key.get_secret_value():
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Invalid API key",
)
return AuthenticatedUser(
token=token,
scheme="api_key",
)
class MTLSServerAuth(ServerAuthScheme):
"""Mutual TLS authentication marker for AgentCard declaration.
This scheme is primarily for AgentCard security_schemes declaration.
Actual mTLS verification happens at the TLS/transport layer, not
at the application layer via token validation.
When configured, this signals to clients that the server requires
client certificates for authentication.
"""
description: str = Field(
default="Mutual TLS certificate authentication",
description="Description for the security scheme",
)
async def authenticate(self, token: str) -> AuthenticatedUser:
"""Return authenticated user for mTLS.
mTLS verification happens at the transport layer before this is called.
If we reach this point, the TLS handshake with client cert succeeded.
Args:
token: Certificate subject or identifier (from TLS layer).
Returns:
AuthenticatedUser indicating mTLS authentication.
"""
return AuthenticatedUser(
token=token or "mtls-verified",
scheme="mtls",
)

View File

@@ -0,0 +1,273 @@
"""Authentication utilities for A2A protocol agent communication.
Provides validation and retry logic for various authentication schemes including
OAuth2, API keys, and HTTP authentication methods.
"""
import asyncio
from collections.abc import Awaitable, Callable, MutableMapping
import hashlib
import re
import threading
from typing import Final, Literal, cast
from a2a.client.errors import A2AClientHTTPError
from a2a.types import (
APIKeySecurityScheme,
AgentCard,
HTTPAuthSecurityScheme,
OAuth2SecurityScheme,
)
from httpx import AsyncClient, Response
from crewai_a2a.auth.client_schemes import (
APIKeyAuth,
BearerTokenAuth,
ClientAuthScheme,
HTTPBasicAuth,
HTTPDigestAuth,
OAuth2AuthorizationCode,
OAuth2ClientCredentials,
)
class _AuthStore:
"""Store for authentication schemes with safe concurrent access."""
def __init__(self) -> None:
self._store: dict[str, ClientAuthScheme | None] = {}
self._lock = threading.RLock()
@staticmethod
def compute_key(auth_type: str, auth_data: str) -> str:
"""Compute a collision-resistant key using SHA-256."""
content = f"{auth_type}:{auth_data}"
return hashlib.sha256(content.encode()).hexdigest()
def set(self, key: str, auth: ClientAuthScheme | None) -> None:
"""Store an auth scheme."""
with self._lock:
self._store[key] = auth
def get(self, key: str) -> ClientAuthScheme | None:
"""Retrieve an auth scheme by key."""
with self._lock:
return self._store.get(key)
def __setitem__(self, key: str, value: ClientAuthScheme | None) -> None:
with self._lock:
self._store[key] = value
def __getitem__(self, key: str) -> ClientAuthScheme | None:
with self._lock:
return self._store[key]
_auth_store = _AuthStore()
_SCHEME_PATTERN: Final[re.Pattern[str]] = re.compile(r"(\w+)\s+(.+?)(?=,\s*\w+\s+|$)")
_PARAM_PATTERN: Final[re.Pattern[str]] = re.compile(r'(\w+)=(?:"([^"]*)"|([^\s,]+))')
_SCHEME_AUTH_MAPPING: Final[dict[type, tuple[type[ClientAuthScheme], ...]]] = {
OAuth2SecurityScheme: (
OAuth2ClientCredentials,
OAuth2AuthorizationCode,
BearerTokenAuth,
),
APIKeySecurityScheme: (APIKeyAuth,),
}
_HTTPSchemeType = Literal["basic", "digest", "bearer"]
_HTTP_SCHEME_MAPPING: Final[dict[_HTTPSchemeType, type[ClientAuthScheme]]] = {
"basic": HTTPBasicAuth,
"digest": HTTPDigestAuth,
"bearer": BearerTokenAuth,
}
def _raise_auth_mismatch(
expected_classes: type[ClientAuthScheme] | tuple[type[ClientAuthScheme], ...],
provided_auth: ClientAuthScheme,
) -> None:
"""Raise authentication mismatch error.
Args:
expected_classes: Expected authentication class or tuple of classes.
provided_auth: Actually provided authentication instance.
Raises:
A2AClientHTTPError: Always raises with 401 status code.
"""
if isinstance(expected_classes, tuple):
if len(expected_classes) == 1:
required = expected_classes[0].__name__
else:
names = [cls.__name__ for cls in expected_classes]
required = f"one of ({', '.join(names)})"
else:
required = expected_classes.__name__
msg = (
f"AgentCard requires {required} authentication, "
f"but {type(provided_auth).__name__} was provided"
)
raise A2AClientHTTPError(401, msg)
def parse_www_authenticate(header_value: str) -> dict[str, dict[str, str]]:
"""Parse WWW-Authenticate header into auth challenges.
Args:
header_value: The WWW-Authenticate header value.
Returns:
Dictionary mapping auth scheme to its parameters.
Example: {"Bearer": {"realm": "api", "scope": "read write"}}
"""
if not header_value:
return {}
challenges: dict[str, dict[str, str]] = {}
for match in _SCHEME_PATTERN.finditer(header_value):
scheme = match.group(1)
params_str = match.group(2)
params: dict[str, str] = {}
for param_match in _PARAM_PATTERN.finditer(params_str):
key = param_match.group(1)
value = param_match.group(2) or param_match.group(3)
params[key] = value
challenges[scheme] = params
return challenges
def validate_auth_against_agent_card(
agent_card: AgentCard, auth: ClientAuthScheme | None
) -> None:
"""Validate that provided auth matches AgentCard security requirements.
Args:
agent_card: The A2A AgentCard containing security requirements.
auth: User-provided authentication scheme (or None).
Raises:
A2AClientHTTPError: If auth doesn't match AgentCard requirements (status_code=401).
"""
if not agent_card.security or not agent_card.security_schemes:
return
if not auth:
msg = "AgentCard requires authentication but no auth scheme provided"
raise A2AClientHTTPError(401, msg)
first_security_req = agent_card.security[0] if agent_card.security else {}
for scheme_name in first_security_req.keys():
security_scheme_wrapper = agent_card.security_schemes.get(scheme_name)
if not security_scheme_wrapper:
continue
scheme = security_scheme_wrapper.root
if allowed_classes := _SCHEME_AUTH_MAPPING.get(type(scheme)):
if not isinstance(auth, allowed_classes):
_raise_auth_mismatch(allowed_classes, auth)
return
if isinstance(scheme, HTTPAuthSecurityScheme):
scheme_key = cast(_HTTPSchemeType, scheme.scheme.lower())
if required_class := _HTTP_SCHEME_MAPPING.get(scheme_key):
if not isinstance(auth, required_class):
_raise_auth_mismatch(required_class, auth)
return
msg = "Could not validate auth against AgentCard security requirements"
raise A2AClientHTTPError(401, msg)
async def retry_on_401(
request_func: Callable[[], Awaitable[Response]],
auth_scheme: ClientAuthScheme | None,
client: AsyncClient,
headers: MutableMapping[str, str],
max_retries: int = 3,
) -> Response:
"""Retry a request on 401 authentication error.
Handles 401 errors by:
1. Parsing WWW-Authenticate header
2. Re-acquiring credentials
3. Retrying the request
Args:
request_func: Async function that makes the HTTP request.
auth_scheme: Authentication scheme to refresh credentials with.
client: HTTP client for making requests.
headers: Request headers to update with new auth.
max_retries: Maximum number of retry attempts (default: 3).
Returns:
HTTP response from the request.
Raises:
httpx.HTTPStatusError: If retries are exhausted or auth scheme is None.
"""
last_response: Response | None = None
last_challenges: dict[str, dict[str, str]] = {}
for attempt in range(max_retries):
response = await request_func()
if response.status_code != 401:
return response
last_response = response
if auth_scheme is None:
response.raise_for_status()
return response
www_authenticate = response.headers.get("WWW-Authenticate", "")
challenges = parse_www_authenticate(www_authenticate)
last_challenges = challenges
if attempt >= max_retries - 1:
break
backoff_time = 2**attempt
await asyncio.sleep(backoff_time)
await auth_scheme.apply_auth(client, headers)
if last_response:
last_response.raise_for_status()
return last_response
msg = "retry_on_401 failed without making any requests"
if last_challenges:
challenge_info = ", ".join(
f"{scheme} (realm={params.get('realm', 'N/A')})"
for scheme, params in last_challenges.items()
)
msg = f"{msg}. Server challenges: {challenge_info}"
raise RuntimeError(msg)
def configure_auth_client(
auth: HTTPDigestAuth | APIKeyAuth, client: AsyncClient
) -> None:
"""Configure HTTP client with auth-specific settings.
Only HTTPDigestAuth and APIKeyAuth need client configuration.
Args:
auth: Authentication scheme that requires client configuration.
client: HTTP client to configure.
"""
auth.configure_client(client)

View File

@@ -0,0 +1,690 @@
"""A2A configuration types.
This module is separate from experimental.a2a to avoid circular imports.
"""
from __future__ import annotations
from pathlib import Path
from typing import Any, ClassVar, Literal, cast
import warnings
from pydantic import (
BaseModel,
ConfigDict,
Field,
FilePath,
PrivateAttr,
SecretStr,
model_validator,
)
from typing_extensions import Self, deprecated
from crewai_a2a.auth.client_schemes import ClientAuthScheme
from crewai_a2a.auth.server_schemes import ServerAuthScheme
from crewai_a2a.extensions.base import ValidatedA2AExtension
from crewai_a2a.types import ProtocolVersion, TransportType, Url
try:
from a2a.types import (
AgentCapabilities,
AgentCardSignature,
AgentInterface,
AgentProvider,
AgentSkill,
SecurityScheme,
)
from crewai_a2a.extensions.server import ServerExtension
from crewai_a2a.updates import UpdateConfig
except ImportError:
UpdateConfig: Any = Any # type: ignore[no-redef]
AgentCapabilities: Any = Any # type: ignore[no-redef]
AgentCardSignature: Any = Any # type: ignore[no-redef]
AgentInterface: Any = Any # type: ignore[no-redef]
AgentProvider: Any = Any # type: ignore[no-redef]
SecurityScheme: Any = Any # type: ignore[no-redef]
AgentSkill: Any = Any # type: ignore[no-redef]
ServerExtension: Any = Any # type: ignore[no-redef]
def _get_default_update_config() -> UpdateConfig:
from crewai_a2a.updates import StreamingConfig
return StreamingConfig()
SigningAlgorithm = Literal[
"RS256",
"RS384",
"RS512",
"ES256",
"ES384",
"ES512",
"PS256",
"PS384",
"PS512",
]
class AgentCardSigningConfig(BaseModel):
"""Configuration for AgentCard JWS signing.
Provides the private key and algorithm settings for signing AgentCards.
Either private_key_path or private_key_pem must be provided, but not both.
Attributes:
private_key_path: Path to a PEM-encoded private key file.
private_key_pem: PEM-encoded private key as a secret string.
key_id: Optional key identifier for the JWS header (kid claim).
algorithm: Signing algorithm (RS256, ES256, PS256, etc.).
"""
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
private_key_path: FilePath | None = Field(
default=None,
description="Path to PEM-encoded private key file",
)
private_key_pem: SecretStr | None = Field(
default=None,
description="PEM-encoded private key",
)
key_id: str | None = Field(
default=None,
description="Key identifier for JWS header (kid claim)",
)
algorithm: SigningAlgorithm = Field(
default="RS256",
description="Signing algorithm (RS256, ES256, PS256, etc.)",
)
@model_validator(mode="after")
def _validate_key_source(self) -> Self:
"""Ensure exactly one key source is provided."""
has_path = self.private_key_path is not None
has_pem = self.private_key_pem is not None
if not has_path and not has_pem:
raise ValueError(
"Either private_key_path or private_key_pem must be provided"
)
if has_path and has_pem:
raise ValueError(
"Only one of private_key_path or private_key_pem should be provided"
)
return self
def get_private_key(self) -> str:
"""Get the private key content.
Returns:
The PEM-encoded private key as a string.
"""
if self.private_key_pem:
return self.private_key_pem.get_secret_value()
if self.private_key_path:
return Path(self.private_key_path).read_text()
raise ValueError("No private key configured")
class GRPCServerConfig(BaseModel):
"""gRPC server transport configuration.
Presence of this config in ServerTransportConfig.grpc enables gRPC transport.
Attributes:
host: Hostname to advertise in agent cards (default: localhost).
Use docker service name (e.g., 'web') for docker-compose setups.
port: Port for the gRPC server.
tls_cert_path: Path to TLS certificate file for gRPC.
tls_key_path: Path to TLS private key file for gRPC.
max_workers: Maximum number of workers for the gRPC thread pool.
reflection_enabled: Whether to enable gRPC reflection for debugging.
"""
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
host: str = Field(
default="localhost",
description="Hostname to advertise in agent cards for gRPC connections",
)
port: int = Field(
default=50051,
description="Port for the gRPC server",
)
tls_cert_path: str | None = Field(
default=None,
description="Path to TLS certificate file for gRPC",
)
tls_key_path: str | None = Field(
default=None,
description="Path to TLS private key file for gRPC",
)
max_workers: int = Field(
default=10,
description="Maximum number of workers for the gRPC thread pool",
)
reflection_enabled: bool = Field(
default=False,
description="Whether to enable gRPC reflection for debugging",
)
class GRPCClientConfig(BaseModel):
"""gRPC client transport configuration.
Attributes:
max_send_message_length: Maximum size for outgoing messages in bytes.
max_receive_message_length: Maximum size for incoming messages in bytes.
keepalive_time_ms: Time between keepalive pings in milliseconds.
keepalive_timeout_ms: Timeout for keepalive ping response in milliseconds.
"""
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
max_send_message_length: int | None = Field(
default=None,
description="Maximum size for outgoing messages in bytes",
)
max_receive_message_length: int | None = Field(
default=None,
description="Maximum size for incoming messages in bytes",
)
keepalive_time_ms: int | None = Field(
default=None,
description="Time between keepalive pings in milliseconds",
)
keepalive_timeout_ms: int | None = Field(
default=None,
description="Timeout for keepalive ping response in milliseconds",
)
class JSONRPCServerConfig(BaseModel):
"""JSON-RPC server transport configuration.
Presence of this config in ServerTransportConfig.jsonrpc enables JSON-RPC transport.
Attributes:
rpc_path: URL path for the JSON-RPC endpoint.
agent_card_path: URL path for the agent card endpoint.
"""
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
rpc_path: str = Field(
default="/a2a",
description="URL path for the JSON-RPC endpoint",
)
agent_card_path: str = Field(
default="/.well-known/agent-card.json",
description="URL path for the agent card endpoint",
)
class JSONRPCClientConfig(BaseModel):
"""JSON-RPC client transport configuration.
Attributes:
max_request_size: Maximum request body size in bytes.
"""
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
max_request_size: int | None = Field(
default=None,
description="Maximum request body size in bytes",
)
class HTTPJSONConfig(BaseModel):
"""HTTP+JSON transport configuration.
Presence of this config in ServerTransportConfig.http_json enables HTTP+JSON transport.
"""
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
class ServerPushNotificationConfig(BaseModel):
"""Configuration for outgoing webhook push notifications.
Controls how the server signs and delivers push notifications to clients.
Attributes:
signature_secret: Shared secret for HMAC-SHA256 signing of outgoing webhooks.
"""
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
signature_secret: SecretStr | None = Field(
default=None,
description="Shared secret for HMAC-SHA256 signing of outgoing push notifications",
)
class ServerTransportConfig(BaseModel):
"""Transport configuration for A2A server.
Groups all transport-related settings including preferred transport
and protocol-specific configurations.
Attributes:
preferred: Transport protocol for the preferred endpoint.
jsonrpc: JSON-RPC server transport configuration.
grpc: gRPC server transport configuration.
http_json: HTTP+JSON transport configuration.
"""
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
preferred: TransportType = Field(
default="JSONRPC",
description="Transport protocol for the preferred endpoint",
)
jsonrpc: JSONRPCServerConfig = Field(
default_factory=JSONRPCServerConfig,
description="JSON-RPC server transport configuration",
)
grpc: GRPCServerConfig | None = Field(
default=None,
description="gRPC server transport configuration",
)
http_json: HTTPJSONConfig | None = Field(
default=None,
description="HTTP+JSON transport configuration",
)
def _migrate_client_transport_fields(
transport: ClientTransportConfig,
transport_protocol: TransportType | None,
supported_transports: list[TransportType] | None,
) -> None:
"""Migrate deprecated transport fields to new config."""
if transport_protocol is not None:
warnings.warn(
"transport_protocol is deprecated, use transport=ClientTransportConfig(preferred=...) instead",
FutureWarning,
stacklevel=5,
)
object.__setattr__(transport, "preferred", transport_protocol)
if supported_transports is not None:
warnings.warn(
"supported_transports is deprecated, use transport=ClientTransportConfig(supported=...) instead",
FutureWarning,
stacklevel=5,
)
object.__setattr__(transport, "supported", supported_transports)
class ClientTransportConfig(BaseModel):
"""Transport configuration for A2A client.
Groups all client transport-related settings including preferred transport,
supported transports for negotiation, and protocol-specific configurations.
Transport negotiation logic:
1. If `preferred` is set and server supports it → use client's preferred
2. Otherwise, if server's preferred is in client's `supported` → use server's preferred
3. Otherwise, find first match from client's `supported` in server's interfaces
Attributes:
preferred: Client's preferred transport. If set, client preference takes priority.
supported: Transports the client can use, in order of preference.
jsonrpc: JSON-RPC client transport configuration.
grpc: gRPC client transport configuration.
"""
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
preferred: TransportType | None = Field(
default=None,
description="Client's preferred transport. If set, takes priority over server preference.",
)
supported: list[TransportType] = Field(
default_factory=lambda: cast(list[TransportType], ["JSONRPC"]),
description="Transports the client can use, in order of preference",
)
jsonrpc: JSONRPCClientConfig = Field(
default_factory=JSONRPCClientConfig,
description="JSON-RPC client transport configuration",
)
grpc: GRPCClientConfig = Field(
default_factory=GRPCClientConfig,
description="gRPC client transport configuration",
)
@deprecated(
"""
`crewai_a2a.config.A2AConfig` is deprecated and will be removed in v2.0.0,
use `crewai_a2a.config.A2AClientConfig` or `crewai_a2a.config.A2AServerConfig` instead.
""",
category=FutureWarning,
)
class A2AConfig(BaseModel):
"""Configuration for A2A protocol integration.
Deprecated:
Use A2AClientConfig instead. This class will be removed in a future version.
Attributes:
endpoint: A2A agent endpoint URL.
auth: Authentication scheme.
timeout: Request timeout in seconds.
max_turns: Maximum conversation turns with A2A agent.
response_model: Optional Pydantic model for structured A2A agent responses.
fail_fast: If True, raise error when agent unreachable; if False, skip and continue.
trust_remote_completion_status: If True, return A2A agent's result directly when completed.
updates: Update mechanism config.
client_extensions: Client-side processing hooks for tool injection and prompt augmentation.
transport: Transport configuration (preferred, supported transports, gRPC settings).
"""
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
endpoint: Url = Field(description="A2A agent endpoint URL")
auth: ClientAuthScheme | None = Field(
default=None,
description="Authentication scheme",
)
timeout: int = Field(default=120, description="Request timeout in seconds")
max_turns: int = Field(
default=10, description="Maximum conversation turns with A2A agent"
)
response_model: type[BaseModel] | None = Field(
default=None,
description="Optional Pydantic model for structured A2A agent responses",
)
fail_fast: bool = Field(
default=True,
description="If True, raise error when agent unreachable; if False, skip",
)
trust_remote_completion_status: bool = Field(
default=False,
description="If True, return A2A result directly when completed",
)
updates: UpdateConfig = Field(
default_factory=_get_default_update_config,
description="Update mechanism config",
)
client_extensions: list[ValidatedA2AExtension] = Field(
default_factory=list,
description="Client-side processing hooks for tool injection and prompt augmentation",
)
transport: ClientTransportConfig = Field(
default_factory=ClientTransportConfig,
description="Transport configuration (preferred, supported transports, gRPC settings)",
)
transport_protocol: TransportType | None = Field(
default=None,
description="Deprecated: Use transport.preferred instead",
exclude=True,
)
supported_transports: list[TransportType] | None = Field(
default=None,
description="Deprecated: Use transport.supported instead",
exclude=True,
)
use_client_preference: bool | None = Field(
default=None,
description="Deprecated: Set transport.preferred to enable client preference",
exclude=True,
)
_parallel_delegation: bool = PrivateAttr(default=False)
@model_validator(mode="after")
def _migrate_deprecated_transport_fields(self) -> Self:
"""Migrate deprecated transport fields to new config."""
_migrate_client_transport_fields(
self.transport, self.transport_protocol, self.supported_transports
)
if self.use_client_preference is not None:
warnings.warn(
"use_client_preference is deprecated, set transport.preferred to enable client preference",
FutureWarning,
stacklevel=4,
)
if self.use_client_preference and self.transport.supported:
object.__setattr__(
self.transport, "preferred", self.transport.supported[0]
)
return self
class A2AClientConfig(BaseModel):
"""Configuration for connecting to remote A2A agents.
Attributes:
endpoint: A2A agent endpoint URL.
auth: Authentication scheme.
timeout: Request timeout in seconds.
max_turns: Maximum conversation turns with A2A agent.
response_model: Optional Pydantic model for structured A2A agent responses.
fail_fast: If True, raise error when agent unreachable; if False, skip and continue.
trust_remote_completion_status: If True, return A2A agent's result directly when completed.
updates: Update mechanism config.
accepted_output_modes: Media types the client can accept in responses.
extensions: Extension URIs the client supports (A2A protocol extensions).
client_extensions: Client-side processing hooks for tool injection and prompt augmentation.
transport: Transport configuration (preferred, supported transports, gRPC settings).
"""
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
endpoint: Url = Field(description="A2A agent endpoint URL")
auth: ClientAuthScheme | None = Field(
default=None,
description="Authentication scheme",
)
timeout: int = Field(default=120, description="Request timeout in seconds")
max_turns: int = Field(
default=10, description="Maximum conversation turns with A2A agent"
)
response_model: type[BaseModel] | None = Field(
default=None,
description="Optional Pydantic model for structured A2A agent responses",
)
fail_fast: bool = Field(
default=True,
description="If True, raise error when agent unreachable; if False, skip",
)
trust_remote_completion_status: bool = Field(
default=False,
description="If True, return A2A result directly when completed",
)
updates: UpdateConfig = Field(
default_factory=_get_default_update_config,
description="Update mechanism config",
)
accepted_output_modes: list[str] = Field(
default_factory=lambda: ["application/json"],
description="Media types the client can accept in responses",
)
extensions: list[str] = Field(
default_factory=list,
description="Extension URIs the client supports",
)
client_extensions: list[ValidatedA2AExtension] = Field(
default_factory=list,
description="Client-side processing hooks for tool injection and prompt augmentation",
)
transport: ClientTransportConfig = Field(
default_factory=ClientTransportConfig,
description="Transport configuration (preferred, supported transports, gRPC settings)",
)
transport_protocol: TransportType | None = Field(
default=None,
description="Deprecated: Use transport.preferred instead",
exclude=True,
)
supported_transports: list[TransportType] | None = Field(
default=None,
description="Deprecated: Use transport.supported instead",
exclude=True,
)
_parallel_delegation: bool = PrivateAttr(default=False)
@model_validator(mode="after")
def _migrate_deprecated_transport_fields(self) -> Self:
"""Migrate deprecated transport fields to new config."""
_migrate_client_transport_fields(
self.transport, self.transport_protocol, self.supported_transports
)
return self
class A2AServerConfig(BaseModel):
"""Configuration for exposing a Crew or Agent as an A2A server.
All fields correspond to A2A AgentCard fields. Fields like name, description,
and skills can be auto-derived from the Crew/Agent if not provided.
Attributes:
name: Human-readable name for the agent.
description: Human-readable description of the agent.
version: Version string for the agent card.
skills: List of agent skills/capabilities.
default_input_modes: Default supported input MIME types.
default_output_modes: Default supported output MIME types.
capabilities: Declaration of optional capabilities.
protocol_version: A2A protocol version this agent supports.
provider: Information about the agent's service provider.
documentation_url: URL to the agent's documentation.
icon_url: URL to an icon for the agent.
additional_interfaces: Additional supported interfaces.
security: Security requirement objects for all interactions.
security_schemes: Security schemes available to authorize requests.
supports_authenticated_extended_card: Whether agent provides extended card to authenticated users.
url: Preferred endpoint URL for the agent.
signing_config: Configuration for signing the AgentCard with JWS.
signatures: Deprecated. Pre-computed JWS signatures. Use signing_config instead.
server_extensions: Server-side A2A protocol extensions with on_request/on_response hooks.
push_notifications: Configuration for outgoing push notifications.
transport: Transport configuration (preferred transport, gRPC, REST settings).
auth: Authentication scheme for A2A endpoints.
"""
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
name: str | None = Field(
default=None,
description="Human-readable name for the agent. Auto-derived from Crew/Agent if not provided.",
)
description: str | None = Field(
default=None,
description="Human-readable description of the agent. Auto-derived from Crew/Agent if not provided.",
)
version: str = Field(
default="1.0.0",
description="Version string for the agent card",
)
skills: list[AgentSkill] = Field(
default_factory=list,
description="List of agent skills. Auto-derived from tasks/tools if not provided.",
)
default_input_modes: list[str] = Field(
default_factory=lambda: ["text/plain", "application/json"],
description="Default supported input MIME types",
)
default_output_modes: list[str] = Field(
default_factory=lambda: ["text/plain", "application/json"],
description="Default supported output MIME types",
)
capabilities: AgentCapabilities = Field(
default_factory=lambda: AgentCapabilities(
streaming=True,
push_notifications=False,
),
description="Declaration of optional capabilities supported by the agent",
)
protocol_version: ProtocolVersion = Field(
default="0.3.0",
description="A2A protocol version this agent supports",
)
provider: AgentProvider | None = Field(
default=None,
description="Information about the agent's service provider",
)
documentation_url: Url | None = Field(
default=None,
description="URL to the agent's documentation",
)
icon_url: Url | None = Field(
default=None,
description="URL to an icon for the agent",
)
additional_interfaces: list[AgentInterface] = Field(
default_factory=list,
description="Additional supported interfaces.",
)
security: list[dict[str, list[str]]] = Field(
default_factory=list,
description="Security requirement objects for all agent interactions",
)
security_schemes: dict[str, SecurityScheme] = Field(
default_factory=dict,
description="Security schemes available to authorize requests",
)
supports_authenticated_extended_card: bool = Field(
default=False,
description="Whether agent provides extended card to authenticated users",
)
url: Url | None = Field(
default=None,
description="Preferred endpoint URL for the agent. Set at runtime if not provided.",
)
signing_config: AgentCardSigningConfig | None = Field(
default=None,
description="Configuration for signing the AgentCard with JWS",
)
signatures: list[AgentCardSignature] | None = Field(
default=None,
description="Deprecated: Use signing_config instead. Pre-computed JWS signatures for the AgentCard.",
exclude=True,
deprecated=True,
)
server_extensions: list[ServerExtension] = Field(
default_factory=list,
description="Server-side A2A protocol extensions that modify agent behavior",
)
push_notifications: ServerPushNotificationConfig | None = Field(
default=None,
description="Configuration for outgoing push notifications",
)
transport: ServerTransportConfig = Field(
default_factory=ServerTransportConfig,
description="Transport configuration (preferred transport, gRPC, REST settings)",
)
preferred_transport: TransportType | None = Field(
default=None,
description="Deprecated: Use transport.preferred instead",
exclude=True,
deprecated=True,
)
auth: ServerAuthScheme | None = Field(
default=None,
description="Authentication scheme for A2A endpoints. Defaults to SimpleTokenAuth using AUTH_TOKEN env var.",
)
@model_validator(mode="after")
def _migrate_deprecated_fields(self) -> Self:
"""Migrate deprecated fields to new config."""
if self.preferred_transport is not None:
warnings.warn(
"preferred_transport is deprecated, use transport=ServerTransportConfig(preferred=...) instead",
FutureWarning,
stacklevel=4,
)
object.__setattr__(self.transport, "preferred", self.preferred_transport)
if self.signatures is not None:
warnings.warn(
"signatures is deprecated, use signing_config=AgentCardSigningConfig(...) instead. "
"The signatures field will be removed in v2.0.0.",
FutureWarning,
stacklevel=4,
)
return self

View File

@@ -0,0 +1,491 @@
"""A2A error codes and error response utilities.
This module provides a centralized mapping of all A2A protocol error codes
as defined in the A2A specification, plus custom CrewAI extensions.
Error codes follow JSON-RPC 2.0 conventions:
- -32700 to -32600: Standard JSON-RPC errors
- -32099 to -32000: Server errors (A2A-specific)
- -32768 to -32100: Reserved for implementation-defined errors
"""
from __future__ import annotations
from dataclasses import dataclass, field
from enum import IntEnum
from typing import Any
from a2a.client.errors import A2AClientTimeoutError
class A2APollingTimeoutError(A2AClientTimeoutError):
"""Raised when polling exceeds the configured timeout."""
class A2AErrorCode(IntEnum):
"""A2A protocol error codes.
Codes follow JSON-RPC 2.0 specification with A2A-specific extensions.
"""
# JSON-RPC 2.0 Standard Errors (-32700 to -32600)
JSON_PARSE_ERROR = -32700
"""Invalid JSON was received by the server."""
INVALID_REQUEST = -32600
"""The JSON sent is not a valid Request object."""
METHOD_NOT_FOUND = -32601
"""The method does not exist / is not available."""
INVALID_PARAMS = -32602
"""Invalid method parameter(s)."""
INTERNAL_ERROR = -32603
"""Internal JSON-RPC error."""
# A2A-Specific Errors (-32099 to -32000)
TASK_NOT_FOUND = -32001
"""The specified task was not found."""
TASK_NOT_CANCELABLE = -32002
"""The task cannot be canceled (already completed/failed)."""
PUSH_NOTIFICATION_NOT_SUPPORTED = -32003
"""Push notifications are not supported by this agent."""
UNSUPPORTED_OPERATION = -32004
"""The requested operation is not supported."""
CONTENT_TYPE_NOT_SUPPORTED = -32005
"""Incompatible content types between client and server."""
INVALID_AGENT_RESPONSE = -32006
"""The agent produced an invalid response."""
# CrewAI Custom Extensions (-32768 to -32100)
UNSUPPORTED_VERSION = -32009
"""The requested A2A protocol version is not supported."""
UNSUPPORTED_EXTENSION = -32010
"""Client does not support required protocol extensions."""
AUTHENTICATION_REQUIRED = -32011
"""Authentication is required for this operation."""
AUTHORIZATION_FAILED = -32012
"""Authorization check failed (insufficient permissions)."""
RATE_LIMIT_EXCEEDED = -32013
"""Rate limit exceeded for this client/operation."""
TASK_TIMEOUT = -32014
"""Task execution timed out."""
TRANSPORT_NEGOTIATION_FAILED = -32015
"""Failed to negotiate a compatible transport protocol."""
CONTEXT_NOT_FOUND = -32016
"""The specified context was not found."""
SKILL_NOT_FOUND = -32017
"""The specified skill was not found."""
ARTIFACT_NOT_FOUND = -32018
"""The specified artifact was not found."""
# Error code to default message mapping
ERROR_MESSAGES: dict[int, str] = {
A2AErrorCode.JSON_PARSE_ERROR: "Parse error",
A2AErrorCode.INVALID_REQUEST: "Invalid Request",
A2AErrorCode.METHOD_NOT_FOUND: "Method not found",
A2AErrorCode.INVALID_PARAMS: "Invalid params",
A2AErrorCode.INTERNAL_ERROR: "Internal error",
A2AErrorCode.TASK_NOT_FOUND: "Task not found",
A2AErrorCode.TASK_NOT_CANCELABLE: "Task not cancelable",
A2AErrorCode.PUSH_NOTIFICATION_NOT_SUPPORTED: "Push Notification is not supported",
A2AErrorCode.UNSUPPORTED_OPERATION: "This operation is not supported",
A2AErrorCode.CONTENT_TYPE_NOT_SUPPORTED: "Incompatible content types",
A2AErrorCode.INVALID_AGENT_RESPONSE: "Invalid agent response",
A2AErrorCode.UNSUPPORTED_VERSION: "Unsupported A2A version",
A2AErrorCode.UNSUPPORTED_EXTENSION: "Client does not support required extensions",
A2AErrorCode.AUTHENTICATION_REQUIRED: "Authentication required",
A2AErrorCode.AUTHORIZATION_FAILED: "Authorization failed",
A2AErrorCode.RATE_LIMIT_EXCEEDED: "Rate limit exceeded",
A2AErrorCode.TASK_TIMEOUT: "Task execution timed out",
A2AErrorCode.TRANSPORT_NEGOTIATION_FAILED: "Transport negotiation failed",
A2AErrorCode.CONTEXT_NOT_FOUND: "Context not found",
A2AErrorCode.SKILL_NOT_FOUND: "Skill not found",
A2AErrorCode.ARTIFACT_NOT_FOUND: "Artifact not found",
}
@dataclass
class A2AError(Exception):
"""Base exception for A2A protocol errors.
Attributes:
code: The A2A/JSON-RPC error code.
message: Human-readable error message.
data: Optional additional error data.
"""
code: int
message: str | None = None
data: Any = None
def __post_init__(self) -> None:
if self.message is None:
self.message = ERROR_MESSAGES.get(self.code, "Unknown error")
super().__init__(self.message)
def to_dict(self) -> dict[str, Any]:
"""Convert to JSON-RPC error object format."""
error: dict[str, Any] = {
"code": self.code,
"message": self.message,
}
if self.data is not None:
error["data"] = self.data
return error
def to_response(self, request_id: str | int | None = None) -> dict[str, Any]:
"""Convert to full JSON-RPC error response."""
return {
"jsonrpc": "2.0",
"error": self.to_dict(),
"id": request_id,
}
@dataclass
class JSONParseError(A2AError):
"""Invalid JSON was received."""
code: int = field(default=A2AErrorCode.JSON_PARSE_ERROR, init=False)
@dataclass
class InvalidRequestError(A2AError):
"""The JSON sent is not a valid Request object."""
code: int = field(default=A2AErrorCode.INVALID_REQUEST, init=False)
@dataclass
class MethodNotFoundError(A2AError):
"""The method does not exist / is not available."""
code: int = field(default=A2AErrorCode.METHOD_NOT_FOUND, init=False)
method: str | None = None
def __post_init__(self) -> None:
if self.message is None and self.method:
self.message = f"Method not found: {self.method}"
super().__post_init__()
@dataclass
class InvalidParamsError(A2AError):
"""Invalid method parameter(s)."""
code: int = field(default=A2AErrorCode.INVALID_PARAMS, init=False)
param: str | None = None
reason: str | None = None
def __post_init__(self) -> None:
if self.message is None:
if self.param and self.reason:
self.message = f"Invalid parameter '{self.param}': {self.reason}"
elif self.param:
self.message = f"Invalid parameter: {self.param}"
super().__post_init__()
@dataclass
class InternalError(A2AError):
"""Internal JSON-RPC error."""
code: int = field(default=A2AErrorCode.INTERNAL_ERROR, init=False)
@dataclass
class TaskNotFoundError(A2AError):
"""The specified task was not found."""
code: int = field(default=A2AErrorCode.TASK_NOT_FOUND, init=False)
task_id: str | None = None
def __post_init__(self) -> None:
if self.message is None and self.task_id:
self.message = f"Task not found: {self.task_id}"
super().__post_init__()
@dataclass
class TaskNotCancelableError(A2AError):
"""The task cannot be canceled."""
code: int = field(default=A2AErrorCode.TASK_NOT_CANCELABLE, init=False)
task_id: str | None = None
reason: str | None = None
def __post_init__(self) -> None:
if self.message is None:
if self.task_id and self.reason:
self.message = f"Task {self.task_id} cannot be canceled: {self.reason}"
elif self.task_id:
self.message = f"Task {self.task_id} cannot be canceled"
super().__post_init__()
@dataclass
class PushNotificationNotSupportedError(A2AError):
"""Push notifications are not supported."""
code: int = field(default=A2AErrorCode.PUSH_NOTIFICATION_NOT_SUPPORTED, init=False)
@dataclass
class UnsupportedOperationError(A2AError):
"""The requested operation is not supported."""
code: int = field(default=A2AErrorCode.UNSUPPORTED_OPERATION, init=False)
operation: str | None = None
def __post_init__(self) -> None:
if self.message is None and self.operation:
self.message = f"Operation not supported: {self.operation}"
super().__post_init__()
@dataclass
class ContentTypeNotSupportedError(A2AError):
"""Incompatible content types."""
code: int = field(default=A2AErrorCode.CONTENT_TYPE_NOT_SUPPORTED, init=False)
requested_types: list[str] | None = None
supported_types: list[str] | None = None
def __post_init__(self) -> None:
if self.message is None and self.requested_types and self.supported_types:
self.message = (
f"Content type not supported. Requested: {self.requested_types}, "
f"Supported: {self.supported_types}"
)
super().__post_init__()
@dataclass
class InvalidAgentResponseError(A2AError):
"""The agent produced an invalid response."""
code: int = field(default=A2AErrorCode.INVALID_AGENT_RESPONSE, init=False)
@dataclass
class UnsupportedVersionError(A2AError):
"""The requested A2A version is not supported."""
code: int = field(default=A2AErrorCode.UNSUPPORTED_VERSION, init=False)
requested_version: str | None = None
supported_versions: list[str] | None = None
def __post_init__(self) -> None:
if self.message is None and self.requested_version:
msg = f"Unsupported A2A version: {self.requested_version}"
if self.supported_versions:
msg += f". Supported versions: {', '.join(self.supported_versions)}"
self.message = msg
super().__post_init__()
@dataclass
class UnsupportedExtensionError(A2AError):
"""Client does not support required extensions."""
code: int = field(default=A2AErrorCode.UNSUPPORTED_EXTENSION, init=False)
required_extensions: list[str] | None = None
def __post_init__(self) -> None:
if self.message is None and self.required_extensions:
self.message = f"Client does not support required extensions: {', '.join(self.required_extensions)}"
super().__post_init__()
@dataclass
class AuthenticationRequiredError(A2AError):
"""Authentication is required."""
code: int = field(default=A2AErrorCode.AUTHENTICATION_REQUIRED, init=False)
@dataclass
class AuthorizationFailedError(A2AError):
"""Authorization check failed."""
code: int = field(default=A2AErrorCode.AUTHORIZATION_FAILED, init=False)
required_scope: str | None = None
def __post_init__(self) -> None:
if self.message is None and self.required_scope:
self.message = (
f"Authorization failed. Required scope: {self.required_scope}"
)
super().__post_init__()
@dataclass
class RateLimitExceededError(A2AError):
"""Rate limit exceeded."""
code: int = field(default=A2AErrorCode.RATE_LIMIT_EXCEEDED, init=False)
retry_after: int | None = None
def __post_init__(self) -> None:
if self.message is None and self.retry_after:
self.message = (
f"Rate limit exceeded. Retry after {self.retry_after} seconds"
)
if self.retry_after:
self.data = {"retry_after": self.retry_after}
super().__post_init__()
@dataclass
class TaskTimeoutError(A2AError):
"""Task execution timed out."""
code: int = field(default=A2AErrorCode.TASK_TIMEOUT, init=False)
task_id: str | None = None
timeout_seconds: float | None = None
def __post_init__(self) -> None:
if self.message is None:
if self.task_id and self.timeout_seconds:
self.message = (
f"Task {self.task_id} timed out after {self.timeout_seconds}s"
)
elif self.task_id:
self.message = f"Task {self.task_id} timed out"
super().__post_init__()
@dataclass
class TransportNegotiationFailedError(A2AError):
"""Failed to negotiate a compatible transport protocol."""
code: int = field(default=A2AErrorCode.TRANSPORT_NEGOTIATION_FAILED, init=False)
client_transports: list[str] | None = None
server_transports: list[str] | None = None
def __post_init__(self) -> None:
if self.message is None and self.client_transports and self.server_transports:
self.message = (
f"Transport negotiation failed. Client: {self.client_transports}, "
f"Server: {self.server_transports}"
)
super().__post_init__()
@dataclass
class ContextNotFoundError(A2AError):
"""The specified context was not found."""
code: int = field(default=A2AErrorCode.CONTEXT_NOT_FOUND, init=False)
context_id: str | None = None
def __post_init__(self) -> None:
if self.message is None and self.context_id:
self.message = f"Context not found: {self.context_id}"
super().__post_init__()
@dataclass
class SkillNotFoundError(A2AError):
"""The specified skill was not found."""
code: int = field(default=A2AErrorCode.SKILL_NOT_FOUND, init=False)
skill_id: str | None = None
def __post_init__(self) -> None:
if self.message is None and self.skill_id:
self.message = f"Skill not found: {self.skill_id}"
super().__post_init__()
@dataclass
class ArtifactNotFoundError(A2AError):
"""The specified artifact was not found."""
code: int = field(default=A2AErrorCode.ARTIFACT_NOT_FOUND, init=False)
artifact_id: str | None = None
def __post_init__(self) -> None:
if self.message is None and self.artifact_id:
self.message = f"Artifact not found: {self.artifact_id}"
super().__post_init__()
def create_error_response(
code: int | A2AErrorCode,
message: str | None = None,
data: Any = None,
request_id: str | int | None = None,
) -> dict[str, Any]:
"""Create a JSON-RPC error response.
Args:
code: Error code (A2AErrorCode or int).
message: Optional error message (uses default if not provided).
data: Optional additional error data.
request_id: Request ID for correlation.
Returns:
Dict in JSON-RPC error response format.
"""
error = A2AError(code=int(code), message=message, data=data)
return error.to_response(request_id)
def is_retryable_error(code: int) -> bool:
"""Check if an error is potentially retryable.
Args:
code: Error code to check.
Returns:
True if the error might be resolved by retrying.
"""
retryable_codes = {
A2AErrorCode.INTERNAL_ERROR,
A2AErrorCode.RATE_LIMIT_EXCEEDED,
A2AErrorCode.TASK_TIMEOUT,
}
return code in retryable_codes
def is_client_error(code: int) -> bool:
"""Check if an error is a client-side error.
Args:
code: Error code to check.
Returns:
True if the error is due to client request issues.
"""
client_error_codes = {
A2AErrorCode.JSON_PARSE_ERROR,
A2AErrorCode.INVALID_REQUEST,
A2AErrorCode.METHOD_NOT_FOUND,
A2AErrorCode.INVALID_PARAMS,
A2AErrorCode.TASK_NOT_FOUND,
A2AErrorCode.CONTENT_TYPE_NOT_SUPPORTED,
A2AErrorCode.UNSUPPORTED_VERSION,
A2AErrorCode.UNSUPPORTED_EXTENSION,
A2AErrorCode.CONTEXT_NOT_FOUND,
A2AErrorCode.SKILL_NOT_FOUND,
A2AErrorCode.ARTIFACT_NOT_FOUND,
}
return code in client_error_codes

View File

@@ -0,0 +1,37 @@
"""A2A Protocol Extensions for CrewAI.
This module contains extensions to the A2A (Agent-to-Agent) protocol.
**Client-side extensions** (A2AExtension) allow customizing how the A2A wrapper
processes requests and responses during delegation to remote agents. These provide
hooks for tool injection, prompt augmentation, and response processing.
**Server-side extensions** (ServerExtension) allow agents to offer additional
functionality beyond the core A2A specification. Clients activate extensions
via the X-A2A-Extensions header.
See: https://a2a-protocol.org/latest/topics/extensions/
"""
from crewai_a2a.extensions.base import (
A2AExtension,
ConversationState,
ExtensionRegistry,
ValidatedA2AExtension,
)
from crewai_a2a.extensions.server import (
ExtensionContext,
ServerExtension,
ServerExtensionRegistry,
)
__all__ = [
"A2AExtension",
"ConversationState",
"ExtensionContext",
"ExtensionRegistry",
"ServerExtension",
"ServerExtensionRegistry",
"ValidatedA2AExtension",
]

View File

@@ -0,0 +1,237 @@
"""Base extension interface for CrewAI A2A wrapper processing hooks.
This module defines the protocol for extending CrewAI's A2A wrapper functionality
with custom logic for tool injection, prompt augmentation, and response processing.
Note: These are CrewAI-specific processing hooks, NOT A2A protocol extensions.
A2A protocol extensions are capability declarations using AgentExtension objects
in AgentCard.capabilities.extensions, activated via the A2A-Extensions HTTP header.
See: https://a2a-protocol.org/latest/topics/extensions/
"""
from __future__ import annotations
from collections.abc import Sequence
from typing import TYPE_CHECKING, Annotated, Any, Protocol, runtime_checkable
from pydantic import BeforeValidator
if TYPE_CHECKING:
from a2a.types import Message
from crewai.agent.core import Agent
def _validate_a2a_extension(v: Any) -> Any:
"""Validate that value implements A2AExtension protocol."""
if not isinstance(v, A2AExtension):
raise ValueError(
f"Value must implement A2AExtension protocol. "
f"Got {type(v).__name__} which is missing required methods."
)
return v
ValidatedA2AExtension = Annotated[Any, BeforeValidator(_validate_a2a_extension)]
@runtime_checkable
class ConversationState(Protocol):
"""Protocol for extension-specific conversation state.
Extensions can define their own state classes that implement this protocol
to track conversation-specific data extracted from message history.
"""
def is_ready(self) -> bool:
"""Check if the state indicates readiness for some action.
Returns:
True if the state is ready, False otherwise.
"""
...
@runtime_checkable
class A2AExtension(Protocol):
"""Protocol for A2A wrapper extensions.
Extensions can implement this protocol to inject custom logic into
the A2A conversation flow at various integration points.
Example:
class MyExtension:
def inject_tools(self, agent: Agent) -> None:
# Add custom tools to the agent
pass
def extract_state_from_history(
self, conversation_history: Sequence[Message]
) -> ConversationState | None:
# Extract state from conversation
return None
def augment_prompt(
self, base_prompt: str, conversation_state: ConversationState | None
) -> str:
# Add custom instructions
return base_prompt
def process_response(
self, agent_response: Any, conversation_state: ConversationState | None
) -> Any:
# Modify response if needed
return agent_response
"""
def inject_tools(self, agent: Agent) -> None:
"""Inject extension-specific tools into the agent.
Called when an agent is wrapped with A2A capabilities. Extensions
can add tools that enable extension-specific functionality.
Args:
agent: The agent instance to inject tools into.
"""
...
def extract_state_from_history(
self, conversation_history: Sequence[Message]
) -> ConversationState | None:
"""Extract extension-specific state from conversation history.
Called during prompt augmentation to allow extensions to analyze
the conversation history and extract relevant state information.
Args:
conversation_history: The sequence of A2A messages exchanged.
Returns:
Extension-specific conversation state, or None if no relevant state.
"""
...
def augment_prompt(
self,
base_prompt: str,
conversation_state: ConversationState | None,
) -> str:
"""Augment the task prompt with extension-specific instructions.
Called during prompt augmentation to allow extensions to add
custom instructions based on conversation state.
Args:
base_prompt: The base prompt to augment.
conversation_state: Extension-specific state from extract_state_from_history.
Returns:
The augmented prompt with extension-specific instructions.
"""
...
def process_response(
self,
agent_response: Any,
conversation_state: ConversationState | None,
) -> Any:
"""Process and potentially modify the agent response.
Called after parsing the agent's response, allowing extensions to
enhance or modify the response based on conversation state.
Args:
agent_response: The parsed agent response.
conversation_state: Extension-specific state from extract_state_from_history.
Returns:
The processed agent response (may be modified or original).
"""
...
class ExtensionRegistry:
"""Registry for managing A2A extensions.
Maintains a collection of extensions and provides methods to invoke
their hooks at various integration points.
"""
def __init__(self) -> None:
"""Initialize the extension registry."""
self._extensions: list[A2AExtension] = []
def register(self, extension: A2AExtension) -> None:
"""Register an extension.
Args:
extension: The extension to register.
"""
self._extensions.append(extension)
def inject_all_tools(self, agent: Agent) -> None:
"""Inject tools from all registered extensions.
Args:
agent: The agent instance to inject tools into.
"""
for extension in self._extensions:
extension.inject_tools(agent)
def extract_all_states(
self, conversation_history: Sequence[Message]
) -> dict[type[A2AExtension], ConversationState]:
"""Extract conversation states from all registered extensions.
Args:
conversation_history: The sequence of A2A messages exchanged.
Returns:
Mapping of extension types to their conversation states.
"""
states: dict[type[A2AExtension], ConversationState] = {}
for extension in self._extensions:
state = extension.extract_state_from_history(conversation_history)
if state is not None:
states[type(extension)] = state
return states
def augment_prompt_with_all(
self,
base_prompt: str,
extension_states: dict[type[A2AExtension], ConversationState],
) -> str:
"""Augment prompt with instructions from all registered extensions.
Args:
base_prompt: The base prompt to augment.
extension_states: Mapping of extension types to conversation states.
Returns:
The fully augmented prompt.
"""
augmented = base_prompt
for extension in self._extensions:
state = extension_states.get(type(extension))
augmented = extension.augment_prompt(augmented, state)
return augmented
def process_response_with_all(
self,
agent_response: Any,
extension_states: dict[type[A2AExtension], ConversationState],
) -> Any:
"""Process response through all registered extensions.
Args:
agent_response: The parsed agent response.
extension_states: Mapping of extension types to conversation states.
Returns:
The processed agent response.
"""
processed = agent_response
for extension in self._extensions:
state = extension_states.get(type(extension))
processed = extension.process_response(processed, state)
return processed

View File

@@ -0,0 +1,170 @@
"""A2A Protocol extension utilities.
This module provides utilities for working with A2A protocol extensions as
defined in the A2A specification. Extensions are capability declarations in
AgentCard.capabilities.extensions using AgentExtension objects, activated
via the X-A2A-Extensions HTTP header.
See: https://a2a-protocol.org/latest/topics/extensions/
"""
from __future__ import annotations
from typing import Any
from a2a.client.middleware import ClientCallContext, ClientCallInterceptor
from a2a.extensions.common import (
HTTP_EXTENSION_HEADER,
)
from a2a.types import AgentCard, AgentExtension
from crewai_a2a.config import A2AClientConfig, A2AConfig
from crewai_a2a.extensions.base import ExtensionRegistry
def get_extensions_from_config(
a2a_config: list[A2AConfig | A2AClientConfig] | A2AConfig | A2AClientConfig,
) -> list[str]:
"""Extract extension URIs from A2A configuration.
Args:
a2a_config: A2A configuration (single or list).
Returns:
Deduplicated list of extension URIs from all configs.
"""
configs = a2a_config if isinstance(a2a_config, list) else [a2a_config]
seen: set[str] = set()
result: list[str] = []
for config in configs:
if not isinstance(config, A2AClientConfig):
continue
for uri in config.extensions:
if uri not in seen:
seen.add(uri)
result.append(uri)
return result
class ExtensionsMiddleware(ClientCallInterceptor):
"""Middleware to add X-A2A-Extensions header to requests.
This middleware adds the extensions header to all outgoing requests,
declaring which A2A protocol extensions the client supports.
"""
def __init__(self, extensions: list[str]) -> None:
"""Initialize with extension URIs.
Args:
extensions: List of extension URIs the client supports.
"""
self._extensions = extensions
async def intercept(
self,
method_name: str,
request_payload: dict[str, Any],
http_kwargs: dict[str, Any],
agent_card: AgentCard | None,
context: ClientCallContext | None,
) -> tuple[dict[str, Any], dict[str, Any]]:
"""Add extensions header to the request.
Args:
method_name: The A2A method being called.
request_payload: The JSON-RPC request payload.
http_kwargs: HTTP request kwargs (headers, etc).
agent_card: The target agent's card.
context: Optional call context.
Returns:
Tuple of (request_payload, modified_http_kwargs).
"""
if self._extensions:
headers = http_kwargs.setdefault("headers", {})
headers[HTTP_EXTENSION_HEADER] = ",".join(self._extensions)
return request_payload, http_kwargs
def validate_required_extensions(
agent_card: AgentCard,
client_extensions: list[str] | None,
) -> list[AgentExtension]:
"""Validate that client supports all required extensions from agent.
Args:
agent_card: The agent's card with declared extensions.
client_extensions: Extension URIs the client supports.
Returns:
List of unsupported required extensions.
Raises:
None - returns list of unsupported extensions for caller to handle.
"""
unsupported: list[AgentExtension] = []
client_set = set(client_extensions or [])
if not agent_card.capabilities or not agent_card.capabilities.extensions:
return unsupported
unsupported.extend(
ext
for ext in agent_card.capabilities.extensions
if ext.required and ext.uri not in client_set
)
return unsupported
def create_extension_registry_from_config(
a2a_config: list[A2AConfig | A2AClientConfig] | A2AConfig | A2AClientConfig,
) -> ExtensionRegistry:
"""Create an extension registry from A2A client configuration.
Extracts client_extensions from each A2AClientConfig and registers them
with the ExtensionRegistry. These extensions provide CrewAI-specific
processing hooks (tool injection, prompt augmentation, response processing).
Note: A2A protocol extensions (URI strings sent via X-A2A-Extensions header)
are handled separately via get_extensions_from_config() and ExtensionsMiddleware.
Args:
a2a_config: A2A configuration (single or list).
Returns:
Extension registry with all client_extensions registered.
Example:
class LoggingExtension:
def inject_tools(self, agent): pass
def extract_state_from_history(self, history): return None
def augment_prompt(self, prompt, state): return prompt
def process_response(self, response, state):
print(f"Response: {response}")
return response
config = A2AClientConfig(
endpoint="https://agent.example.com",
client_extensions=[LoggingExtension()],
)
registry = create_extension_registry_from_config(config)
"""
registry = ExtensionRegistry()
configs = a2a_config if isinstance(a2a_config, list) else [a2a_config]
seen: set[int] = set()
for config in configs:
if isinstance(config, (A2AConfig, A2AClientConfig)):
client_exts = getattr(config, "client_extensions", [])
for extension in client_exts:
ext_id = id(extension)
if ext_id not in seen:
seen.add(ext_id)
registry.register(extension)
return registry

View File

@@ -0,0 +1,305 @@
"""A2A protocol server extensions for CrewAI agents.
This module provides the base class and context for implementing A2A protocol
extensions on the server side. Extensions allow agents to offer additional
functionality beyond the core A2A specification.
See: https://a2a-protocol.org/latest/topics/extensions/
"""
from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
import logging
from typing import TYPE_CHECKING, Annotated, Any
from a2a.types import AgentExtension
from pydantic_core import CoreSchema, core_schema
if TYPE_CHECKING:
from a2a.server.context import ServerCallContext
from pydantic import GetCoreSchemaHandler
logger = logging.getLogger(__name__)
@dataclass
class ExtensionContext:
"""Context passed to extension hooks during request processing.
Provides access to request metadata, client extensions, and shared state
that extensions can read from and write to.
Attributes:
metadata: Request metadata dict, includes extension-namespaced keys.
client_extensions: Set of extension URIs the client declared support for.
state: Mutable dict for extensions to share data during request lifecycle.
server_context: The underlying A2A server call context.
"""
metadata: dict[str, Any]
client_extensions: set[str]
state: dict[str, Any] = field(default_factory=dict)
server_context: ServerCallContext | None = None
def get_extension_metadata(self, uri: str, key: str) -> Any | None:
"""Get extension-specific metadata value.
Extension metadata uses namespaced keys in the format:
"{extension_uri}/{key}"
Args:
uri: The extension URI.
key: The metadata key within the extension namespace.
Returns:
The metadata value, or None if not present.
"""
full_key = f"{uri}/{key}"
return self.metadata.get(full_key)
def set_extension_metadata(self, uri: str, key: str, value: Any) -> None:
"""Set extension-specific metadata value.
Args:
uri: The extension URI.
key: The metadata key within the extension namespace.
value: The value to set.
"""
full_key = f"{uri}/{key}"
self.metadata[full_key] = value
class ServerExtension(ABC):
"""Base class for A2A protocol server extensions.
Subclass this to create custom extensions that modify agent behavior
when clients activate them. Extensions are identified by URI and can
be marked as required.
Example:
class SamplingExtension(ServerExtension):
uri = "urn:crewai:ext:sampling/v1"
required = True
def __init__(self, max_tokens: int = 4096):
self.max_tokens = max_tokens
@property
def params(self) -> dict[str, Any]:
return {"max_tokens": self.max_tokens}
async def on_request(self, context: ExtensionContext) -> None:
limit = context.get_extension_metadata(self.uri, "limit")
if limit:
context.state["token_limit"] = int(limit)
async def on_response(self, context: ExtensionContext, result: Any) -> Any:
return result
"""
uri: Annotated[str, "Extension URI identifier. Must be unique."]
required: Annotated[bool, "Whether clients must support this extension."] = False
description: Annotated[
str | None, "Human-readable description of the extension."
] = None
@classmethod
def __get_pydantic_core_schema__(
cls,
_source_type: Any,
_handler: GetCoreSchemaHandler,
) -> CoreSchema:
"""Tell Pydantic how to validate ServerExtension instances."""
return core_schema.is_instance_schema(cls)
@property
def params(self) -> dict[str, Any] | None:
"""Extension parameters to advertise in AgentCard.
Override this property to expose configuration that clients can read.
Returns:
Dict of parameter names to values, or None.
"""
return None
def agent_extension(self) -> AgentExtension:
"""Generate the AgentExtension object for the AgentCard.
Returns:
AgentExtension with this extension's URI, required flag, and params.
"""
return AgentExtension(
uri=self.uri,
required=self.required if self.required else None,
description=self.description,
params=self.params,
)
def is_active(self, context: ExtensionContext) -> bool:
"""Check if this extension is active for the current request.
An extension is active if the client declared support for it.
Args:
context: The extension context for the current request.
Returns:
True if the client supports this extension.
"""
return self.uri in context.client_extensions
@abstractmethod
async def on_request(self, context: ExtensionContext) -> None:
"""Called before agent execution if extension is active.
Use this hook to:
- Read extension-specific metadata from the request
- Set up state for the execution
- Modify execution parameters via context.state
Args:
context: The extension context with request metadata and state.
"""
...
@abstractmethod
async def on_response(self, context: ExtensionContext, result: Any) -> Any:
"""Called after agent execution if extension is active.
Use this hook to:
- Modify or enhance the result
- Add extension-specific metadata to the response
- Clean up any resources
Args:
context: The extension context with request metadata and state.
result: The agent execution result.
Returns:
The result, potentially modified.
"""
...
class ServerExtensionRegistry:
"""Registry for managing server-side A2A protocol extensions.
Collects extensions and provides methods to generate AgentCapabilities
and invoke extension hooks during request processing.
"""
def __init__(self, extensions: list[ServerExtension] | None = None) -> None:
"""Initialize the registry with optional extensions.
Args:
extensions: Initial list of extensions to register.
"""
self._extensions: list[ServerExtension] = list(extensions) if extensions else []
self._by_uri: dict[str, ServerExtension] = {
ext.uri: ext for ext in self._extensions
}
def register(self, extension: ServerExtension) -> None:
"""Register an extension.
Args:
extension: The extension to register.
Raises:
ValueError: If an extension with the same URI is already registered.
"""
if extension.uri in self._by_uri:
raise ValueError(f"Extension already registered: {extension.uri}")
self._extensions.append(extension)
self._by_uri[extension.uri] = extension
def get_agent_extensions(self) -> list[AgentExtension]:
"""Get AgentExtension objects for all registered extensions.
Returns:
List of AgentExtension objects for the AgentCard.
"""
return [ext.agent_extension() for ext in self._extensions]
def get_extension(self, uri: str) -> ServerExtension | None:
"""Get an extension by URI.
Args:
uri: The extension URI.
Returns:
The extension, or None if not found.
"""
return self._by_uri.get(uri)
@staticmethod
def create_context(
metadata: dict[str, Any],
client_extensions: set[str],
server_context: ServerCallContext | None = None,
) -> ExtensionContext:
"""Create an ExtensionContext for a request.
Args:
metadata: Request metadata dict.
client_extensions: Set of extension URIs from client.
server_context: Optional server call context.
Returns:
ExtensionContext for use in hooks.
"""
return ExtensionContext(
metadata=metadata,
client_extensions=client_extensions,
server_context=server_context,
)
async def invoke_on_request(self, context: ExtensionContext) -> None:
"""Invoke on_request hooks for all active extensions.
Tracks activated extensions and isolates errors from individual hooks.
Args:
context: The extension context for the request.
"""
for extension in self._extensions:
if extension.is_active(context):
try:
await extension.on_request(context)
if context.server_context is not None:
context.server_context.activated_extensions.add(extension.uri)
except Exception:
logger.exception(
"Extension on_request hook failed",
extra={"extension": extension.uri},
)
async def invoke_on_response(self, context: ExtensionContext, result: Any) -> Any:
"""Invoke on_response hooks for all active extensions.
Isolates errors from individual hooks to prevent one failing extension
from breaking the entire response.
Args:
context: The extension context for the request.
result: The agent execution result.
Returns:
The result after all extensions have processed it.
"""
processed = result
for extension in self._extensions:
if extension.is_active(context):
try:
processed = await extension.on_response(context, processed)
except Exception:
logger.exception(
"Extension on_response hook failed",
extra={"extension": extension.uri},
)
return processed

View File

View File

@@ -0,0 +1,479 @@
"""Helper functions for processing A2A task results."""
from __future__ import annotations
from collections.abc import AsyncIterator
from typing import TYPE_CHECKING, Any, TypedDict
import uuid
from a2a.client.errors import A2AClientHTTPError
from a2a.types import (
AgentCard,
Message,
Part,
Role,
Task,
TaskArtifactUpdateEvent,
TaskState,
TaskStatusUpdateEvent,
TextPart,
)
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.a2a_events import (
A2AConnectionErrorEvent,
A2AResponseReceivedEvent,
)
from typing_extensions import NotRequired
if TYPE_CHECKING:
from a2a.types import Task as A2ATask
SendMessageEvent = (
tuple[Task, TaskStatusUpdateEvent | TaskArtifactUpdateEvent | None] | Message
)
TERMINAL_STATES: frozenset[TaskState] = frozenset(
{
TaskState.completed,
TaskState.failed,
TaskState.rejected,
TaskState.canceled,
}
)
ACTIONABLE_STATES: frozenset[TaskState] = frozenset(
{
TaskState.input_required,
TaskState.auth_required,
}
)
PENDING_STATES: frozenset[TaskState] = frozenset(
{
TaskState.submitted,
TaskState.working,
}
)
class TaskStateResult(TypedDict):
"""Result dictionary from processing A2A task state."""
status: TaskState
history: list[Message]
result: NotRequired[str]
error: NotRequired[str]
agent_card: NotRequired[dict[str, Any]]
a2a_agent_name: NotRequired[str | None]
def extract_task_result_parts(a2a_task: A2ATask) -> list[str]:
"""Extract result parts from A2A task status message, history, and artifacts.
Args:
a2a_task: A2A Task object with status, history, and artifacts
Returns:
List of result text parts
"""
result_parts: list[str] = []
if a2a_task.status and a2a_task.status.message:
msg = a2a_task.status.message
result_parts.extend(
part.root.text for part in msg.parts if part.root.kind == "text"
)
if not result_parts and a2a_task.history:
for history_msg in reversed(a2a_task.history):
if history_msg.role == Role.agent:
result_parts.extend(
part.root.text
for part in history_msg.parts
if part.root.kind == "text"
)
break
if a2a_task.artifacts:
result_parts.extend(
part.root.text
for artifact in a2a_task.artifacts
for part in artifact.parts
if part.root.kind == "text"
)
return result_parts
def extract_error_message(a2a_task: A2ATask, default: str) -> str:
"""Extract error message from A2A task.
Args:
a2a_task: A2A Task object
default: Default message if no error found
Returns:
Error message string
"""
if a2a_task.status and a2a_task.status.message:
msg = a2a_task.status.message
if msg:
for part in msg.parts:
if part.root.kind == "text":
return str(part.root.text)
return str(msg)
if a2a_task.history:
for history_msg in reversed(a2a_task.history):
for part in history_msg.parts:
if part.root.kind == "text":
return str(part.root.text)
return default
def process_task_state(
a2a_task: A2ATask,
new_messages: list[Message],
agent_card: AgentCard,
turn_number: int,
is_multiturn: bool,
agent_role: str | None,
result_parts: list[str] | None = None,
endpoint: str | None = None,
a2a_agent_name: str | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
is_final: bool = True,
) -> TaskStateResult | None:
"""Process A2A task state and return result dictionary.
Shared logic for both polling and streaming handlers.
Args:
a2a_task: The A2A task to process.
new_messages: List to collect messages (modified in place).
agent_card: The agent card.
turn_number: Current turn number.
is_multiturn: Whether multi-turn conversation.
agent_role: Agent role for logging.
result_parts: Accumulated result parts (streaming passes accumulated,
polling passes None to extract from task).
endpoint: A2A agent endpoint URL.
a2a_agent_name: Name of the A2A agent from agent card.
from_task: Optional CrewAI Task for event metadata.
from_agent: Optional CrewAI Agent for event metadata.
is_final: Whether this is the final response in the stream.
Returns:
Result dictionary if terminal/actionable state, None otherwise.
"""
if result_parts is None:
result_parts = []
if a2a_task.status.state == TaskState.completed:
if not result_parts:
extracted_parts = extract_task_result_parts(a2a_task)
result_parts.extend(extracted_parts)
if a2a_task.history:
new_messages.extend(a2a_task.history)
response_text = " ".join(result_parts) if result_parts else ""
message_id = None
if a2a_task.status and a2a_task.status.message:
message_id = a2a_task.status.message.message_id
crewai_event_bus.emit(
None,
A2AResponseReceivedEvent(
response=response_text,
turn_number=turn_number,
context_id=a2a_task.context_id,
message_id=message_id,
is_multiturn=is_multiturn,
status="completed",
final=is_final,
agent_role=agent_role,
endpoint=endpoint,
a2a_agent_name=a2a_agent_name,
from_task=from_task,
from_agent=from_agent,
),
)
return TaskStateResult(
status=TaskState.completed,
agent_card=agent_card.model_dump(exclude_none=True),
result=response_text,
history=new_messages,
)
if a2a_task.status.state == TaskState.input_required:
if a2a_task.history:
new_messages.extend(a2a_task.history)
response_text = extract_error_message(a2a_task, "Additional input required")
if response_text and not a2a_task.history:
agent_message = Message(
role=Role.agent,
message_id=str(uuid.uuid4()),
parts=[Part(root=TextPart(text=response_text))],
context_id=a2a_task.context_id,
task_id=a2a_task.id,
)
new_messages.append(agent_message)
input_message_id = None
if a2a_task.status and a2a_task.status.message:
input_message_id = a2a_task.status.message.message_id
crewai_event_bus.emit(
None,
A2AResponseReceivedEvent(
response=response_text,
turn_number=turn_number,
context_id=a2a_task.context_id,
message_id=input_message_id,
is_multiturn=is_multiturn,
status="input_required",
final=is_final,
agent_role=agent_role,
endpoint=endpoint,
a2a_agent_name=a2a_agent_name,
from_task=from_task,
from_agent=from_agent,
),
)
return TaskStateResult(
status=TaskState.input_required,
error=response_text,
history=new_messages,
agent_card=agent_card.model_dump(exclude_none=True),
)
if a2a_task.status.state in {TaskState.failed, TaskState.rejected}:
error_msg = extract_error_message(a2a_task, "Task failed without error message")
if a2a_task.history:
new_messages.extend(a2a_task.history)
return TaskStateResult(
status=TaskState.failed,
error=error_msg,
history=new_messages,
)
if a2a_task.status.state == TaskState.auth_required:
error_msg = extract_error_message(a2a_task, "Authentication required")
return TaskStateResult(
status=TaskState.auth_required,
error=error_msg,
history=new_messages,
)
if a2a_task.status.state == TaskState.canceled:
error_msg = extract_error_message(a2a_task, "Task was canceled")
return TaskStateResult(
status=TaskState.canceled,
error=error_msg,
history=new_messages,
)
if a2a_task.status.state in PENDING_STATES:
return None
return None
async def send_message_and_get_task_id(
event_stream: AsyncIterator[SendMessageEvent],
new_messages: list[Message],
agent_card: AgentCard,
turn_number: int,
is_multiturn: bool,
agent_role: str | None,
from_task: Any | None = None,
from_agent: Any | None = None,
endpoint: str | None = None,
a2a_agent_name: str | None = None,
context_id: str | None = None,
) -> str | TaskStateResult:
"""Send message and process initial response.
Handles the common pattern of sending a message and either:
- Getting an immediate Message response (task completed synchronously)
- Getting a Task that needs polling/waiting for completion
Args:
event_stream: Async iterator from client.send_message()
new_messages: List to collect messages (modified in place)
agent_card: The agent card
turn_number: Current turn number
is_multiturn: Whether multi-turn conversation
agent_role: Agent role for logging
from_task: Optional CrewAI Task object for event metadata.
from_agent: Optional CrewAI Agent object for event metadata.
endpoint: Optional A2A endpoint URL.
a2a_agent_name: Optional A2A agent name.
context_id: Optional A2A context ID for correlation.
Returns:
Task ID string if agent needs polling/waiting, or TaskStateResult if done.
"""
try:
async for event in event_stream:
if isinstance(event, Message):
new_messages.append(event)
result_parts = [
part.root.text for part in event.parts if part.root.kind == "text"
]
response_text = " ".join(result_parts) if result_parts else ""
crewai_event_bus.emit(
None,
A2AResponseReceivedEvent(
response=response_text,
turn_number=turn_number,
context_id=event.context_id,
message_id=event.message_id,
is_multiturn=is_multiturn,
status="completed",
final=True,
agent_role=agent_role,
endpoint=endpoint,
a2a_agent_name=a2a_agent_name,
from_task=from_task,
from_agent=from_agent,
),
)
return TaskStateResult(
status=TaskState.completed,
result=response_text,
history=new_messages,
agent_card=agent_card.model_dump(exclude_none=True),
)
if isinstance(event, tuple):
a2a_task, _ = event
if a2a_task.status.state in TERMINAL_STATES | ACTIONABLE_STATES:
result = process_task_state(
a2a_task=a2a_task,
new_messages=new_messages,
agent_card=agent_card,
turn_number=turn_number,
is_multiturn=is_multiturn,
agent_role=agent_role,
endpoint=endpoint,
a2a_agent_name=a2a_agent_name,
from_task=from_task,
from_agent=from_agent,
)
if result:
return result
return a2a_task.id
return TaskStateResult(
status=TaskState.failed,
error="No task ID received from initial message",
history=new_messages,
)
except A2AClientHTTPError as e:
error_msg = f"HTTP Error {e.status_code}: {e!s}"
error_message = Message(
role=Role.agent,
message_id=str(uuid.uuid4()),
parts=[Part(root=TextPart(text=error_msg))],
context_id=context_id,
)
new_messages.append(error_message)
crewai_event_bus.emit(
None,
A2AConnectionErrorEvent(
endpoint=endpoint or "",
error=str(e),
error_type="http_error",
status_code=e.status_code,
a2a_agent_name=a2a_agent_name,
operation="send_message",
context_id=context_id,
from_task=from_task,
from_agent=from_agent,
),
)
crewai_event_bus.emit(
None,
A2AResponseReceivedEvent(
response=error_msg,
turn_number=turn_number,
context_id=context_id,
is_multiturn=is_multiturn,
status="failed",
final=True,
agent_role=agent_role,
endpoint=endpoint,
a2a_agent_name=a2a_agent_name,
from_task=from_task,
from_agent=from_agent,
),
)
return TaskStateResult(
status=TaskState.failed,
error=error_msg,
history=new_messages,
)
except Exception as e:
error_msg = f"Unexpected error during send_message: {e!s}"
error_message = Message(
role=Role.agent,
message_id=str(uuid.uuid4()),
parts=[Part(root=TextPart(text=error_msg))],
context_id=context_id,
)
new_messages.append(error_message)
crewai_event_bus.emit(
None,
A2AConnectionErrorEvent(
endpoint=endpoint or "",
error=str(e),
error_type="unexpected_error",
a2a_agent_name=a2a_agent_name,
operation="send_message",
context_id=context_id,
from_task=from_task,
from_agent=from_agent,
),
)
crewai_event_bus.emit(
None,
A2AResponseReceivedEvent(
response=error_msg,
turn_number=turn_number,
context_id=context_id,
is_multiturn=is_multiturn,
status="failed",
final=True,
agent_role=agent_role,
endpoint=endpoint,
a2a_agent_name=a2a_agent_name,
from_task=from_task,
from_agent=from_agent,
),
)
return TaskStateResult(
status=TaskState.failed,
error=error_msg,
history=new_messages,
)
finally:
aclose = getattr(event_stream, "aclose", None)
if aclose:
await aclose()

View File

@@ -0,0 +1,55 @@
"""String templates for A2A (Agent-to-Agent) protocol messaging and status."""
from string import Template
from typing import Final
AVAILABLE_AGENTS_TEMPLATE: Final[Template] = Template(
"\n<AVAILABLE_A2A_AGENTS>\n $available_a2a_agents\n</AVAILABLE_A2A_AGENTS>\n"
)
PREVIOUS_A2A_CONVERSATION_TEMPLATE: Final[Template] = Template(
"\n<PREVIOUS_A2A_CONVERSATION>\n"
" $previous_a2a_conversation"
"\n</PREVIOUS_A2A_CONVERSATION>\n"
)
CONVERSATION_TURN_INFO_TEMPLATE: Final[Template] = Template(
"\n<CONVERSATION_PROGRESS>\n"
' turn="$turn_count"\n'
' max_turns="$max_turns"\n'
" $warning"
"\n</CONVERSATION_PROGRESS>\n"
)
UNAVAILABLE_AGENTS_NOTICE_TEMPLATE: Final[Template] = Template(
"\n<A2A_AGENTS_STATUS>\n"
" NOTE: A2A agents were configured but are currently unavailable.\n"
" You cannot delegate to remote agents for this task.\n\n"
" Unavailable Agents:\n"
" $unavailable_agents"
"\n</A2A_AGENTS_STATUS>\n"
)
REMOTE_AGENT_COMPLETED_NOTICE: Final[str] = """
<REMOTE_AGENT_STATUS>
STATUS: COMPLETED
The remote agent has finished processing your request. Their response is in the conversation history above.
You MUST now:
1. Extract the answer from the conversation history
2. Set is_a2a=false
3. Return the answer as your final message
DO NOT send another request - the task is already done.
</REMOTE_AGENT_STATUS>
"""
REMOTE_AGENT_RESPONSE_NOTICE: Final[str] = """
<REMOTE_AGENT_STATUS>
STATUS: RESPONSE_RECEIVED
The remote agent has responded. Their response is in the conversation history above.
You MUST now:
1. Set is_a2a=false (the remote task is complete and cannot receive more messages)
2. Provide YOUR OWN response to the original task based on the information received
IMPORTANT: Your response should be addressed to the USER who gave you the original task.
Report what the remote agent told you in THIRD PERSON (e.g., "The remote agent said..." or "I learned that...").
Do NOT address the remote agent directly or use "you" to refer to them.
</REMOTE_AGENT_STATUS>
"""

View File

@@ -0,0 +1,104 @@
"""Type definitions for A2A protocol message parts."""
from __future__ import annotations
from typing import (
Annotated,
Any,
Literal,
Protocol,
TypedDict,
runtime_checkable,
)
from pydantic import BeforeValidator, HttpUrl, TypeAdapter
from typing_extensions import NotRequired
try:
from crewai_a2a.updates import (
PollingConfig,
PollingHandler,
PushNotificationConfig,
PushNotificationHandler,
StreamingConfig,
StreamingHandler,
UpdateConfig,
)
except ImportError:
PollingConfig = Any # type: ignore[misc,assignment]
PollingHandler = Any # type: ignore[misc,assignment]
PushNotificationConfig = Any # type: ignore[misc,assignment]
PushNotificationHandler = Any # type: ignore[misc,assignment]
StreamingConfig = Any # type: ignore[misc,assignment]
StreamingHandler = Any # type: ignore[misc,assignment]
UpdateConfig = Any # type: ignore[misc,assignment]
TransportType = Literal["JSONRPC", "GRPC", "HTTP+JSON"]
ProtocolVersion = Literal[
"0.2.0",
"0.2.1",
"0.2.2",
"0.2.3",
"0.2.4",
"0.2.5",
"0.2.6",
"0.3.0",
"0.4.0",
]
http_url_adapter: TypeAdapter[HttpUrl] = TypeAdapter(HttpUrl)
Url = Annotated[
str,
BeforeValidator(
lambda value: str(http_url_adapter.validate_python(value, strict=True))
),
]
@runtime_checkable
class AgentResponseProtocol(Protocol):
"""Protocol for the dynamically created AgentResponse model."""
a2a_ids: tuple[str, ...]
message: str
is_a2a: bool
class PartsMetadataDict(TypedDict, total=False):
"""Metadata for A2A message parts.
Attributes:
mimeType: MIME type for the part content.
schema: JSON schema for the part content.
"""
mimeType: Literal["application/json"]
schema: dict[str, Any]
class PartsDict(TypedDict):
"""A2A message part containing text and optional metadata.
Attributes:
text: The text content of the message part.
metadata: Optional metadata describing the part content.
"""
text: str
metadata: NotRequired[PartsMetadataDict]
PollingHandlerType = type[PollingHandler]
StreamingHandlerType = type[StreamingHandler]
PushNotificationHandlerType = type[PushNotificationHandler]
HandlerType = PollingHandlerType | StreamingHandlerType | PushNotificationHandlerType
HANDLER_REGISTRY: dict[type[UpdateConfig], HandlerType] = {
PollingConfig: PollingHandler,
StreamingConfig: StreamingHandler,
PushNotificationConfig: PushNotificationHandler,
}

View File

@@ -0,0 +1,35 @@
"""A2A update mechanism configuration types."""
from crewai_a2a.updates.base import (
BaseHandlerKwargs,
PollingHandlerKwargs,
PushNotificationHandlerKwargs,
PushNotificationResultStore,
StreamingHandlerKwargs,
UpdateHandler,
)
from crewai_a2a.updates.polling.config import PollingConfig
from crewai_a2a.updates.polling.handler import PollingHandler
from crewai_a2a.updates.push_notifications.config import PushNotificationConfig
from crewai_a2a.updates.push_notifications.handler import PushNotificationHandler
from crewai_a2a.updates.streaming.config import StreamingConfig
from crewai_a2a.updates.streaming.handler import StreamingHandler
UpdateConfig = PollingConfig | StreamingConfig | PushNotificationConfig
__all__ = [
"BaseHandlerKwargs",
"PollingConfig",
"PollingHandler",
"PollingHandlerKwargs",
"PushNotificationConfig",
"PushNotificationHandler",
"PushNotificationHandlerKwargs",
"PushNotificationResultStore",
"StreamingConfig",
"StreamingHandler",
"StreamingHandlerKwargs",
"UpdateConfig",
"UpdateHandler",
]

View File

@@ -0,0 +1,176 @@
"""Base types for A2A update mechanism handlers."""
from __future__ import annotations
from typing import TYPE_CHECKING, Any, NamedTuple, Protocol, TypedDict
from pydantic import GetCoreSchemaHandler
from pydantic_core import CoreSchema, core_schema
class CommonParams(NamedTuple):
"""Common parameters shared across all update handlers.
Groups the frequently-passed parameters to reduce duplication.
"""
turn_number: int
is_multiturn: bool
agent_role: str | None
endpoint: str
a2a_agent_name: str | None
context_id: str | None
from_task: Any
from_agent: Any
if TYPE_CHECKING:
from a2a.client import Client
from a2a.types import AgentCard, Message, Task
from crewai_a2a.task_helpers import TaskStateResult
from crewai_a2a.updates.push_notifications.config import PushNotificationConfig
class BaseHandlerKwargs(TypedDict, total=False):
"""Base kwargs shared by all handlers."""
turn_number: int
is_multiturn: bool
agent_role: str | None
context_id: str | None
task_id: str | None
endpoint: str | None
agent_branch: Any
a2a_agent_name: str | None
from_task: Any
from_agent: Any
class PollingHandlerKwargs(BaseHandlerKwargs, total=False):
"""Kwargs for polling handler."""
polling_interval: float
polling_timeout: float
history_length: int
max_polls: int | None
class StreamingHandlerKwargs(BaseHandlerKwargs, total=False):
"""Kwargs for streaming handler."""
class PushNotificationHandlerKwargs(BaseHandlerKwargs, total=False):
"""Kwargs for push notification handler."""
config: PushNotificationConfig
result_store: PushNotificationResultStore
polling_timeout: float
polling_interval: float
class PushNotificationResultStore(Protocol):
"""Protocol for storing and retrieving push notification results.
This protocol defines the interface for a result store that the
PushNotificationHandler uses to wait for task completion.
"""
@classmethod
def __get_pydantic_core_schema__(
cls,
_source_type: Any,
_handler: GetCoreSchemaHandler,
) -> CoreSchema:
return core_schema.any_schema()
async def wait_for_result(
self,
task_id: str,
timeout: float,
poll_interval: float = 1.0,
) -> Task | None:
"""Wait for a task result to be available.
Args:
task_id: The task ID to wait for.
timeout: Max seconds to wait before returning None.
poll_interval: Seconds between polling attempts.
Returns:
The completed Task object, or None if timeout.
"""
...
async def get_result(self, task_id: str) -> Task | None:
"""Get a task result if available.
Args:
task_id: The task ID to retrieve.
Returns:
The Task object if available, None otherwise.
"""
...
async def store_result(self, task: Task) -> None:
"""Store a task result.
Args:
task: The Task object to store.
"""
...
class UpdateHandler(Protocol):
"""Protocol for A2A update mechanism handlers."""
@staticmethod
async def execute(
client: Client,
message: Message,
new_messages: list[Message],
agent_card: AgentCard,
**kwargs: Any,
) -> TaskStateResult:
"""Execute the update mechanism and return result.
Args:
client: A2A client instance.
message: Message to send.
new_messages: List to collect messages (modified in place).
agent_card: The agent card.
**kwargs: Additional handler-specific parameters.
Returns:
Result dictionary with status, result/error, and history.
"""
...
def extract_common_params(kwargs: BaseHandlerKwargs) -> CommonParams:
"""Extract common parameters from handler kwargs.
Args:
kwargs: Handler kwargs dict.
Returns:
CommonParams with extracted values.
Raises:
ValueError: If endpoint is not provided.
"""
endpoint = kwargs.get("endpoint")
if endpoint is None:
raise ValueError("endpoint is required for update handlers")
return CommonParams(
turn_number=kwargs.get("turn_number", 0),
is_multiturn=kwargs.get("is_multiturn", False),
agent_role=kwargs.get("agent_role"),
endpoint=endpoint,
a2a_agent_name=kwargs.get("a2a_agent_name"),
context_id=kwargs.get("context_id"),
from_task=kwargs.get("from_task"),
from_agent=kwargs.get("from_agent"),
)

View File

@@ -0,0 +1 @@
"""Polling update mechanism module."""

View File

@@ -0,0 +1,25 @@
"""Polling update mechanism configuration."""
from __future__ import annotations
from pydantic import BaseModel, Field
class PollingConfig(BaseModel):
"""Configuration for polling-based task updates.
Attributes:
interval: Seconds between poll attempts.
timeout: Max seconds to poll before raising timeout error.
max_polls: Max number of poll attempts.
history_length: Number of messages to retrieve per poll.
"""
interval: float = Field(
default=2.0, gt=0, description="Seconds between poll attempts"
)
timeout: float | None = Field(default=None, gt=0, description="Max seconds to poll")
max_polls: int | None = Field(default=None, gt=0, description="Max poll attempts")
history_length: int = Field(
default=100, gt=0, description="Messages to retrieve per poll"
)

View File

@@ -0,0 +1,359 @@
"""Polling update mechanism handler."""
from __future__ import annotations
import asyncio
import time
from typing import TYPE_CHECKING, Any
import uuid
from a2a.client import Client
from a2a.client.errors import A2AClientHTTPError
from a2a.types import (
AgentCard,
Message,
Part,
Role,
TaskQueryParams,
TaskState,
TextPart,
)
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.a2a_events import (
A2AConnectionErrorEvent,
A2APollingStartedEvent,
A2APollingStatusEvent,
A2AResponseReceivedEvent,
)
from typing_extensions import Unpack
from crewai_a2a.errors import A2APollingTimeoutError
from crewai_a2a.task_helpers import (
ACTIONABLE_STATES,
TERMINAL_STATES,
TaskStateResult,
process_task_state,
send_message_and_get_task_id,
)
from crewai_a2a.updates.base import PollingHandlerKwargs
if TYPE_CHECKING:
from a2a.types import Task as A2ATask
async def _poll_task_until_complete(
client: Client,
task_id: str,
polling_interval: float,
polling_timeout: float,
agent_branch: Any | None = None,
history_length: int = 100,
max_polls: int | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
context_id: str | None = None,
endpoint: str | None = None,
a2a_agent_name: str | None = None,
) -> A2ATask:
"""Poll task status until terminal state reached.
Args:
client: A2A client instance.
task_id: Task ID to poll.
polling_interval: Seconds between poll attempts.
polling_timeout: Max seconds before timeout.
agent_branch: Agent tree branch for logging.
history_length: Number of messages to retrieve per poll.
max_polls: Max number of poll attempts (None = unlimited).
from_task: Optional CrewAI Task object for event metadata.
from_agent: Optional CrewAI Agent object for event metadata.
context_id: A2A context ID for correlation.
endpoint: A2A agent endpoint URL.
a2a_agent_name: Name of the A2A agent from agent card.
Returns:
Final task object in terminal state.
Raises:
A2APollingTimeoutError: If polling exceeds timeout or max_polls.
"""
start_time = time.monotonic()
poll_count = 0
while True:
poll_count += 1
task = await client.get_task(
TaskQueryParams(id=task_id, history_length=history_length)
)
elapsed = time.monotonic() - start_time
effective_context_id = task.context_id or context_id
crewai_event_bus.emit(
agent_branch,
A2APollingStatusEvent(
task_id=task_id,
context_id=effective_context_id,
state=str(task.status.state.value),
elapsed_seconds=elapsed,
poll_count=poll_count,
endpoint=endpoint,
a2a_agent_name=a2a_agent_name,
from_task=from_task,
from_agent=from_agent,
),
)
if task.status.state in TERMINAL_STATES | ACTIONABLE_STATES:
return task
if elapsed > polling_timeout:
raise A2APollingTimeoutError(
f"Polling timeout after {polling_timeout}s ({poll_count} polls)"
)
if max_polls and poll_count >= max_polls:
raise A2APollingTimeoutError(
f"Max polls ({max_polls}) exceeded after {elapsed:.1f}s"
)
await asyncio.sleep(polling_interval)
class PollingHandler:
"""Polling-based update handler."""
@staticmethod
async def execute(
client: Client,
message: Message,
new_messages: list[Message],
agent_card: AgentCard,
**kwargs: Unpack[PollingHandlerKwargs],
) -> TaskStateResult:
"""Execute A2A delegation using polling for updates.
Args:
client: A2A client instance.
message: Message to send.
new_messages: List to collect messages.
agent_card: The agent card.
**kwargs: Polling-specific parameters.
Returns:
Dictionary with status, result/error, and history.
"""
polling_interval = kwargs.get("polling_interval", 2.0)
polling_timeout = kwargs.get("polling_timeout", 300.0)
endpoint = kwargs.get("endpoint", "")
agent_branch = kwargs.get("agent_branch")
turn_number = kwargs.get("turn_number", 0)
is_multiturn = kwargs.get("is_multiturn", False)
agent_role = kwargs.get("agent_role")
history_length = kwargs.get("history_length", 100)
max_polls = kwargs.get("max_polls")
context_id = kwargs.get("context_id")
task_id = kwargs.get("task_id")
a2a_agent_name = kwargs.get("a2a_agent_name")
from_task = kwargs.get("from_task")
from_agent = kwargs.get("from_agent")
try:
result_or_task_id = await send_message_and_get_task_id(
event_stream=client.send_message(message),
new_messages=new_messages,
agent_card=agent_card,
turn_number=turn_number,
is_multiturn=is_multiturn,
agent_role=agent_role,
from_task=from_task,
from_agent=from_agent,
endpoint=endpoint,
a2a_agent_name=a2a_agent_name,
context_id=context_id,
)
if not isinstance(result_or_task_id, str):
return result_or_task_id
task_id = result_or_task_id
crewai_event_bus.emit(
agent_branch,
A2APollingStartedEvent(
task_id=task_id,
context_id=context_id,
polling_interval=polling_interval,
endpoint=endpoint,
a2a_agent_name=a2a_agent_name,
from_task=from_task,
from_agent=from_agent,
),
)
final_task = await _poll_task_until_complete(
client=client,
task_id=task_id,
polling_interval=polling_interval,
polling_timeout=polling_timeout,
agent_branch=agent_branch,
history_length=history_length,
max_polls=max_polls,
from_task=from_task,
from_agent=from_agent,
context_id=context_id,
endpoint=endpoint,
a2a_agent_name=a2a_agent_name,
)
result = process_task_state(
a2a_task=final_task,
new_messages=new_messages,
agent_card=agent_card,
turn_number=turn_number,
is_multiturn=is_multiturn,
agent_role=agent_role,
endpoint=endpoint,
a2a_agent_name=a2a_agent_name,
from_task=from_task,
from_agent=from_agent,
)
if result:
return result
return TaskStateResult(
status=TaskState.failed,
error=f"Unexpected task state: {final_task.status.state}",
history=new_messages,
)
except A2APollingTimeoutError as e:
error_msg = str(e)
error_message = Message(
role=Role.agent,
message_id=str(uuid.uuid4()),
parts=[Part(root=TextPart(text=error_msg))],
context_id=context_id,
task_id=task_id,
)
new_messages.append(error_message)
crewai_event_bus.emit(
agent_branch,
A2AResponseReceivedEvent(
response=error_msg,
turn_number=turn_number,
context_id=context_id,
is_multiturn=is_multiturn,
status="failed",
final=True,
agent_role=agent_role,
endpoint=endpoint,
a2a_agent_name=a2a_agent_name,
from_task=from_task,
from_agent=from_agent,
),
)
return TaskStateResult(
status=TaskState.failed,
error=error_msg,
history=new_messages,
)
except A2AClientHTTPError as e:
error_msg = f"HTTP Error {e.status_code}: {e!s}"
error_message = Message(
role=Role.agent,
message_id=str(uuid.uuid4()),
parts=[Part(root=TextPart(text=error_msg))],
context_id=context_id,
task_id=task_id,
)
new_messages.append(error_message)
crewai_event_bus.emit(
agent_branch,
A2AConnectionErrorEvent(
endpoint=endpoint,
error=str(e),
error_type="http_error",
status_code=e.status_code,
a2a_agent_name=a2a_agent_name,
operation="polling",
context_id=context_id,
task_id=task_id,
from_task=from_task,
from_agent=from_agent,
),
)
crewai_event_bus.emit(
agent_branch,
A2AResponseReceivedEvent(
response=error_msg,
turn_number=turn_number,
context_id=context_id,
is_multiturn=is_multiturn,
status="failed",
final=True,
agent_role=agent_role,
endpoint=endpoint,
a2a_agent_name=a2a_agent_name,
from_task=from_task,
from_agent=from_agent,
),
)
return TaskStateResult(
status=TaskState.failed,
error=error_msg,
history=new_messages,
)
except Exception as e:
error_msg = f"Unexpected error during polling: {e!s}"
error_message = Message(
role=Role.agent,
message_id=str(uuid.uuid4()),
parts=[Part(root=TextPart(text=error_msg))],
context_id=context_id,
task_id=task_id,
)
new_messages.append(error_message)
crewai_event_bus.emit(
agent_branch,
A2AConnectionErrorEvent(
endpoint=endpoint,
error=str(e),
error_type="unexpected_error",
a2a_agent_name=a2a_agent_name,
operation="polling",
context_id=context_id,
task_id=task_id,
from_task=from_task,
from_agent=from_agent,
),
)
crewai_event_bus.emit(
agent_branch,
A2AResponseReceivedEvent(
response=error_msg,
turn_number=turn_number,
context_id=context_id,
is_multiturn=is_multiturn,
status="failed",
final=True,
agent_role=agent_role,
endpoint=endpoint,
a2a_agent_name=a2a_agent_name,
from_task=from_task,
from_agent=from_agent,
),
)
return TaskStateResult(
status=TaskState.failed,
error=error_msg,
history=new_messages,
)

View File

@@ -0,0 +1 @@
"""Push notification update mechanism module."""

View File

@@ -0,0 +1,65 @@
"""Push notification update mechanism configuration."""
from __future__ import annotations
from typing import Annotated
from a2a.types import PushNotificationAuthenticationInfo
from pydantic import AnyHttpUrl, BaseModel, BeforeValidator, Field
from crewai_a2a.updates.base import PushNotificationResultStore
from crewai_a2a.updates.push_notifications.signature import WebhookSignatureConfig
def _coerce_signature(
value: str | WebhookSignatureConfig | None,
) -> WebhookSignatureConfig | None:
"""Convert string secret to WebhookSignatureConfig."""
if value is None:
return None
if isinstance(value, str):
return WebhookSignatureConfig.hmac_sha256(secret=value)
return value
SignatureInput = Annotated[
WebhookSignatureConfig | None,
BeforeValidator(_coerce_signature),
]
class PushNotificationConfig(BaseModel):
"""Configuration for webhook-based task updates.
Attributes:
url: Callback URL where agent sends push notifications.
id: Unique identifier for this config.
token: Token to validate incoming notifications.
authentication: Auth info for agent to use when calling webhook.
timeout: Max seconds to wait for task completion.
interval: Seconds between result polling attempts.
result_store: Store for receiving push notification results.
signature: HMAC signature config. Pass a string (secret) for defaults,
or WebhookSignatureConfig for custom settings.
"""
url: AnyHttpUrl = Field(description="Callback URL for push notifications")
id: str | None = Field(default=None, description="Unique config identifier")
token: str | None = Field(default=None, description="Validation token")
authentication: PushNotificationAuthenticationInfo | None = Field(
default=None, description="Auth info for agent to use when calling webhook"
)
timeout: float | None = Field(
default=300.0, gt=0, description="Max seconds to wait for task completion"
)
interval: float = Field(
default=2.0, gt=0, description="Seconds between result polling attempts"
)
result_store: PushNotificationResultStore | None = Field(
default=None, description="Result store for push notification handling"
)
signature: SignatureInput = Field(
default=None,
description="HMAC signature config. Pass a string (secret) for simple usage, "
"or WebhookSignatureConfig for custom headers/tolerance.",
)

View File

@@ -0,0 +1,354 @@
"""Push notification (webhook) update mechanism handler."""
from __future__ import annotations
import logging
from typing import TYPE_CHECKING, Any
import uuid
from a2a.client import Client
from a2a.client.errors import A2AClientHTTPError
from a2a.types import (
AgentCard,
Message,
Part,
Role,
TaskState,
TextPart,
)
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.a2a_events import (
A2AConnectionErrorEvent,
A2APushNotificationRegisteredEvent,
A2APushNotificationTimeoutEvent,
A2AResponseReceivedEvent,
)
from typing_extensions import Unpack
from crewai_a2a.task_helpers import (
TaskStateResult,
process_task_state,
send_message_and_get_task_id,
)
from crewai_a2a.updates.base import (
CommonParams,
PushNotificationHandlerKwargs,
PushNotificationResultStore,
extract_common_params,
)
if TYPE_CHECKING:
from a2a.types import Task as A2ATask
logger = logging.getLogger(__name__)
def _handle_push_error(
error: Exception,
error_msg: str,
error_type: str,
new_messages: list[Message],
agent_branch: Any | None,
params: CommonParams,
task_id: str | None,
status_code: int | None = None,
) -> TaskStateResult:
"""Handle push notification errors with consistent event emission.
Args:
error: The exception that occurred.
error_msg: Formatted error message for the result.
error_type: Type of error for the event.
new_messages: List to append error message to.
agent_branch: Agent tree branch for events.
params: Common handler parameters.
task_id: A2A task ID.
status_code: HTTP status code if applicable.
Returns:
TaskStateResult with failed status.
"""
error_message = Message(
role=Role.agent,
message_id=str(uuid.uuid4()),
parts=[Part(root=TextPart(text=error_msg))],
context_id=params.context_id,
task_id=task_id,
)
new_messages.append(error_message)
crewai_event_bus.emit(
agent_branch,
A2AConnectionErrorEvent(
endpoint=params.endpoint,
error=str(error),
error_type=error_type,
status_code=status_code,
a2a_agent_name=params.a2a_agent_name,
operation="push_notification",
context_id=params.context_id,
task_id=task_id,
from_task=params.from_task,
from_agent=params.from_agent,
),
)
crewai_event_bus.emit(
agent_branch,
A2AResponseReceivedEvent(
response=error_msg,
turn_number=params.turn_number,
context_id=params.context_id,
is_multiturn=params.is_multiturn,
status="failed",
final=True,
agent_role=params.agent_role,
endpoint=params.endpoint,
a2a_agent_name=params.a2a_agent_name,
from_task=params.from_task,
from_agent=params.from_agent,
),
)
return TaskStateResult(
status=TaskState.failed,
error=error_msg,
history=new_messages,
)
async def _wait_for_push_result(
task_id: str,
result_store: PushNotificationResultStore,
timeout: float,
poll_interval: float,
agent_branch: Any | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
context_id: str | None = None,
endpoint: str | None = None,
a2a_agent_name: str | None = None,
) -> A2ATask | None:
"""Wait for push notification result.
Args:
task_id: Task ID to wait for.
result_store: Store to retrieve results from.
timeout: Max seconds to wait.
poll_interval: Seconds between polling attempts.
agent_branch: Agent tree branch for logging.
from_task: Optional CrewAI Task object for event metadata.
from_agent: Optional CrewAI Agent object for event metadata.
context_id: A2A context ID for correlation.
endpoint: A2A agent endpoint URL.
a2a_agent_name: Name of the A2A agent.
Returns:
Final task object, or None if timeout.
"""
task = await result_store.wait_for_result(
task_id=task_id,
timeout=timeout,
poll_interval=poll_interval,
)
if task is None:
crewai_event_bus.emit(
agent_branch,
A2APushNotificationTimeoutEvent(
task_id=task_id,
context_id=context_id,
timeout_seconds=timeout,
endpoint=endpoint,
a2a_agent_name=a2a_agent_name,
from_task=from_task,
from_agent=from_agent,
),
)
return task
class PushNotificationHandler:
"""Push notification (webhook) based update handler."""
@staticmethod
async def execute(
client: Client,
message: Message,
new_messages: list[Message],
agent_card: AgentCard,
**kwargs: Unpack[PushNotificationHandlerKwargs],
) -> TaskStateResult:
"""Execute A2A delegation using push notifications for updates.
Args:
client: A2A client instance.
message: Message to send.
new_messages: List to collect messages.
agent_card: The agent card.
**kwargs: Push notification-specific parameters.
Returns:
Dictionary with status, result/error, and history.
Raises:
ValueError: If result_store or config not provided.
"""
config = kwargs.get("config")
result_store = kwargs.get("result_store")
polling_timeout = kwargs.get("polling_timeout", 300.0)
polling_interval = kwargs.get("polling_interval", 2.0)
agent_branch = kwargs.get("agent_branch")
task_id = kwargs.get("task_id")
params = extract_common_params(kwargs)
if config is None:
error_msg = (
"PushNotificationConfig is required for push notification handler"
)
crewai_event_bus.emit(
agent_branch,
A2AConnectionErrorEvent(
endpoint=params.endpoint,
error=error_msg,
error_type="configuration_error",
a2a_agent_name=params.a2a_agent_name,
operation="push_notification",
context_id=params.context_id,
task_id=task_id,
from_task=params.from_task,
from_agent=params.from_agent,
),
)
return TaskStateResult(
status=TaskState.failed,
error=error_msg,
history=new_messages,
)
if result_store is None:
error_msg = (
"PushNotificationResultStore is required for push notification handler"
)
crewai_event_bus.emit(
agent_branch,
A2AConnectionErrorEvent(
endpoint=params.endpoint,
error=error_msg,
error_type="configuration_error",
a2a_agent_name=params.a2a_agent_name,
operation="push_notification",
context_id=params.context_id,
task_id=task_id,
from_task=params.from_task,
from_agent=params.from_agent,
),
)
return TaskStateResult(
status=TaskState.failed,
error=error_msg,
history=new_messages,
)
try:
result_or_task_id = await send_message_and_get_task_id(
event_stream=client.send_message(message),
new_messages=new_messages,
agent_card=agent_card,
turn_number=params.turn_number,
is_multiturn=params.is_multiturn,
agent_role=params.agent_role,
from_task=params.from_task,
from_agent=params.from_agent,
endpoint=params.endpoint,
a2a_agent_name=params.a2a_agent_name,
context_id=params.context_id,
)
if not isinstance(result_or_task_id, str):
return result_or_task_id
task_id = result_or_task_id
crewai_event_bus.emit(
agent_branch,
A2APushNotificationRegisteredEvent(
task_id=task_id,
context_id=params.context_id,
callback_url=str(config.url),
endpoint=params.endpoint,
a2a_agent_name=params.a2a_agent_name,
from_task=params.from_task,
from_agent=params.from_agent,
),
)
logger.debug(
"Push notification callback for task %s configured at %s (via initial request)",
task_id,
config.url,
)
final_task = await _wait_for_push_result(
task_id=task_id,
result_store=result_store,
timeout=polling_timeout,
poll_interval=polling_interval,
agent_branch=agent_branch,
from_task=params.from_task,
from_agent=params.from_agent,
context_id=params.context_id,
endpoint=params.endpoint,
a2a_agent_name=params.a2a_agent_name,
)
if final_task is None:
return TaskStateResult(
status=TaskState.failed,
error=f"Push notification timeout after {polling_timeout}s",
history=new_messages,
)
result = process_task_state(
a2a_task=final_task,
new_messages=new_messages,
agent_card=agent_card,
turn_number=params.turn_number,
is_multiturn=params.is_multiturn,
agent_role=params.agent_role,
endpoint=params.endpoint,
a2a_agent_name=params.a2a_agent_name,
from_task=params.from_task,
from_agent=params.from_agent,
)
if result:
return result
return TaskStateResult(
status=TaskState.failed,
error=f"Unexpected task state: {final_task.status.state}",
history=new_messages,
)
except A2AClientHTTPError as e:
return _handle_push_error(
error=e,
error_msg=f"HTTP Error {e.status_code}: {e!s}",
error_type="http_error",
new_messages=new_messages,
agent_branch=agent_branch,
params=params,
task_id=task_id,
status_code=e.status_code,
)
except Exception as e:
return _handle_push_error(
error=e,
error_msg=f"Unexpected error during push notification: {e!s}",
error_type="unexpected_error",
new_messages=new_messages,
agent_branch=agent_branch,
params=params,
task_id=task_id,
)

View File

@@ -0,0 +1,87 @@
"""Webhook signature configuration for push notifications."""
from __future__ import annotations
from enum import Enum
import secrets
from pydantic import BaseModel, Field, SecretStr
class WebhookSignatureMode(str, Enum):
"""Signature mode for webhook push notifications."""
NONE = "none"
HMAC_SHA256 = "hmac_sha256"
class WebhookSignatureConfig(BaseModel):
"""Configuration for webhook signature verification.
Provides cryptographic integrity verification and replay attack protection
for A2A push notifications.
Attributes:
mode: Signature mode (none or hmac_sha256).
secret: Shared secret for HMAC computation (required for hmac_sha256 mode).
timestamp_tolerance_seconds: Max allowed age of timestamps for replay protection.
header_name: HTTP header name for the signature.
timestamp_header_name: HTTP header name for the timestamp.
"""
mode: WebhookSignatureMode = Field(
default=WebhookSignatureMode.NONE,
description="Signature verification mode",
)
secret: SecretStr | None = Field(
default=None,
description="Shared secret for HMAC computation",
)
timestamp_tolerance_seconds: int = Field(
default=300,
ge=0,
description="Max allowed timestamp age in seconds (5 min default)",
)
header_name: str = Field(
default="X-A2A-Signature",
description="HTTP header name for the signature",
)
timestamp_header_name: str = Field(
default="X-A2A-Signature-Timestamp",
description="HTTP header name for the timestamp",
)
@classmethod
def generate_secret(cls, length: int = 32) -> str:
"""Generate a cryptographically secure random secret.
Args:
length: Number of random bytes to generate (default 32).
Returns:
URL-safe base64-encoded secret string.
"""
return secrets.token_urlsafe(length)
@classmethod
def hmac_sha256(
cls,
secret: str | SecretStr,
timestamp_tolerance_seconds: int = 300,
) -> WebhookSignatureConfig:
"""Create an HMAC-SHA256 signature configuration.
Args:
secret: Shared secret for HMAC computation.
timestamp_tolerance_seconds: Max allowed timestamp age in seconds.
Returns:
Configured WebhookSignatureConfig for HMAC-SHA256.
"""
if isinstance(secret, str):
secret = SecretStr(secret)
return cls(
mode=WebhookSignatureMode.HMAC_SHA256,
secret=secret,
timestamp_tolerance_seconds=timestamp_tolerance_seconds,
)

View File

@@ -0,0 +1 @@
"""Streaming update mechanism module."""

View File

@@ -0,0 +1,9 @@
"""Streaming update mechanism configuration."""
from __future__ import annotations
from pydantic import BaseModel
class StreamingConfig(BaseModel):
"""Configuration for SSE-based task updates."""

View File

@@ -0,0 +1,646 @@
"""Streaming (SSE) update mechanism handler."""
from __future__ import annotations
import asyncio
import logging
from typing import Final
import uuid
from a2a.client import Client
from a2a.client.errors import A2AClientHTTPError
from a2a.types import (
AgentCard,
Message,
Part,
Role,
Task,
TaskArtifactUpdateEvent,
TaskIdParams,
TaskQueryParams,
TaskState,
TaskStatusUpdateEvent,
TextPart,
)
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.a2a_events import (
A2AArtifactReceivedEvent,
A2AConnectionErrorEvent,
A2AResponseReceivedEvent,
A2AStreamingChunkEvent,
A2AStreamingStartedEvent,
)
from typing_extensions import Unpack
from crewai_a2a.task_helpers import (
ACTIONABLE_STATES,
TERMINAL_STATES,
TaskStateResult,
process_task_state,
)
from crewai_a2a.updates.base import StreamingHandlerKwargs, extract_common_params
from crewai_a2a.updates.streaming.params import (
process_status_update,
)
logger = logging.getLogger(__name__)
MAX_RESUBSCRIBE_ATTEMPTS: Final[int] = 3
RESUBSCRIBE_BACKOFF_BASE: Final[float] = 1.0
class StreamingHandler:
"""SSE streaming-based update handler."""
@staticmethod
async def _try_recover_from_interruption( # type: ignore[misc]
client: Client,
task_id: str,
new_messages: list[Message],
agent_card: AgentCard,
result_parts: list[str],
**kwargs: Unpack[StreamingHandlerKwargs],
) -> TaskStateResult | None:
"""Attempt to recover from a stream interruption by checking task state.
If the task completed while we were disconnected, returns the result.
If the task is still running, attempts to resubscribe and continue.
Args:
client: A2A client instance.
task_id: The task ID to recover.
new_messages: List of collected messages.
agent_card: The agent card.
result_parts: Accumulated result text parts.
**kwargs: Handler parameters.
Returns:
TaskStateResult if recovery succeeded (task finished or resubscribe worked).
None if recovery not possible (caller should handle failure).
Note:
When None is returned, recovery failed and the original exception should
be handled by the caller. All recovery attempts are logged.
"""
params = extract_common_params(kwargs) # type: ignore[arg-type]
try:
a2a_task: Task = await client.get_task(TaskQueryParams(id=task_id))
if a2a_task.status.state in TERMINAL_STATES:
logger.info(
"Task completed during stream interruption",
extra={"task_id": task_id, "state": str(a2a_task.status.state)},
)
return process_task_state(
a2a_task=a2a_task,
new_messages=new_messages,
agent_card=agent_card,
turn_number=params.turn_number,
is_multiturn=params.is_multiturn,
agent_role=params.agent_role,
result_parts=result_parts,
endpoint=params.endpoint,
a2a_agent_name=params.a2a_agent_name,
from_task=params.from_task,
from_agent=params.from_agent,
)
if a2a_task.status.state in ACTIONABLE_STATES:
logger.info(
"Task in actionable state during stream interruption",
extra={"task_id": task_id, "state": str(a2a_task.status.state)},
)
return process_task_state(
a2a_task=a2a_task,
new_messages=new_messages,
agent_card=agent_card,
turn_number=params.turn_number,
is_multiturn=params.is_multiturn,
agent_role=params.agent_role,
result_parts=result_parts,
endpoint=params.endpoint,
a2a_agent_name=params.a2a_agent_name,
from_task=params.from_task,
from_agent=params.from_agent,
is_final=False,
)
logger.info(
"Task still running, attempting resubscribe",
extra={"task_id": task_id, "state": str(a2a_task.status.state)},
)
for attempt in range(MAX_RESUBSCRIBE_ATTEMPTS):
try:
backoff = RESUBSCRIBE_BACKOFF_BASE * (2**attempt)
if attempt > 0:
await asyncio.sleep(backoff)
event_stream = client.resubscribe(TaskIdParams(id=task_id))
async for event in event_stream:
if isinstance(event, tuple):
resubscribed_task, update = event
is_final_update = (
process_status_update(update, result_parts)
if isinstance(update, TaskStatusUpdateEvent)
else False
)
if isinstance(update, TaskArtifactUpdateEvent):
artifact = update.artifact
result_parts.extend(
part.root.text
for part in artifact.parts
if part.root.kind == "text"
)
if (
is_final_update
or resubscribed_task.status.state
in TERMINAL_STATES | ACTIONABLE_STATES
):
return process_task_state(
a2a_task=resubscribed_task,
new_messages=new_messages,
agent_card=agent_card,
turn_number=params.turn_number,
is_multiturn=params.is_multiturn,
agent_role=params.agent_role,
result_parts=result_parts,
endpoint=params.endpoint,
a2a_agent_name=params.a2a_agent_name,
from_task=params.from_task,
from_agent=params.from_agent,
is_final=is_final_update,
)
elif isinstance(event, Message):
new_messages.append(event)
result_parts.extend(
part.root.text
for part in event.parts
if part.root.kind == "text"
)
final_task = await client.get_task(TaskQueryParams(id=task_id))
return process_task_state(
a2a_task=final_task,
new_messages=new_messages,
agent_card=agent_card,
turn_number=params.turn_number,
is_multiturn=params.is_multiturn,
agent_role=params.agent_role,
result_parts=result_parts,
endpoint=params.endpoint,
a2a_agent_name=params.a2a_agent_name,
from_task=params.from_task,
from_agent=params.from_agent,
)
except Exception as resubscribe_error: # noqa: PERF203
logger.warning(
"Resubscribe attempt failed",
extra={
"task_id": task_id,
"attempt": attempt + 1,
"max_attempts": MAX_RESUBSCRIBE_ATTEMPTS,
"error": str(resubscribe_error),
},
)
if attempt == MAX_RESUBSCRIBE_ATTEMPTS - 1:
return None
except Exception as e:
logger.warning(
"Failed to recover from stream interruption due to unexpected error",
extra={
"task_id": task_id,
"error": str(e),
"error_type": type(e).__name__,
},
exc_info=True,
)
return None
logger.warning(
"Recovery exhausted all resubscribe attempts without success",
extra={"task_id": task_id, "max_attempts": MAX_RESUBSCRIBE_ATTEMPTS},
)
return None
@staticmethod
async def execute(
client: Client,
message: Message,
new_messages: list[Message],
agent_card: AgentCard,
**kwargs: Unpack[StreamingHandlerKwargs],
) -> TaskStateResult:
"""Execute A2A delegation using SSE streaming for updates.
Args:
client: A2A client instance.
message: Message to send.
new_messages: List to collect messages.
agent_card: The agent card.
**kwargs: Streaming-specific parameters.
Returns:
Dictionary with status, result/error, and history.
"""
task_id = kwargs.get("task_id")
agent_branch = kwargs.get("agent_branch")
params = extract_common_params(kwargs)
result_parts: list[str] = []
final_result: TaskStateResult | None = None
event_stream = client.send_message(message)
chunk_index = 0
current_task_id: str | None = task_id
crewai_event_bus.emit(
agent_branch,
A2AStreamingStartedEvent(
task_id=task_id,
context_id=params.context_id,
endpoint=params.endpoint,
a2a_agent_name=params.a2a_agent_name,
turn_number=params.turn_number,
is_multiturn=params.is_multiturn,
agent_role=params.agent_role,
from_task=params.from_task,
from_agent=params.from_agent,
),
)
try:
async for event in event_stream:
if isinstance(event, tuple):
a2a_task, _ = event
current_task_id = a2a_task.id
if isinstance(event, Message):
new_messages.append(event)
message_context_id = event.context_id or params.context_id
for part in event.parts:
if part.root.kind == "text":
text = part.root.text
result_parts.append(text)
crewai_event_bus.emit(
agent_branch,
A2AStreamingChunkEvent(
task_id=event.task_id or task_id,
context_id=message_context_id,
chunk=text,
chunk_index=chunk_index,
endpoint=params.endpoint,
a2a_agent_name=params.a2a_agent_name,
turn_number=params.turn_number,
is_multiturn=params.is_multiturn,
from_task=params.from_task,
from_agent=params.from_agent,
),
)
chunk_index += 1
elif isinstance(event, tuple):
a2a_task, update = event
if isinstance(update, TaskArtifactUpdateEvent):
artifact = update.artifact
result_parts.extend(
part.root.text
for part in artifact.parts
if part.root.kind == "text"
)
artifact_size = None
if artifact.parts:
artifact_size = sum(
len(p.root.text.encode())
if p.root.kind == "text"
else len(getattr(p.root, "data", b""))
for p in artifact.parts
)
effective_context_id = a2a_task.context_id or params.context_id
crewai_event_bus.emit(
agent_branch,
A2AArtifactReceivedEvent(
task_id=a2a_task.id,
artifact_id=artifact.artifact_id,
artifact_name=artifact.name,
artifact_description=artifact.description,
mime_type=artifact.parts[0].root.kind
if artifact.parts
else None,
size_bytes=artifact_size,
append=update.append or False,
last_chunk=update.last_chunk or False,
endpoint=params.endpoint,
a2a_agent_name=params.a2a_agent_name,
context_id=effective_context_id,
turn_number=params.turn_number,
is_multiturn=params.is_multiturn,
from_task=params.from_task,
from_agent=params.from_agent,
),
)
is_final_update = (
process_status_update(update, result_parts)
if isinstance(update, TaskStatusUpdateEvent)
else False
)
if (
not is_final_update
and a2a_task.status.state
not in TERMINAL_STATES | ACTIONABLE_STATES
):
continue
final_result = process_task_state(
a2a_task=a2a_task,
new_messages=new_messages,
agent_card=agent_card,
turn_number=params.turn_number,
is_multiturn=params.is_multiturn,
agent_role=params.agent_role,
result_parts=result_parts,
endpoint=params.endpoint,
a2a_agent_name=params.a2a_agent_name,
from_task=params.from_task,
from_agent=params.from_agent,
is_final=is_final_update,
)
if final_result:
break
except A2AClientHTTPError as e:
if current_task_id:
logger.info(
"Stream interrupted with HTTP error, attempting recovery",
extra={
"task_id": current_task_id,
"error": str(e),
"status_code": e.status_code,
},
)
recovery_kwargs = {k: v for k, v in kwargs.items() if k != "task_id"}
recovered_result = (
await StreamingHandler._try_recover_from_interruption(
client=client,
task_id=current_task_id,
new_messages=new_messages,
agent_card=agent_card,
result_parts=result_parts,
**recovery_kwargs,
)
)
if recovered_result:
logger.info(
"Successfully recovered task after HTTP error",
extra={
"task_id": current_task_id,
"status": str(recovered_result.get("status")),
},
)
return recovered_result
logger.warning(
"Failed to recover from HTTP error, returning failure",
extra={
"task_id": current_task_id,
"status_code": e.status_code,
"original_error": str(e),
},
)
error_msg = f"HTTP Error {e.status_code}: {e!s}"
error_type = "http_error"
status_code = e.status_code
error_message = Message(
role=Role.agent,
message_id=str(uuid.uuid4()),
parts=[Part(root=TextPart(text=error_msg))],
context_id=params.context_id,
task_id=task_id,
)
new_messages.append(error_message)
crewai_event_bus.emit(
agent_branch,
A2AConnectionErrorEvent(
endpoint=params.endpoint,
error=str(e),
error_type=error_type,
status_code=status_code,
a2a_agent_name=params.a2a_agent_name,
operation="streaming",
context_id=params.context_id,
task_id=task_id,
from_task=params.from_task,
from_agent=params.from_agent,
),
)
crewai_event_bus.emit(
agent_branch,
A2AResponseReceivedEvent(
response=error_msg,
turn_number=params.turn_number,
context_id=params.context_id,
is_multiturn=params.is_multiturn,
status="failed",
final=True,
agent_role=params.agent_role,
endpoint=params.endpoint,
a2a_agent_name=params.a2a_agent_name,
from_task=params.from_task,
from_agent=params.from_agent,
),
)
return TaskStateResult(
status=TaskState.failed,
error=error_msg,
history=new_messages,
)
except (asyncio.TimeoutError, asyncio.CancelledError, ConnectionError) as e:
error_type = type(e).__name__.lower()
if current_task_id:
logger.info(
f"Stream interrupted with {error_type}, attempting recovery",
extra={"task_id": current_task_id, "error": str(e)},
)
recovery_kwargs = {k: v for k, v in kwargs.items() if k != "task_id"}
recovered_result = (
await StreamingHandler._try_recover_from_interruption(
client=client,
task_id=current_task_id,
new_messages=new_messages,
agent_card=agent_card,
result_parts=result_parts,
**recovery_kwargs,
)
)
if recovered_result:
logger.info(
f"Successfully recovered task after {error_type}",
extra={
"task_id": current_task_id,
"status": str(recovered_result.get("status")),
},
)
return recovered_result
logger.warning(
f"Failed to recover from {error_type}, returning failure",
extra={
"task_id": current_task_id,
"error_type": error_type,
"original_error": str(e),
},
)
error_msg = f"Connection error during streaming: {e!s}"
status_code = None
error_message = Message(
role=Role.agent,
message_id=str(uuid.uuid4()),
parts=[Part(root=TextPart(text=error_msg))],
context_id=params.context_id,
task_id=task_id,
)
new_messages.append(error_message)
crewai_event_bus.emit(
agent_branch,
A2AConnectionErrorEvent(
endpoint=params.endpoint,
error=str(e),
error_type=error_type,
status_code=status_code,
a2a_agent_name=params.a2a_agent_name,
operation="streaming",
context_id=params.context_id,
task_id=task_id,
from_task=params.from_task,
from_agent=params.from_agent,
),
)
crewai_event_bus.emit(
agent_branch,
A2AResponseReceivedEvent(
response=error_msg,
turn_number=params.turn_number,
context_id=params.context_id,
is_multiturn=params.is_multiturn,
status="failed",
final=True,
agent_role=params.agent_role,
endpoint=params.endpoint,
a2a_agent_name=params.a2a_agent_name,
from_task=params.from_task,
from_agent=params.from_agent,
),
)
return TaskStateResult(
status=TaskState.failed,
error=error_msg,
history=new_messages,
)
except Exception as e:
logger.exception(
"Unexpected error during streaming",
extra={
"task_id": current_task_id,
"error_type": type(e).__name__,
"endpoint": params.endpoint,
},
)
error_msg = f"Unexpected error during streaming: {type(e).__name__}: {e!s}"
error_type = "unexpected_error"
status_code = None
error_message = Message(
role=Role.agent,
message_id=str(uuid.uuid4()),
parts=[Part(root=TextPart(text=error_msg))],
context_id=params.context_id,
task_id=task_id,
)
new_messages.append(error_message)
crewai_event_bus.emit(
agent_branch,
A2AConnectionErrorEvent(
endpoint=params.endpoint,
error=str(e),
error_type=error_type,
status_code=status_code,
a2a_agent_name=params.a2a_agent_name,
operation="streaming",
context_id=params.context_id,
task_id=task_id,
from_task=params.from_task,
from_agent=params.from_agent,
),
)
crewai_event_bus.emit(
agent_branch,
A2AResponseReceivedEvent(
response=error_msg,
turn_number=params.turn_number,
context_id=params.context_id,
is_multiturn=params.is_multiturn,
status="failed",
final=True,
agent_role=params.agent_role,
endpoint=params.endpoint,
a2a_agent_name=params.a2a_agent_name,
from_task=params.from_task,
from_agent=params.from_agent,
),
)
return TaskStateResult(
status=TaskState.failed,
error=error_msg,
history=new_messages,
)
finally:
aclose = getattr(event_stream, "aclose", None)
if aclose:
try:
await aclose()
except Exception as close_error:
crewai_event_bus.emit(
agent_branch,
A2AConnectionErrorEvent(
endpoint=params.endpoint,
error=str(close_error),
error_type="stream_close_error",
a2a_agent_name=params.a2a_agent_name,
operation="stream_close",
context_id=params.context_id,
task_id=task_id,
from_task=params.from_task,
from_agent=params.from_agent,
),
)
if final_result:
return final_result
return TaskStateResult(
status=TaskState.completed,
result=" ".join(result_parts) if result_parts else "",
history=new_messages,
agent_card=agent_card.model_dump(exclude_none=True),
)

View File

@@ -0,0 +1,28 @@
"""Common parameter extraction for streaming handlers."""
from __future__ import annotations
from a2a.types import TaskStatusUpdateEvent
def process_status_update(
update: TaskStatusUpdateEvent,
result_parts: list[str],
) -> bool:
"""Process a status update event and extract text parts.
Args:
update: The status update event.
result_parts: List to append text parts to (modified in place).
Returns:
True if this is a final update, False otherwise.
"""
is_final = update.final
if update.status and update.status.message and update.status.message.parts:
result_parts.extend(
part.root.text
for part in update.status.message.parts
if part.root.kind == "text" and part.root.text
)
return is_final

View File

@@ -0,0 +1 @@
"""A2A utility modules for client operations."""

View File

@@ -0,0 +1,587 @@
"""AgentCard utilities for A2A client and server operations."""
from __future__ import annotations
import asyncio
from collections.abc import MutableMapping
from functools import lru_cache
import ssl
import time
from types import MethodType
from typing import TYPE_CHECKING
from a2a.client.errors import A2AClientHTTPError
from a2a.types import AgentCapabilities, AgentCard, AgentSkill
from aiocache import cached # type: ignore[import-untyped]
from aiocache.serializers import PickleSerializer # type: ignore[import-untyped]
from crewai.crew import Crew
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.a2a_events import (
A2AAgentCardFetchedEvent,
A2AAuthenticationFailedEvent,
A2AConnectionErrorEvent,
)
import httpx
from crewai_a2a.auth.client_schemes import APIKeyAuth, HTTPDigestAuth
from crewai_a2a.auth.utils import (
_auth_store,
configure_auth_client,
retry_on_401,
)
from crewai_a2a.config import A2AServerConfig
if TYPE_CHECKING:
from crewai.agent import Agent
from crewai.task import Task
from crewai_a2a.auth.client_schemes import ClientAuthScheme
def _get_tls_verify(auth: ClientAuthScheme | None) -> ssl.SSLContext | bool | str:
"""Get TLS verify parameter from auth scheme.
Args:
auth: Optional authentication scheme with TLS config.
Returns:
SSL context, CA cert path, True for default verification,
or False if verification disabled.
"""
if auth and auth.tls:
return auth.tls.get_httpx_ssl_context()
return True
async def _prepare_auth_headers(
auth: ClientAuthScheme | None,
timeout: int,
) -> tuple[MutableMapping[str, str], ssl.SSLContext | bool | str]:
"""Prepare authentication headers and TLS verification settings.
Args:
auth: Optional authentication scheme.
timeout: Request timeout in seconds.
Returns:
Tuple of (headers dict, TLS verify setting).
"""
headers: MutableMapping[str, str] = {}
verify = _get_tls_verify(auth)
if auth:
async with httpx.AsyncClient(
timeout=timeout, verify=verify
) as temp_auth_client:
if isinstance(auth, (HTTPDigestAuth, APIKeyAuth)):
configure_auth_client(auth, temp_auth_client)
headers = await auth.apply_auth(temp_auth_client, {})
return headers, verify
def _get_server_config(agent: Agent) -> A2AServerConfig | None:
"""Get A2AServerConfig from an agent's a2a configuration.
Args:
agent: The Agent instance to check.
Returns:
A2AServerConfig if present, None otherwise.
"""
if agent.a2a is None:
return None
if isinstance(agent.a2a, A2AServerConfig):
return agent.a2a
if isinstance(agent.a2a, list):
for config in agent.a2a:
if isinstance(config, A2AServerConfig):
return config
return None
def fetch_agent_card(
endpoint: str,
auth: ClientAuthScheme | None = None,
timeout: int = 30,
use_cache: bool = True,
cache_ttl: int = 300,
) -> AgentCard:
"""Fetch AgentCard from an A2A endpoint with optional caching.
Args:
endpoint: A2A agent endpoint URL (AgentCard URL).
auth: Optional ClientAuthScheme for authentication.
timeout: Request timeout in seconds.
use_cache: Whether to use caching (default True).
cache_ttl: Cache TTL in seconds (default 300 = 5 minutes).
Returns:
AgentCard object with agent capabilities and skills.
Raises:
httpx.HTTPStatusError: If the request fails.
A2AClientHTTPError: If authentication fails.
"""
if use_cache:
if auth:
auth_data = auth.model_dump_json(
exclude={
"_access_token",
"_token_expires_at",
"_refresh_token",
"_authorization_callback",
}
)
auth_hash = _auth_store.compute_key(type(auth).__name__, auth_data)
else:
auth_hash = _auth_store.compute_key("none", "")
_auth_store.set(auth_hash, auth)
ttl_hash = int(time.time() // cache_ttl)
return _fetch_agent_card_cached(endpoint, auth_hash, timeout, ttl_hash)
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
return loop.run_until_complete(
afetch_agent_card(endpoint=endpoint, auth=auth, timeout=timeout)
)
finally:
loop.close()
async def afetch_agent_card(
endpoint: str,
auth: ClientAuthScheme | None = None,
timeout: int = 30,
use_cache: bool = True,
) -> AgentCard:
"""Fetch AgentCard from an A2A endpoint asynchronously.
Native async implementation. Use this when running in an async context.
Args:
endpoint: A2A agent endpoint URL (AgentCard URL).
auth: Optional ClientAuthScheme for authentication.
timeout: Request timeout in seconds.
use_cache: Whether to use caching (default True).
Returns:
AgentCard object with agent capabilities and skills.
Raises:
httpx.HTTPStatusError: If the request fails.
A2AClientHTTPError: If authentication fails.
"""
if use_cache:
if auth:
auth_data = auth.model_dump_json(
exclude={
"_access_token",
"_token_expires_at",
"_refresh_token",
"_authorization_callback",
}
)
auth_hash = _auth_store.compute_key(type(auth).__name__, auth_data)
else:
auth_hash = _auth_store.compute_key("none", "")
_auth_store.set(auth_hash, auth)
agent_card: AgentCard = await _afetch_agent_card_cached(
endpoint, auth_hash, timeout
)
return agent_card
return await _afetch_agent_card_impl(endpoint=endpoint, auth=auth, timeout=timeout)
@lru_cache()
def _fetch_agent_card_cached(
endpoint: str,
auth_hash: str,
timeout: int,
_ttl_hash: int,
) -> AgentCard:
"""Cached sync version of fetch_agent_card."""
auth = _auth_store.get(auth_hash)
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
return loop.run_until_complete(
_afetch_agent_card_impl(endpoint=endpoint, auth=auth, timeout=timeout)
)
finally:
loop.close()
@cached(ttl=300, serializer=PickleSerializer()) # type: ignore[untyped-decorator]
async def _afetch_agent_card_cached(
endpoint: str,
auth_hash: str,
timeout: int,
) -> AgentCard:
"""Cached async implementation of AgentCard fetching."""
auth = _auth_store.get(auth_hash)
return await _afetch_agent_card_impl(endpoint=endpoint, auth=auth, timeout=timeout)
async def _afetch_agent_card_impl(
endpoint: str,
auth: ClientAuthScheme | None,
timeout: int,
) -> AgentCard:
"""Internal async implementation of AgentCard fetching."""
start_time = time.perf_counter()
if "/.well-known/agent-card.json" in endpoint:
base_url = endpoint.replace("/.well-known/agent-card.json", "")
agent_card_path = "/.well-known/agent-card.json"
else:
url_parts = endpoint.split("/", 3)
base_url = f"{url_parts[0]}//{url_parts[2]}"
agent_card_path = (
f"/{url_parts[3]}"
if len(url_parts) > 3 and url_parts[3]
else "/.well-known/agent-card.json"
)
headers, verify = await _prepare_auth_headers(auth, timeout)
async with httpx.AsyncClient(
timeout=timeout, headers=headers, verify=verify
) as temp_client:
if auth and isinstance(auth, (HTTPDigestAuth, APIKeyAuth)):
configure_auth_client(auth, temp_client)
agent_card_url = f"{base_url}{agent_card_path}"
async def _fetch_agent_card_request() -> httpx.Response:
return await temp_client.get(agent_card_url)
try:
response = await retry_on_401(
request_func=_fetch_agent_card_request,
auth_scheme=auth,
client=temp_client,
headers=temp_client.headers,
max_retries=2,
)
response.raise_for_status()
agent_card = AgentCard.model_validate(response.json())
fetch_time_ms = (time.perf_counter() - start_time) * 1000
agent_card_dict = agent_card.model_dump(exclude_none=True)
crewai_event_bus.emit(
None,
A2AAgentCardFetchedEvent(
endpoint=endpoint,
a2a_agent_name=agent_card.name,
agent_card=agent_card_dict,
protocol_version=agent_card.protocol_version,
provider=agent_card_dict.get("provider"),
cached=False,
fetch_time_ms=fetch_time_ms,
),
)
return agent_card
except httpx.HTTPStatusError as e:
elapsed_ms = (time.perf_counter() - start_time) * 1000
response_body = e.response.text[:1000] if e.response.text else None
if e.response.status_code == 401:
error_details = ["Authentication failed"]
www_auth = e.response.headers.get("WWW-Authenticate")
if www_auth:
error_details.append(f"WWW-Authenticate: {www_auth}")
if not auth:
error_details.append("No auth scheme provided")
msg = " | ".join(error_details)
auth_type = type(auth).__name__ if auth else None
crewai_event_bus.emit(
None,
A2AAuthenticationFailedEvent(
endpoint=endpoint,
auth_type=auth_type,
error=msg,
status_code=401,
metadata={
"elapsed_ms": elapsed_ms,
"response_body": response_body,
"www_authenticate": www_auth,
"request_url": str(e.request.url),
},
),
)
raise A2AClientHTTPError(401, msg) from e
crewai_event_bus.emit(
None,
A2AConnectionErrorEvent(
endpoint=endpoint,
error=str(e),
error_type="http_error",
status_code=e.response.status_code,
operation="fetch_agent_card",
metadata={
"elapsed_ms": elapsed_ms,
"response_body": response_body,
"request_url": str(e.request.url),
},
),
)
raise
except httpx.TimeoutException as e:
elapsed_ms = (time.perf_counter() - start_time) * 1000
crewai_event_bus.emit(
None,
A2AConnectionErrorEvent(
endpoint=endpoint,
error=str(e),
error_type="timeout",
operation="fetch_agent_card",
metadata={
"elapsed_ms": elapsed_ms,
"timeout_config": timeout,
"request_url": str(e.request.url) if e.request else None,
},
),
)
raise
except httpx.ConnectError as e:
elapsed_ms = (time.perf_counter() - start_time) * 1000
crewai_event_bus.emit(
None,
A2AConnectionErrorEvent(
endpoint=endpoint,
error=str(e),
error_type="connection_error",
operation="fetch_agent_card",
metadata={
"elapsed_ms": elapsed_ms,
"request_url": str(e.request.url) if e.request else None,
},
),
)
raise
except httpx.RequestError as e:
elapsed_ms = (time.perf_counter() - start_time) * 1000
crewai_event_bus.emit(
None,
A2AConnectionErrorEvent(
endpoint=endpoint,
error=str(e),
error_type="request_error",
operation="fetch_agent_card",
metadata={
"elapsed_ms": elapsed_ms,
"request_url": str(e.request.url) if e.request else None,
},
),
)
raise
def _task_to_skill(task: Task) -> AgentSkill:
"""Convert a CrewAI Task to an A2A AgentSkill.
Args:
task: The CrewAI Task to convert.
Returns:
AgentSkill representing the task's capability.
"""
task_name = task.name or task.description[:50]
task_id = task_name.lower().replace(" ", "_")
tags: list[str] = []
if task.agent:
tags.append(task.agent.role.lower().replace(" ", "-"))
return AgentSkill(
id=task_id,
name=task_name,
description=task.description,
tags=tags,
examples=[task.expected_output] if task.expected_output else None,
)
def _tool_to_skill(tool_name: str, tool_description: str) -> AgentSkill:
"""Convert an Agent's tool to an A2A AgentSkill.
Args:
tool_name: Name of the tool.
tool_description: Description of what the tool does.
Returns:
AgentSkill representing the tool's capability.
"""
tool_id = tool_name.lower().replace(" ", "_")
return AgentSkill(
id=tool_id,
name=tool_name,
description=tool_description,
tags=[tool_name.lower().replace(" ", "-")],
)
def _crew_to_agent_card(crew: Crew, url: str) -> AgentCard:
"""Generate an A2A AgentCard from a Crew instance.
Args:
crew: The Crew instance to generate a card for.
url: The base URL where this crew will be exposed.
Returns:
AgentCard describing the crew's capabilities.
"""
crew_name = getattr(crew, "name", None) or crew.__class__.__name__
description_parts: list[str] = []
crew_description = getattr(crew, "description", None)
if crew_description:
description_parts.append(crew_description)
else:
agent_roles = [agent.role for agent in crew.agents]
description_parts.append(
f"A crew of {len(crew.agents)} agents: {', '.join(agent_roles)}"
)
skills = [_task_to_skill(task) for task in crew.tasks]
return AgentCard(
name=crew_name,
description=" ".join(description_parts),
url=url,
version="1.0.0",
capabilities=AgentCapabilities(
streaming=True,
push_notifications=True,
),
default_input_modes=["text/plain", "application/json"],
default_output_modes=["text/plain", "application/json"],
skills=skills,
)
def _agent_to_agent_card(agent: Agent, url: str) -> AgentCard:
"""Generate an A2A AgentCard from an Agent instance.
Uses A2AServerConfig values when available, falling back to agent properties.
If signing_config is provided, the card will be signed with JWS.
Args:
agent: The Agent instance to generate a card for.
url: The base URL where this agent will be exposed.
Returns:
AgentCard describing the agent's capabilities.
"""
from crewai_a2a.utils.agent_card_signing import sign_agent_card
server_config = _get_server_config(agent) or A2AServerConfig()
name = server_config.name or agent.role
description_parts = [agent.goal]
if agent.backstory:
description_parts.append(agent.backstory)
description = server_config.description or " ".join(description_parts)
skills: list[AgentSkill] = (
server_config.skills.copy() if server_config.skills else []
)
if not skills:
if agent.tools:
for tool in agent.tools:
tool_name = getattr(tool, "name", None) or tool.__class__.__name__
tool_desc = getattr(tool, "description", None) or f"Tool: {tool_name}"
skills.append(_tool_to_skill(tool_name, tool_desc))
if not skills:
skills.append(
AgentSkill(
id=agent.role.lower().replace(" ", "_"),
name=agent.role,
description=agent.goal,
tags=[agent.role.lower().replace(" ", "-")],
)
)
capabilities = server_config.capabilities
if server_config.server_extensions:
from crewai_a2a.extensions.server import ServerExtensionRegistry
registry = ServerExtensionRegistry(server_config.server_extensions)
ext_list = registry.get_agent_extensions()
existing_exts = list(capabilities.extensions) if capabilities.extensions else []
existing_uris = {e.uri for e in existing_exts}
for ext in ext_list:
if ext.uri not in existing_uris:
existing_exts.append(ext)
capabilities = capabilities.model_copy(update={"extensions": existing_exts})
card = AgentCard(
name=name,
description=description,
url=server_config.url or url,
version=server_config.version,
capabilities=capabilities,
default_input_modes=server_config.default_input_modes,
default_output_modes=server_config.default_output_modes,
skills=skills,
preferred_transport=server_config.transport.preferred,
protocol_version=server_config.protocol_version,
provider=server_config.provider,
documentation_url=server_config.documentation_url,
icon_url=server_config.icon_url,
additional_interfaces=server_config.additional_interfaces,
security=server_config.security,
security_schemes=server_config.security_schemes,
supports_authenticated_extended_card=server_config.supports_authenticated_extended_card,
)
if server_config.signing_config:
signature = sign_agent_card(
card,
private_key=server_config.signing_config.get_private_key(),
key_id=server_config.signing_config.key_id,
algorithm=server_config.signing_config.algorithm,
)
card = card.model_copy(update={"signatures": [signature]})
elif server_config.signatures:
card = card.model_copy(update={"signatures": server_config.signatures})
return card
def inject_a2a_server_methods(agent: Agent) -> None:
"""Inject A2A server methods onto an Agent instance.
Adds a `to_agent_card(url: str) -> AgentCard` method to the agent
that generates an A2A-compliant AgentCard.
Only injects if the agent has an A2AServerConfig.
Args:
agent: The Agent instance to inject methods onto.
"""
if _get_server_config(agent) is None:
return
def _to_agent_card(self: Agent, url: str) -> AgentCard:
return _agent_to_agent_card(self, url)
object.__setattr__(agent, "to_agent_card", MethodType(_to_agent_card, agent))

View File

@@ -0,0 +1,236 @@
"""AgentCard JWS signing utilities.
This module provides functions for signing and verifying AgentCards using
JSON Web Signatures (JWS) as per RFC 7515. Signed agent cards allow clients
to verify the authenticity and integrity of agent card information.
Example:
>>> from crewai_a2a.utils.agent_card_signing import sign_agent_card
>>> signature = sign_agent_card(agent_card, private_key_pem, key_id="key-1")
>>> card_with_sig = card.model_copy(update={"signatures": [signature]})
"""
from __future__ import annotations
import base64
import json
import logging
from typing import Any, Literal
from a2a.types import AgentCard, AgentCardSignature
import jwt
from pydantic import SecretStr
logger = logging.getLogger(__name__)
SigningAlgorithm = Literal[
"RS256", "RS384", "RS512", "ES256", "ES384", "ES512", "PS256", "PS384", "PS512"
]
def _normalize_private_key(private_key: str | bytes | SecretStr) -> bytes:
"""Normalize private key to bytes format.
Args:
private_key: PEM-encoded private key as string, bytes, or SecretStr.
Returns:
Private key as bytes.
"""
if isinstance(private_key, SecretStr):
private_key = private_key.get_secret_value()
if isinstance(private_key, str):
private_key = private_key.encode()
return private_key
def _serialize_agent_card(agent_card: AgentCard) -> str:
"""Serialize AgentCard to canonical JSON for signing.
Excludes the signatures field to avoid circular reference during signing.
Uses sorted keys and compact separators for deterministic output.
Args:
agent_card: The AgentCard to serialize.
Returns:
Canonical JSON string representation.
"""
card_dict = agent_card.model_dump(exclude={"signatures"}, exclude_none=True)
return json.dumps(card_dict, sort_keys=True, separators=(",", ":"))
def _base64url_encode(data: bytes | str) -> str:
"""Encode data to URL-safe base64 without padding.
Args:
data: Data to encode.
Returns:
URL-safe base64 encoded string without padding.
"""
if isinstance(data, str):
data = data.encode()
return base64.urlsafe_b64encode(data).rstrip(b"=").decode("ascii")
def sign_agent_card(
agent_card: AgentCard,
private_key: str | bytes | SecretStr,
key_id: str | None = None,
algorithm: SigningAlgorithm = "RS256",
) -> AgentCardSignature:
"""Sign an AgentCard using JWS (RFC 7515).
Creates a detached JWS signature for the AgentCard. The signature covers
all fields except the signatures field itself.
Args:
agent_card: The AgentCard to sign.
private_key: PEM-encoded private key (RSA, EC, or RSA-PSS).
key_id: Optional key identifier for the JWS header (kid claim).
algorithm: Signing algorithm (RS256, ES256, PS256, etc.).
Returns:
AgentCardSignature with protected header and signature.
Raises:
jwt.exceptions.InvalidKeyError: If the private key is invalid.
ValueError: If the algorithm is not supported for the key type.
Example:
>>> signature = sign_agent_card(
... agent_card,
... private_key_pem="-----BEGIN PRIVATE KEY-----...",
... key_id="my-key-id",
... )
"""
key_bytes = _normalize_private_key(private_key)
payload = _serialize_agent_card(agent_card)
protected_header: dict[str, Any] = {"typ": "JWS"}
if key_id:
protected_header["kid"] = key_id
jws_token = jwt.api_jws.encode(
payload.encode(),
key_bytes,
algorithm=algorithm,
headers=protected_header,
)
parts = jws_token.split(".")
protected_b64 = parts[0]
signature_b64 = parts[2]
header: dict[str, Any] | None = None
if key_id:
header = {"kid": key_id}
return AgentCardSignature(
protected=protected_b64,
signature=signature_b64,
header=header,
)
def verify_agent_card_signature(
agent_card: AgentCard,
signature: AgentCardSignature,
public_key: str | bytes,
algorithms: list[str] | None = None,
) -> bool:
"""Verify an AgentCard JWS signature.
Validates that the signature was created with the corresponding private key
and that the AgentCard content has not been modified.
Args:
agent_card: The AgentCard to verify.
signature: The AgentCardSignature to validate.
public_key: PEM-encoded public key (RSA, EC, or RSA-PSS).
algorithms: List of allowed algorithms. Defaults to common asymmetric algorithms.
Returns:
True if signature is valid, False otherwise.
Example:
>>> is_valid = verify_agent_card_signature(
... agent_card, signature, public_key_pem="-----BEGIN PUBLIC KEY-----..."
... )
"""
if algorithms is None:
algorithms = [
"RS256",
"RS384",
"RS512",
"ES256",
"ES384",
"ES512",
"PS256",
"PS384",
"PS512",
]
if isinstance(public_key, str):
public_key = public_key.encode()
payload = _serialize_agent_card(agent_card)
payload_b64 = _base64url_encode(payload)
jws_token = f"{signature.protected}.{payload_b64}.{signature.signature}"
try:
jwt.api_jws.decode(
jws_token,
public_key,
algorithms=algorithms,
)
return True
except jwt.InvalidSignatureError:
logger.debug(
"AgentCard signature verification failed",
extra={"reason": "invalid_signature"},
)
return False
except jwt.DecodeError as e:
logger.debug(
"AgentCard signature verification failed",
extra={"reason": "decode_error", "error": str(e)},
)
return False
except jwt.InvalidAlgorithmError as e:
logger.debug(
"AgentCard signature verification failed",
extra={"reason": "algorithm_error", "error": str(e)},
)
return False
def get_key_id_from_signature(signature: AgentCardSignature) -> str | None:
"""Extract the key ID (kid) from an AgentCardSignature.
Checks both the unprotected header and the protected header for the kid claim.
Args:
signature: The AgentCardSignature to extract from.
Returns:
The key ID if present, None otherwise.
"""
if signature.header and "kid" in signature.header:
kid: str = signature.header["kid"]
return kid
try:
protected = signature.protected
padding_needed = 4 - (len(protected) % 4)
if padding_needed != 4:
protected += "=" * padding_needed
protected_json = base64.urlsafe_b64decode(protected).decode()
protected_header: dict[str, Any] = json.loads(protected_json)
return protected_header.get("kid")
except (ValueError, json.JSONDecodeError):
return None

View File

@@ -0,0 +1,338 @@
"""Content type negotiation for A2A protocol.
This module handles negotiation of input/output MIME types between A2A clients
and servers based on AgentCard capabilities.
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING, Annotated, Final, Literal, cast
from a2a.types import Part
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.a2a_events import A2AContentTypeNegotiatedEvent
if TYPE_CHECKING:
from a2a.types import AgentCard, AgentSkill
TEXT_PLAIN: Literal["text/plain"] = "text/plain"
APPLICATION_JSON: Literal["application/json"] = "application/json"
IMAGE_PNG: Literal["image/png"] = "image/png"
IMAGE_JPEG: Literal["image/jpeg"] = "image/jpeg"
IMAGE_WILDCARD: Literal["image/*"] = "image/*"
APPLICATION_PDF: Literal["application/pdf"] = "application/pdf"
APPLICATION_OCTET_STREAM: Literal["application/octet-stream"] = (
"application/octet-stream"
)
DEFAULT_CLIENT_INPUT_MODES: Final[list[Literal["text/plain", "application/json"]]] = [
TEXT_PLAIN,
APPLICATION_JSON,
]
DEFAULT_CLIENT_OUTPUT_MODES: Final[list[Literal["text/plain", "application/json"]]] = [
TEXT_PLAIN,
APPLICATION_JSON,
]
@dataclass
class NegotiatedContentTypes:
"""Result of content type negotiation."""
input_modes: Annotated[list[str], "Negotiated input MIME types the client can send"]
output_modes: Annotated[
list[str], "Negotiated output MIME types the server will produce"
]
effective_input_modes: Annotated[list[str], "Server's effective input modes"]
effective_output_modes: Annotated[list[str], "Server's effective output modes"]
skill_name: Annotated[
str | None, "Skill name if negotiation was skill-specific"
] = None
class ContentTypeNegotiationError(Exception):
"""Raised when no compatible content types can be negotiated."""
def __init__(
self,
client_input_modes: list[str],
client_output_modes: list[str],
server_input_modes: list[str],
server_output_modes: list[str],
direction: str = "both",
message: str | None = None,
) -> None:
self.client_input_modes = client_input_modes
self.client_output_modes = client_output_modes
self.server_input_modes = server_input_modes
self.server_output_modes = server_output_modes
self.direction = direction
if message is None:
if direction == "input":
message = (
f"No compatible input content types. "
f"Client supports: {client_input_modes}, "
f"Server accepts: {server_input_modes}"
)
elif direction == "output":
message = (
f"No compatible output content types. "
f"Client accepts: {client_output_modes}, "
f"Server produces: {server_output_modes}"
)
else:
message = (
f"No compatible content types. "
f"Input - Client: {client_input_modes}, Server: {server_input_modes}. "
f"Output - Client: {client_output_modes}, Server: {server_output_modes}"
)
super().__init__(message)
def _normalize_mime_type(mime_type: str) -> str:
"""Normalize MIME type for comparison (lowercase, strip whitespace)."""
return mime_type.lower().strip()
def _mime_types_compatible(client_type: str, server_type: str) -> bool:
"""Check if two MIME types are compatible.
Handles wildcards like image/* matching image/png.
"""
client_normalized = _normalize_mime_type(client_type)
server_normalized = _normalize_mime_type(server_type)
if client_normalized == server_normalized:
return True
if "*" in client_normalized or "*" in server_normalized:
client_parts = client_normalized.split("/")
server_parts = server_normalized.split("/")
if len(client_parts) == 2 and len(server_parts) == 2:
type_match = (
client_parts[0] == server_parts[0]
or client_parts[0] == "*"
or server_parts[0] == "*"
)
subtype_match = (
client_parts[1] == server_parts[1]
or client_parts[1] == "*"
or server_parts[1] == "*"
)
return type_match and subtype_match
return False
def _find_compatible_modes(
client_modes: list[str], server_modes: list[str]
) -> list[str]:
"""Find compatible MIME types between client and server.
Returns modes in client preference order.
"""
compatible = []
for client_mode in client_modes:
for server_mode in server_modes:
if _mime_types_compatible(client_mode, server_mode):
if "*" in client_mode and "*" not in server_mode:
if server_mode not in compatible:
compatible.append(server_mode)
else:
if client_mode not in compatible:
compatible.append(client_mode)
break
return compatible
def _get_effective_modes(
agent_card: AgentCard,
skill_name: str | None = None,
) -> tuple[list[str], list[str], AgentSkill | None]:
"""Get effective input/output modes from agent card.
If skill_name is provided and the skill has custom modes, those are used.
Otherwise, falls back to agent card defaults.
"""
skill: AgentSkill | None = None
if skill_name and agent_card.skills:
for s in agent_card.skills:
if s.name == skill_name or s.id == skill_name:
skill = s
break
if skill:
input_modes = (
skill.input_modes if skill.input_modes else agent_card.default_input_modes
)
output_modes = (
skill.output_modes
if skill.output_modes
else agent_card.default_output_modes
)
else:
input_modes = agent_card.default_input_modes
output_modes = agent_card.default_output_modes
return input_modes, output_modes, skill
def negotiate_content_types(
agent_card: AgentCard,
client_input_modes: list[str] | None = None,
client_output_modes: list[str] | None = None,
skill_name: str | None = None,
emit_event: bool = True,
endpoint: str | None = None,
a2a_agent_name: str | None = None,
strict: bool = False,
) -> NegotiatedContentTypes:
"""Negotiate content types between client and server.
Args:
agent_card: The remote agent's card with capability info.
client_input_modes: MIME types the client can send. Defaults to text/plain and application/json.
client_output_modes: MIME types the client can accept. Defaults to text/plain and application/json.
skill_name: Optional skill to use for mode lookup.
emit_event: Whether to emit a content type negotiation event.
endpoint: Agent endpoint (for event metadata).
a2a_agent_name: Agent name (for event metadata).
strict: If True, raises error when no compatible types found.
If False, returns empty lists for incompatible directions.
Returns:
NegotiatedContentTypes with compatible input and output modes.
Raises:
ContentTypeNegotiationError: If strict=True and no compatible types found.
"""
if client_input_modes is None:
client_input_modes = cast(list[str], DEFAULT_CLIENT_INPUT_MODES.copy())
if client_output_modes is None:
client_output_modes = cast(list[str], DEFAULT_CLIENT_OUTPUT_MODES.copy())
server_input_modes, server_output_modes, skill = _get_effective_modes(
agent_card, skill_name
)
compatible_input = _find_compatible_modes(client_input_modes, server_input_modes)
compatible_output = _find_compatible_modes(client_output_modes, server_output_modes)
if strict:
if not compatible_input and not compatible_output:
raise ContentTypeNegotiationError(
client_input_modes=client_input_modes,
client_output_modes=client_output_modes,
server_input_modes=server_input_modes,
server_output_modes=server_output_modes,
)
if not compatible_input:
raise ContentTypeNegotiationError(
client_input_modes=client_input_modes,
client_output_modes=client_output_modes,
server_input_modes=server_input_modes,
server_output_modes=server_output_modes,
direction="input",
)
if not compatible_output:
raise ContentTypeNegotiationError(
client_input_modes=client_input_modes,
client_output_modes=client_output_modes,
server_input_modes=server_input_modes,
server_output_modes=server_output_modes,
direction="output",
)
result = NegotiatedContentTypes(
input_modes=compatible_input,
output_modes=compatible_output,
effective_input_modes=server_input_modes,
effective_output_modes=server_output_modes,
skill_name=skill.name if skill else None,
)
if emit_event:
crewai_event_bus.emit(
None,
A2AContentTypeNegotiatedEvent(
endpoint=endpoint or agent_card.url,
a2a_agent_name=a2a_agent_name or agent_card.name,
skill_name=skill_name,
client_input_modes=client_input_modes,
client_output_modes=client_output_modes,
server_input_modes=server_input_modes,
server_output_modes=server_output_modes,
negotiated_input_modes=compatible_input,
negotiated_output_modes=compatible_output,
negotiation_success=bool(compatible_input and compatible_output),
),
)
return result
def validate_content_type(
content_type: str,
allowed_modes: list[str],
) -> bool:
"""Validate that a content type is allowed by a list of modes.
Args:
content_type: The MIME type to validate.
allowed_modes: List of allowed MIME types (may include wildcards).
Returns:
True if content_type is compatible with any allowed mode.
"""
for mode in allowed_modes:
if _mime_types_compatible(content_type, mode):
return True
return False
def get_part_content_type(part: Part) -> str:
"""Extract MIME type from an A2A Part.
Args:
part: A Part object containing TextPart, DataPart, or FilePart.
Returns:
The MIME type string for this part.
"""
root = part.root
if root.kind == "text":
return TEXT_PLAIN
if root.kind == "data":
return APPLICATION_JSON
if root.kind == "file":
return root.file.mime_type or APPLICATION_OCTET_STREAM
return APPLICATION_OCTET_STREAM
def validate_message_parts(
parts: list[Part],
allowed_modes: list[str],
) -> list[str]:
"""Validate that all message parts have allowed content types.
Args:
parts: List of Parts from the incoming message.
allowed_modes: List of allowed MIME types (from default_input_modes).
Returns:
List of invalid content types found (empty if all valid).
"""
invalid_types: list[str] = []
for part in parts:
content_type = get_part_content_type(part)
if not validate_content_type(content_type, allowed_modes):
if content_type not in invalid_types:
invalid_types.append(content_type)
return invalid_types

View File

@@ -0,0 +1,980 @@
"""A2A delegation utilities for executing tasks on remote agents."""
from __future__ import annotations
import asyncio
import base64
from collections.abc import AsyncIterator, Callable, MutableMapping
from contextlib import asynccontextmanager
import logging
from typing import TYPE_CHECKING, Any, Final, Literal
import uuid
from a2a.client import Client, ClientConfig, ClientFactory
from a2a.types import (
AgentCard,
FilePart,
FileWithBytes,
Message,
Part,
PushNotificationConfig as A2APushNotificationConfig,
Role,
TextPart,
)
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.a2a_events import (
A2AConversationStartedEvent,
A2ADelegationCompletedEvent,
A2ADelegationStartedEvent,
A2AMessageSentEvent,
)
import httpx
from pydantic import BaseModel
from crewai_a2a.auth.client_schemes import APIKeyAuth, HTTPDigestAuth
from crewai_a2a.auth.utils import (
_auth_store,
configure_auth_client,
validate_auth_against_agent_card,
)
from crewai_a2a.config import ClientTransportConfig, GRPCClientConfig
from crewai_a2a.extensions.registry import (
ExtensionsMiddleware,
validate_required_extensions,
)
from crewai_a2a.task_helpers import TaskStateResult
from crewai_a2a.types import (
HANDLER_REGISTRY,
HandlerType,
PartsDict,
PartsMetadataDict,
TransportType,
)
from crewai_a2a.updates import (
PollingConfig,
PushNotificationConfig,
StreamingHandler,
UpdateConfig,
)
from crewai_a2a.utils.agent_card import (
_afetch_agent_card_cached,
_get_tls_verify,
_prepare_auth_headers,
)
from crewai_a2a.utils.content_type import (
DEFAULT_CLIENT_OUTPUT_MODES,
negotiate_content_types,
)
from crewai_a2a.utils.transport import (
NegotiatedTransport,
TransportNegotiationError,
negotiate_transport,
)
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from a2a.types import Message
from crewai_a2a.auth.client_schemes import ClientAuthScheme
_DEFAULT_TRANSPORT: Final[TransportType] = "JSONRPC"
def _create_file_parts(input_files: dict[str, Any] | None) -> list[Part]:
"""Convert FileInput dictionary to FilePart objects.
Args:
input_files: Dictionary mapping names to FileInput objects.
Returns:
List of Part objects containing FilePart data.
"""
if not input_files:
return []
try:
import crewai_files # noqa: F401
except ImportError:
logger.debug("crewai_files not installed, skipping file parts")
return []
parts: list[Part] = []
for name, file_input in input_files.items():
content_bytes = file_input.read()
content_base64 = base64.b64encode(content_bytes).decode()
file_with_bytes = FileWithBytes(
bytes=content_base64,
mimeType=file_input.content_type,
name=file_input.filename or name,
)
parts.append(Part(root=FilePart(file=file_with_bytes)))
return parts
def get_handler(config: UpdateConfig | None) -> HandlerType:
"""Get the handler class for a given update config.
Args:
config: Update mechanism configuration.
Returns:
Handler class for the config type, defaults to StreamingHandler.
"""
if config is None:
return StreamingHandler
return HANDLER_REGISTRY.get(type(config), StreamingHandler)
def execute_a2a_delegation(
endpoint: str,
auth: ClientAuthScheme | None,
timeout: int,
task_description: str,
context: str | None = None,
context_id: str | None = None,
task_id: str | None = None,
reference_task_ids: list[str] | None = None,
metadata: dict[str, Any] | None = None,
extensions: dict[str, Any] | None = None,
conversation_history: list[Message] | None = None,
agent_id: str | None = None,
agent_role: Role | None = None,
agent_branch: Any | None = None,
response_model: type[BaseModel] | None = None,
turn_number: int | None = None,
updates: UpdateConfig | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
skill_id: str | None = None,
client_extensions: list[str] | None = None,
transport: ClientTransportConfig | None = None,
accepted_output_modes: list[str] | None = None,
input_files: dict[str, Any] | None = None,
) -> TaskStateResult:
"""Execute a task delegation to a remote A2A agent synchronously.
WARNING: This function blocks the entire thread by creating and running a new
event loop. Prefer using 'await aexecute_a2a_delegation()' in async contexts
for better performance and resource efficiency.
This is a synchronous wrapper around aexecute_a2a_delegation that creates a
new event loop to run the async implementation. It is provided for compatibility
with synchronous code paths only.
Args:
endpoint: A2A agent endpoint URL (AgentCard URL).
auth: Optional ClientAuthScheme for authentication.
timeout: Request timeout in seconds.
task_description: The task to delegate.
context: Optional context information.
context_id: Context ID for correlating messages/tasks.
task_id: Specific task identifier.
reference_task_ids: List of related task IDs.
metadata: Additional metadata.
extensions: Protocol extensions for custom fields.
conversation_history: Previous Message objects from conversation.
agent_id: Agent identifier for logging.
agent_role: Role of the CrewAI agent delegating the task.
agent_branch: Optional agent tree branch for logging.
response_model: Optional Pydantic model for structured outputs.
turn_number: Optional turn number for multi-turn conversations.
updates: Update mechanism config from A2AConfig.updates.
from_task: Optional CrewAI Task object for event metadata.
from_agent: Optional CrewAI Agent object for event metadata.
skill_id: Optional skill ID to target a specific agent capability.
client_extensions: A2A protocol extension URIs the client supports.
transport: Transport configuration (preferred, supported transports, gRPC settings).
accepted_output_modes: MIME types the client can accept in responses.
input_files: Optional dictionary of files to send to remote agent.
Returns:
TaskStateResult with status, result/error, history, and agent_card.
Raises:
RuntimeError: If called from an async context with a running event loop.
"""
try:
asyncio.get_running_loop()
raise RuntimeError(
"execute_a2a_delegation() cannot be called from an async context. "
"Use 'await aexecute_a2a_delegation()' instead."
)
except RuntimeError as e:
if "no running event loop" not in str(e).lower():
raise
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
return loop.run_until_complete(
aexecute_a2a_delegation(
endpoint=endpoint,
auth=auth,
timeout=timeout,
task_description=task_description,
context=context,
context_id=context_id,
task_id=task_id,
reference_task_ids=reference_task_ids,
metadata=metadata,
extensions=extensions,
conversation_history=conversation_history,
agent_id=agent_id,
agent_role=agent_role,
agent_branch=agent_branch,
response_model=response_model,
turn_number=turn_number,
updates=updates,
from_task=from_task,
from_agent=from_agent,
skill_id=skill_id,
client_extensions=client_extensions,
transport=transport,
accepted_output_modes=accepted_output_modes,
input_files=input_files,
)
)
finally:
try:
loop.run_until_complete(loop.shutdown_asyncgens())
finally:
loop.close()
async def aexecute_a2a_delegation(
endpoint: str,
auth: ClientAuthScheme | None,
timeout: int,
task_description: str,
context: str | None = None,
context_id: str | None = None,
task_id: str | None = None,
reference_task_ids: list[str] | None = None,
metadata: dict[str, Any] | None = None,
extensions: dict[str, Any] | None = None,
conversation_history: list[Message] | None = None,
agent_id: str | None = None,
agent_role: Role | None = None,
agent_branch: Any | None = None,
response_model: type[BaseModel] | None = None,
turn_number: int | None = None,
updates: UpdateConfig | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
skill_id: str | None = None,
client_extensions: list[str] | None = None,
transport: ClientTransportConfig | None = None,
accepted_output_modes: list[str] | None = None,
input_files: dict[str, Any] | None = None,
) -> TaskStateResult:
"""Execute a task delegation to a remote A2A agent asynchronously.
Native async implementation with multi-turn support. Use this when running
in an async context (e.g., with Crew.akickoff() or agent.aexecute_task()).
Args:
endpoint: A2A agent endpoint URL.
auth: Optional ClientAuthScheme for authentication.
timeout: Request timeout in seconds.
task_description: The task to delegate.
context: Optional context information.
context_id: Context ID for correlating messages/tasks.
task_id: Specific task identifier.
reference_task_ids: List of related task IDs.
metadata: Additional metadata.
extensions: Protocol extensions for custom fields.
conversation_history: Previous Message objects from conversation.
agent_id: Agent identifier for logging.
agent_role: Role of the CrewAI agent delegating the task.
agent_branch: Optional agent tree branch for logging.
response_model: Optional Pydantic model for structured outputs.
turn_number: Optional turn number for multi-turn conversations.
updates: Update mechanism config from A2AConfig.updates.
from_task: Optional CrewAI Task object for event metadata.
from_agent: Optional CrewAI Agent object for event metadata.
skill_id: Optional skill ID to target a specific agent capability.
client_extensions: A2A protocol extension URIs the client supports.
transport: Transport configuration (preferred, supported transports, gRPC settings).
accepted_output_modes: MIME types the client can accept in responses.
input_files: Optional dictionary of files to send to remote agent.
Returns:
TaskStateResult with status, result/error, history, and agent_card.
"""
if conversation_history is None:
conversation_history = []
is_multiturn = len(conversation_history) > 0
if turn_number is None:
turn_number = len([m for m in conversation_history if m.role == Role.user]) + 1
try:
result = await _aexecute_a2a_delegation_impl(
endpoint=endpoint,
auth=auth,
timeout=timeout,
task_description=task_description,
context=context,
context_id=context_id,
task_id=task_id,
reference_task_ids=reference_task_ids,
metadata=metadata,
extensions=extensions,
conversation_history=conversation_history,
is_multiturn=is_multiturn,
turn_number=turn_number,
agent_branch=agent_branch,
agent_id=agent_id,
agent_role=agent_role,
response_model=response_model,
updates=updates,
from_task=from_task,
from_agent=from_agent,
skill_id=skill_id,
client_extensions=client_extensions,
transport=transport,
accepted_output_modes=accepted_output_modes,
input_files=input_files,
)
except Exception as e:
crewai_event_bus.emit(
agent_branch,
A2ADelegationCompletedEvent(
status="failed",
result=None,
error=str(e),
context_id=context_id,
is_multiturn=is_multiturn,
endpoint=endpoint,
metadata=metadata,
extensions=list(extensions.keys()) if extensions else None,
from_task=from_task,
from_agent=from_agent,
),
)
raise
agent_card_data = result.get("agent_card")
crewai_event_bus.emit(
agent_branch,
A2ADelegationCompletedEvent(
status=result["status"],
result=result.get("result"),
error=result.get("error"),
context_id=context_id,
is_multiturn=is_multiturn,
endpoint=endpoint,
a2a_agent_name=result.get("a2a_agent_name"),
agent_card=agent_card_data,
provider=agent_card_data.get("provider") if agent_card_data else None,
metadata=metadata,
extensions=list(extensions.keys()) if extensions else None,
from_task=from_task,
from_agent=from_agent,
),
)
return result
async def _aexecute_a2a_delegation_impl(
endpoint: str,
auth: ClientAuthScheme | None,
timeout: int,
task_description: str,
context: str | None,
context_id: str | None,
task_id: str | None,
reference_task_ids: list[str] | None,
metadata: dict[str, Any] | None,
extensions: dict[str, Any] | None,
conversation_history: list[Message],
is_multiturn: bool,
turn_number: int,
agent_branch: Any | None,
agent_id: str | None,
agent_role: str | None,
response_model: type[BaseModel] | None,
updates: UpdateConfig | None,
from_task: Any | None = None,
from_agent: Any | None = None,
skill_id: str | None = None,
client_extensions: list[str] | None = None,
transport: ClientTransportConfig | None = None,
accepted_output_modes: list[str] | None = None,
input_files: dict[str, Any] | None = None,
) -> TaskStateResult:
"""Internal async implementation of A2A delegation."""
if transport is None:
transport = ClientTransportConfig()
if auth:
auth_data = auth.model_dump_json(
exclude={
"_access_token",
"_token_expires_at",
"_refresh_token",
"_authorization_callback",
}
)
auth_hash = _auth_store.compute_key(type(auth).__name__, auth_data)
else:
auth_hash = _auth_store.compute_key("none", endpoint)
_auth_store.set(auth_hash, auth)
agent_card = await _afetch_agent_card_cached(
endpoint=endpoint, auth_hash=auth_hash, timeout=timeout
)
validate_auth_against_agent_card(agent_card, auth)
unsupported_exts = validate_required_extensions(agent_card, client_extensions)
if unsupported_exts:
ext_uris = [ext.uri for ext in unsupported_exts]
raise ValueError(
f"Agent requires extensions not supported by client: {ext_uris}"
)
negotiated: NegotiatedTransport | None = None
effective_transport: TransportType = transport.preferred or _DEFAULT_TRANSPORT
effective_url = endpoint
client_transports: list[str] = (
list(transport.supported) if transport.supported else [_DEFAULT_TRANSPORT]
)
try:
negotiated = negotiate_transport(
agent_card=agent_card,
client_supported_transports=client_transports,
client_preferred_transport=transport.preferred,
endpoint=endpoint,
a2a_agent_name=agent_card.name,
)
effective_transport = negotiated.transport # type: ignore[assignment]
effective_url = negotiated.url
except TransportNegotiationError as e:
logger.warning(
"Transport negotiation failed, using fallback",
extra={
"error": str(e),
"fallback_transport": effective_transport,
"fallback_url": effective_url,
"endpoint": endpoint,
"client_transports": client_transports,
"server_transports": [
iface.transport for iface in agent_card.additional_interfaces or []
]
+ [agent_card.preferred_transport or "JSONRPC"],
},
)
effective_output_modes = accepted_output_modes or DEFAULT_CLIENT_OUTPUT_MODES.copy()
content_negotiated = negotiate_content_types(
agent_card=agent_card,
client_output_modes=accepted_output_modes,
skill_name=skill_id,
endpoint=endpoint,
a2a_agent_name=agent_card.name,
)
if content_negotiated.output_modes:
effective_output_modes = content_negotiated.output_modes
headers, _ = await _prepare_auth_headers(auth, timeout)
a2a_agent_name = None
if agent_card.name:
a2a_agent_name = agent_card.name
agent_card_dict = agent_card.model_dump(exclude_none=True)
crewai_event_bus.emit(
agent_branch,
A2ADelegationStartedEvent(
endpoint=endpoint,
task_description=task_description,
agent_id=agent_id or endpoint,
context_id=context_id,
is_multiturn=is_multiturn,
turn_number=turn_number,
a2a_agent_name=a2a_agent_name,
agent_card=agent_card_dict,
protocol_version=agent_card.protocol_version,
provider=agent_card_dict.get("provider"),
skill_id=skill_id,
metadata=metadata,
extensions=list(extensions.keys()) if extensions else None,
from_task=from_task,
from_agent=from_agent,
),
)
if turn_number == 1:
agent_id_for_event = agent_id or endpoint
crewai_event_bus.emit(
agent_branch,
A2AConversationStartedEvent(
agent_id=agent_id_for_event,
endpoint=endpoint,
context_id=context_id,
a2a_agent_name=a2a_agent_name,
agent_card=agent_card_dict,
protocol_version=agent_card.protocol_version,
provider=agent_card_dict.get("provider"),
skill_id=skill_id,
reference_task_ids=reference_task_ids,
metadata=metadata,
extensions=list(extensions.keys()) if extensions else None,
from_task=from_task,
from_agent=from_agent,
),
)
message_parts = []
if context:
message_parts.append(f"Context:\n{context}\n\n")
message_parts.append(f"{task_description}")
message_text = "".join(message_parts)
if is_multiturn and conversation_history and not task_id:
if first_task_id := conversation_history[0].task_id:
task_id = first_task_id
parts: PartsDict = {"text": message_text}
if response_model:
parts.update(
{
"metadata": PartsMetadataDict(
mimeType="application/json",
schema=response_model.model_json_schema(),
)
}
)
message_metadata = metadata.copy() if metadata else {}
if skill_id:
message_metadata["skill_id"] = skill_id
parts_list: list[Part] = [Part(root=TextPart(**parts))]
parts_list.extend(_create_file_parts(input_files))
message = Message(
role=Role.user,
message_id=str(uuid.uuid4()),
parts=parts_list,
context_id=context_id,
task_id=task_id,
reference_task_ids=reference_task_ids,
metadata=message_metadata if message_metadata else None,
extensions=extensions,
)
new_messages: list[Message] = [*conversation_history, message]
crewai_event_bus.emit(
None,
A2AMessageSentEvent(
message=message_text,
turn_number=turn_number,
context_id=context_id,
message_id=message.message_id,
is_multiturn=is_multiturn,
agent_role=agent_role,
endpoint=endpoint,
a2a_agent_name=a2a_agent_name,
skill_id=skill_id,
metadata=message_metadata if message_metadata else None,
extensions=list(extensions.keys()) if extensions else None,
from_task=from_task,
from_agent=from_agent,
),
)
handler = get_handler(updates)
use_polling = isinstance(updates, PollingConfig)
handler_kwargs: dict[str, Any] = {
"turn_number": turn_number,
"is_multiturn": is_multiturn,
"agent_role": agent_role,
"context_id": context_id,
"task_id": task_id,
"endpoint": endpoint,
"agent_branch": agent_branch,
"a2a_agent_name": a2a_agent_name,
"from_task": from_task,
"from_agent": from_agent,
}
if isinstance(updates, PollingConfig):
handler_kwargs.update(
{
"polling_interval": updates.interval,
"polling_timeout": updates.timeout or float(timeout),
"history_length": updates.history_length,
"max_polls": updates.max_polls,
}
)
elif isinstance(updates, PushNotificationConfig):
handler_kwargs.update(
{
"config": updates,
"result_store": updates.result_store,
"polling_timeout": updates.timeout or float(timeout),
"polling_interval": updates.interval,
}
)
push_config_for_client = (
updates if isinstance(updates, PushNotificationConfig) else None
)
use_streaming = not use_polling and push_config_for_client is None
client_agent_card = agent_card
if effective_url != agent_card.url:
client_agent_card = agent_card.model_copy(update={"url": effective_url})
async with _create_a2a_client(
agent_card=client_agent_card,
transport_protocol=effective_transport,
timeout=timeout,
headers=headers,
streaming=use_streaming,
auth=auth,
use_polling=use_polling,
push_notification_config=push_config_for_client,
client_extensions=client_extensions,
accepted_output_modes=effective_output_modes, # type: ignore[arg-type]
grpc_config=transport.grpc,
) as client:
result = await handler.execute(
client=client,
message=message,
new_messages=new_messages,
agent_card=agent_card,
**handler_kwargs,
)
result["a2a_agent_name"] = a2a_agent_name
result["agent_card"] = agent_card.model_dump(exclude_none=True)
return result
def _normalize_grpc_metadata(
metadata: tuple[tuple[str, str], ...] | None,
) -> tuple[tuple[str, str], ...] | None:
"""Lowercase all gRPC metadata keys.
gRPC requires lowercase metadata keys, but some libraries (like the A2A SDK)
use mixed-case headers like 'X-A2A-Extensions'. This normalizes them.
"""
if metadata is None:
return None
return tuple((key.lower(), value) for key, value in metadata)
def _create_grpc_interceptors(
auth_metadata: list[tuple[str, str]] | None = None,
) -> list[Any]:
"""Create gRPC interceptors for metadata normalization and auth injection.
Args:
auth_metadata: Optional auth metadata to inject into all calls.
Used for insecure channels that need auth (non-localhost without TLS).
Returns a list of interceptors that lowercase metadata keys for gRPC
compatibility. Must be called after grpc is imported.
"""
import grpc.aio # type: ignore[import-untyped]
def _merge_metadata(
existing: tuple[tuple[str, str], ...] | None,
auth: list[tuple[str, str]] | None,
) -> tuple[tuple[str, str], ...] | None:
"""Merge existing metadata with auth metadata and normalize keys."""
merged: list[tuple[str, str]] = []
if existing:
merged.extend(existing)
if auth:
merged.extend(auth)
if not merged:
return None
return tuple((key.lower(), value) for key, value in merged)
def _inject_metadata(client_call_details: Any) -> Any:
"""Inject merged metadata into call details."""
return client_call_details._replace(
metadata=_merge_metadata(client_call_details.metadata, auth_metadata)
)
class MetadataUnaryUnary(grpc.aio.UnaryUnaryClientInterceptor): # type: ignore[misc,no-any-unimported]
"""Interceptor for unary-unary calls that injects auth metadata."""
async def intercept_unary_unary( # type: ignore[no-untyped-def]
self, continuation, client_call_details, request
):
"""Intercept unary-unary call and inject metadata."""
return await continuation(_inject_metadata(client_call_details), request)
class MetadataUnaryStream(grpc.aio.UnaryStreamClientInterceptor): # type: ignore[misc,no-any-unimported]
"""Interceptor for unary-stream calls that injects auth metadata."""
async def intercept_unary_stream( # type: ignore[no-untyped-def]
self, continuation, client_call_details, request
):
"""Intercept unary-stream call and inject metadata."""
return await continuation(_inject_metadata(client_call_details), request)
class MetadataStreamUnary(grpc.aio.StreamUnaryClientInterceptor): # type: ignore[misc,no-any-unimported]
"""Interceptor for stream-unary calls that injects auth metadata."""
async def intercept_stream_unary( # type: ignore[no-untyped-def]
self, continuation, client_call_details, request_iterator
):
"""Intercept stream-unary call and inject metadata."""
return await continuation(
_inject_metadata(client_call_details), request_iterator
)
class MetadataStreamStream(grpc.aio.StreamStreamClientInterceptor): # type: ignore[misc,no-any-unimported]
"""Interceptor for stream-stream calls that injects auth metadata."""
async def intercept_stream_stream( # type: ignore[no-untyped-def]
self, continuation, client_call_details, request_iterator
):
"""Intercept stream-stream call and inject metadata."""
return await continuation(
_inject_metadata(client_call_details), request_iterator
)
return [
MetadataUnaryUnary(),
MetadataUnaryStream(),
MetadataStreamUnary(),
MetadataStreamStream(),
]
def _create_grpc_channel_factory(
grpc_config: GRPCClientConfig,
auth: ClientAuthScheme | None = None,
) -> Callable[[str], Any]:
"""Create a gRPC channel factory with the given configuration.
Args:
grpc_config: gRPC client configuration with channel options.
auth: Optional ClientAuthScheme for TLS and auth configuration.
Returns:
A callable that creates gRPC channels from URLs.
"""
try:
import grpc
except ImportError as e:
raise ImportError(
"gRPC transport requires grpcio. Install with: pip install a2a-sdk[grpc]"
) from e
auth_metadata: list[tuple[str, str]] = []
if auth is not None:
from crewai_a2a.auth.client_schemes import (
APIKeyAuth,
BearerTokenAuth,
HTTPBasicAuth,
HTTPDigestAuth,
OAuth2AuthorizationCode,
OAuth2ClientCredentials,
)
if isinstance(auth, HTTPDigestAuth):
raise ValueError(
"HTTPDigestAuth is not supported with gRPC transport. "
"Digest authentication requires HTTP challenge-response flow. "
"Use BearerTokenAuth, HTTPBasicAuth, APIKeyAuth (header), or OAuth2 instead."
)
if isinstance(auth, APIKeyAuth) and auth.location in ("query", "cookie"):
raise ValueError(
f"APIKeyAuth with location='{auth.location}' is not supported with gRPC transport. "
"gRPC only supports header-based authentication. "
"Use APIKeyAuth with location='header' instead."
)
if isinstance(auth, BearerTokenAuth):
auth_metadata.append(("authorization", f"Bearer {auth.token}"))
elif isinstance(auth, HTTPBasicAuth):
import base64
basic_credentials = f"{auth.username}:{auth.password}"
encoded = base64.b64encode(basic_credentials.encode()).decode()
auth_metadata.append(("authorization", f"Basic {encoded}"))
elif isinstance(auth, APIKeyAuth) and auth.location == "header":
header_name = auth.name.lower()
auth_metadata.append((header_name, auth.api_key))
elif isinstance(auth, (OAuth2ClientCredentials, OAuth2AuthorizationCode)):
if auth._access_token:
auth_metadata.append(("authorization", f"Bearer {auth._access_token}"))
def factory(url: str) -> Any:
"""Create a gRPC channel for the given URL."""
target = url
use_tls = False
if url.startswith("grpcs://"):
target = url[8:]
use_tls = True
elif url.startswith("grpc://"):
target = url[7:]
elif url.startswith("https://"):
target = url[8:]
use_tls = True
elif url.startswith("http://"):
target = url[7:]
options: list[tuple[str, Any]] = []
if grpc_config.max_send_message_length is not None:
options.append(
("grpc.max_send_message_length", grpc_config.max_send_message_length)
)
if grpc_config.max_receive_message_length is not None:
options.append(
(
"grpc.max_receive_message_length",
grpc_config.max_receive_message_length,
)
)
if grpc_config.keepalive_time_ms is not None:
options.append(("grpc.keepalive_time_ms", grpc_config.keepalive_time_ms))
if grpc_config.keepalive_timeout_ms is not None:
options.append(
("grpc.keepalive_timeout_ms", grpc_config.keepalive_timeout_ms)
)
channel_credentials = None
if auth and hasattr(auth, "tls") and auth.tls:
channel_credentials = auth.tls.get_grpc_credentials()
elif use_tls:
channel_credentials = grpc.ssl_channel_credentials()
if channel_credentials and auth_metadata:
class AuthMetadataPlugin(grpc.AuthMetadataPlugin): # type: ignore[misc,no-any-unimported]
"""gRPC auth metadata plugin that adds auth headers as metadata."""
def __init__(self, metadata: list[tuple[str, str]]) -> None:
self._metadata = tuple(metadata)
def __call__( # type: ignore[no-any-unimported]
self,
context: grpc.AuthMetadataContext,
callback: grpc.AuthMetadataPluginCallback,
) -> None:
callback(self._metadata, None)
call_creds = grpc.metadata_call_credentials(
AuthMetadataPlugin(auth_metadata)
)
credentials = grpc.composite_channel_credentials(
channel_credentials, call_creds
)
interceptors = _create_grpc_interceptors()
return grpc.aio.secure_channel(
target, credentials, options=options or None, interceptors=interceptors
)
if channel_credentials:
interceptors = _create_grpc_interceptors()
return grpc.aio.secure_channel(
target,
channel_credentials,
options=options or None,
interceptors=interceptors,
)
interceptors = _create_grpc_interceptors(
auth_metadata=auth_metadata if auth_metadata else None
)
return grpc.aio.insecure_channel(
target, options=options or None, interceptors=interceptors
)
return factory
@asynccontextmanager
async def _create_a2a_client(
agent_card: AgentCard,
transport_protocol: Literal["JSONRPC", "GRPC", "HTTP+JSON"],
timeout: int,
headers: MutableMapping[str, str],
streaming: bool,
auth: ClientAuthScheme | None = None,
use_polling: bool = False,
push_notification_config: PushNotificationConfig | None = None,
client_extensions: list[str] | None = None,
accepted_output_modes: list[str] | None = None,
grpc_config: GRPCClientConfig | None = None,
) -> AsyncIterator[Client]:
"""Create and configure an A2A client.
Args:
agent_card: The A2A agent card.
transport_protocol: Transport protocol to use.
timeout: Request timeout in seconds.
headers: HTTP headers (already with auth applied).
streaming: Enable streaming responses.
auth: Optional ClientAuthScheme for client configuration.
use_polling: Enable polling mode.
push_notification_config: Optional push notification config.
client_extensions: A2A protocol extension URIs to declare support for.
accepted_output_modes: MIME types the client can accept in responses.
grpc_config: Optional gRPC client configuration.
Yields:
Configured A2A client instance.
"""
verify = _get_tls_verify(auth)
async with httpx.AsyncClient(
timeout=timeout,
headers=headers,
verify=verify,
) as httpx_client:
if auth and isinstance(auth, (HTTPDigestAuth, APIKeyAuth)):
configure_auth_client(auth, httpx_client)
push_configs: list[A2APushNotificationConfig] = []
if push_notification_config is not None:
push_configs.append(
A2APushNotificationConfig(
url=str(push_notification_config.url),
id=push_notification_config.id,
token=push_notification_config.token,
authentication=push_notification_config.authentication,
)
)
grpc_channel_factory = None
if transport_protocol == "GRPC":
grpc_channel_factory = _create_grpc_channel_factory(
grpc_config or GRPCClientConfig(),
auth=auth,
)
config = ClientConfig(
httpx_client=httpx_client,
supported_transports=[transport_protocol],
streaming=streaming and not use_polling,
polling=use_polling,
accepted_output_modes=accepted_output_modes or DEFAULT_CLIENT_OUTPUT_MODES, # type: ignore[arg-type]
push_notification_configs=push_configs,
grpc_channel_factory=grpc_channel_factory,
)
factory = ClientFactory(config)
client = factory.create(agent_card)
if client_extensions:
await client.add_request_middleware(ExtensionsMiddleware(client_extensions))
yield client

View File

@@ -0,0 +1,131 @@
"""Structured JSON logging utilities for A2A module."""
from __future__ import annotations
from contextvars import ContextVar
from datetime import datetime, timezone
import json
import logging
from typing import Any
_log_context: ContextVar[dict[str, Any] | None] = ContextVar(
"log_context", default=None
)
class JSONFormatter(logging.Formatter):
"""JSON formatter for structured logging.
Outputs logs as JSON with consistent fields for log aggregators.
"""
def format(self, record: logging.LogRecord) -> str:
"""Format log record as JSON string."""
log_data: dict[str, Any] = {
"timestamp": datetime.now(timezone.utc).isoformat(),
"level": record.levelname,
"logger": record.name,
"message": record.getMessage(),
}
if record.exc_info:
log_data["exception"] = self.formatException(record.exc_info)
context = _log_context.get()
if context is not None:
log_data.update(context)
if hasattr(record, "task_id"):
log_data["task_id"] = record.task_id
if hasattr(record, "context_id"):
log_data["context_id"] = record.context_id
if hasattr(record, "agent"):
log_data["agent"] = record.agent
if hasattr(record, "endpoint"):
log_data["endpoint"] = record.endpoint
if hasattr(record, "extension"):
log_data["extension"] = record.extension
if hasattr(record, "error"):
log_data["error"] = record.error
for key, value in record.__dict__.items():
if key.startswith("_") or key in (
"name",
"msg",
"args",
"created",
"filename",
"funcName",
"levelname",
"levelno",
"lineno",
"module",
"msecs",
"pathname",
"process",
"processName",
"relativeCreated",
"stack_info",
"exc_info",
"exc_text",
"thread",
"threadName",
"taskName",
"message",
):
continue
if key not in log_data:
log_data[key] = value
return json.dumps(log_data, default=str)
class LogContext:
"""Context manager for adding fields to all logs within a scope.
Example:
with LogContext(task_id="abc", context_id="xyz"):
logger.info("Processing task") # Includes task_id and context_id
"""
def __init__(self, **fields: Any) -> None:
self._fields = fields
self._token: Any = None
def __enter__(self) -> LogContext:
current = _log_context.get() or {}
new_context = {**current, **self._fields}
self._token = _log_context.set(new_context)
return self
def __exit__(self, *args: Any) -> None:
_log_context.reset(self._token)
def configure_json_logging(logger_name: str = "crewai.a2a") -> None:
"""Configure JSON logging for the A2A module.
Args:
logger_name: Logger name to configure.
"""
logger = logging.getLogger(logger_name)
for handler in logger.handlers[:]:
logger.removeHandler(handler)
handler = logging.StreamHandler()
handler.setFormatter(JSONFormatter())
logger.addHandler(handler)
def get_logger(name: str) -> logging.Logger:
"""Get a logger configured for structured JSON output.
Args:
name: Logger name.
Returns:
Configured logger instance.
"""
return logging.getLogger(name)

View File

@@ -0,0 +1,101 @@
"""Response model utilities for A2A agent interactions."""
from __future__ import annotations
from typing import TypeAlias
from crewai.types.utils import create_literals_from_strings
from pydantic import BaseModel, Field, create_model
from crewai_a2a.config import A2AClientConfig, A2AConfig, A2AServerConfig
A2AConfigTypes: TypeAlias = A2AConfig | A2AServerConfig | A2AClientConfig
A2AClientConfigTypes: TypeAlias = A2AConfig | A2AClientConfig
def create_agent_response_model(agent_ids: tuple[str, ...]) -> type[BaseModel] | None:
"""Create a dynamic AgentResponse model with Literal types for agent IDs.
Args:
agent_ids: List of available A2A agent IDs.
Returns:
Dynamically created Pydantic model with Literal-constrained a2a_ids field,
or None if agent_ids is empty.
"""
if not agent_ids:
return None
DynamicLiteral = create_literals_from_strings(agent_ids) # noqa: N806
return create_model(
"AgentResponse",
a2a_ids=(
tuple[DynamicLiteral, ...], # type: ignore[valid-type]
Field(
default_factory=tuple,
max_length=len(agent_ids),
description="A2A agent IDs to delegate to.",
),
),
message=(
str,
Field(
description="The message content. If is_a2a=true, this is sent to the A2A agent. If is_a2a=false, this is your final answer ending the conversation."
),
),
is_a2a=(
bool,
Field(
description="Set to false when the remote agent has answered your question - extract their answer and return it as your final message. Set to true ONLY if you need to ask a NEW, DIFFERENT question. NEVER repeat the same request - if the conversation history shows the agent already answered, set is_a2a=false immediately."
),
),
__base__=BaseModel,
)
def extract_a2a_agent_ids_from_config(
a2a_config: list[A2AConfigTypes] | A2AConfigTypes | None,
) -> tuple[list[A2AClientConfigTypes], tuple[str, ...]]:
"""Extract A2A agent IDs from A2A configuration.
Filters out A2AServerConfig since it doesn't have an endpoint for delegation.
Args:
a2a_config: A2A configuration (any type).
Returns:
Tuple of client A2A configs list and agent endpoint IDs.
"""
if a2a_config is None:
return [], ()
configs: list[A2AConfigTypes]
if isinstance(a2a_config, (A2AConfig, A2AClientConfig, A2AServerConfig)):
configs = [a2a_config]
else:
configs = a2a_config
# Filter to only client configs (those with endpoint)
client_configs: list[A2AClientConfigTypes] = [
config for config in configs if isinstance(config, (A2AConfig, A2AClientConfig))
]
return client_configs, tuple(config.endpoint for config in client_configs)
def get_a2a_agents_and_response_model(
a2a_config: list[A2AConfigTypes] | A2AConfigTypes | None,
) -> tuple[list[A2AClientConfigTypes], type[BaseModel] | None]:
"""Get A2A agent configs and response model.
Args:
a2a_config: A2A configuration (any type).
Returns:
Tuple of client A2A configs and response model.
"""
a2a_agents, agent_ids = extract_a2a_agent_ids_from_config(a2a_config=a2a_config)
return a2a_agents, create_agent_response_model(agent_ids)

View File

@@ -0,0 +1,585 @@
"""A2A task utilities for server-side task management."""
from __future__ import annotations
import asyncio
import base64
from collections.abc import Callable, Coroutine
from datetime import datetime
from functools import wraps
import json
import logging
import os
from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar, TypedDict, cast
from urllib.parse import urlparse
from a2a.server.agent_execution import RequestContext
from a2a.server.events import EventQueue
from a2a.types import (
Artifact,
FileWithBytes,
FileWithUri,
InternalError,
InvalidParamsError,
Message,
Part,
Task as A2ATask,
TaskState,
TaskStatus,
TaskStatusUpdateEvent,
)
from a2a.utils import (
get_data_parts,
get_file_parts,
new_agent_text_message,
new_data_artifact,
new_text_artifact,
)
from a2a.utils.errors import ServerError
from aiocache import SimpleMemoryCache, caches # type: ignore[import-untyped]
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.a2a_events import (
A2AServerTaskCanceledEvent,
A2AServerTaskCompletedEvent,
A2AServerTaskFailedEvent,
A2AServerTaskStartedEvent,
)
from crewai.task import Task
from crewai.utilities.pydantic_schema_utils import create_model_from_schema
from pydantic import BaseModel
from crewai_a2a.utils.agent_card import _get_server_config
from crewai_a2a.utils.content_type import validate_message_parts
if TYPE_CHECKING:
from crewai.agent import Agent
from crewai_a2a.extensions.server import ExtensionContext, ServerExtensionRegistry
logger = logging.getLogger(__name__)
P = ParamSpec("P")
T = TypeVar("T")
class RedisCacheConfig(TypedDict, total=False):
"""Configuration for aiocache Redis backend."""
cache: str
endpoint: str
port: int
db: int
password: str
def _parse_redis_url(url: str) -> RedisCacheConfig:
"""Parse a Redis URL into aiocache configuration.
Args:
url: Redis connection URL (e.g., redis://localhost:6379/0).
Returns:
Configuration dict for aiocache.RedisCache.
"""
parsed = urlparse(url)
config: RedisCacheConfig = {
"cache": "aiocache.RedisCache",
"endpoint": parsed.hostname or "localhost",
"port": parsed.port or 6379,
}
if parsed.path and parsed.path != "/":
try:
config["db"] = int(parsed.path.lstrip("/"))
except ValueError:
pass
if parsed.password:
config["password"] = parsed.password
return config
_redis_url = os.environ.get("REDIS_URL")
caches.set_config(
{
"default": _parse_redis_url(_redis_url)
if _redis_url
else {
"cache": "aiocache.SimpleMemoryCache",
}
}
)
def cancellable(
fn: Callable[P, Coroutine[Any, Any, T]],
) -> Callable[P, Coroutine[Any, Any, T]]:
"""Decorator that enables cancellation for A2A task execution.
Runs a cancellation watcher concurrently with the wrapped function.
When a cancel event is published, the execution is cancelled.
Args:
fn: The async function to wrap.
Returns:
Wrapped function with cancellation support.
"""
@wraps(fn)
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
"""Wrap function with cancellation monitoring."""
context: RequestContext | None = None
for arg in args:
if isinstance(arg, RequestContext):
context = arg
break
if context is None:
context = cast(RequestContext | None, kwargs.get("context"))
if context is None:
return await fn(*args, **kwargs)
task_id = context.task_id
cache = caches.get("default")
async def poll_for_cancel() -> bool:
"""Poll cache for cancellation flag."""
while True:
if await cache.get(f"cancel:{task_id}"):
return True
await asyncio.sleep(0.1)
async def watch_for_cancel() -> bool:
"""Watch for cancellation events via pub/sub or polling."""
if isinstance(cache, SimpleMemoryCache):
return await poll_for_cancel()
try:
client = cache.client
pubsub = client.pubsub()
await pubsub.subscribe(f"cancel:{task_id}")
async for message in pubsub.listen():
if message["type"] == "message":
return True
except (OSError, ConnectionError) as e:
logger.warning(
"Cancel watcher Redis error, falling back to polling",
extra={"task_id": task_id, "error": str(e)},
)
return await poll_for_cancel()
return False
execute_task = asyncio.create_task(fn(*args, **kwargs))
cancel_watch = asyncio.create_task(watch_for_cancel())
try:
done, _ = await asyncio.wait(
[execute_task, cancel_watch],
return_when=asyncio.FIRST_COMPLETED,
)
if cancel_watch in done:
execute_task.cancel()
try:
await execute_task
except asyncio.CancelledError:
pass
raise asyncio.CancelledError(f"Task {task_id} was cancelled")
cancel_watch.cancel()
return execute_task.result()
finally:
await cache.delete(f"cancel:{task_id}")
return wrapper
def _convert_a2a_files_to_file_inputs(
a2a_files: list[FileWithBytes | FileWithUri],
) -> dict[str, Any]:
"""Convert a2a file types to crewai FileInput dict.
Args:
a2a_files: List of FileWithBytes or FileWithUri from a2a SDK.
Returns:
Dictionary mapping file names to FileInput objects.
"""
try:
from crewai_files import File, FileBytes
except ImportError:
logger.debug("crewai_files not installed, returning empty file dict")
return {}
file_dict: dict[str, Any] = {}
for idx, a2a_file in enumerate(a2a_files):
if isinstance(a2a_file, FileWithBytes):
file_bytes = base64.b64decode(a2a_file.bytes)
name = a2a_file.name or f"file_{idx}"
file_source = FileBytes(data=file_bytes, filename=a2a_file.name)
file_dict[name] = File(source=file_source)
elif isinstance(a2a_file, FileWithUri):
name = a2a_file.name or f"file_{idx}"
file_dict[name] = File(source=a2a_file.uri)
return file_dict
def _extract_response_schema(parts: list[Part]) -> dict[str, Any] | None:
"""Extract response schema from message parts metadata.
The client may include a JSON schema in TextPart metadata to specify
the expected response format (see delegation.py line 463).
Args:
parts: List of message parts.
Returns:
JSON schema dict if found, None otherwise.
"""
for part in parts:
if part.root.kind == "text" and part.root.metadata:
schema = part.root.metadata.get("schema")
if schema and isinstance(schema, dict):
return schema # type: ignore[no-any-return]
return None
def _create_result_artifact(
result: Any,
task_id: str,
) -> Artifact:
"""Create artifact from task result, using DataPart for structured data.
Args:
result: The task execution result.
task_id: The task ID for naming the artifact.
Returns:
Artifact with appropriate part type (DataPart for dict/Pydantic, TextPart for strings).
"""
artifact_name = f"result_{task_id}"
if isinstance(result, dict):
return new_data_artifact(artifact_name, result)
if isinstance(result, BaseModel):
return new_data_artifact(artifact_name, result.model_dump())
return new_text_artifact(artifact_name, str(result))
def _build_task_description(
user_message: str,
structured_inputs: list[dict[str, Any]],
) -> str:
"""Build task description including structured data if present.
Args:
user_message: The original user message text.
structured_inputs: List of structured data from DataParts.
Returns:
Task description with structured data appended if present.
"""
if not structured_inputs:
return user_message
structured_json = json.dumps(structured_inputs, indent=2)
return f"{user_message}\n\nStructured Data:\n{structured_json}"
async def execute(
agent: Agent,
context: RequestContext,
event_queue: EventQueue,
) -> None:
"""Execute an A2A task using a CrewAI agent.
Args:
agent: The CrewAI agent to execute the task.
context: The A2A request context containing the user's message.
event_queue: The event queue for sending responses back.
"""
await _execute_impl(agent, context, event_queue, None, None)
@cancellable
async def _execute_impl(
agent: Agent,
context: RequestContext,
event_queue: EventQueue,
extension_registry: ServerExtensionRegistry | None,
extension_context: ExtensionContext | None,
) -> None:
"""Internal implementation for task execution with optional extensions."""
server_config = _get_server_config(agent)
if context.message and context.message.parts and server_config:
allowed_modes = server_config.default_input_modes
invalid_types = validate_message_parts(context.message.parts, allowed_modes)
if invalid_types:
raise ServerError(
InvalidParamsError(
message=f"Unsupported content type(s): {', '.join(invalid_types)}. "
f"Supported: {', '.join(allowed_modes)}"
)
)
if extension_registry and extension_context:
await extension_registry.invoke_on_request(extension_context)
user_message = context.get_user_input()
response_model: type[BaseModel] | None = None
structured_inputs: list[dict[str, Any]] = []
a2a_files: list[FileWithBytes | FileWithUri] = []
if context.message and context.message.parts:
schema = _extract_response_schema(context.message.parts)
if schema:
try:
response_model = create_model_from_schema(schema)
except Exception as e:
logger.debug(
"Failed to create response model from schema",
extra={"error": str(e), "schema_title": schema.get("title")},
)
structured_inputs = get_data_parts(context.message.parts)
a2a_files = get_file_parts(context.message.parts)
task_id = context.task_id
context_id = context.context_id
if task_id is None or context_id is None:
msg = "task_id and context_id are required"
crewai_event_bus.emit(
agent,
A2AServerTaskFailedEvent(
task_id="",
context_id="",
error=msg,
from_agent=agent,
),
)
raise ServerError(InvalidParamsError(message=msg)) from None
task = Task(
description=_build_task_description(user_message, structured_inputs),
expected_output="Response to the user's request",
agent=agent,
response_model=response_model,
input_files=_convert_a2a_files_to_file_inputs(a2a_files),
)
crewai_event_bus.emit(
agent,
A2AServerTaskStartedEvent(
task_id=task_id,
context_id=context_id,
from_task=task,
from_agent=agent,
),
)
try:
result = await agent.aexecute_task(task=task, tools=agent.tools)
if extension_registry and extension_context:
result = await extension_registry.invoke_on_response(
extension_context, result
)
result_str = str(result)
history: list[Message] = [context.message] if context.message else []
history.append(new_agent_text_message(result_str, context_id, task_id))
await event_queue.enqueue_event(
A2ATask(
id=task_id,
context_id=context_id,
status=TaskStatus(state=TaskState.completed),
artifacts=[_create_result_artifact(result, task_id)],
history=history,
)
)
crewai_event_bus.emit(
agent,
A2AServerTaskCompletedEvent(
task_id=task_id,
context_id=context_id,
result=str(result),
from_task=task,
from_agent=agent,
),
)
except asyncio.CancelledError:
crewai_event_bus.emit(
agent,
A2AServerTaskCanceledEvent(
task_id=task_id,
context_id=context_id,
from_task=task,
from_agent=agent,
),
)
raise
except Exception as e:
crewai_event_bus.emit(
agent,
A2AServerTaskFailedEvent(
task_id=task_id,
context_id=context_id,
error=str(e),
from_task=task,
from_agent=agent,
),
)
raise ServerError(
error=InternalError(message=f"Task execution failed: {e}")
) from e
async def execute_with_extensions(
agent: Agent,
context: RequestContext,
event_queue: EventQueue,
extension_registry: ServerExtensionRegistry,
extension_context: ExtensionContext,
) -> None:
"""Execute an A2A task with extension hooks.
Args:
agent: The CrewAI agent to execute the task.
context: The A2A request context containing the user's message.
event_queue: The event queue for sending responses back.
extension_registry: Registry of server extensions.
extension_context: Context for extension hooks.
"""
await _execute_impl(
agent, context, event_queue, extension_registry, extension_context
)
async def cancel(
context: RequestContext,
event_queue: EventQueue,
) -> A2ATask | None:
"""Cancel an A2A task.
Publishes a cancel event that the cancellable decorator listens for.
Args:
context: The A2A request context containing task information.
event_queue: The event queue for sending the cancellation status.
Returns:
The canceled task with updated status.
"""
task_id = context.task_id
context_id = context.context_id
if task_id is None or context_id is None:
raise ServerError(InvalidParamsError(message="task_id and context_id required"))
if context.current_task and context.current_task.status.state in (
TaskState.completed,
TaskState.failed,
TaskState.canceled,
):
return context.current_task
cache = caches.get("default")
await cache.set(f"cancel:{task_id}", True, ttl=3600)
if not isinstance(cache, SimpleMemoryCache):
await cache.client.publish(f"cancel:{task_id}", "cancel")
await event_queue.enqueue_event(
TaskStatusUpdateEvent(
task_id=task_id,
context_id=context_id,
status=TaskStatus(state=TaskState.canceled),
final=True,
)
)
if context.current_task:
context.current_task.status = TaskStatus(state=TaskState.canceled)
return context.current_task
return None
def list_tasks(
tasks: list[A2ATask],
context_id: str | None = None,
status: TaskState | None = None,
status_timestamp_after: datetime | None = None,
page_size: int = 50,
page_token: str | None = None,
history_length: int | None = None,
include_artifacts: bool = False,
) -> tuple[list[A2ATask], str | None, int]:
"""Filter and paginate A2A tasks.
Provides filtering by context, status, and timestamp, along with
cursor-based pagination. This is a pure utility function that operates
on an in-memory list of tasks - storage retrieval is handled separately.
Args:
tasks: All tasks to filter.
context_id: Filter by context ID to get tasks in a conversation.
status: Filter by task state (e.g., completed, working).
status_timestamp_after: Filter to tasks updated after this time.
page_size: Maximum tasks per page (default 50).
page_token: Base64-encoded cursor from previous response.
history_length: Limit history messages per task (None = full history).
include_artifacts: Whether to include task artifacts (default False).
Returns:
Tuple of (filtered_tasks, next_page_token, total_count).
- filtered_tasks: Tasks matching filters, paginated and trimmed.
- next_page_token: Token for next page, or None if no more pages.
- total_count: Total number of tasks matching filters (before pagination).
"""
filtered: list[A2ATask] = []
for task in tasks:
if context_id and task.context_id != context_id:
continue
if status and task.status.state != status:
continue
if status_timestamp_after and task.status.timestamp:
ts = datetime.fromisoformat(task.status.timestamp.replace("Z", "+00:00"))
if ts <= status_timestamp_after:
continue
filtered.append(task)
def get_timestamp(t: A2ATask) -> datetime:
"""Extract timestamp from task status for sorting."""
if t.status.timestamp is None:
return datetime.min
return datetime.fromisoformat(t.status.timestamp.replace("Z", "+00:00"))
filtered.sort(key=get_timestamp, reverse=True)
total = len(filtered)
start = 0
if page_token:
try:
cursor_id = base64.b64decode(page_token).decode()
for idx, task in enumerate(filtered):
if task.id == cursor_id:
start = idx + 1
break
except (ValueError, UnicodeDecodeError):
pass
page = filtered[start : start + page_size]
result: list[A2ATask] = []
for task in page:
task = task.model_copy(deep=True)
if history_length is not None and task.history:
task.history = task.history[-history_length:]
if not include_artifacts:
task.artifacts = None
result.append(task)
next_token: str | None = None
if result and len(result) == page_size:
next_token = base64.b64encode(result[-1].id.encode()).decode()
return result, next_token, total

View File

@@ -0,0 +1,214 @@
"""Transport negotiation utilities for A2A protocol.
This module provides functionality for negotiating the transport protocol
between an A2A client and server based on their respective capabilities
and preferences.
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import Final, Literal
from a2a.types import AgentCard, AgentInterface
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.a2a_events import A2ATransportNegotiatedEvent
TransportProtocol = Literal["JSONRPC", "GRPC", "HTTP+JSON"]
NegotiationSource = Literal["client_preferred", "server_preferred", "fallback"]
JSONRPC_TRANSPORT: Literal["JSONRPC"] = "JSONRPC"
GRPC_TRANSPORT: Literal["GRPC"] = "GRPC"
HTTP_JSON_TRANSPORT: Literal["HTTP+JSON"] = "HTTP+JSON"
DEFAULT_TRANSPORT_PREFERENCE: Final[list[TransportProtocol]] = [
JSONRPC_TRANSPORT,
GRPC_TRANSPORT,
HTTP_JSON_TRANSPORT,
]
@dataclass
class NegotiatedTransport:
"""Result of transport negotiation.
Attributes:
transport: The negotiated transport protocol.
url: The URL to use for this transport.
source: How the transport was selected ('preferred', 'additional', 'fallback').
"""
transport: str
url: str
source: NegotiationSource
class TransportNegotiationError(Exception):
"""Raised when no compatible transport can be negotiated."""
def __init__(
self,
client_transports: list[str],
server_transports: list[str],
message: str | None = None,
) -> None:
"""Initialize the error with negotiation details.
Args:
client_transports: Transports supported by the client.
server_transports: Transports supported by the server.
message: Optional custom error message.
"""
self.client_transports = client_transports
self.server_transports = server_transports
if message is None:
message = (
f"No compatible transport found. "
f"Client supports: {client_transports}. "
f"Server supports: {server_transports}."
)
super().__init__(message)
def _get_server_interfaces(agent_card: AgentCard) -> list[AgentInterface]:
"""Extract all available interfaces from an AgentCard.
Creates a unified list of interfaces including the primary URL and
any additional interfaces declared by the agent.
Args:
agent_card: The agent's card containing transport information.
Returns:
List of AgentInterface objects representing all available endpoints.
"""
interfaces: list[AgentInterface] = []
primary_transport = agent_card.preferred_transport or JSONRPC_TRANSPORT
interfaces.append(
AgentInterface(
transport=primary_transport,
url=agent_card.url,
)
)
if agent_card.additional_interfaces:
for interface in agent_card.additional_interfaces:
is_duplicate = any(
i.url == interface.url and i.transport == interface.transport
for i in interfaces
)
if not is_duplicate:
interfaces.append(interface)
return interfaces
def negotiate_transport(
agent_card: AgentCard,
client_supported_transports: list[str] | None = None,
client_preferred_transport: str | None = None,
emit_event: bool = True,
endpoint: str | None = None,
a2a_agent_name: str | None = None,
) -> NegotiatedTransport:
"""Negotiate the transport protocol between client and server.
Compares the client's supported transports with the server's available
interfaces to find a compatible transport and URL.
Negotiation logic:
1. If client_preferred_transport is set and server supports it → use it
2. Otherwise, if server's preferred is in client's supported → use server's
3. Otherwise, find first match from client's supported in server's interfaces
Args:
agent_card: The server's AgentCard with transport information.
client_supported_transports: Transports the client can use.
Defaults to ["JSONRPC"] if not specified.
client_preferred_transport: Client's preferred transport. If set and
server supports it, takes priority over server preference.
emit_event: Whether to emit a transport negotiation event.
endpoint: Original endpoint URL for event metadata.
a2a_agent_name: Agent name for event metadata.
Returns:
NegotiatedTransport with the selected transport, URL, and source.
Raises:
TransportNegotiationError: If no compatible transport is found.
"""
if client_supported_transports is None:
client_supported_transports = [JSONRPC_TRANSPORT]
client_transports = [t.upper() for t in client_supported_transports]
client_preferred = (
client_preferred_transport.upper() if client_preferred_transport else None
)
server_interfaces = _get_server_interfaces(agent_card)
server_transports = [i.transport.upper() for i in server_interfaces]
transport_to_interface: dict[str, AgentInterface] = {}
for interface in server_interfaces:
transport_upper = interface.transport.upper()
if transport_upper not in transport_to_interface:
transport_to_interface[transport_upper] = interface
result: NegotiatedTransport | None = None
if client_preferred and client_preferred in transport_to_interface:
interface = transport_to_interface[client_preferred]
result = NegotiatedTransport(
transport=interface.transport,
url=interface.url,
source="client_preferred",
)
else:
server_preferred = (agent_card.preferred_transport or JSONRPC_TRANSPORT).upper()
if (
server_preferred in client_transports
and server_preferred in transport_to_interface
):
interface = transport_to_interface[server_preferred]
result = NegotiatedTransport(
transport=interface.transport,
url=interface.url,
source="server_preferred",
)
else:
for transport in client_transports:
if transport in transport_to_interface:
interface = transport_to_interface[transport]
result = NegotiatedTransport(
transport=interface.transport,
url=interface.url,
source="fallback",
)
break
if result is None:
raise TransportNegotiationError(
client_transports=client_transports,
server_transports=server_transports,
)
if emit_event:
crewai_event_bus.emit(
None,
A2ATransportNegotiatedEvent(
endpoint=endpoint or agent_card.url,
a2a_agent_name=a2a_agent_name or agent_card.name,
negotiated_transport=result.transport,
negotiated_url=result.url,
source=result.source,
client_supported_transports=client_transports,
server_supported_transports=server_transports,
server_preferred_transport=agent_card.preferred_transport
or JSONRPC_TRANSPORT,
client_preferred_transport=client_preferred,
),
)
return result

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,44 @@
interactions:
- request:
body: ''
headers:
User-Agent:
- X-USER-AGENT-XXX
accept:
- '*/*'
accept-encoding:
- ACCEPT-ENCODING-XXX
connection:
- keep-alive
host:
- localhost:9999
method: GET
uri: http://localhost:9999/.well-known/agent-card.json
response:
body:
string: '{"capabilities":{"streaming":true},"defaultInputModes":["text"],"defaultOutputModes":["text"],"description":"An
AI assistant powered by OpenAI GPT with calculator and time tools. Ask questions,
perform calculations, or get the current time in any timezone.","name":"GPT
Assistant","preferredTransport":"JSONRPC","protocolVersion":"0.3.0","skills":[{"description":"Have
a general conversation with the AI assistant. Ask questions, get explanations,
or just chat.","examples":["Hello, how are you?","Explain quantum computing
in simple terms","What can you help me with?"],"id":"conversation","name":"General
Conversation","tags":["chat","conversation","general"]},{"description":"Perform
mathematical calculations including arithmetic, exponents, and more.","examples":["What
is 25 * 17?","Calculate 2^10","What''s (100 + 50) / 3?"],"id":"calculator","name":"Calculator","tags":["math","calculator","arithmetic"]},{"description":"Get
the current date and time in any timezone.","examples":["What time is it?","What''s
the current time in Tokyo?","What''s today''s date in New York?"],"id":"time","name":"Current
Time","tags":["time","date","timezone"]}],"url":"http://localhost:9999/","version":"1.0.0"}'
headers:
content-length:
- '1198'
content-type:
- application/json
date:
- Tue, 06 Jan 2026 14:17:00 GMT
server:
- uvicorn
status:
code: 200
message: OK
version: 1

View File

@@ -0,0 +1,126 @@
interactions:
- request:
body: ''
headers:
User-Agent:
- X-USER-AGENT-XXX
accept:
- '*/*'
accept-encoding:
- ACCEPT-ENCODING-XXX
connection:
- keep-alive
host:
- localhost:9999
method: GET
uri: http://localhost:9999/.well-known/agent-card.json
response:
body:
string: '{"capabilities":{"streaming":true},"defaultInputModes":["text"],"defaultOutputModes":["text"],"description":"An
AI assistant powered by OpenAI GPT with calculator and time tools. Ask questions,
perform calculations, or get the current time in any timezone.","name":"GPT
Assistant","preferredTransport":"JSONRPC","protocolVersion":"0.3.0","skills":[{"description":"Have
a general conversation with the AI assistant. Ask questions, get explanations,
or just chat.","examples":["Hello, how are you?","Explain quantum computing
in simple terms","What can you help me with?"],"id":"conversation","name":"General
Conversation","tags":["chat","conversation","general"]},{"description":"Perform
mathematical calculations including arithmetic, exponents, and more.","examples":["What
is 25 * 17?","Calculate 2^10","What''s (100 + 50) / 3?"],"id":"calculator","name":"Calculator","tags":["math","calculator","arithmetic"]},{"description":"Get
the current date and time in any timezone.","examples":["What time is it?","What''s
the current time in Tokyo?","What''s today''s date in New York?"],"id":"time","name":"Current
Time","tags":["time","date","timezone"]}],"url":"http://localhost:9999/","version":"1.0.0"}'
headers:
content-length:
- '1198'
content-type:
- application/json
date:
- Tue, 06 Jan 2026 14:16:58 GMT
server:
- uvicorn
status:
code: 200
message: OK
- request:
body: '{"id":"e5ac2160-ae9b-4bf9-aad7-14bf0d53d6d9","jsonrpc":"2.0","method":"message/stream","params":{"configuration":{"acceptedOutputModes":[],"blocking":true},"message":{"kind":"message","messageId":"e1e63c75-3ea0-49fb-b512-5128a2476416","parts":[{"kind":"text","text":"What
is 2 + 2?"}],"role":"user"}}}'
headers:
User-Agent:
- X-USER-AGENT-XXX
accept:
- '*/*, text/event-stream'
accept-encoding:
- ACCEPT-ENCODING-XXX
cache-control:
- no-store
connection:
- keep-alive
content-length:
- '301'
content-type:
- application/json
host:
- localhost:9999
method: POST
uri: http://localhost:9999/
response:
body:
string: "data: {\"id\":\"e5ac2160-ae9b-4bf9-aad7-14bf0d53d6d9\",\"jsonrpc\":\"2.0\",\"result\":{\"contextId\":\"b9e14c1b-734d-4d1e-864a-e6dda5231d71\",\"final\":false,\"kind\":\"status-update\",\"status\":{\"state\":\"submitted\"},\"taskId\":\"0dd4d3af-f35d-409d-9462-01218e5641f9\"}}\r\n\r\ndata:
{\"id\":\"e5ac2160-ae9b-4bf9-aad7-14bf0d53d6d9\",\"jsonrpc\":\"2.0\",\"result\":{\"contextId\":\"b9e14c1b-734d-4d1e-864a-e6dda5231d71\",\"final\":false,\"kind\":\"status-update\",\"status\":{\"state\":\"working\"},\"taskId\":\"0dd4d3af-f35d-409d-9462-01218e5641f9\"}}\r\n\r\ndata:
{\"id\":\"e5ac2160-ae9b-4bf9-aad7-14bf0d53d6d9\",\"jsonrpc\":\"2.0\",\"result\":{\"contextId\":\"b9e14c1b-734d-4d1e-864a-e6dda5231d71\",\"final\":true,\"kind\":\"status-update\",\"status\":{\"message\":{\"kind\":\"message\",\"messageId\":\"54bb7ff3-f2c0-4eb3-b427-bf1c8cf90832\",\"parts\":[{\"kind\":\"text\",\"text\":\"\\n[Tool:
calculator] 2 + 2 = 4\\n2 + 2 equals 4.\"}],\"role\":\"agent\"},\"state\":\"completed\"},\"taskId\":\"0dd4d3af-f35d-409d-9462-01218e5641f9\"}}\r\n\r\n"
headers:
Transfer-Encoding:
- chunked
cache-control:
- no-store
connection:
- keep-alive
content-type:
- text/event-stream; charset=utf-8
date:
- Tue, 06 Jan 2026 14:16:58 GMT
server:
- uvicorn
x-accel-buffering:
- 'no'
status:
code: 200
message: OK
- request:
body: '{"id":"cb1e4af3-d2d0-4848-96b8-7082ee6171d1","jsonrpc":"2.0","method":"tasks/get","params":{"historyLength":100,"id":"0dd4d3af-f35d-409d-9462-01218e5641f9"}}'
headers:
User-Agent:
- X-USER-AGENT-XXX
accept:
- '*/*'
accept-encoding:
- ACCEPT-ENCODING-XXX
connection:
- keep-alive
content-length:
- '157'
content-type:
- application/json
host:
- localhost:9999
method: POST
uri: http://localhost:9999/
response:
body:
string: '{"id":"cb1e4af3-d2d0-4848-96b8-7082ee6171d1","jsonrpc":"2.0","result":{"contextId":"b9e14c1b-734d-4d1e-864a-e6dda5231d71","history":[{"contextId":"b9e14c1b-734d-4d1e-864a-e6dda5231d71","kind":"message","messageId":"e1e63c75-3ea0-49fb-b512-5128a2476416","parts":[{"kind":"text","text":"What
is 2 + 2?"}],"role":"user","taskId":"0dd4d3af-f35d-409d-9462-01218e5641f9"}],"id":"0dd4d3af-f35d-409d-9462-01218e5641f9","kind":"task","status":{"message":{"kind":"message","messageId":"54bb7ff3-f2c0-4eb3-b427-bf1c8cf90832","parts":[{"kind":"text","text":"\n[Tool:
calculator] 2 + 2 = 4\n2 + 2 equals 4."}],"role":"agent"},"state":"completed"}}}'
headers:
content-length:
- '635'
content-type:
- application/json
date:
- Tue, 06 Jan 2026 14:17:00 GMT
server:
- uvicorn
status:
code: 200
message: OK
version: 1

View File

@@ -0,0 +1,90 @@
interactions:
- request:
body: ''
headers:
User-Agent:
- X-USER-AGENT-XXX
accept:
- '*/*'
accept-encoding:
- ACCEPT-ENCODING-XXX
connection:
- keep-alive
host:
- localhost:9999
method: GET
uri: http://localhost:9999/.well-known/agent-card.json
response:
body:
string: '{"capabilities":{"streaming":true},"defaultInputModes":["text"],"defaultOutputModes":["text"],"description":"An
AI assistant powered by OpenAI GPT with calculator and time tools. Ask questions,
perform calculations, or get the current time in any timezone.","name":"GPT
Assistant","preferredTransport":"JSONRPC","protocolVersion":"0.3.0","skills":[{"description":"Have
a general conversation with the AI assistant. Ask questions, get explanations,
or just chat.","examples":["Hello, how are you?","Explain quantum computing
in simple terms","What can you help me with?"],"id":"conversation","name":"General
Conversation","tags":["chat","conversation","general"]},{"description":"Perform
mathematical calculations including arithmetic, exponents, and more.","examples":["What
is 25 * 17?","Calculate 2^10","What''s (100 + 50) / 3?"],"id":"calculator","name":"Calculator","tags":["math","calculator","arithmetic"]},{"description":"Get
the current date and time in any timezone.","examples":["What time is it?","What''s
the current time in Tokyo?","What''s today''s date in New York?"],"id":"time","name":"Current
Time","tags":["time","date","timezone"]}],"url":"http://localhost:9999/","version":"1.0.0"}'
headers:
content-length:
- '1198'
content-type:
- application/json
date:
- Tue, 06 Jan 2026 14:17:02 GMT
server:
- uvicorn
status:
code: 200
message: OK
- request:
body: '{"id":"8cf25b61-8884-4246-adce-fccb32e176ab","jsonrpc":"2.0","method":"message/stream","params":{"configuration":{"acceptedOutputModes":[],"blocking":true},"message":{"kind":"message","messageId":"c145297f-7331-4835-adcc-66b51de92a2b","parts":[{"kind":"text","text":"What
is 2 + 2?"}],"role":"user"}}}'
headers:
User-Agent:
- X-USER-AGENT-XXX
accept:
- '*/*, text/event-stream'
accept-encoding:
- ACCEPT-ENCODING-XXX
cache-control:
- no-store
connection:
- keep-alive
content-length:
- '301'
content-type:
- application/json
host:
- localhost:9999
method: POST
uri: http://localhost:9999/
response:
body:
string: "data: {\"id\":\"8cf25b61-8884-4246-adce-fccb32e176ab\",\"jsonrpc\":\"2.0\",\"result\":{\"contextId\":\"30601267-ab3b-48ef-afc8-916c37a18651\",\"final\":false,\"kind\":\"status-update\",\"status\":{\"state\":\"submitted\"},\"taskId\":\"3083d3da-4739-4f4f-a4e8-7c048ea819c1\"}}\r\n\r\ndata:
{\"id\":\"8cf25b61-8884-4246-adce-fccb32e176ab\",\"jsonrpc\":\"2.0\",\"result\":{\"contextId\":\"30601267-ab3b-48ef-afc8-916c37a18651\",\"final\":false,\"kind\":\"status-update\",\"status\":{\"state\":\"working\"},\"taskId\":\"3083d3da-4739-4f4f-a4e8-7c048ea819c1\"}}\r\n\r\ndata:
{\"id\":\"8cf25b61-8884-4246-adce-fccb32e176ab\",\"jsonrpc\":\"2.0\",\"result\":{\"contextId\":\"30601267-ab3b-48ef-afc8-916c37a18651\",\"final\":true,\"kind\":\"status-update\",\"status\":{\"message\":{\"kind\":\"message\",\"messageId\":\"25f81e3c-b7e8-48b5-a98a-4066f3637a13\",\"parts\":[{\"kind\":\"text\",\"text\":\"\\n[Tool:
calculator] 2 + 2 = 4\\n2 + 2 equals 4.\"}],\"role\":\"agent\"},\"state\":\"completed\"},\"taskId\":\"3083d3da-4739-4f4f-a4e8-7c048ea819c1\"}}\r\n\r\n"
headers:
Transfer-Encoding:
- chunked
cache-control:
- no-store
connection:
- keep-alive
content-type:
- text/event-stream; charset=utf-8
date:
- Tue, 06 Jan 2026 14:17:02 GMT
server:
- uvicorn
x-accel-buffering:
- 'no'
status:
code: 200
message: OK
version: 1

View File

@@ -0,0 +1,90 @@
interactions:
- request:
body: ''
headers:
User-Agent:
- X-USER-AGENT-XXX
accept:
- '*/*'
accept-encoding:
- ACCEPT-ENCODING-XXX
connection:
- keep-alive
host:
- localhost:9999
method: GET
uri: http://localhost:9999/.well-known/agent-card.json
response:
body:
string: '{"capabilities":{"streaming":true},"defaultInputModes":["text"],"defaultOutputModes":["text"],"description":"An
AI assistant powered by OpenAI GPT with calculator and time tools. Ask questions,
perform calculations, or get the current time in any timezone.","name":"GPT
Assistant","preferredTransport":"JSONRPC","protocolVersion":"0.3.0","skills":[{"description":"Have
a general conversation with the AI assistant. Ask questions, get explanations,
or just chat.","examples":["Hello, how are you?","Explain quantum computing
in simple terms","What can you help me with?"],"id":"conversation","name":"General
Conversation","tags":["chat","conversation","general"]},{"description":"Perform
mathematical calculations including arithmetic, exponents, and more.","examples":["What
is 25 * 17?","Calculate 2^10","What''s (100 + 50) / 3?"],"id":"calculator","name":"Calculator","tags":["math","calculator","arithmetic"]},{"description":"Get
the current date and time in any timezone.","examples":["What time is it?","What''s
the current time in Tokyo?","What''s today''s date in New York?"],"id":"time","name":"Current
Time","tags":["time","date","timezone"]}],"url":"http://localhost:9999/","version":"1.0.0"}'
headers:
content-length:
- '1198'
content-type:
- application/json
date:
- Tue, 06 Jan 2026 14:17:00 GMT
server:
- uvicorn
status:
code: 200
message: OK
- request:
body: '{"id":"3a17c6bf-8db6-45a6-8535-34c45c0c4936","jsonrpc":"2.0","method":"message/stream","params":{"configuration":{"acceptedOutputModes":[],"blocking":true},"message":{"kind":"message","messageId":"712558a3-6d92-4591-be8a-9dd8566dde82","parts":[{"kind":"text","text":"What
is 2 + 2?"}],"role":"user"}}}'
headers:
User-Agent:
- X-USER-AGENT-XXX
accept:
- '*/*, text/event-stream'
accept-encoding:
- ACCEPT-ENCODING-XXX
cache-control:
- no-store
connection:
- keep-alive
content-length:
- '301'
content-type:
- application/json
host:
- localhost:9999
method: POST
uri: http://localhost:9999/
response:
body:
string: "data: {\"id\":\"3a17c6bf-8db6-45a6-8535-34c45c0c4936\",\"jsonrpc\":\"2.0\",\"result\":{\"contextId\":\"ca2fbbc9-761e-45d9-a929-0c68b1f8acbf\",\"final\":false,\"kind\":\"status-update\",\"status\":{\"state\":\"submitted\"},\"taskId\":\"c6e88db0-36e9-4269-8b9a-ecb6dfdcf6a1\"}}\r\n\r\ndata:
{\"id\":\"3a17c6bf-8db6-45a6-8535-34c45c0c4936\",\"jsonrpc\":\"2.0\",\"result\":{\"contextId\":\"ca2fbbc9-761e-45d9-a929-0c68b1f8acbf\",\"final\":false,\"kind\":\"status-update\",\"status\":{\"state\":\"working\"},\"taskId\":\"c6e88db0-36e9-4269-8b9a-ecb6dfdcf6a1\"}}\r\n\r\ndata:
{\"id\":\"3a17c6bf-8db6-45a6-8535-34c45c0c4936\",\"jsonrpc\":\"2.0\",\"result\":{\"contextId\":\"ca2fbbc9-761e-45d9-a929-0c68b1f8acbf\",\"final\":true,\"kind\":\"status-update\",\"status\":{\"message\":{\"kind\":\"message\",\"messageId\":\"916324aa-fd25-4849-bceb-c4644e2fcbb0\",\"parts\":[{\"kind\":\"text\",\"text\":\"\\n[Tool:
calculator] 2 + 2 = 4\\n2 + 2 equals 4.\"}],\"role\":\"agent\"},\"state\":\"completed\"},\"taskId\":\"c6e88db0-36e9-4269-8b9a-ecb6dfdcf6a1\"}}\r\n\r\n"
headers:
Transfer-Encoding:
- chunked
cache-control:
- no-store
connection:
- keep-alive
content-type:
- text/event-stream; charset=utf-8
date:
- Tue, 06 Jan 2026 14:17:00 GMT
server:
- uvicorn
x-accel-buffering:
- 'no'
status:
code: 200
message: OK
version: 1

View File

@@ -0,0 +1,21 @@
"""Pytest configuration for crewai-a2a tests.
Ensures Agent model is properly rebuilt with A2A types,
which can fail silently during circular import resolution.
"""
from crewai_a2a.config import A2AClientConfig, A2AConfig, A2AServerConfig
def pytest_configure() -> None:
"""Rebuild Agent/LiteAgent models after crewai_a2a is fully loaded."""
from crewai.agent.core import Agent
from crewai.lite_agent import LiteAgent
ns = {
"A2AConfig": A2AConfig,
"A2AClientConfig": A2AClientConfig,
"A2AServerConfig": A2AServerConfig,
}
Agent.model_rebuild(_types_namespace=ns)
LiteAgent.model_rebuild(_types_namespace=ns)

View File

@@ -3,15 +3,13 @@ from __future__ import annotations
import os
import uuid
from a2a.client import ClientFactory
from a2a.types import AgentCard, Message, Part, Role, Task, TaskState, TextPart
from crewai_a2a.updates.polling.handler import PollingHandler
from crewai_a2a.updates.streaming.handler import StreamingHandler
import pytest
import pytest_asyncio
from a2a.client import ClientFactory
from a2a.types import AgentCard, Message, Part, Role, TaskState, TextPart
from crewai.a2a.updates.polling.handler import PollingHandler
from crewai.a2a.updates.streaming.handler import StreamingHandler
A2A_TEST_ENDPOINT = os.getenv("A2A_TEST_ENDPOINT", "http://localhost:9999")
@@ -162,7 +160,7 @@ class TestA2APushNotificationHandler:
)
@pytest.fixture
def mock_task(self) -> "Task":
def mock_task(self) -> Task:
"""Create a minimal valid task for testing."""
from a2a.types import Task, TaskStatus
@@ -182,11 +180,12 @@ class TestA2APushNotificationHandler:
from unittest.mock import AsyncMock, MagicMock
from a2a.types import Task, TaskStatus
from crewai_a2a.updates.push_notifications.config import PushNotificationConfig
from crewai_a2a.updates.push_notifications.handler import (
PushNotificationHandler,
)
from pydantic import AnyHttpUrl
from crewai.a2a.updates.push_notifications.config import PushNotificationConfig
from crewai.a2a.updates.push_notifications.handler import PushNotificationHandler
completed_task = Task(
id="task-123",
context_id="ctx-123",
@@ -246,11 +245,12 @@ class TestA2APushNotificationHandler:
from unittest.mock import AsyncMock, MagicMock
from a2a.types import Task, TaskStatus
from crewai_a2a.updates.push_notifications.config import PushNotificationConfig
from crewai_a2a.updates.push_notifications.handler import (
PushNotificationHandler,
)
from pydantic import AnyHttpUrl
from crewai.a2a.updates.push_notifications.config import PushNotificationConfig
from crewai.a2a.updates.push_notifications.handler import PushNotificationHandler
mock_store = MagicMock()
mock_store.wait_for_result = AsyncMock(return_value=None)
@@ -303,7 +303,9 @@ class TestA2APushNotificationHandler:
"""Test that push handler fails gracefully without config."""
from unittest.mock import MagicMock
from crewai.a2a.updates.push_notifications.handler import PushNotificationHandler
from crewai_a2a.updates.push_notifications.handler import (
PushNotificationHandler,
)
mock_client = MagicMock()

View File

@@ -3,10 +3,9 @@
from __future__ import annotations
from a2a.types import AgentCard, AgentSkill
from crewai import Agent
from crewai.a2a.config import A2AClientConfig, A2AServerConfig
from crewai.a2a.utils.agent_card import inject_a2a_server_methods
from crewai_a2a.config import A2AClientConfig, A2AServerConfig
from crewai_a2a.utils.agent_card import inject_a2a_server_methods
class TestInjectA2AServerMethods:

View File

@@ -6,13 +6,12 @@ import asyncio
from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
import pytest_asyncio
from a2a.server.agent_execution import RequestContext
from a2a.server.events import EventQueue
from a2a.types import Message, Task as A2ATask, TaskState, TaskStatus
from crewai.a2a.utils.task import cancel, cancellable, execute
from crewai_a2a.utils.task import cancel, cancellable, execute
import pytest
import pytest_asyncio
@pytest.fixture
@@ -85,8 +84,11 @@ class TestCancellableDecorator:
assert call_count == 1
@pytest.mark.asyncio
async def test_executes_function_with_context(self, mock_context: MagicMock) -> None:
async def test_executes_function_with_context(
self, mock_context: MagicMock
) -> None:
"""Function executes normally with RequestContext when not cancelled."""
@cancellable
async def my_func(context: RequestContext) -> str:
await asyncio.sleep(0.01)
@@ -134,6 +136,7 @@ class TestCancellableDecorator:
@pytest.mark.asyncio
async def test_extracts_context_from_kwargs(self, mock_context: MagicMock) -> None:
"""Context can be passed as keyword argument."""
@cancellable
async def my_func(value: int, context: RequestContext | None = None) -> int:
return value + 1
@@ -156,8 +159,8 @@ class TestExecute:
) -> None:
"""Execute completes successfully and enqueues completed task."""
with (
patch("crewai.a2a.utils.task.Task", return_value=mock_task),
patch("crewai.a2a.utils.task.crewai_event_bus") as mock_bus,
patch("crewai_a2a.utils.task.Task", return_value=mock_task),
patch("crewai_a2a.utils.task.crewai_event_bus") as mock_bus,
):
await execute(mock_agent, mock_context, mock_event_queue)
@@ -175,8 +178,8 @@ class TestExecute:
) -> None:
"""Execute emits A2AServerTaskStartedEvent."""
with (
patch("crewai.a2a.utils.task.Task", return_value=mock_task),
patch("crewai.a2a.utils.task.crewai_event_bus") as mock_bus,
patch("crewai_a2a.utils.task.Task", return_value=mock_task),
patch("crewai_a2a.utils.task.crewai_event_bus") as mock_bus,
):
await execute(mock_agent, mock_context, mock_event_queue)
@@ -197,8 +200,8 @@ class TestExecute:
) -> None:
"""Execute emits A2AServerTaskCompletedEvent on success."""
with (
patch("crewai.a2a.utils.task.Task", return_value=mock_task),
patch("crewai.a2a.utils.task.crewai_event_bus") as mock_bus,
patch("crewai_a2a.utils.task.Task", return_value=mock_task),
patch("crewai_a2a.utils.task.crewai_event_bus") as mock_bus,
):
await execute(mock_agent, mock_context, mock_event_queue)
@@ -221,8 +224,8 @@ class TestExecute:
mock_agent.aexecute_task = AsyncMock(side_effect=ValueError("Test error"))
with (
patch("crewai.a2a.utils.task.Task", return_value=mock_task),
patch("crewai.a2a.utils.task.crewai_event_bus") as mock_bus,
patch("crewai_a2a.utils.task.Task", return_value=mock_task),
patch("crewai_a2a.utils.task.crewai_event_bus") as mock_bus,
):
with pytest.raises(Exception):
await execute(mock_agent, mock_context, mock_event_queue)
@@ -245,8 +248,8 @@ class TestExecute:
mock_agent.aexecute_task = AsyncMock(side_effect=asyncio.CancelledError())
with (
patch("crewai.a2a.utils.task.Task", return_value=mock_task),
patch("crewai.a2a.utils.task.crewai_event_bus") as mock_bus,
patch("crewai_a2a.utils.task.Task", return_value=mock_task),
patch("crewai_a2a.utils.task.crewai_event_bus") as mock_bus,
):
with pytest.raises(asyncio.CancelledError):
await execute(mock_agent, mock_context, mock_event_queue)
@@ -354,6 +357,7 @@ class TestExecuteAndCancelIntegration:
mock_task: MagicMock,
) -> None:
"""Calling cancel stops a running execute."""
async def slow_task(**kwargs: Any) -> str:
await asyncio.sleep(2.0)
return "should not complete"
@@ -361,8 +365,8 @@ class TestExecuteAndCancelIntegration:
mock_agent.aexecute_task = slow_task
with (
patch("crewai.a2a.utils.task.Task", return_value=mock_task),
patch("crewai.a2a.utils.task.crewai_event_bus"),
patch("crewai_a2a.utils.task.Task", return_value=mock_task),
patch("crewai_a2a.utils.task.crewai_event_bus"),
):
execute_task = asyncio.create_task(
execute(mock_agent, mock_context, mock_event_queue)
@@ -372,4 +376,4 @@ class TestExecuteAndCancelIntegration:
await cancel(mock_context, mock_event_queue)
with pytest.raises(asyncio.CancelledError):
await execute_task
await execute_task

View File

@@ -96,12 +96,7 @@ azure-ai-inference = [
anthropic = [
"anthropic~=0.73.0",
]
a2a = [
"a2a-sdk~=0.3.10",
"httpx-auth~=0.23.1",
"httpx-sse~=0.4.0",
"aiocache[redis,memcached]~=0.12.3",
]
a2a = ["crewai-a2a==1.10.1b1"]
file-processing = [
"crewai-files",
]
@@ -132,6 +127,7 @@ torchvision = [
{ index = "pytorch", marker = "python_version < '3.13'" },
]
crewai-files = { workspace = true }
crewai-a2a = { workspace = true }
[build-system]

View File

@@ -1,10 +1,13 @@
"""Agent-to-Agent (A2A) protocol communication module for CrewAI."""
"""Backward-compatibility shim — use ``crewai_a2a`` instead."""
from crewai.a2a.config import A2AClientConfig, A2AConfig, A2AServerConfig
import warnings
__all__ = [
"A2AClientConfig",
"A2AConfig",
"A2AServerConfig",
]
warnings.warn(
"'crewai.a2a' has been moved to 'crewai_a2a'. "
"Please update your imports. The old path will be removed in v2.0.0.",
FutureWarning,
stacklevel=2,
)
from crewai_a2a import * # noqa: E402, F403

View File

@@ -1,36 +1,13 @@
"""A2A authentication schemas."""
"""Backward-compatibility shim — use ``crewai_a2a.auth`` instead."""
from crewai.a2a.auth.client_schemes import (
APIKeyAuth,
AuthScheme,
BearerTokenAuth,
ClientAuthScheme,
HTTPBasicAuth,
HTTPDigestAuth,
OAuth2AuthorizationCode,
OAuth2ClientCredentials,
TLSConfig,
)
from crewai.a2a.auth.server_schemes import (
AuthenticatedUser,
OIDCAuth,
ServerAuthScheme,
SimpleTokenAuth,
import warnings
warnings.warn(
"'crewai.a2a.auth' has been moved to 'crewai_a2a.auth'. "
"Please update your imports. The old path will be removed in v2.0.0.",
FutureWarning,
stacklevel=2,
)
__all__ = [
"APIKeyAuth",
"AuthScheme",
"AuthenticatedUser",
"BearerTokenAuth",
"ClientAuthScheme",
"HTTPBasicAuth",
"HTTPDigestAuth",
"OAuth2AuthorizationCode",
"OAuth2ClientCredentials",
"OIDCAuth",
"ServerAuthScheme",
"SimpleTokenAuth",
"TLSConfig",
]
from crewai_a2a.auth import * # noqa: E402, F403

View File

@@ -1,550 +1,13 @@
"""Authentication schemes for A2A protocol clients.
"""Backward-compatibility shim — use ``crewai_a2a.auth.client_schemes`` instead."""
Supported authentication methods:
- Bearer tokens
- OAuth2 (Client Credentials, Authorization Code)
- API Keys (header, query, cookie)
- HTTP Basic authentication
- HTTP Digest authentication
- mTLS (mutual TLS) client certificate authentication
"""
import warnings
from __future__ import annotations
from abc import ABC, abstractmethod
import asyncio
import base64
from collections.abc import Awaitable, Callable, MutableMapping
from pathlib import Path
import ssl
import time
from typing import TYPE_CHECKING, ClassVar, Literal
import urllib.parse
warnings.warn(
"'crewai.a2a.auth.client_schemes' has been moved to 'crewai_a2a.auth.client_schemes'. "
"Please update your imports. The old path will be removed in v2.0.0.",
FutureWarning,
stacklevel=2,
)
import httpx
from httpx import DigestAuth
from pydantic import BaseModel, ConfigDict, Field, FilePath, PrivateAttr
from typing_extensions import deprecated
if TYPE_CHECKING:
import grpc # type: ignore[import-untyped]
class TLSConfig(BaseModel):
"""TLS/mTLS configuration for secure client connections.
Supports mutual TLS (mTLS) where the client presents a certificate to the server,
and standard TLS with custom CA verification.
Attributes:
client_cert_path: Path to client certificate file (PEM format) for mTLS.
client_key_path: Path to client private key file (PEM format) for mTLS.
ca_cert_path: Path to CA certificate bundle for server verification.
verify: Whether to verify server certificates. Set False only for development.
"""
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
client_cert_path: FilePath | None = Field(
default=None,
description="Path to client certificate file (PEM format) for mTLS",
)
client_key_path: FilePath | None = Field(
default=None,
description="Path to client private key file (PEM format) for mTLS",
)
ca_cert_path: FilePath | None = Field(
default=None,
description="Path to CA certificate bundle for server verification",
)
verify: bool = Field(
default=True,
description="Whether to verify server certificates. Set False only for development.",
)
def get_httpx_ssl_context(self) -> ssl.SSLContext | bool | str:
"""Build SSL context for httpx client.
Returns:
SSL context if certificates configured, True for default verification,
False if verification disabled, or path to CA bundle.
"""
if not self.verify:
return False
if self.client_cert_path and self.client_key_path:
context = ssl.create_default_context()
if self.ca_cert_path:
context.load_verify_locations(cafile=str(self.ca_cert_path))
context.load_cert_chain(
certfile=str(self.client_cert_path),
keyfile=str(self.client_key_path),
)
return context
if self.ca_cert_path:
return str(self.ca_cert_path)
return True
def get_grpc_credentials(self) -> grpc.ChannelCredentials | None: # type: ignore[no-any-unimported]
"""Build gRPC channel credentials for secure connections.
Returns:
gRPC SSL credentials if certificates configured, None otherwise.
"""
try:
import grpc
except ImportError:
return None
if not self.verify and not self.client_cert_path:
return None
root_certs: bytes | None = None
private_key: bytes | None = None
certificate_chain: bytes | None = None
if self.ca_cert_path:
root_certs = Path(self.ca_cert_path).read_bytes()
if self.client_cert_path and self.client_key_path:
private_key = Path(self.client_key_path).read_bytes()
certificate_chain = Path(self.client_cert_path).read_bytes()
return grpc.ssl_channel_credentials(
root_certificates=root_certs,
private_key=private_key,
certificate_chain=certificate_chain,
)
class ClientAuthScheme(ABC, BaseModel):
"""Base class for client-side authentication schemes.
Client auth schemes apply credentials to outgoing requests.
Attributes:
tls: Optional TLS/mTLS configuration for secure connections.
"""
tls: TLSConfig | None = Field(
default=None,
description="TLS/mTLS configuration for secure connections",
)
@abstractmethod
async def apply_auth(
self, client: httpx.AsyncClient, headers: MutableMapping[str, str]
) -> MutableMapping[str, str]:
"""Apply authentication to request headers.
Args:
client: HTTP client for making auth requests.
headers: Current request headers.
Returns:
Updated headers with authentication applied.
"""
...
@deprecated("Use ClientAuthScheme instead", category=FutureWarning)
class AuthScheme(ClientAuthScheme):
"""Deprecated: Use ClientAuthScheme instead."""
class BearerTokenAuth(ClientAuthScheme):
"""Bearer token authentication (Authorization: Bearer <token>).
Attributes:
token: Bearer token for authentication.
"""
token: str = Field(description="Bearer token")
async def apply_auth(
self, client: httpx.AsyncClient, headers: MutableMapping[str, str]
) -> MutableMapping[str, str]:
"""Apply Bearer token to Authorization header.
Args:
client: HTTP client for making auth requests.
headers: Current request headers.
Returns:
Updated headers with Bearer token in Authorization header.
"""
headers["Authorization"] = f"Bearer {self.token}"
return headers
class HTTPBasicAuth(ClientAuthScheme):
"""HTTP Basic authentication.
Attributes:
username: Username for Basic authentication.
password: Password for Basic authentication.
"""
username: str = Field(description="Username")
password: str = Field(description="Password")
async def apply_auth(
self, client: httpx.AsyncClient, headers: MutableMapping[str, str]
) -> MutableMapping[str, str]:
"""Apply HTTP Basic authentication.
Args:
client: HTTP client for making auth requests.
headers: Current request headers.
Returns:
Updated headers with Basic auth in Authorization header.
"""
credentials = f"{self.username}:{self.password}"
encoded = base64.b64encode(credentials.encode()).decode()
headers["Authorization"] = f"Basic {encoded}"
return headers
class HTTPDigestAuth(ClientAuthScheme):
"""HTTP Digest authentication.
Note: Uses httpx-auth library for digest implementation.
Attributes:
username: Username for Digest authentication.
password: Password for Digest authentication.
"""
username: str = Field(description="Username")
password: str = Field(description="Password")
_configured_client_id: int | None = PrivateAttr(default=None)
async def apply_auth(
self, client: httpx.AsyncClient, headers: MutableMapping[str, str]
) -> MutableMapping[str, str]:
"""Digest auth is handled by httpx auth flow, not headers.
Args:
client: HTTP client for making auth requests.
headers: Current request headers.
Returns:
Unchanged headers (Digest auth handled by httpx auth flow).
"""
return headers
def configure_client(self, client: httpx.AsyncClient) -> None:
"""Configure client with Digest auth.
Idempotent: Only configures the client once. Subsequent calls on the same
client instance are no-ops to prevent overwriting auth configuration.
Args:
client: HTTP client to configure with Digest authentication.
"""
client_id = id(client)
if self._configured_client_id == client_id:
return
client.auth = DigestAuth(self.username, self.password)
self._configured_client_id = client_id
class APIKeyAuth(ClientAuthScheme):
"""API Key authentication (header, query, or cookie).
Attributes:
api_key: API key value for authentication.
location: Where to send the API key (header, query, or cookie).
name: Parameter name for the API key (default: X-API-Key).
"""
api_key: str = Field(description="API key value")
location: Literal["header", "query", "cookie"] = Field(
default="header", description="Where to send the API key"
)
name: str = Field(default="X-API-Key", description="Parameter name for the API key")
_configured_client_ids: set[int] = PrivateAttr(default_factory=set)
async def apply_auth(
self, client: httpx.AsyncClient, headers: MutableMapping[str, str]
) -> MutableMapping[str, str]:
"""Apply API key authentication.
Args:
client: HTTP client for making auth requests.
headers: Current request headers.
Returns:
Updated headers with API key (for header/cookie locations).
"""
if self.location == "header":
headers[self.name] = self.api_key
elif self.location == "cookie":
headers["Cookie"] = f"{self.name}={self.api_key}"
return headers
def configure_client(self, client: httpx.AsyncClient) -> None:
"""Configure client for query param API keys.
Idempotent: Only adds the request hook once per client instance.
Subsequent calls on the same client are no-ops to prevent hook accumulation.
Args:
client: HTTP client to configure with query param API key hook.
"""
if self.location == "query":
client_id = id(client)
if client_id in self._configured_client_ids:
return
async def _add_api_key_param(request: httpx.Request) -> None:
url = httpx.URL(request.url)
request.url = url.copy_add_param(self.name, self.api_key)
client.event_hooks["request"].append(_add_api_key_param)
self._configured_client_ids.add(client_id)
class OAuth2ClientCredentials(ClientAuthScheme):
"""OAuth2 Client Credentials flow authentication.
Thread-safe implementation with asyncio.Lock to prevent concurrent token fetches
when multiple requests share the same auth instance.
Attributes:
token_url: OAuth2 token endpoint URL.
client_id: OAuth2 client identifier.
client_secret: OAuth2 client secret.
scopes: List of required OAuth2 scopes.
"""
token_url: str = Field(description="OAuth2 token endpoint")
client_id: str = Field(description="OAuth2 client ID")
client_secret: str = Field(description="OAuth2 client secret")
scopes: list[str] = Field(
default_factory=list, description="Required OAuth2 scopes"
)
_access_token: str | None = PrivateAttr(default=None)
_token_expires_at: float | None = PrivateAttr(default=None)
_lock: asyncio.Lock = PrivateAttr(default_factory=asyncio.Lock)
async def apply_auth(
self, client: httpx.AsyncClient, headers: MutableMapping[str, str]
) -> MutableMapping[str, str]:
"""Apply OAuth2 access token to Authorization header.
Uses asyncio.Lock to ensure only one coroutine fetches tokens at a time,
preventing race conditions when multiple concurrent requests use the same
auth instance.
Args:
client: HTTP client for making token requests.
headers: Current request headers.
Returns:
Updated headers with OAuth2 access token in Authorization header.
"""
if (
self._access_token is None
or self._token_expires_at is None
or time.time() >= self._token_expires_at
):
async with self._lock:
if (
self._access_token is None
or self._token_expires_at is None
or time.time() >= self._token_expires_at
):
await self._fetch_token(client)
if self._access_token:
headers["Authorization"] = f"Bearer {self._access_token}"
return headers
async def _fetch_token(self, client: httpx.AsyncClient) -> None:
"""Fetch OAuth2 access token using client credentials flow.
Args:
client: HTTP client for making token request.
Raises:
httpx.HTTPStatusError: If token request fails.
"""
data = {
"grant_type": "client_credentials",
"client_id": self.client_id,
"client_secret": self.client_secret,
}
if self.scopes:
data["scope"] = " ".join(self.scopes)
response = await client.post(self.token_url, data=data)
response.raise_for_status()
token_data = response.json()
self._access_token = token_data["access_token"]
expires_in = token_data.get("expires_in", 3600)
self._token_expires_at = time.time() + expires_in - 60
class OAuth2AuthorizationCode(ClientAuthScheme):
"""OAuth2 Authorization Code flow authentication.
Thread-safe implementation with asyncio.Lock to prevent concurrent token operations.
Note: Requires interactive authorization.
Attributes:
authorization_url: OAuth2 authorization endpoint URL.
token_url: OAuth2 token endpoint URL.
client_id: OAuth2 client identifier.
client_secret: OAuth2 client secret.
redirect_uri: OAuth2 redirect URI for callback.
scopes: List of required OAuth2 scopes.
"""
authorization_url: str = Field(description="OAuth2 authorization endpoint")
token_url: str = Field(description="OAuth2 token endpoint")
client_id: str = Field(description="OAuth2 client ID")
client_secret: str = Field(description="OAuth2 client secret")
redirect_uri: str = Field(description="OAuth2 redirect URI")
scopes: list[str] = Field(
default_factory=list, description="Required OAuth2 scopes"
)
_access_token: str | None = PrivateAttr(default=None)
_refresh_token: str | None = PrivateAttr(default=None)
_token_expires_at: float | None = PrivateAttr(default=None)
_authorization_callback: Callable[[str], Awaitable[str]] | None = PrivateAttr(
default=None
)
_lock: asyncio.Lock = PrivateAttr(default_factory=asyncio.Lock)
def set_authorization_callback(
self, callback: Callable[[str], Awaitable[str]] | None
) -> None:
"""Set callback to handle authorization URL.
Args:
callback: Async function that receives authorization URL and returns auth code.
"""
self._authorization_callback = callback
async def apply_auth(
self, client: httpx.AsyncClient, headers: MutableMapping[str, str]
) -> MutableMapping[str, str]:
"""Apply OAuth2 access token to Authorization header.
Uses asyncio.Lock to ensure only one coroutine handles token operations
(initial fetch or refresh) at a time.
Args:
client: HTTP client for making token requests.
headers: Current request headers.
Returns:
Updated headers with OAuth2 access token in Authorization header.
Raises:
ValueError: If authorization callback is not set.
"""
if self._access_token is None:
if self._authorization_callback is None:
msg = "Authorization callback not set. Use set_authorization_callback()"
raise ValueError(msg)
async with self._lock:
if self._access_token is None:
await self._fetch_initial_token(client)
elif self._token_expires_at and time.time() >= self._token_expires_at:
async with self._lock:
if self._token_expires_at and time.time() >= self._token_expires_at:
await self._refresh_access_token(client)
if self._access_token:
headers["Authorization"] = f"Bearer {self._access_token}"
return headers
async def _fetch_initial_token(self, client: httpx.AsyncClient) -> None:
"""Fetch initial access token using authorization code flow.
Args:
client: HTTP client for making token request.
Raises:
ValueError: If authorization callback is not set.
httpx.HTTPStatusError: If token request fails.
"""
params = {
"response_type": "code",
"client_id": self.client_id,
"redirect_uri": self.redirect_uri,
"scope": " ".join(self.scopes),
}
auth_url = f"{self.authorization_url}?{urllib.parse.urlencode(params)}"
if self._authorization_callback is None:
msg = "Authorization callback not set"
raise ValueError(msg)
auth_code = await self._authorization_callback(auth_url)
data = {
"grant_type": "authorization_code",
"code": auth_code,
"client_id": self.client_id,
"client_secret": self.client_secret,
"redirect_uri": self.redirect_uri,
}
response = await client.post(self.token_url, data=data)
response.raise_for_status()
token_data = response.json()
self._access_token = token_data["access_token"]
self._refresh_token = token_data.get("refresh_token")
expires_in = token_data.get("expires_in", 3600)
self._token_expires_at = time.time() + expires_in - 60
async def _refresh_access_token(self, client: httpx.AsyncClient) -> None:
"""Refresh the access token using refresh token.
Args:
client: HTTP client for making token request.
Raises:
httpx.HTTPStatusError: If token refresh request fails.
"""
if not self._refresh_token:
await self._fetch_initial_token(client)
return
data = {
"grant_type": "refresh_token",
"refresh_token": self._refresh_token,
"client_id": self.client_id,
"client_secret": self.client_secret,
}
response = await client.post(self.token_url, data=data)
response.raise_for_status()
token_data = response.json()
self._access_token = token_data["access_token"]
if "refresh_token" in token_data:
self._refresh_token = token_data["refresh_token"]
expires_in = token_data.get("expires_in", 3600)
self._token_expires_at = time.time() + expires_in - 60
from crewai_a2a.auth.client_schemes import * # noqa: E402, F403

View File

@@ -1,71 +1,13 @@
"""Deprecated: Authentication schemes for A2A protocol agents.
"""Backward-compatibility shim — use ``crewai_a2a.auth.schemas`` instead."""
This module is deprecated. Import from crewai.a2a.auth instead:
- crewai.a2a.auth.ClientAuthScheme (replaces AuthScheme)
- crewai.a2a.auth.BearerTokenAuth
- crewai.a2a.auth.HTTPBasicAuth
- crewai.a2a.auth.HTTPDigestAuth
- crewai.a2a.auth.APIKeyAuth
- crewai.a2a.auth.OAuth2ClientCredentials
- crewai.a2a.auth.OAuth2AuthorizationCode
"""
import warnings
from __future__ import annotations
from typing_extensions import deprecated
from crewai.a2a.auth.client_schemes import (
APIKeyAuth as _APIKeyAuth,
BearerTokenAuth as _BearerTokenAuth,
ClientAuthScheme as _ClientAuthScheme,
HTTPBasicAuth as _HTTPBasicAuth,
HTTPDigestAuth as _HTTPDigestAuth,
OAuth2AuthorizationCode as _OAuth2AuthorizationCode,
OAuth2ClientCredentials as _OAuth2ClientCredentials,
warnings.warn(
"'crewai.a2a.auth.schemas' has been moved to 'crewai_a2a.auth.schemas'. "
"Please update your imports. The old path will be removed in v2.0.0.",
FutureWarning,
stacklevel=2,
)
@deprecated("Use ClientAuthScheme from crewai.a2a.auth instead", category=FutureWarning)
class AuthScheme(_ClientAuthScheme):
"""Deprecated: Use ClientAuthScheme from crewai.a2a.auth instead."""
@deprecated("Import from crewai.a2a.auth instead", category=FutureWarning)
class BearerTokenAuth(_BearerTokenAuth):
"""Deprecated: Import from crewai.a2a.auth instead."""
@deprecated("Import from crewai.a2a.auth instead", category=FutureWarning)
class HTTPBasicAuth(_HTTPBasicAuth):
"""Deprecated: Import from crewai.a2a.auth instead."""
@deprecated("Import from crewai.a2a.auth instead", category=FutureWarning)
class HTTPDigestAuth(_HTTPDigestAuth):
"""Deprecated: Import from crewai.a2a.auth instead."""
@deprecated("Import from crewai.a2a.auth instead", category=FutureWarning)
class APIKeyAuth(_APIKeyAuth):
"""Deprecated: Import from crewai.a2a.auth instead."""
@deprecated("Import from crewai.a2a.auth instead", category=FutureWarning)
class OAuth2ClientCredentials(_OAuth2ClientCredentials):
"""Deprecated: Import from crewai.a2a.auth instead."""
@deprecated("Import from crewai.a2a.auth instead", category=FutureWarning)
class OAuth2AuthorizationCode(_OAuth2AuthorizationCode):
"""Deprecated: Import from crewai.a2a.auth instead."""
__all__ = [
"APIKeyAuth",
"AuthScheme",
"BearerTokenAuth",
"HTTPBasicAuth",
"HTTPDigestAuth",
"OAuth2AuthorizationCode",
"OAuth2ClientCredentials",
]
from crewai_a2a.auth.schemas import * # noqa: E402, F403

View File

@@ -1,739 +1,13 @@
"""Server-side authentication schemes for A2A protocol.
"""Backward-compatibility shim — use ``crewai_a2a.auth.server_schemes`` instead."""
These schemes validate incoming requests to A2A server endpoints.
import warnings
Supported authentication methods:
- Simple token validation with static bearer tokens
- OpenID Connect with JWT validation using JWKS
- OAuth2 with JWT validation or token introspection
"""
from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass
import logging
import os
from typing import TYPE_CHECKING, Annotated, Any, ClassVar, Literal
import jwt
from jwt import PyJWKClient
from pydantic import (
BaseModel,
BeforeValidator,
ConfigDict,
Field,
HttpUrl,
PrivateAttr,
SecretStr,
model_validator,
warnings.warn(
"'crewai.a2a.auth.server_schemes' has been moved to 'crewai_a2a.auth.server_schemes'. "
"Please update your imports. The old path will be removed in v2.0.0.",
FutureWarning,
stacklevel=2,
)
from typing_extensions import Self
if TYPE_CHECKING:
from a2a.types import OAuth2SecurityScheme
logger = logging.getLogger(__name__)
try:
from fastapi import HTTPException, status as http_status
HTTP_401_UNAUTHORIZED = http_status.HTTP_401_UNAUTHORIZED
HTTP_500_INTERNAL_SERVER_ERROR = http_status.HTTP_500_INTERNAL_SERVER_ERROR
HTTP_503_SERVICE_UNAVAILABLE = http_status.HTTP_503_SERVICE_UNAVAILABLE
except ImportError:
class HTTPException(Exception): # type: ignore[no-redef] # noqa: N818
"""Fallback HTTPException when FastAPI is not installed."""
def __init__(
self,
status_code: int,
detail: str | None = None,
headers: dict[str, str] | None = None,
) -> None:
self.status_code = status_code
self.detail = detail
self.headers = headers
super().__init__(detail)
HTTP_401_UNAUTHORIZED = 401
HTTP_500_INTERNAL_SERVER_ERROR = 500
HTTP_503_SERVICE_UNAVAILABLE = 503
def _coerce_secret_str(v: str | SecretStr | None) -> SecretStr | None:
"""Coerce string to SecretStr."""
if v is None or isinstance(v, SecretStr):
return v
return SecretStr(v)
CoercedSecretStr = Annotated[SecretStr, BeforeValidator(_coerce_secret_str)]
JWTAlgorithm = Literal[
"RS256",
"RS384",
"RS512",
"ES256",
"ES384",
"ES512",
"PS256",
"PS384",
"PS512",
]
@dataclass
class AuthenticatedUser:
"""Result of successful authentication.
Attributes:
token: The original token that was validated.
scheme: Name of the authentication scheme used.
claims: JWT claims from OIDC or OAuth2 authentication.
"""
token: str
scheme: str
claims: dict[str, Any] | None = None
class ServerAuthScheme(ABC, BaseModel):
"""Base class for server-side authentication schemes.
Each scheme validates incoming requests and returns an AuthenticatedUser
on success, or raises HTTPException on failure.
"""
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
@abstractmethod
async def authenticate(self, token: str) -> AuthenticatedUser:
"""Authenticate the provided token.
Args:
token: The bearer token to authenticate.
Returns:
AuthenticatedUser on successful authentication.
Raises:
HTTPException: If authentication fails.
"""
...
class SimpleTokenAuth(ServerAuthScheme):
"""Simple bearer token authentication.
Validates tokens against a configured static token or AUTH_TOKEN env var.
Attributes:
token: Expected token value. Falls back to AUTH_TOKEN env var if not set.
"""
token: CoercedSecretStr | None = Field(
default=None,
description="Expected token. Falls back to AUTH_TOKEN env var.",
)
def _get_expected_token(self) -> str | None:
"""Get the expected token value."""
if self.token:
return self.token.get_secret_value()
return os.environ.get("AUTH_TOKEN")
async def authenticate(self, token: str) -> AuthenticatedUser:
"""Authenticate using simple token comparison.
Args:
token: The bearer token to authenticate.
Returns:
AuthenticatedUser on successful authentication.
Raises:
HTTPException: If authentication fails.
"""
expected = self._get_expected_token()
if expected is None:
logger.warning(
"Simple token authentication failed",
extra={"reason": "no_token_configured"},
)
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Authentication not configured",
)
if token != expected:
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Invalid or missing authentication credentials",
)
return AuthenticatedUser(
token=token,
scheme="simple_token",
)
class OIDCAuth(ServerAuthScheme):
"""OpenID Connect authentication.
Validates JWTs using JWKS with caching support via PyJWT.
Attributes:
issuer: The OpenID Connect issuer URL.
audience: The expected audience claim.
jwks_url: Optional explicit JWKS URL. Derived from issuer if not set.
algorithms: List of allowed signing algorithms.
required_claims: List of claims that must be present in the token.
jwks_cache_ttl: TTL for JWKS cache in seconds.
clock_skew_seconds: Allowed clock skew for token validation.
"""
issuer: HttpUrl = Field(
description="OpenID Connect issuer URL (e.g., https://auth.example.com)"
)
audience: str = Field(description="Expected audience claim (e.g., api://my-agent)")
jwks_url: HttpUrl | None = Field(
default=None,
description="Explicit JWKS URL. Derived from issuer if not set.",
)
algorithms: list[str] = Field(
default_factory=lambda: ["RS256"],
description="List of allowed signing algorithms (RS256, ES256, etc.)",
)
required_claims: list[str] = Field(
default_factory=lambda: ["exp", "iat", "iss", "aud", "sub"],
description="List of claims that must be present in the token",
)
jwks_cache_ttl: int = Field(
default=3600,
description="TTL for JWKS cache in seconds",
ge=60,
)
clock_skew_seconds: float = Field(
default=30.0,
description="Allowed clock skew for token validation",
ge=0.0,
)
_jwk_client: PyJWKClient | None = PrivateAttr(default=None)
@model_validator(mode="after")
def _init_jwk_client(self) -> Self:
"""Initialize the JWK client after model creation."""
jwks_url = (
str(self.jwks_url)
if self.jwks_url
else f"{str(self.issuer).rstrip('/')}/.well-known/jwks.json"
)
self._jwk_client = PyJWKClient(jwks_url, lifespan=self.jwks_cache_ttl)
return self
async def authenticate(self, token: str) -> AuthenticatedUser:
"""Authenticate using OIDC JWT validation.
Args:
token: The JWT to authenticate.
Returns:
AuthenticatedUser on successful authentication.
Raises:
HTTPException: If authentication fails.
"""
if self._jwk_client is None:
raise HTTPException(
status_code=HTTP_500_INTERNAL_SERVER_ERROR,
detail="OIDC not initialized",
)
try:
signing_key = self._jwk_client.get_signing_key_from_jwt(token)
claims = jwt.decode(
token,
signing_key.key,
algorithms=self.algorithms,
audience=self.audience,
issuer=str(self.issuer).rstrip("/"),
leeway=self.clock_skew_seconds,
options={
"require": self.required_claims,
},
)
return AuthenticatedUser(
token=token,
scheme="oidc",
claims=claims,
)
except jwt.ExpiredSignatureError:
logger.debug(
"OIDC authentication failed",
extra={"reason": "token_expired", "scheme": "oidc"},
)
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Token has expired",
) from None
except jwt.InvalidAudienceError:
logger.debug(
"OIDC authentication failed",
extra={"reason": "invalid_audience", "scheme": "oidc"},
)
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Invalid token audience",
) from None
except jwt.InvalidIssuerError:
logger.debug(
"OIDC authentication failed",
extra={"reason": "invalid_issuer", "scheme": "oidc"},
)
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Invalid token issuer",
) from None
except jwt.MissingRequiredClaimError as e:
logger.debug(
"OIDC authentication failed",
extra={"reason": "missing_claim", "claim": e.claim, "scheme": "oidc"},
)
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail=f"Missing required claim: {e.claim}",
) from None
except jwt.PyJWKClientError as e:
logger.error(
"OIDC authentication failed",
extra={
"reason": "jwks_client_error",
"error": str(e),
"scheme": "oidc",
},
)
raise HTTPException(
status_code=HTTP_503_SERVICE_UNAVAILABLE,
detail="Unable to fetch signing keys",
) from None
except jwt.InvalidTokenError as e:
logger.debug(
"OIDC authentication failed",
extra={"reason": "invalid_token", "error": str(e), "scheme": "oidc"},
)
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Invalid or missing authentication credentials",
) from None
class OAuth2ServerAuth(ServerAuthScheme):
"""OAuth2 authentication for A2A server.
Declares OAuth2 security scheme in AgentCard and validates tokens using
either JWKS for JWT tokens or token introspection for opaque tokens.
This is distinct from OIDCAuth in that it declares an explicit OAuth2SecurityScheme
with flows, rather than an OpenIdConnectSecurityScheme with discovery URL.
Attributes:
token_url: OAuth2 token endpoint URL for client_credentials flow.
authorization_url: OAuth2 authorization endpoint for authorization_code flow.
refresh_url: Optional refresh token endpoint URL.
scopes: Available OAuth2 scopes with descriptions.
jwks_url: JWKS URL for JWT validation. Required if not using introspection.
introspection_url: Token introspection endpoint (RFC 7662). Alternative to JWKS.
introspection_client_id: Client ID for introspection endpoint authentication.
introspection_client_secret: Client secret for introspection endpoint.
audience: Expected audience claim for JWT validation.
issuer: Expected issuer claim for JWT validation.
algorithms: Allowed JWT signing algorithms.
required_claims: Claims that must be present in the token.
jwks_cache_ttl: TTL for JWKS cache in seconds.
clock_skew_seconds: Allowed clock skew for token validation.
"""
token_url: HttpUrl = Field(
description="OAuth2 token endpoint URL",
)
authorization_url: HttpUrl | None = Field(
default=None,
description="OAuth2 authorization endpoint URL for authorization_code flow",
)
refresh_url: HttpUrl | None = Field(
default=None,
description="OAuth2 refresh token endpoint URL",
)
scopes: dict[str, str] = Field(
default_factory=dict,
description="Available OAuth2 scopes with descriptions",
)
jwks_url: HttpUrl | None = Field(
default=None,
description="JWKS URL for JWT validation. Required if not using introspection.",
)
introspection_url: HttpUrl | None = Field(
default=None,
description="Token introspection endpoint (RFC 7662). Alternative to JWKS.",
)
introspection_client_id: str | None = Field(
default=None,
description="Client ID for introspection endpoint authentication",
)
introspection_client_secret: CoercedSecretStr | None = Field(
default=None,
description="Client secret for introspection endpoint authentication",
)
audience: str | None = Field(
default=None,
description="Expected audience claim for JWT validation",
)
issuer: str | None = Field(
default=None,
description="Expected issuer claim for JWT validation",
)
algorithms: list[str] = Field(
default_factory=lambda: ["RS256"],
description="Allowed JWT signing algorithms",
)
required_claims: list[str] = Field(
default_factory=lambda: ["exp", "iat"],
description="Claims that must be present in the token",
)
jwks_cache_ttl: int = Field(
default=3600,
description="TTL for JWKS cache in seconds",
ge=60,
)
clock_skew_seconds: float = Field(
default=30.0,
description="Allowed clock skew for token validation",
ge=0.0,
)
_jwk_client: PyJWKClient | None = PrivateAttr(default=None)
@model_validator(mode="after")
def _validate_and_init(self) -> Self:
"""Validate configuration and initialize JWKS client if needed."""
if not self.jwks_url and not self.introspection_url:
raise ValueError(
"Either jwks_url or introspection_url must be provided for token validation"
)
if self.introspection_url:
if not self.introspection_client_id or not self.introspection_client_secret:
raise ValueError(
"introspection_client_id and introspection_client_secret are required "
"when using token introspection"
)
if self.jwks_url:
self._jwk_client = PyJWKClient(
str(self.jwks_url), lifespan=self.jwks_cache_ttl
)
return self
async def authenticate(self, token: str) -> AuthenticatedUser:
"""Authenticate using OAuth2 token validation.
Uses JWKS validation if jwks_url is configured, otherwise falls back
to token introspection.
Args:
token: The OAuth2 access token to authenticate.
Returns:
AuthenticatedUser on successful authentication.
Raises:
HTTPException: If authentication fails.
"""
if self._jwk_client:
return await self._authenticate_jwt(token)
return await self._authenticate_introspection(token)
async def _authenticate_jwt(self, token: str) -> AuthenticatedUser:
"""Authenticate using JWKS JWT validation."""
if self._jwk_client is None:
raise HTTPException(
status_code=HTTP_500_INTERNAL_SERVER_ERROR,
detail="OAuth2 JWKS not initialized",
)
try:
signing_key = self._jwk_client.get_signing_key_from_jwt(token)
decode_options: dict[str, Any] = {
"require": self.required_claims,
}
claims = jwt.decode(
token,
signing_key.key,
algorithms=self.algorithms,
audience=self.audience,
issuer=self.issuer,
leeway=self.clock_skew_seconds,
options=decode_options,
)
return AuthenticatedUser(
token=token,
scheme="oauth2",
claims=claims,
)
except jwt.ExpiredSignatureError:
logger.debug(
"OAuth2 authentication failed",
extra={"reason": "token_expired", "scheme": "oauth2"},
)
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Token has expired",
) from None
except jwt.InvalidAudienceError:
logger.debug(
"OAuth2 authentication failed",
extra={"reason": "invalid_audience", "scheme": "oauth2"},
)
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Invalid token audience",
) from None
except jwt.InvalidIssuerError:
logger.debug(
"OAuth2 authentication failed",
extra={"reason": "invalid_issuer", "scheme": "oauth2"},
)
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Invalid token issuer",
) from None
except jwt.MissingRequiredClaimError as e:
logger.debug(
"OAuth2 authentication failed",
extra={"reason": "missing_claim", "claim": e.claim, "scheme": "oauth2"},
)
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail=f"Missing required claim: {e.claim}",
) from None
except jwt.PyJWKClientError as e:
logger.error(
"OAuth2 authentication failed",
extra={
"reason": "jwks_client_error",
"error": str(e),
"scheme": "oauth2",
},
)
raise HTTPException(
status_code=HTTP_503_SERVICE_UNAVAILABLE,
detail="Unable to fetch signing keys",
) from None
except jwt.InvalidTokenError as e:
logger.debug(
"OAuth2 authentication failed",
extra={"reason": "invalid_token", "error": str(e), "scheme": "oauth2"},
)
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Invalid or missing authentication credentials",
) from None
async def _authenticate_introspection(self, token: str) -> AuthenticatedUser:
"""Authenticate using OAuth2 token introspection (RFC 7662)."""
import httpx
if not self.introspection_url:
raise HTTPException(
status_code=HTTP_500_INTERNAL_SERVER_ERROR,
detail="OAuth2 introspection not configured",
)
try:
async with httpx.AsyncClient() as client:
response = await client.post(
str(self.introspection_url),
data={"token": token},
auth=(
self.introspection_client_id or "",
self.introspection_client_secret.get_secret_value()
if self.introspection_client_secret
else "",
),
)
response.raise_for_status()
introspection_result = response.json()
except httpx.HTTPStatusError as e:
logger.error(
"OAuth2 introspection failed",
extra={"reason": "http_error", "status_code": e.response.status_code},
)
raise HTTPException(
status_code=HTTP_503_SERVICE_UNAVAILABLE,
detail="Token introspection service unavailable",
) from None
except Exception as e:
logger.error(
"OAuth2 introspection failed",
extra={"reason": "unexpected_error", "error": str(e)},
)
raise HTTPException(
status_code=HTTP_503_SERVICE_UNAVAILABLE,
detail="Token introspection failed",
) from None
if not introspection_result.get("active", False):
logger.debug(
"OAuth2 authentication failed",
extra={"reason": "token_not_active", "scheme": "oauth2"},
)
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Token is not active",
)
return AuthenticatedUser(
token=token,
scheme="oauth2",
claims=introspection_result,
)
def to_security_scheme(self) -> OAuth2SecurityScheme:
"""Generate OAuth2SecurityScheme for AgentCard declaration.
Creates an OAuth2SecurityScheme with appropriate flows based on
the configured URLs. Includes client_credentials flow if token_url
is set, and authorization_code flow if authorization_url is set.
Returns:
OAuth2SecurityScheme suitable for use in AgentCard security_schemes.
"""
from a2a.types import (
AuthorizationCodeOAuthFlow,
ClientCredentialsOAuthFlow,
OAuth2SecurityScheme,
OAuthFlows,
)
client_credentials = None
authorization_code = None
if self.token_url:
client_credentials = ClientCredentialsOAuthFlow(
token_url=str(self.token_url),
refresh_url=str(self.refresh_url) if self.refresh_url else None,
scopes=self.scopes,
)
if self.authorization_url:
authorization_code = AuthorizationCodeOAuthFlow(
authorization_url=str(self.authorization_url),
token_url=str(self.token_url),
refresh_url=str(self.refresh_url) if self.refresh_url else None,
scopes=self.scopes,
)
return OAuth2SecurityScheme(
flows=OAuthFlows(
client_credentials=client_credentials,
authorization_code=authorization_code,
),
description="OAuth2 authentication",
)
class APIKeyServerAuth(ServerAuthScheme):
"""API Key authentication for A2A server.
Validates requests using an API key in a header, query parameter, or cookie.
Attributes:
name: The name of the API key parameter (default: X-API-Key).
location: Where to look for the API key (header, query, or cookie).
api_key: The expected API key value.
"""
name: str = Field(
default="X-API-Key",
description="Name of the API key parameter",
)
location: Literal["header", "query", "cookie"] = Field(
default="header",
description="Where to look for the API key",
)
api_key: CoercedSecretStr = Field(
description="Expected API key value",
)
async def authenticate(self, token: str) -> AuthenticatedUser:
"""Authenticate using API key comparison.
Args:
token: The API key to authenticate.
Returns:
AuthenticatedUser on successful authentication.
Raises:
HTTPException: If authentication fails.
"""
if token != self.api_key.get_secret_value():
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Invalid API key",
)
return AuthenticatedUser(
token=token,
scheme="api_key",
)
class MTLSServerAuth(ServerAuthScheme):
"""Mutual TLS authentication marker for AgentCard declaration.
This scheme is primarily for AgentCard security_schemes declaration.
Actual mTLS verification happens at the TLS/transport layer, not
at the application layer via token validation.
When configured, this signals to clients that the server requires
client certificates for authentication.
"""
description: str = Field(
default="Mutual TLS certificate authentication",
description="Description for the security scheme",
)
async def authenticate(self, token: str) -> AuthenticatedUser:
"""Return authenticated user for mTLS.
mTLS verification happens at the transport layer before this is called.
If we reach this point, the TLS handshake with client cert succeeded.
Args:
token: Certificate subject or identifier (from TLS layer).
Returns:
AuthenticatedUser indicating mTLS authentication.
"""
return AuthenticatedUser(
token=token or "mtls-verified",
scheme="mtls",
)
from crewai_a2a.auth.server_schemes import * # noqa: E402, F403

View File

@@ -1,273 +1,13 @@
"""Authentication utilities for A2A protocol agent communication.
"""Backward-compatibility shim — use ``crewai_a2a.auth.utils`` instead."""
Provides validation and retry logic for various authentication schemes including
OAuth2, API keys, and HTTP authentication methods.
"""
import warnings
import asyncio
from collections.abc import Awaitable, Callable, MutableMapping
import hashlib
import re
import threading
from typing import Final, Literal, cast
from a2a.client.errors import A2AClientHTTPError
from a2a.types import (
APIKeySecurityScheme,
AgentCard,
HTTPAuthSecurityScheme,
OAuth2SecurityScheme,
)
from httpx import AsyncClient, Response
from crewai.a2a.auth.client_schemes import (
APIKeyAuth,
BearerTokenAuth,
ClientAuthScheme,
HTTPBasicAuth,
HTTPDigestAuth,
OAuth2AuthorizationCode,
OAuth2ClientCredentials,
warnings.warn(
"'crewai.a2a.auth.utils' has been moved to 'crewai_a2a.auth.utils'. "
"Please update your imports. The old path will be removed in v2.0.0.",
FutureWarning,
stacklevel=2,
)
class _AuthStore:
"""Store for authentication schemes with safe concurrent access."""
def __init__(self) -> None:
self._store: dict[str, ClientAuthScheme | None] = {}
self._lock = threading.RLock()
@staticmethod
def compute_key(auth_type: str, auth_data: str) -> str:
"""Compute a collision-resistant key using SHA-256."""
content = f"{auth_type}:{auth_data}"
return hashlib.sha256(content.encode()).hexdigest()
def set(self, key: str, auth: ClientAuthScheme | None) -> None:
"""Store an auth scheme."""
with self._lock:
self._store[key] = auth
def get(self, key: str) -> ClientAuthScheme | None:
"""Retrieve an auth scheme by key."""
with self._lock:
return self._store.get(key)
def __setitem__(self, key: str, value: ClientAuthScheme | None) -> None:
with self._lock:
self._store[key] = value
def __getitem__(self, key: str) -> ClientAuthScheme | None:
with self._lock:
return self._store[key]
_auth_store = _AuthStore()
_SCHEME_PATTERN: Final[re.Pattern[str]] = re.compile(r"(\w+)\s+(.+?)(?=,\s*\w+\s+|$)")
_PARAM_PATTERN: Final[re.Pattern[str]] = re.compile(r'(\w+)=(?:"([^"]*)"|([^\s,]+))')
_SCHEME_AUTH_MAPPING: Final[dict[type, tuple[type[ClientAuthScheme], ...]]] = {
OAuth2SecurityScheme: (
OAuth2ClientCredentials,
OAuth2AuthorizationCode,
BearerTokenAuth,
),
APIKeySecurityScheme: (APIKeyAuth,),
}
_HTTPSchemeType = Literal["basic", "digest", "bearer"]
_HTTP_SCHEME_MAPPING: Final[dict[_HTTPSchemeType, type[ClientAuthScheme]]] = {
"basic": HTTPBasicAuth,
"digest": HTTPDigestAuth,
"bearer": BearerTokenAuth,
}
def _raise_auth_mismatch(
expected_classes: type[ClientAuthScheme] | tuple[type[ClientAuthScheme], ...],
provided_auth: ClientAuthScheme,
) -> None:
"""Raise authentication mismatch error.
Args:
expected_classes: Expected authentication class or tuple of classes.
provided_auth: Actually provided authentication instance.
Raises:
A2AClientHTTPError: Always raises with 401 status code.
"""
if isinstance(expected_classes, tuple):
if len(expected_classes) == 1:
required = expected_classes[0].__name__
else:
names = [cls.__name__ for cls in expected_classes]
required = f"one of ({', '.join(names)})"
else:
required = expected_classes.__name__
msg = (
f"AgentCard requires {required} authentication, "
f"but {type(provided_auth).__name__} was provided"
)
raise A2AClientHTTPError(401, msg)
def parse_www_authenticate(header_value: str) -> dict[str, dict[str, str]]:
"""Parse WWW-Authenticate header into auth challenges.
Args:
header_value: The WWW-Authenticate header value.
Returns:
Dictionary mapping auth scheme to its parameters.
Example: {"Bearer": {"realm": "api", "scope": "read write"}}
"""
if not header_value:
return {}
challenges: dict[str, dict[str, str]] = {}
for match in _SCHEME_PATTERN.finditer(header_value):
scheme = match.group(1)
params_str = match.group(2)
params: dict[str, str] = {}
for param_match in _PARAM_PATTERN.finditer(params_str):
key = param_match.group(1)
value = param_match.group(2) or param_match.group(3)
params[key] = value
challenges[scheme] = params
return challenges
def validate_auth_against_agent_card(
agent_card: AgentCard, auth: ClientAuthScheme | None
) -> None:
"""Validate that provided auth matches AgentCard security requirements.
Args:
agent_card: The A2A AgentCard containing security requirements.
auth: User-provided authentication scheme (or None).
Raises:
A2AClientHTTPError: If auth doesn't match AgentCard requirements (status_code=401).
"""
if not agent_card.security or not agent_card.security_schemes:
return
if not auth:
msg = "AgentCard requires authentication but no auth scheme provided"
raise A2AClientHTTPError(401, msg)
first_security_req = agent_card.security[0] if agent_card.security else {}
for scheme_name in first_security_req.keys():
security_scheme_wrapper = agent_card.security_schemes.get(scheme_name)
if not security_scheme_wrapper:
continue
scheme = security_scheme_wrapper.root
if allowed_classes := _SCHEME_AUTH_MAPPING.get(type(scheme)):
if not isinstance(auth, allowed_classes):
_raise_auth_mismatch(allowed_classes, auth)
return
if isinstance(scheme, HTTPAuthSecurityScheme):
scheme_key = cast(_HTTPSchemeType, scheme.scheme.lower())
if required_class := _HTTP_SCHEME_MAPPING.get(scheme_key):
if not isinstance(auth, required_class):
_raise_auth_mismatch(required_class, auth)
return
msg = "Could not validate auth against AgentCard security requirements"
raise A2AClientHTTPError(401, msg)
async def retry_on_401(
request_func: Callable[[], Awaitable[Response]],
auth_scheme: ClientAuthScheme | None,
client: AsyncClient,
headers: MutableMapping[str, str],
max_retries: int = 3,
) -> Response:
"""Retry a request on 401 authentication error.
Handles 401 errors by:
1. Parsing WWW-Authenticate header
2. Re-acquiring credentials
3. Retrying the request
Args:
request_func: Async function that makes the HTTP request.
auth_scheme: Authentication scheme to refresh credentials with.
client: HTTP client for making requests.
headers: Request headers to update with new auth.
max_retries: Maximum number of retry attempts (default: 3).
Returns:
HTTP response from the request.
Raises:
httpx.HTTPStatusError: If retries are exhausted or auth scheme is None.
"""
last_response: Response | None = None
last_challenges: dict[str, dict[str, str]] = {}
for attempt in range(max_retries):
response = await request_func()
if response.status_code != 401:
return response
last_response = response
if auth_scheme is None:
response.raise_for_status()
return response
www_authenticate = response.headers.get("WWW-Authenticate", "")
challenges = parse_www_authenticate(www_authenticate)
last_challenges = challenges
if attempt >= max_retries - 1:
break
backoff_time = 2**attempt
await asyncio.sleep(backoff_time)
await auth_scheme.apply_auth(client, headers)
if last_response:
last_response.raise_for_status()
return last_response
msg = "retry_on_401 failed without making any requests"
if last_challenges:
challenge_info = ", ".join(
f"{scheme} (realm={params.get('realm', 'N/A')})"
for scheme, params in last_challenges.items()
)
msg = f"{msg}. Server challenges: {challenge_info}"
raise RuntimeError(msg)
def configure_auth_client(
auth: HTTPDigestAuth | APIKeyAuth, client: AsyncClient
) -> None:
"""Configure HTTP client with auth-specific settings.
Only HTTPDigestAuth and APIKeyAuth need client configuration.
Args:
auth: Authentication scheme that requires client configuration.
client: HTTP client to configure.
"""
auth.configure_client(client)
from crewai_a2a.auth.utils import * # noqa: E402, F403

View File

@@ -1,690 +1,13 @@
"""A2A configuration types.
"""Backward-compatibility shim — use ``crewai_a2a.config`` instead."""
This module is separate from experimental.a2a to avoid circular imports.
"""
from __future__ import annotations
from pathlib import Path
from typing import Any, ClassVar, Literal, cast
import warnings
from pydantic import (
BaseModel,
ConfigDict,
Field,
FilePath,
PrivateAttr,
SecretStr,
model_validator,
warnings.warn(
"'crewai.a2a.config' has been moved to 'crewai_a2a.config'. "
"Please update your imports. The old path will be removed in v2.0.0.",
FutureWarning,
stacklevel=2,
)
from typing_extensions import Self, deprecated
from crewai.a2a.auth.client_schemes import ClientAuthScheme
from crewai.a2a.auth.server_schemes import ServerAuthScheme
from crewai.a2a.extensions.base import ValidatedA2AExtension
from crewai.a2a.types import ProtocolVersion, TransportType, Url
try:
from a2a.types import (
AgentCapabilities,
AgentCardSignature,
AgentInterface,
AgentProvider,
AgentSkill,
SecurityScheme,
)
from crewai.a2a.extensions.server import ServerExtension
from crewai.a2a.updates import UpdateConfig
except ImportError:
UpdateConfig: Any = Any # type: ignore[no-redef]
AgentCapabilities: Any = Any # type: ignore[no-redef]
AgentCardSignature: Any = Any # type: ignore[no-redef]
AgentInterface: Any = Any # type: ignore[no-redef]
AgentProvider: Any = Any # type: ignore[no-redef]
SecurityScheme: Any = Any # type: ignore[no-redef]
AgentSkill: Any = Any # type: ignore[no-redef]
ServerExtension: Any = Any # type: ignore[no-redef]
def _get_default_update_config() -> UpdateConfig:
from crewai.a2a.updates import StreamingConfig
return StreamingConfig()
SigningAlgorithm = Literal[
"RS256",
"RS384",
"RS512",
"ES256",
"ES384",
"ES512",
"PS256",
"PS384",
"PS512",
]
class AgentCardSigningConfig(BaseModel):
"""Configuration for AgentCard JWS signing.
Provides the private key and algorithm settings for signing AgentCards.
Either private_key_path or private_key_pem must be provided, but not both.
Attributes:
private_key_path: Path to a PEM-encoded private key file.
private_key_pem: PEM-encoded private key as a secret string.
key_id: Optional key identifier for the JWS header (kid claim).
algorithm: Signing algorithm (RS256, ES256, PS256, etc.).
"""
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
private_key_path: FilePath | None = Field(
default=None,
description="Path to PEM-encoded private key file",
)
private_key_pem: SecretStr | None = Field(
default=None,
description="PEM-encoded private key",
)
key_id: str | None = Field(
default=None,
description="Key identifier for JWS header (kid claim)",
)
algorithm: SigningAlgorithm = Field(
default="RS256",
description="Signing algorithm (RS256, ES256, PS256, etc.)",
)
@model_validator(mode="after")
def _validate_key_source(self) -> Self:
"""Ensure exactly one key source is provided."""
has_path = self.private_key_path is not None
has_pem = self.private_key_pem is not None
if not has_path and not has_pem:
raise ValueError(
"Either private_key_path or private_key_pem must be provided"
)
if has_path and has_pem:
raise ValueError(
"Only one of private_key_path or private_key_pem should be provided"
)
return self
def get_private_key(self) -> str:
"""Get the private key content.
Returns:
The PEM-encoded private key as a string.
"""
if self.private_key_pem:
return self.private_key_pem.get_secret_value()
if self.private_key_path:
return Path(self.private_key_path).read_text()
raise ValueError("No private key configured")
class GRPCServerConfig(BaseModel):
"""gRPC server transport configuration.
Presence of this config in ServerTransportConfig.grpc enables gRPC transport.
Attributes:
host: Hostname to advertise in agent cards (default: localhost).
Use docker service name (e.g., 'web') for docker-compose setups.
port: Port for the gRPC server.
tls_cert_path: Path to TLS certificate file for gRPC.
tls_key_path: Path to TLS private key file for gRPC.
max_workers: Maximum number of workers for the gRPC thread pool.
reflection_enabled: Whether to enable gRPC reflection for debugging.
"""
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
host: str = Field(
default="localhost",
description="Hostname to advertise in agent cards for gRPC connections",
)
port: int = Field(
default=50051,
description="Port for the gRPC server",
)
tls_cert_path: str | None = Field(
default=None,
description="Path to TLS certificate file for gRPC",
)
tls_key_path: str | None = Field(
default=None,
description="Path to TLS private key file for gRPC",
)
max_workers: int = Field(
default=10,
description="Maximum number of workers for the gRPC thread pool",
)
reflection_enabled: bool = Field(
default=False,
description="Whether to enable gRPC reflection for debugging",
)
class GRPCClientConfig(BaseModel):
"""gRPC client transport configuration.
Attributes:
max_send_message_length: Maximum size for outgoing messages in bytes.
max_receive_message_length: Maximum size for incoming messages in bytes.
keepalive_time_ms: Time between keepalive pings in milliseconds.
keepalive_timeout_ms: Timeout for keepalive ping response in milliseconds.
"""
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
max_send_message_length: int | None = Field(
default=None,
description="Maximum size for outgoing messages in bytes",
)
max_receive_message_length: int | None = Field(
default=None,
description="Maximum size for incoming messages in bytes",
)
keepalive_time_ms: int | None = Field(
default=None,
description="Time between keepalive pings in milliseconds",
)
keepalive_timeout_ms: int | None = Field(
default=None,
description="Timeout for keepalive ping response in milliseconds",
)
class JSONRPCServerConfig(BaseModel):
"""JSON-RPC server transport configuration.
Presence of this config in ServerTransportConfig.jsonrpc enables JSON-RPC transport.
Attributes:
rpc_path: URL path for the JSON-RPC endpoint.
agent_card_path: URL path for the agent card endpoint.
"""
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
rpc_path: str = Field(
default="/a2a",
description="URL path for the JSON-RPC endpoint",
)
agent_card_path: str = Field(
default="/.well-known/agent-card.json",
description="URL path for the agent card endpoint",
)
class JSONRPCClientConfig(BaseModel):
"""JSON-RPC client transport configuration.
Attributes:
max_request_size: Maximum request body size in bytes.
"""
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
max_request_size: int | None = Field(
default=None,
description="Maximum request body size in bytes",
)
class HTTPJSONConfig(BaseModel):
"""HTTP+JSON transport configuration.
Presence of this config in ServerTransportConfig.http_json enables HTTP+JSON transport.
"""
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
class ServerPushNotificationConfig(BaseModel):
"""Configuration for outgoing webhook push notifications.
Controls how the server signs and delivers push notifications to clients.
Attributes:
signature_secret: Shared secret for HMAC-SHA256 signing of outgoing webhooks.
"""
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
signature_secret: SecretStr | None = Field(
default=None,
description="Shared secret for HMAC-SHA256 signing of outgoing push notifications",
)
class ServerTransportConfig(BaseModel):
"""Transport configuration for A2A server.
Groups all transport-related settings including preferred transport
and protocol-specific configurations.
Attributes:
preferred: Transport protocol for the preferred endpoint.
jsonrpc: JSON-RPC server transport configuration.
grpc: gRPC server transport configuration.
http_json: HTTP+JSON transport configuration.
"""
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
preferred: TransportType = Field(
default="JSONRPC",
description="Transport protocol for the preferred endpoint",
)
jsonrpc: JSONRPCServerConfig = Field(
default_factory=JSONRPCServerConfig,
description="JSON-RPC server transport configuration",
)
grpc: GRPCServerConfig | None = Field(
default=None,
description="gRPC server transport configuration",
)
http_json: HTTPJSONConfig | None = Field(
default=None,
description="HTTP+JSON transport configuration",
)
def _migrate_client_transport_fields(
transport: ClientTransportConfig,
transport_protocol: TransportType | None,
supported_transports: list[TransportType] | None,
) -> None:
"""Migrate deprecated transport fields to new config."""
if transport_protocol is not None:
warnings.warn(
"transport_protocol is deprecated, use transport=ClientTransportConfig(preferred=...) instead",
FutureWarning,
stacklevel=5,
)
object.__setattr__(transport, "preferred", transport_protocol)
if supported_transports is not None:
warnings.warn(
"supported_transports is deprecated, use transport=ClientTransportConfig(supported=...) instead",
FutureWarning,
stacklevel=5,
)
object.__setattr__(transport, "supported", supported_transports)
class ClientTransportConfig(BaseModel):
"""Transport configuration for A2A client.
Groups all client transport-related settings including preferred transport,
supported transports for negotiation, and protocol-specific configurations.
Transport negotiation logic:
1. If `preferred` is set and server supports it → use client's preferred
2. Otherwise, if server's preferred is in client's `supported` → use server's preferred
3. Otherwise, find first match from client's `supported` in server's interfaces
Attributes:
preferred: Client's preferred transport. If set, client preference takes priority.
supported: Transports the client can use, in order of preference.
jsonrpc: JSON-RPC client transport configuration.
grpc: gRPC client transport configuration.
"""
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
preferred: TransportType | None = Field(
default=None,
description="Client's preferred transport. If set, takes priority over server preference.",
)
supported: list[TransportType] = Field(
default_factory=lambda: cast(list[TransportType], ["JSONRPC"]),
description="Transports the client can use, in order of preference",
)
jsonrpc: JSONRPCClientConfig = Field(
default_factory=JSONRPCClientConfig,
description="JSON-RPC client transport configuration",
)
grpc: GRPCClientConfig = Field(
default_factory=GRPCClientConfig,
description="gRPC client transport configuration",
)
@deprecated(
"""
`crewai.a2a.config.A2AConfig` is deprecated and will be removed in v2.0.0,
use `crewai.a2a.config.A2AClientConfig` or `crewai.a2a.config.A2AServerConfig` instead.
""",
category=FutureWarning,
)
class A2AConfig(BaseModel):
"""Configuration for A2A protocol integration.
Deprecated:
Use A2AClientConfig instead. This class will be removed in a future version.
Attributes:
endpoint: A2A agent endpoint URL.
auth: Authentication scheme.
timeout: Request timeout in seconds.
max_turns: Maximum conversation turns with A2A agent.
response_model: Optional Pydantic model for structured A2A agent responses.
fail_fast: If True, raise error when agent unreachable; if False, skip and continue.
trust_remote_completion_status: If True, return A2A agent's result directly when completed.
updates: Update mechanism config.
client_extensions: Client-side processing hooks for tool injection and prompt augmentation.
transport: Transport configuration (preferred, supported transports, gRPC settings).
"""
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
endpoint: Url = Field(description="A2A agent endpoint URL")
auth: ClientAuthScheme | None = Field(
default=None,
description="Authentication scheme",
)
timeout: int = Field(default=120, description="Request timeout in seconds")
max_turns: int = Field(
default=10, description="Maximum conversation turns with A2A agent"
)
response_model: type[BaseModel] | None = Field(
default=None,
description="Optional Pydantic model for structured A2A agent responses",
)
fail_fast: bool = Field(
default=True,
description="If True, raise error when agent unreachable; if False, skip",
)
trust_remote_completion_status: bool = Field(
default=False,
description="If True, return A2A result directly when completed",
)
updates: UpdateConfig = Field(
default_factory=_get_default_update_config,
description="Update mechanism config",
)
client_extensions: list[ValidatedA2AExtension] = Field(
default_factory=list,
description="Client-side processing hooks for tool injection and prompt augmentation",
)
transport: ClientTransportConfig = Field(
default_factory=ClientTransportConfig,
description="Transport configuration (preferred, supported transports, gRPC settings)",
)
transport_protocol: TransportType | None = Field(
default=None,
description="Deprecated: Use transport.preferred instead",
exclude=True,
)
supported_transports: list[TransportType] | None = Field(
default=None,
description="Deprecated: Use transport.supported instead",
exclude=True,
)
use_client_preference: bool | None = Field(
default=None,
description="Deprecated: Set transport.preferred to enable client preference",
exclude=True,
)
_parallel_delegation: bool = PrivateAttr(default=False)
@model_validator(mode="after")
def _migrate_deprecated_transport_fields(self) -> Self:
"""Migrate deprecated transport fields to new config."""
_migrate_client_transport_fields(
self.transport, self.transport_protocol, self.supported_transports
)
if self.use_client_preference is not None:
warnings.warn(
"use_client_preference is deprecated, set transport.preferred to enable client preference",
FutureWarning,
stacklevel=4,
)
if self.use_client_preference and self.transport.supported:
object.__setattr__(
self.transport, "preferred", self.transport.supported[0]
)
return self
class A2AClientConfig(BaseModel):
"""Configuration for connecting to remote A2A agents.
Attributes:
endpoint: A2A agent endpoint URL.
auth: Authentication scheme.
timeout: Request timeout in seconds.
max_turns: Maximum conversation turns with A2A agent.
response_model: Optional Pydantic model for structured A2A agent responses.
fail_fast: If True, raise error when agent unreachable; if False, skip and continue.
trust_remote_completion_status: If True, return A2A agent's result directly when completed.
updates: Update mechanism config.
accepted_output_modes: Media types the client can accept in responses.
extensions: Extension URIs the client supports (A2A protocol extensions).
client_extensions: Client-side processing hooks for tool injection and prompt augmentation.
transport: Transport configuration (preferred, supported transports, gRPC settings).
"""
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
endpoint: Url = Field(description="A2A agent endpoint URL")
auth: ClientAuthScheme | None = Field(
default=None,
description="Authentication scheme",
)
timeout: int = Field(default=120, description="Request timeout in seconds")
max_turns: int = Field(
default=10, description="Maximum conversation turns with A2A agent"
)
response_model: type[BaseModel] | None = Field(
default=None,
description="Optional Pydantic model for structured A2A agent responses",
)
fail_fast: bool = Field(
default=True,
description="If True, raise error when agent unreachable; if False, skip",
)
trust_remote_completion_status: bool = Field(
default=False,
description="If True, return A2A result directly when completed",
)
updates: UpdateConfig = Field(
default_factory=_get_default_update_config,
description="Update mechanism config",
)
accepted_output_modes: list[str] = Field(
default_factory=lambda: ["application/json"],
description="Media types the client can accept in responses",
)
extensions: list[str] = Field(
default_factory=list,
description="Extension URIs the client supports",
)
client_extensions: list[ValidatedA2AExtension] = Field(
default_factory=list,
description="Client-side processing hooks for tool injection and prompt augmentation",
)
transport: ClientTransportConfig = Field(
default_factory=ClientTransportConfig,
description="Transport configuration (preferred, supported transports, gRPC settings)",
)
transport_protocol: TransportType | None = Field(
default=None,
description="Deprecated: Use transport.preferred instead",
exclude=True,
)
supported_transports: list[TransportType] | None = Field(
default=None,
description="Deprecated: Use transport.supported instead",
exclude=True,
)
_parallel_delegation: bool = PrivateAttr(default=False)
@model_validator(mode="after")
def _migrate_deprecated_transport_fields(self) -> Self:
"""Migrate deprecated transport fields to new config."""
_migrate_client_transport_fields(
self.transport, self.transport_protocol, self.supported_transports
)
return self
class A2AServerConfig(BaseModel):
"""Configuration for exposing a Crew or Agent as an A2A server.
All fields correspond to A2A AgentCard fields. Fields like name, description,
and skills can be auto-derived from the Crew/Agent if not provided.
Attributes:
name: Human-readable name for the agent.
description: Human-readable description of the agent.
version: Version string for the agent card.
skills: List of agent skills/capabilities.
default_input_modes: Default supported input MIME types.
default_output_modes: Default supported output MIME types.
capabilities: Declaration of optional capabilities.
protocol_version: A2A protocol version this agent supports.
provider: Information about the agent's service provider.
documentation_url: URL to the agent's documentation.
icon_url: URL to an icon for the agent.
additional_interfaces: Additional supported interfaces.
security: Security requirement objects for all interactions.
security_schemes: Security schemes available to authorize requests.
supports_authenticated_extended_card: Whether agent provides extended card to authenticated users.
url: Preferred endpoint URL for the agent.
signing_config: Configuration for signing the AgentCard with JWS.
signatures: Deprecated. Pre-computed JWS signatures. Use signing_config instead.
server_extensions: Server-side A2A protocol extensions with on_request/on_response hooks.
push_notifications: Configuration for outgoing push notifications.
transport: Transport configuration (preferred transport, gRPC, REST settings).
auth: Authentication scheme for A2A endpoints.
"""
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
name: str | None = Field(
default=None,
description="Human-readable name for the agent. Auto-derived from Crew/Agent if not provided.",
)
description: str | None = Field(
default=None,
description="Human-readable description of the agent. Auto-derived from Crew/Agent if not provided.",
)
version: str = Field(
default="1.0.0",
description="Version string for the agent card",
)
skills: list[AgentSkill] = Field(
default_factory=list,
description="List of agent skills. Auto-derived from tasks/tools if not provided.",
)
default_input_modes: list[str] = Field(
default_factory=lambda: ["text/plain", "application/json"],
description="Default supported input MIME types",
)
default_output_modes: list[str] = Field(
default_factory=lambda: ["text/plain", "application/json"],
description="Default supported output MIME types",
)
capabilities: AgentCapabilities = Field(
default_factory=lambda: AgentCapabilities(
streaming=True,
push_notifications=False,
),
description="Declaration of optional capabilities supported by the agent",
)
protocol_version: ProtocolVersion = Field(
default="0.3.0",
description="A2A protocol version this agent supports",
)
provider: AgentProvider | None = Field(
default=None,
description="Information about the agent's service provider",
)
documentation_url: Url | None = Field(
default=None,
description="URL to the agent's documentation",
)
icon_url: Url | None = Field(
default=None,
description="URL to an icon for the agent",
)
additional_interfaces: list[AgentInterface] = Field(
default_factory=list,
description="Additional supported interfaces.",
)
security: list[dict[str, list[str]]] = Field(
default_factory=list,
description="Security requirement objects for all agent interactions",
)
security_schemes: dict[str, SecurityScheme] = Field(
default_factory=dict,
description="Security schemes available to authorize requests",
)
supports_authenticated_extended_card: bool = Field(
default=False,
description="Whether agent provides extended card to authenticated users",
)
url: Url | None = Field(
default=None,
description="Preferred endpoint URL for the agent. Set at runtime if not provided.",
)
signing_config: AgentCardSigningConfig | None = Field(
default=None,
description="Configuration for signing the AgentCard with JWS",
)
signatures: list[AgentCardSignature] | None = Field(
default=None,
description="Deprecated: Use signing_config instead. Pre-computed JWS signatures for the AgentCard.",
exclude=True,
deprecated=True,
)
server_extensions: list[ServerExtension] = Field(
default_factory=list,
description="Server-side A2A protocol extensions that modify agent behavior",
)
push_notifications: ServerPushNotificationConfig | None = Field(
default=None,
description="Configuration for outgoing push notifications",
)
transport: ServerTransportConfig = Field(
default_factory=ServerTransportConfig,
description="Transport configuration (preferred transport, gRPC, REST settings)",
)
preferred_transport: TransportType | None = Field(
default=None,
description="Deprecated: Use transport.preferred instead",
exclude=True,
deprecated=True,
)
auth: ServerAuthScheme | None = Field(
default=None,
description="Authentication scheme for A2A endpoints. Defaults to SimpleTokenAuth using AUTH_TOKEN env var.",
)
@model_validator(mode="after")
def _migrate_deprecated_fields(self) -> Self:
"""Migrate deprecated fields to new config."""
if self.preferred_transport is not None:
warnings.warn(
"preferred_transport is deprecated, use transport=ServerTransportConfig(preferred=...) instead",
FutureWarning,
stacklevel=4,
)
object.__setattr__(self.transport, "preferred", self.preferred_transport)
if self.signatures is not None:
warnings.warn(
"signatures is deprecated, use signing_config=AgentCardSigningConfig(...) instead. "
"The signatures field will be removed in v2.0.0.",
FutureWarning,
stacklevel=4,
)
return self
from crewai_a2a.config import * # noqa: E402, F403

View File

@@ -1,491 +1,13 @@
"""A2A error codes and error response utilities.
"""Backward-compatibility shim — use ``crewai_a2a.errors`` instead."""
This module provides a centralized mapping of all A2A protocol error codes
as defined in the A2A specification, plus custom CrewAI extensions.
import warnings
Error codes follow JSON-RPC 2.0 conventions:
- -32700 to -32600: Standard JSON-RPC errors
- -32099 to -32000: Server errors (A2A-specific)
- -32768 to -32100: Reserved for implementation-defined errors
"""
from __future__ import annotations
warnings.warn(
"'crewai.a2a.errors' has been moved to 'crewai_a2a.errors'. "
"Please update your imports. The old path will be removed in v2.0.0.",
FutureWarning,
stacklevel=2,
)
from dataclasses import dataclass, field
from enum import IntEnum
from typing import Any
from a2a.client.errors import A2AClientTimeoutError
class A2APollingTimeoutError(A2AClientTimeoutError):
"""Raised when polling exceeds the configured timeout."""
class A2AErrorCode(IntEnum):
"""A2A protocol error codes.
Codes follow JSON-RPC 2.0 specification with A2A-specific extensions.
"""
# JSON-RPC 2.0 Standard Errors (-32700 to -32600)
JSON_PARSE_ERROR = -32700
"""Invalid JSON was received by the server."""
INVALID_REQUEST = -32600
"""The JSON sent is not a valid Request object."""
METHOD_NOT_FOUND = -32601
"""The method does not exist / is not available."""
INVALID_PARAMS = -32602
"""Invalid method parameter(s)."""
INTERNAL_ERROR = -32603
"""Internal JSON-RPC error."""
# A2A-Specific Errors (-32099 to -32000)
TASK_NOT_FOUND = -32001
"""The specified task was not found."""
TASK_NOT_CANCELABLE = -32002
"""The task cannot be canceled (already completed/failed)."""
PUSH_NOTIFICATION_NOT_SUPPORTED = -32003
"""Push notifications are not supported by this agent."""
UNSUPPORTED_OPERATION = -32004
"""The requested operation is not supported."""
CONTENT_TYPE_NOT_SUPPORTED = -32005
"""Incompatible content types between client and server."""
INVALID_AGENT_RESPONSE = -32006
"""The agent produced an invalid response."""
# CrewAI Custom Extensions (-32768 to -32100)
UNSUPPORTED_VERSION = -32009
"""The requested A2A protocol version is not supported."""
UNSUPPORTED_EXTENSION = -32010
"""Client does not support required protocol extensions."""
AUTHENTICATION_REQUIRED = -32011
"""Authentication is required for this operation."""
AUTHORIZATION_FAILED = -32012
"""Authorization check failed (insufficient permissions)."""
RATE_LIMIT_EXCEEDED = -32013
"""Rate limit exceeded for this client/operation."""
TASK_TIMEOUT = -32014
"""Task execution timed out."""
TRANSPORT_NEGOTIATION_FAILED = -32015
"""Failed to negotiate a compatible transport protocol."""
CONTEXT_NOT_FOUND = -32016
"""The specified context was not found."""
SKILL_NOT_FOUND = -32017
"""The specified skill was not found."""
ARTIFACT_NOT_FOUND = -32018
"""The specified artifact was not found."""
# Error code to default message mapping
ERROR_MESSAGES: dict[int, str] = {
A2AErrorCode.JSON_PARSE_ERROR: "Parse error",
A2AErrorCode.INVALID_REQUEST: "Invalid Request",
A2AErrorCode.METHOD_NOT_FOUND: "Method not found",
A2AErrorCode.INVALID_PARAMS: "Invalid params",
A2AErrorCode.INTERNAL_ERROR: "Internal error",
A2AErrorCode.TASK_NOT_FOUND: "Task not found",
A2AErrorCode.TASK_NOT_CANCELABLE: "Task not cancelable",
A2AErrorCode.PUSH_NOTIFICATION_NOT_SUPPORTED: "Push Notification is not supported",
A2AErrorCode.UNSUPPORTED_OPERATION: "This operation is not supported",
A2AErrorCode.CONTENT_TYPE_NOT_SUPPORTED: "Incompatible content types",
A2AErrorCode.INVALID_AGENT_RESPONSE: "Invalid agent response",
A2AErrorCode.UNSUPPORTED_VERSION: "Unsupported A2A version",
A2AErrorCode.UNSUPPORTED_EXTENSION: "Client does not support required extensions",
A2AErrorCode.AUTHENTICATION_REQUIRED: "Authentication required",
A2AErrorCode.AUTHORIZATION_FAILED: "Authorization failed",
A2AErrorCode.RATE_LIMIT_EXCEEDED: "Rate limit exceeded",
A2AErrorCode.TASK_TIMEOUT: "Task execution timed out",
A2AErrorCode.TRANSPORT_NEGOTIATION_FAILED: "Transport negotiation failed",
A2AErrorCode.CONTEXT_NOT_FOUND: "Context not found",
A2AErrorCode.SKILL_NOT_FOUND: "Skill not found",
A2AErrorCode.ARTIFACT_NOT_FOUND: "Artifact not found",
}
@dataclass
class A2AError(Exception):
"""Base exception for A2A protocol errors.
Attributes:
code: The A2A/JSON-RPC error code.
message: Human-readable error message.
data: Optional additional error data.
"""
code: int
message: str | None = None
data: Any = None
def __post_init__(self) -> None:
if self.message is None:
self.message = ERROR_MESSAGES.get(self.code, "Unknown error")
super().__init__(self.message)
def to_dict(self) -> dict[str, Any]:
"""Convert to JSON-RPC error object format."""
error: dict[str, Any] = {
"code": self.code,
"message": self.message,
}
if self.data is not None:
error["data"] = self.data
return error
def to_response(self, request_id: str | int | None = None) -> dict[str, Any]:
"""Convert to full JSON-RPC error response."""
return {
"jsonrpc": "2.0",
"error": self.to_dict(),
"id": request_id,
}
@dataclass
class JSONParseError(A2AError):
"""Invalid JSON was received."""
code: int = field(default=A2AErrorCode.JSON_PARSE_ERROR, init=False)
@dataclass
class InvalidRequestError(A2AError):
"""The JSON sent is not a valid Request object."""
code: int = field(default=A2AErrorCode.INVALID_REQUEST, init=False)
@dataclass
class MethodNotFoundError(A2AError):
"""The method does not exist / is not available."""
code: int = field(default=A2AErrorCode.METHOD_NOT_FOUND, init=False)
method: str | None = None
def __post_init__(self) -> None:
if self.message is None and self.method:
self.message = f"Method not found: {self.method}"
super().__post_init__()
@dataclass
class InvalidParamsError(A2AError):
"""Invalid method parameter(s)."""
code: int = field(default=A2AErrorCode.INVALID_PARAMS, init=False)
param: str | None = None
reason: str | None = None
def __post_init__(self) -> None:
if self.message is None:
if self.param and self.reason:
self.message = f"Invalid parameter '{self.param}': {self.reason}"
elif self.param:
self.message = f"Invalid parameter: {self.param}"
super().__post_init__()
@dataclass
class InternalError(A2AError):
"""Internal JSON-RPC error."""
code: int = field(default=A2AErrorCode.INTERNAL_ERROR, init=False)
@dataclass
class TaskNotFoundError(A2AError):
"""The specified task was not found."""
code: int = field(default=A2AErrorCode.TASK_NOT_FOUND, init=False)
task_id: str | None = None
def __post_init__(self) -> None:
if self.message is None and self.task_id:
self.message = f"Task not found: {self.task_id}"
super().__post_init__()
@dataclass
class TaskNotCancelableError(A2AError):
"""The task cannot be canceled."""
code: int = field(default=A2AErrorCode.TASK_NOT_CANCELABLE, init=False)
task_id: str | None = None
reason: str | None = None
def __post_init__(self) -> None:
if self.message is None:
if self.task_id and self.reason:
self.message = f"Task {self.task_id} cannot be canceled: {self.reason}"
elif self.task_id:
self.message = f"Task {self.task_id} cannot be canceled"
super().__post_init__()
@dataclass
class PushNotificationNotSupportedError(A2AError):
"""Push notifications are not supported."""
code: int = field(default=A2AErrorCode.PUSH_NOTIFICATION_NOT_SUPPORTED, init=False)
@dataclass
class UnsupportedOperationError(A2AError):
"""The requested operation is not supported."""
code: int = field(default=A2AErrorCode.UNSUPPORTED_OPERATION, init=False)
operation: str | None = None
def __post_init__(self) -> None:
if self.message is None and self.operation:
self.message = f"Operation not supported: {self.operation}"
super().__post_init__()
@dataclass
class ContentTypeNotSupportedError(A2AError):
"""Incompatible content types."""
code: int = field(default=A2AErrorCode.CONTENT_TYPE_NOT_SUPPORTED, init=False)
requested_types: list[str] | None = None
supported_types: list[str] | None = None
def __post_init__(self) -> None:
if self.message is None and self.requested_types and self.supported_types:
self.message = (
f"Content type not supported. Requested: {self.requested_types}, "
f"Supported: {self.supported_types}"
)
super().__post_init__()
@dataclass
class InvalidAgentResponseError(A2AError):
"""The agent produced an invalid response."""
code: int = field(default=A2AErrorCode.INVALID_AGENT_RESPONSE, init=False)
@dataclass
class UnsupportedVersionError(A2AError):
"""The requested A2A version is not supported."""
code: int = field(default=A2AErrorCode.UNSUPPORTED_VERSION, init=False)
requested_version: str | None = None
supported_versions: list[str] | None = None
def __post_init__(self) -> None:
if self.message is None and self.requested_version:
msg = f"Unsupported A2A version: {self.requested_version}"
if self.supported_versions:
msg += f". Supported versions: {', '.join(self.supported_versions)}"
self.message = msg
super().__post_init__()
@dataclass
class UnsupportedExtensionError(A2AError):
"""Client does not support required extensions."""
code: int = field(default=A2AErrorCode.UNSUPPORTED_EXTENSION, init=False)
required_extensions: list[str] | None = None
def __post_init__(self) -> None:
if self.message is None and self.required_extensions:
self.message = f"Client does not support required extensions: {', '.join(self.required_extensions)}"
super().__post_init__()
@dataclass
class AuthenticationRequiredError(A2AError):
"""Authentication is required."""
code: int = field(default=A2AErrorCode.AUTHENTICATION_REQUIRED, init=False)
@dataclass
class AuthorizationFailedError(A2AError):
"""Authorization check failed."""
code: int = field(default=A2AErrorCode.AUTHORIZATION_FAILED, init=False)
required_scope: str | None = None
def __post_init__(self) -> None:
if self.message is None and self.required_scope:
self.message = (
f"Authorization failed. Required scope: {self.required_scope}"
)
super().__post_init__()
@dataclass
class RateLimitExceededError(A2AError):
"""Rate limit exceeded."""
code: int = field(default=A2AErrorCode.RATE_LIMIT_EXCEEDED, init=False)
retry_after: int | None = None
def __post_init__(self) -> None:
if self.message is None and self.retry_after:
self.message = (
f"Rate limit exceeded. Retry after {self.retry_after} seconds"
)
if self.retry_after:
self.data = {"retry_after": self.retry_after}
super().__post_init__()
@dataclass
class TaskTimeoutError(A2AError):
"""Task execution timed out."""
code: int = field(default=A2AErrorCode.TASK_TIMEOUT, init=False)
task_id: str | None = None
timeout_seconds: float | None = None
def __post_init__(self) -> None:
if self.message is None:
if self.task_id and self.timeout_seconds:
self.message = (
f"Task {self.task_id} timed out after {self.timeout_seconds}s"
)
elif self.task_id:
self.message = f"Task {self.task_id} timed out"
super().__post_init__()
@dataclass
class TransportNegotiationFailedError(A2AError):
"""Failed to negotiate a compatible transport protocol."""
code: int = field(default=A2AErrorCode.TRANSPORT_NEGOTIATION_FAILED, init=False)
client_transports: list[str] | None = None
server_transports: list[str] | None = None
def __post_init__(self) -> None:
if self.message is None and self.client_transports and self.server_transports:
self.message = (
f"Transport negotiation failed. Client: {self.client_transports}, "
f"Server: {self.server_transports}"
)
super().__post_init__()
@dataclass
class ContextNotFoundError(A2AError):
"""The specified context was not found."""
code: int = field(default=A2AErrorCode.CONTEXT_NOT_FOUND, init=False)
context_id: str | None = None
def __post_init__(self) -> None:
if self.message is None and self.context_id:
self.message = f"Context not found: {self.context_id}"
super().__post_init__()
@dataclass
class SkillNotFoundError(A2AError):
"""The specified skill was not found."""
code: int = field(default=A2AErrorCode.SKILL_NOT_FOUND, init=False)
skill_id: str | None = None
def __post_init__(self) -> None:
if self.message is None and self.skill_id:
self.message = f"Skill not found: {self.skill_id}"
super().__post_init__()
@dataclass
class ArtifactNotFoundError(A2AError):
"""The specified artifact was not found."""
code: int = field(default=A2AErrorCode.ARTIFACT_NOT_FOUND, init=False)
artifact_id: str | None = None
def __post_init__(self) -> None:
if self.message is None and self.artifact_id:
self.message = f"Artifact not found: {self.artifact_id}"
super().__post_init__()
def create_error_response(
code: int | A2AErrorCode,
message: str | None = None,
data: Any = None,
request_id: str | int | None = None,
) -> dict[str, Any]:
"""Create a JSON-RPC error response.
Args:
code: Error code (A2AErrorCode or int).
message: Optional error message (uses default if not provided).
data: Optional additional error data.
request_id: Request ID for correlation.
Returns:
Dict in JSON-RPC error response format.
"""
error = A2AError(code=int(code), message=message, data=data)
return error.to_response(request_id)
def is_retryable_error(code: int) -> bool:
"""Check if an error is potentially retryable.
Args:
code: Error code to check.
Returns:
True if the error might be resolved by retrying.
"""
retryable_codes = {
A2AErrorCode.INTERNAL_ERROR,
A2AErrorCode.RATE_LIMIT_EXCEEDED,
A2AErrorCode.TASK_TIMEOUT,
}
return code in retryable_codes
def is_client_error(code: int) -> bool:
"""Check if an error is a client-side error.
Args:
code: Error code to check.
Returns:
True if the error is due to client request issues.
"""
client_error_codes = {
A2AErrorCode.JSON_PARSE_ERROR,
A2AErrorCode.INVALID_REQUEST,
A2AErrorCode.METHOD_NOT_FOUND,
A2AErrorCode.INVALID_PARAMS,
A2AErrorCode.TASK_NOT_FOUND,
A2AErrorCode.CONTENT_TYPE_NOT_SUPPORTED,
A2AErrorCode.UNSUPPORTED_VERSION,
A2AErrorCode.UNSUPPORTED_EXTENSION,
A2AErrorCode.CONTEXT_NOT_FOUND,
A2AErrorCode.SKILL_NOT_FOUND,
A2AErrorCode.ARTIFACT_NOT_FOUND,
}
return code in client_error_codes
from crewai_a2a.errors import * # noqa: E402, F403

View File

@@ -1,37 +1,13 @@
"""A2A Protocol Extensions for CrewAI.
"""Backward-compatibility shim — use ``crewai_a2a.extensions`` instead."""
This module contains extensions to the A2A (Agent-to-Agent) protocol.
import warnings
**Client-side extensions** (A2AExtension) allow customizing how the A2A wrapper
processes requests and responses during delegation to remote agents. These provide
hooks for tool injection, prompt augmentation, and response processing.
**Server-side extensions** (ServerExtension) allow agents to offer additional
functionality beyond the core A2A specification. Clients activate extensions
via the X-A2A-Extensions header.
See: https://a2a-protocol.org/latest/topics/extensions/
"""
from crewai.a2a.extensions.base import (
A2AExtension,
ConversationState,
ExtensionRegistry,
ValidatedA2AExtension,
)
from crewai.a2a.extensions.server import (
ExtensionContext,
ServerExtension,
ServerExtensionRegistry,
warnings.warn(
"'crewai.a2a.extensions' has been moved to 'crewai_a2a.extensions'. "
"Please update your imports. The old path will be removed in v2.0.0.",
FutureWarning,
stacklevel=2,
)
__all__ = [
"A2AExtension",
"ConversationState",
"ExtensionContext",
"ExtensionRegistry",
"ServerExtension",
"ServerExtensionRegistry",
"ValidatedA2AExtension",
]
from crewai_a2a.extensions import * # noqa: E402, F403

View File

@@ -1,238 +1,13 @@
"""Base extension interface for CrewAI A2A wrapper processing hooks.
"""Backward-compatibility shim — use ``crewai_a2a.extensions.base`` instead."""
This module defines the protocol for extending CrewAI's A2A wrapper functionality
with custom logic for tool injection, prompt augmentation, and response processing.
Note: These are CrewAI-specific processing hooks, NOT A2A protocol extensions.
A2A protocol extensions are capability declarations using AgentExtension objects
in AgentCard.capabilities.extensions, activated via the A2A-Extensions HTTP header.
See: https://a2a-protocol.org/latest/topics/extensions/
"""
from __future__ import annotations
from collections.abc import Sequence
from typing import TYPE_CHECKING, Annotated, Any, Protocol, runtime_checkable
from pydantic import BeforeValidator
import warnings
if TYPE_CHECKING:
from a2a.types import Message
warnings.warn(
"'crewai.a2a.extensions.base' has been moved to 'crewai_a2a.extensions.base'. "
"Please update your imports. The old path will be removed in v2.0.0.",
FutureWarning,
stacklevel=2,
)
from crewai.agent.core import Agent
def _validate_a2a_extension(v: Any) -> Any:
"""Validate that value implements A2AExtension protocol."""
if not isinstance(v, A2AExtension):
raise ValueError(
f"Value must implement A2AExtension protocol. "
f"Got {type(v).__name__} which is missing required methods."
)
return v
ValidatedA2AExtension = Annotated[Any, BeforeValidator(_validate_a2a_extension)]
@runtime_checkable
class ConversationState(Protocol):
"""Protocol for extension-specific conversation state.
Extensions can define their own state classes that implement this protocol
to track conversation-specific data extracted from message history.
"""
def is_ready(self) -> bool:
"""Check if the state indicates readiness for some action.
Returns:
True if the state is ready, False otherwise.
"""
...
@runtime_checkable
class A2AExtension(Protocol):
"""Protocol for A2A wrapper extensions.
Extensions can implement this protocol to inject custom logic into
the A2A conversation flow at various integration points.
Example:
class MyExtension:
def inject_tools(self, agent: Agent) -> None:
# Add custom tools to the agent
pass
def extract_state_from_history(
self, conversation_history: Sequence[Message]
) -> ConversationState | None:
# Extract state from conversation
return None
def augment_prompt(
self, base_prompt: str, conversation_state: ConversationState | None
) -> str:
# Add custom instructions
return base_prompt
def process_response(
self, agent_response: Any, conversation_state: ConversationState | None
) -> Any:
# Modify response if needed
return agent_response
"""
def inject_tools(self, agent: Agent) -> None:
"""Inject extension-specific tools into the agent.
Called when an agent is wrapped with A2A capabilities. Extensions
can add tools that enable extension-specific functionality.
Args:
agent: The agent instance to inject tools into.
"""
...
def extract_state_from_history(
self, conversation_history: Sequence[Message]
) -> ConversationState | None:
"""Extract extension-specific state from conversation history.
Called during prompt augmentation to allow extensions to analyze
the conversation history and extract relevant state information.
Args:
conversation_history: The sequence of A2A messages exchanged.
Returns:
Extension-specific conversation state, or None if no relevant state.
"""
...
def augment_prompt(
self,
base_prompt: str,
conversation_state: ConversationState | None,
) -> str:
"""Augment the task prompt with extension-specific instructions.
Called during prompt augmentation to allow extensions to add
custom instructions based on conversation state.
Args:
base_prompt: The base prompt to augment.
conversation_state: Extension-specific state from extract_state_from_history.
Returns:
The augmented prompt with extension-specific instructions.
"""
...
def process_response(
self,
agent_response: Any,
conversation_state: ConversationState | None,
) -> Any:
"""Process and potentially modify the agent response.
Called after parsing the agent's response, allowing extensions to
enhance or modify the response based on conversation state.
Args:
agent_response: The parsed agent response.
conversation_state: Extension-specific state from extract_state_from_history.
Returns:
The processed agent response (may be modified or original).
"""
...
class ExtensionRegistry:
"""Registry for managing A2A extensions.
Maintains a collection of extensions and provides methods to invoke
their hooks at various integration points.
"""
def __init__(self) -> None:
"""Initialize the extension registry."""
self._extensions: list[A2AExtension] = []
def register(self, extension: A2AExtension) -> None:
"""Register an extension.
Args:
extension: The extension to register.
"""
self._extensions.append(extension)
def inject_all_tools(self, agent: Agent) -> None:
"""Inject tools from all registered extensions.
Args:
agent: The agent instance to inject tools into.
"""
for extension in self._extensions:
extension.inject_tools(agent)
def extract_all_states(
self, conversation_history: Sequence[Message]
) -> dict[type[A2AExtension], ConversationState]:
"""Extract conversation states from all registered extensions.
Args:
conversation_history: The sequence of A2A messages exchanged.
Returns:
Mapping of extension types to their conversation states.
"""
states: dict[type[A2AExtension], ConversationState] = {}
for extension in self._extensions:
state = extension.extract_state_from_history(conversation_history)
if state is not None:
states[type(extension)] = state
return states
def augment_prompt_with_all(
self,
base_prompt: str,
extension_states: dict[type[A2AExtension], ConversationState],
) -> str:
"""Augment prompt with instructions from all registered extensions.
Args:
base_prompt: The base prompt to augment.
extension_states: Mapping of extension types to conversation states.
Returns:
The fully augmented prompt.
"""
augmented = base_prompt
for extension in self._extensions:
state = extension_states.get(type(extension))
augmented = extension.augment_prompt(augmented, state)
return augmented
def process_response_with_all(
self,
agent_response: Any,
extension_states: dict[type[A2AExtension], ConversationState],
) -> Any:
"""Process response through all registered extensions.
Args:
agent_response: The parsed agent response.
extension_states: Mapping of extension types to conversation states.
Returns:
The processed agent response.
"""
processed = agent_response
for extension in self._extensions:
state = extension_states.get(type(extension))
processed = extension.process_response(processed, state)
return processed
from crewai_a2a.extensions.base import * # noqa: E402, F403

View File

@@ -1,170 +1,13 @@
"""A2A Protocol extension utilities.
"""Backward-compatibility shim — use ``crewai_a2a.extensions.registry`` instead."""
This module provides utilities for working with A2A protocol extensions as
defined in the A2A specification. Extensions are capability declarations in
AgentCard.capabilities.extensions using AgentExtension objects, activated
via the X-A2A-Extensions HTTP header.
import warnings
See: https://a2a-protocol.org/latest/topics/extensions/
"""
from __future__ import annotations
from typing import Any
from a2a.client.middleware import ClientCallContext, ClientCallInterceptor
from a2a.extensions.common import (
HTTP_EXTENSION_HEADER,
warnings.warn(
"'crewai.a2a.extensions.registry' has been moved to 'crewai_a2a.extensions.registry'. "
"Please update your imports. The old path will be removed in v2.0.0.",
FutureWarning,
stacklevel=2,
)
from a2a.types import AgentCard, AgentExtension
from crewai.a2a.config import A2AClientConfig, A2AConfig
from crewai.a2a.extensions.base import ExtensionRegistry
def get_extensions_from_config(
a2a_config: list[A2AConfig | A2AClientConfig] | A2AConfig | A2AClientConfig,
) -> list[str]:
"""Extract extension URIs from A2A configuration.
Args:
a2a_config: A2A configuration (single or list).
Returns:
Deduplicated list of extension URIs from all configs.
"""
configs = a2a_config if isinstance(a2a_config, list) else [a2a_config]
seen: set[str] = set()
result: list[str] = []
for config in configs:
if not isinstance(config, A2AClientConfig):
continue
for uri in config.extensions:
if uri not in seen:
seen.add(uri)
result.append(uri)
return result
class ExtensionsMiddleware(ClientCallInterceptor):
"""Middleware to add X-A2A-Extensions header to requests.
This middleware adds the extensions header to all outgoing requests,
declaring which A2A protocol extensions the client supports.
"""
def __init__(self, extensions: list[str]) -> None:
"""Initialize with extension URIs.
Args:
extensions: List of extension URIs the client supports.
"""
self._extensions = extensions
async def intercept(
self,
method_name: str,
request_payload: dict[str, Any],
http_kwargs: dict[str, Any],
agent_card: AgentCard | None,
context: ClientCallContext | None,
) -> tuple[dict[str, Any], dict[str, Any]]:
"""Add extensions header to the request.
Args:
method_name: The A2A method being called.
request_payload: The JSON-RPC request payload.
http_kwargs: HTTP request kwargs (headers, etc).
agent_card: The target agent's card.
context: Optional call context.
Returns:
Tuple of (request_payload, modified_http_kwargs).
"""
if self._extensions:
headers = http_kwargs.setdefault("headers", {})
headers[HTTP_EXTENSION_HEADER] = ",".join(self._extensions)
return request_payload, http_kwargs
def validate_required_extensions(
agent_card: AgentCard,
client_extensions: list[str] | None,
) -> list[AgentExtension]:
"""Validate that client supports all required extensions from agent.
Args:
agent_card: The agent's card with declared extensions.
client_extensions: Extension URIs the client supports.
Returns:
List of unsupported required extensions.
Raises:
None - returns list of unsupported extensions for caller to handle.
"""
unsupported: list[AgentExtension] = []
client_set = set(client_extensions or [])
if not agent_card.capabilities or not agent_card.capabilities.extensions:
return unsupported
unsupported.extend(
ext
for ext in agent_card.capabilities.extensions
if ext.required and ext.uri not in client_set
)
return unsupported
def create_extension_registry_from_config(
a2a_config: list[A2AConfig | A2AClientConfig] | A2AConfig | A2AClientConfig,
) -> ExtensionRegistry:
"""Create an extension registry from A2A client configuration.
Extracts client_extensions from each A2AClientConfig and registers them
with the ExtensionRegistry. These extensions provide CrewAI-specific
processing hooks (tool injection, prompt augmentation, response processing).
Note: A2A protocol extensions (URI strings sent via X-A2A-Extensions header)
are handled separately via get_extensions_from_config() and ExtensionsMiddleware.
Args:
a2a_config: A2A configuration (single or list).
Returns:
Extension registry with all client_extensions registered.
Example:
class LoggingExtension:
def inject_tools(self, agent): pass
def extract_state_from_history(self, history): return None
def augment_prompt(self, prompt, state): return prompt
def process_response(self, response, state):
print(f"Response: {response}")
return response
config = A2AClientConfig(
endpoint="https://agent.example.com",
client_extensions=[LoggingExtension()],
)
registry = create_extension_registry_from_config(config)
"""
registry = ExtensionRegistry()
configs = a2a_config if isinstance(a2a_config, list) else [a2a_config]
seen: set[int] = set()
for config in configs:
if isinstance(config, (A2AConfig, A2AClientConfig)):
client_exts = getattr(config, "client_extensions", [])
for extension in client_exts:
ext_id = id(extension)
if ext_id not in seen:
seen.add(ext_id)
registry.register(extension)
return registry
from crewai_a2a.extensions.registry import * # noqa: E402, F403

View File

@@ -1,305 +1,13 @@
"""A2A protocol server extensions for CrewAI agents.
"""Backward-compatibility shim — use ``crewai_a2a.extensions.server`` instead."""
This module provides the base class and context for implementing A2A protocol
extensions on the server side. Extensions allow agents to offer additional
functionality beyond the core A2A specification.
import warnings
See: https://a2a-protocol.org/latest/topics/extensions/
"""
from __future__ import annotations
warnings.warn(
"'crewai.a2a.extensions.server' has been moved to 'crewai_a2a.extensions.server'. "
"Please update your imports. The old path will be removed in v2.0.0.",
FutureWarning,
stacklevel=2,
)
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
import logging
from typing import TYPE_CHECKING, Annotated, Any
from a2a.types import AgentExtension
from pydantic_core import CoreSchema, core_schema
if TYPE_CHECKING:
from a2a.server.context import ServerCallContext
from pydantic import GetCoreSchemaHandler
logger = logging.getLogger(__name__)
@dataclass
class ExtensionContext:
"""Context passed to extension hooks during request processing.
Provides access to request metadata, client extensions, and shared state
that extensions can read from and write to.
Attributes:
metadata: Request metadata dict, includes extension-namespaced keys.
client_extensions: Set of extension URIs the client declared support for.
state: Mutable dict for extensions to share data during request lifecycle.
server_context: The underlying A2A server call context.
"""
metadata: dict[str, Any]
client_extensions: set[str]
state: dict[str, Any] = field(default_factory=dict)
server_context: ServerCallContext | None = None
def get_extension_metadata(self, uri: str, key: str) -> Any | None:
"""Get extension-specific metadata value.
Extension metadata uses namespaced keys in the format:
"{extension_uri}/{key}"
Args:
uri: The extension URI.
key: The metadata key within the extension namespace.
Returns:
The metadata value, or None if not present.
"""
full_key = f"{uri}/{key}"
return self.metadata.get(full_key)
def set_extension_metadata(self, uri: str, key: str, value: Any) -> None:
"""Set extension-specific metadata value.
Args:
uri: The extension URI.
key: The metadata key within the extension namespace.
value: The value to set.
"""
full_key = f"{uri}/{key}"
self.metadata[full_key] = value
class ServerExtension(ABC):
"""Base class for A2A protocol server extensions.
Subclass this to create custom extensions that modify agent behavior
when clients activate them. Extensions are identified by URI and can
be marked as required.
Example:
class SamplingExtension(ServerExtension):
uri = "urn:crewai:ext:sampling/v1"
required = True
def __init__(self, max_tokens: int = 4096):
self.max_tokens = max_tokens
@property
def params(self) -> dict[str, Any]:
return {"max_tokens": self.max_tokens}
async def on_request(self, context: ExtensionContext) -> None:
limit = context.get_extension_metadata(self.uri, "limit")
if limit:
context.state["token_limit"] = int(limit)
async def on_response(self, context: ExtensionContext, result: Any) -> Any:
return result
"""
uri: Annotated[str, "Extension URI identifier. Must be unique."]
required: Annotated[bool, "Whether clients must support this extension."] = False
description: Annotated[
str | None, "Human-readable description of the extension."
] = None
@classmethod
def __get_pydantic_core_schema__(
cls,
_source_type: Any,
_handler: GetCoreSchemaHandler,
) -> CoreSchema:
"""Tell Pydantic how to validate ServerExtension instances."""
return core_schema.is_instance_schema(cls)
@property
def params(self) -> dict[str, Any] | None:
"""Extension parameters to advertise in AgentCard.
Override this property to expose configuration that clients can read.
Returns:
Dict of parameter names to values, or None.
"""
return None
def agent_extension(self) -> AgentExtension:
"""Generate the AgentExtension object for the AgentCard.
Returns:
AgentExtension with this extension's URI, required flag, and params.
"""
return AgentExtension(
uri=self.uri,
required=self.required if self.required else None,
description=self.description,
params=self.params,
)
def is_active(self, context: ExtensionContext) -> bool:
"""Check if this extension is active for the current request.
An extension is active if the client declared support for it.
Args:
context: The extension context for the current request.
Returns:
True if the client supports this extension.
"""
return self.uri in context.client_extensions
@abstractmethod
async def on_request(self, context: ExtensionContext) -> None:
"""Called before agent execution if extension is active.
Use this hook to:
- Read extension-specific metadata from the request
- Set up state for the execution
- Modify execution parameters via context.state
Args:
context: The extension context with request metadata and state.
"""
...
@abstractmethod
async def on_response(self, context: ExtensionContext, result: Any) -> Any:
"""Called after agent execution if extension is active.
Use this hook to:
- Modify or enhance the result
- Add extension-specific metadata to the response
- Clean up any resources
Args:
context: The extension context with request metadata and state.
result: The agent execution result.
Returns:
The result, potentially modified.
"""
...
class ServerExtensionRegistry:
"""Registry for managing server-side A2A protocol extensions.
Collects extensions and provides methods to generate AgentCapabilities
and invoke extension hooks during request processing.
"""
def __init__(self, extensions: list[ServerExtension] | None = None) -> None:
"""Initialize the registry with optional extensions.
Args:
extensions: Initial list of extensions to register.
"""
self._extensions: list[ServerExtension] = list(extensions) if extensions else []
self._by_uri: dict[str, ServerExtension] = {
ext.uri: ext for ext in self._extensions
}
def register(self, extension: ServerExtension) -> None:
"""Register an extension.
Args:
extension: The extension to register.
Raises:
ValueError: If an extension with the same URI is already registered.
"""
if extension.uri in self._by_uri:
raise ValueError(f"Extension already registered: {extension.uri}")
self._extensions.append(extension)
self._by_uri[extension.uri] = extension
def get_agent_extensions(self) -> list[AgentExtension]:
"""Get AgentExtension objects for all registered extensions.
Returns:
List of AgentExtension objects for the AgentCard.
"""
return [ext.agent_extension() for ext in self._extensions]
def get_extension(self, uri: str) -> ServerExtension | None:
"""Get an extension by URI.
Args:
uri: The extension URI.
Returns:
The extension, or None if not found.
"""
return self._by_uri.get(uri)
@staticmethod
def create_context(
metadata: dict[str, Any],
client_extensions: set[str],
server_context: ServerCallContext | None = None,
) -> ExtensionContext:
"""Create an ExtensionContext for a request.
Args:
metadata: Request metadata dict.
client_extensions: Set of extension URIs from client.
server_context: Optional server call context.
Returns:
ExtensionContext for use in hooks.
"""
return ExtensionContext(
metadata=metadata,
client_extensions=client_extensions,
server_context=server_context,
)
async def invoke_on_request(self, context: ExtensionContext) -> None:
"""Invoke on_request hooks for all active extensions.
Tracks activated extensions and isolates errors from individual hooks.
Args:
context: The extension context for the request.
"""
for extension in self._extensions:
if extension.is_active(context):
try:
await extension.on_request(context)
if context.server_context is not None:
context.server_context.activated_extensions.add(extension.uri)
except Exception:
logger.exception(
"Extension on_request hook failed",
extra={"extension": extension.uri},
)
async def invoke_on_response(self, context: ExtensionContext, result: Any) -> Any:
"""Invoke on_response hooks for all active extensions.
Isolates errors from individual hooks to prevent one failing extension
from breaking the entire response.
Args:
context: The extension context for the request.
result: The agent execution result.
Returns:
The result after all extensions have processed it.
"""
processed = result
for extension in self._extensions:
if extension.is_active(context):
try:
processed = await extension.on_response(context, processed)
except Exception:
logger.exception(
"Extension on_response hook failed",
extra={"extension": extension.uri},
)
return processed
from crewai_a2a.extensions.server import * # noqa: E402, F403

View File

@@ -1,480 +1,13 @@
"""Helper functions for processing A2A task results."""
"""Backward-compatibility shim — use ``crewai_a2a.task_helpers`` instead."""
from __future__ import annotations
import warnings
from collections.abc import AsyncIterator
from typing import TYPE_CHECKING, Any, TypedDict
import uuid
from a2a.client.errors import A2AClientHTTPError
from a2a.types import (
AgentCard,
Message,
Part,
Role,
Task,
TaskArtifactUpdateEvent,
TaskState,
TaskStatusUpdateEvent,
TextPart,
)
from typing_extensions import NotRequired
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.a2a_events import (
A2AConnectionErrorEvent,
A2AResponseReceivedEvent,
warnings.warn(
"'crewai.a2a.task_helpers' has been moved to 'crewai_a2a.task_helpers'. "
"Please update your imports. The old path will be removed in v2.0.0.",
FutureWarning,
stacklevel=2,
)
if TYPE_CHECKING:
from a2a.types import Task as A2ATask
SendMessageEvent = (
tuple[Task, TaskStatusUpdateEvent | TaskArtifactUpdateEvent | None] | Message
)
TERMINAL_STATES: frozenset[TaskState] = frozenset(
{
TaskState.completed,
TaskState.failed,
TaskState.rejected,
TaskState.canceled,
}
)
ACTIONABLE_STATES: frozenset[TaskState] = frozenset(
{
TaskState.input_required,
TaskState.auth_required,
}
)
PENDING_STATES: frozenset[TaskState] = frozenset(
{
TaskState.submitted,
TaskState.working,
}
)
class TaskStateResult(TypedDict):
"""Result dictionary from processing A2A task state."""
status: TaskState
history: list[Message]
result: NotRequired[str]
error: NotRequired[str]
agent_card: NotRequired[dict[str, Any]]
a2a_agent_name: NotRequired[str | None]
def extract_task_result_parts(a2a_task: A2ATask) -> list[str]:
"""Extract result parts from A2A task status message, history, and artifacts.
Args:
a2a_task: A2A Task object with status, history, and artifacts
Returns:
List of result text parts
"""
result_parts: list[str] = []
if a2a_task.status and a2a_task.status.message:
msg = a2a_task.status.message
result_parts.extend(
part.root.text for part in msg.parts if part.root.kind == "text"
)
if not result_parts and a2a_task.history:
for history_msg in reversed(a2a_task.history):
if history_msg.role == Role.agent:
result_parts.extend(
part.root.text
for part in history_msg.parts
if part.root.kind == "text"
)
break
if a2a_task.artifacts:
result_parts.extend(
part.root.text
for artifact in a2a_task.artifacts
for part in artifact.parts
if part.root.kind == "text"
)
return result_parts
def extract_error_message(a2a_task: A2ATask, default: str) -> str:
"""Extract error message from A2A task.
Args:
a2a_task: A2A Task object
default: Default message if no error found
Returns:
Error message string
"""
if a2a_task.status and a2a_task.status.message:
msg = a2a_task.status.message
if msg:
for part in msg.parts:
if part.root.kind == "text":
return str(part.root.text)
return str(msg)
if a2a_task.history:
for history_msg in reversed(a2a_task.history):
for part in history_msg.parts:
if part.root.kind == "text":
return str(part.root.text)
return default
def process_task_state(
a2a_task: A2ATask,
new_messages: list[Message],
agent_card: AgentCard,
turn_number: int,
is_multiturn: bool,
agent_role: str | None,
result_parts: list[str] | None = None,
endpoint: str | None = None,
a2a_agent_name: str | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
is_final: bool = True,
) -> TaskStateResult | None:
"""Process A2A task state and return result dictionary.
Shared logic for both polling and streaming handlers.
Args:
a2a_task: The A2A task to process.
new_messages: List to collect messages (modified in place).
agent_card: The agent card.
turn_number: Current turn number.
is_multiturn: Whether multi-turn conversation.
agent_role: Agent role for logging.
result_parts: Accumulated result parts (streaming passes accumulated,
polling passes None to extract from task).
endpoint: A2A agent endpoint URL.
a2a_agent_name: Name of the A2A agent from agent card.
from_task: Optional CrewAI Task for event metadata.
from_agent: Optional CrewAI Agent for event metadata.
is_final: Whether this is the final response in the stream.
Returns:
Result dictionary if terminal/actionable state, None otherwise.
"""
if result_parts is None:
result_parts = []
if a2a_task.status.state == TaskState.completed:
if not result_parts:
extracted_parts = extract_task_result_parts(a2a_task)
result_parts.extend(extracted_parts)
if a2a_task.history:
new_messages.extend(a2a_task.history)
response_text = " ".join(result_parts) if result_parts else ""
message_id = None
if a2a_task.status and a2a_task.status.message:
message_id = a2a_task.status.message.message_id
crewai_event_bus.emit(
None,
A2AResponseReceivedEvent(
response=response_text,
turn_number=turn_number,
context_id=a2a_task.context_id,
message_id=message_id,
is_multiturn=is_multiturn,
status="completed",
final=is_final,
agent_role=agent_role,
endpoint=endpoint,
a2a_agent_name=a2a_agent_name,
from_task=from_task,
from_agent=from_agent,
),
)
return TaskStateResult(
status=TaskState.completed,
agent_card=agent_card.model_dump(exclude_none=True),
result=response_text,
history=new_messages,
)
if a2a_task.status.state == TaskState.input_required:
if a2a_task.history:
new_messages.extend(a2a_task.history)
response_text = extract_error_message(a2a_task, "Additional input required")
if response_text and not a2a_task.history:
agent_message = Message(
role=Role.agent,
message_id=str(uuid.uuid4()),
parts=[Part(root=TextPart(text=response_text))],
context_id=a2a_task.context_id,
task_id=a2a_task.id,
)
new_messages.append(agent_message)
input_message_id = None
if a2a_task.status and a2a_task.status.message:
input_message_id = a2a_task.status.message.message_id
crewai_event_bus.emit(
None,
A2AResponseReceivedEvent(
response=response_text,
turn_number=turn_number,
context_id=a2a_task.context_id,
message_id=input_message_id,
is_multiturn=is_multiturn,
status="input_required",
final=is_final,
agent_role=agent_role,
endpoint=endpoint,
a2a_agent_name=a2a_agent_name,
from_task=from_task,
from_agent=from_agent,
),
)
return TaskStateResult(
status=TaskState.input_required,
error=response_text,
history=new_messages,
agent_card=agent_card.model_dump(exclude_none=True),
)
if a2a_task.status.state in {TaskState.failed, TaskState.rejected}:
error_msg = extract_error_message(a2a_task, "Task failed without error message")
if a2a_task.history:
new_messages.extend(a2a_task.history)
return TaskStateResult(
status=TaskState.failed,
error=error_msg,
history=new_messages,
)
if a2a_task.status.state == TaskState.auth_required:
error_msg = extract_error_message(a2a_task, "Authentication required")
return TaskStateResult(
status=TaskState.auth_required,
error=error_msg,
history=new_messages,
)
if a2a_task.status.state == TaskState.canceled:
error_msg = extract_error_message(a2a_task, "Task was canceled")
return TaskStateResult(
status=TaskState.canceled,
error=error_msg,
history=new_messages,
)
if a2a_task.status.state in PENDING_STATES:
return None
return None
async def send_message_and_get_task_id(
event_stream: AsyncIterator[SendMessageEvent],
new_messages: list[Message],
agent_card: AgentCard,
turn_number: int,
is_multiturn: bool,
agent_role: str | None,
from_task: Any | None = None,
from_agent: Any | None = None,
endpoint: str | None = None,
a2a_agent_name: str | None = None,
context_id: str | None = None,
) -> str | TaskStateResult:
"""Send message and process initial response.
Handles the common pattern of sending a message and either:
- Getting an immediate Message response (task completed synchronously)
- Getting a Task that needs polling/waiting for completion
Args:
event_stream: Async iterator from client.send_message()
new_messages: List to collect messages (modified in place)
agent_card: The agent card
turn_number: Current turn number
is_multiturn: Whether multi-turn conversation
agent_role: Agent role for logging
from_task: Optional CrewAI Task object for event metadata.
from_agent: Optional CrewAI Agent object for event metadata.
endpoint: Optional A2A endpoint URL.
a2a_agent_name: Optional A2A agent name.
context_id: Optional A2A context ID for correlation.
Returns:
Task ID string if agent needs polling/waiting, or TaskStateResult if done.
"""
try:
async for event in event_stream:
if isinstance(event, Message):
new_messages.append(event)
result_parts = [
part.root.text for part in event.parts if part.root.kind == "text"
]
response_text = " ".join(result_parts) if result_parts else ""
crewai_event_bus.emit(
None,
A2AResponseReceivedEvent(
response=response_text,
turn_number=turn_number,
context_id=event.context_id,
message_id=event.message_id,
is_multiturn=is_multiturn,
status="completed",
final=True,
agent_role=agent_role,
endpoint=endpoint,
a2a_agent_name=a2a_agent_name,
from_task=from_task,
from_agent=from_agent,
),
)
return TaskStateResult(
status=TaskState.completed,
result=response_text,
history=new_messages,
agent_card=agent_card.model_dump(exclude_none=True),
)
if isinstance(event, tuple):
a2a_task, _ = event
if a2a_task.status.state in TERMINAL_STATES | ACTIONABLE_STATES:
result = process_task_state(
a2a_task=a2a_task,
new_messages=new_messages,
agent_card=agent_card,
turn_number=turn_number,
is_multiturn=is_multiturn,
agent_role=agent_role,
endpoint=endpoint,
a2a_agent_name=a2a_agent_name,
from_task=from_task,
from_agent=from_agent,
)
if result:
return result
return a2a_task.id
return TaskStateResult(
status=TaskState.failed,
error="No task ID received from initial message",
history=new_messages,
)
except A2AClientHTTPError as e:
error_msg = f"HTTP Error {e.status_code}: {e!s}"
error_message = Message(
role=Role.agent,
message_id=str(uuid.uuid4()),
parts=[Part(root=TextPart(text=error_msg))],
context_id=context_id,
)
new_messages.append(error_message)
crewai_event_bus.emit(
None,
A2AConnectionErrorEvent(
endpoint=endpoint or "",
error=str(e),
error_type="http_error",
status_code=e.status_code,
a2a_agent_name=a2a_agent_name,
operation="send_message",
context_id=context_id,
from_task=from_task,
from_agent=from_agent,
),
)
crewai_event_bus.emit(
None,
A2AResponseReceivedEvent(
response=error_msg,
turn_number=turn_number,
context_id=context_id,
is_multiturn=is_multiturn,
status="failed",
final=True,
agent_role=agent_role,
endpoint=endpoint,
a2a_agent_name=a2a_agent_name,
from_task=from_task,
from_agent=from_agent,
),
)
return TaskStateResult(
status=TaskState.failed,
error=error_msg,
history=new_messages,
)
except Exception as e:
error_msg = f"Unexpected error during send_message: {e!s}"
error_message = Message(
role=Role.agent,
message_id=str(uuid.uuid4()),
parts=[Part(root=TextPart(text=error_msg))],
context_id=context_id,
)
new_messages.append(error_message)
crewai_event_bus.emit(
None,
A2AConnectionErrorEvent(
endpoint=endpoint or "",
error=str(e),
error_type="unexpected_error",
a2a_agent_name=a2a_agent_name,
operation="send_message",
context_id=context_id,
from_task=from_task,
from_agent=from_agent,
),
)
crewai_event_bus.emit(
None,
A2AResponseReceivedEvent(
response=error_msg,
turn_number=turn_number,
context_id=context_id,
is_multiturn=is_multiturn,
status="failed",
final=True,
agent_role=agent_role,
endpoint=endpoint,
a2a_agent_name=a2a_agent_name,
from_task=from_task,
from_agent=from_agent,
),
)
return TaskStateResult(
status=TaskState.failed,
error=error_msg,
history=new_messages,
)
finally:
aclose = getattr(event_stream, "aclose", None)
if aclose:
await aclose()
from crewai_a2a.task_helpers import * # noqa: E402, F403

View File

@@ -1,55 +1,13 @@
"""String templates for A2A (Agent-to-Agent) protocol messaging and status."""
"""Backward-compatibility shim — use ``crewai_a2a.templates`` instead."""
from string import Template
from typing import Final
import warnings
AVAILABLE_AGENTS_TEMPLATE: Final[Template] = Template(
"\n<AVAILABLE_A2A_AGENTS>\n $available_a2a_agents\n</AVAILABLE_A2A_AGENTS>\n"
warnings.warn(
"'crewai.a2a.templates' has been moved to 'crewai_a2a.templates'. "
"Please update your imports. The old path will be removed in v2.0.0.",
FutureWarning,
stacklevel=2,
)
PREVIOUS_A2A_CONVERSATION_TEMPLATE: Final[Template] = Template(
"\n<PREVIOUS_A2A_CONVERSATION>\n"
" $previous_a2a_conversation"
"\n</PREVIOUS_A2A_CONVERSATION>\n"
)
CONVERSATION_TURN_INFO_TEMPLATE: Final[Template] = Template(
"\n<CONVERSATION_PROGRESS>\n"
' turn="$turn_count"\n'
' max_turns="$max_turns"\n'
" $warning"
"\n</CONVERSATION_PROGRESS>\n"
)
UNAVAILABLE_AGENTS_NOTICE_TEMPLATE: Final[Template] = Template(
"\n<A2A_AGENTS_STATUS>\n"
" NOTE: A2A agents were configured but are currently unavailable.\n"
" You cannot delegate to remote agents for this task.\n\n"
" Unavailable Agents:\n"
" $unavailable_agents"
"\n</A2A_AGENTS_STATUS>\n"
)
REMOTE_AGENT_COMPLETED_NOTICE: Final[str] = """
<REMOTE_AGENT_STATUS>
STATUS: COMPLETED
The remote agent has finished processing your request. Their response is in the conversation history above.
You MUST now:
1. Extract the answer from the conversation history
2. Set is_a2a=false
3. Return the answer as your final message
DO NOT send another request - the task is already done.
</REMOTE_AGENT_STATUS>
"""
REMOTE_AGENT_RESPONSE_NOTICE: Final[str] = """
<REMOTE_AGENT_STATUS>
STATUS: RESPONSE_RECEIVED
The remote agent has responded. Their response is in the conversation history above.
You MUST now:
1. Set is_a2a=false (the remote task is complete and cannot receive more messages)
2. Provide YOUR OWN response to the original task based on the information received
IMPORTANT: Your response should be addressed to the USER who gave you the original task.
Report what the remote agent told you in THIRD PERSON (e.g., "The remote agent said..." or "I learned that...").
Do NOT address the remote agent directly or use "you" to refer to them.
</REMOTE_AGENT_STATUS>
"""
from crewai_a2a.templates import * # noqa: E402, F403

View File

@@ -1,104 +1,13 @@
"""Type definitions for A2A protocol message parts."""
"""Backward-compatibility shim — use ``crewai_a2a.types`` instead."""
from __future__ import annotations
import warnings
from typing import (
Annotated,
Any,
Literal,
Protocol,
TypedDict,
runtime_checkable,
warnings.warn(
"'crewai.a2a.types' has been moved to 'crewai_a2a.types'. "
"Please update your imports. The old path will be removed in v2.0.0.",
FutureWarning,
stacklevel=2,
)
from pydantic import BeforeValidator, HttpUrl, TypeAdapter
from typing_extensions import NotRequired
try:
from crewai.a2a.updates import (
PollingConfig,
PollingHandler,
PushNotificationConfig,
PushNotificationHandler,
StreamingConfig,
StreamingHandler,
UpdateConfig,
)
except ImportError:
PollingConfig = Any # type: ignore[misc,assignment]
PollingHandler = Any # type: ignore[misc,assignment]
PushNotificationConfig = Any # type: ignore[misc,assignment]
PushNotificationHandler = Any # type: ignore[misc,assignment]
StreamingConfig = Any # type: ignore[misc,assignment]
StreamingHandler = Any # type: ignore[misc,assignment]
UpdateConfig = Any # type: ignore[misc,assignment]
TransportType = Literal["JSONRPC", "GRPC", "HTTP+JSON"]
ProtocolVersion = Literal[
"0.2.0",
"0.2.1",
"0.2.2",
"0.2.3",
"0.2.4",
"0.2.5",
"0.2.6",
"0.3.0",
"0.4.0",
]
http_url_adapter: TypeAdapter[HttpUrl] = TypeAdapter(HttpUrl)
Url = Annotated[
str,
BeforeValidator(
lambda value: str(http_url_adapter.validate_python(value, strict=True))
),
]
@runtime_checkable
class AgentResponseProtocol(Protocol):
"""Protocol for the dynamically created AgentResponse model."""
a2a_ids: tuple[str, ...]
message: str
is_a2a: bool
class PartsMetadataDict(TypedDict, total=False):
"""Metadata for A2A message parts.
Attributes:
mimeType: MIME type for the part content.
schema: JSON schema for the part content.
"""
mimeType: Literal["application/json"]
schema: dict[str, Any]
class PartsDict(TypedDict):
"""A2A message part containing text and optional metadata.
Attributes:
text: The text content of the message part.
metadata: Optional metadata describing the part content.
"""
text: str
metadata: NotRequired[PartsMetadataDict]
PollingHandlerType = type[PollingHandler]
StreamingHandlerType = type[StreamingHandler]
PushNotificationHandlerType = type[PushNotificationHandler]
HandlerType = PollingHandlerType | StreamingHandlerType | PushNotificationHandlerType
HANDLER_REGISTRY: dict[type[UpdateConfig], HandlerType] = {
PollingConfig: PollingHandler,
StreamingConfig: StreamingHandler,
PushNotificationConfig: PushNotificationHandler,
}
from crewai_a2a.types import * # noqa: E402, F403

View File

@@ -1,35 +1,13 @@
"""A2A update mechanism configuration types."""
"""Backward-compatibility shim — use ``crewai_a2a.updates`` instead."""
from crewai.a2a.updates.base import (
BaseHandlerKwargs,
PollingHandlerKwargs,
PushNotificationHandlerKwargs,
PushNotificationResultStore,
StreamingHandlerKwargs,
UpdateHandler,
import warnings
warnings.warn(
"'crewai.a2a.updates' has been moved to 'crewai_a2a.updates'. "
"Please update your imports. The old path will be removed in v2.0.0.",
FutureWarning,
stacklevel=2,
)
from crewai.a2a.updates.polling.config import PollingConfig
from crewai.a2a.updates.polling.handler import PollingHandler
from crewai.a2a.updates.push_notifications.config import PushNotificationConfig
from crewai.a2a.updates.push_notifications.handler import PushNotificationHandler
from crewai.a2a.updates.streaming.config import StreamingConfig
from crewai.a2a.updates.streaming.handler import StreamingHandler
UpdateConfig = PollingConfig | StreamingConfig | PushNotificationConfig
__all__ = [
"BaseHandlerKwargs",
"PollingConfig",
"PollingHandler",
"PollingHandlerKwargs",
"PushNotificationConfig",
"PushNotificationHandler",
"PushNotificationHandlerKwargs",
"PushNotificationResultStore",
"StreamingConfig",
"StreamingHandler",
"StreamingHandlerKwargs",
"UpdateConfig",
"UpdateHandler",
]
from crewai_a2a.updates import * # noqa: E402, F403

View File

@@ -1,176 +1,13 @@
"""Base types for A2A update mechanism handlers."""
"""Backward-compatibility shim — use ``crewai_a2a.updates.base`` instead."""
from __future__ import annotations
from typing import TYPE_CHECKING, Any, NamedTuple, Protocol, TypedDict
from pydantic import GetCoreSchemaHandler
from pydantic_core import CoreSchema, core_schema
import warnings
class CommonParams(NamedTuple):
"""Common parameters shared across all update handlers.
warnings.warn(
"'crewai.a2a.updates.base' has been moved to 'crewai_a2a.updates.base'. "
"Please update your imports. The old path will be removed in v2.0.0.",
FutureWarning,
stacklevel=2,
)
Groups the frequently-passed parameters to reduce duplication.
"""
turn_number: int
is_multiturn: bool
agent_role: str | None
endpoint: str
a2a_agent_name: str | None
context_id: str | None
from_task: Any
from_agent: Any
if TYPE_CHECKING:
from a2a.client import Client
from a2a.types import AgentCard, Message, Task
from crewai.a2a.task_helpers import TaskStateResult
from crewai.a2a.updates.push_notifications.config import PushNotificationConfig
class BaseHandlerKwargs(TypedDict, total=False):
"""Base kwargs shared by all handlers."""
turn_number: int
is_multiturn: bool
agent_role: str | None
context_id: str | None
task_id: str | None
endpoint: str | None
agent_branch: Any
a2a_agent_name: str | None
from_task: Any
from_agent: Any
class PollingHandlerKwargs(BaseHandlerKwargs, total=False):
"""Kwargs for polling handler."""
polling_interval: float
polling_timeout: float
history_length: int
max_polls: int | None
class StreamingHandlerKwargs(BaseHandlerKwargs, total=False):
"""Kwargs for streaming handler."""
class PushNotificationHandlerKwargs(BaseHandlerKwargs, total=False):
"""Kwargs for push notification handler."""
config: PushNotificationConfig
result_store: PushNotificationResultStore
polling_timeout: float
polling_interval: float
class PushNotificationResultStore(Protocol):
"""Protocol for storing and retrieving push notification results.
This protocol defines the interface for a result store that the
PushNotificationHandler uses to wait for task completion.
"""
@classmethod
def __get_pydantic_core_schema__(
cls,
_source_type: Any,
_handler: GetCoreSchemaHandler,
) -> CoreSchema:
return core_schema.any_schema()
async def wait_for_result(
self,
task_id: str,
timeout: float,
poll_interval: float = 1.0,
) -> Task | None:
"""Wait for a task result to be available.
Args:
task_id: The task ID to wait for.
timeout: Max seconds to wait before returning None.
poll_interval: Seconds between polling attempts.
Returns:
The completed Task object, or None if timeout.
"""
...
async def get_result(self, task_id: str) -> Task | None:
"""Get a task result if available.
Args:
task_id: The task ID to retrieve.
Returns:
The Task object if available, None otherwise.
"""
...
async def store_result(self, task: Task) -> None:
"""Store a task result.
Args:
task: The Task object to store.
"""
...
class UpdateHandler(Protocol):
"""Protocol for A2A update mechanism handlers."""
@staticmethod
async def execute(
client: Client,
message: Message,
new_messages: list[Message],
agent_card: AgentCard,
**kwargs: Any,
) -> TaskStateResult:
"""Execute the update mechanism and return result.
Args:
client: A2A client instance.
message: Message to send.
new_messages: List to collect messages (modified in place).
agent_card: The agent card.
**kwargs: Additional handler-specific parameters.
Returns:
Result dictionary with status, result/error, and history.
"""
...
def extract_common_params(kwargs: BaseHandlerKwargs) -> CommonParams:
"""Extract common parameters from handler kwargs.
Args:
kwargs: Handler kwargs dict.
Returns:
CommonParams with extracted values.
Raises:
ValueError: If endpoint is not provided.
"""
endpoint = kwargs.get("endpoint")
if endpoint is None:
raise ValueError("endpoint is required for update handlers")
return CommonParams(
turn_number=kwargs.get("turn_number", 0),
is_multiturn=kwargs.get("is_multiturn", False),
agent_role=kwargs.get("agent_role"),
endpoint=endpoint,
a2a_agent_name=kwargs.get("a2a_agent_name"),
context_id=kwargs.get("context_id"),
from_task=kwargs.get("from_task"),
from_agent=kwargs.get("from_agent"),
)
from crewai_a2a.updates.base import * # noqa: E402, F403

View File

@@ -1 +1,13 @@
"""Polling update mechanism module."""
"""Backward-compatibility shim — use ``crewai_a2a.updates.polling`` instead."""
import warnings
warnings.warn(
"'crewai.a2a.updates.polling' has been moved to 'crewai_a2a.updates.polling'. "
"Please update your imports. The old path will be removed in v2.0.0.",
FutureWarning,
stacklevel=2,
)
from crewai_a2a.updates.polling import * # noqa: E402, F403

View File

@@ -1,25 +1,13 @@
"""Polling update mechanism configuration."""
"""Backward-compatibility shim — use ``crewai_a2a.updates.polling.config`` instead."""
from __future__ import annotations
from pydantic import BaseModel, Field
import warnings
class PollingConfig(BaseModel):
"""Configuration for polling-based task updates.
warnings.warn(
"'crewai.a2a.updates.polling.config' has been moved to 'crewai_a2a.updates.polling.config'. "
"Please update your imports. The old path will be removed in v2.0.0.",
FutureWarning,
stacklevel=2,
)
Attributes:
interval: Seconds between poll attempts.
timeout: Max seconds to poll before raising timeout error.
max_polls: Max number of poll attempts.
history_length: Number of messages to retrieve per poll.
"""
interval: float = Field(
default=2.0, gt=0, description="Seconds between poll attempts"
)
timeout: float | None = Field(default=None, gt=0, description="Max seconds to poll")
max_polls: int | None = Field(default=None, gt=0, description="Max poll attempts")
history_length: int = Field(
default=100, gt=0, description="Messages to retrieve per poll"
)
from crewai_a2a.updates.polling.config import * # noqa: E402, F403

View File

@@ -1,359 +1,13 @@
"""Polling update mechanism handler."""
"""Backward-compatibility shim — use ``crewai_a2a.updates.polling.handler`` instead."""
from __future__ import annotations
import warnings
import asyncio
import time
from typing import TYPE_CHECKING, Any
import uuid
from a2a.client import Client
from a2a.client.errors import A2AClientHTTPError
from a2a.types import (
AgentCard,
Message,
Part,
Role,
TaskQueryParams,
TaskState,
TextPart,
)
from typing_extensions import Unpack
from crewai.a2a.errors import A2APollingTimeoutError
from crewai.a2a.task_helpers import (
ACTIONABLE_STATES,
TERMINAL_STATES,
TaskStateResult,
process_task_state,
send_message_and_get_task_id,
)
from crewai.a2a.updates.base import PollingHandlerKwargs
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.a2a_events import (
A2AConnectionErrorEvent,
A2APollingStartedEvent,
A2APollingStatusEvent,
A2AResponseReceivedEvent,
warnings.warn(
"'crewai.a2a.updates.polling.handler' has been moved to 'crewai_a2a.updates.polling.handler'. "
"Please update your imports. The old path will be removed in v2.0.0.",
FutureWarning,
stacklevel=2,
)
if TYPE_CHECKING:
from a2a.types import Task as A2ATask
async def _poll_task_until_complete(
client: Client,
task_id: str,
polling_interval: float,
polling_timeout: float,
agent_branch: Any | None = None,
history_length: int = 100,
max_polls: int | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
context_id: str | None = None,
endpoint: str | None = None,
a2a_agent_name: str | None = None,
) -> A2ATask:
"""Poll task status until terminal state reached.
Args:
client: A2A client instance.
task_id: Task ID to poll.
polling_interval: Seconds between poll attempts.
polling_timeout: Max seconds before timeout.
agent_branch: Agent tree branch for logging.
history_length: Number of messages to retrieve per poll.
max_polls: Max number of poll attempts (None = unlimited).
from_task: Optional CrewAI Task object for event metadata.
from_agent: Optional CrewAI Agent object for event metadata.
context_id: A2A context ID for correlation.
endpoint: A2A agent endpoint URL.
a2a_agent_name: Name of the A2A agent from agent card.
Returns:
Final task object in terminal state.
Raises:
A2APollingTimeoutError: If polling exceeds timeout or max_polls.
"""
start_time = time.monotonic()
poll_count = 0
while True:
poll_count += 1
task = await client.get_task(
TaskQueryParams(id=task_id, history_length=history_length)
)
elapsed = time.monotonic() - start_time
effective_context_id = task.context_id or context_id
crewai_event_bus.emit(
agent_branch,
A2APollingStatusEvent(
task_id=task_id,
context_id=effective_context_id,
state=str(task.status.state.value),
elapsed_seconds=elapsed,
poll_count=poll_count,
endpoint=endpoint,
a2a_agent_name=a2a_agent_name,
from_task=from_task,
from_agent=from_agent,
),
)
if task.status.state in TERMINAL_STATES | ACTIONABLE_STATES:
return task
if elapsed > polling_timeout:
raise A2APollingTimeoutError(
f"Polling timeout after {polling_timeout}s ({poll_count} polls)"
)
if max_polls and poll_count >= max_polls:
raise A2APollingTimeoutError(
f"Max polls ({max_polls}) exceeded after {elapsed:.1f}s"
)
await asyncio.sleep(polling_interval)
class PollingHandler:
"""Polling-based update handler."""
@staticmethod
async def execute(
client: Client,
message: Message,
new_messages: list[Message],
agent_card: AgentCard,
**kwargs: Unpack[PollingHandlerKwargs],
) -> TaskStateResult:
"""Execute A2A delegation using polling for updates.
Args:
client: A2A client instance.
message: Message to send.
new_messages: List to collect messages.
agent_card: The agent card.
**kwargs: Polling-specific parameters.
Returns:
Dictionary with status, result/error, and history.
"""
polling_interval = kwargs.get("polling_interval", 2.0)
polling_timeout = kwargs.get("polling_timeout", 300.0)
endpoint = kwargs.get("endpoint", "")
agent_branch = kwargs.get("agent_branch")
turn_number = kwargs.get("turn_number", 0)
is_multiturn = kwargs.get("is_multiturn", False)
agent_role = kwargs.get("agent_role")
history_length = kwargs.get("history_length", 100)
max_polls = kwargs.get("max_polls")
context_id = kwargs.get("context_id")
task_id = kwargs.get("task_id")
a2a_agent_name = kwargs.get("a2a_agent_name")
from_task = kwargs.get("from_task")
from_agent = kwargs.get("from_agent")
try:
result_or_task_id = await send_message_and_get_task_id(
event_stream=client.send_message(message),
new_messages=new_messages,
agent_card=agent_card,
turn_number=turn_number,
is_multiturn=is_multiturn,
agent_role=agent_role,
from_task=from_task,
from_agent=from_agent,
endpoint=endpoint,
a2a_agent_name=a2a_agent_name,
context_id=context_id,
)
if not isinstance(result_or_task_id, str):
return result_or_task_id
task_id = result_or_task_id
crewai_event_bus.emit(
agent_branch,
A2APollingStartedEvent(
task_id=task_id,
context_id=context_id,
polling_interval=polling_interval,
endpoint=endpoint,
a2a_agent_name=a2a_agent_name,
from_task=from_task,
from_agent=from_agent,
),
)
final_task = await _poll_task_until_complete(
client=client,
task_id=task_id,
polling_interval=polling_interval,
polling_timeout=polling_timeout,
agent_branch=agent_branch,
history_length=history_length,
max_polls=max_polls,
from_task=from_task,
from_agent=from_agent,
context_id=context_id,
endpoint=endpoint,
a2a_agent_name=a2a_agent_name,
)
result = process_task_state(
a2a_task=final_task,
new_messages=new_messages,
agent_card=agent_card,
turn_number=turn_number,
is_multiturn=is_multiturn,
agent_role=agent_role,
endpoint=endpoint,
a2a_agent_name=a2a_agent_name,
from_task=from_task,
from_agent=from_agent,
)
if result:
return result
return TaskStateResult(
status=TaskState.failed,
error=f"Unexpected task state: {final_task.status.state}",
history=new_messages,
)
except A2APollingTimeoutError as e:
error_msg = str(e)
error_message = Message(
role=Role.agent,
message_id=str(uuid.uuid4()),
parts=[Part(root=TextPart(text=error_msg))],
context_id=context_id,
task_id=task_id,
)
new_messages.append(error_message)
crewai_event_bus.emit(
agent_branch,
A2AResponseReceivedEvent(
response=error_msg,
turn_number=turn_number,
context_id=context_id,
is_multiturn=is_multiturn,
status="failed",
final=True,
agent_role=agent_role,
endpoint=endpoint,
a2a_agent_name=a2a_agent_name,
from_task=from_task,
from_agent=from_agent,
),
)
return TaskStateResult(
status=TaskState.failed,
error=error_msg,
history=new_messages,
)
except A2AClientHTTPError as e:
error_msg = f"HTTP Error {e.status_code}: {e!s}"
error_message = Message(
role=Role.agent,
message_id=str(uuid.uuid4()),
parts=[Part(root=TextPart(text=error_msg))],
context_id=context_id,
task_id=task_id,
)
new_messages.append(error_message)
crewai_event_bus.emit(
agent_branch,
A2AConnectionErrorEvent(
endpoint=endpoint,
error=str(e),
error_type="http_error",
status_code=e.status_code,
a2a_agent_name=a2a_agent_name,
operation="polling",
context_id=context_id,
task_id=task_id,
from_task=from_task,
from_agent=from_agent,
),
)
crewai_event_bus.emit(
agent_branch,
A2AResponseReceivedEvent(
response=error_msg,
turn_number=turn_number,
context_id=context_id,
is_multiturn=is_multiturn,
status="failed",
final=True,
agent_role=agent_role,
endpoint=endpoint,
a2a_agent_name=a2a_agent_name,
from_task=from_task,
from_agent=from_agent,
),
)
return TaskStateResult(
status=TaskState.failed,
error=error_msg,
history=new_messages,
)
except Exception as e:
error_msg = f"Unexpected error during polling: {e!s}"
error_message = Message(
role=Role.agent,
message_id=str(uuid.uuid4()),
parts=[Part(root=TextPart(text=error_msg))],
context_id=context_id,
task_id=task_id,
)
new_messages.append(error_message)
crewai_event_bus.emit(
agent_branch,
A2AConnectionErrorEvent(
endpoint=endpoint,
error=str(e),
error_type="unexpected_error",
a2a_agent_name=a2a_agent_name,
operation="polling",
context_id=context_id,
task_id=task_id,
from_task=from_task,
from_agent=from_agent,
),
)
crewai_event_bus.emit(
agent_branch,
A2AResponseReceivedEvent(
response=error_msg,
turn_number=turn_number,
context_id=context_id,
is_multiturn=is_multiturn,
status="failed",
final=True,
agent_role=agent_role,
endpoint=endpoint,
a2a_agent_name=a2a_agent_name,
from_task=from_task,
from_agent=from_agent,
),
)
return TaskStateResult(
status=TaskState.failed,
error=error_msg,
history=new_messages,
)
from crewai_a2a.updates.polling.handler import * # noqa: E402, F403

View File

@@ -1 +1,13 @@
"""Push notification update mechanism module."""
"""Backward-compatibility shim — use ``crewai_a2a.updates.push_notifications`` instead."""
import warnings
warnings.warn(
"'crewai.a2a.updates.push_notifications' has been moved to 'crewai_a2a.updates.push_notifications'. "
"Please update your imports. The old path will be removed in v2.0.0.",
FutureWarning,
stacklevel=2,
)
from crewai_a2a.updates.push_notifications import * # noqa: E402, F403

View File

@@ -1,65 +1,13 @@
"""Push notification update mechanism configuration."""
"""Backward-compatibility shim — use ``crewai_a2a.updates.push_notifications.config`` instead."""
from __future__ import annotations
from typing import Annotated
from a2a.types import PushNotificationAuthenticationInfo
from pydantic import AnyHttpUrl, BaseModel, BeforeValidator, Field
from crewai.a2a.updates.base import PushNotificationResultStore
from crewai.a2a.updates.push_notifications.signature import WebhookSignatureConfig
import warnings
def _coerce_signature(
value: str | WebhookSignatureConfig | None,
) -> WebhookSignatureConfig | None:
"""Convert string secret to WebhookSignatureConfig."""
if value is None:
return None
if isinstance(value, str):
return WebhookSignatureConfig.hmac_sha256(secret=value)
return value
warnings.warn(
"'crewai.a2a.updates.push_notifications.config' has been moved to 'crewai_a2a.updates.push_notifications.config'. "
"Please update your imports. The old path will be removed in v2.0.0.",
FutureWarning,
stacklevel=2,
)
SignatureInput = Annotated[
WebhookSignatureConfig | None,
BeforeValidator(_coerce_signature),
]
class PushNotificationConfig(BaseModel):
"""Configuration for webhook-based task updates.
Attributes:
url: Callback URL where agent sends push notifications.
id: Unique identifier for this config.
token: Token to validate incoming notifications.
authentication: Auth info for agent to use when calling webhook.
timeout: Max seconds to wait for task completion.
interval: Seconds between result polling attempts.
result_store: Store for receiving push notification results.
signature: HMAC signature config. Pass a string (secret) for defaults,
or WebhookSignatureConfig for custom settings.
"""
url: AnyHttpUrl = Field(description="Callback URL for push notifications")
id: str | None = Field(default=None, description="Unique config identifier")
token: str | None = Field(default=None, description="Validation token")
authentication: PushNotificationAuthenticationInfo | None = Field(
default=None, description="Auth info for agent to use when calling webhook"
)
timeout: float | None = Field(
default=300.0, gt=0, description="Max seconds to wait for task completion"
)
interval: float = Field(
default=2.0, gt=0, description="Seconds between result polling attempts"
)
result_store: PushNotificationResultStore | None = Field(
default=None, description="Result store for push notification handling"
)
signature: SignatureInput = Field(
default=None,
description="HMAC signature config. Pass a string (secret) for simple usage, "
"or WebhookSignatureConfig for custom headers/tolerance.",
)
from crewai_a2a.updates.push_notifications.config import * # noqa: E402, F403

View File

@@ -1,354 +1,13 @@
"""Push notification (webhook) update mechanism handler."""
"""Backward-compatibility shim — use ``crewai_a2a.updates.push_notifications.handler`` instead."""
from __future__ import annotations
import warnings
import logging
from typing import TYPE_CHECKING, Any
import uuid
from a2a.client import Client
from a2a.client.errors import A2AClientHTTPError
from a2a.types import (
AgentCard,
Message,
Part,
Role,
TaskState,
TextPart,
)
from typing_extensions import Unpack
from crewai.a2a.task_helpers import (
TaskStateResult,
process_task_state,
send_message_and_get_task_id,
)
from crewai.a2a.updates.base import (
CommonParams,
PushNotificationHandlerKwargs,
PushNotificationResultStore,
extract_common_params,
)
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.a2a_events import (
A2AConnectionErrorEvent,
A2APushNotificationRegisteredEvent,
A2APushNotificationTimeoutEvent,
A2AResponseReceivedEvent,
warnings.warn(
"'crewai.a2a.updates.push_notifications.handler' has been moved to 'crewai_a2a.updates.push_notifications.handler'. "
"Please update your imports. The old path will be removed in v2.0.0.",
FutureWarning,
stacklevel=2,
)
if TYPE_CHECKING:
from a2a.types import Task as A2ATask
logger = logging.getLogger(__name__)
def _handle_push_error(
error: Exception,
error_msg: str,
error_type: str,
new_messages: list[Message],
agent_branch: Any | None,
params: CommonParams,
task_id: str | None,
status_code: int | None = None,
) -> TaskStateResult:
"""Handle push notification errors with consistent event emission.
Args:
error: The exception that occurred.
error_msg: Formatted error message for the result.
error_type: Type of error for the event.
new_messages: List to append error message to.
agent_branch: Agent tree branch for events.
params: Common handler parameters.
task_id: A2A task ID.
status_code: HTTP status code if applicable.
Returns:
TaskStateResult with failed status.
"""
error_message = Message(
role=Role.agent,
message_id=str(uuid.uuid4()),
parts=[Part(root=TextPart(text=error_msg))],
context_id=params.context_id,
task_id=task_id,
)
new_messages.append(error_message)
crewai_event_bus.emit(
agent_branch,
A2AConnectionErrorEvent(
endpoint=params.endpoint,
error=str(error),
error_type=error_type,
status_code=status_code,
a2a_agent_name=params.a2a_agent_name,
operation="push_notification",
context_id=params.context_id,
task_id=task_id,
from_task=params.from_task,
from_agent=params.from_agent,
),
)
crewai_event_bus.emit(
agent_branch,
A2AResponseReceivedEvent(
response=error_msg,
turn_number=params.turn_number,
context_id=params.context_id,
is_multiturn=params.is_multiturn,
status="failed",
final=True,
agent_role=params.agent_role,
endpoint=params.endpoint,
a2a_agent_name=params.a2a_agent_name,
from_task=params.from_task,
from_agent=params.from_agent,
),
)
return TaskStateResult(
status=TaskState.failed,
error=error_msg,
history=new_messages,
)
async def _wait_for_push_result(
task_id: str,
result_store: PushNotificationResultStore,
timeout: float,
poll_interval: float,
agent_branch: Any | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
context_id: str | None = None,
endpoint: str | None = None,
a2a_agent_name: str | None = None,
) -> A2ATask | None:
"""Wait for push notification result.
Args:
task_id: Task ID to wait for.
result_store: Store to retrieve results from.
timeout: Max seconds to wait.
poll_interval: Seconds between polling attempts.
agent_branch: Agent tree branch for logging.
from_task: Optional CrewAI Task object for event metadata.
from_agent: Optional CrewAI Agent object for event metadata.
context_id: A2A context ID for correlation.
endpoint: A2A agent endpoint URL.
a2a_agent_name: Name of the A2A agent.
Returns:
Final task object, or None if timeout.
"""
task = await result_store.wait_for_result(
task_id=task_id,
timeout=timeout,
poll_interval=poll_interval,
)
if task is None:
crewai_event_bus.emit(
agent_branch,
A2APushNotificationTimeoutEvent(
task_id=task_id,
context_id=context_id,
timeout_seconds=timeout,
endpoint=endpoint,
a2a_agent_name=a2a_agent_name,
from_task=from_task,
from_agent=from_agent,
),
)
return task
class PushNotificationHandler:
"""Push notification (webhook) based update handler."""
@staticmethod
async def execute(
client: Client,
message: Message,
new_messages: list[Message],
agent_card: AgentCard,
**kwargs: Unpack[PushNotificationHandlerKwargs],
) -> TaskStateResult:
"""Execute A2A delegation using push notifications for updates.
Args:
client: A2A client instance.
message: Message to send.
new_messages: List to collect messages.
agent_card: The agent card.
**kwargs: Push notification-specific parameters.
Returns:
Dictionary with status, result/error, and history.
Raises:
ValueError: If result_store or config not provided.
"""
config = kwargs.get("config")
result_store = kwargs.get("result_store")
polling_timeout = kwargs.get("polling_timeout", 300.0)
polling_interval = kwargs.get("polling_interval", 2.0)
agent_branch = kwargs.get("agent_branch")
task_id = kwargs.get("task_id")
params = extract_common_params(kwargs)
if config is None:
error_msg = (
"PushNotificationConfig is required for push notification handler"
)
crewai_event_bus.emit(
agent_branch,
A2AConnectionErrorEvent(
endpoint=params.endpoint,
error=error_msg,
error_type="configuration_error",
a2a_agent_name=params.a2a_agent_name,
operation="push_notification",
context_id=params.context_id,
task_id=task_id,
from_task=params.from_task,
from_agent=params.from_agent,
),
)
return TaskStateResult(
status=TaskState.failed,
error=error_msg,
history=new_messages,
)
if result_store is None:
error_msg = (
"PushNotificationResultStore is required for push notification handler"
)
crewai_event_bus.emit(
agent_branch,
A2AConnectionErrorEvent(
endpoint=params.endpoint,
error=error_msg,
error_type="configuration_error",
a2a_agent_name=params.a2a_agent_name,
operation="push_notification",
context_id=params.context_id,
task_id=task_id,
from_task=params.from_task,
from_agent=params.from_agent,
),
)
return TaskStateResult(
status=TaskState.failed,
error=error_msg,
history=new_messages,
)
try:
result_or_task_id = await send_message_and_get_task_id(
event_stream=client.send_message(message),
new_messages=new_messages,
agent_card=agent_card,
turn_number=params.turn_number,
is_multiturn=params.is_multiturn,
agent_role=params.agent_role,
from_task=params.from_task,
from_agent=params.from_agent,
endpoint=params.endpoint,
a2a_agent_name=params.a2a_agent_name,
context_id=params.context_id,
)
if not isinstance(result_or_task_id, str):
return result_or_task_id
task_id = result_or_task_id
crewai_event_bus.emit(
agent_branch,
A2APushNotificationRegisteredEvent(
task_id=task_id,
context_id=params.context_id,
callback_url=str(config.url),
endpoint=params.endpoint,
a2a_agent_name=params.a2a_agent_name,
from_task=params.from_task,
from_agent=params.from_agent,
),
)
logger.debug(
"Push notification callback for task %s configured at %s (via initial request)",
task_id,
config.url,
)
final_task = await _wait_for_push_result(
task_id=task_id,
result_store=result_store,
timeout=polling_timeout,
poll_interval=polling_interval,
agent_branch=agent_branch,
from_task=params.from_task,
from_agent=params.from_agent,
context_id=params.context_id,
endpoint=params.endpoint,
a2a_agent_name=params.a2a_agent_name,
)
if final_task is None:
return TaskStateResult(
status=TaskState.failed,
error=f"Push notification timeout after {polling_timeout}s",
history=new_messages,
)
result = process_task_state(
a2a_task=final_task,
new_messages=new_messages,
agent_card=agent_card,
turn_number=params.turn_number,
is_multiturn=params.is_multiturn,
agent_role=params.agent_role,
endpoint=params.endpoint,
a2a_agent_name=params.a2a_agent_name,
from_task=params.from_task,
from_agent=params.from_agent,
)
if result:
return result
return TaskStateResult(
status=TaskState.failed,
error=f"Unexpected task state: {final_task.status.state}",
history=new_messages,
)
except A2AClientHTTPError as e:
return _handle_push_error(
error=e,
error_msg=f"HTTP Error {e.status_code}: {e!s}",
error_type="http_error",
new_messages=new_messages,
agent_branch=agent_branch,
params=params,
task_id=task_id,
status_code=e.status_code,
)
except Exception as e:
return _handle_push_error(
error=e,
error_msg=f"Unexpected error during push notification: {e!s}",
error_type="unexpected_error",
new_messages=new_messages,
agent_branch=agent_branch,
params=params,
task_id=task_id,
)
from crewai_a2a.updates.push_notifications.handler import * # noqa: E402, F403

View File

@@ -1,87 +1,13 @@
"""Webhook signature configuration for push notifications."""
"""Backward-compatibility shim — use ``crewai_a2a.updates.push_notifications.signature`` instead."""
from __future__ import annotations
from enum import Enum
import secrets
from pydantic import BaseModel, Field, SecretStr
import warnings
class WebhookSignatureMode(str, Enum):
"""Signature mode for webhook push notifications."""
warnings.warn(
"'crewai.a2a.updates.push_notifications.signature' has been moved to 'crewai_a2a.updates.push_notifications.signature'. "
"Please update your imports. The old path will be removed in v2.0.0.",
FutureWarning,
stacklevel=2,
)
NONE = "none"
HMAC_SHA256 = "hmac_sha256"
class WebhookSignatureConfig(BaseModel):
"""Configuration for webhook signature verification.
Provides cryptographic integrity verification and replay attack protection
for A2A push notifications.
Attributes:
mode: Signature mode (none or hmac_sha256).
secret: Shared secret for HMAC computation (required for hmac_sha256 mode).
timestamp_tolerance_seconds: Max allowed age of timestamps for replay protection.
header_name: HTTP header name for the signature.
timestamp_header_name: HTTP header name for the timestamp.
"""
mode: WebhookSignatureMode = Field(
default=WebhookSignatureMode.NONE,
description="Signature verification mode",
)
secret: SecretStr | None = Field(
default=None,
description="Shared secret for HMAC computation",
)
timestamp_tolerance_seconds: int = Field(
default=300,
ge=0,
description="Max allowed timestamp age in seconds (5 min default)",
)
header_name: str = Field(
default="X-A2A-Signature",
description="HTTP header name for the signature",
)
timestamp_header_name: str = Field(
default="X-A2A-Signature-Timestamp",
description="HTTP header name for the timestamp",
)
@classmethod
def generate_secret(cls, length: int = 32) -> str:
"""Generate a cryptographically secure random secret.
Args:
length: Number of random bytes to generate (default 32).
Returns:
URL-safe base64-encoded secret string.
"""
return secrets.token_urlsafe(length)
@classmethod
def hmac_sha256(
cls,
secret: str | SecretStr,
timestamp_tolerance_seconds: int = 300,
) -> WebhookSignatureConfig:
"""Create an HMAC-SHA256 signature configuration.
Args:
secret: Shared secret for HMAC computation.
timestamp_tolerance_seconds: Max allowed timestamp age in seconds.
Returns:
Configured WebhookSignatureConfig for HMAC-SHA256.
"""
if isinstance(secret, str):
secret = SecretStr(secret)
return cls(
mode=WebhookSignatureMode.HMAC_SHA256,
secret=secret,
timestamp_tolerance_seconds=timestamp_tolerance_seconds,
)
from crewai_a2a.updates.push_notifications.signature import * # noqa: E402, F403

View File

@@ -1 +1,13 @@
"""Streaming update mechanism module."""
"""Backward-compatibility shim — use ``crewai_a2a.updates.streaming`` instead."""
import warnings
warnings.warn(
"'crewai.a2a.updates.streaming' has been moved to 'crewai_a2a.updates.streaming'. "
"Please update your imports. The old path will be removed in v2.0.0.",
FutureWarning,
stacklevel=2,
)
from crewai_a2a.updates.streaming import * # noqa: E402, F403

View File

@@ -1,9 +1,13 @@
"""Streaming update mechanism configuration."""
"""Backward-compatibility shim — use ``crewai_a2a.updates.streaming.config`` instead."""
from __future__ import annotations
from pydantic import BaseModel
import warnings
class StreamingConfig(BaseModel):
"""Configuration for SSE-based task updates."""
warnings.warn(
"'crewai.a2a.updates.streaming.config' has been moved to 'crewai_a2a.updates.streaming.config'. "
"Please update your imports. The old path will be removed in v2.0.0.",
FutureWarning,
stacklevel=2,
)
from crewai_a2a.updates.streaming.config import * # noqa: E402, F403

View File

@@ -1,646 +1,13 @@
"""Streaming (SSE) update mechanism handler."""
"""Backward-compatibility shim — use ``crewai_a2a.updates.streaming.handler`` instead."""
from __future__ import annotations
import warnings
import asyncio
import logging
from typing import Final
import uuid
from a2a.client import Client
from a2a.client.errors import A2AClientHTTPError
from a2a.types import (
AgentCard,
Message,
Part,
Role,
Task,
TaskArtifactUpdateEvent,
TaskIdParams,
TaskQueryParams,
TaskState,
TaskStatusUpdateEvent,
TextPart,
)
from typing_extensions import Unpack
from crewai.a2a.task_helpers import (
ACTIONABLE_STATES,
TERMINAL_STATES,
TaskStateResult,
process_task_state,
)
from crewai.a2a.updates.base import StreamingHandlerKwargs, extract_common_params
from crewai.a2a.updates.streaming.params import (
process_status_update,
)
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.a2a_events import (
A2AArtifactReceivedEvent,
A2AConnectionErrorEvent,
A2AResponseReceivedEvent,
A2AStreamingChunkEvent,
A2AStreamingStartedEvent,
warnings.warn(
"'crewai.a2a.updates.streaming.handler' has been moved to 'crewai_a2a.updates.streaming.handler'. "
"Please update your imports. The old path will be removed in v2.0.0.",
FutureWarning,
stacklevel=2,
)
logger = logging.getLogger(__name__)
MAX_RESUBSCRIBE_ATTEMPTS: Final[int] = 3
RESUBSCRIBE_BACKOFF_BASE: Final[float] = 1.0
class StreamingHandler:
"""SSE streaming-based update handler."""
@staticmethod
async def _try_recover_from_interruption( # type: ignore[misc]
client: Client,
task_id: str,
new_messages: list[Message],
agent_card: AgentCard,
result_parts: list[str],
**kwargs: Unpack[StreamingHandlerKwargs],
) -> TaskStateResult | None:
"""Attempt to recover from a stream interruption by checking task state.
If the task completed while we were disconnected, returns the result.
If the task is still running, attempts to resubscribe and continue.
Args:
client: A2A client instance.
task_id: The task ID to recover.
new_messages: List of collected messages.
agent_card: The agent card.
result_parts: Accumulated result text parts.
**kwargs: Handler parameters.
Returns:
TaskStateResult if recovery succeeded (task finished or resubscribe worked).
None if recovery not possible (caller should handle failure).
Note:
When None is returned, recovery failed and the original exception should
be handled by the caller. All recovery attempts are logged.
"""
params = extract_common_params(kwargs) # type: ignore[arg-type]
try:
a2a_task: Task = await client.get_task(TaskQueryParams(id=task_id))
if a2a_task.status.state in TERMINAL_STATES:
logger.info(
"Task completed during stream interruption",
extra={"task_id": task_id, "state": str(a2a_task.status.state)},
)
return process_task_state(
a2a_task=a2a_task,
new_messages=new_messages,
agent_card=agent_card,
turn_number=params.turn_number,
is_multiturn=params.is_multiturn,
agent_role=params.agent_role,
result_parts=result_parts,
endpoint=params.endpoint,
a2a_agent_name=params.a2a_agent_name,
from_task=params.from_task,
from_agent=params.from_agent,
)
if a2a_task.status.state in ACTIONABLE_STATES:
logger.info(
"Task in actionable state during stream interruption",
extra={"task_id": task_id, "state": str(a2a_task.status.state)},
)
return process_task_state(
a2a_task=a2a_task,
new_messages=new_messages,
agent_card=agent_card,
turn_number=params.turn_number,
is_multiturn=params.is_multiturn,
agent_role=params.agent_role,
result_parts=result_parts,
endpoint=params.endpoint,
a2a_agent_name=params.a2a_agent_name,
from_task=params.from_task,
from_agent=params.from_agent,
is_final=False,
)
logger.info(
"Task still running, attempting resubscribe",
extra={"task_id": task_id, "state": str(a2a_task.status.state)},
)
for attempt in range(MAX_RESUBSCRIBE_ATTEMPTS):
try:
backoff = RESUBSCRIBE_BACKOFF_BASE * (2**attempt)
if attempt > 0:
await asyncio.sleep(backoff)
event_stream = client.resubscribe(TaskIdParams(id=task_id))
async for event in event_stream:
if isinstance(event, tuple):
resubscribed_task, update = event
is_final_update = (
process_status_update(update, result_parts)
if isinstance(update, TaskStatusUpdateEvent)
else False
)
if isinstance(update, TaskArtifactUpdateEvent):
artifact = update.artifact
result_parts.extend(
part.root.text
for part in artifact.parts
if part.root.kind == "text"
)
if (
is_final_update
or resubscribed_task.status.state
in TERMINAL_STATES | ACTIONABLE_STATES
):
return process_task_state(
a2a_task=resubscribed_task,
new_messages=new_messages,
agent_card=agent_card,
turn_number=params.turn_number,
is_multiturn=params.is_multiturn,
agent_role=params.agent_role,
result_parts=result_parts,
endpoint=params.endpoint,
a2a_agent_name=params.a2a_agent_name,
from_task=params.from_task,
from_agent=params.from_agent,
is_final=is_final_update,
)
elif isinstance(event, Message):
new_messages.append(event)
result_parts.extend(
part.root.text
for part in event.parts
if part.root.kind == "text"
)
final_task = await client.get_task(TaskQueryParams(id=task_id))
return process_task_state(
a2a_task=final_task,
new_messages=new_messages,
agent_card=agent_card,
turn_number=params.turn_number,
is_multiturn=params.is_multiturn,
agent_role=params.agent_role,
result_parts=result_parts,
endpoint=params.endpoint,
a2a_agent_name=params.a2a_agent_name,
from_task=params.from_task,
from_agent=params.from_agent,
)
except Exception as resubscribe_error: # noqa: PERF203
logger.warning(
"Resubscribe attempt failed",
extra={
"task_id": task_id,
"attempt": attempt + 1,
"max_attempts": MAX_RESUBSCRIBE_ATTEMPTS,
"error": str(resubscribe_error),
},
)
if attempt == MAX_RESUBSCRIBE_ATTEMPTS - 1:
return None
except Exception as e:
logger.warning(
"Failed to recover from stream interruption due to unexpected error",
extra={
"task_id": task_id,
"error": str(e),
"error_type": type(e).__name__,
},
exc_info=True,
)
return None
logger.warning(
"Recovery exhausted all resubscribe attempts without success",
extra={"task_id": task_id, "max_attempts": MAX_RESUBSCRIBE_ATTEMPTS},
)
return None
@staticmethod
async def execute(
client: Client,
message: Message,
new_messages: list[Message],
agent_card: AgentCard,
**kwargs: Unpack[StreamingHandlerKwargs],
) -> TaskStateResult:
"""Execute A2A delegation using SSE streaming for updates.
Args:
client: A2A client instance.
message: Message to send.
new_messages: List to collect messages.
agent_card: The agent card.
**kwargs: Streaming-specific parameters.
Returns:
Dictionary with status, result/error, and history.
"""
task_id = kwargs.get("task_id")
agent_branch = kwargs.get("agent_branch")
params = extract_common_params(kwargs)
result_parts: list[str] = []
final_result: TaskStateResult | None = None
event_stream = client.send_message(message)
chunk_index = 0
current_task_id: str | None = task_id
crewai_event_bus.emit(
agent_branch,
A2AStreamingStartedEvent(
task_id=task_id,
context_id=params.context_id,
endpoint=params.endpoint,
a2a_agent_name=params.a2a_agent_name,
turn_number=params.turn_number,
is_multiturn=params.is_multiturn,
agent_role=params.agent_role,
from_task=params.from_task,
from_agent=params.from_agent,
),
)
try:
async for event in event_stream:
if isinstance(event, tuple):
a2a_task, _ = event
current_task_id = a2a_task.id
if isinstance(event, Message):
new_messages.append(event)
message_context_id = event.context_id or params.context_id
for part in event.parts:
if part.root.kind == "text":
text = part.root.text
result_parts.append(text)
crewai_event_bus.emit(
agent_branch,
A2AStreamingChunkEvent(
task_id=event.task_id or task_id,
context_id=message_context_id,
chunk=text,
chunk_index=chunk_index,
endpoint=params.endpoint,
a2a_agent_name=params.a2a_agent_name,
turn_number=params.turn_number,
is_multiturn=params.is_multiturn,
from_task=params.from_task,
from_agent=params.from_agent,
),
)
chunk_index += 1
elif isinstance(event, tuple):
a2a_task, update = event
if isinstance(update, TaskArtifactUpdateEvent):
artifact = update.artifact
result_parts.extend(
part.root.text
for part in artifact.parts
if part.root.kind == "text"
)
artifact_size = None
if artifact.parts:
artifact_size = sum(
len(p.root.text.encode())
if p.root.kind == "text"
else len(getattr(p.root, "data", b""))
for p in artifact.parts
)
effective_context_id = a2a_task.context_id or params.context_id
crewai_event_bus.emit(
agent_branch,
A2AArtifactReceivedEvent(
task_id=a2a_task.id,
artifact_id=artifact.artifact_id,
artifact_name=artifact.name,
artifact_description=artifact.description,
mime_type=artifact.parts[0].root.kind
if artifact.parts
else None,
size_bytes=artifact_size,
append=update.append or False,
last_chunk=update.last_chunk or False,
endpoint=params.endpoint,
a2a_agent_name=params.a2a_agent_name,
context_id=effective_context_id,
turn_number=params.turn_number,
is_multiturn=params.is_multiturn,
from_task=params.from_task,
from_agent=params.from_agent,
),
)
is_final_update = (
process_status_update(update, result_parts)
if isinstance(update, TaskStatusUpdateEvent)
else False
)
if (
not is_final_update
and a2a_task.status.state
not in TERMINAL_STATES | ACTIONABLE_STATES
):
continue
final_result = process_task_state(
a2a_task=a2a_task,
new_messages=new_messages,
agent_card=agent_card,
turn_number=params.turn_number,
is_multiturn=params.is_multiturn,
agent_role=params.agent_role,
result_parts=result_parts,
endpoint=params.endpoint,
a2a_agent_name=params.a2a_agent_name,
from_task=params.from_task,
from_agent=params.from_agent,
is_final=is_final_update,
)
if final_result:
break
except A2AClientHTTPError as e:
if current_task_id:
logger.info(
"Stream interrupted with HTTP error, attempting recovery",
extra={
"task_id": current_task_id,
"error": str(e),
"status_code": e.status_code,
},
)
recovery_kwargs = {k: v for k, v in kwargs.items() if k != "task_id"}
recovered_result = (
await StreamingHandler._try_recover_from_interruption(
client=client,
task_id=current_task_id,
new_messages=new_messages,
agent_card=agent_card,
result_parts=result_parts,
**recovery_kwargs,
)
)
if recovered_result:
logger.info(
"Successfully recovered task after HTTP error",
extra={
"task_id": current_task_id,
"status": str(recovered_result.get("status")),
},
)
return recovered_result
logger.warning(
"Failed to recover from HTTP error, returning failure",
extra={
"task_id": current_task_id,
"status_code": e.status_code,
"original_error": str(e),
},
)
error_msg = f"HTTP Error {e.status_code}: {e!s}"
error_type = "http_error"
status_code = e.status_code
error_message = Message(
role=Role.agent,
message_id=str(uuid.uuid4()),
parts=[Part(root=TextPart(text=error_msg))],
context_id=params.context_id,
task_id=task_id,
)
new_messages.append(error_message)
crewai_event_bus.emit(
agent_branch,
A2AConnectionErrorEvent(
endpoint=params.endpoint,
error=str(e),
error_type=error_type,
status_code=status_code,
a2a_agent_name=params.a2a_agent_name,
operation="streaming",
context_id=params.context_id,
task_id=task_id,
from_task=params.from_task,
from_agent=params.from_agent,
),
)
crewai_event_bus.emit(
agent_branch,
A2AResponseReceivedEvent(
response=error_msg,
turn_number=params.turn_number,
context_id=params.context_id,
is_multiturn=params.is_multiturn,
status="failed",
final=True,
agent_role=params.agent_role,
endpoint=params.endpoint,
a2a_agent_name=params.a2a_agent_name,
from_task=params.from_task,
from_agent=params.from_agent,
),
)
return TaskStateResult(
status=TaskState.failed,
error=error_msg,
history=new_messages,
)
except (asyncio.TimeoutError, asyncio.CancelledError, ConnectionError) as e:
error_type = type(e).__name__.lower()
if current_task_id:
logger.info(
f"Stream interrupted with {error_type}, attempting recovery",
extra={"task_id": current_task_id, "error": str(e)},
)
recovery_kwargs = {k: v for k, v in kwargs.items() if k != "task_id"}
recovered_result = (
await StreamingHandler._try_recover_from_interruption(
client=client,
task_id=current_task_id,
new_messages=new_messages,
agent_card=agent_card,
result_parts=result_parts,
**recovery_kwargs,
)
)
if recovered_result:
logger.info(
f"Successfully recovered task after {error_type}",
extra={
"task_id": current_task_id,
"status": str(recovered_result.get("status")),
},
)
return recovered_result
logger.warning(
f"Failed to recover from {error_type}, returning failure",
extra={
"task_id": current_task_id,
"error_type": error_type,
"original_error": str(e),
},
)
error_msg = f"Connection error during streaming: {e!s}"
status_code = None
error_message = Message(
role=Role.agent,
message_id=str(uuid.uuid4()),
parts=[Part(root=TextPart(text=error_msg))],
context_id=params.context_id,
task_id=task_id,
)
new_messages.append(error_message)
crewai_event_bus.emit(
agent_branch,
A2AConnectionErrorEvent(
endpoint=params.endpoint,
error=str(e),
error_type=error_type,
status_code=status_code,
a2a_agent_name=params.a2a_agent_name,
operation="streaming",
context_id=params.context_id,
task_id=task_id,
from_task=params.from_task,
from_agent=params.from_agent,
),
)
crewai_event_bus.emit(
agent_branch,
A2AResponseReceivedEvent(
response=error_msg,
turn_number=params.turn_number,
context_id=params.context_id,
is_multiturn=params.is_multiturn,
status="failed",
final=True,
agent_role=params.agent_role,
endpoint=params.endpoint,
a2a_agent_name=params.a2a_agent_name,
from_task=params.from_task,
from_agent=params.from_agent,
),
)
return TaskStateResult(
status=TaskState.failed,
error=error_msg,
history=new_messages,
)
except Exception as e:
logger.exception(
"Unexpected error during streaming",
extra={
"task_id": current_task_id,
"error_type": type(e).__name__,
"endpoint": params.endpoint,
},
)
error_msg = f"Unexpected error during streaming: {type(e).__name__}: {e!s}"
error_type = "unexpected_error"
status_code = None
error_message = Message(
role=Role.agent,
message_id=str(uuid.uuid4()),
parts=[Part(root=TextPart(text=error_msg))],
context_id=params.context_id,
task_id=task_id,
)
new_messages.append(error_message)
crewai_event_bus.emit(
agent_branch,
A2AConnectionErrorEvent(
endpoint=params.endpoint,
error=str(e),
error_type=error_type,
status_code=status_code,
a2a_agent_name=params.a2a_agent_name,
operation="streaming",
context_id=params.context_id,
task_id=task_id,
from_task=params.from_task,
from_agent=params.from_agent,
),
)
crewai_event_bus.emit(
agent_branch,
A2AResponseReceivedEvent(
response=error_msg,
turn_number=params.turn_number,
context_id=params.context_id,
is_multiturn=params.is_multiturn,
status="failed",
final=True,
agent_role=params.agent_role,
endpoint=params.endpoint,
a2a_agent_name=params.a2a_agent_name,
from_task=params.from_task,
from_agent=params.from_agent,
),
)
return TaskStateResult(
status=TaskState.failed,
error=error_msg,
history=new_messages,
)
finally:
aclose = getattr(event_stream, "aclose", None)
if aclose:
try:
await aclose()
except Exception as close_error:
crewai_event_bus.emit(
agent_branch,
A2AConnectionErrorEvent(
endpoint=params.endpoint,
error=str(close_error),
error_type="stream_close_error",
a2a_agent_name=params.a2a_agent_name,
operation="stream_close",
context_id=params.context_id,
task_id=task_id,
from_task=params.from_task,
from_agent=params.from_agent,
),
)
if final_result:
return final_result
return TaskStateResult(
status=TaskState.completed,
result=" ".join(result_parts) if result_parts else "",
history=new_messages,
agent_card=agent_card.model_dump(exclude_none=True),
)
from crewai_a2a.updates.streaming.handler import * # noqa: E402, F403

View File

@@ -1,28 +1,13 @@
"""Common parameter extraction for streaming handlers."""
"""Backward-compatibility shim — use ``crewai_a2a.updates.streaming.params`` instead."""
from __future__ import annotations
from a2a.types import TaskStatusUpdateEvent
import warnings
def process_status_update(
update: TaskStatusUpdateEvent,
result_parts: list[str],
) -> bool:
"""Process a status update event and extract text parts.
warnings.warn(
"'crewai.a2a.updates.streaming.params' has been moved to 'crewai_a2a.updates.streaming.params'. "
"Please update your imports. The old path will be removed in v2.0.0.",
FutureWarning,
stacklevel=2,
)
Args:
update: The status update event.
result_parts: List to append text parts to (modified in place).
Returns:
True if this is a final update, False otherwise.
"""
is_final = update.final
if update.status and update.status.message and update.status.message.parts:
result_parts.extend(
part.root.text
for part in update.status.message.parts
if part.root.kind == "text" and part.root.text
)
return is_final
from crewai_a2a.updates.streaming.params import * # noqa: E402, F403

View File

@@ -1 +1,13 @@
"""A2A utility modules for client operations."""
"""Backward-compatibility shim — use ``crewai_a2a.utils`` instead."""
import warnings
warnings.warn(
"'crewai.a2a.utils' has been moved to 'crewai_a2a.utils'. "
"Please update your imports. The old path will be removed in v2.0.0.",
FutureWarning,
stacklevel=2,
)
from crewai_a2a.utils import * # noqa: E402, F403

View File

@@ -1,586 +1,13 @@
"""AgentCard utilities for A2A client and server operations."""
"""Backward-compatibility shim — use ``crewai_a2a.utils.agent_card`` instead."""
from __future__ import annotations
import warnings
import asyncio
from collections.abc import MutableMapping
from functools import lru_cache
import ssl
import time
from types import MethodType
from typing import TYPE_CHECKING
from a2a.client.errors import A2AClientHTTPError
from a2a.types import AgentCapabilities, AgentCard, AgentSkill
from aiocache import cached # type: ignore[import-untyped]
from aiocache.serializers import PickleSerializer # type: ignore[import-untyped]
import httpx
from crewai.a2a.auth.client_schemes import APIKeyAuth, HTTPDigestAuth
from crewai.a2a.auth.utils import (
_auth_store,
configure_auth_client,
retry_on_401,
)
from crewai.a2a.config import A2AServerConfig
from crewai.crew import Crew
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.a2a_events import (
A2AAgentCardFetchedEvent,
A2AAuthenticationFailedEvent,
A2AConnectionErrorEvent,
warnings.warn(
"'crewai.a2a.utils.agent_card' has been moved to 'crewai_a2a.utils.agent_card'. "
"Please update your imports. The old path will be removed in v2.0.0.",
FutureWarning,
stacklevel=2,
)
if TYPE_CHECKING:
from crewai.a2a.auth.client_schemes import ClientAuthScheme
from crewai.agent import Agent
from crewai.task import Task
def _get_tls_verify(auth: ClientAuthScheme | None) -> ssl.SSLContext | bool | str:
"""Get TLS verify parameter from auth scheme.
Args:
auth: Optional authentication scheme with TLS config.
Returns:
SSL context, CA cert path, True for default verification,
or False if verification disabled.
"""
if auth and auth.tls:
return auth.tls.get_httpx_ssl_context()
return True
async def _prepare_auth_headers(
auth: ClientAuthScheme | None,
timeout: int,
) -> tuple[MutableMapping[str, str], ssl.SSLContext | bool | str]:
"""Prepare authentication headers and TLS verification settings.
Args:
auth: Optional authentication scheme.
timeout: Request timeout in seconds.
Returns:
Tuple of (headers dict, TLS verify setting).
"""
headers: MutableMapping[str, str] = {}
verify = _get_tls_verify(auth)
if auth:
async with httpx.AsyncClient(
timeout=timeout, verify=verify
) as temp_auth_client:
if isinstance(auth, (HTTPDigestAuth, APIKeyAuth)):
configure_auth_client(auth, temp_auth_client)
headers = await auth.apply_auth(temp_auth_client, {})
return headers, verify
def _get_server_config(agent: Agent) -> A2AServerConfig | None:
"""Get A2AServerConfig from an agent's a2a configuration.
Args:
agent: The Agent instance to check.
Returns:
A2AServerConfig if present, None otherwise.
"""
if agent.a2a is None:
return None
if isinstance(agent.a2a, A2AServerConfig):
return agent.a2a
if isinstance(agent.a2a, list):
for config in agent.a2a:
if isinstance(config, A2AServerConfig):
return config
return None
def fetch_agent_card(
endpoint: str,
auth: ClientAuthScheme | None = None,
timeout: int = 30,
use_cache: bool = True,
cache_ttl: int = 300,
) -> AgentCard:
"""Fetch AgentCard from an A2A endpoint with optional caching.
Args:
endpoint: A2A agent endpoint URL (AgentCard URL).
auth: Optional ClientAuthScheme for authentication.
timeout: Request timeout in seconds.
use_cache: Whether to use caching (default True).
cache_ttl: Cache TTL in seconds (default 300 = 5 minutes).
Returns:
AgentCard object with agent capabilities and skills.
Raises:
httpx.HTTPStatusError: If the request fails.
A2AClientHTTPError: If authentication fails.
"""
if use_cache:
if auth:
auth_data = auth.model_dump_json(
exclude={
"_access_token",
"_token_expires_at",
"_refresh_token",
"_authorization_callback",
}
)
auth_hash = _auth_store.compute_key(type(auth).__name__, auth_data)
else:
auth_hash = _auth_store.compute_key("none", "")
_auth_store.set(auth_hash, auth)
ttl_hash = int(time.time() // cache_ttl)
return _fetch_agent_card_cached(endpoint, auth_hash, timeout, ttl_hash)
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
return loop.run_until_complete(
afetch_agent_card(endpoint=endpoint, auth=auth, timeout=timeout)
)
finally:
loop.close()
async def afetch_agent_card(
endpoint: str,
auth: ClientAuthScheme | None = None,
timeout: int = 30,
use_cache: bool = True,
) -> AgentCard:
"""Fetch AgentCard from an A2A endpoint asynchronously.
Native async implementation. Use this when running in an async context.
Args:
endpoint: A2A agent endpoint URL (AgentCard URL).
auth: Optional ClientAuthScheme for authentication.
timeout: Request timeout in seconds.
use_cache: Whether to use caching (default True).
Returns:
AgentCard object with agent capabilities and skills.
Raises:
httpx.HTTPStatusError: If the request fails.
A2AClientHTTPError: If authentication fails.
"""
if use_cache:
if auth:
auth_data = auth.model_dump_json(
exclude={
"_access_token",
"_token_expires_at",
"_refresh_token",
"_authorization_callback",
}
)
auth_hash = _auth_store.compute_key(type(auth).__name__, auth_data)
else:
auth_hash = _auth_store.compute_key("none", "")
_auth_store.set(auth_hash, auth)
agent_card: AgentCard = await _afetch_agent_card_cached(
endpoint, auth_hash, timeout
)
return agent_card
return await _afetch_agent_card_impl(endpoint=endpoint, auth=auth, timeout=timeout)
@lru_cache()
def _fetch_agent_card_cached(
endpoint: str,
auth_hash: str,
timeout: int,
_ttl_hash: int,
) -> AgentCard:
"""Cached sync version of fetch_agent_card."""
auth = _auth_store.get(auth_hash)
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
return loop.run_until_complete(
_afetch_agent_card_impl(endpoint=endpoint, auth=auth, timeout=timeout)
)
finally:
loop.close()
@cached(ttl=300, serializer=PickleSerializer()) # type: ignore[untyped-decorator]
async def _afetch_agent_card_cached(
endpoint: str,
auth_hash: str,
timeout: int,
) -> AgentCard:
"""Cached async implementation of AgentCard fetching."""
auth = _auth_store.get(auth_hash)
return await _afetch_agent_card_impl(endpoint=endpoint, auth=auth, timeout=timeout)
async def _afetch_agent_card_impl(
endpoint: str,
auth: ClientAuthScheme | None,
timeout: int,
) -> AgentCard:
"""Internal async implementation of AgentCard fetching."""
start_time = time.perf_counter()
if "/.well-known/agent-card.json" in endpoint:
base_url = endpoint.replace("/.well-known/agent-card.json", "")
agent_card_path = "/.well-known/agent-card.json"
else:
url_parts = endpoint.split("/", 3)
base_url = f"{url_parts[0]}//{url_parts[2]}"
agent_card_path = (
f"/{url_parts[3]}"
if len(url_parts) > 3 and url_parts[3]
else "/.well-known/agent-card.json"
)
headers, verify = await _prepare_auth_headers(auth, timeout)
async with httpx.AsyncClient(
timeout=timeout, headers=headers, verify=verify
) as temp_client:
if auth and isinstance(auth, (HTTPDigestAuth, APIKeyAuth)):
configure_auth_client(auth, temp_client)
agent_card_url = f"{base_url}{agent_card_path}"
async def _fetch_agent_card_request() -> httpx.Response:
return await temp_client.get(agent_card_url)
try:
response = await retry_on_401(
request_func=_fetch_agent_card_request,
auth_scheme=auth,
client=temp_client,
headers=temp_client.headers,
max_retries=2,
)
response.raise_for_status()
agent_card = AgentCard.model_validate(response.json())
fetch_time_ms = (time.perf_counter() - start_time) * 1000
agent_card_dict = agent_card.model_dump(exclude_none=True)
crewai_event_bus.emit(
None,
A2AAgentCardFetchedEvent(
endpoint=endpoint,
a2a_agent_name=agent_card.name,
agent_card=agent_card_dict,
protocol_version=agent_card.protocol_version,
provider=agent_card_dict.get("provider"),
cached=False,
fetch_time_ms=fetch_time_ms,
),
)
return agent_card
except httpx.HTTPStatusError as e:
elapsed_ms = (time.perf_counter() - start_time) * 1000
response_body = e.response.text[:1000] if e.response.text else None
if e.response.status_code == 401:
error_details = ["Authentication failed"]
www_auth = e.response.headers.get("WWW-Authenticate")
if www_auth:
error_details.append(f"WWW-Authenticate: {www_auth}")
if not auth:
error_details.append("No auth scheme provided")
msg = " | ".join(error_details)
auth_type = type(auth).__name__ if auth else None
crewai_event_bus.emit(
None,
A2AAuthenticationFailedEvent(
endpoint=endpoint,
auth_type=auth_type,
error=msg,
status_code=401,
metadata={
"elapsed_ms": elapsed_ms,
"response_body": response_body,
"www_authenticate": www_auth,
"request_url": str(e.request.url),
},
),
)
raise A2AClientHTTPError(401, msg) from e
crewai_event_bus.emit(
None,
A2AConnectionErrorEvent(
endpoint=endpoint,
error=str(e),
error_type="http_error",
status_code=e.response.status_code,
operation="fetch_agent_card",
metadata={
"elapsed_ms": elapsed_ms,
"response_body": response_body,
"request_url": str(e.request.url),
},
),
)
raise
except httpx.TimeoutException as e:
elapsed_ms = (time.perf_counter() - start_time) * 1000
crewai_event_bus.emit(
None,
A2AConnectionErrorEvent(
endpoint=endpoint,
error=str(e),
error_type="timeout",
operation="fetch_agent_card",
metadata={
"elapsed_ms": elapsed_ms,
"timeout_config": timeout,
"request_url": str(e.request.url) if e.request else None,
},
),
)
raise
except httpx.ConnectError as e:
elapsed_ms = (time.perf_counter() - start_time) * 1000
crewai_event_bus.emit(
None,
A2AConnectionErrorEvent(
endpoint=endpoint,
error=str(e),
error_type="connection_error",
operation="fetch_agent_card",
metadata={
"elapsed_ms": elapsed_ms,
"request_url": str(e.request.url) if e.request else None,
},
),
)
raise
except httpx.RequestError as e:
elapsed_ms = (time.perf_counter() - start_time) * 1000
crewai_event_bus.emit(
None,
A2AConnectionErrorEvent(
endpoint=endpoint,
error=str(e),
error_type="request_error",
operation="fetch_agent_card",
metadata={
"elapsed_ms": elapsed_ms,
"request_url": str(e.request.url) if e.request else None,
},
),
)
raise
def _task_to_skill(task: Task) -> AgentSkill:
"""Convert a CrewAI Task to an A2A AgentSkill.
Args:
task: The CrewAI Task to convert.
Returns:
AgentSkill representing the task's capability.
"""
task_name = task.name or task.description[:50]
task_id = task_name.lower().replace(" ", "_")
tags: list[str] = []
if task.agent:
tags.append(task.agent.role.lower().replace(" ", "-"))
return AgentSkill(
id=task_id,
name=task_name,
description=task.description,
tags=tags,
examples=[task.expected_output] if task.expected_output else None,
)
def _tool_to_skill(tool_name: str, tool_description: str) -> AgentSkill:
"""Convert an Agent's tool to an A2A AgentSkill.
Args:
tool_name: Name of the tool.
tool_description: Description of what the tool does.
Returns:
AgentSkill representing the tool's capability.
"""
tool_id = tool_name.lower().replace(" ", "_")
return AgentSkill(
id=tool_id,
name=tool_name,
description=tool_description,
tags=[tool_name.lower().replace(" ", "-")],
)
def _crew_to_agent_card(crew: Crew, url: str) -> AgentCard:
"""Generate an A2A AgentCard from a Crew instance.
Args:
crew: The Crew instance to generate a card for.
url: The base URL where this crew will be exposed.
Returns:
AgentCard describing the crew's capabilities.
"""
crew_name = getattr(crew, "name", None) or crew.__class__.__name__
description_parts: list[str] = []
crew_description = getattr(crew, "description", None)
if crew_description:
description_parts.append(crew_description)
else:
agent_roles = [agent.role for agent in crew.agents]
description_parts.append(
f"A crew of {len(crew.agents)} agents: {', '.join(agent_roles)}"
)
skills = [_task_to_skill(task) for task in crew.tasks]
return AgentCard(
name=crew_name,
description=" ".join(description_parts),
url=url,
version="1.0.0",
capabilities=AgentCapabilities(
streaming=True,
push_notifications=True,
),
default_input_modes=["text/plain", "application/json"],
default_output_modes=["text/plain", "application/json"],
skills=skills,
)
def _agent_to_agent_card(agent: Agent, url: str) -> AgentCard:
"""Generate an A2A AgentCard from an Agent instance.
Uses A2AServerConfig values when available, falling back to agent properties.
If signing_config is provided, the card will be signed with JWS.
Args:
agent: The Agent instance to generate a card for.
url: The base URL where this agent will be exposed.
Returns:
AgentCard describing the agent's capabilities.
"""
from crewai.a2a.utils.agent_card_signing import sign_agent_card
server_config = _get_server_config(agent) or A2AServerConfig()
name = server_config.name or agent.role
description_parts = [agent.goal]
if agent.backstory:
description_parts.append(agent.backstory)
description = server_config.description or " ".join(description_parts)
skills: list[AgentSkill] = (
server_config.skills.copy() if server_config.skills else []
)
if not skills:
if agent.tools:
for tool in agent.tools:
tool_name = getattr(tool, "name", None) or tool.__class__.__name__
tool_desc = getattr(tool, "description", None) or f"Tool: {tool_name}"
skills.append(_tool_to_skill(tool_name, tool_desc))
if not skills:
skills.append(
AgentSkill(
id=agent.role.lower().replace(" ", "_"),
name=agent.role,
description=agent.goal,
tags=[agent.role.lower().replace(" ", "-")],
)
)
capabilities = server_config.capabilities
if server_config.server_extensions:
from crewai.a2a.extensions.server import ServerExtensionRegistry
registry = ServerExtensionRegistry(server_config.server_extensions)
ext_list = registry.get_agent_extensions()
existing_exts = list(capabilities.extensions) if capabilities.extensions else []
existing_uris = {e.uri for e in existing_exts}
for ext in ext_list:
if ext.uri not in existing_uris:
existing_exts.append(ext)
capabilities = capabilities.model_copy(update={"extensions": existing_exts})
card = AgentCard(
name=name,
description=description,
url=server_config.url or url,
version=server_config.version,
capabilities=capabilities,
default_input_modes=server_config.default_input_modes,
default_output_modes=server_config.default_output_modes,
skills=skills,
preferred_transport=server_config.transport.preferred,
protocol_version=server_config.protocol_version,
provider=server_config.provider,
documentation_url=server_config.documentation_url,
icon_url=server_config.icon_url,
additional_interfaces=server_config.additional_interfaces,
security=server_config.security,
security_schemes=server_config.security_schemes,
supports_authenticated_extended_card=server_config.supports_authenticated_extended_card,
)
if server_config.signing_config:
signature = sign_agent_card(
card,
private_key=server_config.signing_config.get_private_key(),
key_id=server_config.signing_config.key_id,
algorithm=server_config.signing_config.algorithm,
)
card = card.model_copy(update={"signatures": [signature]})
elif server_config.signatures:
card = card.model_copy(update={"signatures": server_config.signatures})
return card
def inject_a2a_server_methods(agent: Agent) -> None:
"""Inject A2A server methods onto an Agent instance.
Adds a `to_agent_card(url: str) -> AgentCard` method to the agent
that generates an A2A-compliant AgentCard.
Only injects if the agent has an A2AServerConfig.
Args:
agent: The Agent instance to inject methods onto.
"""
if _get_server_config(agent) is None:
return
def _to_agent_card(self: Agent, url: str) -> AgentCard:
return _agent_to_agent_card(self, url)
object.__setattr__(agent, "to_agent_card", MethodType(_to_agent_card, agent))
from crewai_a2a.utils.agent_card import * # noqa: E402, F403

View File

@@ -1,236 +1,13 @@
"""AgentCard JWS signing utilities.
"""Backward-compatibility shim — use ``crewai_a2a.utils.agent_card_signing`` instead."""
This module provides functions for signing and verifying AgentCards using
JSON Web Signatures (JWS) as per RFC 7515. Signed agent cards allow clients
to verify the authenticity and integrity of agent card information.
Example:
>>> from crewai.a2a.utils.agent_card_signing import sign_agent_card
>>> signature = sign_agent_card(agent_card, private_key_pem, key_id="key-1")
>>> card_with_sig = card.model_copy(update={"signatures": [signature]})
"""
from __future__ import annotations
import base64
import json
import logging
from typing import Any, Literal
from a2a.types import AgentCard, AgentCardSignature
import jwt
from pydantic import SecretStr
import warnings
logger = logging.getLogger(__name__)
warnings.warn(
"'crewai.a2a.utils.agent_card_signing' has been moved to 'crewai_a2a.utils.agent_card_signing'. "
"Please update your imports. The old path will be removed in v2.0.0.",
FutureWarning,
stacklevel=2,
)
SigningAlgorithm = Literal[
"RS256", "RS384", "RS512", "ES256", "ES384", "ES512", "PS256", "PS384", "PS512"
]
def _normalize_private_key(private_key: str | bytes | SecretStr) -> bytes:
"""Normalize private key to bytes format.
Args:
private_key: PEM-encoded private key as string, bytes, or SecretStr.
Returns:
Private key as bytes.
"""
if isinstance(private_key, SecretStr):
private_key = private_key.get_secret_value()
if isinstance(private_key, str):
private_key = private_key.encode()
return private_key
def _serialize_agent_card(agent_card: AgentCard) -> str:
"""Serialize AgentCard to canonical JSON for signing.
Excludes the signatures field to avoid circular reference during signing.
Uses sorted keys and compact separators for deterministic output.
Args:
agent_card: The AgentCard to serialize.
Returns:
Canonical JSON string representation.
"""
card_dict = agent_card.model_dump(exclude={"signatures"}, exclude_none=True)
return json.dumps(card_dict, sort_keys=True, separators=(",", ":"))
def _base64url_encode(data: bytes | str) -> str:
"""Encode data to URL-safe base64 without padding.
Args:
data: Data to encode.
Returns:
URL-safe base64 encoded string without padding.
"""
if isinstance(data, str):
data = data.encode()
return base64.urlsafe_b64encode(data).rstrip(b"=").decode("ascii")
def sign_agent_card(
agent_card: AgentCard,
private_key: str | bytes | SecretStr,
key_id: str | None = None,
algorithm: SigningAlgorithm = "RS256",
) -> AgentCardSignature:
"""Sign an AgentCard using JWS (RFC 7515).
Creates a detached JWS signature for the AgentCard. The signature covers
all fields except the signatures field itself.
Args:
agent_card: The AgentCard to sign.
private_key: PEM-encoded private key (RSA, EC, or RSA-PSS).
key_id: Optional key identifier for the JWS header (kid claim).
algorithm: Signing algorithm (RS256, ES256, PS256, etc.).
Returns:
AgentCardSignature with protected header and signature.
Raises:
jwt.exceptions.InvalidKeyError: If the private key is invalid.
ValueError: If the algorithm is not supported for the key type.
Example:
>>> signature = sign_agent_card(
... agent_card,
... private_key_pem="-----BEGIN PRIVATE KEY-----...",
... key_id="my-key-id",
... )
"""
key_bytes = _normalize_private_key(private_key)
payload = _serialize_agent_card(agent_card)
protected_header: dict[str, Any] = {"typ": "JWS"}
if key_id:
protected_header["kid"] = key_id
jws_token = jwt.api_jws.encode(
payload.encode(),
key_bytes,
algorithm=algorithm,
headers=protected_header,
)
parts = jws_token.split(".")
protected_b64 = parts[0]
signature_b64 = parts[2]
header: dict[str, Any] | None = None
if key_id:
header = {"kid": key_id}
return AgentCardSignature(
protected=protected_b64,
signature=signature_b64,
header=header,
)
def verify_agent_card_signature(
agent_card: AgentCard,
signature: AgentCardSignature,
public_key: str | bytes,
algorithms: list[str] | None = None,
) -> bool:
"""Verify an AgentCard JWS signature.
Validates that the signature was created with the corresponding private key
and that the AgentCard content has not been modified.
Args:
agent_card: The AgentCard to verify.
signature: The AgentCardSignature to validate.
public_key: PEM-encoded public key (RSA, EC, or RSA-PSS).
algorithms: List of allowed algorithms. Defaults to common asymmetric algorithms.
Returns:
True if signature is valid, False otherwise.
Example:
>>> is_valid = verify_agent_card_signature(
... agent_card, signature, public_key_pem="-----BEGIN PUBLIC KEY-----..."
... )
"""
if algorithms is None:
algorithms = [
"RS256",
"RS384",
"RS512",
"ES256",
"ES384",
"ES512",
"PS256",
"PS384",
"PS512",
]
if isinstance(public_key, str):
public_key = public_key.encode()
payload = _serialize_agent_card(agent_card)
payload_b64 = _base64url_encode(payload)
jws_token = f"{signature.protected}.{payload_b64}.{signature.signature}"
try:
jwt.api_jws.decode(
jws_token,
public_key,
algorithms=algorithms,
)
return True
except jwt.InvalidSignatureError:
logger.debug(
"AgentCard signature verification failed",
extra={"reason": "invalid_signature"},
)
return False
except jwt.DecodeError as e:
logger.debug(
"AgentCard signature verification failed",
extra={"reason": "decode_error", "error": str(e)},
)
return False
except jwt.InvalidAlgorithmError as e:
logger.debug(
"AgentCard signature verification failed",
extra={"reason": "algorithm_error", "error": str(e)},
)
return False
def get_key_id_from_signature(signature: AgentCardSignature) -> str | None:
"""Extract the key ID (kid) from an AgentCardSignature.
Checks both the unprotected header and the protected header for the kid claim.
Args:
signature: The AgentCardSignature to extract from.
Returns:
The key ID if present, None otherwise.
"""
if signature.header and "kid" in signature.header:
kid: str = signature.header["kid"]
return kid
try:
protected = signature.protected
padding_needed = 4 - (len(protected) % 4)
if padding_needed != 4:
protected += "=" * padding_needed
protected_json = base64.urlsafe_b64decode(protected).decode()
protected_header: dict[str, Any] = json.loads(protected_json)
return protected_header.get("kid")
except (ValueError, json.JSONDecodeError):
return None
from crewai_a2a.utils.agent_card_signing import * # noqa: E402, F403

View File

@@ -1,339 +1,13 @@
"""Content type negotiation for A2A protocol.
"""Backward-compatibility shim — use ``crewai_a2a.utils.content_type`` instead."""
This module handles negotiation of input/output MIME types between A2A clients
and servers based on AgentCard capabilities.
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING, Annotated, Final, Literal, cast
from a2a.types import Part
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.a2a_events import A2AContentTypeNegotiatedEvent
import warnings
if TYPE_CHECKING:
from a2a.types import AgentCard, AgentSkill
TEXT_PLAIN: Literal["text/plain"] = "text/plain"
APPLICATION_JSON: Literal["application/json"] = "application/json"
IMAGE_PNG: Literal["image/png"] = "image/png"
IMAGE_JPEG: Literal["image/jpeg"] = "image/jpeg"
IMAGE_WILDCARD: Literal["image/*"] = "image/*"
APPLICATION_PDF: Literal["application/pdf"] = "application/pdf"
APPLICATION_OCTET_STREAM: Literal["application/octet-stream"] = (
"application/octet-stream"
warnings.warn(
"'crewai.a2a.utils.content_type' has been moved to 'crewai_a2a.utils.content_type'. "
"Please update your imports. The old path will be removed in v2.0.0.",
FutureWarning,
stacklevel=2,
)
DEFAULT_CLIENT_INPUT_MODES: Final[list[Literal["text/plain", "application/json"]]] = [
TEXT_PLAIN,
APPLICATION_JSON,
]
DEFAULT_CLIENT_OUTPUT_MODES: Final[list[Literal["text/plain", "application/json"]]] = [
TEXT_PLAIN,
APPLICATION_JSON,
]
@dataclass
class NegotiatedContentTypes:
"""Result of content type negotiation."""
input_modes: Annotated[list[str], "Negotiated input MIME types the client can send"]
output_modes: Annotated[
list[str], "Negotiated output MIME types the server will produce"
]
effective_input_modes: Annotated[list[str], "Server's effective input modes"]
effective_output_modes: Annotated[list[str], "Server's effective output modes"]
skill_name: Annotated[
str | None, "Skill name if negotiation was skill-specific"
] = None
class ContentTypeNegotiationError(Exception):
"""Raised when no compatible content types can be negotiated."""
def __init__(
self,
client_input_modes: list[str],
client_output_modes: list[str],
server_input_modes: list[str],
server_output_modes: list[str],
direction: str = "both",
message: str | None = None,
) -> None:
self.client_input_modes = client_input_modes
self.client_output_modes = client_output_modes
self.server_input_modes = server_input_modes
self.server_output_modes = server_output_modes
self.direction = direction
if message is None:
if direction == "input":
message = (
f"No compatible input content types. "
f"Client supports: {client_input_modes}, "
f"Server accepts: {server_input_modes}"
)
elif direction == "output":
message = (
f"No compatible output content types. "
f"Client accepts: {client_output_modes}, "
f"Server produces: {server_output_modes}"
)
else:
message = (
f"No compatible content types. "
f"Input - Client: {client_input_modes}, Server: {server_input_modes}. "
f"Output - Client: {client_output_modes}, Server: {server_output_modes}"
)
super().__init__(message)
def _normalize_mime_type(mime_type: str) -> str:
"""Normalize MIME type for comparison (lowercase, strip whitespace)."""
return mime_type.lower().strip()
def _mime_types_compatible(client_type: str, server_type: str) -> bool:
"""Check if two MIME types are compatible.
Handles wildcards like image/* matching image/png.
"""
client_normalized = _normalize_mime_type(client_type)
server_normalized = _normalize_mime_type(server_type)
if client_normalized == server_normalized:
return True
if "*" in client_normalized or "*" in server_normalized:
client_parts = client_normalized.split("/")
server_parts = server_normalized.split("/")
if len(client_parts) == 2 and len(server_parts) == 2:
type_match = (
client_parts[0] == server_parts[0]
or client_parts[0] == "*"
or server_parts[0] == "*"
)
subtype_match = (
client_parts[1] == server_parts[1]
or client_parts[1] == "*"
or server_parts[1] == "*"
)
return type_match and subtype_match
return False
def _find_compatible_modes(
client_modes: list[str], server_modes: list[str]
) -> list[str]:
"""Find compatible MIME types between client and server.
Returns modes in client preference order.
"""
compatible = []
for client_mode in client_modes:
for server_mode in server_modes:
if _mime_types_compatible(client_mode, server_mode):
if "*" in client_mode and "*" not in server_mode:
if server_mode not in compatible:
compatible.append(server_mode)
else:
if client_mode not in compatible:
compatible.append(client_mode)
break
return compatible
def _get_effective_modes(
agent_card: AgentCard,
skill_name: str | None = None,
) -> tuple[list[str], list[str], AgentSkill | None]:
"""Get effective input/output modes from agent card.
If skill_name is provided and the skill has custom modes, those are used.
Otherwise, falls back to agent card defaults.
"""
skill: AgentSkill | None = None
if skill_name and agent_card.skills:
for s in agent_card.skills:
if s.name == skill_name or s.id == skill_name:
skill = s
break
if skill:
input_modes = (
skill.input_modes if skill.input_modes else agent_card.default_input_modes
)
output_modes = (
skill.output_modes
if skill.output_modes
else agent_card.default_output_modes
)
else:
input_modes = agent_card.default_input_modes
output_modes = agent_card.default_output_modes
return input_modes, output_modes, skill
def negotiate_content_types(
agent_card: AgentCard,
client_input_modes: list[str] | None = None,
client_output_modes: list[str] | None = None,
skill_name: str | None = None,
emit_event: bool = True,
endpoint: str | None = None,
a2a_agent_name: str | None = None,
strict: bool = False,
) -> NegotiatedContentTypes:
"""Negotiate content types between client and server.
Args:
agent_card: The remote agent's card with capability info.
client_input_modes: MIME types the client can send. Defaults to text/plain and application/json.
client_output_modes: MIME types the client can accept. Defaults to text/plain and application/json.
skill_name: Optional skill to use for mode lookup.
emit_event: Whether to emit a content type negotiation event.
endpoint: Agent endpoint (for event metadata).
a2a_agent_name: Agent name (for event metadata).
strict: If True, raises error when no compatible types found.
If False, returns empty lists for incompatible directions.
Returns:
NegotiatedContentTypes with compatible input and output modes.
Raises:
ContentTypeNegotiationError: If strict=True and no compatible types found.
"""
if client_input_modes is None:
client_input_modes = cast(list[str], DEFAULT_CLIENT_INPUT_MODES.copy())
if client_output_modes is None:
client_output_modes = cast(list[str], DEFAULT_CLIENT_OUTPUT_MODES.copy())
server_input_modes, server_output_modes, skill = _get_effective_modes(
agent_card, skill_name
)
compatible_input = _find_compatible_modes(client_input_modes, server_input_modes)
compatible_output = _find_compatible_modes(client_output_modes, server_output_modes)
if strict:
if not compatible_input and not compatible_output:
raise ContentTypeNegotiationError(
client_input_modes=client_input_modes,
client_output_modes=client_output_modes,
server_input_modes=server_input_modes,
server_output_modes=server_output_modes,
)
if not compatible_input:
raise ContentTypeNegotiationError(
client_input_modes=client_input_modes,
client_output_modes=client_output_modes,
server_input_modes=server_input_modes,
server_output_modes=server_output_modes,
direction="input",
)
if not compatible_output:
raise ContentTypeNegotiationError(
client_input_modes=client_input_modes,
client_output_modes=client_output_modes,
server_input_modes=server_input_modes,
server_output_modes=server_output_modes,
direction="output",
)
result = NegotiatedContentTypes(
input_modes=compatible_input,
output_modes=compatible_output,
effective_input_modes=server_input_modes,
effective_output_modes=server_output_modes,
skill_name=skill.name if skill else None,
)
if emit_event:
crewai_event_bus.emit(
None,
A2AContentTypeNegotiatedEvent(
endpoint=endpoint or agent_card.url,
a2a_agent_name=a2a_agent_name or agent_card.name,
skill_name=skill_name,
client_input_modes=client_input_modes,
client_output_modes=client_output_modes,
server_input_modes=server_input_modes,
server_output_modes=server_output_modes,
negotiated_input_modes=compatible_input,
negotiated_output_modes=compatible_output,
negotiation_success=bool(compatible_input and compatible_output),
),
)
return result
def validate_content_type(
content_type: str,
allowed_modes: list[str],
) -> bool:
"""Validate that a content type is allowed by a list of modes.
Args:
content_type: The MIME type to validate.
allowed_modes: List of allowed MIME types (may include wildcards).
Returns:
True if content_type is compatible with any allowed mode.
"""
for mode in allowed_modes:
if _mime_types_compatible(content_type, mode):
return True
return False
def get_part_content_type(part: Part) -> str:
"""Extract MIME type from an A2A Part.
Args:
part: A Part object containing TextPart, DataPart, or FilePart.
Returns:
The MIME type string for this part.
"""
root = part.root
if root.kind == "text":
return TEXT_PLAIN
if root.kind == "data":
return APPLICATION_JSON
if root.kind == "file":
return root.file.mime_type or APPLICATION_OCTET_STREAM
return APPLICATION_OCTET_STREAM
def validate_message_parts(
parts: list[Part],
allowed_modes: list[str],
) -> list[str]:
"""Validate that all message parts have allowed content types.
Args:
parts: List of Parts from the incoming message.
allowed_modes: List of allowed MIME types (from default_input_modes).
Returns:
List of invalid content types found (empty if all valid).
"""
invalid_types: list[str] = []
for part in parts:
content_type = get_part_content_type(part)
if not validate_content_type(content_type, allowed_modes):
if content_type not in invalid_types:
invalid_types.append(content_type)
return invalid_types
from crewai_a2a.utils.content_type import * # noqa: E402, F403

View File

@@ -1,980 +1,13 @@
"""A2A delegation utilities for executing tasks on remote agents."""
"""Backward-compatibility shim — use ``crewai_a2a.utils.delegation`` instead."""
from __future__ import annotations
import warnings
import asyncio
import base64
from collections.abc import AsyncIterator, Callable, MutableMapping
from contextlib import asynccontextmanager
import logging
from typing import TYPE_CHECKING, Any, Final, Literal
import uuid
from a2a.client import Client, ClientConfig, ClientFactory
from a2a.types import (
AgentCard,
FilePart,
FileWithBytes,
Message,
Part,
PushNotificationConfig as A2APushNotificationConfig,
Role,
TextPart,
)
import httpx
from pydantic import BaseModel
from crewai.a2a.auth.client_schemes import APIKeyAuth, HTTPDigestAuth
from crewai.a2a.auth.utils import (
_auth_store,
configure_auth_client,
validate_auth_against_agent_card,
)
from crewai.a2a.config import ClientTransportConfig, GRPCClientConfig
from crewai.a2a.extensions.registry import (
ExtensionsMiddleware,
validate_required_extensions,
)
from crewai.a2a.task_helpers import TaskStateResult
from crewai.a2a.types import (
HANDLER_REGISTRY,
HandlerType,
PartsDict,
PartsMetadataDict,
TransportType,
)
from crewai.a2a.updates import (
PollingConfig,
PushNotificationConfig,
StreamingHandler,
UpdateConfig,
)
from crewai.a2a.utils.agent_card import (
_afetch_agent_card_cached,
_get_tls_verify,
_prepare_auth_headers,
)
from crewai.a2a.utils.content_type import (
DEFAULT_CLIENT_OUTPUT_MODES,
negotiate_content_types,
)
from crewai.a2a.utils.transport import (
NegotiatedTransport,
TransportNegotiationError,
negotiate_transport,
)
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.a2a_events import (
A2AConversationStartedEvent,
A2ADelegationCompletedEvent,
A2ADelegationStartedEvent,
A2AMessageSentEvent,
warnings.warn(
"'crewai.a2a.utils.delegation' has been moved to 'crewai_a2a.utils.delegation'. "
"Please update your imports. The old path will be removed in v2.0.0.",
FutureWarning,
stacklevel=2,
)
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from a2a.types import Message
from crewai.a2a.auth.client_schemes import ClientAuthScheme
_DEFAULT_TRANSPORT: Final[TransportType] = "JSONRPC"
def _create_file_parts(input_files: dict[str, Any] | None) -> list[Part]:
"""Convert FileInput dictionary to FilePart objects.
Args:
input_files: Dictionary mapping names to FileInput objects.
Returns:
List of Part objects containing FilePart data.
"""
if not input_files:
return []
try:
import crewai_files # noqa: F401
except ImportError:
logger.debug("crewai_files not installed, skipping file parts")
return []
parts: list[Part] = []
for name, file_input in input_files.items():
content_bytes = file_input.read()
content_base64 = base64.b64encode(content_bytes).decode()
file_with_bytes = FileWithBytes(
bytes=content_base64,
mimeType=file_input.content_type,
name=file_input.filename or name,
)
parts.append(Part(root=FilePart(file=file_with_bytes)))
return parts
def get_handler(config: UpdateConfig | None) -> HandlerType:
"""Get the handler class for a given update config.
Args:
config: Update mechanism configuration.
Returns:
Handler class for the config type, defaults to StreamingHandler.
"""
if config is None:
return StreamingHandler
return HANDLER_REGISTRY.get(type(config), StreamingHandler)
def execute_a2a_delegation(
endpoint: str,
auth: ClientAuthScheme | None,
timeout: int,
task_description: str,
context: str | None = None,
context_id: str | None = None,
task_id: str | None = None,
reference_task_ids: list[str] | None = None,
metadata: dict[str, Any] | None = None,
extensions: dict[str, Any] | None = None,
conversation_history: list[Message] | None = None,
agent_id: str | None = None,
agent_role: Role | None = None,
agent_branch: Any | None = None,
response_model: type[BaseModel] | None = None,
turn_number: int | None = None,
updates: UpdateConfig | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
skill_id: str | None = None,
client_extensions: list[str] | None = None,
transport: ClientTransportConfig | None = None,
accepted_output_modes: list[str] | None = None,
input_files: dict[str, Any] | None = None,
) -> TaskStateResult:
"""Execute a task delegation to a remote A2A agent synchronously.
WARNING: This function blocks the entire thread by creating and running a new
event loop. Prefer using 'await aexecute_a2a_delegation()' in async contexts
for better performance and resource efficiency.
This is a synchronous wrapper around aexecute_a2a_delegation that creates a
new event loop to run the async implementation. It is provided for compatibility
with synchronous code paths only.
Args:
endpoint: A2A agent endpoint URL (AgentCard URL).
auth: Optional ClientAuthScheme for authentication.
timeout: Request timeout in seconds.
task_description: The task to delegate.
context: Optional context information.
context_id: Context ID for correlating messages/tasks.
task_id: Specific task identifier.
reference_task_ids: List of related task IDs.
metadata: Additional metadata.
extensions: Protocol extensions for custom fields.
conversation_history: Previous Message objects from conversation.
agent_id: Agent identifier for logging.
agent_role: Role of the CrewAI agent delegating the task.
agent_branch: Optional agent tree branch for logging.
response_model: Optional Pydantic model for structured outputs.
turn_number: Optional turn number for multi-turn conversations.
updates: Update mechanism config from A2AConfig.updates.
from_task: Optional CrewAI Task object for event metadata.
from_agent: Optional CrewAI Agent object for event metadata.
skill_id: Optional skill ID to target a specific agent capability.
client_extensions: A2A protocol extension URIs the client supports.
transport: Transport configuration (preferred, supported transports, gRPC settings).
accepted_output_modes: MIME types the client can accept in responses.
input_files: Optional dictionary of files to send to remote agent.
Returns:
TaskStateResult with status, result/error, history, and agent_card.
Raises:
RuntimeError: If called from an async context with a running event loop.
"""
try:
asyncio.get_running_loop()
raise RuntimeError(
"execute_a2a_delegation() cannot be called from an async context. "
"Use 'await aexecute_a2a_delegation()' instead."
)
except RuntimeError as e:
if "no running event loop" not in str(e).lower():
raise
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
return loop.run_until_complete(
aexecute_a2a_delegation(
endpoint=endpoint,
auth=auth,
timeout=timeout,
task_description=task_description,
context=context,
context_id=context_id,
task_id=task_id,
reference_task_ids=reference_task_ids,
metadata=metadata,
extensions=extensions,
conversation_history=conversation_history,
agent_id=agent_id,
agent_role=agent_role,
agent_branch=agent_branch,
response_model=response_model,
turn_number=turn_number,
updates=updates,
from_task=from_task,
from_agent=from_agent,
skill_id=skill_id,
client_extensions=client_extensions,
transport=transport,
accepted_output_modes=accepted_output_modes,
input_files=input_files,
)
)
finally:
try:
loop.run_until_complete(loop.shutdown_asyncgens())
finally:
loop.close()
async def aexecute_a2a_delegation(
endpoint: str,
auth: ClientAuthScheme | None,
timeout: int,
task_description: str,
context: str | None = None,
context_id: str | None = None,
task_id: str | None = None,
reference_task_ids: list[str] | None = None,
metadata: dict[str, Any] | None = None,
extensions: dict[str, Any] | None = None,
conversation_history: list[Message] | None = None,
agent_id: str | None = None,
agent_role: Role | None = None,
agent_branch: Any | None = None,
response_model: type[BaseModel] | None = None,
turn_number: int | None = None,
updates: UpdateConfig | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
skill_id: str | None = None,
client_extensions: list[str] | None = None,
transport: ClientTransportConfig | None = None,
accepted_output_modes: list[str] | None = None,
input_files: dict[str, Any] | None = None,
) -> TaskStateResult:
"""Execute a task delegation to a remote A2A agent asynchronously.
Native async implementation with multi-turn support. Use this when running
in an async context (e.g., with Crew.akickoff() or agent.aexecute_task()).
Args:
endpoint: A2A agent endpoint URL.
auth: Optional ClientAuthScheme for authentication.
timeout: Request timeout in seconds.
task_description: The task to delegate.
context: Optional context information.
context_id: Context ID for correlating messages/tasks.
task_id: Specific task identifier.
reference_task_ids: List of related task IDs.
metadata: Additional metadata.
extensions: Protocol extensions for custom fields.
conversation_history: Previous Message objects from conversation.
agent_id: Agent identifier for logging.
agent_role: Role of the CrewAI agent delegating the task.
agent_branch: Optional agent tree branch for logging.
response_model: Optional Pydantic model for structured outputs.
turn_number: Optional turn number for multi-turn conversations.
updates: Update mechanism config from A2AConfig.updates.
from_task: Optional CrewAI Task object for event metadata.
from_agent: Optional CrewAI Agent object for event metadata.
skill_id: Optional skill ID to target a specific agent capability.
client_extensions: A2A protocol extension URIs the client supports.
transport: Transport configuration (preferred, supported transports, gRPC settings).
accepted_output_modes: MIME types the client can accept in responses.
input_files: Optional dictionary of files to send to remote agent.
Returns:
TaskStateResult with status, result/error, history, and agent_card.
"""
if conversation_history is None:
conversation_history = []
is_multiturn = len(conversation_history) > 0
if turn_number is None:
turn_number = len([m for m in conversation_history if m.role == Role.user]) + 1
try:
result = await _aexecute_a2a_delegation_impl(
endpoint=endpoint,
auth=auth,
timeout=timeout,
task_description=task_description,
context=context,
context_id=context_id,
task_id=task_id,
reference_task_ids=reference_task_ids,
metadata=metadata,
extensions=extensions,
conversation_history=conversation_history,
is_multiturn=is_multiturn,
turn_number=turn_number,
agent_branch=agent_branch,
agent_id=agent_id,
agent_role=agent_role,
response_model=response_model,
updates=updates,
from_task=from_task,
from_agent=from_agent,
skill_id=skill_id,
client_extensions=client_extensions,
transport=transport,
accepted_output_modes=accepted_output_modes,
input_files=input_files,
)
except Exception as e:
crewai_event_bus.emit(
agent_branch,
A2ADelegationCompletedEvent(
status="failed",
result=None,
error=str(e),
context_id=context_id,
is_multiturn=is_multiturn,
endpoint=endpoint,
metadata=metadata,
extensions=list(extensions.keys()) if extensions else None,
from_task=from_task,
from_agent=from_agent,
),
)
raise
agent_card_data = result.get("agent_card")
crewai_event_bus.emit(
agent_branch,
A2ADelegationCompletedEvent(
status=result["status"],
result=result.get("result"),
error=result.get("error"),
context_id=context_id,
is_multiturn=is_multiturn,
endpoint=endpoint,
a2a_agent_name=result.get("a2a_agent_name"),
agent_card=agent_card_data,
provider=agent_card_data.get("provider") if agent_card_data else None,
metadata=metadata,
extensions=list(extensions.keys()) if extensions else None,
from_task=from_task,
from_agent=from_agent,
),
)
return result
async def _aexecute_a2a_delegation_impl(
endpoint: str,
auth: ClientAuthScheme | None,
timeout: int,
task_description: str,
context: str | None,
context_id: str | None,
task_id: str | None,
reference_task_ids: list[str] | None,
metadata: dict[str, Any] | None,
extensions: dict[str, Any] | None,
conversation_history: list[Message],
is_multiturn: bool,
turn_number: int,
agent_branch: Any | None,
agent_id: str | None,
agent_role: str | None,
response_model: type[BaseModel] | None,
updates: UpdateConfig | None,
from_task: Any | None = None,
from_agent: Any | None = None,
skill_id: str | None = None,
client_extensions: list[str] | None = None,
transport: ClientTransportConfig | None = None,
accepted_output_modes: list[str] | None = None,
input_files: dict[str, Any] | None = None,
) -> TaskStateResult:
"""Internal async implementation of A2A delegation."""
if transport is None:
transport = ClientTransportConfig()
if auth:
auth_data = auth.model_dump_json(
exclude={
"_access_token",
"_token_expires_at",
"_refresh_token",
"_authorization_callback",
}
)
auth_hash = _auth_store.compute_key(type(auth).__name__, auth_data)
else:
auth_hash = _auth_store.compute_key("none", endpoint)
_auth_store.set(auth_hash, auth)
agent_card = await _afetch_agent_card_cached(
endpoint=endpoint, auth_hash=auth_hash, timeout=timeout
)
validate_auth_against_agent_card(agent_card, auth)
unsupported_exts = validate_required_extensions(agent_card, client_extensions)
if unsupported_exts:
ext_uris = [ext.uri for ext in unsupported_exts]
raise ValueError(
f"Agent requires extensions not supported by client: {ext_uris}"
)
negotiated: NegotiatedTransport | None = None
effective_transport: TransportType = transport.preferred or _DEFAULT_TRANSPORT
effective_url = endpoint
client_transports: list[str] = (
list(transport.supported) if transport.supported else [_DEFAULT_TRANSPORT]
)
try:
negotiated = negotiate_transport(
agent_card=agent_card,
client_supported_transports=client_transports,
client_preferred_transport=transport.preferred,
endpoint=endpoint,
a2a_agent_name=agent_card.name,
)
effective_transport = negotiated.transport # type: ignore[assignment]
effective_url = negotiated.url
except TransportNegotiationError as e:
logger.warning(
"Transport negotiation failed, using fallback",
extra={
"error": str(e),
"fallback_transport": effective_transport,
"fallback_url": effective_url,
"endpoint": endpoint,
"client_transports": client_transports,
"server_transports": [
iface.transport for iface in agent_card.additional_interfaces or []
]
+ [agent_card.preferred_transport or "JSONRPC"],
},
)
effective_output_modes = accepted_output_modes or DEFAULT_CLIENT_OUTPUT_MODES.copy()
content_negotiated = negotiate_content_types(
agent_card=agent_card,
client_output_modes=accepted_output_modes,
skill_name=skill_id,
endpoint=endpoint,
a2a_agent_name=agent_card.name,
)
if content_negotiated.output_modes:
effective_output_modes = content_negotiated.output_modes
headers, _ = await _prepare_auth_headers(auth, timeout)
a2a_agent_name = None
if agent_card.name:
a2a_agent_name = agent_card.name
agent_card_dict = agent_card.model_dump(exclude_none=True)
crewai_event_bus.emit(
agent_branch,
A2ADelegationStartedEvent(
endpoint=endpoint,
task_description=task_description,
agent_id=agent_id or endpoint,
context_id=context_id,
is_multiturn=is_multiturn,
turn_number=turn_number,
a2a_agent_name=a2a_agent_name,
agent_card=agent_card_dict,
protocol_version=agent_card.protocol_version,
provider=agent_card_dict.get("provider"),
skill_id=skill_id,
metadata=metadata,
extensions=list(extensions.keys()) if extensions else None,
from_task=from_task,
from_agent=from_agent,
),
)
if turn_number == 1:
agent_id_for_event = agent_id or endpoint
crewai_event_bus.emit(
agent_branch,
A2AConversationStartedEvent(
agent_id=agent_id_for_event,
endpoint=endpoint,
context_id=context_id,
a2a_agent_name=a2a_agent_name,
agent_card=agent_card_dict,
protocol_version=agent_card.protocol_version,
provider=agent_card_dict.get("provider"),
skill_id=skill_id,
reference_task_ids=reference_task_ids,
metadata=metadata,
extensions=list(extensions.keys()) if extensions else None,
from_task=from_task,
from_agent=from_agent,
),
)
message_parts = []
if context:
message_parts.append(f"Context:\n{context}\n\n")
message_parts.append(f"{task_description}")
message_text = "".join(message_parts)
if is_multiturn and conversation_history and not task_id:
if first_task_id := conversation_history[0].task_id:
task_id = first_task_id
parts: PartsDict = {"text": message_text}
if response_model:
parts.update(
{
"metadata": PartsMetadataDict(
mimeType="application/json",
schema=response_model.model_json_schema(),
)
}
)
message_metadata = metadata.copy() if metadata else {}
if skill_id:
message_metadata["skill_id"] = skill_id
parts_list: list[Part] = [Part(root=TextPart(**parts))]
parts_list.extend(_create_file_parts(input_files))
message = Message(
role=Role.user,
message_id=str(uuid.uuid4()),
parts=parts_list,
context_id=context_id,
task_id=task_id,
reference_task_ids=reference_task_ids,
metadata=message_metadata if message_metadata else None,
extensions=extensions,
)
new_messages: list[Message] = [*conversation_history, message]
crewai_event_bus.emit(
None,
A2AMessageSentEvent(
message=message_text,
turn_number=turn_number,
context_id=context_id,
message_id=message.message_id,
is_multiturn=is_multiturn,
agent_role=agent_role,
endpoint=endpoint,
a2a_agent_name=a2a_agent_name,
skill_id=skill_id,
metadata=message_metadata if message_metadata else None,
extensions=list(extensions.keys()) if extensions else None,
from_task=from_task,
from_agent=from_agent,
),
)
handler = get_handler(updates)
use_polling = isinstance(updates, PollingConfig)
handler_kwargs: dict[str, Any] = {
"turn_number": turn_number,
"is_multiturn": is_multiturn,
"agent_role": agent_role,
"context_id": context_id,
"task_id": task_id,
"endpoint": endpoint,
"agent_branch": agent_branch,
"a2a_agent_name": a2a_agent_name,
"from_task": from_task,
"from_agent": from_agent,
}
if isinstance(updates, PollingConfig):
handler_kwargs.update(
{
"polling_interval": updates.interval,
"polling_timeout": updates.timeout or float(timeout),
"history_length": updates.history_length,
"max_polls": updates.max_polls,
}
)
elif isinstance(updates, PushNotificationConfig):
handler_kwargs.update(
{
"config": updates,
"result_store": updates.result_store,
"polling_timeout": updates.timeout or float(timeout),
"polling_interval": updates.interval,
}
)
push_config_for_client = (
updates if isinstance(updates, PushNotificationConfig) else None
)
use_streaming = not use_polling and push_config_for_client is None
client_agent_card = agent_card
if effective_url != agent_card.url:
client_agent_card = agent_card.model_copy(update={"url": effective_url})
async with _create_a2a_client(
agent_card=client_agent_card,
transport_protocol=effective_transport,
timeout=timeout,
headers=headers,
streaming=use_streaming,
auth=auth,
use_polling=use_polling,
push_notification_config=push_config_for_client,
client_extensions=client_extensions,
accepted_output_modes=effective_output_modes, # type: ignore[arg-type]
grpc_config=transport.grpc,
) as client:
result = await handler.execute(
client=client,
message=message,
new_messages=new_messages,
agent_card=agent_card,
**handler_kwargs,
)
result["a2a_agent_name"] = a2a_agent_name
result["agent_card"] = agent_card.model_dump(exclude_none=True)
return result
def _normalize_grpc_metadata(
metadata: tuple[tuple[str, str], ...] | None,
) -> tuple[tuple[str, str], ...] | None:
"""Lowercase all gRPC metadata keys.
gRPC requires lowercase metadata keys, but some libraries (like the A2A SDK)
use mixed-case headers like 'X-A2A-Extensions'. This normalizes them.
"""
if metadata is None:
return None
return tuple((key.lower(), value) for key, value in metadata)
def _create_grpc_interceptors(
auth_metadata: list[tuple[str, str]] | None = None,
) -> list[Any]:
"""Create gRPC interceptors for metadata normalization and auth injection.
Args:
auth_metadata: Optional auth metadata to inject into all calls.
Used for insecure channels that need auth (non-localhost without TLS).
Returns a list of interceptors that lowercase metadata keys for gRPC
compatibility. Must be called after grpc is imported.
"""
import grpc.aio # type: ignore[import-untyped]
def _merge_metadata(
existing: tuple[tuple[str, str], ...] | None,
auth: list[tuple[str, str]] | None,
) -> tuple[tuple[str, str], ...] | None:
"""Merge existing metadata with auth metadata and normalize keys."""
merged: list[tuple[str, str]] = []
if existing:
merged.extend(existing)
if auth:
merged.extend(auth)
if not merged:
return None
return tuple((key.lower(), value) for key, value in merged)
def _inject_metadata(client_call_details: Any) -> Any:
"""Inject merged metadata into call details."""
return client_call_details._replace(
metadata=_merge_metadata(client_call_details.metadata, auth_metadata)
)
class MetadataUnaryUnary(grpc.aio.UnaryUnaryClientInterceptor): # type: ignore[misc,no-any-unimported]
"""Interceptor for unary-unary calls that injects auth metadata."""
async def intercept_unary_unary( # type: ignore[no-untyped-def]
self, continuation, client_call_details, request
):
"""Intercept unary-unary call and inject metadata."""
return await continuation(_inject_metadata(client_call_details), request)
class MetadataUnaryStream(grpc.aio.UnaryStreamClientInterceptor): # type: ignore[misc,no-any-unimported]
"""Interceptor for unary-stream calls that injects auth metadata."""
async def intercept_unary_stream( # type: ignore[no-untyped-def]
self, continuation, client_call_details, request
):
"""Intercept unary-stream call and inject metadata."""
return await continuation(_inject_metadata(client_call_details), request)
class MetadataStreamUnary(grpc.aio.StreamUnaryClientInterceptor): # type: ignore[misc,no-any-unimported]
"""Interceptor for stream-unary calls that injects auth metadata."""
async def intercept_stream_unary( # type: ignore[no-untyped-def]
self, continuation, client_call_details, request_iterator
):
"""Intercept stream-unary call and inject metadata."""
return await continuation(
_inject_metadata(client_call_details), request_iterator
)
class MetadataStreamStream(grpc.aio.StreamStreamClientInterceptor): # type: ignore[misc,no-any-unimported]
"""Interceptor for stream-stream calls that injects auth metadata."""
async def intercept_stream_stream( # type: ignore[no-untyped-def]
self, continuation, client_call_details, request_iterator
):
"""Intercept stream-stream call and inject metadata."""
return await continuation(
_inject_metadata(client_call_details), request_iterator
)
return [
MetadataUnaryUnary(),
MetadataUnaryStream(),
MetadataStreamUnary(),
MetadataStreamStream(),
]
def _create_grpc_channel_factory(
grpc_config: GRPCClientConfig,
auth: ClientAuthScheme | None = None,
) -> Callable[[str], Any]:
"""Create a gRPC channel factory with the given configuration.
Args:
grpc_config: gRPC client configuration with channel options.
auth: Optional ClientAuthScheme for TLS and auth configuration.
Returns:
A callable that creates gRPC channels from URLs.
"""
try:
import grpc
except ImportError as e:
raise ImportError(
"gRPC transport requires grpcio. Install with: pip install a2a-sdk[grpc]"
) from e
auth_metadata: list[tuple[str, str]] = []
if auth is not None:
from crewai.a2a.auth.client_schemes import (
APIKeyAuth,
BearerTokenAuth,
HTTPBasicAuth,
HTTPDigestAuth,
OAuth2AuthorizationCode,
OAuth2ClientCredentials,
)
if isinstance(auth, HTTPDigestAuth):
raise ValueError(
"HTTPDigestAuth is not supported with gRPC transport. "
"Digest authentication requires HTTP challenge-response flow. "
"Use BearerTokenAuth, HTTPBasicAuth, APIKeyAuth (header), or OAuth2 instead."
)
if isinstance(auth, APIKeyAuth) and auth.location in ("query", "cookie"):
raise ValueError(
f"APIKeyAuth with location='{auth.location}' is not supported with gRPC transport. "
"gRPC only supports header-based authentication. "
"Use APIKeyAuth with location='header' instead."
)
if isinstance(auth, BearerTokenAuth):
auth_metadata.append(("authorization", f"Bearer {auth.token}"))
elif isinstance(auth, HTTPBasicAuth):
import base64
basic_credentials = f"{auth.username}:{auth.password}"
encoded = base64.b64encode(basic_credentials.encode()).decode()
auth_metadata.append(("authorization", f"Basic {encoded}"))
elif isinstance(auth, APIKeyAuth) and auth.location == "header":
header_name = auth.name.lower()
auth_metadata.append((header_name, auth.api_key))
elif isinstance(auth, (OAuth2ClientCredentials, OAuth2AuthorizationCode)):
if auth._access_token:
auth_metadata.append(("authorization", f"Bearer {auth._access_token}"))
def factory(url: str) -> Any:
"""Create a gRPC channel for the given URL."""
target = url
use_tls = False
if url.startswith("grpcs://"):
target = url[8:]
use_tls = True
elif url.startswith("grpc://"):
target = url[7:]
elif url.startswith("https://"):
target = url[8:]
use_tls = True
elif url.startswith("http://"):
target = url[7:]
options: list[tuple[str, Any]] = []
if grpc_config.max_send_message_length is not None:
options.append(
("grpc.max_send_message_length", grpc_config.max_send_message_length)
)
if grpc_config.max_receive_message_length is not None:
options.append(
(
"grpc.max_receive_message_length",
grpc_config.max_receive_message_length,
)
)
if grpc_config.keepalive_time_ms is not None:
options.append(("grpc.keepalive_time_ms", grpc_config.keepalive_time_ms))
if grpc_config.keepalive_timeout_ms is not None:
options.append(
("grpc.keepalive_timeout_ms", grpc_config.keepalive_timeout_ms)
)
channel_credentials = None
if auth and hasattr(auth, "tls") and auth.tls:
channel_credentials = auth.tls.get_grpc_credentials()
elif use_tls:
channel_credentials = grpc.ssl_channel_credentials()
if channel_credentials and auth_metadata:
class AuthMetadataPlugin(grpc.AuthMetadataPlugin): # type: ignore[misc,no-any-unimported]
"""gRPC auth metadata plugin that adds auth headers as metadata."""
def __init__(self, metadata: list[tuple[str, str]]) -> None:
self._metadata = tuple(metadata)
def __call__( # type: ignore[no-any-unimported]
self,
context: grpc.AuthMetadataContext,
callback: grpc.AuthMetadataPluginCallback,
) -> None:
callback(self._metadata, None)
call_creds = grpc.metadata_call_credentials(
AuthMetadataPlugin(auth_metadata)
)
credentials = grpc.composite_channel_credentials(
channel_credentials, call_creds
)
interceptors = _create_grpc_interceptors()
return grpc.aio.secure_channel(
target, credentials, options=options or None, interceptors=interceptors
)
if channel_credentials:
interceptors = _create_grpc_interceptors()
return grpc.aio.secure_channel(
target,
channel_credentials,
options=options or None,
interceptors=interceptors,
)
interceptors = _create_grpc_interceptors(
auth_metadata=auth_metadata if auth_metadata else None
)
return grpc.aio.insecure_channel(
target, options=options or None, interceptors=interceptors
)
return factory
@asynccontextmanager
async def _create_a2a_client(
agent_card: AgentCard,
transport_protocol: Literal["JSONRPC", "GRPC", "HTTP+JSON"],
timeout: int,
headers: MutableMapping[str, str],
streaming: bool,
auth: ClientAuthScheme | None = None,
use_polling: bool = False,
push_notification_config: PushNotificationConfig | None = None,
client_extensions: list[str] | None = None,
accepted_output_modes: list[str] | None = None,
grpc_config: GRPCClientConfig | None = None,
) -> AsyncIterator[Client]:
"""Create and configure an A2A client.
Args:
agent_card: The A2A agent card.
transport_protocol: Transport protocol to use.
timeout: Request timeout in seconds.
headers: HTTP headers (already with auth applied).
streaming: Enable streaming responses.
auth: Optional ClientAuthScheme for client configuration.
use_polling: Enable polling mode.
push_notification_config: Optional push notification config.
client_extensions: A2A protocol extension URIs to declare support for.
accepted_output_modes: MIME types the client can accept in responses.
grpc_config: Optional gRPC client configuration.
Yields:
Configured A2A client instance.
"""
verify = _get_tls_verify(auth)
async with httpx.AsyncClient(
timeout=timeout,
headers=headers,
verify=verify,
) as httpx_client:
if auth and isinstance(auth, (HTTPDigestAuth, APIKeyAuth)):
configure_auth_client(auth, httpx_client)
push_configs: list[A2APushNotificationConfig] = []
if push_notification_config is not None:
push_configs.append(
A2APushNotificationConfig(
url=str(push_notification_config.url),
id=push_notification_config.id,
token=push_notification_config.token,
authentication=push_notification_config.authentication,
)
)
grpc_channel_factory = None
if transport_protocol == "GRPC":
grpc_channel_factory = _create_grpc_channel_factory(
grpc_config or GRPCClientConfig(),
auth=auth,
)
config = ClientConfig(
httpx_client=httpx_client,
supported_transports=[transport_protocol],
streaming=streaming and not use_polling,
polling=use_polling,
accepted_output_modes=accepted_output_modes or DEFAULT_CLIENT_OUTPUT_MODES, # type: ignore[arg-type]
push_notification_configs=push_configs,
grpc_channel_factory=grpc_channel_factory,
)
factory = ClientFactory(config)
client = factory.create(agent_card)
if client_extensions:
await client.add_request_middleware(ExtensionsMiddleware(client_extensions))
yield client
from crewai_a2a.utils.delegation import * # noqa: E402, F403

View File

@@ -1,131 +1,13 @@
"""Structured JSON logging utilities for A2A module."""
"""Backward-compatibility shim — use ``crewai_a2a.utils.logging`` instead."""
from __future__ import annotations
from contextvars import ContextVar
from datetime import datetime, timezone
import json
import logging
from typing import Any
import warnings
_log_context: ContextVar[dict[str, Any] | None] = ContextVar(
"log_context", default=None
warnings.warn(
"'crewai.a2a.utils.logging' has been moved to 'crewai_a2a.utils.logging'. "
"Please update your imports. The old path will be removed in v2.0.0.",
FutureWarning,
stacklevel=2,
)
class JSONFormatter(logging.Formatter):
"""JSON formatter for structured logging.
Outputs logs as JSON with consistent fields for log aggregators.
"""
def format(self, record: logging.LogRecord) -> str:
"""Format log record as JSON string."""
log_data: dict[str, Any] = {
"timestamp": datetime.now(timezone.utc).isoformat(),
"level": record.levelname,
"logger": record.name,
"message": record.getMessage(),
}
if record.exc_info:
log_data["exception"] = self.formatException(record.exc_info)
context = _log_context.get()
if context is not None:
log_data.update(context)
if hasattr(record, "task_id"):
log_data["task_id"] = record.task_id
if hasattr(record, "context_id"):
log_data["context_id"] = record.context_id
if hasattr(record, "agent"):
log_data["agent"] = record.agent
if hasattr(record, "endpoint"):
log_data["endpoint"] = record.endpoint
if hasattr(record, "extension"):
log_data["extension"] = record.extension
if hasattr(record, "error"):
log_data["error"] = record.error
for key, value in record.__dict__.items():
if key.startswith("_") or key in (
"name",
"msg",
"args",
"created",
"filename",
"funcName",
"levelname",
"levelno",
"lineno",
"module",
"msecs",
"pathname",
"process",
"processName",
"relativeCreated",
"stack_info",
"exc_info",
"exc_text",
"thread",
"threadName",
"taskName",
"message",
):
continue
if key not in log_data:
log_data[key] = value
return json.dumps(log_data, default=str)
class LogContext:
"""Context manager for adding fields to all logs within a scope.
Example:
with LogContext(task_id="abc", context_id="xyz"):
logger.info("Processing task") # Includes task_id and context_id
"""
def __init__(self, **fields: Any) -> None:
self._fields = fields
self._token: Any = None
def __enter__(self) -> LogContext:
current = _log_context.get() or {}
new_context = {**current, **self._fields}
self._token = _log_context.set(new_context)
return self
def __exit__(self, *args: Any) -> None:
_log_context.reset(self._token)
def configure_json_logging(logger_name: str = "crewai.a2a") -> None:
"""Configure JSON logging for the A2A module.
Args:
logger_name: Logger name to configure.
"""
logger = logging.getLogger(logger_name)
for handler in logger.handlers[:]:
logger.removeHandler(handler)
handler = logging.StreamHandler()
handler.setFormatter(JSONFormatter())
logger.addHandler(handler)
def get_logger(name: str) -> logging.Logger:
"""Get a logger configured for structured JSON output.
Args:
name: Logger name.
Returns:
Configured logger instance.
"""
return logging.getLogger(name)
from crewai_a2a.utils.logging import * # noqa: E402, F403

View File

@@ -1,101 +1,13 @@
"""Response model utilities for A2A agent interactions."""
"""Backward-compatibility shim — use ``crewai_a2a.utils.response_model`` instead."""
from __future__ import annotations
from typing import TypeAlias
from pydantic import BaseModel, Field, create_model
from crewai.a2a.config import A2AClientConfig, A2AConfig, A2AServerConfig
from crewai.types.utils import create_literals_from_strings
import warnings
A2AConfigTypes: TypeAlias = A2AConfig | A2AServerConfig | A2AClientConfig
A2AClientConfigTypes: TypeAlias = A2AConfig | A2AClientConfig
warnings.warn(
"'crewai.a2a.utils.response_model' has been moved to 'crewai_a2a.utils.response_model'. "
"Please update your imports. The old path will be removed in v2.0.0.",
FutureWarning,
stacklevel=2,
)
def create_agent_response_model(agent_ids: tuple[str, ...]) -> type[BaseModel] | None:
"""Create a dynamic AgentResponse model with Literal types for agent IDs.
Args:
agent_ids: List of available A2A agent IDs.
Returns:
Dynamically created Pydantic model with Literal-constrained a2a_ids field,
or None if agent_ids is empty.
"""
if not agent_ids:
return None
DynamicLiteral = create_literals_from_strings(agent_ids) # noqa: N806
return create_model(
"AgentResponse",
a2a_ids=(
tuple[DynamicLiteral, ...], # type: ignore[valid-type]
Field(
default_factory=tuple,
max_length=len(agent_ids),
description="A2A agent IDs to delegate to.",
),
),
message=(
str,
Field(
description="The message content. If is_a2a=true, this is sent to the A2A agent. If is_a2a=false, this is your final answer ending the conversation."
),
),
is_a2a=(
bool,
Field(
description="Set to false when the remote agent has answered your question - extract their answer and return it as your final message. Set to true ONLY if you need to ask a NEW, DIFFERENT question. NEVER repeat the same request - if the conversation history shows the agent already answered, set is_a2a=false immediately."
),
),
__base__=BaseModel,
)
def extract_a2a_agent_ids_from_config(
a2a_config: list[A2AConfigTypes] | A2AConfigTypes | None,
) -> tuple[list[A2AClientConfigTypes], tuple[str, ...]]:
"""Extract A2A agent IDs from A2A configuration.
Filters out A2AServerConfig since it doesn't have an endpoint for delegation.
Args:
a2a_config: A2A configuration (any type).
Returns:
Tuple of client A2A configs list and agent endpoint IDs.
"""
if a2a_config is None:
return [], ()
configs: list[A2AConfigTypes]
if isinstance(a2a_config, (A2AConfig, A2AClientConfig, A2AServerConfig)):
configs = [a2a_config]
else:
configs = a2a_config
# Filter to only client configs (those with endpoint)
client_configs: list[A2AClientConfigTypes] = [
config for config in configs if isinstance(config, (A2AConfig, A2AClientConfig))
]
return client_configs, tuple(config.endpoint for config in client_configs)
def get_a2a_agents_and_response_model(
a2a_config: list[A2AConfigTypes] | A2AConfigTypes | None,
) -> tuple[list[A2AClientConfigTypes], type[BaseModel] | None]:
"""Get A2A agent configs and response model.
Args:
a2a_config: A2A configuration (any type).
Returns:
Tuple of client A2A configs and response model.
"""
a2a_agents, agent_ids = extract_a2a_agent_ids_from_config(a2a_config=a2a_config)
return a2a_agents, create_agent_response_model(agent_ids)
from crewai_a2a.utils.response_model import * # noqa: E402, F403

View File

@@ -1,584 +1,13 @@
"""A2A task utilities for server-side task management."""
"""Backward-compatibility shim — use ``crewai_a2a.utils.task`` instead."""
from __future__ import annotations
import asyncio
import base64
from collections.abc import Callable, Coroutine
from datetime import datetime
from functools import wraps
import json
import logging
import os
from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar, TypedDict, cast
from urllib.parse import urlparse
from a2a.server.agent_execution import RequestContext
from a2a.server.events import EventQueue
from a2a.types import (
Artifact,
FileWithBytes,
FileWithUri,
InternalError,
InvalidParamsError,
Message,
Part,
Task as A2ATask,
TaskState,
TaskStatus,
TaskStatusUpdateEvent,
)
from a2a.utils import (
get_data_parts,
get_file_parts,
new_agent_text_message,
new_data_artifact,
new_text_artifact,
)
from a2a.utils.errors import ServerError
from aiocache import SimpleMemoryCache, caches # type: ignore[import-untyped]
from pydantic import BaseModel
from crewai.a2a.utils.agent_card import _get_server_config
from crewai.a2a.utils.content_type import validate_message_parts
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.a2a_events import (
A2AServerTaskCanceledEvent,
A2AServerTaskCompletedEvent,
A2AServerTaskFailedEvent,
A2AServerTaskStartedEvent,
)
from crewai.task import Task
from crewai.utilities.pydantic_schema_utils import create_model_from_schema
import warnings
if TYPE_CHECKING:
from crewai.a2a.extensions.server import ExtensionContext, ServerExtensionRegistry
from crewai.agent import Agent
logger = logging.getLogger(__name__)
P = ParamSpec("P")
T = TypeVar("T")
class RedisCacheConfig(TypedDict, total=False):
"""Configuration for aiocache Redis backend."""
cache: str
endpoint: str
port: int
db: int
password: str
def _parse_redis_url(url: str) -> RedisCacheConfig:
"""Parse a Redis URL into aiocache configuration.
Args:
url: Redis connection URL (e.g., redis://localhost:6379/0).
Returns:
Configuration dict for aiocache.RedisCache.
"""
parsed = urlparse(url)
config: RedisCacheConfig = {
"cache": "aiocache.RedisCache",
"endpoint": parsed.hostname or "localhost",
"port": parsed.port or 6379,
}
if parsed.path and parsed.path != "/":
try:
config["db"] = int(parsed.path.lstrip("/"))
except ValueError:
pass
if parsed.password:
config["password"] = parsed.password
return config
_redis_url = os.environ.get("REDIS_URL")
caches.set_config(
{
"default": _parse_redis_url(_redis_url)
if _redis_url
else {
"cache": "aiocache.SimpleMemoryCache",
}
}
warnings.warn(
"'crewai.a2a.utils.task' has been moved to 'crewai_a2a.utils.task'. "
"Please update your imports. The old path will be removed in v2.0.0.",
FutureWarning,
stacklevel=2,
)
def cancellable(
fn: Callable[P, Coroutine[Any, Any, T]],
) -> Callable[P, Coroutine[Any, Any, T]]:
"""Decorator that enables cancellation for A2A task execution.
Runs a cancellation watcher concurrently with the wrapped function.
When a cancel event is published, the execution is cancelled.
Args:
fn: The async function to wrap.
Returns:
Wrapped function with cancellation support.
"""
@wraps(fn)
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
"""Wrap function with cancellation monitoring."""
context: RequestContext | None = None
for arg in args:
if isinstance(arg, RequestContext):
context = arg
break
if context is None:
context = cast(RequestContext | None, kwargs.get("context"))
if context is None:
return await fn(*args, **kwargs)
task_id = context.task_id
cache = caches.get("default")
async def poll_for_cancel() -> bool:
"""Poll cache for cancellation flag."""
while True:
if await cache.get(f"cancel:{task_id}"):
return True
await asyncio.sleep(0.1)
async def watch_for_cancel() -> bool:
"""Watch for cancellation events via pub/sub or polling."""
if isinstance(cache, SimpleMemoryCache):
return await poll_for_cancel()
try:
client = cache.client
pubsub = client.pubsub()
await pubsub.subscribe(f"cancel:{task_id}")
async for message in pubsub.listen():
if message["type"] == "message":
return True
except (OSError, ConnectionError) as e:
logger.warning(
"Cancel watcher Redis error, falling back to polling",
extra={"task_id": task_id, "error": str(e)},
)
return await poll_for_cancel()
return False
execute_task = asyncio.create_task(fn(*args, **kwargs))
cancel_watch = asyncio.create_task(watch_for_cancel())
try:
done, _ = await asyncio.wait(
[execute_task, cancel_watch],
return_when=asyncio.FIRST_COMPLETED,
)
if cancel_watch in done:
execute_task.cancel()
try:
await execute_task
except asyncio.CancelledError:
pass
raise asyncio.CancelledError(f"Task {task_id} was cancelled")
cancel_watch.cancel()
return execute_task.result()
finally:
await cache.delete(f"cancel:{task_id}")
return wrapper
def _convert_a2a_files_to_file_inputs(
a2a_files: list[FileWithBytes | FileWithUri],
) -> dict[str, Any]:
"""Convert a2a file types to crewai FileInput dict.
Args:
a2a_files: List of FileWithBytes or FileWithUri from a2a SDK.
Returns:
Dictionary mapping file names to FileInput objects.
"""
try:
from crewai_files import File, FileBytes
except ImportError:
logger.debug("crewai_files not installed, returning empty file dict")
return {}
file_dict: dict[str, Any] = {}
for idx, a2a_file in enumerate(a2a_files):
if isinstance(a2a_file, FileWithBytes):
file_bytes = base64.b64decode(a2a_file.bytes)
name = a2a_file.name or f"file_{idx}"
file_source = FileBytes(data=file_bytes, filename=a2a_file.name)
file_dict[name] = File(source=file_source)
elif isinstance(a2a_file, FileWithUri):
name = a2a_file.name or f"file_{idx}"
file_dict[name] = File(source=a2a_file.uri)
return file_dict
def _extract_response_schema(parts: list[Part]) -> dict[str, Any] | None:
"""Extract response schema from message parts metadata.
The client may include a JSON schema in TextPart metadata to specify
the expected response format (see delegation.py line 463).
Args:
parts: List of message parts.
Returns:
JSON schema dict if found, None otherwise.
"""
for part in parts:
if part.root.kind == "text" and part.root.metadata:
schema = part.root.metadata.get("schema")
if schema and isinstance(schema, dict):
return schema # type: ignore[no-any-return]
return None
def _create_result_artifact(
result: Any,
task_id: str,
) -> Artifact:
"""Create artifact from task result, using DataPart for structured data.
Args:
result: The task execution result.
task_id: The task ID for naming the artifact.
Returns:
Artifact with appropriate part type (DataPart for dict/Pydantic, TextPart for strings).
"""
artifact_name = f"result_{task_id}"
if isinstance(result, dict):
return new_data_artifact(artifact_name, result)
if isinstance(result, BaseModel):
return new_data_artifact(artifact_name, result.model_dump())
return new_text_artifact(artifact_name, str(result))
def _build_task_description(
user_message: str,
structured_inputs: list[dict[str, Any]],
) -> str:
"""Build task description including structured data if present.
Args:
user_message: The original user message text.
structured_inputs: List of structured data from DataParts.
Returns:
Task description with structured data appended if present.
"""
if not structured_inputs:
return user_message
structured_json = json.dumps(structured_inputs, indent=2)
return f"{user_message}\n\nStructured Data:\n{structured_json}"
async def execute(
agent: Agent,
context: RequestContext,
event_queue: EventQueue,
) -> None:
"""Execute an A2A task using a CrewAI agent.
Args:
agent: The CrewAI agent to execute the task.
context: The A2A request context containing the user's message.
event_queue: The event queue for sending responses back.
"""
await _execute_impl(agent, context, event_queue, None, None)
@cancellable
async def _execute_impl(
agent: Agent,
context: RequestContext,
event_queue: EventQueue,
extension_registry: ServerExtensionRegistry | None,
extension_context: ExtensionContext | None,
) -> None:
"""Internal implementation for task execution with optional extensions."""
server_config = _get_server_config(agent)
if context.message and context.message.parts and server_config:
allowed_modes = server_config.default_input_modes
invalid_types = validate_message_parts(context.message.parts, allowed_modes)
if invalid_types:
raise ServerError(
InvalidParamsError(
message=f"Unsupported content type(s): {', '.join(invalid_types)}. "
f"Supported: {', '.join(allowed_modes)}"
)
)
if extension_registry and extension_context:
await extension_registry.invoke_on_request(extension_context)
user_message = context.get_user_input()
response_model: type[BaseModel] | None = None
structured_inputs: list[dict[str, Any]] = []
a2a_files: list[FileWithBytes | FileWithUri] = []
if context.message and context.message.parts:
schema = _extract_response_schema(context.message.parts)
if schema:
try:
response_model = create_model_from_schema(schema)
except Exception as e:
logger.debug(
"Failed to create response model from schema",
extra={"error": str(e), "schema_title": schema.get("title")},
)
structured_inputs = get_data_parts(context.message.parts)
a2a_files = get_file_parts(context.message.parts)
task_id = context.task_id
context_id = context.context_id
if task_id is None or context_id is None:
msg = "task_id and context_id are required"
crewai_event_bus.emit(
agent,
A2AServerTaskFailedEvent(
task_id="",
context_id="",
error=msg,
from_agent=agent,
),
)
raise ServerError(InvalidParamsError(message=msg)) from None
task = Task(
description=_build_task_description(user_message, structured_inputs),
expected_output="Response to the user's request",
agent=agent,
response_model=response_model,
input_files=_convert_a2a_files_to_file_inputs(a2a_files),
)
crewai_event_bus.emit(
agent,
A2AServerTaskStartedEvent(
task_id=task_id,
context_id=context_id,
from_task=task,
from_agent=agent,
),
)
try:
result = await agent.aexecute_task(task=task, tools=agent.tools)
if extension_registry and extension_context:
result = await extension_registry.invoke_on_response(
extension_context, result
)
result_str = str(result)
history: list[Message] = [context.message] if context.message else []
history.append(new_agent_text_message(result_str, context_id, task_id))
await event_queue.enqueue_event(
A2ATask(
id=task_id,
context_id=context_id,
status=TaskStatus(state=TaskState.completed),
artifacts=[_create_result_artifact(result, task_id)],
history=history,
)
)
crewai_event_bus.emit(
agent,
A2AServerTaskCompletedEvent(
task_id=task_id,
context_id=context_id,
result=str(result),
from_task=task,
from_agent=agent,
),
)
except asyncio.CancelledError:
crewai_event_bus.emit(
agent,
A2AServerTaskCanceledEvent(
task_id=task_id,
context_id=context_id,
from_task=task,
from_agent=agent,
),
)
raise
except Exception as e:
crewai_event_bus.emit(
agent,
A2AServerTaskFailedEvent(
task_id=task_id,
context_id=context_id,
error=str(e),
from_task=task,
from_agent=agent,
),
)
raise ServerError(
error=InternalError(message=f"Task execution failed: {e}")
) from e
async def execute_with_extensions(
agent: Agent,
context: RequestContext,
event_queue: EventQueue,
extension_registry: ServerExtensionRegistry,
extension_context: ExtensionContext,
) -> None:
"""Execute an A2A task with extension hooks.
Args:
agent: The CrewAI agent to execute the task.
context: The A2A request context containing the user's message.
event_queue: The event queue for sending responses back.
extension_registry: Registry of server extensions.
extension_context: Context for extension hooks.
"""
await _execute_impl(
agent, context, event_queue, extension_registry, extension_context
)
async def cancel(
context: RequestContext,
event_queue: EventQueue,
) -> A2ATask | None:
"""Cancel an A2A task.
Publishes a cancel event that the cancellable decorator listens for.
Args:
context: The A2A request context containing task information.
event_queue: The event queue for sending the cancellation status.
Returns:
The canceled task with updated status.
"""
task_id = context.task_id
context_id = context.context_id
if task_id is None or context_id is None:
raise ServerError(InvalidParamsError(message="task_id and context_id required"))
if context.current_task and context.current_task.status.state in (
TaskState.completed,
TaskState.failed,
TaskState.canceled,
):
return context.current_task
cache = caches.get("default")
await cache.set(f"cancel:{task_id}", True, ttl=3600)
if not isinstance(cache, SimpleMemoryCache):
await cache.client.publish(f"cancel:{task_id}", "cancel")
await event_queue.enqueue_event(
TaskStatusUpdateEvent(
task_id=task_id,
context_id=context_id,
status=TaskStatus(state=TaskState.canceled),
final=True,
)
)
if context.current_task:
context.current_task.status = TaskStatus(state=TaskState.canceled)
return context.current_task
return None
def list_tasks(
tasks: list[A2ATask],
context_id: str | None = None,
status: TaskState | None = None,
status_timestamp_after: datetime | None = None,
page_size: int = 50,
page_token: str | None = None,
history_length: int | None = None,
include_artifacts: bool = False,
) -> tuple[list[A2ATask], str | None, int]:
"""Filter and paginate A2A tasks.
Provides filtering by context, status, and timestamp, along with
cursor-based pagination. This is a pure utility function that operates
on an in-memory list of tasks - storage retrieval is handled separately.
Args:
tasks: All tasks to filter.
context_id: Filter by context ID to get tasks in a conversation.
status: Filter by task state (e.g., completed, working).
status_timestamp_after: Filter to tasks updated after this time.
page_size: Maximum tasks per page (default 50).
page_token: Base64-encoded cursor from previous response.
history_length: Limit history messages per task (None = full history).
include_artifacts: Whether to include task artifacts (default False).
Returns:
Tuple of (filtered_tasks, next_page_token, total_count).
- filtered_tasks: Tasks matching filters, paginated and trimmed.
- next_page_token: Token for next page, or None if no more pages.
- total_count: Total number of tasks matching filters (before pagination).
"""
filtered: list[A2ATask] = []
for task in tasks:
if context_id and task.context_id != context_id:
continue
if status and task.status.state != status:
continue
if status_timestamp_after and task.status.timestamp:
ts = datetime.fromisoformat(task.status.timestamp.replace("Z", "+00:00"))
if ts <= status_timestamp_after:
continue
filtered.append(task)
def get_timestamp(t: A2ATask) -> datetime:
"""Extract timestamp from task status for sorting."""
if t.status.timestamp is None:
return datetime.min
return datetime.fromisoformat(t.status.timestamp.replace("Z", "+00:00"))
filtered.sort(key=get_timestamp, reverse=True)
total = len(filtered)
start = 0
if page_token:
try:
cursor_id = base64.b64decode(page_token).decode()
for idx, task in enumerate(filtered):
if task.id == cursor_id:
start = idx + 1
break
except (ValueError, UnicodeDecodeError):
pass
page = filtered[start : start + page_size]
result: list[A2ATask] = []
for task in page:
task = task.model_copy(deep=True)
if history_length is not None and task.history:
task.history = task.history[-history_length:]
if not include_artifacts:
task.artifacts = None
result.append(task)
next_token: str | None = None
if result and len(result) == page_size:
next_token = base64.b64encode(result[-1].id.encode()).decode()
return result, next_token, total
from crewai_a2a.utils.task import * # noqa: E402, F403

View File

@@ -1,215 +1,13 @@
"""Transport negotiation utilities for A2A protocol.
"""Backward-compatibility shim — use ``crewai_a2a.utils.transport`` instead."""
This module provides functionality for negotiating the transport protocol
between an A2A client and server based on their respective capabilities
and preferences.
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import Final, Literal
from a2a.types import AgentCard, AgentInterface
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.a2a_events import A2ATransportNegotiatedEvent
import warnings
TransportProtocol = Literal["JSONRPC", "GRPC", "HTTP+JSON"]
NegotiationSource = Literal["client_preferred", "server_preferred", "fallback"]
warnings.warn(
"'crewai.a2a.utils.transport' has been moved to 'crewai_a2a.utils.transport'. "
"Please update your imports. The old path will be removed in v2.0.0.",
FutureWarning,
stacklevel=2,
)
JSONRPC_TRANSPORT: Literal["JSONRPC"] = "JSONRPC"
GRPC_TRANSPORT: Literal["GRPC"] = "GRPC"
HTTP_JSON_TRANSPORT: Literal["HTTP+JSON"] = "HTTP+JSON"
DEFAULT_TRANSPORT_PREFERENCE: Final[list[TransportProtocol]] = [
JSONRPC_TRANSPORT,
GRPC_TRANSPORT,
HTTP_JSON_TRANSPORT,
]
@dataclass
class NegotiatedTransport:
"""Result of transport negotiation.
Attributes:
transport: The negotiated transport protocol.
url: The URL to use for this transport.
source: How the transport was selected ('preferred', 'additional', 'fallback').
"""
transport: str
url: str
source: NegotiationSource
class TransportNegotiationError(Exception):
"""Raised when no compatible transport can be negotiated."""
def __init__(
self,
client_transports: list[str],
server_transports: list[str],
message: str | None = None,
) -> None:
"""Initialize the error with negotiation details.
Args:
client_transports: Transports supported by the client.
server_transports: Transports supported by the server.
message: Optional custom error message.
"""
self.client_transports = client_transports
self.server_transports = server_transports
if message is None:
message = (
f"No compatible transport found. "
f"Client supports: {client_transports}. "
f"Server supports: {server_transports}."
)
super().__init__(message)
def _get_server_interfaces(agent_card: AgentCard) -> list[AgentInterface]:
"""Extract all available interfaces from an AgentCard.
Creates a unified list of interfaces including the primary URL and
any additional interfaces declared by the agent.
Args:
agent_card: The agent's card containing transport information.
Returns:
List of AgentInterface objects representing all available endpoints.
"""
interfaces: list[AgentInterface] = []
primary_transport = agent_card.preferred_transport or JSONRPC_TRANSPORT
interfaces.append(
AgentInterface(
transport=primary_transport,
url=agent_card.url,
)
)
if agent_card.additional_interfaces:
for interface in agent_card.additional_interfaces:
is_duplicate = any(
i.url == interface.url and i.transport == interface.transport
for i in interfaces
)
if not is_duplicate:
interfaces.append(interface)
return interfaces
def negotiate_transport(
agent_card: AgentCard,
client_supported_transports: list[str] | None = None,
client_preferred_transport: str | None = None,
emit_event: bool = True,
endpoint: str | None = None,
a2a_agent_name: str | None = None,
) -> NegotiatedTransport:
"""Negotiate the transport protocol between client and server.
Compares the client's supported transports with the server's available
interfaces to find a compatible transport and URL.
Negotiation logic:
1. If client_preferred_transport is set and server supports it → use it
2. Otherwise, if server's preferred is in client's supported → use server's
3. Otherwise, find first match from client's supported in server's interfaces
Args:
agent_card: The server's AgentCard with transport information.
client_supported_transports: Transports the client can use.
Defaults to ["JSONRPC"] if not specified.
client_preferred_transport: Client's preferred transport. If set and
server supports it, takes priority over server preference.
emit_event: Whether to emit a transport negotiation event.
endpoint: Original endpoint URL for event metadata.
a2a_agent_name: Agent name for event metadata.
Returns:
NegotiatedTransport with the selected transport, URL, and source.
Raises:
TransportNegotiationError: If no compatible transport is found.
"""
if client_supported_transports is None:
client_supported_transports = [JSONRPC_TRANSPORT]
client_transports = [t.upper() for t in client_supported_transports]
client_preferred = (
client_preferred_transport.upper() if client_preferred_transport else None
)
server_interfaces = _get_server_interfaces(agent_card)
server_transports = [i.transport.upper() for i in server_interfaces]
transport_to_interface: dict[str, AgentInterface] = {}
for interface in server_interfaces:
transport_upper = interface.transport.upper()
if transport_upper not in transport_to_interface:
transport_to_interface[transport_upper] = interface
result: NegotiatedTransport | None = None
if client_preferred and client_preferred in transport_to_interface:
interface = transport_to_interface[client_preferred]
result = NegotiatedTransport(
transport=interface.transport,
url=interface.url,
source="client_preferred",
)
else:
server_preferred = (agent_card.preferred_transport or JSONRPC_TRANSPORT).upper()
if (
server_preferred in client_transports
and server_preferred in transport_to_interface
):
interface = transport_to_interface[server_preferred]
result = NegotiatedTransport(
transport=interface.transport,
url=interface.url,
source="server_preferred",
)
else:
for transport in client_transports:
if transport in transport_to_interface:
interface = transport_to_interface[transport]
result = NegotiatedTransport(
transport=interface.transport,
url=interface.url,
source="fallback",
)
break
if result is None:
raise TransportNegotiationError(
client_transports=client_transports,
server_transports=server_transports,
)
if emit_event:
crewai_event_bus.emit(
None,
A2ATransportNegotiatedEvent(
endpoint=endpoint or agent_card.url,
a2a_agent_name=a2a_agent_name or agent_card.name,
negotiated_transport=result.transport,
negotiated_url=result.url,
source=result.source,
client_supported_transports=client_transports,
server_supported_transports=server_transports,
server_preferred_transport=agent_card.preferred_transport
or JSONRPC_TRANSPORT,
client_preferred_transport=client_preferred,
),
)
return result
from crewai_a2a.utils.transport import * # noqa: E402, F403

File diff suppressed because it is too large Load Diff

View File

@@ -84,16 +84,16 @@ from crewai.utilities.training_handler import CrewTrainingHandler
try:
from crewai.a2a.types import AgentResponseProtocol
from crewai_a2a.types import AgentResponseProtocol
except ImportError:
AgentResponseProtocol = None # type: ignore[assignment, misc]
if TYPE_CHECKING:
from crewai_a2a.config import A2AClientConfig, A2AConfig, A2AServerConfig
from crewai_files import FileInput
from crewai_tools import CodeInterpreterTool
from crewai.a2a.config import A2AClientConfig, A2AConfig, A2AServerConfig
from crewai.agents.agent_builder.base_agent import PlatformAppOrAction
from crewai.task import Task
from crewai.tools.base_tool import BaseTool
@@ -1740,7 +1740,7 @@ class Agent(BaseAgent):
# Rebuild Agent model to resolve A2A type forward references
try:
from crewai.a2a.config import (
from crewai_a2a.config import (
A2AClientConfig as _A2AClientConfig,
A2AConfig as _A2AConfig,
A2AServerConfig as _A2AServerConfig,

View File

@@ -58,10 +58,10 @@ class AgentMeta(ModelMetaclass):
a2a_value = getattr(self, "a2a", None)
if a2a_value is not None:
from crewai.a2a.extensions.registry import (
from crewai_a2a.extensions.registry import (
create_extension_registry_from_config,
)
from crewai.a2a.wrapper import wrap_agent_with_a2a_instance
from crewai_a2a.wrapper import wrap_agent_with_a2a_instance
extension_registry = create_extension_registry_from_config(
a2a_value

View File

@@ -31,10 +31,9 @@ from typing_extensions import Self
if TYPE_CHECKING:
from crewai_a2a.config import A2AClientConfig, A2AConfig, A2AServerConfig
from crewai_files import FileInput
from crewai.a2a.config import A2AClientConfig, A2AConfig, A2AServerConfig
from crewai.agents.agent_builder.base_agent import BaseAgent
from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess
from crewai.agents.cache.cache_handler import CacheHandler
@@ -120,8 +119,9 @@ def _kickoff_with_a2a_support(
Returns:
LiteAgentOutput from either local execution or A2A delegation.
"""
from crewai.a2a.utils.response_model import get_a2a_agents_and_response_model
from crewai.a2a.wrapper import _execute_task_with_a2a
from crewai_a2a.utils.response_model import get_a2a_agents_and_response_model
from crewai_a2a.wrapper import _execute_task_with_a2a
from crewai.task import Task
a2a_agents, agent_response_model = get_a2a_agents_and_response_model(agent.a2a)
@@ -319,11 +319,11 @@ class LiteAgent(FlowTrackable, BaseModel):
def setup_a2a_support(self) -> Self:
"""Setup A2A extensions and server methods if a2a config exists."""
if self.a2a:
from crewai.a2a.config import A2AClientConfig, A2AConfig
from crewai.a2a.extensions.registry import (
from crewai_a2a.config import A2AClientConfig, A2AConfig
from crewai_a2a.extensions.registry import (
create_extension_registry_from_config,
)
from crewai.a2a.utils.agent_card import inject_a2a_server_methods
from crewai_a2a.utils.agent_card import inject_a2a_server_methods
configs = self.a2a if isinstance(self.a2a, list) else [self.a2a]
client_configs = [
@@ -995,7 +995,7 @@ class LiteAgent(FlowTrackable, BaseModel):
try:
from crewai.a2a.config import (
from crewai_a2a.config import (
A2AClientConfig as _A2AClientConfig,
A2AConfig as _A2AConfig,
A2AServerConfig as _A2AServerConfig,

View File

@@ -4,7 +4,7 @@ from unittest.mock import MagicMock, patch
import pytest
from crewai.a2a.config import A2AConfig
from crewai_a2a.config import A2AConfig
try:
from a2a.types import Message, Role
@@ -27,8 +27,8 @@ def _create_mock_agent_card(name: str = "Test", url: str = "http://test-endpoint
@pytest.mark.skipif(not A2A_SDK_INSTALLED, reason="Requires a2a-sdk to be installed")
def test_trust_remote_completion_status_true_returns_directly():
"""When trust_remote_completion_status=True and A2A returns completed, return result directly."""
from crewai.a2a.wrapper import _delegate_to_a2a
from crewai.a2a.types import AgentResponseProtocol
from crewai_a2a.wrapper import _delegate_to_a2a
from crewai_a2a.types import AgentResponseProtocol
from crewai import Agent, Task
a2a_config = A2AConfig(
@@ -51,8 +51,8 @@ def test_trust_remote_completion_status_true_returns_directly():
a2a_ids = ["http://test-endpoint.com/"]
with (
patch("crewai.a2a.wrapper.execute_a2a_delegation") as mock_execute,
patch("crewai.a2a.wrapper._fetch_agent_cards_concurrently") as mock_fetch,
patch("crewai_a2a.wrapper.execute_a2a_delegation") as mock_execute,
patch("crewai_a2a.wrapper._fetch_agent_cards_concurrently") as mock_fetch,
):
mock_card = _create_mock_agent_card()
mock_fetch.return_value = ({"http://test-endpoint.com/": mock_card}, {})
@@ -83,7 +83,7 @@ def test_trust_remote_completion_status_true_returns_directly():
@pytest.mark.skipif(not A2A_SDK_INSTALLED, reason="Requires a2a-sdk to be installed")
def test_trust_remote_completion_status_false_continues_conversation():
"""When trust_remote_completion_status=False and A2A returns completed, ask server agent."""
from crewai.a2a.wrapper import _delegate_to_a2a
from crewai_a2a.wrapper import _delegate_to_a2a
from crewai import Agent, Task
a2a_config = A2AConfig(
@@ -116,8 +116,8 @@ def test_trust_remote_completion_status_false_continues_conversation():
return "unexpected"
with (
patch("crewai.a2a.wrapper.execute_a2a_delegation") as mock_execute,
patch("crewai.a2a.wrapper._fetch_agent_cards_concurrently") as mock_fetch,
patch("crewai_a2a.wrapper.execute_a2a_delegation") as mock_execute,
patch("crewai_a2a.wrapper._fetch_agent_cards_concurrently") as mock_fetch,
):
mock_card = _create_mock_agent_card()
mock_fetch.return_value = ({"http://test-endpoint.com/": mock_card}, {})

View File

@@ -7,7 +7,7 @@ import os
import pytest
from crewai import Agent
from crewai.a2a.config import A2AClientConfig
from crewai_a2a.config import A2AClientConfig
A2A_TEST_ENDPOINT = os.getenv(

View File

@@ -5,7 +5,7 @@ from unittest.mock import patch
import pytest
from crewai import Agent
from crewai.a2a.config import A2AConfig
from crewai_a2a.config import A2AConfig
try:
import a2a # noqa: F401

View File

@@ -106,6 +106,7 @@ ignore-decorators = ["typing.overload"]
"lib/crewai/tests/**/*.py" = ["S101", "RET504", "S105", "S106"] # Allow assert statements, unnecessary assignments, and hardcoded passwords in tests
"lib/crewai-tools/tests/**/*.py" = ["S101", "RET504", "S105", "S106", "RUF012", "N818", "E402", "RUF043", "S110", "B017"] # Allow various test-specific patterns
"lib/crewai-files/tests/**/*.py" = ["S101", "RET504", "S105", "S106", "B017", "F841"] # Allow assert statements and blind exception assertions in tests
"lib/crewai-a2a/tests/**/*.py" = ["S101", "RET504", "S105", "S106", "B017", "F821"] # Allow assert statements, unnecessary assignments, hardcoded passwords, blind exceptions, and forward refs in tests
[tool.mypy]
@@ -118,7 +119,7 @@ warn_return_any = true
show_error_codes = true
warn_unused_ignores = true
python_version = "3.12"
exclude = "(?x)(^lib/crewai/src/crewai/cli/templates/|^lib/crewai/tests/|^lib/crewai-tools/tests/|^lib/crewai-files/tests/)"
exclude = "(?x)(^lib/crewai/src/crewai/cli/templates/|^lib/crewai/tests/|^lib/crewai-tools/tests/|^lib/crewai-files/tests/|^lib/crewai-a2a/tests/)"
plugins = ["pydantic.mypy"]
@@ -134,6 +135,7 @@ testpaths = [
"lib/crewai/tests",
"lib/crewai-tools/tests",
"lib/crewai-files/tests",
"lib/crewai-a2a/tests",
]
asyncio_mode = "strict"
asyncio_default_fixture_loop_scope = "function"
@@ -157,6 +159,7 @@ members = [
"lib/crewai-tools",
"lib/devtools",
"lib/crewai-files",
"lib/crewai-a2a",
]
@@ -165,3 +168,4 @@ crewai = { workspace = true }
crewai-tools = { workspace = true }
crewai-devtools = { workspace = true }
crewai-files = { workspace = true }
crewai-a2a = { workspace = true }

31
uv.lock generated
View File

@@ -15,6 +15,7 @@ resolution-markers = [
[manifest]
members = [
"crewai",
"crewai-a2a",
"crewai-devtools",
"crewai-files",
"crewai-tools",
@@ -1124,10 +1125,7 @@ dependencies = [
[package.optional-dependencies]
a2a = [
{ name = "a2a-sdk" },
{ name = "aiocache", extra = ["memcached", "redis"] },
{ name = "httpx-auth" },
{ name = "httpx-sse" },
{ name = "crewai-a2a" },
]
anthropic = [
{ name = "anthropic" },
@@ -1181,9 +1179,7 @@ watson = [
[package.metadata]
requires-dist = [
{ name = "a2a-sdk", marker = "extra == 'a2a'", specifier = "~=0.3.10" },
{ name = "aiobotocore", marker = "extra == 'aws'", specifier = "~=2.25.2" },
{ name = "aiocache", extras = ["memcached", "redis"], marker = "extra == 'a2a'", specifier = "~=0.12.3" },
{ name = "aiosqlite", specifier = "~=0.21.0" },
{ name = "anthropic", marker = "extra == 'anthropic'", specifier = "~=0.73.0" },
{ name = "appdirs", specifier = "~=1.4.4" },
@@ -1192,13 +1188,12 @@ requires-dist = [
{ name = "boto3", marker = "extra == 'bedrock'", specifier = "~=1.40.45" },
{ name = "chromadb", specifier = "~=1.1.0" },
{ name = "click", specifier = "~=8.1.7" },
{ name = "crewai-a2a", marker = "extra == 'a2a'", editable = "lib/crewai-a2a" },
{ name = "crewai-files", marker = "extra == 'file-processing'", editable = "lib/crewai-files" },
{ name = "crewai-tools", marker = "extra == 'tools'", editable = "lib/crewai-tools" },
{ name = "docling", marker = "extra == 'docling'", specifier = "~=2.63.0" },
{ name = "google-genai", marker = "extra == 'google-genai'", specifier = "~=1.49.0" },
{ name = "httpx", specifier = "~=0.28.1" },
{ name = "httpx-auth", marker = "extra == 'a2a'", specifier = "~=0.23.1" },
{ name = "httpx-sse", marker = "extra == 'a2a'", specifier = "~=0.4.0" },
{ name = "ibm-watsonx-ai", marker = "extra == 'watson'", specifier = "~=1.3.39" },
{ name = "instructor", specifier = ">=1.3.3" },
{ name = "json-repair", specifier = "~=0.25.2" },
@@ -1233,6 +1228,26 @@ requires-dist = [
]
provides-extras = ["a2a", "anthropic", "aws", "azure-ai-inference", "bedrock", "docling", "embeddings", "file-processing", "google-genai", "litellm", "mem0", "openpyxl", "pandas", "qdrant", "tools", "voyageai", "watson"]
[[package]]
name = "crewai-a2a"
source = { editable = "lib/crewai-a2a" }
dependencies = [
{ name = "a2a-sdk" },
{ name = "aiocache", extra = ["memcached", "redis"] },
{ name = "crewai" },
{ name = "httpx-auth" },
{ name = "httpx-sse" },
]
[package.metadata]
requires-dist = [
{ name = "a2a-sdk", specifier = "~=0.3.10" },
{ name = "aiocache", extras = ["memcached", "redis"], specifier = "~=0.12.3" },
{ name = "crewai", editable = "lib/crewai" },
{ name = "httpx-auth", specifier = "~=0.23.1" },
{ name = "httpx-sse", specifier = "~=0.4.0" },
]
[[package]]
name = "crewai-devtools"
source = { editable = "lib/devtools" }