From e72a80be6e0450ed4ac9da763b0ebce3db9100b8 Mon Sep 17 00:00:00 2001 From: Lucas Gomide Date: Tue, 10 Mar 2026 15:00:40 -0300 Subject: [PATCH] Addressing MCP tools resolutions & eliminates all shared mutable connection (#4792) * fix: allow hyphenated tool names in MCP references like notion#get-page The _SLUG_RE regex on BaseAgent rejected MCP tool references containing hyphens (e.g. "notion#get-page") because the fragment pattern only matched \w (word chars) * fix: create fresh MCP client per tool invocation to prevent parallel call races When the LLM dispatches parallel calls to MCP tools on the same server, the executor runs them concurrently via ThreadPoolExecutor. Previously, all tools from a server shared a single MCPClient instance, and even the same tool called twice would reuse one client. Since each thread creates its own asyncio event loop via asyncio.run(), concurrent connect/disconnect calls on the shared client caused anyio cancel-scope errors ("Attempted to exit cancel scope in a different task than it was entered in"). The fix introduces a client_factory pattern: MCPNativeTool now receives a zero-arg callable that produces a fresh MCPClient + transport on every _run_async() invocation. This eliminates all shared mutable connection state between concurrent calls, whether to the same tool or different tools from the same server. * test: ensure we can filter hyphenated MCP tool --- .../crewai/agents/agent_builder/base_agent.py | 2 +- lib/crewai/src/crewai/mcp/tool_resolver.py | 90 ++++++--- .../src/crewai/tools/mcp_native_tool.py | 74 ++----- lib/crewai/tests/agents/test_agent.py | 65 ++++++ lib/crewai/tests/mcp/test_amp_mcp.py | 132 +++++++++++++ lib/crewai/tests/mcp/test_mcp_config.py | 187 ++++++++++++++---- 6 files changed, 420 insertions(+), 130 deletions(-) diff --git a/lib/crewai/src/crewai/agents/agent_builder/base_agent.py b/lib/crewai/src/crewai/agents/agent_builder/base_agent.py index 8b2b9737c..da32d9c1c 100644 --- a/lib/crewai/src/crewai/agents/agent_builder/base_agent.py +++ b/lib/crewai/src/crewai/agents/agent_builder/base_agent.py @@ -38,7 +38,7 @@ from crewai.utilities.string_utils import interpolate_only _SLUG_RE: Final[re.Pattern[str]] = re.compile( - r"^(?:crewai-amp:)?[a-zA-Z0-9][a-zA-Z0-9_-]*(?:#\w+)?$" + r"^(?:crewai-amp:)?[a-zA-Z0-9][a-zA-Z0-9_-]*(?:#[\w-]+)?$" ) diff --git a/lib/crewai/src/crewai/mcp/tool_resolver.py b/lib/crewai/src/crewai/mcp/tool_resolver.py index 34af189f2..c0428f82d 100644 --- a/lib/crewai/src/crewai/mcp/tool_resolver.py +++ b/lib/crewai/src/crewai/mcp/tool_resolver.py @@ -22,6 +22,7 @@ from crewai.mcp.config import ( MCPServerSSE, MCPServerStdio, ) +from crewai.utilities.string_utils import sanitize_tool_name from crewai.mcp.transports.http import HTTPTransport from crewai.mcp.transports.sse import SSETransport from crewai.mcp.transports.stdio import StdioTransport @@ -74,10 +75,9 @@ class MCPToolResolver: elif isinstance(mcp_config, str): amp_refs.append(self._parse_amp_ref(mcp_config)) else: - tools, client = self._resolve_native(mcp_config) + tools, clients = self._resolve_native(mcp_config) all_tools.extend(tools) - if client: - self._clients.append(client) + self._clients.extend(clients) if amp_refs: tools, clients = self._resolve_amp(amp_refs) @@ -131,7 +131,7 @@ class MCPToolResolver: all_tools: list[BaseTool] = [] all_clients: list[Any] = [] - resolved_cache: dict[str, tuple[list[BaseTool], Any | None]] = {} + resolved_cache: dict[str, tuple[list[BaseTool], list[Any]]] = {} for slug in unique_slugs: config_dict = amp_configs_map.get(slug) @@ -149,10 +149,9 @@ class MCPToolResolver: mcp_server_config = self._build_mcp_config_from_dict(config_dict) try: - tools, client = self._resolve_native(mcp_server_config) - resolved_cache[slug] = (tools, client) - if client: - all_clients.append(client) + tools, clients = self._resolve_native(mcp_server_config) + resolved_cache[slug] = (tools, clients) + all_clients.extend(clients) except Exception as e: crewai_event_bus.emit( self, @@ -170,8 +169,9 @@ class MCPToolResolver: slug_tools, _ = cached if specific_tool: + sanitized = sanitize_tool_name(specific_tool) all_tools.extend( - t for t in slug_tools if t.name.endswith(f"_{specific_tool}") + t for t in slug_tools if t.name.endswith(f"_{sanitized}") ) else: all_tools.extend(slug_tools) @@ -198,7 +198,6 @@ class MCPToolResolver: plus_api = PlusAPI(api_key=get_platform_integration_token()) response = plus_api.get_mcp_configs(slugs) - if response.status_code == 200: configs: dict[str, dict[str, Any]] = response.json().get("configs", {}) return configs @@ -218,6 +217,7 @@ class MCPToolResolver: def _resolve_external(self, mcp_ref: str) -> list[BaseTool]: """Resolve an HTTPS MCP server URL into tools.""" + from crewai.tools.base_tool import BaseTool from crewai.tools.mcp_tool_wrapper import MCPToolWrapper if "#" in mcp_ref: @@ -227,6 +227,7 @@ class MCPToolResolver: server_params = {"url": server_url} server_name = self._extract_server_name(server_url) + sanitized_specific_tool = sanitize_tool_name(specific_tool) if specific_tool else None try: tool_schemas = self._get_mcp_tool_schemas(server_params) @@ -239,7 +240,7 @@ class MCPToolResolver: tools = [] for tool_name, schema in tool_schemas.items(): - if specific_tool and tool_name != specific_tool: + if sanitized_specific_tool and tool_name != sanitized_specific_tool: continue try: @@ -271,14 +272,16 @@ class MCPToolResolver: ) return [] - def _resolve_native( - self, mcp_config: MCPServerConfig - ) -> tuple[list[BaseTool], Any | None]: - """Resolve an ``MCPServerConfig`` into tools, returning the client for cleanup.""" - from crewai.tools.base_tool import BaseTool - from crewai.tools.mcp_native_tool import MCPNativeTool + @staticmethod + def _create_transport( + mcp_config: MCPServerConfig, + ) -> tuple[StdioTransport | HTTPTransport | SSETransport, str]: + """Create a fresh transport instance from an MCP server config. - transport: StdioTransport | HTTPTransport | SSETransport + Returns a ``(transport, server_name)`` tuple. Each call produces an + independent transport so that parallel tool executions never share + state. + """ if isinstance(mcp_config, MCPServerStdio): transport = StdioTransport( command=mcp_config.command, @@ -292,38 +295,54 @@ class MCPToolResolver: headers=mcp_config.headers, streamable=mcp_config.streamable, ) - server_name = self._extract_server_name(mcp_config.url) + server_name = MCPToolResolver._extract_server_name(mcp_config.url) elif isinstance(mcp_config, MCPServerSSE): transport = SSETransport( url=mcp_config.url, headers=mcp_config.headers, ) - server_name = self._extract_server_name(mcp_config.url) + server_name = MCPToolResolver._extract_server_name(mcp_config.url) else: raise ValueError(f"Unsupported MCP server config type: {type(mcp_config)}") + return transport, server_name - client = MCPClient( - transport=transport, + def _resolve_native( + self, mcp_config: MCPServerConfig + ) -> tuple[list[BaseTool], list[Any]]: + """Resolve an ``MCPServerConfig`` into tools. + + Returns ``(tools, clients)`` where *clients* is always empty for + native tools (clients are now created on-demand per invocation). + A ``client_factory`` closure is passed to each ``MCPNativeTool`` so + every call -- even concurrent calls to the *same* tool -- gets its + own ``MCPClient`` + transport with no shared mutable state. + """ + from crewai.tools.base_tool import BaseTool + from crewai.tools.mcp_native_tool import MCPNativeTool + + discovery_transport, server_name = self._create_transport(mcp_config) + discovery_client = MCPClient( + transport=discovery_transport, cache_tools_list=mcp_config.cache_tools_list, ) async def _setup_client_and_list_tools() -> list[dict[str, Any]]: try: - if not client.connected: - await client.connect() + if not discovery_client.connected: + await discovery_client.connect() - tools_list = await client.list_tools() + tools_list = await discovery_client.list_tools() try: - await client.disconnect() + await discovery_client.disconnect() await asyncio.sleep(0.1) except Exception as e: self._logger.log("error", f"Error during disconnect: {e}") return tools_list except Exception as e: - if client.connected: - await client.disconnect() + if discovery_client.connected: + await discovery_client.disconnect() await asyncio.sleep(0.1) raise RuntimeError( f"Error during setup client and list tools: {e}" @@ -376,6 +395,13 @@ class MCPToolResolver: filtered_tools.append(tool) tools_list = filtered_tools + def _client_factory() -> MCPClient: + transport, _ = self._create_transport(mcp_config) + return MCPClient( + transport=transport, + cache_tools_list=mcp_config.cache_tools_list, + ) + tools = [] for tool_def in tools_list: tool_name = tool_def.get("name", "") @@ -396,7 +422,7 @@ class MCPToolResolver: try: native_tool = MCPNativeTool( - mcp_client=client, + client_factory=_client_factory, tool_name=tool_name, tool_schema=tool_schema, server_name=server_name, @@ -407,10 +433,10 @@ class MCPToolResolver: self._logger.log("error", f"Failed to create native MCP tool: {e}") continue - return cast(list[BaseTool], tools), client + return cast(list[BaseTool], tools), [] except Exception as e: - if client.connected: - asyncio.run(client.disconnect()) + if discovery_client.connected: + asyncio.run(discovery_client.disconnect()) raise RuntimeError(f"Failed to get native MCP tools: {e}") from e diff --git a/lib/crewai/src/crewai/tools/mcp_native_tool.py b/lib/crewai/src/crewai/tools/mcp_native_tool.py index d14c26a5a..dec365d58 100644 --- a/lib/crewai/src/crewai/tools/mcp_native_tool.py +++ b/lib/crewai/src/crewai/tools/mcp_native_tool.py @@ -1,29 +1,30 @@ """Native MCP tool wrapper for CrewAI agents. -This module provides a tool wrapper that reuses existing MCP client sessions -for better performance and connection management. +This module provides a tool wrapper that creates a fresh MCP client for every +invocation, ensuring safe parallel execution even when the same tool is called +concurrently by the executor. """ import asyncio +from collections.abc import Callable from typing import Any from crewai.tools import BaseTool class MCPNativeTool(BaseTool): - """Native MCP tool that reuses client sessions. + """Native MCP tool that creates a fresh client per invocation. - This tool wrapper is used when agents connect to MCP servers using - structured configurations. It reuses existing client sessions for - better performance and proper connection lifecycle management. - - Unlike MCPToolWrapper which connects on-demand, this tool uses - a shared MCP client instance that maintains a persistent connection. + A ``client_factory`` callable produces an independent ``MCPClient`` + + transport for every ``_run_async`` call. This guarantees that parallel + invocations -- whether of the *same* tool or *different* tools from the + same server -- never share mutable connection state (which would cause + anyio cancel-scope errors). """ def __init__( self, - mcp_client: Any, + client_factory: Callable[[], Any], tool_name: str, tool_schema: dict[str, Any], server_name: str, @@ -32,19 +33,16 @@ class MCPNativeTool(BaseTool): """Initialize native MCP tool. Args: - mcp_client: MCPClient instance with active session. + client_factory: Zero-arg callable that returns a new MCPClient. tool_name: Name of the tool (may be prefixed). tool_schema: Schema information for the tool. server_name: Name of the MCP server for prefixing. original_tool_name: Original name of the tool on the MCP server. """ - # Create tool name with server prefix to avoid conflicts prefixed_name = f"{server_name}_{tool_name}" - # Handle args_schema properly - BaseTool expects a BaseModel subclass args_schema = tool_schema.get("args_schema") - # Only pass args_schema if it's provided kwargs = { "name": prefixed_name, "description": tool_schema.get( @@ -57,16 +55,9 @@ class MCPNativeTool(BaseTool): super().__init__(**kwargs) - # Set instance attributes after super().__init__ - self._mcp_client = mcp_client + self._client_factory = client_factory self._original_tool_name = original_tool_name or tool_name self._server_name = server_name - # self._logger = logging.getLogger(__name__) - - @property - def mcp_client(self) -> Any: - """Get the MCP client instance.""" - return self._mcp_client @property def original_tool_name(self) -> str: @@ -108,51 +99,26 @@ class MCPNativeTool(BaseTool): async def _run_async(self, **kwargs) -> str: """Async implementation of tool execution. + A fresh ``MCPClient`` is created for every invocation so that + concurrent calls never share transport or session state. + Args: **kwargs: Arguments to pass to the MCP tool. Returns: Result from the MCP tool execution. """ - # Note: Since we use asyncio.run() which creates a new event loop each time, - # Always reconnect on-demand because asyncio.run() creates new event loops per call - # All MCP transport context managers (stdio, streamablehttp_client, sse_client) - # use anyio.create_task_group() which can't span different event loops - if self._mcp_client.connected: - await self._mcp_client.disconnect() - - await self._mcp_client.connect() + client = self._client_factory() + await client.connect() try: - result = await self._mcp_client.call_tool(self.original_tool_name, kwargs) - - except Exception as e: - error_str = str(e).lower() - if ( - "not connected" in error_str - or "connection" in error_str - or "send" in error_str - ): - await self._mcp_client.disconnect() - await self._mcp_client.connect() - # Retry the call - result = await self._mcp_client.call_tool( - self.original_tool_name, kwargs - ) - else: - raise - + result = await client.call_tool(self.original_tool_name, kwargs) finally: - # Always disconnect after tool call to ensure clean context manager lifecycle - # This prevents "exit cancel scope in different task" errors - # All transport context managers must be exited in the same event loop they were entered - await self._mcp_client.disconnect() + await client.disconnect() - # Extract result content if isinstance(result, str): return result - # Handle various result formats if hasattr(result, "content") and result.content: if isinstance(result.content, list) and len(result.content) > 0: content_item = result.content[0] diff --git a/lib/crewai/tests/agents/test_agent.py b/lib/crewai/tests/agents/test_agent.py index 025bfd334..4f6a84602 100644 --- a/lib/crewai/tests/agents/test_agent.py +++ b/lib/crewai/tests/agents/test_agent.py @@ -2353,3 +2353,68 @@ def test_agent_without_apps_no_platform_tools(): tools = crew._prepare_tools(agent, task, []) assert tools == [] + + +def test_agent_mcps_accepts_slug_with_specific_tool(): + """Agent(mcps=["notion#get_page"]) must pass validation (_SLUG_RE).""" + agent = Agent( + role="MCP Agent", + goal="Test MCP validation", + backstory="Test agent", + mcps=["notion#get_page"], + ) + assert agent.mcps == ["notion#get_page"] + + +def test_agent_mcps_accepts_slug_with_hyphenated_tool(): + agent = Agent( + role="MCP Agent", + goal="Test MCP validation", + backstory="Test agent", + mcps=["notion#get-page"], + ) + assert agent.mcps == ["notion#get-page"] + + +def test_agent_mcps_accepts_multiple_hash_refs(): + agent = Agent( + role="MCP Agent", + goal="Test MCP validation", + backstory="Test agent", + mcps=["notion#get_page", "notion#search", "github#list_repos"], + ) + assert len(agent.mcps) == 3 + + +def test_agent_mcps_accepts_mixed_ref_types(): + agent = Agent( + role="MCP Agent", + goal="Test MCP validation", + backstory="Test agent", + mcps=[ + "notion#get_page", + "notion", + "https://mcp.example.com/api", + ], + ) + assert len(agent.mcps) == 3 + + +def test_agent_mcps_rejects_hash_without_slug(): + with pytest.raises(ValueError, match="Invalid MCP reference"): + Agent( + role="MCP Agent", + goal="Test MCP validation", + backstory="Test agent", + mcps=["#get_page"], + ) + + +def test_agent_mcps_accepts_legacy_prefix_with_tool(): + agent = Agent( + role="MCP Agent", + goal="Test MCP validation", + backstory="Test agent", + mcps=["crewai-amp:notion#get_page"], + ) + assert agent.mcps == ["crewai-amp:notion#get_page"] diff --git a/lib/crewai/tests/mcp/test_amp_mcp.py b/lib/crewai/tests/mcp/test_amp_mcp.py index 3c4001f3c..f13484a8d 100644 --- a/lib/crewai/tests/mcp/test_amp_mcp.py +++ b/lib/crewai/tests/mcp/test_amp_mcp.py @@ -268,6 +268,54 @@ class TestGetMCPToolsAmpIntegration: assert len(tools) == 1 assert tools[0].name == "mcp_notion_so_sse_search" + @patch("crewai.mcp.tool_resolver.MCPClient") + @patch.object(MCPToolResolver, "_fetch_amp_mcp_configs") + def test_tool_filter_with_hyphenated_hash_syntax( + self, mock_fetch, mock_client_class, agent + ): + """notion#get-page must match the tool whose sanitized name is get_page.""" + mock_fetch.return_value = { + "notion": { + "type": "sse", + "url": "https://mcp.notion.so/sse", + "headers": {"Authorization": "Bearer token"}, + }, + } + + hyphenated_tool_definitions = [ + { + "name": "get_page", + "original_name": "get-page", + "description": "Get a page", + "inputSchema": {}, + }, + { + "name": "search", + "original_name": "search", + "description": "Search tool", + "inputSchema": { + "type": "object", + "properties": { + "query": {"type": "string", "description": "Search query"} + }, + "required": ["query"], + }, + }, + ] + + mock_client = AsyncMock() + mock_client.list_tools = AsyncMock(return_value=hyphenated_tool_definitions) + mock_client.connected = False + mock_client.connect = AsyncMock() + mock_client.disconnect = AsyncMock() + mock_client_class.return_value = mock_client + + tools = agent.get_mcp_tools(["notion#get-page"]) + + mock_fetch.assert_called_once_with(["notion"]) + assert len(tools) == 1 + assert tools[0].name.endswith("_get_page") + @patch("crewai.mcp.tool_resolver.MCPClient") @patch.object(MCPToolResolver, "_fetch_amp_mcp_configs") def test_deduplicates_slugs( @@ -371,3 +419,87 @@ class TestGetMCPToolsAmpIntegration: mock_external.assert_called_once_with("https://external.mcp.com/api") # 2 from notion + 1 from external + 2 from http_config assert len(tools) == 5 + + +class TestResolveExternalToolFilter: + """Tests for _resolve_external with #tool-name filtering.""" + + @pytest.fixture + def agent(self): + return Agent( + role="Test Agent", + goal="Test goal", + backstory="Test backstory", + ) + + @pytest.fixture + def resolver(self, agent): + return MCPToolResolver(agent=agent, logger=agent._logger) + + @patch.object(MCPToolResolver, "_get_mcp_tool_schemas") + def test_filters_hyphenated_tool_name(self, mock_schemas, resolver): + """https://...#get-page must match the sanitized key get_page in schemas.""" + mock_schemas.return_value = { + "get_page": { + "description": "Get a page", + "args_schema": None, + }, + "search": { + "description": "Search tool", + "args_schema": None, + }, + } + + tools = resolver._resolve_external("https://mcp.example.com/api#get-page") + + assert len(tools) == 1 + assert "get_page" in tools[0].name + + @patch.object(MCPToolResolver, "_get_mcp_tool_schemas") + def test_filters_underscored_tool_name(self, mock_schemas, resolver): + """https://...#get_page must also match the sanitized key get_page.""" + mock_schemas.return_value = { + "get_page": { + "description": "Get a page", + "args_schema": None, + }, + "search": { + "description": "Search tool", + "args_schema": None, + }, + } + + tools = resolver._resolve_external("https://mcp.example.com/api#get_page") + + assert len(tools) == 1 + assert "get_page" in tools[0].name + + @patch.object(MCPToolResolver, "_get_mcp_tool_schemas") + def test_returns_all_tools_without_hash(self, mock_schemas, resolver): + mock_schemas.return_value = { + "get_page": { + "description": "Get a page", + "args_schema": None, + }, + "search": { + "description": "Search tool", + "args_schema": None, + }, + } + + tools = resolver._resolve_external("https://mcp.example.com/api") + + assert len(tools) == 2 + + @patch.object(MCPToolResolver, "_get_mcp_tool_schemas") + def test_returns_empty_for_nonexistent_tool(self, mock_schemas, resolver): + mock_schemas.return_value = { + "search": { + "description": "Search tool", + "args_schema": None, + }, + } + + tools = resolver._resolve_external("https://mcp.example.com/api#nonexistent") + + assert len(tools) == 0 diff --git a/lib/crewai/tests/mcp/test_mcp_config.py b/lib/crewai/tests/mcp/test_mcp_config.py index 24fc59769..ce123be6b 100644 --- a/lib/crewai/tests/mcp/test_mcp_config.py +++ b/lib/crewai/tests/mcp/test_mcp_config.py @@ -1,4 +1,5 @@ import asyncio +import concurrent.futures from unittest.mock import AsyncMock, patch import pytest @@ -30,6 +31,17 @@ def mock_tool_definitions(): ] +def _make_mock_client(tool_definitions): + """Create a mock MCPClient that returns *tool_definitions*.""" + client = AsyncMock() + client.list_tools = AsyncMock(return_value=tool_definitions) + client.connected = False + client.connect = AsyncMock() + client.disconnect = AsyncMock() + client.call_tool = AsyncMock(return_value="test result") + return client + + def test_agent_with_stdio_mcp_config(mock_tool_definitions): """Test agent setup with MCPServerStdio configuration.""" stdio_config = MCPServerStdio( @@ -45,14 +57,8 @@ def test_agent_with_stdio_mcp_config(mock_tool_definitions): mcps=[stdio_config], ) - with patch("crewai.mcp.tool_resolver.MCPClient") as mock_client_class: - mock_client = AsyncMock() - mock_client.list_tools = AsyncMock(return_value=mock_tool_definitions) - mock_client.connected = False # Will trigger connect - mock_client.connect = AsyncMock() - mock_client.disconnect = AsyncMock() - mock_client_class.return_value = mock_client + mock_client_class.return_value = _make_mock_client(mock_tool_definitions) tools = agent.get_mcp_tools([stdio_config]) @@ -60,8 +66,7 @@ def test_agent_with_stdio_mcp_config(mock_tool_definitions): assert all(isinstance(tool, BaseTool) for tool in tools) mock_client_class.assert_called_once() - call_args = mock_client_class.call_args - transport = call_args.kwargs["transport"] + transport = mock_client_class.call_args.kwargs["transport"] assert transport.command == "python" assert transport.args == ["server.py"] assert transport.env == {"API_KEY": "test_key"} @@ -83,12 +88,7 @@ def test_agent_with_http_mcp_config(mock_tool_definitions): ) with patch("crewai.mcp.tool_resolver.MCPClient") as mock_client_class: - mock_client = AsyncMock() - mock_client.list_tools = AsyncMock(return_value=mock_tool_definitions) - mock_client.connected = False # Will trigger connect - mock_client.connect = AsyncMock() - mock_client.disconnect = AsyncMock() - mock_client_class.return_value = mock_client + mock_client_class.return_value = _make_mock_client(mock_tool_definitions) tools = agent.get_mcp_tools([http_config]) @@ -96,8 +96,7 @@ def test_agent_with_http_mcp_config(mock_tool_definitions): assert all(isinstance(tool, BaseTool) for tool in tools) mock_client_class.assert_called_once() - call_args = mock_client_class.call_args - transport = call_args.kwargs["transport"] + transport = mock_client_class.call_args.kwargs["transport"] assert transport.url == "https://api.example.com/mcp" assert transport.headers == {"Authorization": "Bearer test_token"} assert transport.streamable is True @@ -118,12 +117,7 @@ def test_agent_with_sse_mcp_config(mock_tool_definitions): ) with patch("crewai.mcp.tool_resolver.MCPClient") as mock_client_class: - mock_client = AsyncMock() - mock_client.list_tools = AsyncMock(return_value=mock_tool_definitions) - mock_client.connected = False - mock_client.connect = AsyncMock() - mock_client.disconnect = AsyncMock() - mock_client_class.return_value = mock_client + mock_client_class.return_value = _make_mock_client(mock_tool_definitions) tools = agent.get_mcp_tools([sse_config]) @@ -131,8 +125,7 @@ def test_agent_with_sse_mcp_config(mock_tool_definitions): assert all(isinstance(tool, BaseTool) for tool in tools) mock_client_class.assert_called_once() - call_args = mock_client_class.call_args - transport = call_args.kwargs["transport"] + transport = mock_client_class.call_args.kwargs["transport"] assert transport.url == "https://api.example.com/mcp/sse" assert transport.headers == {"Authorization": "Bearer test_token"} @@ -142,13 +135,7 @@ def test_mcp_tool_execution_in_sync_context(mock_tool_definitions): http_config = MCPServerHTTP(url="https://api.example.com/mcp") with patch("crewai.mcp.tool_resolver.MCPClient") as mock_client_class: - mock_client = AsyncMock() - mock_client.list_tools = AsyncMock(return_value=mock_tool_definitions) - mock_client.connected = False - mock_client.connect = AsyncMock() - mock_client.disconnect = AsyncMock() - mock_client.call_tool = AsyncMock(return_value="test result") - mock_client_class.return_value = mock_client + mock_client_class.return_value = _make_mock_client(mock_tool_definitions) agent = Agent( role="Test Agent", @@ -160,12 +147,12 @@ def test_mcp_tool_execution_in_sync_context(mock_tool_definitions): tools = agent.get_mcp_tools([http_config]) assert len(tools) == 2 - tool = tools[0] result = tool.run(query="test query") assert result == "test result" - mock_client.call_tool.assert_called() + # 1 discovery + 1 for the run() invocation + assert mock_client_class.call_count == 2 @pytest.mark.asyncio @@ -174,13 +161,7 @@ async def test_mcp_tool_execution_in_async_context(mock_tool_definitions): http_config = MCPServerHTTP(url="https://api.example.com/mcp") with patch("crewai.mcp.tool_resolver.MCPClient") as mock_client_class: - mock_client = AsyncMock() - mock_client.list_tools = AsyncMock(return_value=mock_tool_definitions) - mock_client.connected = False - mock_client.connect = AsyncMock() - mock_client.disconnect = AsyncMock() - mock_client.call_tool = AsyncMock(return_value="test result") - mock_client_class.return_value = mock_client + mock_client_class.return_value = _make_mock_client(mock_tool_definitions) agent = Agent( role="Test Agent", @@ -192,9 +173,129 @@ async def test_mcp_tool_execution_in_async_context(mock_tool_definitions): tools = agent.get_mcp_tools([http_config]) assert len(tools) == 2 - tool = tools[0] result = tool.run(query="test query") assert result == "test result" - mock_client.call_tool.assert_called() + assert mock_client_class.call_count == 2 + + +def test_each_invocation_gets_fresh_client(mock_tool_definitions): + """Every tool.run() must create its own MCPClient (no shared state).""" + http_config = MCPServerHTTP(url="https://api.example.com/mcp") + + clients_created: list = [] + + def _make_client(**kwargs): + client = _make_mock_client(mock_tool_definitions) + clients_created.append(client) + return client + + with patch("crewai.mcp.tool_resolver.MCPClient", side_effect=_make_client): + agent = Agent( + role="Test Agent", + goal="Test goal", + backstory="Test backstory", + mcps=[http_config], + ) + + tools = agent.get_mcp_tools([http_config]) + assert len(tools) == 2 + # 1 discovery client so far + assert len(clients_created) == 1 + + # Two sequential calls to the same tool must create 2 new clients + tools[0].run(query="q1") + tools[0].run(query="q2") + assert len(clients_created) == 3 + assert clients_created[1] is not clients_created[2] + + +def test_parallel_mcp_tool_execution_same_tool(mock_tool_definitions): + """Parallel calls to the *same* tool must not interfere.""" + http_config = MCPServerHTTP(url="https://api.example.com/mcp") + + call_log: list[str] = [] + + def _make_client(**kwargs): + client = AsyncMock() + client.list_tools = AsyncMock(return_value=mock_tool_definitions) + client.connected = False + client.connect = AsyncMock() + client.disconnect = AsyncMock() + + async def _call_tool(name, args): + call_log.append(name) + await asyncio.sleep(0.05) + return f"result-{name}" + + client.call_tool = AsyncMock(side_effect=_call_tool) + return client + + with patch("crewai.mcp.tool_resolver.MCPClient", side_effect=_make_client): + agent = Agent( + role="Test Agent", + goal="Test goal", + backstory="Test backstory", + mcps=[http_config], + ) + + tools = agent.get_mcp_tools([http_config]) + assert len(tools) >= 1 + tool = tools[0] + + # Call the SAME tool concurrently -- the exact scenario from the bug + with concurrent.futures.ThreadPoolExecutor(max_workers=2) as pool: + futures = [ + pool.submit(tool.run, query="q1"), + pool.submit(tool.run, query="q2"), + ] + results = [f.result() for f in concurrent.futures.as_completed(futures)] + + assert len(results) == 2 + assert all("result-" in r for r in results) + assert len(call_log) == 2 + + +def test_parallel_mcp_tool_execution_different_tools(mock_tool_definitions): + """Parallel calls to different tools from the same server must not interfere.""" + http_config = MCPServerHTTP(url="https://api.example.com/mcp") + + call_log: list[str] = [] + + def _make_client(**kwargs): + client = AsyncMock() + client.list_tools = AsyncMock(return_value=mock_tool_definitions) + client.connected = False + client.connect = AsyncMock() + client.disconnect = AsyncMock() + + async def _call_tool(name, args): + call_log.append(name) + await asyncio.sleep(0.05) + return f"result-{name}" + + client.call_tool = AsyncMock(side_effect=_call_tool) + return client + + with patch("crewai.mcp.tool_resolver.MCPClient", side_effect=_make_client): + agent = Agent( + role="Test Agent", + goal="Test goal", + backstory="Test backstory", + mcps=[http_config], + ) + + tools = agent.get_mcp_tools([http_config]) + assert len(tools) == 2 + + with concurrent.futures.ThreadPoolExecutor(max_workers=2) as pool: + futures = [ + pool.submit(tools[0].run, query="q1"), + pool.submit(tools[1].run, query="q2"), + ] + results = [f.result() for f in concurrent.futures.as_completed(futures)] + + assert len(results) == 2 + assert all("result-" in r for r in results) + assert len(call_log) == 2