From 12d60a483a442b224b6ddc564ccb86bf0bf2a177 Mon Sep 17 00:00:00 2001 From: Greyson LaLonde Date: Mon, 5 Jan 2026 18:50:12 -0500 Subject: [PATCH] feat: add polling and streaming handlers --- .../src/crewai/a2a/updates/polling/handler.py | 228 ++++++++++++++++++ .../a2a/updates/push_notifications/handler.py | 3 + .../crewai/a2a/updates/streaming/handler.py | 152 ++++++++++++ .../src/crewai/events/types/a2a_events.py | 34 ++- 4 files changed, 416 insertions(+), 1 deletion(-) create mode 100644 lib/crewai/src/crewai/a2a/updates/polling/handler.py create mode 100644 lib/crewai/src/crewai/a2a/updates/push_notifications/handler.py create mode 100644 lib/crewai/src/crewai/a2a/updates/streaming/handler.py diff --git a/lib/crewai/src/crewai/a2a/updates/polling/handler.py b/lib/crewai/src/crewai/a2a/updates/polling/handler.py new file mode 100644 index 000000000..7099538eb --- /dev/null +++ b/lib/crewai/src/crewai/a2a/updates/polling/handler.py @@ -0,0 +1,228 @@ +"""Polling update mechanism handler.""" + +from __future__ import annotations + +import asyncio +import time +from typing import TYPE_CHECKING, Any + +from a2a.client import Client +from a2a.types import ( + AgentCard, + Message, + TaskQueryParams, + TaskState, +) + +from crewai.a2a.errors import A2APollingTimeoutError +from crewai.a2a.task_helpers import TaskStateResult, process_task_state +from crewai.events.event_bus import crewai_event_bus +from crewai.events.types.a2a_events import ( + A2APollingStartedEvent, + A2APollingStatusEvent, + A2AResponseReceivedEvent, +) + + +if TYPE_CHECKING: + from a2a.types import Task as A2ATask + + +TERMINAL_STATES = { + TaskState.completed, + TaskState.failed, + TaskState.rejected, + TaskState.canceled, +} + + +async def poll_task_until_complete( + client: Client, + task_id: str, + polling_interval: float, + polling_timeout: float, + agent_branch: Any | None = None, + history_length: int = 100, + max_polls: int | None = None, +) -> A2ATask: + """Poll task status until terminal state reached. + + Args: + client: A2A client instance + task_id: Task ID to poll + polling_interval: Seconds between poll attempts + polling_timeout: Max seconds before timeout + agent_branch: Agent tree branch for logging + history_length: Number of messages to retrieve per poll + max_polls: Max number of poll attempts (None = unlimited) + + Returns: + Final task object in terminal state + + Raises: + A2APollingTimeoutError: If polling exceeds timeout or max_polls + """ + start_time = time.monotonic() + poll_count = 0 + + while True: + poll_count += 1 + task = await client.get_task( + TaskQueryParams(id=task_id, history_length=history_length) + ) + + elapsed = time.monotonic() - start_time + crewai_event_bus.emit( + agent_branch, + A2APollingStatusEvent( + task_id=task_id, + state=str(task.status.state.value) if task.status.state else "unknown", + elapsed_seconds=elapsed, + poll_count=poll_count, + ), + ) + + if task.status.state in TERMINAL_STATES: + return task + + if task.status.state in {TaskState.input_required, TaskState.auth_required}: + return task + + if elapsed > polling_timeout: + raise A2APollingTimeoutError( + f"Polling timeout after {polling_timeout}s ({poll_count} polls)" + ) + + if max_polls and poll_count >= max_polls: + raise A2APollingTimeoutError( + f"Max polls ({max_polls}) exceeded after {elapsed:.1f}s" + ) + + await asyncio.sleep(polling_interval) + + +async def execute_polling_delegation( + client: Client, + message: Message, + polling_interval: float, + polling_timeout: float, + endpoint: str, + agent_branch: Any | None, + turn_number: int, + is_multiturn: bool, + agent_role: str | None, + new_messages: list[Message], + agent_card: AgentCard, + history_length: int = 100, + max_polls: int | None = None, +) -> TaskStateResult: + """Execute A2A delegation using polling for updates. + + Args: + client: A2A client instance + message: Message to send + polling_interval: Seconds between poll attempts + polling_timeout: Max seconds before timeout + endpoint: A2A agent endpoint URL + agent_branch: Agent tree branch for logging + turn_number: Current turn number + is_multiturn: Whether this is a multi-turn conversation + agent_role: Agent role for logging + new_messages: List to collect messages + agent_card: The agent card + history_length: Number of messages to retrieve per poll + max_polls: Max number of poll attempts (None = unlimited) + + Returns: + Dictionary with status, result/error, and history + """ + task_id: str | None = None + + async for event in client.send_message(message): + if isinstance(event, Message): + new_messages.append(event) + result_parts = [ + part.root.text for part in event.parts if part.root.kind == "text" + ] + response_text = " ".join(result_parts) if result_parts else "" + + crewai_event_bus.emit( + None, + A2AResponseReceivedEvent( + response=response_text, + turn_number=turn_number, + is_multiturn=is_multiturn, + status="completed", + agent_role=agent_role, + ), + ) + + return TaskStateResult( + status=TaskState.completed, + result=response_text, + history=new_messages, + agent_card=agent_card, + ) + + if isinstance(event, tuple): + a2a_task, _ = event + task_id = a2a_task.id + + if a2a_task.status.state in TERMINAL_STATES | { + TaskState.input_required, + TaskState.auth_required, + }: + result = process_task_state( + a2a_task=a2a_task, + new_messages=new_messages, + agent_card=agent_card, + turn_number=turn_number, + is_multiturn=is_multiturn, + agent_role=agent_role, + ) + if result: + return result + break + + if not task_id: + return TaskStateResult( + status=TaskState.failed, + error="No task ID received from initial message", + history=new_messages, + ) + + crewai_event_bus.emit( + agent_branch, + A2APollingStartedEvent( + task_id=task_id, + polling_interval=polling_interval, + endpoint=endpoint, + ), + ) + + final_task = await poll_task_until_complete( + client=client, + task_id=task_id, + polling_interval=polling_interval, + polling_timeout=polling_timeout, + agent_branch=agent_branch, + history_length=history_length, + max_polls=max_polls, + ) + + result = process_task_state( + a2a_task=final_task, + new_messages=new_messages, + agent_card=agent_card, + turn_number=turn_number, + is_multiturn=is_multiturn, + agent_role=agent_role, + ) + if result: + return result + + return TaskStateResult( + status=TaskState.failed, + error=f"Unexpected task state: {final_task.status.state}", + history=new_messages, + ) diff --git a/lib/crewai/src/crewai/a2a/updates/push_notifications/handler.py b/lib/crewai/src/crewai/a2a/updates/push_notifications/handler.py new file mode 100644 index 000000000..cff96bfaa --- /dev/null +++ b/lib/crewai/src/crewai/a2a/updates/push_notifications/handler.py @@ -0,0 +1,3 @@ +"""Push notification (webhook) update mechanism handler.""" + +from __future__ import annotations diff --git a/lib/crewai/src/crewai/a2a/updates/streaming/handler.py b/lib/crewai/src/crewai/a2a/updates/streaming/handler.py new file mode 100644 index 000000000..b453c687c --- /dev/null +++ b/lib/crewai/src/crewai/a2a/updates/streaming/handler.py @@ -0,0 +1,152 @@ +"""Streaming (SSE) update mechanism handler.""" + +from __future__ import annotations + +import uuid + +from a2a.client import Client +from a2a.client.errors import A2AClientHTTPError +from a2a.types import ( + AgentCard, + Message, + Part, + Role, + TaskArtifactUpdateEvent, + TaskState, + TaskStatusUpdateEvent, + TextPart, +) + +from crewai.a2a.task_helpers import TaskStateResult, process_task_state +from crewai.events.event_bus import crewai_event_bus +from crewai.events.types.a2a_events import A2AResponseReceivedEvent + + +async def execute_streaming_delegation( + client: Client, + message: Message, + context_id: str | None, + task_id: str | None, + turn_number: int, + is_multiturn: bool, + agent_role: str | None, + new_messages: list[Message], + agent_card: AgentCard, +) -> TaskStateResult: + """Execute A2A delegation using SSE streaming for updates. + + Args: + client: A2A client instance + message: Message to send + context_id: Context ID for correlation + task_id: Task ID for correlation + turn_number: Current turn number + is_multiturn: Whether this is a multi-turn conversation + agent_role: Agent role for logging + new_messages: List to collect messages + agent_card: The agent card + + Returns: + Dictionary with status, result/error, and history + """ + result_parts: list[str] = [] + final_result: TaskStateResult | None = None + event_stream = client.send_message(message) + + try: + async for event in event_stream: + if isinstance(event, Message): + new_messages.append(event) + for part in event.parts: + if part.root.kind == "text": + text = part.root.text + result_parts.append(text) + + elif isinstance(event, tuple): + a2a_task, update = event + + if isinstance(update, TaskArtifactUpdateEvent): + artifact = update.artifact + result_parts.extend( + part.root.text + for part in artifact.parts + if part.root.kind == "text" + ) + + is_final_update = False + if isinstance(update, TaskStatusUpdateEvent): + is_final_update = update.final + + if not is_final_update and a2a_task.status.state not in [ + TaskState.completed, + TaskState.input_required, + TaskState.failed, + TaskState.rejected, + TaskState.auth_required, + TaskState.canceled, + ]: + continue + + final_result = process_task_state( + a2a_task=a2a_task, + new_messages=new_messages, + agent_card=agent_card, + turn_number=turn_number, + is_multiturn=is_multiturn, + agent_role=agent_role, + result_parts=result_parts, + ) + if final_result: + break + + except A2AClientHTTPError as e: + error_msg = f"HTTP Error {e.status_code}: {e!s}" + + error_message = Message( + role=Role.agent, + message_id=str(uuid.uuid4()), + parts=[Part(root=TextPart(text=error_msg))], + context_id=context_id, + task_id=task_id, + ) + new_messages.append(error_message) + + crewai_event_bus.emit( + None, + A2AResponseReceivedEvent( + response=error_msg, + turn_number=turn_number, + is_multiturn=is_multiturn, + status="failed", + agent_role=agent_role, + ), + ) + return TaskStateResult( + status=TaskState.failed, + error=error_msg, + history=new_messages, + ) + + except Exception as e: + current_exception: Exception | BaseException | None = e + while current_exception: + if hasattr(current_exception, "response"): + response = current_exception.response + if hasattr(response, "text"): + break + if current_exception and hasattr(current_exception, "__cause__"): + current_exception = current_exception.__cause__ + raise + + finally: + if hasattr(event_stream, "aclose"): + await event_stream.aclose() + + if final_result: + return final_result + + return TaskStateResult( + status=TaskState.completed, + result=" ".join(result_parts) if result_parts else "", + history=new_messages, + ) diff --git a/lib/crewai/src/crewai/events/types/a2a_events.py b/lib/crewai/src/crewai/events/types/a2a_events.py index baafd53c3..6afd1533d 100644 --- a/lib/crewai/src/crewai/events/types/a2a_events.py +++ b/lib/crewai/src/crewai/events/types/a2a_events.py @@ -15,7 +15,7 @@ class A2AEventBase(BaseEvent): from_task: Any | None = None from_agent: Any | None = None - def __init__(self, **data): + def __init__(self, **data: Any) -> None: """Initialize A2A event, extracting task and agent metadata.""" if data.get("from_task"): task = data["from_task"] @@ -139,3 +139,35 @@ class A2AConversationCompletedEvent(A2AEventBase): final_result: str | None = None error: str | None = None total_turns: int + + +class A2APollingStartedEvent(A2AEventBase): + """Event emitted when polling mode begins for A2A delegation. + + Attributes: + task_id: A2A task ID being polled + polling_interval: Seconds between poll attempts + endpoint: A2A agent endpoint URL + """ + + type: str = "a2a_polling_started" + task_id: str + polling_interval: float + endpoint: str + + +class A2APollingStatusEvent(A2AEventBase): + """Event emitted on each polling iteration. + + Attributes: + task_id: A2A task ID being polled + state: Current task state from remote agent + elapsed_seconds: Time since polling started + poll_count: Number of polls completed + """ + + type: str = "a2a_polling_status" + task_id: str + state: str + elapsed_seconds: float + poll_count: int