chore: use TaskStateResult and TaskState enum

This commit is contained in:
Greyson LaLonde
2026-01-05 18:58:12 -05:00
parent 12d60a483a
commit f478004e11
3 changed files with 96 additions and 299 deletions

View File

@@ -5,17 +5,19 @@ This module is separate from experimental.a2a to avoid circular imports.
from __future__ import annotations
from typing import Annotated
from typing import Annotated, ClassVar
from pydantic import (
BaseModel,
BeforeValidator,
ConfigDict,
Field,
HttpUrl,
TypeAdapter,
)
from crewai.a2a.auth.schemas import AuthScheme
from crewai.a2a.updates import StreamingConfig, UpdateConfig
http_url_adapter = TypeAdapter(HttpUrl)
@@ -33,18 +35,21 @@ class A2AConfig(BaseModel):
Attributes:
endpoint: A2A agent endpoint URL.
auth: Authentication scheme (Bearer, OAuth2, API Key, HTTP Basic/Digest).
timeout: Request timeout in seconds (default: 120).
max_turns: Maximum conversation turns with A2A agent (default: 10).
auth: Authentication scheme.
timeout: Request timeout in seconds.
max_turns: Maximum conversation turns with A2A agent.
response_model: Optional Pydantic model for structured A2A agent responses.
fail_fast: If True, raise error when agent unreachable; if False, skip and continue (default: True).
trust_remote_completion_status: If True, return A2A agent's result directly when status is "completed"; if False, always ask server agent to respond (default: False).
fail_fast: If True, raise error when agent unreachable; if False, skip and continue.
trust_remote_completion_status: If True, return A2A agent's result directly when completed.
updates: Update mechanism config.
"""
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
endpoint: Url = Field(description="A2A agent endpoint URL")
auth: AuthScheme | None = Field(
default=None,
description="Authentication scheme (Bearer, OAuth2, API Key, HTTP Basic/Digest)",
description="Authentication scheme",
)
timeout: int = Field(default=120, description="Request timeout in seconds")
max_turns: int = Field(
@@ -52,13 +57,17 @@ class A2AConfig(BaseModel):
)
response_model: type[BaseModel] | None = Field(
default=None,
description="Optional Pydantic model for structured A2A agent responses. When specified, the A2A agent is expected to return JSON matching this schema.",
description="Optional Pydantic model for structured A2A agent responses",
)
fail_fast: bool = Field(
default=True,
description="If True, raise an error immediately when the A2A agent is unreachable. If False, skip the A2A agent and continue execution.",
description="If True, raise error when agent unreachable; if False, skip",
)
trust_remote_completion_status: bool = Field(
default=False,
description='If True, return the A2A agent\'s result directly when status is "completed" without asking the server agent to respond. If False, always ask the server agent to respond, allowing it to potentially delegate again.',
description="If True, return A2A result directly when completed",
)
updates: UpdateConfig = Field(
default_factory=StreamingConfig,
description="Update mechanism config",
)

View File

@@ -10,16 +10,12 @@ import time
from typing import TYPE_CHECKING, Any
import uuid
from a2a.client import Client, ClientConfig, ClientFactory
from a2a.client.errors import A2AClientHTTPError
from a2a.client import A2AClientHTTPError, Client, ClientConfig, ClientFactory
from a2a.types import (
AgentCard,
Message,
Part,
Role,
TaskArtifactUpdateEvent,
TaskState,
TaskStatusUpdateEvent,
TextPart,
TransportProtocol,
)
@@ -36,20 +32,23 @@ from crewai.a2a.auth.utils import (
validate_auth_against_agent_card,
)
from crewai.a2a.config import A2AConfig
from crewai.a2a.task_helpers import TaskStateResult
from crewai.a2a.types import PartsDict, PartsMetadataDict
from crewai.a2a.updates import PollingConfig, UpdateConfig
from crewai.a2a.updates.polling.handler import execute_polling_delegation
from crewai.a2a.updates.streaming.handler import execute_streaming_delegation
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.a2a_events import (
A2AConversationStartedEvent,
A2ADelegationCompletedEvent,
A2ADelegationStartedEvent,
A2AMessageSentEvent,
A2AResponseReceivedEvent,
)
from crewai.types.utils import create_literals_from_strings
if TYPE_CHECKING:
from a2a.types import Message, Task as A2ATask
from a2a.types import Message
from crewai.a2a.auth.schemas import AuthScheme
@@ -235,26 +234,20 @@ def execute_a2a_delegation(
agent_branch: Any | None = None,
response_model: type[BaseModel] | None = None,
turn_number: int | None = None,
) -> dict[str, Any]:
updates: UpdateConfig | None = None,
) -> TaskStateResult:
"""Execute a task delegation to a remote A2A agent with multi-turn support.
Handles:
- AgentCard discovery
- Authentication setup
- Message creation and sending
- Response parsing
- Multi-turn conversations
Args:
endpoint: A2A agent endpoint URL (AgentCard URL)
auth: Optional AuthScheme for authentication (Bearer, OAuth2, API Key, HTTP Basic/Digest)
endpoint: A2A agent endpoint URL
auth: Optional AuthScheme for authentication
timeout: Request timeout in seconds
task_description: The task to delegate
context: Optional context information
context_id: Context ID for correlating messages/tasks
task_id: Specific task identifier
reference_task_ids: List of related task IDs
metadata: Additional metadata (external_id, request_id, etc.)
metadata: Additional metadata
extensions: Protocol extensions for custom fields
conversation_history: Previous Message objects from conversation
agent_id: Agent identifier for logging
@@ -262,16 +255,10 @@ def execute_a2a_delegation(
agent_branch: Optional agent tree branch for logging
response_model: Optional Pydantic model for structured outputs
turn_number: Optional turn number for multi-turn conversations
updates: Update mechanism config from A2AConfig.updates
Returns:
Dictionary with:
- status: "completed", "input_required", "failed", etc.
- result: Result string (if completed)
- error: Error message (if failed)
- history: List of new Message objects from this exchange
Raises:
ImportError: If a2a-sdk is not installed
TaskStateResult with status, result/error, history, and agent_card
"""
is_multiturn = bool(conversation_history and len(conversation_history) > 0)
if turn_number is None:
@@ -311,6 +298,7 @@ def execute_a2a_delegation(
agent_id=agent_id,
agent_role=agent_role,
response_model=response_model,
updates=updates,
)
)
@@ -347,7 +335,8 @@ async def _execute_a2a_delegation_async(
agent_id: str | None = None,
agent_role: str | None = None,
response_model: type[BaseModel] | None = None,
) -> dict[str, Any]:
updates: UpdateConfig | None = None,
) -> TaskStateResult:
"""Async implementation of A2A delegation with multi-turn support.
Args:
@@ -368,9 +357,10 @@ async def _execute_a2a_delegation_async(
agent_id: Agent identifier for logging
agent_role: Agent role for logging
response_model: Optional Pydantic model for structured outputs
updates: Update mechanism config
Returns:
Dictionary with status, result/error, and new history
TaskStateResult with status, result/error, history, and agent_card
"""
if auth:
auth_data = auth.model_dump_json(
@@ -458,201 +448,52 @@ async def _execute_a2a_delegation_async(
),
)
polling_config = updates if isinstance(updates, PollingConfig) else None
use_polling = polling_config is not None
polling_interval = polling_config.interval if polling_config else 2.0
effective_polling_timeout = (
polling_config.timeout
if polling_config and polling_config.timeout
else float(timeout)
)
async with _create_a2a_client(
agent_card=agent_card,
transport_protocol=transport_protocol,
timeout=timeout,
headers=headers,
streaming=True,
streaming=not use_polling,
auth=auth,
use_polling=use_polling,
) as client:
result_parts: list[str] = []
final_result: dict[str, Any] | None = None
event_stream = client.send_message(message)
if use_polling and polling_config:
return await execute_polling_delegation(
client=client,
message=message,
polling_interval=polling_interval,
polling_timeout=effective_polling_timeout,
endpoint=endpoint,
agent_branch=agent_branch,
turn_number=turn_number,
is_multiturn=is_multiturn,
agent_role=agent_role,
new_messages=new_messages,
agent_card=agent_card,
history_length=polling_config.history_length,
max_polls=polling_config.max_polls,
)
try:
async for event in event_stream:
if isinstance(event, Message):
new_messages.append(event)
for part in event.parts:
if part.root.kind == "text":
text = part.root.text
result_parts.append(text)
elif isinstance(event, tuple):
a2a_task, update = event
if isinstance(update, TaskArtifactUpdateEvent):
artifact = update.artifact
result_parts.extend(
part.root.text
for part in artifact.parts
if part.root.kind == "text"
)
is_final_update = False
if isinstance(update, TaskStatusUpdateEvent):
is_final_update = update.final
if not is_final_update and a2a_task.status.state not in [
TaskState.completed,
TaskState.input_required,
TaskState.failed,
TaskState.rejected,
TaskState.auth_required,
TaskState.canceled,
]:
continue
if a2a_task.status.state == TaskState.completed:
extracted_parts = _extract_task_result_parts(a2a_task)
result_parts.extend(extracted_parts)
if a2a_task.history:
new_messages.extend(a2a_task.history)
response_text = " ".join(result_parts) if result_parts else ""
crewai_event_bus.emit(
None,
A2AResponseReceivedEvent(
response=response_text,
turn_number=turn_number,
is_multiturn=is_multiturn,
status="completed",
agent_role=agent_role,
),
)
final_result = {
"status": "completed",
"result": response_text,
"history": new_messages,
"agent_card": agent_card,
}
break
if a2a_task.status.state == TaskState.input_required:
if a2a_task.history:
new_messages.extend(a2a_task.history)
response_text = _extract_error_message(
a2a_task, "Additional input required"
)
if response_text and not a2a_task.history:
agent_message = Message(
role=Role.agent,
message_id=str(uuid.uuid4()),
parts=[Part(root=TextPart(text=response_text))],
context_id=a2a_task.context_id
if hasattr(a2a_task, "context_id")
else None,
task_id=a2a_task.task_id
if hasattr(a2a_task, "task_id")
else None,
)
new_messages.append(agent_message)
crewai_event_bus.emit(
None,
A2AResponseReceivedEvent(
response=response_text,
turn_number=turn_number,
is_multiturn=is_multiturn,
status="input_required",
agent_role=agent_role,
),
)
final_result = {
"status": "input_required",
"error": response_text,
"history": new_messages,
"agent_card": agent_card,
}
break
if a2a_task.status.state in [TaskState.failed, TaskState.rejected]:
error_msg = _extract_error_message(
a2a_task, "Task failed without error message"
)
if a2a_task.history:
new_messages.extend(a2a_task.history)
final_result = {
"status": "failed",
"error": error_msg,
"history": new_messages,
}
break
if a2a_task.status.state == TaskState.auth_required:
error_msg = _extract_error_message(
a2a_task, "Authentication required"
)
final_result = {
"status": "auth_required",
"error": error_msg,
"history": new_messages,
}
break
if a2a_task.status.state == TaskState.canceled:
error_msg = _extract_error_message(
a2a_task, "Task was canceled"
)
final_result = {
"status": "canceled",
"error": error_msg,
"history": new_messages,
}
break
except Exception as e:
if isinstance(e, A2AClientHTTPError):
error_msg = f"HTTP Error {e.status_code}: {e!s}"
error_message = Message(
role=Role.agent,
message_id=str(uuid.uuid4()),
parts=[Part(root=TextPart(text=error_msg))],
context_id=context_id,
task_id=task_id,
)
new_messages.append(error_message)
crewai_event_bus.emit(
None,
A2AResponseReceivedEvent(
response=error_msg,
turn_number=turn_number,
is_multiturn=is_multiturn,
status="failed",
agent_role=agent_role,
),
)
return {
"status": "failed",
"error": error_msg,
"history": new_messages,
}
current_exception: Exception | BaseException | None = e
while current_exception:
if hasattr(current_exception, "response"):
response = current_exception.response
if hasattr(response, "text"):
break
if current_exception and hasattr(current_exception, "__cause__"):
current_exception = current_exception.__cause__
raise
finally:
if hasattr(event_stream, "aclose"):
await event_stream.aclose()
if final_result:
return final_result
return {
"status": "completed",
"result": " ".join(result_parts) if result_parts else "",
"history": new_messages,
}
return await execute_streaming_delegation(
client=client,
message=message,
context_id=context_id,
task_id=task_id,
turn_number=turn_number,
is_multiturn=is_multiturn,
agent_role=agent_role,
new_messages=new_messages,
agent_card=agent_card,
)
@asynccontextmanager
@@ -663,6 +504,7 @@ async def _create_a2a_client(
headers: MutableMapping[str, str],
streaming: bool,
auth: AuthScheme | None = None,
use_polling: bool = False,
) -> AsyncIterator[Client]:
"""Create and configure an A2A client.
@@ -673,6 +515,7 @@ async def _create_a2a_client(
headers: HTTP headers (already with auth applied)
streaming: Enable streaming responses
auth: Optional AuthScheme for client configuration
use_polling: Enable polling mode
Yields:
Configured A2A client instance
@@ -688,7 +531,8 @@ async def _create_a2a_client(
config = ClientConfig(
httpx_client=httpx_client,
supported_transports=[str(transport_protocol.value)],
streaming=streaming,
streaming=streaming and not use_polling,
polling=use_polling,
accepted_output_modes=["application/json"],
)
@@ -697,66 +541,6 @@ async def _create_a2a_client(
yield client
def _extract_task_result_parts(a2a_task: A2ATask) -> list[str]:
"""Extract result parts from A2A task history and artifacts.
Args:
a2a_task: A2A Task object with history and artifacts
Returns:
List of result text parts
"""
result_parts: list[str] = []
if a2a_task.history:
for history_msg in reversed(a2a_task.history):
if history_msg.role == Role.agent:
result_parts.extend(
part.root.text
for part in history_msg.parts
if part.root.kind == "text"
)
break
if a2a_task.artifacts:
result_parts.extend(
part.root.text
for artifact in a2a_task.artifacts
for part in artifact.parts
if part.root.kind == "text"
)
return result_parts
def _extract_error_message(a2a_task: A2ATask, default: str) -> str:
"""Extract error message from A2A task.
Args:
a2a_task: A2A Task object
default: Default message if no error found
Returns:
Error message string
"""
if a2a_task.status and a2a_task.status.message:
msg = a2a_task.status.message
if msg:
for part in msg.parts:
if part.root.kind == "text":
return str(part.root.text)
return str(msg)
if a2a_task.history:
for history_msg in reversed(a2a_task.history):
for part in history_msg.parts:
if part.root.kind == "text":
return str(part.root.text)
return default
def create_agent_response_model(agent_ids: tuple[str, ...]) -> type[BaseModel]:
"""Create a dynamic AgentResponse model with Literal types for agent IDs.

View File

@@ -9,13 +9,14 @@ from collections.abc import Callable
from concurrent.futures import ThreadPoolExecutor, as_completed
from functools import wraps
from types import MethodType
from typing import TYPE_CHECKING, Any, cast
from typing import TYPE_CHECKING, Any
from a2a.types import Role
from a2a.types import Role, TaskState
from pydantic import BaseModel, ValidationError
from crewai.a2a.config import A2AConfig
from crewai.a2a.extensions.base import ExtensionRegistry
from crewai.a2a.task_helpers import TaskStateResult
from crewai.a2a.templates import (
AVAILABLE_AGENTS_TEMPLATE,
CONVERSATION_TURN_INFO_TEMPLATE,
@@ -346,7 +347,7 @@ IMPORTANT: You have the ability to delegate this task to remote A2A agents.
def _parse_agent_response(
raw_result: str | dict[str, Any], agent_response_model: type[BaseModel]
) -> BaseModel | str:
) -> BaseModel | str | dict[str, Any]:
"""Parse LLM output as AgentResponse or return raw agent response.
Args:
@@ -354,7 +355,7 @@ def _parse_agent_response(
agent_response_model: The agent response model
Returns:
Parsed AgentResponse or string
Parsed AgentResponse, or raw result if parsing fails
"""
if agent_response_model:
try:
@@ -363,13 +364,13 @@ def _parse_agent_response(
if isinstance(raw_result, dict):
return agent_response_model.model_validate(raw_result)
except ValidationError:
return cast(str, raw_result)
return cast(str, raw_result)
return raw_result
return raw_result
def _handle_agent_response_and_continue(
self: Agent,
a2a_result: dict[str, Any],
a2a_result: TaskStateResult,
agent_id: str,
agent_cards: dict[str, AgentCard] | None,
a2a_agents: list[A2AConfig],
@@ -568,6 +569,7 @@ def _delegate_to_a2a(
agent_branch=agent_branch,
response_model=agent_config.response_model,
turn_number=turn_num + 1,
updates=agent_config.updates,
)
conversation_history = a2a_result.get("history", [])
@@ -579,11 +581,8 @@ def _delegate_to_a2a(
if latest_message.context_id is not None:
context_id = latest_message.context_id
if a2a_result["status"] in ["completed", "input_required"]:
if (
a2a_result["status"] == "completed"
and agent_config.trust_remote_completion_status
):
if a2a_result["status"] in [TaskState.completed, TaskState.input_required]:
if a2a_result["status"] == TaskState.completed:
if (
task_id_config is not None
and task_id_config not in reference_task_ids
@@ -592,7 +591,12 @@ def _delegate_to_a2a(
if task.config is None:
task.config = {}
task.config["reference_task_ids"] = reference_task_ids
task_id_config = None
if (
a2a_result["status"] == TaskState.completed
and agent_config.trust_remote_completion_status
):
result_text = a2a_result.get("result", "")
final_turn_number = turn_num + 1
crewai_event_bus.emit(
@@ -604,7 +608,7 @@ def _delegate_to_a2a(
total_turns=final_turn_number,
),
)
return cast(str, result_text)
return str(result_text)
final_result, next_request = _handle_agent_response_and_continue(
self=self,