fix: correct tool-calling content handling and schema serialization

- fix(gemini): prevent tool calls from using stale text content; correct key refs
- fix(agent-executor): resolve type errors
- refactor(schema): extract Pydantic schema utilities from platform tools
- fix(schema): properly serialize schemas and ensure Responses API uses a separate structure
- fix: preserve list identity to avoid mutation/aliasing issues
- chore(tests): update assumptions to match new behavior
This commit is contained in:
Greyson LaLonde
2026-01-27 15:47:29 -05:00
committed by GitHub
parent d52dbc1f4b
commit 3b17026082
13 changed files with 560 additions and 807 deletions

View File

@@ -1,10 +1,11 @@
"""Crewai Enterprise Tools."""
import os
import json
import re
from typing import Any, Optional, Union, cast, get_origin
import os
from typing import Any
from crewai.tools import BaseTool
from crewai.utilities.pydantic_schema_utils import create_model_from_schema
from pydantic import Field, create_model
import requests
@@ -14,77 +15,6 @@ from crewai_tools.tools.crewai_platform_tools.misc import (
)
class AllOfSchemaAnalyzer:
"""Helper class to analyze and merge allOf schemas."""
def __init__(self, schemas: list[dict[str, Any]]):
self.schemas = schemas
self._explicit_types: list[str] = []
self._merged_properties: dict[str, Any] = {}
self._merged_required: list[str] = []
self._analyze_schemas()
def _analyze_schemas(self) -> None:
"""Analyze all schemas and extract relevant information."""
for schema in self.schemas:
if "type" in schema:
self._explicit_types.append(schema["type"])
# Merge object properties
if schema.get("type") == "object" and "properties" in schema:
self._merged_properties.update(schema["properties"])
if "required" in schema:
self._merged_required.extend(schema["required"])
def has_consistent_type(self) -> bool:
"""Check if all schemas have the same explicit type."""
return len(set(self._explicit_types)) == 1 if self._explicit_types else False
def get_consistent_type(self) -> type[Any]:
"""Get the consistent type if all schemas agree."""
if not self.has_consistent_type():
raise ValueError("No consistent type found")
type_mapping = {
"string": str,
"integer": int,
"number": float,
"boolean": bool,
"array": list,
"object": dict,
"null": type(None),
}
return type_mapping.get(self._explicit_types[0], str)
def has_object_schemas(self) -> bool:
"""Check if any schemas are object types with properties."""
return bool(self._merged_properties)
def get_merged_properties(self) -> dict[str, Any]:
"""Get merged properties from all object schemas."""
return self._merged_properties
def get_merged_required_fields(self) -> list[str]:
"""Get merged required fields from all object schemas."""
return list(set(self._merged_required)) # Remove duplicates
def get_fallback_type(self) -> type[Any]:
"""Get a fallback type when merging fails."""
if self._explicit_types:
# Use the first explicit type
type_mapping = {
"string": str,
"integer": int,
"number": float,
"boolean": bool,
"array": list,
"object": dict,
"null": type(None),
}
return type_mapping.get(self._explicit_types[0], str)
return str
class CrewAIPlatformActionTool(BaseTool):
action_name: str = Field(default="", description="The name of the action")
action_schema: dict[str, Any] = Field(
@@ -97,42 +27,19 @@ class CrewAIPlatformActionTool(BaseTool):
action_name: str,
action_schema: dict[str, Any],
):
self._model_registry: dict[str, type[Any]] = {}
self._base_name = self._sanitize_name(action_name)
schema_props, required = self._extract_schema_info(action_schema)
field_definitions: dict[str, Any] = {}
for param_name, param_details in schema_props.items():
param_desc = param_details.get("description", "")
is_required = param_name in required
parameters = action_schema.get("function", {}).get("parameters", {})
if parameters and parameters.get("properties"):
try:
field_type = self._process_schema_type(
param_details, self._sanitize_name(param_name).title()
)
if "title" not in parameters:
parameters = {**parameters, "title": f"{action_name}Schema"}
if "type" not in parameters:
parameters = {**parameters, "type": "object"}
args_schema = create_model_from_schema(parameters)
except Exception:
field_type = str
field_definitions[param_name] = self._create_field_definition(
field_type, is_required, param_desc
)
if field_definitions:
try:
args_schema = create_model(
f"{self._base_name}Schema", **field_definitions
)
except Exception:
args_schema = create_model(
f"{self._base_name}Schema",
input_text=(str, Field(description="Input for the action")),
)
args_schema = create_model(f"{action_name}Schema")
else:
args_schema = create_model(
f"{self._base_name}Schema",
input_text=(str, Field(description="Input for the action")),
)
args_schema = create_model(f"{action_name}Schema")
super().__init__(
name=action_name.lower().replace(" ", "_"),
@@ -142,285 +49,12 @@ class CrewAIPlatformActionTool(BaseTool):
self.action_name = action_name
self.action_schema = action_schema
@staticmethod
def _sanitize_name(name: str) -> str:
name = name.lower().replace(" ", "_")
sanitized = re.sub(r"[^a-zA-Z0-9_]", "", name)
parts = sanitized.split("_")
return "".join(word.capitalize() for word in parts if word)
@staticmethod
def _extract_schema_info(
action_schema: dict[str, Any],
) -> tuple[dict[str, Any], list[str]]:
schema_props = (
action_schema.get("function", {})
.get("parameters", {})
.get("properties", {})
)
required = (
action_schema.get("function", {}).get("parameters", {}).get("required", [])
)
return schema_props, required
def _process_schema_type(self, schema: dict[str, Any], type_name: str) -> type[Any]:
"""
Process a JSON Schema type definition into a Python type.
Handles complex schema constructs like anyOf, oneOf, allOf, enums, arrays, and objects.
"""
# Handle composite schema types (anyOf, oneOf, allOf)
if composite_type := self._process_composite_schema(schema, type_name):
return composite_type
# Handle primitive types and simple constructs
return self._process_primitive_schema(schema, type_name)
def _process_composite_schema(
self, schema: dict[str, Any], type_name: str
) -> type[Any] | None:
"""Process composite schema types: anyOf, oneOf, allOf."""
if "anyOf" in schema:
return self._process_any_of_schema(schema["anyOf"], type_name)
if "oneOf" in schema:
return self._process_one_of_schema(schema["oneOf"], type_name)
if "allOf" in schema:
return self._process_all_of_schema(schema["allOf"], type_name)
return None
def _process_any_of_schema(
self, any_of_types: list[dict[str, Any]], type_name: str
) -> type[Any]:
"""Process anyOf schema - creates Union of possible types."""
is_nullable = any(t.get("type") == "null" for t in any_of_types)
non_null_types = [t for t in any_of_types if t.get("type") != "null"]
if not non_null_types:
return cast(
type[Any], cast(object, str | None)
) # fallback for only-null case
base_type = (
self._process_schema_type(non_null_types[0], type_name)
if len(non_null_types) == 1
else self._create_union_type(non_null_types, type_name, "AnyOf")
)
return base_type | None if is_nullable else base_type # type: ignore[return-value]
def _process_one_of_schema(
self, one_of_types: list[dict[str, Any]], type_name: str
) -> type[Any]:
"""Process oneOf schema - creates Union of mutually exclusive types."""
return (
self._process_schema_type(one_of_types[0], type_name)
if len(one_of_types) == 1
else self._create_union_type(one_of_types, type_name, "OneOf")
)
def _process_all_of_schema(
self, all_of_schemas: list[dict[str, Any]], type_name: str
) -> type[Any]:
"""Process allOf schema - merges schemas that must all be satisfied."""
if len(all_of_schemas) == 1:
return self._process_schema_type(all_of_schemas[0], type_name)
return self._merge_all_of_schemas(all_of_schemas, type_name)
def _create_union_type(
self, schemas: list[dict[str, Any]], type_name: str, prefix: str
) -> type[Any]:
"""Create a Union type from multiple schemas."""
return Union[ # type: ignore # noqa: UP007
tuple(
self._process_schema_type(schema, f"{type_name}{prefix}{i}")
for i, schema in enumerate(schemas)
)
]
def _process_primitive_schema(
self, schema: dict[str, Any], type_name: str
) -> type[Any]:
"""Process primitive schema types: string, number, array, object, etc."""
json_type = schema.get("type", "string")
if "enum" in schema:
return self._process_enum_schema(schema, json_type)
if json_type == "array":
return self._process_array_schema(schema, type_name)
if json_type == "object":
return self._create_nested_model(schema, type_name)
return self._map_json_type_to_python(json_type)
def _process_enum_schema(self, schema: dict[str, Any], json_type: str) -> type[Any]:
"""Process enum schema - currently falls back to base type."""
enum_values = schema["enum"]
if not enum_values:
return self._map_json_type_to_python(json_type)
# For Literal types, we need to pass the values directly, not as a tuple
# This is a workaround since we can't dynamically create Literal types easily
# Fall back to the base JSON type for now
return self._map_json_type_to_python(json_type)
def _process_array_schema(
self, schema: dict[str, Any], type_name: str
) -> type[Any]:
items_schema = schema.get("items", {"type": "string"})
item_type = self._process_schema_type(items_schema, f"{type_name}Item")
return list[item_type] # type: ignore
def _merge_all_of_schemas(
self, schemas: list[dict[str, Any]], type_name: str
) -> type[Any]:
schema_analyzer = AllOfSchemaAnalyzer(schemas)
if schema_analyzer.has_consistent_type():
return schema_analyzer.get_consistent_type()
if schema_analyzer.has_object_schemas():
return self._create_merged_object_model(
schema_analyzer.get_merged_properties(),
schema_analyzer.get_merged_required_fields(),
type_name,
)
return schema_analyzer.get_fallback_type()
def _create_merged_object_model(
self, properties: dict[str, Any], required: list[str], model_name: str
) -> type[Any]:
full_model_name = f"{self._base_name}{model_name}AllOf"
if full_model_name in self._model_registry:
return self._model_registry[full_model_name]
if not properties:
return dict
field_definitions = self._build_field_definitions(
properties, required, model_name
)
try:
merged_model = create_model(full_model_name, **field_definitions)
self._model_registry[full_model_name] = merged_model
return merged_model
except Exception:
return dict
def _build_field_definitions(
self, properties: dict[str, Any], required: list[str], model_name: str
) -> dict[str, Any]:
field_definitions = {}
for prop_name, prop_schema in properties.items():
prop_desc = prop_schema.get("description", "")
is_required = prop_name in required
try:
prop_type = self._process_schema_type(
prop_schema, f"{model_name}{self._sanitize_name(prop_name).title()}"
)
except Exception:
prop_type = str
field_definitions[prop_name] = self._create_field_definition(
prop_type, is_required, prop_desc
)
return field_definitions
def _create_nested_model(
self, schema: dict[str, Any], model_name: str
) -> type[Any]:
full_model_name = f"{self._base_name}{model_name}"
if full_model_name in self._model_registry:
return self._model_registry[full_model_name]
properties = schema.get("properties", {})
required_fields = schema.get("required", [])
if not properties:
return dict
field_definitions = {}
for prop_name, prop_schema in properties.items():
prop_desc = prop_schema.get("description", "")
is_required = prop_name in required_fields
try:
prop_type = self._process_schema_type(
prop_schema, f"{model_name}{self._sanitize_name(prop_name).title()}"
)
except Exception:
prop_type = str
field_definitions[prop_name] = self._create_field_definition(
prop_type, is_required, prop_desc
)
try:
nested_model = create_model(full_model_name, **field_definitions) # type: ignore
self._model_registry[full_model_name] = nested_model
return nested_model
except Exception:
return dict
def _create_field_definition(
self, field_type: type[Any], is_required: bool, description: str
) -> tuple:
if is_required:
return (field_type, Field(description=description))
if get_origin(field_type) is Union:
return (field_type, Field(default=None, description=description))
return (
Optional[field_type], # noqa: UP045
Field(default=None, description=description),
)
def _map_json_type_to_python(self, json_type: str) -> type[Any]:
type_mapping = {
"string": str,
"integer": int,
"number": float,
"boolean": bool,
"array": list,
"object": dict,
"null": type(None),
}
return type_mapping.get(json_type, str)
def _get_required_nullable_fields(self) -> list[str]:
schema_props, required = self._extract_schema_info(self.action_schema)
required_nullable_fields = []
for param_name in required:
param_details = schema_props.get(param_name, {})
if self._is_nullable_type(param_details):
required_nullable_fields.append(param_name)
return required_nullable_fields
def _is_nullable_type(self, schema: dict[str, Any]) -> bool:
if "anyOf" in schema:
return any(t.get("type") == "null" for t in schema["anyOf"])
return schema.get("type") == "null"
def _run(self, **kwargs) -> str:
def _run(self, **kwargs: Any) -> str:
try:
cleaned_kwargs = {
key: value for key, value in kwargs.items() if value is not None
}
required_nullable_fields = self._get_required_nullable_fields()
for field_name in required_nullable_fields:
if field_name not in cleaned_kwargs:
cleaned_kwargs[field_name] = None
api_url = (
f"{get_platform_api_base_url()}/actions/{self.action_name}/execute"
)
@@ -429,7 +63,9 @@ class CrewAIPlatformActionTool(BaseTool):
"Authorization": f"Bearer {token}",
"Content-Type": "application/json",
}
payload = cleaned_kwargs
payload = {
"integration": cleaned_kwargs if cleaned_kwargs else {"_noop": True}
}
response = requests.post(
url=api_url,
@@ -441,7 +77,14 @@ class CrewAIPlatformActionTool(BaseTool):
data = response.json()
if not response.ok:
error_message = data.get("error", {}).get("message", json.dumps(data))
if isinstance(data, dict):
error_info = data.get("error", {})
if isinstance(error_info, dict):
error_message = error_info.get("message", json.dumps(data))
else:
error_message = str(error_info)
else:
error_message = str(data)
return f"API request failed: {error_message}"
return json.dumps(data, indent=2)

View File

@@ -1,5 +1,10 @@
from typing import Any
"""CrewAI platform tool builder for fetching and creating action tools."""
import logging
import os
from types import TracebackType
from typing import Any
from crewai.tools import BaseTool
import requests
@@ -12,22 +17,29 @@ from crewai_tools.tools.crewai_platform_tools.misc import (
)
logger = logging.getLogger(__name__)
class CrewaiPlatformToolBuilder:
"""Builds platform tools from remote action schemas."""
def __init__(
self,
apps: list[str],
):
) -> None:
self._apps = apps
self._actions_schema = {} # type: ignore[var-annotated]
self._tools = None
self._actions_schema: dict[str, dict[str, Any]] = {}
self._tools: list[BaseTool] | None = None
def tools(self) -> list[BaseTool]:
"""Fetch actions and return built tools."""
if self._tools is None:
self._fetch_actions()
self._create_tools()
return self._tools if self._tools is not None else []
def _fetch_actions(self):
def _fetch_actions(self) -> None:
"""Fetch action schemas from the platform API."""
actions_url = f"{get_platform_api_base_url()}/actions"
headers = {"Authorization": f"Bearer {get_platform_integration_token()}"}
@@ -40,7 +52,8 @@ class CrewaiPlatformToolBuilder:
verify=os.environ.get("CREWAI_FACTORY", "false").lower() != "true",
)
response.raise_for_status()
except Exception:
except Exception as e:
logger.error(f"Failed to fetch platform tools for apps {self._apps}: {e}")
return
raw_data = response.json()
@@ -51,6 +64,8 @@ class CrewaiPlatformToolBuilder:
for app, action_list in action_categories.items():
if isinstance(action_list, list):
for action in action_list:
if not isinstance(action, dict):
continue
if action_name := action.get("name"):
action_schema = {
"function": {
@@ -64,72 +79,16 @@ class CrewaiPlatformToolBuilder:
}
self._actions_schema[action_name] = action_schema
def _generate_detailed_description(
self, schema: dict[str, Any], indent: int = 0
) -> list[str]:
descriptions = []
indent_str = " " * indent
schema_type = schema.get("type", "string")
if schema_type == "object":
properties = schema.get("properties", {})
required_fields = schema.get("required", [])
if properties:
descriptions.append(f"{indent_str}Object with properties:")
for prop_name, prop_schema in properties.items():
prop_desc = prop_schema.get("description", "")
is_required = prop_name in required_fields
req_str = " (required)" if is_required else " (optional)"
descriptions.append(
f"{indent_str} - {prop_name}: {prop_desc}{req_str}"
)
if prop_schema.get("type") == "object":
descriptions.extend(
self._generate_detailed_description(prop_schema, indent + 2)
)
elif prop_schema.get("type") == "array":
items_schema = prop_schema.get("items", {})
if items_schema.get("type") == "object":
descriptions.append(f"{indent_str} Array of objects:")
descriptions.extend(
self._generate_detailed_description(
items_schema, indent + 3
)
)
elif "enum" in items_schema:
descriptions.append(
f"{indent_str} Array of enum values: {items_schema['enum']}"
)
elif "enum" in prop_schema:
descriptions.append(
f"{indent_str} Enum values: {prop_schema['enum']}"
)
return descriptions
def _create_tools(self):
tools = []
def _create_tools(self) -> None:
"""Create tool instances from fetched action schemas."""
tools: list[BaseTool] = []
for action_name, action_schema in self._actions_schema.items():
function_details = action_schema.get("function", {})
description = function_details.get("description", f"Execute {action_name}")
parameters = function_details.get("parameters", {})
param_descriptions = []
if parameters.get("properties"):
param_descriptions.append("\nDetailed Parameter Structure:")
param_descriptions.extend(
self._generate_detailed_description(parameters)
)
full_description = description + "\n".join(param_descriptions)
tool = CrewAIPlatformActionTool(
description=full_description,
description=description,
action_name=action_name,
action_schema=action_schema,
)
@@ -138,8 +97,14 @@ class CrewaiPlatformToolBuilder:
self._tools = tools
def __enter__(self):
def __enter__(self) -> list[BaseTool]:
"""Enter context manager and return tools."""
return self.tools()
def __exit__(self, exc_type, exc_val, exc_tb):
pass
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
"""Exit context manager."""

View File

@@ -1,4 +1,3 @@
from typing import Union, get_args, get_origin
from unittest.mock import patch, Mock
import os
@@ -7,251 +6,6 @@ from crewai_tools.tools.crewai_platform_tools.crewai_platform_action_tool import
)
class TestSchemaProcessing:
def setup_method(self):
self.base_action_schema = {
"function": {
"parameters": {
"properties": {},
"required": []
}
}
}
def create_test_tool(self, action_name="test_action"):
return CrewAIPlatformActionTool(
description="Test tool",
action_name=action_name,
action_schema=self.base_action_schema
)
def test_anyof_multiple_types(self):
tool = self.create_test_tool()
test_schema = {
"anyOf": [
{"type": "string"},
{"type": "number"},
{"type": "integer"}
]
}
result_type = tool._process_schema_type(test_schema, "TestField")
assert get_origin(result_type) is Union
args = get_args(result_type)
expected_types = (str, float, int)
for expected_type in expected_types:
assert expected_type in args
def test_anyof_with_null(self):
tool = self.create_test_tool()
test_schema = {
"anyOf": [
{"type": "string"},
{"type": "number"},
{"type": "null"}
]
}
result_type = tool._process_schema_type(test_schema, "TestFieldNullable")
assert get_origin(result_type) is Union
args = get_args(result_type)
assert type(None) in args
assert str in args
assert float in args
def test_anyof_single_type(self):
tool = self.create_test_tool()
test_schema = {
"anyOf": [
{"type": "string"}
]
}
result_type = tool._process_schema_type(test_schema, "TestFieldSingle")
assert result_type is str
def test_oneof_multiple_types(self):
tool = self.create_test_tool()
test_schema = {
"oneOf": [
{"type": "string"},
{"type": "boolean"}
]
}
result_type = tool._process_schema_type(test_schema, "TestFieldOneOf")
assert get_origin(result_type) is Union
args = get_args(result_type)
expected_types = (str, bool)
for expected_type in expected_types:
assert expected_type in args
def test_oneof_single_type(self):
tool = self.create_test_tool()
test_schema = {
"oneOf": [
{"type": "integer"}
]
}
result_type = tool._process_schema_type(test_schema, "TestFieldOneOfSingle")
assert result_type is int
def test_basic_types(self):
tool = self.create_test_tool()
test_cases = [
({"type": "string"}, str),
({"type": "integer"}, int),
({"type": "number"}, float),
({"type": "boolean"}, bool),
({"type": "array", "items": {"type": "string"}}, list),
]
for schema, expected_type in test_cases:
result_type = tool._process_schema_type(schema, "TestField")
if schema["type"] == "array":
assert get_origin(result_type) is list
else:
assert result_type is expected_type
def test_enum_handling(self):
tool = self.create_test_tool()
test_schema = {
"type": "string",
"enum": ["option1", "option2", "option3"]
}
result_type = tool._process_schema_type(test_schema, "TestFieldEnum")
assert result_type is str
def test_nested_anyof(self):
tool = self.create_test_tool()
test_schema = {
"anyOf": [
{"type": "string"},
{
"anyOf": [
{"type": "integer"},
{"type": "boolean"}
]
}
]
}
result_type = tool._process_schema_type(test_schema, "TestFieldNested")
assert get_origin(result_type) is Union
args = get_args(result_type)
assert str in args
if len(args) == 3:
assert int in args
assert bool in args
else:
nested_union = next(arg for arg in args if get_origin(arg) is Union)
nested_args = get_args(nested_union)
assert int in nested_args
assert bool in nested_args
def test_allof_same_types(self):
tool = self.create_test_tool()
test_schema = {
"allOf": [
{"type": "string"},
{"type": "string", "maxLength": 100}
]
}
result_type = tool._process_schema_type(test_schema, "TestFieldAllOfSame")
assert result_type is str
def test_allof_object_merge(self):
tool = self.create_test_tool()
test_schema = {
"allOf": [
{
"type": "object",
"properties": {
"name": {"type": "string"},
"age": {"type": "integer"}
},
"required": ["name"]
},
{
"type": "object",
"properties": {
"email": {"type": "string"},
"age": {"type": "integer"}
},
"required": ["email"]
}
]
}
result_type = tool._process_schema_type(test_schema, "TestFieldAllOfMerged")
# Should create a merged model with all properties
# The implementation might fall back to dict if model creation fails
# Let's just verify it's not a basic scalar type
assert result_type is not str
assert result_type is not int
assert result_type is not bool
# It could be dict (fallback) or a proper model class
assert result_type in (dict, type) or hasattr(result_type, '__name__')
def test_allof_single_schema(self):
"""Test that allOf with single schema works correctly."""
tool = self.create_test_tool()
test_schema = {
"allOf": [
{"type": "boolean"}
]
}
result_type = tool._process_schema_type(test_schema, "TestFieldAllOfSingle")
# Should be just bool
assert result_type is bool
def test_allof_mixed_types(self):
tool = self.create_test_tool()
test_schema = {
"allOf": [
{"type": "string"},
{"type": "integer"}
]
}
result_type = tool._process_schema_type(test_schema, "TestFieldAllOfMixed")
assert result_type is str
class TestCrewAIPlatformActionToolVerify:
"""Test suite for SSL verification behavior based on CREWAI_FACTORY environment variable"""

