mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-05-07 10:12:38 +00:00
Compare commits
9 Commits
devin/1778
...
devin/1777
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f12dc9f993 | ||
|
|
24d2d8dabb | ||
|
|
f5331f3a46 | ||
|
|
6b0ddfa3f2 | ||
|
|
856954a311 | ||
|
|
ec98238985 | ||
|
|
d05a3415f8 | ||
|
|
54470f4932 | ||
|
|
bec175ec9a |
@@ -100,7 +100,7 @@ anthropic = [
|
||||
"anthropic~=0.73.0",
|
||||
]
|
||||
a2a = [
|
||||
"a2a-sdk~=0.3.10",
|
||||
"a2a-sdk>=1.0.0,<2",
|
||||
"httpx-auth~=0.23.1",
|
||||
"httpx-sse~=0.4.0",
|
||||
"aiocache[redis,memcached]~=0.12.3",
|
||||
|
||||
318
lib/crewai/src/crewai/a2a/_compat.py
Normal file
318
lib/crewai/src/crewai/a2a/_compat.py
Normal file
@@ -0,0 +1,318 @@
|
||||
"""Compatibility layer for a2a-sdk v0.3 → v1.0 migration.
|
||||
|
||||
Centralizes import aliases and helper functions so the rest of the
|
||||
a2a module can use a single import regardless of SDK version.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from a2a.client.errors import A2AClientError
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Error re-exports
|
||||
# In v0.3 the class was called A2AClientHTTPError; v1.0 renamed it to
|
||||
# A2AClientError. We expose the new name *and* an alias used across the
|
||||
# codebase so callers can migrate incrementally.
|
||||
# ---------------------------------------------------------------------------
|
||||
A2AClientHTTPError = A2AClientError # back-compat alias
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Type helpers - Protobuf Part access
|
||||
# In v0.3 Part was a Pydantic discriminated-union with ``part.root.kind``
|
||||
# and ``part.root.text``; in v1.0 Part is a protobuf message with a
|
||||
# ``content`` oneof.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
from a2a.types import Part # noqa: E402
|
||||
|
||||
|
||||
def part_is_text(part: Part) -> bool:
|
||||
"""Return True when the Part carries text content."""
|
||||
return part.HasField("text")
|
||||
|
||||
|
||||
def part_text(part: Part) -> str:
|
||||
"""Return the text payload of a Part (assumes text content)."""
|
||||
return part.text
|
||||
|
||||
|
||||
def part_has_data(part: Part) -> bool:
|
||||
"""Return True when the Part carries structured data."""
|
||||
return part.HasField("data")
|
||||
|
||||
|
||||
def part_has_file(part: Part) -> bool:
|
||||
"""Return True when the Part carries a file (url or raw bytes)."""
|
||||
return part.HasField("url") or part.HasField("raw")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Enum value aliases
|
||||
# v0.3: TaskState.completed, Role.user (lower snake_case strings)
|
||||
# v1.0: TaskState.TASK_STATE_COMPLETED, Role.ROLE_USER (SCREAMING_SNAKE_CASE)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
from a2a.types import Role, TaskState # noqa: E402
|
||||
|
||||
|
||||
# TaskState aliases
|
||||
TASK_STATE_SUBMITTED = TaskState.TASK_STATE_SUBMITTED
|
||||
TASK_STATE_WORKING = TaskState.TASK_STATE_WORKING
|
||||
TASK_STATE_COMPLETED = TaskState.TASK_STATE_COMPLETED
|
||||
TASK_STATE_FAILED = TaskState.TASK_STATE_FAILED
|
||||
TASK_STATE_CANCELED = TaskState.TASK_STATE_CANCELED
|
||||
TASK_STATE_INPUT_REQUIRED = TaskState.TASK_STATE_INPUT_REQUIRED
|
||||
TASK_STATE_AUTH_REQUIRED = TaskState.TASK_STATE_AUTH_REQUIRED
|
||||
TASK_STATE_REJECTED = TaskState.TASK_STATE_REJECTED
|
||||
|
||||
# Role aliases
|
||||
ROLE_USER = Role.ROLE_USER
|
||||
ROLE_AGENT = Role.ROLE_AGENT
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Protobuf object helpers
|
||||
# Protobuf objects don't have model_dump() / model_copy().
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
from google.protobuf.json_format import ( # noqa: E402
|
||||
MessageToDict,
|
||||
)
|
||||
|
||||
|
||||
def proto_to_json(msg: Any) -> str:
|
||||
"""Serialize a protobuf message to a JSON string.
|
||||
|
||||
Replaces ``msg.model_dump_json(...)`` from v0.3 Pydantic models.
|
||||
"""
|
||||
from google.protobuf.json_format import MessageToJson
|
||||
|
||||
return MessageToJson(msg, preserving_proto_field_name=True, indent=2)
|
||||
|
||||
|
||||
def agent_card_to_dict(agent_card: Any, *, exclude_none: bool = True) -> dict[str, Any]:
|
||||
"""Serialize a protobuf AgentCard to a plain dict.
|
||||
|
||||
Works like ``agent_card.model_dump(exclude_none=True)`` did in v0.3.
|
||||
"""
|
||||
return MessageToDict(
|
||||
agent_card,
|
||||
preserving_proto_field_name=True,
|
||||
always_print_fields_with_no_presence=not exclude_none,
|
||||
)
|
||||
|
||||
|
||||
def proto_copy(msg: Any) -> Any:
|
||||
"""Return a deep copy of a protobuf message (replaces model_copy)."""
|
||||
new = type(msg)()
|
||||
new.CopyFrom(msg)
|
||||
return new
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Message / Part construction helpers
|
||||
# v0.3: Message(role=Role.user, parts=[Part(root=TextPart(text=...))])
|
||||
# v1.0: Message(role=Role.ROLE_USER, parts=[Part(text=...)])
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
from a2a.types import Message # noqa: E402
|
||||
|
||||
|
||||
def new_text_part(text: str, **kwargs: Any) -> Part:
|
||||
"""Create a Part with text content (v1.0 style)."""
|
||||
return Part(text=text, **kwargs)
|
||||
|
||||
|
||||
def new_text_message(
|
||||
text: str,
|
||||
*,
|
||||
role: Any = ROLE_AGENT,
|
||||
message_id: str | None = None,
|
||||
context_id: str | None = None,
|
||||
task_id: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Message:
|
||||
"""Create a Message with a single text Part."""
|
||||
import uuid as _uuid
|
||||
|
||||
return Message(
|
||||
role=role,
|
||||
message_id=message_id or str(_uuid.uuid4()),
|
||||
parts=[Part(text=text)],
|
||||
context_id=context_id or "",
|
||||
task_id=task_id or "",
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def make_send_request(message: Message) -> Any:
|
||||
"""Wrap a Message in a SendMessageRequest (v1.0 API).
|
||||
|
||||
In v0.3, ``client.send_message(message)`` accepted a bare ``Message``.
|
||||
In v1.0, it expects ``SendMessageRequest(message=message)``.
|
||||
"""
|
||||
from a2a.types import SendMessageRequest
|
||||
|
||||
return SendMessageRequest(message=message)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# AgentCard field access helpers
|
||||
# v0.3: agent_card.url, agent_card.preferred_transport, agent_card.additional_interfaces
|
||||
# v1.0: agent_card.supported_interfaces, interface.url, interface.protocol_binding
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
from a2a.types import AgentCard, AgentInterface # noqa: E402
|
||||
|
||||
|
||||
def agent_card_url(agent_card: AgentCard) -> str:
|
||||
"""Get the primary URL from an AgentCard.
|
||||
|
||||
In v0.3 this was ``agent_card.url``.
|
||||
In v1.0 the URL lives inside ``supported_interfaces``.
|
||||
"""
|
||||
if agent_card.supported_interfaces:
|
||||
return agent_card.supported_interfaces[0].url
|
||||
return ""
|
||||
|
||||
|
||||
def agent_card_preferred_transport(agent_card: AgentCard) -> str:
|
||||
"""Get the preferred transport protocol from an AgentCard.
|
||||
|
||||
In v0.3 this was ``agent_card.preferred_transport``.
|
||||
In v1.0 it's the protocol_binding of the first supported_interface.
|
||||
"""
|
||||
if agent_card.supported_interfaces:
|
||||
return agent_card.supported_interfaces[0].protocol_binding
|
||||
return "JSONRPC"
|
||||
|
||||
|
||||
def agent_card_interfaces(agent_card: AgentCard) -> list[AgentInterface]:
|
||||
"""Get all interfaces from an AgentCard.
|
||||
|
||||
In v0.3 these were split between the primary url and
|
||||
``agent_card.additional_interfaces``.
|
||||
In v1.0 everything is in ``supported_interfaces``.
|
||||
"""
|
||||
return (
|
||||
list(agent_card.supported_interfaces) if agent_card.supported_interfaces else []
|
||||
)
|
||||
|
||||
|
||||
def agent_card_protocol_version(agent_card: AgentCard) -> str:
|
||||
"""Get the protocol version from an AgentCard.
|
||||
|
||||
In v0.3 this was ``agent_card.protocol_version``.
|
||||
In v1.0 it's per-interface in ``interface.protocol_version``.
|
||||
"""
|
||||
if agent_card.supported_interfaces:
|
||||
return agent_card.supported_interfaces[0].protocol_version or ""
|
||||
return ""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# StreamResponse helpers
|
||||
# v0.3: send_message returned AsyncIterator[tuple[Task, Update] | Message]
|
||||
# v1.0: send_message returns AsyncIterator[StreamResponse]
|
||||
# ---------------------------------------------------------------------------
|
||||
from a2a.types import ( # noqa: E402
|
||||
StreamResponse,
|
||||
TaskStatusUpdateEvent,
|
||||
)
|
||||
|
||||
|
||||
def is_stream_message(chunk: StreamResponse) -> bool:
|
||||
"""Check if a StreamResponse contains a Message."""
|
||||
return chunk.HasField("message")
|
||||
|
||||
|
||||
def is_stream_task(chunk: StreamResponse) -> bool:
|
||||
"""Check if a StreamResponse contains a Task."""
|
||||
return chunk.HasField("task")
|
||||
|
||||
|
||||
def is_stream_status_update(chunk: StreamResponse) -> bool:
|
||||
"""Check if a StreamResponse contains a TaskStatusUpdateEvent."""
|
||||
return chunk.HasField("status_update")
|
||||
|
||||
|
||||
def is_stream_artifact_update(chunk: StreamResponse) -> bool:
|
||||
"""Check if a StreamResponse contains a TaskArtifactUpdateEvent."""
|
||||
return chunk.HasField("artifact_update")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Client configuration helpers
|
||||
# v0.3: ClientConfig.supported_transports, push_notification_configs (list)
|
||||
# v1.0: ClientConfig.supported_protocol_bindings, push_notification_config (singular)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
from a2a.client import ClientConfig # noqa: E402
|
||||
from a2a.types import TaskPushNotificationConfig # noqa: E402
|
||||
|
||||
|
||||
def create_client_config(
|
||||
*,
|
||||
httpx_client: Any = None,
|
||||
supported_transports: list[str] | None = None,
|
||||
streaming: bool = True,
|
||||
polling: bool = False,
|
||||
accepted_output_modes: list[str] | None = None,
|
||||
push_notification_config: TaskPushNotificationConfig | None = None,
|
||||
grpc_channel_factory: Any = None,
|
||||
) -> ClientConfig:
|
||||
"""Create a ClientConfig compatible with a2a-sdk v1.0."""
|
||||
return ClientConfig(
|
||||
httpx_client=httpx_client,
|
||||
supported_protocol_bindings=supported_transports or ["JSONRPC"],
|
||||
streaming=streaming,
|
||||
polling=polling,
|
||||
accepted_output_modes=accepted_output_modes
|
||||
or ["text/plain", "application/json"],
|
||||
push_notification_config=push_notification_config,
|
||||
grpc_channel_factory=grpc_channel_factory,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GetTaskRequest / SubscribeToTaskRequest
|
||||
# v0.3: TaskQueryParams, TaskIdParams
|
||||
# v1.0: GetTaskRequest, SubscribeToTaskRequest
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
from a2a.types import GetTaskRequest, SubscribeToTaskRequest # noqa: E402
|
||||
|
||||
|
||||
# Expose v0.3 names as aliases for the v1.0 types
|
||||
TaskQueryParams = GetTaskRequest
|
||||
TaskIdParams = SubscribeToTaskRequest
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Task status helpers
|
||||
# v1.0 TaskStatusUpdateEvent no longer has a `final` field. Finality is
|
||||
# determined by the task state being terminal.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
TERMINAL_STATES: frozenset[int] = frozenset(
|
||||
{
|
||||
TASK_STATE_COMPLETED,
|
||||
TASK_STATE_FAILED,
|
||||
TASK_STATE_REJECTED,
|
||||
TASK_STATE_CANCELED,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def is_status_update_final(update: TaskStatusUpdateEvent) -> bool:
|
||||
"""Determine if a status update is final.
|
||||
|
||||
In v0.3 this was ``update.final``. In v1.0 finality is inferred from
|
||||
the task state being terminal.
|
||||
"""
|
||||
if update.status and update.status.state:
|
||||
return update.status.state in TERMINAL_STATES
|
||||
return False
|
||||
@@ -11,12 +11,13 @@ import re
|
||||
import threading
|
||||
from typing import Final, Literal, cast
|
||||
|
||||
from a2a.client.errors import A2AClientHTTPError
|
||||
from a2a.client.errors import A2AClientError
|
||||
from a2a.types import (
|
||||
APIKeySecurityScheme,
|
||||
AgentCard,
|
||||
HTTPAuthSecurityScheme,
|
||||
OAuth2SecurityScheme,
|
||||
SecurityScheme,
|
||||
)
|
||||
from httpx import AsyncClient, Response
|
||||
|
||||
@@ -112,7 +113,7 @@ def _raise_auth_mismatch(
|
||||
f"AgentCard requires {required} authentication, "
|
||||
f"but {type(provided_auth).__name__} was provided"
|
||||
)
|
||||
raise A2AClientHTTPError(401, msg)
|
||||
raise A2AClientError(msg)
|
||||
|
||||
|
||||
def parse_www_authenticate(header_value: str) -> dict[str, dict[str, str]]:
|
||||
@@ -159,25 +160,44 @@ def validate_auth_against_agent_card(
|
||||
A2AClientHTTPError: If auth doesn't match AgentCard requirements (status_code=401).
|
||||
"""
|
||||
|
||||
if not agent_card.security or not agent_card.security_schemes:
|
||||
if not agent_card.security_requirements or not agent_card.security_schemes:
|
||||
return
|
||||
|
||||
if not auth:
|
||||
msg = "AgentCard requires authentication but no auth scheme provided"
|
||||
raise A2AClientHTTPError(401, msg)
|
||||
raise A2AClientError(msg)
|
||||
|
||||
first_security_req = agent_card.security[0] if agent_card.security else {}
|
||||
first_security_req = (
|
||||
agent_card.security_requirements[0]
|
||||
if agent_card.security_requirements
|
||||
else None
|
||||
)
|
||||
if first_security_req is None:
|
||||
return
|
||||
|
||||
for scheme_name in first_security_req.keys():
|
||||
security_scheme_wrapper = agent_card.security_schemes.get(scheme_name)
|
||||
for scheme_name in first_security_req.schemes.keys():
|
||||
security_scheme_wrapper: SecurityScheme | None = (
|
||||
agent_card.security_schemes.get(scheme_name)
|
||||
)
|
||||
if not security_scheme_wrapper:
|
||||
continue
|
||||
|
||||
scheme = security_scheme_wrapper.root
|
||||
scheme_field = security_scheme_wrapper.WhichOneof("scheme")
|
||||
if scheme_field is None:
|
||||
continue
|
||||
|
||||
if allowed_classes := _SCHEME_AUTH_MAPPING.get(type(scheme)):
|
||||
if not isinstance(auth, allowed_classes):
|
||||
_raise_auth_mismatch(allowed_classes, auth)
|
||||
scheme = getattr(security_scheme_wrapper, scheme_field)
|
||||
|
||||
if isinstance(scheme, OAuth2SecurityScheme):
|
||||
allowed = _SCHEME_AUTH_MAPPING.get(OAuth2SecurityScheme)
|
||||
if allowed and not isinstance(auth, allowed):
|
||||
_raise_auth_mismatch(allowed, auth)
|
||||
return
|
||||
|
||||
if isinstance(scheme, APIKeySecurityScheme):
|
||||
allowed = _SCHEME_AUTH_MAPPING.get(APIKeySecurityScheme)
|
||||
if allowed and not isinstance(auth, allowed):
|
||||
_raise_auth_mismatch(allowed, auth)
|
||||
return
|
||||
|
||||
if isinstance(scheme, HTTPAuthSecurityScheme):
|
||||
@@ -188,7 +208,7 @@ def validate_auth_against_agent_card(
|
||||
return
|
||||
|
||||
msg = "Could not validate auth against AgentCard security requirements"
|
||||
raise A2AClientHTTPError(401, msg)
|
||||
raise A2AClientError(msg)
|
||||
|
||||
|
||||
async def retry_on_401(
|
||||
|
||||
@@ -568,7 +568,9 @@ class A2AServerConfig(BaseModel):
|
||||
auth: Authentication scheme for A2A endpoints.
|
||||
"""
|
||||
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(
|
||||
extra="forbid", arbitrary_types_allowed=True
|
||||
)
|
||||
|
||||
name: str | None = Field(
|
||||
default=None,
|
||||
|
||||
@@ -10,9 +10,7 @@ See: https://a2a-protocol.org/latest/topics/extensions/
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from a2a.client.middleware import ClientCallContext, ClientCallInterceptor
|
||||
from a2a.client.interceptors import BeforeArgs, ClientCallInterceptor
|
||||
from a2a.extensions.common import (
|
||||
HTTP_EXTENSION_HEADER,
|
||||
)
|
||||
@@ -63,30 +61,15 @@ class ExtensionsMiddleware(ClientCallInterceptor):
|
||||
"""
|
||||
self._extensions = extensions
|
||||
|
||||
async def intercept(
|
||||
self,
|
||||
method_name: str,
|
||||
request_payload: dict[str, Any],
|
||||
http_kwargs: dict[str, Any],
|
||||
agent_card: AgentCard | None,
|
||||
context: ClientCallContext | None,
|
||||
) -> tuple[dict[str, Any], dict[str, Any]]:
|
||||
"""Add extensions header to the request.
|
||||
async def before(self, args: BeforeArgs) -> None:
|
||||
"""Add extensions header before the request is sent.
|
||||
|
||||
Args:
|
||||
method_name: The A2A method being called.
|
||||
request_payload: The JSON-RPC request payload.
|
||||
http_kwargs: HTTP request kwargs (headers, etc).
|
||||
agent_card: The target agent's card.
|
||||
context: Optional call context.
|
||||
|
||||
Returns:
|
||||
Tuple of (request_payload, modified_http_kwargs).
|
||||
args: The BeforeArgs containing method, input, agent_card, etc.
|
||||
"""
|
||||
if self._extensions:
|
||||
headers = http_kwargs.setdefault("headers", {})
|
||||
headers[HTTP_EXTENSION_HEADER] = ",".join(self._extensions)
|
||||
return request_payload, http_kwargs
|
||||
if self._extensions and isinstance(args.input, dict):
|
||||
metadata = args.input.setdefault("metadata", {})
|
||||
metadata[HTTP_EXTENSION_HEADER] = ",".join(self._extensions)
|
||||
|
||||
|
||||
def validate_required_extensions(
|
||||
|
||||
@@ -4,22 +4,32 @@ from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import TYPE_CHECKING, Any
|
||||
import uuid
|
||||
|
||||
from a2a.client.errors import A2AClientHTTPError
|
||||
from a2a.client.errors import A2AClientError
|
||||
from a2a.types import (
|
||||
AgentCard,
|
||||
Message,
|
||||
Part,
|
||||
Role,
|
||||
Task,
|
||||
TaskArtifactUpdateEvent,
|
||||
TaskState,
|
||||
TaskStatusUpdateEvent,
|
||||
TextPart,
|
||||
StreamResponse,
|
||||
)
|
||||
from typing_extensions import NotRequired, TypedDict
|
||||
|
||||
from crewai.a2a._compat import (
|
||||
ROLE_AGENT,
|
||||
TASK_STATE_AUTH_REQUIRED,
|
||||
TASK_STATE_CANCELED,
|
||||
TASK_STATE_COMPLETED,
|
||||
TASK_STATE_FAILED,
|
||||
TASK_STATE_INPUT_REQUIRED,
|
||||
TASK_STATE_REJECTED,
|
||||
TASK_STATE_SUBMITTED,
|
||||
TASK_STATE_WORKING,
|
||||
agent_card_to_dict,
|
||||
is_stream_message,
|
||||
is_stream_task,
|
||||
new_text_message,
|
||||
part_is_text,
|
||||
part_text,
|
||||
)
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.a2a_events import (
|
||||
A2AConnectionErrorEvent,
|
||||
@@ -30,31 +40,29 @@ from crewai.events.types.a2a_events import (
|
||||
if TYPE_CHECKING:
|
||||
from a2a.types import Task as A2ATask
|
||||
|
||||
SendMessageEvent = (
|
||||
tuple[Task, TaskStatusUpdateEvent | TaskArtifactUpdateEvent | None] | Message
|
||||
)
|
||||
SendMessageEvent = StreamResponse
|
||||
|
||||
|
||||
TERMINAL_STATES: frozenset[TaskState] = frozenset(
|
||||
TERMINAL_STATES: frozenset[int] = frozenset(
|
||||
{
|
||||
TaskState.completed,
|
||||
TaskState.failed,
|
||||
TaskState.rejected,
|
||||
TaskState.canceled,
|
||||
TASK_STATE_COMPLETED,
|
||||
TASK_STATE_FAILED,
|
||||
TASK_STATE_REJECTED,
|
||||
TASK_STATE_CANCELED,
|
||||
}
|
||||
)
|
||||
|
||||
ACTIONABLE_STATES: frozenset[TaskState] = frozenset(
|
||||
ACTIONABLE_STATES: frozenset[int] = frozenset(
|
||||
{
|
||||
TaskState.input_required,
|
||||
TaskState.auth_required,
|
||||
TASK_STATE_INPUT_REQUIRED,
|
||||
TASK_STATE_AUTH_REQUIRED,
|
||||
}
|
||||
)
|
||||
|
||||
PENDING_STATES: frozenset[TaskState] = frozenset(
|
||||
PENDING_STATES: frozenset[int] = frozenset(
|
||||
{
|
||||
TaskState.submitted,
|
||||
TaskState.working,
|
||||
TASK_STATE_SUBMITTED,
|
||||
TASK_STATE_WORKING,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -62,7 +70,7 @@ PENDING_STATES: frozenset[TaskState] = frozenset(
|
||||
class TaskStateResult(TypedDict):
|
||||
"""Result dictionary from processing A2A task state."""
|
||||
|
||||
status: TaskState
|
||||
status: int
|
||||
history: list[Message]
|
||||
result: NotRequired[str]
|
||||
error: NotRequired[str]
|
||||
@@ -83,26 +91,22 @@ def extract_task_result_parts(a2a_task: A2ATask) -> list[str]:
|
||||
|
||||
if a2a_task.status and a2a_task.status.message:
|
||||
msg = a2a_task.status.message
|
||||
result_parts.extend(
|
||||
part.root.text for part in msg.parts if part.root.kind == "text"
|
||||
)
|
||||
result_parts.extend(part_text(part) for part in msg.parts if part_is_text(part))
|
||||
|
||||
if not result_parts and a2a_task.history:
|
||||
for history_msg in reversed(a2a_task.history):
|
||||
if history_msg.role == Role.agent:
|
||||
if history_msg.role == ROLE_AGENT:
|
||||
result_parts.extend(
|
||||
part.root.text
|
||||
for part in history_msg.parts
|
||||
if part.root.kind == "text"
|
||||
part_text(part) for part in history_msg.parts if part_is_text(part)
|
||||
)
|
||||
break
|
||||
|
||||
if a2a_task.artifacts:
|
||||
result_parts.extend(
|
||||
part.root.text
|
||||
part_text(part)
|
||||
for artifact in a2a_task.artifacts
|
||||
for part in artifact.parts
|
||||
if part.root.kind == "text"
|
||||
if part_is_text(part)
|
||||
)
|
||||
|
||||
return result_parts
|
||||
@@ -122,15 +126,15 @@ def extract_error_message(a2a_task: A2ATask, default: str) -> str:
|
||||
msg = a2a_task.status.message
|
||||
if msg:
|
||||
for part in msg.parts:
|
||||
if part.root.kind == "text":
|
||||
return str(part.root.text)
|
||||
return str(msg)
|
||||
if part_is_text(part):
|
||||
return str(part_text(part))
|
||||
return str(msg)
|
||||
|
||||
if a2a_task.history:
|
||||
for history_msg in reversed(a2a_task.history):
|
||||
for part in history_msg.parts:
|
||||
if part.root.kind == "text":
|
||||
return str(part.root.text)
|
||||
if part_is_text(part):
|
||||
return str(part_text(part))
|
||||
|
||||
return default
|
||||
|
||||
@@ -174,7 +178,7 @@ def process_task_state(
|
||||
if result_parts is None:
|
||||
result_parts = []
|
||||
|
||||
if a2a_task.status.state == TaskState.completed:
|
||||
if a2a_task.status.state == TASK_STATE_COMPLETED:
|
||||
if not result_parts:
|
||||
extracted_parts = extract_task_result_parts(a2a_task)
|
||||
result_parts.extend(extracted_parts)
|
||||
@@ -204,22 +208,21 @@ def process_task_state(
|
||||
)
|
||||
|
||||
return TaskStateResult(
|
||||
status=TaskState.completed,
|
||||
agent_card=agent_card.model_dump(exclude_none=True),
|
||||
status=TASK_STATE_COMPLETED,
|
||||
agent_card=agent_card_to_dict(agent_card),
|
||||
result=response_text,
|
||||
history=new_messages,
|
||||
)
|
||||
|
||||
if a2a_task.status.state == TaskState.input_required:
|
||||
if a2a_task.status.state == TASK_STATE_INPUT_REQUIRED:
|
||||
if a2a_task.history:
|
||||
new_messages.extend(a2a_task.history)
|
||||
|
||||
response_text = extract_error_message(a2a_task, "Additional input required")
|
||||
if response_text and not a2a_task.history:
|
||||
agent_message = Message(
|
||||
role=Role.agent,
|
||||
message_id=str(uuid.uuid4()),
|
||||
parts=[Part(root=TextPart(text=response_text))],
|
||||
agent_message = new_text_message(
|
||||
response_text,
|
||||
role=ROLE_AGENT,
|
||||
context_id=a2a_task.context_id,
|
||||
task_id=a2a_task.id,
|
||||
)
|
||||
@@ -247,34 +250,34 @@ def process_task_state(
|
||||
)
|
||||
|
||||
return TaskStateResult(
|
||||
status=TaskState.input_required,
|
||||
status=TASK_STATE_INPUT_REQUIRED,
|
||||
error=response_text,
|
||||
history=new_messages,
|
||||
agent_card=agent_card.model_dump(exclude_none=True),
|
||||
agent_card=agent_card_to_dict(agent_card),
|
||||
)
|
||||
|
||||
if a2a_task.status.state in {TaskState.failed, TaskState.rejected}:
|
||||
if a2a_task.status.state in {TASK_STATE_FAILED, TASK_STATE_REJECTED}:
|
||||
error_msg = extract_error_message(a2a_task, "Task failed without error message")
|
||||
if a2a_task.history:
|
||||
new_messages.extend(a2a_task.history)
|
||||
return TaskStateResult(
|
||||
status=TaskState.failed,
|
||||
status=TASK_STATE_FAILED,
|
||||
error=error_msg,
|
||||
history=new_messages,
|
||||
)
|
||||
|
||||
if a2a_task.status.state == TaskState.auth_required:
|
||||
if a2a_task.status.state == TASK_STATE_AUTH_REQUIRED:
|
||||
error_msg = extract_error_message(a2a_task, "Authentication required")
|
||||
return TaskStateResult(
|
||||
status=TaskState.auth_required,
|
||||
status=TASK_STATE_AUTH_REQUIRED,
|
||||
error=error_msg,
|
||||
history=new_messages,
|
||||
)
|
||||
|
||||
if a2a_task.status.state == TaskState.canceled:
|
||||
if a2a_task.status.state == TASK_STATE_CANCELED:
|
||||
error_msg = extract_error_message(a2a_task, "Task was canceled")
|
||||
return TaskStateResult(
|
||||
status=TaskState.canceled,
|
||||
status=TASK_STATE_CANCELED,
|
||||
error=error_msg,
|
||||
history=new_messages,
|
||||
)
|
||||
@@ -286,7 +289,7 @@ def process_task_state(
|
||||
|
||||
|
||||
async def send_message_and_get_task_id(
|
||||
event_stream: AsyncIterator[SendMessageEvent],
|
||||
event_stream: AsyncIterator[StreamResponse],
|
||||
new_messages: list[Message],
|
||||
agent_card: AgentCard,
|
||||
turn_number: int,
|
||||
@@ -321,11 +324,12 @@ async def send_message_and_get_task_id(
|
||||
Task ID string if agent needs polling/waiting, or TaskStateResult if done.
|
||||
"""
|
||||
try:
|
||||
async for event in event_stream:
|
||||
if isinstance(event, Message):
|
||||
async for chunk in event_stream:
|
||||
if is_stream_message(chunk):
|
||||
event = chunk.message
|
||||
new_messages.append(event)
|
||||
result_parts = [
|
||||
part.root.text for part in event.parts if part.root.kind == "text"
|
||||
part_text(part) for part in event.parts if part_is_text(part)
|
||||
]
|
||||
response_text = " ".join(result_parts) if result_parts else ""
|
||||
|
||||
@@ -348,14 +352,14 @@ async def send_message_and_get_task_id(
|
||||
)
|
||||
|
||||
return TaskStateResult(
|
||||
status=TaskState.completed,
|
||||
status=TASK_STATE_COMPLETED,
|
||||
result=response_text,
|
||||
history=new_messages,
|
||||
agent_card=agent_card.model_dump(exclude_none=True),
|
||||
agent_card=agent_card_to_dict(agent_card),
|
||||
)
|
||||
|
||||
if isinstance(event, tuple):
|
||||
a2a_task, _ = event
|
||||
if is_stream_task(chunk):
|
||||
a2a_task = chunk.task
|
||||
|
||||
if a2a_task.status.state in TERMINAL_STATES | ACTIONABLE_STATES:
|
||||
result = process_task_state(
|
||||
@@ -376,18 +380,17 @@ async def send_message_and_get_task_id(
|
||||
return a2a_task.id
|
||||
|
||||
return TaskStateResult(
|
||||
status=TaskState.failed,
|
||||
status=TASK_STATE_FAILED,
|
||||
error="No task ID received from initial message",
|
||||
history=new_messages,
|
||||
)
|
||||
|
||||
except A2AClientHTTPError as e:
|
||||
error_msg = f"HTTP Error {e.status_code}: {e!s}"
|
||||
except A2AClientError as e:
|
||||
error_msg = f"A2A Client Error: {e!s}"
|
||||
|
||||
error_message = Message(
|
||||
role=Role.agent,
|
||||
message_id=str(uuid.uuid4()),
|
||||
parts=[Part(root=TextPart(text=error_msg))],
|
||||
error_message = new_text_message(
|
||||
error_msg,
|
||||
role=ROLE_AGENT,
|
||||
context_id=context_id,
|
||||
)
|
||||
new_messages.append(error_message)
|
||||
@@ -397,8 +400,7 @@ async def send_message_and_get_task_id(
|
||||
A2AConnectionErrorEvent(
|
||||
endpoint=endpoint or "",
|
||||
error=str(e),
|
||||
error_type="http_error",
|
||||
status_code=e.status_code,
|
||||
error_type="client_error",
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
operation="send_message",
|
||||
context_id=context_id,
|
||||
@@ -423,7 +425,7 @@ async def send_message_and_get_task_id(
|
||||
),
|
||||
)
|
||||
return TaskStateResult(
|
||||
status=TaskState.failed,
|
||||
status=TASK_STATE_FAILED,
|
||||
error=error_msg,
|
||||
history=new_messages,
|
||||
)
|
||||
@@ -431,10 +433,9 @@ async def send_message_and_get_task_id(
|
||||
except Exception as e:
|
||||
error_msg = f"Unexpected error during send_message: {e!s}"
|
||||
|
||||
error_message = Message(
|
||||
role=Role.agent,
|
||||
message_id=str(uuid.uuid4()),
|
||||
parts=[Part(root=TextPart(text=error_msg))],
|
||||
error_message = new_text_message(
|
||||
error_msg,
|
||||
role=ROLE_AGENT,
|
||||
context_id=context_id,
|
||||
)
|
||||
new_messages.append(error_message)
|
||||
@@ -469,7 +470,7 @@ async def send_message_and_get_task_id(
|
||||
),
|
||||
)
|
||||
return TaskStateResult(
|
||||
status=TaskState.failed,
|
||||
status=TASK_STATE_FAILED,
|
||||
error=error_msg,
|
||||
history=new_messages,
|
||||
)
|
||||
|
||||
@@ -5,21 +5,21 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Any
|
||||
import uuid
|
||||
|
||||
from a2a.client import Client
|
||||
from a2a.client.errors import A2AClientHTTPError
|
||||
from a2a.client.errors import A2AClientError
|
||||
from a2a.types import (
|
||||
AgentCard,
|
||||
GetTaskRequest,
|
||||
Message,
|
||||
Part,
|
||||
Role,
|
||||
TaskQueryParams,
|
||||
TaskState,
|
||||
TextPart,
|
||||
)
|
||||
from typing_extensions import Unpack
|
||||
|
||||
from crewai.a2a._compat import (
|
||||
ROLE_AGENT,
|
||||
TASK_STATE_FAILED,
|
||||
new_text_message,
|
||||
)
|
||||
from crewai.a2a.errors import A2APollingTimeoutError
|
||||
from crewai.a2a.task_helpers import (
|
||||
ACTIONABLE_STATES,
|
||||
@@ -84,7 +84,7 @@ async def _poll_task_until_complete(
|
||||
while True:
|
||||
poll_count += 1
|
||||
task = await client.get_task(
|
||||
TaskQueryParams(id=task_id, history_length=history_length)
|
||||
GetTaskRequest(id=task_id, history_length=history_length)
|
||||
)
|
||||
|
||||
elapsed = time.monotonic() - start_time
|
||||
@@ -94,7 +94,7 @@ async def _poll_task_until_complete(
|
||||
A2APollingStatusEvent(
|
||||
task_id=task_id,
|
||||
context_id=effective_context_id,
|
||||
state=str(task.status.state.value),
|
||||
state=str(task.status.state),
|
||||
elapsed_seconds=elapsed,
|
||||
poll_count=poll_count,
|
||||
endpoint=endpoint,
|
||||
@@ -158,9 +158,11 @@ class PollingHandler:
|
||||
from_task = kwargs.get("from_task")
|
||||
from_agent = kwargs.get("from_agent")
|
||||
|
||||
from crewai.a2a._compat import make_send_request
|
||||
|
||||
try:
|
||||
result_or_task_id = await send_message_and_get_task_id(
|
||||
event_stream=client.send_message(message),
|
||||
event_stream=client.send_message(make_send_request(message)),
|
||||
new_messages=new_messages,
|
||||
agent_card=agent_card,
|
||||
turn_number=turn_number,
|
||||
@@ -222,7 +224,7 @@ class PollingHandler:
|
||||
return result
|
||||
|
||||
return TaskStateResult(
|
||||
status=TaskState.failed,
|
||||
status=TASK_STATE_FAILED,
|
||||
error=f"Unexpected task state: {final_task.status.state}",
|
||||
history=new_messages,
|
||||
)
|
||||
@@ -230,10 +232,9 @@ class PollingHandler:
|
||||
except A2APollingTimeoutError as e:
|
||||
error_msg = str(e)
|
||||
|
||||
error_message = Message(
|
||||
role=Role.agent,
|
||||
message_id=str(uuid.uuid4()),
|
||||
parts=[Part(root=TextPart(text=error_msg))],
|
||||
error_message = new_text_message(
|
||||
error_msg,
|
||||
role=ROLE_AGENT,
|
||||
context_id=context_id,
|
||||
task_id=task_id,
|
||||
)
|
||||
@@ -256,18 +257,17 @@ class PollingHandler:
|
||||
),
|
||||
)
|
||||
return TaskStateResult(
|
||||
status=TaskState.failed,
|
||||
status=TASK_STATE_FAILED,
|
||||
error=error_msg,
|
||||
history=new_messages,
|
||||
)
|
||||
|
||||
except A2AClientHTTPError as e:
|
||||
error_msg = f"HTTP Error {e.status_code}: {e!s}"
|
||||
except A2AClientError as e:
|
||||
error_msg = f"A2A Client Error: {e!s}"
|
||||
|
||||
error_message = Message(
|
||||
role=Role.agent,
|
||||
message_id=str(uuid.uuid4()),
|
||||
parts=[Part(root=TextPart(text=error_msg))],
|
||||
error_message = new_text_message(
|
||||
error_msg,
|
||||
role=ROLE_AGENT,
|
||||
context_id=context_id,
|
||||
task_id=task_id,
|
||||
)
|
||||
@@ -278,8 +278,7 @@ class PollingHandler:
|
||||
A2AConnectionErrorEvent(
|
||||
endpoint=endpoint,
|
||||
error=str(e),
|
||||
error_type="http_error",
|
||||
status_code=e.status_code,
|
||||
error_type="client_error",
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
operation="polling",
|
||||
context_id=context_id,
|
||||
@@ -305,7 +304,7 @@ class PollingHandler:
|
||||
),
|
||||
)
|
||||
return TaskStateResult(
|
||||
status=TaskState.failed,
|
||||
status=TASK_STATE_FAILED,
|
||||
error=error_msg,
|
||||
history=new_messages,
|
||||
)
|
||||
@@ -313,10 +312,9 @@ class PollingHandler:
|
||||
except Exception as e:
|
||||
error_msg = f"Unexpected error during polling: {e!s}"
|
||||
|
||||
error_message = Message(
|
||||
role=Role.agent,
|
||||
message_id=str(uuid.uuid4()),
|
||||
parts=[Part(root=TextPart(text=error_msg))],
|
||||
error_message = new_text_message(
|
||||
error_msg,
|
||||
role=ROLE_AGENT,
|
||||
context_id=context_id,
|
||||
task_id=task_id,
|
||||
)
|
||||
@@ -353,7 +351,7 @@ class PollingHandler:
|
||||
),
|
||||
)
|
||||
return TaskStateResult(
|
||||
status=TaskState.failed,
|
||||
status=TASK_STATE_FAILED,
|
||||
error=error_msg,
|
||||
history=new_messages,
|
||||
)
|
||||
|
||||
@@ -2,9 +2,8 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Annotated
|
||||
from typing import Annotated, Any
|
||||
|
||||
from a2a.types import PushNotificationAuthenticationInfo
|
||||
from pydantic import AnyHttpUrl, BaseModel, BeforeValidator, Field
|
||||
|
||||
from crewai.a2a.updates.base import PushNotificationResultStore
|
||||
@@ -46,7 +45,7 @@ class PushNotificationConfig(BaseModel):
|
||||
url: AnyHttpUrl = Field(description="Callback URL for push notifications")
|
||||
id: str | None = Field(default=None, description="Unique config identifier")
|
||||
token: str | None = Field(default=None, description="Validation token")
|
||||
authentication: PushNotificationAuthenticationInfo | None = Field(
|
||||
authentication: Any | None = Field(
|
||||
default=None, description="Auth info for agent to use when calling webhook"
|
||||
)
|
||||
timeout: float | None = Field(
|
||||
|
||||
@@ -4,20 +4,20 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any
|
||||
import uuid
|
||||
|
||||
from a2a.client import Client
|
||||
from a2a.client.errors import A2AClientHTTPError
|
||||
from a2a.client.errors import A2AClientError
|
||||
from a2a.types import (
|
||||
AgentCard,
|
||||
Message,
|
||||
Part,
|
||||
Role,
|
||||
TaskState,
|
||||
TextPart,
|
||||
)
|
||||
from typing_extensions import Unpack
|
||||
|
||||
from crewai.a2a._compat import (
|
||||
ROLE_AGENT,
|
||||
TASK_STATE_FAILED,
|
||||
new_text_message,
|
||||
)
|
||||
from crewai.a2a.task_helpers import (
|
||||
TaskStateResult,
|
||||
process_task_state,
|
||||
@@ -69,10 +69,9 @@ def _handle_push_error(
|
||||
Returns:
|
||||
TaskStateResult with failed status.
|
||||
"""
|
||||
error_message = Message(
|
||||
role=Role.agent,
|
||||
message_id=str(uuid.uuid4()),
|
||||
parts=[Part(root=TextPart(text=error_msg))],
|
||||
error_message = new_text_message(
|
||||
error_msg,
|
||||
role=ROLE_AGENT,
|
||||
context_id=params.context_id,
|
||||
task_id=task_id,
|
||||
)
|
||||
@@ -110,7 +109,7 @@ def _handle_push_error(
|
||||
),
|
||||
)
|
||||
return TaskStateResult(
|
||||
status=TaskState.failed,
|
||||
status=TASK_STATE_FAILED,
|
||||
error=error_msg,
|
||||
history=new_messages,
|
||||
)
|
||||
@@ -221,7 +220,7 @@ class PushNotificationHandler:
|
||||
),
|
||||
)
|
||||
return TaskStateResult(
|
||||
status=TaskState.failed,
|
||||
status=TASK_STATE_FAILED,
|
||||
error=error_msg,
|
||||
history=new_messages,
|
||||
)
|
||||
@@ -245,14 +244,16 @@ class PushNotificationHandler:
|
||||
),
|
||||
)
|
||||
return TaskStateResult(
|
||||
status=TaskState.failed,
|
||||
status=TASK_STATE_FAILED,
|
||||
error=error_msg,
|
||||
history=new_messages,
|
||||
)
|
||||
|
||||
from crewai.a2a._compat import make_send_request
|
||||
|
||||
try:
|
||||
result_or_task_id = await send_message_and_get_task_id(
|
||||
event_stream=client.send_message(message),
|
||||
event_stream=client.send_message(make_send_request(message)),
|
||||
new_messages=new_messages,
|
||||
agent_card=agent_card,
|
||||
turn_number=params.turn_number,
|
||||
@@ -304,7 +305,7 @@ class PushNotificationHandler:
|
||||
|
||||
if final_task is None:
|
||||
return TaskStateResult(
|
||||
status=TaskState.failed,
|
||||
status=TASK_STATE_FAILED,
|
||||
error=f"Push notification timeout after {polling_timeout}s",
|
||||
history=new_messages,
|
||||
)
|
||||
@@ -325,21 +326,20 @@ class PushNotificationHandler:
|
||||
return result
|
||||
|
||||
return TaskStateResult(
|
||||
status=TaskState.failed,
|
||||
status=TASK_STATE_FAILED,
|
||||
error=f"Unexpected task state: {final_task.status.state}",
|
||||
history=new_messages,
|
||||
)
|
||||
|
||||
except A2AClientHTTPError as e:
|
||||
except A2AClientError as e:
|
||||
return _handle_push_error(
|
||||
error=e,
|
||||
error_msg=f"HTTP Error {e.status_code}: {e!s}",
|
||||
error_type="http_error",
|
||||
error_msg=f"A2A Client Error: {e!s}",
|
||||
error_type="client_error",
|
||||
new_messages=new_messages,
|
||||
agent_branch=agent_branch,
|
||||
params=params,
|
||||
task_id=task_id,
|
||||
status_code=e.status_code,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
@@ -5,25 +5,30 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Final
|
||||
import uuid
|
||||
|
||||
from a2a.client import Client
|
||||
from a2a.client.errors import A2AClientHTTPError
|
||||
from a2a.client.errors import A2AClientError
|
||||
from a2a.types import (
|
||||
AgentCard,
|
||||
GetTaskRequest,
|
||||
Message,
|
||||
Part,
|
||||
Role,
|
||||
SubscribeToTaskRequest,
|
||||
Task,
|
||||
TaskArtifactUpdateEvent,
|
||||
TaskIdParams,
|
||||
TaskQueryParams,
|
||||
TaskState,
|
||||
TaskStatusUpdateEvent,
|
||||
TextPart,
|
||||
)
|
||||
from typing_extensions import Unpack
|
||||
|
||||
from crewai.a2a._compat import (
|
||||
ROLE_AGENT,
|
||||
TASK_STATE_FAILED,
|
||||
is_stream_artifact_update,
|
||||
is_stream_message,
|
||||
is_stream_status_update,
|
||||
is_stream_task,
|
||||
new_text_message,
|
||||
part_is_text,
|
||||
part_text,
|
||||
)
|
||||
from crewai.a2a.task_helpers import (
|
||||
ACTIONABLE_STATES,
|
||||
TERMINAL_STATES,
|
||||
@@ -50,6 +55,16 @@ MAX_RESUBSCRIBE_ATTEMPTS: Final[int] = 3
|
||||
RESUBSCRIBE_BACKOFF_BASE: Final[float] = 1.0
|
||||
|
||||
|
||||
def _extract_text_from_artifact(artifact: TaskArtifactUpdateEvent) -> list[str]:
|
||||
"""Extract text parts from an artifact update event."""
|
||||
parts: list[str] = []
|
||||
if artifact.artifact and artifact.artifact.parts:
|
||||
parts.extend(
|
||||
part_text(part) for part in artifact.artifact.parts if part_is_text(part)
|
||||
)
|
||||
return parts
|
||||
|
||||
|
||||
class StreamingHandler:
|
||||
"""SSE streaming-based update handler."""
|
||||
|
||||
@@ -86,7 +101,7 @@ class StreamingHandler:
|
||||
params = extract_common_params(kwargs) # type: ignore[arg-type]
|
||||
|
||||
try:
|
||||
a2a_task: Task = await client.get_task(TaskQueryParams(id=task_id))
|
||||
a2a_task: Task = await client.get_task(GetTaskRequest(id=task_id))
|
||||
|
||||
if a2a_task.status.state in TERMINAL_STATES:
|
||||
logger.info(
|
||||
@@ -138,26 +153,18 @@ class StreamingHandler:
|
||||
if attempt > 0:
|
||||
await asyncio.sleep(backoff)
|
||||
|
||||
event_stream = client.resubscribe(TaskIdParams(id=task_id))
|
||||
event_stream = client.subscribe(SubscribeToTaskRequest(id=task_id))
|
||||
|
||||
async for event in event_stream:
|
||||
if isinstance(event, tuple):
|
||||
resubscribed_task, update = event
|
||||
async for chunk in event_stream:
|
||||
if is_stream_task(chunk):
|
||||
resubscribed_task = chunk.task
|
||||
|
||||
is_final_update = (
|
||||
process_status_update(update, result_parts)
|
||||
if isinstance(update, TaskStatusUpdateEvent)
|
||||
else False
|
||||
if is_stream_status_update(chunk):
|
||||
update = chunk.status_update
|
||||
is_final_update = process_status_update(
|
||||
update, result_parts
|
||||
)
|
||||
|
||||
if isinstance(update, TaskArtifactUpdateEvent):
|
||||
artifact = update.artifact
|
||||
result_parts.extend(
|
||||
part.root.text
|
||||
for part in artifact.parts
|
||||
if part.root.kind == "text"
|
||||
)
|
||||
|
||||
if (
|
||||
is_final_update
|
||||
or resubscribed_task.status.state
|
||||
@@ -178,15 +185,20 @@ class StreamingHandler:
|
||||
is_final=is_final_update,
|
||||
)
|
||||
|
||||
elif isinstance(event, Message):
|
||||
new_messages.append(event)
|
||||
if is_stream_artifact_update(chunk):
|
||||
artifact = chunk.artifact_update
|
||||
result_parts.extend(_extract_text_from_artifact(artifact))
|
||||
|
||||
if is_stream_message(chunk):
|
||||
msg = chunk.message
|
||||
new_messages.append(msg)
|
||||
result_parts.extend(
|
||||
part.root.text
|
||||
for part in event.parts
|
||||
if part.root.kind == "text"
|
||||
part_text(part)
|
||||
for part in msg.parts
|
||||
if part_is_text(part)
|
||||
)
|
||||
|
||||
final_task = await client.get_task(TaskQueryParams(id=task_id))
|
||||
final_task = await client.get_task(GetTaskRequest(id=task_id))
|
||||
return process_task_state(
|
||||
a2a_task=final_task,
|
||||
new_messages=new_messages,
|
||||
@@ -258,9 +270,12 @@ class StreamingHandler:
|
||||
|
||||
result_parts: list[str] = []
|
||||
final_result: TaskStateResult | None = None
|
||||
event_stream = client.send_message(message)
|
||||
from crewai.a2a._compat import make_send_request
|
||||
|
||||
event_stream = client.send_message(make_send_request(message))
|
||||
chunk_index = 0
|
||||
current_task_id: str | None = task_id
|
||||
current_task: Task | None = None
|
||||
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
@@ -278,22 +293,25 @@ class StreamingHandler:
|
||||
)
|
||||
|
||||
try:
|
||||
async for event in event_stream:
|
||||
if isinstance(event, tuple):
|
||||
a2a_task, _ = event
|
||||
current_task_id = a2a_task.id
|
||||
async for chunk in event_stream:
|
||||
# Extract task from task payload
|
||||
if is_stream_task(chunk):
|
||||
current_task = chunk.task
|
||||
current_task_id = current_task.id
|
||||
|
||||
if isinstance(event, Message):
|
||||
new_messages.append(event)
|
||||
message_context_id = event.context_id or params.context_id
|
||||
for part in event.parts:
|
||||
if part.root.kind == "text":
|
||||
text = part.root.text
|
||||
# Handle standalone message responses
|
||||
if is_stream_message(chunk):
|
||||
msg = chunk.message
|
||||
new_messages.append(msg)
|
||||
message_context_id = msg.context_id or params.context_id
|
||||
for part in msg.parts:
|
||||
if part_is_text(part):
|
||||
text = part_text(part)
|
||||
result_parts.append(text)
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
A2AStreamingChunkEvent(
|
||||
task_id=event.task_id or task_id,
|
||||
task_id=msg.task_id or task_id,
|
||||
context_id=message_context_id,
|
||||
chunk=text,
|
||||
chunk_index=chunk_index,
|
||||
@@ -307,38 +325,40 @@ class StreamingHandler:
|
||||
)
|
||||
chunk_index += 1
|
||||
|
||||
elif isinstance(event, tuple):
|
||||
a2a_task, update = event
|
||||
|
||||
if isinstance(update, TaskArtifactUpdateEvent):
|
||||
artifact = update.artifact
|
||||
# Handle artifact updates
|
||||
elif is_stream_artifact_update(chunk):
|
||||
artifact_update = chunk.artifact_update
|
||||
artifact = artifact_update.artifact
|
||||
if artifact and artifact.parts:
|
||||
result_parts.extend(
|
||||
part.root.text
|
||||
part_text(part)
|
||||
for part in artifact.parts
|
||||
if part.root.kind == "text"
|
||||
if part_is_text(part)
|
||||
)
|
||||
artifact_size = None
|
||||
if artifact.parts:
|
||||
artifact_size = sum(
|
||||
len(p.root.text.encode())
|
||||
if p.root.kind == "text"
|
||||
else len(getattr(p.root, "data", b""))
|
||||
len(part_text(p).encode())
|
||||
if part_is_text(p)
|
||||
else len(getattr(p, "raw", b""))
|
||||
for p in artifact.parts
|
||||
)
|
||||
effective_context_id = a2a_task.context_id or params.context_id
|
||||
effective_context_id = (
|
||||
current_task.context_id if current_task else None
|
||||
) or params.context_id
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
A2AArtifactReceivedEvent(
|
||||
task_id=a2a_task.id,
|
||||
task_id=artifact_update.task_id or current_task_id,
|
||||
artifact_id=artifact.artifact_id,
|
||||
artifact_name=artifact.name,
|
||||
artifact_description=artifact.description,
|
||||
mime_type=artifact.parts[0].root.kind
|
||||
if artifact.parts
|
||||
mime_type="text"
|
||||
if artifact.parts and part_is_text(artifact.parts[0])
|
||||
else None,
|
||||
size_bytes=artifact_size,
|
||||
append=update.append or False,
|
||||
last_chunk=update.last_chunk or False,
|
||||
append=artifact_update.append or False,
|
||||
last_chunk=artifact_update.last_chunk or False,
|
||||
endpoint=params.endpoint,
|
||||
a2a_agent_name=params.a2a_agent_name,
|
||||
context_id=effective_context_id,
|
||||
@@ -349,86 +369,63 @@ class StreamingHandler:
|
||||
),
|
||||
)
|
||||
|
||||
is_final_update = (
|
||||
process_status_update(update, result_parts)
|
||||
if isinstance(update, TaskStatusUpdateEvent)
|
||||
else False
|
||||
)
|
||||
# Handle status updates
|
||||
elif is_stream_status_update(chunk):
|
||||
update = chunk.status_update
|
||||
is_final_update = process_status_update(update, result_parts)
|
||||
|
||||
if (
|
||||
not is_final_update
|
||||
and a2a_task.status.state
|
||||
not in TERMINAL_STATES | ACTIONABLE_STATES
|
||||
if current_task and (
|
||||
is_final_update
|
||||
or current_task.status.state
|
||||
in TERMINAL_STATES | ACTIONABLE_STATES
|
||||
):
|
||||
final_result = process_task_state(
|
||||
a2a_task=current_task,
|
||||
new_messages=new_messages,
|
||||
agent_card=agent_card,
|
||||
turn_number=params.turn_number,
|
||||
is_multiturn=params.is_multiturn,
|
||||
agent_role=params.agent_role,
|
||||
result_parts=result_parts,
|
||||
endpoint=params.endpoint,
|
||||
a2a_agent_name=params.a2a_agent_name,
|
||||
from_task=params.from_task,
|
||||
from_agent=params.from_agent,
|
||||
is_final=is_final_update,
|
||||
)
|
||||
elif not current_task and is_final_update:
|
||||
pass
|
||||
else:
|
||||
continue
|
||||
|
||||
final_result = process_task_state(
|
||||
a2a_task=a2a_task,
|
||||
new_messages=new_messages,
|
||||
agent_card=agent_card,
|
||||
turn_number=params.turn_number,
|
||||
is_multiturn=params.is_multiturn,
|
||||
agent_role=params.agent_role,
|
||||
result_parts=result_parts,
|
||||
endpoint=params.endpoint,
|
||||
a2a_agent_name=params.a2a_agent_name,
|
||||
from_task=params.from_task,
|
||||
from_agent=params.from_agent,
|
||||
is_final=is_final_update,
|
||||
)
|
||||
if final_result:
|
||||
break
|
||||
except A2AClientError as e:
|
||||
logger.warning(
|
||||
"Stream interrupted",
|
||||
extra={
|
||||
"task_id": current_task_id,
|
||||
"error": str(e),
|
||||
"error_type": type(e).__name__,
|
||||
},
|
||||
)
|
||||
|
||||
except A2AClientHTTPError as e:
|
||||
if current_task_id:
|
||||
logger.info(
|
||||
"Stream interrupted with HTTP error, attempting recovery",
|
||||
extra={
|
||||
"task_id": current_task_id,
|
||||
"error": str(e),
|
||||
"status_code": e.status_code,
|
||||
},
|
||||
recovery_result = await StreamingHandler._try_recover_from_interruption(
|
||||
client=client,
|
||||
task_id=current_task_id,
|
||||
new_messages=new_messages,
|
||||
agent_card=agent_card,
|
||||
result_parts=result_parts,
|
||||
**kwargs,
|
||||
)
|
||||
recovery_kwargs = {k: v for k, v in kwargs.items() if k != "task_id"}
|
||||
recovered_result = (
|
||||
await StreamingHandler._try_recover_from_interruption(
|
||||
client=client,
|
||||
task_id=current_task_id,
|
||||
new_messages=new_messages,
|
||||
agent_card=agent_card,
|
||||
result_parts=result_parts,
|
||||
**recovery_kwargs,
|
||||
)
|
||||
)
|
||||
if recovered_result:
|
||||
logger.info(
|
||||
"Successfully recovered task after HTTP error",
|
||||
extra={
|
||||
"task_id": current_task_id,
|
||||
"status": str(recovered_result.get("status")),
|
||||
},
|
||||
)
|
||||
return recovered_result
|
||||
if recovery_result:
|
||||
return recovery_result
|
||||
|
||||
logger.warning(
|
||||
"Failed to recover from HTTP error, returning failure",
|
||||
extra={
|
||||
"task_id": current_task_id,
|
||||
"status_code": e.status_code,
|
||||
"original_error": str(e),
|
||||
},
|
||||
)
|
||||
|
||||
error_msg = f"HTTP Error {e.status_code}: {e!s}"
|
||||
error_type = "http_error"
|
||||
status_code = e.status_code
|
||||
|
||||
error_message = Message(
|
||||
role=Role.agent,
|
||||
message_id=str(uuid.uuid4()),
|
||||
parts=[Part(root=TextPart(text=error_msg))],
|
||||
error_msg = f"A2A Client Error: {e!s}"
|
||||
error_message = new_text_message(
|
||||
error_msg,
|
||||
role=ROLE_AGENT,
|
||||
context_id=params.context_id,
|
||||
task_id=task_id,
|
||||
task_id=current_task_id,
|
||||
)
|
||||
new_messages.append(error_message)
|
||||
|
||||
@@ -437,12 +434,11 @@ class StreamingHandler:
|
||||
A2AConnectionErrorEvent(
|
||||
endpoint=params.endpoint,
|
||||
error=str(e),
|
||||
error_type=error_type,
|
||||
status_code=status_code,
|
||||
error_type="client_error",
|
||||
a2a_agent_name=params.a2a_agent_name,
|
||||
operation="streaming",
|
||||
context_id=params.context_id,
|
||||
task_id=task_id,
|
||||
task_id=current_task_id,
|
||||
from_task=params.from_task,
|
||||
from_agent=params.from_agent,
|
||||
),
|
||||
@@ -464,116 +460,40 @@ class StreamingHandler:
|
||||
),
|
||||
)
|
||||
return TaskStateResult(
|
||||
status=TaskState.failed,
|
||||
error=error_msg,
|
||||
history=new_messages,
|
||||
)
|
||||
|
||||
except (asyncio.TimeoutError, asyncio.CancelledError, ConnectionError) as e:
|
||||
error_type = type(e).__name__.lower()
|
||||
if current_task_id:
|
||||
logger.info(
|
||||
f"Stream interrupted with {error_type}, attempting recovery",
|
||||
extra={"task_id": current_task_id, "error": str(e)},
|
||||
)
|
||||
recovery_kwargs = {k: v for k, v in kwargs.items() if k != "task_id"}
|
||||
recovered_result = (
|
||||
await StreamingHandler._try_recover_from_interruption(
|
||||
client=client,
|
||||
task_id=current_task_id,
|
||||
new_messages=new_messages,
|
||||
agent_card=agent_card,
|
||||
result_parts=result_parts,
|
||||
**recovery_kwargs,
|
||||
)
|
||||
)
|
||||
if recovered_result:
|
||||
logger.info(
|
||||
f"Successfully recovered task after {error_type}",
|
||||
extra={
|
||||
"task_id": current_task_id,
|
||||
"status": str(recovered_result.get("status")),
|
||||
},
|
||||
)
|
||||
return recovered_result
|
||||
|
||||
logger.warning(
|
||||
f"Failed to recover from {error_type}, returning failure",
|
||||
extra={
|
||||
"task_id": current_task_id,
|
||||
"error_type": error_type,
|
||||
"original_error": str(e),
|
||||
},
|
||||
)
|
||||
|
||||
error_msg = f"Connection error during streaming: {e!s}"
|
||||
status_code = None
|
||||
|
||||
error_message = Message(
|
||||
role=Role.agent,
|
||||
message_id=str(uuid.uuid4()),
|
||||
parts=[Part(root=TextPart(text=error_msg))],
|
||||
context_id=params.context_id,
|
||||
task_id=task_id,
|
||||
)
|
||||
new_messages.append(error_message)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
A2AConnectionErrorEvent(
|
||||
endpoint=params.endpoint,
|
||||
error=str(e),
|
||||
error_type=error_type,
|
||||
status_code=status_code,
|
||||
a2a_agent_name=params.a2a_agent_name,
|
||||
operation="streaming",
|
||||
context_id=params.context_id,
|
||||
task_id=task_id,
|
||||
from_task=params.from_task,
|
||||
from_agent=params.from_agent,
|
||||
),
|
||||
)
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
A2AResponseReceivedEvent(
|
||||
response=error_msg,
|
||||
turn_number=params.turn_number,
|
||||
context_id=params.context_id,
|
||||
is_multiturn=params.is_multiturn,
|
||||
status="failed",
|
||||
final=True,
|
||||
agent_role=params.agent_role,
|
||||
endpoint=params.endpoint,
|
||||
a2a_agent_name=params.a2a_agent_name,
|
||||
from_task=params.from_task,
|
||||
from_agent=params.from_agent,
|
||||
),
|
||||
)
|
||||
return TaskStateResult(
|
||||
status=TaskState.failed,
|
||||
status=TASK_STATE_FAILED,
|
||||
error=error_msg,
|
||||
history=new_messages,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"Unexpected error during streaming",
|
||||
logger.warning(
|
||||
"Unexpected stream error",
|
||||
extra={
|
||||
"task_id": current_task_id,
|
||||
"error": str(e),
|
||||
"error_type": type(e).__name__,
|
||||
"endpoint": params.endpoint,
|
||||
},
|
||||
exc_info=True,
|
||||
)
|
||||
error_msg = f"Unexpected error during streaming: {type(e).__name__}: {e!s}"
|
||||
error_type = "unexpected_error"
|
||||
status_code = None
|
||||
|
||||
error_message = Message(
|
||||
role=Role.agent,
|
||||
message_id=str(uuid.uuid4()),
|
||||
parts=[Part(root=TextPart(text=error_msg))],
|
||||
if current_task_id:
|
||||
recovery_result = await StreamingHandler._try_recover_from_interruption(
|
||||
client=client,
|
||||
task_id=current_task_id,
|
||||
new_messages=new_messages,
|
||||
agent_card=agent_card,
|
||||
result_parts=result_parts,
|
||||
**kwargs,
|
||||
)
|
||||
if recovery_result:
|
||||
return recovery_result
|
||||
|
||||
error_msg = f"Unexpected error during streaming: {e!s}"
|
||||
error_message = new_text_message(
|
||||
error_msg,
|
||||
role=ROLE_AGENT,
|
||||
context_id=params.context_id,
|
||||
task_id=task_id,
|
||||
task_id=current_task_id,
|
||||
)
|
||||
new_messages.append(error_message)
|
||||
|
||||
@@ -582,12 +502,11 @@ class StreamingHandler:
|
||||
A2AConnectionErrorEvent(
|
||||
endpoint=params.endpoint,
|
||||
error=str(e),
|
||||
error_type=error_type,
|
||||
status_code=status_code,
|
||||
error_type="unexpected_error",
|
||||
a2a_agent_name=params.a2a_agent_name,
|
||||
operation="streaming",
|
||||
context_id=params.context_id,
|
||||
task_id=task_id,
|
||||
task_id=current_task_id,
|
||||
from_task=params.from_task,
|
||||
from_agent=params.from_agent,
|
||||
),
|
||||
@@ -609,38 +528,33 @@ class StreamingHandler:
|
||||
),
|
||||
)
|
||||
return TaskStateResult(
|
||||
status=TaskState.failed,
|
||||
status=TASK_STATE_FAILED,
|
||||
error=error_msg,
|
||||
history=new_messages,
|
||||
)
|
||||
|
||||
finally:
|
||||
aclose = getattr(event_stream, "aclose", None)
|
||||
if aclose:
|
||||
try:
|
||||
await aclose()
|
||||
except Exception as close_error:
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
A2AConnectionErrorEvent(
|
||||
endpoint=params.endpoint,
|
||||
error=str(close_error),
|
||||
error_type="stream_close_error",
|
||||
a2a_agent_name=params.a2a_agent_name,
|
||||
operation="stream_close",
|
||||
context_id=params.context_id,
|
||||
task_id=task_id,
|
||||
from_task=params.from_task,
|
||||
from_agent=params.from_agent,
|
||||
),
|
||||
)
|
||||
|
||||
if final_result:
|
||||
return final_result
|
||||
|
||||
return TaskStateResult(
|
||||
status=TaskState.completed,
|
||||
result=" ".join(result_parts) if result_parts else "",
|
||||
history=new_messages,
|
||||
agent_card=agent_card.model_dump(exclude_none=True),
|
||||
response_text = " ".join(result_parts) if result_parts else ""
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
A2AResponseReceivedEvent(
|
||||
response=response_text,
|
||||
turn_number=params.turn_number,
|
||||
context_id=params.context_id,
|
||||
is_multiturn=params.is_multiturn,
|
||||
status="completed",
|
||||
final=True,
|
||||
agent_role=params.agent_role,
|
||||
endpoint=params.endpoint,
|
||||
a2a_agent_name=params.a2a_agent_name,
|
||||
from_task=params.from_task,
|
||||
from_agent=params.from_agent,
|
||||
),
|
||||
)
|
||||
return TaskStateResult(
|
||||
status=TASK_STATE_FAILED,
|
||||
error="Stream ended without terminal state",
|
||||
history=new_messages,
|
||||
)
|
||||
|
||||
@@ -4,6 +4,8 @@ from __future__ import annotations
|
||||
|
||||
from a2a.types import TaskStatusUpdateEvent
|
||||
|
||||
from crewai.a2a._compat import is_status_update_final, part_is_text, part_text
|
||||
|
||||
|
||||
def process_status_update(
|
||||
update: TaskStatusUpdateEvent,
|
||||
@@ -18,11 +20,11 @@ def process_status_update(
|
||||
Returns:
|
||||
True if this is a final update, False otherwise.
|
||||
"""
|
||||
is_final = update.final
|
||||
is_final = is_status_update_final(update)
|
||||
if update.status and update.status.message and update.status.message.parts:
|
||||
result_parts.extend(
|
||||
part.root.text
|
||||
part_text(part)
|
||||
for part in update.status.message.parts
|
||||
if part.root.kind == "text" and part.root.text
|
||||
if part_is_text(part) and part_text(part)
|
||||
)
|
||||
return is_final
|
||||
|
||||
@@ -12,12 +12,18 @@ import time
|
||||
from types import MethodType
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from a2a.client.errors import A2AClientHTTPError
|
||||
from a2a.types import AgentCapabilities, AgentCard, AgentSkill
|
||||
from a2a.client.errors import A2AClientError
|
||||
from a2a.types import AgentCapabilities, AgentCard, AgentInterface, AgentSkill
|
||||
from aiocache import cached # type: ignore[import-untyped]
|
||||
from aiocache.serializers import PickleSerializer # type: ignore[import-untyped]
|
||||
from google.protobuf.json_format import ParseDict
|
||||
import httpx
|
||||
|
||||
from crewai.a2a._compat import (
|
||||
agent_card_protocol_version,
|
||||
agent_card_to_dict,
|
||||
proto_copy,
|
||||
)
|
||||
from crewai.a2a.auth.client_schemes import APIKeyAuth, HTTPDigestAuth
|
||||
from crewai.a2a.auth.utils import (
|
||||
_auth_store,
|
||||
@@ -277,9 +283,9 @@ async def _afetch_agent_card_impl(
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
agent_card = AgentCard.model_validate(response.json())
|
||||
agent_card = ParseDict(response.json(), AgentCard())
|
||||
fetch_time_ms = (time.perf_counter() - start_time) * 1000
|
||||
agent_card_dict = agent_card.model_dump(exclude_none=True)
|
||||
agent_card_dict = agent_card_to_dict(agent_card)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
None,
|
||||
@@ -287,7 +293,7 @@ async def _afetch_agent_card_impl(
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=agent_card.name,
|
||||
agent_card=agent_card_dict,
|
||||
protocol_version=agent_card.protocol_version,
|
||||
protocol_version=agent_card_protocol_version(agent_card),
|
||||
provider=agent_card_dict.get("provider"),
|
||||
cached=False,
|
||||
fetch_time_ms=fetch_time_ms,
|
||||
@@ -326,7 +332,7 @@ async def _afetch_agent_card_impl(
|
||||
),
|
||||
)
|
||||
|
||||
raise A2AClientHTTPError(401, msg) from e
|
||||
raise A2AClientError(msg) from e
|
||||
|
||||
crewai_event_bus.emit(
|
||||
None,
|
||||
@@ -470,7 +476,9 @@ def _crew_to_agent_card(crew: Crew, url: str) -> AgentCard:
|
||||
return AgentCard(
|
||||
name=crew_name,
|
||||
description=" ".join(description_parts),
|
||||
url=url,
|
||||
supported_interfaces=[
|
||||
AgentInterface(url=url, protocol_binding="JSONRPC"),
|
||||
],
|
||||
version="1.0.0",
|
||||
capabilities=AgentCapabilities(
|
||||
streaming=True,
|
||||
@@ -540,28 +548,43 @@ def _agent_to_agent_card(agent: Agent, url: str) -> AgentCard:
|
||||
if ext.uri not in existing_uris:
|
||||
existing_exts.append(ext)
|
||||
|
||||
capabilities = capabilities.model_copy(update={"extensions": existing_exts})
|
||||
capabilities = proto_copy(capabilities)
|
||||
del capabilities.extensions[:]
|
||||
capabilities.extensions.extend(existing_exts)
|
||||
|
||||
primary_interface = AgentInterface(
|
||||
url=server_config.url or url,
|
||||
protocol_binding=server_config.transport.preferred or "JSONRPC",
|
||||
protocol_version=server_config.protocol_version or "",
|
||||
)
|
||||
interfaces = [primary_interface]
|
||||
if server_config.additional_interfaces:
|
||||
interfaces.extend(server_config.additional_interfaces)
|
||||
|
||||
card = AgentCard(
|
||||
name=name,
|
||||
description=description,
|
||||
url=server_config.url or url,
|
||||
supported_interfaces=interfaces,
|
||||
version=server_config.version,
|
||||
capabilities=capabilities,
|
||||
default_input_modes=server_config.default_input_modes,
|
||||
default_output_modes=server_config.default_output_modes,
|
||||
skills=skills,
|
||||
preferred_transport=server_config.transport.preferred,
|
||||
protocol_version=server_config.protocol_version,
|
||||
provider=server_config.provider,
|
||||
documentation_url=server_config.documentation_url,
|
||||
icon_url=server_config.icon_url,
|
||||
additional_interfaces=server_config.additional_interfaces,
|
||||
security=server_config.security,
|
||||
security_schemes=server_config.security_schemes,
|
||||
supports_authenticated_extended_card=server_config.supports_authenticated_extended_card,
|
||||
documentation_url=server_config.documentation_url or "",
|
||||
icon_url=server_config.icon_url or "",
|
||||
)
|
||||
|
||||
if server_config.provider:
|
||||
card.provider.CopyFrom(server_config.provider)
|
||||
|
||||
if server_config.security_schemes:
|
||||
for k, v in server_config.security_schemes.items():
|
||||
card.security_schemes[k].CopyFrom(v)
|
||||
|
||||
if server_config.security:
|
||||
for req in server_config.security:
|
||||
card.security_requirements.append(req)
|
||||
|
||||
if server_config.signing_config:
|
||||
signature = sign_agent_card(
|
||||
card,
|
||||
@@ -569,9 +592,11 @@ def _agent_to_agent_card(agent: Agent, url: str) -> AgentCard:
|
||||
key_id=server_config.signing_config.key_id,
|
||||
algorithm=server_config.signing_config.algorithm,
|
||||
)
|
||||
card = card.model_copy(update={"signatures": [signature]})
|
||||
del card.signatures[:]
|
||||
card.signatures.append(signature)
|
||||
elif server_config.signatures:
|
||||
card = card.model_copy(update={"signatures": server_config.signatures})
|
||||
del card.signatures[:]
|
||||
card.signatures.extend(server_config.signatures)
|
||||
|
||||
return card
|
||||
|
||||
|
||||
@@ -18,6 +18,7 @@ import logging
|
||||
from typing import Any, Literal
|
||||
|
||||
from a2a.types import AgentCard, AgentCardSignature
|
||||
from google.protobuf.json_format import MessageToDict
|
||||
import jwt
|
||||
from pydantic import SecretStr
|
||||
|
||||
@@ -58,7 +59,11 @@ def _serialize_agent_card(agent_card: AgentCard) -> str:
|
||||
Returns:
|
||||
Canonical JSON string representation.
|
||||
"""
|
||||
card_dict = agent_card.model_dump(exclude={"signatures"}, exclude_none=True)
|
||||
card_dict = MessageToDict(
|
||||
agent_card,
|
||||
preserving_proto_field_name=True,
|
||||
)
|
||||
card_dict.pop("signatures", None)
|
||||
return json.dumps(card_dict, sort_keys=True, separators=(",", ":"))
|
||||
|
||||
|
||||
|
||||
@@ -264,7 +264,7 @@ def negotiate_content_types(
|
||||
crewai_event_bus.emit(
|
||||
None,
|
||||
A2AContentTypeNegotiatedEvent(
|
||||
endpoint=endpoint or agent_card.url,
|
||||
endpoint=endpoint or "",
|
||||
a2a_agent_name=a2a_agent_name or agent_card.name,
|
||||
skill_name=skill_name,
|
||||
client_input_modes=client_input_modes,
|
||||
@@ -303,22 +303,21 @@ def get_part_content_type(part: Part) -> str:
|
||||
"""Extract MIME type from an A2A Part.
|
||||
|
||||
Args:
|
||||
part: A Part object containing TextPart, DataPart, or FilePart.
|
||||
part: A Part object (protobuf oneof: text, data, raw, url).
|
||||
|
||||
Returns:
|
||||
The MIME type string for this part.
|
||||
"""
|
||||
root = part.root
|
||||
if root.kind == "text":
|
||||
if part.HasField("text"):
|
||||
return TEXT_PLAIN
|
||||
if root.kind == "data":
|
||||
metadata = root.metadata or {}
|
||||
if part.HasField("data"):
|
||||
metadata = dict(part.metadata) if part.metadata else {}
|
||||
mime = metadata.get("mimeType", "")
|
||||
if mime == APPLICATION_A2UI_JSON:
|
||||
return APPLICATION_A2UI_JSON
|
||||
return APPLICATION_JSON
|
||||
if root.kind == "file":
|
||||
return root.file.mime_type or APPLICATION_OCTET_STREAM
|
||||
if part.HasField("raw") or part.HasField("url"):
|
||||
return part.media_type or APPLICATION_OCTET_STREAM
|
||||
return APPLICATION_OCTET_STREAM
|
||||
|
||||
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
from collections.abc import AsyncIterator, Callable, MutableMapping
|
||||
import concurrent.futures
|
||||
from contextlib import asynccontextmanager
|
||||
@@ -12,20 +11,27 @@ import logging
|
||||
from typing import TYPE_CHECKING, Any, Final, Literal
|
||||
import uuid
|
||||
|
||||
from a2a.client import Client, ClientConfig, ClientFactory
|
||||
from a2a.client import Client, ClientFactory
|
||||
from a2a.types import (
|
||||
AgentCard,
|
||||
FilePart,
|
||||
FileWithBytes,
|
||||
Message,
|
||||
Part,
|
||||
PushNotificationConfig as A2APushNotificationConfig,
|
||||
Role,
|
||||
TextPart,
|
||||
TaskPushNotificationConfig,
|
||||
)
|
||||
import httpx
|
||||
from pydantic import BaseModel
|
||||
|
||||
from crewai.a2a._compat import (
|
||||
ROLE_USER,
|
||||
agent_card_interfaces,
|
||||
agent_card_protocol_version,
|
||||
agent_card_to_dict,
|
||||
agent_card_url,
|
||||
create_client_config,
|
||||
new_text_part,
|
||||
proto_copy,
|
||||
)
|
||||
from crewai.a2a.auth.client_schemes import APIKeyAuth, HTTPDigestAuth
|
||||
from crewai.a2a.auth.utils import (
|
||||
_auth_store,
|
||||
@@ -41,8 +47,6 @@ from crewai.a2a.task_helpers import TaskStateResult
|
||||
from crewai.a2a.types import (
|
||||
HANDLER_REGISTRY,
|
||||
HandlerType,
|
||||
PartsDict,
|
||||
PartsMetadataDict,
|
||||
TransportType,
|
||||
)
|
||||
from crewai.a2a.updates import (
|
||||
@@ -107,13 +111,13 @@ def _create_file_parts(input_files: dict[str, Any] | None) -> list[Part]:
|
||||
parts: list[Part] = []
|
||||
for name, file_input in input_files.items():
|
||||
content_bytes = file_input.read()
|
||||
content_base64 = base64.b64encode(content_bytes).decode()
|
||||
file_with_bytes = FileWithBytes(
|
||||
bytes=content_base64,
|
||||
mimeType=file_input.content_type,
|
||||
name=file_input.filename or name,
|
||||
parts.append(
|
||||
Part(
|
||||
raw=content_bytes,
|
||||
media_type=file_input.content_type or "application/octet-stream",
|
||||
filename=file_input.filename or name,
|
||||
)
|
||||
)
|
||||
parts.append(Part(root=FilePart(file=file_with_bytes)))
|
||||
|
||||
return parts
|
||||
|
||||
@@ -301,7 +305,7 @@ async def aexecute_a2a_delegation(
|
||||
|
||||
is_multiturn = len(conversation_history) > 0
|
||||
if turn_number is None:
|
||||
turn_number = len([m for m in conversation_history if m.role == Role.user]) + 1
|
||||
turn_number = len([m for m in conversation_history if m.role == ROLE_USER]) + 1
|
||||
|
||||
try:
|
||||
result = await _aexecute_a2a_delegation_impl(
|
||||
@@ -349,7 +353,7 @@ async def aexecute_a2a_delegation(
|
||||
)
|
||||
raise
|
||||
|
||||
agent_card_data = result.get("agent_card")
|
||||
agent_card_data: dict[str, Any] | None = result.get("agent_card")
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
A2ADelegationCompletedEvent(
|
||||
@@ -423,14 +427,14 @@ async def _aexecute_a2a_delegation_impl(
|
||||
|
||||
unsupported_exts = validate_required_extensions(agent_card, client_extensions)
|
||||
if unsupported_exts:
|
||||
ext_uris = [ext.uri for ext in unsupported_exts]
|
||||
ext_uris = [e.uri for e in unsupported_exts]
|
||||
raise ValueError(
|
||||
f"Agent requires extensions not supported by client: {ext_uris}"
|
||||
)
|
||||
|
||||
negotiated: NegotiatedTransport | None = None
|
||||
effective_transport: TransportType = transport.preferred or _DEFAULT_TRANSPORT
|
||||
effective_url = endpoint
|
||||
effective_url = agent_card_url(agent_card) or endpoint
|
||||
|
||||
client_transports: list[str] = (
|
||||
list(transport.supported) if transport.supported else [_DEFAULT_TRANSPORT]
|
||||
@@ -456,9 +460,9 @@ async def _aexecute_a2a_delegation_impl(
|
||||
"endpoint": endpoint,
|
||||
"client_transports": client_transports,
|
||||
"server_transports": [
|
||||
iface.transport for iface in agent_card.additional_interfaces or []
|
||||
]
|
||||
+ [agent_card.preferred_transport or "JSONRPC"],
|
||||
iface.protocol_binding
|
||||
for iface in agent_card_interfaces(agent_card)
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
@@ -476,11 +480,9 @@ async def _aexecute_a2a_delegation_impl(
|
||||
|
||||
headers, _ = await _prepare_auth_headers(auth, timeout)
|
||||
|
||||
a2a_agent_name = None
|
||||
if agent_card.name:
|
||||
a2a_agent_name = agent_card.name
|
||||
a2a_agent_name = agent_card.name or None
|
||||
|
||||
agent_card_dict = agent_card.model_dump(exclude_none=True)
|
||||
agent_card_dict = agent_card_to_dict(agent_card)
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
A2ADelegationStartedEvent(
|
||||
@@ -492,7 +494,7 @@ async def _aexecute_a2a_delegation_impl(
|
||||
turn_number=turn_number,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
agent_card=agent_card_dict,
|
||||
protocol_version=agent_card.protocol_version,
|
||||
protocol_version=agent_card_protocol_version(agent_card),
|
||||
provider=agent_card_dict.get("provider"),
|
||||
skill_id=skill_id,
|
||||
metadata=metadata,
|
||||
@@ -512,7 +514,7 @@ async def _aexecute_a2a_delegation_impl(
|
||||
context_id=context_id,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
agent_card=agent_card_dict,
|
||||
protocol_version=agent_card.protocol_version,
|
||||
protocol_version=agent_card_protocol_version(agent_card),
|
||||
provider=agent_card_dict.get("provider"),
|
||||
skill_id=skill_id,
|
||||
reference_task_ids=reference_task_ids,
|
||||
@@ -534,26 +536,18 @@ async def _aexecute_a2a_delegation_impl(
|
||||
if first_task_id := conversation_history[0].task_id:
|
||||
task_id = first_task_id
|
||||
|
||||
parts: PartsDict = {"text": message_text}
|
||||
if response_model:
|
||||
parts.update(
|
||||
{
|
||||
"metadata": PartsMetadataDict(
|
||||
mimeType="application/json",
|
||||
schema=response_model.model_json_schema(),
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
message_metadata = metadata.copy() if metadata else {}
|
||||
if skill_id:
|
||||
message_metadata["skill_id"] = skill_id
|
||||
if response_model:
|
||||
message_metadata["mimeType"] = "application/json"
|
||||
message_metadata["schema"] = response_model.model_json_schema()
|
||||
|
||||
parts_list: list[Part] = [Part(root=TextPart(**parts))]
|
||||
parts_list: list[Part] = [new_text_part(message_text)]
|
||||
parts_list.extend(_create_file_parts(input_files))
|
||||
|
||||
message = Message(
|
||||
role=Role.user,
|
||||
role=ROLE_USER,
|
||||
message_id=str(uuid.uuid4()),
|
||||
parts=parts_list,
|
||||
context_id=context_id,
|
||||
@@ -625,8 +619,11 @@ async def _aexecute_a2a_delegation_impl(
|
||||
use_streaming = not use_polling and push_config_for_client is None
|
||||
|
||||
client_agent_card = agent_card
|
||||
if effective_url != agent_card.url:
|
||||
client_agent_card = agent_card.model_copy(update={"url": effective_url})
|
||||
card_url = agent_card_url(agent_card)
|
||||
if effective_url != card_url:
|
||||
client_agent_card = proto_copy(agent_card)
|
||||
if client_agent_card.supported_interfaces:
|
||||
client_agent_card.supported_interfaces[0].url = effective_url
|
||||
|
||||
async with _create_a2a_client(
|
||||
agent_card=client_agent_card,
|
||||
@@ -649,7 +646,7 @@ async def _aexecute_a2a_delegation_impl(
|
||||
**handler_kwargs,
|
||||
)
|
||||
result["a2a_agent_name"] = a2a_agent_name
|
||||
result["agent_card"] = agent_card.model_dump(exclude_none=True)
|
||||
result["agent_card"] = agent_card_to_dict(agent_card)
|
||||
return result
|
||||
|
||||
|
||||
@@ -933,15 +930,12 @@ async def _create_a2a_client(
|
||||
if auth and isinstance(auth, (HTTPDigestAuth, APIKeyAuth)):
|
||||
configure_auth_client(auth, httpx_client)
|
||||
|
||||
push_configs: list[A2APushNotificationConfig] = []
|
||||
push_config: TaskPushNotificationConfig | None = None
|
||||
if push_notification_config is not None:
|
||||
push_configs.append(
|
||||
A2APushNotificationConfig(
|
||||
url=str(push_notification_config.url),
|
||||
id=push_notification_config.id,
|
||||
token=push_notification_config.token,
|
||||
authentication=push_notification_config.authentication,
|
||||
)
|
||||
push_config = TaskPushNotificationConfig(
|
||||
url=str(push_notification_config.url),
|
||||
id=push_notification_config.id or "",
|
||||
token=push_notification_config.token or "",
|
||||
)
|
||||
|
||||
grpc_channel_factory = None
|
||||
@@ -951,13 +945,14 @@ async def _create_a2a_client(
|
||||
auth=auth,
|
||||
)
|
||||
|
||||
config = ClientConfig(
|
||||
config = create_client_config(
|
||||
httpx_client=httpx_client,
|
||||
supported_transports=[transport_protocol],
|
||||
streaming=streaming and not use_polling,
|
||||
polling=use_polling,
|
||||
accepted_output_modes=accepted_output_modes or DEFAULT_CLIENT_OUTPUT_MODES, # type: ignore[arg-type]
|
||||
push_notification_configs=push_configs,
|
||||
accepted_output_modes=accepted_output_modes
|
||||
or list(DEFAULT_CLIENT_OUTPUT_MODES),
|
||||
push_notification_config=push_config,
|
||||
grpc_channel_factory=grpc_channel_factory,
|
||||
)
|
||||
|
||||
@@ -965,6 +960,6 @@ async def _create_a2a_client(
|
||||
client = factory.create(agent_card)
|
||||
|
||||
if client_extensions:
|
||||
await client.add_request_middleware(ExtensionsMiddleware(client_extensions))
|
||||
await client.add_interceptor(ExtensionsMiddleware(client_extensions))
|
||||
|
||||
yield client
|
||||
|
||||
@@ -13,13 +13,15 @@ import os
|
||||
from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar, cast
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from a2a.helpers.proto_helpers import (
|
||||
new_artifact,
|
||||
new_text_artifact,
|
||||
new_text_message as new_agent_text_message,
|
||||
)
|
||||
from a2a.server.agent_execution import RequestContext
|
||||
from a2a.server.events import EventQueue
|
||||
from a2a.types import (
|
||||
Artifact,
|
||||
FileWithBytes,
|
||||
FileWithUri,
|
||||
InternalError,
|
||||
InvalidParamsError,
|
||||
Message,
|
||||
Part,
|
||||
@@ -28,18 +30,20 @@ from a2a.types import (
|
||||
TaskStatus,
|
||||
TaskStatusUpdateEvent,
|
||||
)
|
||||
from a2a.utils import (
|
||||
get_data_parts,
|
||||
get_file_parts,
|
||||
new_agent_text_message,
|
||||
new_data_artifact,
|
||||
new_text_artifact,
|
||||
)
|
||||
from a2a.utils.errors import ServerError
|
||||
from a2a.utils.errors import A2AError as ServerError
|
||||
from aiocache import SimpleMemoryCache, caches # type: ignore[import-untyped]
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from crewai.a2a._compat import (
|
||||
TASK_STATE_CANCELED,
|
||||
TASK_STATE_COMPLETED,
|
||||
TASK_STATE_FAILED,
|
||||
part_has_data,
|
||||
part_has_file,
|
||||
part_is_text,
|
||||
proto_copy,
|
||||
)
|
||||
from crewai.a2a.utils.agent_card import _get_server_config
|
||||
from crewai.a2a.utils.content_type import validate_message_parts
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
@@ -64,6 +68,28 @@ P = ParamSpec("P")
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def _get_data_parts(parts: list[Part]) -> list[dict[str, Any]]:
|
||||
"""Extract data parts from a list of protobuf Parts.
|
||||
|
||||
In a2a-sdk v1.0, data is stored via the ``data`` oneof field on Part
|
||||
(a ``google.protobuf.Value``).
|
||||
"""
|
||||
result: list[dict[str, Any]] = []
|
||||
for part in parts:
|
||||
if part_has_data(part):
|
||||
from google.protobuf.json_format import MessageToDict
|
||||
|
||||
val = MessageToDict(part.data)
|
||||
if isinstance(val, dict):
|
||||
result.append(val)
|
||||
return result
|
||||
|
||||
|
||||
def _get_file_parts(parts: list[Part]) -> list[Part]:
|
||||
"""Return parts that carry file content (raw bytes or url)."""
|
||||
return [p for p in parts if part_has_file(p)]
|
||||
|
||||
|
||||
class RedisCacheConfig(TypedDict, total=False):
|
||||
"""Configuration for aiocache Redis backend."""
|
||||
|
||||
@@ -196,12 +222,12 @@ def cancellable(
|
||||
|
||||
|
||||
def _convert_a2a_files_to_file_inputs(
|
||||
a2a_files: list[FileWithBytes | FileWithUri],
|
||||
a2a_files: list[Part],
|
||||
) -> dict[str, Any]:
|
||||
"""Convert a2a file types to crewai FileInput dict.
|
||||
"""Convert a2a file parts to crewai FileInput dict.
|
||||
|
||||
Args:
|
||||
a2a_files: List of FileWithBytes or FileWithUri from a2a SDK.
|
||||
a2a_files: List of Parts that carry file content (raw or url).
|
||||
|
||||
Returns:
|
||||
Dictionary mapping file names to FileInput objects.
|
||||
@@ -213,15 +239,13 @@ def _convert_a2a_files_to_file_inputs(
|
||||
return {}
|
||||
|
||||
file_dict: dict[str, Any] = {}
|
||||
for idx, a2a_file in enumerate(a2a_files):
|
||||
if isinstance(a2a_file, FileWithBytes):
|
||||
file_bytes = base64.b64decode(a2a_file.bytes)
|
||||
name = a2a_file.name or f"file_{idx}"
|
||||
file_source = FileBytes(data=file_bytes, filename=a2a_file.name)
|
||||
for idx, part in enumerate(a2a_files):
|
||||
name = part.filename or f"file_{idx}"
|
||||
if part.HasField("raw"):
|
||||
file_source = FileBytes(data=part.raw, filename=name)
|
||||
file_dict[name] = File(source=file_source)
|
||||
elif isinstance(a2a_file, FileWithUri):
|
||||
name = a2a_file.name or f"file_{idx}"
|
||||
file_dict[name] = File(source=a2a_file.uri)
|
||||
elif part.HasField("url"):
|
||||
file_dict[name] = File(source=part.url)
|
||||
|
||||
return file_dict
|
||||
|
||||
@@ -239,8 +263,9 @@ def _extract_response_schema(parts: list[Part]) -> dict[str, Any] | None:
|
||||
JSON schema dict if found, None otherwise.
|
||||
"""
|
||||
for part in parts:
|
||||
if part.root.kind == "text" and part.root.metadata:
|
||||
schema = part.root.metadata.get("schema")
|
||||
if part_is_text(part) and part.metadata:
|
||||
metadata_dict = dict(part.metadata)
|
||||
schema = metadata_dict.get("schema")
|
||||
if schema and isinstance(schema, dict):
|
||||
return schema # type: ignore[no-any-return]
|
||||
return None
|
||||
@@ -261,9 +286,17 @@ def _create_result_artifact(
|
||||
"""
|
||||
artifact_name = f"result_{task_id}"
|
||||
if isinstance(result, dict):
|
||||
return new_data_artifact(artifact_name, result)
|
||||
from google.protobuf import struct_pb2
|
||||
|
||||
val = struct_pb2.Value()
|
||||
val.struct_value.update(result)
|
||||
return new_artifact([Part(data=val)], artifact_name)
|
||||
if isinstance(result, BaseModel):
|
||||
return new_data_artifact(artifact_name, result.model_dump())
|
||||
from google.protobuf import struct_pb2
|
||||
|
||||
val = struct_pb2.Value()
|
||||
val.struct_value.update(result.model_dump())
|
||||
return new_artifact([Part(data=val)], artifact_name)
|
||||
return new_text_artifact(artifact_name, str(result))
|
||||
|
||||
|
||||
@@ -330,7 +363,7 @@ async def _execute_impl(
|
||||
|
||||
response_model: type[BaseModel] | None = None
|
||||
structured_inputs: list[dict[str, Any]] = []
|
||||
a2a_files: list[FileWithBytes | FileWithUri] = []
|
||||
a2a_files: list[Part] = []
|
||||
|
||||
if context.message and context.message.parts:
|
||||
schema = _extract_response_schema(context.message.parts)
|
||||
@@ -343,8 +376,8 @@ async def _execute_impl(
|
||||
extra={"error": str(e), "schema_title": schema.get("title")},
|
||||
)
|
||||
|
||||
structured_inputs = get_data_parts(context.message.parts)
|
||||
a2a_files = get_file_parts(context.message.parts)
|
||||
structured_inputs = _get_data_parts(context.message.parts)
|
||||
a2a_files = _get_file_parts(context.message.parts)
|
||||
|
||||
task_id = context.task_id
|
||||
context_id = context.context_id
|
||||
@@ -387,12 +420,14 @@ async def _execute_impl(
|
||||
)
|
||||
result_str = str(result)
|
||||
history: list[Message] = [context.message] if context.message else []
|
||||
history.append(new_agent_text_message(result_str, context_id, task_id))
|
||||
history.append(
|
||||
new_agent_text_message(result_str, context_id=context_id, task_id=task_id)
|
||||
)
|
||||
await event_queue.enqueue_event(
|
||||
A2ATask(
|
||||
id=task_id,
|
||||
context_id=context_id,
|
||||
status=TaskStatus(state=TaskState.completed),
|
||||
status=TaskStatus(state=TASK_STATE_COMPLETED),
|
||||
artifacts=[_create_result_artifact(result, task_id)],
|
||||
history=history,
|
||||
)
|
||||
@@ -429,9 +464,7 @@ async def _execute_impl(
|
||||
from_agent=agent,
|
||||
),
|
||||
)
|
||||
raise ServerError(
|
||||
error=InternalError(message=f"Task execution failed: {e}")
|
||||
) from e
|
||||
raise ServerError(f"Task execution failed: {e}") from e
|
||||
|
||||
|
||||
async def execute_with_extensions(
|
||||
@@ -476,9 +509,9 @@ async def cancel(
|
||||
raise ServerError(InvalidParamsError(message="task_id and context_id required"))
|
||||
|
||||
if context.current_task and context.current_task.status.state in (
|
||||
TaskState.completed,
|
||||
TaskState.failed,
|
||||
TaskState.canceled,
|
||||
TASK_STATE_COMPLETED,
|
||||
TASK_STATE_FAILED,
|
||||
TASK_STATE_CANCELED,
|
||||
):
|
||||
return context.current_task
|
||||
|
||||
@@ -492,13 +525,12 @@ async def cancel(
|
||||
TaskStatusUpdateEvent(
|
||||
task_id=task_id,
|
||||
context_id=context_id,
|
||||
status=TaskStatus(state=TaskState.canceled),
|
||||
final=True,
|
||||
status=TaskStatus(state=TASK_STATE_CANCELED),
|
||||
)
|
||||
)
|
||||
|
||||
if context.current_task:
|
||||
context.current_task.status = TaskStatus(state=TaskState.canceled)
|
||||
context.current_task.status.CopyFrom(TaskStatus(state=TASK_STATE_CANCELED))
|
||||
return context.current_task
|
||||
return None
|
||||
|
||||
@@ -571,7 +603,7 @@ def list_tasks(
|
||||
|
||||
result: list[A2ATask] = []
|
||||
for task in page:
|
||||
task = task.model_copy(deep=True)
|
||||
task = proto_copy(task)
|
||||
if history_length is not None and task.history:
|
||||
task.history = task.history[-history_length:]
|
||||
if not include_artifacts:
|
||||
|
||||
@@ -12,6 +12,11 @@ from typing import Final, Literal
|
||||
|
||||
from a2a.types import AgentCard, AgentInterface
|
||||
|
||||
from crewai.a2a._compat import (
|
||||
agent_card_interfaces,
|
||||
agent_card_preferred_transport,
|
||||
agent_card_url,
|
||||
)
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.a2a_events import A2ATransportNegotiatedEvent
|
||||
|
||||
@@ -85,23 +90,21 @@ def _get_server_interfaces(agent_card: AgentCard) -> list[AgentInterface]:
|
||||
List of AgentInterface objects representing all available endpoints.
|
||||
"""
|
||||
interfaces: list[AgentInterface] = []
|
||||
|
||||
primary_transport = agent_card.preferred_transport or JSONRPC_TRANSPORT
|
||||
interfaces.append(
|
||||
AgentInterface(
|
||||
transport=primary_transport,
|
||||
url=agent_card.url,
|
||||
for interface in agent_card_interfaces(agent_card):
|
||||
is_duplicate = any(
|
||||
i.url == interface.url and i.protocol_binding == interface.protocol_binding
|
||||
for i in interfaces
|
||||
)
|
||||
)
|
||||
if not is_duplicate:
|
||||
interfaces.append(interface)
|
||||
|
||||
if agent_card.additional_interfaces:
|
||||
for interface in agent_card.additional_interfaces:
|
||||
is_duplicate = any(
|
||||
i.url == interface.url and i.transport == interface.transport
|
||||
for i in interfaces
|
||||
if not interfaces:
|
||||
interfaces.append(
|
||||
AgentInterface(
|
||||
url=agent_card_url(agent_card),
|
||||
protocol_binding=JSONRPC_TRANSPORT,
|
||||
)
|
||||
if not is_duplicate:
|
||||
interfaces.append(interface)
|
||||
)
|
||||
|
||||
return interfaces
|
||||
|
||||
@@ -149,11 +152,11 @@ def negotiate_transport(
|
||||
)
|
||||
|
||||
server_interfaces = _get_server_interfaces(agent_card)
|
||||
server_transports = [i.transport.upper() for i in server_interfaces]
|
||||
server_transports = [i.protocol_binding.upper() for i in server_interfaces]
|
||||
|
||||
transport_to_interface: dict[str, AgentInterface] = {}
|
||||
for interface in server_interfaces:
|
||||
transport_upper = interface.transport.upper()
|
||||
transport_upper = interface.protocol_binding.upper()
|
||||
if transport_upper not in transport_to_interface:
|
||||
transport_to_interface[transport_upper] = interface
|
||||
|
||||
@@ -162,19 +165,21 @@ def negotiate_transport(
|
||||
if client_preferred and client_preferred in transport_to_interface:
|
||||
interface = transport_to_interface[client_preferred]
|
||||
result = NegotiatedTransport(
|
||||
transport=interface.transport,
|
||||
transport=interface.protocol_binding,
|
||||
url=interface.url,
|
||||
source="client_preferred",
|
||||
)
|
||||
else:
|
||||
server_preferred = (agent_card.preferred_transport or JSONRPC_TRANSPORT).upper()
|
||||
server_preferred = (
|
||||
agent_card_preferred_transport(agent_card) or JSONRPC_TRANSPORT
|
||||
).upper()
|
||||
if (
|
||||
server_preferred in client_transports
|
||||
and server_preferred in transport_to_interface
|
||||
):
|
||||
interface = transport_to_interface[server_preferred]
|
||||
result = NegotiatedTransport(
|
||||
transport=interface.transport,
|
||||
transport=interface.protocol_binding,
|
||||
url=interface.url,
|
||||
source="server_preferred",
|
||||
)
|
||||
@@ -183,7 +188,7 @@ def negotiate_transport(
|
||||
if transport in transport_to_interface:
|
||||
interface = transport_to_interface[transport]
|
||||
result = NegotiatedTransport(
|
||||
transport=interface.transport,
|
||||
transport=interface.protocol_binding,
|
||||
url=interface.url,
|
||||
source="fallback",
|
||||
)
|
||||
@@ -199,14 +204,14 @@ def negotiate_transport(
|
||||
crewai_event_bus.emit(
|
||||
None,
|
||||
A2ATransportNegotiatedEvent(
|
||||
endpoint=endpoint or agent_card.url,
|
||||
endpoint=endpoint or agent_card_url(agent_card),
|
||||
a2a_agent_name=a2a_agent_name or agent_card.name,
|
||||
negotiated_transport=result.transport,
|
||||
negotiated_url=result.url,
|
||||
source=result.source,
|
||||
client_supported_transports=client_transports,
|
||||
server_supported_transports=server_transports,
|
||||
server_preferred_transport=agent_card.preferred_transport
|
||||
server_preferred_transport=agent_card_preferred_transport(agent_card)
|
||||
or JSONRPC_TRANSPORT,
|
||||
client_preferred_transport=client_preferred,
|
||||
),
|
||||
|
||||
@@ -14,9 +14,18 @@ import json
|
||||
from types import MethodType
|
||||
from typing import TYPE_CHECKING, Any, NamedTuple
|
||||
|
||||
from a2a.types import Role, TaskState
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from crewai.a2a._compat import (
|
||||
ROLE_AGENT,
|
||||
ROLE_USER,
|
||||
TASK_STATE_COMPLETED,
|
||||
TASK_STATE_INPUT_REQUIRED,
|
||||
agent_card_to_dict,
|
||||
part_is_text,
|
||||
part_text,
|
||||
proto_to_json,
|
||||
)
|
||||
from crewai.a2a.config import A2AClientConfig, A2AConfig
|
||||
from crewai.a2a.extensions.base import (
|
||||
A2AExtension,
|
||||
@@ -681,7 +690,7 @@ def _augment_prompt_with_a2a(
|
||||
}
|
||||
agents_text += f"\n{json.dumps(filtered, indent=2)}\n"
|
||||
else:
|
||||
agents_text += f"\n{card.model_dump_json(indent=2, exclude_none=True, include={'description', 'url', 'skills'})}\n"
|
||||
agents_text += f"\n{proto_to_json(card)}\n"
|
||||
|
||||
failed_agents = failed_agents or {}
|
||||
if failed_agents:
|
||||
@@ -695,7 +704,7 @@ def _augment_prompt_with_a2a(
|
||||
|
||||
if conversation_history:
|
||||
for msg in conversation_history:
|
||||
history_text += f"\n{msg.model_dump_json(indent=2, exclude_none=True, exclude={'message_id'})}\n"
|
||||
history_text += f"\n{proto_to_json(msg)}\n"
|
||||
|
||||
history_text = PREVIOUS_A2A_CONVERSATION_TEMPLATE.substitute(
|
||||
previous_a2a_conversation=history_text
|
||||
@@ -780,9 +789,9 @@ def _handle_max_turns_exceeded(
|
||||
"""
|
||||
if conversation_history:
|
||||
for msg in reversed(conversation_history):
|
||||
if msg.role == Role.agent:
|
||||
if msg.role == ROLE_AGENT:
|
||||
text_parts = [
|
||||
part.root.text for part in msg.parts if part.root.kind == "text"
|
||||
part_text(part) for part in msg.parts if part_is_text(part)
|
||||
]
|
||||
final_message = (
|
||||
" ".join(text_parts) if text_parts else "Conversation completed"
|
||||
@@ -985,7 +994,9 @@ def _init_delegation_state(
|
||||
reference_task_ids=list(ctx.reference_task_ids),
|
||||
conversation_history=[],
|
||||
agent_card=current_agent_card,
|
||||
agent_card_dict=current_agent_card.model_dump() if current_agent_card else None,
|
||||
agent_card_dict=agent_card_to_dict(current_agent_card)
|
||||
if current_agent_card
|
||||
else None,
|
||||
agent_name=current_agent_card.name if current_agent_card else None,
|
||||
)
|
||||
|
||||
@@ -1110,7 +1121,7 @@ def _handle_task_completion(
|
||||
- remote_notice: Template notice about remote agent response
|
||||
"""
|
||||
remote_notice = ""
|
||||
if a2a_result["status"] == TaskState.completed:
|
||||
if a2a_result["status"] == TASK_STATE_COMPLETED:
|
||||
remote_notice = REMOTE_AGENT_RESPONSE_NOTICE
|
||||
|
||||
if task_id_config is not None and task_id_config not in reference_task_ids:
|
||||
@@ -1294,7 +1305,7 @@ def _delegate_to_a2a(
|
||||
extensions=ctx.extensions,
|
||||
conversation_history=conversation_history,
|
||||
agent_id=ctx.agent_id,
|
||||
agent_role=Role.user,
|
||||
agent_role=ROLE_USER,
|
||||
agent_branch=agent_branch,
|
||||
response_model=ctx.agent_config.response_model,
|
||||
turn_number=turn_num + 1,
|
||||
@@ -1316,7 +1327,10 @@ def _delegate_to_a2a(
|
||||
if latest_message.context_id is not None:
|
||||
context_id = latest_message.context_id
|
||||
|
||||
if a2a_result["status"] in [TaskState.completed, TaskState.input_required]:
|
||||
if a2a_result["status"] in [
|
||||
TASK_STATE_COMPLETED,
|
||||
TASK_STATE_INPUT_REQUIRED,
|
||||
]:
|
||||
trusted_result, task_id, reference_task_ids, remote_notice = (
|
||||
_handle_task_completion(
|
||||
a2a_result,
|
||||
@@ -1649,7 +1663,7 @@ async def _adelegate_to_a2a(
|
||||
extensions=ctx.extensions,
|
||||
conversation_history=conversation_history,
|
||||
agent_id=ctx.agent_id,
|
||||
agent_role=Role.user,
|
||||
agent_role=ROLE_USER,
|
||||
agent_branch=agent_branch,
|
||||
response_model=ctx.agent_config.response_model,
|
||||
turn_number=turn_num + 1,
|
||||
@@ -1671,7 +1685,10 @@ async def _adelegate_to_a2a(
|
||||
if latest_message.context_id is not None:
|
||||
context_id = latest_message.context_id
|
||||
|
||||
if a2a_result["status"] in [TaskState.completed, TaskState.input_required]:
|
||||
if a2a_result["status"] in [
|
||||
TASK_STATE_COMPLETED,
|
||||
TASK_STATE_INPUT_REQUIRED,
|
||||
]:
|
||||
trusted_result, task_id, reference_task_ids, remote_notice = (
|
||||
_handle_task_completion(
|
||||
a2a_result,
|
||||
|
||||
@@ -3,12 +3,23 @@ from __future__ import annotations
|
||||
import os
|
||||
import uuid
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from a2a.client import ClientFactory
|
||||
from a2a.types import AgentCard, Message, Part, Role, TaskState, TextPart
|
||||
from a2a.client import A2ACardResolver, ClientFactory
|
||||
from a2a.types import AgentCapabilities, AgentCard, AgentInterface, Message, Part, Role, TaskState
|
||||
|
||||
from crewai.a2a._compat import (
|
||||
ROLE_AGENT,
|
||||
ROLE_USER,
|
||||
TASK_STATE_COMPLETED,
|
||||
TASK_STATE_FAILED,
|
||||
agent_card_url,
|
||||
make_send_request,
|
||||
new_text_message,
|
||||
new_text_part,
|
||||
)
|
||||
from crewai.a2a.updates.polling.handler import PollingHandler
|
||||
from crewai.a2a.updates.streaming.handler import StreamingHandler
|
||||
|
||||
@@ -17,27 +28,31 @@ A2A_TEST_ENDPOINT = os.getenv("A2A_TEST_ENDPOINT", "http://localhost:9999")
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def a2a_client():
|
||||
async def card_resolver():
|
||||
"""Create an A2ACardResolver for the test server."""
|
||||
async with httpx.AsyncClient() as http_client:
|
||||
resolver = A2ACardResolver(http_client, A2A_TEST_ENDPOINT)
|
||||
yield resolver
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def agent_card(card_resolver) -> AgentCard:
|
||||
"""Fetch the real agent card from the server."""
|
||||
return await card_resolver.get_agent_card()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def a2a_client(agent_card):
|
||||
"""Create A2A client for test server."""
|
||||
client = await ClientFactory.connect(A2A_TEST_ENDPOINT)
|
||||
factory = ClientFactory()
|
||||
client = factory.create(agent_card)
|
||||
yield client
|
||||
await client.close()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_message() -> Message:
|
||||
"""Create a simple test message."""
|
||||
return Message(
|
||||
role=Role.user,
|
||||
parts=[Part(root=TextPart(text="What is 2 + 2?"))],
|
||||
message_id=str(uuid.uuid4()),
|
||||
)
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def agent_card(a2a_client) -> AgentCard:
|
||||
"""Fetch the real agent card from the server."""
|
||||
return await a2a_client.get_card()
|
||||
return new_text_message("What is 2 + 2?", role=ROLE_USER)
|
||||
|
||||
|
||||
class TestA2AAgentCardFetching:
|
||||
@@ -45,13 +60,13 @@ class TestA2AAgentCardFetching:
|
||||
|
||||
@pytest.mark.vcr()
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_agent_card(self, a2a_client) -> None:
|
||||
async def test_fetch_agent_card(self, card_resolver) -> None:
|
||||
"""Test fetching an agent card from the server."""
|
||||
card = await a2a_client.get_card()
|
||||
card = await card_resolver.get_agent_card()
|
||||
|
||||
assert card is not None
|
||||
assert card.name == "GPT Assistant"
|
||||
assert card.url is not None
|
||||
assert card.supported_interfaces is not None
|
||||
assert card.capabilities is not None
|
||||
assert card.capabilities.streaming is True
|
||||
|
||||
@@ -80,7 +95,7 @@ class TestA2APollingIntegration:
|
||||
)
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert result["status"] == TaskState.completed
|
||||
assert result["status"] == TASK_STATE_COMPLETED
|
||||
assert result.get("result") is not None
|
||||
assert "4" in result["result"]
|
||||
|
||||
@@ -104,11 +119,11 @@ class TestA2AStreamingIntegration:
|
||||
message=test_message,
|
||||
new_messages=new_messages,
|
||||
agent_card=agent_card,
|
||||
endpoint=agent_card.url,
|
||||
endpoint=agent_card_url(agent_card),
|
||||
)
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert result["status"] == TaskState.completed
|
||||
assert result["status"] == TASK_STATE_COMPLETED
|
||||
assert result.get("result") is not None
|
||||
|
||||
|
||||
@@ -123,19 +138,19 @@ class TestA2ATaskOperations:
|
||||
test_message: Message,
|
||||
) -> None:
|
||||
"""Test sending a message and getting a response."""
|
||||
from a2a.types import Task
|
||||
from a2a.types import StreamResponse, Task
|
||||
|
||||
from crewai.a2a._compat import is_stream_task
|
||||
|
||||
final_task: Task | None = None
|
||||
async for event in a2a_client.send_message(test_message):
|
||||
if isinstance(event, tuple) and len(event) >= 1:
|
||||
task, _ = event
|
||||
if isinstance(task, Task):
|
||||
final_task = task
|
||||
async for event in a2a_client.send_message(make_send_request(test_message)):
|
||||
if isinstance(event, StreamResponse) and is_stream_task(event):
|
||||
final_task = event.task
|
||||
|
||||
assert final_task is not None
|
||||
assert final_task.id is not None
|
||||
assert final_task.id != ""
|
||||
assert final_task.status is not None
|
||||
assert final_task.status.state == TaskState.completed
|
||||
assert final_task.status.state == TaskState.TASK_STATE_COMPLETED
|
||||
|
||||
|
||||
class TestA2APushNotificationHandler:
|
||||
@@ -148,17 +163,19 @@ class TestA2APushNotificationHandler:
|
||||
@pytest.fixture
|
||||
def mock_agent_card(self) -> AgentCard:
|
||||
"""Create a minimal valid agent card for testing."""
|
||||
from a2a.types import AgentCapabilities
|
||||
|
||||
return AgentCard(
|
||||
name="Test Agent",
|
||||
description="Test agent for push notification tests",
|
||||
url="http://localhost:9999",
|
||||
supported_interfaces=[
|
||||
AgentInterface(
|
||||
url="http://localhost:9999",
|
||||
protocol_binding="JSONRPC",
|
||||
),
|
||||
],
|
||||
version="1.0.0",
|
||||
capabilities=AgentCapabilities(streaming=True, push_notifications=True),
|
||||
default_input_modes=["text"],
|
||||
default_output_modes=["text"],
|
||||
skills=[],
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
@@ -169,7 +186,7 @@ class TestA2APushNotificationHandler:
|
||||
return Task(
|
||||
id="task-123",
|
||||
context_id="ctx-123",
|
||||
status=TaskStatus(state=TaskState.working),
|
||||
status=TaskStatus(state=TaskState.TASK_STATE_WORKING),
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -181,7 +198,7 @@ class TestA2APushNotificationHandler:
|
||||
"""Test that push handler waits for result from store."""
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
from a2a.types import Task, TaskStatus
|
||||
from a2a.types import StreamResponse, Task, TaskStatus
|
||||
from pydantic import AnyHttpUrl
|
||||
|
||||
from crewai.a2a.updates.push_notifications.config import PushNotificationConfig
|
||||
@@ -190,15 +207,14 @@ class TestA2APushNotificationHandler:
|
||||
completed_task = Task(
|
||||
id="task-123",
|
||||
context_id="ctx-123",
|
||||
status=TaskStatus(state=TaskState.completed),
|
||||
history=[],
|
||||
status=TaskStatus(state=TaskState.TASK_STATE_COMPLETED),
|
||||
)
|
||||
|
||||
mock_store = MagicMock()
|
||||
mock_store.wait_for_result = AsyncMock(return_value=completed_task)
|
||||
|
||||
async def mock_send_message(*args, **kwargs):
|
||||
yield (mock_task, None)
|
||||
yield StreamResponse(task=mock_task)
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.send_message = mock_send_message
|
||||
@@ -209,11 +225,7 @@ class TestA2APushNotificationHandler:
|
||||
result_store=mock_store,
|
||||
)
|
||||
|
||||
test_msg = Message(
|
||||
role=Role.user,
|
||||
parts=[Part(root=TextPart(text="What is 2+2?"))],
|
||||
message_id="msg-001",
|
||||
)
|
||||
test_msg = new_text_message("What is 2+2?", role=ROLE_USER)
|
||||
|
||||
new_messages: list[Message] = []
|
||||
|
||||
@@ -226,7 +238,7 @@ class TestA2APushNotificationHandler:
|
||||
result_store=mock_store,
|
||||
polling_timeout=30.0,
|
||||
polling_interval=1.0,
|
||||
endpoint=mock_agent_card.url,
|
||||
endpoint=agent_card_url(mock_agent_card),
|
||||
)
|
||||
|
||||
mock_store.wait_for_result.assert_called_once_with(
|
||||
@@ -235,7 +247,7 @@ class TestA2APushNotificationHandler:
|
||||
poll_interval=1.0,
|
||||
)
|
||||
|
||||
assert result["status"] == TaskState.completed
|
||||
assert result["status"] == TASK_STATE_COMPLETED
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_push_handler_returns_failure_on_timeout(
|
||||
@@ -245,7 +257,7 @@ class TestA2APushNotificationHandler:
|
||||
"""Test that push handler returns failure when result store times out."""
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
from a2a.types import Task, TaskStatus
|
||||
from a2a.types import StreamResponse, Task, TaskStatus
|
||||
from pydantic import AnyHttpUrl
|
||||
|
||||
from crewai.a2a.updates.push_notifications.config import PushNotificationConfig
|
||||
@@ -257,11 +269,11 @@ class TestA2APushNotificationHandler:
|
||||
working_task = Task(
|
||||
id="task-456",
|
||||
context_id="ctx-456",
|
||||
status=TaskStatus(state=TaskState.working),
|
||||
status=TaskStatus(state=TaskState.TASK_STATE_WORKING),
|
||||
)
|
||||
|
||||
async def mock_send_message(*args, **kwargs):
|
||||
yield (working_task, None)
|
||||
yield StreamResponse(task=working_task)
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.send_message = mock_send_message
|
||||
@@ -272,11 +284,7 @@ class TestA2APushNotificationHandler:
|
||||
result_store=mock_store,
|
||||
)
|
||||
|
||||
test_msg = Message(
|
||||
role=Role.user,
|
||||
parts=[Part(root=TextPart(text="test"))],
|
||||
message_id="msg-002",
|
||||
)
|
||||
test_msg = new_text_message("test", role=ROLE_USER)
|
||||
|
||||
new_messages: list[Message] = []
|
||||
|
||||
@@ -289,10 +297,10 @@ class TestA2APushNotificationHandler:
|
||||
result_store=mock_store,
|
||||
polling_timeout=5.0,
|
||||
polling_interval=0.5,
|
||||
endpoint=mock_agent_card.url,
|
||||
endpoint=agent_card_url(mock_agent_card),
|
||||
)
|
||||
|
||||
assert result["status"] == TaskState.failed
|
||||
assert result["status"] == TASK_STATE_FAILED
|
||||
assert "timeout" in result.get("error", "").lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -307,11 +315,7 @@ class TestA2APushNotificationHandler:
|
||||
|
||||
mock_client = MagicMock()
|
||||
|
||||
test_msg = Message(
|
||||
role=Role.user,
|
||||
parts=[Part(root=TextPart(text="test"))],
|
||||
message_id="msg-003",
|
||||
)
|
||||
test_msg = new_text_message("test", role=ROLE_USER)
|
||||
|
||||
new_messages: list[Message] = []
|
||||
|
||||
@@ -320,8 +324,8 @@ class TestA2APushNotificationHandler:
|
||||
message=test_msg,
|
||||
new_messages=new_messages,
|
||||
agent_card=mock_agent_card,
|
||||
endpoint=mock_agent_card.url,
|
||||
endpoint=agent_card_url(mock_agent_card),
|
||||
)
|
||||
|
||||
assert result["status"] == TaskState.failed
|
||||
assert result["status"] == TASK_STATE_FAILED
|
||||
assert "config" in result.get("error", "").lower()
|
||||
|
||||
517
lib/crewai/tests/a2a/test_a2a_sdk_v1_compat.py
Normal file
517
lib/crewai/tests/a2a/test_a2a_sdk_v1_compat.py
Normal file
@@ -0,0 +1,517 @@
|
||||
"""Tests for a2a-sdk v1.0 compatibility.
|
||||
|
||||
These tests validate that crewai.a2a modules correctly import and work with
|
||||
a2a-sdk v1.0.x (protobuf-based types). They cover the core issue described
|
||||
in https://github.com/crewAIInc/crewAI/issues/5607:
|
||||
|
||||
ImportError: cannot import name 'A2AClientHTTPError' from 'a2a.client.errors'
|
||||
|
||||
The migration from a2a-sdk ~0.3.10 to >=1.0.0,<2 introduced major breaking
|
||||
changes including protobuf-based types, renamed error classes, and new enum
|
||||
value conventions.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class TestSdkV1Imports:
|
||||
"""Verify that old v0.3 names no longer exist, and our compat layer works."""
|
||||
|
||||
def test_a2a_client_error_importable(self) -> None:
|
||||
"""A2AClientError (renamed from A2AClientHTTPError) should be importable."""
|
||||
from a2a.client.errors import A2AClientError
|
||||
|
||||
assert A2AClientError is not None
|
||||
|
||||
def test_old_a2a_client_http_error_removed(self) -> None:
|
||||
"""A2AClientHTTPError no longer exists in a2a-sdk v1.0."""
|
||||
with pytest.raises(ImportError):
|
||||
from a2a.client.errors import A2AClientHTTPError # noqa: F401
|
||||
|
||||
def test_compat_alias_maps_to_new_error(self) -> None:
|
||||
"""Our _compat alias should map to the new error class."""
|
||||
from a2a.client.errors import A2AClientError
|
||||
|
||||
from crewai.a2a._compat import A2AClientHTTPError
|
||||
|
||||
assert A2AClientHTTPError is A2AClientError
|
||||
|
||||
def test_text_part_removed_in_v1(self) -> None:
|
||||
"""TextPart no longer exists as a separate type in a2a-sdk v1.0."""
|
||||
with pytest.raises(ImportError):
|
||||
from a2a.types import TextPart # noqa: F401
|
||||
|
||||
def test_protobuf_types_importable(self) -> None:
|
||||
"""Key protobuf types should be importable from a2a.types."""
|
||||
from a2a.types import ( # noqa: F401
|
||||
AgentCapabilities,
|
||||
AgentCard,
|
||||
AgentInterface,
|
||||
GetTaskRequest,
|
||||
Message,
|
||||
Part,
|
||||
Role,
|
||||
StreamResponse,
|
||||
SubscribeToTaskRequest,
|
||||
Task,
|
||||
TaskPushNotificationConfig,
|
||||
TaskState,
|
||||
TaskStatusUpdateEvent,
|
||||
)
|
||||
|
||||
|
||||
class TestCompatLayer:
|
||||
"""Tests for the crewai.a2a._compat compatibility layer."""
|
||||
|
||||
def test_role_constants(self) -> None:
|
||||
"""ROLE_USER and ROLE_AGENT should be valid Role enum values."""
|
||||
from a2a.types import Role
|
||||
|
||||
from crewai.a2a._compat import ROLE_AGENT, ROLE_USER
|
||||
|
||||
assert ROLE_USER == Role.ROLE_USER
|
||||
assert ROLE_AGENT == Role.ROLE_AGENT
|
||||
|
||||
def test_task_state_constants(self) -> None:
|
||||
"""TASK_STATE_* should be valid TaskState enum values."""
|
||||
from a2a.types import TaskState
|
||||
|
||||
from crewai.a2a._compat import (
|
||||
TASK_STATE_CANCELED,
|
||||
TASK_STATE_COMPLETED,
|
||||
TASK_STATE_FAILED,
|
||||
TASK_STATE_INPUT_REQUIRED,
|
||||
TASK_STATE_REJECTED,
|
||||
TASK_STATE_SUBMITTED,
|
||||
TASK_STATE_WORKING,
|
||||
)
|
||||
|
||||
assert TASK_STATE_SUBMITTED == TaskState.TASK_STATE_SUBMITTED
|
||||
assert TASK_STATE_WORKING == TaskState.TASK_STATE_WORKING
|
||||
assert TASK_STATE_COMPLETED == TaskState.TASK_STATE_COMPLETED
|
||||
assert TASK_STATE_FAILED == TaskState.TASK_STATE_FAILED
|
||||
assert TASK_STATE_CANCELED == TaskState.TASK_STATE_CANCELED
|
||||
assert TASK_STATE_INPUT_REQUIRED == TaskState.TASK_STATE_INPUT_REQUIRED
|
||||
assert TASK_STATE_REJECTED == TaskState.TASK_STATE_REJECTED
|
||||
|
||||
def test_terminal_states(self) -> None:
|
||||
"""TERMINAL_STATES should include completed, failed, rejected, canceled."""
|
||||
from crewai.a2a._compat import (
|
||||
TASK_STATE_CANCELED,
|
||||
TASK_STATE_COMPLETED,
|
||||
TASK_STATE_FAILED,
|
||||
TASK_STATE_REJECTED,
|
||||
TERMINAL_STATES,
|
||||
)
|
||||
|
||||
assert TASK_STATE_COMPLETED in TERMINAL_STATES
|
||||
assert TASK_STATE_FAILED in TERMINAL_STATES
|
||||
assert TASK_STATE_REJECTED in TERMINAL_STATES
|
||||
assert TASK_STATE_CANCELED in TERMINAL_STATES
|
||||
|
||||
|
||||
class TestPartHelpers:
|
||||
"""Tests for protobuf Part helpers."""
|
||||
|
||||
def test_new_text_part(self) -> None:
|
||||
"""new_text_part should create a Part with text field set."""
|
||||
from crewai.a2a._compat import new_text_part, part_is_text, part_text
|
||||
|
||||
part = new_text_part("hello world")
|
||||
assert part_is_text(part)
|
||||
assert part_text(part) == "hello world"
|
||||
|
||||
def test_part_is_text_false_for_non_text(self) -> None:
|
||||
"""part_is_text should return False for non-text parts."""
|
||||
from a2a.types import Part
|
||||
from google.protobuf.struct_pb2 import Value
|
||||
|
||||
from crewai.a2a._compat import part_is_text
|
||||
|
||||
v = Value()
|
||||
v.string_value = "test"
|
||||
part = Part(data=v)
|
||||
assert not part_is_text(part)
|
||||
|
||||
def test_part_has_data(self) -> None:
|
||||
"""part_has_data should detect data parts."""
|
||||
from a2a.types import Part
|
||||
from google.protobuf.struct_pb2 import Value
|
||||
|
||||
from crewai.a2a._compat import part_has_data
|
||||
|
||||
v = Value()
|
||||
v.string_value = "test"
|
||||
part = Part(data=v)
|
||||
assert part_has_data(part)
|
||||
|
||||
def test_part_has_file(self) -> None:
|
||||
"""part_has_file should detect raw/url file parts."""
|
||||
from a2a.types import Part
|
||||
|
||||
from crewai.a2a._compat import part_has_file
|
||||
|
||||
raw_part = Part(raw=b"file content", media_type="application/pdf")
|
||||
assert part_has_file(raw_part)
|
||||
|
||||
url_part = Part(url="https://example.com/file.pdf", media_type="application/pdf")
|
||||
assert part_has_file(url_part)
|
||||
|
||||
|
||||
class TestMessageHelpers:
|
||||
"""Tests for protobuf Message helpers."""
|
||||
|
||||
def test_new_text_message(self) -> None:
|
||||
"""new_text_message should create a Message with a text Part."""
|
||||
from crewai.a2a._compat import (
|
||||
ROLE_USER,
|
||||
new_text_message,
|
||||
part_is_text,
|
||||
part_text,
|
||||
)
|
||||
|
||||
msg = new_text_message("test message", role=ROLE_USER)
|
||||
assert msg.role == ROLE_USER
|
||||
assert len(msg.parts) == 1
|
||||
assert part_is_text(msg.parts[0])
|
||||
assert part_text(msg.parts[0]) == "test message"
|
||||
|
||||
def test_new_text_message_with_context_and_task(self) -> None:
|
||||
"""new_text_message should accept context_id and task_id."""
|
||||
from crewai.a2a._compat import ROLE_AGENT, new_text_message
|
||||
|
||||
msg = new_text_message(
|
||||
"response",
|
||||
role=ROLE_AGENT,
|
||||
context_id="ctx-123",
|
||||
task_id="task-456",
|
||||
)
|
||||
assert msg.context_id == "ctx-123"
|
||||
assert msg.task_id == "task-456"
|
||||
|
||||
|
||||
class TestAgentCardHelpers:
|
||||
"""Tests for protobuf AgentCard helpers."""
|
||||
|
||||
def test_agent_card_to_dict(self) -> None:
|
||||
"""agent_card_to_dict should serialize an AgentCard to a plain dict."""
|
||||
from a2a.types import AgentCard, AgentInterface
|
||||
|
||||
from crewai.a2a._compat import agent_card_to_dict
|
||||
|
||||
card = AgentCard(
|
||||
name="Test Agent",
|
||||
description="A test agent",
|
||||
supported_interfaces=[
|
||||
AgentInterface(url="http://localhost:9999", protocol_binding="JSONRPC"),
|
||||
],
|
||||
version="1.0.0",
|
||||
)
|
||||
result = agent_card_to_dict(card)
|
||||
assert isinstance(result, dict)
|
||||
assert result["name"] == "Test Agent"
|
||||
assert result["description"] == "A test agent"
|
||||
|
||||
def test_agent_card_url(self) -> None:
|
||||
"""agent_card_url should return the URL from the first interface."""
|
||||
from a2a.types import AgentCard, AgentInterface
|
||||
|
||||
from crewai.a2a._compat import agent_card_url
|
||||
|
||||
card = AgentCard(
|
||||
name="Test",
|
||||
supported_interfaces=[
|
||||
AgentInterface(url="http://localhost:9999", protocol_binding="JSONRPC"),
|
||||
],
|
||||
)
|
||||
assert agent_card_url(card) == "http://localhost:9999"
|
||||
|
||||
def test_agent_card_url_empty_when_no_interfaces(self) -> None:
|
||||
"""agent_card_url should return empty string if no interfaces."""
|
||||
from a2a.types import AgentCard
|
||||
|
||||
from crewai.a2a._compat import agent_card_url
|
||||
|
||||
card = AgentCard(name="No Interfaces")
|
||||
assert agent_card_url(card) == ""
|
||||
|
||||
def test_agent_card_preferred_transport(self) -> None:
|
||||
"""agent_card_preferred_transport should return protocol_binding."""
|
||||
from a2a.types import AgentCard, AgentInterface
|
||||
|
||||
from crewai.a2a._compat import agent_card_preferred_transport
|
||||
|
||||
card = AgentCard(
|
||||
name="Test",
|
||||
supported_interfaces=[
|
||||
AgentInterface(url="http://localhost", protocol_binding="GRPC"),
|
||||
],
|
||||
)
|
||||
assert agent_card_preferred_transport(card) == "GRPC"
|
||||
|
||||
def test_agent_card_interfaces(self) -> None:
|
||||
"""agent_card_interfaces should return all interfaces."""
|
||||
from a2a.types import AgentCard, AgentInterface
|
||||
|
||||
from crewai.a2a._compat import agent_card_interfaces
|
||||
|
||||
card = AgentCard(
|
||||
name="Test",
|
||||
supported_interfaces=[
|
||||
AgentInterface(url="http://a.com", protocol_binding="JSONRPC"),
|
||||
AgentInterface(url="http://b.com", protocol_binding="GRPC"),
|
||||
],
|
||||
)
|
||||
interfaces = agent_card_interfaces(card)
|
||||
assert len(interfaces) == 2
|
||||
|
||||
def test_agent_card_protocol_version(self) -> None:
|
||||
"""agent_card_protocol_version should return protocol version from first interface."""
|
||||
from a2a.types import AgentCard, AgentInterface
|
||||
|
||||
from crewai.a2a._compat import agent_card_protocol_version
|
||||
|
||||
card = AgentCard(
|
||||
name="Test",
|
||||
supported_interfaces=[
|
||||
AgentInterface(
|
||||
url="http://localhost",
|
||||
protocol_binding="JSONRPC",
|
||||
protocol_version="0.3",
|
||||
),
|
||||
],
|
||||
)
|
||||
assert agent_card_protocol_version(card) == "0.3"
|
||||
|
||||
|
||||
class TestProtoCopy:
|
||||
"""Tests for protobuf deep copy helper."""
|
||||
|
||||
def test_proto_copy_creates_independent_copy(self) -> None:
|
||||
"""proto_copy should create a deep copy of a protobuf message."""
|
||||
from a2a.types import AgentCard, AgentInterface
|
||||
|
||||
from crewai.a2a._compat import proto_copy
|
||||
|
||||
original = AgentCard(
|
||||
name="Original",
|
||||
supported_interfaces=[
|
||||
AgentInterface(url="http://original.com", protocol_binding="JSONRPC"),
|
||||
],
|
||||
)
|
||||
copy = proto_copy(original)
|
||||
copy.name = "Modified"
|
||||
|
||||
assert original.name == "Original"
|
||||
assert copy.name == "Modified"
|
||||
|
||||
|
||||
class TestStreamResponseHelpers:
|
||||
"""Tests for StreamResponse event helpers."""
|
||||
|
||||
def test_is_stream_message(self) -> None:
|
||||
"""is_stream_message should detect messages in StreamResponse."""
|
||||
from a2a.types import Message, StreamResponse
|
||||
|
||||
from crewai.a2a._compat import ROLE_AGENT, is_stream_message, new_text_part
|
||||
|
||||
msg = Message(
|
||||
role=ROLE_AGENT,
|
||||
parts=[new_text_part("hello")],
|
||||
message_id=str(uuid.uuid4()),
|
||||
)
|
||||
sr = StreamResponse(message=msg)
|
||||
assert is_stream_message(sr)
|
||||
|
||||
def test_is_stream_task(self) -> None:
|
||||
"""is_stream_task should detect tasks in StreamResponse."""
|
||||
from a2a.types import StreamResponse, Task, TaskState, TaskStatus
|
||||
|
||||
from crewai.a2a._compat import is_stream_task
|
||||
|
||||
task = Task(
|
||||
id="task-1",
|
||||
status=TaskStatus(state=TaskState.TASK_STATE_COMPLETED),
|
||||
)
|
||||
sr = StreamResponse(task=task)
|
||||
assert is_stream_task(sr)
|
||||
|
||||
def test_is_stream_status_update(self) -> None:
|
||||
"""is_stream_status_update should detect status updates."""
|
||||
from a2a.types import StreamResponse, TaskState, TaskStatus, TaskStatusUpdateEvent
|
||||
|
||||
from crewai.a2a._compat import is_stream_status_update
|
||||
|
||||
update = TaskStatusUpdateEvent(
|
||||
task_id="task-1",
|
||||
context_id="ctx-1",
|
||||
status=TaskStatus(state=TaskState.TASK_STATE_WORKING),
|
||||
)
|
||||
sr = StreamResponse(status_update=update)
|
||||
assert is_stream_status_update(sr)
|
||||
|
||||
|
||||
class TestStatusUpdateFinality:
|
||||
"""Tests for status update finality detection."""
|
||||
|
||||
def test_completed_is_final(self) -> None:
|
||||
"""Completed status should be final."""
|
||||
from a2a.types import TaskState, TaskStatus, TaskStatusUpdateEvent
|
||||
|
||||
from crewai.a2a._compat import is_status_update_final
|
||||
|
||||
update = TaskStatusUpdateEvent(
|
||||
task_id="t1",
|
||||
context_id="c1",
|
||||
status=TaskStatus(state=TaskState.TASK_STATE_COMPLETED),
|
||||
)
|
||||
assert is_status_update_final(update) is True
|
||||
|
||||
def test_working_is_not_final(self) -> None:
|
||||
"""Working status should not be final."""
|
||||
from a2a.types import TaskState, TaskStatus, TaskStatusUpdateEvent
|
||||
|
||||
from crewai.a2a._compat import is_status_update_final
|
||||
|
||||
update = TaskStatusUpdateEvent(
|
||||
task_id="t1",
|
||||
context_id="c1",
|
||||
status=TaskStatus(state=TaskState.TASK_STATE_WORKING),
|
||||
)
|
||||
assert is_status_update_final(update) is False
|
||||
|
||||
def test_failed_is_final(self) -> None:
|
||||
"""Failed status should be final."""
|
||||
from a2a.types import TaskState, TaskStatus, TaskStatusUpdateEvent
|
||||
|
||||
from crewai.a2a._compat import is_status_update_final
|
||||
|
||||
update = TaskStatusUpdateEvent(
|
||||
task_id="t1",
|
||||
context_id="c1",
|
||||
status=TaskStatus(state=TaskState.TASK_STATE_FAILED),
|
||||
)
|
||||
assert is_status_update_final(update) is True
|
||||
|
||||
|
||||
class TestClientConfigHelper:
|
||||
"""Tests for client configuration helper."""
|
||||
|
||||
def test_create_client_config(self) -> None:
|
||||
"""create_client_config should produce a valid ClientConfig."""
|
||||
from crewai.a2a._compat import create_client_config
|
||||
|
||||
config = create_client_config(
|
||||
supported_transports=["JSONRPC", "GRPC"],
|
||||
streaming=True,
|
||||
polling=False,
|
||||
)
|
||||
assert config.supported_protocol_bindings == ["JSONRPC", "GRPC"]
|
||||
assert config.streaming is True
|
||||
assert config.polling is False
|
||||
|
||||
|
||||
class TestProtoToJson:
|
||||
"""Tests for proto_to_json serialization."""
|
||||
|
||||
def test_proto_to_json(self) -> None:
|
||||
"""proto_to_json should serialize a protobuf to JSON string."""
|
||||
from a2a.types import AgentCard, AgentInterface
|
||||
|
||||
from crewai.a2a._compat import proto_to_json
|
||||
|
||||
card = AgentCard(
|
||||
name="Test Agent",
|
||||
supported_interfaces=[
|
||||
AgentInterface(url="http://localhost:9999", protocol_binding="JSONRPC"),
|
||||
],
|
||||
)
|
||||
json_str = proto_to_json(card)
|
||||
assert isinstance(json_str, str)
|
||||
assert "Test Agent" in json_str
|
||||
|
||||
|
||||
class TestModuleImports:
|
||||
"""Verify all crewai.a2a submodules import without error under v1.0."""
|
||||
|
||||
def test_import_compat(self) -> None:
|
||||
from crewai.a2a._compat import A2AClientHTTPError # noqa: F401
|
||||
|
||||
def test_import_task_helpers(self) -> None:
|
||||
from crewai.a2a.task_helpers import process_task_state # noqa: F401
|
||||
|
||||
def test_import_polling_handler(self) -> None:
|
||||
from crewai.a2a.updates.polling.handler import PollingHandler # noqa: F401
|
||||
|
||||
def test_import_streaming_handler(self) -> None:
|
||||
from crewai.a2a.updates.streaming.handler import StreamingHandler # noqa: F401
|
||||
|
||||
def test_import_push_handler(self) -> None:
|
||||
from crewai.a2a.updates.push_notifications.handler import PushNotificationHandler # noqa: F401
|
||||
|
||||
def test_import_auth_utils(self) -> None:
|
||||
from crewai.a2a.auth.utils import validate_auth_against_agent_card # noqa: F401
|
||||
|
||||
def test_import_delegation(self) -> None:
|
||||
from crewai.a2a.utils.delegation import execute_a2a_delegation # noqa: F401
|
||||
|
||||
def test_import_transport(self) -> None:
|
||||
from crewai.a2a.utils.transport import negotiate_transport # noqa: F401
|
||||
|
||||
def test_import_agent_card(self) -> None:
|
||||
from crewai.a2a.utils.agent_card import afetch_agent_card # noqa: F401
|
||||
|
||||
def test_import_agent_card_signing(self) -> None:
|
||||
from crewai.a2a.utils.agent_card_signing import sign_agent_card # noqa: F401
|
||||
|
||||
def test_import_wrapper(self) -> None:
|
||||
from crewai.a2a.wrapper import wrap_agent_with_a2a_instance # noqa: F401
|
||||
|
||||
def test_import_extensions_registry(self) -> None:
|
||||
from crewai.a2a.extensions.registry import ExtensionsMiddleware # noqa: F401
|
||||
|
||||
def test_import_content_type(self) -> None:
|
||||
from crewai.a2a.utils.content_type import get_part_content_type # noqa: F401
|
||||
|
||||
|
||||
class TestGetPartContentType:
|
||||
"""Tests for get_part_content_type with v1.0 protobuf Parts."""
|
||||
|
||||
def test_text_part_returns_text_plain(self) -> None:
|
||||
from a2a.types import Part
|
||||
|
||||
from crewai.a2a.utils.content_type import get_part_content_type
|
||||
|
||||
part = Part(text="hello")
|
||||
assert get_part_content_type(part) == "text/plain"
|
||||
|
||||
def test_data_part_returns_application_json(self) -> None:
|
||||
from a2a.types import Part
|
||||
from google.protobuf.struct_pb2 import Value
|
||||
|
||||
from crewai.a2a.utils.content_type import get_part_content_type
|
||||
|
||||
v = Value()
|
||||
v.string_value = "test"
|
||||
part = Part(data=v)
|
||||
assert get_part_content_type(part) == "application/json"
|
||||
|
||||
def test_raw_part_returns_media_type(self) -> None:
|
||||
from a2a.types import Part
|
||||
|
||||
from crewai.a2a.utils.content_type import get_part_content_type
|
||||
|
||||
part = Part(raw=b"pdf content", media_type="application/pdf")
|
||||
assert get_part_content_type(part) == "application/pdf"
|
||||
|
||||
def test_url_part_returns_media_type(self) -> None:
|
||||
from a2a.types import Part
|
||||
|
||||
from crewai.a2a.utils.content_type import get_part_content_type
|
||||
|
||||
part = Part(url="https://example.com/image.png", media_type="image/png")
|
||||
assert get_part_content_type(part) == "image/png"
|
||||
@@ -5,6 +5,7 @@ from __future__ import annotations
|
||||
from a2a.types import AgentCard, AgentSkill
|
||||
|
||||
from crewai import Agent
|
||||
from crewai.a2a._compat import agent_card_to_dict, agent_card_url, proto_to_json
|
||||
from crewai.a2a.config import A2AClientConfig, A2AServerConfig
|
||||
from crewai.a2a.utils.agent_card import inject_a2a_server_methods
|
||||
|
||||
@@ -154,7 +155,7 @@ class TestToAgentCard:
|
||||
|
||||
card = agent.to_agent_card("http://my-server.com:9000")
|
||||
|
||||
assert card.url == "http://my-server.com:9000"
|
||||
assert agent_card_url(card) == "http://my-server.com:9000"
|
||||
|
||||
def test_uses_server_config_url(self) -> None:
|
||||
"""AgentCard url should prefer A2AServerConfig.url over provided URL."""
|
||||
@@ -167,7 +168,8 @@ class TestToAgentCard:
|
||||
|
||||
card = agent.to_agent_card("http://fallback-url.com")
|
||||
|
||||
assert card.url == "http://configured-url.com/"
|
||||
url = agent_card_url(card)
|
||||
assert url.rstrip("/") == "http://configured-url.com"
|
||||
|
||||
def test_generates_default_skill(self) -> None:
|
||||
"""AgentCard should have at least one skill based on agent role."""
|
||||
@@ -246,16 +248,16 @@ class TestAgentCardJsonStructure:
|
||||
)
|
||||
|
||||
card = agent.to_agent_card("http://localhost:8000")
|
||||
json_data = card.model_dump()
|
||||
json_data = agent_card_to_dict(card)
|
||||
|
||||
assert "name" in json_data
|
||||
assert "description" in json_data
|
||||
assert "url" in json_data
|
||||
assert "supported_interfaces" in json_data
|
||||
assert "version" in json_data
|
||||
assert "skills" in json_data
|
||||
assert "capabilities" in json_data
|
||||
assert "defaultInputModes" in json_data
|
||||
assert "defaultOutputModes" in json_data
|
||||
assert "default_input_modes" in json_data
|
||||
assert "default_output_modes" in json_data
|
||||
|
||||
def test_json_skills_structure(self) -> None:
|
||||
"""Each skill in JSON should have required fields."""
|
||||
@@ -267,7 +269,7 @@ class TestAgentCardJsonStructure:
|
||||
)
|
||||
|
||||
card = agent.to_agent_card("http://localhost:8000")
|
||||
json_data = card.model_dump()
|
||||
json_data = agent_card_to_dict(card)
|
||||
|
||||
assert len(json_data["skills"]) >= 1
|
||||
skill = json_data["skills"][0]
|
||||
@@ -286,11 +288,11 @@ class TestAgentCardJsonStructure:
|
||||
)
|
||||
|
||||
card = agent.to_agent_card("http://localhost:8000")
|
||||
json_data = card.model_dump()
|
||||
json_data = agent_card_to_dict(card)
|
||||
|
||||
capabilities = json_data["capabilities"]
|
||||
assert "streaming" in capabilities
|
||||
assert "pushNotifications" in capabilities
|
||||
assert "push_notifications" in capabilities
|
||||
|
||||
def test_json_serializable(self) -> None:
|
||||
"""AgentCard should be JSON serializable."""
|
||||
@@ -302,14 +304,14 @@ class TestAgentCardJsonStructure:
|
||||
)
|
||||
|
||||
card = agent.to_agent_card("http://localhost:8000")
|
||||
json_str = card.model_dump_json()
|
||||
json_str = proto_to_json(card)
|
||||
|
||||
assert isinstance(json_str, str)
|
||||
assert "Test Agent" in json_str
|
||||
assert "http://localhost:8000" in json_str
|
||||
|
||||
def test_json_excludes_none_values(self) -> None:
|
||||
"""AgentCard JSON with exclude_none should omit None fields."""
|
||||
def test_json_excludes_unset_fields(self) -> None:
|
||||
"""AgentCard JSON should omit fields that were not explicitly set."""
|
||||
agent = Agent(
|
||||
role="Test Agent",
|
||||
goal="Test goal",
|
||||
@@ -318,8 +320,6 @@ class TestAgentCardJsonStructure:
|
||||
)
|
||||
|
||||
card = agent.to_agent_card("http://localhost:8000")
|
||||
json_data = card.model_dump(exclude_none=True)
|
||||
json_data = agent_card_to_dict(card)
|
||||
|
||||
assert "provider" not in json_data
|
||||
assert "documentationUrl" not in json_data
|
||||
assert "iconUrl" not in json_data
|
||||
|
||||
@@ -12,6 +12,7 @@ from a2a.server.agent_execution import RequestContext
|
||||
from a2a.server.events import EventQueue
|
||||
from a2a.types import Message, Task as A2ATask, TaskState, TaskStatus
|
||||
|
||||
from crewai.a2a._compat import TASK_STATE_CANCELED, TASK_STATE_WORKING
|
||||
from crewai.a2a.utils.task import cancel, cancellable, execute
|
||||
|
||||
|
||||
@@ -38,12 +39,13 @@ def mock_task(mock_context: MagicMock) -> MagicMock:
|
||||
@pytest.fixture
|
||||
def mock_context() -> MagicMock:
|
||||
"""Create a mock RequestContext."""
|
||||
from crewai.a2a._compat import ROLE_USER, new_text_message
|
||||
|
||||
context = MagicMock(spec=RequestContext)
|
||||
context.task_id = "test-task-123"
|
||||
context.context_id = "test-context-456"
|
||||
context.get_user_input.return_value = "Test user message"
|
||||
context.message = MagicMock(spec=Message)
|
||||
context.message.parts = []
|
||||
context.message = new_text_message("Test user message", role=ROLE_USER)
|
||||
context.current_task = None
|
||||
return context
|
||||
|
||||
@@ -291,8 +293,7 @@ class TestCancel:
|
||||
|
||||
assert event.task_id == mock_context.task_id
|
||||
assert event.context_id == mock_context.context_id
|
||||
assert event.status.state == TaskState.canceled
|
||||
assert event.final is True
|
||||
assert event.status.state == TASK_STATE_CANCELED
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_none_when_no_current_task(
|
||||
@@ -315,13 +316,13 @@ class TestCancel:
|
||||
) -> None:
|
||||
"""Cancel returns updated task when context has current_task."""
|
||||
current_task = MagicMock(spec=A2ATask)
|
||||
current_task.status = TaskStatus(state=TaskState.working)
|
||||
current_task.status = TaskStatus(state=TASK_STATE_WORKING)
|
||||
mock_context.current_task = current_task
|
||||
|
||||
result = await cancel(mock_context, mock_event_queue)
|
||||
|
||||
assert result is current_task
|
||||
assert result.status.state == TaskState.canceled
|
||||
assert result.status.state == TASK_STATE_CANCELED
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_after_cancel(
|
||||
|
||||
Reference in New Issue
Block a user