mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-10 16:48:30 +00:00
chore: refactor handlers to unified protocol
This commit is contained in:
@@ -15,6 +15,23 @@ if TYPE_CHECKING:
|
||||
from a2a.types import Task as A2ATask
|
||||
|
||||
|
||||
TERMINAL_STATES: frozenset[TaskState] = frozenset(
|
||||
{
|
||||
TaskState.completed,
|
||||
TaskState.failed,
|
||||
TaskState.rejected,
|
||||
TaskState.canceled,
|
||||
}
|
||||
)
|
||||
|
||||
ACTIONABLE_STATES: frozenset[TaskState] = frozenset(
|
||||
{
|
||||
TaskState.input_required,
|
||||
TaskState.auth_required,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class TaskStateResult(TypedDict):
|
||||
"""Result dictionary from processing A2A task state."""
|
||||
|
||||
@@ -154,8 +171,8 @@ def process_task_state(
|
||||
role=Role.agent,
|
||||
message_id=str(uuid.uuid4()),
|
||||
parts=[Part(root=TextPart(text=response_text))],
|
||||
context_id=getattr(a2a_task, "context_id", None),
|
||||
task_id=getattr(a2a_task, "task_id", None),
|
||||
context_id=a2a_task.context_id,
|
||||
task_id=a2a_task.id,
|
||||
)
|
||||
new_messages.append(agent_message)
|
||||
|
||||
|
||||
@@ -1,15 +1,33 @@
|
||||
"""A2A update mechanism configuration types."""
|
||||
|
||||
from crewai.a2a.updates.base import (
|
||||
BaseHandlerKwargs,
|
||||
PollingHandlerKwargs,
|
||||
PushNotificationHandlerKwargs,
|
||||
StreamingHandlerKwargs,
|
||||
UpdateHandler,
|
||||
)
|
||||
from crewai.a2a.updates.polling.config import PollingConfig
|
||||
from crewai.a2a.updates.polling.handler import PollingHandler
|
||||
from crewai.a2a.updates.push_notifications.config import PushNotificationConfig
|
||||
from crewai.a2a.updates.push_notifications.handler import PushNotificationHandler
|
||||
from crewai.a2a.updates.streaming.config import StreamingConfig
|
||||
from crewai.a2a.updates.streaming.handler import StreamingHandler
|
||||
|
||||
|
||||
UpdateConfig = PollingConfig | StreamingConfig | PushNotificationConfig
|
||||
|
||||
__all__ = [
|
||||
"BaseHandlerKwargs",
|
||||
"PollingConfig",
|
||||
"PollingHandler",
|
||||
"PollingHandlerKwargs",
|
||||
"PushNotificationConfig",
|
||||
"PushNotificationHandler",
|
||||
"PushNotificationHandlerKwargs",
|
||||
"StreamingConfig",
|
||||
"StreamingHandler",
|
||||
"StreamingHandlerKwargs",
|
||||
"UpdateConfig",
|
||||
"UpdateHandler",
|
||||
]
|
||||
|
||||
68
lib/crewai/src/crewai/a2a/updates/base.py
Normal file
68
lib/crewai/src/crewai/a2a/updates/base.py
Normal file
@@ -0,0 +1,68 @@
|
||||
"""Base types for A2A update mechanism handlers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Protocol, TypedDict
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from a2a.client import Client
|
||||
from a2a.types import AgentCard, Message
|
||||
|
||||
from crewai.a2a.task_helpers import TaskStateResult
|
||||
|
||||
|
||||
class BaseHandlerKwargs(TypedDict, total=False):
|
||||
"""Base kwargs shared by all handlers."""
|
||||
|
||||
turn_number: int
|
||||
is_multiturn: bool
|
||||
agent_role: str | None
|
||||
|
||||
|
||||
class PollingHandlerKwargs(BaseHandlerKwargs, total=False):
|
||||
"""Kwargs for polling handler."""
|
||||
|
||||
polling_interval: float
|
||||
polling_timeout: float
|
||||
endpoint: str
|
||||
agent_branch: Any
|
||||
history_length: int
|
||||
max_polls: int | None
|
||||
|
||||
|
||||
class StreamingHandlerKwargs(BaseHandlerKwargs, total=False):
|
||||
"""Kwargs for streaming handler."""
|
||||
|
||||
context_id: str | None
|
||||
task_id: str | None
|
||||
|
||||
|
||||
class PushNotificationHandlerKwargs(BaseHandlerKwargs, total=False):
|
||||
"""Kwargs for push notification handler."""
|
||||
|
||||
|
||||
class UpdateHandler(Protocol):
|
||||
"""Protocol for A2A update mechanism handlers."""
|
||||
|
||||
@staticmethod
|
||||
async def execute(
|
||||
client: Client,
|
||||
message: Message,
|
||||
new_messages: list[Message],
|
||||
agent_card: AgentCard,
|
||||
**kwargs: Any,
|
||||
) -> TaskStateResult:
|
||||
"""Execute the update mechanism and return result.
|
||||
|
||||
Args:
|
||||
client: A2A client instance.
|
||||
message: Message to send.
|
||||
new_messages: List to collect messages (modified in place).
|
||||
agent_card: The agent card.
|
||||
**kwargs: Additional handler-specific parameters.
|
||||
|
||||
Returns:
|
||||
Result dictionary with status, result/error, and history.
|
||||
"""
|
||||
...
|
||||
@@ -15,9 +15,11 @@ class PollingConfig(BaseModel):
|
||||
history_length: Number of messages to retrieve per poll.
|
||||
"""
|
||||
|
||||
interval: float = Field(default=2.0, description="Seconds between poll attempts")
|
||||
timeout: float | None = Field(default=None, description="Max seconds to poll")
|
||||
max_polls: int | None = Field(default=None, description="Max poll attempts")
|
||||
history_length: int = Field(
|
||||
default=100, description="Messages to retrieve per poll"
|
||||
interval: float = Field(
|
||||
default=2.0, gt=0, description="Seconds between poll attempts"
|
||||
)
|
||||
timeout: float | None = Field(default=None, gt=0, description="Max seconds to poll")
|
||||
max_polls: int | None = Field(default=None, gt=0, description="Max poll attempts")
|
||||
history_length: int = Field(
|
||||
default=100, gt=0, description="Messages to retrieve per poll"
|
||||
)
|
||||
|
||||
@@ -4,7 +4,7 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import TYPE_CHECKING, Any, Unpack
|
||||
|
||||
from a2a.client import Client
|
||||
from a2a.types import (
|
||||
@@ -15,7 +15,13 @@ from a2a.types import (
|
||||
)
|
||||
|
||||
from crewai.a2a.errors import A2APollingTimeoutError
|
||||
from crewai.a2a.task_helpers import TaskStateResult, process_task_state
|
||||
from crewai.a2a.task_helpers import (
|
||||
ACTIONABLE_STATES,
|
||||
TERMINAL_STATES,
|
||||
TaskStateResult,
|
||||
process_task_state,
|
||||
)
|
||||
from crewai.a2a.updates.base import PollingHandlerKwargs
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.a2a_events import (
|
||||
A2APollingStartedEvent,
|
||||
@@ -28,15 +34,7 @@ if TYPE_CHECKING:
|
||||
from a2a.types import Task as A2ATask
|
||||
|
||||
|
||||
TERMINAL_STATES = {
|
||||
TaskState.completed,
|
||||
TaskState.failed,
|
||||
TaskState.rejected,
|
||||
TaskState.canceled,
|
||||
}
|
||||
|
||||
|
||||
async def poll_task_until_complete(
|
||||
async def _poll_task_until_complete(
|
||||
client: Client,
|
||||
task_id: str,
|
||||
polling_interval: float,
|
||||
@@ -85,7 +83,7 @@ async def poll_task_until_complete(
|
||||
if task.status.state in TERMINAL_STATES:
|
||||
return task
|
||||
|
||||
if task.status.state in {TaskState.input_required, TaskState.auth_required}:
|
||||
if task.status.state in ACTIONABLE_STATES:
|
||||
return task
|
||||
|
||||
if elapsed > polling_timeout:
|
||||
@@ -101,128 +99,123 @@ async def poll_task_until_complete(
|
||||
await asyncio.sleep(polling_interval)
|
||||
|
||||
|
||||
async def execute_polling_delegation(
|
||||
client: Client,
|
||||
message: Message,
|
||||
polling_interval: float,
|
||||
polling_timeout: float,
|
||||
endpoint: str,
|
||||
agent_branch: Any | None,
|
||||
turn_number: int,
|
||||
is_multiturn: bool,
|
||||
agent_role: str | None,
|
||||
new_messages: list[Message],
|
||||
agent_card: AgentCard,
|
||||
history_length: int = 100,
|
||||
max_polls: int | None = None,
|
||||
) -> TaskStateResult:
|
||||
"""Execute A2A delegation using polling for updates.
|
||||
class PollingHandler:
|
||||
"""Polling-based update handler."""
|
||||
|
||||
Args:
|
||||
client: A2A client instance
|
||||
message: Message to send
|
||||
polling_interval: Seconds between poll attempts
|
||||
polling_timeout: Max seconds before timeout
|
||||
endpoint: A2A agent endpoint URL
|
||||
agent_branch: Agent tree branch for logging
|
||||
turn_number: Current turn number
|
||||
is_multiturn: Whether this is a multi-turn conversation
|
||||
agent_role: Agent role for logging
|
||||
new_messages: List to collect messages
|
||||
agent_card: The agent card
|
||||
history_length: Number of messages to retrieve per poll
|
||||
max_polls: Max number of poll attempts (None = unlimited)
|
||||
@staticmethod
|
||||
async def execute(
|
||||
client: Client,
|
||||
message: Message,
|
||||
new_messages: list[Message],
|
||||
agent_card: AgentCard,
|
||||
**kwargs: Unpack[PollingHandlerKwargs],
|
||||
) -> TaskStateResult:
|
||||
"""Execute A2A delegation using polling for updates.
|
||||
|
||||
Returns:
|
||||
Dictionary with status, result/error, and history
|
||||
"""
|
||||
task_id: str | None = None
|
||||
Args:
|
||||
client: A2A client instance.
|
||||
message: Message to send.
|
||||
new_messages: List to collect messages.
|
||||
agent_card: The agent card.
|
||||
**kwargs: Polling-specific parameters.
|
||||
|
||||
async for event in client.send_message(message):
|
||||
if isinstance(event, Message):
|
||||
new_messages.append(event)
|
||||
result_parts = [
|
||||
part.root.text for part in event.parts if part.root.kind == "text"
|
||||
]
|
||||
response_text = " ".join(result_parts) if result_parts else ""
|
||||
Returns:
|
||||
Dictionary with status, result/error, and history.
|
||||
"""
|
||||
polling_interval = kwargs.get("polling_interval", 2.0)
|
||||
polling_timeout = kwargs.get("polling_timeout", 300.0)
|
||||
endpoint = kwargs.get("endpoint", "")
|
||||
agent_branch = kwargs.get("agent_branch")
|
||||
turn_number = kwargs.get("turn_number", 0)
|
||||
is_multiturn = kwargs.get("is_multiturn", False)
|
||||
agent_role = kwargs.get("agent_role")
|
||||
history_length = kwargs.get("history_length", 100)
|
||||
max_polls = kwargs.get("max_polls")
|
||||
|
||||
crewai_event_bus.emit(
|
||||
None,
|
||||
A2AResponseReceivedEvent(
|
||||
response=response_text,
|
||||
turn_number=turn_number,
|
||||
is_multiturn=is_multiturn,
|
||||
status="completed",
|
||||
agent_role=agent_role,
|
||||
),
|
||||
)
|
||||
task_id: str | None = None
|
||||
|
||||
return TaskStateResult(
|
||||
status=TaskState.completed,
|
||||
result=response_text,
|
||||
history=new_messages,
|
||||
agent_card=agent_card,
|
||||
)
|
||||
async for event in client.send_message(message):
|
||||
if isinstance(event, Message):
|
||||
new_messages.append(event)
|
||||
result_parts = [
|
||||
part.root.text for part in event.parts if part.root.kind == "text"
|
||||
]
|
||||
response_text = " ".join(result_parts) if result_parts else ""
|
||||
|
||||
if isinstance(event, tuple):
|
||||
a2a_task, _ = event
|
||||
task_id = a2a_task.id
|
||||
|
||||
if a2a_task.status.state in TERMINAL_STATES | {
|
||||
TaskState.input_required,
|
||||
TaskState.auth_required,
|
||||
}:
|
||||
result = process_task_state(
|
||||
a2a_task=a2a_task,
|
||||
new_messages=new_messages,
|
||||
agent_card=agent_card,
|
||||
turn_number=turn_number,
|
||||
is_multiturn=is_multiturn,
|
||||
agent_role=agent_role,
|
||||
crewai_event_bus.emit(
|
||||
None,
|
||||
A2AResponseReceivedEvent(
|
||||
response=response_text,
|
||||
turn_number=turn_number,
|
||||
is_multiturn=is_multiturn,
|
||||
status="completed",
|
||||
agent_role=agent_role,
|
||||
),
|
||||
)
|
||||
if result:
|
||||
return result
|
||||
break
|
||||
|
||||
if not task_id:
|
||||
return TaskStateResult(
|
||||
status=TaskState.failed,
|
||||
error="No task ID received from initial message",
|
||||
history=new_messages,
|
||||
return TaskStateResult(
|
||||
status=TaskState.completed,
|
||||
result=response_text,
|
||||
history=new_messages,
|
||||
agent_card=agent_card,
|
||||
)
|
||||
|
||||
if isinstance(event, tuple):
|
||||
a2a_task, _ = event
|
||||
task_id = a2a_task.id
|
||||
|
||||
if a2a_task.status.state in TERMINAL_STATES | ACTIONABLE_STATES:
|
||||
result = process_task_state(
|
||||
a2a_task=a2a_task,
|
||||
new_messages=new_messages,
|
||||
agent_card=agent_card,
|
||||
turn_number=turn_number,
|
||||
is_multiturn=is_multiturn,
|
||||
agent_role=agent_role,
|
||||
)
|
||||
if result:
|
||||
return result
|
||||
break
|
||||
|
||||
if not task_id:
|
||||
return TaskStateResult(
|
||||
status=TaskState.failed,
|
||||
error="No task ID received from initial message",
|
||||
history=new_messages,
|
||||
)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
A2APollingStartedEvent(
|
||||
task_id=task_id,
|
||||
polling_interval=polling_interval,
|
||||
endpoint=endpoint,
|
||||
),
|
||||
)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
A2APollingStartedEvent(
|
||||
final_task = await _poll_task_until_complete(
|
||||
client=client,
|
||||
task_id=task_id,
|
||||
polling_interval=polling_interval,
|
||||
endpoint=endpoint,
|
||||
),
|
||||
)
|
||||
polling_timeout=polling_timeout,
|
||||
agent_branch=agent_branch,
|
||||
history_length=history_length,
|
||||
max_polls=max_polls,
|
||||
)
|
||||
|
||||
final_task = await poll_task_until_complete(
|
||||
client=client,
|
||||
task_id=task_id,
|
||||
polling_interval=polling_interval,
|
||||
polling_timeout=polling_timeout,
|
||||
agent_branch=agent_branch,
|
||||
history_length=history_length,
|
||||
max_polls=max_polls,
|
||||
)
|
||||
result = process_task_state(
|
||||
a2a_task=final_task,
|
||||
new_messages=new_messages,
|
||||
agent_card=agent_card,
|
||||
turn_number=turn_number,
|
||||
is_multiturn=is_multiturn,
|
||||
agent_role=agent_role,
|
||||
)
|
||||
if result:
|
||||
return result
|
||||
|
||||
result = process_task_state(
|
||||
a2a_task=final_task,
|
||||
new_messages=new_messages,
|
||||
agent_card=agent_card,
|
||||
turn_number=turn_number,
|
||||
is_multiturn=is_multiturn,
|
||||
agent_role=agent_role,
|
||||
)
|
||||
if result:
|
||||
return result
|
||||
|
||||
return TaskStateResult(
|
||||
status=TaskState.failed,
|
||||
error=f"Unexpected task state: {final_task.status.state}",
|
||||
history=new_messages,
|
||||
)
|
||||
return TaskStateResult(
|
||||
status=TaskState.failed,
|
||||
error=f"Unexpected task state: {final_task.status.state}",
|
||||
history=new_messages,
|
||||
)
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import AnyHttpUrl, BaseModel, Field
|
||||
|
||||
from crewai.a2a.auth.schemas import AuthScheme
|
||||
|
||||
@@ -17,7 +17,7 @@ class PushNotificationConfig(BaseModel):
|
||||
authentication: Auth scheme for the callback endpoint.
|
||||
"""
|
||||
|
||||
url: str = Field(description="Callback URL for push notifications")
|
||||
url: AnyHttpUrl = Field(description="Callback URL for push notifications")
|
||||
id: str | None = Field(default=None, description="Unique config identifier")
|
||||
token: str | None = Field(default=None, description="Validation token")
|
||||
authentication: AuthScheme | None = Field(
|
||||
|
||||
@@ -1,3 +1,40 @@
|
||||
"""Push notification (webhook) update mechanism handler."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Unpack
|
||||
|
||||
from a2a.client import Client
|
||||
from a2a.types import AgentCard, Message
|
||||
|
||||
from crewai.a2a.task_helpers import TaskStateResult
|
||||
from crewai.a2a.updates.base import PushNotificationHandlerKwargs
|
||||
|
||||
|
||||
class PushNotificationHandler:
|
||||
"""Push notification (webhook) based update handler."""
|
||||
|
||||
@staticmethod
|
||||
async def execute(
|
||||
client: Client,
|
||||
message: Message,
|
||||
new_messages: list[Message],
|
||||
agent_card: AgentCard,
|
||||
**kwargs: Unpack[PushNotificationHandlerKwargs],
|
||||
) -> TaskStateResult:
|
||||
"""Execute A2A delegation using push notifications for updates.
|
||||
|
||||
Args:
|
||||
client: A2A client instance.
|
||||
message: Message to send.
|
||||
new_messages: List to collect messages.
|
||||
agent_card: The agent card.
|
||||
**kwargs: Push notification-specific parameters.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: Push notifications not yet implemented.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"Push notification update mechanism is not yet implemented. "
|
||||
"Use PollingConfig or StreamingConfig instead."
|
||||
)
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Unpack
|
||||
import uuid
|
||||
|
||||
from a2a.client import Client
|
||||
@@ -17,136 +18,126 @@ from a2a.types import (
|
||||
TextPart,
|
||||
)
|
||||
|
||||
from crewai.a2a.task_helpers import TaskStateResult, process_task_state
|
||||
from crewai.a2a.task_helpers import (
|
||||
ACTIONABLE_STATES,
|
||||
TERMINAL_STATES,
|
||||
TaskStateResult,
|
||||
process_task_state,
|
||||
)
|
||||
from crewai.a2a.updates.base import StreamingHandlerKwargs
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.a2a_events import A2AResponseReceivedEvent
|
||||
|
||||
|
||||
async def execute_streaming_delegation(
|
||||
client: Client,
|
||||
message: Message,
|
||||
context_id: str | None,
|
||||
task_id: str | None,
|
||||
turn_number: int,
|
||||
is_multiturn: bool,
|
||||
agent_role: str | None,
|
||||
new_messages: list[Message],
|
||||
agent_card: AgentCard,
|
||||
) -> TaskStateResult:
|
||||
"""Execute A2A delegation using SSE streaming for updates.
|
||||
class StreamingHandler:
|
||||
"""SSE streaming-based update handler."""
|
||||
|
||||
Args:
|
||||
client: A2A client instance
|
||||
message: Message to send
|
||||
context_id: Context ID for correlation
|
||||
task_id: Task ID for correlation
|
||||
turn_number: Current turn number
|
||||
is_multiturn: Whether this is a multi-turn conversation
|
||||
agent_role: Agent role for logging
|
||||
new_messages: List to collect messages
|
||||
agent_card: The agent card
|
||||
@staticmethod
|
||||
async def execute(
|
||||
client: Client,
|
||||
message: Message,
|
||||
new_messages: list[Message],
|
||||
agent_card: AgentCard,
|
||||
**kwargs: Unpack[StreamingHandlerKwargs],
|
||||
) -> TaskStateResult:
|
||||
"""Execute A2A delegation using SSE streaming for updates.
|
||||
|
||||
Returns:
|
||||
Dictionary with status, result/error, and history
|
||||
"""
|
||||
result_parts: list[str] = []
|
||||
final_result: TaskStateResult | None = None
|
||||
event_stream = client.send_message(message)
|
||||
Args:
|
||||
client: A2A client instance.
|
||||
message: Message to send.
|
||||
new_messages: List to collect messages.
|
||||
agent_card: The agent card.
|
||||
**kwargs: Streaming-specific parameters.
|
||||
|
||||
try:
|
||||
async for event in event_stream:
|
||||
if isinstance(event, Message):
|
||||
new_messages.append(event)
|
||||
for part in event.parts:
|
||||
if part.root.kind == "text":
|
||||
text = part.root.text
|
||||
result_parts.append(text)
|
||||
Returns:
|
||||
Dictionary with status, result/error, and history.
|
||||
"""
|
||||
context_id = kwargs.get("context_id")
|
||||
task_id = kwargs.get("task_id")
|
||||
turn_number = kwargs.get("turn_number", 0)
|
||||
is_multiturn = kwargs.get("is_multiturn", False)
|
||||
agent_role = kwargs.get("agent_role")
|
||||
|
||||
elif isinstance(event, tuple):
|
||||
a2a_task, update = event
|
||||
result_parts: list[str] = []
|
||||
final_result: TaskStateResult | None = None
|
||||
event_stream = client.send_message(message)
|
||||
|
||||
if isinstance(update, TaskArtifactUpdateEvent):
|
||||
artifact = update.artifact
|
||||
result_parts.extend(
|
||||
part.root.text
|
||||
for part in artifact.parts
|
||||
if part.root.kind == "text"
|
||||
try:
|
||||
async for event in event_stream:
|
||||
if isinstance(event, Message):
|
||||
new_messages.append(event)
|
||||
for part in event.parts:
|
||||
if part.root.kind == "text":
|
||||
text = part.root.text
|
||||
result_parts.append(text)
|
||||
|
||||
elif isinstance(event, tuple):
|
||||
a2a_task, update = event
|
||||
|
||||
if isinstance(update, TaskArtifactUpdateEvent):
|
||||
artifact = update.artifact
|
||||
result_parts.extend(
|
||||
part.root.text
|
||||
for part in artifact.parts
|
||||
if part.root.kind == "text"
|
||||
)
|
||||
|
||||
is_final_update = False
|
||||
if isinstance(update, TaskStatusUpdateEvent):
|
||||
is_final_update = update.final
|
||||
|
||||
if (
|
||||
not is_final_update
|
||||
and a2a_task.status.state
|
||||
not in TERMINAL_STATES | ACTIONABLE_STATES
|
||||
):
|
||||
continue
|
||||
|
||||
final_result = process_task_state(
|
||||
a2a_task=a2a_task,
|
||||
new_messages=new_messages,
|
||||
agent_card=agent_card,
|
||||
turn_number=turn_number,
|
||||
is_multiturn=is_multiturn,
|
||||
agent_role=agent_role,
|
||||
result_parts=result_parts,
|
||||
)
|
||||
if final_result:
|
||||
break
|
||||
|
||||
is_final_update = False
|
||||
if isinstance(update, TaskStatusUpdateEvent):
|
||||
is_final_update = update.final
|
||||
except A2AClientHTTPError as e:
|
||||
error_msg = f"HTTP Error {e.status_code}: {e!s}"
|
||||
|
||||
if not is_final_update and a2a_task.status.state not in [
|
||||
TaskState.completed,
|
||||
TaskState.input_required,
|
||||
TaskState.failed,
|
||||
TaskState.rejected,
|
||||
TaskState.auth_required,
|
||||
TaskState.canceled,
|
||||
]:
|
||||
continue
|
||||
error_message = Message(
|
||||
role=Role.agent,
|
||||
message_id=str(uuid.uuid4()),
|
||||
parts=[Part(root=TextPart(text=error_msg))],
|
||||
context_id=context_id,
|
||||
task_id=task_id,
|
||||
)
|
||||
new_messages.append(error_message)
|
||||
|
||||
final_result = process_task_state(
|
||||
a2a_task=a2a_task,
|
||||
new_messages=new_messages,
|
||||
agent_card=agent_card,
|
||||
crewai_event_bus.emit(
|
||||
None,
|
||||
A2AResponseReceivedEvent(
|
||||
response=error_msg,
|
||||
turn_number=turn_number,
|
||||
is_multiturn=is_multiturn,
|
||||
status="failed",
|
||||
agent_role=agent_role,
|
||||
result_parts=result_parts,
|
||||
)
|
||||
if final_result:
|
||||
break
|
||||
),
|
||||
)
|
||||
return TaskStateResult(
|
||||
status=TaskState.failed,
|
||||
error=error_msg,
|
||||
history=new_messages,
|
||||
)
|
||||
|
||||
except A2AClientHTTPError as e:
|
||||
error_msg = f"HTTP Error {e.status_code}: {e!s}"
|
||||
if final_result:
|
||||
return final_result
|
||||
|
||||
error_message = Message(
|
||||
role=Role.agent,
|
||||
message_id=str(uuid.uuid4()),
|
||||
parts=[Part(root=TextPart(text=error_msg))],
|
||||
context_id=context_id,
|
||||
task_id=task_id,
|
||||
)
|
||||
new_messages.append(error_message)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
None,
|
||||
A2AResponseReceivedEvent(
|
||||
response=error_msg,
|
||||
turn_number=turn_number,
|
||||
is_multiturn=is_multiturn,
|
||||
status="failed",
|
||||
agent_role=agent_role,
|
||||
),
|
||||
)
|
||||
return TaskStateResult(
|
||||
status=TaskState.failed,
|
||||
error=error_msg,
|
||||
status=TaskState.completed,
|
||||
result=" ".join(result_parts) if result_parts else "",
|
||||
history=new_messages,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
current_exception: Exception | BaseException | None = e
|
||||
while current_exception:
|
||||
if hasattr(current_exception, "response"):
|
||||
response = current_exception.response
|
||||
if hasattr(response, "text"):
|
||||
break
|
||||
if current_exception and hasattr(current_exception, "__cause__"):
|
||||
current_exception = current_exception.__cause__
|
||||
raise
|
||||
|
||||
finally:
|
||||
if hasattr(event_stream, "aclose"):
|
||||
await event_stream.aclose()
|
||||
|
||||
if final_result:
|
||||
return final_result
|
||||
|
||||
return TaskStateResult(
|
||||
status=TaskState.completed,
|
||||
result=" ".join(result_parts) if result_parts else "",
|
||||
history=new_messages,
|
||||
)
|
||||
|
||||
@@ -34,9 +34,15 @@ from crewai.a2a.auth.utils import (
|
||||
from crewai.a2a.config import A2AConfig
|
||||
from crewai.a2a.task_helpers import TaskStateResult
|
||||
from crewai.a2a.types import PartsDict, PartsMetadataDict
|
||||
from crewai.a2a.updates import PollingConfig, UpdateConfig
|
||||
from crewai.a2a.updates.polling.handler import execute_polling_delegation
|
||||
from crewai.a2a.updates.streaming.handler import execute_streaming_delegation
|
||||
from crewai.a2a.updates import (
|
||||
PollingConfig,
|
||||
PollingHandler,
|
||||
PushNotificationConfig,
|
||||
PushNotificationHandler,
|
||||
StreamingConfig,
|
||||
StreamingHandler,
|
||||
UpdateConfig,
|
||||
)
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.a2a_events import (
|
||||
A2AConversationStartedEvent,
|
||||
@@ -53,6 +59,31 @@ if TYPE_CHECKING:
|
||||
from crewai.a2a.auth.schemas import AuthScheme
|
||||
|
||||
|
||||
HandlerType = (
|
||||
type[PollingHandler] | type[StreamingHandler] | type[PushNotificationHandler]
|
||||
)
|
||||
|
||||
HANDLER_REGISTRY: dict[type[UpdateConfig], HandlerType] = {
|
||||
PollingConfig: PollingHandler,
|
||||
StreamingConfig: StreamingHandler,
|
||||
PushNotificationConfig: PushNotificationHandler,
|
||||
}
|
||||
|
||||
|
||||
def get_handler(config: UpdateConfig | None) -> HandlerType:
|
||||
"""Get the handler class for a given update config.
|
||||
|
||||
Args:
|
||||
config: Update mechanism configuration.
|
||||
|
||||
Returns:
|
||||
Handler class for the config type, defaults to StreamingHandler.
|
||||
"""
|
||||
if config is None:
|
||||
return StreamingHandler
|
||||
return HANDLER_REGISTRY.get(type(config), StreamingHandler)
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def _fetch_agent_card_cached(
|
||||
endpoint: str,
|
||||
@@ -448,14 +479,28 @@ async def _execute_a2a_delegation_async(
|
||||
),
|
||||
)
|
||||
|
||||
polling_config = updates if isinstance(updates, PollingConfig) else None
|
||||
use_polling = polling_config is not None
|
||||
polling_interval = polling_config.interval if polling_config else 2.0
|
||||
effective_polling_timeout = (
|
||||
polling_config.timeout
|
||||
if polling_config and polling_config.timeout
|
||||
else float(timeout)
|
||||
)
|
||||
handler = get_handler(updates)
|
||||
use_polling = isinstance(updates, PollingConfig)
|
||||
|
||||
handler_kwargs: dict[str, Any] = {
|
||||
"turn_number": turn_number,
|
||||
"is_multiturn": is_multiturn,
|
||||
"agent_role": agent_role,
|
||||
"context_id": context_id,
|
||||
"task_id": task_id,
|
||||
"endpoint": endpoint,
|
||||
"agent_branch": agent_branch,
|
||||
}
|
||||
|
||||
if isinstance(updates, PollingConfig):
|
||||
handler_kwargs.update(
|
||||
{
|
||||
"polling_interval": updates.interval,
|
||||
"polling_timeout": updates.timeout or float(timeout),
|
||||
"history_length": updates.history_length,
|
||||
"max_polls": updates.max_polls,
|
||||
}
|
||||
)
|
||||
|
||||
async with _create_a2a_client(
|
||||
agent_card=agent_card,
|
||||
@@ -466,33 +511,12 @@ async def _execute_a2a_delegation_async(
|
||||
auth=auth,
|
||||
use_polling=use_polling,
|
||||
) as client:
|
||||
if use_polling and polling_config:
|
||||
return await execute_polling_delegation(
|
||||
client=client,
|
||||
message=message,
|
||||
polling_interval=polling_interval,
|
||||
polling_timeout=effective_polling_timeout,
|
||||
endpoint=endpoint,
|
||||
agent_branch=agent_branch,
|
||||
turn_number=turn_number,
|
||||
is_multiturn=is_multiturn,
|
||||
agent_role=agent_role,
|
||||
new_messages=new_messages,
|
||||
agent_card=agent_card,
|
||||
history_length=polling_config.history_length,
|
||||
max_polls=polling_config.max_polls,
|
||||
)
|
||||
|
||||
return await execute_streaming_delegation(
|
||||
return await handler.execute(
|
||||
client=client,
|
||||
message=message,
|
||||
context_id=context_id,
|
||||
task_id=task_id,
|
||||
turn_number=turn_number,
|
||||
is_multiturn=is_multiturn,
|
||||
agent_role=agent_role,
|
||||
new_messages=new_messages,
|
||||
agent_card=agent_card,
|
||||
**handler_kwargs,
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user