diff --git a/lib/crewai/src/crewai/a2a/errors.py b/lib/crewai/src/crewai/a2a/errors.py new file mode 100644 index 000000000..e24e9c296 --- /dev/null +++ b/lib/crewai/src/crewai/a2a/errors.py @@ -0,0 +1,7 @@ +"""A2A protocol error types.""" + +from a2a.client.errors import A2AClientTimeoutError + + +class A2APollingTimeoutError(A2AClientTimeoutError): + """Raised when polling exceeds the configured timeout.""" diff --git a/lib/crewai/src/crewai/a2a/task_helpers.py b/lib/crewai/src/crewai/a2a/task_helpers.py new file mode 100644 index 000000000..26f7201ef --- /dev/null +++ b/lib/crewai/src/crewai/a2a/task_helpers.py @@ -0,0 +1,206 @@ +"""Helper functions for processing A2A task results.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, NotRequired, TypedDict +import uuid + +from a2a.types import AgentCard, Message, Part, Role, TaskState, TextPart + +from crewai.events.event_bus import crewai_event_bus +from crewai.events.types.a2a_events import A2AResponseReceivedEvent + + +if TYPE_CHECKING: + from a2a.types import Task as A2ATask + + +class TaskStateResult(TypedDict): + """Result dictionary from processing A2A task state.""" + + status: TaskState + history: list[Message] + result: NotRequired[str] + error: NotRequired[str] + agent_card: NotRequired[AgentCard] + + +def extract_task_result_parts(a2a_task: A2ATask) -> list[str]: + """Extract result parts from A2A task status message, history, and artifacts. + + Args: + a2a_task: A2A Task object with status, history, and artifacts + + Returns: + List of result text parts + """ + result_parts: list[str] = [] + + if a2a_task.status and a2a_task.status.message: + msg = a2a_task.status.message + result_parts.extend( + part.root.text for part in msg.parts if part.root.kind == "text" + ) + + if not result_parts and a2a_task.history: + for history_msg in reversed(a2a_task.history): + if history_msg.role == Role.agent: + result_parts.extend( + part.root.text + for part in history_msg.parts + if part.root.kind == "text" + ) + break + + if a2a_task.artifacts: + result_parts.extend( + part.root.text + for artifact in a2a_task.artifacts + for part in artifact.parts + if part.root.kind == "text" + ) + + return result_parts + + +def extract_error_message(a2a_task: A2ATask, default: str) -> str: + """Extract error message from A2A task. + + Args: + a2a_task: A2A Task object + default: Default message if no error found + + Returns: + Error message string + """ + if a2a_task.status and a2a_task.status.message: + msg = a2a_task.status.message + if msg: + for part in msg.parts: + if part.root.kind == "text": + return str(part.root.text) + return str(msg) + + if a2a_task.history: + for history_msg in reversed(a2a_task.history): + for part in history_msg.parts: + if part.root.kind == "text": + return str(part.root.text) + + return default + + +def process_task_state( + a2a_task: A2ATask, + new_messages: list[Message], + agent_card: AgentCard, + turn_number: int, + is_multiturn: bool, + agent_role: str | None, + result_parts: list[str] | None = None, +) -> TaskStateResult | None: + """Process A2A task state and return result dictionary. + + Shared logic for both polling and streaming handlers. + + Args: + a2a_task: The A2A task to process + new_messages: List to collect messages (modified in place) + agent_card: The agent card + turn_number: Current turn number + is_multiturn: Whether multi-turn conversation + agent_role: Agent role for logging + result_parts: Accumulated result parts (streaming passes accumulated, + polling passes None to extract from task) + + Returns: + Result dictionary if terminal/actionable state, None otherwise + """ + if result_parts is None: + result_parts = [] + + if a2a_task.status.state == TaskState.completed: + extracted_parts = extract_task_result_parts(a2a_task) + result_parts.extend(extracted_parts) + if a2a_task.history: + new_messages.extend(a2a_task.history) + + response_text = " ".join(result_parts) if result_parts else "" + crewai_event_bus.emit( + None, + A2AResponseReceivedEvent( + response=response_text, + turn_number=turn_number, + is_multiturn=is_multiturn, + status="completed", + agent_role=agent_role, + ), + ) + + return TaskStateResult( + status=TaskState.completed, + agent_card=agent_card, + result=response_text, + history=new_messages, + ) + + if a2a_task.status.state == TaskState.input_required: + if a2a_task.history: + new_messages.extend(a2a_task.history) + + response_text = extract_error_message(a2a_task, "Additional input required") + if response_text and not a2a_task.history: + agent_message = Message( + role=Role.agent, + message_id=str(uuid.uuid4()), + parts=[Part(root=TextPart(text=response_text))], + context_id=getattr(a2a_task, "context_id", None), + task_id=getattr(a2a_task, "task_id", None), + ) + new_messages.append(agent_message) + + crewai_event_bus.emit( + None, + A2AResponseReceivedEvent( + response=response_text, + turn_number=turn_number, + is_multiturn=is_multiturn, + status="input_required", + agent_role=agent_role, + ), + ) + + return TaskStateResult( + status=TaskState.input_required, + error=response_text, + history=new_messages, + agent_card=agent_card, + ) + + if a2a_task.status.state in {TaskState.failed, TaskState.rejected}: + error_msg = extract_error_message(a2a_task, "Task failed without error message") + if a2a_task.history: + new_messages.extend(a2a_task.history) + return TaskStateResult( + status=TaskState.failed, + error=error_msg, + history=new_messages, + ) + + if a2a_task.status.state == TaskState.auth_required: + error_msg = extract_error_message(a2a_task, "Authentication required") + return TaskStateResult( + status=TaskState.auth_required, + error=error_msg, + history=new_messages, + ) + + if a2a_task.status.state == TaskState.canceled: + error_msg = extract_error_message(a2a_task, "Task was canceled") + return TaskStateResult( + status=TaskState.canceled, + error=error_msg, + history=new_messages, + ) + + return None