refactor: fetch & execute enterprise tool actions from platform (#437)

* refactor: fetch enterprise tool actions from platform

* chore: logging legacy token detected
This commit is contained in:
Lucas Gomide
2025-09-02 16:41:00 -03:00
committed by GitHub
parent 93b841fc86
commit 33241ef363
2 changed files with 52 additions and 52 deletions

View File

@@ -1,15 +1,19 @@
import os import os
import json import json
import requests import requests
from typing import List, Any, Dict, Literal, Optional, Union, get_origin import logging
from typing import List, Any, Dict, Literal, Optional, Union, get_origin, Type, cast
from pydantic import Field, create_model from pydantic import Field, create_model
from crewai.tools import BaseTool from crewai.tools import BaseTool
import re import re
# DEFAULTS def get_enterprise_api_base_url() -> str:
ENTERPRISE_ACTION_KIT_PROJECT_ID = "dd525517-df22-49d2-a69e-6a0eed211166" """Get the enterprise API base URL from environment or use default."""
ENTERPRISE_ACTION_KIT_PROJECT_URL = "https://worker-actionkit.tools.crewai.com/projects" base_url = os.getenv("CREWAI_PLUS_URL", "https://app.crewai.com")
return f"{base_url}/crewai_plus/api/v1/integrations"
ENTERPRISE_API_BASE_URL = get_enterprise_api_base_url()
class EnterpriseActionTool(BaseTool): class EnterpriseActionTool(BaseTool):
@@ -22,11 +26,8 @@ class EnterpriseActionTool(BaseTool):
action_schema: Dict[str, Any] = Field( action_schema: Dict[str, Any] = Field(
default={}, description="The schema of the action" default={}, description="The schema of the action"
) )
enterprise_action_kit_project_id: str = Field( enterprise_api_base_url: str = Field(
default=ENTERPRISE_ACTION_KIT_PROJECT_ID, description="The project id" default=ENTERPRISE_API_BASE_URL, description="The base API URL"
)
enterprise_action_kit_project_url: str = Field(
default=ENTERPRISE_ACTION_KIT_PROJECT_URL, description="The project url"
) )
def __init__( def __init__(
@@ -36,8 +37,7 @@ class EnterpriseActionTool(BaseTool):
enterprise_action_token: str, enterprise_action_token: str,
action_name: str, action_name: str,
action_schema: Dict[str, Any], action_schema: Dict[str, Any],
enterprise_action_kit_project_url: str = ENTERPRISE_ACTION_KIT_PROJECT_URL, enterprise_api_base_url: Optional[str] = None,
enterprise_action_kit_project_id: str = ENTERPRISE_ACTION_KIT_PROJECT_ID,
): ):
self._model_registry = {} self._model_registry = {}
self._base_name = self._sanitize_name(name) self._base_name = self._sanitize_name(name)
@@ -86,11 +86,7 @@ class EnterpriseActionTool(BaseTool):
self.enterprise_action_token = enterprise_action_token self.enterprise_action_token = enterprise_action_token
self.action_name = action_name self.action_name = action_name
self.action_schema = action_schema self.action_schema = action_schema
self.enterprise_api_base_url = enterprise_api_base_url or get_enterprise_api_base_url()
if enterprise_action_kit_project_id is not None:
self.enterprise_action_kit_project_id = enterprise_action_kit_project_id
if enterprise_action_kit_project_url is not None:
self.enterprise_action_kit_project_url = enterprise_action_kit_project_url
def _sanitize_name(self, name: str) -> str: def _sanitize_name(self, name: str) -> str:
"""Sanitize names to create proper Python class names.""" """Sanitize names to create proper Python class names."""
@@ -112,7 +108,7 @@ class EnterpriseActionTool(BaseTool):
) )
return schema_props, required return schema_props, required
def _process_schema_type(self, schema: Dict[str, Any], type_name: str) -> type: def _process_schema_type(self, schema: Dict[str, Any], type_name: str) -> Type[Any]:
"""Process a JSON schema and return appropriate Python type.""" """Process a JSON schema and return appropriate Python type."""
if "anyOf" in schema: if "anyOf" in schema:
any_of_types = schema["anyOf"] any_of_types = schema["anyOf"]
@@ -122,7 +118,7 @@ class EnterpriseActionTool(BaseTool):
if non_null_types: if non_null_types:
base_type = self._process_schema_type(non_null_types[0], type_name) base_type = self._process_schema_type(non_null_types[0], type_name)
return Optional[base_type] if is_nullable else base_type return Optional[base_type] if is_nullable else base_type
return Optional[str] return cast(Type[Any], Optional[str])
if "oneOf" in schema: if "oneOf" in schema:
return self._process_schema_type(schema["oneOf"][0], type_name) return self._process_schema_type(schema["oneOf"][0], type_name)
@@ -136,7 +132,7 @@ class EnterpriseActionTool(BaseTool):
enum_values = schema["enum"] enum_values = schema["enum"]
if not enum_values: if not enum_values:
return self._map_json_type_to_python(json_type) return self._map_json_type_to_python(json_type)
return Literal[tuple(enum_values)] # type: ignore return Literal[tuple(enum_values)] # type: ignore[return-value]
if json_type == "array": if json_type == "array":
items_schema = schema.get("items", {"type": "string"}) items_schema = schema.get("items", {"type": "string"})
@@ -148,7 +144,7 @@ class EnterpriseActionTool(BaseTool):
return self._map_json_type_to_python(json_type) return self._map_json_type_to_python(json_type)
def _create_nested_model(self, schema: Dict[str, Any], model_name: str) -> type: def _create_nested_model(self, schema: Dict[str, Any], model_name: str) -> Type[Any]:
"""Create a nested Pydantic model for complex objects.""" """Create a nested Pydantic model for complex objects."""
full_model_name = f"{self._base_name}{model_name}" full_model_name = f"{self._base_name}{model_name}"
@@ -187,7 +183,7 @@ class EnterpriseActionTool(BaseTool):
return dict return dict
def _create_field_definition( def _create_field_definition(
self, field_type: type, is_required: bool, description: str self, field_type: Type[Any], is_required: bool, description: str
) -> tuple: ) -> tuple:
"""Create Pydantic field definition based on type and requirement.""" """Create Pydantic field definition based on type and requirement."""
if is_required: if is_required:
@@ -201,7 +197,7 @@ class EnterpriseActionTool(BaseTool):
Field(default=None, description=description), Field(default=None, description=description),
) )
def _map_json_type_to_python(self, json_type: str) -> type: def _map_json_type_to_python(self, json_type: str) -> Type[Any]:
"""Map basic JSON schema types to Python types.""" """Map basic JSON schema types to Python types."""
type_mapping = { type_mapping = {
"string": str, "string": str,
@@ -246,12 +242,13 @@ class EnterpriseActionTool(BaseTool):
if field_name not in cleaned_kwargs: if field_name not in cleaned_kwargs:
cleaned_kwargs[field_name] = None cleaned_kwargs[field_name] = None
api_url = f"{self.enterprise_action_kit_project_url}/{self.enterprise_action_kit_project_id}/actions"
api_url = f"{self.enterprise_api_base_url}/actions/{self.action_name}/execute"
headers = { headers = {
"Authorization": f"Bearer {self.enterprise_action_token}", "Authorization": f"Bearer {self.enterprise_action_token}",
"Content-Type": "application/json", "Content-Type": "application/json",
} }
payload = {"action": self.action_name, "parameters": cleaned_kwargs} payload = cleaned_kwargs
response = requests.post( response = requests.post(
url=api_url, headers=headers, json=payload, timeout=60 url=api_url, headers=headers, json=payload, timeout=60
@@ -274,40 +271,30 @@ class EnterpriseActionKitToolAdapter:
def __init__( def __init__(
self, self,
enterprise_action_token: str, enterprise_action_token: str,
enterprise_action_kit_project_url: str = ENTERPRISE_ACTION_KIT_PROJECT_URL, enterprise_api_base_url: Optional[str] = None,
enterprise_action_kit_project_id: str = ENTERPRISE_ACTION_KIT_PROJECT_ID,
): ):
"""Initialize the adapter with an enterprise action token.""" """Initialize the adapter with an enterprise action token."""
self.enterprise_action_token = enterprise_action_token self._set_enterprise_action_token(enterprise_action_token)
self._actions_schema = {} self._actions_schema = {}
self._tools = None self._tools = None
self.enterprise_action_kit_project_id = enterprise_action_kit_project_id self.enterprise_api_base_url = enterprise_api_base_url or get_enterprise_api_base_url()
self.enterprise_action_kit_project_url = enterprise_action_kit_project_url
def tools(self) -> List[BaseTool]: def tools(self) -> List[BaseTool]:
"""Get the list of tools created from enterprise actions.""" """Get the list of tools created from enterprise actions."""
if self._tools is None: if self._tools is None:
self._fetch_actions() self._fetch_actions()
self._create_tools() self._create_tools()
return self._tools return self._tools or []
def _fetch_actions(self): def _fetch_actions(self):
"""Fetch available actions from the API.""" """Fetch available actions from the API."""
try: try:
if (
self.enterprise_action_token is None
or self.enterprise_action_token == ""
):
self.enterprise_action_token = os.environ.get(
"CREWAI_ENTERPRISE_TOOLS_TOKEN"
)
actions_url = f"{self.enterprise_action_kit_project_url}/{self.enterprise_action_kit_project_id}/actions" actions_url = f"{self.enterprise_api_base_url}/actions"
headers = {"Authorization": f"Bearer {self.enterprise_action_token}"} headers = {"Authorization": f"Bearer {self.enterprise_action_token}"}
params = {"format": "json_schema"}
response = requests.get( response = requests.get(
actions_url, headers=headers, params=params, timeout=30 actions_url, headers=headers, timeout=30
) )
response.raise_for_status() response.raise_for_status()
@@ -316,17 +303,22 @@ class EnterpriseActionKitToolAdapter:
print(f"Unexpected API response structure: {raw_data}") print(f"Unexpected API response structure: {raw_data}")
return return
# Parse the actions schema
parsed_schema = {} parsed_schema = {}
action_categories = raw_data["actions"] action_categories = raw_data["actions"]
for category, action_list in action_categories.items(): for integration_type, action_list in action_categories.items():
if isinstance(action_list, list): if isinstance(action_list, list):
for action in action_list: for action in action_list:
func_details = action.get("function") action_name = action.get("name")
if func_details and "name" in func_details: if action_name:
action_name = func_details["name"] action_schema = {
parsed_schema[action_name] = action "function": {
"name": action_name,
"description": action.get("description", f"Execute {action_name}"),
"parameters": action.get("parameters", {})
}
}
parsed_schema[action_name] = action_schema
self._actions_schema = parsed_schema self._actions_schema = parsed_schema
@@ -408,14 +400,23 @@ class EnterpriseActionKitToolAdapter:
action_name=action_name, action_name=action_name,
action_schema=action_schema, action_schema=action_schema,
enterprise_action_token=self.enterprise_action_token, enterprise_action_token=self.enterprise_action_token,
enterprise_action_kit_project_id=self.enterprise_action_kit_project_id, enterprise_api_base_url=self.enterprise_api_base_url,
enterprise_action_kit_project_url=self.enterprise_action_kit_project_url,
) )
tools.append(tool) tools.append(tool)
self._tools = tools self._tools = tools
def _set_enterprise_action_token(self, enterprise_action_token: Optional[str]):
if enterprise_action_token and not enterprise_action_token.startswith("PK_"):
logging.warning(
"Legacy token detected, please consider using the new Enterprise Action Auth token. Check out our docs for more information https://docs.crewai.com/en/enterprise/features/integrations."
)
token = enterprise_action_token or os.environ.get("CREWAI_ENTERPRISE_TOOLS_TOKEN")
self.enterprise_action_token = token
def __enter__(self): def __enter__(self):
return self.tools() return self.tools()

View File

@@ -281,10 +281,9 @@ class TestEnterpriseActionToolSchemaConversion(unittest.TestCase):
call_args = mock_post.call_args call_args = mock_post.call_args
payload = call_args[1]["json"] payload = call_args[1]["json"]
self.assertEqual(payload["action"], "GMAIL_SEARCH_FOR_EMAIL") self.assertIn("filterCriteria", payload)
self.assertIn("filterCriteria", payload["parameters"]) self.assertIn("options", payload)
self.assertIn("options", payload["parameters"]) self.assertEqual(payload["filterCriteria"]["operation"], "OR")
self.assertEqual(payload["parameters"]["filterCriteria"]["operation"], "OR")
def test_model_naming_convention(self): def test_model_naming_convention(self):
"""Test that generated model names follow proper conventions.""" """Test that generated model names follow proper conventions."""