mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-04-30 23:02:50 +00:00
* fix: allow hyphenated tool names in MCP references like notion#get-page
The _SLUG_RE regex on BaseAgent rejected MCP tool references containing
hyphens (e.g. "notion#get-page") because the fragment pattern only
matched \w (word chars)
* fix: create fresh MCP client per tool invocation to prevent parallel call races
When the LLM dispatches parallel calls to MCP tools on the same server, the executor runs them concurrently via ThreadPoolExecutor. Previously, all tools from a server shared a single MCPClient instance, and even the same tool called twice would reuse one client. Since each thread creates its own asyncio event loop via asyncio.run(), concurrent connect/disconnect calls on the shared client caused anyio cancel-scope errors ("Attempted to exit cancel scope in a different task than it was entered in").
The fix introduces a client_factory pattern: MCPNativeTool now receives a zero-arg callable that produces a fresh MCPClient + transport on every
_run_async() invocation. This eliminates all shared mutable connection state between concurrent calls, whether to the same tool or different tools from the same server.
* test: ensure we can filter hyphenated MCP tool
302 lines
9.7 KiB
Python
302 lines
9.7 KiB
Python
import asyncio
|
|
import concurrent.futures
|
|
from unittest.mock import AsyncMock, patch
|
|
|
|
import pytest
|
|
from crewai.agent.core import Agent
|
|
from crewai.mcp.config import MCPServerHTTP, MCPServerSSE, MCPServerStdio
|
|
from crewai.tools.base_tool import BaseTool
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_tool_definitions():
|
|
"""Create mock MCP tool definitions (as returned by list_tools)."""
|
|
return [
|
|
{
|
|
"name": "test_tool_1",
|
|
"description": "Test tool 1 description",
|
|
"inputSchema": {
|
|
"type": "object",
|
|
"properties": {
|
|
"query": {"type": "string", "description": "Search query"}
|
|
},
|
|
"required": ["query"]
|
|
}
|
|
},
|
|
{
|
|
"name": "test_tool_2",
|
|
"description": "Test tool 2 description",
|
|
"inputSchema": {}
|
|
}
|
|
]
|
|
|
|
|
|
def _make_mock_client(tool_definitions):
|
|
"""Create a mock MCPClient that returns *tool_definitions*."""
|
|
client = AsyncMock()
|
|
client.list_tools = AsyncMock(return_value=tool_definitions)
|
|
client.connected = False
|
|
client.connect = AsyncMock()
|
|
client.disconnect = AsyncMock()
|
|
client.call_tool = AsyncMock(return_value="test result")
|
|
return client
|
|
|
|
|
|
def test_agent_with_stdio_mcp_config(mock_tool_definitions):
|
|
"""Test agent setup with MCPServerStdio configuration."""
|
|
stdio_config = MCPServerStdio(
|
|
command="python",
|
|
args=["server.py"],
|
|
env={"API_KEY": "test_key"},
|
|
)
|
|
|
|
agent = Agent(
|
|
role="Test Agent",
|
|
goal="Test goal",
|
|
backstory="Test backstory",
|
|
mcps=[stdio_config],
|
|
)
|
|
|
|
with patch("crewai.mcp.tool_resolver.MCPClient") as mock_client_class:
|
|
mock_client_class.return_value = _make_mock_client(mock_tool_definitions)
|
|
|
|
tools = agent.get_mcp_tools([stdio_config])
|
|
|
|
assert len(tools) == 2
|
|
assert all(isinstance(tool, BaseTool) for tool in tools)
|
|
|
|
mock_client_class.assert_called_once()
|
|
transport = mock_client_class.call_args.kwargs["transport"]
|
|
assert transport.command == "python"
|
|
assert transport.args == ["server.py"]
|
|
assert transport.env == {"API_KEY": "test_key"}
|
|
|
|
|
|
def test_agent_with_http_mcp_config(mock_tool_definitions):
|
|
"""Test agent setup with MCPServerHTTP configuration."""
|
|
http_config = MCPServerHTTP(
|
|
url="https://api.example.com/mcp",
|
|
headers={"Authorization": "Bearer test_token"},
|
|
streamable=True,
|
|
)
|
|
|
|
agent = Agent(
|
|
role="Test Agent",
|
|
goal="Test goal",
|
|
backstory="Test backstory",
|
|
mcps=[http_config],
|
|
)
|
|
|
|
with patch("crewai.mcp.tool_resolver.MCPClient") as mock_client_class:
|
|
mock_client_class.return_value = _make_mock_client(mock_tool_definitions)
|
|
|
|
tools = agent.get_mcp_tools([http_config])
|
|
|
|
assert len(tools) == 2
|
|
assert all(isinstance(tool, BaseTool) for tool in tools)
|
|
|
|
mock_client_class.assert_called_once()
|
|
transport = mock_client_class.call_args.kwargs["transport"]
|
|
assert transport.url == "https://api.example.com/mcp"
|
|
assert transport.headers == {"Authorization": "Bearer test_token"}
|
|
assert transport.streamable is True
|
|
|
|
|
|
def test_agent_with_sse_mcp_config(mock_tool_definitions):
|
|
"""Test agent setup with MCPServerSSE configuration."""
|
|
sse_config = MCPServerSSE(
|
|
url="https://api.example.com/mcp/sse",
|
|
headers={"Authorization": "Bearer test_token"},
|
|
)
|
|
|
|
agent = Agent(
|
|
role="Test Agent",
|
|
goal="Test goal",
|
|
backstory="Test backstory",
|
|
mcps=[sse_config],
|
|
)
|
|
|
|
with patch("crewai.mcp.tool_resolver.MCPClient") as mock_client_class:
|
|
mock_client_class.return_value = _make_mock_client(mock_tool_definitions)
|
|
|
|
tools = agent.get_mcp_tools([sse_config])
|
|
|
|
assert len(tools) == 2
|
|
assert all(isinstance(tool, BaseTool) for tool in tools)
|
|
|
|
mock_client_class.assert_called_once()
|
|
transport = mock_client_class.call_args.kwargs["transport"]
|
|
assert transport.url == "https://api.example.com/mcp/sse"
|
|
assert transport.headers == {"Authorization": "Bearer test_token"}
|
|
|
|
|
|
def test_mcp_tool_execution_in_sync_context(mock_tool_definitions):
|
|
"""Test MCPNativeTool execution in synchronous context (normal crew execution)."""
|
|
http_config = MCPServerHTTP(url="https://api.example.com/mcp")
|
|
|
|
with patch("crewai.mcp.tool_resolver.MCPClient") as mock_client_class:
|
|
mock_client_class.return_value = _make_mock_client(mock_tool_definitions)
|
|
|
|
agent = Agent(
|
|
role="Test Agent",
|
|
goal="Test goal",
|
|
backstory="Test backstory",
|
|
mcps=[http_config],
|
|
)
|
|
|
|
tools = agent.get_mcp_tools([http_config])
|
|
assert len(tools) == 2
|
|
|
|
tool = tools[0]
|
|
result = tool.run(query="test query")
|
|
|
|
assert result == "test result"
|
|
# 1 discovery + 1 for the run() invocation
|
|
assert mock_client_class.call_count == 2
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_mcp_tool_execution_in_async_context(mock_tool_definitions):
|
|
"""Test MCPNativeTool execution in async context (e.g., from a Flow)."""
|
|
http_config = MCPServerHTTP(url="https://api.example.com/mcp")
|
|
|
|
with patch("crewai.mcp.tool_resolver.MCPClient") as mock_client_class:
|
|
mock_client_class.return_value = _make_mock_client(mock_tool_definitions)
|
|
|
|
agent = Agent(
|
|
role="Test Agent",
|
|
goal="Test goal",
|
|
backstory="Test backstory",
|
|
mcps=[http_config],
|
|
)
|
|
|
|
tools = agent.get_mcp_tools([http_config])
|
|
assert len(tools) == 2
|
|
|
|
tool = tools[0]
|
|
result = tool.run(query="test query")
|
|
|
|
assert result == "test result"
|
|
assert mock_client_class.call_count == 2
|
|
|
|
|
|
def test_each_invocation_gets_fresh_client(mock_tool_definitions):
|
|
"""Every tool.run() must create its own MCPClient (no shared state)."""
|
|
http_config = MCPServerHTTP(url="https://api.example.com/mcp")
|
|
|
|
clients_created: list = []
|
|
|
|
def _make_client(**kwargs):
|
|
client = _make_mock_client(mock_tool_definitions)
|
|
clients_created.append(client)
|
|
return client
|
|
|
|
with patch("crewai.mcp.tool_resolver.MCPClient", side_effect=_make_client):
|
|
agent = Agent(
|
|
role="Test Agent",
|
|
goal="Test goal",
|
|
backstory="Test backstory",
|
|
mcps=[http_config],
|
|
)
|
|
|
|
tools = agent.get_mcp_tools([http_config])
|
|
assert len(tools) == 2
|
|
# 1 discovery client so far
|
|
assert len(clients_created) == 1
|
|
|
|
# Two sequential calls to the same tool must create 2 new clients
|
|
tools[0].run(query="q1")
|
|
tools[0].run(query="q2")
|
|
assert len(clients_created) == 3
|
|
assert clients_created[1] is not clients_created[2]
|
|
|
|
|
|
def test_parallel_mcp_tool_execution_same_tool(mock_tool_definitions):
|
|
"""Parallel calls to the *same* tool must not interfere."""
|
|
http_config = MCPServerHTTP(url="https://api.example.com/mcp")
|
|
|
|
call_log: list[str] = []
|
|
|
|
def _make_client(**kwargs):
|
|
client = AsyncMock()
|
|
client.list_tools = AsyncMock(return_value=mock_tool_definitions)
|
|
client.connected = False
|
|
client.connect = AsyncMock()
|
|
client.disconnect = AsyncMock()
|
|
|
|
async def _call_tool(name, args):
|
|
call_log.append(name)
|
|
await asyncio.sleep(0.05)
|
|
return f"result-{name}"
|
|
|
|
client.call_tool = AsyncMock(side_effect=_call_tool)
|
|
return client
|
|
|
|
with patch("crewai.mcp.tool_resolver.MCPClient", side_effect=_make_client):
|
|
agent = Agent(
|
|
role="Test Agent",
|
|
goal="Test goal",
|
|
backstory="Test backstory",
|
|
mcps=[http_config],
|
|
)
|
|
|
|
tools = agent.get_mcp_tools([http_config])
|
|
assert len(tools) >= 1
|
|
tool = tools[0]
|
|
|
|
# Call the SAME tool concurrently -- the exact scenario from the bug
|
|
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as pool:
|
|
futures = [
|
|
pool.submit(tool.run, query="q1"),
|
|
pool.submit(tool.run, query="q2"),
|
|
]
|
|
results = [f.result() for f in concurrent.futures.as_completed(futures)]
|
|
|
|
assert len(results) == 2
|
|
assert all("result-" in r for r in results)
|
|
assert len(call_log) == 2
|
|
|
|
|
|
def test_parallel_mcp_tool_execution_different_tools(mock_tool_definitions):
|
|
"""Parallel calls to different tools from the same server must not interfere."""
|
|
http_config = MCPServerHTTP(url="https://api.example.com/mcp")
|
|
|
|
call_log: list[str] = []
|
|
|
|
def _make_client(**kwargs):
|
|
client = AsyncMock()
|
|
client.list_tools = AsyncMock(return_value=mock_tool_definitions)
|
|
client.connected = False
|
|
client.connect = AsyncMock()
|
|
client.disconnect = AsyncMock()
|
|
|
|
async def _call_tool(name, args):
|
|
call_log.append(name)
|
|
await asyncio.sleep(0.05)
|
|
return f"result-{name}"
|
|
|
|
client.call_tool = AsyncMock(side_effect=_call_tool)
|
|
return client
|
|
|
|
with patch("crewai.mcp.tool_resolver.MCPClient", side_effect=_make_client):
|
|
agent = Agent(
|
|
role="Test Agent",
|
|
goal="Test goal",
|
|
backstory="Test backstory",
|
|
mcps=[http_config],
|
|
)
|
|
|
|
tools = agent.get_mcp_tools([http_config])
|
|
assert len(tools) == 2
|
|
|
|
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as pool:
|
|
futures = [
|
|
pool.submit(tools[0].run, query="q1"),
|
|
pool.submit(tools[1].run, query="q2"),
|
|
]
|
|
results = [f.result() for f in concurrent.futures.as_completed(futures)]
|
|
|
|
assert len(results) == 2
|
|
assert all("result-" in r for r in results)
|
|
assert len(call_log) == 2
|