feat: add polling and streaming handlers

This commit is contained in:
Greyson LaLonde
2026-01-05 18:50:12 -05:00
parent 7589e524ab
commit 12d60a483a
4 changed files with 416 additions and 1 deletions

View File

@@ -0,0 +1,228 @@
"""Polling update mechanism handler."""
from __future__ import annotations
import asyncio
import time
from typing import TYPE_CHECKING, Any
from a2a.client import Client
from a2a.types import (
AgentCard,
Message,
TaskQueryParams,
TaskState,
)
from crewai.a2a.errors import A2APollingTimeoutError
from crewai.a2a.task_helpers import TaskStateResult, process_task_state
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.a2a_events import (
A2APollingStartedEvent,
A2APollingStatusEvent,
A2AResponseReceivedEvent,
)
if TYPE_CHECKING:
from a2a.types import Task as A2ATask
TERMINAL_STATES = {
TaskState.completed,
TaskState.failed,
TaskState.rejected,
TaskState.canceled,
}
async def poll_task_until_complete(
client: Client,
task_id: str,
polling_interval: float,
polling_timeout: float,
agent_branch: Any | None = None,
history_length: int = 100,
max_polls: int | None = None,
) -> A2ATask:
"""Poll task status until terminal state reached.
Args:
client: A2A client instance
task_id: Task ID to poll
polling_interval: Seconds between poll attempts
polling_timeout: Max seconds before timeout
agent_branch: Agent tree branch for logging
history_length: Number of messages to retrieve per poll
max_polls: Max number of poll attempts (None = unlimited)
Returns:
Final task object in terminal state
Raises:
A2APollingTimeoutError: If polling exceeds timeout or max_polls
"""
start_time = time.monotonic()
poll_count = 0
while True:
poll_count += 1
task = await client.get_task(
TaskQueryParams(id=task_id, history_length=history_length)
)
elapsed = time.monotonic() - start_time
crewai_event_bus.emit(
agent_branch,
A2APollingStatusEvent(
task_id=task_id,
state=str(task.status.state.value) if task.status.state else "unknown",
elapsed_seconds=elapsed,
poll_count=poll_count,
),
)
if task.status.state in TERMINAL_STATES:
return task
if task.status.state in {TaskState.input_required, TaskState.auth_required}:
return task
if elapsed > polling_timeout:
raise A2APollingTimeoutError(
f"Polling timeout after {polling_timeout}s ({poll_count} polls)"
)
if max_polls and poll_count >= max_polls:
raise A2APollingTimeoutError(
f"Max polls ({max_polls}) exceeded after {elapsed:.1f}s"
)
await asyncio.sleep(polling_interval)
async def execute_polling_delegation(
client: Client,
message: Message,
polling_interval: float,
polling_timeout: float,
endpoint: str,
agent_branch: Any | None,
turn_number: int,
is_multiturn: bool,
agent_role: str | None,
new_messages: list[Message],
agent_card: AgentCard,
history_length: int = 100,
max_polls: int | None = None,
) -> TaskStateResult:
"""Execute A2A delegation using polling for updates.
Args:
client: A2A client instance
message: Message to send
polling_interval: Seconds between poll attempts
polling_timeout: Max seconds before timeout
endpoint: A2A agent endpoint URL
agent_branch: Agent tree branch for logging
turn_number: Current turn number
is_multiturn: Whether this is a multi-turn conversation
agent_role: Agent role for logging
new_messages: List to collect messages
agent_card: The agent card
history_length: Number of messages to retrieve per poll
max_polls: Max number of poll attempts (None = unlimited)
Returns:
Dictionary with status, result/error, and history
"""
task_id: str | None = None
async for event in client.send_message(message):
if isinstance(event, Message):
new_messages.append(event)
result_parts = [
part.root.text for part in event.parts if part.root.kind == "text"
]
response_text = " ".join(result_parts) if result_parts else ""
crewai_event_bus.emit(
None,
A2AResponseReceivedEvent(
response=response_text,
turn_number=turn_number,
is_multiturn=is_multiturn,
status="completed",
agent_role=agent_role,
),
)
return TaskStateResult(
status=TaskState.completed,
result=response_text,
history=new_messages,
agent_card=agent_card,
)
if isinstance(event, tuple):
a2a_task, _ = event
task_id = a2a_task.id
if a2a_task.status.state in TERMINAL_STATES | {
TaskState.input_required,
TaskState.auth_required,
}:
result = process_task_state(
a2a_task=a2a_task,
new_messages=new_messages,
agent_card=agent_card,
turn_number=turn_number,
is_multiturn=is_multiturn,
agent_role=agent_role,
)
if result:
return result
break
if not task_id:
return TaskStateResult(
status=TaskState.failed,
error="No task ID received from initial message",
history=new_messages,
)
crewai_event_bus.emit(
agent_branch,
A2APollingStartedEvent(
task_id=task_id,
polling_interval=polling_interval,
endpoint=endpoint,
),
)
final_task = await poll_task_until_complete(
client=client,
task_id=task_id,
polling_interval=polling_interval,
polling_timeout=polling_timeout,
agent_branch=agent_branch,
history_length=history_length,
max_polls=max_polls,
)
result = process_task_state(
a2a_task=final_task,
new_messages=new_messages,
agent_card=agent_card,
turn_number=turn_number,
is_multiturn=is_multiturn,
agent_role=agent_role,
)
if result:
return result
return TaskStateResult(
status=TaskState.failed,
error=f"Unexpected task state: {final_task.status.state}",
history=new_messages,
)

