mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-09 16:18:30 +00:00
Support to filter available MCP Tools (#345)
* feat: support to complex filter on ToolCollection * refactor: use proper tool collection methot to filter tool in CrewAiEnterpriseTools * feat: allow to filter available MCP tools
This commit is contained in:
@@ -46,10 +46,18 @@ class MCPServerAdapter:
|
||||
with MCPServerAdapter({"url": "http://localhost:8000/sse"}) as tools:
|
||||
# tools is now available
|
||||
|
||||
# context manager with filtered tools
|
||||
with MCPServerAdapter(..., "tool1", "tool2") as filtered_tools:
|
||||
# only tool1 and tool2 are available
|
||||
|
||||
# manually stop mcp server
|
||||
try:
|
||||
mcp_server = MCPServerAdapter(...)
|
||||
tools = mcp_server.tools
|
||||
tools = mcp_server.tools # all tools
|
||||
|
||||
# or with filtered tools
|
||||
mcp_server = MCPServerAdapter(..., "tool1", "tool2")
|
||||
filtered_tools = mcp_server.tools # only tool1 and tool2
|
||||
...
|
||||
finally:
|
||||
mcp_server.stop()
|
||||
@@ -61,18 +69,22 @@ class MCPServerAdapter:
|
||||
def __init__(
|
||||
self,
|
||||
serverparams: StdioServerParameters | dict[str, Any],
|
||||
*tool_names: str,
|
||||
):
|
||||
"""Initialize the MCP Server
|
||||
|
||||
Args:
|
||||
serverparams: The parameters for the MCP server it supports either a
|
||||
`StdioServerParameters` or a `dict` respectively for STDIO and SSE.
|
||||
*tool_names: Optional names of tools to filter. If provided, only tools with
|
||||
matching names will be available.
|
||||
|
||||
"""
|
||||
|
||||
super().__init__()
|
||||
self._adapter = None
|
||||
self._tools = None
|
||||
self._tool_names = list(tool_names) if tool_names else None
|
||||
|
||||
if not MCP_AVAILABLE:
|
||||
import click
|
||||
@@ -127,7 +139,11 @@ class MCPServerAdapter:
|
||||
raise ValueError(
|
||||
"MCP server not started, run `mcp_server.start()` first before accessing `tools`"
|
||||
)
|
||||
return ToolCollection(self._tools)
|
||||
|
||||
tools_collection = ToolCollection(self._tools)
|
||||
if self._tool_names:
|
||||
return tools_collection.filter_by_names(self._tool_names)
|
||||
return tools_collection
|
||||
|
||||
def __enter__(self):
|
||||
"""
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import List, Optional, Union, TypeVar, Generic, Dict
|
||||
from typing import List, Optional, Union, TypeVar, Generic, Dict, Callable
|
||||
from crewai.tools import BaseTool
|
||||
|
||||
T = TypeVar('T', bound=BaseTool)
|
||||
@@ -24,16 +24,16 @@ class ToolCollection(list, Generic[T]):
|
||||
self._build_name_cache()
|
||||
|
||||
def _build_name_cache(self) -> None:
|
||||
self._name_cache = {tool.name: tool for tool in self}
|
||||
self._name_cache = {tool.name.lower(): tool for tool in self}
|
||||
|
||||
def __getitem__(self, key: Union[int, str]) -> T:
|
||||
if isinstance(key, str):
|
||||
return self._name_cache[key]
|
||||
return self._name_cache[key.lower()]
|
||||
return super().__getitem__(key)
|
||||
|
||||
def append(self, tool: T) -> None:
|
||||
super().append(tool)
|
||||
self._name_cache[tool.name] = tool
|
||||
self._name_cache[tool.name.lower()] = tool
|
||||
|
||||
def extend(self, tools: List[T]) -> None:
|
||||
super().extend(tools)
|
||||
@@ -41,19 +41,34 @@ class ToolCollection(list, Generic[T]):
|
||||
|
||||
def insert(self, index: int, tool: T) -> None:
|
||||
super().insert(index, tool)
|
||||
self._name_cache[tool.name] = tool
|
||||
self._name_cache[tool.name.lower()] = tool
|
||||
|
||||
def remove(self, tool: T) -> None:
|
||||
super().remove(tool)
|
||||
if tool.name in self._name_cache:
|
||||
del self._name_cache[tool.name]
|
||||
if tool.name.lower() in self._name_cache:
|
||||
del self._name_cache[tool.name.lower()]
|
||||
|
||||
def pop(self, index: int = -1) -> T:
|
||||
tool = super().pop(index)
|
||||
if tool.name in self._name_cache:
|
||||
del self._name_cache[tool.name]
|
||||
if tool.name.lower() in self._name_cache:
|
||||
del self._name_cache[tool.name.lower()]
|
||||
return tool
|
||||
|
||||
def filter_by_names(self, names: Optional[List[str]] = None) -> "ToolCollection[T]":
|
||||
if names is None:
|
||||
return self
|
||||
|
||||
return ToolCollection(
|
||||
[
|
||||
tool
|
||||
for name in names
|
||||
if (tool := self._name_cache.get(name.lower())) is not None
|
||||
]
|
||||
)
|
||||
|
||||
def filter_where(self, func: Callable[[T], bool]) -> "ToolCollection[T]":
|
||||
return ToolCollection([tool for tool in self if func(tool)])
|
||||
|
||||
def clear(self) -> None:
|
||||
super().clear()
|
||||
self._name_cache.clear()
|
||||
@@ -49,9 +49,5 @@ def CrewaiEnterpriseTools(
|
||||
adapter = EnterpriseActionKitToolAdapter(**adapter_kwargs)
|
||||
all_tools = adapter.tools()
|
||||
|
||||
if actions_list is None:
|
||||
return ToolCollection(all_tools)
|
||||
|
||||
# Filter tools based on the provided list
|
||||
filtered_tools = [tool for tool in all_tools if tool.name.lower() in [action.lower() for action in actions_list]]
|
||||
return ToolCollection(filtered_tools)
|
||||
return ToolCollection(all_tools).filter_by_names(actions_list)
|
||||
|
||||
@@ -19,6 +19,11 @@ def echo_server_script():
|
||||
"""Echo the input text"""
|
||||
return f"Echo: {text}"
|
||||
|
||||
@mcp.tool()
|
||||
def calc_tool(a: int, b: int) -> int:
|
||||
"""Calculate a + b"""
|
||||
return a + b
|
||||
|
||||
mcp.run()
|
||||
'''
|
||||
)
|
||||
@@ -37,6 +42,11 @@ def echo_server_sse_script():
|
||||
"""Echo the input text"""
|
||||
return f"Echo: {text}"
|
||||
|
||||
@mcp.tool()
|
||||
def calc_tool(a: int, b: int) -> int:
|
||||
"""Calculate a + b"""
|
||||
return a + b
|
||||
|
||||
mcp.run("sse")
|
||||
'''
|
||||
)
|
||||
@@ -69,16 +79,20 @@ def test_context_manager_syntax(echo_server_script):
|
||||
)
|
||||
with MCPServerAdapter(serverparams) as tools:
|
||||
assert isinstance(tools, ToolCollection)
|
||||
assert len(tools) == 1
|
||||
assert len(tools) == 2
|
||||
assert tools[0].name == "echo_tool"
|
||||
assert tools[1].name == "calc_tool"
|
||||
assert tools[0].run(text="hello") == "Echo: hello"
|
||||
assert tools[1].run(a=5, b=3) == '8'
|
||||
|
||||
def test_context_manager_syntax_sse(echo_sse_server):
|
||||
sse_serverparams = echo_sse_server
|
||||
with MCPServerAdapter(sse_serverparams) as tools:
|
||||
assert len(tools) == 1
|
||||
assert len(tools) == 2
|
||||
assert tools[0].name == "echo_tool"
|
||||
assert tools[1].name == "calc_tool"
|
||||
assert tools[0].run(text="hello") == "Echo: hello"
|
||||
assert tools[1].run(a=5, b=3) == '8'
|
||||
|
||||
def test_try_finally_syntax(echo_server_script):
|
||||
serverparams = StdioServerParameters(
|
||||
@@ -87,9 +101,11 @@ def test_try_finally_syntax(echo_server_script):
|
||||
try:
|
||||
mcp_server_adapter = MCPServerAdapter(serverparams)
|
||||
tools = mcp_server_adapter.tools
|
||||
assert len(tools) == 1
|
||||
assert len(tools) == 2
|
||||
assert tools[0].name == "echo_tool"
|
||||
assert tools[1].name == "calc_tool"
|
||||
assert tools[0].run(text="hello") == "Echo: hello"
|
||||
assert tools[1].run(a=5, b=3) == '8'
|
||||
finally:
|
||||
mcp_server_adapter.stop()
|
||||
|
||||
@@ -98,8 +114,76 @@ def test_try_finally_syntax_sse(echo_sse_server):
|
||||
mcp_server_adapter = MCPServerAdapter(sse_serverparams)
|
||||
try:
|
||||
tools = mcp_server_adapter.tools
|
||||
assert len(tools) == 2
|
||||
assert tools[0].name == "echo_tool"
|
||||
assert tools[1].name == "calc_tool"
|
||||
assert tools[0].run(text="hello") == "Echo: hello"
|
||||
assert tools[1].run(a=5, b=3) == '8'
|
||||
finally:
|
||||
mcp_server_adapter.stop()
|
||||
|
||||
def test_context_manager_with_filtered_tools(echo_server_script):
|
||||
serverparams = StdioServerParameters(
|
||||
command="uv", args=["run", "python", "-c", echo_server_script]
|
||||
)
|
||||
# Only select the echo_tool
|
||||
with MCPServerAdapter(serverparams, "echo_tool") as tools:
|
||||
assert isinstance(tools, ToolCollection)
|
||||
assert len(tools) == 1
|
||||
assert tools[0].name == "echo_tool"
|
||||
assert tools[0].run(text="hello") == "Echo: hello"
|
||||
# Check that calc_tool is not present
|
||||
with pytest.raises(IndexError):
|
||||
_ = tools[1]
|
||||
with pytest.raises(KeyError):
|
||||
_ = tools["calc_tool"]
|
||||
|
||||
def test_context_manager_sse_with_filtered_tools(echo_sse_server):
|
||||
sse_serverparams = echo_sse_server
|
||||
# Only select the calc_tool
|
||||
with MCPServerAdapter(sse_serverparams, "calc_tool") as tools:
|
||||
assert isinstance(tools, ToolCollection)
|
||||
assert len(tools) == 1
|
||||
assert tools[0].name == "calc_tool"
|
||||
assert tools[0].run(a=10, b=5) == '15'
|
||||
# Check that echo_tool is not present
|
||||
with pytest.raises(IndexError):
|
||||
_ = tools[1]
|
||||
with pytest.raises(KeyError):
|
||||
_ = tools["echo_tool"]
|
||||
|
||||
def test_try_finally_with_filtered_tools(echo_server_script):
|
||||
serverparams = StdioServerParameters(
|
||||
command="uv", args=["run", "python", "-c", echo_server_script]
|
||||
)
|
||||
try:
|
||||
# Select both tools but in reverse order
|
||||
mcp_server_adapter = MCPServerAdapter(serverparams, "calc_tool", "echo_tool")
|
||||
tools = mcp_server_adapter.tools
|
||||
assert len(tools) == 2
|
||||
# The order of tools is based on filter_by_names which preserves
|
||||
# the original order from the collection
|
||||
assert tools[0].name == "calc_tool"
|
||||
assert tools[1].name == "echo_tool"
|
||||
finally:
|
||||
mcp_server_adapter.stop()
|
||||
|
||||
def test_filter_with_nonexistent_tool(echo_server_script):
|
||||
serverparams = StdioServerParameters(
|
||||
command="uv", args=["run", "python", "-c", echo_server_script]
|
||||
)
|
||||
# Include a tool that doesn't exist
|
||||
with MCPServerAdapter(serverparams, "echo_tool", "nonexistent_tool") as tools:
|
||||
# Only echo_tool should be in the result
|
||||
assert len(tools) == 1
|
||||
assert tools[0].name == "echo_tool"
|
||||
|
||||
def test_filter_with_only_nonexistent_tools(echo_server_script):
|
||||
serverparams = StdioServerParameters(
|
||||
command="uv", args=["run", "python", "-c", echo_server_script]
|
||||
)
|
||||
# All requested tools don't exist
|
||||
with MCPServerAdapter(serverparams, "nonexistent1", "nonexistent2") as tools:
|
||||
# Should return an empty tool collection
|
||||
assert isinstance(tools, ToolCollection)
|
||||
assert len(tools) == 0
|
||||
|
||||
@@ -8,7 +8,7 @@ from crewai_tools.adapters.tool_collection import ToolCollection
|
||||
class TestToolCollection(unittest.TestCase):
|
||||
def setUp(self):
|
||||
|
||||
self.search_tool = self._create_mock_tool("search", "Search Tool")
|
||||
self.search_tool = self._create_mock_tool("SearcH", "Search Tool") # Tool name is case sensitive
|
||||
self.calculator_tool = self._create_mock_tool("calculator", "Calculator Tool")
|
||||
self.translator_tool = self._create_mock_tool("translator", "Translator Tool")
|
||||
|
||||
@@ -26,7 +26,7 @@ class TestToolCollection(unittest.TestCase):
|
||||
|
||||
def test_initialization(self):
|
||||
self.assertEqual(len(self.tools), 3)
|
||||
self.assertEqual(self.tools[0].name, "search")
|
||||
self.assertEqual(self.tools[0].name, "SearcH")
|
||||
self.assertEqual(self.tools[1].name, "calculator")
|
||||
self.assertEqual(self.tools[2].name, "translator")
|
||||
|
||||
@@ -170,3 +170,62 @@ class TestToolCollection(unittest.TestCase):
|
||||
|
||||
with self.assertRaises(IndexError):
|
||||
_ = self.tools[123]
|
||||
|
||||
def test_filter_by_names(self):
|
||||
|
||||
filtered = self.tools.filter_by_names(None)
|
||||
|
||||
self.assertIsInstance(filtered, ToolCollection)
|
||||
self.assertEqual(len(filtered), 3)
|
||||
|
||||
filtered = self.tools.filter_by_names(["search", "translator"])
|
||||
|
||||
self.assertIsInstance(filtered, ToolCollection)
|
||||
self.assertEqual(len(filtered), 2)
|
||||
self.assertEqual(filtered[0], self.search_tool)
|
||||
self.assertEqual(filtered[1], self.translator_tool)
|
||||
self.assertEqual(filtered["search"], self.search_tool)
|
||||
self.assertEqual(filtered["translator"], self.translator_tool)
|
||||
|
||||
filtered = self.tools.filter_by_names(["search", "nonexistent"])
|
||||
|
||||
self.assertIsInstance(filtered, ToolCollection)
|
||||
self.assertEqual(len(filtered), 1)
|
||||
self.assertEqual(filtered[0], self.search_tool)
|
||||
|
||||
filtered = self.tools.filter_by_names(["nonexistent1", "nonexistent2"])
|
||||
|
||||
self.assertIsInstance(filtered, ToolCollection)
|
||||
self.assertEqual(len(filtered), 0)
|
||||
|
||||
filtered = self.tools.filter_by_names([])
|
||||
|
||||
self.assertIsInstance(filtered, ToolCollection)
|
||||
self.assertEqual(len(filtered), 0)
|
||||
|
||||
def test_filter_where(self):
|
||||
filtered = self.tools.filter_where(lambda tool: tool.name.startswith("S"))
|
||||
|
||||
self.assertIsInstance(filtered, ToolCollection)
|
||||
self.assertEqual(len(filtered), 1)
|
||||
self.assertEqual(filtered[0], self.search_tool)
|
||||
self.assertEqual(filtered["search"], self.search_tool)
|
||||
|
||||
filtered = self.tools.filter_where(lambda tool: True)
|
||||
|
||||
self.assertIsInstance(filtered, ToolCollection)
|
||||
self.assertEqual(len(filtered), 3)
|
||||
self.assertEqual(filtered[0], self.search_tool)
|
||||
self.assertEqual(filtered[1], self.calculator_tool)
|
||||
self.assertEqual(filtered[2], self.translator_tool)
|
||||
|
||||
filtered = self.tools.filter_where(lambda tool: False)
|
||||
|
||||
self.assertIsInstance(filtered, ToolCollection)
|
||||
self.assertEqual(len(filtered), 0)
|
||||
filtered = self.tools.filter_where(lambda tool: len(tool.name) > 8)
|
||||
|
||||
self.assertIsInstance(filtered, ToolCollection)
|
||||
self.assertEqual(len(filtered), 2)
|
||||
self.assertEqual(filtered[0], self.calculator_tool)
|
||||
self.assertEqual(filtered[1], self.translator_tool)
|
||||
|
||||
Reference in New Issue
Block a user