diff --git a/src/crewai_tools/adapters/mcp_adapter.py b/src/crewai_tools/adapters/mcp_adapter.py index bcb38818d..bfff480eb 100644 --- a/src/crewai_tools/adapters/mcp_adapter.py +++ b/src/crewai_tools/adapters/mcp_adapter.py @@ -4,7 +4,7 @@ import logging from typing import TYPE_CHECKING, Any from crewai.tools import BaseTool - +from crewai_tools.adapters.tool_collection import ToolCollection """ MCPServer for CrewAI. @@ -114,7 +114,7 @@ class MCPServerAdapter: self._adapter.__exit__(None, None, None) @property - def tools(self) -> list[BaseTool]: + def tools(self) -> ToolCollection[BaseTool]: """The CrewAI tools available from the MCP server. Raises: @@ -127,7 +127,7 @@ class MCPServerAdapter: raise ValueError( "MCP server not started, run `mcp_server.start()` first before accessing `tools`" ) - return self._tools + return ToolCollection(self._tools) def __enter__(self): """ diff --git a/src/crewai_tools/adapters/tool_collection.py b/src/crewai_tools/adapters/tool_collection.py new file mode 100644 index 000000000..f0ec9a288 --- /dev/null +++ b/src/crewai_tools/adapters/tool_collection.py @@ -0,0 +1,59 @@ +from typing import List, Optional, Union, TypeVar, Generic, Dict +from crewai.tools import BaseTool + +T = TypeVar('T', bound=BaseTool) + +class ToolCollection(list, Generic[T]): + """ + A collection of tools that can be accessed by index or name + + This class extends the built-in list to provide dictionary-like + access to tools based on their name property. + + Usage: + tools = ToolCollection(list_of_tools) + # Access by index (regular list behavior) + first_tool = tools[0] + # Access by name (new functionality) + search_tool = tools["search"] + """ + + def __init__(self, tools: Optional[List[T]] = None): + super().__init__(tools or []) + self._name_cache: Dict[str, T] = {} + self._build_name_cache() + + def _build_name_cache(self) -> None: + self._name_cache = {tool.name: tool for tool in self} + + def __getitem__(self, key: Union[int, str]) -> T: + if isinstance(key, str): + return self._name_cache[key] + return super().__getitem__(key) + + def append(self, tool: T) -> None: + super().append(tool) + self._name_cache[tool.name] = tool + + def extend(self, tools: List[T]) -> None: + super().extend(tools) + self._build_name_cache() + + def insert(self, index: int, tool: T) -> None: + super().insert(index, tool) + self._name_cache[tool.name] = tool + + def remove(self, tool: T) -> None: + super().remove(tool) + if tool.name in self._name_cache: + del self._name_cache[tool.name] + + def pop(self, index: int = -1) -> T: + tool = super().pop(index) + if tool.name in self._name_cache: + del self._name_cache[tool.name] + return tool + + def clear(self) -> None: + super().clear() + self._name_cache.clear() diff --git a/src/crewai_tools/tools/crewai_enterprise_tools/crewai_enterprise_tools.py b/src/crewai_tools/tools/crewai_enterprise_tools/crewai_enterprise_tools.py index 7fc97d179..8e7275e69 100644 --- a/src/crewai_tools/tools/crewai_enterprise_tools/crewai_enterprise_tools.py +++ b/src/crewai_tools/tools/crewai_enterprise_tools/crewai_enterprise_tools.py @@ -7,6 +7,7 @@ import typing as t import logging from crewai.tools import BaseTool from crewai_tools.adapters.enterprise_adapter import EnterpriseActionKitToolAdapter +from crewai_tools.adapters.tool_collection import ToolCollection logger = logging.getLogger(__name__) @@ -16,7 +17,7 @@ def CrewaiEnterpriseTools( actions_list: t.Optional[t.List[str]] = None, enterprise_action_kit_project_id: t.Optional[str] = None, enterprise_action_kit_project_url: t.Optional[str] = None, -) -> t.List[BaseTool]: +) -> ToolCollection[BaseTool]: """Factory function that returns crewai enterprise tools. Args: @@ -24,9 +25,11 @@ def CrewaiEnterpriseTools( If not provided, will try to use CREWAI_ENTERPRISE_TOOLS_TOKEN env var. actions_list: Optional list of specific tool names to include. If provided, only tools with these names will be returned. + enterprise_action_kit_project_id: Optional ID of the Enterprise Action Kit project. + enterprise_action_kit_project_url: Optional URL of the Enterprise Action Kit project. Returns: - A list of BaseTool instances for enterprise actions + A ToolCollection of BaseTool instances for enterprise actions """ if enterprise_token is None: enterprise_token = os.environ.get("CREWAI_ENTERPRISE_TOOLS_TOKEN") @@ -47,7 +50,8 @@ def CrewaiEnterpriseTools( all_tools = adapter.tools() if actions_list is None: - return all_tools + return ToolCollection(all_tools) # Filter tools based on the provided list - return [tool for tool in all_tools if tool.name in actions_list] + filtered_tools = [tool for tool in all_tools if tool.name in actions_list] + return ToolCollection(filtered_tools) diff --git a/tests/adapters/mcp_adapter.py b/tests/adapters/mcp_adapter_test.py similarity index 96% rename from tests/adapters/mcp_adapter.py rename to tests/adapters/mcp_adapter_test.py index 569a10ae6..f2b08bc16 100644 --- a/tests/adapters/mcp_adapter.py +++ b/tests/adapters/mcp_adapter_test.py @@ -4,7 +4,7 @@ import pytest from mcp import StdioServerParameters from crewai_tools import MCPServerAdapter - +from crewai_tools.adapters.tool_collection import ToolCollection @pytest.fixture def echo_server_script(): @@ -18,7 +18,7 @@ def echo_server_script(): def echo_tool(text: str) -> str: """Echo the input text""" return f"Echo: {text}" - + mcp.run() ''' ) @@ -68,6 +68,7 @@ def test_context_manager_syntax(echo_server_script): command="uv", args=["run", "python", "-c", echo_server_script] ) with MCPServerAdapter(serverparams) as tools: + assert isinstance(tools, ToolCollection) assert len(tools) == 1 assert tools[0].name == "echo_tool" assert tools[0].run(text="hello") == "Echo: hello" @@ -91,7 +92,7 @@ def test_try_finally_syntax(echo_server_script): assert tools[0].run(text="hello") == "Echo: hello" finally: mcp_server_adapter.stop() - + def test_try_finally_syntax_sse(echo_sse_server): sse_serverparams = echo_sse_server mcp_server_adapter = MCPServerAdapter(sse_serverparams) diff --git a/tests/tools/crewai_enterprise_tools_test.py b/tests/tools/crewai_enterprise_tools_test.py new file mode 100644 index 000000000..384093e0f --- /dev/null +++ b/tests/tools/crewai_enterprise_tools_test.py @@ -0,0 +1,70 @@ +import os +import unittest +from unittest.mock import patch, MagicMock + +from crewai.tools import BaseTool +from crewai_tools.tools import CrewaiEnterpriseTools +from crewai_tools.adapters.tool_collection import ToolCollection + + +class TestCrewaiEnterpriseTools(unittest.TestCase): + def setUp(self): + self.mock_tools = [ + self._create_mock_tool("tool1", "Tool 1 Description"), + self._create_mock_tool("tool2", "Tool 2 Description"), + self._create_mock_tool("tool3", "Tool 3 Description"), + ] + self.adapter_patcher = patch( + 'crewai_tools.tools.crewai_enterprise_tools.crewai_enterprise_tools.EnterpriseActionKitToolAdapter' + ) + self.MockAdapter = self.adapter_patcher.start() + + mock_adapter_instance = self.MockAdapter.return_value + mock_adapter_instance.tools.return_value = self.mock_tools + + def tearDown(self): + self.adapter_patcher.stop() + + def _create_mock_tool(self, name, description): + mock_tool = MagicMock(spec=BaseTool) + mock_tool.name = name + mock_tool.description = description + return mock_tool + + @patch.dict(os.environ, {"CREWAI_ENTERPRISE_TOOLS_TOKEN": "env-token"}) + def test_returns_tool_collection(self): + tools = CrewaiEnterpriseTools() + self.assertIsInstance(tools, ToolCollection) + + @patch.dict(os.environ, {"CREWAI_ENTERPRISE_TOOLS_TOKEN": "env-token"}) + def test_returns_all_tools_when_no_actions_list(self): + tools = CrewaiEnterpriseTools() + self.assertEqual(len(tools), 3) + self.assertEqual(tools[0].name, "tool1") + self.assertEqual(tools[1].name, "tool2") + self.assertEqual(tools[2].name, "tool3") + + @patch.dict(os.environ, {"CREWAI_ENTERPRISE_TOOLS_TOKEN": "env-token"}) + def test_filters_tools_by_actions_list(self): + tools = CrewaiEnterpriseTools(actions_list=["tool1", "tool3"]) + self.assertEqual(len(tools), 2) + self.assertEqual(tools[0].name, "tool1") + self.assertEqual(tools[1].name, "tool3") + + def test_uses_provided_parameters(self): + CrewaiEnterpriseTools( + enterprise_token="test-token", + enterprise_action_kit_project_id="project-id", + enterprise_action_kit_project_url="project-url" + ) + + self.MockAdapter.assert_called_once_with( + enterprise_action_token="test-token", + enterprise_action_kit_project_id="project-id", + enterprise_action_kit_project_url="project-url" + ) + + @patch.dict(os.environ, {"CREWAI_ENTERPRISE_TOOLS_TOKEN": "env-token"}) + def test_uses_environment_token(self): + CrewaiEnterpriseTools() + self.MockAdapter.assert_called_once_with(enterprise_action_token="env-token") \ No newline at end of file diff --git a/tests/tools/tool_collection_test.py b/tests/tools/tool_collection_test.py new file mode 100644 index 000000000..fb4f35c95 --- /dev/null +++ b/tests/tools/tool_collection_test.py @@ -0,0 +1,172 @@ +import unittest +from unittest.mock import MagicMock + +from crewai.tools import BaseTool +from crewai_tools.adapters.tool_collection import ToolCollection + + +class TestToolCollection(unittest.TestCase): + def setUp(self): + + self.search_tool = self._create_mock_tool("search", "Search Tool") + self.calculator_tool = self._create_mock_tool("calculator", "Calculator Tool") + self.translator_tool = self._create_mock_tool("translator", "Translator Tool") + + self.tools = ToolCollection([ + self.search_tool, + self.calculator_tool, + self.translator_tool + ]) + + def _create_mock_tool(self, name, description): + mock_tool = MagicMock(spec=BaseTool) + mock_tool.name = name + mock_tool.description = description + return mock_tool + + def test_initialization(self): + self.assertEqual(len(self.tools), 3) + self.assertEqual(self.tools[0].name, "search") + self.assertEqual(self.tools[1].name, "calculator") + self.assertEqual(self.tools[2].name, "translator") + + def test_empty_initialization(self): + empty_collection = ToolCollection() + self.assertEqual(len(empty_collection), 0) + self.assertEqual(empty_collection._name_cache, {}) + + def test_initialization_with_none(self): + collection = ToolCollection(None) + self.assertEqual(len(collection), 0) + self.assertEqual(collection._name_cache, {}) + + def test_access_by_index(self): + self.assertEqual(self.tools[0], self.search_tool) + self.assertEqual(self.tools[1], self.calculator_tool) + self.assertEqual(self.tools[2], self.translator_tool) + + def test_access_by_name(self): + self.assertEqual(self.tools["search"], self.search_tool) + self.assertEqual(self.tools["calculator"], self.calculator_tool) + self.assertEqual(self.tools["translator"], self.translator_tool) + + def test_key_error_for_invalid_name(self): + with self.assertRaises(KeyError): + _ = self.tools["nonexistent"] + + def test_index_error_for_invalid_index(self): + with self.assertRaises(IndexError): + _ = self.tools[10] + + def test_negative_index(self): + self.assertEqual(self.tools[-1], self.translator_tool) + self.assertEqual(self.tools[-2], self.calculator_tool) + self.assertEqual(self.tools[-3], self.search_tool) + + def test_append(self): + new_tool = self._create_mock_tool("new", "New Tool") + self.tools.append(new_tool) + + self.assertEqual(len(self.tools), 4) + self.assertEqual(self.tools[3], new_tool) + self.assertEqual(self.tools["new"], new_tool) + self.assertIn("new", self.tools._name_cache) + + def test_append_duplicate_name(self): + duplicate_tool = self._create_mock_tool("search", "Duplicate Search Tool") + self.tools.append(duplicate_tool) + + self.assertEqual(len(self.tools), 4) + self.assertEqual(self.tools["search"], duplicate_tool) + + def test_extend(self): + new_tools = [ + self._create_mock_tool("tool4", "Tool 4"), + self._create_mock_tool("tool5", "Tool 5"), + ] + self.tools.extend(new_tools) + + self.assertEqual(len(self.tools), 5) + self.assertEqual(self.tools["tool4"], new_tools[0]) + self.assertEqual(self.tools["tool5"], new_tools[1]) + self.assertIn("tool4", self.tools._name_cache) + self.assertIn("tool5", self.tools._name_cache) + + def test_insert(self): + new_tool = self._create_mock_tool("inserted", "Inserted Tool") + self.tools.insert(1, new_tool) + + self.assertEqual(len(self.tools), 4) + self.assertEqual(self.tools[1], new_tool) + self.assertEqual(self.tools["inserted"], new_tool) + self.assertIn("inserted", self.tools._name_cache) + + def test_remove(self): + self.tools.remove(self.calculator_tool) + + self.assertEqual(len(self.tools), 2) + with self.assertRaises(KeyError): + _ = self.tools["calculator"] + self.assertNotIn("calculator", self.tools._name_cache) + + def test_remove_nonexistent_tool(self): + nonexistent_tool = self._create_mock_tool("nonexistent", "Nonexistent Tool") + + with self.assertRaises(ValueError): + self.tools.remove(nonexistent_tool) + + def test_pop(self): + popped = self.tools.pop(1) + + self.assertEqual(popped, self.calculator_tool) + self.assertEqual(len(self.tools), 2) + with self.assertRaises(KeyError): + _ = self.tools["calculator"] + self.assertNotIn("calculator", self.tools._name_cache) + + def test_pop_last(self): + popped = self.tools.pop() + + self.assertEqual(popped, self.translator_tool) + self.assertEqual(len(self.tools), 2) + with self.assertRaises(KeyError): + _ = self.tools["translator"] + self.assertNotIn("translator", self.tools._name_cache) + + def test_clear(self): + self.tools.clear() + + self.assertEqual(len(self.tools), 0) + self.assertEqual(self.tools._name_cache, {}) + with self.assertRaises(KeyError): + _ = self.tools["search"] + + def test_iteration(self): + tools_list = list(self.tools) + self.assertEqual(tools_list, [self.search_tool, self.calculator_tool, self.translator_tool]) + + def test_contains(self): + self.assertIn(self.search_tool, self.tools) + self.assertIn(self.calculator_tool, self.tools) + self.assertIn(self.translator_tool, self.tools) + + nonexistent_tool = self._create_mock_tool("nonexistent", "Nonexistent Tool") + self.assertNotIn(nonexistent_tool, self.tools) + + def test_slicing(self): + slice_result = self.tools[1:3] + self.assertEqual(len(slice_result), 2) + self.assertEqual(slice_result[0], self.calculator_tool) + self.assertEqual(slice_result[1], self.translator_tool) + + self.assertIsInstance(slice_result, list) + self.assertNotIsInstance(slice_result, ToolCollection) + + def test_getitem_with_tool_name_as_int(self): + numeric_name_tool = self._create_mock_tool("123", "Numeric Name Tool") + self.tools.append(numeric_name_tool) + + self.assertEqual(self.tools["123"], numeric_name_tool) + + with self.assertRaises(IndexError): + _ = self.tools[123] \ No newline at end of file