mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-08 15:48:29 +00:00
Enhance EnterpriseActionTool with improved schema processing and erro… (#371)
* Enhance EnterpriseActionTool with improved schema processing and error handling - Added methods for sanitizing names and processing schema types, including support for nested models and nullable types. - Improved error handling during schema creation and processing, with warnings for failures. - Updated parameter handling in the `_run` method to clean up `kwargs` before sending requests. - Introduced a detailed description generation for nested schema structures to enhance tool documentation. * Add tests for EnterpriseActionTool schema conversion and validation - Introduced a new test class for validating complex nested schemas in EnterpriseActionTool. - Added tests for schema conversion, optional fields, enum validation, and required nested fields. - Implemented execution tests to ensure the tool can handle complex validated input correctly. - Verified model naming conventions and added tests for simpler schemas with basic enum validation. - Enhanced overall test coverage for the EnterpriseActionTool functionality. * Update chromadb dependency version in pyproject.toml and uv.lock - Changed chromadb version from >=0.4.22 to ==0.5.23 in both pyproject.toml and uv.lock to ensure compatibility and stability. * Update test workflow configuration - Changed EMBEDCHAIN_DB_URI to point to a temporary test database location. - Added CHROMA_PERSIST_PATH for specifying the path to the Chroma test database. - Cleaned up the test run command in the workflow file. * reverted
This commit is contained in:
@@ -1,9 +1,11 @@
|
|||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
import requests
|
import requests
|
||||||
from typing import List, Any, Dict, Optional
|
from typing import List, Any, Dict, Literal, Optional, Union, get_origin
|
||||||
from pydantic import Field, create_model
|
from pydantic import Field, create_model
|
||||||
from crewai.tools import BaseTool
|
from crewai.tools import BaseTool
|
||||||
|
import re
|
||||||
|
|
||||||
|
|
||||||
# DEFAULTS
|
# DEFAULTS
|
||||||
ENTERPRISE_ACTION_KIT_PROJECT_ID = "dd525517-df22-49d2-a69e-6a0eed211166"
|
ENTERPRISE_ACTION_KIT_PROJECT_ID = "dd525517-df22-49d2-a69e-6a0eed211166"
|
||||||
@@ -37,6 +39,9 @@ class EnterpriseActionTool(BaseTool):
|
|||||||
enterprise_action_kit_project_url: str = ENTERPRISE_ACTION_KIT_PROJECT_URL,
|
enterprise_action_kit_project_url: str = ENTERPRISE_ACTION_KIT_PROJECT_URL,
|
||||||
enterprise_action_kit_project_id: str = ENTERPRISE_ACTION_KIT_PROJECT_ID,
|
enterprise_action_kit_project_id: str = ENTERPRISE_ACTION_KIT_PROJECT_ID,
|
||||||
):
|
):
|
||||||
|
self._model_registry = {}
|
||||||
|
self._base_name = self._sanitize_name(name)
|
||||||
|
|
||||||
schema_props, required = self._extract_schema_info(action_schema)
|
schema_props, required = self._extract_schema_info(action_schema)
|
||||||
|
|
||||||
# Define field definitions for the model
|
# Define field definitions for the model
|
||||||
@@ -44,22 +49,36 @@ class EnterpriseActionTool(BaseTool):
|
|||||||
for param_name, param_details in schema_props.items():
|
for param_name, param_details in schema_props.items():
|
||||||
param_desc = param_details.get("description", "")
|
param_desc = param_details.get("description", "")
|
||||||
is_required = param_name in required
|
is_required = param_name in required
|
||||||
is_nullable, param_type = self._analyze_field_type(param_details)
|
|
||||||
|
|
||||||
# Create field definition based on nullable and required status
|
try:
|
||||||
|
field_type = self._process_schema_type(
|
||||||
|
param_details, self._sanitize_name(param_name).title()
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Warning: Could not process schema for {param_name}: {e}")
|
||||||
|
field_type = str
|
||||||
|
|
||||||
|
# Create field definition based on requirement
|
||||||
field_definitions[param_name] = self._create_field_definition(
|
field_definitions[param_name] = self._create_field_definition(
|
||||||
param_type, is_required, is_nullable, param_desc
|
field_type, is_required, param_desc
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create the model
|
# Create the model
|
||||||
if field_definitions:
|
if field_definitions:
|
||||||
args_schema = create_model(
|
try:
|
||||||
f"{name.capitalize()}Schema", **field_definitions
|
args_schema = create_model(
|
||||||
)
|
f"{self._base_name}Schema", **field_definitions
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Warning: Could not create main schema model: {e}")
|
||||||
|
args_schema = create_model(
|
||||||
|
f"{self._base_name}Schema",
|
||||||
|
input_text=(str, Field(description="Input for the action")),
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# Fallback for empty schema
|
# Fallback for empty schema
|
||||||
args_schema = create_model(
|
args_schema = create_model(
|
||||||
f"{name.capitalize()}Schema",
|
f"{self._base_name}Schema",
|
||||||
input_text=(str, Field(description="Input for the action")),
|
input_text=(str, Field(description="Input for the action")),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -73,6 +92,12 @@ class EnterpriseActionTool(BaseTool):
|
|||||||
if enterprise_action_kit_project_url is not None:
|
if enterprise_action_kit_project_url is not None:
|
||||||
self.enterprise_action_kit_project_url = enterprise_action_kit_project_url
|
self.enterprise_action_kit_project_url = enterprise_action_kit_project_url
|
||||||
|
|
||||||
|
def _sanitize_name(self, name: str) -> str:
|
||||||
|
"""Sanitize names to create proper Python class names."""
|
||||||
|
sanitized = re.sub(r"[^a-zA-Z0-9_]", "", name)
|
||||||
|
parts = sanitized.split("_")
|
||||||
|
return "".join(word.capitalize() for word in parts if word)
|
||||||
|
|
||||||
def _extract_schema_info(
|
def _extract_schema_info(
|
||||||
self, action_schema: Dict[str, Any]
|
self, action_schema: Dict[str, Any]
|
||||||
) -> tuple[Dict[str, Any], List[str]]:
|
) -> tuple[Dict[str, Any], List[str]]:
|
||||||
@@ -87,51 +112,97 @@ class EnterpriseActionTool(BaseTool):
|
|||||||
)
|
)
|
||||||
return schema_props, required
|
return schema_props, required
|
||||||
|
|
||||||
def _analyze_field_type(self, param_details: Dict[str, Any]) -> tuple[bool, type]:
|
def _process_schema_type(self, schema: Dict[str, Any], type_name: str) -> type:
|
||||||
"""Analyze field type and nullability from parameter details."""
|
"""Process a JSON schema and return appropriate Python type."""
|
||||||
is_nullable = False
|
if "anyOf" in schema:
|
||||||
param_type = str # Default type
|
any_of_types = schema["anyOf"]
|
||||||
|
|
||||||
if "anyOf" in param_details:
|
|
||||||
any_of_types = param_details["anyOf"]
|
|
||||||
is_nullable = any(t.get("type") == "null" for t in any_of_types)
|
is_nullable = any(t.get("type") == "null" for t in any_of_types)
|
||||||
non_null_types = [t for t in any_of_types if t.get("type") != "null"]
|
non_null_types = [t for t in any_of_types if t.get("type") != "null"]
|
||||||
if non_null_types:
|
|
||||||
first_type = non_null_types[0].get("type", "string")
|
|
||||||
param_type = self._map_json_type_to_python(
|
|
||||||
first_type, non_null_types[0]
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
json_type = param_details.get("type", "string")
|
|
||||||
param_type = self._map_json_type_to_python(json_type, param_details)
|
|
||||||
is_nullable = json_type == "null"
|
|
||||||
|
|
||||||
return is_nullable, param_type
|
if non_null_types:
|
||||||
|
base_type = self._process_schema_type(non_null_types[0], type_name)
|
||||||
|
return Optional[base_type] if is_nullable else base_type
|
||||||
|
return Optional[str]
|
||||||
|
|
||||||
|
if "oneOf" in schema:
|
||||||
|
return self._process_schema_type(schema["oneOf"][0], type_name)
|
||||||
|
|
||||||
|
if "allOf" in schema:
|
||||||
|
return self._process_schema_type(schema["allOf"][0], type_name)
|
||||||
|
|
||||||
|
json_type = schema.get("type", "string")
|
||||||
|
|
||||||
|
if "enum" in schema:
|
||||||
|
enum_values = schema["enum"]
|
||||||
|
if not enum_values:
|
||||||
|
return self._map_json_type_to_python(json_type)
|
||||||
|
return Literal[tuple(enum_values)] # type: ignore
|
||||||
|
|
||||||
|
if json_type == "array":
|
||||||
|
items_schema = schema.get("items", {"type": "string"})
|
||||||
|
item_type = self._process_schema_type(items_schema, f"{type_name}Item")
|
||||||
|
return List[item_type]
|
||||||
|
|
||||||
|
if json_type == "object":
|
||||||
|
return self._create_nested_model(schema, type_name)
|
||||||
|
|
||||||
|
return self._map_json_type_to_python(json_type)
|
||||||
|
|
||||||
|
def _create_nested_model(self, schema: Dict[str, Any], model_name: str) -> type:
|
||||||
|
"""Create a nested Pydantic model for complex objects."""
|
||||||
|
full_model_name = f"{self._base_name}{model_name}"
|
||||||
|
|
||||||
|
if full_model_name in self._model_registry:
|
||||||
|
return self._model_registry[full_model_name]
|
||||||
|
|
||||||
|
properties = schema.get("properties", {})
|
||||||
|
required_fields = schema.get("required", [])
|
||||||
|
|
||||||
|
if not properties:
|
||||||
|
return dict
|
||||||
|
|
||||||
|
field_definitions = {}
|
||||||
|
for prop_name, prop_schema in properties.items():
|
||||||
|
prop_desc = prop_schema.get("description", "")
|
||||||
|
is_required = prop_name in required_fields
|
||||||
|
|
||||||
|
try:
|
||||||
|
prop_type = self._process_schema_type(
|
||||||
|
prop_schema, f"{model_name}{self._sanitize_name(prop_name).title()}"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Warning: Could not process schema for {prop_name}: {e}")
|
||||||
|
prop_type = str
|
||||||
|
|
||||||
|
field_definitions[prop_name] = self._create_field_definition(
|
||||||
|
prop_type, is_required, prop_desc
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
nested_model = create_model(full_model_name, **field_definitions)
|
||||||
|
self._model_registry[full_model_name] = nested_model
|
||||||
|
return nested_model
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Warning: Could not create nested model {full_model_name}: {e}")
|
||||||
|
return dict
|
||||||
|
|
||||||
def _create_field_definition(
|
def _create_field_definition(
|
||||||
self, param_type: type, is_required: bool, is_nullable: bool, param_desc: str
|
self, field_type: type, is_required: bool, description: str
|
||||||
) -> tuple:
|
) -> tuple:
|
||||||
"""Create Pydantic field definition based on type, requirement, and nullability."""
|
"""Create Pydantic field definition based on type and requirement."""
|
||||||
if is_nullable:
|
if is_required:
|
||||||
return (
|
return (field_type, Field(description=description))
|
||||||
Optional[param_type],
|
|
||||||
Field(default=None, description=param_desc),
|
|
||||||
)
|
|
||||||
elif is_required:
|
|
||||||
return (
|
|
||||||
param_type,
|
|
||||||
Field(description=param_desc),
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
return (
|
if get_origin(field_type) is Union:
|
||||||
Optional[param_type],
|
return (field_type, Field(default=None, description=description))
|
||||||
Field(default=None, description=param_desc),
|
else:
|
||||||
)
|
return (
|
||||||
|
Optional[field_type],
|
||||||
|
Field(default=None, description=description),
|
||||||
|
)
|
||||||
|
|
||||||
def _map_json_type_to_python(
|
def _map_json_type_to_python(self, json_type: str) -> type:
|
||||||
self, json_type: str, param_details: Dict[str, Any]
|
"""Map basic JSON schema types to Python types."""
|
||||||
) -> type:
|
|
||||||
"""Map JSON schema types to Python types."""
|
|
||||||
type_mapping = {
|
type_mapping = {
|
||||||
"string": str,
|
"string": str,
|
||||||
"integer": int,
|
"integer": int,
|
||||||
@@ -139,6 +210,7 @@ class EnterpriseActionTool(BaseTool):
|
|||||||
"boolean": bool,
|
"boolean": bool,
|
||||||
"array": list,
|
"array": list,
|
||||||
"object": dict,
|
"object": dict,
|
||||||
|
"null": type(None),
|
||||||
}
|
}
|
||||||
return type_mapping.get(json_type, str)
|
return type_mapping.get(json_type, str)
|
||||||
|
|
||||||
@@ -149,29 +221,37 @@ class EnterpriseActionTool(BaseTool):
|
|||||||
required_nullable_fields = []
|
required_nullable_fields = []
|
||||||
for param_name in required:
|
for param_name in required:
|
||||||
param_details = schema_props.get(param_name, {})
|
param_details = schema_props.get(param_name, {})
|
||||||
is_nullable, _ = self._analyze_field_type(param_details)
|
if self._is_nullable_type(param_details):
|
||||||
if is_nullable:
|
|
||||||
required_nullable_fields.append(param_name)
|
required_nullable_fields.append(param_name)
|
||||||
|
|
||||||
return required_nullable_fields
|
return required_nullable_fields
|
||||||
|
|
||||||
|
def _is_nullable_type(self, schema: Dict[str, Any]) -> bool:
|
||||||
|
"""Check if a schema represents a nullable type."""
|
||||||
|
if "anyOf" in schema:
|
||||||
|
return any(t.get("type") == "null" for t in schema["anyOf"])
|
||||||
|
return schema.get("type") == "null"
|
||||||
|
|
||||||
def _run(self, **kwargs) -> str:
|
def _run(self, **kwargs) -> str:
|
||||||
"""Execute the specific enterprise action with validated parameters."""
|
"""Execute the specific enterprise action with validated parameters."""
|
||||||
try:
|
try:
|
||||||
|
cleaned_kwargs = {}
|
||||||
|
for key, value in kwargs.items():
|
||||||
|
if value is not None:
|
||||||
|
cleaned_kwargs[key] = value
|
||||||
|
|
||||||
required_nullable_fields = self._get_required_nullable_fields()
|
required_nullable_fields = self._get_required_nullable_fields()
|
||||||
|
|
||||||
for field_name in required_nullable_fields:
|
for field_name in required_nullable_fields:
|
||||||
if field_name not in kwargs:
|
if field_name not in cleaned_kwargs:
|
||||||
kwargs[field_name] = None
|
cleaned_kwargs[field_name] = None
|
||||||
|
|
||||||
params = {k: v for k, v in kwargs.items() if v is not None}
|
|
||||||
|
|
||||||
api_url = f"{self.enterprise_action_kit_project_url}/{self.enterprise_action_kit_project_id}/actions"
|
api_url = f"{self.enterprise_action_kit_project_url}/{self.enterprise_action_kit_project_id}/actions"
|
||||||
headers = {
|
headers = {
|
||||||
"Authorization": f"Bearer {self.enterprise_action_token}",
|
"Authorization": f"Bearer {self.enterprise_action_token}",
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
}
|
}
|
||||||
payload = {"action": self.action_name, "parameters": params}
|
payload = {"action": self.action_name, "parameters": cleaned_kwargs}
|
||||||
|
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
url=api_url, headers=headers, json=payload, timeout=60
|
url=api_url, headers=headers, json=payload, timeout=60
|
||||||
@@ -198,7 +278,6 @@ class EnterpriseActionKitToolAdapter:
|
|||||||
enterprise_action_kit_project_id: str = ENTERPRISE_ACTION_KIT_PROJECT_ID,
|
enterprise_action_kit_project_id: str = ENTERPRISE_ACTION_KIT_PROJECT_ID,
|
||||||
):
|
):
|
||||||
"""Initialize the adapter with an enterprise action token."""
|
"""Initialize the adapter with an enterprise action token."""
|
||||||
|
|
||||||
self.enterprise_action_token = enterprise_action_token
|
self.enterprise_action_token = enterprise_action_token
|
||||||
self._actions_schema = {}
|
self._actions_schema = {}
|
||||||
self._tools = None
|
self._tools = None
|
||||||
@@ -206,11 +285,7 @@ class EnterpriseActionKitToolAdapter:
|
|||||||
self.enterprise_action_kit_project_url = enterprise_action_kit_project_url
|
self.enterprise_action_kit_project_url = enterprise_action_kit_project_url
|
||||||
|
|
||||||
def tools(self) -> List[BaseTool]:
|
def tools(self) -> List[BaseTool]:
|
||||||
"""Get the list of tools created from enterprise actions.
|
"""Get the list of tools created from enterprise actions."""
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of BaseTool instances, one for each enterprise action.
|
|
||||||
"""
|
|
||||||
if self._tools is None:
|
if self._tools is None:
|
||||||
self._fetch_actions()
|
self._fetch_actions()
|
||||||
self._create_tools()
|
self._create_tools()
|
||||||
@@ -261,6 +336,53 @@ class EnterpriseActionKitToolAdapter:
|
|||||||
|
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
|
||||||
|
def _generate_detailed_description(
|
||||||
|
self, schema: Dict[str, Any], indent: int = 0
|
||||||
|
) -> List[str]:
|
||||||
|
"""Generate detailed description for nested schema structures."""
|
||||||
|
descriptions = []
|
||||||
|
indent_str = " " * indent
|
||||||
|
|
||||||
|
schema_type = schema.get("type", "string")
|
||||||
|
|
||||||
|
if schema_type == "object":
|
||||||
|
properties = schema.get("properties", {})
|
||||||
|
required_fields = schema.get("required", [])
|
||||||
|
|
||||||
|
if properties:
|
||||||
|
descriptions.append(f"{indent_str}Object with properties:")
|
||||||
|
for prop_name, prop_schema in properties.items():
|
||||||
|
prop_desc = prop_schema.get("description", "")
|
||||||
|
is_required = prop_name in required_fields
|
||||||
|
req_str = " (required)" if is_required else " (optional)"
|
||||||
|
descriptions.append(
|
||||||
|
f"{indent_str} - {prop_name}: {prop_desc}{req_str}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if prop_schema.get("type") == "object":
|
||||||
|
descriptions.extend(
|
||||||
|
self._generate_detailed_description(prop_schema, indent + 2)
|
||||||
|
)
|
||||||
|
elif prop_schema.get("type") == "array":
|
||||||
|
items_schema = prop_schema.get("items", {})
|
||||||
|
if items_schema.get("type") == "object":
|
||||||
|
descriptions.append(f"{indent_str} Array of objects:")
|
||||||
|
descriptions.extend(
|
||||||
|
self._generate_detailed_description(
|
||||||
|
items_schema, indent + 3
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif "enum" in items_schema:
|
||||||
|
descriptions.append(
|
||||||
|
f"{indent_str} Array of enum values: {items_schema['enum']}"
|
||||||
|
)
|
||||||
|
elif "enum" in prop_schema:
|
||||||
|
descriptions.append(
|
||||||
|
f"{indent_str} Enum values: {prop_schema['enum']}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return descriptions
|
||||||
|
|
||||||
def _create_tools(self):
|
def _create_tools(self):
|
||||||
"""Create BaseTool instances for each action."""
|
"""Create BaseTool instances for each action."""
|
||||||
tools = []
|
tools = []
|
||||||
@@ -269,19 +391,16 @@ class EnterpriseActionKitToolAdapter:
|
|||||||
function_details = action_schema.get("function", {})
|
function_details = action_schema.get("function", {})
|
||||||
description = function_details.get("description", f"Execute {action_name}")
|
description = function_details.get("description", f"Execute {action_name}")
|
||||||
|
|
||||||
# Get parameter info for a better description
|
parameters = function_details.get("parameters", {})
|
||||||
parameters = function_details.get("parameters", {}).get("properties", {})
|
param_descriptions = []
|
||||||
param_info = []
|
|
||||||
for param_name, param_details in parameters.items():
|
if parameters.get("properties"):
|
||||||
param_desc = param_details.get("description", "")
|
param_descriptions.append("\nDetailed Parameter Structure:")
|
||||||
required = param_name in function_details.get("parameters", {}).get(
|
param_descriptions.extend(
|
||||||
"required", []
|
self._generate_detailed_description(parameters)
|
||||||
)
|
|
||||||
param_info.append(
|
|
||||||
f"- {param_name}: {param_desc} {'(required)' if required else '(optional)'}"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
full_description = f"{description}\n\nParameters:\n" + "\n".join(param_info)
|
full_description = description + "\n".join(param_descriptions)
|
||||||
|
|
||||||
tool = EnterpriseActionTool(
|
tool = EnterpriseActionTool(
|
||||||
name=action_name.lower().replace(" ", "_"),
|
name=action_name.lower().replace(" ", "_"),
|
||||||
@@ -297,7 +416,6 @@ class EnterpriseActionKitToolAdapter:
|
|||||||
|
|
||||||
self._tools = tools
|
self._tools = tools
|
||||||
|
|
||||||
# Adding context manager support for convenience, but direct usage is also supported
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
return self.tools()
|
return self.tools()
|
||||||
|
|
||||||
|
|||||||
@@ -2,9 +2,11 @@ import os
|
|||||||
import unittest
|
import unittest
|
||||||
from unittest.mock import patch, MagicMock
|
from unittest.mock import patch, MagicMock
|
||||||
|
|
||||||
|
|
||||||
from crewai.tools import BaseTool
|
from crewai.tools import BaseTool
|
||||||
from crewai_tools.tools import CrewaiEnterpriseTools
|
from crewai_tools.tools import CrewaiEnterpriseTools
|
||||||
from crewai_tools.adapters.tool_collection import ToolCollection
|
from crewai_tools.adapters.tool_collection import ToolCollection
|
||||||
|
from crewai_tools.adapters.enterprise_adapter import EnterpriseActionTool
|
||||||
|
|
||||||
|
|
||||||
class TestCrewaiEnterpriseTools(unittest.TestCase):
|
class TestCrewaiEnterpriseTools(unittest.TestCase):
|
||||||
@@ -86,3 +88,269 @@ class TestCrewaiEnterpriseTools(unittest.TestCase):
|
|||||||
self.assertEqual(len(tools), 2)
|
self.assertEqual(len(tools), 2)
|
||||||
self.assertEqual(tools[0].name, "tool1")
|
self.assertEqual(tools[0].name, "tool1")
|
||||||
self.assertEqual(tools[1].name, "tool3")
|
self.assertEqual(tools[1].name, "tool3")
|
||||||
|
|
||||||
|
|
||||||
|
class TestEnterpriseActionToolSchemaConversion(unittest.TestCase):
|
||||||
|
"""Test the enterprise action tool schema conversion and validation."""
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.test_schema = {
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "TEST_COMPLEX_ACTION",
|
||||||
|
"description": "Test action with complex nested structure",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"filterCriteria": {
|
||||||
|
"type": "object",
|
||||||
|
"description": "Filter criteria object",
|
||||||
|
"properties": {
|
||||||
|
"operation": {"type": "string", "enum": ["AND", "OR"]},
|
||||||
|
"rules": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"field": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["name", "email", "status"],
|
||||||
|
},
|
||||||
|
"operator": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["equals", "contains"],
|
||||||
|
},
|
||||||
|
"value": {"type": "string"},
|
||||||
|
},
|
||||||
|
"required": ["field", "operator", "value"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["operation", "rules"],
|
||||||
|
},
|
||||||
|
"options": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"limit": {"type": "integer"},
|
||||||
|
"offset": {"type": "integer"},
|
||||||
|
},
|
||||||
|
"required": [],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": [],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
def test_complex_schema_conversion(self):
|
||||||
|
"""Test that complex nested schemas are properly converted to Pydantic models."""
|
||||||
|
tool = EnterpriseActionTool(
|
||||||
|
name="gmail_search_for_email",
|
||||||
|
description="Test tool",
|
||||||
|
enterprise_action_token="test_token",
|
||||||
|
action_name="GMAIL_SEARCH_FOR_EMAIL",
|
||||||
|
action_schema=self.test_schema,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(tool.name, "gmail_search_for_email")
|
||||||
|
self.assertEqual(tool.action_name, "GMAIL_SEARCH_FOR_EMAIL")
|
||||||
|
|
||||||
|
schema_class = tool.args_schema
|
||||||
|
self.assertIsNotNone(schema_class)
|
||||||
|
|
||||||
|
schema_fields = schema_class.model_fields
|
||||||
|
self.assertIn("filterCriteria", schema_fields)
|
||||||
|
self.assertIn("options", schema_fields)
|
||||||
|
|
||||||
|
# Test valid input structure
|
||||||
|
valid_input = {
|
||||||
|
"filterCriteria": {
|
||||||
|
"operation": "AND",
|
||||||
|
"rules": [
|
||||||
|
{"field": "name", "operator": "contains", "value": "test"},
|
||||||
|
{"field": "status", "operator": "equals", "value": "active"},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
"options": {"limit": 10},
|
||||||
|
}
|
||||||
|
|
||||||
|
# This should not raise an exception
|
||||||
|
validated_input = schema_class(**valid_input)
|
||||||
|
self.assertIsNotNone(validated_input.filterCriteria)
|
||||||
|
self.assertIsNotNone(validated_input.options)
|
||||||
|
|
||||||
|
def test_optional_fields_validation(self):
|
||||||
|
"""Test that optional fields work correctly."""
|
||||||
|
tool = EnterpriseActionTool(
|
||||||
|
name="gmail_search_for_email",
|
||||||
|
description="Test tool",
|
||||||
|
enterprise_action_token="test_token",
|
||||||
|
action_name="GMAIL_SEARCH_FOR_EMAIL",
|
||||||
|
action_schema=self.test_schema,
|
||||||
|
)
|
||||||
|
|
||||||
|
schema_class = tool.args_schema
|
||||||
|
|
||||||
|
minimal_input = {}
|
||||||
|
validated_input = schema_class(**minimal_input)
|
||||||
|
self.assertIsNone(validated_input.filterCriteria)
|
||||||
|
self.assertIsNone(validated_input.options)
|
||||||
|
|
||||||
|
partial_input = {"options": {"limit": 10}}
|
||||||
|
validated_input = schema_class(**partial_input)
|
||||||
|
self.assertIsNone(validated_input.filterCriteria)
|
||||||
|
self.assertIsNotNone(validated_input.options)
|
||||||
|
|
||||||
|
def test_enum_validation(self):
|
||||||
|
"""Test that enum values are properly validated."""
|
||||||
|
tool = EnterpriseActionTool(
|
||||||
|
name="gmail_search_for_email",
|
||||||
|
description="Test tool",
|
||||||
|
enterprise_action_token="test_token",
|
||||||
|
action_name="GMAIL_SEARCH_FOR_EMAIL",
|
||||||
|
action_schema=self.test_schema,
|
||||||
|
)
|
||||||
|
|
||||||
|
schema_class = tool.args_schema
|
||||||
|
|
||||||
|
invalid_input = {
|
||||||
|
"filterCriteria": {
|
||||||
|
"operation": "INVALID_OPERATOR",
|
||||||
|
"rules": [],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
with self.assertRaises(Exception):
|
||||||
|
schema_class(**invalid_input)
|
||||||
|
|
||||||
|
def test_required_nested_fields(self):
|
||||||
|
"""Test that required fields in nested objects are validated."""
|
||||||
|
tool = EnterpriseActionTool(
|
||||||
|
name="gmail_search_for_email",
|
||||||
|
description="Test tool",
|
||||||
|
enterprise_action_token="test_token",
|
||||||
|
action_name="GMAIL_SEARCH_FOR_EMAIL",
|
||||||
|
action_schema=self.test_schema,
|
||||||
|
)
|
||||||
|
|
||||||
|
schema_class = tool.args_schema
|
||||||
|
|
||||||
|
incomplete_input = {
|
||||||
|
"filterCriteria": {
|
||||||
|
"operation": "OR",
|
||||||
|
"rules": [
|
||||||
|
{
|
||||||
|
"field": "name",
|
||||||
|
"operator": "contains",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
with self.assertRaises(Exception):
|
||||||
|
schema_class(**incomplete_input)
|
||||||
|
|
||||||
|
@patch("requests.post")
|
||||||
|
def test_tool_execution_with_complex_input(self, mock_post):
|
||||||
|
"""Test that the tool can execute with complex validated input."""
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.ok = True
|
||||||
|
mock_response.json.return_value = {"success": True, "results": []}
|
||||||
|
mock_post.return_value = mock_response
|
||||||
|
|
||||||
|
tool = EnterpriseActionTool(
|
||||||
|
name="gmail_search_for_email",
|
||||||
|
description="Test tool",
|
||||||
|
enterprise_action_token="test_token",
|
||||||
|
action_name="GMAIL_SEARCH_FOR_EMAIL",
|
||||||
|
action_schema=self.test_schema,
|
||||||
|
)
|
||||||
|
|
||||||
|
tool._run(
|
||||||
|
filterCriteria={
|
||||||
|
"operation": "OR",
|
||||||
|
"rules": [
|
||||||
|
{"field": "name", "operator": "contains", "value": "test"},
|
||||||
|
{"field": "status", "operator": "equals", "value": "active"},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
options={"limit": 10},
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_post.assert_called_once()
|
||||||
|
call_args = mock_post.call_args
|
||||||
|
payload = call_args[1]["json"]
|
||||||
|
|
||||||
|
self.assertEqual(payload["action"], "GMAIL_SEARCH_FOR_EMAIL")
|
||||||
|
self.assertIn("filterCriteria", payload["parameters"])
|
||||||
|
self.assertIn("options", payload["parameters"])
|
||||||
|
self.assertEqual(payload["parameters"]["filterCriteria"]["operation"], "OR")
|
||||||
|
|
||||||
|
def test_model_naming_convention(self):
|
||||||
|
"""Test that generated model names follow proper conventions."""
|
||||||
|
tool = EnterpriseActionTool(
|
||||||
|
name="gmail_search_for_email",
|
||||||
|
description="Test tool",
|
||||||
|
enterprise_action_token="test_token",
|
||||||
|
action_name="GMAIL_SEARCH_FOR_EMAIL",
|
||||||
|
action_schema=self.test_schema,
|
||||||
|
)
|
||||||
|
|
||||||
|
schema_class = tool.args_schema
|
||||||
|
self.assertIsNotNone(schema_class)
|
||||||
|
|
||||||
|
self.assertTrue(schema_class.__name__.endswith("Schema"))
|
||||||
|
self.assertTrue(schema_class.__name__[0].isupper())
|
||||||
|
|
||||||
|
complex_input = {
|
||||||
|
"filterCriteria": {
|
||||||
|
"operation": "OR",
|
||||||
|
"rules": [
|
||||||
|
{"field": "name", "operator": "contains", "value": "test"},
|
||||||
|
{"field": "status", "operator": "equals", "value": "active"},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
"options": {"limit": 10},
|
||||||
|
}
|
||||||
|
|
||||||
|
validated = schema_class(**complex_input)
|
||||||
|
self.assertIsNotNone(validated.filterCriteria)
|
||||||
|
|
||||||
|
def test_simple_schema_with_enums(self):
|
||||||
|
"""Test a simpler schema with basic enum validation."""
|
||||||
|
simple_schema = {
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "SIMPLE_TEST",
|
||||||
|
"description": "Simple test function",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"status": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["active", "inactive", "pending"],
|
||||||
|
},
|
||||||
|
"priority": {"type": "integer", "enum": [1, 2, 3]},
|
||||||
|
},
|
||||||
|
"required": ["status"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
tool = EnterpriseActionTool(
|
||||||
|
name="simple_test",
|
||||||
|
description="Simple test tool",
|
||||||
|
enterprise_action_token="test_token",
|
||||||
|
action_name="SIMPLE_TEST",
|
||||||
|
action_schema=simple_schema,
|
||||||
|
)
|
||||||
|
|
||||||
|
schema_class = tool.args_schema
|
||||||
|
|
||||||
|
valid_input = {"status": "active", "priority": 2}
|
||||||
|
validated = schema_class(**valid_input)
|
||||||
|
self.assertEqual(validated.status, "active")
|
||||||
|
self.assertEqual(validated.priority, 2)
|
||||||
|
|
||||||
|
with self.assertRaises(Exception):
|
||||||
|
schema_class(status="invalid_status")
|
||||||
|
|||||||
Reference in New Issue
Block a user