mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-24 07:38:14 +00:00
chore: refactor handlers to unified protocol
This commit is contained in:
@@ -15,6 +15,23 @@ if TYPE_CHECKING:
|
|||||||
from a2a.types import Task as A2ATask
|
from a2a.types import Task as A2ATask
|
||||||
|
|
||||||
|
|
||||||
|
TERMINAL_STATES: frozenset[TaskState] = frozenset(
|
||||||
|
{
|
||||||
|
TaskState.completed,
|
||||||
|
TaskState.failed,
|
||||||
|
TaskState.rejected,
|
||||||
|
TaskState.canceled,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
ACTIONABLE_STATES: frozenset[TaskState] = frozenset(
|
||||||
|
{
|
||||||
|
TaskState.input_required,
|
||||||
|
TaskState.auth_required,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TaskStateResult(TypedDict):
|
class TaskStateResult(TypedDict):
|
||||||
"""Result dictionary from processing A2A task state."""
|
"""Result dictionary from processing A2A task state."""
|
||||||
|
|
||||||
@@ -154,8 +171,8 @@ def process_task_state(
|
|||||||
role=Role.agent,
|
role=Role.agent,
|
||||||
message_id=str(uuid.uuid4()),
|
message_id=str(uuid.uuid4()),
|
||||||
parts=[Part(root=TextPart(text=response_text))],
|
parts=[Part(root=TextPart(text=response_text))],
|
||||||
context_id=getattr(a2a_task, "context_id", None),
|
context_id=a2a_task.context_id,
|
||||||
task_id=getattr(a2a_task, "task_id", None),
|
task_id=a2a_task.id,
|
||||||
)
|
)
|
||||||
new_messages.append(agent_message)
|
new_messages.append(agent_message)
|
||||||
|
|
||||||
|
|||||||
@@ -1,15 +1,33 @@
|
|||||||
"""A2A update mechanism configuration types."""
|
"""A2A update mechanism configuration types."""
|
||||||
|
|
||||||
|
from crewai.a2a.updates.base import (
|
||||||
|
BaseHandlerKwargs,
|
||||||
|
PollingHandlerKwargs,
|
||||||
|
PushNotificationHandlerKwargs,
|
||||||
|
StreamingHandlerKwargs,
|
||||||
|
UpdateHandler,
|
||||||
|
)
|
||||||
from crewai.a2a.updates.polling.config import PollingConfig
|
from crewai.a2a.updates.polling.config import PollingConfig
|
||||||
|
from crewai.a2a.updates.polling.handler import PollingHandler
|
||||||
from crewai.a2a.updates.push_notifications.config import PushNotificationConfig
|
from crewai.a2a.updates.push_notifications.config import PushNotificationConfig
|
||||||
|
from crewai.a2a.updates.push_notifications.handler import PushNotificationHandler
|
||||||
from crewai.a2a.updates.streaming.config import StreamingConfig
|
from crewai.a2a.updates.streaming.config import StreamingConfig
|
||||||
|
from crewai.a2a.updates.streaming.handler import StreamingHandler
|
||||||
|
|
||||||
|
|
||||||
UpdateConfig = PollingConfig | StreamingConfig | PushNotificationConfig
|
UpdateConfig = PollingConfig | StreamingConfig | PushNotificationConfig
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
"BaseHandlerKwargs",
|
||||||
"PollingConfig",
|
"PollingConfig",
|
||||||
|
"PollingHandler",
|
||||||
|
"PollingHandlerKwargs",
|
||||||
"PushNotificationConfig",
|
"PushNotificationConfig",
|
||||||
|
"PushNotificationHandler",
|
||||||
|
"PushNotificationHandlerKwargs",
|
||||||
"StreamingConfig",
|
"StreamingConfig",
|
||||||
|
"StreamingHandler",
|
||||||
|
"StreamingHandlerKwargs",
|
||||||
"UpdateConfig",
|
"UpdateConfig",
|
||||||
|
"UpdateHandler",
|
||||||
]
|
]
|
||||||
|
|||||||
68
lib/crewai/src/crewai/a2a/updates/base.py
Normal file
68
lib/crewai/src/crewai/a2a/updates/base.py
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
"""Base types for A2A update mechanism handlers."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING, Any, Protocol, TypedDict
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from a2a.client import Client
|
||||||
|
from a2a.types import AgentCard, Message
|
||||||
|
|
||||||
|
from crewai.a2a.task_helpers import TaskStateResult
|
||||||
|
|
||||||
|
|
||||||
|
class BaseHandlerKwargs(TypedDict, total=False):
|
||||||
|
"""Base kwargs shared by all handlers."""
|
||||||
|
|
||||||
|
turn_number: int
|
||||||
|
is_multiturn: bool
|
||||||
|
agent_role: str | None
|
||||||
|
|
||||||
|
|
||||||
|
class PollingHandlerKwargs(BaseHandlerKwargs, total=False):
|
||||||
|
"""Kwargs for polling handler."""
|
||||||
|
|
||||||
|
polling_interval: float
|
||||||
|
polling_timeout: float
|
||||||
|
endpoint: str
|
||||||
|
agent_branch: Any
|
||||||
|
history_length: int
|
||||||
|
max_polls: int | None
|
||||||
|
|
||||||
|
|
||||||
|
class StreamingHandlerKwargs(BaseHandlerKwargs, total=False):
|
||||||
|
"""Kwargs for streaming handler."""
|
||||||
|
|
||||||
|
context_id: str | None
|
||||||
|
task_id: str | None
|
||||||
|
|
||||||
|
|
||||||
|
class PushNotificationHandlerKwargs(BaseHandlerKwargs, total=False):
|
||||||
|
"""Kwargs for push notification handler."""
|
||||||
|
|
||||||
|
|
||||||
|
class UpdateHandler(Protocol):
|
||||||
|
"""Protocol for A2A update mechanism handlers."""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def execute(
|
||||||
|
client: Client,
|
||||||
|
message: Message,
|
||||||
|
new_messages: list[Message],
|
||||||
|
agent_card: AgentCard,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> TaskStateResult:
|
||||||
|
"""Execute the update mechanism and return result.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
client: A2A client instance.
|
||||||
|
message: Message to send.
|
||||||
|
new_messages: List to collect messages (modified in place).
|
||||||
|
agent_card: The agent card.
|
||||||
|
**kwargs: Additional handler-specific parameters.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Result dictionary with status, result/error, and history.
|
||||||
|
"""
|
||||||
|
...
|
||||||
@@ -15,9 +15,11 @@ class PollingConfig(BaseModel):
|
|||||||
history_length: Number of messages to retrieve per poll.
|
history_length: Number of messages to retrieve per poll.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
interval: float = Field(default=2.0, description="Seconds between poll attempts")
|
interval: float = Field(
|
||||||
timeout: float | None = Field(default=None, description="Max seconds to poll")
|
default=2.0, gt=0, description="Seconds between poll attempts"
|
||||||
max_polls: int | None = Field(default=None, description="Max poll attempts")
|
)
|
||||||
history_length: int = Field(
|
timeout: float | None = Field(default=None, gt=0, description="Max seconds to poll")
|
||||||
default=100, description="Messages to retrieve per poll"
|
max_polls: int | None = Field(default=None, gt=0, description="Max poll attempts")
|
||||||
|
history_length: int = Field(
|
||||||
|
default=100, gt=0, description="Messages to retrieve per poll"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import time
|
import time
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING, Any, Unpack
|
||||||
|
|
||||||
from a2a.client import Client
|
from a2a.client import Client
|
||||||
from a2a.types import (
|
from a2a.types import (
|
||||||
@@ -15,7 +15,13 @@ from a2a.types import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
from crewai.a2a.errors import A2APollingTimeoutError
|
from crewai.a2a.errors import A2APollingTimeoutError
|
||||||
from crewai.a2a.task_helpers import TaskStateResult, process_task_state
|
from crewai.a2a.task_helpers import (
|
||||||
|
ACTIONABLE_STATES,
|
||||||
|
TERMINAL_STATES,
|
||||||
|
TaskStateResult,
|
||||||
|
process_task_state,
|
||||||
|
)
|
||||||
|
from crewai.a2a.updates.base import PollingHandlerKwargs
|
||||||
from crewai.events.event_bus import crewai_event_bus
|
from crewai.events.event_bus import crewai_event_bus
|
||||||
from crewai.events.types.a2a_events import (
|
from crewai.events.types.a2a_events import (
|
||||||
A2APollingStartedEvent,
|
A2APollingStartedEvent,
|
||||||
@@ -28,15 +34,7 @@ if TYPE_CHECKING:
|
|||||||
from a2a.types import Task as A2ATask
|
from a2a.types import Task as A2ATask
|
||||||
|
|
||||||
|
|
||||||
TERMINAL_STATES = {
|
async def _poll_task_until_complete(
|
||||||
TaskState.completed,
|
|
||||||
TaskState.failed,
|
|
||||||
TaskState.rejected,
|
|
||||||
TaskState.canceled,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
async def poll_task_until_complete(
|
|
||||||
client: Client,
|
client: Client,
|
||||||
task_id: str,
|
task_id: str,
|
||||||
polling_interval: float,
|
polling_interval: float,
|
||||||
@@ -85,7 +83,7 @@ async def poll_task_until_complete(
|
|||||||
if task.status.state in TERMINAL_STATES:
|
if task.status.state in TERMINAL_STATES:
|
||||||
return task
|
return task
|
||||||
|
|
||||||
if task.status.state in {TaskState.input_required, TaskState.auth_required}:
|
if task.status.state in ACTIONABLE_STATES:
|
||||||
return task
|
return task
|
||||||
|
|
||||||
if elapsed > polling_timeout:
|
if elapsed > polling_timeout:
|
||||||
@@ -101,128 +99,123 @@ async def poll_task_until_complete(
|
|||||||
await asyncio.sleep(polling_interval)
|
await asyncio.sleep(polling_interval)
|
||||||
|
|
||||||
|
|
||||||
async def execute_polling_delegation(
|
class PollingHandler:
|
||||||
client: Client,
|
"""Polling-based update handler."""
|
||||||
message: Message,
|
|
||||||
polling_interval: float,
|
|
||||||
polling_timeout: float,
|
|
||||||
endpoint: str,
|
|
||||||
agent_branch: Any | None,
|
|
||||||
turn_number: int,
|
|
||||||
is_multiturn: bool,
|
|
||||||
agent_role: str | None,
|
|
||||||
new_messages: list[Message],
|
|
||||||
agent_card: AgentCard,
|
|
||||||
history_length: int = 100,
|
|
||||||
max_polls: int | None = None,
|
|
||||||
) -> TaskStateResult:
|
|
||||||
"""Execute A2A delegation using polling for updates.
|
|
||||||
|
|
||||||
Args:
|
@staticmethod
|
||||||
client: A2A client instance
|
async def execute(
|
||||||
message: Message to send
|
client: Client,
|
||||||
polling_interval: Seconds between poll attempts
|
message: Message,
|
||||||
polling_timeout: Max seconds before timeout
|
new_messages: list[Message],
|
||||||
endpoint: A2A agent endpoint URL
|
agent_card: AgentCard,
|
||||||
agent_branch: Agent tree branch for logging
|
**kwargs: Unpack[PollingHandlerKwargs],
|
||||||
turn_number: Current turn number
|
) -> TaskStateResult:
|
||||||
is_multiturn: Whether this is a multi-turn conversation
|
"""Execute A2A delegation using polling for updates.
|
||||||
agent_role: Agent role for logging
|
|
||||||
new_messages: List to collect messages
|
|
||||||
agent_card: The agent card
|
|
||||||
history_length: Number of messages to retrieve per poll
|
|
||||||
max_polls: Max number of poll attempts (None = unlimited)
|
|
||||||
|
|
||||||
Returns:
|
Args:
|
||||||
Dictionary with status, result/error, and history
|
client: A2A client instance.
|
||||||
"""
|
message: Message to send.
|
||||||
task_id: str | None = None
|
new_messages: List to collect messages.
|
||||||
|
agent_card: The agent card.
|
||||||
|
**kwargs: Polling-specific parameters.
|
||||||
|
|
||||||
async for event in client.send_message(message):
|
Returns:
|
||||||
if isinstance(event, Message):
|
Dictionary with status, result/error, and history.
|
||||||
new_messages.append(event)
|
"""
|
||||||
result_parts = [
|
polling_interval = kwargs.get("polling_interval", 2.0)
|
||||||
part.root.text for part in event.parts if part.root.kind == "text"
|
polling_timeout = kwargs.get("polling_timeout", 300.0)
|
||||||
]
|
endpoint = kwargs.get("endpoint", "")
|
||||||
response_text = " ".join(result_parts) if result_parts else ""
|
agent_branch = kwargs.get("agent_branch")
|
||||||
|
turn_number = kwargs.get("turn_number", 0)
|
||||||
|
is_multiturn = kwargs.get("is_multiturn", False)
|
||||||
|
agent_role = kwargs.get("agent_role")
|
||||||
|
history_length = kwargs.get("history_length", 100)
|
||||||
|
max_polls = kwargs.get("max_polls")
|
||||||
|
|
||||||
crewai_event_bus.emit(
|
task_id: str | None = None
|
||||||
None,
|
|
||||||
A2AResponseReceivedEvent(
|
|
||||||
response=response_text,
|
|
||||||
turn_number=turn_number,
|
|
||||||
is_multiturn=is_multiturn,
|
|
||||||
status="completed",
|
|
||||||
agent_role=agent_role,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
return TaskStateResult(
|
async for event in client.send_message(message):
|
||||||
status=TaskState.completed,
|
if isinstance(event, Message):
|
||||||
result=response_text,
|
new_messages.append(event)
|
||||||
history=new_messages,
|
result_parts = [
|
||||||
agent_card=agent_card,
|
part.root.text for part in event.parts if part.root.kind == "text"
|
||||||
)
|
]
|
||||||
|
response_text = " ".join(result_parts) if result_parts else ""
|
||||||
|
|
||||||
if isinstance(event, tuple):
|
crewai_event_bus.emit(
|
||||||
a2a_task, _ = event
|
None,
|
||||||
task_id = a2a_task.id
|
A2AResponseReceivedEvent(
|
||||||
|
response=response_text,
|
||||||
if a2a_task.status.state in TERMINAL_STATES | {
|
turn_number=turn_number,
|
||||||
TaskState.input_required,
|
is_multiturn=is_multiturn,
|
||||||
TaskState.auth_required,
|
status="completed",
|
||||||
}:
|
agent_role=agent_role,
|
||||||
result = process_task_state(
|
),
|
||||||
a2a_task=a2a_task,
|
|
||||||
new_messages=new_messages,
|
|
||||||
agent_card=agent_card,
|
|
||||||
turn_number=turn_number,
|
|
||||||
is_multiturn=is_multiturn,
|
|
||||||
agent_role=agent_role,
|
|
||||||
)
|
)
|
||||||
if result:
|
|
||||||
return result
|
|
||||||
break
|
|
||||||
|
|
||||||
if not task_id:
|
return TaskStateResult(
|
||||||
return TaskStateResult(
|
status=TaskState.completed,
|
||||||
status=TaskState.failed,
|
result=response_text,
|
||||||
error="No task ID received from initial message",
|
history=new_messages,
|
||||||
history=new_messages,
|
agent_card=agent_card,
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(event, tuple):
|
||||||
|
a2a_task, _ = event
|
||||||
|
task_id = a2a_task.id
|
||||||
|
|
||||||
|
if a2a_task.status.state in TERMINAL_STATES | ACTIONABLE_STATES:
|
||||||
|
result = process_task_state(
|
||||||
|
a2a_task=a2a_task,
|
||||||
|
new_messages=new_messages,
|
||||||
|
agent_card=agent_card,
|
||||||
|
turn_number=turn_number,
|
||||||
|
is_multiturn=is_multiturn,
|
||||||
|
agent_role=agent_role,
|
||||||
|
)
|
||||||
|
if result:
|
||||||
|
return result
|
||||||
|
break
|
||||||
|
|
||||||
|
if not task_id:
|
||||||
|
return TaskStateResult(
|
||||||
|
status=TaskState.failed,
|
||||||
|
error="No task ID received from initial message",
|
||||||
|
history=new_messages,
|
||||||
|
)
|
||||||
|
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
agent_branch,
|
||||||
|
A2APollingStartedEvent(
|
||||||
|
task_id=task_id,
|
||||||
|
polling_interval=polling_interval,
|
||||||
|
endpoint=endpoint,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
crewai_event_bus.emit(
|
final_task = await _poll_task_until_complete(
|
||||||
agent_branch,
|
client=client,
|
||||||
A2APollingStartedEvent(
|
|
||||||
task_id=task_id,
|
task_id=task_id,
|
||||||
polling_interval=polling_interval,
|
polling_interval=polling_interval,
|
||||||
endpoint=endpoint,
|
polling_timeout=polling_timeout,
|
||||||
),
|
agent_branch=agent_branch,
|
||||||
)
|
history_length=history_length,
|
||||||
|
max_polls=max_polls,
|
||||||
|
)
|
||||||
|
|
||||||
final_task = await poll_task_until_complete(
|
result = process_task_state(
|
||||||
client=client,
|
a2a_task=final_task,
|
||||||
task_id=task_id,
|
new_messages=new_messages,
|
||||||
polling_interval=polling_interval,
|
agent_card=agent_card,
|
||||||
polling_timeout=polling_timeout,
|
turn_number=turn_number,
|
||||||
agent_branch=agent_branch,
|
is_multiturn=is_multiturn,
|
||||||
history_length=history_length,
|
agent_role=agent_role,
|
||||||
max_polls=max_polls,
|
)
|
||||||
)
|
if result:
|
||||||
|
return result
|
||||||
|
|
||||||
result = process_task_state(
|
return TaskStateResult(
|
||||||
a2a_task=final_task,
|
status=TaskState.failed,
|
||||||
new_messages=new_messages,
|
error=f"Unexpected task state: {final_task.status.state}",
|
||||||
agent_card=agent_card,
|
history=new_messages,
|
||||||
turn_number=turn_number,
|
)
|
||||||
is_multiturn=is_multiturn,
|
|
||||||
agent_role=agent_role,
|
|
||||||
)
|
|
||||||
if result:
|
|
||||||
return result
|
|
||||||
|
|
||||||
return TaskStateResult(
|
|
||||||
status=TaskState.failed,
|
|
||||||
error=f"Unexpected task state: {final_task.status.state}",
|
|
||||||
history=new_messages,
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import AnyHttpUrl, BaseModel, Field
|
||||||
|
|
||||||
from crewai.a2a.auth.schemas import AuthScheme
|
from crewai.a2a.auth.schemas import AuthScheme
|
||||||
|
|
||||||
@@ -17,7 +17,7 @@ class PushNotificationConfig(BaseModel):
|
|||||||
authentication: Auth scheme for the callback endpoint.
|
authentication: Auth scheme for the callback endpoint.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
url: str = Field(description="Callback URL for push notifications")
|
url: AnyHttpUrl = Field(description="Callback URL for push notifications")
|
||||||
id: str | None = Field(default=None, description="Unique config identifier")
|
id: str | None = Field(default=None, description="Unique config identifier")
|
||||||
token: str | None = Field(default=None, description="Validation token")
|
token: str | None = Field(default=None, description="Validation token")
|
||||||
authentication: AuthScheme | None = Field(
|
authentication: AuthScheme | None = Field(
|
||||||
|
|||||||
@@ -1,3 +1,40 @@
|
|||||||
"""Push notification (webhook) update mechanism handler."""
|
"""Push notification (webhook) update mechanism handler."""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Unpack
|
||||||
|
|
||||||
|
from a2a.client import Client
|
||||||
|
from a2a.types import AgentCard, Message
|
||||||
|
|
||||||
|
from crewai.a2a.task_helpers import TaskStateResult
|
||||||
|
from crewai.a2a.updates.base import PushNotificationHandlerKwargs
|
||||||
|
|
||||||
|
|
||||||
|
class PushNotificationHandler:
|
||||||
|
"""Push notification (webhook) based update handler."""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def execute(
|
||||||
|
client: Client,
|
||||||
|
message: Message,
|
||||||
|
new_messages: list[Message],
|
||||||
|
agent_card: AgentCard,
|
||||||
|
**kwargs: Unpack[PushNotificationHandlerKwargs],
|
||||||
|
) -> TaskStateResult:
|
||||||
|
"""Execute A2A delegation using push notifications for updates.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
client: A2A client instance.
|
||||||
|
message: Message to send.
|
||||||
|
new_messages: List to collect messages.
|
||||||
|
agent_card: The agent card.
|
||||||
|
**kwargs: Push notification-specific parameters.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
NotImplementedError: Push notifications not yet implemented.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Push notification update mechanism is not yet implemented. "
|
||||||
|
"Use PollingConfig or StreamingConfig instead."
|
||||||
|
)
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Unpack
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
from a2a.client import Client
|
from a2a.client import Client
|
||||||
@@ -17,136 +18,126 @@ from a2a.types import (
|
|||||||
TextPart,
|
TextPart,
|
||||||
)
|
)
|
||||||
|
|
||||||
from crewai.a2a.task_helpers import TaskStateResult, process_task_state
|
from crewai.a2a.task_helpers import (
|
||||||
|
ACTIONABLE_STATES,
|
||||||
|
TERMINAL_STATES,
|
||||||
|
TaskStateResult,
|
||||||
|
process_task_state,
|
||||||
|
)
|
||||||
|
from crewai.a2a.updates.base import StreamingHandlerKwargs
|
||||||
from crewai.events.event_bus import crewai_event_bus
|
from crewai.events.event_bus import crewai_event_bus
|
||||||
from crewai.events.types.a2a_events import A2AResponseReceivedEvent
|
from crewai.events.types.a2a_events import A2AResponseReceivedEvent
|
||||||
|
|
||||||
|
|
||||||
async def execute_streaming_delegation(
|
class StreamingHandler:
|
||||||
client: Client,
|
"""SSE streaming-based update handler."""
|
||||||
message: Message,
|
|
||||||
context_id: str | None,
|
|
||||||
task_id: str | None,
|
|
||||||
turn_number: int,
|
|
||||||
is_multiturn: bool,
|
|
||||||
agent_role: str | None,
|
|
||||||
new_messages: list[Message],
|
|
||||||
agent_card: AgentCard,
|
|
||||||
) -> TaskStateResult:
|
|
||||||
"""Execute A2A delegation using SSE streaming for updates.
|
|
||||||
|
|
||||||
Args:
|
@staticmethod
|
||||||
client: A2A client instance
|
async def execute(
|
||||||
message: Message to send
|
client: Client,
|
||||||
context_id: Context ID for correlation
|
message: Message,
|
||||||
task_id: Task ID for correlation
|
new_messages: list[Message],
|
||||||
turn_number: Current turn number
|
agent_card: AgentCard,
|
||||||
is_multiturn: Whether this is a multi-turn conversation
|
**kwargs: Unpack[StreamingHandlerKwargs],
|
||||||
agent_role: Agent role for logging
|
) -> TaskStateResult:
|
||||||
new_messages: List to collect messages
|
"""Execute A2A delegation using SSE streaming for updates.
|
||||||
agent_card: The agent card
|
|
||||||
|
|
||||||
Returns:
|
Args:
|
||||||
Dictionary with status, result/error, and history
|
client: A2A client instance.
|
||||||
"""
|
message: Message to send.
|
||||||
result_parts: list[str] = []
|
new_messages: List to collect messages.
|
||||||
final_result: TaskStateResult | None = None
|
agent_card: The agent card.
|
||||||
event_stream = client.send_message(message)
|
**kwargs: Streaming-specific parameters.
|
||||||
|
|
||||||
try:
|
Returns:
|
||||||
async for event in event_stream:
|
Dictionary with status, result/error, and history.
|
||||||
if isinstance(event, Message):
|
"""
|
||||||
new_messages.append(event)
|
context_id = kwargs.get("context_id")
|
||||||
for part in event.parts:
|
task_id = kwargs.get("task_id")
|
||||||
if part.root.kind == "text":
|
turn_number = kwargs.get("turn_number", 0)
|
||||||
text = part.root.text
|
is_multiturn = kwargs.get("is_multiturn", False)
|
||||||
result_parts.append(text)
|
agent_role = kwargs.get("agent_role")
|
||||||
|
|
||||||
elif isinstance(event, tuple):
|
result_parts: list[str] = []
|
||||||
a2a_task, update = event
|
final_result: TaskStateResult | None = None
|
||||||
|
event_stream = client.send_message(message)
|
||||||
|
|
||||||
if isinstance(update, TaskArtifactUpdateEvent):
|
try:
|
||||||
artifact = update.artifact
|
async for event in event_stream:
|
||||||
result_parts.extend(
|
if isinstance(event, Message):
|
||||||
part.root.text
|
new_messages.append(event)
|
||||||
for part in artifact.parts
|
for part in event.parts:
|
||||||
if part.root.kind == "text"
|
if part.root.kind == "text":
|
||||||
|
text = part.root.text
|
||||||
|
result_parts.append(text)
|
||||||
|
|
||||||
|
elif isinstance(event, tuple):
|
||||||
|
a2a_task, update = event
|
||||||
|
|
||||||
|
if isinstance(update, TaskArtifactUpdateEvent):
|
||||||
|
artifact = update.artifact
|
||||||
|
result_parts.extend(
|
||||||
|
part.root.text
|
||||||
|
for part in artifact.parts
|
||||||
|
if part.root.kind == "text"
|
||||||
|
)
|
||||||
|
|
||||||
|
is_final_update = False
|
||||||
|
if isinstance(update, TaskStatusUpdateEvent):
|
||||||
|
is_final_update = update.final
|
||||||
|
|
||||||
|
if (
|
||||||
|
not is_final_update
|
||||||
|
and a2a_task.status.state
|
||||||
|
not in TERMINAL_STATES | ACTIONABLE_STATES
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
|
||||||
|
final_result = process_task_state(
|
||||||
|
a2a_task=a2a_task,
|
||||||
|
new_messages=new_messages,
|
||||||
|
agent_card=agent_card,
|
||||||
|
turn_number=turn_number,
|
||||||
|
is_multiturn=is_multiturn,
|
||||||
|
agent_role=agent_role,
|
||||||
|
result_parts=result_parts,
|
||||||
)
|
)
|
||||||
|
if final_result:
|
||||||
|
break
|
||||||
|
|
||||||
is_final_update = False
|
except A2AClientHTTPError as e:
|
||||||
if isinstance(update, TaskStatusUpdateEvent):
|
error_msg = f"HTTP Error {e.status_code}: {e!s}"
|
||||||
is_final_update = update.final
|
|
||||||
|
|
||||||
if not is_final_update and a2a_task.status.state not in [
|
error_message = Message(
|
||||||
TaskState.completed,
|
role=Role.agent,
|
||||||
TaskState.input_required,
|
message_id=str(uuid.uuid4()),
|
||||||
TaskState.failed,
|
parts=[Part(root=TextPart(text=error_msg))],
|
||||||
TaskState.rejected,
|
context_id=context_id,
|
||||||
TaskState.auth_required,
|
task_id=task_id,
|
||||||
TaskState.canceled,
|
)
|
||||||
]:
|
new_messages.append(error_message)
|
||||||
continue
|
|
||||||
|
|
||||||
final_result = process_task_state(
|
crewai_event_bus.emit(
|
||||||
a2a_task=a2a_task,
|
None,
|
||||||
new_messages=new_messages,
|
A2AResponseReceivedEvent(
|
||||||
agent_card=agent_card,
|
response=error_msg,
|
||||||
turn_number=turn_number,
|
turn_number=turn_number,
|
||||||
is_multiturn=is_multiturn,
|
is_multiturn=is_multiturn,
|
||||||
|
status="failed",
|
||||||
agent_role=agent_role,
|
agent_role=agent_role,
|
||||||
result_parts=result_parts,
|
),
|
||||||
)
|
)
|
||||||
if final_result:
|
return TaskStateResult(
|
||||||
break
|
status=TaskState.failed,
|
||||||
|
error=error_msg,
|
||||||
|
history=new_messages,
|
||||||
|
)
|
||||||
|
|
||||||
except A2AClientHTTPError as e:
|
if final_result:
|
||||||
error_msg = f"HTTP Error {e.status_code}: {e!s}"
|
return final_result
|
||||||
|
|
||||||
error_message = Message(
|
|
||||||
role=Role.agent,
|
|
||||||
message_id=str(uuid.uuid4()),
|
|
||||||
parts=[Part(root=TextPart(text=error_msg))],
|
|
||||||
context_id=context_id,
|
|
||||||
task_id=task_id,
|
|
||||||
)
|
|
||||||
new_messages.append(error_message)
|
|
||||||
|
|
||||||
crewai_event_bus.emit(
|
|
||||||
None,
|
|
||||||
A2AResponseReceivedEvent(
|
|
||||||
response=error_msg,
|
|
||||||
turn_number=turn_number,
|
|
||||||
is_multiturn=is_multiturn,
|
|
||||||
status="failed",
|
|
||||||
agent_role=agent_role,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
return TaskStateResult(
|
return TaskStateResult(
|
||||||
status=TaskState.failed,
|
status=TaskState.completed,
|
||||||
error=error_msg,
|
result=" ".join(result_parts) if result_parts else "",
|
||||||
history=new_messages,
|
history=new_messages,
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
current_exception: Exception | BaseException | None = e
|
|
||||||
while current_exception:
|
|
||||||
if hasattr(current_exception, "response"):
|
|
||||||
response = current_exception.response
|
|
||||||
if hasattr(response, "text"):
|
|
||||||
break
|
|
||||||
if current_exception and hasattr(current_exception, "__cause__"):
|
|
||||||
current_exception = current_exception.__cause__
|
|
||||||
raise
|
|
||||||
|
|
||||||
finally:
|
|
||||||
if hasattr(event_stream, "aclose"):
|
|
||||||
await event_stream.aclose()
|
|
||||||
|
|
||||||
if final_result:
|
|
||||||
return final_result
|
|
||||||
|
|
||||||
return TaskStateResult(
|
|
||||||
status=TaskState.completed,
|
|
||||||
result=" ".join(result_parts) if result_parts else "",
|
|
||||||
history=new_messages,
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -34,9 +34,15 @@ from crewai.a2a.auth.utils import (
|
|||||||
from crewai.a2a.config import A2AConfig
|
from crewai.a2a.config import A2AConfig
|
||||||
from crewai.a2a.task_helpers import TaskStateResult
|
from crewai.a2a.task_helpers import TaskStateResult
|
||||||
from crewai.a2a.types import PartsDict, PartsMetadataDict
|
from crewai.a2a.types import PartsDict, PartsMetadataDict
|
||||||
from crewai.a2a.updates import PollingConfig, UpdateConfig
|
from crewai.a2a.updates import (
|
||||||
from crewai.a2a.updates.polling.handler import execute_polling_delegation
|
PollingConfig,
|
||||||
from crewai.a2a.updates.streaming.handler import execute_streaming_delegation
|
PollingHandler,
|
||||||
|
PushNotificationConfig,
|
||||||
|
PushNotificationHandler,
|
||||||
|
StreamingConfig,
|
||||||
|
StreamingHandler,
|
||||||
|
UpdateConfig,
|
||||||
|
)
|
||||||
from crewai.events.event_bus import crewai_event_bus
|
from crewai.events.event_bus import crewai_event_bus
|
||||||
from crewai.events.types.a2a_events import (
|
from crewai.events.types.a2a_events import (
|
||||||
A2AConversationStartedEvent,
|
A2AConversationStartedEvent,
|
||||||
@@ -53,6 +59,31 @@ if TYPE_CHECKING:
|
|||||||
from crewai.a2a.auth.schemas import AuthScheme
|
from crewai.a2a.auth.schemas import AuthScheme
|
||||||
|
|
||||||
|
|
||||||
|
HandlerType = (
|
||||||
|
type[PollingHandler] | type[StreamingHandler] | type[PushNotificationHandler]
|
||||||
|
)
|
||||||
|
|
||||||
|
HANDLER_REGISTRY: dict[type[UpdateConfig], HandlerType] = {
|
||||||
|
PollingConfig: PollingHandler,
|
||||||
|
StreamingConfig: StreamingHandler,
|
||||||
|
PushNotificationConfig: PushNotificationHandler,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_handler(config: UpdateConfig | None) -> HandlerType:
|
||||||
|
"""Get the handler class for a given update config.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Update mechanism configuration.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Handler class for the config type, defaults to StreamingHandler.
|
||||||
|
"""
|
||||||
|
if config is None:
|
||||||
|
return StreamingHandler
|
||||||
|
return HANDLER_REGISTRY.get(type(config), StreamingHandler)
|
||||||
|
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def _fetch_agent_card_cached(
|
def _fetch_agent_card_cached(
|
||||||
endpoint: str,
|
endpoint: str,
|
||||||
@@ -448,14 +479,28 @@ async def _execute_a2a_delegation_async(
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
polling_config = updates if isinstance(updates, PollingConfig) else None
|
handler = get_handler(updates)
|
||||||
use_polling = polling_config is not None
|
use_polling = isinstance(updates, PollingConfig)
|
||||||
polling_interval = polling_config.interval if polling_config else 2.0
|
|
||||||
effective_polling_timeout = (
|
handler_kwargs: dict[str, Any] = {
|
||||||
polling_config.timeout
|
"turn_number": turn_number,
|
||||||
if polling_config and polling_config.timeout
|
"is_multiturn": is_multiturn,
|
||||||
else float(timeout)
|
"agent_role": agent_role,
|
||||||
)
|
"context_id": context_id,
|
||||||
|
"task_id": task_id,
|
||||||
|
"endpoint": endpoint,
|
||||||
|
"agent_branch": agent_branch,
|
||||||
|
}
|
||||||
|
|
||||||
|
if isinstance(updates, PollingConfig):
|
||||||
|
handler_kwargs.update(
|
||||||
|
{
|
||||||
|
"polling_interval": updates.interval,
|
||||||
|
"polling_timeout": updates.timeout or float(timeout),
|
||||||
|
"history_length": updates.history_length,
|
||||||
|
"max_polls": updates.max_polls,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
async with _create_a2a_client(
|
async with _create_a2a_client(
|
||||||
agent_card=agent_card,
|
agent_card=agent_card,
|
||||||
@@ -466,33 +511,12 @@ async def _execute_a2a_delegation_async(
|
|||||||
auth=auth,
|
auth=auth,
|
||||||
use_polling=use_polling,
|
use_polling=use_polling,
|
||||||
) as client:
|
) as client:
|
||||||
if use_polling and polling_config:
|
return await handler.execute(
|
||||||
return await execute_polling_delegation(
|
|
||||||
client=client,
|
|
||||||
message=message,
|
|
||||||
polling_interval=polling_interval,
|
|
||||||
polling_timeout=effective_polling_timeout,
|
|
||||||
endpoint=endpoint,
|
|
||||||
agent_branch=agent_branch,
|
|
||||||
turn_number=turn_number,
|
|
||||||
is_multiturn=is_multiturn,
|
|
||||||
agent_role=agent_role,
|
|
||||||
new_messages=new_messages,
|
|
||||||
agent_card=agent_card,
|
|
||||||
history_length=polling_config.history_length,
|
|
||||||
max_polls=polling_config.max_polls,
|
|
||||||
)
|
|
||||||
|
|
||||||
return await execute_streaming_delegation(
|
|
||||||
client=client,
|
client=client,
|
||||||
message=message,
|
message=message,
|
||||||
context_id=context_id,
|
|
||||||
task_id=task_id,
|
|
||||||
turn_number=turn_number,
|
|
||||||
is_multiturn=is_multiturn,
|
|
||||||
agent_role=agent_role,
|
|
||||||
new_messages=new_messages,
|
new_messages=new_messages,
|
||||||
agent_card=agent_card,
|
agent_card=agent_card,
|
||||||
|
**handler_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user