chore: refactor handlers to unified protocol

This commit is contained in:
Greyson LaLonde
2026-01-05 20:07:48 -05:00
parent c75391dfbb
commit 33caeeba28
9 changed files with 428 additions and 278 deletions

View File

@@ -15,6 +15,23 @@ if TYPE_CHECKING:
from a2a.types import Task as A2ATask from a2a.types import Task as A2ATask
TERMINAL_STATES: frozenset[TaskState] = frozenset(
{
TaskState.completed,
TaskState.failed,
TaskState.rejected,
TaskState.canceled,
}
)
ACTIONABLE_STATES: frozenset[TaskState] = frozenset(
{
TaskState.input_required,
TaskState.auth_required,
}
)
class TaskStateResult(TypedDict): class TaskStateResult(TypedDict):
"""Result dictionary from processing A2A task state.""" """Result dictionary from processing A2A task state."""
@@ -154,8 +171,8 @@ def process_task_state(
role=Role.agent, role=Role.agent,
message_id=str(uuid.uuid4()), message_id=str(uuid.uuid4()),
parts=[Part(root=TextPart(text=response_text))], parts=[Part(root=TextPart(text=response_text))],
context_id=getattr(a2a_task, "context_id", None), context_id=a2a_task.context_id,
task_id=getattr(a2a_task, "task_id", None), task_id=a2a_task.id,
) )
new_messages.append(agent_message) new_messages.append(agent_message)

View File

@@ -1,15 +1,33 @@
"""A2A update mechanism configuration types.""" """A2A update mechanism configuration types."""
from crewai.a2a.updates.base import (
BaseHandlerKwargs,
PollingHandlerKwargs,
PushNotificationHandlerKwargs,
StreamingHandlerKwargs,
UpdateHandler,
)
from crewai.a2a.updates.polling.config import PollingConfig from crewai.a2a.updates.polling.config import PollingConfig
from crewai.a2a.updates.polling.handler import PollingHandler
from crewai.a2a.updates.push_notifications.config import PushNotificationConfig from crewai.a2a.updates.push_notifications.config import PushNotificationConfig
from crewai.a2a.updates.push_notifications.handler import PushNotificationHandler
from crewai.a2a.updates.streaming.config import StreamingConfig from crewai.a2a.updates.streaming.config import StreamingConfig
from crewai.a2a.updates.streaming.handler import StreamingHandler
UpdateConfig = PollingConfig | StreamingConfig | PushNotificationConfig UpdateConfig = PollingConfig | StreamingConfig | PushNotificationConfig
__all__ = [ __all__ = [
"BaseHandlerKwargs",
"PollingConfig", "PollingConfig",
"PollingHandler",
"PollingHandlerKwargs",
"PushNotificationConfig", "PushNotificationConfig",
"PushNotificationHandler",
"PushNotificationHandlerKwargs",
"StreamingConfig", "StreamingConfig",
"StreamingHandler",
"StreamingHandlerKwargs",
"UpdateConfig", "UpdateConfig",
"UpdateHandler",
] ]

View File

@@ -0,0 +1,68 @@
"""Base types for A2A update mechanism handlers."""
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Protocol, TypedDict
if TYPE_CHECKING:
from a2a.client import Client
from a2a.types import AgentCard, Message
from crewai.a2a.task_helpers import TaskStateResult
class BaseHandlerKwargs(TypedDict, total=False):
"""Base kwargs shared by all handlers."""
turn_number: int
is_multiturn: bool
agent_role: str | None
class PollingHandlerKwargs(BaseHandlerKwargs, total=False):
"""Kwargs for polling handler."""
polling_interval: float
polling_timeout: float
endpoint: str
agent_branch: Any
history_length: int
max_polls: int | None
class StreamingHandlerKwargs(BaseHandlerKwargs, total=False):
"""Kwargs for streaming handler."""
context_id: str | None
task_id: str | None
class PushNotificationHandlerKwargs(BaseHandlerKwargs, total=False):
"""Kwargs for push notification handler."""
class UpdateHandler(Protocol):
"""Protocol for A2A update mechanism handlers."""
@staticmethod
async def execute(
client: Client,
message: Message,
new_messages: list[Message],
agent_card: AgentCard,
**kwargs: Any,
) -> TaskStateResult:
"""Execute the update mechanism and return result.
Args:
client: A2A client instance.
message: Message to send.
new_messages: List to collect messages (modified in place).
agent_card: The agent card.
**kwargs: Additional handler-specific parameters.
Returns:
Result dictionary with status, result/error, and history.
"""
...

