fix: ensure failed states are handled for push, poll

This commit is contained in:
Greyson LaLonde
2026-01-06 18:38:10 -05:00
parent 514df45c7d
commit f53b8755da
2 changed files with 161 additions and 91 deletions

View File

@@ -5,13 +5,18 @@ from __future__ import annotations
import asyncio
import time
from typing import TYPE_CHECKING, Any
import uuid
from a2a.client import Client
from a2a.client.errors import A2AClientHTTPError
from a2a.types import (
AgentCard,
Message,
Part,
Role,
TaskQueryParams,
TaskState,
TextPart,
)
from typing_extensions import Unpack
@@ -28,6 +33,7 @@ from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.a2a_events import (
A2APollingStartedEvent,
A2APollingStatusEvent,
A2AResponseReceivedEvent,
)
@@ -129,53 +135,84 @@ class PollingHandler:
agent_role = kwargs.get("agent_role")
history_length = kwargs.get("history_length", 100)
max_polls = kwargs.get("max_polls")
context_id = kwargs.get("context_id")
task_id = kwargs.get("task_id")
result_or_task_id = await send_message_and_get_task_id(
event_stream=client.send_message(message),
new_messages=new_messages,
agent_card=agent_card,
turn_number=turn_number,
is_multiturn=is_multiturn,
agent_role=agent_role,
)
try:
result_or_task_id = await send_message_and_get_task_id(
event_stream=client.send_message(message),
new_messages=new_messages,
agent_card=agent_card,
turn_number=turn_number,
is_multiturn=is_multiturn,
agent_role=agent_role,
)
if not isinstance(result_or_task_id, str):
return result_or_task_id
if not isinstance(result_or_task_id, str):
return result_or_task_id
task_id = result_or_task_id
task_id = result_or_task_id
crewai_event_bus.emit(
agent_branch,
A2APollingStartedEvent(
crewai_event_bus.emit(
agent_branch,
A2APollingStartedEvent(
task_id=task_id,
polling_interval=polling_interval,
endpoint=endpoint,
),
)
final_task = await _poll_task_until_complete(
client=client,
task_id=task_id,
polling_interval=polling_interval,
endpoint=endpoint,
),
)
polling_timeout=polling_timeout,
agent_branch=agent_branch,
history_length=history_length,
max_polls=max_polls,
)
final_task = await _poll_task_until_complete(
client=client,
task_id=task_id,
polling_interval=polling_interval,
polling_timeout=polling_timeout,
agent_branch=agent_branch,
history_length=history_length,
max_polls=max_polls,
)
result = process_task_state(
a2a_task=final_task,
new_messages=new_messages,
agent_card=agent_card,
turn_number=turn_number,
is_multiturn=is_multiturn,
agent_role=agent_role,
)
if result:
return result
result = process_task_state(
a2a_task=final_task,
new_messages=new_messages,
agent_card=agent_card,
turn_number=turn_number,
is_multiturn=is_multiturn,
agent_role=agent_role,
)
if result:
return result
return TaskStateResult(
status=TaskState.failed,
error=f"Unexpected task state: {final_task.status.state}",
history=new_messages,
)
return TaskStateResult(
status=TaskState.failed,
error=f"Unexpected task state: {final_task.status.state}",
history=new_messages,
)
except A2AClientHTTPError as e:
error_msg = f"HTTP Error {e.status_code}: {e!s}"
error_message = Message(
role=Role.agent,
message_id=str(uuid.uuid4()),
parts=[Part(root=TextPart(text=error_msg))],
context_id=context_id,
task_id=task_id,
)
new_messages.append(error_message)
crewai_event_bus.emit(
agent_branch,
A2AResponseReceivedEvent(
response=error_msg,
turn_number=turn_number,
is_multiturn=is_multiturn,
status="failed",
agent_role=agent_role,
),
)
return TaskStateResult(
status=TaskState.failed,
error=error_msg,
history=new_messages,
)

View File

