mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-10 08:38:30 +00:00
refactor: enhance schema handling in EnterpriseActionTool (#355)
* refactor: enhance schema handling in EnterpriseActionTool - Extracted schema property and required field extraction into separate methods for better readability and maintainability. - Introduced methods to analyze field types and create Pydantic field definitions based on nullability and requirement status. - Updated the _run method to handle required nullable fields, ensuring they are set to None if not provided in kwargs. * refactor: streamline nullable field handling in EnterpriseActionTool - Removed commented-out code related to handling required nullable fields for clarity. - Simplified the logic in the _run method to focus on processing parameters without unnecessary comments.
This commit is contained in:
@@ -37,34 +37,18 @@ class EnterpriseActionTool(BaseTool):
|
|||||||
enterprise_action_kit_project_url: str = ENTERPRISE_ACTION_KIT_PROJECT_URL,
|
enterprise_action_kit_project_url: str = ENTERPRISE_ACTION_KIT_PROJECT_URL,
|
||||||
enterprise_action_kit_project_id: str = ENTERPRISE_ACTION_KIT_PROJECT_ID,
|
enterprise_action_kit_project_id: str = ENTERPRISE_ACTION_KIT_PROJECT_ID,
|
||||||
):
|
):
|
||||||
schema_props = (
|
schema_props, required = self._extract_schema_info(action_schema)
|
||||||
action_schema.get("function", {})
|
|
||||||
.get("parameters", {})
|
|
||||||
.get("properties", {})
|
|
||||||
)
|
|
||||||
required = (
|
|
||||||
action_schema.get("function", {}).get("parameters", {}).get("required", [])
|
|
||||||
)
|
|
||||||
|
|
||||||
# Define field definitions for the model
|
# Define field definitions for the model
|
||||||
field_definitions = {}
|
field_definitions = {}
|
||||||
for param_name, param_details in schema_props.items():
|
for param_name, param_details in schema_props.items():
|
||||||
param_type = str # Default to string type
|
|
||||||
param_desc = param_details.get("description", "")
|
param_desc = param_details.get("description", "")
|
||||||
is_required = param_name in required
|
is_required = param_name in required
|
||||||
|
is_nullable, param_type = self._analyze_field_type(param_details)
|
||||||
|
|
||||||
# Basic type mapping (can be extended)
|
# Create field definition based on nullable and required status
|
||||||
if param_details.get("type") == "integer":
|
field_definitions[param_name] = self._create_field_definition(
|
||||||
param_type = int
|
param_type, is_required, is_nullable, param_desc
|
||||||
elif param_details.get("type") == "number":
|
|
||||||
param_type = float
|
|
||||||
elif param_details.get("type") == "boolean":
|
|
||||||
param_type = bool
|
|
||||||
|
|
||||||
# Create field with appropriate type and config
|
|
||||||
field_definitions[param_name] = (
|
|
||||||
param_type if is_required else Optional[param_type],
|
|
||||||
Field(description=param_desc),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create the model
|
# Create the model
|
||||||
@@ -89,9 +73,97 @@ class EnterpriseActionTool(BaseTool):
|
|||||||
if enterprise_action_kit_project_url is not None:
|
if enterprise_action_kit_project_url is not None:
|
||||||
self.enterprise_action_kit_project_url = enterprise_action_kit_project_url
|
self.enterprise_action_kit_project_url = enterprise_action_kit_project_url
|
||||||
|
|
||||||
|
def _extract_schema_info(
|
||||||
|
self, action_schema: Dict[str, Any]
|
||||||
|
) -> tuple[Dict[str, Any], List[str]]:
|
||||||
|
"""Extract schema properties and required fields from action schema."""
|
||||||
|
schema_props = (
|
||||||
|
action_schema.get("function", {})
|
||||||
|
.get("parameters", {})
|
||||||
|
.get("properties", {})
|
||||||
|
)
|
||||||
|
required = (
|
||||||
|
action_schema.get("function", {}).get("parameters", {}).get("required", [])
|
||||||
|
)
|
||||||
|
return schema_props, required
|
||||||
|
|
||||||
|
def _analyze_field_type(self, param_details: Dict[str, Any]) -> tuple[bool, type]:
|
||||||
|
"""Analyze field type and nullability from parameter details."""
|
||||||
|
is_nullable = False
|
||||||
|
param_type = str # Default type
|
||||||
|
|
||||||
|
if "anyOf" in param_details:
|
||||||
|
any_of_types = param_details["anyOf"]
|
||||||
|
is_nullable = any(t.get("type") == "null" for t in any_of_types)
|
||||||
|
non_null_types = [t for t in any_of_types if t.get("type") != "null"]
|
||||||
|
if non_null_types:
|
||||||
|
first_type = non_null_types[0].get("type", "string")
|
||||||
|
param_type = self._map_json_type_to_python(
|
||||||
|
first_type, non_null_types[0]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
json_type = param_details.get("type", "string")
|
||||||
|
param_type = self._map_json_type_to_python(json_type, param_details)
|
||||||
|
is_nullable = json_type == "null"
|
||||||
|
|
||||||
|
return is_nullable, param_type
|
||||||
|
|
||||||
|
def _create_field_definition(
|
||||||
|
self, param_type: type, is_required: bool, is_nullable: bool, param_desc: str
|
||||||
|
) -> tuple:
|
||||||
|
"""Create Pydantic field definition based on type, requirement, and nullability."""
|
||||||
|
if is_nullable:
|
||||||
|
return (
|
||||||
|
Optional[param_type],
|
||||||
|
Field(default=None, description=param_desc),
|
||||||
|
)
|
||||||
|
elif is_required:
|
||||||
|
return (
|
||||||
|
param_type,
|
||||||
|
Field(description=param_desc),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return (
|
||||||
|
Optional[param_type],
|
||||||
|
Field(default=None, description=param_desc),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _map_json_type_to_python(
|
||||||
|
self, json_type: str, param_details: Dict[str, Any]
|
||||||
|
) -> type:
|
||||||
|
"""Map JSON schema types to Python types."""
|
||||||
|
type_mapping = {
|
||||||
|
"string": str,
|
||||||
|
"integer": int,
|
||||||
|
"number": float,
|
||||||
|
"boolean": bool,
|
||||||
|
"array": list,
|
||||||
|
"object": dict,
|
||||||
|
}
|
||||||
|
return type_mapping.get(json_type, str)
|
||||||
|
|
||||||
|
def _get_required_nullable_fields(self) -> List[str]:
|
||||||
|
"""Get a list of required nullable fields from the action schema."""
|
||||||
|
schema_props, required = self._extract_schema_info(self.action_schema)
|
||||||
|
|
||||||
|
required_nullable_fields = []
|
||||||
|
for param_name in required:
|
||||||
|
param_details = schema_props.get(param_name, {})
|
||||||
|
is_nullable, _ = self._analyze_field_type(param_details)
|
||||||
|
if is_nullable:
|
||||||
|
required_nullable_fields.append(param_name)
|
||||||
|
|
||||||
|
return required_nullable_fields
|
||||||
|
|
||||||
def _run(self, **kwargs) -> str:
|
def _run(self, **kwargs) -> str:
|
||||||
"""Execute the specific enterprise action with validated parameters."""
|
"""Execute the specific enterprise action with validated parameters."""
|
||||||
try:
|
try:
|
||||||
|
required_nullable_fields = self._get_required_nullable_fields()
|
||||||
|
|
||||||
|
for field_name in required_nullable_fields:
|
||||||
|
if field_name not in kwargs:
|
||||||
|
kwargs[field_name] = None
|
||||||
|
|
||||||
params = {k: v for k, v in kwargs.items() if v is not None}
|
params = {k: v for k, v in kwargs.items() if v is not None}
|
||||||
|
|
||||||
api_url = f"{self.enterprise_action_kit_project_url}/{self.enterprise_action_kit_project_id}/actions"
|
api_url = f"{self.enterprise_action_kit_project_url}/{self.enterprise_action_kit_project_id}/actions"
|
||||||
|
|||||||
Reference in New Issue
Block a user