diff --git a/lib/crewai/src/crewai/a2a/task_helpers.py b/lib/crewai/src/crewai/a2a/task_helpers.py index 2d34b3c40..5352d7eb6 100644 --- a/lib/crewai/src/crewai/a2a/task_helpers.py +++ b/lib/crewai/src/crewai/a2a/task_helpers.py @@ -2,10 +2,21 @@ from __future__ import annotations +from collections.abc import AsyncIterator from typing import TYPE_CHECKING, NotRequired, TypedDict import uuid -from a2a.types import AgentCard, Message, Part, Role, TaskState, TextPart +from a2a.types import ( + AgentCard, + Message, + Part, + Role, + Task, + TaskArtifactUpdateEvent, + TaskState, + TaskStatusUpdateEvent, + TextPart, +) from crewai.events.event_bus import crewai_event_bus from crewai.events.types.a2a_events import A2AResponseReceivedEvent @@ -14,6 +25,10 @@ from crewai.events.types.a2a_events import A2AResponseReceivedEvent if TYPE_CHECKING: from a2a.types import Task as A2ATask +SendMessageEvent = ( + tuple[Task, TaskStatusUpdateEvent | TaskArtifactUpdateEvent | None] | Message +) + TERMINAL_STATES: frozenset[TaskState] = frozenset( { @@ -221,3 +236,78 @@ def process_task_state( ) return None + + +async def send_message_and_get_task_id( + event_stream: AsyncIterator[SendMessageEvent], + new_messages: list[Message], + agent_card: AgentCard, + turn_number: int, + is_multiturn: bool, + agent_role: str | None, +) -> str | TaskStateResult: + """Send message and process initial response. + + Handles the common pattern of sending a message and either: + - Getting an immediate Message response (task completed synchronously) + - Getting a Task that needs polling/waiting for completion + + Args: + event_stream: Async iterator from client.send_message() + new_messages: List to collect messages (modified in place) + agent_card: The agent card + turn_number: Current turn number + is_multiturn: Whether multi-turn conversation + agent_role: Agent role for logging + + Returns: + Task ID string if agent needs polling/waiting, or TaskStateResult if done. + """ + async for event in event_stream: + if isinstance(event, Message): + new_messages.append(event) + result_parts = [ + part.root.text for part in event.parts if part.root.kind == "text" + ] + response_text = " ".join(result_parts) if result_parts else "" + + crewai_event_bus.emit( + None, + A2AResponseReceivedEvent( + response=response_text, + turn_number=turn_number, + is_multiturn=is_multiturn, + status="completed", + agent_role=agent_role, + ), + ) + + return TaskStateResult( + status=TaskState.completed, + result=response_text, + history=new_messages, + agent_card=agent_card, + ) + + if isinstance(event, tuple): + a2a_task, _ = event + + if a2a_task.status.state in TERMINAL_STATES | ACTIONABLE_STATES: + result = process_task_state( + a2a_task=a2a_task, + new_messages=new_messages, + agent_card=agent_card, + turn_number=turn_number, + is_multiturn=is_multiturn, + agent_role=agent_role, + ) + if result: + return result + + return a2a_task.id + + return TaskStateResult( + status=TaskState.failed, + error="No task ID received from initial message", + history=new_messages, + ) diff --git a/lib/crewai/src/crewai/a2a/updates/polling/handler.py b/lib/crewai/src/crewai/a2a/updates/polling/handler.py index 84d41afcf..753a09730 100644 --- a/lib/crewai/src/crewai/a2a/updates/polling/handler.py +++ b/lib/crewai/src/crewai/a2a/updates/polling/handler.py @@ -21,13 +21,13 @@ from crewai.a2a.task_helpers import ( TERMINAL_STATES, TaskStateResult, process_task_state, + send_message_and_get_task_id, ) from crewai.a2a.updates.base import PollingHandlerKwargs from crewai.events.event_bus import crewai_event_bus from crewai.events.types.a2a_events import ( A2APollingStartedEvent, A2APollingStatusEvent, - A2AResponseReceivedEvent, ) @@ -81,10 +81,7 @@ async def _poll_task_until_complete( ), ) - if task.status.state in TERMINAL_STATES: - return task - - if task.status.state in ACTIONABLE_STATES: + if task.status.state in TERMINAL_STATES | ACTIONABLE_STATES: return task if elapsed > polling_timeout: @@ -133,57 +130,19 @@ class PollingHandler: history_length = kwargs.get("history_length", 100) max_polls = kwargs.get("max_polls") - task_id: str | None = None + result_or_task_id = await send_message_and_get_task_id( + event_stream=client.send_message(message), + new_messages=new_messages, + agent_card=agent_card, + turn_number=turn_number, + is_multiturn=is_multiturn, + agent_role=agent_role, + ) - async for event in client.send_message(message): - if isinstance(event, Message): - new_messages.append(event) - result_parts = [ - part.root.text for part in event.parts if part.root.kind == "text" - ] - response_text = " ".join(result_parts) if result_parts else "" + if not isinstance(result_or_task_id, str): + return result_or_task_id - crewai_event_bus.emit( - None, - A2AResponseReceivedEvent( - response=response_text, - turn_number=turn_number, - is_multiturn=is_multiturn, - status="completed", - agent_role=agent_role, - ), - ) - - return TaskStateResult( - status=TaskState.completed, - result=response_text, - history=new_messages, - agent_card=agent_card, - ) - - if isinstance(event, tuple): - a2a_task, _ = event - task_id = a2a_task.id - - if a2a_task.status.state in TERMINAL_STATES | ACTIONABLE_STATES: - result = process_task_state( - a2a_task=a2a_task, - new_messages=new_messages, - agent_card=agent_card, - turn_number=turn_number, - is_multiturn=is_multiturn, - agent_role=agent_role, - ) - if result: - return result - break - - if not task_id: - return TaskStateResult( - status=TaskState.failed, - error="No task ID received from initial message", - history=new_messages, - ) + task_id = result_or_task_id crewai_event_bus.emit( agent_branch,