From f53b8755daad9e8b87fa5e82ea490547cd00eba8 Mon Sep 17 00:00:00 2001 From: Greyson LaLonde Date: Tue, 6 Jan 2026 18:38:10 -0500 Subject: [PATCH] fix: ensure failed states are handled for push, poll --- .../src/crewai/a2a/updates/polling/handler.py | 119 ++++++++++------ .../a2a/updates/push_notifications/handler.py | 133 +++++++++++------- 2 files changed, 161 insertions(+), 91 deletions(-) diff --git a/lib/crewai/src/crewai/a2a/updates/polling/handler.py b/lib/crewai/src/crewai/a2a/updates/polling/handler.py index 753a09730..1338b2b3a 100644 --- a/lib/crewai/src/crewai/a2a/updates/polling/handler.py +++ b/lib/crewai/src/crewai/a2a/updates/polling/handler.py @@ -5,13 +5,18 @@ from __future__ import annotations import asyncio import time from typing import TYPE_CHECKING, Any +import uuid from a2a.client import Client +from a2a.client.errors import A2AClientHTTPError from a2a.types import ( AgentCard, Message, + Part, + Role, TaskQueryParams, TaskState, + TextPart, ) from typing_extensions import Unpack @@ -28,6 +33,7 @@ from crewai.events.event_bus import crewai_event_bus from crewai.events.types.a2a_events import ( A2APollingStartedEvent, A2APollingStatusEvent, + A2AResponseReceivedEvent, ) @@ -129,53 +135,84 @@ class PollingHandler: agent_role = kwargs.get("agent_role") history_length = kwargs.get("history_length", 100) max_polls = kwargs.get("max_polls") + context_id = kwargs.get("context_id") + task_id = kwargs.get("task_id") - result_or_task_id = await send_message_and_get_task_id( - event_stream=client.send_message(message), - new_messages=new_messages, - agent_card=agent_card, - turn_number=turn_number, - is_multiturn=is_multiturn, - agent_role=agent_role, - ) + try: + result_or_task_id = await send_message_and_get_task_id( + event_stream=client.send_message(message), + new_messages=new_messages, + agent_card=agent_card, + turn_number=turn_number, + is_multiturn=is_multiturn, + agent_role=agent_role, + ) - if not isinstance(result_or_task_id, str): - return result_or_task_id + if not isinstance(result_or_task_id, str): + return result_or_task_id - task_id = result_or_task_id + task_id = result_or_task_id - crewai_event_bus.emit( - agent_branch, - A2APollingStartedEvent( + crewai_event_bus.emit( + agent_branch, + A2APollingStartedEvent( + task_id=task_id, + polling_interval=polling_interval, + endpoint=endpoint, + ), + ) + + final_task = await _poll_task_until_complete( + client=client, task_id=task_id, polling_interval=polling_interval, - endpoint=endpoint, - ), - ) + polling_timeout=polling_timeout, + agent_branch=agent_branch, + history_length=history_length, + max_polls=max_polls, + ) - final_task = await _poll_task_until_complete( - client=client, - task_id=task_id, - polling_interval=polling_interval, - polling_timeout=polling_timeout, - agent_branch=agent_branch, - history_length=history_length, - max_polls=max_polls, - ) + result = process_task_state( + a2a_task=final_task, + new_messages=new_messages, + agent_card=agent_card, + turn_number=turn_number, + is_multiturn=is_multiturn, + agent_role=agent_role, + ) + if result: + return result - result = process_task_state( - a2a_task=final_task, - new_messages=new_messages, - agent_card=agent_card, - turn_number=turn_number, - is_multiturn=is_multiturn, - agent_role=agent_role, - ) - if result: - return result + return TaskStateResult( + status=TaskState.failed, + error=f"Unexpected task state: {final_task.status.state}", + history=new_messages, + ) - return TaskStateResult( - status=TaskState.failed, - error=f"Unexpected task state: {final_task.status.state}", - history=new_messages, - ) + except A2AClientHTTPError as e: + error_msg = f"HTTP Error {e.status_code}: {e!s}" + + error_message = Message( + role=Role.agent, + message_id=str(uuid.uuid4()), + parts=[Part(root=TextPart(text=error_msg))], + context_id=context_id, + task_id=task_id, + ) + new_messages.append(error_message) + + crewai_event_bus.emit( + agent_branch, + A2AResponseReceivedEvent( + response=error_msg, + turn_number=turn_number, + is_multiturn=is_multiturn, + status="failed", + agent_role=agent_role, + ), + ) + return TaskStateResult( + status=TaskState.failed, + error=error_msg, + history=new_messages, + ) diff --git a/lib/crewai/src/crewai/a2a/updates/push_notifications/handler.py b/lib/crewai/src/crewai/a2a/updates/push_notifications/handler.py index 189e8aee8..04db239f2 100644 --- a/lib/crewai/src/crewai/a2a/updates/push_notifications/handler.py +++ b/lib/crewai/src/crewai/a2a/updates/push_notifications/handler.py @@ -4,12 +4,17 @@ from __future__ import annotations import logging from typing import TYPE_CHECKING, Any +import uuid from a2a.client import Client +from a2a.client.errors import A2AClientHTTPError from a2a.types import ( AgentCard, Message, + Part, + Role, TaskState, + TextPart, ) from typing_extensions import Unpack @@ -26,6 +31,7 @@ from crewai.events.event_bus import crewai_event_bus from crewai.events.types.a2a_events import ( A2APushNotificationRegisteredEvent, A2APushNotificationTimeoutEvent, + A2AResponseReceivedEvent, ) @@ -107,6 +113,8 @@ class PushNotificationHandler: turn_number = kwargs.get("turn_number", 0) is_multiturn = kwargs.get("is_multiturn", False) agent_role = kwargs.get("agent_role") + context_id = kwargs.get("context_id") + task_id = kwargs.get("task_id") if config is None: return TaskStateResult( @@ -122,66 +130,91 @@ class PushNotificationHandler: history=new_messages, ) - # Note: Push notification config is now included in the initial send_message - # request via ClientConfig.push_notification_configs, so no separate - # set_task_callback call is needed. This avoids race conditions where - # the task completes before the callback is registered. - result_or_task_id = await send_message_and_get_task_id( - event_stream=client.send_message(message), - new_messages=new_messages, - agent_card=agent_card, - turn_number=turn_number, - is_multiturn=is_multiturn, - agent_role=agent_role, - ) + try: + result_or_task_id = await send_message_and_get_task_id( + event_stream=client.send_message(message), + new_messages=new_messages, + agent_card=agent_card, + turn_number=turn_number, + is_multiturn=is_multiturn, + agent_role=agent_role, + ) - if not isinstance(result_or_task_id, str): - return result_or_task_id + if not isinstance(result_or_task_id, str): + return result_or_task_id - task_id = result_or_task_id + task_id = result_or_task_id - crewai_event_bus.emit( - agent_branch, - A2APushNotificationRegisteredEvent( + crewai_event_bus.emit( + agent_branch, + A2APushNotificationRegisteredEvent( + task_id=task_id, + callback_url=str(config.url), + ), + ) + + logger.debug( + "Push notification callback for task %s configured at %s (via initial request)", + task_id, + config.url, + ) + + final_task = await _wait_for_push_result( task_id=task_id, - callback_url=str(config.url), - ), - ) + result_store=result_store, + timeout=polling_timeout, + poll_interval=polling_interval, + agent_branch=agent_branch, + ) - logger.debug( - "Push notification callback for task %s configured at %s (via initial request)", - task_id, - config.url, - ) + if final_task is None: + return TaskStateResult( + status=TaskState.failed, + error=f"Push notification timeout after {polling_timeout}s", + history=new_messages, + ) - final_task = await _wait_for_push_result( - task_id=task_id, - result_store=result_store, - timeout=polling_timeout, - poll_interval=polling_interval, - agent_branch=agent_branch, - ) + result = process_task_state( + a2a_task=final_task, + new_messages=new_messages, + agent_card=agent_card, + turn_number=turn_number, + is_multiturn=is_multiturn, + agent_role=agent_role, + ) + if result: + return result - if final_task is None: return TaskStateResult( status=TaskState.failed, - error=f"Push notification timeout after {polling_timeout}s", + error=f"Unexpected task state: {final_task.status.state}", history=new_messages, ) - result = process_task_state( - a2a_task=final_task, - new_messages=new_messages, - agent_card=agent_card, - turn_number=turn_number, - is_multiturn=is_multiturn, - agent_role=agent_role, - ) - if result: - return result + except A2AClientHTTPError as e: + error_msg = f"HTTP Error {e.status_code}: {e!s}" - return TaskStateResult( - status=TaskState.failed, - error=f"Unexpected task state: {final_task.status.state}", - history=new_messages, - ) + error_message = Message( + role=Role.agent, + message_id=str(uuid.uuid4()), + parts=[Part(root=TextPart(text=error_msg))], + context_id=context_id, + task_id=task_id, + ) + new_messages.append(error_message) + + crewai_event_bus.emit( + agent_branch, + A2AResponseReceivedEvent( + response=error_msg, + turn_number=turn_number, + is_multiturn=is_multiturn, + status="failed", + agent_role=agent_role, + ), + ) + return TaskStateResult( + status=TaskState.failed, + error=error_msg, + history=new_messages, + )