mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-02-28 00:38:13 +00:00
Compare commits
2 Commits
gl/ci/pr-c
...
lg-mcp-eve
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a96f114a36 | ||
|
|
ca220cdc23 |
@@ -95,7 +95,7 @@ class MCPClient:
|
||||
self.discovery_timeout = discovery_timeout
|
||||
self.max_retries = max_retries
|
||||
self.cache_tools_list = cache_tools_list
|
||||
# self._logger = logger or logging.getLogger(__name__)
|
||||
self._logger = logger or logging.getLogger(__name__)
|
||||
self._session: Any = None
|
||||
self._initialized = False
|
||||
self._exit_stack = AsyncExitStack()
|
||||
@@ -358,10 +358,12 @@ class MCPClient:
|
||||
"""Cleanup resources when an error occurs during connection."""
|
||||
try:
|
||||
await self._exit_stack.aclose()
|
||||
|
||||
except Exception as e:
|
||||
# Best effort cleanup - ignore all other errors
|
||||
raise RuntimeError(f"Error during MCP client cleanup: {e}") from e
|
||||
except (RuntimeError, BaseExceptionGroup) as e:
|
||||
error_msg = str(e).lower()
|
||||
if "cancel scope" not in error_msg and "task" not in error_msg:
|
||||
raise RuntimeError(f"Error during MCP client cleanup: {e}") from e
|
||||
except Exception:
|
||||
self._logger.debug("Suppressed error during MCP cleanup", exc_info=True)
|
||||
finally:
|
||||
self._session = None
|
||||
self._initialized = False
|
||||
@@ -374,8 +376,12 @@ class MCPClient:
|
||||
|
||||
try:
|
||||
await self._exit_stack.aclose()
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Error during MCP client disconnect: {e}") from e
|
||||
except (RuntimeError, BaseExceptionGroup) as e:
|
||||
error_msg = str(e).lower()
|
||||
if "cancel scope" not in error_msg and "task" not in error_msg:
|
||||
raise RuntimeError(f"Error during MCP client disconnect: {e}") from e
|
||||
except Exception:
|
||||
self._logger.debug("Suppressed error during MCP disconnect", exc_info=True)
|
||||
finally:
|
||||
self._session = None
|
||||
self._initialized = False
|
||||
|
||||
@@ -87,7 +87,12 @@ class MCPToolResolver:
|
||||
return all_tools
|
||||
|
||||
def cleanup(self) -> None:
|
||||
"""Disconnect all MCP client connections."""
|
||||
"""Disconnect all MCP client connections.
|
||||
|
||||
Submits the disconnect coroutines to the persistent MCP event loop
|
||||
so that transport context managers are exited on the same loop they
|
||||
were entered on.
|
||||
"""
|
||||
if not self._clients:
|
||||
return
|
||||
|
||||
@@ -97,7 +102,11 @@ class MCPToolResolver:
|
||||
await client.disconnect()
|
||||
|
||||
try:
|
||||
asyncio.run(_disconnect_all())
|
||||
from crewai.tools.mcp_native_tool import _get_mcp_event_loop
|
||||
|
||||
loop = _get_mcp_event_loop()
|
||||
future = asyncio.run_coroutine_threadsafe(_disconnect_all(), loop)
|
||||
future.result(timeout=30)
|
||||
except Exception as e:
|
||||
self._logger.log("error", f"Error during MCP client cleanup: {e}")
|
||||
finally:
|
||||
@@ -330,30 +339,27 @@ class MCPToolResolver:
|
||||
) from e
|
||||
|
||||
try:
|
||||
try:
|
||||
asyncio.get_running_loop()
|
||||
import concurrent.futures
|
||||
from crewai.tools.mcp_native_tool import _get_mcp_event_loop
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
future = executor.submit(
|
||||
asyncio.run, _setup_client_and_list_tools()
|
||||
)
|
||||
tools_list = future.result()
|
||||
except RuntimeError:
|
||||
try:
|
||||
tools_list = asyncio.run(_setup_client_and_list_tools())
|
||||
except RuntimeError as e:
|
||||
error_msg = str(e).lower()
|
||||
if "cancel scope" in error_msg or "task" in error_msg:
|
||||
raise ConnectionError(
|
||||
"MCP connection failed due to event loop cleanup issues. "
|
||||
"This may be due to authentication errors or server unavailability."
|
||||
) from e
|
||||
except asyncio.CancelledError as e:
|
||||
loop = _get_mcp_event_loop()
|
||||
future = asyncio.run_coroutine_threadsafe(
|
||||
_setup_client_and_list_tools(), loop
|
||||
)
|
||||
try:
|
||||
tools_list = future.result(timeout=60)
|
||||
except RuntimeError as e:
|
||||
error_msg = str(e).lower()
|
||||
if "cancel scope" in error_msg or "task" in error_msg:
|
||||
raise ConnectionError(
|
||||
"MCP connection was cancelled. This may indicate an authentication "
|
||||
"error or server unavailability."
|
||||
"MCP connection failed due to event loop cleanup issues. "
|
||||
"This may be due to authentication errors or server unavailability."
|
||||
) from e
|
||||
raise
|
||||
except asyncio.CancelledError as e:
|
||||
raise ConnectionError(
|
||||
"MCP connection was cancelled. This may indicate an authentication "
|
||||
"error or server unavailability."
|
||||
) from e
|
||||
|
||||
if mcp_config.tool_filter:
|
||||
filtered_tools = []
|
||||
@@ -410,7 +416,13 @@ class MCPToolResolver:
|
||||
return cast(list[BaseTool], tools), client
|
||||
except Exception as e:
|
||||
if client.connected:
|
||||
asyncio.run(client.disconnect())
|
||||
try:
|
||||
fut = asyncio.run_coroutine_threadsafe(
|
||||
client.disconnect(), loop
|
||||
)
|
||||
fut.result(timeout=10)
|
||||
except Exception:
|
||||
self._logger.log("debug", "Suppressed error during MCP client disconnect on cleanup")
|
||||
|
||||
raise RuntimeError(f"Failed to get native MCP tools: {e}") from e
|
||||
|
||||
|
||||
@@ -5,11 +5,37 @@ for better performance and connection management.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import threading
|
||||
from typing import Any
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
|
||||
|
||||
_mcp_loop: asyncio.AbstractEventLoop | None = None
|
||||
_mcp_loop_thread: threading.Thread | None = None
|
||||
_mcp_loop_lock = threading.Lock()
|
||||
|
||||
|
||||
def _get_mcp_event_loop() -> asyncio.AbstractEventLoop:
|
||||
"""Return (and lazily start) a persistent event loop for MCP operations.
|
||||
|
||||
All MCP SDK transports use anyio task groups whose cancel scopes must be
|
||||
entered and exited on the same event loop / task. By funnelling every
|
||||
MCP call through one long-lived loop we avoid the "exit cancel scope in
|
||||
a different task" errors that happen when asyncio.run() creates a
|
||||
throwaway loop per call.
|
||||
"""
|
||||
global _mcp_loop, _mcp_loop_thread
|
||||
with _mcp_loop_lock:
|
||||
if _mcp_loop is None or _mcp_loop.is_closed():
|
||||
_mcp_loop = asyncio.new_event_loop()
|
||||
_mcp_loop_thread = threading.Thread(
|
||||
target=_mcp_loop.run_forever, daemon=True, name="mcp-event-loop"
|
||||
)
|
||||
_mcp_loop_thread.start()
|
||||
return _mcp_loop
|
||||
|
||||
|
||||
class MCPNativeTool(BaseTool):
|
||||
"""Native MCP tool that reuses client sessions.
|
||||
|
||||
@@ -38,13 +64,10 @@ class MCPNativeTool(BaseTool):
|
||||
server_name: Name of the MCP server for prefixing.
|
||||
original_tool_name: Original name of the tool on the MCP server.
|
||||
"""
|
||||
# Create tool name with server prefix to avoid conflicts
|
||||
prefixed_name = f"{server_name}_{tool_name}"
|
||||
|
||||
# Handle args_schema properly - BaseTool expects a BaseModel subclass
|
||||
args_schema = tool_schema.get("args_schema")
|
||||
|
||||
# Only pass args_schema if it's provided
|
||||
kwargs = {
|
||||
"name": prefixed_name,
|
||||
"description": tool_schema.get(
|
||||
@@ -57,11 +80,9 @@ class MCPNativeTool(BaseTool):
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
# Set instance attributes after super().__init__
|
||||
self._mcp_client = mcp_client
|
||||
self._original_tool_name = original_tool_name or tool_name
|
||||
self._server_name = server_name
|
||||
# self._logger = logging.getLogger(__name__)
|
||||
|
||||
@property
|
||||
def mcp_client(self) -> Any:
|
||||
@@ -81,25 +102,21 @@ class MCPNativeTool(BaseTool):
|
||||
def _run(self, **kwargs) -> str:
|
||||
"""Execute tool using the MCP client session.
|
||||
|
||||
Submits work to a persistent background event loop so that all MCP
|
||||
transport context managers (which rely on anyio cancel scopes) stay
|
||||
on the same loop and task throughout their lifecycle.
|
||||
|
||||
Args:
|
||||
**kwargs: Arguments to pass to the MCP tool.
|
||||
|
||||
Returns:
|
||||
Result from the MCP tool execution.
|
||||
"""
|
||||
loop = _get_mcp_event_loop()
|
||||
timeout = self._mcp_client.connect_timeout + self._mcp_client.execution_timeout
|
||||
try:
|
||||
try:
|
||||
asyncio.get_running_loop()
|
||||
|
||||
import concurrent.futures
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
coro = self._run_async(**kwargs)
|
||||
future = executor.submit(asyncio.run, coro)
|
||||
return future.result()
|
||||
except RuntimeError:
|
||||
return asyncio.run(self._run_async(**kwargs))
|
||||
|
||||
future = asyncio.run_coroutine_threadsafe(self._run_async(**kwargs), loop)
|
||||
return future.result(timeout=timeout)
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"Error executing MCP tool {self.original_tool_name}: {e!s}"
|
||||
@@ -114,18 +131,11 @@ class MCPNativeTool(BaseTool):
|
||||
Returns:
|
||||
Result from the MCP tool execution.
|
||||
"""
|
||||
# Note: Since we use asyncio.run() which creates a new event loop each time,
|
||||
# Always reconnect on-demand because asyncio.run() creates new event loops per call
|
||||
# All MCP transport context managers (stdio, streamablehttp_client, sse_client)
|
||||
# use anyio.create_task_group() which can't span different event loops
|
||||
if self._mcp_client.connected:
|
||||
await self._mcp_client.disconnect()
|
||||
|
||||
await self._mcp_client.connect()
|
||||
if not self._mcp_client.connected:
|
||||
await self._mcp_client.connect()
|
||||
|
||||
try:
|
||||
result = await self._mcp_client.call_tool(self.original_tool_name, kwargs)
|
||||
|
||||
except Exception as e:
|
||||
error_str = str(e).lower()
|
||||
if (
|
||||
@@ -135,24 +145,15 @@ class MCPNativeTool(BaseTool):
|
||||
):
|
||||
await self._mcp_client.disconnect()
|
||||
await self._mcp_client.connect()
|
||||
# Retry the call
|
||||
result = await self._mcp_client.call_tool(
|
||||
self.original_tool_name, kwargs
|
||||
)
|
||||
else:
|
||||
raise
|
||||
|
||||
finally:
|
||||
# Always disconnect after tool call to ensure clean context manager lifecycle
|
||||
# This prevents "exit cancel scope in different task" errors
|
||||
# All transport context managers must be exited in the same event loop they were entered
|
||||
await self._mcp_client.disconnect()
|
||||
|
||||
# Extract result content
|
||||
if isinstance(result, str):
|
||||
return result
|
||||
|
||||
# Handle various result formats
|
||||
if hasattr(result, "content") and result.content:
|
||||
if isinstance(result.content, list) and len(result.content) > 0:
|
||||
content_item = result.content[0]
|
||||
|
||||
@@ -148,6 +148,8 @@ def test_mcp_tool_execution_in_sync_context(mock_tool_definitions):
|
||||
mock_client.connect = AsyncMock()
|
||||
mock_client.disconnect = AsyncMock()
|
||||
mock_client.call_tool = AsyncMock(return_value="test result")
|
||||
mock_client.connect_timeout = 30
|
||||
mock_client.execution_timeout = 30
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
agent = Agent(
|
||||
@@ -180,6 +182,8 @@ async def test_mcp_tool_execution_in_async_context(mock_tool_definitions):
|
||||
mock_client.connect = AsyncMock()
|
||||
mock_client.disconnect = AsyncMock()
|
||||
mock_client.call_tool = AsyncMock(return_value="test result")
|
||||
mock_client.connect_timeout = 30
|
||||
mock_client.execution_timeout = 30
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
agent = Agent(
|
||||
|
||||
Reference in New Issue
Block a user