Adding MCP implementation

This commit is contained in:
Joao Moura
2025-10-19 22:51:47 -07:00
parent 14a8926214
commit ccd3c2163e
4 changed files with 438 additions and 0 deletions

View File

@@ -1,4 +1,5 @@
from collections.abc import Callable, Sequence
import asyncio
import shutil
import subprocess
import time
@@ -55,6 +56,16 @@ from crewai.utilities.token_counter_callback import TokenCalcHandler
from crewai.utilities.training_handler import CrewTrainingHandler
from crewai.utilities.types import LLMMessage
# MCP Connection timeout constants (in seconds)
MCP_CONNECTION_TIMEOUT = 10
MCP_TOOL_EXECUTION_TIMEOUT = 30
MCP_DISCOVERY_TIMEOUT = 15
MCP_MAX_RETRIES = 3
# Simple in-memory cache for MCP tool schemas (duration: 5 minutes)
_mcp_schema_cache = {}
_cache_ttl = 300 # 5 minutes
class Agent(BaseAgent):
"""Represents an agent in a system.
@@ -80,6 +91,7 @@ class Agent(BaseAgent):
knowledge_sources: Knowledge sources for the agent.
embedder: Embedder configuration for the agent.
apps: List of applications that the agent can access through CrewAI Platform.
mcps: List of MCP server references for tool integration.
"""
_times_executed: int = PrivateAttr(default=0)
@@ -611,6 +623,214 @@ class Agent(BaseAgent):
self._logger.log("error", f"Error getting platform tools: {e!s}")
return []
def get_mcp_tools(self, mcps: list[str]) -> list[BaseTool]:
"""Convert MCP server references to CrewAI tools."""
all_tools = []
for mcp_ref in mcps:
try:
if mcp_ref.startswith('crewai-amp:'):
tools = self._get_amp_mcp_tools(mcp_ref)
elif mcp_ref.startswith('https://'):
tools = self._get_external_mcp_tools(mcp_ref)
else:
continue
all_tools.extend(tools)
self._logger.log("info", f"Successfully loaded {len(tools)} tools from {mcp_ref}")
except Exception as e:
self._logger.log("warning", f"Skipping MCP {mcp_ref} due to error: {e}")
continue
return all_tools
def _get_external_mcp_tools(self, mcp_ref: str) -> list[BaseTool]:
"""Get tools from external HTTPS MCP server with graceful error handling."""
from crewai.tools.mcp_tool_wrapper import MCPToolWrapper
# Parse server URL and optional tool name
if '#' in mcp_ref:
server_url, specific_tool = mcp_ref.split('#', 1)
else:
server_url, specific_tool = mcp_ref, None
server_params = {"url": server_url}
server_name = self._extract_server_name(server_url)
try:
# Get tool schemas with timeout and error handling
tool_schemas = self._get_mcp_tool_schemas(server_params)
if not tool_schemas:
self._logger.log("warning", f"No tools discovered from MCP server: {server_url}")
return []
tools = []
for tool_name, schema in tool_schemas.items():
# Skip if specific tool requested and this isn't it
if specific_tool and tool_name != specific_tool:
continue
try:
wrapper = MCPToolWrapper(
mcp_server_params=server_params,
tool_name=tool_name,
tool_schema=schema,
server_name=server_name
)
tools.append(wrapper)
except Exception as e:
self._logger.log("warning", f"Failed to create MCP tool wrapper for {tool_name}: {e}")
continue
if specific_tool and not tools:
self._logger.log("warning", f"Specific tool '{specific_tool}' not found on MCP server: {server_url}")
return tools
except Exception as e:
self._logger.log("warning", f"Failed to connect to MCP server {server_url}: {e}")
return []
def _get_amp_mcp_tools(self, amp_ref: str) -> list[BaseTool]:
"""Get tools from CrewAI AMP MCP marketplace."""
# Parse: "crewai-amp:mcp-name" or "crewai-amp:mcp-name#tool_name"
amp_part = amp_ref.replace('crewai-amp:', '')
if '#' in amp_part:
mcp_name, specific_tool = amp_part.split('#', 1)
else:
mcp_name, specific_tool = amp_part, None
# Call AMP API to get MCP server URLs
mcp_servers = self._fetch_amp_mcp_servers(mcp_name)
tools = []
for server_config in mcp_servers:
server_ref = server_config['url']
if specific_tool:
server_ref += f'#{specific_tool}'
server_tools = self._get_external_mcp_tools(server_ref)
tools.extend(server_tools)
return tools
def _extract_server_name(self, server_url: str) -> str:
"""Extract clean server name from URL for tool prefixing."""
from urllib.parse import urlparse
parsed = urlparse(server_url)
domain = parsed.netloc.replace('.', '_')
path = parsed.path.replace('/', '_').strip('_')
return f"{domain}_{path}" if path else domain
def _get_mcp_tool_schemas(self, server_params: dict) -> dict[str, dict]:
"""Get tool schemas from MCP server for wrapper creation with caching."""
server_url = server_params["url"]
# Check cache first
cache_key = server_url
current_time = time.time()
if cache_key in _mcp_schema_cache:
cached_data, cache_time = _mcp_schema_cache[cache_key]
if current_time - cache_time < _cache_ttl:
self._logger.log("debug", f"Using cached MCP tool schemas for {server_url}")
return cached_data
try:
schemas = asyncio.run(self._get_mcp_tool_schemas_async(server_params))
# Cache successful results
_mcp_schema_cache[cache_key] = (schemas, current_time)
return schemas
except Exception as e:
# Log warning but don't raise - this allows graceful degradation
self._logger.log("warning", f"Failed to get MCP tool schemas from {server_url}: {e}")
return {}
async def _get_mcp_tool_schemas_async(self, server_params: dict) -> dict[str, dict]:
"""Async implementation of MCP tool schema retrieval with timeouts and retries."""
server_url = server_params["url"]
last_error = None
for attempt in range(MCP_MAX_RETRIES):
try:
# Wrap entire operation in timeout
schemas = await asyncio.wait_for(
self._discover_mcp_tools(server_url),
timeout=MCP_DISCOVERY_TIMEOUT
)
return schemas
except asyncio.TimeoutError:
last_error = f"MCP discovery timed out after {MCP_DISCOVERY_TIMEOUT} seconds"
if attempt < MCP_MAX_RETRIES - 1:
wait_time = 2 ** attempt # Exponential backoff
await asyncio.sleep(wait_time)
continue
else:
break
except ImportError:
raise RuntimeError("MCP library not available. Please install with: pip install mcp")
except Exception as e:
error_str = str(e).lower()
# Handle specific error types
if 'connection' in error_str or 'network' in error_str:
last_error = f"Network connection failed: {str(e)}"
elif 'authentication' in error_str or 'unauthorized' in error_str:
raise RuntimeError(f"Authentication failed for MCP server: {str(e)}")
elif 'json' in error_str or 'parsing' in error_str:
last_error = f"Server response parsing error: {str(e)}"
else:
last_error = f"MCP discovery error: {str(e)}"
# Retry for transient errors
if attempt < MCP_MAX_RETRIES - 1 and ('connection' in error_str or 'network' in error_str or 'json' in error_str):
wait_time = 2 ** attempt # Exponential backoff
await asyncio.sleep(wait_time)
continue
else:
break
raise RuntimeError(f"Failed to discover MCP tools after {MCP_MAX_RETRIES} attempts: {last_error}")
async def _discover_mcp_tools(self, server_url: str) -> dict[str, dict]:
"""Discover tools from MCP server with proper timeout handling."""
from mcp import ClientSession
from mcp.client.streamable_http import streamablehttp_client
async with streamablehttp_client(server_url) as (read, write, _):
async with ClientSession(read, write) as session:
# Initialize the connection with timeout
await asyncio.wait_for(
session.initialize(),
timeout=MCP_CONNECTION_TIMEOUT
)
# List available tools with timeout
tools_result = await asyncio.wait_for(
session.list_tools(),
timeout=MCP_DISCOVERY_TIMEOUT - MCP_CONNECTION_TIMEOUT
)
schemas = {}
for tool in tools_result.tools:
schemas[tool.name] = {
'description': getattr(tool, 'description', ''),
'args_schema': None, # Keep simple for now
}
return schemas
def _fetch_amp_mcp_servers(self, mcp_name: str) -> list[dict]:
"""Fetch MCP server configurations from CrewAI AMP API."""
# TODO: Implement AMP API call to "integrations/mcps" endpoint
# Should return list of server configs with URLs
return []
def get_multimodal_tools(self) -> Sequence[BaseTool]:
from crewai.tools.agent_tools.add_image_tool import AddImageTool