View File

@@ -15,9 +15,11 @@ class PollingConfig(BaseModel):
history_length: Number of messages to retrieve per poll. history_length: Number of messages to retrieve per poll.
""" """
interval: float = Field(default=2.0, description="Seconds between poll attempts") interval: float = Field(
timeout: float | None = Field(default=None, description="Max seconds to poll") default=2.0, gt=0, description="Seconds between poll attempts"
max_polls: int | None = Field(default=None, description="Max poll attempts") )
history_length: int = Field( timeout: float | None = Field(default=None, gt=0, description="Max seconds to poll")
default=100, description="Messages to retrieve per poll" max_polls: int | None = Field(default=None, gt=0, description="Max poll attempts")
history_length: int = Field(
default=100, gt=0, description="Messages to retrieve per poll"
) )

View File

@@ -4,7 +4,7 @@ from __future__ import annotations
import asyncio import asyncio
import time import time
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any, Unpack
from a2a.client import Client from a2a.client import Client
from a2a.types import ( from a2a.types import (
@@ -15,7 +15,13 @@ from a2a.types import (
) )
from crewai.a2a.errors import A2APollingTimeoutError from crewai.a2a.errors import A2APollingTimeoutError
from crewai.a2a.task_helpers import TaskStateResult, process_task_state from crewai.a2a.task_helpers import (
ACTIONABLE_STATES,
TERMINAL_STATES,
TaskStateResult,
process_task_state,
)
from crewai.a2a.updates.base import PollingHandlerKwargs
from crewai.events.event_bus import crewai_event_bus from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.a2a_events import ( from crewai.events.types.a2a_events import (
A2APollingStartedEvent, A2APollingStartedEvent,
@@ -28,15 +34,7 @@ if TYPE_CHECKING:
from a2a.types import Task as A2ATask from a2a.types import Task as A2ATask
TERMINAL_STATES = { async def _poll_task_until_complete(
TaskState.completed,
TaskState.failed,
TaskState.rejected,
TaskState.canceled,
}
async def poll_task_until_complete(
client: Client, client: Client,
task_id: str, task_id: str,
polling_interval: float, polling_interval: float,
@@ -85,7 +83,7 @@ async def poll_task_until_complete(
if task.status.state in TERMINAL_STATES: if task.status.state in TERMINAL_STATES:
return task return task
if task.status.state in {TaskState.input_required, TaskState.auth_required}: if task.status.state in ACTIONABLE_STATES:
return task return task
if elapsed > polling_timeout: if elapsed > polling_timeout:
@@ -101,128 +99,123 @@ async def poll_task_until_complete(
await asyncio.sleep(polling_interval) await asyncio.sleep(polling_interval)
async def execute_polling_delegation( class PollingHandler:
client: Client, """Polling-based update handler."""
message: Message,
polling_interval: float,
polling_timeout: float,
endpoint: str,
agent_branch: Any | None,
turn_number: int,
is_multiturn: bool,
agent_role: str | None,
new_messages: list[Message],
agent_card: AgentCard,
history_length: int = 100,
max_polls: int | None = None,
) -> TaskStateResult:
"""Execute A2A delegation using polling for updates.
Args: @staticmethod
client: A2A client instance async def execute(
message: Message to send client: Client,
polling_interval: Seconds between poll attempts message: Message,
polling_timeout: Max seconds before timeout new_messages: list[Message],
endpoint: A2A agent endpoint URL agent_card: AgentCard,
agent_branch: Agent tree branch for logging **kwargs: Unpack[PollingHandlerKwargs],
turn_number: Current turn number ) -> TaskStateResult:
is_multiturn: Whether this is a multi-turn conversation """Execute A2A delegation using polling for updates.
agent_role: Agent role for logging
new_messages: List to collect messages
agent_card: The agent card
history_length: Number of messages to retrieve per poll
max_polls: Max number of poll attempts (None = unlimited)
Returns: Args:
Dictionary with status, result/error, and history client: A2A client instance.
""" message: Message to send.
task_id: str | None = None new_messages: List to collect messages.
agent_card: The agent card.
**kwargs: Polling-specific parameters.
async for event in client.send_message(message): Returns:
if isinstance(event, Message): Dictionary with status, result/error, and history.
new_messages.append(event) """
result_parts = [ polling_interval = kwargs.get("polling_interval", 2.0)
part.root.text for part in event.parts if part.root.kind == "text" polling_timeout = kwargs.get("polling_timeout", 300.0)
] endpoint = kwargs.get("endpoint", "")
response_text = " ".join(result_parts) if result_parts else "" agent_branch = kwargs.get("agent_branch")
turn_number = kwargs.get("turn_number", 0)
is_multiturn = kwargs.get("is_multiturn", False)
agent_role = kwargs.get("agent_role")
history_length = kwargs.get("history_length", 100)
max_polls = kwargs.get("max_polls")
crewai_event_bus.emit( task_id: str | None = None
None,
A2AResponseReceivedEvent(
response=response_text,
turn_number=turn_number,
is_multiturn=is_multiturn,
status="completed",
agent_role=agent_role,
),
)
return TaskStateResult( async for event in client.send_message(message):
status=TaskState.completed, if isinstance(event, Message):
result=response_text, new_messages.append(event)
history=new_messages, result_parts = [
agent_card=agent_card, part.root.text for part in event.parts if part.root.kind == "text"
) ]
response_text = " ".join(result_parts) if result_parts else ""
if isinstance(event, tuple): crewai_event_bus.emit(
a2a_task, _ = event None,
task_id = a2a_task.id A2AResponseReceivedEvent(
response=response_text,
if a2a_task.status.state in TERMINAL_STATES | { turn_number=turn_number,
TaskState.input_required, is_multiturn=is_multiturn,
TaskState.auth_required, status="completed",
}: agent_role=agent_role,
result = process_task_state( ),
a2a_task=a2a_task,
new_messages=new_messages,
agent_card=agent_card,
turn_number=turn_number,
is_multiturn=is_multiturn,
agent_role=agent_role,
) )
if result:
return result
break
if not task_id: return TaskStateResult(
return TaskStateResult( status=TaskState.completed,
status=TaskState.failed, result=response_text,
error="No task ID received from initial message", history=new_messages,
history=new_messages, agent_card=agent_card,
)
if isinstance(event, tuple):
a2a_task, _ = event
task_id = a2a_task.id
if a2a_task.status.state in TERMINAL_STATES | ACTIONABLE_STATES:
result = process_task_state(
a2a_task=a2a_task,
new_messages=new_messages,
agent_card=agent_card,
turn_number=turn_number,
is_multiturn=is_multiturn,
agent_role=agent_role,
)
if result:
return result
break
if not task_id:
return TaskStateResult(
status=TaskState.failed,
error="No task ID received from initial message",
history=new_messages,
)
crewai_event_bus.emit(
agent_branch,
A2APollingStartedEvent(
task_id=task_id,
polling_interval=polling_interval,
endpoint=endpoint,
),
) )
crewai_event_bus.emit( final_task = await _poll_task_until_complete(
agent_branch, client=client,
A2APollingStartedEvent(
task_id=task_id, task_id=task_id,
polling_interval=polling_interval, polling_interval=polling_interval,
endpoint=endpoint, polling_timeout=polling_timeout,
), agent_branch=agent_branch,
) history_length=history_length,
max_polls=max_polls,
)
final_task = await poll_task_until_complete( result = process_task_state(
client=client, a2a_task=final_task,
task_id=task_id, new_messages=new_messages,
polling_interval=polling_interval, agent_card=agent_card,
polling_timeout=polling_timeout, turn_number=turn_number,
agent_branch=agent_branch, is_multiturn=is_multiturn,
history_length=history_length, agent_role=agent_role,
max_polls=max_polls, )
) if result:
return result
result = process_task_state( return TaskStateResult(
a2a_task=final_task, status=TaskState.failed,
new_messages=new_messages, error=f"Unexpected task state: {final_task.status.state}",
agent_card=agent_card, history=new_messages,
turn_number=turn_number, )
is_multiturn=is_multiturn,
agent_role=agent_role,
)
if result:
return result
return TaskStateResult(
status=TaskState.failed,
error=f"Unexpected task state: {final_task.status.state}",
history=new_messages,
)

View File

@@ -2,7 +2,7 @@
from __future__ import annotations from __future__ import annotations
from pydantic import BaseModel, Field from pydantic import AnyHttpUrl, BaseModel, Field
from crewai.a2a.auth.schemas import AuthScheme from crewai.a2a.auth.schemas import AuthScheme
@@ -17,7 +17,7 @@ class PushNotificationConfig(BaseModel):
authentication: Auth scheme for the callback endpoint. authentication: Auth scheme for the callback endpoint.
""" """
url: str = Field(description="Callback URL for push notifications") url: AnyHttpUrl = Field(description="Callback URL for push notifications")
id: str | None = Field(default=None, description="Unique config identifier") id: str | None = Field(default=None, description="Unique config identifier")
token: str | None = Field(default=None, description="Validation token") token: str | None = Field(default=None, description="Validation token")
authentication: AuthScheme | None = Field( authentication: AuthScheme | None = Field(

View File

@@ -1,3 +1,40 @@
"""Push notification (webhook) update mechanism handler.""" """Push notification (webhook) update mechanism handler."""
from __future__ import annotations from __future__ import annotations
from typing import Unpack
from a2a.client import Client
from a2a.types import AgentCard, Message
from crewai.a2a.task_helpers import TaskStateResult
from crewai.a2a.updates.base import PushNotificationHandlerKwargs
class PushNotificationHandler:
"""Push notification (webhook) based update handler."""
@staticmethod
async def execute(
client: Client,
message: Message,
new_messages: list[Message],
agent_card: AgentCard,
**kwargs: Unpack[PushNotificationHandlerKwargs],
) -> TaskStateResult:
"""Execute A2A delegation using push notifications for updates.
Args:
client: A2A client instance.
message: Message to send.
new_messages: List to collect messages.
agent_card: The agent card.
**kwargs: Push notification-specific parameters.
Raises:
NotImplementedError: Push notifications not yet implemented.
"""
raise NotImplementedError(
"Push notification update mechanism is not yet implemented. "
"Use PollingConfig or StreamingConfig instead."
)

View File

@@ -2,6 +2,7 @@
from __future__ import annotations from __future__ import annotations
from typing import Unpack
import uuid import uuid
from a2a.client import Client from a2a.client import Client
@@ -17,136 +18,126 @@ from a2a.types import (
TextPart, TextPart,
) )
from crewai.a2a.task_helpers import TaskStateResult, process_task_state from crewai.a2a.task_helpers import (
ACTIONABLE_STATES,
TERMINAL_STATES,
TaskStateResult,
process_task_state,
)
from crewai.a2a.updates.base import StreamingHandlerKwargs
from crewai.events.event_bus import crewai_event_bus from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.a2a_events import A2AResponseReceivedEvent from crewai.events.types.a2a_events import A2AResponseReceivedEvent
async def execute_streaming_delegation( class StreamingHandler:
client: Client, """SSE streaming-based update handler."""
message: Message,
context_id: str | None,
task_id: str | None,
turn_number: int,
is_multiturn: bool,
agent_role: str | None,
new_messages: list[Message],
agent_card: AgentCard,
) -> TaskStateResult:
"""Execute A2A delegation using SSE streaming for updates.
Args: @staticmethod
client: A2A client instance async def execute(
message: Message to send client: Client,
context_id: Context ID for correlation message: Message,
task_id: Task ID for correlation new_messages: list[Message],
turn_number: Current turn number agent_card: AgentCard,
is_multiturn: Whether this is a multi-turn conversation **kwargs: Unpack[StreamingHandlerKwargs],
agent_role: Agent role for logging ) -> TaskStateResult:
new_messages: List to collect messages """Execute A2A delegation using SSE streaming for updates.
agent_card: The agent card
Returns: Args:
Dictionary with status, result/error, and history client: A2A client instance.
""" message: Message to send.
result_parts: list[str] = [] new_messages: List to collect messages.
final_result: TaskStateResult | None = None agent_card: The agent card.
event_stream = client.send_message(message) **kwargs: Streaming-specific parameters.
try: Returns:
async for event in event_stream: Dictionary with status, result/error, and history.
if isinstance(event, Message): """
new_messages.append(event) context_id = kwargs.get("context_id")
for part in event.parts: task_id = kwargs.get("task_id")
if part.root.kind == "text": turn_number = kwargs.get("turn_number", 0)
text = part.root.text is_multiturn = kwargs.get("is_multiturn", False)
result_parts.append(text) agent_role = kwargs.get("agent_role")
elif isinstance(event, tuple): result_parts: list[str] = []
a2a_task, update = event final_result: TaskStateResult | None = None
event_stream = client.send_message(message)
if isinstance(update, TaskArtifactUpdateEvent): try:
artifact = update.artifact async for event in event_stream:
result_parts.extend( if isinstance(event, Message):
part.root.text new_messages.append(event)
for part in artifact.parts for part in event.parts:
if part.root.kind == "text" if part.root.kind == "text":
text = part.root.text
result_parts.append(text)
elif isinstance(event, tuple):
a2a_task, update = event
if isinstance(update, TaskArtifactUpdateEvent):
artifact = update.artifact
result_parts.extend(
part.root.text
for part in artifact.parts
if part.root.kind == "text"
)
is_final_update = False
if isinstance(update, TaskStatusUpdateEvent):
is_final_update = update.final
if (
not is_final_update
and a2a_task.status.state
not in TERMINAL_STATES | ACTIONABLE_STATES
):
continue
final_result = process_task_state(
a2a_task=a2a_task,
new_messages=new_messages,
agent_card=agent_card,
turn_number=turn_number,
is_multiturn=is_multiturn,
agent_role=agent_role,
result_parts=result_parts,
) )
if final_result:
break
is_final_update = False except A2AClientHTTPError as e:
if isinstance(update, TaskStatusUpdateEvent): error_msg = f"HTTP Error {e.status_code}: {e!s}"
is_final_update = update.final
if not is_final_update and a2a_task.status.state not in [ error_message = Message(
TaskState.completed, role=Role.agent,
TaskState.input_required, message_id=str(uuid.uuid4()),
TaskState.failed, parts=[Part(root=TextPart(text=error_msg))],
TaskState.rejected, context_id=context_id,
TaskState.auth_required, task_id=task_id,
TaskState.canceled, )
]: new_messages.append(error_message)
continue
final_result = process_task_state( crewai_event_bus.emit(
a2a_task=a2a_task, None,
new_messages=new_messages, A2AResponseReceivedEvent(
agent_card=agent_card, response=error_msg,
turn_number=turn_number, turn_number=turn_number,
is_multiturn=is_multiturn, is_multiturn=is_multiturn,
status="failed",
agent_role=agent_role, agent_role=agent_role,
result_parts=result_parts, ),
) )
if final_result: return TaskStateResult(
break status=TaskState.failed,
error=error_msg,
history=new_messages,
)
except A2AClientHTTPError as e: if final_result:
error_msg = f"HTTP Error {e.status_code}: {e!s}" return final_result
error_message = Message(
role=Role.agent,
message_id=str(uuid.uuid4()),
parts=[Part(root=TextPart(text=error_msg))],
context_id=context_id,
task_id=task_id,
)
new_messages.append(error_message)
crewai_event_bus.emit(
None,
A2AResponseReceivedEvent(
response=error_msg,
turn_number=turn_number,
is_multiturn=is_multiturn,
status="failed",
agent_role=agent_role,
),
)
return TaskStateResult( return TaskStateResult(
status=TaskState.failed, status=TaskState.completed,
error=error_msg, result=" ".join(result_parts) if result_parts else "",
history=new_messages, history=new_messages,
) )
except Exception as e:
current_exception: Exception | BaseException | None = e
while current_exception:
if hasattr(current_exception, "response"):
response = current_exception.response
if hasattr(response, "text"):
break
if current_exception and hasattr(current_exception, "__cause__"):
current_exception = current_exception.__cause__
raise
finally:
if hasattr(event_stream, "aclose"):
await event_stream.aclose()
if final_result:
return final_result
return TaskStateResult(
status=TaskState.completed,
result=" ".join(result_parts) if result_parts else "",
history=new_messages,
)

View File

@@ -34,9 +34,15 @@ from crewai.a2a.auth.utils import (
from crewai.a2a.config import A2AConfig from crewai.a2a.config import A2AConfig
from crewai.a2a.task_helpers import TaskStateResult from crewai.a2a.task_helpers import TaskStateResult
from crewai.a2a.types import PartsDict, PartsMetadataDict from crewai.a2a.types import PartsDict, PartsMetadataDict
from crewai.a2a.updates import PollingConfig, UpdateConfig from crewai.a2a.updates import (
from crewai.a2a.updates.polling.handler import execute_polling_delegation PollingConfig,
from crewai.a2a.updates.streaming.handler import execute_streaming_delegation PollingHandler,
PushNotificationConfig,
PushNotificationHandler,
StreamingConfig,
StreamingHandler,
UpdateConfig,
)
from crewai.events.event_bus import crewai_event_bus from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.a2a_events import ( from crewai.events.types.a2a_events import (
A2AConversationStartedEvent, A2AConversationStartedEvent,
@@ -53,6 +59,31 @@ if TYPE_CHECKING:
from crewai.a2a.auth.schemas import AuthScheme from crewai.a2a.auth.schemas import AuthScheme
HandlerType = (
type[PollingHandler] | type[StreamingHandler] | type[PushNotificationHandler]
)
HANDLER_REGISTRY: dict[type[UpdateConfig], HandlerType] = {
PollingConfig: PollingHandler,
StreamingConfig: StreamingHandler,
PushNotificationConfig: PushNotificationHandler,
}
def get_handler(config: UpdateConfig | None) -> HandlerType:
"""Get the handler class for a given update config.
Args:
config: Update mechanism configuration.
Returns:
Handler class for the config type, defaults to StreamingHandler.
"""
if config is None:
return StreamingHandler
return HANDLER_REGISTRY.get(type(config), StreamingHandler)
@lru_cache() @lru_cache()
def _fetch_agent_card_cached( def _fetch_agent_card_cached(
endpoint: str, endpoint: str,
@@ -448,14 +479,28 @@ async def _execute_a2a_delegation_async(
), ),
) )
polling_config = updates if isinstance(updates, PollingConfig) else None handler = get_handler(updates)
use_polling = polling_config is not None use_polling = isinstance(updates, PollingConfig)
polling_interval = polling_config.interval if polling_config else 2.0
effective_polling_timeout = ( handler_kwargs: dict[str, Any] = {
polling_config.timeout "turn_number": turn_number,
if polling_config and polling_config.timeout "is_multiturn": is_multiturn,
else float(timeout) "agent_role": agent_role,
) "context_id": context_id,
"task_id": task_id,
"endpoint": endpoint,
"agent_branch": agent_branch,
}
if isinstance(updates, PollingConfig):
handler_kwargs.update(
{
"polling_interval": updates.interval,
"polling_timeout": updates.timeout or float(timeout),
"history_length": updates.history_length,
"max_polls": updates.max_polls,
}
)
async with _create_a2a_client( async with _create_a2a_client(
agent_card=agent_card, agent_card=agent_card,
@@ -466,33 +511,12 @@ async def _execute_a2a_delegation_async(
auth=auth, auth=auth,
use_polling=use_polling, use_polling=use_polling,
) as client: ) as client:
if use_polling and polling_config: return await handler.execute(
return await execute_polling_delegation(
client=client,
message=message,
polling_interval=polling_interval,
polling_timeout=effective_polling_timeout,
endpoint=endpoint,
agent_branch=agent_branch,
turn_number=turn_number,
is_multiturn=is_multiturn,
agent_role=agent_role,
new_messages=new_messages,
agent_card=agent_card,
history_length=polling_config.history_length,
max_polls=polling_config.max_polls,
)
return await execute_streaming_delegation(
client=client, client=client,
message=message, message=message,
context_id=context_id,
task_id=task_id,
turn_number=turn_number,
is_multiturn=is_multiturn,
agent_role=agent_role,
new_messages=new_messages, new_messages=new_messages,
agent_card=agent_card, agent_card=agent_card,
**handler_kwargs,
) )