mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-09 08:08:32 +00:00
feat: implement push notification handler
This commit is contained in:
@@ -2,13 +2,96 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Unpack
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from a2a.client import Client
|
||||
from a2a.types import AgentCard, Message
|
||||
from a2a.types import (
|
||||
AgentCard,
|
||||
Message,
|
||||
PushNotificationConfig as A2APushNotificationConfig,
|
||||
TaskPushNotificationConfig,
|
||||
TaskState,
|
||||
)
|
||||
from typing_extensions import Unpack
|
||||
|
||||
from crewai.a2a.task_helpers import TaskStateResult
|
||||
from crewai.a2a.updates.base import PushNotificationHandlerKwargs
|
||||
from crewai.a2a.task_helpers import (
|
||||
TaskStateResult,
|
||||
process_task_state,
|
||||
send_message_and_get_task_id,
|
||||
)
|
||||
from crewai.a2a.updates.base import (
|
||||
PushNotificationHandlerKwargs,
|
||||
PushNotificationResultStore,
|
||||
)
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.a2a_events import (
|
||||
A2APushNotificationRegisteredEvent,
|
||||
A2APushNotificationTimeoutEvent,
|
||||
)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from a2a.types import Task as A2ATask
|
||||
|
||||
from crewai.a2a.updates.push_notifications.config import PushNotificationConfig
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _build_a2a_push_config(config: PushNotificationConfig) -> A2APushNotificationConfig:
|
||||
"""Convert our config to A2A SDK's PushNotificationConfig.
|
||||
|
||||
Args:
|
||||
config: Our PushNotificationConfig.
|
||||
|
||||
Returns:
|
||||
A2A SDK PushNotificationConfig.
|
||||
"""
|
||||
return A2APushNotificationConfig(
|
||||
url=str(config.url),
|
||||
id=config.id,
|
||||
token=config.token,
|
||||
authentication=None,
|
||||
)
|
||||
|
||||
|
||||
async def _wait_for_push_result(
|
||||
task_id: str,
|
||||
result_store: PushNotificationResultStore,
|
||||
timeout: float,
|
||||
poll_interval: float,
|
||||
agent_branch: Any | None = None,
|
||||
) -> A2ATask | None:
|
||||
"""Wait for push notification result.
|
||||
|
||||
Args:
|
||||
task_id: Task ID to wait for.
|
||||
result_store: Store to retrieve results from.
|
||||
timeout: Max seconds to wait.
|
||||
poll_interval: Seconds between polling attempts.
|
||||
agent_branch: Agent tree branch for logging.
|
||||
|
||||
Returns:
|
||||
Final task object, or None if timeout.
|
||||
"""
|
||||
task = await result_store.wait_for_result(
|
||||
task_id=task_id,
|
||||
timeout=timeout,
|
||||
poll_interval=poll_interval,
|
||||
)
|
||||
|
||||
if task is None:
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
A2APushNotificationTimeoutEvent(
|
||||
task_id=task_id,
|
||||
timeout_seconds=timeout,
|
||||
),
|
||||
)
|
||||
|
||||
return task
|
||||
|
||||
|
||||
class PushNotificationHandler:
|
||||
@@ -31,10 +114,99 @@ class PushNotificationHandler:
|
||||
agent_card: The agent card.
|
||||
**kwargs: Push notification-specific parameters.
|
||||
|
||||
Returns:
|
||||
Dictionary with status, result/error, and history.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: Push notifications not yet implemented.
|
||||
ValueError: If result_store or config not provided.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"Push notification update mechanism is not yet implemented. "
|
||||
"Use PollingConfig or StreamingConfig instead."
|
||||
config = kwargs.get("config")
|
||||
result_store = kwargs.get("result_store")
|
||||
polling_timeout = kwargs.get("polling_timeout", 300.0)
|
||||
polling_interval = kwargs.get("polling_interval", 2.0)
|
||||
agent_branch = kwargs.get("agent_branch")
|
||||
turn_number = kwargs.get("turn_number", 0)
|
||||
is_multiturn = kwargs.get("is_multiturn", False)
|
||||
agent_role = kwargs.get("agent_role")
|
||||
|
||||
if config is None:
|
||||
return TaskStateResult(
|
||||
status=TaskState.failed,
|
||||
error="PushNotificationConfig is required for push notification handler",
|
||||
history=new_messages,
|
||||
)
|
||||
|
||||
if result_store is None:
|
||||
return TaskStateResult(
|
||||
status=TaskState.failed,
|
||||
error="PushNotificationResultStore is required for push notification handler",
|
||||
history=new_messages,
|
||||
)
|
||||
|
||||
result_or_task_id = await send_message_and_get_task_id(
|
||||
event_stream=client.send_message(message),
|
||||
new_messages=new_messages,
|
||||
agent_card=agent_card,
|
||||
turn_number=turn_number,
|
||||
is_multiturn=is_multiturn,
|
||||
agent_role=agent_role,
|
||||
)
|
||||
|
||||
if not isinstance(result_or_task_id, str):
|
||||
return result_or_task_id
|
||||
|
||||
task_id = result_or_task_id
|
||||
|
||||
a2a_push_config = _build_a2a_push_config(config)
|
||||
await client.set_task_callback(
|
||||
TaskPushNotificationConfig(
|
||||
task_id=task_id,
|
||||
push_notification_config=a2a_push_config,
|
||||
)
|
||||
)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
A2APushNotificationRegisteredEvent(
|
||||
task_id=task_id,
|
||||
callback_url=str(config.url),
|
||||
),
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
"Registered push notification callback for task %s at %s",
|
||||
task_id,
|
||||
config.url,
|
||||
)
|
||||
|
||||
final_task = await _wait_for_push_result(
|
||||
task_id=task_id,
|
||||
result_store=result_store,
|
||||
timeout=polling_timeout,
|
||||
poll_interval=polling_interval,
|
||||
agent_branch=agent_branch,
|
||||
)
|
||||
|
||||
if final_task is None:
|
||||
return TaskStateResult(
|
||||
status=TaskState.failed,
|
||||
error=f"Push notification timeout after {polling_timeout}s",
|
||||
history=new_messages,
|
||||
)
|
||||
|
||||
result = process_task_state(
|
||||
a2a_task=final_task,
|
||||
new_messages=new_messages,
|
||||
agent_card=agent_card,
|
||||
turn_number=turn_number,
|
||||
is_multiturn=is_multiturn,
|
||||
agent_role=agent_role,
|
||||
)
|
||||
if result:
|
||||
return result
|
||||
|
||||
return TaskStateResult(
|
||||
status=TaskState.failed,
|
||||
error=f"Unexpected task state: {final_task.status.state}",
|
||||
history=new_messages,
|
||||
)
|
||||
|
||||
@@ -501,6 +501,15 @@ async def _execute_a2a_delegation_async(
|
||||
"max_polls": updates.max_polls,
|
||||
}
|
||||
)
|
||||
elif isinstance(updates, PushNotificationConfig):
|
||||
handler_kwargs.update(
|
||||
{
|
||||
"config": updates,
|
||||
"result_store": updates.result_store,
|
||||
"polling_timeout": updates.timeout or float(timeout),
|
||||
"polling_interval": updates.interval,
|
||||
}
|
||||
)
|
||||
|
||||
async with _create_a2a_client(
|
||||
agent_card=agent_card,
|
||||
|
||||
Reference in New Issue
Block a user