fix: strip nulls from json schemas and simplify mcp args

This commit is contained in:
Greyson LaLonde
2026-02-03 15:12:46 -05:00
parent 6dbbc5724c
commit 0e4393d7b0
2 changed files with 64 additions and 28 deletions

View File

@@ -7,9 +7,7 @@ import logging
from typing import TYPE_CHECKING, Any
from crewai.tools import BaseTool
from crewai.utilities.pydantic_schema_utils import (
create_model_from_schema,
)
from crewai.utilities.pydantic_schema_utils import create_model_from_schema
from crewai.utilities.string_utils import sanitize_tool_name
from pydantic import BaseModel
@@ -53,9 +51,7 @@ try:
"""
tool_name = sanitize_tool_name(mcp_tool.name)
tool_description = mcp_tool.description or ""
input_schema = mcp_tool.inputSchema
args_model = create_model_from_schema(input_schema)
args_model = create_model_from_schema(mcp_tool.inputSchema)
class CrewAIMCPTool(BaseTool):
name: str = tool_name
@@ -63,25 +59,7 @@ try:
args_schema: type[BaseModel] = args_model
def _run(self, **kwargs: Any) -> Any:
filtered_kwargs: dict[str, Any] = {}
schema_properties = input_schema.get("properties", {})
for key, value in kwargs.items():
if value is None and key in schema_properties:
prop_schema = schema_properties[key]
if isinstance(prop_schema.get("type"), list):
if "null" in prop_schema["type"]:
filtered_kwargs[key] = value
elif "anyOf" in prop_schema:
if any(
opt.get("type") == "null"
for opt in prop_schema["anyOf"]
):
filtered_kwargs[key] = value
else:
filtered_kwargs[key] = value
result = func(filtered_kwargs)
result = func(kwargs)
if len(result.content) == 1:
first_content = result.content[0]
if isinstance(first_content, TextContent):

View File

@@ -19,7 +19,7 @@ from collections.abc import Callable
from copy import deepcopy
import datetime
import logging
from typing import TYPE_CHECKING, Annotated, Any, Final, Literal, Union
from typing import TYPE_CHECKING, Annotated, Any, Final, Literal, TypedDict, Union
import uuid
import jsonref # type: ignore[import-untyped]
@@ -70,6 +70,21 @@ else:
EmailStr = str
class JsonSchemaInfo(TypedDict):
"""Inner structure for JSON schema metadata."""
name: str
strict: Literal[True]
schema: dict[str, Any]
class ModelDescription(TypedDict):
"""Return type for generate_model_description."""
type: Literal["json_schema"]
json_schema: JsonSchemaInfo
def resolve_refs(schema: dict[str, Any]) -> dict[str, Any]:
"""Recursively resolve all local $refs in the given JSON Schema using $defs as the source.
@@ -360,7 +375,49 @@ def ensure_all_properties_required(schema: dict[str, Any]) -> dict[str, Any]:
return schema
def generate_model_description(model: type[BaseModel]) -> dict[str, Any]:
def strip_null_from_types(schema: dict[str, Any]) -> dict[str, Any]:
"""Remove null type from anyOf/type arrays.
Pydantic generates `T | None` for optional fields, which creates schemas with
null in the type. However, for MCP tools, optional fields should be omitted
entirely rather than sent as null. This function strips null from types.
Args:
schema: JSON schema dictionary.
Returns:
Modified schema with null types removed.
"""
if isinstance(schema, dict):
if "anyOf" in schema:
any_of = schema["anyOf"]
non_null = [opt for opt in any_of if opt.get("type") != "null"]
if len(non_null) == 1:
schema.pop("anyOf")
schema.update(non_null[0])
elif len(non_null) > 1:
schema["anyOf"] = non_null
type_value = schema.get("type")
if isinstance(type_value, list) and "null" in type_value:
non_null_types = [t for t in type_value if t != "null"]
if len(non_null_types) == 1:
schema["type"] = non_null_types[0]
elif len(non_null_types) > 1:
schema["type"] = non_null_types
for value in schema.values():
if isinstance(value, dict):
strip_null_from_types(value)
elif isinstance(value, list):
for item in value:
if isinstance(item, dict):
strip_null_from_types(item)
return schema
def generate_model_description(model: type[BaseModel]) -> ModelDescription:
"""Generate JSON schema description of a Pydantic model.
This function takes a Pydantic model class and returns its JSON schema,
@@ -371,7 +428,7 @@ def generate_model_description(model: type[BaseModel]) -> dict[str, Any]:
model: A Pydantic model class.
Returns:
A JSON schema dictionary representation of the model.
A ModelDescription with JSON schema representation of the model.
"""
json_schema = model.model_json_schema(ref_template="#/$defs/{model}")
@@ -385,6 +442,7 @@ def generate_model_description(model: type[BaseModel]) -> dict[str, Any]:
json_schema = fix_discriminator_mappings(json_schema)
json_schema = convert_oneof_to_anyof(json_schema)
json_schema = ensure_all_properties_required(json_schema)
json_schema = strip_null_from_types(json_schema)
return {
"type": "json_schema",