chore: refactor handlers to unified protocol

This commit is contained in:
Greyson LaLonde
2026-01-05 20:07:48 -05:00
parent c75391dfbb
commit 33caeeba28
9 changed files with 428 additions and 278 deletions

View File

@@ -15,6 +15,23 @@ if TYPE_CHECKING:
from a2a.types import Task as A2ATask
TERMINAL_STATES: frozenset[TaskState] = frozenset(
{
TaskState.completed,
TaskState.failed,
TaskState.rejected,
TaskState.canceled,
}
)
ACTIONABLE_STATES: frozenset[TaskState] = frozenset(
{
TaskState.input_required,
TaskState.auth_required,
}
)
class TaskStateResult(TypedDict):
"""Result dictionary from processing A2A task state."""
@@ -154,8 +171,8 @@ def process_task_state(
role=Role.agent,
message_id=str(uuid.uuid4()),
parts=[Part(root=TextPart(text=response_text))],
context_id=getattr(a2a_task, "context_id", None),
task_id=getattr(a2a_task, "task_id", None),
context_id=a2a_task.context_id,
task_id=a2a_task.id,
)
new_messages.append(agent_message)

View File

@@ -1,15 +1,33 @@
"""A2A update mechanism configuration types."""
from crewai.a2a.updates.base import (
BaseHandlerKwargs,
PollingHandlerKwargs,
PushNotificationHandlerKwargs,
StreamingHandlerKwargs,
UpdateHandler,
)
from crewai.a2a.updates.polling.config import PollingConfig
from crewai.a2a.updates.polling.handler import PollingHandler
from crewai.a2a.updates.push_notifications.config import PushNotificationConfig
from crewai.a2a.updates.push_notifications.handler import PushNotificationHandler
from crewai.a2a.updates.streaming.config import StreamingConfig
from crewai.a2a.updates.streaming.handler import StreamingHandler
UpdateConfig = PollingConfig | StreamingConfig | PushNotificationConfig
__all__ = [
"BaseHandlerKwargs",
"PollingConfig",
"PollingHandler",
"PollingHandlerKwargs",
"PushNotificationConfig",
"PushNotificationHandler",
"PushNotificationHandlerKwargs",
"StreamingConfig",
"StreamingHandler",
"StreamingHandlerKwargs",
"UpdateConfig",
"UpdateHandler",
]

View File

@@ -0,0 +1,68 @@
"""Base types for A2A update mechanism handlers."""
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Protocol, TypedDict
if TYPE_CHECKING:
from a2a.client import Client
from a2a.types import AgentCard, Message
from crewai.a2a.task_helpers import TaskStateResult
class BaseHandlerKwargs(TypedDict, total=False):
"""Base kwargs shared by all handlers."""
turn_number: int
is_multiturn: bool
agent_role: str | None
class PollingHandlerKwargs(BaseHandlerKwargs, total=False):
"""Kwargs for polling handler."""
polling_interval: float
polling_timeout: float
endpoint: str
agent_branch: Any
history_length: int
max_polls: int | None
class StreamingHandlerKwargs(BaseHandlerKwargs, total=False):
"""Kwargs for streaming handler."""
context_id: str | None
task_id: str | None
class PushNotificationHandlerKwargs(BaseHandlerKwargs, total=False):
"""Kwargs for push notification handler."""
class UpdateHandler(Protocol):
"""Protocol for A2A update mechanism handlers."""
@staticmethod
async def execute(
client: Client,
message: Message,
new_messages: list[Message],
agent_card: AgentCard,
**kwargs: Any,
) -> TaskStateResult:
"""Execute the update mechanism and return result.
Args:
client: A2A client instance.
message: Message to send.
new_messages: List to collect messages (modified in place).
agent_card: The agent card.
**kwargs: Additional handler-specific parameters.
Returns:
Result dictionary with status, result/error, and history.
"""
...

View File

