mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-08 15:48:29 +00:00
Lorenze/feat mcp first class support (#3850)
* WIP transport support mcp * refactor: streamline MCP tool loading and error handling * linted * Self type from typing with typing_extensions in MCP transport modules * added tests for mcp setup * added tests for mcp setup * docs: enhance MCP overview with detailed integration examples and structured configurations * feat: implement MCP event handling and logging in event listener and client - Added MCP event types and handlers for connection and tool execution events. - Enhanced MCPClient to emit events on connection status and tool execution. - Updated ConsoleFormatter to handle MCP event logging. - Introduced new MCP event types for better integration and monitoring.
This commit is contained in:
@@ -11,9 +11,13 @@ The [Model Context Protocol](https://modelcontextprotocol.io/introduction) (MCP)
|
|||||||
|
|
||||||
CrewAI offers **two approaches** for MCP integration:
|
CrewAI offers **two approaches** for MCP integration:
|
||||||
|
|
||||||
### Simple DSL Integration** (Recommended)
|
### 🚀 **Simple DSL Integration** (Recommended)
|
||||||
|
|
||||||
Use the `mcps` field directly on agents for seamless MCP tool integration:
|
Use the `mcps` field directly on agents for seamless MCP tool integration. The DSL supports both **string references** (for quick setup) and **structured configurations** (for full control).
|
||||||
|
|
||||||
|
#### String-Based References (Quick Setup)
|
||||||
|
|
||||||
|
Perfect for remote HTTPS servers and CrewAI AMP marketplace:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from crewai import Agent
|
from crewai import Agent
|
||||||
@@ -32,6 +36,46 @@ agent = Agent(
|
|||||||
# MCP tools are now automatically available to your agent!
|
# MCP tools are now automatically available to your agent!
|
||||||
```
|
```
|
||||||
|
|
||||||
|
#### Structured Configurations (Full Control)
|
||||||
|
|
||||||
|
For complete control over connection settings, tool filtering, and all transport types:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from crewai import Agent
|
||||||
|
from crewai.mcp import MCPServerStdio, MCPServerHTTP, MCPServerSSE
|
||||||
|
from crewai.mcp.filters import create_static_tool_filter
|
||||||
|
|
||||||
|
agent = Agent(
|
||||||
|
role="Advanced Research Analyst",
|
||||||
|
goal="Research with full control over MCP connections",
|
||||||
|
backstory="Expert researcher with advanced tool access",
|
||||||
|
mcps=[
|
||||||
|
# Stdio transport for local servers
|
||||||
|
MCPServerStdio(
|
||||||
|
command="npx",
|
||||||
|
args=["-y", "@modelcontextprotocol/server-filesystem"],
|
||||||
|
env={"API_KEY": "your_key"},
|
||||||
|
tool_filter=create_static_tool_filter(
|
||||||
|
allowed_tool_names=["read_file", "list_directory"]
|
||||||
|
),
|
||||||
|
cache_tools_list=True,
|
||||||
|
),
|
||||||
|
# HTTP/Streamable HTTP transport for remote servers
|
||||||
|
MCPServerHTTP(
|
||||||
|
url="https://api.example.com/mcp",
|
||||||
|
headers={"Authorization": "Bearer your_token"},
|
||||||
|
streamable=True,
|
||||||
|
cache_tools_list=True,
|
||||||
|
),
|
||||||
|
# SSE transport for real-time streaming
|
||||||
|
MCPServerSSE(
|
||||||
|
url="https://stream.example.com/mcp/sse",
|
||||||
|
headers={"Authorization": "Bearer your_token"},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
### 🔧 **Advanced: MCPServerAdapter** (For Complex Scenarios)
|
### 🔧 **Advanced: MCPServerAdapter** (For Complex Scenarios)
|
||||||
|
|
||||||
For advanced use cases requiring manual connection management, the `crewai-tools` library provides the `MCPServerAdapter` class.
|
For advanced use cases requiring manual connection management, the `crewai-tools` library provides the `MCPServerAdapter` class.
|
||||||
@@ -68,12 +112,14 @@ uv pip install 'crewai-tools[mcp]'
|
|||||||
|
|
||||||
## Quick Start: Simple DSL Integration
|
## Quick Start: Simple DSL Integration
|
||||||
|
|
||||||
The easiest way to integrate MCP servers is using the `mcps` field on your agents:
|
The easiest way to integrate MCP servers is using the `mcps` field on your agents. You can use either string references or structured configurations.
|
||||||
|
|
||||||
|
### Quick Start with String References
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from crewai import Agent, Task, Crew
|
from crewai import Agent, Task, Crew
|
||||||
|
|
||||||
# Create agent with MCP tools
|
# Create agent with MCP tools using string references
|
||||||
research_agent = Agent(
|
research_agent = Agent(
|
||||||
role="Research Analyst",
|
role="Research Analyst",
|
||||||
goal="Find and analyze information using advanced search tools",
|
goal="Find and analyze information using advanced search tools",
|
||||||
@@ -96,13 +142,53 @@ crew = Crew(agents=[research_agent], tasks=[research_task])
|
|||||||
result = crew.kickoff()
|
result = crew.kickoff()
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Quick Start with Structured Configurations
|
||||||
|
|
||||||
|
```python
|
||||||
|
from crewai import Agent, Task, Crew
|
||||||
|
from crewai.mcp import MCPServerStdio, MCPServerHTTP, MCPServerSSE
|
||||||
|
|
||||||
|
# Create agent with structured MCP configurations
|
||||||
|
research_agent = Agent(
|
||||||
|
role="Research Analyst",
|
||||||
|
goal="Find and analyze information using advanced search tools",
|
||||||
|
backstory="Expert researcher with access to multiple data sources",
|
||||||
|
mcps=[
|
||||||
|
# Local stdio server
|
||||||
|
MCPServerStdio(
|
||||||
|
command="python",
|
||||||
|
args=["local_server.py"],
|
||||||
|
env={"API_KEY": "your_key"},
|
||||||
|
),
|
||||||
|
# Remote HTTP server
|
||||||
|
MCPServerHTTP(
|
||||||
|
url="https://api.research.com/mcp",
|
||||||
|
headers={"Authorization": "Bearer your_token"},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create task
|
||||||
|
research_task = Task(
|
||||||
|
description="Research the latest developments in AI agent frameworks",
|
||||||
|
expected_output="Comprehensive research report with citations",
|
||||||
|
agent=research_agent
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create and run crew
|
||||||
|
crew = Crew(agents=[research_agent], tasks=[research_task])
|
||||||
|
result = crew.kickoff()
|
||||||
|
```
|
||||||
|
|
||||||
That's it! The MCP tools are automatically discovered and available to your agent.
|
That's it! The MCP tools are automatically discovered and available to your agent.
|
||||||
|
|
||||||
## MCP Reference Formats
|
## MCP Reference Formats
|
||||||
|
|
||||||
The `mcps` field supports various reference formats for maximum flexibility:
|
The `mcps` field supports both **string references** (for quick setup) and **structured configurations** (for full control). You can mix both formats in the same list.
|
||||||
|
|
||||||
### External MCP Servers
|
### String-Based References
|
||||||
|
|
||||||
|
#### External MCP Servers
|
||||||
|
|
||||||
```python
|
```python
|
||||||
mcps=[
|
mcps=[
|
||||||
@@ -117,7 +203,7 @@ mcps=[
|
|||||||
]
|
]
|
||||||
```
|
```
|
||||||
|
|
||||||
### CrewAI AMP Marketplace
|
#### CrewAI AMP Marketplace
|
||||||
|
|
||||||
```python
|
```python
|
||||||
mcps=[
|
mcps=[
|
||||||
@@ -133,17 +219,166 @@ mcps=[
|
|||||||
]
|
]
|
||||||
```
|
```
|
||||||
|
|
||||||
### Mixed References
|
### Structured Configurations
|
||||||
|
|
||||||
|
#### Stdio Transport (Local Servers)
|
||||||
|
|
||||||
|
Perfect for local MCP servers that run as processes:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
|
from crewai.mcp import MCPServerStdio
|
||||||
|
from crewai.mcp.filters import create_static_tool_filter
|
||||||
|
|
||||||
mcps=[
|
mcps=[
|
||||||
"https://external-api.com/mcp", # External server
|
MCPServerStdio(
|
||||||
"https://weather.service.com/mcp#forecast", # Specific external tool
|
command="npx",
|
||||||
"crewai-amp:financial-insights", # AMP service
|
args=["-y", "@modelcontextprotocol/server-filesystem"],
|
||||||
"crewai-amp:data-analysis#sentiment_tool" # Specific AMP tool
|
env={"API_KEY": "your_key"},
|
||||||
|
tool_filter=create_static_tool_filter(
|
||||||
|
allowed_tool_names=["read_file", "write_file"]
|
||||||
|
),
|
||||||
|
cache_tools_list=True,
|
||||||
|
),
|
||||||
|
# Python-based server
|
||||||
|
MCPServerStdio(
|
||||||
|
command="python",
|
||||||
|
args=["path/to/server.py"],
|
||||||
|
env={"UV_PYTHON": "3.12", "API_KEY": "your_key"},
|
||||||
|
),
|
||||||
]
|
]
|
||||||
```
|
```
|
||||||
|
|
||||||
|
#### HTTP/Streamable HTTP Transport (Remote Servers)
|
||||||
|
|
||||||
|
For remote MCP servers over HTTP/HTTPS:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from crewai.mcp import MCPServerHTTP
|
||||||
|
|
||||||
|
mcps=[
|
||||||
|
# Streamable HTTP (default)
|
||||||
|
MCPServerHTTP(
|
||||||
|
url="https://api.example.com/mcp",
|
||||||
|
headers={"Authorization": "Bearer your_token"},
|
||||||
|
streamable=True,
|
||||||
|
cache_tools_list=True,
|
||||||
|
),
|
||||||
|
# Standard HTTP
|
||||||
|
MCPServerHTTP(
|
||||||
|
url="https://api.example.com/mcp",
|
||||||
|
headers={"Authorization": "Bearer your_token"},
|
||||||
|
streamable=False,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
```
|
||||||
|
|
||||||
|
#### SSE Transport (Real-Time Streaming)
|
||||||
|
|
||||||
|
For remote servers using Server-Sent Events:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from crewai.mcp import MCPServerSSE
|
||||||
|
|
||||||
|
mcps=[
|
||||||
|
MCPServerSSE(
|
||||||
|
url="https://stream.example.com/mcp/sse",
|
||||||
|
headers={"Authorization": "Bearer your_token"},
|
||||||
|
cache_tools_list=True,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
```
|
||||||
|
|
||||||
|
### Mixed References
|
||||||
|
|
||||||
|
You can combine string references and structured configurations:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from crewai.mcp import MCPServerStdio, MCPServerHTTP
|
||||||
|
|
||||||
|
mcps=[
|
||||||
|
# String references
|
||||||
|
"https://external-api.com/mcp", # External server
|
||||||
|
"crewai-amp:financial-insights", # AMP service
|
||||||
|
|
||||||
|
# Structured configurations
|
||||||
|
MCPServerStdio(
|
||||||
|
command="npx",
|
||||||
|
args=["-y", "@modelcontextprotocol/server-filesystem"],
|
||||||
|
),
|
||||||
|
MCPServerHTTP(
|
||||||
|
url="https://api.example.com/mcp",
|
||||||
|
headers={"Authorization": "Bearer token"},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
```
|
||||||
|
|
||||||
|
### Tool Filtering
|
||||||
|
|
||||||
|
Structured configurations support advanced tool filtering:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from crewai.mcp import MCPServerStdio
|
||||||
|
from crewai.mcp.filters import create_static_tool_filter, create_dynamic_tool_filter, ToolFilterContext
|
||||||
|
|
||||||
|
# Static filtering (allow/block lists)
|
||||||
|
static_filter = create_static_tool_filter(
|
||||||
|
allowed_tool_names=["read_file", "write_file"],
|
||||||
|
blocked_tool_names=["delete_file"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Dynamic filtering (context-aware)
|
||||||
|
def dynamic_filter(context: ToolFilterContext, tool: dict) -> bool:
|
||||||
|
# Block dangerous tools for certain agent roles
|
||||||
|
if context.agent.role == "Code Reviewer":
|
||||||
|
if "delete" in tool.get("name", "").lower():
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
mcps=[
|
||||||
|
MCPServerStdio(
|
||||||
|
command="npx",
|
||||||
|
args=["-y", "@modelcontextprotocol/server-filesystem"],
|
||||||
|
tool_filter=static_filter, # or dynamic_filter
|
||||||
|
),
|
||||||
|
]
|
||||||
|
```
|
||||||
|
|
||||||
|
## Configuration Parameters
|
||||||
|
|
||||||
|
Each transport type supports specific configuration options:
|
||||||
|
|
||||||
|
### MCPServerStdio Parameters
|
||||||
|
|
||||||
|
- **`command`** (required): Command to execute (e.g., `"python"`, `"node"`, `"npx"`, `"uvx"`)
|
||||||
|
- **`args`** (optional): List of command arguments (e.g., `["server.py"]` or `["-y", "@mcp/server"]`)
|
||||||
|
- **`env`** (optional): Dictionary of environment variables to pass to the process
|
||||||
|
- **`tool_filter`** (optional): Tool filter function for filtering available tools
|
||||||
|
- **`cache_tools_list`** (optional): Whether to cache the tool list for faster subsequent access (default: `False`)
|
||||||
|
|
||||||
|
### MCPServerHTTP Parameters
|
||||||
|
|
||||||
|
- **`url`** (required): Server URL (e.g., `"https://api.example.com/mcp"`)
|
||||||
|
- **`headers`** (optional): Dictionary of HTTP headers for authentication or other purposes
|
||||||
|
- **`streamable`** (optional): Whether to use streamable HTTP transport (default: `True`)
|
||||||
|
- **`tool_filter`** (optional): Tool filter function for filtering available tools
|
||||||
|
- **`cache_tools_list`** (optional): Whether to cache the tool list for faster subsequent access (default: `False`)
|
||||||
|
|
||||||
|
### MCPServerSSE Parameters
|
||||||
|
|
||||||
|
- **`url`** (required): Server URL (e.g., `"https://api.example.com/mcp/sse"`)
|
||||||
|
- **`headers`** (optional): Dictionary of HTTP headers for authentication or other purposes
|
||||||
|
- **`tool_filter`** (optional): Tool filter function for filtering available tools
|
||||||
|
- **`cache_tools_list`** (optional): Whether to cache the tool list for faster subsequent access (default: `False`)
|
||||||
|
|
||||||
|
### Common Parameters
|
||||||
|
|
||||||
|
All transport types support:
|
||||||
|
- **`tool_filter`**: Filter function to control which tools are available. Can be:
|
||||||
|
- `None` (default): All tools are available
|
||||||
|
- Static filter: Created with `create_static_tool_filter()` for allow/block lists
|
||||||
|
- Dynamic filter: Created with `create_dynamic_tool_filter()` for context-aware filtering
|
||||||
|
- **`cache_tools_list`**: When `True`, caches the tool list after first discovery to improve performance on subsequent connections
|
||||||
|
|
||||||
## Key Features
|
## Key Features
|
||||||
|
|
||||||
- 🔄 **Automatic Tool Discovery**: Tools are automatically discovered and integrated
|
- 🔄 **Automatic Tool Discovery**: Tools are automatically discovered and integrated
|
||||||
@@ -152,26 +387,47 @@ mcps=[
|
|||||||
- 🛡️ **Error Resilience**: Graceful handling of unavailable servers
|
- 🛡️ **Error Resilience**: Graceful handling of unavailable servers
|
||||||
- ⏱️ **Timeout Protection**: Built-in timeouts prevent hanging connections
|
- ⏱️ **Timeout Protection**: Built-in timeouts prevent hanging connections
|
||||||
- 📊 **Transparent Integration**: Works seamlessly with existing CrewAI features
|
- 📊 **Transparent Integration**: Works seamlessly with existing CrewAI features
|
||||||
|
- 🔧 **Full Transport Support**: Stdio, HTTP/Streamable HTTP, and SSE transports
|
||||||
|
- 🎯 **Advanced Filtering**: Static and dynamic tool filtering capabilities
|
||||||
|
- 🔐 **Flexible Authentication**: Support for headers, environment variables, and query parameters
|
||||||
|
|
||||||
## Error Handling
|
## Error Handling
|
||||||
|
|
||||||
The MCP DSL integration is designed to be resilient:
|
The MCP DSL integration is designed to be resilient and handles failures gracefully:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
|
from crewai import Agent
|
||||||
|
from crewai.mcp import MCPServerStdio, MCPServerHTTP
|
||||||
|
|
||||||
agent = Agent(
|
agent = Agent(
|
||||||
role="Resilient Agent",
|
role="Resilient Agent",
|
||||||
goal="Continue working despite server issues",
|
goal="Continue working despite server issues",
|
||||||
backstory="Agent that handles failures gracefully",
|
backstory="Agent that handles failures gracefully",
|
||||||
mcps=[
|
mcps=[
|
||||||
|
# String references
|
||||||
"https://reliable-server.com/mcp", # Will work
|
"https://reliable-server.com/mcp", # Will work
|
||||||
"https://unreachable-server.com/mcp", # Will be skipped gracefully
|
"https://unreachable-server.com/mcp", # Will be skipped gracefully
|
||||||
"https://slow-server.com/mcp", # Will timeout gracefully
|
"crewai-amp:working-service", # Will work
|
||||||
"crewai-amp:working-service" # Will work
|
|
||||||
|
# Structured configs
|
||||||
|
MCPServerStdio(
|
||||||
|
command="python",
|
||||||
|
args=["reliable_server.py"], # Will work
|
||||||
|
),
|
||||||
|
MCPServerHTTP(
|
||||||
|
url="https://slow-server.com/mcp", # Will timeout gracefully
|
||||||
|
),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
# Agent will use tools from working servers and log warnings for failing ones
|
# Agent will use tools from working servers and log warnings for failing ones
|
||||||
```
|
```
|
||||||
|
|
||||||
|
All connection errors are handled gracefully:
|
||||||
|
- **Connection failures**: Logged as warnings, agent continues with available tools
|
||||||
|
- **Timeout errors**: Connections timeout after 30 seconds (configurable)
|
||||||
|
- **Authentication errors**: Logged clearly for debugging
|
||||||
|
- **Invalid configurations**: Validation errors are raised at agent creation time
|
||||||
|
|
||||||
## Advanced: MCPServerAdapter
|
## Advanced: MCPServerAdapter
|
||||||
|
|
||||||
For complex scenarios requiring manual connection management, use the `MCPServerAdapter` class from `crewai-tools`. Using a Python context manager (`with` statement) is the recommended approach as it automatically handles starting and stopping the connection to the MCP server.
|
For complex scenarios requiring manual connection management, use the `MCPServerAdapter` class from `crewai-tools`. Using a Python context manager (`with` statement) is the recommended approach as it automatically handles starting and stopping the connection to the MCP server.
|
||||||
|
|||||||
@@ -40,6 +40,16 @@ from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
|
|||||||
from crewai.knowledge.utils.knowledge_utils import extract_knowledge_context
|
from crewai.knowledge.utils.knowledge_utils import extract_knowledge_context
|
||||||
from crewai.lite_agent import LiteAgent
|
from crewai.lite_agent import LiteAgent
|
||||||
from crewai.llms.base_llm import BaseLLM
|
from crewai.llms.base_llm import BaseLLM
|
||||||
|
from crewai.mcp import (
|
||||||
|
MCPClient,
|
||||||
|
MCPServerConfig,
|
||||||
|
MCPServerHTTP,
|
||||||
|
MCPServerSSE,
|
||||||
|
MCPServerStdio,
|
||||||
|
)
|
||||||
|
from crewai.mcp.transports.http import HTTPTransport
|
||||||
|
from crewai.mcp.transports.sse import SSETransport
|
||||||
|
from crewai.mcp.transports.stdio import StdioTransport
|
||||||
from crewai.memory.contextual.contextual_memory import ContextualMemory
|
from crewai.memory.contextual.contextual_memory import ContextualMemory
|
||||||
from crewai.rag.embeddings.types import EmbedderConfig
|
from crewai.rag.embeddings.types import EmbedderConfig
|
||||||
from crewai.security.fingerprint import Fingerprint
|
from crewai.security.fingerprint import Fingerprint
|
||||||
@@ -108,6 +118,7 @@ class Agent(BaseAgent):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
_times_executed: int = PrivateAttr(default=0)
|
_times_executed: int = PrivateAttr(default=0)
|
||||||
|
_mcp_clients: list[Any] = PrivateAttr(default_factory=list)
|
||||||
max_execution_time: int | None = Field(
|
max_execution_time: int | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="Maximum execution time for an agent to execute a task",
|
description="Maximum execution time for an agent to execute a task",
|
||||||
@@ -526,6 +537,9 @@ class Agent(BaseAgent):
|
|||||||
self,
|
self,
|
||||||
event=AgentExecutionCompletedEvent(agent=self, task=task, output=result),
|
event=AgentExecutionCompletedEvent(agent=self, task=task, output=result),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self._cleanup_mcp_clients()
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def _execute_with_timeout(self, task_prompt: str, task: Task, timeout: int) -> Any:
|
def _execute_with_timeout(self, task_prompt: str, task: Task, timeout: int) -> Any:
|
||||||
@@ -649,30 +663,70 @@ class Agent(BaseAgent):
|
|||||||
self._logger.log("error", f"Error getting platform tools: {e!s}")
|
self._logger.log("error", f"Error getting platform tools: {e!s}")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def get_mcp_tools(self, mcps: list[str]) -> list[BaseTool]:
|
def get_mcp_tools(self, mcps: list[str | MCPServerConfig]) -> list[BaseTool]:
|
||||||
"""Convert MCP server references to CrewAI tools."""
|
"""Convert MCP server references/configs to CrewAI tools.
|
||||||
|
|
||||||
|
Supports both string references (backwards compatible) and structured
|
||||||
|
configuration objects (MCPServerStdio, MCPServerHTTP, MCPServerSSE).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mcps: List of MCP server references (strings) or configurations.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of BaseTool instances from MCP servers.
|
||||||
|
"""
|
||||||
all_tools = []
|
all_tools = []
|
||||||
|
clients = []
|
||||||
|
|
||||||
for mcp_ref in mcps:
|
for mcp_config in mcps:
|
||||||
try:
|
if isinstance(mcp_config, str):
|
||||||
if mcp_ref.startswith("crewai-amp:"):
|
tools = self._get_mcp_tools_from_string(mcp_config)
|
||||||
tools = self._get_amp_mcp_tools(mcp_ref)
|
else:
|
||||||
elif mcp_ref.startswith("https://"):
|
tools, client = self._get_native_mcp_tools(mcp_config)
|
||||||
tools = self._get_external_mcp_tools(mcp_ref)
|
if client:
|
||||||
else:
|
clients.append(client)
|
||||||
continue
|
|
||||||
|
|
||||||
all_tools.extend(tools)
|
all_tools.extend(tools)
|
||||||
self._logger.log(
|
|
||||||
"info", f"Successfully loaded {len(tools)} tools from {mcp_ref}"
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
self._logger.log("warning", f"Skipping MCP {mcp_ref} due to error: {e}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
|
# Store clients for cleanup
|
||||||
|
self._mcp_clients.extend(clients)
|
||||||
return all_tools
|
return all_tools
|
||||||
|
|
||||||
|
def _cleanup_mcp_clients(self) -> None:
|
||||||
|
"""Cleanup MCP client connections after task execution."""
|
||||||
|
if not self._mcp_clients:
|
||||||
|
return
|
||||||
|
|
||||||
|
async def _disconnect_all() -> None:
|
||||||
|
for client in self._mcp_clients:
|
||||||
|
if client and hasattr(client, "connected") and client.connected:
|
||||||
|
await client.disconnect()
|
||||||
|
|
||||||
|
try:
|
||||||
|
asyncio.run(_disconnect_all())
|
||||||
|
except Exception as e:
|
||||||
|
self._logger.log("error", f"Error during MCP client cleanup: {e}")
|
||||||
|
finally:
|
||||||
|
self._mcp_clients.clear()
|
||||||
|
|
||||||
|
def _get_mcp_tools_from_string(self, mcp_ref: str) -> list[BaseTool]:
|
||||||
|
"""Get tools from legacy string-based MCP references.
|
||||||
|
|
||||||
|
This method maintains backwards compatibility with string-based
|
||||||
|
MCP references (https://... and crewai-amp:...).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mcp_ref: String reference to MCP server.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of BaseTool instances.
|
||||||
|
"""
|
||||||
|
if mcp_ref.startswith("crewai-amp:"):
|
||||||
|
return self._get_amp_mcp_tools(mcp_ref)
|
||||||
|
if mcp_ref.startswith("https://"):
|
||||||
|
return self._get_external_mcp_tools(mcp_ref)
|
||||||
|
return []
|
||||||
|
|
||||||
def _get_external_mcp_tools(self, mcp_ref: str) -> list[BaseTool]:
|
def _get_external_mcp_tools(self, mcp_ref: str) -> list[BaseTool]:
|
||||||
"""Get tools from external HTTPS MCP server with graceful error handling."""
|
"""Get tools from external HTTPS MCP server with graceful error handling."""
|
||||||
from crewai.tools.mcp_tool_wrapper import MCPToolWrapper
|
from crewai.tools.mcp_tool_wrapper import MCPToolWrapper
|
||||||
@@ -731,6 +785,154 @@ class Agent(BaseAgent):
|
|||||||
)
|
)
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
def _get_native_mcp_tools(
|
||||||
|
self, mcp_config: MCPServerConfig
|
||||||
|
) -> tuple[list[BaseTool], Any | None]:
|
||||||
|
"""Get tools from MCP server using structured configuration.
|
||||||
|
|
||||||
|
This method creates an MCP client based on the configuration type,
|
||||||
|
connects to the server, discovers tools, applies filtering, and
|
||||||
|
returns wrapped tools along with the client instance for cleanup.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mcp_config: MCP server configuration (MCPServerStdio, MCPServerHTTP, or MCPServerSSE).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (list of BaseTool instances, MCPClient instance for cleanup).
|
||||||
|
"""
|
||||||
|
from crewai.tools.base_tool import BaseTool
|
||||||
|
from crewai.tools.mcp_native_tool import MCPNativeTool
|
||||||
|
|
||||||
|
if isinstance(mcp_config, MCPServerStdio):
|
||||||
|
transport = StdioTransport(
|
||||||
|
command=mcp_config.command,
|
||||||
|
args=mcp_config.args,
|
||||||
|
env=mcp_config.env,
|
||||||
|
)
|
||||||
|
server_name = f"{mcp_config.command}_{'_'.join(mcp_config.args)}"
|
||||||
|
elif isinstance(mcp_config, MCPServerHTTP):
|
||||||
|
transport = HTTPTransport(
|
||||||
|
url=mcp_config.url,
|
||||||
|
headers=mcp_config.headers,
|
||||||
|
streamable=mcp_config.streamable,
|
||||||
|
)
|
||||||
|
server_name = self._extract_server_name(mcp_config.url)
|
||||||
|
elif isinstance(mcp_config, MCPServerSSE):
|
||||||
|
transport = SSETransport(
|
||||||
|
url=mcp_config.url,
|
||||||
|
headers=mcp_config.headers,
|
||||||
|
)
|
||||||
|
server_name = self._extract_server_name(mcp_config.url)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported MCP server config type: {type(mcp_config)}")
|
||||||
|
|
||||||
|
client = MCPClient(
|
||||||
|
transport=transport,
|
||||||
|
cache_tools_list=mcp_config.cache_tools_list,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _setup_client_and_list_tools() -> list[dict[str, Any]]:
|
||||||
|
"""Async helper to connect and list tools in same event loop."""
|
||||||
|
|
||||||
|
try:
|
||||||
|
if not client.connected:
|
||||||
|
await client.connect()
|
||||||
|
|
||||||
|
tools_list = await client.list_tools()
|
||||||
|
|
||||||
|
try:
|
||||||
|
await client.disconnect()
|
||||||
|
# Small delay to allow background tasks to finish cleanup
|
||||||
|
# This helps prevent "cancel scope in different task" errors
|
||||||
|
# when asyncio.run() closes the event loop
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
except Exception as e:
|
||||||
|
self._logger.log("error", f"Error during disconnect: {e}")
|
||||||
|
|
||||||
|
return tools_list
|
||||||
|
except Exception as e:
|
||||||
|
if client.connected:
|
||||||
|
await client.disconnect()
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Error during setup client and list tools: {e}"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
try:
|
||||||
|
try:
|
||||||
|
tools_list = asyncio.run(_setup_client_and_list_tools())
|
||||||
|
except RuntimeError as e:
|
||||||
|
error_msg = str(e).lower()
|
||||||
|
if "cancel scope" in error_msg or "task" in error_msg:
|
||||||
|
raise ConnectionError(
|
||||||
|
"MCP connection failed due to event loop cleanup issues. "
|
||||||
|
"This may be due to authentication errors or server unavailability."
|
||||||
|
) from e
|
||||||
|
except asyncio.CancelledError as e:
|
||||||
|
raise ConnectionError(
|
||||||
|
"MCP connection was cancelled. This may indicate an authentication "
|
||||||
|
"error or server unavailability."
|
||||||
|
) from e
|
||||||
|
|
||||||
|
if mcp_config.tool_filter:
|
||||||
|
filtered_tools = []
|
||||||
|
for tool in tools_list:
|
||||||
|
if callable(mcp_config.tool_filter):
|
||||||
|
try:
|
||||||
|
from crewai.mcp.filters import ToolFilterContext
|
||||||
|
|
||||||
|
context = ToolFilterContext(
|
||||||
|
agent=self,
|
||||||
|
server_name=server_name,
|
||||||
|
run_context=None,
|
||||||
|
)
|
||||||
|
if mcp_config.tool_filter(context, tool):
|
||||||
|
filtered_tools.append(tool)
|
||||||
|
except (TypeError, AttributeError):
|
||||||
|
if mcp_config.tool_filter(tool):
|
||||||
|
filtered_tools.append(tool)
|
||||||
|
else:
|
||||||
|
# Not callable - include tool
|
||||||
|
filtered_tools.append(tool)
|
||||||
|
tools_list = filtered_tools
|
||||||
|
|
||||||
|
tools = []
|
||||||
|
for tool_def in tools_list:
|
||||||
|
tool_name = tool_def.get("name", "")
|
||||||
|
if not tool_name:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Convert inputSchema to Pydantic model if present
|
||||||
|
args_schema = None
|
||||||
|
if tool_def.get("inputSchema"):
|
||||||
|
args_schema = self._json_schema_to_pydantic(
|
||||||
|
tool_name, tool_def["inputSchema"]
|
||||||
|
)
|
||||||
|
|
||||||
|
tool_schema = {
|
||||||
|
"description": tool_def.get("description", ""),
|
||||||
|
"args_schema": args_schema,
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
native_tool = MCPNativeTool(
|
||||||
|
mcp_client=client,
|
||||||
|
tool_name=tool_name,
|
||||||
|
tool_schema=tool_schema,
|
||||||
|
server_name=server_name,
|
||||||
|
)
|
||||||
|
tools.append(native_tool)
|
||||||
|
except Exception as e:
|
||||||
|
self._logger.log("error", f"Failed to create native MCP tool: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
return cast(list[BaseTool], tools), client
|
||||||
|
except Exception as e:
|
||||||
|
if client.connected:
|
||||||
|
asyncio.run(client.disconnect())
|
||||||
|
|
||||||
|
raise RuntimeError(f"Failed to get native MCP tools: {e}") from e
|
||||||
|
|
||||||
def _get_amp_mcp_tools(self, amp_ref: str) -> list[BaseTool]:
|
def _get_amp_mcp_tools(self, amp_ref: str) -> list[BaseTool]:
|
||||||
"""Get tools from CrewAI AMP MCP marketplace."""
|
"""Get tools from CrewAI AMP MCP marketplace."""
|
||||||
# Parse: "crewai-amp:mcp-name" or "crewai-amp:mcp-name#tool_name"
|
# Parse: "crewai-amp:mcp-name" or "crewai-amp:mcp-name#tool_name"
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ from crewai.agents.tools_handler import ToolsHandler
|
|||||||
from crewai.knowledge.knowledge import Knowledge
|
from crewai.knowledge.knowledge import Knowledge
|
||||||
from crewai.knowledge.knowledge_config import KnowledgeConfig
|
from crewai.knowledge.knowledge_config import KnowledgeConfig
|
||||||
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
|
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
|
||||||
|
from crewai.mcp.config import MCPServerConfig
|
||||||
from crewai.rag.embeddings.types import EmbedderConfig
|
from crewai.rag.embeddings.types import EmbedderConfig
|
||||||
from crewai.security.security_config import SecurityConfig
|
from crewai.security.security_config import SecurityConfig
|
||||||
from crewai.tools.base_tool import BaseTool, Tool
|
from crewai.tools.base_tool import BaseTool, Tool
|
||||||
@@ -194,7 +195,7 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta):
|
|||||||
default=None,
|
default=None,
|
||||||
description="List of applications or application/action combinations that the agent can access through CrewAI Platform. Can contain app names (e.g., 'gmail') or specific actions (e.g., 'gmail/send_email')",
|
description="List of applications or application/action combinations that the agent can access through CrewAI Platform. Can contain app names (e.g., 'gmail') or specific actions (e.g., 'gmail/send_email')",
|
||||||
)
|
)
|
||||||
mcps: list[str] | None = Field(
|
mcps: list[str | MCPServerConfig] | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="List of MCP server references. Supports 'https://server.com/path' for external servers and 'crewai-amp:mcp-name' for AMP marketplace. Use '#tool_name' suffix for specific tools.",
|
description="List of MCP server references. Supports 'https://server.com/path' for external servers and 'crewai-amp:mcp-name' for AMP marketplace. Use '#tool_name' suffix for specific tools.",
|
||||||
)
|
)
|
||||||
@@ -253,20 +254,36 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta):
|
|||||||
|
|
||||||
@field_validator("mcps")
|
@field_validator("mcps")
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_mcps(cls, mcps: list[str] | None) -> list[str] | None:
|
def validate_mcps(
|
||||||
|
cls, mcps: list[str | MCPServerConfig] | None
|
||||||
|
) -> list[str | MCPServerConfig] | None:
|
||||||
|
"""Validate MCP server references and configurations.
|
||||||
|
|
||||||
|
Supports both string references (for backwards compatibility) and
|
||||||
|
structured configuration objects (MCPServerStdio, MCPServerHTTP, MCPServerSSE).
|
||||||
|
"""
|
||||||
if not mcps:
|
if not mcps:
|
||||||
return mcps
|
return mcps
|
||||||
|
|
||||||
validated_mcps = []
|
validated_mcps = []
|
||||||
for mcp in mcps:
|
for mcp in mcps:
|
||||||
if mcp.startswith(("https://", "crewai-amp:")):
|
if isinstance(mcp, str):
|
||||||
|
if mcp.startswith(("https://", "crewai-amp:")):
|
||||||
|
validated_mcps.append(mcp)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid MCP reference: {mcp}. "
|
||||||
|
"String references must start with 'https://' or 'crewai-amp:'"
|
||||||
|
)
|
||||||
|
|
||||||
|
elif isinstance(mcp, (MCPServerConfig)):
|
||||||
validated_mcps.append(mcp)
|
validated_mcps.append(mcp)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Invalid MCP reference: {mcp}. Must start with 'https://' or 'crewai-amp:'"
|
f"Invalid MCP configuration: {type(mcp)}. "
|
||||||
|
"Must be a string reference or MCPServerConfig instance."
|
||||||
)
|
)
|
||||||
|
return validated_mcps
|
||||||
return list(set(validated_mcps))
|
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def validate_and_set_attributes(self) -> Self:
|
def validate_and_set_attributes(self) -> Self:
|
||||||
@@ -343,7 +360,7 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta):
|
|||||||
"""Get platform tools for the specified list of applications and/or application/action combinations."""
|
"""Get platform tools for the specified list of applications and/or application/action combinations."""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_mcp_tools(self, mcps: list[str]) -> list[BaseTool]:
|
def get_mcp_tools(self, mcps: list[str | MCPServerConfig]) -> list[BaseTool]:
|
||||||
"""Get MCP tools for the specified list of MCP server references."""
|
"""Get MCP tools for the specified list of MCP server references."""
|
||||||
|
|
||||||
def copy(self) -> Self: # type: ignore # Signature of "copy" incompatible with supertype "BaseModel"
|
def copy(self) -> Self: # type: ignore # Signature of "copy" incompatible with supertype "BaseModel"
|
||||||
|
|||||||
@@ -16,7 +16,6 @@ from crewai.events.base_event_listener import BaseEventListener
|
|||||||
from crewai.events.depends import Depends
|
from crewai.events.depends import Depends
|
||||||
from crewai.events.event_bus import crewai_event_bus
|
from crewai.events.event_bus import crewai_event_bus
|
||||||
from crewai.events.handler_graph import CircularDependencyError
|
from crewai.events.handler_graph import CircularDependencyError
|
||||||
|
|
||||||
from crewai.events.types.crew_events import (
|
from crewai.events.types.crew_events import (
|
||||||
CrewKickoffCompletedEvent,
|
CrewKickoffCompletedEvent,
|
||||||
CrewKickoffFailedEvent,
|
CrewKickoffFailedEvent,
|
||||||
@@ -61,6 +60,14 @@ from crewai.events.types.logging_events import (
|
|||||||
AgentLogsExecutionEvent,
|
AgentLogsExecutionEvent,
|
||||||
AgentLogsStartedEvent,
|
AgentLogsStartedEvent,
|
||||||
)
|
)
|
||||||
|
from crewai.events.types.mcp_events import (
|
||||||
|
MCPConnectionCompletedEvent,
|
||||||
|
MCPConnectionFailedEvent,
|
||||||
|
MCPConnectionStartedEvent,
|
||||||
|
MCPToolExecutionCompletedEvent,
|
||||||
|
MCPToolExecutionFailedEvent,
|
||||||
|
MCPToolExecutionStartedEvent,
|
||||||
|
)
|
||||||
from crewai.events.types.memory_events import (
|
from crewai.events.types.memory_events import (
|
||||||
MemoryQueryCompletedEvent,
|
MemoryQueryCompletedEvent,
|
||||||
MemoryQueryFailedEvent,
|
MemoryQueryFailedEvent,
|
||||||
@@ -153,6 +160,12 @@ __all__ = [
|
|||||||
"LiteAgentExecutionCompletedEvent",
|
"LiteAgentExecutionCompletedEvent",
|
||||||
"LiteAgentExecutionErrorEvent",
|
"LiteAgentExecutionErrorEvent",
|
||||||
"LiteAgentExecutionStartedEvent",
|
"LiteAgentExecutionStartedEvent",
|
||||||
|
"MCPConnectionCompletedEvent",
|
||||||
|
"MCPConnectionFailedEvent",
|
||||||
|
"MCPConnectionStartedEvent",
|
||||||
|
"MCPToolExecutionCompletedEvent",
|
||||||
|
"MCPToolExecutionFailedEvent",
|
||||||
|
"MCPToolExecutionStartedEvent",
|
||||||
"MemoryQueryCompletedEvent",
|
"MemoryQueryCompletedEvent",
|
||||||
"MemoryQueryFailedEvent",
|
"MemoryQueryFailedEvent",
|
||||||
"MemoryQueryStartedEvent",
|
"MemoryQueryStartedEvent",
|
||||||
|
|||||||
@@ -65,6 +65,14 @@ from crewai.events.types.logging_events import (
|
|||||||
AgentLogsExecutionEvent,
|
AgentLogsExecutionEvent,
|
||||||
AgentLogsStartedEvent,
|
AgentLogsStartedEvent,
|
||||||
)
|
)
|
||||||
|
from crewai.events.types.mcp_events import (
|
||||||
|
MCPConnectionCompletedEvent,
|
||||||
|
MCPConnectionFailedEvent,
|
||||||
|
MCPConnectionStartedEvent,
|
||||||
|
MCPToolExecutionCompletedEvent,
|
||||||
|
MCPToolExecutionFailedEvent,
|
||||||
|
MCPToolExecutionStartedEvent,
|
||||||
|
)
|
||||||
from crewai.events.types.reasoning_events import (
|
from crewai.events.types.reasoning_events import (
|
||||||
AgentReasoningCompletedEvent,
|
AgentReasoningCompletedEvent,
|
||||||
AgentReasoningFailedEvent,
|
AgentReasoningFailedEvent,
|
||||||
@@ -615,5 +623,67 @@ class EventListener(BaseEventListener):
|
|||||||
event.total_turns,
|
event.total_turns,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# ----------- MCP EVENTS -----------
|
||||||
|
|
||||||
|
@crewai_event_bus.on(MCPConnectionStartedEvent)
|
||||||
|
def on_mcp_connection_started(source, event: MCPConnectionStartedEvent):
|
||||||
|
self.formatter.handle_mcp_connection_started(
|
||||||
|
event.server_name,
|
||||||
|
event.server_url,
|
||||||
|
event.transport_type,
|
||||||
|
event.is_reconnect,
|
||||||
|
event.connect_timeout,
|
||||||
|
)
|
||||||
|
|
||||||
|
@crewai_event_bus.on(MCPConnectionCompletedEvent)
|
||||||
|
def on_mcp_connection_completed(source, event: MCPConnectionCompletedEvent):
|
||||||
|
self.formatter.handle_mcp_connection_completed(
|
||||||
|
event.server_name,
|
||||||
|
event.server_url,
|
||||||
|
event.transport_type,
|
||||||
|
event.connection_duration_ms,
|
||||||
|
event.is_reconnect,
|
||||||
|
)
|
||||||
|
|
||||||
|
@crewai_event_bus.on(MCPConnectionFailedEvent)
|
||||||
|
def on_mcp_connection_failed(source, event: MCPConnectionFailedEvent):
|
||||||
|
self.formatter.handle_mcp_connection_failed(
|
||||||
|
event.server_name,
|
||||||
|
event.server_url,
|
||||||
|
event.transport_type,
|
||||||
|
event.error,
|
||||||
|
event.error_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
@crewai_event_bus.on(MCPToolExecutionStartedEvent)
|
||||||
|
def on_mcp_tool_execution_started(source, event: MCPToolExecutionStartedEvent):
|
||||||
|
self.formatter.handle_mcp_tool_execution_started(
|
||||||
|
event.server_name,
|
||||||
|
event.tool_name,
|
||||||
|
event.tool_args,
|
||||||
|
)
|
||||||
|
|
||||||
|
@crewai_event_bus.on(MCPToolExecutionCompletedEvent)
|
||||||
|
def on_mcp_tool_execution_completed(
|
||||||
|
source, event: MCPToolExecutionCompletedEvent
|
||||||
|
):
|
||||||
|
self.formatter.handle_mcp_tool_execution_completed(
|
||||||
|
event.server_name,
|
||||||
|
event.tool_name,
|
||||||
|
event.tool_args,
|
||||||
|
event.result,
|
||||||
|
event.execution_duration_ms,
|
||||||
|
)
|
||||||
|
|
||||||
|
@crewai_event_bus.on(MCPToolExecutionFailedEvent)
|
||||||
|
def on_mcp_tool_execution_failed(source, event: MCPToolExecutionFailedEvent):
|
||||||
|
self.formatter.handle_mcp_tool_execution_failed(
|
||||||
|
event.server_name,
|
||||||
|
event.tool_name,
|
||||||
|
event.tool_args,
|
||||||
|
event.error,
|
||||||
|
event.error_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
event_listener = EventListener()
|
event_listener = EventListener()
|
||||||
|
|||||||
@@ -40,6 +40,14 @@ from crewai.events.types.llm_guardrail_events import (
|
|||||||
LLMGuardrailCompletedEvent,
|
LLMGuardrailCompletedEvent,
|
||||||
LLMGuardrailStartedEvent,
|
LLMGuardrailStartedEvent,
|
||||||
)
|
)
|
||||||
|
from crewai.events.types.mcp_events import (
|
||||||
|
MCPConnectionCompletedEvent,
|
||||||
|
MCPConnectionFailedEvent,
|
||||||
|
MCPConnectionStartedEvent,
|
||||||
|
MCPToolExecutionCompletedEvent,
|
||||||
|
MCPToolExecutionFailedEvent,
|
||||||
|
MCPToolExecutionStartedEvent,
|
||||||
|
)
|
||||||
from crewai.events.types.memory_events import (
|
from crewai.events.types.memory_events import (
|
||||||
MemoryQueryCompletedEvent,
|
MemoryQueryCompletedEvent,
|
||||||
MemoryQueryFailedEvent,
|
MemoryQueryFailedEvent,
|
||||||
@@ -115,4 +123,10 @@ EventTypes = (
|
|||||||
| MemoryQueryFailedEvent
|
| MemoryQueryFailedEvent
|
||||||
| MemoryRetrievalStartedEvent
|
| MemoryRetrievalStartedEvent
|
||||||
| MemoryRetrievalCompletedEvent
|
| MemoryRetrievalCompletedEvent
|
||||||
|
| MCPConnectionStartedEvent
|
||||||
|
| MCPConnectionCompletedEvent
|
||||||
|
| MCPConnectionFailedEvent
|
||||||
|
| MCPToolExecutionStartedEvent
|
||||||
|
| MCPToolExecutionCompletedEvent
|
||||||
|
| MCPToolExecutionFailedEvent
|
||||||
)
|
)
|
||||||
|
|||||||
85
lib/crewai/src/crewai/events/types/mcp_events.py
Normal file
85
lib/crewai/src/crewai/events/types/mcp_events.py
Normal file
@@ -0,0 +1,85 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from crewai.events.base_events import BaseEvent
|
||||||
|
|
||||||
|
|
||||||
|
class MCPEvent(BaseEvent):
|
||||||
|
"""Base event for MCP operations."""
|
||||||
|
|
||||||
|
server_name: str
|
||||||
|
server_url: str | None = None
|
||||||
|
transport_type: str | None = None # "stdio", "http", "sse"
|
||||||
|
agent_id: str | None = None
|
||||||
|
agent_role: str | None = None
|
||||||
|
from_agent: Any | None = None
|
||||||
|
from_task: Any | None = None
|
||||||
|
|
||||||
|
def __init__(self, **data):
|
||||||
|
super().__init__(**data)
|
||||||
|
self._set_agent_params(data)
|
||||||
|
self._set_task_params(data)
|
||||||
|
|
||||||
|
|
||||||
|
class MCPConnectionStartedEvent(MCPEvent):
|
||||||
|
"""Event emitted when starting to connect to an MCP server."""
|
||||||
|
|
||||||
|
type: str = "mcp_connection_started"
|
||||||
|
connect_timeout: int | None = None
|
||||||
|
is_reconnect: bool = (
|
||||||
|
False # True if this is a reconnection, False for first connection
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MCPConnectionCompletedEvent(MCPEvent):
|
||||||
|
"""Event emitted when successfully connected to an MCP server."""
|
||||||
|
|
||||||
|
type: str = "mcp_connection_completed"
|
||||||
|
started_at: datetime | None = None
|
||||||
|
completed_at: datetime | None = None
|
||||||
|
connection_duration_ms: float | None = None
|
||||||
|
is_reconnect: bool = (
|
||||||
|
False # True if this was a reconnection, False for first connection
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MCPConnectionFailedEvent(MCPEvent):
|
||||||
|
"""Event emitted when connection to an MCP server fails."""
|
||||||
|
|
||||||
|
type: str = "mcp_connection_failed"
|
||||||
|
error: str
|
||||||
|
error_type: str | None = None # "timeout", "authentication", "network", etc.
|
||||||
|
started_at: datetime | None = None
|
||||||
|
failed_at: datetime | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class MCPToolExecutionStartedEvent(MCPEvent):
|
||||||
|
"""Event emitted when starting to execute an MCP tool."""
|
||||||
|
|
||||||
|
type: str = "mcp_tool_execution_started"
|
||||||
|
tool_name: str
|
||||||
|
tool_args: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class MCPToolExecutionCompletedEvent(MCPEvent):
|
||||||
|
"""Event emitted when MCP tool execution completes."""
|
||||||
|
|
||||||
|
type: str = "mcp_tool_execution_completed"
|
||||||
|
tool_name: str
|
||||||
|
tool_args: dict[str, Any] | None = None
|
||||||
|
result: Any | None = None
|
||||||
|
started_at: datetime | None = None
|
||||||
|
completed_at: datetime | None = None
|
||||||
|
execution_duration_ms: float | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class MCPToolExecutionFailedEvent(MCPEvent):
|
||||||
|
"""Event emitted when MCP tool execution fails."""
|
||||||
|
|
||||||
|
type: str = "mcp_tool_execution_failed"
|
||||||
|
tool_name: str
|
||||||
|
tool_args: dict[str, Any] | None = None
|
||||||
|
error: str
|
||||||
|
error_type: str | None = None # "timeout", "validation", "server_error", etc.
|
||||||
|
started_at: datetime | None = None
|
||||||
|
failed_at: datetime | None = None
|
||||||
@@ -2248,3 +2248,203 @@ class ConsoleFormatter:
|
|||||||
|
|
||||||
self.current_a2a_conversation_branch = None
|
self.current_a2a_conversation_branch = None
|
||||||
self.current_a2a_turn_count = 0
|
self.current_a2a_turn_count = 0
|
||||||
|
|
||||||
|
# ----------- MCP EVENTS -----------
|
||||||
|
|
||||||
|
def handle_mcp_connection_started(
|
||||||
|
self,
|
||||||
|
server_name: str,
|
||||||
|
server_url: str | None = None,
|
||||||
|
transport_type: str | None = None,
|
||||||
|
is_reconnect: bool = False,
|
||||||
|
connect_timeout: int | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Handle MCP connection started event."""
|
||||||
|
if not self.verbose:
|
||||||
|
return
|
||||||
|
|
||||||
|
content = Text()
|
||||||
|
reconnect_text = " (Reconnecting)" if is_reconnect else ""
|
||||||
|
content.append(f"MCP Connection Started{reconnect_text}\n\n", style="cyan bold")
|
||||||
|
content.append("Server: ", style="white")
|
||||||
|
content.append(f"{server_name}\n", style="cyan")
|
||||||
|
|
||||||
|
if server_url:
|
||||||
|
content.append("URL: ", style="white")
|
||||||
|
content.append(f"{server_url}\n", style="cyan dim")
|
||||||
|
|
||||||
|
if transport_type:
|
||||||
|
content.append("Transport: ", style="white")
|
||||||
|
content.append(f"{transport_type}\n", style="cyan")
|
||||||
|
|
||||||
|
if connect_timeout:
|
||||||
|
content.append("Timeout: ", style="white")
|
||||||
|
content.append(f"{connect_timeout}s\n", style="cyan")
|
||||||
|
|
||||||
|
panel = self.create_panel(content, "🔌 MCP Connection", "cyan")
|
||||||
|
self.print(panel)
|
||||||
|
self.print()
|
||||||
|
|
||||||
|
def handle_mcp_connection_completed(
|
||||||
|
self,
|
||||||
|
server_name: str,
|
||||||
|
server_url: str | None = None,
|
||||||
|
transport_type: str | None = None,
|
||||||
|
connection_duration_ms: float | None = None,
|
||||||
|
is_reconnect: bool = False,
|
||||||
|
) -> None:
|
||||||
|
"""Handle MCP connection completed event."""
|
||||||
|
if not self.verbose:
|
||||||
|
return
|
||||||
|
|
||||||
|
content = Text()
|
||||||
|
reconnect_text = " (Reconnected)" if is_reconnect else ""
|
||||||
|
content.append(
|
||||||
|
f"MCP Connection Completed{reconnect_text}\n\n", style="green bold"
|
||||||
|
)
|
||||||
|
content.append("Server: ", style="white")
|
||||||
|
content.append(f"{server_name}\n", style="green")
|
||||||
|
|
||||||
|
if server_url:
|
||||||
|
content.append("URL: ", style="white")
|
||||||
|
content.append(f"{server_url}\n", style="green dim")
|
||||||
|
|
||||||
|
if transport_type:
|
||||||
|
content.append("Transport: ", style="white")
|
||||||
|
content.append(f"{transport_type}\n", style="green")
|
||||||
|
|
||||||
|
if connection_duration_ms is not None:
|
||||||
|
content.append("Duration: ", style="white")
|
||||||
|
content.append(f"{connection_duration_ms:.2f}ms\n", style="green")
|
||||||
|
|
||||||
|
panel = self.create_panel(content, "✅ MCP Connected", "green")
|
||||||
|
self.print(panel)
|
||||||
|
self.print()
|
||||||
|
|
||||||
|
def handle_mcp_connection_failed(
|
||||||
|
self,
|
||||||
|
server_name: str,
|
||||||
|
server_url: str | None = None,
|
||||||
|
transport_type: str | None = None,
|
||||||
|
error: str = "",
|
||||||
|
error_type: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Handle MCP connection failed event."""
|
||||||
|
if not self.verbose:
|
||||||
|
return
|
||||||
|
|
||||||
|
content = Text()
|
||||||
|
content.append("MCP Connection Failed\n\n", style="red bold")
|
||||||
|
content.append("Server: ", style="white")
|
||||||
|
content.append(f"{server_name}\n", style="red")
|
||||||
|
|
||||||
|
if server_url:
|
||||||
|
content.append("URL: ", style="white")
|
||||||
|
content.append(f"{server_url}\n", style="red dim")
|
||||||
|
|
||||||
|
if transport_type:
|
||||||
|
content.append("Transport: ", style="white")
|
||||||
|
content.append(f"{transport_type}\n", style="red")
|
||||||
|
|
||||||
|
if error_type:
|
||||||
|
content.append("Error Type: ", style="white")
|
||||||
|
content.append(f"{error_type}\n", style="red")
|
||||||
|
|
||||||
|
if error:
|
||||||
|
content.append("\nError: ", style="white bold")
|
||||||
|
error_preview = error[:500] + "..." if len(error) > 500 else error
|
||||||
|
content.append(f"{error_preview}\n", style="red")
|
||||||
|
|
||||||
|
panel = self.create_panel(content, "❌ MCP Connection Failed", "red")
|
||||||
|
self.print(panel)
|
||||||
|
self.print()
|
||||||
|
|
||||||
|
def handle_mcp_tool_execution_started(
|
||||||
|
self,
|
||||||
|
server_name: str,
|
||||||
|
tool_name: str,
|
||||||
|
tool_args: dict[str, Any] | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Handle MCP tool execution started event."""
|
||||||
|
if not self.verbose:
|
||||||
|
return
|
||||||
|
|
||||||
|
content = self.create_status_content(
|
||||||
|
"MCP Tool Execution Started",
|
||||||
|
tool_name,
|
||||||
|
"yellow",
|
||||||
|
tool_args=tool_args or {},
|
||||||
|
Server=server_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
panel = self.create_panel(content, "🔧 MCP Tool", "yellow")
|
||||||
|
self.print(panel)
|
||||||
|
self.print()
|
||||||
|
|
||||||
|
def handle_mcp_tool_execution_completed(
|
||||||
|
self,
|
||||||
|
server_name: str,
|
||||||
|
tool_name: str,
|
||||||
|
tool_args: dict[str, Any] | None = None,
|
||||||
|
result: Any | None = None,
|
||||||
|
execution_duration_ms: float | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Handle MCP tool execution completed event."""
|
||||||
|
if not self.verbose:
|
||||||
|
return
|
||||||
|
|
||||||
|
content = self.create_status_content(
|
||||||
|
"MCP Tool Execution Completed",
|
||||||
|
tool_name,
|
||||||
|
"green",
|
||||||
|
tool_args=tool_args or {},
|
||||||
|
Server=server_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
if execution_duration_ms is not None:
|
||||||
|
content.append("Duration: ", style="white")
|
||||||
|
content.append(f"{execution_duration_ms:.2f}ms\n", style="green")
|
||||||
|
|
||||||
|
if result is not None:
|
||||||
|
result_str = str(result)
|
||||||
|
if len(result_str) > 500:
|
||||||
|
result_str = result_str[:497] + "..."
|
||||||
|
content.append("\nResult: ", style="white bold")
|
||||||
|
content.append(f"{result_str}\n", style="green")
|
||||||
|
|
||||||
|
panel = self.create_panel(content, "✅ MCP Tool Completed", "green")
|
||||||
|
self.print(panel)
|
||||||
|
self.print()
|
||||||
|
|
||||||
|
def handle_mcp_tool_execution_failed(
|
||||||
|
self,
|
||||||
|
server_name: str,
|
||||||
|
tool_name: str,
|
||||||
|
tool_args: dict[str, Any] | None = None,
|
||||||
|
error: str = "",
|
||||||
|
error_type: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Handle MCP tool execution failed event."""
|
||||||
|
if not self.verbose:
|
||||||
|
return
|
||||||
|
|
||||||
|
content = self.create_status_content(
|
||||||
|
"MCP Tool Execution Failed",
|
||||||
|
tool_name,
|
||||||
|
"red",
|
||||||
|
tool_args=tool_args or {},
|
||||||
|
Server=server_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
if error_type:
|
||||||
|
content.append("Error Type: ", style="white")
|
||||||
|
content.append(f"{error_type}\n", style="red")
|
||||||
|
|
||||||
|
if error:
|
||||||
|
content.append("\nError: ", style="white bold")
|
||||||
|
error_preview = error[:500] + "..." if len(error) > 500 else error
|
||||||
|
content.append(f"{error_preview}\n", style="red")
|
||||||
|
|
||||||
|
panel = self.create_panel(content, "❌ MCP Tool Failed", "red")
|
||||||
|
self.print(panel)
|
||||||
|
self.print()
|
||||||
|
|||||||
37
lib/crewai/src/crewai/mcp/__init__.py
Normal file
37
lib/crewai/src/crewai/mcp/__init__.py
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
"""MCP (Model Context Protocol) client support for CrewAI agents.
|
||||||
|
|
||||||
|
This module provides native MCP client functionality, allowing CrewAI agents
|
||||||
|
to connect to any MCP-compliant server using various transport types.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from crewai.mcp.client import MCPClient
|
||||||
|
from crewai.mcp.config import (
|
||||||
|
MCPServerConfig,
|
||||||
|
MCPServerHTTP,
|
||||||
|
MCPServerSSE,
|
||||||
|
MCPServerStdio,
|
||||||
|
)
|
||||||
|
from crewai.mcp.filters import (
|
||||||
|
StaticToolFilter,
|
||||||
|
ToolFilter,
|
||||||
|
ToolFilterContext,
|
||||||
|
create_dynamic_tool_filter,
|
||||||
|
create_static_tool_filter,
|
||||||
|
)
|
||||||
|
from crewai.mcp.transports.base import BaseTransport, TransportType
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"BaseTransport",
|
||||||
|
"MCPClient",
|
||||||
|
"MCPServerConfig",
|
||||||
|
"MCPServerHTTP",
|
||||||
|
"MCPServerSSE",
|
||||||
|
"MCPServerStdio",
|
||||||
|
"StaticToolFilter",
|
||||||
|
"ToolFilter",
|
||||||
|
"ToolFilterContext",
|
||||||
|
"TransportType",
|
||||||
|
"create_dynamic_tool_filter",
|
||||||
|
"create_static_tool_filter",
|
||||||
|
]
|
||||||
742
lib/crewai/src/crewai/mcp/client.py
Normal file
742
lib/crewai/src/crewai/mcp/client.py
Normal file
@@ -0,0 +1,742 @@
|
|||||||
|
"""MCP client with session management for CrewAI agents."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from collections.abc import Callable
|
||||||
|
from contextlib import AsyncExitStack
|
||||||
|
from datetime import datetime
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from typing_extensions import Self
|
||||||
|
|
||||||
|
|
||||||
|
# BaseExceptionGroup is available in Python 3.11+
|
||||||
|
try:
|
||||||
|
from builtins import BaseExceptionGroup
|
||||||
|
except ImportError:
|
||||||
|
# Fallback for Python < 3.11 (shouldn't happen in practice)
|
||||||
|
BaseExceptionGroup = Exception
|
||||||
|
|
||||||
|
from crewai.events.event_bus import crewai_event_bus
|
||||||
|
from crewai.events.types.mcp_events import (
|
||||||
|
MCPConnectionCompletedEvent,
|
||||||
|
MCPConnectionFailedEvent,
|
||||||
|
MCPConnectionStartedEvent,
|
||||||
|
MCPToolExecutionCompletedEvent,
|
||||||
|
MCPToolExecutionFailedEvent,
|
||||||
|
MCPToolExecutionStartedEvent,
|
||||||
|
)
|
||||||
|
from crewai.mcp.transports.base import BaseTransport
|
||||||
|
from crewai.mcp.transports.http import HTTPTransport
|
||||||
|
from crewai.mcp.transports.sse import SSETransport
|
||||||
|
from crewai.mcp.transports.stdio import StdioTransport
|
||||||
|
|
||||||
|
|
||||||
|
# MCP Connection timeout constants (in seconds)
|
||||||
|
MCP_CONNECTION_TIMEOUT = 30 # Increased for slow servers
|
||||||
|
MCP_TOOL_EXECUTION_TIMEOUT = 30
|
||||||
|
MCP_DISCOVERY_TIMEOUT = 30 # Increased for slow servers
|
||||||
|
MCP_MAX_RETRIES = 3
|
||||||
|
|
||||||
|
# Simple in-memory cache for MCP tool schemas (duration: 5 minutes)
|
||||||
|
_mcp_schema_cache: dict[str, tuple[dict[str, Any], float]] = {}
|
||||||
|
_cache_ttl = 300 # 5 minutes
|
||||||
|
|
||||||
|
|
||||||
|
class MCPClient:
|
||||||
|
"""MCP client with session management.
|
||||||
|
|
||||||
|
This client manages connections to MCP servers and provides a high-level
|
||||||
|
interface for interacting with MCP tools, prompts, and resources.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
transport = StdioTransport(command="python", args=["server.py"])
|
||||||
|
client = MCPClient(transport)
|
||||||
|
async with client:
|
||||||
|
tools = await client.list_tools()
|
||||||
|
result = await client.call_tool("tool_name", {"arg": "value"})
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
transport: BaseTransport,
|
||||||
|
connect_timeout: int = MCP_CONNECTION_TIMEOUT,
|
||||||
|
execution_timeout: int = MCP_TOOL_EXECUTION_TIMEOUT,
|
||||||
|
discovery_timeout: int = MCP_DISCOVERY_TIMEOUT,
|
||||||
|
max_retries: int = MCP_MAX_RETRIES,
|
||||||
|
cache_tools_list: bool = False,
|
||||||
|
logger: logging.Logger | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Initialize MCP client.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
transport: Transport instance for MCP server connection.
|
||||||
|
connect_timeout: Connection timeout in seconds.
|
||||||
|
execution_timeout: Tool execution timeout in seconds.
|
||||||
|
discovery_timeout: Tool discovery timeout in seconds.
|
||||||
|
max_retries: Maximum retry attempts for operations.
|
||||||
|
cache_tools_list: Whether to cache tool list results.
|
||||||
|
logger: Optional logger instance.
|
||||||
|
"""
|
||||||
|
self.transport = transport
|
||||||
|
self.connect_timeout = connect_timeout
|
||||||
|
self.execution_timeout = execution_timeout
|
||||||
|
self.discovery_timeout = discovery_timeout
|
||||||
|
self.max_retries = max_retries
|
||||||
|
self.cache_tools_list = cache_tools_list
|
||||||
|
# self._logger = logger or logging.getLogger(__name__)
|
||||||
|
self._session: Any = None
|
||||||
|
self._initialized = False
|
||||||
|
self._exit_stack = AsyncExitStack()
|
||||||
|
self._was_connected = False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def connected(self) -> bool:
|
||||||
|
"""Check if client is connected to server."""
|
||||||
|
return self.transport.connected and self._initialized
|
||||||
|
|
||||||
|
@property
|
||||||
|
def session(self) -> Any:
|
||||||
|
"""Get the MCP session."""
|
||||||
|
if self._session is None:
|
||||||
|
raise RuntimeError("Client not connected. Call connect() first.")
|
||||||
|
return self._session
|
||||||
|
|
||||||
|
def _get_server_info(self) -> tuple[str, str | None, str | None]:
|
||||||
|
"""Get server information for events.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (server_name, server_url, transport_type).
|
||||||
|
"""
|
||||||
|
if isinstance(self.transport, StdioTransport):
|
||||||
|
server_name = f"{self.transport.command} {' '.join(self.transport.args)}"
|
||||||
|
server_url = None
|
||||||
|
transport_type = self.transport.transport_type.value
|
||||||
|
elif isinstance(self.transport, HTTPTransport):
|
||||||
|
server_name = self.transport.url
|
||||||
|
server_url = self.transport.url
|
||||||
|
transport_type = self.transport.transport_type.value
|
||||||
|
elif isinstance(self.transport, SSETransport):
|
||||||
|
server_name = self.transport.url
|
||||||
|
server_url = self.transport.url
|
||||||
|
transport_type = self.transport.transport_type.value
|
||||||
|
else:
|
||||||
|
server_name = "Unknown MCP Server"
|
||||||
|
server_url = None
|
||||||
|
transport_type = (
|
||||||
|
self.transport.transport_type.value
|
||||||
|
if hasattr(self.transport, "transport_type")
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
return server_name, server_url, transport_type
|
||||||
|
|
||||||
|
async def connect(self) -> Self:
|
||||||
|
"""Connect to MCP server and initialize session.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Self for method chaining.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ConnectionError: If connection fails.
|
||||||
|
ImportError: If MCP SDK not available.
|
||||||
|
"""
|
||||||
|
if self.connected:
|
||||||
|
return self
|
||||||
|
|
||||||
|
# Get server info for events
|
||||||
|
server_name, server_url, transport_type = self._get_server_info()
|
||||||
|
is_reconnect = self._was_connected
|
||||||
|
|
||||||
|
# Emit connection started event
|
||||||
|
started_at = datetime.now()
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
self,
|
||||||
|
MCPConnectionStartedEvent(
|
||||||
|
server_name=server_name,
|
||||||
|
server_url=server_url,
|
||||||
|
transport_type=transport_type,
|
||||||
|
is_reconnect=is_reconnect,
|
||||||
|
connect_timeout=self.connect_timeout,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from mcp import ClientSession
|
||||||
|
|
||||||
|
# Use AsyncExitStack to manage transport and session contexts together
|
||||||
|
# This ensures they're in the same async scope and prevents cancel scope errors
|
||||||
|
# Always enter transport context via exit stack (it handles already-connected state)
|
||||||
|
await self._exit_stack.enter_async_context(self.transport)
|
||||||
|
|
||||||
|
# Create ClientSession with transport streams
|
||||||
|
self._session = ClientSession(
|
||||||
|
self.transport.read_stream,
|
||||||
|
self.transport.write_stream,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Enter the session's async context manager via exit stack
|
||||||
|
await self._exit_stack.enter_async_context(self._session)
|
||||||
|
|
||||||
|
# Initialize the session (required by MCP protocol)
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(
|
||||||
|
self._session.initialize(),
|
||||||
|
timeout=self.connect_timeout,
|
||||||
|
)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
# If initialization was cancelled (e.g., event loop closing),
|
||||||
|
# cleanup and re-raise - don't suppress cancellation
|
||||||
|
await self._cleanup_on_error()
|
||||||
|
raise
|
||||||
|
except BaseExceptionGroup as eg:
|
||||||
|
# Handle exception groups from anyio task groups
|
||||||
|
# Extract the actual meaningful error (not GeneratorExit)
|
||||||
|
actual_error = None
|
||||||
|
for exc in eg.exceptions:
|
||||||
|
if isinstance(exc, Exception) and not isinstance(
|
||||||
|
exc, GeneratorExit
|
||||||
|
):
|
||||||
|
# Check if it's an HTTP error (like 401)
|
||||||
|
error_msg = str(exc).lower()
|
||||||
|
if "401" in error_msg or "unauthorized" in error_msg:
|
||||||
|
actual_error = exc
|
||||||
|
break
|
||||||
|
if "cancel scope" not in error_msg and "task" not in error_msg:
|
||||||
|
actual_error = exc
|
||||||
|
break
|
||||||
|
|
||||||
|
await self._cleanup_on_error()
|
||||||
|
if actual_error:
|
||||||
|
raise ConnectionError(
|
||||||
|
f"Failed to connect to MCP server: {actual_error}"
|
||||||
|
) from actual_error
|
||||||
|
raise ConnectionError(f"Failed to connect to MCP server: {eg}") from eg
|
||||||
|
|
||||||
|
self._initialized = True
|
||||||
|
self._was_connected = True
|
||||||
|
|
||||||
|
completed_at = datetime.now()
|
||||||
|
connection_duration_ms = (completed_at - started_at).total_seconds() * 1000
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
self,
|
||||||
|
MCPConnectionCompletedEvent(
|
||||||
|
server_name=server_name,
|
||||||
|
server_url=server_url,
|
||||||
|
transport_type=transport_type,
|
||||||
|
started_at=started_at,
|
||||||
|
completed_at=completed_at,
|
||||||
|
connection_duration_ms=connection_duration_ms,
|
||||||
|
is_reconnect=is_reconnect,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
return self
|
||||||
|
except ImportError as e:
|
||||||
|
await self._cleanup_on_error()
|
||||||
|
error_msg = (
|
||||||
|
"MCP library not available. Please install with: pip install mcp"
|
||||||
|
)
|
||||||
|
self._emit_connection_failed(
|
||||||
|
server_name,
|
||||||
|
server_url,
|
||||||
|
transport_type,
|
||||||
|
error_msg,
|
||||||
|
"import_error",
|
||||||
|
started_at,
|
||||||
|
)
|
||||||
|
raise ImportError(error_msg) from e
|
||||||
|
except asyncio.TimeoutError as e:
|
||||||
|
await self._cleanup_on_error()
|
||||||
|
error_msg = f"MCP connection timed out after {self.connect_timeout} seconds. The server may be slow or unreachable."
|
||||||
|
self._emit_connection_failed(
|
||||||
|
server_name,
|
||||||
|
server_url,
|
||||||
|
transport_type,
|
||||||
|
error_msg,
|
||||||
|
"timeout",
|
||||||
|
started_at,
|
||||||
|
)
|
||||||
|
raise ConnectionError(error_msg) from e
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
# Re-raise cancellation - don't suppress it
|
||||||
|
await self._cleanup_on_error()
|
||||||
|
self._emit_connection_failed(
|
||||||
|
server_name,
|
||||||
|
server_url,
|
||||||
|
transport_type,
|
||||||
|
"Connection cancelled",
|
||||||
|
"cancelled",
|
||||||
|
started_at,
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
except BaseExceptionGroup as eg:
|
||||||
|
# Handle exception groups from anyio task groups at outer level
|
||||||
|
actual_error = None
|
||||||
|
for exc in eg.exceptions:
|
||||||
|
if isinstance(exc, Exception) and not isinstance(exc, GeneratorExit):
|
||||||
|
error_msg = str(exc).lower()
|
||||||
|
if "401" in error_msg or "unauthorized" in error_msg:
|
||||||
|
actual_error = exc
|
||||||
|
break
|
||||||
|
if "cancel scope" not in error_msg and "task" not in error_msg:
|
||||||
|
actual_error = exc
|
||||||
|
break
|
||||||
|
|
||||||
|
await self._cleanup_on_error()
|
||||||
|
error_type = (
|
||||||
|
"authentication"
|
||||||
|
if actual_error
|
||||||
|
and (
|
||||||
|
"401" in str(actual_error).lower()
|
||||||
|
or "unauthorized" in str(actual_error).lower()
|
||||||
|
)
|
||||||
|
else "network"
|
||||||
|
)
|
||||||
|
error_msg = str(actual_error) if actual_error else str(eg)
|
||||||
|
self._emit_connection_failed(
|
||||||
|
server_name,
|
||||||
|
server_url,
|
||||||
|
transport_type,
|
||||||
|
error_msg,
|
||||||
|
error_type,
|
||||||
|
started_at,
|
||||||
|
)
|
||||||
|
if actual_error:
|
||||||
|
raise ConnectionError(
|
||||||
|
f"Failed to connect to MCP server: {actual_error}"
|
||||||
|
) from actual_error
|
||||||
|
raise ConnectionError(f"Failed to connect to MCP server: {eg}") from eg
|
||||||
|
except Exception as e:
|
||||||
|
await self._cleanup_on_error()
|
||||||
|
error_type = (
|
||||||
|
"authentication"
|
||||||
|
if "401" in str(e).lower() or "unauthorized" in str(e).lower()
|
||||||
|
else "network"
|
||||||
|
)
|
||||||
|
self._emit_connection_failed(
|
||||||
|
server_name, server_url, transport_type, str(e), error_type, started_at
|
||||||
|
)
|
||||||
|
raise ConnectionError(f"Failed to connect to MCP server: {e}") from e
|
||||||
|
|
||||||
|
def _emit_connection_failed(
|
||||||
|
self,
|
||||||
|
server_name: str,
|
||||||
|
server_url: str | None,
|
||||||
|
transport_type: str | None,
|
||||||
|
error: str,
|
||||||
|
error_type: str,
|
||||||
|
started_at: datetime,
|
||||||
|
) -> None:
|
||||||
|
"""Emit connection failed event."""
|
||||||
|
failed_at = datetime.now()
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
self,
|
||||||
|
MCPConnectionFailedEvent(
|
||||||
|
server_name=server_name,
|
||||||
|
server_url=server_url,
|
||||||
|
transport_type=transport_type,
|
||||||
|
error=error,
|
||||||
|
error_type=error_type,
|
||||||
|
started_at=started_at,
|
||||||
|
failed_at=failed_at,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _cleanup_on_error(self) -> None:
|
||||||
|
"""Cleanup resources when an error occurs during connection."""
|
||||||
|
try:
|
||||||
|
await self._exit_stack.aclose()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# Best effort cleanup - ignore all other errors
|
||||||
|
raise RuntimeError(f"Error during MCP client cleanup: {e}") from e
|
||||||
|
finally:
|
||||||
|
self._session = None
|
||||||
|
self._initialized = False
|
||||||
|
self._exit_stack = AsyncExitStack()
|
||||||
|
|
||||||
|
async def disconnect(self) -> None:
|
||||||
|
"""Disconnect from MCP server and cleanup resources."""
|
||||||
|
if not self.connected:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
await self._exit_stack.aclose()
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"Error during MCP client disconnect: {e}") from e
|
||||||
|
finally:
|
||||||
|
self._session = None
|
||||||
|
self._initialized = False
|
||||||
|
self._exit_stack = AsyncExitStack()
|
||||||
|
|
||||||
|
async def list_tools(self, use_cache: bool | None = None) -> list[dict[str, Any]]:
|
||||||
|
"""List available tools from MCP server.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
use_cache: Whether to use cached results. If None, uses
|
||||||
|
client's cache_tools_list setting.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of tool definitions with name, description, and inputSchema.
|
||||||
|
"""
|
||||||
|
if not self.connected:
|
||||||
|
await self.connect()
|
||||||
|
|
||||||
|
# Check cache if enabled
|
||||||
|
use_cache = use_cache if use_cache is not None else self.cache_tools_list
|
||||||
|
if use_cache:
|
||||||
|
cache_key = self._get_cache_key("tools")
|
||||||
|
if cache_key in _mcp_schema_cache:
|
||||||
|
cached_data, cache_time = _mcp_schema_cache[cache_key]
|
||||||
|
if time.time() - cache_time < _cache_ttl:
|
||||||
|
# Logger removed - return cached data
|
||||||
|
return cached_data
|
||||||
|
|
||||||
|
# List tools with timeout and retries
|
||||||
|
tools = await self._retry_operation(
|
||||||
|
self._list_tools_impl,
|
||||||
|
timeout=self.discovery_timeout,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Cache results if enabled
|
||||||
|
if use_cache:
|
||||||
|
cache_key = self._get_cache_key("tools")
|
||||||
|
_mcp_schema_cache[cache_key] = (tools, time.time())
|
||||||
|
|
||||||
|
return tools
|
||||||
|
|
||||||
|
async def _list_tools_impl(self) -> list[dict[str, Any]]:
|
||||||
|
"""Internal implementation of list_tools."""
|
||||||
|
tools_result = await asyncio.wait_for(
|
||||||
|
self.session.list_tools(),
|
||||||
|
timeout=self.discovery_timeout,
|
||||||
|
)
|
||||||
|
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"name": tool.name,
|
||||||
|
"description": getattr(tool, "description", ""),
|
||||||
|
"inputSchema": getattr(tool, "inputSchema", {}),
|
||||||
|
}
|
||||||
|
for tool in tools_result.tools
|
||||||
|
]
|
||||||
|
|
||||||
|
async def call_tool(
|
||||||
|
self, tool_name: str, arguments: dict[str, Any] | None = None
|
||||||
|
) -> Any:
|
||||||
|
"""Call a tool on the MCP server.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool_name: Name of the tool to call.
|
||||||
|
arguments: Tool arguments.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tool execution result.
|
||||||
|
"""
|
||||||
|
if not self.connected:
|
||||||
|
await self.connect()
|
||||||
|
|
||||||
|
arguments = arguments or {}
|
||||||
|
cleaned_arguments = self._clean_tool_arguments(arguments)
|
||||||
|
|
||||||
|
# Get server info for events
|
||||||
|
server_name, server_url, transport_type = self._get_server_info()
|
||||||
|
|
||||||
|
# Emit tool execution started event
|
||||||
|
started_at = datetime.now()
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
self,
|
||||||
|
MCPToolExecutionStartedEvent(
|
||||||
|
server_name=server_name,
|
||||||
|
server_url=server_url,
|
||||||
|
transport_type=transport_type,
|
||||||
|
tool_name=tool_name,
|
||||||
|
tool_args=cleaned_arguments,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = await self._retry_operation(
|
||||||
|
lambda: self._call_tool_impl(tool_name, cleaned_arguments),
|
||||||
|
timeout=self.execution_timeout,
|
||||||
|
)
|
||||||
|
|
||||||
|
completed_at = datetime.now()
|
||||||
|
execution_duration_ms = (completed_at - started_at).total_seconds() * 1000
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
self,
|
||||||
|
MCPToolExecutionCompletedEvent(
|
||||||
|
server_name=server_name,
|
||||||
|
server_url=server_url,
|
||||||
|
transport_type=transport_type,
|
||||||
|
tool_name=tool_name,
|
||||||
|
tool_args=cleaned_arguments,
|
||||||
|
result=result,
|
||||||
|
started_at=started_at,
|
||||||
|
completed_at=completed_at,
|
||||||
|
execution_duration_ms=execution_duration_ms,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
failed_at = datetime.now()
|
||||||
|
error_type = (
|
||||||
|
"timeout"
|
||||||
|
if isinstance(e, (asyncio.TimeoutError, ConnectionError))
|
||||||
|
and "timeout" in str(e).lower()
|
||||||
|
else "server_error"
|
||||||
|
)
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
self,
|
||||||
|
MCPToolExecutionFailedEvent(
|
||||||
|
server_name=server_name,
|
||||||
|
server_url=server_url,
|
||||||
|
transport_type=transport_type,
|
||||||
|
tool_name=tool_name,
|
||||||
|
tool_args=cleaned_arguments,
|
||||||
|
error=str(e),
|
||||||
|
error_type=error_type,
|
||||||
|
started_at=started_at,
|
||||||
|
failed_at=failed_at,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
def _clean_tool_arguments(self, arguments: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
"""Clean tool arguments by removing None values and fixing formats.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
arguments: Raw tool arguments.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Cleaned arguments ready for MCP server.
|
||||||
|
"""
|
||||||
|
cleaned = {}
|
||||||
|
|
||||||
|
for key, value in arguments.items():
|
||||||
|
# Skip None values
|
||||||
|
if value is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Fix sources array format: convert ["web"] to [{"type": "web"}]
|
||||||
|
if key == "sources" and isinstance(value, list):
|
||||||
|
fixed_sources = []
|
||||||
|
for item in value:
|
||||||
|
if isinstance(item, str):
|
||||||
|
# Convert string to object format
|
||||||
|
fixed_sources.append({"type": item})
|
||||||
|
elif isinstance(item, dict):
|
||||||
|
# Already in correct format
|
||||||
|
fixed_sources.append(item)
|
||||||
|
else:
|
||||||
|
# Keep as is if unknown format
|
||||||
|
fixed_sources.append(item)
|
||||||
|
if fixed_sources:
|
||||||
|
cleaned[key] = fixed_sources
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Recursively clean nested dictionaries
|
||||||
|
if isinstance(value, dict):
|
||||||
|
nested_cleaned = self._clean_tool_arguments(value)
|
||||||
|
if nested_cleaned: # Only add if not empty
|
||||||
|
cleaned[key] = nested_cleaned
|
||||||
|
elif isinstance(value, list):
|
||||||
|
# Clean list items
|
||||||
|
cleaned_list = []
|
||||||
|
for item in value:
|
||||||
|
if isinstance(item, dict):
|
||||||
|
cleaned_item = self._clean_tool_arguments(item)
|
||||||
|
if cleaned_item:
|
||||||
|
cleaned_list.append(cleaned_item)
|
||||||
|
elif item is not None:
|
||||||
|
cleaned_list.append(item)
|
||||||
|
if cleaned_list:
|
||||||
|
cleaned[key] = cleaned_list
|
||||||
|
else:
|
||||||
|
# Keep primitive values
|
||||||
|
cleaned[key] = value
|
||||||
|
|
||||||
|
return cleaned
|
||||||
|
|
||||||
|
async def _call_tool_impl(self, tool_name: str, arguments: dict[str, Any]) -> Any:
|
||||||
|
"""Internal implementation of call_tool."""
|
||||||
|
result = await asyncio.wait_for(
|
||||||
|
self.session.call_tool(tool_name, arguments),
|
||||||
|
timeout=self.execution_timeout,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract result content
|
||||||
|
if hasattr(result, "content") and result.content:
|
||||||
|
if isinstance(result.content, list) and len(result.content) > 0:
|
||||||
|
content_item = result.content[0]
|
||||||
|
if hasattr(content_item, "text"):
|
||||||
|
return str(content_item.text)
|
||||||
|
return str(content_item)
|
||||||
|
return str(result.content)
|
||||||
|
|
||||||
|
return str(result)
|
||||||
|
|
||||||
|
async def list_prompts(self) -> list[dict[str, Any]]:
|
||||||
|
"""List available prompts from MCP server.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of prompt definitions.
|
||||||
|
"""
|
||||||
|
if not self.connected:
|
||||||
|
await self.connect()
|
||||||
|
|
||||||
|
return await self._retry_operation(
|
||||||
|
self._list_prompts_impl,
|
||||||
|
timeout=self.discovery_timeout,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _list_prompts_impl(self) -> list[dict[str, Any]]:
|
||||||
|
"""Internal implementation of list_prompts."""
|
||||||
|
prompts_result = await asyncio.wait_for(
|
||||||
|
self.session.list_prompts(),
|
||||||
|
timeout=self.discovery_timeout,
|
||||||
|
)
|
||||||
|
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"name": prompt.name,
|
||||||
|
"description": getattr(prompt, "description", ""),
|
||||||
|
"arguments": getattr(prompt, "arguments", []),
|
||||||
|
}
|
||||||
|
for prompt in prompts_result.prompts
|
||||||
|
]
|
||||||
|
|
||||||
|
async def get_prompt(
|
||||||
|
self, prompt_name: str, arguments: dict[str, Any] | None = None
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Get a prompt from the MCP server.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt_name: Name of the prompt to get.
|
||||||
|
arguments: Optional prompt arguments.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Prompt content and metadata.
|
||||||
|
"""
|
||||||
|
if not self.connected:
|
||||||
|
await self.connect()
|
||||||
|
|
||||||
|
arguments = arguments or {}
|
||||||
|
|
||||||
|
return await self._retry_operation(
|
||||||
|
lambda: self._get_prompt_impl(prompt_name, arguments),
|
||||||
|
timeout=self.execution_timeout,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _get_prompt_impl(
|
||||||
|
self, prompt_name: str, arguments: dict[str, Any]
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Internal implementation of get_prompt."""
|
||||||
|
result = await asyncio.wait_for(
|
||||||
|
self.session.get_prompt(prompt_name, arguments),
|
||||||
|
timeout=self.execution_timeout,
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"name": prompt_name,
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": msg.role,
|
||||||
|
"content": msg.content,
|
||||||
|
}
|
||||||
|
for msg in result.messages
|
||||||
|
],
|
||||||
|
"arguments": arguments,
|
||||||
|
}
|
||||||
|
|
||||||
|
async def _retry_operation(
|
||||||
|
self,
|
||||||
|
operation: Callable[[], Any],
|
||||||
|
timeout: int | None = None,
|
||||||
|
) -> Any:
|
||||||
|
"""Retry an operation with exponential backoff.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
operation: Async operation to retry.
|
||||||
|
timeout: Operation timeout in seconds.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Operation result.
|
||||||
|
"""
|
||||||
|
last_error = None
|
||||||
|
timeout = timeout or self.execution_timeout
|
||||||
|
|
||||||
|
for attempt in range(self.max_retries):
|
||||||
|
try:
|
||||||
|
if timeout:
|
||||||
|
return await asyncio.wait_for(operation(), timeout=timeout)
|
||||||
|
return await operation()
|
||||||
|
|
||||||
|
except asyncio.TimeoutError as e: # noqa: PERF203
|
||||||
|
last_error = f"Operation timed out after {timeout} seconds"
|
||||||
|
if attempt < self.max_retries - 1:
|
||||||
|
wait_time = 2**attempt
|
||||||
|
await asyncio.sleep(wait_time)
|
||||||
|
else:
|
||||||
|
raise ConnectionError(last_error) from e
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_str = str(e).lower()
|
||||||
|
|
||||||
|
# Classify errors as retryable or non-retryable
|
||||||
|
if "authentication" in error_str or "unauthorized" in error_str:
|
||||||
|
raise ConnectionError(f"Authentication failed: {e}") from e
|
||||||
|
|
||||||
|
if "not found" in error_str:
|
||||||
|
raise ValueError(f"Resource not found: {e}") from e
|
||||||
|
|
||||||
|
# Retryable errors
|
||||||
|
last_error = str(e)
|
||||||
|
if attempt < self.max_retries - 1:
|
||||||
|
wait_time = 2**attempt
|
||||||
|
await asyncio.sleep(wait_time)
|
||||||
|
else:
|
||||||
|
raise ConnectionError(
|
||||||
|
f"Operation failed after {self.max_retries} attempts: {last_error}"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
raise ConnectionError(f"Operation failed: {last_error}")
|
||||||
|
|
||||||
|
def _get_cache_key(self, resource_type: str) -> str:
|
||||||
|
"""Generate cache key for resource.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
resource_type: Type of resource (e.g., "tools", "prompts").
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Cache key string.
|
||||||
|
"""
|
||||||
|
# Use transport type and URL/command as cache key
|
||||||
|
if isinstance(self.transport, StdioTransport):
|
||||||
|
key = f"stdio:{self.transport.command}:{':'.join(self.transport.args)}"
|
||||||
|
elif isinstance(self.transport, HTTPTransport):
|
||||||
|
key = f"http:{self.transport.url}"
|
||||||
|
elif isinstance(self.transport, SSETransport):
|
||||||
|
key = f"sse:{self.transport.url}"
|
||||||
|
else:
|
||||||
|
key = f"{self.transport.transport_type}:unknown"
|
||||||
|
|
||||||
|
return f"mcp:{key}:{resource_type}"
|
||||||
|
|
||||||
|
async def __aenter__(self) -> Self:
|
||||||
|
"""Async context manager entry."""
|
||||||
|
return await self.connect()
|
||||||
|
|
||||||
|
async def __aexit__(
|
||||||
|
self,
|
||||||
|
exc_type: type[BaseException] | None,
|
||||||
|
exc_val: BaseException | None,
|
||||||
|
exc_tb: Any,
|
||||||
|
) -> None:
|
||||||
|
"""Async context manager exit."""
|
||||||
|
await self.disconnect()
|
||||||
124
lib/crewai/src/crewai/mcp/config.py
Normal file
124
lib/crewai/src/crewai/mcp/config.py
Normal file
@@ -0,0 +1,124 @@
|
|||||||
|
"""MCP server configuration models for CrewAI agents.
|
||||||
|
|
||||||
|
This module provides Pydantic models for configuring MCP servers with
|
||||||
|
various transport types, similar to OpenAI's Agents SDK.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from crewai.mcp.filters import ToolFilter
|
||||||
|
|
||||||
|
|
||||||
|
class MCPServerStdio(BaseModel):
|
||||||
|
"""Stdio MCP server configuration.
|
||||||
|
|
||||||
|
This configuration is used for connecting to local MCP servers
|
||||||
|
that run as processes and communicate via standard input/output.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
mcp_server = MCPServerStdio(
|
||||||
|
command="python",
|
||||||
|
args=["path/to/server.py"],
|
||||||
|
env={"API_KEY": "..."},
|
||||||
|
tool_filter=create_static_tool_filter(
|
||||||
|
allowed_tool_names=["read_file", "write_file"]
|
||||||
|
),
|
||||||
|
)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
command: str = Field(
|
||||||
|
...,
|
||||||
|
description="Command to execute (e.g., 'python', 'node', 'npx', 'uvx').",
|
||||||
|
)
|
||||||
|
args: list[str] = Field(
|
||||||
|
default_factory=list,
|
||||||
|
description="Command arguments (e.g., ['server.py'] or ['-y', '@mcp/server']).",
|
||||||
|
)
|
||||||
|
env: dict[str, str] | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="Environment variables to pass to the process.",
|
||||||
|
)
|
||||||
|
tool_filter: ToolFilter | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="Optional tool filter for filtering available tools.",
|
||||||
|
)
|
||||||
|
cache_tools_list: bool = Field(
|
||||||
|
default=False,
|
||||||
|
description="Whether to cache the tool list for faster subsequent access.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MCPServerHTTP(BaseModel):
|
||||||
|
"""HTTP/Streamable HTTP MCP server configuration.
|
||||||
|
|
||||||
|
This configuration is used for connecting to remote MCP servers
|
||||||
|
over HTTP/HTTPS using streamable HTTP transport.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
mcp_server = MCPServerHTTP(
|
||||||
|
url="https://api.example.com/mcp",
|
||||||
|
headers={"Authorization": "Bearer ..."},
|
||||||
|
cache_tools_list=True,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
url: str = Field(
|
||||||
|
..., description="Server URL (e.g., 'https://api.example.com/mcp')."
|
||||||
|
)
|
||||||
|
headers: dict[str, str] | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="Optional HTTP headers for authentication or other purposes.",
|
||||||
|
)
|
||||||
|
streamable: bool = Field(
|
||||||
|
default=True,
|
||||||
|
description="Whether to use streamable HTTP transport (default: True).",
|
||||||
|
)
|
||||||
|
tool_filter: ToolFilter | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="Optional tool filter for filtering available tools.",
|
||||||
|
)
|
||||||
|
cache_tools_list: bool = Field(
|
||||||
|
default=False,
|
||||||
|
description="Whether to cache the tool list for faster subsequent access.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MCPServerSSE(BaseModel):
|
||||||
|
"""Server-Sent Events (SSE) MCP server configuration.
|
||||||
|
|
||||||
|
This configuration is used for connecting to remote MCP servers
|
||||||
|
using Server-Sent Events for real-time streaming communication.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
mcp_server = MCPServerSSE(
|
||||||
|
url="https://api.example.com/mcp/sse",
|
||||||
|
headers={"Authorization": "Bearer ..."},
|
||||||
|
)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
url: str = Field(
|
||||||
|
...,
|
||||||
|
description="Server URL (e.g., 'https://api.example.com/mcp/sse').",
|
||||||
|
)
|
||||||
|
headers: dict[str, str] | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="Optional HTTP headers for authentication or other purposes.",
|
||||||
|
)
|
||||||
|
tool_filter: ToolFilter | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="Optional tool filter for filtering available tools.",
|
||||||
|
)
|
||||||
|
cache_tools_list: bool = Field(
|
||||||
|
default=False,
|
||||||
|
description="Whether to cache the tool list for faster subsequent access.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Type alias for all MCP server configurations
|
||||||
|
MCPServerConfig = MCPServerStdio | MCPServerHTTP | MCPServerSSE
|
||||||
166
lib/crewai/src/crewai/mcp/filters.py
Normal file
166
lib/crewai/src/crewai/mcp/filters.py
Normal file
@@ -0,0 +1,166 @@
|
|||||||
|
"""Tool filtering support for MCP servers.
|
||||||
|
|
||||||
|
This module provides utilities for filtering tools from MCP servers,
|
||||||
|
including static allow/block lists and dynamic context-aware filtering.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from collections.abc import Callable
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ToolFilterContext(BaseModel):
|
||||||
|
"""Context for dynamic tool filtering.
|
||||||
|
|
||||||
|
This context is passed to dynamic tool filters to provide
|
||||||
|
information about the agent, run context, and server.
|
||||||
|
"""
|
||||||
|
|
||||||
|
agent: Any = Field(..., description="The agent requesting tools.")
|
||||||
|
server_name: str = Field(..., description="Name of the MCP server.")
|
||||||
|
run_context: dict[str, Any] | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="Optional run context for additional filtering logic.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Type alias for tool filter functions
|
||||||
|
ToolFilter = (
|
||||||
|
Callable[[ToolFilterContext, dict[str, Any]], bool]
|
||||||
|
| Callable[[dict[str, Any]], bool]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class StaticToolFilter:
|
||||||
|
"""Static tool filter with allow/block lists.
|
||||||
|
|
||||||
|
This filter provides simple allow/block list filtering based on
|
||||||
|
tool names. Useful for restricting which tools are available
|
||||||
|
from an MCP server.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
filter = StaticToolFilter(
|
||||||
|
allowed_tool_names=["read_file", "write_file"],
|
||||||
|
blocked_tool_names=["delete_file"],
|
||||||
|
)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
allowed_tool_names: list[str] | None = None,
|
||||||
|
blocked_tool_names: list[str] | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Initialize static tool filter.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
allowed_tool_names: List of tool names to allow. If None,
|
||||||
|
all tools are allowed (unless blocked).
|
||||||
|
blocked_tool_names: List of tool names to block. Blocked tools
|
||||||
|
take precedence over allowed tools.
|
||||||
|
"""
|
||||||
|
self.allowed_tool_names = set(allowed_tool_names or [])
|
||||||
|
self.blocked_tool_names = set(blocked_tool_names or [])
|
||||||
|
|
||||||
|
def __call__(self, tool: dict[str, Any]) -> bool:
|
||||||
|
"""Filter tool based on allow/block lists.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool: Tool definition dictionary with at least 'name' key.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if tool should be included, False otherwise.
|
||||||
|
"""
|
||||||
|
tool_name = tool.get("name", "")
|
||||||
|
|
||||||
|
# Blocked tools take precedence
|
||||||
|
if self.blocked_tool_names and tool_name in self.blocked_tool_names:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# If allow list exists, tool must be in it
|
||||||
|
if self.allowed_tool_names:
|
||||||
|
return tool_name in self.allowed_tool_names
|
||||||
|
|
||||||
|
# No restrictions - allow all
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def create_static_tool_filter(
|
||||||
|
allowed_tool_names: list[str] | None = None,
|
||||||
|
blocked_tool_names: list[str] | None = None,
|
||||||
|
) -> Callable[[dict[str, Any]], bool]:
|
||||||
|
"""Create a static tool filter function.
|
||||||
|
|
||||||
|
This is a convenience function for creating static tool filters
|
||||||
|
with allow/block lists.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
allowed_tool_names: List of tool names to allow. If None,
|
||||||
|
all tools are allowed (unless blocked).
|
||||||
|
blocked_tool_names: List of tool names to block. Blocked tools
|
||||||
|
take precedence over allowed tools.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tool filter function that returns True for allowed tools.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
filter_fn = create_static_tool_filter(
|
||||||
|
allowed_tool_names=["read_file", "write_file"],
|
||||||
|
blocked_tool_names=["delete_file"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use in MCPServerStdio
|
||||||
|
mcp_server = MCPServerStdio(
|
||||||
|
command="npx",
|
||||||
|
args=["-y", "@modelcontextprotocol/server-filesystem"],
|
||||||
|
tool_filter=filter_fn,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
return StaticToolFilter(
|
||||||
|
allowed_tool_names=allowed_tool_names,
|
||||||
|
blocked_tool_names=blocked_tool_names,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def create_dynamic_tool_filter(
|
||||||
|
filter_func: Callable[[ToolFilterContext, dict[str, Any]], bool],
|
||||||
|
) -> Callable[[ToolFilterContext, dict[str, Any]], bool]:
|
||||||
|
"""Create a dynamic tool filter function.
|
||||||
|
|
||||||
|
This function wraps a dynamic filter function that has access
|
||||||
|
to the tool filter context (agent, server, run context).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filter_func: Function that takes (context, tool) and returns bool.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tool filter function that can be used with MCP server configs.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
async def context_aware_filter(
|
||||||
|
context: ToolFilterContext, tool: dict[str, Any]
|
||||||
|
) -> bool:
|
||||||
|
# Block dangerous tools for code reviewers
|
||||||
|
if context.agent.role == "Code Reviewer":
|
||||||
|
if tool["name"].startswith("danger_"):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
filter_fn = create_dynamic_tool_filter(context_aware_filter)
|
||||||
|
|
||||||
|
mcp_server = MCPServerStdio(
|
||||||
|
command="python", args=["server.py"], tool_filter=filter_fn
|
||||||
|
)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
return filter_func
|
||||||
15
lib/crewai/src/crewai/mcp/transports/__init__.py
Normal file
15
lib/crewai/src/crewai/mcp/transports/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
"""MCP transport implementations for various connection types."""
|
||||||
|
|
||||||
|
from crewai.mcp.transports.base import BaseTransport, TransportType
|
||||||
|
from crewai.mcp.transports.http import HTTPTransport
|
||||||
|
from crewai.mcp.transports.sse import SSETransport
|
||||||
|
from crewai.mcp.transports.stdio import StdioTransport
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"BaseTransport",
|
||||||
|
"HTTPTransport",
|
||||||
|
"SSETransport",
|
||||||
|
"StdioTransport",
|
||||||
|
"TransportType",
|
||||||
|
]
|
||||||
125
lib/crewai/src/crewai/mcp/transports/base.py
Normal file
125
lib/crewai/src/crewai/mcp/transports/base.py
Normal file
@@ -0,0 +1,125 @@
|
|||||||
|
"""Base transport interface for MCP connections."""
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any, Protocol
|
||||||
|
|
||||||
|
from typing_extensions import Self
|
||||||
|
|
||||||
|
|
||||||
|
class TransportType(str, Enum):
|
||||||
|
"""MCP transport types."""
|
||||||
|
|
||||||
|
STDIO = "stdio"
|
||||||
|
HTTP = "http"
|
||||||
|
STREAMABLE_HTTP = "streamable-http"
|
||||||
|
SSE = "sse"
|
||||||
|
|
||||||
|
|
||||||
|
class ReadStream(Protocol):
|
||||||
|
"""Protocol for read streams."""
|
||||||
|
|
||||||
|
async def read(self, n: int = -1) -> bytes:
|
||||||
|
"""Read bytes from stream."""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class WriteStream(Protocol):
|
||||||
|
"""Protocol for write streams."""
|
||||||
|
|
||||||
|
async def write(self, data: bytes) -> None:
|
||||||
|
"""Write bytes to stream."""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class BaseTransport(ABC):
|
||||||
|
"""Base class for MCP transport implementations.
|
||||||
|
|
||||||
|
This abstract base class defines the interface that all transport
|
||||||
|
implementations must follow. Transports handle the low-level communication
|
||||||
|
with MCP servers.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, **kwargs: Any) -> None:
|
||||||
|
"""Initialize the transport.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
**kwargs: Transport-specific configuration options.
|
||||||
|
"""
|
||||||
|
self._read_stream: ReadStream | None = None
|
||||||
|
self._write_stream: WriteStream | None = None
|
||||||
|
self._connected = False
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def transport_type(self) -> TransportType:
|
||||||
|
"""Return the transport type."""
|
||||||
|
...
|
||||||
|
|
||||||
|
@property
|
||||||
|
def connected(self) -> bool:
|
||||||
|
"""Check if transport is connected."""
|
||||||
|
return self._connected
|
||||||
|
|
||||||
|
@property
|
||||||
|
def read_stream(self) -> ReadStream:
|
||||||
|
"""Get the read stream."""
|
||||||
|
if self._read_stream is None:
|
||||||
|
raise RuntimeError("Transport not connected. Call connect() first.")
|
||||||
|
return self._read_stream
|
||||||
|
|
||||||
|
@property
|
||||||
|
def write_stream(self) -> WriteStream:
|
||||||
|
"""Get the write stream."""
|
||||||
|
if self._write_stream is None:
|
||||||
|
raise RuntimeError("Transport not connected. Call connect() first.")
|
||||||
|
return self._write_stream
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def connect(self) -> Self:
|
||||||
|
"""Establish connection to MCP server.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Self for method chaining.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ConnectionError: If connection fails.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def disconnect(self) -> None:
|
||||||
|
"""Close connection to MCP server."""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def __aenter__(self) -> Self:
|
||||||
|
"""Async context manager entry."""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def __aexit__(
|
||||||
|
self,
|
||||||
|
exc_type: type[BaseException] | None,
|
||||||
|
exc_val: BaseException | None,
|
||||||
|
exc_tb: Any,
|
||||||
|
) -> None:
|
||||||
|
"""Async context manager exit."""
|
||||||
|
...
|
||||||
|
|
||||||
|
def _set_streams(self, read: ReadStream, write: WriteStream) -> None:
|
||||||
|
"""Set the read and write streams.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
read: Read stream.
|
||||||
|
write: Write stream.
|
||||||
|
"""
|
||||||
|
self._read_stream = read
|
||||||
|
self._write_stream = write
|
||||||
|
self._connected = True
|
||||||
|
|
||||||
|
def _clear_streams(self) -> None:
|
||||||
|
"""Clear the read and write streams."""
|
||||||
|
self._read_stream = None
|
||||||
|
self._write_stream = None
|
||||||
|
self._connected = False
|
||||||
174
lib/crewai/src/crewai/mcp/transports/http.py
Normal file
174
lib/crewai/src/crewai/mcp/transports/http.py
Normal file
@@ -0,0 +1,174 @@
|
|||||||
|
"""HTTP and Streamable HTTP transport for MCP servers."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from typing_extensions import Self
|
||||||
|
|
||||||
|
|
||||||
|
# BaseExceptionGroup is available in Python 3.11+
|
||||||
|
try:
|
||||||
|
from builtins import BaseExceptionGroup
|
||||||
|
except ImportError:
|
||||||
|
# Fallback for Python < 3.11 (shouldn't happen in practice)
|
||||||
|
BaseExceptionGroup = Exception
|
||||||
|
|
||||||
|
from crewai.mcp.transports.base import BaseTransport, TransportType
|
||||||
|
|
||||||
|
|
||||||
|
class HTTPTransport(BaseTransport):
|
||||||
|
"""HTTP/Streamable HTTP transport for connecting to remote MCP servers.
|
||||||
|
|
||||||
|
This transport connects to MCP servers over HTTP/HTTPS using the
|
||||||
|
streamable HTTP client from the MCP SDK.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
transport = HTTPTransport(
|
||||||
|
url="https://api.example.com/mcp",
|
||||||
|
headers={"Authorization": "Bearer ..."}
|
||||||
|
)
|
||||||
|
async with transport:
|
||||||
|
# Use transport...
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
url: str,
|
||||||
|
headers: dict[str, str] | None = None,
|
||||||
|
streamable: bool = True,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
"""Initialize HTTP transport.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
url: Server URL (e.g., "https://api.example.com/mcp").
|
||||||
|
headers: Optional HTTP headers.
|
||||||
|
streamable: Whether to use streamable HTTP (default: True).
|
||||||
|
**kwargs: Additional transport options.
|
||||||
|
"""
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.url = url
|
||||||
|
self.headers = headers or {}
|
||||||
|
self.streamable = streamable
|
||||||
|
self._transport_context: Any = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def transport_type(self) -> TransportType:
|
||||||
|
"""Return the transport type."""
|
||||||
|
return TransportType.STREAMABLE_HTTP if self.streamable else TransportType.HTTP
|
||||||
|
|
||||||
|
async def connect(self) -> Self:
|
||||||
|
"""Establish HTTP connection to MCP server.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Self for method chaining.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ConnectionError: If connection fails.
|
||||||
|
ImportError: If MCP SDK not available.
|
||||||
|
"""
|
||||||
|
if self._connected:
|
||||||
|
return self
|
||||||
|
|
||||||
|
try:
|
||||||
|
from mcp.client.streamable_http import streamablehttp_client
|
||||||
|
|
||||||
|
self._transport_context = streamablehttp_client(
|
||||||
|
self.url,
|
||||||
|
headers=self.headers if self.headers else None,
|
||||||
|
terminate_on_close=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
read, write, _ = await asyncio.wait_for(
|
||||||
|
self._transport_context.__aenter__(), timeout=30.0
|
||||||
|
)
|
||||||
|
except asyncio.TimeoutError as e:
|
||||||
|
self._transport_context = None
|
||||||
|
raise ConnectionError(
|
||||||
|
"Transport context entry timed out after 30 seconds. "
|
||||||
|
"Server may be slow or unreachable."
|
||||||
|
) from e
|
||||||
|
except Exception as e:
|
||||||
|
self._transport_context = None
|
||||||
|
raise ConnectionError(f"Failed to enter transport context: {e}") from e
|
||||||
|
self._set_streams(read=read, write=write)
|
||||||
|
return self
|
||||||
|
|
||||||
|
except ImportError as e:
|
||||||
|
raise ImportError(
|
||||||
|
"MCP library not available. Please install with: pip install mcp"
|
||||||
|
) from e
|
||||||
|
except Exception as e:
|
||||||
|
self._clear_streams()
|
||||||
|
if self._transport_context is not None:
|
||||||
|
self._transport_context = None
|
||||||
|
raise ConnectionError(f"Failed to connect to MCP server: {e}") from e
|
||||||
|
|
||||||
|
async def disconnect(self) -> None:
|
||||||
|
"""Close HTTP connection."""
|
||||||
|
if not self._connected:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Clear streams first
|
||||||
|
self._clear_streams()
|
||||||
|
# await self._exit_stack.aclose()
|
||||||
|
|
||||||
|
# Exit transport context - this will clean up background tasks
|
||||||
|
# Give a small delay to allow background tasks to complete
|
||||||
|
if self._transport_context is not None:
|
||||||
|
try:
|
||||||
|
# Wait a tiny bit for any pending operations
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
await self._transport_context.__aexit__(None, None, None)
|
||||||
|
except (RuntimeError, asyncio.CancelledError) as e:
|
||||||
|
# Ignore "exit cancel scope in different task" errors and cancellation
|
||||||
|
# These happen when asyncio.run() closes the event loop
|
||||||
|
# while background tasks are still running
|
||||||
|
error_msg = str(e).lower()
|
||||||
|
if "cancel scope" not in error_msg and "task" not in error_msg:
|
||||||
|
# Only suppress cancel scope/task errors, re-raise others
|
||||||
|
if isinstance(e, RuntimeError):
|
||||||
|
raise
|
||||||
|
# For CancelledError, just suppress it
|
||||||
|
except BaseExceptionGroup as eg:
|
||||||
|
# Handle exception groups from anyio task groups
|
||||||
|
# Suppress if they contain cancel scope errors
|
||||||
|
should_suppress = False
|
||||||
|
for exc in eg.exceptions:
|
||||||
|
error_msg = str(exc).lower()
|
||||||
|
if "cancel scope" in error_msg or "task" in error_msg:
|
||||||
|
should_suppress = True
|
||||||
|
break
|
||||||
|
if not should_suppress:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Error during HTTP transport disconnect: {e}"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
self._connected = False
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# Log but don't raise - cleanup should be best effort
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
logger.warning(f"Error during HTTP transport disconnect: {e}")
|
||||||
|
|
||||||
|
async def __aenter__(self) -> Self:
|
||||||
|
"""Async context manager entry."""
|
||||||
|
return await self.connect()
|
||||||
|
|
||||||
|
async def __aexit__(
|
||||||
|
self,
|
||||||
|
exc_type: type[BaseException] | None,
|
||||||
|
exc_val: BaseException | None,
|
||||||
|
exc_tb: Any,
|
||||||
|
) -> None:
|
||||||
|
"""Async context manager exit."""
|
||||||
|
|
||||||
|
await self.disconnect()
|
||||||
113
lib/crewai/src/crewai/mcp/transports/sse.py
Normal file
113
lib/crewai/src/crewai/mcp/transports/sse.py
Normal file
@@ -0,0 +1,113 @@
|
|||||||
|
"""Server-Sent Events (SSE) transport for MCP servers."""
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from typing_extensions import Self
|
||||||
|
|
||||||
|
from crewai.mcp.transports.base import BaseTransport, TransportType
|
||||||
|
|
||||||
|
|
||||||
|
class SSETransport(BaseTransport):
|
||||||
|
"""SSE transport for connecting to remote MCP servers.
|
||||||
|
|
||||||
|
This transport connects to MCP servers using Server-Sent Events (SSE)
|
||||||
|
for real-time streaming communication.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
transport = SSETransport(
|
||||||
|
url="https://api.example.com/mcp/sse",
|
||||||
|
headers={"Authorization": "Bearer ..."}
|
||||||
|
)
|
||||||
|
async with transport:
|
||||||
|
# Use transport...
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
url: str,
|
||||||
|
headers: dict[str, str] | None = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
"""Initialize SSE transport.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
url: Server URL (e.g., "https://api.example.com/mcp/sse").
|
||||||
|
headers: Optional HTTP headers.
|
||||||
|
**kwargs: Additional transport options.
|
||||||
|
"""
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.url = url
|
||||||
|
self.headers = headers or {}
|
||||||
|
self._transport_context: Any = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def transport_type(self) -> TransportType:
|
||||||
|
"""Return the transport type."""
|
||||||
|
return TransportType.SSE
|
||||||
|
|
||||||
|
async def connect(self) -> Self:
|
||||||
|
"""Establish SSE connection to MCP server.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Self for method chaining.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ConnectionError: If connection fails.
|
||||||
|
ImportError: If MCP SDK not available.
|
||||||
|
"""
|
||||||
|
if self._connected:
|
||||||
|
return self
|
||||||
|
|
||||||
|
try:
|
||||||
|
from mcp.client.sse import sse_client
|
||||||
|
|
||||||
|
self._transport_context = sse_client(
|
||||||
|
self.url,
|
||||||
|
headers=self.headers if self.headers else None,
|
||||||
|
terminate_on_close=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
read, write = await self._transport_context.__aenter__()
|
||||||
|
|
||||||
|
self._set_streams(read=read, write=write)
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
except ImportError as e:
|
||||||
|
raise ImportError(
|
||||||
|
"MCP library not available. Please install with: pip install mcp"
|
||||||
|
) from e
|
||||||
|
except Exception as e:
|
||||||
|
self._clear_streams()
|
||||||
|
raise ConnectionError(f"Failed to connect to SSE MCP server: {e}") from e
|
||||||
|
|
||||||
|
async def disconnect(self) -> None:
|
||||||
|
"""Close SSE connection."""
|
||||||
|
if not self._connected:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
self._clear_streams()
|
||||||
|
if self._transport_context is not None:
|
||||||
|
await self._transport_context.__aexit__(None, None, None)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
logger.warning(f"Error during SSE transport disconnect: {e}")
|
||||||
|
|
||||||
|
async def __aenter__(self) -> Self:
|
||||||
|
"""Async context manager entry."""
|
||||||
|
return await self.connect()
|
||||||
|
|
||||||
|
async def __aexit__(
|
||||||
|
self,
|
||||||
|
exc_type: type[BaseException] | None,
|
||||||
|
exc_val: BaseException | None,
|
||||||
|
exc_tb: Any,
|
||||||
|
) -> None:
|
||||||
|
"""Async context manager exit."""
|
||||||
|
await self.disconnect()
|
||||||
153
lib/crewai/src/crewai/mcp/transports/stdio.py
Normal file
153
lib/crewai/src/crewai/mcp/transports/stdio.py
Normal file
@@ -0,0 +1,153 @@
|
|||||||
|
"""Stdio transport for MCP servers running as local processes."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from typing_extensions import Self
|
||||||
|
|
||||||
|
from crewai.mcp.transports.base import BaseTransport, TransportType
|
||||||
|
|
||||||
|
|
||||||
|
class StdioTransport(BaseTransport):
|
||||||
|
"""Stdio transport for connecting to local MCP servers.
|
||||||
|
|
||||||
|
This transport connects to MCP servers running as local processes,
|
||||||
|
communicating via standard input/output streams. Supports Python,
|
||||||
|
Node.js, and other command-line servers.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
transport = StdioTransport(
|
||||||
|
command="python",
|
||||||
|
args=["path/to/server.py"],
|
||||||
|
env={"API_KEY": "..."}
|
||||||
|
)
|
||||||
|
async with transport:
|
||||||
|
# Use transport...
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
command: str,
|
||||||
|
args: list[str] | None = None,
|
||||||
|
env: dict[str, str] | None = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
"""Initialize stdio transport.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
command: Command to execute (e.g., "python", "node", "npx").
|
||||||
|
args: Command arguments (e.g., ["server.py"] or ["-y", "@mcp/server"]).
|
||||||
|
env: Environment variables to pass to the process.
|
||||||
|
**kwargs: Additional transport options.
|
||||||
|
"""
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.command = command
|
||||||
|
self.args = args or []
|
||||||
|
self.env = env or {}
|
||||||
|
self._process: subprocess.Popen[bytes] | None = None
|
||||||
|
self._transport_context: Any = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def transport_type(self) -> TransportType:
|
||||||
|
"""Return the transport type."""
|
||||||
|
return TransportType.STDIO
|
||||||
|
|
||||||
|
async def connect(self) -> Self:
|
||||||
|
"""Start the MCP server process and establish connection.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Self for method chaining.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ConnectionError: If process fails to start.
|
||||||
|
ImportError: If MCP SDK not available.
|
||||||
|
"""
|
||||||
|
if self._connected:
|
||||||
|
return self
|
||||||
|
|
||||||
|
try:
|
||||||
|
from mcp import StdioServerParameters
|
||||||
|
from mcp.client.stdio import stdio_client
|
||||||
|
|
||||||
|
process_env = os.environ.copy()
|
||||||
|
process_env.update(self.env)
|
||||||
|
|
||||||
|
server_params = StdioServerParameters(
|
||||||
|
command=self.command,
|
||||||
|
args=self.args,
|
||||||
|
env=process_env if process_env else None,
|
||||||
|
)
|
||||||
|
self._transport_context = stdio_client(server_params)
|
||||||
|
|
||||||
|
try:
|
||||||
|
read, write = await self._transport_context.__aenter__()
|
||||||
|
except Exception as e:
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
traceback.print_exc()
|
||||||
|
self._transport_context = None
|
||||||
|
raise ConnectionError(
|
||||||
|
f"Failed to enter stdio transport context: {e}"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
self._set_streams(read=read, write=write)
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
except ImportError as e:
|
||||||
|
raise ImportError(
|
||||||
|
"MCP library not available. Please install with: pip install mcp"
|
||||||
|
) from e
|
||||||
|
except Exception as e:
|
||||||
|
self._clear_streams()
|
||||||
|
if self._transport_context is not None:
|
||||||
|
self._transport_context = None
|
||||||
|
raise ConnectionError(f"Failed to start MCP server process: {e}") from e
|
||||||
|
|
||||||
|
async def disconnect(self) -> None:
|
||||||
|
"""Terminate the MCP server process and close connection."""
|
||||||
|
if not self._connected:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
self._clear_streams()
|
||||||
|
|
||||||
|
if self._transport_context is not None:
|
||||||
|
await self._transport_context.__aexit__(None, None, None)
|
||||||
|
|
||||||
|
if self._process is not None:
|
||||||
|
try:
|
||||||
|
self._process.terminate()
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(self._process.wait(), timeout=5.0)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
self._process.kill()
|
||||||
|
await self._process.wait()
|
||||||
|
# except ProcessLookupError:
|
||||||
|
# pass
|
||||||
|
finally:
|
||||||
|
self._process = None
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# Log but don't raise - cleanup should be best effort
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
logger.warning(f"Error during stdio transport disconnect: {e}")
|
||||||
|
|
||||||
|
async def __aenter__(self) -> Self:
|
||||||
|
"""Async context manager entry."""
|
||||||
|
return await self.connect()
|
||||||
|
|
||||||
|
async def __aexit__(
|
||||||
|
self,
|
||||||
|
exc_type: type[BaseException] | None,
|
||||||
|
exc_val: BaseException | None,
|
||||||
|
exc_tb: Any,
|
||||||
|
) -> None:
|
||||||
|
"""Async context manager exit."""
|
||||||
|
await self.disconnect()
|
||||||
154
lib/crewai/src/crewai/tools/mcp_native_tool.py
Normal file
154
lib/crewai/src/crewai/tools/mcp_native_tool.py
Normal file
@@ -0,0 +1,154 @@
|
|||||||
|
"""Native MCP tool wrapper for CrewAI agents.
|
||||||
|
|
||||||
|
This module provides a tool wrapper that reuses existing MCP client sessions
|
||||||
|
for better performance and connection management.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from crewai.tools import BaseTool
|
||||||
|
|
||||||
|
|
||||||
|
class MCPNativeTool(BaseTool):
|
||||||
|
"""Native MCP tool that reuses client sessions.
|
||||||
|
|
||||||
|
This tool wrapper is used when agents connect to MCP servers using
|
||||||
|
structured configurations. It reuses existing client sessions for
|
||||||
|
better performance and proper connection lifecycle management.
|
||||||
|
|
||||||
|
Unlike MCPToolWrapper which connects on-demand, this tool uses
|
||||||
|
a shared MCP client instance that maintains a persistent connection.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
mcp_client: Any,
|
||||||
|
tool_name: str,
|
||||||
|
tool_schema: dict[str, Any],
|
||||||
|
server_name: str,
|
||||||
|
) -> None:
|
||||||
|
"""Initialize native MCP tool.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mcp_client: MCPClient instance with active session.
|
||||||
|
tool_name: Original name of the tool on the MCP server.
|
||||||
|
tool_schema: Schema information for the tool.
|
||||||
|
server_name: Name of the MCP server for prefixing.
|
||||||
|
"""
|
||||||
|
# Create tool name with server prefix to avoid conflicts
|
||||||
|
prefixed_name = f"{server_name}_{tool_name}"
|
||||||
|
|
||||||
|
# Handle args_schema properly - BaseTool expects a BaseModel subclass
|
||||||
|
args_schema = tool_schema.get("args_schema")
|
||||||
|
|
||||||
|
# Only pass args_schema if it's provided
|
||||||
|
kwargs = {
|
||||||
|
"name": prefixed_name,
|
||||||
|
"description": tool_schema.get(
|
||||||
|
"description", f"Tool {tool_name} from {server_name}"
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
if args_schema is not None:
|
||||||
|
kwargs["args_schema"] = args_schema
|
||||||
|
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
# Set instance attributes after super().__init__
|
||||||
|
self._mcp_client = mcp_client
|
||||||
|
self._original_tool_name = tool_name
|
||||||
|
self._server_name = server_name
|
||||||
|
# self._logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def mcp_client(self) -> Any:
|
||||||
|
"""Get the MCP client instance."""
|
||||||
|
return self._mcp_client
|
||||||
|
|
||||||
|
@property
|
||||||
|
def original_tool_name(self) -> str:
|
||||||
|
"""Get the original tool name."""
|
||||||
|
return self._original_tool_name
|
||||||
|
|
||||||
|
@property
|
||||||
|
def server_name(self) -> str:
|
||||||
|
"""Get the server name."""
|
||||||
|
return self._server_name
|
||||||
|
|
||||||
|
def _run(self, **kwargs) -> str:
|
||||||
|
"""Execute tool using the MCP client session.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
**kwargs: Arguments to pass to the MCP tool.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Result from the MCP tool execution.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Always use asyncio.run() to create a fresh event loop
|
||||||
|
# This ensures the async context managers work correctly
|
||||||
|
return asyncio.run(self._run_async(**kwargs))
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Error executing MCP tool {self.original_tool_name}: {e!s}"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
async def _run_async(self, **kwargs) -> str:
|
||||||
|
"""Async implementation of tool execution.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
**kwargs: Arguments to pass to the MCP tool.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Result from the MCP tool execution.
|
||||||
|
"""
|
||||||
|
# Note: Since we use asyncio.run() which creates a new event loop each time,
|
||||||
|
# Always reconnect on-demand because asyncio.run() creates new event loops per call
|
||||||
|
# All MCP transport context managers (stdio, streamablehttp_client, sse_client)
|
||||||
|
# use anyio.create_task_group() which can't span different event loops
|
||||||
|
if self._mcp_client.connected:
|
||||||
|
await self._mcp_client.disconnect()
|
||||||
|
|
||||||
|
await self._mcp_client.connect()
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = await self._mcp_client.call_tool(self.original_tool_name, kwargs)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_str = str(e).lower()
|
||||||
|
if (
|
||||||
|
"not connected" in error_str
|
||||||
|
or "connection" in error_str
|
||||||
|
or "send" in error_str
|
||||||
|
):
|
||||||
|
await self._mcp_client.disconnect()
|
||||||
|
await self._mcp_client.connect()
|
||||||
|
# Retry the call
|
||||||
|
result = await self._mcp_client.call_tool(
|
||||||
|
self.original_tool_name, kwargs
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Always disconnect after tool call to ensure clean context manager lifecycle
|
||||||
|
# This prevents "exit cancel scope in different task" errors
|
||||||
|
# All transport context managers must be exited in the same event loop they were entered
|
||||||
|
await self._mcp_client.disconnect()
|
||||||
|
|
||||||
|
# Extract result content
|
||||||
|
if isinstance(result, str):
|
||||||
|
return result
|
||||||
|
|
||||||
|
# Handle various result formats
|
||||||
|
if hasattr(result, "content") and result.content:
|
||||||
|
if isinstance(result.content, list) and len(result.content) > 0:
|
||||||
|
content_item = result.content[0]
|
||||||
|
if hasattr(content_item, "text"):
|
||||||
|
return str(content_item.text)
|
||||||
|
return str(content_item)
|
||||||
|
return str(result.content)
|
||||||
|
|
||||||
|
return str(result)
|
||||||
4
lib/crewai/tests/mcp/__init__.py
Normal file
4
lib/crewai/tests/mcp/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
"""Tests for MCP (Model Context Protocol) integration."""
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
136
lib/crewai/tests/mcp/test_mcp_config.py
Normal file
136
lib/crewai/tests/mcp/test_mcp_config.py
Normal file
@@ -0,0 +1,136 @@
|
|||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from crewai.agent.core import Agent
|
||||||
|
from crewai.mcp.config import MCPServerHTTP, MCPServerSSE, MCPServerStdio
|
||||||
|
from crewai.tools.base_tool import BaseTool
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_tool_definitions():
|
||||||
|
"""Create mock MCP tool definitions (as returned by list_tools)."""
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"name": "test_tool_1",
|
||||||
|
"description": "Test tool 1 description",
|
||||||
|
"inputSchema": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"query": {"type": "string", "description": "Search query"}
|
||||||
|
},
|
||||||
|
"required": ["query"]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "test_tool_2",
|
||||||
|
"description": "Test tool 2 description",
|
||||||
|
"inputSchema": {}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_agent_with_stdio_mcp_config(mock_tool_definitions):
|
||||||
|
"""Test agent setup with MCPServerStdio configuration."""
|
||||||
|
stdio_config = MCPServerStdio(
|
||||||
|
command="python",
|
||||||
|
args=["server.py"],
|
||||||
|
env={"API_KEY": "test_key"},
|
||||||
|
)
|
||||||
|
|
||||||
|
agent = Agent(
|
||||||
|
role="Test Agent",
|
||||||
|
goal="Test goal",
|
||||||
|
backstory="Test backstory",
|
||||||
|
mcps=[stdio_config],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
with patch("crewai.agent.core.MCPClient") as mock_client_class:
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.list_tools = AsyncMock(return_value=mock_tool_definitions)
|
||||||
|
mock_client.connected = False # Will trigger connect
|
||||||
|
mock_client.connect = AsyncMock()
|
||||||
|
mock_client.disconnect = AsyncMock()
|
||||||
|
mock_client_class.return_value = mock_client
|
||||||
|
|
||||||
|
tools = agent.get_mcp_tools([stdio_config])
|
||||||
|
|
||||||
|
assert len(tools) == 2
|
||||||
|
assert all(isinstance(tool, BaseTool) for tool in tools)
|
||||||
|
|
||||||
|
mock_client_class.assert_called_once()
|
||||||
|
call_args = mock_client_class.call_args
|
||||||
|
transport = call_args.kwargs["transport"]
|
||||||
|
assert transport.command == "python"
|
||||||
|
assert transport.args == ["server.py"]
|
||||||
|
assert transport.env == {"API_KEY": "test_key"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_agent_with_http_mcp_config(mock_tool_definitions):
|
||||||
|
"""Test agent setup with MCPServerHTTP configuration."""
|
||||||
|
http_config = MCPServerHTTP(
|
||||||
|
url="https://api.example.com/mcp",
|
||||||
|
headers={"Authorization": "Bearer test_token"},
|
||||||
|
streamable=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
agent = Agent(
|
||||||
|
role="Test Agent",
|
||||||
|
goal="Test goal",
|
||||||
|
backstory="Test backstory",
|
||||||
|
mcps=[http_config],
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("crewai.agent.core.MCPClient") as mock_client_class:
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.list_tools = AsyncMock(return_value=mock_tool_definitions)
|
||||||
|
mock_client.connected = False # Will trigger connect
|
||||||
|
mock_client.connect = AsyncMock()
|
||||||
|
mock_client.disconnect = AsyncMock()
|
||||||
|
mock_client_class.return_value = mock_client
|
||||||
|
|
||||||
|
tools = agent.get_mcp_tools([http_config])
|
||||||
|
|
||||||
|
assert len(tools) == 2
|
||||||
|
assert all(isinstance(tool, BaseTool) for tool in tools)
|
||||||
|
|
||||||
|
mock_client_class.assert_called_once()
|
||||||
|
call_args = mock_client_class.call_args
|
||||||
|
transport = call_args.kwargs["transport"]
|
||||||
|
assert transport.url == "https://api.example.com/mcp"
|
||||||
|
assert transport.headers == {"Authorization": "Bearer test_token"}
|
||||||
|
assert transport.streamable is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_agent_with_sse_mcp_config(mock_tool_definitions):
|
||||||
|
"""Test agent setup with MCPServerSSE configuration."""
|
||||||
|
sse_config = MCPServerSSE(
|
||||||
|
url="https://api.example.com/mcp/sse",
|
||||||
|
headers={"Authorization": "Bearer test_token"},
|
||||||
|
)
|
||||||
|
|
||||||
|
agent = Agent(
|
||||||
|
role="Test Agent",
|
||||||
|
goal="Test goal",
|
||||||
|
backstory="Test backstory",
|
||||||
|
mcps=[sse_config],
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("crewai.agent.core.MCPClient") as mock_client_class:
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.list_tools = AsyncMock(return_value=mock_tool_definitions)
|
||||||
|
mock_client.connected = False
|
||||||
|
mock_client.connect = AsyncMock()
|
||||||
|
mock_client.disconnect = AsyncMock()
|
||||||
|
mock_client_class.return_value = mock_client
|
||||||
|
|
||||||
|
tools = agent.get_mcp_tools([sse_config])
|
||||||
|
|
||||||
|
assert len(tools) == 2
|
||||||
|
assert all(isinstance(tool, BaseTool) for tool in tools)
|
||||||
|
|
||||||
|
mock_client_class.assert_called_once()
|
||||||
|
call_args = mock_client_class.call_args
|
||||||
|
transport = call_args.kwargs["transport"]
|
||||||
|
assert transport.url == "https://api.example.com/mcp/sse"
|
||||||
|
assert transport.headers == {"Authorization": "Bearer test_token"}
|
||||||
Reference in New Issue
Block a user