Compare commits

...

4 Commits

Author SHA1 Message Date
Lucas Gomide
2fccd5f8bc wip 2026-02-09 15:54:52 -03:00
Lucas Gomide
89556605cd feat: improve JSON schema handling for MCP tools
- Convert enum constraints to Literal types
- Handle format constraints (date, date-time)
- Preserve original MCP tool names for server calls
2026-02-05 12:06:50 -03:00
Lucas Gomide
ff00055e2c feat: use original tool name instead of the normalized one 2026-02-05 12:01:15 -03:00
Lucas Gomide
507aec7a48 wip 2026-02-05 11:39:55 -03:00
4 changed files with 505 additions and 75 deletions

View File

@@ -24,6 +24,7 @@ from pydantic import (
)
from typing_extensions import Self
from crewai.agent.json_schema_converter import JSONSchemaConverter
from crewai.agent.utils import (
ahandle_knowledge_retrieval,
apply_training_data,
@@ -1178,6 +1179,7 @@ class Agent(BaseAgent):
tools = []
for tool_def in tools_list:
tool_name = tool_def.get("name", "")
original_tool_name = tool_def.get("original_name", tool_name)
if not tool_name:
continue
@@ -1199,6 +1201,7 @@ class Agent(BaseAgent):
tool_name=tool_name,
tool_schema=tool_schema,
server_name=server_name,
original_tool_name=original_tool_name,
)
tools.append(native_tool)
except Exception as e:
@@ -1213,26 +1216,63 @@ class Agent(BaseAgent):
raise RuntimeError(f"Failed to get native MCP tools: {e}") from e
def _get_amp_mcp_tools(self, amp_ref: str) -> list[BaseTool]:
"""Get tools from CrewAI AMP MCP marketplace."""
# Parse: "crewai-amp:mcp-name" or "crewai-amp:mcp-name#tool_name"
"""Get tools from CrewAI AMP MCP via crewai-oauth service.
Fetches MCP server configuration with tokens injected from crewai-oauth,
then uses _get_native_mcp_tools to connect and discover tools.
"""
# Parse: "crewai-amp:mcp-slug" or "crewai-amp:mcp-slug#tool_name"
amp_part = amp_ref.replace("crewai-amp:", "")
if "#" in amp_part:
mcp_name, specific_tool = amp_part.split("#", 1)
mcp_slug, specific_tool = amp_part.split("#", 1)
else:
mcp_name, specific_tool = amp_part, None
mcp_slug, specific_tool = amp_part, None
# Call AMP API to get MCP server URLs
mcp_servers = self._fetch_amp_mcp_servers(mcp_name)
# Fetch MCP config from crewai-oauth (with tokens injected)
mcp_config_dict = self._fetch_amp_mcp_config(mcp_slug)
tools = []
for server_config in mcp_servers:
server_ref = server_config["url"]
if specific_tool:
server_ref += f"#{specific_tool}"
server_tools = self._get_external_mcp_tools(server_ref)
tools.extend(server_tools)
if not mcp_config_dict:
self._logger.log(
"warning", f"Failed to fetch MCP config for '{mcp_slug}' from crewai-oauth"
)
return []
return tools
# Convert dict to MCPServerConfig (MCPServerHTTP or MCPServerSSE)
config_type = mcp_config_dict.get("type", "http")
if config_type == "sse":
mcp_config = MCPServerSSE(
url=mcp_config_dict["url"],
headers=mcp_config_dict.get("headers"),
cache_tools_list=mcp_config_dict.get("cache_tools_list", False),
)
else:
mcp_config = MCPServerHTTP(
url=mcp_config_dict["url"],
headers=mcp_config_dict.get("headers"),
streamable=mcp_config_dict.get("streamable", True),
cache_tools_list=mcp_config_dict.get("cache_tools_list", False),
)
# Apply tool filter if specific tool requested
if specific_tool:
from crewai.mcp.filters import create_static_tool_filter
mcp_config.tool_filter = create_static_tool_filter(
allowed_tool_names=[specific_tool]
)
# Use native MCP tools to connect and discover tools
try:
tools, client = self._get_native_mcp_tools(mcp_config)
if client:
self._mcp_clients.append(client)
return tools
except Exception as e:
self._logger.log(
"warning", f"Failed to get MCP tools from '{mcp_slug}': {e}"
)
return []
@staticmethod
def _extract_server_name(server_url: str) -> str:
@@ -1389,6 +1429,9 @@ class Agent(BaseAgent):
}
return schemas
# Shared JSON Schema converter instance
_schema_converter: JSONSchemaConverter = JSONSchemaConverter()
def _json_schema_to_pydantic(
self, tool_name: str, json_schema: dict[str, Any]
) -> type:
@@ -1401,77 +1444,62 @@ class Agent(BaseAgent):
Returns:
Pydantic BaseModel class
"""
from pydantic import Field, create_model
return self._schema_converter.json_schema_to_pydantic(tool_name, json_schema)
properties = json_schema.get("properties", {})
required_fields = json_schema.get("required", [])
def _fetch_amp_mcp_config(self, mcp_slug: str) -> dict[str, Any] | None:
"""Fetch MCP server configuration from crewai-oauth service.
field_definitions: dict[str, Any] = {}
Returns MCPServerConfig dict with tokens injected, ready for use with
_get_native_mcp_tools.
for field_name, field_schema in properties.items():
field_type = self._json_type_to_python(field_schema)
field_description = field_schema.get("description", "")
is_required = field_name in required_fields
if is_required:
field_definitions[field_name] = (
field_type,
Field(..., description=field_description),
)
else:
field_definitions[field_name] = (
field_type | None,
Field(default=None, description=field_description),
)
model_name = f"{tool_name.replace('-', '_').replace(' ', '_')}Schema"
return create_model(model_name, **field_definitions) # type: ignore[no-any-return]
def _json_type_to_python(self, field_schema: dict[str, Any]) -> type:
"""Convert JSON Schema type to Python type.
Environment variables:
CREWAI_OAUTH_URL: Base URL of crewai-oauth service
CREWAI_OAUTH_API_KEY: API key for authenticating with crewai-oauth
Args:
field_schema: JSON Schema field definition
mcp_slug: The MCP server slug (e.g., "notion-mcp-abc123")
Returns:
Python type
Dict with type, url, headers, streamable, cache_tools_list, or None if failed.
"""
import os
json_type = field_schema.get("type")
import requests
if "anyOf" in field_schema:
types: list[type] = []
for option in field_schema["anyOf"]:
if "const" in option:
types.append(str)
else:
types.append(self._json_type_to_python(option))
unique_types = list(set(types))
if len(unique_types) > 1:
result: Any = unique_types[0]
for t in unique_types[1:]:
result = result | t
return result # type: ignore[no-any-return]
return unique_types[0]
try:
endpoint = f"http://localhost:8787/mcps/{mcp_slug}/config"
response = requests.get(
endpoint,
headers={"Authorization": "Bearer 6b327f9ebe62726590f8de8f624cf018ad4765fecb7373f9db475a940ad546d0"},
timeout=30,
)
type_mapping: dict[str | None, type] = {
"string": str,
"number": float,
"integer": int,
"boolean": bool,
"array": list,
"object": dict,
}
if response.status_code == 200:
return response.json()
elif response.status_code == 400:
error_data = response.json()
self._logger.log(
"warning",
f"MCP '{mcp_slug}' is not connected: {error_data.get('error_description', 'Unknown error')}",
)
return None
elif response.status_code == 404:
self._logger.log(
"warning", f"MCP server '{mcp_slug}' not found in crewai-oauth"
)
return None
else:
self._logger.log(
"warning",
f"Failed to fetch MCP config from crewai-oauth: HTTP {response.status_code}",
)
return None
return type_mapping.get(json_type, Any)
@staticmethod
def _fetch_amp_mcp_servers(mcp_name: str) -> list[dict[str, Any]]:
"""Fetch MCP server configurations from CrewAI AMP API."""
# TODO: Implement AMP API call to "integrations/mcps" endpoint
# Should return list of server configs with URLs
return []
except requests.exceptions.RequestException as e:
self._logger.log(
"warning", f"Failed to connect to crewai-oauth: {e}"
)
return None
@staticmethod
def get_multimodal_tools() -> Sequence[BaseTool]:

