diff --git a/src/crewai_tools/adapters/enterprise_adapter.py b/src/crewai_tools/adapters/enterprise_adapter.py index 6799d7ea8..96d64af8b 100644 --- a/src/crewai_tools/adapters/enterprise_adapter.py +++ b/src/crewai_tools/adapters/enterprise_adapter.py @@ -37,34 +37,18 @@ class EnterpriseActionTool(BaseTool): enterprise_action_kit_project_url: str = ENTERPRISE_ACTION_KIT_PROJECT_URL, enterprise_action_kit_project_id: str = ENTERPRISE_ACTION_KIT_PROJECT_ID, ): - schema_props = ( - action_schema.get("function", {}) - .get("parameters", {}) - .get("properties", {}) - ) - required = ( - action_schema.get("function", {}).get("parameters", {}).get("required", []) - ) + schema_props, required = self._extract_schema_info(action_schema) # Define field definitions for the model field_definitions = {} for param_name, param_details in schema_props.items(): - param_type = str # Default to string type param_desc = param_details.get("description", "") is_required = param_name in required + is_nullable, param_type = self._analyze_field_type(param_details) - # Basic type mapping (can be extended) - if param_details.get("type") == "integer": - param_type = int - elif param_details.get("type") == "number": - param_type = float - elif param_details.get("type") == "boolean": - param_type = bool - - # Create field with appropriate type and config - field_definitions[param_name] = ( - param_type if is_required else Optional[param_type], - Field(description=param_desc), + # Create field definition based on nullable and required status + field_definitions[param_name] = self._create_field_definition( + param_type, is_required, is_nullable, param_desc ) # Create the model @@ -89,9 +73,97 @@ class EnterpriseActionTool(BaseTool): if enterprise_action_kit_project_url is not None: self.enterprise_action_kit_project_url = enterprise_action_kit_project_url + def _extract_schema_info( + self, action_schema: Dict[str, Any] + ) -> tuple[Dict[str, Any], List[str]]: + """Extract schema properties and required fields from action schema.""" + schema_props = ( + action_schema.get("function", {}) + .get("parameters", {}) + .get("properties", {}) + ) + required = ( + action_schema.get("function", {}).get("parameters", {}).get("required", []) + ) + return schema_props, required + + def _analyze_field_type(self, param_details: Dict[str, Any]) -> tuple[bool, type]: + """Analyze field type and nullability from parameter details.""" + is_nullable = False + param_type = str # Default type + + if "anyOf" in param_details: + any_of_types = param_details["anyOf"] + is_nullable = any(t.get("type") == "null" for t in any_of_types) + non_null_types = [t for t in any_of_types if t.get("type") != "null"] + if non_null_types: + first_type = non_null_types[0].get("type", "string") + param_type = self._map_json_type_to_python( + first_type, non_null_types[0] + ) + else: + json_type = param_details.get("type", "string") + param_type = self._map_json_type_to_python(json_type, param_details) + is_nullable = json_type == "null" + + return is_nullable, param_type + + def _create_field_definition( + self, param_type: type, is_required: bool, is_nullable: bool, param_desc: str + ) -> tuple: + """Create Pydantic field definition based on type, requirement, and nullability.""" + if is_nullable: + return ( + Optional[param_type], + Field(default=None, description=param_desc), + ) + elif is_required: + return ( + param_type, + Field(description=param_desc), + ) + else: + return ( + Optional[param_type], + Field(default=None, description=param_desc), + ) + + def _map_json_type_to_python( + self, json_type: str, param_details: Dict[str, Any] + ) -> type: + """Map JSON schema types to Python types.""" + type_mapping = { + "string": str, + "integer": int, + "number": float, + "boolean": bool, + "array": list, + "object": dict, + } + return type_mapping.get(json_type, str) + + def _get_required_nullable_fields(self) -> List[str]: + """Get a list of required nullable fields from the action schema.""" + schema_props, required = self._extract_schema_info(self.action_schema) + + required_nullable_fields = [] + for param_name in required: + param_details = schema_props.get(param_name, {}) + is_nullable, _ = self._analyze_field_type(param_details) + if is_nullable: + required_nullable_fields.append(param_name) + + return required_nullable_fields + def _run(self, **kwargs) -> str: """Execute the specific enterprise action with validated parameters.""" try: + required_nullable_fields = self._get_required_nullable_fields() + + for field_name in required_nullable_fields: + if field_name not in kwargs: + kwargs[field_name] = None + params = {k: v for k, v in kwargs.items() if v is not None} api_url = f"{self.enterprise_action_kit_project_url}/{self.enterprise_action_kit_project_id}/actions"