mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-28 09:38:17 +00:00
fix: correct tool-calling content handling and schema serialization
- fix(gemini): prevent tool calls from using stale text content; correct key refs - fix(agent-executor): resolve type errors - refactor(schema): extract Pydantic schema utilities from platform tools - fix(schema): properly serialize schemas and ensure Responses API uses a separate structure - fix: preserve list identity to avoid mutation/aliasing issues - chore(tests): update assumptions to match new behavior
This commit is contained in:
@@ -1,10 +1,11 @@
|
||||
"""Crewai Enterprise Tools."""
|
||||
import os
|
||||
|
||||
import json
|
||||
import re
|
||||
from typing import Any, Optional, Union, cast, get_origin
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
from crewai.utilities.pydantic_schema_utils import create_model_from_schema
|
||||
from pydantic import Field, create_model
|
||||
import requests
|
||||
|
||||
@@ -14,77 +15,6 @@ from crewai_tools.tools.crewai_platform_tools.misc import (
|
||||
)
|
||||
|
||||
|
||||
class AllOfSchemaAnalyzer:
|
||||
"""Helper class to analyze and merge allOf schemas."""
|
||||
|
||||
def __init__(self, schemas: list[dict[str, Any]]):
|
||||
self.schemas = schemas
|
||||
self._explicit_types: list[str] = []
|
||||
self._merged_properties: dict[str, Any] = {}
|
||||
self._merged_required: list[str] = []
|
||||
self._analyze_schemas()
|
||||
|
||||
def _analyze_schemas(self) -> None:
|
||||
"""Analyze all schemas and extract relevant information."""
|
||||
for schema in self.schemas:
|
||||
if "type" in schema:
|
||||
self._explicit_types.append(schema["type"])
|
||||
|
||||
# Merge object properties
|
||||
if schema.get("type") == "object" and "properties" in schema:
|
||||
self._merged_properties.update(schema["properties"])
|
||||
if "required" in schema:
|
||||
self._merged_required.extend(schema["required"])
|
||||
|
||||
def has_consistent_type(self) -> bool:
|
||||
"""Check if all schemas have the same explicit type."""
|
||||
return len(set(self._explicit_types)) == 1 if self._explicit_types else False
|
||||
|
||||
def get_consistent_type(self) -> type[Any]:
|
||||
"""Get the consistent type if all schemas agree."""
|
||||
if not self.has_consistent_type():
|
||||
raise ValueError("No consistent type found")
|
||||
|
||||
type_mapping = {
|
||||
"string": str,
|
||||
"integer": int,
|
||||
"number": float,
|
||||
"boolean": bool,
|
||||
"array": list,
|
||||
"object": dict,
|
||||
"null": type(None),
|
||||
}
|
||||
return type_mapping.get(self._explicit_types[0], str)
|
||||
|
||||
def has_object_schemas(self) -> bool:
|
||||
"""Check if any schemas are object types with properties."""
|
||||
return bool(self._merged_properties)
|
||||
|
||||
def get_merged_properties(self) -> dict[str, Any]:
|
||||
"""Get merged properties from all object schemas."""
|
||||
return self._merged_properties
|
||||
|
||||
def get_merged_required_fields(self) -> list[str]:
|
||||
"""Get merged required fields from all object schemas."""
|
||||
return list(set(self._merged_required)) # Remove duplicates
|
||||
|
||||
def get_fallback_type(self) -> type[Any]:
|
||||
"""Get a fallback type when merging fails."""
|
||||
if self._explicit_types:
|
||||
# Use the first explicit type
|
||||
type_mapping = {
|
||||
"string": str,
|
||||
"integer": int,
|
||||
"number": float,
|
||||
"boolean": bool,
|
||||
"array": list,
|
||||
"object": dict,
|
||||
"null": type(None),
|
||||
}
|
||||
return type_mapping.get(self._explicit_types[0], str)
|
||||
return str
|
||||
|
||||
|
||||
class CrewAIPlatformActionTool(BaseTool):
|
||||
action_name: str = Field(default="", description="The name of the action")
|
||||
action_schema: dict[str, Any] = Field(
|
||||
@@ -97,42 +27,19 @@ class CrewAIPlatformActionTool(BaseTool):
|
||||
action_name: str,
|
||||
action_schema: dict[str, Any],
|
||||
):
|
||||
self._model_registry: dict[str, type[Any]] = {}
|
||||
self._base_name = self._sanitize_name(action_name)
|
||||
|
||||
schema_props, required = self._extract_schema_info(action_schema)
|
||||
|
||||
field_definitions: dict[str, Any] = {}
|
||||
for param_name, param_details in schema_props.items():
|
||||
param_desc = param_details.get("description", "")
|
||||
is_required = param_name in required
|
||||
parameters = action_schema.get("function", {}).get("parameters", {})
|
||||
|
||||
if parameters and parameters.get("properties"):
|
||||
try:
|
||||
field_type = self._process_schema_type(
|
||||
param_details, self._sanitize_name(param_name).title()
|
||||
)
|
||||
if "title" not in parameters:
|
||||
parameters = {**parameters, "title": f"{action_name}Schema"}
|
||||
if "type" not in parameters:
|
||||
parameters = {**parameters, "type": "object"}
|
||||
args_schema = create_model_from_schema(parameters)
|
||||
except Exception:
|
||||
field_type = str
|
||||
|
||||
field_definitions[param_name] = self._create_field_definition(
|
||||
field_type, is_required, param_desc
|
||||
)
|
||||
|
||||
if field_definitions:
|
||||
try:
|
||||
args_schema = create_model(
|
||||
f"{self._base_name}Schema", **field_definitions
|
||||
)
|
||||
except Exception:
|
||||
args_schema = create_model(
|
||||
f"{self._base_name}Schema",
|
||||
input_text=(str, Field(description="Input for the action")),
|
||||
)
|
||||
args_schema = create_model(f"{action_name}Schema")
|
||||
else:
|
||||
args_schema = create_model(
|
||||
f"{self._base_name}Schema",
|
||||
input_text=(str, Field(description="Input for the action")),
|
||||
)
|
||||
args_schema = create_model(f"{action_name}Schema")
|
||||
|
||||
super().__init__(
|
||||
name=action_name.lower().replace(" ", "_"),
|
||||
@@ -142,285 +49,12 @@ class CrewAIPlatformActionTool(BaseTool):
|
||||
self.action_name = action_name
|
||||
self.action_schema = action_schema
|
||||
|
||||
@staticmethod
|
||||
def _sanitize_name(name: str) -> str:
|
||||
name = name.lower().replace(" ", "_")
|
||||
sanitized = re.sub(r"[^a-zA-Z0-9_]", "", name)
|
||||
parts = sanitized.split("_")
|
||||
return "".join(word.capitalize() for word in parts if word)
|
||||
|
||||
@staticmethod
|
||||
def _extract_schema_info(
|
||||
action_schema: dict[str, Any],
|
||||
) -> tuple[dict[str, Any], list[str]]:
|
||||
schema_props = (
|
||||
action_schema.get("function", {})
|
||||
.get("parameters", {})
|
||||
.get("properties", {})
|
||||
)
|
||||
required = (
|
||||
action_schema.get("function", {}).get("parameters", {}).get("required", [])
|
||||
)
|
||||
return schema_props, required
|
||||
|
||||
def _process_schema_type(self, schema: dict[str, Any], type_name: str) -> type[Any]:
|
||||
"""
|
||||
Process a JSON Schema type definition into a Python type.
|
||||
|
||||
Handles complex schema constructs like anyOf, oneOf, allOf, enums, arrays, and objects.
|
||||
"""
|
||||
# Handle composite schema types (anyOf, oneOf, allOf)
|
||||
if composite_type := self._process_composite_schema(schema, type_name):
|
||||
return composite_type
|
||||
|
||||
# Handle primitive types and simple constructs
|
||||
return self._process_primitive_schema(schema, type_name)
|
||||
|
||||
def _process_composite_schema(
|
||||
self, schema: dict[str, Any], type_name: str
|
||||
) -> type[Any] | None:
|
||||
"""Process composite schema types: anyOf, oneOf, allOf."""
|
||||
if "anyOf" in schema:
|
||||
return self._process_any_of_schema(schema["anyOf"], type_name)
|
||||
if "oneOf" in schema:
|
||||
return self._process_one_of_schema(schema["oneOf"], type_name)
|
||||
if "allOf" in schema:
|
||||
return self._process_all_of_schema(schema["allOf"], type_name)
|
||||
return None
|
||||
|
||||
def _process_any_of_schema(
|
||||
self, any_of_types: list[dict[str, Any]], type_name: str
|
||||
) -> type[Any]:
|
||||
"""Process anyOf schema - creates Union of possible types."""
|
||||
is_nullable = any(t.get("type") == "null" for t in any_of_types)
|
||||
non_null_types = [t for t in any_of_types if t.get("type") != "null"]
|
||||
|
||||
if not non_null_types:
|
||||
return cast(
|
||||
type[Any], cast(object, str | None)
|
||||
) # fallback for only-null case
|
||||
|
||||
base_type = (
|
||||
self._process_schema_type(non_null_types[0], type_name)
|
||||
if len(non_null_types) == 1
|
||||
else self._create_union_type(non_null_types, type_name, "AnyOf")
|
||||
)
|
||||
return base_type | None if is_nullable else base_type # type: ignore[return-value]
|
||||
|
||||
def _process_one_of_schema(
|
||||
self, one_of_types: list[dict[str, Any]], type_name: str
|
||||
) -> type[Any]:
|
||||
"""Process oneOf schema - creates Union of mutually exclusive types."""
|
||||
return (
|
||||
self._process_schema_type(one_of_types[0], type_name)
|
||||
if len(one_of_types) == 1
|
||||
else self._create_union_type(one_of_types, type_name, "OneOf")
|
||||
)
|
||||
|
||||
def _process_all_of_schema(
|
||||
self, all_of_schemas: list[dict[str, Any]], type_name: str
|
||||
) -> type[Any]:
|
||||
"""Process allOf schema - merges schemas that must all be satisfied."""
|
||||
if len(all_of_schemas) == 1:
|
||||
return self._process_schema_type(all_of_schemas[0], type_name)
|
||||
return self._merge_all_of_schemas(all_of_schemas, type_name)
|
||||
|
||||
def _create_union_type(
|
||||
self, schemas: list[dict[str, Any]], type_name: str, prefix: str
|
||||
) -> type[Any]:
|
||||
"""Create a Union type from multiple schemas."""
|
||||
return Union[ # type: ignore # noqa: UP007
|
||||
tuple(
|
||||
self._process_schema_type(schema, f"{type_name}{prefix}{i}")
|
||||
for i, schema in enumerate(schemas)
|
||||
)
|
||||
]
|
||||
|
||||
def _process_primitive_schema(
|
||||
self, schema: dict[str, Any], type_name: str
|
||||
) -> type[Any]:
|
||||
"""Process primitive schema types: string, number, array, object, etc."""
|
||||
json_type = schema.get("type", "string")
|
||||
|
||||
if "enum" in schema:
|
||||
return self._process_enum_schema(schema, json_type)
|
||||
|
||||
if json_type == "array":
|
||||
return self._process_array_schema(schema, type_name)
|
||||
|
||||
if json_type == "object":
|
||||
return self._create_nested_model(schema, type_name)
|
||||
|
||||
return self._map_json_type_to_python(json_type)
|
||||
|
||||
def _process_enum_schema(self, schema: dict[str, Any], json_type: str) -> type[Any]:
|
||||
"""Process enum schema - currently falls back to base type."""
|
||||
enum_values = schema["enum"]
|
||||
if not enum_values:
|
||||
return self._map_json_type_to_python(json_type)
|
||||
|
||||
# For Literal types, we need to pass the values directly, not as a tuple
|
||||
# This is a workaround since we can't dynamically create Literal types easily
|
||||
# Fall back to the base JSON type for now
|
||||
return self._map_json_type_to_python(json_type)
|
||||
|
||||
def _process_array_schema(
|
||||
self, schema: dict[str, Any], type_name: str
|
||||
) -> type[Any]:
|
||||
items_schema = schema.get("items", {"type": "string"})
|
||||
item_type = self._process_schema_type(items_schema, f"{type_name}Item")
|
||||
return list[item_type] # type: ignore
|
||||
|
||||
def _merge_all_of_schemas(
|
||||
self, schemas: list[dict[str, Any]], type_name: str
|
||||
) -> type[Any]:
|
||||
schema_analyzer = AllOfSchemaAnalyzer(schemas)
|
||||
|
||||
if schema_analyzer.has_consistent_type():
|
||||
return schema_analyzer.get_consistent_type()
|
||||
|
||||
if schema_analyzer.has_object_schemas():
|
||||
return self._create_merged_object_model(
|
||||
schema_analyzer.get_merged_properties(),
|
||||
schema_analyzer.get_merged_required_fields(),
|
||||
type_name,
|
||||
)
|
||||
|
||||
return schema_analyzer.get_fallback_type()
|
||||
|
||||
def _create_merged_object_model(
|
||||
self, properties: dict[str, Any], required: list[str], model_name: str
|
||||
) -> type[Any]:
|
||||
full_model_name = f"{self._base_name}{model_name}AllOf"
|
||||
|
||||
if full_model_name in self._model_registry:
|
||||
return self._model_registry[full_model_name]
|
||||
|
||||
if not properties:
|
||||
return dict
|
||||
|
||||
field_definitions = self._build_field_definitions(
|
||||
properties, required, model_name
|
||||
)
|
||||
|
||||
try:
|
||||
merged_model = create_model(full_model_name, **field_definitions)
|
||||
self._model_registry[full_model_name] = merged_model
|
||||
return merged_model
|
||||
except Exception:
|
||||
return dict
|
||||
|
||||
def _build_field_definitions(
|
||||
self, properties: dict[str, Any], required: list[str], model_name: str
|
||||
) -> dict[str, Any]:
|
||||
field_definitions = {}
|
||||
|
||||
for prop_name, prop_schema in properties.items():
|
||||
prop_desc = prop_schema.get("description", "")
|
||||
is_required = prop_name in required
|
||||
|
||||
try:
|
||||
prop_type = self._process_schema_type(
|
||||
prop_schema, f"{model_name}{self._sanitize_name(prop_name).title()}"
|
||||
)
|
||||
except Exception:
|
||||
prop_type = str
|
||||
|
||||
field_definitions[prop_name] = self._create_field_definition(
|
||||
prop_type, is_required, prop_desc
|
||||
)
|
||||
|
||||
return field_definitions
|
||||
|
||||
def _create_nested_model(
|
||||
self, schema: dict[str, Any], model_name: str
|
||||
) -> type[Any]:
|
||||
full_model_name = f"{self._base_name}{model_name}"
|
||||
|
||||
if full_model_name in self._model_registry:
|
||||
return self._model_registry[full_model_name]
|
||||
|
||||
properties = schema.get("properties", {})
|
||||
required_fields = schema.get("required", [])
|
||||
|
||||
if not properties:
|
||||
return dict
|
||||
|
||||
field_definitions = {}
|
||||
for prop_name, prop_schema in properties.items():
|
||||
prop_desc = prop_schema.get("description", "")
|
||||
is_required = prop_name in required_fields
|
||||
|
||||
try:
|
||||
prop_type = self._process_schema_type(
|
||||
prop_schema, f"{model_name}{self._sanitize_name(prop_name).title()}"
|
||||
)
|
||||
except Exception:
|
||||
prop_type = str
|
||||
|
||||
field_definitions[prop_name] = self._create_field_definition(
|
||||
prop_type, is_required, prop_desc
|
||||
)
|
||||
|
||||
try:
|
||||
nested_model = create_model(full_model_name, **field_definitions) # type: ignore
|
||||
self._model_registry[full_model_name] = nested_model
|
||||
return nested_model
|
||||
except Exception:
|
||||
return dict
|
||||
|
||||
def _create_field_definition(
|
||||
self, field_type: type[Any], is_required: bool, description: str
|
||||
) -> tuple:
|
||||
if is_required:
|
||||
return (field_type, Field(description=description))
|
||||
if get_origin(field_type) is Union:
|
||||
return (field_type, Field(default=None, description=description))
|
||||
return (
|
||||
Optional[field_type], # noqa: UP045
|
||||
Field(default=None, description=description),
|
||||
)
|
||||
|
||||
def _map_json_type_to_python(self, json_type: str) -> type[Any]:
|
||||
type_mapping = {
|
||||
"string": str,
|
||||
"integer": int,
|
||||
"number": float,
|
||||
"boolean": bool,
|
||||
"array": list,
|
||||
"object": dict,
|
||||
"null": type(None),
|
||||
}
|
||||
return type_mapping.get(json_type, str)
|
||||
|
||||
def _get_required_nullable_fields(self) -> list[str]:
|
||||
schema_props, required = self._extract_schema_info(self.action_schema)
|
||||
|
||||
required_nullable_fields = []
|
||||
for param_name in required:
|
||||
param_details = schema_props.get(param_name, {})
|
||||
if self._is_nullable_type(param_details):
|
||||
required_nullable_fields.append(param_name)
|
||||
|
||||
return required_nullable_fields
|
||||
|
||||
def _is_nullable_type(self, schema: dict[str, Any]) -> bool:
|
||||
if "anyOf" in schema:
|
||||
return any(t.get("type") == "null" for t in schema["anyOf"])
|
||||
return schema.get("type") == "null"
|
||||
|
||||
def _run(self, **kwargs) -> str:
|
||||
def _run(self, **kwargs: Any) -> str:
|
||||
try:
|
||||
cleaned_kwargs = {
|
||||
key: value for key, value in kwargs.items() if value is not None
|
||||
}
|
||||
|
||||
required_nullable_fields = self._get_required_nullable_fields()
|
||||
|
||||
for field_name in required_nullable_fields:
|
||||
if field_name not in cleaned_kwargs:
|
||||
cleaned_kwargs[field_name] = None
|
||||
|
||||
api_url = (
|
||||
f"{get_platform_api_base_url()}/actions/{self.action_name}/execute"
|
||||
)
|
||||
@@ -429,7 +63,9 @@ class CrewAIPlatformActionTool(BaseTool):
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
payload = cleaned_kwargs
|
||||
payload = {
|
||||
"integration": cleaned_kwargs if cleaned_kwargs else {"_noop": True}
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
url=api_url,
|
||||
@@ -441,7 +77,14 @@ class CrewAIPlatformActionTool(BaseTool):
|
||||
|
||||
data = response.json()
|
||||
if not response.ok:
|
||||
error_message = data.get("error", {}).get("message", json.dumps(data))
|
||||
if isinstance(data, dict):
|
||||
error_info = data.get("error", {})
|
||||
if isinstance(error_info, dict):
|
||||
error_message = error_info.get("message", json.dumps(data))
|
||||
else:
|
||||
error_message = str(error_info)
|
||||
else:
|
||||
error_message = str(data)
|
||||
return f"API request failed: {error_message}"
|
||||
|
||||
return json.dumps(data, indent=2)
|
||||
|
||||
@@ -1,5 +1,10 @@
|
||||
from typing import Any
|
||||
"""CrewAI platform tool builder for fetching and creating action tools."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from types import TracebackType
|
||||
from typing import Any
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
import requests
|
||||
|
||||
@@ -12,22 +17,29 @@ from crewai_tools.tools.crewai_platform_tools.misc import (
|
||||
)
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CrewaiPlatformToolBuilder:
|
||||
"""Builds platform tools from remote action schemas."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
apps: list[str],
|
||||
):
|
||||
) -> None:
|
||||
self._apps = apps
|
||||
self._actions_schema = {} # type: ignore[var-annotated]
|
||||
self._tools = None
|
||||
self._actions_schema: dict[str, dict[str, Any]] = {}
|
||||
self._tools: list[BaseTool] | None = None
|
||||
|
||||
def tools(self) -> list[BaseTool]:
|
||||
"""Fetch actions and return built tools."""
|
||||
if self._tools is None:
|
||||
self._fetch_actions()
|
||||
self._create_tools()
|
||||
return self._tools if self._tools is not None else []
|
||||
|
||||
def _fetch_actions(self):
|
||||
def _fetch_actions(self) -> None:
|
||||
"""Fetch action schemas from the platform API."""
|
||||
actions_url = f"{get_platform_api_base_url()}/actions"
|
||||
headers = {"Authorization": f"Bearer {get_platform_integration_token()}"}
|
||||
|
||||
@@ -40,7 +52,8 @@ class CrewaiPlatformToolBuilder:
|
||||
verify=os.environ.get("CREWAI_FACTORY", "false").lower() != "true",
|
||||
)
|
||||
response.raise_for_status()
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to fetch platform tools for apps {self._apps}: {e}")
|
||||
return
|
||||
|
||||
raw_data = response.json()
|
||||
@@ -51,6 +64,8 @@ class CrewaiPlatformToolBuilder:
|
||||
for app, action_list in action_categories.items():
|
||||
if isinstance(action_list, list):
|
||||
for action in action_list:
|
||||
if not isinstance(action, dict):
|
||||
continue
|
||||
if action_name := action.get("name"):
|
||||
action_schema = {
|
||||
"function": {
|
||||
@@ -64,72 +79,16 @@ class CrewaiPlatformToolBuilder:
|
||||
}
|
||||
self._actions_schema[action_name] = action_schema
|
||||
|
||||
def _generate_detailed_description(
|
||||
self, schema: dict[str, Any], indent: int = 0
|
||||
) -> list[str]:
|
||||
descriptions = []
|
||||
indent_str = " " * indent
|
||||
|
||||
schema_type = schema.get("type", "string")
|
||||
|
||||
if schema_type == "object":
|
||||
properties = schema.get("properties", {})
|
||||
required_fields = schema.get("required", [])
|
||||
|
||||
if properties:
|
||||
descriptions.append(f"{indent_str}Object with properties:")
|
||||
for prop_name, prop_schema in properties.items():
|
||||
prop_desc = prop_schema.get("description", "")
|
||||
is_required = prop_name in required_fields
|
||||
req_str = " (required)" if is_required else " (optional)"
|
||||
descriptions.append(
|
||||
f"{indent_str} - {prop_name}: {prop_desc}{req_str}"
|
||||
)
|
||||
|
||||
if prop_schema.get("type") == "object":
|
||||
descriptions.extend(
|
||||
self._generate_detailed_description(prop_schema, indent + 2)
|
||||
)
|
||||
elif prop_schema.get("type") == "array":
|
||||
items_schema = prop_schema.get("items", {})
|
||||
if items_schema.get("type") == "object":
|
||||
descriptions.append(f"{indent_str} Array of objects:")
|
||||
descriptions.extend(
|
||||
self._generate_detailed_description(
|
||||
items_schema, indent + 3
|
||||
)
|
||||
)
|
||||
elif "enum" in items_schema:
|
||||
descriptions.append(
|
||||
f"{indent_str} Array of enum values: {items_schema['enum']}"
|
||||
)
|
||||
elif "enum" in prop_schema:
|
||||
descriptions.append(
|
||||
f"{indent_str} Enum values: {prop_schema['enum']}"
|
||||
)
|
||||
|
||||
return descriptions
|
||||
|
||||
def _create_tools(self):
|
||||
tools = []
|
||||
def _create_tools(self) -> None:
|
||||
"""Create tool instances from fetched action schemas."""
|
||||
tools: list[BaseTool] = []
|
||||
|
||||
for action_name, action_schema in self._actions_schema.items():
|
||||
function_details = action_schema.get("function", {})
|
||||
description = function_details.get("description", f"Execute {action_name}")
|
||||
|
||||
parameters = function_details.get("parameters", {})
|
||||
param_descriptions = []
|
||||
|
||||
if parameters.get("properties"):
|
||||
param_descriptions.append("\nDetailed Parameter Structure:")
|
||||
param_descriptions.extend(
|
||||
self._generate_detailed_description(parameters)
|
||||
)
|
||||
|
||||
full_description = description + "\n".join(param_descriptions)
|
||||
|
||||
tool = CrewAIPlatformActionTool(
|
||||
description=full_description,
|
||||
description=description,
|
||||
action_name=action_name,
|
||||
action_schema=action_schema,
|
||||
)
|
||||
@@ -138,8 +97,14 @@ class CrewaiPlatformToolBuilder:
|
||||
|
||||
self._tools = tools
|
||||
|
||||
def __enter__(self):
|
||||
def __enter__(self) -> list[BaseTool]:
|
||||
"""Enter context manager and return tools."""
|
||||
return self.tools()
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
pass
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_val: BaseException | None,
|
||||
exc_tb: TracebackType | None,
|
||||
) -> None:
|
||||
"""Exit context manager."""
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
from typing import Union, get_args, get_origin
|
||||
from unittest.mock import patch, Mock
|
||||
import os
|
||||
|
||||
@@ -7,251 +6,6 @@ from crewai_tools.tools.crewai_platform_tools.crewai_platform_action_tool import
|
||||
)
|
||||
|
||||
|
||||
class TestSchemaProcessing:
|
||||
|
||||
def setup_method(self):
|
||||
self.base_action_schema = {
|
||||
"function": {
|
||||
"parameters": {
|
||||
"properties": {},
|
||||
"required": []
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def create_test_tool(self, action_name="test_action"):
|
||||
return CrewAIPlatformActionTool(
|
||||
description="Test tool",
|
||||
action_name=action_name,
|
||||
action_schema=self.base_action_schema
|
||||
)
|
||||
|
||||
def test_anyof_multiple_types(self):
|
||||
tool = self.create_test_tool()
|
||||
|
||||
test_schema = {
|
||||
"anyOf": [
|
||||
{"type": "string"},
|
||||
{"type": "number"},
|
||||
{"type": "integer"}
|
||||
]
|
||||
}
|
||||
|
||||
result_type = tool._process_schema_type(test_schema, "TestField")
|
||||
|
||||
assert get_origin(result_type) is Union
|
||||
|
||||
args = get_args(result_type)
|
||||
expected_types = (str, float, int)
|
||||
|
||||
for expected_type in expected_types:
|
||||
assert expected_type in args
|
||||
|
||||
def test_anyof_with_null(self):
|
||||
tool = self.create_test_tool()
|
||||
|
||||
test_schema = {
|
||||
"anyOf": [
|
||||
{"type": "string"},
|
||||
{"type": "number"},
|
||||
{"type": "null"}
|
||||
]
|
||||
}
|
||||
|
||||
result_type = tool._process_schema_type(test_schema, "TestFieldNullable")
|
||||
|
||||
assert get_origin(result_type) is Union
|
||||
|
||||
args = get_args(result_type)
|
||||
assert type(None) in args
|
||||
assert str in args
|
||||
assert float in args
|
||||
|
||||
def test_anyof_single_type(self):
|
||||
tool = self.create_test_tool()
|
||||
|
||||
test_schema = {
|
||||
"anyOf": [
|
||||
{"type": "string"}
|
||||
]
|
||||
}
|
||||
|
||||
result_type = tool._process_schema_type(test_schema, "TestFieldSingle")
|
||||
|
||||
assert result_type is str
|
||||
|
||||
def test_oneof_multiple_types(self):
|
||||
tool = self.create_test_tool()
|
||||
|
||||
test_schema = {
|
||||
"oneOf": [
|
||||
{"type": "string"},
|
||||
{"type": "boolean"}
|
||||
]
|
||||
}
|
||||
|
||||
result_type = tool._process_schema_type(test_schema, "TestFieldOneOf")
|
||||
|
||||
assert get_origin(result_type) is Union
|
||||
|
||||
args = get_args(result_type)
|
||||
expected_types = (str, bool)
|
||||
|
||||
for expected_type in expected_types:
|
||||
assert expected_type in args
|
||||
|
||||
def test_oneof_single_type(self):
|
||||
tool = self.create_test_tool()
|
||||
|
||||
test_schema = {
|
||||
"oneOf": [
|
||||
{"type": "integer"}
|
||||
]
|
||||
}
|
||||
|
||||
result_type = tool._process_schema_type(test_schema, "TestFieldOneOfSingle")
|
||||
|
||||
assert result_type is int
|
||||
|
||||
def test_basic_types(self):
|
||||
tool = self.create_test_tool()
|
||||
|
||||
test_cases = [
|
||||
({"type": "string"}, str),
|
||||
({"type": "integer"}, int),
|
||||
({"type": "number"}, float),
|
||||
({"type": "boolean"}, bool),
|
||||
({"type": "array", "items": {"type": "string"}}, list),
|
||||
]
|
||||
|
||||
for schema, expected_type in test_cases:
|
||||
result_type = tool._process_schema_type(schema, "TestField")
|
||||
if schema["type"] == "array":
|
||||
assert get_origin(result_type) is list
|
||||
else:
|
||||
assert result_type is expected_type
|
||||
|
||||
def test_enum_handling(self):
|
||||
tool = self.create_test_tool()
|
||||
|
||||
test_schema = {
|
||||
"type": "string",
|
||||
"enum": ["option1", "option2", "option3"]
|
||||
}
|
||||
|
||||
result_type = tool._process_schema_type(test_schema, "TestFieldEnum")
|
||||
|
||||
assert result_type is str
|
||||
|
||||
def test_nested_anyof(self):
|
||||
tool = self.create_test_tool()
|
||||
|
||||
test_schema = {
|
||||
"anyOf": [
|
||||
{"type": "string"},
|
||||
{
|
||||
"anyOf": [
|
||||
{"type": "integer"},
|
||||
{"type": "boolean"}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
result_type = tool._process_schema_type(test_schema, "TestFieldNested")
|
||||
|
||||
assert get_origin(result_type) is Union
|
||||
args = get_args(result_type)
|
||||
|
||||
assert str in args
|
||||
|
||||
if len(args) == 3:
|
||||
assert int in args
|
||||
assert bool in args
|
||||
else:
|
||||
nested_union = next(arg for arg in args if get_origin(arg) is Union)
|
||||
nested_args = get_args(nested_union)
|
||||
assert int in nested_args
|
||||
assert bool in nested_args
|
||||
|
||||
def test_allof_same_types(self):
|
||||
tool = self.create_test_tool()
|
||||
|
||||
test_schema = {
|
||||
"allOf": [
|
||||
{"type": "string"},
|
||||
{"type": "string", "maxLength": 100}
|
||||
]
|
||||
}
|
||||
|
||||
result_type = tool._process_schema_type(test_schema, "TestFieldAllOfSame")
|
||||
|
||||
assert result_type is str
|
||||
|
||||
def test_allof_object_merge(self):
|
||||
tool = self.create_test_tool()
|
||||
|
||||
test_schema = {
|
||||
"allOf": [
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"age": {"type": "integer"}
|
||||
},
|
||||
"required": ["name"]
|
||||
},
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"email": {"type": "string"},
|
||||
"age": {"type": "integer"}
|
||||
},
|
||||
"required": ["email"]
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
result_type = tool._process_schema_type(test_schema, "TestFieldAllOfMerged")
|
||||
|
||||
# Should create a merged model with all properties
|
||||
# The implementation might fall back to dict if model creation fails
|
||||
# Let's just verify it's not a basic scalar type
|
||||
assert result_type is not str
|
||||
assert result_type is not int
|
||||
assert result_type is not bool
|
||||
# It could be dict (fallback) or a proper model class
|
||||
assert result_type in (dict, type) or hasattr(result_type, '__name__')
|
||||
|
||||
def test_allof_single_schema(self):
|
||||
"""Test that allOf with single schema works correctly."""
|
||||
tool = self.create_test_tool()
|
||||
|
||||
test_schema = {
|
||||
"allOf": [
|
||||
{"type": "boolean"}
|
||||
]
|
||||
}
|
||||
|
||||
result_type = tool._process_schema_type(test_schema, "TestFieldAllOfSingle")
|
||||
|
||||
# Should be just bool
|
||||
assert result_type is bool
|
||||
|
||||
def test_allof_mixed_types(self):
|
||||
tool = self.create_test_tool()
|
||||
|
||||
test_schema = {
|
||||
"allOf": [
|
||||
{"type": "string"},
|
||||
{"type": "integer"}
|
||||
]
|
||||
}
|
||||
|
||||
result_type = tool._process_schema_type(test_schema, "TestFieldAllOfMixed")
|
||||
|
||||
assert result_type is str
|
||||
|
||||
class TestCrewAIPlatformActionToolVerify:
|
||||
"""Test suite for SSL verification behavior based on CREWAI_FACTORY environment variable"""
|
||||
|
||||
|
||||
@@ -224,43 +224,6 @@ class TestCrewaiPlatformToolBuilder(unittest.TestCase):
|
||||
_, kwargs = mock_get.call_args
|
||||
assert kwargs["params"]["apps"] == ""
|
||||
|
||||
def test_detailed_description_generation(self):
|
||||
builder = CrewaiPlatformToolBuilder(apps=["test"])
|
||||
|
||||
complex_schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"simple_string": {"type": "string", "description": "A simple string"},
|
||||
"nested_object": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"inner_prop": {
|
||||
"type": "integer",
|
||||
"description": "Inner property",
|
||||
}
|
||||
},
|
||||
"description": "Nested object",
|
||||
},
|
||||
"array_prop": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "Array of strings",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
descriptions = builder._generate_detailed_description(complex_schema)
|
||||
|
||||
assert isinstance(descriptions, list)
|
||||
assert len(descriptions) > 0
|
||||
|
||||
description_text = "\n".join(descriptions)
|
||||
assert "simple_string" in description_text
|
||||
assert "nested_object" in description_text
|
||||
assert "array_prop" in description_text
|
||||
|
||||
|
||||
|
||||
class TestCrewaiPlatformToolBuilderVerify(unittest.TestCase):
|
||||
"""Test suite for SSL verification behavior in CrewaiPlatformToolBuilder"""
|
||||
|
||||
|
||||
@@ -36,6 +36,7 @@ from crewai.hooks.llm_hooks import (
|
||||
get_after_llm_call_hooks,
|
||||
get_before_llm_call_hooks,
|
||||
)
|
||||
from crewai.hooks.types import AfterLLMCallHookType, BeforeLLMCallHookType
|
||||
from crewai.utilities.agent_utils import (
|
||||
convert_tools_to_openai_schema,
|
||||
enforce_rpm_limit,
|
||||
@@ -185,8 +186,8 @@ class AgentExecutor(Flow[AgentReActState], CrewAgentExecutorMixin):
|
||||
|
||||
self._instance_id = str(uuid4())[:8]
|
||||
|
||||
self.before_llm_call_hooks: list[Callable] = []
|
||||
self.after_llm_call_hooks: list[Callable] = []
|
||||
self.before_llm_call_hooks: list[BeforeLLMCallHookType] = []
|
||||
self.after_llm_call_hooks: list[AfterLLMCallHookType] = []
|
||||
self.before_llm_call_hooks.extend(get_before_llm_call_hooks())
|
||||
self.after_llm_call_hooks.extend(get_after_llm_call_hooks())
|
||||
|
||||
@@ -299,11 +300,21 @@ class AgentExecutor(Flow[AgentReActState], CrewAgentExecutorMixin):
|
||||
"""Compatibility property for mixin - returns state messages."""
|
||||
return self._state.messages
|
||||
|
||||
@messages.setter
|
||||
def messages(self, value: list[LLMMessage]) -> None:
|
||||
"""Set state messages."""
|
||||
self._state.messages = value
|
||||
|
||||
@property
|
||||
def iterations(self) -> int:
|
||||
"""Compatibility property for mixin - returns state iterations."""
|
||||
return self._state.iterations
|
||||
|
||||
@iterations.setter
|
||||
def iterations(self, value: int) -> None:
|
||||
"""Set state iterations."""
|
||||
self._state.iterations = value
|
||||
|
||||
@start()
|
||||
def initialize_reasoning(self) -> Literal["initialized"]:
|
||||
"""Initialize the reasoning flow and emit agent start logs."""
|
||||
@@ -577,6 +588,12 @@ class AgentExecutor(Flow[AgentReActState], CrewAgentExecutorMixin):
|
||||
"content": None,
|
||||
"tool_calls": tool_calls_to_report,
|
||||
}
|
||||
if all(
|
||||
type(tc).__qualname__ == "Part" for tc in self.state.pending_tool_calls
|
||||
):
|
||||
assistant_message["raw_tool_call_parts"] = list(
|
||||
self.state.pending_tool_calls
|
||||
)
|
||||
self.state.messages.append(assistant_message)
|
||||
|
||||
# Now execute each tool
|
||||
@@ -611,14 +628,12 @@ class AgentExecutor(Flow[AgentReActState], CrewAgentExecutorMixin):
|
||||
|
||||
# Check if tool has reached max usage count
|
||||
max_usage_reached = False
|
||||
if original_tool:
|
||||
if (
|
||||
hasattr(original_tool, "max_usage_count")
|
||||
and original_tool.max_usage_count is not None
|
||||
and original_tool.current_usage_count
|
||||
>= original_tool.max_usage_count
|
||||
):
|
||||
max_usage_reached = True
|
||||
if (
|
||||
original_tool
|
||||
and original_tool.max_usage_count is not None
|
||||
and original_tool.current_usage_count >= original_tool.max_usage_count
|
||||
):
|
||||
max_usage_reached = True
|
||||
|
||||
# Check cache before executing
|
||||
from_cache = False
|
||||
@@ -661,11 +676,7 @@ class AgentExecutor(Flow[AgentReActState], CrewAgentExecutorMixin):
|
||||
# Add to cache after successful execution (before string conversion)
|
||||
if self.tools_handler and self.tools_handler.cache:
|
||||
should_cache = True
|
||||
if (
|
||||
original_tool
|
||||
and hasattr(original_tool, "cache_function")
|
||||
and original_tool.cache_function
|
||||
):
|
||||
if original_tool:
|
||||
should_cache = original_tool.cache_function(
|
||||
args_dict, raw_result
|
||||
)
|
||||
@@ -696,7 +707,7 @@ class AgentExecutor(Flow[AgentReActState], CrewAgentExecutorMixin):
|
||||
error=e,
|
||||
),
|
||||
)
|
||||
elif max_usage_reached:
|
||||
elif max_usage_reached and original_tool:
|
||||
# Return error message when max usage limit is reached
|
||||
result = f"Tool '{func_name}' has reached its usage limit of {original_tool.max_usage_count} times and cannot be used anymore."
|
||||
|
||||
@@ -833,6 +844,10 @@ class AgentExecutor(Flow[AgentReActState], CrewAgentExecutorMixin):
|
||||
@listen("parser_error")
|
||||
def recover_from_parser_error(self) -> Literal["initialized"]:
|
||||
"""Recover from output parser errors and retry."""
|
||||
if not self._last_parser_error:
|
||||
self.state.iterations += 1
|
||||
return "initialized"
|
||||
|
||||
formatted_answer = handle_output_parser_exception(
|
||||
e=self._last_parser_error,
|
||||
messages=list(self.state.messages),
|
||||
|
||||
@@ -9,6 +9,7 @@ from crewai.utilities.printer import Printer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.agents.crew_agent_executor import CrewAgentExecutor
|
||||
from crewai.experimental.agent_executor import AgentExecutor
|
||||
from crewai.lite_agent import LiteAgent
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.utilities.types import LLMMessage
|
||||
@@ -41,7 +42,7 @@ class LLMCallHookContext:
|
||||
Can be modified by returning a new string from after_llm_call hook.
|
||||
"""
|
||||
|
||||
executor: CrewAgentExecutor | LiteAgent | None
|
||||
executor: CrewAgentExecutor | AgentExecutor | LiteAgent | None
|
||||
messages: list[LLMMessage]
|
||||
agent: Any
|
||||
task: Any
|
||||
@@ -52,7 +53,7 @@ class LLMCallHookContext:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
executor: CrewAgentExecutor | LiteAgent | None = None,
|
||||
executor: CrewAgentExecutor | AgentExecutor | LiteAgent | None = None,
|
||||
response: str | None = None,
|
||||
messages: list[LLMMessage] | None = None,
|
||||
llm: BaseLLM | str | Any | None = None, # TODO: look into
|
||||
|
||||
@@ -16,6 +16,7 @@ from crewai.utilities.agent_utils import is_context_length_exceeded
|
||||
from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
||||
LLMContextLengthExceededError,
|
||||
)
|
||||
from crewai.utilities.pydantic_schema_utils import generate_model_description
|
||||
from crewai.utilities.types import LLMMessage
|
||||
|
||||
|
||||
@@ -548,7 +549,11 @@ class BedrockCompletion(BaseLLM):
|
||||
"toolSpec": {
|
||||
"name": "structured_output",
|
||||
"description": "Returns structured data according to the schema",
|
||||
"inputSchema": {"json": response_model.model_json_schema()},
|
||||
"inputSchema": {
|
||||
"json": generate_model_description(response_model)
|
||||
.get("json_schema", {})
|
||||
.get("schema", {})
|
||||
},
|
||||
}
|
||||
}
|
||||
body["toolConfig"] = cast(
|
||||
@@ -779,7 +784,11 @@ class BedrockCompletion(BaseLLM):
|
||||
"toolSpec": {
|
||||
"name": "structured_output",
|
||||
"description": "Returns structured data according to the schema",
|
||||
"inputSchema": {"json": response_model.model_json_schema()},
|
||||
"inputSchema": {
|
||||
"json": generate_model_description(response_model)
|
||||
.get("json_schema", {})
|
||||
.get("schema", {})
|
||||
},
|
||||
}
|
||||
}
|
||||
body["toolConfig"] = cast(
|
||||
@@ -1011,7 +1020,11 @@ class BedrockCompletion(BaseLLM):
|
||||
"toolSpec": {
|
||||
"name": "structured_output",
|
||||
"description": "Returns structured data according to the schema",
|
||||
"inputSchema": {"json": response_model.model_json_schema()},
|
||||
"inputSchema": {
|
||||
"json": generate_model_description(response_model)
|
||||
.get("json_schema", {})
|
||||
.get("schema", {})
|
||||
},
|
||||
}
|
||||
}
|
||||
body["toolConfig"] = cast(
|
||||
@@ -1223,7 +1236,11 @@ class BedrockCompletion(BaseLLM):
|
||||
"toolSpec": {
|
||||
"name": "structured_output",
|
||||
"description": "Returns structured data according to the schema",
|
||||
"inputSchema": {"json": response_model.model_json_schema()},
|
||||
"inputSchema": {
|
||||
"json": generate_model_description(response_model)
|
||||
.get("json_schema", {})
|
||||
.get("schema", {})
|
||||
},
|
||||
}
|
||||
}
|
||||
body["toolConfig"] = cast(
|
||||
|
||||
@@ -15,6 +15,7 @@ from crewai.utilities.agent_utils import is_context_length_exceeded
|
||||
from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
||||
LLMContextLengthExceededError,
|
||||
)
|
||||
from crewai.utilities.pydantic_schema_utils import generate_model_description
|
||||
from crewai.utilities.types import LLMMessage
|
||||
|
||||
|
||||
@@ -464,7 +465,10 @@ class GeminiCompletion(BaseLLM):
|
||||
|
||||
if response_model:
|
||||
config_params["response_mime_type"] = "application/json"
|
||||
config_params["response_schema"] = response_model.model_json_schema()
|
||||
schema_output = generate_model_description(response_model)
|
||||
config_params["response_schema"] = schema_output.get("json_schema", {}).get(
|
||||
"schema", {}
|
||||
)
|
||||
|
||||
# Handle tools for supported models
|
||||
if tools and self.supports_tools:
|
||||
@@ -489,7 +493,7 @@ class GeminiCompletion(BaseLLM):
|
||||
function_declaration = types.FunctionDeclaration(
|
||||
name=name,
|
||||
description=description,
|
||||
parameters=parameters if parameters else None,
|
||||
parameters_json_schema=parameters if parameters else None,
|
||||
)
|
||||
|
||||
gemini_tool = types.Tool(function_declarations=[function_declaration])
|
||||
@@ -543,11 +547,10 @@ class GeminiCompletion(BaseLLM):
|
||||
else:
|
||||
parts.append(types.Part.from_text(text=str(content) if content else ""))
|
||||
|
||||
text_content: str = " ".join(p.text for p in parts if p.text is not None)
|
||||
|
||||
if role == "system":
|
||||
# Extract system instruction - Gemini handles it separately
|
||||
text_content = " ".join(
|
||||
p.text for p in parts if hasattr(p, "text") and p.text
|
||||
)
|
||||
if system_instruction:
|
||||
system_instruction += f"\n\n{text_content}"
|
||||
else:
|
||||
@@ -576,31 +579,40 @@ class GeminiCompletion(BaseLLM):
|
||||
types.Content(role="user", parts=[function_response_part])
|
||||
)
|
||||
elif role == "assistant" and message.get("tool_calls"):
|
||||
tool_parts: list[types.Part] = []
|
||||
raw_parts: list[Any] | None = message.get("raw_tool_call_parts")
|
||||
if raw_parts and all(isinstance(p, types.Part) for p in raw_parts):
|
||||
tool_parts: list[types.Part] = list(raw_parts)
|
||||
if text_content:
|
||||
tool_parts.insert(0, types.Part.from_text(text=text_content))
|
||||
else:
|
||||
tool_parts = []
|
||||
if text_content:
|
||||
tool_parts.append(types.Part.from_text(text=text_content))
|
||||
|
||||
if text_content:
|
||||
tool_parts.append(types.Part.from_text(text=text_content))
|
||||
tool_calls: list[dict[str, Any]] = message.get("tool_calls") or []
|
||||
for tool_call in tool_calls:
|
||||
func: dict[str, Any] = tool_call.get("function") or {}
|
||||
func_name: str = str(func.get("name") or "")
|
||||
func_args_raw: str | dict[str, Any] = (
|
||||
func.get("arguments") or {}
|
||||
)
|
||||
|
||||
tool_calls: list[dict[str, Any]] = message.get("tool_calls") or []
|
||||
for tool_call in tool_calls:
|
||||
func: dict[str, Any] = tool_call.get("function") or {}
|
||||
func_name: str = str(func.get("name") or "")
|
||||
func_args_raw: str | dict[str, Any] = func.get("arguments") or {}
|
||||
func_args: dict[str, Any]
|
||||
if isinstance(func_args_raw, str):
|
||||
try:
|
||||
func_args = (
|
||||
json.loads(func_args_raw) if func_args_raw else {}
|
||||
)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
func_args = {}
|
||||
else:
|
||||
func_args = func_args_raw
|
||||
|
||||
func_args: dict[str, Any]
|
||||
if isinstance(func_args_raw, str):
|
||||
try:
|
||||
func_args = (
|
||||
json.loads(func_args_raw) if func_args_raw else {}
|
||||
tool_parts.append(
|
||||
types.Part.from_function_call(
|
||||
name=func_name, args=func_args
|
||||
)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
func_args = {}
|
||||
else:
|
||||
func_args = func_args_raw
|
||||
|
||||
tool_parts.append(
|
||||
types.Part.from_function_call(name=func_name, args=func_args)
|
||||
)
|
||||
)
|
||||
|
||||
contents.append(types.Content(role="model", parts=tool_parts))
|
||||
else:
|
||||
|
||||
@@ -693,14 +693,14 @@ class OpenAICompletion(BaseLLM):
|
||||
if response_model or self.response_format:
|
||||
format_model = response_model or self.response_format
|
||||
if isinstance(format_model, type) and issubclass(format_model, BaseModel):
|
||||
schema = format_model.model_json_schema()
|
||||
schema["additionalProperties"] = False
|
||||
schema_output = generate_model_description(format_model)
|
||||
json_schema = schema_output.get("json_schema", {})
|
||||
params["text"] = {
|
||||
"format": {
|
||||
"type": "json_schema",
|
||||
"name": format_model.__name__,
|
||||
"strict": True,
|
||||
"schema": schema,
|
||||
"name": json_schema.get("name", format_model.__name__),
|
||||
"strict": json_schema.get("strict", True),
|
||||
"schema": json_schema.get("schema", {}),
|
||||
}
|
||||
}
|
||||
elif isinstance(format_model, dict):
|
||||
@@ -1060,7 +1060,7 @@ class OpenAICompletion(BaseLLM):
|
||||
chunk=delta_text,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
response_id=response_id_stream
|
||||
response_id=response_id_stream,
|
||||
)
|
||||
|
||||
elif event.type == "response.function_call_arguments.delta":
|
||||
@@ -1709,7 +1709,7 @@ class OpenAICompletion(BaseLLM):
|
||||
**parse_params, response_format=response_model
|
||||
) as stream:
|
||||
for chunk in stream:
|
||||
response_id_stream=chunk.id if hasattr(chunk,"id") else None
|
||||
response_id_stream = chunk.id if hasattr(chunk, "id") else None
|
||||
|
||||
if chunk.type == "content.delta":
|
||||
delta_content = chunk.delta
|
||||
@@ -1718,7 +1718,7 @@ class OpenAICompletion(BaseLLM):
|
||||
chunk=delta_content,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
response_id=response_id_stream
|
||||
response_id=response_id_stream,
|
||||
)
|
||||
|
||||
final_completion = stream.get_final_completion()
|
||||
@@ -1748,7 +1748,9 @@ class OpenAICompletion(BaseLLM):
|
||||
usage_data = {"total_tokens": 0}
|
||||
|
||||
for completion_chunk in completion_stream:
|
||||
response_id_stream=completion_chunk.id if hasattr(completion_chunk,"id") else None
|
||||
response_id_stream = (
|
||||
completion_chunk.id if hasattr(completion_chunk, "id") else None
|
||||
)
|
||||
|
||||
if hasattr(completion_chunk, "usage") and completion_chunk.usage:
|
||||
usage_data = self._extract_openai_token_usage(completion_chunk)
|
||||
@@ -1766,7 +1768,7 @@ class OpenAICompletion(BaseLLM):
|
||||
chunk=chunk_delta.content,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
response_id=response_id_stream
|
||||
response_id=response_id_stream,
|
||||
)
|
||||
|
||||
if chunk_delta.tool_calls:
|
||||
@@ -1805,7 +1807,7 @@ class OpenAICompletion(BaseLLM):
|
||||
"index": tool_calls[tool_index]["index"],
|
||||
},
|
||||
call_type=LLMCallType.TOOL_CALL,
|
||||
response_id=response_id_stream
|
||||
response_id=response_id_stream,
|
||||
)
|
||||
|
||||
self._track_token_usage_internal(usage_data)
|
||||
@@ -2017,7 +2019,7 @@ class OpenAICompletion(BaseLLM):
|
||||
accumulated_content = ""
|
||||
usage_data = {"total_tokens": 0}
|
||||
async for chunk in completion_stream:
|
||||
response_id_stream=chunk.id if hasattr(chunk,"id") else None
|
||||
response_id_stream = chunk.id if hasattr(chunk, "id") else None
|
||||
|
||||
if hasattr(chunk, "usage") and chunk.usage:
|
||||
usage_data = self._extract_openai_token_usage(chunk)
|
||||
@@ -2035,7 +2037,7 @@ class OpenAICompletion(BaseLLM):
|
||||
chunk=delta.content,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
response_id=response_id_stream
|
||||
response_id=response_id_stream,
|
||||
)
|
||||
|
||||
self._track_token_usage_internal(usage_data)
|
||||
@@ -2071,7 +2073,7 @@ class OpenAICompletion(BaseLLM):
|
||||
usage_data = {"total_tokens": 0}
|
||||
|
||||
async for chunk in stream:
|
||||
response_id_stream=chunk.id if hasattr(chunk,"id") else None
|
||||
response_id_stream = chunk.id if hasattr(chunk, "id") else None
|
||||
|
||||
if hasattr(chunk, "usage") and chunk.usage:
|
||||
usage_data = self._extract_openai_token_usage(chunk)
|
||||
@@ -2089,7 +2091,7 @@ class OpenAICompletion(BaseLLM):
|
||||
chunk=chunk_delta.content,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
response_id=response_id_stream
|
||||
response_id=response_id_stream,
|
||||
)
|
||||
|
||||
if chunk_delta.tool_calls:
|
||||
@@ -2128,7 +2130,7 @@ class OpenAICompletion(BaseLLM):
|
||||
"index": tool_calls[tool_index]["index"],
|
||||
},
|
||||
call_type=LLMCallType.TOOL_CALL,
|
||||
response_id=response_id_stream
|
||||
response_id=response_id_stream,
|
||||
)
|
||||
|
||||
self._track_token_usage_internal(usage_data)
|
||||
|
||||
@@ -2,6 +2,7 @@ import logging
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from crewai.utilities.pydantic_schema_utils import generate_model_description
|
||||
from crewai.utilities.string_utils import sanitize_tool_name
|
||||
|
||||
|
||||
@@ -77,7 +78,8 @@ def extract_tool_info(tool: dict[str, Any]) -> tuple[str, str, dict[str, Any]]:
|
||||
# Also check for args_schema (Pydantic format)
|
||||
if not parameters and "args_schema" in tool:
|
||||
if hasattr(tool["args_schema"], "model_json_schema"):
|
||||
parameters = tool["args_schema"].model_json_schema()
|
||||
schema_output = generate_model_description(tool["args_schema"])
|
||||
parameters = schema_output.get("json_schema", {}).get("schema", {})
|
||||
|
||||
return name, description, parameters
|
||||
|
||||
|
||||
@@ -28,6 +28,7 @@ from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
||||
)
|
||||
from crewai.utilities.i18n import I18N
|
||||
from crewai.utilities.printer import ColoredText, Printer
|
||||
from crewai.utilities.pydantic_schema_utils import generate_model_description
|
||||
from crewai.utilities.string_utils import sanitize_tool_name
|
||||
from crewai.utilities.token_counter_callback import TokenCalcHandler
|
||||
from crewai.utilities.types import LLMMessage
|
||||
@@ -36,6 +37,7 @@ from crewai.utilities.types import LLMMessage
|
||||
if TYPE_CHECKING:
|
||||
from crewai.agent import Agent
|
||||
from crewai.agents.crew_agent_executor import CrewAgentExecutor
|
||||
from crewai.experimental.agent_executor import AgentExecutor
|
||||
from crewai.lite_agent import LiteAgent
|
||||
from crewai.llm import LLM
|
||||
from crewai.task import Task
|
||||
@@ -158,7 +160,8 @@ def convert_tools_to_openai_schema(
|
||||
parameters: dict[str, Any] = {}
|
||||
if hasattr(tool, "args_schema") and tool.args_schema is not None:
|
||||
try:
|
||||
parameters = tool.args_schema.model_json_schema()
|
||||
schema_output = generate_model_description(tool.args_schema)
|
||||
parameters = schema_output.get("json_schema", {}).get("schema", {})
|
||||
# Remove title and description from schema root as they're redundant
|
||||
parameters.pop("title", None)
|
||||
parameters.pop("description", None)
|
||||
@@ -318,7 +321,7 @@ def get_llm_response(
|
||||
from_task: Task | None = None,
|
||||
from_agent: Agent | LiteAgent | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
executor_context: CrewAgentExecutor | LiteAgent | None = None,
|
||||
executor_context: CrewAgentExecutor | AgentExecutor | LiteAgent | None = None,
|
||||
) -> str | Any:
|
||||
"""Call the LLM and return the response, handling any invalid responses.
|
||||
|
||||
@@ -380,7 +383,7 @@ async def aget_llm_response(
|
||||
from_task: Task | None = None,
|
||||
from_agent: Agent | LiteAgent | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
executor_context: CrewAgentExecutor | None = None,
|
||||
executor_context: CrewAgentExecutor | AgentExecutor | None = None,
|
||||
) -> str | Any:
|
||||
"""Call the LLM asynchronously and return the response.
|
||||
|
||||
@@ -900,7 +903,8 @@ def extract_tool_call_info(
|
||||
|
||||
|
||||
def _setup_before_llm_call_hooks(
|
||||
executor_context: CrewAgentExecutor | LiteAgent | None, printer: Printer
|
||||
executor_context: CrewAgentExecutor | AgentExecutor | LiteAgent | None,
|
||||
printer: Printer,
|
||||
) -> bool:
|
||||
"""Setup and invoke before_llm_call hooks for the executor context.
|
||||
|
||||
@@ -950,7 +954,7 @@ def _setup_before_llm_call_hooks(
|
||||
|
||||
|
||||
def _setup_after_llm_call_hooks(
|
||||
executor_context: CrewAgentExecutor | LiteAgent | None,
|
||||
executor_context: CrewAgentExecutor | AgentExecutor | LiteAgent | None,
|
||||
answer: str,
|
||||
printer: Printer,
|
||||
) -> str:
|
||||
|
||||
@@ -1,14 +1,72 @@
|
||||
"""Utilities for generating JSON schemas from Pydantic models.
|
||||
"""Dynamic Pydantic model creation from JSON schemas.
|
||||
|
||||
This module provides utilities for converting JSON schemas to Pydantic models at runtime.
|
||||
The main function is `create_model_from_schema`, which takes a JSON schema and returns
|
||||
a dynamically created Pydantic model class.
|
||||
|
||||
This is used by the A2A server to honor response schemas sent by clients, allowing
|
||||
structured output from agent tasks.
|
||||
|
||||
Based on dydantic (https://github.com/zenbase-ai/dydantic).
|
||||
|
||||
This module provides functions for converting Pydantic models to JSON schemas
|
||||
suitable for use with LLMs and tool definitions.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from copy import deepcopy
|
||||
from typing import Any
|
||||
import datetime
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Annotated, Any, Literal, Union
|
||||
import uuid
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import (
|
||||
UUID1,
|
||||
UUID3,
|
||||
UUID4,
|
||||
UUID5,
|
||||
AnyUrl,
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
DirectoryPath,
|
||||
Field,
|
||||
FilePath,
|
||||
FileUrl,
|
||||
HttpUrl,
|
||||
Json,
|
||||
MongoDsn,
|
||||
NewPath,
|
||||
PostgresDsn,
|
||||
SecretBytes,
|
||||
SecretStr,
|
||||
StrictBytes,
|
||||
create_model as create_model_base,
|
||||
)
|
||||
from pydantic.networks import ( # type: ignore[attr-defined]
|
||||
IPv4Address,
|
||||
IPv6Address,
|
||||
IPvAnyAddress,
|
||||
IPvAnyInterface,
|
||||
IPvAnyNetwork,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pydantic import EmailStr
|
||||
from pydantic.main import AnyClassMethod
|
||||
else:
|
||||
try:
|
||||
from pydantic import EmailStr
|
||||
except ImportError:
|
||||
logger.warning(
|
||||
"EmailStr unavailable, using str fallback",
|
||||
extra={"missing_package": "email_validator"},
|
||||
)
|
||||
EmailStr = str
|
||||
|
||||
|
||||
def resolve_refs(schema: dict[str, Any]) -> dict[str, Any]:
|
||||
@@ -243,3 +301,319 @@ def generate_model_description(model: type[BaseModel]) -> dict[str, Any]:
|
||||
"schema": json_schema,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
FORMAT_TYPE_MAP: dict[str, type[Any]] = {
|
||||
"base64": Annotated[bytes, Field(json_schema_extra={"format": "base64"})], # type: ignore[dict-item]
|
||||
"binary": StrictBytes,
|
||||
"date": datetime.date,
|
||||
"time": datetime.time,
|
||||
"date-time": datetime.datetime,
|
||||
"duration": datetime.timedelta,
|
||||
"directory-path": DirectoryPath,
|
||||
"email": EmailStr,
|
||||
"file-path": FilePath,
|
||||
"ipv4": IPv4Address,
|
||||
"ipv6": IPv6Address,
|
||||
"ipvanyaddress": IPvAnyAddress, # type: ignore[dict-item]
|
||||
"ipvanyinterface": IPvAnyInterface, # type: ignore[dict-item]
|
||||
"ipvanynetwork": IPvAnyNetwork, # type: ignore[dict-item]
|
||||
"json-string": Json,
|
||||
"multi-host-uri": PostgresDsn | MongoDsn, # type: ignore[dict-item]
|
||||
"password": SecretStr,
|
||||
"path": NewPath,
|
||||
"uri": AnyUrl,
|
||||
"uuid": uuid.UUID,
|
||||
"uuid1": UUID1,
|
||||
"uuid3": UUID3,
|
||||
"uuid4": UUID4,
|
||||
"uuid5": UUID5,
|
||||
}
|
||||
|
||||
|
||||
def create_model_from_schema( # type: ignore[no-any-unimported]
|
||||
json_schema: dict[str, Any],
|
||||
*,
|
||||
root_schema: dict[str, Any] | None = None,
|
||||
__config__: ConfigDict | None = None,
|
||||
__base__: type[BaseModel] | None = None,
|
||||
__module__: str = __name__,
|
||||
__validators__: dict[str, AnyClassMethod] | None = None,
|
||||
__cls_kwargs__: dict[str, Any] | None = None,
|
||||
) -> type[BaseModel]:
|
||||
"""Create a Pydantic model from a JSON schema.
|
||||
|
||||
This function takes a JSON schema as input and dynamically creates a Pydantic
|
||||
model class based on the schema. It supports various JSON schema features such
|
||||
as nested objects, referenced definitions ($ref), arrays with typed items,
|
||||
union types (anyOf/oneOf), and string formats.
|
||||
|
||||
Args:
|
||||
json_schema: A dictionary representing the JSON schema.
|
||||
root_schema: The root schema containing $defs. If not provided, the
|
||||
current schema is treated as the root schema.
|
||||
__config__: Pydantic configuration for the generated model.
|
||||
__base__: Base class for the generated model. Defaults to BaseModel.
|
||||
__module__: Module name for the generated model class.
|
||||
__validators__: A dictionary of custom validators for the generated model.
|
||||
__cls_kwargs__: Additional keyword arguments for the generated model class.
|
||||
|
||||
Returns:
|
||||
A dynamically created Pydantic model class based on the provided JSON schema.
|
||||
|
||||
Example:
|
||||
>>> schema = {
|
||||
... "title": "Person",
|
||||
... "type": "object",
|
||||
... "properties": {
|
||||
... "name": {"type": "string"},
|
||||
... "age": {"type": "integer"},
|
||||
... },
|
||||
... "required": ["name"],
|
||||
... }
|
||||
>>> Person = create_model_from_schema(schema)
|
||||
>>> person = Person(name="John", age=30)
|
||||
>>> person.name
|
||||
'John'
|
||||
"""
|
||||
effective_root = root_schema or json_schema
|
||||
|
||||
if "allOf" in json_schema:
|
||||
json_schema = _merge_all_of_schemas(json_schema["allOf"], effective_root)
|
||||
if "title" not in json_schema and "title" in (root_schema or {}):
|
||||
json_schema["title"] = (root_schema or {}).get("title")
|
||||
|
||||
model_name = json_schema.get("title", "DynamicModel")
|
||||
field_definitions = {
|
||||
name: _json_schema_to_pydantic_field(
|
||||
name, prop, json_schema.get("required", []), effective_root
|
||||
)
|
||||
for name, prop in (json_schema.get("properties", {}) or {}).items()
|
||||
}
|
||||
|
||||
return create_model_base(
|
||||
model_name,
|
||||
__config__=__config__,
|
||||
__base__=__base__,
|
||||
__module__=__module__,
|
||||
__validators__=__validators__,
|
||||
__cls_kwargs__=__cls_kwargs__,
|
||||
**field_definitions,
|
||||
)
|
||||
|
||||
|
||||
def _json_schema_to_pydantic_field(
|
||||
name: str,
|
||||
json_schema: dict[str, Any],
|
||||
required: list[str],
|
||||
root_schema: dict[str, Any],
|
||||
) -> Any:
|
||||
"""Convert a JSON schema property to a Pydantic field definition.
|
||||
|
||||
Args:
|
||||
name: The field name.
|
||||
json_schema: The JSON schema for this field.
|
||||
required: List of required field names.
|
||||
root_schema: The root schema for resolving $ref.
|
||||
|
||||
Returns:
|
||||
A tuple of (type, Field) for use with create_model.
|
||||
"""
|
||||
type_ = _json_schema_to_pydantic_type(json_schema, root_schema, name_=name.title())
|
||||
description = json_schema.get("description")
|
||||
examples = json_schema.get("examples")
|
||||
is_required = name in required
|
||||
|
||||
field_params: dict[str, Any] = {}
|
||||
schema_extra: dict[str, Any] = {}
|
||||
|
||||
if description:
|
||||
field_params["description"] = description
|
||||
if examples:
|
||||
schema_extra["examples"] = examples
|
||||
|
||||
default = ... if is_required else None
|
||||
|
||||
if isinstance(type_, type) and issubclass(type_, (int, float)):
|
||||
if "minimum" in json_schema:
|
||||
field_params["ge"] = json_schema["minimum"]
|
||||
if "exclusiveMinimum" in json_schema:
|
||||
field_params["gt"] = json_schema["exclusiveMinimum"]
|
||||
if "maximum" in json_schema:
|
||||
field_params["le"] = json_schema["maximum"]
|
||||
if "exclusiveMaximum" in json_schema:
|
||||
field_params["lt"] = json_schema["exclusiveMaximum"]
|
||||
if "multipleOf" in json_schema:
|
||||
field_params["multiple_of"] = json_schema["multipleOf"]
|
||||
|
||||
format_ = json_schema.get("format")
|
||||
if format_ in FORMAT_TYPE_MAP:
|
||||
pydantic_type = FORMAT_TYPE_MAP[format_]
|
||||
|
||||
if format_ == "password":
|
||||
if json_schema.get("writeOnly"):
|
||||
pydantic_type = SecretBytes
|
||||
elif format_ == "uri":
|
||||
allowed_schemes = json_schema.get("scheme")
|
||||
if allowed_schemes:
|
||||
if len(allowed_schemes) == 1 and allowed_schemes[0] == "http":
|
||||
pydantic_type = HttpUrl
|
||||
elif len(allowed_schemes) == 1 and allowed_schemes[0] == "file":
|
||||
pydantic_type = FileUrl
|
||||
|
||||
type_ = pydantic_type
|
||||
|
||||
if isinstance(type_, type) and issubclass(type_, str):
|
||||
if "minLength" in json_schema:
|
||||
field_params["min_length"] = json_schema["minLength"]
|
||||
if "maxLength" in json_schema:
|
||||
field_params["max_length"] = json_schema["maxLength"]
|
||||
if "pattern" in json_schema:
|
||||
field_params["pattern"] = json_schema["pattern"]
|
||||
|
||||
if not is_required:
|
||||
type_ = type_ | None
|
||||
|
||||
if schema_extra:
|
||||
field_params["json_schema_extra"] = schema_extra
|
||||
|
||||
return type_, Field(default, **field_params)
|
||||
|
||||
|
||||
def _resolve_ref(ref: str, root_schema: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Resolve a $ref to its actual schema.
|
||||
|
||||
Args:
|
||||
ref: The $ref string (e.g., "#/$defs/MyType").
|
||||
root_schema: The root schema containing $defs.
|
||||
|
||||
Returns:
|
||||
The resolved schema dict.
|
||||
"""
|
||||
from typing import cast
|
||||
|
||||
ref_path = ref.split("/")
|
||||
if ref.startswith("#/$defs/"):
|
||||
ref_schema: dict[str, Any] = root_schema["$defs"]
|
||||
start_idx = 2
|
||||
else:
|
||||
ref_schema = root_schema
|
||||
start_idx = 1
|
||||
for path in ref_path[start_idx:]:
|
||||
ref_schema = cast(dict[str, Any], ref_schema[path])
|
||||
return ref_schema
|
||||
|
||||
|
||||
def _merge_all_of_schemas(
|
||||
schemas: list[dict[str, Any]],
|
||||
root_schema: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
"""Merge multiple allOf schemas into a single schema.
|
||||
|
||||
Combines properties and required fields from all schemas.
|
||||
|
||||
Args:
|
||||
schemas: List of schemas to merge.
|
||||
root_schema: The root schema for resolving $ref.
|
||||
|
||||
Returns:
|
||||
Merged schema with combined properties and required fields.
|
||||
"""
|
||||
merged: dict[str, Any] = {"type": "object", "properties": {}, "required": []}
|
||||
|
||||
for schema in schemas:
|
||||
if "$ref" in schema:
|
||||
schema = _resolve_ref(schema["$ref"], root_schema)
|
||||
|
||||
if "properties" in schema:
|
||||
merged["properties"].update(schema["properties"])
|
||||
|
||||
if "required" in schema:
|
||||
for field in schema["required"]:
|
||||
if field not in merged["required"]:
|
||||
merged["required"].append(field)
|
||||
|
||||
if "title" in schema and "title" not in merged:
|
||||
merged["title"] = schema["title"]
|
||||
|
||||
return merged
|
||||
|
||||
|
||||
def _json_schema_to_pydantic_type(
|
||||
json_schema: dict[str, Any],
|
||||
root_schema: dict[str, Any],
|
||||
*,
|
||||
name_: str | None = None,
|
||||
) -> Any:
|
||||
"""Convert a JSON schema to a Python/Pydantic type.
|
||||
|
||||
Args:
|
||||
json_schema: The JSON schema to convert.
|
||||
root_schema: The root schema for resolving $ref.
|
||||
name_: Optional name for nested models.
|
||||
|
||||
Returns:
|
||||
A Python type corresponding to the JSON schema.
|
||||
"""
|
||||
ref = json_schema.get("$ref")
|
||||
if ref:
|
||||
ref_schema = _resolve_ref(ref, root_schema)
|
||||
return _json_schema_to_pydantic_type(ref_schema, root_schema, name_=name_)
|
||||
|
||||
enum_values = json_schema.get("enum")
|
||||
if enum_values:
|
||||
return Literal[tuple(enum_values)]
|
||||
|
||||
if "const" in json_schema:
|
||||
return Literal[json_schema["const"]]
|
||||
|
||||
any_of_schemas = []
|
||||
if "anyOf" in json_schema or "oneOf" in json_schema:
|
||||
any_of_schemas = json_schema.get("anyOf", []) + json_schema.get("oneOf", [])
|
||||
if any_of_schemas:
|
||||
any_of_types = [
|
||||
_json_schema_to_pydantic_type(schema, root_schema)
|
||||
for schema in any_of_schemas
|
||||
]
|
||||
return Union[tuple(any_of_types)] # noqa: UP007
|
||||
|
||||
all_of_schemas = json_schema.get("allOf")
|
||||
if all_of_schemas:
|
||||
if len(all_of_schemas) == 1:
|
||||
return _json_schema_to_pydantic_type(
|
||||
all_of_schemas[0], root_schema, name_=name_
|
||||
)
|
||||
merged = _merge_all_of_schemas(all_of_schemas, root_schema)
|
||||
return _json_schema_to_pydantic_type(merged, root_schema, name_=name_)
|
||||
|
||||
type_ = json_schema.get("type")
|
||||
|
||||
if type_ == "string":
|
||||
return str
|
||||
if type_ == "integer":
|
||||
return int
|
||||
if type_ == "number":
|
||||
return float
|
||||
if type_ == "boolean":
|
||||
return bool
|
||||
if type_ == "array":
|
||||
items_schema = json_schema.get("items")
|
||||
if items_schema:
|
||||
item_type = _json_schema_to_pydantic_type(
|
||||
items_schema, root_schema, name_=name_
|
||||
)
|
||||
return list[item_type] # type: ignore[valid-type]
|
||||
return list
|
||||
if type_ == "object":
|
||||
properties = json_schema.get("properties")
|
||||
if properties:
|
||||
json_schema_ = json_schema.copy()
|
||||
if json_schema_.get("title") is None:
|
||||
json_schema_["title"] = name_
|
||||
return create_model_from_schema(json_schema_, root_schema=root_schema)
|
||||
return dict
|
||||
if type_ == "null":
|
||||
return None
|
||||
if type_ is None:
|
||||
return Any
|
||||
raise ValueError(f"Unsupported JSON schema type: {type_} from {json_schema}")
|
||||
|
||||
@@ -26,4 +26,5 @@ class LLMMessage(TypedDict):
|
||||
tool_call_id: NotRequired[str]
|
||||
name: NotRequired[str]
|
||||
tool_calls: NotRequired[list[dict[str, Any]]]
|
||||
raw_tool_call_parts: NotRequired[list[Any]]
|
||||
files: NotRequired[dict[str, FileInput]]
|
||||
|
||||
Reference in New Issue
Block a user