mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-21 13:58:15 +00:00
Compare commits
5 Commits
devin/1768
...
devin/1761
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
88c0950a6f | ||
|
|
9dfad32efe | ||
|
|
3b77dd57d8 | ||
|
|
99418b1160 | ||
|
|
873d501401 |
@@ -339,6 +339,225 @@ mcps=["https://mcp.example.com/mcp?api_key=valid_key"]
|
|||||||
# Ensure query parameters are properly URL encoded
|
# Ensure query parameters are properly URL encoded
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Progress Notifications
|
||||||
|
|
||||||
|
CrewAI supports progress notifications from MCP servers during long-running tool executions. This provides real-time visibility into tool execution status and enables precise monitoring of complex operations.
|
||||||
|
|
||||||
|
### Enabling Progress Notifications
|
||||||
|
|
||||||
|
Enable progress tracking by setting `mcp_progress_enabled=True` on your agent:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from crewai import Agent
|
||||||
|
from crewai.events import crewai_event_bus, MCPToolProgressEvent
|
||||||
|
|
||||||
|
agent = Agent(
|
||||||
|
role="Data Processing Specialist",
|
||||||
|
goal="Process large datasets efficiently",
|
||||||
|
backstory="Expert at handling long-running data operations with real-time monitoring",
|
||||||
|
mcps=["https://data-processor.example.com/mcp"],
|
||||||
|
mcp_progress_enabled=True
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Listening to Progress Events
|
||||||
|
|
||||||
|
Progress notifications are emitted as `MCPToolProgressEvent` through the CrewAI event bus:
|
||||||
|
|
||||||
|
```python
|
||||||
|
def handle_progress(source, event: MCPToolProgressEvent):
|
||||||
|
print(f"Tool: {event.tool_name}")
|
||||||
|
print(f"Progress: {event.progress}/{event.total or '?'}")
|
||||||
|
print(f"Message: {event.message}")
|
||||||
|
print(f"Agent: {event.agent_role}")
|
||||||
|
|
||||||
|
crewai_event_bus.register(MCPToolProgressEvent, handle_progress)
|
||||||
|
|
||||||
|
result = crew.kickoff()
|
||||||
|
```
|
||||||
|
|
||||||
|
### Progress Event Fields
|
||||||
|
|
||||||
|
The `MCPToolProgressEvent` provides detailed progress information:
|
||||||
|
|
||||||
|
- `tool_name`: Name of the MCP tool being executed
|
||||||
|
- `server_name`: Name of the MCP server
|
||||||
|
- `progress`: Current progress value
|
||||||
|
- `total`: Total progress value (optional)
|
||||||
|
- `message`: Progress message from the server (optional)
|
||||||
|
- `agent_id`: ID of the agent executing the tool
|
||||||
|
- `agent_role`: Role of the agent
|
||||||
|
- `task_id`: ID of the task being executed (if available)
|
||||||
|
- `task_name`: Name of the task (if available)
|
||||||
|
|
||||||
|
### Complete Progress Monitoring Example
|
||||||
|
|
||||||
|
```python
|
||||||
|
from crewai import Agent, Task, Crew, Process
|
||||||
|
from crewai.events import crewai_event_bus, MCPToolProgressEvent
|
||||||
|
|
||||||
|
progress_updates = []
|
||||||
|
|
||||||
|
def track_progress(source, event: MCPToolProgressEvent):
|
||||||
|
progress_updates.append({
|
||||||
|
"tool": event.tool_name,
|
||||||
|
"progress": event.progress,
|
||||||
|
"total": event.total,
|
||||||
|
"message": event.message,
|
||||||
|
"timestamp": event.timestamp
|
||||||
|
})
|
||||||
|
|
||||||
|
if event.total:
|
||||||
|
percentage = (event.progress / event.total) * 100
|
||||||
|
print(f"[{event.agent_role}] {event.tool_name}: {percentage:.1f}% - {event.message}")
|
||||||
|
else:
|
||||||
|
print(f"[{event.agent_role}] {event.tool_name}: {event.progress} - {event.message}")
|
||||||
|
|
||||||
|
crewai_event_bus.register(MCPToolProgressEvent, track_progress)
|
||||||
|
|
||||||
|
agent = Agent(
|
||||||
|
role="Large-Scale Data Analyst",
|
||||||
|
goal="Analyze massive datasets with progress tracking",
|
||||||
|
backstory="Specialist in processing large-scale data operations with real-time monitoring",
|
||||||
|
mcps=["https://analytics.example.com/mcp"],
|
||||||
|
mcp_progress_enabled=True
|
||||||
|
)
|
||||||
|
|
||||||
|
task = Task(
|
||||||
|
description="Process and analyze the complete customer dataset",
|
||||||
|
expected_output="Comprehensive analysis report with insights",
|
||||||
|
agent=agent
|
||||||
|
)
|
||||||
|
|
||||||
|
crew = Crew(
|
||||||
|
agents=[agent],
|
||||||
|
tasks=[task],
|
||||||
|
process=Process.sequential,
|
||||||
|
verbose=True
|
||||||
|
)
|
||||||
|
|
||||||
|
result = crew.kickoff()
|
||||||
|
|
||||||
|
print(f"Total progress updates received: {len(progress_updates)}")
|
||||||
|
```
|
||||||
|
|
||||||
|
## Middleware Support with Headers
|
||||||
|
|
||||||
|
CrewAI provides precise control over MCP server communication through custom HTTP headers. This enables authentication, request tracking, and integration with server-side middleware for enhanced security and monitoring.
|
||||||
|
|
||||||
|
### Configuring Headers
|
||||||
|
|
||||||
|
Pass custom headers to MCP servers using `mcp_server_headers`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from crewai import Agent
|
||||||
|
|
||||||
|
agent = Agent(
|
||||||
|
role="Secure API Consumer",
|
||||||
|
goal="Access protected MCP services securely",
|
||||||
|
backstory="Security-conscious agent with proper authentication credentials",
|
||||||
|
mcps=["https://secure-api.example.com/mcp"],
|
||||||
|
mcp_server_headers={
|
||||||
|
"Authorization": "Bearer your_access_token",
|
||||||
|
"X-Client-ID": "crewai-client-123",
|
||||||
|
"X-Request-Source": "production-crew"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Common Header Use Cases
|
||||||
|
|
||||||
|
#### Authentication
|
||||||
|
|
||||||
|
```python
|
||||||
|
import os
|
||||||
|
|
||||||
|
agent = Agent(
|
||||||
|
role="Authenticated Researcher",
|
||||||
|
goal="Access premium research tools",
|
||||||
|
backstory="Researcher with authenticated access to premium data sources",
|
||||||
|
mcps=["https://premium-research.example.com/mcp"],
|
||||||
|
mcp_server_headers={
|
||||||
|
"Authorization": f"Bearer {os.getenv('RESEARCH_API_TOKEN')}",
|
||||||
|
"X-API-Key": os.getenv("RESEARCH_API_KEY")
|
||||||
|
}
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Request Tracking
|
||||||
|
|
||||||
|
```python
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
request_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
agent = Agent(
|
||||||
|
role="Tracked Operations Agent",
|
||||||
|
goal="Execute operations with full traceability",
|
||||||
|
backstory="Agent designed for auditable operations with request tracking",
|
||||||
|
mcps=["https://tracked-service.example.com/mcp"],
|
||||||
|
mcp_server_headers={
|
||||||
|
"X-Request-ID": request_id,
|
||||||
|
"X-Client-Version": "crewai-2.0",
|
||||||
|
"X-Environment": "production"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Rate Limiting and Quotas
|
||||||
|
|
||||||
|
```python
|
||||||
|
agent = Agent(
|
||||||
|
role="Quota-Managed Agent",
|
||||||
|
goal="Operate within API quotas and rate limits",
|
||||||
|
backstory="Agent configured for efficient API usage within quota constraints",
|
||||||
|
mcps=["https://rate-limited-api.example.com/mcp"],
|
||||||
|
mcp_server_headers={
|
||||||
|
"X-Client-ID": "crew-client-001",
|
||||||
|
"X-Priority": "high",
|
||||||
|
"X-Quota-Group": "premium-tier"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Combining Progress and Headers
|
||||||
|
|
||||||
|
For complex use cases requiring both progress monitoring and middleware integration:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from crewai import Agent, Task, Crew
|
||||||
|
from crewai.events import crewai_event_bus, MCPToolProgressEvent
|
||||||
|
import os
|
||||||
|
|
||||||
|
def monitor_progress(source, event: MCPToolProgressEvent):
|
||||||
|
print(f"Progress: {event.tool_name} - {event.progress}/{event.total}")
|
||||||
|
|
||||||
|
crewai_event_bus.register(MCPToolProgressEvent, monitor_progress)
|
||||||
|
|
||||||
|
agent = Agent(
|
||||||
|
role="Enterprise Data Processor",
|
||||||
|
goal="Process enterprise data with full monitoring and security",
|
||||||
|
backstory="Enterprise-grade agent with authenticated access and progress tracking",
|
||||||
|
mcps=["https://enterprise-api.example.com/mcp"],
|
||||||
|
mcp_progress_enabled=True,
|
||||||
|
mcp_server_headers={
|
||||||
|
"Authorization": f"Bearer {os.getenv('ENTERPRISE_TOKEN')}",
|
||||||
|
"X-Client-ID": "enterprise-crew-001",
|
||||||
|
"X-Request-Source": "production",
|
||||||
|
"X-Enable-Progress": "true"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
task = Task(
|
||||||
|
description="Process quarterly financial data with real-time progress updates",
|
||||||
|
expected_output="Complete financial analysis with processing metrics",
|
||||||
|
agent=agent
|
||||||
|
)
|
||||||
|
|
||||||
|
crew = Crew(agents=[agent], tasks=[task])
|
||||||
|
result = crew.kickoff()
|
||||||
|
```
|
||||||
|
|
||||||
## Advanced: MCPServerAdapter
|
## Advanced: MCPServerAdapter
|
||||||
|
|
||||||
For complex scenarios requiring manual connection management, use the `MCPServerAdapter` class from `crewai-tools`. Using a Python context manager (`with` statement) is the recommended approach as it automatically handles starting and stopping the connection to the MCP server.
|
For complex scenarios requiring manual connection management, use the `MCPServerAdapter` class from `crewai-tools`. Using a Python context manager (`with` statement) is the recommended approach as it automatically handles starting and stopping the connection to the MCP server.
|
||||||
|
|||||||
@@ -659,7 +659,7 @@ class Agent(BaseAgent):
|
|||||||
|
|
||||||
return all_tools
|
return all_tools
|
||||||
|
|
||||||
def _get_external_mcp_tools(self, mcp_ref: str) -> list[BaseTool]:
|
def _get_external_mcp_tools(self, mcp_ref: str, task: Task | None = None) -> list[BaseTool]:
|
||||||
"""Get tools from external HTTPS MCP server with graceful error handling."""
|
"""Get tools from external HTTPS MCP server with graceful error handling."""
|
||||||
from crewai.tools.mcp_tool_wrapper import MCPToolWrapper
|
from crewai.tools.mcp_tool_wrapper import MCPToolWrapper
|
||||||
|
|
||||||
@@ -670,6 +670,10 @@ class Agent(BaseAgent):
|
|||||||
server_url, specific_tool = mcp_ref, None
|
server_url, specific_tool = mcp_ref, None
|
||||||
|
|
||||||
server_params = {"url": server_url}
|
server_params = {"url": server_url}
|
||||||
|
|
||||||
|
if self.mcp_server_headers:
|
||||||
|
server_params["headers"] = self.mcp_server_headers
|
||||||
|
|
||||||
server_name = self._extract_server_name(server_url)
|
server_name = self._extract_server_name(server_url)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -689,11 +693,25 @@ class Agent(BaseAgent):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
progress_callback = None
|
||||||
|
if self.mcp_progress_enabled:
|
||||||
|
def make_progress_callback(tool_name_ref: str):
|
||||||
|
def callback(progress: float, total: float | None, message: str | None):
|
||||||
|
self._logger.log(
|
||||||
|
"debug",
|
||||||
|
f"MCP tool {tool_name_ref} progress: {progress}/{total or '?'} - {message or 'no message'}"
|
||||||
|
)
|
||||||
|
return callback
|
||||||
|
progress_callback = make_progress_callback(tool_name)
|
||||||
|
|
||||||
wrapper = MCPToolWrapper(
|
wrapper = MCPToolWrapper(
|
||||||
mcp_server_params=server_params,
|
mcp_server_params=server_params,
|
||||||
tool_name=tool_name,
|
tool_name=tool_name,
|
||||||
tool_schema=schema,
|
tool_schema=schema,
|
||||||
server_name=server_name,
|
server_name=server_name,
|
||||||
|
progress_callback=progress_callback,
|
||||||
|
agent=self,
|
||||||
|
task=task,
|
||||||
)
|
)
|
||||||
tools.append(wrapper)
|
tools.append(wrapper)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -197,6 +197,14 @@ class BaseAgent(BaseModel, ABC):
|
|||||||
default=None,
|
default=None,
|
||||||
description="List of MCP server references. Supports 'https://server.com/path' for external servers and 'crewai-amp:mcp-name' for AMP marketplace. Use '#tool_name' suffix for specific tools.",
|
description="List of MCP server references. Supports 'https://server.com/path' for external servers and 'crewai-amp:mcp-name' for AMP marketplace. Use '#tool_name' suffix for specific tools.",
|
||||||
)
|
)
|
||||||
|
mcp_progress_enabled: bool = Field(
|
||||||
|
default=False,
|
||||||
|
description="Enable progress notifications from MCP servers. When enabled, MCPToolProgressEvent will be emitted to the event bus during long-running MCP tool executions.",
|
||||||
|
)
|
||||||
|
mcp_server_headers: dict[str, str] | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="HTTP headers to pass to MCP servers for middleware support (e.g., authentication tokens, client-id). Applied to all MCP server connections.",
|
||||||
|
)
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@@ -90,6 +90,7 @@ from crewai.events.types.task_events import (
|
|||||||
TaskStartedEvent,
|
TaskStartedEvent,
|
||||||
)
|
)
|
||||||
from crewai.events.types.tool_usage_events import (
|
from crewai.events.types.tool_usage_events import (
|
||||||
|
MCPToolProgressEvent,
|
||||||
ToolExecutionErrorEvent,
|
ToolExecutionErrorEvent,
|
||||||
ToolSelectionErrorEvent,
|
ToolSelectionErrorEvent,
|
||||||
ToolUsageErrorEvent,
|
ToolUsageErrorEvent,
|
||||||
@@ -156,6 +157,7 @@ __all__ = [
|
|||||||
"MethodExecutionFailedEvent",
|
"MethodExecutionFailedEvent",
|
||||||
"MethodExecutionFinishedEvent",
|
"MethodExecutionFinishedEvent",
|
||||||
"MethodExecutionStartedEvent",
|
"MethodExecutionStartedEvent",
|
||||||
|
"MCPToolProgressEvent",
|
||||||
"ReasoningEvent",
|
"ReasoningEvent",
|
||||||
"TaskCompletedEvent",
|
"TaskCompletedEvent",
|
||||||
"TaskEvaluationEvent",
|
"TaskEvaluationEvent",
|
||||||
|
|||||||
@@ -110,3 +110,18 @@ class ToolExecutionErrorEvent(BaseEvent):
|
|||||||
and self.agent.fingerprint.metadata
|
and self.agent.fingerprint.metadata
|
||||||
):
|
):
|
||||||
self.fingerprint_metadata = self.agent.fingerprint.metadata
|
self.fingerprint_metadata = self.agent.fingerprint.metadata
|
||||||
|
|
||||||
|
|
||||||
|
class MCPToolProgressEvent(BaseEvent):
|
||||||
|
"""Event emitted when an MCP tool reports progress during execution"""
|
||||||
|
|
||||||
|
type: str = "mcp_tool_progress"
|
||||||
|
tool_name: str
|
||||||
|
server_name: str
|
||||||
|
progress: float
|
||||||
|
total: float | None = None
|
||||||
|
message: str | None = None
|
||||||
|
agent_id: str | None = None
|
||||||
|
agent_role: str | None = None
|
||||||
|
task_id: str | None = None
|
||||||
|
task_name: str | None = None
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
"""MCP Tool Wrapper for on-demand MCP server connections."""
|
"""MCP Tool Wrapper for on-demand MCP server connections."""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
from collections.abc import Callable
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from crewai.tools import BaseTool
|
from crewai.tools import BaseTool
|
||||||
|
|
||||||
@@ -20,6 +22,9 @@ class MCPToolWrapper(BaseTool):
|
|||||||
tool_name: str,
|
tool_name: str,
|
||||||
tool_schema: dict,
|
tool_schema: dict,
|
||||||
server_name: str,
|
server_name: str,
|
||||||
|
progress_callback: Callable[[float, float | None, str | None], None] | None = None,
|
||||||
|
agent: Any | None = None,
|
||||||
|
task: Any | None = None,
|
||||||
):
|
):
|
||||||
"""Initialize the MCP tool wrapper.
|
"""Initialize the MCP tool wrapper.
|
||||||
|
|
||||||
@@ -28,6 +33,9 @@ class MCPToolWrapper(BaseTool):
|
|||||||
tool_name: Original name of the tool on the MCP server
|
tool_name: Original name of the tool on the MCP server
|
||||||
tool_schema: Schema information for the tool
|
tool_schema: Schema information for the tool
|
||||||
server_name: Name of the MCP server for prefixing
|
server_name: Name of the MCP server for prefixing
|
||||||
|
progress_callback: Optional callback for progress notifications (progress, total, message)
|
||||||
|
agent: Optional agent context for event emission
|
||||||
|
task: Optional task context for event emission
|
||||||
"""
|
"""
|
||||||
# Create tool name with server prefix to avoid conflicts
|
# Create tool name with server prefix to avoid conflicts
|
||||||
prefixed_name = f"{server_name}_{tool_name}"
|
prefixed_name = f"{server_name}_{tool_name}"
|
||||||
@@ -52,6 +60,9 @@ class MCPToolWrapper(BaseTool):
|
|||||||
self._mcp_server_params = mcp_server_params
|
self._mcp_server_params = mcp_server_params
|
||||||
self._original_tool_name = tool_name
|
self._original_tool_name = tool_name
|
||||||
self._server_name = server_name
|
self._server_name = server_name
|
||||||
|
self._progress_callback = progress_callback
|
||||||
|
self._agent = agent
|
||||||
|
self._task = task
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def mcp_server_params(self) -> dict:
|
def mcp_server_params(self) -> dict:
|
||||||
@@ -165,20 +176,40 @@ class MCPToolWrapper(BaseTool):
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def _execute_tool(self, **kwargs) -> str:
|
async def _execute_tool(self, **kwargs) -> str:
|
||||||
"""Execute the actual MCP tool call."""
|
"""Execute the actual MCP tool call with progress support."""
|
||||||
from mcp import ClientSession
|
from mcp import ClientSession
|
||||||
from mcp.client.streamable_http import streamablehttp_client
|
from mcp.client.streamable_http import streamablehttp_client
|
||||||
|
|
||||||
server_url = self.mcp_server_params["url"]
|
server_url = self.mcp_server_params["url"]
|
||||||
|
headers = self.mcp_server_params.get("headers")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Wrap entire operation with single timeout
|
# Wrap entire operation with single timeout
|
||||||
async def _do_mcp_call():
|
async def _do_mcp_call():
|
||||||
|
client_kwargs = {"terminate_on_close": True}
|
||||||
|
if headers:
|
||||||
|
client_kwargs["headers"] = headers
|
||||||
|
|
||||||
async with streamablehttp_client(
|
async with streamablehttp_client(
|
||||||
server_url, terminate_on_close=True
|
server_url, **client_kwargs
|
||||||
) as (read, write, _):
|
) as (read, write, _):
|
||||||
async with ClientSession(read, write) as session:
|
async with ClientSession(read, write) as session:
|
||||||
await session.initialize()
|
await session.initialize()
|
||||||
|
|
||||||
|
# Register progress handler if callback is provided
|
||||||
|
if self._progress_callback:
|
||||||
|
def progress_handler(progress_notification):
|
||||||
|
"""Handle progress notifications from MCP server."""
|
||||||
|
progress = progress_notification.progress
|
||||||
|
total = getattr(progress_notification, "total", None)
|
||||||
|
message = getattr(progress_notification, "message", None)
|
||||||
|
|
||||||
|
self._progress_callback(progress, total, message)
|
||||||
|
|
||||||
|
self._emit_progress_event(progress, total, message)
|
||||||
|
|
||||||
|
session.on_progress = progress_handler
|
||||||
|
|
||||||
result = await session.call_tool(
|
result = await session.call_tool(
|
||||||
self.original_tool_name, kwargs
|
self.original_tool_name, kwargs
|
||||||
)
|
)
|
||||||
@@ -211,3 +242,29 @@ class MCPToolWrapper(BaseTool):
|
|||||||
if "TaskGroup" in str(e) or "unhandled errors" in str(e):
|
if "TaskGroup" in str(e) or "unhandled errors" in str(e):
|
||||||
raise asyncio.TimeoutError(f"MCP connection error: {e}") from e
|
raise asyncio.TimeoutError(f"MCP connection error: {e}") from e
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
def _emit_progress_event(
|
||||||
|
self, progress: float, total: float | None, message: str | None
|
||||||
|
) -> None:
|
||||||
|
"""Emit MCPToolProgressEvent to CrewAI event bus."""
|
||||||
|
from crewai.events.event_bus import crewai_event_bus
|
||||||
|
from crewai.events.types.tool_usage_events import MCPToolProgressEvent
|
||||||
|
|
||||||
|
event_data = {
|
||||||
|
"tool_name": self.original_tool_name,
|
||||||
|
"server_name": self.server_name,
|
||||||
|
"progress": progress,
|
||||||
|
"total": total,
|
||||||
|
"message": message,
|
||||||
|
}
|
||||||
|
|
||||||
|
if self._agent:
|
||||||
|
event_data["agent_id"] = str(self._agent.id) if hasattr(self._agent, "id") else None
|
||||||
|
event_data["agent_role"] = getattr(self._agent, "role", None)
|
||||||
|
|
||||||
|
if self._task:
|
||||||
|
event_data["task_id"] = str(self._task.id) if hasattr(self._task, "id") else None
|
||||||
|
event_data["task_name"] = getattr(self._task, "name", None) or getattr(self._task, "description", None)
|
||||||
|
|
||||||
|
event = MCPToolProgressEvent(**event_data)
|
||||||
|
crewai_event_bus.emit(self, event)
|
||||||
|
|||||||
324
lib/crewai/tests/agents/test_agent_mcp_progress.py
Normal file
324
lib/crewai/tests/agents/test_agent_mcp_progress.py
Normal file
@@ -0,0 +1,324 @@
|
|||||||
|
"""Tests for Agent MCP progress and headers configuration."""
|
||||||
|
|
||||||
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from crewai.agent import Agent
|
||||||
|
|
||||||
|
|
||||||
|
class TestAgentMCPProgressConfiguration:
|
||||||
|
"""Test suite for Agent MCP progress configuration."""
|
||||||
|
|
||||||
|
def test_agent_initialization_with_mcp_progress_enabled(self):
|
||||||
|
"""Test that Agent can be initialized with mcp_progress_enabled."""
|
||||||
|
agent = Agent(
|
||||||
|
role="Test Agent",
|
||||||
|
goal="Test goal",
|
||||||
|
backstory="Test backstory",
|
||||||
|
mcps=["https://example.com/mcp"],
|
||||||
|
mcp_progress_enabled=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert agent.mcp_progress_enabled is True
|
||||||
|
|
||||||
|
def test_agent_initialization_with_mcp_progress_disabled(self):
|
||||||
|
"""Test that Agent defaults to mcp_progress_enabled=False."""
|
||||||
|
agent = Agent(
|
||||||
|
role="Test Agent",
|
||||||
|
goal="Test goal",
|
||||||
|
backstory="Test backstory",
|
||||||
|
mcps=["https://example.com/mcp"],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert agent.mcp_progress_enabled is False
|
||||||
|
|
||||||
|
def test_agent_initialization_with_mcp_server_headers(self):
|
||||||
|
"""Test that Agent can be initialized with mcp_server_headers."""
|
||||||
|
headers = {"Authorization": "Bearer token123", "X-Client-ID": "test-client"}
|
||||||
|
|
||||||
|
agent = Agent(
|
||||||
|
role="Test Agent",
|
||||||
|
goal="Test goal",
|
||||||
|
backstory="Test backstory",
|
||||||
|
mcps=["https://example.com/mcp"],
|
||||||
|
mcp_server_headers=headers,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert agent.mcp_server_headers == headers
|
||||||
|
|
||||||
|
def test_agent_initialization_without_mcp_server_headers(self):
|
||||||
|
"""Test that Agent defaults to None for mcp_server_headers."""
|
||||||
|
agent = Agent(
|
||||||
|
role="Test Agent",
|
||||||
|
goal="Test goal",
|
||||||
|
backstory="Test backstory",
|
||||||
|
mcps=["https://example.com/mcp"],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert agent.mcp_server_headers is None
|
||||||
|
|
||||||
|
def test_agent_with_both_progress_and_headers(self):
|
||||||
|
"""Test that Agent can be initialized with both progress and headers."""
|
||||||
|
headers = {"Authorization": "Bearer token123"}
|
||||||
|
|
||||||
|
agent = Agent(
|
||||||
|
role="Test Agent",
|
||||||
|
goal="Test goal",
|
||||||
|
backstory="Test backstory",
|
||||||
|
mcps=["https://example.com/mcp"],
|
||||||
|
mcp_progress_enabled=True,
|
||||||
|
mcp_server_headers=headers,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert agent.mcp_progress_enabled is True
|
||||||
|
assert agent.mcp_server_headers == headers
|
||||||
|
|
||||||
|
|
||||||
|
class TestAgentMCPToolCreation:
|
||||||
|
"""Test suite for Agent MCP tool creation with progress and headers."""
|
||||||
|
|
||||||
|
@patch("crewai.agent.Agent._get_mcp_tool_schemas")
|
||||||
|
@patch("crewai.tools.mcp_tool_wrapper.MCPToolWrapper")
|
||||||
|
def test_get_external_mcp_tools_passes_headers(
|
||||||
|
self, mock_wrapper_class, mock_get_schemas
|
||||||
|
):
|
||||||
|
"""Test that _get_external_mcp_tools passes headers to server_params."""
|
||||||
|
headers = {"Authorization": "Bearer token123"}
|
||||||
|
|
||||||
|
agent = Agent(
|
||||||
|
role="Test Agent",
|
||||||
|
goal="Test goal",
|
||||||
|
backstory="Test backstory",
|
||||||
|
mcp_server_headers=headers,
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_get_schemas.return_value = {
|
||||||
|
"test_tool": {"description": "Test tool"}
|
||||||
|
}
|
||||||
|
|
||||||
|
mock_wrapper_instance = Mock()
|
||||||
|
mock_wrapper_class.return_value = mock_wrapper_instance
|
||||||
|
|
||||||
|
tools = agent._get_external_mcp_tools("https://example.com/mcp")
|
||||||
|
|
||||||
|
assert mock_wrapper_class.called
|
||||||
|
call_args = mock_wrapper_class.call_args
|
||||||
|
server_params = call_args[1]["mcp_server_params"]
|
||||||
|
assert "headers" in server_params
|
||||||
|
assert server_params["headers"] == headers
|
||||||
|
|
||||||
|
@patch("crewai.agent.Agent._get_mcp_tool_schemas")
|
||||||
|
@patch("crewai.tools.mcp_tool_wrapper.MCPToolWrapper")
|
||||||
|
def test_get_external_mcp_tools_no_headers_when_not_configured(
|
||||||
|
self, mock_wrapper_class, mock_get_schemas
|
||||||
|
):
|
||||||
|
"""Test that _get_external_mcp_tools doesn't pass headers when not configured."""
|
||||||
|
agent = Agent(
|
||||||
|
role="Test Agent",
|
||||||
|
goal="Test goal",
|
||||||
|
backstory="Test backstory",
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_get_schemas.return_value = {
|
||||||
|
"test_tool": {"description": "Test tool"}
|
||||||
|
}
|
||||||
|
|
||||||
|
mock_wrapper_instance = Mock()
|
||||||
|
mock_wrapper_class.return_value = mock_wrapper_instance
|
||||||
|
|
||||||
|
tools = agent._get_external_mcp_tools("https://example.com/mcp")
|
||||||
|
|
||||||
|
assert mock_wrapper_class.called
|
||||||
|
call_args = mock_wrapper_class.call_args
|
||||||
|
server_params = call_args[1]["mcp_server_params"]
|
||||||
|
assert "headers" not in server_params
|
||||||
|
|
||||||
|
@patch("crewai.agent.Agent._get_mcp_tool_schemas")
|
||||||
|
@patch("crewai.tools.mcp_tool_wrapper.MCPToolWrapper")
|
||||||
|
def test_get_external_mcp_tools_passes_progress_callback_when_enabled(
|
||||||
|
self, mock_wrapper_class, mock_get_schemas
|
||||||
|
):
|
||||||
|
"""Test that _get_external_mcp_tools passes progress callback when enabled."""
|
||||||
|
agent = Agent(
|
||||||
|
role="Test Agent",
|
||||||
|
goal="Test goal",
|
||||||
|
backstory="Test backstory",
|
||||||
|
mcp_progress_enabled=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_get_schemas.return_value = {
|
||||||
|
"test_tool": {"description": "Test tool"}
|
||||||
|
}
|
||||||
|
|
||||||
|
mock_wrapper_instance = Mock()
|
||||||
|
mock_wrapper_class.return_value = mock_wrapper_instance
|
||||||
|
|
||||||
|
tools = agent._get_external_mcp_tools("https://example.com/mcp")
|
||||||
|
|
||||||
|
assert mock_wrapper_class.called
|
||||||
|
call_args = mock_wrapper_class.call_args
|
||||||
|
assert "progress_callback" in call_args[1]
|
||||||
|
assert call_args[1]["progress_callback"] is not None
|
||||||
|
|
||||||
|
@patch("crewai.agent.Agent._get_mcp_tool_schemas")
|
||||||
|
@patch("crewai.tools.mcp_tool_wrapper.MCPToolWrapper")
|
||||||
|
def test_get_external_mcp_tools_no_progress_callback_when_disabled(
|
||||||
|
self, mock_wrapper_class, mock_get_schemas
|
||||||
|
):
|
||||||
|
"""Test that _get_external_mcp_tools doesn't pass progress callback when disabled."""
|
||||||
|
agent = Agent(
|
||||||
|
role="Test Agent",
|
||||||
|
goal="Test goal",
|
||||||
|
backstory="Test backstory",
|
||||||
|
mcp_progress_enabled=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_get_schemas.return_value = {
|
||||||
|
"test_tool": {"description": "Test tool"}
|
||||||
|
}
|
||||||
|
|
||||||
|
mock_wrapper_instance = Mock()
|
||||||
|
mock_wrapper_class.return_value = mock_wrapper_instance
|
||||||
|
|
||||||
|
tools = agent._get_external_mcp_tools("https://example.com/mcp")
|
||||||
|
|
||||||
|
assert mock_wrapper_class.called
|
||||||
|
call_args = mock_wrapper_class.call_args
|
||||||
|
assert call_args[1]["progress_callback"] is None
|
||||||
|
|
||||||
|
@patch("crewai.agent.Agent._get_mcp_tool_schemas")
|
||||||
|
@patch("crewai.tools.mcp_tool_wrapper.MCPToolWrapper")
|
||||||
|
def test_get_external_mcp_tools_passes_agent_context(
|
||||||
|
self, mock_wrapper_class, mock_get_schemas
|
||||||
|
):
|
||||||
|
"""Test that _get_external_mcp_tools passes agent context to wrapper."""
|
||||||
|
agent = Agent(
|
||||||
|
role="Test Agent",
|
||||||
|
goal="Test goal",
|
||||||
|
backstory="Test backstory",
|
||||||
|
mcp_progress_enabled=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_get_schemas.return_value = {
|
||||||
|
"test_tool": {"description": "Test tool"}
|
||||||
|
}
|
||||||
|
|
||||||
|
mock_wrapper_instance = Mock()
|
||||||
|
mock_wrapper_class.return_value = mock_wrapper_instance
|
||||||
|
|
||||||
|
tools = agent._get_external_mcp_tools("https://example.com/mcp")
|
||||||
|
|
||||||
|
assert mock_wrapper_class.called
|
||||||
|
call_args = mock_wrapper_class.call_args
|
||||||
|
assert "agent" in call_args[1]
|
||||||
|
assert call_args[1]["agent"] == agent
|
||||||
|
|
||||||
|
@patch("crewai.agent.Agent._get_mcp_tool_schemas")
|
||||||
|
@patch("crewai.tools.mcp_tool_wrapper.MCPToolWrapper")
|
||||||
|
def test_get_external_mcp_tools_passes_task_context(
|
||||||
|
self, mock_wrapper_class, mock_get_schemas
|
||||||
|
):
|
||||||
|
"""Test that _get_external_mcp_tools passes task context to wrapper."""
|
||||||
|
agent = Agent(
|
||||||
|
role="Test Agent",
|
||||||
|
goal="Test goal",
|
||||||
|
backstory="Test backstory",
|
||||||
|
mcp_progress_enabled=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_get_schemas.return_value = {
|
||||||
|
"test_tool": {"description": "Test tool"}
|
||||||
|
}
|
||||||
|
|
||||||
|
mock_wrapper_instance = Mock()
|
||||||
|
mock_wrapper_class.return_value = mock_wrapper_instance
|
||||||
|
|
||||||
|
mock_task = Mock()
|
||||||
|
mock_task.id = "test-task-id"
|
||||||
|
|
||||||
|
tools = agent._get_external_mcp_tools("https://example.com/mcp", task=mock_task)
|
||||||
|
|
||||||
|
assert mock_wrapper_class.called
|
||||||
|
call_args = mock_wrapper_class.call_args
|
||||||
|
assert "task" in call_args[1]
|
||||||
|
assert call_args[1]["task"] == mock_task
|
||||||
|
|
||||||
|
@patch("crewai.agent.Agent._get_mcp_tool_schemas")
|
||||||
|
@patch("crewai.tools.mcp_tool_wrapper.MCPToolWrapper")
|
||||||
|
def test_get_external_mcp_tools_with_all_features(
|
||||||
|
self, mock_wrapper_class, mock_get_schemas
|
||||||
|
):
|
||||||
|
"""Test _get_external_mcp_tools with progress, headers, and context."""
|
||||||
|
headers = {"Authorization": "Bearer token123"}
|
||||||
|
|
||||||
|
agent = Agent(
|
||||||
|
role="Test Agent",
|
||||||
|
goal="Test goal",
|
||||||
|
backstory="Test backstory",
|
||||||
|
mcp_progress_enabled=True,
|
||||||
|
mcp_server_headers=headers,
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_get_schemas.return_value = {
|
||||||
|
"test_tool": {"description": "Test tool"}
|
||||||
|
}
|
||||||
|
|
||||||
|
mock_wrapper_instance = Mock()
|
||||||
|
mock_wrapper_class.return_value = mock_wrapper_instance
|
||||||
|
|
||||||
|
mock_task = Mock()
|
||||||
|
mock_task.id = "test-task-id"
|
||||||
|
|
||||||
|
tools = agent._get_external_mcp_tools("https://example.com/mcp", task=mock_task)
|
||||||
|
|
||||||
|
assert mock_wrapper_class.called
|
||||||
|
call_args = mock_wrapper_class.call_args
|
||||||
|
|
||||||
|
server_params = call_args[1]["mcp_server_params"]
|
||||||
|
assert server_params["headers"] == headers
|
||||||
|
|
||||||
|
assert call_args[1]["progress_callback"] is not None
|
||||||
|
|
||||||
|
assert call_args[1]["agent"] == agent
|
||||||
|
assert call_args[1]["task"] == mock_task
|
||||||
|
|
||||||
|
|
||||||
|
class TestAgentMCPProgressCallback:
|
||||||
|
"""Test suite for Agent MCP progress callback behavior."""
|
||||||
|
|
||||||
|
@patch("crewai.agent.Agent._get_mcp_tool_schemas")
|
||||||
|
@patch("crewai.tools.mcp_tool_wrapper.MCPToolWrapper")
|
||||||
|
def test_progress_callback_logs_progress(
|
||||||
|
self, mock_wrapper_class, mock_get_schemas
|
||||||
|
):
|
||||||
|
"""Test that progress callback logs progress information."""
|
||||||
|
agent = Agent(
|
||||||
|
role="Test Agent",
|
||||||
|
goal="Test goal",
|
||||||
|
backstory="Test backstory",
|
||||||
|
mcp_progress_enabled=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_get_schemas.return_value = {
|
||||||
|
"test_tool": {"description": "Test tool"}
|
||||||
|
}
|
||||||
|
|
||||||
|
mock_wrapper_instance = Mock()
|
||||||
|
mock_wrapper_class.return_value = mock_wrapper_instance
|
||||||
|
|
||||||
|
with patch.object(agent._logger, "log") as mock_log:
|
||||||
|
tools = agent._get_external_mcp_tools("https://example.com/mcp")
|
||||||
|
|
||||||
|
call_args = mock_wrapper_class.call_args
|
||||||
|
progress_callback = call_args[1]["progress_callback"]
|
||||||
|
|
||||||
|
progress_callback(50.0, 100.0, "Processing...")
|
||||||
|
|
||||||
|
mock_log.assert_called_once()
|
||||||
|
log_call = mock_log.call_args
|
||||||
|
assert log_call[0][0] == "debug"
|
||||||
|
assert "test_tool" in log_call[0][1]
|
||||||
|
assert "50.0" in log_call[0][1]
|
||||||
|
assert "100.0" in log_call[0][1]
|
||||||
|
assert "Processing..." in log_call[0][1]
|
||||||
392
lib/crewai/tests/tools/test_mcp_tool_wrapper.py
Normal file
392
lib/crewai/tests/tools/test_mcp_tool_wrapper.py
Normal file
@@ -0,0 +1,392 @@
|
|||||||
|
"""Tests for MCPToolWrapper progress and headers support."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import sys
|
||||||
|
import types
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, Mock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from crewai.events.event_bus import crewai_event_bus
|
||||||
|
from crewai.events.types.tool_usage_events import MCPToolProgressEvent
|
||||||
|
from crewai.tools.mcp_tool_wrapper import MCPToolWrapper
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def stub_mcp_modules(monkeypatch):
|
||||||
|
"""Stub the mcp modules in sys.modules to avoid import errors in CI."""
|
||||||
|
mcp = types.ModuleType("mcp")
|
||||||
|
mcp_client = types.ModuleType("mcp.client")
|
||||||
|
mcp_streamable_http = types.ModuleType("mcp.client.streamable_http")
|
||||||
|
|
||||||
|
mcp.__path__ = []
|
||||||
|
mcp_client.__path__ = []
|
||||||
|
mcp.client = mcp_client
|
||||||
|
mcp_client.streamable_http = mcp_streamable_http
|
||||||
|
|
||||||
|
class MockClientSession:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
self.initialize = AsyncMock()
|
||||||
|
self.call_tool = AsyncMock()
|
||||||
|
self.on_progress = None
|
||||||
|
|
||||||
|
async def __aenter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(self, exc_type, exc, tb):
|
||||||
|
return False
|
||||||
|
|
||||||
|
mcp.ClientSession = MockClientSession
|
||||||
|
|
||||||
|
last_kwargs = {}
|
||||||
|
|
||||||
|
def fake_streamablehttp_client(*args, **kwargs):
|
||||||
|
"""Mock streamablehttp_client context manager (NOT async def)."""
|
||||||
|
last_kwargs.clear()
|
||||||
|
last_kwargs.update(kwargs)
|
||||||
|
|
||||||
|
class MockContextManager:
|
||||||
|
async def __aenter__(self):
|
||||||
|
return (AsyncMock(), AsyncMock(), AsyncMock())
|
||||||
|
|
||||||
|
async def __aexit__(self, exc_type, exc, tb):
|
||||||
|
return False
|
||||||
|
|
||||||
|
return MockContextManager()
|
||||||
|
|
||||||
|
fake_streamablehttp_client.last_kwargs = last_kwargs
|
||||||
|
|
||||||
|
mcp_streamable_http.streamablehttp_client = fake_streamablehttp_client
|
||||||
|
|
||||||
|
monkeypatch.setitem(sys.modules, "mcp", mcp)
|
||||||
|
monkeypatch.setitem(sys.modules, "mcp.client", mcp_client)
|
||||||
|
monkeypatch.setitem(sys.modules, "mcp.client.streamable_http", mcp_streamable_http)
|
||||||
|
|
||||||
|
return mcp_streamable_http
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_mcp_session():
|
||||||
|
"""Create a mock MCP ClientSession."""
|
||||||
|
session = AsyncMock()
|
||||||
|
session.initialize = AsyncMock()
|
||||||
|
session.call_tool = AsyncMock()
|
||||||
|
return session
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_streamable_client(mock_mcp_session):
|
||||||
|
"""Create a mock streamablehttp_client context manager."""
|
||||||
|
async def mock_client(*args, **kwargs):
|
||||||
|
read = AsyncMock()
|
||||||
|
write = AsyncMock()
|
||||||
|
close = AsyncMock()
|
||||||
|
|
||||||
|
class MockContextManager:
|
||||||
|
async def __aenter__(self):
|
||||||
|
return (read, write, close)
|
||||||
|
|
||||||
|
async def __aexit__(self, *args):
|
||||||
|
pass
|
||||||
|
|
||||||
|
return MockContextManager()
|
||||||
|
|
||||||
|
return mock_client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_agent():
|
||||||
|
"""Create a mock agent with id and role."""
|
||||||
|
agent = Mock()
|
||||||
|
agent.id = "test-agent-id"
|
||||||
|
agent.role = "Test Agent"
|
||||||
|
return agent
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_task():
|
||||||
|
"""Create a mock task with id and description."""
|
||||||
|
task = Mock()
|
||||||
|
task.id = "test-task-id"
|
||||||
|
task.description = "Test Task Description"
|
||||||
|
task.name = None
|
||||||
|
return task
|
||||||
|
|
||||||
|
|
||||||
|
class TestMCPToolWrapperProgress:
|
||||||
|
"""Test suite for MCP tool wrapper progress notifications."""
|
||||||
|
|
||||||
|
def test_wrapper_initialization_with_progress_callback(self):
|
||||||
|
"""Test that MCPToolWrapper can be initialized with progress callback."""
|
||||||
|
callback = Mock()
|
||||||
|
|
||||||
|
wrapper = MCPToolWrapper(
|
||||||
|
mcp_server_params={"url": "https://example.com/mcp"},
|
||||||
|
tool_name="test_tool",
|
||||||
|
tool_schema={"description": "Test tool"},
|
||||||
|
server_name="test_server",
|
||||||
|
progress_callback=callback,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert wrapper._progress_callback == callback
|
||||||
|
assert wrapper.name == "test_server_test_tool"
|
||||||
|
|
||||||
|
def test_wrapper_initialization_without_progress_callback(self):
|
||||||
|
"""Test that MCPToolWrapper works without progress callback."""
|
||||||
|
wrapper = MCPToolWrapper(
|
||||||
|
mcp_server_params={"url": "https://example.com/mcp"},
|
||||||
|
tool_name="test_tool",
|
||||||
|
tool_schema={"description": "Test tool"},
|
||||||
|
server_name="test_server",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert wrapper._progress_callback is None
|
||||||
|
|
||||||
|
def test_wrapper_initialization_with_agent_and_task(self, mock_agent, mock_task):
|
||||||
|
"""Test that MCPToolWrapper can be initialized with agent and task context."""
|
||||||
|
wrapper = MCPToolWrapper(
|
||||||
|
mcp_server_params={"url": "https://example.com/mcp"},
|
||||||
|
tool_name="test_tool",
|
||||||
|
tool_schema={"description": "Test tool"},
|
||||||
|
server_name="test_server",
|
||||||
|
agent=mock_agent,
|
||||||
|
task=mock_task,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert wrapper._agent == mock_agent
|
||||||
|
assert wrapper._task == mock_task
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_progress_handler_called_during_execution(self, mock_agent, mock_task, stub_mcp_modules):
|
||||||
|
"""Test that progress callback is invoked when MCP server sends progress."""
|
||||||
|
import sys
|
||||||
|
from mcp import ClientSession
|
||||||
|
|
||||||
|
progress_callback = Mock()
|
||||||
|
|
||||||
|
wrapper = MCPToolWrapper(
|
||||||
|
mcp_server_params={"url": "https://example.com/mcp"},
|
||||||
|
tool_name="test_tool",
|
||||||
|
tool_schema={"description": "Test tool"},
|
||||||
|
server_name="test_server",
|
||||||
|
progress_callback=progress_callback,
|
||||||
|
agent=mock_agent,
|
||||||
|
task=mock_task,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Set up the mock result on the stubbed ClientSession
|
||||||
|
mock_result = Mock()
|
||||||
|
mock_result.content = [Mock(text="Test result")]
|
||||||
|
|
||||||
|
original_init = ClientSession.__init__
|
||||||
|
def patched_init(self, *args, **kwargs):
|
||||||
|
original_init(self, *args, **kwargs)
|
||||||
|
self.call_tool = AsyncMock(return_value=mock_result)
|
||||||
|
|
||||||
|
ClientSession.__init__ = patched_init
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = await wrapper._execute_tool(test_arg="test_value")
|
||||||
|
|
||||||
|
assert result == "Test result"
|
||||||
|
finally:
|
||||||
|
ClientSession.__init__ = original_init
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_progress_event_emission(self, mock_agent, mock_task):
|
||||||
|
"""Test that MCPToolProgressEvent is emitted when progress is reported."""
|
||||||
|
events_received = []
|
||||||
|
|
||||||
|
def event_handler(source, event):
|
||||||
|
if isinstance(event, MCPToolProgressEvent):
|
||||||
|
events_received.append(event)
|
||||||
|
|
||||||
|
crewai_event_bus.register_handler(MCPToolProgressEvent, event_handler)
|
||||||
|
|
||||||
|
try:
|
||||||
|
wrapper = MCPToolWrapper(
|
||||||
|
mcp_server_params={"url": "https://example.com/mcp"},
|
||||||
|
tool_name="test_tool",
|
||||||
|
tool_schema={"description": "Test tool"},
|
||||||
|
server_name="test_server",
|
||||||
|
progress_callback=Mock(),
|
||||||
|
agent=mock_agent,
|
||||||
|
task=mock_task,
|
||||||
|
)
|
||||||
|
|
||||||
|
wrapper._emit_progress_event(50.0, 100.0, "Processing...")
|
||||||
|
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
|
||||||
|
assert len(events_received) == 1
|
||||||
|
event = events_received[0]
|
||||||
|
assert event.tool_name == "test_tool"
|
||||||
|
assert event.server_name == "test_server"
|
||||||
|
assert event.progress == 50.0
|
||||||
|
assert event.total == 100.0
|
||||||
|
assert event.message == "Processing..."
|
||||||
|
assert event.agent_id == "test-agent-id"
|
||||||
|
assert event.agent_role == "Test Agent"
|
||||||
|
assert event.task_id == "test-task-id"
|
||||||
|
assert event.task_name == "Test Task Description"
|
||||||
|
|
||||||
|
finally:
|
||||||
|
crewai_event_bus._sync_handlers.pop(MCPToolProgressEvent, None)
|
||||||
|
|
||||||
|
def test_progress_event_without_agent_context(self):
|
||||||
|
"""Test that progress events work without agent context."""
|
||||||
|
wrapper = MCPToolWrapper(
|
||||||
|
mcp_server_params={"url": "https://example.com/mcp"},
|
||||||
|
tool_name="test_tool",
|
||||||
|
tool_schema={"description": "Test tool"},
|
||||||
|
server_name="test_server",
|
||||||
|
progress_callback=Mock(),
|
||||||
|
)
|
||||||
|
|
||||||
|
wrapper._emit_progress_event(25.0, None, "Starting...")
|
||||||
|
|
||||||
|
def test_progress_event_without_task_context(self, mock_agent):
|
||||||
|
"""Test that progress events work without task context."""
|
||||||
|
wrapper = MCPToolWrapper(
|
||||||
|
mcp_server_params={"url": "https://example.com/mcp"},
|
||||||
|
tool_name="test_tool",
|
||||||
|
tool_schema={"description": "Test tool"},
|
||||||
|
server_name="test_server",
|
||||||
|
progress_callback=Mock(),
|
||||||
|
agent=mock_agent,
|
||||||
|
)
|
||||||
|
|
||||||
|
wrapper._emit_progress_event(75.0, 100.0, None)
|
||||||
|
|
||||||
|
|
||||||
|
class TestMCPToolWrapperHeaders:
|
||||||
|
"""Test suite for MCP tool wrapper headers support."""
|
||||||
|
|
||||||
|
def test_wrapper_initialization_with_headers(self):
|
||||||
|
"""Test that MCPToolWrapper accepts headers in server params."""
|
||||||
|
headers = {"Authorization": "Bearer token123", "X-Client-ID": "test-client"}
|
||||||
|
|
||||||
|
wrapper = MCPToolWrapper(
|
||||||
|
mcp_server_params={
|
||||||
|
"url": "https://example.com/mcp",
|
||||||
|
"headers": headers,
|
||||||
|
},
|
||||||
|
tool_name="test_tool",
|
||||||
|
tool_schema={"description": "Test tool"},
|
||||||
|
server_name="test_server",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert wrapper.mcp_server_params["headers"] == headers
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_headers_passed_to_transport(self, stub_mcp_modules):
|
||||||
|
"""Test that headers are passed to streamablehttp_client."""
|
||||||
|
from mcp import ClientSession
|
||||||
|
|
||||||
|
headers = {"Authorization": "Bearer token123"}
|
||||||
|
|
||||||
|
wrapper = MCPToolWrapper(
|
||||||
|
mcp_server_params={
|
||||||
|
"url": "https://example.com/mcp",
|
||||||
|
"headers": headers,
|
||||||
|
},
|
||||||
|
tool_name="test_tool",
|
||||||
|
tool_schema={"description": "Test tool"},
|
||||||
|
server_name="test_server",
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_result = Mock()
|
||||||
|
mock_result.content = [Mock(text="Test result")]
|
||||||
|
|
||||||
|
original_init = ClientSession.__init__
|
||||||
|
def patched_init(self, *args, **kwargs):
|
||||||
|
original_init(self, *args, **kwargs)
|
||||||
|
self.call_tool = AsyncMock(return_value=mock_result)
|
||||||
|
|
||||||
|
ClientSession.__init__ = patched_init
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = await wrapper._execute_tool(test_arg="test_value")
|
||||||
|
assert result == "Test result"
|
||||||
|
|
||||||
|
# Verify headers were passed to streamablehttp_client
|
||||||
|
assert "headers" in stub_mcp_modules.streamablehttp_client.last_kwargs
|
||||||
|
assert stub_mcp_modules.streamablehttp_client.last_kwargs["headers"] == headers
|
||||||
|
finally:
|
||||||
|
ClientSession.__init__ = original_init
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_no_headers_when_not_configured(self, stub_mcp_modules):
|
||||||
|
"""Test that headers are not passed when not configured."""
|
||||||
|
from mcp import ClientSession
|
||||||
|
|
||||||
|
wrapper = MCPToolWrapper(
|
||||||
|
mcp_server_params={"url": "https://example.com/mcp"},
|
||||||
|
tool_name="test_tool",
|
||||||
|
tool_schema={"description": "Test tool"},
|
||||||
|
server_name="test_server",
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_result = Mock()
|
||||||
|
mock_result.content = [Mock(text="Test result")]
|
||||||
|
|
||||||
|
original_init = ClientSession.__init__
|
||||||
|
def patched_init(self, *args, **kwargs):
|
||||||
|
original_init(self, *args, **kwargs)
|
||||||
|
self.call_tool = AsyncMock(return_value=mock_result)
|
||||||
|
|
||||||
|
ClientSession.__init__ = patched_init
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = await wrapper._execute_tool(test_arg="test_value")
|
||||||
|
assert result == "Test result"
|
||||||
|
|
||||||
|
# Verify headers were NOT passed to streamablehttp_client
|
||||||
|
assert "headers" not in stub_mcp_modules.streamablehttp_client.last_kwargs
|
||||||
|
finally:
|
||||||
|
ClientSession.__init__ = original_init
|
||||||
|
|
||||||
|
|
||||||
|
class TestMCPToolWrapperIntegration:
|
||||||
|
"""Integration tests for MCP tool wrapper with progress and headers."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_full_execution_with_progress_and_headers(self, mock_agent, mock_task):
|
||||||
|
"""Test complete execution flow with both progress and headers."""
|
||||||
|
from mcp import ClientSession
|
||||||
|
|
||||||
|
progress_calls = []
|
||||||
|
|
||||||
|
def progress_callback(progress, total, message):
|
||||||
|
progress_calls.append((progress, total, message))
|
||||||
|
|
||||||
|
headers = {"Authorization": "Bearer test-token"}
|
||||||
|
|
||||||
|
wrapper = MCPToolWrapper(
|
||||||
|
mcp_server_params={
|
||||||
|
"url": "https://example.com/mcp",
|
||||||
|
"headers": headers,
|
||||||
|
},
|
||||||
|
tool_name="test_tool",
|
||||||
|
tool_schema={"description": "Test tool"},
|
||||||
|
server_name="test_server",
|
||||||
|
progress_callback=progress_callback,
|
||||||
|
agent=mock_agent,
|
||||||
|
task=mock_task,
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_result = Mock()
|
||||||
|
mock_result.content = [Mock(text="Test result")]
|
||||||
|
|
||||||
|
original_init = ClientSession.__init__
|
||||||
|
def patched_init(self, *args, **kwargs):
|
||||||
|
original_init(self, *args, **kwargs)
|
||||||
|
self.call_tool = AsyncMock(return_value=mock_result)
|
||||||
|
|
||||||
|
ClientSession.__init__ = patched_init
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = await wrapper._execute_tool(test_arg="test_value")
|
||||||
|
assert result == "Test result"
|
||||||
|
finally:
|
||||||
|
ClientSession.__init__ = original_init
|
||||||
Reference in New Issue
Block a user