From 09e3b81ca39ba113187080f5a0fb8a2560fe5448 Mon Sep 17 00:00:00 2001 From: Lucas Gomide Date: Thu, 26 Feb 2026 13:51:34 -0300 Subject: [PATCH] fix: preserve null types in tool parameter schemas for LLM (#4579) * fix: preserve null types in tool parameter schemas for LLM Tool parameter schemas were stripping null from optional fields via generate_model_description, forcing the LLM to provide non-null values for fields. Adds strip_null_types parameter to generate_model_description and passes False when generating tool schemas, so optional fields keep anyOf: [{type: T}, {type: null}] * Update lib/crewai/src/crewai/utilities/pydantic_schema_utils.py Co-authored-by: Gabe Milani --------- Co-authored-by: Greyson LaLonde Co-authored-by: Gabe Milani Co-authored-by: Lorenze Jay <63378463+lorenzejay@users.noreply.github.com> --- .../src/crewai/utilities/agent_utils.py | 4 +- .../crewai/utilities/pydantic_schema_utils.py | 13 +++- .../tests/utilities/test_agent_utils.py | 75 ++++++++++++++++++- 3 files changed, 88 insertions(+), 4 deletions(-) diff --git a/lib/crewai/src/crewai/utilities/agent_utils.py b/lib/crewai/src/crewai/utilities/agent_utils.py index 7cad2ad67..a1e33168d 100644 --- a/lib/crewai/src/crewai/utilities/agent_utils.py +++ b/lib/crewai/src/crewai/utilities/agent_utils.py @@ -168,7 +168,9 @@ def convert_tools_to_openai_schema( parameters: dict[str, Any] = {} if hasattr(tool, "args_schema") and tool.args_schema is not None: try: - schema_output = generate_model_description(tool.args_schema) + schema_output = generate_model_description( + tool.args_schema, strip_null_types=False + ) parameters = schema_output.get("json_schema", {}).get("schema", {}) # Remove title and description from schema root as they're redundant parameters.pop("title", None) diff --git a/lib/crewai/src/crewai/utilities/pydantic_schema_utils.py b/lib/crewai/src/crewai/utilities/pydantic_schema_utils.py index 191f38c35..4548ab9ce 100644 --- a/lib/crewai/src/crewai/utilities/pydantic_schema_utils.py +++ b/lib/crewai/src/crewai/utilities/pydantic_schema_utils.py @@ -417,7 +417,11 @@ def strip_null_from_types(schema: dict[str, Any]) -> dict[str, Any]: return schema -def generate_model_description(model: type[BaseModel]) -> ModelDescription: +def generate_model_description( + model: type[BaseModel], + *, + strip_null_types: bool = True, +) -> ModelDescription: """Generate JSON schema description of a Pydantic model. This function takes a Pydantic model class and returns its JSON schema, @@ -426,6 +430,9 @@ def generate_model_description(model: type[BaseModel]) -> ModelDescription: Args: model: A Pydantic model class. + strip_null_types: When ``True`` (default), remove ``null`` from + ``anyOf`` / ``type`` arrays. Set to ``False`` to allow sending ``null`` for + optional fields. Returns: A ModelDescription with JSON schema representation of the model. @@ -442,7 +449,9 @@ def generate_model_description(model: type[BaseModel]) -> ModelDescription: json_schema = fix_discriminator_mappings(json_schema) json_schema = convert_oneof_to_anyof(json_schema) json_schema = ensure_all_properties_required(json_schema) - json_schema = strip_null_from_types(json_schema) + + if strip_null_types: + json_schema = strip_null_from_types(json_schema) return { "type": "json_schema", diff --git a/lib/crewai/tests/utilities/test_agent_utils.py b/lib/crewai/tests/utilities/test_agent_utils.py index 8e3093219..43477c25e 100644 --- a/lib/crewai/tests/utilities/test_agent_utils.py +++ b/lib/crewai/tests/utilities/test_agent_utils.py @@ -3,7 +3,7 @@ from __future__ import annotations import asyncio -from typing import Any +from typing import Any, Literal, Optional from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -235,6 +235,79 @@ def _make_mock_i18n() -> MagicMock: }.get(key, "") return mock_i18n +class MCPStyleInput(BaseModel): + """Input schema mimicking an MCP tool with optional fields.""" + + query: str = Field(description="Search query") + filter_type: Optional[Literal["internal", "user"]] = Field( + default=None, description="Filter type" + ) + page_id: Optional[str] = Field( + default=None, description="Page UUID" + ) + + +class MCPStyleTool(BaseTool): + """A tool mimicking MCP tool schemas with optional fields.""" + + name: str = "mcp_search" + description: str = "Search with optional filters" + args_schema: type[BaseModel] = MCPStyleInput + + def _run(self, **kwargs: Any) -> str: + return "result" + + +class TestOptionalFieldsPreserveNull: + """Tests that optional tool fields preserve null in the schema.""" + + def test_optional_string_allows_null(self) -> None: + """Optional[str] fields should include null in the schema so the LLM + can send null instead of being forced to guess a value.""" + tools = [MCPStyleTool()] + schemas, _ = convert_tools_to_openai_schema(tools) + + params = schemas[0]["function"]["parameters"] + page_id_prop = params["properties"]["page_id"] + + assert "anyOf" in page_id_prop + type_options = [opt.get("type") for opt in page_id_prop["anyOf"]] + assert "string" in type_options + assert "null" in type_options + + def test_optional_literal_allows_null(self) -> None: + """Optional[Literal[...]] fields should include null.""" + tools = [MCPStyleTool()] + schemas, _ = convert_tools_to_openai_schema(tools) + + params = schemas[0]["function"]["parameters"] + filter_prop = params["properties"]["filter_type"] + + assert "anyOf" in filter_prop + has_null = any(opt.get("type") == "null" for opt in filter_prop["anyOf"]) + assert has_null + + def test_required_field_stays_non_null(self) -> None: + """Required fields without Optional should NOT have null.""" + tools = [MCPStyleTool()] + schemas, _ = convert_tools_to_openai_schema(tools) + + params = schemas[0]["function"]["parameters"] + query_prop = params["properties"]["query"] + + assert query_prop.get("type") == "string" + assert "anyOf" not in query_prop + + def test_all_fields_in_required_for_strict_mode(self) -> None: + """All fields (including optional) must be in required for strict mode.""" + tools = [MCPStyleTool()] + schemas, _ = convert_tools_to_openai_schema(tools) + + params = schemas[0]["function"]["parameters"] + assert "query" in params["required"] + assert "filter_type" in params["required"] + assert "page_id" in params["required"] + class TestSummarizeMessages: """Tests for summarize_messages function."""