@@ -15,9 +15,11 @@ class PollingConfig(BaseModel):
history_length: Number of messages to retrieve per poll.
"""
interval: float = Field(default=2.0, description="Seconds between poll attempts")
timeout: float | None = Field(default=None, description="Max seconds to poll")
max_polls: int | None = Field(default=None, description="Max poll attempts")
history_length: int = Field(
default=100, description="Messages to retrieve per poll"
interval: float = Field(
default=2.0, gt=0, description="Seconds between poll attempts"
)
timeout: float | None = Field(default=None, gt=0, description="Max seconds to poll")
max_polls: int | None = Field(default=None, gt=0, description="Max poll attempts")
history_length: int = Field(
default=100, gt=0, description="Messages to retrieve per poll"
)

View File

@@ -4,7 +4,7 @@ from __future__ import annotations
import asyncio
import time
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Unpack
from a2a.client import Client
from a2a.types import (
@@ -15,7 +15,13 @@ from a2a.types import (
)
from crewai.a2a.errors import A2APollingTimeoutError
from crewai.a2a.task_helpers import TaskStateResult, process_task_state
from crewai.a2a.task_helpers import (
ACTIONABLE_STATES,
TERMINAL_STATES,
TaskStateResult,
process_task_state,
)
from crewai.a2a.updates.base import PollingHandlerKwargs
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.a2a_events import (
A2APollingStartedEvent,
@@ -28,15 +34,7 @@ if TYPE_CHECKING:
from a2a.types import Task as A2ATask
TERMINAL_STATES = {
TaskState.completed,
TaskState.failed,
TaskState.rejected,
TaskState.canceled,
}
async def poll_task_until_complete(
async def _poll_task_until_complete(
client: Client,
task_id: str,
polling_interval: float,
@@ -85,7 +83,7 @@ async def poll_task_until_complete(
if task.status.state in TERMINAL_STATES:
return task
if task.status.state in {TaskState.input_required, TaskState.auth_required}:
if task.status.state in ACTIONABLE_STATES:
return task
if elapsed > polling_timeout:
@@ -101,128 +99,123 @@ async def poll_task_until_complete(
await asyncio.sleep(polling_interval)
async def execute_polling_delegation(
client: Client,
message: Message,
polling_interval: float,
polling_timeout: float,
endpoint: str,
agent_branch: Any | None,
turn_number: int,
is_multiturn: bool,
agent_role: str | None,
new_messages: list[Message],
agent_card: AgentCard,
history_length: int = 100,
max_polls: int | None = None,
) -> TaskStateResult:
"""Execute A2A delegation using polling for updates.
class PollingHandler:
"""Polling-based update handler."""
Args:
client: A2A client instance
message: Message to send
polling_interval: Seconds between poll attempts
polling_timeout: Max seconds before timeout
endpoint: A2A agent endpoint URL
agent_branch: Agent tree branch for logging
turn_number: Current turn number
is_multiturn: Whether this is a multi-turn conversation
agent_role: Agent role for logging
new_messages: List to collect messages
agent_card: The agent card
history_length: Number of messages to retrieve per poll
max_polls: Max number of poll attempts (None = unlimited)
@staticmethod
async def execute(
client: Client,
message: Message,
new_messages: list[Message],
agent_card: AgentCard,
**kwargs: Unpack[PollingHandlerKwargs],
) -> TaskStateResult:
"""Execute A2A delegation using polling for updates.
Returns:
Dictionary with status, result/error, and history
"""
task_id: str | None = None
Args:
client: A2A client instance.
message: Message to send.
new_messages: List to collect messages.
agent_card: The agent card.
**kwargs: Polling-specific parameters.
async for event in client.send_message(message):
if isinstance(event, Message):
new_messages.append(event)
result_parts = [
part.root.text for part in event.parts if part.root.kind == "text"
]
response_text = " ".join(result_parts) if result_parts else ""
Returns:
Dictionary with status, result/error, and history.
"""
polling_interval = kwargs.get("polling_interval", 2.0)
polling_timeout = kwargs.get("polling_timeout", 300.0)
endpoint = kwargs.get("endpoint", "")
agent_branch = kwargs.get("agent_branch")
turn_number = kwargs.get("turn_number", 0)
is_multiturn = kwargs.get("is_multiturn", False)
agent_role = kwargs.get("agent_role")
history_length = kwargs.get("history_length", 100)
max_polls = kwargs.get("max_polls")
crewai_event_bus.emit(
None,
A2AResponseReceivedEvent(
response=response_text,
turn_number=turn_number,
is_multiturn=is_multiturn,
status="completed",
agent_role=agent_role,
),
)
task_id: str | None = None
return TaskStateResult(
status=TaskState.completed,
result=response_text,
history=new_messages,
agent_card=agent_card,
)
async for event in client.send_message(message):
if isinstance(event, Message):
new_messages.append(event)
result_parts = [
part.root.text for part in event.parts if part.root.kind == "text"
]
response_text = " ".join(result_parts) if result_parts else ""
if isinstance(event, tuple):
a2a_task, _ = event
task_id = a2a_task.id
if a2a_task.status.state in TERMINAL_STATES | {
TaskState.input_required,
TaskState.auth_required,
}:
result = process_task_state(
a2a_task=a2a_task,
new_messages=new_messages,
agent_card=agent_card,
turn_number=turn_number,
is_multiturn=is_multiturn,
agent_role=agent_role,
crewai_event_bus.emit(
None,
A2AResponseReceivedEvent(
response=response_text,
turn_number=turn_number,
is_multiturn=is_multiturn,
status="completed",
agent_role=agent_role,
),
)
if result:
return result
break
if not task_id:
return TaskStateResult(
status=TaskState.failed,
error="No task ID received from initial message",
history=new_messages,
return TaskStateResult(
status=TaskState.completed,
result=response_text,
history=new_messages,
agent_card=agent_card,
)
if isinstance(event, tuple):
a2a_task, _ = event
task_id = a2a_task.id
if a2a_task.status.state in TERMINAL_STATES | ACTIONABLE_STATES:
result = process_task_state(
a2a_task=a2a_task,
new_messages=new_messages,
agent_card=agent_card,
turn_number=turn_number,
is_multiturn=is_multiturn,
agent_role=agent_role,
)
if result:
return result
break
if not task_id:
return TaskStateResult(
status=TaskState.failed,
error="No task ID received from initial message",
history=new_messages,
)
crewai_event_bus.emit(
agent_branch,
A2APollingStartedEvent(
task_id=task_id,
polling_interval=polling_interval,
endpoint=endpoint,
),
)
crewai_event_bus.emit(
agent_branch,
A2APollingStartedEvent(
final_task = await _poll_task_until_complete(
client=client,
task_id=task_id,
polling_interval=polling_interval,
endpoint=endpoint,
),
)
polling_timeout=polling_timeout,
agent_branch=agent_branch,
history_length=history_length,
max_polls=max_polls,
)
final_task = await poll_task_until_complete(
client=client,
task_id=task_id,
polling_interval=polling_interval,
polling_timeout=polling_timeout,
agent_branch=agent_branch,
history_length=history_length,
max_polls=max_polls,
)
result = process_task_state(
a2a_task=final_task,
new_messages=new_messages,
agent_card=agent_card,
turn_number=turn_number,
is_multiturn=is_multiturn,
agent_role=agent_role,
)
if result:
return result
result = process_task_state(
a2a_task=final_task,
new_messages=new_messages,
agent_card=agent_card,
turn_number=turn_number,
is_multiturn=is_multiturn,
agent_role=agent_role,
)
if result:
return result
return TaskStateResult(
status=TaskState.failed,
error=f"Unexpected task state: {final_task.status.state}",
history=new_messages,
)
return TaskStateResult(
status=TaskState.failed,
error=f"Unexpected task state: {final_task.status.state}",
history=new_messages,
)

View File

@@ -2,7 +2,7 @@
from __future__ import annotations
from pydantic import BaseModel, Field
from pydantic import AnyHttpUrl, BaseModel, Field
from crewai.a2a.auth.schemas import AuthScheme
@@ -17,7 +17,7 @@ class PushNotificationConfig(BaseModel):
authentication: Auth scheme for the callback endpoint.
"""
url: str = Field(description="Callback URL for push notifications")
url: AnyHttpUrl = Field(description="Callback URL for push notifications")
id: str | None = Field(default=None, description="Unique config identifier")
token: str | None = Field(default=None, description="Validation token")
authentication: AuthScheme | None = Field(

View File

@@ -1,3 +1,40 @@
"""Push notification (webhook) update mechanism handler."""
from __future__ import annotations
from typing import Unpack
from a2a.client import Client
from a2a.types import AgentCard, Message
from crewai.a2a.task_helpers import TaskStateResult
from crewai.a2a.updates.base import PushNotificationHandlerKwargs
class PushNotificationHandler:
"""Push notification (webhook) based update handler."""
@staticmethod
async def execute(
client: Client,
message: Message,
new_messages: list[Message],
agent_card: AgentCard,
**kwargs: Unpack[PushNotificationHandlerKwargs],
) -> TaskStateResult:
"""Execute A2A delegation using push notifications for updates.
Args:
client: A2A client instance.
message: Message to send.
new_messages: List to collect messages.
agent_card: The agent card.
**kwargs: Push notification-specific parameters.
Raises:
NotImplementedError: Push notifications not yet implemented.
"""
raise NotImplementedError(
"Push notification update mechanism is not yet implemented. "
"Use PollingConfig or StreamingConfig instead."
)

View File

@@ -2,6 +2,7 @@
from __future__ import annotations
from typing import Unpack
import uuid
from a2a.client import Client
@@ -17,136 +18,126 @@ from a2a.types import (
TextPart,
)
from crewai.a2a.task_helpers import TaskStateResult, process_task_state
from crewai.a2a.task_helpers import (
ACTIONABLE_STATES,
TERMINAL_STATES,
TaskStateResult,
process_task_state,
)
from crewai.a2a.updates.base import StreamingHandlerKwargs
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.a2a_events import A2AResponseReceivedEvent
async def execute_streaming_delegation(
client: Client,
message: Message,
context_id: str | None,
task_id: str | None,
turn_number: int,
is_multiturn: bool,
agent_role: str | None,
new_messages: list[Message],
agent_card: AgentCard,
) -> TaskStateResult:
"""Execute A2A delegation using SSE streaming for updates.
class StreamingHandler:
"""SSE streaming-based update handler."""
Args:
client: A2A client instance
message: Message to send
context_id: Context ID for correlation
task_id: Task ID for correlation
turn_number: Current turn number
is_multiturn: Whether this is a multi-turn conversation
agent_role: Agent role for logging
new_messages: List to collect messages
agent_card: The agent card
@staticmethod
async def execute(
client: Client,
message: Message,
new_messages: list[Message],
agent_card: AgentCard,
**kwargs: Unpack[StreamingHandlerKwargs],
) -> TaskStateResult:
"""Execute A2A delegation using SSE streaming for updates.
Returns:
Dictionary with status, result/error, and history
"""
result_parts: list[str] = []
final_result: TaskStateResult | None = None
event_stream = client.send_message(message)
Args:
client: A2A client instance.
message: Message to send.
new_messages: List to collect messages.
agent_card: The agent card.
**kwargs: Streaming-specific parameters.
try:
async for event in event_stream:
if isinstance(event, Message):
new_messages.append(event)
for part in event.parts:
if part.root.kind == "text":
text = part.root.text
result_parts.append(text)
Returns:
Dictionary with status, result/error, and history.
"""
context_id = kwargs.get("context_id")
task_id = kwargs.get("task_id")
turn_number = kwargs.get("turn_number", 0)
is_multiturn = kwargs.get("is_multiturn", False)
agent_role = kwargs.get("agent_role")
elif isinstance(event, tuple):
a2a_task, update = event
result_parts: list[str] = []
final_result: TaskStateResult | None = None
event_stream = client.send_message(message)
if isinstance(update, TaskArtifactUpdateEvent):
artifact = update.artifact
result_parts.extend(
part.root.text
for part in artifact.parts
if part.root.kind == "text"
try:
async for event in event_stream:
if isinstance(event, Message):
new_messages.append(event)
for part in event.parts:
if part.root.kind == "text":
text = part.root.text
result_parts.append(text)
elif isinstance(event, tuple):
a2a_task, update = event
if isinstance(update, TaskArtifactUpdateEvent):
artifact = update.artifact
result_parts.extend(
part.root.text
for part in artifact.parts
if part.root.kind == "text"
)
is_final_update = False
if isinstance(update, TaskStatusUpdateEvent):
is_final_update = update.final
if (
not is_final_update
and a2a_task.status.state
not in TERMINAL_STATES | ACTIONABLE_STATES
):
continue
final_result = process_task_state(
a2a_task=a2a_task,
new_messages=new_messages,
agent_card=agent_card,
turn_number=turn_number,
is_multiturn=is_multiturn,
agent_role=agent_role,
result_parts=result_parts,
)
if final_result:
break
is_final_update = False
if isinstance(update, TaskStatusUpdateEvent):
is_final_update = update.final
except A2AClientHTTPError as e:
error_msg = f"HTTP Error {e.status_code}: {e!s}"
if not is_final_update and a2a_task.status.state not in [
TaskState.completed,
TaskState.input_required,
TaskState.failed,
TaskState.rejected,
TaskState.auth_required,
TaskState.canceled,
]:
continue
error_message = Message(
role=Role.agent,
message_id=str(uuid.uuid4()),
parts=[Part(root=TextPart(text=error_msg))],
context_id=context_id,
task_id=task_id,
)
new_messages.append(error_message)
final_result = process_task_state(
a2a_task=a2a_task,
new_messages=new_messages,
agent_card=agent_card,
crewai_event_bus.emit(
None,
A2AResponseReceivedEvent(
response=error_msg,
turn_number=turn_number,
is_multiturn=is_multiturn,
status="failed",
agent_role=agent_role,
result_parts=result_parts,
)
if final_result:
break
),
)
return TaskStateResult(
status=TaskState.failed,
error=error_msg,
history=new_messages,
)
except A2AClientHTTPError as e:
error_msg = f"HTTP Error {e.status_code}: {e!s}"
if final_result:
return final_result
error_message = Message(
role=Role.agent,
message_id=str(uuid.uuid4()),
parts=[Part(root=TextPart(text=error_msg))],
context_id=context_id,
task_id=task_id,
)
new_messages.append(error_message)
crewai_event_bus.emit(
None,
A2AResponseReceivedEvent(
response=error_msg,
turn_number=turn_number,
is_multiturn=is_multiturn,
status="failed",
agent_role=agent_role,
),
)
return TaskStateResult(
status=TaskState.failed,
error=error_msg,
status=TaskState.completed,
result=" ".join(result_parts) if result_parts else "",
history=new_messages,
)
except Exception as e:
current_exception: Exception | BaseException | None = e
while current_exception:
if hasattr(current_exception, "response"):
response = current_exception.response
if hasattr(response, "text"):
break
if current_exception and hasattr(current_exception, "__cause__"):
current_exception = current_exception.__cause__
raise
finally:
if hasattr(event_stream, "aclose"):
await event_stream.aclose()
if final_result:
return final_result
return TaskStateResult(
status=TaskState.completed,
result=" ".join(result_parts) if result_parts else "",
history=new_messages,
)

View File

@@ -34,9 +34,15 @@ from crewai.a2a.auth.utils import (
from crewai.a2a.config import A2AConfig
from crewai.a2a.task_helpers import TaskStateResult
from crewai.a2a.types import PartsDict, PartsMetadataDict
from crewai.a2a.updates import PollingConfig, UpdateConfig
from crewai.a2a.updates.polling.handler import execute_polling_delegation
from crewai.a2a.updates.streaming.handler import execute_streaming_delegation
from crewai.a2a.updates import (
PollingConfig,
PollingHandler,
PushNotificationConfig,
PushNotificationHandler,
StreamingConfig,
StreamingHandler,
UpdateConfig,
)
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.a2a_events import (
A2AConversationStartedEvent,
@@ -53,6 +59,31 @@ if TYPE_CHECKING:
from crewai.a2a.auth.schemas import AuthScheme
HandlerType = (
type[PollingHandler] | type[StreamingHandler] | type[PushNotificationHandler]
)
HANDLER_REGISTRY: dict[type[UpdateConfig], HandlerType] = {
PollingConfig: PollingHandler,
StreamingConfig: StreamingHandler,
PushNotificationConfig: PushNotificationHandler,
}
def get_handler(config: UpdateConfig | None) -> HandlerType:
"""Get the handler class for a given update config.
Args:
config: Update mechanism configuration.
Returns:
Handler class for the config type, defaults to StreamingHandler.
"""
if config is None:
return StreamingHandler
return HANDLER_REGISTRY.get(type(config), StreamingHandler)
@lru_cache()
def _fetch_agent_card_cached(
endpoint: str,
@@ -448,14 +479,28 @@ async def _execute_a2a_delegation_async(
),
)
polling_config = updates if isinstance(updates, PollingConfig) else None
use_polling = polling_config is not None
polling_interval = polling_config.interval if polling_config else 2.0
effective_polling_timeout = (
polling_config.timeout
if polling_config and polling_config.timeout
else float(timeout)
)
handler = get_handler(updates)
use_polling = isinstance(updates, PollingConfig)
handler_kwargs: dict[str, Any] = {
"turn_number": turn_number,
"is_multiturn": is_multiturn,
"agent_role": agent_role,
"context_id": context_id,
"task_id": task_id,
"endpoint": endpoint,
"agent_branch": agent_branch,
}
if isinstance(updates, PollingConfig):
handler_kwargs.update(
{
"polling_interval": updates.interval,
"polling_timeout": updates.timeout or float(timeout),
"history_length": updates.history_length,
"max_polls": updates.max_polls,
}
)
async with _create_a2a_client(
agent_card=agent_card,
@@ -466,33 +511,12 @@ async def _execute_a2a_delegation_async(
auth=auth,
use_polling=use_polling,
) as client:
if use_polling and polling_config:
return await execute_polling_delegation(
client=client,
message=message,
polling_interval=polling_interval,
polling_timeout=effective_polling_timeout,
endpoint=endpoint,
agent_branch=agent_branch,
turn_number=turn_number,
is_multiturn=is_multiturn,
agent_role=agent_role,
new_messages=new_messages,
agent_card=agent_card,
history_length=polling_config.history_length,
max_polls=polling_config.max_polls,
)
return await execute_streaming_delegation(
return await handler.execute(
client=client,
message=message,
context_id=context_id,
task_id=task_id,
turn_number=turn_number,
is_multiturn=is_multiturn,
agent_role=agent_role,
new_messages=new_messages,
agent_card=agent_card,
**handler_kwargs,
)