mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-05-01 15:22:37 +00:00
refactor: extract shared message sending logic
This commit is contained in:
@@ -2,10 +2,21 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import AsyncIterator
|
||||||
from typing import TYPE_CHECKING, NotRequired, TypedDict
|
from typing import TYPE_CHECKING, NotRequired, TypedDict
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
from a2a.types import AgentCard, Message, Part, Role, TaskState, TextPart
|
from a2a.types import (
|
||||||
|
AgentCard,
|
||||||
|
Message,
|
||||||
|
Part,
|
||||||
|
Role,
|
||||||
|
Task,
|
||||||
|
TaskArtifactUpdateEvent,
|
||||||
|
TaskState,
|
||||||
|
TaskStatusUpdateEvent,
|
||||||
|
TextPart,
|
||||||
|
)
|
||||||
|
|
||||||
from crewai.events.event_bus import crewai_event_bus
|
from crewai.events.event_bus import crewai_event_bus
|
||||||
from crewai.events.types.a2a_events import A2AResponseReceivedEvent
|
from crewai.events.types.a2a_events import A2AResponseReceivedEvent
|
||||||
@@ -14,6 +25,10 @@ from crewai.events.types.a2a_events import A2AResponseReceivedEvent
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from a2a.types import Task as A2ATask
|
from a2a.types import Task as A2ATask
|
||||||
|
|
||||||
|
SendMessageEvent = (
|
||||||
|
tuple[Task, TaskStatusUpdateEvent | TaskArtifactUpdateEvent | None] | Message
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
TERMINAL_STATES: frozenset[TaskState] = frozenset(
|
TERMINAL_STATES: frozenset[TaskState] = frozenset(
|
||||||
{
|
{
|
||||||
@@ -221,3 +236,78 @@ def process_task_state(
|
|||||||
)
|
)
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def send_message_and_get_task_id(
|
||||||
|
event_stream: AsyncIterator[SendMessageEvent],
|
||||||
|
new_messages: list[Message],
|
||||||
|
agent_card: AgentCard,
|
||||||
|
turn_number: int,
|
||||||
|
is_multiturn: bool,
|
||||||
|
agent_role: str | None,
|
||||||
|
) -> str | TaskStateResult:
|
||||||
|
"""Send message and process initial response.
|
||||||
|
|
||||||
|
Handles the common pattern of sending a message and either:
|
||||||
|
- Getting an immediate Message response (task completed synchronously)
|
||||||
|
- Getting a Task that needs polling/waiting for completion
|
||||||
|
|
||||||
|
Args:
|
||||||
|
event_stream: Async iterator from client.send_message()
|
||||||
|
new_messages: List to collect messages (modified in place)
|
||||||
|
agent_card: The agent card
|
||||||
|
turn_number: Current turn number
|
||||||
|
is_multiturn: Whether multi-turn conversation
|
||||||
|
agent_role: Agent role for logging
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Task ID string if agent needs polling/waiting, or TaskStateResult if done.
|
||||||
|
"""
|
||||||
|
async for event in event_stream:
|
||||||
|
if isinstance(event, Message):
|
||||||
|
new_messages.append(event)
|
||||||
|
result_parts = [
|
||||||
|
part.root.text for part in event.parts if part.root.kind == "text"
|
||||||
|
]
|
||||||
|
response_text = " ".join(result_parts) if result_parts else ""
|
||||||
|
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
None,
|
||||||
|
A2AResponseReceivedEvent(
|
||||||
|
response=response_text,
|
||||||
|
turn_number=turn_number,
|
||||||
|
is_multiturn=is_multiturn,
|
||||||
|
status="completed",
|
||||||
|
agent_role=agent_role,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
return TaskStateResult(
|
||||||
|
status=TaskState.completed,
|
||||||
|
result=response_text,
|
||||||
|
history=new_messages,
|
||||||
|
agent_card=agent_card,
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(event, tuple):
|
||||||
|
a2a_task, _ = event
|
||||||
|
|
||||||
|
if a2a_task.status.state in TERMINAL_STATES | ACTIONABLE_STATES:
|
||||||
|
result = process_task_state(
|
||||||
|
a2a_task=a2a_task,
|
||||||
|
new_messages=new_messages,
|
||||||
|
agent_card=agent_card,
|
||||||
|
turn_number=turn_number,
|
||||||
|
is_multiturn=is_multiturn,
|
||||||
|
agent_role=agent_role,
|
||||||
|
)
|
||||||
|
if result:
|
||||||
|
return result
|
||||||
|
|
||||||
|
return a2a_task.id
|
||||||
|
|
||||||
|
return TaskStateResult(
|
||||||
|
status=TaskState.failed,
|
||||||
|
error="No task ID received from initial message",
|
||||||
|
history=new_messages,
|
||||||
|
)
|
||||||
|
|||||||
@@ -21,13 +21,13 @@ from crewai.a2a.task_helpers import (
|
|||||||
TERMINAL_STATES,
|
TERMINAL_STATES,
|
||||||
TaskStateResult,
|
TaskStateResult,
|
||||||
process_task_state,
|
process_task_state,
|
||||||
|
send_message_and_get_task_id,
|
||||||
)
|
)
|
||||||
from crewai.a2a.updates.base import PollingHandlerKwargs
|
from crewai.a2a.updates.base import PollingHandlerKwargs
|
||||||
from crewai.events.event_bus import crewai_event_bus
|
from crewai.events.event_bus import crewai_event_bus
|
||||||
from crewai.events.types.a2a_events import (
|
from crewai.events.types.a2a_events import (
|
||||||
A2APollingStartedEvent,
|
A2APollingStartedEvent,
|
||||||
A2APollingStatusEvent,
|
A2APollingStatusEvent,
|
||||||
A2AResponseReceivedEvent,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -81,10 +81,7 @@ async def _poll_task_until_complete(
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
if task.status.state in TERMINAL_STATES:
|
if task.status.state in TERMINAL_STATES | ACTIONABLE_STATES:
|
||||||
return task
|
|
||||||
|
|
||||||
if task.status.state in ACTIONABLE_STATES:
|
|
||||||
return task
|
return task
|
||||||
|
|
||||||
if elapsed > polling_timeout:
|
if elapsed > polling_timeout:
|
||||||
@@ -133,57 +130,19 @@ class PollingHandler:
|
|||||||
history_length = kwargs.get("history_length", 100)
|
history_length = kwargs.get("history_length", 100)
|
||||||
max_polls = kwargs.get("max_polls")
|
max_polls = kwargs.get("max_polls")
|
||||||
|
|
||||||
task_id: str | None = None
|
result_or_task_id = await send_message_and_get_task_id(
|
||||||
|
event_stream=client.send_message(message),
|
||||||
|
new_messages=new_messages,
|
||||||
|
agent_card=agent_card,
|
||||||
|
turn_number=turn_number,
|
||||||
|
is_multiturn=is_multiturn,
|
||||||
|
agent_role=agent_role,
|
||||||
|
)
|
||||||
|
|
||||||
async for event in client.send_message(message):
|
if not isinstance(result_or_task_id, str):
|
||||||
if isinstance(event, Message):
|
return result_or_task_id
|
||||||
new_messages.append(event)
|
|
||||||
result_parts = [
|
|
||||||
part.root.text for part in event.parts if part.root.kind == "text"
|
|
||||||
]
|
|
||||||
response_text = " ".join(result_parts) if result_parts else ""
|
|
||||||
|
|
||||||
crewai_event_bus.emit(
|
task_id = result_or_task_id
|
||||||
None,
|
|
||||||
A2AResponseReceivedEvent(
|
|
||||||
response=response_text,
|
|
||||||
turn_number=turn_number,
|
|
||||||
is_multiturn=is_multiturn,
|
|
||||||
status="completed",
|
|
||||||
agent_role=agent_role,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
return TaskStateResult(
|
|
||||||
status=TaskState.completed,
|
|
||||||
result=response_text,
|
|
||||||
history=new_messages,
|
|
||||||
agent_card=agent_card,
|
|
||||||
)
|
|
||||||
|
|
||||||
if isinstance(event, tuple):
|
|
||||||
a2a_task, _ = event
|
|
||||||
task_id = a2a_task.id
|
|
||||||
|
|
||||||
if a2a_task.status.state in TERMINAL_STATES | ACTIONABLE_STATES:
|
|
||||||
result = process_task_state(
|
|
||||||
a2a_task=a2a_task,
|
|
||||||
new_messages=new_messages,
|
|
||||||
agent_card=agent_card,
|
|
||||||
turn_number=turn_number,
|
|
||||||
is_multiturn=is_multiturn,
|
|
||||||
agent_role=agent_role,
|
|
||||||
)
|
|
||||||
if result:
|
|
||||||
return result
|
|
||||||
break
|
|
||||||
|
|
||||||
if not task_id:
|
|
||||||
return TaskStateResult(
|
|
||||||
status=TaskState.failed,
|
|
||||||
error="No task ID received from initial message",
|
|
||||||
history=new_messages,
|
|
||||||
)
|
|
||||||
|
|
||||||
crewai_event_bus.emit(
|
crewai_event_bus.emit(
|
||||||
agent_branch,
|
agent_branch,
|
||||||
|
|||||||
Reference in New Issue
Block a user