Addressing MCP tools resolutions & eliminates all shared mutable connection (#4792)

* fix: allow hyphenated tool names in MCP references like notion#get-page

The _SLUG_RE regex on BaseAgent rejected MCP tool references containing
hyphens (e.g. "notion#get-page") because the fragment pattern only
matched \w (word chars)

* fix: create fresh MCP client per tool invocation to prevent parallel call races

When the LLM dispatches parallel calls to MCP tools on the same server, the executor runs them concurrently via ThreadPoolExecutor. Previously, all tools from a server shared a single MCPClient instance, and even the same tool called twice would reuse one client. Since each thread creates its own asyncio event loop via asyncio.run(), concurrent connect/disconnect calls on the shared client caused anyio cancel-scope errors ("Attempted to exit cancel scope in a different task than it was entered in").

The fix introduces a client_factory pattern: MCPNativeTool now receives a zero-arg callable that produces a fresh MCPClient + transport on every
_run_async() invocation. This eliminates all shared mutable connection state between concurrent calls, whether to the same tool or different tools from the same server.

* test: ensure we can filter hyphenated MCP tool
This commit is contained in:
Lucas Gomide
2026-03-10 15:00:40 -03:00
committed by GitHub
parent 7cffcab84a
commit e72a80be6e
6 changed files with 420 additions and 130 deletions

View File

@@ -38,7 +38,7 @@ from crewai.utilities.string_utils import interpolate_only
_SLUG_RE: Final[re.Pattern[str]] = re.compile(
r"^(?:crewai-amp:)?[a-zA-Z0-9][a-zA-Z0-9_-]*(?:#\w+)?$"
r"^(?:crewai-amp:)?[a-zA-Z0-9][a-zA-Z0-9_-]*(?:#[\w-]+)?$"
)

View File

@@ -22,6 +22,7 @@ from crewai.mcp.config import (
MCPServerSSE,
MCPServerStdio,
)
from crewai.utilities.string_utils import sanitize_tool_name
from crewai.mcp.transports.http import HTTPTransport
from crewai.mcp.transports.sse import SSETransport
from crewai.mcp.transports.stdio import StdioTransport
@@ -74,10 +75,9 @@ class MCPToolResolver:
elif isinstance(mcp_config, str):
amp_refs.append(self._parse_amp_ref(mcp_config))
else:
tools, client = self._resolve_native(mcp_config)
tools, clients = self._resolve_native(mcp_config)
all_tools.extend(tools)
if client:
self._clients.append(client)
self._clients.extend(clients)
if amp_refs:
tools, clients = self._resolve_amp(amp_refs)
@@ -131,7 +131,7 @@ class MCPToolResolver:
all_tools: list[BaseTool] = []
all_clients: list[Any] = []
resolved_cache: dict[str, tuple[list[BaseTool], Any | None]] = {}
resolved_cache: dict[str, tuple[list[BaseTool], list[Any]]] = {}
for slug in unique_slugs:
config_dict = amp_configs_map.get(slug)
@@ -149,10 +149,9 @@ class MCPToolResolver:
mcp_server_config = self._build_mcp_config_from_dict(config_dict)
try:
tools, client = self._resolve_native(mcp_server_config)
resolved_cache[slug] = (tools, client)
if client:
all_clients.append(client)
tools, clients = self._resolve_native(mcp_server_config)
resolved_cache[slug] = (tools, clients)
all_clients.extend(clients)
except Exception as e:
crewai_event_bus.emit(
self,
@@ -170,8 +169,9 @@ class MCPToolResolver:
slug_tools, _ = cached
if specific_tool:
sanitized = sanitize_tool_name(specific_tool)
all_tools.extend(
t for t in slug_tools if t.name.endswith(f"_{specific_tool}")
t for t in slug_tools if t.name.endswith(f"_{sanitized}")
)
else:
all_tools.extend(slug_tools)
@@ -198,7 +198,6 @@ class MCPToolResolver:
plus_api = PlusAPI(api_key=get_platform_integration_token())
response = plus_api.get_mcp_configs(slugs)
if response.status_code == 200:
configs: dict[str, dict[str, Any]] = response.json().get("configs", {})
return configs
@@ -218,6 +217,7 @@ class MCPToolResolver:
def _resolve_external(self, mcp_ref: str) -> list[BaseTool]:
"""Resolve an HTTPS MCP server URL into tools."""
from crewai.tools.base_tool import BaseTool
from crewai.tools.mcp_tool_wrapper import MCPToolWrapper
if "#" in mcp_ref:
@@ -227,6 +227,7 @@ class MCPToolResolver:
server_params = {"url": server_url}
server_name = self._extract_server_name(server_url)
sanitized_specific_tool = sanitize_tool_name(specific_tool) if specific_tool else None
try:
tool_schemas = self._get_mcp_tool_schemas(server_params)
@@ -239,7 +240,7 @@ class MCPToolResolver:
tools = []
for tool_name, schema in tool_schemas.items():
if specific_tool and tool_name != specific_tool:
if sanitized_specific_tool and tool_name != sanitized_specific_tool:
continue
try:
@@ -271,14 +272,16 @@ class MCPToolResolver:
)
return []
def _resolve_native(
self, mcp_config: MCPServerConfig
) -> tuple[list[BaseTool], Any | None]:
"""Resolve an ``MCPServerConfig`` into tools, returning the client for cleanup."""
from crewai.tools.base_tool import BaseTool
from crewai.tools.mcp_native_tool import MCPNativeTool
@staticmethod
def _create_transport(
mcp_config: MCPServerConfig,
) -> tuple[StdioTransport | HTTPTransport | SSETransport, str]:
"""Create a fresh transport instance from an MCP server config.
transport: StdioTransport | HTTPTransport | SSETransport
Returns a ``(transport, server_name)`` tuple. Each call produces an
independent transport so that parallel tool executions never share
state.
"""
if isinstance(mcp_config, MCPServerStdio):
transport = StdioTransport(
command=mcp_config.command,
@@ -292,38 +295,54 @@ class MCPToolResolver:
headers=mcp_config.headers,
streamable=mcp_config.streamable,
)
server_name = self._extract_server_name(mcp_config.url)
server_name = MCPToolResolver._extract_server_name(mcp_config.url)
elif isinstance(mcp_config, MCPServerSSE):
transport = SSETransport(
url=mcp_config.url,
headers=mcp_config.headers,
)
server_name = self._extract_server_name(mcp_config.url)
server_name = MCPToolResolver._extract_server_name(mcp_config.url)
else:
raise ValueError(f"Unsupported MCP server config type: {type(mcp_config)}")
return transport, server_name
client = MCPClient(
transport=transport,
def _resolve_native(
self, mcp_config: MCPServerConfig
) -> tuple[list[BaseTool], list[Any]]:
"""Resolve an ``MCPServerConfig`` into tools.
Returns ``(tools, clients)`` where *clients* is always empty for
native tools (clients are now created on-demand per invocation).
A ``client_factory`` closure is passed to each ``MCPNativeTool`` so
every call -- even concurrent calls to the *same* tool -- gets its
own ``MCPClient`` + transport with no shared mutable state.
"""
from crewai.tools.base_tool import BaseTool
from crewai.tools.mcp_native_tool import MCPNativeTool
discovery_transport, server_name = self._create_transport(mcp_config)
discovery_client = MCPClient(
transport=discovery_transport,
cache_tools_list=mcp_config.cache_tools_list,
)
async def _setup_client_and_list_tools() -> list[dict[str, Any]]:
try:
if not client.connected:
await client.connect()
if not discovery_client.connected:
await discovery_client.connect()
tools_list = await client.list_tools()
tools_list = await discovery_client.list_tools()
try:
await client.disconnect()
await discovery_client.disconnect()
await asyncio.sleep(0.1)
except Exception as e:
self._logger.log("error", f"Error during disconnect: {e}")
return tools_list
except Exception as e:
if client.connected:
await client.disconnect()
if discovery_client.connected:
await discovery_client.disconnect()
await asyncio.sleep(0.1)
raise RuntimeError(
f"Error during setup client and list tools: {e}"
@@ -376,6 +395,13 @@ class MCPToolResolver:
filtered_tools.append(tool)
tools_list = filtered_tools
def _client_factory() -> MCPClient:
transport, _ = self._create_transport(mcp_config)
return MCPClient(
transport=transport,
cache_tools_list=mcp_config.cache_tools_list,
)
tools = []
for tool_def in tools_list:
tool_name = tool_def.get("name", "")
@@ -396,7 +422,7 @@ class MCPToolResolver:
try:
native_tool = MCPNativeTool(
mcp_client=client,
client_factory=_client_factory,
tool_name=tool_name,
tool_schema=tool_schema,
server_name=server_name,
@@ -407,10 +433,10 @@ class MCPToolResolver:
self._logger.log("error", f"Failed to create native MCP tool: {e}")
continue
return cast(list[BaseTool], tools), client
return cast(list[BaseTool], tools), []
except Exception as e:
if client.connected:
asyncio.run(client.disconnect())
if discovery_client.connected:
asyncio.run(discovery_client.disconnect())
raise RuntimeError(f"Failed to get native MCP tools: {e}") from e

View File

@@ -1,29 +1,30 @@
"""Native MCP tool wrapper for CrewAI agents.
This module provides a tool wrapper that reuses existing MCP client sessions
for better performance and connection management.
This module provides a tool wrapper that creates a fresh MCP client for every
invocation, ensuring safe parallel execution even when the same tool is called
concurrently by the executor.
"""
import asyncio
from collections.abc import Callable
from typing import Any
from crewai.tools import BaseTool
class MCPNativeTool(BaseTool):
"""Native MCP tool that reuses client sessions.
"""Native MCP tool that creates a fresh client per invocation.
This tool wrapper is used when agents connect to MCP servers using
structured configurations. It reuses existing client sessions for
better performance and proper connection lifecycle management.
Unlike MCPToolWrapper which connects on-demand, this tool uses
a shared MCP client instance that maintains a persistent connection.
A ``client_factory`` callable produces an independent ``MCPClient`` +
transport for every ``_run_async`` call. This guarantees that parallel
invocations -- whether of the *same* tool or *different* tools from the
same server -- never share mutable connection state (which would cause
anyio cancel-scope errors).
"""
def __init__(
self,
mcp_client: Any,
client_factory: Callable[[], Any],
tool_name: str,
tool_schema: dict[str, Any],
server_name: str,
@@ -32,19 +33,16 @@ class MCPNativeTool(BaseTool):
"""Initialize native MCP tool.
Args:
mcp_client: MCPClient instance with active session.
client_factory: Zero-arg callable that returns a new MCPClient.
tool_name: Name of the tool (may be prefixed).
tool_schema: Schema information for the tool.
server_name: Name of the MCP server for prefixing.
original_tool_name: Original name of the tool on the MCP server.
"""
# Create tool name with server prefix to avoid conflicts
prefixed_name = f"{server_name}_{tool_name}"
# Handle args_schema properly - BaseTool expects a BaseModel subclass
args_schema = tool_schema.get("args_schema")
# Only pass args_schema if it's provided
kwargs = {
"name": prefixed_name,
"description": tool_schema.get(
@@ -57,16 +55,9 @@ class MCPNativeTool(BaseTool):
super().__init__(**kwargs)
# Set instance attributes after super().__init__
self._mcp_client = mcp_client
self._client_factory = client_factory
self._original_tool_name = original_tool_name or tool_name
self._server_name = server_name
# self._logger = logging.getLogger(__name__)
@property
def mcp_client(self) -> Any:
"""Get the MCP client instance."""
return self._mcp_client
@property
def original_tool_name(self) -> str:
@@ -108,51 +99,26 @@ class MCPNativeTool(BaseTool):
async def _run_async(self, **kwargs) -> str:
"""Async implementation of tool execution.
A fresh ``MCPClient`` is created for every invocation so that
concurrent calls never share transport or session state.
Args:
**kwargs: Arguments to pass to the MCP tool.
Returns:
Result from the MCP tool execution.
"""
# Note: Since we use asyncio.run() which creates a new event loop each time,
# Always reconnect on-demand because asyncio.run() creates new event loops per call
# All MCP transport context managers (stdio, streamablehttp_client, sse_client)
# use anyio.create_task_group() which can't span different event loops
if self._mcp_client.connected:
await self._mcp_client.disconnect()
await self._mcp_client.connect()
client = self._client_factory()
await client.connect()
try:
result = await self._mcp_client.call_tool(self.original_tool_name, kwargs)
except Exception as e:
error_str = str(e).lower()
if (
"not connected" in error_str
or "connection" in error_str
or "send" in error_str
):
await self._mcp_client.disconnect()
await self._mcp_client.connect()
# Retry the call
result = await self._mcp_client.call_tool(
self.original_tool_name, kwargs
)
else:
raise
result = await client.call_tool(self.original_tool_name, kwargs)
finally:
# Always disconnect after tool call to ensure clean context manager lifecycle
# This prevents "exit cancel scope in different task" errors
# All transport context managers must be exited in the same event loop they were entered
await self._mcp_client.disconnect()
await client.disconnect()
# Extract result content
if isinstance(result, str):
return result
# Handle various result formats
if hasattr(result, "content") and result.content:
if isinstance(result.content, list) and len(result.content) > 0:
content_item = result.content[0]

View File

@@ -2353,3 +2353,68 @@ def test_agent_without_apps_no_platform_tools():
tools = crew._prepare_tools(agent, task, [])
assert tools == []
def test_agent_mcps_accepts_slug_with_specific_tool():
"""Agent(mcps=["notion#get_page"]) must pass validation (_SLUG_RE)."""
agent = Agent(
role="MCP Agent",
goal="Test MCP validation",
backstory="Test agent",
mcps=["notion#get_page"],
)
assert agent.mcps == ["notion#get_page"]
def test_agent_mcps_accepts_slug_with_hyphenated_tool():
agent = Agent(
role="MCP Agent",
goal="Test MCP validation",
backstory="Test agent",
mcps=["notion#get-page"],
)
assert agent.mcps == ["notion#get-page"]
def test_agent_mcps_accepts_multiple_hash_refs():
agent = Agent(
role="MCP Agent",
goal="Test MCP validation",
backstory="Test agent",
mcps=["notion#get_page", "notion#search", "github#list_repos"],
)
assert len(agent.mcps) == 3
def test_agent_mcps_accepts_mixed_ref_types():
agent = Agent(
role="MCP Agent",
goal="Test MCP validation",
backstory="Test agent",
mcps=[
"notion#get_page",
"notion",
"https://mcp.example.com/api",
],
)
assert len(agent.mcps) == 3
def test_agent_mcps_rejects_hash_without_slug():
with pytest.raises(ValueError, match="Invalid MCP reference"):
Agent(
role="MCP Agent",
goal="Test MCP validation",
backstory="Test agent",
mcps=["#get_page"],
)
def test_agent_mcps_accepts_legacy_prefix_with_tool():
agent = Agent(
role="MCP Agent",
goal="Test MCP validation",
backstory="Test agent",
mcps=["crewai-amp:notion#get_page"],
)
assert agent.mcps == ["crewai-amp:notion#get_page"]

View File

@@ -268,6 +268,54 @@ class TestGetMCPToolsAmpIntegration:
assert len(tools) == 1
assert tools[0].name == "mcp_notion_so_sse_search"
@patch("crewai.mcp.tool_resolver.MCPClient")
@patch.object(MCPToolResolver, "_fetch_amp_mcp_configs")
def test_tool_filter_with_hyphenated_hash_syntax(
self, mock_fetch, mock_client_class, agent
):
"""notion#get-page must match the tool whose sanitized name is get_page."""
mock_fetch.return_value = {
"notion": {
"type": "sse",
"url": "https://mcp.notion.so/sse",
"headers": {"Authorization": "Bearer token"},
},
}
hyphenated_tool_definitions = [
{
"name": "get_page",
"original_name": "get-page",
"description": "Get a page",
"inputSchema": {},
},
{
"name": "search",
"original_name": "search",
"description": "Search tool",
"inputSchema": {
"type": "object",
"properties": {
"query": {"type": "string", "description": "Search query"}
},
"required": ["query"],
},
},
]
mock_client = AsyncMock()
mock_client.list_tools = AsyncMock(return_value=hyphenated_tool_definitions)
mock_client.connected = False
mock_client.connect = AsyncMock()
mock_client.disconnect = AsyncMock()
mock_client_class.return_value = mock_client
tools = agent.get_mcp_tools(["notion#get-page"])
mock_fetch.assert_called_once_with(["notion"])
assert len(tools) == 1
assert tools[0].name.endswith("_get_page")
@patch("crewai.mcp.tool_resolver.MCPClient")
@patch.object(MCPToolResolver, "_fetch_amp_mcp_configs")
def test_deduplicates_slugs(
@@ -371,3 +419,87 @@ class TestGetMCPToolsAmpIntegration:
mock_external.assert_called_once_with("https://external.mcp.com/api")
# 2 from notion + 1 from external + 2 from http_config
assert len(tools) == 5
class TestResolveExternalToolFilter:
"""Tests for _resolve_external with #tool-name filtering."""
@pytest.fixture
def agent(self):
return Agent(
role="Test Agent",
goal="Test goal",
backstory="Test backstory",
)
@pytest.fixture
def resolver(self, agent):
return MCPToolResolver(agent=agent, logger=agent._logger)
@patch.object(MCPToolResolver, "_get_mcp_tool_schemas")
def test_filters_hyphenated_tool_name(self, mock_schemas, resolver):
"""https://...#get-page must match the sanitized key get_page in schemas."""
mock_schemas.return_value = {
"get_page": {
"description": "Get a page",
"args_schema": None,
},
"search": {
"description": "Search tool",
"args_schema": None,
},
}
tools = resolver._resolve_external("https://mcp.example.com/api#get-page")
assert len(tools) == 1
assert "get_page" in tools[0].name
@patch.object(MCPToolResolver, "_get_mcp_tool_schemas")
def test_filters_underscored_tool_name(self, mock_schemas, resolver):
"""https://...#get_page must also match the sanitized key get_page."""
mock_schemas.return_value = {
"get_page": {
"description": "Get a page",
"args_schema": None,
},
"search": {
"description": "Search tool",
"args_schema": None,
},
}
tools = resolver._resolve_external("https://mcp.example.com/api#get_page")
assert len(tools) == 1
assert "get_page" in tools[0].name
@patch.object(MCPToolResolver, "_get_mcp_tool_schemas")
def test_returns_all_tools_without_hash(self, mock_schemas, resolver):
mock_schemas.return_value = {
"get_page": {
"description": "Get a page",
"args_schema": None,
},
"search": {
"description": "Search tool",
"args_schema": None,
},
}
tools = resolver._resolve_external("https://mcp.example.com/api")
assert len(tools) == 2
@patch.object(MCPToolResolver, "_get_mcp_tool_schemas")
def test_returns_empty_for_nonexistent_tool(self, mock_schemas, resolver):
mock_schemas.return_value = {
"search": {
"description": "Search tool",
"args_schema": None,
},
}
tools = resolver._resolve_external("https://mcp.example.com/api#nonexistent")
assert len(tools) == 0

View File

@@ -1,4 +1,5 @@
import asyncio
import concurrent.futures
from unittest.mock import AsyncMock, patch
import pytest
@@ -30,6 +31,17 @@ def mock_tool_definitions():
]
def _make_mock_client(tool_definitions):
"""Create a mock MCPClient that returns *tool_definitions*."""
client = AsyncMock()
client.list_tools = AsyncMock(return_value=tool_definitions)
client.connected = False
client.connect = AsyncMock()
client.disconnect = AsyncMock()
client.call_tool = AsyncMock(return_value="test result")
return client
def test_agent_with_stdio_mcp_config(mock_tool_definitions):
"""Test agent setup with MCPServerStdio configuration."""
stdio_config = MCPServerStdio(
@@ -45,14 +57,8 @@ def test_agent_with_stdio_mcp_config(mock_tool_definitions):
mcps=[stdio_config],
)
with patch("crewai.mcp.tool_resolver.MCPClient") as mock_client_class:
mock_client = AsyncMock()
mock_client.list_tools = AsyncMock(return_value=mock_tool_definitions)
mock_client.connected = False # Will trigger connect
mock_client.connect = AsyncMock()
mock_client.disconnect = AsyncMock()
mock_client_class.return_value = mock_client
mock_client_class.return_value = _make_mock_client(mock_tool_definitions)
tools = agent.get_mcp_tools([stdio_config])
@@ -60,8 +66,7 @@ def test_agent_with_stdio_mcp_config(mock_tool_definitions):
assert all(isinstance(tool, BaseTool) for tool in tools)
mock_client_class.assert_called_once()
call_args = mock_client_class.call_args
transport = call_args.kwargs["transport"]
transport = mock_client_class.call_args.kwargs["transport"]
assert transport.command == "python"
assert transport.args == ["server.py"]
assert transport.env == {"API_KEY": "test_key"}
@@ -83,12 +88,7 @@ def test_agent_with_http_mcp_config(mock_tool_definitions):
)
with patch("crewai.mcp.tool_resolver.MCPClient") as mock_client_class:
mock_client = AsyncMock()
mock_client.list_tools = AsyncMock(return_value=mock_tool_definitions)
mock_client.connected = False # Will trigger connect
mock_client.connect = AsyncMock()
mock_client.disconnect = AsyncMock()
mock_client_class.return_value = mock_client
mock_client_class.return_value = _make_mock_client(mock_tool_definitions)
tools = agent.get_mcp_tools([http_config])
@@ -96,8 +96,7 @@ def test_agent_with_http_mcp_config(mock_tool_definitions):
assert all(isinstance(tool, BaseTool) for tool in tools)
mock_client_class.assert_called_once()
call_args = mock_client_class.call_args
transport = call_args.kwargs["transport"]
transport = mock_client_class.call_args.kwargs["transport"]
assert transport.url == "https://api.example.com/mcp"
assert transport.headers == {"Authorization": "Bearer test_token"}
assert transport.streamable is True
@@ -118,12 +117,7 @@ def test_agent_with_sse_mcp_config(mock_tool_definitions):
)
with patch("crewai.mcp.tool_resolver.MCPClient") as mock_client_class:
mock_client = AsyncMock()
mock_client.list_tools = AsyncMock(return_value=mock_tool_definitions)
mock_client.connected = False
mock_client.connect = AsyncMock()
mock_client.disconnect = AsyncMock()
mock_client_class.return_value = mock_client
mock_client_class.return_value = _make_mock_client(mock_tool_definitions)
tools = agent.get_mcp_tools([sse_config])
@@ -131,8 +125,7 @@ def test_agent_with_sse_mcp_config(mock_tool_definitions):
assert all(isinstance(tool, BaseTool) for tool in tools)
mock_client_class.assert_called_once()
call_args = mock_client_class.call_args
transport = call_args.kwargs["transport"]
transport = mock_client_class.call_args.kwargs["transport"]
assert transport.url == "https://api.example.com/mcp/sse"
assert transport.headers == {"Authorization": "Bearer test_token"}
@@ -142,13 +135,7 @@ def test_mcp_tool_execution_in_sync_context(mock_tool_definitions):
http_config = MCPServerHTTP(url="https://api.example.com/mcp")
with patch("crewai.mcp.tool_resolver.MCPClient") as mock_client_class:
mock_client = AsyncMock()
mock_client.list_tools = AsyncMock(return_value=mock_tool_definitions)
mock_client.connected = False
mock_client.connect = AsyncMock()
mock_client.disconnect = AsyncMock()
mock_client.call_tool = AsyncMock(return_value="test result")
mock_client_class.return_value = mock_client
mock_client_class.return_value = _make_mock_client(mock_tool_definitions)
agent = Agent(
role="Test Agent",
@@ -160,12 +147,12 @@ def test_mcp_tool_execution_in_sync_context(mock_tool_definitions):
tools = agent.get_mcp_tools([http_config])
assert len(tools) == 2
tool = tools[0]
result = tool.run(query="test query")
assert result == "test result"
mock_client.call_tool.assert_called()
# 1 discovery + 1 for the run() invocation
assert mock_client_class.call_count == 2
@pytest.mark.asyncio
@@ -174,13 +161,7 @@ async def test_mcp_tool_execution_in_async_context(mock_tool_definitions):
http_config = MCPServerHTTP(url="https://api.example.com/mcp")
with patch("crewai.mcp.tool_resolver.MCPClient") as mock_client_class:
mock_client = AsyncMock()
mock_client.list_tools = AsyncMock(return_value=mock_tool_definitions)
mock_client.connected = False
mock_client.connect = AsyncMock()
mock_client.disconnect = AsyncMock()
mock_client.call_tool = AsyncMock(return_value="test result")
mock_client_class.return_value = mock_client
mock_client_class.return_value = _make_mock_client(mock_tool_definitions)
agent = Agent(
role="Test Agent",
@@ -192,9 +173,129 @@ async def test_mcp_tool_execution_in_async_context(mock_tool_definitions):
tools = agent.get_mcp_tools([http_config])
assert len(tools) == 2
tool = tools[0]
result = tool.run(query="test query")
assert result == "test result"
mock_client.call_tool.assert_called()
assert mock_client_class.call_count == 2
def test_each_invocation_gets_fresh_client(mock_tool_definitions):
"""Every tool.run() must create its own MCPClient (no shared state)."""
http_config = MCPServerHTTP(url="https://api.example.com/mcp")
clients_created: list = []
def _make_client(**kwargs):
client = _make_mock_client(mock_tool_definitions)
clients_created.append(client)
return client
with patch("crewai.mcp.tool_resolver.MCPClient", side_effect=_make_client):
agent = Agent(
role="Test Agent",
goal="Test goal",
backstory="Test backstory",
mcps=[http_config],
)
tools = agent.get_mcp_tools([http_config])
assert len(tools) == 2
# 1 discovery client so far
assert len(clients_created) == 1
# Two sequential calls to the same tool must create 2 new clients
tools[0].run(query="q1")
tools[0].run(query="q2")
assert len(clients_created) == 3
assert clients_created[1] is not clients_created[2]
def test_parallel_mcp_tool_execution_same_tool(mock_tool_definitions):
"""Parallel calls to the *same* tool must not interfere."""
http_config = MCPServerHTTP(url="https://api.example.com/mcp")
call_log: list[str] = []
def _make_client(**kwargs):
client = AsyncMock()
client.list_tools = AsyncMock(return_value=mock_tool_definitions)
client.connected = False
client.connect = AsyncMock()
client.disconnect = AsyncMock()
async def _call_tool(name, args):
call_log.append(name)
await asyncio.sleep(0.05)
return f"result-{name}"
client.call_tool = AsyncMock(side_effect=_call_tool)
return client
with patch("crewai.mcp.tool_resolver.MCPClient", side_effect=_make_client):
agent = Agent(
role="Test Agent",
goal="Test goal",
backstory="Test backstory",
mcps=[http_config],
)
tools = agent.get_mcp_tools([http_config])
assert len(tools) >= 1
tool = tools[0]
# Call the SAME tool concurrently -- the exact scenario from the bug
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as pool:
futures = [
pool.submit(tool.run, query="q1"),
pool.submit(tool.run, query="q2"),
]
results = [f.result() for f in concurrent.futures.as_completed(futures)]
assert len(results) == 2
assert all("result-" in r for r in results)
assert len(call_log) == 2
def test_parallel_mcp_tool_execution_different_tools(mock_tool_definitions):
"""Parallel calls to different tools from the same server must not interfere."""
http_config = MCPServerHTTP(url="https://api.example.com/mcp")
call_log: list[str] = []
def _make_client(**kwargs):
client = AsyncMock()
client.list_tools = AsyncMock(return_value=mock_tool_definitions)
client.connected = False
client.connect = AsyncMock()
client.disconnect = AsyncMock()
async def _call_tool(name, args):
call_log.append(name)
await asyncio.sleep(0.05)
return f"result-{name}"
client.call_tool = AsyncMock(side_effect=_call_tool)
return client
with patch("crewai.mcp.tool_resolver.MCPClient", side_effect=_make_client):
agent = Agent(
role="Test Agent",
goal="Test goal",
backstory="Test backstory",
mcps=[http_config],
)
tools = agent.get_mcp_tools([http_config])
assert len(tools) == 2
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as pool:
futures = [
pool.submit(tools[0].run, query="q1"),
pool.submit(tools[1].run, query="q2"),
]
results = [f.result() for f in concurrent.futures.as_completed(futures)]
assert len(results) == 2
assert all("result-" in r for r in results)
assert len(call_log) == 2