View File

@@ -0,0 +1,3 @@
"""Push notification (webhook) update mechanism handler."""
from __future__ import annotations

View File

@@ -0,0 +1,152 @@
"""Streaming (SSE) update mechanism handler."""
from __future__ import annotations
import uuid
from a2a.client import Client
from a2a.client.errors import A2AClientHTTPError
from a2a.types import (
AgentCard,
Message,
Part,
Role,
TaskArtifactUpdateEvent,
TaskState,
TaskStatusUpdateEvent,
TextPart,
)
from crewai.a2a.task_helpers import TaskStateResult, process_task_state
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.a2a_events import A2AResponseReceivedEvent
async def execute_streaming_delegation(
client: Client,
message: Message,
context_id: str | None,
task_id: str | None,
turn_number: int,
is_multiturn: bool,
agent_role: str | None,
new_messages: list[Message],
agent_card: AgentCard,
) -> TaskStateResult:
"""Execute A2A delegation using SSE streaming for updates.
Args:
client: A2A client instance
message: Message to send
context_id: Context ID for correlation
task_id: Task ID for correlation
turn_number: Current turn number
is_multiturn: Whether this is a multi-turn conversation
agent_role: Agent role for logging
new_messages: List to collect messages
agent_card: The agent card
Returns:
Dictionary with status, result/error, and history
"""
result_parts: list[str] = []
final_result: TaskStateResult | None = None
event_stream = client.send_message(message)
try:
async for event in event_stream:
if isinstance(event, Message):
new_messages.append(event)
for part in event.parts:
if part.root.kind == "text":
text = part.root.text
result_parts.append(text)
elif isinstance(event, tuple):
a2a_task, update = event
if isinstance(update, TaskArtifactUpdateEvent):
artifact = update.artifact
result_parts.extend(
part.root.text
for part in artifact.parts
if part.root.kind == "text"
)
is_final_update = False
if isinstance(update, TaskStatusUpdateEvent):
is_final_update = update.final
if not is_final_update and a2a_task.status.state not in [
TaskState.completed,
TaskState.input_required,
TaskState.failed,
TaskState.rejected,
TaskState.auth_required,
TaskState.canceled,
]:
continue
final_result = process_task_state(
a2a_task=a2a_task,
new_messages=new_messages,
agent_card=agent_card,
turn_number=turn_number,
is_multiturn=is_multiturn,
agent_role=agent_role,
result_parts=result_parts,
)
if final_result:
break
except A2AClientHTTPError as e:
error_msg = f"HTTP Error {e.status_code}: {e!s}"
error_message = Message(
role=Role.agent,
message_id=str(uuid.uuid4()),
parts=[Part(root=TextPart(text=error_msg))],
context_id=context_id,
task_id=task_id,
)
new_messages.append(error_message)
crewai_event_bus.emit(
None,
A2AResponseReceivedEvent(
response=error_msg,
turn_number=turn_number,
is_multiturn=is_multiturn,
status="failed",
agent_role=agent_role,
),
)
return TaskStateResult(
status=TaskState.failed,
error=error_msg,
history=new_messages,
)
except Exception as e:
current_exception: Exception | BaseException | None = e
while current_exception:
if hasattr(current_exception, "response"):
response = current_exception.response
if hasattr(response, "text"):
break
if current_exception and hasattr(current_exception, "__cause__"):
current_exception = current_exception.__cause__
raise
finally:
if hasattr(event_stream, "aclose"):
await event_stream.aclose()
if final_result:
return final_result
return TaskStateResult(
status=TaskState.completed,
result=" ".join(result_parts) if result_parts else "",
history=new_messages,
)

View File

@@ -15,7 +15,7 @@ class A2AEventBase(BaseEvent):
from_task: Any | None = None
from_agent: Any | None = None
def __init__(self, **data):
def __init__(self, **data: Any) -> None:
"""Initialize A2A event, extracting task and agent metadata."""
if data.get("from_task"):
task = data["from_task"]
@@ -139,3 +139,35 @@ class A2AConversationCompletedEvent(A2AEventBase):
final_result: str | None = None
error: str | None = None
total_turns: int
class A2APollingStartedEvent(A2AEventBase):
"""Event emitted when polling mode begins for A2A delegation.
Attributes:
task_id: A2A task ID being polled
polling_interval: Seconds between poll attempts
endpoint: A2A agent endpoint URL
"""
type: str = "a2a_polling_started"
task_id: str
polling_interval: float
endpoint: str
class A2APollingStatusEvent(A2AEventBase):
"""Event emitted on each polling iteration.
Attributes:
task_id: A2A task ID being polled
state: Current task state from remote agent
elapsed_seconds: Time since polling started
poll_count: Number of polls completed
"""
type: str = "a2a_polling_status"
task_id: str
state: str
elapsed_seconds: float
poll_count: int