View File

@@ -190,6 +190,10 @@ class BaseAgent(ABC, BaseModel):
default=None,
description="List of applications or application/action combinations that the agent can access through CrewAI Platform. Can contain app names (e.g., 'gmail') or specific actions (e.g., 'gmail/send_email')",
)
mcps: list[str] | None = Field(
default=None,
description="List of MCP server references. Supports 'https://server.com/path' for external servers and 'crewai-amp:mcp-name' for AMP marketplace. Use '#tool_name' suffix for specific tools.",
)
@model_validator(mode="before")
@classmethod
@@ -243,6 +247,21 @@ class BaseAgent(ABC, BaseModel):
return list(set(validated_apps))
@field_validator("mcps")
@classmethod
def validate_mcps(cls, mcps: list[str] | None) -> list[str] | None:
if not mcps:
return mcps
validated_mcps = []
for mcp in mcps:
if mcp.startswith('https://') or mcp.startswith('crewai-amp:'):
validated_mcps.append(mcp)
else:
raise ValueError(f"Invalid MCP reference: {mcp}. Must start with 'https://' or 'crewai-amp:'")
return list(set(validated_mcps))
@model_validator(mode="after")
def validate_and_set_attributes(self):
# Validate required fields
@@ -317,6 +336,10 @@ class BaseAgent(ABC, BaseModel):
def get_platform_tools(self, apps: list[PlatformAppOrAction]) -> list[BaseTool]:
"""Get platform tools for the specified list of applications and/or application/action combinations."""
@abstractmethod
def get_mcp_tools(self, mcps: list[str]) -> list[BaseTool]:
"""Get MCP tools for the specified list of MCP server references."""
def copy(self: T) -> T: # type: ignore # Signature of "copy" incompatible with supertype "BaseModel"
"""Create a deep copy of the Agent."""
exclude = {
@@ -334,6 +357,7 @@ class BaseAgent(ABC, BaseModel):
"knowledge_storage",
"knowledge",
"apps",
"mcps",
"actions",
}

