fix: sanitize tool schemas for strict mode

Pydantic schemas intermittently fail strict tool-use on openai, anthropic,
and bedrock. All three reject nested objects missing additionalProperties:
false, and anthropic also rejects keywords like minLength and top-level
anyOf. Adds per-provider sanitizers that inline refs, close objects, mark
every property required, preserve nullable unions, and strip keywords each
grammar compiler rejects. Verified against real bedrock, anthropic, and
openai.
This commit is contained in:
Greyson LaLonde
2026-04-11 04:55:09 +08:00
parent 298fc7b9c0
commit 40f0b3754b
4 changed files with 151 additions and 23 deletions

View File

@@ -11,10 +11,14 @@ from crewai.events.types.llm_events import LLMCallType
from crewai.llms.base_llm import BaseLLM, JsonResponseFormat, llm_call_context
from crewai.llms.hooks.base import BaseInterceptor
from crewai.llms.hooks.transport import AsyncHTTPTransport, HTTPTransport
from crewai.llms.providers.utils.common import safe_tool_conversion
from crewai.utilities.agent_utils import is_context_length_exceeded
from crewai.utilities.exceptions.context_window_exceeding_exception import (
LLMContextLengthExceededError,
)
from crewai.utilities.pydantic_schema_utils import (
sanitize_tool_params_for_anthropic_strict,
)
from crewai.utilities.types import LLMMessage
@@ -473,10 +477,8 @@ class AnthropicCompletion(BaseLLM):
continue
try:
from crewai.llms.providers.utils.common import safe_tool_conversion
name, description, parameters = safe_tool_conversion(tool, "Anthropic")
except (ImportError, KeyError, ValueError) as e:
except (KeyError, ValueError) as e:
logging.error(f"Error converting tool to Anthropic format: {e}")
raise e
@@ -485,8 +487,15 @@ class AnthropicCompletion(BaseLLM):
"description": description,
}
func_info = tool.get("function", {})
strict_enabled = bool(func_info.get("strict"))
if parameters and isinstance(parameters, dict):
anthropic_tool["input_schema"] = parameters
anthropic_tool["input_schema"] = (
sanitize_tool_params_for_anthropic_strict(parameters)
if strict_enabled
else parameters
)
else:
anthropic_tool["input_schema"] = {
"type": "object",
@@ -494,8 +503,7 @@ class AnthropicCompletion(BaseLLM):
"required": [],
}
func_info = tool.get("function", {})
if func_info.get("strict"):
if strict_enabled:
anthropic_tool["strict"] = True
anthropic_tools.append(anthropic_tool)

View File

@@ -12,11 +12,15 @@ from typing_extensions import Required
from crewai.events.types.llm_events import LLMCallType
from crewai.llms.base_llm import BaseLLM, llm_call_context
from crewai.llms.providers.utils.common import safe_tool_conversion
from crewai.utilities.agent_utils import is_context_length_exceeded
from crewai.utilities.exceptions.context_window_exceeding_exception import (
LLMContextLengthExceededError,
)
from crewai.utilities.pydantic_schema_utils import generate_model_description
from crewai.utilities.pydantic_schema_utils import (
generate_model_description,
sanitize_tool_params_for_bedrock_strict,
)
from crewai.utilities.types import LLMMessage
@@ -1949,8 +1953,6 @@ class BedrockCompletion(BaseLLM):
tools: list[dict[str, Any]],
) -> list[ConverseToolTypeDef]:
"""Convert CrewAI tools to Converse API format following AWS specification."""
from crewai.llms.providers.utils.common import safe_tool_conversion
converse_tools: list[ConverseToolTypeDef] = []
for tool in tools:
@@ -1962,12 +1964,19 @@ class BedrockCompletion(BaseLLM):
"description": description,
}
func_info = tool.get("function", {})
strict_enabled = bool(func_info.get("strict"))
if parameters and isinstance(parameters, dict):
input_schema: ToolInputSchema = {"json": parameters}
schema_params = (
sanitize_tool_params_for_bedrock_strict(parameters)
if strict_enabled
else parameters
)
input_schema: ToolInputSchema = {"json": schema_params}
tool_spec["inputSchema"] = input_schema
func_info = tool.get("function", {})
if func_info.get("strict"):
if strict_enabled:
tool_spec["strict"] = True
converse_tool: ConverseToolTypeDef = {"toolSpec": tool_spec}

View File

