refactor: extract shared message sending logic

This commit is contained in:
Greyson LaLonde
2026-01-05 22:26:03 -05:00
parent 3607993e7e
commit 33d73c0be1
2 changed files with 104 additions and 55 deletions

View File

@@ -2,10 +2,21 @@
from __future__ import annotations
from collections.abc import AsyncIterator
from typing import TYPE_CHECKING, NotRequired, TypedDict
import uuid
from a2a.types import AgentCard, Message, Part, Role, TaskState, TextPart
from a2a.types import (
AgentCard,
Message,
Part,
Role,
Task,
TaskArtifactUpdateEvent,
TaskState,
TaskStatusUpdateEvent,
TextPart,
)
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.a2a_events import A2AResponseReceivedEvent
@@ -14,6 +25,10 @@ from crewai.events.types.a2a_events import A2AResponseReceivedEvent
if TYPE_CHECKING:
from a2a.types import Task as A2ATask
SendMessageEvent = (
tuple[Task, TaskStatusUpdateEvent | TaskArtifactUpdateEvent | None] | Message
)
TERMINAL_STATES: frozenset[TaskState] = frozenset(
{
@@ -221,3 +236,78 @@ def process_task_state(
)
return None
async def send_message_and_get_task_id(
event_stream: AsyncIterator[SendMessageEvent],
new_messages: list[Message],
agent_card: AgentCard,
turn_number: int,
is_multiturn: bool,
agent_role: str | None,
) -> str | TaskStateResult:
"""Send message and process initial response.
Handles the common pattern of sending a message and either:
- Getting an immediate Message response (task completed synchronously)
- Getting a Task that needs polling/waiting for completion
Args:
event_stream: Async iterator from client.send_message()
new_messages: List to collect messages (modified in place)
agent_card: The agent card
turn_number: Current turn number
is_multiturn: Whether multi-turn conversation
agent_role: Agent role for logging
Returns:
Task ID string if agent needs polling/waiting, or TaskStateResult if done.
"""
async for event in event_stream:
if isinstance(event, Message):
new_messages.append(event)
result_parts = [
part.root.text for part in event.parts if part.root.kind == "text"
]
response_text = " ".join(result_parts) if result_parts else ""
crewai_event_bus.emit(
None,
A2AResponseReceivedEvent(
response=response_text,
turn_number=turn_number,
is_multiturn=is_multiturn,
status="completed",
agent_role=agent_role,
),
)
return TaskStateResult(
status=TaskState.completed,
result=response_text,
history=new_messages,
agent_card=agent_card,
)
if isinstance(event, tuple):
a2a_task, _ = event
if a2a_task.status.state in TERMINAL_STATES | ACTIONABLE_STATES:
result = process_task_state(
a2a_task=a2a_task,
new_messages=new_messages,
agent_card=agent_card,
turn_number=turn_number,
is_multiturn=is_multiturn,
agent_role=agent_role,
)
if result:
return result
return a2a_task.id
return TaskStateResult(
status=TaskState.failed,
error="No task ID received from initial message",
history=new_messages,
)

View File

@@ -21,13 +21,13 @@ from crewai.a2a.task_helpers import (
TERMINAL_STATES,
TaskStateResult,
process_task_state,
send_message_and_get_task_id,
)
from crewai.a2a.updates.base import PollingHandlerKwargs
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.a2a_events import (
A2APollingStartedEvent,
A2APollingStatusEvent,
A2AResponseReceivedEvent,
)
@@ -81,10 +81,7 @@ async def _poll_task_until_complete(
),
)
if task.status.state in TERMINAL_STATES:
return task
if task.status.state in ACTIONABLE_STATES:
if task.status.state in TERMINAL_STATES | ACTIONABLE_STATES:
return task
if elapsed > polling_timeout:
@@ -133,57 +130,19 @@ class PollingHandler:
history_length = kwargs.get("history_length", 100)
max_polls = kwargs.get("max_polls")
task_id: str | None = None
result_or_task_id = await send_message_and_get_task_id(
event_stream=client.send_message(message),
new_messages=new_messages,
agent_card=agent_card,
turn_number=turn_number,
is_multiturn=is_multiturn,
agent_role=agent_role,
)
async for event in client.send_message(message):
if isinstance(event, Message):
new_messages.append(event)
result_parts = [
part.root.text for part in event.parts if part.root.kind == "text"
]
response_text = " ".join(result_parts) if result_parts else ""
if not isinstance(result_or_task_id, str):
return result_or_task_id
crewai_event_bus.emit(
None,
A2AResponseReceivedEvent(
response=response_text,
turn_number=turn_number,
is_multiturn=is_multiturn,
status="completed",
agent_role=agent_role,
),
)
return TaskStateResult(
status=TaskState.completed,
result=response_text,
history=new_messages,
agent_card=agent_card,
)
if isinstance(event, tuple):
a2a_task, _ = event
task_id = a2a_task.id
if a2a_task.status.state in TERMINAL_STATES | ACTIONABLE_STATES:
result = process_task_state(
a2a_task=a2a_task,
new_messages=new_messages,
agent_card=agent_card,
turn_number=turn_number,
is_multiturn=is_multiturn,
agent_role=agent_role,
)
if result:
return result
break
if not task_id:
return TaskStateResult(
status=TaskState.failed,
error="No task ID received from initial message",
history=new_messages,
)
task_id = result_or_task_id
crewai_event_bus.emit(
agent_branch,