mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-05-06 17:52:35 +00:00
Fixes #4472 — MCP tools (and any tool with args_schema) produced JSON schemas incompatible with Gemini and Bedrock/Anthropic because extract_tool_info() called generate_model_description() which applies OpenAI-specific transformations (additionalProperties: false, forcing all properties required, title fields). Added generate_tool_parameters_schema() that produces clean schemas without those provider-specific metadata fields. Updated extract_tool_info() to use it for the args_schema path. Co-Authored-By: João <joao@crewai.com>
288 lines
11 KiB
Python
288 lines
11 KiB
Python
"""Tests for tool schema compatibility across LLM providers.
|
|
|
|
Covers issue #4472: MCP tools producing JSON schemas that are incompatible
|
|
with Bedrock (Claude) and Gemini when using ``generate_model_description``
|
|
(OpenAI-specific) instead of the provider-agnostic
|
|
``generate_tool_parameters_schema``.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
from typing import Any, Optional
|
|
|
|
import pytest
|
|
from pydantic import BaseModel, Field
|
|
|
|
from crewai.llms.providers.utils.common import extract_tool_info, safe_tool_conversion
|
|
from crewai.tools.base_tool import BaseTool
|
|
from crewai.utilities.pydantic_schema_utils import (
|
|
_strip_schema_metadata,
|
|
generate_model_description,
|
|
generate_tool_parameters_schema,
|
|
)
|
|
|
|
|
|
class FirewallToolInput(BaseModel):
|
|
"""Simulates an MCP tool input schema similar to the one in issue #4472."""
|
|
|
|
hostname: str = Field(description="The firewall IP address")
|
|
command: Optional[str] = Field(default=None, description="The CLI command string")
|
|
|
|
|
|
class SimpleToolInput(BaseModel):
|
|
"""A simple tool input with only required fields."""
|
|
|
|
text: str = Field(description="Input text")
|
|
|
|
|
|
class ComplexToolInput(BaseModel):
|
|
"""Tool input with nested objects and arrays."""
|
|
|
|
query: str = Field(description="Search query")
|
|
max_results: int = Field(default=10, description="Maximum number of results")
|
|
tags: list[str] = Field(default_factory=list, description="Filter tags")
|
|
|
|
|
|
class NestedModel(BaseModel):
|
|
"""A nested model for testing ref resolution."""
|
|
|
|
name: str
|
|
value: int
|
|
|
|
|
|
class ToolWithNestedInput(BaseModel):
|
|
"""Tool input containing a nested model for ref resolution."""
|
|
|
|
config: NestedModel = Field(description="Configuration object")
|
|
enabled: bool = Field(default=True, description="Whether enabled")
|
|
|
|
|
|
def _has_key_recursive(schema: Any, key: str) -> bool:
|
|
"""Check if a key exists anywhere in a nested dict/list structure."""
|
|
if isinstance(schema, dict):
|
|
if key in schema:
|
|
return True
|
|
return any(_has_key_recursive(v, key) for v in schema.values())
|
|
if isinstance(schema, list):
|
|
return any(_has_key_recursive(item, key) for item in schema)
|
|
return False
|
|
|
|
|
|
class TestGenerateToolParametersSchema:
|
|
"""Tests for generate_tool_parameters_schema — the provider-agnostic path."""
|
|
|
|
def test_no_title_field(self) -> None:
|
|
schema = generate_tool_parameters_schema(FirewallToolInput)
|
|
assert not _has_key_recursive(schema, "title")
|
|
|
|
def test_no_additional_properties(self) -> None:
|
|
schema = generate_tool_parameters_schema(FirewallToolInput)
|
|
assert not _has_key_recursive(schema, "additionalProperties")
|
|
|
|
def test_no_default_field(self) -> None:
|
|
schema = generate_tool_parameters_schema(FirewallToolInput)
|
|
assert not _has_key_recursive(schema, "default")
|
|
|
|
def test_preserves_type(self) -> None:
|
|
schema = generate_tool_parameters_schema(FirewallToolInput)
|
|
assert schema["type"] == "object"
|
|
|
|
def test_preserves_properties(self) -> None:
|
|
schema = generate_tool_parameters_schema(FirewallToolInput)
|
|
assert "hostname" in schema["properties"]
|
|
assert "command" in schema["properties"]
|
|
|
|
def test_preserves_required_only_required_fields(self) -> None:
|
|
schema = generate_tool_parameters_schema(FirewallToolInput)
|
|
assert "hostname" in schema.get("required", [])
|
|
assert "command" not in schema.get("required", [])
|
|
|
|
def test_preserves_description(self) -> None:
|
|
schema = generate_tool_parameters_schema(FirewallToolInput)
|
|
assert schema["properties"]["hostname"]["description"] == "The firewall IP address"
|
|
|
|
def test_optional_field_type_is_string(self) -> None:
|
|
"""Optional fields should have null stripped and resolve to their base type."""
|
|
schema = generate_tool_parameters_schema(FirewallToolInput)
|
|
cmd = schema["properties"]["command"]
|
|
assert cmd.get("type") == "string"
|
|
assert "anyOf" not in cmd
|
|
|
|
def test_simple_model(self) -> None:
|
|
schema = generate_tool_parameters_schema(SimpleToolInput)
|
|
assert schema["type"] == "object"
|
|
assert "text" in schema["properties"]
|
|
assert schema["properties"]["text"]["type"] == "string"
|
|
assert "text" in schema.get("required", [])
|
|
assert not _has_key_recursive(schema, "title")
|
|
assert not _has_key_recursive(schema, "additionalProperties")
|
|
|
|
def test_complex_model_with_array(self) -> None:
|
|
schema = generate_tool_parameters_schema(ComplexToolInput)
|
|
assert "tags" in schema["properties"]
|
|
tags = schema["properties"]["tags"]
|
|
assert tags["type"] == "array"
|
|
assert tags["items"]["type"] == "string"
|
|
assert not _has_key_recursive(schema, "title")
|
|
assert not _has_key_recursive(schema, "default")
|
|
|
|
def test_nested_model_refs_resolved(self) -> None:
|
|
schema = generate_tool_parameters_schema(ToolWithNestedInput)
|
|
assert not _has_key_recursive(schema, "$ref")
|
|
assert "$defs" not in schema
|
|
config_props = schema["properties"]["config"]
|
|
assert "properties" in config_props
|
|
assert "name" in config_props["properties"]
|
|
assert "value" in config_props["properties"]
|
|
|
|
|
|
class TestStripSchemaMetadata:
|
|
"""Unit tests for the _strip_schema_metadata helper."""
|
|
|
|
def test_strips_title(self) -> None:
|
|
schema: dict[str, Any] = {"type": "object", "title": "Foo", "properties": {}}
|
|
result = _strip_schema_metadata(schema)
|
|
assert "title" not in result
|
|
|
|
def test_strips_default(self) -> None:
|
|
schema: dict[str, Any] = {"type": "string", "default": "bar"}
|
|
result = _strip_schema_metadata(schema)
|
|
assert "default" not in result
|
|
|
|
def test_strips_additional_properties(self) -> None:
|
|
schema: dict[str, Any] = {"type": "object", "additionalProperties": False}
|
|
result = _strip_schema_metadata(schema)
|
|
assert "additionalProperties" not in result
|
|
|
|
def test_strips_nested(self) -> None:
|
|
schema: dict[str, Any] = {
|
|
"type": "object",
|
|
"title": "Root",
|
|
"properties": {
|
|
"x": {"type": "string", "title": "X", "default": "a"},
|
|
"nested": {
|
|
"type": "object",
|
|
"title": "Inner",
|
|
"additionalProperties": False,
|
|
"properties": {"y": {"type": "integer", "title": "Y"}},
|
|
},
|
|
},
|
|
}
|
|
result = _strip_schema_metadata(schema)
|
|
assert not _has_key_recursive(result, "title")
|
|
assert not _has_key_recursive(result, "default")
|
|
assert not _has_key_recursive(result, "additionalProperties")
|
|
|
|
def test_preserves_type_and_description(self) -> None:
|
|
schema: dict[str, Any] = {
|
|
"type": "string",
|
|
"description": "hello",
|
|
"title": "Foo",
|
|
}
|
|
result = _strip_schema_metadata(schema)
|
|
assert result["type"] == "string"
|
|
assert result["description"] == "hello"
|
|
|
|
|
|
class TestExtractToolInfoUsesCleanSchema:
|
|
"""Verify that extract_tool_info uses generate_tool_parameters_schema."""
|
|
|
|
def test_args_schema_produces_clean_schema(self) -> None:
|
|
tool_dict: dict[str, Any] = {
|
|
"name": "firewall_tool",
|
|
"description": "Run firewall commands",
|
|
"args_schema": FirewallToolInput,
|
|
}
|
|
name, description, parameters = extract_tool_info(tool_dict)
|
|
assert name == "firewall_tool"
|
|
assert not _has_key_recursive(parameters, "title")
|
|
assert not _has_key_recursive(parameters, "additionalProperties")
|
|
assert not _has_key_recursive(parameters, "default")
|
|
assert "hostname" in parameters.get("required", [])
|
|
assert "command" not in parameters.get("required", [])
|
|
|
|
def test_openai_format_uses_inline_parameters(self) -> None:
|
|
tool_dict: dict[str, Any] = {
|
|
"type": "function",
|
|
"function": {
|
|
"name": "my_tool",
|
|
"description": "desc",
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {"x": {"type": "string"}},
|
|
},
|
|
},
|
|
}
|
|
name, description, parameters = extract_tool_info(tool_dict)
|
|
assert name == "my_tool"
|
|
assert parameters == {
|
|
"type": "object",
|
|
"properties": {"x": {"type": "string"}},
|
|
}
|
|
|
|
|
|
class TestSafeToolConversionForProviders:
|
|
"""Integration tests verifying safe_tool_conversion produces provider-compatible schemas."""
|
|
|
|
@staticmethod
|
|
def _make_tool_dict(args_model: type[BaseModel]) -> dict[str, Any]:
|
|
return {
|
|
"name": "test_tool",
|
|
"description": "A test tool",
|
|
"args_schema": args_model,
|
|
}
|
|
|
|
def test_gemini_compatible_schema(self) -> None:
|
|
_, _, params = safe_tool_conversion(
|
|
self._make_tool_dict(FirewallToolInput), "Gemini"
|
|
)
|
|
assert not _has_key_recursive(params, "title")
|
|
assert not _has_key_recursive(params, "additionalProperties")
|
|
assert not _has_key_recursive(params, "default")
|
|
assert params["properties"]["hostname"]["type"] == "string"
|
|
assert params["properties"]["command"]["type"] == "string"
|
|
|
|
def test_bedrock_compatible_schema(self) -> None:
|
|
_, _, params = safe_tool_conversion(
|
|
self._make_tool_dict(FirewallToolInput), "Bedrock"
|
|
)
|
|
assert not _has_key_recursive(params, "title")
|
|
assert not _has_key_recursive(params, "additionalProperties")
|
|
assert not _has_key_recursive(params, "default")
|
|
assert "hostname" in params.get("required", [])
|
|
assert "command" not in params.get("required", [])
|
|
|
|
def test_anthropic_compatible_schema(self) -> None:
|
|
_, _, params = safe_tool_conversion(
|
|
self._make_tool_dict(ComplexToolInput), "Anthropic"
|
|
)
|
|
assert not _has_key_recursive(params, "title")
|
|
assert not _has_key_recursive(params, "additionalProperties")
|
|
assert params["properties"]["tags"]["type"] == "array"
|
|
|
|
def test_nested_model_clean_for_all_providers(self) -> None:
|
|
for provider in ("Gemini", "Bedrock", "Anthropic"):
|
|
_, _, params = safe_tool_conversion(
|
|
self._make_tool_dict(ToolWithNestedInput), provider
|
|
)
|
|
assert not _has_key_recursive(params, "title"), f"Failed for {provider}"
|
|
assert not _has_key_recursive(
|
|
params, "additionalProperties"
|
|
), f"Failed for {provider}"
|
|
assert not _has_key_recursive(params, "$ref"), f"Failed for {provider}"
|
|
|
|
|
|
class TestOpenAISchemaUnchanged:
|
|
"""Ensure generate_model_description still produces OpenAI-specific schemas."""
|
|
|
|
def test_has_additional_properties_false(self) -> None:
|
|
result = generate_model_description(FirewallToolInput)
|
|
schema = result["json_schema"]["schema"]
|
|
assert schema.get("additionalProperties") is False
|
|
|
|
def test_all_properties_required(self) -> None:
|
|
result = generate_model_description(FirewallToolInput)
|
|
schema = result["json_schema"]["schema"]
|
|
assert "hostname" in schema["required"]
|
|
assert "command" in schema["required"]
|