feat: Add MCP progress notifications and middleware headers support

Implements progress reporting and HTTP headers support for MCP tool integration
to address issue #3797.

Changes:
- Add MCPToolProgressEvent to event system for real-time progress tracking
- Extend MCPToolWrapper to support progress callbacks and event emission
- Add mcp_progress_enabled flag to Agent for opt-in progress notifications
- Add mcp_server_headers to Agent for middleware authentication/tracking
- Thread progress and headers configuration through Agent._get_external_mcp_tools
- Add comprehensive test coverage for progress and headers features
- Update MCP DSL documentation with progress and headers examples

Features:
- Progress notifications emitted as MCPToolProgressEvent via event bus
- Optional progress callback for custom progress handling
- HTTP headers passthrough for authentication and middleware integration
- Agent and task context included in progress events
- Opt-in design ensures backward compatibility

Tests:
- Unit tests for MCPToolWrapper progress and headers functionality
- Integration tests for Agent MCP configuration
- Mock-based tests to avoid network dependencies

Documentation:
- Added Progress Notifications section with examples
- Added Middleware Support with Headers section
- Included complete examples for common use cases

Fixes #3797

Co-Authored-By: João <joao@crewai.com>
This commit is contained in:
Devin AI
2025-10-26 09:45:50 +00:00
parent 494ed7e671
commit 873d501401
8 changed files with 996 additions and 3 deletions

View File

