feat: add shared task helpers and error types

This commit is contained in:
Greyson LaLonde
2026-01-05 18:47:28 -05:00
parent 2d09e6bbcd
commit 7589e524ab
2 changed files with 213 additions and 0 deletions

View File

@@ -0,0 +1,7 @@
"""A2A protocol error types."""
from a2a.client.errors import A2AClientTimeoutError
class A2APollingTimeoutError(A2AClientTimeoutError):
"""Raised when polling exceeds the configured timeout."""

View File

@@ -0,0 +1,206 @@
"""Helper functions for processing A2A task results."""
from __future__ import annotations
from typing import TYPE_CHECKING, NotRequired, TypedDict
import uuid
from a2a.types import AgentCard, Message, Part, Role, TaskState, TextPart
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.a2a_events import A2AResponseReceivedEvent
if TYPE_CHECKING:
from a2a.types import Task as A2ATask
class TaskStateResult(TypedDict):
"""Result dictionary from processing A2A task state."""
status: TaskState
history: list[Message]
result: NotRequired[str]
error: NotRequired[str]
agent_card: NotRequired[AgentCard]
def extract_task_result_parts(a2a_task: A2ATask) -> list[str]:
"""Extract result parts from A2A task status message, history, and artifacts.
Args:
a2a_task: A2A Task object with status, history, and artifacts
Returns:
List of result text parts
"""
result_parts: list[str] = []
if a2a_task.status and a2a_task.status.message:
msg = a2a_task.status.message
result_parts.extend(
part.root.text for part in msg.parts if part.root.kind == "text"
)
if not result_parts and a2a_task.history:
for history_msg in reversed(a2a_task.history):
if history_msg.role == Role.agent:
result_parts.extend(
part.root.text
for part in history_msg.parts
if part.root.kind == "text"
)
break
if a2a_task.artifacts:
result_parts.extend(
part.root.text
for artifact in a2a_task.artifacts
for part in artifact.parts
if part.root.kind == "text"
)
return result_parts
def extract_error_message(a2a_task: A2ATask, default: str) -> str:
"""Extract error message from A2A task.
Args:
a2a_task: A2A Task object
default: Default message if no error found
Returns:
Error message string
"""
if a2a_task.status and a2a_task.status.message:
msg = a2a_task.status.message
if msg:
for part in msg.parts:
if part.root.kind == "text":
return str(part.root.text)
return str(msg)
if a2a_task.history:
for history_msg in reversed(a2a_task.history):
for part in history_msg.parts:
if part.root.kind == "text":
return str(part.root.text)
return default
def process_task_state(
a2a_task: A2ATask,
new_messages: list[Message],
agent_card: AgentCard,
turn_number: int,
is_multiturn: bool,
agent_role: str | None,
result_parts: list[str] | None = None,
) -> TaskStateResult | None:
"""Process A2A task state and return result dictionary.
Shared logic for both polling and streaming handlers.
Args:
a2a_task: The A2A task to process
new_messages: List to collect messages (modified in place)
agent_card: The agent card
turn_number: Current turn number
is_multiturn: Whether multi-turn conversation
agent_role: Agent role for logging
result_parts: Accumulated result parts (streaming passes accumulated,
polling passes None to extract from task)
Returns:
Result dictionary if terminal/actionable state, None otherwise
"""
if result_parts is None:
result_parts = []
if a2a_task.status.state == TaskState.completed:
extracted_parts = extract_task_result_parts(a2a_task)
result_parts.extend(extracted_parts)
if a2a_task.history:
new_messages.extend(a2a_task.history)
response_text = " ".join(result_parts) if result_parts else ""
crewai_event_bus.emit(
None,
A2AResponseReceivedEvent(
response=response_text,
turn_number=turn_number,
is_multiturn=is_multiturn,
status="completed",
agent_role=agent_role,
),
)
return TaskStateResult(
status=TaskState.completed,
agent_card=agent_card,
result=response_text,
history=new_messages,
)
if a2a_task.status.state == TaskState.input_required:
if a2a_task.history:
new_messages.extend(a2a_task.history)
response_text = extract_error_message(a2a_task, "Additional input required")
if response_text and not a2a_task.history:
agent_message = Message(
role=Role.agent,
message_id=str(uuid.uuid4()),
parts=[Part(root=TextPart(text=response_text))],
context_id=getattr(a2a_task, "context_id", None),
task_id=getattr(a2a_task, "task_id", None),
)
new_messages.append(agent_message)
crewai_event_bus.emit(
None,
A2AResponseReceivedEvent(
response=response_text,
turn_number=turn_number,
is_multiturn=is_multiturn,
status="input_required",
agent_role=agent_role,
),
)
return TaskStateResult(
status=TaskState.input_required,
error=response_text,
history=new_messages,
agent_card=agent_card,
)
if a2a_task.status.state in {TaskState.failed, TaskState.rejected}:
error_msg = extract_error_message(a2a_task, "Task failed without error message")
if a2a_task.history:
new_messages.extend(a2a_task.history)
return TaskStateResult(
status=TaskState.failed,
error=error_msg,
history=new_messages,
)
if a2a_task.status.state == TaskState.auth_required:
error_msg = extract_error_message(a2a_task, "Authentication required")
return TaskStateResult(
status=TaskState.auth_required,
error=error_msg,
history=new_messages,
)
if a2a_task.status.state == TaskState.canceled:
error_msg = extract_error_message(a2a_task, "Task was canceled")
return TaskStateResult(
status=TaskState.canceled,
error=error_msg,
history=new_messages,
)
return None