mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-11 00:58:30 +00:00
feat: add polling and streaming handlers
This commit is contained in:
228
lib/crewai/src/crewai/a2a/updates/polling/handler.py
Normal file
228
lib/crewai/src/crewai/a2a/updates/polling/handler.py
Normal file
@@ -0,0 +1,228 @@
|
|||||||
|
"""Polling update mechanism handler."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import time
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
from a2a.client import Client
|
||||||
|
from a2a.types import (
|
||||||
|
AgentCard,
|
||||||
|
Message,
|
||||||
|
TaskQueryParams,
|
||||||
|
TaskState,
|
||||||
|
)
|
||||||
|
|
||||||
|
from crewai.a2a.errors import A2APollingTimeoutError
|
||||||
|
from crewai.a2a.task_helpers import TaskStateResult, process_task_state
|
||||||
|
from crewai.events.event_bus import crewai_event_bus
|
||||||
|
from crewai.events.types.a2a_events import (
|
||||||
|
A2APollingStartedEvent,
|
||||||
|
A2APollingStatusEvent,
|
||||||
|
A2AResponseReceivedEvent,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from a2a.types import Task as A2ATask
|
||||||
|
|
||||||
|
|
||||||
|
TERMINAL_STATES = {
|
||||||
|
TaskState.completed,
|
||||||
|
TaskState.failed,
|
||||||
|
TaskState.rejected,
|
||||||
|
TaskState.canceled,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def poll_task_until_complete(
|
||||||
|
client: Client,
|
||||||
|
task_id: str,
|
||||||
|
polling_interval: float,
|
||||||
|
polling_timeout: float,
|
||||||
|
agent_branch: Any | None = None,
|
||||||
|
history_length: int = 100,
|
||||||
|
max_polls: int | None = None,
|
||||||
|
) -> A2ATask:
|
||||||
|
"""Poll task status until terminal state reached.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
client: A2A client instance
|
||||||
|
task_id: Task ID to poll
|
||||||
|
polling_interval: Seconds between poll attempts
|
||||||
|
polling_timeout: Max seconds before timeout
|
||||||
|
agent_branch: Agent tree branch for logging
|
||||||
|
history_length: Number of messages to retrieve per poll
|
||||||
|
max_polls: Max number of poll attempts (None = unlimited)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Final task object in terminal state
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
A2APollingTimeoutError: If polling exceeds timeout or max_polls
|
||||||
|
"""
|
||||||
|
start_time = time.monotonic()
|
||||||
|
poll_count = 0
|
||||||
|
|
||||||
|
while True:
|
||||||
|
poll_count += 1
|
||||||
|
task = await client.get_task(
|
||||||
|
TaskQueryParams(id=task_id, history_length=history_length)
|
||||||
|
)
|
||||||
|
|
||||||
|
elapsed = time.monotonic() - start_time
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
agent_branch,
|
||||||
|
A2APollingStatusEvent(
|
||||||
|
task_id=task_id,
|
||||||
|
state=str(task.status.state.value) if task.status.state else "unknown",
|
||||||
|
elapsed_seconds=elapsed,
|
||||||
|
poll_count=poll_count,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
if task.status.state in TERMINAL_STATES:
|
||||||
|
return task
|
||||||
|
|
||||||
|
if task.status.state in {TaskState.input_required, TaskState.auth_required}:
|
||||||
|
return task
|
||||||
|
|
||||||
|
if elapsed > polling_timeout:
|
||||||
|
raise A2APollingTimeoutError(
|
||||||
|
f"Polling timeout after {polling_timeout}s ({poll_count} polls)"
|
||||||
|
)
|
||||||
|
|
||||||
|
if max_polls and poll_count >= max_polls:
|
||||||
|
raise A2APollingTimeoutError(
|
||||||
|
f"Max polls ({max_polls}) exceeded after {elapsed:.1f}s"
|
||||||
|
)
|
||||||
|
|
||||||
|
await asyncio.sleep(polling_interval)
|
||||||
|
|
||||||
|
|
||||||
|
async def execute_polling_delegation(
|
||||||
|
client: Client,
|
||||||
|
message: Message,
|
||||||
|
polling_interval: float,
|
||||||
|
polling_timeout: float,
|
||||||
|
endpoint: str,
|
||||||
|
agent_branch: Any | None,
|
||||||
|
turn_number: int,
|
||||||
|
is_multiturn: bool,
|
||||||
|
agent_role: str | None,
|
||||||
|
new_messages: list[Message],
|
||||||
|
agent_card: AgentCard,
|
||||||
|
history_length: int = 100,
|
||||||
|
max_polls: int | None = None,
|
||||||
|
) -> TaskStateResult:
|
||||||
|
"""Execute A2A delegation using polling for updates.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
client: A2A client instance
|
||||||
|
message: Message to send
|
||||||
|
polling_interval: Seconds between poll attempts
|
||||||
|
polling_timeout: Max seconds before timeout
|
||||||
|
endpoint: A2A agent endpoint URL
|
||||||
|
agent_branch: Agent tree branch for logging
|
||||||
|
turn_number: Current turn number
|
||||||
|
is_multiturn: Whether this is a multi-turn conversation
|
||||||
|
agent_role: Agent role for logging
|
||||||
|
new_messages: List to collect messages
|
||||||
|
agent_card: The agent card
|
||||||
|
history_length: Number of messages to retrieve per poll
|
||||||
|
max_polls: Max number of poll attempts (None = unlimited)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with status, result/error, and history
|
||||||
|
"""
|
||||||
|
task_id: str | None = None
|
||||||
|
|
||||||
|
async for event in client.send_message(message):
|
||||||
|
if isinstance(event, Message):
|
||||||
|
new_messages.append(event)
|
||||||
|
result_parts = [
|
||||||
|
part.root.text for part in event.parts if part.root.kind == "text"
|
||||||
|
]
|
||||||
|
response_text = " ".join(result_parts) if result_parts else ""
|
||||||
|
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
None,
|
||||||
|
A2AResponseReceivedEvent(
|
||||||
|
response=response_text,
|
||||||
|
turn_number=turn_number,
|
||||||
|
is_multiturn=is_multiturn,
|
||||||
|
status="completed",
|
||||||
|
agent_role=agent_role,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
return TaskStateResult(
|
||||||
|
status=TaskState.completed,
|
||||||
|
result=response_text,
|
||||||
|
history=new_messages,
|
||||||
|
agent_card=agent_card,
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(event, tuple):
|
||||||
|
a2a_task, _ = event
|
||||||
|
task_id = a2a_task.id
|
||||||
|
|
||||||
|
if a2a_task.status.state in TERMINAL_STATES | {
|
||||||
|
TaskState.input_required,
|
||||||
|
TaskState.auth_required,
|
||||||
|
}:
|
||||||
|
result = process_task_state(
|
||||||
|
a2a_task=a2a_task,
|
||||||
|
new_messages=new_messages,
|
||||||
|
agent_card=agent_card,
|
||||||
|
turn_number=turn_number,
|
||||||
|
is_multiturn=is_multiturn,
|
||||||
|
agent_role=agent_role,
|
||||||
|
)
|
||||||
|
if result:
|
||||||
|
return result
|
||||||
|
break
|
||||||
|
|
||||||
|
if not task_id:
|
||||||
|
return TaskStateResult(
|
||||||
|
status=TaskState.failed,
|
||||||
|
error="No task ID received from initial message",
|
||||||
|
history=new_messages,
|
||||||
|
)
|
||||||
|
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
agent_branch,
|
||||||
|
A2APollingStartedEvent(
|
||||||
|
task_id=task_id,
|
||||||
|
polling_interval=polling_interval,
|
||||||
|
endpoint=endpoint,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
final_task = await poll_task_until_complete(
|
||||||
|
client=client,
|
||||||
|
task_id=task_id,
|
||||||
|
polling_interval=polling_interval,
|
||||||
|
polling_timeout=polling_timeout,
|
||||||
|
agent_branch=agent_branch,
|
||||||
|
history_length=history_length,
|
||||||
|
max_polls=max_polls,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = process_task_state(
|
||||||
|
a2a_task=final_task,
|
||||||
|
new_messages=new_messages,
|
||||||
|
agent_card=agent_card,
|
||||||
|
turn_number=turn_number,
|
||||||
|
is_multiturn=is_multiturn,
|
||||||
|
agent_role=agent_role,
|
||||||
|
)
|
||||||
|
if result:
|
||||||
|
return result
|
||||||
|
|
||||||
|
return TaskStateResult(
|
||||||
|
status=TaskState.failed,
|
||||||
|
error=f"Unexpected task state: {final_task.status.state}",
|
||||||
|
history=new_messages,
|
||||||
|
)
|
||||||
@@ -0,0 +1,3 @@
|
|||||||
|
"""Push notification (webhook) update mechanism handler."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
152
lib/crewai/src/crewai/a2a/updates/streaming/handler.py
Normal file
152
lib/crewai/src/crewai/a2a/updates/streaming/handler.py
Normal file
@@ -0,0 +1,152 @@
|
|||||||
|
"""Streaming (SSE) update mechanism handler."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from a2a.client import Client
|
||||||
|
from a2a.client.errors import A2AClientHTTPError
|
||||||
|
from a2a.types import (
|
||||||
|
AgentCard,
|
||||||
|
Message,
|
||||||
|
Part,
|
||||||
|
Role,
|
||||||
|
TaskArtifactUpdateEvent,
|
||||||
|
TaskState,
|
||||||
|
TaskStatusUpdateEvent,
|
||||||
|
TextPart,
|
||||||
|
)
|
||||||
|
|
||||||
|
from crewai.a2a.task_helpers import TaskStateResult, process_task_state
|
||||||
|
from crewai.events.event_bus import crewai_event_bus
|
||||||
|
from crewai.events.types.a2a_events import A2AResponseReceivedEvent
|
||||||
|
|
||||||
|
|
||||||
|
async def execute_streaming_delegation(
|
||||||
|
client: Client,
|
||||||
|
message: Message,
|
||||||
|
context_id: str | None,
|
||||||
|
task_id: str | None,
|
||||||
|
turn_number: int,
|
||||||
|
is_multiturn: bool,
|
||||||
|
agent_role: str | None,
|
||||||
|
new_messages: list[Message],
|
||||||
|
agent_card: AgentCard,
|
||||||
|
) -> TaskStateResult:
|
||||||
|
"""Execute A2A delegation using SSE streaming for updates.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
client: A2A client instance
|
||||||
|
message: Message to send
|
||||||
|
context_id: Context ID for correlation
|
||||||
|
task_id: Task ID for correlation
|
||||||
|
turn_number: Current turn number
|
||||||
|
is_multiturn: Whether this is a multi-turn conversation
|
||||||
|
agent_role: Agent role for logging
|
||||||
|
new_messages: List to collect messages
|
||||||
|
agent_card: The agent card
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with status, result/error, and history
|
||||||
|
"""
|
||||||
|
result_parts: list[str] = []
|
||||||
|
final_result: TaskStateResult | None = None
|
||||||
|
event_stream = client.send_message(message)
|
||||||
|
|
||||||
|
try:
|
||||||
|
async for event in event_stream:
|
||||||
|
if isinstance(event, Message):
|
||||||
|
new_messages.append(event)
|
||||||
|
for part in event.parts:
|
||||||
|
if part.root.kind == "text":
|
||||||
|
text = part.root.text
|
||||||
|
result_parts.append(text)
|
||||||
|
|
||||||
|
elif isinstance(event, tuple):
|
||||||
|
a2a_task, update = event
|
||||||
|
|
||||||
|
if isinstance(update, TaskArtifactUpdateEvent):
|
||||||
|
artifact = update.artifact
|
||||||
|
result_parts.extend(
|
||||||
|
part.root.text
|
||||||
|
for part in artifact.parts
|
||||||
|
if part.root.kind == "text"
|
||||||
|
)
|
||||||
|
|
||||||
|
is_final_update = False
|
||||||
|
if isinstance(update, TaskStatusUpdateEvent):
|
||||||
|
is_final_update = update.final
|
||||||
|
|
||||||
|
if not is_final_update and a2a_task.status.state not in [
|
||||||
|
TaskState.completed,
|
||||||
|
TaskState.input_required,
|
||||||
|
TaskState.failed,
|
||||||
|
TaskState.rejected,
|
||||||
|
TaskState.auth_required,
|
||||||
|
TaskState.canceled,
|
||||||
|
]:
|
||||||
|
continue
|
||||||
|
|
||||||
|
final_result = process_task_state(
|
||||||
|
a2a_task=a2a_task,
|
||||||
|
new_messages=new_messages,
|
||||||
|
agent_card=agent_card,
|
||||||
|
turn_number=turn_number,
|
||||||
|
is_multiturn=is_multiturn,
|
||||||
|
agent_role=agent_role,
|
||||||
|
result_parts=result_parts,
|
||||||
|
)
|
||||||
|
if final_result:
|
||||||
|
break
|
||||||
|
|
||||||
|
except A2AClientHTTPError as e:
|
||||||
|
error_msg = f"HTTP Error {e.status_code}: {e!s}"
|
||||||
|
|
||||||
|
error_message = Message(
|
||||||
|
role=Role.agent,
|
||||||
|
message_id=str(uuid.uuid4()),
|
||||||
|
parts=[Part(root=TextPart(text=error_msg))],
|
||||||
|
context_id=context_id,
|
||||||
|
task_id=task_id,
|
||||||
|
)
|
||||||
|
new_messages.append(error_message)
|
||||||
|
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
None,
|
||||||
|
A2AResponseReceivedEvent(
|
||||||
|
response=error_msg,
|
||||||
|
turn_number=turn_number,
|
||||||
|
is_multiturn=is_multiturn,
|
||||||
|
status="failed",
|
||||||
|
agent_role=agent_role,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return TaskStateResult(
|
||||||
|
status=TaskState.failed,
|
||||||
|
error=error_msg,
|
||||||
|
history=new_messages,
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
current_exception: Exception | BaseException | None = e
|
||||||
|
while current_exception:
|
||||||
|
if hasattr(current_exception, "response"):
|
||||||
|
response = current_exception.response
|
||||||
|
if hasattr(response, "text"):
|
||||||
|
break
|
||||||
|
if current_exception and hasattr(current_exception, "__cause__"):
|
||||||
|
current_exception = current_exception.__cause__
|
||||||
|
raise
|
||||||
|
|
||||||
|
finally:
|
||||||
|
if hasattr(event_stream, "aclose"):
|
||||||
|
await event_stream.aclose()
|
||||||
|
|
||||||
|
if final_result:
|
||||||
|
return final_result
|
||||||
|
|
||||||
|
return TaskStateResult(
|
||||||
|
status=TaskState.completed,
|
||||||
|
result=" ".join(result_parts) if result_parts else "",
|
||||||
|
history=new_messages,
|
||||||
|
)
|
||||||
@@ -15,7 +15,7 @@ class A2AEventBase(BaseEvent):
|
|||||||
from_task: Any | None = None
|
from_task: Any | None = None
|
||||||
from_agent: Any | None = None
|
from_agent: Any | None = None
|
||||||
|
|
||||||
def __init__(self, **data):
|
def __init__(self, **data: Any) -> None:
|
||||||
"""Initialize A2A event, extracting task and agent metadata."""
|
"""Initialize A2A event, extracting task and agent metadata."""
|
||||||
if data.get("from_task"):
|
if data.get("from_task"):
|
||||||
task = data["from_task"]
|
task = data["from_task"]
|
||||||
@@ -139,3 +139,35 @@ class A2AConversationCompletedEvent(A2AEventBase):
|
|||||||
final_result: str | None = None
|
final_result: str | None = None
|
||||||
error: str | None = None
|
error: str | None = None
|
||||||
total_turns: int
|
total_turns: int
|
||||||
|
|
||||||
|
|
||||||
|
class A2APollingStartedEvent(A2AEventBase):
|
||||||
|
"""Event emitted when polling mode begins for A2A delegation.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
task_id: A2A task ID being polled
|
||||||
|
polling_interval: Seconds between poll attempts
|
||||||
|
endpoint: A2A agent endpoint URL
|
||||||
|
"""
|
||||||
|
|
||||||
|
type: str = "a2a_polling_started"
|
||||||
|
task_id: str
|
||||||
|
polling_interval: float
|
||||||
|
endpoint: str
|
||||||
|
|
||||||
|
|
||||||
|
class A2APollingStatusEvent(A2AEventBase):
|
||||||
|
"""Event emitted on each polling iteration.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
task_id: A2A task ID being polled
|
||||||
|
state: Current task state from remote agent
|
||||||
|
elapsed_seconds: Time since polling started
|
||||||
|
poll_count: Number of polls completed
|
||||||
|
"""
|
||||||
|
|
||||||
|
type: str = "a2a_polling_status"
|
||||||
|
task_id: str
|
||||||
|
state: str
|
||||||
|
elapsed_seconds: float
|
||||||
|
poll_count: int
|
||||||
|
|||||||
Reference in New Issue
Block a user