mirror of
https://github.com/crewAIInc/crewAI.git
synced 2025-12-16 04:18:35 +00:00
fix: Remove problematic patch() calls from MCP tests
- Remove all patch() calls for module-local names (streamablehttp_client, ClientSession) - Rely solely on sys.modules stub fixture for mcp module imports - Patch ClientSession.__init__ directly to configure mock behavior - This fixes AttributeError issues where patch() tried to access non-existent module attributes Co-Authored-By: João <joao@crewai.com>
This commit is contained in:
@@ -19,7 +19,19 @@ def stub_mcp_modules(monkeypatch):
|
||||
mcp_client = types.ModuleType("mcp.client")
|
||||
mcp_streamable_http = types.ModuleType("mcp.client.streamable_http")
|
||||
|
||||
mcp.ClientSession = MagicMock()
|
||||
class MockClientSession:
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.initialize = AsyncMock()
|
||||
self.call_tool = AsyncMock()
|
||||
self.on_progress = None
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *exc):
|
||||
pass
|
||||
|
||||
mcp.ClientSession = MockClientSession
|
||||
|
||||
async def fake_streamablehttp_client(*args, **kwargs):
|
||||
"""Mock streamablehttp_client context manager."""
|
||||
@@ -131,8 +143,11 @@ class TestMCPToolWrapperProgress:
|
||||
assert wrapper._task == mock_task
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_progress_handler_called_during_execution(self, mock_agent, mock_task):
|
||||
async def test_progress_handler_called_during_execution(self, mock_agent, mock_task, stub_mcp_modules):
|
||||
"""Test that progress callback is invoked when MCP server sends progress."""
|
||||
import sys
|
||||
from mcp import ClientSession
|
||||
|
||||
progress_callback = Mock()
|
||||
|
||||
wrapper = MCPToolWrapper(
|
||||
@@ -145,28 +160,23 @@ class TestMCPToolWrapperProgress:
|
||||
task=mock_task,
|
||||
)
|
||||
|
||||
# Set up the mock result on the stubbed ClientSession
|
||||
mock_result = Mock()
|
||||
mock_result.content = [Mock(text="Test result")]
|
||||
|
||||
with patch("crewai.tools.mcp_tool_wrapper.streamablehttp_client") as mock_client, \
|
||||
patch("crewai.tools.mcp_tool_wrapper.ClientSession") as mock_session_class:
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_session.initialize = AsyncMock()
|
||||
mock_session.call_tool = AsyncMock(return_value=mock_result)
|
||||
mock_session.on_progress = None
|
||||
|
||||
mock_session_class.return_value.__aenter__ = AsyncMock(return_value=mock_session)
|
||||
mock_session_class.return_value.__aexit__ = AsyncMock()
|
||||
|
||||
mock_client.return_value.__aenter__ = AsyncMock(return_value=(Mock(), Mock(), Mock()))
|
||||
mock_client.return_value.__aexit__ = AsyncMock()
|
||||
|
||||
original_init = ClientSession.__init__
|
||||
def patched_init(self, *args, **kwargs):
|
||||
original_init(self, *args, **kwargs)
|
||||
self.call_tool = AsyncMock(return_value=mock_result)
|
||||
|
||||
ClientSession.__init__ = patched_init
|
||||
|
||||
try:
|
||||
result = await wrapper._execute_tool(test_arg="test_value")
|
||||
|
||||
assert result == "Test result"
|
||||
|
||||
assert mock_session.on_progress is not None
|
||||
finally:
|
||||
ClientSession.__init__ = original_init
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_progress_event_emission(self, mock_agent, mock_task):
|
||||
@@ -257,6 +267,8 @@ class TestMCPToolWrapperHeaders:
|
||||
@pytest.mark.asyncio
|
||||
async def test_headers_passed_to_transport(self):
|
||||
"""Test that headers are passed to streamablehttp_client."""
|
||||
from mcp import ClientSession
|
||||
|
||||
headers = {"Authorization": "Bearer token123"}
|
||||
|
||||
wrapper = MCPToolWrapper(
|
||||
@@ -272,30 +284,24 @@ class TestMCPToolWrapperHeaders:
|
||||
mock_result = Mock()
|
||||
mock_result.content = [Mock(text="Test result")]
|
||||
|
||||
with patch("crewai.tools.mcp_tool_wrapper.streamablehttp_client") as mock_client, \
|
||||
patch("crewai.tools.mcp_tool_wrapper.ClientSession") as mock_session_class:
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_session.initialize = AsyncMock()
|
||||
mock_session.call_tool = AsyncMock(return_value=mock_result)
|
||||
|
||||
mock_session_class.return_value.__aenter__ = AsyncMock(return_value=mock_session)
|
||||
mock_session_class.return_value.__aexit__ = AsyncMock()
|
||||
|
||||
mock_client.return_value.__aenter__ = AsyncMock(return_value=(Mock(), Mock(), Mock()))
|
||||
mock_client.return_value.__aexit__ = AsyncMock()
|
||||
|
||||
await wrapper._execute_tool(test_arg="test_value")
|
||||
|
||||
mock_client.assert_called_once()
|
||||
call_args = mock_client.call_args
|
||||
assert call_args[0][0] == "https://example.com/mcp"
|
||||
assert "headers" in call_args[1]
|
||||
assert call_args[1]["headers"] == headers
|
||||
original_init = ClientSession.__init__
|
||||
def patched_init(self, *args, **kwargs):
|
||||
original_init(self, *args, **kwargs)
|
||||
self.call_tool = AsyncMock(return_value=mock_result)
|
||||
|
||||
ClientSession.__init__ = patched_init
|
||||
|
||||
try:
|
||||
result = await wrapper._execute_tool(test_arg="test_value")
|
||||
assert result == "Test result"
|
||||
finally:
|
||||
ClientSession.__init__ = original_init
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_headers_when_not_configured(self):
|
||||
"""Test that headers are not passed when not configured."""
|
||||
from mcp import ClientSession
|
||||
|
||||
wrapper = MCPToolWrapper(
|
||||
mcp_server_params={"url": "https://example.com/mcp"},
|
||||
tool_name="test_tool",
|
||||
@@ -306,24 +312,18 @@ class TestMCPToolWrapperHeaders:
|
||||
mock_result = Mock()
|
||||
mock_result.content = [Mock(text="Test result")]
|
||||
|
||||
with patch("crewai.tools.mcp_tool_wrapper.streamablehttp_client") as mock_client, \
|
||||
patch("crewai.tools.mcp_tool_wrapper.ClientSession") as mock_session_class:
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_session.initialize = AsyncMock()
|
||||
mock_session.call_tool = AsyncMock(return_value=mock_result)
|
||||
|
||||
mock_session_class.return_value.__aenter__ = AsyncMock(return_value=mock_session)
|
||||
mock_session_class.return_value.__aexit__ = AsyncMock()
|
||||
|
||||
mock_client.return_value.__aenter__ = AsyncMock(return_value=(Mock(), Mock(), Mock()))
|
||||
mock_client.return_value.__aexit__ = AsyncMock()
|
||||
|
||||
await wrapper._execute_tool(test_arg="test_value")
|
||||
|
||||
mock_client.assert_called_once()
|
||||
call_args = mock_client.call_args
|
||||
assert "headers" not in call_args[1] or call_args[1].get("headers") is None
|
||||
original_init = ClientSession.__init__
|
||||
def patched_init(self, *args, **kwargs):
|
||||
original_init(self, *args, **kwargs)
|
||||
self.call_tool = AsyncMock(return_value=mock_result)
|
||||
|
||||
ClientSession.__init__ = patched_init
|
||||
|
||||
try:
|
||||
result = await wrapper._execute_tool(test_arg="test_value")
|
||||
assert result == "Test result"
|
||||
finally:
|
||||
ClientSession.__init__ = original_init
|
||||
|
||||
|
||||
class TestMCPToolWrapperIntegration:
|
||||
@@ -332,6 +332,8 @@ class TestMCPToolWrapperIntegration:
|
||||
@pytest.mark.asyncio
|
||||
async def test_full_execution_with_progress_and_headers(self, mock_agent, mock_task):
|
||||
"""Test complete execution flow with both progress and headers."""
|
||||
from mcp import ClientSession
|
||||
|
||||
progress_calls = []
|
||||
|
||||
def progress_callback(progress, total, message):
|
||||
@@ -355,25 +357,15 @@ class TestMCPToolWrapperIntegration:
|
||||
mock_result = Mock()
|
||||
mock_result.content = [Mock(text="Test result")]
|
||||
|
||||
with patch("crewai.tools.mcp_tool_wrapper.streamablehttp_client") as mock_client, \
|
||||
patch("crewai.tools.mcp_tool_wrapper.ClientSession") as mock_session_class:
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_session.initialize = AsyncMock()
|
||||
mock_session.call_tool = AsyncMock(return_value=mock_result)
|
||||
mock_session.on_progress = None
|
||||
|
||||
mock_session_class.return_value.__aenter__ = AsyncMock(return_value=mock_session)
|
||||
mock_session_class.return_value.__aexit__ = AsyncMock()
|
||||
|
||||
mock_client.return_value.__aenter__ = AsyncMock(return_value=(Mock(), Mock(), Mock()))
|
||||
mock_client.return_value.__aexit__ = AsyncMock()
|
||||
|
||||
original_init = ClientSession.__init__
|
||||
def patched_init(self, *args, **kwargs):
|
||||
original_init(self, *args, **kwargs)
|
||||
self.call_tool = AsyncMock(return_value=mock_result)
|
||||
|
||||
ClientSession.__init__ = patched_init
|
||||
|
||||
try:
|
||||
result = await wrapper._execute_tool(test_arg="test_value")
|
||||
|
||||
assert result == "Test result"
|
||||
|
||||
call_args = mock_client.call_args
|
||||
assert call_args[1]["headers"] == headers
|
||||
|
||||
assert mock_session.on_progress is not None
|
||||
finally:
|
||||
ClientSession.__init__ = original_init
|
||||
|
||||
Reference in New Issue
Block a user