feat: inject server methods

This commit is contained in:
Greyson LaLonde
2026-01-14 05:46:35 -05:00
parent f806cf5bfb
commit 68df061c20
6 changed files with 87 additions and 41 deletions

View File

@@ -36,7 +36,10 @@ def _get_default_update_config() -> UpdateConfig:
@deprecated(
"Use `crewai.a2a.config.A2AClientConfig` or `crewai.a2a.config.A2AServerConfig` instead.",
"""
`crewai.a2a.config.A2AConfig` is deprecated and will be removed in v2.0.0,
use `crewai.a2a.config.A2AClientConfig` or `crewai.a2a.config.A2AServerConfig` instead.
""",
category=FutureWarning,
)
class A2AConfig(BaseModel):

View File

@@ -1,6 +1,17 @@
"""Type definitions for A2A protocol message parts."""
from typing import Annotated, Any, Literal, Protocol, TypedDict, runtime_checkable
from __future__ import annotations
from typing import (
TYPE_CHECKING,
Annotated,
Any,
Literal,
Protocol,
TypeAlias,
TypedDict,
runtime_checkable,
)
from pydantic import BeforeValidator, HttpUrl, TypeAdapter
from typing_extensions import NotRequired
@@ -16,6 +27,10 @@ from crewai.a2a.updates import (
)
if TYPE_CHECKING:
from crewai.a2a.config import A2AClientConfig, A2AConfig, A2AServerConfig
TransportType = Literal["JSONRPC", "GRPC", "HTTP+JSON"]
http_url_adapter: TypeAdapter[HttpUrl] = TypeAdapter(HttpUrl)
@@ -72,3 +87,6 @@ HANDLER_REGISTRY: dict[type[UpdateConfig], HandlerType] = {
StreamingConfig: StreamingHandler,
PushNotificationConfig: PushNotificationHandler,
}
A2AConfigTypes: TypeAlias = A2AConfig | A2AServerConfig | A2AClientConfig
A2AClientConfigTypes: TypeAlias = A2AConfig | A2AClientConfig

View File

@@ -4,7 +4,8 @@ from __future__ import annotations
from pydantic import BaseModel, Field, create_model
from crewai.a2a.config import A2AConfig
from crewai.a2a.config import A2AClientConfig, A2AConfig, A2AServerConfig
from crewai.a2a.types import A2AClientConfigTypes, A2AConfigTypes
from crewai.types.utils import create_literals_from_strings
@@ -46,36 +47,45 @@ def create_agent_response_model(agent_ids: tuple[str, ...]) -> type[BaseModel]:
def extract_a2a_agent_ids_from_config(
a2a_config: list[A2AConfig] | A2AConfig | None,
) -> tuple[list[A2AConfig], tuple[str, ...]]:
a2a_config: list[A2AConfigTypes] | A2AConfigTypes | None,
) -> tuple[list[A2AClientConfigTypes], tuple[str, ...]]:
"""Extract A2A agent IDs from A2A configuration.
Filters out A2AServerConfig since it doesn't have an endpoint for delegation.
Args:
a2a_config: A2A configuration.
a2a_config: A2A configuration (any type).
Returns:
Tuple of A2A configs list and agent endpoint IDs.
Tuple of client A2A configs list and agent endpoint IDs.
"""
if a2a_config is None:
return [], ()
if isinstance(a2a_config, A2AConfig):
a2a_agents = [a2a_config]
configs: list[A2AConfigTypes]
if isinstance(a2a_config, (A2AConfig, A2AClientConfig, A2AServerConfig)):
configs = [a2a_config]
else:
a2a_agents = a2a_config
return a2a_agents, tuple(config.endpoint for config in a2a_agents)
configs = a2a_config
# Filter to only client configs (those with endpoint)
client_configs: list[A2AClientConfigTypes] = [
config for config in configs if isinstance(config, (A2AConfig, A2AClientConfig))
]
return client_configs, tuple(config.endpoint for config in client_configs)
def get_a2a_agents_and_response_model(
a2a_config: list[A2AConfig] | A2AConfig | None,
) -> tuple[list[A2AConfig], type[BaseModel]]:
a2a_config: list[A2AConfigTypes] | A2AConfigTypes | None,
) -> tuple[list[A2AClientConfigTypes], type[BaseModel]]:
"""Get A2A agent configs and response model.
Args:
a2a_config: A2A configuration.
a2a_config: A2A configuration (any type).
Returns:
Tuple of A2A configs and response model.
Tuple of client A2A configs and response model.
"""
a2a_agents, agent_ids = extract_a2a_agent_ids_from_config(a2a_config=a2a_config)