View File

@@ -224,43 +224,6 @@ class TestCrewaiPlatformToolBuilder(unittest.TestCase):
_, kwargs = mock_get.call_args
assert kwargs["params"]["apps"] == ""
def test_detailed_description_generation(self):
builder = CrewaiPlatformToolBuilder(apps=["test"])
complex_schema = {
"type": "object",
"properties": {
"simple_string": {"type": "string", "description": "A simple string"},
"nested_object": {
"type": "object",
"properties": {
"inner_prop": {
"type": "integer",
"description": "Inner property",
}
},
"description": "Nested object",
},
"array_prop": {
"type": "array",
"items": {"type": "string"},
"description": "Array of strings",
},
},
}
descriptions = builder._generate_detailed_description(complex_schema)
assert isinstance(descriptions, list)
assert len(descriptions) > 0
description_text = "\n".join(descriptions)
assert "simple_string" in description_text
assert "nested_object" in description_text
assert "array_prop" in description_text
class TestCrewaiPlatformToolBuilderVerify(unittest.TestCase):
"""Test suite for SSL verification behavior in CrewaiPlatformToolBuilder"""

View File

@@ -36,6 +36,7 @@ from crewai.hooks.llm_hooks import (
get_after_llm_call_hooks,
get_before_llm_call_hooks,
)
from crewai.hooks.types import AfterLLMCallHookType, BeforeLLMCallHookType
from crewai.utilities.agent_utils import (
convert_tools_to_openai_schema,
enforce_rpm_limit,
@@ -185,8 +186,8 @@ class AgentExecutor(Flow[AgentReActState], CrewAgentExecutorMixin):
self._instance_id = str(uuid4())[:8]
self.before_llm_call_hooks: list[Callable] = []
self.after_llm_call_hooks: list[Callable] = []
self.before_llm_call_hooks: list[BeforeLLMCallHookType] = []
self.after_llm_call_hooks: list[AfterLLMCallHookType] = []
self.before_llm_call_hooks.extend(get_before_llm_call_hooks())
self.after_llm_call_hooks.extend(get_after_llm_call_hooks())
@@ -299,11 +300,21 @@ class AgentExecutor(Flow[AgentReActState], CrewAgentExecutorMixin):
"""Compatibility property for mixin - returns state messages."""
return self._state.messages
@messages.setter
def messages(self, value: list[LLMMessage]) -> None:
"""Set state messages."""
self._state.messages = value
@property
def iterations(self) -> int:
"""Compatibility property for mixin - returns state iterations."""
return self._state.iterations
@iterations.setter
def iterations(self, value: int) -> None:
"""Set state iterations."""
self._state.iterations = value
@start()
def initialize_reasoning(self) -> Literal["initialized"]:
"""Initialize the reasoning flow and emit agent start logs."""
@@ -577,6 +588,12 @@ class AgentExecutor(Flow[AgentReActState], CrewAgentExecutorMixin):
"content": None,
"tool_calls": tool_calls_to_report,
}
if all(
type(tc).__qualname__ == "Part" for tc in self.state.pending_tool_calls
):
assistant_message["raw_tool_call_parts"] = list(
self.state.pending_tool_calls
)
self.state.messages.append(assistant_message)
# Now execute each tool
@@ -611,14 +628,12 @@ class AgentExecutor(Flow[AgentReActState], CrewAgentExecutorMixin):
# Check if tool has reached max usage count
max_usage_reached = False
if original_tool:
if (
hasattr(original_tool, "max_usage_count")
and original_tool.max_usage_count is not None
and original_tool.current_usage_count
>= original_tool.max_usage_count
):
max_usage_reached = True
if (
original_tool
and original_tool.max_usage_count is not None
and original_tool.current_usage_count >= original_tool.max_usage_count
):
max_usage_reached = True
# Check cache before executing
from_cache = False
@@ -661,11 +676,7 @@ class AgentExecutor(Flow[AgentReActState], CrewAgentExecutorMixin):
# Add to cache after successful execution (before string conversion)
if self.tools_handler and self.tools_handler.cache:
should_cache = True
if (
original_tool
and hasattr(original_tool, "cache_function")
and original_tool.cache_function
):
if original_tool:
should_cache = original_tool.cache_function(
args_dict, raw_result
)
@@ -696,7 +707,7 @@ class AgentExecutor(Flow[AgentReActState], CrewAgentExecutorMixin):
error=e,
),
)
elif max_usage_reached:
elif max_usage_reached and original_tool:
# Return error message when max usage limit is reached
result = f"Tool '{func_name}' has reached its usage limit of {original_tool.max_usage_count} times and cannot be used anymore."
@@ -833,6 +844,10 @@ class AgentExecutor(Flow[AgentReActState], CrewAgentExecutorMixin):
@listen("parser_error")
def recover_from_parser_error(self) -> Literal["initialized"]:
"""Recover from output parser errors and retry."""
if not self._last_parser_error:
self.state.iterations += 1
return "initialized"
formatted_answer = handle_output_parser_exception(
e=self._last_parser_error,
messages=list(self.state.messages),

View File

@@ -9,6 +9,7 @@ from crewai.utilities.printer import Printer
if TYPE_CHECKING:
from crewai.agents.crew_agent_executor import CrewAgentExecutor
from crewai.experimental.agent_executor import AgentExecutor
from crewai.lite_agent import LiteAgent
from crewai.llms.base_llm import BaseLLM
from crewai.utilities.types import LLMMessage
@@ -41,7 +42,7 @@ class LLMCallHookContext:
Can be modified by returning a new string from after_llm_call hook.
"""
executor: CrewAgentExecutor | LiteAgent | None
executor: CrewAgentExecutor | AgentExecutor | LiteAgent | None
messages: list[LLMMessage]
agent: Any
task: Any
@@ -52,7 +53,7 @@ class LLMCallHookContext:
def __init__(
self,
executor: CrewAgentExecutor | LiteAgent | None = None,
executor: CrewAgentExecutor | AgentExecutor | LiteAgent | None = None,
response: str | None = None,
messages: list[LLMMessage] | None = None,
llm: BaseLLM | str | Any | None = None, # TODO: look into

View File

@@ -16,6 +16,7 @@ from crewai.utilities.agent_utils import is_context_length_exceeded
from crewai.utilities.exceptions.context_window_exceeding_exception import (
LLMContextLengthExceededError,
)
from crewai.utilities.pydantic_schema_utils import generate_model_description
from crewai.utilities.types import LLMMessage
@@ -548,7 +549,11 @@ class BedrockCompletion(BaseLLM):
"toolSpec": {
"name": "structured_output",
"description": "Returns structured data according to the schema",
"inputSchema": {"json": response_model.model_json_schema()},
"inputSchema": {
"json": generate_model_description(response_model)
.get("json_schema", {})
.get("schema", {})
},
}
}
body["toolConfig"] = cast(
@@ -779,7 +784,11 @@ class BedrockCompletion(BaseLLM):
"toolSpec": {
"name": "structured_output",
"description": "Returns structured data according to the schema",
"inputSchema": {"json": response_model.model_json_schema()},
"inputSchema": {
"json": generate_model_description(response_model)
.get("json_schema", {})
.get("schema", {})
},
}
}
body["toolConfig"] = cast(
@@ -1011,7 +1020,11 @@ class BedrockCompletion(BaseLLM):
"toolSpec": {
"name": "structured_output",
"description": "Returns structured data according to the schema",
"inputSchema": {"json": response_model.model_json_schema()},
"inputSchema": {
"json": generate_model_description(response_model)
.get("json_schema", {})
.get("schema", {})
},
}
}
body["toolConfig"] = cast(
@@ -1223,7 +1236,11 @@ class BedrockCompletion(BaseLLM):
"toolSpec": {
"name": "structured_output",
"description": "Returns structured data according to the schema",
"inputSchema": {"json": response_model.model_json_schema()},
"inputSchema": {
"json": generate_model_description(response_model)
.get("json_schema", {})
.get("schema", {})
},
}
}
body["toolConfig"] = cast(

View File

@@ -15,6 +15,7 @@ from crewai.utilities.agent_utils import is_context_length_exceeded
from crewai.utilities.exceptions.context_window_exceeding_exception import (
LLMContextLengthExceededError,
)
from crewai.utilities.pydantic_schema_utils import generate_model_description
from crewai.utilities.types import LLMMessage
@@ -464,7 +465,10 @@ class GeminiCompletion(BaseLLM):
if response_model:
config_params["response_mime_type"] = "application/json"
config_params["response_schema"] = response_model.model_json_schema()
schema_output = generate_model_description(response_model)
config_params["response_schema"] = schema_output.get("json_schema", {}).get(
"schema", {}
)
# Handle tools for supported models
if tools and self.supports_tools:
@@ -489,7 +493,7 @@ class GeminiCompletion(BaseLLM):
function_declaration = types.FunctionDeclaration(
name=name,
description=description,
parameters=parameters if parameters else None,
parameters_json_schema=parameters if parameters else None,
)
gemini_tool = types.Tool(function_declarations=[function_declaration])
@@ -543,11 +547,10 @@ class GeminiCompletion(BaseLLM):
else:
parts.append(types.Part.from_text(text=str(content) if content else ""))
text_content: str = " ".join(p.text for p in parts if p.text is not None)
if role == "system":
# Extract system instruction - Gemini handles it separately
text_content = " ".join(
p.text for p in parts if hasattr(p, "text") and p.text
)
if system_instruction:
system_instruction += f"\n\n{text_content}"
else:
@@ -576,31 +579,40 @@ class GeminiCompletion(BaseLLM):
types.Content(role="user", parts=[function_response_part])
)
elif role == "assistant" and message.get("tool_calls"):
tool_parts: list[types.Part] = []
raw_parts: list[Any] | None = message.get("raw_tool_call_parts")
if raw_parts and all(isinstance(p, types.Part) for p in raw_parts):
tool_parts: list[types.Part] = list(raw_parts)
if text_content:
tool_parts.insert(0, types.Part.from_text(text=text_content))
else:
tool_parts = []
if text_content:
tool_parts.append(types.Part.from_text(text=text_content))
if text_content:
tool_parts.append(types.Part.from_text(text=text_content))
tool_calls: list[dict[str, Any]] = message.get("tool_calls") or []
for tool_call in tool_calls:
func: dict[str, Any] = tool_call.get("function") or {}
func_name: str = str(func.get("name") or "")
func_args_raw: str | dict[str, Any] = (
func.get("arguments") or {}
)
tool_calls: list[dict[str, Any]] = message.get("tool_calls") or []
for tool_call in tool_calls:
func: dict[str, Any] = tool_call.get("function") or {}
func_name: str = str(func.get("name") or "")
func_args_raw: str | dict[str, Any] = func.get("arguments") or {}
func_args: dict[str, Any]
if isinstance(func_args_raw, str):
try:
func_args = (
json.loads(func_args_raw) if func_args_raw else {}
)
except (json.JSONDecodeError, TypeError):
func_args = {}
else:
func_args = func_args_raw
func_args: dict[str, Any]
if isinstance(func_args_raw, str):
try:
func_args = (
json.loads(func_args_raw) if func_args_raw else {}
tool_parts.append(
types.Part.from_function_call(
name=func_name, args=func_args
)
except (json.JSONDecodeError, TypeError):
func_args = {}
else:
func_args = func_args_raw
tool_parts.append(
types.Part.from_function_call(name=func_name, args=func_args)
)
)
contents.append(types.Content(role="model", parts=tool_parts))
else:

View File

@@ -693,14 +693,14 @@ class OpenAICompletion(BaseLLM):
if response_model or self.response_format:
format_model = response_model or self.response_format
if isinstance(format_model, type) and issubclass(format_model, BaseModel):
schema = format_model.model_json_schema()
schema["additionalProperties"] = False
schema_output = generate_model_description(format_model)
json_schema = schema_output.get("json_schema", {})
params["text"] = {
"format": {
"type": "json_schema",
"name": format_model.__name__,
"strict": True,
"schema": schema,
"name": json_schema.get("name", format_model.__name__),
"strict": json_schema.get("strict", True),
"schema": json_schema.get("schema", {}),
}
}
elif isinstance(format_model, dict):
@@ -1060,7 +1060,7 @@ class OpenAICompletion(BaseLLM):
chunk=delta_text,
from_task=from_task,
from_agent=from_agent,
response_id=response_id_stream
response_id=response_id_stream,
)
elif event.type == "response.function_call_arguments.delta":
@@ -1709,7 +1709,7 @@ class OpenAICompletion(BaseLLM):
**parse_params, response_format=response_model
) as stream:
for chunk in stream:
response_id_stream=chunk.id if hasattr(chunk,"id") else None
response_id_stream = chunk.id if hasattr(chunk, "id") else None
if chunk.type == "content.delta":
delta_content = chunk.delta
@@ -1718,7 +1718,7 @@ class OpenAICompletion(BaseLLM):
chunk=delta_content,
from_task=from_task,
from_agent=from_agent,
response_id=response_id_stream
response_id=response_id_stream,
)
final_completion = stream.get_final_completion()
@@ -1748,7 +1748,9 @@ class OpenAICompletion(BaseLLM):
usage_data = {"total_tokens": 0}
for completion_chunk in completion_stream:
response_id_stream=completion_chunk.id if hasattr(completion_chunk,"id") else None
response_id_stream = (
completion_chunk.id if hasattr(completion_chunk, "id") else None
)
if hasattr(completion_chunk, "usage") and completion_chunk.usage:
usage_data = self._extract_openai_token_usage(completion_chunk)
@@ -1766,7 +1768,7 @@ class OpenAICompletion(BaseLLM):
chunk=chunk_delta.content,
from_task=from_task,
from_agent=from_agent,
response_id=response_id_stream
response_id=response_id_stream,
)
if chunk_delta.tool_calls:
@@ -1805,7 +1807,7 @@ class OpenAICompletion(BaseLLM):
"index": tool_calls[tool_index]["index"],
},
call_type=LLMCallType.TOOL_CALL,
response_id=response_id_stream
response_id=response_id_stream,
)
self._track_token_usage_internal(usage_data)
@@ -2017,7 +2019,7 @@ class OpenAICompletion(BaseLLM):
accumulated_content = ""
usage_data = {"total_tokens": 0}
async for chunk in completion_stream:
response_id_stream=chunk.id if hasattr(chunk,"id") else None
response_id_stream = chunk.id if hasattr(chunk, "id") else None
if hasattr(chunk, "usage") and chunk.usage:
usage_data = self._extract_openai_token_usage(chunk)
@@ -2035,7 +2037,7 @@ class OpenAICompletion(BaseLLM):
chunk=delta.content,
from_task=from_task,
from_agent=from_agent,
response_id=response_id_stream
response_id=response_id_stream,
)
self._track_token_usage_internal(usage_data)
@@ -2071,7 +2073,7 @@ class OpenAICompletion(BaseLLM):
usage_data = {"total_tokens": 0}
async for chunk in stream:
response_id_stream=chunk.id if hasattr(chunk,"id") else None
response_id_stream = chunk.id if hasattr(chunk, "id") else None
if hasattr(chunk, "usage") and chunk.usage:
usage_data = self._extract_openai_token_usage(chunk)
@@ -2089,7 +2091,7 @@ class OpenAICompletion(BaseLLM):
chunk=chunk_delta.content,
from_task=from_task,
from_agent=from_agent,
response_id=response_id_stream
response_id=response_id_stream,
)
if chunk_delta.tool_calls:
@@ -2128,7 +2130,7 @@ class OpenAICompletion(BaseLLM):
"index": tool_calls[tool_index]["index"],
},
call_type=LLMCallType.TOOL_CALL,
response_id=response_id_stream
response_id=response_id_stream,
)
self._track_token_usage_internal(usage_data)

View File

@@ -2,6 +2,7 @@ import logging
import re
from typing import Any
from crewai.utilities.pydantic_schema_utils import generate_model_description
from crewai.utilities.string_utils import sanitize_tool_name
@@ -77,7 +78,8 @@ def extract_tool_info(tool: dict[str, Any]) -> tuple[str, str, dict[str, Any]]:
# Also check for args_schema (Pydantic format)
if not parameters and "args_schema" in tool:
if hasattr(tool["args_schema"], "model_json_schema"):
parameters = tool["args_schema"].model_json_schema()
schema_output = generate_model_description(tool["args_schema"])
parameters = schema_output.get("json_schema", {}).get("schema", {})
return name, description, parameters

View File

@@ -28,6 +28,7 @@ from crewai.utilities.exceptions.context_window_exceeding_exception import (
)
from crewai.utilities.i18n import I18N
from crewai.utilities.printer import ColoredText, Printer
from crewai.utilities.pydantic_schema_utils import generate_model_description
from crewai.utilities.string_utils import sanitize_tool_name
from crewai.utilities.token_counter_callback import TokenCalcHandler
from crewai.utilities.types import LLMMessage
@@ -36,6 +37,7 @@ from crewai.utilities.types import LLMMessage
if TYPE_CHECKING:
from crewai.agent import Agent
from crewai.agents.crew_agent_executor import CrewAgentExecutor
from crewai.experimental.agent_executor import AgentExecutor
from crewai.lite_agent import LiteAgent
from crewai.llm import LLM
from crewai.task import Task
@@ -158,7 +160,8 @@ def convert_tools_to_openai_schema(
parameters: dict[str, Any] = {}
if hasattr(tool, "args_schema") and tool.args_schema is not None:
try:
parameters = tool.args_schema.model_json_schema()
schema_output = generate_model_description(tool.args_schema)
parameters = schema_output.get("json_schema", {}).get("schema", {})
# Remove title and description from schema root as they're redundant
parameters.pop("title", None)
parameters.pop("description", None)
@@ -318,7 +321,7 @@ def get_llm_response(
from_task: Task | None = None,
from_agent: Agent | LiteAgent | None = None,
response_model: type[BaseModel] | None = None,
executor_context: CrewAgentExecutor | LiteAgent | None = None,
executor_context: CrewAgentExecutor | AgentExecutor | LiteAgent | None = None,
) -> str | Any:
"""Call the LLM and return the response, handling any invalid responses.
@@ -380,7 +383,7 @@ async def aget_llm_response(
from_task: Task | None = None,
from_agent: Agent | LiteAgent | None = None,
response_model: type[BaseModel] | None = None,
executor_context: CrewAgentExecutor | None = None,
executor_context: CrewAgentExecutor | AgentExecutor | None = None,
) -> str | Any:
"""Call the LLM asynchronously and return the response.
@@ -900,7 +903,8 @@ def extract_tool_call_info(
def _setup_before_llm_call_hooks(
executor_context: CrewAgentExecutor | LiteAgent | None, printer: Printer
executor_context: CrewAgentExecutor | AgentExecutor | LiteAgent | None,
printer: Printer,
) -> bool:
"""Setup and invoke before_llm_call hooks for the executor context.
@@ -950,7 +954,7 @@ def _setup_before_llm_call_hooks(
def _setup_after_llm_call_hooks(
executor_context: CrewAgentExecutor | LiteAgent | None,
executor_context: CrewAgentExecutor | AgentExecutor | LiteAgent | None,
answer: str,
printer: Printer,
) -> str:

View File

@@ -1,14 +1,72 @@
"""Utilities for generating JSON schemas from Pydantic models.
"""Dynamic Pydantic model creation from JSON schemas.
This module provides utilities for converting JSON schemas to Pydantic models at runtime.
The main function is `create_model_from_schema`, which takes a JSON schema and returns
a dynamically created Pydantic model class.
This is used by the A2A server to honor response schemas sent by clients, allowing
structured output from agent tasks.
Based on dydantic (https://github.com/zenbase-ai/dydantic).
This module provides functions for converting Pydantic models to JSON schemas
suitable for use with LLMs and tool definitions.
"""
from __future__ import annotations
from collections.abc import Callable
from copy import deepcopy
from typing import Any
import datetime
import logging
from typing import TYPE_CHECKING, Annotated, Any, Literal, Union
import uuid
from pydantic import BaseModel
from pydantic import (
UUID1,
UUID3,
UUID4,
UUID5,
AnyUrl,
BaseModel,
ConfigDict,
DirectoryPath,
Field,
FilePath,
FileUrl,
HttpUrl,
Json,
MongoDsn,
NewPath,
PostgresDsn,
SecretBytes,
SecretStr,
StrictBytes,
create_model as create_model_base,
)
from pydantic.networks import ( # type: ignore[attr-defined]
IPv4Address,
IPv6Address,
IPvAnyAddress,
IPvAnyInterface,
IPvAnyNetwork,
)
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from pydantic import EmailStr
from pydantic.main import AnyClassMethod
else:
try:
from pydantic import EmailStr
except ImportError:
logger.warning(
"EmailStr unavailable, using str fallback",
extra={"missing_package": "email_validator"},
)
EmailStr = str
def resolve_refs(schema: dict[str, Any]) -> dict[str, Any]:
@@ -243,3 +301,319 @@ def generate_model_description(model: type[BaseModel]) -> dict[str, Any]:
"schema": json_schema,
},
}
FORMAT_TYPE_MAP: dict[str, type[Any]] = {
"base64": Annotated[bytes, Field(json_schema_extra={"format": "base64"})], # type: ignore[dict-item]
"binary": StrictBytes,
"date": datetime.date,
"time": datetime.time,
"date-time": datetime.datetime,
"duration": datetime.timedelta,
"directory-path": DirectoryPath,
"email": EmailStr,
"file-path": FilePath,
"ipv4": IPv4Address,
"ipv6": IPv6Address,
"ipvanyaddress": IPvAnyAddress, # type: ignore[dict-item]
"ipvanyinterface": IPvAnyInterface, # type: ignore[dict-item]
"ipvanynetwork": IPvAnyNetwork, # type: ignore[dict-item]
"json-string": Json,
"multi-host-uri": PostgresDsn | MongoDsn, # type: ignore[dict-item]
"password": SecretStr,
"path": NewPath,
"uri": AnyUrl,
"uuid": uuid.UUID,
"uuid1": UUID1,
"uuid3": UUID3,
"uuid4": UUID4,
"uuid5": UUID5,
}
def create_model_from_schema( # type: ignore[no-any-unimported]
json_schema: dict[str, Any],
*,
root_schema: dict[str, Any] | None = None,
__config__: ConfigDict | None = None,
__base__: type[BaseModel] | None = None,
__module__: str = __name__,
__validators__: dict[str, AnyClassMethod] | None = None,
__cls_kwargs__: dict[str, Any] | None = None,
) -> type[BaseModel]:
"""Create a Pydantic model from a JSON schema.
This function takes a JSON schema as input and dynamically creates a Pydantic
model class based on the schema. It supports various JSON schema features such
as nested objects, referenced definitions ($ref), arrays with typed items,
union types (anyOf/oneOf), and string formats.
Args:
json_schema: A dictionary representing the JSON schema.
root_schema: The root schema containing $defs. If not provided, the
current schema is treated as the root schema.
__config__: Pydantic configuration for the generated model.
__base__: Base class for the generated model. Defaults to BaseModel.
__module__: Module name for the generated model class.
__validators__: A dictionary of custom validators for the generated model.
__cls_kwargs__: Additional keyword arguments for the generated model class.
Returns:
A dynamically created Pydantic model class based on the provided JSON schema.
Example:
>>> schema = {
... "title": "Person",
... "type": "object",
... "properties": {
... "name": {"type": "string"},
... "age": {"type": "integer"},
... },
... "required": ["name"],
... }
>>> Person = create_model_from_schema(schema)
>>> person = Person(name="John", age=30)
>>> person.name
'John'
"""
effective_root = root_schema or json_schema
if "allOf" in json_schema:
json_schema = _merge_all_of_schemas(json_schema["allOf"], effective_root)
if "title" not in json_schema and "title" in (root_schema or {}):
json_schema["title"] = (root_schema or {}).get("title")
model_name = json_schema.get("title", "DynamicModel")
field_definitions = {
name: _json_schema_to_pydantic_field(
name, prop, json_schema.get("required", []), effective_root
)
for name, prop in (json_schema.get("properties", {}) or {}).items()
}
return create_model_base(
model_name,
__config__=__config__,
__base__=__base__,
__module__=__module__,
__validators__=__validators__,
__cls_kwargs__=__cls_kwargs__,
**field_definitions,
)
def _json_schema_to_pydantic_field(
name: str,
json_schema: dict[str, Any],
required: list[str],
root_schema: dict[str, Any],
) -> Any:
"""Convert a JSON schema property to a Pydantic field definition.
Args:
name: The field name.
json_schema: The JSON schema for this field.
required: List of required field names.
root_schema: The root schema for resolving $ref.
Returns:
A tuple of (type, Field) for use with create_model.
"""
type_ = _json_schema_to_pydantic_type(json_schema, root_schema, name_=name.title())
description = json_schema.get("description")
examples = json_schema.get("examples")
is_required = name in required
field_params: dict[str, Any] = {}
schema_extra: dict[str, Any] = {}
if description:
field_params["description"] = description
if examples:
schema_extra["examples"] = examples
default = ... if is_required else None
if isinstance(type_, type) and issubclass(type_, (int, float)):
if "minimum" in json_schema:
field_params["ge"] = json_schema["minimum"]
if "exclusiveMinimum" in json_schema:
field_params["gt"] = json_schema["exclusiveMinimum"]
if "maximum" in json_schema:
field_params["le"] = json_schema["maximum"]
if "exclusiveMaximum" in json_schema:
field_params["lt"] = json_schema["exclusiveMaximum"]
if "multipleOf" in json_schema:
field_params["multiple_of"] = json_schema["multipleOf"]
format_ = json_schema.get("format")
if format_ in FORMAT_TYPE_MAP:
pydantic_type = FORMAT_TYPE_MAP[format_]
if format_ == "password":
if json_schema.get("writeOnly"):
pydantic_type = SecretBytes
elif format_ == "uri":
allowed_schemes = json_schema.get("scheme")
if allowed_schemes:
if len(allowed_schemes) == 1 and allowed_schemes[0] == "http":
pydantic_type = HttpUrl
elif len(allowed_schemes) == 1 and allowed_schemes[0] == "file":
pydantic_type = FileUrl
type_ = pydantic_type
if isinstance(type_, type) and issubclass(type_, str):
if "minLength" in json_schema:
field_params["min_length"] = json_schema["minLength"]
if "maxLength" in json_schema:
field_params["max_length"] = json_schema["maxLength"]
if "pattern" in json_schema:
field_params["pattern"] = json_schema["pattern"]
if not is_required:
type_ = type_ | None
if schema_extra:
field_params["json_schema_extra"] = schema_extra
return type_, Field(default, **field_params)
def _resolve_ref(ref: str, root_schema: dict[str, Any]) -> dict[str, Any]:
"""Resolve a $ref to its actual schema.
Args:
ref: The $ref string (e.g., "#/$defs/MyType").
root_schema: The root schema containing $defs.
Returns:
The resolved schema dict.
"""
from typing import cast
ref_path = ref.split("/")
if ref.startswith("#/$defs/"):
ref_schema: dict[str, Any] = root_schema["$defs"]
start_idx = 2
else:
ref_schema = root_schema
start_idx = 1
for path in ref_path[start_idx:]:
ref_schema = cast(dict[str, Any], ref_schema[path])
return ref_schema
def _merge_all_of_schemas(
schemas: list[dict[str, Any]],
root_schema: dict[str, Any],
) -> dict[str, Any]:
"""Merge multiple allOf schemas into a single schema.
Combines properties and required fields from all schemas.
Args:
schemas: List of schemas to merge.
root_schema: The root schema for resolving $ref.
Returns:
Merged schema with combined properties and required fields.
"""
merged: dict[str, Any] = {"type": "object", "properties": {}, "required": []}
for schema in schemas:
if "$ref" in schema:
schema = _resolve_ref(schema["$ref"], root_schema)
if "properties" in schema:
merged["properties"].update(schema["properties"])
if "required" in schema:
for field in schema["required"]:
if field not in merged["required"]:
merged["required"].append(field)
if "title" in schema and "title" not in merged:
merged["title"] = schema["title"]
return merged
def _json_schema_to_pydantic_type(
json_schema: dict[str, Any],
root_schema: dict[str, Any],
*,
name_: str | None = None,
) -> Any:
"""Convert a JSON schema to a Python/Pydantic type.
Args:
json_schema: The JSON schema to convert.
root_schema: The root schema for resolving $ref.
name_: Optional name for nested models.
Returns:
A Python type corresponding to the JSON schema.
"""
ref = json_schema.get("$ref")
if ref:
ref_schema = _resolve_ref(ref, root_schema)
return _json_schema_to_pydantic_type(ref_schema, root_schema, name_=name_)
enum_values = json_schema.get("enum")
if enum_values:
return Literal[tuple(enum_values)]
if "const" in json_schema:
return Literal[json_schema["const"]]
any_of_schemas = []
if "anyOf" in json_schema or "oneOf" in json_schema:
any_of_schemas = json_schema.get("anyOf", []) + json_schema.get("oneOf", [])
if any_of_schemas:
any_of_types = [
_json_schema_to_pydantic_type(schema, root_schema)
for schema in any_of_schemas
]
return Union[tuple(any_of_types)] # noqa: UP007
all_of_schemas = json_schema.get("allOf")
if all_of_schemas:
if len(all_of_schemas) == 1:
return _json_schema_to_pydantic_type(
all_of_schemas[0], root_schema, name_=name_
)
merged = _merge_all_of_schemas(all_of_schemas, root_schema)
return _json_schema_to_pydantic_type(merged, root_schema, name_=name_)
type_ = json_schema.get("type")
if type_ == "string":
return str
if type_ == "integer":
return int
if type_ == "number":
return float
if type_ == "boolean":
return bool
if type_ == "array":
items_schema = json_schema.get("items")
if items_schema:
item_type = _json_schema_to_pydantic_type(
items_schema, root_schema, name_=name_
)
return list[item_type] # type: ignore[valid-type]
return list
if type_ == "object":
properties = json_schema.get("properties")
if properties:
json_schema_ = json_schema.copy()
if json_schema_.get("title") is None:
json_schema_["title"] = name_
return create_model_from_schema(json_schema_, root_schema=root_schema)
return dict
if type_ == "null":
return None
if type_ is None:
return Any
raise ValueError(f"Unsupported JSON schema type: {type_} from {json_schema}")

View File

@@ -26,4 +26,5 @@ class LLMMessage(TypedDict):
tool_call_id: NotRequired[str]
name: NotRequired[str]
tool_calls: NotRequired[list[dict[str, Any]]]
raw_tool_call_parts: NotRequired[list[Any]]
files: NotRequired[dict[str, FileInput]]