Adding support to multiple mcp servers on @CrewBase

This commit is contained in:
Thiago Moretto
2025-07-16 21:34:44 -04:00
parent 2490e8cd46
commit db90371c22
2 changed files with 70 additions and 18 deletions

View File

@@ -1,7 +1,7 @@
import inspect import inspect
import logging import logging
from pathlib import Path from pathlib import Path
from typing import Any, Callable, Dict, TypeVar, cast, List from typing import Any, Callable, Dict, TypeVar, cast, List, Union
from crewai.tools import BaseTool from crewai.tools import BaseTool
import yaml import yaml
@@ -28,7 +28,8 @@ def CrewBase(cls: T) -> T:
) )
original_tasks_config_path = getattr(cls, "tasks_config", "config/tasks.yaml") original_tasks_config_path = getattr(cls, "tasks_config", "config/tasks.yaml")
mcp_server_params: Any = getattr(cls, "mcp_server_params", None) mcp_server_params: Union[list[str | dict[str, str]], dict[str, str], None] = getattr(cls, "mcp_server_params", None)
_mcp_server_adapter: Union[dict[str, Any], Any, None] = None
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
@@ -67,36 +68,57 @@ def CrewBase(cls: T) -> T:
self._original_functions, "is_kickoff" self._original_functions, "is_kickoff"
) )
# Add close mcp server method to after kickoff # Add close mcp servers method to after kickoff
bound_method = self._create_close_mcp_server_method() bound_method = self._create_close_mcp_servers_method()
self._after_kickoff['_close_mcp_server'] = bound_method self._after_kickoff['_close_mcp_servers'] = bound_method
def _create_close_mcp_server_method(self): def _create_close_mcp_servers_method(self):
def _close_mcp_server(self, instance, outputs): def _close_mcp_servers(self, instance, outputs):
adapter = getattr(self, '_mcp_server_adapter', None) if self._mcp_server_adapter is None:
if adapter is not None: return outputs
for adapter in self._mcp_server_adapter.values():
try: try:
adapter.stop() adapter.stop()
except Exception as e: except Exception as e:
logging.warning(f"Error stopping MCP server: {e}") logging.warning(f"Error stopping MCP server: {e}")
return outputs return outputs
_close_mcp_server.is_after_kickoff = True _close_mcp_servers.is_after_kickoff = True
import types import types
return types.MethodType(_close_mcp_server, self) return types.MethodType(_close_mcp_servers, self)
def get_mcp_tools(self, *tool_names: list[str]) -> List[BaseTool]: def get_mcp_tools(self, *tool_names: list[str], server: str | None = None) -> List[BaseTool]:
if not self.mcp_server_params: if not self.mcp_server_params:
return [] return []
from crewai_tools import MCPServerAdapter from crewai_tools import MCPServerAdapter
adapter = getattr(self, '_mcp_server_adapter', None) if isinstance(self.mcp_server_params, list):
if not adapter: if self._mcp_server_adapter is None:
self._mcp_server_adapter = MCPServerAdapter(self.mcp_server_params) self._mcp_server_adapter = MCPServerAdapter(self.mcp_server_params)
if server is not None and len(self.mcp_server_params) > 1:
logging.warning("Using list of MCP server parameters. To use server parameter, please use a dictionary of MCP server parameters.")
# Type assertion: when mcp_server_params is a list, _mcp_server_adapter is a single MCPServerAdapter
adapter = cast(Any, self._mcp_server_adapter)
return adapter.tools.filter_by_names(tool_names or None)
return self._mcp_server_adapter.tools.filter_by_names(tool_names or None) # Separated MCP adapters for each server.
elif isinstance(self.mcp_server_params, dict):
if self._mcp_server_adapter is None:
self._mcp_server_adapter = {}
aggregated_tools = []
for server_name, params in self.mcp_server_params.items():
if server is not None and server_name != server:
continue
adapter = self._mcp_server_adapter.get(server_name, None)
if not adapter:
self._mcp_server_adapter[server_name] = MCPServerAdapter(params)
aggregated_tools.extend(
self._mcp_server_adapter[server_name].tools.filter_by_names(tool_names or None))
return aggregated_tools
def load_configurations(self): def load_configurations(self):

View File

@@ -87,7 +87,7 @@ class InternalCrew:
@CrewBase @CrewBase
class InternalCrewWithMCP(InternalCrew): class InternalCrewWithMCP(InternalCrew):
mcp_server_params = {"host": "localhost", "port": 8000} mcp_server_params = [{"url": "localhost", "port": 8000}]
@agent @agent
def reporting_analyst(self): def reporting_analyst(self):
@@ -97,6 +97,19 @@ class InternalCrewWithMCP(InternalCrew):
def researcher(self): def researcher(self):
return Agent(config=self.agents_config["researcher"], tools=self.get_mcp_tools("simple_tool")) # type: ignore[index] return Agent(config=self.agents_config["researcher"], tools=self.get_mcp_tools("simple_tool")) # type: ignore[index]
@CrewBase
class InternalCrewWithMultipleMCP(InternalCrew):
mcp_server_params = {"mcp1": {"url": "localhost", "port": 8000}, "mcp2": {"url": "localhost", "port": 8001}}
@agent
def reporting_analyst(self):
return Agent(config=self.agents_config["reporting_analyst"], tools=self.get_mcp_tools(server="mcp1")) # type: ignore[index]
@agent
def researcher(self):
return Agent(config=self.agents_config["researcher"], tools=self.get_mcp_tools("simple_tool", server="mcp2")) # type: ignore[index]
def test_agent_memoization(): def test_agent_memoization():
crew = SimpleCrew() crew = SimpleCrew()
first_call_result = crew.simple_agent() first_call_result = crew.simple_agent()
@@ -270,4 +283,21 @@ def test_internal_crew_with_mcp():
assert crew.reporting_analyst().tools == [simple_tool, another_simple_tool] assert crew.reporting_analyst().tools == [simple_tool, another_simple_tool]
assert crew.researcher().tools == [simple_tool] assert crew.researcher().tools == [simple_tool]
adapter_mock.assert_called_once_with({"host": "localhost", "port": 8000}) adapter_mock.assert_called_once_with([{"url": "localhost", "port": 8000}])
def test_internal_crew_with_multiple_mcp():
from crewai_tools import MCPServerAdapter
from crewai_tools.adapters.mcp_adapter import ToolCollection
from unittest.mock import call
mock = Mock(spec=MCPServerAdapter)
mock.tools = ToolCollection([simple_tool, another_simple_tool])
with patch("crewai_tools.MCPServerAdapter", return_value=mock) as adapter_mock:
crew = InternalCrewWithMultipleMCP()
assert crew.reporting_analyst().tools == [simple_tool, another_simple_tool]
assert crew.researcher().tools == [simple_tool]
adapter_mock.assert_has_calls([
call({"url": "localhost", "port": 8000}),
call({"url": "localhost", "port": 8001})
], any_order=True)