fix: ensure failed states are handled for push, poll

This commit is contained in:
Greyson LaLonde
2026-01-06 18:38:10 -05:00
parent 514df45c7d
commit f53b8755da
2 changed files with 161 additions and 91 deletions

View File

@@ -5,13 +5,18 @@ from __future__ import annotations
import asyncio import asyncio
import time import time
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
import uuid
from a2a.client import Client from a2a.client import Client
from a2a.client.errors import A2AClientHTTPError
from a2a.types import ( from a2a.types import (
AgentCard, AgentCard,
Message, Message,
Part,
Role,
TaskQueryParams, TaskQueryParams,
TaskState, TaskState,
TextPart,
) )
from typing_extensions import Unpack from typing_extensions import Unpack
@@ -28,6 +33,7 @@ from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.a2a_events import ( from crewai.events.types.a2a_events import (
A2APollingStartedEvent, A2APollingStartedEvent,
A2APollingStatusEvent, A2APollingStatusEvent,
A2AResponseReceivedEvent,
) )
@@ -129,53 +135,84 @@ class PollingHandler:
agent_role = kwargs.get("agent_role") agent_role = kwargs.get("agent_role")
history_length = kwargs.get("history_length", 100) history_length = kwargs.get("history_length", 100)
max_polls = kwargs.get("max_polls") max_polls = kwargs.get("max_polls")
context_id = kwargs.get("context_id")
task_id = kwargs.get("task_id")
result_or_task_id = await send_message_and_get_task_id( try:
event_stream=client.send_message(message), result_or_task_id = await send_message_and_get_task_id(
new_messages=new_messages, event_stream=client.send_message(message),
agent_card=agent_card, new_messages=new_messages,
turn_number=turn_number, agent_card=agent_card,
is_multiturn=is_multiturn, turn_number=turn_number,
agent_role=agent_role, is_multiturn=is_multiturn,
) agent_role=agent_role,
)
if not isinstance(result_or_task_id, str): if not isinstance(result_or_task_id, str):
return result_or_task_id return result_or_task_id
task_id = result_or_task_id task_id = result_or_task_id
crewai_event_bus.emit( crewai_event_bus.emit(
agent_branch, agent_branch,
A2APollingStartedEvent( A2APollingStartedEvent(
task_id=task_id,
polling_interval=polling_interval,
endpoint=endpoint,
),
)
final_task = await _poll_task_until_complete(
client=client,
task_id=task_id, task_id=task_id,
polling_interval=polling_interval, polling_interval=polling_interval,
endpoint=endpoint, polling_timeout=polling_timeout,
), agent_branch=agent_branch,
) history_length=history_length,
max_polls=max_polls,
)
final_task = await _poll_task_until_complete( result = process_task_state(
client=client, a2a_task=final_task,
task_id=task_id, new_messages=new_messages,
polling_interval=polling_interval, agent_card=agent_card,
polling_timeout=polling_timeout, turn_number=turn_number,
agent_branch=agent_branch, is_multiturn=is_multiturn,
history_length=history_length, agent_role=agent_role,
max_polls=max_polls, )
) if result:
return result
result = process_task_state( return TaskStateResult(
a2a_task=final_task, status=TaskState.failed,
new_messages=new_messages, error=f"Unexpected task state: {final_task.status.state}",
agent_card=agent_card, history=new_messages,
turn_number=turn_number, )
is_multiturn=is_multiturn,
agent_role=agent_role,
)
if result:
return result
return TaskStateResult( except A2AClientHTTPError as e:
status=TaskState.failed, error_msg = f"HTTP Error {e.status_code}: {e!s}"
error=f"Unexpected task state: {final_task.status.state}",
history=new_messages, error_message = Message(
) role=Role.agent,
message_id=str(uuid.uuid4()),
parts=[Part(root=TextPart(text=error_msg))],
context_id=context_id,
task_id=task_id,
)
new_messages.append(error_message)
crewai_event_bus.emit(
agent_branch,
A2AResponseReceivedEvent(
response=error_msg,
turn_number=turn_number,
is_multiturn=is_multiturn,
status="failed",
agent_role=agent_role,
),
)
return TaskStateResult(
status=TaskState.failed,
error=error_msg,
history=new_messages,
)

View File

