mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-07-03 14:09:24 +00:00
fix: strip nulls from json schemas and simplify mcp args
This commit is contained in:
@@ -7,9 +7,7 @@ import logging
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
from crewai.utilities.pydantic_schema_utils import (
|
||||
create_model_from_schema,
|
||||
)
|
||||
from crewai.utilities.pydantic_schema_utils import create_model_from_schema
|
||||
from crewai.utilities.string_utils import sanitize_tool_name
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -53,9 +51,7 @@ try:
|
||||
"""
|
||||
tool_name = sanitize_tool_name(mcp_tool.name)
|
||||
tool_description = mcp_tool.description or ""
|
||||
input_schema = mcp_tool.inputSchema
|
||||
|
||||
args_model = create_model_from_schema(input_schema)
|
||||
args_model = create_model_from_schema(mcp_tool.inputSchema)
|
||||
|
||||
class CrewAIMCPTool(BaseTool):
|
||||
name: str = tool_name
|
||||
@@ -63,25 +59,7 @@ try:
|
||||
args_schema: type[BaseModel] = args_model
|
||||
|
||||
def _run(self, **kwargs: Any) -> Any:
|
||||
filtered_kwargs: dict[str, Any] = {}
|
||||
schema_properties = input_schema.get("properties", {})
|
||||
|
||||
for key, value in kwargs.items():
|
||||
if value is None and key in schema_properties:
|
||||
prop_schema = schema_properties[key]
|
||||
if isinstance(prop_schema.get("type"), list):
|
||||
if "null" in prop_schema["type"]:
|
||||
filtered_kwargs[key] = value
|
||||
elif "anyOf" in prop_schema:
|
||||
if any(
|
||||
opt.get("type") == "null"
|
||||
for opt in prop_schema["anyOf"]
|
||||
):
|
||||
filtered_kwargs[key] = value
|
||||
else:
|
||||
filtered_kwargs[key] = value
|
||||
|
||||
result = func(filtered_kwargs)
|
||||
result = func(kwargs)
|
||||
if len(result.content) == 1:
|
||||
first_content = result.content[0]
|
||||
if isinstance(first_content, TextContent):
|
||||
|
||||
@@ -19,7 +19,7 @@ from collections.abc import Callable
|
||||
from copy import deepcopy
|
||||
import datetime
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Annotated, Any, Final, Literal, Union
|
||||
from typing import TYPE_CHECKING, Annotated, Any, Final, Literal, TypedDict, Union
|
||||
import uuid
|
||||
|
||||
import jsonref # type: ignore[import-untyped]
|
||||
@@ -70,6 +70,21 @@ else:
|
||||
EmailStr = str
|
||||
|
||||
|
||||
class JsonSchemaInfo(TypedDict):
|
||||
"""Inner structure for JSON schema metadata."""
|
||||
|
||||
name: str
|
||||
strict: Literal[True]
|
||||
schema: dict[str, Any]
|
||||
|
||||
|
||||
class ModelDescription(TypedDict):
|
||||
"""Return type for generate_model_description."""
|
||||
|
||||
type: Literal["json_schema"]
|
||||
json_schema: JsonSchemaInfo
|
||||
|
||||
|
||||
def resolve_refs(schema: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Recursively resolve all local $refs in the given JSON Schema using $defs as the source.
|
||||
|
||||
@@ -360,7 +375,49 @@ def ensure_all_properties_required(schema: dict[str, Any]) -> dict[str, Any]:
|
||||
return schema
|
||||
|
||||
|
||||
def generate_model_description(model: type[BaseModel]) -> dict[str, Any]:
|
||||
def strip_null_from_types(schema: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Remove null type from anyOf/type arrays.
|
||||
|
||||
Pydantic generates `T | None` for optional fields, which creates schemas with
|
||||
null in the type. However, for MCP tools, optional fields should be omitted
|
||||
entirely rather than sent as null. This function strips null from types.
|
||||
|
||||
Args:
|
||||
schema: JSON schema dictionary.
|
||||
|
||||
Returns:
|
||||
Modified schema with null types removed.
|
||||
"""
|
||||
if isinstance(schema, dict):
|
||||
if "anyOf" in schema:
|
||||
any_of = schema["anyOf"]
|
||||
non_null = [opt for opt in any_of if opt.get("type") != "null"]
|
||||
if len(non_null) == 1:
|
||||
schema.pop("anyOf")
|
||||
schema.update(non_null[0])
|
||||
elif len(non_null) > 1:
|
||||
schema["anyOf"] = non_null
|
||||
|
||||
type_value = schema.get("type")
|
||||
if isinstance(type_value, list) and "null" in type_value:
|
||||
non_null_types = [t for t in type_value if t != "null"]
|
||||
if len(non_null_types) == 1:
|
||||
schema["type"] = non_null_types[0]
|
||||
elif len(non_null_types) > 1:
|
||||
schema["type"] = non_null_types
|
||||
|
||||
for value in schema.values():
|
||||
if isinstance(value, dict):
|
||||
strip_null_from_types(value)
|
||||
elif isinstance(value, list):
|
||||
for item in value:
|
||||
if isinstance(item, dict):
|
||||
strip_null_from_types(item)
|
||||
|
||||
return schema
|
||||
|
||||
|
||||
def generate_model_description(model: type[BaseModel]) -> ModelDescription:
|
||||
"""Generate JSON schema description of a Pydantic model.
|
||||
|
||||
This function takes a Pydantic model class and returns its JSON schema,
|
||||
@@ -371,7 +428,7 @@ def generate_model_description(model: type[BaseModel]) -> dict[str, Any]:
|
||||
model: A Pydantic model class.
|
||||
|
||||
Returns:
|
||||
A JSON schema dictionary representation of the model.
|
||||
A ModelDescription with JSON schema representation of the model.
|
||||
"""
|
||||
json_schema = model.model_json_schema(ref_template="#/$defs/{model}")
|
||||
|
||||
@@ -385,6 +442,7 @@ def generate_model_description(model: type[BaseModel]) -> dict[str, Any]:
|
||||
json_schema = fix_discriminator_mappings(json_schema)
|
||||
json_schema = convert_oneof_to_anyof(json_schema)
|
||||
json_schema = ensure_all_properties_required(json_schema)
|
||||
json_schema = strip_null_from_types(json_schema)
|
||||
|
||||
return {
|
||||
"type": "json_schema",
|
||||
|
||||
Reference in New Issue
Block a user