diff --git a/lib/crewai/src/crewai/agents/agent_builder/base_agent.py b/lib/crewai/src/crewai/agents/agent_builder/base_agent.py index 8b2b9737c..da32d9c1c 100644 --- a/lib/crewai/src/crewai/agents/agent_builder/base_agent.py +++ b/lib/crewai/src/crewai/agents/agent_builder/base_agent.py @@ -38,7 +38,7 @@ from crewai.utilities.string_utils import interpolate_only _SLUG_RE: Final[re.Pattern[str]] = re.compile( - r"^(?:crewai-amp:)?[a-zA-Z0-9][a-zA-Z0-9_-]*(?:#\w+)?$" + r"^(?:crewai-amp:)?[a-zA-Z0-9][a-zA-Z0-9_-]*(?:#[\w-]+)?$" ) diff --git a/lib/crewai/src/crewai/mcp/tool_resolver.py b/lib/crewai/src/crewai/mcp/tool_resolver.py index 34af189f2..c0428f82d 100644 --- a/lib/crewai/src/crewai/mcp/tool_resolver.py +++ b/lib/crewai/src/crewai/mcp/tool_resolver.py @@ -22,6 +22,7 @@ from crewai.mcp.config import ( MCPServerSSE, MCPServerStdio, ) +from crewai.utilities.string_utils import sanitize_tool_name from crewai.mcp.transports.http import HTTPTransport from crewai.mcp.transports.sse import SSETransport from crewai.mcp.transports.stdio import StdioTransport @@ -74,10 +75,9 @@ class MCPToolResolver: elif isinstance(mcp_config, str): amp_refs.append(self._parse_amp_ref(mcp_config)) else: - tools, client = self._resolve_native(mcp_config) + tools, clients = self._resolve_native(mcp_config) all_tools.extend(tools) - if client: - self._clients.append(client) + self._clients.extend(clients) if amp_refs: tools, clients = self._resolve_amp(amp_refs) @@ -131,7 +131,7 @@ class MCPToolResolver: all_tools: list[BaseTool] = [] all_clients: list[Any] = [] - resolved_cache: dict[str, tuple[list[BaseTool], Any | None]] = {} + resolved_cache: dict[str, tuple[list[BaseTool], list[Any]]] = {} for slug in unique_slugs: config_dict = amp_configs_map.get(slug) @@ -149,10 +149,9 @@ class MCPToolResolver: mcp_server_config = self._build_mcp_config_from_dict(config_dict) try: - tools, client = self._resolve_native(mcp_server_config) - resolved_cache[slug] = (tools, client) - if client: - all_clients.append(client) + tools, clients = self._resolve_native(mcp_server_config) + resolved_cache[slug] = (tools, clients) + all_clients.extend(clients) except Exception as e: crewai_event_bus.emit( self, @@ -170,8 +169,9 @@ class MCPToolResolver: slug_tools, _ = cached if specific_tool: + sanitized = sanitize_tool_name(specific_tool) all_tools.extend( - t for t in slug_tools if t.name.endswith(f"_{specific_tool}") + t for t in slug_tools if t.name.endswith(f"_{sanitized}") ) else: all_tools.extend(slug_tools) @@ -198,7 +198,6 @@ class MCPToolResolver: plus_api = PlusAPI(api_key=get_platform_integration_token()) response = plus_api.get_mcp_configs(slugs) - if response.status_code == 200: configs: dict[str, dict[str, Any]] = response.json().get("configs", {}) return configs @@ -218,6 +217,7 @@ class MCPToolResolver: def _resolve_external(self, mcp_ref: str) -> list[BaseTool]: """Resolve an HTTPS MCP server URL into tools.""" + from crewai.tools.base_tool import BaseTool from crewai.tools.mcp_tool_wrapper import MCPToolWrapper if "#" in mcp_ref: @@ -227,6 +227,7 @@ class MCPToolResolver: server_params = {"url": server_url} server_name = self._extract_server_name(server_url) + sanitized_specific_tool = sanitize_tool_name(specific_tool) if specific_tool else None try: tool_schemas = self._get_mcp_tool_schemas(server_params) @@ -239,7 +240,7 @@ class MCPToolResolver: tools = [] for tool_name, schema in tool_schemas.items(): - if specific_tool and tool_name != specific_tool: + if sanitized_specific_tool and tool_name != sanitized_specific_tool: continue try: @@ -271,14 +272,16 @@ class MCPToolResolver: ) return [] - def _resolve_native( - self, mcp_config: MCPServerConfig - ) -> tuple[list[BaseTool], Any | None]: - """Resolve an ``MCPServerConfig`` into tools, returning the client for cleanup.""" - from crewai.tools.base_tool import BaseTool - from crewai.tools.mcp_native_tool import MCPNativeTool + @staticmethod + def _create_transport( + mcp_config: MCPServerConfig, + ) -> tuple[StdioTransport | HTTPTransport | SSETransport, str]: + """Create a fresh transport instance from an MCP server config. - transport: StdioTransport | HTTPTransport | SSETransport + Returns a ``(transport, server_name)`` tuple. Each call produces an + independent transport so that parallel tool executions never share + state. + """ if isinstance(mcp_config, MCPServerStdio): transport = StdioTransport( command=mcp_config.command, @@ -292,38 +295,54 @@ class MCPToolResolver: headers=mcp_config.headers, streamable=mcp_config.streamable, ) - server_name = self._extract_server_name(mcp_config.url) + server_name = MCPToolResolver._extract_server_name(mcp_config.url) elif isinstance(mcp_config, MCPServerSSE): transport = SSETransport( url=mcp_config.url, headers=mcp_config.headers, ) - server_name = self._extract_server_name(mcp_config.url) + server_name = MCPToolResolver._extract_server_name(mcp_config.url) else: raise ValueError(f"Unsupported MCP server config type: {type(mcp_config)}") + return transport, server_name - client = MCPClient( - transport=transport, + def _resolve_native( + self, mcp_config: MCPServerConfig + ) -> tuple[list[BaseTool], list[Any]]: + """Resolve an ``MCPServerConfig`` into tools. + + Returns ``(tools, clients)`` where *clients* is always empty for + native tools (clients are now created on-demand per invocation). + A ``client_factory`` closure is passed to each ``MCPNativeTool`` so + every call -- even concurrent calls to the *same* tool -- gets its + own ``MCPClient`` + transport with no shared mutable state. + """ + from crewai.tools.base_tool import BaseTool + from crewai.tools.mcp_native_tool import MCPNativeTool + + discovery_transport, server_name = self._create_transport(mcp_config) + discovery_client = MCPClient( + transport=discovery_transport, cache_tools_list=mcp_config.cache_tools_list, ) async def _setup_client_and_list_tools() -> list[dict[str, Any]]: try: - if not client.connected: - await client.connect() + if not discovery_client.connected: + await discovery_client.connect() - tools_list = await client.list_tools() + tools_list = await discovery_client.list_tools() try: - await client.disconnect() + await discovery_client.disconnect() await asyncio.sleep(0.1) except Exception as e: self._logger.log("error", f"Error during disconnect: {e}") return tools_list except Exception as e: - if client.connected: - await client.disconnect() + if discovery_client.connected: + await discovery_client.disconnect() await asyncio.sleep(0.1) raise RuntimeError( f"Error during setup client and list tools: {e}" @@ -376,6 +395,13 @@ class MCPToolResolver: filtered_tools.append(tool) tools_list = filtered_tools + def _client_factory() -> MCPClient: + transport, _ = self._create_transport(mcp_config) + return MCPClient( + transport=transport, + cache_tools_list=mcp_config.cache_tools_list, + ) + tools = [] for tool_def in tools_list: tool_name = tool_def.get("name", "") @@ -396,7 +422,7 @@ class MCPToolResolver: try: native_tool = MCPNativeTool( - mcp_client=client, + client_factory=_client_factory, tool_name=tool_name, tool_schema=tool_schema, server_name=server_name, @@ -407,10 +433,10 @@ class MCPToolResolver: self._logger.log("error", f"Failed to create native MCP tool: {e}") continue - return cast(list[BaseTool], tools), client + return cast(list[BaseTool], tools), [] except Exception as e: - if client.connected: - asyncio.run(client.disconnect()) + if discovery_client.connected: + asyncio.run(discovery_client.disconnect()) raise RuntimeError(f"Failed to get native MCP tools: {e}") from e diff --git a/lib/crewai/src/crewai/tools/mcp_native_tool.py b/lib/crewai/src/crewai/tools/mcp_native_tool.py index d14c26a5a..dec365d58 100644 --- a/lib/crewai/src/crewai/tools/mcp_native_tool.py +++ b/lib/crewai/src/crewai/tools/mcp_native_tool.py @@ -1,29 +1,30 @@ """Native MCP tool wrapper for CrewAI agents. -This module provides a tool wrapper that reuses existing MCP client sessions -for better performance and connection management. +This module provides a tool wrapper that creates a fresh MCP client for every +invocation, ensuring safe parallel execution even when the same tool is called +concurrently by the executor. """ import asyncio +from collections.abc import Callable from typing import Any from crewai.tools import BaseTool class MCPNativeTool(BaseTool): - """Native MCP tool that reuses client sessions. + """Native MCP tool that creates a fresh client per invocation. - This tool wrapper is used when agents connect to MCP servers using - structured configurations. It reuses existing client sessions for - better performance and proper connection lifecycle management. - - Unlike MCPToolWrapper which connects on-demand, this tool uses - a shared MCP client instance that maintains a persistent connection. + A ``client_factory`` callable produces an independent ``MCPClient`` + + transport for every ``_run_async`` call. This guarantees that parallel + invocations -- whether of the *same* tool or *different* tools from the + same server -- never share mutable connection state (which would cause + anyio cancel-scope errors). """ def __init__( self, - mcp_client: Any, + client_factory: Callable[[], Any], tool_name: str, tool_schema: dict[str, Any], server_name: str, @@ -32,19 +33,16 @@ class MCPNativeTool(BaseTool): """Initialize native MCP tool. Args: - mcp_client: MCPClient instance with active session. + client_factory: Zero-arg callable that returns a new MCPClient. tool_name: Name of the tool (may be prefixed). tool_schema: Schema information for the tool. server_name: Name of the MCP server for prefixing. original_tool_name: Original name of the tool on the MCP server. """ - # Create tool name with server prefix to avoid conflicts prefixed_name = f"{server_name}_{tool_name}" - # Handle args_schema properly - BaseTool expects a BaseModel subclass args_schema = tool_schema.get("args_schema") - # Only pass args_schema if it's provided kwargs = { "name": prefixed_name, "description": tool_schema.get( @@ -57,16 +55,9 @@ class MCPNativeTool(BaseTool): super().__init__(**kwargs) - # Set instance attributes after super().__init__ - self._mcp_client = mcp_client + self._client_factory = client_factory self._original_tool_name = original_tool_name or tool_name self._server_name = server_name - # self._logger = logging.getLogger(__name__) - - @property - def mcp_client(self) -> Any: - """Get the MCP client instance.""" - return self._mcp_client @property def original_tool_name(self) -> str: @@ -108,51 +99,26 @@ class MCPNativeTool(BaseTool): async def _run_async(self, **kwargs) -> str: """Async implementation of tool execution. + A fresh ``MCPClient`` is created for every invocation so that + concurrent calls never share transport or session state. + Args: **kwargs: Arguments to pass to the MCP tool. Returns: Result from the MCP tool execution. """ - # Note: Since we use asyncio.run() which creates a new event loop each time, - # Always reconnect on-demand because asyncio.run() creates new event loops per call - # All MCP transport context managers (stdio, streamablehttp_client, sse_client) - # use anyio.create_task_group() which can't span different event loops - if self._mcp_client.connected: - await self._mcp_client.disconnect() - - await self._mcp_client.connect() + client = self._client_factory() + await client.connect() try: - result = await self._mcp_client.call_tool(self.original_tool_name, kwargs) - - except Exception as e: - error_str = str(e).lower() - if ( - "not connected" in error_str - or "connection" in error_str - or "send" in error_str - ): - await self._mcp_client.disconnect() - await self._mcp_client.connect() - # Retry the call - result = await self._mcp_client.call_tool( - self.original_tool_name, kwargs - ) - else: - raise - + result = await client.call_tool(self.original_tool_name, kwargs) finally: - # Always disconnect after tool call to ensure clean context manager lifecycle - # This prevents "exit cancel scope in different task" errors - # All transport context managers must be exited in the same event loop they were entered - await self._mcp_client.disconnect() + await client.disconnect() - # Extract result content if isinstance(result, str): return result - # Handle various result formats if hasattr(result, "content") and result.content: if isinstance(result.content, list) and len(result.content) > 0: content_item = result.content[0] diff --git a/lib/crewai/tests/agents/test_agent.py b/lib/crewai/tests/agents/test_agent.py index 025bfd334..4f6a84602 100644 --- a/lib/crewai/tests/agents/test_agent.py +++ b/lib/crewai/tests/agents/test_agent.py @@ -2353,3 +2353,68 @@ def test_agent_without_apps_no_platform_tools(): tools = crew._prepare_tools(agent, task, []) assert tools == [] + + +def test_agent_mcps_accepts_slug_with_specific_tool(): + """Agent(mcps=["notion#get_page"]) must pass validation (_SLUG_RE).""" + agent = Agent( + role="MCP Agent", + goal="Test MCP validation", + backstory="Test agent", + mcps=["notion#get_page"], + ) + assert agent.mcps == ["notion#get_page"] + + +def test_agent_mcps_accepts_slug_with_hyphenated_tool(): + agent = Agent( + role="MCP Agent", + goal="Test MCP validation", + backstory="Test agent", + mcps=["notion#get-page"], + ) + assert agent.mcps == ["notion#get-page"] + + +def test_agent_mcps_accepts_multiple_hash_refs(): + agent = Agent( + role="MCP Agent", + goal="Test MCP validation", + backstory="Test agent", + mcps=["notion#get_page", "notion#search", "github#list_repos"], + ) + assert len(agent.mcps) == 3 + + +def test_agent_mcps_accepts_mixed_ref_types(): + agent = Agent( + role="MCP Agent", + goal="Test MCP validation", + backstory="Test agent", + mcps=[ + "notion#get_page", + "notion", + "https://mcp.example.com/api", + ], + ) + assert len(agent.mcps) == 3 + + +def test_agent_mcps_rejects_hash_without_slug(): + with pytest.raises(ValueError, match="Invalid MCP reference"): + Agent( + role="MCP Agent", + goal="Test MCP validation", + backstory="Test agent", + mcps=["#get_page"], + ) + + +def test_agent_mcps_accepts_legacy_prefix_with_tool(): + agent = Agent( + role="MCP Agent", + goal="Test MCP validation", + backstory="Test agent", + mcps=["crewai-amp:notion#get_page"], + ) + assert agent.mcps == ["crewai-amp:notion#get_page"] diff --git a/lib/crewai/tests/mcp/test_amp_mcp.py b/lib/crewai/tests/mcp/test_amp_mcp.py index 3c4001f3c..f13484a8d 100644 --- a/lib/crewai/tests/mcp/test_amp_mcp.py +++ b/lib/crewai/tests/mcp/test_amp_mcp.py @@ -268,6 +268,54 @@ class TestGetMCPToolsAmpIntegration: assert len(tools) == 1 assert tools[0].name == "mcp_notion_so_sse_search" + @patch("crewai.mcp.tool_resolver.MCPClient") + @patch.object(MCPToolResolver, "_fetch_amp_mcp_configs") + def test_tool_filter_with_hyphenated_hash_syntax( + self, mock_fetch, mock_client_class, agent + ): + """notion#get-page must match the tool whose sanitized name is get_page.""" + mock_fetch.return_value = { + "notion": { + "type": "sse", + "url": "https://mcp.notion.so/sse", + "headers": {"Authorization": "Bearer token"}, + }, + } + + hyphenated_tool_definitions = [ + { + "name": "get_page", + "original_name": "get-page", + "description": "Get a page", + "inputSchema": {}, + }, + { + "name": "search", + "original_name": "search", + "description": "Search tool", + "inputSchema": { + "type": "object", + "properties": { + "query": {"type": "string", "description": "Search query"} + }, + "required": ["query"], + }, + }, + ] + + mock_client = AsyncMock() + mock_client.list_tools = AsyncMock(return_value=hyphenated_tool_definitions) + mock_client.connected = False + mock_client.connect = AsyncMock() + mock_client.disconnect = AsyncMock() + mock_client_class.return_value = mock_client + + tools = agent.get_mcp_tools(["notion#get-page"]) + + mock_fetch.assert_called_once_with(["notion"]) + assert len(tools) == 1 + assert tools[0].name.endswith("_get_page") + @patch("crewai.mcp.tool_resolver.MCPClient") @patch.object(MCPToolResolver, "_fetch_amp_mcp_configs") def test_deduplicates_slugs( @@ -371,3 +419,87 @@ class TestGetMCPToolsAmpIntegration: mock_external.assert_called_once_with("https://external.mcp.com/api") # 2 from notion + 1 from external + 2 from http_config assert len(tools) == 5 + + +class TestResolveExternalToolFilter: + """Tests for _resolve_external with #tool-name filtering.""" + + @pytest.fixture + def agent(self): + return Agent( + role="Test Agent", + goal="Test goal", + backstory="Test backstory", + ) + + @pytest.fixture + def resolver(self, agent): + return MCPToolResolver(agent=agent, logger=agent._logger) + + @patch.object(MCPToolResolver, "_get_mcp_tool_schemas") + def test_filters_hyphenated_tool_name(self, mock_schemas, resolver): + """https://...#get-page must match the sanitized key get_page in schemas.""" + mock_schemas.return_value = { + "get_page": { + "description": "Get a page", + "args_schema": None, + }, + "search": { + "description": "Search tool", + "args_schema": None, + }, + } + + tools = resolver._resolve_external("https://mcp.example.com/api#get-page") + + assert len(tools) == 1 + assert "get_page" in tools[0].name + + @patch.object(MCPToolResolver, "_get_mcp_tool_schemas") + def test_filters_underscored_tool_name(self, mock_schemas, resolver): + """https://...#get_page must also match the sanitized key get_page.""" + mock_schemas.return_value = { + "get_page": { + "description": "Get a page", + "args_schema": None, + }, + "search": { + "description": "Search tool", + "args_schema": None, + }, + } + + tools = resolver._resolve_external("https://mcp.example.com/api#get_page") + + assert len(tools) == 1 + assert "get_page" in tools[0].name + + @patch.object(MCPToolResolver, "_get_mcp_tool_schemas") + def test_returns_all_tools_without_hash(self, mock_schemas, resolver): + mock_schemas.return_value = { + "get_page": { + "description": "Get a page", + "args_schema": None, + }, + "search": { + "description": "Search tool", + "args_schema": None, + }, + } + + tools = resolver._resolve_external("https://mcp.example.com/api") + + assert len(tools) == 2 + + @patch.object(MCPToolResolver, "_get_mcp_tool_schemas") + def test_returns_empty_for_nonexistent_tool(self, mock_schemas, resolver): + mock_schemas.return_value = { + "search": { + "description": "Search tool", + "args_schema": None, + }, + } + + tools = resolver._resolve_external("https://mcp.example.com/api#nonexistent") + + assert len(tools) == 0 diff --git a/lib/crewai/tests/mcp/test_mcp_config.py b/lib/crewai/tests/mcp/test_mcp_config.py index 24fc59769..ce123be6b 100644 --- a/lib/crewai/tests/mcp/test_mcp_config.py +++ b/lib/crewai/tests/mcp/test_mcp_config.py @@ -1,4 +1,5 @@ import asyncio +import concurrent.futures from unittest.mock import AsyncMock, patch import pytest @@ -30,6 +31,17 @@ def mock_tool_definitions(): ] +def _make_mock_client(tool_definitions): + """Create a mock MCPClient that returns *tool_definitions*.""" + client = AsyncMock() + client.list_tools = AsyncMock(return_value=tool_definitions) + client.connected = False + client.connect = AsyncMock() + client.disconnect = AsyncMock() + client.call_tool = AsyncMock(return_value="test result") + return client + + def test_agent_with_stdio_mcp_config(mock_tool_definitions): """Test agent setup with MCPServerStdio configuration.""" stdio_config = MCPServerStdio( @@ -45,14 +57,8 @@ def test_agent_with_stdio_mcp_config(mock_tool_definitions): mcps=[stdio_config], ) - with patch("crewai.mcp.tool_resolver.MCPClient") as mock_client_class: - mock_client = AsyncMock() - mock_client.list_tools = AsyncMock(return_value=mock_tool_definitions) - mock_client.connected = False # Will trigger connect - mock_client.connect = AsyncMock() - mock_client.disconnect = AsyncMock() - mock_client_class.return_value = mock_client + mock_client_class.return_value = _make_mock_client(mock_tool_definitions) tools = agent.get_mcp_tools([stdio_config]) @@ -60,8 +66,7 @@ def test_agent_with_stdio_mcp_config(mock_tool_definitions): assert all(isinstance(tool, BaseTool) for tool in tools) mock_client_class.assert_called_once() - call_args = mock_client_class.call_args - transport = call_args.kwargs["transport"] + transport = mock_client_class.call_args.kwargs["transport"] assert transport.command == "python" assert transport.args == ["server.py"] assert transport.env == {"API_KEY": "test_key"} @@ -83,12 +88,7 @@ def test_agent_with_http_mcp_config(mock_tool_definitions): ) with patch("crewai.mcp.tool_resolver.MCPClient") as mock_client_class: - mock_client = AsyncMock() - mock_client.list_tools = AsyncMock(return_value=mock_tool_definitions) - mock_client.connected = False # Will trigger connect - mock_client.connect = AsyncMock() - mock_client.disconnect = AsyncMock() - mock_client_class.return_value = mock_client + mock_client_class.return_value = _make_mock_client(mock_tool_definitions) tools = agent.get_mcp_tools([http_config]) @@ -96,8 +96,7 @@ def test_agent_with_http_mcp_config(mock_tool_definitions): assert all(isinstance(tool, BaseTool) for tool in tools) mock_client_class.assert_called_once() - call_args = mock_client_class.call_args - transport = call_args.kwargs["transport"] + transport = mock_client_class.call_args.kwargs["transport"] assert transport.url == "https://api.example.com/mcp" assert transport.headers == {"Authorization": "Bearer test_token"} assert transport.streamable is True @@ -118,12 +117,7 @@ def test_agent_with_sse_mcp_config(mock_tool_definitions): ) with patch("crewai.mcp.tool_resolver.MCPClient") as mock_client_class: - mock_client = AsyncMock() - mock_client.list_tools = AsyncMock(return_value=mock_tool_definitions) - mock_client.connected = False - mock_client.connect = AsyncMock() - mock_client.disconnect = AsyncMock() - mock_client_class.return_value = mock_client + mock_client_class.return_value = _make_mock_client(mock_tool_definitions) tools = agent.get_mcp_tools([sse_config]) @@ -131,8 +125,7 @@ def test_agent_with_sse_mcp_config(mock_tool_definitions): assert all(isinstance(tool, BaseTool) for tool in tools) mock_client_class.assert_called_once() - call_args = mock_client_class.call_args - transport = call_args.kwargs["transport"] + transport = mock_client_class.call_args.kwargs["transport"] assert transport.url == "https://api.example.com/mcp/sse" assert transport.headers == {"Authorization": "Bearer test_token"} @@ -142,13 +135,7 @@ def test_mcp_tool_execution_in_sync_context(mock_tool_definitions): http_config = MCPServerHTTP(url="https://api.example.com/mcp") with patch("crewai.mcp.tool_resolver.MCPClient") as mock_client_class: - mock_client = AsyncMock() - mock_client.list_tools = AsyncMock(return_value=mock_tool_definitions) - mock_client.connected = False - mock_client.connect = AsyncMock() - mock_client.disconnect = AsyncMock() - mock_client.call_tool = AsyncMock(return_value="test result") - mock_client_class.return_value = mock_client + mock_client_class.return_value = _make_mock_client(mock_tool_definitions) agent = Agent( role="Test Agent", @@ -160,12 +147,12 @@ def test_mcp_tool_execution_in_sync_context(mock_tool_definitions): tools = agent.get_mcp_tools([http_config]) assert len(tools) == 2 - tool = tools[0] result = tool.run(query="test query") assert result == "test result" - mock_client.call_tool.assert_called() + # 1 discovery + 1 for the run() invocation + assert mock_client_class.call_count == 2 @pytest.mark.asyncio @@ -174,13 +161,7 @@ async def test_mcp_tool_execution_in_async_context(mock_tool_definitions): http_config = MCPServerHTTP(url="https://api.example.com/mcp") with patch("crewai.mcp.tool_resolver.MCPClient") as mock_client_class: - mock_client = AsyncMock() - mock_client.list_tools = AsyncMock(return_value=mock_tool_definitions) - mock_client.connected = False - mock_client.connect = AsyncMock() - mock_client.disconnect = AsyncMock() - mock_client.call_tool = AsyncMock(return_value="test result") - mock_client_class.return_value = mock_client + mock_client_class.return_value = _make_mock_client(mock_tool_definitions) agent = Agent( role="Test Agent", @@ -192,9 +173,129 @@ async def test_mcp_tool_execution_in_async_context(mock_tool_definitions): tools = agent.get_mcp_tools([http_config]) assert len(tools) == 2 - tool = tools[0] result = tool.run(query="test query") assert result == "test result" - mock_client.call_tool.assert_called() + assert mock_client_class.call_count == 2 + + +def test_each_invocation_gets_fresh_client(mock_tool_definitions): + """Every tool.run() must create its own MCPClient (no shared state).""" + http_config = MCPServerHTTP(url="https://api.example.com/mcp") + + clients_created: list = [] + + def _make_client(**kwargs): + client = _make_mock_client(mock_tool_definitions) + clients_created.append(client) + return client + + with patch("crewai.mcp.tool_resolver.MCPClient", side_effect=_make_client): + agent = Agent( + role="Test Agent", + goal="Test goal", + backstory="Test backstory", + mcps=[http_config], + ) + + tools = agent.get_mcp_tools([http_config]) + assert len(tools) == 2 + # 1 discovery client so far + assert len(clients_created) == 1 + + # Two sequential calls to the same tool must create 2 new clients + tools[0].run(query="q1") + tools[0].run(query="q2") + assert len(clients_created) == 3 + assert clients_created[1] is not clients_created[2] + + +def test_parallel_mcp_tool_execution_same_tool(mock_tool_definitions): + """Parallel calls to the *same* tool must not interfere.""" + http_config = MCPServerHTTP(url="https://api.example.com/mcp") + + call_log: list[str] = [] + + def _make_client(**kwargs): + client = AsyncMock() + client.list_tools = AsyncMock(return_value=mock_tool_definitions) + client.connected = False + client.connect = AsyncMock() + client.disconnect = AsyncMock() + + async def _call_tool(name, args): + call_log.append(name) + await asyncio.sleep(0.05) + return f"result-{name}" + + client.call_tool = AsyncMock(side_effect=_call_tool) + return client + + with patch("crewai.mcp.tool_resolver.MCPClient", side_effect=_make_client): + agent = Agent( + role="Test Agent", + goal="Test goal", + backstory="Test backstory", + mcps=[http_config], + ) + + tools = agent.get_mcp_tools([http_config]) + assert len(tools) >= 1 + tool = tools[0] + + # Call the SAME tool concurrently -- the exact scenario from the bug + with concurrent.futures.ThreadPoolExecutor(max_workers=2) as pool: + futures = [ + pool.submit(tool.run, query="q1"), + pool.submit(tool.run, query="q2"), + ] + results = [f.result() for f in concurrent.futures.as_completed(futures)] + + assert len(results) == 2 + assert all("result-" in r for r in results) + assert len(call_log) == 2 + + +def test_parallel_mcp_tool_execution_different_tools(mock_tool_definitions): + """Parallel calls to different tools from the same server must not interfere.""" + http_config = MCPServerHTTP(url="https://api.example.com/mcp") + + call_log: list[str] = [] + + def _make_client(**kwargs): + client = AsyncMock() + client.list_tools = AsyncMock(return_value=mock_tool_definitions) + client.connected = False + client.connect = AsyncMock() + client.disconnect = AsyncMock() + + async def _call_tool(name, args): + call_log.append(name) + await asyncio.sleep(0.05) + return f"result-{name}" + + client.call_tool = AsyncMock(side_effect=_call_tool) + return client + + with patch("crewai.mcp.tool_resolver.MCPClient", side_effect=_make_client): + agent = Agent( + role="Test Agent", + goal="Test goal", + backstory="Test backstory", + mcps=[http_config], + ) + + tools = agent.get_mcp_tools([http_config]) + assert len(tools) == 2 + + with concurrent.futures.ThreadPoolExecutor(max_workers=2) as pool: + futures = [ + pool.submit(tools[0].run, query="q1"), + pool.submit(tools[1].run, query="q2"), + ] + results = [f.result() for f in concurrent.futures.as_completed(futures)] + + assert len(results) == 2 + assert all("result-" in r for r in results) + assert len(call_log) == 2