diff --git a/lib/crewai-tools/src/crewai_tools/tools/crewai_platform_tools/crewai_platform_action_tool.py b/lib/crewai-tools/src/crewai_tools/tools/crewai_platform_tools/crewai_platform_action_tool.py index 098256940..3a3ae3be9 100644 --- a/lib/crewai-tools/src/crewai_tools/tools/crewai_platform_tools/crewai_platform_action_tool.py +++ b/lib/crewai-tools/src/crewai_tools/tools/crewai_platform_tools/crewai_platform_action_tool.py @@ -1,10 +1,11 @@ """Crewai Enterprise Tools.""" -import os + import json -import re -from typing import Any, Optional, Union, cast, get_origin +import os +from typing import Any from crewai.tools import BaseTool +from crewai.utilities.pydantic_schema_utils import create_model_from_schema from pydantic import Field, create_model import requests @@ -14,77 +15,6 @@ from crewai_tools.tools.crewai_platform_tools.misc import ( ) -class AllOfSchemaAnalyzer: - """Helper class to analyze and merge allOf schemas.""" - - def __init__(self, schemas: list[dict[str, Any]]): - self.schemas = schemas - self._explicit_types: list[str] = [] - self._merged_properties: dict[str, Any] = {} - self._merged_required: list[str] = [] - self._analyze_schemas() - - def _analyze_schemas(self) -> None: - """Analyze all schemas and extract relevant information.""" - for schema in self.schemas: - if "type" in schema: - self._explicit_types.append(schema["type"]) - - # Merge object properties - if schema.get("type") == "object" and "properties" in schema: - self._merged_properties.update(schema["properties"]) - if "required" in schema: - self._merged_required.extend(schema["required"]) - - def has_consistent_type(self) -> bool: - """Check if all schemas have the same explicit type.""" - return len(set(self._explicit_types)) == 1 if self._explicit_types else False - - def get_consistent_type(self) -> type[Any]: - """Get the consistent type if all schemas agree.""" - if not self.has_consistent_type(): - raise ValueError("No consistent type found") - - type_mapping = { - "string": str, - "integer": int, - "number": float, - "boolean": bool, - "array": list, - "object": dict, - "null": type(None), - } - return type_mapping.get(self._explicit_types[0], str) - - def has_object_schemas(self) -> bool: - """Check if any schemas are object types with properties.""" - return bool(self._merged_properties) - - def get_merged_properties(self) -> dict[str, Any]: - """Get merged properties from all object schemas.""" - return self._merged_properties - - def get_merged_required_fields(self) -> list[str]: - """Get merged required fields from all object schemas.""" - return list(set(self._merged_required)) # Remove duplicates - - def get_fallback_type(self) -> type[Any]: - """Get a fallback type when merging fails.""" - if self._explicit_types: - # Use the first explicit type - type_mapping = { - "string": str, - "integer": int, - "number": float, - "boolean": bool, - "array": list, - "object": dict, - "null": type(None), - } - return type_mapping.get(self._explicit_types[0], str) - return str - - class CrewAIPlatformActionTool(BaseTool): action_name: str = Field(default="", description="The name of the action") action_schema: dict[str, Any] = Field( @@ -97,42 +27,19 @@ class CrewAIPlatformActionTool(BaseTool): action_name: str, action_schema: dict[str, Any], ): - self._model_registry: dict[str, type[Any]] = {} - self._base_name = self._sanitize_name(action_name) - - schema_props, required = self._extract_schema_info(action_schema) - - field_definitions: dict[str, Any] = {} - for param_name, param_details in schema_props.items(): - param_desc = param_details.get("description", "") - is_required = param_name in required + parameters = action_schema.get("function", {}).get("parameters", {}) + if parameters and parameters.get("properties"): try: - field_type = self._process_schema_type( - param_details, self._sanitize_name(param_name).title() - ) + if "title" not in parameters: + parameters = {**parameters, "title": f"{action_name}Schema"} + if "type" not in parameters: + parameters = {**parameters, "type": "object"} + args_schema = create_model_from_schema(parameters) except Exception: - field_type = str - - field_definitions[param_name] = self._create_field_definition( - field_type, is_required, param_desc - ) - - if field_definitions: - try: - args_schema = create_model( - f"{self._base_name}Schema", **field_definitions - ) - except Exception: - args_schema = create_model( - f"{self._base_name}Schema", - input_text=(str, Field(description="Input for the action")), - ) + args_schema = create_model(f"{action_name}Schema") else: - args_schema = create_model( - f"{self._base_name}Schema", - input_text=(str, Field(description="Input for the action")), - ) + args_schema = create_model(f"{action_name}Schema") super().__init__( name=action_name.lower().replace(" ", "_"), @@ -142,285 +49,12 @@ class CrewAIPlatformActionTool(BaseTool): self.action_name = action_name self.action_schema = action_schema - @staticmethod - def _sanitize_name(name: str) -> str: - name = name.lower().replace(" ", "_") - sanitized = re.sub(r"[^a-zA-Z0-9_]", "", name) - parts = sanitized.split("_") - return "".join(word.capitalize() for word in parts if word) - - @staticmethod - def _extract_schema_info( - action_schema: dict[str, Any], - ) -> tuple[dict[str, Any], list[str]]: - schema_props = ( - action_schema.get("function", {}) - .get("parameters", {}) - .get("properties", {}) - ) - required = ( - action_schema.get("function", {}).get("parameters", {}).get("required", []) - ) - return schema_props, required - - def _process_schema_type(self, schema: dict[str, Any], type_name: str) -> type[Any]: - """ - Process a JSON Schema type definition into a Python type. - - Handles complex schema constructs like anyOf, oneOf, allOf, enums, arrays, and objects. - """ - # Handle composite schema types (anyOf, oneOf, allOf) - if composite_type := self._process_composite_schema(schema, type_name): - return composite_type - - # Handle primitive types and simple constructs - return self._process_primitive_schema(schema, type_name) - - def _process_composite_schema( - self, schema: dict[str, Any], type_name: str - ) -> type[Any] | None: - """Process composite schema types: anyOf, oneOf, allOf.""" - if "anyOf" in schema: - return self._process_any_of_schema(schema["anyOf"], type_name) - if "oneOf" in schema: - return self._process_one_of_schema(schema["oneOf"], type_name) - if "allOf" in schema: - return self._process_all_of_schema(schema["allOf"], type_name) - return None - - def _process_any_of_schema( - self, any_of_types: list[dict[str, Any]], type_name: str - ) -> type[Any]: - """Process anyOf schema - creates Union of possible types.""" - is_nullable = any(t.get("type") == "null" for t in any_of_types) - non_null_types = [t for t in any_of_types if t.get("type") != "null"] - - if not non_null_types: - return cast( - type[Any], cast(object, str | None) - ) # fallback for only-null case - - base_type = ( - self._process_schema_type(non_null_types[0], type_name) - if len(non_null_types) == 1 - else self._create_union_type(non_null_types, type_name, "AnyOf") - ) - return base_type | None if is_nullable else base_type # type: ignore[return-value] - - def _process_one_of_schema( - self, one_of_types: list[dict[str, Any]], type_name: str - ) -> type[Any]: - """Process oneOf schema - creates Union of mutually exclusive types.""" - return ( - self._process_schema_type(one_of_types[0], type_name) - if len(one_of_types) == 1 - else self._create_union_type(one_of_types, type_name, "OneOf") - ) - - def _process_all_of_schema( - self, all_of_schemas: list[dict[str, Any]], type_name: str - ) -> type[Any]: - """Process allOf schema - merges schemas that must all be satisfied.""" - if len(all_of_schemas) == 1: - return self._process_schema_type(all_of_schemas[0], type_name) - return self._merge_all_of_schemas(all_of_schemas, type_name) - - def _create_union_type( - self, schemas: list[dict[str, Any]], type_name: str, prefix: str - ) -> type[Any]: - """Create a Union type from multiple schemas.""" - return Union[ # type: ignore # noqa: UP007 - tuple( - self._process_schema_type(schema, f"{type_name}{prefix}{i}") - for i, schema in enumerate(schemas) - ) - ] - - def _process_primitive_schema( - self, schema: dict[str, Any], type_name: str - ) -> type[Any]: - """Process primitive schema types: string, number, array, object, etc.""" - json_type = schema.get("type", "string") - - if "enum" in schema: - return self._process_enum_schema(schema, json_type) - - if json_type == "array": - return self._process_array_schema(schema, type_name) - - if json_type == "object": - return self._create_nested_model(schema, type_name) - - return self._map_json_type_to_python(json_type) - - def _process_enum_schema(self, schema: dict[str, Any], json_type: str) -> type[Any]: - """Process enum schema - currently falls back to base type.""" - enum_values = schema["enum"] - if not enum_values: - return self._map_json_type_to_python(json_type) - - # For Literal types, we need to pass the values directly, not as a tuple - # This is a workaround since we can't dynamically create Literal types easily - # Fall back to the base JSON type for now - return self._map_json_type_to_python(json_type) - - def _process_array_schema( - self, schema: dict[str, Any], type_name: str - ) -> type[Any]: - items_schema = schema.get("items", {"type": "string"}) - item_type = self._process_schema_type(items_schema, f"{type_name}Item") - return list[item_type] # type: ignore - - def _merge_all_of_schemas( - self, schemas: list[dict[str, Any]], type_name: str - ) -> type[Any]: - schema_analyzer = AllOfSchemaAnalyzer(schemas) - - if schema_analyzer.has_consistent_type(): - return schema_analyzer.get_consistent_type() - - if schema_analyzer.has_object_schemas(): - return self._create_merged_object_model( - schema_analyzer.get_merged_properties(), - schema_analyzer.get_merged_required_fields(), - type_name, - ) - - return schema_analyzer.get_fallback_type() - - def _create_merged_object_model( - self, properties: dict[str, Any], required: list[str], model_name: str - ) -> type[Any]: - full_model_name = f"{self._base_name}{model_name}AllOf" - - if full_model_name in self._model_registry: - return self._model_registry[full_model_name] - - if not properties: - return dict - - field_definitions = self._build_field_definitions( - properties, required, model_name - ) - - try: - merged_model = create_model(full_model_name, **field_definitions) - self._model_registry[full_model_name] = merged_model - return merged_model - except Exception: - return dict - - def _build_field_definitions( - self, properties: dict[str, Any], required: list[str], model_name: str - ) -> dict[str, Any]: - field_definitions = {} - - for prop_name, prop_schema in properties.items(): - prop_desc = prop_schema.get("description", "") - is_required = prop_name in required - - try: - prop_type = self._process_schema_type( - prop_schema, f"{model_name}{self._sanitize_name(prop_name).title()}" - ) - except Exception: - prop_type = str - - field_definitions[prop_name] = self._create_field_definition( - prop_type, is_required, prop_desc - ) - - return field_definitions - - def _create_nested_model( - self, schema: dict[str, Any], model_name: str - ) -> type[Any]: - full_model_name = f"{self._base_name}{model_name}" - - if full_model_name in self._model_registry: - return self._model_registry[full_model_name] - - properties = schema.get("properties", {}) - required_fields = schema.get("required", []) - - if not properties: - return dict - - field_definitions = {} - for prop_name, prop_schema in properties.items(): - prop_desc = prop_schema.get("description", "") - is_required = prop_name in required_fields - - try: - prop_type = self._process_schema_type( - prop_schema, f"{model_name}{self._sanitize_name(prop_name).title()}" - ) - except Exception: - prop_type = str - - field_definitions[prop_name] = self._create_field_definition( - prop_type, is_required, prop_desc - ) - - try: - nested_model = create_model(full_model_name, **field_definitions) # type: ignore - self._model_registry[full_model_name] = nested_model - return nested_model - except Exception: - return dict - - def _create_field_definition( - self, field_type: type[Any], is_required: bool, description: str - ) -> tuple: - if is_required: - return (field_type, Field(description=description)) - if get_origin(field_type) is Union: - return (field_type, Field(default=None, description=description)) - return ( - Optional[field_type], # noqa: UP045 - Field(default=None, description=description), - ) - - def _map_json_type_to_python(self, json_type: str) -> type[Any]: - type_mapping = { - "string": str, - "integer": int, - "number": float, - "boolean": bool, - "array": list, - "object": dict, - "null": type(None), - } - return type_mapping.get(json_type, str) - - def _get_required_nullable_fields(self) -> list[str]: - schema_props, required = self._extract_schema_info(self.action_schema) - - required_nullable_fields = [] - for param_name in required: - param_details = schema_props.get(param_name, {}) - if self._is_nullable_type(param_details): - required_nullable_fields.append(param_name) - - return required_nullable_fields - - def _is_nullable_type(self, schema: dict[str, Any]) -> bool: - if "anyOf" in schema: - return any(t.get("type") == "null" for t in schema["anyOf"]) - return schema.get("type") == "null" - - def _run(self, **kwargs) -> str: + def _run(self, **kwargs: Any) -> str: try: cleaned_kwargs = { key: value for key, value in kwargs.items() if value is not None } - required_nullable_fields = self._get_required_nullable_fields() - - for field_name in required_nullable_fields: - if field_name not in cleaned_kwargs: - cleaned_kwargs[field_name] = None - api_url = ( f"{get_platform_api_base_url()}/actions/{self.action_name}/execute" ) @@ -429,7 +63,9 @@ class CrewAIPlatformActionTool(BaseTool): "Authorization": f"Bearer {token}", "Content-Type": "application/json", } - payload = cleaned_kwargs + payload = { + "integration": cleaned_kwargs if cleaned_kwargs else {"_noop": True} + } response = requests.post( url=api_url, @@ -441,7 +77,14 @@ class CrewAIPlatformActionTool(BaseTool): data = response.json() if not response.ok: - error_message = data.get("error", {}).get("message", json.dumps(data)) + if isinstance(data, dict): + error_info = data.get("error", {}) + if isinstance(error_info, dict): + error_message = error_info.get("message", json.dumps(data)) + else: + error_message = str(error_info) + else: + error_message = str(data) return f"API request failed: {error_message}" return json.dumps(data, indent=2) diff --git a/lib/crewai-tools/src/crewai_tools/tools/crewai_platform_tools/crewai_platform_tool_builder.py b/lib/crewai-tools/src/crewai_tools/tools/crewai_platform_tools/crewai_platform_tool_builder.py index 564637189..e9cc8c3e6 100644 --- a/lib/crewai-tools/src/crewai_tools/tools/crewai_platform_tools/crewai_platform_tool_builder.py +++ b/lib/crewai-tools/src/crewai_tools/tools/crewai_platform_tools/crewai_platform_tool_builder.py @@ -1,5 +1,10 @@ -from typing import Any +"""CrewAI platform tool builder for fetching and creating action tools.""" + +import logging import os +from types import TracebackType +from typing import Any + from crewai.tools import BaseTool import requests @@ -12,22 +17,29 @@ from crewai_tools.tools.crewai_platform_tools.misc import ( ) +logger = logging.getLogger(__name__) + + class CrewaiPlatformToolBuilder: + """Builds platform tools from remote action schemas.""" + def __init__( self, apps: list[str], - ): + ) -> None: self._apps = apps - self._actions_schema = {} # type: ignore[var-annotated] - self._tools = None + self._actions_schema: dict[str, dict[str, Any]] = {} + self._tools: list[BaseTool] | None = None def tools(self) -> list[BaseTool]: + """Fetch actions and return built tools.""" if self._tools is None: self._fetch_actions() self._create_tools() return self._tools if self._tools is not None else [] - def _fetch_actions(self): + def _fetch_actions(self) -> None: + """Fetch action schemas from the platform API.""" actions_url = f"{get_platform_api_base_url()}/actions" headers = {"Authorization": f"Bearer {get_platform_integration_token()}"} @@ -40,7 +52,8 @@ class CrewaiPlatformToolBuilder: verify=os.environ.get("CREWAI_FACTORY", "false").lower() != "true", ) response.raise_for_status() - except Exception: + except Exception as e: + logger.error(f"Failed to fetch platform tools for apps {self._apps}: {e}") return raw_data = response.json() @@ -51,6 +64,8 @@ class CrewaiPlatformToolBuilder: for app, action_list in action_categories.items(): if isinstance(action_list, list): for action in action_list: + if not isinstance(action, dict): + continue if action_name := action.get("name"): action_schema = { "function": { @@ -64,72 +79,16 @@ class CrewaiPlatformToolBuilder: } self._actions_schema[action_name] = action_schema - def _generate_detailed_description( - self, schema: dict[str, Any], indent: int = 0 - ) -> list[str]: - descriptions = [] - indent_str = " " * indent - - schema_type = schema.get("type", "string") - - if schema_type == "object": - properties = schema.get("properties", {}) - required_fields = schema.get("required", []) - - if properties: - descriptions.append(f"{indent_str}Object with properties:") - for prop_name, prop_schema in properties.items(): - prop_desc = prop_schema.get("description", "") - is_required = prop_name in required_fields - req_str = " (required)" if is_required else " (optional)" - descriptions.append( - f"{indent_str} - {prop_name}: {prop_desc}{req_str}" - ) - - if prop_schema.get("type") == "object": - descriptions.extend( - self._generate_detailed_description(prop_schema, indent + 2) - ) - elif prop_schema.get("type") == "array": - items_schema = prop_schema.get("items", {}) - if items_schema.get("type") == "object": - descriptions.append(f"{indent_str} Array of objects:") - descriptions.extend( - self._generate_detailed_description( - items_schema, indent + 3 - ) - ) - elif "enum" in items_schema: - descriptions.append( - f"{indent_str} Array of enum values: {items_schema['enum']}" - ) - elif "enum" in prop_schema: - descriptions.append( - f"{indent_str} Enum values: {prop_schema['enum']}" - ) - - return descriptions - - def _create_tools(self): - tools = [] + def _create_tools(self) -> None: + """Create tool instances from fetched action schemas.""" + tools: list[BaseTool] = [] for action_name, action_schema in self._actions_schema.items(): function_details = action_schema.get("function", {}) description = function_details.get("description", f"Execute {action_name}") - parameters = function_details.get("parameters", {}) - param_descriptions = [] - - if parameters.get("properties"): - param_descriptions.append("\nDetailed Parameter Structure:") - param_descriptions.extend( - self._generate_detailed_description(parameters) - ) - - full_description = description + "\n".join(param_descriptions) - tool = CrewAIPlatformActionTool( - description=full_description, + description=description, action_name=action_name, action_schema=action_schema, ) @@ -138,8 +97,14 @@ class CrewaiPlatformToolBuilder: self._tools = tools - def __enter__(self): + def __enter__(self) -> list[BaseTool]: + """Enter context manager and return tools.""" return self.tools() - def __exit__(self, exc_type, exc_val, exc_tb): - pass + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + """Exit context manager.""" diff --git a/lib/crewai-tools/tests/tools/crewai_platform_tools/test_crewai_platform_action_tool.py b/lib/crewai-tools/tests/tools/crewai_platform_tools/test_crewai_platform_action_tool.py index 5bbbb3f91..92b7f19f0 100644 --- a/lib/crewai-tools/tests/tools/crewai_platform_tools/test_crewai_platform_action_tool.py +++ b/lib/crewai-tools/tests/tools/crewai_platform_tools/test_crewai_platform_action_tool.py @@ -1,4 +1,3 @@ -from typing import Union, get_args, get_origin from unittest.mock import patch, Mock import os @@ -7,251 +6,6 @@ from crewai_tools.tools.crewai_platform_tools.crewai_platform_action_tool import ) -class TestSchemaProcessing: - - def setup_method(self): - self.base_action_schema = { - "function": { - "parameters": { - "properties": {}, - "required": [] - } - } - } - - def create_test_tool(self, action_name="test_action"): - return CrewAIPlatformActionTool( - description="Test tool", - action_name=action_name, - action_schema=self.base_action_schema - ) - - def test_anyof_multiple_types(self): - tool = self.create_test_tool() - - test_schema = { - "anyOf": [ - {"type": "string"}, - {"type": "number"}, - {"type": "integer"} - ] - } - - result_type = tool._process_schema_type(test_schema, "TestField") - - assert get_origin(result_type) is Union - - args = get_args(result_type) - expected_types = (str, float, int) - - for expected_type in expected_types: - assert expected_type in args - - def test_anyof_with_null(self): - tool = self.create_test_tool() - - test_schema = { - "anyOf": [ - {"type": "string"}, - {"type": "number"}, - {"type": "null"} - ] - } - - result_type = tool._process_schema_type(test_schema, "TestFieldNullable") - - assert get_origin(result_type) is Union - - args = get_args(result_type) - assert type(None) in args - assert str in args - assert float in args - - def test_anyof_single_type(self): - tool = self.create_test_tool() - - test_schema = { - "anyOf": [ - {"type": "string"} - ] - } - - result_type = tool._process_schema_type(test_schema, "TestFieldSingle") - - assert result_type is str - - def test_oneof_multiple_types(self): - tool = self.create_test_tool() - - test_schema = { - "oneOf": [ - {"type": "string"}, - {"type": "boolean"} - ] - } - - result_type = tool._process_schema_type(test_schema, "TestFieldOneOf") - - assert get_origin(result_type) is Union - - args = get_args(result_type) - expected_types = (str, bool) - - for expected_type in expected_types: - assert expected_type in args - - def test_oneof_single_type(self): - tool = self.create_test_tool() - - test_schema = { - "oneOf": [ - {"type": "integer"} - ] - } - - result_type = tool._process_schema_type(test_schema, "TestFieldOneOfSingle") - - assert result_type is int - - def test_basic_types(self): - tool = self.create_test_tool() - - test_cases = [ - ({"type": "string"}, str), - ({"type": "integer"}, int), - ({"type": "number"}, float), - ({"type": "boolean"}, bool), - ({"type": "array", "items": {"type": "string"}}, list), - ] - - for schema, expected_type in test_cases: - result_type = tool._process_schema_type(schema, "TestField") - if schema["type"] == "array": - assert get_origin(result_type) is list - else: - assert result_type is expected_type - - def test_enum_handling(self): - tool = self.create_test_tool() - - test_schema = { - "type": "string", - "enum": ["option1", "option2", "option3"] - } - - result_type = tool._process_schema_type(test_schema, "TestFieldEnum") - - assert result_type is str - - def test_nested_anyof(self): - tool = self.create_test_tool() - - test_schema = { - "anyOf": [ - {"type": "string"}, - { - "anyOf": [ - {"type": "integer"}, - {"type": "boolean"} - ] - } - ] - } - - result_type = tool._process_schema_type(test_schema, "TestFieldNested") - - assert get_origin(result_type) is Union - args = get_args(result_type) - - assert str in args - - if len(args) == 3: - assert int in args - assert bool in args - else: - nested_union = next(arg for arg in args if get_origin(arg) is Union) - nested_args = get_args(nested_union) - assert int in nested_args - assert bool in nested_args - - def test_allof_same_types(self): - tool = self.create_test_tool() - - test_schema = { - "allOf": [ - {"type": "string"}, - {"type": "string", "maxLength": 100} - ] - } - - result_type = tool._process_schema_type(test_schema, "TestFieldAllOfSame") - - assert result_type is str - - def test_allof_object_merge(self): - tool = self.create_test_tool() - - test_schema = { - "allOf": [ - { - "type": "object", - "properties": { - "name": {"type": "string"}, - "age": {"type": "integer"} - }, - "required": ["name"] - }, - { - "type": "object", - "properties": { - "email": {"type": "string"}, - "age": {"type": "integer"} - }, - "required": ["email"] - } - ] - } - - result_type = tool._process_schema_type(test_schema, "TestFieldAllOfMerged") - - # Should create a merged model with all properties - # The implementation might fall back to dict if model creation fails - # Let's just verify it's not a basic scalar type - assert result_type is not str - assert result_type is not int - assert result_type is not bool - # It could be dict (fallback) or a proper model class - assert result_type in (dict, type) or hasattr(result_type, '__name__') - - def test_allof_single_schema(self): - """Test that allOf with single schema works correctly.""" - tool = self.create_test_tool() - - test_schema = { - "allOf": [ - {"type": "boolean"} - ] - } - - result_type = tool._process_schema_type(test_schema, "TestFieldAllOfSingle") - - # Should be just bool - assert result_type is bool - - def test_allof_mixed_types(self): - tool = self.create_test_tool() - - test_schema = { - "allOf": [ - {"type": "string"}, - {"type": "integer"} - ] - } - - result_type = tool._process_schema_type(test_schema, "TestFieldAllOfMixed") - - assert result_type is str - class TestCrewAIPlatformActionToolVerify: """Test suite for SSL verification behavior based on CREWAI_FACTORY environment variable""" diff --git a/lib/crewai-tools/tests/tools/crewai_platform_tools/test_crewai_platform_tool_builder.py b/lib/crewai-tools/tests/tools/crewai_platform_tools/test_crewai_platform_tool_builder.py index 880312d44..7703f2104 100644 --- a/lib/crewai-tools/tests/tools/crewai_platform_tools/test_crewai_platform_tool_builder.py +++ b/lib/crewai-tools/tests/tools/crewai_platform_tools/test_crewai_platform_tool_builder.py @@ -224,43 +224,6 @@ class TestCrewaiPlatformToolBuilder(unittest.TestCase): _, kwargs = mock_get.call_args assert kwargs["params"]["apps"] == "" - def test_detailed_description_generation(self): - builder = CrewaiPlatformToolBuilder(apps=["test"]) - - complex_schema = { - "type": "object", - "properties": { - "simple_string": {"type": "string", "description": "A simple string"}, - "nested_object": { - "type": "object", - "properties": { - "inner_prop": { - "type": "integer", - "description": "Inner property", - } - }, - "description": "Nested object", - }, - "array_prop": { - "type": "array", - "items": {"type": "string"}, - "description": "Array of strings", - }, - }, - } - - descriptions = builder._generate_detailed_description(complex_schema) - - assert isinstance(descriptions, list) - assert len(descriptions) > 0 - - description_text = "\n".join(descriptions) - assert "simple_string" in description_text - assert "nested_object" in description_text - assert "array_prop" in description_text - - - class TestCrewaiPlatformToolBuilderVerify(unittest.TestCase): """Test suite for SSL verification behavior in CrewaiPlatformToolBuilder""" diff --git a/lib/crewai/src/crewai/experimental/agent_executor.py b/lib/crewai/src/crewai/experimental/agent_executor.py index 7ad7ad025..e8bcdf460 100644 --- a/lib/crewai/src/crewai/experimental/agent_executor.py +++ b/lib/crewai/src/crewai/experimental/agent_executor.py @@ -36,6 +36,7 @@ from crewai.hooks.llm_hooks import ( get_after_llm_call_hooks, get_before_llm_call_hooks, ) +from crewai.hooks.types import AfterLLMCallHookType, BeforeLLMCallHookType from crewai.utilities.agent_utils import ( convert_tools_to_openai_schema, enforce_rpm_limit, @@ -185,8 +186,8 @@ class AgentExecutor(Flow[AgentReActState], CrewAgentExecutorMixin): self._instance_id = str(uuid4())[:8] - self.before_llm_call_hooks: list[Callable] = [] - self.after_llm_call_hooks: list[Callable] = [] + self.before_llm_call_hooks: list[BeforeLLMCallHookType] = [] + self.after_llm_call_hooks: list[AfterLLMCallHookType] = [] self.before_llm_call_hooks.extend(get_before_llm_call_hooks()) self.after_llm_call_hooks.extend(get_after_llm_call_hooks()) @@ -299,11 +300,21 @@ class AgentExecutor(Flow[AgentReActState], CrewAgentExecutorMixin): """Compatibility property for mixin - returns state messages.""" return self._state.messages + @messages.setter + def messages(self, value: list[LLMMessage]) -> None: + """Set state messages.""" + self._state.messages = value + @property def iterations(self) -> int: """Compatibility property for mixin - returns state iterations.""" return self._state.iterations + @iterations.setter + def iterations(self, value: int) -> None: + """Set state iterations.""" + self._state.iterations = value + @start() def initialize_reasoning(self) -> Literal["initialized"]: """Initialize the reasoning flow and emit agent start logs.""" @@ -577,6 +588,12 @@ class AgentExecutor(Flow[AgentReActState], CrewAgentExecutorMixin): "content": None, "tool_calls": tool_calls_to_report, } + if all( + type(tc).__qualname__ == "Part" for tc in self.state.pending_tool_calls + ): + assistant_message["raw_tool_call_parts"] = list( + self.state.pending_tool_calls + ) self.state.messages.append(assistant_message) # Now execute each tool @@ -611,14 +628,12 @@ class AgentExecutor(Flow[AgentReActState], CrewAgentExecutorMixin): # Check if tool has reached max usage count max_usage_reached = False - if original_tool: - if ( - hasattr(original_tool, "max_usage_count") - and original_tool.max_usage_count is not None - and original_tool.current_usage_count - >= original_tool.max_usage_count - ): - max_usage_reached = True + if ( + original_tool + and original_tool.max_usage_count is not None + and original_tool.current_usage_count >= original_tool.max_usage_count + ): + max_usage_reached = True # Check cache before executing from_cache = False @@ -661,11 +676,7 @@ class AgentExecutor(Flow[AgentReActState], CrewAgentExecutorMixin): # Add to cache after successful execution (before string conversion) if self.tools_handler and self.tools_handler.cache: should_cache = True - if ( - original_tool - and hasattr(original_tool, "cache_function") - and original_tool.cache_function - ): + if original_tool: should_cache = original_tool.cache_function( args_dict, raw_result ) @@ -696,7 +707,7 @@ class AgentExecutor(Flow[AgentReActState], CrewAgentExecutorMixin): error=e, ), ) - elif max_usage_reached: + elif max_usage_reached and original_tool: # Return error message when max usage limit is reached result = f"Tool '{func_name}' has reached its usage limit of {original_tool.max_usage_count} times and cannot be used anymore." @@ -833,6 +844,10 @@ class AgentExecutor(Flow[AgentReActState], CrewAgentExecutorMixin): @listen("parser_error") def recover_from_parser_error(self) -> Literal["initialized"]: """Recover from output parser errors and retry.""" + if not self._last_parser_error: + self.state.iterations += 1 + return "initialized" + formatted_answer = handle_output_parser_exception( e=self._last_parser_error, messages=list(self.state.messages), diff --git a/lib/crewai/src/crewai/hooks/llm_hooks.py b/lib/crewai/src/crewai/hooks/llm_hooks.py index 2388396c9..2f5462fe0 100644 --- a/lib/crewai/src/crewai/hooks/llm_hooks.py +++ b/lib/crewai/src/crewai/hooks/llm_hooks.py @@ -9,6 +9,7 @@ from crewai.utilities.printer import Printer if TYPE_CHECKING: from crewai.agents.crew_agent_executor import CrewAgentExecutor + from crewai.experimental.agent_executor import AgentExecutor from crewai.lite_agent import LiteAgent from crewai.llms.base_llm import BaseLLM from crewai.utilities.types import LLMMessage @@ -41,7 +42,7 @@ class LLMCallHookContext: Can be modified by returning a new string from after_llm_call hook. """ - executor: CrewAgentExecutor | LiteAgent | None + executor: CrewAgentExecutor | AgentExecutor | LiteAgent | None messages: list[LLMMessage] agent: Any task: Any @@ -52,7 +53,7 @@ class LLMCallHookContext: def __init__( self, - executor: CrewAgentExecutor | LiteAgent | None = None, + executor: CrewAgentExecutor | AgentExecutor | LiteAgent | None = None, response: str | None = None, messages: list[LLMMessage] | None = None, llm: BaseLLM | str | Any | None = None, # TODO: look into diff --git a/lib/crewai/src/crewai/llms/providers/bedrock/completion.py b/lib/crewai/src/crewai/llms/providers/bedrock/completion.py index 2ab638f11..850835ff1 100644 --- a/lib/crewai/src/crewai/llms/providers/bedrock/completion.py +++ b/lib/crewai/src/crewai/llms/providers/bedrock/completion.py @@ -16,6 +16,7 @@ from crewai.utilities.agent_utils import is_context_length_exceeded from crewai.utilities.exceptions.context_window_exceeding_exception import ( LLMContextLengthExceededError, ) +from crewai.utilities.pydantic_schema_utils import generate_model_description from crewai.utilities.types import LLMMessage @@ -548,7 +549,11 @@ class BedrockCompletion(BaseLLM): "toolSpec": { "name": "structured_output", "description": "Returns structured data according to the schema", - "inputSchema": {"json": response_model.model_json_schema()}, + "inputSchema": { + "json": generate_model_description(response_model) + .get("json_schema", {}) + .get("schema", {}) + }, } } body["toolConfig"] = cast( @@ -779,7 +784,11 @@ class BedrockCompletion(BaseLLM): "toolSpec": { "name": "structured_output", "description": "Returns structured data according to the schema", - "inputSchema": {"json": response_model.model_json_schema()}, + "inputSchema": { + "json": generate_model_description(response_model) + .get("json_schema", {}) + .get("schema", {}) + }, } } body["toolConfig"] = cast( @@ -1011,7 +1020,11 @@ class BedrockCompletion(BaseLLM): "toolSpec": { "name": "structured_output", "description": "Returns structured data according to the schema", - "inputSchema": {"json": response_model.model_json_schema()}, + "inputSchema": { + "json": generate_model_description(response_model) + .get("json_schema", {}) + .get("schema", {}) + }, } } body["toolConfig"] = cast( @@ -1223,7 +1236,11 @@ class BedrockCompletion(BaseLLM): "toolSpec": { "name": "structured_output", "description": "Returns structured data according to the schema", - "inputSchema": {"json": response_model.model_json_schema()}, + "inputSchema": { + "json": generate_model_description(response_model) + .get("json_schema", {}) + .get("schema", {}) + }, } } body["toolConfig"] = cast( diff --git a/lib/crewai/src/crewai/llms/providers/gemini/completion.py b/lib/crewai/src/crewai/llms/providers/gemini/completion.py index d101ad0be..950cdab57 100644 --- a/lib/crewai/src/crewai/llms/providers/gemini/completion.py +++ b/lib/crewai/src/crewai/llms/providers/gemini/completion.py @@ -15,6 +15,7 @@ from crewai.utilities.agent_utils import is_context_length_exceeded from crewai.utilities.exceptions.context_window_exceeding_exception import ( LLMContextLengthExceededError, ) +from crewai.utilities.pydantic_schema_utils import generate_model_description from crewai.utilities.types import LLMMessage @@ -464,7 +465,10 @@ class GeminiCompletion(BaseLLM): if response_model: config_params["response_mime_type"] = "application/json" - config_params["response_schema"] = response_model.model_json_schema() + schema_output = generate_model_description(response_model) + config_params["response_schema"] = schema_output.get("json_schema", {}).get( + "schema", {} + ) # Handle tools for supported models if tools and self.supports_tools: @@ -489,7 +493,7 @@ class GeminiCompletion(BaseLLM): function_declaration = types.FunctionDeclaration( name=name, description=description, - parameters=parameters if parameters else None, + parameters_json_schema=parameters if parameters else None, ) gemini_tool = types.Tool(function_declarations=[function_declaration]) @@ -543,11 +547,10 @@ class GeminiCompletion(BaseLLM): else: parts.append(types.Part.from_text(text=str(content) if content else "")) + text_content: str = " ".join(p.text for p in parts if p.text is not None) + if role == "system": # Extract system instruction - Gemini handles it separately - text_content = " ".join( - p.text for p in parts if hasattr(p, "text") and p.text - ) if system_instruction: system_instruction += f"\n\n{text_content}" else: @@ -576,31 +579,40 @@ class GeminiCompletion(BaseLLM): types.Content(role="user", parts=[function_response_part]) ) elif role == "assistant" and message.get("tool_calls"): - tool_parts: list[types.Part] = [] + raw_parts: list[Any] | None = message.get("raw_tool_call_parts") + if raw_parts and all(isinstance(p, types.Part) for p in raw_parts): + tool_parts: list[types.Part] = list(raw_parts) + if text_content: + tool_parts.insert(0, types.Part.from_text(text=text_content)) + else: + tool_parts = [] + if text_content: + tool_parts.append(types.Part.from_text(text=text_content)) - if text_content: - tool_parts.append(types.Part.from_text(text=text_content)) + tool_calls: list[dict[str, Any]] = message.get("tool_calls") or [] + for tool_call in tool_calls: + func: dict[str, Any] = tool_call.get("function") or {} + func_name: str = str(func.get("name") or "") + func_args_raw: str | dict[str, Any] = ( + func.get("arguments") or {} + ) - tool_calls: list[dict[str, Any]] = message.get("tool_calls") or [] - for tool_call in tool_calls: - func: dict[str, Any] = tool_call.get("function") or {} - func_name: str = str(func.get("name") or "") - func_args_raw: str | dict[str, Any] = func.get("arguments") or {} + func_args: dict[str, Any] + if isinstance(func_args_raw, str): + try: + func_args = ( + json.loads(func_args_raw) if func_args_raw else {} + ) + except (json.JSONDecodeError, TypeError): + func_args = {} + else: + func_args = func_args_raw - func_args: dict[str, Any] - if isinstance(func_args_raw, str): - try: - func_args = ( - json.loads(func_args_raw) if func_args_raw else {} + tool_parts.append( + types.Part.from_function_call( + name=func_name, args=func_args ) - except (json.JSONDecodeError, TypeError): - func_args = {} - else: - func_args = func_args_raw - - tool_parts.append( - types.Part.from_function_call(name=func_name, args=func_args) - ) + ) contents.append(types.Content(role="model", parts=tool_parts)) else: diff --git a/lib/crewai/src/crewai/llms/providers/openai/completion.py b/lib/crewai/src/crewai/llms/providers/openai/completion.py index 56a6fa2cb..be7991970 100644 --- a/lib/crewai/src/crewai/llms/providers/openai/completion.py +++ b/lib/crewai/src/crewai/llms/providers/openai/completion.py @@ -693,14 +693,14 @@ class OpenAICompletion(BaseLLM): if response_model or self.response_format: format_model = response_model or self.response_format if isinstance(format_model, type) and issubclass(format_model, BaseModel): - schema = format_model.model_json_schema() - schema["additionalProperties"] = False + schema_output = generate_model_description(format_model) + json_schema = schema_output.get("json_schema", {}) params["text"] = { "format": { "type": "json_schema", - "name": format_model.__name__, - "strict": True, - "schema": schema, + "name": json_schema.get("name", format_model.__name__), + "strict": json_schema.get("strict", True), + "schema": json_schema.get("schema", {}), } } elif isinstance(format_model, dict): @@ -1060,7 +1060,7 @@ class OpenAICompletion(BaseLLM): chunk=delta_text, from_task=from_task, from_agent=from_agent, - response_id=response_id_stream + response_id=response_id_stream, ) elif event.type == "response.function_call_arguments.delta": @@ -1709,7 +1709,7 @@ class OpenAICompletion(BaseLLM): **parse_params, response_format=response_model ) as stream: for chunk in stream: - response_id_stream=chunk.id if hasattr(chunk,"id") else None + response_id_stream = chunk.id if hasattr(chunk, "id") else None if chunk.type == "content.delta": delta_content = chunk.delta @@ -1718,7 +1718,7 @@ class OpenAICompletion(BaseLLM): chunk=delta_content, from_task=from_task, from_agent=from_agent, - response_id=response_id_stream + response_id=response_id_stream, ) final_completion = stream.get_final_completion() @@ -1748,7 +1748,9 @@ class OpenAICompletion(BaseLLM): usage_data = {"total_tokens": 0} for completion_chunk in completion_stream: - response_id_stream=completion_chunk.id if hasattr(completion_chunk,"id") else None + response_id_stream = ( + completion_chunk.id if hasattr(completion_chunk, "id") else None + ) if hasattr(completion_chunk, "usage") and completion_chunk.usage: usage_data = self._extract_openai_token_usage(completion_chunk) @@ -1766,7 +1768,7 @@ class OpenAICompletion(BaseLLM): chunk=chunk_delta.content, from_task=from_task, from_agent=from_agent, - response_id=response_id_stream + response_id=response_id_stream, ) if chunk_delta.tool_calls: @@ -1805,7 +1807,7 @@ class OpenAICompletion(BaseLLM): "index": tool_calls[tool_index]["index"], }, call_type=LLMCallType.TOOL_CALL, - response_id=response_id_stream + response_id=response_id_stream, ) self._track_token_usage_internal(usage_data) @@ -2017,7 +2019,7 @@ class OpenAICompletion(BaseLLM): accumulated_content = "" usage_data = {"total_tokens": 0} async for chunk in completion_stream: - response_id_stream=chunk.id if hasattr(chunk,"id") else None + response_id_stream = chunk.id if hasattr(chunk, "id") else None if hasattr(chunk, "usage") and chunk.usage: usage_data = self._extract_openai_token_usage(chunk) @@ -2035,7 +2037,7 @@ class OpenAICompletion(BaseLLM): chunk=delta.content, from_task=from_task, from_agent=from_agent, - response_id=response_id_stream + response_id=response_id_stream, ) self._track_token_usage_internal(usage_data) @@ -2071,7 +2073,7 @@ class OpenAICompletion(BaseLLM): usage_data = {"total_tokens": 0} async for chunk in stream: - response_id_stream=chunk.id if hasattr(chunk,"id") else None + response_id_stream = chunk.id if hasattr(chunk, "id") else None if hasattr(chunk, "usage") and chunk.usage: usage_data = self._extract_openai_token_usage(chunk) @@ -2089,7 +2091,7 @@ class OpenAICompletion(BaseLLM): chunk=chunk_delta.content, from_task=from_task, from_agent=from_agent, - response_id=response_id_stream + response_id=response_id_stream, ) if chunk_delta.tool_calls: @@ -2128,7 +2130,7 @@ class OpenAICompletion(BaseLLM): "index": tool_calls[tool_index]["index"], }, call_type=LLMCallType.TOOL_CALL, - response_id=response_id_stream + response_id=response_id_stream, ) self._track_token_usage_internal(usage_data) diff --git a/lib/crewai/src/crewai/llms/providers/utils/common.py b/lib/crewai/src/crewai/llms/providers/utils/common.py index 9f95c6ce8..f3bec9b2a 100644 --- a/lib/crewai/src/crewai/llms/providers/utils/common.py +++ b/lib/crewai/src/crewai/llms/providers/utils/common.py @@ -2,6 +2,7 @@ import logging import re from typing import Any +from crewai.utilities.pydantic_schema_utils import generate_model_description from crewai.utilities.string_utils import sanitize_tool_name @@ -77,7 +78,8 @@ def extract_tool_info(tool: dict[str, Any]) -> tuple[str, str, dict[str, Any]]: # Also check for args_schema (Pydantic format) if not parameters and "args_schema" in tool: if hasattr(tool["args_schema"], "model_json_schema"): - parameters = tool["args_schema"].model_json_schema() + schema_output = generate_model_description(tool["args_schema"]) + parameters = schema_output.get("json_schema", {}).get("schema", {}) return name, description, parameters diff --git a/lib/crewai/src/crewai/utilities/agent_utils.py b/lib/crewai/src/crewai/utilities/agent_utils.py index 4b927f726..af4f464d9 100644 --- a/lib/crewai/src/crewai/utilities/agent_utils.py +++ b/lib/crewai/src/crewai/utilities/agent_utils.py @@ -28,6 +28,7 @@ from crewai.utilities.exceptions.context_window_exceeding_exception import ( ) from crewai.utilities.i18n import I18N from crewai.utilities.printer import ColoredText, Printer +from crewai.utilities.pydantic_schema_utils import generate_model_description from crewai.utilities.string_utils import sanitize_tool_name from crewai.utilities.token_counter_callback import TokenCalcHandler from crewai.utilities.types import LLMMessage @@ -36,6 +37,7 @@ from crewai.utilities.types import LLMMessage if TYPE_CHECKING: from crewai.agent import Agent from crewai.agents.crew_agent_executor import CrewAgentExecutor + from crewai.experimental.agent_executor import AgentExecutor from crewai.lite_agent import LiteAgent from crewai.llm import LLM from crewai.task import Task @@ -158,7 +160,8 @@ def convert_tools_to_openai_schema( parameters: dict[str, Any] = {} if hasattr(tool, "args_schema") and tool.args_schema is not None: try: - parameters = tool.args_schema.model_json_schema() + schema_output = generate_model_description(tool.args_schema) + parameters = schema_output.get("json_schema", {}).get("schema", {}) # Remove title and description from schema root as they're redundant parameters.pop("title", None) parameters.pop("description", None) @@ -318,7 +321,7 @@ def get_llm_response( from_task: Task | None = None, from_agent: Agent | LiteAgent | None = None, response_model: type[BaseModel] | None = None, - executor_context: CrewAgentExecutor | LiteAgent | None = None, + executor_context: CrewAgentExecutor | AgentExecutor | LiteAgent | None = None, ) -> str | Any: """Call the LLM and return the response, handling any invalid responses. @@ -380,7 +383,7 @@ async def aget_llm_response( from_task: Task | None = None, from_agent: Agent | LiteAgent | None = None, response_model: type[BaseModel] | None = None, - executor_context: CrewAgentExecutor | None = None, + executor_context: CrewAgentExecutor | AgentExecutor | None = None, ) -> str | Any: """Call the LLM asynchronously and return the response. @@ -900,7 +903,8 @@ def extract_tool_call_info( def _setup_before_llm_call_hooks( - executor_context: CrewAgentExecutor | LiteAgent | None, printer: Printer + executor_context: CrewAgentExecutor | AgentExecutor | LiteAgent | None, + printer: Printer, ) -> bool: """Setup and invoke before_llm_call hooks for the executor context. @@ -950,7 +954,7 @@ def _setup_before_llm_call_hooks( def _setup_after_llm_call_hooks( - executor_context: CrewAgentExecutor | LiteAgent | None, + executor_context: CrewAgentExecutor | AgentExecutor | LiteAgent | None, answer: str, printer: Printer, ) -> str: diff --git a/lib/crewai/src/crewai/utilities/pydantic_schema_utils.py b/lib/crewai/src/crewai/utilities/pydantic_schema_utils.py index 6df3d516d..69354742b 100644 --- a/lib/crewai/src/crewai/utilities/pydantic_schema_utils.py +++ b/lib/crewai/src/crewai/utilities/pydantic_schema_utils.py @@ -1,14 +1,72 @@ -"""Utilities for generating JSON schemas from Pydantic models. +"""Dynamic Pydantic model creation from JSON schemas. + +This module provides utilities for converting JSON schemas to Pydantic models at runtime. +The main function is `create_model_from_schema`, which takes a JSON schema and returns +a dynamically created Pydantic model class. + +This is used by the A2A server to honor response schemas sent by clients, allowing +structured output from agent tasks. + +Based on dydantic (https://github.com/zenbase-ai/dydantic). This module provides functions for converting Pydantic models to JSON schemas suitable for use with LLMs and tool definitions. """ +from __future__ import annotations + from collections.abc import Callable from copy import deepcopy -from typing import Any +import datetime +import logging +from typing import TYPE_CHECKING, Annotated, Any, Literal, Union +import uuid -from pydantic import BaseModel +from pydantic import ( + UUID1, + UUID3, + UUID4, + UUID5, + AnyUrl, + BaseModel, + ConfigDict, + DirectoryPath, + Field, + FilePath, + FileUrl, + HttpUrl, + Json, + MongoDsn, + NewPath, + PostgresDsn, + SecretBytes, + SecretStr, + StrictBytes, + create_model as create_model_base, +) +from pydantic.networks import ( # type: ignore[attr-defined] + IPv4Address, + IPv6Address, + IPvAnyAddress, + IPvAnyInterface, + IPvAnyNetwork, +) + + +logger = logging.getLogger(__name__) + +if TYPE_CHECKING: + from pydantic import EmailStr + from pydantic.main import AnyClassMethod +else: + try: + from pydantic import EmailStr + except ImportError: + logger.warning( + "EmailStr unavailable, using str fallback", + extra={"missing_package": "email_validator"}, + ) + EmailStr = str def resolve_refs(schema: dict[str, Any]) -> dict[str, Any]: @@ -243,3 +301,319 @@ def generate_model_description(model: type[BaseModel]) -> dict[str, Any]: "schema": json_schema, }, } + + +FORMAT_TYPE_MAP: dict[str, type[Any]] = { + "base64": Annotated[bytes, Field(json_schema_extra={"format": "base64"})], # type: ignore[dict-item] + "binary": StrictBytes, + "date": datetime.date, + "time": datetime.time, + "date-time": datetime.datetime, + "duration": datetime.timedelta, + "directory-path": DirectoryPath, + "email": EmailStr, + "file-path": FilePath, + "ipv4": IPv4Address, + "ipv6": IPv6Address, + "ipvanyaddress": IPvAnyAddress, # type: ignore[dict-item] + "ipvanyinterface": IPvAnyInterface, # type: ignore[dict-item] + "ipvanynetwork": IPvAnyNetwork, # type: ignore[dict-item] + "json-string": Json, + "multi-host-uri": PostgresDsn | MongoDsn, # type: ignore[dict-item] + "password": SecretStr, + "path": NewPath, + "uri": AnyUrl, + "uuid": uuid.UUID, + "uuid1": UUID1, + "uuid3": UUID3, + "uuid4": UUID4, + "uuid5": UUID5, +} + + +def create_model_from_schema( # type: ignore[no-any-unimported] + json_schema: dict[str, Any], + *, + root_schema: dict[str, Any] | None = None, + __config__: ConfigDict | None = None, + __base__: type[BaseModel] | None = None, + __module__: str = __name__, + __validators__: dict[str, AnyClassMethod] | None = None, + __cls_kwargs__: dict[str, Any] | None = None, +) -> type[BaseModel]: + """Create a Pydantic model from a JSON schema. + + This function takes a JSON schema as input and dynamically creates a Pydantic + model class based on the schema. It supports various JSON schema features such + as nested objects, referenced definitions ($ref), arrays with typed items, + union types (anyOf/oneOf), and string formats. + + Args: + json_schema: A dictionary representing the JSON schema. + root_schema: The root schema containing $defs. If not provided, the + current schema is treated as the root schema. + __config__: Pydantic configuration for the generated model. + __base__: Base class for the generated model. Defaults to BaseModel. + __module__: Module name for the generated model class. + __validators__: A dictionary of custom validators for the generated model. + __cls_kwargs__: Additional keyword arguments for the generated model class. + + Returns: + A dynamically created Pydantic model class based on the provided JSON schema. + + Example: + >>> schema = { + ... "title": "Person", + ... "type": "object", + ... "properties": { + ... "name": {"type": "string"}, + ... "age": {"type": "integer"}, + ... }, + ... "required": ["name"], + ... } + >>> Person = create_model_from_schema(schema) + >>> person = Person(name="John", age=30) + >>> person.name + 'John' + """ + effective_root = root_schema or json_schema + + if "allOf" in json_schema: + json_schema = _merge_all_of_schemas(json_schema["allOf"], effective_root) + if "title" not in json_schema and "title" in (root_schema or {}): + json_schema["title"] = (root_schema or {}).get("title") + + model_name = json_schema.get("title", "DynamicModel") + field_definitions = { + name: _json_schema_to_pydantic_field( + name, prop, json_schema.get("required", []), effective_root + ) + for name, prop in (json_schema.get("properties", {}) or {}).items() + } + + return create_model_base( + model_name, + __config__=__config__, + __base__=__base__, + __module__=__module__, + __validators__=__validators__, + __cls_kwargs__=__cls_kwargs__, + **field_definitions, + ) + + +def _json_schema_to_pydantic_field( + name: str, + json_schema: dict[str, Any], + required: list[str], + root_schema: dict[str, Any], +) -> Any: + """Convert a JSON schema property to a Pydantic field definition. + + Args: + name: The field name. + json_schema: The JSON schema for this field. + required: List of required field names. + root_schema: The root schema for resolving $ref. + + Returns: + A tuple of (type, Field) for use with create_model. + """ + type_ = _json_schema_to_pydantic_type(json_schema, root_schema, name_=name.title()) + description = json_schema.get("description") + examples = json_schema.get("examples") + is_required = name in required + + field_params: dict[str, Any] = {} + schema_extra: dict[str, Any] = {} + + if description: + field_params["description"] = description + if examples: + schema_extra["examples"] = examples + + default = ... if is_required else None + + if isinstance(type_, type) and issubclass(type_, (int, float)): + if "minimum" in json_schema: + field_params["ge"] = json_schema["minimum"] + if "exclusiveMinimum" in json_schema: + field_params["gt"] = json_schema["exclusiveMinimum"] + if "maximum" in json_schema: + field_params["le"] = json_schema["maximum"] + if "exclusiveMaximum" in json_schema: + field_params["lt"] = json_schema["exclusiveMaximum"] + if "multipleOf" in json_schema: + field_params["multiple_of"] = json_schema["multipleOf"] + + format_ = json_schema.get("format") + if format_ in FORMAT_TYPE_MAP: + pydantic_type = FORMAT_TYPE_MAP[format_] + + if format_ == "password": + if json_schema.get("writeOnly"): + pydantic_type = SecretBytes + elif format_ == "uri": + allowed_schemes = json_schema.get("scheme") + if allowed_schemes: + if len(allowed_schemes) == 1 and allowed_schemes[0] == "http": + pydantic_type = HttpUrl + elif len(allowed_schemes) == 1 and allowed_schemes[0] == "file": + pydantic_type = FileUrl + + type_ = pydantic_type + + if isinstance(type_, type) and issubclass(type_, str): + if "minLength" in json_schema: + field_params["min_length"] = json_schema["minLength"] + if "maxLength" in json_schema: + field_params["max_length"] = json_schema["maxLength"] + if "pattern" in json_schema: + field_params["pattern"] = json_schema["pattern"] + + if not is_required: + type_ = type_ | None + + if schema_extra: + field_params["json_schema_extra"] = schema_extra + + return type_, Field(default, **field_params) + + +def _resolve_ref(ref: str, root_schema: dict[str, Any]) -> dict[str, Any]: + """Resolve a $ref to its actual schema. + + Args: + ref: The $ref string (e.g., "#/$defs/MyType"). + root_schema: The root schema containing $defs. + + Returns: + The resolved schema dict. + """ + from typing import cast + + ref_path = ref.split("/") + if ref.startswith("#/$defs/"): + ref_schema: dict[str, Any] = root_schema["$defs"] + start_idx = 2 + else: + ref_schema = root_schema + start_idx = 1 + for path in ref_path[start_idx:]: + ref_schema = cast(dict[str, Any], ref_schema[path]) + return ref_schema + + +def _merge_all_of_schemas( + schemas: list[dict[str, Any]], + root_schema: dict[str, Any], +) -> dict[str, Any]: + """Merge multiple allOf schemas into a single schema. + + Combines properties and required fields from all schemas. + + Args: + schemas: List of schemas to merge. + root_schema: The root schema for resolving $ref. + + Returns: + Merged schema with combined properties and required fields. + """ + merged: dict[str, Any] = {"type": "object", "properties": {}, "required": []} + + for schema in schemas: + if "$ref" in schema: + schema = _resolve_ref(schema["$ref"], root_schema) + + if "properties" in schema: + merged["properties"].update(schema["properties"]) + + if "required" in schema: + for field in schema["required"]: + if field not in merged["required"]: + merged["required"].append(field) + + if "title" in schema and "title" not in merged: + merged["title"] = schema["title"] + + return merged + + +def _json_schema_to_pydantic_type( + json_schema: dict[str, Any], + root_schema: dict[str, Any], + *, + name_: str | None = None, +) -> Any: + """Convert a JSON schema to a Python/Pydantic type. + + Args: + json_schema: The JSON schema to convert. + root_schema: The root schema for resolving $ref. + name_: Optional name for nested models. + + Returns: + A Python type corresponding to the JSON schema. + """ + ref = json_schema.get("$ref") + if ref: + ref_schema = _resolve_ref(ref, root_schema) + return _json_schema_to_pydantic_type(ref_schema, root_schema, name_=name_) + + enum_values = json_schema.get("enum") + if enum_values: + return Literal[tuple(enum_values)] + + if "const" in json_schema: + return Literal[json_schema["const"]] + + any_of_schemas = [] + if "anyOf" in json_schema or "oneOf" in json_schema: + any_of_schemas = json_schema.get("anyOf", []) + json_schema.get("oneOf", []) + if any_of_schemas: + any_of_types = [ + _json_schema_to_pydantic_type(schema, root_schema) + for schema in any_of_schemas + ] + return Union[tuple(any_of_types)] # noqa: UP007 + + all_of_schemas = json_schema.get("allOf") + if all_of_schemas: + if len(all_of_schemas) == 1: + return _json_schema_to_pydantic_type( + all_of_schemas[0], root_schema, name_=name_ + ) + merged = _merge_all_of_schemas(all_of_schemas, root_schema) + return _json_schema_to_pydantic_type(merged, root_schema, name_=name_) + + type_ = json_schema.get("type") + + if type_ == "string": + return str + if type_ == "integer": + return int + if type_ == "number": + return float + if type_ == "boolean": + return bool + if type_ == "array": + items_schema = json_schema.get("items") + if items_schema: + item_type = _json_schema_to_pydantic_type( + items_schema, root_schema, name_=name_ + ) + return list[item_type] # type: ignore[valid-type] + return list + if type_ == "object": + properties = json_schema.get("properties") + if properties: + json_schema_ = json_schema.copy() + if json_schema_.get("title") is None: + json_schema_["title"] = name_ + return create_model_from_schema(json_schema_, root_schema=root_schema) + return dict + if type_ == "null": + return None + if type_ is None: + return Any + raise ValueError(f"Unsupported JSON schema type: {type_} from {json_schema}") diff --git a/lib/crewai/src/crewai/utilities/types.py b/lib/crewai/src/crewai/utilities/types.py index 98ff0877b..340f6f751 100644 --- a/lib/crewai/src/crewai/utilities/types.py +++ b/lib/crewai/src/crewai/utilities/types.py @@ -26,4 +26,5 @@ class LLMMessage(TypedDict): tool_call_id: NotRequired[str] name: NotRequired[str] tool_calls: NotRequired[list[dict[str, Any]]] + raw_tool_call_parts: NotRequired[list[Any]] files: NotRequired[dict[str, FileInput]]