mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-08 15:48:29 +00:00
refactor: fetch & execute enterprise tool actions from platform (#437)
* refactor: fetch enterprise tool actions from platform * chore: logging legacy token detected
This commit is contained in:
@@ -1,15 +1,19 @@
|
||||
import os
|
||||
import json
|
||||
import requests
|
||||
from typing import List, Any, Dict, Literal, Optional, Union, get_origin
|
||||
import logging
|
||||
from typing import List, Any, Dict, Literal, Optional, Union, get_origin, Type, cast
|
||||
from pydantic import Field, create_model
|
||||
from crewai.tools import BaseTool
|
||||
import re
|
||||
|
||||
|
||||
# DEFAULTS
|
||||
ENTERPRISE_ACTION_KIT_PROJECT_ID = "dd525517-df22-49d2-a69e-6a0eed211166"
|
||||
ENTERPRISE_ACTION_KIT_PROJECT_URL = "https://worker-actionkit.tools.crewai.com/projects"
|
||||
def get_enterprise_api_base_url() -> str:
|
||||
"""Get the enterprise API base URL from environment or use default."""
|
||||
base_url = os.getenv("CREWAI_PLUS_URL", "https://app.crewai.com")
|
||||
return f"{base_url}/crewai_plus/api/v1/integrations"
|
||||
|
||||
ENTERPRISE_API_BASE_URL = get_enterprise_api_base_url()
|
||||
|
||||
|
||||
class EnterpriseActionTool(BaseTool):
|
||||
@@ -22,11 +26,8 @@ class EnterpriseActionTool(BaseTool):
|
||||
action_schema: Dict[str, Any] = Field(
|
||||
default={}, description="The schema of the action"
|
||||
)
|
||||
enterprise_action_kit_project_id: str = Field(
|
||||
default=ENTERPRISE_ACTION_KIT_PROJECT_ID, description="The project id"
|
||||
)
|
||||
enterprise_action_kit_project_url: str = Field(
|
||||
default=ENTERPRISE_ACTION_KIT_PROJECT_URL, description="The project url"
|
||||
enterprise_api_base_url: str = Field(
|
||||
default=ENTERPRISE_API_BASE_URL, description="The base API URL"
|
||||
)
|
||||
|
||||
def __init__(
|
||||
@@ -36,8 +37,7 @@ class EnterpriseActionTool(BaseTool):
|
||||
enterprise_action_token: str,
|
||||
action_name: str,
|
||||
action_schema: Dict[str, Any],
|
||||
enterprise_action_kit_project_url: str = ENTERPRISE_ACTION_KIT_PROJECT_URL,
|
||||
enterprise_action_kit_project_id: str = ENTERPRISE_ACTION_KIT_PROJECT_ID,
|
||||
enterprise_api_base_url: Optional[str] = None,
|
||||
):
|
||||
self._model_registry = {}
|
||||
self._base_name = self._sanitize_name(name)
|
||||
@@ -86,11 +86,7 @@ class EnterpriseActionTool(BaseTool):
|
||||
self.enterprise_action_token = enterprise_action_token
|
||||
self.action_name = action_name
|
||||
self.action_schema = action_schema
|
||||
|
||||
if enterprise_action_kit_project_id is not None:
|
||||
self.enterprise_action_kit_project_id = enterprise_action_kit_project_id
|
||||
if enterprise_action_kit_project_url is not None:
|
||||
self.enterprise_action_kit_project_url = enterprise_action_kit_project_url
|
||||
self.enterprise_api_base_url = enterprise_api_base_url or get_enterprise_api_base_url()
|
||||
|
||||
def _sanitize_name(self, name: str) -> str:
|
||||
"""Sanitize names to create proper Python class names."""
|
||||
@@ -112,7 +108,7 @@ class EnterpriseActionTool(BaseTool):
|
||||
)
|
||||
return schema_props, required
|
||||
|
||||
def _process_schema_type(self, schema: Dict[str, Any], type_name: str) -> type:
|
||||
def _process_schema_type(self, schema: Dict[str, Any], type_name: str) -> Type[Any]:
|
||||
"""Process a JSON schema and return appropriate Python type."""
|
||||
if "anyOf" in schema:
|
||||
any_of_types = schema["anyOf"]
|
||||
@@ -122,7 +118,7 @@ class EnterpriseActionTool(BaseTool):
|
||||
if non_null_types:
|
||||
base_type = self._process_schema_type(non_null_types[0], type_name)
|
||||
return Optional[base_type] if is_nullable else base_type
|
||||
return Optional[str]
|
||||
return cast(Type[Any], Optional[str])
|
||||
|
||||
if "oneOf" in schema:
|
||||
return self._process_schema_type(schema["oneOf"][0], type_name)
|
||||
@@ -136,7 +132,7 @@ class EnterpriseActionTool(BaseTool):
|
||||
enum_values = schema["enum"]
|
||||
if not enum_values:
|
||||
return self._map_json_type_to_python(json_type)
|
||||
return Literal[tuple(enum_values)] # type: ignore
|
||||
return Literal[tuple(enum_values)] # type: ignore[return-value]
|
||||
|
||||
if json_type == "array":
|
||||
items_schema = schema.get("items", {"type": "string"})
|
||||
@@ -148,7 +144,7 @@ class EnterpriseActionTool(BaseTool):
|
||||
|
||||
return self._map_json_type_to_python(json_type)
|
||||
|
||||
def _create_nested_model(self, schema: Dict[str, Any], model_name: str) -> type:
|
||||
def _create_nested_model(self, schema: Dict[str, Any], model_name: str) -> Type[Any]:
|
||||
"""Create a nested Pydantic model for complex objects."""
|
||||
full_model_name = f"{self._base_name}{model_name}"
|
||||
|
||||
@@ -187,7 +183,7 @@ class EnterpriseActionTool(BaseTool):
|
||||
return dict
|
||||
|
||||
def _create_field_definition(
|
||||
self, field_type: type, is_required: bool, description: str
|
||||
self, field_type: Type[Any], is_required: bool, description: str
|
||||
) -> tuple:
|
||||
"""Create Pydantic field definition based on type and requirement."""
|
||||
if is_required:
|
||||
@@ -201,7 +197,7 @@ class EnterpriseActionTool(BaseTool):
|
||||
Field(default=None, description=description),
|
||||
)
|
||||
|
||||
def _map_json_type_to_python(self, json_type: str) -> type:
|
||||
def _map_json_type_to_python(self, json_type: str) -> Type[Any]:
|
||||
"""Map basic JSON schema types to Python types."""
|
||||
type_mapping = {
|
||||
"string": str,
|
||||
@@ -246,12 +242,13 @@ class EnterpriseActionTool(BaseTool):
|
||||
if field_name not in cleaned_kwargs:
|
||||
cleaned_kwargs[field_name] = None
|
||||
|
||||
api_url = f"{self.enterprise_action_kit_project_url}/{self.enterprise_action_kit_project_id}/actions"
|
||||
|
||||
api_url = f"{self.enterprise_api_base_url}/actions/{self.action_name}/execute"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.enterprise_action_token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
payload = {"action": self.action_name, "parameters": cleaned_kwargs}
|
||||
payload = cleaned_kwargs
|
||||
|
||||
response = requests.post(
|
||||
url=api_url, headers=headers, json=payload, timeout=60
|
||||
@@ -274,40 +271,30 @@ class EnterpriseActionKitToolAdapter:
|
||||
def __init__(
|
||||
self,
|
||||
enterprise_action_token: str,
|
||||
enterprise_action_kit_project_url: str = ENTERPRISE_ACTION_KIT_PROJECT_URL,
|
||||
enterprise_action_kit_project_id: str = ENTERPRISE_ACTION_KIT_PROJECT_ID,
|
||||
enterprise_api_base_url: Optional[str] = None,
|
||||
):
|
||||
"""Initialize the adapter with an enterprise action token."""
|
||||
self.enterprise_action_token = enterprise_action_token
|
||||
self._set_enterprise_action_token(enterprise_action_token)
|
||||
self._actions_schema = {}
|
||||
self._tools = None
|
||||
self.enterprise_action_kit_project_id = enterprise_action_kit_project_id
|
||||
self.enterprise_action_kit_project_url = enterprise_action_kit_project_url
|
||||
self.enterprise_api_base_url = enterprise_api_base_url or get_enterprise_api_base_url()
|
||||
|
||||
def tools(self) -> List[BaseTool]:
|
||||
"""Get the list of tools created from enterprise actions."""
|
||||
if self._tools is None:
|
||||
self._fetch_actions()
|
||||
self._create_tools()
|
||||
return self._tools
|
||||
return self._tools or []
|
||||
|
||||
def _fetch_actions(self):
|
||||
"""Fetch available actions from the API."""
|
||||
try:
|
||||
if (
|
||||
self.enterprise_action_token is None
|
||||
or self.enterprise_action_token == ""
|
||||
):
|
||||
self.enterprise_action_token = os.environ.get(
|
||||
"CREWAI_ENTERPRISE_TOOLS_TOKEN"
|
||||
)
|
||||
|
||||
actions_url = f"{self.enterprise_action_kit_project_url}/{self.enterprise_action_kit_project_id}/actions"
|
||||
actions_url = f"{self.enterprise_api_base_url}/actions"
|
||||
headers = {"Authorization": f"Bearer {self.enterprise_action_token}"}
|
||||
params = {"format": "json_schema"}
|
||||
|
||||
response = requests.get(
|
||||
actions_url, headers=headers, params=params, timeout=30
|
||||
actions_url, headers=headers, timeout=30
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
@@ -316,17 +303,22 @@ class EnterpriseActionKitToolAdapter:
|
||||
print(f"Unexpected API response structure: {raw_data}")
|
||||
return
|
||||
|
||||
# Parse the actions schema
|
||||
parsed_schema = {}
|
||||
action_categories = raw_data["actions"]
|
||||
|
||||
for category, action_list in action_categories.items():
|
||||
for integration_type, action_list in action_categories.items():
|
||||
if isinstance(action_list, list):
|
||||
for action in action_list:
|
||||
func_details = action.get("function")
|
||||
if func_details and "name" in func_details:
|
||||
action_name = func_details["name"]
|
||||
parsed_schema[action_name] = action
|
||||
action_name = action.get("name")
|
||||
if action_name:
|
||||
action_schema = {
|
||||
"function": {
|
||||
"name": action_name,
|
||||
"description": action.get("description", f"Execute {action_name}"),
|
||||
"parameters": action.get("parameters", {})
|
||||
}
|
||||
}
|
||||
parsed_schema[action_name] = action_schema
|
||||
|
||||
self._actions_schema = parsed_schema
|
||||
|
||||
@@ -408,14 +400,23 @@ class EnterpriseActionKitToolAdapter:
|
||||
action_name=action_name,
|
||||
action_schema=action_schema,
|
||||
enterprise_action_token=self.enterprise_action_token,
|
||||
enterprise_action_kit_project_id=self.enterprise_action_kit_project_id,
|
||||
enterprise_action_kit_project_url=self.enterprise_action_kit_project_url,
|
||||
enterprise_api_base_url=self.enterprise_api_base_url,
|
||||
)
|
||||
|
||||
tools.append(tool)
|
||||
|
||||
self._tools = tools
|
||||
|
||||
def _set_enterprise_action_token(self, enterprise_action_token: Optional[str]):
|
||||
if enterprise_action_token and not enterprise_action_token.startswith("PK_"):
|
||||
logging.warning(
|
||||
"Legacy token detected, please consider using the new Enterprise Action Auth token. Check out our docs for more information https://docs.crewai.com/en/enterprise/features/integrations."
|
||||
)
|
||||
|
||||
token = enterprise_action_token or os.environ.get("CREWAI_ENTERPRISE_TOOLS_TOKEN")
|
||||
|
||||
self.enterprise_action_token = token
|
||||
|
||||
def __enter__(self):
|
||||
return self.tools()
|
||||
|
||||
|
||||
@@ -281,10 +281,9 @@ class TestEnterpriseActionToolSchemaConversion(unittest.TestCase):
|
||||
call_args = mock_post.call_args
|
||||
payload = call_args[1]["json"]
|
||||
|
||||
self.assertEqual(payload["action"], "GMAIL_SEARCH_FOR_EMAIL")
|
||||
self.assertIn("filterCriteria", payload["parameters"])
|
||||
self.assertIn("options", payload["parameters"])
|
||||
self.assertEqual(payload["parameters"]["filterCriteria"]["operation"], "OR")
|
||||
self.assertIn("filterCriteria", payload)
|
||||
self.assertIn("options", payload)
|
||||
self.assertEqual(payload["filterCriteria"]["operation"], "OR")
|
||||
|
||||
def test_model_naming_convention(self):
|
||||
"""Test that generated model names follow proper conventions."""
|
||||
|
||||
Reference in New Issue
Block a user