View File

@@ -990,6 +990,9 @@ class Crew(FlowTrackable, BaseModel):
if agent and (hasattr(agent, "apps") and getattr(agent, "apps", None)):
tools = self._add_platform_tools(task, tools)
if agent and (hasattr(agent, "mcps") and getattr(agent, "mcps", None)):
tools = self._add_mcp_tools(task, tools)
# Return a list[BaseTool] compatible with Task.execute_sync and execute_async
return cast(list[BaseTool], tools)
@@ -1042,6 +1045,18 @@ class Crew(FlowTrackable, BaseModel):
return self._merge_tools(tools, cast(list[BaseTool], platform_tools))
return cast(list[BaseTool], tools)
def _inject_mcp_tools(
self,
tools: list[Tool] | list[BaseTool],
task_agent: BaseAgent,
) -> list[BaseTool]:
mcps = getattr(task_agent, "mcps", None) or []
if hasattr(task_agent, "get_mcp_tools") and mcps:
mcp_tools = task_agent.get_mcp_tools(mcps=mcps)
return self._merge_tools(tools, cast(list[BaseTool], mcp_tools))
return cast(list[BaseTool], tools)
def _add_multimodal_tools(
self, agent: BaseAgent, tools: list[Tool] | list[BaseTool]
) -> list[BaseTool]:
@@ -1080,6 +1095,14 @@ class Crew(FlowTrackable, BaseModel):
return cast(list[BaseTool], tools or [])
def _add_mcp_tools(
self, task: Task, tools: list[Tool] | list[BaseTool]
) -> list[BaseTool]:
if task.agent:
tools = self._inject_mcp_tools(tools, task.agent)
return cast(list[BaseTool], tools or [])
def _log_task_start(self, task: Task, role: str = "None"):
if self.output_log_file:
self._file_handler.log(

View File

@@ -0,0 +1,171 @@
"""MCP Tool Wrapper for on-demand MCP server connections."""
import asyncio
import time
from typing import Any
from crewai.tools import BaseTool
# MCP Connection timeout constants (in seconds)
MCP_CONNECTION_TIMEOUT = 10
MCP_TOOL_EXECUTION_TIMEOUT = 30
MCP_DISCOVERY_TIMEOUT = 15
MCP_MAX_RETRIES = 3
class MCPToolWrapper(BaseTool):
"""Lightweight wrapper for MCP tools that connects on-demand."""
def __init__(
self,
mcp_server_params: dict,
tool_name: str,
tool_schema: dict,
server_name: str,
):
"""Initialize the MCP tool wrapper.
Args:
mcp_server_params: Parameters for connecting to the MCP server
tool_name: Original name of the tool on the MCP server
tool_schema: Schema information for the tool
server_name: Name of the MCP server for prefixing
"""
# Create tool name with server prefix to avoid conflicts
prefixed_name = f"{server_name}_{tool_name}"
# Handle args_schema properly - BaseTool expects a BaseModel subclass
args_schema = tool_schema.get("args_schema")
# Only pass args_schema if it's provided
kwargs = {
"name": prefixed_name,
"description": tool_schema.get(
"description", f"Tool {tool_name} from {server_name}"
),
}
if args_schema is not None:
kwargs["args_schema"] = args_schema
super().__init__(**kwargs)
# Set instance attributes after super().__init__
self._mcp_server_params = mcp_server_params
self._original_tool_name = tool_name
self._server_name = server_name
@property
def mcp_server_params(self) -> dict:
"""Get the MCP server parameters."""
return self._mcp_server_params
@property
def original_tool_name(self) -> str:
"""Get the original tool name."""
return self._original_tool_name
@property
def server_name(self) -> str:
"""Get the server name."""
return self._server_name
def _run(self, **kwargs) -> str:
"""Connect to MCP server and execute tool.
Args:
**kwargs: Arguments to pass to the MCP tool
Returns:
Result from the MCP tool execution
"""
try:
return asyncio.run(self._run_async(**kwargs))
except asyncio.TimeoutError:
return f"MCP tool '{self.original_tool_name}' timed out after {MCP_TOOL_EXECUTION_TIMEOUT} seconds"
except Exception as e:
return f"Error executing MCP tool {self.original_tool_name}: {str(e)}"
async def _run_async(self, **kwargs) -> str:
"""Async implementation of MCP tool execution with timeouts and retry logic."""
last_error = None
for attempt in range(MCP_MAX_RETRIES):
try:
result = await asyncio.wait_for(
self._execute_tool(**kwargs),
timeout=MCP_TOOL_EXECUTION_TIMEOUT
)
return result
except asyncio.TimeoutError:
last_error = f"Connection timed out after {MCP_TOOL_EXECUTION_TIMEOUT} seconds"
if attempt < MCP_MAX_RETRIES - 1:
wait_time = 2 ** attempt # Exponential backoff
await asyncio.sleep(wait_time)
continue
else:
break
except ImportError:
return "MCP library not available. Please install with: pip install mcp"
except Exception as e:
error_str = str(e).lower()
# Handle specific error types
if 'connection' in error_str or 'network' in error_str:
last_error = f"Network connection failed: {str(e)}"
elif 'authentication' in error_str or 'unauthorized' in error_str:
return f"Authentication failed for MCP server: {str(e)}"
elif 'json' in error_str or 'parsing' in error_str:
last_error = f"Server response parsing error: {str(e)}"
elif 'not found' in error_str:
return f"Tool '{self.original_tool_name}' not found on MCP server"
else:
last_error = f"MCP execution error: {str(e)}"
# Retry for transient errors
if attempt < MCP_MAX_RETRIES - 1 and ('connection' in error_str or 'network' in error_str or 'json' in error_str):
wait_time = 2 ** attempt # Exponential backoff
await asyncio.sleep(wait_time)
continue
else:
break
return f"MCP tool execution failed after {MCP_MAX_RETRIES} attempts: {last_error}"
async def _execute_tool(self, **kwargs) -> str:
"""Execute the actual MCP tool call."""
from mcp import ClientSession
from mcp.client.streamable_http import streamablehttp_client
server_url = self.mcp_server_params["url"]
# Connect to MCP server with timeout
async with streamablehttp_client(server_url) as (read, write, _):
async with ClientSession(read, write) as session:
# Initialize the connection with timeout
await asyncio.wait_for(
session.initialize(),
timeout=MCP_CONNECTION_TIMEOUT
)
# Call the specific tool with timeout
result = await asyncio.wait_for(
session.call_tool(self.original_tool_name, kwargs),
timeout=MCP_TOOL_EXECUTION_TIMEOUT - MCP_CONNECTION_TIMEOUT
)
# Extract the result content
if hasattr(result, 'content') and result.content:
if isinstance(result.content, list) and len(result.content) > 0:
content_item = result.content[0]
if hasattr(content_item, 'text'):
return str(content_item.text)
else:
return str(content_item)
else:
return str(result.content)
else:
return str(result)