mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-08 15:48:29 +00:00
Adding MCP implementation
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
from collections.abc import Callable, Sequence
|
||||
import asyncio
|
||||
import shutil
|
||||
import subprocess
|
||||
import time
|
||||
@@ -55,6 +56,16 @@ from crewai.utilities.token_counter_callback import TokenCalcHandler
|
||||
from crewai.utilities.training_handler import CrewTrainingHandler
|
||||
from crewai.utilities.types import LLMMessage
|
||||
|
||||
# MCP Connection timeout constants (in seconds)
|
||||
MCP_CONNECTION_TIMEOUT = 10
|
||||
MCP_TOOL_EXECUTION_TIMEOUT = 30
|
||||
MCP_DISCOVERY_TIMEOUT = 15
|
||||
MCP_MAX_RETRIES = 3
|
||||
|
||||
# Simple in-memory cache for MCP tool schemas (duration: 5 minutes)
|
||||
_mcp_schema_cache = {}
|
||||
_cache_ttl = 300 # 5 minutes
|
||||
|
||||
|
||||
class Agent(BaseAgent):
|
||||
"""Represents an agent in a system.
|
||||
@@ -80,6 +91,7 @@ class Agent(BaseAgent):
|
||||
knowledge_sources: Knowledge sources for the agent.
|
||||
embedder: Embedder configuration for the agent.
|
||||
apps: List of applications that the agent can access through CrewAI Platform.
|
||||
mcps: List of MCP server references for tool integration.
|
||||
"""
|
||||
|
||||
_times_executed: int = PrivateAttr(default=0)
|
||||
@@ -611,6 +623,214 @@ class Agent(BaseAgent):
|
||||
self._logger.log("error", f"Error getting platform tools: {e!s}")
|
||||
return []
|
||||
|
||||
def get_mcp_tools(self, mcps: list[str]) -> list[BaseTool]:
|
||||
"""Convert MCP server references to CrewAI tools."""
|
||||
all_tools = []
|
||||
|
||||
for mcp_ref in mcps:
|
||||
try:
|
||||
if mcp_ref.startswith('crewai-amp:'):
|
||||
tools = self._get_amp_mcp_tools(mcp_ref)
|
||||
elif mcp_ref.startswith('https://'):
|
||||
tools = self._get_external_mcp_tools(mcp_ref)
|
||||
else:
|
||||
continue
|
||||
|
||||
all_tools.extend(tools)
|
||||
self._logger.log("info", f"Successfully loaded {len(tools)} tools from {mcp_ref}")
|
||||
|
||||
except Exception as e:
|
||||
self._logger.log("warning", f"Skipping MCP {mcp_ref} due to error: {e}")
|
||||
continue
|
||||
|
||||
return all_tools
|
||||
|
||||
def _get_external_mcp_tools(self, mcp_ref: str) -> list[BaseTool]:
|
||||
"""Get tools from external HTTPS MCP server with graceful error handling."""
|
||||
from crewai.tools.mcp_tool_wrapper import MCPToolWrapper
|
||||
|
||||
# Parse server URL and optional tool name
|
||||
if '#' in mcp_ref:
|
||||
server_url, specific_tool = mcp_ref.split('#', 1)
|
||||
else:
|
||||
server_url, specific_tool = mcp_ref, None
|
||||
|
||||
server_params = {"url": server_url}
|
||||
server_name = self._extract_server_name(server_url)
|
||||
|
||||
try:
|
||||
# Get tool schemas with timeout and error handling
|
||||
tool_schemas = self._get_mcp_tool_schemas(server_params)
|
||||
|
||||
if not tool_schemas:
|
||||
self._logger.log("warning", f"No tools discovered from MCP server: {server_url}")
|
||||
return []
|
||||
|
||||
tools = []
|
||||
for tool_name, schema in tool_schemas.items():
|
||||
# Skip if specific tool requested and this isn't it
|
||||
if specific_tool and tool_name != specific_tool:
|
||||
continue
|
||||
|
||||
try:
|
||||
wrapper = MCPToolWrapper(
|
||||
mcp_server_params=server_params,
|
||||
tool_name=tool_name,
|
||||
tool_schema=schema,
|
||||
server_name=server_name
|
||||
)
|
||||
tools.append(wrapper)
|
||||
except Exception as e:
|
||||
self._logger.log("warning", f"Failed to create MCP tool wrapper for {tool_name}: {e}")
|
||||
continue
|
||||
|
||||
if specific_tool and not tools:
|
||||
self._logger.log("warning", f"Specific tool '{specific_tool}' not found on MCP server: {server_url}")
|
||||
|
||||
return tools
|
||||
|
||||
except Exception as e:
|
||||
self._logger.log("warning", f"Failed to connect to MCP server {server_url}: {e}")
|
||||
return []
|
||||
|
||||
def _get_amp_mcp_tools(self, amp_ref: str) -> list[BaseTool]:
|
||||
"""Get tools from CrewAI AMP MCP marketplace."""
|
||||
# Parse: "crewai-amp:mcp-name" or "crewai-amp:mcp-name#tool_name"
|
||||
amp_part = amp_ref.replace('crewai-amp:', '')
|
||||
if '#' in amp_part:
|
||||
mcp_name, specific_tool = amp_part.split('#', 1)
|
||||
else:
|
||||
mcp_name, specific_tool = amp_part, None
|
||||
|
||||
# Call AMP API to get MCP server URLs
|
||||
mcp_servers = self._fetch_amp_mcp_servers(mcp_name)
|
||||
|
||||
tools = []
|
||||
for server_config in mcp_servers:
|
||||
server_ref = server_config['url']
|
||||
if specific_tool:
|
||||
server_ref += f'#{specific_tool}'
|
||||
server_tools = self._get_external_mcp_tools(server_ref)
|
||||
tools.extend(server_tools)
|
||||
|
||||
return tools
|
||||
|
||||
def _extract_server_name(self, server_url: str) -> str:
|
||||
"""Extract clean server name from URL for tool prefixing."""
|
||||
from urllib.parse import urlparse
|
||||
parsed = urlparse(server_url)
|
||||
domain = parsed.netloc.replace('.', '_')
|
||||
path = parsed.path.replace('/', '_').strip('_')
|
||||
return f"{domain}_{path}" if path else domain
|
||||
|
||||
def _get_mcp_tool_schemas(self, server_params: dict) -> dict[str, dict]:
|
||||
"""Get tool schemas from MCP server for wrapper creation with caching."""
|
||||
server_url = server_params["url"]
|
||||
|
||||
# Check cache first
|
||||
cache_key = server_url
|
||||
current_time = time.time()
|
||||
|
||||
if cache_key in _mcp_schema_cache:
|
||||
cached_data, cache_time = _mcp_schema_cache[cache_key]
|
||||
if current_time - cache_time < _cache_ttl:
|
||||
self._logger.log("debug", f"Using cached MCP tool schemas for {server_url}")
|
||||
return cached_data
|
||||
|
||||
try:
|
||||
schemas = asyncio.run(self._get_mcp_tool_schemas_async(server_params))
|
||||
|
||||
# Cache successful results
|
||||
_mcp_schema_cache[cache_key] = (schemas, current_time)
|
||||
|
||||
return schemas
|
||||
except Exception as e:
|
||||
# Log warning but don't raise - this allows graceful degradation
|
||||
self._logger.log("warning", f"Failed to get MCP tool schemas from {server_url}: {e}")
|
||||
return {}
|
||||
|
||||
async def _get_mcp_tool_schemas_async(self, server_params: dict) -> dict[str, dict]:
|
||||
"""Async implementation of MCP tool schema retrieval with timeouts and retries."""
|
||||
server_url = server_params["url"]
|
||||
last_error = None
|
||||
|
||||
for attempt in range(MCP_MAX_RETRIES):
|
||||
try:
|
||||
# Wrap entire operation in timeout
|
||||
schemas = await asyncio.wait_for(
|
||||
self._discover_mcp_tools(server_url),
|
||||
timeout=MCP_DISCOVERY_TIMEOUT
|
||||
)
|
||||
return schemas
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
last_error = f"MCP discovery timed out after {MCP_DISCOVERY_TIMEOUT} seconds"
|
||||
if attempt < MCP_MAX_RETRIES - 1:
|
||||
wait_time = 2 ** attempt # Exponential backoff
|
||||
await asyncio.sleep(wait_time)
|
||||
continue
|
||||
else:
|
||||
break
|
||||
|
||||
except ImportError:
|
||||
raise RuntimeError("MCP library not available. Please install with: pip install mcp")
|
||||
|
||||
except Exception as e:
|
||||
error_str = str(e).lower()
|
||||
|
||||
# Handle specific error types
|
||||
if 'connection' in error_str or 'network' in error_str:
|
||||
last_error = f"Network connection failed: {str(e)}"
|
||||
elif 'authentication' in error_str or 'unauthorized' in error_str:
|
||||
raise RuntimeError(f"Authentication failed for MCP server: {str(e)}")
|
||||
elif 'json' in error_str or 'parsing' in error_str:
|
||||
last_error = f"Server response parsing error: {str(e)}"
|
||||
else:
|
||||
last_error = f"MCP discovery error: {str(e)}"
|
||||
|
||||
# Retry for transient errors
|
||||
if attempt < MCP_MAX_RETRIES - 1 and ('connection' in error_str or 'network' in error_str or 'json' in error_str):
|
||||
wait_time = 2 ** attempt # Exponential backoff
|
||||
await asyncio.sleep(wait_time)
|
||||
continue
|
||||
else:
|
||||
break
|
||||
|
||||
raise RuntimeError(f"Failed to discover MCP tools after {MCP_MAX_RETRIES} attempts: {last_error}")
|
||||
|
||||
async def _discover_mcp_tools(self, server_url: str) -> dict[str, dict]:
|
||||
"""Discover tools from MCP server with proper timeout handling."""
|
||||
from mcp import ClientSession
|
||||
from mcp.client.streamable_http import streamablehttp_client
|
||||
|
||||
async with streamablehttp_client(server_url) as (read, write, _):
|
||||
async with ClientSession(read, write) as session:
|
||||
# Initialize the connection with timeout
|
||||
await asyncio.wait_for(
|
||||
session.initialize(),
|
||||
timeout=MCP_CONNECTION_TIMEOUT
|
||||
)
|
||||
|
||||
# List available tools with timeout
|
||||
tools_result = await asyncio.wait_for(
|
||||
session.list_tools(),
|
||||
timeout=MCP_DISCOVERY_TIMEOUT - MCP_CONNECTION_TIMEOUT
|
||||
)
|
||||
|
||||
schemas = {}
|
||||
for tool in tools_result.tools:
|
||||
schemas[tool.name] = {
|
||||
'description': getattr(tool, 'description', ''),
|
||||
'args_schema': None, # Keep simple for now
|
||||
}
|
||||
return schemas
|
||||
|
||||
def _fetch_amp_mcp_servers(self, mcp_name: str) -> list[dict]:
|
||||
"""Fetch MCP server configurations from CrewAI AMP API."""
|
||||
# TODO: Implement AMP API call to "integrations/mcps" endpoint
|
||||
# Should return list of server configs with URLs
|
||||
return []
|
||||
|
||||
def get_multimodal_tools(self) -> Sequence[BaseTool]:
|
||||
from crewai.tools.agent_tools.add_image_tool import AddImageTool
|
||||
|
||||
|
||||
@@ -190,6 +190,10 @@ class BaseAgent(ABC, BaseModel):
|
||||
default=None,
|
||||
description="List of applications or application/action combinations that the agent can access through CrewAI Platform. Can contain app names (e.g., 'gmail') or specific actions (e.g., 'gmail/send_email')",
|
||||
)
|
||||
mcps: list[str] | None = Field(
|
||||
default=None,
|
||||
description="List of MCP server references. Supports 'https://server.com/path' for external servers and 'crewai-amp:mcp-name' for AMP marketplace. Use '#tool_name' suffix for specific tools.",
|
||||
)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
@@ -243,6 +247,21 @@ class BaseAgent(ABC, BaseModel):
|
||||
|
||||
return list(set(validated_apps))
|
||||
|
||||
@field_validator("mcps")
|
||||
@classmethod
|
||||
def validate_mcps(cls, mcps: list[str] | None) -> list[str] | None:
|
||||
if not mcps:
|
||||
return mcps
|
||||
|
||||
validated_mcps = []
|
||||
for mcp in mcps:
|
||||
if mcp.startswith('https://') or mcp.startswith('crewai-amp:'):
|
||||
validated_mcps.append(mcp)
|
||||
else:
|
||||
raise ValueError(f"Invalid MCP reference: {mcp}. Must start with 'https://' or 'crewai-amp:'")
|
||||
|
||||
return list(set(validated_mcps))
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_and_set_attributes(self):
|
||||
# Validate required fields
|
||||
@@ -317,6 +336,10 @@ class BaseAgent(ABC, BaseModel):
|
||||
def get_platform_tools(self, apps: list[PlatformAppOrAction]) -> list[BaseTool]:
|
||||
"""Get platform tools for the specified list of applications and/or application/action combinations."""
|
||||
|
||||
@abstractmethod
|
||||
def get_mcp_tools(self, mcps: list[str]) -> list[BaseTool]:
|
||||
"""Get MCP tools for the specified list of MCP server references."""
|
||||
|
||||
def copy(self: T) -> T: # type: ignore # Signature of "copy" incompatible with supertype "BaseModel"
|
||||
"""Create a deep copy of the Agent."""
|
||||
exclude = {
|
||||
@@ -334,6 +357,7 @@ class BaseAgent(ABC, BaseModel):
|
||||
"knowledge_storage",
|
||||
"knowledge",
|
||||
"apps",
|
||||
"mcps",
|
||||
"actions",
|
||||
}
|
||||
|
||||
|
||||
@@ -990,6 +990,9 @@ class Crew(FlowTrackable, BaseModel):
|
||||
if agent and (hasattr(agent, "apps") and getattr(agent, "apps", None)):
|
||||
tools = self._add_platform_tools(task, tools)
|
||||
|
||||
if agent and (hasattr(agent, "mcps") and getattr(agent, "mcps", None)):
|
||||
tools = self._add_mcp_tools(task, tools)
|
||||
|
||||
# Return a list[BaseTool] compatible with Task.execute_sync and execute_async
|
||||
return cast(list[BaseTool], tools)
|
||||
|
||||
@@ -1042,6 +1045,18 @@ class Crew(FlowTrackable, BaseModel):
|
||||
return self._merge_tools(tools, cast(list[BaseTool], platform_tools))
|
||||
return cast(list[BaseTool], tools)
|
||||
|
||||
def _inject_mcp_tools(
|
||||
self,
|
||||
tools: list[Tool] | list[BaseTool],
|
||||
task_agent: BaseAgent,
|
||||
) -> list[BaseTool]:
|
||||
mcps = getattr(task_agent, "mcps", None) or []
|
||||
|
||||
if hasattr(task_agent, "get_mcp_tools") and mcps:
|
||||
mcp_tools = task_agent.get_mcp_tools(mcps=mcps)
|
||||
return self._merge_tools(tools, cast(list[BaseTool], mcp_tools))
|
||||
return cast(list[BaseTool], tools)
|
||||
|
||||
def _add_multimodal_tools(
|
||||
self, agent: BaseAgent, tools: list[Tool] | list[BaseTool]
|
||||
) -> list[BaseTool]:
|
||||
@@ -1080,6 +1095,14 @@ class Crew(FlowTrackable, BaseModel):
|
||||
|
||||
return cast(list[BaseTool], tools or [])
|
||||
|
||||
def _add_mcp_tools(
|
||||
self, task: Task, tools: list[Tool] | list[BaseTool]
|
||||
) -> list[BaseTool]:
|
||||
if task.agent:
|
||||
tools = self._inject_mcp_tools(tools, task.agent)
|
||||
|
||||
return cast(list[BaseTool], tools or [])
|
||||
|
||||
def _log_task_start(self, task: Task, role: str = "None"):
|
||||
if self.output_log_file:
|
||||
self._file_handler.log(
|
||||
|
||||
171
lib/crewai/src/crewai/tools/mcp_tool_wrapper.py
Normal file
171
lib/crewai/src/crewai/tools/mcp_tool_wrapper.py
Normal file
@@ -0,0 +1,171 @@
|
||||
"""MCP Tool Wrapper for on-demand MCP server connections."""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
|
||||
# MCP Connection timeout constants (in seconds)
|
||||
MCP_CONNECTION_TIMEOUT = 10
|
||||
MCP_TOOL_EXECUTION_TIMEOUT = 30
|
||||
MCP_DISCOVERY_TIMEOUT = 15
|
||||
MCP_MAX_RETRIES = 3
|
||||
|
||||
|
||||
class MCPToolWrapper(BaseTool):
|
||||
"""Lightweight wrapper for MCP tools that connects on-demand."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mcp_server_params: dict,
|
||||
tool_name: str,
|
||||
tool_schema: dict,
|
||||
server_name: str,
|
||||
):
|
||||
"""Initialize the MCP tool wrapper.
|
||||
|
||||
Args:
|
||||
mcp_server_params: Parameters for connecting to the MCP server
|
||||
tool_name: Original name of the tool on the MCP server
|
||||
tool_schema: Schema information for the tool
|
||||
server_name: Name of the MCP server for prefixing
|
||||
"""
|
||||
# Create tool name with server prefix to avoid conflicts
|
||||
prefixed_name = f"{server_name}_{tool_name}"
|
||||
|
||||
# Handle args_schema properly - BaseTool expects a BaseModel subclass
|
||||
args_schema = tool_schema.get("args_schema")
|
||||
|
||||
# Only pass args_schema if it's provided
|
||||
kwargs = {
|
||||
"name": prefixed_name,
|
||||
"description": tool_schema.get(
|
||||
"description", f"Tool {tool_name} from {server_name}"
|
||||
),
|
||||
}
|
||||
|
||||
if args_schema is not None:
|
||||
kwargs["args_schema"] = args_schema
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
# Set instance attributes after super().__init__
|
||||
self._mcp_server_params = mcp_server_params
|
||||
self._original_tool_name = tool_name
|
||||
self._server_name = server_name
|
||||
|
||||
@property
|
||||
def mcp_server_params(self) -> dict:
|
||||
"""Get the MCP server parameters."""
|
||||
return self._mcp_server_params
|
||||
|
||||
@property
|
||||
def original_tool_name(self) -> str:
|
||||
"""Get the original tool name."""
|
||||
return self._original_tool_name
|
||||
|
||||
@property
|
||||
def server_name(self) -> str:
|
||||
"""Get the server name."""
|
||||
return self._server_name
|
||||
|
||||
def _run(self, **kwargs) -> str:
|
||||
"""Connect to MCP server and execute tool.
|
||||
|
||||
Args:
|
||||
**kwargs: Arguments to pass to the MCP tool
|
||||
|
||||
Returns:
|
||||
Result from the MCP tool execution
|
||||
"""
|
||||
try:
|
||||
return asyncio.run(self._run_async(**kwargs))
|
||||
except asyncio.TimeoutError:
|
||||
return f"MCP tool '{self.original_tool_name}' timed out after {MCP_TOOL_EXECUTION_TIMEOUT} seconds"
|
||||
except Exception as e:
|
||||
return f"Error executing MCP tool {self.original_tool_name}: {str(e)}"
|
||||
|
||||
async def _run_async(self, **kwargs) -> str:
|
||||
"""Async implementation of MCP tool execution with timeouts and retry logic."""
|
||||
last_error = None
|
||||
|
||||
for attempt in range(MCP_MAX_RETRIES):
|
||||
try:
|
||||
result = await asyncio.wait_for(
|
||||
self._execute_tool(**kwargs),
|
||||
timeout=MCP_TOOL_EXECUTION_TIMEOUT
|
||||
)
|
||||
return result
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
last_error = f"Connection timed out after {MCP_TOOL_EXECUTION_TIMEOUT} seconds"
|
||||
if attempt < MCP_MAX_RETRIES - 1:
|
||||
wait_time = 2 ** attempt # Exponential backoff
|
||||
await asyncio.sleep(wait_time)
|
||||
continue
|
||||
else:
|
||||
break
|
||||
|
||||
except ImportError:
|
||||
return "MCP library not available. Please install with: pip install mcp"
|
||||
|
||||
except Exception as e:
|
||||
error_str = str(e).lower()
|
||||
|
||||
# Handle specific error types
|
||||
if 'connection' in error_str or 'network' in error_str:
|
||||
last_error = f"Network connection failed: {str(e)}"
|
||||
elif 'authentication' in error_str or 'unauthorized' in error_str:
|
||||
return f"Authentication failed for MCP server: {str(e)}"
|
||||
elif 'json' in error_str or 'parsing' in error_str:
|
||||
last_error = f"Server response parsing error: {str(e)}"
|
||||
elif 'not found' in error_str:
|
||||
return f"Tool '{self.original_tool_name}' not found on MCP server"
|
||||
else:
|
||||
last_error = f"MCP execution error: {str(e)}"
|
||||
|
||||
# Retry for transient errors
|
||||
if attempt < MCP_MAX_RETRIES - 1 and ('connection' in error_str or 'network' in error_str or 'json' in error_str):
|
||||
wait_time = 2 ** attempt # Exponential backoff
|
||||
await asyncio.sleep(wait_time)
|
||||
continue
|
||||
else:
|
||||
break
|
||||
|
||||
return f"MCP tool execution failed after {MCP_MAX_RETRIES} attempts: {last_error}"
|
||||
|
||||
async def _execute_tool(self, **kwargs) -> str:
|
||||
"""Execute the actual MCP tool call."""
|
||||
from mcp import ClientSession
|
||||
from mcp.client.streamable_http import streamablehttp_client
|
||||
|
||||
server_url = self.mcp_server_params["url"]
|
||||
|
||||
# Connect to MCP server with timeout
|
||||
async with streamablehttp_client(server_url) as (read, write, _):
|
||||
async with ClientSession(read, write) as session:
|
||||
# Initialize the connection with timeout
|
||||
await asyncio.wait_for(
|
||||
session.initialize(),
|
||||
timeout=MCP_CONNECTION_TIMEOUT
|
||||
)
|
||||
|
||||
# Call the specific tool with timeout
|
||||
result = await asyncio.wait_for(
|
||||
session.call_tool(self.original_tool_name, kwargs),
|
||||
timeout=MCP_TOOL_EXECUTION_TIMEOUT - MCP_CONNECTION_TIMEOUT
|
||||
)
|
||||
|
||||
# Extract the result content
|
||||
if hasattr(result, 'content') and result.content:
|
||||
if isinstance(result.content, list) and len(result.content) > 0:
|
||||
content_item = result.content[0]
|
||||
if hasattr(content_item, 'text'):
|
||||
return str(content_item.text)
|
||||
else:
|
||||
return str(content_item)
|
||||
else:
|
||||
return str(result.content)
|
||||
else:
|
||||
return str(result)
|
||||
Reference in New Issue
Block a user