diff --git a/lib/crewai/src/crewai/mcp/transports/sse.py b/lib/crewai/src/crewai/mcp/transports/sse.py index ce418c51f..c2184e7d0 100644 --- a/lib/crewai/src/crewai/mcp/transports/sse.py +++ b/lib/crewai/src/crewai/mcp/transports/sse.py @@ -66,7 +66,6 @@ class SSETransport(BaseTransport): self._transport_context = sse_client( self.url, headers=self.headers if self.headers else None, - terminate_on_close=True, ) read, write = await self._transport_context.__aenter__() diff --git a/lib/crewai/tests/mcp/test_agent_mcp_error_handling.py b/lib/crewai/tests/mcp/test_agent_mcp_error_handling.py new file mode 100644 index 000000000..04b973d00 --- /dev/null +++ b/lib/crewai/tests/mcp/test_agent_mcp_error_handling.py @@ -0,0 +1,78 @@ +import sys +from types import ModuleType +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from crewai.agent.core import Agent +from crewai.mcp import MCPServerSSE + + +class FakeSSEClientError: + def __init__(self, url, headers=None): + self.url = url + self.headers = headers + + async def __aenter__(self): + raise Exception("SSE connection failed") + + async def __aexit__(self, exc_type, exc_val, exc_tb): + pass + + +@pytest.fixture +def mock_mcp_sse_error(): + fake_mcp = ModuleType("mcp") + fake_mcp_client = ModuleType("mcp.client") + fake_mcp_client_sse = ModuleType("mcp.client.sse") + + sys.modules["mcp"] = fake_mcp + sys.modules["mcp.client"] = fake_mcp_client + sys.modules["mcp.client.sse"] = fake_mcp_client_sse + + mock_sse_client = MagicMock(side_effect=FakeSSEClientError) + fake_mcp_client_sse.sse_client = mock_sse_client + + yield mock_sse_client + + del sys.modules["mcp.client.sse"] + del sys.modules["mcp.client"] + del sys.modules["mcp"] + + +def test_agent_get_native_mcp_tools_raises_runtime_error_not_unbound_local_error( + mock_mcp_sse_error, +): + agent = Agent( + role="Test Agent", + goal="Test goal", + backstory="Test backstory", + ) + + mcp_config = MCPServerSSE( + url="https://example.com/sse", + headers={"Authorization": "Bearer token"}, + ) + + with pytest.raises(RuntimeError, match="Failed to get native MCP tools"): + agent._get_native_mcp_tools(mcp_config) + + +def test_agent_get_native_mcp_tools_error_message_contains_original_error( + mock_mcp_sse_error, +): + agent = Agent( + role="Test Agent", + goal="Test goal", + backstory="Test backstory", + ) + + mcp_config = MCPServerSSE( + url="https://example.com/sse", + ) + + with pytest.raises(RuntimeError) as exc_info: + agent._get_native_mcp_tools(mcp_config) + + assert "Failed to get native MCP tools" in str(exc_info.value) + assert exc_info.value.__cause__ is not None diff --git a/lib/crewai/tests/mcp/transports/__init__.py b/lib/crewai/tests/mcp/transports/__init__.py new file mode 100644 index 000000000..8b1378917 --- /dev/null +++ b/lib/crewai/tests/mcp/transports/__init__.py @@ -0,0 +1 @@ + diff --git a/lib/crewai/tests/mcp/transports/test_sse_transport.py b/lib/crewai/tests/mcp/transports/test_sse_transport.py new file mode 100644 index 000000000..57f373795 --- /dev/null +++ b/lib/crewai/tests/mcp/transports/test_sse_transport.py @@ -0,0 +1,112 @@ +import sys +from types import ModuleType +from unittest.mock import AsyncMock, MagicMock, call + +import pytest + +from crewai.mcp.transports.sse import SSETransport + + +class FakeSSEClient: + def __init__(self, url, headers=None): + self.url = url + self.headers = headers + self._read = AsyncMock() + self._write = AsyncMock() + + async def __aenter__(self): + return (self._read, self._write) + + async def __aexit__(self, exc_type, exc_val, exc_tb): + pass + + +@pytest.fixture +def mock_mcp_sse(): + fake_mcp = ModuleType("mcp") + fake_mcp_client = ModuleType("mcp.client") + fake_mcp_client_sse = ModuleType("mcp.client.sse") + + sys.modules["mcp"] = fake_mcp + sys.modules["mcp.client"] = fake_mcp_client + sys.modules["mcp.client.sse"] = fake_mcp_client_sse + + mock_sse_client = MagicMock(side_effect=FakeSSEClient) + fake_mcp_client_sse.sse_client = mock_sse_client + + yield mock_sse_client + + del sys.modules["mcp.client.sse"] + del sys.modules["mcp.client"] + del sys.modules["mcp"] + + +@pytest.mark.asyncio +async def test_sse_transport_connect_without_terminate_on_close(mock_mcp_sse): + transport = SSETransport( + url="https://example.com/sse", + headers={"Authorization": "Bearer token"}, + ) + + await transport.connect() + + mock_mcp_sse.assert_called_once_with( + "https://example.com/sse", + headers={"Authorization": "Bearer token"}, + ) + + call_kwargs = mock_mcp_sse.call_args[1] + assert "terminate_on_close" not in call_kwargs + + assert transport._connected is True + + +@pytest.mark.asyncio +async def test_sse_transport_connect_without_headers(mock_mcp_sse): + transport = SSETransport(url="https://example.com/sse") + + await transport.connect() + + mock_mcp_sse.assert_called_once_with( + "https://example.com/sse", + headers=None, + ) + + call_kwargs = mock_mcp_sse.call_args[1] + assert "terminate_on_close" not in call_kwargs + + +@pytest.mark.asyncio +async def test_sse_transport_connect_sets_streams(mock_mcp_sse): + transport = SSETransport(url="https://example.com/sse") + + await transport.connect() + + assert transport._read_stream is not None + assert transport._write_stream is not None + assert transport._connected is True + + +@pytest.mark.asyncio +async def test_sse_transport_context_manager(mock_mcp_sse): + async with SSETransport(url="https://example.com/sse") as transport: + assert transport._connected is True + + assert transport._connected is False + + +@pytest.mark.asyncio +async def test_sse_transport_connect_failure_raises_connection_error(mock_mcp_sse): + mock_sse_client_error = MagicMock( + side_effect=Exception("Connection failed") + ) + + fake_mcp_client_sse = sys.modules["mcp.client.sse"] + fake_mcp_client_sse.sse_client = mock_sse_client_error + + transport = SSETransport(url="https://example.com/sse") + + with pytest.raises(ConnectionError, match="Failed to connect to SSE MCP server"): + await transport.connect() + + assert transport._connected is False