From 03917411b4df05fe2462deb78723ff90b97cb362 Mon Sep 17 00:00:00 2001 From: Lucas Gomide Date: Wed, 25 Jun 2025 13:32:22 -0300 Subject: [PATCH] Support to filter available MCP Tools (#345) * feat: support to complex filter on ToolCollection * refactor: use proper tool collection methot to filter tool in CrewAiEnterpriseTools * feat: allow to filter available MCP tools --- src/crewai_tools/adapters/mcp_adapter.py | 20 ++++- src/crewai_tools/adapters/tool_collection.py | 35 +++++--- .../crewai_enterprise_tools.py | 6 +- tests/adapters/mcp_adapter_test.py | 90 ++++++++++++++++++- tests/tools/tool_collection_test.py | 65 +++++++++++++- 5 files changed, 193 insertions(+), 23 deletions(-) diff --git a/src/crewai_tools/adapters/mcp_adapter.py b/src/crewai_tools/adapters/mcp_adapter.py index bfff480eb..db4c15a24 100644 --- a/src/crewai_tools/adapters/mcp_adapter.py +++ b/src/crewai_tools/adapters/mcp_adapter.py @@ -46,10 +46,18 @@ class MCPServerAdapter: with MCPServerAdapter({"url": "http://localhost:8000/sse"}) as tools: # tools is now available + # context manager with filtered tools + with MCPServerAdapter(..., "tool1", "tool2") as filtered_tools: + # only tool1 and tool2 are available + # manually stop mcp server try: mcp_server = MCPServerAdapter(...) - tools = mcp_server.tools + tools = mcp_server.tools # all tools + + # or with filtered tools + mcp_server = MCPServerAdapter(..., "tool1", "tool2") + filtered_tools = mcp_server.tools # only tool1 and tool2 ... finally: mcp_server.stop() @@ -61,18 +69,22 @@ class MCPServerAdapter: def __init__( self, serverparams: StdioServerParameters | dict[str, Any], + *tool_names: str, ): """Initialize the MCP Server Args: serverparams: The parameters for the MCP server it supports either a `StdioServerParameters` or a `dict` respectively for STDIO and SSE. + *tool_names: Optional names of tools to filter. If provided, only tools with + matching names will be available. """ super().__init__() self._adapter = None self._tools = None + self._tool_names = list(tool_names) if tool_names else None if not MCP_AVAILABLE: import click @@ -127,7 +139,11 @@ class MCPServerAdapter: raise ValueError( "MCP server not started, run `mcp_server.start()` first before accessing `tools`" ) - return ToolCollection(self._tools) + + tools_collection = ToolCollection(self._tools) + if self._tool_names: + return tools_collection.filter_by_names(self._tool_names) + return tools_collection def __enter__(self): """ diff --git a/src/crewai_tools/adapters/tool_collection.py b/src/crewai_tools/adapters/tool_collection.py index f0ec9a288..291fa8f82 100644 --- a/src/crewai_tools/adapters/tool_collection.py +++ b/src/crewai_tools/adapters/tool_collection.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Union, TypeVar, Generic, Dict +from typing import List, Optional, Union, TypeVar, Generic, Dict, Callable from crewai.tools import BaseTool T = TypeVar('T', bound=BaseTool) @@ -24,16 +24,16 @@ class ToolCollection(list, Generic[T]): self._build_name_cache() def _build_name_cache(self) -> None: - self._name_cache = {tool.name: tool for tool in self} + self._name_cache = {tool.name.lower(): tool for tool in self} def __getitem__(self, key: Union[int, str]) -> T: if isinstance(key, str): - return self._name_cache[key] + return self._name_cache[key.lower()] return super().__getitem__(key) def append(self, tool: T) -> None: super().append(tool) - self._name_cache[tool.name] = tool + self._name_cache[tool.name.lower()] = tool def extend(self, tools: List[T]) -> None: super().extend(tools) @@ -41,19 +41,34 @@ class ToolCollection(list, Generic[T]): def insert(self, index: int, tool: T) -> None: super().insert(index, tool) - self._name_cache[tool.name] = tool + self._name_cache[tool.name.lower()] = tool def remove(self, tool: T) -> None: super().remove(tool) - if tool.name in self._name_cache: - del self._name_cache[tool.name] + if tool.name.lower() in self._name_cache: + del self._name_cache[tool.name.lower()] def pop(self, index: int = -1) -> T: tool = super().pop(index) - if tool.name in self._name_cache: - del self._name_cache[tool.name] + if tool.name.lower() in self._name_cache: + del self._name_cache[tool.name.lower()] return tool + def filter_by_names(self, names: Optional[List[str]] = None) -> "ToolCollection[T]": + if names is None: + return self + + return ToolCollection( + [ + tool + for name in names + if (tool := self._name_cache.get(name.lower())) is not None + ] + ) + + def filter_where(self, func: Callable[[T], bool]) -> "ToolCollection[T]": + return ToolCollection([tool for tool in self if func(tool)]) + def clear(self) -> None: super().clear() - self._name_cache.clear() + self._name_cache.clear() \ No newline at end of file diff --git a/src/crewai_tools/tools/crewai_enterprise_tools/crewai_enterprise_tools.py b/src/crewai_tools/tools/crewai_enterprise_tools/crewai_enterprise_tools.py index 871cf7c94..e531afeed 100644 --- a/src/crewai_tools/tools/crewai_enterprise_tools/crewai_enterprise_tools.py +++ b/src/crewai_tools/tools/crewai_enterprise_tools/crewai_enterprise_tools.py @@ -49,9 +49,5 @@ def CrewaiEnterpriseTools( adapter = EnterpriseActionKitToolAdapter(**adapter_kwargs) all_tools = adapter.tools() - if actions_list is None: - return ToolCollection(all_tools) - # Filter tools based on the provided list - filtered_tools = [tool for tool in all_tools if tool.name.lower() in [action.lower() for action in actions_list]] - return ToolCollection(filtered_tools) + return ToolCollection(all_tools).filter_by_names(actions_list) diff --git a/tests/adapters/mcp_adapter_test.py b/tests/adapters/mcp_adapter_test.py index f2b08bc16..d0dc88680 100644 --- a/tests/adapters/mcp_adapter_test.py +++ b/tests/adapters/mcp_adapter_test.py @@ -19,6 +19,11 @@ def echo_server_script(): """Echo the input text""" return f"Echo: {text}" + @mcp.tool() + def calc_tool(a: int, b: int) -> int: + """Calculate a + b""" + return a + b + mcp.run() ''' ) @@ -37,6 +42,11 @@ def echo_server_sse_script(): """Echo the input text""" return f"Echo: {text}" + @mcp.tool() + def calc_tool(a: int, b: int) -> int: + """Calculate a + b""" + return a + b + mcp.run("sse") ''' ) @@ -69,16 +79,20 @@ def test_context_manager_syntax(echo_server_script): ) with MCPServerAdapter(serverparams) as tools: assert isinstance(tools, ToolCollection) - assert len(tools) == 1 + assert len(tools) == 2 assert tools[0].name == "echo_tool" + assert tools[1].name == "calc_tool" assert tools[0].run(text="hello") == "Echo: hello" + assert tools[1].run(a=5, b=3) == '8' def test_context_manager_syntax_sse(echo_sse_server): sse_serverparams = echo_sse_server with MCPServerAdapter(sse_serverparams) as tools: - assert len(tools) == 1 + assert len(tools) == 2 assert tools[0].name == "echo_tool" + assert tools[1].name == "calc_tool" assert tools[0].run(text="hello") == "Echo: hello" + assert tools[1].run(a=5, b=3) == '8' def test_try_finally_syntax(echo_server_script): serverparams = StdioServerParameters( @@ -87,9 +101,11 @@ def test_try_finally_syntax(echo_server_script): try: mcp_server_adapter = MCPServerAdapter(serverparams) tools = mcp_server_adapter.tools - assert len(tools) == 1 + assert len(tools) == 2 assert tools[0].name == "echo_tool" + assert tools[1].name == "calc_tool" assert tools[0].run(text="hello") == "Echo: hello" + assert tools[1].run(a=5, b=3) == '8' finally: mcp_server_adapter.stop() @@ -98,8 +114,76 @@ def test_try_finally_syntax_sse(echo_sse_server): mcp_server_adapter = MCPServerAdapter(sse_serverparams) try: tools = mcp_server_adapter.tools + assert len(tools) == 2 + assert tools[0].name == "echo_tool" + assert tools[1].name == "calc_tool" + assert tools[0].run(text="hello") == "Echo: hello" + assert tools[1].run(a=5, b=3) == '8' + finally: + mcp_server_adapter.stop() + +def test_context_manager_with_filtered_tools(echo_server_script): + serverparams = StdioServerParameters( + command="uv", args=["run", "python", "-c", echo_server_script] + ) + # Only select the echo_tool + with MCPServerAdapter(serverparams, "echo_tool") as tools: + assert isinstance(tools, ToolCollection) assert len(tools) == 1 assert tools[0].name == "echo_tool" assert tools[0].run(text="hello") == "Echo: hello" + # Check that calc_tool is not present + with pytest.raises(IndexError): + _ = tools[1] + with pytest.raises(KeyError): + _ = tools["calc_tool"] + +def test_context_manager_sse_with_filtered_tools(echo_sse_server): + sse_serverparams = echo_sse_server + # Only select the calc_tool + with MCPServerAdapter(sse_serverparams, "calc_tool") as tools: + assert isinstance(tools, ToolCollection) + assert len(tools) == 1 + assert tools[0].name == "calc_tool" + assert tools[0].run(a=10, b=5) == '15' + # Check that echo_tool is not present + with pytest.raises(IndexError): + _ = tools[1] + with pytest.raises(KeyError): + _ = tools["echo_tool"] + +def test_try_finally_with_filtered_tools(echo_server_script): + serverparams = StdioServerParameters( + command="uv", args=["run", "python", "-c", echo_server_script] + ) + try: + # Select both tools but in reverse order + mcp_server_adapter = MCPServerAdapter(serverparams, "calc_tool", "echo_tool") + tools = mcp_server_adapter.tools + assert len(tools) == 2 + # The order of tools is based on filter_by_names which preserves + # the original order from the collection + assert tools[0].name == "calc_tool" + assert tools[1].name == "echo_tool" finally: mcp_server_adapter.stop() + +def test_filter_with_nonexistent_tool(echo_server_script): + serverparams = StdioServerParameters( + command="uv", args=["run", "python", "-c", echo_server_script] + ) + # Include a tool that doesn't exist + with MCPServerAdapter(serverparams, "echo_tool", "nonexistent_tool") as tools: + # Only echo_tool should be in the result + assert len(tools) == 1 + assert tools[0].name == "echo_tool" + +def test_filter_with_only_nonexistent_tools(echo_server_script): + serverparams = StdioServerParameters( + command="uv", args=["run", "python", "-c", echo_server_script] + ) + # All requested tools don't exist + with MCPServerAdapter(serverparams, "nonexistent1", "nonexistent2") as tools: + # Should return an empty tool collection + assert isinstance(tools, ToolCollection) + assert len(tools) == 0 diff --git a/tests/tools/tool_collection_test.py b/tests/tools/tool_collection_test.py index fb4f35c95..e409a4e76 100644 --- a/tests/tools/tool_collection_test.py +++ b/tests/tools/tool_collection_test.py @@ -8,7 +8,7 @@ from crewai_tools.adapters.tool_collection import ToolCollection class TestToolCollection(unittest.TestCase): def setUp(self): - self.search_tool = self._create_mock_tool("search", "Search Tool") + self.search_tool = self._create_mock_tool("SearcH", "Search Tool") # Tool name is case sensitive self.calculator_tool = self._create_mock_tool("calculator", "Calculator Tool") self.translator_tool = self._create_mock_tool("translator", "Translator Tool") @@ -26,7 +26,7 @@ class TestToolCollection(unittest.TestCase): def test_initialization(self): self.assertEqual(len(self.tools), 3) - self.assertEqual(self.tools[0].name, "search") + self.assertEqual(self.tools[0].name, "SearcH") self.assertEqual(self.tools[1].name, "calculator") self.assertEqual(self.tools[2].name, "translator") @@ -169,4 +169,63 @@ class TestToolCollection(unittest.TestCase): self.assertEqual(self.tools["123"], numeric_name_tool) with self.assertRaises(IndexError): - _ = self.tools[123] \ No newline at end of file + _ = self.tools[123] + + def test_filter_by_names(self): + + filtered = self.tools.filter_by_names(None) + + self.assertIsInstance(filtered, ToolCollection) + self.assertEqual(len(filtered), 3) + + filtered = self.tools.filter_by_names(["search", "translator"]) + + self.assertIsInstance(filtered, ToolCollection) + self.assertEqual(len(filtered), 2) + self.assertEqual(filtered[0], self.search_tool) + self.assertEqual(filtered[1], self.translator_tool) + self.assertEqual(filtered["search"], self.search_tool) + self.assertEqual(filtered["translator"], self.translator_tool) + + filtered = self.tools.filter_by_names(["search", "nonexistent"]) + + self.assertIsInstance(filtered, ToolCollection) + self.assertEqual(len(filtered), 1) + self.assertEqual(filtered[0], self.search_tool) + + filtered = self.tools.filter_by_names(["nonexistent1", "nonexistent2"]) + + self.assertIsInstance(filtered, ToolCollection) + self.assertEqual(len(filtered), 0) + + filtered = self.tools.filter_by_names([]) + + self.assertIsInstance(filtered, ToolCollection) + self.assertEqual(len(filtered), 0) + + def test_filter_where(self): + filtered = self.tools.filter_where(lambda tool: tool.name.startswith("S")) + + self.assertIsInstance(filtered, ToolCollection) + self.assertEqual(len(filtered), 1) + self.assertEqual(filtered[0], self.search_tool) + self.assertEqual(filtered["search"], self.search_tool) + + filtered = self.tools.filter_where(lambda tool: True) + + self.assertIsInstance(filtered, ToolCollection) + self.assertEqual(len(filtered), 3) + self.assertEqual(filtered[0], self.search_tool) + self.assertEqual(filtered[1], self.calculator_tool) + self.assertEqual(filtered[2], self.translator_tool) + + filtered = self.tools.filter_where(lambda tool: False) + + self.assertIsInstance(filtered, ToolCollection) + self.assertEqual(len(filtered), 0) + filtered = self.tools.filter_where(lambda tool: len(tool.name) > 8) + + self.assertIsInstance(filtered, ToolCollection) + self.assertEqual(len(filtered), 2) + self.assertEqual(filtered[0], self.calculator_tool) + self.assertEqual(filtered[1], self.translator_tool)