mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-04-10 21:12:37 +00:00
fix: sanitize tool schemas for strict mode
Pydantic schemas intermittently fail strict tool-use on openai, anthropic, and bedrock. All three reject nested objects missing additionalProperties: false, and anthropic also rejects keywords like minLength and top-level anyOf. Adds per-provider sanitizers that inline refs, close objects, mark every property required, preserve nullable unions, and strip keywords each grammar compiler rejects. Verified against real bedrock, anthropic, and openai.
This commit is contained in:
@@ -11,10 +11,14 @@ from crewai.events.types.llm_events import LLMCallType
|
||||
from crewai.llms.base_llm import BaseLLM, JsonResponseFormat, llm_call_context
|
||||
from crewai.llms.hooks.base import BaseInterceptor
|
||||
from crewai.llms.hooks.transport import AsyncHTTPTransport, HTTPTransport
|
||||
from crewai.llms.providers.utils.common import safe_tool_conversion
|
||||
from crewai.utilities.agent_utils import is_context_length_exceeded
|
||||
from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
||||
LLMContextLengthExceededError,
|
||||
)
|
||||
from crewai.utilities.pydantic_schema_utils import (
|
||||
sanitize_tool_params_for_anthropic_strict,
|
||||
)
|
||||
from crewai.utilities.types import LLMMessage
|
||||
|
||||
|
||||
@@ -473,10 +477,8 @@ class AnthropicCompletion(BaseLLM):
|
||||
continue
|
||||
|
||||
try:
|
||||
from crewai.llms.providers.utils.common import safe_tool_conversion
|
||||
|
||||
name, description, parameters = safe_tool_conversion(tool, "Anthropic")
|
||||
except (ImportError, KeyError, ValueError) as e:
|
||||
except (KeyError, ValueError) as e:
|
||||
logging.error(f"Error converting tool to Anthropic format: {e}")
|
||||
raise e
|
||||
|
||||
@@ -485,8 +487,15 @@ class AnthropicCompletion(BaseLLM):
|
||||
"description": description,
|
||||
}
|
||||
|
||||
func_info = tool.get("function", {})
|
||||
strict_enabled = bool(func_info.get("strict"))
|
||||
|
||||
if parameters and isinstance(parameters, dict):
|
||||
anthropic_tool["input_schema"] = parameters
|
||||
anthropic_tool["input_schema"] = (
|
||||
sanitize_tool_params_for_anthropic_strict(parameters)
|
||||
if strict_enabled
|
||||
else parameters
|
||||
)
|
||||
else:
|
||||
anthropic_tool["input_schema"] = {
|
||||
"type": "object",
|
||||
@@ -494,8 +503,7 @@ class AnthropicCompletion(BaseLLM):
|
||||
"required": [],
|
||||
}
|
||||
|
||||
func_info = tool.get("function", {})
|
||||
if func_info.get("strict"):
|
||||
if strict_enabled:
|
||||
anthropic_tool["strict"] = True
|
||||
|
||||
anthropic_tools.append(anthropic_tool)
|
||||
|
||||
@@ -12,11 +12,15 @@ from typing_extensions import Required
|
||||
|
||||
from crewai.events.types.llm_events import LLMCallType
|
||||
from crewai.llms.base_llm import BaseLLM, llm_call_context
|
||||
from crewai.llms.providers.utils.common import safe_tool_conversion
|
||||
from crewai.utilities.agent_utils import is_context_length_exceeded
|
||||
from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
||||
LLMContextLengthExceededError,
|
||||
)
|
||||
from crewai.utilities.pydantic_schema_utils import generate_model_description
|
||||
from crewai.utilities.pydantic_schema_utils import (
|
||||
generate_model_description,
|
||||
sanitize_tool_params_for_bedrock_strict,
|
||||
)
|
||||
from crewai.utilities.types import LLMMessage
|
||||
|
||||
|
||||
@@ -1949,8 +1953,6 @@ class BedrockCompletion(BaseLLM):
|
||||
tools: list[dict[str, Any]],
|
||||
) -> list[ConverseToolTypeDef]:
|
||||
"""Convert CrewAI tools to Converse API format following AWS specification."""
|
||||
from crewai.llms.providers.utils.common import safe_tool_conversion
|
||||
|
||||
converse_tools: list[ConverseToolTypeDef] = []
|
||||
|
||||
for tool in tools:
|
||||
@@ -1962,12 +1964,19 @@ class BedrockCompletion(BaseLLM):
|
||||
"description": description,
|
||||
}
|
||||
|
||||
func_info = tool.get("function", {})
|
||||
strict_enabled = bool(func_info.get("strict"))
|
||||
|
||||
if parameters and isinstance(parameters, dict):
|
||||
input_schema: ToolInputSchema = {"json": parameters}
|
||||
schema_params = (
|
||||
sanitize_tool_params_for_bedrock_strict(parameters)
|
||||
if strict_enabled
|
||||
else parameters
|
||||
)
|
||||
input_schema: ToolInputSchema = {"json": schema_params}
|
||||
tool_spec["inputSchema"] = input_schema
|
||||
|
||||
func_info = tool.get("function", {})
|
||||
if func_info.get("strict"):
|
||||
if strict_enabled:
|
||||
tool_spec["strict"] = True
|
||||
|
||||
converse_tool: ConverseToolTypeDef = {"toolSpec": tool_spec}
|
||||
|
||||
@@ -32,11 +32,15 @@ from crewai.events.types.llm_events import LLMCallType
|
||||
from crewai.llms.base_llm import BaseLLM, JsonResponseFormat, llm_call_context
|
||||
from crewai.llms.hooks.base import BaseInterceptor
|
||||
from crewai.llms.hooks.transport import AsyncHTTPTransport, HTTPTransport
|
||||
from crewai.llms.providers.utils.common import safe_tool_conversion
|
||||
from crewai.utilities.agent_utils import is_context_length_exceeded
|
||||
from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
||||
LLMContextLengthExceededError,
|
||||
)
|
||||
from crewai.utilities.pydantic_schema_utils import generate_model_description
|
||||
from crewai.utilities.pydantic_schema_utils import (
|
||||
generate_model_description,
|
||||
sanitize_tool_params_for_openai_strict,
|
||||
)
|
||||
from crewai.utilities.types import LLMMessage
|
||||
|
||||
|
||||
@@ -764,8 +768,6 @@ class OpenAICompletion(BaseLLM):
|
||||
"function": {"name": "...", "description": "...", "parameters": {...}}
|
||||
}
|
||||
"""
|
||||
from crewai.llms.providers.utils.common import safe_tool_conversion
|
||||
|
||||
responses_tools = []
|
||||
|
||||
for tool in tools:
|
||||
@@ -1548,11 +1550,6 @@ class OpenAICompletion(BaseLLM):
|
||||
self, tools: list[dict[str, BaseTool]]
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Convert CrewAI tool format to OpenAI function calling format."""
|
||||
from crewai.llms.providers.utils.common import safe_tool_conversion
|
||||
from crewai.utilities.pydantic_schema_utils import (
|
||||
force_additional_properties_false,
|
||||
)
|
||||
|
||||
openai_tools = []
|
||||
|
||||
for tool in tools:
|
||||
@@ -1571,8 +1568,9 @@ class OpenAICompletion(BaseLLM):
|
||||
params_dict = (
|
||||
parameters if isinstance(parameters, dict) else dict(parameters)
|
||||
)
|
||||
params_dict = force_additional_properties_false(params_dict)
|
||||
openai_tool["function"]["parameters"] = params_dict
|
||||
openai_tool["function"]["parameters"] = (
|
||||
sanitize_tool_params_for_openai_strict(params_dict)
|
||||
)
|
||||
|
||||
openai_tools.append(openai_tool)
|
||||
return openai_tools
|
||||
|
||||
@@ -19,7 +19,7 @@ from collections.abc import Callable
|
||||
from copy import deepcopy
|
||||
import datetime
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Annotated, Any, Final, Literal, TypedDict, Union
|
||||
from typing import TYPE_CHECKING, Annotated, Any, Final, Literal, TypedDict, Union, cast
|
||||
import uuid
|
||||
|
||||
import jsonref # type: ignore[import-untyped]
|
||||
@@ -417,6 +417,119 @@ def strip_null_from_types(schema: dict[str, Any]) -> dict[str, Any]:
|
||||
return schema
|
||||
|
||||
|
||||
_STRICT_METADATA_KEYS: Final[tuple[str, ...]] = (
|
||||
"title",
|
||||
"default",
|
||||
"examples",
|
||||
"example",
|
||||
"$comment",
|
||||
"readOnly",
|
||||
"writeOnly",
|
||||
"deprecated",
|
||||
)
|
||||
|
||||
_ANTHROPIC_UNSUPPORTED_CONSTRAINTS: Final[tuple[str, ...]] = (
|
||||
"minimum",
|
||||
"maximum",
|
||||
"exclusiveMinimum",
|
||||
"exclusiveMaximum",
|
||||
"multipleOf",
|
||||
"minLength",
|
||||
"maxLength",
|
||||
"pattern",
|
||||
"minItems",
|
||||
"maxItems",
|
||||
"uniqueItems",
|
||||
"minContains",
|
||||
"maxContains",
|
||||
"minProperties",
|
||||
"maxProperties",
|
||||
"patternProperties",
|
||||
"propertyNames",
|
||||
"dependentRequired",
|
||||
"dependentSchemas",
|
||||
)
|
||||
|
||||
|
||||
def _strip_keys_recursive(d: Any, keys: tuple[str, ...]) -> Any:
|
||||
"""Recursively delete a fixed set of keys from a schema."""
|
||||
if isinstance(d, dict):
|
||||
for key in keys:
|
||||
d.pop(key, None)
|
||||
for v in d.values():
|
||||
_strip_keys_recursive(v, keys)
|
||||
elif isinstance(d, list):
|
||||
for i in d:
|
||||
_strip_keys_recursive(i, keys)
|
||||
return d
|
||||
|
||||
|
||||
def lift_top_level_anyof(schema: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Unwrap a top-level anyOf/oneOf/allOf wrapping a single object variant.
|
||||
|
||||
Anthropic's strict ``input_schema`` rejects top-level union keywords. When
|
||||
exactly one variant is an object schema, lift it so the root is a plain
|
||||
object; otherwise leave the schema alone.
|
||||
"""
|
||||
for key in ("anyOf", "oneOf", "allOf"):
|
||||
variants = schema.get(key)
|
||||
if not isinstance(variants, list):
|
||||
continue
|
||||
object_variants = [
|
||||
v for v in variants if isinstance(v, dict) and v.get("type") == "object"
|
||||
]
|
||||
if len(object_variants) == 1:
|
||||
lifted = deepcopy(object_variants[0])
|
||||
schema.pop(key)
|
||||
schema.update(lifted)
|
||||
break
|
||||
return schema
|
||||
|
||||
|
||||
def _common_strict_pipeline(params: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Shared strict sanitization: inline refs, close objects, require all properties."""
|
||||
sanitized = resolve_refs(deepcopy(params))
|
||||
sanitized.pop("$defs", None)
|
||||
sanitized = convert_oneof_to_anyof(sanitized)
|
||||
sanitized = ensure_type_in_schemas(sanitized)
|
||||
sanitized = force_additional_properties_false(sanitized)
|
||||
sanitized = ensure_all_properties_required(sanitized)
|
||||
return cast(dict[str, Any], _strip_keys_recursive(sanitized, _STRICT_METADATA_KEYS))
|
||||
|
||||
|
||||
def sanitize_tool_params_for_openai_strict(
|
||||
params: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
"""Sanitize a JSON schema for OpenAI strict function calling."""
|
||||
if not isinstance(params, dict):
|
||||
return params
|
||||
return cast(
|
||||
dict[str, Any], strip_unsupported_formats(_common_strict_pipeline(params))
|
||||
)
|
||||
|
||||
|
||||
def sanitize_tool_params_for_anthropic_strict(
|
||||
params: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
"""Sanitize a JSON schema for Anthropic strict tool use."""
|
||||
if not isinstance(params, dict):
|
||||
return params
|
||||
sanitized = lift_top_level_anyof(_common_strict_pipeline(params))
|
||||
sanitized = _strip_keys_recursive(sanitized, _ANTHROPIC_UNSUPPORTED_CONSTRAINTS)
|
||||
return cast(dict[str, Any], strip_unsupported_formats(sanitized))
|
||||
|
||||
|
||||
def sanitize_tool_params_for_bedrock_strict(
|
||||
params: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
"""Sanitize a JSON schema for Bedrock Converse strict tool use."""
|
||||
if not isinstance(params, dict):
|
||||
return params
|
||||
sanitized = lift_top_level_anyof(_common_strict_pipeline(params))
|
||||
sanitized = _strip_keys_recursive(sanitized, _ANTHROPIC_UNSUPPORTED_CONSTRAINTS)
|
||||
return cast(dict[str, Any], strip_unsupported_formats(sanitized))
|
||||
|
||||
|
||||
def generate_model_description(
|
||||
model: type[BaseModel],
|
||||
*,
|
||||
|
||||
Reference in New Issue
Block a user