mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-04-10 04:52:40 +00:00
fix: handle complex JSON schemas in MCP adapter to prevent KeyError
Fixes #4312 The issue was that MCP servers with complex JSON schemas containing internal $ref references (like Mapbox MCP server) would cause a KeyError when Pydantic tried to generate JSON schemas from dynamically created models. This fix: - Creates a custom CrewAI adapter (CrewAIAdapterWithSchemaFix) that properly resolves all $ref references using jsonref before creating Pydantic models - Implements _resolve_all_refs() to fully resolve JSON Schema references - Implements _create_model_from_schema() that creates Pydantic models without passing problematic extra fields that cause issues during schema generation - Adds comprehensive tests for the new schema handling functionality Co-Authored-By: João <joao@crewai.com>
This commit is contained in:
@@ -2,6 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable, Coroutine
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
@@ -14,20 +15,303 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from mcp import StdioServerParameters
|
||||
from mcpadapt.core import MCPAdapt
|
||||
from mcpadapt.crewai_adapter import CrewAIAdapter
|
||||
import mcp.types
|
||||
from mcpadapt.core import MCPAdapt, ToolAdapter
|
||||
|
||||
|
||||
try:
|
||||
from typing import ForwardRef, Union
|
||||
|
||||
import jsonref
|
||||
from mcp import StdioServerParameters
|
||||
from mcpadapt.core import MCPAdapt
|
||||
from mcpadapt.crewai_adapter import CrewAIAdapter
|
||||
import mcp.types
|
||||
from mcpadapt.core import MCPAdapt, ToolAdapter
|
||||
from pydantic import BaseModel, Field, create_model
|
||||
|
||||
MCP_AVAILABLE = True
|
||||
except ImportError:
|
||||
MCP_AVAILABLE = False
|
||||
|
||||
|
||||
JSON_TYPE_MAPPING: dict[str, type] = {
|
||||
"string": str,
|
||||
"number": float,
|
||||
"integer": int,
|
||||
"boolean": bool,
|
||||
"object": dict,
|
||||
"array": list,
|
||||
}
|
||||
|
||||
|
||||
def _resolve_refs_to_dict(obj: Any) -> Any:
|
||||
"""Recursively convert JsonRef objects to regular dicts."""
|
||||
if isinstance(obj, dict):
|
||||
return {k: _resolve_refs_to_dict(v) for k, v in obj.items()}
|
||||
if isinstance(obj, list):
|
||||
return [_resolve_refs_to_dict(item) for item in obj]
|
||||
return obj
|
||||
|
||||
|
||||
def _resolve_all_refs(schema: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Resolve all $ref references in a JSON schema using jsonref.
|
||||
|
||||
This function fully resolves all JSON Schema $ref references, including
|
||||
internal references that point to paths within the schema itself
|
||||
(e.g., '#/properties/geometry/anyOf/0/items').
|
||||
|
||||
Args:
|
||||
schema: The JSON schema with potential $ref references.
|
||||
|
||||
Returns:
|
||||
A new schema dict with all $ref references resolved and inlined.
|
||||
"""
|
||||
if not MCP_AVAILABLE:
|
||||
return schema
|
||||
|
||||
resolved = jsonref.replace_refs(schema, lazy_load=False)
|
||||
result = _resolve_refs_to_dict(resolved)
|
||||
if "$defs" in result:
|
||||
del result["$defs"]
|
||||
return result
|
||||
|
||||
|
||||
def _create_model_from_schema(
|
||||
schema: dict[str, Any], model_name: str = "DynamicModel"
|
||||
) -> type[BaseModel]:
|
||||
"""Create a Pydantic model from a JSON schema definition.
|
||||
|
||||
This is a simplified version that handles common JSON schema patterns
|
||||
without passing problematic extra fields to Pydantic's Field().
|
||||
|
||||
Args:
|
||||
schema: The JSON schema definition.
|
||||
model_name: The name for the created model.
|
||||
|
||||
Returns:
|
||||
A Pydantic BaseModel class.
|
||||
"""
|
||||
if not MCP_AVAILABLE:
|
||||
raise RuntimeError("MCP dependencies not available")
|
||||
|
||||
created_models: dict[str, type[BaseModel]] = {}
|
||||
forward_refs: dict[str, ForwardRef] = {}
|
||||
|
||||
def process_schema(name: str, schema_def: dict[str, Any]) -> type[BaseModel]:
|
||||
if name in created_models:
|
||||
return created_models[name]
|
||||
|
||||
if name not in forward_refs:
|
||||
forward_refs[name] = ForwardRef(name)
|
||||
|
||||
fields: dict[str, Any] = {}
|
||||
properties = schema_def.get("properties", {})
|
||||
required = set(schema_def.get("required", []))
|
||||
|
||||
for field_name, field_schema in properties.items():
|
||||
field_type, default = get_field_type(field_name, field_schema, required)
|
||||
fields[field_name] = (
|
||||
field_type,
|
||||
Field(
|
||||
default=default,
|
||||
description=field_schema.get("description", ""),
|
||||
),
|
||||
)
|
||||
|
||||
model: type[BaseModel] = create_model(
|
||||
schema_def.get("title", name),
|
||||
__doc__=schema_def.get("description", ""),
|
||||
**fields,
|
||||
)
|
||||
|
||||
created_models[name] = model
|
||||
return model
|
||||
|
||||
def get_field_type(
|
||||
field_name: str, field_schema: dict[str, Any], required: set[str]
|
||||
) -> tuple[Any, Any]:
|
||||
if "$ref" in field_schema:
|
||||
ref_parts = field_schema["$ref"].lstrip("#/").split("/")
|
||||
ref_name = ref_parts[-1]
|
||||
|
||||
if ref_name not in created_models:
|
||||
ref_schema = schema
|
||||
for part in ref_parts:
|
||||
ref_schema = ref_schema.get(part, {})
|
||||
process_schema(ref_name, ref_schema)
|
||||
|
||||
field_type = created_models[ref_name]
|
||||
is_required = field_name in required
|
||||
return (
|
||||
field_type | None if not is_required else field_type,
|
||||
None if not is_required else ...,
|
||||
)
|
||||
|
||||
if "anyOf" in field_schema:
|
||||
is_nullable = any(
|
||||
opt.get("type") == "null" for opt in field_schema["anyOf"]
|
||||
)
|
||||
types: list[type[Any]] = []
|
||||
|
||||
for option in field_schema["anyOf"]:
|
||||
if "type" in option and option["type"] != "null":
|
||||
types.append(JSON_TYPE_MAPPING.get(option["type"], Any))
|
||||
elif "enum" in option:
|
||||
types.append(str)
|
||||
elif "$ref" in option:
|
||||
ref_parts = option["$ref"].lstrip("#/").split("/")
|
||||
ref_name = ref_parts[-1]
|
||||
|
||||
if ref_name not in created_models:
|
||||
ref_schema = schema
|
||||
for part in ref_parts:
|
||||
ref_schema = ref_schema.get(part, {})
|
||||
process_schema(ref_name, ref_schema)
|
||||
|
||||
types.append(created_models[ref_name])
|
||||
|
||||
if len(types) == 0:
|
||||
field_type = Any
|
||||
elif len(types) == 1:
|
||||
field_type = types[0]
|
||||
else:
|
||||
field_type = Union[tuple(types)] # noqa: UP007
|
||||
|
||||
default = field_schema.get("default")
|
||||
is_required = field_name in required and default is None
|
||||
|
||||
if is_nullable and not is_required:
|
||||
field_type = field_type | None
|
||||
|
||||
return field_type, ... if is_required else default
|
||||
|
||||
if field_schema.get("type") == "array" and "items" in field_schema:
|
||||
item_type, _ = get_field_type("item", field_schema["items"], set())
|
||||
field_type = list[item_type]
|
||||
else:
|
||||
json_type = field_schema.get("type", "string")
|
||||
|
||||
if isinstance(json_type, list):
|
||||
types = []
|
||||
for t in json_type:
|
||||
if t != "null":
|
||||
mapped_type = JSON_TYPE_MAPPING.get(t, Any)
|
||||
types.append(mapped_type)
|
||||
|
||||
if len(types) == 0:
|
||||
field_type = Any
|
||||
elif len(types) == 1:
|
||||
field_type = types[0]
|
||||
else:
|
||||
field_type = Union[tuple(types)] # noqa: UP007
|
||||
else:
|
||||
field_type = JSON_TYPE_MAPPING.get(json_type, Any)
|
||||
|
||||
default = field_schema.get("default")
|
||||
is_required = field_name in required and default is None
|
||||
|
||||
if not is_required:
|
||||
field_type = field_type | None
|
||||
default = default if default is not None else None
|
||||
else:
|
||||
default = ...
|
||||
|
||||
return field_type, default
|
||||
|
||||
if "$defs" in schema:
|
||||
for def_name, def_schema in schema["$defs"].items():
|
||||
process_schema(def_name, def_schema)
|
||||
|
||||
return process_schema(model_name, schema)
|
||||
|
||||
|
||||
class CrewAIAdapterWithSchemaFix(ToolAdapter):
|
||||
"""Custom CrewAI adapter that properly handles complex JSON schemas.
|
||||
|
||||
This adapter extends mcpadapt's ToolAdapter to fix issues with complex
|
||||
JSON schemas that contain internal $ref references (e.g., Mapbox MCP server).
|
||||
It fully resolves all $ref references before creating Pydantic models,
|
||||
preventing KeyError exceptions during JSON schema generation.
|
||||
"""
|
||||
|
||||
def adapt(
|
||||
self,
|
||||
func: Callable[[dict[str, Any] | None], mcp.types.CallToolResult],
|
||||
mcp_tool: mcp.types.Tool,
|
||||
) -> BaseTool:
|
||||
"""Adapt a MCP tool to a CrewAI tool with proper schema handling.
|
||||
|
||||
Args:
|
||||
func: The function to adapt.
|
||||
mcp_tool: The MCP tool to adapt.
|
||||
|
||||
Returns:
|
||||
A CrewAI tool.
|
||||
"""
|
||||
resolved_schema = _resolve_all_refs(mcp_tool.inputSchema)
|
||||
tool_input_model = _create_model_from_schema(resolved_schema)
|
||||
|
||||
class CrewAIMCPTool(BaseTool):
|
||||
name: str = mcp_tool.name
|
||||
description: str = mcp_tool.description or ""
|
||||
args_schema: type[BaseModel] = tool_input_model
|
||||
|
||||
def _run(self, *args: Any, **kwargs: Any) -> Any:
|
||||
filtered_kwargs: dict[str, Any] = {}
|
||||
schema_properties = resolved_schema.get("properties", {})
|
||||
|
||||
for key, value in kwargs.items():
|
||||
if value is None and key in schema_properties:
|
||||
prop_schema = schema_properties[key]
|
||||
if isinstance(prop_schema.get("type"), list):
|
||||
if "null" in prop_schema["type"]:
|
||||
filtered_kwargs[key] = value
|
||||
elif "anyOf" in prop_schema:
|
||||
if any(
|
||||
opt.get("type") == "null"
|
||||
for opt in prop_schema["anyOf"]
|
||||
):
|
||||
filtered_kwargs[key] = value
|
||||
else:
|
||||
filtered_kwargs[key] = value
|
||||
|
||||
result = func(filtered_kwargs)
|
||||
return (
|
||||
result.content[0].text
|
||||
if len(result.content) == 1
|
||||
else str(
|
||||
[
|
||||
content.text
|
||||
for content in result.content
|
||||
if hasattr(content, "text")
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
def _generate_description(self) -> None:
|
||||
try:
|
||||
args_schema = {
|
||||
k: v
|
||||
for k, v in jsonref.replace_refs(
|
||||
self.args_schema.model_json_schema()
|
||||
).items()
|
||||
if k != "$defs"
|
||||
}
|
||||
except Exception:
|
||||
args_schema = resolved_schema
|
||||
self.description = f"Tool Name: {self.name}\nTool Arguments: {args_schema}\nTool Description: {self.description}"
|
||||
|
||||
return CrewAIMCPTool()
|
||||
|
||||
async def async_adapt(
|
||||
self,
|
||||
afunc: Callable[
|
||||
[dict[str, Any] | None], Coroutine[Any, Any, mcp.types.CallToolResult]
|
||||
],
|
||||
mcp_tool: mcp.types.Tool,
|
||||
) -> Any:
|
||||
raise NotImplementedError("async is not supported by the CrewAI framework.")
|
||||
|
||||
|
||||
class MCPServerAdapter:
|
||||
"""Manages the lifecycle of an MCP server and make its tools available to CrewAI.
|
||||
|
||||
@@ -112,7 +396,7 @@ class MCPServerAdapter:
|
||||
try:
|
||||
self._serverparams = serverparams
|
||||
self._adapter = MCPAdapt(
|
||||
self._serverparams, CrewAIAdapter(), connect_timeout
|
||||
self._serverparams, CrewAIAdapterWithSchemaFix(), connect_timeout
|
||||
)
|
||||
self.start()
|
||||
|
||||
|
||||
@@ -2,6 +2,11 @@ from textwrap import dedent
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from crewai_tools import MCPServerAdapter
|
||||
from crewai_tools.adapters.mcp_adapter import (
|
||||
_resolve_all_refs,
|
||||
_create_model_from_schema,
|
||||
CrewAIAdapterWithSchemaFix,
|
||||
)
|
||||
from crewai_tools.adapters.tool_collection import ToolCollection
|
||||
from mcp import StdioServerParameters
|
||||
import pytest
|
||||
@@ -237,3 +242,195 @@ def test_connect_timeout_passed_to_mcpadapt(mock_mcpadapt):
|
||||
MCPServerAdapter(serverparams, connect_timeout=5)
|
||||
mock_mcpadapt.assert_called_once()
|
||||
assert mock_mcpadapt.call_args[0][2] == 5
|
||||
|
||||
|
||||
class TestResolveAllRefs:
|
||||
"""Tests for the _resolve_all_refs function that handles complex JSON schemas."""
|
||||
|
||||
def test_resolve_simple_defs_ref(self):
|
||||
"""Test resolving $ref that points to $defs."""
|
||||
schema = {
|
||||
"type": "object",
|
||||
"$defs": {
|
||||
"Point": {
|
||||
"type": "array",
|
||||
"items": {"type": "number"},
|
||||
}
|
||||
},
|
||||
"properties": {
|
||||
"location": {"$ref": "#/$defs/Point"}
|
||||
}
|
||||
}
|
||||
resolved = _resolve_all_refs(schema)
|
||||
assert "$defs" not in resolved
|
||||
assert resolved["properties"]["location"]["type"] == "array"
|
||||
assert resolved["properties"]["location"]["items"]["type"] == "number"
|
||||
|
||||
def test_resolve_nested_anyof_refs(self):
|
||||
"""Test resolving $ref inside anyOf (like Mapbox geometry schema)."""
|
||||
schema = {
|
||||
"type": "object",
|
||||
"$defs": {
|
||||
"Coordinate": {
|
||||
"type": "array",
|
||||
"items": {"type": "number"},
|
||||
"minItems": 2,
|
||||
}
|
||||
},
|
||||
"properties": {
|
||||
"geometry": {
|
||||
"anyOf": [
|
||||
{
|
||||
"type": "array",
|
||||
"items": {"$ref": "#/$defs/Coordinate"}
|
||||
},
|
||||
{"type": "null"}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
resolved = _resolve_all_refs(schema)
|
||||
assert "$defs" not in resolved
|
||||
geometry_anyof = resolved["properties"]["geometry"]["anyOf"]
|
||||
array_option = geometry_anyof[0]
|
||||
assert array_option["type"] == "array"
|
||||
assert array_option["items"]["type"] == "array"
|
||||
assert array_option["items"]["items"]["type"] == "number"
|
||||
|
||||
def test_resolve_internal_property_refs(self):
|
||||
"""Test resolving $ref that points to internal properties."""
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"coordinates": {
|
||||
"type": "array",
|
||||
"items": {"type": "number"}
|
||||
},
|
||||
"geometry": {
|
||||
"anyOf": [
|
||||
{
|
||||
"type": "array",
|
||||
"items": {"$ref": "#/properties/coordinates"}
|
||||
},
|
||||
{"type": "null"}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
resolved = _resolve_all_refs(schema)
|
||||
geometry_anyof = resolved["properties"]["geometry"]["anyOf"]
|
||||
array_option = geometry_anyof[0]
|
||||
assert array_option["items"]["type"] == "array"
|
||||
assert array_option["items"]["items"]["type"] == "number"
|
||||
|
||||
def test_no_refs_returns_same_schema(self):
|
||||
"""Test that schema without refs is returned unchanged (except $defs removal)."""
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"age": {"type": "integer"}
|
||||
}
|
||||
}
|
||||
resolved = _resolve_all_refs(schema)
|
||||
assert resolved["properties"]["name"]["type"] == "string"
|
||||
assert resolved["properties"]["age"]["type"] == "integer"
|
||||
|
||||
|
||||
class TestCreateModelFromSchema:
|
||||
"""Tests for the _create_model_from_schema function."""
|
||||
|
||||
def test_create_simple_model(self):
|
||||
"""Test creating a model from a simple schema."""
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string", "description": "The name"},
|
||||
"count": {"type": "integer"}
|
||||
},
|
||||
"required": ["name"]
|
||||
}
|
||||
model = _create_model_from_schema(schema)
|
||||
assert model.__name__ == "DynamicModel"
|
||||
instance = model(name="test")
|
||||
assert instance.name == "test"
|
||||
assert instance.count is None
|
||||
|
||||
def test_create_model_with_anyof_nullable(self):
|
||||
"""Test creating a model with anyOf that includes null."""
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"value": {
|
||||
"anyOf": [
|
||||
{"type": "string"},
|
||||
{"type": "null"}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
model = _create_model_from_schema(schema)
|
||||
instance = model(value="test")
|
||||
assert instance.value == "test"
|
||||
instance2 = model(value=None)
|
||||
assert instance2.value is None
|
||||
|
||||
def test_create_model_with_array(self):
|
||||
"""Test creating a model with array type."""
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"items": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"}
|
||||
}
|
||||
}
|
||||
}
|
||||
model = _create_model_from_schema(schema)
|
||||
instance = model(items=["a", "b", "c"])
|
||||
assert instance.items == ["a", "b", "c"]
|
||||
|
||||
def test_model_json_schema_does_not_raise(self):
|
||||
"""Test that generated model's model_json_schema() doesn't raise KeyError.
|
||||
|
||||
This is the core issue from GitHub issue #4312 - complex schemas with
|
||||
internal $ref references would cause KeyError when calling model_json_schema().
|
||||
"""
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"geometry": {
|
||||
"anyOf": [
|
||||
{
|
||||
"type": "array",
|
||||
"items": {"type": "number"}
|
||||
},
|
||||
{"type": "null"}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
resolved = _resolve_all_refs(schema)
|
||||
model = _create_model_from_schema(resolved)
|
||||
json_schema = model.model_json_schema()
|
||||
assert "properties" in json_schema
|
||||
assert "geometry" in json_schema["properties"]
|
||||
|
||||
|
||||
class TestCrewAIAdapterWithSchemaFix:
|
||||
"""Tests for the CrewAIAdapterWithSchemaFix class."""
|
||||
|
||||
def test_adapter_is_tool_adapter(self):
|
||||
"""Test that the adapter is a valid ToolAdapter."""
|
||||
from mcpadapt.core import ToolAdapter
|
||||
adapter = CrewAIAdapterWithSchemaFix()
|
||||
assert isinstance(adapter, ToolAdapter)
|
||||
|
||||
def test_async_adapt_raises_not_implemented(self):
|
||||
"""Test that async_adapt raises NotImplementedError."""
|
||||
adapter = CrewAIAdapterWithSchemaFix()
|
||||
with pytest.raises(NotImplementedError):
|
||||
import asyncio
|
||||
asyncio.get_event_loop().run_until_complete(
|
||||
adapter.async_adapt(None, None)
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user