Enhance MCP tool resolution and related events (#4580)

* feat: enhance MCP tool resolution

* feat: emit event when MCP configuration fails

* feat: emit event when MCP tool execution has failed

* style: resolve linter issues

* refactor: use clear and natural mcp tool name resolution

* test: fix broken tests

* fix: resolve MCP connection leaks, slug validation, duplicate connections, and httpx exception handling

---------

Co-authored-by: Greyson LaLonde <greyson.r.lalonde@gmail.com>
Co-authored-by: Greyson LaLonde <greyson@crewai.com>
This commit is contained in:
Lucas Gomide
2026-02-26 18:59:30 -03:00
committed by GitHub
parent c4a328c9d5
commit d259150d8d
17 changed files with 2112 additions and 611 deletions

View File

@@ -8,11 +8,9 @@ import time
from typing import (
TYPE_CHECKING,
Any,
Final,
Literal,
cast,
)
from urllib.parse import urlparse
from pydantic import (
BaseModel,
@@ -61,16 +59,8 @@ from crewai.knowledge.knowledge import Knowledge
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
from crewai.lite_agent_output import LiteAgentOutput
from crewai.llms.base_llm import BaseLLM
from crewai.mcp import (
MCPClient,
MCPServerConfig,
MCPServerHTTP,
MCPServerSSE,
MCPServerStdio,
)
from crewai.mcp.transports.http import HTTPTransport
from crewai.mcp.transports.sse import SSETransport
from crewai.mcp.transports.stdio import StdioTransport
from crewai.mcp import MCPServerConfig
from crewai.mcp.tool_resolver import MCPToolResolver
from crewai.rag.embeddings.types import EmbedderConfig
from crewai.security.fingerprint import Fingerprint
from crewai.tools.agent_tools.agent_tools import AgentTools
@@ -111,18 +101,8 @@ if TYPE_CHECKING:
from crewai.utilities.types import LLMMessage
# MCP Connection timeout constants (in seconds)
MCP_CONNECTION_TIMEOUT: Final[int] = 10
MCP_TOOL_EXECUTION_TIMEOUT: Final[int] = 30
MCP_DISCOVERY_TIMEOUT: Final[int] = 15
MCP_MAX_RETRIES: Final[int] = 3
_passthrough_exceptions: tuple[type[Exception], ...] = ()
# Simple in-memory cache for MCP tool schemas (duration: 5 minutes)
_mcp_schema_cache: dict[str, Any] = {}
_cache_ttl: Final[int] = 300 # 5 minutes
class Agent(BaseAgent):
"""Represents an agent in a system.
@@ -154,7 +134,7 @@ class Agent(BaseAgent):
model_config = ConfigDict()
_times_executed: int = PrivateAttr(default=0)
_mcp_clients: list[Any] = PrivateAttr(default_factory=list)
_mcp_resolver: MCPToolResolver | None = PrivateAttr(default=None)
_last_messages: list[LLMMessage] = PrivateAttr(default_factory=list)
max_execution_time: int | None = Field(
default=None,
@@ -934,544 +914,17 @@ class Agent(BaseAgent):
def get_mcp_tools(self, mcps: list[str | MCPServerConfig]) -> list[BaseTool]:
"""Convert MCP server references/configs to CrewAI tools.
Supports both string references (backwards compatible) and structured
configuration objects (MCPServerStdio, MCPServerHTTP, MCPServerSSE).
Args:
mcps: List of MCP server references (strings) or configurations.
Returns:
List of BaseTool instances from MCP servers.
Delegates to :class:`~crewai.mcp.tool_resolver.MCPToolResolver`.
"""
all_tools = []
clients = []
for mcp_config in mcps:
if isinstance(mcp_config, str):
tools = self._get_mcp_tools_from_string(mcp_config)
else:
tools, client = self._get_native_mcp_tools(mcp_config)
if client:
clients.append(client)
all_tools.extend(tools)
# Store clients for cleanup
self._mcp_clients.extend(clients)
return all_tools
self._cleanup_mcp_clients()
self._mcp_resolver = MCPToolResolver(agent=self, logger=self._logger)
return self._mcp_resolver.resolve(mcps)
def _cleanup_mcp_clients(self) -> None:
"""Cleanup MCP client connections after task execution."""
if not self._mcp_clients:
return
async def _disconnect_all() -> None:
for client in self._mcp_clients:
if client and hasattr(client, "connected") and client.connected:
await client.disconnect()
try:
asyncio.run(_disconnect_all())
except Exception as e:
self._logger.log("error", f"Error during MCP client cleanup: {e}")
finally:
self._mcp_clients.clear()
def _get_mcp_tools_from_string(self, mcp_ref: str) -> list[BaseTool]:
"""Get tools from legacy string-based MCP references.
This method maintains backwards compatibility with string-based
MCP references (https://... and crewai-amp:...).
Args:
mcp_ref: String reference to MCP server.
Returns:
List of BaseTool instances.
"""
if mcp_ref.startswith("crewai-amp:"):
return self._get_amp_mcp_tools(mcp_ref)
if mcp_ref.startswith("https://"):
return self._get_external_mcp_tools(mcp_ref)
return []
def _get_external_mcp_tools(self, mcp_ref: str) -> list[BaseTool]:
"""Get tools from external HTTPS MCP server with graceful error handling."""
from crewai.tools.mcp_tool_wrapper import MCPToolWrapper
# Parse server URL and optional tool name
if "#" in mcp_ref:
server_url, specific_tool = mcp_ref.split("#", 1)
else:
server_url, specific_tool = mcp_ref, None
server_params = {"url": server_url}
server_name = self._extract_server_name(server_url)
try:
# Get tool schemas with timeout and error handling
tool_schemas = self._get_mcp_tool_schemas(server_params)
if not tool_schemas:
self._logger.log(
"warning", f"No tools discovered from MCP server: {server_url}"
)
return []
tools = []
for tool_name, schema in tool_schemas.items():
# Skip if specific tool requested and this isn't it
if specific_tool and tool_name != specific_tool:
continue
try:
wrapper = MCPToolWrapper(
mcp_server_params=server_params,
tool_name=tool_name,
tool_schema=schema,
server_name=server_name,
)
tools.append(wrapper)
except Exception as e:
self._logger.log(
"warning",
f"Failed to create MCP tool wrapper for {tool_name}: {e}",
)
continue
if specific_tool and not tools:
self._logger.log(
"warning",
f"Specific tool '{specific_tool}' not found on MCP server: {server_url}",
)
return cast(list[BaseTool], tools)
except Exception as e:
self._logger.log(
"warning", f"Failed to connect to MCP server {server_url}: {e}"
)
return []
def _get_native_mcp_tools(
self, mcp_config: MCPServerConfig
) -> tuple[list[BaseTool], Any | None]:
"""Get tools from MCP server using structured configuration.
This method creates an MCP client based on the configuration type,
connects to the server, discovers tools, applies filtering, and
returns wrapped tools along with the client instance for cleanup.
Args:
mcp_config: MCP server configuration (MCPServerStdio, MCPServerHTTP, or MCPServerSSE).
Returns:
Tuple of (list of BaseTool instances, MCPClient instance for cleanup).
"""
from crewai.tools.base_tool import BaseTool
from crewai.tools.mcp_native_tool import MCPNativeTool
transport: StdioTransport | HTTPTransport | SSETransport
if isinstance(mcp_config, MCPServerStdio):
transport = StdioTransport(
command=mcp_config.command,
args=mcp_config.args,
env=mcp_config.env,
)
server_name = f"{mcp_config.command}_{'_'.join(mcp_config.args)}"
elif isinstance(mcp_config, MCPServerHTTP):
transport = HTTPTransport(
url=mcp_config.url,
headers=mcp_config.headers,
streamable=mcp_config.streamable,
)
server_name = self._extract_server_name(mcp_config.url)
elif isinstance(mcp_config, MCPServerSSE):
transport = SSETransport(
url=mcp_config.url,
headers=mcp_config.headers,
)
server_name = self._extract_server_name(mcp_config.url)
else:
raise ValueError(f"Unsupported MCP server config type: {type(mcp_config)}")
client = MCPClient(
transport=transport,
cache_tools_list=mcp_config.cache_tools_list,
)
async def _setup_client_and_list_tools() -> list[dict[str, Any]]:
"""Async helper to connect and list tools in same event loop."""
try:
if not client.connected:
await client.connect()
tools_list = await client.list_tools()
try:
await client.disconnect()
# Small delay to allow background tasks to finish cleanup
# This helps prevent "cancel scope in different task" errors
# when asyncio.run() closes the event loop
await asyncio.sleep(0.1)
except Exception as e:
self._logger.log("error", f"Error during disconnect: {e}")
return tools_list
except Exception as e:
if client.connected:
await client.disconnect()
await asyncio.sleep(0.1)
raise RuntimeError(
f"Error during setup client and list tools: {e}"
) from e
try:
try:
asyncio.get_running_loop()
import concurrent.futures
with concurrent.futures.ThreadPoolExecutor() as executor:
future = executor.submit(
asyncio.run, _setup_client_and_list_tools()
)
tools_list = future.result()
except RuntimeError:
try:
tools_list = asyncio.run(_setup_client_and_list_tools())
except RuntimeError as e:
error_msg = str(e).lower()
if "cancel scope" in error_msg or "task" in error_msg:
raise ConnectionError(
"MCP connection failed due to event loop cleanup issues. "
"This may be due to authentication errors or server unavailability."
) from e
except asyncio.CancelledError as e:
raise ConnectionError(
"MCP connection was cancelled. This may indicate an authentication "
"error or server unavailability."
) from e
if mcp_config.tool_filter:
filtered_tools = []
for tool in tools_list:
if callable(mcp_config.tool_filter):
try:
from crewai.mcp.filters import ToolFilterContext
context = ToolFilterContext(
agent=self,
server_name=server_name,
run_context=None,
)
if mcp_config.tool_filter(context, tool): # type: ignore[call-arg, arg-type]
filtered_tools.append(tool)
except (TypeError, AttributeError):
if mcp_config.tool_filter(tool): # type: ignore[call-arg, arg-type]
filtered_tools.append(tool)
else:
# Not callable - include tool
filtered_tools.append(tool)
tools_list = filtered_tools
tools = []
for tool_def in tools_list:
tool_name = tool_def.get("name", "")
if not tool_name:
continue
# Convert inputSchema to Pydantic model if present
args_schema = None
if tool_def.get("inputSchema"):
args_schema = self._json_schema_to_pydantic(
tool_name, tool_def["inputSchema"]
)
tool_schema = {
"description": tool_def.get("description", ""),
"args_schema": args_schema,
}
try:
native_tool = MCPNativeTool(
mcp_client=client,
tool_name=tool_name,
tool_schema=tool_schema,
server_name=server_name,
)
tools.append(native_tool)
except Exception as e:
self._logger.log("error", f"Failed to create native MCP tool: {e}")
continue
return cast(list[BaseTool], tools), client
except Exception as e:
if client.connected:
asyncio.run(client.disconnect())
raise RuntimeError(f"Failed to get native MCP tools: {e}") from e
def _get_amp_mcp_tools(self, amp_ref: str) -> list[BaseTool]:
"""Get tools from CrewAI AMP MCP marketplace."""
# Parse: "crewai-amp:mcp-name" or "crewai-amp:mcp-name#tool_name"
amp_part = amp_ref.replace("crewai-amp:", "")
if "#" in amp_part:
mcp_name, specific_tool = amp_part.split("#", 1)
else:
mcp_name, specific_tool = amp_part, None
# Call AMP API to get MCP server URLs
mcp_servers = self._fetch_amp_mcp_servers(mcp_name)
tools = []
for server_config in mcp_servers:
server_ref = server_config["url"]
if specific_tool:
server_ref += f"#{specific_tool}"
server_tools = self._get_external_mcp_tools(server_ref)
tools.extend(server_tools)
return tools
@staticmethod
def _extract_server_name(server_url: str) -> str:
"""Extract clean server name from URL for tool prefixing."""
parsed = urlparse(server_url)
domain = parsed.netloc.replace(".", "_")
path = parsed.path.replace("/", "_").strip("_")
return f"{domain}_{path}" if path else domain
def _get_mcp_tool_schemas(
self, server_params: dict[str, Any]
) -> dict[str, dict[str, Any]]:
"""Get tool schemas from MCP server for wrapper creation with caching."""
server_url = server_params["url"]
# Check cache first
cache_key = server_url
current_time = time.time()
if cache_key in _mcp_schema_cache:
cached_data, cache_time = _mcp_schema_cache[cache_key]
if current_time - cache_time < _cache_ttl:
self._logger.log(
"debug", f"Using cached MCP tool schemas for {server_url}"
)
return cached_data # type: ignore[no-any-return]
try:
schemas = asyncio.run(self._get_mcp_tool_schemas_async(server_params))
# Cache successful results
_mcp_schema_cache[cache_key] = (schemas, current_time)
return schemas
except Exception as e:
# Log warning but don't raise - this allows graceful degradation
self._logger.log(
"warning", f"Failed to get MCP tool schemas from {server_url}: {e}"
)
return {}
async def _get_mcp_tool_schemas_async(
self, server_params: dict[str, Any]
) -> dict[str, dict[str, Any]]:
"""Async implementation of MCP tool schema retrieval with timeouts and retries."""
server_url = server_params["url"]
return await self._retry_mcp_discovery(
self._discover_mcp_tools_with_timeout, server_url
)
async def _retry_mcp_discovery(
self, operation_func: Any, server_url: str
) -> dict[str, dict[str, Any]]:
"""Retry MCP discovery operation with exponential backoff, avoiding try-except in loop."""
last_error = None
for attempt in range(MCP_MAX_RETRIES):
# Execute single attempt outside try-except loop structure
result, error, should_retry = await self._attempt_mcp_discovery(
operation_func, server_url
)
# Success case - return immediately
if result is not None:
return result
# Non-retryable error - raise immediately
if not should_retry:
raise RuntimeError(error)
# Retryable error - continue with backoff
last_error = error
if attempt < MCP_MAX_RETRIES - 1:
wait_time = 2**attempt # Exponential backoff
await asyncio.sleep(wait_time)
raise RuntimeError(
f"Failed to discover MCP tools after {MCP_MAX_RETRIES} attempts: {last_error}"
)
@staticmethod
async def _attempt_mcp_discovery(
operation_func: Any, server_url: str
) -> tuple[dict[str, dict[str, Any]] | None, str, bool]:
"""Attempt single MCP discovery operation and return (result, error_message, should_retry)."""
try:
result = await operation_func(server_url)
return result, "", False
except ImportError:
return (
None,
"MCP library not available. Please install with: pip install mcp",
False,
)
except asyncio.TimeoutError:
return (
None,
f"MCP discovery timed out after {MCP_DISCOVERY_TIMEOUT} seconds",
True,
)
except Exception as e:
error_str = str(e).lower()
# Classify errors as retryable or non-retryable
if "authentication" in error_str or "unauthorized" in error_str:
return None, f"Authentication failed for MCP server: {e!s}", False
if "connection" in error_str or "network" in error_str:
return None, f"Network connection failed: {e!s}", True
if "json" in error_str or "parsing" in error_str:
return None, f"Server response parsing error: {e!s}", True
return None, f"MCP discovery error: {e!s}", False
async def _discover_mcp_tools_with_timeout(
self, server_url: str
) -> dict[str, dict[str, Any]]:
"""Discover MCP tools with timeout wrapper."""
return await asyncio.wait_for(
self._discover_mcp_tools(server_url), timeout=MCP_DISCOVERY_TIMEOUT
)
async def _discover_mcp_tools(self, server_url: str) -> dict[str, dict[str, Any]]:
"""Discover tools from MCP server with proper timeout handling."""
from mcp import ClientSession
from mcp.client.streamable_http import streamablehttp_client
async with streamablehttp_client(server_url) as (read, write, _):
async with ClientSession(read, write) as session:
# Initialize the connection with timeout
await asyncio.wait_for(
session.initialize(), timeout=MCP_CONNECTION_TIMEOUT
)
# List available tools with timeout
tools_result = await asyncio.wait_for(
session.list_tools(),
timeout=MCP_DISCOVERY_TIMEOUT - MCP_CONNECTION_TIMEOUT,
)
schemas = {}
for tool in tools_result.tools:
args_schema = None
if hasattr(tool, "inputSchema") and tool.inputSchema:
args_schema = self._json_schema_to_pydantic(
sanitize_tool_name(tool.name), tool.inputSchema
)
schemas[sanitize_tool_name(tool.name)] = {
"description": getattr(tool, "description", ""),
"args_schema": args_schema,
}
return schemas
def _json_schema_to_pydantic(
self, tool_name: str, json_schema: dict[str, Any]
) -> type:
"""Convert JSON Schema to Pydantic model for tool arguments.
Args:
tool_name: Name of the tool (used for model naming)
json_schema: JSON Schema dict with 'properties', 'required', etc.
Returns:
Pydantic BaseModel class
"""
from pydantic import Field, create_model
properties = json_schema.get("properties", {})
required_fields = json_schema.get("required", [])
field_definitions: dict[str, Any] = {}
for field_name, field_schema in properties.items():
field_type = self._json_type_to_python(field_schema)
field_description = field_schema.get("description", "")
is_required = field_name in required_fields
if is_required:
field_definitions[field_name] = (
field_type,
Field(..., description=field_description),
)
else:
field_definitions[field_name] = (
field_type | None,
Field(default=None, description=field_description),
)
model_name = f"{tool_name.replace('-', '_').replace(' ', '_')}Schema"
return create_model(model_name, **field_definitions) # type: ignore[no-any-return]
def _json_type_to_python(self, field_schema: dict[str, Any]) -> type:
"""Convert JSON Schema type to Python type.
Args:
field_schema: JSON Schema field definition
Returns:
Python type
"""
json_type = field_schema.get("type")
if "anyOf" in field_schema:
types: list[type] = []
for option in field_schema["anyOf"]:
if "const" in option:
types.append(str)
else:
types.append(self._json_type_to_python(option))
unique_types = list(set(types))
if len(unique_types) > 1:
result: Any = unique_types[0]
for t in unique_types[1:]:
result = result | t
return result # type: ignore[no-any-return]
return unique_types[0]
type_mapping: dict[str | None, type] = {
"string": str,
"number": float,
"integer": int,
"boolean": bool,
"array": list,
"object": dict,
}
return type_mapping.get(json_type, Any)
@staticmethod
def _fetch_amp_mcp_servers(mcp_name: str) -> list[dict[str, Any]]:
"""Fetch MCP server configurations from CrewAI AMP API."""
# TODO: Implement AMP API call to "integrations/mcps" endpoint
# Should return list of server configs with URLs
return []
if self._mcp_resolver is not None:
self._mcp_resolver.cleanup()
self._mcp_resolver = None
@staticmethod
def get_multimodal_tools() -> Sequence[BaseTool]:

View File

@@ -4,7 +4,8 @@ from abc import ABC, abstractmethod
from collections.abc import Callable
from copy import copy as shallow_copy
from hashlib import md5
from typing import Any, Literal
import re
from typing import Any, Final, Literal
import uuid
from pydantic import (
@@ -36,6 +37,11 @@ from crewai.utilities.rpm_controller import RPMController
from crewai.utilities.string_utils import interpolate_only
_SLUG_RE: Final[re.Pattern[str]] = re.compile(
r"^(?:crewai-amp:)?[a-zA-Z0-9][a-zA-Z0-9_-]*(?:#\w+)?$"
)
PlatformApp = Literal[
"asana",
"box",
@@ -197,7 +203,7 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta):
)
mcps: list[str | MCPServerConfig] | None = Field(
default=None,
description="List of MCP server references. Supports 'https://server.com/path' for external servers and 'crewai-amp:mcp-name' for AMP marketplace. Use '#tool_name' suffix for specific tools.",
description="List of MCP server references. Supports 'https://server.com/path' for external servers and bare slugs like 'notion' for connected MCP integrations. Use '#tool_name' suffix for specific tools.",
)
memory: Any = Field(
default=None,
@@ -276,14 +282,16 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta):
validated_mcps: list[str | MCPServerConfig] = []
for mcp in mcps:
if isinstance(mcp, str):
if mcp.startswith(("https://", "crewai-amp:")):
if mcp.startswith("https://"):
validated_mcps.append(mcp)
elif _SLUG_RE.match(mcp):
validated_mcps.append(mcp)
else:
raise ValueError(
f"Invalid MCP reference: {mcp}. "
"String references must start with 'https://' or 'crewai-amp:'"
f"Invalid MCP reference: {mcp!r}. "
"String references must be an 'https://' URL or a valid "
"slug (e.g. 'notion', 'notion#search', 'crewai-amp:notion')."
)
elif isinstance(mcp, (MCPServerConfig)):
validated_mcps.append(mcp)
else:

View File

@@ -190,6 +190,15 @@ class PlusAPI:
timeout=30,
)
def get_mcp_configs(self, slugs: list[str]) -> httpx.Response:
"""Get MCP server configurations for the given slugs."""
return self._make_request(
"GET",
f"{self.INTEGRATIONS_RESOURCE}/mcp_configs",
params={"slugs": ",".join(slugs)},
timeout=30,
)
def get_triggers(self) -> httpx.Response:
"""Get all available triggers from integrations."""
return self._make_request("GET", f"{self.INTEGRATIONS_RESOURCE}/apps")

View File

@@ -63,6 +63,7 @@ from crewai.events.types.logging_events import (
AgentLogsStartedEvent,
)
from crewai.events.types.mcp_events import (
MCPConfigFetchFailedEvent,
MCPConnectionCompletedEvent,
MCPConnectionFailedEvent,
MCPConnectionStartedEvent,
@@ -165,6 +166,7 @@ __all__ = [
"LiteAgentExecutionCompletedEvent",
"LiteAgentExecutionErrorEvent",
"LiteAgentExecutionStartedEvent",
"MCPConfigFetchFailedEvent",
"MCPConnectionCompletedEvent",
"MCPConnectionFailedEvent",
"MCPConnectionStartedEvent",

View File

@@ -68,6 +68,7 @@ from crewai.events.types.logging_events import (
AgentLogsStartedEvent,
)
from crewai.events.types.mcp_events import (
MCPConfigFetchFailedEvent,
MCPConnectionCompletedEvent,
MCPConnectionFailedEvent,
MCPConnectionStartedEvent,
@@ -665,6 +666,16 @@ class EventListener(BaseEventListener):
event.error_type,
)
@crewai_event_bus.on(MCPConfigFetchFailedEvent)
def on_mcp_config_fetch_failed(
_: Any, event: MCPConfigFetchFailedEvent
) -> None:
self.formatter.handle_mcp_config_fetch_failed(
event.slug,
event.error,
event.error_type,
)
@crewai_event_bus.on(MCPToolExecutionStartedEvent)
def on_mcp_tool_execution_started(
_: Any, event: MCPToolExecutionStartedEvent

View File

@@ -67,6 +67,7 @@ from crewai.events.types.llm_guardrail_events import (
LLMGuardrailStartedEvent,
)
from crewai.events.types.mcp_events import (
MCPConfigFetchFailedEvent,
MCPConnectionCompletedEvent,
MCPConnectionFailedEvent,
MCPConnectionStartedEvent,
@@ -181,4 +182,5 @@ EventTypes = (
| MCPToolExecutionStartedEvent
| MCPToolExecutionCompletedEvent
| MCPToolExecutionFailedEvent
| MCPConfigFetchFailedEvent
)

View File

@@ -83,3 +83,16 @@ class MCPToolExecutionFailedEvent(MCPEvent):
error_type: str | None = None # "timeout", "validation", "server_error", etc.
started_at: datetime | None = None
failed_at: datetime | None = None
class MCPConfigFetchFailedEvent(BaseEvent):
"""Event emitted when fetching an AMP MCP server config fails.
This covers cases where the slug is not connected, the API call
failed, or native MCP resolution failed after config was fetched.
"""
type: str = "mcp_config_fetch_failed"
slug: str
error: str
error_type: str | None = None # "not_connected", "api_error", "connection_failed"

View File

@@ -1512,6 +1512,34 @@ To enable tracing, do any one of these:
self.print(panel)
self.print()
def handle_mcp_config_fetch_failed(
self,
slug: str,
error: str = "",
error_type: str | None = None,
) -> None:
"""Handle MCP config fetch failed event (AMP resolution failures)."""
if not self.verbose:
return
content = Text()
content.append("MCP Config Fetch Failed\n\n", style="red bold")
content.append("Server: ", style="white")
content.append(f"{slug}\n", style="red")
if error_type:
content.append("Error Type: ", style="white")
content.append(f"{error_type}\n", style="red")
if error:
content.append("\nError: ", style="white bold")
error_preview = error[:500] + "..." if len(error) > 500 else error
content.append(f"{error_preview}\n", style="red")
panel = self.create_panel(content, "❌ MCP Config Failed", "red")
self.print(panel)
self.print()
def handle_mcp_tool_execution_started(
self,
server_name: str,

View File

@@ -18,6 +18,7 @@ from crewai.mcp.filters import (
create_dynamic_tool_filter,
create_static_tool_filter,
)
from crewai.mcp.tool_resolver import MCPToolResolver
from crewai.mcp.transports.base import BaseTransport, TransportType
@@ -28,6 +29,7 @@ __all__ = [
"MCPServerHTTP",
"MCPServerSSE",
"MCPServerStdio",
"MCPToolResolver",
"StaticToolFilter",
"ToolFilter",
"ToolFilterContext",

View File

@@ -6,7 +6,7 @@ from contextlib import AsyncExitStack
from datetime import datetime
import logging
import time
from typing import Any
from typing import Any, NamedTuple
from typing_extensions import Self
@@ -34,6 +34,13 @@ from crewai.mcp.transports.stdio import StdioTransport
from crewai.utilities.string_utils import sanitize_tool_name
class _MCPToolResult(NamedTuple):
"""Internal result from an MCP tool call, carrying the ``isError`` flag."""
content: str
is_error: bool
# MCP Connection timeout constants (in seconds)
MCP_CONNECTION_TIMEOUT = 30 # Increased for slow servers
MCP_TOOL_EXECUTION_TIMEOUT = 30
@@ -420,6 +427,7 @@ class MCPClient:
return [
{
"name": sanitize_tool_name(tool.name),
"original_name": tool.name,
"description": getattr(tool, "description", ""),
"inputSchema": getattr(tool, "inputSchema", {}),
}
@@ -461,29 +469,46 @@ class MCPClient:
)
try:
result = await self._retry_operation(
tool_result: _MCPToolResult = await self._retry_operation(
lambda: self._call_tool_impl(tool_name, cleaned_arguments),
timeout=self.execution_timeout,
)
completed_at = datetime.now()
execution_duration_ms = (completed_at - started_at).total_seconds() * 1000
crewai_event_bus.emit(
self,
MCPToolExecutionCompletedEvent(
server_name=server_name,
server_url=server_url,
transport_type=transport_type,
tool_name=tool_name,
tool_args=cleaned_arguments,
result=result,
started_at=started_at,
completed_at=completed_at,
execution_duration_ms=execution_duration_ms,
),
)
finished_at = datetime.now()
execution_duration_ms = (finished_at - started_at).total_seconds() * 1000
return result
if tool_result.is_error:
crewai_event_bus.emit(
self,
MCPToolExecutionFailedEvent(
server_name=server_name,
server_url=server_url,
transport_type=transport_type,
tool_name=tool_name,
tool_args=cleaned_arguments,
error=tool_result.content,
error_type="tool_error",
started_at=started_at,
failed_at=finished_at,
),
)
else:
crewai_event_bus.emit(
self,
MCPToolExecutionCompletedEvent(
server_name=server_name,
server_url=server_url,
transport_type=transport_type,
tool_name=tool_name,
tool_args=cleaned_arguments,
result=tool_result.content,
started_at=started_at,
completed_at=finished_at,
execution_duration_ms=execution_duration_ms,
),
)
return tool_result.content
except Exception as e:
failed_at = datetime.now()
error_type = (
@@ -564,23 +589,27 @@ class MCPClient:
return cleaned
async def _call_tool_impl(self, tool_name: str, arguments: dict[str, Any]) -> Any:
async def _call_tool_impl(
self, tool_name: str, arguments: dict[str, Any]
) -> _MCPToolResult:
"""Internal implementation of call_tool."""
result = await asyncio.wait_for(
self.session.call_tool(tool_name, arguments),
timeout=self.execution_timeout,
)
is_error = getattr(result, "isError", False) or False
# Extract result content
if hasattr(result, "content") and result.content:
if isinstance(result.content, list) and len(result.content) > 0:
content_item = result.content[0]
if hasattr(content_item, "text"):
return str(content_item.text)
return str(content_item)
return str(result.content)
return _MCPToolResult(str(content_item.text), is_error)
return _MCPToolResult(str(content_item), is_error)
return _MCPToolResult(str(result.content), is_error)
return str(result)
return _MCPToolResult(str(result), is_error)
async def list_prompts(self) -> list[dict[str, Any]]:
"""List available prompts from MCP server.

View File

@@ -0,0 +1,592 @@
"""MCP tool resolution for CrewAI agents.
This module extracts all MCP-related tool resolution logic from the Agent class
into a standalone MCPToolResolver. It handles three flavours of MCP reference:
1. Native configs: MCPServerStdio / MCPServerHTTP / MCPServerSSE objects.
2. HTTPS URLs: e.g. "https://mcp.example.com/api"
3. AMP references: e.g. "notion" or "notion#search" (legacy "crewai-amp:" prefix also works)
"""
from __future__ import annotations
import asyncio
import time
from typing import TYPE_CHECKING, Any, Final, cast
from urllib.parse import urlparse
from crewai.mcp.client import MCPClient
from crewai.mcp.config import (
MCPServerConfig,
MCPServerHTTP,
MCPServerSSE,
MCPServerStdio,
)
from crewai.mcp.transports.http import HTTPTransport
from crewai.mcp.transports.sse import SSETransport
from crewai.mcp.transports.stdio import StdioTransport
if TYPE_CHECKING:
from crewai.tools.base_tool import BaseTool
from crewai.utilities.logger import Logger
MCP_CONNECTION_TIMEOUT: Final[int] = 10
MCP_TOOL_EXECUTION_TIMEOUT: Final[int] = 30
MCP_DISCOVERY_TIMEOUT: Final[int] = 15
MCP_MAX_RETRIES: Final[int] = 3
_mcp_schema_cache: dict[str, Any] = {}
_cache_ttl: Final[int] = 300 # 5 minutes
class MCPToolResolver:
"""Resolves MCP server references / configs into CrewAI ``BaseTool`` instances.
Typical lifecycle::
resolver = MCPToolResolver(agent=my_agent, logger=my_agent._logger)
tools = resolver.resolve(my_agent.mcps)
# … agent executes tasks using *tools* …
resolver.cleanup()
The resolver owns the MCP client connections it creates and is responsible
for tearing them down via :meth:`cleanup`.
"""
def __init__(self, agent: Any, logger: Logger) -> None:
self._agent = agent
self._logger = logger
self._clients: list[Any] = []
@property
def clients(self) -> list[Any]:
return list(self._clients)
def resolve(self, mcps: list[str | MCPServerConfig]) -> list[BaseTool]:
"""Convert MCP server references/configs to CrewAI tools."""
all_tools: list[BaseTool] = []
amp_refs: list[tuple[str, str | None]] = []
for mcp_config in mcps:
if isinstance(mcp_config, str) and mcp_config.startswith("https://"):
all_tools.extend(self._resolve_external(mcp_config))
elif isinstance(mcp_config, str):
amp_refs.append(self._parse_amp_ref(mcp_config))
else:
tools, client = self._resolve_native(mcp_config)
all_tools.extend(tools)
if client:
self._clients.append(client)
if amp_refs:
tools, clients = self._resolve_amp(amp_refs)
all_tools.extend(tools)
self._clients.extend(clients)
return all_tools
def cleanup(self) -> None:
"""Disconnect all MCP client connections."""
if not self._clients:
return
async def _disconnect_all() -> None:
for client in self._clients:
if client and hasattr(client, "connected") and client.connected:
await client.disconnect()
try:
asyncio.run(_disconnect_all())
except Exception as e:
self._logger.log("error", f"Error during MCP client cleanup: {e}")
finally:
self._clients.clear()
@staticmethod
def _parse_amp_ref(mcp_config: str) -> tuple[str, str | None]:
"""Parse an AMP reference into *(slug, optional tool name)*.
Accepts both bare slugs (``"notion"``, ``"notion#search"``) and the
legacy ``"crewai-amp:notion"`` form.
"""
bare = mcp_config.removeprefix("crewai-amp:")
slug, _, specific_tool = bare.partition("#")
return slug, specific_tool or None
def _resolve_amp(
self, amp_refs: list[tuple[str, str | None]]
) -> tuple[list[BaseTool], list[Any]]:
"""Fetch AMP configs in bulk and return their tools and clients.
Resolves each unique slug only once (single connection per server),
then applies per-ref tool filters to select specific tools.
"""
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.mcp_events import MCPConfigFetchFailedEvent
unique_slugs = list(dict.fromkeys(slug for slug, _ in amp_refs))
amp_configs_map = self._fetch_amp_mcp_configs(unique_slugs)
all_tools: list[BaseTool] = []
all_clients: list[Any] = []
resolved_cache: dict[str, tuple[list[BaseTool], Any | None]] = {}
for slug in unique_slugs:
config_dict = amp_configs_map.get(slug)
if not config_dict:
crewai_event_bus.emit(
self,
MCPConfigFetchFailedEvent(
slug=slug,
error=f"Config for '{slug}' not found. Make sure it is connected in your account.",
error_type="not_connected",
),
)
continue
mcp_server_config = self._build_mcp_config_from_dict(config_dict)
try:
tools, client = self._resolve_native(mcp_server_config)
resolved_cache[slug] = (tools, client)
if client:
all_clients.append(client)
except Exception as e:
crewai_event_bus.emit(
self,
MCPConfigFetchFailedEvent(
slug=slug,
error=str(e),
error_type="connection_failed",
),
)
for slug, specific_tool in amp_refs:
cached = resolved_cache.get(slug)
if not cached:
continue
slug_tools, _ = cached
if specific_tool:
all_tools.extend(
t for t in slug_tools if t.name.endswith(f"_{specific_tool}")
)
else:
all_tools.extend(slug_tools)
return all_tools, all_clients
def _fetch_amp_mcp_configs(self, slugs: list[str]) -> dict[str, dict[str, Any]]:
"""Fetch MCP server configurations via CrewAI+ API.
Sends a GET request to the CrewAI+ mcps/configs endpoint with
comma-separated slugs. CrewAI+ proxies the request to crewai-oauth.
API-level failures return ``{}``; individual slugs will then
surface as ``MCPConfigFetchFailedEvent`` in :meth:`_resolve_amp`.
"""
import httpx
try:
from crewai_tools.tools.crewai_platform_tools.misc import (
get_platform_integration_token,
)
from crewai.cli.plus_api import PlusAPI
plus_api = PlusAPI(api_key=get_platform_integration_token())
response = plus_api.get_mcp_configs(slugs)
if response.status_code == 200:
configs: dict[str, dict[str, Any]] = response.json().get("configs", {})
return configs
self._logger.log(
"debug",
f"Failed to fetch MCP configs: HTTP {response.status_code}",
)
return {}
except httpx.HTTPError as e:
self._logger.log("debug", f"Failed to fetch MCP configs: {e}")
return {}
except Exception as e:
self._logger.log("debug", f"Cannot fetch AMP MCP configs: {e}")
return {}
def _resolve_external(self, mcp_ref: str) -> list[BaseTool]:
"""Resolve an HTTPS MCP server URL into tools."""
from crewai.tools.mcp_tool_wrapper import MCPToolWrapper
if "#" in mcp_ref:
server_url, specific_tool = mcp_ref.split("#", 1)
else:
server_url, specific_tool = mcp_ref, None
server_params = {"url": server_url}
server_name = self._extract_server_name(server_url)
try:
tool_schemas = self._get_mcp_tool_schemas(server_params)
if not tool_schemas:
self._logger.log(
"warning", f"No tools discovered from MCP server: {server_url}"
)
return []
tools = []
for tool_name, schema in tool_schemas.items():
if specific_tool and tool_name != specific_tool:
continue
try:
wrapper = MCPToolWrapper(
mcp_server_params=server_params,
tool_name=tool_name,
tool_schema=schema,
server_name=server_name,
)
tools.append(wrapper)
except Exception as e:
self._logger.log(
"warning",
f"Failed to create MCP tool wrapper for {tool_name}: {e}",
)
continue
if specific_tool and not tools:
self._logger.log(
"warning",
f"Specific tool '{specific_tool}' not found on MCP server: {server_url}",
)
return cast(list[BaseTool], tools)
except Exception as e:
self._logger.log(
"warning", f"Failed to connect to MCP server {server_url}: {e}"
)
return []
def _resolve_native(
self, mcp_config: MCPServerConfig
) -> tuple[list[BaseTool], Any | None]:
"""Resolve an ``MCPServerConfig`` into tools, returning the client for cleanup."""
from crewai.tools.base_tool import BaseTool
from crewai.tools.mcp_native_tool import MCPNativeTool
transport: StdioTransport | HTTPTransport | SSETransport
if isinstance(mcp_config, MCPServerStdio):
transport = StdioTransport(
command=mcp_config.command,
args=mcp_config.args,
env=mcp_config.env,
)
server_name = f"{mcp_config.command}_{'_'.join(mcp_config.args)}"
elif isinstance(mcp_config, MCPServerHTTP):
transport = HTTPTransport(
url=mcp_config.url,
headers=mcp_config.headers,
streamable=mcp_config.streamable,
)
server_name = self._extract_server_name(mcp_config.url)
elif isinstance(mcp_config, MCPServerSSE):
transport = SSETransport(
url=mcp_config.url,
headers=mcp_config.headers,
)
server_name = self._extract_server_name(mcp_config.url)
else:
raise ValueError(f"Unsupported MCP server config type: {type(mcp_config)}")
client = MCPClient(
transport=transport,
cache_tools_list=mcp_config.cache_tools_list,
)
async def _setup_client_and_list_tools() -> list[dict[str, Any]]:
try:
if not client.connected:
await client.connect()
tools_list = await client.list_tools()
try:
await client.disconnect()
await asyncio.sleep(0.1)
except Exception as e:
self._logger.log("error", f"Error during disconnect: {e}")
return tools_list
except Exception as e:
if client.connected:
await client.disconnect()
await asyncio.sleep(0.1)
raise RuntimeError(
f"Error during setup client and list tools: {e}"
) from e
try:
try:
asyncio.get_running_loop()
import concurrent.futures
with concurrent.futures.ThreadPoolExecutor() as executor:
future = executor.submit(
asyncio.run, _setup_client_and_list_tools()
)
tools_list = future.result()
except RuntimeError:
try:
tools_list = asyncio.run(_setup_client_and_list_tools())
except RuntimeError as e:
error_msg = str(e).lower()
if "cancel scope" in error_msg or "task" in error_msg:
raise ConnectionError(
"MCP connection failed due to event loop cleanup issues. "
"This may be due to authentication errors or server unavailability."
) from e
except asyncio.CancelledError as e:
raise ConnectionError(
"MCP connection was cancelled. This may indicate an authentication "
"error or server unavailability."
) from e
if mcp_config.tool_filter:
filtered_tools = []
for tool in tools_list:
if callable(mcp_config.tool_filter):
try:
from crewai.mcp.filters import ToolFilterContext
context = ToolFilterContext(
agent=self._agent,
server_name=server_name,
run_context=None,
)
if mcp_config.tool_filter(context, tool): # type: ignore[call-arg, arg-type]
filtered_tools.append(tool)
except (TypeError, AttributeError):
if mcp_config.tool_filter(tool): # type: ignore[call-arg, arg-type]
filtered_tools.append(tool)
else:
filtered_tools.append(tool)
tools_list = filtered_tools
tools = []
for tool_def in tools_list:
tool_name = tool_def.get("name", "")
original_tool_name = tool_def.get("original_name", tool_name)
if not tool_name:
continue
args_schema = None
if tool_def.get("inputSchema"):
args_schema = self._json_schema_to_pydantic(
tool_name, tool_def["inputSchema"]
)
tool_schema = {
"description": tool_def.get("description", ""),
"args_schema": args_schema,
}
try:
native_tool = MCPNativeTool(
mcp_client=client,
tool_name=tool_name,
tool_schema=tool_schema,
server_name=server_name,
original_tool_name=original_tool_name,
)
tools.append(native_tool)
except Exception as e:
self._logger.log("error", f"Failed to create native MCP tool: {e}")
continue
return cast(list[BaseTool], tools), client
except Exception as e:
if client.connected:
asyncio.run(client.disconnect())
raise RuntimeError(f"Failed to get native MCP tools: {e}") from e
@staticmethod
def _build_mcp_config_from_dict(
config_dict: dict[str, Any],
) -> MCPServerConfig:
"""Convert a config dict from crewai-oauth into an MCPServerConfig."""
config_type = config_dict.get("type", "http")
if config_type == "sse":
return MCPServerSSE(
url=config_dict["url"],
headers=config_dict.get("headers"),
cache_tools_list=config_dict.get("cache_tools_list", False),
)
return MCPServerHTTP(
url=config_dict["url"],
headers=config_dict.get("headers"),
streamable=config_dict.get("streamable", True),
cache_tools_list=config_dict.get("cache_tools_list", False),
)
@staticmethod
def _extract_server_name(server_url: str) -> str:
"""Extract clean server name from URL for tool prefixing."""
parsed = urlparse(server_url)
domain = parsed.netloc.replace(".", "_")
path = parsed.path.replace("/", "_").strip("_")
return f"{domain}_{path}" if path else domain
def _get_mcp_tool_schemas(
self, server_params: dict[str, Any]
) -> dict[str, dict[str, Any]]:
"""Get tool schemas from MCP server with caching."""
server_url = server_params["url"]
cache_key = server_url
current_time = time.time()
if cache_key in _mcp_schema_cache:
cached_data, cache_time = _mcp_schema_cache[cache_key]
if current_time - cache_time < _cache_ttl:
self._logger.log(
"debug", f"Using cached MCP tool schemas for {server_url}"
)
return cached_data # type: ignore[no-any-return]
try:
schemas = asyncio.run(self._get_mcp_tool_schemas_async(server_params))
_mcp_schema_cache[cache_key] = (schemas, current_time)
return schemas
except Exception as e:
self._logger.log(
"warning", f"Failed to get MCP tool schemas from {server_url}: {e}"
)
return {}
async def _get_mcp_tool_schemas_async(
self, server_params: dict[str, Any]
) -> dict[str, dict[str, Any]]:
"""Async implementation of MCP tool schema retrieval."""
server_url = server_params["url"]
return await self._retry_mcp_discovery(
self._discover_mcp_tools_with_timeout, server_url
)
async def _retry_mcp_discovery(
self, operation_func: Any, server_url: str
) -> dict[str, dict[str, Any]]:
"""Retry MCP discovery with exponential backoff."""
last_error = None
for attempt in range(MCP_MAX_RETRIES):
result, error, should_retry = await self._attempt_mcp_discovery(
operation_func, server_url
)
if result is not None:
return result
if not should_retry:
raise RuntimeError(error)
last_error = error
if attempt < MCP_MAX_RETRIES - 1:
wait_time = 2**attempt
await asyncio.sleep(wait_time)
raise RuntimeError(
f"Failed to discover MCP tools after {MCP_MAX_RETRIES} attempts: {last_error}"
)
@staticmethod
async def _attempt_mcp_discovery(
operation_func: Any, server_url: str
) -> tuple[dict[str, dict[str, Any]] | None, str, bool]:
"""Attempt single MCP discovery; returns *(result, error_message, should_retry)*."""
try:
result = await operation_func(server_url)
return result, "", False
except ImportError:
return (
None,
"MCP library not available. Please install with: pip install mcp",
False,
)
except asyncio.TimeoutError:
return (
None,
f"MCP discovery timed out after {MCP_DISCOVERY_TIMEOUT} seconds",
True,
)
except Exception as e:
error_str = str(e).lower()
if "authentication" in error_str or "unauthorized" in error_str:
return None, f"Authentication failed for MCP server: {e!s}", False
if "connection" in error_str or "network" in error_str:
return None, f"Network connection failed: {e!s}", True
if "json" in error_str or "parsing" in error_str:
return None, f"Server response parsing error: {e!s}", True
return None, f"MCP discovery error: {e!s}", False
async def _discover_mcp_tools_with_timeout(
self, server_url: str
) -> dict[str, dict[str, Any]]:
"""Discover MCP tools with timeout wrapper."""
return await asyncio.wait_for(
self._discover_mcp_tools(server_url), timeout=MCP_DISCOVERY_TIMEOUT
)
async def _discover_mcp_tools(self, server_url: str) -> dict[str, dict[str, Any]]:
"""Discover tools from an MCP server (HTTPS / streamable-HTTP path)."""
from mcp import ClientSession
from mcp.client.streamable_http import streamablehttp_client
from crewai.utilities.string_utils import sanitize_tool_name
async with streamablehttp_client(server_url) as (read, write, _):
async with ClientSession(read, write) as session:
await asyncio.wait_for(
session.initialize(), timeout=MCP_CONNECTION_TIMEOUT
)
tools_result = await asyncio.wait_for(
session.list_tools(),
timeout=MCP_DISCOVERY_TIMEOUT - MCP_CONNECTION_TIMEOUT,
)
schemas = {}
for tool in tools_result.tools:
args_schema = None
if hasattr(tool, "inputSchema") and tool.inputSchema:
args_schema = self._json_schema_to_pydantic(
sanitize_tool_name(tool.name), tool.inputSchema
)
schemas[sanitize_tool_name(tool.name)] = {
"description": getattr(tool, "description", ""),
"args_schema": args_schema,
}
return schemas
@staticmethod
def _json_schema_to_pydantic(tool_name: str, json_schema: dict[str, Any]) -> type:
"""Convert JSON Schema to a Pydantic model for tool arguments."""
from crewai.utilities.pydantic_schema_utils import create_model_from_schema
model_name = f"{tool_name.replace('-', '_').replace(' ', '_')}Schema"
return create_model_from_schema(
json_schema,
model_name=model_name,
enrich_descriptions=True,
)

View File

@@ -27,14 +27,16 @@ class MCPNativeTool(BaseTool):
tool_name: str,
tool_schema: dict[str, Any],
server_name: str,
original_tool_name: str | None = None,
) -> None:
"""Initialize native MCP tool.
Args:
mcp_client: MCPClient instance with active session.
tool_name: Original name of the tool on the MCP server.
tool_name: Name of the tool (may be prefixed).
tool_schema: Schema information for the tool.
server_name: Name of the MCP server for prefixing.
original_tool_name: Original name of the tool on the MCP server.
"""
# Create tool name with server prefix to avoid conflicts
prefixed_name = f"{server_name}_{tool_name}"
@@ -57,7 +59,7 @@ class MCPNativeTool(BaseTool):
# Set instance attributes after super().__init__
self._mcp_client = mcp_client
self._original_tool_name = tool_name
self._original_tool_name = original_tool_name or tool_name
self._server_name = server_name
# self._logger = logging.getLogger(__name__)

View File

@@ -491,10 +491,66 @@ FORMAT_TYPE_MAP: dict[str, type[Any]] = {
}
def build_rich_field_description(prop_schema: dict[str, Any]) -> str:
"""Build a comprehensive field description including constraints.
Embeds format, enum, pattern, min/max, and example constraints into the
description text so that LLMs can understand tool parameter requirements
without inspecting the raw JSON Schema.
Args:
prop_schema: Property schema with description and constraints.
Returns:
Enhanced description with format, enum, and other constraints.
"""
parts: list[str] = []
description = prop_schema.get("description", "")
if description:
parts.append(description)
format_type = prop_schema.get("format")
if format_type:
parts.append(f"Format: {format_type}")
enum_values = prop_schema.get("enum")
if enum_values:
enum_str = ", ".join(repr(v) for v in enum_values)
parts.append(f"Allowed values: [{enum_str}]")
pattern = prop_schema.get("pattern")
if pattern:
parts.append(f"Pattern: {pattern}")
minimum = prop_schema.get("minimum")
maximum = prop_schema.get("maximum")
if minimum is not None:
parts.append(f"Minimum: {minimum}")
if maximum is not None:
parts.append(f"Maximum: {maximum}")
min_length = prop_schema.get("minLength")
max_length = prop_schema.get("maxLength")
if min_length is not None:
parts.append(f"Min length: {min_length}")
if max_length is not None:
parts.append(f"Max length: {max_length}")
examples = prop_schema.get("examples")
if examples:
examples_str = ", ".join(repr(e) for e in examples[:3])
parts.append(f"Examples: {examples_str}")
return ". ".join(parts) if parts else ""
def create_model_from_schema( # type: ignore[no-any-unimported]
json_schema: dict[str, Any],
*,
root_schema: dict[str, Any] | None = None,
model_name: str | None = None,
enrich_descriptions: bool = False,
__config__: ConfigDict | None = None,
__base__: type[BaseModel] | None = None,
__module__: str = __name__,
@@ -512,6 +568,13 @@ def create_model_from_schema( # type: ignore[no-any-unimported]
json_schema: A dictionary representing the JSON schema.
root_schema: The root schema containing $defs. If not provided, the
current schema is treated as the root schema.
model_name: Override for the model name. If not provided, the schema
``title`` field is used, falling back to ``"DynamicModel"``.
enrich_descriptions: When True, augment field descriptions with
constraint info (format, enum, pattern, min/max, examples) via
:func:`build_rich_field_description`. Useful for LLM-facing tool
schemas where constraints in the description help the model
understand parameter requirements.
__config__: Pydantic configuration for the generated model.
__base__: Base class for the generated model. Defaults to BaseModel.
__module__: Module name for the generated model class.
@@ -548,10 +611,14 @@ def create_model_from_schema( # type: ignore[no-any-unimported]
if "title" not in json_schema and "title" in (root_schema or {}):
json_schema["title"] = (root_schema or {}).get("title")
model_name = json_schema.get("title") or "DynamicModel"
effective_name = model_name or json_schema.get("title") or "DynamicModel"
field_definitions = {
name: _json_schema_to_pydantic_field(
name, prop, json_schema.get("required", []), effective_root
name,
prop,
json_schema.get("required", []),
effective_root,
enrich_descriptions=enrich_descriptions,
)
for name, prop in (json_schema.get("properties", {}) or {}).items()
}
@@ -559,7 +626,7 @@ def create_model_from_schema( # type: ignore[no-any-unimported]
effective_config = __config__ or ConfigDict(extra="forbid")
return create_model_base(
model_name,
effective_name,
__config__=effective_config,
__base__=__base__,
__module__=__module__,
@@ -574,6 +641,8 @@ def _json_schema_to_pydantic_field(
json_schema: dict[str, Any],
required: list[str],
root_schema: dict[str, Any],
*,
enrich_descriptions: bool = False,
) -> Any:
"""Convert a JSON schema property to a Pydantic field definition.
@@ -582,20 +651,29 @@ def _json_schema_to_pydantic_field(
json_schema: The JSON schema for this field.
required: List of required field names.
root_schema: The root schema for resolving $ref.
enrich_descriptions: When True, embed constraints in the description.
Returns:
A tuple of (type, Field) for use with create_model.
"""
type_ = _json_schema_to_pydantic_type(json_schema, root_schema, name_=name.title())
description = json_schema.get("description")
examples = json_schema.get("examples")
type_ = _json_schema_to_pydantic_type(
json_schema, root_schema, name_=name.title(), enrich_descriptions=enrich_descriptions
)
is_required = name in required
field_params: dict[str, Any] = {}
schema_extra: dict[str, Any] = {}
if description:
field_params["description"] = description
if enrich_descriptions:
rich_desc = build_rich_field_description(json_schema)
if rich_desc:
field_params["description"] = rich_desc
else:
description = json_schema.get("description")
if description:
field_params["description"] = description
examples = json_schema.get("examples")
if examples:
schema_extra["examples"] = examples
@@ -711,6 +789,7 @@ def _json_schema_to_pydantic_type(
root_schema: dict[str, Any],
*,
name_: str | None = None,
enrich_descriptions: bool = False,
) -> Any:
"""Convert a JSON schema to a Python/Pydantic type.
@@ -718,6 +797,7 @@ def _json_schema_to_pydantic_type(
json_schema: The JSON schema to convert.
root_schema: The root schema for resolving $ref.
name_: Optional name for nested models.
enrich_descriptions: Propagated to nested model creation.
Returns:
A Python type corresponding to the JSON schema.
@@ -725,7 +805,9 @@ def _json_schema_to_pydantic_type(
ref = json_schema.get("$ref")
if ref:
ref_schema = _resolve_ref(ref, root_schema)
return _json_schema_to_pydantic_type(ref_schema, root_schema, name_=name_)
return _json_schema_to_pydantic_type(
ref_schema, root_schema, name_=name_, enrich_descriptions=enrich_descriptions
)
enum_values = json_schema.get("enum")
if enum_values:
@@ -740,7 +822,10 @@ def _json_schema_to_pydantic_type(
if any_of_schemas:
any_of_types = [
_json_schema_to_pydantic_type(
schema, root_schema, name_=f"{name_ or 'Union'}Option{i}"
schema,
root_schema,
name_=f"{name_ or 'Union'}Option{i}",
enrich_descriptions=enrich_descriptions,
)
for i, schema in enumerate(any_of_schemas)
]
@@ -750,10 +835,14 @@ def _json_schema_to_pydantic_type(
if all_of_schemas:
if len(all_of_schemas) == 1:
return _json_schema_to_pydantic_type(
all_of_schemas[0], root_schema, name_=name_
all_of_schemas[0], root_schema, name_=name_,
enrich_descriptions=enrich_descriptions,
)
merged = _merge_all_of_schemas(all_of_schemas, root_schema)
return _json_schema_to_pydantic_type(merged, root_schema, name_=name_)
return _json_schema_to_pydantic_type(
merged, root_schema, name_=name_,
enrich_descriptions=enrich_descriptions,
)
type_ = json_schema.get("type")
@@ -769,7 +858,8 @@ def _json_schema_to_pydantic_type(
items_schema = json_schema.get("items")
if items_schema:
item_type = _json_schema_to_pydantic_type(
items_schema, root_schema, name_=name_
items_schema, root_schema, name_=name_,
enrich_descriptions=enrich_descriptions,
)
return list[item_type] # type: ignore[valid-type]
return list
@@ -779,7 +869,10 @@ def _json_schema_to_pydantic_type(
json_schema_ = json_schema.copy()
if json_schema_.get("title") is None:
json_schema_["title"] = name_ or "DynamicModel"
return create_model_from_schema(json_schema_, root_schema=root_schema)
return create_model_from_schema(
json_schema_, root_schema=root_schema,
enrich_descriptions=enrich_descriptions,
)
return dict
if type_ == "null":
return None

View File

@@ -659,7 +659,7 @@ def test_agent_kickoff_with_platform_tools(mock_get, mock_post):
@patch.dict("os.environ", {"EXA_API_KEY": "test_exa_key"})
@patch("crewai.agent.Agent._get_external_mcp_tools")
@patch("crewai.agent.Agent.get_mcp_tools")
@pytest.mark.vcr()
def test_agent_kickoff_with_mcp_tools(mock_get_mcp_tools):
"""Test that Agent.kickoff() properly integrates MCP tools with LiteAgent"""
@@ -691,7 +691,7 @@ def test_agent_kickoff_with_mcp_tools(mock_get_mcp_tools):
assert result.raw is not None
# Verify MCP tools were retrieved
mock_get_mcp_tools.assert_called_once_with("https://mcp.exa.ai/mcp?api_key=test_exa_key&profile=research")
mock_get_mcp_tools.assert_called_once_with(["https://mcp.exa.ai/mcp?api_key=test_exa_key&profile=research"])
# ============================================================================

View File

@@ -0,0 +1,373 @@
"""Tests for AMP MCP config fetching and tool resolution."""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from crewai.agent.core import Agent
from crewai.mcp.config import MCPServerHTTP, MCPServerSSE
from crewai.mcp.tool_resolver import MCPToolResolver
from crewai.tools.base_tool import BaseTool
@pytest.fixture
def agent():
return Agent(
role="Test Agent",
goal="Test goal",
backstory="Test backstory",
)
@pytest.fixture
def resolver(agent):
return MCPToolResolver(agent=agent, logger=agent._logger)
@pytest.fixture
def mock_tool_definitions():
return [
{
"name": "search",
"description": "Search tool",
"inputSchema": {
"type": "object",
"properties": {
"query": {"type": "string", "description": "Search query"}
},
"required": ["query"],
},
},
{
"name": "create_page",
"description": "Create a page",
"inputSchema": {},
},
]
class TestBuildMCPConfigFromDict:
def test_builds_http_config(self):
config_dict = {
"type": "http",
"url": "https://mcp.example.com/api",
"headers": {"Authorization": "Bearer token123"},
"streamable": True,
"cache_tools_list": False,
}
result = MCPToolResolver._build_mcp_config_from_dict(config_dict)
assert isinstance(result, MCPServerHTTP)
assert result.url == "https://mcp.example.com/api"
assert result.headers == {"Authorization": "Bearer token123"}
assert result.streamable is True
assert result.cache_tools_list is False
def test_builds_sse_config(self):
config_dict = {
"type": "sse",
"url": "https://mcp.example.com/sse",
"headers": {"Authorization": "Bearer token123"},
"cache_tools_list": True,
}
result = MCPToolResolver._build_mcp_config_from_dict(config_dict)
assert isinstance(result, MCPServerSSE)
assert result.url == "https://mcp.example.com/sse"
assert result.headers == {"Authorization": "Bearer token123"}
assert result.cache_tools_list is True
def test_defaults_to_http(self):
config_dict = {
"url": "https://mcp.example.com/api",
}
result = MCPToolResolver._build_mcp_config_from_dict(config_dict)
assert isinstance(result, MCPServerHTTP)
assert result.streamable is True
def test_http_defaults(self):
config_dict = {
"type": "http",
"url": "https://mcp.example.com/api",
}
result = MCPToolResolver._build_mcp_config_from_dict(config_dict)
assert result.headers is None
assert result.streamable is True
assert result.cache_tools_list is False
class TestFetchAmpMCPConfigs:
@patch("crewai.cli.plus_api.PlusAPI")
@patch("crewai_tools.tools.crewai_platform_tools.misc.get_platform_integration_token", return_value="test-api-key")
def test_fetches_configs_successfully(self, mock_get_token, mock_plus_api_class, resolver):
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"configs": {
"notion": {
"type": "sse",
"url": "https://mcp.notion.so/sse",
"headers": {"Authorization": "Bearer notion-token"},
},
"github": {
"type": "http",
"url": "https://mcp.github.com/api",
"headers": {"Authorization": "Bearer gh-token"},
},
},
}
mock_plus_api = MagicMock()
mock_plus_api.get_mcp_configs.return_value = mock_response
mock_plus_api_class.return_value = mock_plus_api
result = resolver._fetch_amp_mcp_configs(["notion", "github"])
assert "notion" in result
assert "github" in result
assert result["notion"]["url"] == "https://mcp.notion.so/sse"
mock_plus_api_class.assert_called_once_with(api_key="test-api-key")
mock_plus_api.get_mcp_configs.assert_called_once_with(["notion", "github"])
@patch("crewai.cli.plus_api.PlusAPI")
@patch("crewai_tools.tools.crewai_platform_tools.misc.get_platform_integration_token", return_value="test-api-key")
def test_omits_missing_slugs(self, mock_get_token, mock_plus_api_class, resolver):
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"configs": {"notion": {"type": "sse", "url": "https://mcp.notion.so/sse"}},
}
mock_plus_api = MagicMock()
mock_plus_api.get_mcp_configs.return_value = mock_response
mock_plus_api_class.return_value = mock_plus_api
result = resolver._fetch_amp_mcp_configs(["notion", "missing-server"])
assert "notion" in result
assert "missing-server" not in result
@patch("crewai.cli.plus_api.PlusAPI")
@patch("crewai_tools.tools.crewai_platform_tools.misc.get_platform_integration_token", return_value="test-api-key")
def test_returns_empty_on_http_error(self, mock_get_token, mock_plus_api_class, resolver):
mock_response = MagicMock()
mock_response.status_code = 500
mock_plus_api = MagicMock()
mock_plus_api.get_mcp_configs.return_value = mock_response
mock_plus_api_class.return_value = mock_plus_api
result = resolver._fetch_amp_mcp_configs(["notion"])
assert result == {}
@patch("crewai.cli.plus_api.PlusAPI")
@patch("crewai_tools.tools.crewai_platform_tools.misc.get_platform_integration_token", return_value="test-api-key")
def test_returns_empty_on_network_error(self, mock_get_token, mock_plus_api_class, resolver):
import httpx
mock_plus_api = MagicMock()
mock_plus_api.get_mcp_configs.side_effect = httpx.ConnectError("Connection refused")
mock_plus_api_class.return_value = mock_plus_api
result = resolver._fetch_amp_mcp_configs(["notion"])
assert result == {}
@patch("crewai_tools.tools.crewai_platform_tools.misc.get_platform_integration_token", side_effect=Exception("No token"))
def test_returns_empty_when_no_token(self, mock_get_token, resolver):
result = resolver._fetch_amp_mcp_configs(["notion"])
assert result == {}
class TestParseAmpRef:
def test_bare_slug(self):
slug, tool = MCPToolResolver._parse_amp_ref("notion")
assert slug == "notion"
assert tool is None
def test_bare_slug_with_tool(self):
slug, tool = MCPToolResolver._parse_amp_ref("notion#search")
assert slug == "notion"
assert tool == "search"
def test_bare_slug_with_empty_tool(self):
slug, tool = MCPToolResolver._parse_amp_ref("notion#")
assert slug == "notion"
assert tool is None
def test_legacy_prefix_slug(self):
slug, tool = MCPToolResolver._parse_amp_ref("crewai-amp:notion")
assert slug == "notion"
assert tool is None
def test_legacy_prefix_with_tool(self):
slug, tool = MCPToolResolver._parse_amp_ref("crewai-amp:notion#search")
assert slug == "notion"
assert tool == "search"
class TestGetMCPToolsAmpIntegration:
@patch("crewai.mcp.tool_resolver.MCPClient")
@patch.object(MCPToolResolver, "_fetch_amp_mcp_configs")
def test_single_request_for_multiple_amp_refs(
self, mock_fetch, mock_client_class, agent, mock_tool_definitions
):
mock_fetch.return_value = {
"notion": {
"type": "sse",
"url": "https://mcp.notion.so/sse",
"headers": {"Authorization": "Bearer token"},
},
"github": {
"type": "http",
"url": "https://mcp.github.com/api",
"headers": {"Authorization": "Bearer gh-token"},
"streamable": True,
},
}
mock_client = AsyncMock()
mock_client.list_tools = AsyncMock(return_value=mock_tool_definitions)
mock_client.connected = False
mock_client.connect = AsyncMock()
mock_client.disconnect = AsyncMock()
mock_client_class.return_value = mock_client
tools = agent.get_mcp_tools(["notion", "github"])
mock_fetch.assert_called_once_with(["notion", "github"])
assert len(tools) == 4 # 2 tools per server
@patch("crewai.mcp.tool_resolver.MCPClient")
@patch.object(MCPToolResolver, "_fetch_amp_mcp_configs")
def test_tool_filter_with_hash_syntax(
self, mock_fetch, mock_client_class, agent, mock_tool_definitions
):
mock_fetch.return_value = {
"notion": {
"type": "sse",
"url": "https://mcp.notion.so/sse",
"headers": {"Authorization": "Bearer token"},
},
}
mock_client = AsyncMock()
mock_client.list_tools = AsyncMock(return_value=mock_tool_definitions)
mock_client.connected = False
mock_client.connect = AsyncMock()
mock_client.disconnect = AsyncMock()
mock_client_class.return_value = mock_client
tools = agent.get_mcp_tools(["notion#search"])
mock_fetch.assert_called_once_with(["notion"])
assert len(tools) == 1
assert tools[0].name == "mcp_notion_so_sse_search"
@patch("crewai.mcp.tool_resolver.MCPClient")
@patch.object(MCPToolResolver, "_fetch_amp_mcp_configs")
def test_deduplicates_slugs(
self, mock_fetch, mock_client_class, agent, mock_tool_definitions
):
mock_fetch.return_value = {
"notion": {
"type": "sse",
"url": "https://mcp.notion.so/sse",
"headers": {"Authorization": "Bearer token"},
},
}
mock_client = AsyncMock()
mock_client.list_tools = AsyncMock(return_value=mock_tool_definitions)
mock_client.connected = False
mock_client.connect = AsyncMock()
mock_client.disconnect = AsyncMock()
mock_client_class.return_value = mock_client
tools = agent.get_mcp_tools(["notion#search", "notion#create_page"])
mock_fetch.assert_called_once_with(["notion"])
assert len(tools) == 2
@patch.object(MCPToolResolver, "_fetch_amp_mcp_configs")
def test_skips_missing_configs_gracefully(self, mock_fetch, agent):
mock_fetch.return_value = {}
tools = agent.get_mcp_tools(["missing-server"])
assert tools == []
@patch("crewai.mcp.tool_resolver.MCPClient")
@patch.object(MCPToolResolver, "_fetch_amp_mcp_configs")
def test_legacy_crewai_amp_prefix_still_works(
self, mock_fetch, mock_client_class, agent, mock_tool_definitions
):
mock_fetch.return_value = {
"notion": {
"type": "sse",
"url": "https://mcp.notion.so/sse",
"headers": {"Authorization": "Bearer token"},
},
}
mock_client = AsyncMock()
mock_client.list_tools = AsyncMock(return_value=mock_tool_definitions)
mock_client.connected = False
mock_client.connect = AsyncMock()
mock_client.disconnect = AsyncMock()
mock_client_class.return_value = mock_client
tools = agent.get_mcp_tools(["crewai-amp:notion"])
mock_fetch.assert_called_once_with(["notion"])
assert len(tools) == 2
@patch("crewai.mcp.tool_resolver.MCPClient")
@patch.object(MCPToolResolver, "_fetch_amp_mcp_configs")
@patch.object(MCPToolResolver, "_resolve_external")
def test_non_amp_items_unaffected(
self,
mock_external,
mock_fetch,
mock_client_class,
agent,
mock_tool_definitions,
):
mock_fetch.return_value = {
"notion": {
"type": "sse",
"url": "https://mcp.notion.so/sse",
},
}
mock_client = AsyncMock()
mock_client.list_tools = AsyncMock(return_value=mock_tool_definitions)
mock_client.connected = False
mock_client.connect = AsyncMock()
mock_client.disconnect = AsyncMock()
mock_client_class.return_value = mock_client
mock_external_tool = MagicMock(spec=BaseTool)
mock_external.return_value = [mock_external_tool]
http_config = MCPServerHTTP(
url="https://other.mcp.com/api",
headers={"Authorization": "Bearer other"},
)
tools = agent.get_mcp_tools(
[
"notion",
"https://external.mcp.com/api",
http_config,
]
)
mock_fetch.assert_called_once_with(["notion"])
mock_external.assert_called_once_with("https://external.mcp.com/api")
# 2 from notion + 1 from external + 2 from http_config
assert len(tools) == 5

View File

@@ -1,5 +1,5 @@
import asyncio
from unittest.mock import AsyncMock, MagicMock, patch
from unittest.mock import AsyncMock, patch
import pytest
from crewai.agent.core import Agent
@@ -46,7 +46,7 @@ def test_agent_with_stdio_mcp_config(mock_tool_definitions):
)
with patch("crewai.agent.core.MCPClient") as mock_client_class:
with patch("crewai.mcp.tool_resolver.MCPClient") as mock_client_class:
mock_client = AsyncMock()
mock_client.list_tools = AsyncMock(return_value=mock_tool_definitions)
mock_client.connected = False # Will trigger connect
@@ -82,7 +82,7 @@ def test_agent_with_http_mcp_config(mock_tool_definitions):
mcps=[http_config],
)
with patch("crewai.agent.core.MCPClient") as mock_client_class:
with patch("crewai.mcp.tool_resolver.MCPClient") as mock_client_class:
mock_client = AsyncMock()
mock_client.list_tools = AsyncMock(return_value=mock_tool_definitions)
mock_client.connected = False # Will trigger connect
@@ -117,7 +117,7 @@ def test_agent_with_sse_mcp_config(mock_tool_definitions):
mcps=[sse_config],
)
with patch("crewai.agent.core.MCPClient") as mock_client_class:
with patch("crewai.mcp.tool_resolver.MCPClient") as mock_client_class:
mock_client = AsyncMock()
mock_client.list_tools = AsyncMock(return_value=mock_tool_definitions)
mock_client.connected = False
@@ -141,7 +141,7 @@ def test_mcp_tool_execution_in_sync_context(mock_tool_definitions):
"""Test MCPNativeTool execution in synchronous context (normal crew execution)."""
http_config = MCPServerHTTP(url="https://api.example.com/mcp")
with patch("crewai.agent.core.MCPClient") as mock_client_class:
with patch("crewai.mcp.tool_resolver.MCPClient") as mock_client_class:
mock_client = AsyncMock()
mock_client.list_tools = AsyncMock(return_value=mock_tool_definitions)
mock_client.connected = False
@@ -173,7 +173,7 @@ async def test_mcp_tool_execution_in_async_context(mock_tool_definitions):
"""Test MCPNativeTool execution in async context (e.g., from a Flow)."""
http_config = MCPServerHTTP(url="https://api.example.com/mcp")
with patch("crewai.agent.core.MCPClient") as mock_client_class:
with patch("crewai.mcp.tool_resolver.MCPClient") as mock_client_class:
mock_client = AsyncMock()
mock_client.list_tools = AsyncMock(return_value=mock_tool_definitions)
mock_client.connected = False

View File

@@ -0,0 +1,884 @@
"""Tests for pydantic_schema_utils module.
Covers:
- create_model_from_schema: type mapping, required/optional, enums, formats,
nested objects, arrays, unions, allOf, $ref, model_name, enrich_descriptions
- Schema transformation helpers: resolve_refs, force_additional_properties_false,
strip_unsupported_formats, ensure_type_in_schemas, convert_oneof_to_anyof,
ensure_all_properties_required, strip_null_from_types, build_rich_field_description
- End-to-end MCP tool schema conversion
"""
from __future__ import annotations
import datetime
from copy import deepcopy
from typing import Any
import pytest
from pydantic import BaseModel
from crewai.utilities.pydantic_schema_utils import (
build_rich_field_description,
convert_oneof_to_anyof,
create_model_from_schema,
ensure_all_properties_required,
ensure_type_in_schemas,
force_additional_properties_false,
resolve_refs,
strip_null_from_types,
strip_unsupported_formats,
)
class TestSimpleTypes:
def test_string_field(self) -> None:
schema = {
"type": "object",
"properties": {"name": {"type": "string"}},
"required": ["name"],
}
Model = create_model_from_schema(schema)
obj = Model(name="Alice")
assert obj.name == "Alice"
def test_integer_field(self) -> None:
schema = {
"type": "object",
"properties": {"count": {"type": "integer"}},
"required": ["count"],
}
Model = create_model_from_schema(schema)
obj = Model(count=42)
assert obj.count == 42
def test_number_field(self) -> None:
schema = {
"type": "object",
"properties": {"score": {"type": "number"}},
"required": ["score"],
}
Model = create_model_from_schema(schema)
obj = Model(score=3.14)
assert obj.score == pytest.approx(3.14)
def test_boolean_field(self) -> None:
schema = {
"type": "object",
"properties": {"active": {"type": "boolean"}},
"required": ["active"],
}
Model = create_model_from_schema(schema)
assert Model(active=True).active is True
def test_null_field(self) -> None:
schema = {
"type": "object",
"properties": {"value": {"type": "null"}},
"required": ["value"],
}
Model = create_model_from_schema(schema)
obj = Model(value=None)
assert obj.value is None
class TestRequiredOptional:
def test_required_field_has_no_default(self) -> None:
schema = {
"type": "object",
"properties": {"name": {"type": "string"}},
"required": ["name"],
}
Model = create_model_from_schema(schema)
with pytest.raises(Exception):
Model()
def test_optional_field_defaults_to_none(self) -> None:
schema = {
"type": "object",
"properties": {"name": {"type": "string"}},
"required": [],
}
Model = create_model_from_schema(schema)
obj = Model()
assert obj.name is None
def test_mixed_required_optional(self) -> None:
schema = {
"type": "object",
"properties": {
"id": {"type": "integer"},
"label": {"type": "string"},
},
"required": ["id"],
}
Model = create_model_from_schema(schema)
obj = Model(id=1)
assert obj.id == 1
assert obj.label is None
class TestEnumLiteral:
def test_string_enum(self) -> None:
schema = {
"type": "object",
"properties": {
"color": {"type": "string", "enum": ["red", "green", "blue"]},
},
"required": ["color"],
}
Model = create_model_from_schema(schema)
obj = Model(color="red")
assert obj.color == "red"
def test_string_enum_rejects_invalid(self) -> None:
schema = {
"type": "object",
"properties": {
"color": {"type": "string", "enum": ["red", "green", "blue"]},
},
"required": ["color"],
}
Model = create_model_from_schema(schema)
with pytest.raises(Exception):
Model(color="yellow")
def test_const_value(self) -> None:
schema = {
"type": "object",
"properties": {
"kind": {"const": "fixed"},
},
"required": ["kind"],
}
Model = create_model_from_schema(schema)
obj = Model(kind="fixed")
assert obj.kind == "fixed"
class TestFormatMapping:
def test_date_format(self) -> None:
schema = {
"type": "object",
"properties": {
"birthday": {"type": "string", "format": "date"},
},
"required": ["birthday"],
}
Model = create_model_from_schema(schema)
obj = Model(birthday=datetime.date(2000, 1, 15))
assert obj.birthday == datetime.date(2000, 1, 15)
def test_datetime_format(self) -> None:
schema = {
"type": "object",
"properties": {
"created_at": {"type": "string", "format": "date-time"},
},
"required": ["created_at"],
}
Model = create_model_from_schema(schema)
dt = datetime.datetime(2025, 6, 1, 12, 0, 0)
obj = Model(created_at=dt)
assert obj.created_at == dt
def test_time_format(self) -> None:
schema = {
"type": "object",
"properties": {
"alarm": {"type": "string", "format": "time"},
},
"required": ["alarm"],
}
Model = create_model_from_schema(schema)
t = datetime.time(8, 30)
obj = Model(alarm=t)
assert obj.alarm == t
class TestNestedObjects:
def test_nested_object_creates_model(self) -> None:
schema = {
"type": "object",
"properties": {
"address": {
"type": "object",
"properties": {
"street": {"type": "string"},
"city": {"type": "string"},
},
"required": ["street", "city"],
},
},
"required": ["address"],
}
Model = create_model_from_schema(schema)
obj = Model(address={"street": "123 Main", "city": "Springfield"})
assert obj.address.street == "123 Main"
assert obj.address.city == "Springfield"
def test_object_without_properties_returns_dict(self) -> None:
schema = {
"type": "object",
"properties": {
"metadata": {"type": "object"},
},
"required": ["metadata"],
}
Model = create_model_from_schema(schema)
obj = Model(metadata={"key": "value"})
assert obj.metadata == {"key": "value"}
class TestTypedArrays:
def test_array_of_strings(self) -> None:
schema = {
"type": "object",
"properties": {
"tags": {"type": "array", "items": {"type": "string"}},
},
"required": ["tags"],
}
Model = create_model_from_schema(schema)
obj = Model(tags=["a", "b", "c"])
assert obj.tags == ["a", "b", "c"]
def test_array_of_objects(self) -> None:
schema = {
"type": "object",
"properties": {
"items": {
"type": "array",
"items": {
"type": "object",
"properties": {"id": {"type": "integer"}},
"required": ["id"],
},
},
},
"required": ["items"],
}
Model = create_model_from_schema(schema)
obj = Model(items=[{"id": 1}, {"id": 2}])
assert len(obj.items) == 2
assert obj.items[0].id == 1
def test_untyped_array(self) -> None:
schema = {
"type": "object",
"properties": {"data": {"type": "array"}},
"required": ["data"],
}
Model = create_model_from_schema(schema)
obj = Model(data=[1, "two", 3.0])
assert obj.data == [1, "two", 3.0]
class TestUnionTypes:
def test_anyof_string_or_integer(self) -> None:
schema = {
"type": "object",
"properties": {
"value": {
"anyOf": [{"type": "string"}, {"type": "integer"}],
},
},
"required": ["value"],
}
Model = create_model_from_schema(schema)
assert Model(value="hello").value == "hello"
assert Model(value=42).value == 42
def test_oneof(self) -> None:
schema = {
"type": "object",
"properties": {
"value": {
"oneOf": [{"type": "string"}, {"type": "number"}],
},
},
"required": ["value"],
}
Model = create_model_from_schema(schema)
assert Model(value="hello").value == "hello"
assert Model(value=3.14).value == pytest.approx(3.14)
class TestAllOfMerging:
def test_allof_merges_properties(self) -> None:
schema = {
"type": "object",
"allOf": [
{
"type": "object",
"properties": {"name": {"type": "string"}},
"required": ["name"],
},
{
"type": "object",
"properties": {"age": {"type": "integer"}},
"required": ["age"],
},
],
}
Model = create_model_from_schema(schema)
obj = Model(name="Alice", age=30)
assert obj.name == "Alice"
assert obj.age == 30
def test_single_allof(self) -> None:
schema = {
"type": "object",
"properties": {
"item": {
"allOf": [
{
"type": "object",
"properties": {"id": {"type": "integer"}},
"required": ["id"],
}
]
}
},
"required": ["item"],
}
Model = create_model_from_schema(schema)
obj = Model(item={"id": 1})
assert obj.item.id == 1
# ---------------------------------------------------------------------------
# $ref resolution
# ---------------------------------------------------------------------------
class TestRefResolution:
def test_ref_in_property(self) -> None:
schema = {
"type": "object",
"properties": {
"item": {"$ref": "#/$defs/Item"},
},
"required": ["item"],
"$defs": {
"Item": {
"type": "object",
"title": "Item",
"properties": {"name": {"type": "string"}},
"required": ["name"],
},
},
}
Model = create_model_from_schema(schema)
obj = Model(item={"name": "Widget"})
assert obj.item.name == "Widget"
# ---------------------------------------------------------------------------
# model_name parameter
# ---------------------------------------------------------------------------
class TestModelName:
def test_model_name_override(self) -> None:
schema = {
"type": "object",
"title": "OriginalName",
"properties": {"x": {"type": "integer"}},
"required": ["x"],
}
Model = create_model_from_schema(schema, model_name="CustomSchema")
assert Model.__name__ == "CustomSchema"
def test_model_name_fallback_to_title(self) -> None:
schema = {
"type": "object",
"title": "FromTitle",
"properties": {"x": {"type": "integer"}},
"required": ["x"],
}
Model = create_model_from_schema(schema)
assert Model.__name__ == "FromTitle"
def test_model_name_fallback_to_dynamic(self) -> None:
schema = {
"type": "object",
"properties": {"x": {"type": "integer"}},
"required": ["x"],
}
Model = create_model_from_schema(schema)
assert Model.__name__ == "DynamicModel"
# ---------------------------------------------------------------------------
# enrich_descriptions
# ---------------------------------------------------------------------------
class TestEnrichDescriptions:
def test_enriched_description_includes_constraints(self) -> None:
schema = {
"type": "object",
"properties": {
"score": {
"type": "integer",
"description": "The score value",
"minimum": 0,
"maximum": 100,
},
},
"required": ["score"],
}
Model = create_model_from_schema(schema, enrich_descriptions=True)
field_info = Model.model_fields["score"]
assert "Minimum: 0" in field_info.description
assert "Maximum: 100" in field_info.description
assert "The score value" in field_info.description
def test_default_does_not_enrich(self) -> None:
schema = {
"type": "object",
"properties": {
"score": {
"type": "integer",
"description": "The score value",
"minimum": 0,
},
},
"required": ["score"],
}
Model = create_model_from_schema(schema, enrich_descriptions=False)
field_info = Model.model_fields["score"]
assert field_info.description == "The score value"
def test_enriched_description_propagates_to_nested(self) -> None:
schema = {
"type": "object",
"properties": {
"config": {
"type": "object",
"properties": {
"level": {
"type": "integer",
"description": "Level",
"minimum": 1,
"maximum": 10,
},
},
"required": ["level"],
},
},
"required": ["config"],
}
Model = create_model_from_schema(schema, enrich_descriptions=True)
nested_model = Model.model_fields["config"].annotation
nested_field = nested_model.model_fields["level"]
assert "Minimum: 1" in nested_field.description
assert "Maximum: 10" in nested_field.description
# ---------------------------------------------------------------------------
# Edge cases
# ---------------------------------------------------------------------------
class TestEdgeCases:
def test_empty_properties(self) -> None:
schema = {"type": "object", "properties": {}, "required": []}
Model = create_model_from_schema(schema)
obj = Model()
assert obj is not None
def test_no_properties_key(self) -> None:
schema = {"type": "object"}
Model = create_model_from_schema(schema)
obj = Model()
assert obj is not None
def test_unknown_type_raises(self) -> None:
schema = {
"type": "object",
"properties": {
"weird": {"type": "hyperspace"},
},
"required": ["weird"],
}
with pytest.raises(ValueError, match="Unsupported JSON schema type"):
create_model_from_schema(schema)
# ---------------------------------------------------------------------------
# build_rich_field_description
# ---------------------------------------------------------------------------
class TestBuildRichFieldDescription:
def test_description_only(self) -> None:
assert build_rich_field_description({"description": "A name"}) == "A name"
def test_empty_schema(self) -> None:
assert build_rich_field_description({}) == ""
def test_format(self) -> None:
desc = build_rich_field_description({"format": "date-time"})
assert "Format: date-time" in desc
def test_enum(self) -> None:
desc = build_rich_field_description({"enum": ["a", "b"]})
assert "Allowed values:" in desc
assert "'a'" in desc
assert "'b'" in desc
def test_pattern(self) -> None:
desc = build_rich_field_description({"pattern": "^[a-z]+$"})
assert "Pattern: ^[a-z]+$" in desc
def test_min_max(self) -> None:
desc = build_rich_field_description({"minimum": 0, "maximum": 100})
assert "Minimum: 0" in desc
assert "Maximum: 100" in desc
def test_min_max_length(self) -> None:
desc = build_rich_field_description({"minLength": 1, "maxLength": 255})
assert "Min length: 1" in desc
assert "Max length: 255" in desc
def test_examples(self) -> None:
desc = build_rich_field_description({"examples": ["foo", "bar", "baz", "extra"]})
assert "Examples:" in desc
assert "'foo'" in desc
assert "'baz'" in desc
# Only first 3 shown
assert "'extra'" not in desc
def test_combined_constraints(self) -> None:
desc = build_rich_field_description({
"description": "A score",
"minimum": 0,
"maximum": 10,
"format": "int32",
})
assert desc.startswith("A score")
assert "Minimum: 0" in desc
assert "Maximum: 10" in desc
assert "Format: int32" in desc
# ---------------------------------------------------------------------------
# Schema transformation functions
# ---------------------------------------------------------------------------
class TestResolveRefs:
def test_basic_ref_resolution(self) -> None:
schema = {
"type": "object",
"properties": {"item": {"$ref": "#/$defs/Item"}},
"$defs": {
"Item": {"type": "object", "properties": {"id": {"type": "integer"}}},
},
}
resolved = resolve_refs(schema)
assert "$ref" not in resolved["properties"]["item"]
assert resolved["properties"]["item"]["type"] == "object"
def test_nested_ref_resolution(self) -> None:
schema = {
"type": "object",
"properties": {"wrapper": {"$ref": "#/$defs/Wrapper"}},
"$defs": {
"Wrapper": {
"type": "object",
"properties": {"inner": {"$ref": "#/$defs/Inner"}},
},
"Inner": {"type": "string"},
},
}
resolved = resolve_refs(schema)
wrapper = resolved["properties"]["wrapper"]
assert wrapper["properties"]["inner"]["type"] == "string"
def test_missing_ref_raises(self) -> None:
schema = {
"properties": {"x": {"$ref": "#/$defs/Missing"}},
"$defs": {},
}
with pytest.raises(KeyError, match="Missing"):
resolve_refs(schema)
def test_no_refs_unchanged(self) -> None:
schema = {
"type": "object",
"properties": {"name": {"type": "string"}},
}
resolved = resolve_refs(schema)
assert resolved == schema
class TestForceAdditionalPropertiesFalse:
def test_adds_to_object(self) -> None:
schema = {"type": "object", "properties": {"x": {"type": "integer"}}}
result = force_additional_properties_false(deepcopy(schema))
assert result["additionalProperties"] is False
def test_adds_empty_properties_and_required(self) -> None:
schema = {"type": "object"}
result = force_additional_properties_false(deepcopy(schema))
assert result["properties"] == {}
assert result["required"] == []
def test_recursive_nested_objects(self) -> None:
schema = {
"type": "object",
"properties": {
"child": {
"type": "object",
"properties": {"id": {"type": "integer"}},
},
},
}
result = force_additional_properties_false(deepcopy(schema))
assert result["additionalProperties"] is False
assert result["properties"]["child"]["additionalProperties"] is False
def test_does_not_affect_non_objects(self) -> None:
schema = {"type": "string"}
result = force_additional_properties_false(deepcopy(schema))
assert "additionalProperties" not in result
class TestStripUnsupportedFormats:
def test_removes_email_format(self) -> None:
schema = {"type": "string", "format": "email"}
result = strip_unsupported_formats(deepcopy(schema))
assert "format" not in result
def test_keeps_date_time(self) -> None:
schema = {"type": "string", "format": "date-time"}
result = strip_unsupported_formats(deepcopy(schema))
assert result["format"] == "date-time"
def test_keeps_date(self) -> None:
schema = {"type": "string", "format": "date"}
result = strip_unsupported_formats(deepcopy(schema))
assert result["format"] == "date"
def test_removes_uri_format(self) -> None:
schema = {"type": "string", "format": "uri"}
result = strip_unsupported_formats(deepcopy(schema))
assert "format" not in result
def test_recursive(self) -> None:
schema = {
"type": "object",
"properties": {
"email": {"type": "string", "format": "email"},
"created": {"type": "string", "format": "date-time"},
},
}
result = strip_unsupported_formats(deepcopy(schema))
assert "format" not in result["properties"]["email"]
assert result["properties"]["created"]["format"] == "date-time"
class TestEnsureTypeInSchemas:
def test_empty_schema_in_anyof_gets_type(self) -> None:
schema = {"anyOf": [{}, {"type": "string"}]}
result = ensure_type_in_schemas(deepcopy(schema))
assert result["anyOf"][0] == {"type": "object"}
def test_empty_schema_in_oneof_gets_type(self) -> None:
schema = {"oneOf": [{}, {"type": "integer"}]}
result = ensure_type_in_schemas(deepcopy(schema))
assert result["oneOf"][0] == {"type": "object"}
def test_non_empty_unchanged(self) -> None:
schema = {"anyOf": [{"type": "string"}, {"type": "integer"}]}
result = ensure_type_in_schemas(deepcopy(schema))
assert result == schema
class TestConvertOneofToAnyof:
def test_converts_top_level(self) -> None:
schema = {"oneOf": [{"type": "string"}, {"type": "integer"}]}
result = convert_oneof_to_anyof(deepcopy(schema))
assert "oneOf" not in result
assert "anyOf" in result
assert len(result["anyOf"]) == 2
def test_converts_nested(self) -> None:
schema = {
"type": "object",
"properties": {
"value": {"oneOf": [{"type": "string"}, {"type": "number"}]},
},
}
result = convert_oneof_to_anyof(deepcopy(schema))
assert "anyOf" in result["properties"]["value"]
assert "oneOf" not in result["properties"]["value"]
class TestEnsureAllPropertiesRequired:
def test_makes_all_required(self) -> None:
schema = {
"type": "object",
"properties": {"a": {"type": "string"}, "b": {"type": "integer"}},
"required": ["a"],
}
result = ensure_all_properties_required(deepcopy(schema))
assert set(result["required"]) == {"a", "b"}
def test_recursive(self) -> None:
schema = {
"type": "object",
"properties": {
"child": {
"type": "object",
"properties": {"x": {"type": "integer"}, "y": {"type": "integer"}},
"required": [],
},
},
}
result = ensure_all_properties_required(deepcopy(schema))
assert set(result["properties"]["child"]["required"]) == {"x", "y"}
class TestStripNullFromTypes:
def test_strips_null_from_anyof(self) -> None:
schema = {
"anyOf": [{"type": "string"}, {"type": "null"}],
}
result = strip_null_from_types(deepcopy(schema))
assert "anyOf" not in result
assert result["type"] == "string"
def test_strips_null_from_type_array(self) -> None:
schema = {"type": ["string", "null"]}
result = strip_null_from_types(deepcopy(schema))
assert result["type"] == "string"
def test_multiple_non_null_in_anyof(self) -> None:
schema = {
"anyOf": [{"type": "string"}, {"type": "integer"}, {"type": "null"}],
}
result = strip_null_from_types(deepcopy(schema))
assert len(result["anyOf"]) == 2
def test_no_null_unchanged(self) -> None:
schema = {"type": "string"}
result = strip_null_from_types(deepcopy(schema))
assert result == schema
class TestEndToEndMCPSchema:
"""Realistic MCP tool schema exercising multiple features simultaneously."""
MCP_SCHEMA: dict[str, Any] = {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "Search query",
"minLength": 1,
"maxLength": 500,
},
"max_results": {
"type": "integer",
"description": "Maximum results",
"minimum": 1,
"maximum": 100,
},
"format": {
"type": "string",
"enum": ["json", "csv", "xml"],
"description": "Output format",
},
"filters": {
"type": "object",
"properties": {
"date_from": {"type": "string", "format": "date"},
"date_to": {"type": "string", "format": "date"},
"categories": {
"type": "array",
"items": {"type": "string"},
},
},
"required": ["date_from"],
},
"sort_order": {
"anyOf": [{"type": "string"}, {"type": "null"}],
},
},
"required": ["query", "format", "filters"],
}
def test_model_creation(self) -> None:
Model = create_model_from_schema(self.MCP_SCHEMA)
assert Model is not None
assert issubclass(Model, BaseModel)
def test_valid_input_accepted(self) -> None:
Model = create_model_from_schema(self.MCP_SCHEMA)
obj = Model(
query="test search",
format="json",
filters={"date_from": "2025-01-01"},
)
assert obj.query == "test search"
assert obj.format == "json"
def test_invalid_enum_rejected(self) -> None:
Model = create_model_from_schema(self.MCP_SCHEMA)
with pytest.raises(Exception):
Model(
query="test",
format="yaml",
filters={"date_from": "2025-01-01"},
)
def test_model_name_for_mcp_tool(self) -> None:
Model = create_model_from_schema(
self.MCP_SCHEMA, model_name="search_toolSchema"
)
assert Model.__name__ == "search_toolSchema"
def test_enriched_descriptions_for_mcp(self) -> None:
Model = create_model_from_schema(
self.MCP_SCHEMA, enrich_descriptions=True
)
query_field = Model.model_fields["query"]
assert "Min length: 1" in query_field.description
assert "Max length: 500" in query_field.description
max_results_field = Model.model_fields["max_results"]
assert "Minimum: 1" in max_results_field.description
assert "Maximum: 100" in max_results_field.description
format_field = Model.model_fields["format"]
assert "Allowed values:" in format_field.description
def test_optional_fields_accept_none(self) -> None:
Model = create_model_from_schema(self.MCP_SCHEMA)
obj = Model(
query="test",
format="csv",
filters={"date_from": "2025-01-01"},
max_results=None,
sort_order=None,
)
assert obj.max_results is None
assert obj.sort_order is None
def test_nested_filters_validated(self) -> None:
Model = create_model_from_schema(self.MCP_SCHEMA)
obj = Model(
query="test",
format="xml",
filters={
"date_from": "2025-01-01",
"date_to": "2025-12-31",
"categories": ["news", "tech"],
},
)
assert obj.filters.date_from == datetime.date(2025, 1, 1)
assert obj.filters.categories == ["news", "tech"]