mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-08 23:58:34 +00:00
Implements progress reporting and HTTP headers support for MCP tool integration to address issue #3797. Changes: - Add MCPToolProgressEvent to event system for real-time progress tracking - Extend MCPToolWrapper to support progress callbacks and event emission - Add mcp_progress_enabled flag to Agent for opt-in progress notifications - Add mcp_server_headers to Agent for middleware authentication/tracking - Thread progress and headers configuration through Agent._get_external_mcp_tools - Add comprehensive test coverage for progress and headers features - Update MCP DSL documentation with progress and headers examples Features: - Progress notifications emitted as MCPToolProgressEvent via event bus - Optional progress callback for custom progress handling - HTTP headers passthrough for authentication and middleware integration - Agent and task context included in progress events - Opt-in design ensures backward compatibility Tests: - Unit tests for MCPToolWrapper progress and headers functionality - Integration tests for Agent MCP configuration - Mock-based tests to avoid network dependencies Documentation: - Added Progress Notifications section with examples - Added Middleware Support with Headers section - Included complete examples for common use cases Fixes #3797 Co-Authored-By: João <joao@crewai.com>
351 lines
13 KiB
Python
351 lines
13 KiB
Python
"""Tests for MCPToolWrapper progress and headers support."""
|
|
|
|
import asyncio
|
|
from unittest.mock import AsyncMock, MagicMock, Mock, patch
|
|
|
|
import pytest
|
|
|
|
from crewai.events.event_bus import crewai_event_bus
|
|
from crewai.events.types.tool_usage_events import MCPToolProgressEvent
|
|
from crewai.tools.mcp_tool_wrapper import MCPToolWrapper
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_mcp_session():
|
|
"""Create a mock MCP ClientSession."""
|
|
session = AsyncMock()
|
|
session.initialize = AsyncMock()
|
|
session.call_tool = AsyncMock()
|
|
return session
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_streamable_client(mock_mcp_session):
|
|
"""Create a mock streamablehttp_client context manager."""
|
|
async def mock_client(*args, **kwargs):
|
|
read = AsyncMock()
|
|
write = AsyncMock()
|
|
close = AsyncMock()
|
|
|
|
class MockContextManager:
|
|
async def __aenter__(self):
|
|
return (read, write, close)
|
|
|
|
async def __aexit__(self, *args):
|
|
pass
|
|
|
|
return MockContextManager()
|
|
|
|
return mock_client
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_agent():
|
|
"""Create a mock agent with id and role."""
|
|
agent = Mock()
|
|
agent.id = "test-agent-id"
|
|
agent.role = "Test Agent"
|
|
return agent
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_task():
|
|
"""Create a mock task with id and description."""
|
|
task = Mock()
|
|
task.id = "test-task-id"
|
|
task.description = "Test Task Description"
|
|
task.name = None
|
|
return task
|
|
|
|
|
|
class TestMCPToolWrapperProgress:
|
|
"""Test suite for MCP tool wrapper progress notifications."""
|
|
|
|
def test_wrapper_initialization_with_progress_callback(self):
|
|
"""Test that MCPToolWrapper can be initialized with progress callback."""
|
|
callback = Mock()
|
|
|
|
wrapper = MCPToolWrapper(
|
|
mcp_server_params={"url": "https://example.com/mcp"},
|
|
tool_name="test_tool",
|
|
tool_schema={"description": "Test tool"},
|
|
server_name="test_server",
|
|
progress_callback=callback,
|
|
)
|
|
|
|
assert wrapper._progress_callback == callback
|
|
assert wrapper.name == "test_server_test_tool"
|
|
|
|
def test_wrapper_initialization_without_progress_callback(self):
|
|
"""Test that MCPToolWrapper works without progress callback."""
|
|
wrapper = MCPToolWrapper(
|
|
mcp_server_params={"url": "https://example.com/mcp"},
|
|
tool_name="test_tool",
|
|
tool_schema={"description": "Test tool"},
|
|
server_name="test_server",
|
|
)
|
|
|
|
assert wrapper._progress_callback is None
|
|
|
|
def test_wrapper_initialization_with_agent_and_task(self, mock_agent, mock_task):
|
|
"""Test that MCPToolWrapper can be initialized with agent and task context."""
|
|
wrapper = MCPToolWrapper(
|
|
mcp_server_params={"url": "https://example.com/mcp"},
|
|
tool_name="test_tool",
|
|
tool_schema={"description": "Test tool"},
|
|
server_name="test_server",
|
|
agent=mock_agent,
|
|
task=mock_task,
|
|
)
|
|
|
|
assert wrapper._agent == mock_agent
|
|
assert wrapper._task == mock_task
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_progress_handler_called_during_execution(self, mock_agent, mock_task):
|
|
"""Test that progress callback is invoked when MCP server sends progress."""
|
|
progress_callback = Mock()
|
|
|
|
wrapper = MCPToolWrapper(
|
|
mcp_server_params={"url": "https://example.com/mcp"},
|
|
tool_name="test_tool",
|
|
tool_schema={"description": "Test tool"},
|
|
server_name="test_server",
|
|
progress_callback=progress_callback,
|
|
agent=mock_agent,
|
|
task=mock_task,
|
|
)
|
|
|
|
mock_result = Mock()
|
|
mock_result.content = [Mock(text="Test result")]
|
|
|
|
with patch("crewai.tools.mcp_tool_wrapper.streamablehttp_client") as mock_client, \
|
|
patch("crewai.tools.mcp_tool_wrapper.ClientSession") as mock_session_class:
|
|
|
|
mock_session = AsyncMock()
|
|
mock_session.initialize = AsyncMock()
|
|
mock_session.call_tool = AsyncMock(return_value=mock_result)
|
|
mock_session.on_progress = None
|
|
|
|
mock_session_class.return_value.__aenter__ = AsyncMock(return_value=mock_session)
|
|
mock_session_class.return_value.__aexit__ = AsyncMock()
|
|
|
|
mock_client.return_value.__aenter__ = AsyncMock(return_value=(Mock(), Mock(), Mock()))
|
|
mock_client.return_value.__aexit__ = AsyncMock()
|
|
|
|
result = await wrapper._execute_tool(test_arg="test_value")
|
|
|
|
assert result == "Test result"
|
|
|
|
assert mock_session.on_progress is not None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_progress_event_emission(self, mock_agent, mock_task):
|
|
"""Test that MCPToolProgressEvent is emitted when progress is reported."""
|
|
events_received = []
|
|
|
|
def event_handler(source, event):
|
|
if isinstance(event, MCPToolProgressEvent):
|
|
events_received.append(event)
|
|
|
|
crewai_event_bus.register(MCPToolProgressEvent, event_handler)
|
|
|
|
try:
|
|
wrapper = MCPToolWrapper(
|
|
mcp_server_params={"url": "https://example.com/mcp"},
|
|
tool_name="test_tool",
|
|
tool_schema={"description": "Test tool"},
|
|
server_name="test_server",
|
|
progress_callback=Mock(),
|
|
agent=mock_agent,
|
|
task=mock_task,
|
|
)
|
|
|
|
wrapper._emit_progress_event(50.0, 100.0, "Processing...")
|
|
|
|
await asyncio.sleep(0.1)
|
|
|
|
assert len(events_received) == 1
|
|
event = events_received[0]
|
|
assert event.tool_name == "test_tool"
|
|
assert event.server_name == "test_server"
|
|
assert event.progress == 50.0
|
|
assert event.total == 100.0
|
|
assert event.message == "Processing..."
|
|
assert event.agent_id == "test-agent-id"
|
|
assert event.agent_role == "Test Agent"
|
|
assert event.task_id == "test-task-id"
|
|
assert event.task_name == "Test Task Description"
|
|
|
|
finally:
|
|
crewai_event_bus._sync_handlers.pop(MCPToolProgressEvent, None)
|
|
|
|
def test_progress_event_without_agent_context(self):
|
|
"""Test that progress events work without agent context."""
|
|
wrapper = MCPToolWrapper(
|
|
mcp_server_params={"url": "https://example.com/mcp"},
|
|
tool_name="test_tool",
|
|
tool_schema={"description": "Test tool"},
|
|
server_name="test_server",
|
|
progress_callback=Mock(),
|
|
)
|
|
|
|
wrapper._emit_progress_event(25.0, None, "Starting...")
|
|
|
|
def test_progress_event_without_task_context(self, mock_agent):
|
|
"""Test that progress events work without task context."""
|
|
wrapper = MCPToolWrapper(
|
|
mcp_server_params={"url": "https://example.com/mcp"},
|
|
tool_name="test_tool",
|
|
tool_schema={"description": "Test tool"},
|
|
server_name="test_server",
|
|
progress_callback=Mock(),
|
|
agent=mock_agent,
|
|
)
|
|
|
|
wrapper._emit_progress_event(75.0, 100.0, None)
|
|
|
|
|
|
class TestMCPToolWrapperHeaders:
|
|
"""Test suite for MCP tool wrapper headers support."""
|
|
|
|
def test_wrapper_initialization_with_headers(self):
|
|
"""Test that MCPToolWrapper accepts headers in server params."""
|
|
headers = {"Authorization": "Bearer token123", "X-Client-ID": "test-client"}
|
|
|
|
wrapper = MCPToolWrapper(
|
|
mcp_server_params={
|
|
"url": "https://example.com/mcp",
|
|
"headers": headers,
|
|
},
|
|
tool_name="test_tool",
|
|
tool_schema={"description": "Test tool"},
|
|
server_name="test_server",
|
|
)
|
|
|
|
assert wrapper.mcp_server_params["headers"] == headers
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_headers_passed_to_transport(self):
|
|
"""Test that headers are passed to streamablehttp_client."""
|
|
headers = {"Authorization": "Bearer token123"}
|
|
|
|
wrapper = MCPToolWrapper(
|
|
mcp_server_params={
|
|
"url": "https://example.com/mcp",
|
|
"headers": headers,
|
|
},
|
|
tool_name="test_tool",
|
|
tool_schema={"description": "Test tool"},
|
|
server_name="test_server",
|
|
)
|
|
|
|
mock_result = Mock()
|
|
mock_result.content = [Mock(text="Test result")]
|
|
|
|
with patch("crewai.tools.mcp_tool_wrapper.streamablehttp_client") as mock_client, \
|
|
patch("crewai.tools.mcp_tool_wrapper.ClientSession") as mock_session_class:
|
|
|
|
mock_session = AsyncMock()
|
|
mock_session.initialize = AsyncMock()
|
|
mock_session.call_tool = AsyncMock(return_value=mock_result)
|
|
|
|
mock_session_class.return_value.__aenter__ = AsyncMock(return_value=mock_session)
|
|
mock_session_class.return_value.__aexit__ = AsyncMock()
|
|
|
|
mock_client.return_value.__aenter__ = AsyncMock(return_value=(Mock(), Mock(), Mock()))
|
|
mock_client.return_value.__aexit__ = AsyncMock()
|
|
|
|
await wrapper._execute_tool(test_arg="test_value")
|
|
|
|
mock_client.assert_called_once()
|
|
call_args = mock_client.call_args
|
|
assert call_args[0][0] == "https://example.com/mcp"
|
|
assert "headers" in call_args[1]
|
|
assert call_args[1]["headers"] == headers
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_no_headers_when_not_configured(self):
|
|
"""Test that headers are not passed when not configured."""
|
|
wrapper = MCPToolWrapper(
|
|
mcp_server_params={"url": "https://example.com/mcp"},
|
|
tool_name="test_tool",
|
|
tool_schema={"description": "Test tool"},
|
|
server_name="test_server",
|
|
)
|
|
|
|
mock_result = Mock()
|
|
mock_result.content = [Mock(text="Test result")]
|
|
|
|
with patch("crewai.tools.mcp_tool_wrapper.streamablehttp_client") as mock_client, \
|
|
patch("crewai.tools.mcp_tool_wrapper.ClientSession") as mock_session_class:
|
|
|
|
mock_session = AsyncMock()
|
|
mock_session.initialize = AsyncMock()
|
|
mock_session.call_tool = AsyncMock(return_value=mock_result)
|
|
|
|
mock_session_class.return_value.__aenter__ = AsyncMock(return_value=mock_session)
|
|
mock_session_class.return_value.__aexit__ = AsyncMock()
|
|
|
|
mock_client.return_value.__aenter__ = AsyncMock(return_value=(Mock(), Mock(), Mock()))
|
|
mock_client.return_value.__aexit__ = AsyncMock()
|
|
|
|
await wrapper._execute_tool(test_arg="test_value")
|
|
|
|
mock_client.assert_called_once()
|
|
call_args = mock_client.call_args
|
|
assert "headers" not in call_args[1] or call_args[1].get("headers") is None
|
|
|
|
|
|
class TestMCPToolWrapperIntegration:
|
|
"""Integration tests for MCP tool wrapper with progress and headers."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_full_execution_with_progress_and_headers(self, mock_agent, mock_task):
|
|
"""Test complete execution flow with both progress and headers."""
|
|
progress_calls = []
|
|
|
|
def progress_callback(progress, total, message):
|
|
progress_calls.append((progress, total, message))
|
|
|
|
headers = {"Authorization": "Bearer test-token"}
|
|
|
|
wrapper = MCPToolWrapper(
|
|
mcp_server_params={
|
|
"url": "https://example.com/mcp",
|
|
"headers": headers,
|
|
},
|
|
tool_name="test_tool",
|
|
tool_schema={"description": "Test tool"},
|
|
server_name="test_server",
|
|
progress_callback=progress_callback,
|
|
agent=mock_agent,
|
|
task=mock_task,
|
|
)
|
|
|
|
mock_result = Mock()
|
|
mock_result.content = [Mock(text="Test result")]
|
|
|
|
with patch("crewai.tools.mcp_tool_wrapper.streamablehttp_client") as mock_client, \
|
|
patch("crewai.tools.mcp_tool_wrapper.ClientSession") as mock_session_class:
|
|
|
|
mock_session = AsyncMock()
|
|
mock_session.initialize = AsyncMock()
|
|
mock_session.call_tool = AsyncMock(return_value=mock_result)
|
|
mock_session.on_progress = None
|
|
|
|
mock_session_class.return_value.__aenter__ = AsyncMock(return_value=mock_session)
|
|
mock_session_class.return_value.__aexit__ = AsyncMock()
|
|
|
|
mock_client.return_value.__aenter__ = AsyncMock(return_value=(Mock(), Mock(), Mock()))
|
|
mock_client.return_value.__aexit__ = AsyncMock()
|
|
|
|
result = await wrapper._execute_tool(test_arg="test_value")
|
|
|
|
assert result == "Test result"
|
|
|
|
call_args = mock_client.call_args
|
|
assert call_args[1]["headers"] == headers
|
|
|
|
assert mock_session.on_progress is not None
|