feat: inject server methods

This commit is contained in:
Greyson LaLonde
2026-01-14 05:46:35 -05:00
parent f806cf5bfb
commit 68df061c20
6 changed files with 87 additions and 41 deletions

View File

@@ -36,7 +36,10 @@ def _get_default_update_config() -> UpdateConfig:
@deprecated( @deprecated(
"Use `crewai.a2a.config.A2AClientConfig` or `crewai.a2a.config.A2AServerConfig` instead.", """
`crewai.a2a.config.A2AConfig` is deprecated and will be removed in v2.0.0,
use `crewai.a2a.config.A2AClientConfig` or `crewai.a2a.config.A2AServerConfig` instead.
""",
category=FutureWarning, category=FutureWarning,
) )
class A2AConfig(BaseModel): class A2AConfig(BaseModel):

View File

@@ -1,6 +1,17 @@
"""Type definitions for A2A protocol message parts.""" """Type definitions for A2A protocol message parts."""
from typing import Annotated, Any, Literal, Protocol, TypedDict, runtime_checkable from __future__ import annotations
from typing import (
TYPE_CHECKING,
Annotated,
Any,
Literal,
Protocol,
TypeAlias,
TypedDict,
runtime_checkable,
)
from pydantic import BeforeValidator, HttpUrl, TypeAdapter from pydantic import BeforeValidator, HttpUrl, TypeAdapter
from typing_extensions import NotRequired from typing_extensions import NotRequired
@@ -16,6 +27,10 @@ from crewai.a2a.updates import (
) )
if TYPE_CHECKING:
from crewai.a2a.config import A2AClientConfig, A2AConfig, A2AServerConfig
TransportType = Literal["JSONRPC", "GRPC", "HTTP+JSON"] TransportType = Literal["JSONRPC", "GRPC", "HTTP+JSON"]
http_url_adapter: TypeAdapter[HttpUrl] = TypeAdapter(HttpUrl) http_url_adapter: TypeAdapter[HttpUrl] = TypeAdapter(HttpUrl)
@@ -72,3 +87,6 @@ HANDLER_REGISTRY: dict[type[UpdateConfig], HandlerType] = {
StreamingConfig: StreamingHandler, StreamingConfig: StreamingHandler,
PushNotificationConfig: PushNotificationHandler, PushNotificationConfig: PushNotificationHandler,
} }
A2AConfigTypes: TypeAlias = A2AConfig | A2AServerConfig | A2AClientConfig
A2AClientConfigTypes: TypeAlias = A2AConfig | A2AClientConfig

View File

@@ -4,7 +4,8 @@ from __future__ import annotations
from pydantic import BaseModel, Field, create_model from pydantic import BaseModel, Field, create_model
from crewai.a2a.config import A2AConfig from crewai.a2a.config import A2AClientConfig, A2AConfig, A2AServerConfig
from crewai.a2a.types import A2AClientConfigTypes, A2AConfigTypes
from crewai.types.utils import create_literals_from_strings from crewai.types.utils import create_literals_from_strings
@@ -46,36 +47,45 @@ def create_agent_response_model(agent_ids: tuple[str, ...]) -> type[BaseModel]:
def extract_a2a_agent_ids_from_config( def extract_a2a_agent_ids_from_config(
a2a_config: list[A2AConfig] | A2AConfig | None, a2a_config: list[A2AConfigTypes] | A2AConfigTypes | None,
) -> tuple[list[A2AConfig], tuple[str, ...]]: ) -> tuple[list[A2AClientConfigTypes], tuple[str, ...]]:
"""Extract A2A agent IDs from A2A configuration. """Extract A2A agent IDs from A2A configuration.
Filters out A2AServerConfig since it doesn't have an endpoint for delegation.
Args: Args:
a2a_config: A2A configuration. a2a_config: A2A configuration (any type).
Returns: Returns:
Tuple of A2A configs list and agent endpoint IDs. Tuple of client A2A configs list and agent endpoint IDs.
""" """
if a2a_config is None: if a2a_config is None:
return [], () return [], ()
if isinstance(a2a_config, A2AConfig): configs: list[A2AConfigTypes]
a2a_agents = [a2a_config] if isinstance(a2a_config, (A2AConfig, A2AClientConfig, A2AServerConfig)):
configs = [a2a_config]
else: else:
a2a_agents = a2a_config configs = a2a_config
return a2a_agents, tuple(config.endpoint for config in a2a_agents)
# Filter to only client configs (those with endpoint)
client_configs: list[A2AClientConfigTypes] = [
config for config in configs if isinstance(config, (A2AConfig, A2AClientConfig))
]
return client_configs, tuple(config.endpoint for config in client_configs)
def get_a2a_agents_and_response_model( def get_a2a_agents_and_response_model(
a2a_config: list[A2AConfig] | A2AConfig | None, a2a_config: list[A2AConfigTypes] | A2AConfigTypes | None,
) -> tuple[list[A2AConfig], type[BaseModel]]: ) -> tuple[list[A2AClientConfigTypes], type[BaseModel]]:
"""Get A2A agent configs and response model. """Get A2A agent configs and response model.
Args: Args:
a2a_config: A2A configuration. a2a_config: A2A configuration (any type).
Returns: Returns:
Tuple of A2A configs and response model. Tuple of client A2A configs and response model.
""" """
a2a_agents, agent_ids = extract_a2a_agent_ids_from_config(a2a_config=a2a_config) a2a_agents, agent_ids = extract_a2a_agent_ids_from_config(a2a_config=a2a_config)

