Compare commits

...

9 Commits

Author SHA1 Message Date
Devin AI
f12dc9f993 Fix: replace new_data_artifact with new_artifact (available in a2a-sdk 1.0.0)
Co-Authored-By: João <joao@crewai.com>
2026-04-24 16:47:55 +00:00
Devin AI
24d2d8dabb Remove all type: ignore comments to avoid unused-ignore on some Python versions
Co-Authored-By: João <joao@crewai.com>
2026-04-24 16:40:28 +00:00
Devin AI
f5331f3a46 Fix lint: remove unused imports, fix E402 noqa, add type: ignore annotations
Co-Authored-By: João <joao@crewai.com>
2026-04-24 16:34:47 +00:00
Devin AI
6b0ddfa3f2 Add type: ignore for pre-existing protobuf stub issues
Co-Authored-By: João <joao@crewai.com>
2026-04-24 16:24:51 +00:00
Devin AI
856954a311 Fix unused type: ignore comments in _compat.py
Co-Authored-By: João <joao@crewai.com>
2026-04-24 16:22:06 +00:00
Devin AI
ec98238985 Apply ruff format to all modified a2a source files
Co-Authored-By: João <joao@crewai.com>
2026-04-24 16:15:38 +00:00
Devin AI
d05a3415f8 Fix lint: remove unused imports, fix formatting, fix ambiguous char
Co-Authored-By: João <joao@crewai.com>
2026-04-24 16:12:59 +00:00
Devin AI
54470f4932 Fix send_message to use SendMessageRequest wrapper, fix ServerError call
- Add make_send_request() helper in _compat.py for v1.0 API
- Update all handlers to wrap Message in SendMessageRequest
- Fix ServerError(error=...) → ServerError(message) in task.py
- Fix MessageToDict parameter name (always_print_fields_with_no_presence)
- Update integration tests for v1.0 client API (A2ACardResolver, ClientFactory)
- Fix test mocks to use real protobuf Message instead of MagicMock

Co-Authored-By: João <joao@crewai.com>
2026-04-24 16:09:29 +00:00
Devin AI
bec175ec9a Migrate crewai.a2a module to a2a-sdk v1.0.x
Fix #5607: CrewAI 1.14.2 is incompatible with a2a-sdk v1.0.1+

Breaking changes in a2a-sdk v1.0:
- A2AClientHTTPError renamed to A2AClientError
- Protobuf-based types replace Pydantic models
- Enum values changed to SCREAMING_SNAKE_CASE
- TextPart/DataPart/FilePart removed (Part uses oneof)
- AgentCard.url removed (use supported_interfaces)
- StreamResponse wraps all event types
- model_dump/model_copy replaced with protobuf serialization

Changes:
- Add _compat.py: centralized compatibility layer with helpers
- Update pyproject.toml: a2a-sdk>=1.0.0,<2
- Update all a2a module files to use protobuf API
- Update existing tests for v1.0 patterns
- Add comprehensive test_a2a_sdk_v1_compat.py (46 tests)

Co-Authored-By: João <joao@crewai.com>
2026-04-24 16:01:13 +00:00
23 changed files with 6288 additions and 5395 deletions

View File