@@ -4,12 +4,17 @@ from __future__ import annotations
import logging import logging
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
import uuid
from a2a.client import Client from a2a.client import Client
from a2a.client.errors import A2AClientHTTPError
from a2a.types import ( from a2a.types import (
AgentCard, AgentCard,
Message, Message,
Part,
Role,
TaskState, TaskState,
TextPart,
) )
from typing_extensions import Unpack from typing_extensions import Unpack
@@ -26,6 +31,7 @@ from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.a2a_events import ( from crewai.events.types.a2a_events import (
A2APushNotificationRegisteredEvent, A2APushNotificationRegisteredEvent,
A2APushNotificationTimeoutEvent, A2APushNotificationTimeoutEvent,
A2AResponseReceivedEvent,
) )
@@ -107,6 +113,8 @@ class PushNotificationHandler:
turn_number = kwargs.get("turn_number", 0) turn_number = kwargs.get("turn_number", 0)
is_multiturn = kwargs.get("is_multiturn", False) is_multiturn = kwargs.get("is_multiturn", False)
agent_role = kwargs.get("agent_role") agent_role = kwargs.get("agent_role")
context_id = kwargs.get("context_id")
task_id = kwargs.get("task_id")
if config is None: if config is None:
return TaskStateResult( return TaskStateResult(
@@ -122,66 +130,91 @@ class PushNotificationHandler:
history=new_messages, history=new_messages,
) )
# Note: Push notification config is now included in the initial send_message try:
# request via ClientConfig.push_notification_configs, so no separate result_or_task_id = await send_message_and_get_task_id(
# set_task_callback call is needed. This avoids race conditions where event_stream=client.send_message(message),
# the task completes before the callback is registered. new_messages=new_messages,
result_or_task_id = await send_message_and_get_task_id( agent_card=agent_card,
event_stream=client.send_message(message), turn_number=turn_number,
new_messages=new_messages, is_multiturn=is_multiturn,
agent_card=agent_card, agent_role=agent_role,
turn_number=turn_number, )
is_multiturn=is_multiturn,
agent_role=agent_role,
)
if not isinstance(result_or_task_id, str): if not isinstance(result_or_task_id, str):
return result_or_task_id return result_or_task_id
task_id = result_or_task_id task_id = result_or_task_id
crewai_event_bus.emit( crewai_event_bus.emit(
agent_branch, agent_branch,
A2APushNotificationRegisteredEvent( A2APushNotificationRegisteredEvent(
task_id=task_id,
callback_url=str(config.url),
),
)
logger.debug(
"Push notification callback for task %s configured at %s (via initial request)",
task_id,
config.url,
)
final_task = await _wait_for_push_result(
task_id=task_id, task_id=task_id,
callback_url=str(config.url), result_store=result_store,
), timeout=polling_timeout,
) poll_interval=polling_interval,
agent_branch=agent_branch,
)
logger.debug( if final_task is None:
"Push notification callback for task %s configured at %s (via initial request)", return TaskStateResult(
task_id, status=TaskState.failed,
config.url, error=f"Push notification timeout after {polling_timeout}s",
) history=new_messages,
)
final_task = await _wait_for_push_result( result = process_task_state(
task_id=task_id, a2a_task=final_task,
result_store=result_store, new_messages=new_messages,
timeout=polling_timeout, agent_card=agent_card,
poll_interval=polling_interval, turn_number=turn_number,
agent_branch=agent_branch, is_multiturn=is_multiturn,
) agent_role=agent_role,
)
if result:
return result
if final_task is None:
return TaskStateResult( return TaskStateResult(
status=TaskState.failed, status=TaskState.failed,
error=f"Push notification timeout after {polling_timeout}s", error=f"Unexpected task state: {final_task.status.state}",
history=new_messages, history=new_messages,
) )
result = process_task_state( except A2AClientHTTPError as e:
a2a_task=final_task, error_msg = f"HTTP Error {e.status_code}: {e!s}"
new_messages=new_messages,
agent_card=agent_card,
turn_number=turn_number,
is_multiturn=is_multiturn,
agent_role=agent_role,
)
if result:
return result
return TaskStateResult( error_message = Message(
status=TaskState.failed, role=Role.agent,
error=f"Unexpected task state: {final_task.status.state}", message_id=str(uuid.uuid4()),
history=new_messages, parts=[Part(root=TextPart(text=error_msg))],
) context_id=context_id,
task_id=task_id,
)
new_messages.append(error_message)
crewai_event_bus.emit(
agent_branch,
A2AResponseReceivedEvent(
response=error_msg,
turn_number=turn_number,
is_multiturn=is_multiturn,
status="failed",
agent_role=agent_role,
),
)
return TaskStateResult(
status=TaskState.failed,
error=error_msg,
history=new_messages,
)