mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-29 18:18:13 +00:00
feat: a2a extensions API and async agent card caching; fix task propagation & streaming
Some checks failed
CodeQL Advanced / Analyze (actions) (push) Has been cancelled
CodeQL Advanced / Analyze (python) (push) Has been cancelled
Notify Downstream / notify-downstream (push) Has been cancelled
Build uv cache / build-cache (3.10) (push) Has been cancelled
Build uv cache / build-cache (3.11) (push) Has been cancelled
Build uv cache / build-cache (3.12) (push) Has been cancelled
Build uv cache / build-cache (3.13) (push) Has been cancelled
Some checks failed
CodeQL Advanced / Analyze (actions) (push) Has been cancelled
CodeQL Advanced / Analyze (python) (push) Has been cancelled
Notify Downstream / notify-downstream (push) Has been cancelled
Build uv cache / build-cache (3.10) (push) Has been cancelled
Build uv cache / build-cache (3.11) (push) Has been cancelled
Build uv cache / build-cache (3.12) (push) Has been cancelled
Build uv cache / build-cache (3.13) (push) Has been cancelled
Adds initial extensions API (with registry temporarily no-op), introduces aiocache for async caching, ensures reference task IDs propagate correctly, fixes streamed response model handling, updates streaming tests, and regenerates lockfiles.
This commit is contained in:
@@ -95,6 +95,7 @@ a2a = [
|
||||
"a2a-sdk~=0.3.10",
|
||||
"httpx-auth~=0.23.1",
|
||||
"httpx-sse~=0.4.0",
|
||||
"aiocache[redis,memcached]~=0.12.3",
|
||||
]
|
||||
|
||||
|
||||
|
||||
4
lib/crewai/src/crewai/a2a/extensions/__init__.py
Normal file
4
lib/crewai/src/crewai/a2a/extensions/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
"""A2A Protocol Extensions for CrewAI.
|
||||
|
||||
This module contains extensions to the A2A (Agent-to-Agent) protocol.
|
||||
"""
|
||||
193
lib/crewai/src/crewai/a2a/extensions/base.py
Normal file
193
lib/crewai/src/crewai/a2a/extensions/base.py
Normal file
@@ -0,0 +1,193 @@
|
||||
"""Base extension interface for A2A wrapper integrations.
|
||||
|
||||
This module defines the protocol for extending A2A wrapper functionality
|
||||
with custom logic for conversation processing, prompt augmentation, and
|
||||
agent response handling.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import TYPE_CHECKING, Any, Protocol
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from a2a.types import Message
|
||||
|
||||
from crewai.agent.core import Agent
|
||||
|
||||
|
||||
class ConversationState(Protocol):
|
||||
"""Protocol for extension-specific conversation state.
|
||||
|
||||
Extensions can define their own state classes that implement this protocol
|
||||
to track conversation-specific data extracted from message history.
|
||||
"""
|
||||
|
||||
def is_ready(self) -> bool:
|
||||
"""Check if the state indicates readiness for some action.
|
||||
|
||||
Returns:
|
||||
True if the state is ready, False otherwise.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class A2AExtension(Protocol):
|
||||
"""Protocol for A2A wrapper extensions.
|
||||
|
||||
Extensions can implement this protocol to inject custom logic into
|
||||
the A2A conversation flow at various integration points.
|
||||
"""
|
||||
|
||||
def inject_tools(self, agent: Agent) -> None:
|
||||
"""Inject extension-specific tools into the agent.
|
||||
|
||||
Called when an agent is wrapped with A2A capabilities. Extensions
|
||||
can add tools that enable extension-specific functionality.
|
||||
|
||||
Args:
|
||||
agent: The agent instance to inject tools into.
|
||||
"""
|
||||
...
|
||||
|
||||
def extract_state_from_history(
|
||||
self, conversation_history: Sequence[Message]
|
||||
) -> ConversationState | None:
|
||||
"""Extract extension-specific state from conversation history.
|
||||
|
||||
Called during prompt augmentation to allow extensions to analyze
|
||||
the conversation history and extract relevant state information.
|
||||
|
||||
Args:
|
||||
conversation_history: The sequence of A2A messages exchanged.
|
||||
|
||||
Returns:
|
||||
Extension-specific conversation state, or None if no relevant state.
|
||||
"""
|
||||
...
|
||||
|
||||
def augment_prompt(
|
||||
self,
|
||||
base_prompt: str,
|
||||
conversation_state: ConversationState | None,
|
||||
) -> str:
|
||||
"""Augment the task prompt with extension-specific instructions.
|
||||
|
||||
Called during prompt augmentation to allow extensions to add
|
||||
custom instructions based on conversation state.
|
||||
|
||||
Args:
|
||||
base_prompt: The base prompt to augment.
|
||||
conversation_state: Extension-specific state from extract_state_from_history.
|
||||
|
||||
Returns:
|
||||
The augmented prompt with extension-specific instructions.
|
||||
"""
|
||||
...
|
||||
|
||||
def process_response(
|
||||
self,
|
||||
agent_response: Any,
|
||||
conversation_state: ConversationState | None,
|
||||
) -> Any:
|
||||
"""Process and potentially modify the agent response.
|
||||
|
||||
Called after parsing the agent's response, allowing extensions to
|
||||
enhance or modify the response based on conversation state.
|
||||
|
||||
Args:
|
||||
agent_response: The parsed agent response.
|
||||
conversation_state: Extension-specific state from extract_state_from_history.
|
||||
|
||||
Returns:
|
||||
The processed agent response (may be modified or original).
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class ExtensionRegistry:
|
||||
"""Registry for managing A2A extensions.
|
||||
|
||||
Maintains a collection of extensions and provides methods to invoke
|
||||
their hooks at various integration points.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the extension registry."""
|
||||
self._extensions: list[A2AExtension] = []
|
||||
|
||||
def register(self, extension: A2AExtension) -> None:
|
||||
"""Register an extension.
|
||||
|
||||
Args:
|
||||
extension: The extension to register.
|
||||
"""
|
||||
self._extensions.append(extension)
|
||||
|
||||
def inject_all_tools(self, agent: Agent) -> None:
|
||||
"""Inject tools from all registered extensions.
|
||||
|
||||
Args:
|
||||
agent: The agent instance to inject tools into.
|
||||
"""
|
||||
for extension in self._extensions:
|
||||
extension.inject_tools(agent)
|
||||
|
||||
def extract_all_states(
|
||||
self, conversation_history: Sequence[Message]
|
||||
) -> dict[type[A2AExtension], ConversationState]:
|
||||
"""Extract conversation states from all registered extensions.
|
||||
|
||||
Args:
|
||||
conversation_history: The sequence of A2A messages exchanged.
|
||||
|
||||
Returns:
|
||||
Mapping of extension types to their conversation states.
|
||||
"""
|
||||
states: dict[type[A2AExtension], ConversationState] = {}
|
||||
for extension in self._extensions:
|
||||
state = extension.extract_state_from_history(conversation_history)
|
||||
if state is not None:
|
||||
states[type(extension)] = state
|
||||
return states
|
||||
|
||||
def augment_prompt_with_all(
|
||||
self,
|
||||
base_prompt: str,
|
||||
extension_states: dict[type[A2AExtension], ConversationState],
|
||||
) -> str:
|
||||
"""Augment prompt with instructions from all registered extensions.
|
||||
|
||||
Args:
|
||||
base_prompt: The base prompt to augment.
|
||||
extension_states: Mapping of extension types to conversation states.
|
||||
|
||||
Returns:
|
||||
The fully augmented prompt.
|
||||
"""
|
||||
augmented = base_prompt
|
||||
for extension in self._extensions:
|
||||
state = extension_states.get(type(extension))
|
||||
augmented = extension.augment_prompt(augmented, state)
|
||||
return augmented
|
||||
|
||||
def process_response_with_all(
|
||||
self,
|
||||
agent_response: Any,
|
||||
extension_states: dict[type[A2AExtension], ConversationState],
|
||||
) -> Any:
|
||||
"""Process response through all registered extensions.
|
||||
|
||||
Args:
|
||||
agent_response: The parsed agent response.
|
||||
extension_states: Mapping of extension types to conversation states.
|
||||
|
||||
Returns:
|
||||
The processed agent response.
|
||||
"""
|
||||
processed = agent_response
|
||||
for extension in self._extensions:
|
||||
state = extension_states.get(type(extension))
|
||||
processed = extension.process_response(processed, state)
|
||||
return processed
|
||||
34
lib/crewai/src/crewai/a2a/extensions/registry.py
Normal file
34
lib/crewai/src/crewai/a2a/extensions/registry.py
Normal file
@@ -0,0 +1,34 @@
|
||||
"""Extension registry factory for A2A configurations.
|
||||
|
||||
This module provides utilities for creating extension registries from A2A configurations.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from crewai.a2a.extensions.base import ExtensionRegistry
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.a2a.config import A2AConfig
|
||||
|
||||
|
||||
def create_extension_registry_from_config(
|
||||
a2a_config: list[A2AConfig] | A2AConfig,
|
||||
) -> ExtensionRegistry:
|
||||
"""Create an extension registry from A2A configuration.
|
||||
|
||||
Args:
|
||||
a2a_config: A2A configuration (single or list)
|
||||
|
||||
Returns:
|
||||
Configured extension registry with all applicable extensions
|
||||
"""
|
||||
registry = ExtensionRegistry()
|
||||
configs = a2a_config if isinstance(a2a_config, list) else [a2a_config]
|
||||
|
||||
for _ in configs:
|
||||
pass
|
||||
|
||||
return registry
|
||||
@@ -23,6 +23,8 @@ from a2a.types import (
|
||||
TextPart,
|
||||
TransportProtocol,
|
||||
)
|
||||
from aiocache import cached # type: ignore[import-untyped]
|
||||
from aiocache.serializers import PickleSerializer # type: ignore[import-untyped]
|
||||
import httpx
|
||||
from pydantic import BaseModel, Field, create_model
|
||||
|
||||
@@ -65,7 +67,7 @@ def _fetch_agent_card_cached(
|
||||
endpoint: A2A agent endpoint URL
|
||||
auth_hash: Hash of the auth object
|
||||
timeout: Request timeout
|
||||
_ttl_hash: Time-based hash for cache invalidation (unused in body)
|
||||
_ttl_hash: Time-based hash for cache invalidation
|
||||
|
||||
Returns:
|
||||
Cached AgentCard
|
||||
@@ -106,7 +108,18 @@ def fetch_agent_card(
|
||||
A2AClientHTTPError: If authentication fails
|
||||
"""
|
||||
if use_cache:
|
||||
auth_hash = hash((type(auth).__name__, id(auth))) if auth else 0
|
||||
if auth:
|
||||
auth_data = auth.model_dump_json(
|
||||
exclude={
|
||||
"_access_token",
|
||||
"_token_expires_at",
|
||||
"_refresh_token",
|
||||
"_authorization_callback",
|
||||
}
|
||||
)
|
||||
auth_hash = hash((type(auth).__name__, auth_data))
|
||||
else:
|
||||
auth_hash = 0
|
||||
_auth_store[auth_hash] = auth
|
||||
ttl_hash = int(time.time() // cache_ttl)
|
||||
return _fetch_agent_card_cached(endpoint, auth_hash, timeout, ttl_hash)
|
||||
@@ -121,6 +134,26 @@ def fetch_agent_card(
|
||||
loop.close()
|
||||
|
||||
|
||||
@cached(ttl=300, serializer=PickleSerializer()) # type: ignore[untyped-decorator]
|
||||
async def _fetch_agent_card_async_cached(
|
||||
endpoint: str,
|
||||
auth_hash: int,
|
||||
timeout: int,
|
||||
) -> AgentCard:
|
||||
"""Cached async implementation of AgentCard fetching.
|
||||
|
||||
Args:
|
||||
endpoint: A2A agent endpoint URL
|
||||
auth_hash: Hash of the auth object
|
||||
timeout: Request timeout in seconds
|
||||
|
||||
Returns:
|
||||
Cached AgentCard object
|
||||
"""
|
||||
auth = _auth_store.get(auth_hash)
|
||||
return await _fetch_agent_card_async(endpoint=endpoint, auth=auth, timeout=timeout)
|
||||
|
||||
|
||||
async def _fetch_agent_card_async(
|
||||
endpoint: str,
|
||||
auth: AuthScheme | None,
|
||||
@@ -339,7 +372,22 @@ async def _execute_a2a_delegation_async(
|
||||
Returns:
|
||||
Dictionary with status, result/error, and new history
|
||||
"""
|
||||
agent_card = await _fetch_agent_card_async(endpoint, auth, timeout)
|
||||
if auth:
|
||||
auth_data = auth.model_dump_json(
|
||||
exclude={
|
||||
"_access_token",
|
||||
"_token_expires_at",
|
||||
"_refresh_token",
|
||||
"_authorization_callback",
|
||||
}
|
||||
)
|
||||
auth_hash = hash((type(auth).__name__, auth_data))
|
||||
else:
|
||||
auth_hash = 0
|
||||
_auth_store[auth_hash] = auth
|
||||
agent_card = await _fetch_agent_card_async_cached(
|
||||
endpoint=endpoint, auth_hash=auth_hash, timeout=timeout
|
||||
)
|
||||
|
||||
validate_auth_against_agent_card(agent_card, auth)
|
||||
|
||||
@@ -556,6 +604,34 @@ async def _execute_a2a_delegation_async(
|
||||
}
|
||||
break
|
||||
except Exception as e:
|
||||
if isinstance(e, A2AClientHTTPError):
|
||||
error_msg = f"HTTP Error {e.status_code}: {e!s}"
|
||||
|
||||
error_message = Message(
|
||||
role=Role.agent,
|
||||
message_id=str(uuid.uuid4()),
|
||||
parts=[Part(root=TextPart(text=error_msg))],
|
||||
context_id=context_id,
|
||||
task_id=task_id,
|
||||
)
|
||||
new_messages.append(error_message)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
None,
|
||||
A2AResponseReceivedEvent(
|
||||
response=error_msg,
|
||||
turn_number=turn_number,
|
||||
is_multiturn=is_multiturn,
|
||||
status="failed",
|
||||
agent_role=agent_role,
|
||||
),
|
||||
)
|
||||
return {
|
||||
"status": "failed",
|
||||
"error": error_msg,
|
||||
"history": new_messages,
|
||||
}
|
||||
|
||||
current_exception: Exception | BaseException | None = e
|
||||
while current_exception:
|
||||
if hasattr(current_exception, "response"):
|
||||
@@ -752,4 +828,5 @@ def get_a2a_agents_and_response_model(
|
||||
Tuple of A2A agent IDs and response model
|
||||
"""
|
||||
a2a_agents, agent_ids = extract_a2a_agent_ids_from_config(a2a_config=a2a_config)
|
||||
|
||||
return a2a_agents, create_agent_response_model(agent_ids)
|
||||
|
||||
@@ -15,6 +15,7 @@ from a2a.types import Role
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from crewai.a2a.config import A2AConfig
|
||||
from crewai.a2a.extensions.base import ExtensionRegistry
|
||||
from crewai.a2a.templates import (
|
||||
AVAILABLE_AGENTS_TEMPLATE,
|
||||
CONVERSATION_TURN_INFO_TEMPLATE,
|
||||
@@ -42,7 +43,9 @@ if TYPE_CHECKING:
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
|
||||
|
||||
def wrap_agent_with_a2a_instance(agent: Agent) -> None:
|
||||
def wrap_agent_with_a2a_instance(
|
||||
agent: Agent, extension_registry: ExtensionRegistry | None = None
|
||||
) -> None:
|
||||
"""Wrap an agent instance's execute_task method with A2A support.
|
||||
|
||||
This function modifies the agent instance by wrapping its execute_task
|
||||
@@ -51,7 +54,13 @@ def wrap_agent_with_a2a_instance(agent: Agent) -> None:
|
||||
|
||||
Args:
|
||||
agent: The agent instance to wrap
|
||||
extension_registry: Optional registry of A2A extensions for injecting tools and custom logic
|
||||
"""
|
||||
if extension_registry is None:
|
||||
extension_registry = ExtensionRegistry()
|
||||
|
||||
extension_registry.inject_all_tools(agent)
|
||||
|
||||
original_execute_task = agent.execute_task.__func__ # type: ignore[attr-defined]
|
||||
|
||||
@wraps(original_execute_task)
|
||||
@@ -85,6 +94,7 @@ def wrap_agent_with_a2a_instance(agent: Agent) -> None:
|
||||
agent_response_model=agent_response_model,
|
||||
context=context,
|
||||
tools=tools,
|
||||
extension_registry=extension_registry,
|
||||
)
|
||||
|
||||
object.__setattr__(agent, "execute_task", MethodType(execute_task_with_a2a, agent))
|
||||
@@ -154,6 +164,7 @@ def _execute_task_with_a2a(
|
||||
agent_response_model: type[BaseModel],
|
||||
context: str | None,
|
||||
tools: list[BaseTool] | None,
|
||||
extension_registry: ExtensionRegistry,
|
||||
) -> str:
|
||||
"""Wrap execute_task with A2A delegation logic.
|
||||
|
||||
@@ -165,6 +176,7 @@ def _execute_task_with_a2a(
|
||||
context: Optional context for task execution
|
||||
tools: Optional tools available to the agent
|
||||
agent_response_model: Optional agent response model
|
||||
extension_registry: Registry of A2A extensions
|
||||
|
||||
Returns:
|
||||
Task execution result (either from LLM or A2A agent)
|
||||
@@ -190,11 +202,12 @@ def _execute_task_with_a2a(
|
||||
finally:
|
||||
task.description = original_description
|
||||
|
||||
task.description = _augment_prompt_with_a2a(
|
||||
task.description, _ = _augment_prompt_with_a2a(
|
||||
a2a_agents=a2a_agents,
|
||||
task_description=original_description,
|
||||
agent_cards=agent_cards,
|
||||
failed_agents=failed_agents,
|
||||
extension_registry=extension_registry,
|
||||
)
|
||||
task.response_model = agent_response_model
|
||||
|
||||
@@ -204,6 +217,11 @@ def _execute_task_with_a2a(
|
||||
raw_result=raw_result, agent_response_model=agent_response_model
|
||||
)
|
||||
|
||||
if extension_registry and isinstance(agent_response, BaseModel):
|
||||
agent_response = extension_registry.process_response_with_all(
|
||||
agent_response, {}
|
||||
)
|
||||
|
||||
if isinstance(agent_response, BaseModel) and isinstance(
|
||||
agent_response, AgentResponseProtocol
|
||||
):
|
||||
@@ -217,6 +235,7 @@ def _execute_task_with_a2a(
|
||||
tools=tools,
|
||||
agent_cards=agent_cards,
|
||||
original_task_description=original_description,
|
||||
extension_registry=extension_registry,
|
||||
)
|
||||
return str(agent_response.message)
|
||||
|
||||
@@ -235,7 +254,8 @@ def _augment_prompt_with_a2a(
|
||||
turn_num: int = 0,
|
||||
max_turns: int | None = None,
|
||||
failed_agents: dict[str, str] | None = None,
|
||||
) -> str:
|
||||
extension_registry: ExtensionRegistry | None = None,
|
||||
) -> tuple[str, bool]:
|
||||
"""Add A2A delegation instructions to prompt.
|
||||
|
||||
Args:
|
||||
@@ -246,13 +266,14 @@ def _augment_prompt_with_a2a(
|
||||
turn_num: Current turn number (0-indexed)
|
||||
max_turns: Maximum allowed turns (from config)
|
||||
failed_agents: Dictionary mapping failed agent endpoints to error messages
|
||||
extension_registry: Optional registry of A2A extensions
|
||||
|
||||
Returns:
|
||||
Augmented task description with A2A instructions
|
||||
Tuple of (augmented prompt, disable_structured_output flag)
|
||||
"""
|
||||
|
||||
if not agent_cards:
|
||||
return task_description
|
||||
return task_description, False
|
||||
|
||||
agents_text = ""
|
||||
|
||||
@@ -270,6 +291,7 @@ def _augment_prompt_with_a2a(
|
||||
agents_text = AVAILABLE_AGENTS_TEMPLATE.substitute(available_a2a_agents=agents_text)
|
||||
|
||||
history_text = ""
|
||||
|
||||
if conversation_history:
|
||||
for msg in conversation_history:
|
||||
history_text += f"\n{msg.model_dump_json(indent=2, exclude_none=True, exclude={'message_id'})}\n"
|
||||
@@ -277,6 +299,15 @@ def _augment_prompt_with_a2a(
|
||||
history_text = PREVIOUS_A2A_CONVERSATION_TEMPLATE.substitute(
|
||||
previous_a2a_conversation=history_text
|
||||
)
|
||||
|
||||
extension_states = {}
|
||||
disable_structured_output = False
|
||||
if extension_registry and conversation_history:
|
||||
extension_states = extension_registry.extract_all_states(conversation_history)
|
||||
for state in extension_states.values():
|
||||
if state.is_ready():
|
||||
disable_structured_output = True
|
||||
break
|
||||
turn_info = ""
|
||||
|
||||
if max_turns is not None and conversation_history:
|
||||
@@ -296,16 +327,22 @@ def _augment_prompt_with_a2a(
|
||||
warning=warning,
|
||||
)
|
||||
|
||||
return f"""{task_description}
|
||||
augmented_prompt = f"""{task_description}
|
||||
|
||||
IMPORTANT: You have the ability to delegate this task to remote A2A agents.
|
||||
|
||||
{agents_text}
|
||||
{history_text}{turn_info}
|
||||
|
||||
|
||||
"""
|
||||
|
||||
if extension_registry:
|
||||
augmented_prompt = extension_registry.augment_prompt_with_all(
|
||||
augmented_prompt, extension_states
|
||||
)
|
||||
|
||||
return augmented_prompt, disable_structured_output
|
||||
|
||||
|
||||
def _parse_agent_response(
|
||||
raw_result: str | dict[str, Any], agent_response_model: type[BaseModel]
|
||||
@@ -373,7 +410,7 @@ def _handle_agent_response_and_continue(
|
||||
if "agent_card" in a2a_result and agent_id not in agent_cards_dict:
|
||||
agent_cards_dict[agent_id] = a2a_result["agent_card"]
|
||||
|
||||
task.description = _augment_prompt_with_a2a(
|
||||
task.description, disable_structured_output = _augment_prompt_with_a2a(
|
||||
a2a_agents=a2a_agents,
|
||||
task_description=original_task_description,
|
||||
conversation_history=conversation_history,
|
||||
@@ -382,7 +419,38 @@ def _handle_agent_response_and_continue(
|
||||
agent_cards=agent_cards_dict,
|
||||
)
|
||||
|
||||
original_response_model = task.response_model
|
||||
if disable_structured_output:
|
||||
task.response_model = None
|
||||
|
||||
raw_result = original_fn(self, task, context, tools)
|
||||
|
||||
if disable_structured_output:
|
||||
task.response_model = original_response_model
|
||||
|
||||
if disable_structured_output:
|
||||
final_turn_number = turn_num + 1
|
||||
result_text = str(raw_result)
|
||||
crewai_event_bus.emit(
|
||||
None,
|
||||
A2AMessageSentEvent(
|
||||
message=result_text,
|
||||
turn_number=final_turn_number,
|
||||
is_multiturn=True,
|
||||
agent_role=self.role,
|
||||
),
|
||||
)
|
||||
crewai_event_bus.emit(
|
||||
None,
|
||||
A2AConversationCompletedEvent(
|
||||
status="completed",
|
||||
final_result=result_text,
|
||||
error=None,
|
||||
total_turns=final_turn_number,
|
||||
),
|
||||
)
|
||||
return result_text, None
|
||||
|
||||
llm_response = _parse_agent_response(
|
||||
raw_result=raw_result, agent_response_model=agent_response_model
|
||||
)
|
||||
@@ -425,6 +493,7 @@ def _delegate_to_a2a(
|
||||
tools: list[BaseTool] | None,
|
||||
agent_cards: dict[str, AgentCard] | None = None,
|
||||
original_task_description: str | None = None,
|
||||
extension_registry: ExtensionRegistry | None = None,
|
||||
) -> str:
|
||||
"""Delegate to A2A agent with multi-turn conversation support.
|
||||
|
||||
@@ -437,6 +506,7 @@ def _delegate_to_a2a(
|
||||
tools: Optional tools available to the agent
|
||||
agent_cards: Pre-fetched agent cards from _execute_task_with_a2a
|
||||
original_task_description: The original task description before A2A augmentation
|
||||
extension_registry: Optional registry of A2A extensions
|
||||
|
||||
Returns:
|
||||
Result from A2A agent
|
||||
@@ -447,9 +517,13 @@ def _delegate_to_a2a(
|
||||
a2a_agents, agent_response_model = get_a2a_agents_and_response_model(self.a2a)
|
||||
agent_ids = tuple(config.endpoint for config in a2a_agents)
|
||||
current_request = str(agent_response.message)
|
||||
agent_id = agent_response.a2a_ids[0]
|
||||
|
||||
if agent_id not in agent_ids:
|
||||
if hasattr(agent_response, "a2a_ids") and agent_response.a2a_ids:
|
||||
agent_id = agent_response.a2a_ids[0]
|
||||
else:
|
||||
agent_id = agent_ids[0] if agent_ids else ""
|
||||
|
||||
if agent_id and agent_id not in agent_ids:
|
||||
raise ValueError(
|
||||
f"Unknown A2A agent ID(s): {agent_response.a2a_ids} not in {agent_ids}"
|
||||
)
|
||||
@@ -458,10 +532,11 @@ def _delegate_to_a2a(
|
||||
task_config = task.config or {}
|
||||
context_id = task_config.get("context_id")
|
||||
task_id_config = task_config.get("task_id")
|
||||
reference_task_ids = task_config.get("reference_task_ids")
|
||||
metadata = task_config.get("metadata")
|
||||
extensions = task_config.get("extensions")
|
||||
|
||||
reference_task_ids = task_config.get("reference_task_ids", [])
|
||||
|
||||
if original_task_description is None:
|
||||
original_task_description = task.description
|
||||
|
||||
@@ -497,11 +572,27 @@ def _delegate_to_a2a(
|
||||
|
||||
conversation_history = a2a_result.get("history", [])
|
||||
|
||||
if conversation_history:
|
||||
latest_message = conversation_history[-1]
|
||||
if latest_message.task_id is not None:
|
||||
task_id_config = latest_message.task_id
|
||||
if latest_message.context_id is not None:
|
||||
context_id = latest_message.context_id
|
||||
|
||||
if a2a_result["status"] in ["completed", "input_required"]:
|
||||
if (
|
||||
a2a_result["status"] == "completed"
|
||||
and agent_config.trust_remote_completion_status
|
||||
):
|
||||
if (
|
||||
task_id_config is not None
|
||||
and task_id_config not in reference_task_ids
|
||||
):
|
||||
reference_task_ids.append(task_id_config)
|
||||
if task.config is None:
|
||||
task.config = {}
|
||||
task.config["reference_task_ids"] = reference_task_ids
|
||||
|
||||
result_text = a2a_result.get("result", "")
|
||||
final_turn_number = turn_num + 1
|
||||
crewai_event_bus.emit(
|
||||
@@ -513,7 +604,7 @@ def _delegate_to_a2a(
|
||||
total_turns=final_turn_number,
|
||||
),
|
||||
)
|
||||
return result_text # type: ignore[no-any-return]
|
||||
return cast(str, result_text)
|
||||
|
||||
final_result, next_request = _handle_agent_response_and_continue(
|
||||
self=self,
|
||||
@@ -541,6 +632,31 @@ def _delegate_to_a2a(
|
||||
continue
|
||||
|
||||
error_msg = a2a_result.get("error", "Unknown error")
|
||||
|
||||
final_result, next_request = _handle_agent_response_and_continue(
|
||||
self=self,
|
||||
a2a_result=a2a_result,
|
||||
agent_id=agent_id,
|
||||
agent_cards=agent_cards,
|
||||
a2a_agents=a2a_agents,
|
||||
original_task_description=original_task_description,
|
||||
conversation_history=conversation_history,
|
||||
turn_num=turn_num,
|
||||
max_turns=max_turns,
|
||||
task=task,
|
||||
original_fn=original_fn,
|
||||
context=context,
|
||||
tools=tools,
|
||||
agent_response_model=agent_response_model,
|
||||
)
|
||||
|
||||
if final_result is not None:
|
||||
return final_result
|
||||
|
||||
if next_request is not None:
|
||||
current_request = next_request
|
||||
continue
|
||||
|
||||
crewai_event_bus.emit(
|
||||
None,
|
||||
A2AConversationCompletedEvent(
|
||||
@@ -550,7 +666,7 @@ def _delegate_to_a2a(
|
||||
total_turns=turn_num + 1,
|
||||
),
|
||||
)
|
||||
raise Exception(f"A2A delegation failed: {error_msg}")
|
||||
return f"A2A delegation failed: {error_msg}"
|
||||
|
||||
if conversation_history:
|
||||
for msg in reversed(conversation_history):
|
||||
|
||||
@@ -4,9 +4,8 @@ This metaclass enables extension capabilities for agents by detecting
|
||||
extension fields in class annotations and applying appropriate wrappers.
|
||||
"""
|
||||
|
||||
import warnings
|
||||
from functools import wraps
|
||||
from typing import Any
|
||||
import warnings
|
||||
|
||||
from pydantic import model_validator
|
||||
from pydantic._internal._model_construction import ModelMetaclass
|
||||
@@ -59,9 +58,15 @@ class AgentMeta(ModelMetaclass):
|
||||
|
||||
a2a_value = getattr(self, "a2a", None)
|
||||
if a2a_value is not None:
|
||||
from crewai.a2a.extensions.registry import (
|
||||
create_extension_registry_from_config,
|
||||
)
|
||||
from crewai.a2a.wrapper import wrap_agent_with_a2a_instance
|
||||
|
||||
wrap_agent_with_a2a_instance(self)
|
||||
extension_registry = create_extension_registry_from_config(
|
||||
a2a_value
|
||||
)
|
||||
wrap_agent_with_a2a_instance(self, extension_registry)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncIterator, Iterator
|
||||
from collections.abc import AsyncIterator
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import httpx
|
||||
from openai import APIConnectionError, AsyncOpenAI, NotFoundError, OpenAI
|
||||
from openai import APIConnectionError, AsyncOpenAI, NotFoundError, OpenAI, Stream
|
||||
from openai.lib.streaming.chat import ChatCompletionStream
|
||||
from openai.types.chat import ChatCompletion, ChatCompletionChunk
|
||||
from openai.types.chat.chat_completion import Choice
|
||||
from openai.types.chat.chat_completion_chunk import ChoiceDelta
|
||||
@@ -515,59 +516,52 @@ class OpenAICompletion(BaseLLM):
|
||||
tool_calls = {}
|
||||
|
||||
if response_model:
|
||||
completion_stream: Iterator[ChatCompletionChunk] = (
|
||||
self.client.chat.completions.create(**params)
|
||||
)
|
||||
parse_params = {
|
||||
k: v
|
||||
for k, v in params.items()
|
||||
if k not in ("response_format", "stream")
|
||||
}
|
||||
|
||||
accumulated_content = ""
|
||||
for chunk in completion_stream:
|
||||
if not chunk.choices:
|
||||
continue
|
||||
stream: ChatCompletionStream[BaseModel]
|
||||
with self.client.beta.chat.completions.stream(
|
||||
**parse_params, response_format=response_model
|
||||
) as stream:
|
||||
for chunk in stream:
|
||||
if chunk.type == "content.delta":
|
||||
delta_content = chunk.delta
|
||||
if delta_content:
|
||||
self._emit_stream_chunk_event(
|
||||
chunk=delta_content,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
choice = chunk.choices[0]
|
||||
delta: ChoiceDelta = choice.delta
|
||||
final_completion = stream.get_final_completion()
|
||||
if final_completion and final_completion.choices:
|
||||
parsed_result = final_completion.choices[0].message.parsed
|
||||
if parsed_result:
|
||||
structured_json = parsed_result.model_dump_json()
|
||||
self._emit_call_completed_event(
|
||||
response=structured_json,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
)
|
||||
return structured_json
|
||||
|
||||
if delta.content:
|
||||
accumulated_content += delta.content
|
||||
self._emit_stream_chunk_event(
|
||||
chunk=delta.content,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
logging.error("Failed to get parsed result from stream")
|
||||
return ""
|
||||
|
||||
try:
|
||||
parsed_object = response_model.model_validate_json(accumulated_content)
|
||||
structured_json = parsed_object.model_dump_json()
|
||||
|
||||
self._emit_call_completed_event(
|
||||
response=structured_json,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
)
|
||||
|
||||
return structured_json
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to parse structured output from stream: {e}")
|
||||
self._emit_call_completed_event(
|
||||
response=accumulated_content,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
)
|
||||
return accumulated_content
|
||||
|
||||
stream: Iterator[ChatCompletionChunk] = self.client.chat.completions.create(
|
||||
**params
|
||||
completion_stream: Stream[ChatCompletionChunk] = (
|
||||
self.client.chat.completions.create(**params)
|
||||
)
|
||||
|
||||
for chunk in stream:
|
||||
if not chunk.choices:
|
||||
for completion_chunk in completion_stream:
|
||||
if not completion_chunk.choices:
|
||||
continue
|
||||
|
||||
choice = chunk.choices[0]
|
||||
choice = completion_chunk.choices[0]
|
||||
chunk_delta: ChoiceDelta = choice.delta
|
||||
|
||||
if chunk_delta.content:
|
||||
|
||||
@@ -505,30 +505,43 @@ def test_openai_streaming_with_response_model():
|
||||
|
||||
llm = LLM(model="openai/gpt-4o", stream=True)
|
||||
|
||||
with patch.object(llm.client.chat.completions, "create") as mock_create:
|
||||
with patch.object(llm.client.beta.chat.completions, "stream") as mock_stream:
|
||||
# Create mock chunks with content.delta event structure
|
||||
mock_chunk1 = MagicMock()
|
||||
mock_chunk1.choices = [
|
||||
MagicMock(delta=MagicMock(content='{"answer": "test", ', tool_calls=None))
|
||||
]
|
||||
mock_chunk1.type = "content.delta"
|
||||
mock_chunk1.delta = '{"answer": "test", '
|
||||
|
||||
mock_chunk2 = MagicMock()
|
||||
mock_chunk2.choices = [
|
||||
MagicMock(
|
||||
delta=MagicMock(content='"confidence": 0.95}', tool_calls=None)
|
||||
)
|
||||
]
|
||||
mock_chunk2.type = "content.delta"
|
||||
mock_chunk2.delta = '"confidence": 0.95}'
|
||||
|
||||
mock_create.return_value = iter([mock_chunk1, mock_chunk2])
|
||||
# Create mock final completion with parsed result
|
||||
mock_parsed = TestResponse(answer="test", confidence=0.95)
|
||||
mock_message = MagicMock()
|
||||
mock_message.parsed = mock_parsed
|
||||
mock_choice = MagicMock()
|
||||
mock_choice.message = mock_message
|
||||
mock_final_completion = MagicMock()
|
||||
mock_final_completion.choices = [mock_choice]
|
||||
|
||||
# Create mock stream context manager
|
||||
mock_stream_obj = MagicMock()
|
||||
mock_stream_obj.__enter__ = MagicMock(return_value=mock_stream_obj)
|
||||
mock_stream_obj.__exit__ = MagicMock(return_value=None)
|
||||
mock_stream_obj.__iter__ = MagicMock(return_value=iter([mock_chunk1, mock_chunk2]))
|
||||
mock_stream_obj.get_final_completion = MagicMock(return_value=mock_final_completion)
|
||||
|
||||
mock_stream.return_value = mock_stream_obj
|
||||
|
||||
result = llm.call("Test question", response_model=TestResponse)
|
||||
|
||||
assert result is not None
|
||||
assert isinstance(result, str)
|
||||
|
||||
assert mock_create.called
|
||||
call_kwargs = mock_create.call_args[1]
|
||||
assert mock_stream.called
|
||||
call_kwargs = mock_stream.call_args[1]
|
||||
assert call_kwargs["model"] == "gpt-4o"
|
||||
assert call_kwargs["stream"] is True
|
||||
assert call_kwargs["response_format"] == TestResponse
|
||||
|
||||
assert "input" not in call_kwargs
|
||||
assert "text_format" not in call_kwargs
|
||||
|
||||
Reference in New Issue
Block a user