refactor: enhance schema handling in EnterpriseActionTool (#355)

* refactor: enhance schema handling in EnterpriseActionTool

- Extracted schema property and required field extraction into separate methods for better readability and maintainability.
- Introduced methods to analyze field types and create Pydantic field definitions based on nullability and requirement status.
- Updated the _run method to handle required nullable fields, ensuring they are set to None if not provided in kwargs.

* refactor: streamline nullable field handling in EnterpriseActionTool

- Removed commented-out code related to handling required nullable fields for clarity.
- Simplified the logic in the _run method to focus on processing parameters without unnecessary comments.
This commit is contained in:
Lorenze Jay
2025-07-02 12:54:09 -07:00
committed by GitHub
parent d53e96fcd7
commit b4786d86b0

View File

@@ -37,34 +37,18 @@ class EnterpriseActionTool(BaseTool):
enterprise_action_kit_project_url: str = ENTERPRISE_ACTION_KIT_PROJECT_URL, enterprise_action_kit_project_url: str = ENTERPRISE_ACTION_KIT_PROJECT_URL,
enterprise_action_kit_project_id: str = ENTERPRISE_ACTION_KIT_PROJECT_ID, enterprise_action_kit_project_id: str = ENTERPRISE_ACTION_KIT_PROJECT_ID,
): ):
schema_props = ( schema_props, required = self._extract_schema_info(action_schema)
action_schema.get("function", {})
.get("parameters", {})
.get("properties", {})
)
required = (
action_schema.get("function", {}).get("parameters", {}).get("required", [])
)
# Define field definitions for the model # Define field definitions for the model
field_definitions = {} field_definitions = {}
for param_name, param_details in schema_props.items(): for param_name, param_details in schema_props.items():
param_type = str # Default to string type
param_desc = param_details.get("description", "") param_desc = param_details.get("description", "")
is_required = param_name in required is_required = param_name in required
is_nullable, param_type = self._analyze_field_type(param_details)
# Basic type mapping (can be extended) # Create field definition based on nullable and required status
if param_details.get("type") == "integer": field_definitions[param_name] = self._create_field_definition(
param_type = int param_type, is_required, is_nullable, param_desc
elif param_details.get("type") == "number":
param_type = float
elif param_details.get("type") == "boolean":
param_type = bool
# Create field with appropriate type and config
field_definitions[param_name] = (
param_type if is_required else Optional[param_type],
Field(description=param_desc),
) )
# Create the model # Create the model
@@ -89,9 +73,97 @@ class EnterpriseActionTool(BaseTool):
if enterprise_action_kit_project_url is not None: if enterprise_action_kit_project_url is not None:
self.enterprise_action_kit_project_url = enterprise_action_kit_project_url self.enterprise_action_kit_project_url = enterprise_action_kit_project_url
def _extract_schema_info(
self, action_schema: Dict[str, Any]
) -> tuple[Dict[str, Any], List[str]]:
"""Extract schema properties and required fields from action schema."""
schema_props = (
action_schema.get("function", {})
.get("parameters", {})
.get("properties", {})
)
required = (
action_schema.get("function", {}).get("parameters", {}).get("required", [])
)
return schema_props, required
def _analyze_field_type(self, param_details: Dict[str, Any]) -> tuple[bool, type]:
"""Analyze field type and nullability from parameter details."""
is_nullable = False
param_type = str # Default type
if "anyOf" in param_details:
any_of_types = param_details["anyOf"]
is_nullable = any(t.get("type") == "null" for t in any_of_types)
non_null_types = [t for t in any_of_types if t.get("type") != "null"]
if non_null_types:
first_type = non_null_types[0].get("type", "string")
param_type = self._map_json_type_to_python(
first_type, non_null_types[0]
)
else:
json_type = param_details.get("type", "string")
param_type = self._map_json_type_to_python(json_type, param_details)
is_nullable = json_type == "null"
return is_nullable, param_type
def _create_field_definition(
self, param_type: type, is_required: bool, is_nullable: bool, param_desc: str
) -> tuple:
"""Create Pydantic field definition based on type, requirement, and nullability."""
if is_nullable:
return (
Optional[param_type],
Field(default=None, description=param_desc),
)
elif is_required:
return (
param_type,
Field(description=param_desc),
)
else:
return (
Optional[param_type],
Field(default=None, description=param_desc),
)
def _map_json_type_to_python(
self, json_type: str, param_details: Dict[str, Any]
) -> type:
"""Map JSON schema types to Python types."""
type_mapping = {
"string": str,
"integer": int,
"number": float,
"boolean": bool,
"array": list,
"object": dict,
}
return type_mapping.get(json_type, str)
def _get_required_nullable_fields(self) -> List[str]:
"""Get a list of required nullable fields from the action schema."""
schema_props, required = self._extract_schema_info(self.action_schema)
required_nullable_fields = []
for param_name in required:
param_details = schema_props.get(param_name, {})
is_nullable, _ = self._analyze_field_type(param_details)
if is_nullable:
required_nullable_fields.append(param_name)
return required_nullable_fields
def _run(self, **kwargs) -> str: def _run(self, **kwargs) -> str:
"""Execute the specific enterprise action with validated parameters.""" """Execute the specific enterprise action with validated parameters."""
try: try:
required_nullable_fields = self._get_required_nullable_fields()
for field_name in required_nullable_fields:
if field_name not in kwargs:
kwargs[field_name] = None
params = {k: v for k, v in kwargs.items() if v is not None} params = {k: v for k, v in kwargs.items() if v is not None}
api_url = f"{self.enterprise_action_kit_project_url}/{self.enterprise_action_kit_project_id}/actions" api_url = f"{self.enterprise_action_kit_project_url}/{self.enterprise_action_kit_project_id}/actions"