mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-04-05 02:28:15 +00:00
Compare commits
2 Commits
docs/add-v
...
gl/chore/m
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8b3acb58a4 | ||
|
|
8a3c2d5ca6 |
@@ -19,7 +19,7 @@ repos:
|
||||
language: system
|
||||
pass_filenames: true
|
||||
types: [python]
|
||||
exclude: ^(lib/crewai/src/crewai/cli/templates/|lib/crewai/tests/|lib/crewai-tools/tests/|lib/crewai-files/tests/)
|
||||
exclude: ^(lib/crewai/src/crewai/cli/templates/|lib/crewai/tests/|lib/crewai-tools/tests/|lib/crewai-files/tests/|lib/crewai-a2a/tests/)
|
||||
- repo: https://github.com/astral-sh/uv-pre-commit
|
||||
rev: 0.9.3
|
||||
hooks:
|
||||
|
||||
@@ -12,6 +12,7 @@ from dotenv import load_dotenv
|
||||
import pytest
|
||||
from vcr.request import Request # type: ignore[import-untyped]
|
||||
|
||||
|
||||
try:
|
||||
import vcr.stubs.httpx_stubs as httpx_stubs # type: ignore[import-untyped]
|
||||
except ModuleNotFoundError:
|
||||
@@ -225,7 +226,7 @@ def vcr_cassette_dir(request: Any) -> str:
|
||||
|
||||
for parent in test_file.parents:
|
||||
if (
|
||||
parent.name in ("crewai", "crewai-tools", "crewai-files")
|
||||
parent.name in ("crewai", "crewai-tools", "crewai-files", "crewai-a2a")
|
||||
and parent.parent.name == "lib"
|
||||
):
|
||||
package_root = parent
|
||||
|
||||
189
lib/crewai-a2a/README.md
Normal file
189
lib/crewai-a2a/README.md
Normal file
@@ -0,0 +1,189 @@
|
||||
# crewai-a2a
|
||||
|
||||
Agent-to-Agent (A2A) protocol support for CrewAI. Enables agents to discover, authenticate, and communicate with remote A2A-compatible agents.
|
||||
|
||||
## Quick Links
|
||||
|
||||
[Homepage](https://www.crewai.com/) | [Documentation](https://docs.crewai.com/) | [Community](https://community.crewai.com/)
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
uv pip install crewai[a2a]
|
||||
# or
|
||||
uv add 'crewai[a2a]'
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
### Connecting to a Remote A2A Agent
|
||||
|
||||
```python
|
||||
from crewai import Agent
|
||||
from crewai_a2a import A2AClientConfig
|
||||
|
||||
agent = Agent(
|
||||
role="Coordinator",
|
||||
goal="Delegate research tasks",
|
||||
a2a=[
|
||||
A2AClientConfig(endpoint="https://research-agent.example.com"),
|
||||
],
|
||||
)
|
||||
```
|
||||
|
||||
### Exposing an Agent as an A2A Server
|
||||
|
||||
```python
|
||||
from crewai import Agent
|
||||
from crewai_a2a import A2AServerConfig
|
||||
|
||||
agent = Agent(
|
||||
role="Researcher",
|
||||
goal="Answer research questions",
|
||||
a2a_server=A2AServerConfig(
|
||||
name="Research Agent",
|
||||
description="Answers research questions using web search",
|
||||
),
|
||||
)
|
||||
```
|
||||
|
||||
## Authentication
|
||||
|
||||
### Client Schemes
|
||||
|
||||
```python
|
||||
from crewai_a2a.auth import (
|
||||
BearerTokenAuth,
|
||||
HTTPBasicAuth,
|
||||
APIKeyAuth,
|
||||
OAuth2ClientCredentials,
|
||||
)
|
||||
from crewai_a2a.config import A2AClientConfig
|
||||
|
||||
# Bearer token
|
||||
A2AClientConfig(
|
||||
endpoint="https://agent.example.com",
|
||||
auth=BearerTokenAuth(token="my-token"),
|
||||
)
|
||||
|
||||
# API key
|
||||
A2AClientConfig(
|
||||
endpoint="https://agent.example.com",
|
||||
auth=APIKeyAuth(api_key="key", location="header", name="X-API-Key"),
|
||||
)
|
||||
|
||||
# OAuth2 client credentials
|
||||
A2AClientConfig(
|
||||
endpoint="https://agent.example.com",
|
||||
auth=OAuth2ClientCredentials(
|
||||
token_url="https://auth.example.com/token",
|
||||
client_id="id",
|
||||
client_secret="secret",
|
||||
),
|
||||
)
|
||||
```
|
||||
|
||||
### Server Schemes
|
||||
|
||||
```python
|
||||
from crewai_a2a.auth import SimpleTokenAuth, OIDCAuth
|
||||
from crewai_a2a.config import A2AServerConfig
|
||||
|
||||
# Simple token validation
|
||||
A2AServerConfig(auth=SimpleTokenAuth(token="expected-token"))
|
||||
|
||||
# OpenID Connect
|
||||
A2AServerConfig(
|
||||
auth=OIDCAuth(
|
||||
issuer="https://auth.example.com",
|
||||
audience="my-agent",
|
||||
),
|
||||
)
|
||||
```
|
||||
|
||||
## Update Mechanisms
|
||||
|
||||
Control how the client receives task updates from remote agents.
|
||||
|
||||
```python
|
||||
from crewai_a2a.updates import PollingConfig, StreamingConfig, PushNotificationConfig
|
||||
from crewai_a2a.config import A2AClientConfig
|
||||
|
||||
|
||||
# Polling
|
||||
A2AClientConfig(
|
||||
endpoint="https://agent.example.com",
|
||||
updates=PollingConfig(interval=2.0, timeout=60),
|
||||
)
|
||||
|
||||
# Server-Sent Events streaming
|
||||
A2AClientConfig(
|
||||
endpoint="https://agent.example.com",
|
||||
updates=StreamingConfig(),
|
||||
)
|
||||
|
||||
# Webhook push notifications
|
||||
A2AClientConfig(
|
||||
endpoint="https://agent.example.com",
|
||||
updates=PushNotificationConfig(
|
||||
url="https://my-server.example.com/webhook",
|
||||
timeout=300,
|
||||
),
|
||||
)
|
||||
```
|
||||
|
||||
## Extensions
|
||||
|
||||
### Client Extensions
|
||||
|
||||
Client extensions inject tools, augment prompts, and process responses.
|
||||
|
||||
```python
|
||||
from crewai_a2a.extensions import A2AExtension
|
||||
|
||||
|
||||
class MyExtension(A2AExtension):
|
||||
def inject_tools(self, agent):
|
||||
...
|
||||
|
||||
def augment_prompt(self, base_prompt, conversation_state):
|
||||
return f"{base_prompt}\n\nAdditional context from extension."
|
||||
```
|
||||
|
||||
### Server Extensions
|
||||
|
||||
Server extensions add protocol-level capabilities to your A2A server.
|
||||
|
||||
```python
|
||||
from crewai_a2a.extensions import ServerExtension
|
||||
|
||||
class MyServerExtension(ServerExtension):
|
||||
uri = "urn:my-org:my-extension"
|
||||
description = "Custom protocol extension"
|
||||
|
||||
async def on_request(self, context):
|
||||
...
|
||||
|
||||
async def on_response(self, context, result):
|
||||
...
|
||||
```
|
||||
|
||||
## Transport
|
||||
|
||||
Three transport protocols are supported: JSON-RPC (default), gRPC, and HTTP+JSON.
|
||||
|
||||
```python
|
||||
from crewai_a2a.config import ClientTransportConfig, GRPCClientConfig
|
||||
from crewai_a2a.config import A2AClientConfig
|
||||
|
||||
|
||||
A2AClientConfig(
|
||||
endpoint="https://agent.example.com",
|
||||
transport=ClientTransportConfig(
|
||||
preferred="GRPC",
|
||||
grpc=GRPCClientConfig(
|
||||
max_send_message_length=4 * 1024 * 1024,
|
||||
),
|
||||
),
|
||||
)
|
||||
```
|
||||
25
lib/crewai-a2a/pyproject.toml
Normal file
25
lib/crewai-a2a/pyproject.toml
Normal file
@@ -0,0 +1,25 @@
|
||||
[project]
|
||||
name = "crewai-a2a"
|
||||
dynamic = ["version"]
|
||||
description = "A2A (Agent-to-Agent) protocol support for CrewAI"
|
||||
readme = "README.md"
|
||||
authors = [{ name = "Greyson LaLonde", email = "greyson@crewai.com" }]
|
||||
requires-python = ">=3.10, <3.14"
|
||||
dependencies = [
|
||||
"crewai==1.10.1b1",
|
||||
"a2a-sdk~=0.3.10",
|
||||
"httpx-auth~=0.23.1",
|
||||
"httpx-sse~=0.4.0",
|
||||
"aiocache[redis,memcached]~=0.12.3",
|
||||
]
|
||||
|
||||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
[tool.hatch.version]
|
||||
path = "src/crewai_a2a/__init__.py"
|
||||
|
||||
[tool.uv.sources]
|
||||
crewai = { workspace = true }
|
||||
crewai-files = { workspace = true }
|
||||
12
lib/crewai-a2a/src/crewai_a2a/__init__.py
Normal file
12
lib/crewai-a2a/src/crewai_a2a/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
||||
"""Agent-to-Agent (A2A) protocol communication module for CrewAI."""
|
||||
|
||||
__version__ = "1.10.1b1"
|
||||
|
||||
from crewai_a2a.config import A2AClientConfig, A2AConfig, A2AServerConfig
|
||||
|
||||
|
||||
__all__ = [
|
||||
"A2AClientConfig",
|
||||
"A2AConfig",
|
||||
"A2AServerConfig",
|
||||
]
|
||||
36
lib/crewai-a2a/src/crewai_a2a/auth/__init__.py
Normal file
36
lib/crewai-a2a/src/crewai_a2a/auth/__init__.py
Normal file
@@ -0,0 +1,36 @@
|
||||
"""A2A authentication schemas."""
|
||||
|
||||
from crewai_a2a.auth.client_schemes import (
|
||||
APIKeyAuth,
|
||||
AuthScheme,
|
||||
BearerTokenAuth,
|
||||
ClientAuthScheme,
|
||||
HTTPBasicAuth,
|
||||
HTTPDigestAuth,
|
||||
OAuth2AuthorizationCode,
|
||||
OAuth2ClientCredentials,
|
||||
TLSConfig,
|
||||
)
|
||||
from crewai_a2a.auth.server_schemes import (
|
||||
AuthenticatedUser,
|
||||
OIDCAuth,
|
||||
ServerAuthScheme,
|
||||
SimpleTokenAuth,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"APIKeyAuth",
|
||||
"AuthScheme",
|
||||
"AuthenticatedUser",
|
||||
"BearerTokenAuth",
|
||||
"ClientAuthScheme",
|
||||
"HTTPBasicAuth",
|
||||
"HTTPDigestAuth",
|
||||
"OAuth2AuthorizationCode",
|
||||
"OAuth2ClientCredentials",
|
||||
"OIDCAuth",
|
||||
"ServerAuthScheme",
|
||||
"SimpleTokenAuth",
|
||||
"TLSConfig",
|
||||
]
|
||||
550
lib/crewai-a2a/src/crewai_a2a/auth/client_schemes.py
Normal file
550
lib/crewai-a2a/src/crewai_a2a/auth/client_schemes.py
Normal file
@@ -0,0 +1,550 @@
|
||||
"""Authentication schemes for A2A protocol clients.
|
||||
|
||||
Supported authentication methods:
|
||||
- Bearer tokens
|
||||
- OAuth2 (Client Credentials, Authorization Code)
|
||||
- API Keys (header, query, cookie)
|
||||
- HTTP Basic authentication
|
||||
- HTTP Digest authentication
|
||||
- mTLS (mutual TLS) client certificate authentication
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
import asyncio
|
||||
import base64
|
||||
from collections.abc import Awaitable, Callable, MutableMapping
|
||||
from pathlib import Path
|
||||
import ssl
|
||||
import time
|
||||
from typing import TYPE_CHECKING, ClassVar, Literal
|
||||
import urllib.parse
|
||||
|
||||
import httpx
|
||||
from httpx import DigestAuth
|
||||
from pydantic import BaseModel, ConfigDict, Field, FilePath, PrivateAttr
|
||||
from typing_extensions import deprecated
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import grpc # type: ignore[import-untyped]
|
||||
|
||||
|
||||
class TLSConfig(BaseModel):
|
||||
"""TLS/mTLS configuration for secure client connections.
|
||||
|
||||
Supports mutual TLS (mTLS) where the client presents a certificate to the server,
|
||||
and standard TLS with custom CA verification.
|
||||
|
||||
Attributes:
|
||||
client_cert_path: Path to client certificate file (PEM format) for mTLS.
|
||||
client_key_path: Path to client private key file (PEM format) for mTLS.
|
||||
ca_cert_path: Path to CA certificate bundle for server verification.
|
||||
verify: Whether to verify server certificates. Set False only for development.
|
||||
"""
|
||||
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
|
||||
|
||||
client_cert_path: FilePath | None = Field(
|
||||
default=None,
|
||||
description="Path to client certificate file (PEM format) for mTLS",
|
||||
)
|
||||
client_key_path: FilePath | None = Field(
|
||||
default=None,
|
||||
description="Path to client private key file (PEM format) for mTLS",
|
||||
)
|
||||
ca_cert_path: FilePath | None = Field(
|
||||
default=None,
|
||||
description="Path to CA certificate bundle for server verification",
|
||||
)
|
||||
verify: bool = Field(
|
||||
default=True,
|
||||
description="Whether to verify server certificates. Set False only for development.",
|
||||
)
|
||||
|
||||
def get_httpx_ssl_context(self) -> ssl.SSLContext | bool | str:
|
||||
"""Build SSL context for httpx client.
|
||||
|
||||
Returns:
|
||||
SSL context if certificates configured, True for default verification,
|
||||
False if verification disabled, or path to CA bundle.
|
||||
"""
|
||||
if not self.verify:
|
||||
return False
|
||||
|
||||
if self.client_cert_path and self.client_key_path:
|
||||
context = ssl.create_default_context()
|
||||
|
||||
if self.ca_cert_path:
|
||||
context.load_verify_locations(cafile=str(self.ca_cert_path))
|
||||
|
||||
context.load_cert_chain(
|
||||
certfile=str(self.client_cert_path),
|
||||
keyfile=str(self.client_key_path),
|
||||
)
|
||||
return context
|
||||
|
||||
if self.ca_cert_path:
|
||||
return str(self.ca_cert_path)
|
||||
|
||||
return True
|
||||
|
||||
def get_grpc_credentials(self) -> grpc.ChannelCredentials | None: # type: ignore[no-any-unimported]
|
||||
"""Build gRPC channel credentials for secure connections.
|
||||
|
||||
Returns:
|
||||
gRPC SSL credentials if certificates configured, None otherwise.
|
||||
"""
|
||||
try:
|
||||
import grpc
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
if not self.verify and not self.client_cert_path:
|
||||
return None
|
||||
|
||||
root_certs: bytes | None = None
|
||||
private_key: bytes | None = None
|
||||
certificate_chain: bytes | None = None
|
||||
|
||||
if self.ca_cert_path:
|
||||
root_certs = Path(self.ca_cert_path).read_bytes()
|
||||
|
||||
if self.client_cert_path and self.client_key_path:
|
||||
private_key = Path(self.client_key_path).read_bytes()
|
||||
certificate_chain = Path(self.client_cert_path).read_bytes()
|
||||
|
||||
return grpc.ssl_channel_credentials(
|
||||
root_certificates=root_certs,
|
||||
private_key=private_key,
|
||||
certificate_chain=certificate_chain,
|
||||
)
|
||||
|
||||
|
||||
class ClientAuthScheme(ABC, BaseModel):
|
||||
"""Base class for client-side authentication schemes.
|
||||
|
||||
Client auth schemes apply credentials to outgoing requests.
|
||||
|
||||
Attributes:
|
||||
tls: Optional TLS/mTLS configuration for secure connections.
|
||||
"""
|
||||
|
||||
tls: TLSConfig | None = Field(
|
||||
default=None,
|
||||
description="TLS/mTLS configuration for secure connections",
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
async def apply_auth(
|
||||
self, client: httpx.AsyncClient, headers: MutableMapping[str, str]
|
||||
) -> MutableMapping[str, str]:
|
||||
"""Apply authentication to request headers.
|
||||
|
||||
Args:
|
||||
client: HTTP client for making auth requests.
|
||||
headers: Current request headers.
|
||||
|
||||
Returns:
|
||||
Updated headers with authentication applied.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
@deprecated("Use ClientAuthScheme instead", category=FutureWarning)
|
||||
class AuthScheme(ClientAuthScheme):
|
||||
"""Deprecated: Use ClientAuthScheme instead."""
|
||||
|
||||
|
||||
class BearerTokenAuth(ClientAuthScheme):
|
||||
"""Bearer token authentication (Authorization: Bearer <token>).
|
||||
|
||||
Attributes:
|
||||
token: Bearer token for authentication.
|
||||
"""
|
||||
|
||||
token: str = Field(description="Bearer token")
|
||||
|
||||
async def apply_auth(
|
||||
self, client: httpx.AsyncClient, headers: MutableMapping[str, str]
|
||||
) -> MutableMapping[str, str]:
|
||||
"""Apply Bearer token to Authorization header.
|
||||
|
||||
Args:
|
||||
client: HTTP client for making auth requests.
|
||||
headers: Current request headers.
|
||||
|
||||
Returns:
|
||||
Updated headers with Bearer token in Authorization header.
|
||||
"""
|
||||
headers["Authorization"] = f"Bearer {self.token}"
|
||||
return headers
|
||||
|
||||
|
||||
class HTTPBasicAuth(ClientAuthScheme):
|
||||
"""HTTP Basic authentication.
|
||||
|
||||
Attributes:
|
||||
username: Username for Basic authentication.
|
||||
password: Password for Basic authentication.
|
||||
"""
|
||||
|
||||
username: str = Field(description="Username")
|
||||
password: str = Field(description="Password")
|
||||
|
||||
async def apply_auth(
|
||||
self, client: httpx.AsyncClient, headers: MutableMapping[str, str]
|
||||
) -> MutableMapping[str, str]:
|
||||
"""Apply HTTP Basic authentication.
|
||||
|
||||
Args:
|
||||
client: HTTP client for making auth requests.
|
||||
headers: Current request headers.
|
||||
|
||||
Returns:
|
||||
Updated headers with Basic auth in Authorization header.
|
||||
"""
|
||||
credentials = f"{self.username}:{self.password}"
|
||||
encoded = base64.b64encode(credentials.encode()).decode()
|
||||
headers["Authorization"] = f"Basic {encoded}"
|
||||
return headers
|
||||
|
||||
|
||||
class HTTPDigestAuth(ClientAuthScheme):
|
||||
"""HTTP Digest authentication.
|
||||
|
||||
Note: Uses httpx-auth library for digest implementation.
|
||||
|
||||
Attributes:
|
||||
username: Username for Digest authentication.
|
||||
password: Password for Digest authentication.
|
||||
"""
|
||||
|
||||
username: str = Field(description="Username")
|
||||
password: str = Field(description="Password")
|
||||
|
||||
_configured_client_id: int | None = PrivateAttr(default=None)
|
||||
|
||||
async def apply_auth(
|
||||
self, client: httpx.AsyncClient, headers: MutableMapping[str, str]
|
||||
) -> MutableMapping[str, str]:
|
||||
"""Digest auth is handled by httpx auth flow, not headers.
|
||||
|
||||
Args:
|
||||
client: HTTP client for making auth requests.
|
||||
headers: Current request headers.
|
||||
|
||||
Returns:
|
||||
Unchanged headers (Digest auth handled by httpx auth flow).
|
||||
"""
|
||||
return headers
|
||||
|
||||
def configure_client(self, client: httpx.AsyncClient) -> None:
|
||||
"""Configure client with Digest auth.
|
||||
|
||||
Idempotent: Only configures the client once. Subsequent calls on the same
|
||||
client instance are no-ops to prevent overwriting auth configuration.
|
||||
|
||||
Args:
|
||||
client: HTTP client to configure with Digest authentication.
|
||||
"""
|
||||
client_id = id(client)
|
||||
if self._configured_client_id == client_id:
|
||||
return
|
||||
|
||||
client.auth = DigestAuth(self.username, self.password)
|
||||
self._configured_client_id = client_id
|
||||
|
||||
|
||||
class APIKeyAuth(ClientAuthScheme):
|
||||
"""API Key authentication (header, query, or cookie).
|
||||
|
||||
Attributes:
|
||||
api_key: API key value for authentication.
|
||||
location: Where to send the API key (header, query, or cookie).
|
||||
name: Parameter name for the API key (default: X-API-Key).
|
||||
"""
|
||||
|
||||
api_key: str = Field(description="API key value")
|
||||
location: Literal["header", "query", "cookie"] = Field(
|
||||
default="header", description="Where to send the API key"
|
||||
)
|
||||
name: str = Field(default="X-API-Key", description="Parameter name for the API key")
|
||||
|
||||
_configured_client_ids: set[int] = PrivateAttr(default_factory=set)
|
||||
|
||||
async def apply_auth(
|
||||
self, client: httpx.AsyncClient, headers: MutableMapping[str, str]
|
||||
) -> MutableMapping[str, str]:
|
||||
"""Apply API key authentication.
|
||||
|
||||
Args:
|
||||
client: HTTP client for making auth requests.
|
||||
headers: Current request headers.
|
||||
|
||||
Returns:
|
||||
Updated headers with API key (for header/cookie locations).
|
||||
"""
|
||||
if self.location == "header":
|
||||
headers[self.name] = self.api_key
|
||||
elif self.location == "cookie":
|
||||
headers["Cookie"] = f"{self.name}={self.api_key}"
|
||||
return headers
|
||||
|
||||
def configure_client(self, client: httpx.AsyncClient) -> None:
|
||||
"""Configure client for query param API keys.
|
||||
|
||||
Idempotent: Only adds the request hook once per client instance.
|
||||
Subsequent calls on the same client are no-ops to prevent hook accumulation.
|
||||
|
||||
Args:
|
||||
client: HTTP client to configure with query param API key hook.
|
||||
"""
|
||||
if self.location == "query":
|
||||
client_id = id(client)
|
||||
if client_id in self._configured_client_ids:
|
||||
return
|
||||
|
||||
async def _add_api_key_param(request: httpx.Request) -> None:
|
||||
url = httpx.URL(request.url)
|
||||
request.url = url.copy_add_param(self.name, self.api_key)
|
||||
|
||||
client.event_hooks["request"].append(_add_api_key_param)
|
||||
self._configured_client_ids.add(client_id)
|
||||
|
||||
|
||||
class OAuth2ClientCredentials(ClientAuthScheme):
|
||||
"""OAuth2 Client Credentials flow authentication.
|
||||
|
||||
Thread-safe implementation with asyncio.Lock to prevent concurrent token fetches
|
||||
when multiple requests share the same auth instance.
|
||||
|
||||
Attributes:
|
||||
token_url: OAuth2 token endpoint URL.
|
||||
client_id: OAuth2 client identifier.
|
||||
client_secret: OAuth2 client secret.
|
||||
scopes: List of required OAuth2 scopes.
|
||||
"""
|
||||
|
||||
token_url: str = Field(description="OAuth2 token endpoint")
|
||||
client_id: str = Field(description="OAuth2 client ID")
|
||||
client_secret: str = Field(description="OAuth2 client secret")
|
||||
scopes: list[str] = Field(
|
||||
default_factory=list, description="Required OAuth2 scopes"
|
||||
)
|
||||
|
||||
_access_token: str | None = PrivateAttr(default=None)
|
||||
_token_expires_at: float | None = PrivateAttr(default=None)
|
||||
_lock: asyncio.Lock = PrivateAttr(default_factory=asyncio.Lock)
|
||||
|
||||
async def apply_auth(
|
||||
self, client: httpx.AsyncClient, headers: MutableMapping[str, str]
|
||||
) -> MutableMapping[str, str]:
|
||||
"""Apply OAuth2 access token to Authorization header.
|
||||
|
||||
Uses asyncio.Lock to ensure only one coroutine fetches tokens at a time,
|
||||
preventing race conditions when multiple concurrent requests use the same
|
||||
auth instance.
|
||||
|
||||
Args:
|
||||
client: HTTP client for making token requests.
|
||||
headers: Current request headers.
|
||||
|
||||
Returns:
|
||||
Updated headers with OAuth2 access token in Authorization header.
|
||||
"""
|
||||
if (
|
||||
self._access_token is None
|
||||
or self._token_expires_at is None
|
||||
or time.time() >= self._token_expires_at
|
||||
):
|
||||
async with self._lock:
|
||||
if (
|
||||
self._access_token is None
|
||||
or self._token_expires_at is None
|
||||
or time.time() >= self._token_expires_at
|
||||
):
|
||||
await self._fetch_token(client)
|
||||
|
||||
if self._access_token:
|
||||
headers["Authorization"] = f"Bearer {self._access_token}"
|
||||
|
||||
return headers
|
||||
|
||||
async def _fetch_token(self, client: httpx.AsyncClient) -> None:
|
||||
"""Fetch OAuth2 access token using client credentials flow.
|
||||
|
||||
Args:
|
||||
client: HTTP client for making token request.
|
||||
|
||||
Raises:
|
||||
httpx.HTTPStatusError: If token request fails.
|
||||
"""
|
||||
data = {
|
||||
"grant_type": "client_credentials",
|
||||
"client_id": self.client_id,
|
||||
"client_secret": self.client_secret,
|
||||
}
|
||||
|
||||
if self.scopes:
|
||||
data["scope"] = " ".join(self.scopes)
|
||||
|
||||
response = await client.post(self.token_url, data=data)
|
||||
response.raise_for_status()
|
||||
|
||||
token_data = response.json()
|
||||
self._access_token = token_data["access_token"]
|
||||
expires_in = token_data.get("expires_in", 3600)
|
||||
self._token_expires_at = time.time() + expires_in - 60
|
||||
|
||||
|
||||
class OAuth2AuthorizationCode(ClientAuthScheme):
|
||||
"""OAuth2 Authorization Code flow authentication.
|
||||
|
||||
Thread-safe implementation with asyncio.Lock to prevent concurrent token operations.
|
||||
|
||||
Note: Requires interactive authorization.
|
||||
|
||||
Attributes:
|
||||
authorization_url: OAuth2 authorization endpoint URL.
|
||||
token_url: OAuth2 token endpoint URL.
|
||||
client_id: OAuth2 client identifier.
|
||||
client_secret: OAuth2 client secret.
|
||||
redirect_uri: OAuth2 redirect URI for callback.
|
||||
scopes: List of required OAuth2 scopes.
|
||||
"""
|
||||
|
||||
authorization_url: str = Field(description="OAuth2 authorization endpoint")
|
||||
token_url: str = Field(description="OAuth2 token endpoint")
|
||||
client_id: str = Field(description="OAuth2 client ID")
|
||||
client_secret: str = Field(description="OAuth2 client secret")
|
||||
redirect_uri: str = Field(description="OAuth2 redirect URI")
|
||||
scopes: list[str] = Field(
|
||||
default_factory=list, description="Required OAuth2 scopes"
|
||||
)
|
||||
|
||||
_access_token: str | None = PrivateAttr(default=None)
|
||||
_refresh_token: str | None = PrivateAttr(default=None)
|
||||
_token_expires_at: float | None = PrivateAttr(default=None)
|
||||
_authorization_callback: Callable[[str], Awaitable[str]] | None = PrivateAttr(
|
||||
default=None
|
||||
)
|
||||
_lock: asyncio.Lock = PrivateAttr(default_factory=asyncio.Lock)
|
||||
|
||||
def set_authorization_callback(
|
||||
self, callback: Callable[[str], Awaitable[str]] | None
|
||||
) -> None:
|
||||
"""Set callback to handle authorization URL.
|
||||
|
||||
Args:
|
||||
callback: Async function that receives authorization URL and returns auth code.
|
||||
"""
|
||||
self._authorization_callback = callback
|
||||
|
||||
async def apply_auth(
|
||||
self, client: httpx.AsyncClient, headers: MutableMapping[str, str]
|
||||
) -> MutableMapping[str, str]:
|
||||
"""Apply OAuth2 access token to Authorization header.
|
||||
|
||||
Uses asyncio.Lock to ensure only one coroutine handles token operations
|
||||
(initial fetch or refresh) at a time.
|
||||
|
||||
Args:
|
||||
client: HTTP client for making token requests.
|
||||
headers: Current request headers.
|
||||
|
||||
Returns:
|
||||
Updated headers with OAuth2 access token in Authorization header.
|
||||
|
||||
Raises:
|
||||
ValueError: If authorization callback is not set.
|
||||
"""
|
||||
if self._access_token is None:
|
||||
if self._authorization_callback is None:
|
||||
msg = "Authorization callback not set. Use set_authorization_callback()"
|
||||
raise ValueError(msg)
|
||||
async with self._lock:
|
||||
if self._access_token is None:
|
||||
await self._fetch_initial_token(client)
|
||||
elif self._token_expires_at and time.time() >= self._token_expires_at:
|
||||
async with self._lock:
|
||||
if self._token_expires_at and time.time() >= self._token_expires_at:
|
||||
await self._refresh_access_token(client)
|
||||
|
||||
if self._access_token:
|
||||
headers["Authorization"] = f"Bearer {self._access_token}"
|
||||
|
||||
return headers
|
||||
|
||||
async def _fetch_initial_token(self, client: httpx.AsyncClient) -> None:
|
||||
"""Fetch initial access token using authorization code flow.
|
||||
|
||||
Args:
|
||||
client: HTTP client for making token request.
|
||||
|
||||
Raises:
|
||||
ValueError: If authorization callback is not set.
|
||||
httpx.HTTPStatusError: If token request fails.
|
||||
"""
|
||||
params = {
|
||||
"response_type": "code",
|
||||
"client_id": self.client_id,
|
||||
"redirect_uri": self.redirect_uri,
|
||||
"scope": " ".join(self.scopes),
|
||||
}
|
||||
auth_url = f"{self.authorization_url}?{urllib.parse.urlencode(params)}"
|
||||
|
||||
if self._authorization_callback is None:
|
||||
msg = "Authorization callback not set"
|
||||
raise ValueError(msg)
|
||||
auth_code = await self._authorization_callback(auth_url)
|
||||
|
||||
data = {
|
||||
"grant_type": "authorization_code",
|
||||
"code": auth_code,
|
||||
"client_id": self.client_id,
|
||||
"client_secret": self.client_secret,
|
||||
"redirect_uri": self.redirect_uri,
|
||||
}
|
||||
|
||||
response = await client.post(self.token_url, data=data)
|
||||
response.raise_for_status()
|
||||
|
||||
token_data = response.json()
|
||||
self._access_token = token_data["access_token"]
|
||||
self._refresh_token = token_data.get("refresh_token")
|
||||
|
||||
expires_in = token_data.get("expires_in", 3600)
|
||||
self._token_expires_at = time.time() + expires_in - 60
|
||||
|
||||
async def _refresh_access_token(self, client: httpx.AsyncClient) -> None:
|
||||
"""Refresh the access token using refresh token.
|
||||
|
||||
Args:
|
||||
client: HTTP client for making token request.
|
||||
|
||||
Raises:
|
||||
httpx.HTTPStatusError: If token refresh request fails.
|
||||
"""
|
||||
if not self._refresh_token:
|
||||
await self._fetch_initial_token(client)
|
||||
return
|
||||
|
||||
data = {
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": self._refresh_token,
|
||||
"client_id": self.client_id,
|
||||
"client_secret": self.client_secret,
|
||||
}
|
||||
|
||||
response = await client.post(self.token_url, data=data)
|
||||
response.raise_for_status()
|
||||
|
||||
token_data = response.json()
|
||||
self._access_token = token_data["access_token"]
|
||||
if "refresh_token" in token_data:
|
||||
self._refresh_token = token_data["refresh_token"]
|
||||
|
||||
expires_in = token_data.get("expires_in", 3600)
|
||||
self._token_expires_at = time.time() + expires_in - 60
|
||||
71
lib/crewai-a2a/src/crewai_a2a/auth/schemas.py
Normal file
71
lib/crewai-a2a/src/crewai_a2a/auth/schemas.py
Normal file
@@ -0,0 +1,71 @@
|
||||
"""Deprecated: Authentication schemes for A2A protocol agents.
|
||||
|
||||
This module is deprecated. Import from crewai_a2a.auth instead:
|
||||
- crewai_a2a.auth.ClientAuthScheme (replaces AuthScheme)
|
||||
- crewai_a2a.auth.BearerTokenAuth
|
||||
- crewai_a2a.auth.HTTPBasicAuth
|
||||
- crewai_a2a.auth.HTTPDigestAuth
|
||||
- crewai_a2a.auth.APIKeyAuth
|
||||
- crewai_a2a.auth.OAuth2ClientCredentials
|
||||
- crewai_a2a.auth.OAuth2AuthorizationCode
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing_extensions import deprecated
|
||||
|
||||
from crewai_a2a.auth.client_schemes import (
|
||||
APIKeyAuth as _APIKeyAuth,
|
||||
BearerTokenAuth as _BearerTokenAuth,
|
||||
ClientAuthScheme as _ClientAuthScheme,
|
||||
HTTPBasicAuth as _HTTPBasicAuth,
|
||||
HTTPDigestAuth as _HTTPDigestAuth,
|
||||
OAuth2AuthorizationCode as _OAuth2AuthorizationCode,
|
||||
OAuth2ClientCredentials as _OAuth2ClientCredentials,
|
||||
)
|
||||
|
||||
|
||||
@deprecated("Use ClientAuthScheme from crewai_a2a.auth instead", category=FutureWarning)
|
||||
class AuthScheme(_ClientAuthScheme):
|
||||
"""Deprecated: Use ClientAuthScheme from crewai_a2a.auth instead."""
|
||||
|
||||
|
||||
@deprecated("Import from crewai_a2a.auth instead", category=FutureWarning)
|
||||
class BearerTokenAuth(_BearerTokenAuth):
|
||||
"""Deprecated: Import from crewai_a2a.auth instead."""
|
||||
|
||||
|
||||
@deprecated("Import from crewai_a2a.auth instead", category=FutureWarning)
|
||||
class HTTPBasicAuth(_HTTPBasicAuth):
|
||||
"""Deprecated: Import from crewai_a2a.auth instead."""
|
||||
|
||||
|
||||
@deprecated("Import from crewai_a2a.auth instead", category=FutureWarning)
|
||||
class HTTPDigestAuth(_HTTPDigestAuth):
|
||||
"""Deprecated: Import from crewai_a2a.auth instead."""
|
||||
|
||||
|
||||
@deprecated("Import from crewai_a2a.auth instead", category=FutureWarning)
|
||||
class APIKeyAuth(_APIKeyAuth):
|
||||
"""Deprecated: Import from crewai_a2a.auth instead."""
|
||||
|
||||
|
||||
@deprecated("Import from crewai_a2a.auth instead", category=FutureWarning)
|
||||
class OAuth2ClientCredentials(_OAuth2ClientCredentials):
|
||||
"""Deprecated: Import from crewai_a2a.auth instead."""
|
||||
|
||||
|
||||
@deprecated("Import from crewai_a2a.auth instead", category=FutureWarning)
|
||||
class OAuth2AuthorizationCode(_OAuth2AuthorizationCode):
|
||||
"""Deprecated: Import from crewai_a2a.auth instead."""
|
||||
|
||||
|
||||
__all__ = [
|
||||
"APIKeyAuth",
|
||||
"AuthScheme",
|
||||
"BearerTokenAuth",
|
||||
"HTTPBasicAuth",
|
||||
"HTTPDigestAuth",
|
||||
"OAuth2AuthorizationCode",
|
||||
"OAuth2ClientCredentials",
|
||||
]
|
||||
742
lib/crewai-a2a/src/crewai_a2a/auth/server_schemes.py
Normal file
742
lib/crewai-a2a/src/crewai_a2a/auth/server_schemes.py
Normal file
@@ -0,0 +1,742 @@
|
||||
"""Server-side authentication schemes for A2A protocol.
|
||||
|
||||
These schemes validate incoming requests to A2A server endpoints.
|
||||
|
||||
Supported authentication methods:
|
||||
- Simple token validation with static bearer tokens
|
||||
- OpenID Connect with JWT validation using JWKS
|
||||
- OAuth2 with JWT validation or token introspection
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
import logging
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Annotated, Any, ClassVar, Literal
|
||||
|
||||
import jwt
|
||||
from jwt import PyJWKClient
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
BeforeValidator,
|
||||
ConfigDict,
|
||||
Field,
|
||||
HttpUrl,
|
||||
PrivateAttr,
|
||||
SecretStr,
|
||||
model_validator,
|
||||
)
|
||||
from typing_extensions import Self
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from a2a.types import OAuth2SecurityScheme
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
try:
|
||||
from fastapi import ( # type: ignore[import-not-found]
|
||||
HTTPException,
|
||||
status as http_status,
|
||||
)
|
||||
|
||||
HTTP_401_UNAUTHORIZED = http_status.HTTP_401_UNAUTHORIZED
|
||||
HTTP_500_INTERNAL_SERVER_ERROR = http_status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
HTTP_503_SERVICE_UNAVAILABLE = http_status.HTTP_503_SERVICE_UNAVAILABLE
|
||||
except ImportError:
|
||||
|
||||
class HTTPException(Exception): # type: ignore[no-redef] # noqa: N818
|
||||
"""Fallback HTTPException when FastAPI is not installed."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
status_code: int,
|
||||
detail: str | None = None,
|
||||
headers: dict[str, str] | None = None,
|
||||
) -> None:
|
||||
self.status_code = status_code
|
||||
self.detail = detail
|
||||
self.headers = headers
|
||||
super().__init__(detail)
|
||||
|
||||
HTTP_401_UNAUTHORIZED = 401
|
||||
HTTP_500_INTERNAL_SERVER_ERROR = 500
|
||||
HTTP_503_SERVICE_UNAVAILABLE = 503
|
||||
|
||||
|
||||
def _coerce_secret_str(v: str | SecretStr | None) -> SecretStr | None:
|
||||
"""Coerce string to SecretStr."""
|
||||
if v is None or isinstance(v, SecretStr):
|
||||
return v
|
||||
return SecretStr(v)
|
||||
|
||||
|
||||
CoercedSecretStr = Annotated[SecretStr, BeforeValidator(_coerce_secret_str)]
|
||||
|
||||
JWTAlgorithm = Literal[
|
||||
"RS256",
|
||||
"RS384",
|
||||
"RS512",
|
||||
"ES256",
|
||||
"ES384",
|
||||
"ES512",
|
||||
"PS256",
|
||||
"PS384",
|
||||
"PS512",
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class AuthenticatedUser:
|
||||
"""Result of successful authentication.
|
||||
|
||||
Attributes:
|
||||
token: The original token that was validated.
|
||||
scheme: Name of the authentication scheme used.
|
||||
claims: JWT claims from OIDC or OAuth2 authentication.
|
||||
"""
|
||||
|
||||
token: str
|
||||
scheme: str
|
||||
claims: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class ServerAuthScheme(ABC, BaseModel):
|
||||
"""Base class for server-side authentication schemes.
|
||||
|
||||
Each scheme validates incoming requests and returns an AuthenticatedUser
|
||||
on success, or raises HTTPException on failure.
|
||||
"""
|
||||
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
|
||||
|
||||
@abstractmethod
|
||||
async def authenticate(self, token: str) -> AuthenticatedUser:
|
||||
"""Authenticate the provided token.
|
||||
|
||||
Args:
|
||||
token: The bearer token to authenticate.
|
||||
|
||||
Returns:
|
||||
AuthenticatedUser on successful authentication.
|
||||
|
||||
Raises:
|
||||
HTTPException: If authentication fails.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class SimpleTokenAuth(ServerAuthScheme):
|
||||
"""Simple bearer token authentication.
|
||||
|
||||
Validates tokens against a configured static token or AUTH_TOKEN env var.
|
||||
|
||||
Attributes:
|
||||
token: Expected token value. Falls back to AUTH_TOKEN env var if not set.
|
||||
"""
|
||||
|
||||
token: CoercedSecretStr | None = Field(
|
||||
default=None,
|
||||
description="Expected token. Falls back to AUTH_TOKEN env var.",
|
||||
)
|
||||
|
||||
def _get_expected_token(self) -> str | None:
|
||||
"""Get the expected token value."""
|
||||
if self.token:
|
||||
return self.token.get_secret_value()
|
||||
return os.environ.get("AUTH_TOKEN")
|
||||
|
||||
async def authenticate(self, token: str) -> AuthenticatedUser:
|
||||
"""Authenticate using simple token comparison.
|
||||
|
||||
Args:
|
||||
token: The bearer token to authenticate.
|
||||
|
||||
Returns:
|
||||
AuthenticatedUser on successful authentication.
|
||||
|
||||
Raises:
|
||||
HTTPException: If authentication fails.
|
||||
"""
|
||||
expected = self._get_expected_token()
|
||||
|
||||
if expected is None:
|
||||
logger.warning(
|
||||
"Simple token authentication failed",
|
||||
extra={"reason": "no_token_configured"},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=HTTP_401_UNAUTHORIZED,
|
||||
detail="Authentication not configured",
|
||||
)
|
||||
|
||||
if token != expected:
|
||||
raise HTTPException(
|
||||
status_code=HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid or missing authentication credentials",
|
||||
)
|
||||
|
||||
return AuthenticatedUser(
|
||||
token=token,
|
||||
scheme="simple_token",
|
||||
)
|
||||
|
||||
|
||||
class OIDCAuth(ServerAuthScheme):
|
||||
"""OpenID Connect authentication.
|
||||
|
||||
Validates JWTs using JWKS with caching support via PyJWT.
|
||||
|
||||
Attributes:
|
||||
issuer: The OpenID Connect issuer URL.
|
||||
audience: The expected audience claim.
|
||||
jwks_url: Optional explicit JWKS URL. Derived from issuer if not set.
|
||||
algorithms: List of allowed signing algorithms.
|
||||
required_claims: List of claims that must be present in the token.
|
||||
jwks_cache_ttl: TTL for JWKS cache in seconds.
|
||||
clock_skew_seconds: Allowed clock skew for token validation.
|
||||
"""
|
||||
|
||||
issuer: HttpUrl = Field(
|
||||
description="OpenID Connect issuer URL (e.g., https://auth.example.com)"
|
||||
)
|
||||
audience: str = Field(description="Expected audience claim (e.g., api://my-agent)")
|
||||
jwks_url: HttpUrl | None = Field(
|
||||
default=None,
|
||||
description="Explicit JWKS URL. Derived from issuer if not set.",
|
||||
)
|
||||
algorithms: list[str] = Field(
|
||||
default_factory=lambda: ["RS256"],
|
||||
description="List of allowed signing algorithms (RS256, ES256, etc.)",
|
||||
)
|
||||
required_claims: list[str] = Field(
|
||||
default_factory=lambda: ["exp", "iat", "iss", "aud", "sub"],
|
||||
description="List of claims that must be present in the token",
|
||||
)
|
||||
jwks_cache_ttl: int = Field(
|
||||
default=3600,
|
||||
description="TTL for JWKS cache in seconds",
|
||||
ge=60,
|
||||
)
|
||||
clock_skew_seconds: float = Field(
|
||||
default=30.0,
|
||||
description="Allowed clock skew for token validation",
|
||||
ge=0.0,
|
||||
)
|
||||
|
||||
_jwk_client: PyJWKClient | None = PrivateAttr(default=None)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _init_jwk_client(self) -> Self:
|
||||
"""Initialize the JWK client after model creation."""
|
||||
jwks_url = (
|
||||
str(self.jwks_url)
|
||||
if self.jwks_url
|
||||
else f"{str(self.issuer).rstrip('/')}/.well-known/jwks.json"
|
||||
)
|
||||
self._jwk_client = PyJWKClient(jwks_url, lifespan=self.jwks_cache_ttl)
|
||||
return self
|
||||
|
||||
async def authenticate(self, token: str) -> AuthenticatedUser:
|
||||
"""Authenticate using OIDC JWT validation.
|
||||
|
||||
Args:
|
||||
token: The JWT to authenticate.
|
||||
|
||||
Returns:
|
||||
AuthenticatedUser on successful authentication.
|
||||
|
||||
Raises:
|
||||
HTTPException: If authentication fails.
|
||||
"""
|
||||
if self._jwk_client is None:
|
||||
raise HTTPException(
|
||||
status_code=HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="OIDC not initialized",
|
||||
)
|
||||
|
||||
try:
|
||||
signing_key = self._jwk_client.get_signing_key_from_jwt(token)
|
||||
|
||||
claims = jwt.decode(
|
||||
token,
|
||||
signing_key.key,
|
||||
algorithms=self.algorithms,
|
||||
audience=self.audience,
|
||||
issuer=str(self.issuer).rstrip("/"),
|
||||
leeway=self.clock_skew_seconds,
|
||||
options={
|
||||
"require": self.required_claims,
|
||||
},
|
||||
)
|
||||
|
||||
return AuthenticatedUser(
|
||||
token=token,
|
||||
scheme="oidc",
|
||||
claims=claims,
|
||||
)
|
||||
|
||||
except jwt.ExpiredSignatureError:
|
||||
logger.debug(
|
||||
"OIDC authentication failed",
|
||||
extra={"reason": "token_expired", "scheme": "oidc"},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=HTTP_401_UNAUTHORIZED,
|
||||
detail="Token has expired",
|
||||
) from None
|
||||
except jwt.InvalidAudienceError:
|
||||
logger.debug(
|
||||
"OIDC authentication failed",
|
||||
extra={"reason": "invalid_audience", "scheme": "oidc"},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid token audience",
|
||||
) from None
|
||||
except jwt.InvalidIssuerError:
|
||||
logger.debug(
|
||||
"OIDC authentication failed",
|
||||
extra={"reason": "invalid_issuer", "scheme": "oidc"},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid token issuer",
|
||||
) from None
|
||||
except jwt.MissingRequiredClaimError as e:
|
||||
logger.debug(
|
||||
"OIDC authentication failed",
|
||||
extra={"reason": "missing_claim", "claim": e.claim, "scheme": "oidc"},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=HTTP_401_UNAUTHORIZED,
|
||||
detail=f"Missing required claim: {e.claim}",
|
||||
) from None
|
||||
except jwt.PyJWKClientError as e:
|
||||
logger.error(
|
||||
"OIDC authentication failed",
|
||||
extra={
|
||||
"reason": "jwks_client_error",
|
||||
"error": str(e),
|
||||
"scheme": "oidc",
|
||||
},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="Unable to fetch signing keys",
|
||||
) from None
|
||||
except jwt.InvalidTokenError as e:
|
||||
logger.debug(
|
||||
"OIDC authentication failed",
|
||||
extra={"reason": "invalid_token", "error": str(e), "scheme": "oidc"},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid or missing authentication credentials",
|
||||
) from None
|
||||
|
||||
|
||||
class OAuth2ServerAuth(ServerAuthScheme):
|
||||
"""OAuth2 authentication for A2A server.
|
||||
|
||||
Declares OAuth2 security scheme in AgentCard and validates tokens using
|
||||
either JWKS for JWT tokens or token introspection for opaque tokens.
|
||||
|
||||
This is distinct from OIDCAuth in that it declares an explicit OAuth2SecurityScheme
|
||||
with flows, rather than an OpenIdConnectSecurityScheme with discovery URL.
|
||||
|
||||
Attributes:
|
||||
token_url: OAuth2 token endpoint URL for client_credentials flow.
|
||||
authorization_url: OAuth2 authorization endpoint for authorization_code flow.
|
||||
refresh_url: Optional refresh token endpoint URL.
|
||||
scopes: Available OAuth2 scopes with descriptions.
|
||||
jwks_url: JWKS URL for JWT validation. Required if not using introspection.
|
||||
introspection_url: Token introspection endpoint (RFC 7662). Alternative to JWKS.
|
||||
introspection_client_id: Client ID for introspection endpoint authentication.
|
||||
introspection_client_secret: Client secret for introspection endpoint.
|
||||
audience: Expected audience claim for JWT validation.
|
||||
issuer: Expected issuer claim for JWT validation.
|
||||
algorithms: Allowed JWT signing algorithms.
|
||||
required_claims: Claims that must be present in the token.
|
||||
jwks_cache_ttl: TTL for JWKS cache in seconds.
|
||||
clock_skew_seconds: Allowed clock skew for token validation.
|
||||
"""
|
||||
|
||||
token_url: HttpUrl = Field(
|
||||
description="OAuth2 token endpoint URL",
|
||||
)
|
||||
authorization_url: HttpUrl | None = Field(
|
||||
default=None,
|
||||
description="OAuth2 authorization endpoint URL for authorization_code flow",
|
||||
)
|
||||
refresh_url: HttpUrl | None = Field(
|
||||
default=None,
|
||||
description="OAuth2 refresh token endpoint URL",
|
||||
)
|
||||
scopes: dict[str, str] = Field(
|
||||
default_factory=dict,
|
||||
description="Available OAuth2 scopes with descriptions",
|
||||
)
|
||||
jwks_url: HttpUrl | None = Field(
|
||||
default=None,
|
||||
description="JWKS URL for JWT validation. Required if not using introspection.",
|
||||
)
|
||||
introspection_url: HttpUrl | None = Field(
|
||||
default=None,
|
||||
description="Token introspection endpoint (RFC 7662). Alternative to JWKS.",
|
||||
)
|
||||
introspection_client_id: str | None = Field(
|
||||
default=None,
|
||||
description="Client ID for introspection endpoint authentication",
|
||||
)
|
||||
introspection_client_secret: CoercedSecretStr | None = Field(
|
||||
default=None,
|
||||
description="Client secret for introspection endpoint authentication",
|
||||
)
|
||||
audience: str | None = Field(
|
||||
default=None,
|
||||
description="Expected audience claim for JWT validation",
|
||||
)
|
||||
issuer: str | None = Field(
|
||||
default=None,
|
||||
description="Expected issuer claim for JWT validation",
|
||||
)
|
||||
algorithms: list[str] = Field(
|
||||
default_factory=lambda: ["RS256"],
|
||||
description="Allowed JWT signing algorithms",
|
||||
)
|
||||
required_claims: list[str] = Field(
|
||||
default_factory=lambda: ["exp", "iat"],
|
||||
description="Claims that must be present in the token",
|
||||
)
|
||||
jwks_cache_ttl: int = Field(
|
||||
default=3600,
|
||||
description="TTL for JWKS cache in seconds",
|
||||
ge=60,
|
||||
)
|
||||
clock_skew_seconds: float = Field(
|
||||
default=30.0,
|
||||
description="Allowed clock skew for token validation",
|
||||
ge=0.0,
|
||||
)
|
||||
|
||||
_jwk_client: PyJWKClient | None = PrivateAttr(default=None)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _validate_and_init(self) -> Self:
|
||||
"""Validate configuration and initialize JWKS client if needed."""
|
||||
if not self.jwks_url and not self.introspection_url:
|
||||
raise ValueError(
|
||||
"Either jwks_url or introspection_url must be provided for token validation"
|
||||
)
|
||||
|
||||
if self.introspection_url:
|
||||
if not self.introspection_client_id or not self.introspection_client_secret:
|
||||
raise ValueError(
|
||||
"introspection_client_id and introspection_client_secret are required "
|
||||
"when using token introspection"
|
||||
)
|
||||
|
||||
if self.jwks_url:
|
||||
self._jwk_client = PyJWKClient(
|
||||
str(self.jwks_url), lifespan=self.jwks_cache_ttl
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
async def authenticate(self, token: str) -> AuthenticatedUser:
|
||||
"""Authenticate using OAuth2 token validation.
|
||||
|
||||
Uses JWKS validation if jwks_url is configured, otherwise falls back
|
||||
to token introspection.
|
||||
|
||||
Args:
|
||||
token: The OAuth2 access token to authenticate.
|
||||
|
||||
Returns:
|
||||
AuthenticatedUser on successful authentication.
|
||||
|
||||
Raises:
|
||||
HTTPException: If authentication fails.
|
||||
"""
|
||||
if self._jwk_client:
|
||||
return await self._authenticate_jwt(token)
|
||||
return await self._authenticate_introspection(token)
|
||||
|
||||
async def _authenticate_jwt(self, token: str) -> AuthenticatedUser:
|
||||
"""Authenticate using JWKS JWT validation."""
|
||||
if self._jwk_client is None:
|
||||
raise HTTPException(
|
||||
status_code=HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="OAuth2 JWKS not initialized",
|
||||
)
|
||||
|
||||
try:
|
||||
signing_key = self._jwk_client.get_signing_key_from_jwt(token)
|
||||
|
||||
decode_options: dict[str, Any] = {
|
||||
"require": self.required_claims,
|
||||
}
|
||||
|
||||
claims = jwt.decode(
|
||||
token,
|
||||
signing_key.key,
|
||||
algorithms=self.algorithms,
|
||||
audience=self.audience,
|
||||
issuer=self.issuer,
|
||||
leeway=self.clock_skew_seconds,
|
||||
options=decode_options, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
return AuthenticatedUser(
|
||||
token=token,
|
||||
scheme="oauth2",
|
||||
claims=claims,
|
||||
)
|
||||
|
||||
except jwt.ExpiredSignatureError:
|
||||
logger.debug(
|
||||
"OAuth2 authentication failed",
|
||||
extra={"reason": "token_expired", "scheme": "oauth2"},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=HTTP_401_UNAUTHORIZED,
|
||||
detail="Token has expired",
|
||||
) from None
|
||||
except jwt.InvalidAudienceError:
|
||||
logger.debug(
|
||||
"OAuth2 authentication failed",
|
||||
extra={"reason": "invalid_audience", "scheme": "oauth2"},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid token audience",
|
||||
) from None
|
||||
except jwt.InvalidIssuerError:
|
||||
logger.debug(
|
||||
"OAuth2 authentication failed",
|
||||
extra={"reason": "invalid_issuer", "scheme": "oauth2"},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid token issuer",
|
||||
) from None
|
||||
except jwt.MissingRequiredClaimError as e:
|
||||
logger.debug(
|
||||
"OAuth2 authentication failed",
|
||||
extra={"reason": "missing_claim", "claim": e.claim, "scheme": "oauth2"},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=HTTP_401_UNAUTHORIZED,
|
||||
detail=f"Missing required claim: {e.claim}",
|
||||
) from None
|
||||
except jwt.PyJWKClientError as e:
|
||||
logger.error(
|
||||
"OAuth2 authentication failed",
|
||||
extra={
|
||||
"reason": "jwks_client_error",
|
||||
"error": str(e),
|
||||
"scheme": "oauth2",
|
||||
},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="Unable to fetch signing keys",
|
||||
) from None
|
||||
except jwt.InvalidTokenError as e:
|
||||
logger.debug(
|
||||
"OAuth2 authentication failed",
|
||||
extra={"reason": "invalid_token", "error": str(e), "scheme": "oauth2"},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid or missing authentication credentials",
|
||||
) from None
|
||||
|
||||
async def _authenticate_introspection(self, token: str) -> AuthenticatedUser:
|
||||
"""Authenticate using OAuth2 token introspection (RFC 7662)."""
|
||||
import httpx
|
||||
|
||||
if not self.introspection_url:
|
||||
raise HTTPException(
|
||||
status_code=HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="OAuth2 introspection not configured",
|
||||
)
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
str(self.introspection_url),
|
||||
data={"token": token},
|
||||
auth=(
|
||||
self.introspection_client_id or "",
|
||||
self.introspection_client_secret.get_secret_value()
|
||||
if self.introspection_client_secret
|
||||
else "",
|
||||
),
|
||||
)
|
||||
response.raise_for_status()
|
||||
introspection_result = response.json()
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(
|
||||
"OAuth2 introspection failed",
|
||||
extra={"reason": "http_error", "status_code": e.response.status_code},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="Token introspection service unavailable",
|
||||
) from None
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"OAuth2 introspection failed",
|
||||
extra={"reason": "unexpected_error", "error": str(e)},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="Token introspection failed",
|
||||
) from None
|
||||
|
||||
if not introspection_result.get("active", False):
|
||||
logger.debug(
|
||||
"OAuth2 authentication failed",
|
||||
extra={"reason": "token_not_active", "scheme": "oauth2"},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=HTTP_401_UNAUTHORIZED,
|
||||
detail="Token is not active",
|
||||
)
|
||||
|
||||
return AuthenticatedUser(
|
||||
token=token,
|
||||
scheme="oauth2",
|
||||
claims=introspection_result,
|
||||
)
|
||||
|
||||
def to_security_scheme(self) -> OAuth2SecurityScheme:
|
||||
"""Generate OAuth2SecurityScheme for AgentCard declaration.
|
||||
|
||||
Creates an OAuth2SecurityScheme with appropriate flows based on
|
||||
the configured URLs. Includes client_credentials flow if token_url
|
||||
is set, and authorization_code flow if authorization_url is set.
|
||||
|
||||
Returns:
|
||||
OAuth2SecurityScheme suitable for use in AgentCard security_schemes.
|
||||
"""
|
||||
from a2a.types import (
|
||||
AuthorizationCodeOAuthFlow,
|
||||
ClientCredentialsOAuthFlow,
|
||||
OAuth2SecurityScheme,
|
||||
OAuthFlows,
|
||||
)
|
||||
|
||||
client_credentials = None
|
||||
authorization_code = None
|
||||
|
||||
if self.token_url:
|
||||
client_credentials = ClientCredentialsOAuthFlow(
|
||||
token_url=str(self.token_url),
|
||||
refresh_url=str(self.refresh_url) if self.refresh_url else None,
|
||||
scopes=self.scopes,
|
||||
)
|
||||
|
||||
if self.authorization_url:
|
||||
authorization_code = AuthorizationCodeOAuthFlow(
|
||||
authorization_url=str(self.authorization_url),
|
||||
token_url=str(self.token_url),
|
||||
refresh_url=str(self.refresh_url) if self.refresh_url else None,
|
||||
scopes=self.scopes,
|
||||
)
|
||||
|
||||
return OAuth2SecurityScheme(
|
||||
flows=OAuthFlows(
|
||||
client_credentials=client_credentials,
|
||||
authorization_code=authorization_code,
|
||||
),
|
||||
description="OAuth2 authentication",
|
||||
)
|
||||
|
||||
|
||||
class APIKeyServerAuth(ServerAuthScheme):
|
||||
"""API Key authentication for A2A server.
|
||||
|
||||
Validates requests using an API key in a header, query parameter, or cookie.
|
||||
|
||||
Attributes:
|
||||
name: The name of the API key parameter (default: X-API-Key).
|
||||
location: Where to look for the API key (header, query, or cookie).
|
||||
api_key: The expected API key value.
|
||||
"""
|
||||
|
||||
name: str = Field(
|
||||
default="X-API-Key",
|
||||
description="Name of the API key parameter",
|
||||
)
|
||||
location: Literal["header", "query", "cookie"] = Field(
|
||||
default="header",
|
||||
description="Where to look for the API key",
|
||||
)
|
||||
api_key: CoercedSecretStr = Field(
|
||||
description="Expected API key value",
|
||||
)
|
||||
|
||||
async def authenticate(self, token: str) -> AuthenticatedUser:
|
||||
"""Authenticate using API key comparison.
|
||||
|
||||
Args:
|
||||
token: The API key to authenticate.
|
||||
|
||||
Returns:
|
||||
AuthenticatedUser on successful authentication.
|
||||
|
||||
Raises:
|
||||
HTTPException: If authentication fails.
|
||||
"""
|
||||
if token != self.api_key.get_secret_value():
|
||||
raise HTTPException(
|
||||
status_code=HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid API key",
|
||||
)
|
||||
|
||||
return AuthenticatedUser(
|
||||
token=token,
|
||||
scheme="api_key",
|
||||
)
|
||||
|
||||
|
||||
class MTLSServerAuth(ServerAuthScheme):
|
||||
"""Mutual TLS authentication marker for AgentCard declaration.
|
||||
|
||||
This scheme is primarily for AgentCard security_schemes declaration.
|
||||
Actual mTLS verification happens at the TLS/transport layer, not
|
||||
at the application layer via token validation.
|
||||
|
||||
When configured, this signals to clients that the server requires
|
||||
client certificates for authentication.
|
||||
"""
|
||||
|
||||
description: str = Field(
|
||||
default="Mutual TLS certificate authentication",
|
||||
description="Description for the security scheme",
|
||||
)
|
||||
|
||||
async def authenticate(self, token: str) -> AuthenticatedUser:
|
||||
"""Return authenticated user for mTLS.
|
||||
|
||||
mTLS verification happens at the transport layer before this is called.
|
||||
If we reach this point, the TLS handshake with client cert succeeded.
|
||||
|
||||
Args:
|
||||
token: Certificate subject or identifier (from TLS layer).
|
||||
|
||||
Returns:
|
||||
AuthenticatedUser indicating mTLS authentication.
|
||||
"""
|
||||
return AuthenticatedUser(
|
||||
token=token or "mtls-verified",
|
||||
scheme="mtls",
|
||||
)
|
||||
273
lib/crewai-a2a/src/crewai_a2a/auth/utils.py
Normal file
273
lib/crewai-a2a/src/crewai_a2a/auth/utils.py
Normal file
@@ -0,0 +1,273 @@
|
||||
"""Authentication utilities for A2A protocol agent communication.
|
||||
|
||||
Provides validation and retry logic for various authentication schemes including
|
||||
OAuth2, API keys, and HTTP authentication methods.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Awaitable, Callable, MutableMapping
|
||||
import hashlib
|
||||
import re
|
||||
import threading
|
||||
from typing import Final, Literal, cast
|
||||
|
||||
from a2a.client.errors import A2AClientHTTPError
|
||||
from a2a.types import (
|
||||
APIKeySecurityScheme,
|
||||
AgentCard,
|
||||
HTTPAuthSecurityScheme,
|
||||
OAuth2SecurityScheme,
|
||||
)
|
||||
from httpx import AsyncClient, Response
|
||||
|
||||
from crewai_a2a.auth.client_schemes import (
|
||||
APIKeyAuth,
|
||||
BearerTokenAuth,
|
||||
ClientAuthScheme,
|
||||
HTTPBasicAuth,
|
||||
HTTPDigestAuth,
|
||||
OAuth2AuthorizationCode,
|
||||
OAuth2ClientCredentials,
|
||||
)
|
||||
|
||||
|
||||
class _AuthStore:
|
||||
"""Store for authentication schemes with safe concurrent access."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._store: dict[str, ClientAuthScheme | None] = {}
|
||||
self._lock = threading.RLock()
|
||||
|
||||
@staticmethod
|
||||
def compute_key(auth_type: str, auth_data: str) -> str:
|
||||
"""Compute a collision-resistant key using SHA-256."""
|
||||
content = f"{auth_type}:{auth_data}"
|
||||
return hashlib.sha256(content.encode()).hexdigest()
|
||||
|
||||
def set(self, key: str, auth: ClientAuthScheme | None) -> None:
|
||||
"""Store an auth scheme."""
|
||||
with self._lock:
|
||||
self._store[key] = auth
|
||||
|
||||
def get(self, key: str) -> ClientAuthScheme | None:
|
||||
"""Retrieve an auth scheme by key."""
|
||||
with self._lock:
|
||||
return self._store.get(key)
|
||||
|
||||
def __setitem__(self, key: str, value: ClientAuthScheme | None) -> None:
|
||||
with self._lock:
|
||||
self._store[key] = value
|
||||
|
||||
def __getitem__(self, key: str) -> ClientAuthScheme | None:
|
||||
with self._lock:
|
||||
return self._store[key]
|
||||
|
||||
|
||||
_auth_store = _AuthStore()
|
||||
|
||||
_SCHEME_PATTERN: Final[re.Pattern[str]] = re.compile(r"(\w+)\s+(.+?)(?=,\s*\w+\s+|$)")
|
||||
_PARAM_PATTERN: Final[re.Pattern[str]] = re.compile(r'(\w+)=(?:"([^"]*)"|([^\s,]+))')
|
||||
|
||||
_SCHEME_AUTH_MAPPING: Final[dict[type, tuple[type[ClientAuthScheme], ...]]] = {
|
||||
OAuth2SecurityScheme: (
|
||||
OAuth2ClientCredentials,
|
||||
OAuth2AuthorizationCode,
|
||||
BearerTokenAuth,
|
||||
),
|
||||
APIKeySecurityScheme: (APIKeyAuth,),
|
||||
}
|
||||
|
||||
_HTTPSchemeType = Literal["basic", "digest", "bearer"]
|
||||
|
||||
_HTTP_SCHEME_MAPPING: Final[dict[_HTTPSchemeType, type[ClientAuthScheme]]] = {
|
||||
"basic": HTTPBasicAuth,
|
||||
"digest": HTTPDigestAuth,
|
||||
"bearer": BearerTokenAuth,
|
||||
}
|
||||
|
||||
|
||||
def _raise_auth_mismatch(
|
||||
expected_classes: type[ClientAuthScheme] | tuple[type[ClientAuthScheme], ...],
|
||||
provided_auth: ClientAuthScheme,
|
||||
) -> None:
|
||||
"""Raise authentication mismatch error.
|
||||
|
||||
Args:
|
||||
expected_classes: Expected authentication class or tuple of classes.
|
||||
provided_auth: Actually provided authentication instance.
|
||||
|
||||
Raises:
|
||||
A2AClientHTTPError: Always raises with 401 status code.
|
||||
"""
|
||||
if isinstance(expected_classes, tuple):
|
||||
if len(expected_classes) == 1:
|
||||
required = expected_classes[0].__name__
|
||||
else:
|
||||
names = [cls.__name__ for cls in expected_classes]
|
||||
required = f"one of ({', '.join(names)})"
|
||||
else:
|
||||
required = expected_classes.__name__
|
||||
|
||||
msg = (
|
||||
f"AgentCard requires {required} authentication, "
|
||||
f"but {type(provided_auth).__name__} was provided"
|
||||
)
|
||||
raise A2AClientHTTPError(401, msg)
|
||||
|
||||
|
||||
def parse_www_authenticate(header_value: str) -> dict[str, dict[str, str]]:
|
||||
"""Parse WWW-Authenticate header into auth challenges.
|
||||
|
||||
Args:
|
||||
header_value: The WWW-Authenticate header value.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping auth scheme to its parameters.
|
||||
Example: {"Bearer": {"realm": "api", "scope": "read write"}}
|
||||
"""
|
||||
if not header_value:
|
||||
return {}
|
||||
|
||||
challenges: dict[str, dict[str, str]] = {}
|
||||
|
||||
for match in _SCHEME_PATTERN.finditer(header_value):
|
||||
scheme = match.group(1)
|
||||
params_str = match.group(2)
|
||||
|
||||
params: dict[str, str] = {}
|
||||
|
||||
for param_match in _PARAM_PATTERN.finditer(params_str):
|
||||
key = param_match.group(1)
|
||||
value = param_match.group(2) or param_match.group(3)
|
||||
params[key] = value
|
||||
|
||||
challenges[scheme] = params
|
||||
|
||||
return challenges
|
||||
|
||||
|
||||
def validate_auth_against_agent_card(
|
||||
agent_card: AgentCard, auth: ClientAuthScheme | None
|
||||
) -> None:
|
||||
"""Validate that provided auth matches AgentCard security requirements.
|
||||
|
||||
Args:
|
||||
agent_card: The A2A AgentCard containing security requirements.
|
||||
auth: User-provided authentication scheme (or None).
|
||||
|
||||
Raises:
|
||||
A2AClientHTTPError: If auth doesn't match AgentCard requirements (status_code=401).
|
||||
"""
|
||||
|
||||
if not agent_card.security or not agent_card.security_schemes:
|
||||
return
|
||||
|
||||
if not auth:
|
||||
msg = "AgentCard requires authentication but no auth scheme provided"
|
||||
raise A2AClientHTTPError(401, msg)
|
||||
|
||||
first_security_req = agent_card.security[0] if agent_card.security else {}
|
||||
|
||||
for scheme_name in first_security_req.keys():
|
||||
security_scheme_wrapper = agent_card.security_schemes.get(scheme_name)
|
||||
if not security_scheme_wrapper:
|
||||
continue
|
||||
|
||||
scheme = security_scheme_wrapper.root
|
||||
|
||||
if allowed_classes := _SCHEME_AUTH_MAPPING.get(type(scheme)):
|
||||
if not isinstance(auth, allowed_classes):
|
||||
_raise_auth_mismatch(allowed_classes, auth)
|
||||
return
|
||||
|
||||
if isinstance(scheme, HTTPAuthSecurityScheme):
|
||||
scheme_key = cast(_HTTPSchemeType, scheme.scheme.lower())
|
||||
if required_class := _HTTP_SCHEME_MAPPING.get(scheme_key):
|
||||
if not isinstance(auth, required_class):
|
||||
_raise_auth_mismatch(required_class, auth)
|
||||
return
|
||||
|
||||
msg = "Could not validate auth against AgentCard security requirements"
|
||||
raise A2AClientHTTPError(401, msg)
|
||||
|
||||
|
||||
async def retry_on_401(
|
||||
request_func: Callable[[], Awaitable[Response]],
|
||||
auth_scheme: ClientAuthScheme | None,
|
||||
client: AsyncClient,
|
||||
headers: MutableMapping[str, str],
|
||||
max_retries: int = 3,
|
||||
) -> Response:
|
||||
"""Retry a request on 401 authentication error.
|
||||
|
||||
Handles 401 errors by:
|
||||
1. Parsing WWW-Authenticate header
|
||||
2. Re-acquiring credentials
|
||||
3. Retrying the request
|
||||
|
||||
Args:
|
||||
request_func: Async function that makes the HTTP request.
|
||||
auth_scheme: Authentication scheme to refresh credentials with.
|
||||
client: HTTP client for making requests.
|
||||
headers: Request headers to update with new auth.
|
||||
max_retries: Maximum number of retry attempts (default: 3).
|
||||
|
||||
Returns:
|
||||
HTTP response from the request.
|
||||
|
||||
Raises:
|
||||
httpx.HTTPStatusError: If retries are exhausted or auth scheme is None.
|
||||
"""
|
||||
last_response: Response | None = None
|
||||
last_challenges: dict[str, dict[str, str]] = {}
|
||||
|
||||
for attempt in range(max_retries):
|
||||
response = await request_func()
|
||||
|
||||
if response.status_code != 401:
|
||||
return response
|
||||
|
||||
last_response = response
|
||||
|
||||
if auth_scheme is None:
|
||||
response.raise_for_status()
|
||||
return response
|
||||
|
||||
www_authenticate = response.headers.get("WWW-Authenticate", "")
|
||||
challenges = parse_www_authenticate(www_authenticate)
|
||||
last_challenges = challenges
|
||||
|
||||
if attempt >= max_retries - 1:
|
||||
break
|
||||
|
||||
backoff_time = 2**attempt
|
||||
await asyncio.sleep(backoff_time)
|
||||
|
||||
await auth_scheme.apply_auth(client, headers)
|
||||
|
||||
if last_response:
|
||||
last_response.raise_for_status()
|
||||
return last_response
|
||||
|
||||
msg = "retry_on_401 failed without making any requests"
|
||||
if last_challenges:
|
||||
challenge_info = ", ".join(
|
||||
f"{scheme} (realm={params.get('realm', 'N/A')})"
|
||||
for scheme, params in last_challenges.items()
|
||||
)
|
||||
msg = f"{msg}. Server challenges: {challenge_info}"
|
||||
raise RuntimeError(msg)
|
||||
|
||||
|
||||
def configure_auth_client(
|
||||
auth: HTTPDigestAuth | APIKeyAuth, client: AsyncClient
|
||||
) -> None:
|
||||
"""Configure HTTP client with auth-specific settings.
|
||||
|
||||
Only HTTPDigestAuth and APIKeyAuth need client configuration.
|
||||
|
||||
Args:
|
||||
auth: Authentication scheme that requires client configuration.
|
||||
client: HTTP client to configure.
|
||||
"""
|
||||
auth.configure_client(client)
|
||||
690
lib/crewai-a2a/src/crewai_a2a/config.py
Normal file
690
lib/crewai-a2a/src/crewai_a2a/config.py
Normal file
@@ -0,0 +1,690 @@
|
||||
"""A2A configuration types.
|
||||
|
||||
This module is separate from experimental.a2a to avoid circular imports.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any, ClassVar, Literal, cast
|
||||
import warnings
|
||||
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
Field,
|
||||
FilePath,
|
||||
PrivateAttr,
|
||||
SecretStr,
|
||||
model_validator,
|
||||
)
|
||||
from typing_extensions import Self, deprecated
|
||||
|
||||
from crewai_a2a.auth.client_schemes import ClientAuthScheme
|
||||
from crewai_a2a.auth.server_schemes import ServerAuthScheme
|
||||
from crewai_a2a.extensions.base import ValidatedA2AExtension
|
||||
from crewai_a2a.types import ProtocolVersion, TransportType, Url
|
||||
|
||||
|
||||
try:
|
||||
from a2a.types import (
|
||||
AgentCapabilities,
|
||||
AgentCardSignature,
|
||||
AgentInterface,
|
||||
AgentProvider,
|
||||
AgentSkill,
|
||||
SecurityScheme,
|
||||
)
|
||||
|
||||
from crewai_a2a.extensions.server import ServerExtension
|
||||
from crewai_a2a.updates import UpdateConfig
|
||||
except ImportError:
|
||||
UpdateConfig: Any = Any # type: ignore[no-redef]
|
||||
AgentCapabilities: Any = Any # type: ignore[no-redef]
|
||||
AgentCardSignature: Any = Any # type: ignore[no-redef]
|
||||
AgentInterface: Any = Any # type: ignore[no-redef]
|
||||
AgentProvider: Any = Any # type: ignore[no-redef]
|
||||
SecurityScheme: Any = Any # type: ignore[no-redef]
|
||||
AgentSkill: Any = Any # type: ignore[no-redef]
|
||||
ServerExtension: Any = Any # type: ignore[no-redef]
|
||||
|
||||
|
||||
def _get_default_update_config() -> UpdateConfig:
|
||||
from crewai_a2a.updates import StreamingConfig
|
||||
|
||||
return StreamingConfig()
|
||||
|
||||
|
||||
SigningAlgorithm = Literal[
|
||||
"RS256",
|
||||
"RS384",
|
||||
"RS512",
|
||||
"ES256",
|
||||
"ES384",
|
||||
"ES512",
|
||||
"PS256",
|
||||
"PS384",
|
||||
"PS512",
|
||||
]
|
||||
|
||||
|
||||
class AgentCardSigningConfig(BaseModel):
|
||||
"""Configuration for AgentCard JWS signing.
|
||||
|
||||
Provides the private key and algorithm settings for signing AgentCards.
|
||||
Either private_key_path or private_key_pem must be provided, but not both.
|
||||
|
||||
Attributes:
|
||||
private_key_path: Path to a PEM-encoded private key file.
|
||||
private_key_pem: PEM-encoded private key as a secret string.
|
||||
key_id: Optional key identifier for the JWS header (kid claim).
|
||||
algorithm: Signing algorithm (RS256, ES256, PS256, etc.).
|
||||
"""
|
||||
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
|
||||
|
||||
private_key_path: FilePath | None = Field(
|
||||
default=None,
|
||||
description="Path to PEM-encoded private key file",
|
||||
)
|
||||
private_key_pem: SecretStr | None = Field(
|
||||
default=None,
|
||||
description="PEM-encoded private key",
|
||||
)
|
||||
key_id: str | None = Field(
|
||||
default=None,
|
||||
description="Key identifier for JWS header (kid claim)",
|
||||
)
|
||||
algorithm: SigningAlgorithm = Field(
|
||||
default="RS256",
|
||||
description="Signing algorithm (RS256, ES256, PS256, etc.)",
|
||||
)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _validate_key_source(self) -> Self:
|
||||
"""Ensure exactly one key source is provided."""
|
||||
has_path = self.private_key_path is not None
|
||||
has_pem = self.private_key_pem is not None
|
||||
|
||||
if not has_path and not has_pem:
|
||||
raise ValueError(
|
||||
"Either private_key_path or private_key_pem must be provided"
|
||||
)
|
||||
if has_path and has_pem:
|
||||
raise ValueError(
|
||||
"Only one of private_key_path or private_key_pem should be provided"
|
||||
)
|
||||
return self
|
||||
|
||||
def get_private_key(self) -> str:
|
||||
"""Get the private key content.
|
||||
|
||||
Returns:
|
||||
The PEM-encoded private key as a string.
|
||||
"""
|
||||
if self.private_key_pem:
|
||||
return self.private_key_pem.get_secret_value()
|
||||
if self.private_key_path:
|
||||
return Path(self.private_key_path).read_text()
|
||||
raise ValueError("No private key configured")
|
||||
|
||||
|
||||
class GRPCServerConfig(BaseModel):
|
||||
"""gRPC server transport configuration.
|
||||
|
||||
Presence of this config in ServerTransportConfig.grpc enables gRPC transport.
|
||||
|
||||
Attributes:
|
||||
host: Hostname to advertise in agent cards (default: localhost).
|
||||
Use docker service name (e.g., 'web') for docker-compose setups.
|
||||
port: Port for the gRPC server.
|
||||
tls_cert_path: Path to TLS certificate file for gRPC.
|
||||
tls_key_path: Path to TLS private key file for gRPC.
|
||||
max_workers: Maximum number of workers for the gRPC thread pool.
|
||||
reflection_enabled: Whether to enable gRPC reflection for debugging.
|
||||
"""
|
||||
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
|
||||
|
||||
host: str = Field(
|
||||
default="localhost",
|
||||
description="Hostname to advertise in agent cards for gRPC connections",
|
||||
)
|
||||
port: int = Field(
|
||||
default=50051,
|
||||
description="Port for the gRPC server",
|
||||
)
|
||||
tls_cert_path: str | None = Field(
|
||||
default=None,
|
||||
description="Path to TLS certificate file for gRPC",
|
||||
)
|
||||
tls_key_path: str | None = Field(
|
||||
default=None,
|
||||
description="Path to TLS private key file for gRPC",
|
||||
)
|
||||
max_workers: int = Field(
|
||||
default=10,
|
||||
description="Maximum number of workers for the gRPC thread pool",
|
||||
)
|
||||
reflection_enabled: bool = Field(
|
||||
default=False,
|
||||
description="Whether to enable gRPC reflection for debugging",
|
||||
)
|
||||
|
||||
|
||||
class GRPCClientConfig(BaseModel):
|
||||
"""gRPC client transport configuration.
|
||||
|
||||
Attributes:
|
||||
max_send_message_length: Maximum size for outgoing messages in bytes.
|
||||
max_receive_message_length: Maximum size for incoming messages in bytes.
|
||||
keepalive_time_ms: Time between keepalive pings in milliseconds.
|
||||
keepalive_timeout_ms: Timeout for keepalive ping response in milliseconds.
|
||||
"""
|
||||
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
|
||||
|
||||
max_send_message_length: int | None = Field(
|
||||
default=None,
|
||||
description="Maximum size for outgoing messages in bytes",
|
||||
)
|
||||
max_receive_message_length: int | None = Field(
|
||||
default=None,
|
||||
description="Maximum size for incoming messages in bytes",
|
||||
)
|
||||
keepalive_time_ms: int | None = Field(
|
||||
default=None,
|
||||
description="Time between keepalive pings in milliseconds",
|
||||
)
|
||||
keepalive_timeout_ms: int | None = Field(
|
||||
default=None,
|
||||
description="Timeout for keepalive ping response in milliseconds",
|
||||
)
|
||||
|
||||
|
||||
class JSONRPCServerConfig(BaseModel):
|
||||
"""JSON-RPC server transport configuration.
|
||||
|
||||
Presence of this config in ServerTransportConfig.jsonrpc enables JSON-RPC transport.
|
||||
|
||||
Attributes:
|
||||
rpc_path: URL path for the JSON-RPC endpoint.
|
||||
agent_card_path: URL path for the agent card endpoint.
|
||||
"""
|
||||
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
|
||||
|
||||
rpc_path: str = Field(
|
||||
default="/a2a",
|
||||
description="URL path for the JSON-RPC endpoint",
|
||||
)
|
||||
agent_card_path: str = Field(
|
||||
default="/.well-known/agent-card.json",
|
||||
description="URL path for the agent card endpoint",
|
||||
)
|
||||
|
||||
|
||||
class JSONRPCClientConfig(BaseModel):
|
||||
"""JSON-RPC client transport configuration.
|
||||
|
||||
Attributes:
|
||||
max_request_size: Maximum request body size in bytes.
|
||||
"""
|
||||
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
|
||||
|
||||
max_request_size: int | None = Field(
|
||||
default=None,
|
||||
description="Maximum request body size in bytes",
|
||||
)
|
||||
|
||||
|
||||
class HTTPJSONConfig(BaseModel):
|
||||
"""HTTP+JSON transport configuration.
|
||||
|
||||
Presence of this config in ServerTransportConfig.http_json enables HTTP+JSON transport.
|
||||
"""
|
||||
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
class ServerPushNotificationConfig(BaseModel):
|
||||
"""Configuration for outgoing webhook push notifications.
|
||||
|
||||
Controls how the server signs and delivers push notifications to clients.
|
||||
|
||||
Attributes:
|
||||
signature_secret: Shared secret for HMAC-SHA256 signing of outgoing webhooks.
|
||||
"""
|
||||
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
|
||||
|
||||
signature_secret: SecretStr | None = Field(
|
||||
default=None,
|
||||
description="Shared secret for HMAC-SHA256 signing of outgoing push notifications",
|
||||
)
|
||||
|
||||
|
||||
class ServerTransportConfig(BaseModel):
|
||||
"""Transport configuration for A2A server.
|
||||
|
||||
Groups all transport-related settings including preferred transport
|
||||
and protocol-specific configurations.
|
||||
|
||||
Attributes:
|
||||
preferred: Transport protocol for the preferred endpoint.
|
||||
jsonrpc: JSON-RPC server transport configuration.
|
||||
grpc: gRPC server transport configuration.
|
||||
http_json: HTTP+JSON transport configuration.
|
||||
"""
|
||||
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
|
||||
|
||||
preferred: TransportType = Field(
|
||||
default="JSONRPC",
|
||||
description="Transport protocol for the preferred endpoint",
|
||||
)
|
||||
jsonrpc: JSONRPCServerConfig = Field(
|
||||
default_factory=JSONRPCServerConfig,
|
||||
description="JSON-RPC server transport configuration",
|
||||
)
|
||||
grpc: GRPCServerConfig | None = Field(
|
||||
default=None,
|
||||
description="gRPC server transport configuration",
|
||||
)
|
||||
http_json: HTTPJSONConfig | None = Field(
|
||||
default=None,
|
||||
description="HTTP+JSON transport configuration",
|
||||
)
|
||||
|
||||
|
||||
def _migrate_client_transport_fields(
|
||||
transport: ClientTransportConfig,
|
||||
transport_protocol: TransportType | None,
|
||||
supported_transports: list[TransportType] | None,
|
||||
) -> None:
|
||||
"""Migrate deprecated transport fields to new config."""
|
||||
if transport_protocol is not None:
|
||||
warnings.warn(
|
||||
"transport_protocol is deprecated, use transport=ClientTransportConfig(preferred=...) instead",
|
||||
FutureWarning,
|
||||
stacklevel=5,
|
||||
)
|
||||
object.__setattr__(transport, "preferred", transport_protocol)
|
||||
if supported_transports is not None:
|
||||
warnings.warn(
|
||||
"supported_transports is deprecated, use transport=ClientTransportConfig(supported=...) instead",
|
||||
FutureWarning,
|
||||
stacklevel=5,
|
||||
)
|
||||
object.__setattr__(transport, "supported", supported_transports)
|
||||
|
||||
|
||||
class ClientTransportConfig(BaseModel):
|
||||
"""Transport configuration for A2A client.
|
||||
|
||||
Groups all client transport-related settings including preferred transport,
|
||||
supported transports for negotiation, and protocol-specific configurations.
|
||||
|
||||
Transport negotiation logic:
|
||||
1. If `preferred` is set and server supports it → use client's preferred
|
||||
2. Otherwise, if server's preferred is in client's `supported` → use server's preferred
|
||||
3. Otherwise, find first match from client's `supported` in server's interfaces
|
||||
|
||||
Attributes:
|
||||
preferred: Client's preferred transport. If set, client preference takes priority.
|
||||
supported: Transports the client can use, in order of preference.
|
||||
jsonrpc: JSON-RPC client transport configuration.
|
||||
grpc: gRPC client transport configuration.
|
||||
"""
|
||||
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
|
||||
|
||||
preferred: TransportType | None = Field(
|
||||
default=None,
|
||||
description="Client's preferred transport. If set, takes priority over server preference.",
|
||||
)
|
||||
supported: list[TransportType] = Field(
|
||||
default_factory=lambda: cast(list[TransportType], ["JSONRPC"]),
|
||||
description="Transports the client can use, in order of preference",
|
||||
)
|
||||
jsonrpc: JSONRPCClientConfig = Field(
|
||||
default_factory=JSONRPCClientConfig,
|
||||
description="JSON-RPC client transport configuration",
|
||||
)
|
||||
grpc: GRPCClientConfig = Field(
|
||||
default_factory=GRPCClientConfig,
|
||||
description="gRPC client transport configuration",
|
||||
)
|
||||
|
||||
|
||||
@deprecated(
|
||||
"""
|
||||
`crewai_a2a.config.A2AConfig` is deprecated and will be removed in v2.0.0,
|
||||
use `crewai_a2a.config.A2AClientConfig` or `crewai_a2a.config.A2AServerConfig` instead.
|
||||
""",
|
||||
category=FutureWarning,
|
||||
)
|
||||
class A2AConfig(BaseModel):
|
||||
"""Configuration for A2A protocol integration.
|
||||
|
||||
Deprecated:
|
||||
Use A2AClientConfig instead. This class will be removed in a future version.
|
||||
|
||||
Attributes:
|
||||
endpoint: A2A agent endpoint URL.
|
||||
auth: Authentication scheme.
|
||||
timeout: Request timeout in seconds.
|
||||
max_turns: Maximum conversation turns with A2A agent.
|
||||
response_model: Optional Pydantic model for structured A2A agent responses.
|
||||
fail_fast: If True, raise error when agent unreachable; if False, skip and continue.
|
||||
trust_remote_completion_status: If True, return A2A agent's result directly when completed.
|
||||
updates: Update mechanism config.
|
||||
client_extensions: Client-side processing hooks for tool injection and prompt augmentation.
|
||||
transport: Transport configuration (preferred, supported transports, gRPC settings).
|
||||
"""
|
||||
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
|
||||
|
||||
endpoint: Url = Field(description="A2A agent endpoint URL")
|
||||
auth: ClientAuthScheme | None = Field(
|
||||
default=None,
|
||||
description="Authentication scheme",
|
||||
)
|
||||
timeout: int = Field(default=120, description="Request timeout in seconds")
|
||||
max_turns: int = Field(
|
||||
default=10, description="Maximum conversation turns with A2A agent"
|
||||
)
|
||||
response_model: type[BaseModel] | None = Field(
|
||||
default=None,
|
||||
description="Optional Pydantic model for structured A2A agent responses",
|
||||
)
|
||||
fail_fast: bool = Field(
|
||||
default=True,
|
||||
description="If True, raise error when agent unreachable; if False, skip",
|
||||
)
|
||||
trust_remote_completion_status: bool = Field(
|
||||
default=False,
|
||||
description="If True, return A2A result directly when completed",
|
||||
)
|
||||
updates: UpdateConfig = Field(
|
||||
default_factory=_get_default_update_config,
|
||||
description="Update mechanism config",
|
||||
)
|
||||
client_extensions: list[ValidatedA2AExtension] = Field(
|
||||
default_factory=list,
|
||||
description="Client-side processing hooks for tool injection and prompt augmentation",
|
||||
)
|
||||
transport: ClientTransportConfig = Field(
|
||||
default_factory=ClientTransportConfig,
|
||||
description="Transport configuration (preferred, supported transports, gRPC settings)",
|
||||
)
|
||||
transport_protocol: TransportType | None = Field(
|
||||
default=None,
|
||||
description="Deprecated: Use transport.preferred instead",
|
||||
exclude=True,
|
||||
)
|
||||
supported_transports: list[TransportType] | None = Field(
|
||||
default=None,
|
||||
description="Deprecated: Use transport.supported instead",
|
||||
exclude=True,
|
||||
)
|
||||
use_client_preference: bool | None = Field(
|
||||
default=None,
|
||||
description="Deprecated: Set transport.preferred to enable client preference",
|
||||
exclude=True,
|
||||
)
|
||||
_parallel_delegation: bool = PrivateAttr(default=False)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _migrate_deprecated_transport_fields(self) -> Self:
|
||||
"""Migrate deprecated transport fields to new config."""
|
||||
_migrate_client_transport_fields(
|
||||
self.transport, self.transport_protocol, self.supported_transports
|
||||
)
|
||||
if self.use_client_preference is not None:
|
||||
warnings.warn(
|
||||
"use_client_preference is deprecated, set transport.preferred to enable client preference",
|
||||
FutureWarning,
|
||||
stacklevel=4,
|
||||
)
|
||||
if self.use_client_preference and self.transport.supported:
|
||||
object.__setattr__(
|
||||
self.transport, "preferred", self.transport.supported[0]
|
||||
)
|
||||
return self
|
||||
|
||||
|
||||
class A2AClientConfig(BaseModel):
|
||||
"""Configuration for connecting to remote A2A agents.
|
||||
|
||||
Attributes:
|
||||
endpoint: A2A agent endpoint URL.
|
||||
auth: Authentication scheme.
|
||||
timeout: Request timeout in seconds.
|
||||
max_turns: Maximum conversation turns with A2A agent.
|
||||
response_model: Optional Pydantic model for structured A2A agent responses.
|
||||
fail_fast: If True, raise error when agent unreachable; if False, skip and continue.
|
||||
trust_remote_completion_status: If True, return A2A agent's result directly when completed.
|
||||
updates: Update mechanism config.
|
||||
accepted_output_modes: Media types the client can accept in responses.
|
||||
extensions: Extension URIs the client supports (A2A protocol extensions).
|
||||
client_extensions: Client-side processing hooks for tool injection and prompt augmentation.
|
||||
transport: Transport configuration (preferred, supported transports, gRPC settings).
|
||||
"""
|
||||
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
|
||||
|
||||
endpoint: Url = Field(description="A2A agent endpoint URL")
|
||||
auth: ClientAuthScheme | None = Field(
|
||||
default=None,
|
||||
description="Authentication scheme",
|
||||
)
|
||||
timeout: int = Field(default=120, description="Request timeout in seconds")
|
||||
max_turns: int = Field(
|
||||
default=10, description="Maximum conversation turns with A2A agent"
|
||||
)
|
||||
response_model: type[BaseModel] | None = Field(
|
||||
default=None,
|
||||
description="Optional Pydantic model for structured A2A agent responses",
|
||||
)
|
||||
fail_fast: bool = Field(
|
||||
default=True,
|
||||
description="If True, raise error when agent unreachable; if False, skip",
|
||||
)
|
||||
trust_remote_completion_status: bool = Field(
|
||||
default=False,
|
||||
description="If True, return A2A result directly when completed",
|
||||
)
|
||||
updates: UpdateConfig = Field(
|
||||
default_factory=_get_default_update_config,
|
||||
description="Update mechanism config",
|
||||
)
|
||||
accepted_output_modes: list[str] = Field(
|
||||
default_factory=lambda: ["application/json"],
|
||||
description="Media types the client can accept in responses",
|
||||
)
|
||||
extensions: list[str] = Field(
|
||||
default_factory=list,
|
||||
description="Extension URIs the client supports",
|
||||
)
|
||||
client_extensions: list[ValidatedA2AExtension] = Field(
|
||||
default_factory=list,
|
||||
description="Client-side processing hooks for tool injection and prompt augmentation",
|
||||
)
|
||||
transport: ClientTransportConfig = Field(
|
||||
default_factory=ClientTransportConfig,
|
||||
description="Transport configuration (preferred, supported transports, gRPC settings)",
|
||||
)
|
||||
transport_protocol: TransportType | None = Field(
|
||||
default=None,
|
||||
description="Deprecated: Use transport.preferred instead",
|
||||
exclude=True,
|
||||
)
|
||||
supported_transports: list[TransportType] | None = Field(
|
||||
default=None,
|
||||
description="Deprecated: Use transport.supported instead",
|
||||
exclude=True,
|
||||
)
|
||||
_parallel_delegation: bool = PrivateAttr(default=False)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _migrate_deprecated_transport_fields(self) -> Self:
|
||||
"""Migrate deprecated transport fields to new config."""
|
||||
_migrate_client_transport_fields(
|
||||
self.transport, self.transport_protocol, self.supported_transports
|
||||
)
|
||||
return self
|
||||
|
||||
|
||||
class A2AServerConfig(BaseModel):
|
||||
"""Configuration for exposing a Crew or Agent as an A2A server.
|
||||
|
||||
All fields correspond to A2A AgentCard fields. Fields like name, description,
|
||||
and skills can be auto-derived from the Crew/Agent if not provided.
|
||||
|
||||
Attributes:
|
||||
name: Human-readable name for the agent.
|
||||
description: Human-readable description of the agent.
|
||||
version: Version string for the agent card.
|
||||
skills: List of agent skills/capabilities.
|
||||
default_input_modes: Default supported input MIME types.
|
||||
default_output_modes: Default supported output MIME types.
|
||||
capabilities: Declaration of optional capabilities.
|
||||
protocol_version: A2A protocol version this agent supports.
|
||||
provider: Information about the agent's service provider.
|
||||
documentation_url: URL to the agent's documentation.
|
||||
icon_url: URL to an icon for the agent.
|
||||
additional_interfaces: Additional supported interfaces.
|
||||
security: Security requirement objects for all interactions.
|
||||
security_schemes: Security schemes available to authorize requests.
|
||||
supports_authenticated_extended_card: Whether agent provides extended card to authenticated users.
|
||||
url: Preferred endpoint URL for the agent.
|
||||
signing_config: Configuration for signing the AgentCard with JWS.
|
||||
signatures: Deprecated. Pre-computed JWS signatures. Use signing_config instead.
|
||||
server_extensions: Server-side A2A protocol extensions with on_request/on_response hooks.
|
||||
push_notifications: Configuration for outgoing push notifications.
|
||||
transport: Transport configuration (preferred transport, gRPC, REST settings).
|
||||
auth: Authentication scheme for A2A endpoints.
|
||||
"""
|
||||
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
|
||||
|
||||
name: str | None = Field(
|
||||
default=None,
|
||||
description="Human-readable name for the agent. Auto-derived from Crew/Agent if not provided.",
|
||||
)
|
||||
description: str | None = Field(
|
||||
default=None,
|
||||
description="Human-readable description of the agent. Auto-derived from Crew/Agent if not provided.",
|
||||
)
|
||||
version: str = Field(
|
||||
default="1.0.0",
|
||||
description="Version string for the agent card",
|
||||
)
|
||||
skills: list[AgentSkill] = Field(
|
||||
default_factory=list,
|
||||
description="List of agent skills. Auto-derived from tasks/tools if not provided.",
|
||||
)
|
||||
default_input_modes: list[str] = Field(
|
||||
default_factory=lambda: ["text/plain", "application/json"],
|
||||
description="Default supported input MIME types",
|
||||
)
|
||||
default_output_modes: list[str] = Field(
|
||||
default_factory=lambda: ["text/plain", "application/json"],
|
||||
description="Default supported output MIME types",
|
||||
)
|
||||
capabilities: AgentCapabilities = Field(
|
||||
default_factory=lambda: AgentCapabilities(
|
||||
streaming=True,
|
||||
push_notifications=False,
|
||||
),
|
||||
description="Declaration of optional capabilities supported by the agent",
|
||||
)
|
||||
protocol_version: ProtocolVersion = Field(
|
||||
default="0.3.0",
|
||||
description="A2A protocol version this agent supports",
|
||||
)
|
||||
provider: AgentProvider | None = Field(
|
||||
default=None,
|
||||
description="Information about the agent's service provider",
|
||||
)
|
||||
documentation_url: Url | None = Field(
|
||||
default=None,
|
||||
description="URL to the agent's documentation",
|
||||
)
|
||||
icon_url: Url | None = Field(
|
||||
default=None,
|
||||
description="URL to an icon for the agent",
|
||||
)
|
||||
additional_interfaces: list[AgentInterface] = Field(
|
||||
default_factory=list,
|
||||
description="Additional supported interfaces.",
|
||||
)
|
||||
security: list[dict[str, list[str]]] = Field(
|
||||
default_factory=list,
|
||||
description="Security requirement objects for all agent interactions",
|
||||
)
|
||||
security_schemes: dict[str, SecurityScheme] = Field(
|
||||
default_factory=dict,
|
||||
description="Security schemes available to authorize requests",
|
||||
)
|
||||
supports_authenticated_extended_card: bool = Field(
|
||||
default=False,
|
||||
description="Whether agent provides extended card to authenticated users",
|
||||
)
|
||||
url: Url | None = Field(
|
||||
default=None,
|
||||
description="Preferred endpoint URL for the agent. Set at runtime if not provided.",
|
||||
)
|
||||
signing_config: AgentCardSigningConfig | None = Field(
|
||||
default=None,
|
||||
description="Configuration for signing the AgentCard with JWS",
|
||||
)
|
||||
signatures: list[AgentCardSignature] | None = Field(
|
||||
default=None,
|
||||
description="Deprecated: Use signing_config instead. Pre-computed JWS signatures for the AgentCard.",
|
||||
exclude=True,
|
||||
deprecated=True,
|
||||
)
|
||||
server_extensions: list[ServerExtension] = Field(
|
||||
default_factory=list,
|
||||
description="Server-side A2A protocol extensions that modify agent behavior",
|
||||
)
|
||||
push_notifications: ServerPushNotificationConfig | None = Field(
|
||||
default=None,
|
||||
description="Configuration for outgoing push notifications",
|
||||
)
|
||||
transport: ServerTransportConfig = Field(
|
||||
default_factory=ServerTransportConfig,
|
||||
description="Transport configuration (preferred transport, gRPC, REST settings)",
|
||||
)
|
||||
preferred_transport: TransportType | None = Field(
|
||||
default=None,
|
||||
description="Deprecated: Use transport.preferred instead",
|
||||
exclude=True,
|
||||
deprecated=True,
|
||||
)
|
||||
auth: ServerAuthScheme | None = Field(
|
||||
default=None,
|
||||
description="Authentication scheme for A2A endpoints. Defaults to SimpleTokenAuth using AUTH_TOKEN env var.",
|
||||
)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _migrate_deprecated_fields(self) -> Self:
|
||||
"""Migrate deprecated fields to new config."""
|
||||
if self.preferred_transport is not None:
|
||||
warnings.warn(
|
||||
"preferred_transport is deprecated, use transport=ServerTransportConfig(preferred=...) instead",
|
||||
FutureWarning,
|
||||
stacklevel=4,
|
||||
)
|
||||
object.__setattr__(self.transport, "preferred", self.preferred_transport)
|
||||
if self.signatures is not None:
|
||||
warnings.warn(
|
||||
"signatures is deprecated, use signing_config=AgentCardSigningConfig(...) instead. "
|
||||
"The signatures field will be removed in v2.0.0.",
|
||||
FutureWarning,
|
||||
stacklevel=4,
|
||||
)
|
||||
return self
|
||||
491
lib/crewai-a2a/src/crewai_a2a/errors.py
Normal file
491
lib/crewai-a2a/src/crewai_a2a/errors.py
Normal file
@@ -0,0 +1,491 @@
|
||||
"""A2A error codes and error response utilities.
|
||||
|
||||
This module provides a centralized mapping of all A2A protocol error codes
|
||||
as defined in the A2A specification, plus custom CrewAI extensions.
|
||||
|
||||
Error codes follow JSON-RPC 2.0 conventions:
|
||||
- -32700 to -32600: Standard JSON-RPC errors
|
||||
- -32099 to -32000: Server errors (A2A-specific)
|
||||
- -32768 to -32100: Reserved for implementation-defined errors
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from enum import IntEnum
|
||||
from typing import Any
|
||||
|
||||
from a2a.client.errors import A2AClientTimeoutError
|
||||
|
||||
|
||||
class A2APollingTimeoutError(A2AClientTimeoutError):
|
||||
"""Raised when polling exceeds the configured timeout."""
|
||||
|
||||
|
||||
class A2AErrorCode(IntEnum):
|
||||
"""A2A protocol error codes.
|
||||
|
||||
Codes follow JSON-RPC 2.0 specification with A2A-specific extensions.
|
||||
"""
|
||||
|
||||
# JSON-RPC 2.0 Standard Errors (-32700 to -32600)
|
||||
JSON_PARSE_ERROR = -32700
|
||||
"""Invalid JSON was received by the server."""
|
||||
|
||||
INVALID_REQUEST = -32600
|
||||
"""The JSON sent is not a valid Request object."""
|
||||
|
||||
METHOD_NOT_FOUND = -32601
|
||||
"""The method does not exist / is not available."""
|
||||
|
||||
INVALID_PARAMS = -32602
|
||||
"""Invalid method parameter(s)."""
|
||||
|
||||
INTERNAL_ERROR = -32603
|
||||
"""Internal JSON-RPC error."""
|
||||
|
||||
# A2A-Specific Errors (-32099 to -32000)
|
||||
TASK_NOT_FOUND = -32001
|
||||
"""The specified task was not found."""
|
||||
|
||||
TASK_NOT_CANCELABLE = -32002
|
||||
"""The task cannot be canceled (already completed/failed)."""
|
||||
|
||||
PUSH_NOTIFICATION_NOT_SUPPORTED = -32003
|
||||
"""Push notifications are not supported by this agent."""
|
||||
|
||||
UNSUPPORTED_OPERATION = -32004
|
||||
"""The requested operation is not supported."""
|
||||
|
||||
CONTENT_TYPE_NOT_SUPPORTED = -32005
|
||||
"""Incompatible content types between client and server."""
|
||||
|
||||
INVALID_AGENT_RESPONSE = -32006
|
||||
"""The agent produced an invalid response."""
|
||||
|
||||
# CrewAI Custom Extensions (-32768 to -32100)
|
||||
UNSUPPORTED_VERSION = -32009
|
||||
"""The requested A2A protocol version is not supported."""
|
||||
|
||||
UNSUPPORTED_EXTENSION = -32010
|
||||
"""Client does not support required protocol extensions."""
|
||||
|
||||
AUTHENTICATION_REQUIRED = -32011
|
||||
"""Authentication is required for this operation."""
|
||||
|
||||
AUTHORIZATION_FAILED = -32012
|
||||
"""Authorization check failed (insufficient permissions)."""
|
||||
|
||||
RATE_LIMIT_EXCEEDED = -32013
|
||||
"""Rate limit exceeded for this client/operation."""
|
||||
|
||||
TASK_TIMEOUT = -32014
|
||||
"""Task execution timed out."""
|
||||
|
||||
TRANSPORT_NEGOTIATION_FAILED = -32015
|
||||
"""Failed to negotiate a compatible transport protocol."""
|
||||
|
||||
CONTEXT_NOT_FOUND = -32016
|
||||
"""The specified context was not found."""
|
||||
|
||||
SKILL_NOT_FOUND = -32017
|
||||
"""The specified skill was not found."""
|
||||
|
||||
ARTIFACT_NOT_FOUND = -32018
|
||||
"""The specified artifact was not found."""
|
||||
|
||||
|
||||
# Error code to default message mapping
|
||||
ERROR_MESSAGES: dict[int, str] = {
|
||||
A2AErrorCode.JSON_PARSE_ERROR: "Parse error",
|
||||
A2AErrorCode.INVALID_REQUEST: "Invalid Request",
|
||||
A2AErrorCode.METHOD_NOT_FOUND: "Method not found",
|
||||
A2AErrorCode.INVALID_PARAMS: "Invalid params",
|
||||
A2AErrorCode.INTERNAL_ERROR: "Internal error",
|
||||
A2AErrorCode.TASK_NOT_FOUND: "Task not found",
|
||||
A2AErrorCode.TASK_NOT_CANCELABLE: "Task not cancelable",
|
||||
A2AErrorCode.PUSH_NOTIFICATION_NOT_SUPPORTED: "Push Notification is not supported",
|
||||
A2AErrorCode.UNSUPPORTED_OPERATION: "This operation is not supported",
|
||||
A2AErrorCode.CONTENT_TYPE_NOT_SUPPORTED: "Incompatible content types",
|
||||
A2AErrorCode.INVALID_AGENT_RESPONSE: "Invalid agent response",
|
||||
A2AErrorCode.UNSUPPORTED_VERSION: "Unsupported A2A version",
|
||||
A2AErrorCode.UNSUPPORTED_EXTENSION: "Client does not support required extensions",
|
||||
A2AErrorCode.AUTHENTICATION_REQUIRED: "Authentication required",
|
||||
A2AErrorCode.AUTHORIZATION_FAILED: "Authorization failed",
|
||||
A2AErrorCode.RATE_LIMIT_EXCEEDED: "Rate limit exceeded",
|
||||
A2AErrorCode.TASK_TIMEOUT: "Task execution timed out",
|
||||
A2AErrorCode.TRANSPORT_NEGOTIATION_FAILED: "Transport negotiation failed",
|
||||
A2AErrorCode.CONTEXT_NOT_FOUND: "Context not found",
|
||||
A2AErrorCode.SKILL_NOT_FOUND: "Skill not found",
|
||||
A2AErrorCode.ARTIFACT_NOT_FOUND: "Artifact not found",
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class A2AError(Exception):
|
||||
"""Base exception for A2A protocol errors.
|
||||
|
||||
Attributes:
|
||||
code: The A2A/JSON-RPC error code.
|
||||
message: Human-readable error message.
|
||||
data: Optional additional error data.
|
||||
"""
|
||||
|
||||
code: int
|
||||
message: str | None = None
|
||||
data: Any = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.message is None:
|
||||
self.message = ERROR_MESSAGES.get(self.code, "Unknown error")
|
||||
super().__init__(self.message)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to JSON-RPC error object format."""
|
||||
error: dict[str, Any] = {
|
||||
"code": self.code,
|
||||
"message": self.message,
|
||||
}
|
||||
if self.data is not None:
|
||||
error["data"] = self.data
|
||||
return error
|
||||
|
||||
def to_response(self, request_id: str | int | None = None) -> dict[str, Any]:
|
||||
"""Convert to full JSON-RPC error response."""
|
||||
return {
|
||||
"jsonrpc": "2.0",
|
||||
"error": self.to_dict(),
|
||||
"id": request_id,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class JSONParseError(A2AError):
|
||||
"""Invalid JSON was received."""
|
||||
|
||||
code: int = field(default=A2AErrorCode.JSON_PARSE_ERROR, init=False)
|
||||
|
||||
|
||||
@dataclass
|
||||
class InvalidRequestError(A2AError):
|
||||
"""The JSON sent is not a valid Request object."""
|
||||
|
||||
code: int = field(default=A2AErrorCode.INVALID_REQUEST, init=False)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MethodNotFoundError(A2AError):
|
||||
"""The method does not exist / is not available."""
|
||||
|
||||
code: int = field(default=A2AErrorCode.METHOD_NOT_FOUND, init=False)
|
||||
method: str | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.message is None and self.method:
|
||||
self.message = f"Method not found: {self.method}"
|
||||
super().__post_init__()
|
||||
|
||||
|
||||
@dataclass
|
||||
class InvalidParamsError(A2AError):
|
||||
"""Invalid method parameter(s)."""
|
||||
|
||||
code: int = field(default=A2AErrorCode.INVALID_PARAMS, init=False)
|
||||
param: str | None = None
|
||||
reason: str | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.message is None:
|
||||
if self.param and self.reason:
|
||||
self.message = f"Invalid parameter '{self.param}': {self.reason}"
|
||||
elif self.param:
|
||||
self.message = f"Invalid parameter: {self.param}"
|
||||
super().__post_init__()
|
||||
|
||||
|
||||
@dataclass
|
||||
class InternalError(A2AError):
|
||||
"""Internal JSON-RPC error."""
|
||||
|
||||
code: int = field(default=A2AErrorCode.INTERNAL_ERROR, init=False)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TaskNotFoundError(A2AError):
|
||||
"""The specified task was not found."""
|
||||
|
||||
code: int = field(default=A2AErrorCode.TASK_NOT_FOUND, init=False)
|
||||
task_id: str | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.message is None and self.task_id:
|
||||
self.message = f"Task not found: {self.task_id}"
|
||||
super().__post_init__()
|
||||
|
||||
|
||||
@dataclass
|
||||
class TaskNotCancelableError(A2AError):
|
||||
"""The task cannot be canceled."""
|
||||
|
||||
code: int = field(default=A2AErrorCode.TASK_NOT_CANCELABLE, init=False)
|
||||
task_id: str | None = None
|
||||
reason: str | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.message is None:
|
||||
if self.task_id and self.reason:
|
||||
self.message = f"Task {self.task_id} cannot be canceled: {self.reason}"
|
||||
elif self.task_id:
|
||||
self.message = f"Task {self.task_id} cannot be canceled"
|
||||
super().__post_init__()
|
||||
|
||||
|
||||
@dataclass
|
||||
class PushNotificationNotSupportedError(A2AError):
|
||||
"""Push notifications are not supported."""
|
||||
|
||||
code: int = field(default=A2AErrorCode.PUSH_NOTIFICATION_NOT_SUPPORTED, init=False)
|
||||
|
||||
|
||||
@dataclass
|
||||
class UnsupportedOperationError(A2AError):
|
||||
"""The requested operation is not supported."""
|
||||
|
||||
code: int = field(default=A2AErrorCode.UNSUPPORTED_OPERATION, init=False)
|
||||
operation: str | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.message is None and self.operation:
|
||||
self.message = f"Operation not supported: {self.operation}"
|
||||
super().__post_init__()
|
||||
|
||||
|
||||
@dataclass
|
||||
class ContentTypeNotSupportedError(A2AError):
|
||||
"""Incompatible content types."""
|
||||
|
||||
code: int = field(default=A2AErrorCode.CONTENT_TYPE_NOT_SUPPORTED, init=False)
|
||||
requested_types: list[str] | None = None
|
||||
supported_types: list[str] | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.message is None and self.requested_types and self.supported_types:
|
||||
self.message = (
|
||||
f"Content type not supported. Requested: {self.requested_types}, "
|
||||
f"Supported: {self.supported_types}"
|
||||
)
|
||||
super().__post_init__()
|
||||
|
||||
|
||||
@dataclass
|
||||
class InvalidAgentResponseError(A2AError):
|
||||
"""The agent produced an invalid response."""
|
||||
|
||||
code: int = field(default=A2AErrorCode.INVALID_AGENT_RESPONSE, init=False)
|
||||
|
||||
|
||||
@dataclass
|
||||
class UnsupportedVersionError(A2AError):
|
||||
"""The requested A2A version is not supported."""
|
||||
|
||||
code: int = field(default=A2AErrorCode.UNSUPPORTED_VERSION, init=False)
|
||||
requested_version: str | None = None
|
||||
supported_versions: list[str] | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.message is None and self.requested_version:
|
||||
msg = f"Unsupported A2A version: {self.requested_version}"
|
||||
if self.supported_versions:
|
||||
msg += f". Supported versions: {', '.join(self.supported_versions)}"
|
||||
self.message = msg
|
||||
super().__post_init__()
|
||||
|
||||
|
||||
@dataclass
|
||||
class UnsupportedExtensionError(A2AError):
|
||||
"""Client does not support required extensions."""
|
||||
|
||||
code: int = field(default=A2AErrorCode.UNSUPPORTED_EXTENSION, init=False)
|
||||
required_extensions: list[str] | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.message is None and self.required_extensions:
|
||||
self.message = f"Client does not support required extensions: {', '.join(self.required_extensions)}"
|
||||
super().__post_init__()
|
||||
|
||||
|
||||
@dataclass
|
||||
class AuthenticationRequiredError(A2AError):
|
||||
"""Authentication is required."""
|
||||
|
||||
code: int = field(default=A2AErrorCode.AUTHENTICATION_REQUIRED, init=False)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AuthorizationFailedError(A2AError):
|
||||
"""Authorization check failed."""
|
||||
|
||||
code: int = field(default=A2AErrorCode.AUTHORIZATION_FAILED, init=False)
|
||||
required_scope: str | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.message is None and self.required_scope:
|
||||
self.message = (
|
||||
f"Authorization failed. Required scope: {self.required_scope}"
|
||||
)
|
||||
super().__post_init__()
|
||||
|
||||
|
||||
@dataclass
|
||||
class RateLimitExceededError(A2AError):
|
||||
"""Rate limit exceeded."""
|
||||
|
||||
code: int = field(default=A2AErrorCode.RATE_LIMIT_EXCEEDED, init=False)
|
||||
retry_after: int | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.message is None and self.retry_after:
|
||||
self.message = (
|
||||
f"Rate limit exceeded. Retry after {self.retry_after} seconds"
|
||||
)
|
||||
if self.retry_after:
|
||||
self.data = {"retry_after": self.retry_after}
|
||||
super().__post_init__()
|
||||
|
||||
|
||||
@dataclass
|
||||
class TaskTimeoutError(A2AError):
|
||||
"""Task execution timed out."""
|
||||
|
||||
code: int = field(default=A2AErrorCode.TASK_TIMEOUT, init=False)
|
||||
task_id: str | None = None
|
||||
timeout_seconds: float | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.message is None:
|
||||
if self.task_id and self.timeout_seconds:
|
||||
self.message = (
|
||||
f"Task {self.task_id} timed out after {self.timeout_seconds}s"
|
||||
)
|
||||
elif self.task_id:
|
||||
self.message = f"Task {self.task_id} timed out"
|
||||
super().__post_init__()
|
||||
|
||||
|
||||
@dataclass
|
||||
class TransportNegotiationFailedError(A2AError):
|
||||
"""Failed to negotiate a compatible transport protocol."""
|
||||
|
||||
code: int = field(default=A2AErrorCode.TRANSPORT_NEGOTIATION_FAILED, init=False)
|
||||
client_transports: list[str] | None = None
|
||||
server_transports: list[str] | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.message is None and self.client_transports and self.server_transports:
|
||||
self.message = (
|
||||
f"Transport negotiation failed. Client: {self.client_transports}, "
|
||||
f"Server: {self.server_transports}"
|
||||
)
|
||||
super().__post_init__()
|
||||
|
||||
|
||||
@dataclass
|
||||
class ContextNotFoundError(A2AError):
|
||||
"""The specified context was not found."""
|
||||
|
||||
code: int = field(default=A2AErrorCode.CONTEXT_NOT_FOUND, init=False)
|
||||
context_id: str | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.message is None and self.context_id:
|
||||
self.message = f"Context not found: {self.context_id}"
|
||||
super().__post_init__()
|
||||
|
||||
|
||||
@dataclass
|
||||
class SkillNotFoundError(A2AError):
|
||||
"""The specified skill was not found."""
|
||||
|
||||
code: int = field(default=A2AErrorCode.SKILL_NOT_FOUND, init=False)
|
||||
skill_id: str | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.message is None and self.skill_id:
|
||||
self.message = f"Skill not found: {self.skill_id}"
|
||||
super().__post_init__()
|
||||
|
||||
|
||||
@dataclass
|
||||
class ArtifactNotFoundError(A2AError):
|
||||
"""The specified artifact was not found."""
|
||||
|
||||
code: int = field(default=A2AErrorCode.ARTIFACT_NOT_FOUND, init=False)
|
||||
artifact_id: str | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.message is None and self.artifact_id:
|
||||
self.message = f"Artifact not found: {self.artifact_id}"
|
||||
super().__post_init__()
|
||||
|
||||
|
||||
def create_error_response(
|
||||
code: int | A2AErrorCode,
|
||||
message: str | None = None,
|
||||
data: Any = None,
|
||||
request_id: str | int | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Create a JSON-RPC error response.
|
||||
|
||||
Args:
|
||||
code: Error code (A2AErrorCode or int).
|
||||
message: Optional error message (uses default if not provided).
|
||||
data: Optional additional error data.
|
||||
request_id: Request ID for correlation.
|
||||
|
||||
Returns:
|
||||
Dict in JSON-RPC error response format.
|
||||
"""
|
||||
error = A2AError(code=int(code), message=message, data=data)
|
||||
return error.to_response(request_id)
|
||||
|
||||
|
||||
def is_retryable_error(code: int) -> bool:
|
||||
"""Check if an error is potentially retryable.
|
||||
|
||||
Args:
|
||||
code: Error code to check.
|
||||
|
||||
Returns:
|
||||
True if the error might be resolved by retrying.
|
||||
"""
|
||||
retryable_codes = {
|
||||
A2AErrorCode.INTERNAL_ERROR,
|
||||
A2AErrorCode.RATE_LIMIT_EXCEEDED,
|
||||
A2AErrorCode.TASK_TIMEOUT,
|
||||
}
|
||||
return code in retryable_codes
|
||||
|
||||
|
||||
def is_client_error(code: int) -> bool:
|
||||
"""Check if an error is a client-side error.
|
||||
|
||||
Args:
|
||||
code: Error code to check.
|
||||
|
||||
Returns:
|
||||
True if the error is due to client request issues.
|
||||
"""
|
||||
client_error_codes = {
|
||||
A2AErrorCode.JSON_PARSE_ERROR,
|
||||
A2AErrorCode.INVALID_REQUEST,
|
||||
A2AErrorCode.METHOD_NOT_FOUND,
|
||||
A2AErrorCode.INVALID_PARAMS,
|
||||
A2AErrorCode.TASK_NOT_FOUND,
|
||||
A2AErrorCode.CONTENT_TYPE_NOT_SUPPORTED,
|
||||
A2AErrorCode.UNSUPPORTED_VERSION,
|
||||
A2AErrorCode.UNSUPPORTED_EXTENSION,
|
||||
A2AErrorCode.CONTEXT_NOT_FOUND,
|
||||
A2AErrorCode.SKILL_NOT_FOUND,
|
||||
A2AErrorCode.ARTIFACT_NOT_FOUND,
|
||||
}
|
||||
return code in client_error_codes
|
||||
37
lib/crewai-a2a/src/crewai_a2a/extensions/__init__.py
Normal file
37
lib/crewai-a2a/src/crewai_a2a/extensions/__init__.py
Normal file
@@ -0,0 +1,37 @@
|
||||
"""A2A Protocol Extensions for CrewAI.
|
||||
|
||||
This module contains extensions to the A2A (Agent-to-Agent) protocol.
|
||||
|
||||
**Client-side extensions** (A2AExtension) allow customizing how the A2A wrapper
|
||||
processes requests and responses during delegation to remote agents. These provide
|
||||
hooks for tool injection, prompt augmentation, and response processing.
|
||||
|
||||
**Server-side extensions** (ServerExtension) allow agents to offer additional
|
||||
functionality beyond the core A2A specification. Clients activate extensions
|
||||
via the X-A2A-Extensions header.
|
||||
|
||||
See: https://a2a-protocol.org/latest/topics/extensions/
|
||||
"""
|
||||
|
||||
from crewai_a2a.extensions.base import (
|
||||
A2AExtension,
|
||||
ConversationState,
|
||||
ExtensionRegistry,
|
||||
ValidatedA2AExtension,
|
||||
)
|
||||
from crewai_a2a.extensions.server import (
|
||||
ExtensionContext,
|
||||
ServerExtension,
|
||||
ServerExtensionRegistry,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"A2AExtension",
|
||||
"ConversationState",
|
||||
"ExtensionContext",
|
||||
"ExtensionRegistry",
|
||||
"ServerExtension",
|
||||
"ServerExtensionRegistry",
|
||||
"ValidatedA2AExtension",
|
||||
]
|
||||
237
lib/crewai-a2a/src/crewai_a2a/extensions/base.py
Normal file
237
lib/crewai-a2a/src/crewai_a2a/extensions/base.py
Normal file
@@ -0,0 +1,237 @@
|
||||
"""Base extension interface for CrewAI A2A wrapper processing hooks.
|
||||
|
||||
This module defines the protocol for extending CrewAI's A2A wrapper functionality
|
||||
with custom logic for tool injection, prompt augmentation, and response processing.
|
||||
|
||||
Note: These are CrewAI-specific processing hooks, NOT A2A protocol extensions.
|
||||
A2A protocol extensions are capability declarations using AgentExtension objects
|
||||
in AgentCard.capabilities.extensions, activated via the A2A-Extensions HTTP header.
|
||||
See: https://a2a-protocol.org/latest/topics/extensions/
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import TYPE_CHECKING, Annotated, Any, Protocol, runtime_checkable
|
||||
|
||||
from pydantic import BeforeValidator
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from a2a.types import Message
|
||||
from crewai.agent.core import Agent
|
||||
|
||||
|
||||
def _validate_a2a_extension(v: Any) -> Any:
|
||||
"""Validate that value implements A2AExtension protocol."""
|
||||
if not isinstance(v, A2AExtension):
|
||||
raise ValueError(
|
||||
f"Value must implement A2AExtension protocol. "
|
||||
f"Got {type(v).__name__} which is missing required methods."
|
||||
)
|
||||
return v
|
||||
|
||||
|
||||
ValidatedA2AExtension = Annotated[Any, BeforeValidator(_validate_a2a_extension)]
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class ConversationState(Protocol):
|
||||
"""Protocol for extension-specific conversation state.
|
||||
|
||||
Extensions can define their own state classes that implement this protocol
|
||||
to track conversation-specific data extracted from message history.
|
||||
"""
|
||||
|
||||
def is_ready(self) -> bool:
|
||||
"""Check if the state indicates readiness for some action.
|
||||
|
||||
Returns:
|
||||
True if the state is ready, False otherwise.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class A2AExtension(Protocol):
|
||||
"""Protocol for A2A wrapper extensions.
|
||||
|
||||
Extensions can implement this protocol to inject custom logic into
|
||||
the A2A conversation flow at various integration points.
|
||||
|
||||
Example:
|
||||
class MyExtension:
|
||||
def inject_tools(self, agent: Agent) -> None:
|
||||
# Add custom tools to the agent
|
||||
pass
|
||||
|
||||
def extract_state_from_history(
|
||||
self, conversation_history: Sequence[Message]
|
||||
) -> ConversationState | None:
|
||||
# Extract state from conversation
|
||||
return None
|
||||
|
||||
def augment_prompt(
|
||||
self, base_prompt: str, conversation_state: ConversationState | None
|
||||
) -> str:
|
||||
# Add custom instructions
|
||||
return base_prompt
|
||||
|
||||
def process_response(
|
||||
self, agent_response: Any, conversation_state: ConversationState | None
|
||||
) -> Any:
|
||||
# Modify response if needed
|
||||
return agent_response
|
||||
"""
|
||||
|
||||
def inject_tools(self, agent: Agent) -> None:
|
||||
"""Inject extension-specific tools into the agent.
|
||||
|
||||
Called when an agent is wrapped with A2A capabilities. Extensions
|
||||
can add tools that enable extension-specific functionality.
|
||||
|
||||
Args:
|
||||
agent: The agent instance to inject tools into.
|
||||
"""
|
||||
...
|
||||
|
||||
def extract_state_from_history(
|
||||
self, conversation_history: Sequence[Message]
|
||||
) -> ConversationState | None:
|
||||
"""Extract extension-specific state from conversation history.
|
||||
|
||||
Called during prompt augmentation to allow extensions to analyze
|
||||
the conversation history and extract relevant state information.
|
||||
|
||||
Args:
|
||||
conversation_history: The sequence of A2A messages exchanged.
|
||||
|
||||
Returns:
|
||||
Extension-specific conversation state, or None if no relevant state.
|
||||
"""
|
||||
...
|
||||
|
||||
def augment_prompt(
|
||||
self,
|
||||
base_prompt: str,
|
||||
conversation_state: ConversationState | None,
|
||||
) -> str:
|
||||
"""Augment the task prompt with extension-specific instructions.
|
||||
|
||||
Called during prompt augmentation to allow extensions to add
|
||||
custom instructions based on conversation state.
|
||||
|
||||
Args:
|
||||
base_prompt: The base prompt to augment.
|
||||
conversation_state: Extension-specific state from extract_state_from_history.
|
||||
|
||||
Returns:
|
||||
The augmented prompt with extension-specific instructions.
|
||||
"""
|
||||
...
|
||||
|
||||
def process_response(
|
||||
self,
|
||||
agent_response: Any,
|
||||
conversation_state: ConversationState | None,
|
||||
) -> Any:
|
||||
"""Process and potentially modify the agent response.
|
||||
|
||||
Called after parsing the agent's response, allowing extensions to
|
||||
enhance or modify the response based on conversation state.
|
||||
|
||||
Args:
|
||||
agent_response: The parsed agent response.
|
||||
conversation_state: Extension-specific state from extract_state_from_history.
|
||||
|
||||
Returns:
|
||||
The processed agent response (may be modified or original).
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class ExtensionRegistry:
|
||||
"""Registry for managing A2A extensions.
|
||||
|
||||
Maintains a collection of extensions and provides methods to invoke
|
||||
their hooks at various integration points.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the extension registry."""
|
||||
self._extensions: list[A2AExtension] = []
|
||||
|
||||
def register(self, extension: A2AExtension) -> None:
|
||||
"""Register an extension.
|
||||
|
||||
Args:
|
||||
extension: The extension to register.
|
||||
"""
|
||||
self._extensions.append(extension)
|
||||
|
||||
def inject_all_tools(self, agent: Agent) -> None:
|
||||
"""Inject tools from all registered extensions.
|
||||
|
||||
Args:
|
||||
agent: The agent instance to inject tools into.
|
||||
"""
|
||||
for extension in self._extensions:
|
||||
extension.inject_tools(agent)
|
||||
|
||||
def extract_all_states(
|
||||
self, conversation_history: Sequence[Message]
|
||||
) -> dict[type[A2AExtension], ConversationState]:
|
||||
"""Extract conversation states from all registered extensions.
|
||||
|
||||
Args:
|
||||
conversation_history: The sequence of A2A messages exchanged.
|
||||
|
||||
Returns:
|
||||
Mapping of extension types to their conversation states.
|
||||
"""
|
||||
states: dict[type[A2AExtension], ConversationState] = {}
|
||||
for extension in self._extensions:
|
||||
state = extension.extract_state_from_history(conversation_history)
|
||||
if state is not None:
|
||||
states[type(extension)] = state
|
||||
return states
|
||||
|
||||
def augment_prompt_with_all(
|
||||
self,
|
||||
base_prompt: str,
|
||||
extension_states: dict[type[A2AExtension], ConversationState],
|
||||
) -> str:
|
||||
"""Augment prompt with instructions from all registered extensions.
|
||||
|
||||
Args:
|
||||
base_prompt: The base prompt to augment.
|
||||
extension_states: Mapping of extension types to conversation states.
|
||||
|
||||
Returns:
|
||||
The fully augmented prompt.
|
||||
"""
|
||||
augmented = base_prompt
|
||||
for extension in self._extensions:
|
||||
state = extension_states.get(type(extension))
|
||||
augmented = extension.augment_prompt(augmented, state)
|
||||
return augmented
|
||||
|
||||
def process_response_with_all(
|
||||
self,
|
||||
agent_response: Any,
|
||||
extension_states: dict[type[A2AExtension], ConversationState],
|
||||
) -> Any:
|
||||
"""Process response through all registered extensions.
|
||||
|
||||
Args:
|
||||
agent_response: The parsed agent response.
|
||||
extension_states: Mapping of extension types to conversation states.
|
||||
|
||||
Returns:
|
||||
The processed agent response.
|
||||
"""
|
||||
processed = agent_response
|
||||
for extension in self._extensions:
|
||||
state = extension_states.get(type(extension))
|
||||
processed = extension.process_response(processed, state)
|
||||
return processed
|
||||
170
lib/crewai-a2a/src/crewai_a2a/extensions/registry.py
Normal file
170
lib/crewai-a2a/src/crewai_a2a/extensions/registry.py
Normal file
@@ -0,0 +1,170 @@
|
||||
"""A2A Protocol extension utilities.
|
||||
|
||||
This module provides utilities for working with A2A protocol extensions as
|
||||
defined in the A2A specification. Extensions are capability declarations in
|
||||
AgentCard.capabilities.extensions using AgentExtension objects, activated
|
||||
via the X-A2A-Extensions HTTP header.
|
||||
|
||||
See: https://a2a-protocol.org/latest/topics/extensions/
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from a2a.client.middleware import ClientCallContext, ClientCallInterceptor
|
||||
from a2a.extensions.common import (
|
||||
HTTP_EXTENSION_HEADER,
|
||||
)
|
||||
from a2a.types import AgentCard, AgentExtension
|
||||
|
||||
from crewai_a2a.config import A2AClientConfig, A2AConfig
|
||||
from crewai_a2a.extensions.base import ExtensionRegistry
|
||||
|
||||
|
||||
def get_extensions_from_config(
|
||||
a2a_config: list[A2AConfig | A2AClientConfig] | A2AConfig | A2AClientConfig,
|
||||
) -> list[str]:
|
||||
"""Extract extension URIs from A2A configuration.
|
||||
|
||||
Args:
|
||||
a2a_config: A2A configuration (single or list).
|
||||
|
||||
Returns:
|
||||
Deduplicated list of extension URIs from all configs.
|
||||
"""
|
||||
configs = a2a_config if isinstance(a2a_config, list) else [a2a_config]
|
||||
seen: set[str] = set()
|
||||
result: list[str] = []
|
||||
|
||||
for config in configs:
|
||||
if not isinstance(config, A2AClientConfig):
|
||||
continue
|
||||
for uri in config.extensions:
|
||||
if uri not in seen:
|
||||
seen.add(uri)
|
||||
result.append(uri)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class ExtensionsMiddleware(ClientCallInterceptor):
|
||||
"""Middleware to add X-A2A-Extensions header to requests.
|
||||
|
||||
This middleware adds the extensions header to all outgoing requests,
|
||||
declaring which A2A protocol extensions the client supports.
|
||||
"""
|
||||
|
||||
def __init__(self, extensions: list[str]) -> None:
|
||||
"""Initialize with extension URIs.
|
||||
|
||||
Args:
|
||||
extensions: List of extension URIs the client supports.
|
||||
"""
|
||||
self._extensions = extensions
|
||||
|
||||
async def intercept(
|
||||
self,
|
||||
method_name: str,
|
||||
request_payload: dict[str, Any],
|
||||
http_kwargs: dict[str, Any],
|
||||
agent_card: AgentCard | None,
|
||||
context: ClientCallContext | None,
|
||||
) -> tuple[dict[str, Any], dict[str, Any]]:
|
||||
"""Add extensions header to the request.
|
||||
|
||||
Args:
|
||||
method_name: The A2A method being called.
|
||||
request_payload: The JSON-RPC request payload.
|
||||
http_kwargs: HTTP request kwargs (headers, etc).
|
||||
agent_card: The target agent's card.
|
||||
context: Optional call context.
|
||||
|
||||
Returns:
|
||||
Tuple of (request_payload, modified_http_kwargs).
|
||||
"""
|
||||
if self._extensions:
|
||||
headers = http_kwargs.setdefault("headers", {})
|
||||
headers[HTTP_EXTENSION_HEADER] = ",".join(self._extensions)
|
||||
return request_payload, http_kwargs
|
||||
|
||||
|
||||
def validate_required_extensions(
|
||||
agent_card: AgentCard,
|
||||
client_extensions: list[str] | None,
|
||||
) -> list[AgentExtension]:
|
||||
"""Validate that client supports all required extensions from agent.
|
||||
|
||||
Args:
|
||||
agent_card: The agent's card with declared extensions.
|
||||
client_extensions: Extension URIs the client supports.
|
||||
|
||||
Returns:
|
||||
List of unsupported required extensions.
|
||||
|
||||
Raises:
|
||||
None - returns list of unsupported extensions for caller to handle.
|
||||
"""
|
||||
unsupported: list[AgentExtension] = []
|
||||
client_set = set(client_extensions or [])
|
||||
|
||||
if not agent_card.capabilities or not agent_card.capabilities.extensions:
|
||||
return unsupported
|
||||
|
||||
unsupported.extend(
|
||||
ext
|
||||
for ext in agent_card.capabilities.extensions
|
||||
if ext.required and ext.uri not in client_set
|
||||
)
|
||||
|
||||
return unsupported
|
||||
|
||||
|
||||
def create_extension_registry_from_config(
|
||||
a2a_config: list[A2AConfig | A2AClientConfig] | A2AConfig | A2AClientConfig,
|
||||
) -> ExtensionRegistry:
|
||||
"""Create an extension registry from A2A client configuration.
|
||||
|
||||
Extracts client_extensions from each A2AClientConfig and registers them
|
||||
with the ExtensionRegistry. These extensions provide CrewAI-specific
|
||||
processing hooks (tool injection, prompt augmentation, response processing).
|
||||
|
||||
Note: A2A protocol extensions (URI strings sent via X-A2A-Extensions header)
|
||||
are handled separately via get_extensions_from_config() and ExtensionsMiddleware.
|
||||
|
||||
Args:
|
||||
a2a_config: A2A configuration (single or list).
|
||||
|
||||
Returns:
|
||||
Extension registry with all client_extensions registered.
|
||||
|
||||
Example:
|
||||
class LoggingExtension:
|
||||
def inject_tools(self, agent): pass
|
||||
def extract_state_from_history(self, history): return None
|
||||
def augment_prompt(self, prompt, state): return prompt
|
||||
def process_response(self, response, state):
|
||||
print(f"Response: {response}")
|
||||
return response
|
||||
|
||||
config = A2AClientConfig(
|
||||
endpoint="https://agent.example.com",
|
||||
client_extensions=[LoggingExtension()],
|
||||
)
|
||||
registry = create_extension_registry_from_config(config)
|
||||
"""
|
||||
registry = ExtensionRegistry()
|
||||
configs = a2a_config if isinstance(a2a_config, list) else [a2a_config]
|
||||
|
||||
seen: set[int] = set()
|
||||
|
||||
for config in configs:
|
||||
if isinstance(config, (A2AConfig, A2AClientConfig)):
|
||||
client_exts = getattr(config, "client_extensions", [])
|
||||
for extension in client_exts:
|
||||
ext_id = id(extension)
|
||||
if ext_id not in seen:
|
||||
seen.add(ext_id)
|
||||
registry.register(extension)
|
||||
|
||||
return registry
|
||||
305
lib/crewai-a2a/src/crewai_a2a/extensions/server.py
Normal file
305
lib/crewai-a2a/src/crewai_a2a/extensions/server.py
Normal file
@@ -0,0 +1,305 @@
|
||||
"""A2A protocol server extensions for CrewAI agents.
|
||||
|
||||
This module provides the base class and context for implementing A2A protocol
|
||||
extensions on the server side. Extensions allow agents to offer additional
|
||||
functionality beyond the core A2A specification.
|
||||
|
||||
See: https://a2a-protocol.org/latest/topics/extensions/
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Annotated, Any
|
||||
|
||||
from a2a.types import AgentExtension
|
||||
from pydantic_core import CoreSchema, core_schema
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from a2a.server.context import ServerCallContext
|
||||
from pydantic import GetCoreSchemaHandler
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExtensionContext:
|
||||
"""Context passed to extension hooks during request processing.
|
||||
|
||||
Provides access to request metadata, client extensions, and shared state
|
||||
that extensions can read from and write to.
|
||||
|
||||
Attributes:
|
||||
metadata: Request metadata dict, includes extension-namespaced keys.
|
||||
client_extensions: Set of extension URIs the client declared support for.
|
||||
state: Mutable dict for extensions to share data during request lifecycle.
|
||||
server_context: The underlying A2A server call context.
|
||||
"""
|
||||
|
||||
metadata: dict[str, Any]
|
||||
client_extensions: set[str]
|
||||
state: dict[str, Any] = field(default_factory=dict)
|
||||
server_context: ServerCallContext | None = None
|
||||
|
||||
def get_extension_metadata(self, uri: str, key: str) -> Any | None:
|
||||
"""Get extension-specific metadata value.
|
||||
|
||||
Extension metadata uses namespaced keys in the format:
|
||||
"{extension_uri}/{key}"
|
||||
|
||||
Args:
|
||||
uri: The extension URI.
|
||||
key: The metadata key within the extension namespace.
|
||||
|
||||
Returns:
|
||||
The metadata value, or None if not present.
|
||||
"""
|
||||
full_key = f"{uri}/{key}"
|
||||
return self.metadata.get(full_key)
|
||||
|
||||
def set_extension_metadata(self, uri: str, key: str, value: Any) -> None:
|
||||
"""Set extension-specific metadata value.
|
||||
|
||||
Args:
|
||||
uri: The extension URI.
|
||||
key: The metadata key within the extension namespace.
|
||||
value: The value to set.
|
||||
"""
|
||||
full_key = f"{uri}/{key}"
|
||||
self.metadata[full_key] = value
|
||||
|
||||
|
||||
class ServerExtension(ABC):
|
||||
"""Base class for A2A protocol server extensions.
|
||||
|
||||
Subclass this to create custom extensions that modify agent behavior
|
||||
when clients activate them. Extensions are identified by URI and can
|
||||
be marked as required.
|
||||
|
||||
Example:
|
||||
class SamplingExtension(ServerExtension):
|
||||
uri = "urn:crewai:ext:sampling/v1"
|
||||
required = True
|
||||
|
||||
def __init__(self, max_tokens: int = 4096):
|
||||
self.max_tokens = max_tokens
|
||||
|
||||
@property
|
||||
def params(self) -> dict[str, Any]:
|
||||
return {"max_tokens": self.max_tokens}
|
||||
|
||||
async def on_request(self, context: ExtensionContext) -> None:
|
||||
limit = context.get_extension_metadata(self.uri, "limit")
|
||||
if limit:
|
||||
context.state["token_limit"] = int(limit)
|
||||
|
||||
async def on_response(self, context: ExtensionContext, result: Any) -> Any:
|
||||
return result
|
||||
"""
|
||||
|
||||
uri: Annotated[str, "Extension URI identifier. Must be unique."]
|
||||
required: Annotated[bool, "Whether clients must support this extension."] = False
|
||||
description: Annotated[
|
||||
str | None, "Human-readable description of the extension."
|
||||
] = None
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(
|
||||
cls,
|
||||
_source_type: Any,
|
||||
_handler: GetCoreSchemaHandler,
|
||||
) -> CoreSchema:
|
||||
"""Tell Pydantic how to validate ServerExtension instances."""
|
||||
return core_schema.is_instance_schema(cls)
|
||||
|
||||
@property
|
||||
def params(self) -> dict[str, Any] | None:
|
||||
"""Extension parameters to advertise in AgentCard.
|
||||
|
||||
Override this property to expose configuration that clients can read.
|
||||
|
||||
Returns:
|
||||
Dict of parameter names to values, or None.
|
||||
"""
|
||||
return None
|
||||
|
||||
def agent_extension(self) -> AgentExtension:
|
||||
"""Generate the AgentExtension object for the AgentCard.
|
||||
|
||||
Returns:
|
||||
AgentExtension with this extension's URI, required flag, and params.
|
||||
"""
|
||||
return AgentExtension(
|
||||
uri=self.uri,
|
||||
required=self.required if self.required else None,
|
||||
description=self.description,
|
||||
params=self.params,
|
||||
)
|
||||
|
||||
def is_active(self, context: ExtensionContext) -> bool:
|
||||
"""Check if this extension is active for the current request.
|
||||
|
||||
An extension is active if the client declared support for it.
|
||||
|
||||
Args:
|
||||
context: The extension context for the current request.
|
||||
|
||||
Returns:
|
||||
True if the client supports this extension.
|
||||
"""
|
||||
return self.uri in context.client_extensions
|
||||
|
||||
@abstractmethod
|
||||
async def on_request(self, context: ExtensionContext) -> None:
|
||||
"""Called before agent execution if extension is active.
|
||||
|
||||
Use this hook to:
|
||||
- Read extension-specific metadata from the request
|
||||
- Set up state for the execution
|
||||
- Modify execution parameters via context.state
|
||||
|
||||
Args:
|
||||
context: The extension context with request metadata and state.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def on_response(self, context: ExtensionContext, result: Any) -> Any:
|
||||
"""Called after agent execution if extension is active.
|
||||
|
||||
Use this hook to:
|
||||
- Modify or enhance the result
|
||||
- Add extension-specific metadata to the response
|
||||
- Clean up any resources
|
||||
|
||||
Args:
|
||||
context: The extension context with request metadata and state.
|
||||
result: The agent execution result.
|
||||
|
||||
Returns:
|
||||
The result, potentially modified.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class ServerExtensionRegistry:
|
||||
"""Registry for managing server-side A2A protocol extensions.
|
||||
|
||||
Collects extensions and provides methods to generate AgentCapabilities
|
||||
and invoke extension hooks during request processing.
|
||||
"""
|
||||
|
||||
def __init__(self, extensions: list[ServerExtension] | None = None) -> None:
|
||||
"""Initialize the registry with optional extensions.
|
||||
|
||||
Args:
|
||||
extensions: Initial list of extensions to register.
|
||||
"""
|
||||
self._extensions: list[ServerExtension] = list(extensions) if extensions else []
|
||||
self._by_uri: dict[str, ServerExtension] = {
|
||||
ext.uri: ext for ext in self._extensions
|
||||
}
|
||||
|
||||
def register(self, extension: ServerExtension) -> None:
|
||||
"""Register an extension.
|
||||
|
||||
Args:
|
||||
extension: The extension to register.
|
||||
|
||||
Raises:
|
||||
ValueError: If an extension with the same URI is already registered.
|
||||
"""
|
||||
if extension.uri in self._by_uri:
|
||||
raise ValueError(f"Extension already registered: {extension.uri}")
|
||||
self._extensions.append(extension)
|
||||
self._by_uri[extension.uri] = extension
|
||||
|
||||
def get_agent_extensions(self) -> list[AgentExtension]:
|
||||
"""Get AgentExtension objects for all registered extensions.
|
||||
|
||||
Returns:
|
||||
List of AgentExtension objects for the AgentCard.
|
||||
"""
|
||||
return [ext.agent_extension() for ext in self._extensions]
|
||||
|
||||
def get_extension(self, uri: str) -> ServerExtension | None:
|
||||
"""Get an extension by URI.
|
||||
|
||||
Args:
|
||||
uri: The extension URI.
|
||||
|
||||
Returns:
|
||||
The extension, or None if not found.
|
||||
"""
|
||||
return self._by_uri.get(uri)
|
||||
|
||||
@staticmethod
|
||||
def create_context(
|
||||
metadata: dict[str, Any],
|
||||
client_extensions: set[str],
|
||||
server_context: ServerCallContext | None = None,
|
||||
) -> ExtensionContext:
|
||||
"""Create an ExtensionContext for a request.
|
||||
|
||||
Args:
|
||||
metadata: Request metadata dict.
|
||||
client_extensions: Set of extension URIs from client.
|
||||
server_context: Optional server call context.
|
||||
|
||||
Returns:
|
||||
ExtensionContext for use in hooks.
|
||||
"""
|
||||
return ExtensionContext(
|
||||
metadata=metadata,
|
||||
client_extensions=client_extensions,
|
||||
server_context=server_context,
|
||||
)
|
||||
|
||||
async def invoke_on_request(self, context: ExtensionContext) -> None:
|
||||
"""Invoke on_request hooks for all active extensions.
|
||||
|
||||
Tracks activated extensions and isolates errors from individual hooks.
|
||||
|
||||
Args:
|
||||
context: The extension context for the request.
|
||||
"""
|
||||
for extension in self._extensions:
|
||||
if extension.is_active(context):
|
||||
try:
|
||||
await extension.on_request(context)
|
||||
if context.server_context is not None:
|
||||
context.server_context.activated_extensions.add(extension.uri)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Extension on_request hook failed",
|
||||
extra={"extension": extension.uri},
|
||||
)
|
||||
|
||||
async def invoke_on_response(self, context: ExtensionContext, result: Any) -> Any:
|
||||
"""Invoke on_response hooks for all active extensions.
|
||||
|
||||
Isolates errors from individual hooks to prevent one failing extension
|
||||
from breaking the entire response.
|
||||
|
||||
Args:
|
||||
context: The extension context for the request.
|
||||
result: The agent execution result.
|
||||
|
||||
Returns:
|
||||
The result after all extensions have processed it.
|
||||
"""
|
||||
processed = result
|
||||
for extension in self._extensions:
|
||||
if extension.is_active(context):
|
||||
try:
|
||||
processed = await extension.on_response(context, processed)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Extension on_response hook failed",
|
||||
extra={"extension": extension.uri},
|
||||
)
|
||||
return processed
|
||||
0
lib/crewai-a2a/src/crewai_a2a/py.typed
Normal file
0
lib/crewai-a2a/src/crewai_a2a/py.typed
Normal file
479
lib/crewai-a2a/src/crewai_a2a/task_helpers.py
Normal file
479
lib/crewai-a2a/src/crewai_a2a/task_helpers.py
Normal file
@@ -0,0 +1,479 @@
|
||||
"""Helper functions for processing A2A task results."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import TYPE_CHECKING, Any, TypedDict
|
||||
import uuid
|
||||
|
||||
from a2a.client.errors import A2AClientHTTPError
|
||||
from a2a.types import (
|
||||
AgentCard,
|
||||
Message,
|
||||
Part,
|
||||
Role,
|
||||
Task,
|
||||
TaskArtifactUpdateEvent,
|
||||
TaskState,
|
||||
TaskStatusUpdateEvent,
|
||||
TextPart,
|
||||
)
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.a2a_events import (
|
||||
A2AConnectionErrorEvent,
|
||||
A2AResponseReceivedEvent,
|
||||
)
|
||||
from typing_extensions import NotRequired
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from a2a.types import Task as A2ATask
|
||||
|
||||
SendMessageEvent = (
|
||||
tuple[Task, TaskStatusUpdateEvent | TaskArtifactUpdateEvent | None] | Message
|
||||
)
|
||||
|
||||
|
||||
TERMINAL_STATES: frozenset[TaskState] = frozenset(
|
||||
{
|
||||
TaskState.completed,
|
||||
TaskState.failed,
|
||||
TaskState.rejected,
|
||||
TaskState.canceled,
|
||||
}
|
||||
)
|
||||
|
||||
ACTIONABLE_STATES: frozenset[TaskState] = frozenset(
|
||||
{
|
||||
TaskState.input_required,
|
||||
TaskState.auth_required,
|
||||
}
|
||||
)
|
||||
|
||||
PENDING_STATES: frozenset[TaskState] = frozenset(
|
||||
{
|
||||
TaskState.submitted,
|
||||
TaskState.working,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class TaskStateResult(TypedDict):
|
||||
"""Result dictionary from processing A2A task state."""
|
||||
|
||||
status: TaskState
|
||||
history: list[Message]
|
||||
result: NotRequired[str]
|
||||
error: NotRequired[str]
|
||||
agent_card: NotRequired[dict[str, Any]]
|
||||
a2a_agent_name: NotRequired[str | None]
|
||||
|
||||
|
||||
def extract_task_result_parts(a2a_task: A2ATask) -> list[str]:
|
||||
"""Extract result parts from A2A task status message, history, and artifacts.
|
||||
|
||||
Args:
|
||||
a2a_task: A2A Task object with status, history, and artifacts
|
||||
|
||||
Returns:
|
||||
List of result text parts
|
||||
"""
|
||||
result_parts: list[str] = []
|
||||
|
||||
if a2a_task.status and a2a_task.status.message:
|
||||
msg = a2a_task.status.message
|
||||
result_parts.extend(
|
||||
part.root.text for part in msg.parts if part.root.kind == "text"
|
||||
)
|
||||
|
||||
if not result_parts and a2a_task.history:
|
||||
for history_msg in reversed(a2a_task.history):
|
||||
if history_msg.role == Role.agent:
|
||||
result_parts.extend(
|
||||
part.root.text
|
||||
for part in history_msg.parts
|
||||
if part.root.kind == "text"
|
||||
)
|
||||
break
|
||||
|
||||
if a2a_task.artifacts:
|
||||
result_parts.extend(
|
||||
part.root.text
|
||||
for artifact in a2a_task.artifacts
|
||||
for part in artifact.parts
|
||||
if part.root.kind == "text"
|
||||
)
|
||||
|
||||
return result_parts
|
||||
|
||||
|
||||
def extract_error_message(a2a_task: A2ATask, default: str) -> str:
|
||||
"""Extract error message from A2A task.
|
||||
|
||||
Args:
|
||||
a2a_task: A2A Task object
|
||||
default: Default message if no error found
|
||||
|
||||
Returns:
|
||||
Error message string
|
||||
"""
|
||||
if a2a_task.status and a2a_task.status.message:
|
||||
msg = a2a_task.status.message
|
||||
if msg:
|
||||
for part in msg.parts:
|
||||
if part.root.kind == "text":
|
||||
return str(part.root.text)
|
||||
return str(msg)
|
||||
|
||||
if a2a_task.history:
|
||||
for history_msg in reversed(a2a_task.history):
|
||||
for part in history_msg.parts:
|
||||
if part.root.kind == "text":
|
||||
return str(part.root.text)
|
||||
|
||||
return default
|
||||
|
||||
|
||||
def process_task_state(
|
||||
a2a_task: A2ATask,
|
||||
new_messages: list[Message],
|
||||
agent_card: AgentCard,
|
||||
turn_number: int,
|
||||
is_multiturn: bool,
|
||||
agent_role: str | None,
|
||||
result_parts: list[str] | None = None,
|
||||
endpoint: str | None = None,
|
||||
a2a_agent_name: str | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
is_final: bool = True,
|
||||
) -> TaskStateResult | None:
|
||||
"""Process A2A task state and return result dictionary.
|
||||
|
||||
Shared logic for both polling and streaming handlers.
|
||||
|
||||
Args:
|
||||
a2a_task: The A2A task to process.
|
||||
new_messages: List to collect messages (modified in place).
|
||||
agent_card: The agent card.
|
||||
turn_number: Current turn number.
|
||||
is_multiturn: Whether multi-turn conversation.
|
||||
agent_role: Agent role for logging.
|
||||
result_parts: Accumulated result parts (streaming passes accumulated,
|
||||
polling passes None to extract from task).
|
||||
endpoint: A2A agent endpoint URL.
|
||||
a2a_agent_name: Name of the A2A agent from agent card.
|
||||
from_task: Optional CrewAI Task for event metadata.
|
||||
from_agent: Optional CrewAI Agent for event metadata.
|
||||
is_final: Whether this is the final response in the stream.
|
||||
|
||||
Returns:
|
||||
Result dictionary if terminal/actionable state, None otherwise.
|
||||
"""
|
||||
if result_parts is None:
|
||||
result_parts = []
|
||||
|
||||
if a2a_task.status.state == TaskState.completed:
|
||||
if not result_parts:
|
||||
extracted_parts = extract_task_result_parts(a2a_task)
|
||||
result_parts.extend(extracted_parts)
|
||||
if a2a_task.history:
|
||||
new_messages.extend(a2a_task.history)
|
||||
|
||||
response_text = " ".join(result_parts) if result_parts else ""
|
||||
message_id = None
|
||||
if a2a_task.status and a2a_task.status.message:
|
||||
message_id = a2a_task.status.message.message_id
|
||||
crewai_event_bus.emit(
|
||||
None,
|
||||
A2AResponseReceivedEvent(
|
||||
response=response_text,
|
||||
turn_number=turn_number,
|
||||
context_id=a2a_task.context_id,
|
||||
message_id=message_id,
|
||||
is_multiturn=is_multiturn,
|
||||
status="completed",
|
||||
final=is_final,
|
||||
agent_role=agent_role,
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
|
||||
return TaskStateResult(
|
||||
status=TaskState.completed,
|
||||
agent_card=agent_card.model_dump(exclude_none=True),
|
||||
result=response_text,
|
||||
history=new_messages,
|
||||
)
|
||||
|
||||
if a2a_task.status.state == TaskState.input_required:
|
||||
if a2a_task.history:
|
||||
new_messages.extend(a2a_task.history)
|
||||
|
||||
response_text = extract_error_message(a2a_task, "Additional input required")
|
||||
if response_text and not a2a_task.history:
|
||||
agent_message = Message(
|
||||
role=Role.agent,
|
||||
message_id=str(uuid.uuid4()),
|
||||
parts=[Part(root=TextPart(text=response_text))],
|
||||
context_id=a2a_task.context_id,
|
||||
task_id=a2a_task.id,
|
||||
)
|
||||
new_messages.append(agent_message)
|
||||
|
||||
input_message_id = None
|
||||
if a2a_task.status and a2a_task.status.message:
|
||||
input_message_id = a2a_task.status.message.message_id
|
||||
crewai_event_bus.emit(
|
||||
None,
|
||||
A2AResponseReceivedEvent(
|
||||
response=response_text,
|
||||
turn_number=turn_number,
|
||||
context_id=a2a_task.context_id,
|
||||
message_id=input_message_id,
|
||||
is_multiturn=is_multiturn,
|
||||
status="input_required",
|
||||
final=is_final,
|
||||
agent_role=agent_role,
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
|
||||
return TaskStateResult(
|
||||
status=TaskState.input_required,
|
||||
error=response_text,
|
||||
history=new_messages,
|
||||
agent_card=agent_card.model_dump(exclude_none=True),
|
||||
)
|
||||
|
||||
if a2a_task.status.state in {TaskState.failed, TaskState.rejected}:
|
||||
error_msg = extract_error_message(a2a_task, "Task failed without error message")
|
||||
if a2a_task.history:
|
||||
new_messages.extend(a2a_task.history)
|
||||
return TaskStateResult(
|
||||
status=TaskState.failed,
|
||||
error=error_msg,
|
||||
history=new_messages,
|
||||
)
|
||||
|
||||
if a2a_task.status.state == TaskState.auth_required:
|
||||
error_msg = extract_error_message(a2a_task, "Authentication required")
|
||||
return TaskStateResult(
|
||||
status=TaskState.auth_required,
|
||||
error=error_msg,
|
||||
history=new_messages,
|
||||
)
|
||||
|
||||
if a2a_task.status.state == TaskState.canceled:
|
||||
error_msg = extract_error_message(a2a_task, "Task was canceled")
|
||||
return TaskStateResult(
|
||||
status=TaskState.canceled,
|
||||
error=error_msg,
|
||||
history=new_messages,
|
||||
)
|
||||
|
||||
if a2a_task.status.state in PENDING_STATES:
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
|
||||
async def send_message_and_get_task_id(
|
||||
event_stream: AsyncIterator[SendMessageEvent],
|
||||
new_messages: list[Message],
|
||||
agent_card: AgentCard,
|
||||
turn_number: int,
|
||||
is_multiturn: bool,
|
||||
agent_role: str | None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
endpoint: str | None = None,
|
||||
a2a_agent_name: str | None = None,
|
||||
context_id: str | None = None,
|
||||
) -> str | TaskStateResult:
|
||||
"""Send message and process initial response.
|
||||
|
||||
Handles the common pattern of sending a message and either:
|
||||
- Getting an immediate Message response (task completed synchronously)
|
||||
- Getting a Task that needs polling/waiting for completion
|
||||
|
||||
Args:
|
||||
event_stream: Async iterator from client.send_message()
|
||||
new_messages: List to collect messages (modified in place)
|
||||
agent_card: The agent card
|
||||
turn_number: Current turn number
|
||||
is_multiturn: Whether multi-turn conversation
|
||||
agent_role: Agent role for logging
|
||||
from_task: Optional CrewAI Task object for event metadata.
|
||||
from_agent: Optional CrewAI Agent object for event metadata.
|
||||
endpoint: Optional A2A endpoint URL.
|
||||
a2a_agent_name: Optional A2A agent name.
|
||||
context_id: Optional A2A context ID for correlation.
|
||||
|
||||
Returns:
|
||||
Task ID string if agent needs polling/waiting, or TaskStateResult if done.
|
||||
"""
|
||||
try:
|
||||
async for event in event_stream:
|
||||
if isinstance(event, Message):
|
||||
new_messages.append(event)
|
||||
result_parts = [
|
||||
part.root.text for part in event.parts if part.root.kind == "text"
|
||||
]
|
||||
response_text = " ".join(result_parts) if result_parts else ""
|
||||
|
||||
crewai_event_bus.emit(
|
||||
None,
|
||||
A2AResponseReceivedEvent(
|
||||
response=response_text,
|
||||
turn_number=turn_number,
|
||||
context_id=event.context_id,
|
||||
message_id=event.message_id,
|
||||
is_multiturn=is_multiturn,
|
||||
status="completed",
|
||||
final=True,
|
||||
agent_role=agent_role,
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
|
||||
return TaskStateResult(
|
||||
status=TaskState.completed,
|
||||
result=response_text,
|
||||
history=new_messages,
|
||||
agent_card=agent_card.model_dump(exclude_none=True),
|
||||
)
|
||||
|
||||
if isinstance(event, tuple):
|
||||
a2a_task, _ = event
|
||||
|
||||
if a2a_task.status.state in TERMINAL_STATES | ACTIONABLE_STATES:
|
||||
result = process_task_state(
|
||||
a2a_task=a2a_task,
|
||||
new_messages=new_messages,
|
||||
agent_card=agent_card,
|
||||
turn_number=turn_number,
|
||||
is_multiturn=is_multiturn,
|
||||
agent_role=agent_role,
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
if result:
|
||||
return result
|
||||
|
||||
return a2a_task.id
|
||||
|
||||
return TaskStateResult(
|
||||
status=TaskState.failed,
|
||||
error="No task ID received from initial message",
|
||||
history=new_messages,
|
||||
)
|
||||
|
||||
except A2AClientHTTPError as e:
|
||||
error_msg = f"HTTP Error {e.status_code}: {e!s}"
|
||||
|
||||
error_message = Message(
|
||||
role=Role.agent,
|
||||
message_id=str(uuid.uuid4()),
|
||||
parts=[Part(root=TextPart(text=error_msg))],
|
||||
context_id=context_id,
|
||||
)
|
||||
new_messages.append(error_message)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
None,
|
||||
A2AConnectionErrorEvent(
|
||||
endpoint=endpoint or "",
|
||||
error=str(e),
|
||||
error_type="http_error",
|
||||
status_code=e.status_code,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
operation="send_message",
|
||||
context_id=context_id,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
crewai_event_bus.emit(
|
||||
None,
|
||||
A2AResponseReceivedEvent(
|
||||
response=error_msg,
|
||||
turn_number=turn_number,
|
||||
context_id=context_id,
|
||||
is_multiturn=is_multiturn,
|
||||
status="failed",
|
||||
final=True,
|
||||
agent_role=agent_role,
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
return TaskStateResult(
|
||||
status=TaskState.failed,
|
||||
error=error_msg,
|
||||
history=new_messages,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Unexpected error during send_message: {e!s}"
|
||||
|
||||
error_message = Message(
|
||||
role=Role.agent,
|
||||
message_id=str(uuid.uuid4()),
|
||||
parts=[Part(root=TextPart(text=error_msg))],
|
||||
context_id=context_id,
|
||||
)
|
||||
new_messages.append(error_message)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
None,
|
||||
A2AConnectionErrorEvent(
|
||||
endpoint=endpoint or "",
|
||||
error=str(e),
|
||||
error_type="unexpected_error",
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
operation="send_message",
|
||||
context_id=context_id,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
crewai_event_bus.emit(
|
||||
None,
|
||||
A2AResponseReceivedEvent(
|
||||
response=error_msg,
|
||||
turn_number=turn_number,
|
||||
context_id=context_id,
|
||||
is_multiturn=is_multiturn,
|
||||
status="failed",
|
||||
final=True,
|
||||
agent_role=agent_role,
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
return TaskStateResult(
|
||||
status=TaskState.failed,
|
||||
error=error_msg,
|
||||
history=new_messages,
|
||||
)
|
||||
|
||||
finally:
|
||||
aclose = getattr(event_stream, "aclose", None)
|
||||
if aclose:
|
||||
await aclose()
|
||||
55
lib/crewai-a2a/src/crewai_a2a/templates.py
Normal file
55
lib/crewai-a2a/src/crewai_a2a/templates.py
Normal file
@@ -0,0 +1,55 @@
|
||||
"""String templates for A2A (Agent-to-Agent) protocol messaging and status."""
|
||||
|
||||
from string import Template
|
||||
from typing import Final
|
||||
|
||||
|
||||
AVAILABLE_AGENTS_TEMPLATE: Final[Template] = Template(
|
||||
"\n<AVAILABLE_A2A_AGENTS>\n $available_a2a_agents\n</AVAILABLE_A2A_AGENTS>\n"
|
||||
)
|
||||
PREVIOUS_A2A_CONVERSATION_TEMPLATE: Final[Template] = Template(
|
||||
"\n<PREVIOUS_A2A_CONVERSATION>\n"
|
||||
" $previous_a2a_conversation"
|
||||
"\n</PREVIOUS_A2A_CONVERSATION>\n"
|
||||
)
|
||||
CONVERSATION_TURN_INFO_TEMPLATE: Final[Template] = Template(
|
||||
"\n<CONVERSATION_PROGRESS>\n"
|
||||
' turn="$turn_count"\n'
|
||||
' max_turns="$max_turns"\n'
|
||||
" $warning"
|
||||
"\n</CONVERSATION_PROGRESS>\n"
|
||||
)
|
||||
UNAVAILABLE_AGENTS_NOTICE_TEMPLATE: Final[Template] = Template(
|
||||
"\n<A2A_AGENTS_STATUS>\n"
|
||||
" NOTE: A2A agents were configured but are currently unavailable.\n"
|
||||
" You cannot delegate to remote agents for this task.\n\n"
|
||||
" Unavailable Agents:\n"
|
||||
" $unavailable_agents"
|
||||
"\n</A2A_AGENTS_STATUS>\n"
|
||||
)
|
||||
REMOTE_AGENT_COMPLETED_NOTICE: Final[str] = """
|
||||
<REMOTE_AGENT_STATUS>
|
||||
STATUS: COMPLETED
|
||||
The remote agent has finished processing your request. Their response is in the conversation history above.
|
||||
You MUST now:
|
||||
1. Extract the answer from the conversation history
|
||||
2. Set is_a2a=false
|
||||
3. Return the answer as your final message
|
||||
DO NOT send another request - the task is already done.
|
||||
</REMOTE_AGENT_STATUS>
|
||||
"""
|
||||
|
||||
REMOTE_AGENT_RESPONSE_NOTICE: Final[str] = """
|
||||
<REMOTE_AGENT_STATUS>
|
||||
STATUS: RESPONSE_RECEIVED
|
||||
The remote agent has responded. Their response is in the conversation history above.
|
||||
|
||||
You MUST now:
|
||||
1. Set is_a2a=false (the remote task is complete and cannot receive more messages)
|
||||
2. Provide YOUR OWN response to the original task based on the information received
|
||||
|
||||
IMPORTANT: Your response should be addressed to the USER who gave you the original task.
|
||||
Report what the remote agent told you in THIRD PERSON (e.g., "The remote agent said..." or "I learned that...").
|
||||
Do NOT address the remote agent directly or use "you" to refer to them.
|
||||
</REMOTE_AGENT_STATUS>
|
||||
"""
|
||||
104
lib/crewai-a2a/src/crewai_a2a/types.py
Normal file
104
lib/crewai-a2a/src/crewai_a2a/types.py
Normal file
@@ -0,0 +1,104 @@
|
||||
"""Type definitions for A2A protocol message parts."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import (
|
||||
Annotated,
|
||||
Any,
|
||||
Literal,
|
||||
Protocol,
|
||||
TypedDict,
|
||||
runtime_checkable,
|
||||
)
|
||||
|
||||
from pydantic import BeforeValidator, HttpUrl, TypeAdapter
|
||||
from typing_extensions import NotRequired
|
||||
|
||||
|
||||
try:
|
||||
from crewai_a2a.updates import (
|
||||
PollingConfig,
|
||||
PollingHandler,
|
||||
PushNotificationConfig,
|
||||
PushNotificationHandler,
|
||||
StreamingConfig,
|
||||
StreamingHandler,
|
||||
UpdateConfig,
|
||||
)
|
||||
except ImportError:
|
||||
PollingConfig = Any # type: ignore[misc,assignment]
|
||||
PollingHandler = Any # type: ignore[misc,assignment]
|
||||
PushNotificationConfig = Any # type: ignore[misc,assignment]
|
||||
PushNotificationHandler = Any # type: ignore[misc,assignment]
|
||||
StreamingConfig = Any # type: ignore[misc,assignment]
|
||||
StreamingHandler = Any # type: ignore[misc,assignment]
|
||||
UpdateConfig = Any # type: ignore[misc,assignment]
|
||||
|
||||
|
||||
TransportType = Literal["JSONRPC", "GRPC", "HTTP+JSON"]
|
||||
ProtocolVersion = Literal[
|
||||
"0.2.0",
|
||||
"0.2.1",
|
||||
"0.2.2",
|
||||
"0.2.3",
|
||||
"0.2.4",
|
||||
"0.2.5",
|
||||
"0.2.6",
|
||||
"0.3.0",
|
||||
"0.4.0",
|
||||
]
|
||||
|
||||
http_url_adapter: TypeAdapter[HttpUrl] = TypeAdapter(HttpUrl)
|
||||
|
||||
Url = Annotated[
|
||||
str,
|
||||
BeforeValidator(
|
||||
lambda value: str(http_url_adapter.validate_python(value, strict=True))
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class AgentResponseProtocol(Protocol):
|
||||
"""Protocol for the dynamically created AgentResponse model."""
|
||||
|
||||
a2a_ids: tuple[str, ...]
|
||||
message: str
|
||||
is_a2a: bool
|
||||
|
||||
|
||||
class PartsMetadataDict(TypedDict, total=False):
|
||||
"""Metadata for A2A message parts.
|
||||
|
||||
Attributes:
|
||||
mimeType: MIME type for the part content.
|
||||
schema: JSON schema for the part content.
|
||||
"""
|
||||
|
||||
mimeType: Literal["application/json"]
|
||||
schema: dict[str, Any]
|
||||
|
||||
|
||||
class PartsDict(TypedDict):
|
||||
"""A2A message part containing text and optional metadata.
|
||||
|
||||
Attributes:
|
||||
text: The text content of the message part.
|
||||
metadata: Optional metadata describing the part content.
|
||||
"""
|
||||
|
||||
text: str
|
||||
metadata: NotRequired[PartsMetadataDict]
|
||||
|
||||
|
||||
PollingHandlerType = type[PollingHandler]
|
||||
StreamingHandlerType = type[StreamingHandler]
|
||||
PushNotificationHandlerType = type[PushNotificationHandler]
|
||||
|
||||
HandlerType = PollingHandlerType | StreamingHandlerType | PushNotificationHandlerType
|
||||
|
||||
HANDLER_REGISTRY: dict[type[UpdateConfig], HandlerType] = {
|
||||
PollingConfig: PollingHandler,
|
||||
StreamingConfig: StreamingHandler,
|
||||
PushNotificationConfig: PushNotificationHandler,
|
||||
}
|
||||
35
lib/crewai-a2a/src/crewai_a2a/updates/__init__.py
Normal file
35
lib/crewai-a2a/src/crewai_a2a/updates/__init__.py
Normal file
@@ -0,0 +1,35 @@
|
||||
"""A2A update mechanism configuration types."""
|
||||
|
||||
from crewai_a2a.updates.base import (
|
||||
BaseHandlerKwargs,
|
||||
PollingHandlerKwargs,
|
||||
PushNotificationHandlerKwargs,
|
||||
PushNotificationResultStore,
|
||||
StreamingHandlerKwargs,
|
||||
UpdateHandler,
|
||||
)
|
||||
from crewai_a2a.updates.polling.config import PollingConfig
|
||||
from crewai_a2a.updates.polling.handler import PollingHandler
|
||||
from crewai_a2a.updates.push_notifications.config import PushNotificationConfig
|
||||
from crewai_a2a.updates.push_notifications.handler import PushNotificationHandler
|
||||
from crewai_a2a.updates.streaming.config import StreamingConfig
|
||||
from crewai_a2a.updates.streaming.handler import StreamingHandler
|
||||
|
||||
|
||||
UpdateConfig = PollingConfig | StreamingConfig | PushNotificationConfig
|
||||
|
||||
__all__ = [
|
||||
"BaseHandlerKwargs",
|
||||
"PollingConfig",
|
||||
"PollingHandler",
|
||||
"PollingHandlerKwargs",
|
||||
"PushNotificationConfig",
|
||||
"PushNotificationHandler",
|
||||
"PushNotificationHandlerKwargs",
|
||||
"PushNotificationResultStore",
|
||||
"StreamingConfig",
|
||||
"StreamingHandler",
|
||||
"StreamingHandlerKwargs",
|
||||
"UpdateConfig",
|
||||
"UpdateHandler",
|
||||
]
|
||||
176
lib/crewai-a2a/src/crewai_a2a/updates/base.py
Normal file
176
lib/crewai-a2a/src/crewai_a2a/updates/base.py
Normal file
@@ -0,0 +1,176 @@
|
||||
"""Base types for A2A update mechanism handlers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, NamedTuple, Protocol, TypedDict
|
||||
|
||||
from pydantic import GetCoreSchemaHandler
|
||||
from pydantic_core import CoreSchema, core_schema
|
||||
|
||||
|
||||
class CommonParams(NamedTuple):
|
||||
"""Common parameters shared across all update handlers.
|
||||
|
||||
Groups the frequently-passed parameters to reduce duplication.
|
||||
"""
|
||||
|
||||
turn_number: int
|
||||
is_multiturn: bool
|
||||
agent_role: str | None
|
||||
endpoint: str
|
||||
a2a_agent_name: str | None
|
||||
context_id: str | None
|
||||
from_task: Any
|
||||
from_agent: Any
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from a2a.client import Client
|
||||
from a2a.types import AgentCard, Message, Task
|
||||
|
||||
from crewai_a2a.task_helpers import TaskStateResult
|
||||
from crewai_a2a.updates.push_notifications.config import PushNotificationConfig
|
||||
|
||||
|
||||
class BaseHandlerKwargs(TypedDict, total=False):
|
||||
"""Base kwargs shared by all handlers."""
|
||||
|
||||
turn_number: int
|
||||
is_multiturn: bool
|
||||
agent_role: str | None
|
||||
context_id: str | None
|
||||
task_id: str | None
|
||||
endpoint: str | None
|
||||
agent_branch: Any
|
||||
a2a_agent_name: str | None
|
||||
from_task: Any
|
||||
from_agent: Any
|
||||
|
||||
|
||||
class PollingHandlerKwargs(BaseHandlerKwargs, total=False):
|
||||
"""Kwargs for polling handler."""
|
||||
|
||||
polling_interval: float
|
||||
polling_timeout: float
|
||||
history_length: int
|
||||
max_polls: int | None
|
||||
|
||||
|
||||
class StreamingHandlerKwargs(BaseHandlerKwargs, total=False):
|
||||
"""Kwargs for streaming handler."""
|
||||
|
||||
|
||||
class PushNotificationHandlerKwargs(BaseHandlerKwargs, total=False):
|
||||
"""Kwargs for push notification handler."""
|
||||
|
||||
config: PushNotificationConfig
|
||||
result_store: PushNotificationResultStore
|
||||
polling_timeout: float
|
||||
polling_interval: float
|
||||
|
||||
|
||||
class PushNotificationResultStore(Protocol):
|
||||
"""Protocol for storing and retrieving push notification results.
|
||||
|
||||
This protocol defines the interface for a result store that the
|
||||
PushNotificationHandler uses to wait for task completion.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(
|
||||
cls,
|
||||
_source_type: Any,
|
||||
_handler: GetCoreSchemaHandler,
|
||||
) -> CoreSchema:
|
||||
return core_schema.any_schema()
|
||||
|
||||
async def wait_for_result(
|
||||
self,
|
||||
task_id: str,
|
||||
timeout: float,
|
||||
poll_interval: float = 1.0,
|
||||
) -> Task | None:
|
||||
"""Wait for a task result to be available.
|
||||
|
||||
Args:
|
||||
task_id: The task ID to wait for.
|
||||
timeout: Max seconds to wait before returning None.
|
||||
poll_interval: Seconds between polling attempts.
|
||||
|
||||
Returns:
|
||||
The completed Task object, or None if timeout.
|
||||
"""
|
||||
...
|
||||
|
||||
async def get_result(self, task_id: str) -> Task | None:
|
||||
"""Get a task result if available.
|
||||
|
||||
Args:
|
||||
task_id: The task ID to retrieve.
|
||||
|
||||
Returns:
|
||||
The Task object if available, None otherwise.
|
||||
"""
|
||||
...
|
||||
|
||||
async def store_result(self, task: Task) -> None:
|
||||
"""Store a task result.
|
||||
|
||||
Args:
|
||||
task: The Task object to store.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class UpdateHandler(Protocol):
|
||||
"""Protocol for A2A update mechanism handlers."""
|
||||
|
||||
@staticmethod
|
||||
async def execute(
|
||||
client: Client,
|
||||
message: Message,
|
||||
new_messages: list[Message],
|
||||
agent_card: AgentCard,
|
||||
**kwargs: Any,
|
||||
) -> TaskStateResult:
|
||||
"""Execute the update mechanism and return result.
|
||||
|
||||
Args:
|
||||
client: A2A client instance.
|
||||
message: Message to send.
|
||||
new_messages: List to collect messages (modified in place).
|
||||
agent_card: The agent card.
|
||||
**kwargs: Additional handler-specific parameters.
|
||||
|
||||
Returns:
|
||||
Result dictionary with status, result/error, and history.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
def extract_common_params(kwargs: BaseHandlerKwargs) -> CommonParams:
|
||||
"""Extract common parameters from handler kwargs.
|
||||
|
||||
Args:
|
||||
kwargs: Handler kwargs dict.
|
||||
|
||||
Returns:
|
||||
CommonParams with extracted values.
|
||||
|
||||
Raises:
|
||||
ValueError: If endpoint is not provided.
|
||||
"""
|
||||
endpoint = kwargs.get("endpoint")
|
||||
if endpoint is None:
|
||||
raise ValueError("endpoint is required for update handlers")
|
||||
|
||||
return CommonParams(
|
||||
turn_number=kwargs.get("turn_number", 0),
|
||||
is_multiturn=kwargs.get("is_multiturn", False),
|
||||
agent_role=kwargs.get("agent_role"),
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=kwargs.get("a2a_agent_name"),
|
||||
context_id=kwargs.get("context_id"),
|
||||
from_task=kwargs.get("from_task"),
|
||||
from_agent=kwargs.get("from_agent"),
|
||||
)
|
||||
@@ -0,0 +1 @@
|
||||
"""Polling update mechanism module."""
|
||||
25
lib/crewai-a2a/src/crewai_a2a/updates/polling/config.py
Normal file
25
lib/crewai-a2a/src/crewai_a2a/updates/polling/config.py
Normal file
@@ -0,0 +1,25 @@
|
||||
"""Polling update mechanism configuration."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class PollingConfig(BaseModel):
|
||||
"""Configuration for polling-based task updates.
|
||||
|
||||
Attributes:
|
||||
interval: Seconds between poll attempts.
|
||||
timeout: Max seconds to poll before raising timeout error.
|
||||
max_polls: Max number of poll attempts.
|
||||
history_length: Number of messages to retrieve per poll.
|
||||
"""
|
||||
|
||||
interval: float = Field(
|
||||
default=2.0, gt=0, description="Seconds between poll attempts"
|
||||
)
|
||||
timeout: float | None = Field(default=None, gt=0, description="Max seconds to poll")
|
||||
max_polls: int | None = Field(default=None, gt=0, description="Max poll attempts")
|
||||
history_length: int = Field(
|
||||
default=100, gt=0, description="Messages to retrieve per poll"
|
||||
)
|
||||
359
lib/crewai-a2a/src/crewai_a2a/updates/polling/handler.py
Normal file
359
lib/crewai-a2a/src/crewai_a2a/updates/polling/handler.py
Normal file
@@ -0,0 +1,359 @@
|
||||
"""Polling update mechanism handler."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Any
|
||||
import uuid
|
||||
|
||||
from a2a.client import Client
|
||||
from a2a.client.errors import A2AClientHTTPError
|
||||
from a2a.types import (
|
||||
AgentCard,
|
||||
Message,
|
||||
Part,
|
||||
Role,
|
||||
TaskQueryParams,
|
||||
TaskState,
|
||||
TextPart,
|
||||
)
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.a2a_events import (
|
||||
A2AConnectionErrorEvent,
|
||||
A2APollingStartedEvent,
|
||||
A2APollingStatusEvent,
|
||||
A2AResponseReceivedEvent,
|
||||
)
|
||||
from typing_extensions import Unpack
|
||||
|
||||
from crewai_a2a.errors import A2APollingTimeoutError
|
||||
from crewai_a2a.task_helpers import (
|
||||
ACTIONABLE_STATES,
|
||||
TERMINAL_STATES,
|
||||
TaskStateResult,
|
||||
process_task_state,
|
||||
send_message_and_get_task_id,
|
||||
)
|
||||
from crewai_a2a.updates.base import PollingHandlerKwargs
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from a2a.types import Task as A2ATask
|
||||
|
||||
|
||||
async def _poll_task_until_complete(
|
||||
client: Client,
|
||||
task_id: str,
|
||||
polling_interval: float,
|
||||
polling_timeout: float,
|
||||
agent_branch: Any | None = None,
|
||||
history_length: int = 100,
|
||||
max_polls: int | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
context_id: str | None = None,
|
||||
endpoint: str | None = None,
|
||||
a2a_agent_name: str | None = None,
|
||||
) -> A2ATask:
|
||||
"""Poll task status until terminal state reached.
|
||||
|
||||
Args:
|
||||
client: A2A client instance.
|
||||
task_id: Task ID to poll.
|
||||
polling_interval: Seconds between poll attempts.
|
||||
polling_timeout: Max seconds before timeout.
|
||||
agent_branch: Agent tree branch for logging.
|
||||
history_length: Number of messages to retrieve per poll.
|
||||
max_polls: Max number of poll attempts (None = unlimited).
|
||||
from_task: Optional CrewAI Task object for event metadata.
|
||||
from_agent: Optional CrewAI Agent object for event metadata.
|
||||
context_id: A2A context ID for correlation.
|
||||
endpoint: A2A agent endpoint URL.
|
||||
a2a_agent_name: Name of the A2A agent from agent card.
|
||||
|
||||
Returns:
|
||||
Final task object in terminal state.
|
||||
|
||||
Raises:
|
||||
A2APollingTimeoutError: If polling exceeds timeout or max_polls.
|
||||
"""
|
||||
start_time = time.monotonic()
|
||||
poll_count = 0
|
||||
|
||||
while True:
|
||||
poll_count += 1
|
||||
task = await client.get_task(
|
||||
TaskQueryParams(id=task_id, history_length=history_length)
|
||||
)
|
||||
|
||||
elapsed = time.monotonic() - start_time
|
||||
effective_context_id = task.context_id or context_id
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
A2APollingStatusEvent(
|
||||
task_id=task_id,
|
||||
context_id=effective_context_id,
|
||||
state=str(task.status.state.value),
|
||||
elapsed_seconds=elapsed,
|
||||
poll_count=poll_count,
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
|
||||
if task.status.state in TERMINAL_STATES | ACTIONABLE_STATES:
|
||||
return task
|
||||
|
||||
if elapsed > polling_timeout:
|
||||
raise A2APollingTimeoutError(
|
||||
f"Polling timeout after {polling_timeout}s ({poll_count} polls)"
|
||||
)
|
||||
|
||||
if max_polls and poll_count >= max_polls:
|
||||
raise A2APollingTimeoutError(
|
||||
f"Max polls ({max_polls}) exceeded after {elapsed:.1f}s"
|
||||
)
|
||||
|
||||
await asyncio.sleep(polling_interval)
|
||||
|
||||
|
||||
class PollingHandler:
|
||||
"""Polling-based update handler."""
|
||||
|
||||
@staticmethod
|
||||
async def execute(
|
||||
client: Client,
|
||||
message: Message,
|
||||
new_messages: list[Message],
|
||||
agent_card: AgentCard,
|
||||
**kwargs: Unpack[PollingHandlerKwargs],
|
||||
) -> TaskStateResult:
|
||||
"""Execute A2A delegation using polling for updates.
|
||||
|
||||
Args:
|
||||
client: A2A client instance.
|
||||
message: Message to send.
|
||||
new_messages: List to collect messages.
|
||||
agent_card: The agent card.
|
||||
**kwargs: Polling-specific parameters.
|
||||
|
||||
Returns:
|
||||
Dictionary with status, result/error, and history.
|
||||
"""
|
||||
polling_interval = kwargs.get("polling_interval", 2.0)
|
||||
polling_timeout = kwargs.get("polling_timeout", 300.0)
|
||||
endpoint = kwargs.get("endpoint", "")
|
||||
agent_branch = kwargs.get("agent_branch")
|
||||
turn_number = kwargs.get("turn_number", 0)
|
||||
is_multiturn = kwargs.get("is_multiturn", False)
|
||||
agent_role = kwargs.get("agent_role")
|
||||
history_length = kwargs.get("history_length", 100)
|
||||
max_polls = kwargs.get("max_polls")
|
||||
context_id = kwargs.get("context_id")
|
||||
task_id = kwargs.get("task_id")
|
||||
a2a_agent_name = kwargs.get("a2a_agent_name")
|
||||
from_task = kwargs.get("from_task")
|
||||
from_agent = kwargs.get("from_agent")
|
||||
|
||||
try:
|
||||
result_or_task_id = await send_message_and_get_task_id(
|
||||
event_stream=client.send_message(message),
|
||||
new_messages=new_messages,
|
||||
agent_card=agent_card,
|
||||
turn_number=turn_number,
|
||||
is_multiturn=is_multiturn,
|
||||
agent_role=agent_role,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
context_id=context_id,
|
||||
)
|
||||
|
||||
if not isinstance(result_or_task_id, str):
|
||||
return result_or_task_id
|
||||
|
||||
task_id = result_or_task_id
|
||||
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
A2APollingStartedEvent(
|
||||
task_id=task_id,
|
||||
context_id=context_id,
|
||||
polling_interval=polling_interval,
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
|
||||
final_task = await _poll_task_until_complete(
|
||||
client=client,
|
||||
task_id=task_id,
|
||||
polling_interval=polling_interval,
|
||||
polling_timeout=polling_timeout,
|
||||
agent_branch=agent_branch,
|
||||
history_length=history_length,
|
||||
max_polls=max_polls,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
context_id=context_id,
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
)
|
||||
|
||||
result = process_task_state(
|
||||
a2a_task=final_task,
|
||||
new_messages=new_messages,
|
||||
agent_card=agent_card,
|
||||
turn_number=turn_number,
|
||||
is_multiturn=is_multiturn,
|
||||
agent_role=agent_role,
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
if result:
|
||||
return result
|
||||
|
||||
return TaskStateResult(
|
||||
status=TaskState.failed,
|
||||
error=f"Unexpected task state: {final_task.status.state}",
|
||||
history=new_messages,
|
||||
)
|
||||
|
||||
except A2APollingTimeoutError as e:
|
||||
error_msg = str(e)
|
||||
|
||||
error_message = Message(
|
||||
role=Role.agent,
|
||||
message_id=str(uuid.uuid4()),
|
||||
parts=[Part(root=TextPart(text=error_msg))],
|
||||
context_id=context_id,
|
||||
task_id=task_id,
|
||||
)
|
||||
new_messages.append(error_message)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
A2AResponseReceivedEvent(
|
||||
response=error_msg,
|
||||
turn_number=turn_number,
|
||||
context_id=context_id,
|
||||
is_multiturn=is_multiturn,
|
||||
status="failed",
|
||||
final=True,
|
||||
agent_role=agent_role,
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
return TaskStateResult(
|
||||
status=TaskState.failed,
|
||||
error=error_msg,
|
||||
history=new_messages,
|
||||
)
|
||||
|
||||
except A2AClientHTTPError as e:
|
||||
error_msg = f"HTTP Error {e.status_code}: {e!s}"
|
||||
|
||||
error_message = Message(
|
||||
role=Role.agent,
|
||||
message_id=str(uuid.uuid4()),
|
||||
parts=[Part(root=TextPart(text=error_msg))],
|
||||
context_id=context_id,
|
||||
task_id=task_id,
|
||||
)
|
||||
new_messages.append(error_message)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
A2AConnectionErrorEvent(
|
||||
endpoint=endpoint,
|
||||
error=str(e),
|
||||
error_type="http_error",
|
||||
status_code=e.status_code,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
operation="polling",
|
||||
context_id=context_id,
|
||||
task_id=task_id,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
A2AResponseReceivedEvent(
|
||||
response=error_msg,
|
||||
turn_number=turn_number,
|
||||
context_id=context_id,
|
||||
is_multiturn=is_multiturn,
|
||||
status="failed",
|
||||
final=True,
|
||||
agent_role=agent_role,
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
return TaskStateResult(
|
||||
status=TaskState.failed,
|
||||
error=error_msg,
|
||||
history=new_messages,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Unexpected error during polling: {e!s}"
|
||||
|
||||
error_message = Message(
|
||||
role=Role.agent,
|
||||
message_id=str(uuid.uuid4()),
|
||||
parts=[Part(root=TextPart(text=error_msg))],
|
||||
context_id=context_id,
|
||||
task_id=task_id,
|
||||
)
|
||||
new_messages.append(error_message)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
A2AConnectionErrorEvent(
|
||||
endpoint=endpoint,
|
||||
error=str(e),
|
||||
error_type="unexpected_error",
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
operation="polling",
|
||||
context_id=context_id,
|
||||
task_id=task_id,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
A2AResponseReceivedEvent(
|
||||
response=error_msg,
|
||||
turn_number=turn_number,
|
||||
context_id=context_id,
|
||||
is_multiturn=is_multiturn,
|
||||
status="failed",
|
||||
final=True,
|
||||
agent_role=agent_role,
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
return TaskStateResult(
|
||||
status=TaskState.failed,
|
||||
error=error_msg,
|
||||
history=new_messages,
|
||||
)
|
||||
@@ -0,0 +1 @@
|
||||
"""Push notification update mechanism module."""
|
||||
@@ -0,0 +1,65 @@
|
||||
"""Push notification update mechanism configuration."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Annotated
|
||||
|
||||
from a2a.types import PushNotificationAuthenticationInfo
|
||||
from pydantic import AnyHttpUrl, BaseModel, BeforeValidator, Field
|
||||
|
||||
from crewai_a2a.updates.base import PushNotificationResultStore
|
||||
from crewai_a2a.updates.push_notifications.signature import WebhookSignatureConfig
|
||||
|
||||
|
||||
def _coerce_signature(
|
||||
value: str | WebhookSignatureConfig | None,
|
||||
) -> WebhookSignatureConfig | None:
|
||||
"""Convert string secret to WebhookSignatureConfig."""
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, str):
|
||||
return WebhookSignatureConfig.hmac_sha256(secret=value)
|
||||
return value
|
||||
|
||||
|
||||
SignatureInput = Annotated[
|
||||
WebhookSignatureConfig | None,
|
||||
BeforeValidator(_coerce_signature),
|
||||
]
|
||||
|
||||
|
||||
class PushNotificationConfig(BaseModel):
|
||||
"""Configuration for webhook-based task updates.
|
||||
|
||||
Attributes:
|
||||
url: Callback URL where agent sends push notifications.
|
||||
id: Unique identifier for this config.
|
||||
token: Token to validate incoming notifications.
|
||||
authentication: Auth info for agent to use when calling webhook.
|
||||
timeout: Max seconds to wait for task completion.
|
||||
interval: Seconds between result polling attempts.
|
||||
result_store: Store for receiving push notification results.
|
||||
signature: HMAC signature config. Pass a string (secret) for defaults,
|
||||
or WebhookSignatureConfig for custom settings.
|
||||
"""
|
||||
|
||||
url: AnyHttpUrl = Field(description="Callback URL for push notifications")
|
||||
id: str | None = Field(default=None, description="Unique config identifier")
|
||||
token: str | None = Field(default=None, description="Validation token")
|
||||
authentication: PushNotificationAuthenticationInfo | None = Field(
|
||||
default=None, description="Auth info for agent to use when calling webhook"
|
||||
)
|
||||
timeout: float | None = Field(
|
||||
default=300.0, gt=0, description="Max seconds to wait for task completion"
|
||||
)
|
||||
interval: float = Field(
|
||||
default=2.0, gt=0, description="Seconds between result polling attempts"
|
||||
)
|
||||
result_store: PushNotificationResultStore | None = Field(
|
||||
default=None, description="Result store for push notification handling"
|
||||
)
|
||||
signature: SignatureInput = Field(
|
||||
default=None,
|
||||
description="HMAC signature config. Pass a string (secret) for simple usage, "
|
||||
"or WebhookSignatureConfig for custom headers/tolerance.",
|
||||
)
|
||||
@@ -0,0 +1,354 @@
|
||||
"""Push notification (webhook) update mechanism handler."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any
|
||||
import uuid
|
||||
|
||||
from a2a.client import Client
|
||||
from a2a.client.errors import A2AClientHTTPError
|
||||
from a2a.types import (
|
||||
AgentCard,
|
||||
Message,
|
||||
Part,
|
||||
Role,
|
||||
TaskState,
|
||||
TextPart,
|
||||
)
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.a2a_events import (
|
||||
A2AConnectionErrorEvent,
|
||||
A2APushNotificationRegisteredEvent,
|
||||
A2APushNotificationTimeoutEvent,
|
||||
A2AResponseReceivedEvent,
|
||||
)
|
||||
from typing_extensions import Unpack
|
||||
|
||||
from crewai_a2a.task_helpers import (
|
||||
TaskStateResult,
|
||||
process_task_state,
|
||||
send_message_and_get_task_id,
|
||||
)
|
||||
from crewai_a2a.updates.base import (
|
||||
CommonParams,
|
||||
PushNotificationHandlerKwargs,
|
||||
PushNotificationResultStore,
|
||||
extract_common_params,
|
||||
)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from a2a.types import Task as A2ATask
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _handle_push_error(
|
||||
error: Exception,
|
||||
error_msg: str,
|
||||
error_type: str,
|
||||
new_messages: list[Message],
|
||||
agent_branch: Any | None,
|
||||
params: CommonParams,
|
||||
task_id: str | None,
|
||||
status_code: int | None = None,
|
||||
) -> TaskStateResult:
|
||||
"""Handle push notification errors with consistent event emission.
|
||||
|
||||
Args:
|
||||
error: The exception that occurred.
|
||||
error_msg: Formatted error message for the result.
|
||||
error_type: Type of error for the event.
|
||||
new_messages: List to append error message to.
|
||||
agent_branch: Agent tree branch for events.
|
||||
params: Common handler parameters.
|
||||
task_id: A2A task ID.
|
||||
status_code: HTTP status code if applicable.
|
||||
|
||||
Returns:
|
||||
TaskStateResult with failed status.
|
||||
"""
|
||||
error_message = Message(
|
||||
role=Role.agent,
|
||||
message_id=str(uuid.uuid4()),
|
||||
parts=[Part(root=TextPart(text=error_msg))],
|
||||
context_id=params.context_id,
|
||||
task_id=task_id,
|
||||
)
|
||||
new_messages.append(error_message)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
A2AConnectionErrorEvent(
|
||||
endpoint=params.endpoint,
|
||||
error=str(error),
|
||||
error_type=error_type,
|
||||
status_code=status_code,
|
||||
a2a_agent_name=params.a2a_agent_name,
|
||||
operation="push_notification",
|
||||
context_id=params.context_id,
|
||||
task_id=task_id,
|
||||
from_task=params.from_task,
|
||||
from_agent=params.from_agent,
|
||||
),
|
||||
)
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
A2AResponseReceivedEvent(
|
||||
response=error_msg,
|
||||
turn_number=params.turn_number,
|
||||
context_id=params.context_id,
|
||||
is_multiturn=params.is_multiturn,
|
||||
status="failed",
|
||||
final=True,
|
||||
agent_role=params.agent_role,
|
||||
endpoint=params.endpoint,
|
||||
a2a_agent_name=params.a2a_agent_name,
|
||||
from_task=params.from_task,
|
||||
from_agent=params.from_agent,
|
||||
),
|
||||
)
|
||||
return TaskStateResult(
|
||||
status=TaskState.failed,
|
||||
error=error_msg,
|
||||
history=new_messages,
|
||||
)
|
||||
|
||||
|
||||
async def _wait_for_push_result(
|
||||
task_id: str,
|
||||
result_store: PushNotificationResultStore,
|
||||
timeout: float,
|
||||
poll_interval: float,
|
||||
agent_branch: Any | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
context_id: str | None = None,
|
||||
endpoint: str | None = None,
|
||||
a2a_agent_name: str | None = None,
|
||||
) -> A2ATask | None:
|
||||
"""Wait for push notification result.
|
||||
|
||||
Args:
|
||||
task_id: Task ID to wait for.
|
||||
result_store: Store to retrieve results from.
|
||||
timeout: Max seconds to wait.
|
||||
poll_interval: Seconds between polling attempts.
|
||||
agent_branch: Agent tree branch for logging.
|
||||
from_task: Optional CrewAI Task object for event metadata.
|
||||
from_agent: Optional CrewAI Agent object for event metadata.
|
||||
context_id: A2A context ID for correlation.
|
||||
endpoint: A2A agent endpoint URL.
|
||||
a2a_agent_name: Name of the A2A agent.
|
||||
|
||||
Returns:
|
||||
Final task object, or None if timeout.
|
||||
"""
|
||||
task = await result_store.wait_for_result(
|
||||
task_id=task_id,
|
||||
timeout=timeout,
|
||||
poll_interval=poll_interval,
|
||||
)
|
||||
|
||||
if task is None:
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
A2APushNotificationTimeoutEvent(
|
||||
task_id=task_id,
|
||||
context_id=context_id,
|
||||
timeout_seconds=timeout,
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
|
||||
return task
|
||||
|
||||
|
||||
class PushNotificationHandler:
|
||||
"""Push notification (webhook) based update handler."""
|
||||
|
||||
@staticmethod
|
||||
async def execute(
|
||||
client: Client,
|
||||
message: Message,
|
||||
new_messages: list[Message],
|
||||
agent_card: AgentCard,
|
||||
**kwargs: Unpack[PushNotificationHandlerKwargs],
|
||||
) -> TaskStateResult:
|
||||
"""Execute A2A delegation using push notifications for updates.
|
||||
|
||||
Args:
|
||||
client: A2A client instance.
|
||||
message: Message to send.
|
||||
new_messages: List to collect messages.
|
||||
agent_card: The agent card.
|
||||
**kwargs: Push notification-specific parameters.
|
||||
|
||||
Returns:
|
||||
Dictionary with status, result/error, and history.
|
||||
|
||||
Raises:
|
||||
ValueError: If result_store or config not provided.
|
||||
"""
|
||||
config = kwargs.get("config")
|
||||
result_store = kwargs.get("result_store")
|
||||
polling_timeout = kwargs.get("polling_timeout", 300.0)
|
||||
polling_interval = kwargs.get("polling_interval", 2.0)
|
||||
agent_branch = kwargs.get("agent_branch")
|
||||
task_id = kwargs.get("task_id")
|
||||
params = extract_common_params(kwargs)
|
||||
|
||||
if config is None:
|
||||
error_msg = (
|
||||
"PushNotificationConfig is required for push notification handler"
|
||||
)
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
A2AConnectionErrorEvent(
|
||||
endpoint=params.endpoint,
|
||||
error=error_msg,
|
||||
error_type="configuration_error",
|
||||
a2a_agent_name=params.a2a_agent_name,
|
||||
operation="push_notification",
|
||||
context_id=params.context_id,
|
||||
task_id=task_id,
|
||||
from_task=params.from_task,
|
||||
from_agent=params.from_agent,
|
||||
),
|
||||
)
|
||||
return TaskStateResult(
|
||||
status=TaskState.failed,
|
||||
error=error_msg,
|
||||
history=new_messages,
|
||||
)
|
||||
|
||||
if result_store is None:
|
||||
error_msg = (
|
||||
"PushNotificationResultStore is required for push notification handler"
|
||||
)
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
A2AConnectionErrorEvent(
|
||||
endpoint=params.endpoint,
|
||||
error=error_msg,
|
||||
error_type="configuration_error",
|
||||
a2a_agent_name=params.a2a_agent_name,
|
||||
operation="push_notification",
|
||||
context_id=params.context_id,
|
||||
task_id=task_id,
|
||||
from_task=params.from_task,
|
||||
from_agent=params.from_agent,
|
||||
),
|
||||
)
|
||||
return TaskStateResult(
|
||||
status=TaskState.failed,
|
||||
error=error_msg,
|
||||
history=new_messages,
|
||||
)
|
||||
|
||||
try:
|
||||
result_or_task_id = await send_message_and_get_task_id(
|
||||
event_stream=client.send_message(message),
|
||||
new_messages=new_messages,
|
||||
agent_card=agent_card,
|
||||
turn_number=params.turn_number,
|
||||
is_multiturn=params.is_multiturn,
|
||||
agent_role=params.agent_role,
|
||||
from_task=params.from_task,
|
||||
from_agent=params.from_agent,
|
||||
endpoint=params.endpoint,
|
||||
a2a_agent_name=params.a2a_agent_name,
|
||||
context_id=params.context_id,
|
||||
)
|
||||
|
||||
if not isinstance(result_or_task_id, str):
|
||||
return result_or_task_id
|
||||
|
||||
task_id = result_or_task_id
|
||||
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
A2APushNotificationRegisteredEvent(
|
||||
task_id=task_id,
|
||||
context_id=params.context_id,
|
||||
callback_url=str(config.url),
|
||||
endpoint=params.endpoint,
|
||||
a2a_agent_name=params.a2a_agent_name,
|
||||
from_task=params.from_task,
|
||||
from_agent=params.from_agent,
|
||||
),
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
"Push notification callback for task %s configured at %s (via initial request)",
|
||||
task_id,
|
||||
config.url,
|
||||
)
|
||||
|
||||
final_task = await _wait_for_push_result(
|
||||
task_id=task_id,
|
||||
result_store=result_store,
|
||||
timeout=polling_timeout,
|
||||
poll_interval=polling_interval,
|
||||
agent_branch=agent_branch,
|
||||
from_task=params.from_task,
|
||||
from_agent=params.from_agent,
|
||||
context_id=params.context_id,
|
||||
endpoint=params.endpoint,
|
||||
a2a_agent_name=params.a2a_agent_name,
|
||||
)
|
||||
|
||||
if final_task is None:
|
||||
return TaskStateResult(
|
||||
status=TaskState.failed,
|
||||
error=f"Push notification timeout after {polling_timeout}s",
|
||||
history=new_messages,
|
||||
)
|
||||
|
||||
result = process_task_state(
|
||||
a2a_task=final_task,
|
||||
new_messages=new_messages,
|
||||
agent_card=agent_card,
|
||||
turn_number=params.turn_number,
|
||||
is_multiturn=params.is_multiturn,
|
||||
agent_role=params.agent_role,
|
||||
endpoint=params.endpoint,
|
||||
a2a_agent_name=params.a2a_agent_name,
|
||||
from_task=params.from_task,
|
||||
from_agent=params.from_agent,
|
||||
)
|
||||
if result:
|
||||
return result
|
||||
|
||||
return TaskStateResult(
|
||||
status=TaskState.failed,
|
||||
error=f"Unexpected task state: {final_task.status.state}",
|
||||
history=new_messages,
|
||||
)
|
||||
|
||||
except A2AClientHTTPError as e:
|
||||
return _handle_push_error(
|
||||
error=e,
|
||||
error_msg=f"HTTP Error {e.status_code}: {e!s}",
|
||||
error_type="http_error",
|
||||
new_messages=new_messages,
|
||||
agent_branch=agent_branch,
|
||||
params=params,
|
||||
task_id=task_id,
|
||||
status_code=e.status_code,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
return _handle_push_error(
|
||||
error=e,
|
||||
error_msg=f"Unexpected error during push notification: {e!s}",
|
||||
error_type="unexpected_error",
|
||||
new_messages=new_messages,
|
||||
agent_branch=agent_branch,
|
||||
params=params,
|
||||
task_id=task_id,
|
||||
)
|
||||
@@ -0,0 +1,87 @@
|
||||
"""Webhook signature configuration for push notifications."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import Enum
|
||||
import secrets
|
||||
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
|
||||
|
||||
class WebhookSignatureMode(str, Enum):
|
||||
"""Signature mode for webhook push notifications."""
|
||||
|
||||
NONE = "none"
|
||||
HMAC_SHA256 = "hmac_sha256"
|
||||
|
||||
|
||||
class WebhookSignatureConfig(BaseModel):
|
||||
"""Configuration for webhook signature verification.
|
||||
|
||||
Provides cryptographic integrity verification and replay attack protection
|
||||
for A2A push notifications.
|
||||
|
||||
Attributes:
|
||||
mode: Signature mode (none or hmac_sha256).
|
||||
secret: Shared secret for HMAC computation (required for hmac_sha256 mode).
|
||||
timestamp_tolerance_seconds: Max allowed age of timestamps for replay protection.
|
||||
header_name: HTTP header name for the signature.
|
||||
timestamp_header_name: HTTP header name for the timestamp.
|
||||
"""
|
||||
|
||||
mode: WebhookSignatureMode = Field(
|
||||
default=WebhookSignatureMode.NONE,
|
||||
description="Signature verification mode",
|
||||
)
|
||||
secret: SecretStr | None = Field(
|
||||
default=None,
|
||||
description="Shared secret for HMAC computation",
|
||||
)
|
||||
timestamp_tolerance_seconds: int = Field(
|
||||
default=300,
|
||||
ge=0,
|
||||
description="Max allowed timestamp age in seconds (5 min default)",
|
||||
)
|
||||
header_name: str = Field(
|
||||
default="X-A2A-Signature",
|
||||
description="HTTP header name for the signature",
|
||||
)
|
||||
timestamp_header_name: str = Field(
|
||||
default="X-A2A-Signature-Timestamp",
|
||||
description="HTTP header name for the timestamp",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def generate_secret(cls, length: int = 32) -> str:
|
||||
"""Generate a cryptographically secure random secret.
|
||||
|
||||
Args:
|
||||
length: Number of random bytes to generate (default 32).
|
||||
|
||||
Returns:
|
||||
URL-safe base64-encoded secret string.
|
||||
"""
|
||||
return secrets.token_urlsafe(length)
|
||||
|
||||
@classmethod
|
||||
def hmac_sha256(
|
||||
cls,
|
||||
secret: str | SecretStr,
|
||||
timestamp_tolerance_seconds: int = 300,
|
||||
) -> WebhookSignatureConfig:
|
||||
"""Create an HMAC-SHA256 signature configuration.
|
||||
|
||||
Args:
|
||||
secret: Shared secret for HMAC computation.
|
||||
timestamp_tolerance_seconds: Max allowed timestamp age in seconds.
|
||||
|
||||
Returns:
|
||||
Configured WebhookSignatureConfig for HMAC-SHA256.
|
||||
"""
|
||||
if isinstance(secret, str):
|
||||
secret = SecretStr(secret)
|
||||
return cls(
|
||||
mode=WebhookSignatureMode.HMAC_SHA256,
|
||||
secret=secret,
|
||||
timestamp_tolerance_seconds=timestamp_tolerance_seconds,
|
||||
)
|
||||
@@ -0,0 +1 @@
|
||||
"""Streaming update mechanism module."""
|
||||
@@ -0,0 +1,9 @@
|
||||
"""Streaming update mechanism configuration."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class StreamingConfig(BaseModel):
|
||||
"""Configuration for SSE-based task updates."""
|
||||
646
lib/crewai-a2a/src/crewai_a2a/updates/streaming/handler.py
Normal file
646
lib/crewai-a2a/src/crewai_a2a/updates/streaming/handler.py
Normal file
@@ -0,0 +1,646 @@
|
||||
"""Streaming (SSE) update mechanism handler."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Final
|
||||
import uuid
|
||||
|
||||
from a2a.client import Client
|
||||
from a2a.client.errors import A2AClientHTTPError
|
||||
from a2a.types import (
|
||||
AgentCard,
|
||||
Message,
|
||||
Part,
|
||||
Role,
|
||||
Task,
|
||||
TaskArtifactUpdateEvent,
|
||||
TaskIdParams,
|
||||
TaskQueryParams,
|
||||
TaskState,
|
||||
TaskStatusUpdateEvent,
|
||||
TextPart,
|
||||
)
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.a2a_events import (
|
||||
A2AArtifactReceivedEvent,
|
||||
A2AConnectionErrorEvent,
|
||||
A2AResponseReceivedEvent,
|
||||
A2AStreamingChunkEvent,
|
||||
A2AStreamingStartedEvent,
|
||||
)
|
||||
from typing_extensions import Unpack
|
||||
|
||||
from crewai_a2a.task_helpers import (
|
||||
ACTIONABLE_STATES,
|
||||
TERMINAL_STATES,
|
||||
TaskStateResult,
|
||||
process_task_state,
|
||||
)
|
||||
from crewai_a2a.updates.base import StreamingHandlerKwargs, extract_common_params
|
||||
from crewai_a2a.updates.streaming.params import (
|
||||
process_status_update,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MAX_RESUBSCRIBE_ATTEMPTS: Final[int] = 3
|
||||
RESUBSCRIBE_BACKOFF_BASE: Final[float] = 1.0
|
||||
|
||||
|
||||
class StreamingHandler:
|
||||
"""SSE streaming-based update handler."""
|
||||
|
||||
@staticmethod
|
||||
async def _try_recover_from_interruption( # type: ignore[misc]
|
||||
client: Client,
|
||||
task_id: str,
|
||||
new_messages: list[Message],
|
||||
agent_card: AgentCard,
|
||||
result_parts: list[str],
|
||||
**kwargs: Unpack[StreamingHandlerKwargs],
|
||||
) -> TaskStateResult | None:
|
||||
"""Attempt to recover from a stream interruption by checking task state.
|
||||
|
||||
If the task completed while we were disconnected, returns the result.
|
||||
If the task is still running, attempts to resubscribe and continue.
|
||||
|
||||
Args:
|
||||
client: A2A client instance.
|
||||
task_id: The task ID to recover.
|
||||
new_messages: List of collected messages.
|
||||
agent_card: The agent card.
|
||||
result_parts: Accumulated result text parts.
|
||||
**kwargs: Handler parameters.
|
||||
|
||||
Returns:
|
||||
TaskStateResult if recovery succeeded (task finished or resubscribe worked).
|
||||
None if recovery not possible (caller should handle failure).
|
||||
|
||||
Note:
|
||||
When None is returned, recovery failed and the original exception should
|
||||
be handled by the caller. All recovery attempts are logged.
|
||||
"""
|
||||
params = extract_common_params(kwargs) # type: ignore[arg-type]
|
||||
|
||||
try:
|
||||
a2a_task: Task = await client.get_task(TaskQueryParams(id=task_id))
|
||||
|
||||
if a2a_task.status.state in TERMINAL_STATES:
|
||||
logger.info(
|
||||
"Task completed during stream interruption",
|
||||
extra={"task_id": task_id, "state": str(a2a_task.status.state)},
|
||||
)
|
||||
return process_task_state(
|
||||
a2a_task=a2a_task,
|
||||
new_messages=new_messages,
|
||||
agent_card=agent_card,
|
||||
turn_number=params.turn_number,
|
||||
is_multiturn=params.is_multiturn,
|
||||
agent_role=params.agent_role,
|
||||
result_parts=result_parts,
|
||||
endpoint=params.endpoint,
|
||||
a2a_agent_name=params.a2a_agent_name,
|
||||
from_task=params.from_task,
|
||||
from_agent=params.from_agent,
|
||||
)
|
||||
|
||||
if a2a_task.status.state in ACTIONABLE_STATES:
|
||||
logger.info(
|
||||
"Task in actionable state during stream interruption",
|
||||
extra={"task_id": task_id, "state": str(a2a_task.status.state)},
|
||||
)
|
||||
return process_task_state(
|
||||
a2a_task=a2a_task,
|
||||
new_messages=new_messages,
|
||||
agent_card=agent_card,
|
||||
turn_number=params.turn_number,
|
||||
is_multiturn=params.is_multiturn,
|
||||
agent_role=params.agent_role,
|
||||
result_parts=result_parts,
|
||||
endpoint=params.endpoint,
|
||||
a2a_agent_name=params.a2a_agent_name,
|
||||
from_task=params.from_task,
|
||||
from_agent=params.from_agent,
|
||||
is_final=False,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Task still running, attempting resubscribe",
|
||||
extra={"task_id": task_id, "state": str(a2a_task.status.state)},
|
||||
)
|
||||
|
||||
for attempt in range(MAX_RESUBSCRIBE_ATTEMPTS):
|
||||
try:
|
||||
backoff = RESUBSCRIBE_BACKOFF_BASE * (2**attempt)
|
||||
if attempt > 0:
|
||||
await asyncio.sleep(backoff)
|
||||
|
||||
event_stream = client.resubscribe(TaskIdParams(id=task_id))
|
||||
|
||||
async for event in event_stream:
|
||||
if isinstance(event, tuple):
|
||||
resubscribed_task, update = event
|
||||
|
||||
is_final_update = (
|
||||
process_status_update(update, result_parts)
|
||||
if isinstance(update, TaskStatusUpdateEvent)
|
||||
else False
|
||||
)
|
||||
|
||||
if isinstance(update, TaskArtifactUpdateEvent):
|
||||
artifact = update.artifact
|
||||
result_parts.extend(
|
||||
part.root.text
|
||||
for part in artifact.parts
|
||||
if part.root.kind == "text"
|
||||
)
|
||||
|
||||
if (
|
||||
is_final_update
|
||||
or resubscribed_task.status.state
|
||||
in TERMINAL_STATES | ACTIONABLE_STATES
|
||||
):
|
||||
return process_task_state(
|
||||
a2a_task=resubscribed_task,
|
||||
new_messages=new_messages,
|
||||
agent_card=agent_card,
|
||||
turn_number=params.turn_number,
|
||||
is_multiturn=params.is_multiturn,
|
||||
agent_role=params.agent_role,
|
||||
result_parts=result_parts,
|
||||
endpoint=params.endpoint,
|
||||
a2a_agent_name=params.a2a_agent_name,
|
||||
from_task=params.from_task,
|
||||
from_agent=params.from_agent,
|
||||
is_final=is_final_update,
|
||||
)
|
||||
|
||||
elif isinstance(event, Message):
|
||||
new_messages.append(event)
|
||||
result_parts.extend(
|
||||
part.root.text
|
||||
for part in event.parts
|
||||
if part.root.kind == "text"
|
||||
)
|
||||
|
||||
final_task = await client.get_task(TaskQueryParams(id=task_id))
|
||||
return process_task_state(
|
||||
a2a_task=final_task,
|
||||
new_messages=new_messages,
|
||||
agent_card=agent_card,
|
||||
turn_number=params.turn_number,
|
||||
is_multiturn=params.is_multiturn,
|
||||
agent_role=params.agent_role,
|
||||
result_parts=result_parts,
|
||||
endpoint=params.endpoint,
|
||||
a2a_agent_name=params.a2a_agent_name,
|
||||
from_task=params.from_task,
|
||||
from_agent=params.from_agent,
|
||||
)
|
||||
|
||||
except Exception as resubscribe_error: # noqa: PERF203
|
||||
logger.warning(
|
||||
"Resubscribe attempt failed",
|
||||
extra={
|
||||
"task_id": task_id,
|
||||
"attempt": attempt + 1,
|
||||
"max_attempts": MAX_RESUBSCRIBE_ATTEMPTS,
|
||||
"error": str(resubscribe_error),
|
||||
},
|
||||
)
|
||||
if attempt == MAX_RESUBSCRIBE_ATTEMPTS - 1:
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Failed to recover from stream interruption due to unexpected error",
|
||||
extra={
|
||||
"task_id": task_id,
|
||||
"error": str(e),
|
||||
"error_type": type(e).__name__,
|
||||
},
|
||||
exc_info=True,
|
||||
)
|
||||
return None
|
||||
|
||||
logger.warning(
|
||||
"Recovery exhausted all resubscribe attempts without success",
|
||||
extra={"task_id": task_id, "max_attempts": MAX_RESUBSCRIBE_ATTEMPTS},
|
||||
)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
async def execute(
|
||||
client: Client,
|
||||
message: Message,
|
||||
new_messages: list[Message],
|
||||
agent_card: AgentCard,
|
||||
**kwargs: Unpack[StreamingHandlerKwargs],
|
||||
) -> TaskStateResult:
|
||||
"""Execute A2A delegation using SSE streaming for updates.
|
||||
|
||||
Args:
|
||||
client: A2A client instance.
|
||||
message: Message to send.
|
||||
new_messages: List to collect messages.
|
||||
agent_card: The agent card.
|
||||
**kwargs: Streaming-specific parameters.
|
||||
|
||||
Returns:
|
||||
Dictionary with status, result/error, and history.
|
||||
"""
|
||||
task_id = kwargs.get("task_id")
|
||||
agent_branch = kwargs.get("agent_branch")
|
||||
params = extract_common_params(kwargs)
|
||||
|
||||
result_parts: list[str] = []
|
||||
final_result: TaskStateResult | None = None
|
||||
event_stream = client.send_message(message)
|
||||
chunk_index = 0
|
||||
current_task_id: str | None = task_id
|
||||
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
A2AStreamingStartedEvent(
|
||||
task_id=task_id,
|
||||
context_id=params.context_id,
|
||||
endpoint=params.endpoint,
|
||||
a2a_agent_name=params.a2a_agent_name,
|
||||
turn_number=params.turn_number,
|
||||
is_multiturn=params.is_multiturn,
|
||||
agent_role=params.agent_role,
|
||||
from_task=params.from_task,
|
||||
from_agent=params.from_agent,
|
||||
),
|
||||
)
|
||||
|
||||
try:
|
||||
async for event in event_stream:
|
||||
if isinstance(event, tuple):
|
||||
a2a_task, _ = event
|
||||
current_task_id = a2a_task.id
|
||||
|
||||
if isinstance(event, Message):
|
||||
new_messages.append(event)
|
||||
message_context_id = event.context_id or params.context_id
|
||||
for part in event.parts:
|
||||
if part.root.kind == "text":
|
||||
text = part.root.text
|
||||
result_parts.append(text)
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
A2AStreamingChunkEvent(
|
||||
task_id=event.task_id or task_id,
|
||||
context_id=message_context_id,
|
||||
chunk=text,
|
||||
chunk_index=chunk_index,
|
||||
endpoint=params.endpoint,
|
||||
a2a_agent_name=params.a2a_agent_name,
|
||||
turn_number=params.turn_number,
|
||||
is_multiturn=params.is_multiturn,
|
||||
from_task=params.from_task,
|
||||
from_agent=params.from_agent,
|
||||
),
|
||||
)
|
||||
chunk_index += 1
|
||||
|
||||
elif isinstance(event, tuple):
|
||||
a2a_task, update = event
|
||||
|
||||
if isinstance(update, TaskArtifactUpdateEvent):
|
||||
artifact = update.artifact
|
||||
result_parts.extend(
|
||||
part.root.text
|
||||
for part in artifact.parts
|
||||
if part.root.kind == "text"
|
||||
)
|
||||
artifact_size = None
|
||||
if artifact.parts:
|
||||
artifact_size = sum(
|
||||
len(p.root.text.encode())
|
||||
if p.root.kind == "text"
|
||||
else len(getattr(p.root, "data", b""))
|
||||
for p in artifact.parts
|
||||
)
|
||||
effective_context_id = a2a_task.context_id or params.context_id
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
A2AArtifactReceivedEvent(
|
||||
task_id=a2a_task.id,
|
||||
artifact_id=artifact.artifact_id,
|
||||
artifact_name=artifact.name,
|
||||
artifact_description=artifact.description,
|
||||
mime_type=artifact.parts[0].root.kind
|
||||
if artifact.parts
|
||||
else None,
|
||||
size_bytes=artifact_size,
|
||||
append=update.append or False,
|
||||
last_chunk=update.last_chunk or False,
|
||||
endpoint=params.endpoint,
|
||||
a2a_agent_name=params.a2a_agent_name,
|
||||
context_id=effective_context_id,
|
||||
turn_number=params.turn_number,
|
||||
is_multiturn=params.is_multiturn,
|
||||
from_task=params.from_task,
|
||||
from_agent=params.from_agent,
|
||||
),
|
||||
)
|
||||
|
||||
is_final_update = (
|
||||
process_status_update(update, result_parts)
|
||||
if isinstance(update, TaskStatusUpdateEvent)
|
||||
else False
|
||||
)
|
||||
|
||||
if (
|
||||
not is_final_update
|
||||
and a2a_task.status.state
|
||||
not in TERMINAL_STATES | ACTIONABLE_STATES
|
||||
):
|
||||
continue
|
||||
|
||||
final_result = process_task_state(
|
||||
a2a_task=a2a_task,
|
||||
new_messages=new_messages,
|
||||
agent_card=agent_card,
|
||||
turn_number=params.turn_number,
|
||||
is_multiturn=params.is_multiturn,
|
||||
agent_role=params.agent_role,
|
||||
result_parts=result_parts,
|
||||
endpoint=params.endpoint,
|
||||
a2a_agent_name=params.a2a_agent_name,
|
||||
from_task=params.from_task,
|
||||
from_agent=params.from_agent,
|
||||
is_final=is_final_update,
|
||||
)
|
||||
if final_result:
|
||||
break
|
||||
|
||||
except A2AClientHTTPError as e:
|
||||
if current_task_id:
|
||||
logger.info(
|
||||
"Stream interrupted with HTTP error, attempting recovery",
|
||||
extra={
|
||||
"task_id": current_task_id,
|
||||
"error": str(e),
|
||||
"status_code": e.status_code,
|
||||
},
|
||||
)
|
||||
recovery_kwargs = {k: v for k, v in kwargs.items() if k != "task_id"}
|
||||
recovered_result = (
|
||||
await StreamingHandler._try_recover_from_interruption(
|
||||
client=client,
|
||||
task_id=current_task_id,
|
||||
new_messages=new_messages,
|
||||
agent_card=agent_card,
|
||||
result_parts=result_parts,
|
||||
**recovery_kwargs,
|
||||
)
|
||||
)
|
||||
if recovered_result:
|
||||
logger.info(
|
||||
"Successfully recovered task after HTTP error",
|
||||
extra={
|
||||
"task_id": current_task_id,
|
||||
"status": str(recovered_result.get("status")),
|
||||
},
|
||||
)
|
||||
return recovered_result
|
||||
|
||||
logger.warning(
|
||||
"Failed to recover from HTTP error, returning failure",
|
||||
extra={
|
||||
"task_id": current_task_id,
|
||||
"status_code": e.status_code,
|
||||
"original_error": str(e),
|
||||
},
|
||||
)
|
||||
|
||||
error_msg = f"HTTP Error {e.status_code}: {e!s}"
|
||||
error_type = "http_error"
|
||||
status_code = e.status_code
|
||||
|
||||
error_message = Message(
|
||||
role=Role.agent,
|
||||
message_id=str(uuid.uuid4()),
|
||||
parts=[Part(root=TextPart(text=error_msg))],
|
||||
context_id=params.context_id,
|
||||
task_id=task_id,
|
||||
)
|
||||
new_messages.append(error_message)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
A2AConnectionErrorEvent(
|
||||
endpoint=params.endpoint,
|
||||
error=str(e),
|
||||
error_type=error_type,
|
||||
status_code=status_code,
|
||||
a2a_agent_name=params.a2a_agent_name,
|
||||
operation="streaming",
|
||||
context_id=params.context_id,
|
||||
task_id=task_id,
|
||||
from_task=params.from_task,
|
||||
from_agent=params.from_agent,
|
||||
),
|
||||
)
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
A2AResponseReceivedEvent(
|
||||
response=error_msg,
|
||||
turn_number=params.turn_number,
|
||||
context_id=params.context_id,
|
||||
is_multiturn=params.is_multiturn,
|
||||
status="failed",
|
||||
final=True,
|
||||
agent_role=params.agent_role,
|
||||
endpoint=params.endpoint,
|
||||
a2a_agent_name=params.a2a_agent_name,
|
||||
from_task=params.from_task,
|
||||
from_agent=params.from_agent,
|
||||
),
|
||||
)
|
||||
return TaskStateResult(
|
||||
status=TaskState.failed,
|
||||
error=error_msg,
|
||||
history=new_messages,
|
||||
)
|
||||
|
||||
except (asyncio.TimeoutError, asyncio.CancelledError, ConnectionError) as e:
|
||||
error_type = type(e).__name__.lower()
|
||||
if current_task_id:
|
||||
logger.info(
|
||||
f"Stream interrupted with {error_type}, attempting recovery",
|
||||
extra={"task_id": current_task_id, "error": str(e)},
|
||||
)
|
||||
recovery_kwargs = {k: v for k, v in kwargs.items() if k != "task_id"}
|
||||
recovered_result = (
|
||||
await StreamingHandler._try_recover_from_interruption(
|
||||
client=client,
|
||||
task_id=current_task_id,
|
||||
new_messages=new_messages,
|
||||
agent_card=agent_card,
|
||||
result_parts=result_parts,
|
||||
**recovery_kwargs,
|
||||
)
|
||||
)
|
||||
if recovered_result:
|
||||
logger.info(
|
||||
f"Successfully recovered task after {error_type}",
|
||||
extra={
|
||||
"task_id": current_task_id,
|
||||
"status": str(recovered_result.get("status")),
|
||||
},
|
||||
)
|
||||
return recovered_result
|
||||
|
||||
logger.warning(
|
||||
f"Failed to recover from {error_type}, returning failure",
|
||||
extra={
|
||||
"task_id": current_task_id,
|
||||
"error_type": error_type,
|
||||
"original_error": str(e),
|
||||
},
|
||||
)
|
||||
|
||||
error_msg = f"Connection error during streaming: {e!s}"
|
||||
status_code = None
|
||||
|
||||
error_message = Message(
|
||||
role=Role.agent,
|
||||
message_id=str(uuid.uuid4()),
|
||||
parts=[Part(root=TextPart(text=error_msg))],
|
||||
context_id=params.context_id,
|
||||
task_id=task_id,
|
||||
)
|
||||
new_messages.append(error_message)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
A2AConnectionErrorEvent(
|
||||
endpoint=params.endpoint,
|
||||
error=str(e),
|
||||
error_type=error_type,
|
||||
status_code=status_code,
|
||||
a2a_agent_name=params.a2a_agent_name,
|
||||
operation="streaming",
|
||||
context_id=params.context_id,
|
||||
task_id=task_id,
|
||||
from_task=params.from_task,
|
||||
from_agent=params.from_agent,
|
||||
),
|
||||
)
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
A2AResponseReceivedEvent(
|
||||
response=error_msg,
|
||||
turn_number=params.turn_number,
|
||||
context_id=params.context_id,
|
||||
is_multiturn=params.is_multiturn,
|
||||
status="failed",
|
||||
final=True,
|
||||
agent_role=params.agent_role,
|
||||
endpoint=params.endpoint,
|
||||
a2a_agent_name=params.a2a_agent_name,
|
||||
from_task=params.from_task,
|
||||
from_agent=params.from_agent,
|
||||
),
|
||||
)
|
||||
return TaskStateResult(
|
||||
status=TaskState.failed,
|
||||
error=error_msg,
|
||||
history=new_messages,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"Unexpected error during streaming",
|
||||
extra={
|
||||
"task_id": current_task_id,
|
||||
"error_type": type(e).__name__,
|
||||
"endpoint": params.endpoint,
|
||||
},
|
||||
)
|
||||
error_msg = f"Unexpected error during streaming: {type(e).__name__}: {e!s}"
|
||||
error_type = "unexpected_error"
|
||||
status_code = None
|
||||
|
||||
error_message = Message(
|
||||
role=Role.agent,
|
||||
message_id=str(uuid.uuid4()),
|
||||
parts=[Part(root=TextPart(text=error_msg))],
|
||||
context_id=params.context_id,
|
||||
task_id=task_id,
|
||||
)
|
||||
new_messages.append(error_message)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
A2AConnectionErrorEvent(
|
||||
endpoint=params.endpoint,
|
||||
error=str(e),
|
||||
error_type=error_type,
|
||||
status_code=status_code,
|
||||
a2a_agent_name=params.a2a_agent_name,
|
||||
operation="streaming",
|
||||
context_id=params.context_id,
|
||||
task_id=task_id,
|
||||
from_task=params.from_task,
|
||||
from_agent=params.from_agent,
|
||||
),
|
||||
)
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
A2AResponseReceivedEvent(
|
||||
response=error_msg,
|
||||
turn_number=params.turn_number,
|
||||
context_id=params.context_id,
|
||||
is_multiturn=params.is_multiturn,
|
||||
status="failed",
|
||||
final=True,
|
||||
agent_role=params.agent_role,
|
||||
endpoint=params.endpoint,
|
||||
a2a_agent_name=params.a2a_agent_name,
|
||||
from_task=params.from_task,
|
||||
from_agent=params.from_agent,
|
||||
),
|
||||
)
|
||||
return TaskStateResult(
|
||||
status=TaskState.failed,
|
||||
error=error_msg,
|
||||
history=new_messages,
|
||||
)
|
||||
|
||||
finally:
|
||||
aclose = getattr(event_stream, "aclose", None)
|
||||
if aclose:
|
||||
try:
|
||||
await aclose()
|
||||
except Exception as close_error:
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
A2AConnectionErrorEvent(
|
||||
endpoint=params.endpoint,
|
||||
error=str(close_error),
|
||||
error_type="stream_close_error",
|
||||
a2a_agent_name=params.a2a_agent_name,
|
||||
operation="stream_close",
|
||||
context_id=params.context_id,
|
||||
task_id=task_id,
|
||||
from_task=params.from_task,
|
||||
from_agent=params.from_agent,
|
||||
),
|
||||
)
|
||||
|
||||
if final_result:
|
||||
return final_result
|
||||
|
||||
return TaskStateResult(
|
||||
status=TaskState.completed,
|
||||
result=" ".join(result_parts) if result_parts else "",
|
||||
history=new_messages,
|
||||
agent_card=agent_card.model_dump(exclude_none=True),
|
||||
)
|
||||
28
lib/crewai-a2a/src/crewai_a2a/updates/streaming/params.py
Normal file
28
lib/crewai-a2a/src/crewai_a2a/updates/streaming/params.py
Normal file
@@ -0,0 +1,28 @@
|
||||
"""Common parameter extraction for streaming handlers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from a2a.types import TaskStatusUpdateEvent
|
||||
|
||||
|
||||
def process_status_update(
|
||||
update: TaskStatusUpdateEvent,
|
||||
result_parts: list[str],
|
||||
) -> bool:
|
||||
"""Process a status update event and extract text parts.
|
||||
|
||||
Args:
|
||||
update: The status update event.
|
||||
result_parts: List to append text parts to (modified in place).
|
||||
|
||||
Returns:
|
||||
True if this is a final update, False otherwise.
|
||||
"""
|
||||
is_final = update.final
|
||||
if update.status and update.status.message and update.status.message.parts:
|
||||
result_parts.extend(
|
||||
part.root.text
|
||||
for part in update.status.message.parts
|
||||
if part.root.kind == "text" and part.root.text
|
||||
)
|
||||
return is_final
|
||||
1
lib/crewai-a2a/src/crewai_a2a/utils/__init__.py
Normal file
1
lib/crewai-a2a/src/crewai_a2a/utils/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""A2A utility modules for client operations."""
|
||||
587
lib/crewai-a2a/src/crewai_a2a/utils/agent_card.py
Normal file
587
lib/crewai-a2a/src/crewai_a2a/utils/agent_card.py
Normal file
@@ -0,0 +1,587 @@
|
||||
"""AgentCard utilities for A2A client and server operations."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import MutableMapping
|
||||
from functools import lru_cache
|
||||
import ssl
|
||||
import time
|
||||
from types import MethodType
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from a2a.client.errors import A2AClientHTTPError
|
||||
from a2a.types import AgentCapabilities, AgentCard, AgentSkill
|
||||
from aiocache import cached # type: ignore[import-untyped]
|
||||
from aiocache.serializers import PickleSerializer # type: ignore[import-untyped]
|
||||
from crewai.crew import Crew
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.a2a_events import (
|
||||
A2AAgentCardFetchedEvent,
|
||||
A2AAuthenticationFailedEvent,
|
||||
A2AConnectionErrorEvent,
|
||||
)
|
||||
import httpx
|
||||
|
||||
from crewai_a2a.auth.client_schemes import APIKeyAuth, HTTPDigestAuth
|
||||
from crewai_a2a.auth.utils import (
|
||||
_auth_store,
|
||||
configure_auth_client,
|
||||
retry_on_401,
|
||||
)
|
||||
from crewai_a2a.config import A2AServerConfig
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.agent import Agent
|
||||
from crewai.task import Task
|
||||
|
||||
from crewai_a2a.auth.client_schemes import ClientAuthScheme
|
||||
|
||||
|
||||
def _get_tls_verify(auth: ClientAuthScheme | None) -> ssl.SSLContext | bool | str:
|
||||
"""Get TLS verify parameter from auth scheme.
|
||||
|
||||
Args:
|
||||
auth: Optional authentication scheme with TLS config.
|
||||
|
||||
Returns:
|
||||
SSL context, CA cert path, True for default verification,
|
||||
or False if verification disabled.
|
||||
"""
|
||||
if auth and auth.tls:
|
||||
return auth.tls.get_httpx_ssl_context()
|
||||
return True
|
||||
|
||||
|
||||
async def _prepare_auth_headers(
|
||||
auth: ClientAuthScheme | None,
|
||||
timeout: int,
|
||||
) -> tuple[MutableMapping[str, str], ssl.SSLContext | bool | str]:
|
||||
"""Prepare authentication headers and TLS verification settings.
|
||||
|
||||
Args:
|
||||
auth: Optional authentication scheme.
|
||||
timeout: Request timeout in seconds.
|
||||
|
||||
Returns:
|
||||
Tuple of (headers dict, TLS verify setting).
|
||||
"""
|
||||
headers: MutableMapping[str, str] = {}
|
||||
verify = _get_tls_verify(auth)
|
||||
if auth:
|
||||
async with httpx.AsyncClient(
|
||||
timeout=timeout, verify=verify
|
||||
) as temp_auth_client:
|
||||
if isinstance(auth, (HTTPDigestAuth, APIKeyAuth)):
|
||||
configure_auth_client(auth, temp_auth_client)
|
||||
headers = await auth.apply_auth(temp_auth_client, {})
|
||||
return headers, verify
|
||||
|
||||
|
||||
def _get_server_config(agent: Agent) -> A2AServerConfig | None:
|
||||
"""Get A2AServerConfig from an agent's a2a configuration.
|
||||
|
||||
Args:
|
||||
agent: The Agent instance to check.
|
||||
|
||||
Returns:
|
||||
A2AServerConfig if present, None otherwise.
|
||||
"""
|
||||
if agent.a2a is None:
|
||||
return None
|
||||
if isinstance(agent.a2a, A2AServerConfig):
|
||||
return agent.a2a
|
||||
if isinstance(agent.a2a, list):
|
||||
for config in agent.a2a:
|
||||
if isinstance(config, A2AServerConfig):
|
||||
return config
|
||||
return None
|
||||
|
||||
|
||||
def fetch_agent_card(
|
||||
endpoint: str,
|
||||
auth: ClientAuthScheme | None = None,
|
||||
timeout: int = 30,
|
||||
use_cache: bool = True,
|
||||
cache_ttl: int = 300,
|
||||
) -> AgentCard:
|
||||
"""Fetch AgentCard from an A2A endpoint with optional caching.
|
||||
|
||||
Args:
|
||||
endpoint: A2A agent endpoint URL (AgentCard URL).
|
||||
auth: Optional ClientAuthScheme for authentication.
|
||||
timeout: Request timeout in seconds.
|
||||
use_cache: Whether to use caching (default True).
|
||||
cache_ttl: Cache TTL in seconds (default 300 = 5 minutes).
|
||||
|
||||
Returns:
|
||||
AgentCard object with agent capabilities and skills.
|
||||
|
||||
Raises:
|
||||
httpx.HTTPStatusError: If the request fails.
|
||||
A2AClientHTTPError: If authentication fails.
|
||||
"""
|
||||
if use_cache:
|
||||
if auth:
|
||||
auth_data = auth.model_dump_json(
|
||||
exclude={
|
||||
"_access_token",
|
||||
"_token_expires_at",
|
||||
"_refresh_token",
|
||||
"_authorization_callback",
|
||||
}
|
||||
)
|
||||
auth_hash = _auth_store.compute_key(type(auth).__name__, auth_data)
|
||||
else:
|
||||
auth_hash = _auth_store.compute_key("none", "")
|
||||
_auth_store.set(auth_hash, auth)
|
||||
ttl_hash = int(time.time() // cache_ttl)
|
||||
return _fetch_agent_card_cached(endpoint, auth_hash, timeout, ttl_hash)
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
return loop.run_until_complete(
|
||||
afetch_agent_card(endpoint=endpoint, auth=auth, timeout=timeout)
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
|
||||
async def afetch_agent_card(
|
||||
endpoint: str,
|
||||
auth: ClientAuthScheme | None = None,
|
||||
timeout: int = 30,
|
||||
use_cache: bool = True,
|
||||
) -> AgentCard:
|
||||
"""Fetch AgentCard from an A2A endpoint asynchronously.
|
||||
|
||||
Native async implementation. Use this when running in an async context.
|
||||
|
||||
Args:
|
||||
endpoint: A2A agent endpoint URL (AgentCard URL).
|
||||
auth: Optional ClientAuthScheme for authentication.
|
||||
timeout: Request timeout in seconds.
|
||||
use_cache: Whether to use caching (default True).
|
||||
|
||||
Returns:
|
||||
AgentCard object with agent capabilities and skills.
|
||||
|
||||
Raises:
|
||||
httpx.HTTPStatusError: If the request fails.
|
||||
A2AClientHTTPError: If authentication fails.
|
||||
"""
|
||||
if use_cache:
|
||||
if auth:
|
||||
auth_data = auth.model_dump_json(
|
||||
exclude={
|
||||
"_access_token",
|
||||
"_token_expires_at",
|
||||
"_refresh_token",
|
||||
"_authorization_callback",
|
||||
}
|
||||
)
|
||||
auth_hash = _auth_store.compute_key(type(auth).__name__, auth_data)
|
||||
else:
|
||||
auth_hash = _auth_store.compute_key("none", "")
|
||||
_auth_store.set(auth_hash, auth)
|
||||
agent_card: AgentCard = await _afetch_agent_card_cached(
|
||||
endpoint, auth_hash, timeout
|
||||
)
|
||||
return agent_card
|
||||
|
||||
return await _afetch_agent_card_impl(endpoint=endpoint, auth=auth, timeout=timeout)
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def _fetch_agent_card_cached(
|
||||
endpoint: str,
|
||||
auth_hash: str,
|
||||
timeout: int,
|
||||
_ttl_hash: int,
|
||||
) -> AgentCard:
|
||||
"""Cached sync version of fetch_agent_card."""
|
||||
auth = _auth_store.get(auth_hash)
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
return loop.run_until_complete(
|
||||
_afetch_agent_card_impl(endpoint=endpoint, auth=auth, timeout=timeout)
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
|
||||
@cached(ttl=300, serializer=PickleSerializer()) # type: ignore[untyped-decorator]
|
||||
async def _afetch_agent_card_cached(
|
||||
endpoint: str,
|
||||
auth_hash: str,
|
||||
timeout: int,
|
||||
) -> AgentCard:
|
||||
"""Cached async implementation of AgentCard fetching."""
|
||||
auth = _auth_store.get(auth_hash)
|
||||
return await _afetch_agent_card_impl(endpoint=endpoint, auth=auth, timeout=timeout)
|
||||
|
||||
|
||||
async def _afetch_agent_card_impl(
|
||||
endpoint: str,
|
||||
auth: ClientAuthScheme | None,
|
||||
timeout: int,
|
||||
) -> AgentCard:
|
||||
"""Internal async implementation of AgentCard fetching."""
|
||||
start_time = time.perf_counter()
|
||||
|
||||
if "/.well-known/agent-card.json" in endpoint:
|
||||
base_url = endpoint.replace("/.well-known/agent-card.json", "")
|
||||
agent_card_path = "/.well-known/agent-card.json"
|
||||
else:
|
||||
url_parts = endpoint.split("/", 3)
|
||||
base_url = f"{url_parts[0]}//{url_parts[2]}"
|
||||
agent_card_path = (
|
||||
f"/{url_parts[3]}"
|
||||
if len(url_parts) > 3 and url_parts[3]
|
||||
else "/.well-known/agent-card.json"
|
||||
)
|
||||
|
||||
headers, verify = await _prepare_auth_headers(auth, timeout)
|
||||
|
||||
async with httpx.AsyncClient(
|
||||
timeout=timeout, headers=headers, verify=verify
|
||||
) as temp_client:
|
||||
if auth and isinstance(auth, (HTTPDigestAuth, APIKeyAuth)):
|
||||
configure_auth_client(auth, temp_client)
|
||||
|
||||
agent_card_url = f"{base_url}{agent_card_path}"
|
||||
|
||||
async def _fetch_agent_card_request() -> httpx.Response:
|
||||
return await temp_client.get(agent_card_url)
|
||||
|
||||
try:
|
||||
response = await retry_on_401(
|
||||
request_func=_fetch_agent_card_request,
|
||||
auth_scheme=auth,
|
||||
client=temp_client,
|
||||
headers=temp_client.headers,
|
||||
max_retries=2,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
agent_card = AgentCard.model_validate(response.json())
|
||||
fetch_time_ms = (time.perf_counter() - start_time) * 1000
|
||||
agent_card_dict = agent_card.model_dump(exclude_none=True)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
None,
|
||||
A2AAgentCardFetchedEvent(
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=agent_card.name,
|
||||
agent_card=agent_card_dict,
|
||||
protocol_version=agent_card.protocol_version,
|
||||
provider=agent_card_dict.get("provider"),
|
||||
cached=False,
|
||||
fetch_time_ms=fetch_time_ms,
|
||||
),
|
||||
)
|
||||
|
||||
return agent_card
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
elapsed_ms = (time.perf_counter() - start_time) * 1000
|
||||
response_body = e.response.text[:1000] if e.response.text else None
|
||||
|
||||
if e.response.status_code == 401:
|
||||
error_details = ["Authentication failed"]
|
||||
www_auth = e.response.headers.get("WWW-Authenticate")
|
||||
if www_auth:
|
||||
error_details.append(f"WWW-Authenticate: {www_auth}")
|
||||
if not auth:
|
||||
error_details.append("No auth scheme provided")
|
||||
msg = " | ".join(error_details)
|
||||
|
||||
auth_type = type(auth).__name__ if auth else None
|
||||
crewai_event_bus.emit(
|
||||
None,
|
||||
A2AAuthenticationFailedEvent(
|
||||
endpoint=endpoint,
|
||||
auth_type=auth_type,
|
||||
error=msg,
|
||||
status_code=401,
|
||||
metadata={
|
||||
"elapsed_ms": elapsed_ms,
|
||||
"response_body": response_body,
|
||||
"www_authenticate": www_auth,
|
||||
"request_url": str(e.request.url),
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
raise A2AClientHTTPError(401, msg) from e
|
||||
|
||||
crewai_event_bus.emit(
|
||||
None,
|
||||
A2AConnectionErrorEvent(
|
||||
endpoint=endpoint,
|
||||
error=str(e),
|
||||
error_type="http_error",
|
||||
status_code=e.response.status_code,
|
||||
operation="fetch_agent_card",
|
||||
metadata={
|
||||
"elapsed_ms": elapsed_ms,
|
||||
"response_body": response_body,
|
||||
"request_url": str(e.request.url),
|
||||
},
|
||||
),
|
||||
)
|
||||
raise
|
||||
|
||||
except httpx.TimeoutException as e:
|
||||
elapsed_ms = (time.perf_counter() - start_time) * 1000
|
||||
crewai_event_bus.emit(
|
||||
None,
|
||||
A2AConnectionErrorEvent(
|
||||
endpoint=endpoint,
|
||||
error=str(e),
|
||||
error_type="timeout",
|
||||
operation="fetch_agent_card",
|
||||
metadata={
|
||||
"elapsed_ms": elapsed_ms,
|
||||
"timeout_config": timeout,
|
||||
"request_url": str(e.request.url) if e.request else None,
|
||||
},
|
||||
),
|
||||
)
|
||||
raise
|
||||
|
||||
except httpx.ConnectError as e:
|
||||
elapsed_ms = (time.perf_counter() - start_time) * 1000
|
||||
crewai_event_bus.emit(
|
||||
None,
|
||||
A2AConnectionErrorEvent(
|
||||
endpoint=endpoint,
|
||||
error=str(e),
|
||||
error_type="connection_error",
|
||||
operation="fetch_agent_card",
|
||||
metadata={
|
||||
"elapsed_ms": elapsed_ms,
|
||||
"request_url": str(e.request.url) if e.request else None,
|
||||
},
|
||||
),
|
||||
)
|
||||
raise
|
||||
|
||||
except httpx.RequestError as e:
|
||||
elapsed_ms = (time.perf_counter() - start_time) * 1000
|
||||
crewai_event_bus.emit(
|
||||
None,
|
||||
A2AConnectionErrorEvent(
|
||||
endpoint=endpoint,
|
||||
error=str(e),
|
||||
error_type="request_error",
|
||||
operation="fetch_agent_card",
|
||||
metadata={
|
||||
"elapsed_ms": elapsed_ms,
|
||||
"request_url": str(e.request.url) if e.request else None,
|
||||
},
|
||||
),
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
def _task_to_skill(task: Task) -> AgentSkill:
|
||||
"""Convert a CrewAI Task to an A2A AgentSkill.
|
||||
|
||||
Args:
|
||||
task: The CrewAI Task to convert.
|
||||
|
||||
Returns:
|
||||
AgentSkill representing the task's capability.
|
||||
"""
|
||||
task_name = task.name or task.description[:50]
|
||||
task_id = task_name.lower().replace(" ", "_")
|
||||
|
||||
tags: list[str] = []
|
||||
if task.agent:
|
||||
tags.append(task.agent.role.lower().replace(" ", "-"))
|
||||
|
||||
return AgentSkill(
|
||||
id=task_id,
|
||||
name=task_name,
|
||||
description=task.description,
|
||||
tags=tags,
|
||||
examples=[task.expected_output] if task.expected_output else None,
|
||||
)
|
||||
|
||||
|
||||
def _tool_to_skill(tool_name: str, tool_description: str) -> AgentSkill:
|
||||
"""Convert an Agent's tool to an A2A AgentSkill.
|
||||
|
||||
Args:
|
||||
tool_name: Name of the tool.
|
||||
tool_description: Description of what the tool does.
|
||||
|
||||
Returns:
|
||||
AgentSkill representing the tool's capability.
|
||||
"""
|
||||
tool_id = tool_name.lower().replace(" ", "_")
|
||||
|
||||
return AgentSkill(
|
||||
id=tool_id,
|
||||
name=tool_name,
|
||||
description=tool_description,
|
||||
tags=[tool_name.lower().replace(" ", "-")],
|
||||
)
|
||||
|
||||
|
||||
def _crew_to_agent_card(crew: Crew, url: str) -> AgentCard:
|
||||
"""Generate an A2A AgentCard from a Crew instance.
|
||||
|
||||
Args:
|
||||
crew: The Crew instance to generate a card for.
|
||||
url: The base URL where this crew will be exposed.
|
||||
|
||||
Returns:
|
||||
AgentCard describing the crew's capabilities.
|
||||
"""
|
||||
crew_name = getattr(crew, "name", None) or crew.__class__.__name__
|
||||
|
||||
description_parts: list[str] = []
|
||||
crew_description = getattr(crew, "description", None)
|
||||
if crew_description:
|
||||
description_parts.append(crew_description)
|
||||
else:
|
||||
agent_roles = [agent.role for agent in crew.agents]
|
||||
description_parts.append(
|
||||
f"A crew of {len(crew.agents)} agents: {', '.join(agent_roles)}"
|
||||
)
|
||||
|
||||
skills = [_task_to_skill(task) for task in crew.tasks]
|
||||
|
||||
return AgentCard(
|
||||
name=crew_name,
|
||||
description=" ".join(description_parts),
|
||||
url=url,
|
||||
version="1.0.0",
|
||||
capabilities=AgentCapabilities(
|
||||
streaming=True,
|
||||
push_notifications=True,
|
||||
),
|
||||
default_input_modes=["text/plain", "application/json"],
|
||||
default_output_modes=["text/plain", "application/json"],
|
||||
skills=skills,
|
||||
)
|
||||
|
||||
|
||||
def _agent_to_agent_card(agent: Agent, url: str) -> AgentCard:
|
||||
"""Generate an A2A AgentCard from an Agent instance.
|
||||
|
||||
Uses A2AServerConfig values when available, falling back to agent properties.
|
||||
If signing_config is provided, the card will be signed with JWS.
|
||||
|
||||
Args:
|
||||
agent: The Agent instance to generate a card for.
|
||||
url: The base URL where this agent will be exposed.
|
||||
|
||||
Returns:
|
||||
AgentCard describing the agent's capabilities.
|
||||
"""
|
||||
from crewai_a2a.utils.agent_card_signing import sign_agent_card
|
||||
|
||||
server_config = _get_server_config(agent) or A2AServerConfig()
|
||||
|
||||
name = server_config.name or agent.role
|
||||
|
||||
description_parts = [agent.goal]
|
||||
if agent.backstory:
|
||||
description_parts.append(agent.backstory)
|
||||
description = server_config.description or " ".join(description_parts)
|
||||
|
||||
skills: list[AgentSkill] = (
|
||||
server_config.skills.copy() if server_config.skills else []
|
||||
)
|
||||
|
||||
if not skills:
|
||||
if agent.tools:
|
||||
for tool in agent.tools:
|
||||
tool_name = getattr(tool, "name", None) or tool.__class__.__name__
|
||||
tool_desc = getattr(tool, "description", None) or f"Tool: {tool_name}"
|
||||
skills.append(_tool_to_skill(tool_name, tool_desc))
|
||||
|
||||
if not skills:
|
||||
skills.append(
|
||||
AgentSkill(
|
||||
id=agent.role.lower().replace(" ", "_"),
|
||||
name=agent.role,
|
||||
description=agent.goal,
|
||||
tags=[agent.role.lower().replace(" ", "-")],
|
||||
)
|
||||
)
|
||||
|
||||
capabilities = server_config.capabilities
|
||||
if server_config.server_extensions:
|
||||
from crewai_a2a.extensions.server import ServerExtensionRegistry
|
||||
|
||||
registry = ServerExtensionRegistry(server_config.server_extensions)
|
||||
ext_list = registry.get_agent_extensions()
|
||||
|
||||
existing_exts = list(capabilities.extensions) if capabilities.extensions else []
|
||||
existing_uris = {e.uri for e in existing_exts}
|
||||
for ext in ext_list:
|
||||
if ext.uri not in existing_uris:
|
||||
existing_exts.append(ext)
|
||||
|
||||
capabilities = capabilities.model_copy(update={"extensions": existing_exts})
|
||||
|
||||
card = AgentCard(
|
||||
name=name,
|
||||
description=description,
|
||||
url=server_config.url or url,
|
||||
version=server_config.version,
|
||||
capabilities=capabilities,
|
||||
default_input_modes=server_config.default_input_modes,
|
||||
default_output_modes=server_config.default_output_modes,
|
||||
skills=skills,
|
||||
preferred_transport=server_config.transport.preferred,
|
||||
protocol_version=server_config.protocol_version,
|
||||
provider=server_config.provider,
|
||||
documentation_url=server_config.documentation_url,
|
||||
icon_url=server_config.icon_url,
|
||||
additional_interfaces=server_config.additional_interfaces,
|
||||
security=server_config.security,
|
||||
security_schemes=server_config.security_schemes,
|
||||
supports_authenticated_extended_card=server_config.supports_authenticated_extended_card,
|
||||
)
|
||||
|
||||
if server_config.signing_config:
|
||||
signature = sign_agent_card(
|
||||
card,
|
||||
private_key=server_config.signing_config.get_private_key(),
|
||||
key_id=server_config.signing_config.key_id,
|
||||
algorithm=server_config.signing_config.algorithm,
|
||||
)
|
||||
card = card.model_copy(update={"signatures": [signature]})
|
||||
elif server_config.signatures:
|
||||
card = card.model_copy(update={"signatures": server_config.signatures})
|
||||
|
||||
return card
|
||||
|
||||
|
||||
def inject_a2a_server_methods(agent: Agent) -> None:
|
||||
"""Inject A2A server methods onto an Agent instance.
|
||||
|
||||
Adds a `to_agent_card(url: str) -> AgentCard` method to the agent
|
||||
that generates an A2A-compliant AgentCard.
|
||||
|
||||
Only injects if the agent has an A2AServerConfig.
|
||||
|
||||
Args:
|
||||
agent: The Agent instance to inject methods onto.
|
||||
"""
|
||||
if _get_server_config(agent) is None:
|
||||
return
|
||||
|
||||
def _to_agent_card(self: Agent, url: str) -> AgentCard:
|
||||
return _agent_to_agent_card(self, url)
|
||||
|
||||
object.__setattr__(agent, "to_agent_card", MethodType(_to_agent_card, agent))
|
||||
236
lib/crewai-a2a/src/crewai_a2a/utils/agent_card_signing.py
Normal file
236
lib/crewai-a2a/src/crewai_a2a/utils/agent_card_signing.py
Normal file
@@ -0,0 +1,236 @@
|
||||
"""AgentCard JWS signing utilities.
|
||||
|
||||
This module provides functions for signing and verifying AgentCards using
|
||||
JSON Web Signatures (JWS) as per RFC 7515. Signed agent cards allow clients
|
||||
to verify the authenticity and integrity of agent card information.
|
||||
|
||||
Example:
|
||||
>>> from crewai_a2a.utils.agent_card_signing import sign_agent_card
|
||||
>>> signature = sign_agent_card(agent_card, private_key_pem, key_id="key-1")
|
||||
>>> card_with_sig = card.model_copy(update={"signatures": [signature]})
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Literal
|
||||
|
||||
from a2a.types import AgentCard, AgentCardSignature
|
||||
import jwt
|
||||
from pydantic import SecretStr
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
SigningAlgorithm = Literal[
|
||||
"RS256", "RS384", "RS512", "ES256", "ES384", "ES512", "PS256", "PS384", "PS512"
|
||||
]
|
||||
|
||||
|
||||
def _normalize_private_key(private_key: str | bytes | SecretStr) -> bytes:
|
||||
"""Normalize private key to bytes format.
|
||||
|
||||
Args:
|
||||
private_key: PEM-encoded private key as string, bytes, or SecretStr.
|
||||
|
||||
Returns:
|
||||
Private key as bytes.
|
||||
"""
|
||||
if isinstance(private_key, SecretStr):
|
||||
private_key = private_key.get_secret_value()
|
||||
if isinstance(private_key, str):
|
||||
private_key = private_key.encode()
|
||||
return private_key
|
||||
|
||||
|
||||
def _serialize_agent_card(agent_card: AgentCard) -> str:
|
||||
"""Serialize AgentCard to canonical JSON for signing.
|
||||
|
||||
Excludes the signatures field to avoid circular reference during signing.
|
||||
Uses sorted keys and compact separators for deterministic output.
|
||||
|
||||
Args:
|
||||
agent_card: The AgentCard to serialize.
|
||||
|
||||
Returns:
|
||||
Canonical JSON string representation.
|
||||
"""
|
||||
card_dict = agent_card.model_dump(exclude={"signatures"}, exclude_none=True)
|
||||
return json.dumps(card_dict, sort_keys=True, separators=(",", ":"))
|
||||
|
||||
|
||||
def _base64url_encode(data: bytes | str) -> str:
|
||||
"""Encode data to URL-safe base64 without padding.
|
||||
|
||||
Args:
|
||||
data: Data to encode.
|
||||
|
||||
Returns:
|
||||
URL-safe base64 encoded string without padding.
|
||||
"""
|
||||
if isinstance(data, str):
|
||||
data = data.encode()
|
||||
return base64.urlsafe_b64encode(data).rstrip(b"=").decode("ascii")
|
||||
|
||||
|
||||
def sign_agent_card(
|
||||
agent_card: AgentCard,
|
||||
private_key: str | bytes | SecretStr,
|
||||
key_id: str | None = None,
|
||||
algorithm: SigningAlgorithm = "RS256",
|
||||
) -> AgentCardSignature:
|
||||
"""Sign an AgentCard using JWS (RFC 7515).
|
||||
|
||||
Creates a detached JWS signature for the AgentCard. The signature covers
|
||||
all fields except the signatures field itself.
|
||||
|
||||
Args:
|
||||
agent_card: The AgentCard to sign.
|
||||
private_key: PEM-encoded private key (RSA, EC, or RSA-PSS).
|
||||
key_id: Optional key identifier for the JWS header (kid claim).
|
||||
algorithm: Signing algorithm (RS256, ES256, PS256, etc.).
|
||||
|
||||
Returns:
|
||||
AgentCardSignature with protected header and signature.
|
||||
|
||||
Raises:
|
||||
jwt.exceptions.InvalidKeyError: If the private key is invalid.
|
||||
ValueError: If the algorithm is not supported for the key type.
|
||||
|
||||
Example:
|
||||
>>> signature = sign_agent_card(
|
||||
... agent_card,
|
||||
... private_key_pem="-----BEGIN PRIVATE KEY-----...",
|
||||
... key_id="my-key-id",
|
||||
... )
|
||||
"""
|
||||
key_bytes = _normalize_private_key(private_key)
|
||||
payload = _serialize_agent_card(agent_card)
|
||||
|
||||
protected_header: dict[str, Any] = {"typ": "JWS"}
|
||||
if key_id:
|
||||
protected_header["kid"] = key_id
|
||||
|
||||
jws_token = jwt.api_jws.encode(
|
||||
payload.encode(),
|
||||
key_bytes,
|
||||
algorithm=algorithm,
|
||||
headers=protected_header,
|
||||
)
|
||||
|
||||
parts = jws_token.split(".")
|
||||
protected_b64 = parts[0]
|
||||
signature_b64 = parts[2]
|
||||
|
||||
header: dict[str, Any] | None = None
|
||||
if key_id:
|
||||
header = {"kid": key_id}
|
||||
|
||||
return AgentCardSignature(
|
||||
protected=protected_b64,
|
||||
signature=signature_b64,
|
||||
header=header,
|
||||
)
|
||||
|
||||
|
||||
def verify_agent_card_signature(
|
||||
agent_card: AgentCard,
|
||||
signature: AgentCardSignature,
|
||||
public_key: str | bytes,
|
||||
algorithms: list[str] | None = None,
|
||||
) -> bool:
|
||||
"""Verify an AgentCard JWS signature.
|
||||
|
||||
Validates that the signature was created with the corresponding private key
|
||||
and that the AgentCard content has not been modified.
|
||||
|
||||
Args:
|
||||
agent_card: The AgentCard to verify.
|
||||
signature: The AgentCardSignature to validate.
|
||||
public_key: PEM-encoded public key (RSA, EC, or RSA-PSS).
|
||||
algorithms: List of allowed algorithms. Defaults to common asymmetric algorithms.
|
||||
|
||||
Returns:
|
||||
True if signature is valid, False otherwise.
|
||||
|
||||
Example:
|
||||
>>> is_valid = verify_agent_card_signature(
|
||||
... agent_card, signature, public_key_pem="-----BEGIN PUBLIC KEY-----..."
|
||||
... )
|
||||
"""
|
||||
if algorithms is None:
|
||||
algorithms = [
|
||||
"RS256",
|
||||
"RS384",
|
||||
"RS512",
|
||||
"ES256",
|
||||
"ES384",
|
||||
"ES512",
|
||||
"PS256",
|
||||
"PS384",
|
||||
"PS512",
|
||||
]
|
||||
|
||||
if isinstance(public_key, str):
|
||||
public_key = public_key.encode()
|
||||
|
||||
payload = _serialize_agent_card(agent_card)
|
||||
payload_b64 = _base64url_encode(payload)
|
||||
jws_token = f"{signature.protected}.{payload_b64}.{signature.signature}"
|
||||
|
||||
try:
|
||||
jwt.api_jws.decode(
|
||||
jws_token,
|
||||
public_key,
|
||||
algorithms=algorithms,
|
||||
)
|
||||
return True
|
||||
except jwt.InvalidSignatureError:
|
||||
logger.debug(
|
||||
"AgentCard signature verification failed",
|
||||
extra={"reason": "invalid_signature"},
|
||||
)
|
||||
return False
|
||||
except jwt.DecodeError as e:
|
||||
logger.debug(
|
||||
"AgentCard signature verification failed",
|
||||
extra={"reason": "decode_error", "error": str(e)},
|
||||
)
|
||||
return False
|
||||
except jwt.InvalidAlgorithmError as e:
|
||||
logger.debug(
|
||||
"AgentCard signature verification failed",
|
||||
extra={"reason": "algorithm_error", "error": str(e)},
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
def get_key_id_from_signature(signature: AgentCardSignature) -> str | None:
|
||||
"""Extract the key ID (kid) from an AgentCardSignature.
|
||||
|
||||
Checks both the unprotected header and the protected header for the kid claim.
|
||||
|
||||
Args:
|
||||
signature: The AgentCardSignature to extract from.
|
||||
|
||||
Returns:
|
||||
The key ID if present, None otherwise.
|
||||
"""
|
||||
if signature.header and "kid" in signature.header:
|
||||
kid: str = signature.header["kid"]
|
||||
return kid
|
||||
|
||||
try:
|
||||
protected = signature.protected
|
||||
padding_needed = 4 - (len(protected) % 4)
|
||||
if padding_needed != 4:
|
||||
protected += "=" * padding_needed
|
||||
|
||||
protected_json = base64.urlsafe_b64decode(protected).decode()
|
||||
protected_header: dict[str, Any] = json.loads(protected_json)
|
||||
return protected_header.get("kid")
|
||||
except (ValueError, json.JSONDecodeError):
|
||||
return None
|
||||
338
lib/crewai-a2a/src/crewai_a2a/utils/content_type.py
Normal file
338
lib/crewai-a2a/src/crewai_a2a/utils/content_type.py
Normal file
@@ -0,0 +1,338 @@
|
||||
"""Content type negotiation for A2A protocol.
|
||||
|
||||
This module handles negotiation of input/output MIME types between A2A clients
|
||||
and servers based on AgentCard capabilities.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Annotated, Final, Literal, cast
|
||||
|
||||
from a2a.types import Part
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.a2a_events import A2AContentTypeNegotiatedEvent
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from a2a.types import AgentCard, AgentSkill
|
||||
|
||||
|
||||
TEXT_PLAIN: Literal["text/plain"] = "text/plain"
|
||||
APPLICATION_JSON: Literal["application/json"] = "application/json"
|
||||
IMAGE_PNG: Literal["image/png"] = "image/png"
|
||||
IMAGE_JPEG: Literal["image/jpeg"] = "image/jpeg"
|
||||
IMAGE_WILDCARD: Literal["image/*"] = "image/*"
|
||||
APPLICATION_PDF: Literal["application/pdf"] = "application/pdf"
|
||||
APPLICATION_OCTET_STREAM: Literal["application/octet-stream"] = (
|
||||
"application/octet-stream"
|
||||
)
|
||||
|
||||
DEFAULT_CLIENT_INPUT_MODES: Final[list[Literal["text/plain", "application/json"]]] = [
|
||||
TEXT_PLAIN,
|
||||
APPLICATION_JSON,
|
||||
]
|
||||
DEFAULT_CLIENT_OUTPUT_MODES: Final[list[Literal["text/plain", "application/json"]]] = [
|
||||
TEXT_PLAIN,
|
||||
APPLICATION_JSON,
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class NegotiatedContentTypes:
|
||||
"""Result of content type negotiation."""
|
||||
|
||||
input_modes: Annotated[list[str], "Negotiated input MIME types the client can send"]
|
||||
output_modes: Annotated[
|
||||
list[str], "Negotiated output MIME types the server will produce"
|
||||
]
|
||||
effective_input_modes: Annotated[list[str], "Server's effective input modes"]
|
||||
effective_output_modes: Annotated[list[str], "Server's effective output modes"]
|
||||
skill_name: Annotated[
|
||||
str | None, "Skill name if negotiation was skill-specific"
|
||||
] = None
|
||||
|
||||
|
||||
class ContentTypeNegotiationError(Exception):
|
||||
"""Raised when no compatible content types can be negotiated."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client_input_modes: list[str],
|
||||
client_output_modes: list[str],
|
||||
server_input_modes: list[str],
|
||||
server_output_modes: list[str],
|
||||
direction: str = "both",
|
||||
message: str | None = None,
|
||||
) -> None:
|
||||
self.client_input_modes = client_input_modes
|
||||
self.client_output_modes = client_output_modes
|
||||
self.server_input_modes = server_input_modes
|
||||
self.server_output_modes = server_output_modes
|
||||
self.direction = direction
|
||||
|
||||
if message is None:
|
||||
if direction == "input":
|
||||
message = (
|
||||
f"No compatible input content types. "
|
||||
f"Client supports: {client_input_modes}, "
|
||||
f"Server accepts: {server_input_modes}"
|
||||
)
|
||||
elif direction == "output":
|
||||
message = (
|
||||
f"No compatible output content types. "
|
||||
f"Client accepts: {client_output_modes}, "
|
||||
f"Server produces: {server_output_modes}"
|
||||
)
|
||||
else:
|
||||
message = (
|
||||
f"No compatible content types. "
|
||||
f"Input - Client: {client_input_modes}, Server: {server_input_modes}. "
|
||||
f"Output - Client: {client_output_modes}, Server: {server_output_modes}"
|
||||
)
|
||||
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
def _normalize_mime_type(mime_type: str) -> str:
|
||||
"""Normalize MIME type for comparison (lowercase, strip whitespace)."""
|
||||
return mime_type.lower().strip()
|
||||
|
||||
|
||||
def _mime_types_compatible(client_type: str, server_type: str) -> bool:
|
||||
"""Check if two MIME types are compatible.
|
||||
|
||||
Handles wildcards like image/* matching image/png.
|
||||
"""
|
||||
client_normalized = _normalize_mime_type(client_type)
|
||||
server_normalized = _normalize_mime_type(server_type)
|
||||
|
||||
if client_normalized == server_normalized:
|
||||
return True
|
||||
|
||||
if "*" in client_normalized or "*" in server_normalized:
|
||||
client_parts = client_normalized.split("/")
|
||||
server_parts = server_normalized.split("/")
|
||||
|
||||
if len(client_parts) == 2 and len(server_parts) == 2:
|
||||
type_match = (
|
||||
client_parts[0] == server_parts[0]
|
||||
or client_parts[0] == "*"
|
||||
or server_parts[0] == "*"
|
||||
)
|
||||
subtype_match = (
|
||||
client_parts[1] == server_parts[1]
|
||||
or client_parts[1] == "*"
|
||||
or server_parts[1] == "*"
|
||||
)
|
||||
return type_match and subtype_match
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def _find_compatible_modes(
|
||||
client_modes: list[str], server_modes: list[str]
|
||||
) -> list[str]:
|
||||
"""Find compatible MIME types between client and server.
|
||||
|
||||
Returns modes in client preference order.
|
||||
"""
|
||||
compatible = []
|
||||
for client_mode in client_modes:
|
||||
for server_mode in server_modes:
|
||||
if _mime_types_compatible(client_mode, server_mode):
|
||||
if "*" in client_mode and "*" not in server_mode:
|
||||
if server_mode not in compatible:
|
||||
compatible.append(server_mode)
|
||||
else:
|
||||
if client_mode not in compatible:
|
||||
compatible.append(client_mode)
|
||||
break
|
||||
return compatible
|
||||
|
||||
|
||||
def _get_effective_modes(
|
||||
agent_card: AgentCard,
|
||||
skill_name: str | None = None,
|
||||
) -> tuple[list[str], list[str], AgentSkill | None]:
|
||||
"""Get effective input/output modes from agent card.
|
||||
|
||||
If skill_name is provided and the skill has custom modes, those are used.
|
||||
Otherwise, falls back to agent card defaults.
|
||||
"""
|
||||
skill: AgentSkill | None = None
|
||||
|
||||
if skill_name and agent_card.skills:
|
||||
for s in agent_card.skills:
|
||||
if s.name == skill_name or s.id == skill_name:
|
||||
skill = s
|
||||
break
|
||||
|
||||
if skill:
|
||||
input_modes = (
|
||||
skill.input_modes if skill.input_modes else agent_card.default_input_modes
|
||||
)
|
||||
output_modes = (
|
||||
skill.output_modes
|
||||
if skill.output_modes
|
||||
else agent_card.default_output_modes
|
||||
)
|
||||
else:
|
||||
input_modes = agent_card.default_input_modes
|
||||
output_modes = agent_card.default_output_modes
|
||||
|
||||
return input_modes, output_modes, skill
|
||||
|
||||
|
||||
def negotiate_content_types(
|
||||
agent_card: AgentCard,
|
||||
client_input_modes: list[str] | None = None,
|
||||
client_output_modes: list[str] | None = None,
|
||||
skill_name: str | None = None,
|
||||
emit_event: bool = True,
|
||||
endpoint: str | None = None,
|
||||
a2a_agent_name: str | None = None,
|
||||
strict: bool = False,
|
||||
) -> NegotiatedContentTypes:
|
||||
"""Negotiate content types between client and server.
|
||||
|
||||
Args:
|
||||
agent_card: The remote agent's card with capability info.
|
||||
client_input_modes: MIME types the client can send. Defaults to text/plain and application/json.
|
||||
client_output_modes: MIME types the client can accept. Defaults to text/plain and application/json.
|
||||
skill_name: Optional skill to use for mode lookup.
|
||||
emit_event: Whether to emit a content type negotiation event.
|
||||
endpoint: Agent endpoint (for event metadata).
|
||||
a2a_agent_name: Agent name (for event metadata).
|
||||
strict: If True, raises error when no compatible types found.
|
||||
If False, returns empty lists for incompatible directions.
|
||||
|
||||
Returns:
|
||||
NegotiatedContentTypes with compatible input and output modes.
|
||||
|
||||
Raises:
|
||||
ContentTypeNegotiationError: If strict=True and no compatible types found.
|
||||
"""
|
||||
if client_input_modes is None:
|
||||
client_input_modes = cast(list[str], DEFAULT_CLIENT_INPUT_MODES.copy())
|
||||
if client_output_modes is None:
|
||||
client_output_modes = cast(list[str], DEFAULT_CLIENT_OUTPUT_MODES.copy())
|
||||
|
||||
server_input_modes, server_output_modes, skill = _get_effective_modes(
|
||||
agent_card, skill_name
|
||||
)
|
||||
|
||||
compatible_input = _find_compatible_modes(client_input_modes, server_input_modes)
|
||||
compatible_output = _find_compatible_modes(client_output_modes, server_output_modes)
|
||||
|
||||
if strict:
|
||||
if not compatible_input and not compatible_output:
|
||||
raise ContentTypeNegotiationError(
|
||||
client_input_modes=client_input_modes,
|
||||
client_output_modes=client_output_modes,
|
||||
server_input_modes=server_input_modes,
|
||||
server_output_modes=server_output_modes,
|
||||
)
|
||||
if not compatible_input:
|
||||
raise ContentTypeNegotiationError(
|
||||
client_input_modes=client_input_modes,
|
||||
client_output_modes=client_output_modes,
|
||||
server_input_modes=server_input_modes,
|
||||
server_output_modes=server_output_modes,
|
||||
direction="input",
|
||||
)
|
||||
if not compatible_output:
|
||||
raise ContentTypeNegotiationError(
|
||||
client_input_modes=client_input_modes,
|
||||
client_output_modes=client_output_modes,
|
||||
server_input_modes=server_input_modes,
|
||||
server_output_modes=server_output_modes,
|
||||
direction="output",
|
||||
)
|
||||
|
||||
result = NegotiatedContentTypes(
|
||||
input_modes=compatible_input,
|
||||
output_modes=compatible_output,
|
||||
effective_input_modes=server_input_modes,
|
||||
effective_output_modes=server_output_modes,
|
||||
skill_name=skill.name if skill else None,
|
||||
)
|
||||
|
||||
if emit_event:
|
||||
crewai_event_bus.emit(
|
||||
None,
|
||||
A2AContentTypeNegotiatedEvent(
|
||||
endpoint=endpoint or agent_card.url,
|
||||
a2a_agent_name=a2a_agent_name or agent_card.name,
|
||||
skill_name=skill_name,
|
||||
client_input_modes=client_input_modes,
|
||||
client_output_modes=client_output_modes,
|
||||
server_input_modes=server_input_modes,
|
||||
server_output_modes=server_output_modes,
|
||||
negotiated_input_modes=compatible_input,
|
||||
negotiated_output_modes=compatible_output,
|
||||
negotiation_success=bool(compatible_input and compatible_output),
|
||||
),
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def validate_content_type(
|
||||
content_type: str,
|
||||
allowed_modes: list[str],
|
||||
) -> bool:
|
||||
"""Validate that a content type is allowed by a list of modes.
|
||||
|
||||
Args:
|
||||
content_type: The MIME type to validate.
|
||||
allowed_modes: List of allowed MIME types (may include wildcards).
|
||||
|
||||
Returns:
|
||||
True if content_type is compatible with any allowed mode.
|
||||
"""
|
||||
for mode in allowed_modes:
|
||||
if _mime_types_compatible(content_type, mode):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def get_part_content_type(part: Part) -> str:
|
||||
"""Extract MIME type from an A2A Part.
|
||||
|
||||
Args:
|
||||
part: A Part object containing TextPart, DataPart, or FilePart.
|
||||
|
||||
Returns:
|
||||
The MIME type string for this part.
|
||||
"""
|
||||
root = part.root
|
||||
if root.kind == "text":
|
||||
return TEXT_PLAIN
|
||||
if root.kind == "data":
|
||||
return APPLICATION_JSON
|
||||
if root.kind == "file":
|
||||
return root.file.mime_type or APPLICATION_OCTET_STREAM
|
||||
return APPLICATION_OCTET_STREAM
|
||||
|
||||
|
||||
def validate_message_parts(
|
||||
parts: list[Part],
|
||||
allowed_modes: list[str],
|
||||
) -> list[str]:
|
||||
"""Validate that all message parts have allowed content types.
|
||||
|
||||
Args:
|
||||
parts: List of Parts from the incoming message.
|
||||
allowed_modes: List of allowed MIME types (from default_input_modes).
|
||||
|
||||
Returns:
|
||||
List of invalid content types found (empty if all valid).
|
||||
"""
|
||||
invalid_types: list[str] = []
|
||||
for part in parts:
|
||||
content_type = get_part_content_type(part)
|
||||
if not validate_content_type(content_type, allowed_modes):
|
||||
if content_type not in invalid_types:
|
||||
invalid_types.append(content_type)
|
||||
return invalid_types
|
||||
980
lib/crewai-a2a/src/crewai_a2a/utils/delegation.py
Normal file
980
lib/crewai-a2a/src/crewai_a2a/utils/delegation.py
Normal file
@@ -0,0 +1,980 @@
|
||||
"""A2A delegation utilities for executing tasks on remote agents."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
from collections.abc import AsyncIterator, Callable, MutableMapping
|
||||
from contextlib import asynccontextmanager
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any, Final, Literal
|
||||
import uuid
|
||||
|
||||
from a2a.client import Client, ClientConfig, ClientFactory
|
||||
from a2a.types import (
|
||||
AgentCard,
|
||||
FilePart,
|
||||
FileWithBytes,
|
||||
Message,
|
||||
Part,
|
||||
PushNotificationConfig as A2APushNotificationConfig,
|
||||
Role,
|
||||
TextPart,
|
||||
)
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.a2a_events import (
|
||||
A2AConversationStartedEvent,
|
||||
A2ADelegationCompletedEvent,
|
||||
A2ADelegationStartedEvent,
|
||||
A2AMessageSentEvent,
|
||||
)
|
||||
import httpx
|
||||
from pydantic import BaseModel
|
||||
|
||||
from crewai_a2a.auth.client_schemes import APIKeyAuth, HTTPDigestAuth
|
||||
from crewai_a2a.auth.utils import (
|
||||
_auth_store,
|
||||
configure_auth_client,
|
||||
validate_auth_against_agent_card,
|
||||
)
|
||||
from crewai_a2a.config import ClientTransportConfig, GRPCClientConfig
|
||||
from crewai_a2a.extensions.registry import (
|
||||
ExtensionsMiddleware,
|
||||
validate_required_extensions,
|
||||
)
|
||||
from crewai_a2a.task_helpers import TaskStateResult
|
||||
from crewai_a2a.types import (
|
||||
HANDLER_REGISTRY,
|
||||
HandlerType,
|
||||
PartsDict,
|
||||
PartsMetadataDict,
|
||||
TransportType,
|
||||
)
|
||||
from crewai_a2a.updates import (
|
||||
PollingConfig,
|
||||
PushNotificationConfig,
|
||||
StreamingHandler,
|
||||
UpdateConfig,
|
||||
)
|
||||
from crewai_a2a.utils.agent_card import (
|
||||
_afetch_agent_card_cached,
|
||||
_get_tls_verify,
|
||||
_prepare_auth_headers,
|
||||
)
|
||||
from crewai_a2a.utils.content_type import (
|
||||
DEFAULT_CLIENT_OUTPUT_MODES,
|
||||
negotiate_content_types,
|
||||
)
|
||||
from crewai_a2a.utils.transport import (
|
||||
NegotiatedTransport,
|
||||
TransportNegotiationError,
|
||||
negotiate_transport,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from a2a.types import Message
|
||||
|
||||
from crewai_a2a.auth.client_schemes import ClientAuthScheme
|
||||
|
||||
|
||||
_DEFAULT_TRANSPORT: Final[TransportType] = "JSONRPC"
|
||||
|
||||
|
||||
def _create_file_parts(input_files: dict[str, Any] | None) -> list[Part]:
|
||||
"""Convert FileInput dictionary to FilePart objects.
|
||||
|
||||
Args:
|
||||
input_files: Dictionary mapping names to FileInput objects.
|
||||
|
||||
Returns:
|
||||
List of Part objects containing FilePart data.
|
||||
"""
|
||||
if not input_files:
|
||||
return []
|
||||
|
||||
try:
|
||||
import crewai_files # noqa: F401
|
||||
except ImportError:
|
||||
logger.debug("crewai_files not installed, skipping file parts")
|
||||
return []
|
||||
|
||||
parts: list[Part] = []
|
||||
for name, file_input in input_files.items():
|
||||
content_bytes = file_input.read()
|
||||
content_base64 = base64.b64encode(content_bytes).decode()
|
||||
file_with_bytes = FileWithBytes(
|
||||
bytes=content_base64,
|
||||
mimeType=file_input.content_type,
|
||||
name=file_input.filename or name,
|
||||
)
|
||||
parts.append(Part(root=FilePart(file=file_with_bytes)))
|
||||
|
||||
return parts
|
||||
|
||||
|
||||
def get_handler(config: UpdateConfig | None) -> HandlerType:
|
||||
"""Get the handler class for a given update config.
|
||||
|
||||
Args:
|
||||
config: Update mechanism configuration.
|
||||
|
||||
Returns:
|
||||
Handler class for the config type, defaults to StreamingHandler.
|
||||
"""
|
||||
if config is None:
|
||||
return StreamingHandler
|
||||
return HANDLER_REGISTRY.get(type(config), StreamingHandler)
|
||||
|
||||
|
||||
def execute_a2a_delegation(
|
||||
endpoint: str,
|
||||
auth: ClientAuthScheme | None,
|
||||
timeout: int,
|
||||
task_description: str,
|
||||
context: str | None = None,
|
||||
context_id: str | None = None,
|
||||
task_id: str | None = None,
|
||||
reference_task_ids: list[str] | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
extensions: dict[str, Any] | None = None,
|
||||
conversation_history: list[Message] | None = None,
|
||||
agent_id: str | None = None,
|
||||
agent_role: Role | None = None,
|
||||
agent_branch: Any | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
turn_number: int | None = None,
|
||||
updates: UpdateConfig | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
skill_id: str | None = None,
|
||||
client_extensions: list[str] | None = None,
|
||||
transport: ClientTransportConfig | None = None,
|
||||
accepted_output_modes: list[str] | None = None,
|
||||
input_files: dict[str, Any] | None = None,
|
||||
) -> TaskStateResult:
|
||||
"""Execute a task delegation to a remote A2A agent synchronously.
|
||||
|
||||
WARNING: This function blocks the entire thread by creating and running a new
|
||||
event loop. Prefer using 'await aexecute_a2a_delegation()' in async contexts
|
||||
for better performance and resource efficiency.
|
||||
|
||||
This is a synchronous wrapper around aexecute_a2a_delegation that creates a
|
||||
new event loop to run the async implementation. It is provided for compatibility
|
||||
with synchronous code paths only.
|
||||
|
||||
Args:
|
||||
endpoint: A2A agent endpoint URL (AgentCard URL).
|
||||
auth: Optional ClientAuthScheme for authentication.
|
||||
timeout: Request timeout in seconds.
|
||||
task_description: The task to delegate.
|
||||
context: Optional context information.
|
||||
context_id: Context ID for correlating messages/tasks.
|
||||
task_id: Specific task identifier.
|
||||
reference_task_ids: List of related task IDs.
|
||||
metadata: Additional metadata.
|
||||
extensions: Protocol extensions for custom fields.
|
||||
conversation_history: Previous Message objects from conversation.
|
||||
agent_id: Agent identifier for logging.
|
||||
agent_role: Role of the CrewAI agent delegating the task.
|
||||
agent_branch: Optional agent tree branch for logging.
|
||||
response_model: Optional Pydantic model for structured outputs.
|
||||
turn_number: Optional turn number for multi-turn conversations.
|
||||
updates: Update mechanism config from A2AConfig.updates.
|
||||
from_task: Optional CrewAI Task object for event metadata.
|
||||
from_agent: Optional CrewAI Agent object for event metadata.
|
||||
skill_id: Optional skill ID to target a specific agent capability.
|
||||
client_extensions: A2A protocol extension URIs the client supports.
|
||||
transport: Transport configuration (preferred, supported transports, gRPC settings).
|
||||
accepted_output_modes: MIME types the client can accept in responses.
|
||||
input_files: Optional dictionary of files to send to remote agent.
|
||||
|
||||
Returns:
|
||||
TaskStateResult with status, result/error, history, and agent_card.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If called from an async context with a running event loop.
|
||||
"""
|
||||
try:
|
||||
asyncio.get_running_loop()
|
||||
raise RuntimeError(
|
||||
"execute_a2a_delegation() cannot be called from an async context. "
|
||||
"Use 'await aexecute_a2a_delegation()' instead."
|
||||
)
|
||||
except RuntimeError as e:
|
||||
if "no running event loop" not in str(e).lower():
|
||||
raise
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
return loop.run_until_complete(
|
||||
aexecute_a2a_delegation(
|
||||
endpoint=endpoint,
|
||||
auth=auth,
|
||||
timeout=timeout,
|
||||
task_description=task_description,
|
||||
context=context,
|
||||
context_id=context_id,
|
||||
task_id=task_id,
|
||||
reference_task_ids=reference_task_ids,
|
||||
metadata=metadata,
|
||||
extensions=extensions,
|
||||
conversation_history=conversation_history,
|
||||
agent_id=agent_id,
|
||||
agent_role=agent_role,
|
||||
agent_branch=agent_branch,
|
||||
response_model=response_model,
|
||||
turn_number=turn_number,
|
||||
updates=updates,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
skill_id=skill_id,
|
||||
client_extensions=client_extensions,
|
||||
transport=transport,
|
||||
accepted_output_modes=accepted_output_modes,
|
||||
input_files=input_files,
|
||||
)
|
||||
)
|
||||
finally:
|
||||
try:
|
||||
loop.run_until_complete(loop.shutdown_asyncgens())
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
|
||||
async def aexecute_a2a_delegation(
|
||||
endpoint: str,
|
||||
auth: ClientAuthScheme | None,
|
||||
timeout: int,
|
||||
task_description: str,
|
||||
context: str | None = None,
|
||||
context_id: str | None = None,
|
||||
task_id: str | None = None,
|
||||
reference_task_ids: list[str] | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
extensions: dict[str, Any] | None = None,
|
||||
conversation_history: list[Message] | None = None,
|
||||
agent_id: str | None = None,
|
||||
agent_role: Role | None = None,
|
||||
agent_branch: Any | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
turn_number: int | None = None,
|
||||
updates: UpdateConfig | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
skill_id: str | None = None,
|
||||
client_extensions: list[str] | None = None,
|
||||
transport: ClientTransportConfig | None = None,
|
||||
accepted_output_modes: list[str] | None = None,
|
||||
input_files: dict[str, Any] | None = None,
|
||||
) -> TaskStateResult:
|
||||
"""Execute a task delegation to a remote A2A agent asynchronously.
|
||||
|
||||
Native async implementation with multi-turn support. Use this when running
|
||||
in an async context (e.g., with Crew.akickoff() or agent.aexecute_task()).
|
||||
|
||||
Args:
|
||||
endpoint: A2A agent endpoint URL.
|
||||
auth: Optional ClientAuthScheme for authentication.
|
||||
timeout: Request timeout in seconds.
|
||||
task_description: The task to delegate.
|
||||
context: Optional context information.
|
||||
context_id: Context ID for correlating messages/tasks.
|
||||
task_id: Specific task identifier.
|
||||
reference_task_ids: List of related task IDs.
|
||||
metadata: Additional metadata.
|
||||
extensions: Protocol extensions for custom fields.
|
||||
conversation_history: Previous Message objects from conversation.
|
||||
agent_id: Agent identifier for logging.
|
||||
agent_role: Role of the CrewAI agent delegating the task.
|
||||
agent_branch: Optional agent tree branch for logging.
|
||||
response_model: Optional Pydantic model for structured outputs.
|
||||
turn_number: Optional turn number for multi-turn conversations.
|
||||
updates: Update mechanism config from A2AConfig.updates.
|
||||
from_task: Optional CrewAI Task object for event metadata.
|
||||
from_agent: Optional CrewAI Agent object for event metadata.
|
||||
skill_id: Optional skill ID to target a specific agent capability.
|
||||
client_extensions: A2A protocol extension URIs the client supports.
|
||||
transport: Transport configuration (preferred, supported transports, gRPC settings).
|
||||
accepted_output_modes: MIME types the client can accept in responses.
|
||||
input_files: Optional dictionary of files to send to remote agent.
|
||||
|
||||
Returns:
|
||||
TaskStateResult with status, result/error, history, and agent_card.
|
||||
"""
|
||||
if conversation_history is None:
|
||||
conversation_history = []
|
||||
|
||||
is_multiturn = len(conversation_history) > 0
|
||||
if turn_number is None:
|
||||
turn_number = len([m for m in conversation_history if m.role == Role.user]) + 1
|
||||
|
||||
try:
|
||||
result = await _aexecute_a2a_delegation_impl(
|
||||
endpoint=endpoint,
|
||||
auth=auth,
|
||||
timeout=timeout,
|
||||
task_description=task_description,
|
||||
context=context,
|
||||
context_id=context_id,
|
||||
task_id=task_id,
|
||||
reference_task_ids=reference_task_ids,
|
||||
metadata=metadata,
|
||||
extensions=extensions,
|
||||
conversation_history=conversation_history,
|
||||
is_multiturn=is_multiturn,
|
||||
turn_number=turn_number,
|
||||
agent_branch=agent_branch,
|
||||
agent_id=agent_id,
|
||||
agent_role=agent_role,
|
||||
response_model=response_model,
|
||||
updates=updates,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
skill_id=skill_id,
|
||||
client_extensions=client_extensions,
|
||||
transport=transport,
|
||||
accepted_output_modes=accepted_output_modes,
|
||||
input_files=input_files,
|
||||
)
|
||||
except Exception as e:
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
A2ADelegationCompletedEvent(
|
||||
status="failed",
|
||||
result=None,
|
||||
error=str(e),
|
||||
context_id=context_id,
|
||||
is_multiturn=is_multiturn,
|
||||
endpoint=endpoint,
|
||||
metadata=metadata,
|
||||
extensions=list(extensions.keys()) if extensions else None,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
raise
|
||||
|
||||
agent_card_data = result.get("agent_card")
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
A2ADelegationCompletedEvent(
|
||||
status=result["status"],
|
||||
result=result.get("result"),
|
||||
error=result.get("error"),
|
||||
context_id=context_id,
|
||||
is_multiturn=is_multiturn,
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=result.get("a2a_agent_name"),
|
||||
agent_card=agent_card_data,
|
||||
provider=agent_card_data.get("provider") if agent_card_data else None,
|
||||
metadata=metadata,
|
||||
extensions=list(extensions.keys()) if extensions else None,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
async def _aexecute_a2a_delegation_impl(
|
||||
endpoint: str,
|
||||
auth: ClientAuthScheme | None,
|
||||
timeout: int,
|
||||
task_description: str,
|
||||
context: str | None,
|
||||
context_id: str | None,
|
||||
task_id: str | None,
|
||||
reference_task_ids: list[str] | None,
|
||||
metadata: dict[str, Any] | None,
|
||||
extensions: dict[str, Any] | None,
|
||||
conversation_history: list[Message],
|
||||
is_multiturn: bool,
|
||||
turn_number: int,
|
||||
agent_branch: Any | None,
|
||||
agent_id: str | None,
|
||||
agent_role: str | None,
|
||||
response_model: type[BaseModel] | None,
|
||||
updates: UpdateConfig | None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
skill_id: str | None = None,
|
||||
client_extensions: list[str] | None = None,
|
||||
transport: ClientTransportConfig | None = None,
|
||||
accepted_output_modes: list[str] | None = None,
|
||||
input_files: dict[str, Any] | None = None,
|
||||
) -> TaskStateResult:
|
||||
"""Internal async implementation of A2A delegation."""
|
||||
if transport is None:
|
||||
transport = ClientTransportConfig()
|
||||
if auth:
|
||||
auth_data = auth.model_dump_json(
|
||||
exclude={
|
||||
"_access_token",
|
||||
"_token_expires_at",
|
||||
"_refresh_token",
|
||||
"_authorization_callback",
|
||||
}
|
||||
)
|
||||
auth_hash = _auth_store.compute_key(type(auth).__name__, auth_data)
|
||||
else:
|
||||
auth_hash = _auth_store.compute_key("none", endpoint)
|
||||
_auth_store.set(auth_hash, auth)
|
||||
agent_card = await _afetch_agent_card_cached(
|
||||
endpoint=endpoint, auth_hash=auth_hash, timeout=timeout
|
||||
)
|
||||
|
||||
validate_auth_against_agent_card(agent_card, auth)
|
||||
|
||||
unsupported_exts = validate_required_extensions(agent_card, client_extensions)
|
||||
if unsupported_exts:
|
||||
ext_uris = [ext.uri for ext in unsupported_exts]
|
||||
raise ValueError(
|
||||
f"Agent requires extensions not supported by client: {ext_uris}"
|
||||
)
|
||||
|
||||
negotiated: NegotiatedTransport | None = None
|
||||
effective_transport: TransportType = transport.preferred or _DEFAULT_TRANSPORT
|
||||
effective_url = endpoint
|
||||
|
||||
client_transports: list[str] = (
|
||||
list(transport.supported) if transport.supported else [_DEFAULT_TRANSPORT]
|
||||
)
|
||||
|
||||
try:
|
||||
negotiated = negotiate_transport(
|
||||
agent_card=agent_card,
|
||||
client_supported_transports=client_transports,
|
||||
client_preferred_transport=transport.preferred,
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=agent_card.name,
|
||||
)
|
||||
effective_transport = negotiated.transport # type: ignore[assignment]
|
||||
effective_url = negotiated.url
|
||||
except TransportNegotiationError as e:
|
||||
logger.warning(
|
||||
"Transport negotiation failed, using fallback",
|
||||
extra={
|
||||
"error": str(e),
|
||||
"fallback_transport": effective_transport,
|
||||
"fallback_url": effective_url,
|
||||
"endpoint": endpoint,
|
||||
"client_transports": client_transports,
|
||||
"server_transports": [
|
||||
iface.transport for iface in agent_card.additional_interfaces or []
|
||||
]
|
||||
+ [agent_card.preferred_transport or "JSONRPC"],
|
||||
},
|
||||
)
|
||||
|
||||
effective_output_modes = accepted_output_modes or DEFAULT_CLIENT_OUTPUT_MODES.copy()
|
||||
|
||||
content_negotiated = negotiate_content_types(
|
||||
agent_card=agent_card,
|
||||
client_output_modes=accepted_output_modes,
|
||||
skill_name=skill_id,
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=agent_card.name,
|
||||
)
|
||||
if content_negotiated.output_modes:
|
||||
effective_output_modes = content_negotiated.output_modes
|
||||
|
||||
headers, _ = await _prepare_auth_headers(auth, timeout)
|
||||
|
||||
a2a_agent_name = None
|
||||
if agent_card.name:
|
||||
a2a_agent_name = agent_card.name
|
||||
|
||||
agent_card_dict = agent_card.model_dump(exclude_none=True)
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
A2ADelegationStartedEvent(
|
||||
endpoint=endpoint,
|
||||
task_description=task_description,
|
||||
agent_id=agent_id or endpoint,
|
||||
context_id=context_id,
|
||||
is_multiturn=is_multiturn,
|
||||
turn_number=turn_number,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
agent_card=agent_card_dict,
|
||||
protocol_version=agent_card.protocol_version,
|
||||
provider=agent_card_dict.get("provider"),
|
||||
skill_id=skill_id,
|
||||
metadata=metadata,
|
||||
extensions=list(extensions.keys()) if extensions else None,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
|
||||
if turn_number == 1:
|
||||
agent_id_for_event = agent_id or endpoint
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
A2AConversationStartedEvent(
|
||||
agent_id=agent_id_for_event,
|
||||
endpoint=endpoint,
|
||||
context_id=context_id,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
agent_card=agent_card_dict,
|
||||
protocol_version=agent_card.protocol_version,
|
||||
provider=agent_card_dict.get("provider"),
|
||||
skill_id=skill_id,
|
||||
reference_task_ids=reference_task_ids,
|
||||
metadata=metadata,
|
||||
extensions=list(extensions.keys()) if extensions else None,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
|
||||
message_parts = []
|
||||
|
||||
if context:
|
||||
message_parts.append(f"Context:\n{context}\n\n")
|
||||
message_parts.append(f"{task_description}")
|
||||
message_text = "".join(message_parts)
|
||||
|
||||
if is_multiturn and conversation_history and not task_id:
|
||||
if first_task_id := conversation_history[0].task_id:
|
||||
task_id = first_task_id
|
||||
|
||||
parts: PartsDict = {"text": message_text}
|
||||
if response_model:
|
||||
parts.update(
|
||||
{
|
||||
"metadata": PartsMetadataDict(
|
||||
mimeType="application/json",
|
||||
schema=response_model.model_json_schema(),
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
message_metadata = metadata.copy() if metadata else {}
|
||||
if skill_id:
|
||||
message_metadata["skill_id"] = skill_id
|
||||
|
||||
parts_list: list[Part] = [Part(root=TextPart(**parts))]
|
||||
parts_list.extend(_create_file_parts(input_files))
|
||||
|
||||
message = Message(
|
||||
role=Role.user,
|
||||
message_id=str(uuid.uuid4()),
|
||||
parts=parts_list,
|
||||
context_id=context_id,
|
||||
task_id=task_id,
|
||||
reference_task_ids=reference_task_ids,
|
||||
metadata=message_metadata if message_metadata else None,
|
||||
extensions=extensions,
|
||||
)
|
||||
|
||||
new_messages: list[Message] = [*conversation_history, message]
|
||||
crewai_event_bus.emit(
|
||||
None,
|
||||
A2AMessageSentEvent(
|
||||
message=message_text,
|
||||
turn_number=turn_number,
|
||||
context_id=context_id,
|
||||
message_id=message.message_id,
|
||||
is_multiturn=is_multiturn,
|
||||
agent_role=agent_role,
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
skill_id=skill_id,
|
||||
metadata=message_metadata if message_metadata else None,
|
||||
extensions=list(extensions.keys()) if extensions else None,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
|
||||
handler = get_handler(updates)
|
||||
use_polling = isinstance(updates, PollingConfig)
|
||||
|
||||
handler_kwargs: dict[str, Any] = {
|
||||
"turn_number": turn_number,
|
||||
"is_multiturn": is_multiturn,
|
||||
"agent_role": agent_role,
|
||||
"context_id": context_id,
|
||||
"task_id": task_id,
|
||||
"endpoint": endpoint,
|
||||
"agent_branch": agent_branch,
|
||||
"a2a_agent_name": a2a_agent_name,
|
||||
"from_task": from_task,
|
||||
"from_agent": from_agent,
|
||||
}
|
||||
|
||||
if isinstance(updates, PollingConfig):
|
||||
handler_kwargs.update(
|
||||
{
|
||||
"polling_interval": updates.interval,
|
||||
"polling_timeout": updates.timeout or float(timeout),
|
||||
"history_length": updates.history_length,
|
||||
"max_polls": updates.max_polls,
|
||||
}
|
||||
)
|
||||
elif isinstance(updates, PushNotificationConfig):
|
||||
handler_kwargs.update(
|
||||
{
|
||||
"config": updates,
|
||||
"result_store": updates.result_store,
|
||||
"polling_timeout": updates.timeout or float(timeout),
|
||||
"polling_interval": updates.interval,
|
||||
}
|
||||
)
|
||||
|
||||
push_config_for_client = (
|
||||
updates if isinstance(updates, PushNotificationConfig) else None
|
||||
)
|
||||
|
||||
use_streaming = not use_polling and push_config_for_client is None
|
||||
|
||||
client_agent_card = agent_card
|
||||
if effective_url != agent_card.url:
|
||||
client_agent_card = agent_card.model_copy(update={"url": effective_url})
|
||||
|
||||
async with _create_a2a_client(
|
||||
agent_card=client_agent_card,
|
||||
transport_protocol=effective_transport,
|
||||
timeout=timeout,
|
||||
headers=headers,
|
||||
streaming=use_streaming,
|
||||
auth=auth,
|
||||
use_polling=use_polling,
|
||||
push_notification_config=push_config_for_client,
|
||||
client_extensions=client_extensions,
|
||||
accepted_output_modes=effective_output_modes, # type: ignore[arg-type]
|
||||
grpc_config=transport.grpc,
|
||||
) as client:
|
||||
result = await handler.execute(
|
||||
client=client,
|
||||
message=message,
|
||||
new_messages=new_messages,
|
||||
agent_card=agent_card,
|
||||
**handler_kwargs,
|
||||
)
|
||||
result["a2a_agent_name"] = a2a_agent_name
|
||||
result["agent_card"] = agent_card.model_dump(exclude_none=True)
|
||||
return result
|
||||
|
||||
|
||||
def _normalize_grpc_metadata(
|
||||
metadata: tuple[tuple[str, str], ...] | None,
|
||||
) -> tuple[tuple[str, str], ...] | None:
|
||||
"""Lowercase all gRPC metadata keys.
|
||||
|
||||
gRPC requires lowercase metadata keys, but some libraries (like the A2A SDK)
|
||||
use mixed-case headers like 'X-A2A-Extensions'. This normalizes them.
|
||||
"""
|
||||
if metadata is None:
|
||||
return None
|
||||
return tuple((key.lower(), value) for key, value in metadata)
|
||||
|
||||
|
||||
def _create_grpc_interceptors(
|
||||
auth_metadata: list[tuple[str, str]] | None = None,
|
||||
) -> list[Any]:
|
||||
"""Create gRPC interceptors for metadata normalization and auth injection.
|
||||
|
||||
Args:
|
||||
auth_metadata: Optional auth metadata to inject into all calls.
|
||||
Used for insecure channels that need auth (non-localhost without TLS).
|
||||
|
||||
Returns a list of interceptors that lowercase metadata keys for gRPC
|
||||
compatibility. Must be called after grpc is imported.
|
||||
"""
|
||||
import grpc.aio # type: ignore[import-untyped]
|
||||
|
||||
def _merge_metadata(
|
||||
existing: tuple[tuple[str, str], ...] | None,
|
||||
auth: list[tuple[str, str]] | None,
|
||||
) -> tuple[tuple[str, str], ...] | None:
|
||||
"""Merge existing metadata with auth metadata and normalize keys."""
|
||||
merged: list[tuple[str, str]] = []
|
||||
if existing:
|
||||
merged.extend(existing)
|
||||
if auth:
|
||||
merged.extend(auth)
|
||||
if not merged:
|
||||
return None
|
||||
return tuple((key.lower(), value) for key, value in merged)
|
||||
|
||||
def _inject_metadata(client_call_details: Any) -> Any:
|
||||
"""Inject merged metadata into call details."""
|
||||
return client_call_details._replace(
|
||||
metadata=_merge_metadata(client_call_details.metadata, auth_metadata)
|
||||
)
|
||||
|
||||
class MetadataUnaryUnary(grpc.aio.UnaryUnaryClientInterceptor): # type: ignore[misc,no-any-unimported]
|
||||
"""Interceptor for unary-unary calls that injects auth metadata."""
|
||||
|
||||
async def intercept_unary_unary( # type: ignore[no-untyped-def]
|
||||
self, continuation, client_call_details, request
|
||||
):
|
||||
"""Intercept unary-unary call and inject metadata."""
|
||||
return await continuation(_inject_metadata(client_call_details), request)
|
||||
|
||||
class MetadataUnaryStream(grpc.aio.UnaryStreamClientInterceptor): # type: ignore[misc,no-any-unimported]
|
||||
"""Interceptor for unary-stream calls that injects auth metadata."""
|
||||
|
||||
async def intercept_unary_stream( # type: ignore[no-untyped-def]
|
||||
self, continuation, client_call_details, request
|
||||
):
|
||||
"""Intercept unary-stream call and inject metadata."""
|
||||
return await continuation(_inject_metadata(client_call_details), request)
|
||||
|
||||
class MetadataStreamUnary(grpc.aio.StreamUnaryClientInterceptor): # type: ignore[misc,no-any-unimported]
|
||||
"""Interceptor for stream-unary calls that injects auth metadata."""
|
||||
|
||||
async def intercept_stream_unary( # type: ignore[no-untyped-def]
|
||||
self, continuation, client_call_details, request_iterator
|
||||
):
|
||||
"""Intercept stream-unary call and inject metadata."""
|
||||
return await continuation(
|
||||
_inject_metadata(client_call_details), request_iterator
|
||||
)
|
||||
|
||||
class MetadataStreamStream(grpc.aio.StreamStreamClientInterceptor): # type: ignore[misc,no-any-unimported]
|
||||
"""Interceptor for stream-stream calls that injects auth metadata."""
|
||||
|
||||
async def intercept_stream_stream( # type: ignore[no-untyped-def]
|
||||
self, continuation, client_call_details, request_iterator
|
||||
):
|
||||
"""Intercept stream-stream call and inject metadata."""
|
||||
return await continuation(
|
||||
_inject_metadata(client_call_details), request_iterator
|
||||
)
|
||||
|
||||
return [
|
||||
MetadataUnaryUnary(),
|
||||
MetadataUnaryStream(),
|
||||
MetadataStreamUnary(),
|
||||
MetadataStreamStream(),
|
||||
]
|
||||
|
||||
|
||||
def _create_grpc_channel_factory(
|
||||
grpc_config: GRPCClientConfig,
|
||||
auth: ClientAuthScheme | None = None,
|
||||
) -> Callable[[str], Any]:
|
||||
"""Create a gRPC channel factory with the given configuration.
|
||||
|
||||
Args:
|
||||
grpc_config: gRPC client configuration with channel options.
|
||||
auth: Optional ClientAuthScheme for TLS and auth configuration.
|
||||
|
||||
Returns:
|
||||
A callable that creates gRPC channels from URLs.
|
||||
"""
|
||||
try:
|
||||
import grpc
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"gRPC transport requires grpcio. Install with: pip install a2a-sdk[grpc]"
|
||||
) from e
|
||||
|
||||
auth_metadata: list[tuple[str, str]] = []
|
||||
|
||||
if auth is not None:
|
||||
from crewai_a2a.auth.client_schemes import (
|
||||
APIKeyAuth,
|
||||
BearerTokenAuth,
|
||||
HTTPBasicAuth,
|
||||
HTTPDigestAuth,
|
||||
OAuth2AuthorizationCode,
|
||||
OAuth2ClientCredentials,
|
||||
)
|
||||
|
||||
if isinstance(auth, HTTPDigestAuth):
|
||||
raise ValueError(
|
||||
"HTTPDigestAuth is not supported with gRPC transport. "
|
||||
"Digest authentication requires HTTP challenge-response flow. "
|
||||
"Use BearerTokenAuth, HTTPBasicAuth, APIKeyAuth (header), or OAuth2 instead."
|
||||
)
|
||||
if isinstance(auth, APIKeyAuth) and auth.location in ("query", "cookie"):
|
||||
raise ValueError(
|
||||
f"APIKeyAuth with location='{auth.location}' is not supported with gRPC transport. "
|
||||
"gRPC only supports header-based authentication. "
|
||||
"Use APIKeyAuth with location='header' instead."
|
||||
)
|
||||
|
||||
if isinstance(auth, BearerTokenAuth):
|
||||
auth_metadata.append(("authorization", f"Bearer {auth.token}"))
|
||||
elif isinstance(auth, HTTPBasicAuth):
|
||||
import base64
|
||||
|
||||
basic_credentials = f"{auth.username}:{auth.password}"
|
||||
encoded = base64.b64encode(basic_credentials.encode()).decode()
|
||||
auth_metadata.append(("authorization", f"Basic {encoded}"))
|
||||
elif isinstance(auth, APIKeyAuth) and auth.location == "header":
|
||||
header_name = auth.name.lower()
|
||||
auth_metadata.append((header_name, auth.api_key))
|
||||
elif isinstance(auth, (OAuth2ClientCredentials, OAuth2AuthorizationCode)):
|
||||
if auth._access_token:
|
||||
auth_metadata.append(("authorization", f"Bearer {auth._access_token}"))
|
||||
|
||||
def factory(url: str) -> Any:
|
||||
"""Create a gRPC channel for the given URL."""
|
||||
target = url
|
||||
use_tls = False
|
||||
|
||||
if url.startswith("grpcs://"):
|
||||
target = url[8:]
|
||||
use_tls = True
|
||||
elif url.startswith("grpc://"):
|
||||
target = url[7:]
|
||||
elif url.startswith("https://"):
|
||||
target = url[8:]
|
||||
use_tls = True
|
||||
elif url.startswith("http://"):
|
||||
target = url[7:]
|
||||
|
||||
options: list[tuple[str, Any]] = []
|
||||
if grpc_config.max_send_message_length is not None:
|
||||
options.append(
|
||||
("grpc.max_send_message_length", grpc_config.max_send_message_length)
|
||||
)
|
||||
if grpc_config.max_receive_message_length is not None:
|
||||
options.append(
|
||||
(
|
||||
"grpc.max_receive_message_length",
|
||||
grpc_config.max_receive_message_length,
|
||||
)
|
||||
)
|
||||
if grpc_config.keepalive_time_ms is not None:
|
||||
options.append(("grpc.keepalive_time_ms", grpc_config.keepalive_time_ms))
|
||||
if grpc_config.keepalive_timeout_ms is not None:
|
||||
options.append(
|
||||
("grpc.keepalive_timeout_ms", grpc_config.keepalive_timeout_ms)
|
||||
)
|
||||
|
||||
channel_credentials = None
|
||||
if auth and hasattr(auth, "tls") and auth.tls:
|
||||
channel_credentials = auth.tls.get_grpc_credentials()
|
||||
elif use_tls:
|
||||
channel_credentials = grpc.ssl_channel_credentials()
|
||||
|
||||
if channel_credentials and auth_metadata:
|
||||
|
||||
class AuthMetadataPlugin(grpc.AuthMetadataPlugin): # type: ignore[misc,no-any-unimported]
|
||||
"""gRPC auth metadata plugin that adds auth headers as metadata."""
|
||||
|
||||
def __init__(self, metadata: list[tuple[str, str]]) -> None:
|
||||
self._metadata = tuple(metadata)
|
||||
|
||||
def __call__( # type: ignore[no-any-unimported]
|
||||
self,
|
||||
context: grpc.AuthMetadataContext,
|
||||
callback: grpc.AuthMetadataPluginCallback,
|
||||
) -> None:
|
||||
callback(self._metadata, None)
|
||||
|
||||
call_creds = grpc.metadata_call_credentials(
|
||||
AuthMetadataPlugin(auth_metadata)
|
||||
)
|
||||
credentials = grpc.composite_channel_credentials(
|
||||
channel_credentials, call_creds
|
||||
)
|
||||
interceptors = _create_grpc_interceptors()
|
||||
return grpc.aio.secure_channel(
|
||||
target, credentials, options=options or None, interceptors=interceptors
|
||||
)
|
||||
if channel_credentials:
|
||||
interceptors = _create_grpc_interceptors()
|
||||
return grpc.aio.secure_channel(
|
||||
target,
|
||||
channel_credentials,
|
||||
options=options or None,
|
||||
interceptors=interceptors,
|
||||
)
|
||||
interceptors = _create_grpc_interceptors(
|
||||
auth_metadata=auth_metadata if auth_metadata else None
|
||||
)
|
||||
return grpc.aio.insecure_channel(
|
||||
target, options=options or None, interceptors=interceptors
|
||||
)
|
||||
|
||||
return factory
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def _create_a2a_client(
|
||||
agent_card: AgentCard,
|
||||
transport_protocol: Literal["JSONRPC", "GRPC", "HTTP+JSON"],
|
||||
timeout: int,
|
||||
headers: MutableMapping[str, str],
|
||||
streaming: bool,
|
||||
auth: ClientAuthScheme | None = None,
|
||||
use_polling: bool = False,
|
||||
push_notification_config: PushNotificationConfig | None = None,
|
||||
client_extensions: list[str] | None = None,
|
||||
accepted_output_modes: list[str] | None = None,
|
||||
grpc_config: GRPCClientConfig | None = None,
|
||||
) -> AsyncIterator[Client]:
|
||||
"""Create and configure an A2A client.
|
||||
|
||||
Args:
|
||||
agent_card: The A2A agent card.
|
||||
transport_protocol: Transport protocol to use.
|
||||
timeout: Request timeout in seconds.
|
||||
headers: HTTP headers (already with auth applied).
|
||||
streaming: Enable streaming responses.
|
||||
auth: Optional ClientAuthScheme for client configuration.
|
||||
use_polling: Enable polling mode.
|
||||
push_notification_config: Optional push notification config.
|
||||
client_extensions: A2A protocol extension URIs to declare support for.
|
||||
accepted_output_modes: MIME types the client can accept in responses.
|
||||
grpc_config: Optional gRPC client configuration.
|
||||
|
||||
Yields:
|
||||
Configured A2A client instance.
|
||||
"""
|
||||
verify = _get_tls_verify(auth)
|
||||
async with httpx.AsyncClient(
|
||||
timeout=timeout,
|
||||
headers=headers,
|
||||
verify=verify,
|
||||
) as httpx_client:
|
||||
if auth and isinstance(auth, (HTTPDigestAuth, APIKeyAuth)):
|
||||
configure_auth_client(auth, httpx_client)
|
||||
|
||||
push_configs: list[A2APushNotificationConfig] = []
|
||||
if push_notification_config is not None:
|
||||
push_configs.append(
|
||||
A2APushNotificationConfig(
|
||||
url=str(push_notification_config.url),
|
||||
id=push_notification_config.id,
|
||||
token=push_notification_config.token,
|
||||
authentication=push_notification_config.authentication,
|
||||
)
|
||||
)
|
||||
|
||||
grpc_channel_factory = None
|
||||
if transport_protocol == "GRPC":
|
||||
grpc_channel_factory = _create_grpc_channel_factory(
|
||||
grpc_config or GRPCClientConfig(),
|
||||
auth=auth,
|
||||
)
|
||||
|
||||
config = ClientConfig(
|
||||
httpx_client=httpx_client,
|
||||
supported_transports=[transport_protocol],
|
||||
streaming=streaming and not use_polling,
|
||||
polling=use_polling,
|
||||
accepted_output_modes=accepted_output_modes or DEFAULT_CLIENT_OUTPUT_MODES, # type: ignore[arg-type]
|
||||
push_notification_configs=push_configs,
|
||||
grpc_channel_factory=grpc_channel_factory,
|
||||
)
|
||||
|
||||
factory = ClientFactory(config)
|
||||
client = factory.create(agent_card)
|
||||
|
||||
if client_extensions:
|
||||
await client.add_request_middleware(ExtensionsMiddleware(client_extensions))
|
||||
|
||||
yield client
|
||||
131
lib/crewai-a2a/src/crewai_a2a/utils/logging.py
Normal file
131
lib/crewai-a2a/src/crewai_a2a/utils/logging.py
Normal file
@@ -0,0 +1,131 @@
|
||||
"""Structured JSON logging utilities for A2A module."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from contextvars import ContextVar
|
||||
from datetime import datetime, timezone
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
|
||||
_log_context: ContextVar[dict[str, Any] | None] = ContextVar(
|
||||
"log_context", default=None
|
||||
)
|
||||
|
||||
|
||||
class JSONFormatter(logging.Formatter):
|
||||
"""JSON formatter for structured logging.
|
||||
|
||||
Outputs logs as JSON with consistent fields for log aggregators.
|
||||
"""
|
||||
|
||||
def format(self, record: logging.LogRecord) -> str:
|
||||
"""Format log record as JSON string."""
|
||||
log_data: dict[str, Any] = {
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"level": record.levelname,
|
||||
"logger": record.name,
|
||||
"message": record.getMessage(),
|
||||
}
|
||||
|
||||
if record.exc_info:
|
||||
log_data["exception"] = self.formatException(record.exc_info)
|
||||
|
||||
context = _log_context.get()
|
||||
if context is not None:
|
||||
log_data.update(context)
|
||||
|
||||
if hasattr(record, "task_id"):
|
||||
log_data["task_id"] = record.task_id
|
||||
if hasattr(record, "context_id"):
|
||||
log_data["context_id"] = record.context_id
|
||||
if hasattr(record, "agent"):
|
||||
log_data["agent"] = record.agent
|
||||
if hasattr(record, "endpoint"):
|
||||
log_data["endpoint"] = record.endpoint
|
||||
if hasattr(record, "extension"):
|
||||
log_data["extension"] = record.extension
|
||||
if hasattr(record, "error"):
|
||||
log_data["error"] = record.error
|
||||
|
||||
for key, value in record.__dict__.items():
|
||||
if key.startswith("_") or key in (
|
||||
"name",
|
||||
"msg",
|
||||
"args",
|
||||
"created",
|
||||
"filename",
|
||||
"funcName",
|
||||
"levelname",
|
||||
"levelno",
|
||||
"lineno",
|
||||
"module",
|
||||
"msecs",
|
||||
"pathname",
|
||||
"process",
|
||||
"processName",
|
||||
"relativeCreated",
|
||||
"stack_info",
|
||||
"exc_info",
|
||||
"exc_text",
|
||||
"thread",
|
||||
"threadName",
|
||||
"taskName",
|
||||
"message",
|
||||
):
|
||||
continue
|
||||
if key not in log_data:
|
||||
log_data[key] = value
|
||||
|
||||
return json.dumps(log_data, default=str)
|
||||
|
||||
|
||||
class LogContext:
|
||||
"""Context manager for adding fields to all logs within a scope.
|
||||
|
||||
Example:
|
||||
with LogContext(task_id="abc", context_id="xyz"):
|
||||
logger.info("Processing task") # Includes task_id and context_id
|
||||
"""
|
||||
|
||||
def __init__(self, **fields: Any) -> None:
|
||||
self._fields = fields
|
||||
self._token: Any = None
|
||||
|
||||
def __enter__(self) -> LogContext:
|
||||
current = _log_context.get() or {}
|
||||
new_context = {**current, **self._fields}
|
||||
self._token = _log_context.set(new_context)
|
||||
return self
|
||||
|
||||
def __exit__(self, *args: Any) -> None:
|
||||
_log_context.reset(self._token)
|
||||
|
||||
|
||||
def configure_json_logging(logger_name: str = "crewai.a2a") -> None:
|
||||
"""Configure JSON logging for the A2A module.
|
||||
|
||||
Args:
|
||||
logger_name: Logger name to configure.
|
||||
"""
|
||||
logger = logging.getLogger(logger_name)
|
||||
|
||||
for handler in logger.handlers[:]:
|
||||
logger.removeHandler(handler)
|
||||
|
||||
handler = logging.StreamHandler()
|
||||
handler.setFormatter(JSONFormatter())
|
||||
logger.addHandler(handler)
|
||||
|
||||
|
||||
def get_logger(name: str) -> logging.Logger:
|
||||
"""Get a logger configured for structured JSON output.
|
||||
|
||||
Args:
|
||||
name: Logger name.
|
||||
|
||||
Returns:
|
||||
Configured logger instance.
|
||||
"""
|
||||
return logging.getLogger(name)
|
||||
101
lib/crewai-a2a/src/crewai_a2a/utils/response_model.py
Normal file
101
lib/crewai-a2a/src/crewai_a2a/utils/response_model.py
Normal file
@@ -0,0 +1,101 @@
|
||||
"""Response model utilities for A2A agent interactions."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TypeAlias
|
||||
|
||||
from crewai.types.utils import create_literals_from_strings
|
||||
from pydantic import BaseModel, Field, create_model
|
||||
|
||||
from crewai_a2a.config import A2AClientConfig, A2AConfig, A2AServerConfig
|
||||
|
||||
|
||||
A2AConfigTypes: TypeAlias = A2AConfig | A2AServerConfig | A2AClientConfig
|
||||
A2AClientConfigTypes: TypeAlias = A2AConfig | A2AClientConfig
|
||||
|
||||
|
||||
def create_agent_response_model(agent_ids: tuple[str, ...]) -> type[BaseModel] | None:
|
||||
"""Create a dynamic AgentResponse model with Literal types for agent IDs.
|
||||
|
||||
Args:
|
||||
agent_ids: List of available A2A agent IDs.
|
||||
|
||||
Returns:
|
||||
Dynamically created Pydantic model with Literal-constrained a2a_ids field,
|
||||
or None if agent_ids is empty.
|
||||
"""
|
||||
if not agent_ids:
|
||||
return None
|
||||
|
||||
DynamicLiteral = create_literals_from_strings(agent_ids) # noqa: N806
|
||||
|
||||
return create_model(
|
||||
"AgentResponse",
|
||||
a2a_ids=(
|
||||
tuple[DynamicLiteral, ...], # type: ignore[valid-type]
|
||||
Field(
|
||||
default_factory=tuple,
|
||||
max_length=len(agent_ids),
|
||||
description="A2A agent IDs to delegate to.",
|
||||
),
|
||||
),
|
||||
message=(
|
||||
str,
|
||||
Field(
|
||||
description="The message content. If is_a2a=true, this is sent to the A2A agent. If is_a2a=false, this is your final answer ending the conversation."
|
||||
),
|
||||
),
|
||||
is_a2a=(
|
||||
bool,
|
||||
Field(
|
||||
description="Set to false when the remote agent has answered your question - extract their answer and return it as your final message. Set to true ONLY if you need to ask a NEW, DIFFERENT question. NEVER repeat the same request - if the conversation history shows the agent already answered, set is_a2a=false immediately."
|
||||
),
|
||||
),
|
||||
__base__=BaseModel,
|
||||
)
|
||||
|
||||
|
||||
def extract_a2a_agent_ids_from_config(
|
||||
a2a_config: list[A2AConfigTypes] | A2AConfigTypes | None,
|
||||
) -> tuple[list[A2AClientConfigTypes], tuple[str, ...]]:
|
||||
"""Extract A2A agent IDs from A2A configuration.
|
||||
|
||||
Filters out A2AServerConfig since it doesn't have an endpoint for delegation.
|
||||
|
||||
Args:
|
||||
a2a_config: A2A configuration (any type).
|
||||
|
||||
Returns:
|
||||
Tuple of client A2A configs list and agent endpoint IDs.
|
||||
"""
|
||||
if a2a_config is None:
|
||||
return [], ()
|
||||
|
||||
configs: list[A2AConfigTypes]
|
||||
if isinstance(a2a_config, (A2AConfig, A2AClientConfig, A2AServerConfig)):
|
||||
configs = [a2a_config]
|
||||
else:
|
||||
configs = a2a_config
|
||||
|
||||
# Filter to only client configs (those with endpoint)
|
||||
client_configs: list[A2AClientConfigTypes] = [
|
||||
config for config in configs if isinstance(config, (A2AConfig, A2AClientConfig))
|
||||
]
|
||||
|
||||
return client_configs, tuple(config.endpoint for config in client_configs)
|
||||
|
||||
|
||||
def get_a2a_agents_and_response_model(
|
||||
a2a_config: list[A2AConfigTypes] | A2AConfigTypes | None,
|
||||
) -> tuple[list[A2AClientConfigTypes], type[BaseModel] | None]:
|
||||
"""Get A2A agent configs and response model.
|
||||
|
||||
Args:
|
||||
a2a_config: A2A configuration (any type).
|
||||
|
||||
Returns:
|
||||
Tuple of client A2A configs and response model.
|
||||
"""
|
||||
a2a_agents, agent_ids = extract_a2a_agent_ids_from_config(a2a_config=a2a_config)
|
||||
|
||||
return a2a_agents, create_agent_response_model(agent_ids)
|
||||
585
lib/crewai-a2a/src/crewai_a2a/utils/task.py
Normal file
585
lib/crewai-a2a/src/crewai_a2a/utils/task.py
Normal file
@@ -0,0 +1,585 @@
|
||||
"""A2A task utilities for server-side task management."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
from collections.abc import Callable, Coroutine
|
||||
from datetime import datetime
|
||||
from functools import wraps
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar, TypedDict, cast
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from a2a.server.agent_execution import RequestContext
|
||||
from a2a.server.events import EventQueue
|
||||
from a2a.types import (
|
||||
Artifact,
|
||||
FileWithBytes,
|
||||
FileWithUri,
|
||||
InternalError,
|
||||
InvalidParamsError,
|
||||
Message,
|
||||
Part,
|
||||
Task as A2ATask,
|
||||
TaskState,
|
||||
TaskStatus,
|
||||
TaskStatusUpdateEvent,
|
||||
)
|
||||
from a2a.utils import (
|
||||
get_data_parts,
|
||||
get_file_parts,
|
||||
new_agent_text_message,
|
||||
new_data_artifact,
|
||||
new_text_artifact,
|
||||
)
|
||||
from a2a.utils.errors import ServerError
|
||||
from aiocache import SimpleMemoryCache, caches # type: ignore[import-untyped]
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.a2a_events import (
|
||||
A2AServerTaskCanceledEvent,
|
||||
A2AServerTaskCompletedEvent,
|
||||
A2AServerTaskFailedEvent,
|
||||
A2AServerTaskStartedEvent,
|
||||
)
|
||||
from crewai.task import Task
|
||||
from crewai.utilities.pydantic_schema_utils import create_model_from_schema
|
||||
from pydantic import BaseModel
|
||||
|
||||
from crewai_a2a.utils.agent_card import _get_server_config
|
||||
from crewai_a2a.utils.content_type import validate_message_parts
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.agent import Agent
|
||||
|
||||
from crewai_a2a.extensions.server import ExtensionContext, ServerExtensionRegistry
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
P = ParamSpec("P")
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class RedisCacheConfig(TypedDict, total=False):
|
||||
"""Configuration for aiocache Redis backend."""
|
||||
|
||||
cache: str
|
||||
endpoint: str
|
||||
port: int
|
||||
db: int
|
||||
password: str
|
||||
|
||||
|
||||
def _parse_redis_url(url: str) -> RedisCacheConfig:
|
||||
"""Parse a Redis URL into aiocache configuration.
|
||||
|
||||
Args:
|
||||
url: Redis connection URL (e.g., redis://localhost:6379/0).
|
||||
|
||||
Returns:
|
||||
Configuration dict for aiocache.RedisCache.
|
||||
"""
|
||||
parsed = urlparse(url)
|
||||
config: RedisCacheConfig = {
|
||||
"cache": "aiocache.RedisCache",
|
||||
"endpoint": parsed.hostname or "localhost",
|
||||
"port": parsed.port or 6379,
|
||||
}
|
||||
if parsed.path and parsed.path != "/":
|
||||
try:
|
||||
config["db"] = int(parsed.path.lstrip("/"))
|
||||
except ValueError:
|
||||
pass
|
||||
if parsed.password:
|
||||
config["password"] = parsed.password
|
||||
return config
|
||||
|
||||
|
||||
_redis_url = os.environ.get("REDIS_URL")
|
||||
|
||||
caches.set_config(
|
||||
{
|
||||
"default": _parse_redis_url(_redis_url)
|
||||
if _redis_url
|
||||
else {
|
||||
"cache": "aiocache.SimpleMemoryCache",
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def cancellable(
|
||||
fn: Callable[P, Coroutine[Any, Any, T]],
|
||||
) -> Callable[P, Coroutine[Any, Any, T]]:
|
||||
"""Decorator that enables cancellation for A2A task execution.
|
||||
|
||||
Runs a cancellation watcher concurrently with the wrapped function.
|
||||
When a cancel event is published, the execution is cancelled.
|
||||
|
||||
Args:
|
||||
fn: The async function to wrap.
|
||||
|
||||
Returns:
|
||||
Wrapped function with cancellation support.
|
||||
"""
|
||||
|
||||
@wraps(fn)
|
||||
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
|
||||
"""Wrap function with cancellation monitoring."""
|
||||
context: RequestContext | None = None
|
||||
for arg in args:
|
||||
if isinstance(arg, RequestContext):
|
||||
context = arg
|
||||
break
|
||||
if context is None:
|
||||
context = cast(RequestContext | None, kwargs.get("context"))
|
||||
|
||||
if context is None:
|
||||
return await fn(*args, **kwargs)
|
||||
|
||||
task_id = context.task_id
|
||||
cache = caches.get("default")
|
||||
|
||||
async def poll_for_cancel() -> bool:
|
||||
"""Poll cache for cancellation flag."""
|
||||
while True:
|
||||
if await cache.get(f"cancel:{task_id}"):
|
||||
return True
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
async def watch_for_cancel() -> bool:
|
||||
"""Watch for cancellation events via pub/sub or polling."""
|
||||
if isinstance(cache, SimpleMemoryCache):
|
||||
return await poll_for_cancel()
|
||||
|
||||
try:
|
||||
client = cache.client
|
||||
pubsub = client.pubsub()
|
||||
await pubsub.subscribe(f"cancel:{task_id}")
|
||||
async for message in pubsub.listen():
|
||||
if message["type"] == "message":
|
||||
return True
|
||||
except (OSError, ConnectionError) as e:
|
||||
logger.warning(
|
||||
"Cancel watcher Redis error, falling back to polling",
|
||||
extra={"task_id": task_id, "error": str(e)},
|
||||
)
|
||||
return await poll_for_cancel()
|
||||
return False
|
||||
|
||||
execute_task = asyncio.create_task(fn(*args, **kwargs))
|
||||
cancel_watch = asyncio.create_task(watch_for_cancel())
|
||||
|
||||
try:
|
||||
done, _ = await asyncio.wait(
|
||||
[execute_task, cancel_watch],
|
||||
return_when=asyncio.FIRST_COMPLETED,
|
||||
)
|
||||
|
||||
if cancel_watch in done:
|
||||
execute_task.cancel()
|
||||
try:
|
||||
await execute_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
raise asyncio.CancelledError(f"Task {task_id} was cancelled")
|
||||
cancel_watch.cancel()
|
||||
return execute_task.result()
|
||||
finally:
|
||||
await cache.delete(f"cancel:{task_id}")
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def _convert_a2a_files_to_file_inputs(
|
||||
a2a_files: list[FileWithBytes | FileWithUri],
|
||||
) -> dict[str, Any]:
|
||||
"""Convert a2a file types to crewai FileInput dict.
|
||||
|
||||
Args:
|
||||
a2a_files: List of FileWithBytes or FileWithUri from a2a SDK.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping file names to FileInput objects.
|
||||
"""
|
||||
try:
|
||||
from crewai_files import File, FileBytes
|
||||
except ImportError:
|
||||
logger.debug("crewai_files not installed, returning empty file dict")
|
||||
return {}
|
||||
|
||||
file_dict: dict[str, Any] = {}
|
||||
for idx, a2a_file in enumerate(a2a_files):
|
||||
if isinstance(a2a_file, FileWithBytes):
|
||||
file_bytes = base64.b64decode(a2a_file.bytes)
|
||||
name = a2a_file.name or f"file_{idx}"
|
||||
file_source = FileBytes(data=file_bytes, filename=a2a_file.name)
|
||||
file_dict[name] = File(source=file_source)
|
||||
elif isinstance(a2a_file, FileWithUri):
|
||||
name = a2a_file.name or f"file_{idx}"
|
||||
file_dict[name] = File(source=a2a_file.uri)
|
||||
|
||||
return file_dict
|
||||
|
||||
|
||||
def _extract_response_schema(parts: list[Part]) -> dict[str, Any] | None:
|
||||
"""Extract response schema from message parts metadata.
|
||||
|
||||
The client may include a JSON schema in TextPart metadata to specify
|
||||
the expected response format (see delegation.py line 463).
|
||||
|
||||
Args:
|
||||
parts: List of message parts.
|
||||
|
||||
Returns:
|
||||
JSON schema dict if found, None otherwise.
|
||||
"""
|
||||
for part in parts:
|
||||
if part.root.kind == "text" and part.root.metadata:
|
||||
schema = part.root.metadata.get("schema")
|
||||
if schema and isinstance(schema, dict):
|
||||
return schema # type: ignore[no-any-return]
|
||||
return None
|
||||
|
||||
|
||||
def _create_result_artifact(
|
||||
result: Any,
|
||||
task_id: str,
|
||||
) -> Artifact:
|
||||
"""Create artifact from task result, using DataPart for structured data.
|
||||
|
||||
Args:
|
||||
result: The task execution result.
|
||||
task_id: The task ID for naming the artifact.
|
||||
|
||||
Returns:
|
||||
Artifact with appropriate part type (DataPart for dict/Pydantic, TextPart for strings).
|
||||
"""
|
||||
artifact_name = f"result_{task_id}"
|
||||
if isinstance(result, dict):
|
||||
return new_data_artifact(artifact_name, result)
|
||||
if isinstance(result, BaseModel):
|
||||
return new_data_artifact(artifact_name, result.model_dump())
|
||||
return new_text_artifact(artifact_name, str(result))
|
||||
|
||||
|
||||
def _build_task_description(
|
||||
user_message: str,
|
||||
structured_inputs: list[dict[str, Any]],
|
||||
) -> str:
|
||||
"""Build task description including structured data if present.
|
||||
|
||||
Args:
|
||||
user_message: The original user message text.
|
||||
structured_inputs: List of structured data from DataParts.
|
||||
|
||||
Returns:
|
||||
Task description with structured data appended if present.
|
||||
"""
|
||||
if not structured_inputs:
|
||||
return user_message
|
||||
|
||||
structured_json = json.dumps(structured_inputs, indent=2)
|
||||
return f"{user_message}\n\nStructured Data:\n{structured_json}"
|
||||
|
||||
|
||||
async def execute(
|
||||
agent: Agent,
|
||||
context: RequestContext,
|
||||
event_queue: EventQueue,
|
||||
) -> None:
|
||||
"""Execute an A2A task using a CrewAI agent.
|
||||
|
||||
Args:
|
||||
agent: The CrewAI agent to execute the task.
|
||||
context: The A2A request context containing the user's message.
|
||||
event_queue: The event queue for sending responses back.
|
||||
"""
|
||||
await _execute_impl(agent, context, event_queue, None, None)
|
||||
|
||||
|
||||
@cancellable
|
||||
async def _execute_impl(
|
||||
agent: Agent,
|
||||
context: RequestContext,
|
||||
event_queue: EventQueue,
|
||||
extension_registry: ServerExtensionRegistry | None,
|
||||
extension_context: ExtensionContext | None,
|
||||
) -> None:
|
||||
"""Internal implementation for task execution with optional extensions."""
|
||||
server_config = _get_server_config(agent)
|
||||
if context.message and context.message.parts and server_config:
|
||||
allowed_modes = server_config.default_input_modes
|
||||
invalid_types = validate_message_parts(context.message.parts, allowed_modes)
|
||||
if invalid_types:
|
||||
raise ServerError(
|
||||
InvalidParamsError(
|
||||
message=f"Unsupported content type(s): {', '.join(invalid_types)}. "
|
||||
f"Supported: {', '.join(allowed_modes)}"
|
||||
)
|
||||
)
|
||||
|
||||
if extension_registry and extension_context:
|
||||
await extension_registry.invoke_on_request(extension_context)
|
||||
|
||||
user_message = context.get_user_input()
|
||||
|
||||
response_model: type[BaseModel] | None = None
|
||||
structured_inputs: list[dict[str, Any]] = []
|
||||
a2a_files: list[FileWithBytes | FileWithUri] = []
|
||||
|
||||
if context.message and context.message.parts:
|
||||
schema = _extract_response_schema(context.message.parts)
|
||||
if schema:
|
||||
try:
|
||||
response_model = create_model_from_schema(schema)
|
||||
except Exception as e:
|
||||
logger.debug(
|
||||
"Failed to create response model from schema",
|
||||
extra={"error": str(e), "schema_title": schema.get("title")},
|
||||
)
|
||||
|
||||
structured_inputs = get_data_parts(context.message.parts)
|
||||
a2a_files = get_file_parts(context.message.parts)
|
||||
|
||||
task_id = context.task_id
|
||||
context_id = context.context_id
|
||||
if task_id is None or context_id is None:
|
||||
msg = "task_id and context_id are required"
|
||||
crewai_event_bus.emit(
|
||||
agent,
|
||||
A2AServerTaskFailedEvent(
|
||||
task_id="",
|
||||
context_id="",
|
||||
error=msg,
|
||||
from_agent=agent,
|
||||
),
|
||||
)
|
||||
raise ServerError(InvalidParamsError(message=msg)) from None
|
||||
|
||||
task = Task(
|
||||
description=_build_task_description(user_message, structured_inputs),
|
||||
expected_output="Response to the user's request",
|
||||
agent=agent,
|
||||
response_model=response_model,
|
||||
input_files=_convert_a2a_files_to_file_inputs(a2a_files),
|
||||
)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
agent,
|
||||
A2AServerTaskStartedEvent(
|
||||
task_id=task_id,
|
||||
context_id=context_id,
|
||||
from_task=task,
|
||||
from_agent=agent,
|
||||
),
|
||||
)
|
||||
|
||||
try:
|
||||
result = await agent.aexecute_task(task=task, tools=agent.tools)
|
||||
if extension_registry and extension_context:
|
||||
result = await extension_registry.invoke_on_response(
|
||||
extension_context, result
|
||||
)
|
||||
result_str = str(result)
|
||||
history: list[Message] = [context.message] if context.message else []
|
||||
history.append(new_agent_text_message(result_str, context_id, task_id))
|
||||
await event_queue.enqueue_event(
|
||||
A2ATask(
|
||||
id=task_id,
|
||||
context_id=context_id,
|
||||
status=TaskStatus(state=TaskState.completed),
|
||||
artifacts=[_create_result_artifact(result, task_id)],
|
||||
history=history,
|
||||
)
|
||||
)
|
||||
crewai_event_bus.emit(
|
||||
agent,
|
||||
A2AServerTaskCompletedEvent(
|
||||
task_id=task_id,
|
||||
context_id=context_id,
|
||||
result=str(result),
|
||||
from_task=task,
|
||||
from_agent=agent,
|
||||
),
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
crewai_event_bus.emit(
|
||||
agent,
|
||||
A2AServerTaskCanceledEvent(
|
||||
task_id=task_id,
|
||||
context_id=context_id,
|
||||
from_task=task,
|
||||
from_agent=agent,
|
||||
),
|
||||
)
|
||||
raise
|
||||
except Exception as e:
|
||||
crewai_event_bus.emit(
|
||||
agent,
|
||||
A2AServerTaskFailedEvent(
|
||||
task_id=task_id,
|
||||
context_id=context_id,
|
||||
error=str(e),
|
||||
from_task=task,
|
||||
from_agent=agent,
|
||||
),
|
||||
)
|
||||
raise ServerError(
|
||||
error=InternalError(message=f"Task execution failed: {e}")
|
||||
) from e
|
||||
|
||||
|
||||
async def execute_with_extensions(
|
||||
agent: Agent,
|
||||
context: RequestContext,
|
||||
event_queue: EventQueue,
|
||||
extension_registry: ServerExtensionRegistry,
|
||||
extension_context: ExtensionContext,
|
||||
) -> None:
|
||||
"""Execute an A2A task with extension hooks.
|
||||
|
||||
Args:
|
||||
agent: The CrewAI agent to execute the task.
|
||||
context: The A2A request context containing the user's message.
|
||||
event_queue: The event queue for sending responses back.
|
||||
extension_registry: Registry of server extensions.
|
||||
extension_context: Context for extension hooks.
|
||||
"""
|
||||
await _execute_impl(
|
||||
agent, context, event_queue, extension_registry, extension_context
|
||||
)
|
||||
|
||||
|
||||
async def cancel(
|
||||
context: RequestContext,
|
||||
event_queue: EventQueue,
|
||||
) -> A2ATask | None:
|
||||
"""Cancel an A2A task.
|
||||
|
||||
Publishes a cancel event that the cancellable decorator listens for.
|
||||
|
||||
Args:
|
||||
context: The A2A request context containing task information.
|
||||
event_queue: The event queue for sending the cancellation status.
|
||||
|
||||
Returns:
|
||||
The canceled task with updated status.
|
||||
"""
|
||||
task_id = context.task_id
|
||||
context_id = context.context_id
|
||||
if task_id is None or context_id is None:
|
||||
raise ServerError(InvalidParamsError(message="task_id and context_id required"))
|
||||
|
||||
if context.current_task and context.current_task.status.state in (
|
||||
TaskState.completed,
|
||||
TaskState.failed,
|
||||
TaskState.canceled,
|
||||
):
|
||||
return context.current_task
|
||||
|
||||
cache = caches.get("default")
|
||||
|
||||
await cache.set(f"cancel:{task_id}", True, ttl=3600)
|
||||
if not isinstance(cache, SimpleMemoryCache):
|
||||
await cache.client.publish(f"cancel:{task_id}", "cancel")
|
||||
|
||||
await event_queue.enqueue_event(
|
||||
TaskStatusUpdateEvent(
|
||||
task_id=task_id,
|
||||
context_id=context_id,
|
||||
status=TaskStatus(state=TaskState.canceled),
|
||||
final=True,
|
||||
)
|
||||
)
|
||||
|
||||
if context.current_task:
|
||||
context.current_task.status = TaskStatus(state=TaskState.canceled)
|
||||
return context.current_task
|
||||
return None
|
||||
|
||||
|
||||
def list_tasks(
|
||||
tasks: list[A2ATask],
|
||||
context_id: str | None = None,
|
||||
status: TaskState | None = None,
|
||||
status_timestamp_after: datetime | None = None,
|
||||
page_size: int = 50,
|
||||
page_token: str | None = None,
|
||||
history_length: int | None = None,
|
||||
include_artifacts: bool = False,
|
||||
) -> tuple[list[A2ATask], str | None, int]:
|
||||
"""Filter and paginate A2A tasks.
|
||||
|
||||
Provides filtering by context, status, and timestamp, along with
|
||||
cursor-based pagination. This is a pure utility function that operates
|
||||
on an in-memory list of tasks - storage retrieval is handled separately.
|
||||
|
||||
Args:
|
||||
tasks: All tasks to filter.
|
||||
context_id: Filter by context ID to get tasks in a conversation.
|
||||
status: Filter by task state (e.g., completed, working).
|
||||
status_timestamp_after: Filter to tasks updated after this time.
|
||||
page_size: Maximum tasks per page (default 50).
|
||||
page_token: Base64-encoded cursor from previous response.
|
||||
history_length: Limit history messages per task (None = full history).
|
||||
include_artifacts: Whether to include task artifacts (default False).
|
||||
|
||||
Returns:
|
||||
Tuple of (filtered_tasks, next_page_token, total_count).
|
||||
- filtered_tasks: Tasks matching filters, paginated and trimmed.
|
||||
- next_page_token: Token for next page, or None if no more pages.
|
||||
- total_count: Total number of tasks matching filters (before pagination).
|
||||
"""
|
||||
filtered: list[A2ATask] = []
|
||||
for task in tasks:
|
||||
if context_id and task.context_id != context_id:
|
||||
continue
|
||||
if status and task.status.state != status:
|
||||
continue
|
||||
if status_timestamp_after and task.status.timestamp:
|
||||
ts = datetime.fromisoformat(task.status.timestamp.replace("Z", "+00:00"))
|
||||
if ts <= status_timestamp_after:
|
||||
continue
|
||||
filtered.append(task)
|
||||
|
||||
def get_timestamp(t: A2ATask) -> datetime:
|
||||
"""Extract timestamp from task status for sorting."""
|
||||
if t.status.timestamp is None:
|
||||
return datetime.min
|
||||
return datetime.fromisoformat(t.status.timestamp.replace("Z", "+00:00"))
|
||||
|
||||
filtered.sort(key=get_timestamp, reverse=True)
|
||||
total = len(filtered)
|
||||
|
||||
start = 0
|
||||
if page_token:
|
||||
try:
|
||||
cursor_id = base64.b64decode(page_token).decode()
|
||||
for idx, task in enumerate(filtered):
|
||||
if task.id == cursor_id:
|
||||
start = idx + 1
|
||||
break
|
||||
except (ValueError, UnicodeDecodeError):
|
||||
pass
|
||||
|
||||
page = filtered[start : start + page_size]
|
||||
|
||||
result: list[A2ATask] = []
|
||||
for task in page:
|
||||
task = task.model_copy(deep=True)
|
||||
if history_length is not None and task.history:
|
||||
task.history = task.history[-history_length:]
|
||||
if not include_artifacts:
|
||||
task.artifacts = None
|
||||
result.append(task)
|
||||
|
||||
next_token: str | None = None
|
||||
if result and len(result) == page_size:
|
||||
next_token = base64.b64encode(result[-1].id.encode()).decode()
|
||||
|
||||
return result, next_token, total
|
||||
214
lib/crewai-a2a/src/crewai_a2a/utils/transport.py
Normal file
214
lib/crewai-a2a/src/crewai_a2a/utils/transport.py
Normal file
@@ -0,0 +1,214 @@
|
||||
"""Transport negotiation utilities for A2A protocol.
|
||||
|
||||
This module provides functionality for negotiating the transport protocol
|
||||
between an A2A client and server based on their respective capabilities
|
||||
and preferences.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Final, Literal
|
||||
|
||||
from a2a.types import AgentCard, AgentInterface
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.a2a_events import A2ATransportNegotiatedEvent
|
||||
|
||||
|
||||
TransportProtocol = Literal["JSONRPC", "GRPC", "HTTP+JSON"]
|
||||
NegotiationSource = Literal["client_preferred", "server_preferred", "fallback"]
|
||||
|
||||
JSONRPC_TRANSPORT: Literal["JSONRPC"] = "JSONRPC"
|
||||
GRPC_TRANSPORT: Literal["GRPC"] = "GRPC"
|
||||
HTTP_JSON_TRANSPORT: Literal["HTTP+JSON"] = "HTTP+JSON"
|
||||
|
||||
DEFAULT_TRANSPORT_PREFERENCE: Final[list[TransportProtocol]] = [
|
||||
JSONRPC_TRANSPORT,
|
||||
GRPC_TRANSPORT,
|
||||
HTTP_JSON_TRANSPORT,
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class NegotiatedTransport:
|
||||
"""Result of transport negotiation.
|
||||
|
||||
Attributes:
|
||||
transport: The negotiated transport protocol.
|
||||
url: The URL to use for this transport.
|
||||
source: How the transport was selected ('preferred', 'additional', 'fallback').
|
||||
"""
|
||||
|
||||
transport: str
|
||||
url: str
|
||||
source: NegotiationSource
|
||||
|
||||
|
||||
class TransportNegotiationError(Exception):
|
||||
"""Raised when no compatible transport can be negotiated."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client_transports: list[str],
|
||||
server_transports: list[str],
|
||||
message: str | None = None,
|
||||
) -> None:
|
||||
"""Initialize the error with negotiation details.
|
||||
|
||||
Args:
|
||||
client_transports: Transports supported by the client.
|
||||
server_transports: Transports supported by the server.
|
||||
message: Optional custom error message.
|
||||
"""
|
||||
self.client_transports = client_transports
|
||||
self.server_transports = server_transports
|
||||
if message is None:
|
||||
message = (
|
||||
f"No compatible transport found. "
|
||||
f"Client supports: {client_transports}. "
|
||||
f"Server supports: {server_transports}."
|
||||
)
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
def _get_server_interfaces(agent_card: AgentCard) -> list[AgentInterface]:
|
||||
"""Extract all available interfaces from an AgentCard.
|
||||
|
||||
Creates a unified list of interfaces including the primary URL and
|
||||
any additional interfaces declared by the agent.
|
||||
|
||||
Args:
|
||||
agent_card: The agent's card containing transport information.
|
||||
|
||||
Returns:
|
||||
List of AgentInterface objects representing all available endpoints.
|
||||
"""
|
||||
interfaces: list[AgentInterface] = []
|
||||
|
||||
primary_transport = agent_card.preferred_transport or JSONRPC_TRANSPORT
|
||||
interfaces.append(
|
||||
AgentInterface(
|
||||
transport=primary_transport,
|
||||
url=agent_card.url,
|
||||
)
|
||||
)
|
||||
|
||||
if agent_card.additional_interfaces:
|
||||
for interface in agent_card.additional_interfaces:
|
||||
is_duplicate = any(
|
||||
i.url == interface.url and i.transport == interface.transport
|
||||
for i in interfaces
|
||||
)
|
||||
if not is_duplicate:
|
||||
interfaces.append(interface)
|
||||
|
||||
return interfaces
|
||||
|
||||
|
||||
def negotiate_transport(
|
||||
agent_card: AgentCard,
|
||||
client_supported_transports: list[str] | None = None,
|
||||
client_preferred_transport: str | None = None,
|
||||
emit_event: bool = True,
|
||||
endpoint: str | None = None,
|
||||
a2a_agent_name: str | None = None,
|
||||
) -> NegotiatedTransport:
|
||||
"""Negotiate the transport protocol between client and server.
|
||||
|
||||
Compares the client's supported transports with the server's available
|
||||
interfaces to find a compatible transport and URL.
|
||||
|
||||
Negotiation logic:
|
||||
1. If client_preferred_transport is set and server supports it → use it
|
||||
2. Otherwise, if server's preferred is in client's supported → use server's
|
||||
3. Otherwise, find first match from client's supported in server's interfaces
|
||||
|
||||
Args:
|
||||
agent_card: The server's AgentCard with transport information.
|
||||
client_supported_transports: Transports the client can use.
|
||||
Defaults to ["JSONRPC"] if not specified.
|
||||
client_preferred_transport: Client's preferred transport. If set and
|
||||
server supports it, takes priority over server preference.
|
||||
emit_event: Whether to emit a transport negotiation event.
|
||||
endpoint: Original endpoint URL for event metadata.
|
||||
a2a_agent_name: Agent name for event metadata.
|
||||
|
||||
Returns:
|
||||
NegotiatedTransport with the selected transport, URL, and source.
|
||||
|
||||
Raises:
|
||||
TransportNegotiationError: If no compatible transport is found.
|
||||
"""
|
||||
if client_supported_transports is None:
|
||||
client_supported_transports = [JSONRPC_TRANSPORT]
|
||||
|
||||
client_transports = [t.upper() for t in client_supported_transports]
|
||||
client_preferred = (
|
||||
client_preferred_transport.upper() if client_preferred_transport else None
|
||||
)
|
||||
|
||||
server_interfaces = _get_server_interfaces(agent_card)
|
||||
server_transports = [i.transport.upper() for i in server_interfaces]
|
||||
|
||||
transport_to_interface: dict[str, AgentInterface] = {}
|
||||
for interface in server_interfaces:
|
||||
transport_upper = interface.transport.upper()
|
||||
if transport_upper not in transport_to_interface:
|
||||
transport_to_interface[transport_upper] = interface
|
||||
|
||||
result: NegotiatedTransport | None = None
|
||||
|
||||
if client_preferred and client_preferred in transport_to_interface:
|
||||
interface = transport_to_interface[client_preferred]
|
||||
result = NegotiatedTransport(
|
||||
transport=interface.transport,
|
||||
url=interface.url,
|
||||
source="client_preferred",
|
||||
)
|
||||
else:
|
||||
server_preferred = (agent_card.preferred_transport or JSONRPC_TRANSPORT).upper()
|
||||
if (
|
||||
server_preferred in client_transports
|
||||
and server_preferred in transport_to_interface
|
||||
):
|
||||
interface = transport_to_interface[server_preferred]
|
||||
result = NegotiatedTransport(
|
||||
transport=interface.transport,
|
||||
url=interface.url,
|
||||
source="server_preferred",
|
||||
)
|
||||
else:
|
||||
for transport in client_transports:
|
||||
if transport in transport_to_interface:
|
||||
interface = transport_to_interface[transport]
|
||||
result = NegotiatedTransport(
|
||||
transport=interface.transport,
|
||||
url=interface.url,
|
||||
source="fallback",
|
||||
)
|
||||
break
|
||||
|
||||
if result is None:
|
||||
raise TransportNegotiationError(
|
||||
client_transports=client_transports,
|
||||
server_transports=server_transports,
|
||||
)
|
||||
|
||||
if emit_event:
|
||||
crewai_event_bus.emit(
|
||||
None,
|
||||
A2ATransportNegotiatedEvent(
|
||||
endpoint=endpoint or agent_card.url,
|
||||
a2a_agent_name=a2a_agent_name or agent_card.name,
|
||||
negotiated_transport=result.transport,
|
||||
negotiated_url=result.url,
|
||||
source=result.source,
|
||||
client_supported_transports=client_transports,
|
||||
server_supported_transports=server_transports,
|
||||
server_preferred_transport=agent_card.preferred_transport
|
||||
or JSONRPC_TRANSPORT,
|
||||
client_preferred_transport=client_preferred,
|
||||
),
|
||||
)
|
||||
|
||||
return result
|
||||
1752
lib/crewai-a2a/src/crewai_a2a/wrapper.py
Normal file
1752
lib/crewai-a2a/src/crewai_a2a/wrapper.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,44 @@
|
||||
interactions:
|
||||
- request:
|
||||
body: ''
|
||||
headers:
|
||||
User-Agent:
|
||||
- X-USER-AGENT-XXX
|
||||
accept:
|
||||
- '*/*'
|
||||
accept-encoding:
|
||||
- ACCEPT-ENCODING-XXX
|
||||
connection:
|
||||
- keep-alive
|
||||
host:
|
||||
- localhost:9999
|
||||
method: GET
|
||||
uri: http://localhost:9999/.well-known/agent-card.json
|
||||
response:
|
||||
body:
|
||||
string: '{"capabilities":{"streaming":true},"defaultInputModes":["text"],"defaultOutputModes":["text"],"description":"An
|
||||
AI assistant powered by OpenAI GPT with calculator and time tools. Ask questions,
|
||||
perform calculations, or get the current time in any timezone.","name":"GPT
|
||||
Assistant","preferredTransport":"JSONRPC","protocolVersion":"0.3.0","skills":[{"description":"Have
|
||||
a general conversation with the AI assistant. Ask questions, get explanations,
|
||||
or just chat.","examples":["Hello, how are you?","Explain quantum computing
|
||||
in simple terms","What can you help me with?"],"id":"conversation","name":"General
|
||||
Conversation","tags":["chat","conversation","general"]},{"description":"Perform
|
||||
mathematical calculations including arithmetic, exponents, and more.","examples":["What
|
||||
is 25 * 17?","Calculate 2^10","What''s (100 + 50) / 3?"],"id":"calculator","name":"Calculator","tags":["math","calculator","arithmetic"]},{"description":"Get
|
||||
the current date and time in any timezone.","examples":["What time is it?","What''s
|
||||
the current time in Tokyo?","What''s today''s date in New York?"],"id":"time","name":"Current
|
||||
Time","tags":["time","date","timezone"]}],"url":"http://localhost:9999/","version":"1.0.0"}'
|
||||
headers:
|
||||
content-length:
|
||||
- '1198'
|
||||
content-type:
|
||||
- application/json
|
||||
date:
|
||||
- Tue, 06 Jan 2026 14:17:00 GMT
|
||||
server:
|
||||
- uvicorn
|
||||
status:
|
||||
code: 200
|
||||
message: OK
|
||||
version: 1
|
||||
@@ -0,0 +1,126 @@
|
||||
interactions:
|
||||
- request:
|
||||
body: ''
|
||||
headers:
|
||||
User-Agent:
|
||||
- X-USER-AGENT-XXX
|
||||
accept:
|
||||
- '*/*'
|
||||
accept-encoding:
|
||||
- ACCEPT-ENCODING-XXX
|
||||
connection:
|
||||
- keep-alive
|
||||
host:
|
||||
- localhost:9999
|
||||
method: GET
|
||||
uri: http://localhost:9999/.well-known/agent-card.json
|
||||
response:
|
||||
body:
|
||||
string: '{"capabilities":{"streaming":true},"defaultInputModes":["text"],"defaultOutputModes":["text"],"description":"An
|
||||
AI assistant powered by OpenAI GPT with calculator and time tools. Ask questions,
|
||||
perform calculations, or get the current time in any timezone.","name":"GPT
|
||||
Assistant","preferredTransport":"JSONRPC","protocolVersion":"0.3.0","skills":[{"description":"Have
|
||||
a general conversation with the AI assistant. Ask questions, get explanations,
|
||||
or just chat.","examples":["Hello, how are you?","Explain quantum computing
|
||||
in simple terms","What can you help me with?"],"id":"conversation","name":"General
|
||||
Conversation","tags":["chat","conversation","general"]},{"description":"Perform
|
||||
mathematical calculations including arithmetic, exponents, and more.","examples":["What
|
||||
is 25 * 17?","Calculate 2^10","What''s (100 + 50) / 3?"],"id":"calculator","name":"Calculator","tags":["math","calculator","arithmetic"]},{"description":"Get
|
||||
the current date and time in any timezone.","examples":["What time is it?","What''s
|
||||
the current time in Tokyo?","What''s today''s date in New York?"],"id":"time","name":"Current
|
||||
Time","tags":["time","date","timezone"]}],"url":"http://localhost:9999/","version":"1.0.0"}'
|
||||
headers:
|
||||
content-length:
|
||||
- '1198'
|
||||
content-type:
|
||||
- application/json
|
||||
date:
|
||||
- Tue, 06 Jan 2026 14:16:58 GMT
|
||||
server:
|
||||
- uvicorn
|
||||
status:
|
||||
code: 200
|
||||
message: OK
|
||||
- request:
|
||||
body: '{"id":"e5ac2160-ae9b-4bf9-aad7-14bf0d53d6d9","jsonrpc":"2.0","method":"message/stream","params":{"configuration":{"acceptedOutputModes":[],"blocking":true},"message":{"kind":"message","messageId":"e1e63c75-3ea0-49fb-b512-5128a2476416","parts":[{"kind":"text","text":"What
|
||||
is 2 + 2?"}],"role":"user"}}}'
|
||||
headers:
|
||||
User-Agent:
|
||||
- X-USER-AGENT-XXX
|
||||
accept:
|
||||
- '*/*, text/event-stream'
|
||||
accept-encoding:
|
||||
- ACCEPT-ENCODING-XXX
|
||||
cache-control:
|
||||
- no-store
|
||||
connection:
|
||||
- keep-alive
|
||||
content-length:
|
||||
- '301'
|
||||
content-type:
|
||||
- application/json
|
||||
host:
|
||||
- localhost:9999
|
||||
method: POST
|
||||
uri: http://localhost:9999/
|
||||
response:
|
||||
body:
|
||||
string: "data: {\"id\":\"e5ac2160-ae9b-4bf9-aad7-14bf0d53d6d9\",\"jsonrpc\":\"2.0\",\"result\":{\"contextId\":\"b9e14c1b-734d-4d1e-864a-e6dda5231d71\",\"final\":false,\"kind\":\"status-update\",\"status\":{\"state\":\"submitted\"},\"taskId\":\"0dd4d3af-f35d-409d-9462-01218e5641f9\"}}\r\n\r\ndata:
|
||||
{\"id\":\"e5ac2160-ae9b-4bf9-aad7-14bf0d53d6d9\",\"jsonrpc\":\"2.0\",\"result\":{\"contextId\":\"b9e14c1b-734d-4d1e-864a-e6dda5231d71\",\"final\":false,\"kind\":\"status-update\",\"status\":{\"state\":\"working\"},\"taskId\":\"0dd4d3af-f35d-409d-9462-01218e5641f9\"}}\r\n\r\ndata:
|
||||
{\"id\":\"e5ac2160-ae9b-4bf9-aad7-14bf0d53d6d9\",\"jsonrpc\":\"2.0\",\"result\":{\"contextId\":\"b9e14c1b-734d-4d1e-864a-e6dda5231d71\",\"final\":true,\"kind\":\"status-update\",\"status\":{\"message\":{\"kind\":\"message\",\"messageId\":\"54bb7ff3-f2c0-4eb3-b427-bf1c8cf90832\",\"parts\":[{\"kind\":\"text\",\"text\":\"\\n[Tool:
|
||||
calculator] 2 + 2 = 4\\n2 + 2 equals 4.\"}],\"role\":\"agent\"},\"state\":\"completed\"},\"taskId\":\"0dd4d3af-f35d-409d-9462-01218e5641f9\"}}\r\n\r\n"
|
||||
headers:
|
||||
Transfer-Encoding:
|
||||
- chunked
|
||||
cache-control:
|
||||
- no-store
|
||||
connection:
|
||||
- keep-alive
|
||||
content-type:
|
||||
- text/event-stream; charset=utf-8
|
||||
date:
|
||||
- Tue, 06 Jan 2026 14:16:58 GMT
|
||||
server:
|
||||
- uvicorn
|
||||
x-accel-buffering:
|
||||
- 'no'
|
||||
status:
|
||||
code: 200
|
||||
message: OK
|
||||
- request:
|
||||
body: '{"id":"cb1e4af3-d2d0-4848-96b8-7082ee6171d1","jsonrpc":"2.0","method":"tasks/get","params":{"historyLength":100,"id":"0dd4d3af-f35d-409d-9462-01218e5641f9"}}'
|
||||
headers:
|
||||
User-Agent:
|
||||
- X-USER-AGENT-XXX
|
||||
accept:
|
||||
- '*/*'
|
||||
accept-encoding:
|
||||
- ACCEPT-ENCODING-XXX
|
||||
connection:
|
||||
- keep-alive
|
||||
content-length:
|
||||
- '157'
|
||||
content-type:
|
||||
- application/json
|
||||
host:
|
||||
- localhost:9999
|
||||
method: POST
|
||||
uri: http://localhost:9999/
|
||||
response:
|
||||
body:
|
||||
string: '{"id":"cb1e4af3-d2d0-4848-96b8-7082ee6171d1","jsonrpc":"2.0","result":{"contextId":"b9e14c1b-734d-4d1e-864a-e6dda5231d71","history":[{"contextId":"b9e14c1b-734d-4d1e-864a-e6dda5231d71","kind":"message","messageId":"e1e63c75-3ea0-49fb-b512-5128a2476416","parts":[{"kind":"text","text":"What
|
||||
is 2 + 2?"}],"role":"user","taskId":"0dd4d3af-f35d-409d-9462-01218e5641f9"}],"id":"0dd4d3af-f35d-409d-9462-01218e5641f9","kind":"task","status":{"message":{"kind":"message","messageId":"54bb7ff3-f2c0-4eb3-b427-bf1c8cf90832","parts":[{"kind":"text","text":"\n[Tool:
|
||||
calculator] 2 + 2 = 4\n2 + 2 equals 4."}],"role":"agent"},"state":"completed"}}}'
|
||||
headers:
|
||||
content-length:
|
||||
- '635'
|
||||
content-type:
|
||||
- application/json
|
||||
date:
|
||||
- Tue, 06 Jan 2026 14:17:00 GMT
|
||||
server:
|
||||
- uvicorn
|
||||
status:
|
||||
code: 200
|
||||
message: OK
|
||||
version: 1
|
||||
@@ -0,0 +1,90 @@
|
||||
interactions:
|
||||
- request:
|
||||
body: ''
|
||||
headers:
|
||||
User-Agent:
|
||||
- X-USER-AGENT-XXX
|
||||
accept:
|
||||
- '*/*'
|
||||
accept-encoding:
|
||||
- ACCEPT-ENCODING-XXX
|
||||
connection:
|
||||
- keep-alive
|
||||
host:
|
||||
- localhost:9999
|
||||
method: GET
|
||||
uri: http://localhost:9999/.well-known/agent-card.json
|
||||
response:
|
||||
body:
|
||||
string: '{"capabilities":{"streaming":true},"defaultInputModes":["text"],"defaultOutputModes":["text"],"description":"An
|
||||
AI assistant powered by OpenAI GPT with calculator and time tools. Ask questions,
|
||||
perform calculations, or get the current time in any timezone.","name":"GPT
|
||||
Assistant","preferredTransport":"JSONRPC","protocolVersion":"0.3.0","skills":[{"description":"Have
|
||||
a general conversation with the AI assistant. Ask questions, get explanations,
|
||||
or just chat.","examples":["Hello, how are you?","Explain quantum computing
|
||||
in simple terms","What can you help me with?"],"id":"conversation","name":"General
|
||||
Conversation","tags":["chat","conversation","general"]},{"description":"Perform
|
||||
mathematical calculations including arithmetic, exponents, and more.","examples":["What
|
||||
is 25 * 17?","Calculate 2^10","What''s (100 + 50) / 3?"],"id":"calculator","name":"Calculator","tags":["math","calculator","arithmetic"]},{"description":"Get
|
||||
the current date and time in any timezone.","examples":["What time is it?","What''s
|
||||
the current time in Tokyo?","What''s today''s date in New York?"],"id":"time","name":"Current
|
||||
Time","tags":["time","date","timezone"]}],"url":"http://localhost:9999/","version":"1.0.0"}'
|
||||
headers:
|
||||
content-length:
|
||||
- '1198'
|
||||
content-type:
|
||||
- application/json
|
||||
date:
|
||||
- Tue, 06 Jan 2026 14:17:02 GMT
|
||||
server:
|
||||
- uvicorn
|
||||
status:
|
||||
code: 200
|
||||
message: OK
|
||||
- request:
|
||||
body: '{"id":"8cf25b61-8884-4246-adce-fccb32e176ab","jsonrpc":"2.0","method":"message/stream","params":{"configuration":{"acceptedOutputModes":[],"blocking":true},"message":{"kind":"message","messageId":"c145297f-7331-4835-adcc-66b51de92a2b","parts":[{"kind":"text","text":"What
|
||||
is 2 + 2?"}],"role":"user"}}}'
|
||||
headers:
|
||||
User-Agent:
|
||||
- X-USER-AGENT-XXX
|
||||
accept:
|
||||
- '*/*, text/event-stream'
|
||||
accept-encoding:
|
||||
- ACCEPT-ENCODING-XXX
|
||||
cache-control:
|
||||
- no-store
|
||||
connection:
|
||||
- keep-alive
|
||||
content-length:
|
||||
- '301'
|
||||
content-type:
|
||||
- application/json
|
||||
host:
|
||||
- localhost:9999
|
||||
method: POST
|
||||
uri: http://localhost:9999/
|
||||
response:
|
||||
body:
|
||||
string: "data: {\"id\":\"8cf25b61-8884-4246-adce-fccb32e176ab\",\"jsonrpc\":\"2.0\",\"result\":{\"contextId\":\"30601267-ab3b-48ef-afc8-916c37a18651\",\"final\":false,\"kind\":\"status-update\",\"status\":{\"state\":\"submitted\"},\"taskId\":\"3083d3da-4739-4f4f-a4e8-7c048ea819c1\"}}\r\n\r\ndata:
|
||||
{\"id\":\"8cf25b61-8884-4246-adce-fccb32e176ab\",\"jsonrpc\":\"2.0\",\"result\":{\"contextId\":\"30601267-ab3b-48ef-afc8-916c37a18651\",\"final\":false,\"kind\":\"status-update\",\"status\":{\"state\":\"working\"},\"taskId\":\"3083d3da-4739-4f4f-a4e8-7c048ea819c1\"}}\r\n\r\ndata:
|
||||
{\"id\":\"8cf25b61-8884-4246-adce-fccb32e176ab\",\"jsonrpc\":\"2.0\",\"result\":{\"contextId\":\"30601267-ab3b-48ef-afc8-916c37a18651\",\"final\":true,\"kind\":\"status-update\",\"status\":{\"message\":{\"kind\":\"message\",\"messageId\":\"25f81e3c-b7e8-48b5-a98a-4066f3637a13\",\"parts\":[{\"kind\":\"text\",\"text\":\"\\n[Tool:
|
||||
calculator] 2 + 2 = 4\\n2 + 2 equals 4.\"}],\"role\":\"agent\"},\"state\":\"completed\"},\"taskId\":\"3083d3da-4739-4f4f-a4e8-7c048ea819c1\"}}\r\n\r\n"
|
||||
headers:
|
||||
Transfer-Encoding:
|
||||
- chunked
|
||||
cache-control:
|
||||
- no-store
|
||||
connection:
|
||||
- keep-alive
|
||||
content-type:
|
||||
- text/event-stream; charset=utf-8
|
||||
date:
|
||||
- Tue, 06 Jan 2026 14:17:02 GMT
|
||||
server:
|
||||
- uvicorn
|
||||
x-accel-buffering:
|
||||
- 'no'
|
||||
status:
|
||||
code: 200
|
||||
message: OK
|
||||
version: 1
|
||||
@@ -0,0 +1,90 @@
|
||||
interactions:
|
||||
- request:
|
||||
body: ''
|
||||
headers:
|
||||
User-Agent:
|
||||
- X-USER-AGENT-XXX
|
||||
accept:
|
||||
- '*/*'
|
||||
accept-encoding:
|
||||
- ACCEPT-ENCODING-XXX
|
||||
connection:
|
||||
- keep-alive
|
||||
host:
|
||||
- localhost:9999
|
||||
method: GET
|
||||
uri: http://localhost:9999/.well-known/agent-card.json
|
||||
response:
|
||||
body:
|
||||
string: '{"capabilities":{"streaming":true},"defaultInputModes":["text"],"defaultOutputModes":["text"],"description":"An
|
||||
AI assistant powered by OpenAI GPT with calculator and time tools. Ask questions,
|
||||
perform calculations, or get the current time in any timezone.","name":"GPT
|
||||
Assistant","preferredTransport":"JSONRPC","protocolVersion":"0.3.0","skills":[{"description":"Have
|
||||
a general conversation with the AI assistant. Ask questions, get explanations,
|
||||
or just chat.","examples":["Hello, how are you?","Explain quantum computing
|
||||
in simple terms","What can you help me with?"],"id":"conversation","name":"General
|
||||
Conversation","tags":["chat","conversation","general"]},{"description":"Perform
|
||||
mathematical calculations including arithmetic, exponents, and more.","examples":["What
|
||||
is 25 * 17?","Calculate 2^10","What''s (100 + 50) / 3?"],"id":"calculator","name":"Calculator","tags":["math","calculator","arithmetic"]},{"description":"Get
|
||||
the current date and time in any timezone.","examples":["What time is it?","What''s
|
||||
the current time in Tokyo?","What''s today''s date in New York?"],"id":"time","name":"Current
|
||||
Time","tags":["time","date","timezone"]}],"url":"http://localhost:9999/","version":"1.0.0"}'
|
||||
headers:
|
||||
content-length:
|
||||
- '1198'
|
||||
content-type:
|
||||
- application/json
|
||||
date:
|
||||
- Tue, 06 Jan 2026 14:17:00 GMT
|
||||
server:
|
||||
- uvicorn
|
||||
status:
|
||||
code: 200
|
||||
message: OK
|
||||
- request:
|
||||
body: '{"id":"3a17c6bf-8db6-45a6-8535-34c45c0c4936","jsonrpc":"2.0","method":"message/stream","params":{"configuration":{"acceptedOutputModes":[],"blocking":true},"message":{"kind":"message","messageId":"712558a3-6d92-4591-be8a-9dd8566dde82","parts":[{"kind":"text","text":"What
|
||||
is 2 + 2?"}],"role":"user"}}}'
|
||||
headers:
|
||||
User-Agent:
|
||||
- X-USER-AGENT-XXX
|
||||
accept:
|
||||
- '*/*, text/event-stream'
|
||||
accept-encoding:
|
||||
- ACCEPT-ENCODING-XXX
|
||||
cache-control:
|
||||
- no-store
|
||||
connection:
|
||||
- keep-alive
|
||||
content-length:
|
||||
- '301'
|
||||
content-type:
|
||||
- application/json
|
||||
host:
|
||||
- localhost:9999
|
||||
method: POST
|
||||
uri: http://localhost:9999/
|
||||
response:
|
||||
body:
|
||||
string: "data: {\"id\":\"3a17c6bf-8db6-45a6-8535-34c45c0c4936\",\"jsonrpc\":\"2.0\",\"result\":{\"contextId\":\"ca2fbbc9-761e-45d9-a929-0c68b1f8acbf\",\"final\":false,\"kind\":\"status-update\",\"status\":{\"state\":\"submitted\"},\"taskId\":\"c6e88db0-36e9-4269-8b9a-ecb6dfdcf6a1\"}}\r\n\r\ndata:
|
||||
{\"id\":\"3a17c6bf-8db6-45a6-8535-34c45c0c4936\",\"jsonrpc\":\"2.0\",\"result\":{\"contextId\":\"ca2fbbc9-761e-45d9-a929-0c68b1f8acbf\",\"final\":false,\"kind\":\"status-update\",\"status\":{\"state\":\"working\"},\"taskId\":\"c6e88db0-36e9-4269-8b9a-ecb6dfdcf6a1\"}}\r\n\r\ndata:
|
||||
{\"id\":\"3a17c6bf-8db6-45a6-8535-34c45c0c4936\",\"jsonrpc\":\"2.0\",\"result\":{\"contextId\":\"ca2fbbc9-761e-45d9-a929-0c68b1f8acbf\",\"final\":true,\"kind\":\"status-update\",\"status\":{\"message\":{\"kind\":\"message\",\"messageId\":\"916324aa-fd25-4849-bceb-c4644e2fcbb0\",\"parts\":[{\"kind\":\"text\",\"text\":\"\\n[Tool:
|
||||
calculator] 2 + 2 = 4\\n2 + 2 equals 4.\"}],\"role\":\"agent\"},\"state\":\"completed\"},\"taskId\":\"c6e88db0-36e9-4269-8b9a-ecb6dfdcf6a1\"}}\r\n\r\n"
|
||||
headers:
|
||||
Transfer-Encoding:
|
||||
- chunked
|
||||
cache-control:
|
||||
- no-store
|
||||
connection:
|
||||
- keep-alive
|
||||
content-type:
|
||||
- text/event-stream; charset=utf-8
|
||||
date:
|
||||
- Tue, 06 Jan 2026 14:17:00 GMT
|
||||
server:
|
||||
- uvicorn
|
||||
x-accel-buffering:
|
||||
- 'no'
|
||||
status:
|
||||
code: 200
|
||||
message: OK
|
||||
version: 1
|
||||
21
lib/crewai-a2a/tests/conftest.py
Normal file
21
lib/crewai-a2a/tests/conftest.py
Normal file
@@ -0,0 +1,21 @@
|
||||
"""Pytest configuration for crewai-a2a tests.
|
||||
|
||||
Ensures Agent model is properly rebuilt with A2A types,
|
||||
which can fail silently during circular import resolution.
|
||||
"""
|
||||
|
||||
from crewai_a2a.config import A2AClientConfig, A2AConfig, A2AServerConfig
|
||||
|
||||
|
||||
def pytest_configure() -> None:
|
||||
"""Rebuild Agent/LiteAgent models after crewai_a2a is fully loaded."""
|
||||
from crewai.agent.core import Agent
|
||||
from crewai.lite_agent import LiteAgent
|
||||
|
||||
ns = {
|
||||
"A2AConfig": A2AConfig,
|
||||
"A2AClientConfig": A2AClientConfig,
|
||||
"A2AServerConfig": A2AServerConfig,
|
||||
}
|
||||
Agent.model_rebuild(_types_namespace=ns)
|
||||
LiteAgent.model_rebuild(_types_namespace=ns)
|
||||
@@ -3,15 +3,13 @@ from __future__ import annotations
|
||||
import os
|
||||
import uuid
|
||||
|
||||
from a2a.client import ClientFactory
|
||||
from a2a.types import AgentCard, Message, Part, Role, Task, TaskState, TextPart
|
||||
from crewai_a2a.updates.polling.handler import PollingHandler
|
||||
from crewai_a2a.updates.streaming.handler import StreamingHandler
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from a2a.client import ClientFactory
|
||||
from a2a.types import AgentCard, Message, Part, Role, TaskState, TextPart
|
||||
|
||||
from crewai.a2a.updates.polling.handler import PollingHandler
|
||||
from crewai.a2a.updates.streaming.handler import StreamingHandler
|
||||
|
||||
|
||||
A2A_TEST_ENDPOINT = os.getenv("A2A_TEST_ENDPOINT", "http://localhost:9999")
|
||||
|
||||
@@ -162,7 +160,7 @@ class TestA2APushNotificationHandler:
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_task(self) -> "Task":
|
||||
def mock_task(self) -> Task:
|
||||
"""Create a minimal valid task for testing."""
|
||||
from a2a.types import Task, TaskStatus
|
||||
|
||||
@@ -182,11 +180,12 @@ class TestA2APushNotificationHandler:
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
from a2a.types import Task, TaskStatus
|
||||
from crewai_a2a.updates.push_notifications.config import PushNotificationConfig
|
||||
from crewai_a2a.updates.push_notifications.handler import (
|
||||
PushNotificationHandler,
|
||||
)
|
||||
from pydantic import AnyHttpUrl
|
||||
|
||||
from crewai.a2a.updates.push_notifications.config import PushNotificationConfig
|
||||
from crewai.a2a.updates.push_notifications.handler import PushNotificationHandler
|
||||
|
||||
completed_task = Task(
|
||||
id="task-123",
|
||||
context_id="ctx-123",
|
||||
@@ -246,11 +245,12 @@ class TestA2APushNotificationHandler:
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
from a2a.types import Task, TaskStatus
|
||||
from crewai_a2a.updates.push_notifications.config import PushNotificationConfig
|
||||
from crewai_a2a.updates.push_notifications.handler import (
|
||||
PushNotificationHandler,
|
||||
)
|
||||
from pydantic import AnyHttpUrl
|
||||
|
||||
from crewai.a2a.updates.push_notifications.config import PushNotificationConfig
|
||||
from crewai.a2a.updates.push_notifications.handler import PushNotificationHandler
|
||||
|
||||
mock_store = MagicMock()
|
||||
mock_store.wait_for_result = AsyncMock(return_value=None)
|
||||
|
||||
@@ -303,7 +303,9 @@ class TestA2APushNotificationHandler:
|
||||
"""Test that push handler fails gracefully without config."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from crewai.a2a.updates.push_notifications.handler import PushNotificationHandler
|
||||
from crewai_a2a.updates.push_notifications.handler import (
|
||||
PushNotificationHandler,
|
||||
)
|
||||
|
||||
mock_client = MagicMock()
|
||||
|
||||
@@ -3,10 +3,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from a2a.types import AgentCard, AgentSkill
|
||||
|
||||
from crewai import Agent
|
||||
from crewai.a2a.config import A2AClientConfig, A2AServerConfig
|
||||
from crewai.a2a.utils.agent_card import inject_a2a_server_methods
|
||||
from crewai_a2a.config import A2AClientConfig, A2AServerConfig
|
||||
from crewai_a2a.utils.agent_card import inject_a2a_server_methods
|
||||
|
||||
|
||||
class TestInjectA2AServerMethods:
|
||||
@@ -6,13 +6,12 @@ import asyncio
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from a2a.server.agent_execution import RequestContext
|
||||
from a2a.server.events import EventQueue
|
||||
from a2a.types import Message, Task as A2ATask, TaskState, TaskStatus
|
||||
|
||||
from crewai.a2a.utils.task import cancel, cancellable, execute
|
||||
from crewai_a2a.utils.task import cancel, cancellable, execute
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -85,8 +84,11 @@ class TestCancellableDecorator:
|
||||
assert call_count == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_executes_function_with_context(self, mock_context: MagicMock) -> None:
|
||||
async def test_executes_function_with_context(
|
||||
self, mock_context: MagicMock
|
||||
) -> None:
|
||||
"""Function executes normally with RequestContext when not cancelled."""
|
||||
|
||||
@cancellable
|
||||
async def my_func(context: RequestContext) -> str:
|
||||
await asyncio.sleep(0.01)
|
||||
@@ -134,6 +136,7 @@ class TestCancellableDecorator:
|
||||
@pytest.mark.asyncio
|
||||
async def test_extracts_context_from_kwargs(self, mock_context: MagicMock) -> None:
|
||||
"""Context can be passed as keyword argument."""
|
||||
|
||||
@cancellable
|
||||
async def my_func(value: int, context: RequestContext | None = None) -> int:
|
||||
return value + 1
|
||||
@@ -156,8 +159,8 @@ class TestExecute:
|
||||
) -> None:
|
||||
"""Execute completes successfully and enqueues completed task."""
|
||||
with (
|
||||
patch("crewai.a2a.utils.task.Task", return_value=mock_task),
|
||||
patch("crewai.a2a.utils.task.crewai_event_bus") as mock_bus,
|
||||
patch("crewai_a2a.utils.task.Task", return_value=mock_task),
|
||||
patch("crewai_a2a.utils.task.crewai_event_bus") as mock_bus,
|
||||
):
|
||||
await execute(mock_agent, mock_context, mock_event_queue)
|
||||
|
||||
@@ -175,8 +178,8 @@ class TestExecute:
|
||||
) -> None:
|
||||
"""Execute emits A2AServerTaskStartedEvent."""
|
||||
with (
|
||||
patch("crewai.a2a.utils.task.Task", return_value=mock_task),
|
||||
patch("crewai.a2a.utils.task.crewai_event_bus") as mock_bus,
|
||||
patch("crewai_a2a.utils.task.Task", return_value=mock_task),
|
||||
patch("crewai_a2a.utils.task.crewai_event_bus") as mock_bus,
|
||||
):
|
||||
await execute(mock_agent, mock_context, mock_event_queue)
|
||||
|
||||
@@ -197,8 +200,8 @@ class TestExecute:
|
||||
) -> None:
|
||||
"""Execute emits A2AServerTaskCompletedEvent on success."""
|
||||
with (
|
||||
patch("crewai.a2a.utils.task.Task", return_value=mock_task),
|
||||
patch("crewai.a2a.utils.task.crewai_event_bus") as mock_bus,
|
||||
patch("crewai_a2a.utils.task.Task", return_value=mock_task),
|
||||
patch("crewai_a2a.utils.task.crewai_event_bus") as mock_bus,
|
||||
):
|
||||
await execute(mock_agent, mock_context, mock_event_queue)
|
||||
|
||||
@@ -221,8 +224,8 @@ class TestExecute:
|
||||
mock_agent.aexecute_task = AsyncMock(side_effect=ValueError("Test error"))
|
||||
|
||||
with (
|
||||
patch("crewai.a2a.utils.task.Task", return_value=mock_task),
|
||||
patch("crewai.a2a.utils.task.crewai_event_bus") as mock_bus,
|
||||
patch("crewai_a2a.utils.task.Task", return_value=mock_task),
|
||||
patch("crewai_a2a.utils.task.crewai_event_bus") as mock_bus,
|
||||
):
|
||||
with pytest.raises(Exception):
|
||||
await execute(mock_agent, mock_context, mock_event_queue)
|
||||
@@ -245,8 +248,8 @@ class TestExecute:
|
||||
mock_agent.aexecute_task = AsyncMock(side_effect=asyncio.CancelledError())
|
||||
|
||||
with (
|
||||
patch("crewai.a2a.utils.task.Task", return_value=mock_task),
|
||||
patch("crewai.a2a.utils.task.crewai_event_bus") as mock_bus,
|
||||
patch("crewai_a2a.utils.task.Task", return_value=mock_task),
|
||||
patch("crewai_a2a.utils.task.crewai_event_bus") as mock_bus,
|
||||
):
|
||||
with pytest.raises(asyncio.CancelledError):
|
||||
await execute(mock_agent, mock_context, mock_event_queue)
|
||||
@@ -354,6 +357,7 @@ class TestExecuteAndCancelIntegration:
|
||||
mock_task: MagicMock,
|
||||
) -> None:
|
||||
"""Calling cancel stops a running execute."""
|
||||
|
||||
async def slow_task(**kwargs: Any) -> str:
|
||||
await asyncio.sleep(2.0)
|
||||
return "should not complete"
|
||||
@@ -361,8 +365,8 @@ class TestExecuteAndCancelIntegration:
|
||||
mock_agent.aexecute_task = slow_task
|
||||
|
||||
with (
|
||||
patch("crewai.a2a.utils.task.Task", return_value=mock_task),
|
||||
patch("crewai.a2a.utils.task.crewai_event_bus"),
|
||||
patch("crewai_a2a.utils.task.Task", return_value=mock_task),
|
||||
patch("crewai_a2a.utils.task.crewai_event_bus"),
|
||||
):
|
||||
execute_task = asyncio.create_task(
|
||||
execute(mock_agent, mock_context, mock_event_queue)
|
||||
@@ -372,4 +376,4 @@ class TestExecuteAndCancelIntegration:
|
||||
await cancel(mock_context, mock_event_queue)
|
||||
|
||||
with pytest.raises(asyncio.CancelledError):
|
||||
await execute_task
|
||||
await execute_task
|
||||
@@ -96,12 +96,7 @@ azure-ai-inference = [
|
||||
anthropic = [
|
||||
"anthropic~=0.73.0",
|
||||
]
|
||||
a2a = [
|
||||
"a2a-sdk~=0.3.10",
|
||||
"httpx-auth~=0.23.1",
|
||||
"httpx-sse~=0.4.0",
|
||||
"aiocache[redis,memcached]~=0.12.3",
|
||||
]
|
||||
a2a = ["crewai-a2a==1.10.1b1"]
|
||||
file-processing = [
|
||||
"crewai-files",
|
||||
]
|
||||
@@ -132,6 +127,7 @@ torchvision = [
|
||||
{ index = "pytorch", marker = "python_version < '3.13'" },
|
||||
]
|
||||
crewai-files = { workspace = true }
|
||||
crewai-a2a = { workspace = true }
|
||||
|
||||
|
||||
[build-system]
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
"""Agent-to-Agent (A2A) protocol communication module for CrewAI."""
|
||||
"""Backward-compatibility shim — use ``crewai_a2a`` instead."""
|
||||
|
||||
from crewai.a2a.config import A2AClientConfig, A2AConfig, A2AServerConfig
|
||||
import warnings
|
||||
|
||||
|
||||
__all__ = [
|
||||
"A2AClientConfig",
|
||||
"A2AConfig",
|
||||
"A2AServerConfig",
|
||||
]
|
||||
warnings.warn(
|
||||
"'crewai.a2a' has been moved to 'crewai_a2a'. "
|
||||
"Please update your imports. The old path will be removed in v2.0.0.",
|
||||
FutureWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
from crewai_a2a import * # noqa: E402, F403
|
||||
|
||||
@@ -1,36 +1,13 @@
|
||||
"""A2A authentication schemas."""
|
||||
"""Backward-compatibility shim — use ``crewai_a2a.auth`` instead."""
|
||||
|
||||
from crewai.a2a.auth.client_schemes import (
|
||||
APIKeyAuth,
|
||||
AuthScheme,
|
||||
BearerTokenAuth,
|
||||
ClientAuthScheme,
|
||||
HTTPBasicAuth,
|
||||
HTTPDigestAuth,
|
||||
OAuth2AuthorizationCode,
|
||||
OAuth2ClientCredentials,
|
||||
TLSConfig,
|
||||
)
|
||||
from crewai.a2a.auth.server_schemes import (
|
||||
AuthenticatedUser,
|
||||
OIDCAuth,
|
||||
ServerAuthScheme,
|
||||
SimpleTokenAuth,
|
||||
import warnings
|
||||
|
||||
|
||||
warnings.warn(
|
||||
"'crewai.a2a.auth' has been moved to 'crewai_a2a.auth'. "
|
||||
"Please update your imports. The old path will be removed in v2.0.0.",
|
||||
FutureWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"APIKeyAuth",
|
||||
"AuthScheme",
|
||||
"AuthenticatedUser",
|
||||
"BearerTokenAuth",
|
||||
"ClientAuthScheme",
|
||||
"HTTPBasicAuth",
|
||||
"HTTPDigestAuth",
|
||||
"OAuth2AuthorizationCode",
|
||||
"OAuth2ClientCredentials",
|
||||
"OIDCAuth",
|
||||
"ServerAuthScheme",
|
||||
"SimpleTokenAuth",
|
||||
"TLSConfig",
|
||||
]
|
||||
from crewai_a2a.auth import * # noqa: E402, F403
|
||||
|
||||
@@ -1,550 +1,13 @@
|
||||
"""Authentication schemes for A2A protocol clients.
|
||||
"""Backward-compatibility shim — use ``crewai_a2a.auth.client_schemes`` instead."""
|
||||
|
||||
Supported authentication methods:
|
||||
- Bearer tokens
|
||||
- OAuth2 (Client Credentials, Authorization Code)
|
||||
- API Keys (header, query, cookie)
|
||||
- HTTP Basic authentication
|
||||
- HTTP Digest authentication
|
||||
- mTLS (mutual TLS) client certificate authentication
|
||||
"""
|
||||
import warnings
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
import asyncio
|
||||
import base64
|
||||
from collections.abc import Awaitable, Callable, MutableMapping
|
||||
from pathlib import Path
|
||||
import ssl
|
||||
import time
|
||||
from typing import TYPE_CHECKING, ClassVar, Literal
|
||||
import urllib.parse
|
||||
warnings.warn(
|
||||
"'crewai.a2a.auth.client_schemes' has been moved to 'crewai_a2a.auth.client_schemes'. "
|
||||
"Please update your imports. The old path will be removed in v2.0.0.",
|
||||
FutureWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
import httpx
|
||||
from httpx import DigestAuth
|
||||
from pydantic import BaseModel, ConfigDict, Field, FilePath, PrivateAttr
|
||||
from typing_extensions import deprecated
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import grpc # type: ignore[import-untyped]
|
||||
|
||||
|
||||
class TLSConfig(BaseModel):
|
||||
"""TLS/mTLS configuration for secure client connections.
|
||||
|
||||
Supports mutual TLS (mTLS) where the client presents a certificate to the server,
|
||||
and standard TLS with custom CA verification.
|
||||
|
||||
Attributes:
|
||||
client_cert_path: Path to client certificate file (PEM format) for mTLS.
|
||||
client_key_path: Path to client private key file (PEM format) for mTLS.
|
||||
ca_cert_path: Path to CA certificate bundle for server verification.
|
||||
verify: Whether to verify server certificates. Set False only for development.
|
||||
"""
|
||||
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
|
||||
|
||||
client_cert_path: FilePath | None = Field(
|
||||
default=None,
|
||||
description="Path to client certificate file (PEM format) for mTLS",
|
||||
)
|
||||
client_key_path: FilePath | None = Field(
|
||||
default=None,
|
||||
description="Path to client private key file (PEM format) for mTLS",
|
||||
)
|
||||
ca_cert_path: FilePath | None = Field(
|
||||
default=None,
|
||||
description="Path to CA certificate bundle for server verification",
|
||||
)
|
||||
verify: bool = Field(
|
||||
default=True,
|
||||
description="Whether to verify server certificates. Set False only for development.",
|
||||
)
|
||||
|
||||
def get_httpx_ssl_context(self) -> ssl.SSLContext | bool | str:
|
||||
"""Build SSL context for httpx client.
|
||||
|
||||
Returns:
|
||||
SSL context if certificates configured, True for default verification,
|
||||
False if verification disabled, or path to CA bundle.
|
||||
"""
|
||||
if not self.verify:
|
||||
return False
|
||||
|
||||
if self.client_cert_path and self.client_key_path:
|
||||
context = ssl.create_default_context()
|
||||
|
||||
if self.ca_cert_path:
|
||||
context.load_verify_locations(cafile=str(self.ca_cert_path))
|
||||
|
||||
context.load_cert_chain(
|
||||
certfile=str(self.client_cert_path),
|
||||
keyfile=str(self.client_key_path),
|
||||
)
|
||||
return context
|
||||
|
||||
if self.ca_cert_path:
|
||||
return str(self.ca_cert_path)
|
||||
|
||||
return True
|
||||
|
||||
def get_grpc_credentials(self) -> grpc.ChannelCredentials | None: # type: ignore[no-any-unimported]
|
||||
"""Build gRPC channel credentials for secure connections.
|
||||
|
||||
Returns:
|
||||
gRPC SSL credentials if certificates configured, None otherwise.
|
||||
"""
|
||||
try:
|
||||
import grpc
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
if not self.verify and not self.client_cert_path:
|
||||
return None
|
||||
|
||||
root_certs: bytes | None = None
|
||||
private_key: bytes | None = None
|
||||
certificate_chain: bytes | None = None
|
||||
|
||||
if self.ca_cert_path:
|
||||
root_certs = Path(self.ca_cert_path).read_bytes()
|
||||
|
||||
if self.client_cert_path and self.client_key_path:
|
||||
private_key = Path(self.client_key_path).read_bytes()
|
||||
certificate_chain = Path(self.client_cert_path).read_bytes()
|
||||
|
||||
return grpc.ssl_channel_credentials(
|
||||
root_certificates=root_certs,
|
||||
private_key=private_key,
|
||||
certificate_chain=certificate_chain,
|
||||
)
|
||||
|
||||
|
||||
class ClientAuthScheme(ABC, BaseModel):
|
||||
"""Base class for client-side authentication schemes.
|
||||
|
||||
Client auth schemes apply credentials to outgoing requests.
|
||||
|
||||
Attributes:
|
||||
tls: Optional TLS/mTLS configuration for secure connections.
|
||||
"""
|
||||
|
||||
tls: TLSConfig | None = Field(
|
||||
default=None,
|
||||
description="TLS/mTLS configuration for secure connections",
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
async def apply_auth(
|
||||
self, client: httpx.AsyncClient, headers: MutableMapping[str, str]
|
||||
) -> MutableMapping[str, str]:
|
||||
"""Apply authentication to request headers.
|
||||
|
||||
Args:
|
||||
client: HTTP client for making auth requests.
|
||||
headers: Current request headers.
|
||||
|
||||
Returns:
|
||||
Updated headers with authentication applied.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
@deprecated("Use ClientAuthScheme instead", category=FutureWarning)
|
||||
class AuthScheme(ClientAuthScheme):
|
||||
"""Deprecated: Use ClientAuthScheme instead."""
|
||||
|
||||
|
||||
class BearerTokenAuth(ClientAuthScheme):
|
||||
"""Bearer token authentication (Authorization: Bearer <token>).
|
||||
|
||||
Attributes:
|
||||
token: Bearer token for authentication.
|
||||
"""
|
||||
|
||||
token: str = Field(description="Bearer token")
|
||||
|
||||
async def apply_auth(
|
||||
self, client: httpx.AsyncClient, headers: MutableMapping[str, str]
|
||||
) -> MutableMapping[str, str]:
|
||||
"""Apply Bearer token to Authorization header.
|
||||
|
||||
Args:
|
||||
client: HTTP client for making auth requests.
|
||||
headers: Current request headers.
|
||||
|
||||
Returns:
|
||||
Updated headers with Bearer token in Authorization header.
|
||||
"""
|
||||
headers["Authorization"] = f"Bearer {self.token}"
|
||||
return headers
|
||||
|
||||
|
||||
class HTTPBasicAuth(ClientAuthScheme):
|
||||
"""HTTP Basic authentication.
|
||||
|
||||
Attributes:
|
||||
username: Username for Basic authentication.
|
||||
password: Password for Basic authentication.
|
||||
"""
|
||||
|
||||
username: str = Field(description="Username")
|
||||
password: str = Field(description="Password")
|
||||
|
||||
async def apply_auth(
|
||||
self, client: httpx.AsyncClient, headers: MutableMapping[str, str]
|
||||
) -> MutableMapping[str, str]:
|
||||
"""Apply HTTP Basic authentication.
|
||||
|
||||
Args:
|
||||
client: HTTP client for making auth requests.
|
||||
headers: Current request headers.
|
||||
|
||||
Returns:
|
||||
Updated headers with Basic auth in Authorization header.
|
||||
"""
|
||||
credentials = f"{self.username}:{self.password}"
|
||||
encoded = base64.b64encode(credentials.encode()).decode()
|
||||
headers["Authorization"] = f"Basic {encoded}"
|
||||
return headers
|
||||
|
||||
|
||||
class HTTPDigestAuth(ClientAuthScheme):
|
||||
"""HTTP Digest authentication.
|
||||
|
||||
Note: Uses httpx-auth library for digest implementation.
|
||||
|
||||
Attributes:
|
||||
username: Username for Digest authentication.
|
||||
password: Password for Digest authentication.
|
||||
"""
|
||||
|
||||
username: str = Field(description="Username")
|
||||
password: str = Field(description="Password")
|
||||
|
||||
_configured_client_id: int | None = PrivateAttr(default=None)
|
||||
|
||||
async def apply_auth(
|
||||
self, client: httpx.AsyncClient, headers: MutableMapping[str, str]
|
||||
) -> MutableMapping[str, str]:
|
||||
"""Digest auth is handled by httpx auth flow, not headers.
|
||||
|
||||
Args:
|
||||
client: HTTP client for making auth requests.
|
||||
headers: Current request headers.
|
||||
|
||||
Returns:
|
||||
Unchanged headers (Digest auth handled by httpx auth flow).
|
||||
"""
|
||||
return headers
|
||||
|
||||
def configure_client(self, client: httpx.AsyncClient) -> None:
|
||||
"""Configure client with Digest auth.
|
||||
|
||||
Idempotent: Only configures the client once. Subsequent calls on the same
|
||||
client instance are no-ops to prevent overwriting auth configuration.
|
||||
|
||||
Args:
|
||||
client: HTTP client to configure with Digest authentication.
|
||||
"""
|
||||
client_id = id(client)
|
||||
if self._configured_client_id == client_id:
|
||||
return
|
||||
|
||||
client.auth = DigestAuth(self.username, self.password)
|
||||
self._configured_client_id = client_id
|
||||
|
||||
|
||||
class APIKeyAuth(ClientAuthScheme):
|
||||
"""API Key authentication (header, query, or cookie).
|
||||
|
||||
Attributes:
|
||||
api_key: API key value for authentication.
|
||||
location: Where to send the API key (header, query, or cookie).
|
||||
name: Parameter name for the API key (default: X-API-Key).
|
||||
"""
|
||||
|
||||
api_key: str = Field(description="API key value")
|
||||
location: Literal["header", "query", "cookie"] = Field(
|
||||
default="header", description="Where to send the API key"
|
||||
)
|
||||
name: str = Field(default="X-API-Key", description="Parameter name for the API key")
|
||||
|
||||
_configured_client_ids: set[int] = PrivateAttr(default_factory=set)
|
||||
|
||||
async def apply_auth(
|
||||
self, client: httpx.AsyncClient, headers: MutableMapping[str, str]
|
||||
) -> MutableMapping[str, str]:
|
||||
"""Apply API key authentication.
|
||||
|
||||
Args:
|
||||
client: HTTP client for making auth requests.
|
||||
headers: Current request headers.
|
||||
|
||||
Returns:
|
||||
Updated headers with API key (for header/cookie locations).
|
||||
"""
|
||||
if self.location == "header":
|
||||
headers[self.name] = self.api_key
|
||||
elif self.location == "cookie":
|
||||
headers["Cookie"] = f"{self.name}={self.api_key}"
|
||||
return headers
|
||||
|
||||
def configure_client(self, client: httpx.AsyncClient) -> None:
|
||||
"""Configure client for query param API keys.
|
||||
|
||||
Idempotent: Only adds the request hook once per client instance.
|
||||
Subsequent calls on the same client are no-ops to prevent hook accumulation.
|
||||
|
||||
Args:
|
||||
client: HTTP client to configure with query param API key hook.
|
||||
"""
|
||||
if self.location == "query":
|
||||
client_id = id(client)
|
||||
if client_id in self._configured_client_ids:
|
||||
return
|
||||
|
||||
async def _add_api_key_param(request: httpx.Request) -> None:
|
||||
url = httpx.URL(request.url)
|
||||
request.url = url.copy_add_param(self.name, self.api_key)
|
||||
|
||||
client.event_hooks["request"].append(_add_api_key_param)
|
||||
self._configured_client_ids.add(client_id)
|
||||
|
||||
|
||||
class OAuth2ClientCredentials(ClientAuthScheme):
|
||||
"""OAuth2 Client Credentials flow authentication.
|
||||
|
||||
Thread-safe implementation with asyncio.Lock to prevent concurrent token fetches
|
||||
when multiple requests share the same auth instance.
|
||||
|
||||
Attributes:
|
||||
token_url: OAuth2 token endpoint URL.
|
||||
client_id: OAuth2 client identifier.
|
||||
client_secret: OAuth2 client secret.
|
||||
scopes: List of required OAuth2 scopes.
|
||||
"""
|
||||
|
||||
token_url: str = Field(description="OAuth2 token endpoint")
|
||||
client_id: str = Field(description="OAuth2 client ID")
|
||||
client_secret: str = Field(description="OAuth2 client secret")
|
||||
scopes: list[str] = Field(
|
||||
default_factory=list, description="Required OAuth2 scopes"
|
||||
)
|
||||
|
||||
_access_token: str | None = PrivateAttr(default=None)
|
||||
_token_expires_at: float | None = PrivateAttr(default=None)
|
||||
_lock: asyncio.Lock = PrivateAttr(default_factory=asyncio.Lock)
|
||||
|
||||
async def apply_auth(
|
||||
self, client: httpx.AsyncClient, headers: MutableMapping[str, str]
|
||||
) -> MutableMapping[str, str]:
|
||||
"""Apply OAuth2 access token to Authorization header.
|
||||
|
||||
Uses asyncio.Lock to ensure only one coroutine fetches tokens at a time,
|
||||
preventing race conditions when multiple concurrent requests use the same
|
||||
auth instance.
|
||||
|
||||
Args:
|
||||
client: HTTP client for making token requests.
|
||||
headers: Current request headers.
|
||||
|
||||
Returns:
|
||||
Updated headers with OAuth2 access token in Authorization header.
|
||||
"""
|
||||
if (
|
||||
self._access_token is None
|
||||
or self._token_expires_at is None
|
||||
or time.time() >= self._token_expires_at
|
||||
):
|
||||
async with self._lock:
|
||||
if (
|
||||
self._access_token is None
|
||||
or self._token_expires_at is None
|
||||
or time.time() >= self._token_expires_at
|
||||
):
|
||||
await self._fetch_token(client)
|
||||
|
||||
if self._access_token:
|
||||
headers["Authorization"] = f"Bearer {self._access_token}"
|
||||
|
||||
return headers
|
||||
|
||||
async def _fetch_token(self, client: httpx.AsyncClient) -> None:
|
||||
"""Fetch OAuth2 access token using client credentials flow.
|
||||
|
||||
Args:
|
||||
client: HTTP client for making token request.
|
||||
|
||||
Raises:
|
||||
httpx.HTTPStatusError: If token request fails.
|
||||
"""
|
||||
data = {
|
||||
"grant_type": "client_credentials",
|
||||
"client_id": self.client_id,
|
||||
"client_secret": self.client_secret,
|
||||
}
|
||||
|
||||
if self.scopes:
|
||||
data["scope"] = " ".join(self.scopes)
|
||||
|
||||
response = await client.post(self.token_url, data=data)
|
||||
response.raise_for_status()
|
||||
|
||||
token_data = response.json()
|
||||
self._access_token = token_data["access_token"]
|
||||
expires_in = token_data.get("expires_in", 3600)
|
||||
self._token_expires_at = time.time() + expires_in - 60
|
||||
|
||||
|
||||
class OAuth2AuthorizationCode(ClientAuthScheme):
|
||||
"""OAuth2 Authorization Code flow authentication.
|
||||
|
||||
Thread-safe implementation with asyncio.Lock to prevent concurrent token operations.
|
||||
|
||||
Note: Requires interactive authorization.
|
||||
|
||||
Attributes:
|
||||
authorization_url: OAuth2 authorization endpoint URL.
|
||||
token_url: OAuth2 token endpoint URL.
|
||||
client_id: OAuth2 client identifier.
|
||||
client_secret: OAuth2 client secret.
|
||||
redirect_uri: OAuth2 redirect URI for callback.
|
||||
scopes: List of required OAuth2 scopes.
|
||||
"""
|
||||
|
||||
authorization_url: str = Field(description="OAuth2 authorization endpoint")
|
||||
token_url: str = Field(description="OAuth2 token endpoint")
|
||||
client_id: str = Field(description="OAuth2 client ID")
|
||||
client_secret: str = Field(description="OAuth2 client secret")
|
||||
redirect_uri: str = Field(description="OAuth2 redirect URI")
|
||||
scopes: list[str] = Field(
|
||||
default_factory=list, description="Required OAuth2 scopes"
|
||||
)
|
||||
|
||||
_access_token: str | None = PrivateAttr(default=None)
|
||||
_refresh_token: str | None = PrivateAttr(default=None)
|
||||
_token_expires_at: float | None = PrivateAttr(default=None)
|
||||
_authorization_callback: Callable[[str], Awaitable[str]] | None = PrivateAttr(
|
||||
default=None
|
||||
)
|
||||
_lock: asyncio.Lock = PrivateAttr(default_factory=asyncio.Lock)
|
||||
|
||||
def set_authorization_callback(
|
||||
self, callback: Callable[[str], Awaitable[str]] | None
|
||||
) -> None:
|
||||
"""Set callback to handle authorization URL.
|
||||
|
||||
Args:
|
||||
callback: Async function that receives authorization URL and returns auth code.
|
||||
"""
|
||||
self._authorization_callback = callback
|
||||
|
||||
async def apply_auth(
|
||||
self, client: httpx.AsyncClient, headers: MutableMapping[str, str]
|
||||
) -> MutableMapping[str, str]:
|
||||
"""Apply OAuth2 access token to Authorization header.
|
||||
|
||||
Uses asyncio.Lock to ensure only one coroutine handles token operations
|
||||
(initial fetch or refresh) at a time.
|
||||
|
||||
Args:
|
||||
client: HTTP client for making token requests.
|
||||
headers: Current request headers.
|
||||
|
||||
Returns:
|
||||
Updated headers with OAuth2 access token in Authorization header.
|
||||
|
||||
Raises:
|
||||
ValueError: If authorization callback is not set.
|
||||
"""
|
||||
if self._access_token is None:
|
||||
if self._authorization_callback is None:
|
||||
msg = "Authorization callback not set. Use set_authorization_callback()"
|
||||
raise ValueError(msg)
|
||||
async with self._lock:
|
||||
if self._access_token is None:
|
||||
await self._fetch_initial_token(client)
|
||||
elif self._token_expires_at and time.time() >= self._token_expires_at:
|
||||
async with self._lock:
|
||||
if self._token_expires_at and time.time() >= self._token_expires_at:
|
||||
await self._refresh_access_token(client)
|
||||
|
||||
if self._access_token:
|
||||
headers["Authorization"] = f"Bearer {self._access_token}"
|
||||
|
||||
return headers
|
||||
|
||||
async def _fetch_initial_token(self, client: httpx.AsyncClient) -> None:
|
||||
"""Fetch initial access token using authorization code flow.
|
||||
|
||||
Args:
|
||||
client: HTTP client for making token request.
|
||||
|
||||
Raises:
|
||||
ValueError: If authorization callback is not set.
|
||||
httpx.HTTPStatusError: If token request fails.
|
||||
"""
|
||||
params = {
|
||||
"response_type": "code",
|
||||
"client_id": self.client_id,
|
||||
"redirect_uri": self.redirect_uri,
|
||||
"scope": " ".join(self.scopes),
|
||||
}
|
||||
auth_url = f"{self.authorization_url}?{urllib.parse.urlencode(params)}"
|
||||
|
||||
if self._authorization_callback is None:
|
||||
msg = "Authorization callback not set"
|
||||
raise ValueError(msg)
|
||||
auth_code = await self._authorization_callback(auth_url)
|
||||
|
||||
data = {
|
||||
"grant_type": "authorization_code",
|
||||
"code": auth_code,
|
||||
"client_id": self.client_id,
|
||||
"client_secret": self.client_secret,
|
||||
"redirect_uri": self.redirect_uri,
|
||||
}
|
||||
|
||||
response = await client.post(self.token_url, data=data)
|
||||
response.raise_for_status()
|
||||
|
||||
token_data = response.json()
|
||||
self._access_token = token_data["access_token"]
|
||||
self._refresh_token = token_data.get("refresh_token")
|
||||
|
||||
expires_in = token_data.get("expires_in", 3600)
|
||||
self._token_expires_at = time.time() + expires_in - 60
|
||||
|
||||
async def _refresh_access_token(self, client: httpx.AsyncClient) -> None:
|
||||
"""Refresh the access token using refresh token.
|
||||
|
||||
Args:
|
||||
client: HTTP client for making token request.
|
||||
|
||||
Raises:
|
||||
httpx.HTTPStatusError: If token refresh request fails.
|
||||
"""
|
||||
if not self._refresh_token:
|
||||
await self._fetch_initial_token(client)
|
||||
return
|
||||
|
||||
data = {
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": self._refresh_token,
|
||||
"client_id": self.client_id,
|
||||
"client_secret": self.client_secret,
|
||||
}
|
||||
|
||||
response = await client.post(self.token_url, data=data)
|
||||
response.raise_for_status()
|
||||
|
||||
token_data = response.json()
|
||||
self._access_token = token_data["access_token"]
|
||||
if "refresh_token" in token_data:
|
||||
self._refresh_token = token_data["refresh_token"]
|
||||
|
||||
expires_in = token_data.get("expires_in", 3600)
|
||||
self._token_expires_at = time.time() + expires_in - 60
|
||||
from crewai_a2a.auth.client_schemes import * # noqa: E402, F403
|
||||
|
||||
@@ -1,71 +1,13 @@
|
||||
"""Deprecated: Authentication schemes for A2A protocol agents.
|
||||
"""Backward-compatibility shim — use ``crewai_a2a.auth.schemas`` instead."""
|
||||
|
||||
This module is deprecated. Import from crewai.a2a.auth instead:
|
||||
- crewai.a2a.auth.ClientAuthScheme (replaces AuthScheme)
|
||||
- crewai.a2a.auth.BearerTokenAuth
|
||||
- crewai.a2a.auth.HTTPBasicAuth
|
||||
- crewai.a2a.auth.HTTPDigestAuth
|
||||
- crewai.a2a.auth.APIKeyAuth
|
||||
- crewai.a2a.auth.OAuth2ClientCredentials
|
||||
- crewai.a2a.auth.OAuth2AuthorizationCode
|
||||
"""
|
||||
import warnings
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing_extensions import deprecated
|
||||
|
||||
from crewai.a2a.auth.client_schemes import (
|
||||
APIKeyAuth as _APIKeyAuth,
|
||||
BearerTokenAuth as _BearerTokenAuth,
|
||||
ClientAuthScheme as _ClientAuthScheme,
|
||||
HTTPBasicAuth as _HTTPBasicAuth,
|
||||
HTTPDigestAuth as _HTTPDigestAuth,
|
||||
OAuth2AuthorizationCode as _OAuth2AuthorizationCode,
|
||||
OAuth2ClientCredentials as _OAuth2ClientCredentials,
|
||||
warnings.warn(
|
||||
"'crewai.a2a.auth.schemas' has been moved to 'crewai_a2a.auth.schemas'. "
|
||||
"Please update your imports. The old path will be removed in v2.0.0.",
|
||||
FutureWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
|
||||
@deprecated("Use ClientAuthScheme from crewai.a2a.auth instead", category=FutureWarning)
|
||||
class AuthScheme(_ClientAuthScheme):
|
||||
"""Deprecated: Use ClientAuthScheme from crewai.a2a.auth instead."""
|
||||
|
||||
|
||||
@deprecated("Import from crewai.a2a.auth instead", category=FutureWarning)
|
||||
class BearerTokenAuth(_BearerTokenAuth):
|
||||
"""Deprecated: Import from crewai.a2a.auth instead."""
|
||||
|
||||
|
||||
@deprecated("Import from crewai.a2a.auth instead", category=FutureWarning)
|
||||
class HTTPBasicAuth(_HTTPBasicAuth):
|
||||
"""Deprecated: Import from crewai.a2a.auth instead."""
|
||||
|
||||
|
||||
@deprecated("Import from crewai.a2a.auth instead", category=FutureWarning)
|
||||
class HTTPDigestAuth(_HTTPDigestAuth):
|
||||
"""Deprecated: Import from crewai.a2a.auth instead."""
|
||||
|
||||
|
||||
@deprecated("Import from crewai.a2a.auth instead", category=FutureWarning)
|
||||
class APIKeyAuth(_APIKeyAuth):
|
||||
"""Deprecated: Import from crewai.a2a.auth instead."""
|
||||
|
||||
|
||||
@deprecated("Import from crewai.a2a.auth instead", category=FutureWarning)
|
||||
class OAuth2ClientCredentials(_OAuth2ClientCredentials):
|
||||
"""Deprecated: Import from crewai.a2a.auth instead."""
|
||||
|
||||
|
||||
@deprecated("Import from crewai.a2a.auth instead", category=FutureWarning)
|
||||
class OAuth2AuthorizationCode(_OAuth2AuthorizationCode):
|
||||
"""Deprecated: Import from crewai.a2a.auth instead."""
|
||||
|
||||
|
||||
__all__ = [
|
||||
"APIKeyAuth",
|
||||
"AuthScheme",
|
||||
"BearerTokenAuth",
|
||||
"HTTPBasicAuth",
|
||||
"HTTPDigestAuth",
|
||||
"OAuth2AuthorizationCode",
|
||||
"OAuth2ClientCredentials",
|
||||
]
|
||||
from crewai_a2a.auth.schemas import * # noqa: E402, F403
|
||||
|
||||
@@ -1,739 +1,13 @@
|
||||
"""Server-side authentication schemes for A2A protocol.
|
||||
"""Backward-compatibility shim — use ``crewai_a2a.auth.server_schemes`` instead."""
|
||||
|
||||
These schemes validate incoming requests to A2A server endpoints.
|
||||
import warnings
|
||||
|
||||
Supported authentication methods:
|
||||
- Simple token validation with static bearer tokens
|
||||
- OpenID Connect with JWT validation using JWKS
|
||||
- OAuth2 with JWT validation or token introspection
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
import logging
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Annotated, Any, ClassVar, Literal
|
||||
|
||||
import jwt
|
||||
from jwt import PyJWKClient
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
BeforeValidator,
|
||||
ConfigDict,
|
||||
Field,
|
||||
HttpUrl,
|
||||
PrivateAttr,
|
||||
SecretStr,
|
||||
model_validator,
|
||||
warnings.warn(
|
||||
"'crewai.a2a.auth.server_schemes' has been moved to 'crewai_a2a.auth.server_schemes'. "
|
||||
"Please update your imports. The old path will be removed in v2.0.0.",
|
||||
FutureWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
from typing_extensions import Self
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from a2a.types import OAuth2SecurityScheme
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
try:
|
||||
from fastapi import HTTPException, status as http_status
|
||||
|
||||
HTTP_401_UNAUTHORIZED = http_status.HTTP_401_UNAUTHORIZED
|
||||
HTTP_500_INTERNAL_SERVER_ERROR = http_status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
HTTP_503_SERVICE_UNAVAILABLE = http_status.HTTP_503_SERVICE_UNAVAILABLE
|
||||
except ImportError:
|
||||
|
||||
class HTTPException(Exception): # type: ignore[no-redef] # noqa: N818
|
||||
"""Fallback HTTPException when FastAPI is not installed."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
status_code: int,
|
||||
detail: str | None = None,
|
||||
headers: dict[str, str] | None = None,
|
||||
) -> None:
|
||||
self.status_code = status_code
|
||||
self.detail = detail
|
||||
self.headers = headers
|
||||
super().__init__(detail)
|
||||
|
||||
HTTP_401_UNAUTHORIZED = 401
|
||||
HTTP_500_INTERNAL_SERVER_ERROR = 500
|
||||
HTTP_503_SERVICE_UNAVAILABLE = 503
|
||||
|
||||
|
||||
def _coerce_secret_str(v: str | SecretStr | None) -> SecretStr | None:
|
||||
"""Coerce string to SecretStr."""
|
||||
if v is None or isinstance(v, SecretStr):
|
||||
return v
|
||||
return SecretStr(v)
|
||||
|
||||
|
||||
CoercedSecretStr = Annotated[SecretStr, BeforeValidator(_coerce_secret_str)]
|
||||
|
||||
JWTAlgorithm = Literal[
|
||||
"RS256",
|
||||
"RS384",
|
||||
"RS512",
|
||||
"ES256",
|
||||
"ES384",
|
||||
"ES512",
|
||||
"PS256",
|
||||
"PS384",
|
||||
"PS512",
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class AuthenticatedUser:
|
||||
"""Result of successful authentication.
|
||||
|
||||
Attributes:
|
||||
token: The original token that was validated.
|
||||
scheme: Name of the authentication scheme used.
|
||||
claims: JWT claims from OIDC or OAuth2 authentication.
|
||||
"""
|
||||
|
||||
token: str
|
||||
scheme: str
|
||||
claims: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class ServerAuthScheme(ABC, BaseModel):
|
||||
"""Base class for server-side authentication schemes.
|
||||
|
||||
Each scheme validates incoming requests and returns an AuthenticatedUser
|
||||
on success, or raises HTTPException on failure.
|
||||
"""
|
||||
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
|
||||
|
||||
@abstractmethod
|
||||
async def authenticate(self, token: str) -> AuthenticatedUser:
|
||||
"""Authenticate the provided token.
|
||||
|
||||
Args:
|
||||
token: The bearer token to authenticate.
|
||||
|
||||
Returns:
|
||||
AuthenticatedUser on successful authentication.
|
||||
|
||||
Raises:
|
||||
HTTPException: If authentication fails.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class SimpleTokenAuth(ServerAuthScheme):
|
||||
"""Simple bearer token authentication.
|
||||
|
||||
Validates tokens against a configured static token or AUTH_TOKEN env var.
|
||||
|
||||
Attributes:
|
||||
token: Expected token value. Falls back to AUTH_TOKEN env var if not set.
|
||||
"""
|
||||
|
||||
token: CoercedSecretStr | None = Field(
|
||||
default=None,
|
||||
description="Expected token. Falls back to AUTH_TOKEN env var.",
|
||||
)
|
||||
|
||||
def _get_expected_token(self) -> str | None:
|
||||
"""Get the expected token value."""
|
||||
if self.token:
|
||||
return self.token.get_secret_value()
|
||||
return os.environ.get("AUTH_TOKEN")
|
||||
|
||||
async def authenticate(self, token: str) -> AuthenticatedUser:
|
||||
"""Authenticate using simple token comparison.
|
||||
|
||||
Args:
|
||||
token: The bearer token to authenticate.
|
||||
|
||||
Returns:
|
||||
AuthenticatedUser on successful authentication.
|
||||
|
||||
Raises:
|
||||
HTTPException: If authentication fails.
|
||||
"""
|
||||
expected = self._get_expected_token()
|
||||
|
||||
if expected is None:
|
||||
logger.warning(
|
||||
"Simple token authentication failed",
|
||||
extra={"reason": "no_token_configured"},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=HTTP_401_UNAUTHORIZED,
|
||||
detail="Authentication not configured",
|
||||
)
|
||||
|
||||
if token != expected:
|
||||
raise HTTPException(
|
||||
status_code=HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid or missing authentication credentials",
|
||||
)
|
||||
|
||||
return AuthenticatedUser(
|
||||
token=token,
|
||||
scheme="simple_token",
|
||||
)
|
||||
|
||||
|
||||
class OIDCAuth(ServerAuthScheme):
|
||||
"""OpenID Connect authentication.
|
||||
|
||||
Validates JWTs using JWKS with caching support via PyJWT.
|
||||
|
||||
Attributes:
|
||||
issuer: The OpenID Connect issuer URL.
|
||||
audience: The expected audience claim.
|
||||
jwks_url: Optional explicit JWKS URL. Derived from issuer if not set.
|
||||
algorithms: List of allowed signing algorithms.
|
||||
required_claims: List of claims that must be present in the token.
|
||||
jwks_cache_ttl: TTL for JWKS cache in seconds.
|
||||
clock_skew_seconds: Allowed clock skew for token validation.
|
||||
"""
|
||||
|
||||
issuer: HttpUrl = Field(
|
||||
description="OpenID Connect issuer URL (e.g., https://auth.example.com)"
|
||||
)
|
||||
audience: str = Field(description="Expected audience claim (e.g., api://my-agent)")
|
||||
jwks_url: HttpUrl | None = Field(
|
||||
default=None,
|
||||
description="Explicit JWKS URL. Derived from issuer if not set.",
|
||||
)
|
||||
algorithms: list[str] = Field(
|
||||
default_factory=lambda: ["RS256"],
|
||||
description="List of allowed signing algorithms (RS256, ES256, etc.)",
|
||||
)
|
||||
required_claims: list[str] = Field(
|
||||
default_factory=lambda: ["exp", "iat", "iss", "aud", "sub"],
|
||||
description="List of claims that must be present in the token",
|
||||
)
|
||||
jwks_cache_ttl: int = Field(
|
||||
default=3600,
|
||||
description="TTL for JWKS cache in seconds",
|
||||
ge=60,
|
||||
)
|
||||
clock_skew_seconds: float = Field(
|
||||
default=30.0,
|
||||
description="Allowed clock skew for token validation",
|
||||
ge=0.0,
|
||||
)
|
||||
|
||||
_jwk_client: PyJWKClient | None = PrivateAttr(default=None)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _init_jwk_client(self) -> Self:
|
||||
"""Initialize the JWK client after model creation."""
|
||||
jwks_url = (
|
||||
str(self.jwks_url)
|
||||
if self.jwks_url
|
||||
else f"{str(self.issuer).rstrip('/')}/.well-known/jwks.json"
|
||||
)
|
||||
self._jwk_client = PyJWKClient(jwks_url, lifespan=self.jwks_cache_ttl)
|
||||
return self
|
||||
|
||||
async def authenticate(self, token: str) -> AuthenticatedUser:
|
||||
"""Authenticate using OIDC JWT validation.
|
||||
|
||||
Args:
|
||||
token: The JWT to authenticate.
|
||||
|
||||
Returns:
|
||||
AuthenticatedUser on successful authentication.
|
||||
|
||||
Raises:
|
||||
HTTPException: If authentication fails.
|
||||
"""
|
||||
if self._jwk_client is None:
|
||||
raise HTTPException(
|
||||
status_code=HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="OIDC not initialized",
|
||||
)
|
||||
|
||||
try:
|
||||
signing_key = self._jwk_client.get_signing_key_from_jwt(token)
|
||||
|
||||
claims = jwt.decode(
|
||||
token,
|
||||
signing_key.key,
|
||||
algorithms=self.algorithms,
|
||||
audience=self.audience,
|
||||
issuer=str(self.issuer).rstrip("/"),
|
||||
leeway=self.clock_skew_seconds,
|
||||
options={
|
||||
"require": self.required_claims,
|
||||
},
|
||||
)
|
||||
|
||||
return AuthenticatedUser(
|
||||
token=token,
|
||||
scheme="oidc",
|
||||
claims=claims,
|
||||
)
|
||||
|
||||
except jwt.ExpiredSignatureError:
|
||||
logger.debug(
|
||||
"OIDC authentication failed",
|
||||
extra={"reason": "token_expired", "scheme": "oidc"},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=HTTP_401_UNAUTHORIZED,
|
||||
detail="Token has expired",
|
||||
) from None
|
||||
except jwt.InvalidAudienceError:
|
||||
logger.debug(
|
||||
"OIDC authentication failed",
|
||||
extra={"reason": "invalid_audience", "scheme": "oidc"},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid token audience",
|
||||
) from None
|
||||
except jwt.InvalidIssuerError:
|
||||
logger.debug(
|
||||
"OIDC authentication failed",
|
||||
extra={"reason": "invalid_issuer", "scheme": "oidc"},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid token issuer",
|
||||
) from None
|
||||
except jwt.MissingRequiredClaimError as e:
|
||||
logger.debug(
|
||||
"OIDC authentication failed",
|
||||
extra={"reason": "missing_claim", "claim": e.claim, "scheme": "oidc"},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=HTTP_401_UNAUTHORIZED,
|
||||
detail=f"Missing required claim: {e.claim}",
|
||||
) from None
|
||||
except jwt.PyJWKClientError as e:
|
||||
logger.error(
|
||||
"OIDC authentication failed",
|
||||
extra={
|
||||
"reason": "jwks_client_error",
|
||||
"error": str(e),
|
||||
"scheme": "oidc",
|
||||
},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="Unable to fetch signing keys",
|
||||
) from None
|
||||
except jwt.InvalidTokenError as e:
|
||||
logger.debug(
|
||||
"OIDC authentication failed",
|
||||
extra={"reason": "invalid_token", "error": str(e), "scheme": "oidc"},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid or missing authentication credentials",
|
||||
) from None
|
||||
|
||||
|
||||
class OAuth2ServerAuth(ServerAuthScheme):
|
||||
"""OAuth2 authentication for A2A server.
|
||||
|
||||
Declares OAuth2 security scheme in AgentCard and validates tokens using
|
||||
either JWKS for JWT tokens or token introspection for opaque tokens.
|
||||
|
||||
This is distinct from OIDCAuth in that it declares an explicit OAuth2SecurityScheme
|
||||
with flows, rather than an OpenIdConnectSecurityScheme with discovery URL.
|
||||
|
||||
Attributes:
|
||||
token_url: OAuth2 token endpoint URL for client_credentials flow.
|
||||
authorization_url: OAuth2 authorization endpoint for authorization_code flow.
|
||||
refresh_url: Optional refresh token endpoint URL.
|
||||
scopes: Available OAuth2 scopes with descriptions.
|
||||
jwks_url: JWKS URL for JWT validation. Required if not using introspection.
|
||||
introspection_url: Token introspection endpoint (RFC 7662). Alternative to JWKS.
|
||||
introspection_client_id: Client ID for introspection endpoint authentication.
|
||||
introspection_client_secret: Client secret for introspection endpoint.
|
||||
audience: Expected audience claim for JWT validation.
|
||||
issuer: Expected issuer claim for JWT validation.
|
||||
algorithms: Allowed JWT signing algorithms.
|
||||
required_claims: Claims that must be present in the token.
|
||||
jwks_cache_ttl: TTL for JWKS cache in seconds.
|
||||
clock_skew_seconds: Allowed clock skew for token validation.
|
||||
"""
|
||||
|
||||
token_url: HttpUrl = Field(
|
||||
description="OAuth2 token endpoint URL",
|
||||
)
|
||||
authorization_url: HttpUrl | None = Field(
|
||||
default=None,
|
||||
description="OAuth2 authorization endpoint URL for authorization_code flow",
|
||||
)
|
||||
refresh_url: HttpUrl | None = Field(
|
||||
default=None,
|
||||
description="OAuth2 refresh token endpoint URL",
|
||||
)
|
||||
scopes: dict[str, str] = Field(
|
||||
default_factory=dict,
|
||||
description="Available OAuth2 scopes with descriptions",
|
||||
)
|
||||
jwks_url: HttpUrl | None = Field(
|
||||
default=None,
|
||||
description="JWKS URL for JWT validation. Required if not using introspection.",
|
||||
)
|
||||
introspection_url: HttpUrl | None = Field(
|
||||
default=None,
|
||||
description="Token introspection endpoint (RFC 7662). Alternative to JWKS.",
|
||||
)
|
||||
introspection_client_id: str | None = Field(
|
||||
default=None,
|
||||
description="Client ID for introspection endpoint authentication",
|
||||
)
|
||||
introspection_client_secret: CoercedSecretStr | None = Field(
|
||||
default=None,
|
||||
description="Client secret for introspection endpoint authentication",
|
||||
)
|
||||
audience: str | None = Field(
|
||||
default=None,
|
||||
description="Expected audience claim for JWT validation",
|
||||
)
|
||||
issuer: str | None = Field(
|
||||
default=None,
|
||||
description="Expected issuer claim for JWT validation",
|
||||
)
|
||||
algorithms: list[str] = Field(
|
||||
default_factory=lambda: ["RS256"],
|
||||
description="Allowed JWT signing algorithms",
|
||||
)
|
||||
required_claims: list[str] = Field(
|
||||
default_factory=lambda: ["exp", "iat"],
|
||||
description="Claims that must be present in the token",
|
||||
)
|
||||
jwks_cache_ttl: int = Field(
|
||||
default=3600,
|
||||
description="TTL for JWKS cache in seconds",
|
||||
ge=60,
|
||||
)
|
||||
clock_skew_seconds: float = Field(
|
||||
default=30.0,
|
||||
description="Allowed clock skew for token validation",
|
||||
ge=0.0,
|
||||
)
|
||||
|
||||
_jwk_client: PyJWKClient | None = PrivateAttr(default=None)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _validate_and_init(self) -> Self:
|
||||
"""Validate configuration and initialize JWKS client if needed."""
|
||||
if not self.jwks_url and not self.introspection_url:
|
||||
raise ValueError(
|
||||
"Either jwks_url or introspection_url must be provided for token validation"
|
||||
)
|
||||
|
||||
if self.introspection_url:
|
||||
if not self.introspection_client_id or not self.introspection_client_secret:
|
||||
raise ValueError(
|
||||
"introspection_client_id and introspection_client_secret are required "
|
||||
"when using token introspection"
|
||||
)
|
||||
|
||||
if self.jwks_url:
|
||||
self._jwk_client = PyJWKClient(
|
||||
str(self.jwks_url), lifespan=self.jwks_cache_ttl
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
async def authenticate(self, token: str) -> AuthenticatedUser:
|
||||
"""Authenticate using OAuth2 token validation.
|
||||
|
||||
Uses JWKS validation if jwks_url is configured, otherwise falls back
|
||||
to token introspection.
|
||||
|
||||
Args:
|
||||
token: The OAuth2 access token to authenticate.
|
||||
|
||||
Returns:
|
||||
AuthenticatedUser on successful authentication.
|
||||
|
||||
Raises:
|
||||
HTTPException: If authentication fails.
|
||||
"""
|
||||
if self._jwk_client:
|
||||
return await self._authenticate_jwt(token)
|
||||
return await self._authenticate_introspection(token)
|
||||
|
||||
async def _authenticate_jwt(self, token: str) -> AuthenticatedUser:
|
||||
"""Authenticate using JWKS JWT validation."""
|
||||
if self._jwk_client is None:
|
||||
raise HTTPException(
|
||||
status_code=HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="OAuth2 JWKS not initialized",
|
||||
)
|
||||
|
||||
try:
|
||||
signing_key = self._jwk_client.get_signing_key_from_jwt(token)
|
||||
|
||||
decode_options: dict[str, Any] = {
|
||||
"require": self.required_claims,
|
||||
}
|
||||
|
||||
claims = jwt.decode(
|
||||
token,
|
||||
signing_key.key,
|
||||
algorithms=self.algorithms,
|
||||
audience=self.audience,
|
||||
issuer=self.issuer,
|
||||
leeway=self.clock_skew_seconds,
|
||||
options=decode_options,
|
||||
)
|
||||
|
||||
return AuthenticatedUser(
|
||||
token=token,
|
||||
scheme="oauth2",
|
||||
claims=claims,
|
||||
)
|
||||
|
||||
except jwt.ExpiredSignatureError:
|
||||
logger.debug(
|
||||
"OAuth2 authentication failed",
|
||||
extra={"reason": "token_expired", "scheme": "oauth2"},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=HTTP_401_UNAUTHORIZED,
|
||||
detail="Token has expired",
|
||||
) from None
|
||||
except jwt.InvalidAudienceError:
|
||||
logger.debug(
|
||||
"OAuth2 authentication failed",
|
||||
extra={"reason": "invalid_audience", "scheme": "oauth2"},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid token audience",
|
||||
) from None
|
||||
except jwt.InvalidIssuerError:
|
||||
logger.debug(
|
||||
"OAuth2 authentication failed",
|
||||
extra={"reason": "invalid_issuer", "scheme": "oauth2"},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid token issuer",
|
||||
) from None
|
||||
except jwt.MissingRequiredClaimError as e:
|
||||
logger.debug(
|
||||
"OAuth2 authentication failed",
|
||||
extra={"reason": "missing_claim", "claim": e.claim, "scheme": "oauth2"},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=HTTP_401_UNAUTHORIZED,
|
||||
detail=f"Missing required claim: {e.claim}",
|
||||
) from None
|
||||
except jwt.PyJWKClientError as e:
|
||||
logger.error(
|
||||
"OAuth2 authentication failed",
|
||||
extra={
|
||||
"reason": "jwks_client_error",
|
||||
"error": str(e),
|
||||
"scheme": "oauth2",
|
||||
},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="Unable to fetch signing keys",
|
||||
) from None
|
||||
except jwt.InvalidTokenError as e:
|
||||
logger.debug(
|
||||
"OAuth2 authentication failed",
|
||||
extra={"reason": "invalid_token", "error": str(e), "scheme": "oauth2"},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid or missing authentication credentials",
|
||||
) from None
|
||||
|
||||
async def _authenticate_introspection(self, token: str) -> AuthenticatedUser:
|
||||
"""Authenticate using OAuth2 token introspection (RFC 7662)."""
|
||||
import httpx
|
||||
|
||||
if not self.introspection_url:
|
||||
raise HTTPException(
|
||||
status_code=HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="OAuth2 introspection not configured",
|
||||
)
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
str(self.introspection_url),
|
||||
data={"token": token},
|
||||
auth=(
|
||||
self.introspection_client_id or "",
|
||||
self.introspection_client_secret.get_secret_value()
|
||||
if self.introspection_client_secret
|
||||
else "",
|
||||
),
|
||||
)
|
||||
response.raise_for_status()
|
||||
introspection_result = response.json()
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(
|
||||
"OAuth2 introspection failed",
|
||||
extra={"reason": "http_error", "status_code": e.response.status_code},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="Token introspection service unavailable",
|
||||
) from None
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"OAuth2 introspection failed",
|
||||
extra={"reason": "unexpected_error", "error": str(e)},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="Token introspection failed",
|
||||
) from None
|
||||
|
||||
if not introspection_result.get("active", False):
|
||||
logger.debug(
|
||||
"OAuth2 authentication failed",
|
||||
extra={"reason": "token_not_active", "scheme": "oauth2"},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=HTTP_401_UNAUTHORIZED,
|
||||
detail="Token is not active",
|
||||
)
|
||||
|
||||
return AuthenticatedUser(
|
||||
token=token,
|
||||
scheme="oauth2",
|
||||
claims=introspection_result,
|
||||
)
|
||||
|
||||
def to_security_scheme(self) -> OAuth2SecurityScheme:
|
||||
"""Generate OAuth2SecurityScheme for AgentCard declaration.
|
||||
|
||||
Creates an OAuth2SecurityScheme with appropriate flows based on
|
||||
the configured URLs. Includes client_credentials flow if token_url
|
||||
is set, and authorization_code flow if authorization_url is set.
|
||||
|
||||
Returns:
|
||||
OAuth2SecurityScheme suitable for use in AgentCard security_schemes.
|
||||
"""
|
||||
from a2a.types import (
|
||||
AuthorizationCodeOAuthFlow,
|
||||
ClientCredentialsOAuthFlow,
|
||||
OAuth2SecurityScheme,
|
||||
OAuthFlows,
|
||||
)
|
||||
|
||||
client_credentials = None
|
||||
authorization_code = None
|
||||
|
||||
if self.token_url:
|
||||
client_credentials = ClientCredentialsOAuthFlow(
|
||||
token_url=str(self.token_url),
|
||||
refresh_url=str(self.refresh_url) if self.refresh_url else None,
|
||||
scopes=self.scopes,
|
||||
)
|
||||
|
||||
if self.authorization_url:
|
||||
authorization_code = AuthorizationCodeOAuthFlow(
|
||||
authorization_url=str(self.authorization_url),
|
||||
token_url=str(self.token_url),
|
||||
refresh_url=str(self.refresh_url) if self.refresh_url else None,
|
||||
scopes=self.scopes,
|
||||
)
|
||||
|
||||
return OAuth2SecurityScheme(
|
||||
flows=OAuthFlows(
|
||||
client_credentials=client_credentials,
|
||||
authorization_code=authorization_code,
|
||||
),
|
||||
description="OAuth2 authentication",
|
||||
)
|
||||
|
||||
|
||||
class APIKeyServerAuth(ServerAuthScheme):
|
||||
"""API Key authentication for A2A server.
|
||||
|
||||
Validates requests using an API key in a header, query parameter, or cookie.
|
||||
|
||||
Attributes:
|
||||
name: The name of the API key parameter (default: X-API-Key).
|
||||
location: Where to look for the API key (header, query, or cookie).
|
||||
api_key: The expected API key value.
|
||||
"""
|
||||
|
||||
name: str = Field(
|
||||
default="X-API-Key",
|
||||
description="Name of the API key parameter",
|
||||
)
|
||||
location: Literal["header", "query", "cookie"] = Field(
|
||||
default="header",
|
||||
description="Where to look for the API key",
|
||||
)
|
||||
api_key: CoercedSecretStr = Field(
|
||||
description="Expected API key value",
|
||||
)
|
||||
|
||||
async def authenticate(self, token: str) -> AuthenticatedUser:
|
||||
"""Authenticate using API key comparison.
|
||||
|
||||
Args:
|
||||
token: The API key to authenticate.
|
||||
|
||||
Returns:
|
||||
AuthenticatedUser on successful authentication.
|
||||
|
||||
Raises:
|
||||
HTTPException: If authentication fails.
|
||||
"""
|
||||
if token != self.api_key.get_secret_value():
|
||||
raise HTTPException(
|
||||
status_code=HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid API key",
|
||||
)
|
||||
|
||||
return AuthenticatedUser(
|
||||
token=token,
|
||||
scheme="api_key",
|
||||
)
|
||||
|
||||
|
||||
class MTLSServerAuth(ServerAuthScheme):
|
||||
"""Mutual TLS authentication marker for AgentCard declaration.
|
||||
|
||||
This scheme is primarily for AgentCard security_schemes declaration.
|
||||
Actual mTLS verification happens at the TLS/transport layer, not
|
||||
at the application layer via token validation.
|
||||
|
||||
When configured, this signals to clients that the server requires
|
||||
client certificates for authentication.
|
||||
"""
|
||||
|
||||
description: str = Field(
|
||||
default="Mutual TLS certificate authentication",
|
||||
description="Description for the security scheme",
|
||||
)
|
||||
|
||||
async def authenticate(self, token: str) -> AuthenticatedUser:
|
||||
"""Return authenticated user for mTLS.
|
||||
|
||||
mTLS verification happens at the transport layer before this is called.
|
||||
If we reach this point, the TLS handshake with client cert succeeded.
|
||||
|
||||
Args:
|
||||
token: Certificate subject or identifier (from TLS layer).
|
||||
|
||||
Returns:
|
||||
AuthenticatedUser indicating mTLS authentication.
|
||||
"""
|
||||
return AuthenticatedUser(
|
||||
token=token or "mtls-verified",
|
||||
scheme="mtls",
|
||||
)
|
||||
from crewai_a2a.auth.server_schemes import * # noqa: E402, F403
|
||||
|
||||
@@ -1,273 +1,13 @@
|
||||
"""Authentication utilities for A2A protocol agent communication.
|
||||
"""Backward-compatibility shim — use ``crewai_a2a.auth.utils`` instead."""
|
||||
|
||||
Provides validation and retry logic for various authentication schemes including
|
||||
OAuth2, API keys, and HTTP authentication methods.
|
||||
"""
|
||||
import warnings
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Awaitable, Callable, MutableMapping
|
||||
import hashlib
|
||||
import re
|
||||
import threading
|
||||
from typing import Final, Literal, cast
|
||||
|
||||
from a2a.client.errors import A2AClientHTTPError
|
||||
from a2a.types import (
|
||||
APIKeySecurityScheme,
|
||||
AgentCard,
|
||||
HTTPAuthSecurityScheme,
|
||||
OAuth2SecurityScheme,
|
||||
)
|
||||
from httpx import AsyncClient, Response
|
||||
|
||||
from crewai.a2a.auth.client_schemes import (
|
||||
APIKeyAuth,
|
||||
BearerTokenAuth,
|
||||
ClientAuthScheme,
|
||||
HTTPBasicAuth,
|
||||
HTTPDigestAuth,
|
||||
OAuth2AuthorizationCode,
|
||||
OAuth2ClientCredentials,
|
||||
warnings.warn(
|
||||
"'crewai.a2a.auth.utils' has been moved to 'crewai_a2a.auth.utils'. "
|
||||
"Please update your imports. The old path will be removed in v2.0.0.",
|
||||
FutureWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
|
||||
class _AuthStore:
|
||||
"""Store for authentication schemes with safe concurrent access."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._store: dict[str, ClientAuthScheme | None] = {}
|
||||
self._lock = threading.RLock()
|
||||
|
||||
@staticmethod
|
||||
def compute_key(auth_type: str, auth_data: str) -> str:
|
||||
"""Compute a collision-resistant key using SHA-256."""
|
||||
content = f"{auth_type}:{auth_data}"
|
||||
return hashlib.sha256(content.encode()).hexdigest()
|
||||
|
||||
def set(self, key: str, auth: ClientAuthScheme | None) -> None:
|
||||
"""Store an auth scheme."""
|
||||
with self._lock:
|
||||
self._store[key] = auth
|
||||
|
||||
def get(self, key: str) -> ClientAuthScheme | None:
|
||||
"""Retrieve an auth scheme by key."""
|
||||
with self._lock:
|
||||
return self._store.get(key)
|
||||
|
||||
def __setitem__(self, key: str, value: ClientAuthScheme | None) -> None:
|
||||
with self._lock:
|
||||
self._store[key] = value
|
||||
|
||||
def __getitem__(self, key: str) -> ClientAuthScheme | None:
|
||||
with self._lock:
|
||||
return self._store[key]
|
||||
|
||||
|
||||
_auth_store = _AuthStore()
|
||||
|
||||
_SCHEME_PATTERN: Final[re.Pattern[str]] = re.compile(r"(\w+)\s+(.+?)(?=,\s*\w+\s+|$)")
|
||||
_PARAM_PATTERN: Final[re.Pattern[str]] = re.compile(r'(\w+)=(?:"([^"]*)"|([^\s,]+))')
|
||||
|
||||
_SCHEME_AUTH_MAPPING: Final[dict[type, tuple[type[ClientAuthScheme], ...]]] = {
|
||||
OAuth2SecurityScheme: (
|
||||
OAuth2ClientCredentials,
|
||||
OAuth2AuthorizationCode,
|
||||
BearerTokenAuth,
|
||||
),
|
||||
APIKeySecurityScheme: (APIKeyAuth,),
|
||||
}
|
||||
|
||||
_HTTPSchemeType = Literal["basic", "digest", "bearer"]
|
||||
|
||||
_HTTP_SCHEME_MAPPING: Final[dict[_HTTPSchemeType, type[ClientAuthScheme]]] = {
|
||||
"basic": HTTPBasicAuth,
|
||||
"digest": HTTPDigestAuth,
|
||||
"bearer": BearerTokenAuth,
|
||||
}
|
||||
|
||||
|
||||
def _raise_auth_mismatch(
|
||||
expected_classes: type[ClientAuthScheme] | tuple[type[ClientAuthScheme], ...],
|
||||
provided_auth: ClientAuthScheme,
|
||||
) -> None:
|
||||
"""Raise authentication mismatch error.
|
||||
|
||||
Args:
|
||||
expected_classes: Expected authentication class or tuple of classes.
|
||||
provided_auth: Actually provided authentication instance.
|
||||
|
||||
Raises:
|
||||
A2AClientHTTPError: Always raises with 401 status code.
|
||||
"""
|
||||
if isinstance(expected_classes, tuple):
|
||||
if len(expected_classes) == 1:
|
||||
required = expected_classes[0].__name__
|
||||
else:
|
||||
names = [cls.__name__ for cls in expected_classes]
|
||||
required = f"one of ({', '.join(names)})"
|
||||
else:
|
||||
required = expected_classes.__name__
|
||||
|
||||
msg = (
|
||||
f"AgentCard requires {required} authentication, "
|
||||
f"but {type(provided_auth).__name__} was provided"
|
||||
)
|
||||
raise A2AClientHTTPError(401, msg)
|
||||
|
||||
|
||||
def parse_www_authenticate(header_value: str) -> dict[str, dict[str, str]]:
|
||||
"""Parse WWW-Authenticate header into auth challenges.
|
||||
|
||||
Args:
|
||||
header_value: The WWW-Authenticate header value.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping auth scheme to its parameters.
|
||||
Example: {"Bearer": {"realm": "api", "scope": "read write"}}
|
||||
"""
|
||||
if not header_value:
|
||||
return {}
|
||||
|
||||
challenges: dict[str, dict[str, str]] = {}
|
||||
|
||||
for match in _SCHEME_PATTERN.finditer(header_value):
|
||||
scheme = match.group(1)
|
||||
params_str = match.group(2)
|
||||
|
||||
params: dict[str, str] = {}
|
||||
|
||||
for param_match in _PARAM_PATTERN.finditer(params_str):
|
||||
key = param_match.group(1)
|
||||
value = param_match.group(2) or param_match.group(3)
|
||||
params[key] = value
|
||||
|
||||
challenges[scheme] = params
|
||||
|
||||
return challenges
|
||||
|
||||
|
||||
def validate_auth_against_agent_card(
|
||||
agent_card: AgentCard, auth: ClientAuthScheme | None
|
||||
) -> None:
|
||||
"""Validate that provided auth matches AgentCard security requirements.
|
||||
|
||||
Args:
|
||||
agent_card: The A2A AgentCard containing security requirements.
|
||||
auth: User-provided authentication scheme (or None).
|
||||
|
||||
Raises:
|
||||
A2AClientHTTPError: If auth doesn't match AgentCard requirements (status_code=401).
|
||||
"""
|
||||
|
||||
if not agent_card.security or not agent_card.security_schemes:
|
||||
return
|
||||
|
||||
if not auth:
|
||||
msg = "AgentCard requires authentication but no auth scheme provided"
|
||||
raise A2AClientHTTPError(401, msg)
|
||||
|
||||
first_security_req = agent_card.security[0] if agent_card.security else {}
|
||||
|
||||
for scheme_name in first_security_req.keys():
|
||||
security_scheme_wrapper = agent_card.security_schemes.get(scheme_name)
|
||||
if not security_scheme_wrapper:
|
||||
continue
|
||||
|
||||
scheme = security_scheme_wrapper.root
|
||||
|
||||
if allowed_classes := _SCHEME_AUTH_MAPPING.get(type(scheme)):
|
||||
if not isinstance(auth, allowed_classes):
|
||||
_raise_auth_mismatch(allowed_classes, auth)
|
||||
return
|
||||
|
||||
if isinstance(scheme, HTTPAuthSecurityScheme):
|
||||
scheme_key = cast(_HTTPSchemeType, scheme.scheme.lower())
|
||||
if required_class := _HTTP_SCHEME_MAPPING.get(scheme_key):
|
||||
if not isinstance(auth, required_class):
|
||||
_raise_auth_mismatch(required_class, auth)
|
||||
return
|
||||
|
||||
msg = "Could not validate auth against AgentCard security requirements"
|
||||
raise A2AClientHTTPError(401, msg)
|
||||
|
||||
|
||||
async def retry_on_401(
|
||||
request_func: Callable[[], Awaitable[Response]],
|
||||
auth_scheme: ClientAuthScheme | None,
|
||||
client: AsyncClient,
|
||||
headers: MutableMapping[str, str],
|
||||
max_retries: int = 3,
|
||||
) -> Response:
|
||||
"""Retry a request on 401 authentication error.
|
||||
|
||||
Handles 401 errors by:
|
||||
1. Parsing WWW-Authenticate header
|
||||
2. Re-acquiring credentials
|
||||
3. Retrying the request
|
||||
|
||||
Args:
|
||||
request_func: Async function that makes the HTTP request.
|
||||
auth_scheme: Authentication scheme to refresh credentials with.
|
||||
client: HTTP client for making requests.
|
||||
headers: Request headers to update with new auth.
|
||||
max_retries: Maximum number of retry attempts (default: 3).
|
||||
|
||||
Returns:
|
||||
HTTP response from the request.
|
||||
|
||||
Raises:
|
||||
httpx.HTTPStatusError: If retries are exhausted or auth scheme is None.
|
||||
"""
|
||||
last_response: Response | None = None
|
||||
last_challenges: dict[str, dict[str, str]] = {}
|
||||
|
||||
for attempt in range(max_retries):
|
||||
response = await request_func()
|
||||
|
||||
if response.status_code != 401:
|
||||
return response
|
||||
|
||||
last_response = response
|
||||
|
||||
if auth_scheme is None:
|
||||
response.raise_for_status()
|
||||
return response
|
||||
|
||||
www_authenticate = response.headers.get("WWW-Authenticate", "")
|
||||
challenges = parse_www_authenticate(www_authenticate)
|
||||
last_challenges = challenges
|
||||
|
||||
if attempt >= max_retries - 1:
|
||||
break
|
||||
|
||||
backoff_time = 2**attempt
|
||||
await asyncio.sleep(backoff_time)
|
||||
|
||||
await auth_scheme.apply_auth(client, headers)
|
||||
|
||||
if last_response:
|
||||
last_response.raise_for_status()
|
||||
return last_response
|
||||
|
||||
msg = "retry_on_401 failed without making any requests"
|
||||
if last_challenges:
|
||||
challenge_info = ", ".join(
|
||||
f"{scheme} (realm={params.get('realm', 'N/A')})"
|
||||
for scheme, params in last_challenges.items()
|
||||
)
|
||||
msg = f"{msg}. Server challenges: {challenge_info}"
|
||||
raise RuntimeError(msg)
|
||||
|
||||
|
||||
def configure_auth_client(
|
||||
auth: HTTPDigestAuth | APIKeyAuth, client: AsyncClient
|
||||
) -> None:
|
||||
"""Configure HTTP client with auth-specific settings.
|
||||
|
||||
Only HTTPDigestAuth and APIKeyAuth need client configuration.
|
||||
|
||||
Args:
|
||||
auth: Authentication scheme that requires client configuration.
|
||||
client: HTTP client to configure.
|
||||
"""
|
||||
auth.configure_client(client)
|
||||
from crewai_a2a.auth.utils import * # noqa: E402, F403
|
||||
|
||||
@@ -1,690 +1,13 @@
|
||||
"""A2A configuration types.
|
||||
"""Backward-compatibility shim — use ``crewai_a2a.config`` instead."""
|
||||
|
||||
This module is separate from experimental.a2a to avoid circular imports.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any, ClassVar, Literal, cast
|
||||
import warnings
|
||||
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
Field,
|
||||
FilePath,
|
||||
PrivateAttr,
|
||||
SecretStr,
|
||||
model_validator,
|
||||
|
||||
warnings.warn(
|
||||
"'crewai.a2a.config' has been moved to 'crewai_a2a.config'. "
|
||||
"Please update your imports. The old path will be removed in v2.0.0.",
|
||||
FutureWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
from typing_extensions import Self, deprecated
|
||||
|
||||
from crewai.a2a.auth.client_schemes import ClientAuthScheme
|
||||
from crewai.a2a.auth.server_schemes import ServerAuthScheme
|
||||
from crewai.a2a.extensions.base import ValidatedA2AExtension
|
||||
from crewai.a2a.types import ProtocolVersion, TransportType, Url
|
||||
|
||||
|
||||
try:
|
||||
from a2a.types import (
|
||||
AgentCapabilities,
|
||||
AgentCardSignature,
|
||||
AgentInterface,
|
||||
AgentProvider,
|
||||
AgentSkill,
|
||||
SecurityScheme,
|
||||
)
|
||||
|
||||
from crewai.a2a.extensions.server import ServerExtension
|
||||
from crewai.a2a.updates import UpdateConfig
|
||||
except ImportError:
|
||||
UpdateConfig: Any = Any # type: ignore[no-redef]
|
||||
AgentCapabilities: Any = Any # type: ignore[no-redef]
|
||||
AgentCardSignature: Any = Any # type: ignore[no-redef]
|
||||
AgentInterface: Any = Any # type: ignore[no-redef]
|
||||
AgentProvider: Any = Any # type: ignore[no-redef]
|
||||
SecurityScheme: Any = Any # type: ignore[no-redef]
|
||||
AgentSkill: Any = Any # type: ignore[no-redef]
|
||||
ServerExtension: Any = Any # type: ignore[no-redef]
|
||||
|
||||
|
||||
def _get_default_update_config() -> UpdateConfig:
|
||||
from crewai.a2a.updates import StreamingConfig
|
||||
|
||||
return StreamingConfig()
|
||||
|
||||
|
||||
SigningAlgorithm = Literal[
|
||||
"RS256",
|
||||
"RS384",
|
||||
"RS512",
|
||||
"ES256",
|
||||
"ES384",
|
||||
"ES512",
|
||||
"PS256",
|
||||
"PS384",
|
||||
"PS512",
|
||||
]
|
||||
|
||||
|
||||
class AgentCardSigningConfig(BaseModel):
|
||||
"""Configuration for AgentCard JWS signing.
|
||||
|
||||
Provides the private key and algorithm settings for signing AgentCards.
|
||||
Either private_key_path or private_key_pem must be provided, but not both.
|
||||
|
||||
Attributes:
|
||||
private_key_path: Path to a PEM-encoded private key file.
|
||||
private_key_pem: PEM-encoded private key as a secret string.
|
||||
key_id: Optional key identifier for the JWS header (kid claim).
|
||||
algorithm: Signing algorithm (RS256, ES256, PS256, etc.).
|
||||
"""
|
||||
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
|
||||
|
||||
private_key_path: FilePath | None = Field(
|
||||
default=None,
|
||||
description="Path to PEM-encoded private key file",
|
||||
)
|
||||
private_key_pem: SecretStr | None = Field(
|
||||
default=None,
|
||||
description="PEM-encoded private key",
|
||||
)
|
||||
key_id: str | None = Field(
|
||||
default=None,
|
||||
description="Key identifier for JWS header (kid claim)",
|
||||
)
|
||||
algorithm: SigningAlgorithm = Field(
|
||||
default="RS256",
|
||||
description="Signing algorithm (RS256, ES256, PS256, etc.)",
|
||||
)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _validate_key_source(self) -> Self:
|
||||
"""Ensure exactly one key source is provided."""
|
||||
has_path = self.private_key_path is not None
|
||||
has_pem = self.private_key_pem is not None
|
||||
|
||||
if not has_path and not has_pem:
|
||||
raise ValueError(
|
||||
"Either private_key_path or private_key_pem must be provided"
|
||||
)
|
||||
if has_path and has_pem:
|
||||
raise ValueError(
|
||||
"Only one of private_key_path or private_key_pem should be provided"
|
||||
)
|
||||
return self
|
||||
|
||||
def get_private_key(self) -> str:
|
||||
"""Get the private key content.
|
||||
|
||||
Returns:
|
||||
The PEM-encoded private key as a string.
|
||||
"""
|
||||
if self.private_key_pem:
|
||||
return self.private_key_pem.get_secret_value()
|
||||
if self.private_key_path:
|
||||
return Path(self.private_key_path).read_text()
|
||||
raise ValueError("No private key configured")
|
||||
|
||||
|
||||
class GRPCServerConfig(BaseModel):
|
||||
"""gRPC server transport configuration.
|
||||
|
||||
Presence of this config in ServerTransportConfig.grpc enables gRPC transport.
|
||||
|
||||
Attributes:
|
||||
host: Hostname to advertise in agent cards (default: localhost).
|
||||
Use docker service name (e.g., 'web') for docker-compose setups.
|
||||
port: Port for the gRPC server.
|
||||
tls_cert_path: Path to TLS certificate file for gRPC.
|
||||
tls_key_path: Path to TLS private key file for gRPC.
|
||||
max_workers: Maximum number of workers for the gRPC thread pool.
|
||||
reflection_enabled: Whether to enable gRPC reflection for debugging.
|
||||
"""
|
||||
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
|
||||
|
||||
host: str = Field(
|
||||
default="localhost",
|
||||
description="Hostname to advertise in agent cards for gRPC connections",
|
||||
)
|
||||
port: int = Field(
|
||||
default=50051,
|
||||
description="Port for the gRPC server",
|
||||
)
|
||||
tls_cert_path: str | None = Field(
|
||||
default=None,
|
||||
description="Path to TLS certificate file for gRPC",
|
||||
)
|
||||
tls_key_path: str | None = Field(
|
||||
default=None,
|
||||
description="Path to TLS private key file for gRPC",
|
||||
)
|
||||
max_workers: int = Field(
|
||||
default=10,
|
||||
description="Maximum number of workers for the gRPC thread pool",
|
||||
)
|
||||
reflection_enabled: bool = Field(
|
||||
default=False,
|
||||
description="Whether to enable gRPC reflection for debugging",
|
||||
)
|
||||
|
||||
|
||||
class GRPCClientConfig(BaseModel):
|
||||
"""gRPC client transport configuration.
|
||||
|
||||
Attributes:
|
||||
max_send_message_length: Maximum size for outgoing messages in bytes.
|
||||
max_receive_message_length: Maximum size for incoming messages in bytes.
|
||||
keepalive_time_ms: Time between keepalive pings in milliseconds.
|
||||
keepalive_timeout_ms: Timeout for keepalive ping response in milliseconds.
|
||||
"""
|
||||
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
|
||||
|
||||
max_send_message_length: int | None = Field(
|
||||
default=None,
|
||||
description="Maximum size for outgoing messages in bytes",
|
||||
)
|
||||
max_receive_message_length: int | None = Field(
|
||||
default=None,
|
||||
description="Maximum size for incoming messages in bytes",
|
||||
)
|
||||
keepalive_time_ms: int | None = Field(
|
||||
default=None,
|
||||
description="Time between keepalive pings in milliseconds",
|
||||
)
|
||||
keepalive_timeout_ms: int | None = Field(
|
||||
default=None,
|
||||
description="Timeout for keepalive ping response in milliseconds",
|
||||
)
|
||||
|
||||
|
||||
class JSONRPCServerConfig(BaseModel):
|
||||
"""JSON-RPC server transport configuration.
|
||||
|
||||
Presence of this config in ServerTransportConfig.jsonrpc enables JSON-RPC transport.
|
||||
|
||||
Attributes:
|
||||
rpc_path: URL path for the JSON-RPC endpoint.
|
||||
agent_card_path: URL path for the agent card endpoint.
|
||||
"""
|
||||
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
|
||||
|
||||
rpc_path: str = Field(
|
||||
default="/a2a",
|
||||
description="URL path for the JSON-RPC endpoint",
|
||||
)
|
||||
agent_card_path: str = Field(
|
||||
default="/.well-known/agent-card.json",
|
||||
description="URL path for the agent card endpoint",
|
||||
)
|
||||
|
||||
|
||||
class JSONRPCClientConfig(BaseModel):
|
||||
"""JSON-RPC client transport configuration.
|
||||
|
||||
Attributes:
|
||||
max_request_size: Maximum request body size in bytes.
|
||||
"""
|
||||
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
|
||||
|
||||
max_request_size: int | None = Field(
|
||||
default=None,
|
||||
description="Maximum request body size in bytes",
|
||||
)
|
||||
|
||||
|
||||
class HTTPJSONConfig(BaseModel):
|
||||
"""HTTP+JSON transport configuration.
|
||||
|
||||
Presence of this config in ServerTransportConfig.http_json enables HTTP+JSON transport.
|
||||
"""
|
||||
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
class ServerPushNotificationConfig(BaseModel):
|
||||
"""Configuration for outgoing webhook push notifications.
|
||||
|
||||
Controls how the server signs and delivers push notifications to clients.
|
||||
|
||||
Attributes:
|
||||
signature_secret: Shared secret for HMAC-SHA256 signing of outgoing webhooks.
|
||||
"""
|
||||
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
|
||||
|
||||
signature_secret: SecretStr | None = Field(
|
||||
default=None,
|
||||
description="Shared secret for HMAC-SHA256 signing of outgoing push notifications",
|
||||
)
|
||||
|
||||
|
||||
class ServerTransportConfig(BaseModel):
|
||||
"""Transport configuration for A2A server.
|
||||
|
||||
Groups all transport-related settings including preferred transport
|
||||
and protocol-specific configurations.
|
||||
|
||||
Attributes:
|
||||
preferred: Transport protocol for the preferred endpoint.
|
||||
jsonrpc: JSON-RPC server transport configuration.
|
||||
grpc: gRPC server transport configuration.
|
||||
http_json: HTTP+JSON transport configuration.
|
||||
"""
|
||||
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
|
||||
|
||||
preferred: TransportType = Field(
|
||||
default="JSONRPC",
|
||||
description="Transport protocol for the preferred endpoint",
|
||||
)
|
||||
jsonrpc: JSONRPCServerConfig = Field(
|
||||
default_factory=JSONRPCServerConfig,
|
||||
description="JSON-RPC server transport configuration",
|
||||
)
|
||||
grpc: GRPCServerConfig | None = Field(
|
||||
default=None,
|
||||
description="gRPC server transport configuration",
|
||||
)
|
||||
http_json: HTTPJSONConfig | None = Field(
|
||||
default=None,
|
||||
description="HTTP+JSON transport configuration",
|
||||
)
|
||||
|
||||
|
||||
def _migrate_client_transport_fields(
|
||||
transport: ClientTransportConfig,
|
||||
transport_protocol: TransportType | None,
|
||||
supported_transports: list[TransportType] | None,
|
||||
) -> None:
|
||||
"""Migrate deprecated transport fields to new config."""
|
||||
if transport_protocol is not None:
|
||||
warnings.warn(
|
||||
"transport_protocol is deprecated, use transport=ClientTransportConfig(preferred=...) instead",
|
||||
FutureWarning,
|
||||
stacklevel=5,
|
||||
)
|
||||
object.__setattr__(transport, "preferred", transport_protocol)
|
||||
if supported_transports is not None:
|
||||
warnings.warn(
|
||||
"supported_transports is deprecated, use transport=ClientTransportConfig(supported=...) instead",
|
||||
FutureWarning,
|
||||
stacklevel=5,
|
||||
)
|
||||
object.__setattr__(transport, "supported", supported_transports)
|
||||
|
||||
|
||||
class ClientTransportConfig(BaseModel):
|
||||
"""Transport configuration for A2A client.
|
||||
|
||||
Groups all client transport-related settings including preferred transport,
|
||||
supported transports for negotiation, and protocol-specific configurations.
|
||||
|
||||
Transport negotiation logic:
|
||||
1. If `preferred` is set and server supports it → use client's preferred
|
||||
2. Otherwise, if server's preferred is in client's `supported` → use server's preferred
|
||||
3. Otherwise, find first match from client's `supported` in server's interfaces
|
||||
|
||||
Attributes:
|
||||
preferred: Client's preferred transport. If set, client preference takes priority.
|
||||
supported: Transports the client can use, in order of preference.
|
||||
jsonrpc: JSON-RPC client transport configuration.
|
||||
grpc: gRPC client transport configuration.
|
||||
"""
|
||||
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
|
||||
|
||||
preferred: TransportType | None = Field(
|
||||
default=None,
|
||||
description="Client's preferred transport. If set, takes priority over server preference.",
|
||||
)
|
||||
supported: list[TransportType] = Field(
|
||||
default_factory=lambda: cast(list[TransportType], ["JSONRPC"]),
|
||||
description="Transports the client can use, in order of preference",
|
||||
)
|
||||
jsonrpc: JSONRPCClientConfig = Field(
|
||||
default_factory=JSONRPCClientConfig,
|
||||
description="JSON-RPC client transport configuration",
|
||||
)
|
||||
grpc: GRPCClientConfig = Field(
|
||||
default_factory=GRPCClientConfig,
|
||||
description="gRPC client transport configuration",
|
||||
)
|
||||
|
||||
|
||||
@deprecated(
|
||||
"""
|
||||
`crewai.a2a.config.A2AConfig` is deprecated and will be removed in v2.0.0,
|
||||
use `crewai.a2a.config.A2AClientConfig` or `crewai.a2a.config.A2AServerConfig` instead.
|
||||
""",
|
||||
category=FutureWarning,
|
||||
)
|
||||
class A2AConfig(BaseModel):
|
||||
"""Configuration for A2A protocol integration.
|
||||
|
||||
Deprecated:
|
||||
Use A2AClientConfig instead. This class will be removed in a future version.
|
||||
|
||||
Attributes:
|
||||
endpoint: A2A agent endpoint URL.
|
||||
auth: Authentication scheme.
|
||||
timeout: Request timeout in seconds.
|
||||
max_turns: Maximum conversation turns with A2A agent.
|
||||
response_model: Optional Pydantic model for structured A2A agent responses.
|
||||
fail_fast: If True, raise error when agent unreachable; if False, skip and continue.
|
||||
trust_remote_completion_status: If True, return A2A agent's result directly when completed.
|
||||
updates: Update mechanism config.
|
||||
client_extensions: Client-side processing hooks for tool injection and prompt augmentation.
|
||||
transport: Transport configuration (preferred, supported transports, gRPC settings).
|
||||
"""
|
||||
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
|
||||
|
||||
endpoint: Url = Field(description="A2A agent endpoint URL")
|
||||
auth: ClientAuthScheme | None = Field(
|
||||
default=None,
|
||||
description="Authentication scheme",
|
||||
)
|
||||
timeout: int = Field(default=120, description="Request timeout in seconds")
|
||||
max_turns: int = Field(
|
||||
default=10, description="Maximum conversation turns with A2A agent"
|
||||
)
|
||||
response_model: type[BaseModel] | None = Field(
|
||||
default=None,
|
||||
description="Optional Pydantic model for structured A2A agent responses",
|
||||
)
|
||||
fail_fast: bool = Field(
|
||||
default=True,
|
||||
description="If True, raise error when agent unreachable; if False, skip",
|
||||
)
|
||||
trust_remote_completion_status: bool = Field(
|
||||
default=False,
|
||||
description="If True, return A2A result directly when completed",
|
||||
)
|
||||
updates: UpdateConfig = Field(
|
||||
default_factory=_get_default_update_config,
|
||||
description="Update mechanism config",
|
||||
)
|
||||
client_extensions: list[ValidatedA2AExtension] = Field(
|
||||
default_factory=list,
|
||||
description="Client-side processing hooks for tool injection and prompt augmentation",
|
||||
)
|
||||
transport: ClientTransportConfig = Field(
|
||||
default_factory=ClientTransportConfig,
|
||||
description="Transport configuration (preferred, supported transports, gRPC settings)",
|
||||
)
|
||||
transport_protocol: TransportType | None = Field(
|
||||
default=None,
|
||||
description="Deprecated: Use transport.preferred instead",
|
||||
exclude=True,
|
||||
)
|
||||
supported_transports: list[TransportType] | None = Field(
|
||||
default=None,
|
||||
description="Deprecated: Use transport.supported instead",
|
||||
exclude=True,
|
||||
)
|
||||
use_client_preference: bool | None = Field(
|
||||
default=None,
|
||||
description="Deprecated: Set transport.preferred to enable client preference",
|
||||
exclude=True,
|
||||
)
|
||||
_parallel_delegation: bool = PrivateAttr(default=False)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _migrate_deprecated_transport_fields(self) -> Self:
|
||||
"""Migrate deprecated transport fields to new config."""
|
||||
_migrate_client_transport_fields(
|
||||
self.transport, self.transport_protocol, self.supported_transports
|
||||
)
|
||||
if self.use_client_preference is not None:
|
||||
warnings.warn(
|
||||
"use_client_preference is deprecated, set transport.preferred to enable client preference",
|
||||
FutureWarning,
|
||||
stacklevel=4,
|
||||
)
|
||||
if self.use_client_preference and self.transport.supported:
|
||||
object.__setattr__(
|
||||
self.transport, "preferred", self.transport.supported[0]
|
||||
)
|
||||
return self
|
||||
|
||||
|
||||
class A2AClientConfig(BaseModel):
|
||||
"""Configuration for connecting to remote A2A agents.
|
||||
|
||||
Attributes:
|
||||
endpoint: A2A agent endpoint URL.
|
||||
auth: Authentication scheme.
|
||||
timeout: Request timeout in seconds.
|
||||
max_turns: Maximum conversation turns with A2A agent.
|
||||
response_model: Optional Pydantic model for structured A2A agent responses.
|
||||
fail_fast: If True, raise error when agent unreachable; if False, skip and continue.
|
||||
trust_remote_completion_status: If True, return A2A agent's result directly when completed.
|
||||
updates: Update mechanism config.
|
||||
accepted_output_modes: Media types the client can accept in responses.
|
||||
extensions: Extension URIs the client supports (A2A protocol extensions).
|
||||
client_extensions: Client-side processing hooks for tool injection and prompt augmentation.
|
||||
transport: Transport configuration (preferred, supported transports, gRPC settings).
|
||||
"""
|
||||
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
|
||||
|
||||
endpoint: Url = Field(description="A2A agent endpoint URL")
|
||||
auth: ClientAuthScheme | None = Field(
|
||||
default=None,
|
||||
description="Authentication scheme",
|
||||
)
|
||||
timeout: int = Field(default=120, description="Request timeout in seconds")
|
||||
max_turns: int = Field(
|
||||
default=10, description="Maximum conversation turns with A2A agent"
|
||||
)
|
||||
response_model: type[BaseModel] | None = Field(
|
||||
default=None,
|
||||
description="Optional Pydantic model for structured A2A agent responses",
|
||||
)
|
||||
fail_fast: bool = Field(
|
||||
default=True,
|
||||
description="If True, raise error when agent unreachable; if False, skip",
|
||||
)
|
||||
trust_remote_completion_status: bool = Field(
|
||||
default=False,
|
||||
description="If True, return A2A result directly when completed",
|
||||
)
|
||||
updates: UpdateConfig = Field(
|
||||
default_factory=_get_default_update_config,
|
||||
description="Update mechanism config",
|
||||
)
|
||||
accepted_output_modes: list[str] = Field(
|
||||
default_factory=lambda: ["application/json"],
|
||||
description="Media types the client can accept in responses",
|
||||
)
|
||||
extensions: list[str] = Field(
|
||||
default_factory=list,
|
||||
description="Extension URIs the client supports",
|
||||
)
|
||||
client_extensions: list[ValidatedA2AExtension] = Field(
|
||||
default_factory=list,
|
||||
description="Client-side processing hooks for tool injection and prompt augmentation",
|
||||
)
|
||||
transport: ClientTransportConfig = Field(
|
||||
default_factory=ClientTransportConfig,
|
||||
description="Transport configuration (preferred, supported transports, gRPC settings)",
|
||||
)
|
||||
transport_protocol: TransportType | None = Field(
|
||||
default=None,
|
||||
description="Deprecated: Use transport.preferred instead",
|
||||
exclude=True,
|
||||
)
|
||||
supported_transports: list[TransportType] | None = Field(
|
||||
default=None,
|
||||
description="Deprecated: Use transport.supported instead",
|
||||
exclude=True,
|
||||
)
|
||||
_parallel_delegation: bool = PrivateAttr(default=False)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _migrate_deprecated_transport_fields(self) -> Self:
|
||||
"""Migrate deprecated transport fields to new config."""
|
||||
_migrate_client_transport_fields(
|
||||
self.transport, self.transport_protocol, self.supported_transports
|
||||
)
|
||||
return self
|
||||
|
||||
|
||||
class A2AServerConfig(BaseModel):
|
||||
"""Configuration for exposing a Crew or Agent as an A2A server.
|
||||
|
||||
All fields correspond to A2A AgentCard fields. Fields like name, description,
|
||||
and skills can be auto-derived from the Crew/Agent if not provided.
|
||||
|
||||
Attributes:
|
||||
name: Human-readable name for the agent.
|
||||
description: Human-readable description of the agent.
|
||||
version: Version string for the agent card.
|
||||
skills: List of agent skills/capabilities.
|
||||
default_input_modes: Default supported input MIME types.
|
||||
default_output_modes: Default supported output MIME types.
|
||||
capabilities: Declaration of optional capabilities.
|
||||
protocol_version: A2A protocol version this agent supports.
|
||||
provider: Information about the agent's service provider.
|
||||
documentation_url: URL to the agent's documentation.
|
||||
icon_url: URL to an icon for the agent.
|
||||
additional_interfaces: Additional supported interfaces.
|
||||
security: Security requirement objects for all interactions.
|
||||
security_schemes: Security schemes available to authorize requests.
|
||||
supports_authenticated_extended_card: Whether agent provides extended card to authenticated users.
|
||||
url: Preferred endpoint URL for the agent.
|
||||
signing_config: Configuration for signing the AgentCard with JWS.
|
||||
signatures: Deprecated. Pre-computed JWS signatures. Use signing_config instead.
|
||||
server_extensions: Server-side A2A protocol extensions with on_request/on_response hooks.
|
||||
push_notifications: Configuration for outgoing push notifications.
|
||||
transport: Transport configuration (preferred transport, gRPC, REST settings).
|
||||
auth: Authentication scheme for A2A endpoints.
|
||||
"""
|
||||
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
|
||||
|
||||
name: str | None = Field(
|
||||
default=None,
|
||||
description="Human-readable name for the agent. Auto-derived from Crew/Agent if not provided.",
|
||||
)
|
||||
description: str | None = Field(
|
||||
default=None,
|
||||
description="Human-readable description of the agent. Auto-derived from Crew/Agent if not provided.",
|
||||
)
|
||||
version: str = Field(
|
||||
default="1.0.0",
|
||||
description="Version string for the agent card",
|
||||
)
|
||||
skills: list[AgentSkill] = Field(
|
||||
default_factory=list,
|
||||
description="List of agent skills. Auto-derived from tasks/tools if not provided.",
|
||||
)
|
||||
default_input_modes: list[str] = Field(
|
||||
default_factory=lambda: ["text/plain", "application/json"],
|
||||
description="Default supported input MIME types",
|
||||
)
|
||||
default_output_modes: list[str] = Field(
|
||||
default_factory=lambda: ["text/plain", "application/json"],
|
||||
description="Default supported output MIME types",
|
||||
)
|
||||
capabilities: AgentCapabilities = Field(
|
||||
default_factory=lambda: AgentCapabilities(
|
||||
streaming=True,
|
||||
push_notifications=False,
|
||||
),
|
||||
description="Declaration of optional capabilities supported by the agent",
|
||||
)
|
||||
protocol_version: ProtocolVersion = Field(
|
||||
default="0.3.0",
|
||||
description="A2A protocol version this agent supports",
|
||||
)
|
||||
provider: AgentProvider | None = Field(
|
||||
default=None,
|
||||
description="Information about the agent's service provider",
|
||||
)
|
||||
documentation_url: Url | None = Field(
|
||||
default=None,
|
||||
description="URL to the agent's documentation",
|
||||
)
|
||||
icon_url: Url | None = Field(
|
||||
default=None,
|
||||
description="URL to an icon for the agent",
|
||||
)
|
||||
additional_interfaces: list[AgentInterface] = Field(
|
||||
default_factory=list,
|
||||
description="Additional supported interfaces.",
|
||||
)
|
||||
security: list[dict[str, list[str]]] = Field(
|
||||
default_factory=list,
|
||||
description="Security requirement objects for all agent interactions",
|
||||
)
|
||||
security_schemes: dict[str, SecurityScheme] = Field(
|
||||
default_factory=dict,
|
||||
description="Security schemes available to authorize requests",
|
||||
)
|
||||
supports_authenticated_extended_card: bool = Field(
|
||||
default=False,
|
||||
description="Whether agent provides extended card to authenticated users",
|
||||
)
|
||||
url: Url | None = Field(
|
||||
default=None,
|
||||
description="Preferred endpoint URL for the agent. Set at runtime if not provided.",
|
||||
)
|
||||
signing_config: AgentCardSigningConfig | None = Field(
|
||||
default=None,
|
||||
description="Configuration for signing the AgentCard with JWS",
|
||||
)
|
||||
signatures: list[AgentCardSignature] | None = Field(
|
||||
default=None,
|
||||
description="Deprecated: Use signing_config instead. Pre-computed JWS signatures for the AgentCard.",
|
||||
exclude=True,
|
||||
deprecated=True,
|
||||
)
|
||||
server_extensions: list[ServerExtension] = Field(
|
||||
default_factory=list,
|
||||
description="Server-side A2A protocol extensions that modify agent behavior",
|
||||
)
|
||||
push_notifications: ServerPushNotificationConfig | None = Field(
|
||||
default=None,
|
||||
description="Configuration for outgoing push notifications",
|
||||
)
|
||||
transport: ServerTransportConfig = Field(
|
||||
default_factory=ServerTransportConfig,
|
||||
description="Transport configuration (preferred transport, gRPC, REST settings)",
|
||||
)
|
||||
preferred_transport: TransportType | None = Field(
|
||||
default=None,
|
||||
description="Deprecated: Use transport.preferred instead",
|
||||
exclude=True,
|
||||
deprecated=True,
|
||||
)
|
||||
auth: ServerAuthScheme | None = Field(
|
||||
default=None,
|
||||
description="Authentication scheme for A2A endpoints. Defaults to SimpleTokenAuth using AUTH_TOKEN env var.",
|
||||
)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _migrate_deprecated_fields(self) -> Self:
|
||||
"""Migrate deprecated fields to new config."""
|
||||
if self.preferred_transport is not None:
|
||||
warnings.warn(
|
||||
"preferred_transport is deprecated, use transport=ServerTransportConfig(preferred=...) instead",
|
||||
FutureWarning,
|
||||
stacklevel=4,
|
||||
)
|
||||
object.__setattr__(self.transport, "preferred", self.preferred_transport)
|
||||
if self.signatures is not None:
|
||||
warnings.warn(
|
||||
"signatures is deprecated, use signing_config=AgentCardSigningConfig(...) instead. "
|
||||
"The signatures field will be removed in v2.0.0.",
|
||||
FutureWarning,
|
||||
stacklevel=4,
|
||||
)
|
||||
return self
|
||||
from crewai_a2a.config import * # noqa: E402, F403
|
||||
|
||||
@@ -1,491 +1,13 @@
|
||||
"""A2A error codes and error response utilities.
|
||||
"""Backward-compatibility shim — use ``crewai_a2a.errors`` instead."""
|
||||
|
||||
This module provides a centralized mapping of all A2A protocol error codes
|
||||
as defined in the A2A specification, plus custom CrewAI extensions.
|
||||
import warnings
|
||||
|
||||
Error codes follow JSON-RPC 2.0 conventions:
|
||||
- -32700 to -32600: Standard JSON-RPC errors
|
||||
- -32099 to -32000: Server errors (A2A-specific)
|
||||
- -32768 to -32100: Reserved for implementation-defined errors
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
warnings.warn(
|
||||
"'crewai.a2a.errors' has been moved to 'crewai_a2a.errors'. "
|
||||
"Please update your imports. The old path will be removed in v2.0.0.",
|
||||
FutureWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from enum import IntEnum
|
||||
from typing import Any
|
||||
|
||||
from a2a.client.errors import A2AClientTimeoutError
|
||||
|
||||
|
||||
class A2APollingTimeoutError(A2AClientTimeoutError):
|
||||
"""Raised when polling exceeds the configured timeout."""
|
||||
|
||||
|
||||
class A2AErrorCode(IntEnum):
|
||||
"""A2A protocol error codes.
|
||||
|
||||
Codes follow JSON-RPC 2.0 specification with A2A-specific extensions.
|
||||
"""
|
||||
|
||||
# JSON-RPC 2.0 Standard Errors (-32700 to -32600)
|
||||
JSON_PARSE_ERROR = -32700
|
||||
"""Invalid JSON was received by the server."""
|
||||
|
||||
INVALID_REQUEST = -32600
|
||||
"""The JSON sent is not a valid Request object."""
|
||||
|
||||
METHOD_NOT_FOUND = -32601
|
||||
"""The method does not exist / is not available."""
|
||||
|
||||
INVALID_PARAMS = -32602
|
||||
"""Invalid method parameter(s)."""
|
||||
|
||||
INTERNAL_ERROR = -32603
|
||||
"""Internal JSON-RPC error."""
|
||||
|
||||
# A2A-Specific Errors (-32099 to -32000)
|
||||
TASK_NOT_FOUND = -32001
|
||||
"""The specified task was not found."""
|
||||
|
||||
TASK_NOT_CANCELABLE = -32002
|
||||
"""The task cannot be canceled (already completed/failed)."""
|
||||
|
||||
PUSH_NOTIFICATION_NOT_SUPPORTED = -32003
|
||||
"""Push notifications are not supported by this agent."""
|
||||
|
||||
UNSUPPORTED_OPERATION = -32004
|
||||
"""The requested operation is not supported."""
|
||||
|
||||
CONTENT_TYPE_NOT_SUPPORTED = -32005
|
||||
"""Incompatible content types between client and server."""
|
||||
|
||||
INVALID_AGENT_RESPONSE = -32006
|
||||
"""The agent produced an invalid response."""
|
||||
|
||||
# CrewAI Custom Extensions (-32768 to -32100)
|
||||
UNSUPPORTED_VERSION = -32009
|
||||
"""The requested A2A protocol version is not supported."""
|
||||
|
||||
UNSUPPORTED_EXTENSION = -32010
|
||||
"""Client does not support required protocol extensions."""
|
||||
|
||||
AUTHENTICATION_REQUIRED = -32011
|
||||
"""Authentication is required for this operation."""
|
||||
|
||||
AUTHORIZATION_FAILED = -32012
|
||||
"""Authorization check failed (insufficient permissions)."""
|
||||
|
||||
RATE_LIMIT_EXCEEDED = -32013
|
||||
"""Rate limit exceeded for this client/operation."""
|
||||
|
||||
TASK_TIMEOUT = -32014
|
||||
"""Task execution timed out."""
|
||||
|
||||
TRANSPORT_NEGOTIATION_FAILED = -32015
|
||||
"""Failed to negotiate a compatible transport protocol."""
|
||||
|
||||
CONTEXT_NOT_FOUND = -32016
|
||||
"""The specified context was not found."""
|
||||
|
||||
SKILL_NOT_FOUND = -32017
|
||||
"""The specified skill was not found."""
|
||||
|
||||
ARTIFACT_NOT_FOUND = -32018
|
||||
"""The specified artifact was not found."""
|
||||
|
||||
|
||||
# Error code to default message mapping
|
||||
ERROR_MESSAGES: dict[int, str] = {
|
||||
A2AErrorCode.JSON_PARSE_ERROR: "Parse error",
|
||||
A2AErrorCode.INVALID_REQUEST: "Invalid Request",
|
||||
A2AErrorCode.METHOD_NOT_FOUND: "Method not found",
|
||||
A2AErrorCode.INVALID_PARAMS: "Invalid params",
|
||||
A2AErrorCode.INTERNAL_ERROR: "Internal error",
|
||||
A2AErrorCode.TASK_NOT_FOUND: "Task not found",
|
||||
A2AErrorCode.TASK_NOT_CANCELABLE: "Task not cancelable",
|
||||
A2AErrorCode.PUSH_NOTIFICATION_NOT_SUPPORTED: "Push Notification is not supported",
|
||||
A2AErrorCode.UNSUPPORTED_OPERATION: "This operation is not supported",
|
||||
A2AErrorCode.CONTENT_TYPE_NOT_SUPPORTED: "Incompatible content types",
|
||||
A2AErrorCode.INVALID_AGENT_RESPONSE: "Invalid agent response",
|
||||
A2AErrorCode.UNSUPPORTED_VERSION: "Unsupported A2A version",
|
||||
A2AErrorCode.UNSUPPORTED_EXTENSION: "Client does not support required extensions",
|
||||
A2AErrorCode.AUTHENTICATION_REQUIRED: "Authentication required",
|
||||
A2AErrorCode.AUTHORIZATION_FAILED: "Authorization failed",
|
||||
A2AErrorCode.RATE_LIMIT_EXCEEDED: "Rate limit exceeded",
|
||||
A2AErrorCode.TASK_TIMEOUT: "Task execution timed out",
|
||||
A2AErrorCode.TRANSPORT_NEGOTIATION_FAILED: "Transport negotiation failed",
|
||||
A2AErrorCode.CONTEXT_NOT_FOUND: "Context not found",
|
||||
A2AErrorCode.SKILL_NOT_FOUND: "Skill not found",
|
||||
A2AErrorCode.ARTIFACT_NOT_FOUND: "Artifact not found",
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class A2AError(Exception):
|
||||
"""Base exception for A2A protocol errors.
|
||||
|
||||
Attributes:
|
||||
code: The A2A/JSON-RPC error code.
|
||||
message: Human-readable error message.
|
||||
data: Optional additional error data.
|
||||
"""
|
||||
|
||||
code: int
|
||||
message: str | None = None
|
||||
data: Any = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.message is None:
|
||||
self.message = ERROR_MESSAGES.get(self.code, "Unknown error")
|
||||
super().__init__(self.message)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to JSON-RPC error object format."""
|
||||
error: dict[str, Any] = {
|
||||
"code": self.code,
|
||||
"message": self.message,
|
||||
}
|
||||
if self.data is not None:
|
||||
error["data"] = self.data
|
||||
return error
|
||||
|
||||
def to_response(self, request_id: str | int | None = None) -> dict[str, Any]:
|
||||
"""Convert to full JSON-RPC error response."""
|
||||
return {
|
||||
"jsonrpc": "2.0",
|
||||
"error": self.to_dict(),
|
||||
"id": request_id,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class JSONParseError(A2AError):
|
||||
"""Invalid JSON was received."""
|
||||
|
||||
code: int = field(default=A2AErrorCode.JSON_PARSE_ERROR, init=False)
|
||||
|
||||
|
||||
@dataclass
|
||||
class InvalidRequestError(A2AError):
|
||||
"""The JSON sent is not a valid Request object."""
|
||||
|
||||
code: int = field(default=A2AErrorCode.INVALID_REQUEST, init=False)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MethodNotFoundError(A2AError):
|
||||
"""The method does not exist / is not available."""
|
||||
|
||||
code: int = field(default=A2AErrorCode.METHOD_NOT_FOUND, init=False)
|
||||
method: str | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.message is None and self.method:
|
||||
self.message = f"Method not found: {self.method}"
|
||||
super().__post_init__()
|
||||
|
||||
|
||||
@dataclass
|
||||
class InvalidParamsError(A2AError):
|
||||
"""Invalid method parameter(s)."""
|
||||
|
||||
code: int = field(default=A2AErrorCode.INVALID_PARAMS, init=False)
|
||||
param: str | None = None
|
||||
reason: str | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.message is None:
|
||||
if self.param and self.reason:
|
||||
self.message = f"Invalid parameter '{self.param}': {self.reason}"
|
||||
elif self.param:
|
||||
self.message = f"Invalid parameter: {self.param}"
|
||||
super().__post_init__()
|
||||
|
||||
|
||||
@dataclass
|
||||
class InternalError(A2AError):
|
||||
"""Internal JSON-RPC error."""
|
||||
|
||||
code: int = field(default=A2AErrorCode.INTERNAL_ERROR, init=False)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TaskNotFoundError(A2AError):
|
||||
"""The specified task was not found."""
|
||||
|
||||
code: int = field(default=A2AErrorCode.TASK_NOT_FOUND, init=False)
|
||||
task_id: str | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.message is None and self.task_id:
|
||||
self.message = f"Task not found: {self.task_id}"
|
||||
super().__post_init__()
|
||||
|
||||
|
||||
@dataclass
|
||||
class TaskNotCancelableError(A2AError):
|
||||
"""The task cannot be canceled."""
|
||||
|
||||
code: int = field(default=A2AErrorCode.TASK_NOT_CANCELABLE, init=False)
|
||||
task_id: str | None = None
|
||||
reason: str | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.message is None:
|
||||
if self.task_id and self.reason:
|
||||
self.message = f"Task {self.task_id} cannot be canceled: {self.reason}"
|
||||
elif self.task_id:
|
||||
self.message = f"Task {self.task_id} cannot be canceled"
|
||||
super().__post_init__()
|
||||
|
||||
|
||||
@dataclass
|
||||
class PushNotificationNotSupportedError(A2AError):
|
||||
"""Push notifications are not supported."""
|
||||
|
||||
code: int = field(default=A2AErrorCode.PUSH_NOTIFICATION_NOT_SUPPORTED, init=False)
|
||||
|
||||
|
||||
@dataclass
|
||||
class UnsupportedOperationError(A2AError):
|
||||
"""The requested operation is not supported."""
|
||||
|
||||
code: int = field(default=A2AErrorCode.UNSUPPORTED_OPERATION, init=False)
|
||||
operation: str | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.message is None and self.operation:
|
||||
self.message = f"Operation not supported: {self.operation}"
|
||||
super().__post_init__()
|
||||
|
||||
|
||||
@dataclass
|
||||
class ContentTypeNotSupportedError(A2AError):
|
||||
"""Incompatible content types."""
|
||||
|
||||
code: int = field(default=A2AErrorCode.CONTENT_TYPE_NOT_SUPPORTED, init=False)
|
||||
requested_types: list[str] | None = None
|
||||
supported_types: list[str] | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.message is None and self.requested_types and self.supported_types:
|
||||
self.message = (
|
||||
f"Content type not supported. Requested: {self.requested_types}, "
|
||||
f"Supported: {self.supported_types}"
|
||||
)
|
||||
super().__post_init__()
|
||||
|
||||
|
||||
@dataclass
|
||||
class InvalidAgentResponseError(A2AError):
|
||||
"""The agent produced an invalid response."""
|
||||
|
||||
code: int = field(default=A2AErrorCode.INVALID_AGENT_RESPONSE, init=False)
|
||||
|
||||
|
||||
@dataclass
|
||||
class UnsupportedVersionError(A2AError):
|
||||
"""The requested A2A version is not supported."""
|
||||
|
||||
code: int = field(default=A2AErrorCode.UNSUPPORTED_VERSION, init=False)
|
||||
requested_version: str | None = None
|
||||
supported_versions: list[str] | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.message is None and self.requested_version:
|
||||
msg = f"Unsupported A2A version: {self.requested_version}"
|
||||
if self.supported_versions:
|
||||
msg += f". Supported versions: {', '.join(self.supported_versions)}"
|
||||
self.message = msg
|
||||
super().__post_init__()
|
||||
|
||||
|
||||
@dataclass
|
||||
class UnsupportedExtensionError(A2AError):
|
||||
"""Client does not support required extensions."""
|
||||
|
||||
code: int = field(default=A2AErrorCode.UNSUPPORTED_EXTENSION, init=False)
|
||||
required_extensions: list[str] | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.message is None and self.required_extensions:
|
||||
self.message = f"Client does not support required extensions: {', '.join(self.required_extensions)}"
|
||||
super().__post_init__()
|
||||
|
||||
|
||||
@dataclass
|
||||
class AuthenticationRequiredError(A2AError):
|
||||
"""Authentication is required."""
|
||||
|
||||
code: int = field(default=A2AErrorCode.AUTHENTICATION_REQUIRED, init=False)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AuthorizationFailedError(A2AError):
|
||||
"""Authorization check failed."""
|
||||
|
||||
code: int = field(default=A2AErrorCode.AUTHORIZATION_FAILED, init=False)
|
||||
required_scope: str | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.message is None and self.required_scope:
|
||||
self.message = (
|
||||
f"Authorization failed. Required scope: {self.required_scope}"
|
||||
)
|
||||
super().__post_init__()
|
||||
|
||||
|
||||
@dataclass
|
||||
class RateLimitExceededError(A2AError):
|
||||
"""Rate limit exceeded."""
|
||||
|
||||
code: int = field(default=A2AErrorCode.RATE_LIMIT_EXCEEDED, init=False)
|
||||
retry_after: int | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.message is None and self.retry_after:
|
||||
self.message = (
|
||||
f"Rate limit exceeded. Retry after {self.retry_after} seconds"
|
||||
)
|
||||
if self.retry_after:
|
||||
self.data = {"retry_after": self.retry_after}
|
||||
super().__post_init__()
|
||||
|
||||
|
||||
@dataclass
|
||||
class TaskTimeoutError(A2AError):
|
||||
"""Task execution timed out."""
|
||||
|
||||
code: int = field(default=A2AErrorCode.TASK_TIMEOUT, init=False)
|
||||
task_id: str | None = None
|
||||
timeout_seconds: float | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.message is None:
|
||||
if self.task_id and self.timeout_seconds:
|
||||
self.message = (
|
||||
f"Task {self.task_id} timed out after {self.timeout_seconds}s"
|
||||
)
|
||||
elif self.task_id:
|
||||
self.message = f"Task {self.task_id} timed out"
|
||||
super().__post_init__()
|
||||
|
||||
|
||||
@dataclass
|
||||
class TransportNegotiationFailedError(A2AError):
|
||||
"""Failed to negotiate a compatible transport protocol."""
|
||||
|
||||
code: int = field(default=A2AErrorCode.TRANSPORT_NEGOTIATION_FAILED, init=False)
|
||||
client_transports: list[str] | None = None
|
||||
server_transports: list[str] | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.message is None and self.client_transports and self.server_transports:
|
||||
self.message = (
|
||||
f"Transport negotiation failed. Client: {self.client_transports}, "
|
||||
f"Server: {self.server_transports}"
|
||||
)
|
||||
super().__post_init__()
|
||||
|
||||
|
||||
@dataclass
|
||||
class ContextNotFoundError(A2AError):
|
||||
"""The specified context was not found."""
|
||||
|
||||
code: int = field(default=A2AErrorCode.CONTEXT_NOT_FOUND, init=False)
|
||||
context_id: str | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.message is None and self.context_id:
|
||||
self.message = f"Context not found: {self.context_id}"
|
||||
super().__post_init__()
|
||||
|
||||
|
||||
@dataclass
|
||||
class SkillNotFoundError(A2AError):
|
||||
"""The specified skill was not found."""
|
||||
|
||||
code: int = field(default=A2AErrorCode.SKILL_NOT_FOUND, init=False)
|
||||
skill_id: str | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.message is None and self.skill_id:
|
||||
self.message = f"Skill not found: {self.skill_id}"
|
||||
super().__post_init__()
|
||||
|
||||
|
||||
@dataclass
|
||||
class ArtifactNotFoundError(A2AError):
|
||||
"""The specified artifact was not found."""
|
||||
|
||||
code: int = field(default=A2AErrorCode.ARTIFACT_NOT_FOUND, init=False)
|
||||
artifact_id: str | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.message is None and self.artifact_id:
|
||||
self.message = f"Artifact not found: {self.artifact_id}"
|
||||
super().__post_init__()
|
||||
|
||||
|
||||
def create_error_response(
|
||||
code: int | A2AErrorCode,
|
||||
message: str | None = None,
|
||||
data: Any = None,
|
||||
request_id: str | int | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Create a JSON-RPC error response.
|
||||
|
||||
Args:
|
||||
code: Error code (A2AErrorCode or int).
|
||||
message: Optional error message (uses default if not provided).
|
||||
data: Optional additional error data.
|
||||
request_id: Request ID for correlation.
|
||||
|
||||
Returns:
|
||||
Dict in JSON-RPC error response format.
|
||||
"""
|
||||
error = A2AError(code=int(code), message=message, data=data)
|
||||
return error.to_response(request_id)
|
||||
|
||||
|
||||
def is_retryable_error(code: int) -> bool:
|
||||
"""Check if an error is potentially retryable.
|
||||
|
||||
Args:
|
||||
code: Error code to check.
|
||||
|
||||
Returns:
|
||||
True if the error might be resolved by retrying.
|
||||
"""
|
||||
retryable_codes = {
|
||||
A2AErrorCode.INTERNAL_ERROR,
|
||||
A2AErrorCode.RATE_LIMIT_EXCEEDED,
|
||||
A2AErrorCode.TASK_TIMEOUT,
|
||||
}
|
||||
return code in retryable_codes
|
||||
|
||||
|
||||
def is_client_error(code: int) -> bool:
|
||||
"""Check if an error is a client-side error.
|
||||
|
||||
Args:
|
||||
code: Error code to check.
|
||||
|
||||
Returns:
|
||||
True if the error is due to client request issues.
|
||||
"""
|
||||
client_error_codes = {
|
||||
A2AErrorCode.JSON_PARSE_ERROR,
|
||||
A2AErrorCode.INVALID_REQUEST,
|
||||
A2AErrorCode.METHOD_NOT_FOUND,
|
||||
A2AErrorCode.INVALID_PARAMS,
|
||||
A2AErrorCode.TASK_NOT_FOUND,
|
||||
A2AErrorCode.CONTENT_TYPE_NOT_SUPPORTED,
|
||||
A2AErrorCode.UNSUPPORTED_VERSION,
|
||||
A2AErrorCode.UNSUPPORTED_EXTENSION,
|
||||
A2AErrorCode.CONTEXT_NOT_FOUND,
|
||||
A2AErrorCode.SKILL_NOT_FOUND,
|
||||
A2AErrorCode.ARTIFACT_NOT_FOUND,
|
||||
}
|
||||
return code in client_error_codes
|
||||
from crewai_a2a.errors import * # noqa: E402, F403
|
||||
|
||||
@@ -1,37 +1,13 @@
|
||||
"""A2A Protocol Extensions for CrewAI.
|
||||
"""Backward-compatibility shim — use ``crewai_a2a.extensions`` instead."""
|
||||
|
||||
This module contains extensions to the A2A (Agent-to-Agent) protocol.
|
||||
import warnings
|
||||
|
||||
**Client-side extensions** (A2AExtension) allow customizing how the A2A wrapper
|
||||
processes requests and responses during delegation to remote agents. These provide
|
||||
hooks for tool injection, prompt augmentation, and response processing.
|
||||
|
||||
**Server-side extensions** (ServerExtension) allow agents to offer additional
|
||||
functionality beyond the core A2A specification. Clients activate extensions
|
||||
via the X-A2A-Extensions header.
|
||||
|
||||
See: https://a2a-protocol.org/latest/topics/extensions/
|
||||
"""
|
||||
|
||||
from crewai.a2a.extensions.base import (
|
||||
A2AExtension,
|
||||
ConversationState,
|
||||
ExtensionRegistry,
|
||||
ValidatedA2AExtension,
|
||||
)
|
||||
from crewai.a2a.extensions.server import (
|
||||
ExtensionContext,
|
||||
ServerExtension,
|
||||
ServerExtensionRegistry,
|
||||
warnings.warn(
|
||||
"'crewai.a2a.extensions' has been moved to 'crewai_a2a.extensions'. "
|
||||
"Please update your imports. The old path will be removed in v2.0.0.",
|
||||
FutureWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"A2AExtension",
|
||||
"ConversationState",
|
||||
"ExtensionContext",
|
||||
"ExtensionRegistry",
|
||||
"ServerExtension",
|
||||
"ServerExtensionRegistry",
|
||||
"ValidatedA2AExtension",
|
||||
]
|
||||
from crewai_a2a.extensions import * # noqa: E402, F403
|
||||
|
||||
@@ -1,238 +1,13 @@
|
||||
"""Base extension interface for CrewAI A2A wrapper processing hooks.
|
||||
"""Backward-compatibility shim — use ``crewai_a2a.extensions.base`` instead."""
|
||||
|
||||
This module defines the protocol for extending CrewAI's A2A wrapper functionality
|
||||
with custom logic for tool injection, prompt augmentation, and response processing.
|
||||
|
||||
Note: These are CrewAI-specific processing hooks, NOT A2A protocol extensions.
|
||||
A2A protocol extensions are capability declarations using AgentExtension objects
|
||||
in AgentCard.capabilities.extensions, activated via the A2A-Extensions HTTP header.
|
||||
See: https://a2a-protocol.org/latest/topics/extensions/
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import TYPE_CHECKING, Annotated, Any, Protocol, runtime_checkable
|
||||
|
||||
from pydantic import BeforeValidator
|
||||
import warnings
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from a2a.types import Message
|
||||
warnings.warn(
|
||||
"'crewai.a2a.extensions.base' has been moved to 'crewai_a2a.extensions.base'. "
|
||||
"Please update your imports. The old path will be removed in v2.0.0.",
|
||||
FutureWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
from crewai.agent.core import Agent
|
||||
|
||||
|
||||
def _validate_a2a_extension(v: Any) -> Any:
|
||||
"""Validate that value implements A2AExtension protocol."""
|
||||
if not isinstance(v, A2AExtension):
|
||||
raise ValueError(
|
||||
f"Value must implement A2AExtension protocol. "
|
||||
f"Got {type(v).__name__} which is missing required methods."
|
||||
)
|
||||
return v
|
||||
|
||||
|
||||
ValidatedA2AExtension = Annotated[Any, BeforeValidator(_validate_a2a_extension)]
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class ConversationState(Protocol):
|
||||
"""Protocol for extension-specific conversation state.
|
||||
|
||||
Extensions can define their own state classes that implement this protocol
|
||||
to track conversation-specific data extracted from message history.
|
||||
"""
|
||||
|
||||
def is_ready(self) -> bool:
|
||||
"""Check if the state indicates readiness for some action.
|
||||
|
||||
Returns:
|
||||
True if the state is ready, False otherwise.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class A2AExtension(Protocol):
|
||||
"""Protocol for A2A wrapper extensions.
|
||||
|
||||
Extensions can implement this protocol to inject custom logic into
|
||||
the A2A conversation flow at various integration points.
|
||||
|
||||
Example:
|
||||
class MyExtension:
|
||||
def inject_tools(self, agent: Agent) -> None:
|
||||
# Add custom tools to the agent
|
||||
pass
|
||||
|
||||
def extract_state_from_history(
|
||||
self, conversation_history: Sequence[Message]
|
||||
) -> ConversationState | None:
|
||||
# Extract state from conversation
|
||||
return None
|
||||
|
||||
def augment_prompt(
|
||||
self, base_prompt: str, conversation_state: ConversationState | None
|
||||
) -> str:
|
||||
# Add custom instructions
|
||||
return base_prompt
|
||||
|
||||
def process_response(
|
||||
self, agent_response: Any, conversation_state: ConversationState | None
|
||||
) -> Any:
|
||||
# Modify response if needed
|
||||
return agent_response
|
||||
"""
|
||||
|
||||
def inject_tools(self, agent: Agent) -> None:
|
||||
"""Inject extension-specific tools into the agent.
|
||||
|
||||
Called when an agent is wrapped with A2A capabilities. Extensions
|
||||
can add tools that enable extension-specific functionality.
|
||||
|
||||
Args:
|
||||
agent: The agent instance to inject tools into.
|
||||
"""
|
||||
...
|
||||
|
||||
def extract_state_from_history(
|
||||
self, conversation_history: Sequence[Message]
|
||||
) -> ConversationState | None:
|
||||
"""Extract extension-specific state from conversation history.
|
||||
|
||||
Called during prompt augmentation to allow extensions to analyze
|
||||
the conversation history and extract relevant state information.
|
||||
|
||||
Args:
|
||||
conversation_history: The sequence of A2A messages exchanged.
|
||||
|
||||
Returns:
|
||||
Extension-specific conversation state, or None if no relevant state.
|
||||
"""
|
||||
...
|
||||
|
||||
def augment_prompt(
|
||||
self,
|
||||
base_prompt: str,
|
||||
conversation_state: ConversationState | None,
|
||||
) -> str:
|
||||
"""Augment the task prompt with extension-specific instructions.
|
||||
|
||||
Called during prompt augmentation to allow extensions to add
|
||||
custom instructions based on conversation state.
|
||||
|
||||
Args:
|
||||
base_prompt: The base prompt to augment.
|
||||
conversation_state: Extension-specific state from extract_state_from_history.
|
||||
|
||||
Returns:
|
||||
The augmented prompt with extension-specific instructions.
|
||||
"""
|
||||
...
|
||||
|
||||
def process_response(
|
||||
self,
|
||||
agent_response: Any,
|
||||
conversation_state: ConversationState | None,
|
||||
) -> Any:
|
||||
"""Process and potentially modify the agent response.
|
||||
|
||||
Called after parsing the agent's response, allowing extensions to
|
||||
enhance or modify the response based on conversation state.
|
||||
|
||||
Args:
|
||||
agent_response: The parsed agent response.
|
||||
conversation_state: Extension-specific state from extract_state_from_history.
|
||||
|
||||
Returns:
|
||||
The processed agent response (may be modified or original).
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class ExtensionRegistry:
|
||||
"""Registry for managing A2A extensions.
|
||||
|
||||
Maintains a collection of extensions and provides methods to invoke
|
||||
their hooks at various integration points.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the extension registry."""
|
||||
self._extensions: list[A2AExtension] = []
|
||||
|
||||
def register(self, extension: A2AExtension) -> None:
|
||||
"""Register an extension.
|
||||
|
||||
Args:
|
||||
extension: The extension to register.
|
||||
"""
|
||||
self._extensions.append(extension)
|
||||
|
||||
def inject_all_tools(self, agent: Agent) -> None:
|
||||
"""Inject tools from all registered extensions.
|
||||
|
||||
Args:
|
||||
agent: The agent instance to inject tools into.
|
||||
"""
|
||||
for extension in self._extensions:
|
||||
extension.inject_tools(agent)
|
||||
|
||||
def extract_all_states(
|
||||
self, conversation_history: Sequence[Message]
|
||||
) -> dict[type[A2AExtension], ConversationState]:
|
||||
"""Extract conversation states from all registered extensions.
|
||||
|
||||
Args:
|
||||
conversation_history: The sequence of A2A messages exchanged.
|
||||
|
||||
Returns:
|
||||
Mapping of extension types to their conversation states.
|
||||
"""
|
||||
states: dict[type[A2AExtension], ConversationState] = {}
|
||||
for extension in self._extensions:
|
||||
state = extension.extract_state_from_history(conversation_history)
|
||||
if state is not None:
|
||||
states[type(extension)] = state
|
||||
return states
|
||||
|
||||
def augment_prompt_with_all(
|
||||
self,
|
||||
base_prompt: str,
|
||||
extension_states: dict[type[A2AExtension], ConversationState],
|
||||
) -> str:
|
||||
"""Augment prompt with instructions from all registered extensions.
|
||||
|
||||
Args:
|
||||
base_prompt: The base prompt to augment.
|
||||
extension_states: Mapping of extension types to conversation states.
|
||||
|
||||
Returns:
|
||||
The fully augmented prompt.
|
||||
"""
|
||||
augmented = base_prompt
|
||||
for extension in self._extensions:
|
||||
state = extension_states.get(type(extension))
|
||||
augmented = extension.augment_prompt(augmented, state)
|
||||
return augmented
|
||||
|
||||
def process_response_with_all(
|
||||
self,
|
||||
agent_response: Any,
|
||||
extension_states: dict[type[A2AExtension], ConversationState],
|
||||
) -> Any:
|
||||
"""Process response through all registered extensions.
|
||||
|
||||
Args:
|
||||
agent_response: The parsed agent response.
|
||||
extension_states: Mapping of extension types to conversation states.
|
||||
|
||||
Returns:
|
||||
The processed agent response.
|
||||
"""
|
||||
processed = agent_response
|
||||
for extension in self._extensions:
|
||||
state = extension_states.get(type(extension))
|
||||
processed = extension.process_response(processed, state)
|
||||
return processed
|
||||
from crewai_a2a.extensions.base import * # noqa: E402, F403
|
||||
|
||||
@@ -1,170 +1,13 @@
|
||||
"""A2A Protocol extension utilities.
|
||||
"""Backward-compatibility shim — use ``crewai_a2a.extensions.registry`` instead."""
|
||||
|
||||
This module provides utilities for working with A2A protocol extensions as
|
||||
defined in the A2A specification. Extensions are capability declarations in
|
||||
AgentCard.capabilities.extensions using AgentExtension objects, activated
|
||||
via the X-A2A-Extensions HTTP header.
|
||||
import warnings
|
||||
|
||||
See: https://a2a-protocol.org/latest/topics/extensions/
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from a2a.client.middleware import ClientCallContext, ClientCallInterceptor
|
||||
from a2a.extensions.common import (
|
||||
HTTP_EXTENSION_HEADER,
|
||||
warnings.warn(
|
||||
"'crewai.a2a.extensions.registry' has been moved to 'crewai_a2a.extensions.registry'. "
|
||||
"Please update your imports. The old path will be removed in v2.0.0.",
|
||||
FutureWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
from a2a.types import AgentCard, AgentExtension
|
||||
|
||||
from crewai.a2a.config import A2AClientConfig, A2AConfig
|
||||
from crewai.a2a.extensions.base import ExtensionRegistry
|
||||
|
||||
|
||||
def get_extensions_from_config(
|
||||
a2a_config: list[A2AConfig | A2AClientConfig] | A2AConfig | A2AClientConfig,
|
||||
) -> list[str]:
|
||||
"""Extract extension URIs from A2A configuration.
|
||||
|
||||
Args:
|
||||
a2a_config: A2A configuration (single or list).
|
||||
|
||||
Returns:
|
||||
Deduplicated list of extension URIs from all configs.
|
||||
"""
|
||||
configs = a2a_config if isinstance(a2a_config, list) else [a2a_config]
|
||||
seen: set[str] = set()
|
||||
result: list[str] = []
|
||||
|
||||
for config in configs:
|
||||
if not isinstance(config, A2AClientConfig):
|
||||
continue
|
||||
for uri in config.extensions:
|
||||
if uri not in seen:
|
||||
seen.add(uri)
|
||||
result.append(uri)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class ExtensionsMiddleware(ClientCallInterceptor):
|
||||
"""Middleware to add X-A2A-Extensions header to requests.
|
||||
|
||||
This middleware adds the extensions header to all outgoing requests,
|
||||
declaring which A2A protocol extensions the client supports.
|
||||
"""
|
||||
|
||||
def __init__(self, extensions: list[str]) -> None:
|
||||
"""Initialize with extension URIs.
|
||||
|
||||
Args:
|
||||
extensions: List of extension URIs the client supports.
|
||||
"""
|
||||
self._extensions = extensions
|
||||
|
||||
async def intercept(
|
||||
self,
|
||||
method_name: str,
|
||||
request_payload: dict[str, Any],
|
||||
http_kwargs: dict[str, Any],
|
||||
agent_card: AgentCard | None,
|
||||
context: ClientCallContext | None,
|
||||
) -> tuple[dict[str, Any], dict[str, Any]]:
|
||||
"""Add extensions header to the request.
|
||||
|
||||
Args:
|
||||
method_name: The A2A method being called.
|
||||
request_payload: The JSON-RPC request payload.
|
||||
http_kwargs: HTTP request kwargs (headers, etc).
|
||||
agent_card: The target agent's card.
|
||||
context: Optional call context.
|
||||
|
||||
Returns:
|
||||
Tuple of (request_payload, modified_http_kwargs).
|
||||
"""
|
||||
if self._extensions:
|
||||
headers = http_kwargs.setdefault("headers", {})
|
||||
headers[HTTP_EXTENSION_HEADER] = ",".join(self._extensions)
|
||||
return request_payload, http_kwargs
|
||||
|
||||
|
||||
def validate_required_extensions(
|
||||
agent_card: AgentCard,
|
||||
client_extensions: list[str] | None,
|
||||
) -> list[AgentExtension]:
|
||||
"""Validate that client supports all required extensions from agent.
|
||||
|
||||
Args:
|
||||
agent_card: The agent's card with declared extensions.
|
||||
client_extensions: Extension URIs the client supports.
|
||||
|
||||
Returns:
|
||||
List of unsupported required extensions.
|
||||
|
||||
Raises:
|
||||
None - returns list of unsupported extensions for caller to handle.
|
||||
"""
|
||||
unsupported: list[AgentExtension] = []
|
||||
client_set = set(client_extensions or [])
|
||||
|
||||
if not agent_card.capabilities or not agent_card.capabilities.extensions:
|
||||
return unsupported
|
||||
|
||||
unsupported.extend(
|
||||
ext
|
||||
for ext in agent_card.capabilities.extensions
|
||||
if ext.required and ext.uri not in client_set
|
||||
)
|
||||
|
||||
return unsupported
|
||||
|
||||
|
||||
def create_extension_registry_from_config(
|
||||
a2a_config: list[A2AConfig | A2AClientConfig] | A2AConfig | A2AClientConfig,
|
||||
) -> ExtensionRegistry:
|
||||
"""Create an extension registry from A2A client configuration.
|
||||
|
||||
Extracts client_extensions from each A2AClientConfig and registers them
|
||||
with the ExtensionRegistry. These extensions provide CrewAI-specific
|
||||
processing hooks (tool injection, prompt augmentation, response processing).
|
||||
|
||||
Note: A2A protocol extensions (URI strings sent via X-A2A-Extensions header)
|
||||
are handled separately via get_extensions_from_config() and ExtensionsMiddleware.
|
||||
|
||||
Args:
|
||||
a2a_config: A2A configuration (single or list).
|
||||
|
||||
Returns:
|
||||
Extension registry with all client_extensions registered.
|
||||
|
||||
Example:
|
||||
class LoggingExtension:
|
||||
def inject_tools(self, agent): pass
|
||||
def extract_state_from_history(self, history): return None
|
||||
def augment_prompt(self, prompt, state): return prompt
|
||||
def process_response(self, response, state):
|
||||
print(f"Response: {response}")
|
||||
return response
|
||||
|
||||
config = A2AClientConfig(
|
||||
endpoint="https://agent.example.com",
|
||||
client_extensions=[LoggingExtension()],
|
||||
)
|
||||
registry = create_extension_registry_from_config(config)
|
||||
"""
|
||||
registry = ExtensionRegistry()
|
||||
configs = a2a_config if isinstance(a2a_config, list) else [a2a_config]
|
||||
|
||||
seen: set[int] = set()
|
||||
|
||||
for config in configs:
|
||||
if isinstance(config, (A2AConfig, A2AClientConfig)):
|
||||
client_exts = getattr(config, "client_extensions", [])
|
||||
for extension in client_exts:
|
||||
ext_id = id(extension)
|
||||
if ext_id not in seen:
|
||||
seen.add(ext_id)
|
||||
registry.register(extension)
|
||||
|
||||
return registry
|
||||
from crewai_a2a.extensions.registry import * # noqa: E402, F403
|
||||
|
||||
@@ -1,305 +1,13 @@
|
||||
"""A2A protocol server extensions for CrewAI agents.
|
||||
"""Backward-compatibility shim — use ``crewai_a2a.extensions.server`` instead."""
|
||||
|
||||
This module provides the base class and context for implementing A2A protocol
|
||||
extensions on the server side. Extensions allow agents to offer additional
|
||||
functionality beyond the core A2A specification.
|
||||
import warnings
|
||||
|
||||
See: https://a2a-protocol.org/latest/topics/extensions/
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
warnings.warn(
|
||||
"'crewai.a2a.extensions.server' has been moved to 'crewai_a2a.extensions.server'. "
|
||||
"Please update your imports. The old path will be removed in v2.0.0.",
|
||||
FutureWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Annotated, Any
|
||||
|
||||
from a2a.types import AgentExtension
|
||||
from pydantic_core import CoreSchema, core_schema
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from a2a.server.context import ServerCallContext
|
||||
from pydantic import GetCoreSchemaHandler
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExtensionContext:
|
||||
"""Context passed to extension hooks during request processing.
|
||||
|
||||
Provides access to request metadata, client extensions, and shared state
|
||||
that extensions can read from and write to.
|
||||
|
||||
Attributes:
|
||||
metadata: Request metadata dict, includes extension-namespaced keys.
|
||||
client_extensions: Set of extension URIs the client declared support for.
|
||||
state: Mutable dict for extensions to share data during request lifecycle.
|
||||
server_context: The underlying A2A server call context.
|
||||
"""
|
||||
|
||||
metadata: dict[str, Any]
|
||||
client_extensions: set[str]
|
||||
state: dict[str, Any] = field(default_factory=dict)
|
||||
server_context: ServerCallContext | None = None
|
||||
|
||||
def get_extension_metadata(self, uri: str, key: str) -> Any | None:
|
||||
"""Get extension-specific metadata value.
|
||||
|
||||
Extension metadata uses namespaced keys in the format:
|
||||
"{extension_uri}/{key}"
|
||||
|
||||
Args:
|
||||
uri: The extension URI.
|
||||
key: The metadata key within the extension namespace.
|
||||
|
||||
Returns:
|
||||
The metadata value, or None if not present.
|
||||
"""
|
||||
full_key = f"{uri}/{key}"
|
||||
return self.metadata.get(full_key)
|
||||
|
||||
def set_extension_metadata(self, uri: str, key: str, value: Any) -> None:
|
||||
"""Set extension-specific metadata value.
|
||||
|
||||
Args:
|
||||
uri: The extension URI.
|
||||
key: The metadata key within the extension namespace.
|
||||
value: The value to set.
|
||||
"""
|
||||
full_key = f"{uri}/{key}"
|
||||
self.metadata[full_key] = value
|
||||
|
||||
|
||||
class ServerExtension(ABC):
|
||||
"""Base class for A2A protocol server extensions.
|
||||
|
||||
Subclass this to create custom extensions that modify agent behavior
|
||||
when clients activate them. Extensions are identified by URI and can
|
||||
be marked as required.
|
||||
|
||||
Example:
|
||||
class SamplingExtension(ServerExtension):
|
||||
uri = "urn:crewai:ext:sampling/v1"
|
||||
required = True
|
||||
|
||||
def __init__(self, max_tokens: int = 4096):
|
||||
self.max_tokens = max_tokens
|
||||
|
||||
@property
|
||||
def params(self) -> dict[str, Any]:
|
||||
return {"max_tokens": self.max_tokens}
|
||||
|
||||
async def on_request(self, context: ExtensionContext) -> None:
|
||||
limit = context.get_extension_metadata(self.uri, "limit")
|
||||
if limit:
|
||||
context.state["token_limit"] = int(limit)
|
||||
|
||||
async def on_response(self, context: ExtensionContext, result: Any) -> Any:
|
||||
return result
|
||||
"""
|
||||
|
||||
uri: Annotated[str, "Extension URI identifier. Must be unique."]
|
||||
required: Annotated[bool, "Whether clients must support this extension."] = False
|
||||
description: Annotated[
|
||||
str | None, "Human-readable description of the extension."
|
||||
] = None
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(
|
||||
cls,
|
||||
_source_type: Any,
|
||||
_handler: GetCoreSchemaHandler,
|
||||
) -> CoreSchema:
|
||||
"""Tell Pydantic how to validate ServerExtension instances."""
|
||||
return core_schema.is_instance_schema(cls)
|
||||
|
||||
@property
|
||||
def params(self) -> dict[str, Any] | None:
|
||||
"""Extension parameters to advertise in AgentCard.
|
||||
|
||||
Override this property to expose configuration that clients can read.
|
||||
|
||||
Returns:
|
||||
Dict of parameter names to values, or None.
|
||||
"""
|
||||
return None
|
||||
|
||||
def agent_extension(self) -> AgentExtension:
|
||||
"""Generate the AgentExtension object for the AgentCard.
|
||||
|
||||
Returns:
|
||||
AgentExtension with this extension's URI, required flag, and params.
|
||||
"""
|
||||
return AgentExtension(
|
||||
uri=self.uri,
|
||||
required=self.required if self.required else None,
|
||||
description=self.description,
|
||||
params=self.params,
|
||||
)
|
||||
|
||||
def is_active(self, context: ExtensionContext) -> bool:
|
||||
"""Check if this extension is active for the current request.
|
||||
|
||||
An extension is active if the client declared support for it.
|
||||
|
||||
Args:
|
||||
context: The extension context for the current request.
|
||||
|
||||
Returns:
|
||||
True if the client supports this extension.
|
||||
"""
|
||||
return self.uri in context.client_extensions
|
||||
|
||||
@abstractmethod
|
||||
async def on_request(self, context: ExtensionContext) -> None:
|
||||
"""Called before agent execution if extension is active.
|
||||
|
||||
Use this hook to:
|
||||
- Read extension-specific metadata from the request
|
||||
- Set up state for the execution
|
||||
- Modify execution parameters via context.state
|
||||
|
||||
Args:
|
||||
context: The extension context with request metadata and state.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def on_response(self, context: ExtensionContext, result: Any) -> Any:
|
||||
"""Called after agent execution if extension is active.
|
||||
|
||||
Use this hook to:
|
||||
- Modify or enhance the result
|
||||
- Add extension-specific metadata to the response
|
||||
- Clean up any resources
|
||||
|
||||
Args:
|
||||
context: The extension context with request metadata and state.
|
||||
result: The agent execution result.
|
||||
|
||||
Returns:
|
||||
The result, potentially modified.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class ServerExtensionRegistry:
|
||||
"""Registry for managing server-side A2A protocol extensions.
|
||||
|
||||
Collects extensions and provides methods to generate AgentCapabilities
|
||||
and invoke extension hooks during request processing.
|
||||
"""
|
||||
|
||||
def __init__(self, extensions: list[ServerExtension] | None = None) -> None:
|
||||
"""Initialize the registry with optional extensions.
|
||||
|
||||
Args:
|
||||
extensions: Initial list of extensions to register.
|
||||
"""
|
||||
self._extensions: list[ServerExtension] = list(extensions) if extensions else []
|
||||
self._by_uri: dict[str, ServerExtension] = {
|
||||
ext.uri: ext for ext in self._extensions
|
||||
}
|
||||
|
||||
def register(self, extension: ServerExtension) -> None:
|
||||
"""Register an extension.
|
||||
|
||||
Args:
|
||||
extension: The extension to register.
|
||||
|
||||
Raises:
|
||||
ValueError: If an extension with the same URI is already registered.
|
||||
"""
|
||||
if extension.uri in self._by_uri:
|
||||
raise ValueError(f"Extension already registered: {extension.uri}")
|
||||
self._extensions.append(extension)
|
||||
self._by_uri[extension.uri] = extension
|
||||
|
||||
def get_agent_extensions(self) -> list[AgentExtension]:
|
||||
"""Get AgentExtension objects for all registered extensions.
|
||||
|
||||
Returns:
|
||||
List of AgentExtension objects for the AgentCard.
|
||||
"""
|
||||
return [ext.agent_extension() for ext in self._extensions]
|
||||
|
||||
def get_extension(self, uri: str) -> ServerExtension | None:
|
||||
"""Get an extension by URI.
|
||||
|
||||
Args:
|
||||
uri: The extension URI.
|
||||
|
||||
Returns:
|
||||
The extension, or None if not found.
|
||||
"""
|
||||
return self._by_uri.get(uri)
|
||||
|
||||
@staticmethod
|
||||
def create_context(
|
||||
metadata: dict[str, Any],
|
||||
client_extensions: set[str],
|
||||
server_context: ServerCallContext | None = None,
|
||||
) -> ExtensionContext:
|
||||
"""Create an ExtensionContext for a request.
|
||||
|
||||
Args:
|
||||
metadata: Request metadata dict.
|
||||
client_extensions: Set of extension URIs from client.
|
||||
server_context: Optional server call context.
|
||||
|
||||
Returns:
|
||||
ExtensionContext for use in hooks.
|
||||
"""
|
||||
return ExtensionContext(
|
||||
metadata=metadata,
|
||||
client_extensions=client_extensions,
|
||||
server_context=server_context,
|
||||
)
|
||||
|
||||
async def invoke_on_request(self, context: ExtensionContext) -> None:
|
||||
"""Invoke on_request hooks for all active extensions.
|
||||
|
||||
Tracks activated extensions and isolates errors from individual hooks.
|
||||
|
||||
Args:
|
||||
context: The extension context for the request.
|
||||
"""
|
||||
for extension in self._extensions:
|
||||
if extension.is_active(context):
|
||||
try:
|
||||
await extension.on_request(context)
|
||||
if context.server_context is not None:
|
||||
context.server_context.activated_extensions.add(extension.uri)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Extension on_request hook failed",
|
||||
extra={"extension": extension.uri},
|
||||
)
|
||||
|
||||
async def invoke_on_response(self, context: ExtensionContext, result: Any) -> Any:
|
||||
"""Invoke on_response hooks for all active extensions.
|
||||
|
||||
Isolates errors from individual hooks to prevent one failing extension
|
||||
from breaking the entire response.
|
||||
|
||||
Args:
|
||||
context: The extension context for the request.
|
||||
result: The agent execution result.
|
||||
|
||||
Returns:
|
||||
The result after all extensions have processed it.
|
||||
"""
|
||||
processed = result
|
||||
for extension in self._extensions:
|
||||
if extension.is_active(context):
|
||||
try:
|
||||
processed = await extension.on_response(context, processed)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Extension on_response hook failed",
|
||||
extra={"extension": extension.uri},
|
||||
)
|
||||
return processed
|
||||
from crewai_a2a.extensions.server import * # noqa: E402, F403
|
||||
|
||||
@@ -1,480 +1,13 @@
|
||||
"""Helper functions for processing A2A task results."""
|
||||
"""Backward-compatibility shim — use ``crewai_a2a.task_helpers`` instead."""
|
||||
|
||||
from __future__ import annotations
|
||||
import warnings
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import TYPE_CHECKING, Any, TypedDict
|
||||
import uuid
|
||||
|
||||
from a2a.client.errors import A2AClientHTTPError
|
||||
from a2a.types import (
|
||||
AgentCard,
|
||||
Message,
|
||||
Part,
|
||||
Role,
|
||||
Task,
|
||||
TaskArtifactUpdateEvent,
|
||||
TaskState,
|
||||
TaskStatusUpdateEvent,
|
||||
TextPart,
|
||||
)
|
||||
from typing_extensions import NotRequired
|
||||
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.a2a_events import (
|
||||
A2AConnectionErrorEvent,
|
||||
A2AResponseReceivedEvent,
|
||||
warnings.warn(
|
||||
"'crewai.a2a.task_helpers' has been moved to 'crewai_a2a.task_helpers'. "
|
||||
"Please update your imports. The old path will be removed in v2.0.0.",
|
||||
FutureWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from a2a.types import Task as A2ATask
|
||||
|
||||
SendMessageEvent = (
|
||||
tuple[Task, TaskStatusUpdateEvent | TaskArtifactUpdateEvent | None] | Message
|
||||
)
|
||||
|
||||
|
||||
TERMINAL_STATES: frozenset[TaskState] = frozenset(
|
||||
{
|
||||
TaskState.completed,
|
||||
TaskState.failed,
|
||||
TaskState.rejected,
|
||||
TaskState.canceled,
|
||||
}
|
||||
)
|
||||
|
||||
ACTIONABLE_STATES: frozenset[TaskState] = frozenset(
|
||||
{
|
||||
TaskState.input_required,
|
||||
TaskState.auth_required,
|
||||
}
|
||||
)
|
||||
|
||||
PENDING_STATES: frozenset[TaskState] = frozenset(
|
||||
{
|
||||
TaskState.submitted,
|
||||
TaskState.working,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class TaskStateResult(TypedDict):
|
||||
"""Result dictionary from processing A2A task state."""
|
||||
|
||||
status: TaskState
|
||||
history: list[Message]
|
||||
result: NotRequired[str]
|
||||
error: NotRequired[str]
|
||||
agent_card: NotRequired[dict[str, Any]]
|
||||
a2a_agent_name: NotRequired[str | None]
|
||||
|
||||
|
||||
def extract_task_result_parts(a2a_task: A2ATask) -> list[str]:
|
||||
"""Extract result parts from A2A task status message, history, and artifacts.
|
||||
|
||||
Args:
|
||||
a2a_task: A2A Task object with status, history, and artifacts
|
||||
|
||||
Returns:
|
||||
List of result text parts
|
||||
"""
|
||||
result_parts: list[str] = []
|
||||
|
||||
if a2a_task.status and a2a_task.status.message:
|
||||
msg = a2a_task.status.message
|
||||
result_parts.extend(
|
||||
part.root.text for part in msg.parts if part.root.kind == "text"
|
||||
)
|
||||
|
||||
if not result_parts and a2a_task.history:
|
||||
for history_msg in reversed(a2a_task.history):
|
||||
if history_msg.role == Role.agent:
|
||||
result_parts.extend(
|
||||
part.root.text
|
||||
for part in history_msg.parts
|
||||
if part.root.kind == "text"
|
||||
)
|
||||
break
|
||||
|
||||
if a2a_task.artifacts:
|
||||
result_parts.extend(
|
||||
part.root.text
|
||||
for artifact in a2a_task.artifacts
|
||||
for part in artifact.parts
|
||||
if part.root.kind == "text"
|
||||
)
|
||||
|
||||
return result_parts
|
||||
|
||||
|
||||
def extract_error_message(a2a_task: A2ATask, default: str) -> str:
|
||||
"""Extract error message from A2A task.
|
||||
|
||||
Args:
|
||||
a2a_task: A2A Task object
|
||||
default: Default message if no error found
|
||||
|
||||
Returns:
|
||||
Error message string
|
||||
"""
|
||||
if a2a_task.status and a2a_task.status.message:
|
||||
msg = a2a_task.status.message
|
||||
if msg:
|
||||
for part in msg.parts:
|
||||
if part.root.kind == "text":
|
||||
return str(part.root.text)
|
||||
return str(msg)
|
||||
|
||||
if a2a_task.history:
|
||||
for history_msg in reversed(a2a_task.history):
|
||||
for part in history_msg.parts:
|
||||
if part.root.kind == "text":
|
||||
return str(part.root.text)
|
||||
|
||||
return default
|
||||
|
||||
|
||||
def process_task_state(
|
||||
a2a_task: A2ATask,
|
||||
new_messages: list[Message],
|
||||
agent_card: AgentCard,
|
||||
turn_number: int,
|
||||
is_multiturn: bool,
|
||||
agent_role: str | None,
|
||||
result_parts: list[str] | None = None,
|
||||
endpoint: str | None = None,
|
||||
a2a_agent_name: str | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
is_final: bool = True,
|
||||
) -> TaskStateResult | None:
|
||||
"""Process A2A task state and return result dictionary.
|
||||
|
||||
Shared logic for both polling and streaming handlers.
|
||||
|
||||
Args:
|
||||
a2a_task: The A2A task to process.
|
||||
new_messages: List to collect messages (modified in place).
|
||||
agent_card: The agent card.
|
||||
turn_number: Current turn number.
|
||||
is_multiturn: Whether multi-turn conversation.
|
||||
agent_role: Agent role for logging.
|
||||
result_parts: Accumulated result parts (streaming passes accumulated,
|
||||
polling passes None to extract from task).
|
||||
endpoint: A2A agent endpoint URL.
|
||||
a2a_agent_name: Name of the A2A agent from agent card.
|
||||
from_task: Optional CrewAI Task for event metadata.
|
||||
from_agent: Optional CrewAI Agent for event metadata.
|
||||
is_final: Whether this is the final response in the stream.
|
||||
|
||||
Returns:
|
||||
Result dictionary if terminal/actionable state, None otherwise.
|
||||
"""
|
||||
if result_parts is None:
|
||||
result_parts = []
|
||||
|
||||
if a2a_task.status.state == TaskState.completed:
|
||||
if not result_parts:
|
||||
extracted_parts = extract_task_result_parts(a2a_task)
|
||||
result_parts.extend(extracted_parts)
|
||||
if a2a_task.history:
|
||||
new_messages.extend(a2a_task.history)
|
||||
|
||||
response_text = " ".join(result_parts) if result_parts else ""
|
||||
message_id = None
|
||||
if a2a_task.status and a2a_task.status.message:
|
||||
message_id = a2a_task.status.message.message_id
|
||||
crewai_event_bus.emit(
|
||||
None,
|
||||
A2AResponseReceivedEvent(
|
||||
response=response_text,
|
||||
turn_number=turn_number,
|
||||
context_id=a2a_task.context_id,
|
||||
message_id=message_id,
|
||||
is_multiturn=is_multiturn,
|
||||
status="completed",
|
||||
final=is_final,
|
||||
agent_role=agent_role,
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
|
||||
return TaskStateResult(
|
||||
status=TaskState.completed,
|
||||
agent_card=agent_card.model_dump(exclude_none=True),
|
||||
result=response_text,
|
||||
history=new_messages,
|
||||
)
|
||||
|
||||
if a2a_task.status.state == TaskState.input_required:
|
||||
if a2a_task.history:
|
||||
new_messages.extend(a2a_task.history)
|
||||
|
||||
response_text = extract_error_message(a2a_task, "Additional input required")
|
||||
if response_text and not a2a_task.history:
|
||||
agent_message = Message(
|
||||
role=Role.agent,
|
||||
message_id=str(uuid.uuid4()),
|
||||
parts=[Part(root=TextPart(text=response_text))],
|
||||
context_id=a2a_task.context_id,
|
||||
task_id=a2a_task.id,
|
||||
)
|
||||
new_messages.append(agent_message)
|
||||
|
||||
input_message_id = None
|
||||
if a2a_task.status and a2a_task.status.message:
|
||||
input_message_id = a2a_task.status.message.message_id
|
||||
crewai_event_bus.emit(
|
||||
None,
|
||||
A2AResponseReceivedEvent(
|
||||
response=response_text,
|
||||
turn_number=turn_number,
|
||||
context_id=a2a_task.context_id,
|
||||
message_id=input_message_id,
|
||||
is_multiturn=is_multiturn,
|
||||
status="input_required",
|
||||
final=is_final,
|
||||
agent_role=agent_role,
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
|
||||
return TaskStateResult(
|
||||
status=TaskState.input_required,
|
||||
error=response_text,
|
||||
history=new_messages,
|
||||
agent_card=agent_card.model_dump(exclude_none=True),
|
||||
)
|
||||
|
||||
if a2a_task.status.state in {TaskState.failed, TaskState.rejected}:
|
||||
error_msg = extract_error_message(a2a_task, "Task failed without error message")
|
||||
if a2a_task.history:
|
||||
new_messages.extend(a2a_task.history)
|
||||
return TaskStateResult(
|
||||
status=TaskState.failed,
|
||||
error=error_msg,
|
||||
history=new_messages,
|
||||
)
|
||||
|
||||
if a2a_task.status.state == TaskState.auth_required:
|
||||
error_msg = extract_error_message(a2a_task, "Authentication required")
|
||||
return TaskStateResult(
|
||||
status=TaskState.auth_required,
|
||||
error=error_msg,
|
||||
history=new_messages,
|
||||
)
|
||||
|
||||
if a2a_task.status.state == TaskState.canceled:
|
||||
error_msg = extract_error_message(a2a_task, "Task was canceled")
|
||||
return TaskStateResult(
|
||||
status=TaskState.canceled,
|
||||
error=error_msg,
|
||||
history=new_messages,
|
||||
)
|
||||
|
||||
if a2a_task.status.state in PENDING_STATES:
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
|
||||
async def send_message_and_get_task_id(
|
||||
event_stream: AsyncIterator[SendMessageEvent],
|
||||
new_messages: list[Message],
|
||||
agent_card: AgentCard,
|
||||
turn_number: int,
|
||||
is_multiturn: bool,
|
||||
agent_role: str | None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
endpoint: str | None = None,
|
||||
a2a_agent_name: str | None = None,
|
||||
context_id: str | None = None,
|
||||
) -> str | TaskStateResult:
|
||||
"""Send message and process initial response.
|
||||
|
||||
Handles the common pattern of sending a message and either:
|
||||
- Getting an immediate Message response (task completed synchronously)
|
||||
- Getting a Task that needs polling/waiting for completion
|
||||
|
||||
Args:
|
||||
event_stream: Async iterator from client.send_message()
|
||||
new_messages: List to collect messages (modified in place)
|
||||
agent_card: The agent card
|
||||
turn_number: Current turn number
|
||||
is_multiturn: Whether multi-turn conversation
|
||||
agent_role: Agent role for logging
|
||||
from_task: Optional CrewAI Task object for event metadata.
|
||||
from_agent: Optional CrewAI Agent object for event metadata.
|
||||
endpoint: Optional A2A endpoint URL.
|
||||
a2a_agent_name: Optional A2A agent name.
|
||||
context_id: Optional A2A context ID for correlation.
|
||||
|
||||
Returns:
|
||||
Task ID string if agent needs polling/waiting, or TaskStateResult if done.
|
||||
"""
|
||||
try:
|
||||
async for event in event_stream:
|
||||
if isinstance(event, Message):
|
||||
new_messages.append(event)
|
||||
result_parts = [
|
||||
part.root.text for part in event.parts if part.root.kind == "text"
|
||||
]
|
||||
response_text = " ".join(result_parts) if result_parts else ""
|
||||
|
||||
crewai_event_bus.emit(
|
||||
None,
|
||||
A2AResponseReceivedEvent(
|
||||
response=response_text,
|
||||
turn_number=turn_number,
|
||||
context_id=event.context_id,
|
||||
message_id=event.message_id,
|
||||
is_multiturn=is_multiturn,
|
||||
status="completed",
|
||||
final=True,
|
||||
agent_role=agent_role,
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
|
||||
return TaskStateResult(
|
||||
status=TaskState.completed,
|
||||
result=response_text,
|
||||
history=new_messages,
|
||||
agent_card=agent_card.model_dump(exclude_none=True),
|
||||
)
|
||||
|
||||
if isinstance(event, tuple):
|
||||
a2a_task, _ = event
|
||||
|
||||
if a2a_task.status.state in TERMINAL_STATES | ACTIONABLE_STATES:
|
||||
result = process_task_state(
|
||||
a2a_task=a2a_task,
|
||||
new_messages=new_messages,
|
||||
agent_card=agent_card,
|
||||
turn_number=turn_number,
|
||||
is_multiturn=is_multiturn,
|
||||
agent_role=agent_role,
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
if result:
|
||||
return result
|
||||
|
||||
return a2a_task.id
|
||||
|
||||
return TaskStateResult(
|
||||
status=TaskState.failed,
|
||||
error="No task ID received from initial message",
|
||||
history=new_messages,
|
||||
)
|
||||
|
||||
except A2AClientHTTPError as e:
|
||||
error_msg = f"HTTP Error {e.status_code}: {e!s}"
|
||||
|
||||
error_message = Message(
|
||||
role=Role.agent,
|
||||
message_id=str(uuid.uuid4()),
|
||||
parts=[Part(root=TextPart(text=error_msg))],
|
||||
context_id=context_id,
|
||||
)
|
||||
new_messages.append(error_message)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
None,
|
||||
A2AConnectionErrorEvent(
|
||||
endpoint=endpoint or "",
|
||||
error=str(e),
|
||||
error_type="http_error",
|
||||
status_code=e.status_code,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
operation="send_message",
|
||||
context_id=context_id,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
crewai_event_bus.emit(
|
||||
None,
|
||||
A2AResponseReceivedEvent(
|
||||
response=error_msg,
|
||||
turn_number=turn_number,
|
||||
context_id=context_id,
|
||||
is_multiturn=is_multiturn,
|
||||
status="failed",
|
||||
final=True,
|
||||
agent_role=agent_role,
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
return TaskStateResult(
|
||||
status=TaskState.failed,
|
||||
error=error_msg,
|
||||
history=new_messages,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Unexpected error during send_message: {e!s}"
|
||||
|
||||
error_message = Message(
|
||||
role=Role.agent,
|
||||
message_id=str(uuid.uuid4()),
|
||||
parts=[Part(root=TextPart(text=error_msg))],
|
||||
context_id=context_id,
|
||||
)
|
||||
new_messages.append(error_message)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
None,
|
||||
A2AConnectionErrorEvent(
|
||||
endpoint=endpoint or "",
|
||||
error=str(e),
|
||||
error_type="unexpected_error",
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
operation="send_message",
|
||||
context_id=context_id,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
crewai_event_bus.emit(
|
||||
None,
|
||||
A2AResponseReceivedEvent(
|
||||
response=error_msg,
|
||||
turn_number=turn_number,
|
||||
context_id=context_id,
|
||||
is_multiturn=is_multiturn,
|
||||
status="failed",
|
||||
final=True,
|
||||
agent_role=agent_role,
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
return TaskStateResult(
|
||||
status=TaskState.failed,
|
||||
error=error_msg,
|
||||
history=new_messages,
|
||||
)
|
||||
|
||||
finally:
|
||||
aclose = getattr(event_stream, "aclose", None)
|
||||
if aclose:
|
||||
await aclose()
|
||||
from crewai_a2a.task_helpers import * # noqa: E402, F403
|
||||
|
||||
@@ -1,55 +1,13 @@
|
||||
"""String templates for A2A (Agent-to-Agent) protocol messaging and status."""
|
||||
"""Backward-compatibility shim — use ``crewai_a2a.templates`` instead."""
|
||||
|
||||
from string import Template
|
||||
from typing import Final
|
||||
import warnings
|
||||
|
||||
|
||||
AVAILABLE_AGENTS_TEMPLATE: Final[Template] = Template(
|
||||
"\n<AVAILABLE_A2A_AGENTS>\n $available_a2a_agents\n</AVAILABLE_A2A_AGENTS>\n"
|
||||
warnings.warn(
|
||||
"'crewai.a2a.templates' has been moved to 'crewai_a2a.templates'. "
|
||||
"Please update your imports. The old path will be removed in v2.0.0.",
|
||||
FutureWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
PREVIOUS_A2A_CONVERSATION_TEMPLATE: Final[Template] = Template(
|
||||
"\n<PREVIOUS_A2A_CONVERSATION>\n"
|
||||
" $previous_a2a_conversation"
|
||||
"\n</PREVIOUS_A2A_CONVERSATION>\n"
|
||||
)
|
||||
CONVERSATION_TURN_INFO_TEMPLATE: Final[Template] = Template(
|
||||
"\n<CONVERSATION_PROGRESS>\n"
|
||||
' turn="$turn_count"\n'
|
||||
' max_turns="$max_turns"\n'
|
||||
" $warning"
|
||||
"\n</CONVERSATION_PROGRESS>\n"
|
||||
)
|
||||
UNAVAILABLE_AGENTS_NOTICE_TEMPLATE: Final[Template] = Template(
|
||||
"\n<A2A_AGENTS_STATUS>\n"
|
||||
" NOTE: A2A agents were configured but are currently unavailable.\n"
|
||||
" You cannot delegate to remote agents for this task.\n\n"
|
||||
" Unavailable Agents:\n"
|
||||
" $unavailable_agents"
|
||||
"\n</A2A_AGENTS_STATUS>\n"
|
||||
)
|
||||
REMOTE_AGENT_COMPLETED_NOTICE: Final[str] = """
|
||||
<REMOTE_AGENT_STATUS>
|
||||
STATUS: COMPLETED
|
||||
The remote agent has finished processing your request. Their response is in the conversation history above.
|
||||
You MUST now:
|
||||
1. Extract the answer from the conversation history
|
||||
2. Set is_a2a=false
|
||||
3. Return the answer as your final message
|
||||
DO NOT send another request - the task is already done.
|
||||
</REMOTE_AGENT_STATUS>
|
||||
"""
|
||||
|
||||
REMOTE_AGENT_RESPONSE_NOTICE: Final[str] = """
|
||||
<REMOTE_AGENT_STATUS>
|
||||
STATUS: RESPONSE_RECEIVED
|
||||
The remote agent has responded. Their response is in the conversation history above.
|
||||
|
||||
You MUST now:
|
||||
1. Set is_a2a=false (the remote task is complete and cannot receive more messages)
|
||||
2. Provide YOUR OWN response to the original task based on the information received
|
||||
|
||||
IMPORTANT: Your response should be addressed to the USER who gave you the original task.
|
||||
Report what the remote agent told you in THIRD PERSON (e.g., "The remote agent said..." or "I learned that...").
|
||||
Do NOT address the remote agent directly or use "you" to refer to them.
|
||||
</REMOTE_AGENT_STATUS>
|
||||
"""
|
||||
from crewai_a2a.templates import * # noqa: E402, F403
|
||||
|
||||
@@ -1,104 +1,13 @@
|
||||
"""Type definitions for A2A protocol message parts."""
|
||||
"""Backward-compatibility shim — use ``crewai_a2a.types`` instead."""
|
||||
|
||||
from __future__ import annotations
|
||||
import warnings
|
||||
|
||||
from typing import (
|
||||
Annotated,
|
||||
Any,
|
||||
Literal,
|
||||
Protocol,
|
||||
TypedDict,
|
||||
runtime_checkable,
|
||||
|
||||
warnings.warn(
|
||||
"'crewai.a2a.types' has been moved to 'crewai_a2a.types'. "
|
||||
"Please update your imports. The old path will be removed in v2.0.0.",
|
||||
FutureWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
from pydantic import BeforeValidator, HttpUrl, TypeAdapter
|
||||
from typing_extensions import NotRequired
|
||||
|
||||
|
||||
try:
|
||||
from crewai.a2a.updates import (
|
||||
PollingConfig,
|
||||
PollingHandler,
|
||||
PushNotificationConfig,
|
||||
PushNotificationHandler,
|
||||
StreamingConfig,
|
||||
StreamingHandler,
|
||||
UpdateConfig,
|
||||
)
|
||||
except ImportError:
|
||||
PollingConfig = Any # type: ignore[misc,assignment]
|
||||
PollingHandler = Any # type: ignore[misc,assignment]
|
||||
PushNotificationConfig = Any # type: ignore[misc,assignment]
|
||||
PushNotificationHandler = Any # type: ignore[misc,assignment]
|
||||
StreamingConfig = Any # type: ignore[misc,assignment]
|
||||
StreamingHandler = Any # type: ignore[misc,assignment]
|
||||
UpdateConfig = Any # type: ignore[misc,assignment]
|
||||
|
||||
|
||||
TransportType = Literal["JSONRPC", "GRPC", "HTTP+JSON"]
|
||||
ProtocolVersion = Literal[
|
||||
"0.2.0",
|
||||
"0.2.1",
|
||||
"0.2.2",
|
||||
"0.2.3",
|
||||
"0.2.4",
|
||||
"0.2.5",
|
||||
"0.2.6",
|
||||
"0.3.0",
|
||||
"0.4.0",
|
||||
]
|
||||
|
||||
http_url_adapter: TypeAdapter[HttpUrl] = TypeAdapter(HttpUrl)
|
||||
|
||||
Url = Annotated[
|
||||
str,
|
||||
BeforeValidator(
|
||||
lambda value: str(http_url_adapter.validate_python(value, strict=True))
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class AgentResponseProtocol(Protocol):
|
||||
"""Protocol for the dynamically created AgentResponse model."""
|
||||
|
||||
a2a_ids: tuple[str, ...]
|
||||
message: str
|
||||
is_a2a: bool
|
||||
|
||||
|
||||
class PartsMetadataDict(TypedDict, total=False):
|
||||
"""Metadata for A2A message parts.
|
||||
|
||||
Attributes:
|
||||
mimeType: MIME type for the part content.
|
||||
schema: JSON schema for the part content.
|
||||
"""
|
||||
|
||||
mimeType: Literal["application/json"]
|
||||
schema: dict[str, Any]
|
||||
|
||||
|
||||
class PartsDict(TypedDict):
|
||||
"""A2A message part containing text and optional metadata.
|
||||
|
||||
Attributes:
|
||||
text: The text content of the message part.
|
||||
metadata: Optional metadata describing the part content.
|
||||
"""
|
||||
|
||||
text: str
|
||||
metadata: NotRequired[PartsMetadataDict]
|
||||
|
||||
|
||||
PollingHandlerType = type[PollingHandler]
|
||||
StreamingHandlerType = type[StreamingHandler]
|
||||
PushNotificationHandlerType = type[PushNotificationHandler]
|
||||
|
||||
HandlerType = PollingHandlerType | StreamingHandlerType | PushNotificationHandlerType
|
||||
|
||||
HANDLER_REGISTRY: dict[type[UpdateConfig], HandlerType] = {
|
||||
PollingConfig: PollingHandler,
|
||||
StreamingConfig: StreamingHandler,
|
||||
PushNotificationConfig: PushNotificationHandler,
|
||||
}
|
||||
from crewai_a2a.types import * # noqa: E402, F403
|
||||
|
||||
@@ -1,35 +1,13 @@
|
||||
"""A2A update mechanism configuration types."""
|
||||
"""Backward-compatibility shim — use ``crewai_a2a.updates`` instead."""
|
||||
|
||||
from crewai.a2a.updates.base import (
|
||||
BaseHandlerKwargs,
|
||||
PollingHandlerKwargs,
|
||||
PushNotificationHandlerKwargs,
|
||||
PushNotificationResultStore,
|
||||
StreamingHandlerKwargs,
|
||||
UpdateHandler,
|
||||
import warnings
|
||||
|
||||
|
||||
warnings.warn(
|
||||
"'crewai.a2a.updates' has been moved to 'crewai_a2a.updates'. "
|
||||
"Please update your imports. The old path will be removed in v2.0.0.",
|
||||
FutureWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
from crewai.a2a.updates.polling.config import PollingConfig
|
||||
from crewai.a2a.updates.polling.handler import PollingHandler
|
||||
from crewai.a2a.updates.push_notifications.config import PushNotificationConfig
|
||||
from crewai.a2a.updates.push_notifications.handler import PushNotificationHandler
|
||||
from crewai.a2a.updates.streaming.config import StreamingConfig
|
||||
from crewai.a2a.updates.streaming.handler import StreamingHandler
|
||||
|
||||
|
||||
UpdateConfig = PollingConfig | StreamingConfig | PushNotificationConfig
|
||||
|
||||
__all__ = [
|
||||
"BaseHandlerKwargs",
|
||||
"PollingConfig",
|
||||
"PollingHandler",
|
||||
"PollingHandlerKwargs",
|
||||
"PushNotificationConfig",
|
||||
"PushNotificationHandler",
|
||||
"PushNotificationHandlerKwargs",
|
||||
"PushNotificationResultStore",
|
||||
"StreamingConfig",
|
||||
"StreamingHandler",
|
||||
"StreamingHandlerKwargs",
|
||||
"UpdateConfig",
|
||||
"UpdateHandler",
|
||||
]
|
||||
from crewai_a2a.updates import * # noqa: E402, F403
|
||||
|
||||
@@ -1,176 +1,13 @@
|
||||
"""Base types for A2A update mechanism handlers."""
|
||||
"""Backward-compatibility shim — use ``crewai_a2a.updates.base`` instead."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, NamedTuple, Protocol, TypedDict
|
||||
|
||||
from pydantic import GetCoreSchemaHandler
|
||||
from pydantic_core import CoreSchema, core_schema
|
||||
import warnings
|
||||
|
||||
|
||||
class CommonParams(NamedTuple):
|
||||
"""Common parameters shared across all update handlers.
|
||||
warnings.warn(
|
||||
"'crewai.a2a.updates.base' has been moved to 'crewai_a2a.updates.base'. "
|
||||
"Please update your imports. The old path will be removed in v2.0.0.",
|
||||
FutureWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
Groups the frequently-passed parameters to reduce duplication.
|
||||
"""
|
||||
|
||||
turn_number: int
|
||||
is_multiturn: bool
|
||||
agent_role: str | None
|
||||
endpoint: str
|
||||
a2a_agent_name: str | None
|
||||
context_id: str | None
|
||||
from_task: Any
|
||||
from_agent: Any
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from a2a.client import Client
|
||||
from a2a.types import AgentCard, Message, Task
|
||||
|
||||
from crewai.a2a.task_helpers import TaskStateResult
|
||||
from crewai.a2a.updates.push_notifications.config import PushNotificationConfig
|
||||
|
||||
|
||||
class BaseHandlerKwargs(TypedDict, total=False):
|
||||
"""Base kwargs shared by all handlers."""
|
||||
|
||||
turn_number: int
|
||||
is_multiturn: bool
|
||||
agent_role: str | None
|
||||
context_id: str | None
|
||||
task_id: str | None
|
||||
endpoint: str | None
|
||||
agent_branch: Any
|
||||
a2a_agent_name: str | None
|
||||
from_task: Any
|
||||
from_agent: Any
|
||||
|
||||
|
||||
class PollingHandlerKwargs(BaseHandlerKwargs, total=False):
|
||||
"""Kwargs for polling handler."""
|
||||
|
||||
polling_interval: float
|
||||
polling_timeout: float
|
||||
history_length: int
|
||||
max_polls: int | None
|
||||
|
||||
|
||||
class StreamingHandlerKwargs(BaseHandlerKwargs, total=False):
|
||||
"""Kwargs for streaming handler."""
|
||||
|
||||
|
||||
class PushNotificationHandlerKwargs(BaseHandlerKwargs, total=False):
|
||||
"""Kwargs for push notification handler."""
|
||||
|
||||
config: PushNotificationConfig
|
||||
result_store: PushNotificationResultStore
|
||||
polling_timeout: float
|
||||
polling_interval: float
|
||||
|
||||
|
||||
class PushNotificationResultStore(Protocol):
|
||||
"""Protocol for storing and retrieving push notification results.
|
||||
|
||||
This protocol defines the interface for a result store that the
|
||||
PushNotificationHandler uses to wait for task completion.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(
|
||||
cls,
|
||||
_source_type: Any,
|
||||
_handler: GetCoreSchemaHandler,
|
||||
) -> CoreSchema:
|
||||
return core_schema.any_schema()
|
||||
|
||||
async def wait_for_result(
|
||||
self,
|
||||
task_id: str,
|
||||
timeout: float,
|
||||
poll_interval: float = 1.0,
|
||||
) -> Task | None:
|
||||
"""Wait for a task result to be available.
|
||||
|
||||
Args:
|
||||
task_id: The task ID to wait for.
|
||||
timeout: Max seconds to wait before returning None.
|
||||
poll_interval: Seconds between polling attempts.
|
||||
|
||||
Returns:
|
||||
The completed Task object, or None if timeout.
|
||||
"""
|
||||
...
|
||||
|
||||
async def get_result(self, task_id: str) -> Task | None:
|
||||
"""Get a task result if available.
|
||||
|
||||
Args:
|
||||
task_id: The task ID to retrieve.
|
||||
|
||||
Returns:
|
||||
The Task object if available, None otherwise.
|
||||
"""
|
||||
...
|
||||
|
||||
async def store_result(self, task: Task) -> None:
|
||||
"""Store a task result.
|
||||
|
||||
Args:
|
||||
task: The Task object to store.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class UpdateHandler(Protocol):
|
||||
"""Protocol for A2A update mechanism handlers."""
|
||||
|
||||
@staticmethod
|
||||
async def execute(
|
||||
client: Client,
|
||||
message: Message,
|
||||
new_messages: list[Message],
|
||||
agent_card: AgentCard,
|
||||
**kwargs: Any,
|
||||
) -> TaskStateResult:
|
||||
"""Execute the update mechanism and return result.
|
||||
|
||||
Args:
|
||||
client: A2A client instance.
|
||||
message: Message to send.
|
||||
new_messages: List to collect messages (modified in place).
|
||||
agent_card: The agent card.
|
||||
**kwargs: Additional handler-specific parameters.
|
||||
|
||||
Returns:
|
||||
Result dictionary with status, result/error, and history.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
def extract_common_params(kwargs: BaseHandlerKwargs) -> CommonParams:
|
||||
"""Extract common parameters from handler kwargs.
|
||||
|
||||
Args:
|
||||
kwargs: Handler kwargs dict.
|
||||
|
||||
Returns:
|
||||
CommonParams with extracted values.
|
||||
|
||||
Raises:
|
||||
ValueError: If endpoint is not provided.
|
||||
"""
|
||||
endpoint = kwargs.get("endpoint")
|
||||
if endpoint is None:
|
||||
raise ValueError("endpoint is required for update handlers")
|
||||
|
||||
return CommonParams(
|
||||
turn_number=kwargs.get("turn_number", 0),
|
||||
is_multiturn=kwargs.get("is_multiturn", False),
|
||||
agent_role=kwargs.get("agent_role"),
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=kwargs.get("a2a_agent_name"),
|
||||
context_id=kwargs.get("context_id"),
|
||||
from_task=kwargs.get("from_task"),
|
||||
from_agent=kwargs.get("from_agent"),
|
||||
)
|
||||
from crewai_a2a.updates.base import * # noqa: E402, F403
|
||||
|
||||
@@ -1 +1,13 @@
|
||||
"""Polling update mechanism module."""
|
||||
"""Backward-compatibility shim — use ``crewai_a2a.updates.polling`` instead."""
|
||||
|
||||
import warnings
|
||||
|
||||
|
||||
warnings.warn(
|
||||
"'crewai.a2a.updates.polling' has been moved to 'crewai_a2a.updates.polling'. "
|
||||
"Please update your imports. The old path will be removed in v2.0.0.",
|
||||
FutureWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
from crewai_a2a.updates.polling import * # noqa: E402, F403
|
||||
|
||||
@@ -1,25 +1,13 @@
|
||||
"""Polling update mechanism configuration."""
|
||||
"""Backward-compatibility shim — use ``crewai_a2a.updates.polling.config`` instead."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
import warnings
|
||||
|
||||
|
||||
class PollingConfig(BaseModel):
|
||||
"""Configuration for polling-based task updates.
|
||||
warnings.warn(
|
||||
"'crewai.a2a.updates.polling.config' has been moved to 'crewai_a2a.updates.polling.config'. "
|
||||
"Please update your imports. The old path will be removed in v2.0.0.",
|
||||
FutureWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
Attributes:
|
||||
interval: Seconds between poll attempts.
|
||||
timeout: Max seconds to poll before raising timeout error.
|
||||
max_polls: Max number of poll attempts.
|
||||
history_length: Number of messages to retrieve per poll.
|
||||
"""
|
||||
|
||||
interval: float = Field(
|
||||
default=2.0, gt=0, description="Seconds between poll attempts"
|
||||
)
|
||||
timeout: float | None = Field(default=None, gt=0, description="Max seconds to poll")
|
||||
max_polls: int | None = Field(default=None, gt=0, description="Max poll attempts")
|
||||
history_length: int = Field(
|
||||
default=100, gt=0, description="Messages to retrieve per poll"
|
||||
)
|
||||
from crewai_a2a.updates.polling.config import * # noqa: E402, F403
|
||||
|
||||
@@ -1,359 +1,13 @@
|
||||
"""Polling update mechanism handler."""
|
||||
"""Backward-compatibility shim — use ``crewai_a2a.updates.polling.handler`` instead."""
|
||||
|
||||
from __future__ import annotations
|
||||
import warnings
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Any
|
||||
import uuid
|
||||
|
||||
from a2a.client import Client
|
||||
from a2a.client.errors import A2AClientHTTPError
|
||||
from a2a.types import (
|
||||
AgentCard,
|
||||
Message,
|
||||
Part,
|
||||
Role,
|
||||
TaskQueryParams,
|
||||
TaskState,
|
||||
TextPart,
|
||||
)
|
||||
from typing_extensions import Unpack
|
||||
|
||||
from crewai.a2a.errors import A2APollingTimeoutError
|
||||
from crewai.a2a.task_helpers import (
|
||||
ACTIONABLE_STATES,
|
||||
TERMINAL_STATES,
|
||||
TaskStateResult,
|
||||
process_task_state,
|
||||
send_message_and_get_task_id,
|
||||
)
|
||||
from crewai.a2a.updates.base import PollingHandlerKwargs
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.a2a_events import (
|
||||
A2AConnectionErrorEvent,
|
||||
A2APollingStartedEvent,
|
||||
A2APollingStatusEvent,
|
||||
A2AResponseReceivedEvent,
|
||||
warnings.warn(
|
||||
"'crewai.a2a.updates.polling.handler' has been moved to 'crewai_a2a.updates.polling.handler'. "
|
||||
"Please update your imports. The old path will be removed in v2.0.0.",
|
||||
FutureWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from a2a.types import Task as A2ATask
|
||||
|
||||
|
||||
async def _poll_task_until_complete(
|
||||
client: Client,
|
||||
task_id: str,
|
||||
polling_interval: float,
|
||||
polling_timeout: float,
|
||||
agent_branch: Any | None = None,
|
||||
history_length: int = 100,
|
||||
max_polls: int | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
context_id: str | None = None,
|
||||
endpoint: str | None = None,
|
||||
a2a_agent_name: str | None = None,
|
||||
) -> A2ATask:
|
||||
"""Poll task status until terminal state reached.
|
||||
|
||||
Args:
|
||||
client: A2A client instance.
|
||||
task_id: Task ID to poll.
|
||||
polling_interval: Seconds between poll attempts.
|
||||
polling_timeout: Max seconds before timeout.
|
||||
agent_branch: Agent tree branch for logging.
|
||||
history_length: Number of messages to retrieve per poll.
|
||||
max_polls: Max number of poll attempts (None = unlimited).
|
||||
from_task: Optional CrewAI Task object for event metadata.
|
||||
from_agent: Optional CrewAI Agent object for event metadata.
|
||||
context_id: A2A context ID for correlation.
|
||||
endpoint: A2A agent endpoint URL.
|
||||
a2a_agent_name: Name of the A2A agent from agent card.
|
||||
|
||||
Returns:
|
||||
Final task object in terminal state.
|
||||
|
||||
Raises:
|
||||
A2APollingTimeoutError: If polling exceeds timeout or max_polls.
|
||||
"""
|
||||
start_time = time.monotonic()
|
||||
poll_count = 0
|
||||
|
||||
while True:
|
||||
poll_count += 1
|
||||
task = await client.get_task(
|
||||
TaskQueryParams(id=task_id, history_length=history_length)
|
||||
)
|
||||
|
||||
elapsed = time.monotonic() - start_time
|
||||
effective_context_id = task.context_id or context_id
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
A2APollingStatusEvent(
|
||||
task_id=task_id,
|
||||
context_id=effective_context_id,
|
||||
state=str(task.status.state.value),
|
||||
elapsed_seconds=elapsed,
|
||||
poll_count=poll_count,
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
|
||||
if task.status.state in TERMINAL_STATES | ACTIONABLE_STATES:
|
||||
return task
|
||||
|
||||
if elapsed > polling_timeout:
|
||||
raise A2APollingTimeoutError(
|
||||
f"Polling timeout after {polling_timeout}s ({poll_count} polls)"
|
||||
)
|
||||
|
||||
if max_polls and poll_count >= max_polls:
|
||||
raise A2APollingTimeoutError(
|
||||
f"Max polls ({max_polls}) exceeded after {elapsed:.1f}s"
|
||||
)
|
||||
|
||||
await asyncio.sleep(polling_interval)
|
||||
|
||||
|
||||
class PollingHandler:
|
||||
"""Polling-based update handler."""
|
||||
|
||||
@staticmethod
|
||||
async def execute(
|
||||
client: Client,
|
||||
message: Message,
|
||||
new_messages: list[Message],
|
||||
agent_card: AgentCard,
|
||||
**kwargs: Unpack[PollingHandlerKwargs],
|
||||
) -> TaskStateResult:
|
||||
"""Execute A2A delegation using polling for updates.
|
||||
|
||||
Args:
|
||||
client: A2A client instance.
|
||||
message: Message to send.
|
||||
new_messages: List to collect messages.
|
||||
agent_card: The agent card.
|
||||
**kwargs: Polling-specific parameters.
|
||||
|
||||
Returns:
|
||||
Dictionary with status, result/error, and history.
|
||||
"""
|
||||
polling_interval = kwargs.get("polling_interval", 2.0)
|
||||
polling_timeout = kwargs.get("polling_timeout", 300.0)
|
||||
endpoint = kwargs.get("endpoint", "")
|
||||
agent_branch = kwargs.get("agent_branch")
|
||||
turn_number = kwargs.get("turn_number", 0)
|
||||
is_multiturn = kwargs.get("is_multiturn", False)
|
||||
agent_role = kwargs.get("agent_role")
|
||||
history_length = kwargs.get("history_length", 100)
|
||||
max_polls = kwargs.get("max_polls")
|
||||
context_id = kwargs.get("context_id")
|
||||
task_id = kwargs.get("task_id")
|
||||
a2a_agent_name = kwargs.get("a2a_agent_name")
|
||||
from_task = kwargs.get("from_task")
|
||||
from_agent = kwargs.get("from_agent")
|
||||
|
||||
try:
|
||||
result_or_task_id = await send_message_and_get_task_id(
|
||||
event_stream=client.send_message(message),
|
||||
new_messages=new_messages,
|
||||
agent_card=agent_card,
|
||||
turn_number=turn_number,
|
||||
is_multiturn=is_multiturn,
|
||||
agent_role=agent_role,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
context_id=context_id,
|
||||
)
|
||||
|
||||
if not isinstance(result_or_task_id, str):
|
||||
return result_or_task_id
|
||||
|
||||
task_id = result_or_task_id
|
||||
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
A2APollingStartedEvent(
|
||||
task_id=task_id,
|
||||
context_id=context_id,
|
||||
polling_interval=polling_interval,
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
|
||||
final_task = await _poll_task_until_complete(
|
||||
client=client,
|
||||
task_id=task_id,
|
||||
polling_interval=polling_interval,
|
||||
polling_timeout=polling_timeout,
|
||||
agent_branch=agent_branch,
|
||||
history_length=history_length,
|
||||
max_polls=max_polls,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
context_id=context_id,
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
)
|
||||
|
||||
result = process_task_state(
|
||||
a2a_task=final_task,
|
||||
new_messages=new_messages,
|
||||
agent_card=agent_card,
|
||||
turn_number=turn_number,
|
||||
is_multiturn=is_multiturn,
|
||||
agent_role=agent_role,
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
if result:
|
||||
return result
|
||||
|
||||
return TaskStateResult(
|
||||
status=TaskState.failed,
|
||||
error=f"Unexpected task state: {final_task.status.state}",
|
||||
history=new_messages,
|
||||
)
|
||||
|
||||
except A2APollingTimeoutError as e:
|
||||
error_msg = str(e)
|
||||
|
||||
error_message = Message(
|
||||
role=Role.agent,
|
||||
message_id=str(uuid.uuid4()),
|
||||
parts=[Part(root=TextPart(text=error_msg))],
|
||||
context_id=context_id,
|
||||
task_id=task_id,
|
||||
)
|
||||
new_messages.append(error_message)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
A2AResponseReceivedEvent(
|
||||
response=error_msg,
|
||||
turn_number=turn_number,
|
||||
context_id=context_id,
|
||||
is_multiturn=is_multiturn,
|
||||
status="failed",
|
||||
final=True,
|
||||
agent_role=agent_role,
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
return TaskStateResult(
|
||||
status=TaskState.failed,
|
||||
error=error_msg,
|
||||
history=new_messages,
|
||||
)
|
||||
|
||||
except A2AClientHTTPError as e:
|
||||
error_msg = f"HTTP Error {e.status_code}: {e!s}"
|
||||
|
||||
error_message = Message(
|
||||
role=Role.agent,
|
||||
message_id=str(uuid.uuid4()),
|
||||
parts=[Part(root=TextPart(text=error_msg))],
|
||||
context_id=context_id,
|
||||
task_id=task_id,
|
||||
)
|
||||
new_messages.append(error_message)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
A2AConnectionErrorEvent(
|
||||
endpoint=endpoint,
|
||||
error=str(e),
|
||||
error_type="http_error",
|
||||
status_code=e.status_code,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
operation="polling",
|
||||
context_id=context_id,
|
||||
task_id=task_id,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
A2AResponseReceivedEvent(
|
||||
response=error_msg,
|
||||
turn_number=turn_number,
|
||||
context_id=context_id,
|
||||
is_multiturn=is_multiturn,
|
||||
status="failed",
|
||||
final=True,
|
||||
agent_role=agent_role,
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
return TaskStateResult(
|
||||
status=TaskState.failed,
|
||||
error=error_msg,
|
||||
history=new_messages,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Unexpected error during polling: {e!s}"
|
||||
|
||||
error_message = Message(
|
||||
role=Role.agent,
|
||||
message_id=str(uuid.uuid4()),
|
||||
parts=[Part(root=TextPart(text=error_msg))],
|
||||
context_id=context_id,
|
||||
task_id=task_id,
|
||||
)
|
||||
new_messages.append(error_message)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
A2AConnectionErrorEvent(
|
||||
endpoint=endpoint,
|
||||
error=str(e),
|
||||
error_type="unexpected_error",
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
operation="polling",
|
||||
context_id=context_id,
|
||||
task_id=task_id,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
A2AResponseReceivedEvent(
|
||||
response=error_msg,
|
||||
turn_number=turn_number,
|
||||
context_id=context_id,
|
||||
is_multiturn=is_multiturn,
|
||||
status="failed",
|
||||
final=True,
|
||||
agent_role=agent_role,
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
return TaskStateResult(
|
||||
status=TaskState.failed,
|
||||
error=error_msg,
|
||||
history=new_messages,
|
||||
)
|
||||
from crewai_a2a.updates.polling.handler import * # noqa: E402, F403
|
||||
|
||||
@@ -1 +1,13 @@
|
||||
"""Push notification update mechanism module."""
|
||||
"""Backward-compatibility shim — use ``crewai_a2a.updates.push_notifications`` instead."""
|
||||
|
||||
import warnings
|
||||
|
||||
|
||||
warnings.warn(
|
||||
"'crewai.a2a.updates.push_notifications' has been moved to 'crewai_a2a.updates.push_notifications'. "
|
||||
"Please update your imports. The old path will be removed in v2.0.0.",
|
||||
FutureWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
from crewai_a2a.updates.push_notifications import * # noqa: E402, F403
|
||||
|
||||
@@ -1,65 +1,13 @@
|
||||
"""Push notification update mechanism configuration."""
|
||||
"""Backward-compatibility shim — use ``crewai_a2a.updates.push_notifications.config`` instead."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Annotated
|
||||
|
||||
from a2a.types import PushNotificationAuthenticationInfo
|
||||
from pydantic import AnyHttpUrl, BaseModel, BeforeValidator, Field
|
||||
|
||||
from crewai.a2a.updates.base import PushNotificationResultStore
|
||||
from crewai.a2a.updates.push_notifications.signature import WebhookSignatureConfig
|
||||
import warnings
|
||||
|
||||
|
||||
def _coerce_signature(
|
||||
value: str | WebhookSignatureConfig | None,
|
||||
) -> WebhookSignatureConfig | None:
|
||||
"""Convert string secret to WebhookSignatureConfig."""
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, str):
|
||||
return WebhookSignatureConfig.hmac_sha256(secret=value)
|
||||
return value
|
||||
warnings.warn(
|
||||
"'crewai.a2a.updates.push_notifications.config' has been moved to 'crewai_a2a.updates.push_notifications.config'. "
|
||||
"Please update your imports. The old path will be removed in v2.0.0.",
|
||||
FutureWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
|
||||
SignatureInput = Annotated[
|
||||
WebhookSignatureConfig | None,
|
||||
BeforeValidator(_coerce_signature),
|
||||
]
|
||||
|
||||
|
||||
class PushNotificationConfig(BaseModel):
|
||||
"""Configuration for webhook-based task updates.
|
||||
|
||||
Attributes:
|
||||
url: Callback URL where agent sends push notifications.
|
||||
id: Unique identifier for this config.
|
||||
token: Token to validate incoming notifications.
|
||||
authentication: Auth info for agent to use when calling webhook.
|
||||
timeout: Max seconds to wait for task completion.
|
||||
interval: Seconds between result polling attempts.
|
||||
result_store: Store for receiving push notification results.
|
||||
signature: HMAC signature config. Pass a string (secret) for defaults,
|
||||
or WebhookSignatureConfig for custom settings.
|
||||
"""
|
||||
|
||||
url: AnyHttpUrl = Field(description="Callback URL for push notifications")
|
||||
id: str | None = Field(default=None, description="Unique config identifier")
|
||||
token: str | None = Field(default=None, description="Validation token")
|
||||
authentication: PushNotificationAuthenticationInfo | None = Field(
|
||||
default=None, description="Auth info for agent to use when calling webhook"
|
||||
)
|
||||
timeout: float | None = Field(
|
||||
default=300.0, gt=0, description="Max seconds to wait for task completion"
|
||||
)
|
||||
interval: float = Field(
|
||||
default=2.0, gt=0, description="Seconds between result polling attempts"
|
||||
)
|
||||
result_store: PushNotificationResultStore | None = Field(
|
||||
default=None, description="Result store for push notification handling"
|
||||
)
|
||||
signature: SignatureInput = Field(
|
||||
default=None,
|
||||
description="HMAC signature config. Pass a string (secret) for simple usage, "
|
||||
"or WebhookSignatureConfig for custom headers/tolerance.",
|
||||
)
|
||||
from crewai_a2a.updates.push_notifications.config import * # noqa: E402, F403
|
||||
|
||||
@@ -1,354 +1,13 @@
|
||||
"""Push notification (webhook) update mechanism handler."""
|
||||
"""Backward-compatibility shim — use ``crewai_a2a.updates.push_notifications.handler`` instead."""
|
||||
|
||||
from __future__ import annotations
|
||||
import warnings
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any
|
||||
import uuid
|
||||
|
||||
from a2a.client import Client
|
||||
from a2a.client.errors import A2AClientHTTPError
|
||||
from a2a.types import (
|
||||
AgentCard,
|
||||
Message,
|
||||
Part,
|
||||
Role,
|
||||
TaskState,
|
||||
TextPart,
|
||||
)
|
||||
from typing_extensions import Unpack
|
||||
|
||||
from crewai.a2a.task_helpers import (
|
||||
TaskStateResult,
|
||||
process_task_state,
|
||||
send_message_and_get_task_id,
|
||||
)
|
||||
from crewai.a2a.updates.base import (
|
||||
CommonParams,
|
||||
PushNotificationHandlerKwargs,
|
||||
PushNotificationResultStore,
|
||||
extract_common_params,
|
||||
)
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.a2a_events import (
|
||||
A2AConnectionErrorEvent,
|
||||
A2APushNotificationRegisteredEvent,
|
||||
A2APushNotificationTimeoutEvent,
|
||||
A2AResponseReceivedEvent,
|
||||
warnings.warn(
|
||||
"'crewai.a2a.updates.push_notifications.handler' has been moved to 'crewai_a2a.updates.push_notifications.handler'. "
|
||||
"Please update your imports. The old path will be removed in v2.0.0.",
|
||||
FutureWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from a2a.types import Task as A2ATask
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _handle_push_error(
|
||||
error: Exception,
|
||||
error_msg: str,
|
||||
error_type: str,
|
||||
new_messages: list[Message],
|
||||
agent_branch: Any | None,
|
||||
params: CommonParams,
|
||||
task_id: str | None,
|
||||
status_code: int | None = None,
|
||||
) -> TaskStateResult:
|
||||
"""Handle push notification errors with consistent event emission.
|
||||
|
||||
Args:
|
||||
error: The exception that occurred.
|
||||
error_msg: Formatted error message for the result.
|
||||
error_type: Type of error for the event.
|
||||
new_messages: List to append error message to.
|
||||
agent_branch: Agent tree branch for events.
|
||||
params: Common handler parameters.
|
||||
task_id: A2A task ID.
|
||||
status_code: HTTP status code if applicable.
|
||||
|
||||
Returns:
|
||||
TaskStateResult with failed status.
|
||||
"""
|
||||
error_message = Message(
|
||||
role=Role.agent,
|
||||
message_id=str(uuid.uuid4()),
|
||||
parts=[Part(root=TextPart(text=error_msg))],
|
||||
context_id=params.context_id,
|
||||
task_id=task_id,
|
||||
)
|
||||
new_messages.append(error_message)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
A2AConnectionErrorEvent(
|
||||
endpoint=params.endpoint,
|
||||
error=str(error),
|
||||
error_type=error_type,
|
||||
status_code=status_code,
|
||||
a2a_agent_name=params.a2a_agent_name,
|
||||
operation="push_notification",
|
||||
context_id=params.context_id,
|
||||
task_id=task_id,
|
||||
from_task=params.from_task,
|
||||
from_agent=params.from_agent,
|
||||
),
|
||||
)
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
A2AResponseReceivedEvent(
|
||||
response=error_msg,
|
||||
turn_number=params.turn_number,
|
||||
context_id=params.context_id,
|
||||
is_multiturn=params.is_multiturn,
|
||||
status="failed",
|
||||
final=True,
|
||||
agent_role=params.agent_role,
|
||||
endpoint=params.endpoint,
|
||||
a2a_agent_name=params.a2a_agent_name,
|
||||
from_task=params.from_task,
|
||||
from_agent=params.from_agent,
|
||||
),
|
||||
)
|
||||
return TaskStateResult(
|
||||
status=TaskState.failed,
|
||||
error=error_msg,
|
||||
history=new_messages,
|
||||
)
|
||||
|
||||
|
||||
async def _wait_for_push_result(
|
||||
task_id: str,
|
||||
result_store: PushNotificationResultStore,
|
||||
timeout: float,
|
||||
poll_interval: float,
|
||||
agent_branch: Any | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
context_id: str | None = None,
|
||||
endpoint: str | None = None,
|
||||
a2a_agent_name: str | None = None,
|
||||
) -> A2ATask | None:
|
||||
"""Wait for push notification result.
|
||||
|
||||
Args:
|
||||
task_id: Task ID to wait for.
|
||||
result_store: Store to retrieve results from.
|
||||
timeout: Max seconds to wait.
|
||||
poll_interval: Seconds between polling attempts.
|
||||
agent_branch: Agent tree branch for logging.
|
||||
from_task: Optional CrewAI Task object for event metadata.
|
||||
from_agent: Optional CrewAI Agent object for event metadata.
|
||||
context_id: A2A context ID for correlation.
|
||||
endpoint: A2A agent endpoint URL.
|
||||
a2a_agent_name: Name of the A2A agent.
|
||||
|
||||
Returns:
|
||||
Final task object, or None if timeout.
|
||||
"""
|
||||
task = await result_store.wait_for_result(
|
||||
task_id=task_id,
|
||||
timeout=timeout,
|
||||
poll_interval=poll_interval,
|
||||
)
|
||||
|
||||
if task is None:
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
A2APushNotificationTimeoutEvent(
|
||||
task_id=task_id,
|
||||
context_id=context_id,
|
||||
timeout_seconds=timeout,
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
|
||||
return task
|
||||
|
||||
|
||||
class PushNotificationHandler:
|
||||
"""Push notification (webhook) based update handler."""
|
||||
|
||||
@staticmethod
|
||||
async def execute(
|
||||
client: Client,
|
||||
message: Message,
|
||||
new_messages: list[Message],
|
||||
agent_card: AgentCard,
|
||||
**kwargs: Unpack[PushNotificationHandlerKwargs],
|
||||
) -> TaskStateResult:
|
||||
"""Execute A2A delegation using push notifications for updates.
|
||||
|
||||
Args:
|
||||
client: A2A client instance.
|
||||
message: Message to send.
|
||||
new_messages: List to collect messages.
|
||||
agent_card: The agent card.
|
||||
**kwargs: Push notification-specific parameters.
|
||||
|
||||
Returns:
|
||||
Dictionary with status, result/error, and history.
|
||||
|
||||
Raises:
|
||||
ValueError: If result_store or config not provided.
|
||||
"""
|
||||
config = kwargs.get("config")
|
||||
result_store = kwargs.get("result_store")
|
||||
polling_timeout = kwargs.get("polling_timeout", 300.0)
|
||||
polling_interval = kwargs.get("polling_interval", 2.0)
|
||||
agent_branch = kwargs.get("agent_branch")
|
||||
task_id = kwargs.get("task_id")
|
||||
params = extract_common_params(kwargs)
|
||||
|
||||
if config is None:
|
||||
error_msg = (
|
||||
"PushNotificationConfig is required for push notification handler"
|
||||
)
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
A2AConnectionErrorEvent(
|
||||
endpoint=params.endpoint,
|
||||
error=error_msg,
|
||||
error_type="configuration_error",
|
||||
a2a_agent_name=params.a2a_agent_name,
|
||||
operation="push_notification",
|
||||
context_id=params.context_id,
|
||||
task_id=task_id,
|
||||
from_task=params.from_task,
|
||||
from_agent=params.from_agent,
|
||||
),
|
||||
)
|
||||
return TaskStateResult(
|
||||
status=TaskState.failed,
|
||||
error=error_msg,
|
||||
history=new_messages,
|
||||
)
|
||||
|
||||
if result_store is None:
|
||||
error_msg = (
|
||||
"PushNotificationResultStore is required for push notification handler"
|
||||
)
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
A2AConnectionErrorEvent(
|
||||
endpoint=params.endpoint,
|
||||
error=error_msg,
|
||||
error_type="configuration_error",
|
||||
a2a_agent_name=params.a2a_agent_name,
|
||||
operation="push_notification",
|
||||
context_id=params.context_id,
|
||||
task_id=task_id,
|
||||
from_task=params.from_task,
|
||||
from_agent=params.from_agent,
|
||||
),
|
||||
)
|
||||
return TaskStateResult(
|
||||
status=TaskState.failed,
|
||||
error=error_msg,
|
||||
history=new_messages,
|
||||
)
|
||||
|
||||
try:
|
||||
result_or_task_id = await send_message_and_get_task_id(
|
||||
event_stream=client.send_message(message),
|
||||
new_messages=new_messages,
|
||||
agent_card=agent_card,
|
||||
turn_number=params.turn_number,
|
||||
is_multiturn=params.is_multiturn,
|
||||
agent_role=params.agent_role,
|
||||
from_task=params.from_task,
|
||||
from_agent=params.from_agent,
|
||||
endpoint=params.endpoint,
|
||||
a2a_agent_name=params.a2a_agent_name,
|
||||
context_id=params.context_id,
|
||||
)
|
||||
|
||||
if not isinstance(result_or_task_id, str):
|
||||
return result_or_task_id
|
||||
|
||||
task_id = result_or_task_id
|
||||
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
A2APushNotificationRegisteredEvent(
|
||||
task_id=task_id,
|
||||
context_id=params.context_id,
|
||||
callback_url=str(config.url),
|
||||
endpoint=params.endpoint,
|
||||
a2a_agent_name=params.a2a_agent_name,
|
||||
from_task=params.from_task,
|
||||
from_agent=params.from_agent,
|
||||
),
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
"Push notification callback for task %s configured at %s (via initial request)",
|
||||
task_id,
|
||||
config.url,
|
||||
)
|
||||
|
||||
final_task = await _wait_for_push_result(
|
||||
task_id=task_id,
|
||||
result_store=result_store,
|
||||
timeout=polling_timeout,
|
||||
poll_interval=polling_interval,
|
||||
agent_branch=agent_branch,
|
||||
from_task=params.from_task,
|
||||
from_agent=params.from_agent,
|
||||
context_id=params.context_id,
|
||||
endpoint=params.endpoint,
|
||||
a2a_agent_name=params.a2a_agent_name,
|
||||
)
|
||||
|
||||
if final_task is None:
|
||||
return TaskStateResult(
|
||||
status=TaskState.failed,
|
||||
error=f"Push notification timeout after {polling_timeout}s",
|
||||
history=new_messages,
|
||||
)
|
||||
|
||||
result = process_task_state(
|
||||
a2a_task=final_task,
|
||||
new_messages=new_messages,
|
||||
agent_card=agent_card,
|
||||
turn_number=params.turn_number,
|
||||
is_multiturn=params.is_multiturn,
|
||||
agent_role=params.agent_role,
|
||||
endpoint=params.endpoint,
|
||||
a2a_agent_name=params.a2a_agent_name,
|
||||
from_task=params.from_task,
|
||||
from_agent=params.from_agent,
|
||||
)
|
||||
if result:
|
||||
return result
|
||||
|
||||
return TaskStateResult(
|
||||
status=TaskState.failed,
|
||||
error=f"Unexpected task state: {final_task.status.state}",
|
||||
history=new_messages,
|
||||
)
|
||||
|
||||
except A2AClientHTTPError as e:
|
||||
return _handle_push_error(
|
||||
error=e,
|
||||
error_msg=f"HTTP Error {e.status_code}: {e!s}",
|
||||
error_type="http_error",
|
||||
new_messages=new_messages,
|
||||
agent_branch=agent_branch,
|
||||
params=params,
|
||||
task_id=task_id,
|
||||
status_code=e.status_code,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
return _handle_push_error(
|
||||
error=e,
|
||||
error_msg=f"Unexpected error during push notification: {e!s}",
|
||||
error_type="unexpected_error",
|
||||
new_messages=new_messages,
|
||||
agent_branch=agent_branch,
|
||||
params=params,
|
||||
task_id=task_id,
|
||||
)
|
||||
from crewai_a2a.updates.push_notifications.handler import * # noqa: E402, F403
|
||||
|
||||
@@ -1,87 +1,13 @@
|
||||
"""Webhook signature configuration for push notifications."""
|
||||
"""Backward-compatibility shim — use ``crewai_a2a.updates.push_notifications.signature`` instead."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import Enum
|
||||
import secrets
|
||||
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
import warnings
|
||||
|
||||
|
||||
class WebhookSignatureMode(str, Enum):
|
||||
"""Signature mode for webhook push notifications."""
|
||||
warnings.warn(
|
||||
"'crewai.a2a.updates.push_notifications.signature' has been moved to 'crewai_a2a.updates.push_notifications.signature'. "
|
||||
"Please update your imports. The old path will be removed in v2.0.0.",
|
||||
FutureWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
NONE = "none"
|
||||
HMAC_SHA256 = "hmac_sha256"
|
||||
|
||||
|
||||
class WebhookSignatureConfig(BaseModel):
|
||||
"""Configuration for webhook signature verification.
|
||||
|
||||
Provides cryptographic integrity verification and replay attack protection
|
||||
for A2A push notifications.
|
||||
|
||||
Attributes:
|
||||
mode: Signature mode (none or hmac_sha256).
|
||||
secret: Shared secret for HMAC computation (required for hmac_sha256 mode).
|
||||
timestamp_tolerance_seconds: Max allowed age of timestamps for replay protection.
|
||||
header_name: HTTP header name for the signature.
|
||||
timestamp_header_name: HTTP header name for the timestamp.
|
||||
"""
|
||||
|
||||
mode: WebhookSignatureMode = Field(
|
||||
default=WebhookSignatureMode.NONE,
|
||||
description="Signature verification mode",
|
||||
)
|
||||
secret: SecretStr | None = Field(
|
||||
default=None,
|
||||
description="Shared secret for HMAC computation",
|
||||
)
|
||||
timestamp_tolerance_seconds: int = Field(
|
||||
default=300,
|
||||
ge=0,
|
||||
description="Max allowed timestamp age in seconds (5 min default)",
|
||||
)
|
||||
header_name: str = Field(
|
||||
default="X-A2A-Signature",
|
||||
description="HTTP header name for the signature",
|
||||
)
|
||||
timestamp_header_name: str = Field(
|
||||
default="X-A2A-Signature-Timestamp",
|
||||
description="HTTP header name for the timestamp",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def generate_secret(cls, length: int = 32) -> str:
|
||||
"""Generate a cryptographically secure random secret.
|
||||
|
||||
Args:
|
||||
length: Number of random bytes to generate (default 32).
|
||||
|
||||
Returns:
|
||||
URL-safe base64-encoded secret string.
|
||||
"""
|
||||
return secrets.token_urlsafe(length)
|
||||
|
||||
@classmethod
|
||||
def hmac_sha256(
|
||||
cls,
|
||||
secret: str | SecretStr,
|
||||
timestamp_tolerance_seconds: int = 300,
|
||||
) -> WebhookSignatureConfig:
|
||||
"""Create an HMAC-SHA256 signature configuration.
|
||||
|
||||
Args:
|
||||
secret: Shared secret for HMAC computation.
|
||||
timestamp_tolerance_seconds: Max allowed timestamp age in seconds.
|
||||
|
||||
Returns:
|
||||
Configured WebhookSignatureConfig for HMAC-SHA256.
|
||||
"""
|
||||
if isinstance(secret, str):
|
||||
secret = SecretStr(secret)
|
||||
return cls(
|
||||
mode=WebhookSignatureMode.HMAC_SHA256,
|
||||
secret=secret,
|
||||
timestamp_tolerance_seconds=timestamp_tolerance_seconds,
|
||||
)
|
||||
from crewai_a2a.updates.push_notifications.signature import * # noqa: E402, F403
|
||||
|
||||
@@ -1 +1,13 @@
|
||||
"""Streaming update mechanism module."""
|
||||
"""Backward-compatibility shim — use ``crewai_a2a.updates.streaming`` instead."""
|
||||
|
||||
import warnings
|
||||
|
||||
|
||||
warnings.warn(
|
||||
"'crewai.a2a.updates.streaming' has been moved to 'crewai_a2a.updates.streaming'. "
|
||||
"Please update your imports. The old path will be removed in v2.0.0.",
|
||||
FutureWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
from crewai_a2a.updates.streaming import * # noqa: E402, F403
|
||||
|
||||
@@ -1,9 +1,13 @@
|
||||
"""Streaming update mechanism configuration."""
|
||||
"""Backward-compatibility shim — use ``crewai_a2a.updates.streaming.config`` instead."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pydantic import BaseModel
|
||||
import warnings
|
||||
|
||||
|
||||
class StreamingConfig(BaseModel):
|
||||
"""Configuration for SSE-based task updates."""
|
||||
warnings.warn(
|
||||
"'crewai.a2a.updates.streaming.config' has been moved to 'crewai_a2a.updates.streaming.config'. "
|
||||
"Please update your imports. The old path will be removed in v2.0.0.",
|
||||
FutureWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
from crewai_a2a.updates.streaming.config import * # noqa: E402, F403
|
||||
|
||||
@@ -1,646 +1,13 @@
|
||||
"""Streaming (SSE) update mechanism handler."""
|
||||
"""Backward-compatibility shim — use ``crewai_a2a.updates.streaming.handler`` instead."""
|
||||
|
||||
from __future__ import annotations
|
||||
import warnings
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Final
|
||||
import uuid
|
||||
|
||||
from a2a.client import Client
|
||||
from a2a.client.errors import A2AClientHTTPError
|
||||
from a2a.types import (
|
||||
AgentCard,
|
||||
Message,
|
||||
Part,
|
||||
Role,
|
||||
Task,
|
||||
TaskArtifactUpdateEvent,
|
||||
TaskIdParams,
|
||||
TaskQueryParams,
|
||||
TaskState,
|
||||
TaskStatusUpdateEvent,
|
||||
TextPart,
|
||||
)
|
||||
from typing_extensions import Unpack
|
||||
|
||||
from crewai.a2a.task_helpers import (
|
||||
ACTIONABLE_STATES,
|
||||
TERMINAL_STATES,
|
||||
TaskStateResult,
|
||||
process_task_state,
|
||||
)
|
||||
from crewai.a2a.updates.base import StreamingHandlerKwargs, extract_common_params
|
||||
from crewai.a2a.updates.streaming.params import (
|
||||
process_status_update,
|
||||
)
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.a2a_events import (
|
||||
A2AArtifactReceivedEvent,
|
||||
A2AConnectionErrorEvent,
|
||||
A2AResponseReceivedEvent,
|
||||
A2AStreamingChunkEvent,
|
||||
A2AStreamingStartedEvent,
|
||||
warnings.warn(
|
||||
"'crewai.a2a.updates.streaming.handler' has been moved to 'crewai_a2a.updates.streaming.handler'. "
|
||||
"Please update your imports. The old path will be removed in v2.0.0.",
|
||||
FutureWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MAX_RESUBSCRIBE_ATTEMPTS: Final[int] = 3
|
||||
RESUBSCRIBE_BACKOFF_BASE: Final[float] = 1.0
|
||||
|
||||
|
||||
class StreamingHandler:
|
||||
"""SSE streaming-based update handler."""
|
||||
|
||||
@staticmethod
|
||||
async def _try_recover_from_interruption( # type: ignore[misc]
|
||||
client: Client,
|
||||
task_id: str,
|
||||
new_messages: list[Message],
|
||||
agent_card: AgentCard,
|
||||
result_parts: list[str],
|
||||
**kwargs: Unpack[StreamingHandlerKwargs],
|
||||
) -> TaskStateResult | None:
|
||||
"""Attempt to recover from a stream interruption by checking task state.
|
||||
|
||||
If the task completed while we were disconnected, returns the result.
|
||||
If the task is still running, attempts to resubscribe and continue.
|
||||
|
||||
Args:
|
||||
client: A2A client instance.
|
||||
task_id: The task ID to recover.
|
||||
new_messages: List of collected messages.
|
||||
agent_card: The agent card.
|
||||
result_parts: Accumulated result text parts.
|
||||
**kwargs: Handler parameters.
|
||||
|
||||
Returns:
|
||||
TaskStateResult if recovery succeeded (task finished or resubscribe worked).
|
||||
None if recovery not possible (caller should handle failure).
|
||||
|
||||
Note:
|
||||
When None is returned, recovery failed and the original exception should
|
||||
be handled by the caller. All recovery attempts are logged.
|
||||
"""
|
||||
params = extract_common_params(kwargs) # type: ignore[arg-type]
|
||||
|
||||
try:
|
||||
a2a_task: Task = await client.get_task(TaskQueryParams(id=task_id))
|
||||
|
||||
if a2a_task.status.state in TERMINAL_STATES:
|
||||
logger.info(
|
||||
"Task completed during stream interruption",
|
||||
extra={"task_id": task_id, "state": str(a2a_task.status.state)},
|
||||
)
|
||||
return process_task_state(
|
||||
a2a_task=a2a_task,
|
||||
new_messages=new_messages,
|
||||
agent_card=agent_card,
|
||||
turn_number=params.turn_number,
|
||||
is_multiturn=params.is_multiturn,
|
||||
agent_role=params.agent_role,
|
||||
result_parts=result_parts,
|
||||
endpoint=params.endpoint,
|
||||
a2a_agent_name=params.a2a_agent_name,
|
||||
from_task=params.from_task,
|
||||
from_agent=params.from_agent,
|
||||
)
|
||||
|
||||
if a2a_task.status.state in ACTIONABLE_STATES:
|
||||
logger.info(
|
||||
"Task in actionable state during stream interruption",
|
||||
extra={"task_id": task_id, "state": str(a2a_task.status.state)},
|
||||
)
|
||||
return process_task_state(
|
||||
a2a_task=a2a_task,
|
||||
new_messages=new_messages,
|
||||
agent_card=agent_card,
|
||||
turn_number=params.turn_number,
|
||||
is_multiturn=params.is_multiturn,
|
||||
agent_role=params.agent_role,
|
||||
result_parts=result_parts,
|
||||
endpoint=params.endpoint,
|
||||
a2a_agent_name=params.a2a_agent_name,
|
||||
from_task=params.from_task,
|
||||
from_agent=params.from_agent,
|
||||
is_final=False,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Task still running, attempting resubscribe",
|
||||
extra={"task_id": task_id, "state": str(a2a_task.status.state)},
|
||||
)
|
||||
|
||||
for attempt in range(MAX_RESUBSCRIBE_ATTEMPTS):
|
||||
try:
|
||||
backoff = RESUBSCRIBE_BACKOFF_BASE * (2**attempt)
|
||||
if attempt > 0:
|
||||
await asyncio.sleep(backoff)
|
||||
|
||||
event_stream = client.resubscribe(TaskIdParams(id=task_id))
|
||||
|
||||
async for event in event_stream:
|
||||
if isinstance(event, tuple):
|
||||
resubscribed_task, update = event
|
||||
|
||||
is_final_update = (
|
||||
process_status_update(update, result_parts)
|
||||
if isinstance(update, TaskStatusUpdateEvent)
|
||||
else False
|
||||
)
|
||||
|
||||
if isinstance(update, TaskArtifactUpdateEvent):
|
||||
artifact = update.artifact
|
||||
result_parts.extend(
|
||||
part.root.text
|
||||
for part in artifact.parts
|
||||
if part.root.kind == "text"
|
||||
)
|
||||
|
||||
if (
|
||||
is_final_update
|
||||
or resubscribed_task.status.state
|
||||
in TERMINAL_STATES | ACTIONABLE_STATES
|
||||
):
|
||||
return process_task_state(
|
||||
a2a_task=resubscribed_task,
|
||||
new_messages=new_messages,
|
||||
agent_card=agent_card,
|
||||
turn_number=params.turn_number,
|
||||
is_multiturn=params.is_multiturn,
|
||||
agent_role=params.agent_role,
|
||||
result_parts=result_parts,
|
||||
endpoint=params.endpoint,
|
||||
a2a_agent_name=params.a2a_agent_name,
|
||||
from_task=params.from_task,
|
||||
from_agent=params.from_agent,
|
||||
is_final=is_final_update,
|
||||
)
|
||||
|
||||
elif isinstance(event, Message):
|
||||
new_messages.append(event)
|
||||
result_parts.extend(
|
||||
part.root.text
|
||||
for part in event.parts
|
||||
if part.root.kind == "text"
|
||||
)
|
||||
|
||||
final_task = await client.get_task(TaskQueryParams(id=task_id))
|
||||
return process_task_state(
|
||||
a2a_task=final_task,
|
||||
new_messages=new_messages,
|
||||
agent_card=agent_card,
|
||||
turn_number=params.turn_number,
|
||||
is_multiturn=params.is_multiturn,
|
||||
agent_role=params.agent_role,
|
||||
result_parts=result_parts,
|
||||
endpoint=params.endpoint,
|
||||
a2a_agent_name=params.a2a_agent_name,
|
||||
from_task=params.from_task,
|
||||
from_agent=params.from_agent,
|
||||
)
|
||||
|
||||
except Exception as resubscribe_error: # noqa: PERF203
|
||||
logger.warning(
|
||||
"Resubscribe attempt failed",
|
||||
extra={
|
||||
"task_id": task_id,
|
||||
"attempt": attempt + 1,
|
||||
"max_attempts": MAX_RESUBSCRIBE_ATTEMPTS,
|
||||
"error": str(resubscribe_error),
|
||||
},
|
||||
)
|
||||
if attempt == MAX_RESUBSCRIBE_ATTEMPTS - 1:
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Failed to recover from stream interruption due to unexpected error",
|
||||
extra={
|
||||
"task_id": task_id,
|
||||
"error": str(e),
|
||||
"error_type": type(e).__name__,
|
||||
},
|
||||
exc_info=True,
|
||||
)
|
||||
return None
|
||||
|
||||
logger.warning(
|
||||
"Recovery exhausted all resubscribe attempts without success",
|
||||
extra={"task_id": task_id, "max_attempts": MAX_RESUBSCRIBE_ATTEMPTS},
|
||||
)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
async def execute(
|
||||
client: Client,
|
||||
message: Message,
|
||||
new_messages: list[Message],
|
||||
agent_card: AgentCard,
|
||||
**kwargs: Unpack[StreamingHandlerKwargs],
|
||||
) -> TaskStateResult:
|
||||
"""Execute A2A delegation using SSE streaming for updates.
|
||||
|
||||
Args:
|
||||
client: A2A client instance.
|
||||
message: Message to send.
|
||||
new_messages: List to collect messages.
|
||||
agent_card: The agent card.
|
||||
**kwargs: Streaming-specific parameters.
|
||||
|
||||
Returns:
|
||||
Dictionary with status, result/error, and history.
|
||||
"""
|
||||
task_id = kwargs.get("task_id")
|
||||
agent_branch = kwargs.get("agent_branch")
|
||||
params = extract_common_params(kwargs)
|
||||
|
||||
result_parts: list[str] = []
|
||||
final_result: TaskStateResult | None = None
|
||||
event_stream = client.send_message(message)
|
||||
chunk_index = 0
|
||||
current_task_id: str | None = task_id
|
||||
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
A2AStreamingStartedEvent(
|
||||
task_id=task_id,
|
||||
context_id=params.context_id,
|
||||
endpoint=params.endpoint,
|
||||
a2a_agent_name=params.a2a_agent_name,
|
||||
turn_number=params.turn_number,
|
||||
is_multiturn=params.is_multiturn,
|
||||
agent_role=params.agent_role,
|
||||
from_task=params.from_task,
|
||||
from_agent=params.from_agent,
|
||||
),
|
||||
)
|
||||
|
||||
try:
|
||||
async for event in event_stream:
|
||||
if isinstance(event, tuple):
|
||||
a2a_task, _ = event
|
||||
current_task_id = a2a_task.id
|
||||
|
||||
if isinstance(event, Message):
|
||||
new_messages.append(event)
|
||||
message_context_id = event.context_id or params.context_id
|
||||
for part in event.parts:
|
||||
if part.root.kind == "text":
|
||||
text = part.root.text
|
||||
result_parts.append(text)
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
A2AStreamingChunkEvent(
|
||||
task_id=event.task_id or task_id,
|
||||
context_id=message_context_id,
|
||||
chunk=text,
|
||||
chunk_index=chunk_index,
|
||||
endpoint=params.endpoint,
|
||||
a2a_agent_name=params.a2a_agent_name,
|
||||
turn_number=params.turn_number,
|
||||
is_multiturn=params.is_multiturn,
|
||||
from_task=params.from_task,
|
||||
from_agent=params.from_agent,
|
||||
),
|
||||
)
|
||||
chunk_index += 1
|
||||
|
||||
elif isinstance(event, tuple):
|
||||
a2a_task, update = event
|
||||
|
||||
if isinstance(update, TaskArtifactUpdateEvent):
|
||||
artifact = update.artifact
|
||||
result_parts.extend(
|
||||
part.root.text
|
||||
for part in artifact.parts
|
||||
if part.root.kind == "text"
|
||||
)
|
||||
artifact_size = None
|
||||
if artifact.parts:
|
||||
artifact_size = sum(
|
||||
len(p.root.text.encode())
|
||||
if p.root.kind == "text"
|
||||
else len(getattr(p.root, "data", b""))
|
||||
for p in artifact.parts
|
||||
)
|
||||
effective_context_id = a2a_task.context_id or params.context_id
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
A2AArtifactReceivedEvent(
|
||||
task_id=a2a_task.id,
|
||||
artifact_id=artifact.artifact_id,
|
||||
artifact_name=artifact.name,
|
||||
artifact_description=artifact.description,
|
||||
mime_type=artifact.parts[0].root.kind
|
||||
if artifact.parts
|
||||
else None,
|
||||
size_bytes=artifact_size,
|
||||
append=update.append or False,
|
||||
last_chunk=update.last_chunk or False,
|
||||
endpoint=params.endpoint,
|
||||
a2a_agent_name=params.a2a_agent_name,
|
||||
context_id=effective_context_id,
|
||||
turn_number=params.turn_number,
|
||||
is_multiturn=params.is_multiturn,
|
||||
from_task=params.from_task,
|
||||
from_agent=params.from_agent,
|
||||
),
|
||||
)
|
||||
|
||||
is_final_update = (
|
||||
process_status_update(update, result_parts)
|
||||
if isinstance(update, TaskStatusUpdateEvent)
|
||||
else False
|
||||
)
|
||||
|
||||
if (
|
||||
not is_final_update
|
||||
and a2a_task.status.state
|
||||
not in TERMINAL_STATES | ACTIONABLE_STATES
|
||||
):
|
||||
continue
|
||||
|
||||
final_result = process_task_state(
|
||||
a2a_task=a2a_task,
|
||||
new_messages=new_messages,
|
||||
agent_card=agent_card,
|
||||
turn_number=params.turn_number,
|
||||
is_multiturn=params.is_multiturn,
|
||||
agent_role=params.agent_role,
|
||||
result_parts=result_parts,
|
||||
endpoint=params.endpoint,
|
||||
a2a_agent_name=params.a2a_agent_name,
|
||||
from_task=params.from_task,
|
||||
from_agent=params.from_agent,
|
||||
is_final=is_final_update,
|
||||
)
|
||||
if final_result:
|
||||
break
|
||||
|
||||
except A2AClientHTTPError as e:
|
||||
if current_task_id:
|
||||
logger.info(
|
||||
"Stream interrupted with HTTP error, attempting recovery",
|
||||
extra={
|
||||
"task_id": current_task_id,
|
||||
"error": str(e),
|
||||
"status_code": e.status_code,
|
||||
},
|
||||
)
|
||||
recovery_kwargs = {k: v for k, v in kwargs.items() if k != "task_id"}
|
||||
recovered_result = (
|
||||
await StreamingHandler._try_recover_from_interruption(
|
||||
client=client,
|
||||
task_id=current_task_id,
|
||||
new_messages=new_messages,
|
||||
agent_card=agent_card,
|
||||
result_parts=result_parts,
|
||||
**recovery_kwargs,
|
||||
)
|
||||
)
|
||||
if recovered_result:
|
||||
logger.info(
|
||||
"Successfully recovered task after HTTP error",
|
||||
extra={
|
||||
"task_id": current_task_id,
|
||||
"status": str(recovered_result.get("status")),
|
||||
},
|
||||
)
|
||||
return recovered_result
|
||||
|
||||
logger.warning(
|
||||
"Failed to recover from HTTP error, returning failure",
|
||||
extra={
|
||||
"task_id": current_task_id,
|
||||
"status_code": e.status_code,
|
||||
"original_error": str(e),
|
||||
},
|
||||
)
|
||||
|
||||
error_msg = f"HTTP Error {e.status_code}: {e!s}"
|
||||
error_type = "http_error"
|
||||
status_code = e.status_code
|
||||
|
||||
error_message = Message(
|
||||
role=Role.agent,
|
||||
message_id=str(uuid.uuid4()),
|
||||
parts=[Part(root=TextPart(text=error_msg))],
|
||||
context_id=params.context_id,
|
||||
task_id=task_id,
|
||||
)
|
||||
new_messages.append(error_message)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
A2AConnectionErrorEvent(
|
||||
endpoint=params.endpoint,
|
||||
error=str(e),
|
||||
error_type=error_type,
|
||||
status_code=status_code,
|
||||
a2a_agent_name=params.a2a_agent_name,
|
||||
operation="streaming",
|
||||
context_id=params.context_id,
|
||||
task_id=task_id,
|
||||
from_task=params.from_task,
|
||||
from_agent=params.from_agent,
|
||||
),
|
||||
)
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
A2AResponseReceivedEvent(
|
||||
response=error_msg,
|
||||
turn_number=params.turn_number,
|
||||
context_id=params.context_id,
|
||||
is_multiturn=params.is_multiturn,
|
||||
status="failed",
|
||||
final=True,
|
||||
agent_role=params.agent_role,
|
||||
endpoint=params.endpoint,
|
||||
a2a_agent_name=params.a2a_agent_name,
|
||||
from_task=params.from_task,
|
||||
from_agent=params.from_agent,
|
||||
),
|
||||
)
|
||||
return TaskStateResult(
|
||||
status=TaskState.failed,
|
||||
error=error_msg,
|
||||
history=new_messages,
|
||||
)
|
||||
|
||||
except (asyncio.TimeoutError, asyncio.CancelledError, ConnectionError) as e:
|
||||
error_type = type(e).__name__.lower()
|
||||
if current_task_id:
|
||||
logger.info(
|
||||
f"Stream interrupted with {error_type}, attempting recovery",
|
||||
extra={"task_id": current_task_id, "error": str(e)},
|
||||
)
|
||||
recovery_kwargs = {k: v for k, v in kwargs.items() if k != "task_id"}
|
||||
recovered_result = (
|
||||
await StreamingHandler._try_recover_from_interruption(
|
||||
client=client,
|
||||
task_id=current_task_id,
|
||||
new_messages=new_messages,
|
||||
agent_card=agent_card,
|
||||
result_parts=result_parts,
|
||||
**recovery_kwargs,
|
||||
)
|
||||
)
|
||||
if recovered_result:
|
||||
logger.info(
|
||||
f"Successfully recovered task after {error_type}",
|
||||
extra={
|
||||
"task_id": current_task_id,
|
||||
"status": str(recovered_result.get("status")),
|
||||
},
|
||||
)
|
||||
return recovered_result
|
||||
|
||||
logger.warning(
|
||||
f"Failed to recover from {error_type}, returning failure",
|
||||
extra={
|
||||
"task_id": current_task_id,
|
||||
"error_type": error_type,
|
||||
"original_error": str(e),
|
||||
},
|
||||
)
|
||||
|
||||
error_msg = f"Connection error during streaming: {e!s}"
|
||||
status_code = None
|
||||
|
||||
error_message = Message(
|
||||
role=Role.agent,
|
||||
message_id=str(uuid.uuid4()),
|
||||
parts=[Part(root=TextPart(text=error_msg))],
|
||||
context_id=params.context_id,
|
||||
task_id=task_id,
|
||||
)
|
||||
new_messages.append(error_message)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
A2AConnectionErrorEvent(
|
||||
endpoint=params.endpoint,
|
||||
error=str(e),
|
||||
error_type=error_type,
|
||||
status_code=status_code,
|
||||
a2a_agent_name=params.a2a_agent_name,
|
||||
operation="streaming",
|
||||
context_id=params.context_id,
|
||||
task_id=task_id,
|
||||
from_task=params.from_task,
|
||||
from_agent=params.from_agent,
|
||||
),
|
||||
)
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
A2AResponseReceivedEvent(
|
||||
response=error_msg,
|
||||
turn_number=params.turn_number,
|
||||
context_id=params.context_id,
|
||||
is_multiturn=params.is_multiturn,
|
||||
status="failed",
|
||||
final=True,
|
||||
agent_role=params.agent_role,
|
||||
endpoint=params.endpoint,
|
||||
a2a_agent_name=params.a2a_agent_name,
|
||||
from_task=params.from_task,
|
||||
from_agent=params.from_agent,
|
||||
),
|
||||
)
|
||||
return TaskStateResult(
|
||||
status=TaskState.failed,
|
||||
error=error_msg,
|
||||
history=new_messages,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"Unexpected error during streaming",
|
||||
extra={
|
||||
"task_id": current_task_id,
|
||||
"error_type": type(e).__name__,
|
||||
"endpoint": params.endpoint,
|
||||
},
|
||||
)
|
||||
error_msg = f"Unexpected error during streaming: {type(e).__name__}: {e!s}"
|
||||
error_type = "unexpected_error"
|
||||
status_code = None
|
||||
|
||||
error_message = Message(
|
||||
role=Role.agent,
|
||||
message_id=str(uuid.uuid4()),
|
||||
parts=[Part(root=TextPart(text=error_msg))],
|
||||
context_id=params.context_id,
|
||||
task_id=task_id,
|
||||
)
|
||||
new_messages.append(error_message)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
A2AConnectionErrorEvent(
|
||||
endpoint=params.endpoint,
|
||||
error=str(e),
|
||||
error_type=error_type,
|
||||
status_code=status_code,
|
||||
a2a_agent_name=params.a2a_agent_name,
|
||||
operation="streaming",
|
||||
context_id=params.context_id,
|
||||
task_id=task_id,
|
||||
from_task=params.from_task,
|
||||
from_agent=params.from_agent,
|
||||
),
|
||||
)
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
A2AResponseReceivedEvent(
|
||||
response=error_msg,
|
||||
turn_number=params.turn_number,
|
||||
context_id=params.context_id,
|
||||
is_multiturn=params.is_multiturn,
|
||||
status="failed",
|
||||
final=True,
|
||||
agent_role=params.agent_role,
|
||||
endpoint=params.endpoint,
|
||||
a2a_agent_name=params.a2a_agent_name,
|
||||
from_task=params.from_task,
|
||||
from_agent=params.from_agent,
|
||||
),
|
||||
)
|
||||
return TaskStateResult(
|
||||
status=TaskState.failed,
|
||||
error=error_msg,
|
||||
history=new_messages,
|
||||
)
|
||||
|
||||
finally:
|
||||
aclose = getattr(event_stream, "aclose", None)
|
||||
if aclose:
|
||||
try:
|
||||
await aclose()
|
||||
except Exception as close_error:
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
A2AConnectionErrorEvent(
|
||||
endpoint=params.endpoint,
|
||||
error=str(close_error),
|
||||
error_type="stream_close_error",
|
||||
a2a_agent_name=params.a2a_agent_name,
|
||||
operation="stream_close",
|
||||
context_id=params.context_id,
|
||||
task_id=task_id,
|
||||
from_task=params.from_task,
|
||||
from_agent=params.from_agent,
|
||||
),
|
||||
)
|
||||
|
||||
if final_result:
|
||||
return final_result
|
||||
|
||||
return TaskStateResult(
|
||||
status=TaskState.completed,
|
||||
result=" ".join(result_parts) if result_parts else "",
|
||||
history=new_messages,
|
||||
agent_card=agent_card.model_dump(exclude_none=True),
|
||||
)
|
||||
from crewai_a2a.updates.streaming.handler import * # noqa: E402, F403
|
||||
|
||||
@@ -1,28 +1,13 @@
|
||||
"""Common parameter extraction for streaming handlers."""
|
||||
"""Backward-compatibility shim — use ``crewai_a2a.updates.streaming.params`` instead."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from a2a.types import TaskStatusUpdateEvent
|
||||
import warnings
|
||||
|
||||
|
||||
def process_status_update(
|
||||
update: TaskStatusUpdateEvent,
|
||||
result_parts: list[str],
|
||||
) -> bool:
|
||||
"""Process a status update event and extract text parts.
|
||||
warnings.warn(
|
||||
"'crewai.a2a.updates.streaming.params' has been moved to 'crewai_a2a.updates.streaming.params'. "
|
||||
"Please update your imports. The old path will be removed in v2.0.0.",
|
||||
FutureWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
Args:
|
||||
update: The status update event.
|
||||
result_parts: List to append text parts to (modified in place).
|
||||
|
||||
Returns:
|
||||
True if this is a final update, False otherwise.
|
||||
"""
|
||||
is_final = update.final
|
||||
if update.status and update.status.message and update.status.message.parts:
|
||||
result_parts.extend(
|
||||
part.root.text
|
||||
for part in update.status.message.parts
|
||||
if part.root.kind == "text" and part.root.text
|
||||
)
|
||||
return is_final
|
||||
from crewai_a2a.updates.streaming.params import * # noqa: E402, F403
|
||||
|
||||
@@ -1 +1,13 @@
|
||||
"""A2A utility modules for client operations."""
|
||||
"""Backward-compatibility shim — use ``crewai_a2a.utils`` instead."""
|
||||
|
||||
import warnings
|
||||
|
||||
|
||||
warnings.warn(
|
||||
"'crewai.a2a.utils' has been moved to 'crewai_a2a.utils'. "
|
||||
"Please update your imports. The old path will be removed in v2.0.0.",
|
||||
FutureWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
from crewai_a2a.utils import * # noqa: E402, F403
|
||||
|
||||
@@ -1,586 +1,13 @@
|
||||
"""AgentCard utilities for A2A client and server operations."""
|
||||
"""Backward-compatibility shim — use ``crewai_a2a.utils.agent_card`` instead."""
|
||||
|
||||
from __future__ import annotations
|
||||
import warnings
|
||||
|
||||
import asyncio
|
||||
from collections.abc import MutableMapping
|
||||
from functools import lru_cache
|
||||
import ssl
|
||||
import time
|
||||
from types import MethodType
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from a2a.client.errors import A2AClientHTTPError
|
||||
from a2a.types import AgentCapabilities, AgentCard, AgentSkill
|
||||
from aiocache import cached # type: ignore[import-untyped]
|
||||
from aiocache.serializers import PickleSerializer # type: ignore[import-untyped]
|
||||
import httpx
|
||||
|
||||
from crewai.a2a.auth.client_schemes import APIKeyAuth, HTTPDigestAuth
|
||||
from crewai.a2a.auth.utils import (
|
||||
_auth_store,
|
||||
configure_auth_client,
|
||||
retry_on_401,
|
||||
)
|
||||
from crewai.a2a.config import A2AServerConfig
|
||||
from crewai.crew import Crew
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.a2a_events import (
|
||||
A2AAgentCardFetchedEvent,
|
||||
A2AAuthenticationFailedEvent,
|
||||
A2AConnectionErrorEvent,
|
||||
warnings.warn(
|
||||
"'crewai.a2a.utils.agent_card' has been moved to 'crewai_a2a.utils.agent_card'. "
|
||||
"Please update your imports. The old path will be removed in v2.0.0.",
|
||||
FutureWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.a2a.auth.client_schemes import ClientAuthScheme
|
||||
from crewai.agent import Agent
|
||||
from crewai.task import Task
|
||||
|
||||
|
||||
def _get_tls_verify(auth: ClientAuthScheme | None) -> ssl.SSLContext | bool | str:
|
||||
"""Get TLS verify parameter from auth scheme.
|
||||
|
||||
Args:
|
||||
auth: Optional authentication scheme with TLS config.
|
||||
|
||||
Returns:
|
||||
SSL context, CA cert path, True for default verification,
|
||||
or False if verification disabled.
|
||||
"""
|
||||
if auth and auth.tls:
|
||||
return auth.tls.get_httpx_ssl_context()
|
||||
return True
|
||||
|
||||
|
||||
async def _prepare_auth_headers(
|
||||
auth: ClientAuthScheme | None,
|
||||
timeout: int,
|
||||
) -> tuple[MutableMapping[str, str], ssl.SSLContext | bool | str]:
|
||||
"""Prepare authentication headers and TLS verification settings.
|
||||
|
||||
Args:
|
||||
auth: Optional authentication scheme.
|
||||
timeout: Request timeout in seconds.
|
||||
|
||||
Returns:
|
||||
Tuple of (headers dict, TLS verify setting).
|
||||
"""
|
||||
headers: MutableMapping[str, str] = {}
|
||||
verify = _get_tls_verify(auth)
|
||||
if auth:
|
||||
async with httpx.AsyncClient(
|
||||
timeout=timeout, verify=verify
|
||||
) as temp_auth_client:
|
||||
if isinstance(auth, (HTTPDigestAuth, APIKeyAuth)):
|
||||
configure_auth_client(auth, temp_auth_client)
|
||||
headers = await auth.apply_auth(temp_auth_client, {})
|
||||
return headers, verify
|
||||
|
||||
|
||||
def _get_server_config(agent: Agent) -> A2AServerConfig | None:
|
||||
"""Get A2AServerConfig from an agent's a2a configuration.
|
||||
|
||||
Args:
|
||||
agent: The Agent instance to check.
|
||||
|
||||
Returns:
|
||||
A2AServerConfig if present, None otherwise.
|
||||
"""
|
||||
if agent.a2a is None:
|
||||
return None
|
||||
if isinstance(agent.a2a, A2AServerConfig):
|
||||
return agent.a2a
|
||||
if isinstance(agent.a2a, list):
|
||||
for config in agent.a2a:
|
||||
if isinstance(config, A2AServerConfig):
|
||||
return config
|
||||
return None
|
||||
|
||||
|
||||
def fetch_agent_card(
|
||||
endpoint: str,
|
||||
auth: ClientAuthScheme | None = None,
|
||||
timeout: int = 30,
|
||||
use_cache: bool = True,
|
||||
cache_ttl: int = 300,
|
||||
) -> AgentCard:
|
||||
"""Fetch AgentCard from an A2A endpoint with optional caching.
|
||||
|
||||
Args:
|
||||
endpoint: A2A agent endpoint URL (AgentCard URL).
|
||||
auth: Optional ClientAuthScheme for authentication.
|
||||
timeout: Request timeout in seconds.
|
||||
use_cache: Whether to use caching (default True).
|
||||
cache_ttl: Cache TTL in seconds (default 300 = 5 minutes).
|
||||
|
||||
Returns:
|
||||
AgentCard object with agent capabilities and skills.
|
||||
|
||||
Raises:
|
||||
httpx.HTTPStatusError: If the request fails.
|
||||
A2AClientHTTPError: If authentication fails.
|
||||
"""
|
||||
if use_cache:
|
||||
if auth:
|
||||
auth_data = auth.model_dump_json(
|
||||
exclude={
|
||||
"_access_token",
|
||||
"_token_expires_at",
|
||||
"_refresh_token",
|
||||
"_authorization_callback",
|
||||
}
|
||||
)
|
||||
auth_hash = _auth_store.compute_key(type(auth).__name__, auth_data)
|
||||
else:
|
||||
auth_hash = _auth_store.compute_key("none", "")
|
||||
_auth_store.set(auth_hash, auth)
|
||||
ttl_hash = int(time.time() // cache_ttl)
|
||||
return _fetch_agent_card_cached(endpoint, auth_hash, timeout, ttl_hash)
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
return loop.run_until_complete(
|
||||
afetch_agent_card(endpoint=endpoint, auth=auth, timeout=timeout)
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
|
||||
async def afetch_agent_card(
|
||||
endpoint: str,
|
||||
auth: ClientAuthScheme | None = None,
|
||||
timeout: int = 30,
|
||||
use_cache: bool = True,
|
||||
) -> AgentCard:
|
||||
"""Fetch AgentCard from an A2A endpoint asynchronously.
|
||||
|
||||
Native async implementation. Use this when running in an async context.
|
||||
|
||||
Args:
|
||||
endpoint: A2A agent endpoint URL (AgentCard URL).
|
||||
auth: Optional ClientAuthScheme for authentication.
|
||||
timeout: Request timeout in seconds.
|
||||
use_cache: Whether to use caching (default True).
|
||||
|
||||
Returns:
|
||||
AgentCard object with agent capabilities and skills.
|
||||
|
||||
Raises:
|
||||
httpx.HTTPStatusError: If the request fails.
|
||||
A2AClientHTTPError: If authentication fails.
|
||||
"""
|
||||
if use_cache:
|
||||
if auth:
|
||||
auth_data = auth.model_dump_json(
|
||||
exclude={
|
||||
"_access_token",
|
||||
"_token_expires_at",
|
||||
"_refresh_token",
|
||||
"_authorization_callback",
|
||||
}
|
||||
)
|
||||
auth_hash = _auth_store.compute_key(type(auth).__name__, auth_data)
|
||||
else:
|
||||
auth_hash = _auth_store.compute_key("none", "")
|
||||
_auth_store.set(auth_hash, auth)
|
||||
agent_card: AgentCard = await _afetch_agent_card_cached(
|
||||
endpoint, auth_hash, timeout
|
||||
)
|
||||
return agent_card
|
||||
|
||||
return await _afetch_agent_card_impl(endpoint=endpoint, auth=auth, timeout=timeout)
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def _fetch_agent_card_cached(
|
||||
endpoint: str,
|
||||
auth_hash: str,
|
||||
timeout: int,
|
||||
_ttl_hash: int,
|
||||
) -> AgentCard:
|
||||
"""Cached sync version of fetch_agent_card."""
|
||||
auth = _auth_store.get(auth_hash)
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
return loop.run_until_complete(
|
||||
_afetch_agent_card_impl(endpoint=endpoint, auth=auth, timeout=timeout)
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
|
||||
@cached(ttl=300, serializer=PickleSerializer()) # type: ignore[untyped-decorator]
|
||||
async def _afetch_agent_card_cached(
|
||||
endpoint: str,
|
||||
auth_hash: str,
|
||||
timeout: int,
|
||||
) -> AgentCard:
|
||||
"""Cached async implementation of AgentCard fetching."""
|
||||
auth = _auth_store.get(auth_hash)
|
||||
return await _afetch_agent_card_impl(endpoint=endpoint, auth=auth, timeout=timeout)
|
||||
|
||||
|
||||
async def _afetch_agent_card_impl(
|
||||
endpoint: str,
|
||||
auth: ClientAuthScheme | None,
|
||||
timeout: int,
|
||||
) -> AgentCard:
|
||||
"""Internal async implementation of AgentCard fetching."""
|
||||
start_time = time.perf_counter()
|
||||
|
||||
if "/.well-known/agent-card.json" in endpoint:
|
||||
base_url = endpoint.replace("/.well-known/agent-card.json", "")
|
||||
agent_card_path = "/.well-known/agent-card.json"
|
||||
else:
|
||||
url_parts = endpoint.split("/", 3)
|
||||
base_url = f"{url_parts[0]}//{url_parts[2]}"
|
||||
agent_card_path = (
|
||||
f"/{url_parts[3]}"
|
||||
if len(url_parts) > 3 and url_parts[3]
|
||||
else "/.well-known/agent-card.json"
|
||||
)
|
||||
|
||||
headers, verify = await _prepare_auth_headers(auth, timeout)
|
||||
|
||||
async with httpx.AsyncClient(
|
||||
timeout=timeout, headers=headers, verify=verify
|
||||
) as temp_client:
|
||||
if auth and isinstance(auth, (HTTPDigestAuth, APIKeyAuth)):
|
||||
configure_auth_client(auth, temp_client)
|
||||
|
||||
agent_card_url = f"{base_url}{agent_card_path}"
|
||||
|
||||
async def _fetch_agent_card_request() -> httpx.Response:
|
||||
return await temp_client.get(agent_card_url)
|
||||
|
||||
try:
|
||||
response = await retry_on_401(
|
||||
request_func=_fetch_agent_card_request,
|
||||
auth_scheme=auth,
|
||||
client=temp_client,
|
||||
headers=temp_client.headers,
|
||||
max_retries=2,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
agent_card = AgentCard.model_validate(response.json())
|
||||
fetch_time_ms = (time.perf_counter() - start_time) * 1000
|
||||
agent_card_dict = agent_card.model_dump(exclude_none=True)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
None,
|
||||
A2AAgentCardFetchedEvent(
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=agent_card.name,
|
||||
agent_card=agent_card_dict,
|
||||
protocol_version=agent_card.protocol_version,
|
||||
provider=agent_card_dict.get("provider"),
|
||||
cached=False,
|
||||
fetch_time_ms=fetch_time_ms,
|
||||
),
|
||||
)
|
||||
|
||||
return agent_card
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
elapsed_ms = (time.perf_counter() - start_time) * 1000
|
||||
response_body = e.response.text[:1000] if e.response.text else None
|
||||
|
||||
if e.response.status_code == 401:
|
||||
error_details = ["Authentication failed"]
|
||||
www_auth = e.response.headers.get("WWW-Authenticate")
|
||||
if www_auth:
|
||||
error_details.append(f"WWW-Authenticate: {www_auth}")
|
||||
if not auth:
|
||||
error_details.append("No auth scheme provided")
|
||||
msg = " | ".join(error_details)
|
||||
|
||||
auth_type = type(auth).__name__ if auth else None
|
||||
crewai_event_bus.emit(
|
||||
None,
|
||||
A2AAuthenticationFailedEvent(
|
||||
endpoint=endpoint,
|
||||
auth_type=auth_type,
|
||||
error=msg,
|
||||
status_code=401,
|
||||
metadata={
|
||||
"elapsed_ms": elapsed_ms,
|
||||
"response_body": response_body,
|
||||
"www_authenticate": www_auth,
|
||||
"request_url": str(e.request.url),
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
raise A2AClientHTTPError(401, msg) from e
|
||||
|
||||
crewai_event_bus.emit(
|
||||
None,
|
||||
A2AConnectionErrorEvent(
|
||||
endpoint=endpoint,
|
||||
error=str(e),
|
||||
error_type="http_error",
|
||||
status_code=e.response.status_code,
|
||||
operation="fetch_agent_card",
|
||||
metadata={
|
||||
"elapsed_ms": elapsed_ms,
|
||||
"response_body": response_body,
|
||||
"request_url": str(e.request.url),
|
||||
},
|
||||
),
|
||||
)
|
||||
raise
|
||||
|
||||
except httpx.TimeoutException as e:
|
||||
elapsed_ms = (time.perf_counter() - start_time) * 1000
|
||||
crewai_event_bus.emit(
|
||||
None,
|
||||
A2AConnectionErrorEvent(
|
||||
endpoint=endpoint,
|
||||
error=str(e),
|
||||
error_type="timeout",
|
||||
operation="fetch_agent_card",
|
||||
metadata={
|
||||
"elapsed_ms": elapsed_ms,
|
||||
"timeout_config": timeout,
|
||||
"request_url": str(e.request.url) if e.request else None,
|
||||
},
|
||||
),
|
||||
)
|
||||
raise
|
||||
|
||||
except httpx.ConnectError as e:
|
||||
elapsed_ms = (time.perf_counter() - start_time) * 1000
|
||||
crewai_event_bus.emit(
|
||||
None,
|
||||
A2AConnectionErrorEvent(
|
||||
endpoint=endpoint,
|
||||
error=str(e),
|
||||
error_type="connection_error",
|
||||
operation="fetch_agent_card",
|
||||
metadata={
|
||||
"elapsed_ms": elapsed_ms,
|
||||
"request_url": str(e.request.url) if e.request else None,
|
||||
},
|
||||
),
|
||||
)
|
||||
raise
|
||||
|
||||
except httpx.RequestError as e:
|
||||
elapsed_ms = (time.perf_counter() - start_time) * 1000
|
||||
crewai_event_bus.emit(
|
||||
None,
|
||||
A2AConnectionErrorEvent(
|
||||
endpoint=endpoint,
|
||||
error=str(e),
|
||||
error_type="request_error",
|
||||
operation="fetch_agent_card",
|
||||
metadata={
|
||||
"elapsed_ms": elapsed_ms,
|
||||
"request_url": str(e.request.url) if e.request else None,
|
||||
},
|
||||
),
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
def _task_to_skill(task: Task) -> AgentSkill:
|
||||
"""Convert a CrewAI Task to an A2A AgentSkill.
|
||||
|
||||
Args:
|
||||
task: The CrewAI Task to convert.
|
||||
|
||||
Returns:
|
||||
AgentSkill representing the task's capability.
|
||||
"""
|
||||
task_name = task.name or task.description[:50]
|
||||
task_id = task_name.lower().replace(" ", "_")
|
||||
|
||||
tags: list[str] = []
|
||||
if task.agent:
|
||||
tags.append(task.agent.role.lower().replace(" ", "-"))
|
||||
|
||||
return AgentSkill(
|
||||
id=task_id,
|
||||
name=task_name,
|
||||
description=task.description,
|
||||
tags=tags,
|
||||
examples=[task.expected_output] if task.expected_output else None,
|
||||
)
|
||||
|
||||
|
||||
def _tool_to_skill(tool_name: str, tool_description: str) -> AgentSkill:
|
||||
"""Convert an Agent's tool to an A2A AgentSkill.
|
||||
|
||||
Args:
|
||||
tool_name: Name of the tool.
|
||||
tool_description: Description of what the tool does.
|
||||
|
||||
Returns:
|
||||
AgentSkill representing the tool's capability.
|
||||
"""
|
||||
tool_id = tool_name.lower().replace(" ", "_")
|
||||
|
||||
return AgentSkill(
|
||||
id=tool_id,
|
||||
name=tool_name,
|
||||
description=tool_description,
|
||||
tags=[tool_name.lower().replace(" ", "-")],
|
||||
)
|
||||
|
||||
|
||||
def _crew_to_agent_card(crew: Crew, url: str) -> AgentCard:
|
||||
"""Generate an A2A AgentCard from a Crew instance.
|
||||
|
||||
Args:
|
||||
crew: The Crew instance to generate a card for.
|
||||
url: The base URL where this crew will be exposed.
|
||||
|
||||
Returns:
|
||||
AgentCard describing the crew's capabilities.
|
||||
"""
|
||||
crew_name = getattr(crew, "name", None) or crew.__class__.__name__
|
||||
|
||||
description_parts: list[str] = []
|
||||
crew_description = getattr(crew, "description", None)
|
||||
if crew_description:
|
||||
description_parts.append(crew_description)
|
||||
else:
|
||||
agent_roles = [agent.role for agent in crew.agents]
|
||||
description_parts.append(
|
||||
f"A crew of {len(crew.agents)} agents: {', '.join(agent_roles)}"
|
||||
)
|
||||
|
||||
skills = [_task_to_skill(task) for task in crew.tasks]
|
||||
|
||||
return AgentCard(
|
||||
name=crew_name,
|
||||
description=" ".join(description_parts),
|
||||
url=url,
|
||||
version="1.0.0",
|
||||
capabilities=AgentCapabilities(
|
||||
streaming=True,
|
||||
push_notifications=True,
|
||||
),
|
||||
default_input_modes=["text/plain", "application/json"],
|
||||
default_output_modes=["text/plain", "application/json"],
|
||||
skills=skills,
|
||||
)
|
||||
|
||||
|
||||
def _agent_to_agent_card(agent: Agent, url: str) -> AgentCard:
|
||||
"""Generate an A2A AgentCard from an Agent instance.
|
||||
|
||||
Uses A2AServerConfig values when available, falling back to agent properties.
|
||||
If signing_config is provided, the card will be signed with JWS.
|
||||
|
||||
Args:
|
||||
agent: The Agent instance to generate a card for.
|
||||
url: The base URL where this agent will be exposed.
|
||||
|
||||
Returns:
|
||||
AgentCard describing the agent's capabilities.
|
||||
"""
|
||||
from crewai.a2a.utils.agent_card_signing import sign_agent_card
|
||||
|
||||
server_config = _get_server_config(agent) or A2AServerConfig()
|
||||
|
||||
name = server_config.name or agent.role
|
||||
|
||||
description_parts = [agent.goal]
|
||||
if agent.backstory:
|
||||
description_parts.append(agent.backstory)
|
||||
description = server_config.description or " ".join(description_parts)
|
||||
|
||||
skills: list[AgentSkill] = (
|
||||
server_config.skills.copy() if server_config.skills else []
|
||||
)
|
||||
|
||||
if not skills:
|
||||
if agent.tools:
|
||||
for tool in agent.tools:
|
||||
tool_name = getattr(tool, "name", None) or tool.__class__.__name__
|
||||
tool_desc = getattr(tool, "description", None) or f"Tool: {tool_name}"
|
||||
skills.append(_tool_to_skill(tool_name, tool_desc))
|
||||
|
||||
if not skills:
|
||||
skills.append(
|
||||
AgentSkill(
|
||||
id=agent.role.lower().replace(" ", "_"),
|
||||
name=agent.role,
|
||||
description=agent.goal,
|
||||
tags=[agent.role.lower().replace(" ", "-")],
|
||||
)
|
||||
)
|
||||
|
||||
capabilities = server_config.capabilities
|
||||
if server_config.server_extensions:
|
||||
from crewai.a2a.extensions.server import ServerExtensionRegistry
|
||||
|
||||
registry = ServerExtensionRegistry(server_config.server_extensions)
|
||||
ext_list = registry.get_agent_extensions()
|
||||
|
||||
existing_exts = list(capabilities.extensions) if capabilities.extensions else []
|
||||
existing_uris = {e.uri for e in existing_exts}
|
||||
for ext in ext_list:
|
||||
if ext.uri not in existing_uris:
|
||||
existing_exts.append(ext)
|
||||
|
||||
capabilities = capabilities.model_copy(update={"extensions": existing_exts})
|
||||
|
||||
card = AgentCard(
|
||||
name=name,
|
||||
description=description,
|
||||
url=server_config.url or url,
|
||||
version=server_config.version,
|
||||
capabilities=capabilities,
|
||||
default_input_modes=server_config.default_input_modes,
|
||||
default_output_modes=server_config.default_output_modes,
|
||||
skills=skills,
|
||||
preferred_transport=server_config.transport.preferred,
|
||||
protocol_version=server_config.protocol_version,
|
||||
provider=server_config.provider,
|
||||
documentation_url=server_config.documentation_url,
|
||||
icon_url=server_config.icon_url,
|
||||
additional_interfaces=server_config.additional_interfaces,
|
||||
security=server_config.security,
|
||||
security_schemes=server_config.security_schemes,
|
||||
supports_authenticated_extended_card=server_config.supports_authenticated_extended_card,
|
||||
)
|
||||
|
||||
if server_config.signing_config:
|
||||
signature = sign_agent_card(
|
||||
card,
|
||||
private_key=server_config.signing_config.get_private_key(),
|
||||
key_id=server_config.signing_config.key_id,
|
||||
algorithm=server_config.signing_config.algorithm,
|
||||
)
|
||||
card = card.model_copy(update={"signatures": [signature]})
|
||||
elif server_config.signatures:
|
||||
card = card.model_copy(update={"signatures": server_config.signatures})
|
||||
|
||||
return card
|
||||
|
||||
|
||||
def inject_a2a_server_methods(agent: Agent) -> None:
|
||||
"""Inject A2A server methods onto an Agent instance.
|
||||
|
||||
Adds a `to_agent_card(url: str) -> AgentCard` method to the agent
|
||||
that generates an A2A-compliant AgentCard.
|
||||
|
||||
Only injects if the agent has an A2AServerConfig.
|
||||
|
||||
Args:
|
||||
agent: The Agent instance to inject methods onto.
|
||||
"""
|
||||
if _get_server_config(agent) is None:
|
||||
return
|
||||
|
||||
def _to_agent_card(self: Agent, url: str) -> AgentCard:
|
||||
return _agent_to_agent_card(self, url)
|
||||
|
||||
object.__setattr__(agent, "to_agent_card", MethodType(_to_agent_card, agent))
|
||||
from crewai_a2a.utils.agent_card import * # noqa: E402, F403
|
||||
|
||||
@@ -1,236 +1,13 @@
|
||||
"""AgentCard JWS signing utilities.
|
||||
"""Backward-compatibility shim — use ``crewai_a2a.utils.agent_card_signing`` instead."""
|
||||
|
||||
This module provides functions for signing and verifying AgentCards using
|
||||
JSON Web Signatures (JWS) as per RFC 7515. Signed agent cards allow clients
|
||||
to verify the authenticity and integrity of agent card information.
|
||||
|
||||
Example:
|
||||
>>> from crewai.a2a.utils.agent_card_signing import sign_agent_card
|
||||
>>> signature = sign_agent_card(agent_card, private_key_pem, key_id="key-1")
|
||||
>>> card_with_sig = card.model_copy(update={"signatures": [signature]})
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Literal
|
||||
|
||||
from a2a.types import AgentCard, AgentCardSignature
|
||||
import jwt
|
||||
from pydantic import SecretStr
|
||||
import warnings
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
warnings.warn(
|
||||
"'crewai.a2a.utils.agent_card_signing' has been moved to 'crewai_a2a.utils.agent_card_signing'. "
|
||||
"Please update your imports. The old path will be removed in v2.0.0.",
|
||||
FutureWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
|
||||
SigningAlgorithm = Literal[
|
||||
"RS256", "RS384", "RS512", "ES256", "ES384", "ES512", "PS256", "PS384", "PS512"
|
||||
]
|
||||
|
||||
|
||||
def _normalize_private_key(private_key: str | bytes | SecretStr) -> bytes:
|
||||
"""Normalize private key to bytes format.
|
||||
|
||||
Args:
|
||||
private_key: PEM-encoded private key as string, bytes, or SecretStr.
|
||||
|
||||
Returns:
|
||||
Private key as bytes.
|
||||
"""
|
||||
if isinstance(private_key, SecretStr):
|
||||
private_key = private_key.get_secret_value()
|
||||
if isinstance(private_key, str):
|
||||
private_key = private_key.encode()
|
||||
return private_key
|
||||
|
||||
|
||||
def _serialize_agent_card(agent_card: AgentCard) -> str:
|
||||
"""Serialize AgentCard to canonical JSON for signing.
|
||||
|
||||
Excludes the signatures field to avoid circular reference during signing.
|
||||
Uses sorted keys and compact separators for deterministic output.
|
||||
|
||||
Args:
|
||||
agent_card: The AgentCard to serialize.
|
||||
|
||||
Returns:
|
||||
Canonical JSON string representation.
|
||||
"""
|
||||
card_dict = agent_card.model_dump(exclude={"signatures"}, exclude_none=True)
|
||||
return json.dumps(card_dict, sort_keys=True, separators=(",", ":"))
|
||||
|
||||
|
||||
def _base64url_encode(data: bytes | str) -> str:
|
||||
"""Encode data to URL-safe base64 without padding.
|
||||
|
||||
Args:
|
||||
data: Data to encode.
|
||||
|
||||
Returns:
|
||||
URL-safe base64 encoded string without padding.
|
||||
"""
|
||||
if isinstance(data, str):
|
||||
data = data.encode()
|
||||
return base64.urlsafe_b64encode(data).rstrip(b"=").decode("ascii")
|
||||
|
||||
|
||||
def sign_agent_card(
|
||||
agent_card: AgentCard,
|
||||
private_key: str | bytes | SecretStr,
|
||||
key_id: str | None = None,
|
||||
algorithm: SigningAlgorithm = "RS256",
|
||||
) -> AgentCardSignature:
|
||||
"""Sign an AgentCard using JWS (RFC 7515).
|
||||
|
||||
Creates a detached JWS signature for the AgentCard. The signature covers
|
||||
all fields except the signatures field itself.
|
||||
|
||||
Args:
|
||||
agent_card: The AgentCard to sign.
|
||||
private_key: PEM-encoded private key (RSA, EC, or RSA-PSS).
|
||||
key_id: Optional key identifier for the JWS header (kid claim).
|
||||
algorithm: Signing algorithm (RS256, ES256, PS256, etc.).
|
||||
|
||||
Returns:
|
||||
AgentCardSignature with protected header and signature.
|
||||
|
||||
Raises:
|
||||
jwt.exceptions.InvalidKeyError: If the private key is invalid.
|
||||
ValueError: If the algorithm is not supported for the key type.
|
||||
|
||||
Example:
|
||||
>>> signature = sign_agent_card(
|
||||
... agent_card,
|
||||
... private_key_pem="-----BEGIN PRIVATE KEY-----...",
|
||||
... key_id="my-key-id",
|
||||
... )
|
||||
"""
|
||||
key_bytes = _normalize_private_key(private_key)
|
||||
payload = _serialize_agent_card(agent_card)
|
||||
|
||||
protected_header: dict[str, Any] = {"typ": "JWS"}
|
||||
if key_id:
|
||||
protected_header["kid"] = key_id
|
||||
|
||||
jws_token = jwt.api_jws.encode(
|
||||
payload.encode(),
|
||||
key_bytes,
|
||||
algorithm=algorithm,
|
||||
headers=protected_header,
|
||||
)
|
||||
|
||||
parts = jws_token.split(".")
|
||||
protected_b64 = parts[0]
|
||||
signature_b64 = parts[2]
|
||||
|
||||
header: dict[str, Any] | None = None
|
||||
if key_id:
|
||||
header = {"kid": key_id}
|
||||
|
||||
return AgentCardSignature(
|
||||
protected=protected_b64,
|
||||
signature=signature_b64,
|
||||
header=header,
|
||||
)
|
||||
|
||||
|
||||
def verify_agent_card_signature(
|
||||
agent_card: AgentCard,
|
||||
signature: AgentCardSignature,
|
||||
public_key: str | bytes,
|
||||
algorithms: list[str] | None = None,
|
||||
) -> bool:
|
||||
"""Verify an AgentCard JWS signature.
|
||||
|
||||
Validates that the signature was created with the corresponding private key
|
||||
and that the AgentCard content has not been modified.
|
||||
|
||||
Args:
|
||||
agent_card: The AgentCard to verify.
|
||||
signature: The AgentCardSignature to validate.
|
||||
public_key: PEM-encoded public key (RSA, EC, or RSA-PSS).
|
||||
algorithms: List of allowed algorithms. Defaults to common asymmetric algorithms.
|
||||
|
||||
Returns:
|
||||
True if signature is valid, False otherwise.
|
||||
|
||||
Example:
|
||||
>>> is_valid = verify_agent_card_signature(
|
||||
... agent_card, signature, public_key_pem="-----BEGIN PUBLIC KEY-----..."
|
||||
... )
|
||||
"""
|
||||
if algorithms is None:
|
||||
algorithms = [
|
||||
"RS256",
|
||||
"RS384",
|
||||
"RS512",
|
||||
"ES256",
|
||||
"ES384",
|
||||
"ES512",
|
||||
"PS256",
|
||||
"PS384",
|
||||
"PS512",
|
||||
]
|
||||
|
||||
if isinstance(public_key, str):
|
||||
public_key = public_key.encode()
|
||||
|
||||
payload = _serialize_agent_card(agent_card)
|
||||
payload_b64 = _base64url_encode(payload)
|
||||
jws_token = f"{signature.protected}.{payload_b64}.{signature.signature}"
|
||||
|
||||
try:
|
||||
jwt.api_jws.decode(
|
||||
jws_token,
|
||||
public_key,
|
||||
algorithms=algorithms,
|
||||
)
|
||||
return True
|
||||
except jwt.InvalidSignatureError:
|
||||
logger.debug(
|
||||
"AgentCard signature verification failed",
|
||||
extra={"reason": "invalid_signature"},
|
||||
)
|
||||
return False
|
||||
except jwt.DecodeError as e:
|
||||
logger.debug(
|
||||
"AgentCard signature verification failed",
|
||||
extra={"reason": "decode_error", "error": str(e)},
|
||||
)
|
||||
return False
|
||||
except jwt.InvalidAlgorithmError as e:
|
||||
logger.debug(
|
||||
"AgentCard signature verification failed",
|
||||
extra={"reason": "algorithm_error", "error": str(e)},
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
def get_key_id_from_signature(signature: AgentCardSignature) -> str | None:
|
||||
"""Extract the key ID (kid) from an AgentCardSignature.
|
||||
|
||||
Checks both the unprotected header and the protected header for the kid claim.
|
||||
|
||||
Args:
|
||||
signature: The AgentCardSignature to extract from.
|
||||
|
||||
Returns:
|
||||
The key ID if present, None otherwise.
|
||||
"""
|
||||
if signature.header and "kid" in signature.header:
|
||||
kid: str = signature.header["kid"]
|
||||
return kid
|
||||
|
||||
try:
|
||||
protected = signature.protected
|
||||
padding_needed = 4 - (len(protected) % 4)
|
||||
if padding_needed != 4:
|
||||
protected += "=" * padding_needed
|
||||
|
||||
protected_json = base64.urlsafe_b64decode(protected).decode()
|
||||
protected_header: dict[str, Any] = json.loads(protected_json)
|
||||
return protected_header.get("kid")
|
||||
except (ValueError, json.JSONDecodeError):
|
||||
return None
|
||||
from crewai_a2a.utils.agent_card_signing import * # noqa: E402, F403
|
||||
|
||||
@@ -1,339 +1,13 @@
|
||||
"""Content type negotiation for A2A protocol.
|
||||
"""Backward-compatibility shim — use ``crewai_a2a.utils.content_type`` instead."""
|
||||
|
||||
This module handles negotiation of input/output MIME types between A2A clients
|
||||
and servers based on AgentCard capabilities.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Annotated, Final, Literal, cast
|
||||
|
||||
from a2a.types import Part
|
||||
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.a2a_events import A2AContentTypeNegotiatedEvent
|
||||
import warnings
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from a2a.types import AgentCard, AgentSkill
|
||||
|
||||
|
||||
TEXT_PLAIN: Literal["text/plain"] = "text/plain"
|
||||
APPLICATION_JSON: Literal["application/json"] = "application/json"
|
||||
IMAGE_PNG: Literal["image/png"] = "image/png"
|
||||
IMAGE_JPEG: Literal["image/jpeg"] = "image/jpeg"
|
||||
IMAGE_WILDCARD: Literal["image/*"] = "image/*"
|
||||
APPLICATION_PDF: Literal["application/pdf"] = "application/pdf"
|
||||
APPLICATION_OCTET_STREAM: Literal["application/octet-stream"] = (
|
||||
"application/octet-stream"
|
||||
warnings.warn(
|
||||
"'crewai.a2a.utils.content_type' has been moved to 'crewai_a2a.utils.content_type'. "
|
||||
"Please update your imports. The old path will be removed in v2.0.0.",
|
||||
FutureWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
DEFAULT_CLIENT_INPUT_MODES: Final[list[Literal["text/plain", "application/json"]]] = [
|
||||
TEXT_PLAIN,
|
||||
APPLICATION_JSON,
|
||||
]
|
||||
DEFAULT_CLIENT_OUTPUT_MODES: Final[list[Literal["text/plain", "application/json"]]] = [
|
||||
TEXT_PLAIN,
|
||||
APPLICATION_JSON,
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class NegotiatedContentTypes:
|
||||
"""Result of content type negotiation."""
|
||||
|
||||
input_modes: Annotated[list[str], "Negotiated input MIME types the client can send"]
|
||||
output_modes: Annotated[
|
||||
list[str], "Negotiated output MIME types the server will produce"
|
||||
]
|
||||
effective_input_modes: Annotated[list[str], "Server's effective input modes"]
|
||||
effective_output_modes: Annotated[list[str], "Server's effective output modes"]
|
||||
skill_name: Annotated[
|
||||
str | None, "Skill name if negotiation was skill-specific"
|
||||
] = None
|
||||
|
||||
|
||||
class ContentTypeNegotiationError(Exception):
|
||||
"""Raised when no compatible content types can be negotiated."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client_input_modes: list[str],
|
||||
client_output_modes: list[str],
|
||||
server_input_modes: list[str],
|
||||
server_output_modes: list[str],
|
||||
direction: str = "both",
|
||||
message: str | None = None,
|
||||
) -> None:
|
||||
self.client_input_modes = client_input_modes
|
||||
self.client_output_modes = client_output_modes
|
||||
self.server_input_modes = server_input_modes
|
||||
self.server_output_modes = server_output_modes
|
||||
self.direction = direction
|
||||
|
||||
if message is None:
|
||||
if direction == "input":
|
||||
message = (
|
||||
f"No compatible input content types. "
|
||||
f"Client supports: {client_input_modes}, "
|
||||
f"Server accepts: {server_input_modes}"
|
||||
)
|
||||
elif direction == "output":
|
||||
message = (
|
||||
f"No compatible output content types. "
|
||||
f"Client accepts: {client_output_modes}, "
|
||||
f"Server produces: {server_output_modes}"
|
||||
)
|
||||
else:
|
||||
message = (
|
||||
f"No compatible content types. "
|
||||
f"Input - Client: {client_input_modes}, Server: {server_input_modes}. "
|
||||
f"Output - Client: {client_output_modes}, Server: {server_output_modes}"
|
||||
)
|
||||
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
def _normalize_mime_type(mime_type: str) -> str:
|
||||
"""Normalize MIME type for comparison (lowercase, strip whitespace)."""
|
||||
return mime_type.lower().strip()
|
||||
|
||||
|
||||
def _mime_types_compatible(client_type: str, server_type: str) -> bool:
|
||||
"""Check if two MIME types are compatible.
|
||||
|
||||
Handles wildcards like image/* matching image/png.
|
||||
"""
|
||||
client_normalized = _normalize_mime_type(client_type)
|
||||
server_normalized = _normalize_mime_type(server_type)
|
||||
|
||||
if client_normalized == server_normalized:
|
||||
return True
|
||||
|
||||
if "*" in client_normalized or "*" in server_normalized:
|
||||
client_parts = client_normalized.split("/")
|
||||
server_parts = server_normalized.split("/")
|
||||
|
||||
if len(client_parts) == 2 and len(server_parts) == 2:
|
||||
type_match = (
|
||||
client_parts[0] == server_parts[0]
|
||||
or client_parts[0] == "*"
|
||||
or server_parts[0] == "*"
|
||||
)
|
||||
subtype_match = (
|
||||
client_parts[1] == server_parts[1]
|
||||
or client_parts[1] == "*"
|
||||
or server_parts[1] == "*"
|
||||
)
|
||||
return type_match and subtype_match
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def _find_compatible_modes(
|
||||
client_modes: list[str], server_modes: list[str]
|
||||
) -> list[str]:
|
||||
"""Find compatible MIME types between client and server.
|
||||
|
||||
Returns modes in client preference order.
|
||||
"""
|
||||
compatible = []
|
||||
for client_mode in client_modes:
|
||||
for server_mode in server_modes:
|
||||
if _mime_types_compatible(client_mode, server_mode):
|
||||
if "*" in client_mode and "*" not in server_mode:
|
||||
if server_mode not in compatible:
|
||||
compatible.append(server_mode)
|
||||
else:
|
||||
if client_mode not in compatible:
|
||||
compatible.append(client_mode)
|
||||
break
|
||||
return compatible
|
||||
|
||||
|
||||
def _get_effective_modes(
|
||||
agent_card: AgentCard,
|
||||
skill_name: str | None = None,
|
||||
) -> tuple[list[str], list[str], AgentSkill | None]:
|
||||
"""Get effective input/output modes from agent card.
|
||||
|
||||
If skill_name is provided and the skill has custom modes, those are used.
|
||||
Otherwise, falls back to agent card defaults.
|
||||
"""
|
||||
skill: AgentSkill | None = None
|
||||
|
||||
if skill_name and agent_card.skills:
|
||||
for s in agent_card.skills:
|
||||
if s.name == skill_name or s.id == skill_name:
|
||||
skill = s
|
||||
break
|
||||
|
||||
if skill:
|
||||
input_modes = (
|
||||
skill.input_modes if skill.input_modes else agent_card.default_input_modes
|
||||
)
|
||||
output_modes = (
|
||||
skill.output_modes
|
||||
if skill.output_modes
|
||||
else agent_card.default_output_modes
|
||||
)
|
||||
else:
|
||||
input_modes = agent_card.default_input_modes
|
||||
output_modes = agent_card.default_output_modes
|
||||
|
||||
return input_modes, output_modes, skill
|
||||
|
||||
|
||||
def negotiate_content_types(
|
||||
agent_card: AgentCard,
|
||||
client_input_modes: list[str] | None = None,
|
||||
client_output_modes: list[str] | None = None,
|
||||
skill_name: str | None = None,
|
||||
emit_event: bool = True,
|
||||
endpoint: str | None = None,
|
||||
a2a_agent_name: str | None = None,
|
||||
strict: bool = False,
|
||||
) -> NegotiatedContentTypes:
|
||||
"""Negotiate content types between client and server.
|
||||
|
||||
Args:
|
||||
agent_card: The remote agent's card with capability info.
|
||||
client_input_modes: MIME types the client can send. Defaults to text/plain and application/json.
|
||||
client_output_modes: MIME types the client can accept. Defaults to text/plain and application/json.
|
||||
skill_name: Optional skill to use for mode lookup.
|
||||
emit_event: Whether to emit a content type negotiation event.
|
||||
endpoint: Agent endpoint (for event metadata).
|
||||
a2a_agent_name: Agent name (for event metadata).
|
||||
strict: If True, raises error when no compatible types found.
|
||||
If False, returns empty lists for incompatible directions.
|
||||
|
||||
Returns:
|
||||
NegotiatedContentTypes with compatible input and output modes.
|
||||
|
||||
Raises:
|
||||
ContentTypeNegotiationError: If strict=True and no compatible types found.
|
||||
"""
|
||||
if client_input_modes is None:
|
||||
client_input_modes = cast(list[str], DEFAULT_CLIENT_INPUT_MODES.copy())
|
||||
if client_output_modes is None:
|
||||
client_output_modes = cast(list[str], DEFAULT_CLIENT_OUTPUT_MODES.copy())
|
||||
|
||||
server_input_modes, server_output_modes, skill = _get_effective_modes(
|
||||
agent_card, skill_name
|
||||
)
|
||||
|
||||
compatible_input = _find_compatible_modes(client_input_modes, server_input_modes)
|
||||
compatible_output = _find_compatible_modes(client_output_modes, server_output_modes)
|
||||
|
||||
if strict:
|
||||
if not compatible_input and not compatible_output:
|
||||
raise ContentTypeNegotiationError(
|
||||
client_input_modes=client_input_modes,
|
||||
client_output_modes=client_output_modes,
|
||||
server_input_modes=server_input_modes,
|
||||
server_output_modes=server_output_modes,
|
||||
)
|
||||
if not compatible_input:
|
||||
raise ContentTypeNegotiationError(
|
||||
client_input_modes=client_input_modes,
|
||||
client_output_modes=client_output_modes,
|
||||
server_input_modes=server_input_modes,
|
||||
server_output_modes=server_output_modes,
|
||||
direction="input",
|
||||
)
|
||||
if not compatible_output:
|
||||
raise ContentTypeNegotiationError(
|
||||
client_input_modes=client_input_modes,
|
||||
client_output_modes=client_output_modes,
|
||||
server_input_modes=server_input_modes,
|
||||
server_output_modes=server_output_modes,
|
||||
direction="output",
|
||||
)
|
||||
|
||||
result = NegotiatedContentTypes(
|
||||
input_modes=compatible_input,
|
||||
output_modes=compatible_output,
|
||||
effective_input_modes=server_input_modes,
|
||||
effective_output_modes=server_output_modes,
|
||||
skill_name=skill.name if skill else None,
|
||||
)
|
||||
|
||||
if emit_event:
|
||||
crewai_event_bus.emit(
|
||||
None,
|
||||
A2AContentTypeNegotiatedEvent(
|
||||
endpoint=endpoint or agent_card.url,
|
||||
a2a_agent_name=a2a_agent_name or agent_card.name,
|
||||
skill_name=skill_name,
|
||||
client_input_modes=client_input_modes,
|
||||
client_output_modes=client_output_modes,
|
||||
server_input_modes=server_input_modes,
|
||||
server_output_modes=server_output_modes,
|
||||
negotiated_input_modes=compatible_input,
|
||||
negotiated_output_modes=compatible_output,
|
||||
negotiation_success=bool(compatible_input and compatible_output),
|
||||
),
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def validate_content_type(
|
||||
content_type: str,
|
||||
allowed_modes: list[str],
|
||||
) -> bool:
|
||||
"""Validate that a content type is allowed by a list of modes.
|
||||
|
||||
Args:
|
||||
content_type: The MIME type to validate.
|
||||
allowed_modes: List of allowed MIME types (may include wildcards).
|
||||
|
||||
Returns:
|
||||
True if content_type is compatible with any allowed mode.
|
||||
"""
|
||||
for mode in allowed_modes:
|
||||
if _mime_types_compatible(content_type, mode):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def get_part_content_type(part: Part) -> str:
|
||||
"""Extract MIME type from an A2A Part.
|
||||
|
||||
Args:
|
||||
part: A Part object containing TextPart, DataPart, or FilePart.
|
||||
|
||||
Returns:
|
||||
The MIME type string for this part.
|
||||
"""
|
||||
root = part.root
|
||||
if root.kind == "text":
|
||||
return TEXT_PLAIN
|
||||
if root.kind == "data":
|
||||
return APPLICATION_JSON
|
||||
if root.kind == "file":
|
||||
return root.file.mime_type or APPLICATION_OCTET_STREAM
|
||||
return APPLICATION_OCTET_STREAM
|
||||
|
||||
|
||||
def validate_message_parts(
|
||||
parts: list[Part],
|
||||
allowed_modes: list[str],
|
||||
) -> list[str]:
|
||||
"""Validate that all message parts have allowed content types.
|
||||
|
||||
Args:
|
||||
parts: List of Parts from the incoming message.
|
||||
allowed_modes: List of allowed MIME types (from default_input_modes).
|
||||
|
||||
Returns:
|
||||
List of invalid content types found (empty if all valid).
|
||||
"""
|
||||
invalid_types: list[str] = []
|
||||
for part in parts:
|
||||
content_type = get_part_content_type(part)
|
||||
if not validate_content_type(content_type, allowed_modes):
|
||||
if content_type not in invalid_types:
|
||||
invalid_types.append(content_type)
|
||||
return invalid_types
|
||||
from crewai_a2a.utils.content_type import * # noqa: E402, F403
|
||||
|
||||
@@ -1,980 +1,13 @@
|
||||
"""A2A delegation utilities for executing tasks on remote agents."""
|
||||
"""Backward-compatibility shim — use ``crewai_a2a.utils.delegation`` instead."""
|
||||
|
||||
from __future__ import annotations
|
||||
import warnings
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
from collections.abc import AsyncIterator, Callable, MutableMapping
|
||||
from contextlib import asynccontextmanager
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any, Final, Literal
|
||||
import uuid
|
||||
|
||||
from a2a.client import Client, ClientConfig, ClientFactory
|
||||
from a2a.types import (
|
||||
AgentCard,
|
||||
FilePart,
|
||||
FileWithBytes,
|
||||
Message,
|
||||
Part,
|
||||
PushNotificationConfig as A2APushNotificationConfig,
|
||||
Role,
|
||||
TextPart,
|
||||
)
|
||||
import httpx
|
||||
from pydantic import BaseModel
|
||||
|
||||
from crewai.a2a.auth.client_schemes import APIKeyAuth, HTTPDigestAuth
|
||||
from crewai.a2a.auth.utils import (
|
||||
_auth_store,
|
||||
configure_auth_client,
|
||||
validate_auth_against_agent_card,
|
||||
)
|
||||
from crewai.a2a.config import ClientTransportConfig, GRPCClientConfig
|
||||
from crewai.a2a.extensions.registry import (
|
||||
ExtensionsMiddleware,
|
||||
validate_required_extensions,
|
||||
)
|
||||
from crewai.a2a.task_helpers import TaskStateResult
|
||||
from crewai.a2a.types import (
|
||||
HANDLER_REGISTRY,
|
||||
HandlerType,
|
||||
PartsDict,
|
||||
PartsMetadataDict,
|
||||
TransportType,
|
||||
)
|
||||
from crewai.a2a.updates import (
|
||||
PollingConfig,
|
||||
PushNotificationConfig,
|
||||
StreamingHandler,
|
||||
UpdateConfig,
|
||||
)
|
||||
from crewai.a2a.utils.agent_card import (
|
||||
_afetch_agent_card_cached,
|
||||
_get_tls_verify,
|
||||
_prepare_auth_headers,
|
||||
)
|
||||
from crewai.a2a.utils.content_type import (
|
||||
DEFAULT_CLIENT_OUTPUT_MODES,
|
||||
negotiate_content_types,
|
||||
)
|
||||
from crewai.a2a.utils.transport import (
|
||||
NegotiatedTransport,
|
||||
TransportNegotiationError,
|
||||
negotiate_transport,
|
||||
)
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.a2a_events import (
|
||||
A2AConversationStartedEvent,
|
||||
A2ADelegationCompletedEvent,
|
||||
A2ADelegationStartedEvent,
|
||||
A2AMessageSentEvent,
|
||||
warnings.warn(
|
||||
"'crewai.a2a.utils.delegation' has been moved to 'crewai_a2a.utils.delegation'. "
|
||||
"Please update your imports. The old path will be removed in v2.0.0.",
|
||||
FutureWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from a2a.types import Message
|
||||
|
||||
from crewai.a2a.auth.client_schemes import ClientAuthScheme
|
||||
|
||||
|
||||
_DEFAULT_TRANSPORT: Final[TransportType] = "JSONRPC"
|
||||
|
||||
|
||||
def _create_file_parts(input_files: dict[str, Any] | None) -> list[Part]:
|
||||
"""Convert FileInput dictionary to FilePart objects.
|
||||
|
||||
Args:
|
||||
input_files: Dictionary mapping names to FileInput objects.
|
||||
|
||||
Returns:
|
||||
List of Part objects containing FilePart data.
|
||||
"""
|
||||
if not input_files:
|
||||
return []
|
||||
|
||||
try:
|
||||
import crewai_files # noqa: F401
|
||||
except ImportError:
|
||||
logger.debug("crewai_files not installed, skipping file parts")
|
||||
return []
|
||||
|
||||
parts: list[Part] = []
|
||||
for name, file_input in input_files.items():
|
||||
content_bytes = file_input.read()
|
||||
content_base64 = base64.b64encode(content_bytes).decode()
|
||||
file_with_bytes = FileWithBytes(
|
||||
bytes=content_base64,
|
||||
mimeType=file_input.content_type,
|
||||
name=file_input.filename or name,
|
||||
)
|
||||
parts.append(Part(root=FilePart(file=file_with_bytes)))
|
||||
|
||||
return parts
|
||||
|
||||
|
||||
def get_handler(config: UpdateConfig | None) -> HandlerType:
|
||||
"""Get the handler class for a given update config.
|
||||
|
||||
Args:
|
||||
config: Update mechanism configuration.
|
||||
|
||||
Returns:
|
||||
Handler class for the config type, defaults to StreamingHandler.
|
||||
"""
|
||||
if config is None:
|
||||
return StreamingHandler
|
||||
return HANDLER_REGISTRY.get(type(config), StreamingHandler)
|
||||
|
||||
|
||||
def execute_a2a_delegation(
|
||||
endpoint: str,
|
||||
auth: ClientAuthScheme | None,
|
||||
timeout: int,
|
||||
task_description: str,
|
||||
context: str | None = None,
|
||||
context_id: str | None = None,
|
||||
task_id: str | None = None,
|
||||
reference_task_ids: list[str] | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
extensions: dict[str, Any] | None = None,
|
||||
conversation_history: list[Message] | None = None,
|
||||
agent_id: str | None = None,
|
||||
agent_role: Role | None = None,
|
||||
agent_branch: Any | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
turn_number: int | None = None,
|
||||
updates: UpdateConfig | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
skill_id: str | None = None,
|
||||
client_extensions: list[str] | None = None,
|
||||
transport: ClientTransportConfig | None = None,
|
||||
accepted_output_modes: list[str] | None = None,
|
||||
input_files: dict[str, Any] | None = None,
|
||||
) -> TaskStateResult:
|
||||
"""Execute a task delegation to a remote A2A agent synchronously.
|
||||
|
||||
WARNING: This function blocks the entire thread by creating and running a new
|
||||
event loop. Prefer using 'await aexecute_a2a_delegation()' in async contexts
|
||||
for better performance and resource efficiency.
|
||||
|
||||
This is a synchronous wrapper around aexecute_a2a_delegation that creates a
|
||||
new event loop to run the async implementation. It is provided for compatibility
|
||||
with synchronous code paths only.
|
||||
|
||||
Args:
|
||||
endpoint: A2A agent endpoint URL (AgentCard URL).
|
||||
auth: Optional ClientAuthScheme for authentication.
|
||||
timeout: Request timeout in seconds.
|
||||
task_description: The task to delegate.
|
||||
context: Optional context information.
|
||||
context_id: Context ID for correlating messages/tasks.
|
||||
task_id: Specific task identifier.
|
||||
reference_task_ids: List of related task IDs.
|
||||
metadata: Additional metadata.
|
||||
extensions: Protocol extensions for custom fields.
|
||||
conversation_history: Previous Message objects from conversation.
|
||||
agent_id: Agent identifier for logging.
|
||||
agent_role: Role of the CrewAI agent delegating the task.
|
||||
agent_branch: Optional agent tree branch for logging.
|
||||
response_model: Optional Pydantic model for structured outputs.
|
||||
turn_number: Optional turn number for multi-turn conversations.
|
||||
updates: Update mechanism config from A2AConfig.updates.
|
||||
from_task: Optional CrewAI Task object for event metadata.
|
||||
from_agent: Optional CrewAI Agent object for event metadata.
|
||||
skill_id: Optional skill ID to target a specific agent capability.
|
||||
client_extensions: A2A protocol extension URIs the client supports.
|
||||
transport: Transport configuration (preferred, supported transports, gRPC settings).
|
||||
accepted_output_modes: MIME types the client can accept in responses.
|
||||
input_files: Optional dictionary of files to send to remote agent.
|
||||
|
||||
Returns:
|
||||
TaskStateResult with status, result/error, history, and agent_card.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If called from an async context with a running event loop.
|
||||
"""
|
||||
try:
|
||||
asyncio.get_running_loop()
|
||||
raise RuntimeError(
|
||||
"execute_a2a_delegation() cannot be called from an async context. "
|
||||
"Use 'await aexecute_a2a_delegation()' instead."
|
||||
)
|
||||
except RuntimeError as e:
|
||||
if "no running event loop" not in str(e).lower():
|
||||
raise
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
return loop.run_until_complete(
|
||||
aexecute_a2a_delegation(
|
||||
endpoint=endpoint,
|
||||
auth=auth,
|
||||
timeout=timeout,
|
||||
task_description=task_description,
|
||||
context=context,
|
||||
context_id=context_id,
|
||||
task_id=task_id,
|
||||
reference_task_ids=reference_task_ids,
|
||||
metadata=metadata,
|
||||
extensions=extensions,
|
||||
conversation_history=conversation_history,
|
||||
agent_id=agent_id,
|
||||
agent_role=agent_role,
|
||||
agent_branch=agent_branch,
|
||||
response_model=response_model,
|
||||
turn_number=turn_number,
|
||||
updates=updates,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
skill_id=skill_id,
|
||||
client_extensions=client_extensions,
|
||||
transport=transport,
|
||||
accepted_output_modes=accepted_output_modes,
|
||||
input_files=input_files,
|
||||
)
|
||||
)
|
||||
finally:
|
||||
try:
|
||||
loop.run_until_complete(loop.shutdown_asyncgens())
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
|
||||
async def aexecute_a2a_delegation(
|
||||
endpoint: str,
|
||||
auth: ClientAuthScheme | None,
|
||||
timeout: int,
|
||||
task_description: str,
|
||||
context: str | None = None,
|
||||
context_id: str | None = None,
|
||||
task_id: str | None = None,
|
||||
reference_task_ids: list[str] | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
extensions: dict[str, Any] | None = None,
|
||||
conversation_history: list[Message] | None = None,
|
||||
agent_id: str | None = None,
|
||||
agent_role: Role | None = None,
|
||||
agent_branch: Any | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
turn_number: int | None = None,
|
||||
updates: UpdateConfig | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
skill_id: str | None = None,
|
||||
client_extensions: list[str] | None = None,
|
||||
transport: ClientTransportConfig | None = None,
|
||||
accepted_output_modes: list[str] | None = None,
|
||||
input_files: dict[str, Any] | None = None,
|
||||
) -> TaskStateResult:
|
||||
"""Execute a task delegation to a remote A2A agent asynchronously.
|
||||
|
||||
Native async implementation with multi-turn support. Use this when running
|
||||
in an async context (e.g., with Crew.akickoff() or agent.aexecute_task()).
|
||||
|
||||
Args:
|
||||
endpoint: A2A agent endpoint URL.
|
||||
auth: Optional ClientAuthScheme for authentication.
|
||||
timeout: Request timeout in seconds.
|
||||
task_description: The task to delegate.
|
||||
context: Optional context information.
|
||||
context_id: Context ID for correlating messages/tasks.
|
||||
task_id: Specific task identifier.
|
||||
reference_task_ids: List of related task IDs.
|
||||
metadata: Additional metadata.
|
||||
extensions: Protocol extensions for custom fields.
|
||||
conversation_history: Previous Message objects from conversation.
|
||||
agent_id: Agent identifier for logging.
|
||||
agent_role: Role of the CrewAI agent delegating the task.
|
||||
agent_branch: Optional agent tree branch for logging.
|
||||
response_model: Optional Pydantic model for structured outputs.
|
||||
turn_number: Optional turn number for multi-turn conversations.
|
||||
updates: Update mechanism config from A2AConfig.updates.
|
||||
from_task: Optional CrewAI Task object for event metadata.
|
||||
from_agent: Optional CrewAI Agent object for event metadata.
|
||||
skill_id: Optional skill ID to target a specific agent capability.
|
||||
client_extensions: A2A protocol extension URIs the client supports.
|
||||
transport: Transport configuration (preferred, supported transports, gRPC settings).
|
||||
accepted_output_modes: MIME types the client can accept in responses.
|
||||
input_files: Optional dictionary of files to send to remote agent.
|
||||
|
||||
Returns:
|
||||
TaskStateResult with status, result/error, history, and agent_card.
|
||||
"""
|
||||
if conversation_history is None:
|
||||
conversation_history = []
|
||||
|
||||
is_multiturn = len(conversation_history) > 0
|
||||
if turn_number is None:
|
||||
turn_number = len([m for m in conversation_history if m.role == Role.user]) + 1
|
||||
|
||||
try:
|
||||
result = await _aexecute_a2a_delegation_impl(
|
||||
endpoint=endpoint,
|
||||
auth=auth,
|
||||
timeout=timeout,
|
||||
task_description=task_description,
|
||||
context=context,
|
||||
context_id=context_id,
|
||||
task_id=task_id,
|
||||
reference_task_ids=reference_task_ids,
|
||||
metadata=metadata,
|
||||
extensions=extensions,
|
||||
conversation_history=conversation_history,
|
||||
is_multiturn=is_multiturn,
|
||||
turn_number=turn_number,
|
||||
agent_branch=agent_branch,
|
||||
agent_id=agent_id,
|
||||
agent_role=agent_role,
|
||||
response_model=response_model,
|
||||
updates=updates,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
skill_id=skill_id,
|
||||
client_extensions=client_extensions,
|
||||
transport=transport,
|
||||
accepted_output_modes=accepted_output_modes,
|
||||
input_files=input_files,
|
||||
)
|
||||
except Exception as e:
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
A2ADelegationCompletedEvent(
|
||||
status="failed",
|
||||
result=None,
|
||||
error=str(e),
|
||||
context_id=context_id,
|
||||
is_multiturn=is_multiturn,
|
||||
endpoint=endpoint,
|
||||
metadata=metadata,
|
||||
extensions=list(extensions.keys()) if extensions else None,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
raise
|
||||
|
||||
agent_card_data = result.get("agent_card")
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
A2ADelegationCompletedEvent(
|
||||
status=result["status"],
|
||||
result=result.get("result"),
|
||||
error=result.get("error"),
|
||||
context_id=context_id,
|
||||
is_multiturn=is_multiturn,
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=result.get("a2a_agent_name"),
|
||||
agent_card=agent_card_data,
|
||||
provider=agent_card_data.get("provider") if agent_card_data else None,
|
||||
metadata=metadata,
|
||||
extensions=list(extensions.keys()) if extensions else None,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
async def _aexecute_a2a_delegation_impl(
|
||||
endpoint: str,
|
||||
auth: ClientAuthScheme | None,
|
||||
timeout: int,
|
||||
task_description: str,
|
||||
context: str | None,
|
||||
context_id: str | None,
|
||||
task_id: str | None,
|
||||
reference_task_ids: list[str] | None,
|
||||
metadata: dict[str, Any] | None,
|
||||
extensions: dict[str, Any] | None,
|
||||
conversation_history: list[Message],
|
||||
is_multiturn: bool,
|
||||
turn_number: int,
|
||||
agent_branch: Any | None,
|
||||
agent_id: str | None,
|
||||
agent_role: str | None,
|
||||
response_model: type[BaseModel] | None,
|
||||
updates: UpdateConfig | None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
skill_id: str | None = None,
|
||||
client_extensions: list[str] | None = None,
|
||||
transport: ClientTransportConfig | None = None,
|
||||
accepted_output_modes: list[str] | None = None,
|
||||
input_files: dict[str, Any] | None = None,
|
||||
) -> TaskStateResult:
|
||||
"""Internal async implementation of A2A delegation."""
|
||||
if transport is None:
|
||||
transport = ClientTransportConfig()
|
||||
if auth:
|
||||
auth_data = auth.model_dump_json(
|
||||
exclude={
|
||||
"_access_token",
|
||||
"_token_expires_at",
|
||||
"_refresh_token",
|
||||
"_authorization_callback",
|
||||
}
|
||||
)
|
||||
auth_hash = _auth_store.compute_key(type(auth).__name__, auth_data)
|
||||
else:
|
||||
auth_hash = _auth_store.compute_key("none", endpoint)
|
||||
_auth_store.set(auth_hash, auth)
|
||||
agent_card = await _afetch_agent_card_cached(
|
||||
endpoint=endpoint, auth_hash=auth_hash, timeout=timeout
|
||||
)
|
||||
|
||||
validate_auth_against_agent_card(agent_card, auth)
|
||||
|
||||
unsupported_exts = validate_required_extensions(agent_card, client_extensions)
|
||||
if unsupported_exts:
|
||||
ext_uris = [ext.uri for ext in unsupported_exts]
|
||||
raise ValueError(
|
||||
f"Agent requires extensions not supported by client: {ext_uris}"
|
||||
)
|
||||
|
||||
negotiated: NegotiatedTransport | None = None
|
||||
effective_transport: TransportType = transport.preferred or _DEFAULT_TRANSPORT
|
||||
effective_url = endpoint
|
||||
|
||||
client_transports: list[str] = (
|
||||
list(transport.supported) if transport.supported else [_DEFAULT_TRANSPORT]
|
||||
)
|
||||
|
||||
try:
|
||||
negotiated = negotiate_transport(
|
||||
agent_card=agent_card,
|
||||
client_supported_transports=client_transports,
|
||||
client_preferred_transport=transport.preferred,
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=agent_card.name,
|
||||
)
|
||||
effective_transport = negotiated.transport # type: ignore[assignment]
|
||||
effective_url = negotiated.url
|
||||
except TransportNegotiationError as e:
|
||||
logger.warning(
|
||||
"Transport negotiation failed, using fallback",
|
||||
extra={
|
||||
"error": str(e),
|
||||
"fallback_transport": effective_transport,
|
||||
"fallback_url": effective_url,
|
||||
"endpoint": endpoint,
|
||||
"client_transports": client_transports,
|
||||
"server_transports": [
|
||||
iface.transport for iface in agent_card.additional_interfaces or []
|
||||
]
|
||||
+ [agent_card.preferred_transport or "JSONRPC"],
|
||||
},
|
||||
)
|
||||
|
||||
effective_output_modes = accepted_output_modes or DEFAULT_CLIENT_OUTPUT_MODES.copy()
|
||||
|
||||
content_negotiated = negotiate_content_types(
|
||||
agent_card=agent_card,
|
||||
client_output_modes=accepted_output_modes,
|
||||
skill_name=skill_id,
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=agent_card.name,
|
||||
)
|
||||
if content_negotiated.output_modes:
|
||||
effective_output_modes = content_negotiated.output_modes
|
||||
|
||||
headers, _ = await _prepare_auth_headers(auth, timeout)
|
||||
|
||||
a2a_agent_name = None
|
||||
if agent_card.name:
|
||||
a2a_agent_name = agent_card.name
|
||||
|
||||
agent_card_dict = agent_card.model_dump(exclude_none=True)
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
A2ADelegationStartedEvent(
|
||||
endpoint=endpoint,
|
||||
task_description=task_description,
|
||||
agent_id=agent_id or endpoint,
|
||||
context_id=context_id,
|
||||
is_multiturn=is_multiturn,
|
||||
turn_number=turn_number,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
agent_card=agent_card_dict,
|
||||
protocol_version=agent_card.protocol_version,
|
||||
provider=agent_card_dict.get("provider"),
|
||||
skill_id=skill_id,
|
||||
metadata=metadata,
|
||||
extensions=list(extensions.keys()) if extensions else None,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
|
||||
if turn_number == 1:
|
||||
agent_id_for_event = agent_id or endpoint
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
A2AConversationStartedEvent(
|
||||
agent_id=agent_id_for_event,
|
||||
endpoint=endpoint,
|
||||
context_id=context_id,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
agent_card=agent_card_dict,
|
||||
protocol_version=agent_card.protocol_version,
|
||||
provider=agent_card_dict.get("provider"),
|
||||
skill_id=skill_id,
|
||||
reference_task_ids=reference_task_ids,
|
||||
metadata=metadata,
|
||||
extensions=list(extensions.keys()) if extensions else None,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
|
||||
message_parts = []
|
||||
|
||||
if context:
|
||||
message_parts.append(f"Context:\n{context}\n\n")
|
||||
message_parts.append(f"{task_description}")
|
||||
message_text = "".join(message_parts)
|
||||
|
||||
if is_multiturn and conversation_history and not task_id:
|
||||
if first_task_id := conversation_history[0].task_id:
|
||||
task_id = first_task_id
|
||||
|
||||
parts: PartsDict = {"text": message_text}
|
||||
if response_model:
|
||||
parts.update(
|
||||
{
|
||||
"metadata": PartsMetadataDict(
|
||||
mimeType="application/json",
|
||||
schema=response_model.model_json_schema(),
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
message_metadata = metadata.copy() if metadata else {}
|
||||
if skill_id:
|
||||
message_metadata["skill_id"] = skill_id
|
||||
|
||||
parts_list: list[Part] = [Part(root=TextPart(**parts))]
|
||||
parts_list.extend(_create_file_parts(input_files))
|
||||
|
||||
message = Message(
|
||||
role=Role.user,
|
||||
message_id=str(uuid.uuid4()),
|
||||
parts=parts_list,
|
||||
context_id=context_id,
|
||||
task_id=task_id,
|
||||
reference_task_ids=reference_task_ids,
|
||||
metadata=message_metadata if message_metadata else None,
|
||||
extensions=extensions,
|
||||
)
|
||||
|
||||
new_messages: list[Message] = [*conversation_history, message]
|
||||
crewai_event_bus.emit(
|
||||
None,
|
||||
A2AMessageSentEvent(
|
||||
message=message_text,
|
||||
turn_number=turn_number,
|
||||
context_id=context_id,
|
||||
message_id=message.message_id,
|
||||
is_multiturn=is_multiturn,
|
||||
agent_role=agent_role,
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
skill_id=skill_id,
|
||||
metadata=message_metadata if message_metadata else None,
|
||||
extensions=list(extensions.keys()) if extensions else None,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
|
||||
handler = get_handler(updates)
|
||||
use_polling = isinstance(updates, PollingConfig)
|
||||
|
||||
handler_kwargs: dict[str, Any] = {
|
||||
"turn_number": turn_number,
|
||||
"is_multiturn": is_multiturn,
|
||||
"agent_role": agent_role,
|
||||
"context_id": context_id,
|
||||
"task_id": task_id,
|
||||
"endpoint": endpoint,
|
||||
"agent_branch": agent_branch,
|
||||
"a2a_agent_name": a2a_agent_name,
|
||||
"from_task": from_task,
|
||||
"from_agent": from_agent,
|
||||
}
|
||||
|
||||
if isinstance(updates, PollingConfig):
|
||||
handler_kwargs.update(
|
||||
{
|
||||
"polling_interval": updates.interval,
|
||||
"polling_timeout": updates.timeout or float(timeout),
|
||||
"history_length": updates.history_length,
|
||||
"max_polls": updates.max_polls,
|
||||
}
|
||||
)
|
||||
elif isinstance(updates, PushNotificationConfig):
|
||||
handler_kwargs.update(
|
||||
{
|
||||
"config": updates,
|
||||
"result_store": updates.result_store,
|
||||
"polling_timeout": updates.timeout or float(timeout),
|
||||
"polling_interval": updates.interval,
|
||||
}
|
||||
)
|
||||
|
||||
push_config_for_client = (
|
||||
updates if isinstance(updates, PushNotificationConfig) else None
|
||||
)
|
||||
|
||||
use_streaming = not use_polling and push_config_for_client is None
|
||||
|
||||
client_agent_card = agent_card
|
||||
if effective_url != agent_card.url:
|
||||
client_agent_card = agent_card.model_copy(update={"url": effective_url})
|
||||
|
||||
async with _create_a2a_client(
|
||||
agent_card=client_agent_card,
|
||||
transport_protocol=effective_transport,
|
||||
timeout=timeout,
|
||||
headers=headers,
|
||||
streaming=use_streaming,
|
||||
auth=auth,
|
||||
use_polling=use_polling,
|
||||
push_notification_config=push_config_for_client,
|
||||
client_extensions=client_extensions,
|
||||
accepted_output_modes=effective_output_modes, # type: ignore[arg-type]
|
||||
grpc_config=transport.grpc,
|
||||
) as client:
|
||||
result = await handler.execute(
|
||||
client=client,
|
||||
message=message,
|
||||
new_messages=new_messages,
|
||||
agent_card=agent_card,
|
||||
**handler_kwargs,
|
||||
)
|
||||
result["a2a_agent_name"] = a2a_agent_name
|
||||
result["agent_card"] = agent_card.model_dump(exclude_none=True)
|
||||
return result
|
||||
|
||||
|
||||
def _normalize_grpc_metadata(
|
||||
metadata: tuple[tuple[str, str], ...] | None,
|
||||
) -> tuple[tuple[str, str], ...] | None:
|
||||
"""Lowercase all gRPC metadata keys.
|
||||
|
||||
gRPC requires lowercase metadata keys, but some libraries (like the A2A SDK)
|
||||
use mixed-case headers like 'X-A2A-Extensions'. This normalizes them.
|
||||
"""
|
||||
if metadata is None:
|
||||
return None
|
||||
return tuple((key.lower(), value) for key, value in metadata)
|
||||
|
||||
|
||||
def _create_grpc_interceptors(
|
||||
auth_metadata: list[tuple[str, str]] | None = None,
|
||||
) -> list[Any]:
|
||||
"""Create gRPC interceptors for metadata normalization and auth injection.
|
||||
|
||||
Args:
|
||||
auth_metadata: Optional auth metadata to inject into all calls.
|
||||
Used for insecure channels that need auth (non-localhost without TLS).
|
||||
|
||||
Returns a list of interceptors that lowercase metadata keys for gRPC
|
||||
compatibility. Must be called after grpc is imported.
|
||||
"""
|
||||
import grpc.aio # type: ignore[import-untyped]
|
||||
|
||||
def _merge_metadata(
|
||||
existing: tuple[tuple[str, str], ...] | None,
|
||||
auth: list[tuple[str, str]] | None,
|
||||
) -> tuple[tuple[str, str], ...] | None:
|
||||
"""Merge existing metadata with auth metadata and normalize keys."""
|
||||
merged: list[tuple[str, str]] = []
|
||||
if existing:
|
||||
merged.extend(existing)
|
||||
if auth:
|
||||
merged.extend(auth)
|
||||
if not merged:
|
||||
return None
|
||||
return tuple((key.lower(), value) for key, value in merged)
|
||||
|
||||
def _inject_metadata(client_call_details: Any) -> Any:
|
||||
"""Inject merged metadata into call details."""
|
||||
return client_call_details._replace(
|
||||
metadata=_merge_metadata(client_call_details.metadata, auth_metadata)
|
||||
)
|
||||
|
||||
class MetadataUnaryUnary(grpc.aio.UnaryUnaryClientInterceptor): # type: ignore[misc,no-any-unimported]
|
||||
"""Interceptor for unary-unary calls that injects auth metadata."""
|
||||
|
||||
async def intercept_unary_unary( # type: ignore[no-untyped-def]
|
||||
self, continuation, client_call_details, request
|
||||
):
|
||||
"""Intercept unary-unary call and inject metadata."""
|
||||
return await continuation(_inject_metadata(client_call_details), request)
|
||||
|
||||
class MetadataUnaryStream(grpc.aio.UnaryStreamClientInterceptor): # type: ignore[misc,no-any-unimported]
|
||||
"""Interceptor for unary-stream calls that injects auth metadata."""
|
||||
|
||||
async def intercept_unary_stream( # type: ignore[no-untyped-def]
|
||||
self, continuation, client_call_details, request
|
||||
):
|
||||
"""Intercept unary-stream call and inject metadata."""
|
||||
return await continuation(_inject_metadata(client_call_details), request)
|
||||
|
||||
class MetadataStreamUnary(grpc.aio.StreamUnaryClientInterceptor): # type: ignore[misc,no-any-unimported]
|
||||
"""Interceptor for stream-unary calls that injects auth metadata."""
|
||||
|
||||
async def intercept_stream_unary( # type: ignore[no-untyped-def]
|
||||
self, continuation, client_call_details, request_iterator
|
||||
):
|
||||
"""Intercept stream-unary call and inject metadata."""
|
||||
return await continuation(
|
||||
_inject_metadata(client_call_details), request_iterator
|
||||
)
|
||||
|
||||
class MetadataStreamStream(grpc.aio.StreamStreamClientInterceptor): # type: ignore[misc,no-any-unimported]
|
||||
"""Interceptor for stream-stream calls that injects auth metadata."""
|
||||
|
||||
async def intercept_stream_stream( # type: ignore[no-untyped-def]
|
||||
self, continuation, client_call_details, request_iterator
|
||||
):
|
||||
"""Intercept stream-stream call and inject metadata."""
|
||||
return await continuation(
|
||||
_inject_metadata(client_call_details), request_iterator
|
||||
)
|
||||
|
||||
return [
|
||||
MetadataUnaryUnary(),
|
||||
MetadataUnaryStream(),
|
||||
MetadataStreamUnary(),
|
||||
MetadataStreamStream(),
|
||||
]
|
||||
|
||||
|
||||
def _create_grpc_channel_factory(
|
||||
grpc_config: GRPCClientConfig,
|
||||
auth: ClientAuthScheme | None = None,
|
||||
) -> Callable[[str], Any]:
|
||||
"""Create a gRPC channel factory with the given configuration.
|
||||
|
||||
Args:
|
||||
grpc_config: gRPC client configuration with channel options.
|
||||
auth: Optional ClientAuthScheme for TLS and auth configuration.
|
||||
|
||||
Returns:
|
||||
A callable that creates gRPC channels from URLs.
|
||||
"""
|
||||
try:
|
||||
import grpc
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"gRPC transport requires grpcio. Install with: pip install a2a-sdk[grpc]"
|
||||
) from e
|
||||
|
||||
auth_metadata: list[tuple[str, str]] = []
|
||||
|
||||
if auth is not None:
|
||||
from crewai.a2a.auth.client_schemes import (
|
||||
APIKeyAuth,
|
||||
BearerTokenAuth,
|
||||
HTTPBasicAuth,
|
||||
HTTPDigestAuth,
|
||||
OAuth2AuthorizationCode,
|
||||
OAuth2ClientCredentials,
|
||||
)
|
||||
|
||||
if isinstance(auth, HTTPDigestAuth):
|
||||
raise ValueError(
|
||||
"HTTPDigestAuth is not supported with gRPC transport. "
|
||||
"Digest authentication requires HTTP challenge-response flow. "
|
||||
"Use BearerTokenAuth, HTTPBasicAuth, APIKeyAuth (header), or OAuth2 instead."
|
||||
)
|
||||
if isinstance(auth, APIKeyAuth) and auth.location in ("query", "cookie"):
|
||||
raise ValueError(
|
||||
f"APIKeyAuth with location='{auth.location}' is not supported with gRPC transport. "
|
||||
"gRPC only supports header-based authentication. "
|
||||
"Use APIKeyAuth with location='header' instead."
|
||||
)
|
||||
|
||||
if isinstance(auth, BearerTokenAuth):
|
||||
auth_metadata.append(("authorization", f"Bearer {auth.token}"))
|
||||
elif isinstance(auth, HTTPBasicAuth):
|
||||
import base64
|
||||
|
||||
basic_credentials = f"{auth.username}:{auth.password}"
|
||||
encoded = base64.b64encode(basic_credentials.encode()).decode()
|
||||
auth_metadata.append(("authorization", f"Basic {encoded}"))
|
||||
elif isinstance(auth, APIKeyAuth) and auth.location == "header":
|
||||
header_name = auth.name.lower()
|
||||
auth_metadata.append((header_name, auth.api_key))
|
||||
elif isinstance(auth, (OAuth2ClientCredentials, OAuth2AuthorizationCode)):
|
||||
if auth._access_token:
|
||||
auth_metadata.append(("authorization", f"Bearer {auth._access_token}"))
|
||||
|
||||
def factory(url: str) -> Any:
|
||||
"""Create a gRPC channel for the given URL."""
|
||||
target = url
|
||||
use_tls = False
|
||||
|
||||
if url.startswith("grpcs://"):
|
||||
target = url[8:]
|
||||
use_tls = True
|
||||
elif url.startswith("grpc://"):
|
||||
target = url[7:]
|
||||
elif url.startswith("https://"):
|
||||
target = url[8:]
|
||||
use_tls = True
|
||||
elif url.startswith("http://"):
|
||||
target = url[7:]
|
||||
|
||||
options: list[tuple[str, Any]] = []
|
||||
if grpc_config.max_send_message_length is not None:
|
||||
options.append(
|
||||
("grpc.max_send_message_length", grpc_config.max_send_message_length)
|
||||
)
|
||||
if grpc_config.max_receive_message_length is not None:
|
||||
options.append(
|
||||
(
|
||||
"grpc.max_receive_message_length",
|
||||
grpc_config.max_receive_message_length,
|
||||
)
|
||||
)
|
||||
if grpc_config.keepalive_time_ms is not None:
|
||||
options.append(("grpc.keepalive_time_ms", grpc_config.keepalive_time_ms))
|
||||
if grpc_config.keepalive_timeout_ms is not None:
|
||||
options.append(
|
||||
("grpc.keepalive_timeout_ms", grpc_config.keepalive_timeout_ms)
|
||||
)
|
||||
|
||||
channel_credentials = None
|
||||
if auth and hasattr(auth, "tls") and auth.tls:
|
||||
channel_credentials = auth.tls.get_grpc_credentials()
|
||||
elif use_tls:
|
||||
channel_credentials = grpc.ssl_channel_credentials()
|
||||
|
||||
if channel_credentials and auth_metadata:
|
||||
|
||||
class AuthMetadataPlugin(grpc.AuthMetadataPlugin): # type: ignore[misc,no-any-unimported]
|
||||
"""gRPC auth metadata plugin that adds auth headers as metadata."""
|
||||
|
||||
def __init__(self, metadata: list[tuple[str, str]]) -> None:
|
||||
self._metadata = tuple(metadata)
|
||||
|
||||
def __call__( # type: ignore[no-any-unimported]
|
||||
self,
|
||||
context: grpc.AuthMetadataContext,
|
||||
callback: grpc.AuthMetadataPluginCallback,
|
||||
) -> None:
|
||||
callback(self._metadata, None)
|
||||
|
||||
call_creds = grpc.metadata_call_credentials(
|
||||
AuthMetadataPlugin(auth_metadata)
|
||||
)
|
||||
credentials = grpc.composite_channel_credentials(
|
||||
channel_credentials, call_creds
|
||||
)
|
||||
interceptors = _create_grpc_interceptors()
|
||||
return grpc.aio.secure_channel(
|
||||
target, credentials, options=options or None, interceptors=interceptors
|
||||
)
|
||||
if channel_credentials:
|
||||
interceptors = _create_grpc_interceptors()
|
||||
return grpc.aio.secure_channel(
|
||||
target,
|
||||
channel_credentials,
|
||||
options=options or None,
|
||||
interceptors=interceptors,
|
||||
)
|
||||
interceptors = _create_grpc_interceptors(
|
||||
auth_metadata=auth_metadata if auth_metadata else None
|
||||
)
|
||||
return grpc.aio.insecure_channel(
|
||||
target, options=options or None, interceptors=interceptors
|
||||
)
|
||||
|
||||
return factory
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def _create_a2a_client(
|
||||
agent_card: AgentCard,
|
||||
transport_protocol: Literal["JSONRPC", "GRPC", "HTTP+JSON"],
|
||||
timeout: int,
|
||||
headers: MutableMapping[str, str],
|
||||
streaming: bool,
|
||||
auth: ClientAuthScheme | None = None,
|
||||
use_polling: bool = False,
|
||||
push_notification_config: PushNotificationConfig | None = None,
|
||||
client_extensions: list[str] | None = None,
|
||||
accepted_output_modes: list[str] | None = None,
|
||||
grpc_config: GRPCClientConfig | None = None,
|
||||
) -> AsyncIterator[Client]:
|
||||
"""Create and configure an A2A client.
|
||||
|
||||
Args:
|
||||
agent_card: The A2A agent card.
|
||||
transport_protocol: Transport protocol to use.
|
||||
timeout: Request timeout in seconds.
|
||||
headers: HTTP headers (already with auth applied).
|
||||
streaming: Enable streaming responses.
|
||||
auth: Optional ClientAuthScheme for client configuration.
|
||||
use_polling: Enable polling mode.
|
||||
push_notification_config: Optional push notification config.
|
||||
client_extensions: A2A protocol extension URIs to declare support for.
|
||||
accepted_output_modes: MIME types the client can accept in responses.
|
||||
grpc_config: Optional gRPC client configuration.
|
||||
|
||||
Yields:
|
||||
Configured A2A client instance.
|
||||
"""
|
||||
verify = _get_tls_verify(auth)
|
||||
async with httpx.AsyncClient(
|
||||
timeout=timeout,
|
||||
headers=headers,
|
||||
verify=verify,
|
||||
) as httpx_client:
|
||||
if auth and isinstance(auth, (HTTPDigestAuth, APIKeyAuth)):
|
||||
configure_auth_client(auth, httpx_client)
|
||||
|
||||
push_configs: list[A2APushNotificationConfig] = []
|
||||
if push_notification_config is not None:
|
||||
push_configs.append(
|
||||
A2APushNotificationConfig(
|
||||
url=str(push_notification_config.url),
|
||||
id=push_notification_config.id,
|
||||
token=push_notification_config.token,
|
||||
authentication=push_notification_config.authentication,
|
||||
)
|
||||
)
|
||||
|
||||
grpc_channel_factory = None
|
||||
if transport_protocol == "GRPC":
|
||||
grpc_channel_factory = _create_grpc_channel_factory(
|
||||
grpc_config or GRPCClientConfig(),
|
||||
auth=auth,
|
||||
)
|
||||
|
||||
config = ClientConfig(
|
||||
httpx_client=httpx_client,
|
||||
supported_transports=[transport_protocol],
|
||||
streaming=streaming and not use_polling,
|
||||
polling=use_polling,
|
||||
accepted_output_modes=accepted_output_modes or DEFAULT_CLIENT_OUTPUT_MODES, # type: ignore[arg-type]
|
||||
push_notification_configs=push_configs,
|
||||
grpc_channel_factory=grpc_channel_factory,
|
||||
)
|
||||
|
||||
factory = ClientFactory(config)
|
||||
client = factory.create(agent_card)
|
||||
|
||||
if client_extensions:
|
||||
await client.add_request_middleware(ExtensionsMiddleware(client_extensions))
|
||||
|
||||
yield client
|
||||
from crewai_a2a.utils.delegation import * # noqa: E402, F403
|
||||
|
||||
@@ -1,131 +1,13 @@
|
||||
"""Structured JSON logging utilities for A2A module."""
|
||||
"""Backward-compatibility shim — use ``crewai_a2a.utils.logging`` instead."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from contextvars import ContextVar
|
||||
from datetime import datetime, timezone
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
import warnings
|
||||
|
||||
|
||||
_log_context: ContextVar[dict[str, Any] | None] = ContextVar(
|
||||
"log_context", default=None
|
||||
warnings.warn(
|
||||
"'crewai.a2a.utils.logging' has been moved to 'crewai_a2a.utils.logging'. "
|
||||
"Please update your imports. The old path will be removed in v2.0.0.",
|
||||
FutureWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
|
||||
class JSONFormatter(logging.Formatter):
|
||||
"""JSON formatter for structured logging.
|
||||
|
||||
Outputs logs as JSON with consistent fields for log aggregators.
|
||||
"""
|
||||
|
||||
def format(self, record: logging.LogRecord) -> str:
|
||||
"""Format log record as JSON string."""
|
||||
log_data: dict[str, Any] = {
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"level": record.levelname,
|
||||
"logger": record.name,
|
||||
"message": record.getMessage(),
|
||||
}
|
||||
|
||||
if record.exc_info:
|
||||
log_data["exception"] = self.formatException(record.exc_info)
|
||||
|
||||
context = _log_context.get()
|
||||
if context is not None:
|
||||
log_data.update(context)
|
||||
|
||||
if hasattr(record, "task_id"):
|
||||
log_data["task_id"] = record.task_id
|
||||
if hasattr(record, "context_id"):
|
||||
log_data["context_id"] = record.context_id
|
||||
if hasattr(record, "agent"):
|
||||
log_data["agent"] = record.agent
|
||||
if hasattr(record, "endpoint"):
|
||||
log_data["endpoint"] = record.endpoint
|
||||
if hasattr(record, "extension"):
|
||||
log_data["extension"] = record.extension
|
||||
if hasattr(record, "error"):
|
||||
log_data["error"] = record.error
|
||||
|
||||
for key, value in record.__dict__.items():
|
||||
if key.startswith("_") or key in (
|
||||
"name",
|
||||
"msg",
|
||||
"args",
|
||||
"created",
|
||||
"filename",
|
||||
"funcName",
|
||||
"levelname",
|
||||
"levelno",
|
||||
"lineno",
|
||||
"module",
|
||||
"msecs",
|
||||
"pathname",
|
||||
"process",
|
||||
"processName",
|
||||
"relativeCreated",
|
||||
"stack_info",
|
||||
"exc_info",
|
||||
"exc_text",
|
||||
"thread",
|
||||
"threadName",
|
||||
"taskName",
|
||||
"message",
|
||||
):
|
||||
continue
|
||||
if key not in log_data:
|
||||
log_data[key] = value
|
||||
|
||||
return json.dumps(log_data, default=str)
|
||||
|
||||
|
||||
class LogContext:
|
||||
"""Context manager for adding fields to all logs within a scope.
|
||||
|
||||
Example:
|
||||
with LogContext(task_id="abc", context_id="xyz"):
|
||||
logger.info("Processing task") # Includes task_id and context_id
|
||||
"""
|
||||
|
||||
def __init__(self, **fields: Any) -> None:
|
||||
self._fields = fields
|
||||
self._token: Any = None
|
||||
|
||||
def __enter__(self) -> LogContext:
|
||||
current = _log_context.get() or {}
|
||||
new_context = {**current, **self._fields}
|
||||
self._token = _log_context.set(new_context)
|
||||
return self
|
||||
|
||||
def __exit__(self, *args: Any) -> None:
|
||||
_log_context.reset(self._token)
|
||||
|
||||
|
||||
def configure_json_logging(logger_name: str = "crewai.a2a") -> None:
|
||||
"""Configure JSON logging for the A2A module.
|
||||
|
||||
Args:
|
||||
logger_name: Logger name to configure.
|
||||
"""
|
||||
logger = logging.getLogger(logger_name)
|
||||
|
||||
for handler in logger.handlers[:]:
|
||||
logger.removeHandler(handler)
|
||||
|
||||
handler = logging.StreamHandler()
|
||||
handler.setFormatter(JSONFormatter())
|
||||
logger.addHandler(handler)
|
||||
|
||||
|
||||
def get_logger(name: str) -> logging.Logger:
|
||||
"""Get a logger configured for structured JSON output.
|
||||
|
||||
Args:
|
||||
name: Logger name.
|
||||
|
||||
Returns:
|
||||
Configured logger instance.
|
||||
"""
|
||||
return logging.getLogger(name)
|
||||
from crewai_a2a.utils.logging import * # noqa: E402, F403
|
||||
|
||||
@@ -1,101 +1,13 @@
|
||||
"""Response model utilities for A2A agent interactions."""
|
||||
"""Backward-compatibility shim — use ``crewai_a2a.utils.response_model`` instead."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TypeAlias
|
||||
|
||||
from pydantic import BaseModel, Field, create_model
|
||||
|
||||
from crewai.a2a.config import A2AClientConfig, A2AConfig, A2AServerConfig
|
||||
from crewai.types.utils import create_literals_from_strings
|
||||
import warnings
|
||||
|
||||
|
||||
A2AConfigTypes: TypeAlias = A2AConfig | A2AServerConfig | A2AClientConfig
|
||||
A2AClientConfigTypes: TypeAlias = A2AConfig | A2AClientConfig
|
||||
warnings.warn(
|
||||
"'crewai.a2a.utils.response_model' has been moved to 'crewai_a2a.utils.response_model'. "
|
||||
"Please update your imports. The old path will be removed in v2.0.0.",
|
||||
FutureWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
|
||||
def create_agent_response_model(agent_ids: tuple[str, ...]) -> type[BaseModel] | None:
|
||||
"""Create a dynamic AgentResponse model with Literal types for agent IDs.
|
||||
|
||||
Args:
|
||||
agent_ids: List of available A2A agent IDs.
|
||||
|
||||
Returns:
|
||||
Dynamically created Pydantic model with Literal-constrained a2a_ids field,
|
||||
or None if agent_ids is empty.
|
||||
"""
|
||||
if not agent_ids:
|
||||
return None
|
||||
|
||||
DynamicLiteral = create_literals_from_strings(agent_ids) # noqa: N806
|
||||
|
||||
return create_model(
|
||||
"AgentResponse",
|
||||
a2a_ids=(
|
||||
tuple[DynamicLiteral, ...], # type: ignore[valid-type]
|
||||
Field(
|
||||
default_factory=tuple,
|
||||
max_length=len(agent_ids),
|
||||
description="A2A agent IDs to delegate to.",
|
||||
),
|
||||
),
|
||||
message=(
|
||||
str,
|
||||
Field(
|
||||
description="The message content. If is_a2a=true, this is sent to the A2A agent. If is_a2a=false, this is your final answer ending the conversation."
|
||||
),
|
||||
),
|
||||
is_a2a=(
|
||||
bool,
|
||||
Field(
|
||||
description="Set to false when the remote agent has answered your question - extract their answer and return it as your final message. Set to true ONLY if you need to ask a NEW, DIFFERENT question. NEVER repeat the same request - if the conversation history shows the agent already answered, set is_a2a=false immediately."
|
||||
),
|
||||
),
|
||||
__base__=BaseModel,
|
||||
)
|
||||
|
||||
|
||||
def extract_a2a_agent_ids_from_config(
|
||||
a2a_config: list[A2AConfigTypes] | A2AConfigTypes | None,
|
||||
) -> tuple[list[A2AClientConfigTypes], tuple[str, ...]]:
|
||||
"""Extract A2A agent IDs from A2A configuration.
|
||||
|
||||
Filters out A2AServerConfig since it doesn't have an endpoint for delegation.
|
||||
|
||||
Args:
|
||||
a2a_config: A2A configuration (any type).
|
||||
|
||||
Returns:
|
||||
Tuple of client A2A configs list and agent endpoint IDs.
|
||||
"""
|
||||
if a2a_config is None:
|
||||
return [], ()
|
||||
|
||||
configs: list[A2AConfigTypes]
|
||||
if isinstance(a2a_config, (A2AConfig, A2AClientConfig, A2AServerConfig)):
|
||||
configs = [a2a_config]
|
||||
else:
|
||||
configs = a2a_config
|
||||
|
||||
# Filter to only client configs (those with endpoint)
|
||||
client_configs: list[A2AClientConfigTypes] = [
|
||||
config for config in configs if isinstance(config, (A2AConfig, A2AClientConfig))
|
||||
]
|
||||
|
||||
return client_configs, tuple(config.endpoint for config in client_configs)
|
||||
|
||||
|
||||
def get_a2a_agents_and_response_model(
|
||||
a2a_config: list[A2AConfigTypes] | A2AConfigTypes | None,
|
||||
) -> tuple[list[A2AClientConfigTypes], type[BaseModel] | None]:
|
||||
"""Get A2A agent configs and response model.
|
||||
|
||||
Args:
|
||||
a2a_config: A2A configuration (any type).
|
||||
|
||||
Returns:
|
||||
Tuple of client A2A configs and response model.
|
||||
"""
|
||||
a2a_agents, agent_ids = extract_a2a_agent_ids_from_config(a2a_config=a2a_config)
|
||||
|
||||
return a2a_agents, create_agent_response_model(agent_ids)
|
||||
from crewai_a2a.utils.response_model import * # noqa: E402, F403
|
||||
|
||||
@@ -1,584 +1,13 @@
|
||||
"""A2A task utilities for server-side task management."""
|
||||
"""Backward-compatibility shim — use ``crewai_a2a.utils.task`` instead."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
from collections.abc import Callable, Coroutine
|
||||
from datetime import datetime
|
||||
from functools import wraps
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar, TypedDict, cast
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from a2a.server.agent_execution import RequestContext
|
||||
from a2a.server.events import EventQueue
|
||||
from a2a.types import (
|
||||
Artifact,
|
||||
FileWithBytes,
|
||||
FileWithUri,
|
||||
InternalError,
|
||||
InvalidParamsError,
|
||||
Message,
|
||||
Part,
|
||||
Task as A2ATask,
|
||||
TaskState,
|
||||
TaskStatus,
|
||||
TaskStatusUpdateEvent,
|
||||
)
|
||||
from a2a.utils import (
|
||||
get_data_parts,
|
||||
get_file_parts,
|
||||
new_agent_text_message,
|
||||
new_data_artifact,
|
||||
new_text_artifact,
|
||||
)
|
||||
from a2a.utils.errors import ServerError
|
||||
from aiocache import SimpleMemoryCache, caches # type: ignore[import-untyped]
|
||||
from pydantic import BaseModel
|
||||
|
||||
from crewai.a2a.utils.agent_card import _get_server_config
|
||||
from crewai.a2a.utils.content_type import validate_message_parts
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.a2a_events import (
|
||||
A2AServerTaskCanceledEvent,
|
||||
A2AServerTaskCompletedEvent,
|
||||
A2AServerTaskFailedEvent,
|
||||
A2AServerTaskStartedEvent,
|
||||
)
|
||||
from crewai.task import Task
|
||||
from crewai.utilities.pydantic_schema_utils import create_model_from_schema
|
||||
import warnings
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.a2a.extensions.server import ExtensionContext, ServerExtensionRegistry
|
||||
from crewai.agent import Agent
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
P = ParamSpec("P")
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class RedisCacheConfig(TypedDict, total=False):
|
||||
"""Configuration for aiocache Redis backend."""
|
||||
|
||||
cache: str
|
||||
endpoint: str
|
||||
port: int
|
||||
db: int
|
||||
password: str
|
||||
|
||||
|
||||
def _parse_redis_url(url: str) -> RedisCacheConfig:
|
||||
"""Parse a Redis URL into aiocache configuration.
|
||||
|
||||
Args:
|
||||
url: Redis connection URL (e.g., redis://localhost:6379/0).
|
||||
|
||||
Returns:
|
||||
Configuration dict for aiocache.RedisCache.
|
||||
"""
|
||||
parsed = urlparse(url)
|
||||
config: RedisCacheConfig = {
|
||||
"cache": "aiocache.RedisCache",
|
||||
"endpoint": parsed.hostname or "localhost",
|
||||
"port": parsed.port or 6379,
|
||||
}
|
||||
if parsed.path and parsed.path != "/":
|
||||
try:
|
||||
config["db"] = int(parsed.path.lstrip("/"))
|
||||
except ValueError:
|
||||
pass
|
||||
if parsed.password:
|
||||
config["password"] = parsed.password
|
||||
return config
|
||||
|
||||
|
||||
_redis_url = os.environ.get("REDIS_URL")
|
||||
|
||||
caches.set_config(
|
||||
{
|
||||
"default": _parse_redis_url(_redis_url)
|
||||
if _redis_url
|
||||
else {
|
||||
"cache": "aiocache.SimpleMemoryCache",
|
||||
}
|
||||
}
|
||||
warnings.warn(
|
||||
"'crewai.a2a.utils.task' has been moved to 'crewai_a2a.utils.task'. "
|
||||
"Please update your imports. The old path will be removed in v2.0.0.",
|
||||
FutureWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
|
||||
def cancellable(
|
||||
fn: Callable[P, Coroutine[Any, Any, T]],
|
||||
) -> Callable[P, Coroutine[Any, Any, T]]:
|
||||
"""Decorator that enables cancellation for A2A task execution.
|
||||
|
||||
Runs a cancellation watcher concurrently with the wrapped function.
|
||||
When a cancel event is published, the execution is cancelled.
|
||||
|
||||
Args:
|
||||
fn: The async function to wrap.
|
||||
|
||||
Returns:
|
||||
Wrapped function with cancellation support.
|
||||
"""
|
||||
|
||||
@wraps(fn)
|
||||
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
|
||||
"""Wrap function with cancellation monitoring."""
|
||||
context: RequestContext | None = None
|
||||
for arg in args:
|
||||
if isinstance(arg, RequestContext):
|
||||
context = arg
|
||||
break
|
||||
if context is None:
|
||||
context = cast(RequestContext | None, kwargs.get("context"))
|
||||
|
||||
if context is None:
|
||||
return await fn(*args, **kwargs)
|
||||
|
||||
task_id = context.task_id
|
||||
cache = caches.get("default")
|
||||
|
||||
async def poll_for_cancel() -> bool:
|
||||
"""Poll cache for cancellation flag."""
|
||||
while True:
|
||||
if await cache.get(f"cancel:{task_id}"):
|
||||
return True
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
async def watch_for_cancel() -> bool:
|
||||
"""Watch for cancellation events via pub/sub or polling."""
|
||||
if isinstance(cache, SimpleMemoryCache):
|
||||
return await poll_for_cancel()
|
||||
|
||||
try:
|
||||
client = cache.client
|
||||
pubsub = client.pubsub()
|
||||
await pubsub.subscribe(f"cancel:{task_id}")
|
||||
async for message in pubsub.listen():
|
||||
if message["type"] == "message":
|
||||
return True
|
||||
except (OSError, ConnectionError) as e:
|
||||
logger.warning(
|
||||
"Cancel watcher Redis error, falling back to polling",
|
||||
extra={"task_id": task_id, "error": str(e)},
|
||||
)
|
||||
return await poll_for_cancel()
|
||||
return False
|
||||
|
||||
execute_task = asyncio.create_task(fn(*args, **kwargs))
|
||||
cancel_watch = asyncio.create_task(watch_for_cancel())
|
||||
|
||||
try:
|
||||
done, _ = await asyncio.wait(
|
||||
[execute_task, cancel_watch],
|
||||
return_when=asyncio.FIRST_COMPLETED,
|
||||
)
|
||||
|
||||
if cancel_watch in done:
|
||||
execute_task.cancel()
|
||||
try:
|
||||
await execute_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
raise asyncio.CancelledError(f"Task {task_id} was cancelled")
|
||||
cancel_watch.cancel()
|
||||
return execute_task.result()
|
||||
finally:
|
||||
await cache.delete(f"cancel:{task_id}")
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def _convert_a2a_files_to_file_inputs(
|
||||
a2a_files: list[FileWithBytes | FileWithUri],
|
||||
) -> dict[str, Any]:
|
||||
"""Convert a2a file types to crewai FileInput dict.
|
||||
|
||||
Args:
|
||||
a2a_files: List of FileWithBytes or FileWithUri from a2a SDK.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping file names to FileInput objects.
|
||||
"""
|
||||
try:
|
||||
from crewai_files import File, FileBytes
|
||||
except ImportError:
|
||||
logger.debug("crewai_files not installed, returning empty file dict")
|
||||
return {}
|
||||
|
||||
file_dict: dict[str, Any] = {}
|
||||
for idx, a2a_file in enumerate(a2a_files):
|
||||
if isinstance(a2a_file, FileWithBytes):
|
||||
file_bytes = base64.b64decode(a2a_file.bytes)
|
||||
name = a2a_file.name or f"file_{idx}"
|
||||
file_source = FileBytes(data=file_bytes, filename=a2a_file.name)
|
||||
file_dict[name] = File(source=file_source)
|
||||
elif isinstance(a2a_file, FileWithUri):
|
||||
name = a2a_file.name or f"file_{idx}"
|
||||
file_dict[name] = File(source=a2a_file.uri)
|
||||
|
||||
return file_dict
|
||||
|
||||
|
||||
def _extract_response_schema(parts: list[Part]) -> dict[str, Any] | None:
|
||||
"""Extract response schema from message parts metadata.
|
||||
|
||||
The client may include a JSON schema in TextPart metadata to specify
|
||||
the expected response format (see delegation.py line 463).
|
||||
|
||||
Args:
|
||||
parts: List of message parts.
|
||||
|
||||
Returns:
|
||||
JSON schema dict if found, None otherwise.
|
||||
"""
|
||||
for part in parts:
|
||||
if part.root.kind == "text" and part.root.metadata:
|
||||
schema = part.root.metadata.get("schema")
|
||||
if schema and isinstance(schema, dict):
|
||||
return schema # type: ignore[no-any-return]
|
||||
return None
|
||||
|
||||
|
||||
def _create_result_artifact(
|
||||
result: Any,
|
||||
task_id: str,
|
||||
) -> Artifact:
|
||||
"""Create artifact from task result, using DataPart for structured data.
|
||||
|
||||
Args:
|
||||
result: The task execution result.
|
||||
task_id: The task ID for naming the artifact.
|
||||
|
||||
Returns:
|
||||
Artifact with appropriate part type (DataPart for dict/Pydantic, TextPart for strings).
|
||||
"""
|
||||
artifact_name = f"result_{task_id}"
|
||||
if isinstance(result, dict):
|
||||
return new_data_artifact(artifact_name, result)
|
||||
if isinstance(result, BaseModel):
|
||||
return new_data_artifact(artifact_name, result.model_dump())
|
||||
return new_text_artifact(artifact_name, str(result))
|
||||
|
||||
|
||||
def _build_task_description(
|
||||
user_message: str,
|
||||
structured_inputs: list[dict[str, Any]],
|
||||
) -> str:
|
||||
"""Build task description including structured data if present.
|
||||
|
||||
Args:
|
||||
user_message: The original user message text.
|
||||
structured_inputs: List of structured data from DataParts.
|
||||
|
||||
Returns:
|
||||
Task description with structured data appended if present.
|
||||
"""
|
||||
if not structured_inputs:
|
||||
return user_message
|
||||
|
||||
structured_json = json.dumps(structured_inputs, indent=2)
|
||||
return f"{user_message}\n\nStructured Data:\n{structured_json}"
|
||||
|
||||
|
||||
async def execute(
|
||||
agent: Agent,
|
||||
context: RequestContext,
|
||||
event_queue: EventQueue,
|
||||
) -> None:
|
||||
"""Execute an A2A task using a CrewAI agent.
|
||||
|
||||
Args:
|
||||
agent: The CrewAI agent to execute the task.
|
||||
context: The A2A request context containing the user's message.
|
||||
event_queue: The event queue for sending responses back.
|
||||
"""
|
||||
await _execute_impl(agent, context, event_queue, None, None)
|
||||
|
||||
|
||||
@cancellable
|
||||
async def _execute_impl(
|
||||
agent: Agent,
|
||||
context: RequestContext,
|
||||
event_queue: EventQueue,
|
||||
extension_registry: ServerExtensionRegistry | None,
|
||||
extension_context: ExtensionContext | None,
|
||||
) -> None:
|
||||
"""Internal implementation for task execution with optional extensions."""
|
||||
server_config = _get_server_config(agent)
|
||||
if context.message and context.message.parts and server_config:
|
||||
allowed_modes = server_config.default_input_modes
|
||||
invalid_types = validate_message_parts(context.message.parts, allowed_modes)
|
||||
if invalid_types:
|
||||
raise ServerError(
|
||||
InvalidParamsError(
|
||||
message=f"Unsupported content type(s): {', '.join(invalid_types)}. "
|
||||
f"Supported: {', '.join(allowed_modes)}"
|
||||
)
|
||||
)
|
||||
|
||||
if extension_registry and extension_context:
|
||||
await extension_registry.invoke_on_request(extension_context)
|
||||
|
||||
user_message = context.get_user_input()
|
||||
|
||||
response_model: type[BaseModel] | None = None
|
||||
structured_inputs: list[dict[str, Any]] = []
|
||||
a2a_files: list[FileWithBytes | FileWithUri] = []
|
||||
|
||||
if context.message and context.message.parts:
|
||||
schema = _extract_response_schema(context.message.parts)
|
||||
if schema:
|
||||
try:
|
||||
response_model = create_model_from_schema(schema)
|
||||
except Exception as e:
|
||||
logger.debug(
|
||||
"Failed to create response model from schema",
|
||||
extra={"error": str(e), "schema_title": schema.get("title")},
|
||||
)
|
||||
|
||||
structured_inputs = get_data_parts(context.message.parts)
|
||||
a2a_files = get_file_parts(context.message.parts)
|
||||
|
||||
task_id = context.task_id
|
||||
context_id = context.context_id
|
||||
if task_id is None or context_id is None:
|
||||
msg = "task_id and context_id are required"
|
||||
crewai_event_bus.emit(
|
||||
agent,
|
||||
A2AServerTaskFailedEvent(
|
||||
task_id="",
|
||||
context_id="",
|
||||
error=msg,
|
||||
from_agent=agent,
|
||||
),
|
||||
)
|
||||
raise ServerError(InvalidParamsError(message=msg)) from None
|
||||
|
||||
task = Task(
|
||||
description=_build_task_description(user_message, structured_inputs),
|
||||
expected_output="Response to the user's request",
|
||||
agent=agent,
|
||||
response_model=response_model,
|
||||
input_files=_convert_a2a_files_to_file_inputs(a2a_files),
|
||||
)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
agent,
|
||||
A2AServerTaskStartedEvent(
|
||||
task_id=task_id,
|
||||
context_id=context_id,
|
||||
from_task=task,
|
||||
from_agent=agent,
|
||||
),
|
||||
)
|
||||
|
||||
try:
|
||||
result = await agent.aexecute_task(task=task, tools=agent.tools)
|
||||
if extension_registry and extension_context:
|
||||
result = await extension_registry.invoke_on_response(
|
||||
extension_context, result
|
||||
)
|
||||
result_str = str(result)
|
||||
history: list[Message] = [context.message] if context.message else []
|
||||
history.append(new_agent_text_message(result_str, context_id, task_id))
|
||||
await event_queue.enqueue_event(
|
||||
A2ATask(
|
||||
id=task_id,
|
||||
context_id=context_id,
|
||||
status=TaskStatus(state=TaskState.completed),
|
||||
artifacts=[_create_result_artifact(result, task_id)],
|
||||
history=history,
|
||||
)
|
||||
)
|
||||
crewai_event_bus.emit(
|
||||
agent,
|
||||
A2AServerTaskCompletedEvent(
|
||||
task_id=task_id,
|
||||
context_id=context_id,
|
||||
result=str(result),
|
||||
from_task=task,
|
||||
from_agent=agent,
|
||||
),
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
crewai_event_bus.emit(
|
||||
agent,
|
||||
A2AServerTaskCanceledEvent(
|
||||
task_id=task_id,
|
||||
context_id=context_id,
|
||||
from_task=task,
|
||||
from_agent=agent,
|
||||
),
|
||||
)
|
||||
raise
|
||||
except Exception as e:
|
||||
crewai_event_bus.emit(
|
||||
agent,
|
||||
A2AServerTaskFailedEvent(
|
||||
task_id=task_id,
|
||||
context_id=context_id,
|
||||
error=str(e),
|
||||
from_task=task,
|
||||
from_agent=agent,
|
||||
),
|
||||
)
|
||||
raise ServerError(
|
||||
error=InternalError(message=f"Task execution failed: {e}")
|
||||
) from e
|
||||
|
||||
|
||||
async def execute_with_extensions(
|
||||
agent: Agent,
|
||||
context: RequestContext,
|
||||
event_queue: EventQueue,
|
||||
extension_registry: ServerExtensionRegistry,
|
||||
extension_context: ExtensionContext,
|
||||
) -> None:
|
||||
"""Execute an A2A task with extension hooks.
|
||||
|
||||
Args:
|
||||
agent: The CrewAI agent to execute the task.
|
||||
context: The A2A request context containing the user's message.
|
||||
event_queue: The event queue for sending responses back.
|
||||
extension_registry: Registry of server extensions.
|
||||
extension_context: Context for extension hooks.
|
||||
"""
|
||||
await _execute_impl(
|
||||
agent, context, event_queue, extension_registry, extension_context
|
||||
)
|
||||
|
||||
|
||||
async def cancel(
|
||||
context: RequestContext,
|
||||
event_queue: EventQueue,
|
||||
) -> A2ATask | None:
|
||||
"""Cancel an A2A task.
|
||||
|
||||
Publishes a cancel event that the cancellable decorator listens for.
|
||||
|
||||
Args:
|
||||
context: The A2A request context containing task information.
|
||||
event_queue: The event queue for sending the cancellation status.
|
||||
|
||||
Returns:
|
||||
The canceled task with updated status.
|
||||
"""
|
||||
task_id = context.task_id
|
||||
context_id = context.context_id
|
||||
if task_id is None or context_id is None:
|
||||
raise ServerError(InvalidParamsError(message="task_id and context_id required"))
|
||||
|
||||
if context.current_task and context.current_task.status.state in (
|
||||
TaskState.completed,
|
||||
TaskState.failed,
|
||||
TaskState.canceled,
|
||||
):
|
||||
return context.current_task
|
||||
|
||||
cache = caches.get("default")
|
||||
|
||||
await cache.set(f"cancel:{task_id}", True, ttl=3600)
|
||||
if not isinstance(cache, SimpleMemoryCache):
|
||||
await cache.client.publish(f"cancel:{task_id}", "cancel")
|
||||
|
||||
await event_queue.enqueue_event(
|
||||
TaskStatusUpdateEvent(
|
||||
task_id=task_id,
|
||||
context_id=context_id,
|
||||
status=TaskStatus(state=TaskState.canceled),
|
||||
final=True,
|
||||
)
|
||||
)
|
||||
|
||||
if context.current_task:
|
||||
context.current_task.status = TaskStatus(state=TaskState.canceled)
|
||||
return context.current_task
|
||||
return None
|
||||
|
||||
|
||||
def list_tasks(
|
||||
tasks: list[A2ATask],
|
||||
context_id: str | None = None,
|
||||
status: TaskState | None = None,
|
||||
status_timestamp_after: datetime | None = None,
|
||||
page_size: int = 50,
|
||||
page_token: str | None = None,
|
||||
history_length: int | None = None,
|
||||
include_artifacts: bool = False,
|
||||
) -> tuple[list[A2ATask], str | None, int]:
|
||||
"""Filter and paginate A2A tasks.
|
||||
|
||||
Provides filtering by context, status, and timestamp, along with
|
||||
cursor-based pagination. This is a pure utility function that operates
|
||||
on an in-memory list of tasks - storage retrieval is handled separately.
|
||||
|
||||
Args:
|
||||
tasks: All tasks to filter.
|
||||
context_id: Filter by context ID to get tasks in a conversation.
|
||||
status: Filter by task state (e.g., completed, working).
|
||||
status_timestamp_after: Filter to tasks updated after this time.
|
||||
page_size: Maximum tasks per page (default 50).
|
||||
page_token: Base64-encoded cursor from previous response.
|
||||
history_length: Limit history messages per task (None = full history).
|
||||
include_artifacts: Whether to include task artifacts (default False).
|
||||
|
||||
Returns:
|
||||
Tuple of (filtered_tasks, next_page_token, total_count).
|
||||
- filtered_tasks: Tasks matching filters, paginated and trimmed.
|
||||
- next_page_token: Token for next page, or None if no more pages.
|
||||
- total_count: Total number of tasks matching filters (before pagination).
|
||||
"""
|
||||
filtered: list[A2ATask] = []
|
||||
for task in tasks:
|
||||
if context_id and task.context_id != context_id:
|
||||
continue
|
||||
if status and task.status.state != status:
|
||||
continue
|
||||
if status_timestamp_after and task.status.timestamp:
|
||||
ts = datetime.fromisoformat(task.status.timestamp.replace("Z", "+00:00"))
|
||||
if ts <= status_timestamp_after:
|
||||
continue
|
||||
filtered.append(task)
|
||||
|
||||
def get_timestamp(t: A2ATask) -> datetime:
|
||||
"""Extract timestamp from task status for sorting."""
|
||||
if t.status.timestamp is None:
|
||||
return datetime.min
|
||||
return datetime.fromisoformat(t.status.timestamp.replace("Z", "+00:00"))
|
||||
|
||||
filtered.sort(key=get_timestamp, reverse=True)
|
||||
total = len(filtered)
|
||||
|
||||
start = 0
|
||||
if page_token:
|
||||
try:
|
||||
cursor_id = base64.b64decode(page_token).decode()
|
||||
for idx, task in enumerate(filtered):
|
||||
if task.id == cursor_id:
|
||||
start = idx + 1
|
||||
break
|
||||
except (ValueError, UnicodeDecodeError):
|
||||
pass
|
||||
|
||||
page = filtered[start : start + page_size]
|
||||
|
||||
result: list[A2ATask] = []
|
||||
for task in page:
|
||||
task = task.model_copy(deep=True)
|
||||
if history_length is not None and task.history:
|
||||
task.history = task.history[-history_length:]
|
||||
if not include_artifacts:
|
||||
task.artifacts = None
|
||||
result.append(task)
|
||||
|
||||
next_token: str | None = None
|
||||
if result and len(result) == page_size:
|
||||
next_token = base64.b64encode(result[-1].id.encode()).decode()
|
||||
|
||||
return result, next_token, total
|
||||
from crewai_a2a.utils.task import * # noqa: E402, F403
|
||||
|
||||
@@ -1,215 +1,13 @@
|
||||
"""Transport negotiation utilities for A2A protocol.
|
||||
"""Backward-compatibility shim — use ``crewai_a2a.utils.transport`` instead."""
|
||||
|
||||
This module provides functionality for negotiating the transport protocol
|
||||
between an A2A client and server based on their respective capabilities
|
||||
and preferences.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Final, Literal
|
||||
|
||||
from a2a.types import AgentCard, AgentInterface
|
||||
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.a2a_events import A2ATransportNegotiatedEvent
|
||||
import warnings
|
||||
|
||||
|
||||
TransportProtocol = Literal["JSONRPC", "GRPC", "HTTP+JSON"]
|
||||
NegotiationSource = Literal["client_preferred", "server_preferred", "fallback"]
|
||||
warnings.warn(
|
||||
"'crewai.a2a.utils.transport' has been moved to 'crewai_a2a.utils.transport'. "
|
||||
"Please update your imports. The old path will be removed in v2.0.0.",
|
||||
FutureWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
JSONRPC_TRANSPORT: Literal["JSONRPC"] = "JSONRPC"
|
||||
GRPC_TRANSPORT: Literal["GRPC"] = "GRPC"
|
||||
HTTP_JSON_TRANSPORT: Literal["HTTP+JSON"] = "HTTP+JSON"
|
||||
|
||||
DEFAULT_TRANSPORT_PREFERENCE: Final[list[TransportProtocol]] = [
|
||||
JSONRPC_TRANSPORT,
|
||||
GRPC_TRANSPORT,
|
||||
HTTP_JSON_TRANSPORT,
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class NegotiatedTransport:
|
||||
"""Result of transport negotiation.
|
||||
|
||||
Attributes:
|
||||
transport: The negotiated transport protocol.
|
||||
url: The URL to use for this transport.
|
||||
source: How the transport was selected ('preferred', 'additional', 'fallback').
|
||||
"""
|
||||
|
||||
transport: str
|
||||
url: str
|
||||
source: NegotiationSource
|
||||
|
||||
|
||||
class TransportNegotiationError(Exception):
|
||||
"""Raised when no compatible transport can be negotiated."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client_transports: list[str],
|
||||
server_transports: list[str],
|
||||
message: str | None = None,
|
||||
) -> None:
|
||||
"""Initialize the error with negotiation details.
|
||||
|
||||
Args:
|
||||
client_transports: Transports supported by the client.
|
||||
server_transports: Transports supported by the server.
|
||||
message: Optional custom error message.
|
||||
"""
|
||||
self.client_transports = client_transports
|
||||
self.server_transports = server_transports
|
||||
if message is None:
|
||||
message = (
|
||||
f"No compatible transport found. "
|
||||
f"Client supports: {client_transports}. "
|
||||
f"Server supports: {server_transports}."
|
||||
)
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
def _get_server_interfaces(agent_card: AgentCard) -> list[AgentInterface]:
|
||||
"""Extract all available interfaces from an AgentCard.
|
||||
|
||||
Creates a unified list of interfaces including the primary URL and
|
||||
any additional interfaces declared by the agent.
|
||||
|
||||
Args:
|
||||
agent_card: The agent's card containing transport information.
|
||||
|
||||
Returns:
|
||||
List of AgentInterface objects representing all available endpoints.
|
||||
"""
|
||||
interfaces: list[AgentInterface] = []
|
||||
|
||||
primary_transport = agent_card.preferred_transport or JSONRPC_TRANSPORT
|
||||
interfaces.append(
|
||||
AgentInterface(
|
||||
transport=primary_transport,
|
||||
url=agent_card.url,
|
||||
)
|
||||
)
|
||||
|
||||
if agent_card.additional_interfaces:
|
||||
for interface in agent_card.additional_interfaces:
|
||||
is_duplicate = any(
|
||||
i.url == interface.url and i.transport == interface.transport
|
||||
for i in interfaces
|
||||
)
|
||||
if not is_duplicate:
|
||||
interfaces.append(interface)
|
||||
|
||||
return interfaces
|
||||
|
||||
|
||||
def negotiate_transport(
|
||||
agent_card: AgentCard,
|
||||
client_supported_transports: list[str] | None = None,
|
||||
client_preferred_transport: str | None = None,
|
||||
emit_event: bool = True,
|
||||
endpoint: str | None = None,
|
||||
a2a_agent_name: str | None = None,
|
||||
) -> NegotiatedTransport:
|
||||
"""Negotiate the transport protocol between client and server.
|
||||
|
||||
Compares the client's supported transports with the server's available
|
||||
interfaces to find a compatible transport and URL.
|
||||
|
||||
Negotiation logic:
|
||||
1. If client_preferred_transport is set and server supports it → use it
|
||||
2. Otherwise, if server's preferred is in client's supported → use server's
|
||||
3. Otherwise, find first match from client's supported in server's interfaces
|
||||
|
||||
Args:
|
||||
agent_card: The server's AgentCard with transport information.
|
||||
client_supported_transports: Transports the client can use.
|
||||
Defaults to ["JSONRPC"] if not specified.
|
||||
client_preferred_transport: Client's preferred transport. If set and
|
||||
server supports it, takes priority over server preference.
|
||||
emit_event: Whether to emit a transport negotiation event.
|
||||
endpoint: Original endpoint URL for event metadata.
|
||||
a2a_agent_name: Agent name for event metadata.
|
||||
|
||||
Returns:
|
||||
NegotiatedTransport with the selected transport, URL, and source.
|
||||
|
||||
Raises:
|
||||
TransportNegotiationError: If no compatible transport is found.
|
||||
"""
|
||||
if client_supported_transports is None:
|
||||
client_supported_transports = [JSONRPC_TRANSPORT]
|
||||
|
||||
client_transports = [t.upper() for t in client_supported_transports]
|
||||
client_preferred = (
|
||||
client_preferred_transport.upper() if client_preferred_transport else None
|
||||
)
|
||||
|
||||
server_interfaces = _get_server_interfaces(agent_card)
|
||||
server_transports = [i.transport.upper() for i in server_interfaces]
|
||||
|
||||
transport_to_interface: dict[str, AgentInterface] = {}
|
||||
for interface in server_interfaces:
|
||||
transport_upper = interface.transport.upper()
|
||||
if transport_upper not in transport_to_interface:
|
||||
transport_to_interface[transport_upper] = interface
|
||||
|
||||
result: NegotiatedTransport | None = None
|
||||
|
||||
if client_preferred and client_preferred in transport_to_interface:
|
||||
interface = transport_to_interface[client_preferred]
|
||||
result = NegotiatedTransport(
|
||||
transport=interface.transport,
|
||||
url=interface.url,
|
||||
source="client_preferred",
|
||||
)
|
||||
else:
|
||||
server_preferred = (agent_card.preferred_transport or JSONRPC_TRANSPORT).upper()
|
||||
if (
|
||||
server_preferred in client_transports
|
||||
and server_preferred in transport_to_interface
|
||||
):
|
||||
interface = transport_to_interface[server_preferred]
|
||||
result = NegotiatedTransport(
|
||||
transport=interface.transport,
|
||||
url=interface.url,
|
||||
source="server_preferred",
|
||||
)
|
||||
else:
|
||||
for transport in client_transports:
|
||||
if transport in transport_to_interface:
|
||||
interface = transport_to_interface[transport]
|
||||
result = NegotiatedTransport(
|
||||
transport=interface.transport,
|
||||
url=interface.url,
|
||||
source="fallback",
|
||||
)
|
||||
break
|
||||
|
||||
if result is None:
|
||||
raise TransportNegotiationError(
|
||||
client_transports=client_transports,
|
||||
server_transports=server_transports,
|
||||
)
|
||||
|
||||
if emit_event:
|
||||
crewai_event_bus.emit(
|
||||
None,
|
||||
A2ATransportNegotiatedEvent(
|
||||
endpoint=endpoint or agent_card.url,
|
||||
a2a_agent_name=a2a_agent_name or agent_card.name,
|
||||
negotiated_transport=result.transport,
|
||||
negotiated_url=result.url,
|
||||
source=result.source,
|
||||
client_supported_transports=client_transports,
|
||||
server_supported_transports=server_transports,
|
||||
server_preferred_transport=agent_card.preferred_transport
|
||||
or JSONRPC_TRANSPORT,
|
||||
client_preferred_transport=client_preferred,
|
||||
),
|
||||
)
|
||||
|
||||
return result
|
||||
from crewai_a2a.utils.transport import * # noqa: E402, F403
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -84,16 +84,16 @@ from crewai.utilities.training_handler import CrewTrainingHandler
|
||||
|
||||
|
||||
try:
|
||||
from crewai.a2a.types import AgentResponseProtocol
|
||||
from crewai_a2a.types import AgentResponseProtocol
|
||||
except ImportError:
|
||||
AgentResponseProtocol = None # type: ignore[assignment, misc]
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai_a2a.config import A2AClientConfig, A2AConfig, A2AServerConfig
|
||||
from crewai_files import FileInput
|
||||
from crewai_tools import CodeInterpreterTool
|
||||
|
||||
from crewai.a2a.config import A2AClientConfig, A2AConfig, A2AServerConfig
|
||||
from crewai.agents.agent_builder.base_agent import PlatformAppOrAction
|
||||
from crewai.task import Task
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
@@ -1740,7 +1740,7 @@ class Agent(BaseAgent):
|
||||
|
||||
# Rebuild Agent model to resolve A2A type forward references
|
||||
try:
|
||||
from crewai.a2a.config import (
|
||||
from crewai_a2a.config import (
|
||||
A2AClientConfig as _A2AClientConfig,
|
||||
A2AConfig as _A2AConfig,
|
||||
A2AServerConfig as _A2AServerConfig,
|
||||
|
||||
@@ -58,10 +58,10 @@ class AgentMeta(ModelMetaclass):
|
||||
|
||||
a2a_value = getattr(self, "a2a", None)
|
||||
if a2a_value is not None:
|
||||
from crewai.a2a.extensions.registry import (
|
||||
from crewai_a2a.extensions.registry import (
|
||||
create_extension_registry_from_config,
|
||||
)
|
||||
from crewai.a2a.wrapper import wrap_agent_with_a2a_instance
|
||||
from crewai_a2a.wrapper import wrap_agent_with_a2a_instance
|
||||
|
||||
extension_registry = create_extension_registry_from_config(
|
||||
a2a_value
|
||||
|
||||
@@ -31,10 +31,9 @@ from typing_extensions import Self
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai_a2a.config import A2AClientConfig, A2AConfig, A2AServerConfig
|
||||
from crewai_files import FileInput
|
||||
|
||||
from crewai.a2a.config import A2AClientConfig, A2AConfig, A2AServerConfig
|
||||
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess
|
||||
from crewai.agents.cache.cache_handler import CacheHandler
|
||||
@@ -120,8 +119,9 @@ def _kickoff_with_a2a_support(
|
||||
Returns:
|
||||
LiteAgentOutput from either local execution or A2A delegation.
|
||||
"""
|
||||
from crewai.a2a.utils.response_model import get_a2a_agents_and_response_model
|
||||
from crewai.a2a.wrapper import _execute_task_with_a2a
|
||||
from crewai_a2a.utils.response_model import get_a2a_agents_and_response_model
|
||||
from crewai_a2a.wrapper import _execute_task_with_a2a
|
||||
|
||||
from crewai.task import Task
|
||||
|
||||
a2a_agents, agent_response_model = get_a2a_agents_and_response_model(agent.a2a)
|
||||
@@ -319,11 +319,11 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
def setup_a2a_support(self) -> Self:
|
||||
"""Setup A2A extensions and server methods if a2a config exists."""
|
||||
if self.a2a:
|
||||
from crewai.a2a.config import A2AClientConfig, A2AConfig
|
||||
from crewai.a2a.extensions.registry import (
|
||||
from crewai_a2a.config import A2AClientConfig, A2AConfig
|
||||
from crewai_a2a.extensions.registry import (
|
||||
create_extension_registry_from_config,
|
||||
)
|
||||
from crewai.a2a.utils.agent_card import inject_a2a_server_methods
|
||||
from crewai_a2a.utils.agent_card import inject_a2a_server_methods
|
||||
|
||||
configs = self.a2a if isinstance(self.a2a, list) else [self.a2a]
|
||||
client_configs = [
|
||||
@@ -995,7 +995,7 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
|
||||
|
||||
try:
|
||||
from crewai.a2a.config import (
|
||||
from crewai_a2a.config import (
|
||||
A2AClientConfig as _A2AClientConfig,
|
||||
A2AConfig as _A2AConfig,
|
||||
A2AServerConfig as _A2AServerConfig,
|
||||
|
||||
@@ -4,7 +4,7 @@ from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.a2a.config import A2AConfig
|
||||
from crewai_a2a.config import A2AConfig
|
||||
|
||||
try:
|
||||
from a2a.types import Message, Role
|
||||
@@ -27,8 +27,8 @@ def _create_mock_agent_card(name: str = "Test", url: str = "http://test-endpoint
|
||||
@pytest.mark.skipif(not A2A_SDK_INSTALLED, reason="Requires a2a-sdk to be installed")
|
||||
def test_trust_remote_completion_status_true_returns_directly():
|
||||
"""When trust_remote_completion_status=True and A2A returns completed, return result directly."""
|
||||
from crewai.a2a.wrapper import _delegate_to_a2a
|
||||
from crewai.a2a.types import AgentResponseProtocol
|
||||
from crewai_a2a.wrapper import _delegate_to_a2a
|
||||
from crewai_a2a.types import AgentResponseProtocol
|
||||
from crewai import Agent, Task
|
||||
|
||||
a2a_config = A2AConfig(
|
||||
@@ -51,8 +51,8 @@ def test_trust_remote_completion_status_true_returns_directly():
|
||||
a2a_ids = ["http://test-endpoint.com/"]
|
||||
|
||||
with (
|
||||
patch("crewai.a2a.wrapper.execute_a2a_delegation") as mock_execute,
|
||||
patch("crewai.a2a.wrapper._fetch_agent_cards_concurrently") as mock_fetch,
|
||||
patch("crewai_a2a.wrapper.execute_a2a_delegation") as mock_execute,
|
||||
patch("crewai_a2a.wrapper._fetch_agent_cards_concurrently") as mock_fetch,
|
||||
):
|
||||
mock_card = _create_mock_agent_card()
|
||||
mock_fetch.return_value = ({"http://test-endpoint.com/": mock_card}, {})
|
||||
@@ -83,7 +83,7 @@ def test_trust_remote_completion_status_true_returns_directly():
|
||||
@pytest.mark.skipif(not A2A_SDK_INSTALLED, reason="Requires a2a-sdk to be installed")
|
||||
def test_trust_remote_completion_status_false_continues_conversation():
|
||||
"""When trust_remote_completion_status=False and A2A returns completed, ask server agent."""
|
||||
from crewai.a2a.wrapper import _delegate_to_a2a
|
||||
from crewai_a2a.wrapper import _delegate_to_a2a
|
||||
from crewai import Agent, Task
|
||||
|
||||
a2a_config = A2AConfig(
|
||||
@@ -116,8 +116,8 @@ def test_trust_remote_completion_status_false_continues_conversation():
|
||||
return "unexpected"
|
||||
|
||||
with (
|
||||
patch("crewai.a2a.wrapper.execute_a2a_delegation") as mock_execute,
|
||||
patch("crewai.a2a.wrapper._fetch_agent_cards_concurrently") as mock_fetch,
|
||||
patch("crewai_a2a.wrapper.execute_a2a_delegation") as mock_execute,
|
||||
patch("crewai_a2a.wrapper._fetch_agent_cards_concurrently") as mock_fetch,
|
||||
):
|
||||
mock_card = _create_mock_agent_card()
|
||||
mock_fetch.return_value = ({"http://test-endpoint.com/": mock_card}, {})
|
||||
|
||||
@@ -7,7 +7,7 @@ import os
|
||||
import pytest
|
||||
|
||||
from crewai import Agent
|
||||
from crewai.a2a.config import A2AClientConfig
|
||||
from crewai_a2a.config import A2AClientConfig
|
||||
|
||||
|
||||
A2A_TEST_ENDPOINT = os.getenv(
|
||||
|
||||
@@ -5,7 +5,7 @@ from unittest.mock import patch
|
||||
import pytest
|
||||
|
||||
from crewai import Agent
|
||||
from crewai.a2a.config import A2AConfig
|
||||
from crewai_a2a.config import A2AConfig
|
||||
|
||||
try:
|
||||
import a2a # noqa: F401
|
||||
|
||||
@@ -106,6 +106,7 @@ ignore-decorators = ["typing.overload"]
|
||||
"lib/crewai/tests/**/*.py" = ["S101", "RET504", "S105", "S106"] # Allow assert statements, unnecessary assignments, and hardcoded passwords in tests
|
||||
"lib/crewai-tools/tests/**/*.py" = ["S101", "RET504", "S105", "S106", "RUF012", "N818", "E402", "RUF043", "S110", "B017"] # Allow various test-specific patterns
|
||||
"lib/crewai-files/tests/**/*.py" = ["S101", "RET504", "S105", "S106", "B017", "F841"] # Allow assert statements and blind exception assertions in tests
|
||||
"lib/crewai-a2a/tests/**/*.py" = ["S101", "RET504", "S105", "S106", "B017", "F821"] # Allow assert statements, unnecessary assignments, hardcoded passwords, blind exceptions, and forward refs in tests
|
||||
|
||||
|
||||
[tool.mypy]
|
||||
@@ -118,7 +119,7 @@ warn_return_any = true
|
||||
show_error_codes = true
|
||||
warn_unused_ignores = true
|
||||
python_version = "3.12"
|
||||
exclude = "(?x)(^lib/crewai/src/crewai/cli/templates/|^lib/crewai/tests/|^lib/crewai-tools/tests/|^lib/crewai-files/tests/)"
|
||||
exclude = "(?x)(^lib/crewai/src/crewai/cli/templates/|^lib/crewai/tests/|^lib/crewai-tools/tests/|^lib/crewai-files/tests/|^lib/crewai-a2a/tests/)"
|
||||
plugins = ["pydantic.mypy"]
|
||||
|
||||
|
||||
@@ -134,6 +135,7 @@ testpaths = [
|
||||
"lib/crewai/tests",
|
||||
"lib/crewai-tools/tests",
|
||||
"lib/crewai-files/tests",
|
||||
"lib/crewai-a2a/tests",
|
||||
]
|
||||
asyncio_mode = "strict"
|
||||
asyncio_default_fixture_loop_scope = "function"
|
||||
@@ -157,6 +159,7 @@ members = [
|
||||
"lib/crewai-tools",
|
||||
"lib/devtools",
|
||||
"lib/crewai-files",
|
||||
"lib/crewai-a2a",
|
||||
]
|
||||
|
||||
|
||||
@@ -165,3 +168,4 @@ crewai = { workspace = true }
|
||||
crewai-tools = { workspace = true }
|
||||
crewai-devtools = { workspace = true }
|
||||
crewai-files = { workspace = true }
|
||||
crewai-a2a = { workspace = true }
|
||||
|
||||
31
uv.lock
generated
31
uv.lock
generated
@@ -15,6 +15,7 @@ resolution-markers = [
|
||||
[manifest]
|
||||
members = [
|
||||
"crewai",
|
||||
"crewai-a2a",
|
||||
"crewai-devtools",
|
||||
"crewai-files",
|
||||
"crewai-tools",
|
||||
@@ -1124,10 +1125,7 @@ dependencies = [
|
||||
|
||||
[package.optional-dependencies]
|
||||
a2a = [
|
||||
{ name = "a2a-sdk" },
|
||||
{ name = "aiocache", extra = ["memcached", "redis"] },
|
||||
{ name = "httpx-auth" },
|
||||
{ name = "httpx-sse" },
|
||||
{ name = "crewai-a2a" },
|
||||
]
|
||||
anthropic = [
|
||||
{ name = "anthropic" },
|
||||
@@ -1181,9 +1179,7 @@ watson = [
|
||||
|
||||
[package.metadata]
|
||||
requires-dist = [
|
||||
{ name = "a2a-sdk", marker = "extra == 'a2a'", specifier = "~=0.3.10" },
|
||||
{ name = "aiobotocore", marker = "extra == 'aws'", specifier = "~=2.25.2" },
|
||||
{ name = "aiocache", extras = ["memcached", "redis"], marker = "extra == 'a2a'", specifier = "~=0.12.3" },
|
||||
{ name = "aiosqlite", specifier = "~=0.21.0" },
|
||||
{ name = "anthropic", marker = "extra == 'anthropic'", specifier = "~=0.73.0" },
|
||||
{ name = "appdirs", specifier = "~=1.4.4" },
|
||||
@@ -1192,13 +1188,12 @@ requires-dist = [
|
||||
{ name = "boto3", marker = "extra == 'bedrock'", specifier = "~=1.40.45" },
|
||||
{ name = "chromadb", specifier = "~=1.1.0" },
|
||||
{ name = "click", specifier = "~=8.1.7" },
|
||||
{ name = "crewai-a2a", marker = "extra == 'a2a'", editable = "lib/crewai-a2a" },
|
||||
{ name = "crewai-files", marker = "extra == 'file-processing'", editable = "lib/crewai-files" },
|
||||
{ name = "crewai-tools", marker = "extra == 'tools'", editable = "lib/crewai-tools" },
|
||||
{ name = "docling", marker = "extra == 'docling'", specifier = "~=2.63.0" },
|
||||
{ name = "google-genai", marker = "extra == 'google-genai'", specifier = "~=1.49.0" },
|
||||
{ name = "httpx", specifier = "~=0.28.1" },
|
||||
{ name = "httpx-auth", marker = "extra == 'a2a'", specifier = "~=0.23.1" },
|
||||
{ name = "httpx-sse", marker = "extra == 'a2a'", specifier = "~=0.4.0" },
|
||||
{ name = "ibm-watsonx-ai", marker = "extra == 'watson'", specifier = "~=1.3.39" },
|
||||
{ name = "instructor", specifier = ">=1.3.3" },
|
||||
{ name = "json-repair", specifier = "~=0.25.2" },
|
||||
@@ -1233,6 +1228,26 @@ requires-dist = [
|
||||
]
|
||||
provides-extras = ["a2a", "anthropic", "aws", "azure-ai-inference", "bedrock", "docling", "embeddings", "file-processing", "google-genai", "litellm", "mem0", "openpyxl", "pandas", "qdrant", "tools", "voyageai", "watson"]
|
||||
|
||||
[[package]]
|
||||
name = "crewai-a2a"
|
||||
source = { editable = "lib/crewai-a2a" }
|
||||
dependencies = [
|
||||
{ name = "a2a-sdk" },
|
||||
{ name = "aiocache", extra = ["memcached", "redis"] },
|
||||
{ name = "crewai" },
|
||||
{ name = "httpx-auth" },
|
||||
{ name = "httpx-sse" },
|
||||
]
|
||||
|
||||
[package.metadata]
|
||||
requires-dist = [
|
||||
{ name = "a2a-sdk", specifier = "~=0.3.10" },
|
||||
{ name = "aiocache", extras = ["memcached", "redis"], specifier = "~=0.12.3" },
|
||||
{ name = "crewai", editable = "lib/crewai" },
|
||||
{ name = "httpx-auth", specifier = "~=0.23.1" },
|
||||
{ name = "httpx-sse", specifier = "~=0.4.0" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "crewai-devtools"
|
||||
source = { editable = "lib/devtools" }
|
||||
|
||||
Reference in New Issue
Block a user