From 78f5144bde44fee525d2b68f39bd43c4912579f2 Mon Sep 17 00:00:00 2001 From: Lorenze Jay <63378463+lorenzejay@users.noreply.github.com> Date: Fri, 11 Jul 2025 10:18:54 -0700 Subject: [PATCH] =?UTF-8?q?Enhance=20EnterpriseActionTool=20with=20improve?= =?UTF-8?q?d=20schema=20processing=20and=20erro=E2=80=A6=20(#371)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Enhance EnterpriseActionTool with improved schema processing and error handling - Added methods for sanitizing names and processing schema types, including support for nested models and nullable types. - Improved error handling during schema creation and processing, with warnings for failures. - Updated parameter handling in the `_run` method to clean up `kwargs` before sending requests. - Introduced a detailed description generation for nested schema structures to enhance tool documentation. * Add tests for EnterpriseActionTool schema conversion and validation - Introduced a new test class for validating complex nested schemas in EnterpriseActionTool. - Added tests for schema conversion, optional fields, enum validation, and required nested fields. - Implemented execution tests to ensure the tool can handle complex validated input correctly. - Verified model naming conventions and added tests for simpler schemas with basic enum validation. - Enhanced overall test coverage for the EnterpriseActionTool functionality. * Update chromadb dependency version in pyproject.toml and uv.lock - Changed chromadb version from >=0.4.22 to ==0.5.23 in both pyproject.toml and uv.lock to ensure compatibility and stability. * Update test workflow configuration - Changed EMBEDCHAIN_DB_URI to point to a temporary test database location. - Added CHROMA_PERSIST_PATH for specifying the path to the Chroma test database. - Cleaned up the test run command in the workflow file. * reverted --- .../adapters/enterprise_adapter.py | 258 ++++++++++++----- tests/tools/crewai_enterprise_tools_test.py | 268 ++++++++++++++++++ 2 files changed, 456 insertions(+), 70 deletions(-) diff --git a/src/crewai_tools/adapters/enterprise_adapter.py b/src/crewai_tools/adapters/enterprise_adapter.py index 96d64af8b..bd442d98f 100644 --- a/src/crewai_tools/adapters/enterprise_adapter.py +++ b/src/crewai_tools/adapters/enterprise_adapter.py @@ -1,9 +1,11 @@ import os import json import requests -from typing import List, Any, Dict, Optional +from typing import List, Any, Dict, Literal, Optional, Union, get_origin from pydantic import Field, create_model from crewai.tools import BaseTool +import re + # DEFAULTS ENTERPRISE_ACTION_KIT_PROJECT_ID = "dd525517-df22-49d2-a69e-6a0eed211166" @@ -37,6 +39,9 @@ class EnterpriseActionTool(BaseTool): enterprise_action_kit_project_url: str = ENTERPRISE_ACTION_KIT_PROJECT_URL, enterprise_action_kit_project_id: str = ENTERPRISE_ACTION_KIT_PROJECT_ID, ): + self._model_registry = {} + self._base_name = self._sanitize_name(name) + schema_props, required = self._extract_schema_info(action_schema) # Define field definitions for the model @@ -44,22 +49,36 @@ class EnterpriseActionTool(BaseTool): for param_name, param_details in schema_props.items(): param_desc = param_details.get("description", "") is_required = param_name in required - is_nullable, param_type = self._analyze_field_type(param_details) - # Create field definition based on nullable and required status + try: + field_type = self._process_schema_type( + param_details, self._sanitize_name(param_name).title() + ) + except Exception as e: + print(f"Warning: Could not process schema for {param_name}: {e}") + field_type = str + + # Create field definition based on requirement field_definitions[param_name] = self._create_field_definition( - param_type, is_required, is_nullable, param_desc + field_type, is_required, param_desc ) # Create the model if field_definitions: - args_schema = create_model( - f"{name.capitalize()}Schema", **field_definitions - ) + try: + args_schema = create_model( + f"{self._base_name}Schema", **field_definitions + ) + except Exception as e: + print(f"Warning: Could not create main schema model: {e}") + args_schema = create_model( + f"{self._base_name}Schema", + input_text=(str, Field(description="Input for the action")), + ) else: # Fallback for empty schema args_schema = create_model( - f"{name.capitalize()}Schema", + f"{self._base_name}Schema", input_text=(str, Field(description="Input for the action")), ) @@ -73,6 +92,12 @@ class EnterpriseActionTool(BaseTool): if enterprise_action_kit_project_url is not None: self.enterprise_action_kit_project_url = enterprise_action_kit_project_url + def _sanitize_name(self, name: str) -> str: + """Sanitize names to create proper Python class names.""" + sanitized = re.sub(r"[^a-zA-Z0-9_]", "", name) + parts = sanitized.split("_") + return "".join(word.capitalize() for word in parts if word) + def _extract_schema_info( self, action_schema: Dict[str, Any] ) -> tuple[Dict[str, Any], List[str]]: @@ -87,51 +112,97 @@ class EnterpriseActionTool(BaseTool): ) return schema_props, required - def _analyze_field_type(self, param_details: Dict[str, Any]) -> tuple[bool, type]: - """Analyze field type and nullability from parameter details.""" - is_nullable = False - param_type = str # Default type - - if "anyOf" in param_details: - any_of_types = param_details["anyOf"] + def _process_schema_type(self, schema: Dict[str, Any], type_name: str) -> type: + """Process a JSON schema and return appropriate Python type.""" + if "anyOf" in schema: + any_of_types = schema["anyOf"] is_nullable = any(t.get("type") == "null" for t in any_of_types) non_null_types = [t for t in any_of_types if t.get("type") != "null"] - if non_null_types: - first_type = non_null_types[0].get("type", "string") - param_type = self._map_json_type_to_python( - first_type, non_null_types[0] - ) - else: - json_type = param_details.get("type", "string") - param_type = self._map_json_type_to_python(json_type, param_details) - is_nullable = json_type == "null" - return is_nullable, param_type + if non_null_types: + base_type = self._process_schema_type(non_null_types[0], type_name) + return Optional[base_type] if is_nullable else base_type + return Optional[str] + + if "oneOf" in schema: + return self._process_schema_type(schema["oneOf"][0], type_name) + + if "allOf" in schema: + return self._process_schema_type(schema["allOf"][0], type_name) + + json_type = schema.get("type", "string") + + if "enum" in schema: + enum_values = schema["enum"] + if not enum_values: + return self._map_json_type_to_python(json_type) + return Literal[tuple(enum_values)] # type: ignore + + if json_type == "array": + items_schema = schema.get("items", {"type": "string"}) + item_type = self._process_schema_type(items_schema, f"{type_name}Item") + return List[item_type] + + if json_type == "object": + return self._create_nested_model(schema, type_name) + + return self._map_json_type_to_python(json_type) + + def _create_nested_model(self, schema: Dict[str, Any], model_name: str) -> type: + """Create a nested Pydantic model for complex objects.""" + full_model_name = f"{self._base_name}{model_name}" + + if full_model_name in self._model_registry: + return self._model_registry[full_model_name] + + properties = schema.get("properties", {}) + required_fields = schema.get("required", []) + + if not properties: + return dict + + field_definitions = {} + for prop_name, prop_schema in properties.items(): + prop_desc = prop_schema.get("description", "") + is_required = prop_name in required_fields + + try: + prop_type = self._process_schema_type( + prop_schema, f"{model_name}{self._sanitize_name(prop_name).title()}" + ) + except Exception as e: + print(f"Warning: Could not process schema for {prop_name}: {e}") + prop_type = str + + field_definitions[prop_name] = self._create_field_definition( + prop_type, is_required, prop_desc + ) + + try: + nested_model = create_model(full_model_name, **field_definitions) + self._model_registry[full_model_name] = nested_model + return nested_model + except Exception as e: + print(f"Warning: Could not create nested model {full_model_name}: {e}") + return dict def _create_field_definition( - self, param_type: type, is_required: bool, is_nullable: bool, param_desc: str + self, field_type: type, is_required: bool, description: str ) -> tuple: - """Create Pydantic field definition based on type, requirement, and nullability.""" - if is_nullable: - return ( - Optional[param_type], - Field(default=None, description=param_desc), - ) - elif is_required: - return ( - param_type, - Field(description=param_desc), - ) + """Create Pydantic field definition based on type and requirement.""" + if is_required: + return (field_type, Field(description=description)) else: - return ( - Optional[param_type], - Field(default=None, description=param_desc), - ) + if get_origin(field_type) is Union: + return (field_type, Field(default=None, description=description)) + else: + return ( + Optional[field_type], + Field(default=None, description=description), + ) - def _map_json_type_to_python( - self, json_type: str, param_details: Dict[str, Any] - ) -> type: - """Map JSON schema types to Python types.""" + def _map_json_type_to_python(self, json_type: str) -> type: + """Map basic JSON schema types to Python types.""" type_mapping = { "string": str, "integer": int, @@ -139,6 +210,7 @@ class EnterpriseActionTool(BaseTool): "boolean": bool, "array": list, "object": dict, + "null": type(None), } return type_mapping.get(json_type, str) @@ -149,29 +221,37 @@ class EnterpriseActionTool(BaseTool): required_nullable_fields = [] for param_name in required: param_details = schema_props.get(param_name, {}) - is_nullable, _ = self._analyze_field_type(param_details) - if is_nullable: + if self._is_nullable_type(param_details): required_nullable_fields.append(param_name) return required_nullable_fields + def _is_nullable_type(self, schema: Dict[str, Any]) -> bool: + """Check if a schema represents a nullable type.""" + if "anyOf" in schema: + return any(t.get("type") == "null" for t in schema["anyOf"]) + return schema.get("type") == "null" + def _run(self, **kwargs) -> str: """Execute the specific enterprise action with validated parameters.""" try: + cleaned_kwargs = {} + for key, value in kwargs.items(): + if value is not None: + cleaned_kwargs[key] = value + required_nullable_fields = self._get_required_nullable_fields() for field_name in required_nullable_fields: - if field_name not in kwargs: - kwargs[field_name] = None - - params = {k: v for k, v in kwargs.items() if v is not None} + if field_name not in cleaned_kwargs: + cleaned_kwargs[field_name] = None api_url = f"{self.enterprise_action_kit_project_url}/{self.enterprise_action_kit_project_id}/actions" headers = { "Authorization": f"Bearer {self.enterprise_action_token}", "Content-Type": "application/json", } - payload = {"action": self.action_name, "parameters": params} + payload = {"action": self.action_name, "parameters": cleaned_kwargs} response = requests.post( url=api_url, headers=headers, json=payload, timeout=60 @@ -198,7 +278,6 @@ class EnterpriseActionKitToolAdapter: enterprise_action_kit_project_id: str = ENTERPRISE_ACTION_KIT_PROJECT_ID, ): """Initialize the adapter with an enterprise action token.""" - self.enterprise_action_token = enterprise_action_token self._actions_schema = {} self._tools = None @@ -206,11 +285,7 @@ class EnterpriseActionKitToolAdapter: self.enterprise_action_kit_project_url = enterprise_action_kit_project_url def tools(self) -> List[BaseTool]: - """Get the list of tools created from enterprise actions. - - Returns: - List of BaseTool instances, one for each enterprise action. - """ + """Get the list of tools created from enterprise actions.""" if self._tools is None: self._fetch_actions() self._create_tools() @@ -261,6 +336,53 @@ class EnterpriseActionKitToolAdapter: traceback.print_exc() + def _generate_detailed_description( + self, schema: Dict[str, Any], indent: int = 0 + ) -> List[str]: + """Generate detailed description for nested schema structures.""" + descriptions = [] + indent_str = " " * indent + + schema_type = schema.get("type", "string") + + if schema_type == "object": + properties = schema.get("properties", {}) + required_fields = schema.get("required", []) + + if properties: + descriptions.append(f"{indent_str}Object with properties:") + for prop_name, prop_schema in properties.items(): + prop_desc = prop_schema.get("description", "") + is_required = prop_name in required_fields + req_str = " (required)" if is_required else " (optional)" + descriptions.append( + f"{indent_str} - {prop_name}: {prop_desc}{req_str}" + ) + + if prop_schema.get("type") == "object": + descriptions.extend( + self._generate_detailed_description(prop_schema, indent + 2) + ) + elif prop_schema.get("type") == "array": + items_schema = prop_schema.get("items", {}) + if items_schema.get("type") == "object": + descriptions.append(f"{indent_str} Array of objects:") + descriptions.extend( + self._generate_detailed_description( + items_schema, indent + 3 + ) + ) + elif "enum" in items_schema: + descriptions.append( + f"{indent_str} Array of enum values: {items_schema['enum']}" + ) + elif "enum" in prop_schema: + descriptions.append( + f"{indent_str} Enum values: {prop_schema['enum']}" + ) + + return descriptions + def _create_tools(self): """Create BaseTool instances for each action.""" tools = [] @@ -269,19 +391,16 @@ class EnterpriseActionKitToolAdapter: function_details = action_schema.get("function", {}) description = function_details.get("description", f"Execute {action_name}") - # Get parameter info for a better description - parameters = function_details.get("parameters", {}).get("properties", {}) - param_info = [] - for param_name, param_details in parameters.items(): - param_desc = param_details.get("description", "") - required = param_name in function_details.get("parameters", {}).get( - "required", [] - ) - param_info.append( - f"- {param_name}: {param_desc} {'(required)' if required else '(optional)'}" + parameters = function_details.get("parameters", {}) + param_descriptions = [] + + if parameters.get("properties"): + param_descriptions.append("\nDetailed Parameter Structure:") + param_descriptions.extend( + self._generate_detailed_description(parameters) ) - full_description = f"{description}\n\nParameters:\n" + "\n".join(param_info) + full_description = description + "\n".join(param_descriptions) tool = EnterpriseActionTool( name=action_name.lower().replace(" ", "_"), @@ -297,7 +416,6 @@ class EnterpriseActionKitToolAdapter: self._tools = tools - # Adding context manager support for convenience, but direct usage is also supported def __enter__(self): return self.tools() diff --git a/tests/tools/crewai_enterprise_tools_test.py b/tests/tools/crewai_enterprise_tools_test.py index d7a868472..b043289dc 100644 --- a/tests/tools/crewai_enterprise_tools_test.py +++ b/tests/tools/crewai_enterprise_tools_test.py @@ -2,9 +2,11 @@ import os import unittest from unittest.mock import patch, MagicMock + from crewai.tools import BaseTool from crewai_tools.tools import CrewaiEnterpriseTools from crewai_tools.adapters.tool_collection import ToolCollection +from crewai_tools.adapters.enterprise_adapter import EnterpriseActionTool class TestCrewaiEnterpriseTools(unittest.TestCase): @@ -86,3 +88,269 @@ class TestCrewaiEnterpriseTools(unittest.TestCase): self.assertEqual(len(tools), 2) self.assertEqual(tools[0].name, "tool1") self.assertEqual(tools[1].name, "tool3") + + +class TestEnterpriseActionToolSchemaConversion(unittest.TestCase): + """Test the enterprise action tool schema conversion and validation.""" + + def setUp(self): + self.test_schema = { + "type": "function", + "function": { + "name": "TEST_COMPLEX_ACTION", + "description": "Test action with complex nested structure", + "parameters": { + "type": "object", + "properties": { + "filterCriteria": { + "type": "object", + "description": "Filter criteria object", + "properties": { + "operation": {"type": "string", "enum": ["AND", "OR"]}, + "rules": { + "type": "array", + "items": { + "type": "object", + "properties": { + "field": { + "type": "string", + "enum": ["name", "email", "status"], + }, + "operator": { + "type": "string", + "enum": ["equals", "contains"], + }, + "value": {"type": "string"}, + }, + "required": ["field", "operator", "value"], + }, + }, + }, + "required": ["operation", "rules"], + }, + "options": { + "type": "object", + "properties": { + "limit": {"type": "integer"}, + "offset": {"type": "integer"}, + }, + "required": [], + }, + }, + "required": [], + }, + }, + } + + def test_complex_schema_conversion(self): + """Test that complex nested schemas are properly converted to Pydantic models.""" + tool = EnterpriseActionTool( + name="gmail_search_for_email", + description="Test tool", + enterprise_action_token="test_token", + action_name="GMAIL_SEARCH_FOR_EMAIL", + action_schema=self.test_schema, + ) + + self.assertEqual(tool.name, "gmail_search_for_email") + self.assertEqual(tool.action_name, "GMAIL_SEARCH_FOR_EMAIL") + + schema_class = tool.args_schema + self.assertIsNotNone(schema_class) + + schema_fields = schema_class.model_fields + self.assertIn("filterCriteria", schema_fields) + self.assertIn("options", schema_fields) + + # Test valid input structure + valid_input = { + "filterCriteria": { + "operation": "AND", + "rules": [ + {"field": "name", "operator": "contains", "value": "test"}, + {"field": "status", "operator": "equals", "value": "active"}, + ], + }, + "options": {"limit": 10}, + } + + # This should not raise an exception + validated_input = schema_class(**valid_input) + self.assertIsNotNone(validated_input.filterCriteria) + self.assertIsNotNone(validated_input.options) + + def test_optional_fields_validation(self): + """Test that optional fields work correctly.""" + tool = EnterpriseActionTool( + name="gmail_search_for_email", + description="Test tool", + enterprise_action_token="test_token", + action_name="GMAIL_SEARCH_FOR_EMAIL", + action_schema=self.test_schema, + ) + + schema_class = tool.args_schema + + minimal_input = {} + validated_input = schema_class(**minimal_input) + self.assertIsNone(validated_input.filterCriteria) + self.assertIsNone(validated_input.options) + + partial_input = {"options": {"limit": 10}} + validated_input = schema_class(**partial_input) + self.assertIsNone(validated_input.filterCriteria) + self.assertIsNotNone(validated_input.options) + + def test_enum_validation(self): + """Test that enum values are properly validated.""" + tool = EnterpriseActionTool( + name="gmail_search_for_email", + description="Test tool", + enterprise_action_token="test_token", + action_name="GMAIL_SEARCH_FOR_EMAIL", + action_schema=self.test_schema, + ) + + schema_class = tool.args_schema + + invalid_input = { + "filterCriteria": { + "operation": "INVALID_OPERATOR", + "rules": [], + } + } + + with self.assertRaises(Exception): + schema_class(**invalid_input) + + def test_required_nested_fields(self): + """Test that required fields in nested objects are validated.""" + tool = EnterpriseActionTool( + name="gmail_search_for_email", + description="Test tool", + enterprise_action_token="test_token", + action_name="GMAIL_SEARCH_FOR_EMAIL", + action_schema=self.test_schema, + ) + + schema_class = tool.args_schema + + incomplete_input = { + "filterCriteria": { + "operation": "OR", + "rules": [ + { + "field": "name", + "operator": "contains", + } + ], + } + } + + with self.assertRaises(Exception): + schema_class(**incomplete_input) + + @patch("requests.post") + def test_tool_execution_with_complex_input(self, mock_post): + """Test that the tool can execute with complex validated input.""" + mock_response = MagicMock() + mock_response.ok = True + mock_response.json.return_value = {"success": True, "results": []} + mock_post.return_value = mock_response + + tool = EnterpriseActionTool( + name="gmail_search_for_email", + description="Test tool", + enterprise_action_token="test_token", + action_name="GMAIL_SEARCH_FOR_EMAIL", + action_schema=self.test_schema, + ) + + tool._run( + filterCriteria={ + "operation": "OR", + "rules": [ + {"field": "name", "operator": "contains", "value": "test"}, + {"field": "status", "operator": "equals", "value": "active"}, + ], + }, + options={"limit": 10}, + ) + + mock_post.assert_called_once() + call_args = mock_post.call_args + payload = call_args[1]["json"] + + self.assertEqual(payload["action"], "GMAIL_SEARCH_FOR_EMAIL") + self.assertIn("filterCriteria", payload["parameters"]) + self.assertIn("options", payload["parameters"]) + self.assertEqual(payload["parameters"]["filterCriteria"]["operation"], "OR") + + def test_model_naming_convention(self): + """Test that generated model names follow proper conventions.""" + tool = EnterpriseActionTool( + name="gmail_search_for_email", + description="Test tool", + enterprise_action_token="test_token", + action_name="GMAIL_SEARCH_FOR_EMAIL", + action_schema=self.test_schema, + ) + + schema_class = tool.args_schema + self.assertIsNotNone(schema_class) + + self.assertTrue(schema_class.__name__.endswith("Schema")) + self.assertTrue(schema_class.__name__[0].isupper()) + + complex_input = { + "filterCriteria": { + "operation": "OR", + "rules": [ + {"field": "name", "operator": "contains", "value": "test"}, + {"field": "status", "operator": "equals", "value": "active"}, + ], + }, + "options": {"limit": 10}, + } + + validated = schema_class(**complex_input) + self.assertIsNotNone(validated.filterCriteria) + + def test_simple_schema_with_enums(self): + """Test a simpler schema with basic enum validation.""" + simple_schema = { + "type": "function", + "function": { + "name": "SIMPLE_TEST", + "description": "Simple test function", + "parameters": { + "type": "object", + "properties": { + "status": { + "type": "string", + "enum": ["active", "inactive", "pending"], + }, + "priority": {"type": "integer", "enum": [1, 2, 3]}, + }, + "required": ["status"], + }, + }, + } + + tool = EnterpriseActionTool( + name="simple_test", + description="Simple test tool", + enterprise_action_token="test_token", + action_name="SIMPLE_TEST", + action_schema=simple_schema, + ) + + schema_class = tool.args_schema + + valid_input = {"status": "active", "priority": 2} + validated = schema_class(**valid_input) + self.assertEqual(validated.status, "active") + self.assertEqual(validated.priority, 2) + + with self.assertRaises(Exception): + schema_class(status="invalid_status")