diff --git a/src/crewai/project/crew_base.py b/src/crewai/project/crew_base.py index e1602acf0..a60391420 100644 --- a/src/crewai/project/crew_base.py +++ b/src/crewai/project/crew_base.py @@ -1,7 +1,7 @@ import inspect import logging from pathlib import Path -from typing import Any, Callable, Dict, TypeVar, cast, List +from typing import Any, Callable, Dict, TypeVar, cast, List, Union from crewai.tools import BaseTool import yaml @@ -28,7 +28,8 @@ def CrewBase(cls: T) -> T: ) original_tasks_config_path = getattr(cls, "tasks_config", "config/tasks.yaml") - mcp_server_params: Any = getattr(cls, "mcp_server_params", None) + mcp_server_params: Union[list[str | dict[str, str]], dict[str, str], None] = getattr(cls, "mcp_server_params", None) + _mcp_server_adapter: Union[dict[str, Any], Any, None] = None def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -67,36 +68,57 @@ def CrewBase(cls: T) -> T: self._original_functions, "is_kickoff" ) - # Add close mcp server method to after kickoff - bound_method = self._create_close_mcp_server_method() - self._after_kickoff['_close_mcp_server'] = bound_method + # Add close mcp servers method to after kickoff + bound_method = self._create_close_mcp_servers_method() + self._after_kickoff['_close_mcp_servers'] = bound_method - def _create_close_mcp_server_method(self): - def _close_mcp_server(self, instance, outputs): - adapter = getattr(self, '_mcp_server_adapter', None) - if adapter is not None: + def _create_close_mcp_servers_method(self): + def _close_mcp_servers(self, instance, outputs): + if self._mcp_server_adapter is None: + return outputs + for adapter in self._mcp_server_adapter.values(): try: adapter.stop() except Exception as e: logging.warning(f"Error stopping MCP server: {e}") return outputs - _close_mcp_server.is_after_kickoff = True + _close_mcp_servers.is_after_kickoff = True import types - return types.MethodType(_close_mcp_server, self) + return types.MethodType(_close_mcp_servers, self) - def get_mcp_tools(self, *tool_names: list[str]) -> List[BaseTool]: + def get_mcp_tools(self, *tool_names: list[str], server: str | None = None) -> List[BaseTool]: if not self.mcp_server_params: return [] from crewai_tools import MCPServerAdapter - adapter = getattr(self, '_mcp_server_adapter', None) - if not adapter: - self._mcp_server_adapter = MCPServerAdapter(self.mcp_server_params) + if isinstance(self.mcp_server_params, list): + if self._mcp_server_adapter is None: + self._mcp_server_adapter = MCPServerAdapter(self.mcp_server_params) + if server is not None and len(self.mcp_server_params) > 1: + logging.warning("Using list of MCP server parameters. To use server parameter, please use a dictionary of MCP server parameters.") + # Type assertion: when mcp_server_params is a list, _mcp_server_adapter is a single MCPServerAdapter + adapter = cast(Any, self._mcp_server_adapter) + return adapter.tools.filter_by_names(tool_names or None) - return self._mcp_server_adapter.tools.filter_by_names(tool_names or None) + # Separated MCP adapters for each server. + elif isinstance(self.mcp_server_params, dict): + if self._mcp_server_adapter is None: + self._mcp_server_adapter = {} + aggregated_tools = [] + for server_name, params in self.mcp_server_params.items(): + if server is not None and server_name != server: + continue + + adapter = self._mcp_server_adapter.get(server_name, None) + if not adapter: + self._mcp_server_adapter[server_name] = MCPServerAdapter(params) + aggregated_tools.extend( + self._mcp_server_adapter[server_name].tools.filter_by_names(tool_names or None)) + + return aggregated_tools def load_configurations(self): diff --git a/tests/project_test.py b/tests/project_test.py index 708913d24..8865444c8 100644 --- a/tests/project_test.py +++ b/tests/project_test.py @@ -87,7 +87,7 @@ class InternalCrew: @CrewBase class InternalCrewWithMCP(InternalCrew): - mcp_server_params = {"host": "localhost", "port": 8000} + mcp_server_params = [{"url": "localhost", "port": 8000}] @agent def reporting_analyst(self): @@ -97,6 +97,19 @@ class InternalCrewWithMCP(InternalCrew): def researcher(self): return Agent(config=self.agents_config["researcher"], tools=self.get_mcp_tools("simple_tool")) # type: ignore[index] +@CrewBase +class InternalCrewWithMultipleMCP(InternalCrew): + mcp_server_params = {"mcp1": {"url": "localhost", "port": 8000}, "mcp2": {"url": "localhost", "port": 8001}} + + @agent + def reporting_analyst(self): + return Agent(config=self.agents_config["reporting_analyst"], tools=self.get_mcp_tools(server="mcp1")) # type: ignore[index] + + @agent + def researcher(self): + return Agent(config=self.agents_config["researcher"], tools=self.get_mcp_tools("simple_tool", server="mcp2")) # type: ignore[index] + + def test_agent_memoization(): crew = SimpleCrew() first_call_result = crew.simple_agent() @@ -270,4 +283,21 @@ def test_internal_crew_with_mcp(): assert crew.reporting_analyst().tools == [simple_tool, another_simple_tool] assert crew.researcher().tools == [simple_tool] - adapter_mock.assert_called_once_with({"host": "localhost", "port": 8000}) \ No newline at end of file + adapter_mock.assert_called_once_with([{"url": "localhost", "port": 8000}]) + + +def test_internal_crew_with_multiple_mcp(): + from crewai_tools import MCPServerAdapter + from crewai_tools.adapters.mcp_adapter import ToolCollection + from unittest.mock import call + + mock = Mock(spec=MCPServerAdapter) + mock.tools = ToolCollection([simple_tool, another_simple_tool]) + with patch("crewai_tools.MCPServerAdapter", return_value=mock) as adapter_mock: + crew = InternalCrewWithMultipleMCP() + assert crew.reporting_analyst().tools == [simple_tool, another_simple_tool] + assert crew.researcher().tools == [simple_tool] + adapter_mock.assert_has_calls([ + call({"url": "localhost", "port": 8000}), + call({"url": "localhost", "port": 8001}) + ], any_order=True)