Adding support to multiple mcp servers on @CrewBase

This commit is contained in:
Thiago Moretto
2025-07-16 21:34:44 -04:00
parent 2490e8cd46
commit db90371c22
2 changed files with 70 additions and 18 deletions

View File

@@ -1,7 +1,7 @@
import inspect
import logging
from pathlib import Path
from typing import Any, Callable, Dict, TypeVar, cast, List
from typing import Any, Callable, Dict, TypeVar, cast, List, Union
from crewai.tools import BaseTool
import yaml
@@ -28,7 +28,8 @@ def CrewBase(cls: T) -> T:
)
original_tasks_config_path = getattr(cls, "tasks_config", "config/tasks.yaml")
mcp_server_params: Any = getattr(cls, "mcp_server_params", None)
mcp_server_params: Union[list[str | dict[str, str]], dict[str, str], None] = getattr(cls, "mcp_server_params", None)
_mcp_server_adapter: Union[dict[str, Any], Any, None] = None
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@@ -67,36 +68,57 @@ def CrewBase(cls: T) -> T:
self._original_functions, "is_kickoff"
)
# Add close mcp server method to after kickoff
bound_method = self._create_close_mcp_server_method()
self._after_kickoff['_close_mcp_server'] = bound_method
# Add close mcp servers method to after kickoff
bound_method = self._create_close_mcp_servers_method()
self._after_kickoff['_close_mcp_servers'] = bound_method
def _create_close_mcp_server_method(self):
def _close_mcp_server(self, instance, outputs):
adapter = getattr(self, '_mcp_server_adapter', None)
if adapter is not None:
def _create_close_mcp_servers_method(self):
def _close_mcp_servers(self, instance, outputs):
if self._mcp_server_adapter is None:
return outputs
for adapter in self._mcp_server_adapter.values():
try:
adapter.stop()
except Exception as e:
logging.warning(f"Error stopping MCP server: {e}")
return outputs
_close_mcp_server.is_after_kickoff = True
_close_mcp_servers.is_after_kickoff = True
import types
return types.MethodType(_close_mcp_server, self)
return types.MethodType(_close_mcp_servers, self)
def get_mcp_tools(self, *tool_names: list[str]) -> List[BaseTool]:
def get_mcp_tools(self, *tool_names: list[str], server: str | None = None) -> List[BaseTool]:
if not self.mcp_server_params:
return []
from crewai_tools import MCPServerAdapter
adapter = getattr(self, '_mcp_server_adapter', None)
if not adapter:
self._mcp_server_adapter = MCPServerAdapter(self.mcp_server_params)
if isinstance(self.mcp_server_params, list):
if self._mcp_server_adapter is None:
self._mcp_server_adapter = MCPServerAdapter(self.mcp_server_params)
if server is not None and len(self.mcp_server_params) > 1:
logging.warning("Using list of MCP server parameters. To use server parameter, please use a dictionary of MCP server parameters.")
# Type assertion: when mcp_server_params is a list, _mcp_server_adapter is a single MCPServerAdapter
adapter = cast(Any, self._mcp_server_adapter)
return adapter.tools.filter_by_names(tool_names or None)
return self._mcp_server_adapter.tools.filter_by_names(tool_names or None)
# Separated MCP adapters for each server.
elif isinstance(self.mcp_server_params, dict):
if self._mcp_server_adapter is None:
self._mcp_server_adapter = {}
aggregated_tools = []
for server_name, params in self.mcp_server_params.items():
if server is not None and server_name != server:
continue
adapter = self._mcp_server_adapter.get(server_name, None)
if not adapter:
self._mcp_server_adapter[server_name] = MCPServerAdapter(params)
aggregated_tools.extend(
self._mcp_server_adapter[server_name].tools.filter_by_names(tool_names or None))
return aggregated_tools
def load_configurations(self):

View File

@@ -87,7 +87,7 @@ class InternalCrew:
@CrewBase
class InternalCrewWithMCP(InternalCrew):
mcp_server_params = {"host": "localhost", "port": 8000}
mcp_server_params = [{"url": "localhost", "port": 8000}]
@agent
def reporting_analyst(self):
@@ -97,6 +97,19 @@ class InternalCrewWithMCP(InternalCrew):
def researcher(self):
return Agent(config=self.agents_config["researcher"], tools=self.get_mcp_tools("simple_tool")) # type: ignore[index]
@CrewBase
class InternalCrewWithMultipleMCP(InternalCrew):
mcp_server_params = {"mcp1": {"url": "localhost", "port": 8000}, "mcp2": {"url": "localhost", "port": 8001}}
@agent
def reporting_analyst(self):
return Agent(config=self.agents_config["reporting_analyst"], tools=self.get_mcp_tools(server="mcp1")) # type: ignore[index]
@agent
def researcher(self):
return Agent(config=self.agents_config["researcher"], tools=self.get_mcp_tools("simple_tool", server="mcp2")) # type: ignore[index]
def test_agent_memoization():
crew = SimpleCrew()
first_call_result = crew.simple_agent()
@@ -270,4 +283,21 @@ def test_internal_crew_with_mcp():
assert crew.reporting_analyst().tools == [simple_tool, another_simple_tool]
assert crew.researcher().tools == [simple_tool]
adapter_mock.assert_called_once_with({"host": "localhost", "port": 8000})
adapter_mock.assert_called_once_with([{"url": "localhost", "port": 8000}])
def test_internal_crew_with_multiple_mcp():
from crewai_tools import MCPServerAdapter
from crewai_tools.adapters.mcp_adapter import ToolCollection
from unittest.mock import call
mock = Mock(spec=MCPServerAdapter)
mock.tools = ToolCollection([simple_tool, another_simple_tool])
with patch("crewai_tools.MCPServerAdapter", return_value=mock) as adapter_mock:
crew = InternalCrewWithMultipleMCP()
assert crew.reporting_analyst().tools == [simple_tool, another_simple_tool]
assert crew.researcher().tools == [simple_tool]
adapter_mock.assert_has_calls([
call({"url": "localhost", "port": 8000}),
call({"url": "localhost", "port": 8001})
], any_order=True)