Compare commits

..

4 Commits

Author SHA1 Message Date
Lorenze Jay
8f022be106 feat: bump versions to 1.8.1 (#4242)
Some checks failed
CodeQL Advanced / Analyze (actions) (push) Has been cancelled
CodeQL Advanced / Analyze (python) (push) Has been cancelled
Check Documentation Broken Links / Check broken links (push) Has been cancelled
Notify Downstream / notify-downstream (push) Has been cancelled
Mark stale issues and pull requests / stale (push) Has been cancelled
Build uv cache / build-cache (3.10) (push) Has been cancelled
Build uv cache / build-cache (3.11) (push) Has been cancelled
Build uv cache / build-cache (3.12) (push) Has been cancelled
Build uv cache / build-cache (3.13) (push) Has been cancelled
* feat: bump versions to 1.8.1

* bump bump
2026-01-14 20:49:14 -08:00
Greyson LaLonde
6a19b0a279 feat: a2a task execution utilities 2026-01-14 22:56:17 -05:00
Greyson LaLonde
641c336b2c chore: a2a agent card docs, refine existing a2a docs 2026-01-14 22:46:53 -05:00
Greyson LaLonde
22f1812824 feat: add a2a server config; agent card generation 2026-01-14 22:09:11 -05:00
26 changed files with 2055 additions and 929 deletions

View File

@@ -1,43 +1,48 @@
---
title: Agent-to-Agent (A2A) Protocol
description: Enable CrewAI agents to delegate tasks to remote A2A-compliant agents for specialized handling
description: Agents delegate tasks to remote A2A agents and/or operate as A2A-compliant server agents.
icon: network-wired
mode: "wide"
---
## A2A Agent Delegation
CrewAI supports the Agent-to-Agent (A2A) protocol, allowing agents to delegate tasks to remote specialized agents. The agent's LLM automatically decides whether to handle a task directly or delegate to an A2A agent based on the task requirements.
<Note>
A2A delegation requires the `a2a-sdk` package. Install with: `uv add 'crewai[a2a]'` or `pip install 'crewai[a2a]'`
</Note>
CrewAI treats [A2A protocol](https://a2a-protocol.org/latest/) as a first-class delegation primitive, enabling agents to delegate tasks, request information, and collaborate with remote agents, as well as act as A2A-compliant server agents.
In client mode, agents autonomously choose between local execution and remote delegation based on task requirements.
## How It Works
When an agent is configured with A2A capabilities:
1. The LLM analyzes each task
1. The Agent analyzes each task
2. It decides to either:
- Handle the task directly using its own capabilities
- Delegate to a remote A2A agent for specialized handling
3. If delegating, the agent communicates with the remote A2A agent through the protocol
4. Results are returned to the CrewAI workflow
<Note>
A2A delegation requires the `a2a-sdk` package. Install with: `uv add 'crewai[a2a]'` or `pip install 'crewai[a2a]'`
</Note>
## Basic Configuration
<Warning>
`crewai.a2a.config.A2AConfig` is deprecated and will be removed in v2.0.0. Use `A2AClientConfig` for connecting to remote agents and/or `A2AServerConfig` for exposing agents as servers.
</Warning>
Configure an agent for A2A delegation by setting the `a2a` parameter:
```python Code
from crewai import Agent, Crew, Task
from crewai.a2a import A2AConfig
from crewai.a2a import A2AClientConfig
agent = Agent(
role="Research Coordinator",
goal="Coordinate research tasks efficiently",
backstory="Expert at delegating to specialized research agents",
llm="gpt-4o",
a2a=A2AConfig(
a2a=A2AClientConfig(
endpoint="https://example.com/.well-known/agent-card.json",
timeout=120,
max_turns=10
@@ -54,9 +59,9 @@ crew = Crew(agents=[agent], tasks=[task], verbose=True)
result = crew.kickoff()
```
## Configuration Options
## Client Configuration Options
The `A2AConfig` class accepts the following parameters:
The `A2AClientConfig` class accepts the following parameters:
<ParamField path="endpoint" type="str" required>
The A2A agent endpoint URL (typically points to `.well-known/agent-card.json`)
@@ -95,14 +100,30 @@ The `A2AConfig` class accepts the following parameters:
Transport protocol for A2A communication. Options: `JSONRPC` (default), `GRPC`, or `HTTP+JSON`.
</ParamField>
<ParamField path="accepted_output_modes" type="list[str]" default='["application/json"]'>
Media types the client can accept in responses.
</ParamField>
<ParamField path="supported_transports" type="list[str]" default='["JSONRPC"]'>
Ordered list of transport protocols the client supports.
</ParamField>
<ParamField path="use_client_preference" type="bool" default="False">
Whether to prioritize client transport preferences over server.
</ParamField>
<ParamField path="extensions" type="list[str]" default="[]">
Extension URIs the client supports.
</ParamField>
## Authentication
For A2A agents that require authentication, use one of the provided auth schemes:
<Tabs>
<Tab title="Bearer Token">
```python Code
from crewai.a2a import A2AConfig
```python bearer_token_auth.py lines
from crewai.a2a import A2AClientConfig
from crewai.a2a.auth import BearerTokenAuth
agent = Agent(
@@ -110,18 +131,18 @@ agent = Agent(
goal="Coordinate tasks with secured agents",
backstory="Manages secure agent communications",
llm="gpt-4o",
a2a=A2AConfig(
a2a=A2AClientConfig(
endpoint="https://secure-agent.example.com/.well-known/agent-card.json",
auth=BearerTokenAuth(token="your-bearer-token"),
timeout=120
)
)
```
```
</Tab>
<Tab title="API Key">
```python Code
from crewai.a2a import A2AConfig
```python api_key_auth.py lines
from crewai.a2a import A2AClientConfig
from crewai.a2a.auth import APIKeyAuth
agent = Agent(
@@ -129,7 +150,7 @@ agent = Agent(
goal="Coordinate with API-based agents",
backstory="Manages API-authenticated communications",
llm="gpt-4o",
a2a=A2AConfig(
a2a=A2AClientConfig(
endpoint="https://api-agent.example.com/.well-known/agent-card.json",
auth=APIKeyAuth(
api_key="your-api-key",
@@ -139,12 +160,12 @@ agent = Agent(
timeout=120
)
)
```
```
</Tab>
<Tab title="OAuth2">
```python Code
from crewai.a2a import A2AConfig
```python oauth2_auth.py lines
from crewai.a2a import A2AClientConfig
from crewai.a2a.auth import OAuth2ClientCredentials
agent = Agent(
@@ -152,7 +173,7 @@ agent = Agent(
goal="Coordinate with OAuth-secured agents",
backstory="Manages OAuth-authenticated communications",
llm="gpt-4o",
a2a=A2AConfig(
a2a=A2AClientConfig(
endpoint="https://oauth-agent.example.com/.well-known/agent-card.json",
auth=OAuth2ClientCredentials(
token_url="https://auth.example.com/oauth/token",
@@ -163,12 +184,12 @@ agent = Agent(
timeout=120
)
)
```
```
</Tab>
<Tab title="HTTP Basic">
```python Code
from crewai.a2a import A2AConfig
```python http_basic_auth.py lines
from crewai.a2a import A2AClientConfig
from crewai.a2a.auth import HTTPBasicAuth
agent = Agent(
@@ -176,7 +197,7 @@ agent = Agent(
goal="Coordinate with basic auth agents",
backstory="Manages basic authentication communications",
llm="gpt-4o",
a2a=A2AConfig(
a2a=A2AClientConfig(
endpoint="https://basic-agent.example.com/.well-known/agent-card.json",
auth=HTTPBasicAuth(
username="your-username",
@@ -185,7 +206,7 @@ agent = Agent(
timeout=120
)
)
```
```
</Tab>
</Tabs>
@@ -194,7 +215,7 @@ agent = Agent(
Configure multiple A2A agents for delegation by passing a list:
```python Code
from crewai.a2a import A2AConfig
from crewai.a2a import A2AClientConfig
from crewai.a2a.auth import BearerTokenAuth
agent = Agent(
@@ -203,11 +224,11 @@ agent = Agent(
backstory="Expert at delegating to the right specialist",
llm="gpt-4o",
a2a=[
A2AConfig(
A2AClientConfig(
endpoint="https://research.example.com/.well-known/agent-card.json",
timeout=120
),
A2AConfig(
A2AClientConfig(
endpoint="https://data.example.com/.well-known/agent-card.json",
auth=BearerTokenAuth(token="data-token"),
timeout=90
@@ -223,7 +244,7 @@ The LLM will automatically choose which A2A agent to delegate to based on the ta
Control how agent connection failures are handled using the `fail_fast` parameter:
```python Code
from crewai.a2a import A2AConfig
from crewai.a2a import A2AClientConfig
# Fail immediately on connection errors (default)
agent = Agent(
@@ -231,7 +252,7 @@ agent = Agent(
goal="Coordinate research tasks",
backstory="Expert at delegation",
llm="gpt-4o",
a2a=A2AConfig(
a2a=A2AClientConfig(
endpoint="https://research.example.com/.well-known/agent-card.json",
fail_fast=True
)
@@ -244,11 +265,11 @@ agent = Agent(
backstory="Expert at working with available resources",
llm="gpt-4o",
a2a=[
A2AConfig(
A2AClientConfig(
endpoint="https://primary.example.com/.well-known/agent-card.json",
fail_fast=False
),
A2AConfig(
A2AClientConfig(
endpoint="https://backup.example.com/.well-known/agent-card.json",
fail_fast=False
)
@@ -267,8 +288,8 @@ Control how your agent receives task status updates from remote A2A agents:
<Tabs>
<Tab title="Streaming (Default)">
```python Code
from crewai.a2a import A2AConfig
```python streaming_config.py lines
from crewai.a2a import A2AClientConfig
from crewai.a2a.updates import StreamingConfig
agent = Agent(
@@ -276,17 +297,17 @@ agent = Agent(
goal="Coordinate research tasks",
backstory="Expert at delegation",
llm="gpt-4o",
a2a=A2AConfig(
a2a=A2AClientConfig(
endpoint="https://research.example.com/.well-known/agent-card.json",
updates=StreamingConfig()
)
)
```
```
</Tab>
<Tab title="Polling">
```python Code
from crewai.a2a import A2AConfig
```python polling_config.py lines
from crewai.a2a import A2AClientConfig
from crewai.a2a.updates import PollingConfig
agent = Agent(
@@ -294,7 +315,7 @@ agent = Agent(
goal="Coordinate research tasks",
backstory="Expert at delegation",
llm="gpt-4o",
a2a=A2AConfig(
a2a=A2AClientConfig(
endpoint="https://research.example.com/.well-known/agent-card.json",
updates=PollingConfig(
interval=2.0,
@@ -303,12 +324,12 @@ agent = Agent(
)
)
)
```
```
</Tab>
<Tab title="Push Notifications">
```python Code
from crewai.a2a import A2AConfig
```python push_notifications_config.py lines
from crewai.a2a import A2AClientConfig
from crewai.a2a.updates import PushNotificationConfig
agent = Agent(
@@ -316,19 +337,137 @@ agent = Agent(
goal="Coordinate research tasks",
backstory="Expert at delegation",
llm="gpt-4o",
a2a=A2AConfig(
a2a=A2AClientConfig(
endpoint="https://research.example.com/.well-known/agent-card.json",
updates=PushNotificationConfig(
url={base_url}/a2a/callback",
url="{base_url}/a2a/callback",
token="your-validation-token",
timeout=300.0
)
)
)
```
```
</Tab>
</Tabs>
## Exposing Agents as A2A Servers
You can expose your CrewAI agents as A2A-compliant servers, allowing other A2A clients to delegate tasks to them.
### Server Configuration
Add an `A2AServerConfig` to your agent to enable server capabilities:
```python a2a_server_agent.py lines
from crewai import Agent
from crewai.a2a import A2AServerConfig
agent = Agent(
role="Data Analyst",
goal="Analyze datasets and provide insights",
backstory="Expert data scientist with statistical analysis skills",
llm="gpt-4o",
a2a=A2AServerConfig(url="https://your-server.com")
)
```
### Server Configuration Options
<ParamField path="name" type="str" default="None">
Human-readable name for the agent. Defaults to the agent's role if not provided.
</ParamField>
<ParamField path="description" type="str" default="None">
Human-readable description. Defaults to the agent's goal and backstory if not provided.
</ParamField>
<ParamField path="version" type="str" default="1.0.0">
Version string for the agent card.
</ParamField>
<ParamField path="skills" type="list[AgentSkill]" default="[]">
List of agent skills. Auto-generated from agent tools if not provided.
</ParamField>
<ParamField path="capabilities" type="AgentCapabilities" default="AgentCapabilities(streaming=True, push_notifications=False)">
Declaration of optional capabilities supported by the agent.
</ParamField>
<ParamField path="default_input_modes" type="list[str]" default='["text/plain", "application/json"]'>
Supported input MIME types.
</ParamField>
<ParamField path="default_output_modes" type="list[str]" default='["text/plain", "application/json"]'>
Supported output MIME types.
</ParamField>
<ParamField path="url" type="str" default="None">
Preferred endpoint URL. If set, overrides the URL passed to `to_agent_card()`.
</ParamField>
<ParamField path="preferred_transport" type="Literal['JSONRPC', 'GRPC', 'HTTP+JSON']" default="JSONRPC">
Transport protocol for the preferred endpoint.
</ParamField>
<ParamField path="protocol_version" type="str" default="0.3">
A2A protocol version this agent supports.
</ParamField>
<ParamField path="provider" type="AgentProvider" default="None">
Information about the agent's service provider.
</ParamField>
<ParamField path="documentation_url" type="str" default="None">
URL to the agent's documentation.
</ParamField>
<ParamField path="icon_url" type="str" default="None">
URL to an icon for the agent.
</ParamField>
<ParamField path="additional_interfaces" type="list[AgentInterface]" default="[]">
Additional supported interfaces (transport and URL combinations).
</ParamField>
<ParamField path="security" type="list[dict[str, list[str]]]" default="[]">
Security requirement objects for all agent interactions.
</ParamField>
<ParamField path="security_schemes" type="dict[str, SecurityScheme]" default="{}">
Security schemes available to authorize requests.
</ParamField>
<ParamField path="supports_authenticated_extended_card" type="bool" default="False">
Whether agent provides extended card to authenticated users.
</ParamField>
<ParamField path="signatures" type="list[AgentCardSignature]" default="[]">
JSON Web Signatures for the AgentCard.
</ParamField>
### Combined Client and Server
An agent can act as both client and server by providing both configurations:
```python Code
from crewai import Agent
from crewai.a2a import A2AClientConfig, A2AServerConfig
agent = Agent(
role="Research Coordinator",
goal="Coordinate research and serve analysis requests",
backstory="Expert at delegation and analysis",
llm="gpt-4o",
a2a=[
A2AClientConfig(
endpoint="https://specialist.example.com/.well-known/agent-card.json",
timeout=120
),
A2AServerConfig(url="https://your-server.com")
]
)
```
## Best Practices
<CardGroup cols={2}>

View File

@@ -12,7 +12,7 @@ dependencies = [
"pytube~=15.0.0",
"requests~=2.32.5",
"docker~=7.1.0",
"crewai==1.8.0",
"crewai==1.8.1",
"lancedb~=0.5.4",
"tiktoken~=0.8.0",
"beautifulsoup4~=4.13.4",

View File

@@ -291,4 +291,4 @@ __all__ = [
"ZapierActionTools",
]
__version__ = "1.8.0"
__version__ = "1.8.1"

View File

@@ -49,7 +49,7 @@ Repository = "https://github.com/crewAIInc/crewAI"
[project.optional-dependencies]
tools = [
"crewai-tools==1.8.0",
"crewai-tools==1.8.1",
]
embeddings = [
"tiktoken~=0.8.0"

View File

@@ -40,7 +40,7 @@ def _suppress_pydantic_deprecation_warnings() -> None:
_suppress_pydantic_deprecation_warnings()
__version__ = "1.8.0"
__version__ = "1.8.1"
_telemetry_submitted = False

View File

@@ -1,8 +1,10 @@
"""Agent-to-Agent (A2A) protocol communication module for CrewAI."""
from crewai.a2a.config import A2AConfig
from crewai.a2a.config import A2AClientConfig, A2AConfig, A2AServerConfig
__all__ = [
"A2AClientConfig",
"A2AConfig",
"A2AServerConfig",
]

View File

@@ -5,45 +5,57 @@ This module is separate from experimental.a2a to avoid circular imports.
from __future__ import annotations
from typing import Annotated, Any, ClassVar, Literal
from importlib.metadata import version
from typing import Any, ClassVar, Literal
from pydantic import (
BaseModel,
BeforeValidator,
ConfigDict,
Field,
HttpUrl,
TypeAdapter,
)
from pydantic import BaseModel, ConfigDict, Field
from typing_extensions import deprecated
from crewai.a2a.auth.schemas import AuthScheme
from crewai.a2a.types import TransportType, Url
try:
from a2a.types import (
AgentCapabilities,
AgentCardSignature,
AgentInterface,
AgentProvider,
AgentSkill,
SecurityScheme,
)
from crewai.a2a.updates import UpdateConfig
except ImportError:
UpdateConfig = Any
AgentCapabilities = Any
AgentCardSignature = Any
AgentInterface = Any
AgentProvider = Any
SecurityScheme = Any
AgentSkill = Any
UpdateConfig = Any # type: ignore[misc,assignment]
http_url_adapter = TypeAdapter(HttpUrl)
Url = Annotated[
str,
BeforeValidator(
lambda value: str(http_url_adapter.validate_python(value, strict=True))
),
]
def _get_default_update_config() -> UpdateConfig:
from crewai.a2a.updates import StreamingConfig
return StreamingConfig()
@deprecated(
"""
`crewai.a2a.config.A2AConfig` is deprecated and will be removed in v2.0.0,
use `crewai.a2a.config.A2AClientConfig` or `crewai.a2a.config.A2AServerConfig` instead.
""",
category=FutureWarning,
)
class A2AConfig(BaseModel):
"""Configuration for A2A protocol integration.
Deprecated:
Use A2AClientConfig instead. This class will be removed in a future version.
Attributes:
endpoint: A2A agent endpoint URL.
auth: Authentication scheme.
@@ -87,3 +99,176 @@ class A2AConfig(BaseModel):
default="JSONRPC",
description="Specified mode of A2A transport protocol",
)
class A2AClientConfig(BaseModel):
"""Configuration for connecting to remote A2A agents.
Attributes:
endpoint: A2A agent endpoint URL.
auth: Authentication scheme.
timeout: Request timeout in seconds.
max_turns: Maximum conversation turns with A2A agent.
response_model: Optional Pydantic model for structured A2A agent responses.
fail_fast: If True, raise error when agent unreachable; if False, skip and continue.
trust_remote_completion_status: If True, return A2A agent's result directly when completed.
updates: Update mechanism config.
accepted_output_modes: Media types the client can accept in responses.
supported_transports: Ordered list of transport protocols the client supports.
use_client_preference: Whether to prioritize client transport preferences over server.
extensions: Extension URIs the client supports.
"""
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
endpoint: Url = Field(description="A2A agent endpoint URL")
auth: AuthScheme | None = Field(
default=None,
description="Authentication scheme",
)
timeout: int = Field(default=120, description="Request timeout in seconds")
max_turns: int = Field(
default=10, description="Maximum conversation turns with A2A agent"
)
response_model: type[BaseModel] | None = Field(
default=None,
description="Optional Pydantic model for structured A2A agent responses",
)
fail_fast: bool = Field(
default=True,
description="If True, raise error when agent unreachable; if False, skip",
)
trust_remote_completion_status: bool = Field(
default=False,
description="If True, return A2A result directly when completed",
)
updates: UpdateConfig = Field(
default_factory=_get_default_update_config,
description="Update mechanism config",
)
accepted_output_modes: list[str] = Field(
default_factory=lambda: ["application/json"],
description="Media types the client can accept in responses",
)
supported_transports: list[str] = Field(
default_factory=lambda: ["JSONRPC"],
description="Ordered list of transport protocols the client supports",
)
use_client_preference: bool = Field(
default=False,
description="Whether to prioritize client transport preferences over server",
)
extensions: list[str] = Field(
default_factory=list,
description="Extension URIs the client supports",
)
transport_protocol: Literal["JSONRPC", "GRPC", "HTTP+JSON"] = Field(
default="JSONRPC",
description="Specified mode of A2A transport protocol",
)
class A2AServerConfig(BaseModel):
"""Configuration for exposing a Crew or Agent as an A2A server.
All fields correspond to A2A AgentCard fields. Fields like name, description,
and skills can be auto-derived from the Crew/Agent if not provided.
Attributes:
name: Human-readable name for the agent.
description: Human-readable description of the agent.
version: Version string for the agent card.
skills: List of agent skills/capabilities.
default_input_modes: Default supported input MIME types.
default_output_modes: Default supported output MIME types.
capabilities: Declaration of optional capabilities.
preferred_transport: Transport protocol for the preferred endpoint.
protocol_version: A2A protocol version this agent supports.
provider: Information about the agent's service provider.
documentation_url: URL to the agent's documentation.
icon_url: URL to an icon for the agent.
additional_interfaces: Additional supported interfaces.
security: Security requirement objects for all interactions.
security_schemes: Security schemes available to authorize requests.
supports_authenticated_extended_card: Whether agent provides extended card to authenticated users.
url: Preferred endpoint URL for the agent.
signatures: JSON Web Signatures for the AgentCard.
"""
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
name: str | None = Field(
default=None,
description="Human-readable name for the agent. Auto-derived from Crew/Agent if not provided.",
)
description: str | None = Field(
default=None,
description="Human-readable description of the agent. Auto-derived from Crew/Agent if not provided.",
)
version: str = Field(
default="1.0.0",
description="Version string for the agent card",
)
skills: list[AgentSkill] = Field(
default_factory=list,
description="List of agent skills. Auto-derived from tasks/tools if not provided.",
)
default_input_modes: list[str] = Field(
default_factory=lambda: ["text/plain", "application/json"],
description="Default supported input MIME types",
)
default_output_modes: list[str] = Field(
default_factory=lambda: ["text/plain", "application/json"],
description="Default supported output MIME types",
)
capabilities: AgentCapabilities = Field(
default_factory=lambda: AgentCapabilities(
streaming=True,
push_notifications=False,
),
description="Declaration of optional capabilities supported by the agent",
)
preferred_transport: TransportType = Field(
default="JSONRPC",
description="Transport protocol for the preferred endpoint",
)
protocol_version: str = Field(
default_factory=lambda: version("a2a-sdk"),
description="A2A protocol version this agent supports",
)
provider: AgentProvider | None = Field(
default=None,
description="Information about the agent's service provider",
)
documentation_url: Url | None = Field(
default=None,
description="URL to the agent's documentation",
)
icon_url: Url | None = Field(
default=None,
description="URL to an icon for the agent",
)
additional_interfaces: list[AgentInterface] = Field(
default_factory=list,
description="Additional supported interfaces (transport and URL combinations)",
)
security: list[dict[str, list[str]]] = Field(
default_factory=list,
description="Security requirement objects for all agent interactions",
)
security_schemes: dict[str, SecurityScheme] = Field(
default_factory=dict,
description="Security schemes available to authorize requests",
)
supports_authenticated_extended_card: bool = Field(
default=False,
description="Whether agent provides extended card to authenticated users",
)
url: Url | None = Field(
default=None,
description="Preferred endpoint URL for the agent. Set at runtime if not provided.",
)
signatures: list[AgentCardSignature] = Field(
default_factory=list,
description="JSON Web Signatures for the AgentCard",
)

View File

@@ -1,7 +1,17 @@
"""Type definitions for A2A protocol message parts."""
from typing import Any, Literal, Protocol, TypedDict, runtime_checkable
from __future__ import annotations
from typing import (
Annotated,
Any,
Literal,
Protocol,
TypedDict,
runtime_checkable,
)
from pydantic import BeforeValidator, HttpUrl, TypeAdapter
from typing_extensions import NotRequired
from crewai.a2a.updates import (
@@ -15,6 +25,18 @@ from crewai.a2a.updates import (
)
TransportType = Literal["JSONRPC", "GRPC", "HTTP+JSON"]
http_url_adapter: TypeAdapter[HttpUrl] = TypeAdapter(HttpUrl)
Url = Annotated[
str,
BeforeValidator(
lambda value: str(http_url_adapter.validate_python(value, strict=True))
),
]
@runtime_checkable
class AgentResponseProtocol(Protocol):
"""Protocol for the dynamically created AgentResponse model."""

View File

@@ -0,0 +1 @@
"""A2A utility modules for client operations."""

View File

@@ -0,0 +1,399 @@
"""AgentCard utilities for A2A client and server operations."""
from __future__ import annotations
import asyncio
from collections.abc import MutableMapping
from functools import lru_cache
import time
from types import MethodType
from typing import TYPE_CHECKING
from a2a.client.errors import A2AClientHTTPError
from a2a.types import AgentCapabilities, AgentCard, AgentSkill
from aiocache import cached # type: ignore[import-untyped]
from aiocache.serializers import PickleSerializer # type: ignore[import-untyped]
import httpx
from crewai.a2a.auth.schemas import APIKeyAuth, HTTPDigestAuth
from crewai.a2a.auth.utils import (
_auth_store,
configure_auth_client,
retry_on_401,
)
from crewai.a2a.config import A2AServerConfig
from crewai.crew import Crew
if TYPE_CHECKING:
from crewai.a2a.auth.schemas import AuthScheme
from crewai.agent import Agent
from crewai.task import Task
def _get_server_config(agent: Agent) -> A2AServerConfig | None:
"""Get A2AServerConfig from an agent's a2a configuration.
Args:
agent: The Agent instance to check.
Returns:
A2AServerConfig if present, None otherwise.
"""
if agent.a2a is None:
return None
if isinstance(agent.a2a, A2AServerConfig):
return agent.a2a
if isinstance(agent.a2a, list):
for config in agent.a2a:
if isinstance(config, A2AServerConfig):
return config
return None
def fetch_agent_card(
endpoint: str,
auth: AuthScheme | None = None,
timeout: int = 30,
use_cache: bool = True,
cache_ttl: int = 300,
) -> AgentCard:
"""Fetch AgentCard from an A2A endpoint with optional caching.
Args:
endpoint: A2A agent endpoint URL (AgentCard URL).
auth: Optional AuthScheme for authentication.
timeout: Request timeout in seconds.
use_cache: Whether to use caching (default True).
cache_ttl: Cache TTL in seconds (default 300 = 5 minutes).
Returns:
AgentCard object with agent capabilities and skills.
Raises:
httpx.HTTPStatusError: If the request fails.
A2AClientHTTPError: If authentication fails.
"""
if use_cache:
if auth:
auth_data = auth.model_dump_json(
exclude={
"_access_token",
"_token_expires_at",
"_refresh_token",
"_authorization_callback",
}
)
auth_hash = hash((type(auth).__name__, auth_data))
else:
auth_hash = 0
_auth_store[auth_hash] = auth
ttl_hash = int(time.time() // cache_ttl)
return _fetch_agent_card_cached(endpoint, auth_hash, timeout, ttl_hash)
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
return loop.run_until_complete(
afetch_agent_card(endpoint=endpoint, auth=auth, timeout=timeout)
)
finally:
loop.close()
async def afetch_agent_card(
endpoint: str,
auth: AuthScheme | None = None,
timeout: int = 30,
use_cache: bool = True,
) -> AgentCard:
"""Fetch AgentCard from an A2A endpoint asynchronously.
Native async implementation. Use this when running in an async context.
Args:
endpoint: A2A agent endpoint URL (AgentCard URL).
auth: Optional AuthScheme for authentication.
timeout: Request timeout in seconds.
use_cache: Whether to use caching (default True).
Returns:
AgentCard object with agent capabilities and skills.
Raises:
httpx.HTTPStatusError: If the request fails.
A2AClientHTTPError: If authentication fails.
"""
if use_cache:
if auth:
auth_data = auth.model_dump_json(
exclude={
"_access_token",
"_token_expires_at",
"_refresh_token",
"_authorization_callback",
}
)
auth_hash = hash((type(auth).__name__, auth_data))
else:
auth_hash = 0
_auth_store[auth_hash] = auth
agent_card: AgentCard = await _afetch_agent_card_cached(
endpoint, auth_hash, timeout
)
return agent_card
return await _afetch_agent_card_impl(endpoint=endpoint, auth=auth, timeout=timeout)
@lru_cache()
def _fetch_agent_card_cached(
endpoint: str,
auth_hash: int,
timeout: int,
_ttl_hash: int,
) -> AgentCard:
"""Cached sync version of fetch_agent_card."""
auth = _auth_store.get(auth_hash)
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
return loop.run_until_complete(
_afetch_agent_card_impl(endpoint=endpoint, auth=auth, timeout=timeout)
)
finally:
loop.close()
@cached(ttl=300, serializer=PickleSerializer()) # type: ignore[untyped-decorator]
async def _afetch_agent_card_cached(
endpoint: str,
auth_hash: int,
timeout: int,
) -> AgentCard:
"""Cached async implementation of AgentCard fetching."""
auth = _auth_store.get(auth_hash)
return await _afetch_agent_card_impl(endpoint=endpoint, auth=auth, timeout=timeout)
async def _afetch_agent_card_impl(
endpoint: str,
auth: AuthScheme | None,
timeout: int,
) -> AgentCard:
"""Internal async implementation of AgentCard fetching."""
if "/.well-known/agent-card.json" in endpoint:
base_url = endpoint.replace("/.well-known/agent-card.json", "")
agent_card_path = "/.well-known/agent-card.json"
else:
url_parts = endpoint.split("/", 3)
base_url = f"{url_parts[0]}//{url_parts[2]}"
agent_card_path = f"/{url_parts[3]}" if len(url_parts) > 3 else "/"
headers: MutableMapping[str, str] = {}
if auth:
async with httpx.AsyncClient(timeout=timeout) as temp_auth_client:
if isinstance(auth, (HTTPDigestAuth, APIKeyAuth)):
configure_auth_client(auth, temp_auth_client)
headers = await auth.apply_auth(temp_auth_client, {})
async with httpx.AsyncClient(timeout=timeout, headers=headers) as temp_client:
if auth and isinstance(auth, (HTTPDigestAuth, APIKeyAuth)):
configure_auth_client(auth, temp_client)
agent_card_url = f"{base_url}{agent_card_path}"
async def _fetch_agent_card_request() -> httpx.Response:
return await temp_client.get(agent_card_url)
try:
response = await retry_on_401(
request_func=_fetch_agent_card_request,
auth_scheme=auth,
client=temp_client,
headers=temp_client.headers,
max_retries=2,
)
response.raise_for_status()
return AgentCard.model_validate(response.json())
except httpx.HTTPStatusError as e:
if e.response.status_code == 401:
error_details = ["Authentication failed"]
www_auth = e.response.headers.get("WWW-Authenticate")
if www_auth:
error_details.append(f"WWW-Authenticate: {www_auth}")
if not auth:
error_details.append("No auth scheme provided")
msg = " | ".join(error_details)
raise A2AClientHTTPError(401, msg) from e
raise
def _task_to_skill(task: Task) -> AgentSkill:
"""Convert a CrewAI Task to an A2A AgentSkill.
Args:
task: The CrewAI Task to convert.
Returns:
AgentSkill representing the task's capability.
"""
task_name = task.name or task.description[:50]
task_id = task_name.lower().replace(" ", "_")
tags: list[str] = []
if task.agent:
tags.append(task.agent.role.lower().replace(" ", "-"))
return AgentSkill(
id=task_id,
name=task_name,
description=task.description,
tags=tags,
examples=[task.expected_output] if task.expected_output else None,
)
def _tool_to_skill(tool_name: str, tool_description: str) -> AgentSkill:
"""Convert an Agent's tool to an A2A AgentSkill.
Args:
tool_name: Name of the tool.
tool_description: Description of what the tool does.
Returns:
AgentSkill representing the tool's capability.
"""
tool_id = tool_name.lower().replace(" ", "_")
return AgentSkill(
id=tool_id,
name=tool_name,
description=tool_description,
tags=[tool_name.lower().replace(" ", "-")],
)
def _crew_to_agent_card(crew: Crew, url: str) -> AgentCard:
"""Generate an A2A AgentCard from a Crew instance.
Args:
crew: The Crew instance to generate a card for.
url: The base URL where this crew will be exposed.
Returns:
AgentCard describing the crew's capabilities.
"""
crew_name = getattr(crew, "name", None) or crew.__class__.__name__
description_parts: list[str] = []
crew_description = getattr(crew, "description", None)
if crew_description:
description_parts.append(crew_description)
else:
agent_roles = [agent.role for agent in crew.agents]
description_parts.append(
f"A crew of {len(crew.agents)} agents: {', '.join(agent_roles)}"
)
skills = [_task_to_skill(task) for task in crew.tasks]
return AgentCard(
name=crew_name,
description=" ".join(description_parts),
url=url,
version="1.0.0",
capabilities=AgentCapabilities(
streaming=True,
push_notifications=True,
),
default_input_modes=["text/plain", "application/json"],
default_output_modes=["text/plain", "application/json"],
skills=skills,
)
def _agent_to_agent_card(agent: Agent, url: str) -> AgentCard:
"""Generate an A2A AgentCard from an Agent instance.
Uses A2AServerConfig values when available, falling back to agent properties.
Args:
agent: The Agent instance to generate a card for.
url: The base URL where this agent will be exposed.
Returns:
AgentCard describing the agent's capabilities.
"""
server_config = _get_server_config(agent) or A2AServerConfig()
name = server_config.name or agent.role
description_parts = [agent.goal]
if agent.backstory:
description_parts.append(agent.backstory)
description = server_config.description or " ".join(description_parts)
skills: list[AgentSkill] = (
server_config.skills.copy() if server_config.skills else []
)
if not skills:
if agent.tools:
for tool in agent.tools:
tool_name = getattr(tool, "name", None) or tool.__class__.__name__
tool_desc = getattr(tool, "description", None) or f"Tool: {tool_name}"
skills.append(_tool_to_skill(tool_name, tool_desc))
if not skills:
skills.append(
AgentSkill(
id=agent.role.lower().replace(" ", "_"),
name=agent.role,
description=agent.goal,
tags=[agent.role.lower().replace(" ", "-")],
)
)
return AgentCard(
name=name,
description=description,
url=server_config.url or url,
version=server_config.version,
capabilities=server_config.capabilities,
default_input_modes=server_config.default_input_modes,
default_output_modes=server_config.default_output_modes,
skills=skills,
protocol_version=server_config.protocol_version,
provider=server_config.provider,
documentation_url=server_config.documentation_url,
icon_url=server_config.icon_url,
additional_interfaces=server_config.additional_interfaces,
security=server_config.security,
security_schemes=server_config.security_schemes,
supports_authenticated_extended_card=server_config.supports_authenticated_extended_card,
signatures=server_config.signatures,
)
def inject_a2a_server_methods(agent: Agent) -> None:
"""Inject A2A server methods onto an Agent instance.
Adds a `to_agent_card(url: str) -> AgentCard` method to the agent
that generates an A2A-compliant AgentCard.
Only injects if the agent has an A2AServerConfig.
Args:
agent: The Agent instance to inject methods onto.
"""
if _get_server_config(agent) is None:
return
def _to_agent_card(self: Agent, url: str) -> AgentCard:
return _agent_to_agent_card(self, url)
object.__setattr__(agent, "to_agent_card", MethodType(_to_agent_card, agent))

View File

@@ -1,16 +1,14 @@
"""Utility functions for A2A (Agent-to-Agent) protocol delegation."""
"""A2A delegation utilities for executing tasks on remote agents."""
from __future__ import annotations
import asyncio
from collections.abc import AsyncIterator, MutableMapping
from contextlib import asynccontextmanager
from functools import lru_cache
import time
from typing import TYPE_CHECKING, Any, Literal
import uuid
from a2a.client import A2AClientHTTPError, Client, ClientConfig, ClientFactory
from a2a.client import Client, ClientConfig, ClientFactory
from a2a.types import (
AgentCard,
Message,
@@ -19,19 +17,15 @@ from a2a.types import (
Role,
TextPart,
)
from aiocache import cached # type: ignore[import-untyped]
from aiocache.serializers import PickleSerializer # type: ignore[import-untyped]
import httpx
from pydantic import BaseModel, Field, create_model
from pydantic import BaseModel
from crewai.a2a.auth.schemas import APIKeyAuth, HTTPDigestAuth
from crewai.a2a.auth.utils import (
_auth_store,
configure_auth_client,
retry_on_401,
validate_auth_against_agent_card,
)
from crewai.a2a.config import A2AConfig
from crewai.a2a.task_helpers import TaskStateResult
from crewai.a2a.types import (
HANDLER_REGISTRY,
@@ -45,6 +39,7 @@ from crewai.a2a.updates import (
StreamingHandler,
UpdateConfig,
)
from crewai.a2a.utils.agent_card import _afetch_agent_card_cached
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.a2a_events import (
A2AConversationStartedEvent,
@@ -52,7 +47,6 @@ from crewai.events.types.a2a_events import (
A2ADelegationStartedEvent,
A2AMessageSentEvent,
)
from crewai.types.utils import create_literals_from_strings
if TYPE_CHECKING:
@@ -75,187 +69,6 @@ def get_handler(config: UpdateConfig | None) -> HandlerType:
return HANDLER_REGISTRY.get(type(config), StreamingHandler)
@lru_cache()
def _fetch_agent_card_cached(
endpoint: str,
auth_hash: int,
timeout: int,
_ttl_hash: int,
) -> AgentCard:
"""Cached sync version of fetch_agent_card."""
auth = _auth_store.get(auth_hash)
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
return loop.run_until_complete(
_afetch_agent_card_impl(endpoint=endpoint, auth=auth, timeout=timeout)
)
finally:
loop.close()
def fetch_agent_card(
endpoint: str,
auth: AuthScheme | None = None,
timeout: int = 30,
use_cache: bool = True,
cache_ttl: int = 300,
) -> AgentCard:
"""Fetch AgentCard from an A2A endpoint with optional caching.
Args:
endpoint: A2A agent endpoint URL (AgentCard URL)
auth: Optional AuthScheme for authentication
timeout: Request timeout in seconds
use_cache: Whether to use caching (default True)
cache_ttl: Cache TTL in seconds (default 300 = 5 minutes)
Returns:
AgentCard object with agent capabilities and skills
Raises:
httpx.HTTPStatusError: If the request fails
A2AClientHTTPError: If authentication fails
"""
if use_cache:
if auth:
auth_data = auth.model_dump_json(
exclude={
"_access_token",
"_token_expires_at",
"_refresh_token",
"_authorization_callback",
}
)
auth_hash = hash((type(auth).__name__, auth_data))
else:
auth_hash = 0
_auth_store[auth_hash] = auth
ttl_hash = int(time.time() // cache_ttl)
return _fetch_agent_card_cached(endpoint, auth_hash, timeout, ttl_hash)
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
return loop.run_until_complete(
afetch_agent_card(endpoint=endpoint, auth=auth, timeout=timeout)
)
finally:
loop.close()
async def afetch_agent_card(
endpoint: str,
auth: AuthScheme | None = None,
timeout: int = 30,
use_cache: bool = True,
) -> AgentCard:
"""Fetch AgentCard from an A2A endpoint asynchronously.
Native async implementation. Use this when running in an async context.
Args:
endpoint: A2A agent endpoint URL (AgentCard URL).
auth: Optional AuthScheme for authentication.
timeout: Request timeout in seconds.
use_cache: Whether to use caching (default True).
Returns:
AgentCard object with agent capabilities and skills.
Raises:
httpx.HTTPStatusError: If the request fails.
A2AClientHTTPError: If authentication fails.
"""
if use_cache:
if auth:
auth_data = auth.model_dump_json(
exclude={
"_access_token",
"_token_expires_at",
"_refresh_token",
"_authorization_callback",
}
)
auth_hash = hash((type(auth).__name__, auth_data))
else:
auth_hash = 0
_auth_store[auth_hash] = auth
agent_card: AgentCard = await _afetch_agent_card_cached(
endpoint, auth_hash, timeout
)
return agent_card
return await _afetch_agent_card_impl(endpoint=endpoint, auth=auth, timeout=timeout)
@cached(ttl=300, serializer=PickleSerializer()) # type: ignore[untyped-decorator]
async def _afetch_agent_card_cached(
endpoint: str,
auth_hash: int,
timeout: int,
) -> AgentCard:
"""Cached async implementation of AgentCard fetching."""
auth = _auth_store.get(auth_hash)
return await _afetch_agent_card_impl(endpoint=endpoint, auth=auth, timeout=timeout)
async def _afetch_agent_card_impl(
endpoint: str,
auth: AuthScheme | None,
timeout: int,
) -> AgentCard:
"""Internal async implementation of AgentCard fetching."""
if "/.well-known/agent-card.json" in endpoint:
base_url = endpoint.replace("/.well-known/agent-card.json", "")
agent_card_path = "/.well-known/agent-card.json"
else:
url_parts = endpoint.split("/", 3)
base_url = f"{url_parts[0]}//{url_parts[2]}"
agent_card_path = f"/{url_parts[3]}" if len(url_parts) > 3 else "/"
headers: MutableMapping[str, str] = {}
if auth:
async with httpx.AsyncClient(timeout=timeout) as temp_auth_client:
if isinstance(auth, (HTTPDigestAuth, APIKeyAuth)):
configure_auth_client(auth, temp_auth_client)
headers = await auth.apply_auth(temp_auth_client, {})
async with httpx.AsyncClient(timeout=timeout, headers=headers) as temp_client:
if auth and isinstance(auth, (HTTPDigestAuth, APIKeyAuth)):
configure_auth_client(auth, temp_client)
agent_card_url = f"{base_url}{agent_card_path}"
async def _fetch_agent_card_request() -> httpx.Response:
return await temp_client.get(agent_card_url)
try:
response = await retry_on_401(
request_func=_fetch_agent_card_request,
auth_scheme=auth,
client=temp_client,
headers=temp_client.headers,
max_retries=2,
)
response.raise_for_status()
return AgentCard.model_validate(response.json())
except httpx.HTTPStatusError as e:
if e.response.status_code == 401:
error_details = ["Authentication failed"]
www_auth = e.response.headers.get("WWW-Authenticate")
if www_auth:
error_details.append(f"WWW-Authenticate: {www_auth}")
if not auth:
error_details.append("No auth scheme provided")
msg = " | ".join(error_details)
raise A2AClientHTTPError(401, msg) from e
raise
def execute_a2a_delegation(
endpoint: str,
transport_protocol: Literal["JSONRPC", "GRPC", "HTTP+JSON"],
@@ -644,19 +457,18 @@ async def _create_a2a_client(
"""Create and configure an A2A client.
Args:
agent_card: The A2A agent card
transport_protocol: Transport protocol to use
timeout: Request timeout in seconds
headers: HTTP headers (already with auth applied)
streaming: Enable streaming responses
auth: Optional AuthScheme for client configuration
use_polling: Enable polling mode
push_notification_config: Optional push notification config to include in requests
agent_card: The A2A agent card.
transport_protocol: Transport protocol to use.
timeout: Request timeout in seconds.
headers: HTTP headers (already with auth applied).
streaming: Enable streaming responses.
auth: Optional AuthScheme for client configuration.
use_polling: Enable polling mode.
push_notification_config: Optional push notification config.
Yields:
Configured A2A client instance
Configured A2A client instance.
"""
async with httpx.AsyncClient(
timeout=timeout,
headers=headers,
@@ -687,78 +499,3 @@ async def _create_a2a_client(
factory = ClientFactory(config)
client = factory.create(agent_card)
yield client
def create_agent_response_model(agent_ids: tuple[str, ...]) -> type[BaseModel]:
"""Create a dynamic AgentResponse model with Literal types for agent IDs.
Args:
agent_ids: List of available A2A agent IDs
Returns:
Dynamically created Pydantic model with Literal-constrained a2a_ids field
"""
DynamicLiteral = create_literals_from_strings(agent_ids) # noqa: N806
return create_model(
"AgentResponse",
a2a_ids=(
tuple[DynamicLiteral, ...], # type: ignore[valid-type]
Field(
default_factory=tuple,
max_length=len(agent_ids),
description="A2A agent IDs to delegate to.",
),
),
message=(
str,
Field(
description="The message content. If is_a2a=true, this is sent to the A2A agent. If is_a2a=false, this is your final answer ending the conversation."
),
),
is_a2a=(
bool,
Field(
description="Set to false when the remote agent has answered your question - extract their answer and return it as your final message. Set to true ONLY if you need to ask a NEW, DIFFERENT question. NEVER repeat the same request - if the conversation history shows the agent already answered, set is_a2a=false immediately."
),
),
__base__=BaseModel,
)
def extract_a2a_agent_ids_from_config(
a2a_config: list[A2AConfig] | A2AConfig | None,
) -> tuple[list[A2AConfig], tuple[str, ...]]:
"""Extract A2A agent IDs from A2A configuration.
Args:
a2a_config: A2A configuration
Returns:
List of A2A agent IDs
"""
if a2a_config is None:
return [], ()
if isinstance(a2a_config, A2AConfig):
a2a_agents = [a2a_config]
else:
a2a_agents = a2a_config
return a2a_agents, tuple(config.endpoint for config in a2a_agents)
def get_a2a_agents_and_response_model(
a2a_config: list[A2AConfig] | A2AConfig | None,
) -> tuple[list[A2AConfig], type[BaseModel]]:
"""Get A2A agent IDs and response model.
Args:
a2a_config: A2A configuration
Returns:
Tuple of A2A agent IDs and response model
"""
a2a_agents, agent_ids = extract_a2a_agent_ids_from_config(a2a_config=a2a_config)
return a2a_agents, create_agent_response_model(agent_ids)

View File

@@ -0,0 +1,101 @@
"""Response model utilities for A2A agent interactions."""
from __future__ import annotations
from typing import TypeAlias
from pydantic import BaseModel, Field, create_model
from crewai.a2a.config import A2AClientConfig, A2AConfig, A2AServerConfig
from crewai.types.utils import create_literals_from_strings
A2AConfigTypes: TypeAlias = A2AConfig | A2AServerConfig | A2AClientConfig
A2AClientConfigTypes: TypeAlias = A2AConfig | A2AClientConfig
def create_agent_response_model(agent_ids: tuple[str, ...]) -> type[BaseModel] | None:
"""Create a dynamic AgentResponse model with Literal types for agent IDs.
Args:
agent_ids: List of available A2A agent IDs.
Returns:
Dynamically created Pydantic model with Literal-constrained a2a_ids field,
or None if agent_ids is empty.
"""
if not agent_ids:
return None
DynamicLiteral = create_literals_from_strings(agent_ids) # noqa: N806
return create_model(
"AgentResponse",
a2a_ids=(
tuple[DynamicLiteral, ...], # type: ignore[valid-type]
Field(
default_factory=tuple,
max_length=len(agent_ids),
description="A2A agent IDs to delegate to.",
),
),
message=(
str,
Field(
description="The message content. If is_a2a=true, this is sent to the A2A agent. If is_a2a=false, this is your final answer ending the conversation."
),
),
is_a2a=(
bool,
Field(
description="Set to false when the remote agent has answered your question - extract their answer and return it as your final message. Set to true ONLY if you need to ask a NEW, DIFFERENT question. NEVER repeat the same request - if the conversation history shows the agent already answered, set is_a2a=false immediately."
),
),
__base__=BaseModel,
)
def extract_a2a_agent_ids_from_config(
a2a_config: list[A2AConfigTypes] | A2AConfigTypes | None,
) -> tuple[list[A2AClientConfigTypes], tuple[str, ...]]:
"""Extract A2A agent IDs from A2A configuration.
Filters out A2AServerConfig since it doesn't have an endpoint for delegation.
Args:
a2a_config: A2A configuration (any type).
Returns:
Tuple of client A2A configs list and agent endpoint IDs.
"""
if a2a_config is None:
return [], ()
configs: list[A2AConfigTypes]
if isinstance(a2a_config, (A2AConfig, A2AClientConfig, A2AServerConfig)):
configs = [a2a_config]
else:
configs = a2a_config
# Filter to only client configs (those with endpoint)
client_configs: list[A2AClientConfigTypes] = [
config for config in configs if isinstance(config, (A2AConfig, A2AClientConfig))
]
return client_configs, tuple(config.endpoint for config in client_configs)
def get_a2a_agents_and_response_model(
a2a_config: list[A2AConfigTypes] | A2AConfigTypes | None,
) -> tuple[list[A2AClientConfigTypes], type[BaseModel] | None]:
"""Get A2A agent configs and response model.
Args:
a2a_config: A2A configuration (any type).
Returns:
Tuple of client A2A configs and response model.
"""
a2a_agents, agent_ids = extract_a2a_agent_ids_from_config(a2a_config=a2a_config)
return a2a_agents, create_agent_response_model(agent_ids)

View File

@@ -0,0 +1,284 @@
"""A2A task utilities for server-side task management."""
from __future__ import annotations
import asyncio
from collections.abc import Callable, Coroutine
from functools import wraps
import logging
import os
from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar, cast
from a2a.server.agent_execution import RequestContext
from a2a.server.events import EventQueue
from a2a.types import (
InternalError,
InvalidParamsError,
Message,
Task as A2ATask,
TaskState,
TaskStatus,
TaskStatusUpdateEvent,
)
from a2a.utils import new_agent_text_message, new_text_artifact
from a2a.utils.errors import ServerError
from aiocache import SimpleMemoryCache, caches # type: ignore[import-untyped]
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.a2a_events import (
A2AServerTaskCanceledEvent,
A2AServerTaskCompletedEvent,
A2AServerTaskFailedEvent,
A2AServerTaskStartedEvent,
)
from crewai.task import Task
if TYPE_CHECKING:
from crewai.agent import Agent
logger = logging.getLogger(__name__)
P = ParamSpec("P")
T = TypeVar("T")
def _parse_redis_url(url: str) -> dict[str, Any]:
from urllib.parse import urlparse
parsed = urlparse(url)
config: dict[str, Any] = {
"cache": "aiocache.RedisCache",
"endpoint": parsed.hostname or "localhost",
"port": parsed.port or 6379,
}
if parsed.path and parsed.path != "/":
try:
config["db"] = int(parsed.path.lstrip("/"))
except ValueError:
pass
if parsed.password:
config["password"] = parsed.password
return config
_redis_url = os.environ.get("REDIS_URL")
caches.set_config(
{
"default": _parse_redis_url(_redis_url)
if _redis_url
else {
"cache": "aiocache.SimpleMemoryCache",
}
}
)
def cancellable(
fn: Callable[P, Coroutine[Any, Any, T]],
) -> Callable[P, Coroutine[Any, Any, T]]:
"""Decorator that enables cancellation for A2A task execution.
Runs a cancellation watcher concurrently with the wrapped function.
When a cancel event is published, the execution is cancelled.
Args:
fn: The async function to wrap.
Returns:
Wrapped function with cancellation support.
"""
@wraps(fn)
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
"""Wrap function with cancellation monitoring."""
context: RequestContext | None = None
for arg in args:
if isinstance(arg, RequestContext):
context = arg
break
if context is None:
context = cast(RequestContext | None, kwargs.get("context"))
if context is None:
return await fn(*args, **kwargs)
task_id = context.task_id
cache = caches.get("default")
async def poll_for_cancel() -> bool:
"""Poll cache for cancellation flag."""
while True:
if await cache.get(f"cancel:{task_id}"):
return True
await asyncio.sleep(0.1)
async def watch_for_cancel() -> bool:
"""Watch for cancellation events via pub/sub or polling."""
if isinstance(cache, SimpleMemoryCache):
return await poll_for_cancel()
try:
client = cache.client
pubsub = client.pubsub()
await pubsub.subscribe(f"cancel:{task_id}")
async for message in pubsub.listen():
if message["type"] == "message":
return True
except Exception as e:
logger.warning("Cancel watcher error for task_id=%s: %s", task_id, e)
return await poll_for_cancel()
return False
execute_task = asyncio.create_task(fn(*args, **kwargs))
cancel_watch = asyncio.create_task(watch_for_cancel())
try:
done, _ = await asyncio.wait(
[execute_task, cancel_watch],
return_when=asyncio.FIRST_COMPLETED,
)
if cancel_watch in done:
execute_task.cancel()
try:
await execute_task
except asyncio.CancelledError:
pass
raise asyncio.CancelledError(f"Task {task_id} was cancelled")
cancel_watch.cancel()
return execute_task.result()
finally:
await cache.delete(f"cancel:{task_id}")
return wrapper
@cancellable
async def execute(
agent: Agent,
context: RequestContext,
event_queue: EventQueue,
) -> None:
"""Execute an A2A task using a CrewAI agent.
Args:
agent: The CrewAI agent to execute the task.
context: The A2A request context containing the user's message.
event_queue: The event queue for sending responses back.
TODOs:
* need to impl both of structured output and file inputs, depends on `file_inputs` for
`crewai.task.Task`, pass the below two to Task. both utils in `a2a.utils.parts`
* structured outputs ingestion, `structured_inputs = get_data_parts(parts=context.message.parts)`
* file inputs ingestion, `file_inputs = get_file_parts(parts=context.message.parts)`
"""
user_message = context.get_user_input()
task_id = context.task_id
context_id = context.context_id
if task_id is None or context_id is None:
msg = "task_id and context_id are required"
crewai_event_bus.emit(
agent,
A2AServerTaskFailedEvent(a2a_task_id="", a2a_context_id="", error=msg),
)
raise ServerError(InvalidParamsError(message=msg)) from None
task = Task(
description=user_message,
expected_output="Response to the user's request",
agent=agent,
)
crewai_event_bus.emit(
agent,
A2AServerTaskStartedEvent(a2a_task_id=task_id, a2a_context_id=context_id),
)
try:
result = await agent.aexecute_task(task=task, tools=agent.tools)
result_str = str(result)
history: list[Message] = [context.message] if context.message else []
history.append(new_agent_text_message(result_str, context_id, task_id))
await event_queue.enqueue_event(
A2ATask(
id=task_id,
context_id=context_id,
status=TaskStatus(state=TaskState.input_required),
artifacts=[new_text_artifact(result_str, f"result_{task_id}")],
history=history,
)
)
crewai_event_bus.emit(
agent,
A2AServerTaskCompletedEvent(
a2a_task_id=task_id, a2a_context_id=context_id, result=str(result)
),
)
except asyncio.CancelledError:
crewai_event_bus.emit(
agent,
A2AServerTaskCanceledEvent(a2a_task_id=task_id, a2a_context_id=context_id),
)
raise
except Exception as e:
crewai_event_bus.emit(
agent,
A2AServerTaskFailedEvent(
a2a_task_id=task_id, a2a_context_id=context_id, error=str(e)
),
)
raise ServerError(
error=InternalError(message=f"Task execution failed: {e}")
) from e
async def cancel(
context: RequestContext,
event_queue: EventQueue,
) -> A2ATask | None:
"""Cancel an A2A task.
Publishes a cancel event that the cancellable decorator listens for.
Args:
context: The A2A request context containing task information.
event_queue: The event queue for sending the cancellation status.
Returns:
The canceled task with updated status.
"""
task_id = context.task_id
context_id = context.context_id
if task_id is None or context_id is None:
raise ServerError(InvalidParamsError(message="task_id and context_id required"))
if context.current_task and context.current_task.status.state in (
TaskState.completed,
TaskState.failed,
TaskState.canceled,
):
return context.current_task
cache = caches.get("default")
await cache.set(f"cancel:{task_id}", True, ttl=3600)
if not isinstance(cache, SimpleMemoryCache):
await cache.client.publish(f"cancel:{task_id}", "cancel")
await event_queue.enqueue_event(
TaskStatusUpdateEvent(
task_id=task_id,
context_id=context_id,
status=TaskStatus(state=TaskState.canceled),
final=True,
)
)
if context.current_task:
context.current_task.status = TaskStatus(state=TaskState.canceled)
return context.current_task
return None

View File

@@ -15,7 +15,7 @@ from typing import TYPE_CHECKING, Any
from a2a.types import Role, TaskState
from pydantic import BaseModel, ValidationError
from crewai.a2a.config import A2AConfig
from crewai.a2a.config import A2AClientConfig, A2AConfig
from crewai.a2a.extensions.base import ExtensionRegistry
from crewai.a2a.task_helpers import TaskStateResult
from crewai.a2a.templates import (
@@ -26,13 +26,16 @@ from crewai.a2a.templates import (
UNAVAILABLE_AGENTS_NOTICE_TEMPLATE,
)
from crewai.a2a.types import AgentResponseProtocol
from crewai.a2a.utils import (
aexecute_a2a_delegation,
from crewai.a2a.utils.agent_card import (
afetch_agent_card,
execute_a2a_delegation,
fetch_agent_card,
get_a2a_agents_and_response_model,
inject_a2a_server_methods,
)
from crewai.a2a.utils.delegation import (
aexecute_a2a_delegation,
execute_a2a_delegation,
)
from crewai.a2a.utils.response_model import get_a2a_agents_and_response_model
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.a2a_events import (
A2AConversationCompletedEvent,
@@ -122,10 +125,12 @@ def wrap_agent_with_a2a_instance(
agent, "aexecute_task", MethodType(aexecute_task_with_a2a, agent)
)
inject_a2a_server_methods(agent)
def _fetch_card_from_config(
config: A2AConfig,
) -> tuple[A2AConfig, AgentCard | Exception]:
config: A2AConfig | A2AClientConfig,
) -> tuple[A2AConfig | A2AClientConfig, AgentCard | Exception]:
"""Fetch agent card from A2A config.
Args:
@@ -146,7 +151,7 @@ def _fetch_card_from_config(
def _fetch_agent_cards_concurrently(
a2a_agents: list[A2AConfig],
a2a_agents: list[A2AConfig | A2AClientConfig],
) -> tuple[dict[str, AgentCard], dict[str, str]]:
"""Fetch agent cards concurrently for multiple A2A agents.
@@ -181,7 +186,7 @@ def _fetch_agent_cards_concurrently(
def _execute_task_with_a2a(
self: Agent,
a2a_agents: list[A2AConfig],
a2a_agents: list[A2AConfig | A2AClientConfig],
original_fn: Callable[..., str],
task: Task,
agent_response_model: type[BaseModel],
@@ -270,7 +275,7 @@ def _execute_task_with_a2a(
def _augment_prompt_with_a2a(
a2a_agents: list[A2AConfig],
a2a_agents: list[A2AConfig | A2AClientConfig],
task_description: str,
agent_cards: dict[str, AgentCard],
conversation_history: list[Message] | None = None,
@@ -523,11 +528,11 @@ def _prepare_delegation_context(
task: Task,
original_task_description: str | None,
) -> tuple[
list[A2AConfig],
list[A2AConfig | A2AClientConfig],
type[BaseModel],
str,
str,
A2AConfig,
A2AConfig | A2AClientConfig,
str | None,
str | None,
dict[str, Any] | None,
@@ -591,7 +596,7 @@ def _handle_task_completion(
task: Task,
task_id_config: str | None,
reference_task_ids: list[str],
agent_config: A2AConfig,
agent_config: A2AConfig | A2AClientConfig,
turn_num: int,
) -> tuple[str | None, str | None, list[str]]:
"""Handle task completion state including reference task updates.
@@ -631,7 +636,7 @@ def _handle_agent_response_and_continue(
a2a_result: TaskStateResult,
agent_id: str,
agent_cards: dict[str, AgentCard] | None,
a2a_agents: list[A2AConfig],
a2a_agents: list[A2AConfig | A2AClientConfig],
original_task_description: str,
conversation_history: list[Message],
turn_num: int,
@@ -868,8 +873,8 @@ def _delegate_to_a2a(
async def _afetch_card_from_config(
config: A2AConfig,
) -> tuple[A2AConfig, AgentCard | Exception]:
config: A2AConfig | A2AClientConfig,
) -> tuple[A2AConfig | A2AClientConfig, AgentCard | Exception]:
"""Fetch agent card from A2A config asynchronously."""
try:
card = await afetch_agent_card(
@@ -883,7 +888,7 @@ async def _afetch_card_from_config(
async def _afetch_agent_cards_concurrently(
a2a_agents: list[A2AConfig],
a2a_agents: list[A2AConfig | A2AClientConfig],
) -> tuple[dict[str, AgentCard], dict[str, str]]:
"""Fetch agent cards concurrently for multiple A2A agents using asyncio."""
agent_cards: dict[str, AgentCard] = {}
@@ -908,7 +913,7 @@ async def _afetch_agent_cards_concurrently(
async def _aexecute_task_with_a2a(
self: Agent,
a2a_agents: list[A2AConfig],
a2a_agents: list[A2AConfig | A2AClientConfig],
original_fn: Callable[..., Coroutine[Any, Any, str]],
task: Task,
agent_response_model: type[BaseModel],
@@ -987,7 +992,7 @@ async def _ahandle_agent_response_and_continue(
a2a_result: TaskStateResult,
agent_id: str,
agent_cards: dict[str, AgentCard] | None,
a2a_agents: list[A2AConfig],
a2a_agents: list[A2AConfig | A2AClientConfig],
original_task_description: str,
conversation_history: list[Message],
turn_num: int,

View File

@@ -17,7 +17,6 @@ from urllib.parse import urlparse
from pydantic import BaseModel, Field, InstanceOf, PrivateAttr, model_validator
from typing_extensions import Self
from crewai.a2a.config import A2AConfig
from crewai.agent.utils import (
ahandle_knowledge_retrieval,
apply_training_data,
@@ -73,11 +72,19 @@ from crewai.utilities.constants import TRAINED_AGENTS_DATA_FILE, TRAINING_DATA_F
from crewai.utilities.converter import Converter
from crewai.utilities.guardrail_types import GuardrailType
from crewai.utilities.llm_utils import create_llm
from crewai.utilities.prompts import Prompts
from crewai.utilities.prompts import Prompts, StandardPromptResult, SystemPromptResult
from crewai.utilities.token_counter_callback import TokenCalcHandler
from crewai.utilities.training_handler import CrewTrainingHandler
try:
from crewai.a2a.config import A2AClientConfig, A2AConfig, A2AServerConfig
except ImportError:
A2AClientConfig = Any
A2AConfig = Any
A2AServerConfig = Any
if TYPE_CHECKING:
from crewai_tools import CodeInterpreterTool
@@ -218,9 +225,18 @@ class Agent(BaseAgent):
guardrail_max_retries: int = Field(
default=3, description="Maximum number of retries when guardrail fails"
)
a2a: list[A2AConfig] | A2AConfig | None = Field(
a2a: (
list[A2AConfig | A2AServerConfig | A2AClientConfig]
| A2AConfig
| A2AServerConfig
| A2AClientConfig
| None
) = Field(
default=None,
description="A2A (Agent-to-Agent) configuration for delegating tasks to remote agents. Can be a single A2AConfig or a dict mapping agent IDs to configs.",
description="""
A2A (Agent-to-Agent) configuration for delegating tasks to remote agents.
Can be a single A2AConfig/A2AClientConfig/A2AServerConfig, or a list of any number of A2AConfig/A2AClientConfig with a single A2AServerConfig.
""",
)
executor_class: type[CrewAgentExecutor] | type[CrewAgentExecutorFlow] = Field(
default=CrewAgentExecutor,
@@ -733,7 +749,7 @@ class Agent(BaseAgent):
if self.agent_executor is not None:
self._update_executor_parameters(
task=task,
tools=parsed_tools,
tools=parsed_tools, # type: ignore[arg-type]
raw_tools=raw_tools,
prompt=prompt,
stop_words=stop_words,
@@ -742,7 +758,7 @@ class Agent(BaseAgent):
else:
self.agent_executor = self.executor_class(
llm=cast(BaseLLM, self.llm),
task=task,
task=task, # type: ignore[arg-type]
i18n=self.i18n,
agent=self,
crew=self.crew,
@@ -765,11 +781,11 @@ class Agent(BaseAgent):
def _update_executor_parameters(
self,
task: Task | None,
tools: list,
tools: list[BaseTool],
raw_tools: list[BaseTool],
prompt: dict,
prompt: SystemPromptResult | StandardPromptResult,
stop_words: list[str],
rpm_limit_fn: Callable | None,
rpm_limit_fn: Callable | None, # type: ignore[type-arg]
) -> None:
"""Update executor parameters without recreating instance.

View File

@@ -5,7 +5,7 @@ description = "{{name}} using crewAI"
authors = [{ name = "Your Name", email = "you@example.com" }]
requires-python = ">=3.10,<3.14"
dependencies = [
"crewai[tools]==1.8.0"
"crewai[tools]==1.8.1"
]
[project.scripts]

View File

@@ -5,7 +5,7 @@ description = "{{name}} using crewAI"
authors = [{ name = "Your Name", email = "you@example.com" }]
requires-python = ">=3.10,<3.14"
dependencies = [
"crewai[tools]==1.8.0"
"crewai[tools]==1.8.1"
]
[project.scripts]

View File

@@ -1,3 +1,20 @@
from crewai.events.types.a2a_events import (
A2AConversationCompletedEvent,
A2AConversationStartedEvent,
A2ADelegationCompletedEvent,
A2ADelegationStartedEvent,
A2AMessageSentEvent,
A2APollingStartedEvent,
A2APollingStatusEvent,
A2APushNotificationReceivedEvent,
A2APushNotificationRegisteredEvent,
A2APushNotificationTimeoutEvent,
A2AResponseReceivedEvent,
A2AServerTaskCanceledEvent,
A2AServerTaskCompletedEvent,
A2AServerTaskFailedEvent,
A2AServerTaskStartedEvent,
)
from crewai.events.types.agent_events import (
AgentExecutionCompletedEvent,
AgentExecutionErrorEvent,
@@ -76,7 +93,22 @@ from crewai.events.types.tool_usage_events import (
EventTypes = (
CrewKickoffStartedEvent
A2AConversationCompletedEvent
| A2AConversationStartedEvent
| A2ADelegationCompletedEvent
| A2ADelegationStartedEvent
| A2AMessageSentEvent
| A2APollingStartedEvent
| A2APollingStatusEvent
| A2APushNotificationReceivedEvent
| A2APushNotificationRegisteredEvent
| A2APushNotificationTimeoutEvent
| A2AResponseReceivedEvent
| A2AServerTaskCanceledEvent
| A2AServerTaskCompletedEvent
| A2AServerTaskFailedEvent
| A2AServerTaskStartedEvent
| CrewKickoffStartedEvent
| CrewKickoffCompletedEvent
| CrewKickoffFailedEvent
| CrewTestStartedEvent

View File

@@ -210,3 +210,37 @@ class A2APushNotificationTimeoutEvent(A2AEventBase):
type: str = "a2a_push_notification_timeout"
task_id: str
timeout_seconds: float
class A2AServerTaskStartedEvent(A2AEventBase):
"""Event emitted when an A2A server task execution starts."""
type: str = "a2a_server_task_started"
a2a_task_id: str
a2a_context_id: str
class A2AServerTaskCompletedEvent(A2AEventBase):
"""Event emitted when an A2A server task execution completes."""
type: str = "a2a_server_task_completed"
a2a_task_id: str
a2a_context_id: str
result: str
class A2AServerTaskCanceledEvent(A2AEventBase):
"""Event emitted when an A2A server task execution is canceled."""
type: str = "a2a_server_task_canceled"
a2a_task_id: str
a2a_context_id: str
class A2AServerTaskFailedEvent(A2AEventBase):
"""Event emitted when an A2A server task execution fails."""
type: str = "a2a_server_task_failed"
a2a_task_id: str
a2a_context_id: str
error: str

View File

@@ -1,8 +1,6 @@
"""Utilities for creating and manipulating types."""
from typing import Annotated, Final, Literal
from typing_extensions import TypeAliasType
from typing import Annotated, Final, Literal, cast
_DYNAMIC_LITERAL_ALIAS: Final[Literal["DynamicLiteral"]] = "DynamicLiteral"
@@ -20,6 +18,11 @@ def create_literals_from_strings(
Returns:
Literal type for each A2A agent ID
Raises:
ValueError: If values is empty (Literal requires at least one value)
"""
unique_values: tuple[str, ...] = tuple(dict.fromkeys(values))
return Literal.__getitem__(unique_values)
if not unique_values:
raise ValueError("Cannot create Literal type from empty values")
return cast(type, Literal.__getitem__(unique_values))

View File

@@ -229,48 +229,6 @@ def enforce_rpm_limit(
request_within_rpm_limit()
def _extract_tools_from_context(
executor_context: CrewAgentExecutor | LiteAgent | None,
) -> list[dict[str, Any]] | None:
"""Extract tools from executor context and convert to LLM-compatible format.
Args:
executor_context: The executor context containing tools.
Returns:
List of tool dictionaries in LLM-compatible format, or None if no tools.
"""
if executor_context is None:
return None
# Get tools from executor context
# CrewAgentExecutor has 'tools' attribute, LiteAgent has '_parsed_tools'
tools: list[CrewStructuredTool] | None = None
if hasattr(executor_context, "tools"):
context_tools = executor_context.tools
if isinstance(context_tools, list) and len(context_tools) > 0:
tools = context_tools
if tools is None and hasattr(executor_context, "_parsed_tools"):
parsed_tools = executor_context._parsed_tools
if isinstance(parsed_tools, list) and len(parsed_tools) > 0:
tools = parsed_tools
if not tools:
return None
# Convert CrewStructuredTool to dict format expected by LLM
tool_dicts: list[dict[str, Any]] = []
for tool in tools:
tool_dict: dict[str, Any] = {
"name": tool.name,
"description": tool.description,
"args_schema": tool.args_schema,
}
tool_dicts.append(tool_dict)
return tool_dicts if tool_dicts else None
def get_llm_response(
llm: LLM | BaseLLM,
messages: list[LLMMessage],
@@ -306,29 +264,14 @@ def get_llm_response(
raise ValueError("LLM call blocked by before_llm_call hook")
messages = executor_context.messages
# Extract tools from executor context for native function calling support
tools = _extract_tools_from_context(executor_context)
try:
# Only pass tools parameter if tools are available to maintain backward compatibility
# with code that checks "tools" in kwargs
if tools is not None:
answer = llm.call(
messages,
tools=tools,
callbacks=callbacks,
from_task=from_task,
from_agent=from_agent, # type: ignore[arg-type]
response_model=response_model,
)
else:
answer = llm.call(
messages,
callbacks=callbacks,
from_task=from_task,
from_agent=from_agent, # type: ignore[arg-type]
response_model=response_model,
)
answer = llm.call(
messages,
callbacks=callbacks,
from_task=from_task,
from_agent=from_agent, # type: ignore[arg-type]
response_model=response_model,
)
except Exception as e:
raise e
if not answer:
@@ -349,7 +292,7 @@ async def aget_llm_response(
from_task: Task | None = None,
from_agent: Agent | LiteAgent | None = None,
response_model: type[BaseModel] | None = None,
executor_context: CrewAgentExecutor | LiteAgent | None = None,
executor_context: CrewAgentExecutor | None = None,
) -> str:
"""Call the LLM asynchronously and return the response.
@@ -375,29 +318,14 @@ async def aget_llm_response(
raise ValueError("LLM call blocked by before_llm_call hook")
messages = executor_context.messages
# Extract tools from executor context for native function calling support
tools = _extract_tools_from_context(executor_context)
try:
# Only pass tools parameter if tools are available to maintain backward compatibility
# with code that checks "tools" in kwargs
if tools is not None:
answer = await llm.acall(
messages,
tools=tools,
callbacks=callbacks,
from_task=from_task,
from_agent=from_agent, # type: ignore[arg-type]
response_model=response_model,
)
else:
answer = await llm.acall(
messages,
callbacks=callbacks,
from_task=from_task,
from_agent=from_agent, # type: ignore[arg-type]
response_model=response_model,
)
answer = await llm.acall(
messages,
callbacks=callbacks,
from_task=from_task,
from_agent=from_agent, # type: ignore[arg-type]
response_model=response_model,
)
except Exception as e:
raise e
if not answer:

View File

@@ -0,0 +1,325 @@
"""Tests for A2A agent card utilities."""
from __future__ import annotations
from a2a.types import AgentCard, AgentSkill
from crewai import Agent
from crewai.a2a.config import A2AClientConfig, A2AServerConfig
from crewai.a2a.utils.agent_card import inject_a2a_server_methods
class TestInjectA2AServerMethods:
"""Tests for inject_a2a_server_methods function."""
def test_agent_with_server_config_gets_to_agent_card_method(self) -> None:
"""Agent with A2AServerConfig should have to_agent_card method injected."""
agent = Agent(
role="Test Agent",
goal="Test goal",
backstory="Test backstory",
a2a=A2AServerConfig(),
)
assert hasattr(agent, "to_agent_card")
assert callable(agent.to_agent_card)
def test_agent_without_server_config_no_injection(self) -> None:
"""Agent without A2AServerConfig should not get to_agent_card method."""
agent = Agent(
role="Test Agent",
goal="Test goal",
backstory="Test backstory",
a2a=A2AClientConfig(endpoint="http://example.com"),
)
assert not hasattr(agent, "to_agent_card")
def test_agent_without_a2a_no_injection(self) -> None:
"""Agent without any a2a config should not get to_agent_card method."""
agent = Agent(
role="Test Agent",
goal="Test goal",
backstory="Test backstory",
)
assert not hasattr(agent, "to_agent_card")
def test_agent_with_mixed_configs_gets_injection(self) -> None:
"""Agent with list containing A2AServerConfig should get to_agent_card."""
agent = Agent(
role="Test Agent",
goal="Test goal",
backstory="Test backstory",
a2a=[
A2AClientConfig(endpoint="http://example.com"),
A2AServerConfig(name="My Agent"),
],
)
assert hasattr(agent, "to_agent_card")
assert callable(agent.to_agent_card)
def test_manual_injection_on_plain_agent(self) -> None:
"""inject_a2a_server_methods should work when called manually."""
agent = Agent(
role="Test Agent",
goal="Test goal",
backstory="Test backstory",
)
# Manually set server config and inject
object.__setattr__(agent, "a2a", A2AServerConfig())
inject_a2a_server_methods(agent)
assert hasattr(agent, "to_agent_card")
assert callable(agent.to_agent_card)
class TestToAgentCard:
"""Tests for the injected to_agent_card method."""
def test_returns_agent_card(self) -> None:
"""to_agent_card should return an AgentCard instance."""
agent = Agent(
role="Test Agent",
goal="Test goal",
backstory="Test backstory",
a2a=A2AServerConfig(),
)
card = agent.to_agent_card("http://localhost:8000")
assert isinstance(card, AgentCard)
def test_uses_agent_role_as_name(self) -> None:
"""AgentCard name should default to agent role."""
agent = Agent(
role="Data Analyst",
goal="Analyze data",
backstory="Expert analyst",
a2a=A2AServerConfig(),
)
card = agent.to_agent_card("http://localhost:8000")
assert card.name == "Data Analyst"
def test_uses_server_config_name(self) -> None:
"""AgentCard name should prefer A2AServerConfig.name over role."""
agent = Agent(
role="Data Analyst",
goal="Analyze data",
backstory="Expert analyst",
a2a=A2AServerConfig(name="Custom Agent Name"),
)
card = agent.to_agent_card("http://localhost:8000")
assert card.name == "Custom Agent Name"
def test_uses_goal_as_description(self) -> None:
"""AgentCard description should include agent goal."""
agent = Agent(
role="Test Agent",
goal="Accomplish important tasks",
backstory="Has extensive experience",
a2a=A2AServerConfig(),
)
card = agent.to_agent_card("http://localhost:8000")
assert "Accomplish important tasks" in card.description
def test_uses_server_config_description(self) -> None:
"""AgentCard description should prefer A2AServerConfig.description."""
agent = Agent(
role="Test Agent",
goal="Accomplish important tasks",
backstory="Has extensive experience",
a2a=A2AServerConfig(description="Custom description"),
)
card = agent.to_agent_card("http://localhost:8000")
assert card.description == "Custom description"
def test_uses_provided_url(self) -> None:
"""AgentCard url should use the provided URL."""
agent = Agent(
role="Test Agent",
goal="Test goal",
backstory="Test backstory",
a2a=A2AServerConfig(),
)
card = agent.to_agent_card("http://my-server.com:9000")
assert card.url == "http://my-server.com:9000"
def test_uses_server_config_url(self) -> None:
"""AgentCard url should prefer A2AServerConfig.url over provided URL."""
agent = Agent(
role="Test Agent",
goal="Test goal",
backstory="Test backstory",
a2a=A2AServerConfig(url="http://configured-url.com"),
)
card = agent.to_agent_card("http://fallback-url.com")
assert card.url == "http://configured-url.com/"
def test_generates_default_skill(self) -> None:
"""AgentCard should have at least one skill based on agent role."""
agent = Agent(
role="Research Assistant",
goal="Help with research",
backstory="Skilled researcher",
a2a=A2AServerConfig(),
)
card = agent.to_agent_card("http://localhost:8000")
assert len(card.skills) >= 1
skill = card.skills[0]
assert skill.name == "Research Assistant"
assert skill.description == "Help with research"
def test_uses_server_config_skills(self) -> None:
"""AgentCard skills should prefer A2AServerConfig.skills."""
custom_skill = AgentSkill(
id="custom-skill",
name="Custom Skill",
description="A custom skill",
tags=["custom"],
)
agent = Agent(
role="Test Agent",
goal="Test goal",
backstory="Test backstory",
a2a=A2AServerConfig(skills=[custom_skill]),
)
card = agent.to_agent_card("http://localhost:8000")
assert len(card.skills) == 1
assert card.skills[0].id == "custom-skill"
assert card.skills[0].name == "Custom Skill"
def test_includes_custom_version(self) -> None:
"""AgentCard should include version from A2AServerConfig."""
agent = Agent(
role="Test Agent",
goal="Test goal",
backstory="Test backstory",
a2a=A2AServerConfig(version="2.0.0"),
)
card = agent.to_agent_card("http://localhost:8000")
assert card.version == "2.0.0"
def test_default_version(self) -> None:
"""AgentCard should have default version 1.0.0."""
agent = Agent(
role="Test Agent",
goal="Test goal",
backstory="Test backstory",
a2a=A2AServerConfig(),
)
card = agent.to_agent_card("http://localhost:8000")
assert card.version == "1.0.0"
class TestAgentCardJsonStructure:
"""Tests for the JSON structure of AgentCard."""
def test_json_has_required_fields(self) -> None:
"""AgentCard JSON should contain all required A2A protocol fields."""
agent = Agent(
role="Test Agent",
goal="Test goal",
backstory="Test backstory",
a2a=A2AServerConfig(),
)
card = agent.to_agent_card("http://localhost:8000")
json_data = card.model_dump()
assert "name" in json_data
assert "description" in json_data
assert "url" in json_data
assert "version" in json_data
assert "skills" in json_data
assert "capabilities" in json_data
assert "defaultInputModes" in json_data
assert "defaultOutputModes" in json_data
def test_json_skills_structure(self) -> None:
"""Each skill in JSON should have required fields."""
agent = Agent(
role="Test Agent",
goal="Test goal",
backstory="Test backstory",
a2a=A2AServerConfig(),
)
card = agent.to_agent_card("http://localhost:8000")
json_data = card.model_dump()
assert len(json_data["skills"]) >= 1
skill = json_data["skills"][0]
assert "id" in skill
assert "name" in skill
assert "description" in skill
assert "tags" in skill
def test_json_capabilities_structure(self) -> None:
"""Capabilities in JSON should have expected fields."""
agent = Agent(
role="Test Agent",
goal="Test goal",
backstory="Test backstory",
a2a=A2AServerConfig(),
)
card = agent.to_agent_card("http://localhost:8000")
json_data = card.model_dump()
capabilities = json_data["capabilities"]
assert "streaming" in capabilities
assert "pushNotifications" in capabilities
def test_json_serializable(self) -> None:
"""AgentCard should be JSON serializable."""
agent = Agent(
role="Test Agent",
goal="Test goal",
backstory="Test backstory",
a2a=A2AServerConfig(),
)
card = agent.to_agent_card("http://localhost:8000")
json_str = card.model_dump_json()
assert isinstance(json_str, str)
assert "Test Agent" in json_str
assert "http://localhost:8000" in json_str
def test_json_excludes_none_values(self) -> None:
"""AgentCard JSON with exclude_none should omit None fields."""
agent = Agent(
role="Test Agent",
goal="Test goal",
backstory="Test backstory",
a2a=A2AServerConfig(),
)
card = agent.to_agent_card("http://localhost:8000")
json_data = card.model_dump(exclude_none=True)
assert "provider" not in json_data
assert "documentationUrl" not in json_data
assert "iconUrl" not in json_data

View File

@@ -0,0 +1,370 @@
"""Tests for A2A task utilities."""
from __future__ import annotations
import asyncio
from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
import pytest_asyncio
from a2a.server.agent_execution import RequestContext
from a2a.server.events import EventQueue
from a2a.types import Message, Task as A2ATask, TaskState, TaskStatus
from crewai.a2a.utils.task import cancel, cancellable, execute
@pytest.fixture
def mock_agent() -> MagicMock:
"""Create a mock CrewAI agent."""
agent = MagicMock()
agent.role = "Test Agent"
agent.tools = []
agent.aexecute_task = AsyncMock(return_value="Task completed successfully")
return agent
@pytest.fixture
def mock_task() -> MagicMock:
"""Create a mock Task."""
return MagicMock()
@pytest.fixture
def mock_context() -> MagicMock:
"""Create a mock RequestContext."""
context = MagicMock(spec=RequestContext)
context.task_id = "test-task-123"
context.context_id = "test-context-456"
context.get_user_input.return_value = "Test user message"
context.message = MagicMock(spec=Message)
context.current_task = None
return context
@pytest.fixture
def mock_event_queue() -> AsyncMock:
"""Create a mock EventQueue."""
queue = AsyncMock(spec=EventQueue)
queue.enqueue_event = AsyncMock()
return queue
@pytest_asyncio.fixture(autouse=True)
async def clear_cache(mock_context: MagicMock) -> None:
"""Clear cancel flag from cache before each test."""
from aiocache import caches
cache = caches.get("default")
await cache.delete(f"cancel:{mock_context.task_id}")
class TestCancellableDecorator:
"""Tests for the cancellable decorator."""
@pytest.mark.asyncio
async def test_executes_function_without_context(self) -> None:
"""Function executes normally when no RequestContext is provided."""
call_count = 0
@cancellable
async def my_func(value: int) -> int:
nonlocal call_count
call_count += 1
return value * 2
result = await my_func(5)
assert result == 10
assert call_count == 1
@pytest.mark.asyncio
async def test_executes_function_with_context(self, mock_context: MagicMock) -> None:
"""Function executes normally with RequestContext when not cancelled."""
@cancellable
async def my_func(context: RequestContext) -> str:
await asyncio.sleep(0.01)
return "completed"
result = await my_func(mock_context)
assert result == "completed"
@pytest.mark.asyncio
async def test_cancellation_raises_cancelled_error(
self, mock_context: MagicMock
) -> None:
"""Function raises CancelledError when cancel flag is set."""
from aiocache import caches
cache = caches.get("default")
@cancellable
async def slow_func(context: RequestContext) -> str:
await asyncio.sleep(1.0)
return "should not reach"
await cache.set(f"cancel:{mock_context.task_id}", True)
with pytest.raises(asyncio.CancelledError):
await slow_func(mock_context)
@pytest.mark.asyncio
async def test_cleanup_removes_cancel_flag(self, mock_context: MagicMock) -> None:
"""Cancel flag is cleaned up after execution."""
from aiocache import caches
cache = caches.get("default")
@cancellable
async def quick_func(context: RequestContext) -> str:
return "done"
await quick_func(mock_context)
flag = await cache.get(f"cancel:{mock_context.task_id}")
assert flag is None
@pytest.mark.asyncio
async def test_extracts_context_from_kwargs(self, mock_context: MagicMock) -> None:
"""Context can be passed as keyword argument."""
@cancellable
async def my_func(value: int, context: RequestContext | None = None) -> int:
return value + 1
result = await my_func(10, context=mock_context)
assert result == 11
class TestExecute:
"""Tests for the execute function."""
@pytest.mark.asyncio
async def test_successful_execution(
self,
mock_agent: MagicMock,
mock_context: MagicMock,
mock_event_queue: AsyncMock,
mock_task: MagicMock,
) -> None:
"""Execute completes successfully and enqueues completed task."""
with (
patch("crewai.a2a.utils.task.Task", return_value=mock_task),
patch("crewai.a2a.utils.task.crewai_event_bus") as mock_bus,
):
await execute(mock_agent, mock_context, mock_event_queue)
mock_agent.aexecute_task.assert_called_once()
mock_event_queue.enqueue_event.assert_called_once()
assert mock_bus.emit.call_count == 2
@pytest.mark.asyncio
async def test_emits_started_event(
self,
mock_agent: MagicMock,
mock_context: MagicMock,
mock_event_queue: AsyncMock,
mock_task: MagicMock,
) -> None:
"""Execute emits A2AServerTaskStartedEvent."""
with (
patch("crewai.a2a.utils.task.Task", return_value=mock_task),
patch("crewai.a2a.utils.task.crewai_event_bus") as mock_bus,
):
await execute(mock_agent, mock_context, mock_event_queue)
first_call = mock_bus.emit.call_args_list[0]
event = first_call[0][1]
assert event.type == "a2a_server_task_started"
assert event.a2a_task_id == mock_context.task_id
assert event.a2a_context_id == mock_context.context_id
@pytest.mark.asyncio
async def test_emits_completed_event(
self,
mock_agent: MagicMock,
mock_context: MagicMock,
mock_event_queue: AsyncMock,
mock_task: MagicMock,
) -> None:
"""Execute emits A2AServerTaskCompletedEvent on success."""
with (
patch("crewai.a2a.utils.task.Task", return_value=mock_task),
patch("crewai.a2a.utils.task.crewai_event_bus") as mock_bus,
):
await execute(mock_agent, mock_context, mock_event_queue)
second_call = mock_bus.emit.call_args_list[1]
event = second_call[0][1]
assert event.type == "a2a_server_task_completed"
assert event.a2a_task_id == mock_context.task_id
assert event.result == "Task completed successfully"
@pytest.mark.asyncio
async def test_emits_failed_event_on_exception(
self,
mock_agent: MagicMock,
mock_context: MagicMock,
mock_event_queue: AsyncMock,
mock_task: MagicMock,
) -> None:
"""Execute emits A2AServerTaskFailedEvent on exception."""
mock_agent.aexecute_task = AsyncMock(side_effect=ValueError("Test error"))
with (
patch("crewai.a2a.utils.task.Task", return_value=mock_task),
patch("crewai.a2a.utils.task.crewai_event_bus") as mock_bus,
):
with pytest.raises(Exception):
await execute(mock_agent, mock_context, mock_event_queue)
failed_call = mock_bus.emit.call_args_list[1]
event = failed_call[0][1]
assert event.type == "a2a_server_task_failed"
assert "Test error" in event.error
@pytest.mark.asyncio
async def test_emits_canceled_event_on_cancellation(
self,
mock_agent: MagicMock,
mock_context: MagicMock,
mock_event_queue: AsyncMock,
mock_task: MagicMock,
) -> None:
"""Execute emits A2AServerTaskCanceledEvent on CancelledError."""
mock_agent.aexecute_task = AsyncMock(side_effect=asyncio.CancelledError())
with (
patch("crewai.a2a.utils.task.Task", return_value=mock_task),
patch("crewai.a2a.utils.task.crewai_event_bus") as mock_bus,
):
with pytest.raises(asyncio.CancelledError):
await execute(mock_agent, mock_context, mock_event_queue)
canceled_call = mock_bus.emit.call_args_list[1]
event = canceled_call[0][1]
assert event.type == "a2a_server_task_canceled"
assert event.a2a_task_id == mock_context.task_id
class TestCancel:
"""Tests for the cancel function."""
@pytest.mark.asyncio
async def test_sets_cancel_flag_in_cache(
self,
mock_context: MagicMock,
mock_event_queue: AsyncMock,
) -> None:
"""Cancel sets the cancel flag in cache."""
from aiocache import caches
cache = caches.get("default")
await cancel(mock_context, mock_event_queue)
flag = await cache.get(f"cancel:{mock_context.task_id}")
assert flag is True
@pytest.mark.asyncio
async def test_enqueues_task_status_update_event(
self,
mock_context: MagicMock,
mock_event_queue: AsyncMock,
) -> None:
"""Cancel enqueues TaskStatusUpdateEvent with canceled state."""
await cancel(mock_context, mock_event_queue)
mock_event_queue.enqueue_event.assert_called_once()
event = mock_event_queue.enqueue_event.call_args[0][0]
assert event.task_id == mock_context.task_id
assert event.context_id == mock_context.context_id
assert event.status.state == TaskState.canceled
assert event.final is True
@pytest.mark.asyncio
async def test_returns_none_when_no_current_task(
self,
mock_context: MagicMock,
mock_event_queue: AsyncMock,
) -> None:
"""Cancel returns None when context has no current_task."""
mock_context.current_task = None
result = await cancel(mock_context, mock_event_queue)
assert result is None
@pytest.mark.asyncio
async def test_returns_updated_task_when_current_task_exists(
self,
mock_context: MagicMock,
mock_event_queue: AsyncMock,
) -> None:
"""Cancel returns updated task when context has current_task."""
current_task = MagicMock(spec=A2ATask)
current_task.status = TaskStatus(state=TaskState.working)
mock_context.current_task = current_task
result = await cancel(mock_context, mock_event_queue)
assert result is current_task
assert result.status.state == TaskState.canceled
@pytest.mark.asyncio
async def test_cleanup_after_cancel(
self,
mock_context: MagicMock,
mock_event_queue: AsyncMock,
) -> None:
"""Cancel flag persists for cancellable decorator to detect."""
from aiocache import caches
cache = caches.get("default")
await cancel(mock_context, mock_event_queue)
flag = await cache.get(f"cancel:{mock_context.task_id}")
assert flag is True
await cache.delete(f"cancel:{mock_context.task_id}")
class TestExecuteAndCancelIntegration:
"""Integration tests for execute and cancel working together."""
@pytest.mark.asyncio
async def test_cancel_stops_running_execute(
self,
mock_agent: MagicMock,
mock_context: MagicMock,
mock_event_queue: AsyncMock,
mock_task: MagicMock,
) -> None:
"""Calling cancel stops a running execute."""
async def slow_task(**kwargs: Any) -> str:
await asyncio.sleep(2.0)
return "should not complete"
mock_agent.aexecute_task = slow_task
with (
patch("crewai.a2a.utils.task.Task", return_value=mock_task),
patch("crewai.a2a.utils.task.crewai_event_bus"),
):
execute_task = asyncio.create_task(
execute(mock_agent, mock_context, mock_event_queue)
)
await asyncio.sleep(0.1)
await cancel(mock_context, mock_event_queue)
with pytest.raises(asyncio.CancelledError):
await execute_task

View File

@@ -1,457 +0,0 @@
"""Unit tests for agent_utils module.
Tests the utility functions for agent execution including tool extraction
and LLM response handling.
"""
from unittest.mock import AsyncMock, MagicMock, Mock, patch
import pytest
from pydantic import BaseModel, Field
from crewai.tools.structured_tool import CrewStructuredTool
from crewai.utilities.agent_utils import (
_extract_tools_from_context,
aget_llm_response,
get_llm_response,
)
from crewai.utilities.printer import Printer
class MockArgsSchema(BaseModel):
"""Mock args schema for testing."""
query: str = Field(description="The search query")
class TestExtractToolsFromContext:
"""Test _extract_tools_from_context function."""
def test_returns_none_when_context_is_none(self):
"""Test that None is returned when executor_context is None."""
result = _extract_tools_from_context(None)
assert result is None
def test_returns_none_when_no_tools_attribute(self):
"""Test that None is returned when context has no tools."""
mock_context = Mock(spec=[])
result = _extract_tools_from_context(mock_context)
assert result is None
def test_returns_none_when_tools_is_empty(self):
"""Test that None is returned when tools list is empty."""
mock_context = Mock()
mock_context.tools = []
result = _extract_tools_from_context(mock_context)
assert result is None
def test_extracts_tools_from_crew_agent_executor(self):
"""Test tool extraction from CrewAgentExecutor (has 'tools' attribute)."""
mock_tool = CrewStructuredTool(
name="search_tool",
description="A tool for searching",
args_schema=MockArgsSchema,
func=lambda query: f"Results for {query}",
)
mock_context = Mock()
mock_context.tools = [mock_tool]
result = _extract_tools_from_context(mock_context)
assert result is not None
assert len(result) == 1
assert result[0]["name"] == "search_tool"
assert result[0]["description"] == "A tool for searching"
assert result[0]["args_schema"] == MockArgsSchema
def test_extracts_tools_from_lite_agent(self):
"""Test tool extraction from LiteAgent (has '_parsed_tools' attribute)."""
mock_tool = CrewStructuredTool(
name="calculator_tool",
description="A tool for calculations",
args_schema=MockArgsSchema,
func=lambda query: f"Calculated {query}",
)
mock_context = Mock(spec=["_parsed_tools"])
mock_context._parsed_tools = [mock_tool]
result = _extract_tools_from_context(mock_context)
assert result is not None
assert len(result) == 1
assert result[0]["name"] == "calculator_tool"
assert result[0]["description"] == "A tool for calculations"
assert result[0]["args_schema"] == MockArgsSchema
def test_extracts_multiple_tools(self):
"""Test extraction of multiple tools."""
tool1 = CrewStructuredTool(
name="tool1",
description="First tool",
args_schema=MockArgsSchema,
func=lambda query: "result1",
)
tool2 = CrewStructuredTool(
name="tool2",
description="Second tool",
args_schema=MockArgsSchema,
func=lambda query: "result2",
)
mock_context = Mock()
mock_context.tools = [tool1, tool2]
result = _extract_tools_from_context(mock_context)
assert result is not None
assert len(result) == 2
assert result[0]["name"] == "tool1"
assert result[1]["name"] == "tool2"
def test_prefers_tools_over_parsed_tools(self):
"""Test that 'tools' attribute is preferred over '_parsed_tools'."""
tool_from_tools = CrewStructuredTool(
name="from_tools",
description="Tool from tools attribute",
args_schema=MockArgsSchema,
func=lambda query: "from_tools",
)
tool_from_parsed = CrewStructuredTool(
name="from_parsed",
description="Tool from _parsed_tools attribute",
args_schema=MockArgsSchema,
func=lambda query: "from_parsed",
)
mock_context = Mock()
mock_context.tools = [tool_from_tools]
mock_context._parsed_tools = [tool_from_parsed]
result = _extract_tools_from_context(mock_context)
assert result is not None
assert len(result) == 1
assert result[0]["name"] == "from_tools"
class TestGetLlmResponse:
"""Test get_llm_response function."""
@pytest.fixture
def mock_llm(self):
"""Create a mock LLM."""
llm = Mock()
llm.call = Mock(return_value="LLM response")
return llm
@pytest.fixture
def mock_printer(self):
"""Create a mock printer."""
return Mock(spec=Printer)
def test_passes_tools_to_llm_call(self, mock_llm, mock_printer):
"""Test that tools are extracted and passed to llm.call()."""
mock_tool = CrewStructuredTool(
name="test_tool",
description="A test tool",
args_schema=MockArgsSchema,
func=lambda query: "result",
)
mock_context = Mock()
mock_context.tools = [mock_tool]
mock_context.messages = [{"role": "user", "content": "test"}]
mock_context.before_llm_call_hooks = []
mock_context.after_llm_call_hooks = []
with patch(
"crewai.utilities.agent_utils._setup_before_llm_call_hooks",
return_value=True,
):
with patch(
"crewai.utilities.agent_utils._setup_after_llm_call_hooks",
return_value="LLM response",
):
result = get_llm_response(
llm=mock_llm,
messages=[{"role": "user", "content": "test"}],
callbacks=[],
printer=mock_printer,
executor_context=mock_context,
)
# Verify llm.call was called with tools parameter
mock_llm.call.assert_called_once()
call_kwargs = mock_llm.call.call_args[1]
assert "tools" in call_kwargs
assert call_kwargs["tools"] is not None
assert len(call_kwargs["tools"]) == 1
assert call_kwargs["tools"][0]["name"] == "test_tool"
def test_does_not_pass_tools_when_no_context(self, mock_llm, mock_printer):
"""Test that tools parameter is not passed when no executor_context."""
result = get_llm_response(
llm=mock_llm,
messages=[{"role": "user", "content": "test"}],
callbacks=[],
printer=mock_printer,
executor_context=None,
)
mock_llm.call.assert_called_once()
call_kwargs = mock_llm.call.call_args[1]
# tools should NOT be in kwargs when there are no tools
# This maintains backward compatibility with code that checks "tools" in kwargs
assert "tools" not in call_kwargs
def test_does_not_pass_tools_when_context_has_no_tools(
self, mock_llm, mock_printer
):
"""Test that tools parameter is not passed when context has no tools."""
mock_context = Mock()
mock_context.tools = []
mock_context.messages = [{"role": "user", "content": "test"}]
mock_context.before_llm_call_hooks = []
mock_context.after_llm_call_hooks = []
with patch(
"crewai.utilities.agent_utils._setup_before_llm_call_hooks",
return_value=True,
):
with patch(
"crewai.utilities.agent_utils._setup_after_llm_call_hooks",
return_value="LLM response",
):
result = get_llm_response(
llm=mock_llm,
messages=[{"role": "user", "content": "test"}],
callbacks=[],
printer=mock_printer,
executor_context=mock_context,
)
mock_llm.call.assert_called_once()
call_kwargs = mock_llm.call.call_args[1]
# tools should NOT be in kwargs when there are no tools
# This maintains backward compatibility with code that checks "tools" in kwargs
assert "tools" not in call_kwargs
class TestAgetLlmResponse:
"""Test aget_llm_response async function."""
@pytest.fixture
def mock_llm(self):
"""Create a mock LLM with async call."""
llm = Mock()
llm.acall = AsyncMock(return_value="Async LLM response")
return llm
@pytest.fixture
def mock_printer(self):
"""Create a mock printer."""
return Mock(spec=Printer)
@pytest.mark.asyncio
async def test_passes_tools_to_llm_acall(self, mock_llm, mock_printer):
"""Test that tools are extracted and passed to llm.acall()."""
mock_tool = CrewStructuredTool(
name="async_test_tool",
description="An async test tool",
args_schema=MockArgsSchema,
func=lambda query: "async result",
)
mock_context = Mock()
mock_context.tools = [mock_tool]
mock_context.messages = [{"role": "user", "content": "async test"}]
mock_context.before_llm_call_hooks = []
mock_context.after_llm_call_hooks = []
with patch(
"crewai.utilities.agent_utils._setup_before_llm_call_hooks",
return_value=True,
):
with patch(
"crewai.utilities.agent_utils._setup_after_llm_call_hooks",
return_value="Async LLM response",
):
result = await aget_llm_response(
llm=mock_llm,
messages=[{"role": "user", "content": "async test"}],
callbacks=[],
printer=mock_printer,
executor_context=mock_context,
)
# Verify llm.acall was called with tools parameter
mock_llm.acall.assert_called_once()
call_kwargs = mock_llm.acall.call_args[1]
assert "tools" in call_kwargs
assert call_kwargs["tools"] is not None
assert len(call_kwargs["tools"]) == 1
assert call_kwargs["tools"][0]["name"] == "async_test_tool"
@pytest.mark.asyncio
async def test_does_not_pass_tools_when_no_context(self, mock_llm, mock_printer):
"""Test that tools parameter is not passed when no executor_context."""
result = await aget_llm_response(
llm=mock_llm,
messages=[{"role": "user", "content": "test"}],
callbacks=[],
printer=mock_printer,
executor_context=None,
)
mock_llm.acall.assert_called_once()
call_kwargs = mock_llm.acall.call_args[1]
# tools should NOT be in kwargs when there are no tools
# This maintains backward compatibility with code that checks "tools" in kwargs
assert "tools" not in call_kwargs
class TestToolsPassedToGeminiModels:
"""Test that tools are properly passed for Gemini models.
This test class specifically addresses GitHub issue #4238 where
Gemini models fail with UNEXPECTED_TOOL_CALL errors because tools
were not being passed to llm.call().
"""
@pytest.fixture
def mock_gemini_llm(self):
"""Create a mock Gemini LLM."""
llm = Mock()
llm.model = "gemini/gemini-2.0-flash-exp"
llm.call = Mock(return_value="Gemini response with tool call")
return llm
@pytest.fixture
def mock_printer(self):
"""Create a mock printer."""
return Mock(spec=Printer)
@pytest.fixture
def delegation_tools(self):
"""Create mock delegation tools similar to hierarchical crew setup."""
class DelegateWorkArgsSchema(BaseModel):
task: str = Field(description="The task to delegate")
context: str = Field(description="Context for the task")
coworker: str = Field(description="The coworker to delegate to")
class AskQuestionArgsSchema(BaseModel):
question: str = Field(description="The question to ask")
context: str = Field(description="Context for the question")
coworker: str = Field(description="The coworker to ask")
delegate_tool = CrewStructuredTool(
name="Delegate work to coworker",
description="Delegate a specific task to one of your coworkers",
args_schema=DelegateWorkArgsSchema,
func=lambda task, context, coworker: f"Delegated {task} to {coworker}",
)
ask_question_tool = CrewStructuredTool(
name="Ask question to coworker",
description="Ask a specific question to one of your coworkers",
args_schema=AskQuestionArgsSchema,
func=lambda question, context, coworker: f"Asked {question} to {coworker}",
)
return [delegate_tool, ask_question_tool]
def test_gemini_receives_tools_for_hierarchical_crew(
self, mock_gemini_llm, mock_printer, delegation_tools
):
"""Test that Gemini models receive tools when used in hierarchical crew.
This test verifies the fix for issue #4238 where the manager agent
in a hierarchical crew would fail because tools weren't passed to
the Gemini model, causing UNEXPECTED_TOOL_CALL errors.
"""
mock_context = Mock()
mock_context.tools = delegation_tools
mock_context.messages = [
{"role": "system", "content": "You are a manager agent"},
{"role": "user", "content": "Coordinate the team to answer this question"},
]
mock_context.before_llm_call_hooks = []
mock_context.after_llm_call_hooks = []
with patch(
"crewai.utilities.agent_utils._setup_before_llm_call_hooks",
return_value=True,
):
with patch(
"crewai.utilities.agent_utils._setup_after_llm_call_hooks",
return_value="Gemini response with tool call",
):
result = get_llm_response(
llm=mock_gemini_llm,
messages=mock_context.messages,
callbacks=[],
printer=mock_printer,
executor_context=mock_context,
)
# Verify that tools were passed to the Gemini model
mock_gemini_llm.call.assert_called_once()
call_kwargs = mock_gemini_llm.call.call_args[1]
assert "tools" in call_kwargs
assert call_kwargs["tools"] is not None
assert len(call_kwargs["tools"]) == 2
# Verify the delegation tools are properly formatted
tool_names = [t["name"] for t in call_kwargs["tools"]]
assert "Delegate work to coworker" in tool_names
assert "Ask question to coworker" in tool_names
# Verify each tool has the required fields
for tool_dict in call_kwargs["tools"]:
assert "name" in tool_dict
assert "description" in tool_dict
assert "args_schema" in tool_dict
def test_tool_dict_format_compatible_with_llm_providers(
self, mock_gemini_llm, mock_printer, delegation_tools
):
"""Test that extracted tools are in a format compatible with LLM providers.
The tool dictionaries should have 'name', 'description', and 'args_schema'
fields that can be processed by the LLM's _prepare_completion_params method.
"""
mock_context = Mock()
mock_context.tools = delegation_tools
mock_context.messages = [{"role": "user", "content": "test"}]
mock_context.before_llm_call_hooks = []
mock_context.after_llm_call_hooks = []
with patch(
"crewai.utilities.agent_utils._setup_before_llm_call_hooks",
return_value=True,
):
with patch(
"crewai.utilities.agent_utils._setup_after_llm_call_hooks",
return_value="response",
):
get_llm_response(
llm=mock_gemini_llm,
messages=mock_context.messages,
callbacks=[],
printer=mock_printer,
executor_context=mock_context,
)
call_kwargs = mock_gemini_llm.call.call_args[1]
tools = call_kwargs["tools"]
for tool_dict in tools:
# Verify the format matches what extract_tool_info() in common.py expects
assert isinstance(tool_dict["name"], str)
assert isinstance(tool_dict["description"], str)
# args_schema should be a Pydantic model class
assert hasattr(tool_dict["args_schema"], "model_json_schema")

View File

@@ -1,3 +1,3 @@
"""CrewAI development tools."""
__version__ = "1.8.0"
__version__ = "1.8.1"

View File

@@ -117,7 +117,7 @@ show_error_codes = true
warn_unused_ignores = true
python_version = "3.12"
exclude = "(?x)(^lib/crewai/src/crewai/cli/templates/ | ^lib/crewai/tests/ | ^lib/crewai-tools/tests/)"
plugins = ["pydantic.mypy", "crewai.mypy"]
plugins = ["pydantic.mypy"]
[tool.bandit]