From f9977a5ebe9659b5b1f7452afe6c8353b140e24b Mon Sep 17 00:00:00 2001 From: Greyson LaLonde Date: Mon, 5 Jan 2026 22:27:13 -0500 Subject: [PATCH] feat: implement push notification handler --- .../a2a/updates/push_notifications/handler.py | 188 +++++++++++++++++- lib/crewai/src/crewai/a2a/utils.py | 9 + 2 files changed, 189 insertions(+), 8 deletions(-) diff --git a/lib/crewai/src/crewai/a2a/updates/push_notifications/handler.py b/lib/crewai/src/crewai/a2a/updates/push_notifications/handler.py index 712d7c380..8217cf391 100644 --- a/lib/crewai/src/crewai/a2a/updates/push_notifications/handler.py +++ b/lib/crewai/src/crewai/a2a/updates/push_notifications/handler.py @@ -2,13 +2,96 @@ from __future__ import annotations -from typing import Unpack +import logging +from typing import TYPE_CHECKING, Any from a2a.client import Client -from a2a.types import AgentCard, Message +from a2a.types import ( + AgentCard, + Message, + PushNotificationConfig as A2APushNotificationConfig, + TaskPushNotificationConfig, + TaskState, +) +from typing_extensions import Unpack -from crewai.a2a.task_helpers import TaskStateResult -from crewai.a2a.updates.base import PushNotificationHandlerKwargs +from crewai.a2a.task_helpers import ( + TaskStateResult, + process_task_state, + send_message_and_get_task_id, +) +from crewai.a2a.updates.base import ( + PushNotificationHandlerKwargs, + PushNotificationResultStore, +) +from crewai.events.event_bus import crewai_event_bus +from crewai.events.types.a2a_events import ( + A2APushNotificationRegisteredEvent, + A2APushNotificationTimeoutEvent, +) + + +if TYPE_CHECKING: + from a2a.types import Task as A2ATask + + from crewai.a2a.updates.push_notifications.config import PushNotificationConfig + + +logger = logging.getLogger(__name__) + + +def _build_a2a_push_config(config: PushNotificationConfig) -> A2APushNotificationConfig: + """Convert our config to A2A SDK's PushNotificationConfig. + + Args: + config: Our PushNotificationConfig. + + Returns: + A2A SDK PushNotificationConfig. + """ + return A2APushNotificationConfig( + url=str(config.url), + id=config.id, + token=config.token, + authentication=None, + ) + + +async def _wait_for_push_result( + task_id: str, + result_store: PushNotificationResultStore, + timeout: float, + poll_interval: float, + agent_branch: Any | None = None, +) -> A2ATask | None: + """Wait for push notification result. + + Args: + task_id: Task ID to wait for. + result_store: Store to retrieve results from. + timeout: Max seconds to wait. + poll_interval: Seconds between polling attempts. + agent_branch: Agent tree branch for logging. + + Returns: + Final task object, or None if timeout. + """ + task = await result_store.wait_for_result( + task_id=task_id, + timeout=timeout, + poll_interval=poll_interval, + ) + + if task is None: + crewai_event_bus.emit( + agent_branch, + A2APushNotificationTimeoutEvent( + task_id=task_id, + timeout_seconds=timeout, + ), + ) + + return task class PushNotificationHandler: @@ -31,10 +114,99 @@ class PushNotificationHandler: agent_card: The agent card. **kwargs: Push notification-specific parameters. + Returns: + Dictionary with status, result/error, and history. + Raises: - NotImplementedError: Push notifications not yet implemented. + ValueError: If result_store or config not provided. """ - raise NotImplementedError( - "Push notification update mechanism is not yet implemented. " - "Use PollingConfig or StreamingConfig instead." + config = kwargs.get("config") + result_store = kwargs.get("result_store") + polling_timeout = kwargs.get("polling_timeout", 300.0) + polling_interval = kwargs.get("polling_interval", 2.0) + agent_branch = kwargs.get("agent_branch") + turn_number = kwargs.get("turn_number", 0) + is_multiturn = kwargs.get("is_multiturn", False) + agent_role = kwargs.get("agent_role") + + if config is None: + return TaskStateResult( + status=TaskState.failed, + error="PushNotificationConfig is required for push notification handler", + history=new_messages, + ) + + if result_store is None: + return TaskStateResult( + status=TaskState.failed, + error="PushNotificationResultStore is required for push notification handler", + history=new_messages, + ) + + result_or_task_id = await send_message_and_get_task_id( + event_stream=client.send_message(message), + new_messages=new_messages, + agent_card=agent_card, + turn_number=turn_number, + is_multiturn=is_multiturn, + agent_role=agent_role, + ) + + if not isinstance(result_or_task_id, str): + return result_or_task_id + + task_id = result_or_task_id + + a2a_push_config = _build_a2a_push_config(config) + await client.set_task_callback( + TaskPushNotificationConfig( + task_id=task_id, + push_notification_config=a2a_push_config, + ) + ) + + crewai_event_bus.emit( + agent_branch, + A2APushNotificationRegisteredEvent( + task_id=task_id, + callback_url=str(config.url), + ), + ) + + logger.debug( + "Registered push notification callback for task %s at %s", + task_id, + config.url, + ) + + final_task = await _wait_for_push_result( + task_id=task_id, + result_store=result_store, + timeout=polling_timeout, + poll_interval=polling_interval, + agent_branch=agent_branch, + ) + + if final_task is None: + return TaskStateResult( + status=TaskState.failed, + error=f"Push notification timeout after {polling_timeout}s", + history=new_messages, + ) + + result = process_task_state( + a2a_task=final_task, + new_messages=new_messages, + agent_card=agent_card, + turn_number=turn_number, + is_multiturn=is_multiturn, + agent_role=agent_role, + ) + if result: + return result + + return TaskStateResult( + status=TaskState.failed, + error=f"Unexpected task state: {final_task.status.state}", + history=new_messages, ) diff --git a/lib/crewai/src/crewai/a2a/utils.py b/lib/crewai/src/crewai/a2a/utils.py index 933f77b66..1b0ac3808 100644 --- a/lib/crewai/src/crewai/a2a/utils.py +++ b/lib/crewai/src/crewai/a2a/utils.py @@ -501,6 +501,15 @@ async def _execute_a2a_delegation_async( "max_polls": updates.max_polls, } ) + elif isinstance(updates, PushNotificationConfig): + handler_kwargs.update( + { + "config": updates, + "result_store": updates.result_store, + "polling_timeout": updates.timeout or float(timeout), + "polling_interval": updates.interval, + } + ) async with _create_a2a_client( agent_card=agent_card,