mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-05-01 07:13:00 +00:00
fix: use persistent event loop for MCP operations to prevent cancel scope errors
Replace per-call asyncio.run() with a single persistent background event loop for all MCP operations. The MCP SDK's streamable HTTP transport uses anyio task groups whose cancel scopes must be entered and exited on the same event loop and task. Creating a throwaway loop per tool call caused "Attempted to exit cancel scope in a different task" RuntimeErrors during cleanup, preventing MCP tools from working reliably
This commit is contained in:
@@ -95,7 +95,7 @@ class MCPClient:
|
|||||||
self.discovery_timeout = discovery_timeout
|
self.discovery_timeout = discovery_timeout
|
||||||
self.max_retries = max_retries
|
self.max_retries = max_retries
|
||||||
self.cache_tools_list = cache_tools_list
|
self.cache_tools_list = cache_tools_list
|
||||||
# self._logger = logger or logging.getLogger(__name__)
|
self._logger = logger or logging.getLogger(__name__)
|
||||||
self._session: Any = None
|
self._session: Any = None
|
||||||
self._initialized = False
|
self._initialized = False
|
||||||
self._exit_stack = AsyncExitStack()
|
self._exit_stack = AsyncExitStack()
|
||||||
@@ -358,10 +358,12 @@ class MCPClient:
|
|||||||
"""Cleanup resources when an error occurs during connection."""
|
"""Cleanup resources when an error occurs during connection."""
|
||||||
try:
|
try:
|
||||||
await self._exit_stack.aclose()
|
await self._exit_stack.aclose()
|
||||||
|
except (RuntimeError, BaseExceptionGroup) as e:
|
||||||
except Exception as e:
|
error_msg = str(e).lower()
|
||||||
# Best effort cleanup - ignore all other errors
|
if "cancel scope" not in error_msg and "task" not in error_msg:
|
||||||
raise RuntimeError(f"Error during MCP client cleanup: {e}") from e
|
raise RuntimeError(f"Error during MCP client cleanup: {e}") from e
|
||||||
|
except Exception:
|
||||||
|
self._logger.debug("Suppressed error during MCP cleanup", exc_info=True)
|
||||||
finally:
|
finally:
|
||||||
self._session = None
|
self._session = None
|
||||||
self._initialized = False
|
self._initialized = False
|
||||||
@@ -374,8 +376,12 @@ class MCPClient:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
await self._exit_stack.aclose()
|
await self._exit_stack.aclose()
|
||||||
except Exception as e:
|
except (RuntimeError, BaseExceptionGroup) as e:
|
||||||
raise RuntimeError(f"Error during MCP client disconnect: {e}") from e
|
error_msg = str(e).lower()
|
||||||
|
if "cancel scope" not in error_msg and "task" not in error_msg:
|
||||||
|
raise RuntimeError(f"Error during MCP client disconnect: {e}") from e
|
||||||
|
except Exception:
|
||||||
|
self._logger.debug("Suppressed error during MCP disconnect", exc_info=True)
|
||||||
finally:
|
finally:
|
||||||
self._session = None
|
self._session = None
|
||||||
self._initialized = False
|
self._initialized = False
|
||||||
|
|||||||
@@ -87,7 +87,12 @@ class MCPToolResolver:
|
|||||||
return all_tools
|
return all_tools
|
||||||
|
|
||||||
def cleanup(self) -> None:
|
def cleanup(self) -> None:
|
||||||
"""Disconnect all MCP client connections."""
|
"""Disconnect all MCP client connections.
|
||||||
|
|
||||||
|
Submits the disconnect coroutines to the persistent MCP event loop
|
||||||
|
so that transport context managers are exited on the same loop they
|
||||||
|
were entered on.
|
||||||
|
"""
|
||||||
if not self._clients:
|
if not self._clients:
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -97,7 +102,11 @@ class MCPToolResolver:
|
|||||||
await client.disconnect()
|
await client.disconnect()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
asyncio.run(_disconnect_all())
|
from crewai.tools.mcp_native_tool import _get_mcp_event_loop
|
||||||
|
|
||||||
|
loop = _get_mcp_event_loop()
|
||||||
|
future = asyncio.run_coroutine_threadsafe(_disconnect_all(), loop)
|
||||||
|
future.result(timeout=30)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self._logger.log("error", f"Error during MCP client cleanup: {e}")
|
self._logger.log("error", f"Error during MCP client cleanup: {e}")
|
||||||
finally:
|
finally:
|
||||||
@@ -330,30 +339,27 @@ class MCPToolResolver:
|
|||||||
) from e
|
) from e
|
||||||
|
|
||||||
try:
|
try:
|
||||||
try:
|
from crewai.tools.mcp_native_tool import _get_mcp_event_loop
|
||||||
asyncio.get_running_loop()
|
|
||||||
import concurrent.futures
|
|
||||||
|
|
||||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
loop = _get_mcp_event_loop()
|
||||||
future = executor.submit(
|
future = asyncio.run_coroutine_threadsafe(
|
||||||
asyncio.run, _setup_client_and_list_tools()
|
_setup_client_and_list_tools(), loop
|
||||||
)
|
)
|
||||||
tools_list = future.result()
|
try:
|
||||||
except RuntimeError:
|
tools_list = future.result(timeout=60)
|
||||||
try:
|
except RuntimeError as e:
|
||||||
tools_list = asyncio.run(_setup_client_and_list_tools())
|
error_msg = str(e).lower()
|
||||||
except RuntimeError as e:
|
if "cancel scope" in error_msg or "task" in error_msg:
|
||||||
error_msg = str(e).lower()
|
|
||||||
if "cancel scope" in error_msg or "task" in error_msg:
|
|
||||||
raise ConnectionError(
|
|
||||||
"MCP connection failed due to event loop cleanup issues. "
|
|
||||||
"This may be due to authentication errors or server unavailability."
|
|
||||||
) from e
|
|
||||||
except asyncio.CancelledError as e:
|
|
||||||
raise ConnectionError(
|
raise ConnectionError(
|
||||||
"MCP connection was cancelled. This may indicate an authentication "
|
"MCP connection failed due to event loop cleanup issues. "
|
||||||
"error or server unavailability."
|
"This may be due to authentication errors or server unavailability."
|
||||||
) from e
|
) from e
|
||||||
|
raise
|
||||||
|
except asyncio.CancelledError as e:
|
||||||
|
raise ConnectionError(
|
||||||
|
"MCP connection was cancelled. This may indicate an authentication "
|
||||||
|
"error or server unavailability."
|
||||||
|
) from e
|
||||||
|
|
||||||
if mcp_config.tool_filter:
|
if mcp_config.tool_filter:
|
||||||
filtered_tools = []
|
filtered_tools = []
|
||||||
@@ -410,7 +416,13 @@ class MCPToolResolver:
|
|||||||
return cast(list[BaseTool], tools), client
|
return cast(list[BaseTool], tools), client
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if client.connected:
|
if client.connected:
|
||||||
asyncio.run(client.disconnect())
|
try:
|
||||||
|
fut = asyncio.run_coroutine_threadsafe(
|
||||||
|
client.disconnect(), loop
|
||||||
|
)
|
||||||
|
fut.result(timeout=10)
|
||||||
|
except Exception:
|
||||||
|
self._logger.log("debug", "Suppressed error during MCP client disconnect on cleanup")
|
||||||
|
|
||||||
raise RuntimeError(f"Failed to get native MCP tools: {e}") from e
|
raise RuntimeError(f"Failed to get native MCP tools: {e}") from e
|
||||||
|
|
||||||
|
|||||||
@@ -5,11 +5,37 @@ for better performance and connection management.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import threading
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from crewai.tools import BaseTool
|
from crewai.tools import BaseTool
|
||||||
|
|
||||||
|
|
||||||
|
_mcp_loop: asyncio.AbstractEventLoop | None = None
|
||||||
|
_mcp_loop_thread: threading.Thread | None = None
|
||||||
|
_mcp_loop_lock = threading.Lock()
|
||||||
|
|
||||||
|
|
||||||
|
def _get_mcp_event_loop() -> asyncio.AbstractEventLoop:
|
||||||
|
"""Return (and lazily start) a persistent event loop for MCP operations.
|
||||||
|
|
||||||
|
All MCP SDK transports use anyio task groups whose cancel scopes must be
|
||||||
|
entered and exited on the same event loop / task. By funnelling every
|
||||||
|
MCP call through one long-lived loop we avoid the "exit cancel scope in
|
||||||
|
a different task" errors that happen when asyncio.run() creates a
|
||||||
|
throwaway loop per call.
|
||||||
|
"""
|
||||||
|
global _mcp_loop, _mcp_loop_thread
|
||||||
|
with _mcp_loop_lock:
|
||||||
|
if _mcp_loop is None or _mcp_loop.is_closed():
|
||||||
|
_mcp_loop = asyncio.new_event_loop()
|
||||||
|
_mcp_loop_thread = threading.Thread(
|
||||||
|
target=_mcp_loop.run_forever, daemon=True, name="mcp-event-loop"
|
||||||
|
)
|
||||||
|
_mcp_loop_thread.start()
|
||||||
|
return _mcp_loop
|
||||||
|
|
||||||
|
|
||||||
class MCPNativeTool(BaseTool):
|
class MCPNativeTool(BaseTool):
|
||||||
"""Native MCP tool that reuses client sessions.
|
"""Native MCP tool that reuses client sessions.
|
||||||
|
|
||||||
@@ -38,13 +64,10 @@ class MCPNativeTool(BaseTool):
|
|||||||
server_name: Name of the MCP server for prefixing.
|
server_name: Name of the MCP server for prefixing.
|
||||||
original_tool_name: Original name of the tool on the MCP server.
|
original_tool_name: Original name of the tool on the MCP server.
|
||||||
"""
|
"""
|
||||||
# Create tool name with server prefix to avoid conflicts
|
|
||||||
prefixed_name = f"{server_name}_{tool_name}"
|
prefixed_name = f"{server_name}_{tool_name}"
|
||||||
|
|
||||||
# Handle args_schema properly - BaseTool expects a BaseModel subclass
|
|
||||||
args_schema = tool_schema.get("args_schema")
|
args_schema = tool_schema.get("args_schema")
|
||||||
|
|
||||||
# Only pass args_schema if it's provided
|
|
||||||
kwargs = {
|
kwargs = {
|
||||||
"name": prefixed_name,
|
"name": prefixed_name,
|
||||||
"description": tool_schema.get(
|
"description": tool_schema.get(
|
||||||
@@ -57,11 +80,9 @@ class MCPNativeTool(BaseTool):
|
|||||||
|
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
# Set instance attributes after super().__init__
|
|
||||||
self._mcp_client = mcp_client
|
self._mcp_client = mcp_client
|
||||||
self._original_tool_name = original_tool_name or tool_name
|
self._original_tool_name = original_tool_name or tool_name
|
||||||
self._server_name = server_name
|
self._server_name = server_name
|
||||||
# self._logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def mcp_client(self) -> Any:
|
def mcp_client(self) -> Any:
|
||||||
@@ -81,25 +102,21 @@ class MCPNativeTool(BaseTool):
|
|||||||
def _run(self, **kwargs) -> str:
|
def _run(self, **kwargs) -> str:
|
||||||
"""Execute tool using the MCP client session.
|
"""Execute tool using the MCP client session.
|
||||||
|
|
||||||
|
Submits work to a persistent background event loop so that all MCP
|
||||||
|
transport context managers (which rely on anyio cancel scopes) stay
|
||||||
|
on the same loop and task throughout their lifecycle.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
**kwargs: Arguments to pass to the MCP tool.
|
**kwargs: Arguments to pass to the MCP tool.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Result from the MCP tool execution.
|
Result from the MCP tool execution.
|
||||||
"""
|
"""
|
||||||
|
loop = _get_mcp_event_loop()
|
||||||
|
timeout = self._mcp_client.connect_timeout + self._mcp_client.execution_timeout
|
||||||
try:
|
try:
|
||||||
try:
|
future = asyncio.run_coroutine_threadsafe(self._run_async(**kwargs), loop)
|
||||||
asyncio.get_running_loop()
|
return future.result(timeout=timeout)
|
||||||
|
|
||||||
import concurrent.futures
|
|
||||||
|
|
||||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
|
||||||
coro = self._run_async(**kwargs)
|
|
||||||
future = executor.submit(asyncio.run, coro)
|
|
||||||
return future.result()
|
|
||||||
except RuntimeError:
|
|
||||||
return asyncio.run(self._run_async(**kwargs))
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Error executing MCP tool {self.original_tool_name}: {e!s}"
|
f"Error executing MCP tool {self.original_tool_name}: {e!s}"
|
||||||
@@ -114,18 +131,11 @@ class MCPNativeTool(BaseTool):
|
|||||||
Returns:
|
Returns:
|
||||||
Result from the MCP tool execution.
|
Result from the MCP tool execution.
|
||||||
"""
|
"""
|
||||||
# Note: Since we use asyncio.run() which creates a new event loop each time,
|
if not self._mcp_client.connected:
|
||||||
# Always reconnect on-demand because asyncio.run() creates new event loops per call
|
await self._mcp_client.connect()
|
||||||
# All MCP transport context managers (stdio, streamablehttp_client, sse_client)
|
|
||||||
# use anyio.create_task_group() which can't span different event loops
|
|
||||||
if self._mcp_client.connected:
|
|
||||||
await self._mcp_client.disconnect()
|
|
||||||
|
|
||||||
await self._mcp_client.connect()
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = await self._mcp_client.call_tool(self.original_tool_name, kwargs)
|
result = await self._mcp_client.call_tool(self.original_tool_name, kwargs)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_str = str(e).lower()
|
error_str = str(e).lower()
|
||||||
if (
|
if (
|
||||||
@@ -135,24 +145,15 @@ class MCPNativeTool(BaseTool):
|
|||||||
):
|
):
|
||||||
await self._mcp_client.disconnect()
|
await self._mcp_client.disconnect()
|
||||||
await self._mcp_client.connect()
|
await self._mcp_client.connect()
|
||||||
# Retry the call
|
|
||||||
result = await self._mcp_client.call_tool(
|
result = await self._mcp_client.call_tool(
|
||||||
self.original_tool_name, kwargs
|
self.original_tool_name, kwargs
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise
|
raise
|
||||||
|
|
||||||
finally:
|
|
||||||
# Always disconnect after tool call to ensure clean context manager lifecycle
|
|
||||||
# This prevents "exit cancel scope in different task" errors
|
|
||||||
# All transport context managers must be exited in the same event loop they were entered
|
|
||||||
await self._mcp_client.disconnect()
|
|
||||||
|
|
||||||
# Extract result content
|
|
||||||
if isinstance(result, str):
|
if isinstance(result, str):
|
||||||
return result
|
return result
|
||||||
|
|
||||||
# Handle various result formats
|
|
||||||
if hasattr(result, "content") and result.content:
|
if hasattr(result, "content") and result.content:
|
||||||
if isinstance(result.content, list) and len(result.content) > 0:
|
if isinstance(result.content, list) and len(result.content) > 0:
|
||||||
content_item = result.content[0]
|
content_item = result.content[0]
|
||||||
|
|||||||
@@ -148,6 +148,8 @@ def test_mcp_tool_execution_in_sync_context(mock_tool_definitions):
|
|||||||
mock_client.connect = AsyncMock()
|
mock_client.connect = AsyncMock()
|
||||||
mock_client.disconnect = AsyncMock()
|
mock_client.disconnect = AsyncMock()
|
||||||
mock_client.call_tool = AsyncMock(return_value="test result")
|
mock_client.call_tool = AsyncMock(return_value="test result")
|
||||||
|
mock_client.connect_timeout = 30
|
||||||
|
mock_client.execution_timeout = 30
|
||||||
mock_client_class.return_value = mock_client
|
mock_client_class.return_value = mock_client
|
||||||
|
|
||||||
agent = Agent(
|
agent = Agent(
|
||||||
@@ -180,6 +182,8 @@ async def test_mcp_tool_execution_in_async_context(mock_tool_definitions):
|
|||||||
mock_client.connect = AsyncMock()
|
mock_client.connect = AsyncMock()
|
||||||
mock_client.disconnect = AsyncMock()
|
mock_client.disconnect = AsyncMock()
|
||||||
mock_client.call_tool = AsyncMock(return_value="test result")
|
mock_client.call_tool = AsyncMock(return_value="test result")
|
||||||
|
mock_client.connect_timeout = 30
|
||||||
|
mock_client.execution_timeout = 30
|
||||||
mock_client_class.return_value = mock_client
|
mock_client_class.return_value = mock_client
|
||||||
|
|
||||||
agent = Agent(
|
agent = Agent(
|
||||||
|
|||||||
Reference in New Issue
Block a user