feat: implement push notification handler

This commit is contained in:
Greyson LaLonde
2026-01-05 22:27:13 -05:00
parent 33d73c0be1
commit f9977a5ebe
2 changed files with 189 additions and 8 deletions

View File

@@ -2,13 +2,96 @@
from __future__ import annotations
from typing import Unpack
import logging
from typing import TYPE_CHECKING, Any
from a2a.client import Client
from a2a.types import AgentCard, Message
from a2a.types import (
AgentCard,
Message,
PushNotificationConfig as A2APushNotificationConfig,
TaskPushNotificationConfig,
TaskState,
)
from typing_extensions import Unpack
from crewai.a2a.task_helpers import TaskStateResult
from crewai.a2a.updates.base import PushNotificationHandlerKwargs
from crewai.a2a.task_helpers import (
TaskStateResult,
process_task_state,
send_message_and_get_task_id,
)
from crewai.a2a.updates.base import (
PushNotificationHandlerKwargs,
PushNotificationResultStore,
)
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.a2a_events import (
A2APushNotificationRegisteredEvent,
A2APushNotificationTimeoutEvent,
)
if TYPE_CHECKING:
from a2a.types import Task as A2ATask
from crewai.a2a.updates.push_notifications.config import PushNotificationConfig
logger = logging.getLogger(__name__)
def _build_a2a_push_config(config: PushNotificationConfig) -> A2APushNotificationConfig:
"""Convert our config to A2A SDK's PushNotificationConfig.
Args:
config: Our PushNotificationConfig.
Returns:
A2A SDK PushNotificationConfig.
"""
return A2APushNotificationConfig(
url=str(config.url),
id=config.id,
token=config.token,
authentication=None,
)
async def _wait_for_push_result(
task_id: str,
result_store: PushNotificationResultStore,
timeout: float,
poll_interval: float,
agent_branch: Any | None = None,
) -> A2ATask | None:
"""Wait for push notification result.
Args:
task_id: Task ID to wait for.
result_store: Store to retrieve results from.
timeout: Max seconds to wait.
poll_interval: Seconds between polling attempts.
agent_branch: Agent tree branch for logging.
Returns:
Final task object, or None if timeout.
"""
task = await result_store.wait_for_result(
task_id=task_id,
timeout=timeout,
poll_interval=poll_interval,
)
if task is None:
crewai_event_bus.emit(
agent_branch,
A2APushNotificationTimeoutEvent(
task_id=task_id,
timeout_seconds=timeout,
),
)
return task
class PushNotificationHandler:
@@ -31,10 +114,99 @@ class PushNotificationHandler:
agent_card: The agent card.
**kwargs: Push notification-specific parameters.
Returns:
Dictionary with status, result/error, and history.
Raises:
NotImplementedError: Push notifications not yet implemented.
ValueError: If result_store or config not provided.
"""
raise NotImplementedError(
"Push notification update mechanism is not yet implemented. "
"Use PollingConfig or StreamingConfig instead."
config = kwargs.get("config")
result_store = kwargs.get("result_store")
polling_timeout = kwargs.get("polling_timeout", 300.0)
polling_interval = kwargs.get("polling_interval", 2.0)
agent_branch = kwargs.get("agent_branch")
turn_number = kwargs.get("turn_number", 0)
is_multiturn = kwargs.get("is_multiturn", False)
agent_role = kwargs.get("agent_role")
if config is None:
return TaskStateResult(
status=TaskState.failed,
error="PushNotificationConfig is required for push notification handler",
history=new_messages,
)
if result_store is None:
return TaskStateResult(
status=TaskState.failed,
error="PushNotificationResultStore is required for push notification handler",
history=new_messages,
)
result_or_task_id = await send_message_and_get_task_id(
event_stream=client.send_message(message),
new_messages=new_messages,
agent_card=agent_card,
turn_number=turn_number,
is_multiturn=is_multiturn,
agent_role=agent_role,
)
if not isinstance(result_or_task_id, str):
return result_or_task_id
task_id = result_or_task_id
a2a_push_config = _build_a2a_push_config(config)
await client.set_task_callback(
TaskPushNotificationConfig(
task_id=task_id,
push_notification_config=a2a_push_config,
)
)
crewai_event_bus.emit(
agent_branch,
A2APushNotificationRegisteredEvent(
task_id=task_id,
callback_url=str(config.url),
),
)
logger.debug(
"Registered push notification callback for task %s at %s",
task_id,
config.url,
)
final_task = await _wait_for_push_result(
task_id=task_id,
result_store=result_store,
timeout=polling_timeout,
poll_interval=polling_interval,
agent_branch=agent_branch,
)
if final_task is None:
return TaskStateResult(
status=TaskState.failed,
error=f"Push notification timeout after {polling_timeout}s",
history=new_messages,
)
result = process_task_state(
a2a_task=final_task,
new_messages=new_messages,
agent_card=agent_card,
turn_number=turn_number,
is_multiturn=is_multiturn,
agent_role=agent_role,
)
if result:
return result
return TaskStateResult(
status=TaskState.failed,
error=f"Unexpected task state: {final_task.status.state}",
history=new_messages,
)

View File

@@ -501,6 +501,15 @@ async def _execute_a2a_delegation_async(
"max_polls": updates.max_polls,
}
)
elif isinstance(updates, PushNotificationConfig):
handler_kwargs.update(
{
"config": updates,
"result_store": updates.result_store,
"polling_timeout": updates.timeout or float(timeout),
"polling_interval": updates.interval,
}
)
async with _create_a2a_client(
agent_card=agent_card,