Support to filter available MCP Tools (#345)

* feat: support to complex filter on ToolCollection

* refactor: use proper tool collection methot to filter tool in CrewAiEnterpriseTools

* feat: allow to filter available MCP tools
This commit is contained in:
Lucas Gomide
2025-06-25 13:32:22 -03:00
committed by GitHub
parent e8825d071a
commit 03917411b4
5 changed files with 193 additions and 23 deletions

View File

@@ -46,10 +46,18 @@ class MCPServerAdapter:
with MCPServerAdapter({"url": "http://localhost:8000/sse"}) as tools:
# tools is now available
# context manager with filtered tools
with MCPServerAdapter(..., "tool1", "tool2") as filtered_tools:
# only tool1 and tool2 are available
# manually stop mcp server
try:
mcp_server = MCPServerAdapter(...)
tools = mcp_server.tools
tools = mcp_server.tools # all tools
# or with filtered tools
mcp_server = MCPServerAdapter(..., "tool1", "tool2")
filtered_tools = mcp_server.tools # only tool1 and tool2
...
finally:
mcp_server.stop()
@@ -61,18 +69,22 @@ class MCPServerAdapter:
def __init__(
self,
serverparams: StdioServerParameters | dict[str, Any],
*tool_names: str,
):
"""Initialize the MCP Server
Args:
serverparams: The parameters for the MCP server it supports either a
`StdioServerParameters` or a `dict` respectively for STDIO and SSE.
*tool_names: Optional names of tools to filter. If provided, only tools with
matching names will be available.
"""
super().__init__()
self._adapter = None
self._tools = None
self._tool_names = list(tool_names) if tool_names else None
if not MCP_AVAILABLE:
import click
@@ -127,7 +139,11 @@ class MCPServerAdapter:
raise ValueError(
"MCP server not started, run `mcp_server.start()` first before accessing `tools`"
)
return ToolCollection(self._tools)
tools_collection = ToolCollection(self._tools)
if self._tool_names:
return tools_collection.filter_by_names(self._tool_names)
return tools_collection
def __enter__(self):
"""

View File

@@ -1,4 +1,4 @@
from typing import List, Optional, Union, TypeVar, Generic, Dict
from typing import List, Optional, Union, TypeVar, Generic, Dict, Callable
from crewai.tools import BaseTool
T = TypeVar('T', bound=BaseTool)
@@ -24,16 +24,16 @@ class ToolCollection(list, Generic[T]):
self._build_name_cache()
def _build_name_cache(self) -> None:
self._name_cache = {tool.name: tool for tool in self}
self._name_cache = {tool.name.lower(): tool for tool in self}
def __getitem__(self, key: Union[int, str]) -> T:
if isinstance(key, str):
return self._name_cache[key]
return self._name_cache[key.lower()]
return super().__getitem__(key)
def append(self, tool: T) -> None:
super().append(tool)
self._name_cache[tool.name] = tool
self._name_cache[tool.name.lower()] = tool
def extend(self, tools: List[T]) -> None:
super().extend(tools)
@@ -41,19 +41,34 @@ class ToolCollection(list, Generic[T]):
def insert(self, index: int, tool: T) -> None:
super().insert(index, tool)
self._name_cache[tool.name] = tool
self._name_cache[tool.name.lower()] = tool
def remove(self, tool: T) -> None:
super().remove(tool)
if tool.name in self._name_cache:
del self._name_cache[tool.name]
if tool.name.lower() in self._name_cache:
del self._name_cache[tool.name.lower()]
def pop(self, index: int = -1) -> T:
tool = super().pop(index)
if tool.name in self._name_cache:
del self._name_cache[tool.name]
if tool.name.lower() in self._name_cache:
del self._name_cache[tool.name.lower()]
return tool
def filter_by_names(self, names: Optional[List[str]] = None) -> "ToolCollection[T]":
if names is None:
return self
return ToolCollection(
[
tool
for name in names
if (tool := self._name_cache.get(name.lower())) is not None
]
)
def filter_where(self, func: Callable[[T], bool]) -> "ToolCollection[T]":
return ToolCollection([tool for tool in self if func(tool)])
def clear(self) -> None:
super().clear()
self._name_cache.clear()
self._name_cache.clear()

View File

@@ -49,9 +49,5 @@ def CrewaiEnterpriseTools(
adapter = EnterpriseActionKitToolAdapter(**adapter_kwargs)
all_tools = adapter.tools()
if actions_list is None:
return ToolCollection(all_tools)
# Filter tools based on the provided list
filtered_tools = [tool for tool in all_tools if tool.name.lower() in [action.lower() for action in actions_list]]
return ToolCollection(filtered_tools)
return ToolCollection(all_tools).filter_by_names(actions_list)