diff --git a/src/crewai/cli/plus_api.py b/src/crewai/cli/plus_api.py index 4d30e97b3..e34c26b1b 100644 --- a/src/crewai/cli/plus_api.py +++ b/src/crewai/cli/plus_api.py @@ -26,6 +26,9 @@ class PlusAPI: "User-Agent": f"CrewAI-CLI/{get_crewai_version()}", "X-Crewai-Version": get_crewai_version(), } + settings = Settings() + if settings.org_uuid: + self.headers["X-Crewai-Organization-Id"] = settings.org_uuid self.base_url = getenv("CREWAI_BASE_URL", "https://app.crewai.com") def _make_request(self, method: str, endpoint: str, **kwargs) -> requests.Response: @@ -35,25 +38,13 @@ class PlusAPI: return session.request(method, url, headers=self.headers, **kwargs) def login_to_tool_repository(self): - settings = Settings() - payload = {} - if settings.org_uuid: - payload["organization_uuid"] = settings.org_uuid - return self._make_request("POST", f"{self.TOOLS_RESOURCE}/login", json=payload) + return self._make_request("POST", f"{self.TOOLS_RESOURCE}/login") def get_tool(self, handle: str): - settings = Settings() - params = {} - if settings.org_uuid: - params["organization_uuid"] = settings.org_uuid - return self._make_request("GET", f"{self.TOOLS_RESOURCE}/{handle}", params=params) + return self._make_request("GET", f"{self.TOOLS_RESOURCE}/{handle}") def get_agent(self, handle: str): - settings = Settings() - params = {} - if settings.org_uuid: - params["organization_uuid"] = settings.org_uuid - return self._make_request("GET", f"{self.AGENTS_RESOURCE}/{handle}", params=params) + return self._make_request("GET", f"{self.AGENTS_RESOURCE}/{handle}") def publish_tool( self, @@ -72,9 +63,6 @@ class PlusAPI: "description": description, "available_exports": available_exports, } - settings = Settings() - if settings.org_uuid: - params["organization_uuid"] = settings.org_uuid return self._make_request("POST", f"{self.TOOLS_RESOURCE}", json=params) def deploy_by_name(self, project_name: str) -> requests.Response: diff --git a/tests/cli/test_plus_api.py b/tests/cli/test_plus_api.py index 7bab1ddf7..eff57e1a5 100644 --- a/tests/cli/test_plus_api.py +++ b/tests/cli/test_plus_api.py @@ -1,6 +1,6 @@ import os import unittest -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock, patch, ANY from crewai.cli.plus_api import PlusAPI @@ -9,6 +9,7 @@ class TestPlusAPI(unittest.TestCase): def setUp(self): self.api_key = "test_api_key" self.api = PlusAPI(self.api_key) + self.org_uuid = "test-org-uuid" def test_init(self): self.assertEqual(self.api.api_key, self.api_key) @@ -25,25 +26,33 @@ class TestPlusAPI(unittest.TestCase): response = self.api.login_to_tool_repository() mock_make_request.assert_called_once_with( - "POST", "/crewai_plus/api/v1/tools/login", json={} + "POST", "/crewai_plus/api/v1/tools/login" ) self.assertEqual(response, mock_response) - + + def assert_request_with_org_id(self, mock_make_request, method: str, endpoint: str, **kwargs): + mock_make_request.assert_called_once_with( + method, f"https://app.crewai.com{endpoint}", headers={'Authorization': ANY, 'Content-Type': ANY, 'User-Agent': ANY, 'X-Crewai-Version': ANY, 'X-Crewai-Organization-Id': self.org_uuid}, **kwargs + ) + @patch("crewai.cli.plus_api.Settings") - @patch("crewai.cli.plus_api.PlusAPI._make_request") + @patch("requests.Session.request") def test_login_to_tool_repository_with_org_uuid(self, mock_make_request, mock_settings_class): mock_settings = MagicMock() - mock_settings.org_uuid = "test-org-uuid" + mock_settings.org_uuid = self.org_uuid mock_settings_class.return_value = mock_settings + # re-initialize Client + self.api = PlusAPI(self.api_key) mock_response = MagicMock() mock_make_request.return_value = mock_response response = self.api.login_to_tool_repository() - mock_make_request.assert_called_once_with( - "POST", "/crewai_plus/api/v1/tools/login", - json={"organization_uuid": "test-org-uuid"} + self.assert_request_with_org_id( + mock_make_request, + 'POST', + '/crewai_plus/api/v1/tools/login' ) self.assertEqual(response, mock_response) @@ -54,28 +63,28 @@ class TestPlusAPI(unittest.TestCase): response = self.api.get_agent("test_agent_handle") mock_make_request.assert_called_once_with( - "GET", "/crewai_plus/api/v1/agents/test_agent_handle", params={} + "GET", "/crewai_plus/api/v1/agents/test_agent_handle" ) self.assertEqual(response, mock_response) @patch("crewai.cli.plus_api.Settings") - @patch("crewai.cli.plus_api.PlusAPI._make_request") + @patch("requests.Session.request") def test_get_agent_with_org_uuid(self, mock_make_request, mock_settings_class): - # Set up mock settings with org_uuid mock_settings = MagicMock() - mock_settings.org_uuid = "test-org-uuid" + mock_settings.org_uuid = self.org_uuid mock_settings_class.return_value = mock_settings + # re-initialize Client + self.api = PlusAPI(self.api_key) - # Set up mock response mock_response = MagicMock() mock_make_request.return_value = mock_response response = self.api.get_agent("test_agent_handle") - # Verify the params include the organization_uuid - mock_make_request.assert_called_once_with( - "GET", "/crewai_plus/api/v1/agents/test_agent_handle", - params={"organization_uuid": "test-org-uuid"} + self.assert_request_with_org_id( + mock_make_request, + "GET", + "/crewai_plus/api/v1/agents/test_agent_handle" ) self.assertEqual(response, mock_response) @@ -86,28 +95,29 @@ class TestPlusAPI(unittest.TestCase): response = self.api.get_tool("test_tool_handle") mock_make_request.assert_called_once_with( - "GET", "/crewai_plus/api/v1/tools/test_tool_handle", params={} + "GET", "/crewai_plus/api/v1/tools/test_tool_handle" ) self.assertEqual(response, mock_response) @patch("crewai.cli.plus_api.Settings") - @patch("crewai.cli.plus_api.PlusAPI._make_request") + @patch("requests.Session.request") def test_get_tool_with_org_uuid(self, mock_make_request, mock_settings_class): - # Set up mock settings with org_uuid mock_settings = MagicMock() - mock_settings.org_uuid = "test-org-uuid" + mock_settings.org_uuid = self.org_uuid mock_settings_class.return_value = mock_settings - + # re-initialize Client + self.api = PlusAPI(self.api_key) + # Set up mock response mock_response = MagicMock() mock_make_request.return_value = mock_response response = self.api.get_tool("test_tool_handle") - # Verify the params include the organization_uuid - mock_make_request.assert_called_once_with( - "GET", "/crewai_plus/api/v1/tools/test_tool_handle", - params={"organization_uuid": "test-org-uuid"} + self.assert_request_with_org_id( + mock_make_request, + "GET", + "/crewai_plus/api/v1/tools/test_tool_handle" ) self.assertEqual(response, mock_response) @@ -139,13 +149,14 @@ class TestPlusAPI(unittest.TestCase): self.assertEqual(response, mock_response) @patch("crewai.cli.plus_api.Settings") - @patch("crewai.cli.plus_api.PlusAPI._make_request") + @patch("requests.Session.request") def test_publish_tool_with_org_uuid(self, mock_make_request, mock_settings_class): - # Set up mock settings with org_uuid mock_settings = MagicMock() - mock_settings.org_uuid = "test-org-uuid" + mock_settings.org_uuid = self.org_uuid mock_settings_class.return_value = mock_settings - + # re-initialize Client + self.api = PlusAPI(self.api_key) + # Set up mock response mock_response = MagicMock() mock_make_request.return_value = mock_response @@ -168,11 +179,13 @@ class TestPlusAPI(unittest.TestCase): "file": encoded_file, "description": description, "available_exports": None, - "organization_uuid": "test-org-uuid", } - mock_make_request.assert_called_once_with( - "POST", "/crewai_plus/api/v1/tools", json=expected_params + self.assert_request_with_org_id( + mock_make_request, + "POST", + "/crewai_plus/api/v1/tools", + json=expected_params ) self.assertEqual(response, mock_response)