@@ -4,12 +4,17 @@ from __future__ import annotations
import logging
from typing import TYPE_CHECKING, Any
import uuid
from a2a.client import Client
from a2a.client.errors import A2AClientHTTPError
from a2a.types import (
AgentCard,
Message,
Part,
Role,
TaskState,
TextPart,
)
from typing_extensions import Unpack
@@ -26,6 +31,7 @@ from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.a2a_events import (
A2APushNotificationRegisteredEvent,
A2APushNotificationTimeoutEvent,
A2AResponseReceivedEvent,
)
@@ -107,6 +113,8 @@ class PushNotificationHandler:
turn_number = kwargs.get("turn_number", 0)
is_multiturn = kwargs.get("is_multiturn", False)
agent_role = kwargs.get("agent_role")
context_id = kwargs.get("context_id")
task_id = kwargs.get("task_id")
if config is None:
return TaskStateResult(
@@ -122,66 +130,91 @@ class PushNotificationHandler:
history=new_messages,
)
# Note: Push notification config is now included in the initial send_message
# request via ClientConfig.push_notification_configs, so no separate
# set_task_callback call is needed. This avoids race conditions where
# the task completes before the callback is registered.
result_or_task_id = await send_message_and_get_task_id(
event_stream=client.send_message(message),
new_messages=new_messages,
agent_card=agent_card,
turn_number=turn_number,
is_multiturn=is_multiturn,
agent_role=agent_role,
)
try:
result_or_task_id = await send_message_and_get_task_id(
event_stream=client.send_message(message),
new_messages=new_messages,
agent_card=agent_card,
turn_number=turn_number,
is_multiturn=is_multiturn,
agent_role=agent_role,
)
if not isinstance(result_or_task_id, str):
return result_or_task_id
if not isinstance(result_or_task_id, str):
return result_or_task_id
task_id = result_or_task_id
task_id = result_or_task_id
crewai_event_bus.emit(
agent_branch,
A2APushNotificationRegisteredEvent(
crewai_event_bus.emit(
agent_branch,
A2APushNotificationRegisteredEvent(
task_id=task_id,
callback_url=str(config.url),
),
)
logger.debug(
"Push notification callback for task %s configured at %s (via initial request)",
task_id,
config.url,
)
final_task = await _wait_for_push_result(
task_id=task_id,
callback_url=str(config.url),
),
)
result_store=result_store,
timeout=polling_timeout,
poll_interval=polling_interval,
agent_branch=agent_branch,
)
logger.debug(
"Push notification callback for task %s configured at %s (via initial request)",
task_id,
config.url,
)
if final_task is None:
return TaskStateResult(
status=TaskState.failed,
error=f"Push notification timeout after {polling_timeout}s",
history=new_messages,
)
final_task = await _wait_for_push_result(
task_id=task_id,
result_store=result_store,
timeout=polling_timeout,
poll_interval=polling_interval,
agent_branch=agent_branch,
)
result = process_task_state(
a2a_task=final_task,
new_messages=new_messages,
agent_card=agent_card,
turn_number=turn_number,
is_multiturn=is_multiturn,
agent_role=agent_role,
)
if result:
return result
if final_task is None:
return TaskStateResult(
status=TaskState.failed,
error=f"Push notification timeout after {polling_timeout}s",
error=f"Unexpected task state: {final_task.status.state}",
history=new_messages,
)
result = process_task_state(
a2a_task=final_task,
new_messages=new_messages,
agent_card=agent_card,
turn_number=turn_number,
is_multiturn=is_multiturn,
agent_role=agent_role,
)
if result:
return result
except A2AClientHTTPError as e:
error_msg = f"HTTP Error {e.status_code}: {e!s}"
return TaskStateResult(
status=TaskState.failed,
error=f"Unexpected task state: {final_task.status.state}",
history=new_messages,
)
error_message = Message(
role=Role.agent,
message_id=str(uuid.uuid4()),
parts=[Part(root=TextPart(text=error_msg))],
context_id=context_id,
task_id=task_id,
)
new_messages.append(error_message)
crewai_event_bus.emit(
agent_branch,
A2AResponseReceivedEvent(
response=error_msg,
turn_number=turn_number,
is_multiturn=is_multiturn,
status="failed",
agent_role=agent_role,
),
)
return TaskStateResult(
status=TaskState.failed,
error=error_msg,
history=new_messages,
)