mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-04-30 23:02:50 +00:00
fix: preserve null types in tool parameter schemas for LLM (#4579)
* fix: preserve null types in tool parameter schemas for LLM
Tool parameter schemas were stripping null from optional fields via
generate_model_description, forcing the LLM to provide non-null values
for fields.
Adds strip_null_types parameter to generate_model_description and passes False when generating tool
schemas, so optional fields keep anyOf: [{type: T}, {type: null}]
* Update lib/crewai/src/crewai/utilities/pydantic_schema_utils.py
Co-authored-by: Gabe Milani <gabriel@crewai.com>
---------
Co-authored-by: Greyson LaLonde <greyson.r.lalonde@gmail.com>
Co-authored-by: Gabe Milani <gabriel@crewai.com>
Co-authored-by: Lorenze Jay <63378463+lorenzejay@users.noreply.github.com>
This commit is contained in:
@@ -168,7 +168,9 @@ def convert_tools_to_openai_schema(
|
||||
parameters: dict[str, Any] = {}
|
||||
if hasattr(tool, "args_schema") and tool.args_schema is not None:
|
||||
try:
|
||||
schema_output = generate_model_description(tool.args_schema)
|
||||
schema_output = generate_model_description(
|
||||
tool.args_schema, strip_null_types=False
|
||||
)
|
||||
parameters = schema_output.get("json_schema", {}).get("schema", {})
|
||||
# Remove title and description from schema root as they're redundant
|
||||
parameters.pop("title", None)
|
||||
|
||||
@@ -417,7 +417,11 @@ def strip_null_from_types(schema: dict[str, Any]) -> dict[str, Any]:
|
||||
return schema
|
||||
|
||||
|
||||
def generate_model_description(model: type[BaseModel]) -> ModelDescription:
|
||||
def generate_model_description(
|
||||
model: type[BaseModel],
|
||||
*,
|
||||
strip_null_types: bool = True,
|
||||
) -> ModelDescription:
|
||||
"""Generate JSON schema description of a Pydantic model.
|
||||
|
||||
This function takes a Pydantic model class and returns its JSON schema,
|
||||
@@ -426,6 +430,9 @@ def generate_model_description(model: type[BaseModel]) -> ModelDescription:
|
||||
|
||||
Args:
|
||||
model: A Pydantic model class.
|
||||
strip_null_types: When ``True`` (default), remove ``null`` from
|
||||
``anyOf`` / ``type`` arrays. Set to ``False`` to allow sending ``null`` for
|
||||
optional fields.
|
||||
|
||||
Returns:
|
||||
A ModelDescription with JSON schema representation of the model.
|
||||
@@ -442,7 +449,9 @@ def generate_model_description(model: type[BaseModel]) -> ModelDescription:
|
||||
json_schema = fix_discriminator_mappings(json_schema)
|
||||
json_schema = convert_oneof_to_anyof(json_schema)
|
||||
json_schema = ensure_all_properties_required(json_schema)
|
||||
json_schema = strip_null_from_types(json_schema)
|
||||
|
||||
if strip_null_types:
|
||||
json_schema = strip_null_from_types(json_schema)
|
||||
|
||||
return {
|
||||
"type": "json_schema",
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import Any
|
||||
from typing import Any, Literal, Optional
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
@@ -235,6 +235,79 @@ def _make_mock_i18n() -> MagicMock:
|
||||
}.get(key, "")
|
||||
return mock_i18n
|
||||
|
||||
class MCPStyleInput(BaseModel):
|
||||
"""Input schema mimicking an MCP tool with optional fields."""
|
||||
|
||||
query: str = Field(description="Search query")
|
||||
filter_type: Optional[Literal["internal", "user"]] = Field(
|
||||
default=None, description="Filter type"
|
||||
)
|
||||
page_id: Optional[str] = Field(
|
||||
default=None, description="Page UUID"
|
||||
)
|
||||
|
||||
|
||||
class MCPStyleTool(BaseTool):
|
||||
"""A tool mimicking MCP tool schemas with optional fields."""
|
||||
|
||||
name: str = "mcp_search"
|
||||
description: str = "Search with optional filters"
|
||||
args_schema: type[BaseModel] = MCPStyleInput
|
||||
|
||||
def _run(self, **kwargs: Any) -> str:
|
||||
return "result"
|
||||
|
||||
|
||||
class TestOptionalFieldsPreserveNull:
|
||||
"""Tests that optional tool fields preserve null in the schema."""
|
||||
|
||||
def test_optional_string_allows_null(self) -> None:
|
||||
"""Optional[str] fields should include null in the schema so the LLM
|
||||
can send null instead of being forced to guess a value."""
|
||||
tools = [MCPStyleTool()]
|
||||
schemas, _ = convert_tools_to_openai_schema(tools)
|
||||
|
||||
params = schemas[0]["function"]["parameters"]
|
||||
page_id_prop = params["properties"]["page_id"]
|
||||
|
||||
assert "anyOf" in page_id_prop
|
||||
type_options = [opt.get("type") for opt in page_id_prop["anyOf"]]
|
||||
assert "string" in type_options
|
||||
assert "null" in type_options
|
||||
|
||||
def test_optional_literal_allows_null(self) -> None:
|
||||
"""Optional[Literal[...]] fields should include null."""
|
||||
tools = [MCPStyleTool()]
|
||||
schemas, _ = convert_tools_to_openai_schema(tools)
|
||||
|
||||
params = schemas[0]["function"]["parameters"]
|
||||
filter_prop = params["properties"]["filter_type"]
|
||||
|
||||
assert "anyOf" in filter_prop
|
||||
has_null = any(opt.get("type") == "null" for opt in filter_prop["anyOf"])
|
||||
assert has_null
|
||||
|
||||
def test_required_field_stays_non_null(self) -> None:
|
||||
"""Required fields without Optional should NOT have null."""
|
||||
tools = [MCPStyleTool()]
|
||||
schemas, _ = convert_tools_to_openai_schema(tools)
|
||||
|
||||
params = schemas[0]["function"]["parameters"]
|
||||
query_prop = params["properties"]["query"]
|
||||
|
||||
assert query_prop.get("type") == "string"
|
||||
assert "anyOf" not in query_prop
|
||||
|
||||
def test_all_fields_in_required_for_strict_mode(self) -> None:
|
||||
"""All fields (including optional) must be in required for strict mode."""
|
||||
tools = [MCPStyleTool()]
|
||||
schemas, _ = convert_tools_to_openai_schema(tools)
|
||||
|
||||
params = schemas[0]["function"]["parameters"]
|
||||
assert "query" in params["required"]
|
||||
assert "filter_type" in params["required"]
|
||||
assert "page_id" in params["required"]
|
||||
|
||||
|
||||
class TestSummarizeMessages:
|
||||
"""Tests for summarize_messages function."""
|
||||
|
||||
Reference in New Issue
Block a user