fix: handle complex JSON schemas in MCP adapter to prevent KeyError

Fixes #4312

The issue was that MCP servers with complex JSON schemas containing internal
$ref references (like Mapbox MCP server) would cause a KeyError when
Pydantic tried to generate JSON schemas from dynamically created models.

This fix:
- Creates a custom CrewAI adapter (CrewAIAdapterWithSchemaFix) that properly
  resolves all $ref references using jsonref before creating Pydantic models
- Implements _resolve_all_refs() to fully resolve JSON Schema references
- Implements _create_model_from_schema() that creates Pydantic models without
  passing problematic extra fields that cause issues during schema generation
- Adds comprehensive tests for the new schema handling functionality

Co-Authored-By: João <joao@crewai.com>
This commit is contained in:
Devin AI
2026-01-30 17:27:59 +00:00
parent 85f31459c1
commit 5b296fa20f
2 changed files with 486 additions and 5 deletions

View File

@@ -2,6 +2,7 @@
from __future__ import annotations
from collections.abc import Callable, Coroutine
import logging
from typing import TYPE_CHECKING, Any
@@ -14,20 +15,303 @@ logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from mcp import StdioServerParameters
from mcpadapt.core import MCPAdapt
from mcpadapt.crewai_adapter import CrewAIAdapter
import mcp.types
from mcpadapt.core import MCPAdapt, ToolAdapter
try:
from typing import ForwardRef, Union
import jsonref
from mcp import StdioServerParameters
from mcpadapt.core import MCPAdapt
from mcpadapt.crewai_adapter import CrewAIAdapter
import mcp.types
from mcpadapt.core import MCPAdapt, ToolAdapter
from pydantic import BaseModel, Field, create_model
MCP_AVAILABLE = True
except ImportError:
MCP_AVAILABLE = False
JSON_TYPE_MAPPING: dict[str, type] = {
"string": str,
"number": float,
"integer": int,
"boolean": bool,
"object": dict,
"array": list,
}
def _resolve_refs_to_dict(obj: Any) -> Any:
"""Recursively convert JsonRef objects to regular dicts."""
if isinstance(obj, dict):
return {k: _resolve_refs_to_dict(v) for k, v in obj.items()}
if isinstance(obj, list):
return [_resolve_refs_to_dict(item) for item in obj]
return obj
def _resolve_all_refs(schema: dict[str, Any]) -> dict[str, Any]:
"""Resolve all $ref references in a JSON schema using jsonref.
This function fully resolves all JSON Schema $ref references, including
internal references that point to paths within the schema itself
(e.g., '#/properties/geometry/anyOf/0/items').
Args:
schema: The JSON schema with potential $ref references.
Returns:
A new schema dict with all $ref references resolved and inlined.
"""
if not MCP_AVAILABLE:
return schema
resolved = jsonref.replace_refs(schema, lazy_load=False)
result = _resolve_refs_to_dict(resolved)
if "$defs" in result:
del result["$defs"]
return result
def _create_model_from_schema(
schema: dict[str, Any], model_name: str = "DynamicModel"
) -> type[BaseModel]:
"""Create a Pydantic model from a JSON schema definition.
This is a simplified version that handles common JSON schema patterns
without passing problematic extra fields to Pydantic's Field().
Args:
schema: The JSON schema definition.
model_name: The name for the created model.
Returns:
A Pydantic BaseModel class.
"""
if not MCP_AVAILABLE:
raise RuntimeError("MCP dependencies not available")
created_models: dict[str, type[BaseModel]] = {}
forward_refs: dict[str, ForwardRef] = {}
def process_schema(name: str, schema_def: dict[str, Any]) -> type[BaseModel]:
if name in created_models:
return created_models[name]
if name not in forward_refs:
forward_refs[name] = ForwardRef(name)
fields: dict[str, Any] = {}
properties = schema_def.get("properties", {})
required = set(schema_def.get("required", []))
for field_name, field_schema in properties.items():
field_type, default = get_field_type(field_name, field_schema, required)
fields[field_name] = (
field_type,
Field(
default=default,
description=field_schema.get("description", ""),
),
)
model: type[BaseModel] = create_model(
schema_def.get("title", name),
__doc__=schema_def.get("description", ""),
**fields,
)
created_models[name] = model
return model
def get_field_type(
field_name: str, field_schema: dict[str, Any], required: set[str]
) -> tuple[Any, Any]:
if "$ref" in field_schema:
ref_parts = field_schema["$ref"].lstrip("#/").split("/")
ref_name = ref_parts[-1]
if ref_name not in created_models:
ref_schema = schema
for part in ref_parts:
ref_schema = ref_schema.get(part, {})
process_schema(ref_name, ref_schema)
field_type = created_models[ref_name]
is_required = field_name in required
return (
field_type | None if not is_required else field_type,
None if not is_required else ...,
)
if "anyOf" in field_schema:
is_nullable = any(
opt.get("type") == "null" for opt in field_schema["anyOf"]
)
types: list[type[Any]] = []
for option in field_schema["anyOf"]:
if "type" in option and option["type"] != "null":
types.append(JSON_TYPE_MAPPING.get(option["type"], Any))
elif "enum" in option:
types.append(str)
elif "$ref" in option:
ref_parts = option["$ref"].lstrip("#/").split("/")
ref_name = ref_parts[-1]
if ref_name not in created_models:
ref_schema = schema
for part in ref_parts:
ref_schema = ref_schema.get(part, {})
process_schema(ref_name, ref_schema)
types.append(created_models[ref_name])
if len(types) == 0:
field_type = Any
elif len(types) == 1:
field_type = types[0]
else:
field_type = Union[tuple(types)] # noqa: UP007
default = field_schema.get("default")
is_required = field_name in required and default is None
if is_nullable and not is_required:
field_type = field_type | None
return field_type, ... if is_required else default
if field_schema.get("type") == "array" and "items" in field_schema:
item_type, _ = get_field_type("item", field_schema["items"], set())
field_type = list[item_type]
else:
json_type = field_schema.get("type", "string")
if isinstance(json_type, list):
types = []
for t in json_type:
if t != "null":
mapped_type = JSON_TYPE_MAPPING.get(t, Any)
types.append(mapped_type)
if len(types) == 0:
field_type = Any
elif len(types) == 1:
field_type = types[0]
else:
field_type = Union[tuple(types)] # noqa: UP007
else:
field_type = JSON_TYPE_MAPPING.get(json_type, Any)
default = field_schema.get("default")
is_required = field_name in required and default is None
if not is_required:
field_type = field_type | None
default = default if default is not None else None
else:
default = ...
return field_type, default
if "$defs" in schema:
for def_name, def_schema in schema["$defs"].items():
process_schema(def_name, def_schema)
return process_schema(model_name, schema)
class CrewAIAdapterWithSchemaFix(ToolAdapter):
"""Custom CrewAI adapter that properly handles complex JSON schemas.
This adapter extends mcpadapt's ToolAdapter to fix issues with complex
JSON schemas that contain internal $ref references (e.g., Mapbox MCP server).
It fully resolves all $ref references before creating Pydantic models,
preventing KeyError exceptions during JSON schema generation.
"""
def adapt(
self,
func: Callable[[dict[str, Any] | None], mcp.types.CallToolResult],
mcp_tool: mcp.types.Tool,
) -> BaseTool:
"""Adapt a MCP tool to a CrewAI tool with proper schema handling.
Args:
func: The function to adapt.
mcp_tool: The MCP tool to adapt.
Returns:
A CrewAI tool.
"""
resolved_schema = _resolve_all_refs(mcp_tool.inputSchema)
tool_input_model = _create_model_from_schema(resolved_schema)
class CrewAIMCPTool(BaseTool):
name: str = mcp_tool.name
description: str = mcp_tool.description or ""
args_schema: type[BaseModel] = tool_input_model
def _run(self, *args: Any, **kwargs: Any) -> Any:
filtered_kwargs: dict[str, Any] = {}
schema_properties = resolved_schema.get("properties", {})
for key, value in kwargs.items():
if value is None and key in schema_properties:
prop_schema = schema_properties[key]
if isinstance(prop_schema.get("type"), list):
if "null" in prop_schema["type"]:
filtered_kwargs[key] = value
elif "anyOf" in prop_schema:
if any(
opt.get("type") == "null"
for opt in prop_schema["anyOf"]
):
filtered_kwargs[key] = value
else:
filtered_kwargs[key] = value
result = func(filtered_kwargs)
return (
result.content[0].text
if len(result.content) == 1
else str(
[
content.text
for content in result.content
if hasattr(content, "text")
]
)
)
def _generate_description(self) -> None:
try:
args_schema = {
k: v
for k, v in jsonref.replace_refs(
self.args_schema.model_json_schema()
).items()
if k != "$defs"
}
except Exception:
args_schema = resolved_schema
self.description = f"Tool Name: {self.name}\nTool Arguments: {args_schema}\nTool Description: {self.description}"
return CrewAIMCPTool()
async def async_adapt(
self,
afunc: Callable[
[dict[str, Any] | None], Coroutine[Any, Any, mcp.types.CallToolResult]
],
mcp_tool: mcp.types.Tool,
) -> Any:
raise NotImplementedError("async is not supported by the CrewAI framework.")
class MCPServerAdapter:
"""Manages the lifecycle of an MCP server and make its tools available to CrewAI.
@@ -112,7 +396,7 @@ class MCPServerAdapter:
try:
self._serverparams = serverparams
self._adapter = MCPAdapt(
self._serverparams, CrewAIAdapter(), connect_timeout
self._serverparams, CrewAIAdapterWithSchemaFix(), connect_timeout
)
self.start()

View File

@@ -2,6 +2,11 @@ from textwrap import dedent
from unittest.mock import MagicMock, patch
from crewai_tools import MCPServerAdapter
from crewai_tools.adapters.mcp_adapter import (
_resolve_all_refs,
_create_model_from_schema,
CrewAIAdapterWithSchemaFix,
)
from crewai_tools.adapters.tool_collection import ToolCollection
from mcp import StdioServerParameters
import pytest
@@ -237,3 +242,195 @@ def test_connect_timeout_passed_to_mcpadapt(mock_mcpadapt):
MCPServerAdapter(serverparams, connect_timeout=5)
mock_mcpadapt.assert_called_once()
assert mock_mcpadapt.call_args[0][2] == 5
class TestResolveAllRefs:
"""Tests for the _resolve_all_refs function that handles complex JSON schemas."""
def test_resolve_simple_defs_ref(self):
"""Test resolving $ref that points to $defs."""
schema = {
"type": "object",
"$defs": {
"Point": {
"type": "array",
"items": {"type": "number"},
}
},
"properties": {
"location": {"$ref": "#/$defs/Point"}
}
}
resolved = _resolve_all_refs(schema)
assert "$defs" not in resolved
assert resolved["properties"]["location"]["type"] == "array"
assert resolved["properties"]["location"]["items"]["type"] == "number"
def test_resolve_nested_anyof_refs(self):
"""Test resolving $ref inside anyOf (like Mapbox geometry schema)."""
schema = {
"type": "object",
"$defs": {
"Coordinate": {
"type": "array",
"items": {"type": "number"},
"minItems": 2,
}
},
"properties": {
"geometry": {
"anyOf": [
{
"type": "array",
"items": {"$ref": "#/$defs/Coordinate"}
},
{"type": "null"}
]
}
}
}
resolved = _resolve_all_refs(schema)
assert "$defs" not in resolved
geometry_anyof = resolved["properties"]["geometry"]["anyOf"]
array_option = geometry_anyof[0]
assert array_option["type"] == "array"
assert array_option["items"]["type"] == "array"
assert array_option["items"]["items"]["type"] == "number"
def test_resolve_internal_property_refs(self):
"""Test resolving $ref that points to internal properties."""
schema = {
"type": "object",
"properties": {
"coordinates": {
"type": "array",
"items": {"type": "number"}
},
"geometry": {
"anyOf": [
{
"type": "array",
"items": {"$ref": "#/properties/coordinates"}
},
{"type": "null"}
]
}
}
}
resolved = _resolve_all_refs(schema)
geometry_anyof = resolved["properties"]["geometry"]["anyOf"]
array_option = geometry_anyof[0]
assert array_option["items"]["type"] == "array"
assert array_option["items"]["items"]["type"] == "number"
def test_no_refs_returns_same_schema(self):
"""Test that schema without refs is returned unchanged (except $defs removal)."""
schema = {
"type": "object",
"properties": {
"name": {"type": "string"},
"age": {"type": "integer"}
}
}
resolved = _resolve_all_refs(schema)
assert resolved["properties"]["name"]["type"] == "string"
assert resolved["properties"]["age"]["type"] == "integer"
class TestCreateModelFromSchema:
"""Tests for the _create_model_from_schema function."""
def test_create_simple_model(self):
"""Test creating a model from a simple schema."""
schema = {
"type": "object",
"properties": {
"name": {"type": "string", "description": "The name"},
"count": {"type": "integer"}
},
"required": ["name"]
}
model = _create_model_from_schema(schema)
assert model.__name__ == "DynamicModel"
instance = model(name="test")
assert instance.name == "test"
assert instance.count is None
def test_create_model_with_anyof_nullable(self):
"""Test creating a model with anyOf that includes null."""
schema = {
"type": "object",
"properties": {
"value": {
"anyOf": [
{"type": "string"},
{"type": "null"}
]
}
}
}
model = _create_model_from_schema(schema)
instance = model(value="test")
assert instance.value == "test"
instance2 = model(value=None)
assert instance2.value is None
def test_create_model_with_array(self):
"""Test creating a model with array type."""
schema = {
"type": "object",
"properties": {
"items": {
"type": "array",
"items": {"type": "string"}
}
}
}
model = _create_model_from_schema(schema)
instance = model(items=["a", "b", "c"])
assert instance.items == ["a", "b", "c"]
def test_model_json_schema_does_not_raise(self):
"""Test that generated model's model_json_schema() doesn't raise KeyError.
This is the core issue from GitHub issue #4312 - complex schemas with
internal $ref references would cause KeyError when calling model_json_schema().
"""
schema = {
"type": "object",
"properties": {
"geometry": {
"anyOf": [
{
"type": "array",
"items": {"type": "number"}
},
{"type": "null"}
]
}
}
}
resolved = _resolve_all_refs(schema)
model = _create_model_from_schema(resolved)
json_schema = model.model_json_schema()
assert "properties" in json_schema
assert "geometry" in json_schema["properties"]
class TestCrewAIAdapterWithSchemaFix:
"""Tests for the CrewAIAdapterWithSchemaFix class."""
def test_adapter_is_tool_adapter(self):
"""Test that the adapter is a valid ToolAdapter."""
from mcpadapt.core import ToolAdapter
adapter = CrewAIAdapterWithSchemaFix()
assert isinstance(adapter, ToolAdapter)
def test_async_adapt_raises_not_implemented(self):
"""Test that async_adapt raises NotImplementedError."""
adapter = CrewAIAdapterWithSchemaFix()
with pytest.raises(NotImplementedError):
import asyncio
asyncio.get_event_loop().run_until_complete(
adapter.async_adapt(None, None)
)