@@ -0,0 +1,324 @@
"""Tests for Agent MCP progress and headers configuration."""
from unittest.mock import Mock, patch
import pytest
from crewai.agent import Agent
class TestAgentMCPProgressConfiguration:
"""Test suite for Agent MCP progress configuration."""
def test_agent_initialization_with_mcp_progress_enabled(self):
"""Test that Agent can be initialized with mcp_progress_enabled."""
agent = Agent(
role="Test Agent",
goal="Test goal",
backstory="Test backstory",
mcps=["https://example.com/mcp"],
mcp_progress_enabled=True,
)
assert agent.mcp_progress_enabled is True
def test_agent_initialization_with_mcp_progress_disabled(self):
"""Test that Agent defaults to mcp_progress_enabled=False."""
agent = Agent(
role="Test Agent",
goal="Test goal",
backstory="Test backstory",
mcps=["https://example.com/mcp"],
)
assert agent.mcp_progress_enabled is False
def test_agent_initialization_with_mcp_server_headers(self):
"""Test that Agent can be initialized with mcp_server_headers."""
headers = {"Authorization": "Bearer token123", "X-Client-ID": "test-client"}
agent = Agent(
role="Test Agent",
goal="Test goal",
backstory="Test backstory",
mcps=["https://example.com/mcp"],
mcp_server_headers=headers,
)
assert agent.mcp_server_headers == headers
def test_agent_initialization_without_mcp_server_headers(self):
"""Test that Agent defaults to None for mcp_server_headers."""
agent = Agent(
role="Test Agent",
goal="Test goal",
backstory="Test backstory",
mcps=["https://example.com/mcp"],
)
assert agent.mcp_server_headers is None
def test_agent_with_both_progress_and_headers(self):
"""Test that Agent can be initialized with both progress and headers."""
headers = {"Authorization": "Bearer token123"}
agent = Agent(
role="Test Agent",
goal="Test goal",
backstory="Test backstory",
mcps=["https://example.com/mcp"],
mcp_progress_enabled=True,
mcp_server_headers=headers,
)
assert agent.mcp_progress_enabled is True
assert agent.mcp_server_headers == headers
class TestAgentMCPToolCreation:
"""Test suite for Agent MCP tool creation with progress and headers."""
@patch("crewai.agent.Agent._get_mcp_tool_schemas")
@patch("crewai.tools.mcp_tool_wrapper.MCPToolWrapper")
def test_get_external_mcp_tools_passes_headers(
self, mock_wrapper_class, mock_get_schemas
):
"""Test that _get_external_mcp_tools passes headers to server_params."""
headers = {"Authorization": "Bearer token123"}
agent = Agent(
role="Test Agent",
goal="Test goal",
backstory="Test backstory",
mcp_server_headers=headers,
)
mock_get_schemas.return_value = {
"test_tool": {"description": "Test tool"}
}
mock_wrapper_instance = Mock()
mock_wrapper_class.return_value = mock_wrapper_instance
tools = agent._get_external_mcp_tools("https://example.com/mcp")
assert mock_wrapper_class.called
call_args = mock_wrapper_class.call_args
server_params = call_args[1]["mcp_server_params"]
assert "headers" in server_params
assert server_params["headers"] == headers
@patch("crewai.agent.Agent._get_mcp_tool_schemas")
@patch("crewai.tools.mcp_tool_wrapper.MCPToolWrapper")
def test_get_external_mcp_tools_no_headers_when_not_configured(
self, mock_wrapper_class, mock_get_schemas
):
"""Test that _get_external_mcp_tools doesn't pass headers when not configured."""
agent = Agent(
role="Test Agent",
goal="Test goal",
backstory="Test backstory",
)
mock_get_schemas.return_value = {
"test_tool": {"description": "Test tool"}
}
mock_wrapper_instance = Mock()
mock_wrapper_class.return_value = mock_wrapper_instance
tools = agent._get_external_mcp_tools("https://example.com/mcp")
assert mock_wrapper_class.called
call_args = mock_wrapper_class.call_args
server_params = call_args[1]["mcp_server_params"]
assert "headers" not in server_params
@patch("crewai.agent.Agent._get_mcp_tool_schemas")
@patch("crewai.tools.mcp_tool_wrapper.MCPToolWrapper")
def test_get_external_mcp_tools_passes_progress_callback_when_enabled(
self, mock_wrapper_class, mock_get_schemas
):
"""Test that _get_external_mcp_tools passes progress callback when enabled."""
agent = Agent(
role="Test Agent",
goal="Test goal",
backstory="Test backstory",
mcp_progress_enabled=True,
)
mock_get_schemas.return_value = {
"test_tool": {"description": "Test tool"}
}
mock_wrapper_instance = Mock()
mock_wrapper_class.return_value = mock_wrapper_instance
tools = agent._get_external_mcp_tools("https://example.com/mcp")
assert mock_wrapper_class.called
call_args = mock_wrapper_class.call_args
assert "progress_callback" in call_args[1]
assert call_args[1]["progress_callback"] is not None
@patch("crewai.agent.Agent._get_mcp_tool_schemas")
@patch("crewai.tools.mcp_tool_wrapper.MCPToolWrapper")
def test_get_external_mcp_tools_no_progress_callback_when_disabled(
self, mock_wrapper_class, mock_get_schemas
):
"""Test that _get_external_mcp_tools doesn't pass progress callback when disabled."""
agent = Agent(
role="Test Agent",
goal="Test goal",
backstory="Test backstory",
mcp_progress_enabled=False,
)
mock_get_schemas.return_value = {
"test_tool": {"description": "Test tool"}
}
mock_wrapper_instance = Mock()
mock_wrapper_class.return_value = mock_wrapper_instance
tools = agent._get_external_mcp_tools("https://example.com/mcp")
assert mock_wrapper_class.called
call_args = mock_wrapper_class.call_args
assert call_args[1]["progress_callback"] is None
@patch("crewai.agent.Agent._get_mcp_tool_schemas")
@patch("crewai.tools.mcp_tool_wrapper.MCPToolWrapper")
def test_get_external_mcp_tools_passes_agent_context(
self, mock_wrapper_class, mock_get_schemas
):
"""Test that _get_external_mcp_tools passes agent context to wrapper."""
agent = Agent(
role="Test Agent",
goal="Test goal",
backstory="Test backstory",
mcp_progress_enabled=True,
)
mock_get_schemas.return_value = {
"test_tool": {"description": "Test tool"}
}
mock_wrapper_instance = Mock()
mock_wrapper_class.return_value = mock_wrapper_instance
tools = agent._get_external_mcp_tools("https://example.com/mcp")
assert mock_wrapper_class.called
call_args = mock_wrapper_class.call_args
assert "agent" in call_args[1]
assert call_args[1]["agent"] == agent
@patch("crewai.agent.Agent._get_mcp_tool_schemas")
@patch("crewai.tools.mcp_tool_wrapper.MCPToolWrapper")
def test_get_external_mcp_tools_passes_task_context(
self, mock_wrapper_class, mock_get_schemas
):
"""Test that _get_external_mcp_tools passes task context to wrapper."""
agent = Agent(
role="Test Agent",
goal="Test goal",
backstory="Test backstory",
mcp_progress_enabled=True,
)
mock_get_schemas.return_value = {
"test_tool": {"description": "Test tool"}
}
mock_wrapper_instance = Mock()
mock_wrapper_class.return_value = mock_wrapper_instance
mock_task = Mock()
mock_task.id = "test-task-id"
tools = agent._get_external_mcp_tools("https://example.com/mcp", task=mock_task)
assert mock_wrapper_class.called
call_args = mock_wrapper_class.call_args
assert "task" in call_args[1]
assert call_args[1]["task"] == mock_task
@patch("crewai.agent.Agent._get_mcp_tool_schemas")
@patch("crewai.tools.mcp_tool_wrapper.MCPToolWrapper")
def test_get_external_mcp_tools_with_all_features(
self, mock_wrapper_class, mock_get_schemas
):
"""Test _get_external_mcp_tools with progress, headers, and context."""
headers = {"Authorization": "Bearer token123"}
agent = Agent(
role="Test Agent",
goal="Test goal",
backstory="Test backstory",
mcp_progress_enabled=True,
mcp_server_headers=headers,
)
mock_get_schemas.return_value = {
"test_tool": {"description": "Test tool"}
}
mock_wrapper_instance = Mock()
mock_wrapper_class.return_value = mock_wrapper_instance
mock_task = Mock()
mock_task.id = "test-task-id"
tools = agent._get_external_mcp_tools("https://example.com/mcp", task=mock_task)
assert mock_wrapper_class.called
call_args = mock_wrapper_class.call_args
server_params = call_args[1]["mcp_server_params"]
assert server_params["headers"] == headers
assert call_args[1]["progress_callback"] is not None
assert call_args[1]["agent"] == agent
assert call_args[1]["task"] == mock_task
class TestAgentMCPProgressCallback:
"""Test suite for Agent MCP progress callback behavior."""
@patch("crewai.agent.Agent._get_mcp_tool_schemas")
@patch("crewai.tools.mcp_tool_wrapper.MCPToolWrapper")
def test_progress_callback_logs_progress(
self, mock_wrapper_class, mock_get_schemas
):
"""Test that progress callback logs progress information."""
agent = Agent(
role="Test Agent",
goal="Test goal",
backstory="Test backstory",
mcp_progress_enabled=True,
)
mock_get_schemas.return_value = {
"test_tool": {"description": "Test tool"}
}
mock_wrapper_instance = Mock()
mock_wrapper_class.return_value = mock_wrapper_instance
with patch.object(agent._logger, "log") as mock_log:
tools = agent._get_external_mcp_tools("https://example.com/mcp")
call_args = mock_wrapper_class.call_args
progress_callback = call_args[1]["progress_callback"]
progress_callback(50.0, 100.0, "Processing...")
mock_log.assert_called_once()
log_call = mock_log.call_args
assert log_call[0][0] == "debug"
assert "test_tool" in log_call[0][1]
assert "50.0" in log_call[0][1]
assert "100.0" in log_call[0][1]
assert "Processing..." in log_call[0][1]