View File

@@ -0,0 +1,399 @@
from typing import Any, Literal, Type, Union, get_args
from pydantic import Field, create_model
from pydantic.fields import FieldInfo
import datetime
import uuid
class JSONSchemaConverter:
"""Converts JSON Schema definitions to Python/Pydantic types."""
def json_schema_to_pydantic(
self, tool_name: str, json_schema: dict[str, Any]
) -> Type[Any]:
"""Convert JSON Schema to Pydantic model for tool arguments.
Args:
tool_name: Name of the tool (used for model naming)
json_schema: JSON Schema dict with 'properties', 'required', etc.
Returns:
Pydantic BaseModel class
"""
properties = json_schema.get("properties", {})
required_fields = json_schema.get("required", [])
model_name = f"{tool_name.replace('-', '_').replace(' ', '_')}Schema"
return self._create_pydantic_model(model_name, properties, required_fields)
def _json_type_to_python(
self, field_schema: dict[str, Any], field_name: str = "Field"
) -> Type[Any]:
"""Convert JSON Schema type to Python type, handling nested structures.
Args:
field_schema: JSON Schema field definition
field_name: Name of the field (used for nested model naming)
Returns:
Python type (may be a dynamically created Pydantic model for objects/arrays)
"""
if not field_schema:
return Any
# Handle $ref if needed
if "$ref" in field_schema:
# You might want to implement reference resolution here
return Any
# Handle enum constraint - create Literal type
if "enum" in field_schema:
return self._handle_enum(field_schema)
# Handle different schema constructs in order of precedence
if "allOf" in field_schema:
return self._handle_allof(field_schema, field_name)
if "anyOf" in field_schema or "oneOf" in field_schema:
return self._handle_union_schemas(field_schema, field_name)
json_type = field_schema.get("type")
if isinstance(json_type, list):
return self._handle_type_union(json_type)
if json_type == "array":
return self._handle_array_type(field_schema, field_name)
if json_type == "object":
return self._handle_object_type(field_schema, field_name)
# Handle format for string types
if json_type == "string" and "format" in field_schema:
return self._get_formatted_type(field_schema["format"])
return self._get_simple_type(json_type)
def _get_formatted_type(self, format_type: str) -> Type[Any]:
"""Get Python type for JSON Schema format constraint.
Args:
format_type: JSON Schema format string (date, date-time, email, etc.)
Returns:
Appropriate Python type for the format
"""
format_mapping: dict[str, Type[Any]] = {
"date": datetime.date,
"date-time": datetime.datetime,
"time": datetime.time,
"email": str, # Could use EmailStr from pydantic
"uri": str,
"uuid": str, # Could use UUID
"hostname": str,
"ipv4": str,
"ipv6": str,
}
return format_mapping.get(format_type, str)
def _handle_enum(self, field_schema: dict[str, Any]) -> Type[Any]:
"""Handle enum constraint by creating a Literal type.
Args:
field_schema: Schema containing enum values
Returns:
Literal type with enum values
"""
enum_values = field_schema.get("enum", [])
if not enum_values:
return str
# Filter out None values for the Literal type
non_null_values = [v for v in enum_values if v is not None]
if not non_null_values:
return type(None)
# Create Literal type with enum values
# For strings, create Literal["value1", "value2", ...]
if all(isinstance(v, str) for v in non_null_values):
literal_type = Literal[tuple(non_null_values)] # type: ignore[valid-type]
# If null is in enum, make it optional
if None in enum_values:
return literal_type | None # type: ignore[return-value]
return literal_type # type: ignore[return-value]
# For mixed types or non-strings, fall back to the base type
json_type = field_schema.get("type", "string")
return self._get_simple_type(json_type)
def _handle_allof(
self, field_schema: dict[str, Any], field_name: str
) -> Type[Any]:
"""Handle allOf schema composition by merging all schemas.
Args:
field_schema: Schema containing allOf
field_name: Name for the generated model
Returns:
Merged Pydantic model or basic type
"""
merged_properties: dict[str, Any] = {}
merged_required: list[str] = []
found_type: str | None = None
for sub_schema in field_schema["allOf"]:
# Collect type information
if sub_schema.get("type"):
found_type = sub_schema.get("type")
# Merge properties
if sub_schema.get("properties"):
merged_properties.update(sub_schema["properties"])
# Merge required fields
if sub_schema.get("required"):
merged_required.extend(sub_schema["required"])
# Handle nested anyOf/oneOf - merge properties from all variants
for union_key in ("anyOf", "oneOf"):
if union_key in sub_schema:
for variant in sub_schema[union_key]:
if variant.get("properties"):
# Merge variant properties (will be optional)
for prop_name, prop_schema in variant["properties"].items():
if prop_name not in merged_properties:
merged_properties[prop_name] = prop_schema
# If we found properties, create a merged object model
if merged_properties:
return self._create_pydantic_model(
field_name, merged_properties, merged_required
)
# Fallback: return the found type or dict
if found_type == "object":
return dict
elif found_type == "array":
return list
return dict # Default for complex allOf
def _handle_union_schemas(
self, field_schema: dict[str, Any], field_name: str
) -> Type[Any]:
"""Handle anyOf/oneOf union schemas.
Args:
field_schema: Schema containing anyOf or oneOf
field_name: Name for nested types
Returns:
Union type combining all options
"""
key = "anyOf" if "anyOf" in field_schema else "oneOf"
types: list[Type[Any]] = []
for option in field_schema[key]:
if "const" in option:
# For const values, use string type
# Could use Literal[option["const"]] for more precision
types.append(str)
else:
types.append(self._json_type_to_python(option, field_name))
return self._build_union_type(types)
def _handle_type_union(self, json_types: list[str]) -> Type[Any]:
"""Handle union types from type arrays.
Args:
json_types: List of JSON Schema type strings
Returns:
Union of corresponding Python types
"""
type_mapping: dict[str, Type[Any]] = {
"string": str,
"number": float,
"integer": int,
"boolean": bool,
"null": type(None),
"array": list,
"object": dict,
}
types = [type_mapping.get(t, Any) for t in json_types]
return self._build_union_type(types)
def _handle_array_type(
self, field_schema: dict[str, Any], field_name: str
) -> Type[Any]:
"""Handle array type with typed items.
Args:
field_schema: Schema with type="array"
field_name: Name for item types
Returns:
list or list[ItemType]
"""
items_schema = field_schema.get("items")
if items_schema:
item_type = self._json_type_to_python(items_schema, f"{field_name}Item")
return list[item_type] # type: ignore[valid-type]
return list
def _handle_object_type(
self, field_schema: dict[str, Any], field_name: str
) -> Type[Any]:
"""Handle object type with properties.
Args:
field_schema: Schema with type="object"
field_name: Name for the generated model
Returns:
Pydantic model or dict
"""
properties = field_schema.get("properties")
if properties:
required_fields = field_schema.get("required", [])
return self._create_pydantic_model(field_name, properties, required_fields)
# Object without properties (e.g., additionalProperties only)
return dict
def _create_pydantic_model(
self,
field_name: str,
properties: dict[str, Any],
required_fields: list[str],
) -> Type[Any]:
"""Create a Pydantic model from properties.
Args:
field_name: Base name for the model
properties: Property schemas
required_fields: List of required property names
Returns:
Dynamically created Pydantic model
"""
model_name = f"Generated_{field_name}_{uuid.uuid4().hex[:8]}"
field_definitions: dict[str, Any] = {}
for prop_name, prop_schema in properties.items():
prop_type = self._json_type_to_python(prop_schema, prop_name.title())
prop_description = self._build_field_description(prop_schema)
is_required = prop_name in required_fields
if is_required:
field_definitions[prop_name] = (
prop_type,
Field(..., description=prop_description),
)
else:
field_definitions[prop_name] = (
prop_type | None,
Field(default=None, description=prop_description),
)
return create_model(model_name, **field_definitions) # type: ignore[return-value]
def _build_field_description(self, prop_schema: dict[str, Any]) -> str:
"""Build a comprehensive field description including constraints.
Args:
prop_schema: Property schema with description and constraints
Returns:
Enhanced description with format, enum, and other constraints
"""
parts: list[str] = []
# Start with the original description
description = prop_schema.get("description", "")
if description:
parts.append(description)
# Add format constraint
format_type = prop_schema.get("format")
if format_type:
parts.append(f"Format: {format_type}")
# Add enum constraint (if not already handled by Literal type)
enum_values = prop_schema.get("enum")
if enum_values:
enum_str = ", ".join(repr(v) for v in enum_values)
parts.append(f"Allowed values: [{enum_str}]")
# Add pattern constraint
pattern = prop_schema.get("pattern")
if pattern:
parts.append(f"Pattern: {pattern}")
# Add min/max constraints
minimum = prop_schema.get("minimum")
maximum = prop_schema.get("maximum")
if minimum is not None:
parts.append(f"Minimum: {minimum}")
if maximum is not None:
parts.append(f"Maximum: {maximum}")
min_length = prop_schema.get("minLength")
max_length = prop_schema.get("maxLength")
if min_length is not None:
parts.append(f"Min length: {min_length}")
if max_length is not None:
parts.append(f"Max length: {max_length}")
# Add examples if available
examples = prop_schema.get("examples")
if examples:
examples_str = ", ".join(repr(e) for e in examples[:3]) # Limit to 3
parts.append(f"Examples: {examples_str}")
return ". ".join(parts) if parts else ""
def _get_simple_type(self, json_type: str | None) -> Type[Any]:
"""Map simple JSON Schema types to Python types.
Args:
json_type: JSON Schema type string
Returns:
Corresponding Python type
"""
simple_type_mapping: dict[str | None, Type[Any]] = {
"string": str,
"number": float,
"integer": int,
"boolean": bool,
"null": type(None),
}
return simple_type_mapping.get(json_type, Any)
def _build_union_type(self, types: list[Type[Any]]) -> Type[Any]:
"""Build a union type from a list of types.
Args:
types: List of Python types to combine
Returns:
Union type or single type if only one unique type
"""
# Remove duplicates while preserving order
unique_types = list(dict.fromkeys(types))
if len(unique_types) == 1:
return unique_types[0]
# Build union using | operator
result = unique_types[0]
for t in unique_types[1:]:
result = result | t
return result # type: ignore[no-any-return]

View File

@@ -420,6 +420,7 @@ class MCPClient:
return [
{
"name": sanitize_tool_name(tool.name),
"original_name": tool.name,
"description": getattr(tool, "description", ""),
"inputSchema": getattr(tool, "inputSchema", {}),
}

View File

@@ -27,14 +27,16 @@ class MCPNativeTool(BaseTool):
tool_name: str,
tool_schema: dict[str, Any],
server_name: str,
original_tool_name: str | None = None,
) -> None:
"""Initialize native MCP tool.
Args:
mcp_client: MCPClient instance with active session.
tool_name: Original name of the tool on the MCP server.
tool_name: Name of the tool (may be prefixed).
tool_schema: Schema information for the tool.
server_name: Name of the MCP server for prefixing.
original_tool_name: Original name of the tool on the MCP server.
"""
# Create tool name with server prefix to avoid conflicts
prefixed_name = f"{server_name}_{tool_name}"
@@ -57,7 +59,7 @@ class MCPNativeTool(BaseTool):
# Set instance attributes after super().__init__
self._mcp_client = mcp_client
self._original_tool_name = tool_name
self._original_tool_name = original_tool_name or tool_name
self._server_name = server_name
# self._logger = logging.getLogger(__name__)