View File

@@ -15,7 +15,7 @@ from typing import TYPE_CHECKING, Any
from a2a.types import Role, TaskState
from pydantic import BaseModel, ValidationError
from crewai.a2a.config import A2AConfig
from crewai.a2a.config import A2AClientConfig, A2AConfig
from crewai.a2a.extensions.base import ExtensionRegistry
from crewai.a2a.task_helpers import TaskStateResult
from crewai.a2a.templates import (
@@ -26,7 +26,11 @@ from crewai.a2a.templates import (
UNAVAILABLE_AGENTS_NOTICE_TEMPLATE,
)
from crewai.a2a.types import AgentResponseProtocol
from crewai.a2a.utils.agent_card import afetch_agent_card, fetch_agent_card
from crewai.a2a.utils.agent_card import (
afetch_agent_card,
fetch_agent_card,
inject_a2a_server_methods,
)
from crewai.a2a.utils.delegation import (
aexecute_a2a_delegation,
execute_a2a_delegation,
@@ -121,10 +125,12 @@ def wrap_agent_with_a2a_instance(
agent, "aexecute_task", MethodType(aexecute_task_with_a2a, agent)
)
inject_a2a_server_methods(agent)
def _fetch_card_from_config(
config: A2AConfig,
) -> tuple[A2AConfig, AgentCard | Exception]:
config: A2AConfig | A2AClientConfig,
) -> tuple[A2AConfig | A2AClientConfig, AgentCard | Exception]:
"""Fetch agent card from A2A config.
Args:
@@ -145,7 +151,7 @@ def _fetch_card_from_config(
def _fetch_agent_cards_concurrently(
a2a_agents: list[A2AConfig],
a2a_agents: list[A2AConfig | A2AClientConfig],
) -> tuple[dict[str, AgentCard], dict[str, str]]:
"""Fetch agent cards concurrently for multiple A2A agents.
@@ -180,7 +186,7 @@ def _fetch_agent_cards_concurrently(
def _execute_task_with_a2a(
self: Agent,
a2a_agents: list[A2AConfig],
a2a_agents: list[A2AConfig | A2AClientConfig],
original_fn: Callable[..., str],
task: Task,
agent_response_model: type[BaseModel],
@@ -269,7 +275,7 @@ def _execute_task_with_a2a(
def _augment_prompt_with_a2a(
a2a_agents: list[A2AConfig],
a2a_agents: list[A2AConfig | A2AClientConfig],
task_description: str,
agent_cards: dict[str, AgentCard],
conversation_history: list[Message] | None = None,
@@ -522,11 +528,11 @@ def _prepare_delegation_context(
task: Task,
original_task_description: str | None,
) -> tuple[
list[A2AConfig],
list[A2AConfig | A2AClientConfig],
type[BaseModel],
str,
str,
A2AConfig,
A2AConfig | A2AClientConfig,
str | None,
str | None,
dict[str, Any] | None,
@@ -590,7 +596,7 @@ def _handle_task_completion(
task: Task,
task_id_config: str | None,
reference_task_ids: list[str],
agent_config: A2AConfig,
agent_config: A2AConfig | A2AClientConfig,
turn_num: int,
) -> tuple[str | None, str | None, list[str]]:
"""Handle task completion state including reference task updates.
@@ -630,7 +636,7 @@ def _handle_agent_response_and_continue(
a2a_result: TaskStateResult,
agent_id: str,
agent_cards: dict[str, AgentCard] | None,
a2a_agents: list[A2AConfig],
a2a_agents: list[A2AConfig | A2AClientConfig],
original_task_description: str,
conversation_history: list[Message],
turn_num: int,
@@ -867,8 +873,8 @@ def _delegate_to_a2a(
async def _afetch_card_from_config(
config: A2AConfig,
) -> tuple[A2AConfig, AgentCard | Exception]:
config: A2AConfig | A2AClientConfig,
) -> tuple[A2AConfig | A2AClientConfig, AgentCard | Exception]:
"""Fetch agent card from A2A config asynchronously."""
try:
card = await afetch_agent_card(
@@ -882,7 +888,7 @@ async def _afetch_card_from_config(
async def _afetch_agent_cards_concurrently(
a2a_agents: list[A2AConfig],
a2a_agents: list[A2AConfig | A2AClientConfig],
) -> tuple[dict[str, AgentCard], dict[str, str]]:
"""Fetch agent cards concurrently for multiple A2A agents using asyncio."""
agent_cards: dict[str, AgentCard] = {}
@@ -907,7 +913,7 @@ async def _afetch_agent_cards_concurrently(
async def _aexecute_task_with_a2a(
self: Agent,
a2a_agents: list[A2AConfig],
a2a_agents: list[A2AConfig | A2AClientConfig],
original_fn: Callable[..., Coroutine[Any, Any, str]],
task: Task,
agent_response_model: type[BaseModel],
@@ -986,7 +992,7 @@ async def _ahandle_agent_response_and_continue(
a2a_result: TaskStateResult,
agent_id: str,
agent_cards: dict[str, AgentCard] | None,
a2a_agents: list[A2AConfig],
a2a_agents: list[A2AConfig | A2AClientConfig],
original_task_description: str,
conversation_history: list[Message],
turn_num: int,

View File

@@ -17,7 +17,7 @@ from urllib.parse import urlparse
from pydantic import BaseModel, Field, InstanceOf, PrivateAttr, model_validator
from typing_extensions import Self
from crewai.a2a.config import A2AConfig
from crewai.a2a.config import A2AClientConfig, A2AConfig, A2AServerConfig
from crewai.agent.utils import (
ahandle_knowledge_retrieval,
apply_training_data,
@@ -73,7 +73,7 @@ from crewai.utilities.constants import TRAINED_AGENTS_DATA_FILE, TRAINING_DATA_F
from crewai.utilities.converter import Converter
from crewai.utilities.guardrail_types import GuardrailType
from crewai.utilities.llm_utils import create_llm
from crewai.utilities.prompts import Prompts
from crewai.utilities.prompts import Prompts, StandardPromptResult, SystemPromptResult
from crewai.utilities.token_counter_callback import TokenCalcHandler
from crewai.utilities.training_handler import CrewTrainingHandler
@@ -218,9 +218,18 @@ class Agent(BaseAgent):
guardrail_max_retries: int = Field(
default=3, description="Maximum number of retries when guardrail fails"
)
a2a: list[A2AConfig] | A2AConfig | None = Field(
a2a: (
list[A2AConfig | A2AServerConfig | A2AClientConfig]
| A2AConfig
| A2AServerConfig
| A2AClientConfig
| None
) = Field(
default=None,
description="A2A (Agent-to-Agent) configuration for delegating tasks to remote agents. Can be a single A2AConfig or a dict mapping agent IDs to configs.",
description="""
A2A (Agent-to-Agent) configuration for delegating tasks to remote agents.
Can be a single A2AConfig/A2AClientConfig/A2AServerConfig, or a list of any number of A2AConfig/A2AClientConfig with a single A2AServerConfig.
""",
)
executor_class: type[CrewAgentExecutor] | type[CrewAgentExecutorFlow] = Field(
default=CrewAgentExecutor,
@@ -733,7 +742,7 @@ class Agent(BaseAgent):
if self.agent_executor is not None:
self._update_executor_parameters(
task=task,
tools=parsed_tools,
tools=parsed_tools, # type: ignore[arg-type]
raw_tools=raw_tools,
prompt=prompt,
stop_words=stop_words,
@@ -742,7 +751,7 @@ class Agent(BaseAgent):
else:
self.agent_executor = self.executor_class(
llm=cast(BaseLLM, self.llm),
task=task,
task=task, # type: ignore[arg-type]
i18n=self.i18n,
agent=self,
crew=self.crew,
@@ -765,11 +774,11 @@ class Agent(BaseAgent):
def _update_executor_parameters(
self,
task: Task | None,
tools: list,
tools: list[BaseTool],
raw_tools: list[BaseTool],
prompt: dict,
prompt: SystemPromptResult | StandardPromptResult,
stop_words: list[str],
rpm_limit_fn: Callable | None,
rpm_limit_fn: Callable | None, # type: ignore[type-arg]
) -> None:
"""Update executor parameters without recreating instance.