View File

@@ -0,0 +1,350 @@
"""Tests for MCPToolWrapper progress and headers support."""
import asyncio
from unittest.mock import AsyncMock, MagicMock, Mock, patch
import pytest
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.tool_usage_events import MCPToolProgressEvent
from crewai.tools.mcp_tool_wrapper import MCPToolWrapper
@pytest.fixture
def mock_mcp_session():
"""Create a mock MCP ClientSession."""
session = AsyncMock()
session.initialize = AsyncMock()
session.call_tool = AsyncMock()
return session
@pytest.fixture
def mock_streamable_client(mock_mcp_session):
"""Create a mock streamablehttp_client context manager."""
async def mock_client(*args, **kwargs):
read = AsyncMock()
write = AsyncMock()
close = AsyncMock()
class MockContextManager:
async def __aenter__(self):
return (read, write, close)
async def __aexit__(self, *args):
pass
return MockContextManager()
return mock_client
@pytest.fixture
def mock_agent():
"""Create a mock agent with id and role."""
agent = Mock()
agent.id = "test-agent-id"
agent.role = "Test Agent"
return agent
@pytest.fixture
def mock_task():
"""Create a mock task with id and description."""
task = Mock()
task.id = "test-task-id"
task.description = "Test Task Description"
task.name = None
return task
class TestMCPToolWrapperProgress:
"""Test suite for MCP tool wrapper progress notifications."""
def test_wrapper_initialization_with_progress_callback(self):
"""Test that MCPToolWrapper can be initialized with progress callback."""
callback = Mock()
wrapper = MCPToolWrapper(
mcp_server_params={"url": "https://example.com/mcp"},
tool_name="test_tool",
tool_schema={"description": "Test tool"},
server_name="test_server",
progress_callback=callback,
)
assert wrapper._progress_callback == callback
assert wrapper.name == "test_server_test_tool"
def test_wrapper_initialization_without_progress_callback(self):
"""Test that MCPToolWrapper works without progress callback."""
wrapper = MCPToolWrapper(
mcp_server_params={"url": "https://example.com/mcp"},
tool_name="test_tool",
tool_schema={"description": "Test tool"},
server_name="test_server",
)
assert wrapper._progress_callback is None
def test_wrapper_initialization_with_agent_and_task(self, mock_agent, mock_task):
"""Test that MCPToolWrapper can be initialized with agent and task context."""
wrapper = MCPToolWrapper(
mcp_server_params={"url": "https://example.com/mcp"},
tool_name="test_tool",
tool_schema={"description": "Test tool"},
server_name="test_server",
agent=mock_agent,
task=mock_task,
)
assert wrapper._agent == mock_agent
assert wrapper._task == mock_task
@pytest.mark.asyncio
async def test_progress_handler_called_during_execution(self, mock_agent, mock_task):
"""Test that progress callback is invoked when MCP server sends progress."""
progress_callback = Mock()
wrapper = MCPToolWrapper(
mcp_server_params={"url": "https://example.com/mcp"},
tool_name="test_tool",
tool_schema={"description": "Test tool"},
server_name="test_server",
progress_callback=progress_callback,
agent=mock_agent,
task=mock_task,
)
mock_result = Mock()
mock_result.content = [Mock(text="Test result")]
with patch("crewai.tools.mcp_tool_wrapper.streamablehttp_client") as mock_client, \
patch("crewai.tools.mcp_tool_wrapper.ClientSession") as mock_session_class:
mock_session = AsyncMock()
mock_session.initialize = AsyncMock()
mock_session.call_tool = AsyncMock(return_value=mock_result)
mock_session.on_progress = None
mock_session_class.return_value.__aenter__ = AsyncMock(return_value=mock_session)
mock_session_class.return_value.__aexit__ = AsyncMock()
mock_client.return_value.__aenter__ = AsyncMock(return_value=(Mock(), Mock(), Mock()))
mock_client.return_value.__aexit__ = AsyncMock()
result = await wrapper._execute_tool(test_arg="test_value")
assert result == "Test result"
assert mock_session.on_progress is not None
@pytest.mark.asyncio
async def test_progress_event_emission(self, mock_agent, mock_task):
"""Test that MCPToolProgressEvent is emitted when progress is reported."""
events_received = []
def event_handler(source, event):
if isinstance(event, MCPToolProgressEvent):
events_received.append(event)
crewai_event_bus.register(MCPToolProgressEvent, event_handler)
try:
wrapper = MCPToolWrapper(
mcp_server_params={"url": "https://example.com/mcp"},
tool_name="test_tool",
tool_schema={"description": "Test tool"},
server_name="test_server",
progress_callback=Mock(),
agent=mock_agent,
task=mock_task,
)
wrapper._emit_progress_event(50.0, 100.0, "Processing...")
await asyncio.sleep(0.1)
assert len(events_received) == 1
event = events_received[0]
assert event.tool_name == "test_tool"
assert event.server_name == "test_server"
assert event.progress == 50.0
assert event.total == 100.0
assert event.message == "Processing..."
assert event.agent_id == "test-agent-id"
assert event.agent_role == "Test Agent"
assert event.task_id == "test-task-id"
assert event.task_name == "Test Task Description"
finally:
crewai_event_bus._sync_handlers.pop(MCPToolProgressEvent, None)
def test_progress_event_without_agent_context(self):
"""Test that progress events work without agent context."""
wrapper = MCPToolWrapper(
mcp_server_params={"url": "https://example.com/mcp"},
tool_name="test_tool",
tool_schema={"description": "Test tool"},
server_name="test_server",
progress_callback=Mock(),
)
wrapper._emit_progress_event(25.0, None, "Starting...")
def test_progress_event_without_task_context(self, mock_agent):
"""Test that progress events work without task context."""
wrapper = MCPToolWrapper(
mcp_server_params={"url": "https://example.com/mcp"},
tool_name="test_tool",
tool_schema={"description": "Test tool"},
server_name="test_server",
progress_callback=Mock(),
agent=mock_agent,
)
wrapper._emit_progress_event(75.0, 100.0, None)
class TestMCPToolWrapperHeaders:
"""Test suite for MCP tool wrapper headers support."""
def test_wrapper_initialization_with_headers(self):
"""Test that MCPToolWrapper accepts headers in server params."""
headers = {"Authorization": "Bearer token123", "X-Client-ID": "test-client"}
wrapper = MCPToolWrapper(
mcp_server_params={
"url": "https://example.com/mcp",
"headers": headers,
},
tool_name="test_tool",
tool_schema={"description": "Test tool"},
server_name="test_server",
)
assert wrapper.mcp_server_params["headers"] == headers
@pytest.mark.asyncio
async def test_headers_passed_to_transport(self):
"""Test that headers are passed to streamablehttp_client."""
headers = {"Authorization": "Bearer token123"}
wrapper = MCPToolWrapper(
mcp_server_params={
"url": "https://example.com/mcp",
"headers": headers,
},
tool_name="test_tool",
tool_schema={"description": "Test tool"},
server_name="test_server",
)
mock_result = Mock()
mock_result.content = [Mock(text="Test result")]
with patch("crewai.tools.mcp_tool_wrapper.streamablehttp_client") as mock_client, \
patch("crewai.tools.mcp_tool_wrapper.ClientSession") as mock_session_class:
mock_session = AsyncMock()
mock_session.initialize = AsyncMock()
mock_session.call_tool = AsyncMock(return_value=mock_result)
mock_session_class.return_value.__aenter__ = AsyncMock(return_value=mock_session)
mock_session_class.return_value.__aexit__ = AsyncMock()
mock_client.return_value.__aenter__ = AsyncMock(return_value=(Mock(), Mock(), Mock()))
mock_client.return_value.__aexit__ = AsyncMock()
await wrapper._execute_tool(test_arg="test_value")
mock_client.assert_called_once()
call_args = mock_client.call_args
assert call_args[0][0] == "https://example.com/mcp"
assert "headers" in call_args[1]
assert call_args[1]["headers"] == headers
@pytest.mark.asyncio
async def test_no_headers_when_not_configured(self):
"""Test that headers are not passed when not configured."""
wrapper = MCPToolWrapper(
mcp_server_params={"url": "https://example.com/mcp"},
tool_name="test_tool",
tool_schema={"description": "Test tool"},
server_name="test_server",
)
mock_result = Mock()
mock_result.content = [Mock(text="Test result")]
with patch("crewai.tools.mcp_tool_wrapper.streamablehttp_client") as mock_client, \
patch("crewai.tools.mcp_tool_wrapper.ClientSession") as mock_session_class:
mock_session = AsyncMock()
mock_session.initialize = AsyncMock()
mock_session.call_tool = AsyncMock(return_value=mock_result)
mock_session_class.return_value.__aenter__ = AsyncMock(return_value=mock_session)
mock_session_class.return_value.__aexit__ = AsyncMock()
mock_client.return_value.__aenter__ = AsyncMock(return_value=(Mock(), Mock(), Mock()))
mock_client.return_value.__aexit__ = AsyncMock()
await wrapper._execute_tool(test_arg="test_value")
mock_client.assert_called_once()
call_args = mock_client.call_args
assert "headers" not in call_args[1] or call_args[1].get("headers") is None
class TestMCPToolWrapperIntegration:
"""Integration tests for MCP tool wrapper with progress and headers."""
@pytest.mark.asyncio
async def test_full_execution_with_progress_and_headers(self, mock_agent, mock_task):
"""Test complete execution flow with both progress and headers."""
progress_calls = []
def progress_callback(progress, total, message):
progress_calls.append((progress, total, message))
headers = {"Authorization": "Bearer test-token"}
wrapper = MCPToolWrapper(
mcp_server_params={
"url": "https://example.com/mcp",
"headers": headers,
},
tool_name="test_tool",
tool_schema={"description": "Test tool"},
server_name="test_server",
progress_callback=progress_callback,
agent=mock_agent,
task=mock_task,
)
mock_result = Mock()
mock_result.content = [Mock(text="Test result")]
with patch("crewai.tools.mcp_tool_wrapper.streamablehttp_client") as mock_client, \
patch("crewai.tools.mcp_tool_wrapper.ClientSession") as mock_session_class:
mock_session = AsyncMock()
mock_session.initialize = AsyncMock()
mock_session.call_tool = AsyncMock(return_value=mock_result)
mock_session.on_progress = None
mock_session_class.return_value.__aenter__ = AsyncMock(return_value=mock_session)
mock_session_class.return_value.__aexit__ = AsyncMock()
mock_client.return_value.__aenter__ = AsyncMock(return_value=(Mock(), Mock(), Mock()))
mock_client.return_value.__aexit__ = AsyncMock()
result = await wrapper._execute_tool(test_arg="test_value")
assert result == "Test result"
call_args = mock_client.call_args
assert call_args[1]["headers"] == headers
assert mock_session.on_progress is not None