diff --git a/lib/crewai-tools/src/crewai_tools/adapters/mcp_adapter.py b/lib/crewai-tools/src/crewai_tools/adapters/mcp_adapter.py index edfb222a3..ae77ddde1 100644 --- a/lib/crewai-tools/src/crewai_tools/adapters/mcp_adapter.py +++ b/lib/crewai-tools/src/crewai_tools/adapters/mcp_adapter.py @@ -2,6 +2,7 @@ from __future__ import annotations +from collections.abc import Callable, Coroutine import logging from typing import TYPE_CHECKING, Any @@ -14,20 +15,303 @@ logger = logging.getLogger(__name__) if TYPE_CHECKING: from mcp import StdioServerParameters - from mcpadapt.core import MCPAdapt - from mcpadapt.crewai_adapter import CrewAIAdapter + import mcp.types + from mcpadapt.core import MCPAdapt, ToolAdapter try: + from typing import ForwardRef, Union + + import jsonref from mcp import StdioServerParameters - from mcpadapt.core import MCPAdapt - from mcpadapt.crewai_adapter import CrewAIAdapter + import mcp.types + from mcpadapt.core import MCPAdapt, ToolAdapter + from pydantic import BaseModel, Field, create_model MCP_AVAILABLE = True except ImportError: MCP_AVAILABLE = False +JSON_TYPE_MAPPING: dict[str, type] = { + "string": str, + "number": float, + "integer": int, + "boolean": bool, + "object": dict, + "array": list, +} + + +def _resolve_refs_to_dict(obj: Any) -> Any: + """Recursively convert JsonRef objects to regular dicts.""" + if isinstance(obj, dict): + return {k: _resolve_refs_to_dict(v) for k, v in obj.items()} + if isinstance(obj, list): + return [_resolve_refs_to_dict(item) for item in obj] + return obj + + +def _resolve_all_refs(schema: dict[str, Any]) -> dict[str, Any]: + """Resolve all $ref references in a JSON schema using jsonref. + + This function fully resolves all JSON Schema $ref references, including + internal references that point to paths within the schema itself + (e.g., '#/properties/geometry/anyOf/0/items'). + + Args: + schema: The JSON schema with potential $ref references. + + Returns: + A new schema dict with all $ref references resolved and inlined. + """ + if not MCP_AVAILABLE: + return schema + + resolved = jsonref.replace_refs(schema, lazy_load=False) + result = _resolve_refs_to_dict(resolved) + if "$defs" in result: + del result["$defs"] + return result + + +def _create_model_from_schema( + schema: dict[str, Any], model_name: str = "DynamicModel" +) -> type[BaseModel]: + """Create a Pydantic model from a JSON schema definition. + + This is a simplified version that handles common JSON schema patterns + without passing problematic extra fields to Pydantic's Field(). + + Args: + schema: The JSON schema definition. + model_name: The name for the created model. + + Returns: + A Pydantic BaseModel class. + """ + if not MCP_AVAILABLE: + raise RuntimeError("MCP dependencies not available") + + created_models: dict[str, type[BaseModel]] = {} + forward_refs: dict[str, ForwardRef] = {} + + def process_schema(name: str, schema_def: dict[str, Any]) -> type[BaseModel]: + if name in created_models: + return created_models[name] + + if name not in forward_refs: + forward_refs[name] = ForwardRef(name) + + fields: dict[str, Any] = {} + properties = schema_def.get("properties", {}) + required = set(schema_def.get("required", [])) + + for field_name, field_schema in properties.items(): + field_type, default = get_field_type(field_name, field_schema, required) + fields[field_name] = ( + field_type, + Field( + default=default, + description=field_schema.get("description", ""), + ), + ) + + model: type[BaseModel] = create_model( + schema_def.get("title", name), + __doc__=schema_def.get("description", ""), + **fields, + ) + + created_models[name] = model + return model + + def get_field_type( + field_name: str, field_schema: dict[str, Any], required: set[str] + ) -> tuple[Any, Any]: + if "$ref" in field_schema: + ref_parts = field_schema["$ref"].lstrip("#/").split("/") + ref_name = ref_parts[-1] + + if ref_name not in created_models: + ref_schema = schema + for part in ref_parts: + ref_schema = ref_schema.get(part, {}) + process_schema(ref_name, ref_schema) + + field_type = created_models[ref_name] + is_required = field_name in required + return ( + field_type | None if not is_required else field_type, + None if not is_required else ..., + ) + + if "anyOf" in field_schema: + is_nullable = any( + opt.get("type") == "null" for opt in field_schema["anyOf"] + ) + types: list[type[Any]] = [] + + for option in field_schema["anyOf"]: + if "type" in option and option["type"] != "null": + types.append(JSON_TYPE_MAPPING.get(option["type"], Any)) + elif "enum" in option: + types.append(str) + elif "$ref" in option: + ref_parts = option["$ref"].lstrip("#/").split("/") + ref_name = ref_parts[-1] + + if ref_name not in created_models: + ref_schema = schema + for part in ref_parts: + ref_schema = ref_schema.get(part, {}) + process_schema(ref_name, ref_schema) + + types.append(created_models[ref_name]) + + if len(types) == 0: + field_type = Any + elif len(types) == 1: + field_type = types[0] + else: + field_type = Union[tuple(types)] # noqa: UP007 + + default = field_schema.get("default") + is_required = field_name in required and default is None + + if is_nullable and not is_required: + field_type = field_type | None + + return field_type, ... if is_required else default + + if field_schema.get("type") == "array" and "items" in field_schema: + item_type, _ = get_field_type("item", field_schema["items"], set()) + field_type = list[item_type] + else: + json_type = field_schema.get("type", "string") + + if isinstance(json_type, list): + types = [] + for t in json_type: + if t != "null": + mapped_type = JSON_TYPE_MAPPING.get(t, Any) + types.append(mapped_type) + + if len(types) == 0: + field_type = Any + elif len(types) == 1: + field_type = types[0] + else: + field_type = Union[tuple(types)] # noqa: UP007 + else: + field_type = JSON_TYPE_MAPPING.get(json_type, Any) + + default = field_schema.get("default") + is_required = field_name in required and default is None + + if not is_required: + field_type = field_type | None + default = default if default is not None else None + else: + default = ... + + return field_type, default + + if "$defs" in schema: + for def_name, def_schema in schema["$defs"].items(): + process_schema(def_name, def_schema) + + return process_schema(model_name, schema) + + +class CrewAIAdapterWithSchemaFix(ToolAdapter): + """Custom CrewAI adapter that properly handles complex JSON schemas. + + This adapter extends mcpadapt's ToolAdapter to fix issues with complex + JSON schemas that contain internal $ref references (e.g., Mapbox MCP server). + It fully resolves all $ref references before creating Pydantic models, + preventing KeyError exceptions during JSON schema generation. + """ + + def adapt( + self, + func: Callable[[dict[str, Any] | None], mcp.types.CallToolResult], + mcp_tool: mcp.types.Tool, + ) -> BaseTool: + """Adapt a MCP tool to a CrewAI tool with proper schema handling. + + Args: + func: The function to adapt. + mcp_tool: The MCP tool to adapt. + + Returns: + A CrewAI tool. + """ + resolved_schema = _resolve_all_refs(mcp_tool.inputSchema) + tool_input_model = _create_model_from_schema(resolved_schema) + + class CrewAIMCPTool(BaseTool): + name: str = mcp_tool.name + description: str = mcp_tool.description or "" + args_schema: type[BaseModel] = tool_input_model + + def _run(self, *args: Any, **kwargs: Any) -> Any: + filtered_kwargs: dict[str, Any] = {} + schema_properties = resolved_schema.get("properties", {}) + + for key, value in kwargs.items(): + if value is None and key in schema_properties: + prop_schema = schema_properties[key] + if isinstance(prop_schema.get("type"), list): + if "null" in prop_schema["type"]: + filtered_kwargs[key] = value + elif "anyOf" in prop_schema: + if any( + opt.get("type") == "null" + for opt in prop_schema["anyOf"] + ): + filtered_kwargs[key] = value + else: + filtered_kwargs[key] = value + + result = func(filtered_kwargs) + return ( + result.content[0].text + if len(result.content) == 1 + else str( + [ + content.text + for content in result.content + if hasattr(content, "text") + ] + ) + ) + + def _generate_description(self) -> None: + try: + args_schema = { + k: v + for k, v in jsonref.replace_refs( + self.args_schema.model_json_schema() + ).items() + if k != "$defs" + } + except Exception: + args_schema = resolved_schema + self.description = f"Tool Name: {self.name}\nTool Arguments: {args_schema}\nTool Description: {self.description}" + + return CrewAIMCPTool() + + async def async_adapt( + self, + afunc: Callable[ + [dict[str, Any] | None], Coroutine[Any, Any, mcp.types.CallToolResult] + ], + mcp_tool: mcp.types.Tool, + ) -> Any: + raise NotImplementedError("async is not supported by the CrewAI framework.") + + class MCPServerAdapter: """Manages the lifecycle of an MCP server and make its tools available to CrewAI. @@ -112,7 +396,7 @@ class MCPServerAdapter: try: self._serverparams = serverparams self._adapter = MCPAdapt( - self._serverparams, CrewAIAdapter(), connect_timeout + self._serverparams, CrewAIAdapterWithSchemaFix(), connect_timeout ) self.start() diff --git a/lib/crewai-tools/tests/adapters/mcp_adapter_test.py b/lib/crewai-tools/tests/adapters/mcp_adapter_test.py index 188f86699..c5d0091d4 100644 --- a/lib/crewai-tools/tests/adapters/mcp_adapter_test.py +++ b/lib/crewai-tools/tests/adapters/mcp_adapter_test.py @@ -2,6 +2,11 @@ from textwrap import dedent from unittest.mock import MagicMock, patch from crewai_tools import MCPServerAdapter +from crewai_tools.adapters.mcp_adapter import ( + _resolve_all_refs, + _create_model_from_schema, + CrewAIAdapterWithSchemaFix, +) from crewai_tools.adapters.tool_collection import ToolCollection from mcp import StdioServerParameters import pytest @@ -237,3 +242,195 @@ def test_connect_timeout_passed_to_mcpadapt(mock_mcpadapt): MCPServerAdapter(serverparams, connect_timeout=5) mock_mcpadapt.assert_called_once() assert mock_mcpadapt.call_args[0][2] == 5 + + +class TestResolveAllRefs: + """Tests for the _resolve_all_refs function that handles complex JSON schemas.""" + + def test_resolve_simple_defs_ref(self): + """Test resolving $ref that points to $defs.""" + schema = { + "type": "object", + "$defs": { + "Point": { + "type": "array", + "items": {"type": "number"}, + } + }, + "properties": { + "location": {"$ref": "#/$defs/Point"} + } + } + resolved = _resolve_all_refs(schema) + assert "$defs" not in resolved + assert resolved["properties"]["location"]["type"] == "array" + assert resolved["properties"]["location"]["items"]["type"] == "number" + + def test_resolve_nested_anyof_refs(self): + """Test resolving $ref inside anyOf (like Mapbox geometry schema).""" + schema = { + "type": "object", + "$defs": { + "Coordinate": { + "type": "array", + "items": {"type": "number"}, + "minItems": 2, + } + }, + "properties": { + "geometry": { + "anyOf": [ + { + "type": "array", + "items": {"$ref": "#/$defs/Coordinate"} + }, + {"type": "null"} + ] + } + } + } + resolved = _resolve_all_refs(schema) + assert "$defs" not in resolved + geometry_anyof = resolved["properties"]["geometry"]["anyOf"] + array_option = geometry_anyof[0] + assert array_option["type"] == "array" + assert array_option["items"]["type"] == "array" + assert array_option["items"]["items"]["type"] == "number" + + def test_resolve_internal_property_refs(self): + """Test resolving $ref that points to internal properties.""" + schema = { + "type": "object", + "properties": { + "coordinates": { + "type": "array", + "items": {"type": "number"} + }, + "geometry": { + "anyOf": [ + { + "type": "array", + "items": {"$ref": "#/properties/coordinates"} + }, + {"type": "null"} + ] + } + } + } + resolved = _resolve_all_refs(schema) + geometry_anyof = resolved["properties"]["geometry"]["anyOf"] + array_option = geometry_anyof[0] + assert array_option["items"]["type"] == "array" + assert array_option["items"]["items"]["type"] == "number" + + def test_no_refs_returns_same_schema(self): + """Test that schema without refs is returned unchanged (except $defs removal).""" + schema = { + "type": "object", + "properties": { + "name": {"type": "string"}, + "age": {"type": "integer"} + } + } + resolved = _resolve_all_refs(schema) + assert resolved["properties"]["name"]["type"] == "string" + assert resolved["properties"]["age"]["type"] == "integer" + + +class TestCreateModelFromSchema: + """Tests for the _create_model_from_schema function.""" + + def test_create_simple_model(self): + """Test creating a model from a simple schema.""" + schema = { + "type": "object", + "properties": { + "name": {"type": "string", "description": "The name"}, + "count": {"type": "integer"} + }, + "required": ["name"] + } + model = _create_model_from_schema(schema) + assert model.__name__ == "DynamicModel" + instance = model(name="test") + assert instance.name == "test" + assert instance.count is None + + def test_create_model_with_anyof_nullable(self): + """Test creating a model with anyOf that includes null.""" + schema = { + "type": "object", + "properties": { + "value": { + "anyOf": [ + {"type": "string"}, + {"type": "null"} + ] + } + } + } + model = _create_model_from_schema(schema) + instance = model(value="test") + assert instance.value == "test" + instance2 = model(value=None) + assert instance2.value is None + + def test_create_model_with_array(self): + """Test creating a model with array type.""" + schema = { + "type": "object", + "properties": { + "items": { + "type": "array", + "items": {"type": "string"} + } + } + } + model = _create_model_from_schema(schema) + instance = model(items=["a", "b", "c"]) + assert instance.items == ["a", "b", "c"] + + def test_model_json_schema_does_not_raise(self): + """Test that generated model's model_json_schema() doesn't raise KeyError. + + This is the core issue from GitHub issue #4312 - complex schemas with + internal $ref references would cause KeyError when calling model_json_schema(). + """ + schema = { + "type": "object", + "properties": { + "geometry": { + "anyOf": [ + { + "type": "array", + "items": {"type": "number"} + }, + {"type": "null"} + ] + } + } + } + resolved = _resolve_all_refs(schema) + model = _create_model_from_schema(resolved) + json_schema = model.model_json_schema() + assert "properties" in json_schema + assert "geometry" in json_schema["properties"] + + +class TestCrewAIAdapterWithSchemaFix: + """Tests for the CrewAIAdapterWithSchemaFix class.""" + + def test_adapter_is_tool_adapter(self): + """Test that the adapter is a valid ToolAdapter.""" + from mcpadapt.core import ToolAdapter + adapter = CrewAIAdapterWithSchemaFix() + assert isinstance(adapter, ToolAdapter) + + def test_async_adapt_raises_not_implemented(self): + """Test that async_adapt raises NotImplementedError.""" + adapter = CrewAIAdapterWithSchemaFix() + with pytest.raises(NotImplementedError): + import asyncio + asyncio.get_event_loop().run_until_complete( + adapter.async_adapt(None, None) + )