View File

@@ -15,7 +15,7 @@ from typing import TYPE_CHECKING, Any
from a2a.types import Role, TaskState from a2a.types import Role, TaskState
from pydantic import BaseModel, ValidationError from pydantic import BaseModel, ValidationError
from crewai.a2a.config import A2AConfig from crewai.a2a.config import A2AClientConfig, A2AConfig
from crewai.a2a.extensions.base import ExtensionRegistry from crewai.a2a.extensions.base import ExtensionRegistry
from crewai.a2a.task_helpers import TaskStateResult from crewai.a2a.task_helpers import TaskStateResult
from crewai.a2a.templates import ( from crewai.a2a.templates import (
@@ -26,7 +26,11 @@ from crewai.a2a.templates import (
UNAVAILABLE_AGENTS_NOTICE_TEMPLATE, UNAVAILABLE_AGENTS_NOTICE_TEMPLATE,
) )
from crewai.a2a.types import AgentResponseProtocol from crewai.a2a.types import AgentResponseProtocol
from crewai.a2a.utils.agent_card import afetch_agent_card, fetch_agent_card from crewai.a2a.utils.agent_card import (
afetch_agent_card,
fetch_agent_card,
inject_a2a_server_methods,
)
from crewai.a2a.utils.delegation import ( from crewai.a2a.utils.delegation import (
aexecute_a2a_delegation, aexecute_a2a_delegation,
execute_a2a_delegation, execute_a2a_delegation,
@@ -121,10 +125,12 @@ def wrap_agent_with_a2a_instance(
agent, "aexecute_task", MethodType(aexecute_task_with_a2a, agent) agent, "aexecute_task", MethodType(aexecute_task_with_a2a, agent)
) )
inject_a2a_server_methods(agent)
def _fetch_card_from_config( def _fetch_card_from_config(
config: A2AConfig, config: A2AConfig | A2AClientConfig,
) -> tuple[A2AConfig, AgentCard | Exception]: ) -> tuple[A2AConfig | A2AClientConfig, AgentCard | Exception]:
"""Fetch agent card from A2A config. """Fetch agent card from A2A config.
Args: Args:
@@ -145,7 +151,7 @@ def _fetch_card_from_config(
def _fetch_agent_cards_concurrently( def _fetch_agent_cards_concurrently(
a2a_agents: list[A2AConfig], a2a_agents: list[A2AConfig | A2AClientConfig],
) -> tuple[dict[str, AgentCard], dict[str, str]]: ) -> tuple[dict[str, AgentCard], dict[str, str]]:
"""Fetch agent cards concurrently for multiple A2A agents. """Fetch agent cards concurrently for multiple A2A agents.
@@ -180,7 +186,7 @@ def _fetch_agent_cards_concurrently(
def _execute_task_with_a2a( def _execute_task_with_a2a(
self: Agent, self: Agent,
a2a_agents: list[A2AConfig], a2a_agents: list[A2AConfig | A2AClientConfig],
original_fn: Callable[..., str], original_fn: Callable[..., str],
task: Task, task: Task,
agent_response_model: type[BaseModel], agent_response_model: type[BaseModel],
@@ -269,7 +275,7 @@ def _execute_task_with_a2a(
def _augment_prompt_with_a2a( def _augment_prompt_with_a2a(
a2a_agents: list[A2AConfig], a2a_agents: list[A2AConfig | A2AClientConfig],
task_description: str, task_description: str,
agent_cards: dict[str, AgentCard], agent_cards: dict[str, AgentCard],
conversation_history: list[Message] | None = None, conversation_history: list[Message] | None = None,
@@ -522,11 +528,11 @@ def _prepare_delegation_context(
task: Task, task: Task,
original_task_description: str | None, original_task_description: str | None,
) -> tuple[ ) -> tuple[
list[A2AConfig], list[A2AConfig | A2AClientConfig],
type[BaseModel], type[BaseModel],
str, str,
str, str,
A2AConfig, A2AConfig | A2AClientConfig,
str | None, str | None,
str | None, str | None,
dict[str, Any] | None, dict[str, Any] | None,
@@ -590,7 +596,7 @@ def _handle_task_completion(
task: Task, task: Task,
task_id_config: str | None, task_id_config: str | None,
reference_task_ids: list[str], reference_task_ids: list[str],
agent_config: A2AConfig, agent_config: A2AConfig | A2AClientConfig,
turn_num: int, turn_num: int,
) -> tuple[str | None, str | None, list[str]]: ) -> tuple[str | None, str | None, list[str]]:
"""Handle task completion state including reference task updates. """Handle task completion state including reference task updates.
@@ -630,7 +636,7 @@ def _handle_agent_response_and_continue(
a2a_result: TaskStateResult, a2a_result: TaskStateResult,
agent_id: str, agent_id: str,
agent_cards: dict[str, AgentCard] | None, agent_cards: dict[str, AgentCard] | None,
a2a_agents: list[A2AConfig], a2a_agents: list[A2AConfig | A2AClientConfig],
original_task_description: str, original_task_description: str,
conversation_history: list[Message], conversation_history: list[Message],
turn_num: int, turn_num: int,
@@ -867,8 +873,8 @@ def _delegate_to_a2a(
async def _afetch_card_from_config( async def _afetch_card_from_config(
config: A2AConfig, config: A2AConfig | A2AClientConfig,
) -> tuple[A2AConfig, AgentCard | Exception]: ) -> tuple[A2AConfig | A2AClientConfig, AgentCard | Exception]:
"""Fetch agent card from A2A config asynchronously.""" """Fetch agent card from A2A config asynchronously."""
try: try:
card = await afetch_agent_card( card = await afetch_agent_card(
@@ -882,7 +888,7 @@ async def _afetch_card_from_config(
async def _afetch_agent_cards_concurrently( async def _afetch_agent_cards_concurrently(
a2a_agents: list[A2AConfig], a2a_agents: list[A2AConfig | A2AClientConfig],
) -> tuple[dict[str, AgentCard], dict[str, str]]: ) -> tuple[dict[str, AgentCard], dict[str, str]]:
"""Fetch agent cards concurrently for multiple A2A agents using asyncio.""" """Fetch agent cards concurrently for multiple A2A agents using asyncio."""
agent_cards: dict[str, AgentCard] = {} agent_cards: dict[str, AgentCard] = {}
@@ -907,7 +913,7 @@ async def _afetch_agent_cards_concurrently(
async def _aexecute_task_with_a2a( async def _aexecute_task_with_a2a(
self: Agent, self: Agent,
a2a_agents: list[A2AConfig], a2a_agents: list[A2AConfig | A2AClientConfig],
original_fn: Callable[..., Coroutine[Any, Any, str]], original_fn: Callable[..., Coroutine[Any, Any, str]],
task: Task, task: Task,
agent_response_model: type[BaseModel], agent_response_model: type[BaseModel],
@@ -986,7 +992,7 @@ async def _ahandle_agent_response_and_continue(
a2a_result: TaskStateResult, a2a_result: TaskStateResult,
agent_id: str, agent_id: str,
agent_cards: dict[str, AgentCard] | None, agent_cards: dict[str, AgentCard] | None,
a2a_agents: list[A2AConfig], a2a_agents: list[A2AConfig | A2AClientConfig],
original_task_description: str, original_task_description: str,
conversation_history: list[Message], conversation_history: list[Message],
turn_num: int, turn_num: int,

View File

@@ -17,7 +17,7 @@ from urllib.parse import urlparse
from pydantic import BaseModel, Field, InstanceOf, PrivateAttr, model_validator from pydantic import BaseModel, Field, InstanceOf, PrivateAttr, model_validator
from typing_extensions import Self from typing_extensions import Self
from crewai.a2a.config import A2AConfig from crewai.a2a.config import A2AClientConfig, A2AConfig, A2AServerConfig
from crewai.agent.utils import ( from crewai.agent.utils import (
ahandle_knowledge_retrieval, ahandle_knowledge_retrieval,
apply_training_data, apply_training_data,
@@ -73,7 +73,7 @@ from crewai.utilities.constants import TRAINED_AGENTS_DATA_FILE, TRAINING_DATA_F
from crewai.utilities.converter import Converter from crewai.utilities.converter import Converter
from crewai.utilities.guardrail_types import GuardrailType from crewai.utilities.guardrail_types import GuardrailType
from crewai.utilities.llm_utils import create_llm from crewai.utilities.llm_utils import create_llm
from crewai.utilities.prompts import Prompts from crewai.utilities.prompts import Prompts, StandardPromptResult, SystemPromptResult
from crewai.utilities.token_counter_callback import TokenCalcHandler from crewai.utilities.token_counter_callback import TokenCalcHandler
from crewai.utilities.training_handler import CrewTrainingHandler from crewai.utilities.training_handler import CrewTrainingHandler
@@ -218,9 +218,18 @@ class Agent(BaseAgent):
guardrail_max_retries: int = Field( guardrail_max_retries: int = Field(
default=3, description="Maximum number of retries when guardrail fails" default=3, description="Maximum number of retries when guardrail fails"
) )
a2a: list[A2AConfig] | A2AConfig | None = Field( a2a: (
list[A2AConfig | A2AServerConfig | A2AClientConfig]
| A2AConfig
| A2AServerConfig
| A2AClientConfig
| None
) = Field(
default=None, default=None,
description="A2A (Agent-to-Agent) configuration for delegating tasks to remote agents. Can be a single A2AConfig or a dict mapping agent IDs to configs.", description="""
A2A (Agent-to-Agent) configuration for delegating tasks to remote agents.
Can be a single A2AConfig/A2AClientConfig/A2AServerConfig, or a list of any number of A2AConfig/A2AClientConfig with a single A2AServerConfig.
""",
) )
executor_class: type[CrewAgentExecutor] | type[CrewAgentExecutorFlow] = Field( executor_class: type[CrewAgentExecutor] | type[CrewAgentExecutorFlow] = Field(
default=CrewAgentExecutor, default=CrewAgentExecutor,
@@ -733,7 +742,7 @@ class Agent(BaseAgent):
if self.agent_executor is not None: if self.agent_executor is not None:
self._update_executor_parameters( self._update_executor_parameters(
task=task, task=task,
tools=parsed_tools, tools=parsed_tools, # type: ignore[arg-type]
raw_tools=raw_tools, raw_tools=raw_tools,
prompt=prompt, prompt=prompt,
stop_words=stop_words, stop_words=stop_words,
@@ -742,7 +751,7 @@ class Agent(BaseAgent):
else: else:
self.agent_executor = self.executor_class( self.agent_executor = self.executor_class(
llm=cast(BaseLLM, self.llm), llm=cast(BaseLLM, self.llm),
task=task, task=task, # type: ignore[arg-type]
i18n=self.i18n, i18n=self.i18n,
agent=self, agent=self,
crew=self.crew, crew=self.crew,
@@ -765,11 +774,11 @@ class Agent(BaseAgent):
def _update_executor_parameters( def _update_executor_parameters(
self, self,
task: Task | None, task: Task | None,
tools: list, tools: list[BaseTool],
raw_tools: list[BaseTool], raw_tools: list[BaseTool],
prompt: dict, prompt: SystemPromptResult | StandardPromptResult,
stop_words: list[str], stop_words: list[str],
rpm_limit_fn: Callable | None, rpm_limit_fn: Callable | None, # type: ignore[type-arg]
) -> None: ) -> None:
"""Update executor parameters without recreating instance. """Update executor parameters without recreating instance.

View File

@@ -117,7 +117,7 @@ show_error_codes = true
warn_unused_ignores = true warn_unused_ignores = true
python_version = "3.12" python_version = "3.12"
exclude = "(?x)(^lib/crewai/src/crewai/cli/templates/ | ^lib/crewai/tests/ | ^lib/crewai-tools/tests/)" exclude = "(?x)(^lib/crewai/src/crewai/cli/templates/ | ^lib/crewai/tests/ | ^lib/crewai-tools/tests/)"
plugins = ["pydantic.mypy", "crewai.mypy"] plugins = ["pydantic.mypy"]
[tool.bandit] [tool.bandit]