@@ -32,11 +32,15 @@ from crewai.events.types.llm_events import LLMCallType
from crewai.llms.base_llm import BaseLLM, JsonResponseFormat, llm_call_context
from crewai.llms.hooks.base import BaseInterceptor
from crewai.llms.hooks.transport import AsyncHTTPTransport, HTTPTransport
from crewai.llms.providers.utils.common import safe_tool_conversion
from crewai.utilities.agent_utils import is_context_length_exceeded
from crewai.utilities.exceptions.context_window_exceeding_exception import (
LLMContextLengthExceededError,
)
from crewai.utilities.pydantic_schema_utils import generate_model_description
from crewai.utilities.pydantic_schema_utils import (
generate_model_description,
sanitize_tool_params_for_openai_strict,
)
from crewai.utilities.types import LLMMessage
@@ -764,8 +768,6 @@ class OpenAICompletion(BaseLLM):
"function": {"name": "...", "description": "...", "parameters": {...}}
}
"""
from crewai.llms.providers.utils.common import safe_tool_conversion
responses_tools = []
for tool in tools:
@@ -1548,11 +1550,6 @@ class OpenAICompletion(BaseLLM):
self, tools: list[dict[str, BaseTool]]
) -> list[dict[str, Any]]:
"""Convert CrewAI tool format to OpenAI function calling format."""
from crewai.llms.providers.utils.common import safe_tool_conversion
from crewai.utilities.pydantic_schema_utils import (
force_additional_properties_false,
)
openai_tools = []
for tool in tools:
@@ -1571,8 +1568,9 @@ class OpenAICompletion(BaseLLM):
params_dict = (
parameters if isinstance(parameters, dict) else dict(parameters)
)
params_dict = force_additional_properties_false(params_dict)
openai_tool["function"]["parameters"] = params_dict
openai_tool["function"]["parameters"] = (
sanitize_tool_params_for_openai_strict(params_dict)
)
openai_tools.append(openai_tool)
return openai_tools

View File

@@ -19,7 +19,7 @@ from collections.abc import Callable
from copy import deepcopy
import datetime
import logging
from typing import TYPE_CHECKING, Annotated, Any, Final, Literal, TypedDict, Union
from typing import TYPE_CHECKING, Annotated, Any, Final, Literal, TypedDict, Union, cast
import uuid
import jsonref # type: ignore[import-untyped]
@@ -417,6 +417,119 @@ def strip_null_from_types(schema: dict[str, Any]) -> dict[str, Any]:
return schema
_STRICT_METADATA_KEYS: Final[tuple[str, ...]] = (
"title",
"default",
"examples",
"example",
"$comment",
"readOnly",
"writeOnly",
"deprecated",
)
_ANTHROPIC_UNSUPPORTED_CONSTRAINTS: Final[tuple[str, ...]] = (
"minimum",
"maximum",
"exclusiveMinimum",
"exclusiveMaximum",
"multipleOf",
"minLength",
"maxLength",
"pattern",
"minItems",
"maxItems",
"uniqueItems",
"minContains",
"maxContains",
"minProperties",
"maxProperties",
"patternProperties",
"propertyNames",
"dependentRequired",
"dependentSchemas",
)
def _strip_keys_recursive(d: Any, keys: tuple[str, ...]) -> Any:
"""Recursively delete a fixed set of keys from a schema."""
if isinstance(d, dict):
for key in keys:
d.pop(key, None)
for v in d.values():
_strip_keys_recursive(v, keys)
elif isinstance(d, list):
for i in d:
_strip_keys_recursive(i, keys)
return d
def lift_top_level_anyof(schema: dict[str, Any]) -> dict[str, Any]:
"""Unwrap a top-level anyOf/oneOf/allOf wrapping a single object variant.
Anthropic's strict ``input_schema`` rejects top-level union keywords. When
exactly one variant is an object schema, lift it so the root is a plain
object; otherwise leave the schema alone.
"""
for key in ("anyOf", "oneOf", "allOf"):
variants = schema.get(key)
if not isinstance(variants, list):
continue
object_variants = [
v for v in variants if isinstance(v, dict) and v.get("type") == "object"
]
if len(object_variants) == 1:
lifted = deepcopy(object_variants[0])
schema.pop(key)
schema.update(lifted)
break
return schema
def _common_strict_pipeline(params: dict[str, Any]) -> dict[str, Any]:
"""Shared strict sanitization: inline refs, close objects, require all properties."""
sanitized = resolve_refs(deepcopy(params))
sanitized.pop("$defs", None)
sanitized = convert_oneof_to_anyof(sanitized)
sanitized = ensure_type_in_schemas(sanitized)
sanitized = force_additional_properties_false(sanitized)
sanitized = ensure_all_properties_required(sanitized)
return cast(dict[str, Any], _strip_keys_recursive(sanitized, _STRICT_METADATA_KEYS))
def sanitize_tool_params_for_openai_strict(
params: dict[str, Any],
) -> dict[str, Any]:
"""Sanitize a JSON schema for OpenAI strict function calling."""
if not isinstance(params, dict):
return params
return cast(
dict[str, Any], strip_unsupported_formats(_common_strict_pipeline(params))
)
def sanitize_tool_params_for_anthropic_strict(
params: dict[str, Any],
) -> dict[str, Any]:
"""Sanitize a JSON schema for Anthropic strict tool use."""
if not isinstance(params, dict):
return params
sanitized = lift_top_level_anyof(_common_strict_pipeline(params))
sanitized = _strip_keys_recursive(sanitized, _ANTHROPIC_UNSUPPORTED_CONSTRAINTS)
return cast(dict[str, Any], strip_unsupported_formats(sanitized))
def sanitize_tool_params_for_bedrock_strict(
params: dict[str, Any],
) -> dict[str, Any]:
"""Sanitize a JSON schema for Bedrock Converse strict tool use."""
if not isinstance(params, dict):
return params
sanitized = lift_top_level_anyof(_common_strict_pipeline(params))
sanitized = _strip_keys_recursive(sanitized, _ANTHROPIC_UNSUPPORTED_CONSTRAINTS)
return cast(dict[str, Any], strip_unsupported_formats(sanitized))
def generate_model_description(
model: type[BaseModel],
*,