mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-08 23:58:34 +00:00
471 lines
21 KiB
Python
471 lines
21 KiB
Python
"""Tests for Agent MCP integration functionality."""
|
|
|
|
import asyncio
|
|
import pytest
|
|
import time
|
|
from unittest.mock import AsyncMock, Mock, patch, MagicMock
|
|
|
|
# Import from the source directory
|
|
import sys
|
|
import os
|
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../../src'))
|
|
|
|
from crewai.agent import Agent
|
|
from crewai.tools.mcp_tool_wrapper import MCPToolWrapper
|
|
|
|
|
|
class TestAgentMCPIntegration:
|
|
"""Test suite for Agent MCP integration functionality."""
|
|
|
|
@pytest.fixture
|
|
def sample_agent(self):
|
|
"""Create a sample agent for testing."""
|
|
return Agent(
|
|
role="Test Research Agent",
|
|
goal="Test MCP integration capabilities",
|
|
backstory="Agent designed for testing MCP functionality",
|
|
mcps=["https://api.example.com/mcp"]
|
|
)
|
|
|
|
@pytest.fixture
|
|
def mock_mcp_tools_response(self):
|
|
"""Mock MCP server tools response."""
|
|
mock_tool1 = Mock()
|
|
mock_tool1.name = "search_tool"
|
|
mock_tool1.description = "Search for information"
|
|
|
|
mock_tool2 = Mock()
|
|
mock_tool2.name = "analysis_tool"
|
|
mock_tool2.description = "Analyze data"
|
|
|
|
mock_result = Mock()
|
|
mock_result.tools = [mock_tool1, mock_tool2]
|
|
|
|
return mock_result
|
|
|
|
def test_get_mcp_tools_empty_list(self, sample_agent):
|
|
"""Test get_mcp_tools with empty list."""
|
|
tools = sample_agent.get_mcp_tools([])
|
|
assert tools == []
|
|
|
|
def test_get_mcp_tools_with_https_url(self, sample_agent):
|
|
"""Test get_mcp_tools with HTTPS URL."""
|
|
with patch.object(sample_agent, '_get_external_mcp_tools', return_value=[Mock()]) as mock_get:
|
|
tools = sample_agent.get_mcp_tools(["https://api.example.com/mcp"])
|
|
|
|
mock_get.assert_called_once_with("https://api.example.com/mcp")
|
|
assert len(tools) == 1
|
|
|
|
def test_get_mcp_tools_with_crewai_amp_reference(self, sample_agent):
|
|
"""Test get_mcp_tools with CrewAI AMP reference."""
|
|
with patch.object(sample_agent, '_get_amp_mcp_tools', return_value=[Mock()]) as mock_get:
|
|
tools = sample_agent.get_mcp_tools(["crewai-amp:financial-data"])
|
|
|
|
mock_get.assert_called_once_with("crewai-amp:financial-data")
|
|
assert len(tools) == 1
|
|
|
|
def test_get_mcp_tools_mixed_references(self, sample_agent):
|
|
"""Test get_mcp_tools with mixed reference types."""
|
|
mock_external_tools = [Mock(name="external_tool")]
|
|
mock_amp_tools = [Mock(name="amp_tool")]
|
|
|
|
with patch.object(sample_agent, '_get_external_mcp_tools', return_value=mock_external_tools), \
|
|
patch.object(sample_agent, '_get_amp_mcp_tools', return_value=mock_amp_tools):
|
|
|
|
tools = sample_agent.get_mcp_tools([
|
|
"https://api.example.com/mcp",
|
|
"crewai-amp:research-tools"
|
|
])
|
|
|
|
assert len(tools) == 2
|
|
|
|
def test_get_mcp_tools_error_handling(self, sample_agent):
|
|
"""Test get_mcp_tools error handling and graceful degradation."""
|
|
with patch.object(sample_agent, '_get_external_mcp_tools', side_effect=Exception("Connection failed")), \
|
|
patch.object(sample_agent, '_logger') as mock_logger:
|
|
|
|
tools = sample_agent.get_mcp_tools(["https://api.example.com/mcp"])
|
|
|
|
# Should return empty list and log warning
|
|
assert tools == []
|
|
mock_logger.log.assert_called_with("warning", "Skipping MCP https://api.example.com/mcp due to error: Connection failed")
|
|
|
|
def test_extract_server_name_basic_url(self, sample_agent):
|
|
"""Test server name extraction from basic URLs."""
|
|
server_name = sample_agent._extract_server_name("https://api.example.com/mcp")
|
|
assert server_name == "api_example_com_mcp"
|
|
|
|
def test_extract_server_name_with_path(self, sample_agent):
|
|
"""Test server name extraction from URLs with paths."""
|
|
server_name = sample_agent._extract_server_name("https://mcp.exa.ai/api/v1/mcp")
|
|
assert server_name == "mcp_exa_ai_api_v1_mcp"
|
|
|
|
def test_extract_server_name_no_path(self, sample_agent):
|
|
"""Test server name extraction from URLs without path."""
|
|
server_name = sample_agent._extract_server_name("https://example.com")
|
|
assert server_name == "example_com"
|
|
|
|
def test_extract_server_name_with_query_params(self, sample_agent):
|
|
"""Test server name extraction ignores query parameters."""
|
|
server_name = sample_agent._extract_server_name("https://api.example.com/mcp?api_key=test")
|
|
assert server_name == "api_example_com_mcp"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_mcp_tool_schemas_success(self, sample_agent, mock_mcp_tools_response):
|
|
"""Test successful MCP tool schema retrieval."""
|
|
server_params = {"url": "https://api.example.com/mcp"}
|
|
|
|
with patch.object(sample_agent, '_get_mcp_tool_schemas_async', return_value={
|
|
"search_tool": {"description": "Search tool", "args_schema": None},
|
|
"analysis_tool": {"description": "Analysis tool", "args_schema": None}
|
|
}) as mock_async:
|
|
|
|
schemas = sample_agent._get_mcp_tool_schemas(server_params)
|
|
|
|
assert len(schemas) == 2
|
|
assert "search_tool" in schemas
|
|
assert "analysis_tool" in schemas
|
|
mock_async.assert_called_once()
|
|
|
|
def test_get_mcp_tool_schemas_caching(self, sample_agent):
|
|
"""Test MCP tool schema caching behavior."""
|
|
from crewai.agent import _mcp_schema_cache
|
|
|
|
# Clear cache to ensure clean test state
|
|
_mcp_schema_cache.clear()
|
|
|
|
server_params = {"url": "https://api.example.com/mcp"}
|
|
mock_schemas = {"tool1": {"description": "Tool 1"}}
|
|
|
|
with patch.object(sample_agent, '_get_mcp_tool_schemas_async', return_value=mock_schemas) as mock_async:
|
|
|
|
# First call at time 1000 - should hit server
|
|
with patch('crewai.agent.time.time', return_value=1000):
|
|
schemas1 = sample_agent._get_mcp_tool_schemas(server_params)
|
|
assert mock_async.call_count == 1
|
|
|
|
# Second call within TTL - should use cache
|
|
with patch('crewai.agent.time.time', return_value=1100): # 100 seconds later, within 300s TTL
|
|
schemas2 = sample_agent._get_mcp_tool_schemas(server_params)
|
|
assert mock_async.call_count == 1 # Not called again
|
|
assert schemas1 == schemas2
|
|
|
|
# Clean up cache after test
|
|
_mcp_schema_cache.clear()
|
|
|
|
def test_get_mcp_tool_schemas_cache_expiration(self, sample_agent):
|
|
"""Test MCP tool schema cache expiration."""
|
|
from crewai.agent import _mcp_schema_cache
|
|
|
|
# Clear cache to ensure clean test state
|
|
_mcp_schema_cache.clear()
|
|
|
|
server_params = {"url": "https://api.example.com/mcp"}
|
|
mock_schemas = {"tool1": {"description": "Tool 1"}}
|
|
|
|
with patch.object(sample_agent, '_get_mcp_tool_schemas_async', return_value=mock_schemas) as mock_async:
|
|
|
|
# First call at time 1000
|
|
with patch('crewai.agent.time.time', return_value=1000):
|
|
schemas1 = sample_agent._get_mcp_tool_schemas(server_params)
|
|
assert mock_async.call_count == 1
|
|
|
|
# Call after cache expiration (> 300s TTL)
|
|
with patch('crewai.agent.time.time', return_value=1400): # 400 seconds later, beyond 300s TTL
|
|
schemas2 = sample_agent._get_mcp_tool_schemas(server_params)
|
|
assert mock_async.call_count == 2 # Called again after cache expiration
|
|
|
|
# Clean up cache after test
|
|
_mcp_schema_cache.clear()
|
|
|
|
def test_get_mcp_tool_schemas_error_handling(self, sample_agent):
|
|
"""Test MCP tool schema retrieval error handling."""
|
|
server_params = {"url": "https://api.example.com/mcp"}
|
|
|
|
with patch.object(sample_agent, '_get_mcp_tool_schemas_async', side_effect=Exception("Connection failed")), \
|
|
patch.object(sample_agent, '_logger') as mock_logger:
|
|
|
|
schemas = sample_agent._get_mcp_tool_schemas(server_params)
|
|
|
|
# Should return empty dict and log warning
|
|
assert schemas == {}
|
|
mock_logger.log.assert_called_with("warning", "Failed to get MCP tool schemas from https://api.example.com/mcp: Connection failed")
|
|
|
|
def test_get_external_mcp_tools_full_server(self, sample_agent):
|
|
"""Test getting tools from external MCP server (full server)."""
|
|
mcp_ref = "https://api.example.com/mcp"
|
|
mock_schemas = {
|
|
"tool1": {"description": "Tool 1"},
|
|
"tool2": {"description": "Tool 2"}
|
|
}
|
|
|
|
with patch.object(sample_agent, '_get_mcp_tool_schemas', return_value=mock_schemas), \
|
|
patch.object(sample_agent, '_extract_server_name', return_value="example_server"):
|
|
|
|
tools = sample_agent._get_external_mcp_tools(mcp_ref)
|
|
|
|
assert len(tools) == 2
|
|
assert all(isinstance(tool, MCPToolWrapper) for tool in tools)
|
|
assert tools[0].server_name == "example_server"
|
|
|
|
def test_get_external_mcp_tools_specific_tool(self, sample_agent):
|
|
"""Test getting specific tool from external MCP server."""
|
|
mcp_ref = "https://api.example.com/mcp#tool1"
|
|
mock_schemas = {
|
|
"tool1": {"description": "Tool 1"},
|
|
"tool2": {"description": "Tool 2"}
|
|
}
|
|
|
|
with patch.object(sample_agent, '_get_mcp_tool_schemas', return_value=mock_schemas), \
|
|
patch.object(sample_agent, '_extract_server_name', return_value="example_server"):
|
|
|
|
tools = sample_agent._get_external_mcp_tools(mcp_ref)
|
|
|
|
# Should only get tool1
|
|
assert len(tools) == 1
|
|
assert tools[0].original_tool_name == "tool1"
|
|
|
|
def test_get_external_mcp_tools_specific_tool_not_found(self, sample_agent):
|
|
"""Test getting specific tool that doesn't exist on MCP server."""
|
|
mcp_ref = "https://api.example.com/mcp#nonexistent_tool"
|
|
mock_schemas = {
|
|
"tool1": {"description": "Tool 1"},
|
|
"tool2": {"description": "Tool 2"}
|
|
}
|
|
|
|
with patch.object(sample_agent, '_get_mcp_tool_schemas', return_value=mock_schemas), \
|
|
patch.object(sample_agent, '_logger') as mock_logger:
|
|
|
|
tools = sample_agent._get_external_mcp_tools(mcp_ref)
|
|
|
|
# Should return empty list and log warning
|
|
assert tools == []
|
|
mock_logger.log.assert_called_with("warning", "Specific tool 'nonexistent_tool' not found on MCP server: https://api.example.com/mcp")
|
|
|
|
def test_get_external_mcp_tools_no_schemas(self, sample_agent):
|
|
"""Test getting tools when no schemas are discovered."""
|
|
mcp_ref = "https://api.example.com/mcp"
|
|
|
|
with patch.object(sample_agent, '_get_mcp_tool_schemas', return_value={}), \
|
|
patch.object(sample_agent, '_logger') as mock_logger:
|
|
|
|
tools = sample_agent._get_external_mcp_tools(mcp_ref)
|
|
|
|
assert tools == []
|
|
mock_logger.log.assert_called_with("warning", "No tools discovered from MCP server: https://api.example.com/mcp")
|
|
|
|
def test_get_amp_mcp_tools_full_mcp(self, sample_agent):
|
|
"""Test getting tools from CrewAI AMP MCP marketplace (full MCP)."""
|
|
amp_ref = "crewai-amp:financial-data"
|
|
mock_servers = [{"url": "https://amp.crewai.com/mcp/financial"}]
|
|
|
|
with patch.object(sample_agent, '_fetch_amp_mcp_servers', return_value=mock_servers), \
|
|
patch.object(sample_agent, '_get_external_mcp_tools', return_value=[Mock()]) as mock_get_tools:
|
|
|
|
tools = sample_agent._get_amp_mcp_tools(amp_ref)
|
|
|
|
mock_get_tools.assert_called_once_with("https://amp.crewai.com/mcp/financial")
|
|
assert len(tools) == 1
|
|
|
|
def test_get_amp_mcp_tools_specific_tool(self, sample_agent):
|
|
"""Test getting specific tool from CrewAI AMP MCP marketplace."""
|
|
amp_ref = "crewai-amp:financial-data#get_stock_price"
|
|
mock_servers = [{"url": "https://amp.crewai.com/mcp/financial"}]
|
|
|
|
with patch.object(sample_agent, '_fetch_amp_mcp_servers', return_value=mock_servers), \
|
|
patch.object(sample_agent, '_get_external_mcp_tools', return_value=[Mock()]) as mock_get_tools:
|
|
|
|
tools = sample_agent._get_amp_mcp_tools(amp_ref)
|
|
|
|
mock_get_tools.assert_called_once_with("https://amp.crewai.com/mcp/financial#get_stock_price")
|
|
assert len(tools) == 1
|
|
|
|
def test_get_amp_mcp_tools_multiple_servers(self, sample_agent):
|
|
"""Test getting tools from multiple AMP MCP servers."""
|
|
amp_ref = "crewai-amp:multi-server-mcp"
|
|
mock_servers = [
|
|
{"url": "https://amp.crewai.com/mcp/server1"},
|
|
{"url": "https://amp.crewai.com/mcp/server2"}
|
|
]
|
|
|
|
with patch.object(sample_agent, '_fetch_amp_mcp_servers', return_value=mock_servers), \
|
|
patch.object(sample_agent, '_get_external_mcp_tools', return_value=[Mock()]) as mock_get_tools:
|
|
|
|
tools = sample_agent._get_amp_mcp_tools(amp_ref)
|
|
|
|
assert mock_get_tools.call_count == 2
|
|
assert len(tools) == 2
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_mcp_tool_schemas_async_success(self, sample_agent, mock_mcp_tools_response):
|
|
"""Test successful async MCP tool schema retrieval."""
|
|
server_params = {"url": "https://api.example.com/mcp"}
|
|
|
|
with patch('crewai.agent.streamablehttp_client') as mock_client, \
|
|
patch('crewai.agent.ClientSession') as mock_session_class:
|
|
|
|
mock_session = AsyncMock()
|
|
mock_session_class.return_value.__aenter__.return_value = mock_session
|
|
mock_session.initialize = AsyncMock()
|
|
mock_session.list_tools = AsyncMock(return_value=mock_mcp_tools_response)
|
|
|
|
mock_client.return_value.__aenter__.return_value = (None, None, None)
|
|
|
|
schemas = await sample_agent._get_mcp_tool_schemas_async(server_params)
|
|
|
|
assert len(schemas) == 2
|
|
assert "search_tool" in schemas
|
|
assert "analysis_tool" in schemas
|
|
assert schemas["search_tool"]["description"] == "Search for information"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_mcp_tool_schemas_async_timeout(self, sample_agent):
|
|
"""Test async MCP tool schema retrieval timeout handling."""
|
|
server_params = {"url": "https://api.example.com/mcp"}
|
|
|
|
with patch('crewai.agent.asyncio.wait_for', side_effect=asyncio.TimeoutError):
|
|
with pytest.raises(RuntimeError, match="Failed to discover MCP tools after 3 attempts"):
|
|
await sample_agent._get_mcp_tool_schemas_async(server_params)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_mcp_tool_schemas_async_import_error(self, sample_agent):
|
|
"""Test async MCP tool schema retrieval with missing MCP library."""
|
|
server_params = {"url": "https://api.example.com/mcp"}
|
|
|
|
with patch('crewai.agent.ClientSession', side_effect=ImportError("No module named 'mcp'")):
|
|
with pytest.raises(RuntimeError, match="MCP library not available"):
|
|
await sample_agent._get_mcp_tool_schemas_async(server_params)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_mcp_tool_schemas_async_retry_logic(self, sample_agent):
|
|
"""Test retry logic with exponential backoff in async schema retrieval."""
|
|
server_params = {"url": "https://api.example.com/mcp"}
|
|
|
|
call_count = 0
|
|
async def mock_discover_tools(url):
|
|
nonlocal call_count
|
|
call_count += 1
|
|
if call_count < 3:
|
|
raise Exception("Network connection failed")
|
|
return {"tool1": {"description": "Tool 1"}}
|
|
|
|
with patch.object(sample_agent, '_discover_mcp_tools', side_effect=mock_discover_tools), \
|
|
patch('crewai.agent.asyncio.sleep') as mock_sleep:
|
|
|
|
schemas = await sample_agent._get_mcp_tool_schemas_async(server_params)
|
|
|
|
assert schemas == {"tool1": {"description": "Tool 1"}}
|
|
assert call_count == 3
|
|
# Verify exponential backoff
|
|
assert mock_sleep.call_count == 2
|
|
mock_sleep.assert_any_call(1) # First retry: 2^0 = 1
|
|
mock_sleep.assert_any_call(2) # Second retry: 2^1 = 2
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_discover_mcp_tools_success(self, sample_agent, mock_mcp_tools_response):
|
|
"""Test successful MCP tool discovery."""
|
|
server_url = "https://api.example.com/mcp"
|
|
|
|
with patch('crewai.agent.streamablehttp_client') as mock_client, \
|
|
patch('crewai.agent.ClientSession') as mock_session_class:
|
|
|
|
mock_session = AsyncMock()
|
|
mock_session_class.return_value.__aenter__.return_value = mock_session
|
|
mock_session.initialize = AsyncMock()
|
|
mock_session.list_tools = AsyncMock(return_value=mock_mcp_tools_response)
|
|
|
|
mock_client.return_value.__aenter__.return_value = (None, None, None)
|
|
|
|
schemas = await sample_agent._discover_mcp_tools(server_url)
|
|
|
|
assert len(schemas) == 2
|
|
assert schemas["search_tool"]["description"] == "Search for information"
|
|
assert schemas["analysis_tool"]["description"] == "Analyze data"
|
|
|
|
def test_fetch_amp_mcp_servers_placeholder(self, sample_agent):
|
|
"""Test AMP MCP server fetching (currently returns empty list)."""
|
|
result = sample_agent._fetch_amp_mcp_servers("test-mcp")
|
|
|
|
# Currently returns empty list - placeholder implementation
|
|
assert result == []
|
|
|
|
def test_get_external_mcp_tools_error_handling(self, sample_agent):
|
|
"""Test external MCP tools error handling."""
|
|
mcp_ref = "https://failing-server.com/mcp"
|
|
|
|
with patch.object(sample_agent, '_get_mcp_tool_schemas', side_effect=Exception("Server unavailable")), \
|
|
patch.object(sample_agent, '_logger') as mock_logger:
|
|
|
|
tools = sample_agent._get_external_mcp_tools(mcp_ref)
|
|
|
|
assert tools == []
|
|
mock_logger.log.assert_called_with("warning", "Failed to connect to MCP server https://failing-server.com/mcp: Server unavailable")
|
|
|
|
def test_get_external_mcp_tools_wrapper_creation_error(self, sample_agent):
|
|
"""Test handling of MCPToolWrapper creation errors."""
|
|
mcp_ref = "https://api.example.com/mcp"
|
|
mock_schemas = {"tool1": {"description": "Tool 1"}}
|
|
|
|
with patch.object(sample_agent, '_get_mcp_tool_schemas', return_value=mock_schemas), \
|
|
patch.object(sample_agent, '_extract_server_name', return_value="example_server"), \
|
|
patch('crewai.tools.mcp_tool_wrapper.MCPToolWrapper', side_effect=Exception("Wrapper creation failed")), \
|
|
patch.object(sample_agent, '_logger') as mock_logger:
|
|
|
|
tools = sample_agent._get_external_mcp_tools(mcp_ref)
|
|
|
|
assert tools == []
|
|
mock_logger.log.assert_called_with("warning", "Failed to create MCP tool wrapper for tool1: Wrapper creation failed")
|
|
|
|
def test_mcp_tools_integration_with_existing_agent_features(self, sample_agent):
|
|
"""Test MCP tools integration with existing agent features."""
|
|
# Test that MCP field works alongside other agent features
|
|
agent_with_all_features = Agent(
|
|
role="Full Feature Agent",
|
|
goal="Test all features together",
|
|
backstory="Agent with all features enabled",
|
|
mcps=["https://api.example.com/mcp", "crewai-amp:financial-data"],
|
|
apps=["gmail", "slack"], # Platform apps
|
|
tools=[], # Regular tools
|
|
verbose=True,
|
|
max_iter=15,
|
|
allow_delegation=True
|
|
)
|
|
|
|
assert len(agent_with_all_features.mcps) == 2
|
|
assert len(agent_with_all_features.apps) == 2
|
|
assert agent_with_all_features.verbose is True
|
|
assert agent_with_all_features.max_iter == 15
|
|
assert agent_with_all_features.allow_delegation is True
|
|
|
|
def test_mcp_reference_parsing_edge_cases(self, sample_agent):
|
|
"""Test MCP reference parsing with edge cases."""
|
|
test_cases = [
|
|
# URL with complex query parameters
|
|
("https://api.example.com/mcp?api_key=abc123&profile=test&version=1.0", "api.example.com", None),
|
|
# URL with tool name and query params
|
|
("https://api.example.com/mcp?api_key=test#search_tool", "api.example.com", "search_tool"),
|
|
# AMP reference with dashes and underscores
|
|
("crewai-amp:financial_data-v2", "financial_data-v2", None),
|
|
# AMP reference with tool name
|
|
("crewai-amp:research-tools#pubmed_search", "research-tools", "pubmed_search"),
|
|
]
|
|
|
|
for mcp_ref, expected_server_part, expected_tool in test_cases:
|
|
if mcp_ref.startswith("https://"):
|
|
if '#' in mcp_ref:
|
|
server_url, tool_name = mcp_ref.split('#', 1)
|
|
assert expected_server_part in server_url
|
|
assert tool_name == expected_tool
|
|
else:
|
|
assert expected_server_part in mcp_ref
|
|
assert expected_tool is None
|
|
else: # AMP reference
|
|
amp_part = mcp_ref.replace('crewai-amp:', '')
|
|
if '#' in amp_part:
|
|
mcp_name, tool_name = amp_part.split('#', 1)
|
|
assert mcp_name == expected_server_part
|
|
assert tool_name == expected_tool
|
|
else:
|
|
assert amp_part == expected_server_part
|
|
assert expected_tool is None
|