From 873d501401c5389e4d195452e329b21105d0175d Mon Sep 17 00:00:00 2001 From: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com> Date: Sun, 26 Oct 2025 09:45:50 +0000 Subject: [PATCH] feat: Add MCP progress notifications and middleware headers support MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implements progress reporting and HTTP headers support for MCP tool integration to address issue #3797. Changes: - Add MCPToolProgressEvent to event system for real-time progress tracking - Extend MCPToolWrapper to support progress callbacks and event emission - Add mcp_progress_enabled flag to Agent for opt-in progress notifications - Add mcp_server_headers to Agent for middleware authentication/tracking - Thread progress and headers configuration through Agent._get_external_mcp_tools - Add comprehensive test coverage for progress and headers features - Update MCP DSL documentation with progress and headers examples Features: - Progress notifications emitted as MCPToolProgressEvent via event bus - Optional progress callback for custom progress handling - HTTP headers passthrough for authentication and middleware integration - Agent and task context included in progress events - Opt-in design ensures backward compatibility Tests: - Unit tests for MCPToolWrapper progress and headers functionality - Integration tests for Agent MCP configuration - Mock-based tests to avoid network dependencies Documentation: - Added Progress Notifications section with examples - Added Middleware Support with Headers section - Included complete examples for common use cases Fixes #3797 Co-Authored-By: João --- docs/en/mcp/dsl-integration.mdx | 219 +++++++++++ lib/crewai/src/crewai/agent.py | 20 +- .../crewai/agents/agent_builder/base_agent.py | 8 + lib/crewai/src/crewai/events/__init__.py | 2 + .../crewai/events/types/tool_usage_events.py | 15 + .../src/crewai/tools/mcp_tool_wrapper.py | 61 ++- .../tests/agents/test_agent_mcp_progress.py | 324 ++++++++++++++++ .../tests/tools/test_mcp_tool_wrapper.py | 350 ++++++++++++++++++ 8 files changed, 996 insertions(+), 3 deletions(-) create mode 100644 lib/crewai/tests/agents/test_agent_mcp_progress.py create mode 100644 lib/crewai/tests/tools/test_mcp_tool_wrapper.py diff --git a/docs/en/mcp/dsl-integration.mdx b/docs/en/mcp/dsl-integration.mdx index 78f1e884d..39853620d 100644 --- a/docs/en/mcp/dsl-integration.mdx +++ b/docs/en/mcp/dsl-integration.mdx @@ -339,6 +339,225 @@ mcps=["https://mcp.example.com/mcp?api_key=valid_key"] # Ensure query parameters are properly URL encoded ``` +## Progress Notifications + +CrewAI supports progress notifications from MCP servers during long-running tool executions. This provides real-time visibility into tool execution status and enables precise monitoring of complex operations. + +### Enabling Progress Notifications + +Enable progress tracking by setting `mcp_progress_enabled=True` on your agent: + +```python +from crewai import Agent +from crewai.events import crewai_event_bus, MCPToolProgressEvent + +agent = Agent( + role="Data Processing Specialist", + goal="Process large datasets efficiently", + backstory="Expert at handling long-running data operations with real-time monitoring", + mcps=["https://data-processor.example.com/mcp"], + mcp_progress_enabled=True +) +``` + +### Listening to Progress Events + +Progress notifications are emitted as `MCPToolProgressEvent` through the CrewAI event bus: + +```python +def handle_progress(source, event: MCPToolProgressEvent): + print(f"Tool: {event.tool_name}") + print(f"Progress: {event.progress}/{event.total or '?'}") + print(f"Message: {event.message}") + print(f"Agent: {event.agent_role}") + +crewai_event_bus.register(MCPToolProgressEvent, handle_progress) + +result = crew.kickoff() +``` + +### Progress Event Fields + +The `MCPToolProgressEvent` provides detailed progress information: + +- `tool_name`: Name of the MCP tool being executed +- `server_name`: Name of the MCP server +- `progress`: Current progress value +- `total`: Total progress value (optional) +- `message`: Progress message from the server (optional) +- `agent_id`: ID of the agent executing the tool +- `agent_role`: Role of the agent +- `task_id`: ID of the task being executed (if available) +- `task_name`: Name of the task (if available) + +### Complete Progress Monitoring Example + +```python +from crewai import Agent, Task, Crew, Process +from crewai.events import crewai_event_bus, MCPToolProgressEvent + +progress_updates = [] + +def track_progress(source, event: MCPToolProgressEvent): + progress_updates.append({ + "tool": event.tool_name, + "progress": event.progress, + "total": event.total, + "message": event.message, + "timestamp": event.timestamp + }) + + if event.total: + percentage = (event.progress / event.total) * 100 + print(f"[{event.agent_role}] {event.tool_name}: {percentage:.1f}% - {event.message}") + else: + print(f"[{event.agent_role}] {event.tool_name}: {event.progress} - {event.message}") + +crewai_event_bus.register(MCPToolProgressEvent, track_progress) + +agent = Agent( + role="Large-Scale Data Analyst", + goal="Analyze massive datasets with progress tracking", + backstory="Specialist in processing large-scale data operations with real-time monitoring", + mcps=["https://analytics.example.com/mcp"], + mcp_progress_enabled=True +) + +task = Task( + description="Process and analyze the complete customer dataset", + expected_output="Comprehensive analysis report with insights", + agent=agent +) + +crew = Crew( + agents=[agent], + tasks=[task], + process=Process.sequential, + verbose=True +) + +result = crew.kickoff() + +print(f"Total progress updates received: {len(progress_updates)}") +``` + +## Middleware Support with Headers + +CrewAI provides precise control over MCP server communication through custom HTTP headers. This enables authentication, request tracking, and integration with server-side middleware for enhanced security and monitoring. + +### Configuring Headers + +Pass custom headers to MCP servers using `mcp_server_headers`: + +```python +from crewai import Agent + +agent = Agent( + role="Secure API Consumer", + goal="Access protected MCP services securely", + backstory="Security-conscious agent with proper authentication credentials", + mcps=["https://secure-api.example.com/mcp"], + mcp_server_headers={ + "Authorization": "Bearer your_access_token", + "X-Client-ID": "crewai-client-123", + "X-Request-Source": "production-crew" + } +) +``` + +### Common Header Use Cases + +#### Authentication + +```python +import os + +agent = Agent( + role="Authenticated Researcher", + goal="Access premium research tools", + backstory="Researcher with authenticated access to premium data sources", + mcps=["https://premium-research.example.com/mcp"], + mcp_server_headers={ + "Authorization": f"Bearer {os.getenv('RESEARCH_API_TOKEN')}", + "X-API-Key": os.getenv("RESEARCH_API_KEY") + } +) +``` + +#### Request Tracking + +```python +import uuid + +request_id = str(uuid.uuid4()) + +agent = Agent( + role="Tracked Operations Agent", + goal="Execute operations with full traceability", + backstory="Agent designed for auditable operations with request tracking", + mcps=["https://tracked-service.example.com/mcp"], + mcp_server_headers={ + "X-Request-ID": request_id, + "X-Client-Version": "crewai-2.0", + "X-Environment": "production" + } +) +``` + +#### Rate Limiting and Quotas + +```python +agent = Agent( + role="Quota-Managed Agent", + goal="Operate within API quotas and rate limits", + backstory="Agent configured for efficient API usage within quota constraints", + mcps=["https://rate-limited-api.example.com/mcp"], + mcp_server_headers={ + "X-Client-ID": "crew-client-001", + "X-Priority": "high", + "X-Quota-Group": "premium-tier" + } +) +``` + +### Combining Progress and Headers + +For complex use cases requiring both progress monitoring and middleware integration: + +```python +from crewai import Agent, Task, Crew +from crewai.events import crewai_event_bus, MCPToolProgressEvent +import os + +def monitor_progress(source, event: MCPToolProgressEvent): + print(f"Progress: {event.tool_name} - {event.progress}/{event.total}") + +crewai_event_bus.register(MCPToolProgressEvent, monitor_progress) + +agent = Agent( + role="Enterprise Data Processor", + goal="Process enterprise data with full monitoring and security", + backstory="Enterprise-grade agent with authenticated access and progress tracking", + mcps=["https://enterprise-api.example.com/mcp"], + mcp_progress_enabled=True, + mcp_server_headers={ + "Authorization": f"Bearer {os.getenv('ENTERPRISE_TOKEN')}", + "X-Client-ID": "enterprise-crew-001", + "X-Request-Source": "production", + "X-Enable-Progress": "true" + } +) + +task = Task( + description="Process quarterly financial data with real-time progress updates", + expected_output="Complete financial analysis with processing metrics", + agent=agent +) + +crew = Crew(agents=[agent], tasks=[task]) +result = crew.kickoff() +``` + ## Advanced: MCPServerAdapter For complex scenarios requiring manual connection management, use the `MCPServerAdapter` class from `crewai-tools`. Using a Python context manager (`with` statement) is the recommended approach as it automatically handles starting and stopping the connection to the MCP server. diff --git a/lib/crewai/src/crewai/agent.py b/lib/crewai/src/crewai/agent.py index 75593a1d4..048765996 100644 --- a/lib/crewai/src/crewai/agent.py +++ b/lib/crewai/src/crewai/agent.py @@ -659,7 +659,7 @@ class Agent(BaseAgent): return all_tools - def _get_external_mcp_tools(self, mcp_ref: str) -> list[BaseTool]: + def _get_external_mcp_tools(self, mcp_ref: str, task: Task | None = None) -> list[BaseTool]: """Get tools from external HTTPS MCP server with graceful error handling.""" from crewai.tools.mcp_tool_wrapper import MCPToolWrapper @@ -670,6 +670,10 @@ class Agent(BaseAgent): server_url, specific_tool = mcp_ref, None server_params = {"url": server_url} + + if self.mcp_server_headers: + server_params["headers"] = self.mcp_server_headers + server_name = self._extract_server_name(server_url) try: @@ -689,11 +693,25 @@ class Agent(BaseAgent): continue try: + progress_callback = None + if self.mcp_progress_enabled: + def make_progress_callback(tool_name_ref: str): + def callback(progress: float, total: float | None, message: str | None): + self._logger.log( + "debug", + f"MCP tool {tool_name_ref} progress: {progress}/{total or '?'} - {message or 'no message'}" + ) + return callback + progress_callback = make_progress_callback(tool_name) + wrapper = MCPToolWrapper( mcp_server_params=server_params, tool_name=tool_name, tool_schema=schema, server_name=server_name, + progress_callback=progress_callback, + agent=self, + task=task, ) tools.append(wrapper) except Exception as e: diff --git a/lib/crewai/src/crewai/agents/agent_builder/base_agent.py b/lib/crewai/src/crewai/agents/agent_builder/base_agent.py index dd7d7d7f7..31cb6b03b 100644 --- a/lib/crewai/src/crewai/agents/agent_builder/base_agent.py +++ b/lib/crewai/src/crewai/agents/agent_builder/base_agent.py @@ -197,6 +197,14 @@ class BaseAgent(BaseModel, ABC): default=None, description="List of MCP server references. Supports 'https://server.com/path' for external servers and 'crewai-amp:mcp-name' for AMP marketplace. Use '#tool_name' suffix for specific tools.", ) + mcp_progress_enabled: bool = Field( + default=False, + description="Enable progress notifications from MCP servers. When enabled, MCPToolProgressEvent will be emitted to the event bus during long-running MCP tool executions.", + ) + mcp_server_headers: dict[str, str] | None = Field( + default=None, + description="HTTP headers to pass to MCP servers for middleware support (e.g., authentication tokens, client-id). Applied to all MCP server connections.", + ) @model_validator(mode="before") @classmethod diff --git a/lib/crewai/src/crewai/events/__init__.py b/lib/crewai/src/crewai/events/__init__.py index e3eb4920f..5b5ad81ca 100644 --- a/lib/crewai/src/crewai/events/__init__.py +++ b/lib/crewai/src/crewai/events/__init__.py @@ -90,6 +90,7 @@ from crewai.events.types.task_events import ( TaskStartedEvent, ) from crewai.events.types.tool_usage_events import ( + MCPToolProgressEvent, ToolExecutionErrorEvent, ToolSelectionErrorEvent, ToolUsageErrorEvent, @@ -156,6 +157,7 @@ __all__ = [ "MethodExecutionFailedEvent", "MethodExecutionFinishedEvent", "MethodExecutionStartedEvent", + "MCPToolProgressEvent", "ReasoningEvent", "TaskCompletedEvent", "TaskEvaluationEvent", diff --git a/lib/crewai/src/crewai/events/types/tool_usage_events.py b/lib/crewai/src/crewai/events/types/tool_usage_events.py index 7fe9b897f..b170c7c98 100644 --- a/lib/crewai/src/crewai/events/types/tool_usage_events.py +++ b/lib/crewai/src/crewai/events/types/tool_usage_events.py @@ -110,3 +110,18 @@ class ToolExecutionErrorEvent(BaseEvent): and self.agent.fingerprint.metadata ): self.fingerprint_metadata = self.agent.fingerprint.metadata + + +class MCPToolProgressEvent(BaseEvent): + """Event emitted when an MCP tool reports progress during execution""" + + type: str = "mcp_tool_progress" + tool_name: str + server_name: str + progress: float + total: float | None = None + message: str | None = None + agent_id: str | None = None + agent_role: str | None = None + task_id: str | None = None + task_name: str | None = None diff --git a/lib/crewai/src/crewai/tools/mcp_tool_wrapper.py b/lib/crewai/src/crewai/tools/mcp_tool_wrapper.py index 7845d0c85..de9d0fb8e 100644 --- a/lib/crewai/src/crewai/tools/mcp_tool_wrapper.py +++ b/lib/crewai/src/crewai/tools/mcp_tool_wrapper.py @@ -1,6 +1,8 @@ """MCP Tool Wrapper for on-demand MCP server connections.""" import asyncio +from collections.abc import Callable +from typing import Any from crewai.tools import BaseTool @@ -20,6 +22,9 @@ class MCPToolWrapper(BaseTool): tool_name: str, tool_schema: dict, server_name: str, + progress_callback: Callable[[float, float | None, str | None], None] | None = None, + agent: Any | None = None, + task: Any | None = None, ): """Initialize the MCP tool wrapper. @@ -28,6 +33,9 @@ class MCPToolWrapper(BaseTool): tool_name: Original name of the tool on the MCP server tool_schema: Schema information for the tool server_name: Name of the MCP server for prefixing + progress_callback: Optional callback for progress notifications (progress, total, message) + agent: Optional agent context for event emission + task: Optional task context for event emission """ # Create tool name with server prefix to avoid conflicts prefixed_name = f"{server_name}_{tool_name}" @@ -52,6 +60,9 @@ class MCPToolWrapper(BaseTool): self._mcp_server_params = mcp_server_params self._original_tool_name = tool_name self._server_name = server_name + self._progress_callback = progress_callback + self._agent = agent + self._task = task @property def mcp_server_params(self) -> dict: @@ -165,20 +176,40 @@ class MCPToolWrapper(BaseTool): ) async def _execute_tool(self, **kwargs) -> str: - """Execute the actual MCP tool call.""" + """Execute the actual MCP tool call with progress support.""" from mcp import ClientSession from mcp.client.streamable_http import streamablehttp_client server_url = self.mcp_server_params["url"] + headers = self.mcp_server_params.get("headers") try: # Wrap entire operation with single timeout async def _do_mcp_call(): + client_kwargs = {"terminate_on_close": True} + if headers: + client_kwargs["headers"] = headers + async with streamablehttp_client( - server_url, terminate_on_close=True + server_url, **client_kwargs ) as (read, write, _): async with ClientSession(read, write) as session: await session.initialize() + + # Register progress handler if callback is provided + if self._progress_callback: + def progress_handler(progress_notification): + """Handle progress notifications from MCP server.""" + progress = progress_notification.progress + total = getattr(progress_notification, "total", None) + message = getattr(progress_notification, "message", None) + + self._progress_callback(progress, total, message) + + self._emit_progress_event(progress, total, message) + + session.on_progress = progress_handler + result = await session.call_tool( self.original_tool_name, kwargs ) @@ -211,3 +242,29 @@ class MCPToolWrapper(BaseTool): if "TaskGroup" in str(e) or "unhandled errors" in str(e): raise asyncio.TimeoutError(f"MCP connection error: {e}") from e raise + + def _emit_progress_event( + self, progress: float, total: float | None, message: str | None + ) -> None: + """Emit MCPToolProgressEvent to CrewAI event bus.""" + from crewai.events.event_bus import crewai_event_bus + from crewai.events.types.tool_usage_events import MCPToolProgressEvent + + event_data = { + "tool_name": self.original_tool_name, + "server_name": self.server_name, + "progress": progress, + "total": total, + "message": message, + } + + if self._agent: + event_data["agent_id"] = str(self._agent.id) if hasattr(self._agent, "id") else None + event_data["agent_role"] = getattr(self._agent, "role", None) + + if self._task: + event_data["task_id"] = str(self._task.id) if hasattr(self._task, "id") else None + event_data["task_name"] = getattr(self._task, "name", None) or getattr(self._task, "description", None) + + event = MCPToolProgressEvent(**event_data) + crewai_event_bus.emit(self, event) diff --git a/lib/crewai/tests/agents/test_agent_mcp_progress.py b/lib/crewai/tests/agents/test_agent_mcp_progress.py new file mode 100644 index 000000000..0a893d6fc --- /dev/null +++ b/lib/crewai/tests/agents/test_agent_mcp_progress.py @@ -0,0 +1,324 @@ +"""Tests for Agent MCP progress and headers configuration.""" + +from unittest.mock import Mock, patch + +import pytest + +from crewai.agent import Agent + + +class TestAgentMCPProgressConfiguration: + """Test suite for Agent MCP progress configuration.""" + + def test_agent_initialization_with_mcp_progress_enabled(self): + """Test that Agent can be initialized with mcp_progress_enabled.""" + agent = Agent( + role="Test Agent", + goal="Test goal", + backstory="Test backstory", + mcps=["https://example.com/mcp"], + mcp_progress_enabled=True, + ) + + assert agent.mcp_progress_enabled is True + + def test_agent_initialization_with_mcp_progress_disabled(self): + """Test that Agent defaults to mcp_progress_enabled=False.""" + agent = Agent( + role="Test Agent", + goal="Test goal", + backstory="Test backstory", + mcps=["https://example.com/mcp"], + ) + + assert agent.mcp_progress_enabled is False + + def test_agent_initialization_with_mcp_server_headers(self): + """Test that Agent can be initialized with mcp_server_headers.""" + headers = {"Authorization": "Bearer token123", "X-Client-ID": "test-client"} + + agent = Agent( + role="Test Agent", + goal="Test goal", + backstory="Test backstory", + mcps=["https://example.com/mcp"], + mcp_server_headers=headers, + ) + + assert agent.mcp_server_headers == headers + + def test_agent_initialization_without_mcp_server_headers(self): + """Test that Agent defaults to None for mcp_server_headers.""" + agent = Agent( + role="Test Agent", + goal="Test goal", + backstory="Test backstory", + mcps=["https://example.com/mcp"], + ) + + assert agent.mcp_server_headers is None + + def test_agent_with_both_progress_and_headers(self): + """Test that Agent can be initialized with both progress and headers.""" + headers = {"Authorization": "Bearer token123"} + + agent = Agent( + role="Test Agent", + goal="Test goal", + backstory="Test backstory", + mcps=["https://example.com/mcp"], + mcp_progress_enabled=True, + mcp_server_headers=headers, + ) + + assert agent.mcp_progress_enabled is True + assert agent.mcp_server_headers == headers + + +class TestAgentMCPToolCreation: + """Test suite for Agent MCP tool creation with progress and headers.""" + + @patch("crewai.agent.Agent._get_mcp_tool_schemas") + @patch("crewai.tools.mcp_tool_wrapper.MCPToolWrapper") + def test_get_external_mcp_tools_passes_headers( + self, mock_wrapper_class, mock_get_schemas + ): + """Test that _get_external_mcp_tools passes headers to server_params.""" + headers = {"Authorization": "Bearer token123"} + + agent = Agent( + role="Test Agent", + goal="Test goal", + backstory="Test backstory", + mcp_server_headers=headers, + ) + + mock_get_schemas.return_value = { + "test_tool": {"description": "Test tool"} + } + + mock_wrapper_instance = Mock() + mock_wrapper_class.return_value = mock_wrapper_instance + + tools = agent._get_external_mcp_tools("https://example.com/mcp") + + assert mock_wrapper_class.called + call_args = mock_wrapper_class.call_args + server_params = call_args[1]["mcp_server_params"] + assert "headers" in server_params + assert server_params["headers"] == headers + + @patch("crewai.agent.Agent._get_mcp_tool_schemas") + @patch("crewai.tools.mcp_tool_wrapper.MCPToolWrapper") + def test_get_external_mcp_tools_no_headers_when_not_configured( + self, mock_wrapper_class, mock_get_schemas + ): + """Test that _get_external_mcp_tools doesn't pass headers when not configured.""" + agent = Agent( + role="Test Agent", + goal="Test goal", + backstory="Test backstory", + ) + + mock_get_schemas.return_value = { + "test_tool": {"description": "Test tool"} + } + + mock_wrapper_instance = Mock() + mock_wrapper_class.return_value = mock_wrapper_instance + + tools = agent._get_external_mcp_tools("https://example.com/mcp") + + assert mock_wrapper_class.called + call_args = mock_wrapper_class.call_args + server_params = call_args[1]["mcp_server_params"] + assert "headers" not in server_params + + @patch("crewai.agent.Agent._get_mcp_tool_schemas") + @patch("crewai.tools.mcp_tool_wrapper.MCPToolWrapper") + def test_get_external_mcp_tools_passes_progress_callback_when_enabled( + self, mock_wrapper_class, mock_get_schemas + ): + """Test that _get_external_mcp_tools passes progress callback when enabled.""" + agent = Agent( + role="Test Agent", + goal="Test goal", + backstory="Test backstory", + mcp_progress_enabled=True, + ) + + mock_get_schemas.return_value = { + "test_tool": {"description": "Test tool"} + } + + mock_wrapper_instance = Mock() + mock_wrapper_class.return_value = mock_wrapper_instance + + tools = agent._get_external_mcp_tools("https://example.com/mcp") + + assert mock_wrapper_class.called + call_args = mock_wrapper_class.call_args + assert "progress_callback" in call_args[1] + assert call_args[1]["progress_callback"] is not None + + @patch("crewai.agent.Agent._get_mcp_tool_schemas") + @patch("crewai.tools.mcp_tool_wrapper.MCPToolWrapper") + def test_get_external_mcp_tools_no_progress_callback_when_disabled( + self, mock_wrapper_class, mock_get_schemas + ): + """Test that _get_external_mcp_tools doesn't pass progress callback when disabled.""" + agent = Agent( + role="Test Agent", + goal="Test goal", + backstory="Test backstory", + mcp_progress_enabled=False, + ) + + mock_get_schemas.return_value = { + "test_tool": {"description": "Test tool"} + } + + mock_wrapper_instance = Mock() + mock_wrapper_class.return_value = mock_wrapper_instance + + tools = agent._get_external_mcp_tools("https://example.com/mcp") + + assert mock_wrapper_class.called + call_args = mock_wrapper_class.call_args + assert call_args[1]["progress_callback"] is None + + @patch("crewai.agent.Agent._get_mcp_tool_schemas") + @patch("crewai.tools.mcp_tool_wrapper.MCPToolWrapper") + def test_get_external_mcp_tools_passes_agent_context( + self, mock_wrapper_class, mock_get_schemas + ): + """Test that _get_external_mcp_tools passes agent context to wrapper.""" + agent = Agent( + role="Test Agent", + goal="Test goal", + backstory="Test backstory", + mcp_progress_enabled=True, + ) + + mock_get_schemas.return_value = { + "test_tool": {"description": "Test tool"} + } + + mock_wrapper_instance = Mock() + mock_wrapper_class.return_value = mock_wrapper_instance + + tools = agent._get_external_mcp_tools("https://example.com/mcp") + + assert mock_wrapper_class.called + call_args = mock_wrapper_class.call_args + assert "agent" in call_args[1] + assert call_args[1]["agent"] == agent + + @patch("crewai.agent.Agent._get_mcp_tool_schemas") + @patch("crewai.tools.mcp_tool_wrapper.MCPToolWrapper") + def test_get_external_mcp_tools_passes_task_context( + self, mock_wrapper_class, mock_get_schemas + ): + """Test that _get_external_mcp_tools passes task context to wrapper.""" + agent = Agent( + role="Test Agent", + goal="Test goal", + backstory="Test backstory", + mcp_progress_enabled=True, + ) + + mock_get_schemas.return_value = { + "test_tool": {"description": "Test tool"} + } + + mock_wrapper_instance = Mock() + mock_wrapper_class.return_value = mock_wrapper_instance + + mock_task = Mock() + mock_task.id = "test-task-id" + + tools = agent._get_external_mcp_tools("https://example.com/mcp", task=mock_task) + + assert mock_wrapper_class.called + call_args = mock_wrapper_class.call_args + assert "task" in call_args[1] + assert call_args[1]["task"] == mock_task + + @patch("crewai.agent.Agent._get_mcp_tool_schemas") + @patch("crewai.tools.mcp_tool_wrapper.MCPToolWrapper") + def test_get_external_mcp_tools_with_all_features( + self, mock_wrapper_class, mock_get_schemas + ): + """Test _get_external_mcp_tools with progress, headers, and context.""" + headers = {"Authorization": "Bearer token123"} + + agent = Agent( + role="Test Agent", + goal="Test goal", + backstory="Test backstory", + mcp_progress_enabled=True, + mcp_server_headers=headers, + ) + + mock_get_schemas.return_value = { + "test_tool": {"description": "Test tool"} + } + + mock_wrapper_instance = Mock() + mock_wrapper_class.return_value = mock_wrapper_instance + + mock_task = Mock() + mock_task.id = "test-task-id" + + tools = agent._get_external_mcp_tools("https://example.com/mcp", task=mock_task) + + assert mock_wrapper_class.called + call_args = mock_wrapper_class.call_args + + server_params = call_args[1]["mcp_server_params"] + assert server_params["headers"] == headers + + assert call_args[1]["progress_callback"] is not None + + assert call_args[1]["agent"] == agent + assert call_args[1]["task"] == mock_task + + +class TestAgentMCPProgressCallback: + """Test suite for Agent MCP progress callback behavior.""" + + @patch("crewai.agent.Agent._get_mcp_tool_schemas") + @patch("crewai.tools.mcp_tool_wrapper.MCPToolWrapper") + def test_progress_callback_logs_progress( + self, mock_wrapper_class, mock_get_schemas + ): + """Test that progress callback logs progress information.""" + agent = Agent( + role="Test Agent", + goal="Test goal", + backstory="Test backstory", + mcp_progress_enabled=True, + ) + + mock_get_schemas.return_value = { + "test_tool": {"description": "Test tool"} + } + + mock_wrapper_instance = Mock() + mock_wrapper_class.return_value = mock_wrapper_instance + + with patch.object(agent._logger, "log") as mock_log: + tools = agent._get_external_mcp_tools("https://example.com/mcp") + + call_args = mock_wrapper_class.call_args + progress_callback = call_args[1]["progress_callback"] + + progress_callback(50.0, 100.0, "Processing...") + + mock_log.assert_called_once() + log_call = mock_log.call_args + assert log_call[0][0] == "debug" + assert "test_tool" in log_call[0][1] + assert "50.0" in log_call[0][1] + assert "100.0" in log_call[0][1] + assert "Processing..." in log_call[0][1] diff --git a/lib/crewai/tests/tools/test_mcp_tool_wrapper.py b/lib/crewai/tests/tools/test_mcp_tool_wrapper.py new file mode 100644 index 000000000..90da4e6d2 --- /dev/null +++ b/lib/crewai/tests/tools/test_mcp_tool_wrapper.py @@ -0,0 +1,350 @@ +"""Tests for MCPToolWrapper progress and headers support.""" + +import asyncio +from unittest.mock import AsyncMock, MagicMock, Mock, patch + +import pytest + +from crewai.events.event_bus import crewai_event_bus +from crewai.events.types.tool_usage_events import MCPToolProgressEvent +from crewai.tools.mcp_tool_wrapper import MCPToolWrapper + + +@pytest.fixture +def mock_mcp_session(): + """Create a mock MCP ClientSession.""" + session = AsyncMock() + session.initialize = AsyncMock() + session.call_tool = AsyncMock() + return session + + +@pytest.fixture +def mock_streamable_client(mock_mcp_session): + """Create a mock streamablehttp_client context manager.""" + async def mock_client(*args, **kwargs): + read = AsyncMock() + write = AsyncMock() + close = AsyncMock() + + class MockContextManager: + async def __aenter__(self): + return (read, write, close) + + async def __aexit__(self, *args): + pass + + return MockContextManager() + + return mock_client + + +@pytest.fixture +def mock_agent(): + """Create a mock agent with id and role.""" + agent = Mock() + agent.id = "test-agent-id" + agent.role = "Test Agent" + return agent + + +@pytest.fixture +def mock_task(): + """Create a mock task with id and description.""" + task = Mock() + task.id = "test-task-id" + task.description = "Test Task Description" + task.name = None + return task + + +class TestMCPToolWrapperProgress: + """Test suite for MCP tool wrapper progress notifications.""" + + def test_wrapper_initialization_with_progress_callback(self): + """Test that MCPToolWrapper can be initialized with progress callback.""" + callback = Mock() + + wrapper = MCPToolWrapper( + mcp_server_params={"url": "https://example.com/mcp"}, + tool_name="test_tool", + tool_schema={"description": "Test tool"}, + server_name="test_server", + progress_callback=callback, + ) + + assert wrapper._progress_callback == callback + assert wrapper.name == "test_server_test_tool" + + def test_wrapper_initialization_without_progress_callback(self): + """Test that MCPToolWrapper works without progress callback.""" + wrapper = MCPToolWrapper( + mcp_server_params={"url": "https://example.com/mcp"}, + tool_name="test_tool", + tool_schema={"description": "Test tool"}, + server_name="test_server", + ) + + assert wrapper._progress_callback is None + + def test_wrapper_initialization_with_agent_and_task(self, mock_agent, mock_task): + """Test that MCPToolWrapper can be initialized with agent and task context.""" + wrapper = MCPToolWrapper( + mcp_server_params={"url": "https://example.com/mcp"}, + tool_name="test_tool", + tool_schema={"description": "Test tool"}, + server_name="test_server", + agent=mock_agent, + task=mock_task, + ) + + assert wrapper._agent == mock_agent + assert wrapper._task == mock_task + + @pytest.mark.asyncio + async def test_progress_handler_called_during_execution(self, mock_agent, mock_task): + """Test that progress callback is invoked when MCP server sends progress.""" + progress_callback = Mock() + + wrapper = MCPToolWrapper( + mcp_server_params={"url": "https://example.com/mcp"}, + tool_name="test_tool", + tool_schema={"description": "Test tool"}, + server_name="test_server", + progress_callback=progress_callback, + agent=mock_agent, + task=mock_task, + ) + + mock_result = Mock() + mock_result.content = [Mock(text="Test result")] + + with patch("crewai.tools.mcp_tool_wrapper.streamablehttp_client") as mock_client, \ + patch("crewai.tools.mcp_tool_wrapper.ClientSession") as mock_session_class: + + mock_session = AsyncMock() + mock_session.initialize = AsyncMock() + mock_session.call_tool = AsyncMock(return_value=mock_result) + mock_session.on_progress = None + + mock_session_class.return_value.__aenter__ = AsyncMock(return_value=mock_session) + mock_session_class.return_value.__aexit__ = AsyncMock() + + mock_client.return_value.__aenter__ = AsyncMock(return_value=(Mock(), Mock(), Mock())) + mock_client.return_value.__aexit__ = AsyncMock() + + result = await wrapper._execute_tool(test_arg="test_value") + + assert result == "Test result" + + assert mock_session.on_progress is not None + + @pytest.mark.asyncio + async def test_progress_event_emission(self, mock_agent, mock_task): + """Test that MCPToolProgressEvent is emitted when progress is reported.""" + events_received = [] + + def event_handler(source, event): + if isinstance(event, MCPToolProgressEvent): + events_received.append(event) + + crewai_event_bus.register(MCPToolProgressEvent, event_handler) + + try: + wrapper = MCPToolWrapper( + mcp_server_params={"url": "https://example.com/mcp"}, + tool_name="test_tool", + tool_schema={"description": "Test tool"}, + server_name="test_server", + progress_callback=Mock(), + agent=mock_agent, + task=mock_task, + ) + + wrapper._emit_progress_event(50.0, 100.0, "Processing...") + + await asyncio.sleep(0.1) + + assert len(events_received) == 1 + event = events_received[0] + assert event.tool_name == "test_tool" + assert event.server_name == "test_server" + assert event.progress == 50.0 + assert event.total == 100.0 + assert event.message == "Processing..." + assert event.agent_id == "test-agent-id" + assert event.agent_role == "Test Agent" + assert event.task_id == "test-task-id" + assert event.task_name == "Test Task Description" + + finally: + crewai_event_bus._sync_handlers.pop(MCPToolProgressEvent, None) + + def test_progress_event_without_agent_context(self): + """Test that progress events work without agent context.""" + wrapper = MCPToolWrapper( + mcp_server_params={"url": "https://example.com/mcp"}, + tool_name="test_tool", + tool_schema={"description": "Test tool"}, + server_name="test_server", + progress_callback=Mock(), + ) + + wrapper._emit_progress_event(25.0, None, "Starting...") + + def test_progress_event_without_task_context(self, mock_agent): + """Test that progress events work without task context.""" + wrapper = MCPToolWrapper( + mcp_server_params={"url": "https://example.com/mcp"}, + tool_name="test_tool", + tool_schema={"description": "Test tool"}, + server_name="test_server", + progress_callback=Mock(), + agent=mock_agent, + ) + + wrapper._emit_progress_event(75.0, 100.0, None) + + +class TestMCPToolWrapperHeaders: + """Test suite for MCP tool wrapper headers support.""" + + def test_wrapper_initialization_with_headers(self): + """Test that MCPToolWrapper accepts headers in server params.""" + headers = {"Authorization": "Bearer token123", "X-Client-ID": "test-client"} + + wrapper = MCPToolWrapper( + mcp_server_params={ + "url": "https://example.com/mcp", + "headers": headers, + }, + tool_name="test_tool", + tool_schema={"description": "Test tool"}, + server_name="test_server", + ) + + assert wrapper.mcp_server_params["headers"] == headers + + @pytest.mark.asyncio + async def test_headers_passed_to_transport(self): + """Test that headers are passed to streamablehttp_client.""" + headers = {"Authorization": "Bearer token123"} + + wrapper = MCPToolWrapper( + mcp_server_params={ + "url": "https://example.com/mcp", + "headers": headers, + }, + tool_name="test_tool", + tool_schema={"description": "Test tool"}, + server_name="test_server", + ) + + mock_result = Mock() + mock_result.content = [Mock(text="Test result")] + + with patch("crewai.tools.mcp_tool_wrapper.streamablehttp_client") as mock_client, \ + patch("crewai.tools.mcp_tool_wrapper.ClientSession") as mock_session_class: + + mock_session = AsyncMock() + mock_session.initialize = AsyncMock() + mock_session.call_tool = AsyncMock(return_value=mock_result) + + mock_session_class.return_value.__aenter__ = AsyncMock(return_value=mock_session) + mock_session_class.return_value.__aexit__ = AsyncMock() + + mock_client.return_value.__aenter__ = AsyncMock(return_value=(Mock(), Mock(), Mock())) + mock_client.return_value.__aexit__ = AsyncMock() + + await wrapper._execute_tool(test_arg="test_value") + + mock_client.assert_called_once() + call_args = mock_client.call_args + assert call_args[0][0] == "https://example.com/mcp" + assert "headers" in call_args[1] + assert call_args[1]["headers"] == headers + + @pytest.mark.asyncio + async def test_no_headers_when_not_configured(self): + """Test that headers are not passed when not configured.""" + wrapper = MCPToolWrapper( + mcp_server_params={"url": "https://example.com/mcp"}, + tool_name="test_tool", + tool_schema={"description": "Test tool"}, + server_name="test_server", + ) + + mock_result = Mock() + mock_result.content = [Mock(text="Test result")] + + with patch("crewai.tools.mcp_tool_wrapper.streamablehttp_client") as mock_client, \ + patch("crewai.tools.mcp_tool_wrapper.ClientSession") as mock_session_class: + + mock_session = AsyncMock() + mock_session.initialize = AsyncMock() + mock_session.call_tool = AsyncMock(return_value=mock_result) + + mock_session_class.return_value.__aenter__ = AsyncMock(return_value=mock_session) + mock_session_class.return_value.__aexit__ = AsyncMock() + + mock_client.return_value.__aenter__ = AsyncMock(return_value=(Mock(), Mock(), Mock())) + mock_client.return_value.__aexit__ = AsyncMock() + + await wrapper._execute_tool(test_arg="test_value") + + mock_client.assert_called_once() + call_args = mock_client.call_args + assert "headers" not in call_args[1] or call_args[1].get("headers") is None + + +class TestMCPToolWrapperIntegration: + """Integration tests for MCP tool wrapper with progress and headers.""" + + @pytest.mark.asyncio + async def test_full_execution_with_progress_and_headers(self, mock_agent, mock_task): + """Test complete execution flow with both progress and headers.""" + progress_calls = [] + + def progress_callback(progress, total, message): + progress_calls.append((progress, total, message)) + + headers = {"Authorization": "Bearer test-token"} + + wrapper = MCPToolWrapper( + mcp_server_params={ + "url": "https://example.com/mcp", + "headers": headers, + }, + tool_name="test_tool", + tool_schema={"description": "Test tool"}, + server_name="test_server", + progress_callback=progress_callback, + agent=mock_agent, + task=mock_task, + ) + + mock_result = Mock() + mock_result.content = [Mock(text="Test result")] + + with patch("crewai.tools.mcp_tool_wrapper.streamablehttp_client") as mock_client, \ + patch("crewai.tools.mcp_tool_wrapper.ClientSession") as mock_session_class: + + mock_session = AsyncMock() + mock_session.initialize = AsyncMock() + mock_session.call_tool = AsyncMock(return_value=mock_result) + mock_session.on_progress = None + + mock_session_class.return_value.__aenter__ = AsyncMock(return_value=mock_session) + mock_session_class.return_value.__aexit__ = AsyncMock() + + mock_client.return_value.__aenter__ = AsyncMock(return_value=(Mock(), Mock(), Mock())) + mock_client.return_value.__aexit__ = AsyncMock() + + result = await wrapper._execute_tool(test_arg="test_value") + + assert result == "Test result" + + call_args = mock_client.call_args + assert call_args[1]["headers"] == headers + + assert mock_session.on_progress is not None