mirror of
https://github.com/crewAIInc/crewAI.git
synced 2025-12-16 04:18:35 +00:00
Fix SSE transport invalid parameter issue
Remove terminate_on_close parameter from sse_client call in SSETransport. This parameter is only valid for streamable HTTP transport, not for SSE transport. Fixes #3938 Changes: - Remove terminate_on_close=True from sse_client() call in SSETransport.connect() - Add comprehensive unit tests for SSETransport that mock sse_client and verify correct parameters - Add error-path tests to ensure Agent._get_native_mcp_tools raises proper RuntimeError instead of UnboundLocalError when connection fails Co-Authored-By: João <joao@crewai.com>
This commit is contained in:
@@ -66,7 +66,6 @@ class SSETransport(BaseTransport):
|
||||
self._transport_context = sse_client(
|
||||
self.url,
|
||||
headers=self.headers if self.headers else None,
|
||||
terminate_on_close=True,
|
||||
)
|
||||
|
||||
read, write = await self._transport_context.__aenter__()
|
||||
|
||||
78
lib/crewai/tests/mcp/test_agent_mcp_error_handling.py
Normal file
78
lib/crewai/tests/mcp/test_agent_mcp_error_handling.py
Normal file
@@ -0,0 +1,78 @@
|
||||
import sys
|
||||
from types import ModuleType
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.agent.core import Agent
|
||||
from crewai.mcp import MCPServerSSE
|
||||
|
||||
|
||||
class FakeSSEClientError:
|
||||
def __init__(self, url, headers=None):
|
||||
self.url = url
|
||||
self.headers = headers
|
||||
|
||||
async def __aenter__(self):
|
||||
raise Exception("SSE connection failed")
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
pass
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_mcp_sse_error():
|
||||
fake_mcp = ModuleType("mcp")
|
||||
fake_mcp_client = ModuleType("mcp.client")
|
||||
fake_mcp_client_sse = ModuleType("mcp.client.sse")
|
||||
|
||||
sys.modules["mcp"] = fake_mcp
|
||||
sys.modules["mcp.client"] = fake_mcp_client
|
||||
sys.modules["mcp.client.sse"] = fake_mcp_client_sse
|
||||
|
||||
mock_sse_client = MagicMock(side_effect=FakeSSEClientError)
|
||||
fake_mcp_client_sse.sse_client = mock_sse_client
|
||||
|
||||
yield mock_sse_client
|
||||
|
||||
del sys.modules["mcp.client.sse"]
|
||||
del sys.modules["mcp.client"]
|
||||
del sys.modules["mcp"]
|
||||
|
||||
|
||||
def test_agent_get_native_mcp_tools_raises_runtime_error_not_unbound_local_error(
|
||||
mock_mcp_sse_error,
|
||||
):
|
||||
agent = Agent(
|
||||
role="Test Agent",
|
||||
goal="Test goal",
|
||||
backstory="Test backstory",
|
||||
)
|
||||
|
||||
mcp_config = MCPServerSSE(
|
||||
url="https://example.com/sse",
|
||||
headers={"Authorization": "Bearer token"},
|
||||
)
|
||||
|
||||
with pytest.raises(RuntimeError, match="Failed to get native MCP tools"):
|
||||
agent._get_native_mcp_tools(mcp_config)
|
||||
|
||||
|
||||
def test_agent_get_native_mcp_tools_error_message_contains_original_error(
|
||||
mock_mcp_sse_error,
|
||||
):
|
||||
agent = Agent(
|
||||
role="Test Agent",
|
||||
goal="Test goal",
|
||||
backstory="Test backstory",
|
||||
)
|
||||
|
||||
mcp_config = MCPServerSSE(
|
||||
url="https://example.com/sse",
|
||||
)
|
||||
|
||||
with pytest.raises(RuntimeError) as exc_info:
|
||||
agent._get_native_mcp_tools(mcp_config)
|
||||
|
||||
assert "Failed to get native MCP tools" in str(exc_info.value)
|
||||
assert exc_info.value.__cause__ is not None
|
||||
1
lib/crewai/tests/mcp/transports/__init__.py
Normal file
1
lib/crewai/tests/mcp/transports/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
|
||||
112
lib/crewai/tests/mcp/transports/test_sse_transport.py
Normal file
112
lib/crewai/tests/mcp/transports/test_sse_transport.py
Normal file
@@ -0,0 +1,112 @@
|
||||
import sys
|
||||
from types import ModuleType
|
||||
from unittest.mock import AsyncMock, MagicMock, call
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.mcp.transports.sse import SSETransport
|
||||
|
||||
|
||||
class FakeSSEClient:
|
||||
def __init__(self, url, headers=None):
|
||||
self.url = url
|
||||
self.headers = headers
|
||||
self._read = AsyncMock()
|
||||
self._write = AsyncMock()
|
||||
|
||||
async def __aenter__(self):
|
||||
return (self._read, self._write)
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
pass
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_mcp_sse():
|
||||
fake_mcp = ModuleType("mcp")
|
||||
fake_mcp_client = ModuleType("mcp.client")
|
||||
fake_mcp_client_sse = ModuleType("mcp.client.sse")
|
||||
|
||||
sys.modules["mcp"] = fake_mcp
|
||||
sys.modules["mcp.client"] = fake_mcp_client
|
||||
sys.modules["mcp.client.sse"] = fake_mcp_client_sse
|
||||
|
||||
mock_sse_client = MagicMock(side_effect=FakeSSEClient)
|
||||
fake_mcp_client_sse.sse_client = mock_sse_client
|
||||
|
||||
yield mock_sse_client
|
||||
|
||||
del sys.modules["mcp.client.sse"]
|
||||
del sys.modules["mcp.client"]
|
||||
del sys.modules["mcp"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sse_transport_connect_without_terminate_on_close(mock_mcp_sse):
|
||||
transport = SSETransport(
|
||||
url="https://example.com/sse",
|
||||
headers={"Authorization": "Bearer token"},
|
||||
)
|
||||
|
||||
await transport.connect()
|
||||
|
||||
mock_mcp_sse.assert_called_once_with(
|
||||
"https://example.com/sse",
|
||||
headers={"Authorization": "Bearer token"},
|
||||
)
|
||||
|
||||
call_kwargs = mock_mcp_sse.call_args[1]
|
||||
assert "terminate_on_close" not in call_kwargs
|
||||
|
||||
assert transport._connected is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sse_transport_connect_without_headers(mock_mcp_sse):
|
||||
transport = SSETransport(url="https://example.com/sse")
|
||||
|
||||
await transport.connect()
|
||||
|
||||
mock_mcp_sse.assert_called_once_with(
|
||||
"https://example.com/sse",
|
||||
headers=None,
|
||||
)
|
||||
|
||||
call_kwargs = mock_mcp_sse.call_args[1]
|
||||
assert "terminate_on_close" not in call_kwargs
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sse_transport_connect_sets_streams(mock_mcp_sse):
|
||||
transport = SSETransport(url="https://example.com/sse")
|
||||
|
||||
await transport.connect()
|
||||
|
||||
assert transport._read_stream is not None
|
||||
assert transport._write_stream is not None
|
||||
assert transport._connected is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sse_transport_context_manager(mock_mcp_sse):
|
||||
async with SSETransport(url="https://example.com/sse") as transport:
|
||||
assert transport._connected is True
|
||||
|
||||
assert transport._connected is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sse_transport_connect_failure_raises_connection_error(mock_mcp_sse):
|
||||
mock_sse_client_error = MagicMock(
|
||||
side_effect=Exception("Connection failed")
|
||||
)
|
||||
|
||||
fake_mcp_client_sse = sys.modules["mcp.client.sse"]
|
||||
fake_mcp_client_sse.sse_client = mock_sse_client_error
|
||||
|
||||
transport = SSETransport(url="https://example.com/sse")
|
||||
|
||||
with pytest.raises(ConnectionError, match="Failed to connect to SSE MCP server"):
|
||||
await transport.connect()
|
||||
|
||||
assert transport._connected is False
|
||||
Reference in New Issue
Block a user