mirror of
https://github.com/crewAIInc/crewAI.git
synced 2025-12-15 11:58:31 +00:00
New tests for MCP implementation
This commit is contained in:
@@ -10,7 +10,7 @@ mode: "wide"
|
||||
In the CrewAI framework, an `Agent` is an autonomous unit that can:
|
||||
- Perform specific tasks
|
||||
- Make decisions based on its role and goal
|
||||
- Use tools to accomplish objectives
|
||||
- Use apps, tools and MCP Servers to accomplish objectives
|
||||
- Communicate and collaborate with other agents
|
||||
- Maintain memory of interactions
|
||||
- Delegate tasks when allowed
|
||||
@@ -40,6 +40,7 @@ The Visual Agent Builder enables:
|
||||
| **Backstory** | `backstory` | `str` | Provides context and personality to the agent, enriching interactions. |
|
||||
| **LLM** _(optional)_ | `llm` | `Union[str, LLM, Any]` | Language model that powers the agent. Defaults to the model specified in `OPENAI_MODEL_NAME` or "gpt-4". |
|
||||
| **Tools** _(optional)_ | `tools` | `List[BaseTool]` | Capabilities or functions available to the agent. Defaults to an empty list. |
|
||||
| **MCP Servers** _(optional)_ | `mcps` | `Optional[List[str]]` | MCP server references for automatic tool integration. Supports HTTPS URLs and CrewAI AMP marketplace references. |
|
||||
| **Function Calling LLM** _(optional)_ | `function_calling_llm` | `Optional[Any]` | Language model for tool calling, overrides crew's LLM if specified. |
|
||||
| **Max Iterations** _(optional)_ | `max_iter` | `int` | Maximum iterations before the agent must provide its best answer. Default is 20. |
|
||||
| **Max RPM** _(optional)_ | `max_rpm` | `Optional[int]` | Maximum requests per minute to avoid rate limits. |
|
||||
@@ -194,6 +195,25 @@ research_agent = Agent(
|
||||
)
|
||||
```
|
||||
|
||||
#### Research Agent with MCP Integration
|
||||
```python Code
|
||||
mcp_research_agent = Agent(
|
||||
role="Advanced Research Analyst",
|
||||
goal="Conduct comprehensive research using multiple data sources",
|
||||
backstory="Expert researcher with access to web search, academic papers, and real-time data",
|
||||
mcps=[
|
||||
"https://mcp.exa.ai/mcp?api_key=your_key&profile=research",
|
||||
"crewai-amp:academic-research#pubmed_search", # CrewAI AMP Router for specific action in an MCP Server
|
||||
"crewai-amp:market-intelligence" # CrewAI AMP Router for a specific MCP server
|
||||
],
|
||||
verbose=True
|
||||
)
|
||||
```
|
||||
|
||||
<Note>
|
||||
The `mcps` field automatically discovers and integrates tools from MCP servers. Tools are cached for performance and connections are made on-demand. See [MCP DSL Integration](/en/mcp/dsl-integration) for detailed usage.
|
||||
</Note>
|
||||
|
||||
#### Code Development Agent
|
||||
```python Code
|
||||
dev_agent = Agent(
|
||||
|
||||
323
lib/crewai/tests/agents/test_base_agent_mcp.py
Normal file
323
lib/crewai/tests/agents/test_base_agent_mcp.py
Normal file
@@ -0,0 +1,323 @@
|
||||
"""Tests for BaseAgent MCP field validation and functionality."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch
|
||||
from pydantic import ValidationError
|
||||
|
||||
# Import from the source directory
|
||||
import sys
|
||||
import os
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../../src'))
|
||||
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
from crewai.agent import Agent
|
||||
|
||||
|
||||
class TestMCPAgent(BaseAgent):
|
||||
"""Test implementation of BaseAgent for MCP testing."""
|
||||
|
||||
def execute_task(self, task, context=None, tools=None):
|
||||
return "Test execution"
|
||||
|
||||
def create_agent_executor(self, tools=None):
|
||||
pass
|
||||
|
||||
def get_delegation_tools(self, agents):
|
||||
return []
|
||||
|
||||
def get_platform_tools(self, apps):
|
||||
return []
|
||||
|
||||
def get_mcp_tools(self, mcps):
|
||||
return []
|
||||
|
||||
|
||||
class TestBaseAgentMCPField:
|
||||
"""Test suite for BaseAgent MCP field validation and functionality."""
|
||||
|
||||
def test_mcp_field_exists(self):
|
||||
"""Test that mcps field exists on BaseAgent."""
|
||||
agent = TestMCPAgent(
|
||||
role="Test Agent",
|
||||
goal="Test MCP field",
|
||||
backstory="Testing BaseAgent MCP field"
|
||||
)
|
||||
|
||||
assert hasattr(agent, 'mcps')
|
||||
assert agent.mcps is None # Default value
|
||||
|
||||
def test_mcp_field_accepts_none(self):
|
||||
"""Test that mcps field accepts None value."""
|
||||
agent = TestMCPAgent(
|
||||
role="Test Agent",
|
||||
goal="Test MCP field",
|
||||
backstory="Testing BaseAgent MCP field",
|
||||
mcps=None
|
||||
)
|
||||
|
||||
assert agent.mcps is None
|
||||
|
||||
def test_mcp_field_accepts_empty_list(self):
|
||||
"""Test that mcps field accepts empty list."""
|
||||
agent = TestMCPAgent(
|
||||
role="Test Agent",
|
||||
goal="Test MCP field",
|
||||
backstory="Testing BaseAgent MCP field",
|
||||
mcps=[]
|
||||
)
|
||||
|
||||
assert agent.mcps == []
|
||||
|
||||
def test_mcp_field_accepts_valid_https_urls(self):
|
||||
"""Test that mcps field accepts valid HTTPS URLs."""
|
||||
valid_urls = [
|
||||
"https://api.example.com/mcp",
|
||||
"https://mcp.server.org/endpoint",
|
||||
"https://localhost:8080/mcp"
|
||||
]
|
||||
|
||||
agent = TestMCPAgent(
|
||||
role="Test Agent",
|
||||
goal="Test MCP field",
|
||||
backstory="Testing BaseAgent MCP field",
|
||||
mcps=valid_urls
|
||||
)
|
||||
|
||||
# Field validator may reorder items due to set() deduplication
|
||||
assert len(agent.mcps) == len(valid_urls)
|
||||
assert all(url in agent.mcps for url in valid_urls)
|
||||
|
||||
def test_mcp_field_accepts_valid_crewai_amp_references(self):
|
||||
"""Test that mcps field accepts valid CrewAI AMP references."""
|
||||
valid_amp_refs = [
|
||||
"crewai-amp:weather-service",
|
||||
"crewai-amp:financial-data",
|
||||
"crewai-amp:research-tools"
|
||||
]
|
||||
|
||||
agent = TestMCPAgent(
|
||||
role="Test Agent",
|
||||
goal="Test MCP field",
|
||||
backstory="Testing BaseAgent MCP field",
|
||||
mcps=valid_amp_refs
|
||||
)
|
||||
|
||||
# Field validator may reorder items due to set() deduplication
|
||||
assert len(agent.mcps) == len(valid_amp_refs)
|
||||
assert all(ref in agent.mcps for ref in valid_amp_refs)
|
||||
|
||||
def test_mcp_field_accepts_mixed_valid_references(self):
|
||||
"""Test that mcps field accepts mixed valid references."""
|
||||
mixed_refs = [
|
||||
"https://api.example.com/mcp",
|
||||
"crewai-amp:weather-service",
|
||||
"https://mcp.exa.ai/mcp?api_key=test",
|
||||
"crewai-amp:financial-data"
|
||||
]
|
||||
|
||||
agent = TestMCPAgent(
|
||||
role="Test Agent",
|
||||
goal="Test MCP field",
|
||||
backstory="Testing BaseAgent MCP field",
|
||||
mcps=mixed_refs
|
||||
)
|
||||
|
||||
# Field validator may reorder items due to set() deduplication
|
||||
assert len(agent.mcps) == len(mixed_refs)
|
||||
assert all(ref in agent.mcps for ref in mixed_refs)
|
||||
|
||||
def test_mcp_field_rejects_invalid_formats(self):
|
||||
"""Test that mcps field rejects invalid URL formats."""
|
||||
invalid_refs = [
|
||||
"http://insecure.com/mcp", # HTTP not allowed
|
||||
"invalid-format", # No protocol
|
||||
"ftp://example.com/mcp", # Wrong protocol
|
||||
"crewai:invalid", # Wrong AMP format
|
||||
"", # Empty string
|
||||
]
|
||||
|
||||
for invalid_ref in invalid_refs:
|
||||
with pytest.raises(ValidationError, match="Invalid MCP reference"):
|
||||
TestMCPAgent(
|
||||
role="Test Agent",
|
||||
goal="Test MCP field",
|
||||
backstory="Testing BaseAgent MCP field",
|
||||
mcps=[invalid_ref]
|
||||
)
|
||||
|
||||
def test_mcp_field_removes_duplicates(self):
|
||||
"""Test that mcps field removes duplicate references."""
|
||||
mcps_with_duplicates = [
|
||||
"https://api.example.com/mcp",
|
||||
"crewai-amp:weather-service",
|
||||
"https://api.example.com/mcp", # Duplicate
|
||||
"crewai-amp:weather-service" # Duplicate
|
||||
]
|
||||
|
||||
agent = TestMCPAgent(
|
||||
role="Test Agent",
|
||||
goal="Test MCP field",
|
||||
backstory="Testing BaseAgent MCP field",
|
||||
mcps=mcps_with_duplicates
|
||||
)
|
||||
|
||||
# Should contain only unique references
|
||||
assert len(agent.mcps) == 2
|
||||
assert "https://api.example.com/mcp" in agent.mcps
|
||||
assert "crewai-amp:weather-service" in agent.mcps
|
||||
|
||||
def test_mcp_field_validates_list_type(self):
|
||||
"""Test that mcps field validates list type."""
|
||||
with pytest.raises(ValidationError):
|
||||
TestMCPAgent(
|
||||
role="Test Agent",
|
||||
goal="Test MCP field",
|
||||
backstory="Testing BaseAgent MCP field",
|
||||
mcps="not-a-list" # Should be list[str]
|
||||
)
|
||||
|
||||
def test_abstract_get_mcp_tools_method_exists(self):
|
||||
"""Test that get_mcp_tools abstract method exists."""
|
||||
assert hasattr(BaseAgent, 'get_mcp_tools')
|
||||
|
||||
# Verify it's abstract by checking it's in __abstractmethods__
|
||||
assert 'get_mcp_tools' in BaseAgent.__abstractmethods__
|
||||
|
||||
def test_concrete_implementation_must_implement_get_mcp_tools(self):
|
||||
"""Test that concrete implementations must implement get_mcp_tools."""
|
||||
# This should work - TestMCPAgent implements get_mcp_tools
|
||||
agent = TestMCPAgent(
|
||||
role="Test Agent",
|
||||
goal="Test MCP field",
|
||||
backstory="Testing BaseAgent MCP field"
|
||||
)
|
||||
|
||||
assert hasattr(agent, 'get_mcp_tools')
|
||||
assert callable(agent.get_mcp_tools)
|
||||
|
||||
def test_copy_method_excludes_mcps_field(self):
|
||||
"""Test that copy method excludes mcps field from being copied."""
|
||||
agent = TestMCPAgent(
|
||||
role="Test Agent",
|
||||
goal="Test MCP field",
|
||||
backstory="Testing BaseAgent MCP field",
|
||||
mcps=["https://api.example.com/mcp"]
|
||||
)
|
||||
|
||||
copied_agent = agent.copy()
|
||||
|
||||
# MCP field should be excluded from copy
|
||||
assert copied_agent.mcps is None or copied_agent.mcps == []
|
||||
|
||||
def test_model_validation_pipeline_with_mcps(self):
|
||||
"""Test model validation pipeline with mcps field."""
|
||||
# Test validation runs correctly through entire pipeline
|
||||
agent = TestMCPAgent(
|
||||
role="Test Agent",
|
||||
goal="Test MCP field",
|
||||
backstory="Testing BaseAgent MCP field",
|
||||
mcps=["https://api.example.com/mcp", "crewai-amp:test-service"]
|
||||
)
|
||||
|
||||
# Verify all required fields are set
|
||||
assert agent.role == "Test Agent"
|
||||
assert agent.goal == "Test MCP field"
|
||||
assert agent.backstory == "Testing BaseAgent MCP field"
|
||||
assert len(agent.mcps) == 2
|
||||
|
||||
def test_mcp_field_description_is_correct(self):
|
||||
"""Test that mcps field has correct description."""
|
||||
# Get field info from model
|
||||
fields = BaseAgent.model_fields
|
||||
mcps_field = fields.get('mcps')
|
||||
|
||||
assert mcps_field is not None
|
||||
assert "MCP server references" in mcps_field.description
|
||||
assert "https://" in mcps_field.description
|
||||
assert "crewai-amp:" in mcps_field.description
|
||||
assert "#tool_name" in mcps_field.description
|
||||
|
||||
|
||||
class TestAgentMCPFieldIntegration:
|
||||
"""Test MCP field integration with concrete Agent class."""
|
||||
|
||||
def test_agent_class_has_mcp_field(self):
|
||||
"""Test that concrete Agent class inherits MCP field."""
|
||||
agent = Agent(
|
||||
role="Test Agent",
|
||||
goal="Test MCP integration",
|
||||
backstory="Testing Agent MCP field",
|
||||
mcps=["https://api.example.com/mcp"]
|
||||
)
|
||||
|
||||
assert hasattr(agent, 'mcps')
|
||||
assert agent.mcps == ["https://api.example.com/mcp"]
|
||||
|
||||
def test_agent_class_implements_get_mcp_tools(self):
|
||||
"""Test that concrete Agent class implements get_mcp_tools."""
|
||||
agent = Agent(
|
||||
role="Test Agent",
|
||||
goal="Test MCP integration",
|
||||
backstory="Testing Agent MCP field"
|
||||
)
|
||||
|
||||
assert hasattr(agent, 'get_mcp_tools')
|
||||
assert callable(agent.get_mcp_tools)
|
||||
|
||||
# Test it can be called
|
||||
result = agent.get_mcp_tools([])
|
||||
assert isinstance(result, list)
|
||||
|
||||
def test_agent_mcp_field_validation_integration(self):
|
||||
"""Test MCP field validation works with concrete Agent class."""
|
||||
# Valid case
|
||||
agent = Agent(
|
||||
role="Test Agent",
|
||||
goal="Test MCP integration",
|
||||
backstory="Testing Agent MCP field",
|
||||
mcps=["https://mcp.exa.ai/mcp", "crewai-amp:research-tools"]
|
||||
)
|
||||
|
||||
assert len(agent.mcps) == 2
|
||||
|
||||
# Invalid case
|
||||
with pytest.raises(ValidationError, match="Invalid MCP reference"):
|
||||
Agent(
|
||||
role="Test Agent",
|
||||
goal="Test MCP integration",
|
||||
backstory="Testing Agent MCP field",
|
||||
mcps=["invalid-format"]
|
||||
)
|
||||
|
||||
def test_agent_docstring_mentions_mcps(self):
|
||||
"""Test that Agent class docstring mentions mcps field."""
|
||||
docstring = Agent.__doc__
|
||||
|
||||
assert docstring is not None
|
||||
assert "mcps" in docstring.lower()
|
||||
|
||||
@patch('crewai.agent.create_llm')
|
||||
def test_agent_initialization_with_mcps_field(self, mock_create_llm):
|
||||
"""Test complete Agent initialization with mcps field."""
|
||||
mock_create_llm.return_value = Mock()
|
||||
|
||||
agent = Agent(
|
||||
role="MCP Test Agent",
|
||||
goal="Test complete MCP integration",
|
||||
backstory="Agent for testing MCP functionality",
|
||||
mcps=[
|
||||
"https://mcp.exa.ai/mcp?api_key=test",
|
||||
"crewai-amp:financial-data#get_stock_price"
|
||||
],
|
||||
verbose=True
|
||||
)
|
||||
|
||||
# Verify agent is properly initialized
|
||||
assert agent.role == "MCP Test Agent"
|
||||
assert len(agent.mcps) == 2
|
||||
assert agent.verbose is True
|
||||
|
||||
# Verify MCP-specific functionality is available
|
||||
assert hasattr(agent, 'get_mcp_tools')
|
||||
assert hasattr(agent, '_get_external_mcp_tools')
|
||||
assert hasattr(agent, '_get_amp_mcp_tools')
|
||||
470
lib/crewai/tests/agents/test_mcp_integration.py
Normal file
470
lib/crewai/tests/agents/test_mcp_integration.py
Normal file
@@ -0,0 +1,470 @@
|
||||
"""Tests for Agent MCP integration functionality."""
|
||||
|
||||
import asyncio
|
||||
import pytest
|
||||
import time
|
||||
from unittest.mock import AsyncMock, Mock, patch, MagicMock
|
||||
|
||||
# Import from the source directory
|
||||
import sys
|
||||
import os
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../../src'))
|
||||
|
||||
from crewai.agent import Agent
|
||||
from crewai.tools.mcp_tool_wrapper import MCPToolWrapper
|
||||
|
||||
|
||||
class TestAgentMCPIntegration:
|
||||
"""Test suite for Agent MCP integration functionality."""
|
||||
|
||||
@pytest.fixture
|
||||
def sample_agent(self):
|
||||
"""Create a sample agent for testing."""
|
||||
return Agent(
|
||||
role="Test Research Agent",
|
||||
goal="Test MCP integration capabilities",
|
||||
backstory="Agent designed for testing MCP functionality",
|
||||
mcps=["https://api.example.com/mcp"]
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_mcp_tools_response(self):
|
||||
"""Mock MCP server tools response."""
|
||||
mock_tool1 = Mock()
|
||||
mock_tool1.name = "search_tool"
|
||||
mock_tool1.description = "Search for information"
|
||||
|
||||
mock_tool2 = Mock()
|
||||
mock_tool2.name = "analysis_tool"
|
||||
mock_tool2.description = "Analyze data"
|
||||
|
||||
mock_result = Mock()
|
||||
mock_result.tools = [mock_tool1, mock_tool2]
|
||||
|
||||
return mock_result
|
||||
|
||||
def test_get_mcp_tools_empty_list(self, sample_agent):
|
||||
"""Test get_mcp_tools with empty list."""
|
||||
tools = sample_agent.get_mcp_tools([])
|
||||
assert tools == []
|
||||
|
||||
def test_get_mcp_tools_with_https_url(self, sample_agent):
|
||||
"""Test get_mcp_tools with HTTPS URL."""
|
||||
with patch.object(sample_agent, '_get_external_mcp_tools', return_value=[Mock()]) as mock_get:
|
||||
tools = sample_agent.get_mcp_tools(["https://api.example.com/mcp"])
|
||||
|
||||
mock_get.assert_called_once_with("https://api.example.com/mcp")
|
||||
assert len(tools) == 1
|
||||
|
||||
def test_get_mcp_tools_with_crewai_amp_reference(self, sample_agent):
|
||||
"""Test get_mcp_tools with CrewAI AMP reference."""
|
||||
with patch.object(sample_agent, '_get_amp_mcp_tools', return_value=[Mock()]) as mock_get:
|
||||
tools = sample_agent.get_mcp_tools(["crewai-amp:financial-data"])
|
||||
|
||||
mock_get.assert_called_once_with("crewai-amp:financial-data")
|
||||
assert len(tools) == 1
|
||||
|
||||
def test_get_mcp_tools_mixed_references(self, sample_agent):
|
||||
"""Test get_mcp_tools with mixed reference types."""
|
||||
mock_external_tools = [Mock(name="external_tool")]
|
||||
mock_amp_tools = [Mock(name="amp_tool")]
|
||||
|
||||
with patch.object(sample_agent, '_get_external_mcp_tools', return_value=mock_external_tools), \
|
||||
patch.object(sample_agent, '_get_amp_mcp_tools', return_value=mock_amp_tools):
|
||||
|
||||
tools = sample_agent.get_mcp_tools([
|
||||
"https://api.example.com/mcp",
|
||||
"crewai-amp:research-tools"
|
||||
])
|
||||
|
||||
assert len(tools) == 2
|
||||
|
||||
def test_get_mcp_tools_error_handling(self, sample_agent):
|
||||
"""Test get_mcp_tools error handling and graceful degradation."""
|
||||
with patch.object(sample_agent, '_get_external_mcp_tools', side_effect=Exception("Connection failed")), \
|
||||
patch.object(sample_agent, '_logger') as mock_logger:
|
||||
|
||||
tools = sample_agent.get_mcp_tools(["https://api.example.com/mcp"])
|
||||
|
||||
# Should return empty list and log warning
|
||||
assert tools == []
|
||||
mock_logger.log.assert_called_with("warning", "Skipping MCP https://api.example.com/mcp due to error: Connection failed")
|
||||
|
||||
def test_extract_server_name_basic_url(self, sample_agent):
|
||||
"""Test server name extraction from basic URLs."""
|
||||
server_name = sample_agent._extract_server_name("https://api.example.com/mcp")
|
||||
assert server_name == "api_example_com_mcp"
|
||||
|
||||
def test_extract_server_name_with_path(self, sample_agent):
|
||||
"""Test server name extraction from URLs with paths."""
|
||||
server_name = sample_agent._extract_server_name("https://mcp.exa.ai/api/v1/mcp")
|
||||
assert server_name == "mcp_exa_ai_api_v1_mcp"
|
||||
|
||||
def test_extract_server_name_no_path(self, sample_agent):
|
||||
"""Test server name extraction from URLs without path."""
|
||||
server_name = sample_agent._extract_server_name("https://example.com")
|
||||
assert server_name == "example_com"
|
||||
|
||||
def test_extract_server_name_with_query_params(self, sample_agent):
|
||||
"""Test server name extraction ignores query parameters."""
|
||||
server_name = sample_agent._extract_server_name("https://api.example.com/mcp?api_key=test")
|
||||
assert server_name == "api_example_com_mcp"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_mcp_tool_schemas_success(self, sample_agent, mock_mcp_tools_response):
|
||||
"""Test successful MCP tool schema retrieval."""
|
||||
server_params = {"url": "https://api.example.com/mcp"}
|
||||
|
||||
with patch.object(sample_agent, '_get_mcp_tool_schemas_async', return_value={
|
||||
"search_tool": {"description": "Search tool", "args_schema": None},
|
||||
"analysis_tool": {"description": "Analysis tool", "args_schema": None}
|
||||
}) as mock_async:
|
||||
|
||||
schemas = sample_agent._get_mcp_tool_schemas(server_params)
|
||||
|
||||
assert len(schemas) == 2
|
||||
assert "search_tool" in schemas
|
||||
assert "analysis_tool" in schemas
|
||||
mock_async.assert_called_once()
|
||||
|
||||
def test_get_mcp_tool_schemas_caching(self, sample_agent):
|
||||
"""Test MCP tool schema caching behavior."""
|
||||
from crewai.agent import _mcp_schema_cache
|
||||
|
||||
# Clear cache to ensure clean test state
|
||||
_mcp_schema_cache.clear()
|
||||
|
||||
server_params = {"url": "https://api.example.com/mcp"}
|
||||
mock_schemas = {"tool1": {"description": "Tool 1"}}
|
||||
|
||||
with patch.object(sample_agent, '_get_mcp_tool_schemas_async', return_value=mock_schemas) as mock_async:
|
||||
|
||||
# First call at time 1000 - should hit server
|
||||
with patch('crewai.agent.time.time', return_value=1000):
|
||||
schemas1 = sample_agent._get_mcp_tool_schemas(server_params)
|
||||
assert mock_async.call_count == 1
|
||||
|
||||
# Second call within TTL - should use cache
|
||||
with patch('crewai.agent.time.time', return_value=1100): # 100 seconds later, within 300s TTL
|
||||
schemas2 = sample_agent._get_mcp_tool_schemas(server_params)
|
||||
assert mock_async.call_count == 1 # Not called again
|
||||
assert schemas1 == schemas2
|
||||
|
||||
# Clean up cache after test
|
||||
_mcp_schema_cache.clear()
|
||||
|
||||
def test_get_mcp_tool_schemas_cache_expiration(self, sample_agent):
|
||||
"""Test MCP tool schema cache expiration."""
|
||||
from crewai.agent import _mcp_schema_cache
|
||||
|
||||
# Clear cache to ensure clean test state
|
||||
_mcp_schema_cache.clear()
|
||||
|
||||
server_params = {"url": "https://api.example.com/mcp"}
|
||||
mock_schemas = {"tool1": {"description": "Tool 1"}}
|
||||
|
||||
with patch.object(sample_agent, '_get_mcp_tool_schemas_async', return_value=mock_schemas) as mock_async:
|
||||
|
||||
# First call at time 1000
|
||||
with patch('crewai.agent.time.time', return_value=1000):
|
||||
schemas1 = sample_agent._get_mcp_tool_schemas(server_params)
|
||||
assert mock_async.call_count == 1
|
||||
|
||||
# Call after cache expiration (> 300s TTL)
|
||||
with patch('crewai.agent.time.time', return_value=1400): # 400 seconds later, beyond 300s TTL
|
||||
schemas2 = sample_agent._get_mcp_tool_schemas(server_params)
|
||||
assert mock_async.call_count == 2 # Called again after cache expiration
|
||||
|
||||
# Clean up cache after test
|
||||
_mcp_schema_cache.clear()
|
||||
|
||||
def test_get_mcp_tool_schemas_error_handling(self, sample_agent):
|
||||
"""Test MCP tool schema retrieval error handling."""
|
||||
server_params = {"url": "https://api.example.com/mcp"}
|
||||
|
||||
with patch.object(sample_agent, '_get_mcp_tool_schemas_async', side_effect=Exception("Connection failed")), \
|
||||
patch.object(sample_agent, '_logger') as mock_logger:
|
||||
|
||||
schemas = sample_agent._get_mcp_tool_schemas(server_params)
|
||||
|
||||
# Should return empty dict and log warning
|
||||
assert schemas == {}
|
||||
mock_logger.log.assert_called_with("warning", "Failed to get MCP tool schemas from https://api.example.com/mcp: Connection failed")
|
||||
|
||||
def test_get_external_mcp_tools_full_server(self, sample_agent):
|
||||
"""Test getting tools from external MCP server (full server)."""
|
||||
mcp_ref = "https://api.example.com/mcp"
|
||||
mock_schemas = {
|
||||
"tool1": {"description": "Tool 1"},
|
||||
"tool2": {"description": "Tool 2"}
|
||||
}
|
||||
|
||||
with patch.object(sample_agent, '_get_mcp_tool_schemas', return_value=mock_schemas), \
|
||||
patch.object(sample_agent, '_extract_server_name', return_value="example_server"):
|
||||
|
||||
tools = sample_agent._get_external_mcp_tools(mcp_ref)
|
||||
|
||||
assert len(tools) == 2
|
||||
assert all(isinstance(tool, MCPToolWrapper) for tool in tools)
|
||||
assert tools[0].server_name == "example_server"
|
||||
|
||||
def test_get_external_mcp_tools_specific_tool(self, sample_agent):
|
||||
"""Test getting specific tool from external MCP server."""
|
||||
mcp_ref = "https://api.example.com/mcp#tool1"
|
||||
mock_schemas = {
|
||||
"tool1": {"description": "Tool 1"},
|
||||
"tool2": {"description": "Tool 2"}
|
||||
}
|
||||
|
||||
with patch.object(sample_agent, '_get_mcp_tool_schemas', return_value=mock_schemas), \
|
||||
patch.object(sample_agent, '_extract_server_name', return_value="example_server"):
|
||||
|
||||
tools = sample_agent._get_external_mcp_tools(mcp_ref)
|
||||
|
||||
# Should only get tool1
|
||||
assert len(tools) == 1
|
||||
assert tools[0].original_tool_name == "tool1"
|
||||
|
||||
def test_get_external_mcp_tools_specific_tool_not_found(self, sample_agent):
|
||||
"""Test getting specific tool that doesn't exist on MCP server."""
|
||||
mcp_ref = "https://api.example.com/mcp#nonexistent_tool"
|
||||
mock_schemas = {
|
||||
"tool1": {"description": "Tool 1"},
|
||||
"tool2": {"description": "Tool 2"}
|
||||
}
|
||||
|
||||
with patch.object(sample_agent, '_get_mcp_tool_schemas', return_value=mock_schemas), \
|
||||
patch.object(sample_agent, '_logger') as mock_logger:
|
||||
|
||||
tools = sample_agent._get_external_mcp_tools(mcp_ref)
|
||||
|
||||
# Should return empty list and log warning
|
||||
assert tools == []
|
||||
mock_logger.log.assert_called_with("warning", "Specific tool 'nonexistent_tool' not found on MCP server: https://api.example.com/mcp")
|
||||
|
||||
def test_get_external_mcp_tools_no_schemas(self, sample_agent):
|
||||
"""Test getting tools when no schemas are discovered."""
|
||||
mcp_ref = "https://api.example.com/mcp"
|
||||
|
||||
with patch.object(sample_agent, '_get_mcp_tool_schemas', return_value={}), \
|
||||
patch.object(sample_agent, '_logger') as mock_logger:
|
||||
|
||||
tools = sample_agent._get_external_mcp_tools(mcp_ref)
|
||||
|
||||
assert tools == []
|
||||
mock_logger.log.assert_called_with("warning", "No tools discovered from MCP server: https://api.example.com/mcp")
|
||||
|
||||
def test_get_amp_mcp_tools_full_mcp(self, sample_agent):
|
||||
"""Test getting tools from CrewAI AMP MCP marketplace (full MCP)."""
|
||||
amp_ref = "crewai-amp:financial-data"
|
||||
mock_servers = [{"url": "https://amp.crewai.com/mcp/financial"}]
|
||||
|
||||
with patch.object(sample_agent, '_fetch_amp_mcp_servers', return_value=mock_servers), \
|
||||
patch.object(sample_agent, '_get_external_mcp_tools', return_value=[Mock()]) as mock_get_tools:
|
||||
|
||||
tools = sample_agent._get_amp_mcp_tools(amp_ref)
|
||||
|
||||
mock_get_tools.assert_called_once_with("https://amp.crewai.com/mcp/financial")
|
||||
assert len(tools) == 1
|
||||
|
||||
def test_get_amp_mcp_tools_specific_tool(self, sample_agent):
|
||||
"""Test getting specific tool from CrewAI AMP MCP marketplace."""
|
||||
amp_ref = "crewai-amp:financial-data#get_stock_price"
|
||||
mock_servers = [{"url": "https://amp.crewai.com/mcp/financial"}]
|
||||
|
||||
with patch.object(sample_agent, '_fetch_amp_mcp_servers', return_value=mock_servers), \
|
||||
patch.object(sample_agent, '_get_external_mcp_tools', return_value=[Mock()]) as mock_get_tools:
|
||||
|
||||
tools = sample_agent._get_amp_mcp_tools(amp_ref)
|
||||
|
||||
mock_get_tools.assert_called_once_with("https://amp.crewai.com/mcp/financial#get_stock_price")
|
||||
assert len(tools) == 1
|
||||
|
||||
def test_get_amp_mcp_tools_multiple_servers(self, sample_agent):
|
||||
"""Test getting tools from multiple AMP MCP servers."""
|
||||
amp_ref = "crewai-amp:multi-server-mcp"
|
||||
mock_servers = [
|
||||
{"url": "https://amp.crewai.com/mcp/server1"},
|
||||
{"url": "https://amp.crewai.com/mcp/server2"}
|
||||
]
|
||||
|
||||
with patch.object(sample_agent, '_fetch_amp_mcp_servers', return_value=mock_servers), \
|
||||
patch.object(sample_agent, '_get_external_mcp_tools', return_value=[Mock()]) as mock_get_tools:
|
||||
|
||||
tools = sample_agent._get_amp_mcp_tools(amp_ref)
|
||||
|
||||
assert mock_get_tools.call_count == 2
|
||||
assert len(tools) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_mcp_tool_schemas_async_success(self, sample_agent, mock_mcp_tools_response):
|
||||
"""Test successful async MCP tool schema retrieval."""
|
||||
server_params = {"url": "https://api.example.com/mcp"}
|
||||
|
||||
with patch('crewai.agent.streamablehttp_client') as mock_client, \
|
||||
patch('crewai.agent.ClientSession') as mock_session_class:
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_session_class.return_value.__aenter__.return_value = mock_session
|
||||
mock_session.initialize = AsyncMock()
|
||||
mock_session.list_tools = AsyncMock(return_value=mock_mcp_tools_response)
|
||||
|
||||
mock_client.return_value.__aenter__.return_value = (None, None, None)
|
||||
|
||||
schemas = await sample_agent._get_mcp_tool_schemas_async(server_params)
|
||||
|
||||
assert len(schemas) == 2
|
||||
assert "search_tool" in schemas
|
||||
assert "analysis_tool" in schemas
|
||||
assert schemas["search_tool"]["description"] == "Search for information"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_mcp_tool_schemas_async_timeout(self, sample_agent):
|
||||
"""Test async MCP tool schema retrieval timeout handling."""
|
||||
server_params = {"url": "https://api.example.com/mcp"}
|
||||
|
||||
with patch('crewai.agent.asyncio.wait_for', side_effect=asyncio.TimeoutError):
|
||||
with pytest.raises(RuntimeError, match="Failed to discover MCP tools after 3 attempts"):
|
||||
await sample_agent._get_mcp_tool_schemas_async(server_params)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_mcp_tool_schemas_async_import_error(self, sample_agent):
|
||||
"""Test async MCP tool schema retrieval with missing MCP library."""
|
||||
server_params = {"url": "https://api.example.com/mcp"}
|
||||
|
||||
with patch('crewai.agent.ClientSession', side_effect=ImportError("No module named 'mcp'")):
|
||||
with pytest.raises(RuntimeError, match="MCP library not available"):
|
||||
await sample_agent._get_mcp_tool_schemas_async(server_params)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_mcp_tool_schemas_async_retry_logic(self, sample_agent):
|
||||
"""Test retry logic with exponential backoff in async schema retrieval."""
|
||||
server_params = {"url": "https://api.example.com/mcp"}
|
||||
|
||||
call_count = 0
|
||||
async def mock_discover_tools(url):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count < 3:
|
||||
raise Exception("Network connection failed")
|
||||
return {"tool1": {"description": "Tool 1"}}
|
||||
|
||||
with patch.object(sample_agent, '_discover_mcp_tools', side_effect=mock_discover_tools), \
|
||||
patch('crewai.agent.asyncio.sleep') as mock_sleep:
|
||||
|
||||
schemas = await sample_agent._get_mcp_tool_schemas_async(server_params)
|
||||
|
||||
assert schemas == {"tool1": {"description": "Tool 1"}}
|
||||
assert call_count == 3
|
||||
# Verify exponential backoff
|
||||
assert mock_sleep.call_count == 2
|
||||
mock_sleep.assert_any_call(1) # First retry: 2^0 = 1
|
||||
mock_sleep.assert_any_call(2) # Second retry: 2^1 = 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_discover_mcp_tools_success(self, sample_agent, mock_mcp_tools_response):
|
||||
"""Test successful MCP tool discovery."""
|
||||
server_url = "https://api.example.com/mcp"
|
||||
|
||||
with patch('crewai.agent.streamablehttp_client') as mock_client, \
|
||||
patch('crewai.agent.ClientSession') as mock_session_class:
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_session_class.return_value.__aenter__.return_value = mock_session
|
||||
mock_session.initialize = AsyncMock()
|
||||
mock_session.list_tools = AsyncMock(return_value=mock_mcp_tools_response)
|
||||
|
||||
mock_client.return_value.__aenter__.return_value = (None, None, None)
|
||||
|
||||
schemas = await sample_agent._discover_mcp_tools(server_url)
|
||||
|
||||
assert len(schemas) == 2
|
||||
assert schemas["search_tool"]["description"] == "Search for information"
|
||||
assert schemas["analysis_tool"]["description"] == "Analyze data"
|
||||
|
||||
def test_fetch_amp_mcp_servers_placeholder(self, sample_agent):
|
||||
"""Test AMP MCP server fetching (currently returns empty list)."""
|
||||
result = sample_agent._fetch_amp_mcp_servers("test-mcp")
|
||||
|
||||
# Currently returns empty list - placeholder implementation
|
||||
assert result == []
|
||||
|
||||
def test_get_external_mcp_tools_error_handling(self, sample_agent):
|
||||
"""Test external MCP tools error handling."""
|
||||
mcp_ref = "https://failing-server.com/mcp"
|
||||
|
||||
with patch.object(sample_agent, '_get_mcp_tool_schemas', side_effect=Exception("Server unavailable")), \
|
||||
patch.object(sample_agent, '_logger') as mock_logger:
|
||||
|
||||
tools = sample_agent._get_external_mcp_tools(mcp_ref)
|
||||
|
||||
assert tools == []
|
||||
mock_logger.log.assert_called_with("warning", "Failed to connect to MCP server https://failing-server.com/mcp: Server unavailable")
|
||||
|
||||
def test_get_external_mcp_tools_wrapper_creation_error(self, sample_agent):
|
||||
"""Test handling of MCPToolWrapper creation errors."""
|
||||
mcp_ref = "https://api.example.com/mcp"
|
||||
mock_schemas = {"tool1": {"description": "Tool 1"}}
|
||||
|
||||
with patch.object(sample_agent, '_get_mcp_tool_schemas', return_value=mock_schemas), \
|
||||
patch.object(sample_agent, '_extract_server_name', return_value="example_server"), \
|
||||
patch('crewai.tools.mcp_tool_wrapper.MCPToolWrapper', side_effect=Exception("Wrapper creation failed")), \
|
||||
patch.object(sample_agent, '_logger') as mock_logger:
|
||||
|
||||
tools = sample_agent._get_external_mcp_tools(mcp_ref)
|
||||
|
||||
assert tools == []
|
||||
mock_logger.log.assert_called_with("warning", "Failed to create MCP tool wrapper for tool1: Wrapper creation failed")
|
||||
|
||||
def test_mcp_tools_integration_with_existing_agent_features(self, sample_agent):
|
||||
"""Test MCP tools integration with existing agent features."""
|
||||
# Test that MCP field works alongside other agent features
|
||||
agent_with_all_features = Agent(
|
||||
role="Full Feature Agent",
|
||||
goal="Test all features together",
|
||||
backstory="Agent with all features enabled",
|
||||
mcps=["https://api.example.com/mcp", "crewai-amp:financial-data"],
|
||||
apps=["gmail", "slack"], # Platform apps
|
||||
tools=[], # Regular tools
|
||||
verbose=True,
|
||||
max_iter=15,
|
||||
allow_delegation=True
|
||||
)
|
||||
|
||||
assert len(agent_with_all_features.mcps) == 2
|
||||
assert len(agent_with_all_features.apps) == 2
|
||||
assert agent_with_all_features.verbose is True
|
||||
assert agent_with_all_features.max_iter == 15
|
||||
assert agent_with_all_features.allow_delegation is True
|
||||
|
||||
def test_mcp_reference_parsing_edge_cases(self, sample_agent):
|
||||
"""Test MCP reference parsing with edge cases."""
|
||||
test_cases = [
|
||||
# URL with complex query parameters
|
||||
("https://api.example.com/mcp?api_key=abc123&profile=test&version=1.0", "api.example.com", None),
|
||||
# URL with tool name and query params
|
||||
("https://api.example.com/mcp?api_key=test#search_tool", "api.example.com", "search_tool"),
|
||||
# AMP reference with dashes and underscores
|
||||
("crewai-amp:financial_data-v2", "financial_data-v2", None),
|
||||
# AMP reference with tool name
|
||||
("crewai-amp:research-tools#pubmed_search", "research-tools", "pubmed_search"),
|
||||
]
|
||||
|
||||
for mcp_ref, expected_server_part, expected_tool in test_cases:
|
||||
if mcp_ref.startswith("https://"):
|
||||
if '#' in mcp_ref:
|
||||
server_url, tool_name = mcp_ref.split('#', 1)
|
||||
assert expected_server_part in server_url
|
||||
assert tool_name == expected_tool
|
||||
else:
|
||||
assert expected_server_part in mcp_ref
|
||||
assert expected_tool is None
|
||||
else: # AMP reference
|
||||
amp_part = mcp_ref.replace('crewai-amp:', '')
|
||||
if '#' in amp_part:
|
||||
mcp_name, tool_name = amp_part.split('#', 1)
|
||||
assert mcp_name == expected_server_part
|
||||
assert tool_name == expected_tool
|
||||
else:
|
||||
assert amp_part == expected_server_part
|
||||
assert expected_tool is None
|
||||
502
lib/crewai/tests/crew/test_mcp_crew_integration.py
Normal file
502
lib/crewai/tests/crew/test_mcp_crew_integration.py
Normal file
@@ -0,0 +1,502 @@
|
||||
"""Tests for Crew MCP integration functionality."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch, MagicMock
|
||||
|
||||
# Import from the source directory
|
||||
import sys
|
||||
import os
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../../src'))
|
||||
|
||||
from crewai.agent import Agent
|
||||
from crewai.crew import Crew
|
||||
from crewai.task import Task
|
||||
from crewai.tools.mcp_tool_wrapper import MCPToolWrapper
|
||||
|
||||
|
||||
class TestCrewMCPIntegration:
|
||||
"""Test suite for Crew MCP integration functionality."""
|
||||
|
||||
@pytest.fixture
|
||||
def mcp_agent(self):
|
||||
"""Create an agent with MCP tools."""
|
||||
return Agent(
|
||||
role="MCP Research Agent",
|
||||
goal="Research using MCP tools",
|
||||
backstory="Agent with access to MCP tools for research",
|
||||
mcps=["https://api.example.com/mcp"]
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def regular_agent(self):
|
||||
"""Create a regular agent without MCP tools."""
|
||||
return Agent(
|
||||
role="Regular Agent",
|
||||
goal="Regular tasks without MCP",
|
||||
backstory="Standard agent without MCP access"
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def sample_task(self, mcp_agent):
|
||||
"""Create a sample task for testing."""
|
||||
return Task(
|
||||
description="Research AI frameworks using available tools",
|
||||
expected_output="Comprehensive research report",
|
||||
agent=mcp_agent
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_mcp_tools(self):
|
||||
"""Create mock MCP tools."""
|
||||
tool1 = Mock(spec=MCPToolWrapper)
|
||||
tool1.name = "mcp_server_search_tool"
|
||||
tool1.description = "Search tool from MCP server"
|
||||
|
||||
tool2 = Mock(spec=MCPToolWrapper)
|
||||
tool2.name = "mcp_server_analysis_tool"
|
||||
tool2.description = "Analysis tool from MCP server"
|
||||
|
||||
return [tool1, tool2]
|
||||
|
||||
def test_crew_creation_with_mcp_agent(self, mcp_agent, sample_task):
|
||||
"""Test crew creation with MCP-enabled agent."""
|
||||
crew = Crew(
|
||||
agents=[mcp_agent],
|
||||
tasks=[sample_task],
|
||||
verbose=False
|
||||
)
|
||||
|
||||
assert crew is not None
|
||||
assert len(crew.agents) == 1
|
||||
assert len(crew.tasks) == 1
|
||||
assert crew.agents[0] == mcp_agent
|
||||
|
||||
def test_crew_add_mcp_tools_method_exists(self):
|
||||
"""Test that Crew class has _add_mcp_tools method."""
|
||||
crew = Crew(agents=[], tasks=[])
|
||||
|
||||
assert hasattr(crew, '_add_mcp_tools')
|
||||
assert callable(crew._add_mcp_tools)
|
||||
|
||||
def test_crew_inject_mcp_tools_method_exists(self):
|
||||
"""Test that Crew class has _inject_mcp_tools method."""
|
||||
crew = Crew(agents=[], tasks=[])
|
||||
|
||||
assert hasattr(crew, '_inject_mcp_tools')
|
||||
assert callable(crew._inject_mcp_tools)
|
||||
|
||||
def test_inject_mcp_tools_with_mcp_agent(self, mcp_agent, mock_mcp_tools):
|
||||
"""Test MCP tools injection with MCP-enabled agent."""
|
||||
crew = Crew(agents=[mcp_agent], tasks=[])
|
||||
|
||||
initial_tools = []
|
||||
|
||||
with patch.object(mcp_agent, 'get_mcp_tools', return_value=mock_mcp_tools):
|
||||
result_tools = crew._inject_mcp_tools(initial_tools, mcp_agent)
|
||||
|
||||
# Should merge MCP tools with existing tools
|
||||
assert len(result_tools) == len(mock_mcp_tools)
|
||||
|
||||
# Verify get_mcp_tools was called with agent's mcps
|
||||
mcp_agent.get_mcp_tools.assert_called_once_with(mcps=mcp_agent.mcps)
|
||||
|
||||
def test_inject_mcp_tools_with_regular_agent(self, regular_agent):
|
||||
"""Test MCP tools injection with regular agent (no MCP tools)."""
|
||||
crew = Crew(agents=[regular_agent], tasks=[])
|
||||
|
||||
initial_tools = [Mock(name="existing_tool")]
|
||||
|
||||
# Regular agent has no mcps attribute
|
||||
result_tools = crew._inject_mcp_tools(initial_tools, regular_agent)
|
||||
|
||||
# Should return original tools unchanged
|
||||
assert result_tools == initial_tools
|
||||
|
||||
def test_inject_mcp_tools_empty_mcps_list(self, mcp_agent):
|
||||
"""Test MCP tools injection with empty mcps list."""
|
||||
crew = Crew(agents=[mcp_agent], tasks=[])
|
||||
|
||||
# Agent with empty mcps list
|
||||
mcp_agent.mcps = []
|
||||
initial_tools = [Mock(name="existing_tool")]
|
||||
|
||||
result_tools = crew._inject_mcp_tools(initial_tools, mcp_agent)
|
||||
|
||||
# Should return original tools unchanged
|
||||
assert result_tools == initial_tools
|
||||
|
||||
def test_inject_mcp_tools_none_mcps(self, mcp_agent):
|
||||
"""Test MCP tools injection with None mcps."""
|
||||
crew = Crew(agents=[mcp_agent], tasks=[])
|
||||
|
||||
# Agent with None mcps
|
||||
mcp_agent.mcps = None
|
||||
initial_tools = [Mock(name="existing_tool")]
|
||||
|
||||
result_tools = crew._inject_mcp_tools(initial_tools, mcp_agent)
|
||||
|
||||
# Should return original tools unchanged
|
||||
assert result_tools == initial_tools
|
||||
|
||||
def test_add_mcp_tools_with_task_agent(self, mcp_agent, sample_task, mock_mcp_tools):
|
||||
"""Test _add_mcp_tools method with task agent."""
|
||||
crew = Crew(agents=[mcp_agent], tasks=[sample_task])
|
||||
|
||||
initial_tools = [Mock(name="task_tool")]
|
||||
|
||||
with patch.object(crew, '_inject_mcp_tools', return_value=initial_tools + mock_mcp_tools) as mock_inject:
|
||||
|
||||
result_tools = crew._add_mcp_tools(sample_task, initial_tools)
|
||||
|
||||
# Should call _inject_mcp_tools with task agent
|
||||
mock_inject.assert_called_once_with(initial_tools, sample_task.agent)
|
||||
assert len(result_tools) == len(initial_tools) + len(mock_mcp_tools)
|
||||
|
||||
def test_add_mcp_tools_with_no_agent_task(self):
|
||||
"""Test _add_mcp_tools method with task that has no agent."""
|
||||
crew = Crew(agents=[], tasks=[])
|
||||
|
||||
# Task without agent
|
||||
task_no_agent = Task(
|
||||
description="Task without agent",
|
||||
expected_output="Some output",
|
||||
agent=None
|
||||
)
|
||||
|
||||
initial_tools = [Mock(name="task_tool")]
|
||||
|
||||
result_tools = crew._add_mcp_tools(task_no_agent, initial_tools)
|
||||
|
||||
# Should return original tools unchanged
|
||||
assert result_tools == initial_tools
|
||||
|
||||
def test_mcp_tools_integration_in_task_preparation_flow(self, mcp_agent, sample_task, mock_mcp_tools):
|
||||
"""Test MCP tools integration in the task preparation flow."""
|
||||
crew = Crew(agents=[mcp_agent], tasks=[sample_task])
|
||||
|
||||
# Mock the crew's tool preparation methods
|
||||
with patch.object(crew, '_add_platform_tools', return_value=[Mock(name="platform_tool")]) as mock_platform, \
|
||||
patch.object(crew, '_add_mcp_tools', return_value=mock_mcp_tools) as mock_mcp, \
|
||||
patch.object(crew, '_add_multimodal_tools', return_value=mock_mcp_tools) as mock_multimodal:
|
||||
|
||||
# This tests the integration point where MCP tools are added to task tools
|
||||
# We can't easily test the full _prepare_tools_for_task method due to complexity,
|
||||
# but we can verify our _add_mcp_tools integration point works
|
||||
|
||||
result = crew._add_mcp_tools(sample_task, [])
|
||||
|
||||
assert result == mock_mcp_tools
|
||||
mock_mcp.assert_called_once_with(sample_task, [])
|
||||
|
||||
def test_mcp_tools_merge_with_existing_tools(self, mcp_agent, mock_mcp_tools):
|
||||
"""Test that MCP tools merge correctly with existing tools."""
|
||||
from crewai.tools import BaseTool
|
||||
|
||||
class ExistingTool(BaseTool):
|
||||
name: str = "existing_search"
|
||||
description: str = "Existing search tool"
|
||||
|
||||
def _run(self, **kwargs):
|
||||
return "Existing search result"
|
||||
|
||||
existing_tools = [ExistingTool()]
|
||||
crew = Crew(agents=[mcp_agent], tasks=[])
|
||||
|
||||
with patch.object(mcp_agent, 'get_mcp_tools', return_value=mock_mcp_tools):
|
||||
merged_tools = crew._inject_mcp_tools(existing_tools, mcp_agent)
|
||||
|
||||
# Should have both existing tools and MCP tools
|
||||
total_expected = len(existing_tools) + len(mock_mcp_tools)
|
||||
assert len(merged_tools) == total_expected
|
||||
|
||||
# Verify existing tools are preserved
|
||||
existing_names = [tool.name for tool in existing_tools]
|
||||
merged_names = [tool.name for tool in merged_tools]
|
||||
|
||||
for existing_name in existing_names:
|
||||
assert existing_name in merged_names
|
||||
|
||||
def test_mcp_tools_available_in_crew_context(self, mcp_agent, sample_task, mock_mcp_tools):
|
||||
"""Test that MCP tools are available in crew execution context."""
|
||||
crew = Crew(agents=[mcp_agent], tasks=[sample_task])
|
||||
|
||||
with patch.object(mcp_agent, 'get_mcp_tools', return_value=mock_mcp_tools):
|
||||
|
||||
# Test that crew can access MCP tools through agent
|
||||
agent_tools = crew._inject_mcp_tools([], mcp_agent)
|
||||
|
||||
assert len(agent_tools) == len(mock_mcp_tools)
|
||||
assert all(tool in agent_tools for tool in mock_mcp_tools)
|
||||
|
||||
def test_crew_with_mixed_agents_mcp_and_regular(self, mcp_agent, regular_agent, mock_mcp_tools):
|
||||
"""Test crew with both MCP-enabled and regular agents."""
|
||||
task1 = Task(
|
||||
description="Task for MCP agent",
|
||||
expected_output="MCP-powered result",
|
||||
agent=mcp_agent
|
||||
)
|
||||
|
||||
task2 = Task(
|
||||
description="Task for regular agent",
|
||||
expected_output="Regular result",
|
||||
agent=regular_agent
|
||||
)
|
||||
|
||||
crew = Crew(
|
||||
agents=[mcp_agent, regular_agent],
|
||||
tasks=[task1, task2]
|
||||
)
|
||||
|
||||
# Test MCP tools injection for MCP agent
|
||||
with patch.object(mcp_agent, 'get_mcp_tools', return_value=mock_mcp_tools):
|
||||
mcp_tools = crew._inject_mcp_tools([], mcp_agent)
|
||||
assert len(mcp_tools) == len(mock_mcp_tools)
|
||||
|
||||
# Test MCP tools injection for regular agent
|
||||
regular_tools = crew._inject_mcp_tools([], regular_agent)
|
||||
assert len(regular_tools) == 0
|
||||
|
||||
def test_crew_mcp_tools_error_handling_during_execution_prep(self, mcp_agent, sample_task):
|
||||
"""Test crew error handling when MCP tools fail during execution preparation."""
|
||||
crew = Crew(agents=[mcp_agent], tasks=[sample_task])
|
||||
|
||||
# Mock MCP tools failure during crew execution preparation
|
||||
with patch.object(mcp_agent, 'get_mcp_tools', side_effect=Exception("MCP tools failed")):
|
||||
|
||||
# Crew operations should continue despite MCP failure
|
||||
try:
|
||||
crew._inject_mcp_tools([], mcp_agent)
|
||||
# If we get here, the error was handled gracefully by returning empty tools
|
||||
except Exception as e:
|
||||
# If exception propagates, it should be an expected one
|
||||
assert "MCP tools failed" in str(e)
|
||||
|
||||
def test_crew_task_execution_flow_includes_mcp_tools(self, mcp_agent, sample_task):
|
||||
"""Test that crew task execution flow includes MCP tools integration."""
|
||||
crew = Crew(agents=[mcp_agent], tasks=[sample_task])
|
||||
|
||||
# Verify that crew has the necessary methods for MCP integration
|
||||
assert hasattr(crew, '_add_mcp_tools')
|
||||
assert hasattr(crew, '_inject_mcp_tools')
|
||||
|
||||
# Test the task has an agent with MCP capabilities
|
||||
assert sample_task.agent == mcp_agent
|
||||
assert hasattr(sample_task.agent, 'mcps')
|
||||
assert hasattr(sample_task.agent, 'get_mcp_tools')
|
||||
|
||||
def test_mcp_tools_do_not_interfere_with_platform_tools(self, mock_mcp_tools):
|
||||
"""Test that MCP tools don't interfere with platform tools integration."""
|
||||
agent_with_both = Agent(
|
||||
role="Multi-Tool Agent",
|
||||
goal="Use both platform and MCP tools",
|
||||
backstory="Agent with access to multiple tool types",
|
||||
apps=["gmail", "slack"],
|
||||
mcps=["https://api.example.com/mcp"]
|
||||
)
|
||||
|
||||
task = Task(
|
||||
description="Use both platform and MCP tools",
|
||||
expected_output="Combined tool usage result",
|
||||
agent=agent_with_both
|
||||
)
|
||||
|
||||
crew = Crew(agents=[agent_with_both], tasks=[task])
|
||||
|
||||
platform_tools = [Mock(name="gmail_tool"), Mock(name="slack_tool")]
|
||||
|
||||
# Test platform tools injection
|
||||
with patch.object(crew, '_inject_platform_tools', return_value=platform_tools):
|
||||
result_platform = crew._inject_platform_tools([], agent_with_both)
|
||||
assert len(result_platform) == len(platform_tools)
|
||||
|
||||
# Test MCP tools injection
|
||||
with patch.object(agent_with_both, 'get_mcp_tools', return_value=mock_mcp_tools):
|
||||
result_mcp = crew._inject_mcp_tools(platform_tools, agent_with_both)
|
||||
assert len(result_mcp) == len(platform_tools) + len(mock_mcp_tools)
|
||||
|
||||
def test_crew_task_execution_order_includes_mcp_tools(self, mcp_agent, sample_task):
|
||||
"""Test that crew task execution order includes MCP tools at the right point."""
|
||||
crew = Crew(agents=[mcp_agent], tasks=[sample_task])
|
||||
|
||||
# Mock the various tool addition methods to verify call order
|
||||
call_order = []
|
||||
|
||||
def track_platform_tools(*args, **kwargs):
|
||||
call_order.append("platform")
|
||||
return []
|
||||
|
||||
def track_mcp_tools(*args, **kwargs):
|
||||
call_order.append("mcp")
|
||||
return []
|
||||
|
||||
def track_multimodal_tools(*args, **kwargs):
|
||||
call_order.append("multimodal")
|
||||
return []
|
||||
|
||||
with patch.object(crew, '_add_platform_tools', side_effect=track_platform_tools), \
|
||||
patch.object(crew, '_add_mcp_tools', side_effect=track_mcp_tools), \
|
||||
patch.object(crew, '_add_multimodal_tools', side_effect=track_multimodal_tools):
|
||||
|
||||
# Test the crew's task preparation flow
|
||||
# We check that MCP tools are added in the right sequence
|
||||
|
||||
# These methods are called in the task preparation flow
|
||||
crew._add_platform_tools(sample_task, [])
|
||||
crew._add_mcp_tools(sample_task, [])
|
||||
|
||||
assert "platform" in call_order
|
||||
assert "mcp" in call_order
|
||||
|
||||
def test_crew_handles_agent_without_get_mcp_tools_method(self):
|
||||
"""Test crew handles agents that don't implement get_mcp_tools method."""
|
||||
# Create a mock agent that doesn't have get_mcp_tools
|
||||
mock_agent = Mock()
|
||||
mock_agent.mcps = ["https://api.example.com/mcp"]
|
||||
# Explicitly don't add get_mcp_tools method
|
||||
|
||||
crew = Crew(agents=[], tasks=[])
|
||||
|
||||
# Should handle gracefully when agent doesn't have get_mcp_tools
|
||||
result_tools = crew._inject_mcp_tools([], mock_agent)
|
||||
|
||||
# Should return empty list since agent doesn't have get_mcp_tools
|
||||
assert result_tools == []
|
||||
|
||||
def test_crew_handles_agent_get_mcp_tools_exception(self, mcp_agent):
|
||||
"""Test crew handles exceptions from agent's get_mcp_tools method."""
|
||||
crew = Crew(agents=[mcp_agent], tasks=[])
|
||||
|
||||
with patch.object(mcp_agent, 'get_mcp_tools', side_effect=Exception("MCP tools failed")):
|
||||
|
||||
# Should handle exception gracefully
|
||||
result_tools = crew._inject_mcp_tools([], mcp_agent)
|
||||
|
||||
# Depending on implementation, should either return empty list or re-raise
|
||||
# Since get_mcp_tools handles errors internally, this should return empty list
|
||||
assert isinstance(result_tools, list)
|
||||
|
||||
def test_crew_mcp_tools_merge_functionality(self, mock_mcp_tools):
|
||||
"""Test crew's tool merging functionality with MCP tools."""
|
||||
crew = Crew(agents=[], tasks=[])
|
||||
|
||||
existing_tools = [Mock(name="existing_tool_1"), Mock(name="existing_tool_2")]
|
||||
|
||||
# Test _merge_tools method with MCP tools
|
||||
merged_tools = crew._merge_tools(existing_tools, mock_mcp_tools)
|
||||
|
||||
total_expected = len(existing_tools) + len(mock_mcp_tools)
|
||||
assert len(merged_tools) == total_expected
|
||||
|
||||
# Verify all tools are present
|
||||
all_tool_names = [tool.name for tool in merged_tools]
|
||||
assert "existing_tool_1" in all_tool_names
|
||||
assert "existing_tool_2" in all_tool_names
|
||||
assert mock_mcp_tools[0].name in all_tool_names
|
||||
assert mock_mcp_tools[1].name in all_tool_names
|
||||
|
||||
def test_crew_workflow_integration_conditions(self, mcp_agent, sample_task):
|
||||
"""Test the conditions for MCP tools integration in crew workflows."""
|
||||
crew = Crew(agents=[mcp_agent], tasks=[sample_task])
|
||||
|
||||
# Test condition: agent exists and has mcps attribute
|
||||
assert hasattr(sample_task.agent, 'mcps')
|
||||
assert sample_task.agent.mcps is not None
|
||||
|
||||
# Test condition: agent has get_mcp_tools method
|
||||
assert hasattr(sample_task.agent, 'get_mcp_tools')
|
||||
|
||||
# Test condition: mcps list is not empty
|
||||
sample_task.agent.mcps = ["https://api.example.com/mcp"]
|
||||
assert len(sample_task.agent.mcps) > 0
|
||||
|
||||
def test_crew_mcp_integration_performance_impact(self, mcp_agent, sample_task, mock_mcp_tools):
|
||||
"""Test that MCP integration doesn't significantly impact crew performance."""
|
||||
crew = Crew(agents=[mcp_agent], tasks=[sample_task])
|
||||
|
||||
import time
|
||||
|
||||
# Test tool injection performance
|
||||
start_time = time.time()
|
||||
|
||||
with patch.object(mcp_agent, 'get_mcp_tools', return_value=mock_mcp_tools):
|
||||
# Multiple tool injection calls should be fast due to caching
|
||||
for _ in range(5):
|
||||
tools = crew._inject_mcp_tools([], mcp_agent)
|
||||
|
||||
end_time = time.time()
|
||||
total_time = end_time - start_time
|
||||
|
||||
# Should complete quickly (less than 1 second for 5 operations)
|
||||
assert total_time < 1.0
|
||||
assert len(tools) == len(mock_mcp_tools)
|
||||
|
||||
def test_crew_task_tool_availability_with_mcp(self, mcp_agent, sample_task, mock_mcp_tools):
|
||||
"""Test that MCP tools are available during task execution."""
|
||||
crew = Crew(agents=[mcp_agent], tasks=[sample_task])
|
||||
|
||||
with patch.object(mcp_agent, 'get_mcp_tools', return_value=mock_mcp_tools):
|
||||
|
||||
# Simulate task tool preparation
|
||||
prepared_tools = crew._inject_mcp_tools([], mcp_agent)
|
||||
|
||||
# Verify tools are properly prepared for task execution
|
||||
assert len(prepared_tools) == len(mock_mcp_tools)
|
||||
|
||||
# Each tool should be a valid BaseTool-compatible object
|
||||
for tool in prepared_tools:
|
||||
assert hasattr(tool, 'name')
|
||||
assert hasattr(tool, 'description')
|
||||
|
||||
def test_crew_handles_mcp_tool_name_conflicts(self, mock_mcp_tools):
|
||||
"""Test crew handling of potential tool name conflicts."""
|
||||
# Create MCP agent with tools that might conflict
|
||||
agent1 = Agent(
|
||||
role="Agent 1",
|
||||
goal="Test conflicts",
|
||||
backstory="First agent with MCP tools",
|
||||
mcps=["https://server1.com/mcp"]
|
||||
)
|
||||
|
||||
agent2 = Agent(
|
||||
role="Agent 2",
|
||||
goal="Test conflicts",
|
||||
backstory="Second agent with MCP tools",
|
||||
mcps=["https://server2.com/mcp"]
|
||||
)
|
||||
|
||||
# Mock tools with same original names but different server prefixes
|
||||
server1_tools = [Mock(name="server1_com_mcp_search_tool")]
|
||||
server2_tools = [Mock(name="server2_com_mcp_search_tool")]
|
||||
|
||||
task1 = Task(description="Task 1", expected_output="Result 1", agent=agent1)
|
||||
task2 = Task(description="Task 2", expected_output="Result 2", agent=agent2)
|
||||
|
||||
crew = Crew(agents=[agent1, agent2], tasks=[task1, task2])
|
||||
|
||||
with patch.object(agent1, 'get_mcp_tools', return_value=server1_tools), \
|
||||
patch.object(agent2, 'get_mcp_tools', return_value=server2_tools):
|
||||
|
||||
# Each agent should get its own prefixed tools
|
||||
tools1 = crew._inject_mcp_tools([], agent1)
|
||||
tools2 = crew._inject_mcp_tools([], agent2)
|
||||
|
||||
assert len(tools1) == 1
|
||||
assert len(tools2) == 1
|
||||
assert tools1[0].name != tools2[0].name # Names should be different due to prefixing
|
||||
|
||||
def test_crew_mcp_integration_with_verbose_mode(self, mcp_agent, sample_task):
|
||||
"""Test MCP integration works with crew verbose mode."""
|
||||
crew = Crew(
|
||||
agents=[mcp_agent],
|
||||
tasks=[sample_task],
|
||||
verbose=True # Enable verbose mode
|
||||
)
|
||||
|
||||
# Should work the same regardless of verbose mode
|
||||
assert crew.verbose is True
|
||||
assert hasattr(crew, '_inject_mcp_tools')
|
||||
|
||||
# MCP integration should not be affected by verbose mode
|
||||
with patch.object(mcp_agent, 'get_mcp_tools', return_value=[Mock()]):
|
||||
tools = crew._inject_mcp_tools([], mcp_agent)
|
||||
assert len(tools) == 1
|
||||
1
lib/crewai/tests/fixtures/__init__.py
vendored
Normal file
1
lib/crewai/tests/fixtures/__init__.py
vendored
Normal file
@@ -0,0 +1 @@
|
||||
"""Test fixtures package."""
|
||||
375
lib/crewai/tests/fixtures/mcp_fixtures.py
vendored
Normal file
375
lib/crewai/tests/fixtures/mcp_fixtures.py
vendored
Normal file
@@ -0,0 +1,375 @@
|
||||
"""Shared fixtures for MCP testing."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock
|
||||
|
||||
# Import from the source directory
|
||||
import sys
|
||||
import os
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../../src'))
|
||||
|
||||
from crewai.agent import Agent
|
||||
from crewai.tools.mcp_tool_wrapper import MCPToolWrapper
|
||||
from tests.mocks.mcp_server_mock import MockMCPServerFactory
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_mcp_agent():
|
||||
"""Create a sample agent with MCP configuration for testing."""
|
||||
return Agent(
|
||||
role="Test MCP Agent",
|
||||
goal="Test MCP functionality",
|
||||
backstory="Agent designed for MCP testing",
|
||||
mcps=["https://api.test.com/mcp"]
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def multi_mcp_agent():
|
||||
"""Create agent with multiple MCP configurations."""
|
||||
return Agent(
|
||||
role="Multi-MCP Agent",
|
||||
goal="Test multiple MCP server integration",
|
||||
backstory="Agent with access to multiple MCP servers",
|
||||
mcps=[
|
||||
"https://search.server.com/mcp",
|
||||
"https://analysis.server.com/mcp#specific_tool",
|
||||
"crewai-amp:research-tools",
|
||||
"crewai-amp:financial-data#stock_prices"
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mcp_agent_no_tools():
|
||||
"""Create agent without MCP configuration."""
|
||||
return Agent(
|
||||
role="No MCP Agent",
|
||||
goal="Test without MCP tools",
|
||||
backstory="Standard agent without MCP access"
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_mcp_tool_wrapper():
|
||||
"""Create a sample MCPToolWrapper for testing."""
|
||||
return MCPToolWrapper(
|
||||
mcp_server_params={"url": "https://test.server.com/mcp"},
|
||||
tool_name="test_tool",
|
||||
tool_schema={
|
||||
"description": "Test tool for MCP integration",
|
||||
"args_schema": None
|
||||
},
|
||||
server_name="test_server_com_mcp"
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_mcp_tool_schemas():
|
||||
"""Provide mock MCP tool schemas for testing."""
|
||||
return {
|
||||
"search_web": {
|
||||
"description": "Search the web for information",
|
||||
"args_schema": None
|
||||
},
|
||||
"analyze_data": {
|
||||
"description": "Analyze provided data and generate insights",
|
||||
"args_schema": None
|
||||
},
|
||||
"get_weather": {
|
||||
"description": "Get weather information for a location",
|
||||
"args_schema": None
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_exa_like_tools():
|
||||
"""Provide mock tools similar to Exa MCP server."""
|
||||
tools = []
|
||||
|
||||
# Web search tool
|
||||
web_search = Mock(spec=MCPToolWrapper)
|
||||
web_search.name = "mcp_exa_ai_mcp_web_search_exa"
|
||||
web_search.description = "Search the web using Exa AI"
|
||||
web_search.original_tool_name = "web_search_exa"
|
||||
web_search.server_name = "mcp_exa_ai_mcp"
|
||||
tools.append(web_search)
|
||||
|
||||
# Code context tool
|
||||
code_context = Mock(spec=MCPToolWrapper)
|
||||
code_context.name = "mcp_exa_ai_mcp_get_code_context_exa"
|
||||
code_context.description = "Get code context using Exa"
|
||||
code_context.original_tool_name = "get_code_context_exa"
|
||||
code_context.server_name = "mcp_exa_ai_mcp"
|
||||
tools.append(code_context)
|
||||
|
||||
return tools
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_weather_like_tools():
|
||||
"""Provide mock tools similar to weather MCP server."""
|
||||
tools = []
|
||||
|
||||
weather_tools = [
|
||||
("get_current_weather", "Get current weather conditions"),
|
||||
("get_forecast", "Get weather forecast for next 5 days"),
|
||||
("get_alerts", "Get active weather alerts")
|
||||
]
|
||||
|
||||
for tool_name, description in weather_tools:
|
||||
tool = Mock(spec=MCPToolWrapper)
|
||||
tool.name = f"weather_server_com_mcp_{tool_name}"
|
||||
tool.description = description
|
||||
tool.original_tool_name = tool_name
|
||||
tool.server_name = "weather_server_com_mcp"
|
||||
tools.append(tool)
|
||||
|
||||
return tools
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_amp_mcp_responses():
|
||||
"""Provide mock responses for CrewAI AMP MCP API calls."""
|
||||
return {
|
||||
"research-tools": [
|
||||
{"url": "https://amp.crewai.com/mcp/research/v1"},
|
||||
{"url": "https://amp.crewai.com/mcp/research/v2"}
|
||||
],
|
||||
"financial-data": [
|
||||
{"url": "https://amp.crewai.com/mcp/financial/main"}
|
||||
],
|
||||
"weather-service": [
|
||||
{"url": "https://amp.crewai.com/mcp/weather/api"}
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def performance_test_mcps():
|
||||
"""Provide MCP configurations for performance testing."""
|
||||
return [
|
||||
"https://fast-server.com/mcp",
|
||||
"https://medium-server.com/mcp",
|
||||
"https://reliable-server.com/mcp"
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def error_scenario_mcps():
|
||||
"""Provide MCP configurations for error scenario testing."""
|
||||
return [
|
||||
"https://timeout-server.com/mcp",
|
||||
"https://auth-fail-server.com/mcp",
|
||||
"https://json-error-server.com/mcp",
|
||||
"https://not-found-server.com/mcp"
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mixed_quality_mcps():
|
||||
"""Provide mixed quality MCP server configurations for resilience testing."""
|
||||
return [
|
||||
"https://excellent-server.com/mcp", # Always works
|
||||
"https://intermittent-server.com/mcp", # Sometimes works
|
||||
"https://slow-but-working-server.com/mcp", # Slow but reliable
|
||||
"https://completely-broken-server.com/mcp" # Never works
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def server_name_test_cases():
|
||||
"""Provide test cases for server name extraction."""
|
||||
return [
|
||||
# (input_url, expected_server_name)
|
||||
("https://api.example.com/mcp", "api_example_com_mcp"),
|
||||
("https://mcp.exa.ai/api/v1", "mcp_exa_ai_api_v1"),
|
||||
("https://simple.com", "simple_com"),
|
||||
("https://complex-domain.co.uk/deep/path/mcp", "complex-domain_co_uk_deep_path_mcp"),
|
||||
("https://localhost:8080/mcp", "localhost:8080_mcp"),
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mcp_reference_parsing_cases():
|
||||
"""Provide test cases for MCP reference parsing."""
|
||||
return [
|
||||
# (mcp_ref, expected_type, expected_server, expected_tool)
|
||||
("https://api.example.com/mcp", "external", "https://api.example.com/mcp", None),
|
||||
("https://api.example.com/mcp#search", "external", "https://api.example.com/mcp", "search"),
|
||||
("crewai-amp:weather-service", "amp", "weather-service", None),
|
||||
("crewai-amp:financial-data#stock_price", "amp", "financial-data", "stock_price"),
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def cache_test_scenarios():
|
||||
"""Provide scenarios for cache testing."""
|
||||
return {
|
||||
"cache_hit": {
|
||||
"initial_time": 1000,
|
||||
"subsequent_time": 1100, # Within 300s TTL
|
||||
"expected_calls": 1
|
||||
},
|
||||
"cache_miss": {
|
||||
"initial_time": 1000,
|
||||
"subsequent_time": 1400, # Beyond 300s TTL
|
||||
"expected_calls": 2
|
||||
},
|
||||
"cache_boundary": {
|
||||
"initial_time": 1000,
|
||||
"subsequent_time": 1300, # Exactly at 300s TTL boundary
|
||||
"expected_calls": 2
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def timeout_test_scenarios():
|
||||
"""Provide scenarios for timeout testing."""
|
||||
return {
|
||||
"connection_timeout": {
|
||||
"timeout_type": "connection",
|
||||
"delay": 15, # Exceeds 10s connection timeout
|
||||
"expected_error": "timed out"
|
||||
},
|
||||
"execution_timeout": {
|
||||
"timeout_type": "execution",
|
||||
"delay": 35, # Exceeds 30s execution timeout
|
||||
"expected_error": "timed out"
|
||||
},
|
||||
"discovery_timeout": {
|
||||
"timeout_type": "discovery",
|
||||
"delay": 20, # Exceeds 15s discovery timeout
|
||||
"expected_error": "timed out"
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mcp_error_scenarios():
|
||||
"""Provide various MCP error scenarios for testing."""
|
||||
return {
|
||||
"connection_refused": {
|
||||
"error": ConnectionRefusedError("Connection refused"),
|
||||
"expected_msg": "network connection failed",
|
||||
"retryable": True
|
||||
},
|
||||
"auth_failed": {
|
||||
"error": Exception("Authentication failed"),
|
||||
"expected_msg": "authentication failed",
|
||||
"retryable": False
|
||||
},
|
||||
"json_parse_error": {
|
||||
"error": ValueError("JSON decode error"),
|
||||
"expected_msg": "server response parsing error",
|
||||
"retryable": True
|
||||
},
|
||||
"tool_not_found": {
|
||||
"error": Exception("Tool not found"),
|
||||
"expected_msg": "not found",
|
||||
"retryable": False
|
||||
},
|
||||
"server_error": {
|
||||
"error": Exception("Internal server error"),
|
||||
"expected_msg": "mcp execution error",
|
||||
"retryable": False
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clear_mcp_cache():
|
||||
"""Automatically clear MCP cache before each test."""
|
||||
from crewai.agent import _mcp_schema_cache
|
||||
_mcp_schema_cache.clear()
|
||||
yield
|
||||
_mcp_schema_cache.clear()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_successful_mcp_execution():
|
||||
"""Provide a mock for successful MCP tool execution."""
|
||||
def _mock_execution(**kwargs):
|
||||
return f"Successful MCP execution with args: {kwargs}"
|
||||
return _mock_execution
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def performance_benchmarks():
|
||||
"""Provide performance benchmarks for MCP operations."""
|
||||
return {
|
||||
"agent_creation_max_time": 0.5, # 500ms
|
||||
"tool_discovery_max_time": 2.0, # 2 seconds
|
||||
"cache_hit_max_time": 0.01, # 10ms
|
||||
"tool_execution_max_time": 35.0, # 35 seconds (includes timeout buffer)
|
||||
"crew_integration_max_time": 0.1 # 100ms
|
||||
}
|
||||
|
||||
|
||||
# Convenience functions for common test setup
|
||||
|
||||
def setup_successful_mcp_environment():
|
||||
"""Set up a complete successful MCP test environment."""
|
||||
mock_server = MockMCPServerFactory.create_exa_like_server("https://mock-exa.com/mcp")
|
||||
|
||||
agent = Agent(
|
||||
role="Success Test Agent",
|
||||
goal="Test successful MCP operations",
|
||||
backstory="Agent for testing successful scenarios",
|
||||
mcps=["https://mock-exa.com/mcp"]
|
||||
)
|
||||
|
||||
return agent, mock_server
|
||||
|
||||
|
||||
def setup_error_prone_mcp_environment():
|
||||
"""Set up an MCP test environment with various error conditions."""
|
||||
agents = {}
|
||||
|
||||
# Different agents for different error scenarios
|
||||
agents["timeout"] = Agent(
|
||||
role="Timeout Agent",
|
||||
goal="Test timeout scenarios",
|
||||
backstory="Agent for timeout testing",
|
||||
mcps=["https://slow-server.com/mcp"]
|
||||
)
|
||||
|
||||
agents["auth_fail"] = Agent(
|
||||
role="Auth Fail Agent",
|
||||
goal="Test auth failures",
|
||||
backstory="Agent for auth testing",
|
||||
mcps=["https://secure-server.com/mcp"]
|
||||
)
|
||||
|
||||
agents["mixed"] = Agent(
|
||||
role="Mixed Results Agent",
|
||||
goal="Test mixed success/failure",
|
||||
backstory="Agent for mixed scenario testing",
|
||||
mcps=[
|
||||
"https://working-server.com/mcp",
|
||||
"https://failing-server.com/mcp",
|
||||
"crewai-amp:working-service",
|
||||
"crewai-amp:failing-service"
|
||||
]
|
||||
)
|
||||
|
||||
return agents
|
||||
|
||||
|
||||
def create_test_crew_with_mcp_agents(agents, task_descriptions=None):
|
||||
"""Create a test crew with MCP-enabled agents."""
|
||||
if task_descriptions is None:
|
||||
task_descriptions = ["Generic test task" for _ in agents]
|
||||
|
||||
tasks = []
|
||||
for i, agent in enumerate(agents):
|
||||
task = Task(
|
||||
description=task_descriptions[i] if i < len(task_descriptions) else f"Task for {agent.role}",
|
||||
expected_output=f"Output from {agent.role}",
|
||||
agent=agent
|
||||
)
|
||||
tasks.append(task)
|
||||
|
||||
return Crew(agents=agents, tasks=tasks)
|
||||
1
lib/crewai/tests/integration/__init__.py
Normal file
1
lib/crewai/tests/integration/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Integration tests package."""
|
||||
770
lib/crewai/tests/integration/test_mcp_end_to_end.py
Normal file
770
lib/crewai/tests/integration/test_mcp_end_to_end.py
Normal file
@@ -0,0 +1,770 @@
|
||||
"""End-to-end integration tests for MCP DSL functionality."""
|
||||
|
||||
import asyncio
|
||||
import pytest
|
||||
import time
|
||||
from unittest.mock import Mock, patch, AsyncMock
|
||||
|
||||
# Import from the source directory
|
||||
import sys
|
||||
import os
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../../src'))
|
||||
|
||||
from crewai.agent import Agent
|
||||
from crewai.crew import Crew
|
||||
from crewai.task import Task
|
||||
from crewai.tools.mcp_tool_wrapper import MCPToolWrapper
|
||||
from tests.mocks.mcp_server_mock import MockMCPServerFactory
|
||||
|
||||
|
||||
class TestMCPEndToEndIntegration:
|
||||
"""End-to-end integration tests for MCP DSL functionality."""
|
||||
|
||||
def test_complete_mcp_workflow_single_server(self):
|
||||
"""Test complete MCP workflow with single server."""
|
||||
print("\n=== Testing Complete MCP Workflow ===")
|
||||
|
||||
# Step 1: Create agent with MCP configuration
|
||||
agent = Agent(
|
||||
role="E2E Test Agent",
|
||||
goal="Test complete MCP workflow",
|
||||
backstory="Agent for end-to-end MCP testing",
|
||||
mcps=["https://api.example.com/mcp"]
|
||||
)
|
||||
|
||||
assert agent.mcps == ["https://api.example.com/mcp"]
|
||||
print("✅ Step 1: Agent created with MCP configuration")
|
||||
|
||||
# Step 2: Mock tool discovery
|
||||
mock_schemas = {
|
||||
"search_web": {"description": "Search the web for information"},
|
||||
"analyze_data": {"description": "Analyze provided data"}
|
||||
}
|
||||
|
||||
with patch.object(agent, '_get_mcp_tool_schemas', return_value=mock_schemas):
|
||||
|
||||
# Step 3: Discover MCP tools
|
||||
discovered_tools = agent.get_mcp_tools(agent.mcps)
|
||||
|
||||
assert len(discovered_tools) == 2
|
||||
assert all(isinstance(tool, MCPToolWrapper) for tool in discovered_tools)
|
||||
print(f"✅ Step 3: Discovered {len(discovered_tools)} MCP tools")
|
||||
|
||||
# Verify tool names are properly prefixed
|
||||
tool_names = [tool.name for tool in discovered_tools]
|
||||
assert "api_example_com_mcp_search_web" in tool_names
|
||||
assert "api_example_com_mcp_analyze_data" in tool_names
|
||||
|
||||
# Step 4: Create task and crew
|
||||
task = Task(
|
||||
description="Research AI frameworks using MCP tools",
|
||||
expected_output="Research report using discovered tools",
|
||||
agent=agent
|
||||
)
|
||||
|
||||
crew = Crew(agents=[agent], tasks=[task])
|
||||
print("✅ Step 4: Created task and crew")
|
||||
|
||||
# Step 5: Test crew tool integration
|
||||
crew_tools = crew._inject_mcp_tools([], agent)
|
||||
assert len(crew_tools) == 2
|
||||
print("✅ Step 5: MCP tools integrated into crew")
|
||||
|
||||
# Step 6: Test tool execution
|
||||
search_tool = next(tool for tool in discovered_tools if "search" in tool.name)
|
||||
|
||||
# Mock successful tool execution
|
||||
with patch.object(search_tool, '_run_async', return_value="Search results: AI frameworks found"):
|
||||
result = search_tool._run(query="AI frameworks")
|
||||
|
||||
assert "Search results" in result
|
||||
print("✅ Step 6: Tool execution successful")
|
||||
|
||||
def test_complete_mcp_workflow_multiple_servers(self):
|
||||
"""Test complete MCP workflow with multiple servers."""
|
||||
print("\n=== Testing Multi-Server MCP Workflow ===")
|
||||
|
||||
# Create agent with multiple MCP servers
|
||||
agent = Agent(
|
||||
role="Multi-Server Agent",
|
||||
goal="Test multiple MCP server integration",
|
||||
backstory="Agent with access to multiple MCP servers",
|
||||
mcps=[
|
||||
"https://search-server.com/mcp",
|
||||
"https://analysis-server.com/mcp#specific_tool",
|
||||
"crewai-amp:weather-service",
|
||||
"crewai-amp:financial-data#stock_price_tool"
|
||||
]
|
||||
)
|
||||
|
||||
print(f"✅ Agent created with {len(agent.mcps)} MCP references")
|
||||
|
||||
# Mock different server responses
|
||||
def mock_external_tools(mcp_ref):
|
||||
if "search-server" in mcp_ref:
|
||||
return [Mock(name="search_server_com_mcp_web_search")]
|
||||
elif "analysis-server" in mcp_ref:
|
||||
return [Mock(name="analysis_server_com_mcp_specific_tool")]
|
||||
return []
|
||||
|
||||
def mock_amp_tools(amp_ref):
|
||||
if "weather-service" in amp_ref:
|
||||
return [Mock(name="weather_service_get_forecast")]
|
||||
elif "financial-data" in amp_ref:
|
||||
return [Mock(name="financial_data_stock_price_tool")]
|
||||
return []
|
||||
|
||||
with patch.object(agent, '_get_external_mcp_tools', side_effect=mock_external_tools), \
|
||||
patch.object(agent, '_get_amp_mcp_tools', side_effect=mock_amp_tools):
|
||||
|
||||
# Discover all tools
|
||||
all_tools = agent.get_mcp_tools(agent.mcps)
|
||||
|
||||
# Should get tools from all servers
|
||||
expected_tools = 2 + 2 # 2 external + 2 AMP
|
||||
assert len(all_tools) == expected_tools
|
||||
print(f"✅ Discovered {len(all_tools)} tools from multiple servers")
|
||||
|
||||
# Create multi-task crew
|
||||
tasks = [
|
||||
Task(
|
||||
description="Search for information",
|
||||
expected_output="Search results",
|
||||
agent=agent
|
||||
),
|
||||
Task(
|
||||
description="Analyze financial data",
|
||||
expected_output="Analysis report",
|
||||
agent=agent
|
||||
)
|
||||
]
|
||||
|
||||
crew = Crew(agents=[agent], tasks=tasks)
|
||||
|
||||
# Test crew integration with multiple tools
|
||||
for task in tasks:
|
||||
task_tools = crew._inject_mcp_tools([], task.agent)
|
||||
assert len(task_tools) == expected_tools
|
||||
|
||||
print("✅ Multi-server integration successful")
|
||||
|
||||
def test_mcp_workflow_with_error_recovery(self):
|
||||
"""Test MCP workflow with error recovery scenarios."""
|
||||
print("\n=== Testing MCP Workflow with Error Recovery ===")
|
||||
|
||||
# Create agent with mix of working and failing servers
|
||||
agent = Agent(
|
||||
role="Error Recovery Agent",
|
||||
goal="Test error recovery capabilities",
|
||||
backstory="Agent designed to handle MCP server failures",
|
||||
mcps=[
|
||||
"https://failing-server.com/mcp", # Will fail
|
||||
"https://working-server.com/mcp", # Will work
|
||||
"https://timeout-server.com/mcp", # Will timeout
|
||||
"crewai-amp:nonexistent-service" # Will fail
|
||||
]
|
||||
)
|
||||
|
||||
print(f"✅ Agent created with {len(agent.mcps)} MCP references (some will fail)")
|
||||
|
||||
# Mock mixed success/failure scenario
|
||||
def mock_mixed_external_tools(mcp_ref):
|
||||
if "failing-server" in mcp_ref:
|
||||
raise Exception("Server connection failed")
|
||||
elif "working-server" in mcp_ref:
|
||||
return [Mock(name="working_server_com_mcp_reliable_tool")]
|
||||
elif "timeout-server" in mcp_ref:
|
||||
raise Exception("Connection timed out")
|
||||
return []
|
||||
|
||||
def mock_failing_amp_tools(amp_ref):
|
||||
raise Exception("AMP server unavailable")
|
||||
|
||||
with patch.object(agent, '_get_external_mcp_tools', side_effect=mock_mixed_external_tools), \
|
||||
patch.object(agent, '_get_amp_mcp_tools', side_effect=mock_failing_amp_tools), \
|
||||
patch.object(agent, '_logger') as mock_logger:
|
||||
|
||||
# Should handle failures gracefully and continue with working servers
|
||||
working_tools = agent.get_mcp_tools(agent.mcps)
|
||||
|
||||
# Should get tools from working server only
|
||||
assert len(working_tools) == 1
|
||||
assert working_tools[0].name == "working_server_com_mcp_reliable_tool"
|
||||
print("✅ Error recovery successful - got tools from working server")
|
||||
|
||||
# Should log warnings for failing servers
|
||||
warning_calls = [call for call in mock_logger.log.call_args_list if call[0][0] == "warning"]
|
||||
assert len(warning_calls) >= 3 # At least 3 failures logged
|
||||
|
||||
print("✅ Error logging and recovery complete")
|
||||
|
||||
def test_mcp_workflow_performance_benchmarks(self):
|
||||
"""Test MCP workflow performance meets benchmarks."""
|
||||
print("\n=== Testing MCP Performance Benchmarks ===")
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# Agent creation should be fast
|
||||
agent = Agent(
|
||||
role="Performance Benchmark Agent",
|
||||
goal="Establish performance benchmarks",
|
||||
backstory="Agent for performance testing",
|
||||
mcps=[
|
||||
"https://perf1.com/mcp",
|
||||
"https://perf2.com/mcp",
|
||||
"https://perf3.com/mcp"
|
||||
]
|
||||
)
|
||||
|
||||
agent_creation_time = time.time() - start_time
|
||||
assert agent_creation_time < 0.5 # Less than 500ms
|
||||
print(f"✅ Agent creation: {agent_creation_time:.3f}s")
|
||||
|
||||
# Tool discovery should be efficient
|
||||
mock_schemas = {f"tool_{i}": {"description": f"Tool {i}"} for i in range(5)}
|
||||
|
||||
with patch.object(agent, '_get_mcp_tool_schemas', return_value=mock_schemas):
|
||||
|
||||
discovery_start = time.time()
|
||||
tools = agent.get_mcp_tools(agent.mcps)
|
||||
discovery_time = time.time() - discovery_start
|
||||
|
||||
# Should discover tools from 3 servers with 5 tools each = 15 tools
|
||||
assert len(tools) == 15
|
||||
assert discovery_time < 2.0 # Less than 2 seconds
|
||||
print(f"✅ Tool discovery: {discovery_time:.3f}s for {len(tools)} tools")
|
||||
|
||||
# Crew creation should be fast
|
||||
task = Task(
|
||||
description="Performance test task",
|
||||
expected_output="Performance results",
|
||||
agent=agent
|
||||
)
|
||||
|
||||
crew_start = time.time()
|
||||
crew = Crew(agents=[agent], tasks=[task])
|
||||
crew_creation_time = time.time() - crew_start
|
||||
|
||||
assert crew_creation_time < 0.1 # Less than 100ms
|
||||
print(f"✅ Crew creation: {crew_creation_time:.3f}s")
|
||||
|
||||
total_time = time.time() - start_time
|
||||
print(f"✅ Total workflow: {total_time:.3f}s")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mcp_workflow_with_real_async_patterns(self):
|
||||
"""Test MCP workflow with realistic async operation patterns."""
|
||||
print("\n=== Testing Async MCP Workflow Patterns ===")
|
||||
|
||||
# Create agent
|
||||
agent = Agent(
|
||||
role="Async Test Agent",
|
||||
goal="Test async MCP operations",
|
||||
backstory="Agent for testing async patterns",
|
||||
mcps=["https://async-test.com/mcp"]
|
||||
)
|
||||
|
||||
# Mock realistic async MCP server behavior
|
||||
mock_server = MockMCPServerFactory.create_exa_like_server("https://async-test.com/mcp")
|
||||
|
||||
with patch('crewai.agent.streamablehttp_client') as mock_client, \
|
||||
patch('crewai.agent.ClientSession') as mock_session_class:
|
||||
|
||||
# Setup async mocks
|
||||
mock_session = AsyncMock()
|
||||
mock_session_class.return_value.__aenter__.return_value = mock_session
|
||||
mock_session.initialize = AsyncMock()
|
||||
mock_session.list_tools = AsyncMock(return_value=await mock_server.simulate_list_tools())
|
||||
|
||||
mock_client.return_value.__aenter__.return_value = (None, None, None)
|
||||
|
||||
# Test async tool discovery
|
||||
start_time = time.time()
|
||||
schemas = await agent._get_mcp_tool_schemas_async({"url": "https://async-test.com/mcp"})
|
||||
discovery_time = time.time() - start_time
|
||||
|
||||
assert len(schemas) == 2 # Exa-like server has 2 tools
|
||||
assert discovery_time < 1.0 # Should be fast with mocked operations
|
||||
print(f"✅ Async discovery: {discovery_time:.3f}s")
|
||||
|
||||
# Verify async methods were called
|
||||
mock_session.initialize.assert_called_once()
|
||||
mock_session.list_tools.assert_called_once()
|
||||
|
||||
def test_mcp_workflow_scalability_test(self):
|
||||
"""Test MCP workflow scalability with many agents and tools."""
|
||||
print("\n=== Testing MCP Workflow Scalability ===")
|
||||
|
||||
# Create multiple agents with MCP configurations
|
||||
agents = []
|
||||
for i in range(10):
|
||||
agent = Agent(
|
||||
role=f"Scalability Agent {i}",
|
||||
goal=f"Test scalability scenario {i}",
|
||||
backstory=f"Agent {i} for scalability testing",
|
||||
mcps=[f"https://scale-server-{i}.com/mcp"]
|
||||
)
|
||||
agents.append(agent)
|
||||
|
||||
print(f"✅ Created {len(agents)} agents with MCP configurations")
|
||||
|
||||
# Mock tool discovery for all agents
|
||||
mock_schemas = {f"scale_tool_{i}": {"description": f"Scalability tool {i}"} for i in range(3)}
|
||||
|
||||
with patch.object(Agent, '_get_mcp_tool_schemas', return_value=mock_schemas):
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# Discover tools for all agents
|
||||
all_agent_tools = []
|
||||
for agent in agents:
|
||||
tools = agent.get_mcp_tools(agent.mcps)
|
||||
all_agent_tools.extend(tools)
|
||||
|
||||
scalability_time = time.time() - start_time
|
||||
|
||||
# Should handle multiple agents efficiently
|
||||
total_tools = len(agents) * 3 # 3 tools per agent
|
||||
assert len(all_agent_tools) == total_tools
|
||||
assert scalability_time < 5.0 # Should complete within 5 seconds
|
||||
|
||||
print(f"✅ Scalability test: {len(all_agent_tools)} tools from {len(agents)} agents in {scalability_time:.3f}s")
|
||||
|
||||
# Test crew creation with multiple MCP agents
|
||||
tasks = [
|
||||
Task(
|
||||
description=f"Task for agent {i}",
|
||||
expected_output=f"Output from agent {i}",
|
||||
agent=agents[i]
|
||||
) for i in range(len(agents))
|
||||
]
|
||||
|
||||
crew = Crew(agents=agents, tasks=tasks)
|
||||
|
||||
assert len(crew.agents) == 10
|
||||
assert len(crew.tasks) == 10
|
||||
print("✅ Scalability crew creation successful")
|
||||
|
||||
def test_mcp_workflow_with_specific_tool_selection(self):
|
||||
"""Test MCP workflow with specific tool selection using # syntax."""
|
||||
print("\n=== Testing Specific Tool Selection ===")
|
||||
|
||||
# Create agent with specific tool selections
|
||||
agent = Agent(
|
||||
role="Specific Tool Agent",
|
||||
goal="Test specific tool selection",
|
||||
backstory="Agent that uses specific MCP tools",
|
||||
mcps=[
|
||||
"https://multi-tool-server.com/mcp#search_tool",
|
||||
"https://another-server.com/mcp#analysis_tool",
|
||||
"crewai-amp:research-service#pubmed_search"
|
||||
]
|
||||
)
|
||||
|
||||
# Mock servers with multiple tools, but we should only get specific ones
|
||||
def mock_external_tools_specific(mcp_ref):
|
||||
if "#search_tool" in mcp_ref:
|
||||
return [Mock(name="multi_tool_server_com_mcp_search_tool")]
|
||||
elif "#analysis_tool" in mcp_ref:
|
||||
return [Mock(name="another_server_com_mcp_analysis_tool")]
|
||||
return []
|
||||
|
||||
def mock_amp_tools_specific(amp_ref):
|
||||
if "#pubmed_search" in amp_ref:
|
||||
return [Mock(name="research_service_pubmed_search")]
|
||||
return []
|
||||
|
||||
with patch.object(agent, '_get_external_mcp_tools', side_effect=mock_external_tools_specific), \
|
||||
patch.object(agent, '_get_amp_mcp_tools', side_effect=mock_amp_tools_specific):
|
||||
|
||||
specific_tools = agent.get_mcp_tools(agent.mcps)
|
||||
|
||||
# Should get exactly 3 specific tools
|
||||
assert len(specific_tools) == 3
|
||||
print("✅ Specific tool selection working correctly")
|
||||
|
||||
# Verify correct tools were selected
|
||||
tool_names = [tool.name for tool in specific_tools]
|
||||
expected_names = [
|
||||
"multi_tool_server_com_mcp_search_tool",
|
||||
"another_server_com_mcp_analysis_tool",
|
||||
"research_service_pubmed_search"
|
||||
]
|
||||
|
||||
for expected_name in expected_names:
|
||||
assert expected_name in tool_names
|
||||
|
||||
def test_mcp_workflow_resilience_under_stress(self):
|
||||
"""Test MCP workflow resilience under stress conditions."""
|
||||
print("\n=== Testing MCP Workflow Resilience ===")
|
||||
|
||||
# Create stress test scenario
|
||||
stress_mcps = []
|
||||
for i in range(20):
|
||||
# Mix of different server types
|
||||
if i % 4 == 0:
|
||||
stress_mcps.append(f"https://working-server-{i}.com/mcp")
|
||||
elif i % 4 == 1:
|
||||
stress_mcps.append(f"https://failing-server-{i}.com/mcp")
|
||||
elif i % 4 == 2:
|
||||
stress_mcps.append(f"crewai-amp:service-{i}")
|
||||
else:
|
||||
stress_mcps.append(f"https://slow-server-{i}.com/mcp#specific_tool")
|
||||
|
||||
agent = Agent(
|
||||
role="Stress Test Agent",
|
||||
goal="Test MCP workflow under stress",
|
||||
backstory="Agent for stress testing MCP functionality",
|
||||
mcps=stress_mcps
|
||||
)
|
||||
|
||||
# Mock stress test behaviors
|
||||
def mock_stress_external_tools(mcp_ref):
|
||||
if "failing" in mcp_ref:
|
||||
raise Exception("Simulated failure")
|
||||
elif "slow" in mcp_ref:
|
||||
# Simulate slow response
|
||||
time.sleep(0.1)
|
||||
return [Mock(name=f"tool_from_{mcp_ref}")]
|
||||
elif "working" in mcp_ref:
|
||||
return [Mock(name=f"tool_from_{mcp_ref}")]
|
||||
return []
|
||||
|
||||
def mock_stress_amp_tools(amp_ref):
|
||||
return [Mock(name=f"amp_tool_from_{amp_ref}")]
|
||||
|
||||
with patch.object(agent, '_get_external_mcp_tools', side_effect=mock_stress_external_tools), \
|
||||
patch.object(agent, '_get_amp_mcp_tools', side_effect=mock_stress_amp_tools):
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# Should handle all servers (working, failing, slow, AMP)
|
||||
stress_tools = agent.get_mcp_tools(agent.mcps)
|
||||
|
||||
stress_time = time.time() - start_time
|
||||
|
||||
# Should get tools from working servers (5 working + 5 slow + 5 AMP = 15)
|
||||
expected_working_tools = 15
|
||||
assert len(stress_tools) == expected_working_tools
|
||||
|
||||
# Should complete within reasonable time despite stress
|
||||
assert stress_time < 10.0
|
||||
|
||||
print(f"✅ Stress test: {len(stress_tools)} tools processed in {stress_time:.3f}s")
|
||||
|
||||
def test_mcp_workflow_integration_with_existing_features(self):
|
||||
"""Test MCP workflow integration with existing CrewAI features."""
|
||||
print("\n=== Testing Integration with Existing Features ===")
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
|
||||
# Create custom tool for testing integration
|
||||
class CustomTool(BaseTool):
|
||||
name: str = "custom_search_tool"
|
||||
description: str = "Custom search tool"
|
||||
|
||||
def _run(self, **kwargs):
|
||||
return "Custom tool result"
|
||||
|
||||
# Create agent with both regular tools, platform apps, and MCP tools
|
||||
agent = Agent(
|
||||
role="Full Integration Agent",
|
||||
goal="Test integration with all CrewAI features",
|
||||
backstory="Agent with access to all tool types",
|
||||
tools=[CustomTool()], # Regular tools
|
||||
apps=["gmail", "slack"], # Platform apps
|
||||
mcps=["https://integration-server.com/mcp"], # MCP tools
|
||||
verbose=True,
|
||||
max_iter=15,
|
||||
allow_delegation=True
|
||||
)
|
||||
|
||||
print("✅ Agent created with all feature types")
|
||||
|
||||
# Test that all features work together
|
||||
assert len(agent.tools) == 1 # Regular tools
|
||||
assert len(agent.apps) == 2 # Platform apps
|
||||
assert len(agent.mcps) == 1 # MCP tools
|
||||
assert agent.verbose is True
|
||||
assert agent.max_iter == 15
|
||||
assert agent.allow_delegation is True
|
||||
|
||||
# Mock MCP tool discovery
|
||||
mock_mcp_tools = [Mock(name="integration_server_com_mcp_integration_tool")]
|
||||
|
||||
with patch.object(agent, 'get_mcp_tools', return_value=mock_mcp_tools):
|
||||
|
||||
# Create crew with integrated agent
|
||||
task = Task(
|
||||
description="Use all available tool types for comprehensive research",
|
||||
expected_output="Comprehensive research using all tools",
|
||||
agent=agent
|
||||
)
|
||||
|
||||
crew = Crew(agents=[agent], tasks=[task])
|
||||
|
||||
# Test crew tool integration
|
||||
crew_tools = crew._inject_mcp_tools([], agent)
|
||||
assert len(crew_tools) == len(mock_mcp_tools)
|
||||
|
||||
print("✅ Full feature integration successful")
|
||||
|
||||
def test_mcp_workflow_user_experience_simulation(self):
|
||||
"""Simulate typical user experience with MCP DSL."""
|
||||
print("\n=== Simulating User Experience ===")
|
||||
|
||||
# Simulate user creating agent for research
|
||||
research_agent = Agent(
|
||||
role="AI Research Specialist",
|
||||
goal="Research AI technologies and frameworks",
|
||||
backstory="Expert AI researcher with access to search and analysis tools",
|
||||
mcps=[
|
||||
"https://mcp.exa.ai/mcp?api_key=user_key&profile=research",
|
||||
"https://analysis.tools.com/mcp#analyze_trends",
|
||||
"crewai-amp:academic-research",
|
||||
"crewai-amp:market-analysis#competitor_analysis"
|
||||
]
|
||||
)
|
||||
|
||||
print("✅ User created research agent with 4 MCP references")
|
||||
|
||||
# Mock realistic tool discovery
|
||||
mock_tools = [
|
||||
Mock(name="mcp_exa_ai_mcp_web_search_exa"),
|
||||
Mock(name="analysis_tools_com_mcp_analyze_trends"),
|
||||
Mock(name="academic_research_paper_search"),
|
||||
Mock(name="market_analysis_competitor_analysis")
|
||||
]
|
||||
|
||||
with patch.object(research_agent, 'get_mcp_tools', return_value=mock_tools):
|
||||
|
||||
# User creates research task
|
||||
research_task = Task(
|
||||
description="Research the current state of multi-agent AI frameworks, focusing on CrewAI",
|
||||
expected_output="Comprehensive research report with market analysis and competitor comparison",
|
||||
agent=research_agent
|
||||
)
|
||||
|
||||
# User creates and configures crew
|
||||
research_crew = Crew(
|
||||
agents=[research_agent],
|
||||
tasks=[research_task],
|
||||
verbose=True
|
||||
)
|
||||
|
||||
print("✅ User created research task and crew")
|
||||
|
||||
# Verify user's MCP tools are available
|
||||
available_tools = research_crew._inject_mcp_tools([], research_agent)
|
||||
assert len(available_tools) == 4
|
||||
|
||||
print("✅ User's MCP tools integrated successfully")
|
||||
print(f" Available tools: {[tool.name for tool in available_tools]}")
|
||||
|
||||
# Test tool execution simulation
|
||||
search_tool = available_tools[0]
|
||||
with patch.object(search_tool, '_run', return_value="Research results about CrewAI framework"):
|
||||
result = search_tool._run(query="CrewAI multi-agent framework", num_results=5)
|
||||
|
||||
assert "CrewAI framework" in result
|
||||
print("✅ User tool execution successful")
|
||||
|
||||
def test_mcp_workflow_production_readiness_checklist(self):
|
||||
"""Verify MCP workflow meets production readiness checklist."""
|
||||
print("\n=== Production Readiness Checklist ===")
|
||||
|
||||
checklist_results = {}
|
||||
|
||||
# ✅ Test 1: Agent creation without external dependencies
|
||||
try:
|
||||
agent = Agent(
|
||||
role="Production Test Agent",
|
||||
goal="Verify production readiness",
|
||||
backstory="Agent for production testing",
|
||||
mcps=["https://prod-test.com/mcp"]
|
||||
)
|
||||
checklist_results["agent_creation"] = "✅ PASS"
|
||||
except Exception as e:
|
||||
checklist_results["agent_creation"] = f"❌ FAIL: {e}"
|
||||
|
||||
# ✅ Test 2: Graceful handling of unavailable servers
|
||||
with patch.object(agent, '_get_external_mcp_tools', side_effect=Exception("Server unavailable")):
|
||||
try:
|
||||
tools = agent.get_mcp_tools(agent.mcps)
|
||||
assert tools == [] # Should return empty list, not crash
|
||||
checklist_results["error_handling"] = "✅ PASS"
|
||||
except Exception as e:
|
||||
checklist_results["error_handling"] = f"❌ FAIL: {e}"
|
||||
|
||||
# ✅ Test 3: Performance within acceptable limits
|
||||
start_time = time.time()
|
||||
mock_tools = [Mock() for _ in range(10)]
|
||||
with patch.object(agent, 'get_mcp_tools', return_value=mock_tools):
|
||||
tools = agent.get_mcp_tools(agent.mcps)
|
||||
performance_time = time.time() - start_time
|
||||
|
||||
if performance_time < 1.0:
|
||||
checklist_results["performance"] = "✅ PASS"
|
||||
else:
|
||||
checklist_results["performance"] = f"❌ FAIL: {performance_time:.3f}s"
|
||||
|
||||
# ✅ Test 4: Integration with crew workflows
|
||||
try:
|
||||
task = Task(description="Test task", expected_output="Test output", agent=agent)
|
||||
crew = Crew(agents=[agent], tasks=[task])
|
||||
|
||||
crew_tools = crew._inject_mcp_tools([], agent)
|
||||
checklist_results["crew_integration"] = "✅ PASS"
|
||||
except Exception as e:
|
||||
checklist_results["crew_integration"] = f"❌ FAIL: {e}"
|
||||
|
||||
# ✅ Test 5: Input validation works correctly
|
||||
try:
|
||||
Agent(
|
||||
role="Validation Test",
|
||||
goal="Test validation",
|
||||
backstory="Testing validation",
|
||||
mcps=["invalid-format"]
|
||||
)
|
||||
checklist_results["input_validation"] = "❌ FAIL: Should reject invalid format"
|
||||
except Exception:
|
||||
checklist_results["input_validation"] = "✅ PASS"
|
||||
|
||||
# Print results
|
||||
print("\nProduction Readiness Results:")
|
||||
for test_name, result in checklist_results.items():
|
||||
print(f" {test_name.replace('_', ' ').title()}: {result}")
|
||||
|
||||
# All tests should pass
|
||||
passed_tests = sum(1 for result in checklist_results.values() if "✅ PASS" in result)
|
||||
total_tests = len(checklist_results)
|
||||
|
||||
assert passed_tests == total_tests, f"Only {passed_tests}/{total_tests} production readiness tests passed"
|
||||
print(f"\n🎉 Production Readiness: {passed_tests}/{total_tests} tests PASSED")
|
||||
|
||||
def test_complete_user_journey_simulation(self):
|
||||
"""Simulate a complete user journey from setup to execution."""
|
||||
print("\n=== Complete User Journey Simulation ===")
|
||||
|
||||
# Step 1: User installs CrewAI (already done)
|
||||
print("✅ Step 1: CrewAI installed")
|
||||
|
||||
# Step 2: User creates agent with MCP tools
|
||||
user_agent = Agent(
|
||||
role="Data Analyst",
|
||||
goal="Analyze market trends and competitor data",
|
||||
backstory="Experienced analyst with access to real-time data sources",
|
||||
mcps=[
|
||||
"https://api.marketdata.com/mcp",
|
||||
"https://competitor.intelligence.com/mcp#competitor_analysis",
|
||||
"crewai-amp:financial-insights",
|
||||
"crewai-amp:market-research#trend_analysis"
|
||||
]
|
||||
)
|
||||
print("✅ Step 2: User created agent with 4 MCP tool sources")
|
||||
|
||||
# Step 3: MCP tools are discovered automatically
|
||||
mock_discovered_tools = [
|
||||
Mock(name="api_marketdata_com_mcp_get_market_data"),
|
||||
Mock(name="competitor_intelligence_com_mcp_competitor_analysis"),
|
||||
Mock(name="financial_insights_stock_analysis"),
|
||||
Mock(name="market_research_trend_analysis")
|
||||
]
|
||||
|
||||
with patch.object(user_agent, 'get_mcp_tools', return_value=mock_discovered_tools):
|
||||
available_tools = user_agent.get_mcp_tools(user_agent.mcps)
|
||||
|
||||
assert len(available_tools) == 4
|
||||
print("✅ Step 3: MCP tools discovered automatically")
|
||||
|
||||
# Step 4: User creates analysis task
|
||||
analysis_task = Task(
|
||||
description="Analyze current market trends in AI technology sector and identify top competitors",
|
||||
expected_output="Comprehensive market analysis report with competitor insights and trend predictions",
|
||||
agent=user_agent
|
||||
)
|
||||
print("✅ Step 4: User created analysis task")
|
||||
|
||||
# Step 5: User sets up crew for execution
|
||||
analysis_crew = Crew(
|
||||
agents=[user_agent],
|
||||
tasks=[analysis_task],
|
||||
verbose=True # User wants to see progress
|
||||
)
|
||||
print("✅ Step 5: User configured crew for execution")
|
||||
|
||||
# Step 6: Crew integrates MCP tools automatically
|
||||
with patch.object(user_agent, 'get_mcp_tools', return_value=mock_discovered_tools):
|
||||
integrated_tools = analysis_crew._inject_mcp_tools([], user_agent)
|
||||
|
||||
assert len(integrated_tools) == 4
|
||||
print("✅ Step 6: Crew integrated MCP tools automatically")
|
||||
|
||||
# Step 7: Tools are ready for execution
|
||||
tool_names = [tool.name for tool in integrated_tools]
|
||||
expected_capabilities = [
|
||||
"market data access",
|
||||
"competitor analysis",
|
||||
"financial insights",
|
||||
"trend analysis"
|
||||
]
|
||||
|
||||
# Verify tools provide expected capabilities
|
||||
for capability in expected_capabilities:
|
||||
capability_found = any(
|
||||
capability.replace(" ", "_") in tool_name.lower()
|
||||
for tool_name in tool_names
|
||||
)
|
||||
assert capability_found, f"Expected capability '{capability}' not found in tools"
|
||||
|
||||
print("✅ Step 7: All expected capabilities available")
|
||||
print("\n🚀 Complete User Journey: SUCCESS!")
|
||||
print(" User can now execute crew.kickoff() with full MCP integration")
|
||||
|
||||
def test_mcp_workflow_backwards_compatibility(self):
|
||||
"""Test that MCP integration doesn't break existing functionality."""
|
||||
print("\n=== Testing Backwards Compatibility ===")
|
||||
|
||||
# Test 1: Agent without MCP field works normally
|
||||
classic_agent = Agent(
|
||||
role="Classic Agent",
|
||||
goal="Test backwards compatibility",
|
||||
backstory="Agent without MCP configuration"
|
||||
# No mcps field specified
|
||||
)
|
||||
|
||||
assert classic_agent.mcps is None
|
||||
assert hasattr(classic_agent, 'get_mcp_tools') # Method exists but mcps is None
|
||||
print("✅ Classic agent creation works")
|
||||
|
||||
# Test 2: Existing crew workflows unchanged
|
||||
classic_task = Task(
|
||||
description="Classic task without MCP",
|
||||
expected_output="Classic result",
|
||||
agent=classic_agent
|
||||
)
|
||||
|
||||
classic_crew = Crew(agents=[classic_agent], tasks=[classic_task])
|
||||
|
||||
# MCP integration should not affect classic workflows
|
||||
tools_result = classic_crew._inject_mcp_tools([], classic_agent)
|
||||
assert tools_result == [] # No MCP tools, empty list returned
|
||||
|
||||
print("✅ Existing crew workflows unchanged")
|
||||
|
||||
# Test 3: Agent with empty mcps list works normally
|
||||
empty_mcps_agent = Agent(
|
||||
role="Empty MCP Agent",
|
||||
goal="Test empty mcps list",
|
||||
backstory="Agent with empty mcps list",
|
||||
mcps=[]
|
||||
)
|
||||
|
||||
assert empty_mcps_agent.mcps == []
|
||||
empty_tools = empty_mcps_agent.get_mcp_tools(empty_mcps_agent.mcps)
|
||||
assert empty_tools == []
|
||||
|
||||
print("✅ Empty mcps list handling works")
|
||||
|
||||
print("\n✅ Backwards Compatibility: CONFIRMED")
|
||||
print(" Existing CrewAI functionality remains unchanged")
|
||||
1
lib/crewai/tests/mcp/__init__.py
Normal file
1
lib/crewai/tests/mcp/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""MCP integration tests package."""
|
||||
281
lib/crewai/tests/mcp/test_mcp_caching.py
Normal file
281
lib/crewai/tests/mcp/test_mcp_caching.py
Normal file
@@ -0,0 +1,281 @@
|
||||
"""Tests for MCP caching functionality."""
|
||||
|
||||
import time
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
# Import from the source directory
|
||||
import sys
|
||||
import os
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../../src'))
|
||||
|
||||
from crewai.agent import Agent, _mcp_schema_cache, _cache_ttl
|
||||
|
||||
|
||||
class TestMCPCaching:
|
||||
"""Test suite for MCP caching functionality."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Clear cache before each test."""
|
||||
_mcp_schema_cache.clear()
|
||||
|
||||
def teardown_method(self):
|
||||
"""Clear cache after each test."""
|
||||
_mcp_schema_cache.clear()
|
||||
|
||||
@pytest.fixture
|
||||
def caching_agent(self):
|
||||
"""Create agent for caching tests."""
|
||||
return Agent(
|
||||
role="Caching Test Agent",
|
||||
goal="Test MCP caching behavior",
|
||||
backstory="Agent designed for testing cache functionality",
|
||||
mcps=["https://cache-test.com/mcp"]
|
||||
)
|
||||
|
||||
def test_cache_initially_empty(self):
|
||||
"""Test that MCP schema cache starts empty."""
|
||||
assert len(_mcp_schema_cache) == 0
|
||||
|
||||
def test_cache_ttl_constant(self):
|
||||
"""Test that cache TTL is set to expected value."""
|
||||
assert _cache_ttl == 300 # 5 minutes
|
||||
|
||||
def test_cache_population_on_first_access(self, caching_agent):
|
||||
"""Test that cache gets populated on first schema access."""
|
||||
server_params = {"url": "https://cache-test.com/mcp"}
|
||||
mock_schemas = {"tool1": {"description": "Cached tool"}}
|
||||
|
||||
with patch.object(caching_agent, '_get_mcp_tool_schemas_async', return_value=mock_schemas), \
|
||||
patch('crewai.agent.time.time', return_value=1000):
|
||||
|
||||
# Cache should be empty initially
|
||||
assert len(_mcp_schema_cache) == 0
|
||||
|
||||
# First call should populate cache
|
||||
schemas = caching_agent._get_mcp_tool_schemas(server_params)
|
||||
|
||||
assert schemas == mock_schemas
|
||||
assert len(_mcp_schema_cache) == 1
|
||||
assert "https://cache-test.com/mcp" in _mcp_schema_cache
|
||||
|
||||
def test_cache_hit_returns_cached_data(self, caching_agent):
|
||||
"""Test that cache hit returns previously cached data."""
|
||||
server_params = {"url": "https://cache-hit-test.com/mcp"}
|
||||
mock_schemas = {"tool1": {"description": "Cache hit tool"}}
|
||||
|
||||
with patch.object(caching_agent, '_get_mcp_tool_schemas_async', return_value=mock_schemas) as mock_async:
|
||||
|
||||
# First call - populates cache
|
||||
with patch('crewai.agent.time.time', return_value=1000):
|
||||
schemas1 = caching_agent._get_mcp_tool_schemas(server_params)
|
||||
|
||||
# Second call - should use cache
|
||||
with patch('crewai.agent.time.time', return_value=1150): # 150s later, within TTL
|
||||
schemas2 = caching_agent._get_mcp_tool_schemas(server_params)
|
||||
|
||||
assert schemas1 == schemas2 == mock_schemas
|
||||
assert mock_async.call_count == 1 # Only called once
|
||||
|
||||
def test_cache_miss_after_ttl_expiration(self, caching_agent):
|
||||
"""Test that cache miss occurs after TTL expiration."""
|
||||
server_params = {"url": "https://cache-expiry-test.com/mcp"}
|
||||
mock_schemas = {"tool1": {"description": "Expiry test tool"}}
|
||||
|
||||
with patch.object(caching_agent, '_get_mcp_tool_schemas_async', return_value=mock_schemas) as mock_async:
|
||||
|
||||
# First call at time 1000
|
||||
with patch('crewai.agent.time.time', return_value=1000):
|
||||
schemas1 = caching_agent._get_mcp_tool_schemas(server_params)
|
||||
|
||||
# Call after TTL expiration (300s + buffer)
|
||||
with patch('crewai.agent.time.time', return_value=1400): # 400s later, beyond TTL
|
||||
schemas2 = caching_agent._get_mcp_tool_schemas(server_params)
|
||||
|
||||
assert schemas1 == schemas2 == mock_schemas
|
||||
assert mock_async.call_count == 2 # Called twice due to expiration
|
||||
|
||||
def test_cache_key_generation(self, caching_agent):
|
||||
"""Test that cache keys are generated correctly."""
|
||||
different_urls = [
|
||||
"https://server1.com/mcp",
|
||||
"https://server2.com/mcp",
|
||||
"https://server1.com/mcp?api_key=different"
|
||||
]
|
||||
|
||||
mock_schemas = {"tool1": {"description": "Key test tool"}}
|
||||
|
||||
with patch.object(caching_agent, '_get_mcp_tool_schemas_async', return_value=mock_schemas), \
|
||||
patch('crewai.agent.time.time', return_value=1000):
|
||||
|
||||
# Call with different URLs
|
||||
for url in different_urls:
|
||||
caching_agent._get_mcp_tool_schemas({"url": url})
|
||||
|
||||
# Should create separate cache entries for each URL
|
||||
assert len(_mcp_schema_cache) == len(different_urls)
|
||||
|
||||
for url in different_urls:
|
||||
assert url in _mcp_schema_cache
|
||||
|
||||
def test_cache_handles_identical_concurrent_requests(self, caching_agent):
|
||||
"""Test cache behavior with identical concurrent requests."""
|
||||
server_params = {"url": "https://concurrent-test.com/mcp"}
|
||||
mock_schemas = {"tool1": {"description": "Concurrent tool"}}
|
||||
|
||||
call_count = 0
|
||||
async def counted_async_call(*args, **kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
# Add small delay to simulate network call
|
||||
await asyncio.sleep(0.1)
|
||||
return mock_schemas
|
||||
|
||||
with patch.object(caching_agent, '_get_mcp_tool_schemas_async', side_effect=counted_async_call), \
|
||||
patch('crewai.agent.time.time', return_value=1000):
|
||||
|
||||
# First call populates cache
|
||||
schemas1 = caching_agent._get_mcp_tool_schemas(server_params)
|
||||
|
||||
# Subsequent calls should use cache
|
||||
with patch('crewai.agent.time.time', return_value=1100):
|
||||
schemas2 = caching_agent._get_mcp_tool_schemas(server_params)
|
||||
schemas3 = caching_agent._get_mcp_tool_schemas(server_params)
|
||||
|
||||
assert schemas1 == schemas2 == schemas3 == mock_schemas
|
||||
assert call_count == 1 # Only first call should hit the server
|
||||
|
||||
def test_cache_isolation_between_different_servers(self, caching_agent):
|
||||
"""Test that cache entries are isolated between different servers."""
|
||||
server1_params = {"url": "https://server1.com/mcp"}
|
||||
server2_params = {"url": "https://server2.com/mcp"}
|
||||
|
||||
server1_schemas = {"tool1": {"description": "Server 1 tool"}}
|
||||
server2_schemas = {"tool2": {"description": "Server 2 tool"}}
|
||||
|
||||
def mock_async_by_url(server_params):
|
||||
url = server_params["url"]
|
||||
if "server1" in url:
|
||||
return server1_schemas
|
||||
elif "server2" in url:
|
||||
return server2_schemas
|
||||
return {}
|
||||
|
||||
with patch.object(caching_agent, '_get_mcp_tool_schemas_async', side_effect=mock_async_by_url), \
|
||||
patch('crewai.agent.time.time', return_value=1000):
|
||||
|
||||
# Call both servers
|
||||
schemas1 = caching_agent._get_mcp_tool_schemas(server1_params)
|
||||
schemas2 = caching_agent._get_mcp_tool_schemas(server2_params)
|
||||
|
||||
assert schemas1 == server1_schemas
|
||||
assert schemas2 == server2_schemas
|
||||
assert len(_mcp_schema_cache) == 2
|
||||
|
||||
def test_cache_handles_failed_operations_correctly(self, caching_agent):
|
||||
"""Test that cache doesn't store failed operations."""
|
||||
server_params = {"url": "https://failing-cache-test.com/mcp"}
|
||||
|
||||
with patch.object(caching_agent, '_get_mcp_tool_schemas_async', side_effect=Exception("Server failed")), \
|
||||
patch('crewai.agent.time.time', return_value=1000):
|
||||
|
||||
# Failed operation should not populate cache
|
||||
schemas = caching_agent._get_mcp_tool_schemas(server_params)
|
||||
|
||||
assert schemas == {} # Empty dict returned on failure
|
||||
assert len(_mcp_schema_cache) == 0 # Cache should remain empty
|
||||
|
||||
def test_cache_debug_logging(self, caching_agent):
|
||||
"""Test cache debug logging functionality."""
|
||||
server_params = {"url": "https://debug-log-test.com/mcp"}
|
||||
mock_schemas = {"tool1": {"description": "Debug log tool"}}
|
||||
|
||||
with patch.object(caching_agent, '_get_mcp_tool_schemas_async', return_value=mock_schemas), \
|
||||
patch.object(caching_agent, '_logger') as mock_logger:
|
||||
|
||||
# First call - populates cache
|
||||
with patch('crewai.agent.time.time', return_value=1000):
|
||||
caching_agent._get_mcp_tool_schemas(server_params)
|
||||
|
||||
# Second call - should log cache hit
|
||||
with patch('crewai.agent.time.time', return_value=1100): # Within TTL
|
||||
caching_agent._get_mcp_tool_schemas(server_params)
|
||||
|
||||
# Should log debug message about cache usage
|
||||
debug_calls = [call for call in mock_logger.log.call_args_list if call[0][0] == "debug"]
|
||||
assert len(debug_calls) > 0
|
||||
assert "cached mcp tool schemas" in debug_calls[0][0][1].lower()
|
||||
|
||||
def test_cache_thread_safety_simulation(self, caching_agent):
|
||||
"""Simulate thread safety scenarios for cache access."""
|
||||
server_params = {"url": "https://thread-safety-test.com/mcp"}
|
||||
mock_schemas = {"tool1": {"description": "Thread safety tool"}}
|
||||
|
||||
# Simulate multiple "threads" accessing cache simultaneously
|
||||
# (Note: This is a simplified simulation in a single-threaded test)
|
||||
|
||||
results = []
|
||||
|
||||
with patch.object(caching_agent, '_get_mcp_tool_schemas_async', return_value=mock_schemas) as mock_async, \
|
||||
patch('crewai.agent.time.time', return_value=1000):
|
||||
|
||||
# First call populates cache
|
||||
result1 = caching_agent._get_mcp_tool_schemas(server_params)
|
||||
results.append(result1)
|
||||
|
||||
# Multiple rapid subsequent calls (simulating concurrent access)
|
||||
with patch('crewai.agent.time.time', return_value=1001):
|
||||
for _ in range(5):
|
||||
result = caching_agent._get_mcp_tool_schemas(server_params)
|
||||
results.append(result)
|
||||
|
||||
# All results should be identical (from cache)
|
||||
assert all(result == mock_schemas for result in results)
|
||||
assert len(results) == 6
|
||||
# Async method should only be called once
|
||||
assert mock_async.call_count == 1
|
||||
|
||||
def test_cache_size_management_with_many_servers(self, caching_agent):
|
||||
"""Test cache behavior with many different servers."""
|
||||
mock_schemas = {"tool1": {"description": "Size management tool"}}
|
||||
|
||||
with patch.object(caching_agent, '_get_mcp_tool_schemas_async', return_value=mock_schemas), \
|
||||
patch('crewai.agent.time.time', return_value=1000):
|
||||
|
||||
# Add many server entries to cache
|
||||
for i in range(50):
|
||||
server_url = f"https://server{i:03d}.com/mcp"
|
||||
caching_agent._get_mcp_tool_schemas({"url": server_url})
|
||||
|
||||
# Cache should contain all entries
|
||||
assert len(_mcp_schema_cache) == 50
|
||||
|
||||
# Verify each entry has correct structure
|
||||
for server_url, (cached_schemas, cache_time) in _mcp_schema_cache.items():
|
||||
assert cached_schemas == mock_schemas
|
||||
assert cache_time == 1000
|
||||
|
||||
def test_cache_performance_comparison_with_without_cache(self, caching_agent):
|
||||
"""Compare performance with and without caching."""
|
||||
server_params = {"url": "https://performance-comparison.com/mcp"}
|
||||
mock_schemas = {"tool1": {"description": "Performance comparison tool"}}
|
||||
|
||||
# Test without cache (cold call)
|
||||
with patch.object(caching_agent, '_get_mcp_tool_schemas_async', return_value=mock_schemas) as mock_async:
|
||||
|
||||
# Cold call
|
||||
start_time = time.time()
|
||||
with patch('crewai.agent.time.time', return_value=1000):
|
||||
schemas1 = caching_agent._get_mcp_tool_schemas(server_params)
|
||||
cold_call_time = time.time() - start_time
|
||||
|
||||
# Warm call (from cache)
|
||||
start_time = time.time()
|
||||
with patch('crewai.agent.time.time', return_value=1100): # Within TTL
|
||||
schemas2 = caching_agent._get_mcp_tool_schemas(server_params)
|
||||
warm_call_time = time.time() - start_time
|
||||
|
||||
assert schemas1 == schemas2 == mock_schemas
|
||||
assert mock_async.call_count == 1
|
||||
# Warm call should be significantly faster
|
||||
assert warm_call_time < cold_call_time / 2
|
||||
559
lib/crewai/tests/mcp/test_mcp_error_handling.py
Normal file
559
lib/crewai/tests/mcp/test_mcp_error_handling.py
Normal file
@@ -0,0 +1,559 @@
|
||||
"""Tests for MCP error handling scenarios."""
|
||||
|
||||
import asyncio
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, Mock, patch, MagicMock
|
||||
|
||||
# Import from the source directory
|
||||
import sys
|
||||
import os
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../../src'))
|
||||
|
||||
from crewai.agent import Agent
|
||||
from crewai.tools.mcp_tool_wrapper import MCPToolWrapper
|
||||
|
||||
|
||||
class TestMCPErrorHandling:
|
||||
"""Test suite for MCP error handling scenarios."""
|
||||
|
||||
@pytest.fixture
|
||||
def sample_agent(self):
|
||||
"""Create a sample agent for error testing."""
|
||||
return Agent(
|
||||
role="Error Test Agent",
|
||||
goal="Test error handling capabilities",
|
||||
backstory="Agent designed for testing error scenarios",
|
||||
mcps=["https://api.example.com/mcp"]
|
||||
)
|
||||
|
||||
def test_connection_timeout_graceful_handling(self, sample_agent):
|
||||
"""Test graceful handling of connection timeouts."""
|
||||
with patch.object(sample_agent, '_get_mcp_tool_schemas', side_effect=Exception("Connection timed out")), \
|
||||
patch.object(sample_agent, '_logger') as mock_logger:
|
||||
|
||||
tools = sample_agent.get_mcp_tools(["https://slow-server.com/mcp"])
|
||||
|
||||
# Should return empty list and log warning
|
||||
assert tools == []
|
||||
mock_logger.log.assert_called_with("warning", "Skipping MCP https://slow-server.com/mcp due to error: Connection timed out")
|
||||
|
||||
def test_authentication_failure_handling(self, sample_agent):
|
||||
"""Test handling of authentication failures."""
|
||||
with patch.object(sample_agent, '_get_external_mcp_tools', side_effect=Exception("Authentication failed")), \
|
||||
patch.object(sample_agent, '_logger') as mock_logger:
|
||||
|
||||
tools = sample_agent.get_mcp_tools(["https://secure-server.com/mcp"])
|
||||
|
||||
assert tools == []
|
||||
mock_logger.log.assert_called_with("warning", "Skipping MCP https://secure-server.com/mcp due to error: Authentication failed")
|
||||
|
||||
def test_json_parsing_error_handling(self, sample_agent):
|
||||
"""Test handling of JSON parsing errors."""
|
||||
with patch.object(sample_agent, '_get_external_mcp_tools', side_effect=Exception("JSON parsing failed")), \
|
||||
patch.object(sample_agent, '_logger') as mock_logger:
|
||||
|
||||
tools = sample_agent.get_mcp_tools(["https://malformed-server.com/mcp"])
|
||||
|
||||
assert tools == []
|
||||
mock_logger.log.assert_called_with("warning", "Skipping MCP https://malformed-server.com/mcp due to error: JSON parsing failed")
|
||||
|
||||
def test_network_connectivity_issues(self, sample_agent):
|
||||
"""Test handling of network connectivity issues."""
|
||||
network_errors = [
|
||||
"Network unreachable",
|
||||
"Connection refused",
|
||||
"DNS resolution failed",
|
||||
"Timeout occurred"
|
||||
]
|
||||
|
||||
for error_msg in network_errors:
|
||||
with patch.object(sample_agent, '_get_external_mcp_tools', side_effect=Exception(error_msg)), \
|
||||
patch.object(sample_agent, '_logger') as mock_logger:
|
||||
|
||||
tools = sample_agent.get_mcp_tools(["https://unreachable-server.com/mcp"])
|
||||
|
||||
assert tools == []
|
||||
mock_logger.log.assert_called_with("warning", f"Skipping MCP https://unreachable-server.com/mcp due to error: {error_msg}")
|
||||
|
||||
def test_malformed_mcp_server_responses(self, sample_agent):
|
||||
"""Test handling of malformed MCP server responses."""
|
||||
malformed_errors = [
|
||||
"Invalid JSON response",
|
||||
"Unexpected response format",
|
||||
"Missing required fields",
|
||||
"Protocol version mismatch"
|
||||
]
|
||||
|
||||
for error_msg in malformed_errors:
|
||||
with patch.object(sample_agent, '_get_mcp_tool_schemas', side_effect=Exception(error_msg)):
|
||||
|
||||
tools = sample_agent._get_external_mcp_tools("https://malformed-server.com/mcp")
|
||||
|
||||
# Should handle error gracefully
|
||||
assert tools == []
|
||||
|
||||
def test_server_unavailability_scenarios(self, sample_agent):
|
||||
"""Test various server unavailability scenarios."""
|
||||
unavailability_scenarios = [
|
||||
"Server returned 404",
|
||||
"Server returned 500",
|
||||
"Service unavailable",
|
||||
"Server maintenance mode"
|
||||
]
|
||||
|
||||
for scenario in unavailability_scenarios:
|
||||
with patch.object(sample_agent, '_get_mcp_tool_schemas', side_effect=Exception(scenario)):
|
||||
|
||||
# Should not raise exception, should return empty list
|
||||
tools = sample_agent._get_external_mcp_tools("https://unavailable-server.com/mcp")
|
||||
assert tools == []
|
||||
|
||||
def test_tool_not_found_errors(self):
|
||||
"""Test handling when specific tool is not found."""
|
||||
wrapper = MCPToolWrapper(
|
||||
mcp_server_params={"url": "https://test.com/mcp"},
|
||||
tool_name="nonexistent_tool",
|
||||
tool_schema={"description": "Tool that doesn't exist"},
|
||||
server_name="test_server"
|
||||
)
|
||||
|
||||
# Mock scenario where tool is not found on server
|
||||
with patch('crewai.tools.mcp_tool_wrapper.streamablehttp_client') as mock_client, \
|
||||
patch('crewai.tools.mcp_tool_wrapper.ClientSession') as mock_session_class:
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_session_class.return_value.__aenter__.return_value = mock_session
|
||||
mock_session.initialize = AsyncMock()
|
||||
|
||||
# Mock empty tools list (tool not found)
|
||||
mock_tools = []
|
||||
|
||||
with patch('crewai.tools.mcp_tool_wrapper.MCPServerAdapter') as mock_adapter:
|
||||
mock_adapter.return_value.__enter__.return_value = mock_tools
|
||||
|
||||
result = wrapper._run(query="test")
|
||||
|
||||
assert "not found on MCP server" in result
|
||||
|
||||
def test_mixed_server_success_and_failure(self, sample_agent):
|
||||
"""Test handling mixed scenarios with both successful and failing servers."""
|
||||
mcps = [
|
||||
"https://failing-server.com/mcp", # Will fail
|
||||
"https://working-server.com/mcp", # Will succeed
|
||||
"https://another-failing.com/mcp", # Will fail
|
||||
]
|
||||
|
||||
def mock_get_external_tools(mcp_ref):
|
||||
if "failing" in mcp_ref:
|
||||
raise Exception("Server failed")
|
||||
else:
|
||||
# Return mock tool for working server
|
||||
return [Mock(name=f"tool_from_{mcp_ref}")]
|
||||
|
||||
with patch.object(sample_agent, '_get_external_mcp_tools', side_effect=mock_get_external_tools), \
|
||||
patch.object(sample_agent, '_logger') as mock_logger:
|
||||
|
||||
tools = sample_agent.get_mcp_tools(mcps)
|
||||
|
||||
# Should get tools from working server only
|
||||
assert len(tools) == 1
|
||||
|
||||
# Should log warnings for failing servers
|
||||
assert mock_logger.log.call_count >= 2 # At least 2 warning calls
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_mcp_operations_error_isolation(self, sample_agent):
|
||||
"""Test that errors in concurrent MCP operations are properly isolated."""
|
||||
async def mock_operation_with_random_failures(server_params):
|
||||
url = server_params["url"]
|
||||
if "fail" in url:
|
||||
raise Exception(f"Simulated failure for {url}")
|
||||
return {"tool1": {"description": "Success tool"}}
|
||||
|
||||
server_params_list = [
|
||||
{"url": "https://server1-fail.com/mcp"},
|
||||
{"url": "https://server2-success.com/mcp"},
|
||||
{"url": "https://server3-fail.com/mcp"},
|
||||
{"url": "https://server4-success.com/mcp"}
|
||||
]
|
||||
|
||||
# Run operations concurrently
|
||||
results = []
|
||||
for params in server_params_list:
|
||||
try:
|
||||
result = await mock_operation_with_random_failures(params)
|
||||
results.append(result)
|
||||
except Exception:
|
||||
results.append({}) # Empty dict for failures
|
||||
|
||||
# Should have 2 successful results and 2 empty results
|
||||
successful_results = [r for r in results if r]
|
||||
assert len(successful_results) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mcp_library_import_error_handling(self):
|
||||
"""Test handling when MCP library is not available."""
|
||||
wrapper = MCPToolWrapper(
|
||||
mcp_server_params={"url": "https://test.com/mcp"},
|
||||
tool_name="test_tool",
|
||||
tool_schema={"description": "Test tool"},
|
||||
server_name="test_server"
|
||||
)
|
||||
|
||||
# Mock ImportError for MCP library
|
||||
with patch('builtins.__import__', side_effect=ImportError("No module named 'mcp'")):
|
||||
result = await wrapper._run_async(query="test")
|
||||
|
||||
assert "mcp library not available" in result.lower()
|
||||
assert "pip install mcp" in result
|
||||
|
||||
def test_mcp_tools_graceful_degradation_in_agent_creation(self):
|
||||
"""Test that agent creation continues even with failing MCP servers."""
|
||||
with patch('crewai.agent.Agent._get_external_mcp_tools', side_effect=Exception("All MCP servers failed")):
|
||||
|
||||
# Agent creation should succeed even if MCP discovery fails
|
||||
agent = Agent(
|
||||
role="Resilient Agent",
|
||||
goal="Continue working despite MCP failures",
|
||||
backstory="Agent that handles MCP failures gracefully",
|
||||
mcps=["https://failing-server.com/mcp"]
|
||||
)
|
||||
|
||||
assert agent is not None
|
||||
assert agent.role == "Resilient Agent"
|
||||
assert len(agent.mcps) == 1
|
||||
|
||||
def test_partial_mcp_server_failure_recovery(self, sample_agent):
|
||||
"""Test recovery when some but not all MCP servers fail."""
|
||||
mcps = [
|
||||
"https://server1.com/mcp", # Will succeed
|
||||
"https://server2.com/mcp", # Will fail
|
||||
"https://server3.com/mcp" # Will succeed
|
||||
]
|
||||
|
||||
def mock_external_tools(mcp_ref):
|
||||
if "server2" in mcp_ref:
|
||||
raise Exception("Server 2 is down")
|
||||
return [Mock(name=f"tool_from_{mcp_ref.split('//')[-1].split('.')[0]}")]
|
||||
|
||||
with patch.object(sample_agent, '_get_external_mcp_tools', side_effect=mock_external_tools):
|
||||
tools = sample_agent.get_mcp_tools(mcps)
|
||||
|
||||
# Should get tools from server1 and server3, skip server2
|
||||
assert len(tools) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_execution_error_messages_are_informative(self):
|
||||
"""Test that tool execution error messages provide useful information."""
|
||||
wrapper = MCPToolWrapper(
|
||||
mcp_server_params={"url": "https://test.com/mcp"},
|
||||
tool_name="failing_tool",
|
||||
tool_schema={"description": "Tool that fails"},
|
||||
server_name="test_server"
|
||||
)
|
||||
|
||||
error_scenarios = [
|
||||
(asyncio.TimeoutError(), "timed out"),
|
||||
(ConnectionError("Connection failed"), "network connection failed"),
|
||||
(Exception("Authentication failed"), "authentication failed"),
|
||||
(ValueError("JSON parsing error"), "server response parsing error"),
|
||||
(Exception("Tool not found"), "mcp execution error")
|
||||
]
|
||||
|
||||
for error, expected_msg in error_scenarios:
|
||||
with patch.object(wrapper, '_execute_tool', side_effect=error):
|
||||
result = await wrapper._run_async(query="test")
|
||||
|
||||
assert expected_msg.lower() in result.lower()
|
||||
assert "failing_tool" in result
|
||||
|
||||
def test_mcp_server_connection_resilience(self, sample_agent):
|
||||
"""Test MCP server connection resilience across multiple operations."""
|
||||
# Simulate intermittent connection issues
|
||||
call_count = 0
|
||||
def intermittent_connection_mock(server_params):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
|
||||
# Fail every other call to simulate intermittent issues
|
||||
if call_count % 2 == 0:
|
||||
raise Exception("Intermittent connection failure")
|
||||
|
||||
return {"stable_tool": {"description": "Tool from stable connection"}}
|
||||
|
||||
with patch.object(sample_agent, '_get_mcp_tool_schemas', side_effect=intermittent_connection_mock):
|
||||
|
||||
# Multiple calls should handle intermittent failures
|
||||
results = []
|
||||
for i in range(4):
|
||||
tools = sample_agent._get_external_mcp_tools("https://intermittent-server.com/mcp")
|
||||
results.append(len(tools))
|
||||
|
||||
# Should have some successes and some failures
|
||||
successes = [r for r in results if r > 0]
|
||||
failures = [r for r in results if r == 0]
|
||||
|
||||
assert len(successes) >= 1 # At least one success
|
||||
assert len(failures) >= 1 # At least one failure
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mcp_tool_schema_discovery_timeout_handling(self, sample_agent):
|
||||
"""Test timeout handling in MCP tool schema discovery."""
|
||||
server_params = {"url": "https://slow-server.com/mcp"}
|
||||
|
||||
# Mock timeout during discovery
|
||||
with patch.object(sample_agent, '_discover_mcp_tools', side_effect=asyncio.TimeoutError):
|
||||
with pytest.raises(RuntimeError, match="Failed to discover MCP tools after 3 attempts"):
|
||||
await sample_agent._get_mcp_tool_schemas_async(server_params)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mcp_session_initialization_timeout(self, sample_agent):
|
||||
"""Test timeout during MCP session initialization."""
|
||||
server_url = "https://slow-init-server.com/mcp"
|
||||
|
||||
with patch('crewai.agent.streamablehttp_client') as mock_client, \
|
||||
patch('crewai.agent.ClientSession') as mock_session_class:
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_session_class.return_value.__aenter__.return_value = mock_session
|
||||
# Mock timeout during initialization
|
||||
mock_session.initialize = AsyncMock(side_effect=asyncio.TimeoutError)
|
||||
|
||||
mock_client.return_value.__aenter__.return_value = (None, None, None)
|
||||
|
||||
with pytest.raises(asyncio.TimeoutError):
|
||||
await sample_agent._discover_mcp_tools(server_url)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mcp_tool_listing_timeout(self, sample_agent):
|
||||
"""Test timeout during MCP tool listing."""
|
||||
server_url = "https://slow-list-server.com/mcp"
|
||||
|
||||
with patch('crewai.agent.streamablehttp_client') as mock_client, \
|
||||
patch('crewai.agent.ClientSession') as mock_session_class:
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_session_class.return_value.__aenter__.return_value = mock_session
|
||||
mock_session.initialize = AsyncMock()
|
||||
# Mock timeout during tool listing
|
||||
mock_session.list_tools = AsyncMock(side_effect=asyncio.TimeoutError)
|
||||
|
||||
mock_client.return_value.__aenter__.return_value = (None, None, None)
|
||||
|
||||
with pytest.raises(asyncio.TimeoutError):
|
||||
await sample_agent._discover_mcp_tools(server_url)
|
||||
|
||||
def test_mcp_server_response_format_errors(self, sample_agent):
|
||||
"""Test handling of various MCP server response format errors."""
|
||||
response_format_errors = [
|
||||
"Invalid response structure",
|
||||
"Missing required fields",
|
||||
"Unexpected response type",
|
||||
"Protocol version incompatible"
|
||||
]
|
||||
|
||||
for error_msg in response_format_errors:
|
||||
with patch.object(sample_agent, '_get_mcp_tool_schemas', side_effect=Exception(error_msg)):
|
||||
|
||||
tools = sample_agent._get_external_mcp_tools("https://bad-format-server.com/mcp")
|
||||
assert tools == []
|
||||
|
||||
def test_mcp_multiple_concurrent_failures(self, sample_agent):
|
||||
"""Test handling multiple concurrent MCP server failures."""
|
||||
failing_mcps = [
|
||||
"https://fail1.com/mcp",
|
||||
"https://fail2.com/mcp",
|
||||
"https://fail3.com/mcp",
|
||||
"https://fail4.com/mcp",
|
||||
"https://fail5.com/mcp"
|
||||
]
|
||||
|
||||
with patch.object(sample_agent, '_get_external_mcp_tools', side_effect=Exception("Server failure")), \
|
||||
patch.object(sample_agent, '_logger') as mock_logger:
|
||||
|
||||
tools = sample_agent.get_mcp_tools(failing_mcps)
|
||||
|
||||
# Should handle all failures gracefully
|
||||
assert tools == []
|
||||
# Should log warning for each failed server
|
||||
assert mock_logger.log.call_count == len(failing_mcps)
|
||||
|
||||
def test_mcp_crewai_amp_server_failures(self, sample_agent):
|
||||
"""Test handling of CrewAI AMP server failures."""
|
||||
amp_refs = [
|
||||
"crewai-amp:nonexistent-mcp",
|
||||
"crewai-amp:failing-mcp#tool_name"
|
||||
]
|
||||
|
||||
with patch.object(sample_agent, '_get_amp_mcp_tools', side_effect=Exception("AMP server unavailable")), \
|
||||
patch.object(sample_agent, '_logger') as mock_logger:
|
||||
|
||||
tools = sample_agent.get_mcp_tools(amp_refs)
|
||||
|
||||
assert tools == []
|
||||
assert mock_logger.log.call_count == len(amp_refs)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mcp_tool_execution_various_failure_modes(self):
|
||||
"""Test various MCP tool execution failure modes."""
|
||||
wrapper = MCPToolWrapper(
|
||||
mcp_server_params={"url": "https://test.com/mcp"},
|
||||
tool_name="test_tool",
|
||||
tool_schema={"description": "Test tool"},
|
||||
server_name="test_server"
|
||||
)
|
||||
|
||||
failure_scenarios = [
|
||||
# Connection failures
|
||||
(ConnectionError("Connection reset by peer"), "network connection failed"),
|
||||
(ConnectionRefusedError("Connection refused"), "network connection failed"),
|
||||
|
||||
# Timeout failures
|
||||
(asyncio.TimeoutError(), "timed out"),
|
||||
|
||||
# Authentication failures
|
||||
(PermissionError("Access denied"), "authentication failed"),
|
||||
(Exception("401 Unauthorized"), "authentication failed"),
|
||||
|
||||
# Parsing failures
|
||||
(ValueError("JSON decode error"), "server response parsing error"),
|
||||
(Exception("Invalid JSON"), "server response parsing error"),
|
||||
|
||||
# Generic failures
|
||||
(Exception("Unknown error"), "mcp execution error"),
|
||||
]
|
||||
|
||||
for error, expected_msg_part in failure_scenarios:
|
||||
with patch.object(wrapper, '_execute_tool', side_effect=error):
|
||||
result = await wrapper._run_async(query="test")
|
||||
|
||||
assert expected_msg_part in result.lower()
|
||||
|
||||
def test_mcp_error_logging_provides_context(self, sample_agent):
|
||||
"""Test that MCP error logging provides sufficient context for debugging."""
|
||||
problematic_mcp = "https://problematic-server.com/mcp#specific_tool"
|
||||
|
||||
with patch.object(sample_agent, '_get_external_mcp_tools', side_effect=Exception("Detailed error message with context")), \
|
||||
patch.object(sample_agent, '_logger') as mock_logger:
|
||||
|
||||
tools = sample_agent.get_mcp_tools([problematic_mcp])
|
||||
|
||||
# Verify logging call includes full MCP reference
|
||||
mock_logger.log.assert_called_with("warning", f"Skipping MCP {problematic_mcp} due to error: Detailed error message with context")
|
||||
|
||||
def test_mcp_error_recovery_preserves_agent_functionality(self, sample_agent):
|
||||
"""Test that MCP errors don't break core agent functionality."""
|
||||
# Even with all MCP servers failing, agent should still work
|
||||
with patch.object(sample_agent, 'get_mcp_tools', return_value=[]):
|
||||
|
||||
# Agent should still have core functionality
|
||||
assert sample_agent.role is not None
|
||||
assert sample_agent.goal is not None
|
||||
assert sample_agent.backstory is not None
|
||||
assert hasattr(sample_agent, 'execute_task')
|
||||
assert hasattr(sample_agent, 'create_agent_executor')
|
||||
|
||||
def test_mcp_error_handling_with_existing_tools(self, sample_agent):
|
||||
"""Test MCP error handling when agent has existing tools."""
|
||||
from crewai.tools import BaseTool
|
||||
|
||||
class TestTool(BaseTool):
|
||||
name: str = "existing_tool"
|
||||
description: str = "Existing agent tool"
|
||||
|
||||
def _run(self, **kwargs):
|
||||
return "Existing tool result"
|
||||
|
||||
agent_with_tools = Agent(
|
||||
role="Agent with Tools",
|
||||
goal="Test MCP errors with existing tools",
|
||||
backstory="Agent that has both regular and MCP tools",
|
||||
tools=[TestTool()],
|
||||
mcps=["https://failing-mcp.com/mcp"]
|
||||
)
|
||||
|
||||
# MCP failures should not affect existing tools
|
||||
with patch.object(agent_with_tools, 'get_mcp_tools', return_value=[]):
|
||||
assert len(agent_with_tools.tools) == 1
|
||||
assert agent_with_tools.tools[0].name == "existing_tool"
|
||||
|
||||
|
||||
class TestMCPErrorRecoveryPatterns:
|
||||
"""Test specific error recovery patterns for MCP integration."""
|
||||
|
||||
def test_exponential_backoff_calculation(self):
|
||||
"""Test exponential backoff timing calculation."""
|
||||
wrapper = MCPToolWrapper(
|
||||
mcp_server_params={"url": "https://test.com/mcp"},
|
||||
tool_name="test_tool",
|
||||
tool_schema={"description": "Test tool"},
|
||||
server_name="test_server"
|
||||
)
|
||||
|
||||
# Test backoff timing
|
||||
with patch('crewai.tools.mcp_tool_wrapper.asyncio.sleep') as mock_sleep, \
|
||||
patch.object(wrapper, '_execute_tool', side_effect=[
|
||||
Exception("Fail 1"),
|
||||
Exception("Fail 2"),
|
||||
"Success"
|
||||
]):
|
||||
|
||||
result = asyncio.run(wrapper._run_async(query="test"))
|
||||
|
||||
# Should succeed after retries
|
||||
assert result == "Success"
|
||||
|
||||
# Verify exponential backoff sleep calls
|
||||
expected_sleeps = [1, 2] # 2^0=1, 2^1=2
|
||||
actual_sleeps = [call.args[0] for call in mock_sleep.call_args_list]
|
||||
assert actual_sleeps == expected_sleeps
|
||||
|
||||
def test_non_retryable_errors_fail_fast(self):
|
||||
"""Test that non-retryable errors (like auth) fail fast without retries."""
|
||||
wrapper = MCPToolWrapper(
|
||||
mcp_server_params={"url": "https://test.com/mcp"},
|
||||
tool_name="test_tool",
|
||||
tool_schema={"description": "Test tool"},
|
||||
server_name="test_server"
|
||||
)
|
||||
|
||||
# Authentication errors should not be retried
|
||||
with patch.object(wrapper, '_execute_tool', side_effect=Exception("Authentication failed")), \
|
||||
patch('crewai.tools.mcp_tool_wrapper.asyncio.sleep') as mock_sleep:
|
||||
|
||||
result = asyncio.run(wrapper._run_async(query="test"))
|
||||
|
||||
assert "authentication failed" in result.lower()
|
||||
# Should not have retried (no sleep calls)
|
||||
mock_sleep.assert_not_called()
|
||||
|
||||
def test_cache_invalidation_on_persistent_errors(self, sample_agent):
|
||||
"""Test that persistent errors don't get cached."""
|
||||
server_params = {"url": "https://persistently-failing.com/mcp"}
|
||||
|
||||
with patch.object(sample_agent, '_get_mcp_tool_schemas_async', side_effect=Exception("Persistent failure")), \
|
||||
patch('crewai.agent.time.time', return_value=1000):
|
||||
|
||||
# First call should attempt and fail
|
||||
schemas1 = sample_agent._get_mcp_tool_schemas(server_params)
|
||||
assert schemas1 == {}
|
||||
|
||||
# Second call should attempt again (not use cached failure)
|
||||
with patch('crewai.agent.time.time', return_value=1001):
|
||||
schemas2 = sample_agent._get_mcp_tool_schemas(server_params)
|
||||
assert schemas2 == {}
|
||||
|
||||
def test_error_context_preservation_through_call_stack(self, sample_agent):
|
||||
"""Test that error context is preserved through the entire call stack."""
|
||||
original_error = Exception("Original detailed error with context information")
|
||||
|
||||
with patch.object(sample_agent, '_get_mcp_tool_schemas', side_effect=original_error), \
|
||||
patch.object(sample_agent, '_logger') as mock_logger:
|
||||
|
||||
# Call through the full stack
|
||||
tools = sample_agent.get_mcp_tools(["https://error-context-server.com/mcp"])
|
||||
|
||||
# Original error message should be preserved in logs
|
||||
assert tools == []
|
||||
log_call = mock_logger.log.call_args
|
||||
assert "Original detailed error with context information" in log_call[0][1]
|
||||
548
lib/crewai/tests/mcp/test_mcp_performance.py
Normal file
548
lib/crewai/tests/mcp/test_mcp_performance.py
Normal file
@@ -0,0 +1,548 @@
|
||||
"""Tests for MCP performance and timeout behavior."""
|
||||
|
||||
import asyncio
|
||||
import pytest
|
||||
import time
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
# Import from the source directory
|
||||
import sys
|
||||
import os
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../../src'))
|
||||
|
||||
from crewai.agent import Agent, MCP_CONNECTION_TIMEOUT, MCP_TOOL_EXECUTION_TIMEOUT, MCP_DISCOVERY_TIMEOUT
|
||||
from crewai.tools.mcp_tool_wrapper import MCPToolWrapper
|
||||
|
||||
|
||||
class TestMCPPerformanceAndTimeouts:
|
||||
"""Test suite for MCP performance and timeout behavior."""
|
||||
|
||||
@pytest.fixture
|
||||
def sample_agent(self):
|
||||
"""Create a sample agent for performance testing."""
|
||||
return Agent(
|
||||
role="Performance Test Agent",
|
||||
goal="Test MCP performance characteristics",
|
||||
backstory="Agent designed for performance testing",
|
||||
mcps=["https://api.example.com/mcp"]
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def performance_wrapper(self):
|
||||
"""Create MCPToolWrapper for performance testing."""
|
||||
return MCPToolWrapper(
|
||||
mcp_server_params={"url": "https://performance-test.com/mcp"},
|
||||
tool_name="performance_tool",
|
||||
tool_schema={"description": "Tool for performance testing"},
|
||||
server_name="performance_server"
|
||||
)
|
||||
|
||||
def test_connection_timeout_constant_value(self):
|
||||
"""Test that connection timeout constant is set correctly."""
|
||||
assert MCP_CONNECTION_TIMEOUT == 10
|
||||
assert isinstance(MCP_CONNECTION_TIMEOUT, int)
|
||||
|
||||
def test_tool_execution_timeout_constant_value(self):
|
||||
"""Test that tool execution timeout constant is set correctly."""
|
||||
assert MCP_TOOL_EXECUTION_TIMEOUT == 30
|
||||
assert isinstance(MCP_TOOL_EXECUTION_TIMEOUT, int)
|
||||
|
||||
def test_discovery_timeout_constant_value(self):
|
||||
"""Test that discovery timeout constant is set correctly."""
|
||||
assert MCP_DISCOVERY_TIMEOUT == 15
|
||||
assert isinstance(MCP_DISCOVERY_TIMEOUT, int)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connection_timeout_enforcement(self, performance_wrapper):
|
||||
"""Test that connection timeout is properly enforced."""
|
||||
# Mock slow connection that exceeds timeout
|
||||
slow_init = AsyncMock()
|
||||
slow_init.side_effect = asyncio.sleep(MCP_CONNECTION_TIMEOUT + 5) # Exceed timeout
|
||||
|
||||
with patch('crewai.tools.mcp_tool_wrapper.streamablehttp_client') as mock_client, \
|
||||
patch('crewai.tools.mcp_tool_wrapper.ClientSession') as mock_session_class:
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_session_class.return_value.__aenter__.return_value = mock_session
|
||||
mock_session.initialize = slow_init
|
||||
|
||||
mock_client.return_value.__aenter__.return_value = (None, None, None)
|
||||
|
||||
start_time = time.time()
|
||||
result = await performance_wrapper._run_async(query="test")
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
# Should timeout and not take much longer than timeout period
|
||||
assert elapsed_time < MCP_TOOL_EXECUTION_TIMEOUT + 5
|
||||
assert "timed out" in result.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_execution_timeout_enforcement(self, performance_wrapper):
|
||||
"""Test that tool execution timeout is properly enforced."""
|
||||
# Mock slow tool execution
|
||||
with patch('crewai.tools.mcp_tool_wrapper.streamablehttp_client') as mock_client, \
|
||||
patch('crewai.tools.mcp_tool_wrapper.ClientSession') as mock_session_class:
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_session_class.return_value.__aenter__.return_value = mock_session
|
||||
mock_session.initialize = AsyncMock()
|
||||
|
||||
# Mock slow tool call
|
||||
async def slow_tool_call(*args, **kwargs):
|
||||
await asyncio.sleep(MCP_TOOL_EXECUTION_TIMEOUT + 5) # Exceed timeout
|
||||
return Mock(content="Should not reach here")
|
||||
|
||||
mock_session.call_tool = slow_tool_call
|
||||
mock_client.return_value.__aenter__.return_value = (None, None, None)
|
||||
|
||||
start_time = time.time()
|
||||
result = await performance_wrapper._run_async(query="test")
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
# Should timeout within reasonable time
|
||||
assert elapsed_time < MCP_TOOL_EXECUTION_TIMEOUT + 10
|
||||
assert "timed out" in result.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_discovery_timeout_enforcement(self, sample_agent):
|
||||
"""Test that discovery timeout is properly enforced."""
|
||||
server_params = {"url": "https://slow-discovery.com/mcp"}
|
||||
|
||||
# Mock slow discovery operation
|
||||
async def slow_discover(server_url):
|
||||
await asyncio.sleep(MCP_DISCOVERY_TIMEOUT + 5) # Exceed timeout
|
||||
return {"tool": {"description": "Should not reach here"}}
|
||||
|
||||
with patch.object(sample_agent, '_discover_mcp_tools', side_effect=slow_discover):
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
with pytest.raises(RuntimeError, match="Failed to discover MCP tools after 3 attempts"):
|
||||
await sample_agent._get_mcp_tool_schemas_async(server_params)
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
# Should timeout within reasonable bounds (including retries)
|
||||
max_expected_time = (MCP_DISCOVERY_TIMEOUT + 5) * 3 + 10 # Retries + buffer
|
||||
assert elapsed_time < max_expected_time
|
||||
|
||||
def test_cache_performance_improvement(self, sample_agent):
|
||||
"""Test that caching provides significant performance improvement."""
|
||||
server_params = {"url": "https://cached-server.com/mcp"}
|
||||
mock_schemas = {"tool1": {"description": "Cached tool"}}
|
||||
|
||||
with patch.object(sample_agent, '_get_mcp_tool_schemas_async', return_value=mock_schemas) as mock_async:
|
||||
|
||||
# First call - should hit server
|
||||
start_time = time.time()
|
||||
with patch('crewai.agent.time.time', return_value=1000):
|
||||
schemas1 = sample_agent._get_mcp_tool_schemas(server_params)
|
||||
first_call_time = time.time() - start_time
|
||||
|
||||
assert mock_async.call_count == 1
|
||||
assert schemas1 == mock_schemas
|
||||
|
||||
# Second call - should use cache
|
||||
start_time = time.time()
|
||||
with patch('crewai.agent.time.time', return_value=1100): # Within 300s TTL
|
||||
schemas2 = sample_agent._get_mcp_tool_schemas(server_params)
|
||||
second_call_time = time.time() - start_time
|
||||
|
||||
# Async method should not be called again
|
||||
assert mock_async.call_count == 1
|
||||
assert schemas2 == mock_schemas
|
||||
|
||||
# Second call should be much faster (cache hit)
|
||||
assert second_call_time < first_call_time / 10 # At least 10x faster
|
||||
|
||||
def test_cache_ttl_expiration_behavior(self, sample_agent):
|
||||
"""Test cache TTL expiration and refresh behavior."""
|
||||
server_params = {"url": "https://ttl-test.com/mcp"}
|
||||
mock_schemas = {"tool1": {"description": "TTL test tool"}}
|
||||
|
||||
with patch.object(sample_agent, '_get_mcp_tool_schemas_async', return_value=mock_schemas) as mock_async:
|
||||
|
||||
# Initial call at time 1000
|
||||
with patch('crewai.agent.time.time', return_value=1000):
|
||||
schemas1 = sample_agent._get_mcp_tool_schemas(server_params)
|
||||
|
||||
assert mock_async.call_count == 1
|
||||
|
||||
# Call within TTL (300 seconds) - should use cache
|
||||
with patch('crewai.agent.time.time', return_value=1200): # 200s later, within TTL
|
||||
schemas2 = sample_agent._get_mcp_tool_schemas(server_params)
|
||||
|
||||
assert mock_async.call_count == 1 # No additional call
|
||||
|
||||
# Call after TTL expiration - should refresh
|
||||
with patch('crewai.agent.time.time', return_value=1400): # 400s later, beyond 300s TTL
|
||||
schemas3 = sample_agent._get_mcp_tool_schemas(server_params)
|
||||
|
||||
assert mock_async.call_count == 2 # Additional call made
|
||||
|
||||
def test_retry_logic_exponential_backoff_timing(self, performance_wrapper):
|
||||
"""Test that retry logic uses proper exponential backoff timing."""
|
||||
failure_count = 0
|
||||
sleep_times = []
|
||||
|
||||
async def mock_failing_execute(**kwargs):
|
||||
nonlocal failure_count
|
||||
failure_count += 1
|
||||
if failure_count < 3:
|
||||
raise Exception("Network connection failed") # Retryable error
|
||||
return "Success after retries"
|
||||
|
||||
async def track_sleep(seconds):
|
||||
sleep_times.append(seconds)
|
||||
|
||||
with patch.object(performance_wrapper, '_execute_tool', side_effect=mock_failing_execute), \
|
||||
patch('crewai.tools.mcp_tool_wrapper.asyncio.sleep', side_effect=track_sleep):
|
||||
|
||||
result = await performance_wrapper._run_async(query="test")
|
||||
|
||||
assert result == "Success after retries"
|
||||
assert failure_count == 3
|
||||
|
||||
# Verify exponential backoff: 2^0=1, 2^1=2
|
||||
assert sleep_times == [1, 2]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_mcp_operations_performance(self, sample_agent):
|
||||
"""Test performance of concurrent MCP operations."""
|
||||
server_urls = [
|
||||
"https://concurrent1.com/mcp",
|
||||
"https://concurrent2.com/mcp",
|
||||
"https://concurrent3.com/mcp",
|
||||
"https://concurrent4.com/mcp",
|
||||
"https://concurrent5.com/mcp"
|
||||
]
|
||||
|
||||
async def mock_discovery(server_url):
|
||||
# Simulate some processing time
|
||||
await asyncio.sleep(0.1)
|
||||
return {f"tool_from_{server_url.split('//')[1].split('.')[0]}": {"description": "Concurrent tool"}}
|
||||
|
||||
with patch.object(sample_agent, '_discover_mcp_tools', side_effect=mock_discovery):
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# Run concurrent operations
|
||||
tasks = []
|
||||
for url in server_urls:
|
||||
server_params = {"url": url}
|
||||
task = sample_agent._get_mcp_tool_schemas_async(server_params)
|
||||
tasks.append(task)
|
||||
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
# Concurrent operations should complete faster than sequential
|
||||
# With 0.1s per operation, concurrent should be ~0.1s, sequential would be ~0.5s
|
||||
assert elapsed_time < 0.5
|
||||
assert len(results) == len(server_urls)
|
||||
|
||||
def test_mcp_tool_creation_performance(self, sample_agent):
|
||||
"""Test performance of MCP tool creation."""
|
||||
# Large number of tools to test scaling
|
||||
large_tool_schemas = {}
|
||||
for i in range(100):
|
||||
large_tool_schemas[f"tool_{i}"] = {"description": f"Tool {i}"}
|
||||
|
||||
with patch.object(sample_agent, '_get_mcp_tool_schemas', return_value=large_tool_schemas), \
|
||||
patch.object(sample_agent, '_extract_server_name', return_value="test_server"):
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
tools = sample_agent._get_external_mcp_tools("https://many-tools-server.com/mcp")
|
||||
|
||||
creation_time = time.time() - start_time
|
||||
|
||||
# Should create 100 tools quickly (less than 1 second)
|
||||
assert len(tools) == 100
|
||||
assert creation_time < 1.0
|
||||
|
||||
def test_memory_usage_with_large_mcp_tool_sets(self, sample_agent):
|
||||
"""Test memory usage with large MCP tool sets."""
|
||||
import sys
|
||||
|
||||
# Create large tool schema set
|
||||
large_schemas = {}
|
||||
for i in range(1000):
|
||||
large_schemas[f"tool_{i}"] = {
|
||||
"description": f"Tool {i} with description " * 10, # Larger descriptions
|
||||
"args_schema": None
|
||||
}
|
||||
|
||||
with patch.object(sample_agent, '_get_mcp_tool_schemas', return_value=large_schemas):
|
||||
|
||||
# Measure memory usage
|
||||
initial_size = sys.getsizeof(sample_agent)
|
||||
|
||||
tools = sample_agent._get_external_mcp_tools("https://large-server.com/mcp")
|
||||
|
||||
final_size = sys.getsizeof(sample_agent)
|
||||
|
||||
# Memory usage should be reasonable
|
||||
assert len(tools) == 1000
|
||||
memory_increase = final_size - initial_size
|
||||
# Should not use excessive memory (less than 10MB increase)
|
||||
assert memory_increase < 10 * 1024 * 1024
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mcp_async_operation_timing_accuracy(self, performance_wrapper):
|
||||
"""Test that async MCP operations respect timing constraints accurately."""
|
||||
# Test various timeout scenarios
|
||||
timeout_tests = [
|
||||
(5, "Should complete within timeout"),
|
||||
(15, "Should complete within longer timeout"),
|
||||
]
|
||||
|
||||
for test_timeout, description in timeout_tests:
|
||||
mock_result = Mock()
|
||||
mock_result.content = [Mock(text=f"Result for {description}")]
|
||||
|
||||
with patch('crewai.tools.mcp_tool_wrapper.streamablehttp_client') as mock_client, \
|
||||
patch('crewai.tools.mcp_tool_wrapper.ClientSession') as mock_session_class:
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_session_class.return_value.__aenter__.return_value = mock_session
|
||||
mock_session.initialize = AsyncMock()
|
||||
|
||||
# Mock tool call with controlled timing
|
||||
async def timed_call(*args, **kwargs):
|
||||
await asyncio.sleep(test_timeout - 2) # Complete just before timeout
|
||||
return mock_result
|
||||
|
||||
mock_session.call_tool = timed_call
|
||||
mock_client.return_value.__aenter__.return_value = (None, None, None)
|
||||
|
||||
start_time = time.time()
|
||||
result = await performance_wrapper._run_async(query="test")
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
# Should complete successfully within expected timeframe
|
||||
assert description.lower() in result.lower()
|
||||
assert elapsed_time < test_timeout + 2 # Small buffer for test execution
|
||||
|
||||
def test_cache_performance_under_concurrent_access(self, sample_agent):
|
||||
"""Test cache performance under concurrent access."""
|
||||
server_params = {"url": "https://concurrent-cache-test.com/mcp"}
|
||||
mock_schemas = {"tool1": {"description": "Concurrent test tool"}}
|
||||
|
||||
with patch.object(sample_agent, '_get_mcp_tool_schemas_async', return_value=mock_schemas) as mock_async, \
|
||||
patch('crewai.agent.time.time', return_value=1000):
|
||||
|
||||
# First call populates cache
|
||||
sample_agent._get_mcp_tool_schemas(server_params)
|
||||
assert mock_async.call_count == 1
|
||||
|
||||
# Multiple concurrent cache accesses
|
||||
with patch('crewai.agent.time.time', return_value=1100): # Within TTL
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# Simulate concurrent access to cache
|
||||
for _ in range(10):
|
||||
schemas = sample_agent._get_mcp_tool_schemas(server_params)
|
||||
assert schemas == mock_schemas
|
||||
|
||||
concurrent_time = time.time() - start_time
|
||||
|
||||
# All cache hits should be very fast
|
||||
assert concurrent_time < 0.1
|
||||
assert mock_async.call_count == 1 # Should not call async method again
|
||||
|
||||
def test_mcp_tool_discovery_batch_performance(self, sample_agent):
|
||||
"""Test performance when discovering tools from multiple MCP servers."""
|
||||
mcps = [
|
||||
"https://server1.com/mcp",
|
||||
"https://server2.com/mcp",
|
||||
"https://server3.com/mcp",
|
||||
"https://server4.com/mcp",
|
||||
"https://server5.com/mcp"
|
||||
]
|
||||
|
||||
def mock_get_tools(mcp_ref):
|
||||
# Simulate processing time per server
|
||||
time.sleep(0.05) # Small delay per server
|
||||
return [Mock(name=f"tool_from_{mcp_ref}")]
|
||||
|
||||
with patch.object(sample_agent, '_get_external_mcp_tools', side_effect=mock_get_tools):
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
all_tools = sample_agent.get_mcp_tools(mcps)
|
||||
|
||||
batch_time = time.time() - start_time
|
||||
|
||||
# Should process all servers efficiently
|
||||
assert len(all_tools) == len(mcps)
|
||||
# Should complete in reasonable time despite multiple servers
|
||||
assert batch_time < 2.0
|
||||
|
||||
def test_mcp_agent_initialization_performance_impact(self):
|
||||
"""Test that MCP field addition doesn't impact agent initialization performance."""
|
||||
start_time = time.time()
|
||||
|
||||
# Create agents with MCP configuration
|
||||
agents = []
|
||||
for i in range(50):
|
||||
agent = Agent(
|
||||
role=f"Agent {i}",
|
||||
goal=f"Goal {i}",
|
||||
backstory=f"Backstory {i}",
|
||||
mcps=[f"https://server{i}.com/mcp"]
|
||||
)
|
||||
agents.append(agent)
|
||||
|
||||
initialization_time = time.time() - start_time
|
||||
|
||||
# Should initialize quickly (less than 5 seconds for 50 agents)
|
||||
assert len(agents) == 50
|
||||
assert initialization_time < 5.0
|
||||
|
||||
# Each agent should have MCP configuration
|
||||
for agent in agents:
|
||||
assert hasattr(agent, 'mcps')
|
||||
assert len(agent.mcps) == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mcp_retry_backoff_total_time_bounds(self, performance_wrapper):
|
||||
"""Test that retry backoff total time stays within reasonable bounds."""
|
||||
# Mock 3 failures (max retries)
|
||||
failure_count = 0
|
||||
async def always_fail(**kwargs):
|
||||
nonlocal failure_count
|
||||
failure_count += 1
|
||||
raise Exception("Retryable network error")
|
||||
|
||||
with patch.object(performance_wrapper, '_execute_tool', side_effect=always_fail), \
|
||||
patch('crewai.tools.mcp_tool_wrapper.asyncio.sleep'): # Don't actually sleep in test
|
||||
|
||||
start_time = time.time()
|
||||
result = await performance_wrapper._run_async(query="test")
|
||||
total_time = time.time() - start_time
|
||||
|
||||
# Should fail after 3 attempts without excessive delay
|
||||
assert "failed after 3 attempts" in result
|
||||
assert failure_count == 3
|
||||
# Total time should be reasonable (not including actual sleep time due to patch)
|
||||
assert total_time < 1.0
|
||||
|
||||
def test_mcp_cache_memory_efficiency(self, sample_agent):
|
||||
"""Test that MCP cache doesn't consume excessive memory."""
|
||||
import sys
|
||||
|
||||
# Get initial cache size
|
||||
from crewai.agent import _mcp_schema_cache
|
||||
initial_cache_size = sys.getsizeof(_mcp_schema_cache)
|
||||
|
||||
# Add multiple cached entries
|
||||
test_servers = []
|
||||
for i in range(20):
|
||||
server_url = f"https://server{i}.com/mcp"
|
||||
test_servers.append(server_url)
|
||||
|
||||
mock_schemas = {f"tool_{i}": {"description": f"Tool {i}"}}
|
||||
|
||||
with patch.object(sample_agent, '_get_mcp_tool_schemas_async', return_value=mock_schemas), \
|
||||
patch('crewai.agent.time.time', return_value=1000 + i):
|
||||
|
||||
sample_agent._get_mcp_tool_schemas({"url": server_url})
|
||||
|
||||
final_cache_size = sys.getsizeof(_mcp_schema_cache)
|
||||
cache_growth = final_cache_size - initial_cache_size
|
||||
|
||||
# Cache should not grow excessively (less than 1MB for 20 entries)
|
||||
assert len(_mcp_schema_cache) == 20
|
||||
assert cache_growth < 1024 * 1024 # Less than 1MB
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mcp_operation_cancellation_handling(self, performance_wrapper):
|
||||
"""Test handling of cancelled MCP operations."""
|
||||
# Mock operation that gets cancelled
|
||||
async def cancellable_operation(**kwargs):
|
||||
try:
|
||||
await asyncio.sleep(10) # Long operation
|
||||
return "Should not complete"
|
||||
except asyncio.CancelledError:
|
||||
raise asyncio.CancelledError("Operation was cancelled")
|
||||
|
||||
with patch.object(performance_wrapper, '_execute_tool', side_effect=cancellable_operation):
|
||||
|
||||
# Start operation and cancel it
|
||||
task = asyncio.create_task(performance_wrapper._run_async(query="test"))
|
||||
await asyncio.sleep(0.1) # Let it start
|
||||
task.cancel()
|
||||
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
# Cancellation should be handled gracefully
|
||||
assert task.cancelled()
|
||||
|
||||
def test_mcp_performance_monitoring_integration(self, sample_agent):
|
||||
"""Test integration with performance monitoring systems."""
|
||||
with patch.object(sample_agent, '_logger') as mock_logger:
|
||||
|
||||
# Successful operation should log info
|
||||
with patch.object(sample_agent, '_get_external_mcp_tools', return_value=[Mock()]):
|
||||
tools = sample_agent.get_mcp_tools(["https://monitored-server.com/mcp"])
|
||||
|
||||
# Should log successful tool loading
|
||||
info_calls = [call for call in mock_logger.log.call_args_list if call[0][0] == "info"]
|
||||
assert len(info_calls) > 0
|
||||
assert "successfully loaded" in info_calls[0][0][1].lower()
|
||||
|
||||
def test_mcp_resource_cleanup_after_operations(self, performance_wrapper):
|
||||
"""Test that MCP operations clean up resources properly."""
|
||||
# This is more of a structural test since resource cleanup
|
||||
# is handled by context managers in the implementation
|
||||
|
||||
with patch('crewai.tools.mcp_tool_wrapper.streamablehttp_client') as mock_client:
|
||||
mock_context = Mock()
|
||||
mock_context.__aenter__ = AsyncMock(return_value=(None, None, None))
|
||||
mock_context.__aexit__ = AsyncMock()
|
||||
mock_client.return_value = mock_context
|
||||
|
||||
with patch('crewai.tools.mcp_tool_wrapper.ClientSession') as mock_session_class:
|
||||
mock_session = AsyncMock()
|
||||
mock_session_class.return_value.__aenter__.return_value = mock_session
|
||||
mock_session.initialize = AsyncMock()
|
||||
mock_session.call_tool = AsyncMock(return_value=Mock(content="Test"))
|
||||
|
||||
result = await performance_wrapper._run_async(query="test")
|
||||
|
||||
# Verify context managers were properly exited
|
||||
mock_context.__aexit__.assert_called_once()
|
||||
|
||||
def test_mcp_performance_baseline_establishment(self, sample_agent):
|
||||
"""Establish performance baselines for MCP operations."""
|
||||
performance_metrics = {}
|
||||
|
||||
# Test agent creation performance
|
||||
start = time.time()
|
||||
agent = Agent(
|
||||
role="Baseline Agent",
|
||||
goal="Establish performance baselines",
|
||||
backstory="Agent for performance baseline testing",
|
||||
mcps=["https://baseline-server.com/mcp"]
|
||||
)
|
||||
performance_metrics["agent_creation"] = time.time() - start
|
||||
|
||||
# Test tool discovery performance (mocked)
|
||||
with patch.object(agent, '_get_mcp_tool_schemas', return_value={"tool1": {"description": "Baseline tool"}}):
|
||||
start = time.time()
|
||||
tools = agent._get_external_mcp_tools("https://baseline-server.com/mcp")
|
||||
performance_metrics["tool_discovery"] = time.time() - start
|
||||
|
||||
# Establish reasonable performance expectations
|
||||
assert performance_metrics["agent_creation"] < 0.1 # < 100ms
|
||||
assert performance_metrics["tool_discovery"] < 0.1 # < 100ms
|
||||
assert len(tools) == 1
|
||||
|
||||
# Log performance metrics for future reference
|
||||
print(f"\nMCP Performance Baselines:")
|
||||
for metric, value in performance_metrics.items():
|
||||
print(f" {metric}: {value:.3f}s")
|
||||
1
lib/crewai/tests/mocks/__init__.py
Normal file
1
lib/crewai/tests/mocks/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Mock infrastructure for testing."""
|
||||
267
lib/crewai/tests/mocks/mcp_server_mock.py
Normal file
267
lib/crewai/tests/mocks/mcp_server_mock.py
Normal file
@@ -0,0 +1,267 @@
|
||||
"""Mock MCP server implementation for testing."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Any, Dict, List
|
||||
from unittest.mock import AsyncMock, Mock
|
||||
|
||||
|
||||
class MockMCPTool:
|
||||
"""Mock MCP tool for testing."""
|
||||
|
||||
def __init__(self, name: str, description: str, input_schema: Dict[str, Any] = None):
|
||||
self.name = name
|
||||
self.description = description
|
||||
self.inputSchema = input_schema or {"type": "object", "properties": {}}
|
||||
|
||||
|
||||
class MockMCPServer:
|
||||
"""Mock MCP server for testing various scenarios."""
|
||||
|
||||
def __init__(self, server_url: str, tools: List[MockMCPTool] = None, behavior: str = "normal"):
|
||||
self.server_url = server_url
|
||||
self.tools = tools or []
|
||||
self.behavior = behavior
|
||||
self.call_count = 0
|
||||
self.initialize_count = 0
|
||||
self.list_tools_count = 0
|
||||
|
||||
def add_tool(self, name: str, description: str, input_schema: Dict[str, Any] = None):
|
||||
"""Add a tool to the mock server."""
|
||||
tool = MockMCPTool(name, description, input_schema)
|
||||
self.tools.append(tool)
|
||||
return tool
|
||||
|
||||
async def simulate_initialize(self):
|
||||
"""Simulate MCP session initialization."""
|
||||
self.initialize_count += 1
|
||||
|
||||
if self.behavior == "slow_init":
|
||||
await asyncio.sleep(15) # Exceed connection timeout
|
||||
elif self.behavior == "init_error":
|
||||
raise Exception("Initialization failed")
|
||||
elif self.behavior == "auth_error":
|
||||
raise Exception("Authentication failed")
|
||||
|
||||
async def simulate_list_tools(self):
|
||||
"""Simulate MCP tools listing."""
|
||||
self.list_tools_count += 1
|
||||
|
||||
if self.behavior == "slow_list":
|
||||
await asyncio.sleep(20) # Exceed discovery timeout
|
||||
elif self.behavior == "list_error":
|
||||
raise Exception("Failed to list tools")
|
||||
elif self.behavior == "json_error":
|
||||
raise Exception("JSON parsing error in list_tools")
|
||||
|
||||
mock_result = Mock()
|
||||
mock_result.tools = self.tools
|
||||
return mock_result
|
||||
|
||||
async def simulate_call_tool(self, tool_name: str, arguments: Dict[str, Any]):
|
||||
"""Simulate MCP tool execution."""
|
||||
self.call_count += 1
|
||||
|
||||
if self.behavior == "slow_execution":
|
||||
await asyncio.sleep(35) # Exceed execution timeout
|
||||
elif self.behavior == "execution_error":
|
||||
raise Exception("Tool execution failed")
|
||||
elif self.behavior == "tool_not_found":
|
||||
raise Exception(f"Tool {tool_name} not found")
|
||||
|
||||
# Find the tool
|
||||
tool = next((t for t in self.tools if t.name == tool_name), None)
|
||||
if not tool and self.behavior == "normal":
|
||||
raise Exception(f"Tool {tool_name} not found")
|
||||
|
||||
# Create mock successful response
|
||||
mock_result = Mock()
|
||||
mock_result.content = [Mock(text=f"Result from {tool_name} with args: {arguments}")]
|
||||
return mock_result
|
||||
|
||||
|
||||
class MockMCPServerFactory:
|
||||
"""Factory for creating various types of mock MCP servers."""
|
||||
|
||||
@staticmethod
|
||||
def create_working_server(server_url: str) -> MockMCPServer:
|
||||
"""Create a mock server that works normally."""
|
||||
server = MockMCPServer(server_url, behavior="normal")
|
||||
server.add_tool("search_tool", "Search for information")
|
||||
server.add_tool("analysis_tool", "Analyze data")
|
||||
return server
|
||||
|
||||
@staticmethod
|
||||
def create_slow_server(server_url: str, slow_operation: str = "init") -> MockMCPServer:
|
||||
"""Create a mock server that is slow for testing timeouts."""
|
||||
behavior_map = {
|
||||
"init": "slow_init",
|
||||
"list": "slow_list",
|
||||
"execution": "slow_execution"
|
||||
}
|
||||
|
||||
server = MockMCPServer(server_url, behavior=behavior_map.get(slow_operation, "slow_init"))
|
||||
server.add_tool("slow_tool", "A slow tool")
|
||||
return server
|
||||
|
||||
@staticmethod
|
||||
def create_failing_server(server_url: str, failure_type: str = "connection") -> MockMCPServer:
|
||||
"""Create a mock server that fails in various ways."""
|
||||
behavior_map = {
|
||||
"connection": "init_error",
|
||||
"auth": "auth_error",
|
||||
"list": "list_error",
|
||||
"json": "json_error",
|
||||
"execution": "execution_error",
|
||||
"tool_missing": "tool_not_found"
|
||||
}
|
||||
|
||||
server = MockMCPServer(server_url, behavior=behavior_map.get(failure_type, "init_error"))
|
||||
if failure_type != "tool_missing":
|
||||
server.add_tool("failing_tool", "A tool that fails")
|
||||
return server
|
||||
|
||||
@staticmethod
|
||||
def create_exa_like_server(server_url: str) -> MockMCPServer:
|
||||
"""Create a mock server that mimics the Exa MCP server."""
|
||||
server = MockMCPServer(server_url, behavior="normal")
|
||||
server.add_tool(
|
||||
"web_search_exa",
|
||||
"Search the web using Exa AI - performs real-time web searches and can scrape content from specific URLs",
|
||||
{"type": "object", "properties": {"query": {"type": "string"}, "num_results": {"type": "integer"}}}
|
||||
)
|
||||
server.add_tool(
|
||||
"get_code_context_exa",
|
||||
"Search and get relevant context for any programming task. Exa-code has the highest quality context",
|
||||
{"type": "object", "properties": {"query": {"type": "string"}, "language": {"type": "string"}}}
|
||||
)
|
||||
return server
|
||||
|
||||
@staticmethod
|
||||
def create_weather_like_server(server_url: str) -> MockMCPServer:
|
||||
"""Create a mock server that mimics a weather MCP server."""
|
||||
server = MockMCPServer(server_url, behavior="normal")
|
||||
server.add_tool(
|
||||
"get_current_weather",
|
||||
"Get current weather conditions for a location",
|
||||
{"type": "object", "properties": {"location": {"type": "string"}}}
|
||||
)
|
||||
server.add_tool(
|
||||
"get_forecast",
|
||||
"Get weather forecast for the next 5 days",
|
||||
{"type": "object", "properties": {"location": {"type": "string"}, "days": {"type": "integer"}}}
|
||||
)
|
||||
server.add_tool(
|
||||
"get_alerts",
|
||||
"Get active weather alerts for a region",
|
||||
{"type": "object", "properties": {"region": {"type": "string"}}}
|
||||
)
|
||||
return server
|
||||
|
||||
|
||||
class MCPServerContextManager:
|
||||
"""Context manager for mock MCP servers."""
|
||||
|
||||
def __init__(self, mock_server: MockMCPServer):
|
||||
self.mock_server = mock_server
|
||||
|
||||
async def __aenter__(self):
|
||||
return (None, None, None) # read, write, cleanup
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
pass
|
||||
|
||||
|
||||
class MCPSessionContextManager:
|
||||
"""Context manager for mock MCP sessions."""
|
||||
|
||||
def __init__(self, mock_server: MockMCPServer):
|
||||
self.mock_server = mock_server
|
||||
|
||||
async def __aenter__(self):
|
||||
return MockMCPSession(self.mock_server)
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
pass
|
||||
|
||||
|
||||
class MockMCPSession:
|
||||
"""Mock MCP session for testing."""
|
||||
|
||||
def __init__(self, mock_server: MockMCPServer):
|
||||
self.mock_server = mock_server
|
||||
|
||||
async def initialize(self):
|
||||
"""Mock session initialization."""
|
||||
await self.mock_server.simulate_initialize()
|
||||
|
||||
async def list_tools(self):
|
||||
"""Mock tools listing."""
|
||||
return await self.mock_server.simulate_list_tools()
|
||||
|
||||
async def call_tool(self, tool_name: str, arguments: Dict[str, Any]):
|
||||
"""Mock tool execution."""
|
||||
return await self.mock_server.simulate_call_tool(tool_name, arguments)
|
||||
|
||||
|
||||
def mock_streamablehttp_client(server_url: str, mock_server: MockMCPServer):
|
||||
"""Create a mock streamable HTTP client for testing."""
|
||||
return MCPServerContextManager(mock_server)
|
||||
|
||||
|
||||
def mock_client_session(read, write, mock_server: MockMCPServer):
|
||||
"""Create a mock client session for testing."""
|
||||
return MCPSessionContextManager(mock_server)
|
||||
|
||||
|
||||
# Convenience functions for common test scenarios
|
||||
|
||||
def create_successful_exa_mock():
|
||||
"""Create a successful Exa-like mock server."""
|
||||
return MockMCPServerFactory.create_exa_like_server("https://mcp.exa.ai/mcp")
|
||||
|
||||
|
||||
def create_failing_connection_mock():
|
||||
"""Create a mock server that fails to connect."""
|
||||
return MockMCPServerFactory.create_failing_server("https://failing.com/mcp", "connection")
|
||||
|
||||
|
||||
def create_timeout_mock():
|
||||
"""Create a mock server that times out."""
|
||||
return MockMCPServerFactory.create_slow_server("https://slow.com/mcp", "init")
|
||||
|
||||
|
||||
def create_mixed_servers_scenario():
|
||||
"""Create a mixed scenario with working and failing servers."""
|
||||
return {
|
||||
"working": MockMCPServerFactory.create_working_server("https://working.com/mcp"),
|
||||
"failing": MockMCPServerFactory.create_failing_server("https://failing.com/mcp"),
|
||||
"slow": MockMCPServerFactory.create_slow_server("https://slow.com/mcp"),
|
||||
"auth_fail": MockMCPServerFactory.create_failing_server("https://auth-fail.com/mcp", "auth")
|
||||
}
|
||||
|
||||
|
||||
# Pytest fixtures for common mock scenarios
|
||||
|
||||
@pytest.fixture
|
||||
def mock_exa_server():
|
||||
"""Provide mock Exa server for tests."""
|
||||
return create_successful_exa_mock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_failing_server():
|
||||
"""Provide mock failing server for tests."""
|
||||
return create_failing_connection_mock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_slow_server():
|
||||
"""Provide mock slow server for tests."""
|
||||
return create_timeout_mock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mixed_mock_servers():
|
||||
"""Provide mixed mock servers scenario."""
|
||||
return create_mixed_servers_scenario()
|
||||
295
lib/crewai/tests/tools/test_mcp_tool_wrapper.py
Normal file
295
lib/crewai/tests/tools/test_mcp_tool_wrapper.py
Normal file
@@ -0,0 +1,295 @@
|
||||
"""Tests for MCPToolWrapper class."""
|
||||
|
||||
import asyncio
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, Mock, patch, MagicMock
|
||||
|
||||
# Import from the source directory
|
||||
import sys
|
||||
import os
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../../src'))
|
||||
|
||||
from crewai.tools.mcp_tool_wrapper import MCPToolWrapper
|
||||
|
||||
|
||||
class TestMCPToolWrapper:
|
||||
"""Test suite for MCPToolWrapper class."""
|
||||
|
||||
def test_tool_wrapper_creation(self):
|
||||
"""Test MCPToolWrapper creation with valid parameters."""
|
||||
wrapper = MCPToolWrapper(
|
||||
mcp_server_params={"url": "https://test.com/mcp"},
|
||||
tool_name="test_tool",
|
||||
tool_schema={"description": "Test tool", "args_schema": None},
|
||||
server_name="test_server"
|
||||
)
|
||||
|
||||
assert wrapper.name == "test_server_test_tool"
|
||||
assert wrapper.original_tool_name == "test_tool"
|
||||
assert wrapper.server_name == "test_server"
|
||||
assert wrapper.mcp_server_params == {"url": "https://test.com/mcp"}
|
||||
assert "Test tool" in wrapper.description
|
||||
|
||||
def test_tool_wrapper_creation_with_custom_description(self):
|
||||
"""Test MCPToolWrapper creation with custom description."""
|
||||
custom_description = "Custom test tool for analysis"
|
||||
wrapper = MCPToolWrapper(
|
||||
mcp_server_params={"url": "https://api.example.com/mcp"},
|
||||
tool_name="analysis_tool",
|
||||
tool_schema={"description": custom_description, "args_schema": None},
|
||||
server_name="example_server"
|
||||
)
|
||||
|
||||
assert wrapper.name == "example_server_analysis_tool"
|
||||
assert custom_description in wrapper.description
|
||||
|
||||
def test_tool_wrapper_creation_without_args_schema(self):
|
||||
"""Test MCPToolWrapper creation when args_schema is None."""
|
||||
wrapper = MCPToolWrapper(
|
||||
mcp_server_params={"url": "https://test.com/mcp"},
|
||||
tool_name="test_tool",
|
||||
tool_schema={"description": "Test tool"}, # No args_schema
|
||||
server_name="test_server"
|
||||
)
|
||||
|
||||
assert wrapper.name == "test_server_test_tool"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_execution_success(self):
|
||||
"""Test successful MCP tool execution."""
|
||||
wrapper = MCPToolWrapper(
|
||||
mcp_server_params={"url": "https://test.com/mcp"},
|
||||
tool_name="test_tool",
|
||||
tool_schema={"description": "Test tool"},
|
||||
server_name="test_server"
|
||||
)
|
||||
|
||||
# Mock successful MCP response
|
||||
mock_result = Mock()
|
||||
mock_result.content = [Mock(text="Test result from MCP server")]
|
||||
|
||||
with patch('crewai.tools.mcp_tool_wrapper.streamablehttp_client') as mock_client, \
|
||||
patch('crewai.tools.mcp_tool_wrapper.ClientSession') as mock_session_class:
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_session_class.return_value.__aenter__.return_value = mock_session
|
||||
mock_session.initialize = AsyncMock()
|
||||
mock_session.call_tool = AsyncMock(return_value=mock_result)
|
||||
|
||||
mock_client.return_value.__aenter__.return_value = (None, None, None)
|
||||
|
||||
result = await wrapper._run_async(query="test query")
|
||||
|
||||
assert result == "Test result from MCP server"
|
||||
mock_session.initialize.assert_called_once()
|
||||
mock_session.call_tool.assert_called_once_with("test_tool", {"query": "test query"})
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_execution_timeout(self):
|
||||
"""Test MCP tool execution timeout handling."""
|
||||
wrapper = MCPToolWrapper(
|
||||
mcp_server_params={"url": "https://test.com/mcp"},
|
||||
tool_name="slow_tool",
|
||||
tool_schema={"description": "Slow tool"},
|
||||
server_name="test_server"
|
||||
)
|
||||
|
||||
# Mock timeout scenario
|
||||
with patch('crewai.tools.mcp_tool_wrapper.asyncio.wait_for', side_effect=asyncio.TimeoutError):
|
||||
result = await wrapper._run_async(query="test")
|
||||
|
||||
assert "timed out" in result.lower()
|
||||
assert "30 seconds" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_execution_connection_error(self):
|
||||
"""Test MCP tool execution with connection error."""
|
||||
wrapper = MCPToolWrapper(
|
||||
mcp_server_params={"url": "https://test.com/mcp"},
|
||||
tool_name="test_tool",
|
||||
tool_schema={"description": "Test tool"},
|
||||
server_name="test_server"
|
||||
)
|
||||
|
||||
# Mock connection error
|
||||
with patch('crewai.tools.mcp_tool_wrapper.streamablehttp_client',
|
||||
side_effect=Exception("Connection refused")):
|
||||
result = await wrapper._run_async(query="test")
|
||||
|
||||
assert "failed after 3 attempts" in result.lower()
|
||||
assert "connection" in result.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_execution_authentication_error(self):
|
||||
"""Test MCP tool execution with authentication error."""
|
||||
wrapper = MCPToolWrapper(
|
||||
mcp_server_params={"url": "https://test.com/mcp"},
|
||||
tool_name="test_tool",
|
||||
tool_schema={"description": "Test tool"},
|
||||
server_name="test_server"
|
||||
)
|
||||
|
||||
# Mock authentication error
|
||||
with patch('crewai.tools.mcp_tool_wrapper.streamablehttp_client',
|
||||
side_effect=Exception("Authentication failed")):
|
||||
result = await wrapper._run_async(query="test")
|
||||
|
||||
assert "authentication failed" in result.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_execution_json_parsing_error(self):
|
||||
"""Test MCP tool execution with JSON parsing error."""
|
||||
wrapper = MCPToolWrapper(
|
||||
mcp_server_params={"url": "https://test.com/mcp"},
|
||||
tool_name="test_tool",
|
||||
tool_schema={"description": "Test tool"},
|
||||
server_name="test_server"
|
||||
)
|
||||
|
||||
# Mock JSON parsing error
|
||||
with patch('crewai.tools.mcp_tool_wrapper.streamablehttp_client',
|
||||
side_effect=Exception("JSON parsing error")):
|
||||
result = await wrapper._run_async(query="test")
|
||||
|
||||
assert "failed after 3 attempts" in result.lower()
|
||||
assert "parsing error" in result.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_execution_retry_logic(self):
|
||||
"""Test MCP tool execution retry logic with exponential backoff."""
|
||||
wrapper = MCPToolWrapper(
|
||||
mcp_server_params={"url": "https://test.com/mcp"},
|
||||
tool_name="test_tool",
|
||||
tool_schema={"description": "Test tool"},
|
||||
server_name="test_server"
|
||||
)
|
||||
|
||||
call_count = 0
|
||||
async def mock_execute_tool(**kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count < 3:
|
||||
raise Exception("Network connection failed")
|
||||
# Success on third attempt
|
||||
mock_result = Mock()
|
||||
mock_result.content = [Mock(text="Success after retry")]
|
||||
return "Success after retry"
|
||||
|
||||
with patch.object(wrapper, '_execute_tool', side_effect=mock_execute_tool):
|
||||
with patch('crewai.tools.mcp_tool_wrapper.asyncio.sleep') as mock_sleep:
|
||||
result = await wrapper._run_async(query="test")
|
||||
|
||||
assert result == "Success after retry"
|
||||
assert call_count == 3
|
||||
# Verify exponential backoff sleep calls
|
||||
assert mock_sleep.call_count == 2 # 2 retries
|
||||
mock_sleep.assert_any_call(1) # 2^0 = 1
|
||||
mock_sleep.assert_any_call(2) # 2^1 = 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_execution_mcp_library_missing(self):
|
||||
"""Test MCP tool execution when MCP library is missing."""
|
||||
wrapper = MCPToolWrapper(
|
||||
mcp_server_params={"url": "https://test.com/mcp"},
|
||||
tool_name="test_tool",
|
||||
tool_schema={"description": "Test tool"},
|
||||
server_name="test_server"
|
||||
)
|
||||
|
||||
# Mock ImportError for MCP library
|
||||
with patch('crewai.tools.mcp_tool_wrapper.ClientSession', side_effect=ImportError("No module named 'mcp'")):
|
||||
result = await wrapper._run_async(query="test")
|
||||
|
||||
assert "mcp library not available" in result.lower()
|
||||
assert "pip install mcp" in result.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_execution_various_content_formats(self):
|
||||
"""Test MCP tool execution with various response content formats."""
|
||||
wrapper = MCPToolWrapper(
|
||||
mcp_server_params={"url": "https://test.com/mcp"},
|
||||
tool_name="test_tool",
|
||||
tool_schema={"description": "Test tool"},
|
||||
server_name="test_server"
|
||||
)
|
||||
|
||||
test_cases = [
|
||||
# Content as list with text attribute
|
||||
([Mock(text="List text content")], "List text content"),
|
||||
# Content as list without text attribute
|
||||
([Mock(spec=[])], "Mock object"),
|
||||
# Content as string
|
||||
("String content", "String content"),
|
||||
# No content
|
||||
(None, "Mock object"),
|
||||
]
|
||||
|
||||
for content, expected_substring in test_cases:
|
||||
mock_result = Mock()
|
||||
mock_result.content = content
|
||||
|
||||
with patch('crewai.tools.mcp_tool_wrapper.streamablehttp_client') as mock_client, \
|
||||
patch('crewai.tools.mcp_tool_wrapper.ClientSession') as mock_session_class:
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_session_class.return_value.__aenter__.return_value = mock_session
|
||||
mock_session.initialize = AsyncMock()
|
||||
mock_session.call_tool = AsyncMock(return_value=mock_result)
|
||||
|
||||
mock_client.return_value.__aenter__.return_value = (None, None, None)
|
||||
|
||||
result = await wrapper._run_async(query="test")
|
||||
|
||||
if expected_substring != "Mock object":
|
||||
assert expected_substring in result
|
||||
else:
|
||||
# For mock objects, just verify it's a string
|
||||
assert isinstance(result, str)
|
||||
|
||||
def test_sync_run_method(self):
|
||||
"""Test the synchronous _run method wrapper."""
|
||||
wrapper = MCPToolWrapper(
|
||||
mcp_server_params={"url": "https://test.com/mcp"},
|
||||
tool_name="test_tool",
|
||||
tool_schema={"description": "Test tool"},
|
||||
server_name="test_server"
|
||||
)
|
||||
|
||||
# Mock successful async execution
|
||||
async def mock_async_run(**kwargs):
|
||||
return "Async result"
|
||||
|
||||
with patch.object(wrapper, '_run_async', side_effect=mock_async_run):
|
||||
result = wrapper._run(query="test")
|
||||
|
||||
assert result == "Async result"
|
||||
|
||||
def test_sync_run_method_timeout_error(self):
|
||||
"""Test the synchronous _run method handling asyncio.TimeoutError."""
|
||||
wrapper = MCPToolWrapper(
|
||||
mcp_server_params={"url": "https://test.com/mcp"},
|
||||
tool_name="test_tool",
|
||||
tool_schema={"description": "Test tool"},
|
||||
server_name="test_server"
|
||||
)
|
||||
|
||||
with patch('asyncio.run', side_effect=asyncio.TimeoutError()):
|
||||
result = wrapper._run(query="test")
|
||||
|
||||
assert "test_tool" in result
|
||||
assert "timed out after 30 seconds" in result
|
||||
|
||||
def test_sync_run_method_general_error(self):
|
||||
"""Test the synchronous _run method handling general exceptions."""
|
||||
wrapper = MCPToolWrapper(
|
||||
mcp_server_params={"url": "https://test.com/mcp"},
|
||||
tool_name="test_tool",
|
||||
tool_schema={"description": "Test tool"},
|
||||
server_name="test_server"
|
||||
)
|
||||
|
||||
with patch('asyncio.run', side_effect=Exception("General error")):
|
||||
result = wrapper._run(query="test")
|
||||
|
||||
assert "error executing mcp tool test_tool" in result.lower()
|
||||
assert "general error" in result.lower()
|
||||
Reference in New Issue
Block a user