@@ -100,7 +100,7 @@ anthropic = [
"anthropic~=0.73.0",
]
a2a = [
"a2a-sdk~=0.3.10",
"a2a-sdk>=1.0.0,<2",
"httpx-auth~=0.23.1",
"httpx-sse~=0.4.0",
"aiocache[redis,memcached]~=0.12.3",

View File

@@ -0,0 +1,318 @@
"""Compatibility layer for a2a-sdk v0.3 → v1.0 migration.
Centralizes import aliases and helper functions so the rest of the
a2a module can use a single import regardless of SDK version.
"""
from __future__ import annotations
from typing import Any
from a2a.client.errors import A2AClientError
# ---------------------------------------------------------------------------
# Error re-exports
# In v0.3 the class was called A2AClientHTTPError; v1.0 renamed it to
# A2AClientError. We expose the new name *and* an alias used across the
# codebase so callers can migrate incrementally.
# ---------------------------------------------------------------------------
A2AClientHTTPError = A2AClientError # back-compat alias
# ---------------------------------------------------------------------------
# Type helpers - Protobuf Part access
# In v0.3 Part was a Pydantic discriminated-union with ``part.root.kind``
# and ``part.root.text``; in v1.0 Part is a protobuf message with a
# ``content`` oneof.
# ---------------------------------------------------------------------------
from a2a.types import Part # noqa: E402
def part_is_text(part: Part) -> bool:
"""Return True when the Part carries text content."""
return part.HasField("text")
def part_text(part: Part) -> str:
"""Return the text payload of a Part (assumes text content)."""
return part.text
def part_has_data(part: Part) -> bool:
"""Return True when the Part carries structured data."""
return part.HasField("data")
def part_has_file(part: Part) -> bool:
"""Return True when the Part carries a file (url or raw bytes)."""
return part.HasField("url") or part.HasField("raw")
# ---------------------------------------------------------------------------
# Enum value aliases
# v0.3: TaskState.completed, Role.user (lower snake_case strings)
# v1.0: TaskState.TASK_STATE_COMPLETED, Role.ROLE_USER (SCREAMING_SNAKE_CASE)
# ---------------------------------------------------------------------------
from a2a.types import Role, TaskState # noqa: E402
# TaskState aliases
TASK_STATE_SUBMITTED = TaskState.TASK_STATE_SUBMITTED
TASK_STATE_WORKING = TaskState.TASK_STATE_WORKING
TASK_STATE_COMPLETED = TaskState.TASK_STATE_COMPLETED
TASK_STATE_FAILED = TaskState.TASK_STATE_FAILED
TASK_STATE_CANCELED = TaskState.TASK_STATE_CANCELED
TASK_STATE_INPUT_REQUIRED = TaskState.TASK_STATE_INPUT_REQUIRED
TASK_STATE_AUTH_REQUIRED = TaskState.TASK_STATE_AUTH_REQUIRED
TASK_STATE_REJECTED = TaskState.TASK_STATE_REJECTED
# Role aliases
ROLE_USER = Role.ROLE_USER
ROLE_AGENT = Role.ROLE_AGENT
# ---------------------------------------------------------------------------
# Protobuf object helpers
# Protobuf objects don't have model_dump() / model_copy().
# ---------------------------------------------------------------------------
from google.protobuf.json_format import ( # noqa: E402
MessageToDict,
)
def proto_to_json(msg: Any) -> str:
"""Serialize a protobuf message to a JSON string.
Replaces ``msg.model_dump_json(...)`` from v0.3 Pydantic models.
"""
from google.protobuf.json_format import MessageToJson
return MessageToJson(msg, preserving_proto_field_name=True, indent=2)
def agent_card_to_dict(agent_card: Any, *, exclude_none: bool = True) -> dict[str, Any]:
"""Serialize a protobuf AgentCard to a plain dict.
Works like ``agent_card.model_dump(exclude_none=True)`` did in v0.3.
"""
return MessageToDict(
agent_card,
preserving_proto_field_name=True,
always_print_fields_with_no_presence=not exclude_none,
)
def proto_copy(msg: Any) -> Any:
"""Return a deep copy of a protobuf message (replaces model_copy)."""
new = type(msg)()
new.CopyFrom(msg)
return new
# ---------------------------------------------------------------------------
# Message / Part construction helpers
# v0.3: Message(role=Role.user, parts=[Part(root=TextPart(text=...))])
# v1.0: Message(role=Role.ROLE_USER, parts=[Part(text=...)])
# ---------------------------------------------------------------------------
from a2a.types import Message # noqa: E402
def new_text_part(text: str, **kwargs: Any) -> Part:
"""Create a Part with text content (v1.0 style)."""
return Part(text=text, **kwargs)
def new_text_message(
text: str,
*,
role: Any = ROLE_AGENT,
message_id: str | None = None,
context_id: str | None = None,
task_id: str | None = None,
**kwargs: Any,
) -> Message:
"""Create a Message with a single text Part."""
import uuid as _uuid
return Message(
role=role,
message_id=message_id or str(_uuid.uuid4()),
parts=[Part(text=text)],
context_id=context_id or "",
task_id=task_id or "",
**kwargs,
)
def make_send_request(message: Message) -> Any:
"""Wrap a Message in a SendMessageRequest (v1.0 API).
In v0.3, ``client.send_message(message)`` accepted a bare ``Message``.
In v1.0, it expects ``SendMessageRequest(message=message)``.
"""
from a2a.types import SendMessageRequest
return SendMessageRequest(message=message)
# ---------------------------------------------------------------------------
# AgentCard field access helpers
# v0.3: agent_card.url, agent_card.preferred_transport, agent_card.additional_interfaces
# v1.0: agent_card.supported_interfaces, interface.url, interface.protocol_binding
# ---------------------------------------------------------------------------
from a2a.types import AgentCard, AgentInterface # noqa: E402
def agent_card_url(agent_card: AgentCard) -> str:
"""Get the primary URL from an AgentCard.
In v0.3 this was ``agent_card.url``.
In v1.0 the URL lives inside ``supported_interfaces``.
"""
if agent_card.supported_interfaces:
return agent_card.supported_interfaces[0].url
return ""
def agent_card_preferred_transport(agent_card: AgentCard) -> str:
"""Get the preferred transport protocol from an AgentCard.
In v0.3 this was ``agent_card.preferred_transport``.
In v1.0 it's the protocol_binding of the first supported_interface.
"""
if agent_card.supported_interfaces:
return agent_card.supported_interfaces[0].protocol_binding
return "JSONRPC"
def agent_card_interfaces(agent_card: AgentCard) -> list[AgentInterface]:
"""Get all interfaces from an AgentCard.
In v0.3 these were split between the primary url and
``agent_card.additional_interfaces``.
In v1.0 everything is in ``supported_interfaces``.
"""
return (
list(agent_card.supported_interfaces) if agent_card.supported_interfaces else []
)
def agent_card_protocol_version(agent_card: AgentCard) -> str:
"""Get the protocol version from an AgentCard.
In v0.3 this was ``agent_card.protocol_version``.
In v1.0 it's per-interface in ``interface.protocol_version``.
"""
if agent_card.supported_interfaces:
return agent_card.supported_interfaces[0].protocol_version or ""
return ""
# ---------------------------------------------------------------------------
# StreamResponse helpers
# v0.3: send_message returned AsyncIterator[tuple[Task, Update] | Message]
# v1.0: send_message returns AsyncIterator[StreamResponse]
# ---------------------------------------------------------------------------
from a2a.types import ( # noqa: E402
StreamResponse,
TaskStatusUpdateEvent,
)
def is_stream_message(chunk: StreamResponse) -> bool:
"""Check if a StreamResponse contains a Message."""
return chunk.HasField("message")
def is_stream_task(chunk: StreamResponse) -> bool:
"""Check if a StreamResponse contains a Task."""
return chunk.HasField("task")
def is_stream_status_update(chunk: StreamResponse) -> bool:
"""Check if a StreamResponse contains a TaskStatusUpdateEvent."""
return chunk.HasField("status_update")
def is_stream_artifact_update(chunk: StreamResponse) -> bool:
"""Check if a StreamResponse contains a TaskArtifactUpdateEvent."""
return chunk.HasField("artifact_update")
# ---------------------------------------------------------------------------
# Client configuration helpers
# v0.3: ClientConfig.supported_transports, push_notification_configs (list)
# v1.0: ClientConfig.supported_protocol_bindings, push_notification_config (singular)
# ---------------------------------------------------------------------------
from a2a.client import ClientConfig # noqa: E402
from a2a.types import TaskPushNotificationConfig # noqa: E402
def create_client_config(
*,
httpx_client: Any = None,
supported_transports: list[str] | None = None,
streaming: bool = True,
polling: bool = False,
accepted_output_modes: list[str] | None = None,
push_notification_config: TaskPushNotificationConfig | None = None,
grpc_channel_factory: Any = None,
) -> ClientConfig:
"""Create a ClientConfig compatible with a2a-sdk v1.0."""
return ClientConfig(
httpx_client=httpx_client,
supported_protocol_bindings=supported_transports or ["JSONRPC"],
streaming=streaming,
polling=polling,
accepted_output_modes=accepted_output_modes
or ["text/plain", "application/json"],
push_notification_config=push_notification_config,
grpc_channel_factory=grpc_channel_factory,
)
# ---------------------------------------------------------------------------
# GetTaskRequest / SubscribeToTaskRequest
# v0.3: TaskQueryParams, TaskIdParams
# v1.0: GetTaskRequest, SubscribeToTaskRequest
# ---------------------------------------------------------------------------
from a2a.types import GetTaskRequest, SubscribeToTaskRequest # noqa: E402
# Expose v0.3 names as aliases for the v1.0 types
TaskQueryParams = GetTaskRequest
TaskIdParams = SubscribeToTaskRequest
# ---------------------------------------------------------------------------
# Task status helpers
# v1.0 TaskStatusUpdateEvent no longer has a `final` field. Finality is
# determined by the task state being terminal.
# ---------------------------------------------------------------------------
TERMINAL_STATES: frozenset[int] = frozenset(
{
TASK_STATE_COMPLETED,
TASK_STATE_FAILED,
TASK_STATE_REJECTED,
TASK_STATE_CANCELED,
}
)
def is_status_update_final(update: TaskStatusUpdateEvent) -> bool:
"""Determine if a status update is final.
In v0.3 this was ``update.final``. In v1.0 finality is inferred from
the task state being terminal.
"""
if update.status and update.status.state:
return update.status.state in TERMINAL_STATES
return False

View File

@@ -11,12 +11,13 @@ import re
import threading
from typing import Final, Literal, cast
from a2a.client.errors import A2AClientHTTPError
from a2a.client.errors import A2AClientError
from a2a.types import (
APIKeySecurityScheme,
AgentCard,
HTTPAuthSecurityScheme,
OAuth2SecurityScheme,
SecurityScheme,
)
from httpx import AsyncClient, Response
@@ -112,7 +113,7 @@ def _raise_auth_mismatch(
f"AgentCard requires {required} authentication, "
f"but {type(provided_auth).__name__} was provided"
)
raise A2AClientHTTPError(401, msg)
raise A2AClientError(msg)
def parse_www_authenticate(header_value: str) -> dict[str, dict[str, str]]:
@@ -159,25 +160,44 @@ def validate_auth_against_agent_card(
A2AClientHTTPError: If auth doesn't match AgentCard requirements (status_code=401).
"""
if not agent_card.security or not agent_card.security_schemes:
if not agent_card.security_requirements or not agent_card.security_schemes:
return
if not auth:
msg = "AgentCard requires authentication but no auth scheme provided"
raise A2AClientHTTPError(401, msg)
raise A2AClientError(msg)
first_security_req = agent_card.security[0] if agent_card.security else {}
first_security_req = (
agent_card.security_requirements[0]
if agent_card.security_requirements
else None
)
if first_security_req is None:
return
for scheme_name in first_security_req.keys():
security_scheme_wrapper = agent_card.security_schemes.get(scheme_name)
for scheme_name in first_security_req.schemes.keys():
security_scheme_wrapper: SecurityScheme | None = (
agent_card.security_schemes.get(scheme_name)
)
if not security_scheme_wrapper:
continue
scheme = security_scheme_wrapper.root
scheme_field = security_scheme_wrapper.WhichOneof("scheme")
if scheme_field is None:
continue
if allowed_classes := _SCHEME_AUTH_MAPPING.get(type(scheme)):
if not isinstance(auth, allowed_classes):
_raise_auth_mismatch(allowed_classes, auth)
scheme = getattr(security_scheme_wrapper, scheme_field)
if isinstance(scheme, OAuth2SecurityScheme):
allowed = _SCHEME_AUTH_MAPPING.get(OAuth2SecurityScheme)
if allowed and not isinstance(auth, allowed):
_raise_auth_mismatch(allowed, auth)
return
if isinstance(scheme, APIKeySecurityScheme):
allowed = _SCHEME_AUTH_MAPPING.get(APIKeySecurityScheme)
if allowed and not isinstance(auth, allowed):
_raise_auth_mismatch(allowed, auth)
return
if isinstance(scheme, HTTPAuthSecurityScheme):
@@ -188,7 +208,7 @@ def validate_auth_against_agent_card(
return
msg = "Could not validate auth against AgentCard security requirements"
raise A2AClientHTTPError(401, msg)
raise A2AClientError(msg)
async def retry_on_401(

View File

@@ -568,7 +568,9 @@ class A2AServerConfig(BaseModel):
auth: Authentication scheme for A2A endpoints.
"""
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
model_config: ClassVar[ConfigDict] = ConfigDict(
extra="forbid", arbitrary_types_allowed=True
)
name: str | None = Field(
default=None,

View File

@@ -10,9 +10,7 @@ See: https://a2a-protocol.org/latest/topics/extensions/
from __future__ import annotations
from typing import Any
from a2a.client.middleware import ClientCallContext, ClientCallInterceptor
from a2a.client.interceptors import BeforeArgs, ClientCallInterceptor
from a2a.extensions.common import (
HTTP_EXTENSION_HEADER,
)
@@ -63,30 +61,15 @@ class ExtensionsMiddleware(ClientCallInterceptor):
"""
self._extensions = extensions
async def intercept(
self,
method_name: str,
request_payload: dict[str, Any],
http_kwargs: dict[str, Any],
agent_card: AgentCard | None,
context: ClientCallContext | None,
) -> tuple[dict[str, Any], dict[str, Any]]:
"""Add extensions header to the request.
async def before(self, args: BeforeArgs) -> None:
"""Add extensions header before the request is sent.
Args:
method_name: The A2A method being called.
request_payload: The JSON-RPC request payload.
http_kwargs: HTTP request kwargs (headers, etc).
agent_card: The target agent's card.
context: Optional call context.
Returns:
Tuple of (request_payload, modified_http_kwargs).
args: The BeforeArgs containing method, input, agent_card, etc.
"""
if self._extensions:
headers = http_kwargs.setdefault("headers", {})
headers[HTTP_EXTENSION_HEADER] = ",".join(self._extensions)
return request_payload, http_kwargs
if self._extensions and isinstance(args.input, dict):
metadata = args.input.setdefault("metadata", {})
metadata[HTTP_EXTENSION_HEADER] = ",".join(self._extensions)
def validate_required_extensions(

View File

@@ -4,22 +4,32 @@ from __future__ import annotations
from collections.abc import AsyncIterator
from typing import TYPE_CHECKING, Any
import uuid
from a2a.client.errors import A2AClientHTTPError
from a2a.client.errors import A2AClientError
from a2a.types import (
AgentCard,
Message,
Part,
Role,
Task,
TaskArtifactUpdateEvent,
TaskState,
TaskStatusUpdateEvent,
TextPart,
StreamResponse,
)
from typing_extensions import NotRequired, TypedDict
from crewai.a2a._compat import (
ROLE_AGENT,
TASK_STATE_AUTH_REQUIRED,
TASK_STATE_CANCELED,
TASK_STATE_COMPLETED,
TASK_STATE_FAILED,
TASK_STATE_INPUT_REQUIRED,
TASK_STATE_REJECTED,
TASK_STATE_SUBMITTED,
TASK_STATE_WORKING,
agent_card_to_dict,
is_stream_message,
is_stream_task,
new_text_message,
part_is_text,
part_text,
)
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.a2a_events import (
A2AConnectionErrorEvent,
@@ -30,31 +40,29 @@ from crewai.events.types.a2a_events import (
if TYPE_CHECKING:
from a2a.types import Task as A2ATask
SendMessageEvent = (
tuple[Task, TaskStatusUpdateEvent | TaskArtifactUpdateEvent | None] | Message
)
SendMessageEvent = StreamResponse
TERMINAL_STATES: frozenset[TaskState] = frozenset(
TERMINAL_STATES: frozenset[int] = frozenset(
{
TaskState.completed,
TaskState.failed,
TaskState.rejected,
TaskState.canceled,
TASK_STATE_COMPLETED,
TASK_STATE_FAILED,
TASK_STATE_REJECTED,
TASK_STATE_CANCELED,
}
)
ACTIONABLE_STATES: frozenset[TaskState] = frozenset(
ACTIONABLE_STATES: frozenset[int] = frozenset(
{
TaskState.input_required,
TaskState.auth_required,
TASK_STATE_INPUT_REQUIRED,
TASK_STATE_AUTH_REQUIRED,
}
)
PENDING_STATES: frozenset[TaskState] = frozenset(
PENDING_STATES: frozenset[int] = frozenset(
{
TaskState.submitted,
TaskState.working,
TASK_STATE_SUBMITTED,
TASK_STATE_WORKING,
}
)
@@ -62,7 +70,7 @@ PENDING_STATES: frozenset[TaskState] = frozenset(
class TaskStateResult(TypedDict):
"""Result dictionary from processing A2A task state."""
status: TaskState
status: int
history: list[Message]
result: NotRequired[str]
error: NotRequired[str]
@@ -83,26 +91,22 @@ def extract_task_result_parts(a2a_task: A2ATask) -> list[str]:
if a2a_task.status and a2a_task.status.message:
msg = a2a_task.status.message
result_parts.extend(
part.root.text for part in msg.parts if part.root.kind == "text"
)
result_parts.extend(part_text(part) for part in msg.parts if part_is_text(part))
if not result_parts and a2a_task.history:
for history_msg in reversed(a2a_task.history):
if history_msg.role == Role.agent:
if history_msg.role == ROLE_AGENT:
result_parts.extend(
part.root.text
for part in history_msg.parts
if part.root.kind == "text"
part_text(part) for part in history_msg.parts if part_is_text(part)
)
break
if a2a_task.artifacts:
result_parts.extend(
part.root.text
part_text(part)
for artifact in a2a_task.artifacts
for part in artifact.parts
if part.root.kind == "text"
if part_is_text(part)
)
return result_parts
@@ -122,15 +126,15 @@ def extract_error_message(a2a_task: A2ATask, default: str) -> str:
msg = a2a_task.status.message
if msg:
for part in msg.parts:
if part.root.kind == "text":
return str(part.root.text)
return str(msg)
if part_is_text(part):
return str(part_text(part))
return str(msg)
if a2a_task.history:
for history_msg in reversed(a2a_task.history):
for part in history_msg.parts:
if part.root.kind == "text":
return str(part.root.text)
if part_is_text(part):
return str(part_text(part))
return default
@@ -174,7 +178,7 @@ def process_task_state(
if result_parts is None:
result_parts = []
if a2a_task.status.state == TaskState.completed:
if a2a_task.status.state == TASK_STATE_COMPLETED:
if not result_parts:
extracted_parts = extract_task_result_parts(a2a_task)
result_parts.extend(extracted_parts)
@@ -204,22 +208,21 @@ def process_task_state(
)
return TaskStateResult(
status=TaskState.completed,
agent_card=agent_card.model_dump(exclude_none=True),
status=TASK_STATE_COMPLETED,
agent_card=agent_card_to_dict(agent_card),
result=response_text,
history=new_messages,
)
if a2a_task.status.state == TaskState.input_required:
if a2a_task.status.state == TASK_STATE_INPUT_REQUIRED:
if a2a_task.history:
new_messages.extend(a2a_task.history)
response_text = extract_error_message(a2a_task, "Additional input required")
if response_text and not a2a_task.history:
agent_message = Message(
role=Role.agent,
message_id=str(uuid.uuid4()),
parts=[Part(root=TextPart(text=response_text))],
agent_message = new_text_message(
response_text,
role=ROLE_AGENT,
context_id=a2a_task.context_id,
task_id=a2a_task.id,
)
@@ -247,34 +250,34 @@ def process_task_state(
)
return TaskStateResult(
status=TaskState.input_required,
status=TASK_STATE_INPUT_REQUIRED,
error=response_text,
history=new_messages,
agent_card=agent_card.model_dump(exclude_none=True),
agent_card=agent_card_to_dict(agent_card),
)
if a2a_task.status.state in {TaskState.failed, TaskState.rejected}:
if a2a_task.status.state in {TASK_STATE_FAILED, TASK_STATE_REJECTED}:
error_msg = extract_error_message(a2a_task, "Task failed without error message")
if a2a_task.history:
new_messages.extend(a2a_task.history)
return TaskStateResult(
status=TaskState.failed,
status=TASK_STATE_FAILED,
error=error_msg,
history=new_messages,
)
if a2a_task.status.state == TaskState.auth_required:
if a2a_task.status.state == TASK_STATE_AUTH_REQUIRED:
error_msg = extract_error_message(a2a_task, "Authentication required")
return TaskStateResult(
status=TaskState.auth_required,
status=TASK_STATE_AUTH_REQUIRED,
error=error_msg,
history=new_messages,
)
if a2a_task.status.state == TaskState.canceled:
if a2a_task.status.state == TASK_STATE_CANCELED:
error_msg = extract_error_message(a2a_task, "Task was canceled")
return TaskStateResult(
status=TaskState.canceled,
status=TASK_STATE_CANCELED,
error=error_msg,
history=new_messages,
)
@@ -286,7 +289,7 @@ def process_task_state(
async def send_message_and_get_task_id(
event_stream: AsyncIterator[SendMessageEvent],
event_stream: AsyncIterator[StreamResponse],
new_messages: list[Message],
agent_card: AgentCard,
turn_number: int,
@@ -321,11 +324,12 @@ async def send_message_and_get_task_id(
Task ID string if agent needs polling/waiting, or TaskStateResult if done.
"""
try:
async for event in event_stream:
if isinstance(event, Message):
async for chunk in event_stream:
if is_stream_message(chunk):
event = chunk.message
new_messages.append(event)
result_parts = [
part.root.text for part in event.parts if part.root.kind == "text"
part_text(part) for part in event.parts if part_is_text(part)
]
response_text = " ".join(result_parts) if result_parts else ""
@@ -348,14 +352,14 @@ async def send_message_and_get_task_id(
)
return TaskStateResult(
status=TaskState.completed,
status=TASK_STATE_COMPLETED,
result=response_text,
history=new_messages,
agent_card=agent_card.model_dump(exclude_none=True),
agent_card=agent_card_to_dict(agent_card),
)
if isinstance(event, tuple):
a2a_task, _ = event
if is_stream_task(chunk):
a2a_task = chunk.task
if a2a_task.status.state in TERMINAL_STATES | ACTIONABLE_STATES:
result = process_task_state(
@@ -376,18 +380,17 @@ async def send_message_and_get_task_id(
return a2a_task.id
return TaskStateResult(
status=TaskState.failed,
status=TASK_STATE_FAILED,
error="No task ID received from initial message",
history=new_messages,
)
except A2AClientHTTPError as e:
error_msg = f"HTTP Error {e.status_code}: {e!s}"
except A2AClientError as e:
error_msg = f"A2A Client Error: {e!s}"
error_message = Message(
role=Role.agent,
message_id=str(uuid.uuid4()),
parts=[Part(root=TextPart(text=error_msg))],
error_message = new_text_message(
error_msg,
role=ROLE_AGENT,
context_id=context_id,
)
new_messages.append(error_message)
@@ -397,8 +400,7 @@ async def send_message_and_get_task_id(
A2AConnectionErrorEvent(
endpoint=endpoint or "",
error=str(e),
error_type="http_error",
status_code=e.status_code,
error_type="client_error",
a2a_agent_name=a2a_agent_name,
operation="send_message",
context_id=context_id,
@@ -423,7 +425,7 @@ async def send_message_and_get_task_id(
),
)
return TaskStateResult(
status=TaskState.failed,
status=TASK_STATE_FAILED,
error=error_msg,
history=new_messages,
)
@@ -431,10 +433,9 @@ async def send_message_and_get_task_id(
except Exception as e:
error_msg = f"Unexpected error during send_message: {e!s}"
error_message = Message(
role=Role.agent,
message_id=str(uuid.uuid4()),
parts=[Part(root=TextPart(text=error_msg))],
error_message = new_text_message(
error_msg,
role=ROLE_AGENT,
context_id=context_id,
)
new_messages.append(error_message)
@@ -469,7 +470,7 @@ async def send_message_and_get_task_id(
),
)
return TaskStateResult(
status=TaskState.failed,
status=TASK_STATE_FAILED,
error=error_msg,
history=new_messages,
)

View File

@@ -5,21 +5,21 @@ from __future__ import annotations
import asyncio
import time
from typing import TYPE_CHECKING, Any
import uuid
from a2a.client import Client
from a2a.client.errors import A2AClientHTTPError
from a2a.client.errors import A2AClientError
from a2a.types import (
AgentCard,
GetTaskRequest,
Message,
Part,
Role,
TaskQueryParams,
TaskState,
TextPart,
)
from typing_extensions import Unpack
from crewai.a2a._compat import (
ROLE_AGENT,
TASK_STATE_FAILED,
new_text_message,
)
from crewai.a2a.errors import A2APollingTimeoutError
from crewai.a2a.task_helpers import (
ACTIONABLE_STATES,
@@ -84,7 +84,7 @@ async def _poll_task_until_complete(
while True:
poll_count += 1
task = await client.get_task(
TaskQueryParams(id=task_id, history_length=history_length)
GetTaskRequest(id=task_id, history_length=history_length)
)
elapsed = time.monotonic() - start_time
@@ -94,7 +94,7 @@ async def _poll_task_until_complete(
A2APollingStatusEvent(
task_id=task_id,
context_id=effective_context_id,
state=str(task.status.state.value),
state=str(task.status.state),
elapsed_seconds=elapsed,
poll_count=poll_count,
endpoint=endpoint,
@@ -158,9 +158,11 @@ class PollingHandler:
from_task = kwargs.get("from_task")
from_agent = kwargs.get("from_agent")
from crewai.a2a._compat import make_send_request
try:
result_or_task_id = await send_message_and_get_task_id(
event_stream=client.send_message(message),
event_stream=client.send_message(make_send_request(message)),
new_messages=new_messages,
agent_card=agent_card,
turn_number=turn_number,
@@ -222,7 +224,7 @@ class PollingHandler:
return result
return TaskStateResult(
status=TaskState.failed,
status=TASK_STATE_FAILED,
error=f"Unexpected task state: {final_task.status.state}",
history=new_messages,
)
@@ -230,10 +232,9 @@ class PollingHandler:
except A2APollingTimeoutError as e:
error_msg = str(e)
error_message = Message(
role=Role.agent,
message_id=str(uuid.uuid4()),
parts=[Part(root=TextPart(text=error_msg))],
error_message = new_text_message(
error_msg,
role=ROLE_AGENT,
context_id=context_id,
task_id=task_id,
)
@@ -256,18 +257,17 @@ class PollingHandler:
),
)
return TaskStateResult(
status=TaskState.failed,
status=TASK_STATE_FAILED,
error=error_msg,
history=new_messages,
)
except A2AClientHTTPError as e:
error_msg = f"HTTP Error {e.status_code}: {e!s}"
except A2AClientError as e:
error_msg = f"A2A Client Error: {e!s}"
error_message = Message(
role=Role.agent,
message_id=str(uuid.uuid4()),
parts=[Part(root=TextPart(text=error_msg))],
error_message = new_text_message(
error_msg,
role=ROLE_AGENT,
context_id=context_id,
task_id=task_id,
)
@@ -278,8 +278,7 @@ class PollingHandler:
A2AConnectionErrorEvent(
endpoint=endpoint,
error=str(e),
error_type="http_error",
status_code=e.status_code,
error_type="client_error",
a2a_agent_name=a2a_agent_name,
operation="polling",
context_id=context_id,
@@ -305,7 +304,7 @@ class PollingHandler:
),
)
return TaskStateResult(
status=TaskState.failed,
status=TASK_STATE_FAILED,
error=error_msg,
history=new_messages,
)
@@ -313,10 +312,9 @@ class PollingHandler:
except Exception as e:
error_msg = f"Unexpected error during polling: {e!s}"
error_message = Message(
role=Role.agent,
message_id=str(uuid.uuid4()),
parts=[Part(root=TextPart(text=error_msg))],
error_message = new_text_message(
error_msg,
role=ROLE_AGENT,
context_id=context_id,
task_id=task_id,
)
@@ -353,7 +351,7 @@ class PollingHandler:
),
)
return TaskStateResult(
status=TaskState.failed,
status=TASK_STATE_FAILED,
error=error_msg,
history=new_messages,
)

View File

@@ -2,9 +2,8 @@
from __future__ import annotations
from typing import Annotated
from typing import Annotated, Any
from a2a.types import PushNotificationAuthenticationInfo
from pydantic import AnyHttpUrl, BaseModel, BeforeValidator, Field
from crewai.a2a.updates.base import PushNotificationResultStore
@@ -46,7 +45,7 @@ class PushNotificationConfig(BaseModel):
url: AnyHttpUrl = Field(description="Callback URL for push notifications")
id: str | None = Field(default=None, description="Unique config identifier")
token: str | None = Field(default=None, description="Validation token")
authentication: PushNotificationAuthenticationInfo | None = Field(
authentication: Any | None = Field(
default=None, description="Auth info for agent to use when calling webhook"
)
timeout: float | None = Field(

View File

@@ -4,20 +4,20 @@ from __future__ import annotations
import logging
from typing import TYPE_CHECKING, Any
import uuid
from a2a.client import Client
from a2a.client.errors import A2AClientHTTPError
from a2a.client.errors import A2AClientError
from a2a.types import (
AgentCard,
Message,
Part,
Role,
TaskState,
TextPart,
)
from typing_extensions import Unpack
from crewai.a2a._compat import (
ROLE_AGENT,
TASK_STATE_FAILED,
new_text_message,
)
from crewai.a2a.task_helpers import (
TaskStateResult,
process_task_state,
@@ -69,10 +69,9 @@ def _handle_push_error(
Returns:
TaskStateResult with failed status.
"""
error_message = Message(
role=Role.agent,
message_id=str(uuid.uuid4()),
parts=[Part(root=TextPart(text=error_msg))],
error_message = new_text_message(
error_msg,
role=ROLE_AGENT,
context_id=params.context_id,
task_id=task_id,
)
@@ -110,7 +109,7 @@ def _handle_push_error(
),
)
return TaskStateResult(
status=TaskState.failed,
status=TASK_STATE_FAILED,
error=error_msg,
history=new_messages,
)
@@ -221,7 +220,7 @@ class PushNotificationHandler:
),
)
return TaskStateResult(
status=TaskState.failed,
status=TASK_STATE_FAILED,
error=error_msg,
history=new_messages,
)
@@ -245,14 +244,16 @@ class PushNotificationHandler:
),
)
return TaskStateResult(
status=TaskState.failed,
status=TASK_STATE_FAILED,
error=error_msg,
history=new_messages,
)
from crewai.a2a._compat import make_send_request
try:
result_or_task_id = await send_message_and_get_task_id(
event_stream=client.send_message(message),
event_stream=client.send_message(make_send_request(message)),
new_messages=new_messages,
agent_card=agent_card,
turn_number=params.turn_number,
@@ -304,7 +305,7 @@ class PushNotificationHandler:
if final_task is None:
return TaskStateResult(
status=TaskState.failed,
status=TASK_STATE_FAILED,
error=f"Push notification timeout after {polling_timeout}s",
history=new_messages,
)
@@ -325,21 +326,20 @@ class PushNotificationHandler:
return result
return TaskStateResult(
status=TaskState.failed,
status=TASK_STATE_FAILED,
error=f"Unexpected task state: {final_task.status.state}",
history=new_messages,
)
except A2AClientHTTPError as e:
except A2AClientError as e:
return _handle_push_error(
error=e,
error_msg=f"HTTP Error {e.status_code}: {e!s}",
error_type="http_error",
error_msg=f"A2A Client Error: {e!s}",
error_type="client_error",
new_messages=new_messages,
agent_branch=agent_branch,
params=params,
task_id=task_id,
status_code=e.status_code,
)
except Exception as e:

View File

@@ -5,25 +5,30 @@ from __future__ import annotations
import asyncio
import logging
from typing import Final
import uuid
from a2a.client import Client
from a2a.client.errors import A2AClientHTTPError
from a2a.client.errors import A2AClientError
from a2a.types import (
AgentCard,
GetTaskRequest,
Message,
Part,
Role,
SubscribeToTaskRequest,
Task,
TaskArtifactUpdateEvent,
TaskIdParams,
TaskQueryParams,
TaskState,
TaskStatusUpdateEvent,
TextPart,
)
from typing_extensions import Unpack
from crewai.a2a._compat import (
ROLE_AGENT,
TASK_STATE_FAILED,
is_stream_artifact_update,
is_stream_message,
is_stream_status_update,
is_stream_task,
new_text_message,
part_is_text,
part_text,
)
from crewai.a2a.task_helpers import (
ACTIONABLE_STATES,
TERMINAL_STATES,
@@ -50,6 +55,16 @@ MAX_RESUBSCRIBE_ATTEMPTS: Final[int] = 3
RESUBSCRIBE_BACKOFF_BASE: Final[float] = 1.0
def _extract_text_from_artifact(artifact: TaskArtifactUpdateEvent) -> list[str]:
"""Extract text parts from an artifact update event."""
parts: list[str] = []
if artifact.artifact and artifact.artifact.parts:
parts.extend(
part_text(part) for part in artifact.artifact.parts if part_is_text(part)
)
return parts
class StreamingHandler:
"""SSE streaming-based update handler."""
@@ -86,7 +101,7 @@ class StreamingHandler:
params = extract_common_params(kwargs) # type: ignore[arg-type]
try:
a2a_task: Task = await client.get_task(TaskQueryParams(id=task_id))
a2a_task: Task = await client.get_task(GetTaskRequest(id=task_id))
if a2a_task.status.state in TERMINAL_STATES:
logger.info(
@@ -138,26 +153,18 @@ class StreamingHandler:
if attempt > 0:
await asyncio.sleep(backoff)
event_stream = client.resubscribe(TaskIdParams(id=task_id))
event_stream = client.subscribe(SubscribeToTaskRequest(id=task_id))
async for event in event_stream:
if isinstance(event, tuple):
resubscribed_task, update = event
async for chunk in event_stream:
if is_stream_task(chunk):
resubscribed_task = chunk.task
is_final_update = (
process_status_update(update, result_parts)
if isinstance(update, TaskStatusUpdateEvent)
else False
if is_stream_status_update(chunk):
update = chunk.status_update
is_final_update = process_status_update(
update, result_parts
)
if isinstance(update, TaskArtifactUpdateEvent):
artifact = update.artifact
result_parts.extend(
part.root.text
for part in artifact.parts
if part.root.kind == "text"
)
if (
is_final_update
or resubscribed_task.status.state
@@ -178,15 +185,20 @@ class StreamingHandler:
is_final=is_final_update,
)
elif isinstance(event, Message):
new_messages.append(event)
if is_stream_artifact_update(chunk):
artifact = chunk.artifact_update
result_parts.extend(_extract_text_from_artifact(artifact))
if is_stream_message(chunk):
msg = chunk.message
new_messages.append(msg)
result_parts.extend(
part.root.text
for part in event.parts
if part.root.kind == "text"
part_text(part)
for part in msg.parts
if part_is_text(part)
)
final_task = await client.get_task(TaskQueryParams(id=task_id))
final_task = await client.get_task(GetTaskRequest(id=task_id))
return process_task_state(
a2a_task=final_task,
new_messages=new_messages,
@@ -258,9 +270,12 @@ class StreamingHandler:
result_parts: list[str] = []
final_result: TaskStateResult | None = None
event_stream = client.send_message(message)
from crewai.a2a._compat import make_send_request
event_stream = client.send_message(make_send_request(message))
chunk_index = 0
current_task_id: str | None = task_id
current_task: Task | None = None
crewai_event_bus.emit(
agent_branch,
@@ -278,22 +293,25 @@ class StreamingHandler:
)
try:
async for event in event_stream:
if isinstance(event, tuple):
a2a_task, _ = event
current_task_id = a2a_task.id
async for chunk in event_stream:
# Extract task from task payload
if is_stream_task(chunk):
current_task = chunk.task
current_task_id = current_task.id
if isinstance(event, Message):
new_messages.append(event)
message_context_id = event.context_id or params.context_id
for part in event.parts:
if part.root.kind == "text":
text = part.root.text
# Handle standalone message responses
if is_stream_message(chunk):
msg = chunk.message
new_messages.append(msg)
message_context_id = msg.context_id or params.context_id
for part in msg.parts:
if part_is_text(part):
text = part_text(part)
result_parts.append(text)
crewai_event_bus.emit(
agent_branch,
A2AStreamingChunkEvent(
task_id=event.task_id or task_id,
task_id=msg.task_id or task_id,
context_id=message_context_id,
chunk=text,
chunk_index=chunk_index,
@@ -307,38 +325,40 @@ class StreamingHandler:
)
chunk_index += 1
elif isinstance(event, tuple):
a2a_task, update = event
if isinstance(update, TaskArtifactUpdateEvent):
artifact = update.artifact
# Handle artifact updates
elif is_stream_artifact_update(chunk):
artifact_update = chunk.artifact_update
artifact = artifact_update.artifact
if artifact and artifact.parts:
result_parts.extend(
part.root.text
part_text(part)
for part in artifact.parts
if part.root.kind == "text"
if part_is_text(part)
)
artifact_size = None
if artifact.parts:
artifact_size = sum(
len(p.root.text.encode())
if p.root.kind == "text"
else len(getattr(p.root, "data", b""))
len(part_text(p).encode())
if part_is_text(p)
else len(getattr(p, "raw", b""))
for p in artifact.parts
)
effective_context_id = a2a_task.context_id or params.context_id
effective_context_id = (
current_task.context_id if current_task else None
) or params.context_id
crewai_event_bus.emit(
agent_branch,
A2AArtifactReceivedEvent(
task_id=a2a_task.id,
task_id=artifact_update.task_id or current_task_id,
artifact_id=artifact.artifact_id,
artifact_name=artifact.name,
artifact_description=artifact.description,
mime_type=artifact.parts[0].root.kind
if artifact.parts
mime_type="text"
if artifact.parts and part_is_text(artifact.parts[0])
else None,
size_bytes=artifact_size,
append=update.append or False,
last_chunk=update.last_chunk or False,
append=artifact_update.append or False,
last_chunk=artifact_update.last_chunk or False,
endpoint=params.endpoint,
a2a_agent_name=params.a2a_agent_name,
context_id=effective_context_id,
@@ -349,86 +369,63 @@ class StreamingHandler:
),
)
is_final_update = (
process_status_update(update, result_parts)
if isinstance(update, TaskStatusUpdateEvent)
else False
)
# Handle status updates
elif is_stream_status_update(chunk):
update = chunk.status_update
is_final_update = process_status_update(update, result_parts)
if (
not is_final_update
and a2a_task.status.state
not in TERMINAL_STATES | ACTIONABLE_STATES
if current_task and (
is_final_update
or current_task.status.state
in TERMINAL_STATES | ACTIONABLE_STATES
):
final_result = process_task_state(
a2a_task=current_task,
new_messages=new_messages,
agent_card=agent_card,
turn_number=params.turn_number,
is_multiturn=params.is_multiturn,
agent_role=params.agent_role,
result_parts=result_parts,
endpoint=params.endpoint,
a2a_agent_name=params.a2a_agent_name,
from_task=params.from_task,
from_agent=params.from_agent,
is_final=is_final_update,
)
elif not current_task and is_final_update:
pass
else:
continue
final_result = process_task_state(
a2a_task=a2a_task,
new_messages=new_messages,
agent_card=agent_card,
turn_number=params.turn_number,
is_multiturn=params.is_multiturn,
agent_role=params.agent_role,
result_parts=result_parts,
endpoint=params.endpoint,
a2a_agent_name=params.a2a_agent_name,
from_task=params.from_task,
from_agent=params.from_agent,
is_final=is_final_update,
)
if final_result:
break
except A2AClientError as e:
logger.warning(
"Stream interrupted",
extra={
"task_id": current_task_id,
"error": str(e),
"error_type": type(e).__name__,
},
)
except A2AClientHTTPError as e:
if current_task_id:
logger.info(
"Stream interrupted with HTTP error, attempting recovery",
extra={
"task_id": current_task_id,
"error": str(e),
"status_code": e.status_code,
},
recovery_result = await StreamingHandler._try_recover_from_interruption(
client=client,
task_id=current_task_id,
new_messages=new_messages,
agent_card=agent_card,
result_parts=result_parts,
**kwargs,
)
recovery_kwargs = {k: v for k, v in kwargs.items() if k != "task_id"}
recovered_result = (
await StreamingHandler._try_recover_from_interruption(
client=client,
task_id=current_task_id,
new_messages=new_messages,
agent_card=agent_card,
result_parts=result_parts,
**recovery_kwargs,
)
)
if recovered_result:
logger.info(
"Successfully recovered task after HTTP error",
extra={
"task_id": current_task_id,
"status": str(recovered_result.get("status")),
},
)
return recovered_result
if recovery_result:
return recovery_result
logger.warning(
"Failed to recover from HTTP error, returning failure",
extra={
"task_id": current_task_id,
"status_code": e.status_code,
"original_error": str(e),
},
)
error_msg = f"HTTP Error {e.status_code}: {e!s}"
error_type = "http_error"
status_code = e.status_code
error_message = Message(
role=Role.agent,
message_id=str(uuid.uuid4()),
parts=[Part(root=TextPart(text=error_msg))],
error_msg = f"A2A Client Error: {e!s}"
error_message = new_text_message(
error_msg,
role=ROLE_AGENT,
context_id=params.context_id,
task_id=task_id,
task_id=current_task_id,
)
new_messages.append(error_message)
@@ -437,12 +434,11 @@ class StreamingHandler:
A2AConnectionErrorEvent(
endpoint=params.endpoint,
error=str(e),
error_type=error_type,
status_code=status_code,
error_type="client_error",
a2a_agent_name=params.a2a_agent_name,
operation="streaming",
context_id=params.context_id,
task_id=task_id,
task_id=current_task_id,
from_task=params.from_task,
from_agent=params.from_agent,
),
@@ -464,116 +460,40 @@ class StreamingHandler:
),
)
return TaskStateResult(
status=TaskState.failed,
error=error_msg,
history=new_messages,
)
except (asyncio.TimeoutError, asyncio.CancelledError, ConnectionError) as e:
error_type = type(e).__name__.lower()
if current_task_id:
logger.info(
f"Stream interrupted with {error_type}, attempting recovery",
extra={"task_id": current_task_id, "error": str(e)},
)
recovery_kwargs = {k: v for k, v in kwargs.items() if k != "task_id"}
recovered_result = (
await StreamingHandler._try_recover_from_interruption(
client=client,
task_id=current_task_id,
new_messages=new_messages,
agent_card=agent_card,
result_parts=result_parts,
**recovery_kwargs,
)
)
if recovered_result:
logger.info(
f"Successfully recovered task after {error_type}",
extra={
"task_id": current_task_id,
"status": str(recovered_result.get("status")),
},
)
return recovered_result
logger.warning(
f"Failed to recover from {error_type}, returning failure",
extra={
"task_id": current_task_id,
"error_type": error_type,
"original_error": str(e),
},
)
error_msg = f"Connection error during streaming: {e!s}"
status_code = None
error_message = Message(
role=Role.agent,
message_id=str(uuid.uuid4()),
parts=[Part(root=TextPart(text=error_msg))],
context_id=params.context_id,
task_id=task_id,
)
new_messages.append(error_message)
crewai_event_bus.emit(
agent_branch,
A2AConnectionErrorEvent(
endpoint=params.endpoint,
error=str(e),
error_type=error_type,
status_code=status_code,
a2a_agent_name=params.a2a_agent_name,
operation="streaming",
context_id=params.context_id,
task_id=task_id,
from_task=params.from_task,
from_agent=params.from_agent,
),
)
crewai_event_bus.emit(
agent_branch,
A2AResponseReceivedEvent(
response=error_msg,
turn_number=params.turn_number,
context_id=params.context_id,
is_multiturn=params.is_multiturn,
status="failed",
final=True,
agent_role=params.agent_role,
endpoint=params.endpoint,
a2a_agent_name=params.a2a_agent_name,
from_task=params.from_task,
from_agent=params.from_agent,
),
)
return TaskStateResult(
status=TaskState.failed,
status=TASK_STATE_FAILED,
error=error_msg,
history=new_messages,
)
except Exception as e:
logger.exception(
"Unexpected error during streaming",
logger.warning(
"Unexpected stream error",
extra={
"task_id": current_task_id,
"error": str(e),
"error_type": type(e).__name__,
"endpoint": params.endpoint,
},
exc_info=True,
)
error_msg = f"Unexpected error during streaming: {type(e).__name__}: {e!s}"
error_type = "unexpected_error"
status_code = None
error_message = Message(
role=Role.agent,
message_id=str(uuid.uuid4()),
parts=[Part(root=TextPart(text=error_msg))],
if current_task_id:
recovery_result = await StreamingHandler._try_recover_from_interruption(
client=client,
task_id=current_task_id,
new_messages=new_messages,
agent_card=agent_card,
result_parts=result_parts,
**kwargs,
)
if recovery_result:
return recovery_result
error_msg = f"Unexpected error during streaming: {e!s}"
error_message = new_text_message(
error_msg,
role=ROLE_AGENT,
context_id=params.context_id,
task_id=task_id,
task_id=current_task_id,
)
new_messages.append(error_message)
@@ -582,12 +502,11 @@ class StreamingHandler:
A2AConnectionErrorEvent(
endpoint=params.endpoint,
error=str(e),
error_type=error_type,
status_code=status_code,
error_type="unexpected_error",
a2a_agent_name=params.a2a_agent_name,
operation="streaming",
context_id=params.context_id,
task_id=task_id,
task_id=current_task_id,
from_task=params.from_task,
from_agent=params.from_agent,
),
@@ -609,38 +528,33 @@ class StreamingHandler:
),
)
return TaskStateResult(
status=TaskState.failed,
status=TASK_STATE_FAILED,
error=error_msg,
history=new_messages,
)
finally:
aclose = getattr(event_stream, "aclose", None)
if aclose:
try:
await aclose()
except Exception as close_error:
crewai_event_bus.emit(
agent_branch,
A2AConnectionErrorEvent(
endpoint=params.endpoint,
error=str(close_error),
error_type="stream_close_error",
a2a_agent_name=params.a2a_agent_name,
operation="stream_close",
context_id=params.context_id,
task_id=task_id,
from_task=params.from_task,
from_agent=params.from_agent,
),
)
if final_result:
return final_result
return TaskStateResult(
status=TaskState.completed,
result=" ".join(result_parts) if result_parts else "",
history=new_messages,
agent_card=agent_card.model_dump(exclude_none=True),
response_text = " ".join(result_parts) if result_parts else ""
crewai_event_bus.emit(
agent_branch,
A2AResponseReceivedEvent(
response=response_text,
turn_number=params.turn_number,
context_id=params.context_id,
is_multiturn=params.is_multiturn,
status="completed",
final=True,
agent_role=params.agent_role,
endpoint=params.endpoint,
a2a_agent_name=params.a2a_agent_name,
from_task=params.from_task,
from_agent=params.from_agent,
),
)
return TaskStateResult(
status=TASK_STATE_FAILED,
error="Stream ended without terminal state",
history=new_messages,
)

View File

@@ -4,6 +4,8 @@ from __future__ import annotations
from a2a.types import TaskStatusUpdateEvent
from crewai.a2a._compat import is_status_update_final, part_is_text, part_text
def process_status_update(
update: TaskStatusUpdateEvent,
@@ -18,11 +20,11 @@ def process_status_update(
Returns:
True if this is a final update, False otherwise.
"""
is_final = update.final
is_final = is_status_update_final(update)
if update.status and update.status.message and update.status.message.parts:
result_parts.extend(
part.root.text
part_text(part)
for part in update.status.message.parts
if part.root.kind == "text" and part.root.text
if part_is_text(part) and part_text(part)
)
return is_final

View File

@@ -12,12 +12,18 @@ import time
from types import MethodType
from typing import TYPE_CHECKING
from a2a.client.errors import A2AClientHTTPError
from a2a.types import AgentCapabilities, AgentCard, AgentSkill
from a2a.client.errors import A2AClientError
from a2a.types import AgentCapabilities, AgentCard, AgentInterface, AgentSkill
from aiocache import cached # type: ignore[import-untyped]
from aiocache.serializers import PickleSerializer # type: ignore[import-untyped]
from google.protobuf.json_format import ParseDict
import httpx
from crewai.a2a._compat import (
agent_card_protocol_version,
agent_card_to_dict,
proto_copy,
)
from crewai.a2a.auth.client_schemes import APIKeyAuth, HTTPDigestAuth
from crewai.a2a.auth.utils import (
_auth_store,
@@ -277,9 +283,9 @@ async def _afetch_agent_card_impl(
)
response.raise_for_status()
agent_card = AgentCard.model_validate(response.json())
agent_card = ParseDict(response.json(), AgentCard())
fetch_time_ms = (time.perf_counter() - start_time) * 1000
agent_card_dict = agent_card.model_dump(exclude_none=True)
agent_card_dict = agent_card_to_dict(agent_card)
crewai_event_bus.emit(
None,
@@ -287,7 +293,7 @@ async def _afetch_agent_card_impl(
endpoint=endpoint,
a2a_agent_name=agent_card.name,
agent_card=agent_card_dict,
protocol_version=agent_card.protocol_version,
protocol_version=agent_card_protocol_version(agent_card),
provider=agent_card_dict.get("provider"),
cached=False,
fetch_time_ms=fetch_time_ms,
@@ -326,7 +332,7 @@ async def _afetch_agent_card_impl(
),
)
raise A2AClientHTTPError(401, msg) from e
raise A2AClientError(msg) from e
crewai_event_bus.emit(
None,
@@ -470,7 +476,9 @@ def _crew_to_agent_card(crew: Crew, url: str) -> AgentCard:
return AgentCard(
name=crew_name,
description=" ".join(description_parts),
url=url,
supported_interfaces=[
AgentInterface(url=url, protocol_binding="JSONRPC"),
],
version="1.0.0",
capabilities=AgentCapabilities(
streaming=True,
@@ -540,28 +548,43 @@ def _agent_to_agent_card(agent: Agent, url: str) -> AgentCard:
if ext.uri not in existing_uris:
existing_exts.append(ext)
capabilities = capabilities.model_copy(update={"extensions": existing_exts})
capabilities = proto_copy(capabilities)
del capabilities.extensions[:]
capabilities.extensions.extend(existing_exts)
primary_interface = AgentInterface(
url=server_config.url or url,
protocol_binding=server_config.transport.preferred or "JSONRPC",
protocol_version=server_config.protocol_version or "",
)
interfaces = [primary_interface]
if server_config.additional_interfaces:
interfaces.extend(server_config.additional_interfaces)
card = AgentCard(
name=name,
description=description,
url=server_config.url or url,
supported_interfaces=interfaces,
version=server_config.version,
capabilities=capabilities,
default_input_modes=server_config.default_input_modes,
default_output_modes=server_config.default_output_modes,
skills=skills,
preferred_transport=server_config.transport.preferred,
protocol_version=server_config.protocol_version,
provider=server_config.provider,
documentation_url=server_config.documentation_url,
icon_url=server_config.icon_url,
additional_interfaces=server_config.additional_interfaces,
security=server_config.security,
security_schemes=server_config.security_schemes,
supports_authenticated_extended_card=server_config.supports_authenticated_extended_card,
documentation_url=server_config.documentation_url or "",
icon_url=server_config.icon_url or "",
)
if server_config.provider:
card.provider.CopyFrom(server_config.provider)
if server_config.security_schemes:
for k, v in server_config.security_schemes.items():
card.security_schemes[k].CopyFrom(v)
if server_config.security:
for req in server_config.security:
card.security_requirements.append(req)
if server_config.signing_config:
signature = sign_agent_card(
card,
@@ -569,9 +592,11 @@ def _agent_to_agent_card(agent: Agent, url: str) -> AgentCard:
key_id=server_config.signing_config.key_id,
algorithm=server_config.signing_config.algorithm,
)
card = card.model_copy(update={"signatures": [signature]})
del card.signatures[:]
card.signatures.append(signature)
elif server_config.signatures:
card = card.model_copy(update={"signatures": server_config.signatures})
del card.signatures[:]
card.signatures.extend(server_config.signatures)
return card

View File

@@ -18,6 +18,7 @@ import logging
from typing import Any, Literal
from a2a.types import AgentCard, AgentCardSignature
from google.protobuf.json_format import MessageToDict
import jwt
from pydantic import SecretStr
@@ -58,7 +59,11 @@ def _serialize_agent_card(agent_card: AgentCard) -> str:
Returns:
Canonical JSON string representation.
"""
card_dict = agent_card.model_dump(exclude={"signatures"}, exclude_none=True)
card_dict = MessageToDict(
agent_card,
preserving_proto_field_name=True,
)
card_dict.pop("signatures", None)
return json.dumps(card_dict, sort_keys=True, separators=(",", ":"))

View File

@@ -264,7 +264,7 @@ def negotiate_content_types(
crewai_event_bus.emit(
None,
A2AContentTypeNegotiatedEvent(
endpoint=endpoint or agent_card.url,
endpoint=endpoint or "",
a2a_agent_name=a2a_agent_name or agent_card.name,
skill_name=skill_name,
client_input_modes=client_input_modes,
@@ -303,22 +303,21 @@ def get_part_content_type(part: Part) -> str:
"""Extract MIME type from an A2A Part.
Args:
part: A Part object containing TextPart, DataPart, or FilePart.
part: A Part object (protobuf oneof: text, data, raw, url).
Returns:
The MIME type string for this part.
"""
root = part.root
if root.kind == "text":
if part.HasField("text"):
return TEXT_PLAIN
if root.kind == "data":
metadata = root.metadata or {}
if part.HasField("data"):
metadata = dict(part.metadata) if part.metadata else {}
mime = metadata.get("mimeType", "")
if mime == APPLICATION_A2UI_JSON:
return APPLICATION_A2UI_JSON
return APPLICATION_JSON
if root.kind == "file":
return root.file.mime_type or APPLICATION_OCTET_STREAM
if part.HasField("raw") or part.HasField("url"):
return part.media_type or APPLICATION_OCTET_STREAM
return APPLICATION_OCTET_STREAM

View File

@@ -3,7 +3,6 @@
from __future__ import annotations
import asyncio
import base64
from collections.abc import AsyncIterator, Callable, MutableMapping
import concurrent.futures
from contextlib import asynccontextmanager
@@ -12,20 +11,27 @@ import logging
from typing import TYPE_CHECKING, Any, Final, Literal
import uuid
from a2a.client import Client, ClientConfig, ClientFactory
from a2a.client import Client, ClientFactory
from a2a.types import (
AgentCard,
FilePart,
FileWithBytes,
Message,
Part,
PushNotificationConfig as A2APushNotificationConfig,
Role,
TextPart,
TaskPushNotificationConfig,
)
import httpx
from pydantic import BaseModel
from crewai.a2a._compat import (
ROLE_USER,
agent_card_interfaces,
agent_card_protocol_version,
agent_card_to_dict,
agent_card_url,
create_client_config,
new_text_part,
proto_copy,
)
from crewai.a2a.auth.client_schemes import APIKeyAuth, HTTPDigestAuth
from crewai.a2a.auth.utils import (
_auth_store,
@@ -41,8 +47,6 @@ from crewai.a2a.task_helpers import TaskStateResult
from crewai.a2a.types import (
HANDLER_REGISTRY,
HandlerType,
PartsDict,
PartsMetadataDict,
TransportType,
)
from crewai.a2a.updates import (
@@ -107,13 +111,13 @@ def _create_file_parts(input_files: dict[str, Any] | None) -> list[Part]:
parts: list[Part] = []
for name, file_input in input_files.items():
content_bytes = file_input.read()
content_base64 = base64.b64encode(content_bytes).decode()
file_with_bytes = FileWithBytes(
bytes=content_base64,
mimeType=file_input.content_type,
name=file_input.filename or name,
parts.append(
Part(
raw=content_bytes,
media_type=file_input.content_type or "application/octet-stream",
filename=file_input.filename or name,
)
)
parts.append(Part(root=FilePart(file=file_with_bytes)))
return parts
@@ -301,7 +305,7 @@ async def aexecute_a2a_delegation(
is_multiturn = len(conversation_history) > 0
if turn_number is None:
turn_number = len([m for m in conversation_history if m.role == Role.user]) + 1
turn_number = len([m for m in conversation_history if m.role == ROLE_USER]) + 1
try:
result = await _aexecute_a2a_delegation_impl(
@@ -349,7 +353,7 @@ async def aexecute_a2a_delegation(
)
raise
agent_card_data = result.get("agent_card")
agent_card_data: dict[str, Any] | None = result.get("agent_card")
crewai_event_bus.emit(
agent_branch,
A2ADelegationCompletedEvent(
@@ -423,14 +427,14 @@ async def _aexecute_a2a_delegation_impl(
unsupported_exts = validate_required_extensions(agent_card, client_extensions)
if unsupported_exts:
ext_uris = [ext.uri for ext in unsupported_exts]
ext_uris = [e.uri for e in unsupported_exts]
raise ValueError(
f"Agent requires extensions not supported by client: {ext_uris}"
)
negotiated: NegotiatedTransport | None = None
effective_transport: TransportType = transport.preferred or _DEFAULT_TRANSPORT
effective_url = endpoint
effective_url = agent_card_url(agent_card) or endpoint
client_transports: list[str] = (
list(transport.supported) if transport.supported else [_DEFAULT_TRANSPORT]
@@ -456,9 +460,9 @@ async def _aexecute_a2a_delegation_impl(
"endpoint": endpoint,
"client_transports": client_transports,
"server_transports": [
iface.transport for iface in agent_card.additional_interfaces or []
]
+ [agent_card.preferred_transport or "JSONRPC"],
iface.protocol_binding
for iface in agent_card_interfaces(agent_card)
],
},
)
@@ -476,11 +480,9 @@ async def _aexecute_a2a_delegation_impl(
headers, _ = await _prepare_auth_headers(auth, timeout)
a2a_agent_name = None
if agent_card.name:
a2a_agent_name = agent_card.name
a2a_agent_name = agent_card.name or None
agent_card_dict = agent_card.model_dump(exclude_none=True)
agent_card_dict = agent_card_to_dict(agent_card)
crewai_event_bus.emit(
agent_branch,
A2ADelegationStartedEvent(
@@ -492,7 +494,7 @@ async def _aexecute_a2a_delegation_impl(
turn_number=turn_number,
a2a_agent_name=a2a_agent_name,
agent_card=agent_card_dict,
protocol_version=agent_card.protocol_version,
protocol_version=agent_card_protocol_version(agent_card),
provider=agent_card_dict.get("provider"),
skill_id=skill_id,
metadata=metadata,
@@ -512,7 +514,7 @@ async def _aexecute_a2a_delegation_impl(
context_id=context_id,
a2a_agent_name=a2a_agent_name,
agent_card=agent_card_dict,
protocol_version=agent_card.protocol_version,
protocol_version=agent_card_protocol_version(agent_card),
provider=agent_card_dict.get("provider"),
skill_id=skill_id,
reference_task_ids=reference_task_ids,
@@ -534,26 +536,18 @@ async def _aexecute_a2a_delegation_impl(
if first_task_id := conversation_history[0].task_id:
task_id = first_task_id
parts: PartsDict = {"text": message_text}
if response_model:
parts.update(
{
"metadata": PartsMetadataDict(
mimeType="application/json",
schema=response_model.model_json_schema(),
)
}
)
message_metadata = metadata.copy() if metadata else {}
if skill_id:
message_metadata["skill_id"] = skill_id
if response_model:
message_metadata["mimeType"] = "application/json"
message_metadata["schema"] = response_model.model_json_schema()
parts_list: list[Part] = [Part(root=TextPart(**parts))]
parts_list: list[Part] = [new_text_part(message_text)]
parts_list.extend(_create_file_parts(input_files))
message = Message(
role=Role.user,
role=ROLE_USER,
message_id=str(uuid.uuid4()),
parts=parts_list,
context_id=context_id,
@@ -625,8 +619,11 @@ async def _aexecute_a2a_delegation_impl(
use_streaming = not use_polling and push_config_for_client is None
client_agent_card = agent_card
if effective_url != agent_card.url:
client_agent_card = agent_card.model_copy(update={"url": effective_url})
card_url = agent_card_url(agent_card)
if effective_url != card_url:
client_agent_card = proto_copy(agent_card)
if client_agent_card.supported_interfaces:
client_agent_card.supported_interfaces[0].url = effective_url
async with _create_a2a_client(
agent_card=client_agent_card,
@@ -649,7 +646,7 @@ async def _aexecute_a2a_delegation_impl(
**handler_kwargs,
)
result["a2a_agent_name"] = a2a_agent_name
result["agent_card"] = agent_card.model_dump(exclude_none=True)
result["agent_card"] = agent_card_to_dict(agent_card)
return result
@@ -933,15 +930,12 @@ async def _create_a2a_client(
if auth and isinstance(auth, (HTTPDigestAuth, APIKeyAuth)):
configure_auth_client(auth, httpx_client)
push_configs: list[A2APushNotificationConfig] = []
push_config: TaskPushNotificationConfig | None = None
if push_notification_config is not None:
push_configs.append(
A2APushNotificationConfig(
url=str(push_notification_config.url),
id=push_notification_config.id,
token=push_notification_config.token,
authentication=push_notification_config.authentication,
)
push_config = TaskPushNotificationConfig(
url=str(push_notification_config.url),
id=push_notification_config.id or "",
token=push_notification_config.token or "",
)
grpc_channel_factory = None
@@ -951,13 +945,14 @@ async def _create_a2a_client(
auth=auth,
)
config = ClientConfig(
config = create_client_config(
httpx_client=httpx_client,
supported_transports=[transport_protocol],
streaming=streaming and not use_polling,
polling=use_polling,
accepted_output_modes=accepted_output_modes or DEFAULT_CLIENT_OUTPUT_MODES, # type: ignore[arg-type]
push_notification_configs=push_configs,
accepted_output_modes=accepted_output_modes
or list(DEFAULT_CLIENT_OUTPUT_MODES),
push_notification_config=push_config,
grpc_channel_factory=grpc_channel_factory,
)
@@ -965,6 +960,6 @@ async def _create_a2a_client(
client = factory.create(agent_card)
if client_extensions:
await client.add_request_middleware(ExtensionsMiddleware(client_extensions))
await client.add_interceptor(ExtensionsMiddleware(client_extensions))
yield client

View File

@@ -13,13 +13,15 @@ import os
from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar, cast
from urllib.parse import urlparse
from a2a.helpers.proto_helpers import (
new_artifact,
new_text_artifact,
new_text_message as new_agent_text_message,
)
from a2a.server.agent_execution import RequestContext
from a2a.server.events import EventQueue
from a2a.types import (
Artifact,
FileWithBytes,
FileWithUri,
InternalError,
InvalidParamsError,
Message,
Part,
@@ -28,18 +30,20 @@ from a2a.types import (
TaskStatus,
TaskStatusUpdateEvent,
)
from a2a.utils import (
get_data_parts,
get_file_parts,
new_agent_text_message,
new_data_artifact,
new_text_artifact,
)
from a2a.utils.errors import ServerError
from a2a.utils.errors import A2AError as ServerError
from aiocache import SimpleMemoryCache, caches # type: ignore[import-untyped]
from pydantic import BaseModel
from typing_extensions import TypedDict
from crewai.a2a._compat import (
TASK_STATE_CANCELED,
TASK_STATE_COMPLETED,
TASK_STATE_FAILED,
part_has_data,
part_has_file,
part_is_text,
proto_copy,
)
from crewai.a2a.utils.agent_card import _get_server_config
from crewai.a2a.utils.content_type import validate_message_parts
from crewai.events.event_bus import crewai_event_bus
@@ -64,6 +68,28 @@ P = ParamSpec("P")
T = TypeVar("T")
def _get_data_parts(parts: list[Part]) -> list[dict[str, Any]]:
"""Extract data parts from a list of protobuf Parts.
In a2a-sdk v1.0, data is stored via the ``data`` oneof field on Part
(a ``google.protobuf.Value``).
"""
result: list[dict[str, Any]] = []
for part in parts:
if part_has_data(part):
from google.protobuf.json_format import MessageToDict
val = MessageToDict(part.data)
if isinstance(val, dict):
result.append(val)
return result
def _get_file_parts(parts: list[Part]) -> list[Part]:
"""Return parts that carry file content (raw bytes or url)."""
return [p for p in parts if part_has_file(p)]
class RedisCacheConfig(TypedDict, total=False):
"""Configuration for aiocache Redis backend."""
@@ -196,12 +222,12 @@ def cancellable(
def _convert_a2a_files_to_file_inputs(
a2a_files: list[FileWithBytes | FileWithUri],
a2a_files: list[Part],
) -> dict[str, Any]:
"""Convert a2a file types to crewai FileInput dict.
"""Convert a2a file parts to crewai FileInput dict.
Args:
a2a_files: List of FileWithBytes or FileWithUri from a2a SDK.
a2a_files: List of Parts that carry file content (raw or url).
Returns:
Dictionary mapping file names to FileInput objects.
@@ -213,15 +239,13 @@ def _convert_a2a_files_to_file_inputs(
return {}
file_dict: dict[str, Any] = {}
for idx, a2a_file in enumerate(a2a_files):
if isinstance(a2a_file, FileWithBytes):
file_bytes = base64.b64decode(a2a_file.bytes)
name = a2a_file.name or f"file_{idx}"
file_source = FileBytes(data=file_bytes, filename=a2a_file.name)
for idx, part in enumerate(a2a_files):
name = part.filename or f"file_{idx}"
if part.HasField("raw"):
file_source = FileBytes(data=part.raw, filename=name)
file_dict[name] = File(source=file_source)
elif isinstance(a2a_file, FileWithUri):
name = a2a_file.name or f"file_{idx}"
file_dict[name] = File(source=a2a_file.uri)
elif part.HasField("url"):
file_dict[name] = File(source=part.url)
return file_dict
@@ -239,8 +263,9 @@ def _extract_response_schema(parts: list[Part]) -> dict[str, Any] | None:
JSON schema dict if found, None otherwise.
"""
for part in parts:
if part.root.kind == "text" and part.root.metadata:
schema = part.root.metadata.get("schema")
if part_is_text(part) and part.metadata:
metadata_dict = dict(part.metadata)
schema = metadata_dict.get("schema")
if schema and isinstance(schema, dict):
return schema # type: ignore[no-any-return]
return None
@@ -261,9 +286,17 @@ def _create_result_artifact(
"""
artifact_name = f"result_{task_id}"
if isinstance(result, dict):
return new_data_artifact(artifact_name, result)
from google.protobuf import struct_pb2
val = struct_pb2.Value()
val.struct_value.update(result)
return new_artifact([Part(data=val)], artifact_name)
if isinstance(result, BaseModel):
return new_data_artifact(artifact_name, result.model_dump())
from google.protobuf import struct_pb2
val = struct_pb2.Value()
val.struct_value.update(result.model_dump())
return new_artifact([Part(data=val)], artifact_name)
return new_text_artifact(artifact_name, str(result))
@@ -330,7 +363,7 @@ async def _execute_impl(
response_model: type[BaseModel] | None = None
structured_inputs: list[dict[str, Any]] = []
a2a_files: list[FileWithBytes | FileWithUri] = []
a2a_files: list[Part] = []
if context.message and context.message.parts:
schema = _extract_response_schema(context.message.parts)
@@ -343,8 +376,8 @@ async def _execute_impl(
extra={"error": str(e), "schema_title": schema.get("title")},
)
structured_inputs = get_data_parts(context.message.parts)
a2a_files = get_file_parts(context.message.parts)
structured_inputs = _get_data_parts(context.message.parts)
a2a_files = _get_file_parts(context.message.parts)
task_id = context.task_id
context_id = context.context_id
@@ -387,12 +420,14 @@ async def _execute_impl(
)
result_str = str(result)
history: list[Message] = [context.message] if context.message else []
history.append(new_agent_text_message(result_str, context_id, task_id))
history.append(
new_agent_text_message(result_str, context_id=context_id, task_id=task_id)
)
await event_queue.enqueue_event(
A2ATask(
id=task_id,
context_id=context_id,
status=TaskStatus(state=TaskState.completed),
status=TaskStatus(state=TASK_STATE_COMPLETED),
artifacts=[_create_result_artifact(result, task_id)],
history=history,
)
@@ -429,9 +464,7 @@ async def _execute_impl(
from_agent=agent,
),
)
raise ServerError(
error=InternalError(message=f"Task execution failed: {e}")
) from e
raise ServerError(f"Task execution failed: {e}") from e
async def execute_with_extensions(
@@ -476,9 +509,9 @@ async def cancel(
raise ServerError(InvalidParamsError(message="task_id and context_id required"))
if context.current_task and context.current_task.status.state in (
TaskState.completed,
TaskState.failed,
TaskState.canceled,
TASK_STATE_COMPLETED,
TASK_STATE_FAILED,
TASK_STATE_CANCELED,
):
return context.current_task
@@ -492,13 +525,12 @@ async def cancel(
TaskStatusUpdateEvent(
task_id=task_id,
context_id=context_id,
status=TaskStatus(state=TaskState.canceled),
final=True,
status=TaskStatus(state=TASK_STATE_CANCELED),
)
)
if context.current_task:
context.current_task.status = TaskStatus(state=TaskState.canceled)
context.current_task.status.CopyFrom(TaskStatus(state=TASK_STATE_CANCELED))
return context.current_task
return None
@@ -571,7 +603,7 @@ def list_tasks(
result: list[A2ATask] = []
for task in page:
task = task.model_copy(deep=True)
task = proto_copy(task)
if history_length is not None and task.history:
task.history = task.history[-history_length:]
if not include_artifacts:

View File

@@ -12,6 +12,11 @@ from typing import Final, Literal
from a2a.types import AgentCard, AgentInterface
from crewai.a2a._compat import (
agent_card_interfaces,
agent_card_preferred_transport,
agent_card_url,
)
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.a2a_events import A2ATransportNegotiatedEvent
@@ -85,23 +90,21 @@ def _get_server_interfaces(agent_card: AgentCard) -> list[AgentInterface]:
List of AgentInterface objects representing all available endpoints.
"""
interfaces: list[AgentInterface] = []
primary_transport = agent_card.preferred_transport or JSONRPC_TRANSPORT
interfaces.append(
AgentInterface(
transport=primary_transport,
url=agent_card.url,
for interface in agent_card_interfaces(agent_card):
is_duplicate = any(
i.url == interface.url and i.protocol_binding == interface.protocol_binding
for i in interfaces
)
)
if not is_duplicate:
interfaces.append(interface)
if agent_card.additional_interfaces:
for interface in agent_card.additional_interfaces:
is_duplicate = any(
i.url == interface.url and i.transport == interface.transport
for i in interfaces
if not interfaces:
interfaces.append(
AgentInterface(
url=agent_card_url(agent_card),
protocol_binding=JSONRPC_TRANSPORT,
)
if not is_duplicate:
interfaces.append(interface)
)
return interfaces
@@ -149,11 +152,11 @@ def negotiate_transport(
)
server_interfaces = _get_server_interfaces(agent_card)
server_transports = [i.transport.upper() for i in server_interfaces]
server_transports = [i.protocol_binding.upper() for i in server_interfaces]
transport_to_interface: dict[str, AgentInterface] = {}
for interface in server_interfaces:
transport_upper = interface.transport.upper()
transport_upper = interface.protocol_binding.upper()
if transport_upper not in transport_to_interface:
transport_to_interface[transport_upper] = interface
@@ -162,19 +165,21 @@ def negotiate_transport(
if client_preferred and client_preferred in transport_to_interface:
interface = transport_to_interface[client_preferred]
result = NegotiatedTransport(
transport=interface.transport,
transport=interface.protocol_binding,
url=interface.url,
source="client_preferred",
)
else:
server_preferred = (agent_card.preferred_transport or JSONRPC_TRANSPORT).upper()
server_preferred = (
agent_card_preferred_transport(agent_card) or JSONRPC_TRANSPORT
).upper()
if (
server_preferred in client_transports
and server_preferred in transport_to_interface
):
interface = transport_to_interface[server_preferred]
result = NegotiatedTransport(
transport=interface.transport,
transport=interface.protocol_binding,
url=interface.url,
source="server_preferred",
)
@@ -183,7 +188,7 @@ def negotiate_transport(
if transport in transport_to_interface:
interface = transport_to_interface[transport]
result = NegotiatedTransport(
transport=interface.transport,
transport=interface.protocol_binding,
url=interface.url,
source="fallback",
)
@@ -199,14 +204,14 @@ def negotiate_transport(
crewai_event_bus.emit(
None,
A2ATransportNegotiatedEvent(
endpoint=endpoint or agent_card.url,
endpoint=endpoint or agent_card_url(agent_card),
a2a_agent_name=a2a_agent_name or agent_card.name,
negotiated_transport=result.transport,
negotiated_url=result.url,
source=result.source,
client_supported_transports=client_transports,
server_supported_transports=server_transports,
server_preferred_transport=agent_card.preferred_transport
server_preferred_transport=agent_card_preferred_transport(agent_card)
or JSONRPC_TRANSPORT,
client_preferred_transport=client_preferred,
),

View File

@@ -14,9 +14,18 @@ import json
from types import MethodType
from typing import TYPE_CHECKING, Any, NamedTuple
from a2a.types import Role, TaskState
from pydantic import BaseModel, ValidationError
from crewai.a2a._compat import (
ROLE_AGENT,
ROLE_USER,
TASK_STATE_COMPLETED,
TASK_STATE_INPUT_REQUIRED,
agent_card_to_dict,
part_is_text,
part_text,
proto_to_json,
)
from crewai.a2a.config import A2AClientConfig, A2AConfig
from crewai.a2a.extensions.base import (
A2AExtension,
@@ -681,7 +690,7 @@ def _augment_prompt_with_a2a(
}
agents_text += f"\n{json.dumps(filtered, indent=2)}\n"
else:
agents_text += f"\n{card.model_dump_json(indent=2, exclude_none=True, include={'description', 'url', 'skills'})}\n"
agents_text += f"\n{proto_to_json(card)}\n"
failed_agents = failed_agents or {}
if failed_agents:
@@ -695,7 +704,7 @@ def _augment_prompt_with_a2a(
if conversation_history:
for msg in conversation_history:
history_text += f"\n{msg.model_dump_json(indent=2, exclude_none=True, exclude={'message_id'})}\n"
history_text += f"\n{proto_to_json(msg)}\n"
history_text = PREVIOUS_A2A_CONVERSATION_TEMPLATE.substitute(
previous_a2a_conversation=history_text
@@ -780,9 +789,9 @@ def _handle_max_turns_exceeded(
"""
if conversation_history:
for msg in reversed(conversation_history):
if msg.role == Role.agent:
if msg.role == ROLE_AGENT:
text_parts = [
part.root.text for part in msg.parts if part.root.kind == "text"
part_text(part) for part in msg.parts if part_is_text(part)
]
final_message = (
" ".join(text_parts) if text_parts else "Conversation completed"
@@ -985,7 +994,9 @@ def _init_delegation_state(
reference_task_ids=list(ctx.reference_task_ids),
conversation_history=[],
agent_card=current_agent_card,
agent_card_dict=current_agent_card.model_dump() if current_agent_card else None,
agent_card_dict=agent_card_to_dict(current_agent_card)
if current_agent_card
else None,
agent_name=current_agent_card.name if current_agent_card else None,
)
@@ -1110,7 +1121,7 @@ def _handle_task_completion(
- remote_notice: Template notice about remote agent response
"""
remote_notice = ""
if a2a_result["status"] == TaskState.completed:
if a2a_result["status"] == TASK_STATE_COMPLETED:
remote_notice = REMOTE_AGENT_RESPONSE_NOTICE
if task_id_config is not None and task_id_config not in reference_task_ids:
@@ -1294,7 +1305,7 @@ def _delegate_to_a2a(
extensions=ctx.extensions,
conversation_history=conversation_history,
agent_id=ctx.agent_id,
agent_role=Role.user,
agent_role=ROLE_USER,
agent_branch=agent_branch,
response_model=ctx.agent_config.response_model,
turn_number=turn_num + 1,
@@ -1316,7 +1327,10 @@ def _delegate_to_a2a(
if latest_message.context_id is not None:
context_id = latest_message.context_id
if a2a_result["status"] in [TaskState.completed, TaskState.input_required]:
if a2a_result["status"] in [
TASK_STATE_COMPLETED,
TASK_STATE_INPUT_REQUIRED,
]:
trusted_result, task_id, reference_task_ids, remote_notice = (
_handle_task_completion(
a2a_result,
@@ -1649,7 +1663,7 @@ async def _adelegate_to_a2a(
extensions=ctx.extensions,
conversation_history=conversation_history,
agent_id=ctx.agent_id,
agent_role=Role.user,
agent_role=ROLE_USER,
agent_branch=agent_branch,
response_model=ctx.agent_config.response_model,
turn_number=turn_num + 1,
@@ -1671,7 +1685,10 @@ async def _adelegate_to_a2a(
if latest_message.context_id is not None:
context_id = latest_message.context_id
if a2a_result["status"] in [TaskState.completed, TaskState.input_required]:
if a2a_result["status"] in [
TASK_STATE_COMPLETED,
TASK_STATE_INPUT_REQUIRED,
]:
trusted_result, task_id, reference_task_ids, remote_notice = (
_handle_task_completion(
a2a_result,

View File

@@ -3,12 +3,23 @@ from __future__ import annotations
import os
import uuid
import httpx
import pytest
import pytest_asyncio
from a2a.client import ClientFactory
from a2a.types import AgentCard, Message, Part, Role, TaskState, TextPart
from a2a.client import A2ACardResolver, ClientFactory
from a2a.types import AgentCapabilities, AgentCard, AgentInterface, Message, Part, Role, TaskState
from crewai.a2a._compat import (
ROLE_AGENT,
ROLE_USER,
TASK_STATE_COMPLETED,
TASK_STATE_FAILED,
agent_card_url,
make_send_request,
new_text_message,
new_text_part,
)
from crewai.a2a.updates.polling.handler import PollingHandler
from crewai.a2a.updates.streaming.handler import StreamingHandler
@@ -17,27 +28,31 @@ A2A_TEST_ENDPOINT = os.getenv("A2A_TEST_ENDPOINT", "http://localhost:9999")
@pytest_asyncio.fixture
async def a2a_client():
async def card_resolver():
"""Create an A2ACardResolver for the test server."""
async with httpx.AsyncClient() as http_client:
resolver = A2ACardResolver(http_client, A2A_TEST_ENDPOINT)
yield resolver
@pytest_asyncio.fixture
async def agent_card(card_resolver) -> AgentCard:
"""Fetch the real agent card from the server."""
return await card_resolver.get_agent_card()
@pytest_asyncio.fixture
async def a2a_client(agent_card):
"""Create A2A client for test server."""
client = await ClientFactory.connect(A2A_TEST_ENDPOINT)
factory = ClientFactory()
client = factory.create(agent_card)
yield client
await client.close()
@pytest.fixture
def test_message() -> Message:
"""Create a simple test message."""
return Message(
role=Role.user,
parts=[Part(root=TextPart(text="What is 2 + 2?"))],
message_id=str(uuid.uuid4()),
)
@pytest_asyncio.fixture
async def agent_card(a2a_client) -> AgentCard:
"""Fetch the real agent card from the server."""
return await a2a_client.get_card()
return new_text_message("What is 2 + 2?", role=ROLE_USER)
class TestA2AAgentCardFetching:
@@ -45,13 +60,13 @@ class TestA2AAgentCardFetching:
@pytest.mark.vcr()
@pytest.mark.asyncio
async def test_fetch_agent_card(self, a2a_client) -> None:
async def test_fetch_agent_card(self, card_resolver) -> None:
"""Test fetching an agent card from the server."""
card = await a2a_client.get_card()
card = await card_resolver.get_agent_card()
assert card is not None
assert card.name == "GPT Assistant"
assert card.url is not None
assert card.supported_interfaces is not None
assert card.capabilities is not None
assert card.capabilities.streaming is True
@@ -80,7 +95,7 @@ class TestA2APollingIntegration:
)
assert isinstance(result, dict)
assert result["status"] == TaskState.completed
assert result["status"] == TASK_STATE_COMPLETED
assert result.get("result") is not None
assert "4" in result["result"]
@@ -104,11 +119,11 @@ class TestA2AStreamingIntegration:
message=test_message,
new_messages=new_messages,
agent_card=agent_card,
endpoint=agent_card.url,
endpoint=agent_card_url(agent_card),
)
assert isinstance(result, dict)
assert result["status"] == TaskState.completed
assert result["status"] == TASK_STATE_COMPLETED
assert result.get("result") is not None
@@ -123,19 +138,19 @@ class TestA2ATaskOperations:
test_message: Message,
) -> None:
"""Test sending a message and getting a response."""
from a2a.types import Task
from a2a.types import StreamResponse, Task
from crewai.a2a._compat import is_stream_task
final_task: Task | None = None
async for event in a2a_client.send_message(test_message):
if isinstance(event, tuple) and len(event) >= 1:
task, _ = event
if isinstance(task, Task):
final_task = task
async for event in a2a_client.send_message(make_send_request(test_message)):
if isinstance(event, StreamResponse) and is_stream_task(event):
final_task = event.task
assert final_task is not None
assert final_task.id is not None
assert final_task.id != ""
assert final_task.status is not None
assert final_task.status.state == TaskState.completed
assert final_task.status.state == TaskState.TASK_STATE_COMPLETED
class TestA2APushNotificationHandler:
@@ -148,17 +163,19 @@ class TestA2APushNotificationHandler:
@pytest.fixture
def mock_agent_card(self) -> AgentCard:
"""Create a minimal valid agent card for testing."""
from a2a.types import AgentCapabilities
return AgentCard(
name="Test Agent",
description="Test agent for push notification tests",
url="http://localhost:9999",
supported_interfaces=[
AgentInterface(
url="http://localhost:9999",
protocol_binding="JSONRPC",
),
],
version="1.0.0",
capabilities=AgentCapabilities(streaming=True, push_notifications=True),
default_input_modes=["text"],
default_output_modes=["text"],
skills=[],
)
@pytest.fixture
@@ -169,7 +186,7 @@ class TestA2APushNotificationHandler:
return Task(
id="task-123",
context_id="ctx-123",
status=TaskStatus(state=TaskState.working),
status=TaskStatus(state=TaskState.TASK_STATE_WORKING),
)
@pytest.mark.asyncio
@@ -181,7 +198,7 @@ class TestA2APushNotificationHandler:
"""Test that push handler waits for result from store."""
from unittest.mock import AsyncMock, MagicMock
from a2a.types import Task, TaskStatus
from a2a.types import StreamResponse, Task, TaskStatus
from pydantic import AnyHttpUrl
from crewai.a2a.updates.push_notifications.config import PushNotificationConfig
@@ -190,15 +207,14 @@ class TestA2APushNotificationHandler:
completed_task = Task(
id="task-123",
context_id="ctx-123",
status=TaskStatus(state=TaskState.completed),
history=[],
status=TaskStatus(state=TaskState.TASK_STATE_COMPLETED),
)
mock_store = MagicMock()
mock_store.wait_for_result = AsyncMock(return_value=completed_task)
async def mock_send_message(*args, **kwargs):
yield (mock_task, None)
yield StreamResponse(task=mock_task)
mock_client = MagicMock()
mock_client.send_message = mock_send_message
@@ -209,11 +225,7 @@ class TestA2APushNotificationHandler:
result_store=mock_store,
)
test_msg = Message(
role=Role.user,
parts=[Part(root=TextPart(text="What is 2+2?"))],
message_id="msg-001",
)
test_msg = new_text_message("What is 2+2?", role=ROLE_USER)
new_messages: list[Message] = []
@@ -226,7 +238,7 @@ class TestA2APushNotificationHandler:
result_store=mock_store,
polling_timeout=30.0,
polling_interval=1.0,
endpoint=mock_agent_card.url,
endpoint=agent_card_url(mock_agent_card),
)
mock_store.wait_for_result.assert_called_once_with(
@@ -235,7 +247,7 @@ class TestA2APushNotificationHandler:
poll_interval=1.0,
)
assert result["status"] == TaskState.completed
assert result["status"] == TASK_STATE_COMPLETED
@pytest.mark.asyncio
async def test_push_handler_returns_failure_on_timeout(
@@ -245,7 +257,7 @@ class TestA2APushNotificationHandler:
"""Test that push handler returns failure when result store times out."""
from unittest.mock import AsyncMock, MagicMock
from a2a.types import Task, TaskStatus
from a2a.types import StreamResponse, Task, TaskStatus
from pydantic import AnyHttpUrl
from crewai.a2a.updates.push_notifications.config import PushNotificationConfig
@@ -257,11 +269,11 @@ class TestA2APushNotificationHandler:
working_task = Task(
id="task-456",
context_id="ctx-456",
status=TaskStatus(state=TaskState.working),
status=TaskStatus(state=TaskState.TASK_STATE_WORKING),
)
async def mock_send_message(*args, **kwargs):
yield (working_task, None)
yield StreamResponse(task=working_task)
mock_client = MagicMock()
mock_client.send_message = mock_send_message
@@ -272,11 +284,7 @@ class TestA2APushNotificationHandler:
result_store=mock_store,
)
test_msg = Message(
role=Role.user,
parts=[Part(root=TextPart(text="test"))],
message_id="msg-002",
)
test_msg = new_text_message("test", role=ROLE_USER)
new_messages: list[Message] = []
@@ -289,10 +297,10 @@ class TestA2APushNotificationHandler:
result_store=mock_store,
polling_timeout=5.0,
polling_interval=0.5,
endpoint=mock_agent_card.url,
endpoint=agent_card_url(mock_agent_card),
)
assert result["status"] == TaskState.failed
assert result["status"] == TASK_STATE_FAILED
assert "timeout" in result.get("error", "").lower()
@pytest.mark.asyncio
@@ -307,11 +315,7 @@ class TestA2APushNotificationHandler:
mock_client = MagicMock()
test_msg = Message(
role=Role.user,
parts=[Part(root=TextPart(text="test"))],
message_id="msg-003",
)
test_msg = new_text_message("test", role=ROLE_USER)
new_messages: list[Message] = []
@@ -320,8 +324,8 @@ class TestA2APushNotificationHandler:
message=test_msg,
new_messages=new_messages,
agent_card=mock_agent_card,
endpoint=mock_agent_card.url,
endpoint=agent_card_url(mock_agent_card),
)
assert result["status"] == TaskState.failed
assert result["status"] == TASK_STATE_FAILED
assert "config" in result.get("error", "").lower()

View File

@@ -0,0 +1,517 @@
"""Tests for a2a-sdk v1.0 compatibility.
These tests validate that crewai.a2a modules correctly import and work with
a2a-sdk v1.0.x (protobuf-based types). They cover the core issue described
in https://github.com/crewAIInc/crewAI/issues/5607:
ImportError: cannot import name 'A2AClientHTTPError' from 'a2a.client.errors'
The migration from a2a-sdk ~0.3.10 to >=1.0.0,<2 introduced major breaking
changes including protobuf-based types, renamed error classes, and new enum
value conventions.
"""
from __future__ import annotations
import uuid
import pytest
class TestSdkV1Imports:
"""Verify that old v0.3 names no longer exist, and our compat layer works."""
def test_a2a_client_error_importable(self) -> None:
"""A2AClientError (renamed from A2AClientHTTPError) should be importable."""
from a2a.client.errors import A2AClientError
assert A2AClientError is not None
def test_old_a2a_client_http_error_removed(self) -> None:
"""A2AClientHTTPError no longer exists in a2a-sdk v1.0."""
with pytest.raises(ImportError):
from a2a.client.errors import A2AClientHTTPError # noqa: F401
def test_compat_alias_maps_to_new_error(self) -> None:
"""Our _compat alias should map to the new error class."""
from a2a.client.errors import A2AClientError
from crewai.a2a._compat import A2AClientHTTPError
assert A2AClientHTTPError is A2AClientError
def test_text_part_removed_in_v1(self) -> None:
"""TextPart no longer exists as a separate type in a2a-sdk v1.0."""
with pytest.raises(ImportError):
from a2a.types import TextPart # noqa: F401
def test_protobuf_types_importable(self) -> None:
"""Key protobuf types should be importable from a2a.types."""
from a2a.types import ( # noqa: F401
AgentCapabilities,
AgentCard,
AgentInterface,
GetTaskRequest,
Message,
Part,
Role,
StreamResponse,
SubscribeToTaskRequest,
Task,
TaskPushNotificationConfig,
TaskState,
TaskStatusUpdateEvent,
)
class TestCompatLayer:
"""Tests for the crewai.a2a._compat compatibility layer."""
def test_role_constants(self) -> None:
"""ROLE_USER and ROLE_AGENT should be valid Role enum values."""
from a2a.types import Role
from crewai.a2a._compat import ROLE_AGENT, ROLE_USER
assert ROLE_USER == Role.ROLE_USER
assert ROLE_AGENT == Role.ROLE_AGENT
def test_task_state_constants(self) -> None:
"""TASK_STATE_* should be valid TaskState enum values."""
from a2a.types import TaskState
from crewai.a2a._compat import (
TASK_STATE_CANCELED,
TASK_STATE_COMPLETED,
TASK_STATE_FAILED,
TASK_STATE_INPUT_REQUIRED,
TASK_STATE_REJECTED,
TASK_STATE_SUBMITTED,
TASK_STATE_WORKING,
)
assert TASK_STATE_SUBMITTED == TaskState.TASK_STATE_SUBMITTED
assert TASK_STATE_WORKING == TaskState.TASK_STATE_WORKING
assert TASK_STATE_COMPLETED == TaskState.TASK_STATE_COMPLETED
assert TASK_STATE_FAILED == TaskState.TASK_STATE_FAILED
assert TASK_STATE_CANCELED == TaskState.TASK_STATE_CANCELED
assert TASK_STATE_INPUT_REQUIRED == TaskState.TASK_STATE_INPUT_REQUIRED
assert TASK_STATE_REJECTED == TaskState.TASK_STATE_REJECTED
def test_terminal_states(self) -> None:
"""TERMINAL_STATES should include completed, failed, rejected, canceled."""
from crewai.a2a._compat import (
TASK_STATE_CANCELED,
TASK_STATE_COMPLETED,
TASK_STATE_FAILED,
TASK_STATE_REJECTED,
TERMINAL_STATES,
)
assert TASK_STATE_COMPLETED in TERMINAL_STATES
assert TASK_STATE_FAILED in TERMINAL_STATES
assert TASK_STATE_REJECTED in TERMINAL_STATES
assert TASK_STATE_CANCELED in TERMINAL_STATES
class TestPartHelpers:
"""Tests for protobuf Part helpers."""
def test_new_text_part(self) -> None:
"""new_text_part should create a Part with text field set."""
from crewai.a2a._compat import new_text_part, part_is_text, part_text
part = new_text_part("hello world")
assert part_is_text(part)
assert part_text(part) == "hello world"
def test_part_is_text_false_for_non_text(self) -> None:
"""part_is_text should return False for non-text parts."""
from a2a.types import Part
from google.protobuf.struct_pb2 import Value
from crewai.a2a._compat import part_is_text
v = Value()
v.string_value = "test"
part = Part(data=v)
assert not part_is_text(part)
def test_part_has_data(self) -> None:
"""part_has_data should detect data parts."""
from a2a.types import Part
from google.protobuf.struct_pb2 import Value
from crewai.a2a._compat import part_has_data
v = Value()
v.string_value = "test"
part = Part(data=v)
assert part_has_data(part)
def test_part_has_file(self) -> None:
"""part_has_file should detect raw/url file parts."""
from a2a.types import Part
from crewai.a2a._compat import part_has_file
raw_part = Part(raw=b"file content", media_type="application/pdf")
assert part_has_file(raw_part)
url_part = Part(url="https://example.com/file.pdf", media_type="application/pdf")
assert part_has_file(url_part)
class TestMessageHelpers:
"""Tests for protobuf Message helpers."""
def test_new_text_message(self) -> None:
"""new_text_message should create a Message with a text Part."""
from crewai.a2a._compat import (
ROLE_USER,
new_text_message,
part_is_text,
part_text,
)
msg = new_text_message("test message", role=ROLE_USER)
assert msg.role == ROLE_USER
assert len(msg.parts) == 1
assert part_is_text(msg.parts[0])
assert part_text(msg.parts[0]) == "test message"
def test_new_text_message_with_context_and_task(self) -> None:
"""new_text_message should accept context_id and task_id."""
from crewai.a2a._compat import ROLE_AGENT, new_text_message
msg = new_text_message(
"response",
role=ROLE_AGENT,
context_id="ctx-123",
task_id="task-456",
)
assert msg.context_id == "ctx-123"
assert msg.task_id == "task-456"
class TestAgentCardHelpers:
"""Tests for protobuf AgentCard helpers."""
def test_agent_card_to_dict(self) -> None:
"""agent_card_to_dict should serialize an AgentCard to a plain dict."""
from a2a.types import AgentCard, AgentInterface
from crewai.a2a._compat import agent_card_to_dict
card = AgentCard(
name="Test Agent",
description="A test agent",
supported_interfaces=[
AgentInterface(url="http://localhost:9999", protocol_binding="JSONRPC"),
],
version="1.0.0",
)
result = agent_card_to_dict(card)
assert isinstance(result, dict)
assert result["name"] == "Test Agent"
assert result["description"] == "A test agent"
def test_agent_card_url(self) -> None:
"""agent_card_url should return the URL from the first interface."""
from a2a.types import AgentCard, AgentInterface
from crewai.a2a._compat import agent_card_url
card = AgentCard(
name="Test",
supported_interfaces=[
AgentInterface(url="http://localhost:9999", protocol_binding="JSONRPC"),
],
)
assert agent_card_url(card) == "http://localhost:9999"
def test_agent_card_url_empty_when_no_interfaces(self) -> None:
"""agent_card_url should return empty string if no interfaces."""
from a2a.types import AgentCard
from crewai.a2a._compat import agent_card_url
card = AgentCard(name="No Interfaces")
assert agent_card_url(card) == ""
def test_agent_card_preferred_transport(self) -> None:
"""agent_card_preferred_transport should return protocol_binding."""
from a2a.types import AgentCard, AgentInterface
from crewai.a2a._compat import agent_card_preferred_transport
card = AgentCard(
name="Test",
supported_interfaces=[
AgentInterface(url="http://localhost", protocol_binding="GRPC"),
],
)
assert agent_card_preferred_transport(card) == "GRPC"
def test_agent_card_interfaces(self) -> None:
"""agent_card_interfaces should return all interfaces."""
from a2a.types import AgentCard, AgentInterface
from crewai.a2a._compat import agent_card_interfaces
card = AgentCard(
name="Test",
supported_interfaces=[
AgentInterface(url="http://a.com", protocol_binding="JSONRPC"),
AgentInterface(url="http://b.com", protocol_binding="GRPC"),
],
)
interfaces = agent_card_interfaces(card)
assert len(interfaces) == 2
def test_agent_card_protocol_version(self) -> None:
"""agent_card_protocol_version should return protocol version from first interface."""
from a2a.types import AgentCard, AgentInterface
from crewai.a2a._compat import agent_card_protocol_version
card = AgentCard(
name="Test",
supported_interfaces=[
AgentInterface(
url="http://localhost",
protocol_binding="JSONRPC",
protocol_version="0.3",
),
],
)
assert agent_card_protocol_version(card) == "0.3"
class TestProtoCopy:
"""Tests for protobuf deep copy helper."""
def test_proto_copy_creates_independent_copy(self) -> None:
"""proto_copy should create a deep copy of a protobuf message."""
from a2a.types import AgentCard, AgentInterface
from crewai.a2a._compat import proto_copy
original = AgentCard(
name="Original",
supported_interfaces=[
AgentInterface(url="http://original.com", protocol_binding="JSONRPC"),
],
)
copy = proto_copy(original)
copy.name = "Modified"
assert original.name == "Original"
assert copy.name == "Modified"
class TestStreamResponseHelpers:
"""Tests for StreamResponse event helpers."""
def test_is_stream_message(self) -> None:
"""is_stream_message should detect messages in StreamResponse."""
from a2a.types import Message, StreamResponse
from crewai.a2a._compat import ROLE_AGENT, is_stream_message, new_text_part
msg = Message(
role=ROLE_AGENT,
parts=[new_text_part("hello")],
message_id=str(uuid.uuid4()),
)
sr = StreamResponse(message=msg)
assert is_stream_message(sr)
def test_is_stream_task(self) -> None:
"""is_stream_task should detect tasks in StreamResponse."""
from a2a.types import StreamResponse, Task, TaskState, TaskStatus
from crewai.a2a._compat import is_stream_task
task = Task(
id="task-1",
status=TaskStatus(state=TaskState.TASK_STATE_COMPLETED),
)
sr = StreamResponse(task=task)
assert is_stream_task(sr)
def test_is_stream_status_update(self) -> None:
"""is_stream_status_update should detect status updates."""
from a2a.types import StreamResponse, TaskState, TaskStatus, TaskStatusUpdateEvent
from crewai.a2a._compat import is_stream_status_update
update = TaskStatusUpdateEvent(
task_id="task-1",
context_id="ctx-1",
status=TaskStatus(state=TaskState.TASK_STATE_WORKING),
)
sr = StreamResponse(status_update=update)
assert is_stream_status_update(sr)
class TestStatusUpdateFinality:
"""Tests for status update finality detection."""
def test_completed_is_final(self) -> None:
"""Completed status should be final."""
from a2a.types import TaskState, TaskStatus, TaskStatusUpdateEvent
from crewai.a2a._compat import is_status_update_final
update = TaskStatusUpdateEvent(
task_id="t1",
context_id="c1",
status=TaskStatus(state=TaskState.TASK_STATE_COMPLETED),
)
assert is_status_update_final(update) is True
def test_working_is_not_final(self) -> None:
"""Working status should not be final."""
from a2a.types import TaskState, TaskStatus, TaskStatusUpdateEvent
from crewai.a2a._compat import is_status_update_final
update = TaskStatusUpdateEvent(
task_id="t1",
context_id="c1",
status=TaskStatus(state=TaskState.TASK_STATE_WORKING),
)
assert is_status_update_final(update) is False
def test_failed_is_final(self) -> None:
"""Failed status should be final."""
from a2a.types import TaskState, TaskStatus, TaskStatusUpdateEvent
from crewai.a2a._compat import is_status_update_final
update = TaskStatusUpdateEvent(
task_id="t1",
context_id="c1",
status=TaskStatus(state=TaskState.TASK_STATE_FAILED),
)
assert is_status_update_final(update) is True
class TestClientConfigHelper:
"""Tests for client configuration helper."""
def test_create_client_config(self) -> None:
"""create_client_config should produce a valid ClientConfig."""
from crewai.a2a._compat import create_client_config
config = create_client_config(
supported_transports=["JSONRPC", "GRPC"],
streaming=True,
polling=False,
)
assert config.supported_protocol_bindings == ["JSONRPC", "GRPC"]
assert config.streaming is True
assert config.polling is False
class TestProtoToJson:
"""Tests for proto_to_json serialization."""
def test_proto_to_json(self) -> None:
"""proto_to_json should serialize a protobuf to JSON string."""
from a2a.types import AgentCard, AgentInterface
from crewai.a2a._compat import proto_to_json
card = AgentCard(
name="Test Agent",
supported_interfaces=[
AgentInterface(url="http://localhost:9999", protocol_binding="JSONRPC"),
],
)
json_str = proto_to_json(card)
assert isinstance(json_str, str)
assert "Test Agent" in json_str
class TestModuleImports:
"""Verify all crewai.a2a submodules import without error under v1.0."""
def test_import_compat(self) -> None:
from crewai.a2a._compat import A2AClientHTTPError # noqa: F401
def test_import_task_helpers(self) -> None:
from crewai.a2a.task_helpers import process_task_state # noqa: F401
def test_import_polling_handler(self) -> None:
from crewai.a2a.updates.polling.handler import PollingHandler # noqa: F401
def test_import_streaming_handler(self) -> None:
from crewai.a2a.updates.streaming.handler import StreamingHandler # noqa: F401
def test_import_push_handler(self) -> None:
from crewai.a2a.updates.push_notifications.handler import PushNotificationHandler # noqa: F401
def test_import_auth_utils(self) -> None:
from crewai.a2a.auth.utils import validate_auth_against_agent_card # noqa: F401
def test_import_delegation(self) -> None:
from crewai.a2a.utils.delegation import execute_a2a_delegation # noqa: F401
def test_import_transport(self) -> None:
from crewai.a2a.utils.transport import negotiate_transport # noqa: F401
def test_import_agent_card(self) -> None:
from crewai.a2a.utils.agent_card import afetch_agent_card # noqa: F401
def test_import_agent_card_signing(self) -> None:
from crewai.a2a.utils.agent_card_signing import sign_agent_card # noqa: F401
def test_import_wrapper(self) -> None:
from crewai.a2a.wrapper import wrap_agent_with_a2a_instance # noqa: F401
def test_import_extensions_registry(self) -> None:
from crewai.a2a.extensions.registry import ExtensionsMiddleware # noqa: F401
def test_import_content_type(self) -> None:
from crewai.a2a.utils.content_type import get_part_content_type # noqa: F401
class TestGetPartContentType:
"""Tests for get_part_content_type with v1.0 protobuf Parts."""
def test_text_part_returns_text_plain(self) -> None:
from a2a.types import Part
from crewai.a2a.utils.content_type import get_part_content_type
part = Part(text="hello")
assert get_part_content_type(part) == "text/plain"
def test_data_part_returns_application_json(self) -> None:
from a2a.types import Part
from google.protobuf.struct_pb2 import Value
from crewai.a2a.utils.content_type import get_part_content_type
v = Value()
v.string_value = "test"
part = Part(data=v)
assert get_part_content_type(part) == "application/json"
def test_raw_part_returns_media_type(self) -> None:
from a2a.types import Part
from crewai.a2a.utils.content_type import get_part_content_type
part = Part(raw=b"pdf content", media_type="application/pdf")
assert get_part_content_type(part) == "application/pdf"
def test_url_part_returns_media_type(self) -> None:
from a2a.types import Part
from crewai.a2a.utils.content_type import get_part_content_type
part = Part(url="https://example.com/image.png", media_type="image/png")
assert get_part_content_type(part) == "image/png"

View File

@@ -5,6 +5,7 @@ from __future__ import annotations
from a2a.types import AgentCard, AgentSkill
from crewai import Agent
from crewai.a2a._compat import agent_card_to_dict, agent_card_url, proto_to_json
from crewai.a2a.config import A2AClientConfig, A2AServerConfig
from crewai.a2a.utils.agent_card import inject_a2a_server_methods
@@ -154,7 +155,7 @@ class TestToAgentCard:
card = agent.to_agent_card("http://my-server.com:9000")
assert card.url == "http://my-server.com:9000"
assert agent_card_url(card) == "http://my-server.com:9000"
def test_uses_server_config_url(self) -> None:
"""AgentCard url should prefer A2AServerConfig.url over provided URL."""
@@ -167,7 +168,8 @@ class TestToAgentCard:
card = agent.to_agent_card("http://fallback-url.com")
assert card.url == "http://configured-url.com/"
url = agent_card_url(card)
assert url.rstrip("/") == "http://configured-url.com"
def test_generates_default_skill(self) -> None:
"""AgentCard should have at least one skill based on agent role."""
@@ -246,16 +248,16 @@ class TestAgentCardJsonStructure:
)
card = agent.to_agent_card("http://localhost:8000")
json_data = card.model_dump()
json_data = agent_card_to_dict(card)
assert "name" in json_data
assert "description" in json_data
assert "url" in json_data
assert "supported_interfaces" in json_data
assert "version" in json_data
assert "skills" in json_data
assert "capabilities" in json_data
assert "defaultInputModes" in json_data
assert "defaultOutputModes" in json_data
assert "default_input_modes" in json_data
assert "default_output_modes" in json_data
def test_json_skills_structure(self) -> None:
"""Each skill in JSON should have required fields."""
@@ -267,7 +269,7 @@ class TestAgentCardJsonStructure:
)
card = agent.to_agent_card("http://localhost:8000")
json_data = card.model_dump()
json_data = agent_card_to_dict(card)
assert len(json_data["skills"]) >= 1
skill = json_data["skills"][0]
@@ -286,11 +288,11 @@ class TestAgentCardJsonStructure:
)
card = agent.to_agent_card("http://localhost:8000")
json_data = card.model_dump()
json_data = agent_card_to_dict(card)
capabilities = json_data["capabilities"]
assert "streaming" in capabilities
assert "pushNotifications" in capabilities
assert "push_notifications" in capabilities
def test_json_serializable(self) -> None:
"""AgentCard should be JSON serializable."""
@@ -302,14 +304,14 @@ class TestAgentCardJsonStructure:
)
card = agent.to_agent_card("http://localhost:8000")
json_str = card.model_dump_json()
json_str = proto_to_json(card)
assert isinstance(json_str, str)
assert "Test Agent" in json_str
assert "http://localhost:8000" in json_str
def test_json_excludes_none_values(self) -> None:
"""AgentCard JSON with exclude_none should omit None fields."""
def test_json_excludes_unset_fields(self) -> None:
"""AgentCard JSON should omit fields that were not explicitly set."""
agent = Agent(
role="Test Agent",
goal="Test goal",
@@ -318,8 +320,6 @@ class TestAgentCardJsonStructure:
)
card = agent.to_agent_card("http://localhost:8000")
json_data = card.model_dump(exclude_none=True)
json_data = agent_card_to_dict(card)
assert "provider" not in json_data
assert "documentationUrl" not in json_data
assert "iconUrl" not in json_data

View File

@@ -12,6 +12,7 @@ from a2a.server.agent_execution import RequestContext
from a2a.server.events import EventQueue
from a2a.types import Message, Task as A2ATask, TaskState, TaskStatus
from crewai.a2a._compat import TASK_STATE_CANCELED, TASK_STATE_WORKING
from crewai.a2a.utils.task import cancel, cancellable, execute
@@ -38,12 +39,13 @@ def mock_task(mock_context: MagicMock) -> MagicMock:
@pytest.fixture
def mock_context() -> MagicMock:
"""Create a mock RequestContext."""
from crewai.a2a._compat import ROLE_USER, new_text_message
context = MagicMock(spec=RequestContext)
context.task_id = "test-task-123"
context.context_id = "test-context-456"
context.get_user_input.return_value = "Test user message"
context.message = MagicMock(spec=Message)
context.message.parts = []
context.message = new_text_message("Test user message", role=ROLE_USER)
context.current_task = None
return context
@@ -291,8 +293,7 @@ class TestCancel:
assert event.task_id == mock_context.task_id
assert event.context_id == mock_context.context_id
assert event.status.state == TaskState.canceled
assert event.final is True
assert event.status.state == TASK_STATE_CANCELED
@pytest.mark.asyncio
async def test_returns_none_when_no_current_task(
@@ -315,13 +316,13 @@ class TestCancel:
) -> None:
"""Cancel returns updated task when context has current_task."""
current_task = MagicMock(spec=A2ATask)
current_task.status = TaskStatus(state=TaskState.working)
current_task.status = TaskStatus(state=TASK_STATE_WORKING)
mock_context.current_task = current_task
result = await cancel(mock_context, mock_event_queue)
assert result is current_task
assert result.status.state == TaskState.canceled
assert result.status.state == TASK_STATE_CANCELED
@pytest.mark.asyncio
async def test_cleanup_after_cancel(

9496
uv.lock generated

File diff suppressed because it is too large Load Diff