mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-26 08:38:15 +00:00
fix: ensure failed states are handled for push, poll
This commit is contained in:
@@ -5,13 +5,18 @@ from __future__ import annotations
|
|||||||
import asyncio
|
import asyncio
|
||||||
import time
|
import time
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING, Any
|
||||||
|
import uuid
|
||||||
|
|
||||||
from a2a.client import Client
|
from a2a.client import Client
|
||||||
|
from a2a.client.errors import A2AClientHTTPError
|
||||||
from a2a.types import (
|
from a2a.types import (
|
||||||
AgentCard,
|
AgentCard,
|
||||||
Message,
|
Message,
|
||||||
|
Part,
|
||||||
|
Role,
|
||||||
TaskQueryParams,
|
TaskQueryParams,
|
||||||
TaskState,
|
TaskState,
|
||||||
|
TextPart,
|
||||||
)
|
)
|
||||||
from typing_extensions import Unpack
|
from typing_extensions import Unpack
|
||||||
|
|
||||||
@@ -28,6 +33,7 @@ from crewai.events.event_bus import crewai_event_bus
|
|||||||
from crewai.events.types.a2a_events import (
|
from crewai.events.types.a2a_events import (
|
||||||
A2APollingStartedEvent,
|
A2APollingStartedEvent,
|
||||||
A2APollingStatusEvent,
|
A2APollingStatusEvent,
|
||||||
|
A2AResponseReceivedEvent,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -129,53 +135,84 @@ class PollingHandler:
|
|||||||
agent_role = kwargs.get("agent_role")
|
agent_role = kwargs.get("agent_role")
|
||||||
history_length = kwargs.get("history_length", 100)
|
history_length = kwargs.get("history_length", 100)
|
||||||
max_polls = kwargs.get("max_polls")
|
max_polls = kwargs.get("max_polls")
|
||||||
|
context_id = kwargs.get("context_id")
|
||||||
|
task_id = kwargs.get("task_id")
|
||||||
|
|
||||||
result_or_task_id = await send_message_and_get_task_id(
|
try:
|
||||||
event_stream=client.send_message(message),
|
result_or_task_id = await send_message_and_get_task_id(
|
||||||
new_messages=new_messages,
|
event_stream=client.send_message(message),
|
||||||
agent_card=agent_card,
|
new_messages=new_messages,
|
||||||
turn_number=turn_number,
|
agent_card=agent_card,
|
||||||
is_multiturn=is_multiturn,
|
turn_number=turn_number,
|
||||||
agent_role=agent_role,
|
is_multiturn=is_multiturn,
|
||||||
)
|
agent_role=agent_role,
|
||||||
|
)
|
||||||
|
|
||||||
if not isinstance(result_or_task_id, str):
|
if not isinstance(result_or_task_id, str):
|
||||||
return result_or_task_id
|
return result_or_task_id
|
||||||
|
|
||||||
task_id = result_or_task_id
|
task_id = result_or_task_id
|
||||||
|
|
||||||
crewai_event_bus.emit(
|
crewai_event_bus.emit(
|
||||||
agent_branch,
|
agent_branch,
|
||||||
A2APollingStartedEvent(
|
A2APollingStartedEvent(
|
||||||
|
task_id=task_id,
|
||||||
|
polling_interval=polling_interval,
|
||||||
|
endpoint=endpoint,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
final_task = await _poll_task_until_complete(
|
||||||
|
client=client,
|
||||||
task_id=task_id,
|
task_id=task_id,
|
||||||
polling_interval=polling_interval,
|
polling_interval=polling_interval,
|
||||||
endpoint=endpoint,
|
polling_timeout=polling_timeout,
|
||||||
),
|
agent_branch=agent_branch,
|
||||||
)
|
history_length=history_length,
|
||||||
|
max_polls=max_polls,
|
||||||
|
)
|
||||||
|
|
||||||
final_task = await _poll_task_until_complete(
|
result = process_task_state(
|
||||||
client=client,
|
a2a_task=final_task,
|
||||||
task_id=task_id,
|
new_messages=new_messages,
|
||||||
polling_interval=polling_interval,
|
agent_card=agent_card,
|
||||||
polling_timeout=polling_timeout,
|
turn_number=turn_number,
|
||||||
agent_branch=agent_branch,
|
is_multiturn=is_multiturn,
|
||||||
history_length=history_length,
|
agent_role=agent_role,
|
||||||
max_polls=max_polls,
|
)
|
||||||
)
|
if result:
|
||||||
|
return result
|
||||||
|
|
||||||
result = process_task_state(
|
return TaskStateResult(
|
||||||
a2a_task=final_task,
|
status=TaskState.failed,
|
||||||
new_messages=new_messages,
|
error=f"Unexpected task state: {final_task.status.state}",
|
||||||
agent_card=agent_card,
|
history=new_messages,
|
||||||
turn_number=turn_number,
|
)
|
||||||
is_multiturn=is_multiturn,
|
|
||||||
agent_role=agent_role,
|
|
||||||
)
|
|
||||||
if result:
|
|
||||||
return result
|
|
||||||
|
|
||||||
return TaskStateResult(
|
except A2AClientHTTPError as e:
|
||||||
status=TaskState.failed,
|
error_msg = f"HTTP Error {e.status_code}: {e!s}"
|
||||||
error=f"Unexpected task state: {final_task.status.state}",
|
|
||||||
history=new_messages,
|
error_message = Message(
|
||||||
)
|
role=Role.agent,
|
||||||
|
message_id=str(uuid.uuid4()),
|
||||||
|
parts=[Part(root=TextPart(text=error_msg))],
|
||||||
|
context_id=context_id,
|
||||||
|
task_id=task_id,
|
||||||
|
)
|
||||||
|
new_messages.append(error_message)
|
||||||
|
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
agent_branch,
|
||||||
|
A2AResponseReceivedEvent(
|
||||||
|
response=error_msg,
|
||||||
|
turn_number=turn_number,
|
||||||
|
is_multiturn=is_multiturn,
|
||||||
|
status="failed",
|
||||||
|
agent_role=agent_role,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return TaskStateResult(
|
||||||
|
status=TaskState.failed,
|
||||||
|
error=error_msg,
|
||||||
|
history=new_messages,
|
||||||
|
)
|
||||||
|
|||||||
@@ -4,12 +4,17 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING, Any
|
||||||
|
import uuid
|
||||||
|
|
||||||
from a2a.client import Client
|
from a2a.client import Client
|
||||||
|
from a2a.client.errors import A2AClientHTTPError
|
||||||
from a2a.types import (
|
from a2a.types import (
|
||||||
AgentCard,
|
AgentCard,
|
||||||
Message,
|
Message,
|
||||||
|
Part,
|
||||||
|
Role,
|
||||||
TaskState,
|
TaskState,
|
||||||
|
TextPart,
|
||||||
)
|
)
|
||||||
from typing_extensions import Unpack
|
from typing_extensions import Unpack
|
||||||
|
|
||||||
@@ -26,6 +31,7 @@ from crewai.events.event_bus import crewai_event_bus
|
|||||||
from crewai.events.types.a2a_events import (
|
from crewai.events.types.a2a_events import (
|
||||||
A2APushNotificationRegisteredEvent,
|
A2APushNotificationRegisteredEvent,
|
||||||
A2APushNotificationTimeoutEvent,
|
A2APushNotificationTimeoutEvent,
|
||||||
|
A2AResponseReceivedEvent,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -107,6 +113,8 @@ class PushNotificationHandler:
|
|||||||
turn_number = kwargs.get("turn_number", 0)
|
turn_number = kwargs.get("turn_number", 0)
|
||||||
is_multiturn = kwargs.get("is_multiturn", False)
|
is_multiturn = kwargs.get("is_multiturn", False)
|
||||||
agent_role = kwargs.get("agent_role")
|
agent_role = kwargs.get("agent_role")
|
||||||
|
context_id = kwargs.get("context_id")
|
||||||
|
task_id = kwargs.get("task_id")
|
||||||
|
|
||||||
if config is None:
|
if config is None:
|
||||||
return TaskStateResult(
|
return TaskStateResult(
|
||||||
@@ -122,66 +130,91 @@ class PushNotificationHandler:
|
|||||||
history=new_messages,
|
history=new_messages,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Note: Push notification config is now included in the initial send_message
|
try:
|
||||||
# request via ClientConfig.push_notification_configs, so no separate
|
result_or_task_id = await send_message_and_get_task_id(
|
||||||
# set_task_callback call is needed. This avoids race conditions where
|
event_stream=client.send_message(message),
|
||||||
# the task completes before the callback is registered.
|
new_messages=new_messages,
|
||||||
result_or_task_id = await send_message_and_get_task_id(
|
agent_card=agent_card,
|
||||||
event_stream=client.send_message(message),
|
turn_number=turn_number,
|
||||||
new_messages=new_messages,
|
is_multiturn=is_multiturn,
|
||||||
agent_card=agent_card,
|
agent_role=agent_role,
|
||||||
turn_number=turn_number,
|
)
|
||||||
is_multiturn=is_multiturn,
|
|
||||||
agent_role=agent_role,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not isinstance(result_or_task_id, str):
|
if not isinstance(result_or_task_id, str):
|
||||||
return result_or_task_id
|
return result_or_task_id
|
||||||
|
|
||||||
task_id = result_or_task_id
|
task_id = result_or_task_id
|
||||||
|
|
||||||
crewai_event_bus.emit(
|
crewai_event_bus.emit(
|
||||||
agent_branch,
|
agent_branch,
|
||||||
A2APushNotificationRegisteredEvent(
|
A2APushNotificationRegisteredEvent(
|
||||||
|
task_id=task_id,
|
||||||
|
callback_url=str(config.url),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
"Push notification callback for task %s configured at %s (via initial request)",
|
||||||
|
task_id,
|
||||||
|
config.url,
|
||||||
|
)
|
||||||
|
|
||||||
|
final_task = await _wait_for_push_result(
|
||||||
task_id=task_id,
|
task_id=task_id,
|
||||||
callback_url=str(config.url),
|
result_store=result_store,
|
||||||
),
|
timeout=polling_timeout,
|
||||||
)
|
poll_interval=polling_interval,
|
||||||
|
agent_branch=agent_branch,
|
||||||
|
)
|
||||||
|
|
||||||
logger.debug(
|
if final_task is None:
|
||||||
"Push notification callback for task %s configured at %s (via initial request)",
|
return TaskStateResult(
|
||||||
task_id,
|
status=TaskState.failed,
|
||||||
config.url,
|
error=f"Push notification timeout after {polling_timeout}s",
|
||||||
)
|
history=new_messages,
|
||||||
|
)
|
||||||
|
|
||||||
final_task = await _wait_for_push_result(
|
result = process_task_state(
|
||||||
task_id=task_id,
|
a2a_task=final_task,
|
||||||
result_store=result_store,
|
new_messages=new_messages,
|
||||||
timeout=polling_timeout,
|
agent_card=agent_card,
|
||||||
poll_interval=polling_interval,
|
turn_number=turn_number,
|
||||||
agent_branch=agent_branch,
|
is_multiturn=is_multiturn,
|
||||||
)
|
agent_role=agent_role,
|
||||||
|
)
|
||||||
|
if result:
|
||||||
|
return result
|
||||||
|
|
||||||
if final_task is None:
|
|
||||||
return TaskStateResult(
|
return TaskStateResult(
|
||||||
status=TaskState.failed,
|
status=TaskState.failed,
|
||||||
error=f"Push notification timeout after {polling_timeout}s",
|
error=f"Unexpected task state: {final_task.status.state}",
|
||||||
history=new_messages,
|
history=new_messages,
|
||||||
)
|
)
|
||||||
|
|
||||||
result = process_task_state(
|
except A2AClientHTTPError as e:
|
||||||
a2a_task=final_task,
|
error_msg = f"HTTP Error {e.status_code}: {e!s}"
|
||||||
new_messages=new_messages,
|
|
||||||
agent_card=agent_card,
|
|
||||||
turn_number=turn_number,
|
|
||||||
is_multiturn=is_multiturn,
|
|
||||||
agent_role=agent_role,
|
|
||||||
)
|
|
||||||
if result:
|
|
||||||
return result
|
|
||||||
|
|
||||||
return TaskStateResult(
|
error_message = Message(
|
||||||
status=TaskState.failed,
|
role=Role.agent,
|
||||||
error=f"Unexpected task state: {final_task.status.state}",
|
message_id=str(uuid.uuid4()),
|
||||||
history=new_messages,
|
parts=[Part(root=TextPart(text=error_msg))],
|
||||||
)
|
context_id=context_id,
|
||||||
|
task_id=task_id,
|
||||||
|
)
|
||||||
|
new_messages.append(error_message)
|
||||||
|
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
agent_branch,
|
||||||
|
A2AResponseReceivedEvent(
|
||||||
|
response=error_msg,
|
||||||
|
turn_number=turn_number,
|
||||||
|
is_multiturn=is_multiturn,
|
||||||
|
status="failed",
|
||||||
|
agent_role=agent_role,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return TaskStateResult(
|
||||||
|
status=TaskState.failed,
|
||||||
|
error=error_msg,
|
||||||
|
history=new_messages,
|
||||||
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user