diff --git a/lib/crewai/src/crewai/agent/core.py b/lib/crewai/src/crewai/agent/core.py index e121a9771..21edbd160 100644 --- a/lib/crewai/src/crewai/agent/core.py +++ b/lib/crewai/src/crewai/agent/core.py @@ -8,11 +8,9 @@ import time from typing import ( TYPE_CHECKING, Any, - Final, Literal, cast, ) -from urllib.parse import urlparse from pydantic import ( BaseModel, @@ -61,16 +59,8 @@ from crewai.knowledge.knowledge import Knowledge from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource from crewai.lite_agent_output import LiteAgentOutput from crewai.llms.base_llm import BaseLLM -from crewai.mcp import ( - MCPClient, - MCPServerConfig, - MCPServerHTTP, - MCPServerSSE, - MCPServerStdio, -) -from crewai.mcp.transports.http import HTTPTransport -from crewai.mcp.transports.sse import SSETransport -from crewai.mcp.transports.stdio import StdioTransport +from crewai.mcp import MCPServerConfig +from crewai.mcp.tool_resolver import MCPToolResolver from crewai.rag.embeddings.types import EmbedderConfig from crewai.security.fingerprint import Fingerprint from crewai.tools.agent_tools.agent_tools import AgentTools @@ -111,18 +101,8 @@ if TYPE_CHECKING: from crewai.utilities.types import LLMMessage -# MCP Connection timeout constants (in seconds) -MCP_CONNECTION_TIMEOUT: Final[int] = 10 -MCP_TOOL_EXECUTION_TIMEOUT: Final[int] = 30 -MCP_DISCOVERY_TIMEOUT: Final[int] = 15 -MCP_MAX_RETRIES: Final[int] = 3 - _passthrough_exceptions: tuple[type[Exception], ...] = () -# Simple in-memory cache for MCP tool schemas (duration: 5 minutes) -_mcp_schema_cache: dict[str, Any] = {} -_cache_ttl: Final[int] = 300 # 5 minutes - class Agent(BaseAgent): """Represents an agent in a system. @@ -154,7 +134,7 @@ class Agent(BaseAgent): model_config = ConfigDict() _times_executed: int = PrivateAttr(default=0) - _mcp_clients: list[Any] = PrivateAttr(default_factory=list) + _mcp_resolver: MCPToolResolver | None = PrivateAttr(default=None) _last_messages: list[LLMMessage] = PrivateAttr(default_factory=list) max_execution_time: int | None = Field( default=None, @@ -934,544 +914,17 @@ class Agent(BaseAgent): def get_mcp_tools(self, mcps: list[str | MCPServerConfig]) -> list[BaseTool]: """Convert MCP server references/configs to CrewAI tools. - Supports both string references (backwards compatible) and structured - configuration objects (MCPServerStdio, MCPServerHTTP, MCPServerSSE). - - Args: - mcps: List of MCP server references (strings) or configurations. - - Returns: - List of BaseTool instances from MCP servers. + Delegates to :class:`~crewai.mcp.tool_resolver.MCPToolResolver`. """ - all_tools = [] - clients = [] - - for mcp_config in mcps: - if isinstance(mcp_config, str): - tools = self._get_mcp_tools_from_string(mcp_config) - else: - tools, client = self._get_native_mcp_tools(mcp_config) - if client: - clients.append(client) - - all_tools.extend(tools) - - # Store clients for cleanup - self._mcp_clients.extend(clients) - return all_tools + self._cleanup_mcp_clients() + self._mcp_resolver = MCPToolResolver(agent=self, logger=self._logger) + return self._mcp_resolver.resolve(mcps) def _cleanup_mcp_clients(self) -> None: """Cleanup MCP client connections after task execution.""" - if not self._mcp_clients: - return - - async def _disconnect_all() -> None: - for client in self._mcp_clients: - if client and hasattr(client, "connected") and client.connected: - await client.disconnect() - - try: - asyncio.run(_disconnect_all()) - except Exception as e: - self._logger.log("error", f"Error during MCP client cleanup: {e}") - finally: - self._mcp_clients.clear() - - def _get_mcp_tools_from_string(self, mcp_ref: str) -> list[BaseTool]: - """Get tools from legacy string-based MCP references. - - This method maintains backwards compatibility with string-based - MCP references (https://... and crewai-amp:...). - - Args: - mcp_ref: String reference to MCP server. - - Returns: - List of BaseTool instances. - """ - if mcp_ref.startswith("crewai-amp:"): - return self._get_amp_mcp_tools(mcp_ref) - if mcp_ref.startswith("https://"): - return self._get_external_mcp_tools(mcp_ref) - return [] - - def _get_external_mcp_tools(self, mcp_ref: str) -> list[BaseTool]: - """Get tools from external HTTPS MCP server with graceful error handling.""" - from crewai.tools.mcp_tool_wrapper import MCPToolWrapper - - # Parse server URL and optional tool name - if "#" in mcp_ref: - server_url, specific_tool = mcp_ref.split("#", 1) - else: - server_url, specific_tool = mcp_ref, None - - server_params = {"url": server_url} - server_name = self._extract_server_name(server_url) - - try: - # Get tool schemas with timeout and error handling - tool_schemas = self._get_mcp_tool_schemas(server_params) - - if not tool_schemas: - self._logger.log( - "warning", f"No tools discovered from MCP server: {server_url}" - ) - return [] - - tools = [] - for tool_name, schema in tool_schemas.items(): - # Skip if specific tool requested and this isn't it - if specific_tool and tool_name != specific_tool: - continue - - try: - wrapper = MCPToolWrapper( - mcp_server_params=server_params, - tool_name=tool_name, - tool_schema=schema, - server_name=server_name, - ) - tools.append(wrapper) - except Exception as e: - self._logger.log( - "warning", - f"Failed to create MCP tool wrapper for {tool_name}: {e}", - ) - continue - - if specific_tool and not tools: - self._logger.log( - "warning", - f"Specific tool '{specific_tool}' not found on MCP server: {server_url}", - ) - - return cast(list[BaseTool], tools) - - except Exception as e: - self._logger.log( - "warning", f"Failed to connect to MCP server {server_url}: {e}" - ) - return [] - - def _get_native_mcp_tools( - self, mcp_config: MCPServerConfig - ) -> tuple[list[BaseTool], Any | None]: - """Get tools from MCP server using structured configuration. - - This method creates an MCP client based on the configuration type, - connects to the server, discovers tools, applies filtering, and - returns wrapped tools along with the client instance for cleanup. - - Args: - mcp_config: MCP server configuration (MCPServerStdio, MCPServerHTTP, or MCPServerSSE). - - Returns: - Tuple of (list of BaseTool instances, MCPClient instance for cleanup). - """ - from crewai.tools.base_tool import BaseTool - from crewai.tools.mcp_native_tool import MCPNativeTool - - transport: StdioTransport | HTTPTransport | SSETransport - if isinstance(mcp_config, MCPServerStdio): - transport = StdioTransport( - command=mcp_config.command, - args=mcp_config.args, - env=mcp_config.env, - ) - server_name = f"{mcp_config.command}_{'_'.join(mcp_config.args)}" - elif isinstance(mcp_config, MCPServerHTTP): - transport = HTTPTransport( - url=mcp_config.url, - headers=mcp_config.headers, - streamable=mcp_config.streamable, - ) - server_name = self._extract_server_name(mcp_config.url) - elif isinstance(mcp_config, MCPServerSSE): - transport = SSETransport( - url=mcp_config.url, - headers=mcp_config.headers, - ) - server_name = self._extract_server_name(mcp_config.url) - else: - raise ValueError(f"Unsupported MCP server config type: {type(mcp_config)}") - - client = MCPClient( - transport=transport, - cache_tools_list=mcp_config.cache_tools_list, - ) - - async def _setup_client_and_list_tools() -> list[dict[str, Any]]: - """Async helper to connect and list tools in same event loop.""" - - try: - if not client.connected: - await client.connect() - - tools_list = await client.list_tools() - - try: - await client.disconnect() - # Small delay to allow background tasks to finish cleanup - # This helps prevent "cancel scope in different task" errors - # when asyncio.run() closes the event loop - await asyncio.sleep(0.1) - except Exception as e: - self._logger.log("error", f"Error during disconnect: {e}") - - return tools_list - except Exception as e: - if client.connected: - await client.disconnect() - await asyncio.sleep(0.1) - raise RuntimeError( - f"Error during setup client and list tools: {e}" - ) from e - - try: - try: - asyncio.get_running_loop() - import concurrent.futures - - with concurrent.futures.ThreadPoolExecutor() as executor: - future = executor.submit( - asyncio.run, _setup_client_and_list_tools() - ) - tools_list = future.result() - except RuntimeError: - try: - tools_list = asyncio.run(_setup_client_and_list_tools()) - except RuntimeError as e: - error_msg = str(e).lower() - if "cancel scope" in error_msg or "task" in error_msg: - raise ConnectionError( - "MCP connection failed due to event loop cleanup issues. " - "This may be due to authentication errors or server unavailability." - ) from e - except asyncio.CancelledError as e: - raise ConnectionError( - "MCP connection was cancelled. This may indicate an authentication " - "error or server unavailability." - ) from e - - if mcp_config.tool_filter: - filtered_tools = [] - for tool in tools_list: - if callable(mcp_config.tool_filter): - try: - from crewai.mcp.filters import ToolFilterContext - - context = ToolFilterContext( - agent=self, - server_name=server_name, - run_context=None, - ) - if mcp_config.tool_filter(context, tool): # type: ignore[call-arg, arg-type] - filtered_tools.append(tool) - except (TypeError, AttributeError): - if mcp_config.tool_filter(tool): # type: ignore[call-arg, arg-type] - filtered_tools.append(tool) - else: - # Not callable - include tool - filtered_tools.append(tool) - tools_list = filtered_tools - - tools = [] - for tool_def in tools_list: - tool_name = tool_def.get("name", "") - if not tool_name: - continue - - # Convert inputSchema to Pydantic model if present - args_schema = None - if tool_def.get("inputSchema"): - args_schema = self._json_schema_to_pydantic( - tool_name, tool_def["inputSchema"] - ) - - tool_schema = { - "description": tool_def.get("description", ""), - "args_schema": args_schema, - } - - try: - native_tool = MCPNativeTool( - mcp_client=client, - tool_name=tool_name, - tool_schema=tool_schema, - server_name=server_name, - ) - tools.append(native_tool) - except Exception as e: - self._logger.log("error", f"Failed to create native MCP tool: {e}") - continue - - return cast(list[BaseTool], tools), client - except Exception as e: - if client.connected: - asyncio.run(client.disconnect()) - - raise RuntimeError(f"Failed to get native MCP tools: {e}") from e - - def _get_amp_mcp_tools(self, amp_ref: str) -> list[BaseTool]: - """Get tools from CrewAI AMP MCP marketplace.""" - # Parse: "crewai-amp:mcp-name" or "crewai-amp:mcp-name#tool_name" - amp_part = amp_ref.replace("crewai-amp:", "") - if "#" in amp_part: - mcp_name, specific_tool = amp_part.split("#", 1) - else: - mcp_name, specific_tool = amp_part, None - - # Call AMP API to get MCP server URLs - mcp_servers = self._fetch_amp_mcp_servers(mcp_name) - - tools = [] - for server_config in mcp_servers: - server_ref = server_config["url"] - if specific_tool: - server_ref += f"#{specific_tool}" - server_tools = self._get_external_mcp_tools(server_ref) - tools.extend(server_tools) - - return tools - - @staticmethod - def _extract_server_name(server_url: str) -> str: - """Extract clean server name from URL for tool prefixing.""" - - parsed = urlparse(server_url) - domain = parsed.netloc.replace(".", "_") - path = parsed.path.replace("/", "_").strip("_") - return f"{domain}_{path}" if path else domain - - def _get_mcp_tool_schemas( - self, server_params: dict[str, Any] - ) -> dict[str, dict[str, Any]]: - """Get tool schemas from MCP server for wrapper creation with caching.""" - server_url = server_params["url"] - - # Check cache first - cache_key = server_url - current_time = time.time() - - if cache_key in _mcp_schema_cache: - cached_data, cache_time = _mcp_schema_cache[cache_key] - if current_time - cache_time < _cache_ttl: - self._logger.log( - "debug", f"Using cached MCP tool schemas for {server_url}" - ) - return cached_data # type: ignore[no-any-return] - - try: - schemas = asyncio.run(self._get_mcp_tool_schemas_async(server_params)) - - # Cache successful results - _mcp_schema_cache[cache_key] = (schemas, current_time) - - return schemas - except Exception as e: - # Log warning but don't raise - this allows graceful degradation - self._logger.log( - "warning", f"Failed to get MCP tool schemas from {server_url}: {e}" - ) - return {} - - async def _get_mcp_tool_schemas_async( - self, server_params: dict[str, Any] - ) -> dict[str, dict[str, Any]]: - """Async implementation of MCP tool schema retrieval with timeouts and retries.""" - server_url = server_params["url"] - return await self._retry_mcp_discovery( - self._discover_mcp_tools_with_timeout, server_url - ) - - async def _retry_mcp_discovery( - self, operation_func: Any, server_url: str - ) -> dict[str, dict[str, Any]]: - """Retry MCP discovery operation with exponential backoff, avoiding try-except in loop.""" - last_error = None - - for attempt in range(MCP_MAX_RETRIES): - # Execute single attempt outside try-except loop structure - result, error, should_retry = await self._attempt_mcp_discovery( - operation_func, server_url - ) - - # Success case - return immediately - if result is not None: - return result - - # Non-retryable error - raise immediately - if not should_retry: - raise RuntimeError(error) - - # Retryable error - continue with backoff - last_error = error - if attempt < MCP_MAX_RETRIES - 1: - wait_time = 2**attempt # Exponential backoff - await asyncio.sleep(wait_time) - - raise RuntimeError( - f"Failed to discover MCP tools after {MCP_MAX_RETRIES} attempts: {last_error}" - ) - - @staticmethod - async def _attempt_mcp_discovery( - operation_func: Any, server_url: str - ) -> tuple[dict[str, dict[str, Any]] | None, str, bool]: - """Attempt single MCP discovery operation and return (result, error_message, should_retry).""" - try: - result = await operation_func(server_url) - return result, "", False - - except ImportError: - return ( - None, - "MCP library not available. Please install with: pip install mcp", - False, - ) - - except asyncio.TimeoutError: - return ( - None, - f"MCP discovery timed out after {MCP_DISCOVERY_TIMEOUT} seconds", - True, - ) - - except Exception as e: - error_str = str(e).lower() - - # Classify errors as retryable or non-retryable - if "authentication" in error_str or "unauthorized" in error_str: - return None, f"Authentication failed for MCP server: {e!s}", False - if "connection" in error_str or "network" in error_str: - return None, f"Network connection failed: {e!s}", True - if "json" in error_str or "parsing" in error_str: - return None, f"Server response parsing error: {e!s}", True - return None, f"MCP discovery error: {e!s}", False - - async def _discover_mcp_tools_with_timeout( - self, server_url: str - ) -> dict[str, dict[str, Any]]: - """Discover MCP tools with timeout wrapper.""" - return await asyncio.wait_for( - self._discover_mcp_tools(server_url), timeout=MCP_DISCOVERY_TIMEOUT - ) - - async def _discover_mcp_tools(self, server_url: str) -> dict[str, dict[str, Any]]: - """Discover tools from MCP server with proper timeout handling.""" - from mcp import ClientSession - from mcp.client.streamable_http import streamablehttp_client - - async with streamablehttp_client(server_url) as (read, write, _): - async with ClientSession(read, write) as session: - # Initialize the connection with timeout - await asyncio.wait_for( - session.initialize(), timeout=MCP_CONNECTION_TIMEOUT - ) - - # List available tools with timeout - tools_result = await asyncio.wait_for( - session.list_tools(), - timeout=MCP_DISCOVERY_TIMEOUT - MCP_CONNECTION_TIMEOUT, - ) - - schemas = {} - for tool in tools_result.tools: - args_schema = None - if hasattr(tool, "inputSchema") and tool.inputSchema: - args_schema = self._json_schema_to_pydantic( - sanitize_tool_name(tool.name), tool.inputSchema - ) - - schemas[sanitize_tool_name(tool.name)] = { - "description": getattr(tool, "description", ""), - "args_schema": args_schema, - } - return schemas - - def _json_schema_to_pydantic( - self, tool_name: str, json_schema: dict[str, Any] - ) -> type: - """Convert JSON Schema to Pydantic model for tool arguments. - - Args: - tool_name: Name of the tool (used for model naming) - json_schema: JSON Schema dict with 'properties', 'required', etc. - - Returns: - Pydantic BaseModel class - """ - from pydantic import Field, create_model - - properties = json_schema.get("properties", {}) - required_fields = json_schema.get("required", []) - - field_definitions: dict[str, Any] = {} - - for field_name, field_schema in properties.items(): - field_type = self._json_type_to_python(field_schema) - field_description = field_schema.get("description", "") - - is_required = field_name in required_fields - - if is_required: - field_definitions[field_name] = ( - field_type, - Field(..., description=field_description), - ) - else: - field_definitions[field_name] = ( - field_type | None, - Field(default=None, description=field_description), - ) - - model_name = f"{tool_name.replace('-', '_').replace(' ', '_')}Schema" - return create_model(model_name, **field_definitions) # type: ignore[no-any-return] - - def _json_type_to_python(self, field_schema: dict[str, Any]) -> type: - """Convert JSON Schema type to Python type. - - Args: - field_schema: JSON Schema field definition - - Returns: - Python type - """ - - json_type = field_schema.get("type") - - if "anyOf" in field_schema: - types: list[type] = [] - for option in field_schema["anyOf"]: - if "const" in option: - types.append(str) - else: - types.append(self._json_type_to_python(option)) - unique_types = list(set(types)) - if len(unique_types) > 1: - result: Any = unique_types[0] - for t in unique_types[1:]: - result = result | t - return result # type: ignore[no-any-return] - return unique_types[0] - - type_mapping: dict[str | None, type] = { - "string": str, - "number": float, - "integer": int, - "boolean": bool, - "array": list, - "object": dict, - } - - return type_mapping.get(json_type, Any) - - @staticmethod - def _fetch_amp_mcp_servers(mcp_name: str) -> list[dict[str, Any]]: - """Fetch MCP server configurations from CrewAI AMP API.""" - # TODO: Implement AMP API call to "integrations/mcps" endpoint - # Should return list of server configs with URLs - return [] + if self._mcp_resolver is not None: + self._mcp_resolver.cleanup() + self._mcp_resolver = None @staticmethod def get_multimodal_tools() -> Sequence[BaseTool]: diff --git a/lib/crewai/src/crewai/agents/agent_builder/base_agent.py b/lib/crewai/src/crewai/agents/agent_builder/base_agent.py index 286f244ed..8b2b9737c 100644 --- a/lib/crewai/src/crewai/agents/agent_builder/base_agent.py +++ b/lib/crewai/src/crewai/agents/agent_builder/base_agent.py @@ -4,7 +4,8 @@ from abc import ABC, abstractmethod from collections.abc import Callable from copy import copy as shallow_copy from hashlib import md5 -from typing import Any, Literal +import re +from typing import Any, Final, Literal import uuid from pydantic import ( @@ -36,6 +37,11 @@ from crewai.utilities.rpm_controller import RPMController from crewai.utilities.string_utils import interpolate_only +_SLUG_RE: Final[re.Pattern[str]] = re.compile( + r"^(?:crewai-amp:)?[a-zA-Z0-9][a-zA-Z0-9_-]*(?:#\w+)?$" +) + + PlatformApp = Literal[ "asana", "box", @@ -197,7 +203,7 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta): ) mcps: list[str | MCPServerConfig] | None = Field( default=None, - description="List of MCP server references. Supports 'https://server.com/path' for external servers and 'crewai-amp:mcp-name' for AMP marketplace. Use '#tool_name' suffix for specific tools.", + description="List of MCP server references. Supports 'https://server.com/path' for external servers and bare slugs like 'notion' for connected MCP integrations. Use '#tool_name' suffix for specific tools.", ) memory: Any = Field( default=None, @@ -276,14 +282,16 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta): validated_mcps: list[str | MCPServerConfig] = [] for mcp in mcps: if isinstance(mcp, str): - if mcp.startswith(("https://", "crewai-amp:")): + if mcp.startswith("https://"): + validated_mcps.append(mcp) + elif _SLUG_RE.match(mcp): validated_mcps.append(mcp) else: raise ValueError( - f"Invalid MCP reference: {mcp}. " - "String references must start with 'https://' or 'crewai-amp:'" + f"Invalid MCP reference: {mcp!r}. " + "String references must be an 'https://' URL or a valid " + "slug (e.g. 'notion', 'notion#search', 'crewai-amp:notion')." ) - elif isinstance(mcp, (MCPServerConfig)): validated_mcps.append(mcp) else: diff --git a/lib/crewai/src/crewai/cli/plus_api.py b/lib/crewai/src/crewai/cli/plus_api.py index cbe402eff..17884ffc2 100644 --- a/lib/crewai/src/crewai/cli/plus_api.py +++ b/lib/crewai/src/crewai/cli/plus_api.py @@ -190,6 +190,15 @@ class PlusAPI: timeout=30, ) + def get_mcp_configs(self, slugs: list[str]) -> httpx.Response: + """Get MCP server configurations for the given slugs.""" + return self._make_request( + "GET", + f"{self.INTEGRATIONS_RESOURCE}/mcp_configs", + params={"slugs": ",".join(slugs)}, + timeout=30, + ) + def get_triggers(self) -> httpx.Response: """Get all available triggers from integrations.""" return self._make_request("GET", f"{self.INTEGRATIONS_RESOURCE}/apps") diff --git a/lib/crewai/src/crewai/events/__init__.py b/lib/crewai/src/crewai/events/__init__.py index a6f213a54..36933fc45 100644 --- a/lib/crewai/src/crewai/events/__init__.py +++ b/lib/crewai/src/crewai/events/__init__.py @@ -63,6 +63,7 @@ from crewai.events.types.logging_events import ( AgentLogsStartedEvent, ) from crewai.events.types.mcp_events import ( + MCPConfigFetchFailedEvent, MCPConnectionCompletedEvent, MCPConnectionFailedEvent, MCPConnectionStartedEvent, @@ -165,6 +166,7 @@ __all__ = [ "LiteAgentExecutionCompletedEvent", "LiteAgentExecutionErrorEvent", "LiteAgentExecutionStartedEvent", + "MCPConfigFetchFailedEvent", "MCPConnectionCompletedEvent", "MCPConnectionFailedEvent", "MCPConnectionStartedEvent", diff --git a/lib/crewai/src/crewai/events/event_listener.py b/lib/crewai/src/crewai/events/event_listener.py index 5f22d0188..09dc25316 100644 --- a/lib/crewai/src/crewai/events/event_listener.py +++ b/lib/crewai/src/crewai/events/event_listener.py @@ -68,6 +68,7 @@ from crewai.events.types.logging_events import ( AgentLogsStartedEvent, ) from crewai.events.types.mcp_events import ( + MCPConfigFetchFailedEvent, MCPConnectionCompletedEvent, MCPConnectionFailedEvent, MCPConnectionStartedEvent, @@ -665,6 +666,16 @@ class EventListener(BaseEventListener): event.error_type, ) + @crewai_event_bus.on(MCPConfigFetchFailedEvent) + def on_mcp_config_fetch_failed( + _: Any, event: MCPConfigFetchFailedEvent + ) -> None: + self.formatter.handle_mcp_config_fetch_failed( + event.slug, + event.error, + event.error_type, + ) + @crewai_event_bus.on(MCPToolExecutionStartedEvent) def on_mcp_tool_execution_started( _: Any, event: MCPToolExecutionStartedEvent diff --git a/lib/crewai/src/crewai/events/event_types.py b/lib/crewai/src/crewai/events/event_types.py index 5fca4bd7d..63b6cdfc8 100644 --- a/lib/crewai/src/crewai/events/event_types.py +++ b/lib/crewai/src/crewai/events/event_types.py @@ -67,6 +67,7 @@ from crewai.events.types.llm_guardrail_events import ( LLMGuardrailStartedEvent, ) from crewai.events.types.mcp_events import ( + MCPConfigFetchFailedEvent, MCPConnectionCompletedEvent, MCPConnectionFailedEvent, MCPConnectionStartedEvent, @@ -181,4 +182,5 @@ EventTypes = ( | MCPToolExecutionStartedEvent | MCPToolExecutionCompletedEvent | MCPToolExecutionFailedEvent + | MCPConfigFetchFailedEvent ) diff --git a/lib/crewai/src/crewai/events/types/mcp_events.py b/lib/crewai/src/crewai/events/types/mcp_events.py index d360aa62a..d6ca9b99a 100644 --- a/lib/crewai/src/crewai/events/types/mcp_events.py +++ b/lib/crewai/src/crewai/events/types/mcp_events.py @@ -83,3 +83,16 @@ class MCPToolExecutionFailedEvent(MCPEvent): error_type: str | None = None # "timeout", "validation", "server_error", etc. started_at: datetime | None = None failed_at: datetime | None = None + + +class MCPConfigFetchFailedEvent(BaseEvent): + """Event emitted when fetching an AMP MCP server config fails. + + This covers cases where the slug is not connected, the API call + failed, or native MCP resolution failed after config was fetched. + """ + + type: str = "mcp_config_fetch_failed" + slug: str + error: str + error_type: str | None = None # "not_connected", "api_error", "connection_failed" diff --git a/lib/crewai/src/crewai/events/utils/console_formatter.py b/lib/crewai/src/crewai/events/utils/console_formatter.py index 157d812ef..77cc76f4b 100644 --- a/lib/crewai/src/crewai/events/utils/console_formatter.py +++ b/lib/crewai/src/crewai/events/utils/console_formatter.py @@ -1512,6 +1512,34 @@ To enable tracing, do any one of these: self.print(panel) self.print() + def handle_mcp_config_fetch_failed( + self, + slug: str, + error: str = "", + error_type: str | None = None, + ) -> None: + """Handle MCP config fetch failed event (AMP resolution failures).""" + if not self.verbose: + return + + content = Text() + content.append("MCP Config Fetch Failed\n\n", style="red bold") + content.append("Server: ", style="white") + content.append(f"{slug}\n", style="red") + + if error_type: + content.append("Error Type: ", style="white") + content.append(f"{error_type}\n", style="red") + + if error: + content.append("\nError: ", style="white bold") + error_preview = error[:500] + "..." if len(error) > 500 else error + content.append(f"{error_preview}\n", style="red") + + panel = self.create_panel(content, "❌ MCP Config Failed", "red") + self.print(panel) + self.print() + def handle_mcp_tool_execution_started( self, server_name: str, diff --git a/lib/crewai/src/crewai/mcp/__init__.py b/lib/crewai/src/crewai/mcp/__init__.py index 282cb1f56..e078919fd 100644 --- a/lib/crewai/src/crewai/mcp/__init__.py +++ b/lib/crewai/src/crewai/mcp/__init__.py @@ -18,6 +18,7 @@ from crewai.mcp.filters import ( create_dynamic_tool_filter, create_static_tool_filter, ) +from crewai.mcp.tool_resolver import MCPToolResolver from crewai.mcp.transports.base import BaseTransport, TransportType @@ -28,6 +29,7 @@ __all__ = [ "MCPServerHTTP", "MCPServerSSE", "MCPServerStdio", + "MCPToolResolver", "StaticToolFilter", "ToolFilter", "ToolFilterContext", diff --git a/lib/crewai/src/crewai/mcp/client.py b/lib/crewai/src/crewai/mcp/client.py index f608933f6..2b5d75371 100644 --- a/lib/crewai/src/crewai/mcp/client.py +++ b/lib/crewai/src/crewai/mcp/client.py @@ -6,7 +6,7 @@ from contextlib import AsyncExitStack from datetime import datetime import logging import time -from typing import Any +from typing import Any, NamedTuple from typing_extensions import Self @@ -34,6 +34,13 @@ from crewai.mcp.transports.stdio import StdioTransport from crewai.utilities.string_utils import sanitize_tool_name +class _MCPToolResult(NamedTuple): + """Internal result from an MCP tool call, carrying the ``isError`` flag.""" + + content: str + is_error: bool + + # MCP Connection timeout constants (in seconds) MCP_CONNECTION_TIMEOUT = 30 # Increased for slow servers MCP_TOOL_EXECUTION_TIMEOUT = 30 @@ -420,6 +427,7 @@ class MCPClient: return [ { "name": sanitize_tool_name(tool.name), + "original_name": tool.name, "description": getattr(tool, "description", ""), "inputSchema": getattr(tool, "inputSchema", {}), } @@ -461,29 +469,46 @@ class MCPClient: ) try: - result = await self._retry_operation( + tool_result: _MCPToolResult = await self._retry_operation( lambda: self._call_tool_impl(tool_name, cleaned_arguments), timeout=self.execution_timeout, ) - completed_at = datetime.now() - execution_duration_ms = (completed_at - started_at).total_seconds() * 1000 - crewai_event_bus.emit( - self, - MCPToolExecutionCompletedEvent( - server_name=server_name, - server_url=server_url, - transport_type=transport_type, - tool_name=tool_name, - tool_args=cleaned_arguments, - result=result, - started_at=started_at, - completed_at=completed_at, - execution_duration_ms=execution_duration_ms, - ), - ) + finished_at = datetime.now() + execution_duration_ms = (finished_at - started_at).total_seconds() * 1000 - return result + if tool_result.is_error: + crewai_event_bus.emit( + self, + MCPToolExecutionFailedEvent( + server_name=server_name, + server_url=server_url, + transport_type=transport_type, + tool_name=tool_name, + tool_args=cleaned_arguments, + error=tool_result.content, + error_type="tool_error", + started_at=started_at, + failed_at=finished_at, + ), + ) + else: + crewai_event_bus.emit( + self, + MCPToolExecutionCompletedEvent( + server_name=server_name, + server_url=server_url, + transport_type=transport_type, + tool_name=tool_name, + tool_args=cleaned_arguments, + result=tool_result.content, + started_at=started_at, + completed_at=finished_at, + execution_duration_ms=execution_duration_ms, + ), + ) + + return tool_result.content except Exception as e: failed_at = datetime.now() error_type = ( @@ -564,23 +589,27 @@ class MCPClient: return cleaned - async def _call_tool_impl(self, tool_name: str, arguments: dict[str, Any]) -> Any: + async def _call_tool_impl( + self, tool_name: str, arguments: dict[str, Any] + ) -> _MCPToolResult: """Internal implementation of call_tool.""" result = await asyncio.wait_for( self.session.call_tool(tool_name, arguments), timeout=self.execution_timeout, ) + is_error = getattr(result, "isError", False) or False + # Extract result content if hasattr(result, "content") and result.content: if isinstance(result.content, list) and len(result.content) > 0: content_item = result.content[0] if hasattr(content_item, "text"): - return str(content_item.text) - return str(content_item) - return str(result.content) + return _MCPToolResult(str(content_item.text), is_error) + return _MCPToolResult(str(content_item), is_error) + return _MCPToolResult(str(result.content), is_error) - return str(result) + return _MCPToolResult(str(result), is_error) async def list_prompts(self) -> list[dict[str, Any]]: """List available prompts from MCP server. diff --git a/lib/crewai/src/crewai/mcp/tool_resolver.py b/lib/crewai/src/crewai/mcp/tool_resolver.py new file mode 100644 index 000000000..34af189f2 --- /dev/null +++ b/lib/crewai/src/crewai/mcp/tool_resolver.py @@ -0,0 +1,592 @@ +"""MCP tool resolution for CrewAI agents. + +This module extracts all MCP-related tool resolution logic from the Agent class +into a standalone MCPToolResolver. It handles three flavours of MCP reference: + + 1. Native configs: MCPServerStdio / MCPServerHTTP / MCPServerSSE objects. + 2. HTTPS URLs: e.g. "https://mcp.example.com/api" + 3. AMP references: e.g. "notion" or "notion#search" (legacy "crewai-amp:" prefix also works) +""" + +from __future__ import annotations + +import asyncio +import time +from typing import TYPE_CHECKING, Any, Final, cast +from urllib.parse import urlparse + +from crewai.mcp.client import MCPClient +from crewai.mcp.config import ( + MCPServerConfig, + MCPServerHTTP, + MCPServerSSE, + MCPServerStdio, +) +from crewai.mcp.transports.http import HTTPTransport +from crewai.mcp.transports.sse import SSETransport +from crewai.mcp.transports.stdio import StdioTransport + + +if TYPE_CHECKING: + from crewai.tools.base_tool import BaseTool + from crewai.utilities.logger import Logger + +MCP_CONNECTION_TIMEOUT: Final[int] = 10 +MCP_TOOL_EXECUTION_TIMEOUT: Final[int] = 30 +MCP_DISCOVERY_TIMEOUT: Final[int] = 15 +MCP_MAX_RETRIES: Final[int] = 3 + +_mcp_schema_cache: dict[str, Any] = {} +_cache_ttl: Final[int] = 300 # 5 minutes + + +class MCPToolResolver: + """Resolves MCP server references / configs into CrewAI ``BaseTool`` instances. + + Typical lifecycle:: + + resolver = MCPToolResolver(agent=my_agent, logger=my_agent._logger) + tools = resolver.resolve(my_agent.mcps) + # … agent executes tasks using *tools* … + resolver.cleanup() + + The resolver owns the MCP client connections it creates and is responsible + for tearing them down via :meth:`cleanup`. + """ + + def __init__(self, agent: Any, logger: Logger) -> None: + self._agent = agent + self._logger = logger + self._clients: list[Any] = [] + + @property + def clients(self) -> list[Any]: + return list(self._clients) + + def resolve(self, mcps: list[str | MCPServerConfig]) -> list[BaseTool]: + """Convert MCP server references/configs to CrewAI tools.""" + all_tools: list[BaseTool] = [] + amp_refs: list[tuple[str, str | None]] = [] + + for mcp_config in mcps: + if isinstance(mcp_config, str) and mcp_config.startswith("https://"): + all_tools.extend(self._resolve_external(mcp_config)) + elif isinstance(mcp_config, str): + amp_refs.append(self._parse_amp_ref(mcp_config)) + else: + tools, client = self._resolve_native(mcp_config) + all_tools.extend(tools) + if client: + self._clients.append(client) + + if amp_refs: + tools, clients = self._resolve_amp(amp_refs) + all_tools.extend(tools) + self._clients.extend(clients) + + return all_tools + + def cleanup(self) -> None: + """Disconnect all MCP client connections.""" + if not self._clients: + return + + async def _disconnect_all() -> None: + for client in self._clients: + if client and hasattr(client, "connected") and client.connected: + await client.disconnect() + + try: + asyncio.run(_disconnect_all()) + except Exception as e: + self._logger.log("error", f"Error during MCP client cleanup: {e}") + finally: + self._clients.clear() + + @staticmethod + def _parse_amp_ref(mcp_config: str) -> tuple[str, str | None]: + """Parse an AMP reference into *(slug, optional tool name)*. + + Accepts both bare slugs (``"notion"``, ``"notion#search"``) and the + legacy ``"crewai-amp:notion"`` form. + """ + bare = mcp_config.removeprefix("crewai-amp:") + slug, _, specific_tool = bare.partition("#") + return slug, specific_tool or None + + def _resolve_amp( + self, amp_refs: list[tuple[str, str | None]] + ) -> tuple[list[BaseTool], list[Any]]: + """Fetch AMP configs in bulk and return their tools and clients. + + Resolves each unique slug only once (single connection per server), + then applies per-ref tool filters to select specific tools. + """ + from crewai.events.event_bus import crewai_event_bus + from crewai.events.types.mcp_events import MCPConfigFetchFailedEvent + + unique_slugs = list(dict.fromkeys(slug for slug, _ in amp_refs)) + amp_configs_map = self._fetch_amp_mcp_configs(unique_slugs) + + all_tools: list[BaseTool] = [] + all_clients: list[Any] = [] + + resolved_cache: dict[str, tuple[list[BaseTool], Any | None]] = {} + + for slug in unique_slugs: + config_dict = amp_configs_map.get(slug) + if not config_dict: + crewai_event_bus.emit( + self, + MCPConfigFetchFailedEvent( + slug=slug, + error=f"Config for '{slug}' not found. Make sure it is connected in your account.", + error_type="not_connected", + ), + ) + continue + + mcp_server_config = self._build_mcp_config_from_dict(config_dict) + + try: + tools, client = self._resolve_native(mcp_server_config) + resolved_cache[slug] = (tools, client) + if client: + all_clients.append(client) + except Exception as e: + crewai_event_bus.emit( + self, + MCPConfigFetchFailedEvent( + slug=slug, + error=str(e), + error_type="connection_failed", + ), + ) + + for slug, specific_tool in amp_refs: + cached = resolved_cache.get(slug) + if not cached: + continue + + slug_tools, _ = cached + if specific_tool: + all_tools.extend( + t for t in slug_tools if t.name.endswith(f"_{specific_tool}") + ) + else: + all_tools.extend(slug_tools) + + return all_tools, all_clients + + def _fetch_amp_mcp_configs(self, slugs: list[str]) -> dict[str, dict[str, Any]]: + """Fetch MCP server configurations via CrewAI+ API. + + Sends a GET request to the CrewAI+ mcps/configs endpoint with + comma-separated slugs. CrewAI+ proxies the request to crewai-oauth. + + API-level failures return ``{}``; individual slugs will then + surface as ``MCPConfigFetchFailedEvent`` in :meth:`_resolve_amp`. + """ + import httpx + + try: + from crewai_tools.tools.crewai_platform_tools.misc import ( + get_platform_integration_token, + ) + + from crewai.cli.plus_api import PlusAPI + + plus_api = PlusAPI(api_key=get_platform_integration_token()) + response = plus_api.get_mcp_configs(slugs) + + if response.status_code == 200: + configs: dict[str, dict[str, Any]] = response.json().get("configs", {}) + return configs + + self._logger.log( + "debug", + f"Failed to fetch MCP configs: HTTP {response.status_code}", + ) + return {} + + except httpx.HTTPError as e: + self._logger.log("debug", f"Failed to fetch MCP configs: {e}") + return {} + except Exception as e: + self._logger.log("debug", f"Cannot fetch AMP MCP configs: {e}") + return {} + + def _resolve_external(self, mcp_ref: str) -> list[BaseTool]: + """Resolve an HTTPS MCP server URL into tools.""" + from crewai.tools.mcp_tool_wrapper import MCPToolWrapper + + if "#" in mcp_ref: + server_url, specific_tool = mcp_ref.split("#", 1) + else: + server_url, specific_tool = mcp_ref, None + + server_params = {"url": server_url} + server_name = self._extract_server_name(server_url) + + try: + tool_schemas = self._get_mcp_tool_schemas(server_params) + + if not tool_schemas: + self._logger.log( + "warning", f"No tools discovered from MCP server: {server_url}" + ) + return [] + + tools = [] + for tool_name, schema in tool_schemas.items(): + if specific_tool and tool_name != specific_tool: + continue + + try: + wrapper = MCPToolWrapper( + mcp_server_params=server_params, + tool_name=tool_name, + tool_schema=schema, + server_name=server_name, + ) + tools.append(wrapper) + except Exception as e: + self._logger.log( + "warning", + f"Failed to create MCP tool wrapper for {tool_name}: {e}", + ) + continue + + if specific_tool and not tools: + self._logger.log( + "warning", + f"Specific tool '{specific_tool}' not found on MCP server: {server_url}", + ) + + return cast(list[BaseTool], tools) + + except Exception as e: + self._logger.log( + "warning", f"Failed to connect to MCP server {server_url}: {e}" + ) + return [] + + def _resolve_native( + self, mcp_config: MCPServerConfig + ) -> tuple[list[BaseTool], Any | None]: + """Resolve an ``MCPServerConfig`` into tools, returning the client for cleanup.""" + from crewai.tools.base_tool import BaseTool + from crewai.tools.mcp_native_tool import MCPNativeTool + + transport: StdioTransport | HTTPTransport | SSETransport + if isinstance(mcp_config, MCPServerStdio): + transport = StdioTransport( + command=mcp_config.command, + args=mcp_config.args, + env=mcp_config.env, + ) + server_name = f"{mcp_config.command}_{'_'.join(mcp_config.args)}" + elif isinstance(mcp_config, MCPServerHTTP): + transport = HTTPTransport( + url=mcp_config.url, + headers=mcp_config.headers, + streamable=mcp_config.streamable, + ) + server_name = self._extract_server_name(mcp_config.url) + elif isinstance(mcp_config, MCPServerSSE): + transport = SSETransport( + url=mcp_config.url, + headers=mcp_config.headers, + ) + server_name = self._extract_server_name(mcp_config.url) + else: + raise ValueError(f"Unsupported MCP server config type: {type(mcp_config)}") + + client = MCPClient( + transport=transport, + cache_tools_list=mcp_config.cache_tools_list, + ) + + async def _setup_client_and_list_tools() -> list[dict[str, Any]]: + try: + if not client.connected: + await client.connect() + + tools_list = await client.list_tools() + + try: + await client.disconnect() + await asyncio.sleep(0.1) + except Exception as e: + self._logger.log("error", f"Error during disconnect: {e}") + + return tools_list + except Exception as e: + if client.connected: + await client.disconnect() + await asyncio.sleep(0.1) + raise RuntimeError( + f"Error during setup client and list tools: {e}" + ) from e + + try: + try: + asyncio.get_running_loop() + import concurrent.futures + + with concurrent.futures.ThreadPoolExecutor() as executor: + future = executor.submit( + asyncio.run, _setup_client_and_list_tools() + ) + tools_list = future.result() + except RuntimeError: + try: + tools_list = asyncio.run(_setup_client_and_list_tools()) + except RuntimeError as e: + error_msg = str(e).lower() + if "cancel scope" in error_msg or "task" in error_msg: + raise ConnectionError( + "MCP connection failed due to event loop cleanup issues. " + "This may be due to authentication errors or server unavailability." + ) from e + except asyncio.CancelledError as e: + raise ConnectionError( + "MCP connection was cancelled. This may indicate an authentication " + "error or server unavailability." + ) from e + + if mcp_config.tool_filter: + filtered_tools = [] + for tool in tools_list: + if callable(mcp_config.tool_filter): + try: + from crewai.mcp.filters import ToolFilterContext + + context = ToolFilterContext( + agent=self._agent, + server_name=server_name, + run_context=None, + ) + if mcp_config.tool_filter(context, tool): # type: ignore[call-arg, arg-type] + filtered_tools.append(tool) + except (TypeError, AttributeError): + if mcp_config.tool_filter(tool): # type: ignore[call-arg, arg-type] + filtered_tools.append(tool) + else: + filtered_tools.append(tool) + tools_list = filtered_tools + + tools = [] + for tool_def in tools_list: + tool_name = tool_def.get("name", "") + original_tool_name = tool_def.get("original_name", tool_name) + if not tool_name: + continue + + args_schema = None + if tool_def.get("inputSchema"): + args_schema = self._json_schema_to_pydantic( + tool_name, tool_def["inputSchema"] + ) + + tool_schema = { + "description": tool_def.get("description", ""), + "args_schema": args_schema, + } + + try: + native_tool = MCPNativeTool( + mcp_client=client, + tool_name=tool_name, + tool_schema=tool_schema, + server_name=server_name, + original_tool_name=original_tool_name, + ) + tools.append(native_tool) + except Exception as e: + self._logger.log("error", f"Failed to create native MCP tool: {e}") + continue + + return cast(list[BaseTool], tools), client + except Exception as e: + if client.connected: + asyncio.run(client.disconnect()) + + raise RuntimeError(f"Failed to get native MCP tools: {e}") from e + + @staticmethod + def _build_mcp_config_from_dict( + config_dict: dict[str, Any], + ) -> MCPServerConfig: + """Convert a config dict from crewai-oauth into an MCPServerConfig.""" + config_type = config_dict.get("type", "http") + + if config_type == "sse": + return MCPServerSSE( + url=config_dict["url"], + headers=config_dict.get("headers"), + cache_tools_list=config_dict.get("cache_tools_list", False), + ) + + return MCPServerHTTP( + url=config_dict["url"], + headers=config_dict.get("headers"), + streamable=config_dict.get("streamable", True), + cache_tools_list=config_dict.get("cache_tools_list", False), + ) + + @staticmethod + def _extract_server_name(server_url: str) -> str: + """Extract clean server name from URL for tool prefixing.""" + parsed = urlparse(server_url) + domain = parsed.netloc.replace(".", "_") + path = parsed.path.replace("/", "_").strip("_") + return f"{domain}_{path}" if path else domain + + def _get_mcp_tool_schemas( + self, server_params: dict[str, Any] + ) -> dict[str, dict[str, Any]]: + """Get tool schemas from MCP server with caching.""" + server_url = server_params["url"] + + cache_key = server_url + current_time = time.time() + + if cache_key in _mcp_schema_cache: + cached_data, cache_time = _mcp_schema_cache[cache_key] + if current_time - cache_time < _cache_ttl: + self._logger.log( + "debug", f"Using cached MCP tool schemas for {server_url}" + ) + return cached_data # type: ignore[no-any-return] + + try: + schemas = asyncio.run(self._get_mcp_tool_schemas_async(server_params)) + _mcp_schema_cache[cache_key] = (schemas, current_time) + return schemas + except Exception as e: + self._logger.log( + "warning", f"Failed to get MCP tool schemas from {server_url}: {e}" + ) + return {} + + async def _get_mcp_tool_schemas_async( + self, server_params: dict[str, Any] + ) -> dict[str, dict[str, Any]]: + """Async implementation of MCP tool schema retrieval.""" + server_url = server_params["url"] + return await self._retry_mcp_discovery( + self._discover_mcp_tools_with_timeout, server_url + ) + + async def _retry_mcp_discovery( + self, operation_func: Any, server_url: str + ) -> dict[str, dict[str, Any]]: + """Retry MCP discovery with exponential backoff.""" + last_error = None + + for attempt in range(MCP_MAX_RETRIES): + result, error, should_retry = await self._attempt_mcp_discovery( + operation_func, server_url + ) + + if result is not None: + return result + + if not should_retry: + raise RuntimeError(error) + + last_error = error + if attempt < MCP_MAX_RETRIES - 1: + wait_time = 2**attempt + await asyncio.sleep(wait_time) + + raise RuntimeError( + f"Failed to discover MCP tools after {MCP_MAX_RETRIES} attempts: {last_error}" + ) + + @staticmethod + async def _attempt_mcp_discovery( + operation_func: Any, server_url: str + ) -> tuple[dict[str, dict[str, Any]] | None, str, bool]: + """Attempt single MCP discovery; returns *(result, error_message, should_retry)*.""" + try: + result = await operation_func(server_url) + return result, "", False + + except ImportError: + return ( + None, + "MCP library not available. Please install with: pip install mcp", + False, + ) + + except asyncio.TimeoutError: + return ( + None, + f"MCP discovery timed out after {MCP_DISCOVERY_TIMEOUT} seconds", + True, + ) + + except Exception as e: + error_str = str(e).lower() + + if "authentication" in error_str or "unauthorized" in error_str: + return None, f"Authentication failed for MCP server: {e!s}", False + if "connection" in error_str or "network" in error_str: + return None, f"Network connection failed: {e!s}", True + if "json" in error_str or "parsing" in error_str: + return None, f"Server response parsing error: {e!s}", True + return None, f"MCP discovery error: {e!s}", False + + async def _discover_mcp_tools_with_timeout( + self, server_url: str + ) -> dict[str, dict[str, Any]]: + """Discover MCP tools with timeout wrapper.""" + return await asyncio.wait_for( + self._discover_mcp_tools(server_url), timeout=MCP_DISCOVERY_TIMEOUT + ) + + async def _discover_mcp_tools(self, server_url: str) -> dict[str, dict[str, Any]]: + """Discover tools from an MCP server (HTTPS / streamable-HTTP path).""" + from mcp import ClientSession + from mcp.client.streamable_http import streamablehttp_client + + from crewai.utilities.string_utils import sanitize_tool_name + + async with streamablehttp_client(server_url) as (read, write, _): + async with ClientSession(read, write) as session: + await asyncio.wait_for( + session.initialize(), timeout=MCP_CONNECTION_TIMEOUT + ) + + tools_result = await asyncio.wait_for( + session.list_tools(), + timeout=MCP_DISCOVERY_TIMEOUT - MCP_CONNECTION_TIMEOUT, + ) + + schemas = {} + for tool in tools_result.tools: + args_schema = None + if hasattr(tool, "inputSchema") and tool.inputSchema: + args_schema = self._json_schema_to_pydantic( + sanitize_tool_name(tool.name), tool.inputSchema + ) + + schemas[sanitize_tool_name(tool.name)] = { + "description": getattr(tool, "description", ""), + "args_schema": args_schema, + } + return schemas + + @staticmethod + def _json_schema_to_pydantic(tool_name: str, json_schema: dict[str, Any]) -> type: + """Convert JSON Schema to a Pydantic model for tool arguments.""" + from crewai.utilities.pydantic_schema_utils import create_model_from_schema + + model_name = f"{tool_name.replace('-', '_').replace(' ', '_')}Schema" + return create_model_from_schema( + json_schema, + model_name=model_name, + enrich_descriptions=True, + ) diff --git a/lib/crewai/src/crewai/tools/mcp_native_tool.py b/lib/crewai/src/crewai/tools/mcp_native_tool.py index f25b2f4d7..d14c26a5a 100644 --- a/lib/crewai/src/crewai/tools/mcp_native_tool.py +++ b/lib/crewai/src/crewai/tools/mcp_native_tool.py @@ -27,14 +27,16 @@ class MCPNativeTool(BaseTool): tool_name: str, tool_schema: dict[str, Any], server_name: str, + original_tool_name: str | None = None, ) -> None: """Initialize native MCP tool. Args: mcp_client: MCPClient instance with active session. - tool_name: Original name of the tool on the MCP server. + tool_name: Name of the tool (may be prefixed). tool_schema: Schema information for the tool. server_name: Name of the MCP server for prefixing. + original_tool_name: Original name of the tool on the MCP server. """ # Create tool name with server prefix to avoid conflicts prefixed_name = f"{server_name}_{tool_name}" @@ -57,7 +59,7 @@ class MCPNativeTool(BaseTool): # Set instance attributes after super().__init__ self._mcp_client = mcp_client - self._original_tool_name = tool_name + self._original_tool_name = original_tool_name or tool_name self._server_name = server_name # self._logger = logging.getLogger(__name__) diff --git a/lib/crewai/src/crewai/utilities/pydantic_schema_utils.py b/lib/crewai/src/crewai/utilities/pydantic_schema_utils.py index 4548ab9ce..87d80da81 100644 --- a/lib/crewai/src/crewai/utilities/pydantic_schema_utils.py +++ b/lib/crewai/src/crewai/utilities/pydantic_schema_utils.py @@ -491,10 +491,66 @@ FORMAT_TYPE_MAP: dict[str, type[Any]] = { } +def build_rich_field_description(prop_schema: dict[str, Any]) -> str: + """Build a comprehensive field description including constraints. + + Embeds format, enum, pattern, min/max, and example constraints into the + description text so that LLMs can understand tool parameter requirements + without inspecting the raw JSON Schema. + + Args: + prop_schema: Property schema with description and constraints. + + Returns: + Enhanced description with format, enum, and other constraints. + """ + parts: list[str] = [] + + description = prop_schema.get("description", "") + if description: + parts.append(description) + + format_type = prop_schema.get("format") + if format_type: + parts.append(f"Format: {format_type}") + + enum_values = prop_schema.get("enum") + if enum_values: + enum_str = ", ".join(repr(v) for v in enum_values) + parts.append(f"Allowed values: [{enum_str}]") + + pattern = prop_schema.get("pattern") + if pattern: + parts.append(f"Pattern: {pattern}") + + minimum = prop_schema.get("minimum") + maximum = prop_schema.get("maximum") + if minimum is not None: + parts.append(f"Minimum: {minimum}") + if maximum is not None: + parts.append(f"Maximum: {maximum}") + + min_length = prop_schema.get("minLength") + max_length = prop_schema.get("maxLength") + if min_length is not None: + parts.append(f"Min length: {min_length}") + if max_length is not None: + parts.append(f"Max length: {max_length}") + + examples = prop_schema.get("examples") + if examples: + examples_str = ", ".join(repr(e) for e in examples[:3]) + parts.append(f"Examples: {examples_str}") + + return ". ".join(parts) if parts else "" + + def create_model_from_schema( # type: ignore[no-any-unimported] json_schema: dict[str, Any], *, root_schema: dict[str, Any] | None = None, + model_name: str | None = None, + enrich_descriptions: bool = False, __config__: ConfigDict | None = None, __base__: type[BaseModel] | None = None, __module__: str = __name__, @@ -512,6 +568,13 @@ def create_model_from_schema( # type: ignore[no-any-unimported] json_schema: A dictionary representing the JSON schema. root_schema: The root schema containing $defs. If not provided, the current schema is treated as the root schema. + model_name: Override for the model name. If not provided, the schema + ``title`` field is used, falling back to ``"DynamicModel"``. + enrich_descriptions: When True, augment field descriptions with + constraint info (format, enum, pattern, min/max, examples) via + :func:`build_rich_field_description`. Useful for LLM-facing tool + schemas where constraints in the description help the model + understand parameter requirements. __config__: Pydantic configuration for the generated model. __base__: Base class for the generated model. Defaults to BaseModel. __module__: Module name for the generated model class. @@ -548,10 +611,14 @@ def create_model_from_schema( # type: ignore[no-any-unimported] if "title" not in json_schema and "title" in (root_schema or {}): json_schema["title"] = (root_schema or {}).get("title") - model_name = json_schema.get("title") or "DynamicModel" + effective_name = model_name or json_schema.get("title") or "DynamicModel" field_definitions = { name: _json_schema_to_pydantic_field( - name, prop, json_schema.get("required", []), effective_root + name, + prop, + json_schema.get("required", []), + effective_root, + enrich_descriptions=enrich_descriptions, ) for name, prop in (json_schema.get("properties", {}) or {}).items() } @@ -559,7 +626,7 @@ def create_model_from_schema( # type: ignore[no-any-unimported] effective_config = __config__ or ConfigDict(extra="forbid") return create_model_base( - model_name, + effective_name, __config__=effective_config, __base__=__base__, __module__=__module__, @@ -574,6 +641,8 @@ def _json_schema_to_pydantic_field( json_schema: dict[str, Any], required: list[str], root_schema: dict[str, Any], + *, + enrich_descriptions: bool = False, ) -> Any: """Convert a JSON schema property to a Pydantic field definition. @@ -582,20 +651,29 @@ def _json_schema_to_pydantic_field( json_schema: The JSON schema for this field. required: List of required field names. root_schema: The root schema for resolving $ref. + enrich_descriptions: When True, embed constraints in the description. Returns: A tuple of (type, Field) for use with create_model. """ - type_ = _json_schema_to_pydantic_type(json_schema, root_schema, name_=name.title()) - description = json_schema.get("description") - examples = json_schema.get("examples") + type_ = _json_schema_to_pydantic_type( + json_schema, root_schema, name_=name.title(), enrich_descriptions=enrich_descriptions + ) is_required = name in required field_params: dict[str, Any] = {} schema_extra: dict[str, Any] = {} - if description: - field_params["description"] = description + if enrich_descriptions: + rich_desc = build_rich_field_description(json_schema) + if rich_desc: + field_params["description"] = rich_desc + else: + description = json_schema.get("description") + if description: + field_params["description"] = description + + examples = json_schema.get("examples") if examples: schema_extra["examples"] = examples @@ -711,6 +789,7 @@ def _json_schema_to_pydantic_type( root_schema: dict[str, Any], *, name_: str | None = None, + enrich_descriptions: bool = False, ) -> Any: """Convert a JSON schema to a Python/Pydantic type. @@ -718,6 +797,7 @@ def _json_schema_to_pydantic_type( json_schema: The JSON schema to convert. root_schema: The root schema for resolving $ref. name_: Optional name for nested models. + enrich_descriptions: Propagated to nested model creation. Returns: A Python type corresponding to the JSON schema. @@ -725,7 +805,9 @@ def _json_schema_to_pydantic_type( ref = json_schema.get("$ref") if ref: ref_schema = _resolve_ref(ref, root_schema) - return _json_schema_to_pydantic_type(ref_schema, root_schema, name_=name_) + return _json_schema_to_pydantic_type( + ref_schema, root_schema, name_=name_, enrich_descriptions=enrich_descriptions + ) enum_values = json_schema.get("enum") if enum_values: @@ -740,7 +822,10 @@ def _json_schema_to_pydantic_type( if any_of_schemas: any_of_types = [ _json_schema_to_pydantic_type( - schema, root_schema, name_=f"{name_ or 'Union'}Option{i}" + schema, + root_schema, + name_=f"{name_ or 'Union'}Option{i}", + enrich_descriptions=enrich_descriptions, ) for i, schema in enumerate(any_of_schemas) ] @@ -750,10 +835,14 @@ def _json_schema_to_pydantic_type( if all_of_schemas: if len(all_of_schemas) == 1: return _json_schema_to_pydantic_type( - all_of_schemas[0], root_schema, name_=name_ + all_of_schemas[0], root_schema, name_=name_, + enrich_descriptions=enrich_descriptions, ) merged = _merge_all_of_schemas(all_of_schemas, root_schema) - return _json_schema_to_pydantic_type(merged, root_schema, name_=name_) + return _json_schema_to_pydantic_type( + merged, root_schema, name_=name_, + enrich_descriptions=enrich_descriptions, + ) type_ = json_schema.get("type") @@ -769,7 +858,8 @@ def _json_schema_to_pydantic_type( items_schema = json_schema.get("items") if items_schema: item_type = _json_schema_to_pydantic_type( - items_schema, root_schema, name_=name_ + items_schema, root_schema, name_=name_, + enrich_descriptions=enrich_descriptions, ) return list[item_type] # type: ignore[valid-type] return list @@ -779,7 +869,10 @@ def _json_schema_to_pydantic_type( json_schema_ = json_schema.copy() if json_schema_.get("title") is None: json_schema_["title"] = name_ or "DynamicModel" - return create_model_from_schema(json_schema_, root_schema=root_schema) + return create_model_from_schema( + json_schema_, root_schema=root_schema, + enrich_descriptions=enrich_descriptions, + ) return dict if type_ == "null": return None diff --git a/lib/crewai/tests/agents/test_lite_agent.py b/lib/crewai/tests/agents/test_lite_agent.py index c99b5c534..ac03ffc28 100644 --- a/lib/crewai/tests/agents/test_lite_agent.py +++ b/lib/crewai/tests/agents/test_lite_agent.py @@ -659,7 +659,7 @@ def test_agent_kickoff_with_platform_tools(mock_get, mock_post): @patch.dict("os.environ", {"EXA_API_KEY": "test_exa_key"}) -@patch("crewai.agent.Agent._get_external_mcp_tools") +@patch("crewai.agent.Agent.get_mcp_tools") @pytest.mark.vcr() def test_agent_kickoff_with_mcp_tools(mock_get_mcp_tools): """Test that Agent.kickoff() properly integrates MCP tools with LiteAgent""" @@ -691,7 +691,7 @@ def test_agent_kickoff_with_mcp_tools(mock_get_mcp_tools): assert result.raw is not None # Verify MCP tools were retrieved - mock_get_mcp_tools.assert_called_once_with("https://mcp.exa.ai/mcp?api_key=test_exa_key&profile=research") + mock_get_mcp_tools.assert_called_once_with(["https://mcp.exa.ai/mcp?api_key=test_exa_key&profile=research"]) # ============================================================================ diff --git a/lib/crewai/tests/mcp/test_amp_mcp.py b/lib/crewai/tests/mcp/test_amp_mcp.py new file mode 100644 index 000000000..3c4001f3c --- /dev/null +++ b/lib/crewai/tests/mcp/test_amp_mcp.py @@ -0,0 +1,373 @@ +"""Tests for AMP MCP config fetching and tool resolution.""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from crewai.agent.core import Agent +from crewai.mcp.config import MCPServerHTTP, MCPServerSSE +from crewai.mcp.tool_resolver import MCPToolResolver +from crewai.tools.base_tool import BaseTool + + +@pytest.fixture +def agent(): + return Agent( + role="Test Agent", + goal="Test goal", + backstory="Test backstory", + ) + + +@pytest.fixture +def resolver(agent): + return MCPToolResolver(agent=agent, logger=agent._logger) + + +@pytest.fixture +def mock_tool_definitions(): + return [ + { + "name": "search", + "description": "Search tool", + "inputSchema": { + "type": "object", + "properties": { + "query": {"type": "string", "description": "Search query"} + }, + "required": ["query"], + }, + }, + { + "name": "create_page", + "description": "Create a page", + "inputSchema": {}, + }, + ] + + +class TestBuildMCPConfigFromDict: + def test_builds_http_config(self): + config_dict = { + "type": "http", + "url": "https://mcp.example.com/api", + "headers": {"Authorization": "Bearer token123"}, + "streamable": True, + "cache_tools_list": False, + } + + result = MCPToolResolver._build_mcp_config_from_dict(config_dict) + + assert isinstance(result, MCPServerHTTP) + assert result.url == "https://mcp.example.com/api" + assert result.headers == {"Authorization": "Bearer token123"} + assert result.streamable is True + assert result.cache_tools_list is False + + def test_builds_sse_config(self): + config_dict = { + "type": "sse", + "url": "https://mcp.example.com/sse", + "headers": {"Authorization": "Bearer token123"}, + "cache_tools_list": True, + } + + result = MCPToolResolver._build_mcp_config_from_dict(config_dict) + + assert isinstance(result, MCPServerSSE) + assert result.url == "https://mcp.example.com/sse" + assert result.headers == {"Authorization": "Bearer token123"} + assert result.cache_tools_list is True + + def test_defaults_to_http(self): + config_dict = { + "url": "https://mcp.example.com/api", + } + + result = MCPToolResolver._build_mcp_config_from_dict(config_dict) + + assert isinstance(result, MCPServerHTTP) + assert result.streamable is True + + def test_http_defaults(self): + config_dict = { + "type": "http", + "url": "https://mcp.example.com/api", + } + + result = MCPToolResolver._build_mcp_config_from_dict(config_dict) + + assert result.headers is None + assert result.streamable is True + assert result.cache_tools_list is False + + +class TestFetchAmpMCPConfigs: + @patch("crewai.cli.plus_api.PlusAPI") + @patch("crewai_tools.tools.crewai_platform_tools.misc.get_platform_integration_token", return_value="test-api-key") + def test_fetches_configs_successfully(self, mock_get_token, mock_plus_api_class, resolver): + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "configs": { + "notion": { + "type": "sse", + "url": "https://mcp.notion.so/sse", + "headers": {"Authorization": "Bearer notion-token"}, + }, + "github": { + "type": "http", + "url": "https://mcp.github.com/api", + "headers": {"Authorization": "Bearer gh-token"}, + }, + }, + } + mock_plus_api = MagicMock() + mock_plus_api.get_mcp_configs.return_value = mock_response + mock_plus_api_class.return_value = mock_plus_api + + result = resolver._fetch_amp_mcp_configs(["notion", "github"]) + + assert "notion" in result + assert "github" in result + assert result["notion"]["url"] == "https://mcp.notion.so/sse" + mock_plus_api_class.assert_called_once_with(api_key="test-api-key") + mock_plus_api.get_mcp_configs.assert_called_once_with(["notion", "github"]) + + @patch("crewai.cli.plus_api.PlusAPI") + @patch("crewai_tools.tools.crewai_platform_tools.misc.get_platform_integration_token", return_value="test-api-key") + def test_omits_missing_slugs(self, mock_get_token, mock_plus_api_class, resolver): + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "configs": {"notion": {"type": "sse", "url": "https://mcp.notion.so/sse"}}, + } + mock_plus_api = MagicMock() + mock_plus_api.get_mcp_configs.return_value = mock_response + mock_plus_api_class.return_value = mock_plus_api + + result = resolver._fetch_amp_mcp_configs(["notion", "missing-server"]) + + assert "notion" in result + assert "missing-server" not in result + + @patch("crewai.cli.plus_api.PlusAPI") + @patch("crewai_tools.tools.crewai_platform_tools.misc.get_platform_integration_token", return_value="test-api-key") + def test_returns_empty_on_http_error(self, mock_get_token, mock_plus_api_class, resolver): + mock_response = MagicMock() + mock_response.status_code = 500 + mock_plus_api = MagicMock() + mock_plus_api.get_mcp_configs.return_value = mock_response + mock_plus_api_class.return_value = mock_plus_api + + result = resolver._fetch_amp_mcp_configs(["notion"]) + + assert result == {} + + @patch("crewai.cli.plus_api.PlusAPI") + @patch("crewai_tools.tools.crewai_platform_tools.misc.get_platform_integration_token", return_value="test-api-key") + def test_returns_empty_on_network_error(self, mock_get_token, mock_plus_api_class, resolver): + import httpx + + mock_plus_api = MagicMock() + mock_plus_api.get_mcp_configs.side_effect = httpx.ConnectError("Connection refused") + mock_plus_api_class.return_value = mock_plus_api + + result = resolver._fetch_amp_mcp_configs(["notion"]) + + assert result == {} + + @patch("crewai_tools.tools.crewai_platform_tools.misc.get_platform_integration_token", side_effect=Exception("No token")) + def test_returns_empty_when_no_token(self, mock_get_token, resolver): + result = resolver._fetch_amp_mcp_configs(["notion"]) + + assert result == {} + + +class TestParseAmpRef: + def test_bare_slug(self): + slug, tool = MCPToolResolver._parse_amp_ref("notion") + assert slug == "notion" + assert tool is None + + def test_bare_slug_with_tool(self): + slug, tool = MCPToolResolver._parse_amp_ref("notion#search") + assert slug == "notion" + assert tool == "search" + + def test_bare_slug_with_empty_tool(self): + slug, tool = MCPToolResolver._parse_amp_ref("notion#") + assert slug == "notion" + assert tool is None + + def test_legacy_prefix_slug(self): + slug, tool = MCPToolResolver._parse_amp_ref("crewai-amp:notion") + assert slug == "notion" + assert tool is None + + def test_legacy_prefix_with_tool(self): + slug, tool = MCPToolResolver._parse_amp_ref("crewai-amp:notion#search") + assert slug == "notion" + assert tool == "search" + + +class TestGetMCPToolsAmpIntegration: + @patch("crewai.mcp.tool_resolver.MCPClient") + @patch.object(MCPToolResolver, "_fetch_amp_mcp_configs") + def test_single_request_for_multiple_amp_refs( + self, mock_fetch, mock_client_class, agent, mock_tool_definitions + ): + mock_fetch.return_value = { + "notion": { + "type": "sse", + "url": "https://mcp.notion.so/sse", + "headers": {"Authorization": "Bearer token"}, + }, + "github": { + "type": "http", + "url": "https://mcp.github.com/api", + "headers": {"Authorization": "Bearer gh-token"}, + "streamable": True, + }, + } + + mock_client = AsyncMock() + mock_client.list_tools = AsyncMock(return_value=mock_tool_definitions) + mock_client.connected = False + mock_client.connect = AsyncMock() + mock_client.disconnect = AsyncMock() + mock_client_class.return_value = mock_client + + tools = agent.get_mcp_tools(["notion", "github"]) + + mock_fetch.assert_called_once_with(["notion", "github"]) + assert len(tools) == 4 # 2 tools per server + + @patch("crewai.mcp.tool_resolver.MCPClient") + @patch.object(MCPToolResolver, "_fetch_amp_mcp_configs") + def test_tool_filter_with_hash_syntax( + self, mock_fetch, mock_client_class, agent, mock_tool_definitions + ): + mock_fetch.return_value = { + "notion": { + "type": "sse", + "url": "https://mcp.notion.so/sse", + "headers": {"Authorization": "Bearer token"}, + }, + } + + mock_client = AsyncMock() + mock_client.list_tools = AsyncMock(return_value=mock_tool_definitions) + mock_client.connected = False + mock_client.connect = AsyncMock() + mock_client.disconnect = AsyncMock() + mock_client_class.return_value = mock_client + + tools = agent.get_mcp_tools(["notion#search"]) + + mock_fetch.assert_called_once_with(["notion"]) + assert len(tools) == 1 + assert tools[0].name == "mcp_notion_so_sse_search" + + @patch("crewai.mcp.tool_resolver.MCPClient") + @patch.object(MCPToolResolver, "_fetch_amp_mcp_configs") + def test_deduplicates_slugs( + self, mock_fetch, mock_client_class, agent, mock_tool_definitions + ): + mock_fetch.return_value = { + "notion": { + "type": "sse", + "url": "https://mcp.notion.so/sse", + "headers": {"Authorization": "Bearer token"}, + }, + } + + mock_client = AsyncMock() + mock_client.list_tools = AsyncMock(return_value=mock_tool_definitions) + mock_client.connected = False + mock_client.connect = AsyncMock() + mock_client.disconnect = AsyncMock() + mock_client_class.return_value = mock_client + + tools = agent.get_mcp_tools(["notion#search", "notion#create_page"]) + + mock_fetch.assert_called_once_with(["notion"]) + assert len(tools) == 2 + + @patch.object(MCPToolResolver, "_fetch_amp_mcp_configs") + def test_skips_missing_configs_gracefully(self, mock_fetch, agent): + mock_fetch.return_value = {} + + tools = agent.get_mcp_tools(["missing-server"]) + + assert tools == [] + + @patch("crewai.mcp.tool_resolver.MCPClient") + @patch.object(MCPToolResolver, "_fetch_amp_mcp_configs") + def test_legacy_crewai_amp_prefix_still_works( + self, mock_fetch, mock_client_class, agent, mock_tool_definitions + ): + mock_fetch.return_value = { + "notion": { + "type": "sse", + "url": "https://mcp.notion.so/sse", + "headers": {"Authorization": "Bearer token"}, + }, + } + + mock_client = AsyncMock() + mock_client.list_tools = AsyncMock(return_value=mock_tool_definitions) + mock_client.connected = False + mock_client.connect = AsyncMock() + mock_client.disconnect = AsyncMock() + mock_client_class.return_value = mock_client + + tools = agent.get_mcp_tools(["crewai-amp:notion"]) + + mock_fetch.assert_called_once_with(["notion"]) + assert len(tools) == 2 + + @patch("crewai.mcp.tool_resolver.MCPClient") + @patch.object(MCPToolResolver, "_fetch_amp_mcp_configs") + @patch.object(MCPToolResolver, "_resolve_external") + def test_non_amp_items_unaffected( + self, + mock_external, + mock_fetch, + mock_client_class, + agent, + mock_tool_definitions, + ): + mock_fetch.return_value = { + "notion": { + "type": "sse", + "url": "https://mcp.notion.so/sse", + }, + } + + mock_client = AsyncMock() + mock_client.list_tools = AsyncMock(return_value=mock_tool_definitions) + mock_client.connected = False + mock_client.connect = AsyncMock() + mock_client.disconnect = AsyncMock() + mock_client_class.return_value = mock_client + + mock_external_tool = MagicMock(spec=BaseTool) + mock_external.return_value = [mock_external_tool] + + http_config = MCPServerHTTP( + url="https://other.mcp.com/api", + headers={"Authorization": "Bearer other"}, + ) + + tools = agent.get_mcp_tools( + [ + "notion", + "https://external.mcp.com/api", + http_config, + ] + ) + + mock_fetch.assert_called_once_with(["notion"]) + mock_external.assert_called_once_with("https://external.mcp.com/api") + # 2 from notion + 1 from external + 2 from http_config + assert len(tools) == 5 diff --git a/lib/crewai/tests/mcp/test_mcp_config.py b/lib/crewai/tests/mcp/test_mcp_config.py index e55a7d504..24fc59769 100644 --- a/lib/crewai/tests/mcp/test_mcp_config.py +++ b/lib/crewai/tests/mcp/test_mcp_config.py @@ -1,5 +1,5 @@ import asyncio -from unittest.mock import AsyncMock, MagicMock, patch +from unittest.mock import AsyncMock, patch import pytest from crewai.agent.core import Agent @@ -46,7 +46,7 @@ def test_agent_with_stdio_mcp_config(mock_tool_definitions): ) - with patch("crewai.agent.core.MCPClient") as mock_client_class: + with patch("crewai.mcp.tool_resolver.MCPClient") as mock_client_class: mock_client = AsyncMock() mock_client.list_tools = AsyncMock(return_value=mock_tool_definitions) mock_client.connected = False # Will trigger connect @@ -82,7 +82,7 @@ def test_agent_with_http_mcp_config(mock_tool_definitions): mcps=[http_config], ) - with patch("crewai.agent.core.MCPClient") as mock_client_class: + with patch("crewai.mcp.tool_resolver.MCPClient") as mock_client_class: mock_client = AsyncMock() mock_client.list_tools = AsyncMock(return_value=mock_tool_definitions) mock_client.connected = False # Will trigger connect @@ -117,7 +117,7 @@ def test_agent_with_sse_mcp_config(mock_tool_definitions): mcps=[sse_config], ) - with patch("crewai.agent.core.MCPClient") as mock_client_class: + with patch("crewai.mcp.tool_resolver.MCPClient") as mock_client_class: mock_client = AsyncMock() mock_client.list_tools = AsyncMock(return_value=mock_tool_definitions) mock_client.connected = False @@ -141,7 +141,7 @@ def test_mcp_tool_execution_in_sync_context(mock_tool_definitions): """Test MCPNativeTool execution in synchronous context (normal crew execution).""" http_config = MCPServerHTTP(url="https://api.example.com/mcp") - with patch("crewai.agent.core.MCPClient") as mock_client_class: + with patch("crewai.mcp.tool_resolver.MCPClient") as mock_client_class: mock_client = AsyncMock() mock_client.list_tools = AsyncMock(return_value=mock_tool_definitions) mock_client.connected = False @@ -173,7 +173,7 @@ async def test_mcp_tool_execution_in_async_context(mock_tool_definitions): """Test MCPNativeTool execution in async context (e.g., from a Flow).""" http_config = MCPServerHTTP(url="https://api.example.com/mcp") - with patch("crewai.agent.core.MCPClient") as mock_client_class: + with patch("crewai.mcp.tool_resolver.MCPClient") as mock_client_class: mock_client = AsyncMock() mock_client.list_tools = AsyncMock(return_value=mock_tool_definitions) mock_client.connected = False diff --git a/lib/crewai/tests/utilities/test_pydantic_schema_utils.py b/lib/crewai/tests/utilities/test_pydantic_schema_utils.py new file mode 100644 index 000000000..98a5e6aa5 --- /dev/null +++ b/lib/crewai/tests/utilities/test_pydantic_schema_utils.py @@ -0,0 +1,884 @@ +"""Tests for pydantic_schema_utils module. + +Covers: +- create_model_from_schema: type mapping, required/optional, enums, formats, + nested objects, arrays, unions, allOf, $ref, model_name, enrich_descriptions +- Schema transformation helpers: resolve_refs, force_additional_properties_false, + strip_unsupported_formats, ensure_type_in_schemas, convert_oneof_to_anyof, + ensure_all_properties_required, strip_null_from_types, build_rich_field_description +- End-to-end MCP tool schema conversion +""" + +from __future__ import annotations + +import datetime +from copy import deepcopy +from typing import Any + +import pytest +from pydantic import BaseModel + +from crewai.utilities.pydantic_schema_utils import ( + build_rich_field_description, + convert_oneof_to_anyof, + create_model_from_schema, + ensure_all_properties_required, + ensure_type_in_schemas, + force_additional_properties_false, + resolve_refs, + strip_null_from_types, + strip_unsupported_formats, +) + + +class TestSimpleTypes: + def test_string_field(self) -> None: + schema = { + "type": "object", + "properties": {"name": {"type": "string"}}, + "required": ["name"], + } + Model = create_model_from_schema(schema) + obj = Model(name="Alice") + assert obj.name == "Alice" + + def test_integer_field(self) -> None: + schema = { + "type": "object", + "properties": {"count": {"type": "integer"}}, + "required": ["count"], + } + Model = create_model_from_schema(schema) + obj = Model(count=42) + assert obj.count == 42 + + def test_number_field(self) -> None: + schema = { + "type": "object", + "properties": {"score": {"type": "number"}}, + "required": ["score"], + } + Model = create_model_from_schema(schema) + obj = Model(score=3.14) + assert obj.score == pytest.approx(3.14) + + def test_boolean_field(self) -> None: + schema = { + "type": "object", + "properties": {"active": {"type": "boolean"}}, + "required": ["active"], + } + Model = create_model_from_schema(schema) + assert Model(active=True).active is True + + def test_null_field(self) -> None: + schema = { + "type": "object", + "properties": {"value": {"type": "null"}}, + "required": ["value"], + } + Model = create_model_from_schema(schema) + obj = Model(value=None) + assert obj.value is None + + +class TestRequiredOptional: + def test_required_field_has_no_default(self) -> None: + schema = { + "type": "object", + "properties": {"name": {"type": "string"}}, + "required": ["name"], + } + Model = create_model_from_schema(schema) + with pytest.raises(Exception): + Model() + + def test_optional_field_defaults_to_none(self) -> None: + schema = { + "type": "object", + "properties": {"name": {"type": "string"}}, + "required": [], + } + Model = create_model_from_schema(schema) + obj = Model() + assert obj.name is None + + def test_mixed_required_optional(self) -> None: + schema = { + "type": "object", + "properties": { + "id": {"type": "integer"}, + "label": {"type": "string"}, + }, + "required": ["id"], + } + Model = create_model_from_schema(schema) + obj = Model(id=1) + assert obj.id == 1 + assert obj.label is None + + +class TestEnumLiteral: + def test_string_enum(self) -> None: + schema = { + "type": "object", + "properties": { + "color": {"type": "string", "enum": ["red", "green", "blue"]}, + }, + "required": ["color"], + } + Model = create_model_from_schema(schema) + obj = Model(color="red") + assert obj.color == "red" + + def test_string_enum_rejects_invalid(self) -> None: + schema = { + "type": "object", + "properties": { + "color": {"type": "string", "enum": ["red", "green", "blue"]}, + }, + "required": ["color"], + } + Model = create_model_from_schema(schema) + with pytest.raises(Exception): + Model(color="yellow") + + def test_const_value(self) -> None: + schema = { + "type": "object", + "properties": { + "kind": {"const": "fixed"}, + }, + "required": ["kind"], + } + Model = create_model_from_schema(schema) + obj = Model(kind="fixed") + assert obj.kind == "fixed" + + +class TestFormatMapping: + def test_date_format(self) -> None: + schema = { + "type": "object", + "properties": { + "birthday": {"type": "string", "format": "date"}, + }, + "required": ["birthday"], + } + Model = create_model_from_schema(schema) + obj = Model(birthday=datetime.date(2000, 1, 15)) + assert obj.birthday == datetime.date(2000, 1, 15) + + def test_datetime_format(self) -> None: + schema = { + "type": "object", + "properties": { + "created_at": {"type": "string", "format": "date-time"}, + }, + "required": ["created_at"], + } + Model = create_model_from_schema(schema) + dt = datetime.datetime(2025, 6, 1, 12, 0, 0) + obj = Model(created_at=dt) + assert obj.created_at == dt + + def test_time_format(self) -> None: + schema = { + "type": "object", + "properties": { + "alarm": {"type": "string", "format": "time"}, + }, + "required": ["alarm"], + } + Model = create_model_from_schema(schema) + t = datetime.time(8, 30) + obj = Model(alarm=t) + assert obj.alarm == t + + +class TestNestedObjects: + def test_nested_object_creates_model(self) -> None: + schema = { + "type": "object", + "properties": { + "address": { + "type": "object", + "properties": { + "street": {"type": "string"}, + "city": {"type": "string"}, + }, + "required": ["street", "city"], + }, + }, + "required": ["address"], + } + Model = create_model_from_schema(schema) + obj = Model(address={"street": "123 Main", "city": "Springfield"}) + assert obj.address.street == "123 Main" + assert obj.address.city == "Springfield" + + def test_object_without_properties_returns_dict(self) -> None: + schema = { + "type": "object", + "properties": { + "metadata": {"type": "object"}, + }, + "required": ["metadata"], + } + Model = create_model_from_schema(schema) + obj = Model(metadata={"key": "value"}) + assert obj.metadata == {"key": "value"} + + +class TestTypedArrays: + def test_array_of_strings(self) -> None: + schema = { + "type": "object", + "properties": { + "tags": {"type": "array", "items": {"type": "string"}}, + }, + "required": ["tags"], + } + Model = create_model_from_schema(schema) + obj = Model(tags=["a", "b", "c"]) + assert obj.tags == ["a", "b", "c"] + + def test_array_of_objects(self) -> None: + schema = { + "type": "object", + "properties": { + "items": { + "type": "array", + "items": { + "type": "object", + "properties": {"id": {"type": "integer"}}, + "required": ["id"], + }, + }, + }, + "required": ["items"], + } + Model = create_model_from_schema(schema) + obj = Model(items=[{"id": 1}, {"id": 2}]) + assert len(obj.items) == 2 + assert obj.items[0].id == 1 + + def test_untyped_array(self) -> None: + schema = { + "type": "object", + "properties": {"data": {"type": "array"}}, + "required": ["data"], + } + Model = create_model_from_schema(schema) + obj = Model(data=[1, "two", 3.0]) + assert obj.data == [1, "two", 3.0] + + +class TestUnionTypes: + def test_anyof_string_or_integer(self) -> None: + schema = { + "type": "object", + "properties": { + "value": { + "anyOf": [{"type": "string"}, {"type": "integer"}], + }, + }, + "required": ["value"], + } + Model = create_model_from_schema(schema) + assert Model(value="hello").value == "hello" + assert Model(value=42).value == 42 + + def test_oneof(self) -> None: + schema = { + "type": "object", + "properties": { + "value": { + "oneOf": [{"type": "string"}, {"type": "number"}], + }, + }, + "required": ["value"], + } + Model = create_model_from_schema(schema) + assert Model(value="hello").value == "hello" + assert Model(value=3.14).value == pytest.approx(3.14) + + +class TestAllOfMerging: + def test_allof_merges_properties(self) -> None: + schema = { + "type": "object", + "allOf": [ + { + "type": "object", + "properties": {"name": {"type": "string"}}, + "required": ["name"], + }, + { + "type": "object", + "properties": {"age": {"type": "integer"}}, + "required": ["age"], + }, + ], + } + Model = create_model_from_schema(schema) + obj = Model(name="Alice", age=30) + assert obj.name == "Alice" + assert obj.age == 30 + + def test_single_allof(self) -> None: + schema = { + "type": "object", + "properties": { + "item": { + "allOf": [ + { + "type": "object", + "properties": {"id": {"type": "integer"}}, + "required": ["id"], + } + ] + } + }, + "required": ["item"], + } + Model = create_model_from_schema(schema) + obj = Model(item={"id": 1}) + assert obj.item.id == 1 + + +# --------------------------------------------------------------------------- +# $ref resolution +# --------------------------------------------------------------------------- + + +class TestRefResolution: + def test_ref_in_property(self) -> None: + schema = { + "type": "object", + "properties": { + "item": {"$ref": "#/$defs/Item"}, + }, + "required": ["item"], + "$defs": { + "Item": { + "type": "object", + "title": "Item", + "properties": {"name": {"type": "string"}}, + "required": ["name"], + }, + }, + } + Model = create_model_from_schema(schema) + obj = Model(item={"name": "Widget"}) + assert obj.item.name == "Widget" + + +# --------------------------------------------------------------------------- +# model_name parameter +# --------------------------------------------------------------------------- + + +class TestModelName: + def test_model_name_override(self) -> None: + schema = { + "type": "object", + "title": "OriginalName", + "properties": {"x": {"type": "integer"}}, + "required": ["x"], + } + Model = create_model_from_schema(schema, model_name="CustomSchema") + assert Model.__name__ == "CustomSchema" + + def test_model_name_fallback_to_title(self) -> None: + schema = { + "type": "object", + "title": "FromTitle", + "properties": {"x": {"type": "integer"}}, + "required": ["x"], + } + Model = create_model_from_schema(schema) + assert Model.__name__ == "FromTitle" + + def test_model_name_fallback_to_dynamic(self) -> None: + schema = { + "type": "object", + "properties": {"x": {"type": "integer"}}, + "required": ["x"], + } + Model = create_model_from_schema(schema) + assert Model.__name__ == "DynamicModel" + + +# --------------------------------------------------------------------------- +# enrich_descriptions +# --------------------------------------------------------------------------- + + +class TestEnrichDescriptions: + def test_enriched_description_includes_constraints(self) -> None: + schema = { + "type": "object", + "properties": { + "score": { + "type": "integer", + "description": "The score value", + "minimum": 0, + "maximum": 100, + }, + }, + "required": ["score"], + } + Model = create_model_from_schema(schema, enrich_descriptions=True) + field_info = Model.model_fields["score"] + assert "Minimum: 0" in field_info.description + assert "Maximum: 100" in field_info.description + assert "The score value" in field_info.description + + def test_default_does_not_enrich(self) -> None: + schema = { + "type": "object", + "properties": { + "score": { + "type": "integer", + "description": "The score value", + "minimum": 0, + }, + }, + "required": ["score"], + } + Model = create_model_from_schema(schema, enrich_descriptions=False) + field_info = Model.model_fields["score"] + assert field_info.description == "The score value" + + def test_enriched_description_propagates_to_nested(self) -> None: + schema = { + "type": "object", + "properties": { + "config": { + "type": "object", + "properties": { + "level": { + "type": "integer", + "description": "Level", + "minimum": 1, + "maximum": 10, + }, + }, + "required": ["level"], + }, + }, + "required": ["config"], + } + Model = create_model_from_schema(schema, enrich_descriptions=True) + nested_model = Model.model_fields["config"].annotation + nested_field = nested_model.model_fields["level"] + assert "Minimum: 1" in nested_field.description + assert "Maximum: 10" in nested_field.description + + +# --------------------------------------------------------------------------- +# Edge cases +# --------------------------------------------------------------------------- + + +class TestEdgeCases: + def test_empty_properties(self) -> None: + schema = {"type": "object", "properties": {}, "required": []} + Model = create_model_from_schema(schema) + obj = Model() + assert obj is not None + + def test_no_properties_key(self) -> None: + schema = {"type": "object"} + Model = create_model_from_schema(schema) + obj = Model() + assert obj is not None + + def test_unknown_type_raises(self) -> None: + schema = { + "type": "object", + "properties": { + "weird": {"type": "hyperspace"}, + }, + "required": ["weird"], + } + with pytest.raises(ValueError, match="Unsupported JSON schema type"): + create_model_from_schema(schema) + + +# --------------------------------------------------------------------------- +# build_rich_field_description +# --------------------------------------------------------------------------- + + +class TestBuildRichFieldDescription: + def test_description_only(self) -> None: + assert build_rich_field_description({"description": "A name"}) == "A name" + + def test_empty_schema(self) -> None: + assert build_rich_field_description({}) == "" + + def test_format(self) -> None: + desc = build_rich_field_description({"format": "date-time"}) + assert "Format: date-time" in desc + + def test_enum(self) -> None: + desc = build_rich_field_description({"enum": ["a", "b"]}) + assert "Allowed values:" in desc + assert "'a'" in desc + assert "'b'" in desc + + def test_pattern(self) -> None: + desc = build_rich_field_description({"pattern": "^[a-z]+$"}) + assert "Pattern: ^[a-z]+$" in desc + + def test_min_max(self) -> None: + desc = build_rich_field_description({"minimum": 0, "maximum": 100}) + assert "Minimum: 0" in desc + assert "Maximum: 100" in desc + + def test_min_max_length(self) -> None: + desc = build_rich_field_description({"minLength": 1, "maxLength": 255}) + assert "Min length: 1" in desc + assert "Max length: 255" in desc + + def test_examples(self) -> None: + desc = build_rich_field_description({"examples": ["foo", "bar", "baz", "extra"]}) + assert "Examples:" in desc + assert "'foo'" in desc + assert "'baz'" in desc + # Only first 3 shown + assert "'extra'" not in desc + + def test_combined_constraints(self) -> None: + desc = build_rich_field_description({ + "description": "A score", + "minimum": 0, + "maximum": 10, + "format": "int32", + }) + assert desc.startswith("A score") + assert "Minimum: 0" in desc + assert "Maximum: 10" in desc + assert "Format: int32" in desc + + +# --------------------------------------------------------------------------- +# Schema transformation functions +# --------------------------------------------------------------------------- + + +class TestResolveRefs: + def test_basic_ref_resolution(self) -> None: + schema = { + "type": "object", + "properties": {"item": {"$ref": "#/$defs/Item"}}, + "$defs": { + "Item": {"type": "object", "properties": {"id": {"type": "integer"}}}, + }, + } + resolved = resolve_refs(schema) + assert "$ref" not in resolved["properties"]["item"] + assert resolved["properties"]["item"]["type"] == "object" + + def test_nested_ref_resolution(self) -> None: + schema = { + "type": "object", + "properties": {"wrapper": {"$ref": "#/$defs/Wrapper"}}, + "$defs": { + "Wrapper": { + "type": "object", + "properties": {"inner": {"$ref": "#/$defs/Inner"}}, + }, + "Inner": {"type": "string"}, + }, + } + resolved = resolve_refs(schema) + wrapper = resolved["properties"]["wrapper"] + assert wrapper["properties"]["inner"]["type"] == "string" + + def test_missing_ref_raises(self) -> None: + schema = { + "properties": {"x": {"$ref": "#/$defs/Missing"}}, + "$defs": {}, + } + with pytest.raises(KeyError, match="Missing"): + resolve_refs(schema) + + def test_no_refs_unchanged(self) -> None: + schema = { + "type": "object", + "properties": {"name": {"type": "string"}}, + } + resolved = resolve_refs(schema) + assert resolved == schema + + +class TestForceAdditionalPropertiesFalse: + def test_adds_to_object(self) -> None: + schema = {"type": "object", "properties": {"x": {"type": "integer"}}} + result = force_additional_properties_false(deepcopy(schema)) + assert result["additionalProperties"] is False + + def test_adds_empty_properties_and_required(self) -> None: + schema = {"type": "object"} + result = force_additional_properties_false(deepcopy(schema)) + assert result["properties"] == {} + assert result["required"] == [] + + def test_recursive_nested_objects(self) -> None: + schema = { + "type": "object", + "properties": { + "child": { + "type": "object", + "properties": {"id": {"type": "integer"}}, + }, + }, + } + result = force_additional_properties_false(deepcopy(schema)) + assert result["additionalProperties"] is False + assert result["properties"]["child"]["additionalProperties"] is False + + def test_does_not_affect_non_objects(self) -> None: + schema = {"type": "string"} + result = force_additional_properties_false(deepcopy(schema)) + assert "additionalProperties" not in result + + +class TestStripUnsupportedFormats: + def test_removes_email_format(self) -> None: + schema = {"type": "string", "format": "email"} + result = strip_unsupported_formats(deepcopy(schema)) + assert "format" not in result + + def test_keeps_date_time(self) -> None: + schema = {"type": "string", "format": "date-time"} + result = strip_unsupported_formats(deepcopy(schema)) + assert result["format"] == "date-time" + + def test_keeps_date(self) -> None: + schema = {"type": "string", "format": "date"} + result = strip_unsupported_formats(deepcopy(schema)) + assert result["format"] == "date" + + def test_removes_uri_format(self) -> None: + schema = {"type": "string", "format": "uri"} + result = strip_unsupported_formats(deepcopy(schema)) + assert "format" not in result + + def test_recursive(self) -> None: + schema = { + "type": "object", + "properties": { + "email": {"type": "string", "format": "email"}, + "created": {"type": "string", "format": "date-time"}, + }, + } + result = strip_unsupported_formats(deepcopy(schema)) + assert "format" not in result["properties"]["email"] + assert result["properties"]["created"]["format"] == "date-time" + + +class TestEnsureTypeInSchemas: + def test_empty_schema_in_anyof_gets_type(self) -> None: + schema = {"anyOf": [{}, {"type": "string"}]} + result = ensure_type_in_schemas(deepcopy(schema)) + assert result["anyOf"][0] == {"type": "object"} + + def test_empty_schema_in_oneof_gets_type(self) -> None: + schema = {"oneOf": [{}, {"type": "integer"}]} + result = ensure_type_in_schemas(deepcopy(schema)) + assert result["oneOf"][0] == {"type": "object"} + + def test_non_empty_unchanged(self) -> None: + schema = {"anyOf": [{"type": "string"}, {"type": "integer"}]} + result = ensure_type_in_schemas(deepcopy(schema)) + assert result == schema + + +class TestConvertOneofToAnyof: + def test_converts_top_level(self) -> None: + schema = {"oneOf": [{"type": "string"}, {"type": "integer"}]} + result = convert_oneof_to_anyof(deepcopy(schema)) + assert "oneOf" not in result + assert "anyOf" in result + assert len(result["anyOf"]) == 2 + + def test_converts_nested(self) -> None: + schema = { + "type": "object", + "properties": { + "value": {"oneOf": [{"type": "string"}, {"type": "number"}]}, + }, + } + result = convert_oneof_to_anyof(deepcopy(schema)) + assert "anyOf" in result["properties"]["value"] + assert "oneOf" not in result["properties"]["value"] + + +class TestEnsureAllPropertiesRequired: + def test_makes_all_required(self) -> None: + schema = { + "type": "object", + "properties": {"a": {"type": "string"}, "b": {"type": "integer"}}, + "required": ["a"], + } + result = ensure_all_properties_required(deepcopy(schema)) + assert set(result["required"]) == {"a", "b"} + + def test_recursive(self) -> None: + schema = { + "type": "object", + "properties": { + "child": { + "type": "object", + "properties": {"x": {"type": "integer"}, "y": {"type": "integer"}}, + "required": [], + }, + }, + } + result = ensure_all_properties_required(deepcopy(schema)) + assert set(result["properties"]["child"]["required"]) == {"x", "y"} + + +class TestStripNullFromTypes: + def test_strips_null_from_anyof(self) -> None: + schema = { + "anyOf": [{"type": "string"}, {"type": "null"}], + } + result = strip_null_from_types(deepcopy(schema)) + assert "anyOf" not in result + assert result["type"] == "string" + + def test_strips_null_from_type_array(self) -> None: + schema = {"type": ["string", "null"]} + result = strip_null_from_types(deepcopy(schema)) + assert result["type"] == "string" + + def test_multiple_non_null_in_anyof(self) -> None: + schema = { + "anyOf": [{"type": "string"}, {"type": "integer"}, {"type": "null"}], + } + result = strip_null_from_types(deepcopy(schema)) + assert len(result["anyOf"]) == 2 + + def test_no_null_unchanged(self) -> None: + schema = {"type": "string"} + result = strip_null_from_types(deepcopy(schema)) + assert result == schema + + +class TestEndToEndMCPSchema: + """Realistic MCP tool schema exercising multiple features simultaneously.""" + + MCP_SCHEMA: dict[str, Any] = { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "Search query", + "minLength": 1, + "maxLength": 500, + }, + "max_results": { + "type": "integer", + "description": "Maximum results", + "minimum": 1, + "maximum": 100, + }, + "format": { + "type": "string", + "enum": ["json", "csv", "xml"], + "description": "Output format", + }, + "filters": { + "type": "object", + "properties": { + "date_from": {"type": "string", "format": "date"}, + "date_to": {"type": "string", "format": "date"}, + "categories": { + "type": "array", + "items": {"type": "string"}, + }, + }, + "required": ["date_from"], + }, + "sort_order": { + "anyOf": [{"type": "string"}, {"type": "null"}], + }, + }, + "required": ["query", "format", "filters"], + } + + def test_model_creation(self) -> None: + Model = create_model_from_schema(self.MCP_SCHEMA) + assert Model is not None + assert issubclass(Model, BaseModel) + + def test_valid_input_accepted(self) -> None: + Model = create_model_from_schema(self.MCP_SCHEMA) + obj = Model( + query="test search", + format="json", + filters={"date_from": "2025-01-01"}, + ) + assert obj.query == "test search" + assert obj.format == "json" + + def test_invalid_enum_rejected(self) -> None: + Model = create_model_from_schema(self.MCP_SCHEMA) + with pytest.raises(Exception): + Model( + query="test", + format="yaml", + filters={"date_from": "2025-01-01"}, + ) + + def test_model_name_for_mcp_tool(self) -> None: + Model = create_model_from_schema( + self.MCP_SCHEMA, model_name="search_toolSchema" + ) + assert Model.__name__ == "search_toolSchema" + + def test_enriched_descriptions_for_mcp(self) -> None: + Model = create_model_from_schema( + self.MCP_SCHEMA, enrich_descriptions=True + ) + query_field = Model.model_fields["query"] + assert "Min length: 1" in query_field.description + assert "Max length: 500" in query_field.description + + max_results_field = Model.model_fields["max_results"] + assert "Minimum: 1" in max_results_field.description + assert "Maximum: 100" in max_results_field.description + + format_field = Model.model_fields["format"] + assert "Allowed values:" in format_field.description + + def test_optional_fields_accept_none(self) -> None: + Model = create_model_from_schema(self.MCP_SCHEMA) + obj = Model( + query="test", + format="csv", + filters={"date_from": "2025-01-01"}, + max_results=None, + sort_order=None, + ) + assert obj.max_results is None + assert obj.sort_order is None + + def test_nested_filters_validated(self) -> None: + Model = create_model_from_schema(self.MCP_SCHEMA) + obj = Model( + query="test", + format="xml", + filters={ + "date_from": "2025-01-01", + "date_to": "2025-12-31", + "categories": ["news", "tech"], + }, + ) + assert obj.filters.date_from == datetime.date(2025, 1, 1) + assert obj.filters.categories == ["news", "tech"]