diff --git a/docs/en/mcp/overview.mdx b/docs/en/mcp/overview.mdx index 63bdab5d6..d8eb2743c 100644 --- a/docs/en/mcp/overview.mdx +++ b/docs/en/mcp/overview.mdx @@ -11,9 +11,13 @@ The [Model Context Protocol](https://modelcontextprotocol.io/introduction) (MCP) CrewAI offers **two approaches** for MCP integration: -### Simple DSL Integration** (Recommended) +### 🚀 **Simple DSL Integration** (Recommended) -Use the `mcps` field directly on agents for seamless MCP tool integration: +Use the `mcps` field directly on agents for seamless MCP tool integration. The DSL supports both **string references** (for quick setup) and **structured configurations** (for full control). + +#### String-Based References (Quick Setup) + +Perfect for remote HTTPS servers and CrewAI AMP marketplace: ```python from crewai import Agent @@ -32,6 +36,46 @@ agent = Agent( # MCP tools are now automatically available to your agent! ``` +#### Structured Configurations (Full Control) + +For complete control over connection settings, tool filtering, and all transport types: + +```python +from crewai import Agent +from crewai.mcp import MCPServerStdio, MCPServerHTTP, MCPServerSSE +from crewai.mcp.filters import create_static_tool_filter + +agent = Agent( + role="Advanced Research Analyst", + goal="Research with full control over MCP connections", + backstory="Expert researcher with advanced tool access", + mcps=[ + # Stdio transport for local servers + MCPServerStdio( + command="npx", + args=["-y", "@modelcontextprotocol/server-filesystem"], + env={"API_KEY": "your_key"}, + tool_filter=create_static_tool_filter( + allowed_tool_names=["read_file", "list_directory"] + ), + cache_tools_list=True, + ), + # HTTP/Streamable HTTP transport for remote servers + MCPServerHTTP( + url="https://api.example.com/mcp", + headers={"Authorization": "Bearer your_token"}, + streamable=True, + cache_tools_list=True, + ), + # SSE transport for real-time streaming + MCPServerSSE( + url="https://stream.example.com/mcp/sse", + headers={"Authorization": "Bearer your_token"}, + ), + ] +) +``` + ### 🔧 **Advanced: MCPServerAdapter** (For Complex Scenarios) For advanced use cases requiring manual connection management, the `crewai-tools` library provides the `MCPServerAdapter` class. @@ -68,12 +112,14 @@ uv pip install 'crewai-tools[mcp]' ## Quick Start: Simple DSL Integration -The easiest way to integrate MCP servers is using the `mcps` field on your agents: +The easiest way to integrate MCP servers is using the `mcps` field on your agents. You can use either string references or structured configurations. + +### Quick Start with String References ```python from crewai import Agent, Task, Crew -# Create agent with MCP tools +# Create agent with MCP tools using string references research_agent = Agent( role="Research Analyst", goal="Find and analyze information using advanced search tools", @@ -96,13 +142,53 @@ crew = Crew(agents=[research_agent], tasks=[research_task]) result = crew.kickoff() ``` +### Quick Start with Structured Configurations + +```python +from crewai import Agent, Task, Crew +from crewai.mcp import MCPServerStdio, MCPServerHTTP, MCPServerSSE + +# Create agent with structured MCP configurations +research_agent = Agent( + role="Research Analyst", + goal="Find and analyze information using advanced search tools", + backstory="Expert researcher with access to multiple data sources", + mcps=[ + # Local stdio server + MCPServerStdio( + command="python", + args=["local_server.py"], + env={"API_KEY": "your_key"}, + ), + # Remote HTTP server + MCPServerHTTP( + url="https://api.research.com/mcp", + headers={"Authorization": "Bearer your_token"}, + ), + ] +) + +# Create task +research_task = Task( + description="Research the latest developments in AI agent frameworks", + expected_output="Comprehensive research report with citations", + agent=research_agent +) + +# Create and run crew +crew = Crew(agents=[research_agent], tasks=[research_task]) +result = crew.kickoff() +``` + That's it! The MCP tools are automatically discovered and available to your agent. ## MCP Reference Formats -The `mcps` field supports various reference formats for maximum flexibility: +The `mcps` field supports both **string references** (for quick setup) and **structured configurations** (for full control). You can mix both formats in the same list. -### External MCP Servers +### String-Based References + +#### External MCP Servers ```python mcps=[ @@ -117,7 +203,7 @@ mcps=[ ] ``` -### CrewAI AMP Marketplace +#### CrewAI AMP Marketplace ```python mcps=[ @@ -133,17 +219,166 @@ mcps=[ ] ``` -### Mixed References +### Structured Configurations + +#### Stdio Transport (Local Servers) + +Perfect for local MCP servers that run as processes: ```python +from crewai.mcp import MCPServerStdio +from crewai.mcp.filters import create_static_tool_filter + mcps=[ - "https://external-api.com/mcp", # External server - "https://weather.service.com/mcp#forecast", # Specific external tool - "crewai-amp:financial-insights", # AMP service - "crewai-amp:data-analysis#sentiment_tool" # Specific AMP tool + MCPServerStdio( + command="npx", + args=["-y", "@modelcontextprotocol/server-filesystem"], + env={"API_KEY": "your_key"}, + tool_filter=create_static_tool_filter( + allowed_tool_names=["read_file", "write_file"] + ), + cache_tools_list=True, + ), + # Python-based server + MCPServerStdio( + command="python", + args=["path/to/server.py"], + env={"UV_PYTHON": "3.12", "API_KEY": "your_key"}, + ), ] ``` +#### HTTP/Streamable HTTP Transport (Remote Servers) + +For remote MCP servers over HTTP/HTTPS: + +```python +from crewai.mcp import MCPServerHTTP + +mcps=[ + # Streamable HTTP (default) + MCPServerHTTP( + url="https://api.example.com/mcp", + headers={"Authorization": "Bearer your_token"}, + streamable=True, + cache_tools_list=True, + ), + # Standard HTTP + MCPServerHTTP( + url="https://api.example.com/mcp", + headers={"Authorization": "Bearer your_token"}, + streamable=False, + ), +] +``` + +#### SSE Transport (Real-Time Streaming) + +For remote servers using Server-Sent Events: + +```python +from crewai.mcp import MCPServerSSE + +mcps=[ + MCPServerSSE( + url="https://stream.example.com/mcp/sse", + headers={"Authorization": "Bearer your_token"}, + cache_tools_list=True, + ), +] +``` + +### Mixed References + +You can combine string references and structured configurations: + +```python +from crewai.mcp import MCPServerStdio, MCPServerHTTP + +mcps=[ + # String references + "https://external-api.com/mcp", # External server + "crewai-amp:financial-insights", # AMP service + + # Structured configurations + MCPServerStdio( + command="npx", + args=["-y", "@modelcontextprotocol/server-filesystem"], + ), + MCPServerHTTP( + url="https://api.example.com/mcp", + headers={"Authorization": "Bearer token"}, + ), +] +``` + +### Tool Filtering + +Structured configurations support advanced tool filtering: + +```python +from crewai.mcp import MCPServerStdio +from crewai.mcp.filters import create_static_tool_filter, create_dynamic_tool_filter, ToolFilterContext + +# Static filtering (allow/block lists) +static_filter = create_static_tool_filter( + allowed_tool_names=["read_file", "write_file"], + blocked_tool_names=["delete_file"], +) + +# Dynamic filtering (context-aware) +def dynamic_filter(context: ToolFilterContext, tool: dict) -> bool: + # Block dangerous tools for certain agent roles + if context.agent.role == "Code Reviewer": + if "delete" in tool.get("name", "").lower(): + return False + return True + +mcps=[ + MCPServerStdio( + command="npx", + args=["-y", "@modelcontextprotocol/server-filesystem"], + tool_filter=static_filter, # or dynamic_filter + ), +] +``` + +## Configuration Parameters + +Each transport type supports specific configuration options: + +### MCPServerStdio Parameters + +- **`command`** (required): Command to execute (e.g., `"python"`, `"node"`, `"npx"`, `"uvx"`) +- **`args`** (optional): List of command arguments (e.g., `["server.py"]` or `["-y", "@mcp/server"]`) +- **`env`** (optional): Dictionary of environment variables to pass to the process +- **`tool_filter`** (optional): Tool filter function for filtering available tools +- **`cache_tools_list`** (optional): Whether to cache the tool list for faster subsequent access (default: `False`) + +### MCPServerHTTP Parameters + +- **`url`** (required): Server URL (e.g., `"https://api.example.com/mcp"`) +- **`headers`** (optional): Dictionary of HTTP headers for authentication or other purposes +- **`streamable`** (optional): Whether to use streamable HTTP transport (default: `True`) +- **`tool_filter`** (optional): Tool filter function for filtering available tools +- **`cache_tools_list`** (optional): Whether to cache the tool list for faster subsequent access (default: `False`) + +### MCPServerSSE Parameters + +- **`url`** (required): Server URL (e.g., `"https://api.example.com/mcp/sse"`) +- **`headers`** (optional): Dictionary of HTTP headers for authentication or other purposes +- **`tool_filter`** (optional): Tool filter function for filtering available tools +- **`cache_tools_list`** (optional): Whether to cache the tool list for faster subsequent access (default: `False`) + +### Common Parameters + +All transport types support: +- **`tool_filter`**: Filter function to control which tools are available. Can be: + - `None` (default): All tools are available + - Static filter: Created with `create_static_tool_filter()` for allow/block lists + - Dynamic filter: Created with `create_dynamic_tool_filter()` for context-aware filtering +- **`cache_tools_list`**: When `True`, caches the tool list after first discovery to improve performance on subsequent connections + ## Key Features - 🔄 **Automatic Tool Discovery**: Tools are automatically discovered and integrated @@ -152,26 +387,47 @@ mcps=[ - 🛡️ **Error Resilience**: Graceful handling of unavailable servers - ⏱️ **Timeout Protection**: Built-in timeouts prevent hanging connections - 📊 **Transparent Integration**: Works seamlessly with existing CrewAI features +- 🔧 **Full Transport Support**: Stdio, HTTP/Streamable HTTP, and SSE transports +- 🎯 **Advanced Filtering**: Static and dynamic tool filtering capabilities +- 🔐 **Flexible Authentication**: Support for headers, environment variables, and query parameters ## Error Handling -The MCP DSL integration is designed to be resilient: +The MCP DSL integration is designed to be resilient and handles failures gracefully: ```python +from crewai import Agent +from crewai.mcp import MCPServerStdio, MCPServerHTTP + agent = Agent( role="Resilient Agent", goal="Continue working despite server issues", backstory="Agent that handles failures gracefully", mcps=[ + # String references "https://reliable-server.com/mcp", # Will work "https://unreachable-server.com/mcp", # Will be skipped gracefully - "https://slow-server.com/mcp", # Will timeout gracefully - "crewai-amp:working-service" # Will work + "crewai-amp:working-service", # Will work + + # Structured configs + MCPServerStdio( + command="python", + args=["reliable_server.py"], # Will work + ), + MCPServerHTTP( + url="https://slow-server.com/mcp", # Will timeout gracefully + ), ] ) # Agent will use tools from working servers and log warnings for failing ones ``` +All connection errors are handled gracefully: +- **Connection failures**: Logged as warnings, agent continues with available tools +- **Timeout errors**: Connections timeout after 30 seconds (configurable) +- **Authentication errors**: Logged clearly for debugging +- **Invalid configurations**: Validation errors are raised at agent creation time + ## Advanced: MCPServerAdapter For complex scenarios requiring manual connection management, use the `MCPServerAdapter` class from `crewai-tools`. Using a Python context manager (`with` statement) is the recommended approach as it automatically handles starting and stopping the connection to the MCP server. diff --git a/lib/crewai/src/crewai/agent/core.py b/lib/crewai/src/crewai/agent/core.py index 3e925cef6..54ae52ba8 100644 --- a/lib/crewai/src/crewai/agent/core.py +++ b/lib/crewai/src/crewai/agent/core.py @@ -40,6 +40,16 @@ from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource from crewai.knowledge.utils.knowledge_utils import extract_knowledge_context from crewai.lite_agent import LiteAgent from crewai.llms.base_llm import BaseLLM +from crewai.mcp import ( + MCPClient, + MCPServerConfig, + MCPServerHTTP, + MCPServerSSE, + MCPServerStdio, +) +from crewai.mcp.transports.http import HTTPTransport +from crewai.mcp.transports.sse import SSETransport +from crewai.mcp.transports.stdio import StdioTransport from crewai.memory.contextual.contextual_memory import ContextualMemory from crewai.rag.embeddings.types import EmbedderConfig from crewai.security.fingerprint import Fingerprint @@ -108,6 +118,7 @@ class Agent(BaseAgent): """ _times_executed: int = PrivateAttr(default=0) + _mcp_clients: list[Any] = PrivateAttr(default_factory=list) max_execution_time: int | None = Field( default=None, description="Maximum execution time for an agent to execute a task", @@ -526,6 +537,9 @@ class Agent(BaseAgent): self, event=AgentExecutionCompletedEvent(agent=self, task=task, output=result), ) + + self._cleanup_mcp_clients() + return result def _execute_with_timeout(self, task_prompt: str, task: Task, timeout: int) -> Any: @@ -649,30 +663,70 @@ class Agent(BaseAgent): self._logger.log("error", f"Error getting platform tools: {e!s}") return [] - def get_mcp_tools(self, mcps: list[str]) -> list[BaseTool]: - """Convert MCP server references to CrewAI tools.""" + def get_mcp_tools(self, mcps: list[str | MCPServerConfig]) -> list[BaseTool]: + """Convert MCP server references/configs to CrewAI tools. + + Supports both string references (backwards compatible) and structured + configuration objects (MCPServerStdio, MCPServerHTTP, MCPServerSSE). + + Args: + mcps: List of MCP server references (strings) or configurations. + + Returns: + List of BaseTool instances from MCP servers. + """ all_tools = [] + clients = [] - for mcp_ref in mcps: - try: - if mcp_ref.startswith("crewai-amp:"): - tools = self._get_amp_mcp_tools(mcp_ref) - elif mcp_ref.startswith("https://"): - tools = self._get_external_mcp_tools(mcp_ref) - else: - continue + for mcp_config in mcps: + if isinstance(mcp_config, str): + tools = self._get_mcp_tools_from_string(mcp_config) + else: + tools, client = self._get_native_mcp_tools(mcp_config) + if client: + clients.append(client) - all_tools.extend(tools) - self._logger.log( - "info", f"Successfully loaded {len(tools)} tools from {mcp_ref}" - ) - - except Exception as e: - self._logger.log("warning", f"Skipping MCP {mcp_ref} due to error: {e}") - continue + all_tools.extend(tools) + # Store clients for cleanup + self._mcp_clients.extend(clients) return all_tools + def _cleanup_mcp_clients(self) -> None: + """Cleanup MCP client connections after task execution.""" + if not self._mcp_clients: + return + + async def _disconnect_all() -> None: + for client in self._mcp_clients: + if client and hasattr(client, "connected") and client.connected: + await client.disconnect() + + try: + asyncio.run(_disconnect_all()) + except Exception as e: + self._logger.log("error", f"Error during MCP client cleanup: {e}") + finally: + self._mcp_clients.clear() + + def _get_mcp_tools_from_string(self, mcp_ref: str) -> list[BaseTool]: + """Get tools from legacy string-based MCP references. + + This method maintains backwards compatibility with string-based + MCP references (https://... and crewai-amp:...). + + Args: + mcp_ref: String reference to MCP server. + + Returns: + List of BaseTool instances. + """ + if mcp_ref.startswith("crewai-amp:"): + return self._get_amp_mcp_tools(mcp_ref) + if mcp_ref.startswith("https://"): + return self._get_external_mcp_tools(mcp_ref) + return [] + def _get_external_mcp_tools(self, mcp_ref: str) -> list[BaseTool]: """Get tools from external HTTPS MCP server with graceful error handling.""" from crewai.tools.mcp_tool_wrapper import MCPToolWrapper @@ -731,6 +785,154 @@ class Agent(BaseAgent): ) return [] + def _get_native_mcp_tools( + self, mcp_config: MCPServerConfig + ) -> tuple[list[BaseTool], Any | None]: + """Get tools from MCP server using structured configuration. + + This method creates an MCP client based on the configuration type, + connects to the server, discovers tools, applies filtering, and + returns wrapped tools along with the client instance for cleanup. + + Args: + mcp_config: MCP server configuration (MCPServerStdio, MCPServerHTTP, or MCPServerSSE). + + Returns: + Tuple of (list of BaseTool instances, MCPClient instance for cleanup). + """ + from crewai.tools.base_tool import BaseTool + from crewai.tools.mcp_native_tool import MCPNativeTool + + if isinstance(mcp_config, MCPServerStdio): + transport = StdioTransport( + command=mcp_config.command, + args=mcp_config.args, + env=mcp_config.env, + ) + server_name = f"{mcp_config.command}_{'_'.join(mcp_config.args)}" + elif isinstance(mcp_config, MCPServerHTTP): + transport = HTTPTransport( + url=mcp_config.url, + headers=mcp_config.headers, + streamable=mcp_config.streamable, + ) + server_name = self._extract_server_name(mcp_config.url) + elif isinstance(mcp_config, MCPServerSSE): + transport = SSETransport( + url=mcp_config.url, + headers=mcp_config.headers, + ) + server_name = self._extract_server_name(mcp_config.url) + else: + raise ValueError(f"Unsupported MCP server config type: {type(mcp_config)}") + + client = MCPClient( + transport=transport, + cache_tools_list=mcp_config.cache_tools_list, + ) + + async def _setup_client_and_list_tools() -> list[dict[str, Any]]: + """Async helper to connect and list tools in same event loop.""" + + try: + if not client.connected: + await client.connect() + + tools_list = await client.list_tools() + + try: + await client.disconnect() + # Small delay to allow background tasks to finish cleanup + # This helps prevent "cancel scope in different task" errors + # when asyncio.run() closes the event loop + await asyncio.sleep(0.1) + except Exception as e: + self._logger.log("error", f"Error during disconnect: {e}") + + return tools_list + except Exception as e: + if client.connected: + await client.disconnect() + await asyncio.sleep(0.1) + raise RuntimeError( + f"Error during setup client and list tools: {e}" + ) from e + + try: + try: + tools_list = asyncio.run(_setup_client_and_list_tools()) + except RuntimeError as e: + error_msg = str(e).lower() + if "cancel scope" in error_msg or "task" in error_msg: + raise ConnectionError( + "MCP connection failed due to event loop cleanup issues. " + "This may be due to authentication errors or server unavailability." + ) from e + except asyncio.CancelledError as e: + raise ConnectionError( + "MCP connection was cancelled. This may indicate an authentication " + "error or server unavailability." + ) from e + + if mcp_config.tool_filter: + filtered_tools = [] + for tool in tools_list: + if callable(mcp_config.tool_filter): + try: + from crewai.mcp.filters import ToolFilterContext + + context = ToolFilterContext( + agent=self, + server_name=server_name, + run_context=None, + ) + if mcp_config.tool_filter(context, tool): + filtered_tools.append(tool) + except (TypeError, AttributeError): + if mcp_config.tool_filter(tool): + filtered_tools.append(tool) + else: + # Not callable - include tool + filtered_tools.append(tool) + tools_list = filtered_tools + + tools = [] + for tool_def in tools_list: + tool_name = tool_def.get("name", "") + if not tool_name: + continue + + # Convert inputSchema to Pydantic model if present + args_schema = None + if tool_def.get("inputSchema"): + args_schema = self._json_schema_to_pydantic( + tool_name, tool_def["inputSchema"] + ) + + tool_schema = { + "description": tool_def.get("description", ""), + "args_schema": args_schema, + } + + try: + native_tool = MCPNativeTool( + mcp_client=client, + tool_name=tool_name, + tool_schema=tool_schema, + server_name=server_name, + ) + tools.append(native_tool) + except Exception as e: + self._logger.log("error", f"Failed to create native MCP tool: {e}") + continue + + return cast(list[BaseTool], tools), client + except Exception as e: + if client.connected: + asyncio.run(client.disconnect()) + + raise RuntimeError(f"Failed to get native MCP tools: {e}") from e + def _get_amp_mcp_tools(self, amp_ref: str) -> list[BaseTool]: """Get tools from CrewAI AMP MCP marketplace.""" # Parse: "crewai-amp:mcp-name" or "crewai-amp:mcp-name#tool_name" diff --git a/lib/crewai/src/crewai/agents/agent_builder/base_agent.py b/lib/crewai/src/crewai/agents/agent_builder/base_agent.py index b26c24515..932c98611 100644 --- a/lib/crewai/src/crewai/agents/agent_builder/base_agent.py +++ b/lib/crewai/src/crewai/agents/agent_builder/base_agent.py @@ -25,6 +25,7 @@ from crewai.agents.tools_handler import ToolsHandler from crewai.knowledge.knowledge import Knowledge from crewai.knowledge.knowledge_config import KnowledgeConfig from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource +from crewai.mcp.config import MCPServerConfig from crewai.rag.embeddings.types import EmbedderConfig from crewai.security.security_config import SecurityConfig from crewai.tools.base_tool import BaseTool, Tool @@ -194,7 +195,7 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta): default=None, description="List of applications or application/action combinations that the agent can access through CrewAI Platform. Can contain app names (e.g., 'gmail') or specific actions (e.g., 'gmail/send_email')", ) - mcps: list[str] | None = Field( + mcps: list[str | MCPServerConfig] | None = Field( default=None, description="List of MCP server references. Supports 'https://server.com/path' for external servers and 'crewai-amp:mcp-name' for AMP marketplace. Use '#tool_name' suffix for specific tools.", ) @@ -253,20 +254,36 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta): @field_validator("mcps") @classmethod - def validate_mcps(cls, mcps: list[str] | None) -> list[str] | None: + def validate_mcps( + cls, mcps: list[str | MCPServerConfig] | None + ) -> list[str | MCPServerConfig] | None: + """Validate MCP server references and configurations. + + Supports both string references (for backwards compatibility) and + structured configuration objects (MCPServerStdio, MCPServerHTTP, MCPServerSSE). + """ if not mcps: return mcps validated_mcps = [] for mcp in mcps: - if mcp.startswith(("https://", "crewai-amp:")): + if isinstance(mcp, str): + if mcp.startswith(("https://", "crewai-amp:")): + validated_mcps.append(mcp) + else: + raise ValueError( + f"Invalid MCP reference: {mcp}. " + "String references must start with 'https://' or 'crewai-amp:'" + ) + + elif isinstance(mcp, (MCPServerConfig)): validated_mcps.append(mcp) else: raise ValueError( - f"Invalid MCP reference: {mcp}. Must start with 'https://' or 'crewai-amp:'" + f"Invalid MCP configuration: {type(mcp)}. " + "Must be a string reference or MCPServerConfig instance." ) - - return list(set(validated_mcps)) + return validated_mcps @model_validator(mode="after") def validate_and_set_attributes(self) -> Self: @@ -343,7 +360,7 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta): """Get platform tools for the specified list of applications and/or application/action combinations.""" @abstractmethod - def get_mcp_tools(self, mcps: list[str]) -> list[BaseTool]: + def get_mcp_tools(self, mcps: list[str | MCPServerConfig]) -> list[BaseTool]: """Get MCP tools for the specified list of MCP server references.""" def copy(self) -> Self: # type: ignore # Signature of "copy" incompatible with supertype "BaseModel" diff --git a/lib/crewai/src/crewai/events/__init__.py b/lib/crewai/src/crewai/events/__init__.py index 66c441bed..4147965e1 100644 --- a/lib/crewai/src/crewai/events/__init__.py +++ b/lib/crewai/src/crewai/events/__init__.py @@ -16,7 +16,6 @@ from crewai.events.base_event_listener import BaseEventListener from crewai.events.depends import Depends from crewai.events.event_bus import crewai_event_bus from crewai.events.handler_graph import CircularDependencyError - from crewai.events.types.crew_events import ( CrewKickoffCompletedEvent, CrewKickoffFailedEvent, @@ -61,6 +60,14 @@ from crewai.events.types.logging_events import ( AgentLogsExecutionEvent, AgentLogsStartedEvent, ) +from crewai.events.types.mcp_events import ( + MCPConnectionCompletedEvent, + MCPConnectionFailedEvent, + MCPConnectionStartedEvent, + MCPToolExecutionCompletedEvent, + MCPToolExecutionFailedEvent, + MCPToolExecutionStartedEvent, +) from crewai.events.types.memory_events import ( MemoryQueryCompletedEvent, MemoryQueryFailedEvent, @@ -153,6 +160,12 @@ __all__ = [ "LiteAgentExecutionCompletedEvent", "LiteAgentExecutionErrorEvent", "LiteAgentExecutionStartedEvent", + "MCPConnectionCompletedEvent", + "MCPConnectionFailedEvent", + "MCPConnectionStartedEvent", + "MCPToolExecutionCompletedEvent", + "MCPToolExecutionFailedEvent", + "MCPToolExecutionStartedEvent", "MemoryQueryCompletedEvent", "MemoryQueryFailedEvent", "MemoryQueryStartedEvent", diff --git a/lib/crewai/src/crewai/events/event_listener.py b/lib/crewai/src/crewai/events/event_listener.py index 6bb604b64..e07ee193c 100644 --- a/lib/crewai/src/crewai/events/event_listener.py +++ b/lib/crewai/src/crewai/events/event_listener.py @@ -65,6 +65,14 @@ from crewai.events.types.logging_events import ( AgentLogsExecutionEvent, AgentLogsStartedEvent, ) +from crewai.events.types.mcp_events import ( + MCPConnectionCompletedEvent, + MCPConnectionFailedEvent, + MCPConnectionStartedEvent, + MCPToolExecutionCompletedEvent, + MCPToolExecutionFailedEvent, + MCPToolExecutionStartedEvent, +) from crewai.events.types.reasoning_events import ( AgentReasoningCompletedEvent, AgentReasoningFailedEvent, @@ -615,5 +623,67 @@ class EventListener(BaseEventListener): event.total_turns, ) + # ----------- MCP EVENTS ----------- + + @crewai_event_bus.on(MCPConnectionStartedEvent) + def on_mcp_connection_started(source, event: MCPConnectionStartedEvent): + self.formatter.handle_mcp_connection_started( + event.server_name, + event.server_url, + event.transport_type, + event.is_reconnect, + event.connect_timeout, + ) + + @crewai_event_bus.on(MCPConnectionCompletedEvent) + def on_mcp_connection_completed(source, event: MCPConnectionCompletedEvent): + self.formatter.handle_mcp_connection_completed( + event.server_name, + event.server_url, + event.transport_type, + event.connection_duration_ms, + event.is_reconnect, + ) + + @crewai_event_bus.on(MCPConnectionFailedEvent) + def on_mcp_connection_failed(source, event: MCPConnectionFailedEvent): + self.formatter.handle_mcp_connection_failed( + event.server_name, + event.server_url, + event.transport_type, + event.error, + event.error_type, + ) + + @crewai_event_bus.on(MCPToolExecutionStartedEvent) + def on_mcp_tool_execution_started(source, event: MCPToolExecutionStartedEvent): + self.formatter.handle_mcp_tool_execution_started( + event.server_name, + event.tool_name, + event.tool_args, + ) + + @crewai_event_bus.on(MCPToolExecutionCompletedEvent) + def on_mcp_tool_execution_completed( + source, event: MCPToolExecutionCompletedEvent + ): + self.formatter.handle_mcp_tool_execution_completed( + event.server_name, + event.tool_name, + event.tool_args, + event.result, + event.execution_duration_ms, + ) + + @crewai_event_bus.on(MCPToolExecutionFailedEvent) + def on_mcp_tool_execution_failed(source, event: MCPToolExecutionFailedEvent): + self.formatter.handle_mcp_tool_execution_failed( + event.server_name, + event.tool_name, + event.tool_args, + event.error, + event.error_type, + ) + event_listener = EventListener() diff --git a/lib/crewai/src/crewai/events/event_types.py b/lib/crewai/src/crewai/events/event_types.py index f7a4d1f72..ea00aa9ae 100644 --- a/lib/crewai/src/crewai/events/event_types.py +++ b/lib/crewai/src/crewai/events/event_types.py @@ -40,6 +40,14 @@ from crewai.events.types.llm_guardrail_events import ( LLMGuardrailCompletedEvent, LLMGuardrailStartedEvent, ) +from crewai.events.types.mcp_events import ( + MCPConnectionCompletedEvent, + MCPConnectionFailedEvent, + MCPConnectionStartedEvent, + MCPToolExecutionCompletedEvent, + MCPToolExecutionFailedEvent, + MCPToolExecutionStartedEvent, +) from crewai.events.types.memory_events import ( MemoryQueryCompletedEvent, MemoryQueryFailedEvent, @@ -115,4 +123,10 @@ EventTypes = ( | MemoryQueryFailedEvent | MemoryRetrievalStartedEvent | MemoryRetrievalCompletedEvent + | MCPConnectionStartedEvent + | MCPConnectionCompletedEvent + | MCPConnectionFailedEvent + | MCPToolExecutionStartedEvent + | MCPToolExecutionCompletedEvent + | MCPToolExecutionFailedEvent ) diff --git a/lib/crewai/src/crewai/events/types/mcp_events.py b/lib/crewai/src/crewai/events/types/mcp_events.py new file mode 100644 index 000000000..d360aa62a --- /dev/null +++ b/lib/crewai/src/crewai/events/types/mcp_events.py @@ -0,0 +1,85 @@ +from datetime import datetime +from typing import Any + +from crewai.events.base_events import BaseEvent + + +class MCPEvent(BaseEvent): + """Base event for MCP operations.""" + + server_name: str + server_url: str | None = None + transport_type: str | None = None # "stdio", "http", "sse" + agent_id: str | None = None + agent_role: str | None = None + from_agent: Any | None = None + from_task: Any | None = None + + def __init__(self, **data): + super().__init__(**data) + self._set_agent_params(data) + self._set_task_params(data) + + +class MCPConnectionStartedEvent(MCPEvent): + """Event emitted when starting to connect to an MCP server.""" + + type: str = "mcp_connection_started" + connect_timeout: int | None = None + is_reconnect: bool = ( + False # True if this is a reconnection, False for first connection + ) + + +class MCPConnectionCompletedEvent(MCPEvent): + """Event emitted when successfully connected to an MCP server.""" + + type: str = "mcp_connection_completed" + started_at: datetime | None = None + completed_at: datetime | None = None + connection_duration_ms: float | None = None + is_reconnect: bool = ( + False # True if this was a reconnection, False for first connection + ) + + +class MCPConnectionFailedEvent(MCPEvent): + """Event emitted when connection to an MCP server fails.""" + + type: str = "mcp_connection_failed" + error: str + error_type: str | None = None # "timeout", "authentication", "network", etc. + started_at: datetime | None = None + failed_at: datetime | None = None + + +class MCPToolExecutionStartedEvent(MCPEvent): + """Event emitted when starting to execute an MCP tool.""" + + type: str = "mcp_tool_execution_started" + tool_name: str + tool_args: dict[str, Any] | None = None + + +class MCPToolExecutionCompletedEvent(MCPEvent): + """Event emitted when MCP tool execution completes.""" + + type: str = "mcp_tool_execution_completed" + tool_name: str + tool_args: dict[str, Any] | None = None + result: Any | None = None + started_at: datetime | None = None + completed_at: datetime | None = None + execution_duration_ms: float | None = None + + +class MCPToolExecutionFailedEvent(MCPEvent): + """Event emitted when MCP tool execution fails.""" + + type: str = "mcp_tool_execution_failed" + tool_name: str + tool_args: dict[str, Any] | None = None + error: str + error_type: str | None = None # "timeout", "validation", "server_error", etc. + started_at: datetime | None = None + failed_at: datetime | None = None diff --git a/lib/crewai/src/crewai/events/utils/console_formatter.py b/lib/crewai/src/crewai/events/utils/console_formatter.py index 4ee2aa52b..32aa8d208 100644 --- a/lib/crewai/src/crewai/events/utils/console_formatter.py +++ b/lib/crewai/src/crewai/events/utils/console_formatter.py @@ -2248,3 +2248,203 @@ class ConsoleFormatter: self.current_a2a_conversation_branch = None self.current_a2a_turn_count = 0 + + # ----------- MCP EVENTS ----------- + + def handle_mcp_connection_started( + self, + server_name: str, + server_url: str | None = None, + transport_type: str | None = None, + is_reconnect: bool = False, + connect_timeout: int | None = None, + ) -> None: + """Handle MCP connection started event.""" + if not self.verbose: + return + + content = Text() + reconnect_text = " (Reconnecting)" if is_reconnect else "" + content.append(f"MCP Connection Started{reconnect_text}\n\n", style="cyan bold") + content.append("Server: ", style="white") + content.append(f"{server_name}\n", style="cyan") + + if server_url: + content.append("URL: ", style="white") + content.append(f"{server_url}\n", style="cyan dim") + + if transport_type: + content.append("Transport: ", style="white") + content.append(f"{transport_type}\n", style="cyan") + + if connect_timeout: + content.append("Timeout: ", style="white") + content.append(f"{connect_timeout}s\n", style="cyan") + + panel = self.create_panel(content, "🔌 MCP Connection", "cyan") + self.print(panel) + self.print() + + def handle_mcp_connection_completed( + self, + server_name: str, + server_url: str | None = None, + transport_type: str | None = None, + connection_duration_ms: float | None = None, + is_reconnect: bool = False, + ) -> None: + """Handle MCP connection completed event.""" + if not self.verbose: + return + + content = Text() + reconnect_text = " (Reconnected)" if is_reconnect else "" + content.append( + f"MCP Connection Completed{reconnect_text}\n\n", style="green bold" + ) + content.append("Server: ", style="white") + content.append(f"{server_name}\n", style="green") + + if server_url: + content.append("URL: ", style="white") + content.append(f"{server_url}\n", style="green dim") + + if transport_type: + content.append("Transport: ", style="white") + content.append(f"{transport_type}\n", style="green") + + if connection_duration_ms is not None: + content.append("Duration: ", style="white") + content.append(f"{connection_duration_ms:.2f}ms\n", style="green") + + panel = self.create_panel(content, "✅ MCP Connected", "green") + self.print(panel) + self.print() + + def handle_mcp_connection_failed( + self, + server_name: str, + server_url: str | None = None, + transport_type: str | None = None, + error: str = "", + error_type: str | None = None, + ) -> None: + """Handle MCP connection failed event.""" + if not self.verbose: + return + + content = Text() + content.append("MCP Connection Failed\n\n", style="red bold") + content.append("Server: ", style="white") + content.append(f"{server_name}\n", style="red") + + if server_url: + content.append("URL: ", style="white") + content.append(f"{server_url}\n", style="red dim") + + if transport_type: + content.append("Transport: ", style="white") + content.append(f"{transport_type}\n", style="red") + + if error_type: + content.append("Error Type: ", style="white") + content.append(f"{error_type}\n", style="red") + + if error: + content.append("\nError: ", style="white bold") + error_preview = error[:500] + "..." if len(error) > 500 else error + content.append(f"{error_preview}\n", style="red") + + panel = self.create_panel(content, "❌ MCP Connection Failed", "red") + self.print(panel) + self.print() + + def handle_mcp_tool_execution_started( + self, + server_name: str, + tool_name: str, + tool_args: dict[str, Any] | None = None, + ) -> None: + """Handle MCP tool execution started event.""" + if not self.verbose: + return + + content = self.create_status_content( + "MCP Tool Execution Started", + tool_name, + "yellow", + tool_args=tool_args or {}, + Server=server_name, + ) + + panel = self.create_panel(content, "🔧 MCP Tool", "yellow") + self.print(panel) + self.print() + + def handle_mcp_tool_execution_completed( + self, + server_name: str, + tool_name: str, + tool_args: dict[str, Any] | None = None, + result: Any | None = None, + execution_duration_ms: float | None = None, + ) -> None: + """Handle MCP tool execution completed event.""" + if not self.verbose: + return + + content = self.create_status_content( + "MCP Tool Execution Completed", + tool_name, + "green", + tool_args=tool_args or {}, + Server=server_name, + ) + + if execution_duration_ms is not None: + content.append("Duration: ", style="white") + content.append(f"{execution_duration_ms:.2f}ms\n", style="green") + + if result is not None: + result_str = str(result) + if len(result_str) > 500: + result_str = result_str[:497] + "..." + content.append("\nResult: ", style="white bold") + content.append(f"{result_str}\n", style="green") + + panel = self.create_panel(content, "✅ MCP Tool Completed", "green") + self.print(panel) + self.print() + + def handle_mcp_tool_execution_failed( + self, + server_name: str, + tool_name: str, + tool_args: dict[str, Any] | None = None, + error: str = "", + error_type: str | None = None, + ) -> None: + """Handle MCP tool execution failed event.""" + if not self.verbose: + return + + content = self.create_status_content( + "MCP Tool Execution Failed", + tool_name, + "red", + tool_args=tool_args or {}, + Server=server_name, + ) + + if error_type: + content.append("Error Type: ", style="white") + content.append(f"{error_type}\n", style="red") + + if error: + content.append("\nError: ", style="white bold") + error_preview = error[:500] + "..." if len(error) > 500 else error + content.append(f"{error_preview}\n", style="red") + + panel = self.create_panel(content, "❌ MCP Tool Failed", "red") + self.print(panel) + self.print() diff --git a/lib/crewai/src/crewai/mcp/__init__.py b/lib/crewai/src/crewai/mcp/__init__.py new file mode 100644 index 000000000..282cb1f56 --- /dev/null +++ b/lib/crewai/src/crewai/mcp/__init__.py @@ -0,0 +1,37 @@ +"""MCP (Model Context Protocol) client support for CrewAI agents. + +This module provides native MCP client functionality, allowing CrewAI agents +to connect to any MCP-compliant server using various transport types. +""" + +from crewai.mcp.client import MCPClient +from crewai.mcp.config import ( + MCPServerConfig, + MCPServerHTTP, + MCPServerSSE, + MCPServerStdio, +) +from crewai.mcp.filters import ( + StaticToolFilter, + ToolFilter, + ToolFilterContext, + create_dynamic_tool_filter, + create_static_tool_filter, +) +from crewai.mcp.transports.base import BaseTransport, TransportType + + +__all__ = [ + "BaseTransport", + "MCPClient", + "MCPServerConfig", + "MCPServerHTTP", + "MCPServerSSE", + "MCPServerStdio", + "StaticToolFilter", + "ToolFilter", + "ToolFilterContext", + "TransportType", + "create_dynamic_tool_filter", + "create_static_tool_filter", +] diff --git a/lib/crewai/src/crewai/mcp/client.py b/lib/crewai/src/crewai/mcp/client.py new file mode 100644 index 000000000..aff07a397 --- /dev/null +++ b/lib/crewai/src/crewai/mcp/client.py @@ -0,0 +1,742 @@ +"""MCP client with session management for CrewAI agents.""" + +import asyncio +from collections.abc import Callable +from contextlib import AsyncExitStack +from datetime import datetime +import logging +import time +from typing import Any + +from typing_extensions import Self + + +# BaseExceptionGroup is available in Python 3.11+ +try: + from builtins import BaseExceptionGroup +except ImportError: + # Fallback for Python < 3.11 (shouldn't happen in practice) + BaseExceptionGroup = Exception + +from crewai.events.event_bus import crewai_event_bus +from crewai.events.types.mcp_events import ( + MCPConnectionCompletedEvent, + MCPConnectionFailedEvent, + MCPConnectionStartedEvent, + MCPToolExecutionCompletedEvent, + MCPToolExecutionFailedEvent, + MCPToolExecutionStartedEvent, +) +from crewai.mcp.transports.base import BaseTransport +from crewai.mcp.transports.http import HTTPTransport +from crewai.mcp.transports.sse import SSETransport +from crewai.mcp.transports.stdio import StdioTransport + + +# MCP Connection timeout constants (in seconds) +MCP_CONNECTION_TIMEOUT = 30 # Increased for slow servers +MCP_TOOL_EXECUTION_TIMEOUT = 30 +MCP_DISCOVERY_TIMEOUT = 30 # Increased for slow servers +MCP_MAX_RETRIES = 3 + +# Simple in-memory cache for MCP tool schemas (duration: 5 minutes) +_mcp_schema_cache: dict[str, tuple[dict[str, Any], float]] = {} +_cache_ttl = 300 # 5 minutes + + +class MCPClient: + """MCP client with session management. + + This client manages connections to MCP servers and provides a high-level + interface for interacting with MCP tools, prompts, and resources. + + Example: + ```python + transport = StdioTransport(command="python", args=["server.py"]) + client = MCPClient(transport) + async with client: + tools = await client.list_tools() + result = await client.call_tool("tool_name", {"arg": "value"}) + ``` + """ + + def __init__( + self, + transport: BaseTransport, + connect_timeout: int = MCP_CONNECTION_TIMEOUT, + execution_timeout: int = MCP_TOOL_EXECUTION_TIMEOUT, + discovery_timeout: int = MCP_DISCOVERY_TIMEOUT, + max_retries: int = MCP_MAX_RETRIES, + cache_tools_list: bool = False, + logger: logging.Logger | None = None, + ) -> None: + """Initialize MCP client. + + Args: + transport: Transport instance for MCP server connection. + connect_timeout: Connection timeout in seconds. + execution_timeout: Tool execution timeout in seconds. + discovery_timeout: Tool discovery timeout in seconds. + max_retries: Maximum retry attempts for operations. + cache_tools_list: Whether to cache tool list results. + logger: Optional logger instance. + """ + self.transport = transport + self.connect_timeout = connect_timeout + self.execution_timeout = execution_timeout + self.discovery_timeout = discovery_timeout + self.max_retries = max_retries + self.cache_tools_list = cache_tools_list + # self._logger = logger or logging.getLogger(__name__) + self._session: Any = None + self._initialized = False + self._exit_stack = AsyncExitStack() + self._was_connected = False + + @property + def connected(self) -> bool: + """Check if client is connected to server.""" + return self.transport.connected and self._initialized + + @property + def session(self) -> Any: + """Get the MCP session.""" + if self._session is None: + raise RuntimeError("Client not connected. Call connect() first.") + return self._session + + def _get_server_info(self) -> tuple[str, str | None, str | None]: + """Get server information for events. + + Returns: + Tuple of (server_name, server_url, transport_type). + """ + if isinstance(self.transport, StdioTransport): + server_name = f"{self.transport.command} {' '.join(self.transport.args)}" + server_url = None + transport_type = self.transport.transport_type.value + elif isinstance(self.transport, HTTPTransport): + server_name = self.transport.url + server_url = self.transport.url + transport_type = self.transport.transport_type.value + elif isinstance(self.transport, SSETransport): + server_name = self.transport.url + server_url = self.transport.url + transport_type = self.transport.transport_type.value + else: + server_name = "Unknown MCP Server" + server_url = None + transport_type = ( + self.transport.transport_type.value + if hasattr(self.transport, "transport_type") + else None + ) + + return server_name, server_url, transport_type + + async def connect(self) -> Self: + """Connect to MCP server and initialize session. + + Returns: + Self for method chaining. + + Raises: + ConnectionError: If connection fails. + ImportError: If MCP SDK not available. + """ + if self.connected: + return self + + # Get server info for events + server_name, server_url, transport_type = self._get_server_info() + is_reconnect = self._was_connected + + # Emit connection started event + started_at = datetime.now() + crewai_event_bus.emit( + self, + MCPConnectionStartedEvent( + server_name=server_name, + server_url=server_url, + transport_type=transport_type, + is_reconnect=is_reconnect, + connect_timeout=self.connect_timeout, + ), + ) + + try: + from mcp import ClientSession + + # Use AsyncExitStack to manage transport and session contexts together + # This ensures they're in the same async scope and prevents cancel scope errors + # Always enter transport context via exit stack (it handles already-connected state) + await self._exit_stack.enter_async_context(self.transport) + + # Create ClientSession with transport streams + self._session = ClientSession( + self.transport.read_stream, + self.transport.write_stream, + ) + + # Enter the session's async context manager via exit stack + await self._exit_stack.enter_async_context(self._session) + + # Initialize the session (required by MCP protocol) + try: + await asyncio.wait_for( + self._session.initialize(), + timeout=self.connect_timeout, + ) + except asyncio.CancelledError: + # If initialization was cancelled (e.g., event loop closing), + # cleanup and re-raise - don't suppress cancellation + await self._cleanup_on_error() + raise + except BaseExceptionGroup as eg: + # Handle exception groups from anyio task groups + # Extract the actual meaningful error (not GeneratorExit) + actual_error = None + for exc in eg.exceptions: + if isinstance(exc, Exception) and not isinstance( + exc, GeneratorExit + ): + # Check if it's an HTTP error (like 401) + error_msg = str(exc).lower() + if "401" in error_msg or "unauthorized" in error_msg: + actual_error = exc + break + if "cancel scope" not in error_msg and "task" not in error_msg: + actual_error = exc + break + + await self._cleanup_on_error() + if actual_error: + raise ConnectionError( + f"Failed to connect to MCP server: {actual_error}" + ) from actual_error + raise ConnectionError(f"Failed to connect to MCP server: {eg}") from eg + + self._initialized = True + self._was_connected = True + + completed_at = datetime.now() + connection_duration_ms = (completed_at - started_at).total_seconds() * 1000 + crewai_event_bus.emit( + self, + MCPConnectionCompletedEvent( + server_name=server_name, + server_url=server_url, + transport_type=transport_type, + started_at=started_at, + completed_at=completed_at, + connection_duration_ms=connection_duration_ms, + is_reconnect=is_reconnect, + ), + ) + + return self + except ImportError as e: + await self._cleanup_on_error() + error_msg = ( + "MCP library not available. Please install with: pip install mcp" + ) + self._emit_connection_failed( + server_name, + server_url, + transport_type, + error_msg, + "import_error", + started_at, + ) + raise ImportError(error_msg) from e + except asyncio.TimeoutError as e: + await self._cleanup_on_error() + error_msg = f"MCP connection timed out after {self.connect_timeout} seconds. The server may be slow or unreachable." + self._emit_connection_failed( + server_name, + server_url, + transport_type, + error_msg, + "timeout", + started_at, + ) + raise ConnectionError(error_msg) from e + except asyncio.CancelledError: + # Re-raise cancellation - don't suppress it + await self._cleanup_on_error() + self._emit_connection_failed( + server_name, + server_url, + transport_type, + "Connection cancelled", + "cancelled", + started_at, + ) + raise + except BaseExceptionGroup as eg: + # Handle exception groups from anyio task groups at outer level + actual_error = None + for exc in eg.exceptions: + if isinstance(exc, Exception) and not isinstance(exc, GeneratorExit): + error_msg = str(exc).lower() + if "401" in error_msg or "unauthorized" in error_msg: + actual_error = exc + break + if "cancel scope" not in error_msg and "task" not in error_msg: + actual_error = exc + break + + await self._cleanup_on_error() + error_type = ( + "authentication" + if actual_error + and ( + "401" in str(actual_error).lower() + or "unauthorized" in str(actual_error).lower() + ) + else "network" + ) + error_msg = str(actual_error) if actual_error else str(eg) + self._emit_connection_failed( + server_name, + server_url, + transport_type, + error_msg, + error_type, + started_at, + ) + if actual_error: + raise ConnectionError( + f"Failed to connect to MCP server: {actual_error}" + ) from actual_error + raise ConnectionError(f"Failed to connect to MCP server: {eg}") from eg + except Exception as e: + await self._cleanup_on_error() + error_type = ( + "authentication" + if "401" in str(e).lower() or "unauthorized" in str(e).lower() + else "network" + ) + self._emit_connection_failed( + server_name, server_url, transport_type, str(e), error_type, started_at + ) + raise ConnectionError(f"Failed to connect to MCP server: {e}") from e + + def _emit_connection_failed( + self, + server_name: str, + server_url: str | None, + transport_type: str | None, + error: str, + error_type: str, + started_at: datetime, + ) -> None: + """Emit connection failed event.""" + failed_at = datetime.now() + crewai_event_bus.emit( + self, + MCPConnectionFailedEvent( + server_name=server_name, + server_url=server_url, + transport_type=transport_type, + error=error, + error_type=error_type, + started_at=started_at, + failed_at=failed_at, + ), + ) + + async def _cleanup_on_error(self) -> None: + """Cleanup resources when an error occurs during connection.""" + try: + await self._exit_stack.aclose() + + except Exception as e: + # Best effort cleanup - ignore all other errors + raise RuntimeError(f"Error during MCP client cleanup: {e}") from e + finally: + self._session = None + self._initialized = False + self._exit_stack = AsyncExitStack() + + async def disconnect(self) -> None: + """Disconnect from MCP server and cleanup resources.""" + if not self.connected: + return + + try: + await self._exit_stack.aclose() + except Exception as e: + raise RuntimeError(f"Error during MCP client disconnect: {e}") from e + finally: + self._session = None + self._initialized = False + self._exit_stack = AsyncExitStack() + + async def list_tools(self, use_cache: bool | None = None) -> list[dict[str, Any]]: + """List available tools from MCP server. + + Args: + use_cache: Whether to use cached results. If None, uses + client's cache_tools_list setting. + + Returns: + List of tool definitions with name, description, and inputSchema. + """ + if not self.connected: + await self.connect() + + # Check cache if enabled + use_cache = use_cache if use_cache is not None else self.cache_tools_list + if use_cache: + cache_key = self._get_cache_key("tools") + if cache_key in _mcp_schema_cache: + cached_data, cache_time = _mcp_schema_cache[cache_key] + if time.time() - cache_time < _cache_ttl: + # Logger removed - return cached data + return cached_data + + # List tools with timeout and retries + tools = await self._retry_operation( + self._list_tools_impl, + timeout=self.discovery_timeout, + ) + + # Cache results if enabled + if use_cache: + cache_key = self._get_cache_key("tools") + _mcp_schema_cache[cache_key] = (tools, time.time()) + + return tools + + async def _list_tools_impl(self) -> list[dict[str, Any]]: + """Internal implementation of list_tools.""" + tools_result = await asyncio.wait_for( + self.session.list_tools(), + timeout=self.discovery_timeout, + ) + + return [ + { + "name": tool.name, + "description": getattr(tool, "description", ""), + "inputSchema": getattr(tool, "inputSchema", {}), + } + for tool in tools_result.tools + ] + + async def call_tool( + self, tool_name: str, arguments: dict[str, Any] | None = None + ) -> Any: + """Call a tool on the MCP server. + + Args: + tool_name: Name of the tool to call. + arguments: Tool arguments. + + Returns: + Tool execution result. + """ + if not self.connected: + await self.connect() + + arguments = arguments or {} + cleaned_arguments = self._clean_tool_arguments(arguments) + + # Get server info for events + server_name, server_url, transport_type = self._get_server_info() + + # Emit tool execution started event + started_at = datetime.now() + crewai_event_bus.emit( + self, + MCPToolExecutionStartedEvent( + server_name=server_name, + server_url=server_url, + transport_type=transport_type, + tool_name=tool_name, + tool_args=cleaned_arguments, + ), + ) + + try: + result = await self._retry_operation( + lambda: self._call_tool_impl(tool_name, cleaned_arguments), + timeout=self.execution_timeout, + ) + + completed_at = datetime.now() + execution_duration_ms = (completed_at - started_at).total_seconds() * 1000 + crewai_event_bus.emit( + self, + MCPToolExecutionCompletedEvent( + server_name=server_name, + server_url=server_url, + transport_type=transport_type, + tool_name=tool_name, + tool_args=cleaned_arguments, + result=result, + started_at=started_at, + completed_at=completed_at, + execution_duration_ms=execution_duration_ms, + ), + ) + + return result + except Exception as e: + failed_at = datetime.now() + error_type = ( + "timeout" + if isinstance(e, (asyncio.TimeoutError, ConnectionError)) + and "timeout" in str(e).lower() + else "server_error" + ) + crewai_event_bus.emit( + self, + MCPToolExecutionFailedEvent( + server_name=server_name, + server_url=server_url, + transport_type=transport_type, + tool_name=tool_name, + tool_args=cleaned_arguments, + error=str(e), + error_type=error_type, + started_at=started_at, + failed_at=failed_at, + ), + ) + raise + + def _clean_tool_arguments(self, arguments: dict[str, Any]) -> dict[str, Any]: + """Clean tool arguments by removing None values and fixing formats. + + Args: + arguments: Raw tool arguments. + + Returns: + Cleaned arguments ready for MCP server. + """ + cleaned = {} + + for key, value in arguments.items(): + # Skip None values + if value is None: + continue + + # Fix sources array format: convert ["web"] to [{"type": "web"}] + if key == "sources" and isinstance(value, list): + fixed_sources = [] + for item in value: + if isinstance(item, str): + # Convert string to object format + fixed_sources.append({"type": item}) + elif isinstance(item, dict): + # Already in correct format + fixed_sources.append(item) + else: + # Keep as is if unknown format + fixed_sources.append(item) + if fixed_sources: + cleaned[key] = fixed_sources + continue + + # Recursively clean nested dictionaries + if isinstance(value, dict): + nested_cleaned = self._clean_tool_arguments(value) + if nested_cleaned: # Only add if not empty + cleaned[key] = nested_cleaned + elif isinstance(value, list): + # Clean list items + cleaned_list = [] + for item in value: + if isinstance(item, dict): + cleaned_item = self._clean_tool_arguments(item) + if cleaned_item: + cleaned_list.append(cleaned_item) + elif item is not None: + cleaned_list.append(item) + if cleaned_list: + cleaned[key] = cleaned_list + else: + # Keep primitive values + cleaned[key] = value + + return cleaned + + async def _call_tool_impl(self, tool_name: str, arguments: dict[str, Any]) -> Any: + """Internal implementation of call_tool.""" + result = await asyncio.wait_for( + self.session.call_tool(tool_name, arguments), + timeout=self.execution_timeout, + ) + + # Extract result content + if hasattr(result, "content") and result.content: + if isinstance(result.content, list) and len(result.content) > 0: + content_item = result.content[0] + if hasattr(content_item, "text"): + return str(content_item.text) + return str(content_item) + return str(result.content) + + return str(result) + + async def list_prompts(self) -> list[dict[str, Any]]: + """List available prompts from MCP server. + + Returns: + List of prompt definitions. + """ + if not self.connected: + await self.connect() + + return await self._retry_operation( + self._list_prompts_impl, + timeout=self.discovery_timeout, + ) + + async def _list_prompts_impl(self) -> list[dict[str, Any]]: + """Internal implementation of list_prompts.""" + prompts_result = await asyncio.wait_for( + self.session.list_prompts(), + timeout=self.discovery_timeout, + ) + + return [ + { + "name": prompt.name, + "description": getattr(prompt, "description", ""), + "arguments": getattr(prompt, "arguments", []), + } + for prompt in prompts_result.prompts + ] + + async def get_prompt( + self, prompt_name: str, arguments: dict[str, Any] | None = None + ) -> dict[str, Any]: + """Get a prompt from the MCP server. + + Args: + prompt_name: Name of the prompt to get. + arguments: Optional prompt arguments. + + Returns: + Prompt content and metadata. + """ + if not self.connected: + await self.connect() + + arguments = arguments or {} + + return await self._retry_operation( + lambda: self._get_prompt_impl(prompt_name, arguments), + timeout=self.execution_timeout, + ) + + async def _get_prompt_impl( + self, prompt_name: str, arguments: dict[str, Any] + ) -> dict[str, Any]: + """Internal implementation of get_prompt.""" + result = await asyncio.wait_for( + self.session.get_prompt(prompt_name, arguments), + timeout=self.execution_timeout, + ) + + return { + "name": prompt_name, + "messages": [ + { + "role": msg.role, + "content": msg.content, + } + for msg in result.messages + ], + "arguments": arguments, + } + + async def _retry_operation( + self, + operation: Callable[[], Any], + timeout: int | None = None, + ) -> Any: + """Retry an operation with exponential backoff. + + Args: + operation: Async operation to retry. + timeout: Operation timeout in seconds. + + Returns: + Operation result. + """ + last_error = None + timeout = timeout or self.execution_timeout + + for attempt in range(self.max_retries): + try: + if timeout: + return await asyncio.wait_for(operation(), timeout=timeout) + return await operation() + + except asyncio.TimeoutError as e: # noqa: PERF203 + last_error = f"Operation timed out after {timeout} seconds" + if attempt < self.max_retries - 1: + wait_time = 2**attempt + await asyncio.sleep(wait_time) + else: + raise ConnectionError(last_error) from e + + except Exception as e: + error_str = str(e).lower() + + # Classify errors as retryable or non-retryable + if "authentication" in error_str or "unauthorized" in error_str: + raise ConnectionError(f"Authentication failed: {e}") from e + + if "not found" in error_str: + raise ValueError(f"Resource not found: {e}") from e + + # Retryable errors + last_error = str(e) + if attempt < self.max_retries - 1: + wait_time = 2**attempt + await asyncio.sleep(wait_time) + else: + raise ConnectionError( + f"Operation failed after {self.max_retries} attempts: {last_error}" + ) from e + + raise ConnectionError(f"Operation failed: {last_error}") + + def _get_cache_key(self, resource_type: str) -> str: + """Generate cache key for resource. + + Args: + resource_type: Type of resource (e.g., "tools", "prompts"). + + Returns: + Cache key string. + """ + # Use transport type and URL/command as cache key + if isinstance(self.transport, StdioTransport): + key = f"stdio:{self.transport.command}:{':'.join(self.transport.args)}" + elif isinstance(self.transport, HTTPTransport): + key = f"http:{self.transport.url}" + elif isinstance(self.transport, SSETransport): + key = f"sse:{self.transport.url}" + else: + key = f"{self.transport.transport_type}:unknown" + + return f"mcp:{key}:{resource_type}" + + async def __aenter__(self) -> Self: + """Async context manager entry.""" + return await self.connect() + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: Any, + ) -> None: + """Async context manager exit.""" + await self.disconnect() diff --git a/lib/crewai/src/crewai/mcp/config.py b/lib/crewai/src/crewai/mcp/config.py new file mode 100644 index 000000000..775f9403d --- /dev/null +++ b/lib/crewai/src/crewai/mcp/config.py @@ -0,0 +1,124 @@ +"""MCP server configuration models for CrewAI agents. + +This module provides Pydantic models for configuring MCP servers with +various transport types, similar to OpenAI's Agents SDK. +""" + +from pydantic import BaseModel, Field + +from crewai.mcp.filters import ToolFilter + + +class MCPServerStdio(BaseModel): + """Stdio MCP server configuration. + + This configuration is used for connecting to local MCP servers + that run as processes and communicate via standard input/output. + + Example: + ```python + mcp_server = MCPServerStdio( + command="python", + args=["path/to/server.py"], + env={"API_KEY": "..."}, + tool_filter=create_static_tool_filter( + allowed_tool_names=["read_file", "write_file"] + ), + ) + ``` + """ + + command: str = Field( + ..., + description="Command to execute (e.g., 'python', 'node', 'npx', 'uvx').", + ) + args: list[str] = Field( + default_factory=list, + description="Command arguments (e.g., ['server.py'] or ['-y', '@mcp/server']).", + ) + env: dict[str, str] | None = Field( + default=None, + description="Environment variables to pass to the process.", + ) + tool_filter: ToolFilter | None = Field( + default=None, + description="Optional tool filter for filtering available tools.", + ) + cache_tools_list: bool = Field( + default=False, + description="Whether to cache the tool list for faster subsequent access.", + ) + + +class MCPServerHTTP(BaseModel): + """HTTP/Streamable HTTP MCP server configuration. + + This configuration is used for connecting to remote MCP servers + over HTTP/HTTPS using streamable HTTP transport. + + Example: + ```python + mcp_server = MCPServerHTTP( + url="https://api.example.com/mcp", + headers={"Authorization": "Bearer ..."}, + cache_tools_list=True, + ) + ``` + """ + + url: str = Field( + ..., description="Server URL (e.g., 'https://api.example.com/mcp')." + ) + headers: dict[str, str] | None = Field( + default=None, + description="Optional HTTP headers for authentication or other purposes.", + ) + streamable: bool = Field( + default=True, + description="Whether to use streamable HTTP transport (default: True).", + ) + tool_filter: ToolFilter | None = Field( + default=None, + description="Optional tool filter for filtering available tools.", + ) + cache_tools_list: bool = Field( + default=False, + description="Whether to cache the tool list for faster subsequent access.", + ) + + +class MCPServerSSE(BaseModel): + """Server-Sent Events (SSE) MCP server configuration. + + This configuration is used for connecting to remote MCP servers + using Server-Sent Events for real-time streaming communication. + + Example: + ```python + mcp_server = MCPServerSSE( + url="https://api.example.com/mcp/sse", + headers={"Authorization": "Bearer ..."}, + ) + ``` + """ + + url: str = Field( + ..., + description="Server URL (e.g., 'https://api.example.com/mcp/sse').", + ) + headers: dict[str, str] | None = Field( + default=None, + description="Optional HTTP headers for authentication or other purposes.", + ) + tool_filter: ToolFilter | None = Field( + default=None, + description="Optional tool filter for filtering available tools.", + ) + cache_tools_list: bool = Field( + default=False, + description="Whether to cache the tool list for faster subsequent access.", + ) + + +# Type alias for all MCP server configurations +MCPServerConfig = MCPServerStdio | MCPServerHTTP | MCPServerSSE diff --git a/lib/crewai/src/crewai/mcp/filters.py b/lib/crewai/src/crewai/mcp/filters.py new file mode 100644 index 000000000..ee2f7a560 --- /dev/null +++ b/lib/crewai/src/crewai/mcp/filters.py @@ -0,0 +1,166 @@ +"""Tool filtering support for MCP servers. + +This module provides utilities for filtering tools from MCP servers, +including static allow/block lists and dynamic context-aware filtering. +""" + +from collections.abc import Callable +from typing import TYPE_CHECKING, Any + +from pydantic import BaseModel, Field + + +if TYPE_CHECKING: + pass + + +class ToolFilterContext(BaseModel): + """Context for dynamic tool filtering. + + This context is passed to dynamic tool filters to provide + information about the agent, run context, and server. + """ + + agent: Any = Field(..., description="The agent requesting tools.") + server_name: str = Field(..., description="Name of the MCP server.") + run_context: dict[str, Any] | None = Field( + default=None, + description="Optional run context for additional filtering logic.", + ) + + +# Type alias for tool filter functions +ToolFilter = ( + Callable[[ToolFilterContext, dict[str, Any]], bool] + | Callable[[dict[str, Any]], bool] +) + + +class StaticToolFilter: + """Static tool filter with allow/block lists. + + This filter provides simple allow/block list filtering based on + tool names. Useful for restricting which tools are available + from an MCP server. + + Example: + ```python + filter = StaticToolFilter( + allowed_tool_names=["read_file", "write_file"], + blocked_tool_names=["delete_file"], + ) + ``` + """ + + def __init__( + self, + allowed_tool_names: list[str] | None = None, + blocked_tool_names: list[str] | None = None, + ) -> None: + """Initialize static tool filter. + + Args: + allowed_tool_names: List of tool names to allow. If None, + all tools are allowed (unless blocked). + blocked_tool_names: List of tool names to block. Blocked tools + take precedence over allowed tools. + """ + self.allowed_tool_names = set(allowed_tool_names or []) + self.blocked_tool_names = set(blocked_tool_names or []) + + def __call__(self, tool: dict[str, Any]) -> bool: + """Filter tool based on allow/block lists. + + Args: + tool: Tool definition dictionary with at least 'name' key. + + Returns: + True if tool should be included, False otherwise. + """ + tool_name = tool.get("name", "") + + # Blocked tools take precedence + if self.blocked_tool_names and tool_name in self.blocked_tool_names: + return False + + # If allow list exists, tool must be in it + if self.allowed_tool_names: + return tool_name in self.allowed_tool_names + + # No restrictions - allow all + return True + + +def create_static_tool_filter( + allowed_tool_names: list[str] | None = None, + blocked_tool_names: list[str] | None = None, +) -> Callable[[dict[str, Any]], bool]: + """Create a static tool filter function. + + This is a convenience function for creating static tool filters + with allow/block lists. + + Args: + allowed_tool_names: List of tool names to allow. If None, + all tools are allowed (unless blocked). + blocked_tool_names: List of tool names to block. Blocked tools + take precedence over allowed tools. + + Returns: + Tool filter function that returns True for allowed tools. + + Example: + ```python + filter_fn = create_static_tool_filter( + allowed_tool_names=["read_file", "write_file"], + blocked_tool_names=["delete_file"], + ) + + # Use in MCPServerStdio + mcp_server = MCPServerStdio( + command="npx", + args=["-y", "@modelcontextprotocol/server-filesystem"], + tool_filter=filter_fn, + ) + ``` + """ + return StaticToolFilter( + allowed_tool_names=allowed_tool_names, + blocked_tool_names=blocked_tool_names, + ) + + +def create_dynamic_tool_filter( + filter_func: Callable[[ToolFilterContext, dict[str, Any]], bool], +) -> Callable[[ToolFilterContext, dict[str, Any]], bool]: + """Create a dynamic tool filter function. + + This function wraps a dynamic filter function that has access + to the tool filter context (agent, server, run context). + + Args: + filter_func: Function that takes (context, tool) and returns bool. + + Returns: + Tool filter function that can be used with MCP server configs. + + Example: + ```python + async def context_aware_filter( + context: ToolFilterContext, tool: dict[str, Any] + ) -> bool: + # Block dangerous tools for code reviewers + if context.agent.role == "Code Reviewer": + if tool["name"].startswith("danger_"): + return False + return True + + + filter_fn = create_dynamic_tool_filter(context_aware_filter) + + mcp_server = MCPServerStdio( + command="python", args=["server.py"], tool_filter=filter_fn + ) + ``` + """ + return filter_func diff --git a/lib/crewai/src/crewai/mcp/transports/__init__.py b/lib/crewai/src/crewai/mcp/transports/__init__.py new file mode 100644 index 000000000..4e579f50e --- /dev/null +++ b/lib/crewai/src/crewai/mcp/transports/__init__.py @@ -0,0 +1,15 @@ +"""MCP transport implementations for various connection types.""" + +from crewai.mcp.transports.base import BaseTransport, TransportType +from crewai.mcp.transports.http import HTTPTransport +from crewai.mcp.transports.sse import SSETransport +from crewai.mcp.transports.stdio import StdioTransport + + +__all__ = [ + "BaseTransport", + "HTTPTransport", + "SSETransport", + "StdioTransport", + "TransportType", +] diff --git a/lib/crewai/src/crewai/mcp/transports/base.py b/lib/crewai/src/crewai/mcp/transports/base.py new file mode 100644 index 000000000..d6e5f958d --- /dev/null +++ b/lib/crewai/src/crewai/mcp/transports/base.py @@ -0,0 +1,125 @@ +"""Base transport interface for MCP connections.""" + +from abc import ABC, abstractmethod +from enum import Enum +from typing import Any, Protocol + +from typing_extensions import Self + + +class TransportType(str, Enum): + """MCP transport types.""" + + STDIO = "stdio" + HTTP = "http" + STREAMABLE_HTTP = "streamable-http" + SSE = "sse" + + +class ReadStream(Protocol): + """Protocol for read streams.""" + + async def read(self, n: int = -1) -> bytes: + """Read bytes from stream.""" + ... + + +class WriteStream(Protocol): + """Protocol for write streams.""" + + async def write(self, data: bytes) -> None: + """Write bytes to stream.""" + ... + + +class BaseTransport(ABC): + """Base class for MCP transport implementations. + + This abstract base class defines the interface that all transport + implementations must follow. Transports handle the low-level communication + with MCP servers. + """ + + def __init__(self, **kwargs: Any) -> None: + """Initialize the transport. + + Args: + **kwargs: Transport-specific configuration options. + """ + self._read_stream: ReadStream | None = None + self._write_stream: WriteStream | None = None + self._connected = False + + @property + @abstractmethod + def transport_type(self) -> TransportType: + """Return the transport type.""" + ... + + @property + def connected(self) -> bool: + """Check if transport is connected.""" + return self._connected + + @property + def read_stream(self) -> ReadStream: + """Get the read stream.""" + if self._read_stream is None: + raise RuntimeError("Transport not connected. Call connect() first.") + return self._read_stream + + @property + def write_stream(self) -> WriteStream: + """Get the write stream.""" + if self._write_stream is None: + raise RuntimeError("Transport not connected. Call connect() first.") + return self._write_stream + + @abstractmethod + async def connect(self) -> Self: + """Establish connection to MCP server. + + Returns: + Self for method chaining. + + Raises: + ConnectionError: If connection fails. + """ + ... + + @abstractmethod + async def disconnect(self) -> None: + """Close connection to MCP server.""" + ... + + @abstractmethod + async def __aenter__(self) -> Self: + """Async context manager entry.""" + ... + + @abstractmethod + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: Any, + ) -> None: + """Async context manager exit.""" + ... + + def _set_streams(self, read: ReadStream, write: WriteStream) -> None: + """Set the read and write streams. + + Args: + read: Read stream. + write: Write stream. + """ + self._read_stream = read + self._write_stream = write + self._connected = True + + def _clear_streams(self) -> None: + """Clear the read and write streams.""" + self._read_stream = None + self._write_stream = None + self._connected = False diff --git a/lib/crewai/src/crewai/mcp/transports/http.py b/lib/crewai/src/crewai/mcp/transports/http.py new file mode 100644 index 000000000..d531d8906 --- /dev/null +++ b/lib/crewai/src/crewai/mcp/transports/http.py @@ -0,0 +1,174 @@ +"""HTTP and Streamable HTTP transport for MCP servers.""" + +import asyncio +from typing import Any + +from typing_extensions import Self + + +# BaseExceptionGroup is available in Python 3.11+ +try: + from builtins import BaseExceptionGroup +except ImportError: + # Fallback for Python < 3.11 (shouldn't happen in practice) + BaseExceptionGroup = Exception + +from crewai.mcp.transports.base import BaseTransport, TransportType + + +class HTTPTransport(BaseTransport): + """HTTP/Streamable HTTP transport for connecting to remote MCP servers. + + This transport connects to MCP servers over HTTP/HTTPS using the + streamable HTTP client from the MCP SDK. + + Example: + ```python + transport = HTTPTransport( + url="https://api.example.com/mcp", + headers={"Authorization": "Bearer ..."} + ) + async with transport: + # Use transport... + ``` + """ + + def __init__( + self, + url: str, + headers: dict[str, str] | None = None, + streamable: bool = True, + **kwargs: Any, + ) -> None: + """Initialize HTTP transport. + + Args: + url: Server URL (e.g., "https://api.example.com/mcp"). + headers: Optional HTTP headers. + streamable: Whether to use streamable HTTP (default: True). + **kwargs: Additional transport options. + """ + super().__init__(**kwargs) + self.url = url + self.headers = headers or {} + self.streamable = streamable + self._transport_context: Any = None + + @property + def transport_type(self) -> TransportType: + """Return the transport type.""" + return TransportType.STREAMABLE_HTTP if self.streamable else TransportType.HTTP + + async def connect(self) -> Self: + """Establish HTTP connection to MCP server. + + Returns: + Self for method chaining. + + Raises: + ConnectionError: If connection fails. + ImportError: If MCP SDK not available. + """ + if self._connected: + return self + + try: + from mcp.client.streamable_http import streamablehttp_client + + self._transport_context = streamablehttp_client( + self.url, + headers=self.headers if self.headers else None, + terminate_on_close=True, + ) + + try: + read, write, _ = await asyncio.wait_for( + self._transport_context.__aenter__(), timeout=30.0 + ) + except asyncio.TimeoutError as e: + self._transport_context = None + raise ConnectionError( + "Transport context entry timed out after 30 seconds. " + "Server may be slow or unreachable." + ) from e + except Exception as e: + self._transport_context = None + raise ConnectionError(f"Failed to enter transport context: {e}") from e + self._set_streams(read=read, write=write) + return self + + except ImportError as e: + raise ImportError( + "MCP library not available. Please install with: pip install mcp" + ) from e + except Exception as e: + self._clear_streams() + if self._transport_context is not None: + self._transport_context = None + raise ConnectionError(f"Failed to connect to MCP server: {e}") from e + + async def disconnect(self) -> None: + """Close HTTP connection.""" + if not self._connected: + return + + try: + # Clear streams first + self._clear_streams() + # await self._exit_stack.aclose() + + # Exit transport context - this will clean up background tasks + # Give a small delay to allow background tasks to complete + if self._transport_context is not None: + try: + # Wait a tiny bit for any pending operations + await asyncio.sleep(0.1) + await self._transport_context.__aexit__(None, None, None) + except (RuntimeError, asyncio.CancelledError) as e: + # Ignore "exit cancel scope in different task" errors and cancellation + # These happen when asyncio.run() closes the event loop + # while background tasks are still running + error_msg = str(e).lower() + if "cancel scope" not in error_msg and "task" not in error_msg: + # Only suppress cancel scope/task errors, re-raise others + if isinstance(e, RuntimeError): + raise + # For CancelledError, just suppress it + except BaseExceptionGroup as eg: + # Handle exception groups from anyio task groups + # Suppress if they contain cancel scope errors + should_suppress = False + for exc in eg.exceptions: + error_msg = str(exc).lower() + if "cancel scope" in error_msg or "task" in error_msg: + should_suppress = True + break + if not should_suppress: + raise + except Exception as e: + raise RuntimeError( + f"Error during HTTP transport disconnect: {e}" + ) from e + + self._connected = False + + except Exception as e: + # Log but don't raise - cleanup should be best effort + import logging + + logger = logging.getLogger(__name__) + logger.warning(f"Error during HTTP transport disconnect: {e}") + + async def __aenter__(self) -> Self: + """Async context manager entry.""" + return await self.connect() + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: Any, + ) -> None: + """Async context manager exit.""" + + await self.disconnect() diff --git a/lib/crewai/src/crewai/mcp/transports/sse.py b/lib/crewai/src/crewai/mcp/transports/sse.py new file mode 100644 index 000000000..ce418c51f --- /dev/null +++ b/lib/crewai/src/crewai/mcp/transports/sse.py @@ -0,0 +1,113 @@ +"""Server-Sent Events (SSE) transport for MCP servers.""" + +from typing import Any + +from typing_extensions import Self + +from crewai.mcp.transports.base import BaseTransport, TransportType + + +class SSETransport(BaseTransport): + """SSE transport for connecting to remote MCP servers. + + This transport connects to MCP servers using Server-Sent Events (SSE) + for real-time streaming communication. + + Example: + ```python + transport = SSETransport( + url="https://api.example.com/mcp/sse", + headers={"Authorization": "Bearer ..."} + ) + async with transport: + # Use transport... + ``` + """ + + def __init__( + self, + url: str, + headers: dict[str, str] | None = None, + **kwargs: Any, + ) -> None: + """Initialize SSE transport. + + Args: + url: Server URL (e.g., "https://api.example.com/mcp/sse"). + headers: Optional HTTP headers. + **kwargs: Additional transport options. + """ + super().__init__(**kwargs) + self.url = url + self.headers = headers or {} + self._transport_context: Any = None + + @property + def transport_type(self) -> TransportType: + """Return the transport type.""" + return TransportType.SSE + + async def connect(self) -> Self: + """Establish SSE connection to MCP server. + + Returns: + Self for method chaining. + + Raises: + ConnectionError: If connection fails. + ImportError: If MCP SDK not available. + """ + if self._connected: + return self + + try: + from mcp.client.sse import sse_client + + self._transport_context = sse_client( + self.url, + headers=self.headers if self.headers else None, + terminate_on_close=True, + ) + + read, write = await self._transport_context.__aenter__() + + self._set_streams(read=read, write=write) + + return self + + except ImportError as e: + raise ImportError( + "MCP library not available. Please install with: pip install mcp" + ) from e + except Exception as e: + self._clear_streams() + raise ConnectionError(f"Failed to connect to SSE MCP server: {e}") from e + + async def disconnect(self) -> None: + """Close SSE connection.""" + if not self._connected: + return + + try: + self._clear_streams() + if self._transport_context is not None: + await self._transport_context.__aexit__(None, None, None) + + except Exception as e: + import logging + + logger = logging.getLogger(__name__) + logger.warning(f"Error during SSE transport disconnect: {e}") + + async def __aenter__(self) -> Self: + """Async context manager entry.""" + return await self.connect() + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: Any, + ) -> None: + """Async context manager exit.""" + await self.disconnect() diff --git a/lib/crewai/src/crewai/mcp/transports/stdio.py b/lib/crewai/src/crewai/mcp/transports/stdio.py new file mode 100644 index 000000000..65f288505 --- /dev/null +++ b/lib/crewai/src/crewai/mcp/transports/stdio.py @@ -0,0 +1,153 @@ +"""Stdio transport for MCP servers running as local processes.""" + +import asyncio +import os +import subprocess +from typing import Any + +from typing_extensions import Self + +from crewai.mcp.transports.base import BaseTransport, TransportType + + +class StdioTransport(BaseTransport): + """Stdio transport for connecting to local MCP servers. + + This transport connects to MCP servers running as local processes, + communicating via standard input/output streams. Supports Python, + Node.js, and other command-line servers. + + Example: + ```python + transport = StdioTransport( + command="python", + args=["path/to/server.py"], + env={"API_KEY": "..."} + ) + async with transport: + # Use transport... + ``` + """ + + def __init__( + self, + command: str, + args: list[str] | None = None, + env: dict[str, str] | None = None, + **kwargs: Any, + ) -> None: + """Initialize stdio transport. + + Args: + command: Command to execute (e.g., "python", "node", "npx"). + args: Command arguments (e.g., ["server.py"] or ["-y", "@mcp/server"]). + env: Environment variables to pass to the process. + **kwargs: Additional transport options. + """ + super().__init__(**kwargs) + self.command = command + self.args = args or [] + self.env = env or {} + self._process: subprocess.Popen[bytes] | None = None + self._transport_context: Any = None + + @property + def transport_type(self) -> TransportType: + """Return the transport type.""" + return TransportType.STDIO + + async def connect(self) -> Self: + """Start the MCP server process and establish connection. + + Returns: + Self for method chaining. + + Raises: + ConnectionError: If process fails to start. + ImportError: If MCP SDK not available. + """ + if self._connected: + return self + + try: + from mcp import StdioServerParameters + from mcp.client.stdio import stdio_client + + process_env = os.environ.copy() + process_env.update(self.env) + + server_params = StdioServerParameters( + command=self.command, + args=self.args, + env=process_env if process_env else None, + ) + self._transport_context = stdio_client(server_params) + + try: + read, write = await self._transport_context.__aenter__() + except Exception as e: + import traceback + + traceback.print_exc() + self._transport_context = None + raise ConnectionError( + f"Failed to enter stdio transport context: {e}" + ) from e + + self._set_streams(read=read, write=write) + + return self + + except ImportError as e: + raise ImportError( + "MCP library not available. Please install with: pip install mcp" + ) from e + except Exception as e: + self._clear_streams() + if self._transport_context is not None: + self._transport_context = None + raise ConnectionError(f"Failed to start MCP server process: {e}") from e + + async def disconnect(self) -> None: + """Terminate the MCP server process and close connection.""" + if not self._connected: + return + + try: + self._clear_streams() + + if self._transport_context is not None: + await self._transport_context.__aexit__(None, None, None) + + if self._process is not None: + try: + self._process.terminate() + try: + await asyncio.wait_for(self._process.wait(), timeout=5.0) + except asyncio.TimeoutError: + self._process.kill() + await self._process.wait() + # except ProcessLookupError: + # pass + finally: + self._process = None + + except Exception as e: + # Log but don't raise - cleanup should be best effort + import logging + + logger = logging.getLogger(__name__) + logger.warning(f"Error during stdio transport disconnect: {e}") + + async def __aenter__(self) -> Self: + """Async context manager entry.""" + return await self.connect() + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: Any, + ) -> None: + """Async context manager exit.""" + await self.disconnect() diff --git a/lib/crewai/src/crewai/tools/mcp_native_tool.py b/lib/crewai/src/crewai/tools/mcp_native_tool.py new file mode 100644 index 000000000..c10d51eee --- /dev/null +++ b/lib/crewai/src/crewai/tools/mcp_native_tool.py @@ -0,0 +1,154 @@ +"""Native MCP tool wrapper for CrewAI agents. + +This module provides a tool wrapper that reuses existing MCP client sessions +for better performance and connection management. +""" + +import asyncio +from typing import Any + +from crewai.tools import BaseTool + + +class MCPNativeTool(BaseTool): + """Native MCP tool that reuses client sessions. + + This tool wrapper is used when agents connect to MCP servers using + structured configurations. It reuses existing client sessions for + better performance and proper connection lifecycle management. + + Unlike MCPToolWrapper which connects on-demand, this tool uses + a shared MCP client instance that maintains a persistent connection. + """ + + def __init__( + self, + mcp_client: Any, + tool_name: str, + tool_schema: dict[str, Any], + server_name: str, + ) -> None: + """Initialize native MCP tool. + + Args: + mcp_client: MCPClient instance with active session. + tool_name: Original name of the tool on the MCP server. + tool_schema: Schema information for the tool. + server_name: Name of the MCP server for prefixing. + """ + # Create tool name with server prefix to avoid conflicts + prefixed_name = f"{server_name}_{tool_name}" + + # Handle args_schema properly - BaseTool expects a BaseModel subclass + args_schema = tool_schema.get("args_schema") + + # Only pass args_schema if it's provided + kwargs = { + "name": prefixed_name, + "description": tool_schema.get( + "description", f"Tool {tool_name} from {server_name}" + ), + } + + if args_schema is not None: + kwargs["args_schema"] = args_schema + + super().__init__(**kwargs) + + # Set instance attributes after super().__init__ + self._mcp_client = mcp_client + self._original_tool_name = tool_name + self._server_name = server_name + # self._logger = logging.getLogger(__name__) + + @property + def mcp_client(self) -> Any: + """Get the MCP client instance.""" + return self._mcp_client + + @property + def original_tool_name(self) -> str: + """Get the original tool name.""" + return self._original_tool_name + + @property + def server_name(self) -> str: + """Get the server name.""" + return self._server_name + + def _run(self, **kwargs) -> str: + """Execute tool using the MCP client session. + + Args: + **kwargs: Arguments to pass to the MCP tool. + + Returns: + Result from the MCP tool execution. + """ + try: + # Always use asyncio.run() to create a fresh event loop + # This ensures the async context managers work correctly + return asyncio.run(self._run_async(**kwargs)) + + except Exception as e: + raise RuntimeError( + f"Error executing MCP tool {self.original_tool_name}: {e!s}" + ) from e + + async def _run_async(self, **kwargs) -> str: + """Async implementation of tool execution. + + Args: + **kwargs: Arguments to pass to the MCP tool. + + Returns: + Result from the MCP tool execution. + """ + # Note: Since we use asyncio.run() which creates a new event loop each time, + # Always reconnect on-demand because asyncio.run() creates new event loops per call + # All MCP transport context managers (stdio, streamablehttp_client, sse_client) + # use anyio.create_task_group() which can't span different event loops + if self._mcp_client.connected: + await self._mcp_client.disconnect() + + await self._mcp_client.connect() + + try: + result = await self._mcp_client.call_tool(self.original_tool_name, kwargs) + + except Exception as e: + error_str = str(e).lower() + if ( + "not connected" in error_str + or "connection" in error_str + or "send" in error_str + ): + await self._mcp_client.disconnect() + await self._mcp_client.connect() + # Retry the call + result = await self._mcp_client.call_tool( + self.original_tool_name, kwargs + ) + else: + raise + + finally: + # Always disconnect after tool call to ensure clean context manager lifecycle + # This prevents "exit cancel scope in different task" errors + # All transport context managers must be exited in the same event loop they were entered + await self._mcp_client.disconnect() + + # Extract result content + if isinstance(result, str): + return result + + # Handle various result formats + if hasattr(result, "content") and result.content: + if isinstance(result.content, list) and len(result.content) > 0: + content_item = result.content[0] + if hasattr(content_item, "text"): + return str(content_item.text) + return str(content_item) + return str(result.content) + + return str(result) diff --git a/lib/crewai/tests/mcp/__init__.py b/lib/crewai/tests/mcp/__init__.py new file mode 100644 index 000000000..740ce3f01 --- /dev/null +++ b/lib/crewai/tests/mcp/__init__.py @@ -0,0 +1,4 @@ +"""Tests for MCP (Model Context Protocol) integration.""" + + + diff --git a/lib/crewai/tests/mcp/test_mcp_config.py b/lib/crewai/tests/mcp/test_mcp_config.py new file mode 100644 index 000000000..627ceb6e2 --- /dev/null +++ b/lib/crewai/tests/mcp/test_mcp_config.py @@ -0,0 +1,136 @@ +from unittest.mock import AsyncMock, patch + +import pytest +from crewai.agent.core import Agent +from crewai.mcp.config import MCPServerHTTP, MCPServerSSE, MCPServerStdio +from crewai.tools.base_tool import BaseTool + + +@pytest.fixture +def mock_tool_definitions(): + """Create mock MCP tool definitions (as returned by list_tools).""" + return [ + { + "name": "test_tool_1", + "description": "Test tool 1 description", + "inputSchema": { + "type": "object", + "properties": { + "query": {"type": "string", "description": "Search query"} + }, + "required": ["query"] + } + }, + { + "name": "test_tool_2", + "description": "Test tool 2 description", + "inputSchema": {} + } + ] + + +def test_agent_with_stdio_mcp_config(mock_tool_definitions): + """Test agent setup with MCPServerStdio configuration.""" + stdio_config = MCPServerStdio( + command="python", + args=["server.py"], + env={"API_KEY": "test_key"}, + ) + + agent = Agent( + role="Test Agent", + goal="Test goal", + backstory="Test backstory", + mcps=[stdio_config], + ) + + + with patch("crewai.agent.core.MCPClient") as mock_client_class: + mock_client = AsyncMock() + mock_client.list_tools = AsyncMock(return_value=mock_tool_definitions) + mock_client.connected = False # Will trigger connect + mock_client.connect = AsyncMock() + mock_client.disconnect = AsyncMock() + mock_client_class.return_value = mock_client + + tools = agent.get_mcp_tools([stdio_config]) + + assert len(tools) == 2 + assert all(isinstance(tool, BaseTool) for tool in tools) + + mock_client_class.assert_called_once() + call_args = mock_client_class.call_args + transport = call_args.kwargs["transport"] + assert transport.command == "python" + assert transport.args == ["server.py"] + assert transport.env == {"API_KEY": "test_key"} + + +def test_agent_with_http_mcp_config(mock_tool_definitions): + """Test agent setup with MCPServerHTTP configuration.""" + http_config = MCPServerHTTP( + url="https://api.example.com/mcp", + headers={"Authorization": "Bearer test_token"}, + streamable=True, + ) + + agent = Agent( + role="Test Agent", + goal="Test goal", + backstory="Test backstory", + mcps=[http_config], + ) + + with patch("crewai.agent.core.MCPClient") as mock_client_class: + mock_client = AsyncMock() + mock_client.list_tools = AsyncMock(return_value=mock_tool_definitions) + mock_client.connected = False # Will trigger connect + mock_client.connect = AsyncMock() + mock_client.disconnect = AsyncMock() + mock_client_class.return_value = mock_client + + tools = agent.get_mcp_tools([http_config]) + + assert len(tools) == 2 + assert all(isinstance(tool, BaseTool) for tool in tools) + + mock_client_class.assert_called_once() + call_args = mock_client_class.call_args + transport = call_args.kwargs["transport"] + assert transport.url == "https://api.example.com/mcp" + assert transport.headers == {"Authorization": "Bearer test_token"} + assert transport.streamable is True + + +def test_agent_with_sse_mcp_config(mock_tool_definitions): + """Test agent setup with MCPServerSSE configuration.""" + sse_config = MCPServerSSE( + url="https://api.example.com/mcp/sse", + headers={"Authorization": "Bearer test_token"}, + ) + + agent = Agent( + role="Test Agent", + goal="Test goal", + backstory="Test backstory", + mcps=[sse_config], + ) + + with patch("crewai.agent.core.MCPClient") as mock_client_class: + mock_client = AsyncMock() + mock_client.list_tools = AsyncMock(return_value=mock_tool_definitions) + mock_client.connected = False + mock_client.connect = AsyncMock() + mock_client.disconnect = AsyncMock() + mock_client_class.return_value = mock_client + + tools = agent.get_mcp_tools([sse_config]) + + assert len(tools) == 2 + assert all(isinstance(tool, BaseTool) for tool in tools) + + mock_client_class.assert_called_once() + call_args = mock_client_class.call_args + transport = call_args.kwargs["transport"] + assert transport.url == "https://api.example.com/mcp/sse" + assert transport.headers == {"Authorization": "Bearer test_token"}