From 67953b3a6a57d0f8b18474aea8f8835a5fcd99fd Mon Sep 17 00:00:00 2001 From: Greyson LaLonde Date: Wed, 7 Jan 2026 14:07:40 -0500 Subject: [PATCH] feat: a2a native async chain --- lib/crewai/src/crewai/a2a/types.py | 23 + lib/crewai/src/crewai/a2a/utils.py | 346 +++++++----- lib/crewai/src/crewai/a2a/wrapper.py | 792 +++++++++++++++++++++------ 3 files changed, 852 insertions(+), 309 deletions(-) diff --git a/lib/crewai/src/crewai/a2a/types.py b/lib/crewai/src/crewai/a2a/types.py index fca22d8bb..217b59467 100644 --- a/lib/crewai/src/crewai/a2a/types.py +++ b/lib/crewai/src/crewai/a2a/types.py @@ -4,6 +4,16 @@ from typing import Any, Literal, Protocol, TypedDict, runtime_checkable from typing_extensions import NotRequired +from crewai.a2a.updates import ( + PollingConfig, + PollingHandler, + PushNotificationConfig, + PushNotificationHandler, + StreamingConfig, + StreamingHandler, + UpdateConfig, +) + @runtime_checkable class AgentResponseProtocol(Protocol): @@ -36,3 +46,16 @@ class PartsDict(TypedDict): text: str metadata: NotRequired[PartsMetadataDict] + + +PollingHandlerType = type[PollingHandler] +StreamingHandlerType = type[StreamingHandler] +PushNotificationHandlerType = type[PushNotificationHandler] + +HandlerType = PollingHandlerType | StreamingHandlerType | PushNotificationHandlerType + +HANDLER_REGISTRY: dict[type[UpdateConfig], HandlerType] = { + PollingConfig: PollingHandler, + StreamingConfig: StreamingHandler, + PushNotificationConfig: PushNotificationHandler, +} diff --git a/lib/crewai/src/crewai/a2a/utils.py b/lib/crewai/src/crewai/a2a/utils.py index 42f5f44e0..4b3ba23e9 100644 --- a/lib/crewai/src/crewai/a2a/utils.py +++ b/lib/crewai/src/crewai/a2a/utils.py @@ -34,13 +34,15 @@ from crewai.a2a.auth.utils import ( ) from crewai.a2a.config import A2AConfig from crewai.a2a.task_helpers import TaskStateResult -from crewai.a2a.types import PartsDict, PartsMetadataDict +from crewai.a2a.types import ( + HANDLER_REGISTRY, + HandlerType, + PartsDict, + PartsMetadataDict, +) from crewai.a2a.updates import ( PollingConfig, - PollingHandler, PushNotificationConfig, - PushNotificationHandler, - StreamingConfig, StreamingHandler, UpdateConfig, ) @@ -60,17 +62,6 @@ if TYPE_CHECKING: from crewai.a2a.auth.schemas import AuthScheme -HandlerType = ( - type[PollingHandler] | type[StreamingHandler] | type[PushNotificationHandler] -) - -HANDLER_REGISTRY: dict[type[UpdateConfig], HandlerType] = { - PollingConfig: PollingHandler, - StreamingConfig: StreamingHandler, - PushNotificationConfig: PushNotificationHandler, -} - - def get_handler(config: UpdateConfig | None) -> HandlerType: """Get the handler class for a given update config. @@ -92,24 +83,14 @@ def _fetch_agent_card_cached( timeout: int, _ttl_hash: int, ) -> AgentCard: - """Cached version of fetch_agent_card with auth support. - - Args: - endpoint: A2A agent endpoint URL - auth_hash: Hash of the auth object - timeout: Request timeout - _ttl_hash: Time-based hash for cache invalidation - - Returns: - Cached AgentCard - """ + """Cached sync version of fetch_agent_card.""" auth = _auth_store.get(auth_hash) loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) try: return loop.run_until_complete( - _fetch_agent_card_async(endpoint=endpoint, auth=auth, timeout=timeout) + _afetch_agent_card_impl(endpoint=endpoint, auth=auth, timeout=timeout) ) finally: loop.close() @@ -159,47 +140,74 @@ def fetch_agent_card( asyncio.set_event_loop(loop) try: return loop.run_until_complete( - _fetch_agent_card_async(endpoint=endpoint, auth=auth, timeout=timeout) + afetch_agent_card(endpoint=endpoint, auth=auth, timeout=timeout) ) finally: loop.close() +async def afetch_agent_card( + endpoint: str, + auth: AuthScheme | None = None, + timeout: int = 30, + use_cache: bool = True, +) -> AgentCard: + """Fetch AgentCard from an A2A endpoint asynchronously. + + Native async implementation. Use this when running in an async context. + + Args: + endpoint: A2A agent endpoint URL (AgentCard URL). + auth: Optional AuthScheme for authentication. + timeout: Request timeout in seconds. + use_cache: Whether to use caching (default True). + + Returns: + AgentCard object with agent capabilities and skills. + + Raises: + httpx.HTTPStatusError: If the request fails. + A2AClientHTTPError: If authentication fails. + """ + if use_cache: + if auth: + auth_data = auth.model_dump_json( + exclude={ + "_access_token", + "_token_expires_at", + "_refresh_token", + "_authorization_callback", + } + ) + auth_hash = hash((type(auth).__name__, auth_data)) + else: + auth_hash = 0 + _auth_store[auth_hash] = auth + agent_card: AgentCard = await _afetch_agent_card_cached( + endpoint, auth_hash, timeout + ) + return agent_card + + return await _afetch_agent_card_impl(endpoint=endpoint, auth=auth, timeout=timeout) + + @cached(ttl=300, serializer=PickleSerializer()) # type: ignore[untyped-decorator] -async def _fetch_agent_card_async_cached( +async def _afetch_agent_card_cached( endpoint: str, auth_hash: int, timeout: int, ) -> AgentCard: - """Cached async implementation of AgentCard fetching. - - Args: - endpoint: A2A agent endpoint URL - auth_hash: Hash of the auth object - timeout: Request timeout in seconds - - Returns: - Cached AgentCard object - """ + """Cached async implementation of AgentCard fetching.""" auth = _auth_store.get(auth_hash) - return await _fetch_agent_card_async(endpoint=endpoint, auth=auth, timeout=timeout) + return await _afetch_agent_card_impl(endpoint=endpoint, auth=auth, timeout=timeout) -async def _fetch_agent_card_async( +async def _afetch_agent_card_impl( endpoint: str, auth: AuthScheme | None, timeout: int, ) -> AgentCard: - """Async implementation of AgentCard fetching. - - Args: - endpoint: A2A agent endpoint URL - auth: Optional AuthScheme for authentication - timeout: Request timeout in seconds - - Returns: - AgentCard object - """ + """Internal async implementation of AgentCard fetching.""" if "/.well-known/agent-card.json" in endpoint: base_url = endpoint.replace("/.well-known/agent-card.json", "") agent_card_path = "/.well-known/agent-card.json" @@ -268,35 +276,114 @@ def execute_a2a_delegation( turn_number: int | None = None, updates: UpdateConfig | None = None, ) -> TaskStateResult: - """Execute a task delegation to a remote A2A agent with multi-turn support. + """Execute a task delegation to a remote A2A agent synchronously. + + This is the sync wrapper around aexecute_a2a_delegation. For async contexts, + use aexecute_a2a_delegation directly. Args: - endpoint: A2A agent endpoint URL - auth: Optional AuthScheme for authentication - timeout: Request timeout in seconds - task_description: The task to delegate - context: Optional context information - context_id: Context ID for correlating messages/tasks - task_id: Specific task identifier - reference_task_ids: List of related task IDs - metadata: Additional metadata - extensions: Protocol extensions for custom fields - conversation_history: Previous Message objects from conversation - agent_id: Agent identifier for logging - agent_role: Role of the CrewAI agent delegating the task - agent_branch: Optional agent tree branch for logging - response_model: Optional Pydantic model for structured outputs - turn_number: Optional turn number for multi-turn conversations - updates: Update mechanism config from A2AConfig.updates + endpoint: A2A agent endpoint URL. + auth: Optional AuthScheme for authentication. + timeout: Request timeout in seconds. + task_description: The task to delegate. + context: Optional context information. + context_id: Context ID for correlating messages/tasks. + task_id: Specific task identifier. + reference_task_ids: List of related task IDs. + metadata: Additional metadata. + extensions: Protocol extensions for custom fields. + conversation_history: Previous Message objects from conversation. + agent_id: Agent identifier for logging. + agent_role: Role of the CrewAI agent delegating the task. + agent_branch: Optional agent tree branch for logging. + response_model: Optional Pydantic model for structured outputs. + turn_number: Optional turn number for multi-turn conversations. + updates: Update mechanism config from A2AConfig.updates. Returns: - TaskStateResult with status, result/error, history, and agent_card + TaskStateResult with status, result/error, history, and agent_card. """ - is_multiturn = bool(conversation_history and len(conversation_history) > 0) - if turn_number is None: - turn_number = ( - len([m for m in (conversation_history or []) if m.role == Role.user]) + 1 + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + return loop.run_until_complete( + aexecute_a2a_delegation( + endpoint=endpoint, + auth=auth, + timeout=timeout, + task_description=task_description, + context=context, + context_id=context_id, + task_id=task_id, + reference_task_ids=reference_task_ids, + metadata=metadata, + extensions=extensions, + conversation_history=conversation_history, + agent_id=agent_id, + agent_role=agent_role, + agent_branch=agent_branch, + response_model=response_model, + turn_number=turn_number, + updates=updates, + ) ) + finally: + loop.close() + + +async def aexecute_a2a_delegation( + endpoint: str, + auth: AuthScheme | None, + timeout: int, + task_description: str, + context: str | None = None, + context_id: str | None = None, + task_id: str | None = None, + reference_task_ids: list[str] | None = None, + metadata: dict[str, Any] | None = None, + extensions: dict[str, Any] | None = None, + conversation_history: list[Message] | None = None, + agent_id: str | None = None, + agent_role: Role | None = None, + agent_branch: Any | None = None, + response_model: type[BaseModel] | None = None, + turn_number: int | None = None, + updates: UpdateConfig | None = None, +) -> TaskStateResult: + """Execute a task delegation to a remote A2A agent asynchronously. + + Native async implementation with multi-turn support. Use this when running + in an async context (e.g., with Crew.akickoff() or agent.aexecute_task()). + + Args: + endpoint: A2A agent endpoint URL. + auth: Optional AuthScheme for authentication. + timeout: Request timeout in seconds. + task_description: The task to delegate. + context: Optional context information. + context_id: Context ID for correlating messages/tasks. + task_id: Specific task identifier. + reference_task_ids: List of related task IDs. + metadata: Additional metadata. + extensions: Protocol extensions for custom fields. + conversation_history: Previous Message objects from conversation. + agent_id: Agent identifier for logging. + agent_role: Role of the CrewAI agent delegating the task. + agent_branch: Optional agent tree branch for logging. + response_model: Optional Pydantic model for structured outputs. + turn_number: Optional turn number for multi-turn conversations. + updates: Update mechanism config from A2AConfig.updates. + + Returns: + TaskStateResult with status, result/error, history, and agent_card. + """ + if conversation_history is None: + conversation_history = [] + + is_multiturn = len(conversation_history) > 0 + if turn_number is None: + turn_number = len([m for m in conversation_history if m.role == Role.user]) + 1 + crewai_event_bus.emit( agent_branch, A2ADelegationStartedEvent( @@ -308,48 +395,41 @@ def execute_a2a_delegation( ), ) - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - try: - result = loop.run_until_complete( - _execute_a2a_delegation_async( - endpoint=endpoint, - auth=auth, - timeout=timeout, - task_description=task_description, - context=context, - context_id=context_id, - task_id=task_id, - reference_task_ids=reference_task_ids, - metadata=metadata, - extensions=extensions, - conversation_history=conversation_history or [], - is_multiturn=is_multiturn, - turn_number=turn_number, - agent_branch=agent_branch, - agent_id=agent_id, - agent_role=agent_role, - response_model=response_model, - updates=updates, - ) - ) + result = await _aexecute_a2a_delegation_impl( + endpoint=endpoint, + auth=auth, + timeout=timeout, + task_description=task_description, + context=context, + context_id=context_id, + task_id=task_id, + reference_task_ids=reference_task_ids, + metadata=metadata, + extensions=extensions, + conversation_history=conversation_history, + is_multiturn=is_multiturn, + turn_number=turn_number, + agent_branch=agent_branch, + agent_id=agent_id, + agent_role=agent_role, + response_model=response_model, + updates=updates, + ) - crewai_event_bus.emit( - agent_branch, - A2ADelegationCompletedEvent( - status=result["status"], - result=result.get("result"), - error=result.get("error"), - is_multiturn=is_multiturn, - ), - ) + crewai_event_bus.emit( + agent_branch, + A2ADelegationCompletedEvent( + status=result["status"], + result=result.get("result"), + error=result.get("error"), + is_multiturn=is_multiturn, + ), + ) - return result - finally: - loop.close() + return result -async def _execute_a2a_delegation_async( +async def _aexecute_a2a_delegation_impl( endpoint: str, auth: AuthScheme | None, timeout: int, @@ -361,39 +441,15 @@ async def _execute_a2a_delegation_async( metadata: dict[str, Any] | None, extensions: dict[str, Any] | None, conversation_history: list[Message], - is_multiturn: bool = False, - turn_number: int = 1, - agent_branch: Any | None = None, - agent_id: str | None = None, - agent_role: str | None = None, - response_model: type[BaseModel] | None = None, - updates: UpdateConfig | None = None, + is_multiturn: bool, + turn_number: int, + agent_branch: Any | None, + agent_id: str | None, + agent_role: str | None, + response_model: type[BaseModel] | None, + updates: UpdateConfig | None, ) -> TaskStateResult: - """Async implementation of A2A delegation with multi-turn support. - - Args: - endpoint: A2A agent endpoint URL - auth: Optional AuthScheme for authentication - timeout: Request timeout in seconds - task_description: Task to delegate - context: Optional context - context_id: Context ID for correlation - task_id: Specific task identifier - reference_task_ids: Related task IDs - metadata: Additional metadata - extensions: Protocol extensions - conversation_history: Previous Message objects - is_multiturn: Whether this is a multi-turn conversation - turn_number: Current turn number - agent_branch: Agent tree branch for logging - agent_id: Agent identifier for logging - agent_role: Agent role for logging - response_model: Optional Pydantic model for structured outputs - updates: Update mechanism config - - Returns: - TaskStateResult with status, result/error, history, and agent_card - """ + """Internal async implementation of A2A delegation.""" if auth: auth_data = auth.model_dump_json( exclude={ @@ -407,7 +463,7 @@ async def _execute_a2a_delegation_async( else: auth_hash = 0 _auth_store[auth_hash] = auth - agent_card = await _fetch_agent_card_async_cached( + agent_card = await _afetch_agent_card_cached( endpoint=endpoint, auth_hash=auth_hash, timeout=timeout ) diff --git a/lib/crewai/src/crewai/a2a/wrapper.py b/lib/crewai/src/crewai/a2a/wrapper.py index 63c445921..358d0fc79 100644 --- a/lib/crewai/src/crewai/a2a/wrapper.py +++ b/lib/crewai/src/crewai/a2a/wrapper.py @@ -5,7 +5,8 @@ Wraps agent classes with A2A delegation capabilities. from __future__ import annotations -from collections.abc import Callable +import asyncio +from collections.abc import Callable, Coroutine from concurrent.futures import ThreadPoolExecutor, as_completed from functools import wraps from types import MethodType @@ -26,6 +27,8 @@ from crewai.a2a.templates import ( ) from crewai.a2a.types import AgentResponseProtocol from crewai.a2a.utils import ( + aexecute_a2a_delegation, + afetch_agent_card, execute_a2a_delegation, fetch_agent_card, get_a2a_agents_and_response_model, @@ -48,15 +51,15 @@ if TYPE_CHECKING: def wrap_agent_with_a2a_instance( agent: Agent, extension_registry: ExtensionRegistry | None = None ) -> None: - """Wrap an agent instance's execute_task method with A2A support. + """Wrap an agent instance's execute_task and aexecute_task methods with A2A support. This function modifies the agent instance by wrapping its execute_task - method to add A2A delegation capabilities. Should only be called when - the agent has a2a configuration set. + and aexecute_task methods to add A2A delegation capabilities. Should only + be called when the agent has a2a configuration set. Args: - agent: The agent instance to wrap - extension_registry: Optional registry of A2A extensions for injecting tools and custom logic + agent: The agent instance to wrap. + extension_registry: Optional registry of A2A extensions. """ if extension_registry is None: extension_registry = ExtensionRegistry() @@ -64,6 +67,7 @@ def wrap_agent_with_a2a_instance( extension_registry.inject_all_tools(agent) original_execute_task = agent.execute_task.__func__ # type: ignore[attr-defined] + original_aexecute_task = agent.aexecute_task.__func__ # type: ignore[attr-defined] @wraps(original_execute_task) def execute_task_with_a2a( @@ -72,17 +76,7 @@ def wrap_agent_with_a2a_instance( context: str | None = None, tools: list[BaseTool] | None = None, ) -> str: - """Execute task with A2A delegation support. - - Args: - self: The agent instance - task: The task to execute - context: Optional context for task execution - tools: Optional tools available to the agent - - Returns: - Task execution result - """ + """Execute task with A2A delegation support (sync).""" if not self.a2a: return original_execute_task(self, task, context, tools) # type: ignore[no-any-return] @@ -99,7 +93,34 @@ def wrap_agent_with_a2a_instance( extension_registry=extension_registry, ) + @wraps(original_aexecute_task) + async def aexecute_task_with_a2a( + self: Agent, + task: Task, + context: str | None = None, + tools: list[BaseTool] | None = None, + ) -> str: + """Execute task with A2A delegation support (async).""" + if not self.a2a: + return await original_aexecute_task(self, task, context, tools) # type: ignore[no-any-return] + + a2a_agents, agent_response_model = get_a2a_agents_and_response_model(self.a2a) + + return await _aexecute_task_with_a2a( + self=self, + a2a_agents=a2a_agents, + original_fn=original_aexecute_task, + task=task, + agent_response_model=agent_response_model, + context=context, + tools=tools, + extension_registry=extension_registry, + ) + object.__setattr__(agent, "execute_task", MethodType(execute_task_with_a2a, agent)) + object.__setattr__( + agent, "aexecute_task", MethodType(aexecute_task_with_a2a, agent) + ) def _fetch_card_from_config( @@ -353,15 +374,7 @@ IMPORTANT: You have the ability to delegate this task to remote A2A agents. def _parse_agent_response( raw_result: str | dict[str, Any], agent_response_model: type[BaseModel] ) -> BaseModel | str | dict[str, Any]: - """Parse LLM output as AgentResponse or return raw agent response. - - Args: - raw_result: Raw output from LLM - agent_response_model: The agent response model - - Returns: - Parsed AgentResponse, or raw result if parsing fails - """ + """Parse LLM output as AgentResponse or return raw agent response.""" if agent_response_model: try: if isinstance(raw_result, str): @@ -373,6 +386,246 @@ def _parse_agent_response( return raw_result +def _handle_max_turns_exceeded( + conversation_history: list[Message], + max_turns: int, +) -> str: + """Handle the case when max turns is exceeded. + + Shared logic for both sync and async delegation. + + Returns: + Final message if found in history. + + Raises: + Exception: If no final message found and max turns exceeded. + """ + if conversation_history: + for msg in reversed(conversation_history): + if msg.role == Role.agent: + text_parts = [ + part.root.text for part in msg.parts if part.root.kind == "text" + ] + final_message = ( + " ".join(text_parts) if text_parts else "Conversation completed" + ) + crewai_event_bus.emit( + None, + A2AConversationCompletedEvent( + status="completed", + final_result=final_message, + error=None, + total_turns=max_turns, + ), + ) + return final_message + + crewai_event_bus.emit( + None, + A2AConversationCompletedEvent( + status="failed", + final_result=None, + error=f"Conversation exceeded maximum turns ({max_turns})", + total_turns=max_turns, + ), + ) + raise Exception(f"A2A conversation exceeded maximum turns ({max_turns})") + + +def _process_response_result( + raw_result: str, + disable_structured_output: bool, + turn_num: int, + agent_role: str, + agent_response_model: type[BaseModel], +) -> tuple[str | None, str | None]: + """Process LLM response and determine next action. + + Shared logic for both sync and async handlers. + + Returns: + Tuple of (final_result, next_request). + """ + if disable_structured_output: + final_turn_number = turn_num + 1 + result_text = str(raw_result) + crewai_event_bus.emit( + None, + A2AMessageSentEvent( + message=result_text, + turn_number=final_turn_number, + is_multiturn=True, + agent_role=agent_role, + ), + ) + crewai_event_bus.emit( + None, + A2AConversationCompletedEvent( + status="completed", + final_result=result_text, + error=None, + total_turns=final_turn_number, + ), + ) + return result_text, None + + llm_response = _parse_agent_response( + raw_result=raw_result, agent_response_model=agent_response_model + ) + + if isinstance(llm_response, BaseModel) and isinstance( + llm_response, AgentResponseProtocol + ): + if not llm_response.is_a2a: + final_turn_number = turn_num + 1 + crewai_event_bus.emit( + None, + A2AMessageSentEvent( + message=str(llm_response.message), + turn_number=final_turn_number, + is_multiturn=True, + agent_role=agent_role, + ), + ) + crewai_event_bus.emit( + None, + A2AConversationCompletedEvent( + status="completed", + final_result=str(llm_response.message), + error=None, + total_turns=final_turn_number, + ), + ) + return str(llm_response.message), None + return None, str(llm_response.message) + + return str(raw_result), None + + +def _prepare_agent_cards_dict( + a2a_result: TaskStateResult, + agent_id: str, + agent_cards: dict[str, AgentCard] | None, +) -> dict[str, AgentCard]: + """Prepare agent cards dictionary from result and existing cards. + + Shared logic for both sync and async response handlers. + """ + agent_cards_dict = agent_cards or {} + if "agent_card" in a2a_result and agent_id not in agent_cards_dict: + agent_cards_dict[agent_id] = a2a_result["agent_card"] + return agent_cards_dict + + +def _prepare_delegation_context( + self: Agent, + agent_response: AgentResponseProtocol, + task: Task, + original_task_description: str | None, +) -> tuple[ + list[A2AConfig], + type[BaseModel], + str, + str, + A2AConfig, + str | None, + str | None, + dict[str, Any] | None, + dict[str, Any] | None, + list[str], + str, + int, +]: + """Prepare delegation context from agent response and task. + + Shared logic for both sync and async delegation. + + Returns: + Tuple containing all the context values needed for delegation. + """ + a2a_agents, agent_response_model = get_a2a_agents_and_response_model(self.a2a) + agent_ids = tuple(config.endpoint for config in a2a_agents) + current_request = str(agent_response.message) + + if hasattr(agent_response, "a2a_ids") and agent_response.a2a_ids: + agent_id = agent_response.a2a_ids[0] + else: + agent_id = agent_ids[0] if agent_ids else "" + + if agent_id and agent_id not in agent_ids: + raise ValueError( + f"Unknown A2A agent ID(s): {agent_response.a2a_ids} not in {agent_ids}" + ) + + agent_config = next(filter(lambda x: x.endpoint == agent_id, a2a_agents)) + task_config = task.config or {} + context_id = task_config.get("context_id") + task_id_config = task_config.get("task_id") + metadata = task_config.get("metadata") + extensions = task_config.get("extensions") + reference_task_ids = task_config.get("reference_task_ids", []) + + if original_task_description is None: + original_task_description = task.description + + max_turns = agent_config.max_turns + + return ( + a2a_agents, + agent_response_model, + current_request, + agent_id, + agent_config, + context_id, + task_id_config, + metadata, + extensions, + reference_task_ids, + original_task_description, + max_turns, + ) + + +def _handle_task_completion( + a2a_result: TaskStateResult, + task: Task, + task_id_config: str | None, + reference_task_ids: list[str], + agent_config: A2AConfig, + turn_num: int, +) -> tuple[str | None, str | None, list[str]]: + """Handle task completion state including reference task updates. + + Shared logic for both sync and async delegation. + + Returns: + Tuple of (result_if_trusted, updated_task_id, updated_reference_task_ids). + """ + if a2a_result["status"] == TaskState.completed: + if task_id_config is not None and task_id_config not in reference_task_ids: + reference_task_ids.append(task_id_config) + if task.config is None: + task.config = {} + task.config["reference_task_ids"] = reference_task_ids + task_id_config = None + + if agent_config.trust_remote_completion_status: + result_text = a2a_result.get("result", "") + final_turn_number = turn_num + 1 + crewai_event_bus.emit( + None, + A2AConversationCompletedEvent( + status="completed", + final_result=result_text, + error=None, + total_turns=final_turn_number, + ), + ) + return str(result_text), task_id_config, reference_task_ids + + return None, task_id_config, reference_task_ids + + def _handle_agent_response_and_continue( self: Agent, a2a_result: TaskStateResult, @@ -413,9 +666,7 @@ def _handle_agent_response_and_continue( - final_result is not None if conversation should end - current_request is the next message to send if continuing """ - agent_cards_dict = agent_cards or {} - if "agent_card" in a2a_result and agent_id not in agent_cards_dict: - agent_cards_dict[agent_id] = a2a_result["agent_card"] + agent_cards_dict = _prepare_agent_cards_dict(a2a_result, agent_id, agent_cards) task.description, disable_structured_output = _augment_prompt_with_a2a( a2a_agents=a2a_agents, @@ -436,61 +687,14 @@ def _handle_agent_response_and_continue( if disable_structured_output: task.response_model = original_response_model - if disable_structured_output: - final_turn_number = turn_num + 1 - result_text = str(raw_result) - crewai_event_bus.emit( - None, - A2AMessageSentEvent( - message=result_text, - turn_number=final_turn_number, - is_multiturn=True, - agent_role=self.role, - ), - ) - crewai_event_bus.emit( - None, - A2AConversationCompletedEvent( - status="completed", - final_result=result_text, - error=None, - total_turns=final_turn_number, - ), - ) - return result_text, None - - llm_response = _parse_agent_response( - raw_result=raw_result, agent_response_model=agent_response_model + return _process_response_result( + raw_result=raw_result, + disable_structured_output=disable_structured_output, + turn_num=turn_num, + agent_role=self.role, + agent_response_model=agent_response_model, ) - if isinstance(llm_response, BaseModel) and isinstance( - llm_response, AgentResponseProtocol - ): - if not llm_response.is_a2a: - final_turn_number = turn_num + 1 - crewai_event_bus.emit( - None, - A2AMessageSentEvent( - message=str(llm_response.message), - turn_number=final_turn_number, - is_multiturn=True, - agent_role=self.role, - ), - ) - crewai_event_bus.emit( - None, - A2AConversationCompletedEvent( - status="completed", - final_result=str(llm_response.message), - error=None, - total_turns=final_turn_number, - ), - ) - return str(llm_response.message), None - return None, str(llm_response.message) - - return str(raw_result), None - def _delegate_to_a2a( self: Agent, @@ -522,34 +726,24 @@ def _delegate_to_a2a( Raises: ImportError: If a2a-sdk is not installed """ - a2a_agents, agent_response_model = get_a2a_agents_and_response_model(self.a2a) - agent_ids = tuple(config.endpoint for config in a2a_agents) - current_request = str(agent_response.message) - - if hasattr(agent_response, "a2a_ids") and agent_response.a2a_ids: - agent_id = agent_response.a2a_ids[0] - else: - agent_id = agent_ids[0] if agent_ids else "" - - if agent_id and agent_id not in agent_ids: - raise ValueError( - f"Unknown A2A agent ID(s): {agent_response.a2a_ids} not in {agent_ids}" - ) - - agent_config = next(filter(lambda x: x.endpoint == agent_id, a2a_agents)) - task_config = task.config or {} - context_id = task_config.get("context_id") - task_id_config = task_config.get("task_id") - metadata = task_config.get("metadata") - extensions = task_config.get("extensions") - - reference_task_ids = task_config.get("reference_task_ids", []) - - if original_task_description is None: - original_task_description = task.description + ( + a2a_agents, + agent_response_model, + current_request, + agent_id, + agent_config, + context_id, + task_id_config, + metadata, + extensions, + reference_task_ids, + original_task_description, + max_turns, + ) = _prepare_delegation_context( + self, agent_response, task, original_task_description + ) conversation_history: list[Message] = [] - max_turns = agent_config.max_turns try: for turn_num in range(max_turns): @@ -589,33 +783,18 @@ def _delegate_to_a2a( context_id = latest_message.context_id if a2a_result["status"] in [TaskState.completed, TaskState.input_required]: - if a2a_result["status"] == TaskState.completed: - if ( - task_id_config is not None - and task_id_config not in reference_task_ids - ): - reference_task_ids.append(task_id_config) - if task.config is None: - task.config = {} - task.config["reference_task_ids"] = reference_task_ids - task_id_config = None - - if ( - a2a_result["status"] == TaskState.completed - and agent_config.trust_remote_completion_status - ): - result_text = a2a_result.get("result", "") - final_turn_number = turn_num + 1 - crewai_event_bus.emit( - None, - A2AConversationCompletedEvent( - status="completed", - final_result=result_text, - error=None, - total_turns=final_turn_number, - ), + trusted_result, task_id_config, reference_task_ids = ( + _handle_task_completion( + a2a_result, + task, + task_id_config, + reference_task_ids, + agent_config, + turn_num, ) - return str(result_text) + ) + if trusted_result is not None: + return trusted_result final_result, next_request = _handle_agent_response_and_continue( self=self, @@ -681,36 +860,321 @@ def _delegate_to_a2a( ) return f"A2A delegation failed: {error_msg}" - if conversation_history: - for msg in reversed(conversation_history): - if msg.role == Role.agent: - text_parts = [ - part.root.text for part in msg.parts if part.root.kind == "text" - ] - final_message = ( - " ".join(text_parts) if text_parts else "Conversation completed" - ) - crewai_event_bus.emit( - None, - A2AConversationCompletedEvent( - status="completed", - final_result=final_message, - error=None, - total_turns=max_turns, - ), - ) - return final_message - - crewai_event_bus.emit( - None, - A2AConversationCompletedEvent( - status="failed", - final_result=None, - error=f"Conversation exceeded maximum turns ({max_turns})", - total_turns=max_turns, - ), - ) - raise Exception(f"A2A conversation exceeded maximum turns ({max_turns})") + return _handle_max_turns_exceeded(conversation_history, max_turns) + + finally: + task.description = original_task_description + + +async def _afetch_card_from_config( + config: A2AConfig, +) -> tuple[A2AConfig, AgentCard | Exception]: + """Fetch agent card from A2A config asynchronously.""" + try: + card = await afetch_agent_card( + endpoint=config.endpoint, + auth=config.auth, + timeout=config.timeout, + ) + return config, card + except Exception as e: + return config, e + + +async def _afetch_agent_cards_concurrently( + a2a_agents: list[A2AConfig], +) -> tuple[dict[str, AgentCard], dict[str, str]]: + """Fetch agent cards concurrently for multiple A2A agents using asyncio.""" + agent_cards: dict[str, AgentCard] = {} + failed_agents: dict[str, str] = {} + + tasks = [_afetch_card_from_config(config) for config in a2a_agents] + results = await asyncio.gather(*tasks) + + for config, result in results: + if isinstance(result, Exception): + if config.fail_fast: + raise RuntimeError( + f"Failed to fetch agent card from {config.endpoint}. " + f"Ensure the A2A agent is running and accessible. Error: {result}" + ) from result + failed_agents[config.endpoint] = str(result) + else: + agent_cards[config.endpoint] = result + + return agent_cards, failed_agents + + +async def _aexecute_task_with_a2a( + self: Agent, + a2a_agents: list[A2AConfig], + original_fn: Callable[..., Coroutine[Any, Any, str]], + task: Task, + agent_response_model: type[BaseModel], + context: str | None, + tools: list[BaseTool] | None, + extension_registry: ExtensionRegistry, +) -> str: + """Async version of _execute_task_with_a2a.""" + original_description: str = task.description + original_output_pydantic = task.output_pydantic + original_response_model = task.response_model + + agent_cards, failed_agents = await _afetch_agent_cards_concurrently(a2a_agents) + + if not agent_cards and a2a_agents and failed_agents: + unavailable_agents_text = "" + for endpoint, error in failed_agents.items(): + unavailable_agents_text += f" - {endpoint}: {error}\n" + + notice = UNAVAILABLE_AGENTS_NOTICE_TEMPLATE.substitute( + unavailable_agents=unavailable_agents_text + ) + task.description = f"{original_description}{notice}" + + try: + return await original_fn(self, task, context, tools) + finally: + task.description = original_description + + task.description, _ = _augment_prompt_with_a2a( + a2a_agents=a2a_agents, + task_description=original_description, + agent_cards=agent_cards, + failed_agents=failed_agents, + extension_registry=extension_registry, + ) + task.response_model = agent_response_model + + try: + raw_result = await original_fn(self, task, context, tools) + agent_response = _parse_agent_response( + raw_result=raw_result, agent_response_model=agent_response_model + ) + + if extension_registry and isinstance(agent_response, BaseModel): + agent_response = extension_registry.process_response_with_all( + agent_response, {} + ) + + if isinstance(agent_response, BaseModel) and isinstance( + agent_response, AgentResponseProtocol + ): + if agent_response.is_a2a: + return await _adelegate_to_a2a( + self, + agent_response=agent_response, + task=task, + original_fn=original_fn, + context=context, + tools=tools, + agent_cards=agent_cards, + original_task_description=original_description, + extension_registry=extension_registry, + ) + return str(agent_response.message) + + return raw_result + finally: + task.description = original_description + task.output_pydantic = original_output_pydantic + task.response_model = original_response_model + + +async def _ahandle_agent_response_and_continue( + self: Agent, + a2a_result: TaskStateResult, + agent_id: str, + agent_cards: dict[str, AgentCard] | None, + a2a_agents: list[A2AConfig], + original_task_description: str, + conversation_history: list[Message], + turn_num: int, + max_turns: int, + task: Task, + original_fn: Callable[..., Coroutine[Any, Any, str]], + context: str | None, + tools: list[BaseTool] | None, + agent_response_model: type[BaseModel], + remote_task_completed: bool = False, +) -> tuple[str | None, str | None]: + """Async version of _handle_agent_response_and_continue.""" + agent_cards_dict = _prepare_agent_cards_dict(a2a_result, agent_id, agent_cards) + + task.description, disable_structured_output = _augment_prompt_with_a2a( + a2a_agents=a2a_agents, + task_description=original_task_description, + conversation_history=conversation_history, + turn_num=turn_num, + max_turns=max_turns, + agent_cards=agent_cards_dict, + remote_task_completed=remote_task_completed, + ) + + original_response_model = task.response_model + if disable_structured_output: + task.response_model = None + + raw_result = await original_fn(self, task, context, tools) + + if disable_structured_output: + task.response_model = original_response_model + + return _process_response_result( + raw_result=raw_result, + disable_structured_output=disable_structured_output, + turn_num=turn_num, + agent_role=self.role, + agent_response_model=agent_response_model, + ) + + +async def _adelegate_to_a2a( + self: Agent, + agent_response: AgentResponseProtocol, + task: Task, + original_fn: Callable[..., Coroutine[Any, Any, str]], + context: str | None, + tools: list[BaseTool] | None, + agent_cards: dict[str, AgentCard] | None = None, + original_task_description: str | None = None, + extension_registry: ExtensionRegistry | None = None, +) -> str: + """Async version of _delegate_to_a2a.""" + ( + a2a_agents, + agent_response_model, + current_request, + agent_id, + agent_config, + context_id, + task_id_config, + metadata, + extensions, + reference_task_ids, + original_task_description, + max_turns, + ) = _prepare_delegation_context( + self, agent_response, task, original_task_description + ) + + conversation_history: list[Message] = [] + + try: + for turn_num in range(max_turns): + console_formatter = getattr(crewai_event_bus, "_console", None) + agent_branch = None + if console_formatter: + agent_branch = getattr( + console_formatter, "current_agent_branch", None + ) or getattr(console_formatter, "current_task_branch", None) + + a2a_result = await aexecute_a2a_delegation( + endpoint=agent_config.endpoint, + auth=agent_config.auth, + timeout=agent_config.timeout, + task_description=current_request, + context_id=context_id, + task_id=task_id_config, + reference_task_ids=reference_task_ids, + metadata=metadata, + extensions=extensions, + conversation_history=conversation_history, + agent_id=agent_id, + agent_role=Role.user, + agent_branch=agent_branch, + response_model=agent_config.response_model, + turn_number=turn_num + 1, + updates=agent_config.updates, + ) + + conversation_history = a2a_result.get("history", []) + + if conversation_history: + latest_message = conversation_history[-1] + if latest_message.task_id is not None: + task_id_config = latest_message.task_id + if latest_message.context_id is not None: + context_id = latest_message.context_id + + if a2a_result["status"] in [TaskState.completed, TaskState.input_required]: + trusted_result, task_id_config, reference_task_ids = ( + _handle_task_completion( + a2a_result, + task, + task_id_config, + reference_task_ids, + agent_config, + turn_num, + ) + ) + if trusted_result is not None: + return trusted_result + + final_result, next_request = await _ahandle_agent_response_and_continue( + self=self, + a2a_result=a2a_result, + agent_id=agent_id, + agent_cards=agent_cards, + a2a_agents=a2a_agents, + original_task_description=original_task_description, + conversation_history=conversation_history, + turn_num=turn_num, + max_turns=max_turns, + task=task, + original_fn=original_fn, + context=context, + tools=tools, + agent_response_model=agent_response_model, + remote_task_completed=(a2a_result["status"] == TaskState.completed), + ) + + if final_result is not None: + return final_result + + if next_request is not None: + current_request = next_request + + continue + + error_msg = a2a_result.get("error", "Unknown error") + + final_result, next_request = await _ahandle_agent_response_and_continue( + self=self, + a2a_result=a2a_result, + agent_id=agent_id, + agent_cards=agent_cards, + a2a_agents=a2a_agents, + original_task_description=original_task_description, + conversation_history=conversation_history, + turn_num=turn_num, + max_turns=max_turns, + task=task, + original_fn=original_fn, + context=context, + tools=tools, + agent_response_model=agent_response_model, + ) + + if final_result is not None: + return final_result + + if next_request is not None: + current_request = next_request + continue + + crewai_event_bus.emit( + None, + A2AConversationCompletedEvent( + status="failed", + final_result=None, + error=error_msg, + total_turns=turn_num + 1, + ), + ) + return f"A2A delegation failed: {error_msg}" + + return _handle_max_turns_exceeded(conversation_history, max_turns) finally: task.description = original_task_description