From 6f36d7003bc465dcbc267a1f840676d17df19e3f Mon Sep 17 00:00:00 2001 From: Lorenze Jay <63378463+lorenzejay@users.noreply.github.com> Date: Thu, 6 Nov 2025 17:45:16 -0800 Subject: [PATCH 1/4] Lorenze/feat mcp first class support (#3850) * WIP transport support mcp * refactor: streamline MCP tool loading and error handling * linted * Self type from typing with typing_extensions in MCP transport modules * added tests for mcp setup * added tests for mcp setup * docs: enhance MCP overview with detailed integration examples and structured configurations * feat: implement MCP event handling and logging in event listener and client - Added MCP event types and handlers for connection and tool execution events. - Enhanced MCPClient to emit events on connection status and tool execution. - Updated ConsoleFormatter to handle MCP event logging. - Introduced new MCP event types for better integration and monitoring. --- docs/en/mcp/overview.mdx | 286 ++++++- lib/crewai/src/crewai/agent/core.py | 238 +++++- .../crewai/agents/agent_builder/base_agent.py | 31 +- lib/crewai/src/crewai/events/__init__.py | 15 +- .../src/crewai/events/event_listener.py | 70 ++ lib/crewai/src/crewai/events/event_types.py | 14 + .../src/crewai/events/types/mcp_events.py | 85 ++ .../crewai/events/utils/console_formatter.py | 200 +++++ lib/crewai/src/crewai/mcp/__init__.py | 37 + lib/crewai/src/crewai/mcp/client.py | 742 ++++++++++++++++++ lib/crewai/src/crewai/mcp/config.py | 124 +++ lib/crewai/src/crewai/mcp/filters.py | 166 ++++ .../src/crewai/mcp/transports/__init__.py | 15 + lib/crewai/src/crewai/mcp/transports/base.py | 125 +++ lib/crewai/src/crewai/mcp/transports/http.py | 174 ++++ lib/crewai/src/crewai/mcp/transports/sse.py | 113 +++ lib/crewai/src/crewai/mcp/transports/stdio.py | 153 ++++ .../src/crewai/tools/mcp_native_tool.py | 154 ++++ lib/crewai/tests/mcp/__init__.py | 4 + lib/crewai/tests/mcp/test_mcp_config.py | 136 ++++ 20 files changed, 2841 insertions(+), 41 deletions(-) create mode 100644 lib/crewai/src/crewai/events/types/mcp_events.py create mode 100644 lib/crewai/src/crewai/mcp/__init__.py create mode 100644 lib/crewai/src/crewai/mcp/client.py create mode 100644 lib/crewai/src/crewai/mcp/config.py create mode 100644 lib/crewai/src/crewai/mcp/filters.py create mode 100644 lib/crewai/src/crewai/mcp/transports/__init__.py create mode 100644 lib/crewai/src/crewai/mcp/transports/base.py create mode 100644 lib/crewai/src/crewai/mcp/transports/http.py create mode 100644 lib/crewai/src/crewai/mcp/transports/sse.py create mode 100644 lib/crewai/src/crewai/mcp/transports/stdio.py create mode 100644 lib/crewai/src/crewai/tools/mcp_native_tool.py create mode 100644 lib/crewai/tests/mcp/__init__.py create mode 100644 lib/crewai/tests/mcp/test_mcp_config.py diff --git a/docs/en/mcp/overview.mdx b/docs/en/mcp/overview.mdx index 63bdab5d6..d8eb2743c 100644 --- a/docs/en/mcp/overview.mdx +++ b/docs/en/mcp/overview.mdx @@ -11,9 +11,13 @@ The [Model Context Protocol](https://modelcontextprotocol.io/introduction) (MCP) CrewAI offers **two approaches** for MCP integration: -### Simple DSL Integration** (Recommended) +### 🚀 **Simple DSL Integration** (Recommended) -Use the `mcps` field directly on agents for seamless MCP tool integration: +Use the `mcps` field directly on agents for seamless MCP tool integration. The DSL supports both **string references** (for quick setup) and **structured configurations** (for full control). + +#### String-Based References (Quick Setup) + +Perfect for remote HTTPS servers and CrewAI AMP marketplace: ```python from crewai import Agent @@ -32,6 +36,46 @@ agent = Agent( # MCP tools are now automatically available to your agent! ``` +#### Structured Configurations (Full Control) + +For complete control over connection settings, tool filtering, and all transport types: + +```python +from crewai import Agent +from crewai.mcp import MCPServerStdio, MCPServerHTTP, MCPServerSSE +from crewai.mcp.filters import create_static_tool_filter + +agent = Agent( + role="Advanced Research Analyst", + goal="Research with full control over MCP connections", + backstory="Expert researcher with advanced tool access", + mcps=[ + # Stdio transport for local servers + MCPServerStdio( + command="npx", + args=["-y", "@modelcontextprotocol/server-filesystem"], + env={"API_KEY": "your_key"}, + tool_filter=create_static_tool_filter( + allowed_tool_names=["read_file", "list_directory"] + ), + cache_tools_list=True, + ), + # HTTP/Streamable HTTP transport for remote servers + MCPServerHTTP( + url="https://api.example.com/mcp", + headers={"Authorization": "Bearer your_token"}, + streamable=True, + cache_tools_list=True, + ), + # SSE transport for real-time streaming + MCPServerSSE( + url="https://stream.example.com/mcp/sse", + headers={"Authorization": "Bearer your_token"}, + ), + ] +) +``` + ### 🔧 **Advanced: MCPServerAdapter** (For Complex Scenarios) For advanced use cases requiring manual connection management, the `crewai-tools` library provides the `MCPServerAdapter` class. @@ -68,12 +112,14 @@ uv pip install 'crewai-tools[mcp]' ## Quick Start: Simple DSL Integration -The easiest way to integrate MCP servers is using the `mcps` field on your agents: +The easiest way to integrate MCP servers is using the `mcps` field on your agents. You can use either string references or structured configurations. + +### Quick Start with String References ```python from crewai import Agent, Task, Crew -# Create agent with MCP tools +# Create agent with MCP tools using string references research_agent = Agent( role="Research Analyst", goal="Find and analyze information using advanced search tools", @@ -96,13 +142,53 @@ crew = Crew(agents=[research_agent], tasks=[research_task]) result = crew.kickoff() ``` +### Quick Start with Structured Configurations + +```python +from crewai import Agent, Task, Crew +from crewai.mcp import MCPServerStdio, MCPServerHTTP, MCPServerSSE + +# Create agent with structured MCP configurations +research_agent = Agent( + role="Research Analyst", + goal="Find and analyze information using advanced search tools", + backstory="Expert researcher with access to multiple data sources", + mcps=[ + # Local stdio server + MCPServerStdio( + command="python", + args=["local_server.py"], + env={"API_KEY": "your_key"}, + ), + # Remote HTTP server + MCPServerHTTP( + url="https://api.research.com/mcp", + headers={"Authorization": "Bearer your_token"}, + ), + ] +) + +# Create task +research_task = Task( + description="Research the latest developments in AI agent frameworks", + expected_output="Comprehensive research report with citations", + agent=research_agent +) + +# Create and run crew +crew = Crew(agents=[research_agent], tasks=[research_task]) +result = crew.kickoff() +``` + That's it! The MCP tools are automatically discovered and available to your agent. ## MCP Reference Formats -The `mcps` field supports various reference formats for maximum flexibility: +The `mcps` field supports both **string references** (for quick setup) and **structured configurations** (for full control). You can mix both formats in the same list. -### External MCP Servers +### String-Based References + +#### External MCP Servers ```python mcps=[ @@ -117,7 +203,7 @@ mcps=[ ] ``` -### CrewAI AMP Marketplace +#### CrewAI AMP Marketplace ```python mcps=[ @@ -133,17 +219,166 @@ mcps=[ ] ``` -### Mixed References +### Structured Configurations + +#### Stdio Transport (Local Servers) + +Perfect for local MCP servers that run as processes: ```python +from crewai.mcp import MCPServerStdio +from crewai.mcp.filters import create_static_tool_filter + mcps=[ - "https://external-api.com/mcp", # External server - "https://weather.service.com/mcp#forecast", # Specific external tool - "crewai-amp:financial-insights", # AMP service - "crewai-amp:data-analysis#sentiment_tool" # Specific AMP tool + MCPServerStdio( + command="npx", + args=["-y", "@modelcontextprotocol/server-filesystem"], + env={"API_KEY": "your_key"}, + tool_filter=create_static_tool_filter( + allowed_tool_names=["read_file", "write_file"] + ), + cache_tools_list=True, + ), + # Python-based server + MCPServerStdio( + command="python", + args=["path/to/server.py"], + env={"UV_PYTHON": "3.12", "API_KEY": "your_key"}, + ), ] ``` +#### HTTP/Streamable HTTP Transport (Remote Servers) + +For remote MCP servers over HTTP/HTTPS: + +```python +from crewai.mcp import MCPServerHTTP + +mcps=[ + # Streamable HTTP (default) + MCPServerHTTP( + url="https://api.example.com/mcp", + headers={"Authorization": "Bearer your_token"}, + streamable=True, + cache_tools_list=True, + ), + # Standard HTTP + MCPServerHTTP( + url="https://api.example.com/mcp", + headers={"Authorization": "Bearer your_token"}, + streamable=False, + ), +] +``` + +#### SSE Transport (Real-Time Streaming) + +For remote servers using Server-Sent Events: + +```python +from crewai.mcp import MCPServerSSE + +mcps=[ + MCPServerSSE( + url="https://stream.example.com/mcp/sse", + headers={"Authorization": "Bearer your_token"}, + cache_tools_list=True, + ), +] +``` + +### Mixed References + +You can combine string references and structured configurations: + +```python +from crewai.mcp import MCPServerStdio, MCPServerHTTP + +mcps=[ + # String references + "https://external-api.com/mcp", # External server + "crewai-amp:financial-insights", # AMP service + + # Structured configurations + MCPServerStdio( + command="npx", + args=["-y", "@modelcontextprotocol/server-filesystem"], + ), + MCPServerHTTP( + url="https://api.example.com/mcp", + headers={"Authorization": "Bearer token"}, + ), +] +``` + +### Tool Filtering + +Structured configurations support advanced tool filtering: + +```python +from crewai.mcp import MCPServerStdio +from crewai.mcp.filters import create_static_tool_filter, create_dynamic_tool_filter, ToolFilterContext + +# Static filtering (allow/block lists) +static_filter = create_static_tool_filter( + allowed_tool_names=["read_file", "write_file"], + blocked_tool_names=["delete_file"], +) + +# Dynamic filtering (context-aware) +def dynamic_filter(context: ToolFilterContext, tool: dict) -> bool: + # Block dangerous tools for certain agent roles + if context.agent.role == "Code Reviewer": + if "delete" in tool.get("name", "").lower(): + return False + return True + +mcps=[ + MCPServerStdio( + command="npx", + args=["-y", "@modelcontextprotocol/server-filesystem"], + tool_filter=static_filter, # or dynamic_filter + ), +] +``` + +## Configuration Parameters + +Each transport type supports specific configuration options: + +### MCPServerStdio Parameters + +- **`command`** (required): Command to execute (e.g., `"python"`, `"node"`, `"npx"`, `"uvx"`) +- **`args`** (optional): List of command arguments (e.g., `["server.py"]` or `["-y", "@mcp/server"]`) +- **`env`** (optional): Dictionary of environment variables to pass to the process +- **`tool_filter`** (optional): Tool filter function for filtering available tools +- **`cache_tools_list`** (optional): Whether to cache the tool list for faster subsequent access (default: `False`) + +### MCPServerHTTP Parameters + +- **`url`** (required): Server URL (e.g., `"https://api.example.com/mcp"`) +- **`headers`** (optional): Dictionary of HTTP headers for authentication or other purposes +- **`streamable`** (optional): Whether to use streamable HTTP transport (default: `True`) +- **`tool_filter`** (optional): Tool filter function for filtering available tools +- **`cache_tools_list`** (optional): Whether to cache the tool list for faster subsequent access (default: `False`) + +### MCPServerSSE Parameters + +- **`url`** (required): Server URL (e.g., `"https://api.example.com/mcp/sse"`) +- **`headers`** (optional): Dictionary of HTTP headers for authentication or other purposes +- **`tool_filter`** (optional): Tool filter function for filtering available tools +- **`cache_tools_list`** (optional): Whether to cache the tool list for faster subsequent access (default: `False`) + +### Common Parameters + +All transport types support: +- **`tool_filter`**: Filter function to control which tools are available. Can be: + - `None` (default): All tools are available + - Static filter: Created with `create_static_tool_filter()` for allow/block lists + - Dynamic filter: Created with `create_dynamic_tool_filter()` for context-aware filtering +- **`cache_tools_list`**: When `True`, caches the tool list after first discovery to improve performance on subsequent connections + ## Key Features - 🔄 **Automatic Tool Discovery**: Tools are automatically discovered and integrated @@ -152,26 +387,47 @@ mcps=[ - 🛡️ **Error Resilience**: Graceful handling of unavailable servers - ⏱️ **Timeout Protection**: Built-in timeouts prevent hanging connections - 📊 **Transparent Integration**: Works seamlessly with existing CrewAI features +- 🔧 **Full Transport Support**: Stdio, HTTP/Streamable HTTP, and SSE transports +- 🎯 **Advanced Filtering**: Static and dynamic tool filtering capabilities +- 🔐 **Flexible Authentication**: Support for headers, environment variables, and query parameters ## Error Handling -The MCP DSL integration is designed to be resilient: +The MCP DSL integration is designed to be resilient and handles failures gracefully: ```python +from crewai import Agent +from crewai.mcp import MCPServerStdio, MCPServerHTTP + agent = Agent( role="Resilient Agent", goal="Continue working despite server issues", backstory="Agent that handles failures gracefully", mcps=[ + # String references "https://reliable-server.com/mcp", # Will work "https://unreachable-server.com/mcp", # Will be skipped gracefully - "https://slow-server.com/mcp", # Will timeout gracefully - "crewai-amp:working-service" # Will work + "crewai-amp:working-service", # Will work + + # Structured configs + MCPServerStdio( + command="python", + args=["reliable_server.py"], # Will work + ), + MCPServerHTTP( + url="https://slow-server.com/mcp", # Will timeout gracefully + ), ] ) # Agent will use tools from working servers and log warnings for failing ones ``` +All connection errors are handled gracefully: +- **Connection failures**: Logged as warnings, agent continues with available tools +- **Timeout errors**: Connections timeout after 30 seconds (configurable) +- **Authentication errors**: Logged clearly for debugging +- **Invalid configurations**: Validation errors are raised at agent creation time + ## Advanced: MCPServerAdapter For complex scenarios requiring manual connection management, use the `MCPServerAdapter` class from `crewai-tools`. Using a Python context manager (`with` statement) is the recommended approach as it automatically handles starting and stopping the connection to the MCP server. diff --git a/lib/crewai/src/crewai/agent/core.py b/lib/crewai/src/crewai/agent/core.py index 3e925cef6..54ae52ba8 100644 --- a/lib/crewai/src/crewai/agent/core.py +++ b/lib/crewai/src/crewai/agent/core.py @@ -40,6 +40,16 @@ from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource from crewai.knowledge.utils.knowledge_utils import extract_knowledge_context from crewai.lite_agent import LiteAgent from crewai.llms.base_llm import BaseLLM +from crewai.mcp import ( + MCPClient, + MCPServerConfig, + MCPServerHTTP, + MCPServerSSE, + MCPServerStdio, +) +from crewai.mcp.transports.http import HTTPTransport +from crewai.mcp.transports.sse import SSETransport +from crewai.mcp.transports.stdio import StdioTransport from crewai.memory.contextual.contextual_memory import ContextualMemory from crewai.rag.embeddings.types import EmbedderConfig from crewai.security.fingerprint import Fingerprint @@ -108,6 +118,7 @@ class Agent(BaseAgent): """ _times_executed: int = PrivateAttr(default=0) + _mcp_clients: list[Any] = PrivateAttr(default_factory=list) max_execution_time: int | None = Field( default=None, description="Maximum execution time for an agent to execute a task", @@ -526,6 +537,9 @@ class Agent(BaseAgent): self, event=AgentExecutionCompletedEvent(agent=self, task=task, output=result), ) + + self._cleanup_mcp_clients() + return result def _execute_with_timeout(self, task_prompt: str, task: Task, timeout: int) -> Any: @@ -649,30 +663,70 @@ class Agent(BaseAgent): self._logger.log("error", f"Error getting platform tools: {e!s}") return [] - def get_mcp_tools(self, mcps: list[str]) -> list[BaseTool]: - """Convert MCP server references to CrewAI tools.""" + def get_mcp_tools(self, mcps: list[str | MCPServerConfig]) -> list[BaseTool]: + """Convert MCP server references/configs to CrewAI tools. + + Supports both string references (backwards compatible) and structured + configuration objects (MCPServerStdio, MCPServerHTTP, MCPServerSSE). + + Args: + mcps: List of MCP server references (strings) or configurations. + + Returns: + List of BaseTool instances from MCP servers. + """ all_tools = [] + clients = [] - for mcp_ref in mcps: - try: - if mcp_ref.startswith("crewai-amp:"): - tools = self._get_amp_mcp_tools(mcp_ref) - elif mcp_ref.startswith("https://"): - tools = self._get_external_mcp_tools(mcp_ref) - else: - continue + for mcp_config in mcps: + if isinstance(mcp_config, str): + tools = self._get_mcp_tools_from_string(mcp_config) + else: + tools, client = self._get_native_mcp_tools(mcp_config) + if client: + clients.append(client) - all_tools.extend(tools) - self._logger.log( - "info", f"Successfully loaded {len(tools)} tools from {mcp_ref}" - ) - - except Exception as e: - self._logger.log("warning", f"Skipping MCP {mcp_ref} due to error: {e}") - continue + all_tools.extend(tools) + # Store clients for cleanup + self._mcp_clients.extend(clients) return all_tools + def _cleanup_mcp_clients(self) -> None: + """Cleanup MCP client connections after task execution.""" + if not self._mcp_clients: + return + + async def _disconnect_all() -> None: + for client in self._mcp_clients: + if client and hasattr(client, "connected") and client.connected: + await client.disconnect() + + try: + asyncio.run(_disconnect_all()) + except Exception as e: + self._logger.log("error", f"Error during MCP client cleanup: {e}") + finally: + self._mcp_clients.clear() + + def _get_mcp_tools_from_string(self, mcp_ref: str) -> list[BaseTool]: + """Get tools from legacy string-based MCP references. + + This method maintains backwards compatibility with string-based + MCP references (https://... and crewai-amp:...). + + Args: + mcp_ref: String reference to MCP server. + + Returns: + List of BaseTool instances. + """ + if mcp_ref.startswith("crewai-amp:"): + return self._get_amp_mcp_tools(mcp_ref) + if mcp_ref.startswith("https://"): + return self._get_external_mcp_tools(mcp_ref) + return [] + def _get_external_mcp_tools(self, mcp_ref: str) -> list[BaseTool]: """Get tools from external HTTPS MCP server with graceful error handling.""" from crewai.tools.mcp_tool_wrapper import MCPToolWrapper @@ -731,6 +785,154 @@ class Agent(BaseAgent): ) return [] + def _get_native_mcp_tools( + self, mcp_config: MCPServerConfig + ) -> tuple[list[BaseTool], Any | None]: + """Get tools from MCP server using structured configuration. + + This method creates an MCP client based on the configuration type, + connects to the server, discovers tools, applies filtering, and + returns wrapped tools along with the client instance for cleanup. + + Args: + mcp_config: MCP server configuration (MCPServerStdio, MCPServerHTTP, or MCPServerSSE). + + Returns: + Tuple of (list of BaseTool instances, MCPClient instance for cleanup). + """ + from crewai.tools.base_tool import BaseTool + from crewai.tools.mcp_native_tool import MCPNativeTool + + if isinstance(mcp_config, MCPServerStdio): + transport = StdioTransport( + command=mcp_config.command, + args=mcp_config.args, + env=mcp_config.env, + ) + server_name = f"{mcp_config.command}_{'_'.join(mcp_config.args)}" + elif isinstance(mcp_config, MCPServerHTTP): + transport = HTTPTransport( + url=mcp_config.url, + headers=mcp_config.headers, + streamable=mcp_config.streamable, + ) + server_name = self._extract_server_name(mcp_config.url) + elif isinstance(mcp_config, MCPServerSSE): + transport = SSETransport( + url=mcp_config.url, + headers=mcp_config.headers, + ) + server_name = self._extract_server_name(mcp_config.url) + else: + raise ValueError(f"Unsupported MCP server config type: {type(mcp_config)}") + + client = MCPClient( + transport=transport, + cache_tools_list=mcp_config.cache_tools_list, + ) + + async def _setup_client_and_list_tools() -> list[dict[str, Any]]: + """Async helper to connect and list tools in same event loop.""" + + try: + if not client.connected: + await client.connect() + + tools_list = await client.list_tools() + + try: + await client.disconnect() + # Small delay to allow background tasks to finish cleanup + # This helps prevent "cancel scope in different task" errors + # when asyncio.run() closes the event loop + await asyncio.sleep(0.1) + except Exception as e: + self._logger.log("error", f"Error during disconnect: {e}") + + return tools_list + except Exception as e: + if client.connected: + await client.disconnect() + await asyncio.sleep(0.1) + raise RuntimeError( + f"Error during setup client and list tools: {e}" + ) from e + + try: + try: + tools_list = asyncio.run(_setup_client_and_list_tools()) + except RuntimeError as e: + error_msg = str(e).lower() + if "cancel scope" in error_msg or "task" in error_msg: + raise ConnectionError( + "MCP connection failed due to event loop cleanup issues. " + "This may be due to authentication errors or server unavailability." + ) from e + except asyncio.CancelledError as e: + raise ConnectionError( + "MCP connection was cancelled. This may indicate an authentication " + "error or server unavailability." + ) from e + + if mcp_config.tool_filter: + filtered_tools = [] + for tool in tools_list: + if callable(mcp_config.tool_filter): + try: + from crewai.mcp.filters import ToolFilterContext + + context = ToolFilterContext( + agent=self, + server_name=server_name, + run_context=None, + ) + if mcp_config.tool_filter(context, tool): + filtered_tools.append(tool) + except (TypeError, AttributeError): + if mcp_config.tool_filter(tool): + filtered_tools.append(tool) + else: + # Not callable - include tool + filtered_tools.append(tool) + tools_list = filtered_tools + + tools = [] + for tool_def in tools_list: + tool_name = tool_def.get("name", "") + if not tool_name: + continue + + # Convert inputSchema to Pydantic model if present + args_schema = None + if tool_def.get("inputSchema"): + args_schema = self._json_schema_to_pydantic( + tool_name, tool_def["inputSchema"] + ) + + tool_schema = { + "description": tool_def.get("description", ""), + "args_schema": args_schema, + } + + try: + native_tool = MCPNativeTool( + mcp_client=client, + tool_name=tool_name, + tool_schema=tool_schema, + server_name=server_name, + ) + tools.append(native_tool) + except Exception as e: + self._logger.log("error", f"Failed to create native MCP tool: {e}") + continue + + return cast(list[BaseTool], tools), client + except Exception as e: + if client.connected: + asyncio.run(client.disconnect()) + + raise RuntimeError(f"Failed to get native MCP tools: {e}") from e + def _get_amp_mcp_tools(self, amp_ref: str) -> list[BaseTool]: """Get tools from CrewAI AMP MCP marketplace.""" # Parse: "crewai-amp:mcp-name" or "crewai-amp:mcp-name#tool_name" diff --git a/lib/crewai/src/crewai/agents/agent_builder/base_agent.py b/lib/crewai/src/crewai/agents/agent_builder/base_agent.py index b26c24515..932c98611 100644 --- a/lib/crewai/src/crewai/agents/agent_builder/base_agent.py +++ b/lib/crewai/src/crewai/agents/agent_builder/base_agent.py @@ -25,6 +25,7 @@ from crewai.agents.tools_handler import ToolsHandler from crewai.knowledge.knowledge import Knowledge from crewai.knowledge.knowledge_config import KnowledgeConfig from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource +from crewai.mcp.config import MCPServerConfig from crewai.rag.embeddings.types import EmbedderConfig from crewai.security.security_config import SecurityConfig from crewai.tools.base_tool import BaseTool, Tool @@ -194,7 +195,7 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta): default=None, description="List of applications or application/action combinations that the agent can access through CrewAI Platform. Can contain app names (e.g., 'gmail') or specific actions (e.g., 'gmail/send_email')", ) - mcps: list[str] | None = Field( + mcps: list[str | MCPServerConfig] | None = Field( default=None, description="List of MCP server references. Supports 'https://server.com/path' for external servers and 'crewai-amp:mcp-name' for AMP marketplace. Use '#tool_name' suffix for specific tools.", ) @@ -253,20 +254,36 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta): @field_validator("mcps") @classmethod - def validate_mcps(cls, mcps: list[str] | None) -> list[str] | None: + def validate_mcps( + cls, mcps: list[str | MCPServerConfig] | None + ) -> list[str | MCPServerConfig] | None: + """Validate MCP server references and configurations. + + Supports both string references (for backwards compatibility) and + structured configuration objects (MCPServerStdio, MCPServerHTTP, MCPServerSSE). + """ if not mcps: return mcps validated_mcps = [] for mcp in mcps: - if mcp.startswith(("https://", "crewai-amp:")): + if isinstance(mcp, str): + if mcp.startswith(("https://", "crewai-amp:")): + validated_mcps.append(mcp) + else: + raise ValueError( + f"Invalid MCP reference: {mcp}. " + "String references must start with 'https://' or 'crewai-amp:'" + ) + + elif isinstance(mcp, (MCPServerConfig)): validated_mcps.append(mcp) else: raise ValueError( - f"Invalid MCP reference: {mcp}. Must start with 'https://' or 'crewai-amp:'" + f"Invalid MCP configuration: {type(mcp)}. " + "Must be a string reference or MCPServerConfig instance." ) - - return list(set(validated_mcps)) + return validated_mcps @model_validator(mode="after") def validate_and_set_attributes(self) -> Self: @@ -343,7 +360,7 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta): """Get platform tools for the specified list of applications and/or application/action combinations.""" @abstractmethod - def get_mcp_tools(self, mcps: list[str]) -> list[BaseTool]: + def get_mcp_tools(self, mcps: list[str | MCPServerConfig]) -> list[BaseTool]: """Get MCP tools for the specified list of MCP server references.""" def copy(self) -> Self: # type: ignore # Signature of "copy" incompatible with supertype "BaseModel" diff --git a/lib/crewai/src/crewai/events/__init__.py b/lib/crewai/src/crewai/events/__init__.py index 66c441bed..4147965e1 100644 --- a/lib/crewai/src/crewai/events/__init__.py +++ b/lib/crewai/src/crewai/events/__init__.py @@ -16,7 +16,6 @@ from crewai.events.base_event_listener import BaseEventListener from crewai.events.depends import Depends from crewai.events.event_bus import crewai_event_bus from crewai.events.handler_graph import CircularDependencyError - from crewai.events.types.crew_events import ( CrewKickoffCompletedEvent, CrewKickoffFailedEvent, @@ -61,6 +60,14 @@ from crewai.events.types.logging_events import ( AgentLogsExecutionEvent, AgentLogsStartedEvent, ) +from crewai.events.types.mcp_events import ( + MCPConnectionCompletedEvent, + MCPConnectionFailedEvent, + MCPConnectionStartedEvent, + MCPToolExecutionCompletedEvent, + MCPToolExecutionFailedEvent, + MCPToolExecutionStartedEvent, +) from crewai.events.types.memory_events import ( MemoryQueryCompletedEvent, MemoryQueryFailedEvent, @@ -153,6 +160,12 @@ __all__ = [ "LiteAgentExecutionCompletedEvent", "LiteAgentExecutionErrorEvent", "LiteAgentExecutionStartedEvent", + "MCPConnectionCompletedEvent", + "MCPConnectionFailedEvent", + "MCPConnectionStartedEvent", + "MCPToolExecutionCompletedEvent", + "MCPToolExecutionFailedEvent", + "MCPToolExecutionStartedEvent", "MemoryQueryCompletedEvent", "MemoryQueryFailedEvent", "MemoryQueryStartedEvent", diff --git a/lib/crewai/src/crewai/events/event_listener.py b/lib/crewai/src/crewai/events/event_listener.py index 6bb604b64..e07ee193c 100644 --- a/lib/crewai/src/crewai/events/event_listener.py +++ b/lib/crewai/src/crewai/events/event_listener.py @@ -65,6 +65,14 @@ from crewai.events.types.logging_events import ( AgentLogsExecutionEvent, AgentLogsStartedEvent, ) +from crewai.events.types.mcp_events import ( + MCPConnectionCompletedEvent, + MCPConnectionFailedEvent, + MCPConnectionStartedEvent, + MCPToolExecutionCompletedEvent, + MCPToolExecutionFailedEvent, + MCPToolExecutionStartedEvent, +) from crewai.events.types.reasoning_events import ( AgentReasoningCompletedEvent, AgentReasoningFailedEvent, @@ -615,5 +623,67 @@ class EventListener(BaseEventListener): event.total_turns, ) + # ----------- MCP EVENTS ----------- + + @crewai_event_bus.on(MCPConnectionStartedEvent) + def on_mcp_connection_started(source, event: MCPConnectionStartedEvent): + self.formatter.handle_mcp_connection_started( + event.server_name, + event.server_url, + event.transport_type, + event.is_reconnect, + event.connect_timeout, + ) + + @crewai_event_bus.on(MCPConnectionCompletedEvent) + def on_mcp_connection_completed(source, event: MCPConnectionCompletedEvent): + self.formatter.handle_mcp_connection_completed( + event.server_name, + event.server_url, + event.transport_type, + event.connection_duration_ms, + event.is_reconnect, + ) + + @crewai_event_bus.on(MCPConnectionFailedEvent) + def on_mcp_connection_failed(source, event: MCPConnectionFailedEvent): + self.formatter.handle_mcp_connection_failed( + event.server_name, + event.server_url, + event.transport_type, + event.error, + event.error_type, + ) + + @crewai_event_bus.on(MCPToolExecutionStartedEvent) + def on_mcp_tool_execution_started(source, event: MCPToolExecutionStartedEvent): + self.formatter.handle_mcp_tool_execution_started( + event.server_name, + event.tool_name, + event.tool_args, + ) + + @crewai_event_bus.on(MCPToolExecutionCompletedEvent) + def on_mcp_tool_execution_completed( + source, event: MCPToolExecutionCompletedEvent + ): + self.formatter.handle_mcp_tool_execution_completed( + event.server_name, + event.tool_name, + event.tool_args, + event.result, + event.execution_duration_ms, + ) + + @crewai_event_bus.on(MCPToolExecutionFailedEvent) + def on_mcp_tool_execution_failed(source, event: MCPToolExecutionFailedEvent): + self.formatter.handle_mcp_tool_execution_failed( + event.server_name, + event.tool_name, + event.tool_args, + event.error, + event.error_type, + ) + event_listener = EventListener() diff --git a/lib/crewai/src/crewai/events/event_types.py b/lib/crewai/src/crewai/events/event_types.py index f7a4d1f72..ea00aa9ae 100644 --- a/lib/crewai/src/crewai/events/event_types.py +++ b/lib/crewai/src/crewai/events/event_types.py @@ -40,6 +40,14 @@ from crewai.events.types.llm_guardrail_events import ( LLMGuardrailCompletedEvent, LLMGuardrailStartedEvent, ) +from crewai.events.types.mcp_events import ( + MCPConnectionCompletedEvent, + MCPConnectionFailedEvent, + MCPConnectionStartedEvent, + MCPToolExecutionCompletedEvent, + MCPToolExecutionFailedEvent, + MCPToolExecutionStartedEvent, +) from crewai.events.types.memory_events import ( MemoryQueryCompletedEvent, MemoryQueryFailedEvent, @@ -115,4 +123,10 @@ EventTypes = ( | MemoryQueryFailedEvent | MemoryRetrievalStartedEvent | MemoryRetrievalCompletedEvent + | MCPConnectionStartedEvent + | MCPConnectionCompletedEvent + | MCPConnectionFailedEvent + | MCPToolExecutionStartedEvent + | MCPToolExecutionCompletedEvent + | MCPToolExecutionFailedEvent ) diff --git a/lib/crewai/src/crewai/events/types/mcp_events.py b/lib/crewai/src/crewai/events/types/mcp_events.py new file mode 100644 index 000000000..d360aa62a --- /dev/null +++ b/lib/crewai/src/crewai/events/types/mcp_events.py @@ -0,0 +1,85 @@ +from datetime import datetime +from typing import Any + +from crewai.events.base_events import BaseEvent + + +class MCPEvent(BaseEvent): + """Base event for MCP operations.""" + + server_name: str + server_url: str | None = None + transport_type: str | None = None # "stdio", "http", "sse" + agent_id: str | None = None + agent_role: str | None = None + from_agent: Any | None = None + from_task: Any | None = None + + def __init__(self, **data): + super().__init__(**data) + self._set_agent_params(data) + self._set_task_params(data) + + +class MCPConnectionStartedEvent(MCPEvent): + """Event emitted when starting to connect to an MCP server.""" + + type: str = "mcp_connection_started" + connect_timeout: int | None = None + is_reconnect: bool = ( + False # True if this is a reconnection, False for first connection + ) + + +class MCPConnectionCompletedEvent(MCPEvent): + """Event emitted when successfully connected to an MCP server.""" + + type: str = "mcp_connection_completed" + started_at: datetime | None = None + completed_at: datetime | None = None + connection_duration_ms: float | None = None + is_reconnect: bool = ( + False # True if this was a reconnection, False for first connection + ) + + +class MCPConnectionFailedEvent(MCPEvent): + """Event emitted when connection to an MCP server fails.""" + + type: str = "mcp_connection_failed" + error: str + error_type: str | None = None # "timeout", "authentication", "network", etc. + started_at: datetime | None = None + failed_at: datetime | None = None + + +class MCPToolExecutionStartedEvent(MCPEvent): + """Event emitted when starting to execute an MCP tool.""" + + type: str = "mcp_tool_execution_started" + tool_name: str + tool_args: dict[str, Any] | None = None + + +class MCPToolExecutionCompletedEvent(MCPEvent): + """Event emitted when MCP tool execution completes.""" + + type: str = "mcp_tool_execution_completed" + tool_name: str + tool_args: dict[str, Any] | None = None + result: Any | None = None + started_at: datetime | None = None + completed_at: datetime | None = None + execution_duration_ms: float | None = None + + +class MCPToolExecutionFailedEvent(MCPEvent): + """Event emitted when MCP tool execution fails.""" + + type: str = "mcp_tool_execution_failed" + tool_name: str + tool_args: dict[str, Any] | None = None + error: str + error_type: str | None = None # "timeout", "validation", "server_error", etc. + started_at: datetime | None = None + failed_at: datetime | None = None diff --git a/lib/crewai/src/crewai/events/utils/console_formatter.py b/lib/crewai/src/crewai/events/utils/console_formatter.py index 4ee2aa52b..32aa8d208 100644 --- a/lib/crewai/src/crewai/events/utils/console_formatter.py +++ b/lib/crewai/src/crewai/events/utils/console_formatter.py @@ -2248,3 +2248,203 @@ class ConsoleFormatter: self.current_a2a_conversation_branch = None self.current_a2a_turn_count = 0 + + # ----------- MCP EVENTS ----------- + + def handle_mcp_connection_started( + self, + server_name: str, + server_url: str | None = None, + transport_type: str | None = None, + is_reconnect: bool = False, + connect_timeout: int | None = None, + ) -> None: + """Handle MCP connection started event.""" + if not self.verbose: + return + + content = Text() + reconnect_text = " (Reconnecting)" if is_reconnect else "" + content.append(f"MCP Connection Started{reconnect_text}\n\n", style="cyan bold") + content.append("Server: ", style="white") + content.append(f"{server_name}\n", style="cyan") + + if server_url: + content.append("URL: ", style="white") + content.append(f"{server_url}\n", style="cyan dim") + + if transport_type: + content.append("Transport: ", style="white") + content.append(f"{transport_type}\n", style="cyan") + + if connect_timeout: + content.append("Timeout: ", style="white") + content.append(f"{connect_timeout}s\n", style="cyan") + + panel = self.create_panel(content, "🔌 MCP Connection", "cyan") + self.print(panel) + self.print() + + def handle_mcp_connection_completed( + self, + server_name: str, + server_url: str | None = None, + transport_type: str | None = None, + connection_duration_ms: float | None = None, + is_reconnect: bool = False, + ) -> None: + """Handle MCP connection completed event.""" + if not self.verbose: + return + + content = Text() + reconnect_text = " (Reconnected)" if is_reconnect else "" + content.append( + f"MCP Connection Completed{reconnect_text}\n\n", style="green bold" + ) + content.append("Server: ", style="white") + content.append(f"{server_name}\n", style="green") + + if server_url: + content.append("URL: ", style="white") + content.append(f"{server_url}\n", style="green dim") + + if transport_type: + content.append("Transport: ", style="white") + content.append(f"{transport_type}\n", style="green") + + if connection_duration_ms is not None: + content.append("Duration: ", style="white") + content.append(f"{connection_duration_ms:.2f}ms\n", style="green") + + panel = self.create_panel(content, "✅ MCP Connected", "green") + self.print(panel) + self.print() + + def handle_mcp_connection_failed( + self, + server_name: str, + server_url: str | None = None, + transport_type: str | None = None, + error: str = "", + error_type: str | None = None, + ) -> None: + """Handle MCP connection failed event.""" + if not self.verbose: + return + + content = Text() + content.append("MCP Connection Failed\n\n", style="red bold") + content.append("Server: ", style="white") + content.append(f"{server_name}\n", style="red") + + if server_url: + content.append("URL: ", style="white") + content.append(f"{server_url}\n", style="red dim") + + if transport_type: + content.append("Transport: ", style="white") + content.append(f"{transport_type}\n", style="red") + + if error_type: + content.append("Error Type: ", style="white") + content.append(f"{error_type}\n", style="red") + + if error: + content.append("\nError: ", style="white bold") + error_preview = error[:500] + "..." if len(error) > 500 else error + content.append(f"{error_preview}\n", style="red") + + panel = self.create_panel(content, "❌ MCP Connection Failed", "red") + self.print(panel) + self.print() + + def handle_mcp_tool_execution_started( + self, + server_name: str, + tool_name: str, + tool_args: dict[str, Any] | None = None, + ) -> None: + """Handle MCP tool execution started event.""" + if not self.verbose: + return + + content = self.create_status_content( + "MCP Tool Execution Started", + tool_name, + "yellow", + tool_args=tool_args or {}, + Server=server_name, + ) + + panel = self.create_panel(content, "🔧 MCP Tool", "yellow") + self.print(panel) + self.print() + + def handle_mcp_tool_execution_completed( + self, + server_name: str, + tool_name: str, + tool_args: dict[str, Any] | None = None, + result: Any | None = None, + execution_duration_ms: float | None = None, + ) -> None: + """Handle MCP tool execution completed event.""" + if not self.verbose: + return + + content = self.create_status_content( + "MCP Tool Execution Completed", + tool_name, + "green", + tool_args=tool_args or {}, + Server=server_name, + ) + + if execution_duration_ms is not None: + content.append("Duration: ", style="white") + content.append(f"{execution_duration_ms:.2f}ms\n", style="green") + + if result is not None: + result_str = str(result) + if len(result_str) > 500: + result_str = result_str[:497] + "..." + content.append("\nResult: ", style="white bold") + content.append(f"{result_str}\n", style="green") + + panel = self.create_panel(content, "✅ MCP Tool Completed", "green") + self.print(panel) + self.print() + + def handle_mcp_tool_execution_failed( + self, + server_name: str, + tool_name: str, + tool_args: dict[str, Any] | None = None, + error: str = "", + error_type: str | None = None, + ) -> None: + """Handle MCP tool execution failed event.""" + if not self.verbose: + return + + content = self.create_status_content( + "MCP Tool Execution Failed", + tool_name, + "red", + tool_args=tool_args or {}, + Server=server_name, + ) + + if error_type: + content.append("Error Type: ", style="white") + content.append(f"{error_type}\n", style="red") + + if error: + content.append("\nError: ", style="white bold") + error_preview = error[:500] + "..." if len(error) > 500 else error + content.append(f"{error_preview}\n", style="red") + + panel = self.create_panel(content, "❌ MCP Tool Failed", "red") + self.print(panel) + self.print() diff --git a/lib/crewai/src/crewai/mcp/__init__.py b/lib/crewai/src/crewai/mcp/__init__.py new file mode 100644 index 000000000..282cb1f56 --- /dev/null +++ b/lib/crewai/src/crewai/mcp/__init__.py @@ -0,0 +1,37 @@ +"""MCP (Model Context Protocol) client support for CrewAI agents. + +This module provides native MCP client functionality, allowing CrewAI agents +to connect to any MCP-compliant server using various transport types. +""" + +from crewai.mcp.client import MCPClient +from crewai.mcp.config import ( + MCPServerConfig, + MCPServerHTTP, + MCPServerSSE, + MCPServerStdio, +) +from crewai.mcp.filters import ( + StaticToolFilter, + ToolFilter, + ToolFilterContext, + create_dynamic_tool_filter, + create_static_tool_filter, +) +from crewai.mcp.transports.base import BaseTransport, TransportType + + +__all__ = [ + "BaseTransport", + "MCPClient", + "MCPServerConfig", + "MCPServerHTTP", + "MCPServerSSE", + "MCPServerStdio", + "StaticToolFilter", + "ToolFilter", + "ToolFilterContext", + "TransportType", + "create_dynamic_tool_filter", + "create_static_tool_filter", +] diff --git a/lib/crewai/src/crewai/mcp/client.py b/lib/crewai/src/crewai/mcp/client.py new file mode 100644 index 000000000..aff07a397 --- /dev/null +++ b/lib/crewai/src/crewai/mcp/client.py @@ -0,0 +1,742 @@ +"""MCP client with session management for CrewAI agents.""" + +import asyncio +from collections.abc import Callable +from contextlib import AsyncExitStack +from datetime import datetime +import logging +import time +from typing import Any + +from typing_extensions import Self + + +# BaseExceptionGroup is available in Python 3.11+ +try: + from builtins import BaseExceptionGroup +except ImportError: + # Fallback for Python < 3.11 (shouldn't happen in practice) + BaseExceptionGroup = Exception + +from crewai.events.event_bus import crewai_event_bus +from crewai.events.types.mcp_events import ( + MCPConnectionCompletedEvent, + MCPConnectionFailedEvent, + MCPConnectionStartedEvent, + MCPToolExecutionCompletedEvent, + MCPToolExecutionFailedEvent, + MCPToolExecutionStartedEvent, +) +from crewai.mcp.transports.base import BaseTransport +from crewai.mcp.transports.http import HTTPTransport +from crewai.mcp.transports.sse import SSETransport +from crewai.mcp.transports.stdio import StdioTransport + + +# MCP Connection timeout constants (in seconds) +MCP_CONNECTION_TIMEOUT = 30 # Increased for slow servers +MCP_TOOL_EXECUTION_TIMEOUT = 30 +MCP_DISCOVERY_TIMEOUT = 30 # Increased for slow servers +MCP_MAX_RETRIES = 3 + +# Simple in-memory cache for MCP tool schemas (duration: 5 minutes) +_mcp_schema_cache: dict[str, tuple[dict[str, Any], float]] = {} +_cache_ttl = 300 # 5 minutes + + +class MCPClient: + """MCP client with session management. + + This client manages connections to MCP servers and provides a high-level + interface for interacting with MCP tools, prompts, and resources. + + Example: + ```python + transport = StdioTransport(command="python", args=["server.py"]) + client = MCPClient(transport) + async with client: + tools = await client.list_tools() + result = await client.call_tool("tool_name", {"arg": "value"}) + ``` + """ + + def __init__( + self, + transport: BaseTransport, + connect_timeout: int = MCP_CONNECTION_TIMEOUT, + execution_timeout: int = MCP_TOOL_EXECUTION_TIMEOUT, + discovery_timeout: int = MCP_DISCOVERY_TIMEOUT, + max_retries: int = MCP_MAX_RETRIES, + cache_tools_list: bool = False, + logger: logging.Logger | None = None, + ) -> None: + """Initialize MCP client. + + Args: + transport: Transport instance for MCP server connection. + connect_timeout: Connection timeout in seconds. + execution_timeout: Tool execution timeout in seconds. + discovery_timeout: Tool discovery timeout in seconds. + max_retries: Maximum retry attempts for operations. + cache_tools_list: Whether to cache tool list results. + logger: Optional logger instance. + """ + self.transport = transport + self.connect_timeout = connect_timeout + self.execution_timeout = execution_timeout + self.discovery_timeout = discovery_timeout + self.max_retries = max_retries + self.cache_tools_list = cache_tools_list + # self._logger = logger or logging.getLogger(__name__) + self._session: Any = None + self._initialized = False + self._exit_stack = AsyncExitStack() + self._was_connected = False + + @property + def connected(self) -> bool: + """Check if client is connected to server.""" + return self.transport.connected and self._initialized + + @property + def session(self) -> Any: + """Get the MCP session.""" + if self._session is None: + raise RuntimeError("Client not connected. Call connect() first.") + return self._session + + def _get_server_info(self) -> tuple[str, str | None, str | None]: + """Get server information for events. + + Returns: + Tuple of (server_name, server_url, transport_type). + """ + if isinstance(self.transport, StdioTransport): + server_name = f"{self.transport.command} {' '.join(self.transport.args)}" + server_url = None + transport_type = self.transport.transport_type.value + elif isinstance(self.transport, HTTPTransport): + server_name = self.transport.url + server_url = self.transport.url + transport_type = self.transport.transport_type.value + elif isinstance(self.transport, SSETransport): + server_name = self.transport.url + server_url = self.transport.url + transport_type = self.transport.transport_type.value + else: + server_name = "Unknown MCP Server" + server_url = None + transport_type = ( + self.transport.transport_type.value + if hasattr(self.transport, "transport_type") + else None + ) + + return server_name, server_url, transport_type + + async def connect(self) -> Self: + """Connect to MCP server and initialize session. + + Returns: + Self for method chaining. + + Raises: + ConnectionError: If connection fails. + ImportError: If MCP SDK not available. + """ + if self.connected: + return self + + # Get server info for events + server_name, server_url, transport_type = self._get_server_info() + is_reconnect = self._was_connected + + # Emit connection started event + started_at = datetime.now() + crewai_event_bus.emit( + self, + MCPConnectionStartedEvent( + server_name=server_name, + server_url=server_url, + transport_type=transport_type, + is_reconnect=is_reconnect, + connect_timeout=self.connect_timeout, + ), + ) + + try: + from mcp import ClientSession + + # Use AsyncExitStack to manage transport and session contexts together + # This ensures they're in the same async scope and prevents cancel scope errors + # Always enter transport context via exit stack (it handles already-connected state) + await self._exit_stack.enter_async_context(self.transport) + + # Create ClientSession with transport streams + self._session = ClientSession( + self.transport.read_stream, + self.transport.write_stream, + ) + + # Enter the session's async context manager via exit stack + await self._exit_stack.enter_async_context(self._session) + + # Initialize the session (required by MCP protocol) + try: + await asyncio.wait_for( + self._session.initialize(), + timeout=self.connect_timeout, + ) + except asyncio.CancelledError: + # If initialization was cancelled (e.g., event loop closing), + # cleanup and re-raise - don't suppress cancellation + await self._cleanup_on_error() + raise + except BaseExceptionGroup as eg: + # Handle exception groups from anyio task groups + # Extract the actual meaningful error (not GeneratorExit) + actual_error = None + for exc in eg.exceptions: + if isinstance(exc, Exception) and not isinstance( + exc, GeneratorExit + ): + # Check if it's an HTTP error (like 401) + error_msg = str(exc).lower() + if "401" in error_msg or "unauthorized" in error_msg: + actual_error = exc + break + if "cancel scope" not in error_msg and "task" not in error_msg: + actual_error = exc + break + + await self._cleanup_on_error() + if actual_error: + raise ConnectionError( + f"Failed to connect to MCP server: {actual_error}" + ) from actual_error + raise ConnectionError(f"Failed to connect to MCP server: {eg}") from eg + + self._initialized = True + self._was_connected = True + + completed_at = datetime.now() + connection_duration_ms = (completed_at - started_at).total_seconds() * 1000 + crewai_event_bus.emit( + self, + MCPConnectionCompletedEvent( + server_name=server_name, + server_url=server_url, + transport_type=transport_type, + started_at=started_at, + completed_at=completed_at, + connection_duration_ms=connection_duration_ms, + is_reconnect=is_reconnect, + ), + ) + + return self + except ImportError as e: + await self._cleanup_on_error() + error_msg = ( + "MCP library not available. Please install with: pip install mcp" + ) + self._emit_connection_failed( + server_name, + server_url, + transport_type, + error_msg, + "import_error", + started_at, + ) + raise ImportError(error_msg) from e + except asyncio.TimeoutError as e: + await self._cleanup_on_error() + error_msg = f"MCP connection timed out after {self.connect_timeout} seconds. The server may be slow or unreachable." + self._emit_connection_failed( + server_name, + server_url, + transport_type, + error_msg, + "timeout", + started_at, + ) + raise ConnectionError(error_msg) from e + except asyncio.CancelledError: + # Re-raise cancellation - don't suppress it + await self._cleanup_on_error() + self._emit_connection_failed( + server_name, + server_url, + transport_type, + "Connection cancelled", + "cancelled", + started_at, + ) + raise + except BaseExceptionGroup as eg: + # Handle exception groups from anyio task groups at outer level + actual_error = None + for exc in eg.exceptions: + if isinstance(exc, Exception) and not isinstance(exc, GeneratorExit): + error_msg = str(exc).lower() + if "401" in error_msg or "unauthorized" in error_msg: + actual_error = exc + break + if "cancel scope" not in error_msg and "task" not in error_msg: + actual_error = exc + break + + await self._cleanup_on_error() + error_type = ( + "authentication" + if actual_error + and ( + "401" in str(actual_error).lower() + or "unauthorized" in str(actual_error).lower() + ) + else "network" + ) + error_msg = str(actual_error) if actual_error else str(eg) + self._emit_connection_failed( + server_name, + server_url, + transport_type, + error_msg, + error_type, + started_at, + ) + if actual_error: + raise ConnectionError( + f"Failed to connect to MCP server: {actual_error}" + ) from actual_error + raise ConnectionError(f"Failed to connect to MCP server: {eg}") from eg + except Exception as e: + await self._cleanup_on_error() + error_type = ( + "authentication" + if "401" in str(e).lower() or "unauthorized" in str(e).lower() + else "network" + ) + self._emit_connection_failed( + server_name, server_url, transport_type, str(e), error_type, started_at + ) + raise ConnectionError(f"Failed to connect to MCP server: {e}") from e + + def _emit_connection_failed( + self, + server_name: str, + server_url: str | None, + transport_type: str | None, + error: str, + error_type: str, + started_at: datetime, + ) -> None: + """Emit connection failed event.""" + failed_at = datetime.now() + crewai_event_bus.emit( + self, + MCPConnectionFailedEvent( + server_name=server_name, + server_url=server_url, + transport_type=transport_type, + error=error, + error_type=error_type, + started_at=started_at, + failed_at=failed_at, + ), + ) + + async def _cleanup_on_error(self) -> None: + """Cleanup resources when an error occurs during connection.""" + try: + await self._exit_stack.aclose() + + except Exception as e: + # Best effort cleanup - ignore all other errors + raise RuntimeError(f"Error during MCP client cleanup: {e}") from e + finally: + self._session = None + self._initialized = False + self._exit_stack = AsyncExitStack() + + async def disconnect(self) -> None: + """Disconnect from MCP server and cleanup resources.""" + if not self.connected: + return + + try: + await self._exit_stack.aclose() + except Exception as e: + raise RuntimeError(f"Error during MCP client disconnect: {e}") from e + finally: + self._session = None + self._initialized = False + self._exit_stack = AsyncExitStack() + + async def list_tools(self, use_cache: bool | None = None) -> list[dict[str, Any]]: + """List available tools from MCP server. + + Args: + use_cache: Whether to use cached results. If None, uses + client's cache_tools_list setting. + + Returns: + List of tool definitions with name, description, and inputSchema. + """ + if not self.connected: + await self.connect() + + # Check cache if enabled + use_cache = use_cache if use_cache is not None else self.cache_tools_list + if use_cache: + cache_key = self._get_cache_key("tools") + if cache_key in _mcp_schema_cache: + cached_data, cache_time = _mcp_schema_cache[cache_key] + if time.time() - cache_time < _cache_ttl: + # Logger removed - return cached data + return cached_data + + # List tools with timeout and retries + tools = await self._retry_operation( + self._list_tools_impl, + timeout=self.discovery_timeout, + ) + + # Cache results if enabled + if use_cache: + cache_key = self._get_cache_key("tools") + _mcp_schema_cache[cache_key] = (tools, time.time()) + + return tools + + async def _list_tools_impl(self) -> list[dict[str, Any]]: + """Internal implementation of list_tools.""" + tools_result = await asyncio.wait_for( + self.session.list_tools(), + timeout=self.discovery_timeout, + ) + + return [ + { + "name": tool.name, + "description": getattr(tool, "description", ""), + "inputSchema": getattr(tool, "inputSchema", {}), + } + for tool in tools_result.tools + ] + + async def call_tool( + self, tool_name: str, arguments: dict[str, Any] | None = None + ) -> Any: + """Call a tool on the MCP server. + + Args: + tool_name: Name of the tool to call. + arguments: Tool arguments. + + Returns: + Tool execution result. + """ + if not self.connected: + await self.connect() + + arguments = arguments or {} + cleaned_arguments = self._clean_tool_arguments(arguments) + + # Get server info for events + server_name, server_url, transport_type = self._get_server_info() + + # Emit tool execution started event + started_at = datetime.now() + crewai_event_bus.emit( + self, + MCPToolExecutionStartedEvent( + server_name=server_name, + server_url=server_url, + transport_type=transport_type, + tool_name=tool_name, + tool_args=cleaned_arguments, + ), + ) + + try: + result = await self._retry_operation( + lambda: self._call_tool_impl(tool_name, cleaned_arguments), + timeout=self.execution_timeout, + ) + + completed_at = datetime.now() + execution_duration_ms = (completed_at - started_at).total_seconds() * 1000 + crewai_event_bus.emit( + self, + MCPToolExecutionCompletedEvent( + server_name=server_name, + server_url=server_url, + transport_type=transport_type, + tool_name=tool_name, + tool_args=cleaned_arguments, + result=result, + started_at=started_at, + completed_at=completed_at, + execution_duration_ms=execution_duration_ms, + ), + ) + + return result + except Exception as e: + failed_at = datetime.now() + error_type = ( + "timeout" + if isinstance(e, (asyncio.TimeoutError, ConnectionError)) + and "timeout" in str(e).lower() + else "server_error" + ) + crewai_event_bus.emit( + self, + MCPToolExecutionFailedEvent( + server_name=server_name, + server_url=server_url, + transport_type=transport_type, + tool_name=tool_name, + tool_args=cleaned_arguments, + error=str(e), + error_type=error_type, + started_at=started_at, + failed_at=failed_at, + ), + ) + raise + + def _clean_tool_arguments(self, arguments: dict[str, Any]) -> dict[str, Any]: + """Clean tool arguments by removing None values and fixing formats. + + Args: + arguments: Raw tool arguments. + + Returns: + Cleaned arguments ready for MCP server. + """ + cleaned = {} + + for key, value in arguments.items(): + # Skip None values + if value is None: + continue + + # Fix sources array format: convert ["web"] to [{"type": "web"}] + if key == "sources" and isinstance(value, list): + fixed_sources = [] + for item in value: + if isinstance(item, str): + # Convert string to object format + fixed_sources.append({"type": item}) + elif isinstance(item, dict): + # Already in correct format + fixed_sources.append(item) + else: + # Keep as is if unknown format + fixed_sources.append(item) + if fixed_sources: + cleaned[key] = fixed_sources + continue + + # Recursively clean nested dictionaries + if isinstance(value, dict): + nested_cleaned = self._clean_tool_arguments(value) + if nested_cleaned: # Only add if not empty + cleaned[key] = nested_cleaned + elif isinstance(value, list): + # Clean list items + cleaned_list = [] + for item in value: + if isinstance(item, dict): + cleaned_item = self._clean_tool_arguments(item) + if cleaned_item: + cleaned_list.append(cleaned_item) + elif item is not None: + cleaned_list.append(item) + if cleaned_list: + cleaned[key] = cleaned_list + else: + # Keep primitive values + cleaned[key] = value + + return cleaned + + async def _call_tool_impl(self, tool_name: str, arguments: dict[str, Any]) -> Any: + """Internal implementation of call_tool.""" + result = await asyncio.wait_for( + self.session.call_tool(tool_name, arguments), + timeout=self.execution_timeout, + ) + + # Extract result content + if hasattr(result, "content") and result.content: + if isinstance(result.content, list) and len(result.content) > 0: + content_item = result.content[0] + if hasattr(content_item, "text"): + return str(content_item.text) + return str(content_item) + return str(result.content) + + return str(result) + + async def list_prompts(self) -> list[dict[str, Any]]: + """List available prompts from MCP server. + + Returns: + List of prompt definitions. + """ + if not self.connected: + await self.connect() + + return await self._retry_operation( + self._list_prompts_impl, + timeout=self.discovery_timeout, + ) + + async def _list_prompts_impl(self) -> list[dict[str, Any]]: + """Internal implementation of list_prompts.""" + prompts_result = await asyncio.wait_for( + self.session.list_prompts(), + timeout=self.discovery_timeout, + ) + + return [ + { + "name": prompt.name, + "description": getattr(prompt, "description", ""), + "arguments": getattr(prompt, "arguments", []), + } + for prompt in prompts_result.prompts + ] + + async def get_prompt( + self, prompt_name: str, arguments: dict[str, Any] | None = None + ) -> dict[str, Any]: + """Get a prompt from the MCP server. + + Args: + prompt_name: Name of the prompt to get. + arguments: Optional prompt arguments. + + Returns: + Prompt content and metadata. + """ + if not self.connected: + await self.connect() + + arguments = arguments or {} + + return await self._retry_operation( + lambda: self._get_prompt_impl(prompt_name, arguments), + timeout=self.execution_timeout, + ) + + async def _get_prompt_impl( + self, prompt_name: str, arguments: dict[str, Any] + ) -> dict[str, Any]: + """Internal implementation of get_prompt.""" + result = await asyncio.wait_for( + self.session.get_prompt(prompt_name, arguments), + timeout=self.execution_timeout, + ) + + return { + "name": prompt_name, + "messages": [ + { + "role": msg.role, + "content": msg.content, + } + for msg in result.messages + ], + "arguments": arguments, + } + + async def _retry_operation( + self, + operation: Callable[[], Any], + timeout: int | None = None, + ) -> Any: + """Retry an operation with exponential backoff. + + Args: + operation: Async operation to retry. + timeout: Operation timeout in seconds. + + Returns: + Operation result. + """ + last_error = None + timeout = timeout or self.execution_timeout + + for attempt in range(self.max_retries): + try: + if timeout: + return await asyncio.wait_for(operation(), timeout=timeout) + return await operation() + + except asyncio.TimeoutError as e: # noqa: PERF203 + last_error = f"Operation timed out after {timeout} seconds" + if attempt < self.max_retries - 1: + wait_time = 2**attempt + await asyncio.sleep(wait_time) + else: + raise ConnectionError(last_error) from e + + except Exception as e: + error_str = str(e).lower() + + # Classify errors as retryable or non-retryable + if "authentication" in error_str or "unauthorized" in error_str: + raise ConnectionError(f"Authentication failed: {e}") from e + + if "not found" in error_str: + raise ValueError(f"Resource not found: {e}") from e + + # Retryable errors + last_error = str(e) + if attempt < self.max_retries - 1: + wait_time = 2**attempt + await asyncio.sleep(wait_time) + else: + raise ConnectionError( + f"Operation failed after {self.max_retries} attempts: {last_error}" + ) from e + + raise ConnectionError(f"Operation failed: {last_error}") + + def _get_cache_key(self, resource_type: str) -> str: + """Generate cache key for resource. + + Args: + resource_type: Type of resource (e.g., "tools", "prompts"). + + Returns: + Cache key string. + """ + # Use transport type and URL/command as cache key + if isinstance(self.transport, StdioTransport): + key = f"stdio:{self.transport.command}:{':'.join(self.transport.args)}" + elif isinstance(self.transport, HTTPTransport): + key = f"http:{self.transport.url}" + elif isinstance(self.transport, SSETransport): + key = f"sse:{self.transport.url}" + else: + key = f"{self.transport.transport_type}:unknown" + + return f"mcp:{key}:{resource_type}" + + async def __aenter__(self) -> Self: + """Async context manager entry.""" + return await self.connect() + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: Any, + ) -> None: + """Async context manager exit.""" + await self.disconnect() diff --git a/lib/crewai/src/crewai/mcp/config.py b/lib/crewai/src/crewai/mcp/config.py new file mode 100644 index 000000000..775f9403d --- /dev/null +++ b/lib/crewai/src/crewai/mcp/config.py @@ -0,0 +1,124 @@ +"""MCP server configuration models for CrewAI agents. + +This module provides Pydantic models for configuring MCP servers with +various transport types, similar to OpenAI's Agents SDK. +""" + +from pydantic import BaseModel, Field + +from crewai.mcp.filters import ToolFilter + + +class MCPServerStdio(BaseModel): + """Stdio MCP server configuration. + + This configuration is used for connecting to local MCP servers + that run as processes and communicate via standard input/output. + + Example: + ```python + mcp_server = MCPServerStdio( + command="python", + args=["path/to/server.py"], + env={"API_KEY": "..."}, + tool_filter=create_static_tool_filter( + allowed_tool_names=["read_file", "write_file"] + ), + ) + ``` + """ + + command: str = Field( + ..., + description="Command to execute (e.g., 'python', 'node', 'npx', 'uvx').", + ) + args: list[str] = Field( + default_factory=list, + description="Command arguments (e.g., ['server.py'] or ['-y', '@mcp/server']).", + ) + env: dict[str, str] | None = Field( + default=None, + description="Environment variables to pass to the process.", + ) + tool_filter: ToolFilter | None = Field( + default=None, + description="Optional tool filter for filtering available tools.", + ) + cache_tools_list: bool = Field( + default=False, + description="Whether to cache the tool list for faster subsequent access.", + ) + + +class MCPServerHTTP(BaseModel): + """HTTP/Streamable HTTP MCP server configuration. + + This configuration is used for connecting to remote MCP servers + over HTTP/HTTPS using streamable HTTP transport. + + Example: + ```python + mcp_server = MCPServerHTTP( + url="https://api.example.com/mcp", + headers={"Authorization": "Bearer ..."}, + cache_tools_list=True, + ) + ``` + """ + + url: str = Field( + ..., description="Server URL (e.g., 'https://api.example.com/mcp')." + ) + headers: dict[str, str] | None = Field( + default=None, + description="Optional HTTP headers for authentication or other purposes.", + ) + streamable: bool = Field( + default=True, + description="Whether to use streamable HTTP transport (default: True).", + ) + tool_filter: ToolFilter | None = Field( + default=None, + description="Optional tool filter for filtering available tools.", + ) + cache_tools_list: bool = Field( + default=False, + description="Whether to cache the tool list for faster subsequent access.", + ) + + +class MCPServerSSE(BaseModel): + """Server-Sent Events (SSE) MCP server configuration. + + This configuration is used for connecting to remote MCP servers + using Server-Sent Events for real-time streaming communication. + + Example: + ```python + mcp_server = MCPServerSSE( + url="https://api.example.com/mcp/sse", + headers={"Authorization": "Bearer ..."}, + ) + ``` + """ + + url: str = Field( + ..., + description="Server URL (e.g., 'https://api.example.com/mcp/sse').", + ) + headers: dict[str, str] | None = Field( + default=None, + description="Optional HTTP headers for authentication or other purposes.", + ) + tool_filter: ToolFilter | None = Field( + default=None, + description="Optional tool filter for filtering available tools.", + ) + cache_tools_list: bool = Field( + default=False, + description="Whether to cache the tool list for faster subsequent access.", + ) + + +# Type alias for all MCP server configurations +MCPServerConfig = MCPServerStdio | MCPServerHTTP | MCPServerSSE diff --git a/lib/crewai/src/crewai/mcp/filters.py b/lib/crewai/src/crewai/mcp/filters.py new file mode 100644 index 000000000..ee2f7a560 --- /dev/null +++ b/lib/crewai/src/crewai/mcp/filters.py @@ -0,0 +1,166 @@ +"""Tool filtering support for MCP servers. + +This module provides utilities for filtering tools from MCP servers, +including static allow/block lists and dynamic context-aware filtering. +""" + +from collections.abc import Callable +from typing import TYPE_CHECKING, Any + +from pydantic import BaseModel, Field + + +if TYPE_CHECKING: + pass + + +class ToolFilterContext(BaseModel): + """Context for dynamic tool filtering. + + This context is passed to dynamic tool filters to provide + information about the agent, run context, and server. + """ + + agent: Any = Field(..., description="The agent requesting tools.") + server_name: str = Field(..., description="Name of the MCP server.") + run_context: dict[str, Any] | None = Field( + default=None, + description="Optional run context for additional filtering logic.", + ) + + +# Type alias for tool filter functions +ToolFilter = ( + Callable[[ToolFilterContext, dict[str, Any]], bool] + | Callable[[dict[str, Any]], bool] +) + + +class StaticToolFilter: + """Static tool filter with allow/block lists. + + This filter provides simple allow/block list filtering based on + tool names. Useful for restricting which tools are available + from an MCP server. + + Example: + ```python + filter = StaticToolFilter( + allowed_tool_names=["read_file", "write_file"], + blocked_tool_names=["delete_file"], + ) + ``` + """ + + def __init__( + self, + allowed_tool_names: list[str] | None = None, + blocked_tool_names: list[str] | None = None, + ) -> None: + """Initialize static tool filter. + + Args: + allowed_tool_names: List of tool names to allow. If None, + all tools are allowed (unless blocked). + blocked_tool_names: List of tool names to block. Blocked tools + take precedence over allowed tools. + """ + self.allowed_tool_names = set(allowed_tool_names or []) + self.blocked_tool_names = set(blocked_tool_names or []) + + def __call__(self, tool: dict[str, Any]) -> bool: + """Filter tool based on allow/block lists. + + Args: + tool: Tool definition dictionary with at least 'name' key. + + Returns: + True if tool should be included, False otherwise. + """ + tool_name = tool.get("name", "") + + # Blocked tools take precedence + if self.blocked_tool_names and tool_name in self.blocked_tool_names: + return False + + # If allow list exists, tool must be in it + if self.allowed_tool_names: + return tool_name in self.allowed_tool_names + + # No restrictions - allow all + return True + + +def create_static_tool_filter( + allowed_tool_names: list[str] | None = None, + blocked_tool_names: list[str] | None = None, +) -> Callable[[dict[str, Any]], bool]: + """Create a static tool filter function. + + This is a convenience function for creating static tool filters + with allow/block lists. + + Args: + allowed_tool_names: List of tool names to allow. If None, + all tools are allowed (unless blocked). + blocked_tool_names: List of tool names to block. Blocked tools + take precedence over allowed tools. + + Returns: + Tool filter function that returns True for allowed tools. + + Example: + ```python + filter_fn = create_static_tool_filter( + allowed_tool_names=["read_file", "write_file"], + blocked_tool_names=["delete_file"], + ) + + # Use in MCPServerStdio + mcp_server = MCPServerStdio( + command="npx", + args=["-y", "@modelcontextprotocol/server-filesystem"], + tool_filter=filter_fn, + ) + ``` + """ + return StaticToolFilter( + allowed_tool_names=allowed_tool_names, + blocked_tool_names=blocked_tool_names, + ) + + +def create_dynamic_tool_filter( + filter_func: Callable[[ToolFilterContext, dict[str, Any]], bool], +) -> Callable[[ToolFilterContext, dict[str, Any]], bool]: + """Create a dynamic tool filter function. + + This function wraps a dynamic filter function that has access + to the tool filter context (agent, server, run context). + + Args: + filter_func: Function that takes (context, tool) and returns bool. + + Returns: + Tool filter function that can be used with MCP server configs. + + Example: + ```python + async def context_aware_filter( + context: ToolFilterContext, tool: dict[str, Any] + ) -> bool: + # Block dangerous tools for code reviewers + if context.agent.role == "Code Reviewer": + if tool["name"].startswith("danger_"): + return False + return True + + + filter_fn = create_dynamic_tool_filter(context_aware_filter) + + mcp_server = MCPServerStdio( + command="python", args=["server.py"], tool_filter=filter_fn + ) + ``` + """ + return filter_func diff --git a/lib/crewai/src/crewai/mcp/transports/__init__.py b/lib/crewai/src/crewai/mcp/transports/__init__.py new file mode 100644 index 000000000..4e579f50e --- /dev/null +++ b/lib/crewai/src/crewai/mcp/transports/__init__.py @@ -0,0 +1,15 @@ +"""MCP transport implementations for various connection types.""" + +from crewai.mcp.transports.base import BaseTransport, TransportType +from crewai.mcp.transports.http import HTTPTransport +from crewai.mcp.transports.sse import SSETransport +from crewai.mcp.transports.stdio import StdioTransport + + +__all__ = [ + "BaseTransport", + "HTTPTransport", + "SSETransport", + "StdioTransport", + "TransportType", +] diff --git a/lib/crewai/src/crewai/mcp/transports/base.py b/lib/crewai/src/crewai/mcp/transports/base.py new file mode 100644 index 000000000..d6e5f958d --- /dev/null +++ b/lib/crewai/src/crewai/mcp/transports/base.py @@ -0,0 +1,125 @@ +"""Base transport interface for MCP connections.""" + +from abc import ABC, abstractmethod +from enum import Enum +from typing import Any, Protocol + +from typing_extensions import Self + + +class TransportType(str, Enum): + """MCP transport types.""" + + STDIO = "stdio" + HTTP = "http" + STREAMABLE_HTTP = "streamable-http" + SSE = "sse" + + +class ReadStream(Protocol): + """Protocol for read streams.""" + + async def read(self, n: int = -1) -> bytes: + """Read bytes from stream.""" + ... + + +class WriteStream(Protocol): + """Protocol for write streams.""" + + async def write(self, data: bytes) -> None: + """Write bytes to stream.""" + ... + + +class BaseTransport(ABC): + """Base class for MCP transport implementations. + + This abstract base class defines the interface that all transport + implementations must follow. Transports handle the low-level communication + with MCP servers. + """ + + def __init__(self, **kwargs: Any) -> None: + """Initialize the transport. + + Args: + **kwargs: Transport-specific configuration options. + """ + self._read_stream: ReadStream | None = None + self._write_stream: WriteStream | None = None + self._connected = False + + @property + @abstractmethod + def transport_type(self) -> TransportType: + """Return the transport type.""" + ... + + @property + def connected(self) -> bool: + """Check if transport is connected.""" + return self._connected + + @property + def read_stream(self) -> ReadStream: + """Get the read stream.""" + if self._read_stream is None: + raise RuntimeError("Transport not connected. Call connect() first.") + return self._read_stream + + @property + def write_stream(self) -> WriteStream: + """Get the write stream.""" + if self._write_stream is None: + raise RuntimeError("Transport not connected. Call connect() first.") + return self._write_stream + + @abstractmethod + async def connect(self) -> Self: + """Establish connection to MCP server. + + Returns: + Self for method chaining. + + Raises: + ConnectionError: If connection fails. + """ + ... + + @abstractmethod + async def disconnect(self) -> None: + """Close connection to MCP server.""" + ... + + @abstractmethod + async def __aenter__(self) -> Self: + """Async context manager entry.""" + ... + + @abstractmethod + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: Any, + ) -> None: + """Async context manager exit.""" + ... + + def _set_streams(self, read: ReadStream, write: WriteStream) -> None: + """Set the read and write streams. + + Args: + read: Read stream. + write: Write stream. + """ + self._read_stream = read + self._write_stream = write + self._connected = True + + def _clear_streams(self) -> None: + """Clear the read and write streams.""" + self._read_stream = None + self._write_stream = None + self._connected = False diff --git a/lib/crewai/src/crewai/mcp/transports/http.py b/lib/crewai/src/crewai/mcp/transports/http.py new file mode 100644 index 000000000..d531d8906 --- /dev/null +++ b/lib/crewai/src/crewai/mcp/transports/http.py @@ -0,0 +1,174 @@ +"""HTTP and Streamable HTTP transport for MCP servers.""" + +import asyncio +from typing import Any + +from typing_extensions import Self + + +# BaseExceptionGroup is available in Python 3.11+ +try: + from builtins import BaseExceptionGroup +except ImportError: + # Fallback for Python < 3.11 (shouldn't happen in practice) + BaseExceptionGroup = Exception + +from crewai.mcp.transports.base import BaseTransport, TransportType + + +class HTTPTransport(BaseTransport): + """HTTP/Streamable HTTP transport for connecting to remote MCP servers. + + This transport connects to MCP servers over HTTP/HTTPS using the + streamable HTTP client from the MCP SDK. + + Example: + ```python + transport = HTTPTransport( + url="https://api.example.com/mcp", + headers={"Authorization": "Bearer ..."} + ) + async with transport: + # Use transport... + ``` + """ + + def __init__( + self, + url: str, + headers: dict[str, str] | None = None, + streamable: bool = True, + **kwargs: Any, + ) -> None: + """Initialize HTTP transport. + + Args: + url: Server URL (e.g., "https://api.example.com/mcp"). + headers: Optional HTTP headers. + streamable: Whether to use streamable HTTP (default: True). + **kwargs: Additional transport options. + """ + super().__init__(**kwargs) + self.url = url + self.headers = headers or {} + self.streamable = streamable + self._transport_context: Any = None + + @property + def transport_type(self) -> TransportType: + """Return the transport type.""" + return TransportType.STREAMABLE_HTTP if self.streamable else TransportType.HTTP + + async def connect(self) -> Self: + """Establish HTTP connection to MCP server. + + Returns: + Self for method chaining. + + Raises: + ConnectionError: If connection fails. + ImportError: If MCP SDK not available. + """ + if self._connected: + return self + + try: + from mcp.client.streamable_http import streamablehttp_client + + self._transport_context = streamablehttp_client( + self.url, + headers=self.headers if self.headers else None, + terminate_on_close=True, + ) + + try: + read, write, _ = await asyncio.wait_for( + self._transport_context.__aenter__(), timeout=30.0 + ) + except asyncio.TimeoutError as e: + self._transport_context = None + raise ConnectionError( + "Transport context entry timed out after 30 seconds. " + "Server may be slow or unreachable." + ) from e + except Exception as e: + self._transport_context = None + raise ConnectionError(f"Failed to enter transport context: {e}") from e + self._set_streams(read=read, write=write) + return self + + except ImportError as e: + raise ImportError( + "MCP library not available. Please install with: pip install mcp" + ) from e + except Exception as e: + self._clear_streams() + if self._transport_context is not None: + self._transport_context = None + raise ConnectionError(f"Failed to connect to MCP server: {e}") from e + + async def disconnect(self) -> None: + """Close HTTP connection.""" + if not self._connected: + return + + try: + # Clear streams first + self._clear_streams() + # await self._exit_stack.aclose() + + # Exit transport context - this will clean up background tasks + # Give a small delay to allow background tasks to complete + if self._transport_context is not None: + try: + # Wait a tiny bit for any pending operations + await asyncio.sleep(0.1) + await self._transport_context.__aexit__(None, None, None) + except (RuntimeError, asyncio.CancelledError) as e: + # Ignore "exit cancel scope in different task" errors and cancellation + # These happen when asyncio.run() closes the event loop + # while background tasks are still running + error_msg = str(e).lower() + if "cancel scope" not in error_msg and "task" not in error_msg: + # Only suppress cancel scope/task errors, re-raise others + if isinstance(e, RuntimeError): + raise + # For CancelledError, just suppress it + except BaseExceptionGroup as eg: + # Handle exception groups from anyio task groups + # Suppress if they contain cancel scope errors + should_suppress = False + for exc in eg.exceptions: + error_msg = str(exc).lower() + if "cancel scope" in error_msg or "task" in error_msg: + should_suppress = True + break + if not should_suppress: + raise + except Exception as e: + raise RuntimeError( + f"Error during HTTP transport disconnect: {e}" + ) from e + + self._connected = False + + except Exception as e: + # Log but don't raise - cleanup should be best effort + import logging + + logger = logging.getLogger(__name__) + logger.warning(f"Error during HTTP transport disconnect: {e}") + + async def __aenter__(self) -> Self: + """Async context manager entry.""" + return await self.connect() + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: Any, + ) -> None: + """Async context manager exit.""" + + await self.disconnect() diff --git a/lib/crewai/src/crewai/mcp/transports/sse.py b/lib/crewai/src/crewai/mcp/transports/sse.py new file mode 100644 index 000000000..ce418c51f --- /dev/null +++ b/lib/crewai/src/crewai/mcp/transports/sse.py @@ -0,0 +1,113 @@ +"""Server-Sent Events (SSE) transport for MCP servers.""" + +from typing import Any + +from typing_extensions import Self + +from crewai.mcp.transports.base import BaseTransport, TransportType + + +class SSETransport(BaseTransport): + """SSE transport for connecting to remote MCP servers. + + This transport connects to MCP servers using Server-Sent Events (SSE) + for real-time streaming communication. + + Example: + ```python + transport = SSETransport( + url="https://api.example.com/mcp/sse", + headers={"Authorization": "Bearer ..."} + ) + async with transport: + # Use transport... + ``` + """ + + def __init__( + self, + url: str, + headers: dict[str, str] | None = None, + **kwargs: Any, + ) -> None: + """Initialize SSE transport. + + Args: + url: Server URL (e.g., "https://api.example.com/mcp/sse"). + headers: Optional HTTP headers. + **kwargs: Additional transport options. + """ + super().__init__(**kwargs) + self.url = url + self.headers = headers or {} + self._transport_context: Any = None + + @property + def transport_type(self) -> TransportType: + """Return the transport type.""" + return TransportType.SSE + + async def connect(self) -> Self: + """Establish SSE connection to MCP server. + + Returns: + Self for method chaining. + + Raises: + ConnectionError: If connection fails. + ImportError: If MCP SDK not available. + """ + if self._connected: + return self + + try: + from mcp.client.sse import sse_client + + self._transport_context = sse_client( + self.url, + headers=self.headers if self.headers else None, + terminate_on_close=True, + ) + + read, write = await self._transport_context.__aenter__() + + self._set_streams(read=read, write=write) + + return self + + except ImportError as e: + raise ImportError( + "MCP library not available. Please install with: pip install mcp" + ) from e + except Exception as e: + self._clear_streams() + raise ConnectionError(f"Failed to connect to SSE MCP server: {e}") from e + + async def disconnect(self) -> None: + """Close SSE connection.""" + if not self._connected: + return + + try: + self._clear_streams() + if self._transport_context is not None: + await self._transport_context.__aexit__(None, None, None) + + except Exception as e: + import logging + + logger = logging.getLogger(__name__) + logger.warning(f"Error during SSE transport disconnect: {e}") + + async def __aenter__(self) -> Self: + """Async context manager entry.""" + return await self.connect() + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: Any, + ) -> None: + """Async context manager exit.""" + await self.disconnect() diff --git a/lib/crewai/src/crewai/mcp/transports/stdio.py b/lib/crewai/src/crewai/mcp/transports/stdio.py new file mode 100644 index 000000000..65f288505 --- /dev/null +++ b/lib/crewai/src/crewai/mcp/transports/stdio.py @@ -0,0 +1,153 @@ +"""Stdio transport for MCP servers running as local processes.""" + +import asyncio +import os +import subprocess +from typing import Any + +from typing_extensions import Self + +from crewai.mcp.transports.base import BaseTransport, TransportType + + +class StdioTransport(BaseTransport): + """Stdio transport for connecting to local MCP servers. + + This transport connects to MCP servers running as local processes, + communicating via standard input/output streams. Supports Python, + Node.js, and other command-line servers. + + Example: + ```python + transport = StdioTransport( + command="python", + args=["path/to/server.py"], + env={"API_KEY": "..."} + ) + async with transport: + # Use transport... + ``` + """ + + def __init__( + self, + command: str, + args: list[str] | None = None, + env: dict[str, str] | None = None, + **kwargs: Any, + ) -> None: + """Initialize stdio transport. + + Args: + command: Command to execute (e.g., "python", "node", "npx"). + args: Command arguments (e.g., ["server.py"] or ["-y", "@mcp/server"]). + env: Environment variables to pass to the process. + **kwargs: Additional transport options. + """ + super().__init__(**kwargs) + self.command = command + self.args = args or [] + self.env = env or {} + self._process: subprocess.Popen[bytes] | None = None + self._transport_context: Any = None + + @property + def transport_type(self) -> TransportType: + """Return the transport type.""" + return TransportType.STDIO + + async def connect(self) -> Self: + """Start the MCP server process and establish connection. + + Returns: + Self for method chaining. + + Raises: + ConnectionError: If process fails to start. + ImportError: If MCP SDK not available. + """ + if self._connected: + return self + + try: + from mcp import StdioServerParameters + from mcp.client.stdio import stdio_client + + process_env = os.environ.copy() + process_env.update(self.env) + + server_params = StdioServerParameters( + command=self.command, + args=self.args, + env=process_env if process_env else None, + ) + self._transport_context = stdio_client(server_params) + + try: + read, write = await self._transport_context.__aenter__() + except Exception as e: + import traceback + + traceback.print_exc() + self._transport_context = None + raise ConnectionError( + f"Failed to enter stdio transport context: {e}" + ) from e + + self._set_streams(read=read, write=write) + + return self + + except ImportError as e: + raise ImportError( + "MCP library not available. Please install with: pip install mcp" + ) from e + except Exception as e: + self._clear_streams() + if self._transport_context is not None: + self._transport_context = None + raise ConnectionError(f"Failed to start MCP server process: {e}") from e + + async def disconnect(self) -> None: + """Terminate the MCP server process and close connection.""" + if not self._connected: + return + + try: + self._clear_streams() + + if self._transport_context is not None: + await self._transport_context.__aexit__(None, None, None) + + if self._process is not None: + try: + self._process.terminate() + try: + await asyncio.wait_for(self._process.wait(), timeout=5.0) + except asyncio.TimeoutError: + self._process.kill() + await self._process.wait() + # except ProcessLookupError: + # pass + finally: + self._process = None + + except Exception as e: + # Log but don't raise - cleanup should be best effort + import logging + + logger = logging.getLogger(__name__) + logger.warning(f"Error during stdio transport disconnect: {e}") + + async def __aenter__(self) -> Self: + """Async context manager entry.""" + return await self.connect() + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: Any, + ) -> None: + """Async context manager exit.""" + await self.disconnect() diff --git a/lib/crewai/src/crewai/tools/mcp_native_tool.py b/lib/crewai/src/crewai/tools/mcp_native_tool.py new file mode 100644 index 000000000..c10d51eee --- /dev/null +++ b/lib/crewai/src/crewai/tools/mcp_native_tool.py @@ -0,0 +1,154 @@ +"""Native MCP tool wrapper for CrewAI agents. + +This module provides a tool wrapper that reuses existing MCP client sessions +for better performance and connection management. +""" + +import asyncio +from typing import Any + +from crewai.tools import BaseTool + + +class MCPNativeTool(BaseTool): + """Native MCP tool that reuses client sessions. + + This tool wrapper is used when agents connect to MCP servers using + structured configurations. It reuses existing client sessions for + better performance and proper connection lifecycle management. + + Unlike MCPToolWrapper which connects on-demand, this tool uses + a shared MCP client instance that maintains a persistent connection. + """ + + def __init__( + self, + mcp_client: Any, + tool_name: str, + tool_schema: dict[str, Any], + server_name: str, + ) -> None: + """Initialize native MCP tool. + + Args: + mcp_client: MCPClient instance with active session. + tool_name: Original name of the tool on the MCP server. + tool_schema: Schema information for the tool. + server_name: Name of the MCP server for prefixing. + """ + # Create tool name with server prefix to avoid conflicts + prefixed_name = f"{server_name}_{tool_name}" + + # Handle args_schema properly - BaseTool expects a BaseModel subclass + args_schema = tool_schema.get("args_schema") + + # Only pass args_schema if it's provided + kwargs = { + "name": prefixed_name, + "description": tool_schema.get( + "description", f"Tool {tool_name} from {server_name}" + ), + } + + if args_schema is not None: + kwargs["args_schema"] = args_schema + + super().__init__(**kwargs) + + # Set instance attributes after super().__init__ + self._mcp_client = mcp_client + self._original_tool_name = tool_name + self._server_name = server_name + # self._logger = logging.getLogger(__name__) + + @property + def mcp_client(self) -> Any: + """Get the MCP client instance.""" + return self._mcp_client + + @property + def original_tool_name(self) -> str: + """Get the original tool name.""" + return self._original_tool_name + + @property + def server_name(self) -> str: + """Get the server name.""" + return self._server_name + + def _run(self, **kwargs) -> str: + """Execute tool using the MCP client session. + + Args: + **kwargs: Arguments to pass to the MCP tool. + + Returns: + Result from the MCP tool execution. + """ + try: + # Always use asyncio.run() to create a fresh event loop + # This ensures the async context managers work correctly + return asyncio.run(self._run_async(**kwargs)) + + except Exception as e: + raise RuntimeError( + f"Error executing MCP tool {self.original_tool_name}: {e!s}" + ) from e + + async def _run_async(self, **kwargs) -> str: + """Async implementation of tool execution. + + Args: + **kwargs: Arguments to pass to the MCP tool. + + Returns: + Result from the MCP tool execution. + """ + # Note: Since we use asyncio.run() which creates a new event loop each time, + # Always reconnect on-demand because asyncio.run() creates new event loops per call + # All MCP transport context managers (stdio, streamablehttp_client, sse_client) + # use anyio.create_task_group() which can't span different event loops + if self._mcp_client.connected: + await self._mcp_client.disconnect() + + await self._mcp_client.connect() + + try: + result = await self._mcp_client.call_tool(self.original_tool_name, kwargs) + + except Exception as e: + error_str = str(e).lower() + if ( + "not connected" in error_str + or "connection" in error_str + or "send" in error_str + ): + await self._mcp_client.disconnect() + await self._mcp_client.connect() + # Retry the call + result = await self._mcp_client.call_tool( + self.original_tool_name, kwargs + ) + else: + raise + + finally: + # Always disconnect after tool call to ensure clean context manager lifecycle + # This prevents "exit cancel scope in different task" errors + # All transport context managers must be exited in the same event loop they were entered + await self._mcp_client.disconnect() + + # Extract result content + if isinstance(result, str): + return result + + # Handle various result formats + if hasattr(result, "content") and result.content: + if isinstance(result.content, list) and len(result.content) > 0: + content_item = result.content[0] + if hasattr(content_item, "text"): + return str(content_item.text) + return str(content_item) + return str(result.content) + + return str(result) diff --git a/lib/crewai/tests/mcp/__init__.py b/lib/crewai/tests/mcp/__init__.py new file mode 100644 index 000000000..740ce3f01 --- /dev/null +++ b/lib/crewai/tests/mcp/__init__.py @@ -0,0 +1,4 @@ +"""Tests for MCP (Model Context Protocol) integration.""" + + + diff --git a/lib/crewai/tests/mcp/test_mcp_config.py b/lib/crewai/tests/mcp/test_mcp_config.py new file mode 100644 index 000000000..627ceb6e2 --- /dev/null +++ b/lib/crewai/tests/mcp/test_mcp_config.py @@ -0,0 +1,136 @@ +from unittest.mock import AsyncMock, patch + +import pytest +from crewai.agent.core import Agent +from crewai.mcp.config import MCPServerHTTP, MCPServerSSE, MCPServerStdio +from crewai.tools.base_tool import BaseTool + + +@pytest.fixture +def mock_tool_definitions(): + """Create mock MCP tool definitions (as returned by list_tools).""" + return [ + { + "name": "test_tool_1", + "description": "Test tool 1 description", + "inputSchema": { + "type": "object", + "properties": { + "query": {"type": "string", "description": "Search query"} + }, + "required": ["query"] + } + }, + { + "name": "test_tool_2", + "description": "Test tool 2 description", + "inputSchema": {} + } + ] + + +def test_agent_with_stdio_mcp_config(mock_tool_definitions): + """Test agent setup with MCPServerStdio configuration.""" + stdio_config = MCPServerStdio( + command="python", + args=["server.py"], + env={"API_KEY": "test_key"}, + ) + + agent = Agent( + role="Test Agent", + goal="Test goal", + backstory="Test backstory", + mcps=[stdio_config], + ) + + + with patch("crewai.agent.core.MCPClient") as mock_client_class: + mock_client = AsyncMock() + mock_client.list_tools = AsyncMock(return_value=mock_tool_definitions) + mock_client.connected = False # Will trigger connect + mock_client.connect = AsyncMock() + mock_client.disconnect = AsyncMock() + mock_client_class.return_value = mock_client + + tools = agent.get_mcp_tools([stdio_config]) + + assert len(tools) == 2 + assert all(isinstance(tool, BaseTool) for tool in tools) + + mock_client_class.assert_called_once() + call_args = mock_client_class.call_args + transport = call_args.kwargs["transport"] + assert transport.command == "python" + assert transport.args == ["server.py"] + assert transport.env == {"API_KEY": "test_key"} + + +def test_agent_with_http_mcp_config(mock_tool_definitions): + """Test agent setup with MCPServerHTTP configuration.""" + http_config = MCPServerHTTP( + url="https://api.example.com/mcp", + headers={"Authorization": "Bearer test_token"}, + streamable=True, + ) + + agent = Agent( + role="Test Agent", + goal="Test goal", + backstory="Test backstory", + mcps=[http_config], + ) + + with patch("crewai.agent.core.MCPClient") as mock_client_class: + mock_client = AsyncMock() + mock_client.list_tools = AsyncMock(return_value=mock_tool_definitions) + mock_client.connected = False # Will trigger connect + mock_client.connect = AsyncMock() + mock_client.disconnect = AsyncMock() + mock_client_class.return_value = mock_client + + tools = agent.get_mcp_tools([http_config]) + + assert len(tools) == 2 + assert all(isinstance(tool, BaseTool) for tool in tools) + + mock_client_class.assert_called_once() + call_args = mock_client_class.call_args + transport = call_args.kwargs["transport"] + assert transport.url == "https://api.example.com/mcp" + assert transport.headers == {"Authorization": "Bearer test_token"} + assert transport.streamable is True + + +def test_agent_with_sse_mcp_config(mock_tool_definitions): + """Test agent setup with MCPServerSSE configuration.""" + sse_config = MCPServerSSE( + url="https://api.example.com/mcp/sse", + headers={"Authorization": "Bearer test_token"}, + ) + + agent = Agent( + role="Test Agent", + goal="Test goal", + backstory="Test backstory", + mcps=[sse_config], + ) + + with patch("crewai.agent.core.MCPClient") as mock_client_class: + mock_client = AsyncMock() + mock_client.list_tools = AsyncMock(return_value=mock_tool_definitions) + mock_client.connected = False + mock_client.connect = AsyncMock() + mock_client.disconnect = AsyncMock() + mock_client_class.return_value = mock_client + + tools = agent.get_mcp_tools([sse_config]) + + assert len(tools) == 2 + assert all(isinstance(tool, BaseTool) for tool in tools) + + mock_client_class.assert_called_once() + call_args = mock_client_class.call_args + transport = call_args.kwargs["transport"] + assert transport.url == "https://api.example.com/mcp/sse" + assert transport.headers == {"Authorization": "Bearer test_token"} From 40a2d387a1ae6f34ecc6c57e8090d09d3405cd90 Mon Sep 17 00:00:00 2001 From: Greyson LaLonde Date: Thu, 6 Nov 2025 21:10:25 -0500 Subject: [PATCH 2/4] fix: keep stopwords updated --- lib/crewai/src/crewai/llms/hooks/transport.py | 11 +- .../llms/providers/anthropic/completion.py | 24 +++ .../llms/providers/bedrock/completion.py | 24 +++ .../llms/providers/gemini/completion.py | 24 +++ ..._anthropic_stop_sequences_sent_to_api.yaml | 202 ++++++++++++++++++ .../tests/llms/anthropic/test_anthropic.py | 34 +++ lib/crewai/tests/llms/bedrock/test_bedrock.py | 53 +++++ lib/crewai/tests/llms/google/test_google.py | 52 +++++ 8 files changed, 418 insertions(+), 6 deletions(-) create mode 100644 lib/crewai/tests/cassettes/test_anthropic_stop_sequences_sent_to_api.yaml diff --git a/lib/crewai/src/crewai/llms/hooks/transport.py b/lib/crewai/src/crewai/llms/hooks/transport.py index ee3f9224c..27a0972ab 100644 --- a/lib/crewai/src/crewai/llms/hooks/transport.py +++ b/lib/crewai/src/crewai/llms/hooks/transport.py @@ -25,7 +25,7 @@ if TYPE_CHECKING: from crewai.llms.hooks.base import BaseInterceptor -class HTTPTransportKwargs(TypedDict): +class HTTPTransportKwargs(TypedDict, total=False): """Typed dictionary for httpx.HTTPTransport initialization parameters. These parameters configure the underlying HTTP transport behavior including @@ -33,14 +33,14 @@ class HTTPTransportKwargs(TypedDict): """ verify: bool | str | SSLContext - cert: NotRequired[CertTypes | None] + cert: NotRequired[CertTypes] trust_env: bool http1: bool http2: bool limits: Limits - proxy: NotRequired[ProxyTypes | None] - uds: NotRequired[str | None] - local_address: NotRequired[str | None] + proxy: NotRequired[ProxyTypes] + uds: NotRequired[str] + local_address: NotRequired[str] retries: int socket_options: NotRequired[ Iterable[ @@ -48,7 +48,6 @@ class HTTPTransportKwargs(TypedDict): | tuple[int, int, bytes | bytearray] | tuple[int, int, None, int] ] - | None ] diff --git a/lib/crewai/src/crewai/llms/providers/anthropic/completion.py b/lib/crewai/src/crewai/llms/providers/anthropic/completion.py index 50298eb77..ea161fc63 100644 --- a/lib/crewai/src/crewai/llms/providers/anthropic/completion.py +++ b/lib/crewai/src/crewai/llms/providers/anthropic/completion.py @@ -94,6 +94,30 @@ class AnthropicCompletion(BaseLLM): self.is_claude_3 = "claude-3" in model.lower() self.supports_tools = self.is_claude_3 # Claude 3+ supports tool use + @property + def stop(self) -> list[str]: + """Get stop sequences sent to the API.""" + return self.stop_sequences + + @stop.setter + def stop(self, value: list[str] | str | None) -> None: + """Set stop sequences. + + Synchronizes stop_sequences to ensure values set by CrewAgentExecutor + are properly sent to the Anthropic API. + + Args: + value: Stop sequences as a list, single string, or None + """ + if value is None: + self.stop_sequences = [] + elif isinstance(value, str): + self.stop_sequences = [value] + elif isinstance(value, list): + self.stop_sequences = value + else: + self.stop_sequences = [] + def _get_client_params(self) -> dict[str, Any]: """Get client parameters.""" diff --git a/lib/crewai/src/crewai/llms/providers/bedrock/completion.py b/lib/crewai/src/crewai/llms/providers/bedrock/completion.py index ff0808937..20eabf763 100644 --- a/lib/crewai/src/crewai/llms/providers/bedrock/completion.py +++ b/lib/crewai/src/crewai/llms/providers/bedrock/completion.py @@ -243,6 +243,30 @@ class BedrockCompletion(BaseLLM): # Handle inference profiles for newer models self.model_id = model + @property + def stop(self) -> list[str]: + """Get stop sequences sent to the API.""" + return list(self.stop_sequences) + + @stop.setter + def stop(self, value: Sequence[str] | str | None) -> None: + """Set stop sequences. + + Synchronizes stop_sequences to ensure values set by CrewAgentExecutor + are properly sent to the Bedrock API. + + Args: + value: Stop sequences as a Sequence, single string, or None + """ + if value is None: + self.stop_sequences = [] + elif isinstance(value, str): + self.stop_sequences = [value] + elif isinstance(value, Sequence): + self.stop_sequences = list(value) + else: + self.stop_sequences = [] + def call( self, messages: str | list[LLMMessage], diff --git a/lib/crewai/src/crewai/llms/providers/gemini/completion.py b/lib/crewai/src/crewai/llms/providers/gemini/completion.py index 45b603c19..8668a8f58 100644 --- a/lib/crewai/src/crewai/llms/providers/gemini/completion.py +++ b/lib/crewai/src/crewai/llms/providers/gemini/completion.py @@ -104,6 +104,30 @@ class GeminiCompletion(BaseLLM): self.is_gemini_1_5 = "gemini-1.5" in model.lower() self.supports_tools = self.is_gemini_1_5 or self.is_gemini_2 + @property + def stop(self) -> list[str]: + """Get stop sequences sent to the API.""" + return self.stop_sequences + + @stop.setter + def stop(self, value: list[str] | str | None) -> None: + """Set stop sequences. + + Synchronizes stop_sequences to ensure values set by CrewAgentExecutor + are properly sent to the Gemini API. + + Args: + value: Stop sequences as a list, single string, or None + """ + if value is None: + self.stop_sequences = [] + elif isinstance(value, str): + self.stop_sequences = [value] + elif isinstance(value, list): + self.stop_sequences = value + else: + self.stop_sequences = [] + def _initialize_client(self, use_vertexai: bool = False) -> genai.Client: # type: ignore[no-any-unimported] """Initialize the Google Gen AI client with proper parameter handling. diff --git a/lib/crewai/tests/cassettes/test_anthropic_stop_sequences_sent_to_api.yaml b/lib/crewai/tests/cassettes/test_anthropic_stop_sequences_sent_to_api.yaml new file mode 100644 index 000000000..8759062f9 --- /dev/null +++ b/lib/crewai/tests/cassettes/test_anthropic_stop_sequences_sent_to_api.yaml @@ -0,0 +1,202 @@ +interactions: +- request: + body: '{"trace_id": "1703c4e0-d3be-411c-85e7-48018c2df384", "execution_type": + "crew", "user_identifier": null, "execution_context": {"crew_fingerprint": null, + "crew_name": "Unknown Crew", "flow_name": null, "crewai_version": "1.3.0", "privacy_level": + "standard"}, "execution_metadata": {"expected_duration_estimate": 300, "agent_count": + 0, "task_count": 0, "flow_method_count": 0, "execution_started_at": "2025-11-07T01:58:22.260309+00:00"}}' + headers: + Accept: + - '*/*' + Accept-Encoding: + - gzip, deflate, zstd + Connection: + - keep-alive + Content-Length: + - '434' + Content-Type: + - application/json + User-Agent: + - CrewAI-CLI/1.3.0 + X-Crewai-Version: + - 1.3.0 + method: POST + uri: https://app.crewai.com/crewai_plus/api/v1/tracing/batches + response: + body: + string: '{"error":"bad_credentials","message":"Bad credentials"}' + headers: + Connection: + - keep-alive + Content-Length: + - '55' + Content-Type: + - application/json; charset=utf-8 + Date: + - Fri, 07 Nov 2025 01:58:22 GMT + cache-control: + - no-store + content-security-policy: + - 'default-src ''self'' *.app.crewai.com app.crewai.com; script-src ''self'' + ''unsafe-inline'' *.app.crewai.com app.crewai.com https://cdn.jsdelivr.net/npm/apexcharts + https://www.gstatic.com https://run.pstmn.io https://apis.google.com https://apis.google.com/js/api.js + https://accounts.google.com https://accounts.google.com/gsi/client https://cdnjs.cloudflare.com/ajax/libs/normalize/8.0.1/normalize.min.css.map + https://*.google.com https://docs.google.com https://slides.google.com https://js.hs-scripts.com + https://js.sentry-cdn.com https://browser.sentry-cdn.com https://www.googletagmanager.com + https://js-na1.hs-scripts.com https://js.hubspot.com http://js-na1.hs-scripts.com + https://bat.bing.com https://cdn.amplitude.com https://cdn.segment.com https://d1d3n03t5zntha.cloudfront.net/ + https://descriptusercontent.com https://edge.fullstory.com https://googleads.g.doubleclick.net + https://js.hs-analytics.net https://js.hs-banner.com https://js.hsadspixel.net + https://js.hscollectedforms.net https://js.usemessages.com https://snap.licdn.com + https://static.cloudflareinsights.com https://static.reo.dev https://www.google-analytics.com + https://share.descript.com/; style-src ''self'' ''unsafe-inline'' *.app.crewai.com + app.crewai.com https://cdn.jsdelivr.net/npm/apexcharts; img-src ''self'' data: + *.app.crewai.com app.crewai.com https://zeus.tools.crewai.com https://dashboard.tools.crewai.com + https://cdn.jsdelivr.net https://forms.hsforms.com https://track.hubspot.com + https://px.ads.linkedin.com https://px4.ads.linkedin.com https://www.google.com + https://www.google.com.br; font-src ''self'' data: *.app.crewai.com app.crewai.com; + connect-src ''self'' *.app.crewai.com app.crewai.com https://zeus.tools.crewai.com + https://connect.useparagon.com/ https://zeus.useparagon.com/* https://*.useparagon.com/* + https://run.pstmn.io https://connect.tools.crewai.com/ https://*.sentry.io + https://www.google-analytics.com https://edge.fullstory.com https://rs.fullstory.com + https://api.hubspot.com https://forms.hscollectedforms.net https://api.hubapi.com + https://px.ads.linkedin.com https://px4.ads.linkedin.com https://google.com/pagead/form-data/16713662509 + https://google.com/ccm/form-data/16713662509 https://www.google.com/ccm/collect + https://worker-actionkit.tools.crewai.com https://api.reo.dev; frame-src ''self'' + *.app.crewai.com app.crewai.com https://connect.useparagon.com/ https://zeus.tools.crewai.com + https://zeus.useparagon.com/* https://connect.tools.crewai.com/ https://docs.google.com + https://drive.google.com https://slides.google.com https://accounts.google.com + https://*.google.com https://app.hubspot.com/ https://td.doubleclick.net https://www.googletagmanager.com/ + https://www.youtube.com https://share.descript.com' + expires: + - '0' + permissions-policy: + - camera=(), microphone=(self), geolocation=() + pragma: + - no-cache + referrer-policy: + - strict-origin-when-cross-origin + strict-transport-security: + - max-age=63072000; includeSubDomains + vary: + - Accept + x-content-type-options: + - nosniff + x-frame-options: + - SAMEORIGIN + x-permitted-cross-domain-policies: + - none + x-request-id: + - 4124c4ce-02cf-4d08-9b0b-8983c2e9da6e + x-runtime: + - '0.073764' + x-xss-protection: + - 1; mode=block + status: + code: 401 + message: Unauthorized +- request: + body: '{"max_tokens":4096,"messages":[{"role":"user","content":"Say hello in one + word"}],"model":"claude-3-5-haiku-20241022","stop_sequences":["\nObservation:","\nThought:"],"stream":false}' + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate, zstd + anthropic-version: + - '2023-06-01' + connection: + - keep-alive + content-length: + - '182' + content-type: + - application/json + host: + - api.anthropic.com + user-agent: + - Anthropic/Python 0.71.0 + x-stainless-arch: + - arm64 + x-stainless-async: + - 'false' + x-stainless-lang: + - python + x-stainless-os: + - MacOS + x-stainless-package-version: + - 0.71.0 + x-stainless-retry-count: + - '0' + x-stainless-runtime: + - CPython + x-stainless-runtime-version: + - 3.12.9 + x-stainless-timeout: + - NOT_GIVEN + method: POST + uri: https://api.anthropic.com/v1/messages + response: + body: + string: !!binary | + H4sIAAAAAAAAAwAAAP//dJBdS8QwEEX/y31Ope26u5J3dwWfBH0SCTEZtmHTpCYTXSn979LF4hc+ + DdxzZgbuiD5a8pAwXhdL1apaV512x1K1dXvZ1G0LAWch0eeDqpvN1f3q7bTPz91tc/ewLcfr/Xaz + gwC/DzRblLM+EARS9HOgc3aZdWAImBiYAkM+jovPdJrJeUjckPfxAtOTQOY4qEQ6xwAJClZxSQGf + INNLoWAIMhTvBcr5qRzhwlBYcTxSyJBNK2C06UiZRJpdDOqnUC88kbb/sWV3vk9DRz0l7dW6/+t/ + 0ab7TSeBWPh7tBbIlF6dIcWOEiTmoqxOFtP0AQAA//8DAM5WvkqaAQAA + headers: + CF-RAY: + - 99a939a5a931556e-EWR + Connection: + - keep-alive + Content-Encoding: + - gzip + Content-Type: + - application/json + Date: + - Fri, 07 Nov 2025 01:58:22 GMT + Server: + - cloudflare + Transfer-Encoding: + - chunked + X-Robots-Tag: + - none + anthropic-organization-id: + - SCRUBBED-ORG-ID + anthropic-ratelimit-input-tokens-limit: + - '400000' + anthropic-ratelimit-input-tokens-remaining: + - '400000' + anthropic-ratelimit-input-tokens-reset: + - '2025-11-07T01:58:22Z' + anthropic-ratelimit-output-tokens-limit: + - '80000' + anthropic-ratelimit-output-tokens-remaining: + - '80000' + anthropic-ratelimit-output-tokens-reset: + - '2025-11-07T01:58:22Z' + anthropic-ratelimit-requests-limit: + - '4000' + anthropic-ratelimit-requests-remaining: + - '3999' + anthropic-ratelimit-requests-reset: + - '2025-11-07T01:58:22Z' + anthropic-ratelimit-tokens-limit: + - '480000' + anthropic-ratelimit-tokens-remaining: + - '480000' + anthropic-ratelimit-tokens-reset: + - '2025-11-07T01:58:22Z' + cf-cache-status: + - DYNAMIC + request-id: + - req_011CUshbL7CEVoner91hUvxL + retry-after: + - '41' + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + x-envoy-upstream-service-time: + - '390' + status: + code: 200 + message: OK +version: 1 diff --git a/lib/crewai/tests/llms/anthropic/test_anthropic.py b/lib/crewai/tests/llms/anthropic/test_anthropic.py index 37ba366b9..6ba294d8b 100644 --- a/lib/crewai/tests/llms/anthropic/test_anthropic.py +++ b/lib/crewai/tests/llms/anthropic/test_anthropic.py @@ -664,3 +664,37 @@ def test_anthropic_token_usage_tracking(): assert usage["input_tokens"] == 50 assert usage["output_tokens"] == 25 assert usage["total_tokens"] == 75 + + +def test_anthropic_stop_sequences_sync(): + """Test that stop and stop_sequences attributes stay synchronized.""" + llm = LLM(model="anthropic/claude-3-5-sonnet-20241022") + + # Test setting stop as a list + llm.stop = ["\nObservation:", "\nThought:"] + assert llm.stop_sequences == ["\nObservation:", "\nThought:"] + assert llm.stop == ["\nObservation:", "\nThought:"] + + # Test setting stop as a string + llm.stop = "\nFinal Answer:" + assert llm.stop_sequences == ["\nFinal Answer:"] + assert llm.stop == ["\nFinal Answer:"] + + # Test setting stop as None + llm.stop = None + assert llm.stop_sequences == [] + assert llm.stop == [] + + +@pytest.mark.vcr(filter_headers=["authorization", "x-api-key"]) +def test_anthropic_stop_sequences_sent_to_api(): + """Test that stop_sequences are properly sent to the Anthropic API.""" + llm = LLM(model="anthropic/claude-3-5-haiku-20241022") + + llm.stop = ["\nObservation:", "\nThought:"] + + result = llm.call("Say hello in one word") + + assert result is not None + assert isinstance(result, str) + assert len(result) > 0 diff --git a/lib/crewai/tests/llms/bedrock/test_bedrock.py b/lib/crewai/tests/llms/bedrock/test_bedrock.py index 9fd172cc6..aecbdde0e 100644 --- a/lib/crewai/tests/llms/bedrock/test_bedrock.py +++ b/lib/crewai/tests/llms/bedrock/test_bedrock.py @@ -736,3 +736,56 @@ def test_bedrock_client_error_handling(): with pytest.raises(RuntimeError) as exc_info: llm.call("Hello") assert "throttled" in str(exc_info.value).lower() + + +def test_bedrock_stop_sequences_sync(): + """Test that stop and stop_sequences attributes stay synchronized.""" + llm = LLM(model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0") + + # Test setting stop as a list + llm.stop = ["\nObservation:", "\nThought:"] + assert list(llm.stop_sequences) == ["\nObservation:", "\nThought:"] + assert llm.stop == ["\nObservation:", "\nThought:"] + + # Test setting stop as a string + llm.stop = "\nFinal Answer:" + assert list(llm.stop_sequences) == ["\nFinal Answer:"] + assert llm.stop == ["\nFinal Answer:"] + + # Test setting stop as None + llm.stop = None + assert list(llm.stop_sequences) == [] + assert llm.stop == [] + + +def test_bedrock_stop_sequences_sent_to_api(): + """Test that stop_sequences are properly sent to the Bedrock API.""" + llm = LLM(model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0") + + # Set stop sequences via the stop attribute (simulating CrewAgentExecutor) + llm.stop = ["\nObservation:", "\nThought:"] + + # Patch the API call to capture parameters without making real call + with patch.object(llm.client, 'converse') as mock_converse: + mock_response = { + 'output': { + 'message': { + 'role': 'assistant', + 'content': [{'text': 'Hello'}] + } + }, + 'usage': { + 'inputTokens': 10, + 'outputTokens': 5, + 'totalTokens': 15 + } + } + mock_converse.return_value = mock_response + + llm.call("Say hello in one word") + + # Verify stop_sequences were passed to the API in the inference config + call_kwargs = mock_converse.call_args[1] + assert "inferenceConfig" in call_kwargs + assert "stopSequences" in call_kwargs["inferenceConfig"] + assert call_kwargs["inferenceConfig"]["stopSequences"] == ["\nObservation:", "\nThought:"] diff --git a/lib/crewai/tests/llms/google/test_google.py b/lib/crewai/tests/llms/google/test_google.py index fc3ff9099..f7b721d1d 100644 --- a/lib/crewai/tests/llms/google/test_google.py +++ b/lib/crewai/tests/llms/google/test_google.py @@ -648,3 +648,55 @@ def test_gemini_token_usage_tracking(): assert usage["candidates_token_count"] == 25 assert usage["total_token_count"] == 75 assert usage["total_tokens"] == 75 + + +def test_gemini_stop_sequences_sync(): + """Test that stop and stop_sequences attributes stay synchronized.""" + llm = LLM(model="google/gemini-2.0-flash-001") + + # Test setting stop as a list + llm.stop = ["\nObservation:", "\nThought:"] + assert llm.stop_sequences == ["\nObservation:", "\nThought:"] + assert llm.stop == ["\nObservation:", "\nThought:"] + + # Test setting stop as a string + llm.stop = "\nFinal Answer:" + assert llm.stop_sequences == ["\nFinal Answer:"] + assert llm.stop == ["\nFinal Answer:"] + + # Test setting stop as None + llm.stop = None + assert llm.stop_sequences == [] + assert llm.stop == [] + + +def test_gemini_stop_sequences_sent_to_api(): + """Test that stop_sequences are properly sent to the Gemini API.""" + llm = LLM(model="google/gemini-2.0-flash-001") + + # Set stop sequences via the stop attribute (simulating CrewAgentExecutor) + llm.stop = ["\nObservation:", "\nThought:"] + + # Patch the API call to capture parameters without making real call + with patch.object(llm.client.models, 'generate_content') as mock_generate: + mock_response = MagicMock() + mock_response.text = "Hello" + mock_response.candidates = [] + mock_response.usage_metadata = MagicMock( + prompt_token_count=10, + candidates_token_count=5, + total_token_count=15 + ) + mock_generate.return_value = mock_response + + llm.call("Say hello in one word") + + # Verify stop_sequences were passed to the API in the config + call_kwargs = mock_generate.call_args[1] + assert "config" in call_kwargs + # The config object should have stop_sequences set + config = call_kwargs["config"] + # Check if the config has stop_sequences attribute + assert hasattr(config, 'stop_sequences') or 'stop_sequences' in config.__dict__ + if hasattr(config, 'stop_sequences'): + assert config.stop_sequences == ["\nObservation:", "\nThought:"] From f6aed9798bae90bcb8f619fe7dc5c13aeeaa355b Mon Sep 17 00:00:00 2001 From: Greyson LaLonde Date: Thu, 6 Nov 2025 21:17:29 -0500 Subject: [PATCH 3/4] feat: allow non-ast plot routes --- lib/crewai/src/crewai/flow/flow.py | 2 + lib/crewai/src/crewai/flow/types.py | 1 + lib/crewai/src/crewai/flow/utils.py | 132 +- .../flow/visualization/assets/interactive.js | 1293 +++++++++++++---- .../assets/interactive_flow.html.j2 | 173 ++- .../flow/visualization/assets/style.css | 303 +++- .../src/crewai/flow/visualization/builder.py | 67 +- .../visualization/renderers/interactive.py | 195 ++- .../src/crewai/flow/visualization/types.py | 4 +- 9 files changed, 1746 insertions(+), 424 deletions(-) diff --git a/lib/crewai/src/crewai/flow/flow.py b/lib/crewai/src/crewai/flow/flow.py index 187ff482c..42b36eb1f 100644 --- a/lib/crewai/src/crewai/flow/flow.py +++ b/lib/crewai/src/crewai/flow/flow.py @@ -428,6 +428,8 @@ class FlowMeta(type): possible_returns = get_possible_return_constants(attr_value) if possible_returns: router_paths[attr_name] = possible_returns + else: + router_paths[attr_name] = [] cls._start_methods = start_methods # type: ignore[attr-defined] cls._listeners = listeners # type: ignore[attr-defined] diff --git a/lib/crewai/src/crewai/flow/types.py b/lib/crewai/src/crewai/flow/types.py index 819f9b09a..024de41df 100644 --- a/lib/crewai/src/crewai/flow/types.py +++ b/lib/crewai/src/crewai/flow/types.py @@ -21,6 +21,7 @@ P = ParamSpec("P") R = TypeVar("R", covariant=True) FlowMethodName = NewType("FlowMethodName", str) +FlowRouteName = NewType("FlowRouteName", str) PendingListenerKey = NewType( "PendingListenerKey", Annotated[str, "nested flow conditions use 'listener_name:object_id'"], diff --git a/lib/crewai/src/crewai/flow/utils.py b/lib/crewai/src/crewai/flow/utils.py index bad9d9670..55db5d9c5 100644 --- a/lib/crewai/src/crewai/flow/utils.py +++ b/lib/crewai/src/crewai/flow/utils.py @@ -19,11 +19,11 @@ import ast from collections import defaultdict, deque import inspect import textwrap -from typing import Any, TYPE_CHECKING +from typing import TYPE_CHECKING, Any from typing_extensions import TypeIs -from crewai.flow.constants import OR_CONDITION, AND_CONDITION +from crewai.flow.constants import AND_CONDITION, OR_CONDITION from crewai.flow.flow_wrappers import ( FlowCondition, FlowConditions, @@ -33,6 +33,7 @@ from crewai.flow.flow_wrappers import ( from crewai.flow.types import FlowMethodCallable, FlowMethodName from crewai.utilities.printer import Printer + if TYPE_CHECKING: from crewai.flow.flow import Flow @@ -40,6 +41,22 @@ _printer = Printer() def get_possible_return_constants(function: Any) -> list[str] | None: + """Extract possible string return values from a function using AST parsing. + + This function analyzes the source code of a router method to identify + all possible string values it might return. It handles: + - Direct string literals: return "value" + - Variable assignments: x = "value"; return x + - Dictionary lookups: d = {"k": "v"}; return d[key] + - Conditional returns: return "a" if cond else "b" + - State attributes: return self.state.attr (infers from class context) + + Args: + function: The function to analyze. + + Returns: + List of possible string return values, or None if analysis fails. + """ try: source = inspect.getsource(function) except OSError: @@ -82,6 +99,7 @@ def get_possible_return_constants(function: Any) -> list[str] | None: return_values: set[str] = set() dict_definitions: dict[str, list[str]] = {} variable_values: dict[str, list[str]] = {} + state_attribute_values: dict[str, list[str]] = {} def extract_string_constants(node: ast.expr) -> list[str]: """Recursively extract all string constants from an AST node.""" @@ -91,6 +109,17 @@ def get_possible_return_constants(function: Any) -> list[str] | None: elif isinstance(node, ast.IfExp): strings.extend(extract_string_constants(node.body)) strings.extend(extract_string_constants(node.orelse)) + elif isinstance(node, ast.Call): + if ( + isinstance(node.func, ast.Attribute) + and node.func.attr == "get" + and len(node.args) >= 2 + ): + default_arg = node.args[1] + if isinstance(default_arg, ast.Constant) and isinstance( + default_arg.value, str + ): + strings.append(default_arg.value) return strings class VariableAssignmentVisitor(ast.NodeVisitor): @@ -124,6 +153,22 @@ def get_possible_return_constants(function: Any) -> list[str] | None: self.generic_visit(node) + def get_attribute_chain(node: ast.expr) -> str | None: + """Extract the full attribute chain from an AST node. + + Examples: + self.state.run_type -> "self.state.run_type" + x.y.z -> "x.y.z" + simple_var -> "simple_var" + """ + if isinstance(node, ast.Name): + return node.id + if isinstance(node, ast.Attribute): + base = get_attribute_chain(node.value) + if base: + return f"{base}.{node.attr}" + return None + class ReturnVisitor(ast.NodeVisitor): def visit_Return(self, node: ast.Return) -> None: if ( @@ -139,21 +184,94 @@ def get_possible_return_constants(function: Any) -> list[str] | None: for v in dict_definitions[var_name_dict]: return_values.add(v) elif node.value: - var_name_ret: str | None = None - if isinstance(node.value, ast.Name): - var_name_ret = node.value.id - elif isinstance(node.value, ast.Attribute): - var_name_ret = f"{node.value.value.id if isinstance(node.value.value, ast.Name) else '_'}.{node.value.attr}" + var_name_ret = get_attribute_chain(node.value) if var_name_ret and var_name_ret in variable_values: for v in variable_values[var_name_ret]: return_values.add(v) + elif var_name_ret and var_name_ret in state_attribute_values: + for v in state_attribute_values[var_name_ret]: + return_values.add(v) self.generic_visit(node) def visit_If(self, node: ast.If) -> None: self.generic_visit(node) + # Try to get the class context to infer state attribute values + try: + if hasattr(function, "__self__"): + # Method is bound, get the class + class_obj = function.__self__.__class__ + elif hasattr(function, "__qualname__") and "." in function.__qualname__: + # Method is unbound but we can try to get class from module + class_name = function.__qualname__.rsplit(".", 1)[0] + if hasattr(function, "__globals__"): + class_obj = function.__globals__.get(class_name) + else: + class_obj = None + else: + class_obj = None + + if class_obj is not None: + try: + class_source = inspect.getsource(class_obj) + class_source = textwrap.dedent(class_source) + class_ast = ast.parse(class_source) + + # Look for comparisons and assignments involving state attributes + class StateAttributeVisitor(ast.NodeVisitor): + def visit_Compare(self, node: ast.Compare) -> None: + """Find comparisons like: self.state.attr == "value" """ + left_attr = get_attribute_chain(node.left) + + if left_attr: + for comparator in node.comparators: + if isinstance(comparator, ast.Constant) and isinstance( + comparator.value, str + ): + if left_attr not in state_attribute_values: + state_attribute_values[left_attr] = [] + if ( + comparator.value + not in state_attribute_values[left_attr] + ): + state_attribute_values[left_attr].append( + comparator.value + ) + + # Also check right side + for comparator in node.comparators: + right_attr = get_attribute_chain(comparator) + if ( + right_attr + and isinstance(node.left, ast.Constant) + and isinstance(node.left.value, str) + ): + if right_attr not in state_attribute_values: + state_attribute_values[right_attr] = [] + if ( + node.left.value + not in state_attribute_values[right_attr] + ): + state_attribute_values[right_attr].append( + node.left.value + ) + + self.generic_visit(node) + + StateAttributeVisitor().visit(class_ast) + except Exception as e: + _printer.print( + f"Could not analyze class context for {function.__name__}: {e}", + color="yellow", + ) + except Exception as e: + _printer.print( + f"Could not introspect class for {function.__name__}: {e}", + color="yellow", + ) + VariableAssignmentVisitor().visit(code_ast) ReturnVisitor().visit(code_ast) diff --git a/lib/crewai/src/crewai/flow/visualization/assets/interactive.js b/lib/crewai/src/crewai/flow/visualization/assets/interactive.js index 8d4fe9bd9..10788727f 100644 --- a/lib/crewai/src/crewai/flow/visualization/assets/interactive.js +++ b/lib/crewai/src/crewai/flow/visualization/assets/interactive.js @@ -1,24 +1,16 @@ "use strict"; -/** - * Flow Visualization Interactive Script - * Handles the interactive network visualization for CrewAI flows - */ - -// ============================================================================ -// Constants -// ============================================================================ - const CONSTANTS = { NODE: { - BASE_WIDTH: 200, - BASE_HEIGHT: 60, + BASE_WIDTH: 220, + BASE_HEIGHT: 100, BORDER_RADIUS: 20, TEXT_SIZE: 13, - TEXT_PADDING: 8, + TEXT_PADDING: 16, TEXT_BG_RADIUS: 6, - HOVER_SCALE: 1.04, - PRESSED_SCALE: 0.98, + HOVER_SCALE: 1.00, + PRESSED_SCALE: 1.16, + SELECTED_SCALE: 1.05, }, EDGE: { DEFAULT_WIDTH: 2, @@ -32,26 +24,18 @@ const CONSTANTS = { EASE_OUT_CUBIC: (t) => 1 - Math.pow(1 - t, 3), }, NETWORK: { - STABILIZATION_ITERATIONS: 200, - NODE_DISTANCE: 180, - SPRING_LENGTH: 150, - LEVEL_SEPARATION: 180, - NODE_SPACING: 220, + STABILIZATION_ITERATIONS: 300, + NODE_DISTANCE: 225, + SPRING_LENGTH: 100, + LEVEL_SEPARATION: 150, + NODE_SPACING: 350, TREE_SPACING: 250, }, DRAWER: { WIDTH: 400, - OFFSET_SCALE: 0.3, }, }; -// ============================================================================ -// Utility Functions -// ============================================================================ - -/** - * Loads the vis-network library from CDN - */ function loadVisCDN() { return new Promise((resolve, reject) => { const script = document.createElement("script"); @@ -62,9 +46,6 @@ function loadVisCDN() { }); } -/** - * Draws a rounded rectangle on a canvas context - */ function drawRoundedRect(ctx, x, y, width, height, radius) { ctx.beginPath(); ctx.moveTo(x + radius, y); @@ -79,50 +60,54 @@ function drawRoundedRect(ctx, x, y, width, height, radius) { ctx.closePath(); } -/** - * Highlights Python code using Prism - */ function highlightPython(code) { return Prism.highlight(code, Prism.languages.python, "python"); } -// ============================================================================ -// Node Renderer -// ============================================================================ - class NodeRenderer { constructor(nodes, networkManager) { this.nodes = nodes; this.networkManager = networkManager; + this.nodeScales = new Map(); + this.scaleAnimations = new Map(); + this.hoverGlowIntensities = new Map(); + this.glowAnimations = new Map(); + this.colorCache = new Map(); + this.tempCanvas = document.createElement('canvas'); + this.tempCanvas.width = 1; + this.tempCanvas.height = 1; + this.tempCtx = this.tempCanvas.getContext('2d'); } - render({ ctx, id, x, y, state, style, label }) { + render({ ctx, id, x, y }) { const node = this.nodes.get(id); - if (!node || !node.nodeStyle) return {}; + if (!node?.nodeStyle) return {}; const scale = this.getNodeScale(id); - const isActiveDrawer = - this.networkManager.drawerManager?.activeNodeId === id; + const isActiveDrawer = this.networkManager.drawerManager?.activeNodeId === id; + const isHovered = this.networkManager.hoveredNodeId === id && !isActiveDrawer; const nodeStyle = node.nodeStyle; - const width = CONSTANTS.NODE.BASE_WIDTH * scale; - const height = CONSTANTS.NODE.BASE_HEIGHT * scale; + + // Manage hover glow intensity animation + const glowIntensity = this.getHoverGlowIntensity(id, isHovered); + + ctx.font = `500 ${CONSTANTS.NODE.TEXT_SIZE * scale}px 'JetBrains Mono', 'SF Mono', 'Monaco', 'Menlo', 'Consolas', monospace`; + const textMetrics = ctx.measureText(nodeStyle.name); + const textWidth = textMetrics.width; + const textHeight = CONSTANTS.NODE.TEXT_SIZE * scale; + const textPadding = CONSTANTS.NODE.TEXT_PADDING * scale; + + const width = textWidth + textPadding * 5; + const height = textHeight + textPadding * 2.5; return { drawNode: () => { ctx.save(); - this.applyNodeOpacity(ctx, node); - this.applyShadow(ctx, node, isActiveDrawer); - this.drawNodeShape( - ctx, - x, - y, - width, - height, - scale, - nodeStyle, - isActiveDrawer, - ); - this.drawNodeText(ctx, x, y, scale, nodeStyle); + const opacity = node.opacity !== undefined ? node.opacity : 1.0; + this.applyShadow(ctx, node, glowIntensity, opacity); + ctx.globalAlpha = opacity; + this.drawNodeShape(ctx, x, y, width, height, scale, nodeStyle, opacity, node); + this.drawNodeText(ctx, x, y, scale, nodeStyle, opacity, node); ctx.restore(); }, nodeDimensions: { width, height }, @@ -130,54 +115,378 @@ class NodeRenderer { } getNodeScale(id) { - if (this.networkManager.pressedNodeId === id) { - return CONSTANTS.NODE.PRESSED_SCALE; + const isActiveDrawer = this.networkManager.drawerManager?.activeNodeId === id; + + let targetScale = 1.0; + if (isActiveDrawer) { + targetScale = CONSTANTS.NODE.SELECTED_SCALE; + } else if (this.networkManager.pressedNodeId === id) { + targetScale = CONSTANTS.NODE.PRESSED_SCALE; } else if (this.networkManager.hoveredNodeId === id) { - return CONSTANTS.NODE.HOVER_SCALE; + targetScale = CONSTANTS.NODE.HOVER_SCALE; } - return 1.0; + + const currentScale = this.nodeScales.get(id) ?? 1.0; + const runningAnimation = this.scaleAnimations.get(id); + const animationTarget = runningAnimation?.targetScale; + + if (Math.abs(targetScale - currentScale) > 0.001) { + if (runningAnimation && animationTarget !== targetScale) { + cancelAnimationFrame(runningAnimation.frameId); + this.scaleAnimations.delete(id); + } + + if (!this.scaleAnimations.has(id)) { + this.animateScale(id, currentScale, targetScale); + } + } + + return currentScale; } - applyNodeOpacity(ctx, node) { - const nodeOpacity = node.opacity !== undefined ? node.opacity : 1.0; - ctx.globalAlpha = nodeOpacity; + animateScale(id, startScale, targetScale) { + const startTime = performance.now(); + const duration = 150; + + const animate = () => { + const elapsed = performance.now() - startTime; + const progress = Math.min(elapsed / duration, 1); + const eased = CONSTANTS.ANIMATION.EASE_OUT_CUBIC(progress); + + const currentScale = startScale + (targetScale - startScale) * eased; + this.nodeScales.set(id, currentScale); + + if (progress < 1) { + const frameId = requestAnimationFrame(animate); + this.scaleAnimations.set(id, { frameId, targetScale }); + } else { + this.scaleAnimations.delete(id); + this.nodeScales.set(id, targetScale); + } + + this.networkManager.network?.redraw(); + }; + + animate(); } - applyShadow(ctx, node, isActiveDrawer) { - if (node.shadow && node.shadow.enabled) { + getHoverGlowIntensity(id, isHovered) { + const targetIntensity = isHovered ? 1.0 : 0.0; + const currentIntensity = this.hoverGlowIntensities.get(id) ?? 0.0; + const runningAnimation = this.glowAnimations.get(id); + const animationTarget = runningAnimation?.targetIntensity; + + if (Math.abs(targetIntensity - currentIntensity) > 0.001) { + if (runningAnimation && animationTarget !== targetIntensity) { + cancelAnimationFrame(runningAnimation.frameId); + this.glowAnimations.delete(id); + } + + if (!this.glowAnimations.has(id)) { + this.animateGlowIntensity(id, currentIntensity, targetIntensity); + } + } + + return currentIntensity; + } + + animateGlowIntensity(id, startIntensity, targetIntensity) { + const startTime = performance.now(); + const duration = 200; + + const animate = () => { + const elapsed = performance.now() - startTime; + const progress = Math.min(elapsed / duration, 1); + const eased = CONSTANTS.ANIMATION.EASE_OUT_CUBIC(progress); + + const currentIntensity = startIntensity + (targetIntensity - startIntensity) * eased; + this.hoverGlowIntensities.set(id, currentIntensity); + + if (progress < 1) { + const frameId = requestAnimationFrame(animate); + this.glowAnimations.set(id, { frameId, targetIntensity }); + } else { + this.glowAnimations.delete(id); + this.hoverGlowIntensities.set(id, targetIntensity); + } + + this.networkManager.network?.redraw(); + }; + + animate(); + } + + applyShadow(ctx, node, glowIntensity = 0, nodeOpacity = 1.0) { + if (glowIntensity > 0.001) { + // Save current alpha and apply glow at full opacity + const currentAlpha = ctx.globalAlpha; + ctx.globalAlpha = 1.0; + + const isDarkMode = document.documentElement.getAttribute('data-theme') === 'dark'; + + // Use CrewAI orange for hover glow in both themes + const glowR = 255; + const glowG = 90; + const glowB = 80; + const blurRadius = isDarkMode ? 20 : 35; + + // Scale glow intensity proportionally based on node opacity + // When node is inactive (opacity < 1.0), reduce glow intensity accordingly + const scaledGlowIntensity = glowIntensity * nodeOpacity; + + const glowColor = `rgba(${glowR}, ${glowG}, ${glowB}, ${scaledGlowIntensity})`; + + ctx.shadowColor = glowColor; + ctx.shadowBlur = blurRadius * scaledGlowIntensity; + ctx.shadowOffsetX = 0; + ctx.shadowOffsetY = 0; + + // Restore the original alpha + ctx.globalAlpha = currentAlpha; + return; + } + + if (node.shadow?.enabled) { ctx.shadowColor = node.shadow.color || "rgba(0,0,0,0.1)"; ctx.shadowBlur = node.shadow.size || 8; ctx.shadowOffsetX = node.shadow.x || 0; ctx.shadowOffsetY = node.shadow.y || 0; - } else if (isActiveDrawer) { - ctx.shadowColor = "{{ CREWAI_ORANGE }}"; - ctx.shadowBlur = 20; - ctx.shadowOffsetX = 0; - ctx.shadowOffsetY = 0; + return; } + + ctx.shadowColor = "transparent"; + ctx.shadowBlur = 0; + ctx.shadowOffsetX = 0; + ctx.shadowOffsetY = 0; } - drawNodeShape(ctx, x, y, width, height, scale, nodeStyle, isActiveDrawer) { + resolveCSSVariable(color) { + if (color?.startsWith('var(')) { + const varName = color.match(/var\((--[^)]+)\)/)?.[1]; + if (varName) { + return getComputedStyle(document.documentElement).getPropertyValue(varName).trim(); + } + } + return color; + } + + + parseColor(color) { + const cacheKey = `parse_${color}`; + if (this.colorCache.has(cacheKey)) { + return this.colorCache.get(cacheKey); + } + + this.tempCtx.fillStyle = color; + this.tempCtx.fillRect(0, 0, 1, 1); + const [r, g, b] = this.tempCtx.getImageData(0, 0, 1, 1).data; + + const result = { r, g, b }; + this.colorCache.set(cacheKey, result); + return result; + } + + darkenColor(color, opacity) { + if (opacity >= 0.9) return color; + + const { r, g, b } = this.parseColor(color); + + const t = (opacity - 0.85) / (1.0 - 0.85); + const normalizedT = Math.max(0, Math.min(1, t)); + + const minBrightness = 0.4; + const brightness = minBrightness + (1.0 - minBrightness) * normalizedT; + + const newR = Math.floor(r * brightness); + const newG = Math.floor(g * brightness); + const newB = Math.floor(b * brightness); + + return `rgb(${newR}, ${newG}, ${newB})`; + } + + desaturateColor(color, opacity) { + if (opacity >= 0.9) return color; + + const { r, g, b } = this.parseColor(color); + + // Convert to HSL to adjust saturation and lightness + const max = Math.max(r, g, b) / 255; + const min = Math.min(r, g, b) / 255; + const l = (max + min) / 2; + let h = 0, s = 0; + + if (max !== min) { + const d = max - min; + s = l > 0.5 ? d / (2 - max - min) : d / (max + min); + + if (max === r / 255) { + h = ((g / 255 - b / 255) / d + (g < b ? 6 : 0)) / 6; + } else if (max === g / 255) { + h = ((b / 255 - r / 255) / d + 2) / 6; + } else { + h = ((r / 255 - g / 255) / d + 4) / 6; + } + } + + // Reduce saturation and lightness by 40% + s = s * 0.6; + const newL = l * 0.6; + + // Convert back to RGB + const hue2rgb = (p, q, t) => { + if (t < 0) t += 1; + if (t > 1) t -= 1; + if (t < 1/6) return p + (q - p) * 6 * t; + if (t < 1/2) return q; + if (t < 2/3) return p + (q - p) * (2/3 - t) * 6; + return p; + }; + + let newR, newG, newB; + if (s === 0) { + newR = newG = newB = Math.floor(newL * 255); + } else { + const q = newL < 0.5 ? newL * (1 + s) : newL + s - newL * s; + const p = 2 * newL - q; + newR = Math.floor(hue2rgb(p, q, h + 1/3) * 255); + newG = Math.floor(hue2rgb(p, q, h) * 255); + newB = Math.floor(hue2rgb(p, q, h - 1/3) * 255); + } + + return `rgb(${newR}, ${newG}, ${newB})`; + } + + drawNodeShape(ctx, x, y, width, height, scale, nodeStyle, opacity = 1.0, node = null) { const radius = CONSTANTS.NODE.BORDER_RADIUS * scale; const rectX = x - width / 2; const rectY = y - height / 2; - drawRoundedRect(ctx, rectX, rectY, width, height, radius); + const isDarkMode = document.documentElement.getAttribute('data-theme') === 'dark'; + const nodeData = '{{ nodeData }}'; + const metadata = node ? nodeData[node.id] : null; + const isStartNode = metadata && metadata.type === 'start'; - ctx.fillStyle = nodeStyle.bgColor; + let nodeColor; + + if (isDarkMode || isStartNode) { + // In dark mode or for start nodes, use the theme color + nodeColor = this.resolveCSSVariable(nodeStyle.bgColor); + } else { + // In light mode for non-start nodes, use white + nodeColor = 'rgb(255, 255, 255)'; + } + + // Parse the base color to get RGB values + let { r, g, b } = this.parseColor(nodeColor); + + // For inactive nodes, check if node is in highlighted list + // If drawer is open and node is not highlighted, it's inactive + const isDrawerOpen = this.networkManager.drawerManager?.activeNodeId !== null; + const isHighlighted = this.networkManager.triggeredByHighlighter?.highlightedNodes?.includes(node?.id); + const isActiveNode = this.networkManager.drawerManager?.activeNodeId === node?.id; + const hasHighlightedNodes = this.networkManager.triggeredByHighlighter?.highlightedNodes?.length > 0; + + // Non-prominent nodes: drawer is open, has highlighted nodes, but this node is not highlighted or active + const isNonProminent = isDrawerOpen && hasHighlightedNodes && !isHighlighted && !isActiveNode; + + // Inactive nodes: drawer is open but no highlighted nodes, and this node is not active + const isInactive = isDrawerOpen && !hasHighlightedNodes && !isActiveNode; + + if (isNonProminent || isInactive) { + // Make non-prominent and inactive nodes a darker version of the normal active color + const darkenFactor = 0.4; // Keep 40% of original color (darken by 60%) + r = Math.round(r * darkenFactor); + g = Math.round(g * darkenFactor); + b = Math.round(b * darkenFactor); + } + + // Draw base shape with frosted glass effect + ctx.beginPath(); + drawRoundedRect(ctx, rectX, rectY, width, height, radius); + // Use full opacity for all nodes + const glassOpacity = 1.0; + ctx.fillStyle = `rgba(${r}, ${g}, ${b}, ${glassOpacity})`; ctx.fill(); + // Calculate text label area to exclude from frosted overlay + const textPadding = CONSTANTS.NODE.TEXT_PADDING * scale; + const textBgRadius = CONSTANTS.NODE.TEXT_BG_RADIUS * scale; + + ctx.font = `500 ${CONSTANTS.NODE.TEXT_SIZE * scale}px 'JetBrains Mono', 'SF Mono', 'Monaco', 'Menlo', 'Consolas', monospace`; + const textMetrics = ctx.measureText(nodeStyle.name); + const textWidth = textMetrics.width; + const textHeight = CONSTANTS.NODE.TEXT_SIZE * scale; + const textBgWidth = textWidth + textPadding * 2; + const textBgHeight = textHeight + textPadding * 0.75; + const textBgX = x - textBgWidth / 2; + const textBgY = y - textBgHeight / 2; + + // Add frosted overlay (clipped to node shape, excluding text area) + ctx.save(); + ctx.beginPath(); + drawRoundedRect(ctx, rectX, rectY, width, height, radius); + ctx.clip(); + + // Cut out the text label area from the frosted overlay + ctx.beginPath(); + drawRoundedRect(ctx, rectX, rectY, width, height, radius); + drawRoundedRect(ctx, textBgX, textBgY, textBgWidth, textBgHeight, textBgRadius); + ctx.clip('evenodd'); + + // For inactive nodes, use stronger absolute frost values + // For active nodes, scale frost with opacity + let frostTop, frostMid, frostBottom; + if (isInactive) { + // Inactive nodes get stronger, more consistent frost + frostTop = 0.45; + frostMid = 0.35; + frostBottom = 0.25; + } else { + // Active nodes get opacity-scaled frost + frostTop = opacity * 0.3; + frostMid = opacity * 0.2; + frostBottom = opacity * 0.15; + } + + // Stronger white overlay for frosted appearance + const frostOverlay = ctx.createLinearGradient(rectX, rectY, rectX, rectY + height); + frostOverlay.addColorStop(0, `rgba(255, 255, 255, ${frostTop})`); + frostOverlay.addColorStop(0.5, `rgba(255, 255, 255, ${frostMid})`); + frostOverlay.addColorStop(1, `rgba(255, 255, 255, ${frostBottom})`); + + ctx.fillStyle = frostOverlay; + ctx.fillRect(rectX, rectY, width, height); + ctx.restore(); + ctx.shadowColor = "transparent"; ctx.shadowBlur = 0; - ctx.strokeStyle = isActiveDrawer - ? "{{ CREWAI_ORANGE }}" - : nodeStyle.borderColor; + // Draw border at full opacity (desaturated for inactive nodes) + // Reset globalAlpha to 1.0 so the border is always fully visible + ctx.save(); + ctx.globalAlpha = 1.0; + ctx.beginPath(); + drawRoundedRect(ctx, rectX, rectY, width, height, radius); + const borderColor = this.resolveCSSVariable(nodeStyle.borderColor); + let finalBorderColor = this.desaturateColor(borderColor, opacity); + + // Darken border color for non-prominent and inactive nodes + if (isNonProminent || isInactive) { + const borderRGB = this.parseColor(finalBorderColor); + const darkenFactor = 0.4; + const darkenedR = Math.round(borderRGB.r * darkenFactor); + const darkenedG = Math.round(borderRGB.g * darkenFactor); + const darkenedB = Math.round(borderRGB.b * darkenFactor); + finalBorderColor = `rgb(${darkenedR}, ${darkenedG}, ${darkenedB})`; + } + + ctx.strokeStyle = finalBorderColor; ctx.lineWidth = nodeStyle.borderWidth * scale; ctx.stroke(); + ctx.restore(); } - drawNodeText(ctx, x, y, scale, nodeStyle) { + drawNodeText(ctx, x, y, scale, nodeStyle, opacity = 1.0, node = null) { ctx.font = `500 ${CONSTANTS.NODE.TEXT_SIZE * scale}px 'JetBrains Mono', 'SF Mono', 'Monaco', 'Menlo', 'Consolas', monospace`; ctx.textAlign = "center"; ctx.textBaseline = "middle"; @@ -188,10 +497,10 @@ class NodeRenderer { const textPadding = CONSTANTS.NODE.TEXT_PADDING * scale; const textBgRadius = CONSTANTS.NODE.TEXT_BG_RADIUS * scale; - const textBgX = x - textWidth / 2 - textPadding; - const textBgY = y - textHeight / 2 - textPadding / 2; const textBgWidth = textWidth + textPadding * 2; - const textBgHeight = textHeight + textPadding; + const textBgHeight = textHeight + textPadding * 0.75; + const textBgX = x - textBgWidth / 2; + const textBgY = y - textBgHeight / 2; drawRoundedRect( ctx, @@ -202,18 +511,71 @@ class NodeRenderer { textBgRadius, ); - ctx.fillStyle = "rgba(255, 255, 255, 0.2)"; + const isDarkMode = document.documentElement.getAttribute('data-theme') === 'dark'; + const nodeData = '{{ nodeData }}'; + const metadata = node ? nodeData[node.id] : null; + const isStartNode = metadata && metadata.type === 'start'; + + // Check if this is an inactive or non-prominent node using the same logic as drawNodeShape + const isDrawerOpen = this.networkManager.drawerManager?.activeNodeId !== null; + const isHighlighted = this.networkManager.triggeredByHighlighter?.highlightedNodes?.includes(node?.id); + const isActiveNode = this.networkManager.drawerManager?.activeNodeId === node?.id; + const hasHighlightedNodes = this.networkManager.triggeredByHighlighter?.highlightedNodes?.length > 0; + + const isNonProminent = isDrawerOpen && hasHighlightedNodes && !isHighlighted && !isActiveNode; + const isInactive = isDrawerOpen && !hasHighlightedNodes && !isActiveNode; + + // Get the base node color to darken it for inactive nodes + let nodeColor; + if (isDarkMode || isStartNode) { + nodeColor = this.resolveCSSVariable(nodeStyle.bgColor); + } else { + nodeColor = 'rgb(255, 255, 255)'; + } + const { r, g, b } = this.parseColor(nodeColor); + + let labelBgR = 255, labelBgG = 255, labelBgB = 255; + let labelBgOpacity = 0.2 * opacity; + + if (isNonProminent || isInactive) { + // Darken the base node color for non-prominent and inactive label backgrounds + const darkenFactor = 0.4; + labelBgR = Math.round(r * darkenFactor); + labelBgG = Math.round(g * darkenFactor); + labelBgB = Math.round(b * darkenFactor); + labelBgOpacity = 0.5; + } else if (!isDarkMode && !isStartNode) { + // In light mode for non-start nodes, use gray for active node labels + labelBgR = labelBgG = labelBgB = 128; + labelBgOpacity = 0.25; + } + + ctx.fillStyle = `rgba(${labelBgR}, ${labelBgG}, ${labelBgB}, ${labelBgOpacity})`; ctx.fill(); - ctx.fillStyle = nodeStyle.fontColor; + // For start nodes or dark mode, use theme color; in light mode, use dark text + let fontColor; + if (isDarkMode || isStartNode) { + fontColor = this.resolveCSSVariable(nodeStyle.fontColor); + } else { + fontColor = 'rgb(30, 30, 30)'; + } + + // Darken font color for non-prominent and inactive nodes + if (isNonProminent || isInactive) { + const fontRGB = this.parseColor(fontColor); + const darkenFactor = 0.4; + const darkenedR = Math.round(fontRGB.r * darkenFactor); + const darkenedG = Math.round(fontRGB.g * darkenFactor); + const darkenedB = Math.round(fontRGB.b * darkenFactor); + fontColor = `rgb(${darkenedR}, ${darkenedG}, ${darkenedB})`; + } + + ctx.fillStyle = fontColor; ctx.fillText(nodeStyle.name, x, y); } } -// ============================================================================ -// Animation Manager -// ============================================================================ - class AnimationManager { constructor() { this.animations = new Map(); @@ -265,10 +627,6 @@ class AnimationManager { } } -// ============================================================================ -// Triggered By Highlighter -// ============================================================================ - class TriggeredByHighlighter { constructor(network, nodes, edges, highlightCanvas) { this.network = network; @@ -305,7 +663,6 @@ class TriggeredByHighlighter { this.clear(); if (!this.activeDrawerNodeId || !triggerNodeIds || triggerNodeIds.length === 0) { - console.warn("TriggeredByHighlighter: Missing activeDrawerNodeId or triggerNodeIds"); return; } @@ -333,38 +690,74 @@ class TriggeredByHighlighter { const routerEdges = allEdges.filter( (edge) => edge.from === routerNode && edge.dashes ); + let foundEdge = false; for (const routerEdge of routerEdges) { - if (routerEdge.to === this.activeDrawerNodeId) { + if (routerEdge.label === triggerNodeId) { connectingEdges.push(routerEdge); pathNodes.add(routerNode); - pathNodes.add(this.activeDrawerNodeId); - break; - } + pathNodes.add(routerEdge.to); - const intermediateNode = routerEdge.to; - const pathToActive = allEdges.filter( - (edge) => edge.from === intermediateNode && edge.to === this.activeDrawerNodeId - ); + if (routerEdge.to !== this.activeDrawerNodeId) { + const pathToActive = allEdges.filter( + (edge) => edge.from === routerEdge.to && edge.to === this.activeDrawerNodeId + ); - if (pathToActive.length > 0) { - connectingEdges.push(routerEdge); - connectingEdges.push(...pathToActive); - pathNodes.add(routerNode); - pathNodes.add(intermediateNode); - pathNodes.add(this.activeDrawerNodeId); + if (pathToActive.length > 0) { + connectingEdges.push(...pathToActive); + pathNodes.add(this.activeDrawerNodeId); + } + } + + foundEdge = true; break; } } - if (connectingEdges.length > 0) break; + if (!foundEdge) { + for (const routerEdge of routerEdges) { + if (routerEdge.to === triggerNodeId) { + connectingEdges.push(routerEdge); + pathNodes.add(routerNode); + pathNodes.add(routerEdge.to); + + const pathToActive = allEdges.filter( + (edge) => edge.from === triggerNodeId && edge.to === this.activeDrawerNodeId + ); + + if (pathToActive.length > 0) { + connectingEdges.push(...pathToActive); + pathNodes.add(this.activeDrawerNodeId); + } + + foundEdge = true; + break; + } + } + } + + if (!foundEdge) { + const directRouterEdge = routerEdges.find( + (edge) => edge.to === this.activeDrawerNodeId + ); + + if (directRouterEdge) { + connectingEdges.push(directRouterEdge); + pathNodes.add(routerNode); + pathNodes.add(this.activeDrawerNodeId); + foundEdge = true; + } + } + + if (foundEdge) { + break; + } } } } }); if (connectingEdges.length === 0) { - console.warn("TriggeredByHighlighter: No connecting edges found for group", { triggerNodeIds }); return; } @@ -379,7 +772,6 @@ class TriggeredByHighlighter { this.clear(); if (!this.activeDrawerNodeId) { - console.warn("TriggeredByHighlighter: Missing activeDrawerNodeId"); return; } @@ -419,11 +811,6 @@ class TriggeredByHighlighter { } if (routerEdges.length === 0) { - console.warn("TriggeredByHighlighter: No router paths found for node", { - activeDrawerNodeId: this.activeDrawerNodeId, - outgoingEdges: outgoingRouterEdges.length, - hasRouterPathsMetadata: !!activeMetadata?.router_paths, - }); return; } @@ -438,24 +825,12 @@ class TriggeredByHighlighter { this.clear(); if (this.activeDrawerEdges && this.activeDrawerEdges.length > 0) { - this.activeDrawerEdges.forEach((edgeId) => { - this.edges.update({ - id: edgeId, - width: CONSTANTS.EDGE.DEFAULT_WIDTH, - opacity: 1.0, - }); - }); + // Animate the activeDrawerEdges back to default + this.resetEdgesToDefault(this.activeDrawerEdges); this.activeDrawerEdges = []; } if (!this.activeDrawerNodeId || !triggerNodeId) { - console.warn( - "TriggeredByHighlighter: Missing activeDrawerNodeId or triggerNodeId", - { - activeDrawerNodeId: this.activeDrawerNodeId, - triggerNodeId: triggerNodeId, - }, - ); return; } @@ -570,17 +945,6 @@ class TriggeredByHighlighter { } if (connectingEdges.length === 0) { - console.warn("TriggeredByHighlighter: No connecting edges found", { - triggerNodeId, - activeDrawerNodeId: this.activeDrawerNodeId, - allEdges: allEdges.length, - edgeDetails: allEdges.map((e) => ({ - from: e.from, - to: e.to, - label: e.label, - dashes: e.dashes, - })), - }); return; } @@ -601,6 +965,7 @@ class TriggeredByHighlighter { const allNodesList = this.nodes.get(); const nodeAnimDuration = CONSTANTS.ANIMATION.DURATION; const nodeAnimStart = performance.now(); + const isDarkMode = document.documentElement.getAttribute('data-theme') === 'dark'; const animate = () => { const elapsed = performance.now() - nodeAnimStart; @@ -609,9 +974,11 @@ class TriggeredByHighlighter { allNodesList.forEach((node) => { const currentOpacity = node.opacity !== undefined ? node.opacity : 1.0; + // Keep inactive nodes at full opacity + const inactiveOpacity = 1.0; const targetOpacity = this.highlightedNodes.includes(node.id) ? 1.0 - : 0.2; + : inactiveOpacity; const newOpacity = currentOpacity + (targetOpacity - currentOpacity) * eased; @@ -621,6 +988,8 @@ class TriggeredByHighlighter { }); }); + this.network.redraw(); + if (progress < 1) { requestAnimationFrame(animate); } @@ -654,18 +1023,23 @@ class TriggeredByHighlighter { const newShadowSize = currentShadowSize + (targetShadowSize - currentShadowSize) * eased; + const isAndOrRouter = edge.dashes || edge.label === "AND"; + const highlightColor = isAndOrRouter + ? "{{ CREWAI_ORANGE }}" + : getComputedStyle(document.documentElement).getPropertyValue('--edge-or-color').trim(); + const updateData = { id: edge.id, hidden: false, opacity: 1.0, width: newWidth, color: { - color: "{{ CREWAI_ORANGE }}", - highlight: "{{ CREWAI_ORANGE }}", + color: highlightColor, + highlight: highlightColor, }, shadow: { enabled: true, - color: "{{ CREWAI_ORANGE }}", + color: highlightColor, size: newShadowSize, x: 0, y: 0, @@ -686,30 +1060,52 @@ class TriggeredByHighlighter { }; updateData.color = { - color: "{{ CREWAI_ORANGE }}", - highlight: "{{ CREWAI_ORANGE }}", - hover: "{{ CREWAI_ORANGE }}", + color: highlightColor, + highlight: highlightColor, + hover: highlightColor, inherit: "to", }; this.edges.update(updateData); } else { const currentOpacity = edge.opacity !== undefined ? edge.opacity : 1.0; - const targetOpacity = 0.25; + // Keep inactive edges at full opacity + const targetOpacity = 1.0; const newOpacity = currentOpacity + (targetOpacity - currentOpacity) * eased; const currentWidth = edge.width !== undefined ? edge.width : CONSTANTS.EDGE.DEFAULT_WIDTH; - const targetWidth = 1; + const targetWidth = 1.2; const newWidth = currentWidth + (targetWidth - currentWidth) * eased; + // Keep the original edge color instead of turning gray + const isAndOrRouter = edge.dashes || edge.label === "AND"; + const baseColor = isAndOrRouter + ? "{{ CREWAI_ORANGE }}" + : getComputedStyle(document.documentElement).getPropertyValue('--edge-or-color').trim(); + + // Convert color to rgba with opacity for vis.js + let inactiveEdgeColor; + if (baseColor.startsWith('#')) { + // Convert hex to rgba + const hex = baseColor.replace('#', ''); + const r = parseInt(hex.substr(0, 2), 16); + const g = parseInt(hex.substr(2, 2), 16); + const b = parseInt(hex.substr(4, 2), 16); + inactiveEdgeColor = `rgba(${r}, ${g}, ${b}, ${newOpacity})`; + } else if (baseColor.startsWith('rgb(')) { + inactiveEdgeColor = baseColor.replace('rgb(', `rgba(`).replace(')', `, ${newOpacity})`); + } else { + inactiveEdgeColor = baseColor; + } + this.edges.update({ id: edge.id, hidden: false, - opacity: newOpacity, width: newWidth, color: { - color: "rgba(128, 128, 128, 0.3)", - highlight: "rgba(128, 128, 128, 0.3)", + color: inactiveEdgeColor, + highlight: inactiveEdgeColor, + hover: inactiveEdgeColor, }, shadow: { enabled: false, @@ -726,55 +1122,91 @@ class TriggeredByHighlighter { animate(); } - drawHighlightLayer() { - this.ctx.clearRect(0, 0, this.canvas.width, this.canvas.height); + resetEdgesToDefault(edgeIds = null, excludeEdges = []) { + const targetEdgeIds = edgeIds || this.edges.getIds(); + const edgeAnimDuration = CONSTANTS.ANIMATION.DURATION; + const edgeAnimStart = performance.now(); - if (this.highlightedNodes.length === 0) return; + const animate = () => { + const elapsed = performance.now() - edgeAnimStart; + const progress = Math.min(elapsed / edgeAnimDuration, 1); + const eased = CONSTANTS.ANIMATION.EASE_OUT_CUBIC(progress); - this.highlightedNodes.forEach((nodeId) => { - const nodePosition = this.network.getPositions([nodeId])[nodeId]; - if (!nodePosition) return; + targetEdgeIds.forEach((edgeId) => { + if (excludeEdges.includes(edgeId)) { + return; + } - const canvasPos = this.network.canvasToDOM(nodePosition); - const node = this.nodes.get(nodeId); - if (!node || !node.nodeStyle) return; + const edge = this.edges.get(edgeId); + if (!edge) return; - const nodeStyle = node.nodeStyle; - const scale = 1.0; - const width = CONSTANTS.NODE.BASE_WIDTH * scale; - const height = CONSTANTS.NODE.BASE_HEIGHT * scale; + const defaultColor = + edge.dashes || edge.label === "AND" + ? "{{ CREWAI_ORANGE }}" + : getComputedStyle(document.documentElement).getPropertyValue('--edge-or-color').trim(); + const currentOpacity = edge.opacity !== undefined ? edge.opacity : 1.0; + const currentWidth = + edge.width !== undefined ? edge.width : CONSTANTS.EDGE.DEFAULT_WIDTH; + const currentShadowSize = + edge.shadow && edge.shadow.size !== undefined + ? edge.shadow.size + : CONSTANTS.EDGE.DEFAULT_SHADOW_SIZE; - this.ctx.save(); + const targetOpacity = 1.0; + const targetWidth = CONSTANTS.EDGE.DEFAULT_WIDTH; + const targetShadowSize = CONSTANTS.EDGE.DEFAULT_SHADOW_SIZE; - this.ctx.shadowColor = "transparent"; - this.ctx.shadowBlur = 0; - this.ctx.shadowOffsetX = 0; - this.ctx.shadowOffsetY = 0; + const newOpacity = + currentOpacity + (targetOpacity - currentOpacity) * eased; + const newWidth = currentWidth + (targetWidth - currentWidth) * eased; + const newShadowSize = + currentShadowSize + (targetShadowSize - currentShadowSize) * eased; - const radius = CONSTANTS.NODE.BORDER_RADIUS * scale; - const rectX = canvasPos.x - width / 2; - const rectY = canvasPos.y - height / 2; + const updateData = { + id: edge.id, + hidden: false, + opacity: newOpacity, + width: newWidth, + color: { + color: defaultColor, + highlight: defaultColor, + hover: defaultColor, + inherit: false, + }, + shadow: { + enabled: true, + color: "rgba(0,0,0,0.08)", + size: newShadowSize, + x: 1, + y: 1, + }, + font: { + color: "transparent", + background: "transparent", + }, + arrows: { + to: { + enabled: true, + scaleFactor: 0.8, + type: "triangle", + }, + }, + }; - drawRoundedRect(this.ctx, rectX, rectY, width, height, radius); + if (edge.dashes) { + const scale = Math.sqrt(newWidth / CONSTANTS.EDGE.DEFAULT_WIDTH); + updateData.dashes = [15 * scale, 10 * scale]; + } - this.ctx.fillStyle = nodeStyle.bgColor; - this.ctx.fill(); + this.edges.update(updateData); + }); - this.ctx.shadowColor = "transparent"; - this.ctx.shadowBlur = 0; + if (progress < 1) { + requestAnimationFrame(animate); + } + }; - this.ctx.strokeStyle = "{{ CREWAI_ORANGE }}"; - this.ctx.lineWidth = nodeStyle.borderWidth * scale; - this.ctx.stroke(); - - this.ctx.fillStyle = nodeStyle.fontColor; - this.ctx.font = `500 ${15 * scale}px Inter, -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif`; - this.ctx.textAlign = "center"; - this.ctx.textBaseline = "middle"; - this.ctx.fillText(nodeStyle.name, canvasPos.x, canvasPos.y); - - this.ctx.restore(); - }); + animate(); } clear() { @@ -890,7 +1322,7 @@ class TriggeredByHighlighter { this.highlightedNodes = []; this.highlightedEdges = []; - this.canvas.style.transition = "opacity 300ms ease-out"; + this.canvas.style.transition = `opacity ${CONSTANTS.ANIMATION.DURATION}ms ease-out`; this.canvas.style.opacity = "0"; setTimeout(() => { this.canvas.classList.remove("visible"); @@ -898,21 +1330,18 @@ class TriggeredByHighlighter { this.canvas.style.transition = ""; this.ctx.clearRect(0, 0, this.canvas.width, this.canvas.height); this.network.redraw(); - }, 300); + }, CONSTANTS.ANIMATION.DURATION); } } -// ============================================================================ -// Drawer Manager -// ============================================================================ - class DrawerManager { - constructor(network, nodes, edges, animationManager, triggeredByHighlighter) { + constructor(network, nodes, edges, animationManager, triggeredByHighlighter, networkManager) { this.network = network; this.nodes = nodes; this.edges = edges; this.animationManager = animationManager; this.triggeredByHighlighter = triggeredByHighlighter; + this.networkManager = networkManager; this.elements = { drawer: document.getElementById("drawer"), @@ -922,6 +1351,7 @@ class DrawerManager { openIdeButton: document.getElementById("drawer-open-ide"), closeButton: document.getElementById("drawer-close"), navControls: document.querySelector(".nav-controls"), + legendPanel: document.getElementById("legend-panel"), }; this.activeNodeId = null; @@ -979,9 +1409,7 @@ class DrawerManager { document.body.removeChild(link); const fallbackText = `${filePath}:${lineNum}`; - navigator.clipboard.writeText(fallbackText).catch((err) => { - console.error("Failed to copy:", err); - }); + navigator.clipboard.writeText(fallbackText).catch(() => {}); } detectIDE() { @@ -1002,21 +1430,109 @@ class DrawerManager { this.elements.content.innerHTML = content; this.attachContentEventListeners(nodeName); + + // Initialize Lucide icons in the newly rendered drawer content + if (typeof lucide !== 'undefined') { + lucide.createIcons(); + } } renderTriggerCondition(metadata) { if (metadata.trigger_condition) { return this.renderConditionTree(metadata.trigger_condition); } else if (metadata.trigger_methods) { + const uniqueTriggers = [...new Set(metadata.trigger_methods)]; + const grouped = this.groupByIdenticalAction(uniqueTriggers); + return ` `; } return ""; } + groupByIdenticalAction(triggerIds) { + const nodeData = '{{ nodeData }}'; + const allEdges = this.edges.get(); + const activeNodeId = this.activeNodeId; + + const triggerPaths = new Map(); + + triggerIds.forEach(triggerId => { + const pathSignature = this.getPathSignature(triggerId, activeNodeId, allEdges, nodeData); + if (!triggerPaths.has(pathSignature)) { + triggerPaths.set(pathSignature, []); + } + triggerPaths.get(pathSignature).push(triggerId); + }); + + return Array.from(triggerPaths.values()).map(items => ({ items })); + } + + getPathSignature(triggerNodeId, activeNodeId, allEdges, nodeData) { + const connectingEdges = []; + const direct = allEdges.filter( + (edge) => edge.from === triggerNodeId && edge.to === activeNodeId + ); + if (direct.length > 0) { + return direct.map(e => e.id).sort().join(','); + } + + const activeMetadata = nodeData[activeNodeId]; + if (activeMetadata && activeMetadata.trigger_methods && activeMetadata.trigger_methods.includes(triggerNodeId)) { + for (const [nodeName, nodeInfo] of Object.entries(nodeData)) { + if (nodeInfo.router_paths && nodeInfo.router_paths.includes(triggerNodeId)) { + const routerEdges = allEdges.filter( + (edge) => edge.from === nodeName && edge.dashes + ); + + const matchingEdge = routerEdges.find(edge => edge.label === triggerNodeId); + if (matchingEdge) { + if (matchingEdge.to === activeNodeId) { + return matchingEdge.id; + } + + const pathToActive = allEdges.filter( + (edge) => edge.from === matchingEdge.to && edge.to === activeNodeId + ); + + if (pathToActive.length > 0) { + return [matchingEdge.id, ...pathToActive.map(e => e.id)].sort().join(','); + } + } + + for (const routerEdge of routerEdges) { + if (routerEdge.to === activeNodeId) { + return routerEdge.id; + } + } + } + } + } + + return triggerNodeId; + } + renderConditionTree(condition, depth = 0) { if (typeof condition === "string") { return `${condition}`; @@ -1031,12 +1547,43 @@ class DrawerManager { const triggerIds = this.extractTriggerIds(condition); const triggerIdsJson = JSON.stringify(triggerIds).replace(/"/g, '"'); - const children = condition.conditions.map(sub => this.renderConditionTree(sub, depth + 1)).join(""); + const stringChildren = condition.conditions.filter(c => typeof c === 'string'); + const nonStringChildren = condition.conditions.filter(c => typeof c !== 'string'); + + let children = ""; + + if (nonStringChildren.length > 0) { + children += nonStringChildren.map(sub => this.renderConditionTree(sub, depth + 1)).join(""); + } + + if (stringChildren.length > 0) { + const grouped = this.groupByIdenticalAction(stringChildren); + children += grouped.map((group) => { + if (group.items.length === 1) { + return this.renderConditionTree(group.items[0], depth + 1); + } else { + const groupId = group.items.join(','); + const groupColor = conditionType === "AND" ? "{{ CREWAI_ORANGE }}" : "var(--text-secondary)"; + const groupBgColor = conditionType === "AND" ? "rgba(255,90,80,0.08)" : "rgba(102,102,102,0.06)"; + const groupHoverBg = conditionType === "AND" ? "rgba(255,90,80,0.15)" : "rgba(102,102,102,0.12)"; + return ` +
+
+ ${group.items.length} routes +
+
+ ${group.items.map((t) => `${t}`).join("")} +
+
+ `; + } + }).join(""); + } return `
-
- ${conditionType} +
+ ${conditionType}
${children} @@ -1065,6 +1612,7 @@ class DrawerManager { } renderMetadata(metadata) { + console.log('renderMetadata called with:', metadata); let metadataContent = ""; const nodeType = metadata.type || "unknown"; @@ -1097,7 +1645,13 @@ class DrawerManager { `; } + console.log('Checking trigger data:', { + has_trigger_condition: !!metadata.trigger_condition, + has_trigger_methods: !!(metadata.trigger_methods && metadata.trigger_methods.length > 0) + }); + if (metadata.trigger_condition || (metadata.trigger_methods && metadata.trigger_methods.length > 0)) { + console.log('Rendering Triggered By section'); metadataContent += `
Triggered By
@@ -1107,14 +1661,15 @@ class DrawerManager { } if (metadata.router_paths && metadata.router_paths.length > 0) { - const routerPathsJson = JSON.stringify(metadata.router_paths).replace(/"/g, '"'); + const uniqueRouterPaths = [...new Set(metadata.router_paths)]; + const routerPathsJson = JSON.stringify(uniqueRouterPaths).replace(/"/g, '"'); metadataContent += `
- Router Paths + Router Paths
    - ${metadata.router_paths.map((p) => `
  • ${p}
  • `).join("")} + ${uniqueRouterPaths.map((p) => `
  • ${p}
  • `).join("")}
`; @@ -1243,6 +1798,19 @@ class DrawerManager { }); }); + const triggerGroups = this.elements.content.querySelectorAll( + ".trigger-group[data-trigger-items]", + ); + triggerGroups.forEach((group) => { + group.addEventListener("click", (e) => { + e.preventDefault(); + e.stopPropagation(); + e.stopImmediatePropagation(); + const triggerItems = group.getAttribute("data-trigger-items").split(','); + this.triggeredByHighlighter.highlightTriggeredByGroup(triggerItems); + }); + }); + const conditionGroups = this.elements.content.querySelectorAll( ".condition-group[data-trigger-group]", ); @@ -1271,39 +1839,55 @@ class DrawerManager { this.elements.drawer.style.visibility = "visible"; const wasAlreadyOpen = this.elements.drawer.classList.contains("open"); - requestAnimationFrame(() => { + if (!wasAlreadyOpen) { + // Save current position and scale before opening drawer + const currentPosition = this.networkManager.network.getViewPosition(); + const currentScale = this.networkManager.network.getScale(); + this.networkManager.positionBeforeDrawer = { + position: currentPosition, + scale: currentScale + }; + + const targetPosition = this.networkManager.calculateNetworkPosition(true); this.elements.drawer.classList.add("open"); this.elements.overlay.classList.add("visible"); this.elements.navControls.classList.add("drawer-open"); - - if (!wasAlreadyOpen) { - setTimeout(() => { - const currentScale = this.network.getScale(); - const currentPosition = this.network.getViewPosition(); - const offsetX = - (CONSTANTS.DRAWER.WIDTH * CONSTANTS.DRAWER.OFFSET_SCALE) / - currentScale; - - this.network.moveTo({ - position: { - x: currentPosition.x + offsetX, - y: currentPosition.y, - }, - scale: currentScale, - animation: { - duration: 300, - easingFunction: "easeInOutQuad", - }, - }); - }, 50); - } - }); + this.elements.legendPanel.classList.add("drawer-open"); + this.networkManager.animateToPosition(targetPosition); + } else { + this.elements.drawer.classList.add("open"); + this.elements.overlay.classList.add("visible"); + this.elements.navControls.classList.add("drawer-open"); + this.elements.legendPanel.classList.add("drawer-open"); + } } close() { + // Animate accordions closed before removing classes + const accordions = this.elements.drawer.querySelectorAll(".accordion-section.expanded"); + accordions.forEach(accordion => { + const content = accordion.querySelector(".accordion-content"); + if (content) { + // Set explicit height for smooth animation + content.style.height = content.scrollHeight + "px"; + // Force reflow + content.offsetHeight; + // Trigger collapse animation + content.style.height = "0px"; + } + // Remove expanded class after animation + setTimeout(() => { + accordion.classList.remove("expanded"); + if (content) { + content.style.height = ""; + } + }, CONSTANTS.ANIMATION.DURATION); + }); + this.elements.drawer.classList.remove("open"); this.elements.overlay.classList.remove("visible"); this.elements.navControls.classList.remove("drawer-open"); + this.elements.legendPanel.classList.remove("drawer-open"); if (this.activeNodeId) { this.activeEdges.forEach((edgeId) => { @@ -1311,7 +1895,7 @@ class DrawerManager { this.edges, edgeId, CONSTANTS.EDGE.DEFAULT_WIDTH, - 200, + CONSTANTS.ANIMATION.DURATION, ); }); this.activeNodeId = null; @@ -1319,21 +1903,21 @@ class DrawerManager { } this.triggeredByHighlighter.clear(); + this.elements.drawer.offsetHeight; - setTimeout(() => { - this.network.fit({ - animation: { - duration: 300, - easingFunction: "easeInOutQuad", - }, - }); - }, 50); + // Restore the position before the drawer was opened + if (this.networkManager.positionBeforeDrawer) { + this.networkManager.animateToPosition(this.networkManager.positionBeforeDrawer); + this.networkManager.positionBeforeDrawer = null; + } else { + this.networkManager.fitToAvailableSpace(); + } setTimeout(() => { if (!this.elements.drawer.classList.contains("open")) { this.elements.drawer.style.visibility = "hidden"; } - }, 300); + }, CONSTANTS.ANIMATION.DURATION); } setActiveNode(nodeId, connectedEdges) { @@ -1359,6 +1943,7 @@ class NetworkManager { this.pressedNodeId = null; this.pressedEdges = []; this.isClicking = false; + this.positionBeforeDrawer = null; } async initialize() { @@ -1390,6 +1975,7 @@ class NetworkManager { this.edges, this.animationManager, this.triggeredByHighlighter, + this, ); this.setupEventListeners(); @@ -1397,7 +1983,7 @@ class NetworkManager { this.setupTheme(); this.network.once("stabilizationIterationsDone", () => { - this.network.fit(); + this.fitToAvailableSpace(true); }); } catch (error) { console.error("Failed to initialize network:", error); @@ -1405,7 +1991,11 @@ class NetworkManager { } createNetworkOptions() { - const nodeRenderer = new NodeRenderer(this.nodes, this); + this.nodeRenderer = new NodeRenderer(this.nodes, this); + const nodesArray = this.nodes.get(); + const hasExplicitPositions = nodesArray.some(node => + node.x !== undefined && node.y !== undefined + ); return { nodes: { @@ -1413,7 +2003,7 @@ class NetworkManager { shadow: false, chosen: false, size: 30, - ctxRenderer: (params) => nodeRenderer.render(params), + ctxRenderer: (params) => this.nodeRenderer.render(params), scaling: { min: 1, max: 100, @@ -1425,8 +2015,10 @@ class NetworkManager { labelHighlightBold: false, shadow: false, smooth: { + enabled: true, type: "cubicBezier", - roundness: 0.5, + roundness: 0.35, + forceDirection: 'vertical', }, font: { size: 13, @@ -1451,7 +2043,7 @@ class NetworkManager { }, }, physics: { - enabled: true, + enabled: false, hierarchicalRepulsion: { nodeDistance: CONSTANTS.NETWORK.NODE_DISTANCE, centralGravity: 0.0, @@ -1461,19 +2053,22 @@ class NetworkManager { }, solver: "hierarchicalRepulsion", stabilization: { - enabled: true, + enabled: false, iterations: CONSTANTS.NETWORK.STABILIZATION_ITERATIONS, updateInterval: 25, }, }, layout: { hierarchical: { - enabled: true, + enabled: !hasExplicitPositions, direction: "UD", sortMethod: "directed", levelSeparation: CONSTANTS.NETWORK.LEVEL_SEPARATION, nodeSpacing: CONSTANTS.NETWORK.NODE_SPACING, treeSpacing: CONSTANTS.NETWORK.TREE_SPACING, + edgeMinimization: false, + blockShifting: true, + parentCentralization: true, }, }, interaction: { @@ -1550,11 +2145,6 @@ class NetworkManager { } }); - this.network.on("afterDrawing", () => { - if (this.triggeredByHighlighter.canvas.classList.contains("visible")) { - this.triggeredByHighlighter.drawHighlightLayer(); - } - }); } handleNodeClick(nodeId) { @@ -1563,23 +2153,36 @@ class NetworkManager { const metadata = nodeData[nodeId]; this.isClicking = true; - - if ( - this.drawerManager.activeNodeId && - this.drawerManager.activeNodeId !== nodeId - ) { - this.drawerManager.activeEdges.forEach((edgeId) => { - this.animationManager.animateEdgeWidth( - this.edges, - edgeId, - CONSTANTS.EDGE.DEFAULT_WIDTH, - 200, - ); - }); - this.triggeredByHighlighter.clear(); + if (this.drawerManager && this.drawerManager.activeNodeId === nodeId) { + this.drawerManager.close(); + return; } const connectedEdges = this.network.getConnectedEdges(nodeId); + + const allEdges = this.edges.get(); + const connectedNodeIds = new Set([nodeId]); + + connectedEdges.forEach((edgeId) => { + const edge = allEdges.find(e => e.id === edgeId); + if (edge) { + if (edge.from === nodeId) { + connectedNodeIds.add(edge.to); + } else if (edge.to === nodeId) { + connectedNodeIds.add(edge.from); + } + } + }); + const allNodes = this.nodes.get(); + allNodes.forEach((n) => { + this.nodes.update({ id: n.id, opacity: 1.0 }); + }); + this.triggeredByHighlighter.highlightedNodes = []; + this.triggeredByHighlighter.highlightedEdges = []; + + // Animate all edges back to default, excluding the ones we'll highlight + this.triggeredByHighlighter.resetEdgesToDefault(null, connectedEdges); + this.drawerManager.setActiveNode(nodeId, connectedEdges); this.triggeredByHighlighter.setActiveDrawer(nodeId, connectedEdges); @@ -1590,6 +2193,7 @@ class NetworkManager { }, 15); this.drawerManager.open(nodeId, metadata); + this.network.redraw(); } setupControls() { @@ -1617,9 +2221,7 @@ class NetworkManager { }); document.getElementById("fit").addEventListener("click", () => { - this.network.fit({ - animation: { duration: 300, easingFunction: "easeInOutQuad" }, - }); + this.fitToAvailableSpace(); }); this.setupExportControls(); @@ -1634,7 +2236,7 @@ class NetworkManager { html2canvas(document.getElementById("network-container")).then( (canvas) => { const link = document.createElement("a"); - link.download = "flow_dag.png"; + link.download = "flow.png"; link.href = canvas.toDataURL(); link.click(); }, @@ -1663,7 +2265,7 @@ class NetworkManager { format: [canvas.width, canvas.height], }); pdf.addImage(imgData, "PNG", 0, 0, canvas.width, canvas.height); - pdf.save("flow_dag.pdf"); + pdf.save("flow.pdf"); }, ); }; @@ -1672,17 +2274,90 @@ class NetworkManager { document.head.appendChild(script1); }); - document.getElementById("export-json").addEventListener("click", () => { - const dagData = '{{ dagData }}'; - const dataStr = JSON.stringify(dagData, null, 2); - const blob = new Blob([dataStr], { type: "application/json" }); - const url = URL.createObjectURL(blob); - const link = document.createElement("a"); - link.download = "flow_dag.json"; - link.href = url; - link.click(); - URL.revokeObjectURL(url); + // document.getElementById("export-json").addEventListener("click", () => { + // const dagData = '{{ dagData }}'; + // const dataStr = JSON.stringify(dagData, null, 2); + // const blob = new Blob([dataStr], { type: "application/json" }); + // const url = URL.createObjectURL(blob); + // const link = document.createElement("a"); + // link.download = "flow_dag.json"; + // link.href = url; + // link.click(); + // URL.revokeObjectURL(url); + // }); + } + + calculateNetworkPosition(isDrawerOpen, centerScreen = false) { + const infoBox = document.getElementById("info"); + const infoRect = infoBox.getBoundingClientRect(); + const leftEdge = infoRect.right + 40; // 40px padding after legend + const rightEdge = isDrawerOpen ? window.innerWidth - CONSTANTS.DRAWER.WIDTH - 40 : window.innerWidth - 40; + const availableWidth = rightEdge - leftEdge; + + // Use true screen center for initial position, otherwise use available space center + const canvas = this.network ? this.network.canvas.frame.canvas : document.getElementById("network"); + const canvasRect = canvas.getBoundingClientRect(); + const domCenterX = centerScreen ? canvasRect.left + canvas.clientWidth / 2 : leftEdge + (availableWidth / 2); + + const nodePositions = this.network.getPositions(); + const nodeIds = Object.keys(nodePositions); + + if (nodeIds.length === 0) return null; + const canvasWidth = canvas.clientWidth; + const canvasHeight = canvas.clientHeight; + + const padding = 30; + const maxNodeWidth = 200; + const maxNodeHeight = 60; + + let minX = Infinity, maxX = -Infinity, minY = Infinity, maxY = -Infinity; + nodeIds.forEach(id => { + const pos = nodePositions[id]; + minX = Math.min(minX, pos.x - maxNodeWidth / 2); + maxX = Math.max(maxX, pos.x + maxNodeWidth / 2); + minY = Math.min(minY, pos.y - maxNodeHeight / 2); + maxY = Math.max(maxY, pos.y + maxNodeHeight / 2); }); + + const networkWidth = maxX - minX; + const networkHeight = maxY - minY; + const networkCenterX = (minX + maxX) / 2; + const networkCenterY = (minY + maxY) / 2; + const scaleX = availableWidth / (networkWidth + padding * 2); + const scaleY = (canvasHeight - padding * 2) / (networkHeight + padding * 2); + const scale = Math.min(scaleX, scaleY); + const targetDOMX = domCenterX; + const canvasCenterDOMX = canvasRect.left + canvasWidth / 2; + const domShift = targetDOMX - canvasCenterDOMX; + const networkShift = domShift / scale; + + return { + position: { + x: networkCenterX - networkShift, + y: networkCenterY, + }, + scale: scale, + }; + } + + animateToPosition(targetPosition) { + if (!targetPosition) return; + + this.network.moveTo({ + position: targetPosition.position, + scale: targetPosition.scale, + animation: { + duration: 300, + easingFunction: "easeInOutCubic" + }, + }); + } + + fitToAvailableSpace(centerScreen = false) { + const drawer = document.getElementById("drawer"); + const isDrawerOpen = drawer.classList.contains("open"); + const targetPosition = this.calculateNetworkPosition(isDrawerOpen, centerScreen); + this.animateToPosition(targetPosition); } setupTheme() { @@ -1718,33 +2393,55 @@ class NetworkManager { this.network.redraw(); }; - const savedTheme = localStorage.getItem("theme") || "light"; - if (savedTheme === "dark") { - htmlElement.setAttribute("data-theme", "dark"); - themeToggle.textContent = "☀️"; - themeToggle.title = "Toggle Light Mode"; - setTimeout(updateEdgeColors, 0); - } else { - setTimeout(updateEdgeColors, 0); - } + const updateThemeIcon = (isDark) => { + const iconName = isDark ? 'sun' : 'moon'; + themeToggle.title = isDark ? "Toggle Light Mode" : "Toggle Dark Mode"; + // Replace the icon HTML entirely and reinitialize Lucide + themeToggle.innerHTML = ``; + + // Reinitialize Lucide icons for the specific button + if (typeof lucide !== 'undefined') { + lucide.createIcons({ + elements: themeToggle.querySelectorAll('[data-lucide]') + }); + } + }; + + // Set up click handler FIRST before any icon updates themeToggle.addEventListener("click", () => { const currentTheme = htmlElement.getAttribute("data-theme"); const newTheme = currentTheme === "dark" ? "light" : "dark"; if (newTheme === "dark") { htmlElement.setAttribute("data-theme", "dark"); - themeToggle.textContent = "☀️"; - themeToggle.title = "Toggle Light Mode"; + updateThemeIcon(true); } else { htmlElement.removeAttribute("data-theme"); - themeToggle.textContent = "🌙"; - themeToggle.title = "Toggle Dark Mode"; + updateThemeIcon(false); } localStorage.setItem("theme", newTheme); + + // Clear color cache to ensure theme-dependent colors are recalculated + if (this.nodeRenderer) { + this.nodeRenderer.colorCache.clear(); + } + + // Update edge colors and redraw network with new theme setTimeout(updateEdgeColors, 50); }); + + // Initialize theme after click handler is set up + const savedTheme = localStorage.getItem("theme") || "dark"; + if (savedTheme === "dark") { + htmlElement.setAttribute("data-theme", "dark"); + updateThemeIcon(true); + setTimeout(updateEdgeColors, 0); + } else { + updateThemeIcon(false); + setTimeout(updateEdgeColors, 0); + } } } @@ -1753,6 +2450,16 @@ class NetworkManager { // ============================================================================ (async () => { + // Initialize Lucide icons first (before theme setup) + if (typeof lucide !== 'undefined') { + lucide.createIcons(); + } + const networkManager = new NetworkManager(); await networkManager.initialize(); + + // Re-initialize Lucide icons after theme is set up + if (typeof lucide !== 'undefined') { + lucide.createIcons(); + } })(); diff --git a/lib/crewai/src/crewai/flow/visualization/assets/interactive_flow.html.j2 b/lib/crewai/src/crewai/flow/visualization/assets/interactive_flow.html.j2 index 2f374f4bc..876286e67 100644 --- a/lib/crewai/src/crewai/flow/visualization/assets/interactive_flow.html.j2 +++ b/lib/crewai/src/crewai/flow/visualization/assets/interactive_flow.html.j2 @@ -6,6 +6,7 @@ + @@ -23,93 +24,129 @@
Node Details
- +
-
+
CrewAI Logo -
-

Flow Execution

-
-

Nodes: '{{ dag_nodes_count }}'

-

Edges: '{{ dag_edges_count }}'

-

Topological Paths: '{{ execution_paths }}'

-
-
-
Node Types
-
-
- Start Methods -
-
-
- Router Methods -
-
-
- Listen Methods -
-
-
-
Edge Types
-
- - - - Router Paths -
-
- - - - OR Conditions -
-
- - - - AND Conditions -
-
-
- Interactions:
- • Drag to pan
- • Scroll to zoom

- IDE: - + style="width: 144px; height: auto;">
+
+ + +
+ +
+
+
+ '{{ dag_nodes_count }}' + Nodes +
+
+ '{{ dag_edges_count }}' + Edges +
+
+ '{{ execution_paths }}' + Paths +
+
+
+ + +
+
+
+
+ Start +
+
+
+ Router +
+
+
+ Listen +
+
+
+ + +
+
+
+ + + + Router +
+
+ + + + OR +
+
+ + + + AND +
+
+
+ + +
+
+ + +
+
+
diff --git a/lib/crewai/src/crewai/flow/visualization/assets/style.css b/lib/crewai/src/crewai/flow/visualization/assets/style.css index 724ec5cbb..7566d822e 100644 --- a/lib/crewai/src/crewai/flow/visualization/assets/style.css +++ b/lib/crewai/src/crewai/flow/visualization/assets/style.css @@ -13,6 +13,14 @@ --edge-label-text: '{{ GRAY }}'; --edge-label-bg: rgba(255, 255, 255, 0.8); --edge-or-color: #000000; + --edge-router-color: '{{ CREWAI_ORANGE }}'; + --node-border-start: #C94238; + --node-border-listen: #3D3D3D; + --node-bg-start: #FF7066; + --node-bg-router: #FFFFFF; + --node-bg-listen: #FFFFFF; + --node-text-color: #FFFFFF; + --nav-button-hover: #f5f5f5; } [data-theme="dark"] { @@ -30,6 +38,14 @@ --edge-label-text: #c9d1d9; --edge-label-bg: rgba(22, 27, 34, 0.9); --edge-or-color: #ffffff; + --edge-router-color: '{{ CREWAI_ORANGE }}'; + --node-border-start: #FF5A50; + --node-border-listen: #666666; + --node-bg-start: #B33830; + --node-bg-router: #3D3D3D; + --node-bg-listen: #3D3D3D; + --node-text-color: #FFFFFF; + --nav-button-hover: #30363d; } @keyframes dash { @@ -72,12 +88,10 @@ body { position: absolute; top: 20px; left: 20px; - background: var(--bg-secondary); + background: transparent; padding: 20px; border-radius: 8px; - box-shadow: 0 4px 12px var(--shadow-strong); max-width: 320px; - border: 1px solid var(--border-color); z-index: 10000; pointer-events: auto; transition: background 0.3s ease, border-color 0.3s ease, box-shadow 0.3s ease; @@ -125,12 +139,16 @@ h3 { margin-right: 12px; border-radius: 3px; box-sizing: border-box; + transition: background 0.3s ease, border-color 0.3s ease; } .legend-item span { color: var(--text-secondary); font-size: 13px; transition: color 0.3s ease; } +.legend-item svg line { + transition: stroke 0.3s ease; +} .instructions { margin-top: 15px; padding-top: 15px; @@ -155,7 +173,7 @@ h3 { bottom: 20px; right: auto; display: grid; - grid-template-columns: repeat(4, 40px); + grid-template-columns: repeat(3, 40px); gap: 8px; z-index: 10002; pointer-events: auto; @@ -165,10 +183,187 @@ h3 { .nav-controls.drawer-open { } +#legend-panel { + position: fixed; + left: 164px; + bottom: 20px; + right: 20px; + height: 92px; + background: var(--bg-secondary); + backdrop-filter: blur(12px) saturate(180%); + -webkit-backdrop-filter: blur(12px) saturate(180%); + border: 1px solid var(--border-subtle); + border-radius: 6px; + box-shadow: 0 2px 8px var(--shadow-color); + display: grid; + grid-template-columns: repeat(4, 1fr); + align-items: center; + gap: 0; + padding: 0 24px; + box-sizing: border-box; + z-index: 10001; + pointer-events: auto; + transition: background 0.3s ease, border-color 0.3s ease, box-shadow 0.3s ease, right 0.3s cubic-bezier(0.4, 0, 0.2, 1); +} + +#legend-panel.drawer-open { + right: 405px; +} + +.legend-section { + display: flex; + align-items: center; + justify-content: center; + min-width: 0; + width: -webkit-fill-available; + width: -moz-available; + width: stretch; + position: relative; +} + +.legend-section:not(:last-child)::after { + content: ''; + position: absolute; + right: 0; + top: 50%; + transform: translateY(-50%); + width: 1px; + height: 48px; + background: var(--border-color); + transition: background 0.3s ease; +} + +.legend-stats-row { + display: flex; + gap: 32px; + justify-content: center; + align-items: center; + min-width: 0; +} + +.legend-stat-item { + display: flex; + flex-direction: column; + align-items: center; + gap: 4px; +} + +.stat-value { + font-size: 19px; + font-weight: 700; + color: var(--text-primary); + line-height: 1; + transition: color 0.3s ease; +} + +.stat-label { + font-size: 8px; + font-weight: 500; + text-transform: uppercase; + color: var(--text-secondary); + letter-spacing: 0.5px; + transition: color 0.3s ease; +} + +.legend-items-row { + display: flex; + gap: 16px; + align-items: center; + justify-content: center; + min-width: 0; +} + +.legend-group { + display: flex; + gap: 16px; + align-items: center; +} + +.legend-item-compact { + display: flex; + align-items: center; + gap: 6px; +} + +.legend-item-compact span { + font-size: 12px; + font-weight: 500; + text-transform: uppercase; + color: var(--text-secondary); + letter-spacing: 0.5px; + white-space: nowrap; + font-family: inherit; + line-height: 1; + transition: color 0.3s ease; +} + +.legend-color-small { + width: 17px; + height: 17px; + border-radius: 2px; + box-sizing: border-box; + flex-shrink: 0; + transition: background 0.3s ease, border-color 0.3s ease; +} + +.legend-item-compact svg { + display: block; + flex-shrink: 0; + width: 29px; + height: 14px; +} + +.legend-item-compact svg line { + transition: stroke 0.3s ease; +} + +.legend-ide-column { + display: flex; + flex-direction: row; + gap: 8px; + align-items: center; + justify-content: center; + min-width: 0; + width: 100%; +} + +.legend-ide-label { + font-size: 12px; + font-weight: 500; + text-transform: uppercase; + color: var(--text-secondary); + letter-spacing: 0.5px; + transition: color 0.3s ease; + white-space: nowrap; +} + +.legend-ide-select { + width: 120px; + padding: 6px 10px; + border-radius: 4px; + border: 1px solid var(--border-subtle); + background: var(--bg-primary); + color: var(--text-primary); + font-size: 11px; + cursor: pointer; + transition: all 0.3s ease; +} + +.legend-ide-select:hover { + border-color: var(--text-secondary); +} + +.legend-ide-select:focus { + outline: none; + border-color: '{{ CREWAI_ORANGE }}'; +} + .nav-button { width: 40px; height: 40px; background: var(--bg-secondary); + backdrop-filter: blur(12px) saturate(180%); + -webkit-backdrop-filter: blur(12px) saturate(180%); border: 1px solid var(--border-subtle); border-radius: 6px; display: flex; @@ -181,12 +376,12 @@ h3 { user-select: none; pointer-events: auto; position: relative; - z-index: 10001; + z-index: 10002; transition: background 0.3s ease, border-color 0.3s ease, color 0.3s ease, box-shadow 0.3s ease; } .nav-button:hover { - background: var(--border-subtle); + background: var(--nav-button-hover); } #drawer { @@ -198,9 +393,10 @@ h3 { background: var(--bg-drawer); box-shadow: -4px 0 12px var(--shadow-strong); transition: right 0.3s cubic-bezier(0.4, 0, 0.2, 1), background 0.3s ease, box-shadow 0.3s ease; - z-index: 2000; - overflow-y: auto; - padding: 24px; + z-index: 10003; + overflow: hidden; + transform: translateZ(0); + isolation: isolate; } #drawer.open { @@ -247,17 +443,22 @@ h3 { justify-content: space-between; align-items: center; margin-bottom: 20px; - padding-bottom: 16px; + padding: 24px 24px 16px 24px; border-bottom: 2px solid '{{ CREWAI_ORANGE }}'; position: relative; z-index: 2001; } .drawer-title { - font-size: 20px; + font-size: 15px; font-weight: 700; color: var(--text-primary); transition: color 0.3s ease; + overflow: hidden; + text-overflow: ellipsis; + white-space: nowrap; + flex: 1; + min-width: 0; } .drawer-close { @@ -269,12 +470,19 @@ h3 { padding: 4px 8px; line-height: 1; transition: color 0.3s ease; + display: flex; + align-items: center; + justify-content: center; } .drawer-close:hover { color: '{{ CREWAI_ORANGE }}'; } +.drawer-close i { + display: block; +} + .drawer-open-ide { background: '{{ CREWAI_ORANGE }}'; border: none; @@ -292,6 +500,9 @@ h3 { position: relative; z-index: 9999; pointer-events: auto; + white-space: nowrap; + flex-shrink: 0; + min-width: fit-content; } .drawer-open-ide:hover { @@ -305,14 +516,19 @@ h3 { box-shadow: 0 1px 4px rgba(255, 90, 80, 0.2); } -.drawer-open-ide svg { +.drawer-open-ide svg, +.drawer-open-ide i { width: 14px; height: 14px; + display: block; } .drawer-content { color: '{{ DARK_GRAY }}'; line-height: 1.6; + padding: 0 24px 24px 24px; + overflow-y: auto; + height: calc(100vh - 95px); } .drawer-section { @@ -328,6 +544,10 @@ h3 { position: relative; } +.drawer-metadata-grid:has(.drawer-section:nth-child(3):nth-last-child(1)) { + grid-template-columns: 1fr 2fr; +} + .drawer-metadata-grid::before { content: ''; position: absolute; @@ -419,20 +639,35 @@ h3 { grid-column: 2; display: flex; flex-direction: column; - justify-content: center; + justify-content: flex-start; + align-items: flex-start; } .drawer-metadata-grid:has(.drawer-section:nth-child(3):nth-last-child(1))::after { - right: 50%; + right: 66.666%; +} + +.drawer-metadata-grid:has(.drawer-section:nth-child(3):nth-last-child(1))::before { + left: 33.333%; +} + +.drawer-metadata-grid .drawer-section:nth-child(3):nth-last-child(1) .drawer-section-title { + align-self: flex-start; +} + +.drawer-metadata-grid .drawer-section:nth-child(3):nth-last-child(1) > *:not(.drawer-section-title) { + width: 100%; + align-self: stretch; } .drawer-section-title { font-size: 12px; text-transform: uppercase; - color: '{{ GRAY }}'; + color: var(--text-secondary); letter-spacing: 0.5px; margin-bottom: 8px; font-weight: 600; + transition: color 0.3s ease; } .drawer-badge { @@ -465,9 +700,44 @@ h3 { padding: 3px 0; } +.drawer-metadata-grid .drawer-section .drawer-list { + display: flex; + flex-direction: column; + gap: 6px; +} + +.drawer-metadata-grid .drawer-section .drawer-list li { + border-bottom: none; + padding: 0; +} + .drawer-metadata-grid .drawer-section:nth-child(3) .drawer-list li { border-bottom: none; - padding: 3px 0; + padding: 0; +} + +.drawer-metadata-grid .drawer-section { + overflow: visible; +} + +.drawer-metadata-grid .drawer-section .condition-group, +.drawer-metadata-grid .drawer-section .trigger-group { + width: 100%; + box-sizing: border-box; +} + +.drawer-metadata-grid .drawer-section .condition-children { + width: 100%; +} + +.drawer-metadata-grid .drawer-section .trigger-group-items { + width: 100%; +} + +.drawer-metadata-grid .drawer-section .drawer-code-link { + word-break: break-word; + overflow-wrap: break-word; + max-width: 100%; } .drawer-code { @@ -491,6 +761,7 @@ h3 { cursor: pointer; transition: all 0.2s; display: inline-block; + margin: 3px 2px; } .drawer-code-link:hover { diff --git a/lib/crewai/src/crewai/flow/visualization/builder.py b/lib/crewai/src/crewai/flow/visualization/builder.py index 8a7ffece1..33ec2c114 100644 --- a/lib/crewai/src/crewai/flow/visualization/builder.py +++ b/lib/crewai/src/crewai/flow/visualization/builder.py @@ -3,12 +3,13 @@ from __future__ import annotations from collections import defaultdict +from collections.abc import Iterable import inspect from typing import TYPE_CHECKING, Any from crewai.flow.constants import AND_CONDITION, OR_CONDITION from crewai.flow.flow_wrappers import FlowCondition -from crewai.flow.types import FlowMethodName +from crewai.flow.types import FlowMethodName, FlowRouteName from crewai.flow.utils import ( is_flow_condition_dict, is_simple_flow_condition, @@ -197,8 +198,6 @@ def build_flow_structure(flow: Flow[Any]) -> FlowStructure: node_metadata["type"] = "router" router_methods.append(method_name) - node_metadata["condition_type"] = "IF" - if method_name in flow._router_paths: node_metadata["router_paths"] = [ str(p) for p in flow._router_paths[method_name] @@ -210,9 +209,13 @@ def build_flow_structure(flow: Flow[Any]) -> FlowStructure: ] if hasattr(method, "__condition_type__") and method.__condition_type__: + node_metadata["trigger_condition_type"] = method.__condition_type__ if "condition_type" not in node_metadata: node_metadata["condition_type"] = method.__condition_type__ + if node_metadata.get("is_router") and "condition_type" not in node_metadata: + node_metadata["condition_type"] = "IF" + if ( hasattr(method, "__trigger_condition__") and method.__trigger_condition__ is not None @@ -298,6 +301,9 @@ def build_flow_structure(flow: Flow[Any]) -> FlowStructure: nodes[method_name] = node_metadata for listener_name, condition_data in flow._listeners.items(): + if listener_name in router_methods: + continue + if is_simple_flow_condition(condition_data): cond_type, methods = condition_data edges.extend( @@ -315,6 +321,60 @@ def build_flow_structure(flow: Flow[Any]) -> FlowStructure: _create_edges_from_condition(condition_data, str(listener_name), nodes) ) + for method_name, node_metadata in nodes.items(): # type: ignore[assignment] + if node_metadata.get("is_router") and "trigger_methods" in node_metadata: + trigger_methods = node_metadata["trigger_methods"] + condition_type = node_metadata.get("trigger_condition_type", OR_CONDITION) + + if "trigger_condition" in node_metadata: + edges.extend( + _create_edges_from_condition( + node_metadata["trigger_condition"], # type: ignore[arg-type] + method_name, + nodes, + ) + ) + else: + edges.extend( + StructureEdge( + source=trigger_method, + target=method_name, + condition_type=condition_type, + is_router_path=False, + ) + for trigger_method in trigger_methods + if trigger_method in nodes + ) + + for router_method_name in router_methods: + if router_method_name not in flow._router_paths: + flow._router_paths[FlowMethodName(router_method_name)] = [] + + inferred_paths: Iterable[FlowMethodName | FlowRouteName] = set( + flow._router_paths.get(FlowMethodName(router_method_name), []) + ) + + for condition_data in flow._listeners.values(): + trigger_strings: list[str] = [] + + if is_simple_flow_condition(condition_data): + _, methods = condition_data + trigger_strings = [str(m) for m in methods] + elif is_flow_condition_dict(condition_data): + trigger_strings = _extract_direct_or_triggers(condition_data) + + for trigger_str in trigger_strings: + if trigger_str not in nodes: + # This is likely a router path output + inferred_paths.add(trigger_str) # type: ignore[attr-defined] + + if inferred_paths: + flow._router_paths[FlowMethodName(router_method_name)] = list( + inferred_paths # type: ignore[arg-type] + ) + if router_method_name in nodes: + nodes[router_method_name]["router_paths"] = list(inferred_paths) + for router_method_name in router_methods: if router_method_name not in flow._router_paths: continue @@ -340,6 +400,7 @@ def build_flow_structure(flow: Flow[Any]) -> FlowStructure: target=str(listener_name), condition_type=None, is_router_path=True, + router_path_label=str(path), ) ) diff --git a/lib/crewai/src/crewai/flow/visualization/renderers/interactive.py b/lib/crewai/src/crewai/flow/visualization/renderers/interactive.py index 6ce0c0fc7..88242bea6 100644 --- a/lib/crewai/src/crewai/flow/visualization/renderers/interactive.py +++ b/lib/crewai/src/crewai/flow/visualization/renderers/interactive.py @@ -20,7 +20,7 @@ class CSSExtension(Extension): Provides {% css 'path/to/file.css' %} tag syntax. """ - tags: ClassVar[set[str]] = {"css"} # type: ignore[assignment] + tags: ClassVar[set[str]] = {"css"} # type: ignore[misc] def parse(self, parser: Parser) -> nodes.Node: """Parse {% css 'styles.css' %} tag. @@ -53,7 +53,7 @@ class JSExtension(Extension): Provides {% js 'path/to/file.js' %} tag syntax. """ - tags: ClassVar[set[str]] = {"js"} # type: ignore[assignment] + tags: ClassVar[set[str]] = {"js"} # type: ignore[misc] def parse(self, parser: Parser) -> nodes.Node: """Parse {% js 'script.js' %} tag. @@ -91,6 +91,116 @@ TEXT_PRIMARY = "#e6edf3" TEXT_SECONDARY = "#7d8590" +def calculate_node_positions( + dag: FlowStructure, +) -> dict[str, dict[str, int | float]]: + """Calculate hierarchical positions (level, x, y) for each node. + + Args: + dag: FlowStructure containing nodes and edges. + + Returns: + Dictionary mapping node names to their position data (level, x, y). + """ + children: dict[str, list[str]] = {name: [] for name in dag["nodes"]} + parents: dict[str, list[str]] = {name: [] for name in dag["nodes"]} + + for edge in dag["edges"]: + source = edge["source"] + target = edge["target"] + if source in children and target in children: + children[source].append(target) + parents[target].append(source) + + levels: dict[str, int] = {} + queue: list[tuple[str, int]] = [] + + for start_method in dag["start_methods"]: + if start_method in dag["nodes"]: + levels[start_method] = 0 + queue.append((start_method, 0)) + + visited: set[str] = set() + while queue: + node, level = queue.pop(0) + if node in visited: + continue + visited.add(node) + + if node not in levels or levels[node] < level: + levels[node] = level + + for child in children.get(node, []): + if child not in visited: + child_level = level + 1 + if child not in levels or levels[child] < child_level: + levels[child] = child_level + queue.append((child, child_level)) + + for name in dag["nodes"]: + if name not in levels: + levels[name] = 0 + + nodes_by_level: dict[int, list[str]] = {} + for node, level in levels.items(): + if level not in nodes_by_level: + nodes_by_level[level] = [] + nodes_by_level[level].append(node) + + positions: dict[str, dict[str, int | float]] = {} + level_separation = 300 # Vertical spacing between levels + node_spacing = 400 # Horizontal spacing between nodes + + parent_count: dict[str, int] = {} + for node, parent_list in parents.items(): + parent_count[node] = len(parent_list) + + for level, nodes_at_level in sorted(nodes_by_level.items()): + y = level * level_separation + + if level == 0: + num_nodes = len(nodes_at_level) + for i, node in enumerate(nodes_at_level): + x = (i - (num_nodes - 1) / 2) * node_spacing + positions[node] = {"level": level, "x": x, "y": y} + else: + for i, node in enumerate(nodes_at_level): + parent_list = parents.get(node, []) + parent_positions: list[float] = [ + positions[parent]["x"] + for parent in parent_list + if parent in positions + ] + + if parent_positions: + if len(parent_positions) > 1 and len(set(parent_positions)) == 1: + base_x = parent_positions[0] + avg_x = base_x + node_spacing * 0.4 + else: + avg_x = sum(parent_positions) / len(parent_positions) + else: + avg_x = i * node_spacing * 0.5 + + positions[node] = {"level": level, "x": avg_x, "y": y} + + nodes_at_level_sorted = sorted( + nodes_at_level, key=lambda n: positions[n]["x"] + ) + min_spacing = node_spacing * 0.6 # Minimum horizontal distance + + for i in range(len(nodes_at_level_sorted) - 1): + current_node = nodes_at_level_sorted[i] + next_node = nodes_at_level_sorted[i + 1] + + current_x = positions[current_node]["x"] + next_x = positions[next_node]["x"] + + if next_x - current_x < min_spacing: + positions[next_node]["x"] = current_x + min_spacing + + return positions + + def render_interactive( dag: FlowStructure, filename: str = "flow_dag.html", @@ -110,6 +220,8 @@ def render_interactive( Returns: Absolute path to generated HTML file in temporary directory. """ + node_positions = calculate_node_positions(dag) + nodes_list: list[dict[str, Any]] = [] for name, metadata in dag["nodes"].items(): node_type: str = metadata.get("type", "listen") @@ -120,37 +232,37 @@ def render_interactive( if node_type == "start": color_config = { - "background": CREWAI_ORANGE, - "border": CREWAI_ORANGE, + "background": "var(--node-bg-start)", + "border": "var(--node-border-start)", "highlight": { - "background": CREWAI_ORANGE, - "border": CREWAI_ORANGE, + "background": "var(--node-bg-start)", + "border": "var(--node-border-start)", }, } - font_color = WHITE - border_width = 2 + font_color = "var(--node-text-color)" + border_width = 3 elif node_type == "router": color_config = { - "background": DARK_GRAY, + "background": "var(--node-bg-router)", "border": CREWAI_ORANGE, "highlight": { - "background": DARK_GRAY, + "background": "var(--node-bg-router)", "border": CREWAI_ORANGE, }, } - font_color = WHITE + font_color = "var(--node-text-color)" border_width = 3 else: color_config = { - "background": DARK_GRAY, - "border": DARK_GRAY, + "background": "var(--node-bg-listen)", + "border": "var(--node-border-listen)", "highlight": { - "background": DARK_GRAY, - "border": DARK_GRAY, + "background": "var(--node-bg-listen)", + "border": "var(--node-border-listen)", }, } - font_color = WHITE - border_width = 2 + font_color = "var(--node-text-color)" + border_width = 3 title_parts: list[str] = [] @@ -215,25 +327,34 @@ def render_interactive( bg_color = color_config["background"] border_color = color_config["border"] - nodes_list.append( - { - "id": name, - "label": name, - "title": "".join(title_parts), - "shape": "custom", - "size": 30, - "nodeStyle": { - "name": name, - "bgColor": bg_color, - "borderColor": border_color, - "borderWidth": border_width, - "fontColor": font_color, - }, - "opacity": 1.0, - "glowSize": 0, - "glowColor": None, - } - ) + position_data = node_positions.get(name, {"level": 0, "x": 0, "y": 0}) + + node_data: dict[str, Any] = { + "id": name, + "label": name, + "title": "".join(title_parts), + "shape": "custom", + "size": 30, + "level": position_data["level"], + "nodeStyle": { + "name": name, + "bgColor": bg_color, + "borderColor": border_color, + "borderWidth": border_width, + "fontColor": font_color, + }, + "opacity": 1.0, + "glowSize": 0, + "glowColor": None, + } + + # Add x,y only for graphs with 3-4 nodes + total_nodes = len(dag["nodes"]) + if 3 <= total_nodes <= 4: + node_data["x"] = position_data["x"] + node_data["y"] = position_data["y"] + + nodes_list.append(node_data) execution_paths: int = calculate_execution_paths(dag) @@ -246,6 +367,8 @@ def render_interactive( if edge["is_router_path"]: edge_color = CREWAI_ORANGE edge_dashes = [15, 10] + if "router_path_label" in edge: + edge_label = edge["router_path_label"] elif edge["condition_type"] == "AND": edge_label = "AND" edge_color = CREWAI_ORANGE diff --git a/lib/crewai/src/crewai/flow/visualization/types.py b/lib/crewai/src/crewai/flow/visualization/types.py index 6cb165bc4..6ce57069e 100644 --- a/lib/crewai/src/crewai/flow/visualization/types.py +++ b/lib/crewai/src/crewai/flow/visualization/types.py @@ -10,6 +10,7 @@ class NodeMetadata(TypedDict, total=False): is_router: bool router_paths: list[str] condition_type: str | None + trigger_condition_type: str | None trigger_methods: list[str] trigger_condition: dict[str, Any] | None method_signature: dict[str, Any] @@ -22,13 +23,14 @@ class NodeMetadata(TypedDict, total=False): class_line_number: int -class StructureEdge(TypedDict): +class StructureEdge(TypedDict, total=False): """Represents a connection in the flow structure.""" source: str target: str condition_type: str | None is_router_path: bool + router_path_label: str class FlowStructure(TypedDict): From b2c278ed22c5dbf7c1f3e5a0b2eb891544b0670a Mon Sep 17 00:00:00 2001 From: Lorenze Jay <63378463+lorenzejay@users.noreply.github.com> Date: Thu, 6 Nov 2025 19:28:08 -0800 Subject: [PATCH 4/4] refactor: improve MCP tool execution handling with concurrent futures (#3854) - Enhanced the MCP tool execution in both synchronous and asynchronous contexts by utilizing for better event loop management. - Updated error handling to provide clearer messages for connection issues and task cancellations. - Added tests to validate MCP tool execution in both sync and async scenarios, ensuring robust functionality across different contexts. --- lib/crewai/src/crewai/agent/core.py | 32 +++++---- .../src/crewai/tools/mcp_native_tool.py | 14 +++- lib/crewai/tests/mcp/test_mcp_config.py | 66 ++++++++++++++++++- 3 files changed, 97 insertions(+), 15 deletions(-) diff --git a/lib/crewai/src/crewai/agent/core.py b/lib/crewai/src/crewai/agent/core.py index 54ae52ba8..1d94c4d19 100644 --- a/lib/crewai/src/crewai/agent/core.py +++ b/lib/crewai/src/crewai/agent/core.py @@ -860,19 +860,29 @@ class Agent(BaseAgent): try: try: - tools_list = asyncio.run(_setup_client_and_list_tools()) - except RuntimeError as e: - error_msg = str(e).lower() - if "cancel scope" in error_msg or "task" in error_msg: + asyncio.get_running_loop() + import concurrent.futures + + with concurrent.futures.ThreadPoolExecutor() as executor: + future = executor.submit( + asyncio.run, _setup_client_and_list_tools() + ) + tools_list = future.result() + except RuntimeError: + try: + tools_list = asyncio.run(_setup_client_and_list_tools()) + except RuntimeError as e: + error_msg = str(e).lower() + if "cancel scope" in error_msg or "task" in error_msg: + raise ConnectionError( + "MCP connection failed due to event loop cleanup issues. " + "This may be due to authentication errors or server unavailability." + ) from e + except asyncio.CancelledError as e: raise ConnectionError( - "MCP connection failed due to event loop cleanup issues. " - "This may be due to authentication errors or server unavailability." + "MCP connection was cancelled. This may indicate an authentication " + "error or server unavailability." ) from e - except asyncio.CancelledError as e: - raise ConnectionError( - "MCP connection was cancelled. This may indicate an authentication " - "error or server unavailability." - ) from e if mcp_config.tool_filter: filtered_tools = [] diff --git a/lib/crewai/src/crewai/tools/mcp_native_tool.py b/lib/crewai/src/crewai/tools/mcp_native_tool.py index c10d51eee..f25b2f4d7 100644 --- a/lib/crewai/src/crewai/tools/mcp_native_tool.py +++ b/lib/crewai/src/crewai/tools/mcp_native_tool.py @@ -86,9 +86,17 @@ class MCPNativeTool(BaseTool): Result from the MCP tool execution. """ try: - # Always use asyncio.run() to create a fresh event loop - # This ensures the async context managers work correctly - return asyncio.run(self._run_async(**kwargs)) + try: + asyncio.get_running_loop() + + import concurrent.futures + + with concurrent.futures.ThreadPoolExecutor() as executor: + coro = self._run_async(**kwargs) + future = executor.submit(asyncio.run, coro) + return future.result() + except RuntimeError: + return asyncio.run(self._run_async(**kwargs)) except Exception as e: raise RuntimeError( diff --git a/lib/crewai/tests/mcp/test_mcp_config.py b/lib/crewai/tests/mcp/test_mcp_config.py index 627ceb6e2..e55a7d504 100644 --- a/lib/crewai/tests/mcp/test_mcp_config.py +++ b/lib/crewai/tests/mcp/test_mcp_config.py @@ -1,4 +1,5 @@ -from unittest.mock import AsyncMock, patch +import asyncio +from unittest.mock import AsyncMock, MagicMock, patch import pytest from crewai.agent.core import Agent @@ -134,3 +135,66 @@ def test_agent_with_sse_mcp_config(mock_tool_definitions): transport = call_args.kwargs["transport"] assert transport.url == "https://api.example.com/mcp/sse" assert transport.headers == {"Authorization": "Bearer test_token"} + + +def test_mcp_tool_execution_in_sync_context(mock_tool_definitions): + """Test MCPNativeTool execution in synchronous context (normal crew execution).""" + http_config = MCPServerHTTP(url="https://api.example.com/mcp") + + with patch("crewai.agent.core.MCPClient") as mock_client_class: + mock_client = AsyncMock() + mock_client.list_tools = AsyncMock(return_value=mock_tool_definitions) + mock_client.connected = False + mock_client.connect = AsyncMock() + mock_client.disconnect = AsyncMock() + mock_client.call_tool = AsyncMock(return_value="test result") + mock_client_class.return_value = mock_client + + agent = Agent( + role="Test Agent", + goal="Test goal", + backstory="Test backstory", + mcps=[http_config], + ) + + tools = agent.get_mcp_tools([http_config]) + assert len(tools) == 2 + + + tool = tools[0] + result = tool.run(query="test query") + + assert result == "test result" + mock_client.call_tool.assert_called() + + +@pytest.mark.asyncio +async def test_mcp_tool_execution_in_async_context(mock_tool_definitions): + """Test MCPNativeTool execution in async context (e.g., from a Flow).""" + http_config = MCPServerHTTP(url="https://api.example.com/mcp") + + with patch("crewai.agent.core.MCPClient") as mock_client_class: + mock_client = AsyncMock() + mock_client.list_tools = AsyncMock(return_value=mock_tool_definitions) + mock_client.connected = False + mock_client.connect = AsyncMock() + mock_client.disconnect = AsyncMock() + mock_client.call_tool = AsyncMock(return_value="test result") + mock_client_class.return_value = mock_client + + agent = Agent( + role="Test Agent", + goal="Test goal", + backstory="Test backstory", + mcps=[http_config], + ) + + tools = agent.get_mcp_tools([http_config]) + assert len(tools) == 2 + + + tool = tools[0] + result = tool.run(query="test query") + + assert result == "test result" + mock_client.call_tool.assert_called()