mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-04-30 14:52:36 +00:00
Enhance MCP tool resolution and related events (#4580)
* feat: enhance MCP tool resolution * feat: emit event when MCP configuration fails * feat: emit event when MCP tool execution has failed * style: resolve linter issues * refactor: use clear and natural mcp tool name resolution * test: fix broken tests * fix: resolve MCP connection leaks, slug validation, duplicate connections, and httpx exception handling --------- Co-authored-by: Greyson LaLonde <greyson.r.lalonde@gmail.com> Co-authored-by: Greyson LaLonde <greyson@crewai.com>
This commit is contained in:
@@ -8,11 +8,9 @@ import time
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Final,
|
||||
Literal,
|
||||
cast,
|
||||
)
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
@@ -61,16 +59,8 @@ from crewai.knowledge.knowledge import Knowledge
|
||||
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
|
||||
from crewai.lite_agent_output import LiteAgentOutput
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.mcp import (
|
||||
MCPClient,
|
||||
MCPServerConfig,
|
||||
MCPServerHTTP,
|
||||
MCPServerSSE,
|
||||
MCPServerStdio,
|
||||
)
|
||||
from crewai.mcp.transports.http import HTTPTransport
|
||||
from crewai.mcp.transports.sse import SSETransport
|
||||
from crewai.mcp.transports.stdio import StdioTransport
|
||||
from crewai.mcp import MCPServerConfig
|
||||
from crewai.mcp.tool_resolver import MCPToolResolver
|
||||
from crewai.rag.embeddings.types import EmbedderConfig
|
||||
from crewai.security.fingerprint import Fingerprint
|
||||
from crewai.tools.agent_tools.agent_tools import AgentTools
|
||||
@@ -111,18 +101,8 @@ if TYPE_CHECKING:
|
||||
from crewai.utilities.types import LLMMessage
|
||||
|
||||
|
||||
# MCP Connection timeout constants (in seconds)
|
||||
MCP_CONNECTION_TIMEOUT: Final[int] = 10
|
||||
MCP_TOOL_EXECUTION_TIMEOUT: Final[int] = 30
|
||||
MCP_DISCOVERY_TIMEOUT: Final[int] = 15
|
||||
MCP_MAX_RETRIES: Final[int] = 3
|
||||
|
||||
_passthrough_exceptions: tuple[type[Exception], ...] = ()
|
||||
|
||||
# Simple in-memory cache for MCP tool schemas (duration: 5 minutes)
|
||||
_mcp_schema_cache: dict[str, Any] = {}
|
||||
_cache_ttl: Final[int] = 300 # 5 minutes
|
||||
|
||||
|
||||
class Agent(BaseAgent):
|
||||
"""Represents an agent in a system.
|
||||
@@ -154,7 +134,7 @@ class Agent(BaseAgent):
|
||||
model_config = ConfigDict()
|
||||
|
||||
_times_executed: int = PrivateAttr(default=0)
|
||||
_mcp_clients: list[Any] = PrivateAttr(default_factory=list)
|
||||
_mcp_resolver: MCPToolResolver | None = PrivateAttr(default=None)
|
||||
_last_messages: list[LLMMessage] = PrivateAttr(default_factory=list)
|
||||
max_execution_time: int | None = Field(
|
||||
default=None,
|
||||
@@ -934,544 +914,17 @@ class Agent(BaseAgent):
|
||||
def get_mcp_tools(self, mcps: list[str | MCPServerConfig]) -> list[BaseTool]:
|
||||
"""Convert MCP server references/configs to CrewAI tools.
|
||||
|
||||
Supports both string references (backwards compatible) and structured
|
||||
configuration objects (MCPServerStdio, MCPServerHTTP, MCPServerSSE).
|
||||
|
||||
Args:
|
||||
mcps: List of MCP server references (strings) or configurations.
|
||||
|
||||
Returns:
|
||||
List of BaseTool instances from MCP servers.
|
||||
Delegates to :class:`~crewai.mcp.tool_resolver.MCPToolResolver`.
|
||||
"""
|
||||
all_tools = []
|
||||
clients = []
|
||||
|
||||
for mcp_config in mcps:
|
||||
if isinstance(mcp_config, str):
|
||||
tools = self._get_mcp_tools_from_string(mcp_config)
|
||||
else:
|
||||
tools, client = self._get_native_mcp_tools(mcp_config)
|
||||
if client:
|
||||
clients.append(client)
|
||||
|
||||
all_tools.extend(tools)
|
||||
|
||||
# Store clients for cleanup
|
||||
self._mcp_clients.extend(clients)
|
||||
return all_tools
|
||||
self._cleanup_mcp_clients()
|
||||
self._mcp_resolver = MCPToolResolver(agent=self, logger=self._logger)
|
||||
return self._mcp_resolver.resolve(mcps)
|
||||
|
||||
def _cleanup_mcp_clients(self) -> None:
|
||||
"""Cleanup MCP client connections after task execution."""
|
||||
if not self._mcp_clients:
|
||||
return
|
||||
|
||||
async def _disconnect_all() -> None:
|
||||
for client in self._mcp_clients:
|
||||
if client and hasattr(client, "connected") and client.connected:
|
||||
await client.disconnect()
|
||||
|
||||
try:
|
||||
asyncio.run(_disconnect_all())
|
||||
except Exception as e:
|
||||
self._logger.log("error", f"Error during MCP client cleanup: {e}")
|
||||
finally:
|
||||
self._mcp_clients.clear()
|
||||
|
||||
def _get_mcp_tools_from_string(self, mcp_ref: str) -> list[BaseTool]:
|
||||
"""Get tools from legacy string-based MCP references.
|
||||
|
||||
This method maintains backwards compatibility with string-based
|
||||
MCP references (https://... and crewai-amp:...).
|
||||
|
||||
Args:
|
||||
mcp_ref: String reference to MCP server.
|
||||
|
||||
Returns:
|
||||
List of BaseTool instances.
|
||||
"""
|
||||
if mcp_ref.startswith("crewai-amp:"):
|
||||
return self._get_amp_mcp_tools(mcp_ref)
|
||||
if mcp_ref.startswith("https://"):
|
||||
return self._get_external_mcp_tools(mcp_ref)
|
||||
return []
|
||||
|
||||
def _get_external_mcp_tools(self, mcp_ref: str) -> list[BaseTool]:
|
||||
"""Get tools from external HTTPS MCP server with graceful error handling."""
|
||||
from crewai.tools.mcp_tool_wrapper import MCPToolWrapper
|
||||
|
||||
# Parse server URL and optional tool name
|
||||
if "#" in mcp_ref:
|
||||
server_url, specific_tool = mcp_ref.split("#", 1)
|
||||
else:
|
||||
server_url, specific_tool = mcp_ref, None
|
||||
|
||||
server_params = {"url": server_url}
|
||||
server_name = self._extract_server_name(server_url)
|
||||
|
||||
try:
|
||||
# Get tool schemas with timeout and error handling
|
||||
tool_schemas = self._get_mcp_tool_schemas(server_params)
|
||||
|
||||
if not tool_schemas:
|
||||
self._logger.log(
|
||||
"warning", f"No tools discovered from MCP server: {server_url}"
|
||||
)
|
||||
return []
|
||||
|
||||
tools = []
|
||||
for tool_name, schema in tool_schemas.items():
|
||||
# Skip if specific tool requested and this isn't it
|
||||
if specific_tool and tool_name != specific_tool:
|
||||
continue
|
||||
|
||||
try:
|
||||
wrapper = MCPToolWrapper(
|
||||
mcp_server_params=server_params,
|
||||
tool_name=tool_name,
|
||||
tool_schema=schema,
|
||||
server_name=server_name,
|
||||
)
|
||||
tools.append(wrapper)
|
||||
except Exception as e:
|
||||
self._logger.log(
|
||||
"warning",
|
||||
f"Failed to create MCP tool wrapper for {tool_name}: {e}",
|
||||
)
|
||||
continue
|
||||
|
||||
if specific_tool and not tools:
|
||||
self._logger.log(
|
||||
"warning",
|
||||
f"Specific tool '{specific_tool}' not found on MCP server: {server_url}",
|
||||
)
|
||||
|
||||
return cast(list[BaseTool], tools)
|
||||
|
||||
except Exception as e:
|
||||
self._logger.log(
|
||||
"warning", f"Failed to connect to MCP server {server_url}: {e}"
|
||||
)
|
||||
return []
|
||||
|
||||
def _get_native_mcp_tools(
|
||||
self, mcp_config: MCPServerConfig
|
||||
) -> tuple[list[BaseTool], Any | None]:
|
||||
"""Get tools from MCP server using structured configuration.
|
||||
|
||||
This method creates an MCP client based on the configuration type,
|
||||
connects to the server, discovers tools, applies filtering, and
|
||||
returns wrapped tools along with the client instance for cleanup.
|
||||
|
||||
Args:
|
||||
mcp_config: MCP server configuration (MCPServerStdio, MCPServerHTTP, or MCPServerSSE).
|
||||
|
||||
Returns:
|
||||
Tuple of (list of BaseTool instances, MCPClient instance for cleanup).
|
||||
"""
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
from crewai.tools.mcp_native_tool import MCPNativeTool
|
||||
|
||||
transport: StdioTransport | HTTPTransport | SSETransport
|
||||
if isinstance(mcp_config, MCPServerStdio):
|
||||
transport = StdioTransport(
|
||||
command=mcp_config.command,
|
||||
args=mcp_config.args,
|
||||
env=mcp_config.env,
|
||||
)
|
||||
server_name = f"{mcp_config.command}_{'_'.join(mcp_config.args)}"
|
||||
elif isinstance(mcp_config, MCPServerHTTP):
|
||||
transport = HTTPTransport(
|
||||
url=mcp_config.url,
|
||||
headers=mcp_config.headers,
|
||||
streamable=mcp_config.streamable,
|
||||
)
|
||||
server_name = self._extract_server_name(mcp_config.url)
|
||||
elif isinstance(mcp_config, MCPServerSSE):
|
||||
transport = SSETransport(
|
||||
url=mcp_config.url,
|
||||
headers=mcp_config.headers,
|
||||
)
|
||||
server_name = self._extract_server_name(mcp_config.url)
|
||||
else:
|
||||
raise ValueError(f"Unsupported MCP server config type: {type(mcp_config)}")
|
||||
|
||||
client = MCPClient(
|
||||
transport=transport,
|
||||
cache_tools_list=mcp_config.cache_tools_list,
|
||||
)
|
||||
|
||||
async def _setup_client_and_list_tools() -> list[dict[str, Any]]:
|
||||
"""Async helper to connect and list tools in same event loop."""
|
||||
|
||||
try:
|
||||
if not client.connected:
|
||||
await client.connect()
|
||||
|
||||
tools_list = await client.list_tools()
|
||||
|
||||
try:
|
||||
await client.disconnect()
|
||||
# Small delay to allow background tasks to finish cleanup
|
||||
# This helps prevent "cancel scope in different task" errors
|
||||
# when asyncio.run() closes the event loop
|
||||
await asyncio.sleep(0.1)
|
||||
except Exception as e:
|
||||
self._logger.log("error", f"Error during disconnect: {e}")
|
||||
|
||||
return tools_list
|
||||
except Exception as e:
|
||||
if client.connected:
|
||||
await client.disconnect()
|
||||
await asyncio.sleep(0.1)
|
||||
raise RuntimeError(
|
||||
f"Error during setup client and list tools: {e}"
|
||||
) from e
|
||||
|
||||
try:
|
||||
try:
|
||||
asyncio.get_running_loop()
|
||||
import concurrent.futures
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
future = executor.submit(
|
||||
asyncio.run, _setup_client_and_list_tools()
|
||||
)
|
||||
tools_list = future.result()
|
||||
except RuntimeError:
|
||||
try:
|
||||
tools_list = asyncio.run(_setup_client_and_list_tools())
|
||||
except RuntimeError as e:
|
||||
error_msg = str(e).lower()
|
||||
if "cancel scope" in error_msg or "task" in error_msg:
|
||||
raise ConnectionError(
|
||||
"MCP connection failed due to event loop cleanup issues. "
|
||||
"This may be due to authentication errors or server unavailability."
|
||||
) from e
|
||||
except asyncio.CancelledError as e:
|
||||
raise ConnectionError(
|
||||
"MCP connection was cancelled. This may indicate an authentication "
|
||||
"error or server unavailability."
|
||||
) from e
|
||||
|
||||
if mcp_config.tool_filter:
|
||||
filtered_tools = []
|
||||
for tool in tools_list:
|
||||
if callable(mcp_config.tool_filter):
|
||||
try:
|
||||
from crewai.mcp.filters import ToolFilterContext
|
||||
|
||||
context = ToolFilterContext(
|
||||
agent=self,
|
||||
server_name=server_name,
|
||||
run_context=None,
|
||||
)
|
||||
if mcp_config.tool_filter(context, tool): # type: ignore[call-arg, arg-type]
|
||||
filtered_tools.append(tool)
|
||||
except (TypeError, AttributeError):
|
||||
if mcp_config.tool_filter(tool): # type: ignore[call-arg, arg-type]
|
||||
filtered_tools.append(tool)
|
||||
else:
|
||||
# Not callable - include tool
|
||||
filtered_tools.append(tool)
|
||||
tools_list = filtered_tools
|
||||
|
||||
tools = []
|
||||
for tool_def in tools_list:
|
||||
tool_name = tool_def.get("name", "")
|
||||
if not tool_name:
|
||||
continue
|
||||
|
||||
# Convert inputSchema to Pydantic model if present
|
||||
args_schema = None
|
||||
if tool_def.get("inputSchema"):
|
||||
args_schema = self._json_schema_to_pydantic(
|
||||
tool_name, tool_def["inputSchema"]
|
||||
)
|
||||
|
||||
tool_schema = {
|
||||
"description": tool_def.get("description", ""),
|
||||
"args_schema": args_schema,
|
||||
}
|
||||
|
||||
try:
|
||||
native_tool = MCPNativeTool(
|
||||
mcp_client=client,
|
||||
tool_name=tool_name,
|
||||
tool_schema=tool_schema,
|
||||
server_name=server_name,
|
||||
)
|
||||
tools.append(native_tool)
|
||||
except Exception as e:
|
||||
self._logger.log("error", f"Failed to create native MCP tool: {e}")
|
||||
continue
|
||||
|
||||
return cast(list[BaseTool], tools), client
|
||||
except Exception as e:
|
||||
if client.connected:
|
||||
asyncio.run(client.disconnect())
|
||||
|
||||
raise RuntimeError(f"Failed to get native MCP tools: {e}") from e
|
||||
|
||||
def _get_amp_mcp_tools(self, amp_ref: str) -> list[BaseTool]:
|
||||
"""Get tools from CrewAI AMP MCP marketplace."""
|
||||
# Parse: "crewai-amp:mcp-name" or "crewai-amp:mcp-name#tool_name"
|
||||
amp_part = amp_ref.replace("crewai-amp:", "")
|
||||
if "#" in amp_part:
|
||||
mcp_name, specific_tool = amp_part.split("#", 1)
|
||||
else:
|
||||
mcp_name, specific_tool = amp_part, None
|
||||
|
||||
# Call AMP API to get MCP server URLs
|
||||
mcp_servers = self._fetch_amp_mcp_servers(mcp_name)
|
||||
|
||||
tools = []
|
||||
for server_config in mcp_servers:
|
||||
server_ref = server_config["url"]
|
||||
if specific_tool:
|
||||
server_ref += f"#{specific_tool}"
|
||||
server_tools = self._get_external_mcp_tools(server_ref)
|
||||
tools.extend(server_tools)
|
||||
|
||||
return tools
|
||||
|
||||
@staticmethod
|
||||
def _extract_server_name(server_url: str) -> str:
|
||||
"""Extract clean server name from URL for tool prefixing."""
|
||||
|
||||
parsed = urlparse(server_url)
|
||||
domain = parsed.netloc.replace(".", "_")
|
||||
path = parsed.path.replace("/", "_").strip("_")
|
||||
return f"{domain}_{path}" if path else domain
|
||||
|
||||
def _get_mcp_tool_schemas(
|
||||
self, server_params: dict[str, Any]
|
||||
) -> dict[str, dict[str, Any]]:
|
||||
"""Get tool schemas from MCP server for wrapper creation with caching."""
|
||||
server_url = server_params["url"]
|
||||
|
||||
# Check cache first
|
||||
cache_key = server_url
|
||||
current_time = time.time()
|
||||
|
||||
if cache_key in _mcp_schema_cache:
|
||||
cached_data, cache_time = _mcp_schema_cache[cache_key]
|
||||
if current_time - cache_time < _cache_ttl:
|
||||
self._logger.log(
|
||||
"debug", f"Using cached MCP tool schemas for {server_url}"
|
||||
)
|
||||
return cached_data # type: ignore[no-any-return]
|
||||
|
||||
try:
|
||||
schemas = asyncio.run(self._get_mcp_tool_schemas_async(server_params))
|
||||
|
||||
# Cache successful results
|
||||
_mcp_schema_cache[cache_key] = (schemas, current_time)
|
||||
|
||||
return schemas
|
||||
except Exception as e:
|
||||
# Log warning but don't raise - this allows graceful degradation
|
||||
self._logger.log(
|
||||
"warning", f"Failed to get MCP tool schemas from {server_url}: {e}"
|
||||
)
|
||||
return {}
|
||||
|
||||
async def _get_mcp_tool_schemas_async(
|
||||
self, server_params: dict[str, Any]
|
||||
) -> dict[str, dict[str, Any]]:
|
||||
"""Async implementation of MCP tool schema retrieval with timeouts and retries."""
|
||||
server_url = server_params["url"]
|
||||
return await self._retry_mcp_discovery(
|
||||
self._discover_mcp_tools_with_timeout, server_url
|
||||
)
|
||||
|
||||
async def _retry_mcp_discovery(
|
||||
self, operation_func: Any, server_url: str
|
||||
) -> dict[str, dict[str, Any]]:
|
||||
"""Retry MCP discovery operation with exponential backoff, avoiding try-except in loop."""
|
||||
last_error = None
|
||||
|
||||
for attempt in range(MCP_MAX_RETRIES):
|
||||
# Execute single attempt outside try-except loop structure
|
||||
result, error, should_retry = await self._attempt_mcp_discovery(
|
||||
operation_func, server_url
|
||||
)
|
||||
|
||||
# Success case - return immediately
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
# Non-retryable error - raise immediately
|
||||
if not should_retry:
|
||||
raise RuntimeError(error)
|
||||
|
||||
# Retryable error - continue with backoff
|
||||
last_error = error
|
||||
if attempt < MCP_MAX_RETRIES - 1:
|
||||
wait_time = 2**attempt # Exponential backoff
|
||||
await asyncio.sleep(wait_time)
|
||||
|
||||
raise RuntimeError(
|
||||
f"Failed to discover MCP tools after {MCP_MAX_RETRIES} attempts: {last_error}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def _attempt_mcp_discovery(
|
||||
operation_func: Any, server_url: str
|
||||
) -> tuple[dict[str, dict[str, Any]] | None, str, bool]:
|
||||
"""Attempt single MCP discovery operation and return (result, error_message, should_retry)."""
|
||||
try:
|
||||
result = await operation_func(server_url)
|
||||
return result, "", False
|
||||
|
||||
except ImportError:
|
||||
return (
|
||||
None,
|
||||
"MCP library not available. Please install with: pip install mcp",
|
||||
False,
|
||||
)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
return (
|
||||
None,
|
||||
f"MCP discovery timed out after {MCP_DISCOVERY_TIMEOUT} seconds",
|
||||
True,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
error_str = str(e).lower()
|
||||
|
||||
# Classify errors as retryable or non-retryable
|
||||
if "authentication" in error_str or "unauthorized" in error_str:
|
||||
return None, f"Authentication failed for MCP server: {e!s}", False
|
||||
if "connection" in error_str or "network" in error_str:
|
||||
return None, f"Network connection failed: {e!s}", True
|
||||
if "json" in error_str or "parsing" in error_str:
|
||||
return None, f"Server response parsing error: {e!s}", True
|
||||
return None, f"MCP discovery error: {e!s}", False
|
||||
|
||||
async def _discover_mcp_tools_with_timeout(
|
||||
self, server_url: str
|
||||
) -> dict[str, dict[str, Any]]:
|
||||
"""Discover MCP tools with timeout wrapper."""
|
||||
return await asyncio.wait_for(
|
||||
self._discover_mcp_tools(server_url), timeout=MCP_DISCOVERY_TIMEOUT
|
||||
)
|
||||
|
||||
async def _discover_mcp_tools(self, server_url: str) -> dict[str, dict[str, Any]]:
|
||||
"""Discover tools from MCP server with proper timeout handling."""
|
||||
from mcp import ClientSession
|
||||
from mcp.client.streamable_http import streamablehttp_client
|
||||
|
||||
async with streamablehttp_client(server_url) as (read, write, _):
|
||||
async with ClientSession(read, write) as session:
|
||||
# Initialize the connection with timeout
|
||||
await asyncio.wait_for(
|
||||
session.initialize(), timeout=MCP_CONNECTION_TIMEOUT
|
||||
)
|
||||
|
||||
# List available tools with timeout
|
||||
tools_result = await asyncio.wait_for(
|
||||
session.list_tools(),
|
||||
timeout=MCP_DISCOVERY_TIMEOUT - MCP_CONNECTION_TIMEOUT,
|
||||
)
|
||||
|
||||
schemas = {}
|
||||
for tool in tools_result.tools:
|
||||
args_schema = None
|
||||
if hasattr(tool, "inputSchema") and tool.inputSchema:
|
||||
args_schema = self._json_schema_to_pydantic(
|
||||
sanitize_tool_name(tool.name), tool.inputSchema
|
||||
)
|
||||
|
||||
schemas[sanitize_tool_name(tool.name)] = {
|
||||
"description": getattr(tool, "description", ""),
|
||||
"args_schema": args_schema,
|
||||
}
|
||||
return schemas
|
||||
|
||||
def _json_schema_to_pydantic(
|
||||
self, tool_name: str, json_schema: dict[str, Any]
|
||||
) -> type:
|
||||
"""Convert JSON Schema to Pydantic model for tool arguments.
|
||||
|
||||
Args:
|
||||
tool_name: Name of the tool (used for model naming)
|
||||
json_schema: JSON Schema dict with 'properties', 'required', etc.
|
||||
|
||||
Returns:
|
||||
Pydantic BaseModel class
|
||||
"""
|
||||
from pydantic import Field, create_model
|
||||
|
||||
properties = json_schema.get("properties", {})
|
||||
required_fields = json_schema.get("required", [])
|
||||
|
||||
field_definitions: dict[str, Any] = {}
|
||||
|
||||
for field_name, field_schema in properties.items():
|
||||
field_type = self._json_type_to_python(field_schema)
|
||||
field_description = field_schema.get("description", "")
|
||||
|
||||
is_required = field_name in required_fields
|
||||
|
||||
if is_required:
|
||||
field_definitions[field_name] = (
|
||||
field_type,
|
||||
Field(..., description=field_description),
|
||||
)
|
||||
else:
|
||||
field_definitions[field_name] = (
|
||||
field_type | None,
|
||||
Field(default=None, description=field_description),
|
||||
)
|
||||
|
||||
model_name = f"{tool_name.replace('-', '_').replace(' ', '_')}Schema"
|
||||
return create_model(model_name, **field_definitions) # type: ignore[no-any-return]
|
||||
|
||||
def _json_type_to_python(self, field_schema: dict[str, Any]) -> type:
|
||||
"""Convert JSON Schema type to Python type.
|
||||
|
||||
Args:
|
||||
field_schema: JSON Schema field definition
|
||||
|
||||
Returns:
|
||||
Python type
|
||||
"""
|
||||
|
||||
json_type = field_schema.get("type")
|
||||
|
||||
if "anyOf" in field_schema:
|
||||
types: list[type] = []
|
||||
for option in field_schema["anyOf"]:
|
||||
if "const" in option:
|
||||
types.append(str)
|
||||
else:
|
||||
types.append(self._json_type_to_python(option))
|
||||
unique_types = list(set(types))
|
||||
if len(unique_types) > 1:
|
||||
result: Any = unique_types[0]
|
||||
for t in unique_types[1:]:
|
||||
result = result | t
|
||||
return result # type: ignore[no-any-return]
|
||||
return unique_types[0]
|
||||
|
||||
type_mapping: dict[str | None, type] = {
|
||||
"string": str,
|
||||
"number": float,
|
||||
"integer": int,
|
||||
"boolean": bool,
|
||||
"array": list,
|
||||
"object": dict,
|
||||
}
|
||||
|
||||
return type_mapping.get(json_type, Any)
|
||||
|
||||
@staticmethod
|
||||
def _fetch_amp_mcp_servers(mcp_name: str) -> list[dict[str, Any]]:
|
||||
"""Fetch MCP server configurations from CrewAI AMP API."""
|
||||
# TODO: Implement AMP API call to "integrations/mcps" endpoint
|
||||
# Should return list of server configs with URLs
|
||||
return []
|
||||
if self._mcp_resolver is not None:
|
||||
self._mcp_resolver.cleanup()
|
||||
self._mcp_resolver = None
|
||||
|
||||
@staticmethod
|
||||
def get_multimodal_tools() -> Sequence[BaseTool]:
|
||||
|
||||
@@ -4,7 +4,8 @@ from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable
|
||||
from copy import copy as shallow_copy
|
||||
from hashlib import md5
|
||||
from typing import Any, Literal
|
||||
import re
|
||||
from typing import Any, Final, Literal
|
||||
import uuid
|
||||
|
||||
from pydantic import (
|
||||
@@ -36,6 +37,11 @@ from crewai.utilities.rpm_controller import RPMController
|
||||
from crewai.utilities.string_utils import interpolate_only
|
||||
|
||||
|
||||
_SLUG_RE: Final[re.Pattern[str]] = re.compile(
|
||||
r"^(?:crewai-amp:)?[a-zA-Z0-9][a-zA-Z0-9_-]*(?:#\w+)?$"
|
||||
)
|
||||
|
||||
|
||||
PlatformApp = Literal[
|
||||
"asana",
|
||||
"box",
|
||||
@@ -197,7 +203,7 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta):
|
||||
)
|
||||
mcps: list[str | MCPServerConfig] | None = Field(
|
||||
default=None,
|
||||
description="List of MCP server references. Supports 'https://server.com/path' for external servers and 'crewai-amp:mcp-name' for AMP marketplace. Use '#tool_name' suffix for specific tools.",
|
||||
description="List of MCP server references. Supports 'https://server.com/path' for external servers and bare slugs like 'notion' for connected MCP integrations. Use '#tool_name' suffix for specific tools.",
|
||||
)
|
||||
memory: Any = Field(
|
||||
default=None,
|
||||
@@ -276,14 +282,16 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta):
|
||||
validated_mcps: list[str | MCPServerConfig] = []
|
||||
for mcp in mcps:
|
||||
if isinstance(mcp, str):
|
||||
if mcp.startswith(("https://", "crewai-amp:")):
|
||||
if mcp.startswith("https://"):
|
||||
validated_mcps.append(mcp)
|
||||
elif _SLUG_RE.match(mcp):
|
||||
validated_mcps.append(mcp)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid MCP reference: {mcp}. "
|
||||
"String references must start with 'https://' or 'crewai-amp:'"
|
||||
f"Invalid MCP reference: {mcp!r}. "
|
||||
"String references must be an 'https://' URL or a valid "
|
||||
"slug (e.g. 'notion', 'notion#search', 'crewai-amp:notion')."
|
||||
)
|
||||
|
||||
elif isinstance(mcp, (MCPServerConfig)):
|
||||
validated_mcps.append(mcp)
|
||||
else:
|
||||
|
||||
@@ -190,6 +190,15 @@ class PlusAPI:
|
||||
timeout=30,
|
||||
)
|
||||
|
||||
def get_mcp_configs(self, slugs: list[str]) -> httpx.Response:
|
||||
"""Get MCP server configurations for the given slugs."""
|
||||
return self._make_request(
|
||||
"GET",
|
||||
f"{self.INTEGRATIONS_RESOURCE}/mcp_configs",
|
||||
params={"slugs": ",".join(slugs)},
|
||||
timeout=30,
|
||||
)
|
||||
|
||||
def get_triggers(self) -> httpx.Response:
|
||||
"""Get all available triggers from integrations."""
|
||||
return self._make_request("GET", f"{self.INTEGRATIONS_RESOURCE}/apps")
|
||||
|
||||
@@ -63,6 +63,7 @@ from crewai.events.types.logging_events import (
|
||||
AgentLogsStartedEvent,
|
||||
)
|
||||
from crewai.events.types.mcp_events import (
|
||||
MCPConfigFetchFailedEvent,
|
||||
MCPConnectionCompletedEvent,
|
||||
MCPConnectionFailedEvent,
|
||||
MCPConnectionStartedEvent,
|
||||
@@ -165,6 +166,7 @@ __all__ = [
|
||||
"LiteAgentExecutionCompletedEvent",
|
||||
"LiteAgentExecutionErrorEvent",
|
||||
"LiteAgentExecutionStartedEvent",
|
||||
"MCPConfigFetchFailedEvent",
|
||||
"MCPConnectionCompletedEvent",
|
||||
"MCPConnectionFailedEvent",
|
||||
"MCPConnectionStartedEvent",
|
||||
|
||||
@@ -68,6 +68,7 @@ from crewai.events.types.logging_events import (
|
||||
AgentLogsStartedEvent,
|
||||
)
|
||||
from crewai.events.types.mcp_events import (
|
||||
MCPConfigFetchFailedEvent,
|
||||
MCPConnectionCompletedEvent,
|
||||
MCPConnectionFailedEvent,
|
||||
MCPConnectionStartedEvent,
|
||||
@@ -665,6 +666,16 @@ class EventListener(BaseEventListener):
|
||||
event.error_type,
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(MCPConfigFetchFailedEvent)
|
||||
def on_mcp_config_fetch_failed(
|
||||
_: Any, event: MCPConfigFetchFailedEvent
|
||||
) -> None:
|
||||
self.formatter.handle_mcp_config_fetch_failed(
|
||||
event.slug,
|
||||
event.error,
|
||||
event.error_type,
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(MCPToolExecutionStartedEvent)
|
||||
def on_mcp_tool_execution_started(
|
||||
_: Any, event: MCPToolExecutionStartedEvent
|
||||
|
||||
@@ -67,6 +67,7 @@ from crewai.events.types.llm_guardrail_events import (
|
||||
LLMGuardrailStartedEvent,
|
||||
)
|
||||
from crewai.events.types.mcp_events import (
|
||||
MCPConfigFetchFailedEvent,
|
||||
MCPConnectionCompletedEvent,
|
||||
MCPConnectionFailedEvent,
|
||||
MCPConnectionStartedEvent,
|
||||
@@ -181,4 +182,5 @@ EventTypes = (
|
||||
| MCPToolExecutionStartedEvent
|
||||
| MCPToolExecutionCompletedEvent
|
||||
| MCPToolExecutionFailedEvent
|
||||
| MCPConfigFetchFailedEvent
|
||||
)
|
||||
|
||||
@@ -83,3 +83,16 @@ class MCPToolExecutionFailedEvent(MCPEvent):
|
||||
error_type: str | None = None # "timeout", "validation", "server_error", etc.
|
||||
started_at: datetime | None = None
|
||||
failed_at: datetime | None = None
|
||||
|
||||
|
||||
class MCPConfigFetchFailedEvent(BaseEvent):
|
||||
"""Event emitted when fetching an AMP MCP server config fails.
|
||||
|
||||
This covers cases where the slug is not connected, the API call
|
||||
failed, or native MCP resolution failed after config was fetched.
|
||||
"""
|
||||
|
||||
type: str = "mcp_config_fetch_failed"
|
||||
slug: str
|
||||
error: str
|
||||
error_type: str | None = None # "not_connected", "api_error", "connection_failed"
|
||||
|
||||
@@ -1512,6 +1512,34 @@ To enable tracing, do any one of these:
|
||||
self.print(panel)
|
||||
self.print()
|
||||
|
||||
def handle_mcp_config_fetch_failed(
|
||||
self,
|
||||
slug: str,
|
||||
error: str = "",
|
||||
error_type: str | None = None,
|
||||
) -> None:
|
||||
"""Handle MCP config fetch failed event (AMP resolution failures)."""
|
||||
if not self.verbose:
|
||||
return
|
||||
|
||||
content = Text()
|
||||
content.append("MCP Config Fetch Failed\n\n", style="red bold")
|
||||
content.append("Server: ", style="white")
|
||||
content.append(f"{slug}\n", style="red")
|
||||
|
||||
if error_type:
|
||||
content.append("Error Type: ", style="white")
|
||||
content.append(f"{error_type}\n", style="red")
|
||||
|
||||
if error:
|
||||
content.append("\nError: ", style="white bold")
|
||||
error_preview = error[:500] + "..." if len(error) > 500 else error
|
||||
content.append(f"{error_preview}\n", style="red")
|
||||
|
||||
panel = self.create_panel(content, "❌ MCP Config Failed", "red")
|
||||
self.print(panel)
|
||||
self.print()
|
||||
|
||||
def handle_mcp_tool_execution_started(
|
||||
self,
|
||||
server_name: str,
|
||||
|
||||
@@ -18,6 +18,7 @@ from crewai.mcp.filters import (
|
||||
create_dynamic_tool_filter,
|
||||
create_static_tool_filter,
|
||||
)
|
||||
from crewai.mcp.tool_resolver import MCPToolResolver
|
||||
from crewai.mcp.transports.base import BaseTransport, TransportType
|
||||
|
||||
|
||||
@@ -28,6 +29,7 @@ __all__ = [
|
||||
"MCPServerHTTP",
|
||||
"MCPServerSSE",
|
||||
"MCPServerStdio",
|
||||
"MCPToolResolver",
|
||||
"StaticToolFilter",
|
||||
"ToolFilter",
|
||||
"ToolFilterContext",
|
||||
|
||||
@@ -6,7 +6,7 @@ from contextlib import AsyncExitStack
|
||||
from datetime import datetime
|
||||
import logging
|
||||
import time
|
||||
from typing import Any
|
||||
from typing import Any, NamedTuple
|
||||
|
||||
from typing_extensions import Self
|
||||
|
||||
@@ -34,6 +34,13 @@ from crewai.mcp.transports.stdio import StdioTransport
|
||||
from crewai.utilities.string_utils import sanitize_tool_name
|
||||
|
||||
|
||||
class _MCPToolResult(NamedTuple):
|
||||
"""Internal result from an MCP tool call, carrying the ``isError`` flag."""
|
||||
|
||||
content: str
|
||||
is_error: bool
|
||||
|
||||
|
||||
# MCP Connection timeout constants (in seconds)
|
||||
MCP_CONNECTION_TIMEOUT = 30 # Increased for slow servers
|
||||
MCP_TOOL_EXECUTION_TIMEOUT = 30
|
||||
@@ -420,6 +427,7 @@ class MCPClient:
|
||||
return [
|
||||
{
|
||||
"name": sanitize_tool_name(tool.name),
|
||||
"original_name": tool.name,
|
||||
"description": getattr(tool, "description", ""),
|
||||
"inputSchema": getattr(tool, "inputSchema", {}),
|
||||
}
|
||||
@@ -461,29 +469,46 @@ class MCPClient:
|
||||
)
|
||||
|
||||
try:
|
||||
result = await self._retry_operation(
|
||||
tool_result: _MCPToolResult = await self._retry_operation(
|
||||
lambda: self._call_tool_impl(tool_name, cleaned_arguments),
|
||||
timeout=self.execution_timeout,
|
||||
)
|
||||
|
||||
completed_at = datetime.now()
|
||||
execution_duration_ms = (completed_at - started_at).total_seconds() * 1000
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
MCPToolExecutionCompletedEvent(
|
||||
server_name=server_name,
|
||||
server_url=server_url,
|
||||
transport_type=transport_type,
|
||||
tool_name=tool_name,
|
||||
tool_args=cleaned_arguments,
|
||||
result=result,
|
||||
started_at=started_at,
|
||||
completed_at=completed_at,
|
||||
execution_duration_ms=execution_duration_ms,
|
||||
),
|
||||
)
|
||||
finished_at = datetime.now()
|
||||
execution_duration_ms = (finished_at - started_at).total_seconds() * 1000
|
||||
|
||||
return result
|
||||
if tool_result.is_error:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
MCPToolExecutionFailedEvent(
|
||||
server_name=server_name,
|
||||
server_url=server_url,
|
||||
transport_type=transport_type,
|
||||
tool_name=tool_name,
|
||||
tool_args=cleaned_arguments,
|
||||
error=tool_result.content,
|
||||
error_type="tool_error",
|
||||
started_at=started_at,
|
||||
failed_at=finished_at,
|
||||
),
|
||||
)
|
||||
else:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
MCPToolExecutionCompletedEvent(
|
||||
server_name=server_name,
|
||||
server_url=server_url,
|
||||
transport_type=transport_type,
|
||||
tool_name=tool_name,
|
||||
tool_args=cleaned_arguments,
|
||||
result=tool_result.content,
|
||||
started_at=started_at,
|
||||
completed_at=finished_at,
|
||||
execution_duration_ms=execution_duration_ms,
|
||||
),
|
||||
)
|
||||
|
||||
return tool_result.content
|
||||
except Exception as e:
|
||||
failed_at = datetime.now()
|
||||
error_type = (
|
||||
@@ -564,23 +589,27 @@ class MCPClient:
|
||||
|
||||
return cleaned
|
||||
|
||||
async def _call_tool_impl(self, tool_name: str, arguments: dict[str, Any]) -> Any:
|
||||
async def _call_tool_impl(
|
||||
self, tool_name: str, arguments: dict[str, Any]
|
||||
) -> _MCPToolResult:
|
||||
"""Internal implementation of call_tool."""
|
||||
result = await asyncio.wait_for(
|
||||
self.session.call_tool(tool_name, arguments),
|
||||
timeout=self.execution_timeout,
|
||||
)
|
||||
|
||||
is_error = getattr(result, "isError", False) or False
|
||||
|
||||
# Extract result content
|
||||
if hasattr(result, "content") and result.content:
|
||||
if isinstance(result.content, list) and len(result.content) > 0:
|
||||
content_item = result.content[0]
|
||||
if hasattr(content_item, "text"):
|
||||
return str(content_item.text)
|
||||
return str(content_item)
|
||||
return str(result.content)
|
||||
return _MCPToolResult(str(content_item.text), is_error)
|
||||
return _MCPToolResult(str(content_item), is_error)
|
||||
return _MCPToolResult(str(result.content), is_error)
|
||||
|
||||
return str(result)
|
||||
return _MCPToolResult(str(result), is_error)
|
||||
|
||||
async def list_prompts(self) -> list[dict[str, Any]]:
|
||||
"""List available prompts from MCP server.
|
||||
|
||||
592
lib/crewai/src/crewai/mcp/tool_resolver.py
Normal file
592
lib/crewai/src/crewai/mcp/tool_resolver.py
Normal file
@@ -0,0 +1,592 @@
|
||||
"""MCP tool resolution for CrewAI agents.
|
||||
|
||||
This module extracts all MCP-related tool resolution logic from the Agent class
|
||||
into a standalone MCPToolResolver. It handles three flavours of MCP reference:
|
||||
|
||||
1. Native configs: MCPServerStdio / MCPServerHTTP / MCPServerSSE objects.
|
||||
2. HTTPS URLs: e.g. "https://mcp.example.com/api"
|
||||
3. AMP references: e.g. "notion" or "notion#search" (legacy "crewai-amp:" prefix also works)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Any, Final, cast
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from crewai.mcp.client import MCPClient
|
||||
from crewai.mcp.config import (
|
||||
MCPServerConfig,
|
||||
MCPServerHTTP,
|
||||
MCPServerSSE,
|
||||
MCPServerStdio,
|
||||
)
|
||||
from crewai.mcp.transports.http import HTTPTransport
|
||||
from crewai.mcp.transports.sse import SSETransport
|
||||
from crewai.mcp.transports.stdio import StdioTransport
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
from crewai.utilities.logger import Logger
|
||||
|
||||
MCP_CONNECTION_TIMEOUT: Final[int] = 10
|
||||
MCP_TOOL_EXECUTION_TIMEOUT: Final[int] = 30
|
||||
MCP_DISCOVERY_TIMEOUT: Final[int] = 15
|
||||
MCP_MAX_RETRIES: Final[int] = 3
|
||||
|
||||
_mcp_schema_cache: dict[str, Any] = {}
|
||||
_cache_ttl: Final[int] = 300 # 5 minutes
|
||||
|
||||
|
||||
class MCPToolResolver:
|
||||
"""Resolves MCP server references / configs into CrewAI ``BaseTool`` instances.
|
||||
|
||||
Typical lifecycle::
|
||||
|
||||
resolver = MCPToolResolver(agent=my_agent, logger=my_agent._logger)
|
||||
tools = resolver.resolve(my_agent.mcps)
|
||||
# … agent executes tasks using *tools* …
|
||||
resolver.cleanup()
|
||||
|
||||
The resolver owns the MCP client connections it creates and is responsible
|
||||
for tearing them down via :meth:`cleanup`.
|
||||
"""
|
||||
|
||||
def __init__(self, agent: Any, logger: Logger) -> None:
|
||||
self._agent = agent
|
||||
self._logger = logger
|
||||
self._clients: list[Any] = []
|
||||
|
||||
@property
|
||||
def clients(self) -> list[Any]:
|
||||
return list(self._clients)
|
||||
|
||||
def resolve(self, mcps: list[str | MCPServerConfig]) -> list[BaseTool]:
|
||||
"""Convert MCP server references/configs to CrewAI tools."""
|
||||
all_tools: list[BaseTool] = []
|
||||
amp_refs: list[tuple[str, str | None]] = []
|
||||
|
||||
for mcp_config in mcps:
|
||||
if isinstance(mcp_config, str) and mcp_config.startswith("https://"):
|
||||
all_tools.extend(self._resolve_external(mcp_config))
|
||||
elif isinstance(mcp_config, str):
|
||||
amp_refs.append(self._parse_amp_ref(mcp_config))
|
||||
else:
|
||||
tools, client = self._resolve_native(mcp_config)
|
||||
all_tools.extend(tools)
|
||||
if client:
|
||||
self._clients.append(client)
|
||||
|
||||
if amp_refs:
|
||||
tools, clients = self._resolve_amp(amp_refs)
|
||||
all_tools.extend(tools)
|
||||
self._clients.extend(clients)
|
||||
|
||||
return all_tools
|
||||
|
||||
def cleanup(self) -> None:
|
||||
"""Disconnect all MCP client connections."""
|
||||
if not self._clients:
|
||||
return
|
||||
|
||||
async def _disconnect_all() -> None:
|
||||
for client in self._clients:
|
||||
if client and hasattr(client, "connected") and client.connected:
|
||||
await client.disconnect()
|
||||
|
||||
try:
|
||||
asyncio.run(_disconnect_all())
|
||||
except Exception as e:
|
||||
self._logger.log("error", f"Error during MCP client cleanup: {e}")
|
||||
finally:
|
||||
self._clients.clear()
|
||||
|
||||
@staticmethod
|
||||
def _parse_amp_ref(mcp_config: str) -> tuple[str, str | None]:
|
||||
"""Parse an AMP reference into *(slug, optional tool name)*.
|
||||
|
||||
Accepts both bare slugs (``"notion"``, ``"notion#search"``) and the
|
||||
legacy ``"crewai-amp:notion"`` form.
|
||||
"""
|
||||
bare = mcp_config.removeprefix("crewai-amp:")
|
||||
slug, _, specific_tool = bare.partition("#")
|
||||
return slug, specific_tool or None
|
||||
|
||||
def _resolve_amp(
|
||||
self, amp_refs: list[tuple[str, str | None]]
|
||||
) -> tuple[list[BaseTool], list[Any]]:
|
||||
"""Fetch AMP configs in bulk and return their tools and clients.
|
||||
|
||||
Resolves each unique slug only once (single connection per server),
|
||||
then applies per-ref tool filters to select specific tools.
|
||||
"""
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.mcp_events import MCPConfigFetchFailedEvent
|
||||
|
||||
unique_slugs = list(dict.fromkeys(slug for slug, _ in amp_refs))
|
||||
amp_configs_map = self._fetch_amp_mcp_configs(unique_slugs)
|
||||
|
||||
all_tools: list[BaseTool] = []
|
||||
all_clients: list[Any] = []
|
||||
|
||||
resolved_cache: dict[str, tuple[list[BaseTool], Any | None]] = {}
|
||||
|
||||
for slug in unique_slugs:
|
||||
config_dict = amp_configs_map.get(slug)
|
||||
if not config_dict:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
MCPConfigFetchFailedEvent(
|
||||
slug=slug,
|
||||
error=f"Config for '{slug}' not found. Make sure it is connected in your account.",
|
||||
error_type="not_connected",
|
||||
),
|
||||
)
|
||||
continue
|
||||
|
||||
mcp_server_config = self._build_mcp_config_from_dict(config_dict)
|
||||
|
||||
try:
|
||||
tools, client = self._resolve_native(mcp_server_config)
|
||||
resolved_cache[slug] = (tools, client)
|
||||
if client:
|
||||
all_clients.append(client)
|
||||
except Exception as e:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
MCPConfigFetchFailedEvent(
|
||||
slug=slug,
|
||||
error=str(e),
|
||||
error_type="connection_failed",
|
||||
),
|
||||
)
|
||||
|
||||
for slug, specific_tool in amp_refs:
|
||||
cached = resolved_cache.get(slug)
|
||||
if not cached:
|
||||
continue
|
||||
|
||||
slug_tools, _ = cached
|
||||
if specific_tool:
|
||||
all_tools.extend(
|
||||
t for t in slug_tools if t.name.endswith(f"_{specific_tool}")
|
||||
)
|
||||
else:
|
||||
all_tools.extend(slug_tools)
|
||||
|
||||
return all_tools, all_clients
|
||||
|
||||
def _fetch_amp_mcp_configs(self, slugs: list[str]) -> dict[str, dict[str, Any]]:
|
||||
"""Fetch MCP server configurations via CrewAI+ API.
|
||||
|
||||
Sends a GET request to the CrewAI+ mcps/configs endpoint with
|
||||
comma-separated slugs. CrewAI+ proxies the request to crewai-oauth.
|
||||
|
||||
API-level failures return ``{}``; individual slugs will then
|
||||
surface as ``MCPConfigFetchFailedEvent`` in :meth:`_resolve_amp`.
|
||||
"""
|
||||
import httpx
|
||||
|
||||
try:
|
||||
from crewai_tools.tools.crewai_platform_tools.misc import (
|
||||
get_platform_integration_token,
|
||||
)
|
||||
|
||||
from crewai.cli.plus_api import PlusAPI
|
||||
|
||||
plus_api = PlusAPI(api_key=get_platform_integration_token())
|
||||
response = plus_api.get_mcp_configs(slugs)
|
||||
|
||||
if response.status_code == 200:
|
||||
configs: dict[str, dict[str, Any]] = response.json().get("configs", {})
|
||||
return configs
|
||||
|
||||
self._logger.log(
|
||||
"debug",
|
||||
f"Failed to fetch MCP configs: HTTP {response.status_code}",
|
||||
)
|
||||
return {}
|
||||
|
||||
except httpx.HTTPError as e:
|
||||
self._logger.log("debug", f"Failed to fetch MCP configs: {e}")
|
||||
return {}
|
||||
except Exception as e:
|
||||
self._logger.log("debug", f"Cannot fetch AMP MCP configs: {e}")
|
||||
return {}
|
||||
|
||||
def _resolve_external(self, mcp_ref: str) -> list[BaseTool]:
|
||||
"""Resolve an HTTPS MCP server URL into tools."""
|
||||
from crewai.tools.mcp_tool_wrapper import MCPToolWrapper
|
||||
|
||||
if "#" in mcp_ref:
|
||||
server_url, specific_tool = mcp_ref.split("#", 1)
|
||||
else:
|
||||
server_url, specific_tool = mcp_ref, None
|
||||
|
||||
server_params = {"url": server_url}
|
||||
server_name = self._extract_server_name(server_url)
|
||||
|
||||
try:
|
||||
tool_schemas = self._get_mcp_tool_schemas(server_params)
|
||||
|
||||
if not tool_schemas:
|
||||
self._logger.log(
|
||||
"warning", f"No tools discovered from MCP server: {server_url}"
|
||||
)
|
||||
return []
|
||||
|
||||
tools = []
|
||||
for tool_name, schema in tool_schemas.items():
|
||||
if specific_tool and tool_name != specific_tool:
|
||||
continue
|
||||
|
||||
try:
|
||||
wrapper = MCPToolWrapper(
|
||||
mcp_server_params=server_params,
|
||||
tool_name=tool_name,
|
||||
tool_schema=schema,
|
||||
server_name=server_name,
|
||||
)
|
||||
tools.append(wrapper)
|
||||
except Exception as e:
|
||||
self._logger.log(
|
||||
"warning",
|
||||
f"Failed to create MCP tool wrapper for {tool_name}: {e}",
|
||||
)
|
||||
continue
|
||||
|
||||
if specific_tool and not tools:
|
||||
self._logger.log(
|
||||
"warning",
|
||||
f"Specific tool '{specific_tool}' not found on MCP server: {server_url}",
|
||||
)
|
||||
|
||||
return cast(list[BaseTool], tools)
|
||||
|
||||
except Exception as e:
|
||||
self._logger.log(
|
||||
"warning", f"Failed to connect to MCP server {server_url}: {e}"
|
||||
)
|
||||
return []
|
||||
|
||||
def _resolve_native(
|
||||
self, mcp_config: MCPServerConfig
|
||||
) -> tuple[list[BaseTool], Any | None]:
|
||||
"""Resolve an ``MCPServerConfig`` into tools, returning the client for cleanup."""
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
from crewai.tools.mcp_native_tool import MCPNativeTool
|
||||
|
||||
transport: StdioTransport | HTTPTransport | SSETransport
|
||||
if isinstance(mcp_config, MCPServerStdio):
|
||||
transport = StdioTransport(
|
||||
command=mcp_config.command,
|
||||
args=mcp_config.args,
|
||||
env=mcp_config.env,
|
||||
)
|
||||
server_name = f"{mcp_config.command}_{'_'.join(mcp_config.args)}"
|
||||
elif isinstance(mcp_config, MCPServerHTTP):
|
||||
transport = HTTPTransport(
|
||||
url=mcp_config.url,
|
||||
headers=mcp_config.headers,
|
||||
streamable=mcp_config.streamable,
|
||||
)
|
||||
server_name = self._extract_server_name(mcp_config.url)
|
||||
elif isinstance(mcp_config, MCPServerSSE):
|
||||
transport = SSETransport(
|
||||
url=mcp_config.url,
|
||||
headers=mcp_config.headers,
|
||||
)
|
||||
server_name = self._extract_server_name(mcp_config.url)
|
||||
else:
|
||||
raise ValueError(f"Unsupported MCP server config type: {type(mcp_config)}")
|
||||
|
||||
client = MCPClient(
|
||||
transport=transport,
|
||||
cache_tools_list=mcp_config.cache_tools_list,
|
||||
)
|
||||
|
||||
async def _setup_client_and_list_tools() -> list[dict[str, Any]]:
|
||||
try:
|
||||
if not client.connected:
|
||||
await client.connect()
|
||||
|
||||
tools_list = await client.list_tools()
|
||||
|
||||
try:
|
||||
await client.disconnect()
|
||||
await asyncio.sleep(0.1)
|
||||
except Exception as e:
|
||||
self._logger.log("error", f"Error during disconnect: {e}")
|
||||
|
||||
return tools_list
|
||||
except Exception as e:
|
||||
if client.connected:
|
||||
await client.disconnect()
|
||||
await asyncio.sleep(0.1)
|
||||
raise RuntimeError(
|
||||
f"Error during setup client and list tools: {e}"
|
||||
) from e
|
||||
|
||||
try:
|
||||
try:
|
||||
asyncio.get_running_loop()
|
||||
import concurrent.futures
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
future = executor.submit(
|
||||
asyncio.run, _setup_client_and_list_tools()
|
||||
)
|
||||
tools_list = future.result()
|
||||
except RuntimeError:
|
||||
try:
|
||||
tools_list = asyncio.run(_setup_client_and_list_tools())
|
||||
except RuntimeError as e:
|
||||
error_msg = str(e).lower()
|
||||
if "cancel scope" in error_msg or "task" in error_msg:
|
||||
raise ConnectionError(
|
||||
"MCP connection failed due to event loop cleanup issues. "
|
||||
"This may be due to authentication errors or server unavailability."
|
||||
) from e
|
||||
except asyncio.CancelledError as e:
|
||||
raise ConnectionError(
|
||||
"MCP connection was cancelled. This may indicate an authentication "
|
||||
"error or server unavailability."
|
||||
) from e
|
||||
|
||||
if mcp_config.tool_filter:
|
||||
filtered_tools = []
|
||||
for tool in tools_list:
|
||||
if callable(mcp_config.tool_filter):
|
||||
try:
|
||||
from crewai.mcp.filters import ToolFilterContext
|
||||
|
||||
context = ToolFilterContext(
|
||||
agent=self._agent,
|
||||
server_name=server_name,
|
||||
run_context=None,
|
||||
)
|
||||
if mcp_config.tool_filter(context, tool): # type: ignore[call-arg, arg-type]
|
||||
filtered_tools.append(tool)
|
||||
except (TypeError, AttributeError):
|
||||
if mcp_config.tool_filter(tool): # type: ignore[call-arg, arg-type]
|
||||
filtered_tools.append(tool)
|
||||
else:
|
||||
filtered_tools.append(tool)
|
||||
tools_list = filtered_tools
|
||||
|
||||
tools = []
|
||||
for tool_def in tools_list:
|
||||
tool_name = tool_def.get("name", "")
|
||||
original_tool_name = tool_def.get("original_name", tool_name)
|
||||
if not tool_name:
|
||||
continue
|
||||
|
||||
args_schema = None
|
||||
if tool_def.get("inputSchema"):
|
||||
args_schema = self._json_schema_to_pydantic(
|
||||
tool_name, tool_def["inputSchema"]
|
||||
)
|
||||
|
||||
tool_schema = {
|
||||
"description": tool_def.get("description", ""),
|
||||
"args_schema": args_schema,
|
||||
}
|
||||
|
||||
try:
|
||||
native_tool = MCPNativeTool(
|
||||
mcp_client=client,
|
||||
tool_name=tool_name,
|
||||
tool_schema=tool_schema,
|
||||
server_name=server_name,
|
||||
original_tool_name=original_tool_name,
|
||||
)
|
||||
tools.append(native_tool)
|
||||
except Exception as e:
|
||||
self._logger.log("error", f"Failed to create native MCP tool: {e}")
|
||||
continue
|
||||
|
||||
return cast(list[BaseTool], tools), client
|
||||
except Exception as e:
|
||||
if client.connected:
|
||||
asyncio.run(client.disconnect())
|
||||
|
||||
raise RuntimeError(f"Failed to get native MCP tools: {e}") from e
|
||||
|
||||
@staticmethod
|
||||
def _build_mcp_config_from_dict(
|
||||
config_dict: dict[str, Any],
|
||||
) -> MCPServerConfig:
|
||||
"""Convert a config dict from crewai-oauth into an MCPServerConfig."""
|
||||
config_type = config_dict.get("type", "http")
|
||||
|
||||
if config_type == "sse":
|
||||
return MCPServerSSE(
|
||||
url=config_dict["url"],
|
||||
headers=config_dict.get("headers"),
|
||||
cache_tools_list=config_dict.get("cache_tools_list", False),
|
||||
)
|
||||
|
||||
return MCPServerHTTP(
|
||||
url=config_dict["url"],
|
||||
headers=config_dict.get("headers"),
|
||||
streamable=config_dict.get("streamable", True),
|
||||
cache_tools_list=config_dict.get("cache_tools_list", False),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _extract_server_name(server_url: str) -> str:
|
||||
"""Extract clean server name from URL for tool prefixing."""
|
||||
parsed = urlparse(server_url)
|
||||
domain = parsed.netloc.replace(".", "_")
|
||||
path = parsed.path.replace("/", "_").strip("_")
|
||||
return f"{domain}_{path}" if path else domain
|
||||
|
||||
def _get_mcp_tool_schemas(
|
||||
self, server_params: dict[str, Any]
|
||||
) -> dict[str, dict[str, Any]]:
|
||||
"""Get tool schemas from MCP server with caching."""
|
||||
server_url = server_params["url"]
|
||||
|
||||
cache_key = server_url
|
||||
current_time = time.time()
|
||||
|
||||
if cache_key in _mcp_schema_cache:
|
||||
cached_data, cache_time = _mcp_schema_cache[cache_key]
|
||||
if current_time - cache_time < _cache_ttl:
|
||||
self._logger.log(
|
||||
"debug", f"Using cached MCP tool schemas for {server_url}"
|
||||
)
|
||||
return cached_data # type: ignore[no-any-return]
|
||||
|
||||
try:
|
||||
schemas = asyncio.run(self._get_mcp_tool_schemas_async(server_params))
|
||||
_mcp_schema_cache[cache_key] = (schemas, current_time)
|
||||
return schemas
|
||||
except Exception as e:
|
||||
self._logger.log(
|
||||
"warning", f"Failed to get MCP tool schemas from {server_url}: {e}"
|
||||
)
|
||||
return {}
|
||||
|
||||
async def _get_mcp_tool_schemas_async(
|
||||
self, server_params: dict[str, Any]
|
||||
) -> dict[str, dict[str, Any]]:
|
||||
"""Async implementation of MCP tool schema retrieval."""
|
||||
server_url = server_params["url"]
|
||||
return await self._retry_mcp_discovery(
|
||||
self._discover_mcp_tools_with_timeout, server_url
|
||||
)
|
||||
|
||||
async def _retry_mcp_discovery(
|
||||
self, operation_func: Any, server_url: str
|
||||
) -> dict[str, dict[str, Any]]:
|
||||
"""Retry MCP discovery with exponential backoff."""
|
||||
last_error = None
|
||||
|
||||
for attempt in range(MCP_MAX_RETRIES):
|
||||
result, error, should_retry = await self._attempt_mcp_discovery(
|
||||
operation_func, server_url
|
||||
)
|
||||
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
if not should_retry:
|
||||
raise RuntimeError(error)
|
||||
|
||||
last_error = error
|
||||
if attempt < MCP_MAX_RETRIES - 1:
|
||||
wait_time = 2**attempt
|
||||
await asyncio.sleep(wait_time)
|
||||
|
||||
raise RuntimeError(
|
||||
f"Failed to discover MCP tools after {MCP_MAX_RETRIES} attempts: {last_error}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def _attempt_mcp_discovery(
|
||||
operation_func: Any, server_url: str
|
||||
) -> tuple[dict[str, dict[str, Any]] | None, str, bool]:
|
||||
"""Attempt single MCP discovery; returns *(result, error_message, should_retry)*."""
|
||||
try:
|
||||
result = await operation_func(server_url)
|
||||
return result, "", False
|
||||
|
||||
except ImportError:
|
||||
return (
|
||||
None,
|
||||
"MCP library not available. Please install with: pip install mcp",
|
||||
False,
|
||||
)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
return (
|
||||
None,
|
||||
f"MCP discovery timed out after {MCP_DISCOVERY_TIMEOUT} seconds",
|
||||
True,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
error_str = str(e).lower()
|
||||
|
||||
if "authentication" in error_str or "unauthorized" in error_str:
|
||||
return None, f"Authentication failed for MCP server: {e!s}", False
|
||||
if "connection" in error_str or "network" in error_str:
|
||||
return None, f"Network connection failed: {e!s}", True
|
||||
if "json" in error_str or "parsing" in error_str:
|
||||
return None, f"Server response parsing error: {e!s}", True
|
||||
return None, f"MCP discovery error: {e!s}", False
|
||||
|
||||
async def _discover_mcp_tools_with_timeout(
|
||||
self, server_url: str
|
||||
) -> dict[str, dict[str, Any]]:
|
||||
"""Discover MCP tools with timeout wrapper."""
|
||||
return await asyncio.wait_for(
|
||||
self._discover_mcp_tools(server_url), timeout=MCP_DISCOVERY_TIMEOUT
|
||||
)
|
||||
|
||||
async def _discover_mcp_tools(self, server_url: str) -> dict[str, dict[str, Any]]:
|
||||
"""Discover tools from an MCP server (HTTPS / streamable-HTTP path)."""
|
||||
from mcp import ClientSession
|
||||
from mcp.client.streamable_http import streamablehttp_client
|
||||
|
||||
from crewai.utilities.string_utils import sanitize_tool_name
|
||||
|
||||
async with streamablehttp_client(server_url) as (read, write, _):
|
||||
async with ClientSession(read, write) as session:
|
||||
await asyncio.wait_for(
|
||||
session.initialize(), timeout=MCP_CONNECTION_TIMEOUT
|
||||
)
|
||||
|
||||
tools_result = await asyncio.wait_for(
|
||||
session.list_tools(),
|
||||
timeout=MCP_DISCOVERY_TIMEOUT - MCP_CONNECTION_TIMEOUT,
|
||||
)
|
||||
|
||||
schemas = {}
|
||||
for tool in tools_result.tools:
|
||||
args_schema = None
|
||||
if hasattr(tool, "inputSchema") and tool.inputSchema:
|
||||
args_schema = self._json_schema_to_pydantic(
|
||||
sanitize_tool_name(tool.name), tool.inputSchema
|
||||
)
|
||||
|
||||
schemas[sanitize_tool_name(tool.name)] = {
|
||||
"description": getattr(tool, "description", ""),
|
||||
"args_schema": args_schema,
|
||||
}
|
||||
return schemas
|
||||
|
||||
@staticmethod
|
||||
def _json_schema_to_pydantic(tool_name: str, json_schema: dict[str, Any]) -> type:
|
||||
"""Convert JSON Schema to a Pydantic model for tool arguments."""
|
||||
from crewai.utilities.pydantic_schema_utils import create_model_from_schema
|
||||
|
||||
model_name = f"{tool_name.replace('-', '_').replace(' ', '_')}Schema"
|
||||
return create_model_from_schema(
|
||||
json_schema,
|
||||
model_name=model_name,
|
||||
enrich_descriptions=True,
|
||||
)
|
||||
@@ -27,14 +27,16 @@ class MCPNativeTool(BaseTool):
|
||||
tool_name: str,
|
||||
tool_schema: dict[str, Any],
|
||||
server_name: str,
|
||||
original_tool_name: str | None = None,
|
||||
) -> None:
|
||||
"""Initialize native MCP tool.
|
||||
|
||||
Args:
|
||||
mcp_client: MCPClient instance with active session.
|
||||
tool_name: Original name of the tool on the MCP server.
|
||||
tool_name: Name of the tool (may be prefixed).
|
||||
tool_schema: Schema information for the tool.
|
||||
server_name: Name of the MCP server for prefixing.
|
||||
original_tool_name: Original name of the tool on the MCP server.
|
||||
"""
|
||||
# Create tool name with server prefix to avoid conflicts
|
||||
prefixed_name = f"{server_name}_{tool_name}"
|
||||
@@ -57,7 +59,7 @@ class MCPNativeTool(BaseTool):
|
||||
|
||||
# Set instance attributes after super().__init__
|
||||
self._mcp_client = mcp_client
|
||||
self._original_tool_name = tool_name
|
||||
self._original_tool_name = original_tool_name or tool_name
|
||||
self._server_name = server_name
|
||||
# self._logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -491,10 +491,66 @@ FORMAT_TYPE_MAP: dict[str, type[Any]] = {
|
||||
}
|
||||
|
||||
|
||||
def build_rich_field_description(prop_schema: dict[str, Any]) -> str:
|
||||
"""Build a comprehensive field description including constraints.
|
||||
|
||||
Embeds format, enum, pattern, min/max, and example constraints into the
|
||||
description text so that LLMs can understand tool parameter requirements
|
||||
without inspecting the raw JSON Schema.
|
||||
|
||||
Args:
|
||||
prop_schema: Property schema with description and constraints.
|
||||
|
||||
Returns:
|
||||
Enhanced description with format, enum, and other constraints.
|
||||
"""
|
||||
parts: list[str] = []
|
||||
|
||||
description = prop_schema.get("description", "")
|
||||
if description:
|
||||
parts.append(description)
|
||||
|
||||
format_type = prop_schema.get("format")
|
||||
if format_type:
|
||||
parts.append(f"Format: {format_type}")
|
||||
|
||||
enum_values = prop_schema.get("enum")
|
||||
if enum_values:
|
||||
enum_str = ", ".join(repr(v) for v in enum_values)
|
||||
parts.append(f"Allowed values: [{enum_str}]")
|
||||
|
||||
pattern = prop_schema.get("pattern")
|
||||
if pattern:
|
||||
parts.append(f"Pattern: {pattern}")
|
||||
|
||||
minimum = prop_schema.get("minimum")
|
||||
maximum = prop_schema.get("maximum")
|
||||
if minimum is not None:
|
||||
parts.append(f"Minimum: {minimum}")
|
||||
if maximum is not None:
|
||||
parts.append(f"Maximum: {maximum}")
|
||||
|
||||
min_length = prop_schema.get("minLength")
|
||||
max_length = prop_schema.get("maxLength")
|
||||
if min_length is not None:
|
||||
parts.append(f"Min length: {min_length}")
|
||||
if max_length is not None:
|
||||
parts.append(f"Max length: {max_length}")
|
||||
|
||||
examples = prop_schema.get("examples")
|
||||
if examples:
|
||||
examples_str = ", ".join(repr(e) for e in examples[:3])
|
||||
parts.append(f"Examples: {examples_str}")
|
||||
|
||||
return ". ".join(parts) if parts else ""
|
||||
|
||||
|
||||
def create_model_from_schema( # type: ignore[no-any-unimported]
|
||||
json_schema: dict[str, Any],
|
||||
*,
|
||||
root_schema: dict[str, Any] | None = None,
|
||||
model_name: str | None = None,
|
||||
enrich_descriptions: bool = False,
|
||||
__config__: ConfigDict | None = None,
|
||||
__base__: type[BaseModel] | None = None,
|
||||
__module__: str = __name__,
|
||||
@@ -512,6 +568,13 @@ def create_model_from_schema( # type: ignore[no-any-unimported]
|
||||
json_schema: A dictionary representing the JSON schema.
|
||||
root_schema: The root schema containing $defs. If not provided, the
|
||||
current schema is treated as the root schema.
|
||||
model_name: Override for the model name. If not provided, the schema
|
||||
``title`` field is used, falling back to ``"DynamicModel"``.
|
||||
enrich_descriptions: When True, augment field descriptions with
|
||||
constraint info (format, enum, pattern, min/max, examples) via
|
||||
:func:`build_rich_field_description`. Useful for LLM-facing tool
|
||||
schemas where constraints in the description help the model
|
||||
understand parameter requirements.
|
||||
__config__: Pydantic configuration for the generated model.
|
||||
__base__: Base class for the generated model. Defaults to BaseModel.
|
||||
__module__: Module name for the generated model class.
|
||||
@@ -548,10 +611,14 @@ def create_model_from_schema( # type: ignore[no-any-unimported]
|
||||
if "title" not in json_schema and "title" in (root_schema or {}):
|
||||
json_schema["title"] = (root_schema or {}).get("title")
|
||||
|
||||
model_name = json_schema.get("title") or "DynamicModel"
|
||||
effective_name = model_name or json_schema.get("title") or "DynamicModel"
|
||||
field_definitions = {
|
||||
name: _json_schema_to_pydantic_field(
|
||||
name, prop, json_schema.get("required", []), effective_root
|
||||
name,
|
||||
prop,
|
||||
json_schema.get("required", []),
|
||||
effective_root,
|
||||
enrich_descriptions=enrich_descriptions,
|
||||
)
|
||||
for name, prop in (json_schema.get("properties", {}) or {}).items()
|
||||
}
|
||||
@@ -559,7 +626,7 @@ def create_model_from_schema( # type: ignore[no-any-unimported]
|
||||
effective_config = __config__ or ConfigDict(extra="forbid")
|
||||
|
||||
return create_model_base(
|
||||
model_name,
|
||||
effective_name,
|
||||
__config__=effective_config,
|
||||
__base__=__base__,
|
||||
__module__=__module__,
|
||||
@@ -574,6 +641,8 @@ def _json_schema_to_pydantic_field(
|
||||
json_schema: dict[str, Any],
|
||||
required: list[str],
|
||||
root_schema: dict[str, Any],
|
||||
*,
|
||||
enrich_descriptions: bool = False,
|
||||
) -> Any:
|
||||
"""Convert a JSON schema property to a Pydantic field definition.
|
||||
|
||||
@@ -582,20 +651,29 @@ def _json_schema_to_pydantic_field(
|
||||
json_schema: The JSON schema for this field.
|
||||
required: List of required field names.
|
||||
root_schema: The root schema for resolving $ref.
|
||||
enrich_descriptions: When True, embed constraints in the description.
|
||||
|
||||
Returns:
|
||||
A tuple of (type, Field) for use with create_model.
|
||||
"""
|
||||
type_ = _json_schema_to_pydantic_type(json_schema, root_schema, name_=name.title())
|
||||
description = json_schema.get("description")
|
||||
examples = json_schema.get("examples")
|
||||
type_ = _json_schema_to_pydantic_type(
|
||||
json_schema, root_schema, name_=name.title(), enrich_descriptions=enrich_descriptions
|
||||
)
|
||||
is_required = name in required
|
||||
|
||||
field_params: dict[str, Any] = {}
|
||||
schema_extra: dict[str, Any] = {}
|
||||
|
||||
if description:
|
||||
field_params["description"] = description
|
||||
if enrich_descriptions:
|
||||
rich_desc = build_rich_field_description(json_schema)
|
||||
if rich_desc:
|
||||
field_params["description"] = rich_desc
|
||||
else:
|
||||
description = json_schema.get("description")
|
||||
if description:
|
||||
field_params["description"] = description
|
||||
|
||||
examples = json_schema.get("examples")
|
||||
if examples:
|
||||
schema_extra["examples"] = examples
|
||||
|
||||
@@ -711,6 +789,7 @@ def _json_schema_to_pydantic_type(
|
||||
root_schema: dict[str, Any],
|
||||
*,
|
||||
name_: str | None = None,
|
||||
enrich_descriptions: bool = False,
|
||||
) -> Any:
|
||||
"""Convert a JSON schema to a Python/Pydantic type.
|
||||
|
||||
@@ -718,6 +797,7 @@ def _json_schema_to_pydantic_type(
|
||||
json_schema: The JSON schema to convert.
|
||||
root_schema: The root schema for resolving $ref.
|
||||
name_: Optional name for nested models.
|
||||
enrich_descriptions: Propagated to nested model creation.
|
||||
|
||||
Returns:
|
||||
A Python type corresponding to the JSON schema.
|
||||
@@ -725,7 +805,9 @@ def _json_schema_to_pydantic_type(
|
||||
ref = json_schema.get("$ref")
|
||||
if ref:
|
||||
ref_schema = _resolve_ref(ref, root_schema)
|
||||
return _json_schema_to_pydantic_type(ref_schema, root_schema, name_=name_)
|
||||
return _json_schema_to_pydantic_type(
|
||||
ref_schema, root_schema, name_=name_, enrich_descriptions=enrich_descriptions
|
||||
)
|
||||
|
||||
enum_values = json_schema.get("enum")
|
||||
if enum_values:
|
||||
@@ -740,7 +822,10 @@ def _json_schema_to_pydantic_type(
|
||||
if any_of_schemas:
|
||||
any_of_types = [
|
||||
_json_schema_to_pydantic_type(
|
||||
schema, root_schema, name_=f"{name_ or 'Union'}Option{i}"
|
||||
schema,
|
||||
root_schema,
|
||||
name_=f"{name_ or 'Union'}Option{i}",
|
||||
enrich_descriptions=enrich_descriptions,
|
||||
)
|
||||
for i, schema in enumerate(any_of_schemas)
|
||||
]
|
||||
@@ -750,10 +835,14 @@ def _json_schema_to_pydantic_type(
|
||||
if all_of_schemas:
|
||||
if len(all_of_schemas) == 1:
|
||||
return _json_schema_to_pydantic_type(
|
||||
all_of_schemas[0], root_schema, name_=name_
|
||||
all_of_schemas[0], root_schema, name_=name_,
|
||||
enrich_descriptions=enrich_descriptions,
|
||||
)
|
||||
merged = _merge_all_of_schemas(all_of_schemas, root_schema)
|
||||
return _json_schema_to_pydantic_type(merged, root_schema, name_=name_)
|
||||
return _json_schema_to_pydantic_type(
|
||||
merged, root_schema, name_=name_,
|
||||
enrich_descriptions=enrich_descriptions,
|
||||
)
|
||||
|
||||
type_ = json_schema.get("type")
|
||||
|
||||
@@ -769,7 +858,8 @@ def _json_schema_to_pydantic_type(
|
||||
items_schema = json_schema.get("items")
|
||||
if items_schema:
|
||||
item_type = _json_schema_to_pydantic_type(
|
||||
items_schema, root_schema, name_=name_
|
||||
items_schema, root_schema, name_=name_,
|
||||
enrich_descriptions=enrich_descriptions,
|
||||
)
|
||||
return list[item_type] # type: ignore[valid-type]
|
||||
return list
|
||||
@@ -779,7 +869,10 @@ def _json_schema_to_pydantic_type(
|
||||
json_schema_ = json_schema.copy()
|
||||
if json_schema_.get("title") is None:
|
||||
json_schema_["title"] = name_ or "DynamicModel"
|
||||
return create_model_from_schema(json_schema_, root_schema=root_schema)
|
||||
return create_model_from_schema(
|
||||
json_schema_, root_schema=root_schema,
|
||||
enrich_descriptions=enrich_descriptions,
|
||||
)
|
||||
return dict
|
||||
if type_ == "null":
|
||||
return None
|
||||
|
||||
@@ -659,7 +659,7 @@ def test_agent_kickoff_with_platform_tools(mock_get, mock_post):
|
||||
|
||||
|
||||
@patch.dict("os.environ", {"EXA_API_KEY": "test_exa_key"})
|
||||
@patch("crewai.agent.Agent._get_external_mcp_tools")
|
||||
@patch("crewai.agent.Agent.get_mcp_tools")
|
||||
@pytest.mark.vcr()
|
||||
def test_agent_kickoff_with_mcp_tools(mock_get_mcp_tools):
|
||||
"""Test that Agent.kickoff() properly integrates MCP tools with LiteAgent"""
|
||||
@@ -691,7 +691,7 @@ def test_agent_kickoff_with_mcp_tools(mock_get_mcp_tools):
|
||||
assert result.raw is not None
|
||||
|
||||
# Verify MCP tools were retrieved
|
||||
mock_get_mcp_tools.assert_called_once_with("https://mcp.exa.ai/mcp?api_key=test_exa_key&profile=research")
|
||||
mock_get_mcp_tools.assert_called_once_with(["https://mcp.exa.ai/mcp?api_key=test_exa_key&profile=research"])
|
||||
|
||||
|
||||
# ============================================================================
|
||||
|
||||
373
lib/crewai/tests/mcp/test_amp_mcp.py
Normal file
373
lib/crewai/tests/mcp/test_amp_mcp.py
Normal file
@@ -0,0 +1,373 @@
|
||||
"""Tests for AMP MCP config fetching and tool resolution."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from crewai.agent.core import Agent
|
||||
from crewai.mcp.config import MCPServerHTTP, MCPServerSSE
|
||||
from crewai.mcp.tool_resolver import MCPToolResolver
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def agent():
|
||||
return Agent(
|
||||
role="Test Agent",
|
||||
goal="Test goal",
|
||||
backstory="Test backstory",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def resolver(agent):
|
||||
return MCPToolResolver(agent=agent, logger=agent._logger)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_tool_definitions():
|
||||
return [
|
||||
{
|
||||
"name": "search",
|
||||
"description": "Search tool",
|
||||
"inputSchema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string", "description": "Search query"}
|
||||
},
|
||||
"required": ["query"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "create_page",
|
||||
"description": "Create a page",
|
||||
"inputSchema": {},
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
class TestBuildMCPConfigFromDict:
|
||||
def test_builds_http_config(self):
|
||||
config_dict = {
|
||||
"type": "http",
|
||||
"url": "https://mcp.example.com/api",
|
||||
"headers": {"Authorization": "Bearer token123"},
|
||||
"streamable": True,
|
||||
"cache_tools_list": False,
|
||||
}
|
||||
|
||||
result = MCPToolResolver._build_mcp_config_from_dict(config_dict)
|
||||
|
||||
assert isinstance(result, MCPServerHTTP)
|
||||
assert result.url == "https://mcp.example.com/api"
|
||||
assert result.headers == {"Authorization": "Bearer token123"}
|
||||
assert result.streamable is True
|
||||
assert result.cache_tools_list is False
|
||||
|
||||
def test_builds_sse_config(self):
|
||||
config_dict = {
|
||||
"type": "sse",
|
||||
"url": "https://mcp.example.com/sse",
|
||||
"headers": {"Authorization": "Bearer token123"},
|
||||
"cache_tools_list": True,
|
||||
}
|
||||
|
||||
result = MCPToolResolver._build_mcp_config_from_dict(config_dict)
|
||||
|
||||
assert isinstance(result, MCPServerSSE)
|
||||
assert result.url == "https://mcp.example.com/sse"
|
||||
assert result.headers == {"Authorization": "Bearer token123"}
|
||||
assert result.cache_tools_list is True
|
||||
|
||||
def test_defaults_to_http(self):
|
||||
config_dict = {
|
||||
"url": "https://mcp.example.com/api",
|
||||
}
|
||||
|
||||
result = MCPToolResolver._build_mcp_config_from_dict(config_dict)
|
||||
|
||||
assert isinstance(result, MCPServerHTTP)
|
||||
assert result.streamable is True
|
||||
|
||||
def test_http_defaults(self):
|
||||
config_dict = {
|
||||
"type": "http",
|
||||
"url": "https://mcp.example.com/api",
|
||||
}
|
||||
|
||||
result = MCPToolResolver._build_mcp_config_from_dict(config_dict)
|
||||
|
||||
assert result.headers is None
|
||||
assert result.streamable is True
|
||||
assert result.cache_tools_list is False
|
||||
|
||||
|
||||
class TestFetchAmpMCPConfigs:
|
||||
@patch("crewai.cli.plus_api.PlusAPI")
|
||||
@patch("crewai_tools.tools.crewai_platform_tools.misc.get_platform_integration_token", return_value="test-api-key")
|
||||
def test_fetches_configs_successfully(self, mock_get_token, mock_plus_api_class, resolver):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"configs": {
|
||||
"notion": {
|
||||
"type": "sse",
|
||||
"url": "https://mcp.notion.so/sse",
|
||||
"headers": {"Authorization": "Bearer notion-token"},
|
||||
},
|
||||
"github": {
|
||||
"type": "http",
|
||||
"url": "https://mcp.github.com/api",
|
||||
"headers": {"Authorization": "Bearer gh-token"},
|
||||
},
|
||||
},
|
||||
}
|
||||
mock_plus_api = MagicMock()
|
||||
mock_plus_api.get_mcp_configs.return_value = mock_response
|
||||
mock_plus_api_class.return_value = mock_plus_api
|
||||
|
||||
result = resolver._fetch_amp_mcp_configs(["notion", "github"])
|
||||
|
||||
assert "notion" in result
|
||||
assert "github" in result
|
||||
assert result["notion"]["url"] == "https://mcp.notion.so/sse"
|
||||
mock_plus_api_class.assert_called_once_with(api_key="test-api-key")
|
||||
mock_plus_api.get_mcp_configs.assert_called_once_with(["notion", "github"])
|
||||
|
||||
@patch("crewai.cli.plus_api.PlusAPI")
|
||||
@patch("crewai_tools.tools.crewai_platform_tools.misc.get_platform_integration_token", return_value="test-api-key")
|
||||
def test_omits_missing_slugs(self, mock_get_token, mock_plus_api_class, resolver):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"configs": {"notion": {"type": "sse", "url": "https://mcp.notion.so/sse"}},
|
||||
}
|
||||
mock_plus_api = MagicMock()
|
||||
mock_plus_api.get_mcp_configs.return_value = mock_response
|
||||
mock_plus_api_class.return_value = mock_plus_api
|
||||
|
||||
result = resolver._fetch_amp_mcp_configs(["notion", "missing-server"])
|
||||
|
||||
assert "notion" in result
|
||||
assert "missing-server" not in result
|
||||
|
||||
@patch("crewai.cli.plus_api.PlusAPI")
|
||||
@patch("crewai_tools.tools.crewai_platform_tools.misc.get_platform_integration_token", return_value="test-api-key")
|
||||
def test_returns_empty_on_http_error(self, mock_get_token, mock_plus_api_class, resolver):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 500
|
||||
mock_plus_api = MagicMock()
|
||||
mock_plus_api.get_mcp_configs.return_value = mock_response
|
||||
mock_plus_api_class.return_value = mock_plus_api
|
||||
|
||||
result = resolver._fetch_amp_mcp_configs(["notion"])
|
||||
|
||||
assert result == {}
|
||||
|
||||
@patch("crewai.cli.plus_api.PlusAPI")
|
||||
@patch("crewai_tools.tools.crewai_platform_tools.misc.get_platform_integration_token", return_value="test-api-key")
|
||||
def test_returns_empty_on_network_error(self, mock_get_token, mock_plus_api_class, resolver):
|
||||
import httpx
|
||||
|
||||
mock_plus_api = MagicMock()
|
||||
mock_plus_api.get_mcp_configs.side_effect = httpx.ConnectError("Connection refused")
|
||||
mock_plus_api_class.return_value = mock_plus_api
|
||||
|
||||
result = resolver._fetch_amp_mcp_configs(["notion"])
|
||||
|
||||
assert result == {}
|
||||
|
||||
@patch("crewai_tools.tools.crewai_platform_tools.misc.get_platform_integration_token", side_effect=Exception("No token"))
|
||||
def test_returns_empty_when_no_token(self, mock_get_token, resolver):
|
||||
result = resolver._fetch_amp_mcp_configs(["notion"])
|
||||
|
||||
assert result == {}
|
||||
|
||||
|
||||
class TestParseAmpRef:
|
||||
def test_bare_slug(self):
|
||||
slug, tool = MCPToolResolver._parse_amp_ref("notion")
|
||||
assert slug == "notion"
|
||||
assert tool is None
|
||||
|
||||
def test_bare_slug_with_tool(self):
|
||||
slug, tool = MCPToolResolver._parse_amp_ref("notion#search")
|
||||
assert slug == "notion"
|
||||
assert tool == "search"
|
||||
|
||||
def test_bare_slug_with_empty_tool(self):
|
||||
slug, tool = MCPToolResolver._parse_amp_ref("notion#")
|
||||
assert slug == "notion"
|
||||
assert tool is None
|
||||
|
||||
def test_legacy_prefix_slug(self):
|
||||
slug, tool = MCPToolResolver._parse_amp_ref("crewai-amp:notion")
|
||||
assert slug == "notion"
|
||||
assert tool is None
|
||||
|
||||
def test_legacy_prefix_with_tool(self):
|
||||
slug, tool = MCPToolResolver._parse_amp_ref("crewai-amp:notion#search")
|
||||
assert slug == "notion"
|
||||
assert tool == "search"
|
||||
|
||||
|
||||
class TestGetMCPToolsAmpIntegration:
|
||||
@patch("crewai.mcp.tool_resolver.MCPClient")
|
||||
@patch.object(MCPToolResolver, "_fetch_amp_mcp_configs")
|
||||
def test_single_request_for_multiple_amp_refs(
|
||||
self, mock_fetch, mock_client_class, agent, mock_tool_definitions
|
||||
):
|
||||
mock_fetch.return_value = {
|
||||
"notion": {
|
||||
"type": "sse",
|
||||
"url": "https://mcp.notion.so/sse",
|
||||
"headers": {"Authorization": "Bearer token"},
|
||||
},
|
||||
"github": {
|
||||
"type": "http",
|
||||
"url": "https://mcp.github.com/api",
|
||||
"headers": {"Authorization": "Bearer gh-token"},
|
||||
"streamable": True,
|
||||
},
|
||||
}
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.list_tools = AsyncMock(return_value=mock_tool_definitions)
|
||||
mock_client.connected = False
|
||||
mock_client.connect = AsyncMock()
|
||||
mock_client.disconnect = AsyncMock()
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
tools = agent.get_mcp_tools(["notion", "github"])
|
||||
|
||||
mock_fetch.assert_called_once_with(["notion", "github"])
|
||||
assert len(tools) == 4 # 2 tools per server
|
||||
|
||||
@patch("crewai.mcp.tool_resolver.MCPClient")
|
||||
@patch.object(MCPToolResolver, "_fetch_amp_mcp_configs")
|
||||
def test_tool_filter_with_hash_syntax(
|
||||
self, mock_fetch, mock_client_class, agent, mock_tool_definitions
|
||||
):
|
||||
mock_fetch.return_value = {
|
||||
"notion": {
|
||||
"type": "sse",
|
||||
"url": "https://mcp.notion.so/sse",
|
||||
"headers": {"Authorization": "Bearer token"},
|
||||
},
|
||||
}
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.list_tools = AsyncMock(return_value=mock_tool_definitions)
|
||||
mock_client.connected = False
|
||||
mock_client.connect = AsyncMock()
|
||||
mock_client.disconnect = AsyncMock()
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
tools = agent.get_mcp_tools(["notion#search"])
|
||||
|
||||
mock_fetch.assert_called_once_with(["notion"])
|
||||
assert len(tools) == 1
|
||||
assert tools[0].name == "mcp_notion_so_sse_search"
|
||||
|
||||
@patch("crewai.mcp.tool_resolver.MCPClient")
|
||||
@patch.object(MCPToolResolver, "_fetch_amp_mcp_configs")
|
||||
def test_deduplicates_slugs(
|
||||
self, mock_fetch, mock_client_class, agent, mock_tool_definitions
|
||||
):
|
||||
mock_fetch.return_value = {
|
||||
"notion": {
|
||||
"type": "sse",
|
||||
"url": "https://mcp.notion.so/sse",
|
||||
"headers": {"Authorization": "Bearer token"},
|
||||
},
|
||||
}
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.list_tools = AsyncMock(return_value=mock_tool_definitions)
|
||||
mock_client.connected = False
|
||||
mock_client.connect = AsyncMock()
|
||||
mock_client.disconnect = AsyncMock()
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
tools = agent.get_mcp_tools(["notion#search", "notion#create_page"])
|
||||
|
||||
mock_fetch.assert_called_once_with(["notion"])
|
||||
assert len(tools) == 2
|
||||
|
||||
@patch.object(MCPToolResolver, "_fetch_amp_mcp_configs")
|
||||
def test_skips_missing_configs_gracefully(self, mock_fetch, agent):
|
||||
mock_fetch.return_value = {}
|
||||
|
||||
tools = agent.get_mcp_tools(["missing-server"])
|
||||
|
||||
assert tools == []
|
||||
|
||||
@patch("crewai.mcp.tool_resolver.MCPClient")
|
||||
@patch.object(MCPToolResolver, "_fetch_amp_mcp_configs")
|
||||
def test_legacy_crewai_amp_prefix_still_works(
|
||||
self, mock_fetch, mock_client_class, agent, mock_tool_definitions
|
||||
):
|
||||
mock_fetch.return_value = {
|
||||
"notion": {
|
||||
"type": "sse",
|
||||
"url": "https://mcp.notion.so/sse",
|
||||
"headers": {"Authorization": "Bearer token"},
|
||||
},
|
||||
}
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.list_tools = AsyncMock(return_value=mock_tool_definitions)
|
||||
mock_client.connected = False
|
||||
mock_client.connect = AsyncMock()
|
||||
mock_client.disconnect = AsyncMock()
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
tools = agent.get_mcp_tools(["crewai-amp:notion"])
|
||||
|
||||
mock_fetch.assert_called_once_with(["notion"])
|
||||
assert len(tools) == 2
|
||||
|
||||
@patch("crewai.mcp.tool_resolver.MCPClient")
|
||||
@patch.object(MCPToolResolver, "_fetch_amp_mcp_configs")
|
||||
@patch.object(MCPToolResolver, "_resolve_external")
|
||||
def test_non_amp_items_unaffected(
|
||||
self,
|
||||
mock_external,
|
||||
mock_fetch,
|
||||
mock_client_class,
|
||||
agent,
|
||||
mock_tool_definitions,
|
||||
):
|
||||
mock_fetch.return_value = {
|
||||
"notion": {
|
||||
"type": "sse",
|
||||
"url": "https://mcp.notion.so/sse",
|
||||
},
|
||||
}
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.list_tools = AsyncMock(return_value=mock_tool_definitions)
|
||||
mock_client.connected = False
|
||||
mock_client.connect = AsyncMock()
|
||||
mock_client.disconnect = AsyncMock()
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
mock_external_tool = MagicMock(spec=BaseTool)
|
||||
mock_external.return_value = [mock_external_tool]
|
||||
|
||||
http_config = MCPServerHTTP(
|
||||
url="https://other.mcp.com/api",
|
||||
headers={"Authorization": "Bearer other"},
|
||||
)
|
||||
|
||||
tools = agent.get_mcp_tools(
|
||||
[
|
||||
"notion",
|
||||
"https://external.mcp.com/api",
|
||||
http_config,
|
||||
]
|
||||
)
|
||||
|
||||
mock_fetch.assert_called_once_with(["notion"])
|
||||
mock_external.assert_called_once_with("https://external.mcp.com/api")
|
||||
# 2 from notion + 1 from external + 2 from http_config
|
||||
assert len(tools) == 5
|
||||
@@ -1,5 +1,5 @@
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
from crewai.agent.core import Agent
|
||||
@@ -46,7 +46,7 @@ def test_agent_with_stdio_mcp_config(mock_tool_definitions):
|
||||
)
|
||||
|
||||
|
||||
with patch("crewai.agent.core.MCPClient") as mock_client_class:
|
||||
with patch("crewai.mcp.tool_resolver.MCPClient") as mock_client_class:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.list_tools = AsyncMock(return_value=mock_tool_definitions)
|
||||
mock_client.connected = False # Will trigger connect
|
||||
@@ -82,7 +82,7 @@ def test_agent_with_http_mcp_config(mock_tool_definitions):
|
||||
mcps=[http_config],
|
||||
)
|
||||
|
||||
with patch("crewai.agent.core.MCPClient") as mock_client_class:
|
||||
with patch("crewai.mcp.tool_resolver.MCPClient") as mock_client_class:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.list_tools = AsyncMock(return_value=mock_tool_definitions)
|
||||
mock_client.connected = False # Will trigger connect
|
||||
@@ -117,7 +117,7 @@ def test_agent_with_sse_mcp_config(mock_tool_definitions):
|
||||
mcps=[sse_config],
|
||||
)
|
||||
|
||||
with patch("crewai.agent.core.MCPClient") as mock_client_class:
|
||||
with patch("crewai.mcp.tool_resolver.MCPClient") as mock_client_class:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.list_tools = AsyncMock(return_value=mock_tool_definitions)
|
||||
mock_client.connected = False
|
||||
@@ -141,7 +141,7 @@ def test_mcp_tool_execution_in_sync_context(mock_tool_definitions):
|
||||
"""Test MCPNativeTool execution in synchronous context (normal crew execution)."""
|
||||
http_config = MCPServerHTTP(url="https://api.example.com/mcp")
|
||||
|
||||
with patch("crewai.agent.core.MCPClient") as mock_client_class:
|
||||
with patch("crewai.mcp.tool_resolver.MCPClient") as mock_client_class:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.list_tools = AsyncMock(return_value=mock_tool_definitions)
|
||||
mock_client.connected = False
|
||||
@@ -173,7 +173,7 @@ async def test_mcp_tool_execution_in_async_context(mock_tool_definitions):
|
||||
"""Test MCPNativeTool execution in async context (e.g., from a Flow)."""
|
||||
http_config = MCPServerHTTP(url="https://api.example.com/mcp")
|
||||
|
||||
with patch("crewai.agent.core.MCPClient") as mock_client_class:
|
||||
with patch("crewai.mcp.tool_resolver.MCPClient") as mock_client_class:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.list_tools = AsyncMock(return_value=mock_tool_definitions)
|
||||
mock_client.connected = False
|
||||
|
||||
884
lib/crewai/tests/utilities/test_pydantic_schema_utils.py
Normal file
884
lib/crewai/tests/utilities/test_pydantic_schema_utils.py
Normal file
@@ -0,0 +1,884 @@
|
||||
"""Tests for pydantic_schema_utils module.
|
||||
|
||||
Covers:
|
||||
- create_model_from_schema: type mapping, required/optional, enums, formats,
|
||||
nested objects, arrays, unions, allOf, $ref, model_name, enrich_descriptions
|
||||
- Schema transformation helpers: resolve_refs, force_additional_properties_false,
|
||||
strip_unsupported_formats, ensure_type_in_schemas, convert_oneof_to_anyof,
|
||||
ensure_all_properties_required, strip_null_from_types, build_rich_field_description
|
||||
- End-to-end MCP tool schema conversion
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
from copy import deepcopy
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from crewai.utilities.pydantic_schema_utils import (
|
||||
build_rich_field_description,
|
||||
convert_oneof_to_anyof,
|
||||
create_model_from_schema,
|
||||
ensure_all_properties_required,
|
||||
ensure_type_in_schemas,
|
||||
force_additional_properties_false,
|
||||
resolve_refs,
|
||||
strip_null_from_types,
|
||||
strip_unsupported_formats,
|
||||
)
|
||||
|
||||
|
||||
class TestSimpleTypes:
|
||||
def test_string_field(self) -> None:
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {"name": {"type": "string"}},
|
||||
"required": ["name"],
|
||||
}
|
||||
Model = create_model_from_schema(schema)
|
||||
obj = Model(name="Alice")
|
||||
assert obj.name == "Alice"
|
||||
|
||||
def test_integer_field(self) -> None:
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {"count": {"type": "integer"}},
|
||||
"required": ["count"],
|
||||
}
|
||||
Model = create_model_from_schema(schema)
|
||||
obj = Model(count=42)
|
||||
assert obj.count == 42
|
||||
|
||||
def test_number_field(self) -> None:
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {"score": {"type": "number"}},
|
||||
"required": ["score"],
|
||||
}
|
||||
Model = create_model_from_schema(schema)
|
||||
obj = Model(score=3.14)
|
||||
assert obj.score == pytest.approx(3.14)
|
||||
|
||||
def test_boolean_field(self) -> None:
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {"active": {"type": "boolean"}},
|
||||
"required": ["active"],
|
||||
}
|
||||
Model = create_model_from_schema(schema)
|
||||
assert Model(active=True).active is True
|
||||
|
||||
def test_null_field(self) -> None:
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {"value": {"type": "null"}},
|
||||
"required": ["value"],
|
||||
}
|
||||
Model = create_model_from_schema(schema)
|
||||
obj = Model(value=None)
|
||||
assert obj.value is None
|
||||
|
||||
|
||||
class TestRequiredOptional:
|
||||
def test_required_field_has_no_default(self) -> None:
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {"name": {"type": "string"}},
|
||||
"required": ["name"],
|
||||
}
|
||||
Model = create_model_from_schema(schema)
|
||||
with pytest.raises(Exception):
|
||||
Model()
|
||||
|
||||
def test_optional_field_defaults_to_none(self) -> None:
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {"name": {"type": "string"}},
|
||||
"required": [],
|
||||
}
|
||||
Model = create_model_from_schema(schema)
|
||||
obj = Model()
|
||||
assert obj.name is None
|
||||
|
||||
def test_mixed_required_optional(self) -> None:
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"id": {"type": "integer"},
|
||||
"label": {"type": "string"},
|
||||
},
|
||||
"required": ["id"],
|
||||
}
|
||||
Model = create_model_from_schema(schema)
|
||||
obj = Model(id=1)
|
||||
assert obj.id == 1
|
||||
assert obj.label is None
|
||||
|
||||
|
||||
class TestEnumLiteral:
|
||||
def test_string_enum(self) -> None:
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"color": {"type": "string", "enum": ["red", "green", "blue"]},
|
||||
},
|
||||
"required": ["color"],
|
||||
}
|
||||
Model = create_model_from_schema(schema)
|
||||
obj = Model(color="red")
|
||||
assert obj.color == "red"
|
||||
|
||||
def test_string_enum_rejects_invalid(self) -> None:
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"color": {"type": "string", "enum": ["red", "green", "blue"]},
|
||||
},
|
||||
"required": ["color"],
|
||||
}
|
||||
Model = create_model_from_schema(schema)
|
||||
with pytest.raises(Exception):
|
||||
Model(color="yellow")
|
||||
|
||||
def test_const_value(self) -> None:
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"kind": {"const": "fixed"},
|
||||
},
|
||||
"required": ["kind"],
|
||||
}
|
||||
Model = create_model_from_schema(schema)
|
||||
obj = Model(kind="fixed")
|
||||
assert obj.kind == "fixed"
|
||||
|
||||
|
||||
class TestFormatMapping:
|
||||
def test_date_format(self) -> None:
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"birthday": {"type": "string", "format": "date"},
|
||||
},
|
||||
"required": ["birthday"],
|
||||
}
|
||||
Model = create_model_from_schema(schema)
|
||||
obj = Model(birthday=datetime.date(2000, 1, 15))
|
||||
assert obj.birthday == datetime.date(2000, 1, 15)
|
||||
|
||||
def test_datetime_format(self) -> None:
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"created_at": {"type": "string", "format": "date-time"},
|
||||
},
|
||||
"required": ["created_at"],
|
||||
}
|
||||
Model = create_model_from_schema(schema)
|
||||
dt = datetime.datetime(2025, 6, 1, 12, 0, 0)
|
||||
obj = Model(created_at=dt)
|
||||
assert obj.created_at == dt
|
||||
|
||||
def test_time_format(self) -> None:
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"alarm": {"type": "string", "format": "time"},
|
||||
},
|
||||
"required": ["alarm"],
|
||||
}
|
||||
Model = create_model_from_schema(schema)
|
||||
t = datetime.time(8, 30)
|
||||
obj = Model(alarm=t)
|
||||
assert obj.alarm == t
|
||||
|
||||
|
||||
class TestNestedObjects:
|
||||
def test_nested_object_creates_model(self) -> None:
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"address": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"street": {"type": "string"},
|
||||
"city": {"type": "string"},
|
||||
},
|
||||
"required": ["street", "city"],
|
||||
},
|
||||
},
|
||||
"required": ["address"],
|
||||
}
|
||||
Model = create_model_from_schema(schema)
|
||||
obj = Model(address={"street": "123 Main", "city": "Springfield"})
|
||||
assert obj.address.street == "123 Main"
|
||||
assert obj.address.city == "Springfield"
|
||||
|
||||
def test_object_without_properties_returns_dict(self) -> None:
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"metadata": {"type": "object"},
|
||||
},
|
||||
"required": ["metadata"],
|
||||
}
|
||||
Model = create_model_from_schema(schema)
|
||||
obj = Model(metadata={"key": "value"})
|
||||
assert obj.metadata == {"key": "value"}
|
||||
|
||||
|
||||
class TestTypedArrays:
|
||||
def test_array_of_strings(self) -> None:
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"tags": {"type": "array", "items": {"type": "string"}},
|
||||
},
|
||||
"required": ["tags"],
|
||||
}
|
||||
Model = create_model_from_schema(schema)
|
||||
obj = Model(tags=["a", "b", "c"])
|
||||
assert obj.tags == ["a", "b", "c"]
|
||||
|
||||
def test_array_of_objects(self) -> None:
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"items": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {"id": {"type": "integer"}},
|
||||
"required": ["id"],
|
||||
},
|
||||
},
|
||||
},
|
||||
"required": ["items"],
|
||||
}
|
||||
Model = create_model_from_schema(schema)
|
||||
obj = Model(items=[{"id": 1}, {"id": 2}])
|
||||
assert len(obj.items) == 2
|
||||
assert obj.items[0].id == 1
|
||||
|
||||
def test_untyped_array(self) -> None:
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {"data": {"type": "array"}},
|
||||
"required": ["data"],
|
||||
}
|
||||
Model = create_model_from_schema(schema)
|
||||
obj = Model(data=[1, "two", 3.0])
|
||||
assert obj.data == [1, "two", 3.0]
|
||||
|
||||
|
||||
class TestUnionTypes:
|
||||
def test_anyof_string_or_integer(self) -> None:
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"value": {
|
||||
"anyOf": [{"type": "string"}, {"type": "integer"}],
|
||||
},
|
||||
},
|
||||
"required": ["value"],
|
||||
}
|
||||
Model = create_model_from_schema(schema)
|
||||
assert Model(value="hello").value == "hello"
|
||||
assert Model(value=42).value == 42
|
||||
|
||||
def test_oneof(self) -> None:
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"value": {
|
||||
"oneOf": [{"type": "string"}, {"type": "number"}],
|
||||
},
|
||||
},
|
||||
"required": ["value"],
|
||||
}
|
||||
Model = create_model_from_schema(schema)
|
||||
assert Model(value="hello").value == "hello"
|
||||
assert Model(value=3.14).value == pytest.approx(3.14)
|
||||
|
||||
|
||||
class TestAllOfMerging:
|
||||
def test_allof_merges_properties(self) -> None:
|
||||
schema = {
|
||||
"type": "object",
|
||||
"allOf": [
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {"name": {"type": "string"}},
|
||||
"required": ["name"],
|
||||
},
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {"age": {"type": "integer"}},
|
||||
"required": ["age"],
|
||||
},
|
||||
],
|
||||
}
|
||||
Model = create_model_from_schema(schema)
|
||||
obj = Model(name="Alice", age=30)
|
||||
assert obj.name == "Alice"
|
||||
assert obj.age == 30
|
||||
|
||||
def test_single_allof(self) -> None:
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"item": {
|
||||
"allOf": [
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {"id": {"type": "integer"}},
|
||||
"required": ["id"],
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"required": ["item"],
|
||||
}
|
||||
Model = create_model_from_schema(schema)
|
||||
obj = Model(item={"id": 1})
|
||||
assert obj.item.id == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# $ref resolution
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRefResolution:
|
||||
def test_ref_in_property(self) -> None:
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"item": {"$ref": "#/$defs/Item"},
|
||||
},
|
||||
"required": ["item"],
|
||||
"$defs": {
|
||||
"Item": {
|
||||
"type": "object",
|
||||
"title": "Item",
|
||||
"properties": {"name": {"type": "string"}},
|
||||
"required": ["name"],
|
||||
},
|
||||
},
|
||||
}
|
||||
Model = create_model_from_schema(schema)
|
||||
obj = Model(item={"name": "Widget"})
|
||||
assert obj.item.name == "Widget"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# model_name parameter
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestModelName:
|
||||
def test_model_name_override(self) -> None:
|
||||
schema = {
|
||||
"type": "object",
|
||||
"title": "OriginalName",
|
||||
"properties": {"x": {"type": "integer"}},
|
||||
"required": ["x"],
|
||||
}
|
||||
Model = create_model_from_schema(schema, model_name="CustomSchema")
|
||||
assert Model.__name__ == "CustomSchema"
|
||||
|
||||
def test_model_name_fallback_to_title(self) -> None:
|
||||
schema = {
|
||||
"type": "object",
|
||||
"title": "FromTitle",
|
||||
"properties": {"x": {"type": "integer"}},
|
||||
"required": ["x"],
|
||||
}
|
||||
Model = create_model_from_schema(schema)
|
||||
assert Model.__name__ == "FromTitle"
|
||||
|
||||
def test_model_name_fallback_to_dynamic(self) -> None:
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {"x": {"type": "integer"}},
|
||||
"required": ["x"],
|
||||
}
|
||||
Model = create_model_from_schema(schema)
|
||||
assert Model.__name__ == "DynamicModel"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# enrich_descriptions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEnrichDescriptions:
|
||||
def test_enriched_description_includes_constraints(self) -> None:
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"score": {
|
||||
"type": "integer",
|
||||
"description": "The score value",
|
||||
"minimum": 0,
|
||||
"maximum": 100,
|
||||
},
|
||||
},
|
||||
"required": ["score"],
|
||||
}
|
||||
Model = create_model_from_schema(schema, enrich_descriptions=True)
|
||||
field_info = Model.model_fields["score"]
|
||||
assert "Minimum: 0" in field_info.description
|
||||
assert "Maximum: 100" in field_info.description
|
||||
assert "The score value" in field_info.description
|
||||
|
||||
def test_default_does_not_enrich(self) -> None:
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"score": {
|
||||
"type": "integer",
|
||||
"description": "The score value",
|
||||
"minimum": 0,
|
||||
},
|
||||
},
|
||||
"required": ["score"],
|
||||
}
|
||||
Model = create_model_from_schema(schema, enrich_descriptions=False)
|
||||
field_info = Model.model_fields["score"]
|
||||
assert field_info.description == "The score value"
|
||||
|
||||
def test_enriched_description_propagates_to_nested(self) -> None:
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"config": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"level": {
|
||||
"type": "integer",
|
||||
"description": "Level",
|
||||
"minimum": 1,
|
||||
"maximum": 10,
|
||||
},
|
||||
},
|
||||
"required": ["level"],
|
||||
},
|
||||
},
|
||||
"required": ["config"],
|
||||
}
|
||||
Model = create_model_from_schema(schema, enrich_descriptions=True)
|
||||
nested_model = Model.model_fields["config"].annotation
|
||||
nested_field = nested_model.model_fields["level"]
|
||||
assert "Minimum: 1" in nested_field.description
|
||||
assert "Maximum: 10" in nested_field.description
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Edge cases
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEdgeCases:
|
||||
def test_empty_properties(self) -> None:
|
||||
schema = {"type": "object", "properties": {}, "required": []}
|
||||
Model = create_model_from_schema(schema)
|
||||
obj = Model()
|
||||
assert obj is not None
|
||||
|
||||
def test_no_properties_key(self) -> None:
|
||||
schema = {"type": "object"}
|
||||
Model = create_model_from_schema(schema)
|
||||
obj = Model()
|
||||
assert obj is not None
|
||||
|
||||
def test_unknown_type_raises(self) -> None:
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"weird": {"type": "hyperspace"},
|
||||
},
|
||||
"required": ["weird"],
|
||||
}
|
||||
with pytest.raises(ValueError, match="Unsupported JSON schema type"):
|
||||
create_model_from_schema(schema)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# build_rich_field_description
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBuildRichFieldDescription:
|
||||
def test_description_only(self) -> None:
|
||||
assert build_rich_field_description({"description": "A name"}) == "A name"
|
||||
|
||||
def test_empty_schema(self) -> None:
|
||||
assert build_rich_field_description({}) == ""
|
||||
|
||||
def test_format(self) -> None:
|
||||
desc = build_rich_field_description({"format": "date-time"})
|
||||
assert "Format: date-time" in desc
|
||||
|
||||
def test_enum(self) -> None:
|
||||
desc = build_rich_field_description({"enum": ["a", "b"]})
|
||||
assert "Allowed values:" in desc
|
||||
assert "'a'" in desc
|
||||
assert "'b'" in desc
|
||||
|
||||
def test_pattern(self) -> None:
|
||||
desc = build_rich_field_description({"pattern": "^[a-z]+$"})
|
||||
assert "Pattern: ^[a-z]+$" in desc
|
||||
|
||||
def test_min_max(self) -> None:
|
||||
desc = build_rich_field_description({"minimum": 0, "maximum": 100})
|
||||
assert "Minimum: 0" in desc
|
||||
assert "Maximum: 100" in desc
|
||||
|
||||
def test_min_max_length(self) -> None:
|
||||
desc = build_rich_field_description({"minLength": 1, "maxLength": 255})
|
||||
assert "Min length: 1" in desc
|
||||
assert "Max length: 255" in desc
|
||||
|
||||
def test_examples(self) -> None:
|
||||
desc = build_rich_field_description({"examples": ["foo", "bar", "baz", "extra"]})
|
||||
assert "Examples:" in desc
|
||||
assert "'foo'" in desc
|
||||
assert "'baz'" in desc
|
||||
# Only first 3 shown
|
||||
assert "'extra'" not in desc
|
||||
|
||||
def test_combined_constraints(self) -> None:
|
||||
desc = build_rich_field_description({
|
||||
"description": "A score",
|
||||
"minimum": 0,
|
||||
"maximum": 10,
|
||||
"format": "int32",
|
||||
})
|
||||
assert desc.startswith("A score")
|
||||
assert "Minimum: 0" in desc
|
||||
assert "Maximum: 10" in desc
|
||||
assert "Format: int32" in desc
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Schema transformation functions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestResolveRefs:
|
||||
def test_basic_ref_resolution(self) -> None:
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {"item": {"$ref": "#/$defs/Item"}},
|
||||
"$defs": {
|
||||
"Item": {"type": "object", "properties": {"id": {"type": "integer"}}},
|
||||
},
|
||||
}
|
||||
resolved = resolve_refs(schema)
|
||||
assert "$ref" not in resolved["properties"]["item"]
|
||||
assert resolved["properties"]["item"]["type"] == "object"
|
||||
|
||||
def test_nested_ref_resolution(self) -> None:
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {"wrapper": {"$ref": "#/$defs/Wrapper"}},
|
||||
"$defs": {
|
||||
"Wrapper": {
|
||||
"type": "object",
|
||||
"properties": {"inner": {"$ref": "#/$defs/Inner"}},
|
||||
},
|
||||
"Inner": {"type": "string"},
|
||||
},
|
||||
}
|
||||
resolved = resolve_refs(schema)
|
||||
wrapper = resolved["properties"]["wrapper"]
|
||||
assert wrapper["properties"]["inner"]["type"] == "string"
|
||||
|
||||
def test_missing_ref_raises(self) -> None:
|
||||
schema = {
|
||||
"properties": {"x": {"$ref": "#/$defs/Missing"}},
|
||||
"$defs": {},
|
||||
}
|
||||
with pytest.raises(KeyError, match="Missing"):
|
||||
resolve_refs(schema)
|
||||
|
||||
def test_no_refs_unchanged(self) -> None:
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {"name": {"type": "string"}},
|
||||
}
|
||||
resolved = resolve_refs(schema)
|
||||
assert resolved == schema
|
||||
|
||||
|
||||
class TestForceAdditionalPropertiesFalse:
|
||||
def test_adds_to_object(self) -> None:
|
||||
schema = {"type": "object", "properties": {"x": {"type": "integer"}}}
|
||||
result = force_additional_properties_false(deepcopy(schema))
|
||||
assert result["additionalProperties"] is False
|
||||
|
||||
def test_adds_empty_properties_and_required(self) -> None:
|
||||
schema = {"type": "object"}
|
||||
result = force_additional_properties_false(deepcopy(schema))
|
||||
assert result["properties"] == {}
|
||||
assert result["required"] == []
|
||||
|
||||
def test_recursive_nested_objects(self) -> None:
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"child": {
|
||||
"type": "object",
|
||||
"properties": {"id": {"type": "integer"}},
|
||||
},
|
||||
},
|
||||
}
|
||||
result = force_additional_properties_false(deepcopy(schema))
|
||||
assert result["additionalProperties"] is False
|
||||
assert result["properties"]["child"]["additionalProperties"] is False
|
||||
|
||||
def test_does_not_affect_non_objects(self) -> None:
|
||||
schema = {"type": "string"}
|
||||
result = force_additional_properties_false(deepcopy(schema))
|
||||
assert "additionalProperties" not in result
|
||||
|
||||
|
||||
class TestStripUnsupportedFormats:
|
||||
def test_removes_email_format(self) -> None:
|
||||
schema = {"type": "string", "format": "email"}
|
||||
result = strip_unsupported_formats(deepcopy(schema))
|
||||
assert "format" not in result
|
||||
|
||||
def test_keeps_date_time(self) -> None:
|
||||
schema = {"type": "string", "format": "date-time"}
|
||||
result = strip_unsupported_formats(deepcopy(schema))
|
||||
assert result["format"] == "date-time"
|
||||
|
||||
def test_keeps_date(self) -> None:
|
||||
schema = {"type": "string", "format": "date"}
|
||||
result = strip_unsupported_formats(deepcopy(schema))
|
||||
assert result["format"] == "date"
|
||||
|
||||
def test_removes_uri_format(self) -> None:
|
||||
schema = {"type": "string", "format": "uri"}
|
||||
result = strip_unsupported_formats(deepcopy(schema))
|
||||
assert "format" not in result
|
||||
|
||||
def test_recursive(self) -> None:
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"email": {"type": "string", "format": "email"},
|
||||
"created": {"type": "string", "format": "date-time"},
|
||||
},
|
||||
}
|
||||
result = strip_unsupported_formats(deepcopy(schema))
|
||||
assert "format" not in result["properties"]["email"]
|
||||
assert result["properties"]["created"]["format"] == "date-time"
|
||||
|
||||
|
||||
class TestEnsureTypeInSchemas:
|
||||
def test_empty_schema_in_anyof_gets_type(self) -> None:
|
||||
schema = {"anyOf": [{}, {"type": "string"}]}
|
||||
result = ensure_type_in_schemas(deepcopy(schema))
|
||||
assert result["anyOf"][0] == {"type": "object"}
|
||||
|
||||
def test_empty_schema_in_oneof_gets_type(self) -> None:
|
||||
schema = {"oneOf": [{}, {"type": "integer"}]}
|
||||
result = ensure_type_in_schemas(deepcopy(schema))
|
||||
assert result["oneOf"][0] == {"type": "object"}
|
||||
|
||||
def test_non_empty_unchanged(self) -> None:
|
||||
schema = {"anyOf": [{"type": "string"}, {"type": "integer"}]}
|
||||
result = ensure_type_in_schemas(deepcopy(schema))
|
||||
assert result == schema
|
||||
|
||||
|
||||
class TestConvertOneofToAnyof:
|
||||
def test_converts_top_level(self) -> None:
|
||||
schema = {"oneOf": [{"type": "string"}, {"type": "integer"}]}
|
||||
result = convert_oneof_to_anyof(deepcopy(schema))
|
||||
assert "oneOf" not in result
|
||||
assert "anyOf" in result
|
||||
assert len(result["anyOf"]) == 2
|
||||
|
||||
def test_converts_nested(self) -> None:
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"value": {"oneOf": [{"type": "string"}, {"type": "number"}]},
|
||||
},
|
||||
}
|
||||
result = convert_oneof_to_anyof(deepcopy(schema))
|
||||
assert "anyOf" in result["properties"]["value"]
|
||||
assert "oneOf" not in result["properties"]["value"]
|
||||
|
||||
|
||||
class TestEnsureAllPropertiesRequired:
|
||||
def test_makes_all_required(self) -> None:
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {"a": {"type": "string"}, "b": {"type": "integer"}},
|
||||
"required": ["a"],
|
||||
}
|
||||
result = ensure_all_properties_required(deepcopy(schema))
|
||||
assert set(result["required"]) == {"a", "b"}
|
||||
|
||||
def test_recursive(self) -> None:
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"child": {
|
||||
"type": "object",
|
||||
"properties": {"x": {"type": "integer"}, "y": {"type": "integer"}},
|
||||
"required": [],
|
||||
},
|
||||
},
|
||||
}
|
||||
result = ensure_all_properties_required(deepcopy(schema))
|
||||
assert set(result["properties"]["child"]["required"]) == {"x", "y"}
|
||||
|
||||
|
||||
class TestStripNullFromTypes:
|
||||
def test_strips_null_from_anyof(self) -> None:
|
||||
schema = {
|
||||
"anyOf": [{"type": "string"}, {"type": "null"}],
|
||||
}
|
||||
result = strip_null_from_types(deepcopy(schema))
|
||||
assert "anyOf" not in result
|
||||
assert result["type"] == "string"
|
||||
|
||||
def test_strips_null_from_type_array(self) -> None:
|
||||
schema = {"type": ["string", "null"]}
|
||||
result = strip_null_from_types(deepcopy(schema))
|
||||
assert result["type"] == "string"
|
||||
|
||||
def test_multiple_non_null_in_anyof(self) -> None:
|
||||
schema = {
|
||||
"anyOf": [{"type": "string"}, {"type": "integer"}, {"type": "null"}],
|
||||
}
|
||||
result = strip_null_from_types(deepcopy(schema))
|
||||
assert len(result["anyOf"]) == 2
|
||||
|
||||
def test_no_null_unchanged(self) -> None:
|
||||
schema = {"type": "string"}
|
||||
result = strip_null_from_types(deepcopy(schema))
|
||||
assert result == schema
|
||||
|
||||
|
||||
class TestEndToEndMCPSchema:
|
||||
"""Realistic MCP tool schema exercising multiple features simultaneously."""
|
||||
|
||||
MCP_SCHEMA: dict[str, Any] = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "Search query",
|
||||
"minLength": 1,
|
||||
"maxLength": 500,
|
||||
},
|
||||
"max_results": {
|
||||
"type": "integer",
|
||||
"description": "Maximum results",
|
||||
"minimum": 1,
|
||||
"maximum": 100,
|
||||
},
|
||||
"format": {
|
||||
"type": "string",
|
||||
"enum": ["json", "csv", "xml"],
|
||||
"description": "Output format",
|
||||
},
|
||||
"filters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"date_from": {"type": "string", "format": "date"},
|
||||
"date_to": {"type": "string", "format": "date"},
|
||||
"categories": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
},
|
||||
},
|
||||
"required": ["date_from"],
|
||||
},
|
||||
"sort_order": {
|
||||
"anyOf": [{"type": "string"}, {"type": "null"}],
|
||||
},
|
||||
},
|
||||
"required": ["query", "format", "filters"],
|
||||
}
|
||||
|
||||
def test_model_creation(self) -> None:
|
||||
Model = create_model_from_schema(self.MCP_SCHEMA)
|
||||
assert Model is not None
|
||||
assert issubclass(Model, BaseModel)
|
||||
|
||||
def test_valid_input_accepted(self) -> None:
|
||||
Model = create_model_from_schema(self.MCP_SCHEMA)
|
||||
obj = Model(
|
||||
query="test search",
|
||||
format="json",
|
||||
filters={"date_from": "2025-01-01"},
|
||||
)
|
||||
assert obj.query == "test search"
|
||||
assert obj.format == "json"
|
||||
|
||||
def test_invalid_enum_rejected(self) -> None:
|
||||
Model = create_model_from_schema(self.MCP_SCHEMA)
|
||||
with pytest.raises(Exception):
|
||||
Model(
|
||||
query="test",
|
||||
format="yaml",
|
||||
filters={"date_from": "2025-01-01"},
|
||||
)
|
||||
|
||||
def test_model_name_for_mcp_tool(self) -> None:
|
||||
Model = create_model_from_schema(
|
||||
self.MCP_SCHEMA, model_name="search_toolSchema"
|
||||
)
|
||||
assert Model.__name__ == "search_toolSchema"
|
||||
|
||||
def test_enriched_descriptions_for_mcp(self) -> None:
|
||||
Model = create_model_from_schema(
|
||||
self.MCP_SCHEMA, enrich_descriptions=True
|
||||
)
|
||||
query_field = Model.model_fields["query"]
|
||||
assert "Min length: 1" in query_field.description
|
||||
assert "Max length: 500" in query_field.description
|
||||
|
||||
max_results_field = Model.model_fields["max_results"]
|
||||
assert "Minimum: 1" in max_results_field.description
|
||||
assert "Maximum: 100" in max_results_field.description
|
||||
|
||||
format_field = Model.model_fields["format"]
|
||||
assert "Allowed values:" in format_field.description
|
||||
|
||||
def test_optional_fields_accept_none(self) -> None:
|
||||
Model = create_model_from_schema(self.MCP_SCHEMA)
|
||||
obj = Model(
|
||||
query="test",
|
||||
format="csv",
|
||||
filters={"date_from": "2025-01-01"},
|
||||
max_results=None,
|
||||
sort_order=None,
|
||||
)
|
||||
assert obj.max_results is None
|
||||
assert obj.sort_order is None
|
||||
|
||||
def test_nested_filters_validated(self) -> None:
|
||||
Model = create_model_from_schema(self.MCP_SCHEMA)
|
||||
obj = Model(
|
||||
query="test",
|
||||
format="xml",
|
||||
filters={
|
||||
"date_from": "2025-01-01",
|
||||
"date_to": "2025-12-31",
|
||||
"categories": ["news", "tech"],
|
||||
},
|
||||
)
|
||||
assert obj.filters.date_from == datetime.date(2025, 1, 1)
|
||||
assert obj.filters.categories == ["news", "tech"]
|
||||
Reference in New Issue
Block a user