mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-08 15:48:29 +00:00
Compare commits
12 Commits
devin/1762
...
gl/feat/py
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
81850350e8 | ||
|
|
b2c278ed22 | ||
|
|
f6aed9798b | ||
|
|
40a2d387a1 | ||
|
|
6f36d7003b | ||
|
|
7404d8f198 | ||
|
|
138b9af274 | ||
|
|
9e5906c52f | ||
|
|
fc521839e4 | ||
|
|
a5e0803f20 | ||
|
|
c4279b0339 | ||
|
|
965aa48ea1 |
@@ -11,9 +11,13 @@ The [Model Context Protocol](https://modelcontextprotocol.io/introduction) (MCP)
|
||||
|
||||
CrewAI offers **two approaches** for MCP integration:
|
||||
|
||||
### Simple DSL Integration** (Recommended)
|
||||
### 🚀 **Simple DSL Integration** (Recommended)
|
||||
|
||||
Use the `mcps` field directly on agents for seamless MCP tool integration:
|
||||
Use the `mcps` field directly on agents for seamless MCP tool integration. The DSL supports both **string references** (for quick setup) and **structured configurations** (for full control).
|
||||
|
||||
#### String-Based References (Quick Setup)
|
||||
|
||||
Perfect for remote HTTPS servers and CrewAI AMP marketplace:
|
||||
|
||||
```python
|
||||
from crewai import Agent
|
||||
@@ -32,6 +36,46 @@ agent = Agent(
|
||||
# MCP tools are now automatically available to your agent!
|
||||
```
|
||||
|
||||
#### Structured Configurations (Full Control)
|
||||
|
||||
For complete control over connection settings, tool filtering, and all transport types:
|
||||
|
||||
```python
|
||||
from crewai import Agent
|
||||
from crewai.mcp import MCPServerStdio, MCPServerHTTP, MCPServerSSE
|
||||
from crewai.mcp.filters import create_static_tool_filter
|
||||
|
||||
agent = Agent(
|
||||
role="Advanced Research Analyst",
|
||||
goal="Research with full control over MCP connections",
|
||||
backstory="Expert researcher with advanced tool access",
|
||||
mcps=[
|
||||
# Stdio transport for local servers
|
||||
MCPServerStdio(
|
||||
command="npx",
|
||||
args=["-y", "@modelcontextprotocol/server-filesystem"],
|
||||
env={"API_KEY": "your_key"},
|
||||
tool_filter=create_static_tool_filter(
|
||||
allowed_tool_names=["read_file", "list_directory"]
|
||||
),
|
||||
cache_tools_list=True,
|
||||
),
|
||||
# HTTP/Streamable HTTP transport for remote servers
|
||||
MCPServerHTTP(
|
||||
url="https://api.example.com/mcp",
|
||||
headers={"Authorization": "Bearer your_token"},
|
||||
streamable=True,
|
||||
cache_tools_list=True,
|
||||
),
|
||||
# SSE transport for real-time streaming
|
||||
MCPServerSSE(
|
||||
url="https://stream.example.com/mcp/sse",
|
||||
headers={"Authorization": "Bearer your_token"},
|
||||
),
|
||||
]
|
||||
)
|
||||
```
|
||||
|
||||
### 🔧 **Advanced: MCPServerAdapter** (For Complex Scenarios)
|
||||
|
||||
For advanced use cases requiring manual connection management, the `crewai-tools` library provides the `MCPServerAdapter` class.
|
||||
@@ -68,12 +112,14 @@ uv pip install 'crewai-tools[mcp]'
|
||||
|
||||
## Quick Start: Simple DSL Integration
|
||||
|
||||
The easiest way to integrate MCP servers is using the `mcps` field on your agents:
|
||||
The easiest way to integrate MCP servers is using the `mcps` field on your agents. You can use either string references or structured configurations.
|
||||
|
||||
### Quick Start with String References
|
||||
|
||||
```python
|
||||
from crewai import Agent, Task, Crew
|
||||
|
||||
# Create agent with MCP tools
|
||||
# Create agent with MCP tools using string references
|
||||
research_agent = Agent(
|
||||
role="Research Analyst",
|
||||
goal="Find and analyze information using advanced search tools",
|
||||
@@ -96,13 +142,53 @@ crew = Crew(agents=[research_agent], tasks=[research_task])
|
||||
result = crew.kickoff()
|
||||
```
|
||||
|
||||
### Quick Start with Structured Configurations
|
||||
|
||||
```python
|
||||
from crewai import Agent, Task, Crew
|
||||
from crewai.mcp import MCPServerStdio, MCPServerHTTP, MCPServerSSE
|
||||
|
||||
# Create agent with structured MCP configurations
|
||||
research_agent = Agent(
|
||||
role="Research Analyst",
|
||||
goal="Find and analyze information using advanced search tools",
|
||||
backstory="Expert researcher with access to multiple data sources",
|
||||
mcps=[
|
||||
# Local stdio server
|
||||
MCPServerStdio(
|
||||
command="python",
|
||||
args=["local_server.py"],
|
||||
env={"API_KEY": "your_key"},
|
||||
),
|
||||
# Remote HTTP server
|
||||
MCPServerHTTP(
|
||||
url="https://api.research.com/mcp",
|
||||
headers={"Authorization": "Bearer your_token"},
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
# Create task
|
||||
research_task = Task(
|
||||
description="Research the latest developments in AI agent frameworks",
|
||||
expected_output="Comprehensive research report with citations",
|
||||
agent=research_agent
|
||||
)
|
||||
|
||||
# Create and run crew
|
||||
crew = Crew(agents=[research_agent], tasks=[research_task])
|
||||
result = crew.kickoff()
|
||||
```
|
||||
|
||||
That's it! The MCP tools are automatically discovered and available to your agent.
|
||||
|
||||
## MCP Reference Formats
|
||||
|
||||
The `mcps` field supports various reference formats for maximum flexibility:
|
||||
The `mcps` field supports both **string references** (for quick setup) and **structured configurations** (for full control). You can mix both formats in the same list.
|
||||
|
||||
### External MCP Servers
|
||||
### String-Based References
|
||||
|
||||
#### External MCP Servers
|
||||
|
||||
```python
|
||||
mcps=[
|
||||
@@ -117,7 +203,7 @@ mcps=[
|
||||
]
|
||||
```
|
||||
|
||||
### CrewAI AMP Marketplace
|
||||
#### CrewAI AMP Marketplace
|
||||
|
||||
```python
|
||||
mcps=[
|
||||
@@ -133,17 +219,166 @@ mcps=[
|
||||
]
|
||||
```
|
||||
|
||||
### Mixed References
|
||||
### Structured Configurations
|
||||
|
||||
#### Stdio Transport (Local Servers)
|
||||
|
||||
Perfect for local MCP servers that run as processes:
|
||||
|
||||
```python
|
||||
from crewai.mcp import MCPServerStdio
|
||||
from crewai.mcp.filters import create_static_tool_filter
|
||||
|
||||
mcps=[
|
||||
"https://external-api.com/mcp", # External server
|
||||
"https://weather.service.com/mcp#forecast", # Specific external tool
|
||||
"crewai-amp:financial-insights", # AMP service
|
||||
"crewai-amp:data-analysis#sentiment_tool" # Specific AMP tool
|
||||
MCPServerStdio(
|
||||
command="npx",
|
||||
args=["-y", "@modelcontextprotocol/server-filesystem"],
|
||||
env={"API_KEY": "your_key"},
|
||||
tool_filter=create_static_tool_filter(
|
||||
allowed_tool_names=["read_file", "write_file"]
|
||||
),
|
||||
cache_tools_list=True,
|
||||
),
|
||||
# Python-based server
|
||||
MCPServerStdio(
|
||||
command="python",
|
||||
args=["path/to/server.py"],
|
||||
env={"UV_PYTHON": "3.12", "API_KEY": "your_key"},
|
||||
),
|
||||
]
|
||||
```
|
||||
|
||||
#### HTTP/Streamable HTTP Transport (Remote Servers)
|
||||
|
||||
For remote MCP servers over HTTP/HTTPS:
|
||||
|
||||
```python
|
||||
from crewai.mcp import MCPServerHTTP
|
||||
|
||||
mcps=[
|
||||
# Streamable HTTP (default)
|
||||
MCPServerHTTP(
|
||||
url="https://api.example.com/mcp",
|
||||
headers={"Authorization": "Bearer your_token"},
|
||||
streamable=True,
|
||||
cache_tools_list=True,
|
||||
),
|
||||
# Standard HTTP
|
||||
MCPServerHTTP(
|
||||
url="https://api.example.com/mcp",
|
||||
headers={"Authorization": "Bearer your_token"},
|
||||
streamable=False,
|
||||
),
|
||||
]
|
||||
```
|
||||
|
||||
#### SSE Transport (Real-Time Streaming)
|
||||
|
||||
For remote servers using Server-Sent Events:
|
||||
|
||||
```python
|
||||
from crewai.mcp import MCPServerSSE
|
||||
|
||||
mcps=[
|
||||
MCPServerSSE(
|
||||
url="https://stream.example.com/mcp/sse",
|
||||
headers={"Authorization": "Bearer your_token"},
|
||||
cache_tools_list=True,
|
||||
),
|
||||
]
|
||||
```
|
||||
|
||||
### Mixed References
|
||||
|
||||
You can combine string references and structured configurations:
|
||||
|
||||
```python
|
||||
from crewai.mcp import MCPServerStdio, MCPServerHTTP
|
||||
|
||||
mcps=[
|
||||
# String references
|
||||
"https://external-api.com/mcp", # External server
|
||||
"crewai-amp:financial-insights", # AMP service
|
||||
|
||||
# Structured configurations
|
||||
MCPServerStdio(
|
||||
command="npx",
|
||||
args=["-y", "@modelcontextprotocol/server-filesystem"],
|
||||
),
|
||||
MCPServerHTTP(
|
||||
url="https://api.example.com/mcp",
|
||||
headers={"Authorization": "Bearer token"},
|
||||
),
|
||||
]
|
||||
```
|
||||
|
||||
### Tool Filtering
|
||||
|
||||
Structured configurations support advanced tool filtering:
|
||||
|
||||
```python
|
||||
from crewai.mcp import MCPServerStdio
|
||||
from crewai.mcp.filters import create_static_tool_filter, create_dynamic_tool_filter, ToolFilterContext
|
||||
|
||||
# Static filtering (allow/block lists)
|
||||
static_filter = create_static_tool_filter(
|
||||
allowed_tool_names=["read_file", "write_file"],
|
||||
blocked_tool_names=["delete_file"],
|
||||
)
|
||||
|
||||
# Dynamic filtering (context-aware)
|
||||
def dynamic_filter(context: ToolFilterContext, tool: dict) -> bool:
|
||||
# Block dangerous tools for certain agent roles
|
||||
if context.agent.role == "Code Reviewer":
|
||||
if "delete" in tool.get("name", "").lower():
|
||||
return False
|
||||
return True
|
||||
|
||||
mcps=[
|
||||
MCPServerStdio(
|
||||
command="npx",
|
||||
args=["-y", "@modelcontextprotocol/server-filesystem"],
|
||||
tool_filter=static_filter, # or dynamic_filter
|
||||
),
|
||||
]
|
||||
```
|
||||
|
||||
## Configuration Parameters
|
||||
|
||||
Each transport type supports specific configuration options:
|
||||
|
||||
### MCPServerStdio Parameters
|
||||
|
||||
- **`command`** (required): Command to execute (e.g., `"python"`, `"node"`, `"npx"`, `"uvx"`)
|
||||
- **`args`** (optional): List of command arguments (e.g., `["server.py"]` or `["-y", "@mcp/server"]`)
|
||||
- **`env`** (optional): Dictionary of environment variables to pass to the process
|
||||
- **`tool_filter`** (optional): Tool filter function for filtering available tools
|
||||
- **`cache_tools_list`** (optional): Whether to cache the tool list for faster subsequent access (default: `False`)
|
||||
|
||||
### MCPServerHTTP Parameters
|
||||
|
||||
- **`url`** (required): Server URL (e.g., `"https://api.example.com/mcp"`)
|
||||
- **`headers`** (optional): Dictionary of HTTP headers for authentication or other purposes
|
||||
- **`streamable`** (optional): Whether to use streamable HTTP transport (default: `True`)
|
||||
- **`tool_filter`** (optional): Tool filter function for filtering available tools
|
||||
- **`cache_tools_list`** (optional): Whether to cache the tool list for faster subsequent access (default: `False`)
|
||||
|
||||
### MCPServerSSE Parameters
|
||||
|
||||
- **`url`** (required): Server URL (e.g., `"https://api.example.com/mcp/sse"`)
|
||||
- **`headers`** (optional): Dictionary of HTTP headers for authentication or other purposes
|
||||
- **`tool_filter`** (optional): Tool filter function for filtering available tools
|
||||
- **`cache_tools_list`** (optional): Whether to cache the tool list for faster subsequent access (default: `False`)
|
||||
|
||||
### Common Parameters
|
||||
|
||||
All transport types support:
|
||||
- **`tool_filter`**: Filter function to control which tools are available. Can be:
|
||||
- `None` (default): All tools are available
|
||||
- Static filter: Created with `create_static_tool_filter()` for allow/block lists
|
||||
- Dynamic filter: Created with `create_dynamic_tool_filter()` for context-aware filtering
|
||||
- **`cache_tools_list`**: When `True`, caches the tool list after first discovery to improve performance on subsequent connections
|
||||
|
||||
## Key Features
|
||||
|
||||
- 🔄 **Automatic Tool Discovery**: Tools are automatically discovered and integrated
|
||||
@@ -152,26 +387,47 @@ mcps=[
|
||||
- 🛡️ **Error Resilience**: Graceful handling of unavailable servers
|
||||
- ⏱️ **Timeout Protection**: Built-in timeouts prevent hanging connections
|
||||
- 📊 **Transparent Integration**: Works seamlessly with existing CrewAI features
|
||||
- 🔧 **Full Transport Support**: Stdio, HTTP/Streamable HTTP, and SSE transports
|
||||
- 🎯 **Advanced Filtering**: Static and dynamic tool filtering capabilities
|
||||
- 🔐 **Flexible Authentication**: Support for headers, environment variables, and query parameters
|
||||
|
||||
## Error Handling
|
||||
|
||||
The MCP DSL integration is designed to be resilient:
|
||||
The MCP DSL integration is designed to be resilient and handles failures gracefully:
|
||||
|
||||
```python
|
||||
from crewai import Agent
|
||||
from crewai.mcp import MCPServerStdio, MCPServerHTTP
|
||||
|
||||
agent = Agent(
|
||||
role="Resilient Agent",
|
||||
goal="Continue working despite server issues",
|
||||
backstory="Agent that handles failures gracefully",
|
||||
mcps=[
|
||||
# String references
|
||||
"https://reliable-server.com/mcp", # Will work
|
||||
"https://unreachable-server.com/mcp", # Will be skipped gracefully
|
||||
"https://slow-server.com/mcp", # Will timeout gracefully
|
||||
"crewai-amp:working-service" # Will work
|
||||
"crewai-amp:working-service", # Will work
|
||||
|
||||
# Structured configs
|
||||
MCPServerStdio(
|
||||
command="python",
|
||||
args=["reliable_server.py"], # Will work
|
||||
),
|
||||
MCPServerHTTP(
|
||||
url="https://slow-server.com/mcp", # Will timeout gracefully
|
||||
),
|
||||
]
|
||||
)
|
||||
# Agent will use tools from working servers and log warnings for failing ones
|
||||
```
|
||||
|
||||
All connection errors are handled gracefully:
|
||||
- **Connection failures**: Logged as warnings, agent continues with available tools
|
||||
- **Timeout errors**: Connections timeout after 30 seconds (configurable)
|
||||
- **Authentication errors**: Logged clearly for debugging
|
||||
- **Invalid configurations**: Validation errors are raised at agent creation time
|
||||
|
||||
## Advanced: MCPServerAdapter
|
||||
|
||||
For complex scenarios requiring manual connection management, use the `MCPServerAdapter` class from `crewai-tools`. Using a Python context manager (`with` statement) is the recommended approach as it automatically handles starting and stopping the connection to the MCP server.
|
||||
|
||||
@@ -40,6 +40,16 @@ from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
|
||||
from crewai.knowledge.utils.knowledge_utils import extract_knowledge_context
|
||||
from crewai.lite_agent import LiteAgent
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.mcp import (
|
||||
MCPClient,
|
||||
MCPServerConfig,
|
||||
MCPServerHTTP,
|
||||
MCPServerSSE,
|
||||
MCPServerStdio,
|
||||
)
|
||||
from crewai.mcp.transports.http import HTTPTransport
|
||||
from crewai.mcp.transports.sse import SSETransport
|
||||
from crewai.mcp.transports.stdio import StdioTransport
|
||||
from crewai.memory.contextual.contextual_memory import ContextualMemory
|
||||
from crewai.rag.embeddings.types import EmbedderConfig
|
||||
from crewai.security.fingerprint import Fingerprint
|
||||
@@ -108,6 +118,7 @@ class Agent(BaseAgent):
|
||||
"""
|
||||
|
||||
_times_executed: int = PrivateAttr(default=0)
|
||||
_mcp_clients: list[Any] = PrivateAttr(default_factory=list)
|
||||
max_execution_time: int | None = Field(
|
||||
default=None,
|
||||
description="Maximum execution time for an agent to execute a task",
|
||||
@@ -526,6 +537,9 @@ class Agent(BaseAgent):
|
||||
self,
|
||||
event=AgentExecutionCompletedEvent(agent=self, task=task, output=result),
|
||||
)
|
||||
|
||||
self._cleanup_mcp_clients()
|
||||
|
||||
return result
|
||||
|
||||
def _execute_with_timeout(self, task_prompt: str, task: Task, timeout: int) -> Any:
|
||||
@@ -604,22 +618,22 @@ class Agent(BaseAgent):
|
||||
response_template=self.response_template,
|
||||
).task_execution()
|
||||
|
||||
stop_words = [self.i18n.slice("observation")]
|
||||
stop_sequences = [self.i18n.slice("observation")]
|
||||
|
||||
if self.response_template:
|
||||
stop_words.append(
|
||||
stop_sequences.append(
|
||||
self.response_template.split("{{ .Response }}")[1].strip()
|
||||
)
|
||||
|
||||
self.agent_executor = CrewAgentExecutor(
|
||||
llm=self.llm,
|
||||
llm=self.llm, # type: ignore[arg-type]
|
||||
task=task, # type: ignore[arg-type]
|
||||
agent=self,
|
||||
crew=self.crew,
|
||||
tools=parsed_tools,
|
||||
prompt=prompt,
|
||||
original_tools=raw_tools,
|
||||
stop_words=stop_words,
|
||||
stop_sequences=stop_sequences,
|
||||
max_iter=self.max_iter,
|
||||
tools_handler=self.tools_handler,
|
||||
tools_names=get_tool_names(parsed_tools),
|
||||
@@ -649,30 +663,70 @@ class Agent(BaseAgent):
|
||||
self._logger.log("error", f"Error getting platform tools: {e!s}")
|
||||
return []
|
||||
|
||||
def get_mcp_tools(self, mcps: list[str]) -> list[BaseTool]:
|
||||
"""Convert MCP server references to CrewAI tools."""
|
||||
def get_mcp_tools(self, mcps: list[str | MCPServerConfig]) -> list[BaseTool]:
|
||||
"""Convert MCP server references/configs to CrewAI tools.
|
||||
|
||||
Supports both string references (backwards compatible) and structured
|
||||
configuration objects (MCPServerStdio, MCPServerHTTP, MCPServerSSE).
|
||||
|
||||
Args:
|
||||
mcps: List of MCP server references (strings) or configurations.
|
||||
|
||||
Returns:
|
||||
List of BaseTool instances from MCP servers.
|
||||
"""
|
||||
all_tools = []
|
||||
clients = []
|
||||
|
||||
for mcp_ref in mcps:
|
||||
try:
|
||||
if mcp_ref.startswith("crewai-amp:"):
|
||||
tools = self._get_amp_mcp_tools(mcp_ref)
|
||||
elif mcp_ref.startswith("https://"):
|
||||
tools = self._get_external_mcp_tools(mcp_ref)
|
||||
else:
|
||||
continue
|
||||
for mcp_config in mcps:
|
||||
if isinstance(mcp_config, str):
|
||||
tools = self._get_mcp_tools_from_string(mcp_config)
|
||||
else:
|
||||
tools, client = self._get_native_mcp_tools(mcp_config)
|
||||
if client:
|
||||
clients.append(client)
|
||||
|
||||
all_tools.extend(tools)
|
||||
self._logger.log(
|
||||
"info", f"Successfully loaded {len(tools)} tools from {mcp_ref}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
self._logger.log("warning", f"Skipping MCP {mcp_ref} due to error: {e}")
|
||||
continue
|
||||
all_tools.extend(tools)
|
||||
|
||||
# Store clients for cleanup
|
||||
self._mcp_clients.extend(clients)
|
||||
return all_tools
|
||||
|
||||
def _cleanup_mcp_clients(self) -> None:
|
||||
"""Cleanup MCP client connections after task execution."""
|
||||
if not self._mcp_clients:
|
||||
return
|
||||
|
||||
async def _disconnect_all() -> None:
|
||||
for client in self._mcp_clients:
|
||||
if client and hasattr(client, "connected") and client.connected:
|
||||
await client.disconnect()
|
||||
|
||||
try:
|
||||
asyncio.run(_disconnect_all())
|
||||
except Exception as e:
|
||||
self._logger.log("error", f"Error during MCP client cleanup: {e}")
|
||||
finally:
|
||||
self._mcp_clients.clear()
|
||||
|
||||
def _get_mcp_tools_from_string(self, mcp_ref: str) -> list[BaseTool]:
|
||||
"""Get tools from legacy string-based MCP references.
|
||||
|
||||
This method maintains backwards compatibility with string-based
|
||||
MCP references (https://... and crewai-amp:...).
|
||||
|
||||
Args:
|
||||
mcp_ref: String reference to MCP server.
|
||||
|
||||
Returns:
|
||||
List of BaseTool instances.
|
||||
"""
|
||||
if mcp_ref.startswith("crewai-amp:"):
|
||||
return self._get_amp_mcp_tools(mcp_ref)
|
||||
if mcp_ref.startswith("https://"):
|
||||
return self._get_external_mcp_tools(mcp_ref)
|
||||
return []
|
||||
|
||||
def _get_external_mcp_tools(self, mcp_ref: str) -> list[BaseTool]:
|
||||
"""Get tools from external HTTPS MCP server with graceful error handling."""
|
||||
from crewai.tools.mcp_tool_wrapper import MCPToolWrapper
|
||||
@@ -731,6 +785,164 @@ class Agent(BaseAgent):
|
||||
)
|
||||
return []
|
||||
|
||||
def _get_native_mcp_tools(
|
||||
self, mcp_config: MCPServerConfig
|
||||
) -> tuple[list[BaseTool], Any | None]:
|
||||
"""Get tools from MCP server using structured configuration.
|
||||
|
||||
This method creates an MCP client based on the configuration type,
|
||||
connects to the server, discovers tools, applies filtering, and
|
||||
returns wrapped tools along with the client instance for cleanup.
|
||||
|
||||
Args:
|
||||
mcp_config: MCP server configuration (MCPServerStdio, MCPServerHTTP, or MCPServerSSE).
|
||||
|
||||
Returns:
|
||||
Tuple of (list of BaseTool instances, MCPClient instance for cleanup).
|
||||
"""
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
from crewai.tools.mcp_native_tool import MCPNativeTool
|
||||
|
||||
if isinstance(mcp_config, MCPServerStdio):
|
||||
transport = StdioTransport(
|
||||
command=mcp_config.command,
|
||||
args=mcp_config.args,
|
||||
env=mcp_config.env,
|
||||
)
|
||||
server_name = f"{mcp_config.command}_{'_'.join(mcp_config.args)}"
|
||||
elif isinstance(mcp_config, MCPServerHTTP):
|
||||
transport = HTTPTransport(
|
||||
url=mcp_config.url,
|
||||
headers=mcp_config.headers,
|
||||
streamable=mcp_config.streamable,
|
||||
)
|
||||
server_name = self._extract_server_name(mcp_config.url)
|
||||
elif isinstance(mcp_config, MCPServerSSE):
|
||||
transport = SSETransport(
|
||||
url=mcp_config.url,
|
||||
headers=mcp_config.headers,
|
||||
)
|
||||
server_name = self._extract_server_name(mcp_config.url)
|
||||
else:
|
||||
raise ValueError(f"Unsupported MCP server config type: {type(mcp_config)}")
|
||||
|
||||
client = MCPClient(
|
||||
transport=transport,
|
||||
cache_tools_list=mcp_config.cache_tools_list,
|
||||
)
|
||||
|
||||
async def _setup_client_and_list_tools() -> list[dict[str, Any]]:
|
||||
"""Async helper to connect and list tools in same event loop."""
|
||||
|
||||
try:
|
||||
if not client.connected:
|
||||
await client.connect()
|
||||
|
||||
tools_list = await client.list_tools()
|
||||
|
||||
try:
|
||||
await client.disconnect()
|
||||
# Small delay to allow background tasks to finish cleanup
|
||||
# This helps prevent "cancel scope in different task" errors
|
||||
# when asyncio.run() closes the event loop
|
||||
await asyncio.sleep(0.1)
|
||||
except Exception as e:
|
||||
self._logger.log("error", f"Error during disconnect: {e}")
|
||||
|
||||
return tools_list
|
||||
except Exception as e:
|
||||
if client.connected:
|
||||
await client.disconnect()
|
||||
await asyncio.sleep(0.1)
|
||||
raise RuntimeError(
|
||||
f"Error during setup client and list tools: {e}"
|
||||
) from e
|
||||
|
||||
try:
|
||||
try:
|
||||
asyncio.get_running_loop()
|
||||
import concurrent.futures
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
future = executor.submit(
|
||||
asyncio.run, _setup_client_and_list_tools()
|
||||
)
|
||||
tools_list = future.result()
|
||||
except RuntimeError:
|
||||
try:
|
||||
tools_list = asyncio.run(_setup_client_and_list_tools())
|
||||
except RuntimeError as e:
|
||||
error_msg = str(e).lower()
|
||||
if "cancel scope" in error_msg or "task" in error_msg:
|
||||
raise ConnectionError(
|
||||
"MCP connection failed due to event loop cleanup issues. "
|
||||
"This may be due to authentication errors or server unavailability."
|
||||
) from e
|
||||
except asyncio.CancelledError as e:
|
||||
raise ConnectionError(
|
||||
"MCP connection was cancelled. This may indicate an authentication "
|
||||
"error or server unavailability."
|
||||
) from e
|
||||
|
||||
if mcp_config.tool_filter:
|
||||
filtered_tools = []
|
||||
for tool in tools_list:
|
||||
if callable(mcp_config.tool_filter):
|
||||
try:
|
||||
from crewai.mcp.filters import ToolFilterContext
|
||||
|
||||
context = ToolFilterContext(
|
||||
agent=self,
|
||||
server_name=server_name,
|
||||
run_context=None,
|
||||
)
|
||||
if mcp_config.tool_filter(context, tool):
|
||||
filtered_tools.append(tool)
|
||||
except (TypeError, AttributeError):
|
||||
if mcp_config.tool_filter(tool):
|
||||
filtered_tools.append(tool)
|
||||
else:
|
||||
# Not callable - include tool
|
||||
filtered_tools.append(tool)
|
||||
tools_list = filtered_tools
|
||||
|
||||
tools = []
|
||||
for tool_def in tools_list:
|
||||
tool_name = tool_def.get("name", "")
|
||||
if not tool_name:
|
||||
continue
|
||||
|
||||
# Convert inputSchema to Pydantic model if present
|
||||
args_schema = None
|
||||
if tool_def.get("inputSchema"):
|
||||
args_schema = self._json_schema_to_pydantic(
|
||||
tool_name, tool_def["inputSchema"]
|
||||
)
|
||||
|
||||
tool_schema = {
|
||||
"description": tool_def.get("description", ""),
|
||||
"args_schema": args_schema,
|
||||
}
|
||||
|
||||
try:
|
||||
native_tool = MCPNativeTool(
|
||||
mcp_client=client,
|
||||
tool_name=tool_name,
|
||||
tool_schema=tool_schema,
|
||||
server_name=server_name,
|
||||
)
|
||||
tools.append(native_tool)
|
||||
except Exception as e:
|
||||
self._logger.log("error", f"Failed to create native MCP tool: {e}")
|
||||
continue
|
||||
|
||||
return cast(list[BaseTool], tools), client
|
||||
except Exception as e:
|
||||
if client.connected:
|
||||
asyncio.run(client.disconnect())
|
||||
|
||||
raise RuntimeError(f"Failed to get native MCP tools: {e}") from e
|
||||
|
||||
def _get_amp_mcp_tools(self, amp_ref: str) -> list[BaseTool]:
|
||||
"""Get tools from CrewAI AMP MCP marketplace."""
|
||||
# Parse: "crewai-amp:mcp-name" or "crewai-amp:mcp-name#tool_name"
|
||||
@@ -762,7 +974,9 @@ class Agent(BaseAgent):
|
||||
path = parsed.path.replace("/", "_").strip("_")
|
||||
return f"{domain}_{path}" if path else domain
|
||||
|
||||
def _get_mcp_tool_schemas(self, server_params: dict) -> dict[str, dict]:
|
||||
def _get_mcp_tool_schemas(
|
||||
self, server_params: dict[str, Any]
|
||||
) -> dict[str, dict[str, Any]] | Any:
|
||||
"""Get tool schemas from MCP server for wrapper creation with caching."""
|
||||
server_url = server_params["url"]
|
||||
|
||||
@@ -794,7 +1008,7 @@ class Agent(BaseAgent):
|
||||
|
||||
async def _get_mcp_tool_schemas_async(
|
||||
self, server_params: dict[str, Any]
|
||||
) -> dict[str, dict]:
|
||||
) -> dict[str, dict[str, Any]]:
|
||||
"""Async implementation of MCP tool schema retrieval with timeouts and retries."""
|
||||
server_url = server_params["url"]
|
||||
return await self._retry_mcp_discovery(
|
||||
@@ -802,7 +1016,7 @@ class Agent(BaseAgent):
|
||||
)
|
||||
|
||||
async def _retry_mcp_discovery(
|
||||
self, operation_func, server_url: str
|
||||
self, operation_func: Any, server_url: str
|
||||
) -> dict[str, dict[str, Any]]:
|
||||
"""Retry MCP discovery operation with exponential backoff, avoiding try-except in loop."""
|
||||
last_error = None
|
||||
@@ -833,7 +1047,7 @@ class Agent(BaseAgent):
|
||||
|
||||
@staticmethod
|
||||
async def _attempt_mcp_discovery(
|
||||
operation_func, server_url: str
|
||||
operation_func: Any, server_url: str
|
||||
) -> tuple[dict[str, dict[str, Any]] | None, str, bool]:
|
||||
"""Attempt single MCP discovery operation and return (result, error_message, should_retry)."""
|
||||
try:
|
||||
@@ -937,13 +1151,13 @@ class Agent(BaseAgent):
|
||||
Field(..., description=field_description),
|
||||
)
|
||||
else:
|
||||
field_definitions[field_name] = (
|
||||
field_definitions[field_name] = ( # type: ignore[assignment]
|
||||
field_type | None,
|
||||
Field(default=None, description=field_description),
|
||||
)
|
||||
|
||||
model_name = f"{tool_name.replace('-', '_').replace(' ', '_')}Schema"
|
||||
return create_model(model_name, **field_definitions)
|
||||
return create_model(model_name, **field_definitions) # type: ignore[no-any-return,call-overload]
|
||||
|
||||
def _json_type_to_python(self, field_schema: dict[str, Any]) -> type:
|
||||
"""Convert JSON Schema type to Python type.
|
||||
@@ -963,12 +1177,12 @@ class Agent(BaseAgent):
|
||||
if "const" in option:
|
||||
types.append(str)
|
||||
else:
|
||||
types.append(self._json_type_to_python(option))
|
||||
types.append(self._json_type_to_python(option)) # type: ignore[arg-type]
|
||||
unique_types = list(set(types))
|
||||
if len(unique_types) > 1:
|
||||
result = unique_types[0]
|
||||
for t in unique_types[1:]:
|
||||
result = result | t
|
||||
result = result | t # type: ignore[assignment]
|
||||
return result
|
||||
return unique_types[0]
|
||||
|
||||
@@ -981,10 +1195,10 @@ class Agent(BaseAgent):
|
||||
"object": dict,
|
||||
}
|
||||
|
||||
return type_mapping.get(json_type, Any)
|
||||
return type_mapping.get(json_type, Any) # type: ignore[arg-type]
|
||||
|
||||
@staticmethod
|
||||
def _fetch_amp_mcp_servers(mcp_name: str) -> list[dict]:
|
||||
def _fetch_amp_mcp_servers(mcp_name: str) -> list[dict[str, Any]]:
|
||||
"""Fetch MCP server configurations from CrewAI AMP API."""
|
||||
# TODO: Implement AMP API call to "integrations/mcps" endpoint
|
||||
# Should return list of server configs with URLs
|
||||
@@ -1223,7 +1437,7 @@ class Agent(BaseAgent):
|
||||
goal=self.goal,
|
||||
backstory=self.backstory,
|
||||
llm=self.llm,
|
||||
tools=self.tools or [],
|
||||
tools=self.tools,
|
||||
max_iterations=self.max_iter,
|
||||
max_execution_time=self.max_execution_time,
|
||||
respect_context_window=self.respect_context_window,
|
||||
|
||||
@@ -25,6 +25,7 @@ from crewai.agents.tools_handler import ToolsHandler
|
||||
from crewai.knowledge.knowledge import Knowledge
|
||||
from crewai.knowledge.knowledge_config import KnowledgeConfig
|
||||
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
|
||||
from crewai.mcp.config import MCPServerConfig
|
||||
from crewai.rag.embeddings.types import EmbedderConfig
|
||||
from crewai.security.security_config import SecurityConfig
|
||||
from crewai.tools.base_tool import BaseTool, Tool
|
||||
@@ -136,7 +137,7 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta):
|
||||
default=False,
|
||||
description="Enable agent to delegate and ask questions among each other.",
|
||||
)
|
||||
tools: list[BaseTool] | None = Field(
|
||||
tools: list[BaseTool] = Field(
|
||||
default_factory=list, description="Tools at agents' disposal"
|
||||
)
|
||||
max_iter: int = Field(
|
||||
@@ -194,7 +195,7 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta):
|
||||
default=None,
|
||||
description="List of applications or application/action combinations that the agent can access through CrewAI Platform. Can contain app names (e.g., 'gmail') or specific actions (e.g., 'gmail/send_email')",
|
||||
)
|
||||
mcps: list[str] | None = Field(
|
||||
mcps: list[str | MCPServerConfig] | None = Field(
|
||||
default=None,
|
||||
description="List of MCP server references. Supports 'https://server.com/path' for external servers and 'crewai-amp:mcp-name' for AMP marketplace. Use '#tool_name' suffix for specific tools.",
|
||||
)
|
||||
@@ -253,20 +254,36 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta):
|
||||
|
||||
@field_validator("mcps")
|
||||
@classmethod
|
||||
def validate_mcps(cls, mcps: list[str] | None) -> list[str] | None:
|
||||
def validate_mcps(
|
||||
cls, mcps: list[str | MCPServerConfig] | None
|
||||
) -> list[str | MCPServerConfig] | None:
|
||||
"""Validate MCP server references and configurations.
|
||||
|
||||
Supports both string references (for backwards compatibility) and
|
||||
structured configuration objects (MCPServerStdio, MCPServerHTTP, MCPServerSSE).
|
||||
"""
|
||||
if not mcps:
|
||||
return mcps
|
||||
|
||||
validated_mcps = []
|
||||
for mcp in mcps:
|
||||
if mcp.startswith(("https://", "crewai-amp:")):
|
||||
if isinstance(mcp, str):
|
||||
if mcp.startswith(("https://", "crewai-amp:")):
|
||||
validated_mcps.append(mcp)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid MCP reference: {mcp}. "
|
||||
"String references must start with 'https://' or 'crewai-amp:'"
|
||||
)
|
||||
|
||||
elif isinstance(mcp, (MCPServerConfig)):
|
||||
validated_mcps.append(mcp)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid MCP reference: {mcp}. Must start with 'https://' or 'crewai-amp:'"
|
||||
f"Invalid MCP configuration: {type(mcp)}. "
|
||||
"Must be a string reference or MCPServerConfig instance."
|
||||
)
|
||||
|
||||
return list(set(validated_mcps))
|
||||
return validated_mcps
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_and_set_attributes(self) -> Self:
|
||||
@@ -343,7 +360,7 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta):
|
||||
"""Get platform tools for the specified list of applications and/or application/action combinations."""
|
||||
|
||||
@abstractmethod
|
||||
def get_mcp_tools(self, mcps: list[str]) -> list[BaseTool]:
|
||||
def get_mcp_tools(self, mcps: list[str | MCPServerConfig]) -> list[BaseTool]:
|
||||
"""Get MCP tools for the specified list of MCP server references."""
|
||||
|
||||
def copy(self) -> Self: # type: ignore # Signature of "copy" incompatible with supertype "BaseModel"
|
||||
|
||||
@@ -73,7 +73,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
max_iter: int,
|
||||
tools: list[CrewStructuredTool],
|
||||
tools_names: str,
|
||||
stop_words: list[str],
|
||||
stop_sequences: list[str],
|
||||
tools_description: str,
|
||||
tools_handler: ToolsHandler,
|
||||
step_callback: Any = None,
|
||||
@@ -95,7 +95,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
max_iter: Maximum iterations.
|
||||
tools: Available tools.
|
||||
tools_names: Tool names string.
|
||||
stop_words: Stop word list.
|
||||
stop_sequences: Stop sequences list for halting generation.
|
||||
tools_description: Tool descriptions.
|
||||
tools_handler: Tool handler instance.
|
||||
step_callback: Optional step callback.
|
||||
@@ -114,7 +114,6 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
self.prompt = prompt
|
||||
self.tools = tools
|
||||
self.tools_names = tools_names
|
||||
self.stop = stop_words
|
||||
self.max_iter = max_iter
|
||||
self.callbacks = callbacks or []
|
||||
self._printer: Printer = Printer()
|
||||
@@ -131,15 +130,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
self.iterations = 0
|
||||
self.log_error_after = 3
|
||||
if self.llm:
|
||||
# This may be mutating the shared llm object and needs further evaluation
|
||||
existing_stop = getattr(self.llm, "stop", [])
|
||||
self.llm.stop = list(
|
||||
set(
|
||||
existing_stop + self.stop
|
||||
if isinstance(existing_stop, list)
|
||||
else self.stop
|
||||
)
|
||||
)
|
||||
self.llm.stop_sequences.extend(stop_sequences)
|
||||
|
||||
@property
|
||||
def use_stop_words(self) -> bool:
|
||||
@@ -148,7 +139,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
Returns:
|
||||
bool: True if tool should be used or not.
|
||||
"""
|
||||
return self.llm.supports_stop_words() if self.llm else False
|
||||
return self.llm.supports_stop_words if self.llm else False
|
||||
|
||||
def invoke(self, inputs: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Execute the agent with given inputs.
|
||||
|
||||
@@ -16,7 +16,6 @@ from crewai.events.base_event_listener import BaseEventListener
|
||||
from crewai.events.depends import Depends
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.handler_graph import CircularDependencyError
|
||||
|
||||
from crewai.events.types.crew_events import (
|
||||
CrewKickoffCompletedEvent,
|
||||
CrewKickoffFailedEvent,
|
||||
@@ -61,6 +60,14 @@ from crewai.events.types.logging_events import (
|
||||
AgentLogsExecutionEvent,
|
||||
AgentLogsStartedEvent,
|
||||
)
|
||||
from crewai.events.types.mcp_events import (
|
||||
MCPConnectionCompletedEvent,
|
||||
MCPConnectionFailedEvent,
|
||||
MCPConnectionStartedEvent,
|
||||
MCPToolExecutionCompletedEvent,
|
||||
MCPToolExecutionFailedEvent,
|
||||
MCPToolExecutionStartedEvent,
|
||||
)
|
||||
from crewai.events.types.memory_events import (
|
||||
MemoryQueryCompletedEvent,
|
||||
MemoryQueryFailedEvent,
|
||||
@@ -153,6 +160,12 @@ __all__ = [
|
||||
"LiteAgentExecutionCompletedEvent",
|
||||
"LiteAgentExecutionErrorEvent",
|
||||
"LiteAgentExecutionStartedEvent",
|
||||
"MCPConnectionCompletedEvent",
|
||||
"MCPConnectionFailedEvent",
|
||||
"MCPConnectionStartedEvent",
|
||||
"MCPToolExecutionCompletedEvent",
|
||||
"MCPToolExecutionFailedEvent",
|
||||
"MCPToolExecutionStartedEvent",
|
||||
"MemoryQueryCompletedEvent",
|
||||
"MemoryQueryFailedEvent",
|
||||
"MemoryQueryStartedEvent",
|
||||
|
||||
@@ -65,6 +65,14 @@ from crewai.events.types.logging_events import (
|
||||
AgentLogsExecutionEvent,
|
||||
AgentLogsStartedEvent,
|
||||
)
|
||||
from crewai.events.types.mcp_events import (
|
||||
MCPConnectionCompletedEvent,
|
||||
MCPConnectionFailedEvent,
|
||||
MCPConnectionStartedEvent,
|
||||
MCPToolExecutionCompletedEvent,
|
||||
MCPToolExecutionFailedEvent,
|
||||
MCPToolExecutionStartedEvent,
|
||||
)
|
||||
from crewai.events.types.reasoning_events import (
|
||||
AgentReasoningCompletedEvent,
|
||||
AgentReasoningFailedEvent,
|
||||
@@ -615,5 +623,67 @@ class EventListener(BaseEventListener):
|
||||
event.total_turns,
|
||||
)
|
||||
|
||||
# ----------- MCP EVENTS -----------
|
||||
|
||||
@crewai_event_bus.on(MCPConnectionStartedEvent)
|
||||
def on_mcp_connection_started(source, event: MCPConnectionStartedEvent):
|
||||
self.formatter.handle_mcp_connection_started(
|
||||
event.server_name,
|
||||
event.server_url,
|
||||
event.transport_type,
|
||||
event.is_reconnect,
|
||||
event.connect_timeout,
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(MCPConnectionCompletedEvent)
|
||||
def on_mcp_connection_completed(source, event: MCPConnectionCompletedEvent):
|
||||
self.formatter.handle_mcp_connection_completed(
|
||||
event.server_name,
|
||||
event.server_url,
|
||||
event.transport_type,
|
||||
event.connection_duration_ms,
|
||||
event.is_reconnect,
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(MCPConnectionFailedEvent)
|
||||
def on_mcp_connection_failed(source, event: MCPConnectionFailedEvent):
|
||||
self.formatter.handle_mcp_connection_failed(
|
||||
event.server_name,
|
||||
event.server_url,
|
||||
event.transport_type,
|
||||
event.error,
|
||||
event.error_type,
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(MCPToolExecutionStartedEvent)
|
||||
def on_mcp_tool_execution_started(source, event: MCPToolExecutionStartedEvent):
|
||||
self.formatter.handle_mcp_tool_execution_started(
|
||||
event.server_name,
|
||||
event.tool_name,
|
||||
event.tool_args,
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(MCPToolExecutionCompletedEvent)
|
||||
def on_mcp_tool_execution_completed(
|
||||
source, event: MCPToolExecutionCompletedEvent
|
||||
):
|
||||
self.formatter.handle_mcp_tool_execution_completed(
|
||||
event.server_name,
|
||||
event.tool_name,
|
||||
event.tool_args,
|
||||
event.result,
|
||||
event.execution_duration_ms,
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(MCPToolExecutionFailedEvent)
|
||||
def on_mcp_tool_execution_failed(source, event: MCPToolExecutionFailedEvent):
|
||||
self.formatter.handle_mcp_tool_execution_failed(
|
||||
event.server_name,
|
||||
event.tool_name,
|
||||
event.tool_args,
|
||||
event.error,
|
||||
event.error_type,
|
||||
)
|
||||
|
||||
|
||||
event_listener = EventListener()
|
||||
|
||||
@@ -40,6 +40,14 @@ from crewai.events.types.llm_guardrail_events import (
|
||||
LLMGuardrailCompletedEvent,
|
||||
LLMGuardrailStartedEvent,
|
||||
)
|
||||
from crewai.events.types.mcp_events import (
|
||||
MCPConnectionCompletedEvent,
|
||||
MCPConnectionFailedEvent,
|
||||
MCPConnectionStartedEvent,
|
||||
MCPToolExecutionCompletedEvent,
|
||||
MCPToolExecutionFailedEvent,
|
||||
MCPToolExecutionStartedEvent,
|
||||
)
|
||||
from crewai.events.types.memory_events import (
|
||||
MemoryQueryCompletedEvent,
|
||||
MemoryQueryFailedEvent,
|
||||
@@ -115,4 +123,10 @@ EventTypes = (
|
||||
| MemoryQueryFailedEvent
|
||||
| MemoryRetrievalStartedEvent
|
||||
| MemoryRetrievalCompletedEvent
|
||||
| MCPConnectionStartedEvent
|
||||
| MCPConnectionCompletedEvent
|
||||
| MCPConnectionFailedEvent
|
||||
| MCPToolExecutionStartedEvent
|
||||
| MCPToolExecutionCompletedEvent
|
||||
| MCPToolExecutionFailedEvent
|
||||
)
|
||||
|
||||
85
lib/crewai/src/crewai/events/types/mcp_events.py
Normal file
85
lib/crewai/src/crewai/events/types/mcp_events.py
Normal file
@@ -0,0 +1,85 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from crewai.events.base_events import BaseEvent
|
||||
|
||||
|
||||
class MCPEvent(BaseEvent):
|
||||
"""Base event for MCP operations."""
|
||||
|
||||
server_name: str
|
||||
server_url: str | None = None
|
||||
transport_type: str | None = None # "stdio", "http", "sse"
|
||||
agent_id: str | None = None
|
||||
agent_role: str | None = None
|
||||
from_agent: Any | None = None
|
||||
from_task: Any | None = None
|
||||
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
self._set_agent_params(data)
|
||||
self._set_task_params(data)
|
||||
|
||||
|
||||
class MCPConnectionStartedEvent(MCPEvent):
|
||||
"""Event emitted when starting to connect to an MCP server."""
|
||||
|
||||
type: str = "mcp_connection_started"
|
||||
connect_timeout: int | None = None
|
||||
is_reconnect: bool = (
|
||||
False # True if this is a reconnection, False for first connection
|
||||
)
|
||||
|
||||
|
||||
class MCPConnectionCompletedEvent(MCPEvent):
|
||||
"""Event emitted when successfully connected to an MCP server."""
|
||||
|
||||
type: str = "mcp_connection_completed"
|
||||
started_at: datetime | None = None
|
||||
completed_at: datetime | None = None
|
||||
connection_duration_ms: float | None = None
|
||||
is_reconnect: bool = (
|
||||
False # True if this was a reconnection, False for first connection
|
||||
)
|
||||
|
||||
|
||||
class MCPConnectionFailedEvent(MCPEvent):
|
||||
"""Event emitted when connection to an MCP server fails."""
|
||||
|
||||
type: str = "mcp_connection_failed"
|
||||
error: str
|
||||
error_type: str | None = None # "timeout", "authentication", "network", etc.
|
||||
started_at: datetime | None = None
|
||||
failed_at: datetime | None = None
|
||||
|
||||
|
||||
class MCPToolExecutionStartedEvent(MCPEvent):
|
||||
"""Event emitted when starting to execute an MCP tool."""
|
||||
|
||||
type: str = "mcp_tool_execution_started"
|
||||
tool_name: str
|
||||
tool_args: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class MCPToolExecutionCompletedEvent(MCPEvent):
|
||||
"""Event emitted when MCP tool execution completes."""
|
||||
|
||||
type: str = "mcp_tool_execution_completed"
|
||||
tool_name: str
|
||||
tool_args: dict[str, Any] | None = None
|
||||
result: Any | None = None
|
||||
started_at: datetime | None = None
|
||||
completed_at: datetime | None = None
|
||||
execution_duration_ms: float | None = None
|
||||
|
||||
|
||||
class MCPToolExecutionFailedEvent(MCPEvent):
|
||||
"""Event emitted when MCP tool execution fails."""
|
||||
|
||||
type: str = "mcp_tool_execution_failed"
|
||||
tool_name: str
|
||||
tool_args: dict[str, Any] | None = None
|
||||
error: str
|
||||
error_type: str | None = None # "timeout", "validation", "server_error", etc.
|
||||
started_at: datetime | None = None
|
||||
failed_at: datetime | None = None
|
||||
@@ -2248,3 +2248,203 @@ class ConsoleFormatter:
|
||||
|
||||
self.current_a2a_conversation_branch = None
|
||||
self.current_a2a_turn_count = 0
|
||||
|
||||
# ----------- MCP EVENTS -----------
|
||||
|
||||
def handle_mcp_connection_started(
|
||||
self,
|
||||
server_name: str,
|
||||
server_url: str | None = None,
|
||||
transport_type: str | None = None,
|
||||
is_reconnect: bool = False,
|
||||
connect_timeout: int | None = None,
|
||||
) -> None:
|
||||
"""Handle MCP connection started event."""
|
||||
if not self.verbose:
|
||||
return
|
||||
|
||||
content = Text()
|
||||
reconnect_text = " (Reconnecting)" if is_reconnect else ""
|
||||
content.append(f"MCP Connection Started{reconnect_text}\n\n", style="cyan bold")
|
||||
content.append("Server: ", style="white")
|
||||
content.append(f"{server_name}\n", style="cyan")
|
||||
|
||||
if server_url:
|
||||
content.append("URL: ", style="white")
|
||||
content.append(f"{server_url}\n", style="cyan dim")
|
||||
|
||||
if transport_type:
|
||||
content.append("Transport: ", style="white")
|
||||
content.append(f"{transport_type}\n", style="cyan")
|
||||
|
||||
if connect_timeout:
|
||||
content.append("Timeout: ", style="white")
|
||||
content.append(f"{connect_timeout}s\n", style="cyan")
|
||||
|
||||
panel = self.create_panel(content, "🔌 MCP Connection", "cyan")
|
||||
self.print(panel)
|
||||
self.print()
|
||||
|
||||
def handle_mcp_connection_completed(
|
||||
self,
|
||||
server_name: str,
|
||||
server_url: str | None = None,
|
||||
transport_type: str | None = None,
|
||||
connection_duration_ms: float | None = None,
|
||||
is_reconnect: bool = False,
|
||||
) -> None:
|
||||
"""Handle MCP connection completed event."""
|
||||
if not self.verbose:
|
||||
return
|
||||
|
||||
content = Text()
|
||||
reconnect_text = " (Reconnected)" if is_reconnect else ""
|
||||
content.append(
|
||||
f"MCP Connection Completed{reconnect_text}\n\n", style="green bold"
|
||||
)
|
||||
content.append("Server: ", style="white")
|
||||
content.append(f"{server_name}\n", style="green")
|
||||
|
||||
if server_url:
|
||||
content.append("URL: ", style="white")
|
||||
content.append(f"{server_url}\n", style="green dim")
|
||||
|
||||
if transport_type:
|
||||
content.append("Transport: ", style="white")
|
||||
content.append(f"{transport_type}\n", style="green")
|
||||
|
||||
if connection_duration_ms is not None:
|
||||
content.append("Duration: ", style="white")
|
||||
content.append(f"{connection_duration_ms:.2f}ms\n", style="green")
|
||||
|
||||
panel = self.create_panel(content, "✅ MCP Connected", "green")
|
||||
self.print(panel)
|
||||
self.print()
|
||||
|
||||
def handle_mcp_connection_failed(
|
||||
self,
|
||||
server_name: str,
|
||||
server_url: str | None = None,
|
||||
transport_type: str | None = None,
|
||||
error: str = "",
|
||||
error_type: str | None = None,
|
||||
) -> None:
|
||||
"""Handle MCP connection failed event."""
|
||||
if not self.verbose:
|
||||
return
|
||||
|
||||
content = Text()
|
||||
content.append("MCP Connection Failed\n\n", style="red bold")
|
||||
content.append("Server: ", style="white")
|
||||
content.append(f"{server_name}\n", style="red")
|
||||
|
||||
if server_url:
|
||||
content.append("URL: ", style="white")
|
||||
content.append(f"{server_url}\n", style="red dim")
|
||||
|
||||
if transport_type:
|
||||
content.append("Transport: ", style="white")
|
||||
content.append(f"{transport_type}\n", style="red")
|
||||
|
||||
if error_type:
|
||||
content.append("Error Type: ", style="white")
|
||||
content.append(f"{error_type}\n", style="red")
|
||||
|
||||
if error:
|
||||
content.append("\nError: ", style="white bold")
|
||||
error_preview = error[:500] + "..." if len(error) > 500 else error
|
||||
content.append(f"{error_preview}\n", style="red")
|
||||
|
||||
panel = self.create_panel(content, "❌ MCP Connection Failed", "red")
|
||||
self.print(panel)
|
||||
self.print()
|
||||
|
||||
def handle_mcp_tool_execution_started(
|
||||
self,
|
||||
server_name: str,
|
||||
tool_name: str,
|
||||
tool_args: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""Handle MCP tool execution started event."""
|
||||
if not self.verbose:
|
||||
return
|
||||
|
||||
content = self.create_status_content(
|
||||
"MCP Tool Execution Started",
|
||||
tool_name,
|
||||
"yellow",
|
||||
tool_args=tool_args or {},
|
||||
Server=server_name,
|
||||
)
|
||||
|
||||
panel = self.create_panel(content, "🔧 MCP Tool", "yellow")
|
||||
self.print(panel)
|
||||
self.print()
|
||||
|
||||
def handle_mcp_tool_execution_completed(
|
||||
self,
|
||||
server_name: str,
|
||||
tool_name: str,
|
||||
tool_args: dict[str, Any] | None = None,
|
||||
result: Any | None = None,
|
||||
execution_duration_ms: float | None = None,
|
||||
) -> None:
|
||||
"""Handle MCP tool execution completed event."""
|
||||
if not self.verbose:
|
||||
return
|
||||
|
||||
content = self.create_status_content(
|
||||
"MCP Tool Execution Completed",
|
||||
tool_name,
|
||||
"green",
|
||||
tool_args=tool_args or {},
|
||||
Server=server_name,
|
||||
)
|
||||
|
||||
if execution_duration_ms is not None:
|
||||
content.append("Duration: ", style="white")
|
||||
content.append(f"{execution_duration_ms:.2f}ms\n", style="green")
|
||||
|
||||
if result is not None:
|
||||
result_str = str(result)
|
||||
if len(result_str) > 500:
|
||||
result_str = result_str[:497] + "..."
|
||||
content.append("\nResult: ", style="white bold")
|
||||
content.append(f"{result_str}\n", style="green")
|
||||
|
||||
panel = self.create_panel(content, "✅ MCP Tool Completed", "green")
|
||||
self.print(panel)
|
||||
self.print()
|
||||
|
||||
def handle_mcp_tool_execution_failed(
|
||||
self,
|
||||
server_name: str,
|
||||
tool_name: str,
|
||||
tool_args: dict[str, Any] | None = None,
|
||||
error: str = "",
|
||||
error_type: str | None = None,
|
||||
) -> None:
|
||||
"""Handle MCP tool execution failed event."""
|
||||
if not self.verbose:
|
||||
return
|
||||
|
||||
content = self.create_status_content(
|
||||
"MCP Tool Execution Failed",
|
||||
tool_name,
|
||||
"red",
|
||||
tool_args=tool_args or {},
|
||||
Server=server_name,
|
||||
)
|
||||
|
||||
if error_type:
|
||||
content.append("Error Type: ", style="white")
|
||||
content.append(f"{error_type}\n", style="red")
|
||||
|
||||
if error:
|
||||
content.append("\nError: ", style="white bold")
|
||||
error_preview = error[:500] + "..." if len(error) > 500 else error
|
||||
content.append(f"{error_preview}\n", style="red")
|
||||
|
||||
panel = self.create_panel(content, "❌ MCP Tool Failed", "red")
|
||||
self.print(panel)
|
||||
self.print()
|
||||
|
||||
@@ -428,6 +428,8 @@ class FlowMeta(type):
|
||||
possible_returns = get_possible_return_constants(attr_value)
|
||||
if possible_returns:
|
||||
router_paths[attr_name] = possible_returns
|
||||
else:
|
||||
router_paths[attr_name] = []
|
||||
|
||||
cls._start_methods = start_methods # type: ignore[attr-defined]
|
||||
cls._listeners = listeners # type: ignore[attr-defined]
|
||||
|
||||
@@ -21,6 +21,7 @@ P = ParamSpec("P")
|
||||
R = TypeVar("R", covariant=True)
|
||||
|
||||
FlowMethodName = NewType("FlowMethodName", str)
|
||||
FlowRouteName = NewType("FlowRouteName", str)
|
||||
PendingListenerKey = NewType(
|
||||
"PendingListenerKey",
|
||||
Annotated[str, "nested flow conditions use 'listener_name:object_id'"],
|
||||
|
||||
@@ -19,11 +19,11 @@ import ast
|
||||
from collections import defaultdict, deque
|
||||
import inspect
|
||||
import textwrap
|
||||
from typing import Any, TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from typing_extensions import TypeIs
|
||||
|
||||
from crewai.flow.constants import OR_CONDITION, AND_CONDITION
|
||||
from crewai.flow.constants import AND_CONDITION, OR_CONDITION
|
||||
from crewai.flow.flow_wrappers import (
|
||||
FlowCondition,
|
||||
FlowConditions,
|
||||
@@ -33,6 +33,7 @@ from crewai.flow.flow_wrappers import (
|
||||
from crewai.flow.types import FlowMethodCallable, FlowMethodName
|
||||
from crewai.utilities.printer import Printer
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.flow.flow import Flow
|
||||
|
||||
@@ -40,6 +41,22 @@ _printer = Printer()
|
||||
|
||||
|
||||
def get_possible_return_constants(function: Any) -> list[str] | None:
|
||||
"""Extract possible string return values from a function using AST parsing.
|
||||
|
||||
This function analyzes the source code of a router method to identify
|
||||
all possible string values it might return. It handles:
|
||||
- Direct string literals: return "value"
|
||||
- Variable assignments: x = "value"; return x
|
||||
- Dictionary lookups: d = {"k": "v"}; return d[key]
|
||||
- Conditional returns: return "a" if cond else "b"
|
||||
- State attributes: return self.state.attr (infers from class context)
|
||||
|
||||
Args:
|
||||
function: The function to analyze.
|
||||
|
||||
Returns:
|
||||
List of possible string return values, or None if analysis fails.
|
||||
"""
|
||||
try:
|
||||
source = inspect.getsource(function)
|
||||
except OSError:
|
||||
@@ -82,6 +99,7 @@ def get_possible_return_constants(function: Any) -> list[str] | None:
|
||||
return_values: set[str] = set()
|
||||
dict_definitions: dict[str, list[str]] = {}
|
||||
variable_values: dict[str, list[str]] = {}
|
||||
state_attribute_values: dict[str, list[str]] = {}
|
||||
|
||||
def extract_string_constants(node: ast.expr) -> list[str]:
|
||||
"""Recursively extract all string constants from an AST node."""
|
||||
@@ -91,6 +109,17 @@ def get_possible_return_constants(function: Any) -> list[str] | None:
|
||||
elif isinstance(node, ast.IfExp):
|
||||
strings.extend(extract_string_constants(node.body))
|
||||
strings.extend(extract_string_constants(node.orelse))
|
||||
elif isinstance(node, ast.Call):
|
||||
if (
|
||||
isinstance(node.func, ast.Attribute)
|
||||
and node.func.attr == "get"
|
||||
and len(node.args) >= 2
|
||||
):
|
||||
default_arg = node.args[1]
|
||||
if isinstance(default_arg, ast.Constant) and isinstance(
|
||||
default_arg.value, str
|
||||
):
|
||||
strings.append(default_arg.value)
|
||||
return strings
|
||||
|
||||
class VariableAssignmentVisitor(ast.NodeVisitor):
|
||||
@@ -124,6 +153,22 @@ def get_possible_return_constants(function: Any) -> list[str] | None:
|
||||
|
||||
self.generic_visit(node)
|
||||
|
||||
def get_attribute_chain(node: ast.expr) -> str | None:
|
||||
"""Extract the full attribute chain from an AST node.
|
||||
|
||||
Examples:
|
||||
self.state.run_type -> "self.state.run_type"
|
||||
x.y.z -> "x.y.z"
|
||||
simple_var -> "simple_var"
|
||||
"""
|
||||
if isinstance(node, ast.Name):
|
||||
return node.id
|
||||
if isinstance(node, ast.Attribute):
|
||||
base = get_attribute_chain(node.value)
|
||||
if base:
|
||||
return f"{base}.{node.attr}"
|
||||
return None
|
||||
|
||||
class ReturnVisitor(ast.NodeVisitor):
|
||||
def visit_Return(self, node: ast.Return) -> None:
|
||||
if (
|
||||
@@ -139,21 +184,94 @@ def get_possible_return_constants(function: Any) -> list[str] | None:
|
||||
for v in dict_definitions[var_name_dict]:
|
||||
return_values.add(v)
|
||||
elif node.value:
|
||||
var_name_ret: str | None = None
|
||||
if isinstance(node.value, ast.Name):
|
||||
var_name_ret = node.value.id
|
||||
elif isinstance(node.value, ast.Attribute):
|
||||
var_name_ret = f"{node.value.value.id if isinstance(node.value.value, ast.Name) else '_'}.{node.value.attr}"
|
||||
var_name_ret = get_attribute_chain(node.value)
|
||||
|
||||
if var_name_ret and var_name_ret in variable_values:
|
||||
for v in variable_values[var_name_ret]:
|
||||
return_values.add(v)
|
||||
elif var_name_ret and var_name_ret in state_attribute_values:
|
||||
for v in state_attribute_values[var_name_ret]:
|
||||
return_values.add(v)
|
||||
|
||||
self.generic_visit(node)
|
||||
|
||||
def visit_If(self, node: ast.If) -> None:
|
||||
self.generic_visit(node)
|
||||
|
||||
# Try to get the class context to infer state attribute values
|
||||
try:
|
||||
if hasattr(function, "__self__"):
|
||||
# Method is bound, get the class
|
||||
class_obj = function.__self__.__class__
|
||||
elif hasattr(function, "__qualname__") and "." in function.__qualname__:
|
||||
# Method is unbound but we can try to get class from module
|
||||
class_name = function.__qualname__.rsplit(".", 1)[0]
|
||||
if hasattr(function, "__globals__"):
|
||||
class_obj = function.__globals__.get(class_name)
|
||||
else:
|
||||
class_obj = None
|
||||
else:
|
||||
class_obj = None
|
||||
|
||||
if class_obj is not None:
|
||||
try:
|
||||
class_source = inspect.getsource(class_obj)
|
||||
class_source = textwrap.dedent(class_source)
|
||||
class_ast = ast.parse(class_source)
|
||||
|
||||
# Look for comparisons and assignments involving state attributes
|
||||
class StateAttributeVisitor(ast.NodeVisitor):
|
||||
def visit_Compare(self, node: ast.Compare) -> None:
|
||||
"""Find comparisons like: self.state.attr == "value" """
|
||||
left_attr = get_attribute_chain(node.left)
|
||||
|
||||
if left_attr:
|
||||
for comparator in node.comparators:
|
||||
if isinstance(comparator, ast.Constant) and isinstance(
|
||||
comparator.value, str
|
||||
):
|
||||
if left_attr not in state_attribute_values:
|
||||
state_attribute_values[left_attr] = []
|
||||
if (
|
||||
comparator.value
|
||||
not in state_attribute_values[left_attr]
|
||||
):
|
||||
state_attribute_values[left_attr].append(
|
||||
comparator.value
|
||||
)
|
||||
|
||||
# Also check right side
|
||||
for comparator in node.comparators:
|
||||
right_attr = get_attribute_chain(comparator)
|
||||
if (
|
||||
right_attr
|
||||
and isinstance(node.left, ast.Constant)
|
||||
and isinstance(node.left.value, str)
|
||||
):
|
||||
if right_attr not in state_attribute_values:
|
||||
state_attribute_values[right_attr] = []
|
||||
if (
|
||||
node.left.value
|
||||
not in state_attribute_values[right_attr]
|
||||
):
|
||||
state_attribute_values[right_attr].append(
|
||||
node.left.value
|
||||
)
|
||||
|
||||
self.generic_visit(node)
|
||||
|
||||
StateAttributeVisitor().visit(class_ast)
|
||||
except Exception as e:
|
||||
_printer.print(
|
||||
f"Could not analyze class context for {function.__name__}: {e}",
|
||||
color="yellow",
|
||||
)
|
||||
except Exception as e:
|
||||
_printer.print(
|
||||
f"Could not introspect class for {function.__name__}: {e}",
|
||||
color="yellow",
|
||||
)
|
||||
|
||||
VariableAssignmentVisitor().visit(code_ast)
|
||||
ReturnVisitor().visit(code_ast)
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -6,6 +6,7 @@
|
||||
<link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
|
||||
<link href="https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600;700&display=swap" rel="stylesheet">
|
||||
<link rel="stylesheet" href="'{{ css_path }}'" />
|
||||
<script src="https://unpkg.com/lucide@latest"></script>
|
||||
<script src="https://cdnjs.cloudflare.com/ajax/libs/prism/1.29.0/prism.min.js"></script>
|
||||
<script src="https://cdnjs.cloudflare.com/ajax/libs/prism/1.29.0/components/prism-python.min.js"></script>
|
||||
<script src="'{{ js_path }}'"></script>
|
||||
@@ -23,93 +24,129 @@
|
||||
<div class="drawer-title" id="drawer-node-name">Node Details</div>
|
||||
<div style="display: flex; align-items: center;">
|
||||
<button class="drawer-open-ide" id="drawer-open-ide" style="display: none;">
|
||||
<svg viewBox="0 0 16 16" fill="none" stroke="currentColor" stroke-width="2">
|
||||
<path d="M4 2 L12 2 L12 14 L4 14 Z" stroke-linecap="round" stroke-linejoin="round"/>
|
||||
<path d="M6 5 L10 5 M6 8 L10 8 M6 11 L10 11" stroke-linecap="round"/>
|
||||
</svg>
|
||||
<i data-lucide="file-code" style="width: 16px; height: 16px;"></i>
|
||||
Open in IDE
|
||||
</button>
|
||||
<button class="drawer-close" id="drawer-close">×</button>
|
||||
<button class="drawer-close" id="drawer-close">
|
||||
<i data-lucide="x" style="width: 20px; height: 20px;"></i>
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
<div class="drawer-content" id="drawer-content"></div>
|
||||
</div>
|
||||
|
||||
<div id="info">
|
||||
<div style="text-align: center; margin-bottom: 20px;">
|
||||
<div style="text-align: center;">
|
||||
<img src="https://cdn.prod.website-files.com/68de1ee6d7c127849807d7a6/68de1ee6d7c127849807d7ef_Logo.svg"
|
||||
alt="CrewAI Logo"
|
||||
style="width: 120px; height: auto;">
|
||||
</div>
|
||||
<h3>Flow Execution</h3>
|
||||
<div class="stats">
|
||||
<p><strong>Nodes:</strong> '{{ dag_nodes_count }}'</p>
|
||||
<p><strong>Edges:</strong> '{{ dag_edges_count }}'</p>
|
||||
<p><strong>Topological Paths:</strong> '{{ execution_paths }}'</p>
|
||||
</div>
|
||||
<div class="legend">
|
||||
<div class="legend-title">Node Types</div>
|
||||
<div class="legend-item">
|
||||
<div class="legend-color" style="background: '{{ CREWAI_ORANGE }}';"></div>
|
||||
<span>Start Methods</span>
|
||||
</div>
|
||||
<div class="legend-item">
|
||||
<div class="legend-color" style="background: '{{ DARK_GRAY }}'; border: 3px solid '{{ CREWAI_ORANGE }}';"></div>
|
||||
<span>Router Methods</span>
|
||||
</div>
|
||||
<div class="legend-item">
|
||||
<div class="legend-color" style="background: '{{ DARK_GRAY }}';"></div>
|
||||
<span>Listen Methods</span>
|
||||
</div>
|
||||
</div>
|
||||
<div class="legend">
|
||||
<div class="legend-title">Edge Types</div>
|
||||
<div class="legend-item">
|
||||
<svg width="24" height="12" style="margin-right: 12px;">
|
||||
<line x1="0" y1="6" x2="24" y2="6" stroke="'{{ CREWAI_ORANGE }}'" stroke-width="2" stroke-dasharray="5,5"/>
|
||||
</svg>
|
||||
<span>Router Paths</span>
|
||||
</div>
|
||||
<div class="legend-item">
|
||||
<svg width="24" height="12" style="margin-right: 12px;" class="legend-or-line">
|
||||
<line x1="0" y1="6" x2="24" y2="6" stroke="var(--edge-or-color)" stroke-width="2"/>
|
||||
</svg>
|
||||
<span>OR Conditions</span>
|
||||
</div>
|
||||
<div class="legend-item">
|
||||
<svg width="24" height="12" style="margin-right: 12px;">
|
||||
<line x1="0" y1="6" x2="24" y2="6" stroke="'{{ CREWAI_ORANGE }}'" stroke-width="2"/>
|
||||
</svg>
|
||||
<span>AND Conditions</span>
|
||||
</div>
|
||||
</div>
|
||||
<div class="instructions">
|
||||
<strong>Interactions:</strong><br>
|
||||
• Drag to pan<br>
|
||||
• Scroll to zoom<br><br>
|
||||
<strong>IDE:</strong>
|
||||
<select id="ide-selector" style="width: 100%; padding: 4px; margin-top: 4px; border-radius: 3px; border: 1px solid #e0e0e0; background: white; font-size: 12px; cursor: pointer; pointer-events: auto; position: relative; z-index: 10;">
|
||||
<option value="auto">Auto-detect</option>
|
||||
<option value="pycharm">PyCharm</option>
|
||||
<option value="vscode">VS Code</option>
|
||||
<option value="jetbrains">JetBrains (Toolbox)</option>
|
||||
</select>
|
||||
style="width: 144px; height: auto;">
|
||||
</div>
|
||||
</div>
|
||||
|
||||
|
||||
<!-- Custom navigation controls -->
|
||||
<div class="nav-controls">
|
||||
<div class="nav-button" id="theme-toggle" title="Toggle Dark Mode">🌙</div>
|
||||
<div class="nav-button" id="zoom-in" title="Zoom In">+</div>
|
||||
<div class="nav-button" id="zoom-out" title="Zoom Out">−</div>
|
||||
<div class="nav-button" id="fit" title="Fit to Screen">⊡</div>
|
||||
<div class="nav-button" id="export-png" title="Export to PNG">🖼</div>
|
||||
<div class="nav-button" id="export-pdf" title="Export to PDF">📄</div>
|
||||
<div class="nav-button" id="export-json" title="Export to JSON">{}</div>
|
||||
<div class="nav-button" id="theme-toggle" title="Toggle Dark Mode">
|
||||
<i data-lucide="moon" style="width: 18px; height: 18px;"></i>
|
||||
</div>
|
||||
<div class="nav-button" id="zoom-in" title="Zoom In">
|
||||
<i data-lucide="zoom-in" style="width: 18px; height: 18px;"></i>
|
||||
</div>
|
||||
<div class="nav-button" id="zoom-out" title="Zoom Out">
|
||||
<i data-lucide="zoom-out" style="width: 18px; height: 18px;"></i>
|
||||
</div>
|
||||
<div class="nav-button" id="fit" title="Fit to Screen">
|
||||
<i data-lucide="maximize-2" style="width: 18px; height: 18px;"></i>
|
||||
</div>
|
||||
<div class="nav-button" id="export-png" title="Export to PNG">
|
||||
<i data-lucide="image" style="width: 18px; height: 18px;"></i>
|
||||
</div>
|
||||
<div class="nav-button" id="export-pdf" title="Export to PDF">
|
||||
<i data-lucide="file-text" style="width: 18px; height: 18px;"></i>
|
||||
</div>
|
||||
<!-- <div class="nav-button" id="export-json" title="Export to JSON">
|
||||
<i data-lucide="braces" style="width: 18px; height: 18px;"></i>
|
||||
</div> -->
|
||||
</div>
|
||||
|
||||
<div id="network-container">
|
||||
<div id="network"></div>
|
||||
</div>
|
||||
|
||||
<!-- Info panel at bottom -->
|
||||
<div id="legend-panel">
|
||||
<!-- Stats Section -->
|
||||
<div class="legend-section">
|
||||
<div class="legend-stats-row">
|
||||
<div class="legend-stat-item">
|
||||
<span class="stat-value">'{{ dag_nodes_count }}'</span>
|
||||
<span class="stat-label">Nodes</span>
|
||||
</div>
|
||||
<div class="legend-stat-item">
|
||||
<span class="stat-value">'{{ dag_edges_count }}'</span>
|
||||
<span class="stat-label">Edges</span>
|
||||
</div>
|
||||
<div class="legend-stat-item">
|
||||
<span class="stat-value">'{{ execution_paths }}'</span>
|
||||
<span class="stat-label">Paths</span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Node Types Section -->
|
||||
<div class="legend-section">
|
||||
<div class="legend-group">
|
||||
<div class="legend-item-compact">
|
||||
<div class="legend-color-small" style="background: var(--node-bg-start);"></div>
|
||||
<span>Start</span>
|
||||
</div>
|
||||
<div class="legend-item-compact">
|
||||
<div class="legend-color-small" style="background: var(--node-bg-router); border: 2px solid var(--node-border-start);"></div>
|
||||
<span>Router</span>
|
||||
</div>
|
||||
<div class="legend-item-compact">
|
||||
<div class="legend-color-small" style="background: var(--node-bg-listen); border: 2px solid var(--node-border-listen);"></div>
|
||||
<span>Listen</span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Edge Types Section -->
|
||||
<div class="legend-section">
|
||||
<div class="legend-group">
|
||||
<div class="legend-item-compact">
|
||||
<svg>
|
||||
<line x1="0" y1="7" x2="29" y2="7" stroke="var(--edge-router-color)" stroke-width="2" stroke-dasharray="4,4"/>
|
||||
</svg>
|
||||
<span>Router</span>
|
||||
</div>
|
||||
<div class="legend-item-compact">
|
||||
<svg class="legend-or-line">
|
||||
<line x1="0" y1="7" x2="29" y2="7" stroke="var(--edge-or-color)" stroke-width="2"/>
|
||||
</svg>
|
||||
<span>OR</span>
|
||||
</div>
|
||||
<div class="legend-item-compact">
|
||||
<svg>
|
||||
<line x1="0" y1="7" x2="29" y2="7" stroke="var(--edge-router-color)" stroke-width="2"/>
|
||||
</svg>
|
||||
<span>AND</span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- IDE Selector Section -->
|
||||
<div class="legend-section">
|
||||
<div class="legend-ide-column">
|
||||
<label class="legend-ide-label">IDE</label>
|
||||
<select id="ide-selector" class="legend-ide-select">
|
||||
<option value="auto">Auto-detect</option>
|
||||
<option value="pycharm">PyCharm</option>
|
||||
<option value="vscode">VS Code</option>
|
||||
<option value="jetbrains">JetBrains</option>
|
||||
</select>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
|
||||
@@ -13,6 +13,14 @@
|
||||
--edge-label-text: '{{ GRAY }}';
|
||||
--edge-label-bg: rgba(255, 255, 255, 0.8);
|
||||
--edge-or-color: #000000;
|
||||
--edge-router-color: '{{ CREWAI_ORANGE }}';
|
||||
--node-border-start: #C94238;
|
||||
--node-border-listen: #3D3D3D;
|
||||
--node-bg-start: #FF7066;
|
||||
--node-bg-router: #FFFFFF;
|
||||
--node-bg-listen: #FFFFFF;
|
||||
--node-text-color: #FFFFFF;
|
||||
--nav-button-hover: #f5f5f5;
|
||||
}
|
||||
|
||||
[data-theme="dark"] {
|
||||
@@ -30,6 +38,14 @@
|
||||
--edge-label-text: #c9d1d9;
|
||||
--edge-label-bg: rgba(22, 27, 34, 0.9);
|
||||
--edge-or-color: #ffffff;
|
||||
--edge-router-color: '{{ CREWAI_ORANGE }}';
|
||||
--node-border-start: #FF5A50;
|
||||
--node-border-listen: #666666;
|
||||
--node-bg-start: #B33830;
|
||||
--node-bg-router: #3D3D3D;
|
||||
--node-bg-listen: #3D3D3D;
|
||||
--node-text-color: #FFFFFF;
|
||||
--nav-button-hover: #30363d;
|
||||
}
|
||||
|
||||
@keyframes dash {
|
||||
@@ -72,12 +88,10 @@ body {
|
||||
position: absolute;
|
||||
top: 20px;
|
||||
left: 20px;
|
||||
background: var(--bg-secondary);
|
||||
background: transparent;
|
||||
padding: 20px;
|
||||
border-radius: 8px;
|
||||
box-shadow: 0 4px 12px var(--shadow-strong);
|
||||
max-width: 320px;
|
||||
border: 1px solid var(--border-color);
|
||||
z-index: 10000;
|
||||
pointer-events: auto;
|
||||
transition: background 0.3s ease, border-color 0.3s ease, box-shadow 0.3s ease;
|
||||
@@ -125,12 +139,16 @@ h3 {
|
||||
margin-right: 12px;
|
||||
border-radius: 3px;
|
||||
box-sizing: border-box;
|
||||
transition: background 0.3s ease, border-color 0.3s ease;
|
||||
}
|
||||
.legend-item span {
|
||||
color: var(--text-secondary);
|
||||
font-size: 13px;
|
||||
transition: color 0.3s ease;
|
||||
}
|
||||
.legend-item svg line {
|
||||
transition: stroke 0.3s ease;
|
||||
}
|
||||
.instructions {
|
||||
margin-top: 15px;
|
||||
padding-top: 15px;
|
||||
@@ -155,7 +173,7 @@ h3 {
|
||||
bottom: 20px;
|
||||
right: auto;
|
||||
display: grid;
|
||||
grid-template-columns: repeat(4, 40px);
|
||||
grid-template-columns: repeat(3, 40px);
|
||||
gap: 8px;
|
||||
z-index: 10002;
|
||||
pointer-events: auto;
|
||||
@@ -165,10 +183,187 @@ h3 {
|
||||
.nav-controls.drawer-open {
|
||||
}
|
||||
|
||||
#legend-panel {
|
||||
position: fixed;
|
||||
left: 164px;
|
||||
bottom: 20px;
|
||||
right: 20px;
|
||||
height: 92px;
|
||||
background: var(--bg-secondary);
|
||||
backdrop-filter: blur(12px) saturate(180%);
|
||||
-webkit-backdrop-filter: blur(12px) saturate(180%);
|
||||
border: 1px solid var(--border-subtle);
|
||||
border-radius: 6px;
|
||||
box-shadow: 0 2px 8px var(--shadow-color);
|
||||
display: grid;
|
||||
grid-template-columns: repeat(4, 1fr);
|
||||
align-items: center;
|
||||
gap: 0;
|
||||
padding: 0 24px;
|
||||
box-sizing: border-box;
|
||||
z-index: 10001;
|
||||
pointer-events: auto;
|
||||
transition: background 0.3s ease, border-color 0.3s ease, box-shadow 0.3s ease, right 0.3s cubic-bezier(0.4, 0, 0.2, 1);
|
||||
}
|
||||
|
||||
#legend-panel.drawer-open {
|
||||
right: 405px;
|
||||
}
|
||||
|
||||
.legend-section {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
min-width: 0;
|
||||
width: -webkit-fill-available;
|
||||
width: -moz-available;
|
||||
width: stretch;
|
||||
position: relative;
|
||||
}
|
||||
|
||||
.legend-section:not(:last-child)::after {
|
||||
content: '';
|
||||
position: absolute;
|
||||
right: 0;
|
||||
top: 50%;
|
||||
transform: translateY(-50%);
|
||||
width: 1px;
|
||||
height: 48px;
|
||||
background: var(--border-color);
|
||||
transition: background 0.3s ease;
|
||||
}
|
||||
|
||||
.legend-stats-row {
|
||||
display: flex;
|
||||
gap: 32px;
|
||||
justify-content: center;
|
||||
align-items: center;
|
||||
min-width: 0;
|
||||
}
|
||||
|
||||
.legend-stat-item {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
align-items: center;
|
||||
gap: 4px;
|
||||
}
|
||||
|
||||
.stat-value {
|
||||
font-size: 19px;
|
||||
font-weight: 700;
|
||||
color: var(--text-primary);
|
||||
line-height: 1;
|
||||
transition: color 0.3s ease;
|
||||
}
|
||||
|
||||
.stat-label {
|
||||
font-size: 8px;
|
||||
font-weight: 500;
|
||||
text-transform: uppercase;
|
||||
color: var(--text-secondary);
|
||||
letter-spacing: 0.5px;
|
||||
transition: color 0.3s ease;
|
||||
}
|
||||
|
||||
.legend-items-row {
|
||||
display: flex;
|
||||
gap: 16px;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
min-width: 0;
|
||||
}
|
||||
|
||||
.legend-group {
|
||||
display: flex;
|
||||
gap: 16px;
|
||||
align-items: center;
|
||||
}
|
||||
|
||||
.legend-item-compact {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 6px;
|
||||
}
|
||||
|
||||
.legend-item-compact span {
|
||||
font-size: 12px;
|
||||
font-weight: 500;
|
||||
text-transform: uppercase;
|
||||
color: var(--text-secondary);
|
||||
letter-spacing: 0.5px;
|
||||
white-space: nowrap;
|
||||
font-family: inherit;
|
||||
line-height: 1;
|
||||
transition: color 0.3s ease;
|
||||
}
|
||||
|
||||
.legend-color-small {
|
||||
width: 17px;
|
||||
height: 17px;
|
||||
border-radius: 2px;
|
||||
box-sizing: border-box;
|
||||
flex-shrink: 0;
|
||||
transition: background 0.3s ease, border-color 0.3s ease;
|
||||
}
|
||||
|
||||
.legend-item-compact svg {
|
||||
display: block;
|
||||
flex-shrink: 0;
|
||||
width: 29px;
|
||||
height: 14px;
|
||||
}
|
||||
|
||||
.legend-item-compact svg line {
|
||||
transition: stroke 0.3s ease;
|
||||
}
|
||||
|
||||
.legend-ide-column {
|
||||
display: flex;
|
||||
flex-direction: row;
|
||||
gap: 8px;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
min-width: 0;
|
||||
width: 100%;
|
||||
}
|
||||
|
||||
.legend-ide-label {
|
||||
font-size: 12px;
|
||||
font-weight: 500;
|
||||
text-transform: uppercase;
|
||||
color: var(--text-secondary);
|
||||
letter-spacing: 0.5px;
|
||||
transition: color 0.3s ease;
|
||||
white-space: nowrap;
|
||||
}
|
||||
|
||||
.legend-ide-select {
|
||||
width: 120px;
|
||||
padding: 6px 10px;
|
||||
border-radius: 4px;
|
||||
border: 1px solid var(--border-subtle);
|
||||
background: var(--bg-primary);
|
||||
color: var(--text-primary);
|
||||
font-size: 11px;
|
||||
cursor: pointer;
|
||||
transition: all 0.3s ease;
|
||||
}
|
||||
|
||||
.legend-ide-select:hover {
|
||||
border-color: var(--text-secondary);
|
||||
}
|
||||
|
||||
.legend-ide-select:focus {
|
||||
outline: none;
|
||||
border-color: '{{ CREWAI_ORANGE }}';
|
||||
}
|
||||
|
||||
.nav-button {
|
||||
width: 40px;
|
||||
height: 40px;
|
||||
background: var(--bg-secondary);
|
||||
backdrop-filter: blur(12px) saturate(180%);
|
||||
-webkit-backdrop-filter: blur(12px) saturate(180%);
|
||||
border: 1px solid var(--border-subtle);
|
||||
border-radius: 6px;
|
||||
display: flex;
|
||||
@@ -181,12 +376,12 @@ h3 {
|
||||
user-select: none;
|
||||
pointer-events: auto;
|
||||
position: relative;
|
||||
z-index: 10001;
|
||||
z-index: 10002;
|
||||
transition: background 0.3s ease, border-color 0.3s ease, color 0.3s ease, box-shadow 0.3s ease;
|
||||
}
|
||||
|
||||
.nav-button:hover {
|
||||
background: var(--border-subtle);
|
||||
background: var(--nav-button-hover);
|
||||
}
|
||||
|
||||
#drawer {
|
||||
@@ -198,9 +393,10 @@ h3 {
|
||||
background: var(--bg-drawer);
|
||||
box-shadow: -4px 0 12px var(--shadow-strong);
|
||||
transition: right 0.3s cubic-bezier(0.4, 0, 0.2, 1), background 0.3s ease, box-shadow 0.3s ease;
|
||||
z-index: 2000;
|
||||
overflow-y: auto;
|
||||
padding: 24px;
|
||||
z-index: 10003;
|
||||
overflow: hidden;
|
||||
transform: translateZ(0);
|
||||
isolation: isolate;
|
||||
}
|
||||
|
||||
#drawer.open {
|
||||
@@ -247,17 +443,22 @@ h3 {
|
||||
justify-content: space-between;
|
||||
align-items: center;
|
||||
margin-bottom: 20px;
|
||||
padding-bottom: 16px;
|
||||
padding: 24px 24px 16px 24px;
|
||||
border-bottom: 2px solid '{{ CREWAI_ORANGE }}';
|
||||
position: relative;
|
||||
z-index: 2001;
|
||||
}
|
||||
|
||||
.drawer-title {
|
||||
font-size: 20px;
|
||||
font-size: 15px;
|
||||
font-weight: 700;
|
||||
color: var(--text-primary);
|
||||
transition: color 0.3s ease;
|
||||
overflow: hidden;
|
||||
text-overflow: ellipsis;
|
||||
white-space: nowrap;
|
||||
flex: 1;
|
||||
min-width: 0;
|
||||
}
|
||||
|
||||
.drawer-close {
|
||||
@@ -269,12 +470,19 @@ h3 {
|
||||
padding: 4px 8px;
|
||||
line-height: 1;
|
||||
transition: color 0.3s ease;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
}
|
||||
|
||||
.drawer-close:hover {
|
||||
color: '{{ CREWAI_ORANGE }}';
|
||||
}
|
||||
|
||||
.drawer-close i {
|
||||
display: block;
|
||||
}
|
||||
|
||||
.drawer-open-ide {
|
||||
background: '{{ CREWAI_ORANGE }}';
|
||||
border: none;
|
||||
@@ -292,6 +500,9 @@ h3 {
|
||||
position: relative;
|
||||
z-index: 9999;
|
||||
pointer-events: auto;
|
||||
white-space: nowrap;
|
||||
flex-shrink: 0;
|
||||
min-width: fit-content;
|
||||
}
|
||||
|
||||
.drawer-open-ide:hover {
|
||||
@@ -305,14 +516,19 @@ h3 {
|
||||
box-shadow: 0 1px 4px rgba(255, 90, 80, 0.2);
|
||||
}
|
||||
|
||||
.drawer-open-ide svg {
|
||||
.drawer-open-ide svg,
|
||||
.drawer-open-ide i {
|
||||
width: 14px;
|
||||
height: 14px;
|
||||
display: block;
|
||||
}
|
||||
|
||||
.drawer-content {
|
||||
color: '{{ DARK_GRAY }}';
|
||||
line-height: 1.6;
|
||||
padding: 0 24px 24px 24px;
|
||||
overflow-y: auto;
|
||||
height: calc(100vh - 95px);
|
||||
}
|
||||
|
||||
.drawer-section {
|
||||
@@ -328,6 +544,10 @@ h3 {
|
||||
position: relative;
|
||||
}
|
||||
|
||||
.drawer-metadata-grid:has(.drawer-section:nth-child(3):nth-last-child(1)) {
|
||||
grid-template-columns: 1fr 2fr;
|
||||
}
|
||||
|
||||
.drawer-metadata-grid::before {
|
||||
content: '';
|
||||
position: absolute;
|
||||
@@ -419,20 +639,35 @@ h3 {
|
||||
grid-column: 2;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
justify-content: center;
|
||||
justify-content: flex-start;
|
||||
align-items: flex-start;
|
||||
}
|
||||
|
||||
.drawer-metadata-grid:has(.drawer-section:nth-child(3):nth-last-child(1))::after {
|
||||
right: 50%;
|
||||
right: 66.666%;
|
||||
}
|
||||
|
||||
.drawer-metadata-grid:has(.drawer-section:nth-child(3):nth-last-child(1))::before {
|
||||
left: 33.333%;
|
||||
}
|
||||
|
||||
.drawer-metadata-grid .drawer-section:nth-child(3):nth-last-child(1) .drawer-section-title {
|
||||
align-self: flex-start;
|
||||
}
|
||||
|
||||
.drawer-metadata-grid .drawer-section:nth-child(3):nth-last-child(1) > *:not(.drawer-section-title) {
|
||||
width: 100%;
|
||||
align-self: stretch;
|
||||
}
|
||||
|
||||
.drawer-section-title {
|
||||
font-size: 12px;
|
||||
text-transform: uppercase;
|
||||
color: '{{ GRAY }}';
|
||||
color: var(--text-secondary);
|
||||
letter-spacing: 0.5px;
|
||||
margin-bottom: 8px;
|
||||
font-weight: 600;
|
||||
transition: color 0.3s ease;
|
||||
}
|
||||
|
||||
.drawer-badge {
|
||||
@@ -465,9 +700,44 @@ h3 {
|
||||
padding: 3px 0;
|
||||
}
|
||||
|
||||
.drawer-metadata-grid .drawer-section .drawer-list {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 6px;
|
||||
}
|
||||
|
||||
.drawer-metadata-grid .drawer-section .drawer-list li {
|
||||
border-bottom: none;
|
||||
padding: 0;
|
||||
}
|
||||
|
||||
.drawer-metadata-grid .drawer-section:nth-child(3) .drawer-list li {
|
||||
border-bottom: none;
|
||||
padding: 3px 0;
|
||||
padding: 0;
|
||||
}
|
||||
|
||||
.drawer-metadata-grid .drawer-section {
|
||||
overflow: visible;
|
||||
}
|
||||
|
||||
.drawer-metadata-grid .drawer-section .condition-group,
|
||||
.drawer-metadata-grid .drawer-section .trigger-group {
|
||||
width: 100%;
|
||||
box-sizing: border-box;
|
||||
}
|
||||
|
||||
.drawer-metadata-grid .drawer-section .condition-children {
|
||||
width: 100%;
|
||||
}
|
||||
|
||||
.drawer-metadata-grid .drawer-section .trigger-group-items {
|
||||
width: 100%;
|
||||
}
|
||||
|
||||
.drawer-metadata-grid .drawer-section .drawer-code-link {
|
||||
word-break: break-word;
|
||||
overflow-wrap: break-word;
|
||||
max-width: 100%;
|
||||
}
|
||||
|
||||
.drawer-code {
|
||||
@@ -491,6 +761,7 @@ h3 {
|
||||
cursor: pointer;
|
||||
transition: all 0.2s;
|
||||
display: inline-block;
|
||||
margin: 3px 2px;
|
||||
}
|
||||
|
||||
.drawer-code-link:hover {
|
||||
|
||||
@@ -3,12 +3,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterable
|
||||
import inspect
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from crewai.flow.constants import AND_CONDITION, OR_CONDITION
|
||||
from crewai.flow.flow_wrappers import FlowCondition
|
||||
from crewai.flow.types import FlowMethodName
|
||||
from crewai.flow.types import FlowMethodName, FlowRouteName
|
||||
from crewai.flow.utils import (
|
||||
is_flow_condition_dict,
|
||||
is_simple_flow_condition,
|
||||
@@ -197,8 +198,6 @@ def build_flow_structure(flow: Flow[Any]) -> FlowStructure:
|
||||
node_metadata["type"] = "router"
|
||||
router_methods.append(method_name)
|
||||
|
||||
node_metadata["condition_type"] = "IF"
|
||||
|
||||
if method_name in flow._router_paths:
|
||||
node_metadata["router_paths"] = [
|
||||
str(p) for p in flow._router_paths[method_name]
|
||||
@@ -210,9 +209,13 @@ def build_flow_structure(flow: Flow[Any]) -> FlowStructure:
|
||||
]
|
||||
|
||||
if hasattr(method, "__condition_type__") and method.__condition_type__:
|
||||
node_metadata["trigger_condition_type"] = method.__condition_type__
|
||||
if "condition_type" not in node_metadata:
|
||||
node_metadata["condition_type"] = method.__condition_type__
|
||||
|
||||
if node_metadata.get("is_router") and "condition_type" not in node_metadata:
|
||||
node_metadata["condition_type"] = "IF"
|
||||
|
||||
if (
|
||||
hasattr(method, "__trigger_condition__")
|
||||
and method.__trigger_condition__ is not None
|
||||
@@ -298,6 +301,9 @@ def build_flow_structure(flow: Flow[Any]) -> FlowStructure:
|
||||
nodes[method_name] = node_metadata
|
||||
|
||||
for listener_name, condition_data in flow._listeners.items():
|
||||
if listener_name in router_methods:
|
||||
continue
|
||||
|
||||
if is_simple_flow_condition(condition_data):
|
||||
cond_type, methods = condition_data
|
||||
edges.extend(
|
||||
@@ -315,6 +321,60 @@ def build_flow_structure(flow: Flow[Any]) -> FlowStructure:
|
||||
_create_edges_from_condition(condition_data, str(listener_name), nodes)
|
||||
)
|
||||
|
||||
for method_name, node_metadata in nodes.items(): # type: ignore[assignment]
|
||||
if node_metadata.get("is_router") and "trigger_methods" in node_metadata:
|
||||
trigger_methods = node_metadata["trigger_methods"]
|
||||
condition_type = node_metadata.get("trigger_condition_type", OR_CONDITION)
|
||||
|
||||
if "trigger_condition" in node_metadata:
|
||||
edges.extend(
|
||||
_create_edges_from_condition(
|
||||
node_metadata["trigger_condition"], # type: ignore[arg-type]
|
||||
method_name,
|
||||
nodes,
|
||||
)
|
||||
)
|
||||
else:
|
||||
edges.extend(
|
||||
StructureEdge(
|
||||
source=trigger_method,
|
||||
target=method_name,
|
||||
condition_type=condition_type,
|
||||
is_router_path=False,
|
||||
)
|
||||
for trigger_method in trigger_methods
|
||||
if trigger_method in nodes
|
||||
)
|
||||
|
||||
for router_method_name in router_methods:
|
||||
if router_method_name not in flow._router_paths:
|
||||
flow._router_paths[FlowMethodName(router_method_name)] = []
|
||||
|
||||
inferred_paths: Iterable[FlowMethodName | FlowRouteName] = set(
|
||||
flow._router_paths.get(FlowMethodName(router_method_name), [])
|
||||
)
|
||||
|
||||
for condition_data in flow._listeners.values():
|
||||
trigger_strings: list[str] = []
|
||||
|
||||
if is_simple_flow_condition(condition_data):
|
||||
_, methods = condition_data
|
||||
trigger_strings = [str(m) for m in methods]
|
||||
elif is_flow_condition_dict(condition_data):
|
||||
trigger_strings = _extract_direct_or_triggers(condition_data)
|
||||
|
||||
for trigger_str in trigger_strings:
|
||||
if trigger_str not in nodes:
|
||||
# This is likely a router path output
|
||||
inferred_paths.add(trigger_str) # type: ignore[attr-defined]
|
||||
|
||||
if inferred_paths:
|
||||
flow._router_paths[FlowMethodName(router_method_name)] = list(
|
||||
inferred_paths # type: ignore[arg-type]
|
||||
)
|
||||
if router_method_name in nodes:
|
||||
nodes[router_method_name]["router_paths"] = list(inferred_paths)
|
||||
|
||||
for router_method_name in router_methods:
|
||||
if router_method_name not in flow._router_paths:
|
||||
continue
|
||||
@@ -340,6 +400,7 @@ def build_flow_structure(flow: Flow[Any]) -> FlowStructure:
|
||||
target=str(listener_name),
|
||||
condition_type=None,
|
||||
is_router_path=True,
|
||||
router_path_label=str(path),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ class CSSExtension(Extension):
|
||||
Provides {% css 'path/to/file.css' %} tag syntax.
|
||||
"""
|
||||
|
||||
tags: ClassVar[set[str]] = {"css"} # type: ignore[assignment]
|
||||
tags: ClassVar[set[str]] = {"css"} # type: ignore[misc]
|
||||
|
||||
def parse(self, parser: Parser) -> nodes.Node:
|
||||
"""Parse {% css 'styles.css' %} tag.
|
||||
@@ -53,7 +53,7 @@ class JSExtension(Extension):
|
||||
Provides {% js 'path/to/file.js' %} tag syntax.
|
||||
"""
|
||||
|
||||
tags: ClassVar[set[str]] = {"js"} # type: ignore[assignment]
|
||||
tags: ClassVar[set[str]] = {"js"} # type: ignore[misc]
|
||||
|
||||
def parse(self, parser: Parser) -> nodes.Node:
|
||||
"""Parse {% js 'script.js' %} tag.
|
||||
@@ -91,6 +91,116 @@ TEXT_PRIMARY = "#e6edf3"
|
||||
TEXT_SECONDARY = "#7d8590"
|
||||
|
||||
|
||||
def calculate_node_positions(
|
||||
dag: FlowStructure,
|
||||
) -> dict[str, dict[str, int | float]]:
|
||||
"""Calculate hierarchical positions (level, x, y) for each node.
|
||||
|
||||
Args:
|
||||
dag: FlowStructure containing nodes and edges.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping node names to their position data (level, x, y).
|
||||
"""
|
||||
children: dict[str, list[str]] = {name: [] for name in dag["nodes"]}
|
||||
parents: dict[str, list[str]] = {name: [] for name in dag["nodes"]}
|
||||
|
||||
for edge in dag["edges"]:
|
||||
source = edge["source"]
|
||||
target = edge["target"]
|
||||
if source in children and target in children:
|
||||
children[source].append(target)
|
||||
parents[target].append(source)
|
||||
|
||||
levels: dict[str, int] = {}
|
||||
queue: list[tuple[str, int]] = []
|
||||
|
||||
for start_method in dag["start_methods"]:
|
||||
if start_method in dag["nodes"]:
|
||||
levels[start_method] = 0
|
||||
queue.append((start_method, 0))
|
||||
|
||||
visited: set[str] = set()
|
||||
while queue:
|
||||
node, level = queue.pop(0)
|
||||
if node in visited:
|
||||
continue
|
||||
visited.add(node)
|
||||
|
||||
if node not in levels or levels[node] < level:
|
||||
levels[node] = level
|
||||
|
||||
for child in children.get(node, []):
|
||||
if child not in visited:
|
||||
child_level = level + 1
|
||||
if child not in levels or levels[child] < child_level:
|
||||
levels[child] = child_level
|
||||
queue.append((child, child_level))
|
||||
|
||||
for name in dag["nodes"]:
|
||||
if name not in levels:
|
||||
levels[name] = 0
|
||||
|
||||
nodes_by_level: dict[int, list[str]] = {}
|
||||
for node, level in levels.items():
|
||||
if level not in nodes_by_level:
|
||||
nodes_by_level[level] = []
|
||||
nodes_by_level[level].append(node)
|
||||
|
||||
positions: dict[str, dict[str, int | float]] = {}
|
||||
level_separation = 300 # Vertical spacing between levels
|
||||
node_spacing = 400 # Horizontal spacing between nodes
|
||||
|
||||
parent_count: dict[str, int] = {}
|
||||
for node, parent_list in parents.items():
|
||||
parent_count[node] = len(parent_list)
|
||||
|
||||
for level, nodes_at_level in sorted(nodes_by_level.items()):
|
||||
y = level * level_separation
|
||||
|
||||
if level == 0:
|
||||
num_nodes = len(nodes_at_level)
|
||||
for i, node in enumerate(nodes_at_level):
|
||||
x = (i - (num_nodes - 1) / 2) * node_spacing
|
||||
positions[node] = {"level": level, "x": x, "y": y}
|
||||
else:
|
||||
for i, node in enumerate(nodes_at_level):
|
||||
parent_list = parents.get(node, [])
|
||||
parent_positions: list[float] = [
|
||||
positions[parent]["x"]
|
||||
for parent in parent_list
|
||||
if parent in positions
|
||||
]
|
||||
|
||||
if parent_positions:
|
||||
if len(parent_positions) > 1 and len(set(parent_positions)) == 1:
|
||||
base_x = parent_positions[0]
|
||||
avg_x = base_x + node_spacing * 0.4
|
||||
else:
|
||||
avg_x = sum(parent_positions) / len(parent_positions)
|
||||
else:
|
||||
avg_x = i * node_spacing * 0.5
|
||||
|
||||
positions[node] = {"level": level, "x": avg_x, "y": y}
|
||||
|
||||
nodes_at_level_sorted = sorted(
|
||||
nodes_at_level, key=lambda n: positions[n]["x"]
|
||||
)
|
||||
min_spacing = node_spacing * 0.6 # Minimum horizontal distance
|
||||
|
||||
for i in range(len(nodes_at_level_sorted) - 1):
|
||||
current_node = nodes_at_level_sorted[i]
|
||||
next_node = nodes_at_level_sorted[i + 1]
|
||||
|
||||
current_x = positions[current_node]["x"]
|
||||
next_x = positions[next_node]["x"]
|
||||
|
||||
if next_x - current_x < min_spacing:
|
||||
positions[next_node]["x"] = current_x + min_spacing
|
||||
|
||||
return positions
|
||||
|
||||
|
||||
def render_interactive(
|
||||
dag: FlowStructure,
|
||||
filename: str = "flow_dag.html",
|
||||
@@ -110,6 +220,8 @@ def render_interactive(
|
||||
Returns:
|
||||
Absolute path to generated HTML file in temporary directory.
|
||||
"""
|
||||
node_positions = calculate_node_positions(dag)
|
||||
|
||||
nodes_list: list[dict[str, Any]] = []
|
||||
for name, metadata in dag["nodes"].items():
|
||||
node_type: str = metadata.get("type", "listen")
|
||||
@@ -120,37 +232,37 @@ def render_interactive(
|
||||
|
||||
if node_type == "start":
|
||||
color_config = {
|
||||
"background": CREWAI_ORANGE,
|
||||
"border": CREWAI_ORANGE,
|
||||
"background": "var(--node-bg-start)",
|
||||
"border": "var(--node-border-start)",
|
||||
"highlight": {
|
||||
"background": CREWAI_ORANGE,
|
||||
"border": CREWAI_ORANGE,
|
||||
"background": "var(--node-bg-start)",
|
||||
"border": "var(--node-border-start)",
|
||||
},
|
||||
}
|
||||
font_color = WHITE
|
||||
border_width = 2
|
||||
font_color = "var(--node-text-color)"
|
||||
border_width = 3
|
||||
elif node_type == "router":
|
||||
color_config = {
|
||||
"background": DARK_GRAY,
|
||||
"background": "var(--node-bg-router)",
|
||||
"border": CREWAI_ORANGE,
|
||||
"highlight": {
|
||||
"background": DARK_GRAY,
|
||||
"background": "var(--node-bg-router)",
|
||||
"border": CREWAI_ORANGE,
|
||||
},
|
||||
}
|
||||
font_color = WHITE
|
||||
font_color = "var(--node-text-color)"
|
||||
border_width = 3
|
||||
else:
|
||||
color_config = {
|
||||
"background": DARK_GRAY,
|
||||
"border": DARK_GRAY,
|
||||
"background": "var(--node-bg-listen)",
|
||||
"border": "var(--node-border-listen)",
|
||||
"highlight": {
|
||||
"background": DARK_GRAY,
|
||||
"border": DARK_GRAY,
|
||||
"background": "var(--node-bg-listen)",
|
||||
"border": "var(--node-border-listen)",
|
||||
},
|
||||
}
|
||||
font_color = WHITE
|
||||
border_width = 2
|
||||
font_color = "var(--node-text-color)"
|
||||
border_width = 3
|
||||
|
||||
title_parts: list[str] = []
|
||||
|
||||
@@ -215,25 +327,34 @@ def render_interactive(
|
||||
bg_color = color_config["background"]
|
||||
border_color = color_config["border"]
|
||||
|
||||
nodes_list.append(
|
||||
{
|
||||
"id": name,
|
||||
"label": name,
|
||||
"title": "".join(title_parts),
|
||||
"shape": "custom",
|
||||
"size": 30,
|
||||
"nodeStyle": {
|
||||
"name": name,
|
||||
"bgColor": bg_color,
|
||||
"borderColor": border_color,
|
||||
"borderWidth": border_width,
|
||||
"fontColor": font_color,
|
||||
},
|
||||
"opacity": 1.0,
|
||||
"glowSize": 0,
|
||||
"glowColor": None,
|
||||
}
|
||||
)
|
||||
position_data = node_positions.get(name, {"level": 0, "x": 0, "y": 0})
|
||||
|
||||
node_data: dict[str, Any] = {
|
||||
"id": name,
|
||||
"label": name,
|
||||
"title": "".join(title_parts),
|
||||
"shape": "custom",
|
||||
"size": 30,
|
||||
"level": position_data["level"],
|
||||
"nodeStyle": {
|
||||
"name": name,
|
||||
"bgColor": bg_color,
|
||||
"borderColor": border_color,
|
||||
"borderWidth": border_width,
|
||||
"fontColor": font_color,
|
||||
},
|
||||
"opacity": 1.0,
|
||||
"glowSize": 0,
|
||||
"glowColor": None,
|
||||
}
|
||||
|
||||
# Add x,y only for graphs with 3-4 nodes
|
||||
total_nodes = len(dag["nodes"])
|
||||
if 3 <= total_nodes <= 4:
|
||||
node_data["x"] = position_data["x"]
|
||||
node_data["y"] = position_data["y"]
|
||||
|
||||
nodes_list.append(node_data)
|
||||
|
||||
execution_paths: int = calculate_execution_paths(dag)
|
||||
|
||||
@@ -246,6 +367,8 @@ def render_interactive(
|
||||
if edge["is_router_path"]:
|
||||
edge_color = CREWAI_ORANGE
|
||||
edge_dashes = [15, 10]
|
||||
if "router_path_label" in edge:
|
||||
edge_label = edge["router_path_label"]
|
||||
elif edge["condition_type"] == "AND":
|
||||
edge_label = "AND"
|
||||
edge_color = CREWAI_ORANGE
|
||||
|
||||
@@ -10,6 +10,7 @@ class NodeMetadata(TypedDict, total=False):
|
||||
is_router: bool
|
||||
router_paths: list[str]
|
||||
condition_type: str | None
|
||||
trigger_condition_type: str | None
|
||||
trigger_methods: list[str]
|
||||
trigger_condition: dict[str, Any] | None
|
||||
method_signature: dict[str, Any]
|
||||
@@ -22,13 +23,14 @@ class NodeMetadata(TypedDict, total=False):
|
||||
class_line_number: int
|
||||
|
||||
|
||||
class StructureEdge(TypedDict):
|
||||
class StructureEdge(TypedDict, total=False):
|
||||
"""Represents a connection in the flow structure."""
|
||||
|
||||
source: str
|
||||
target: str
|
||||
condition_type: str | None
|
||||
is_router_path: bool
|
||||
router_path_label: str
|
||||
|
||||
|
||||
class FlowStructure(TypedDict):
|
||||
|
||||
@@ -20,8 +20,7 @@ from typing import (
|
||||
)
|
||||
|
||||
from dotenv import load_dotenv
|
||||
import httpx
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
@@ -54,7 +53,6 @@ if TYPE_CHECKING:
|
||||
from litellm.utils import supports_response_schema
|
||||
|
||||
from crewai.agent.core import Agent
|
||||
from crewai.llms.hooks.base import BaseInterceptor
|
||||
from crewai.task import Task
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
from crewai.utilities.types import LLMMessage
|
||||
@@ -320,7 +318,138 @@ class AccumulatedToolArgs(BaseModel):
|
||||
|
||||
|
||||
class LLM(BaseLLM):
|
||||
completion_cost: float | None = None
|
||||
completion_cost: float | None = Field(
|
||||
default=None, description="The completion cost of the LLM."
|
||||
)
|
||||
top_p: float | None = Field(
|
||||
default=None, description="Sampling probability threshold."
|
||||
)
|
||||
n: int | None = Field(
|
||||
default=None, description="Number of completions to generate."
|
||||
)
|
||||
max_completion_tokens: int | None = Field(
|
||||
default=None,
|
||||
description="Maximum number of tokens to generate in the completion.",
|
||||
)
|
||||
max_tokens: int | None = Field(
|
||||
default=None,
|
||||
description="Maximum number of tokens allowed in the prompt + completion.",
|
||||
)
|
||||
presence_penalty: float | None = Field(
|
||||
default=None, description="Penalty on the presence penalty."
|
||||
)
|
||||
frequency_penalty: float | None = Field(
|
||||
default=None, description="Penalty on the frequency penalty."
|
||||
)
|
||||
logit_bias: dict[int, float] | None = Field(
|
||||
default=None,
|
||||
description="Modifies the likelihood of specified tokens appearing in the completion.",
|
||||
)
|
||||
response_format: type[BaseModel] | None = Field(
|
||||
default=None,
|
||||
description="Pydantic model class for structured response parsing.",
|
||||
)
|
||||
seed: int | None = Field(
|
||||
default=None,
|
||||
description="Random seed for reproducibility.",
|
||||
)
|
||||
logprobs: int | None = Field(
|
||||
default=None,
|
||||
description="Number of top logprobs to return.",
|
||||
)
|
||||
top_logprobs: int | None = Field(
|
||||
default=None,
|
||||
description="Number of top logprobs to return.",
|
||||
)
|
||||
api_base: str | None = Field(
|
||||
default=None,
|
||||
description="Base URL for the API endpoint.",
|
||||
)
|
||||
api_version: str | None = Field(
|
||||
default=None,
|
||||
description="API version to use.",
|
||||
)
|
||||
callbacks: list[Any] = Field(
|
||||
default_factory=list,
|
||||
description="List of callback handlers for LLM events.",
|
||||
)
|
||||
reasoning_effort: Literal["none", "low", "medium", "high"] | None = Field(
|
||||
default=None,
|
||||
description="Level of reasoning effort for the LLM.",
|
||||
)
|
||||
context_window_size: int = Field(
|
||||
default=0,
|
||||
description="The context window size of the LLM.",
|
||||
)
|
||||
is_anthropic: bool = Field(
|
||||
default=False,
|
||||
description="Indicates if the model is from Anthropic provider.",
|
||||
)
|
||||
supports_function_calling: bool = Field(
|
||||
default=False,
|
||||
description="Indicates if the model supports function calling.",
|
||||
)
|
||||
supports_stop_words: bool = Field(
|
||||
default=False,
|
||||
description="Indicates if the model supports stop words.",
|
||||
)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def initialize_client(self) -> Self:
|
||||
self.is_anthropic = any(
|
||||
prefix in self.model.lower() for prefix in ANTHROPIC_PREFIXES
|
||||
)
|
||||
try:
|
||||
provider = self._get_custom_llm_provider()
|
||||
self.supports_function_calling = litellm.utils.supports_function_calling(
|
||||
self.model, custom_llm_provider=provider
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to check function calling support: {e!s}")
|
||||
self.supports_function_calling = False
|
||||
try:
|
||||
params = get_supported_openai_params(model=self.model)
|
||||
self.supports_stop_words = params is not None and "stop" in params
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to get supported params: {e!s}")
|
||||
self.supports_stop_words = False
|
||||
|
||||
with suppress_warnings():
|
||||
callback_types = [type(callback) for callback in self.callbacks]
|
||||
for callback in litellm.success_callback[:]:
|
||||
if type(callback) in callback_types:
|
||||
litellm.success_callback.remove(callback)
|
||||
|
||||
for callback in litellm._async_success_callback[:]:
|
||||
if type(callback) in callback_types:
|
||||
litellm._async_success_callback.remove(callback)
|
||||
|
||||
litellm.callbacks = self.callbacks
|
||||
|
||||
with suppress_warnings():
|
||||
success_callbacks_str = os.environ.get("LITELLM_SUCCESS_CALLBACKS", "")
|
||||
success_callbacks: list[str | Callable[..., Any] | CustomLogger] = []
|
||||
if success_callbacks_str:
|
||||
success_callbacks = [
|
||||
cb.strip() for cb in success_callbacks_str.split(",") if cb.strip()
|
||||
]
|
||||
|
||||
failure_callbacks_str = os.environ.get("LITELLM_FAILURE_CALLBACKS", "")
|
||||
if failure_callbacks_str:
|
||||
failure_callbacks: list[str | Callable[..., Any] | CustomLogger] = [
|
||||
cb.strip() for cb in failure_callbacks_str.split(",") if cb.strip()
|
||||
]
|
||||
|
||||
litellm.success_callback = success_callbacks
|
||||
litellm.failure_callback = failure_callbacks
|
||||
return self
|
||||
|
||||
# @computed_field
|
||||
# @property
|
||||
# def is_anthropic(self) -> bool:
|
||||
# """Determine if the model is from Anthropic provider."""
|
||||
# anthropic_prefixes = ("anthropic/", "claude-", "claude/")
|
||||
# return any(prefix in self.model.lower() for prefix in anthropic_prefixes)
|
||||
|
||||
def __new__(cls, model: str, is_litellm: bool = False, **kwargs: Any) -> LLM:
|
||||
"""Factory method that routes to native SDK or falls back to LiteLLM."""
|
||||
@@ -383,98 +512,6 @@ class LLM(BaseLLM):
|
||||
|
||||
return None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
timeout: float | int | None = None,
|
||||
temperature: float | None = None,
|
||||
top_p: float | None = None,
|
||||
n: int | None = None,
|
||||
stop: str | list[str] | None = None,
|
||||
max_completion_tokens: int | None = None,
|
||||
max_tokens: int | float | None = None,
|
||||
presence_penalty: float | None = None,
|
||||
frequency_penalty: float | None = None,
|
||||
logit_bias: dict[int, float] | None = None,
|
||||
response_format: type[BaseModel] | None = None,
|
||||
seed: int | None = None,
|
||||
logprobs: int | None = None,
|
||||
top_logprobs: int | None = None,
|
||||
base_url: str | None = None,
|
||||
api_base: str | None = None,
|
||||
api_version: str | None = None,
|
||||
api_key: str | None = None,
|
||||
callbacks: list[Any] | None = None,
|
||||
reasoning_effort: Literal["none", "low", "medium", "high"] | None = None,
|
||||
stream: bool = False,
|
||||
interceptor: BaseInterceptor[httpx.Request, httpx.Response] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize LLM instance.
|
||||
|
||||
Note: This __init__ method is only called for fallback instances.
|
||||
Native provider instances handle their own initialization in their respective classes.
|
||||
"""
|
||||
super().__init__(
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
timeout=timeout,
|
||||
**kwargs,
|
||||
)
|
||||
self.model = model
|
||||
self.timeout = timeout
|
||||
self.temperature = temperature
|
||||
self.top_p = top_p
|
||||
self.n = n
|
||||
self.max_completion_tokens = max_completion_tokens
|
||||
self.max_tokens = max_tokens
|
||||
self.presence_penalty = presence_penalty
|
||||
self.frequency_penalty = frequency_penalty
|
||||
self.logit_bias = logit_bias
|
||||
self.response_format = response_format
|
||||
self.seed = seed
|
||||
self.logprobs = logprobs
|
||||
self.top_logprobs = top_logprobs
|
||||
self.base_url = base_url
|
||||
self.api_base = api_base
|
||||
self.api_version = api_version
|
||||
self.api_key = api_key
|
||||
self.callbacks = callbacks
|
||||
self.context_window_size = 0
|
||||
self.reasoning_effort = reasoning_effort
|
||||
self.additional_params = kwargs
|
||||
self.is_anthropic = self._is_anthropic_model(model)
|
||||
self.stream = stream
|
||||
self.interceptor = interceptor
|
||||
|
||||
litellm.drop_params = True
|
||||
|
||||
# Normalize self.stop to always be a list[str]
|
||||
if stop is None:
|
||||
self.stop: list[str] = []
|
||||
elif isinstance(stop, str):
|
||||
self.stop = [stop]
|
||||
else:
|
||||
self.stop = stop
|
||||
|
||||
self.set_callbacks(callbacks or [])
|
||||
self.set_env_callbacks()
|
||||
|
||||
@staticmethod
|
||||
def _is_anthropic_model(model: str) -> bool:
|
||||
"""Determine if the model is from Anthropic provider.
|
||||
|
||||
Args:
|
||||
model: The model identifier string.
|
||||
|
||||
Returns:
|
||||
bool: True if the model is from Anthropic, False otherwise.
|
||||
"""
|
||||
anthropic_prefixes = ("anthropic/", "claude-", "claude/")
|
||||
return any(prefix in model.lower() for prefix in anthropic_prefixes)
|
||||
|
||||
def _prepare_completion_params(
|
||||
self,
|
||||
messages: str | list[LLMMessage],
|
||||
@@ -1188,8 +1225,6 @@ class LLM(BaseLLM):
|
||||
message["role"] = msg_role
|
||||
# --- 5) Set up callbacks if provided
|
||||
with suppress_warnings():
|
||||
if callbacks and len(callbacks) > 0:
|
||||
self.set_callbacks(callbacks)
|
||||
try:
|
||||
# --- 6) Prepare parameters for the completion call
|
||||
params = self._prepare_completion_params(messages, tools)
|
||||
@@ -1378,24 +1413,6 @@ class LLM(BaseLLM):
|
||||
"Please remove response_format or use a supported model."
|
||||
)
|
||||
|
||||
def supports_function_calling(self) -> bool:
|
||||
try:
|
||||
provider = self._get_custom_llm_provider()
|
||||
return litellm.utils.supports_function_calling(
|
||||
self.model, custom_llm_provider=provider
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to check function calling support: {e!s}")
|
||||
return False
|
||||
|
||||
def supports_stop_words(self) -> bool:
|
||||
try:
|
||||
params = get_supported_openai_params(model=self.model)
|
||||
return params is not None and "stop" in params
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to get supported params: {e!s}")
|
||||
return False
|
||||
|
||||
def get_context_window_size(self) -> int:
|
||||
"""
|
||||
Returns the context window size, using 75% of the maximum to avoid
|
||||
@@ -1425,60 +1442,6 @@ class LLM(BaseLLM):
|
||||
self.context_window_size = int(value * CONTEXT_WINDOW_USAGE_RATIO)
|
||||
return self.context_window_size
|
||||
|
||||
@staticmethod
|
||||
def set_callbacks(callbacks: list[Any]) -> None:
|
||||
"""
|
||||
Attempt to keep a single set of callbacks in litellm by removing old
|
||||
duplicates and adding new ones.
|
||||
"""
|
||||
with suppress_warnings():
|
||||
callback_types = [type(callback) for callback in callbacks]
|
||||
for callback in litellm.success_callback[:]:
|
||||
if type(callback) in callback_types:
|
||||
litellm.success_callback.remove(callback)
|
||||
|
||||
for callback in litellm._async_success_callback[:]:
|
||||
if type(callback) in callback_types:
|
||||
litellm._async_success_callback.remove(callback)
|
||||
|
||||
litellm.callbacks = callbacks
|
||||
|
||||
@staticmethod
|
||||
def set_env_callbacks() -> None:
|
||||
"""Sets the success and failure callbacks for the LiteLLM library from environment variables.
|
||||
|
||||
This method reads the `LITELLM_SUCCESS_CALLBACKS` and `LITELLM_FAILURE_CALLBACKS`
|
||||
environment variables, which should contain comma-separated lists of callback names.
|
||||
It then assigns these lists to `litellm.success_callback` and `litellm.failure_callback`,
|
||||
respectively.
|
||||
|
||||
If the environment variables are not set or are empty, the corresponding callback lists
|
||||
will be set to empty lists.
|
||||
|
||||
Examples:
|
||||
LITELLM_SUCCESS_CALLBACKS="langfuse,langsmith"
|
||||
LITELLM_FAILURE_CALLBACKS="langfuse"
|
||||
|
||||
This will set `litellm.success_callback` to ["langfuse", "langsmith"] and
|
||||
`litellm.failure_callback` to ["langfuse"].
|
||||
"""
|
||||
with suppress_warnings():
|
||||
success_callbacks_str = os.environ.get("LITELLM_SUCCESS_CALLBACKS", "")
|
||||
success_callbacks: list[str | Callable[..., Any] | CustomLogger] = []
|
||||
if success_callbacks_str:
|
||||
success_callbacks = [
|
||||
cb.strip() for cb in success_callbacks_str.split(",") if cb.strip()
|
||||
]
|
||||
|
||||
failure_callbacks_str = os.environ.get("LITELLM_FAILURE_CALLBACKS", "")
|
||||
if failure_callbacks_str:
|
||||
failure_callbacks: list[str | Callable[..., Any] | CustomLogger] = [
|
||||
cb.strip() for cb in failure_callbacks_str.split(",") if cb.strip()
|
||||
]
|
||||
|
||||
litellm.success_callback = success_callbacks
|
||||
litellm.failure_callback = failure_callbacks
|
||||
|
||||
def __copy__(self) -> LLM:
|
||||
"""Create a shallow copy of the LLM instance."""
|
||||
# Filter out parameters that are already explicitly passed to avoid conflicts
|
||||
@@ -1539,7 +1502,7 @@ class LLM(BaseLLM):
|
||||
**filtered_params,
|
||||
)
|
||||
|
||||
def __deepcopy__(self, memo: dict[int, Any] | None) -> LLM:
|
||||
def __deepcopy__(self, memo: dict[int, Any] | None) -> LLM: # type: ignore[override]
|
||||
"""Create a deep copy of the LLM instance."""
|
||||
import copy
|
||||
|
||||
|
||||
@@ -13,8 +13,9 @@ import logging
|
||||
import re
|
||||
from typing import TYPE_CHECKING, Any, Final
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import AliasChoices, BaseModel, Field, PrivateAttr, field_validator
|
||||
|
||||
from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.llm_events import (
|
||||
LLMCallCompletedEvent,
|
||||
@@ -28,6 +29,7 @@ from crewai.events.types.tool_usage_events import (
|
||||
ToolUsageFinishedEvent,
|
||||
ToolUsageStartedEvent,
|
||||
)
|
||||
from crewai.llms.hooks import BaseInterceptor
|
||||
from crewai.types.usage_metrics import UsageMetrics
|
||||
|
||||
|
||||
@@ -43,7 +45,7 @@ DEFAULT_SUPPORTS_STOP_WORDS: Final[bool] = True
|
||||
_JSON_EXTRACTION_PATTERN: Final[re.Pattern[str]] = re.compile(r"\{.*}", re.DOTALL)
|
||||
|
||||
|
||||
class BaseLLM(ABC):
|
||||
class BaseLLM(BaseModel, ABC):
|
||||
"""Abstract base class for LLM implementations.
|
||||
|
||||
This class defines the interface that all LLM implementations must follow.
|
||||
@@ -55,70 +57,105 @@ class BaseLLM(ABC):
|
||||
implement proper validation for input parameters and provide clear error
|
||||
messages when things go wrong.
|
||||
|
||||
|
||||
Attributes:
|
||||
model: The model identifier/name.
|
||||
temperature: Optional temperature setting for response generation.
|
||||
stop: A list of stop sequences that the LLM should use to stop generation.
|
||||
additional_params: Additional provider-specific parameters.
|
||||
"""
|
||||
|
||||
is_litellm: bool = False
|
||||
provider: str | re.Pattern[str] = Field(
|
||||
default="openai", description="The provider of the LLM."
|
||||
)
|
||||
model: str = Field(description="The model identifier/name.")
|
||||
temperature: float | None = Field(
|
||||
default=None, ge=0, le=2, description="Temperature for response generation."
|
||||
)
|
||||
api_key: str | None = Field(default=None, description="API key for authentication.")
|
||||
base_url: str | None = Field(default=None, description="Base URL for API calls.")
|
||||
timeout: float | None = Field(default=None, description="Timeout for API calls.")
|
||||
max_retries: int = Field(
|
||||
default=2, description="Maximum number of API requests to make."
|
||||
)
|
||||
max_tokens: int | None = Field(
|
||||
default=None, description="Maximum tokens for response generation."
|
||||
)
|
||||
stream: bool | None = Field(default=False, description="Stream the API requests.")
|
||||
client: Any = Field(description="Underlying LLM client instance.")
|
||||
interceptor: BaseInterceptor[Any, Any] | None = Field(
|
||||
default=None,
|
||||
description="An optional HTTPX interceptor for modifying requests/responses.",
|
||||
)
|
||||
client_params: dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="Additional parameters for the underlying LLM client.",
|
||||
)
|
||||
supports_stop_words: bool = Field(
|
||||
default=DEFAULT_SUPPORTS_STOP_WORDS,
|
||||
description="Whether or not to support stop words.",
|
||||
)
|
||||
stop_sequences: list[str] = Field(
|
||||
default_factory=list,
|
||||
validation_alias=AliasChoices("stop_sequences", "stop"),
|
||||
description="Stop sequences for generation (synchronized with stop).",
|
||||
)
|
||||
is_litellm: bool = Field(
|
||||
default=False, description="Is this LLM implementation in litellm?"
|
||||
)
|
||||
additional_params: dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="Additional parameters for LLM calls.",
|
||||
)
|
||||
_token_usage: TokenProcess = PrivateAttr(default_factory=TokenProcess)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
temperature: float | None = None,
|
||||
api_key: str | None = None,
|
||||
base_url: str | None = None,
|
||||
provider: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize the BaseLLM with default attributes.
|
||||
@field_validator("provider", mode="before")
|
||||
@classmethod
|
||||
def extract_provider_from_model(
|
||||
cls, v: str | re.Pattern[str] | None, info: Any
|
||||
) -> str | re.Pattern[str]:
|
||||
"""Extract provider from model string if not explicitly provided.
|
||||
|
||||
Args:
|
||||
model: The model identifier/name.
|
||||
temperature: Optional temperature setting for response generation.
|
||||
stop: Optional list of stop sequences for generation.
|
||||
**kwargs: Additional provider-specific parameters.
|
||||
v: Provided provider value (can be str, Pattern, or None)
|
||||
info: Validation info containing other field values
|
||||
|
||||
Returns:
|
||||
Provider name (str) or Pattern
|
||||
"""
|
||||
if not model:
|
||||
raise ValueError("Model name is required and cannot be empty")
|
||||
# If provider explicitly provided, validate and return it
|
||||
if v is not None:
|
||||
if not isinstance(v, (str, re.Pattern)):
|
||||
raise ValueError(f"Provider must be str or Pattern, got {type(v)}")
|
||||
return v
|
||||
|
||||
self.model = model
|
||||
self.temperature = temperature
|
||||
self.api_key = api_key
|
||||
self.base_url = base_url
|
||||
# Store additional parameters for provider-specific use
|
||||
self.additional_params = kwargs
|
||||
self._provider = provider or "openai"
|
||||
model: str = info.data.get("model", "")
|
||||
if "/" in model:
|
||||
return model.partition("/")[0]
|
||||
return "openai"
|
||||
|
||||
stop = kwargs.pop("stop", None)
|
||||
if stop is None:
|
||||
self.stop: list[str] = []
|
||||
elif isinstance(stop, str):
|
||||
self.stop = [stop]
|
||||
elif isinstance(stop, list):
|
||||
self.stop = stop
|
||||
else:
|
||||
self.stop = []
|
||||
@field_validator("stop_sequences", mode="before")
|
||||
@classmethod
|
||||
def normalize_stop_sequences(
|
||||
cls, v: str | list[str] | set[str] | None
|
||||
) -> list[str]:
|
||||
"""Validate and normalize stop sequences.
|
||||
|
||||
self._token_usage = {
|
||||
"total_tokens": 0,
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"successful_requests": 0,
|
||||
"cached_prompt_tokens": 0,
|
||||
}
|
||||
Converts string to list and handles None values.
|
||||
AliasChoices handles accepting both 'stop' and 'stop_sequences' parameter names.
|
||||
"""
|
||||
if v is None:
|
||||
return []
|
||||
if isinstance(v, str):
|
||||
return [v]
|
||||
if isinstance(v, set):
|
||||
return list(v)
|
||||
if isinstance(v, list):
|
||||
return v
|
||||
return []
|
||||
|
||||
@property
|
||||
def provider(self) -> str:
|
||||
"""Get the provider of the LLM."""
|
||||
return self._provider
|
||||
|
||||
@provider.setter
|
||||
def provider(self, value: str) -> None:
|
||||
"""Set the provider of the LLM."""
|
||||
self._provider = value
|
||||
def stop(self) -> list[str]:
|
||||
"""Alias for stop_sequences to maintain backward compatibility."""
|
||||
return self.stop_sequences
|
||||
|
||||
@abstractmethod
|
||||
def call(
|
||||
@@ -171,14 +208,6 @@ class BaseLLM(ABC):
|
||||
"""
|
||||
return tools
|
||||
|
||||
def supports_stop_words(self) -> bool:
|
||||
"""Check if the LLM supports stop words.
|
||||
|
||||
Returns:
|
||||
True if the LLM supports stop words, False otherwise.
|
||||
"""
|
||||
return DEFAULT_SUPPORTS_STOP_WORDS
|
||||
|
||||
def _supports_stop_words_implementation(self) -> bool:
|
||||
"""Check if stop words are configured for this LLM instance.
|
||||
|
||||
@@ -506,7 +535,7 @@ class BaseLLM(ABC):
|
||||
"""
|
||||
if "/" in model:
|
||||
return model.partition("/")[0]
|
||||
return "openai" # Default provider
|
||||
return "openai"
|
||||
|
||||
def _track_token_usage_internal(self, usage_data: dict[str, Any]) -> None:
|
||||
"""Track token usage internally in the LLM instance.
|
||||
@@ -535,11 +564,11 @@ class BaseLLM(ABC):
|
||||
or 0
|
||||
)
|
||||
|
||||
self._token_usage["prompt_tokens"] += prompt_tokens
|
||||
self._token_usage["completion_tokens"] += completion_tokens
|
||||
self._token_usage["total_tokens"] += prompt_tokens + completion_tokens
|
||||
self._token_usage["successful_requests"] += 1
|
||||
self._token_usage["cached_prompt_tokens"] += cached_tokens
|
||||
self._token_usage.prompt_tokens += prompt_tokens
|
||||
self._token_usage.completion_tokens += completion_tokens
|
||||
self._token_usage.total_tokens += prompt_tokens + completion_tokens
|
||||
self._token_usage.successful_requests += 1
|
||||
self._token_usage.cached_prompt_tokens += cached_tokens
|
||||
|
||||
def get_token_usage_summary(self) -> UsageMetrics:
|
||||
"""Get summary of token usage for this LLM instance.
|
||||
@@ -547,4 +576,10 @@ class BaseLLM(ABC):
|
||||
Returns:
|
||||
Dictionary with token usage totals
|
||||
"""
|
||||
return UsageMetrics(**self._token_usage)
|
||||
return UsageMetrics(
|
||||
prompt_tokens=self._token_usage.prompt_tokens,
|
||||
completion_tokens=self._token_usage.completion_tokens,
|
||||
total_tokens=self._token_usage.total_tokens,
|
||||
successful_requests=self._token_usage.successful_requests,
|
||||
cached_prompt_tokens=self._token_usage.cached_prompt_tokens,
|
||||
)
|
||||
|
||||
@@ -7,7 +7,14 @@ outbound and inbound messages at the transport level.
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Generic, TypeVar
|
||||
from typing import TYPE_CHECKING, Any, Generic, TypeVar
|
||||
|
||||
from pydantic_core import core_schema
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pydantic import GetCoreSchemaHandler
|
||||
from pydantic_core import CoreSchema
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
@@ -25,6 +32,7 @@ class BaseInterceptor(ABC, Generic[T, U]):
|
||||
U: Inbound message type (e.g., httpx.Response)
|
||||
|
||||
Example:
|
||||
>>> import httpx
|
||||
>>> class CustomInterceptor(BaseInterceptor[httpx.Request, httpx.Response]):
|
||||
... def on_outbound(self, message: httpx.Request) -> httpx.Request:
|
||||
... message.headers["X-Custom-Header"] = "value"
|
||||
@@ -80,3 +88,46 @@ class BaseInterceptor(ABC, Generic[T, U]):
|
||||
Modified message object.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(
|
||||
cls, _source_type: Any, _handler: GetCoreSchemaHandler
|
||||
) -> CoreSchema:
|
||||
"""Generate Pydantic core schema for BaseInterceptor.
|
||||
|
||||
This allows the generic BaseInterceptor to be used in Pydantic models
|
||||
without requiring arbitrary_types_allowed=True. The schema validates
|
||||
that the value is an instance of BaseInterceptor.
|
||||
|
||||
Args:
|
||||
_source_type: The source type being validated (unused).
|
||||
_handler: Handler for generating schemas (unused).
|
||||
|
||||
Returns:
|
||||
A Pydantic core schema that validates BaseInterceptor instances.
|
||||
"""
|
||||
return core_schema.no_info_plain_validator_function(
|
||||
_validate_interceptor,
|
||||
serialization=core_schema.plain_serializer_function_ser_schema(
|
||||
lambda x: x, return_schema=core_schema.any_schema()
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _validate_interceptor(value: Any) -> BaseInterceptor[T, U]:
|
||||
"""Validate that the value is a BaseInterceptor instance.
|
||||
|
||||
Args:
|
||||
value: The value to validate.
|
||||
|
||||
Returns:
|
||||
The validated BaseInterceptor instance.
|
||||
|
||||
Raises:
|
||||
ValueError: If the value is not a BaseInterceptor instance.
|
||||
"""
|
||||
if not isinstance(value, BaseInterceptor):
|
||||
raise ValueError(
|
||||
f"Expected BaseInterceptor instance, got {type(value).__name__}"
|
||||
)
|
||||
return value
|
||||
|
||||
@@ -6,16 +6,52 @@ to enable request/response modification at the transport level.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from collections.abc import Iterable
|
||||
from typing import TYPE_CHECKING, TypedDict
|
||||
|
||||
import httpx
|
||||
from httpx import (
|
||||
AsyncHTTPTransport as _AsyncHTTPTransport,
|
||||
HTTPTransport as _HTTPTransport,
|
||||
)
|
||||
from typing_extensions import NotRequired, Unpack
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ssl import SSLContext
|
||||
|
||||
from httpx import Limits, Request, Response
|
||||
from httpx._types import CertTypes, ProxyTypes
|
||||
|
||||
from crewai.llms.hooks.base import BaseInterceptor
|
||||
|
||||
|
||||
class HTTPTransport(httpx.HTTPTransport):
|
||||
class HTTPTransportKwargs(TypedDict, total=False):
|
||||
"""Typed dictionary for httpx.HTTPTransport initialization parameters.
|
||||
|
||||
These parameters configure the underlying HTTP transport behavior including
|
||||
SSL verification, proxies, connection limits, and low-level socket options.
|
||||
"""
|
||||
|
||||
verify: bool | str | SSLContext
|
||||
cert: NotRequired[CertTypes]
|
||||
trust_env: bool
|
||||
http1: bool
|
||||
http2: bool
|
||||
limits: Limits
|
||||
proxy: NotRequired[ProxyTypes]
|
||||
uds: NotRequired[str]
|
||||
local_address: NotRequired[str]
|
||||
retries: int
|
||||
socket_options: NotRequired[
|
||||
Iterable[
|
||||
tuple[int, int, int]
|
||||
| tuple[int, int, bytes | bytearray]
|
||||
| tuple[int, int, None, int]
|
||||
]
|
||||
]
|
||||
|
||||
|
||||
class HTTPTransport(_HTTPTransport):
|
||||
"""HTTP transport that uses an interceptor for request/response modification.
|
||||
|
||||
This transport is used internally when a user provides a BaseInterceptor.
|
||||
@@ -25,19 +61,19 @@ class HTTPTransport(httpx.HTTPTransport):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
interceptor: BaseInterceptor[httpx.Request, httpx.Response],
|
||||
**kwargs: Any,
|
||||
interceptor: BaseInterceptor[Request, Response],
|
||||
**kwargs: Unpack[HTTPTransportKwargs],
|
||||
) -> None:
|
||||
"""Initialize transport with interceptor.
|
||||
|
||||
Args:
|
||||
interceptor: HTTP interceptor for modifying raw request/response objects.
|
||||
**kwargs: Additional arguments passed to httpx.HTTPTransport.
|
||||
**kwargs: HTTPTransport configuration parameters (verify, cert, proxy, etc.).
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
self.interceptor = interceptor
|
||||
|
||||
def handle_request(self, request: httpx.Request) -> httpx.Response:
|
||||
def handle_request(self, request: Request) -> Response:
|
||||
"""Handle request with interception.
|
||||
|
||||
Args:
|
||||
@@ -51,7 +87,7 @@ class HTTPTransport(httpx.HTTPTransport):
|
||||
return self.interceptor.on_inbound(response)
|
||||
|
||||
|
||||
class AsyncHTTPransport(httpx.AsyncHTTPTransport):
|
||||
class AsyncHTTPTransport(_AsyncHTTPTransport):
|
||||
"""Async HTTP transport that uses an interceptor for request/response modification.
|
||||
|
||||
This transport is used internally when a user provides a BaseInterceptor.
|
||||
@@ -61,19 +97,19 @@ class AsyncHTTPransport(httpx.AsyncHTTPTransport):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
interceptor: BaseInterceptor[httpx.Request, httpx.Response],
|
||||
**kwargs: Any,
|
||||
interceptor: BaseInterceptor[Request, Response],
|
||||
**kwargs: Unpack[HTTPTransportKwargs],
|
||||
) -> None:
|
||||
"""Initialize async transport with interceptor.
|
||||
|
||||
Args:
|
||||
interceptor: HTTP interceptor for modifying raw request/response objects.
|
||||
**kwargs: Additional arguments passed to httpx.AsyncHTTPTransport.
|
||||
**kwargs: HTTPTransport configuration parameters (verify, cert, proxy, etc.).
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
self.interceptor = interceptor
|
||||
|
||||
async def handle_async_request(self, request: httpx.Request) -> httpx.Response:
|
||||
async def handle_async_request(self, request: Request) -> Response:
|
||||
"""Handle async request with interception.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -5,11 +5,14 @@ import logging
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field, PrivateAttr, computed_field, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
from crewai.events.types.llm_events import LLMCallType
|
||||
from crewai.llm import CONTEXT_WINDOW_USAGE_RATIO
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.llms.hooks.transport import HTTPTransport
|
||||
from crewai.llms.providers.utils.common import safe_tool_conversion
|
||||
from crewai.utilities.agent_utils import is_context_length_exceeded
|
||||
from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
||||
LLMContextLengthExceededError,
|
||||
@@ -18,7 +21,8 @@ from crewai.utilities.types import LLMMessage
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.llms.hooks.base import BaseInterceptor
|
||||
from crewai.agent import Agent
|
||||
from crewai.task import Task
|
||||
|
||||
try:
|
||||
from anthropic import Anthropic
|
||||
@@ -31,6 +35,19 @@ except ImportError:
|
||||
) from None
|
||||
|
||||
|
||||
ANTHROPIC_CONTEXT_WINDOWS: dict[str, int] = {
|
||||
"claude-3-5-sonnet": 200000,
|
||||
"claude-3-5-haiku": 200000,
|
||||
"claude-3-opus": 200000,
|
||||
"claude-3-sonnet": 200000,
|
||||
"claude-3-haiku": 200000,
|
||||
"claude-3-7-sonnet": 200000,
|
||||
"claude-2.1": 200000,
|
||||
"claude-2": 100000,
|
||||
"claude-instant": 100000,
|
||||
}
|
||||
|
||||
|
||||
class AnthropicCompletion(BaseLLM):
|
||||
"""Anthropic native completion implementation.
|
||||
|
||||
@@ -38,86 +55,69 @@ class AnthropicCompletion(BaseLLM):
|
||||
offering native tool use, streaming support, and proper message formatting.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str = "claude-3-5-sonnet-20241022",
|
||||
api_key: str | None = None,
|
||||
base_url: str | None = None,
|
||||
timeout: float | None = None,
|
||||
max_retries: int = 2,
|
||||
temperature: float | None = None,
|
||||
max_tokens: int = 4096, # Required for Anthropic
|
||||
top_p: float | None = None,
|
||||
stop_sequences: list[str] | None = None,
|
||||
stream: bool = False,
|
||||
client_params: dict[str, Any] | None = None,
|
||||
interceptor: BaseInterceptor[httpx.Request, httpx.Response] | None = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
"""Initialize Anthropic chat completion client.
|
||||
model: str = Field(
|
||||
default="claude-3-5-sonnet-20241022",
|
||||
description="Anthropic model name (e.g., 'claude-3-5-sonnet-20241022')",
|
||||
)
|
||||
max_tokens: int = Field(
|
||||
default=4096,
|
||||
description="Maximum number of allowed tokens in response.",
|
||||
)
|
||||
top_p: float | None = Field(
|
||||
default=None,
|
||||
description="Nucleus sampling parameter.",
|
||||
)
|
||||
_client: Anthropic = PrivateAttr(
|
||||
default_factory=Anthropic,
|
||||
)
|
||||
|
||||
Args:
|
||||
model: Anthropic model name (e.g., 'claude-3-5-sonnet-20241022')
|
||||
api_key: Anthropic API key (defaults to ANTHROPIC_API_KEY env var)
|
||||
base_url: Custom base URL for Anthropic API
|
||||
timeout: Request timeout in seconds
|
||||
max_retries: Maximum number of retries
|
||||
temperature: Sampling temperature (0-1)
|
||||
max_tokens: Maximum tokens in response (required for Anthropic)
|
||||
top_p: Nucleus sampling parameter
|
||||
stop_sequences: Stop sequences (Anthropic uses stop_sequences, not stop)
|
||||
stream: Enable streaming responses
|
||||
client_params: Additional parameters for the Anthropic client
|
||||
interceptor: HTTP interceptor for modifying requests/responses at transport level.
|
||||
**kwargs: Additional parameters
|
||||
@model_validator(mode="after")
|
||||
def initialize_client(self) -> Self:
|
||||
"""Initialize the Anthropic client after Pydantic validation.
|
||||
|
||||
This runs after all field validation is complete, ensuring that:
|
||||
- All BaseLLM fields are set (model, temperature, stop_sequences, etc.)
|
||||
- Field validators have run (stop_sequences is normalized to set[str])
|
||||
- API key and other configuration is ready
|
||||
"""
|
||||
super().__init__(
|
||||
model=model, temperature=temperature, stop=stop_sequences or [], **kwargs
|
||||
)
|
||||
|
||||
# Client params
|
||||
self.interceptor = interceptor
|
||||
self.client_params = client_params
|
||||
self.base_url = base_url
|
||||
self.timeout = timeout
|
||||
self.max_retries = max_retries
|
||||
|
||||
self.client = Anthropic(**self._get_client_params())
|
||||
|
||||
# Store completion parameters
|
||||
self.max_tokens = max_tokens
|
||||
self.top_p = top_p
|
||||
self.stream = stream
|
||||
self.stop_sequences = stop_sequences or []
|
||||
|
||||
# Model-specific settings
|
||||
self.is_claude_3 = "claude-3" in model.lower()
|
||||
self.supports_tools = self.is_claude_3 # Claude 3+ supports tool use
|
||||
|
||||
def _get_client_params(self) -> dict[str, Any]:
|
||||
"""Get client parameters."""
|
||||
|
||||
if self.api_key is None:
|
||||
self.api_key = os.getenv("ANTHROPIC_API_KEY")
|
||||
if self.api_key is None:
|
||||
raise ValueError("ANTHROPIC_API_KEY is required")
|
||||
|
||||
client_params = {
|
||||
"api_key": self.api_key,
|
||||
"base_url": self.base_url,
|
||||
"timeout": self.timeout,
|
||||
"max_retries": self.max_retries,
|
||||
}
|
||||
params = self.model_dump(
|
||||
include={"api_key", "base_url", "timeout", "max_retries"},
|
||||
exclude_none=True,
|
||||
)
|
||||
|
||||
if self.interceptor:
|
||||
transport = HTTPTransport(interceptor=self.interceptor)
|
||||
http_client = httpx.Client(transport=transport)
|
||||
client_params["http_client"] = http_client # type: ignore[assignment]
|
||||
params["http_client"] = http_client
|
||||
|
||||
if self.client_params:
|
||||
client_params.update(self.client_params)
|
||||
params.update(self.client_params)
|
||||
|
||||
return client_params
|
||||
self._client = Anthropic(**params)
|
||||
return self
|
||||
|
||||
@computed_field # type: ignore[prop-decorator]
|
||||
@property
|
||||
def is_claude_3(self) -> bool:
|
||||
"""Check if the model is Claude 3 or higher."""
|
||||
return "claude-3" in self.model.lower()
|
||||
|
||||
@computed_field # type: ignore[prop-decorator]
|
||||
@property
|
||||
def supports_tools(self) -> bool:
|
||||
"""Check if the model supports tool use."""
|
||||
return self.is_claude_3
|
||||
|
||||
@computed_field # type: ignore[prop-decorator]
|
||||
@property
|
||||
def supports_function_calling(self) -> bool:
|
||||
"""Check if the model supports function calling."""
|
||||
return self.supports_tools
|
||||
|
||||
def call(
|
||||
self,
|
||||
@@ -125,8 +125,8 @@ class AnthropicCompletion(BaseLLM):
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
callbacks: list[Any] | None = None,
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
from_task: Task | None = None,
|
||||
from_agent: Agent | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str | Any:
|
||||
"""Call Anthropic messages API.
|
||||
@@ -205,25 +205,21 @@ class AnthropicCompletion(BaseLLM):
|
||||
Returns:
|
||||
Parameters dictionary for Anthropic API
|
||||
"""
|
||||
params = {
|
||||
"model": self.model,
|
||||
"messages": messages,
|
||||
"max_tokens": self.max_tokens,
|
||||
"stream": self.stream,
|
||||
}
|
||||
|
||||
params = self.model_dump(
|
||||
include={
|
||||
"model",
|
||||
"max_tokens",
|
||||
"stream",
|
||||
"temperature",
|
||||
"top_p",
|
||||
"stop_sequences",
|
||||
},
|
||||
)
|
||||
params["messages"] = messages
|
||||
# Add system message if present
|
||||
if system_message:
|
||||
params["system"] = system_message
|
||||
|
||||
# Add optional parameters if set
|
||||
if self.temperature is not None:
|
||||
params["temperature"] = self.temperature
|
||||
if self.top_p is not None:
|
||||
params["top_p"] = self.top_p
|
||||
if self.stop_sequences:
|
||||
params["stop_sequences"] = self.stop_sequences
|
||||
|
||||
# Handle tools for Claude 3+
|
||||
if tools and self.supports_tools:
|
||||
params["tools"] = self._convert_tools_for_interference(tools)
|
||||
@@ -242,8 +238,6 @@ class AnthropicCompletion(BaseLLM):
|
||||
continue
|
||||
|
||||
try:
|
||||
from crewai.llms.providers.utils.common import safe_tool_conversion
|
||||
|
||||
name, description, parameters = safe_tool_conversion(tool, "Anthropic")
|
||||
except (ImportError, KeyError, ValueError) as e:
|
||||
logging.error(f"Error converting tool to Anthropic format: {e}")
|
||||
@@ -317,8 +311,8 @@ class AnthropicCompletion(BaseLLM):
|
||||
self,
|
||||
params: dict[str, Any],
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
from_task: Task | None = None,
|
||||
from_agent: Agent | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str | Any:
|
||||
"""Handle non-streaming message completion."""
|
||||
@@ -333,7 +327,7 @@ class AnthropicCompletion(BaseLLM):
|
||||
params["tool_choice"] = {"type": "tool", "name": "structured_output"}
|
||||
|
||||
try:
|
||||
response: Message = self.client.messages.create(**params)
|
||||
response: Message = self._client.messages.create(**params)
|
||||
|
||||
except Exception as e:
|
||||
if is_context_length_exceeded(e):
|
||||
@@ -405,8 +399,8 @@ class AnthropicCompletion(BaseLLM):
|
||||
self,
|
||||
params: dict[str, Any],
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
from_task: Task | None = None,
|
||||
from_agent: Agent | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str:
|
||||
"""Handle streaming message completion."""
|
||||
@@ -427,7 +421,7 @@ class AnthropicCompletion(BaseLLM):
|
||||
stream_params = {k: v for k, v in params.items() if k != "stream"}
|
||||
|
||||
# Make streaming API call
|
||||
with self.client.messages.stream(**stream_params) as stream:
|
||||
with self._client.messages.stream(**stream_params) as stream:
|
||||
for event in stream:
|
||||
if hasattr(event, "delta") and hasattr(event.delta, "text"):
|
||||
text_delta = event.delta.text
|
||||
@@ -501,8 +495,8 @@ class AnthropicCompletion(BaseLLM):
|
||||
tool_uses: list[ToolUseBlock],
|
||||
params: dict[str, Any],
|
||||
available_functions: dict[str, Any],
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
from_task: Task | None = None,
|
||||
from_agent: Agent | None = None,
|
||||
) -> str:
|
||||
"""Handle the complete tool use conversation flow.
|
||||
|
||||
@@ -555,7 +549,7 @@ class AnthropicCompletion(BaseLLM):
|
||||
|
||||
try:
|
||||
# Send tool results back to Claude for final response
|
||||
final_response: Message = self.client.messages.create(**follow_up_params)
|
||||
final_response: Message = self._client.messages.create(**follow_up_params)
|
||||
|
||||
# Track token usage for follow-up call
|
||||
follow_up_usage = self._extract_anthropic_token_usage(final_response)
|
||||
@@ -602,48 +596,24 @@ class AnthropicCompletion(BaseLLM):
|
||||
return tool_results[0]["content"]
|
||||
raise e
|
||||
|
||||
def supports_function_calling(self) -> bool:
|
||||
"""Check if the model supports function calling."""
|
||||
return self.supports_tools
|
||||
|
||||
def supports_stop_words(self) -> bool:
|
||||
"""Check if the model supports stop words."""
|
||||
return True # All Claude models support stop sequences
|
||||
|
||||
def get_context_window_size(self) -> int:
|
||||
"""Get the context window size for the model."""
|
||||
from crewai.llm import CONTEXT_WINDOW_USAGE_RATIO
|
||||
|
||||
# Context window sizes for Anthropic models
|
||||
context_windows = {
|
||||
"claude-3-5-sonnet": 200000,
|
||||
"claude-3-5-haiku": 200000,
|
||||
"claude-3-opus": 200000,
|
||||
"claude-3-sonnet": 200000,
|
||||
"claude-3-haiku": 200000,
|
||||
"claude-3-7-sonnet": 200000,
|
||||
"claude-2.1": 200000,
|
||||
"claude-2": 100000,
|
||||
"claude-instant": 100000,
|
||||
}
|
||||
|
||||
# Find the best match for the model name
|
||||
for model_prefix, size in context_windows.items():
|
||||
for model_prefix, size in ANTHROPIC_CONTEXT_WINDOWS.items():
|
||||
if self.model.startswith(model_prefix):
|
||||
return int(size * CONTEXT_WINDOW_USAGE_RATIO)
|
||||
|
||||
# Default context window size for Claude models
|
||||
return int(200000 * CONTEXT_WINDOW_USAGE_RATIO)
|
||||
|
||||
def _extract_anthropic_token_usage(self, response: Message) -> dict[str, Any]:
|
||||
@staticmethod
|
||||
def _extract_anthropic_token_usage(response: Message) -> dict[str, Any]:
|
||||
"""Extract token usage from Anthropic response."""
|
||||
if hasattr(response, "usage") and response.usage:
|
||||
if response.usage:
|
||||
usage = response.usage
|
||||
input_tokens = getattr(usage, "input_tokens", 0)
|
||||
output_tokens = getattr(usage, "output_tokens", 0)
|
||||
return {
|
||||
"input_tokens": input_tokens,
|
||||
"output_tokens": output_tokens,
|
||||
"total_tokens": input_tokens + output_tokens,
|
||||
"input_tokens": usage.input_tokens,
|
||||
"output_tokens": usage.output_tokens,
|
||||
"total_tokens": usage.input_tokens + usage.output_tokens,
|
||||
}
|
||||
return {"total_tokens": 0}
|
||||
|
||||
@@ -243,6 +243,30 @@ class BedrockCompletion(BaseLLM):
|
||||
# Handle inference profiles for newer models
|
||||
self.model_id = model
|
||||
|
||||
@property
|
||||
def stop(self) -> list[str]:
|
||||
"""Get stop sequences sent to the API."""
|
||||
return list(self.stop_sequences)
|
||||
|
||||
@stop.setter
|
||||
def stop(self, value: Sequence[str] | str | None) -> None:
|
||||
"""Set stop sequences.
|
||||
|
||||
Synchronizes stop_sequences to ensure values set by CrewAgentExecutor
|
||||
are properly sent to the Bedrock API.
|
||||
|
||||
Args:
|
||||
value: Stop sequences as a Sequence, single string, or None
|
||||
"""
|
||||
if value is None:
|
||||
self.stop_sequences = []
|
||||
elif isinstance(value, str):
|
||||
self.stop_sequences = [value]
|
||||
elif isinstance(value, Sequence):
|
||||
self.stop_sequences = list(value)
|
||||
else:
|
||||
self.stop_sequences = []
|
||||
|
||||
def call(
|
||||
self,
|
||||
messages: str | list[LLMMessage],
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, cast
|
||||
from __future__ import annotations
|
||||
|
||||
from pydantic import BaseModel
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
from pydantic import BaseModel, Field, PrivateAttr, computed_field, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
from crewai.events.types.llm_events import LLMCallType
|
||||
from crewai.llm import CONTEXT_WINDOW_USAGE_RATIO, LLM_CONTEXT_WINDOW_SIZES
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.llms.hooks.base import BaseInterceptor
|
||||
from crewai.utilities.agent_utils import is_context_length_exceeded
|
||||
from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
||||
LLMContextLengthExceededError,
|
||||
@@ -14,6 +16,11 @@ from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
||||
from crewai.utilities.types import LLMMessage
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.agent import Agent
|
||||
from crewai.task import Task
|
||||
|
||||
|
||||
try:
|
||||
from google import genai # type: ignore[import-untyped]
|
||||
from google.genai import types # type: ignore[import-untyped]
|
||||
@@ -24,6 +31,27 @@ except ImportError:
|
||||
) from None
|
||||
|
||||
|
||||
GEMINI_CONTEXT_WINDOWS: dict[str, int] = {
|
||||
"gemini-2.0-flash": 1048576, # 1M tokens
|
||||
"gemini-2.0-flash-thinking": 32768,
|
||||
"gemini-2.0-flash-lite": 1048576,
|
||||
"gemini-2.5-flash": 1048576,
|
||||
"gemini-2.5-pro": 1048576,
|
||||
"gemini-1.5-pro": 2097152, # 2M tokens
|
||||
"gemini-1.5-flash": 1048576,
|
||||
"gemini-1.5-flash-8b": 1048576,
|
||||
"gemini-1.0-pro": 32768,
|
||||
"gemma-3-1b": 32000,
|
||||
"gemma-3-4b": 128000,
|
||||
"gemma-3-12b": 128000,
|
||||
"gemma-3-27b": 128000,
|
||||
}
|
||||
|
||||
# Context window validation constraints
|
||||
MIN_CONTEXT_WINDOW: int = 1024
|
||||
MAX_CONTEXT_WINDOW: int = 2097152
|
||||
|
||||
|
||||
class GeminiCompletion(BaseLLM):
|
||||
"""Google Gemini native completion implementation.
|
||||
|
||||
@@ -31,78 +59,164 @@ class GeminiCompletion(BaseLLM):
|
||||
offering native function calling, streaming support, and proper Gemini formatting.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str = "gemini-2.0-flash-001",
|
||||
api_key: str | None = None,
|
||||
project: str | None = None,
|
||||
location: str | None = None,
|
||||
temperature: float | None = None,
|
||||
top_p: float | None = None,
|
||||
top_k: int | None = None,
|
||||
max_output_tokens: int | None = None,
|
||||
stop_sequences: list[str] | None = None,
|
||||
stream: bool = False,
|
||||
safety_settings: dict[str, Any] | None = None,
|
||||
client_params: dict[str, Any] | None = None,
|
||||
interceptor: BaseInterceptor[Any, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
"""Initialize Google Gemini chat completion client.
|
||||
model: str = Field(
|
||||
default="gemini-2.0-flash-001",
|
||||
description="Gemini model name (e.g., 'gemini-2.0-flash-001', 'gemini-1.5-pro')",
|
||||
)
|
||||
project: str | None = Field(
|
||||
default=None,
|
||||
description="Google Cloud project ID (for Vertex AI)",
|
||||
)
|
||||
location: str = Field(
|
||||
default="us-central1",
|
||||
description="Google Cloud location (for Vertex AI)",
|
||||
)
|
||||
top_p: float | None = Field(
|
||||
default=None,
|
||||
description="Nucleus sampling parameter",
|
||||
)
|
||||
top_k: int | None = Field(
|
||||
default=None,
|
||||
description="Top-k sampling parameter",
|
||||
)
|
||||
max_output_tokens: int | None = Field(
|
||||
default=None,
|
||||
description="Maximum tokens in response",
|
||||
)
|
||||
safety_settings: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
description="Safety filter settings",
|
||||
)
|
||||
_client: genai.Client = PrivateAttr( # type: ignore[no-any-unimported]
|
||||
default_factory=genai.Client,
|
||||
)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def initialize_client(self) -> Self:
|
||||
"""Initialize the Anthropic client after Pydantic validation.
|
||||
|
||||
This runs after all field validation is complete, ensuring that:
|
||||
- All BaseLLM fields are set (model, temperature, stop_sequences, etc.)
|
||||
- Field validators have run (stop_sequences is normalized to set[str])
|
||||
- API key and other configuration is ready
|
||||
"""
|
||||
self._client = genai.Client(**self._get_client_params())
|
||||
return self
|
||||
|
||||
# def __init__(
|
||||
# self,
|
||||
# model: str = "gemini-2.0-flash-001",
|
||||
# api_key: str | None = None,
|
||||
# project: str | None = None,
|
||||
# location: str | None = None,
|
||||
# temperature: float | None = None,
|
||||
# top_p: float | None = None,
|
||||
# top_k: int | None = None,
|
||||
# max_output_tokens: int | None = None,
|
||||
# stop_sequences: list[str] | None = None,
|
||||
# stream: bool = False,
|
||||
# safety_settings: dict[str, Any] | None = None,
|
||||
# client_params: dict[str, Any] | None = None,
|
||||
# interceptor: BaseInterceptor[Any, Any] | None = None,
|
||||
# **kwargs: Any,
|
||||
# # ):
|
||||
# """Initialize Google Gemini chat completion client.
|
||||
#
|
||||
# Args:
|
||||
# model: Gemini model name (e.g., 'gemini-2.0-flash-001', 'gemini-1.5-pro')
|
||||
# api_key: Google API key (defaults to GOOGLE_API_KEY or GEMINI_API_KEY env var)
|
||||
# project: Google Cloud project ID (for Vertex AI)
|
||||
# location: Google Cloud location (for Vertex AI, defaults to 'us-central1')
|
||||
# temperature: Sampling temperature (0-2)
|
||||
# top_p: Nucleus sampling parameter
|
||||
# top_k: Top-k sampling parameter
|
||||
# max_output_tokens: Maximum tokens in response
|
||||
# stop_sequences: Stop sequences
|
||||
# stream: Enable streaming responses
|
||||
# safety_settings: Safety filter settings
|
||||
# client_params: Additional parameters to pass to the Google Gen AI Client constructor.
|
||||
# Supports parameters like http_options, credentials, debug_config, etc.
|
||||
# interceptor: HTTP interceptor (not yet supported for Gemini).
|
||||
# **kwargs: Additional parameters
|
||||
# """
|
||||
# if interceptor is not None:
|
||||
# raise NotImplementedError(
|
||||
# "HTTP interceptors are not yet supported for Google Gemini provider. "
|
||||
# "Interceptors are currently supported for OpenAI and Anthropic providers only."
|
||||
# )
|
||||
#
|
||||
# super().__init__(
|
||||
# model=model, temperature=temperature, stop=stop_sequences or [], **kwargs
|
||||
# )
|
||||
#
|
||||
# # Store client params for later use
|
||||
# self.client_params = client_params or {}
|
||||
#
|
||||
# # Get API configuration with environment variable fallbacks
|
||||
# self.api_key = (
|
||||
# api_key or os.getenv("GOOGLE_API_KEY") or os.getenv("GEMINI_API_KEY")
|
||||
# )
|
||||
# self.project = project or os.getenv("GOOGLE_CLOUD_PROJECT")
|
||||
# self.location = location or os.getenv("GOOGLE_CLOUD_LOCATION") or "us-central1"
|
||||
#
|
||||
# use_vertexai = os.getenv("GOOGLE_GENAI_USE_VERTEXAI", "").lower() == "true"
|
||||
#
|
||||
# self.client = self._initialize_client(use_vertexai)
|
||||
#
|
||||
# # Store completion parameters
|
||||
# self.top_p = top_p
|
||||
# self.top_k = top_k
|
||||
# self.max_output_tokens = max_output_tokens
|
||||
# self.stream = stream
|
||||
# self.safety_settings = safety_settings or {}
|
||||
# self.stop_sequences = stop_sequences or []
|
||||
#
|
||||
# # Model-specific settings
|
||||
# self.is_gemini_2 = "gemini-2" in model.lower()
|
||||
# self.is_gemini_1_5 = "gemini-1.5" in model.lower()
|
||||
# self.supports_tools = self.is_gemini_1_5 or self.is_gemini_2
|
||||
|
||||
@computed_field # type: ignore[prop-decorator]
|
||||
@property
|
||||
def is_gemini_2(self) -> bool:
|
||||
"""Check if the model is Gemini 2.x."""
|
||||
return "gemini-2" in self.model.lower()
|
||||
|
||||
@computed_field # type: ignore[prop-decorator]
|
||||
@property
|
||||
def is_gemini_1_5(self) -> bool:
|
||||
"""Check if the model is Gemini 1.5.x."""
|
||||
return "gemini-1.5" in self.model.lower()
|
||||
|
||||
@computed_field # type: ignore[prop-decorator]
|
||||
@property
|
||||
def supports_tools(self) -> bool:
|
||||
"""Check if the model supports tool/function calling."""
|
||||
return self.is_gemini_1_5 or self.is_gemini_2
|
||||
|
||||
@property
|
||||
def stop(self) -> list[str]:
|
||||
"""Get stop sequences sent to the API."""
|
||||
return self.stop_sequences
|
||||
|
||||
@stop.setter
|
||||
def stop(self, value: list[str] | str | None) -> None:
|
||||
"""Set stop sequences.
|
||||
|
||||
Synchronizes stop_sequences to ensure values set by CrewAgentExecutor
|
||||
are properly sent to the Gemini API.
|
||||
|
||||
Args:
|
||||
model: Gemini model name (e.g., 'gemini-2.0-flash-001', 'gemini-1.5-pro')
|
||||
api_key: Google API key (defaults to GOOGLE_API_KEY or GEMINI_API_KEY env var)
|
||||
project: Google Cloud project ID (for Vertex AI)
|
||||
location: Google Cloud location (for Vertex AI, defaults to 'us-central1')
|
||||
temperature: Sampling temperature (0-2)
|
||||
top_p: Nucleus sampling parameter
|
||||
top_k: Top-k sampling parameter
|
||||
max_output_tokens: Maximum tokens in response
|
||||
stop_sequences: Stop sequences
|
||||
stream: Enable streaming responses
|
||||
safety_settings: Safety filter settings
|
||||
client_params: Additional parameters to pass to the Google Gen AI Client constructor.
|
||||
Supports parameters like http_options, credentials, debug_config, etc.
|
||||
interceptor: HTTP interceptor (not yet supported for Gemini).
|
||||
**kwargs: Additional parameters
|
||||
value: Stop sequences as a list, single string, or None
|
||||
"""
|
||||
if interceptor is not None:
|
||||
raise NotImplementedError(
|
||||
"HTTP interceptors are not yet supported for Google Gemini provider. "
|
||||
"Interceptors are currently supported for OpenAI and Anthropic providers only."
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
model=model, temperature=temperature, stop=stop_sequences or [], **kwargs
|
||||
)
|
||||
|
||||
# Store client params for later use
|
||||
self.client_params = client_params or {}
|
||||
|
||||
# Get API configuration with environment variable fallbacks
|
||||
self.api_key = (
|
||||
api_key or os.getenv("GOOGLE_API_KEY") or os.getenv("GEMINI_API_KEY")
|
||||
)
|
||||
self.project = project or os.getenv("GOOGLE_CLOUD_PROJECT")
|
||||
self.location = location or os.getenv("GOOGLE_CLOUD_LOCATION") or "us-central1"
|
||||
|
||||
use_vertexai = os.getenv("GOOGLE_GENAI_USE_VERTEXAI", "").lower() == "true"
|
||||
|
||||
self.client = self._initialize_client(use_vertexai)
|
||||
|
||||
# Store completion parameters
|
||||
self.top_p = top_p
|
||||
self.top_k = top_k
|
||||
self.max_output_tokens = max_output_tokens
|
||||
self.stream = stream
|
||||
self.safety_settings = safety_settings or {}
|
||||
self.stop_sequences = stop_sequences or []
|
||||
|
||||
# Model-specific settings
|
||||
self.is_gemini_2 = "gemini-2" in model.lower()
|
||||
self.is_gemini_1_5 = "gemini-1.5" in model.lower()
|
||||
self.supports_tools = self.is_gemini_1_5 or self.is_gemini_2
|
||||
if value is None:
|
||||
self.stop_sequences = []
|
||||
elif isinstance(value, str):
|
||||
self.stop_sequences = [value]
|
||||
elif isinstance(value, list):
|
||||
self.stop_sequences = value
|
||||
else:
|
||||
self.stop_sequences = []
|
||||
|
||||
def _initialize_client(self, use_vertexai: bool = False) -> genai.Client: # type: ignore[no-any-unimported]
|
||||
"""Initialize the Google Gen AI client with proper parameter handling.
|
||||
@@ -118,6 +232,12 @@ class GeminiCompletion(BaseLLM):
|
||||
if self.client_params:
|
||||
client_params.update(self.client_params)
|
||||
|
||||
if self.interceptor:
|
||||
raise NotImplementedError(
|
||||
"HTTP interceptors are not yet supported for Google Gemini provider. "
|
||||
"Interceptors are currently supported for OpenAI and Anthropic providers only."
|
||||
)
|
||||
|
||||
if use_vertexai or self.project:
|
||||
client_params.update(
|
||||
{
|
||||
@@ -157,7 +277,7 @@ class GeminiCompletion(BaseLLM):
|
||||
|
||||
if (
|
||||
hasattr(self, "client")
|
||||
and hasattr(self.client, "vertexai")
|
||||
and hasattr(self._client, "vertexai")
|
||||
and self.client.vertexai
|
||||
):
|
||||
# Vertex AI configuration
|
||||
@@ -182,8 +302,8 @@ class GeminiCompletion(BaseLLM):
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
callbacks: list[Any] | None = None,
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
from_task: Task | None = None,
|
||||
from_agent: Agent | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str | Any:
|
||||
"""Call Google Gemini generate content API.
|
||||
@@ -270,7 +390,16 @@ class GeminiCompletion(BaseLLM):
|
||||
GenerateContentConfig object for Gemini API
|
||||
"""
|
||||
self.tools = tools
|
||||
config_params = {}
|
||||
config_params = self.model_dump(
|
||||
include={
|
||||
"temperature",
|
||||
"top_p",
|
||||
"top_k",
|
||||
"max_output_tokens",
|
||||
"stop_sequences",
|
||||
"safety_settings",
|
||||
}
|
||||
)
|
||||
|
||||
# Add system instruction if present
|
||||
if system_instruction:
|
||||
@@ -280,18 +409,6 @@ class GeminiCompletion(BaseLLM):
|
||||
)
|
||||
config_params["system_instruction"] = system_content
|
||||
|
||||
# Add generation config parameters
|
||||
if self.temperature is not None:
|
||||
config_params["temperature"] = self.temperature
|
||||
if self.top_p is not None:
|
||||
config_params["top_p"] = self.top_p
|
||||
if self.top_k is not None:
|
||||
config_params["top_k"] = self.top_k
|
||||
if self.max_output_tokens is not None:
|
||||
config_params["max_output_tokens"] = self.max_output_tokens
|
||||
if self.stop_sequences:
|
||||
config_params["stop_sequences"] = self.stop_sequences
|
||||
|
||||
if response_model:
|
||||
config_params["response_mime_type"] = "application/json"
|
||||
config_params["response_schema"] = response_model.model_json_schema()
|
||||
@@ -300,9 +417,6 @@ class GeminiCompletion(BaseLLM):
|
||||
if tools and self.supports_tools:
|
||||
config_params["tools"] = self._convert_tools_for_interference(tools)
|
||||
|
||||
if self.safety_settings:
|
||||
config_params["safety_settings"] = self.safety_settings
|
||||
|
||||
return types.GenerateContentConfig(**config_params)
|
||||
|
||||
def _convert_tools_for_interference( # type: ignore[no-any-unimported]
|
||||
@@ -380,8 +494,8 @@ class GeminiCompletion(BaseLLM):
|
||||
system_instruction: str | None,
|
||||
config: types.GenerateContentConfig,
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
from_task: Task | None = None,
|
||||
from_agent: Agent | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str | Any:
|
||||
"""Handle non-streaming content generation."""
|
||||
@@ -392,7 +506,7 @@ class GeminiCompletion(BaseLLM):
|
||||
}
|
||||
|
||||
try:
|
||||
response = self.client.models.generate_content(**api_params)
|
||||
response = self._client.models.generate_content(**api_params)
|
||||
|
||||
usage = self._extract_token_usage(response)
|
||||
except Exception as e:
|
||||
@@ -446,8 +560,8 @@ class GeminiCompletion(BaseLLM):
|
||||
contents: list[types.Content],
|
||||
config: types.GenerateContentConfig,
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
from_task: Task | None = None,
|
||||
from_agent: Agent | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str:
|
||||
"""Handle streaming content generation."""
|
||||
@@ -460,7 +574,7 @@ class GeminiCompletion(BaseLLM):
|
||||
"config": config,
|
||||
}
|
||||
|
||||
for chunk in self.client.models.generate_content_stream(**api_params):
|
||||
for chunk in self._client.models.generate_content_stream(**api_params):
|
||||
if hasattr(chunk, "text") and chunk.text:
|
||||
full_response += chunk.text
|
||||
self._emit_stream_chunk_event(
|
||||
@@ -513,52 +627,30 @@ class GeminiCompletion(BaseLLM):
|
||||
|
||||
return full_response
|
||||
|
||||
@computed_field # type: ignore[prop-decorator]
|
||||
@property
|
||||
def supports_function_calling(self) -> bool:
|
||||
"""Check if the model supports function calling."""
|
||||
return self.supports_tools
|
||||
|
||||
def supports_stop_words(self) -> bool:
|
||||
"""Check if the model supports stop words."""
|
||||
return True
|
||||
|
||||
def get_context_window_size(self) -> int:
|
||||
"""Get the context window size for the model."""
|
||||
from crewai.llm import CONTEXT_WINDOW_USAGE_RATIO, LLM_CONTEXT_WINDOW_SIZES
|
||||
|
||||
min_context = 1024
|
||||
max_context = 2097152
|
||||
|
||||
for key, value in LLM_CONTEXT_WINDOW_SIZES.items():
|
||||
if value < min_context or value > max_context:
|
||||
if value < MIN_CONTEXT_WINDOW or value > MAX_CONTEXT_WINDOW:
|
||||
raise ValueError(
|
||||
f"Context window for {key} must be between {min_context} and {max_context}"
|
||||
f"Context window for {key} must be between {MIN_CONTEXT_WINDOW} and {MAX_CONTEXT_WINDOW}"
|
||||
)
|
||||
|
||||
context_windows = {
|
||||
"gemini-2.0-flash": 1048576, # 1M tokens
|
||||
"gemini-2.0-flash-thinking": 32768,
|
||||
"gemini-2.0-flash-lite": 1048576,
|
||||
"gemini-2.5-flash": 1048576,
|
||||
"gemini-2.5-pro": 1048576,
|
||||
"gemini-1.5-pro": 2097152, # 2M tokens
|
||||
"gemini-1.5-flash": 1048576,
|
||||
"gemini-1.5-flash-8b": 1048576,
|
||||
"gemini-1.0-pro": 32768,
|
||||
"gemma-3-1b": 32000,
|
||||
"gemma-3-4b": 128000,
|
||||
"gemma-3-12b": 128000,
|
||||
"gemma-3-27b": 128000,
|
||||
}
|
||||
|
||||
# Find the best match for the model name
|
||||
for model_prefix, size in context_windows.items():
|
||||
for model_prefix, size in GEMINI_CONTEXT_WINDOWS.items():
|
||||
if self.model.startswith(model_prefix):
|
||||
return int(size * CONTEXT_WINDOW_USAGE_RATIO)
|
||||
|
||||
# Default context window size for Gemini models
|
||||
return int(1048576 * CONTEXT_WINDOW_USAGE_RATIO) # 1M tokens
|
||||
|
||||
def _extract_token_usage(self, response: dict[str, Any]) -> dict[str, Any]:
|
||||
@staticmethod
|
||||
def _extract_token_usage(response: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Extract token usage from Gemini response."""
|
||||
if hasattr(response, "usage_metadata"):
|
||||
usage = response.usage_metadata
|
||||
@@ -570,8 +662,8 @@ class GeminiCompletion(BaseLLM):
|
||||
}
|
||||
return {"total_tokens": 0}
|
||||
|
||||
@staticmethod
|
||||
def _convert_contents_to_dict( # type: ignore[no-any-unimported]
|
||||
self,
|
||||
contents: list[types.Content],
|
||||
) -> list[dict[str, str]]:
|
||||
"""Convert contents to dict format."""
|
||||
|
||||
@@ -4,16 +4,23 @@ from collections.abc import Iterator
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import TYPE_CHECKING, Any, Final
|
||||
|
||||
import httpx
|
||||
from openai import APIConnectionError, NotFoundError, OpenAI
|
||||
from openai.types.chat import ChatCompletion, ChatCompletionChunk
|
||||
from openai.types.chat.chat_completion import Choice
|
||||
from openai.types.chat.chat_completion_chunk import ChoiceDelta
|
||||
from pydantic import BaseModel
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
Field,
|
||||
PrivateAttr,
|
||||
model_validator,
|
||||
)
|
||||
from typing_extensions import Self
|
||||
|
||||
from crewai.events.types.llm_events import LLMCallType
|
||||
from crewai.llm import CONTEXT_WINDOW_USAGE_RATIO, LLM_CONTEXT_WINDOW_SIZES
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.llms.hooks.transport import HTTPTransport
|
||||
from crewai.utilities.agent_utils import is_context_length_exceeded
|
||||
@@ -25,11 +32,28 @@ from crewai.utilities.types import LLMMessage
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.agent.core import Agent
|
||||
from crewai.llms.hooks.base import BaseInterceptor
|
||||
from crewai.task import Task
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
|
||||
|
||||
OPENAI_CONTEXT_WINDOWS: dict[str, int] = {
|
||||
"gpt-4": 8192,
|
||||
"gpt-4o": 128000,
|
||||
"gpt-4o-mini": 200000,
|
||||
"gpt-4-turbo": 128000,
|
||||
"gpt-4.1": 1047576,
|
||||
"gpt-4.1-mini-2025-04-14": 1047576,
|
||||
"gpt-4.1-nano-2025-04-14": 1047576,
|
||||
"o1-preview": 128000,
|
||||
"o1-mini": 128000,
|
||||
"o3-mini": 200000,
|
||||
"o4-mini": 200000,
|
||||
}
|
||||
|
||||
MIN_CONTEXT_WINDOW: Final[int] = 1024
|
||||
MAX_CONTEXT_WINDOW: Final[int] = 2097152
|
||||
|
||||
|
||||
class OpenAICompletion(BaseLLM):
|
||||
"""OpenAI native completion implementation.
|
||||
|
||||
@@ -37,112 +61,125 @@ class OpenAICompletion(BaseLLM):
|
||||
offering native structured outputs, function calling, and streaming support.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str = "gpt-4o",
|
||||
api_key: str | None = None,
|
||||
base_url: str | None = None,
|
||||
organization: str | None = None,
|
||||
project: str | None = None,
|
||||
timeout: float | None = None,
|
||||
max_retries: int = 2,
|
||||
default_headers: dict[str, str] | None = None,
|
||||
default_query: dict[str, Any] | None = None,
|
||||
client_params: dict[str, Any] | None = None,
|
||||
temperature: float | None = None,
|
||||
top_p: float | None = None,
|
||||
frequency_penalty: float | None = None,
|
||||
presence_penalty: float | None = None,
|
||||
max_tokens: int | None = None,
|
||||
max_completion_tokens: int | None = None,
|
||||
seed: int | None = None,
|
||||
stream: bool = False,
|
||||
response_format: dict[str, Any] | type[BaseModel] | None = None,
|
||||
logprobs: bool | None = None,
|
||||
top_logprobs: int | None = None,
|
||||
reasoning_effort: str | None = None,
|
||||
provider: str | None = None,
|
||||
interceptor: BaseInterceptor[httpx.Request, httpx.Response] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize OpenAI chat completion client."""
|
||||
model: str = Field(
|
||||
default="gpt-4o",
|
||||
description="OpenAI model name (e.g., 'gpt-4o')",
|
||||
)
|
||||
organization: str | None = Field(
|
||||
default=None,
|
||||
description="Name of the OpenAI organization",
|
||||
)
|
||||
project: str | None = Field(
|
||||
default=None,
|
||||
description="Name of the OpenAI project",
|
||||
)
|
||||
api_base: str | None = Field(
|
||||
default=os.getenv("OPENAI_BASE_URL"),
|
||||
description="Base URL for OpenAI API",
|
||||
)
|
||||
default_headers: dict[str, str] | None = Field(
|
||||
default=None,
|
||||
description="Default headers for OpenAI API requests",
|
||||
)
|
||||
default_query: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
description="Default query parameters for OpenAI API requests",
|
||||
)
|
||||
top_p: float | None = Field(
|
||||
default=None,
|
||||
description="Top-p sampling parameter",
|
||||
)
|
||||
frequency_penalty: float | None = Field(
|
||||
default=None,
|
||||
description="Frequency penalty parameter",
|
||||
)
|
||||
presence_penalty: float | None = Field(
|
||||
default=None,
|
||||
description="Presence penalty parameter",
|
||||
)
|
||||
max_completion_tokens: int | None = Field(
|
||||
default=None,
|
||||
description="Maximum tokens for completion",
|
||||
)
|
||||
seed: int | None = Field(
|
||||
default=None,
|
||||
description="Random seed for reproducibility",
|
||||
)
|
||||
response_format: dict[str, Any] | type[BaseModel] | None = Field(
|
||||
default=None,
|
||||
description="Response format for structured output",
|
||||
)
|
||||
logprobs: bool | None = Field(
|
||||
default=None,
|
||||
description="Whether to include log probabilities",
|
||||
)
|
||||
top_logprobs: int | None = Field(
|
||||
default=None,
|
||||
description="Number of top log probabilities to return",
|
||||
)
|
||||
reasoning_effort: str | None = Field(
|
||||
default=None,
|
||||
description="Reasoning effort level for o1 models",
|
||||
)
|
||||
supports_function_calling: bool = Field(
|
||||
default=True,
|
||||
description="Whether the model supports function calling",
|
||||
)
|
||||
is_o1_model: bool = Field(
|
||||
default=False,
|
||||
description="Whether the model is an o1 model",
|
||||
)
|
||||
is_gpt4_model: bool = Field(
|
||||
default=False,
|
||||
description="Whether the model is a GPT-4 model",
|
||||
)
|
||||
_client: OpenAI = PrivateAttr(
|
||||
default_factory=OpenAI,
|
||||
)
|
||||
|
||||
if provider is None:
|
||||
provider = kwargs.pop("provider", "openai")
|
||||
|
||||
self.interceptor = interceptor
|
||||
# Client configuration attributes
|
||||
self.organization = organization
|
||||
self.project = project
|
||||
self.max_retries = max_retries
|
||||
self.default_headers = default_headers
|
||||
self.default_query = default_query
|
||||
self.client_params = client_params
|
||||
self.timeout = timeout
|
||||
self.base_url = base_url
|
||||
self.api_base = kwargs.pop("api_base", None)
|
||||
|
||||
super().__init__(
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
api_key=api_key or os.getenv("OPENAI_API_KEY"),
|
||||
base_url=base_url,
|
||||
timeout=timeout,
|
||||
provider=provider,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
client_config = self._get_client_params()
|
||||
if self.interceptor:
|
||||
transport = HTTPTransport(interceptor=self.interceptor)
|
||||
http_client = httpx.Client(transport=transport)
|
||||
client_config["http_client"] = http_client
|
||||
|
||||
self.client = OpenAI(**client_config)
|
||||
|
||||
# Completion parameters
|
||||
self.top_p = top_p
|
||||
self.frequency_penalty = frequency_penalty
|
||||
self.presence_penalty = presence_penalty
|
||||
self.max_tokens = max_tokens
|
||||
self.max_completion_tokens = max_completion_tokens
|
||||
self.seed = seed
|
||||
self.stream = stream
|
||||
self.response_format = response_format
|
||||
self.logprobs = logprobs
|
||||
self.top_logprobs = top_logprobs
|
||||
self.reasoning_effort = reasoning_effort
|
||||
self.is_o1_model = "o1" in model.lower()
|
||||
self.is_gpt4_model = "gpt-4" in model.lower()
|
||||
|
||||
def _get_client_params(self) -> dict[str, Any]:
|
||||
"""Get OpenAI client parameters."""
|
||||
@model_validator(mode="after")
|
||||
def initialize_client(self) -> Self:
|
||||
"""Initialize the Anthropic client after Pydantic validation.
|
||||
|
||||
This runs after all field validation is complete, ensuring that:
|
||||
- All BaseLLM fields are set (model, temperature, stop_sequences, etc.)
|
||||
- Field validators have run (stop_sequences is normalized to set[str])
|
||||
- API key and other configuration is ready
|
||||
"""
|
||||
if self.api_key is None:
|
||||
self.api_key = os.getenv("OPENAI_API_KEY")
|
||||
if self.api_key is None:
|
||||
raise ValueError("OPENAI_API_KEY is required")
|
||||
|
||||
base_params = {
|
||||
"api_key": self.api_key,
|
||||
"organization": self.organization,
|
||||
"project": self.project,
|
||||
"base_url": self.base_url
|
||||
or self.api_base
|
||||
or os.getenv("OPENAI_BASE_URL")
|
||||
or None,
|
||||
"timeout": self.timeout,
|
||||
"max_retries": self.max_retries,
|
||||
"default_headers": self.default_headers,
|
||||
"default_query": self.default_query,
|
||||
}
|
||||
self.is_o1_model = "o1" in self.model.lower()
|
||||
self.supports_function_calling = not self.is_o1_model
|
||||
self.is_gpt4_model = "gpt-4" in self.model.lower()
|
||||
self.supports_stop_words = not self.is_o1_model
|
||||
|
||||
client_params = {k: v for k, v in base_params.items() if v is not None}
|
||||
params = self.model_dump(
|
||||
include={
|
||||
"api_key",
|
||||
"organization",
|
||||
"project",
|
||||
"base_url",
|
||||
"timeout",
|
||||
"max_retries",
|
||||
"default_headers",
|
||||
"default_query",
|
||||
},
|
||||
exclude_none=True,
|
||||
)
|
||||
|
||||
if self.interceptor:
|
||||
transport = HTTPTransport(interceptor=self.interceptor)
|
||||
http_client = httpx.Client(transport=transport)
|
||||
params["http_client"] = http_client
|
||||
|
||||
if self.client_params:
|
||||
client_params.update(self.client_params)
|
||||
params.update(self.client_params)
|
||||
|
||||
return client_params
|
||||
self._client = OpenAI(**params)
|
||||
return self
|
||||
|
||||
def call(
|
||||
self,
|
||||
@@ -213,38 +250,26 @@ class OpenAICompletion(BaseLLM):
|
||||
self, messages: list[LLMMessage], tools: list[dict[str, BaseTool]] | None = None
|
||||
) -> dict[str, Any]:
|
||||
"""Prepare parameters for OpenAI chat completion."""
|
||||
params: dict[str, Any] = {
|
||||
"model": self.model,
|
||||
"messages": messages,
|
||||
}
|
||||
if self.stream:
|
||||
params["stream"] = self.stream
|
||||
|
||||
params = self.model_dump(
|
||||
include={
|
||||
"model",
|
||||
"stream",
|
||||
"temperature",
|
||||
"top_p",
|
||||
"frequency_penalty",
|
||||
"presence_penalty",
|
||||
"max_completion_tokens",
|
||||
"max_tokens",
|
||||
"seed",
|
||||
"logprobs",
|
||||
"top_logprobs",
|
||||
"reasoning_effort",
|
||||
},
|
||||
exclude_none=True,
|
||||
)
|
||||
params["messages"] = messages
|
||||
params.update(self.additional_params)
|
||||
|
||||
if self.temperature is not None:
|
||||
params["temperature"] = self.temperature
|
||||
if self.top_p is not None:
|
||||
params["top_p"] = self.top_p
|
||||
if self.frequency_penalty is not None:
|
||||
params["frequency_penalty"] = self.frequency_penalty
|
||||
if self.presence_penalty is not None:
|
||||
params["presence_penalty"] = self.presence_penalty
|
||||
if self.max_completion_tokens is not None:
|
||||
params["max_completion_tokens"] = self.max_completion_tokens
|
||||
elif self.max_tokens is not None:
|
||||
params["max_tokens"] = self.max_tokens
|
||||
if self.seed is not None:
|
||||
params["seed"] = self.seed
|
||||
if self.logprobs is not None:
|
||||
params["logprobs"] = self.logprobs
|
||||
if self.top_logprobs is not None:
|
||||
params["top_logprobs"] = self.top_logprobs
|
||||
|
||||
# Handle o1 model specific parameters
|
||||
if self.is_o1_model and self.reasoning_effort:
|
||||
params["reasoning_effort"] = self.reasoning_effort
|
||||
|
||||
if tools:
|
||||
params["tools"] = self._convert_tools_for_interference(tools)
|
||||
params["tool_choice"] = "auto"
|
||||
@@ -296,14 +321,14 @@ class OpenAICompletion(BaseLLM):
|
||||
self,
|
||||
params: dict[str, Any],
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
from_task: Task | None = None,
|
||||
from_agent: Agent | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str | Any:
|
||||
"""Handle non-streaming chat completion."""
|
||||
try:
|
||||
if response_model:
|
||||
parsed_response = self.client.beta.chat.completions.parse(
|
||||
parsed_response = self._client.beta.chat.completions.parse(
|
||||
**params,
|
||||
response_format=response_model,
|
||||
)
|
||||
@@ -327,7 +352,7 @@ class OpenAICompletion(BaseLLM):
|
||||
)
|
||||
return structured_json
|
||||
|
||||
response: ChatCompletion = self.client.chat.completions.create(**params)
|
||||
response: ChatCompletion = self._client.chat.completions.create(**params)
|
||||
|
||||
usage = self._extract_openai_token_usage(response)
|
||||
|
||||
@@ -419,8 +444,8 @@ class OpenAICompletion(BaseLLM):
|
||||
self,
|
||||
params: dict[str, Any],
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
from_task: Task | None = None,
|
||||
from_agent: Agent | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str:
|
||||
"""Handle streaming chat completion."""
|
||||
@@ -429,7 +454,7 @@ class OpenAICompletion(BaseLLM):
|
||||
|
||||
if response_model:
|
||||
completion_stream: Iterator[ChatCompletionChunk] = (
|
||||
self.client.chat.completions.create(**params)
|
||||
self._client.chat.completions.create(**params)
|
||||
)
|
||||
|
||||
accumulated_content = ""
|
||||
@@ -472,7 +497,7 @@ class OpenAICompletion(BaseLLM):
|
||||
)
|
||||
return accumulated_content
|
||||
|
||||
stream: Iterator[ChatCompletionChunk] = self.client.chat.completions.create(
|
||||
stream: Iterator[ChatCompletionChunk] = self._client.chat.completions.create(
|
||||
**params
|
||||
)
|
||||
|
||||
@@ -550,58 +575,31 @@ class OpenAICompletion(BaseLLM):
|
||||
|
||||
return full_response
|
||||
|
||||
def supports_function_calling(self) -> bool:
|
||||
"""Check if the model supports function calling."""
|
||||
return not self.is_o1_model
|
||||
|
||||
def supports_stop_words(self) -> bool:
|
||||
"""Check if the model supports stop words."""
|
||||
return not self.is_o1_model
|
||||
|
||||
def get_context_window_size(self) -> int:
|
||||
"""Get the context window size for the model."""
|
||||
from crewai.llm import CONTEXT_WINDOW_USAGE_RATIO, LLM_CONTEXT_WINDOW_SIZES
|
||||
|
||||
min_context = 1024
|
||||
max_context = 2097152
|
||||
|
||||
for key, value in LLM_CONTEXT_WINDOW_SIZES.items():
|
||||
if value < min_context or value > max_context:
|
||||
if value < MIN_CONTEXT_WINDOW or value > MAX_CONTEXT_WINDOW:
|
||||
raise ValueError(
|
||||
f"Context window for {key} must be between {min_context} and {max_context}"
|
||||
f"Context window for {key} must be between {MIN_CONTEXT_WINDOW} and {MAX_CONTEXT_WINDOW}"
|
||||
)
|
||||
|
||||
# Context window sizes for OpenAI models
|
||||
context_windows = {
|
||||
"gpt-4": 8192,
|
||||
"gpt-4o": 128000,
|
||||
"gpt-4o-mini": 200000,
|
||||
"gpt-4-turbo": 128000,
|
||||
"gpt-4.1": 1047576,
|
||||
"gpt-4.1-mini-2025-04-14": 1047576,
|
||||
"gpt-4.1-nano-2025-04-14": 1047576,
|
||||
"o1-preview": 128000,
|
||||
"o1-mini": 128000,
|
||||
"o3-mini": 200000,
|
||||
"o4-mini": 200000,
|
||||
}
|
||||
|
||||
# Find the best match for the model name
|
||||
for model_prefix, size in context_windows.items():
|
||||
for model_prefix, size in OPENAI_CONTEXT_WINDOWS.items():
|
||||
if self.model.startswith(model_prefix):
|
||||
return int(size * CONTEXT_WINDOW_USAGE_RATIO)
|
||||
|
||||
# Default context window size
|
||||
return int(8192 * CONTEXT_WINDOW_USAGE_RATIO)
|
||||
|
||||
def _extract_openai_token_usage(self, response: ChatCompletion) -> dict[str, Any]:
|
||||
@staticmethod
|
||||
def _extract_openai_token_usage(response: ChatCompletion) -> dict[str, Any]:
|
||||
"""Extract token usage from OpenAI ChatCompletion response."""
|
||||
if hasattr(response, "usage") and response.usage:
|
||||
if response.usage:
|
||||
usage = response.usage
|
||||
return {
|
||||
"prompt_tokens": getattr(usage, "prompt_tokens", 0),
|
||||
"completion_tokens": getattr(usage, "completion_tokens", 0),
|
||||
"total_tokens": getattr(usage, "total_tokens", 0),
|
||||
"prompt_tokens": usage.prompt_tokens,
|
||||
"completion_tokens": usage.completion_tokens,
|
||||
"total_tokens": usage.total_tokens,
|
||||
}
|
||||
return {"total_tokens": 0}
|
||||
|
||||
|
||||
37
lib/crewai/src/crewai/mcp/__init__.py
Normal file
37
lib/crewai/src/crewai/mcp/__init__.py
Normal file
@@ -0,0 +1,37 @@
|
||||
"""MCP (Model Context Protocol) client support for CrewAI agents.
|
||||
|
||||
This module provides native MCP client functionality, allowing CrewAI agents
|
||||
to connect to any MCP-compliant server using various transport types.
|
||||
"""
|
||||
|
||||
from crewai.mcp.client import MCPClient
|
||||
from crewai.mcp.config import (
|
||||
MCPServerConfig,
|
||||
MCPServerHTTP,
|
||||
MCPServerSSE,
|
||||
MCPServerStdio,
|
||||
)
|
||||
from crewai.mcp.filters import (
|
||||
StaticToolFilter,
|
||||
ToolFilter,
|
||||
ToolFilterContext,
|
||||
create_dynamic_tool_filter,
|
||||
create_static_tool_filter,
|
||||
)
|
||||
from crewai.mcp.transports.base import BaseTransport, TransportType
|
||||
|
||||
|
||||
__all__ = [
|
||||
"BaseTransport",
|
||||
"MCPClient",
|
||||
"MCPServerConfig",
|
||||
"MCPServerHTTP",
|
||||
"MCPServerSSE",
|
||||
"MCPServerStdio",
|
||||
"StaticToolFilter",
|
||||
"ToolFilter",
|
||||
"ToolFilterContext",
|
||||
"TransportType",
|
||||
"create_dynamic_tool_filter",
|
||||
"create_static_tool_filter",
|
||||
]
|
||||
742
lib/crewai/src/crewai/mcp/client.py
Normal file
742
lib/crewai/src/crewai/mcp/client.py
Normal file
@@ -0,0 +1,742 @@
|
||||
"""MCP client with session management for CrewAI agents."""
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Callable
|
||||
from contextlib import AsyncExitStack
|
||||
from datetime import datetime
|
||||
import logging
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from typing_extensions import Self
|
||||
|
||||
|
||||
# BaseExceptionGroup is available in Python 3.11+
|
||||
try:
|
||||
from builtins import BaseExceptionGroup
|
||||
except ImportError:
|
||||
# Fallback for Python < 3.11 (shouldn't happen in practice)
|
||||
BaseExceptionGroup = Exception
|
||||
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.mcp_events import (
|
||||
MCPConnectionCompletedEvent,
|
||||
MCPConnectionFailedEvent,
|
||||
MCPConnectionStartedEvent,
|
||||
MCPToolExecutionCompletedEvent,
|
||||
MCPToolExecutionFailedEvent,
|
||||
MCPToolExecutionStartedEvent,
|
||||
)
|
||||
from crewai.mcp.transports.base import BaseTransport
|
||||
from crewai.mcp.transports.http import HTTPTransport
|
||||
from crewai.mcp.transports.sse import SSETransport
|
||||
from crewai.mcp.transports.stdio import StdioTransport
|
||||
|
||||
|
||||
# MCP Connection timeout constants (in seconds)
|
||||
MCP_CONNECTION_TIMEOUT = 30 # Increased for slow servers
|
||||
MCP_TOOL_EXECUTION_TIMEOUT = 30
|
||||
MCP_DISCOVERY_TIMEOUT = 30 # Increased for slow servers
|
||||
MCP_MAX_RETRIES = 3
|
||||
|
||||
# Simple in-memory cache for MCP tool schemas (duration: 5 minutes)
|
||||
_mcp_schema_cache: dict[str, tuple[dict[str, Any], float]] = {}
|
||||
_cache_ttl = 300 # 5 minutes
|
||||
|
||||
|
||||
class MCPClient:
|
||||
"""MCP client with session management.
|
||||
|
||||
This client manages connections to MCP servers and provides a high-level
|
||||
interface for interacting with MCP tools, prompts, and resources.
|
||||
|
||||
Example:
|
||||
```python
|
||||
transport = StdioTransport(command="python", args=["server.py"])
|
||||
client = MCPClient(transport)
|
||||
async with client:
|
||||
tools = await client.list_tools()
|
||||
result = await client.call_tool("tool_name", {"arg": "value"})
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
transport: BaseTransport,
|
||||
connect_timeout: int = MCP_CONNECTION_TIMEOUT,
|
||||
execution_timeout: int = MCP_TOOL_EXECUTION_TIMEOUT,
|
||||
discovery_timeout: int = MCP_DISCOVERY_TIMEOUT,
|
||||
max_retries: int = MCP_MAX_RETRIES,
|
||||
cache_tools_list: bool = False,
|
||||
logger: logging.Logger | None = None,
|
||||
) -> None:
|
||||
"""Initialize MCP client.
|
||||
|
||||
Args:
|
||||
transport: Transport instance for MCP server connection.
|
||||
connect_timeout: Connection timeout in seconds.
|
||||
execution_timeout: Tool execution timeout in seconds.
|
||||
discovery_timeout: Tool discovery timeout in seconds.
|
||||
max_retries: Maximum retry attempts for operations.
|
||||
cache_tools_list: Whether to cache tool list results.
|
||||
logger: Optional logger instance.
|
||||
"""
|
||||
self.transport = transport
|
||||
self.connect_timeout = connect_timeout
|
||||
self.execution_timeout = execution_timeout
|
||||
self.discovery_timeout = discovery_timeout
|
||||
self.max_retries = max_retries
|
||||
self.cache_tools_list = cache_tools_list
|
||||
# self._logger = logger or logging.getLogger(__name__)
|
||||
self._session: Any = None
|
||||
self._initialized = False
|
||||
self._exit_stack = AsyncExitStack()
|
||||
self._was_connected = False
|
||||
|
||||
@property
|
||||
def connected(self) -> bool:
|
||||
"""Check if client is connected to server."""
|
||||
return self.transport.connected and self._initialized
|
||||
|
||||
@property
|
||||
def session(self) -> Any:
|
||||
"""Get the MCP session."""
|
||||
if self._session is None:
|
||||
raise RuntimeError("Client not connected. Call connect() first.")
|
||||
return self._session
|
||||
|
||||
def _get_server_info(self) -> tuple[str, str | None, str | None]:
|
||||
"""Get server information for events.
|
||||
|
||||
Returns:
|
||||
Tuple of (server_name, server_url, transport_type).
|
||||
"""
|
||||
if isinstance(self.transport, StdioTransport):
|
||||
server_name = f"{self.transport.command} {' '.join(self.transport.args)}"
|
||||
server_url = None
|
||||
transport_type = self.transport.transport_type.value
|
||||
elif isinstance(self.transport, HTTPTransport):
|
||||
server_name = self.transport.url
|
||||
server_url = self.transport.url
|
||||
transport_type = self.transport.transport_type.value
|
||||
elif isinstance(self.transport, SSETransport):
|
||||
server_name = self.transport.url
|
||||
server_url = self.transport.url
|
||||
transport_type = self.transport.transport_type.value
|
||||
else:
|
||||
server_name = "Unknown MCP Server"
|
||||
server_url = None
|
||||
transport_type = (
|
||||
self.transport.transport_type.value
|
||||
if hasattr(self.transport, "transport_type")
|
||||
else None
|
||||
)
|
||||
|
||||
return server_name, server_url, transport_type
|
||||
|
||||
async def connect(self) -> Self:
|
||||
"""Connect to MCP server and initialize session.
|
||||
|
||||
Returns:
|
||||
Self for method chaining.
|
||||
|
||||
Raises:
|
||||
ConnectionError: If connection fails.
|
||||
ImportError: If MCP SDK not available.
|
||||
"""
|
||||
if self.connected:
|
||||
return self
|
||||
|
||||
# Get server info for events
|
||||
server_name, server_url, transport_type = self._get_server_info()
|
||||
is_reconnect = self._was_connected
|
||||
|
||||
# Emit connection started event
|
||||
started_at = datetime.now()
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
MCPConnectionStartedEvent(
|
||||
server_name=server_name,
|
||||
server_url=server_url,
|
||||
transport_type=transport_type,
|
||||
is_reconnect=is_reconnect,
|
||||
connect_timeout=self.connect_timeout,
|
||||
),
|
||||
)
|
||||
|
||||
try:
|
||||
from mcp import ClientSession
|
||||
|
||||
# Use AsyncExitStack to manage transport and session contexts together
|
||||
# This ensures they're in the same async scope and prevents cancel scope errors
|
||||
# Always enter transport context via exit stack (it handles already-connected state)
|
||||
await self._exit_stack.enter_async_context(self.transport)
|
||||
|
||||
# Create ClientSession with transport streams
|
||||
self._session = ClientSession(
|
||||
self.transport.read_stream,
|
||||
self.transport.write_stream,
|
||||
)
|
||||
|
||||
# Enter the session's async context manager via exit stack
|
||||
await self._exit_stack.enter_async_context(self._session)
|
||||
|
||||
# Initialize the session (required by MCP protocol)
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
self._session.initialize(),
|
||||
timeout=self.connect_timeout,
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
# If initialization was cancelled (e.g., event loop closing),
|
||||
# cleanup and re-raise - don't suppress cancellation
|
||||
await self._cleanup_on_error()
|
||||
raise
|
||||
except BaseExceptionGroup as eg:
|
||||
# Handle exception groups from anyio task groups
|
||||
# Extract the actual meaningful error (not GeneratorExit)
|
||||
actual_error = None
|
||||
for exc in eg.exceptions:
|
||||
if isinstance(exc, Exception) and not isinstance(
|
||||
exc, GeneratorExit
|
||||
):
|
||||
# Check if it's an HTTP error (like 401)
|
||||
error_msg = str(exc).lower()
|
||||
if "401" in error_msg or "unauthorized" in error_msg:
|
||||
actual_error = exc
|
||||
break
|
||||
if "cancel scope" not in error_msg and "task" not in error_msg:
|
||||
actual_error = exc
|
||||
break
|
||||
|
||||
await self._cleanup_on_error()
|
||||
if actual_error:
|
||||
raise ConnectionError(
|
||||
f"Failed to connect to MCP server: {actual_error}"
|
||||
) from actual_error
|
||||
raise ConnectionError(f"Failed to connect to MCP server: {eg}") from eg
|
||||
|
||||
self._initialized = True
|
||||
self._was_connected = True
|
||||
|
||||
completed_at = datetime.now()
|
||||
connection_duration_ms = (completed_at - started_at).total_seconds() * 1000
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
MCPConnectionCompletedEvent(
|
||||
server_name=server_name,
|
||||
server_url=server_url,
|
||||
transport_type=transport_type,
|
||||
started_at=started_at,
|
||||
completed_at=completed_at,
|
||||
connection_duration_ms=connection_duration_ms,
|
||||
is_reconnect=is_reconnect,
|
||||
),
|
||||
)
|
||||
|
||||
return self
|
||||
except ImportError as e:
|
||||
await self._cleanup_on_error()
|
||||
error_msg = (
|
||||
"MCP library not available. Please install with: pip install mcp"
|
||||
)
|
||||
self._emit_connection_failed(
|
||||
server_name,
|
||||
server_url,
|
||||
transport_type,
|
||||
error_msg,
|
||||
"import_error",
|
||||
started_at,
|
||||
)
|
||||
raise ImportError(error_msg) from e
|
||||
except asyncio.TimeoutError as e:
|
||||
await self._cleanup_on_error()
|
||||
error_msg = f"MCP connection timed out after {self.connect_timeout} seconds. The server may be slow or unreachable."
|
||||
self._emit_connection_failed(
|
||||
server_name,
|
||||
server_url,
|
||||
transport_type,
|
||||
error_msg,
|
||||
"timeout",
|
||||
started_at,
|
||||
)
|
||||
raise ConnectionError(error_msg) from e
|
||||
except asyncio.CancelledError:
|
||||
# Re-raise cancellation - don't suppress it
|
||||
await self._cleanup_on_error()
|
||||
self._emit_connection_failed(
|
||||
server_name,
|
||||
server_url,
|
||||
transport_type,
|
||||
"Connection cancelled",
|
||||
"cancelled",
|
||||
started_at,
|
||||
)
|
||||
raise
|
||||
except BaseExceptionGroup as eg:
|
||||
# Handle exception groups from anyio task groups at outer level
|
||||
actual_error = None
|
||||
for exc in eg.exceptions:
|
||||
if isinstance(exc, Exception) and not isinstance(exc, GeneratorExit):
|
||||
error_msg = str(exc).lower()
|
||||
if "401" in error_msg or "unauthorized" in error_msg:
|
||||
actual_error = exc
|
||||
break
|
||||
if "cancel scope" not in error_msg and "task" not in error_msg:
|
||||
actual_error = exc
|
||||
break
|
||||
|
||||
await self._cleanup_on_error()
|
||||
error_type = (
|
||||
"authentication"
|
||||
if actual_error
|
||||
and (
|
||||
"401" in str(actual_error).lower()
|
||||
or "unauthorized" in str(actual_error).lower()
|
||||
)
|
||||
else "network"
|
||||
)
|
||||
error_msg = str(actual_error) if actual_error else str(eg)
|
||||
self._emit_connection_failed(
|
||||
server_name,
|
||||
server_url,
|
||||
transport_type,
|
||||
error_msg,
|
||||
error_type,
|
||||
started_at,
|
||||
)
|
||||
if actual_error:
|
||||
raise ConnectionError(
|
||||
f"Failed to connect to MCP server: {actual_error}"
|
||||
) from actual_error
|
||||
raise ConnectionError(f"Failed to connect to MCP server: {eg}") from eg
|
||||
except Exception as e:
|
||||
await self._cleanup_on_error()
|
||||
error_type = (
|
||||
"authentication"
|
||||
if "401" in str(e).lower() or "unauthorized" in str(e).lower()
|
||||
else "network"
|
||||
)
|
||||
self._emit_connection_failed(
|
||||
server_name, server_url, transport_type, str(e), error_type, started_at
|
||||
)
|
||||
raise ConnectionError(f"Failed to connect to MCP server: {e}") from e
|
||||
|
||||
def _emit_connection_failed(
|
||||
self,
|
||||
server_name: str,
|
||||
server_url: str | None,
|
||||
transport_type: str | None,
|
||||
error: str,
|
||||
error_type: str,
|
||||
started_at: datetime,
|
||||
) -> None:
|
||||
"""Emit connection failed event."""
|
||||
failed_at = datetime.now()
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
MCPConnectionFailedEvent(
|
||||
server_name=server_name,
|
||||
server_url=server_url,
|
||||
transport_type=transport_type,
|
||||
error=error,
|
||||
error_type=error_type,
|
||||
started_at=started_at,
|
||||
failed_at=failed_at,
|
||||
),
|
||||
)
|
||||
|
||||
async def _cleanup_on_error(self) -> None:
|
||||
"""Cleanup resources when an error occurs during connection."""
|
||||
try:
|
||||
await self._exit_stack.aclose()
|
||||
|
||||
except Exception as e:
|
||||
# Best effort cleanup - ignore all other errors
|
||||
raise RuntimeError(f"Error during MCP client cleanup: {e}") from e
|
||||
finally:
|
||||
self._session = None
|
||||
self._initialized = False
|
||||
self._exit_stack = AsyncExitStack()
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
"""Disconnect from MCP server and cleanup resources."""
|
||||
if not self.connected:
|
||||
return
|
||||
|
||||
try:
|
||||
await self._exit_stack.aclose()
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Error during MCP client disconnect: {e}") from e
|
||||
finally:
|
||||
self._session = None
|
||||
self._initialized = False
|
||||
self._exit_stack = AsyncExitStack()
|
||||
|
||||
async def list_tools(self, use_cache: bool | None = None) -> list[dict[str, Any]]:
|
||||
"""List available tools from MCP server.
|
||||
|
||||
Args:
|
||||
use_cache: Whether to use cached results. If None, uses
|
||||
client's cache_tools_list setting.
|
||||
|
||||
Returns:
|
||||
List of tool definitions with name, description, and inputSchema.
|
||||
"""
|
||||
if not self.connected:
|
||||
await self.connect()
|
||||
|
||||
# Check cache if enabled
|
||||
use_cache = use_cache if use_cache is not None else self.cache_tools_list
|
||||
if use_cache:
|
||||
cache_key = self._get_cache_key("tools")
|
||||
if cache_key in _mcp_schema_cache:
|
||||
cached_data, cache_time = _mcp_schema_cache[cache_key]
|
||||
if time.time() - cache_time < _cache_ttl:
|
||||
# Logger removed - return cached data
|
||||
return cached_data
|
||||
|
||||
# List tools with timeout and retries
|
||||
tools = await self._retry_operation(
|
||||
self._list_tools_impl,
|
||||
timeout=self.discovery_timeout,
|
||||
)
|
||||
|
||||
# Cache results if enabled
|
||||
if use_cache:
|
||||
cache_key = self._get_cache_key("tools")
|
||||
_mcp_schema_cache[cache_key] = (tools, time.time())
|
||||
|
||||
return tools
|
||||
|
||||
async def _list_tools_impl(self) -> list[dict[str, Any]]:
|
||||
"""Internal implementation of list_tools."""
|
||||
tools_result = await asyncio.wait_for(
|
||||
self.session.list_tools(),
|
||||
timeout=self.discovery_timeout,
|
||||
)
|
||||
|
||||
return [
|
||||
{
|
||||
"name": tool.name,
|
||||
"description": getattr(tool, "description", ""),
|
||||
"inputSchema": getattr(tool, "inputSchema", {}),
|
||||
}
|
||||
for tool in tools_result.tools
|
||||
]
|
||||
|
||||
async def call_tool(
|
||||
self, tool_name: str, arguments: dict[str, Any] | None = None
|
||||
) -> Any:
|
||||
"""Call a tool on the MCP server.
|
||||
|
||||
Args:
|
||||
tool_name: Name of the tool to call.
|
||||
arguments: Tool arguments.
|
||||
|
||||
Returns:
|
||||
Tool execution result.
|
||||
"""
|
||||
if not self.connected:
|
||||
await self.connect()
|
||||
|
||||
arguments = arguments or {}
|
||||
cleaned_arguments = self._clean_tool_arguments(arguments)
|
||||
|
||||
# Get server info for events
|
||||
server_name, server_url, transport_type = self._get_server_info()
|
||||
|
||||
# Emit tool execution started event
|
||||
started_at = datetime.now()
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
MCPToolExecutionStartedEvent(
|
||||
server_name=server_name,
|
||||
server_url=server_url,
|
||||
transport_type=transport_type,
|
||||
tool_name=tool_name,
|
||||
tool_args=cleaned_arguments,
|
||||
),
|
||||
)
|
||||
|
||||
try:
|
||||
result = await self._retry_operation(
|
||||
lambda: self._call_tool_impl(tool_name, cleaned_arguments),
|
||||
timeout=self.execution_timeout,
|
||||
)
|
||||
|
||||
completed_at = datetime.now()
|
||||
execution_duration_ms = (completed_at - started_at).total_seconds() * 1000
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
MCPToolExecutionCompletedEvent(
|
||||
server_name=server_name,
|
||||
server_url=server_url,
|
||||
transport_type=transport_type,
|
||||
tool_name=tool_name,
|
||||
tool_args=cleaned_arguments,
|
||||
result=result,
|
||||
started_at=started_at,
|
||||
completed_at=completed_at,
|
||||
execution_duration_ms=execution_duration_ms,
|
||||
),
|
||||
)
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
failed_at = datetime.now()
|
||||
error_type = (
|
||||
"timeout"
|
||||
if isinstance(e, (asyncio.TimeoutError, ConnectionError))
|
||||
and "timeout" in str(e).lower()
|
||||
else "server_error"
|
||||
)
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
MCPToolExecutionFailedEvent(
|
||||
server_name=server_name,
|
||||
server_url=server_url,
|
||||
transport_type=transport_type,
|
||||
tool_name=tool_name,
|
||||
tool_args=cleaned_arguments,
|
||||
error=str(e),
|
||||
error_type=error_type,
|
||||
started_at=started_at,
|
||||
failed_at=failed_at,
|
||||
),
|
||||
)
|
||||
raise
|
||||
|
||||
def _clean_tool_arguments(self, arguments: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Clean tool arguments by removing None values and fixing formats.
|
||||
|
||||
Args:
|
||||
arguments: Raw tool arguments.
|
||||
|
||||
Returns:
|
||||
Cleaned arguments ready for MCP server.
|
||||
"""
|
||||
cleaned = {}
|
||||
|
||||
for key, value in arguments.items():
|
||||
# Skip None values
|
||||
if value is None:
|
||||
continue
|
||||
|
||||
# Fix sources array format: convert ["web"] to [{"type": "web"}]
|
||||
if key == "sources" and isinstance(value, list):
|
||||
fixed_sources = []
|
||||
for item in value:
|
||||
if isinstance(item, str):
|
||||
# Convert string to object format
|
||||
fixed_sources.append({"type": item})
|
||||
elif isinstance(item, dict):
|
||||
# Already in correct format
|
||||
fixed_sources.append(item)
|
||||
else:
|
||||
# Keep as is if unknown format
|
||||
fixed_sources.append(item)
|
||||
if fixed_sources:
|
||||
cleaned[key] = fixed_sources
|
||||
continue
|
||||
|
||||
# Recursively clean nested dictionaries
|
||||
if isinstance(value, dict):
|
||||
nested_cleaned = self._clean_tool_arguments(value)
|
||||
if nested_cleaned: # Only add if not empty
|
||||
cleaned[key] = nested_cleaned
|
||||
elif isinstance(value, list):
|
||||
# Clean list items
|
||||
cleaned_list = []
|
||||
for item in value:
|
||||
if isinstance(item, dict):
|
||||
cleaned_item = self._clean_tool_arguments(item)
|
||||
if cleaned_item:
|
||||
cleaned_list.append(cleaned_item)
|
||||
elif item is not None:
|
||||
cleaned_list.append(item)
|
||||
if cleaned_list:
|
||||
cleaned[key] = cleaned_list
|
||||
else:
|
||||
# Keep primitive values
|
||||
cleaned[key] = value
|
||||
|
||||
return cleaned
|
||||
|
||||
async def _call_tool_impl(self, tool_name: str, arguments: dict[str, Any]) -> Any:
|
||||
"""Internal implementation of call_tool."""
|
||||
result = await asyncio.wait_for(
|
||||
self.session.call_tool(tool_name, arguments),
|
||||
timeout=self.execution_timeout,
|
||||
)
|
||||
|
||||
# Extract result content
|
||||
if hasattr(result, "content") and result.content:
|
||||
if isinstance(result.content, list) and len(result.content) > 0:
|
||||
content_item = result.content[0]
|
||||
if hasattr(content_item, "text"):
|
||||
return str(content_item.text)
|
||||
return str(content_item)
|
||||
return str(result.content)
|
||||
|
||||
return str(result)
|
||||
|
||||
async def list_prompts(self) -> list[dict[str, Any]]:
|
||||
"""List available prompts from MCP server.
|
||||
|
||||
Returns:
|
||||
List of prompt definitions.
|
||||
"""
|
||||
if not self.connected:
|
||||
await self.connect()
|
||||
|
||||
return await self._retry_operation(
|
||||
self._list_prompts_impl,
|
||||
timeout=self.discovery_timeout,
|
||||
)
|
||||
|
||||
async def _list_prompts_impl(self) -> list[dict[str, Any]]:
|
||||
"""Internal implementation of list_prompts."""
|
||||
prompts_result = await asyncio.wait_for(
|
||||
self.session.list_prompts(),
|
||||
timeout=self.discovery_timeout,
|
||||
)
|
||||
|
||||
return [
|
||||
{
|
||||
"name": prompt.name,
|
||||
"description": getattr(prompt, "description", ""),
|
||||
"arguments": getattr(prompt, "arguments", []),
|
||||
}
|
||||
for prompt in prompts_result.prompts
|
||||
]
|
||||
|
||||
async def get_prompt(
|
||||
self, prompt_name: str, arguments: dict[str, Any] | None = None
|
||||
) -> dict[str, Any]:
|
||||
"""Get a prompt from the MCP server.
|
||||
|
||||
Args:
|
||||
prompt_name: Name of the prompt to get.
|
||||
arguments: Optional prompt arguments.
|
||||
|
||||
Returns:
|
||||
Prompt content and metadata.
|
||||
"""
|
||||
if not self.connected:
|
||||
await self.connect()
|
||||
|
||||
arguments = arguments or {}
|
||||
|
||||
return await self._retry_operation(
|
||||
lambda: self._get_prompt_impl(prompt_name, arguments),
|
||||
timeout=self.execution_timeout,
|
||||
)
|
||||
|
||||
async def _get_prompt_impl(
|
||||
self, prompt_name: str, arguments: dict[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
"""Internal implementation of get_prompt."""
|
||||
result = await asyncio.wait_for(
|
||||
self.session.get_prompt(prompt_name, arguments),
|
||||
timeout=self.execution_timeout,
|
||||
)
|
||||
|
||||
return {
|
||||
"name": prompt_name,
|
||||
"messages": [
|
||||
{
|
||||
"role": msg.role,
|
||||
"content": msg.content,
|
||||
}
|
||||
for msg in result.messages
|
||||
],
|
||||
"arguments": arguments,
|
||||
}
|
||||
|
||||
async def _retry_operation(
|
||||
self,
|
||||
operation: Callable[[], Any],
|
||||
timeout: int | None = None,
|
||||
) -> Any:
|
||||
"""Retry an operation with exponential backoff.
|
||||
|
||||
Args:
|
||||
operation: Async operation to retry.
|
||||
timeout: Operation timeout in seconds.
|
||||
|
||||
Returns:
|
||||
Operation result.
|
||||
"""
|
||||
last_error = None
|
||||
timeout = timeout or self.execution_timeout
|
||||
|
||||
for attempt in range(self.max_retries):
|
||||
try:
|
||||
if timeout:
|
||||
return await asyncio.wait_for(operation(), timeout=timeout)
|
||||
return await operation()
|
||||
|
||||
except asyncio.TimeoutError as e: # noqa: PERF203
|
||||
last_error = f"Operation timed out after {timeout} seconds"
|
||||
if attempt < self.max_retries - 1:
|
||||
wait_time = 2**attempt
|
||||
await asyncio.sleep(wait_time)
|
||||
else:
|
||||
raise ConnectionError(last_error) from e
|
||||
|
||||
except Exception as e:
|
||||
error_str = str(e).lower()
|
||||
|
||||
# Classify errors as retryable or non-retryable
|
||||
if "authentication" in error_str or "unauthorized" in error_str:
|
||||
raise ConnectionError(f"Authentication failed: {e}") from e
|
||||
|
||||
if "not found" in error_str:
|
||||
raise ValueError(f"Resource not found: {e}") from e
|
||||
|
||||
# Retryable errors
|
||||
last_error = str(e)
|
||||
if attempt < self.max_retries - 1:
|
||||
wait_time = 2**attempt
|
||||
await asyncio.sleep(wait_time)
|
||||
else:
|
||||
raise ConnectionError(
|
||||
f"Operation failed after {self.max_retries} attempts: {last_error}"
|
||||
) from e
|
||||
|
||||
raise ConnectionError(f"Operation failed: {last_error}")
|
||||
|
||||
def _get_cache_key(self, resource_type: str) -> str:
|
||||
"""Generate cache key for resource.
|
||||
|
||||
Args:
|
||||
resource_type: Type of resource (e.g., "tools", "prompts").
|
||||
|
||||
Returns:
|
||||
Cache key string.
|
||||
"""
|
||||
# Use transport type and URL/command as cache key
|
||||
if isinstance(self.transport, StdioTransport):
|
||||
key = f"stdio:{self.transport.command}:{':'.join(self.transport.args)}"
|
||||
elif isinstance(self.transport, HTTPTransport):
|
||||
key = f"http:{self.transport.url}"
|
||||
elif isinstance(self.transport, SSETransport):
|
||||
key = f"sse:{self.transport.url}"
|
||||
else:
|
||||
key = f"{self.transport.transport_type}:unknown"
|
||||
|
||||
return f"mcp:{key}:{resource_type}"
|
||||
|
||||
async def __aenter__(self) -> Self:
|
||||
"""Async context manager entry."""
|
||||
return await self.connect()
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_val: BaseException | None,
|
||||
exc_tb: Any,
|
||||
) -> None:
|
||||
"""Async context manager exit."""
|
||||
await self.disconnect()
|
||||
124
lib/crewai/src/crewai/mcp/config.py
Normal file
124
lib/crewai/src/crewai/mcp/config.py
Normal file
@@ -0,0 +1,124 @@
|
||||
"""MCP server configuration models for CrewAI agents.
|
||||
|
||||
This module provides Pydantic models for configuring MCP servers with
|
||||
various transport types, similar to OpenAI's Agents SDK.
|
||||
"""
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai.mcp.filters import ToolFilter
|
||||
|
||||
|
||||
class MCPServerStdio(BaseModel):
|
||||
"""Stdio MCP server configuration.
|
||||
|
||||
This configuration is used for connecting to local MCP servers
|
||||
that run as processes and communicate via standard input/output.
|
||||
|
||||
Example:
|
||||
```python
|
||||
mcp_server = MCPServerStdio(
|
||||
command="python",
|
||||
args=["path/to/server.py"],
|
||||
env={"API_KEY": "..."},
|
||||
tool_filter=create_static_tool_filter(
|
||||
allowed_tool_names=["read_file", "write_file"]
|
||||
),
|
||||
)
|
||||
```
|
||||
"""
|
||||
|
||||
command: str = Field(
|
||||
...,
|
||||
description="Command to execute (e.g., 'python', 'node', 'npx', 'uvx').",
|
||||
)
|
||||
args: list[str] = Field(
|
||||
default_factory=list,
|
||||
description="Command arguments (e.g., ['server.py'] or ['-y', '@mcp/server']).",
|
||||
)
|
||||
env: dict[str, str] | None = Field(
|
||||
default=None,
|
||||
description="Environment variables to pass to the process.",
|
||||
)
|
||||
tool_filter: ToolFilter | None = Field(
|
||||
default=None,
|
||||
description="Optional tool filter for filtering available tools.",
|
||||
)
|
||||
cache_tools_list: bool = Field(
|
||||
default=False,
|
||||
description="Whether to cache the tool list for faster subsequent access.",
|
||||
)
|
||||
|
||||
|
||||
class MCPServerHTTP(BaseModel):
|
||||
"""HTTP/Streamable HTTP MCP server configuration.
|
||||
|
||||
This configuration is used for connecting to remote MCP servers
|
||||
over HTTP/HTTPS using streamable HTTP transport.
|
||||
|
||||
Example:
|
||||
```python
|
||||
mcp_server = MCPServerHTTP(
|
||||
url="https://api.example.com/mcp",
|
||||
headers={"Authorization": "Bearer ..."},
|
||||
cache_tools_list=True,
|
||||
)
|
||||
```
|
||||
"""
|
||||
|
||||
url: str = Field(
|
||||
..., description="Server URL (e.g., 'https://api.example.com/mcp')."
|
||||
)
|
||||
headers: dict[str, str] | None = Field(
|
||||
default=None,
|
||||
description="Optional HTTP headers for authentication or other purposes.",
|
||||
)
|
||||
streamable: bool = Field(
|
||||
default=True,
|
||||
description="Whether to use streamable HTTP transport (default: True).",
|
||||
)
|
||||
tool_filter: ToolFilter | None = Field(
|
||||
default=None,
|
||||
description="Optional tool filter for filtering available tools.",
|
||||
)
|
||||
cache_tools_list: bool = Field(
|
||||
default=False,
|
||||
description="Whether to cache the tool list for faster subsequent access.",
|
||||
)
|
||||
|
||||
|
||||
class MCPServerSSE(BaseModel):
|
||||
"""Server-Sent Events (SSE) MCP server configuration.
|
||||
|
||||
This configuration is used for connecting to remote MCP servers
|
||||
using Server-Sent Events for real-time streaming communication.
|
||||
|
||||
Example:
|
||||
```python
|
||||
mcp_server = MCPServerSSE(
|
||||
url="https://api.example.com/mcp/sse",
|
||||
headers={"Authorization": "Bearer ..."},
|
||||
)
|
||||
```
|
||||
"""
|
||||
|
||||
url: str = Field(
|
||||
...,
|
||||
description="Server URL (e.g., 'https://api.example.com/mcp/sse').",
|
||||
)
|
||||
headers: dict[str, str] | None = Field(
|
||||
default=None,
|
||||
description="Optional HTTP headers for authentication or other purposes.",
|
||||
)
|
||||
tool_filter: ToolFilter | None = Field(
|
||||
default=None,
|
||||
description="Optional tool filter for filtering available tools.",
|
||||
)
|
||||
cache_tools_list: bool = Field(
|
||||
default=False,
|
||||
description="Whether to cache the tool list for faster subsequent access.",
|
||||
)
|
||||
|
||||
|
||||
# Type alias for all MCP server configurations
|
||||
MCPServerConfig = MCPServerStdio | MCPServerHTTP | MCPServerSSE
|
||||
166
lib/crewai/src/crewai/mcp/filters.py
Normal file
166
lib/crewai/src/crewai/mcp/filters.py
Normal file
@@ -0,0 +1,166 @@
|
||||
"""Tool filtering support for MCP servers.
|
||||
|
||||
This module provides utilities for filtering tools from MCP servers,
|
||||
including static allow/block lists and dynamic context-aware filtering.
|
||||
"""
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
|
||||
class ToolFilterContext(BaseModel):
|
||||
"""Context for dynamic tool filtering.
|
||||
|
||||
This context is passed to dynamic tool filters to provide
|
||||
information about the agent, run context, and server.
|
||||
"""
|
||||
|
||||
agent: Any = Field(..., description="The agent requesting tools.")
|
||||
server_name: str = Field(..., description="Name of the MCP server.")
|
||||
run_context: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
description="Optional run context for additional filtering logic.",
|
||||
)
|
||||
|
||||
|
||||
# Type alias for tool filter functions
|
||||
ToolFilter = (
|
||||
Callable[[ToolFilterContext, dict[str, Any]], bool]
|
||||
| Callable[[dict[str, Any]], bool]
|
||||
)
|
||||
|
||||
|
||||
class StaticToolFilter:
|
||||
"""Static tool filter with allow/block lists.
|
||||
|
||||
This filter provides simple allow/block list filtering based on
|
||||
tool names. Useful for restricting which tools are available
|
||||
from an MCP server.
|
||||
|
||||
Example:
|
||||
```python
|
||||
filter = StaticToolFilter(
|
||||
allowed_tool_names=["read_file", "write_file"],
|
||||
blocked_tool_names=["delete_file"],
|
||||
)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
allowed_tool_names: list[str] | None = None,
|
||||
blocked_tool_names: list[str] | None = None,
|
||||
) -> None:
|
||||
"""Initialize static tool filter.
|
||||
|
||||
Args:
|
||||
allowed_tool_names: List of tool names to allow. If None,
|
||||
all tools are allowed (unless blocked).
|
||||
blocked_tool_names: List of tool names to block. Blocked tools
|
||||
take precedence over allowed tools.
|
||||
"""
|
||||
self.allowed_tool_names = set(allowed_tool_names or [])
|
||||
self.blocked_tool_names = set(blocked_tool_names or [])
|
||||
|
||||
def __call__(self, tool: dict[str, Any]) -> bool:
|
||||
"""Filter tool based on allow/block lists.
|
||||
|
||||
Args:
|
||||
tool: Tool definition dictionary with at least 'name' key.
|
||||
|
||||
Returns:
|
||||
True if tool should be included, False otherwise.
|
||||
"""
|
||||
tool_name = tool.get("name", "")
|
||||
|
||||
# Blocked tools take precedence
|
||||
if self.blocked_tool_names and tool_name in self.blocked_tool_names:
|
||||
return False
|
||||
|
||||
# If allow list exists, tool must be in it
|
||||
if self.allowed_tool_names:
|
||||
return tool_name in self.allowed_tool_names
|
||||
|
||||
# No restrictions - allow all
|
||||
return True
|
||||
|
||||
|
||||
def create_static_tool_filter(
|
||||
allowed_tool_names: list[str] | None = None,
|
||||
blocked_tool_names: list[str] | None = None,
|
||||
) -> Callable[[dict[str, Any]], bool]:
|
||||
"""Create a static tool filter function.
|
||||
|
||||
This is a convenience function for creating static tool filters
|
||||
with allow/block lists.
|
||||
|
||||
Args:
|
||||
allowed_tool_names: List of tool names to allow. If None,
|
||||
all tools are allowed (unless blocked).
|
||||
blocked_tool_names: List of tool names to block. Blocked tools
|
||||
take precedence over allowed tools.
|
||||
|
||||
Returns:
|
||||
Tool filter function that returns True for allowed tools.
|
||||
|
||||
Example:
|
||||
```python
|
||||
filter_fn = create_static_tool_filter(
|
||||
allowed_tool_names=["read_file", "write_file"],
|
||||
blocked_tool_names=["delete_file"],
|
||||
)
|
||||
|
||||
# Use in MCPServerStdio
|
||||
mcp_server = MCPServerStdio(
|
||||
command="npx",
|
||||
args=["-y", "@modelcontextprotocol/server-filesystem"],
|
||||
tool_filter=filter_fn,
|
||||
)
|
||||
```
|
||||
"""
|
||||
return StaticToolFilter(
|
||||
allowed_tool_names=allowed_tool_names,
|
||||
blocked_tool_names=blocked_tool_names,
|
||||
)
|
||||
|
||||
|
||||
def create_dynamic_tool_filter(
|
||||
filter_func: Callable[[ToolFilterContext, dict[str, Any]], bool],
|
||||
) -> Callable[[ToolFilterContext, dict[str, Any]], bool]:
|
||||
"""Create a dynamic tool filter function.
|
||||
|
||||
This function wraps a dynamic filter function that has access
|
||||
to the tool filter context (agent, server, run context).
|
||||
|
||||
Args:
|
||||
filter_func: Function that takes (context, tool) and returns bool.
|
||||
|
||||
Returns:
|
||||
Tool filter function that can be used with MCP server configs.
|
||||
|
||||
Example:
|
||||
```python
|
||||
async def context_aware_filter(
|
||||
context: ToolFilterContext, tool: dict[str, Any]
|
||||
) -> bool:
|
||||
# Block dangerous tools for code reviewers
|
||||
if context.agent.role == "Code Reviewer":
|
||||
if tool["name"].startswith("danger_"):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
filter_fn = create_dynamic_tool_filter(context_aware_filter)
|
||||
|
||||
mcp_server = MCPServerStdio(
|
||||
command="python", args=["server.py"], tool_filter=filter_fn
|
||||
)
|
||||
```
|
||||
"""
|
||||
return filter_func
|
||||
15
lib/crewai/src/crewai/mcp/transports/__init__.py
Normal file
15
lib/crewai/src/crewai/mcp/transports/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
"""MCP transport implementations for various connection types."""
|
||||
|
||||
from crewai.mcp.transports.base import BaseTransport, TransportType
|
||||
from crewai.mcp.transports.http import HTTPTransport
|
||||
from crewai.mcp.transports.sse import SSETransport
|
||||
from crewai.mcp.transports.stdio import StdioTransport
|
||||
|
||||
|
||||
__all__ = [
|
||||
"BaseTransport",
|
||||
"HTTPTransport",
|
||||
"SSETransport",
|
||||
"StdioTransport",
|
||||
"TransportType",
|
||||
]
|
||||
125
lib/crewai/src/crewai/mcp/transports/base.py
Normal file
125
lib/crewai/src/crewai/mcp/transports/base.py
Normal file
@@ -0,0 +1,125 @@
|
||||
"""Base transport interface for MCP connections."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from typing import Any, Protocol
|
||||
|
||||
from typing_extensions import Self
|
||||
|
||||
|
||||
class TransportType(str, Enum):
|
||||
"""MCP transport types."""
|
||||
|
||||
STDIO = "stdio"
|
||||
HTTP = "http"
|
||||
STREAMABLE_HTTP = "streamable-http"
|
||||
SSE = "sse"
|
||||
|
||||
|
||||
class ReadStream(Protocol):
|
||||
"""Protocol for read streams."""
|
||||
|
||||
async def read(self, n: int = -1) -> bytes:
|
||||
"""Read bytes from stream."""
|
||||
...
|
||||
|
||||
|
||||
class WriteStream(Protocol):
|
||||
"""Protocol for write streams."""
|
||||
|
||||
async def write(self, data: bytes) -> None:
|
||||
"""Write bytes to stream."""
|
||||
...
|
||||
|
||||
|
||||
class BaseTransport(ABC):
|
||||
"""Base class for MCP transport implementations.
|
||||
|
||||
This abstract base class defines the interface that all transport
|
||||
implementations must follow. Transports handle the low-level communication
|
||||
with MCP servers.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
"""Initialize the transport.
|
||||
|
||||
Args:
|
||||
**kwargs: Transport-specific configuration options.
|
||||
"""
|
||||
self._read_stream: ReadStream | None = None
|
||||
self._write_stream: WriteStream | None = None
|
||||
self._connected = False
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def transport_type(self) -> TransportType:
|
||||
"""Return the transport type."""
|
||||
...
|
||||
|
||||
@property
|
||||
def connected(self) -> bool:
|
||||
"""Check if transport is connected."""
|
||||
return self._connected
|
||||
|
||||
@property
|
||||
def read_stream(self) -> ReadStream:
|
||||
"""Get the read stream."""
|
||||
if self._read_stream is None:
|
||||
raise RuntimeError("Transport not connected. Call connect() first.")
|
||||
return self._read_stream
|
||||
|
||||
@property
|
||||
def write_stream(self) -> WriteStream:
|
||||
"""Get the write stream."""
|
||||
if self._write_stream is None:
|
||||
raise RuntimeError("Transport not connected. Call connect() first.")
|
||||
return self._write_stream
|
||||
|
||||
@abstractmethod
|
||||
async def connect(self) -> Self:
|
||||
"""Establish connection to MCP server.
|
||||
|
||||
Returns:
|
||||
Self for method chaining.
|
||||
|
||||
Raises:
|
||||
ConnectionError: If connection fails.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def disconnect(self) -> None:
|
||||
"""Close connection to MCP server."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def __aenter__(self) -> Self:
|
||||
"""Async context manager entry."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_val: BaseException | None,
|
||||
exc_tb: Any,
|
||||
) -> None:
|
||||
"""Async context manager exit."""
|
||||
...
|
||||
|
||||
def _set_streams(self, read: ReadStream, write: WriteStream) -> None:
|
||||
"""Set the read and write streams.
|
||||
|
||||
Args:
|
||||
read: Read stream.
|
||||
write: Write stream.
|
||||
"""
|
||||
self._read_stream = read
|
||||
self._write_stream = write
|
||||
self._connected = True
|
||||
|
||||
def _clear_streams(self) -> None:
|
||||
"""Clear the read and write streams."""
|
||||
self._read_stream = None
|
||||
self._write_stream = None
|
||||
self._connected = False
|
||||
174
lib/crewai/src/crewai/mcp/transports/http.py
Normal file
174
lib/crewai/src/crewai/mcp/transports/http.py
Normal file
@@ -0,0 +1,174 @@
|
||||
"""HTTP and Streamable HTTP transport for MCP servers."""
|
||||
|
||||
import asyncio
|
||||
from typing import Any
|
||||
|
||||
from typing_extensions import Self
|
||||
|
||||
|
||||
# BaseExceptionGroup is available in Python 3.11+
|
||||
try:
|
||||
from builtins import BaseExceptionGroup
|
||||
except ImportError:
|
||||
# Fallback for Python < 3.11 (shouldn't happen in practice)
|
||||
BaseExceptionGroup = Exception
|
||||
|
||||
from crewai.mcp.transports.base import BaseTransport, TransportType
|
||||
|
||||
|
||||
class HTTPTransport(BaseTransport):
|
||||
"""HTTP/Streamable HTTP transport for connecting to remote MCP servers.
|
||||
|
||||
This transport connects to MCP servers over HTTP/HTTPS using the
|
||||
streamable HTTP client from the MCP SDK.
|
||||
|
||||
Example:
|
||||
```python
|
||||
transport = HTTPTransport(
|
||||
url="https://api.example.com/mcp",
|
||||
headers={"Authorization": "Bearer ..."}
|
||||
)
|
||||
async with transport:
|
||||
# Use transport...
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
url: str,
|
||||
headers: dict[str, str] | None = None,
|
||||
streamable: bool = True,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize HTTP transport.
|
||||
|
||||
Args:
|
||||
url: Server URL (e.g., "https://api.example.com/mcp").
|
||||
headers: Optional HTTP headers.
|
||||
streamable: Whether to use streamable HTTP (default: True).
|
||||
**kwargs: Additional transport options.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
self.url = url
|
||||
self.headers = headers or {}
|
||||
self.streamable = streamable
|
||||
self._transport_context: Any = None
|
||||
|
||||
@property
|
||||
def transport_type(self) -> TransportType:
|
||||
"""Return the transport type."""
|
||||
return TransportType.STREAMABLE_HTTP if self.streamable else TransportType.HTTP
|
||||
|
||||
async def connect(self) -> Self:
|
||||
"""Establish HTTP connection to MCP server.
|
||||
|
||||
Returns:
|
||||
Self for method chaining.
|
||||
|
||||
Raises:
|
||||
ConnectionError: If connection fails.
|
||||
ImportError: If MCP SDK not available.
|
||||
"""
|
||||
if self._connected:
|
||||
return self
|
||||
|
||||
try:
|
||||
from mcp.client.streamable_http import streamablehttp_client
|
||||
|
||||
self._transport_context = streamablehttp_client(
|
||||
self.url,
|
||||
headers=self.headers if self.headers else None,
|
||||
terminate_on_close=True,
|
||||
)
|
||||
|
||||
try:
|
||||
read, write, _ = await asyncio.wait_for(
|
||||
self._transport_context.__aenter__(), timeout=30.0
|
||||
)
|
||||
except asyncio.TimeoutError as e:
|
||||
self._transport_context = None
|
||||
raise ConnectionError(
|
||||
"Transport context entry timed out after 30 seconds. "
|
||||
"Server may be slow or unreachable."
|
||||
) from e
|
||||
except Exception as e:
|
||||
self._transport_context = None
|
||||
raise ConnectionError(f"Failed to enter transport context: {e}") from e
|
||||
self._set_streams(read=read, write=write)
|
||||
return self
|
||||
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"MCP library not available. Please install with: pip install mcp"
|
||||
) from e
|
||||
except Exception as e:
|
||||
self._clear_streams()
|
||||
if self._transport_context is not None:
|
||||
self._transport_context = None
|
||||
raise ConnectionError(f"Failed to connect to MCP server: {e}") from e
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
"""Close HTTP connection."""
|
||||
if not self._connected:
|
||||
return
|
||||
|
||||
try:
|
||||
# Clear streams first
|
||||
self._clear_streams()
|
||||
# await self._exit_stack.aclose()
|
||||
|
||||
# Exit transport context - this will clean up background tasks
|
||||
# Give a small delay to allow background tasks to complete
|
||||
if self._transport_context is not None:
|
||||
try:
|
||||
# Wait a tiny bit for any pending operations
|
||||
await asyncio.sleep(0.1)
|
||||
await self._transport_context.__aexit__(None, None, None)
|
||||
except (RuntimeError, asyncio.CancelledError) as e:
|
||||
# Ignore "exit cancel scope in different task" errors and cancellation
|
||||
# These happen when asyncio.run() closes the event loop
|
||||
# while background tasks are still running
|
||||
error_msg = str(e).lower()
|
||||
if "cancel scope" not in error_msg and "task" not in error_msg:
|
||||
# Only suppress cancel scope/task errors, re-raise others
|
||||
if isinstance(e, RuntimeError):
|
||||
raise
|
||||
# For CancelledError, just suppress it
|
||||
except BaseExceptionGroup as eg:
|
||||
# Handle exception groups from anyio task groups
|
||||
# Suppress if they contain cancel scope errors
|
||||
should_suppress = False
|
||||
for exc in eg.exceptions:
|
||||
error_msg = str(exc).lower()
|
||||
if "cancel scope" in error_msg or "task" in error_msg:
|
||||
should_suppress = True
|
||||
break
|
||||
if not should_suppress:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"Error during HTTP transport disconnect: {e}"
|
||||
) from e
|
||||
|
||||
self._connected = False
|
||||
|
||||
except Exception as e:
|
||||
# Log but don't raise - cleanup should be best effort
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.warning(f"Error during HTTP transport disconnect: {e}")
|
||||
|
||||
async def __aenter__(self) -> Self:
|
||||
"""Async context manager entry."""
|
||||
return await self.connect()
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_val: BaseException | None,
|
||||
exc_tb: Any,
|
||||
) -> None:
|
||||
"""Async context manager exit."""
|
||||
|
||||
await self.disconnect()
|
||||
113
lib/crewai/src/crewai/mcp/transports/sse.py
Normal file
113
lib/crewai/src/crewai/mcp/transports/sse.py
Normal file
@@ -0,0 +1,113 @@
|
||||
"""Server-Sent Events (SSE) transport for MCP servers."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from typing_extensions import Self
|
||||
|
||||
from crewai.mcp.transports.base import BaseTransport, TransportType
|
||||
|
||||
|
||||
class SSETransport(BaseTransport):
|
||||
"""SSE transport for connecting to remote MCP servers.
|
||||
|
||||
This transport connects to MCP servers using Server-Sent Events (SSE)
|
||||
for real-time streaming communication.
|
||||
|
||||
Example:
|
||||
```python
|
||||
transport = SSETransport(
|
||||
url="https://api.example.com/mcp/sse",
|
||||
headers={"Authorization": "Bearer ..."}
|
||||
)
|
||||
async with transport:
|
||||
# Use transport...
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
url: str,
|
||||
headers: dict[str, str] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize SSE transport.
|
||||
|
||||
Args:
|
||||
url: Server URL (e.g., "https://api.example.com/mcp/sse").
|
||||
headers: Optional HTTP headers.
|
||||
**kwargs: Additional transport options.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
self.url = url
|
||||
self.headers = headers or {}
|
||||
self._transport_context: Any = None
|
||||
|
||||
@property
|
||||
def transport_type(self) -> TransportType:
|
||||
"""Return the transport type."""
|
||||
return TransportType.SSE
|
||||
|
||||
async def connect(self) -> Self:
|
||||
"""Establish SSE connection to MCP server.
|
||||
|
||||
Returns:
|
||||
Self for method chaining.
|
||||
|
||||
Raises:
|
||||
ConnectionError: If connection fails.
|
||||
ImportError: If MCP SDK not available.
|
||||
"""
|
||||
if self._connected:
|
||||
return self
|
||||
|
||||
try:
|
||||
from mcp.client.sse import sse_client
|
||||
|
||||
self._transport_context = sse_client(
|
||||
self.url,
|
||||
headers=self.headers if self.headers else None,
|
||||
terminate_on_close=True,
|
||||
)
|
||||
|
||||
read, write = await self._transport_context.__aenter__()
|
||||
|
||||
self._set_streams(read=read, write=write)
|
||||
|
||||
return self
|
||||
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"MCP library not available. Please install with: pip install mcp"
|
||||
) from e
|
||||
except Exception as e:
|
||||
self._clear_streams()
|
||||
raise ConnectionError(f"Failed to connect to SSE MCP server: {e}") from e
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
"""Close SSE connection."""
|
||||
if not self._connected:
|
||||
return
|
||||
|
||||
try:
|
||||
self._clear_streams()
|
||||
if self._transport_context is not None:
|
||||
await self._transport_context.__aexit__(None, None, None)
|
||||
|
||||
except Exception as e:
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.warning(f"Error during SSE transport disconnect: {e}")
|
||||
|
||||
async def __aenter__(self) -> Self:
|
||||
"""Async context manager entry."""
|
||||
return await self.connect()
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_val: BaseException | None,
|
||||
exc_tb: Any,
|
||||
) -> None:
|
||||
"""Async context manager exit."""
|
||||
await self.disconnect()
|
||||
153
lib/crewai/src/crewai/mcp/transports/stdio.py
Normal file
153
lib/crewai/src/crewai/mcp/transports/stdio.py
Normal file
@@ -0,0 +1,153 @@
|
||||
"""Stdio transport for MCP servers running as local processes."""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import subprocess
|
||||
from typing import Any
|
||||
|
||||
from typing_extensions import Self
|
||||
|
||||
from crewai.mcp.transports.base import BaseTransport, TransportType
|
||||
|
||||
|
||||
class StdioTransport(BaseTransport):
|
||||
"""Stdio transport for connecting to local MCP servers.
|
||||
|
||||
This transport connects to MCP servers running as local processes,
|
||||
communicating via standard input/output streams. Supports Python,
|
||||
Node.js, and other command-line servers.
|
||||
|
||||
Example:
|
||||
```python
|
||||
transport = StdioTransport(
|
||||
command="python",
|
||||
args=["path/to/server.py"],
|
||||
env={"API_KEY": "..."}
|
||||
)
|
||||
async with transport:
|
||||
# Use transport...
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
command: str,
|
||||
args: list[str] | None = None,
|
||||
env: dict[str, str] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize stdio transport.
|
||||
|
||||
Args:
|
||||
command: Command to execute (e.g., "python", "node", "npx").
|
||||
args: Command arguments (e.g., ["server.py"] or ["-y", "@mcp/server"]).
|
||||
env: Environment variables to pass to the process.
|
||||
**kwargs: Additional transport options.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
self.command = command
|
||||
self.args = args or []
|
||||
self.env = env or {}
|
||||
self._process: subprocess.Popen[bytes] | None = None
|
||||
self._transport_context: Any = None
|
||||
|
||||
@property
|
||||
def transport_type(self) -> TransportType:
|
||||
"""Return the transport type."""
|
||||
return TransportType.STDIO
|
||||
|
||||
async def connect(self) -> Self:
|
||||
"""Start the MCP server process and establish connection.
|
||||
|
||||
Returns:
|
||||
Self for method chaining.
|
||||
|
||||
Raises:
|
||||
ConnectionError: If process fails to start.
|
||||
ImportError: If MCP SDK not available.
|
||||
"""
|
||||
if self._connected:
|
||||
return self
|
||||
|
||||
try:
|
||||
from mcp import StdioServerParameters
|
||||
from mcp.client.stdio import stdio_client
|
||||
|
||||
process_env = os.environ.copy()
|
||||
process_env.update(self.env)
|
||||
|
||||
server_params = StdioServerParameters(
|
||||
command=self.command,
|
||||
args=self.args,
|
||||
env=process_env if process_env else None,
|
||||
)
|
||||
self._transport_context = stdio_client(server_params)
|
||||
|
||||
try:
|
||||
read, write = await self._transport_context.__aenter__()
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
self._transport_context = None
|
||||
raise ConnectionError(
|
||||
f"Failed to enter stdio transport context: {e}"
|
||||
) from e
|
||||
|
||||
self._set_streams(read=read, write=write)
|
||||
|
||||
return self
|
||||
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"MCP library not available. Please install with: pip install mcp"
|
||||
) from e
|
||||
except Exception as e:
|
||||
self._clear_streams()
|
||||
if self._transport_context is not None:
|
||||
self._transport_context = None
|
||||
raise ConnectionError(f"Failed to start MCP server process: {e}") from e
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
"""Terminate the MCP server process and close connection."""
|
||||
if not self._connected:
|
||||
return
|
||||
|
||||
try:
|
||||
self._clear_streams()
|
||||
|
||||
if self._transport_context is not None:
|
||||
await self._transport_context.__aexit__(None, None, None)
|
||||
|
||||
if self._process is not None:
|
||||
try:
|
||||
self._process.terminate()
|
||||
try:
|
||||
await asyncio.wait_for(self._process.wait(), timeout=5.0)
|
||||
except asyncio.TimeoutError:
|
||||
self._process.kill()
|
||||
await self._process.wait()
|
||||
# except ProcessLookupError:
|
||||
# pass
|
||||
finally:
|
||||
self._process = None
|
||||
|
||||
except Exception as e:
|
||||
# Log but don't raise - cleanup should be best effort
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.warning(f"Error during stdio transport disconnect: {e}")
|
||||
|
||||
async def __aenter__(self) -> Self:
|
||||
"""Async context manager entry."""
|
||||
return await self.connect()
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_val: BaseException | None,
|
||||
exc_tb: Any,
|
||||
) -> None:
|
||||
"""Async context manager exit."""
|
||||
await self.disconnect()
|
||||
@@ -67,31 +67,44 @@ def _prepare_documents_for_chromadb(
|
||||
ids: list[str] = []
|
||||
texts: list[str] = []
|
||||
metadatas: list[Mapping[str, str | int | float | bool]] = []
|
||||
seen_ids: dict[str, int] = {}
|
||||
|
||||
try:
|
||||
for doc in documents:
|
||||
if "doc_id" in doc:
|
||||
doc_id = str(doc["doc_id"])
|
||||
else:
|
||||
metadata = doc.get("metadata")
|
||||
if metadata and isinstance(metadata, dict) and "doc_id" in metadata:
|
||||
doc_id = str(metadata["doc_id"])
|
||||
else:
|
||||
content_for_hash = doc["content"]
|
||||
if metadata:
|
||||
metadata_str = json.dumps(metadata, sort_keys=True)
|
||||
content_for_hash = f"{content_for_hash}|{metadata_str}"
|
||||
doc_id = hashlib.sha256(content_for_hash.encode()).hexdigest()
|
||||
|
||||
for doc in documents:
|
||||
if "doc_id" in doc:
|
||||
ids.append(doc["doc_id"])
|
||||
else:
|
||||
content_for_hash = doc["content"]
|
||||
metadata = doc.get("metadata")
|
||||
if metadata:
|
||||
metadata_str = json.dumps(metadata, sort_keys=True)
|
||||
content_for_hash = f"{content_for_hash}|{metadata_str}"
|
||||
|
||||
content_hash = hashlib.blake2b(
|
||||
content_for_hash.encode(), digest_size=32
|
||||
).hexdigest()
|
||||
ids.append(content_hash)
|
||||
|
||||
texts.append(doc["content"])
|
||||
metadata = doc.get("metadata")
|
||||
if metadata:
|
||||
if isinstance(metadata, list):
|
||||
metadatas.append(metadata[0] if metadata and metadata[0] else {})
|
||||
if isinstance(metadata, list):
|
||||
processed_metadata = metadata[0] if metadata and metadata[0] else {}
|
||||
else:
|
||||
processed_metadata = metadata
|
||||
else:
|
||||
metadatas.append(metadata)
|
||||
else:
|
||||
metadatas.append({})
|
||||
processed_metadata = {}
|
||||
|
||||
if doc_id in seen_ids:
|
||||
idx = seen_ids[doc_id]
|
||||
texts[idx] = doc["content"]
|
||||
metadatas[idx] = processed_metadata
|
||||
else:
|
||||
idx = len(ids)
|
||||
ids.append(doc_id)
|
||||
texts.append(doc["content"])
|
||||
metadatas.append(processed_metadata)
|
||||
seen_ids[doc_id] = idx
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error preparing documents for ChromaDB: {e}") from e
|
||||
|
||||
return PreparedDocuments(ids, texts, metadatas)
|
||||
|
||||
|
||||
162
lib/crewai/src/crewai/tools/mcp_native_tool.py
Normal file
162
lib/crewai/src/crewai/tools/mcp_native_tool.py
Normal file
@@ -0,0 +1,162 @@
|
||||
"""Native MCP tool wrapper for CrewAI agents.
|
||||
|
||||
This module provides a tool wrapper that reuses existing MCP client sessions
|
||||
for better performance and connection management.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import Any
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
|
||||
|
||||
class MCPNativeTool(BaseTool):
|
||||
"""Native MCP tool that reuses client sessions.
|
||||
|
||||
This tool wrapper is used when agents connect to MCP servers using
|
||||
structured configurations. It reuses existing client sessions for
|
||||
better performance and proper connection lifecycle management.
|
||||
|
||||
Unlike MCPToolWrapper which connects on-demand, this tool uses
|
||||
a shared MCP client instance that maintains a persistent connection.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mcp_client: Any,
|
||||
tool_name: str,
|
||||
tool_schema: dict[str, Any],
|
||||
server_name: str,
|
||||
) -> None:
|
||||
"""Initialize native MCP tool.
|
||||
|
||||
Args:
|
||||
mcp_client: MCPClient instance with active session.
|
||||
tool_name: Original name of the tool on the MCP server.
|
||||
tool_schema: Schema information for the tool.
|
||||
server_name: Name of the MCP server for prefixing.
|
||||
"""
|
||||
# Create tool name with server prefix to avoid conflicts
|
||||
prefixed_name = f"{server_name}_{tool_name}"
|
||||
|
||||
# Handle args_schema properly - BaseTool expects a BaseModel subclass
|
||||
args_schema = tool_schema.get("args_schema")
|
||||
|
||||
# Only pass args_schema if it's provided
|
||||
kwargs = {
|
||||
"name": prefixed_name,
|
||||
"description": tool_schema.get(
|
||||
"description", f"Tool {tool_name} from {server_name}"
|
||||
),
|
||||
}
|
||||
|
||||
if args_schema is not None:
|
||||
kwargs["args_schema"] = args_schema
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
# Set instance attributes after super().__init__
|
||||
self._mcp_client = mcp_client
|
||||
self._original_tool_name = tool_name
|
||||
self._server_name = server_name
|
||||
# self._logger = logging.getLogger(__name__)
|
||||
|
||||
@property
|
||||
def mcp_client(self) -> Any:
|
||||
"""Get the MCP client instance."""
|
||||
return self._mcp_client
|
||||
|
||||
@property
|
||||
def original_tool_name(self) -> str:
|
||||
"""Get the original tool name."""
|
||||
return self._original_tool_name
|
||||
|
||||
@property
|
||||
def server_name(self) -> str:
|
||||
"""Get the server name."""
|
||||
return self._server_name
|
||||
|
||||
def _run(self, **kwargs) -> str:
|
||||
"""Execute tool using the MCP client session.
|
||||
|
||||
Args:
|
||||
**kwargs: Arguments to pass to the MCP tool.
|
||||
|
||||
Returns:
|
||||
Result from the MCP tool execution.
|
||||
"""
|
||||
try:
|
||||
try:
|
||||
asyncio.get_running_loop()
|
||||
|
||||
import concurrent.futures
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
coro = self._run_async(**kwargs)
|
||||
future = executor.submit(asyncio.run, coro)
|
||||
return future.result()
|
||||
except RuntimeError:
|
||||
return asyncio.run(self._run_async(**kwargs))
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"Error executing MCP tool {self.original_tool_name}: {e!s}"
|
||||
) from e
|
||||
|
||||
async def _run_async(self, **kwargs) -> str:
|
||||
"""Async implementation of tool execution.
|
||||
|
||||
Args:
|
||||
**kwargs: Arguments to pass to the MCP tool.
|
||||
|
||||
Returns:
|
||||
Result from the MCP tool execution.
|
||||
"""
|
||||
# Note: Since we use asyncio.run() which creates a new event loop each time,
|
||||
# Always reconnect on-demand because asyncio.run() creates new event loops per call
|
||||
# All MCP transport context managers (stdio, streamablehttp_client, sse_client)
|
||||
# use anyio.create_task_group() which can't span different event loops
|
||||
if self._mcp_client.connected:
|
||||
await self._mcp_client.disconnect()
|
||||
|
||||
await self._mcp_client.connect()
|
||||
|
||||
try:
|
||||
result = await self._mcp_client.call_tool(self.original_tool_name, kwargs)
|
||||
|
||||
except Exception as e:
|
||||
error_str = str(e).lower()
|
||||
if (
|
||||
"not connected" in error_str
|
||||
or "connection" in error_str
|
||||
or "send" in error_str
|
||||
):
|
||||
await self._mcp_client.disconnect()
|
||||
await self._mcp_client.connect()
|
||||
# Retry the call
|
||||
result = await self._mcp_client.call_tool(
|
||||
self.original_tool_name, kwargs
|
||||
)
|
||||
else:
|
||||
raise
|
||||
|
||||
finally:
|
||||
# Always disconnect after tool call to ensure clean context manager lifecycle
|
||||
# This prevents "exit cancel scope in different task" errors
|
||||
# All transport context managers must be exited in the same event loop they were entered
|
||||
await self._mcp_client.disconnect()
|
||||
|
||||
# Extract result content
|
||||
if isinstance(result, str):
|
||||
return result
|
||||
|
||||
# Handle various result formats
|
||||
if hasattr(result, "content") and result.content:
|
||||
if isinstance(result.content, list) and len(result.content) > 0:
|
||||
content_item = result.content[0]
|
||||
if hasattr(content_item, "text"):
|
||||
return str(content_item.text)
|
||||
return str(content_item)
|
||||
return str(result.content)
|
||||
|
||||
return str(result)
|
||||
@@ -2117,15 +2117,14 @@ def test_agent_with_only_crewai_knowledge():
|
||||
goal="Provide information based on knowledge sources",
|
||||
backstory="You have access to specific knowledge sources.",
|
||||
llm=LLM(
|
||||
model="openrouter/openai/gpt-4o-mini",
|
||||
api_key=os.getenv("OPENROUTER_API_KEY"),
|
||||
model="gpt-4o-mini",
|
||||
),
|
||||
)
|
||||
|
||||
# Create a task that requires the agent to use the knowledge
|
||||
task = Task(
|
||||
description="What is Vidit's favorite color?",
|
||||
expected_output="Vidit's favorclearite color.",
|
||||
expected_output="Vidit's favorite color.",
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,35 +1,128 @@
|
||||
interactions:
|
||||
- request:
|
||||
body: '{"messages": [{"role": "system", "content": "You are test role. test backstory\nYour
|
||||
body: '{"trace_id": "bf042234-54a3-4fc0-857d-1ae5585a174e", "execution_type":
|
||||
"crew", "user_identifier": null, "execution_context": {"crew_fingerprint": null,
|
||||
"crew_name": "Unknown Crew", "flow_name": null, "crewai_version": "1.3.0", "privacy_level":
|
||||
"standard"}, "execution_metadata": {"expected_duration_estimate": 300, "agent_count":
|
||||
0, "task_count": 0, "flow_method_count": 0, "execution_started_at": "2025-11-06T16:05:14.776800+00:00"}}'
|
||||
headers:
|
||||
Accept:
|
||||
- '*/*'
|
||||
Accept-Encoding:
|
||||
- gzip, deflate, zstd
|
||||
Connection:
|
||||
- keep-alive
|
||||
Content-Length:
|
||||
- '434'
|
||||
Content-Type:
|
||||
- application/json
|
||||
User-Agent:
|
||||
- CrewAI-CLI/1.3.0
|
||||
X-Crewai-Version:
|
||||
- 1.3.0
|
||||
method: POST
|
||||
uri: https://app.crewai.com/crewai_plus/api/v1/tracing/batches
|
||||
response:
|
||||
body:
|
||||
string: '{"error":"bad_credentials","message":"Bad credentials"}'
|
||||
headers:
|
||||
Connection:
|
||||
- keep-alive
|
||||
Content-Length:
|
||||
- '55'
|
||||
Content-Type:
|
||||
- application/json; charset=utf-8
|
||||
Date:
|
||||
- Thu, 06 Nov 2025 16:05:15 GMT
|
||||
cache-control:
|
||||
- no-store
|
||||
content-security-policy:
|
||||
- 'default-src ''self'' *.app.crewai.com app.crewai.com; script-src ''self''
|
||||
''unsafe-inline'' *.app.crewai.com app.crewai.com https://cdn.jsdelivr.net/npm/apexcharts
|
||||
https://www.gstatic.com https://run.pstmn.io https://apis.google.com https://apis.google.com/js/api.js
|
||||
https://accounts.google.com https://accounts.google.com/gsi/client https://cdnjs.cloudflare.com/ajax/libs/normalize/8.0.1/normalize.min.css.map
|
||||
https://*.google.com https://docs.google.com https://slides.google.com https://js.hs-scripts.com
|
||||
https://js.sentry-cdn.com https://browser.sentry-cdn.com https://www.googletagmanager.com
|
||||
https://js-na1.hs-scripts.com https://js.hubspot.com http://js-na1.hs-scripts.com
|
||||
https://bat.bing.com https://cdn.amplitude.com https://cdn.segment.com https://d1d3n03t5zntha.cloudfront.net/
|
||||
https://descriptusercontent.com https://edge.fullstory.com https://googleads.g.doubleclick.net
|
||||
https://js.hs-analytics.net https://js.hs-banner.com https://js.hsadspixel.net
|
||||
https://js.hscollectedforms.net https://js.usemessages.com https://snap.licdn.com
|
||||
https://static.cloudflareinsights.com https://static.reo.dev https://www.google-analytics.com
|
||||
https://share.descript.com/; style-src ''self'' ''unsafe-inline'' *.app.crewai.com
|
||||
app.crewai.com https://cdn.jsdelivr.net/npm/apexcharts; img-src ''self'' data:
|
||||
*.app.crewai.com app.crewai.com https://zeus.tools.crewai.com https://dashboard.tools.crewai.com
|
||||
https://cdn.jsdelivr.net https://forms.hsforms.com https://track.hubspot.com
|
||||
https://px.ads.linkedin.com https://px4.ads.linkedin.com https://www.google.com
|
||||
https://www.google.com.br; font-src ''self'' data: *.app.crewai.com app.crewai.com;
|
||||
connect-src ''self'' *.app.crewai.com app.crewai.com https://zeus.tools.crewai.com
|
||||
https://connect.useparagon.com/ https://zeus.useparagon.com/* https://*.useparagon.com/*
|
||||
https://run.pstmn.io https://connect.tools.crewai.com/ https://*.sentry.io
|
||||
https://www.google-analytics.com https://edge.fullstory.com https://rs.fullstory.com
|
||||
https://api.hubspot.com https://forms.hscollectedforms.net https://api.hubapi.com
|
||||
https://px.ads.linkedin.com https://px4.ads.linkedin.com https://google.com/pagead/form-data/16713662509
|
||||
https://google.com/ccm/form-data/16713662509 https://www.google.com/ccm/collect
|
||||
https://worker-actionkit.tools.crewai.com https://api.reo.dev; frame-src ''self''
|
||||
*.app.crewai.com app.crewai.com https://connect.useparagon.com/ https://zeus.tools.crewai.com
|
||||
https://zeus.useparagon.com/* https://connect.tools.crewai.com/ https://docs.google.com
|
||||
https://drive.google.com https://slides.google.com https://accounts.google.com
|
||||
https://*.google.com https://app.hubspot.com/ https://td.doubleclick.net https://www.googletagmanager.com/
|
||||
https://www.youtube.com https://share.descript.com'
|
||||
expires:
|
||||
- '0'
|
||||
permissions-policy:
|
||||
- camera=(), microphone=(self), geolocation=()
|
||||
pragma:
|
||||
- no-cache
|
||||
referrer-policy:
|
||||
- strict-origin-when-cross-origin
|
||||
strict-transport-security:
|
||||
- max-age=63072000; includeSubDomains
|
||||
vary:
|
||||
- Accept
|
||||
x-content-type-options:
|
||||
- nosniff
|
||||
x-frame-options:
|
||||
- SAMEORIGIN
|
||||
x-permitted-cross-domain-policies:
|
||||
- none
|
||||
x-request-id:
|
||||
- 9e528076-59a8-4c21-a999-2367937321ed
|
||||
x-runtime:
|
||||
- '0.070063'
|
||||
x-xss-protection:
|
||||
- 1; mode=block
|
||||
status:
|
||||
code: 401
|
||||
message: Unauthorized
|
||||
- request:
|
||||
body: '{"messages":[{"role":"system","content":"You are test role. test backstory\nYour
|
||||
personal goal is: test goal\nTo give my best complete final answer to the task
|
||||
use the exact following format:\n\nThought: I now can give a great answer\nFinal
|
||||
Answer: Your final answer must be the great and the most complete as possible,
|
||||
it must be outcome described.\n\nI MUST use these formats, my job depends on
|
||||
it!"}, {"role": "user", "content": "\nCurrent Task: Summarize the given context
|
||||
in one sentence\n\nThis is the expect criteria for your final answer: A one-sentence
|
||||
summary\nyou MUST return the actual complete content as the final answer, not
|
||||
a summary.\n\nThis is the context you''re working with:\nThe quick brown fox
|
||||
jumps over the lazy dog. This sentence contains every letter of the alphabet.\n\nBegin!
|
||||
This is VERY important to you, use the tools available and give your best Final
|
||||
Answer, your job depends on it!\n\nThought:"}], "model": "gpt-3.5-turbo"}'
|
||||
respond using the exact following format:\n\nThought: I now can give a great
|
||||
answer\nFinal Answer: Your final answer must be the great and the most complete
|
||||
as possible, it must be outcome described.\n\nI MUST use these formats, my job
|
||||
depends on it!"},{"role":"user","content":"\nCurrent Task: Summarize the given
|
||||
context in one sentence\n\nThis is the expected criteria for your final answer:
|
||||
A one-sentence summary\nyou MUST return the actual complete content as the final
|
||||
answer, not a summary.\n\nThis is the context you''re working with:\nThe quick
|
||||
brown fox jumps over the lazy dog. This sentence contains every letter of the
|
||||
alphabet.\n\nBegin! This is VERY important to you, use the tools available and
|
||||
give your best Final Answer, your job depends on it!\n\nThought:"}],"model":"gpt-3.5-turbo"}'
|
||||
headers:
|
||||
accept:
|
||||
- application/json
|
||||
accept-encoding:
|
||||
- gzip, deflate
|
||||
- gzip, deflate, zstd
|
||||
connection:
|
||||
- keep-alive
|
||||
content-length:
|
||||
- '961'
|
||||
- '963'
|
||||
content-type:
|
||||
- application/json
|
||||
cookie:
|
||||
- __cf_bm=rb61BZH2ejzD5YPmLaEJqI7km71QqyNJGTVdNxBq6qk-1727213194-1.0.1.1-pJ49onmgX9IugEMuYQMralzD7oj_6W.CHbSu4Su1z3NyjTGYg.rhgJZWng8feFYah._oSnoYlkTjpK1Wd2C9FA;
|
||||
_cfuvid=lbRdAddVWV6W3f5Dm9SaOPWDUOxqtZBSPr_fTW26nEA-1727213194587-0.0.1.1-604800000
|
||||
host:
|
||||
- api.openai.com
|
||||
user-agent:
|
||||
- OpenAI/Python 1.47.0
|
||||
- OpenAI/Python 1.109.1
|
||||
x-stainless-arch:
|
||||
- arm64
|
||||
x-stainless-async:
|
||||
@@ -39,30 +132,33 @@ interactions:
|
||||
x-stainless-os:
|
||||
- MacOS
|
||||
x-stainless-package-version:
|
||||
- 1.47.0
|
||||
x-stainless-raw-response:
|
||||
- 'true'
|
||||
- 1.109.1
|
||||
x-stainless-read-timeout:
|
||||
- '600'
|
||||
x-stainless-retry-count:
|
||||
- '0'
|
||||
x-stainless-runtime:
|
||||
- CPython
|
||||
x-stainless-runtime-version:
|
||||
- 3.11.7
|
||||
- 3.12.9
|
||||
method: POST
|
||||
uri: https://api.openai.com/v1/chat/completions
|
||||
response:
|
||||
content: "{\n \"id\": \"chatcmpl-AB7WTXzhDaFVbUrrQKXCo78KID8N9\",\n \"object\":
|
||||
\"chat.completion\",\n \"created\": 1727213889,\n \"model\": \"gpt-3.5-turbo-0125\",\n
|
||||
\ \"choices\": [\n {\n \"index\": 0,\n \"message\": {\n \"role\":
|
||||
\"assistant\",\n \"content\": \"I now can give a great answer\\nFinal
|
||||
Answer: The quick brown fox jumps over the lazy dog. This sentence contains
|
||||
every letter of the alphabet.\",\n \"refusal\": null\n },\n \"logprobs\":
|
||||
null,\n \"finish_reason\": \"stop\"\n }\n ],\n \"usage\": {\n \"prompt_tokens\":
|
||||
190,\n \"completion_tokens\": 30,\n \"total_tokens\": 220,\n \"completion_tokens_details\":
|
||||
{\n \"reasoning_tokens\": 0\n }\n },\n \"system_fingerprint\": null\n}\n"
|
||||
body:
|
||||
string: !!binary |
|
||||
H4sIAAAAAAAAAwAAAP//jFPBbtswDL37Kwidk6BOmgbLbRgwYLdtCLAVaxHIEm2rkUVVopOmRf59
|
||||
kJLG6dYBuxgwH9/z4yP9UgAIo8UShGolq87b8afbXTfbr74/89erb2o/axY0f7x1RkX+8VOMEoOq
|
||||
B1T8ypoo6rxFNuSOsAooGZNqubiZXl/Py3KegY402kRrPI9nk/mY+1DR+Kqczk/MlozCKJbwqwAA
|
||||
eMnP5NFpfBJLuBq9VjqMUTYolucmABHIpoqQMZrI0rEYDaAix+iy7S/QO40htWjgFoFl3EB62Rlr
|
||||
wQdSiBqYoDFbzB0VRobaOGlBurjDMLlzd+5zLnzMhSWsWoTH3qgNVIF2Dmp6goe+8xFoiyHLWPm8
|
||||
B03NBFatiRAxeVIIyZw0LgJuMezBIjMGoDqTpPWtrJAnl+MErPsoU5yut/YCkM4Ry7SOHOT9CTmc
|
||||
o7PU+EBV/IMqauNMbNcBZSSXYopMXmT0UADc5xX1b1IXPlDnec20wfy58kN51BPDVQzo7OYEMrG0
|
||||
Q306XYze0VtrZGlsvFiyUFK1qAfqcBGy14YugOJi6r/dvKd9nNy45n/kB0Ap9Ix67QNqo95OPLQF
|
||||
TD/Nv9rOKWfDImLYGoVrNhjSJjTWsrfHcxZxHxm7dW1cg8EHk286bbI4FL8BAAD//wMAHFSnRdID
|
||||
AAA=
|
||||
headers:
|
||||
CF-Cache-Status:
|
||||
- DYNAMIC
|
||||
CF-RAY:
|
||||
- 8c85eb7568111cf3-GRU
|
||||
- 99a5d4d0bb8f7327-EWR
|
||||
Connection:
|
||||
- keep-alive
|
||||
Content-Encoding:
|
||||
@@ -70,37 +166,54 @@ interactions:
|
||||
Content-Type:
|
||||
- application/json
|
||||
Date:
|
||||
- Tue, 24 Sep 2024 21:38:09 GMT
|
||||
- Thu, 06 Nov 2025 16:05:16 GMT
|
||||
Server:
|
||||
- cloudflare
|
||||
Set-Cookie:
|
||||
- __cf_bm=REDACTED;
|
||||
path=/; expires=Thu, 06-Nov-25 16:35:16 GMT; domain=.api.openai.com; HttpOnly;
|
||||
Secure; SameSite=None
|
||||
- _cfuvid=REDACTED;
|
||||
path=/; domain=.api.openai.com; HttpOnly; Secure; SameSite=None
|
||||
Strict-Transport-Security:
|
||||
- max-age=31536000; includeSubDomains; preload
|
||||
Transfer-Encoding:
|
||||
- chunked
|
||||
X-Content-Type-Options:
|
||||
- nosniff
|
||||
access-control-expose-headers:
|
||||
- X-Request-ID
|
||||
alt-svc:
|
||||
- h3=":443"; ma=86400
|
||||
cf-cache-status:
|
||||
- DYNAMIC
|
||||
openai-organization:
|
||||
- crewai-iuxna1
|
||||
- user-REDACTED
|
||||
openai-processing-ms:
|
||||
- '662'
|
||||
- '836'
|
||||
openai-project:
|
||||
- proj_REDACTED
|
||||
openai-version:
|
||||
- '2020-10-01'
|
||||
strict-transport-security:
|
||||
- max-age=31536000; includeSubDomains; preload
|
||||
x-envoy-upstream-service-time:
|
||||
- '983'
|
||||
x-openai-proxy-wasm:
|
||||
- v0.1
|
||||
x-ratelimit-limit-requests:
|
||||
- '10000'
|
||||
x-ratelimit-limit-tokens:
|
||||
- '50000000'
|
||||
- '200000'
|
||||
x-ratelimit-remaining-requests:
|
||||
- '9999'
|
||||
x-ratelimit-remaining-tokens:
|
||||
- '49999772'
|
||||
- '199785'
|
||||
x-ratelimit-reset-requests:
|
||||
- 6ms
|
||||
- 8.64s
|
||||
x-ratelimit-reset-tokens:
|
||||
- 0s
|
||||
- 64ms
|
||||
x-request-id:
|
||||
- req_833406276d399714b624a32627fc5b4a
|
||||
http_version: HTTP/1.1
|
||||
status_code: 200
|
||||
- req_c302b31f8f804399ae05fc424215303a
|
||||
status:
|
||||
code: 200
|
||||
message: OK
|
||||
version: 1
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,54 +1,165 @@
|
||||
interactions:
|
||||
- request:
|
||||
body: '{"model": "openai/gpt-4o-mini", "messages": [{"role": "system", "content":
|
||||
"Your goal is to rewrite the user query so that it is optimized for retrieval
|
||||
from a vector database. Consider how the query will be used to find relevant
|
||||
documents, and aim to make it more specific and context-aware. \n\n Do not include
|
||||
any other text than the rewritten query, especially any preamble or postamble
|
||||
and only add expected output format if its relevant to the rewritten query.
|
||||
\n\n Focus on the key words of the intended task and to retrieve the most relevant
|
||||
information. \n\n There will be some extra context provided that might need
|
||||
to be removed such as expected_output formats structured_outputs and other instructions."},
|
||||
{"role": "user", "content": "The original query is: What is Vidit''s favorite
|
||||
color?\n\nThis is the expected criteria for your final answer: Vidit''s favorclearite
|
||||
color.\nyou MUST return the actual complete content as the final answer, not
|
||||
a summary.."}], "stream": false, "stop": ["\nObservation:"]}'
|
||||
body: '{"trace_id": "9d9d9d14-e5bc-44bc-8cfc-3df9ba4e6055", "execution_type":
|
||||
"crew", "user_identifier": null, "execution_context": {"crew_fingerprint": null,
|
||||
"crew_name": "crew", "flow_name": null, "crewai_version": "1.3.0", "privacy_level":
|
||||
"standard"}, "execution_metadata": {"expected_duration_estimate": 300, "agent_count":
|
||||
0, "task_count": 0, "flow_method_count": 0, "execution_started_at": "2025-11-06T15:58:15.778396+00:00"},
|
||||
"ephemeral_trace_id": "9d9d9d14-e5bc-44bc-8cfc-3df9ba4e6055"}'
|
||||
headers:
|
||||
Accept:
|
||||
- '*/*'
|
||||
Accept-Encoding:
|
||||
- gzip, deflate, zstd
|
||||
Connection:
|
||||
- keep-alive
|
||||
Content-Length:
|
||||
- '488'
|
||||
Content-Type:
|
||||
- application/json
|
||||
User-Agent:
|
||||
- CrewAI-CLI/1.3.0
|
||||
X-Crewai-Version:
|
||||
- 1.3.0
|
||||
method: POST
|
||||
uri: https://app.crewai.com/crewai_plus/api/v1/tracing/ephemeral/batches
|
||||
response:
|
||||
body:
|
||||
string: '{"id":"f303021e-f1a0-4fd8-9c7d-8ba6779f8ad3","ephemeral_trace_id":"9d9d9d14-e5bc-44bc-8cfc-3df9ba4e6055","execution_type":"crew","crew_name":"crew","flow_name":null,"status":"running","duration_ms":null,"crewai_version":"1.3.0","total_events":0,"execution_context":{"crew_fingerprint":null,"crew_name":"crew","flow_name":null,"crewai_version":"1.3.0","privacy_level":"standard"},"created_at":"2025-11-06T15:58:16.189Z","updated_at":"2025-11-06T15:58:16.189Z","access_code":"TRACE-c2990cd4d4","user_identifier":null}'
|
||||
headers:
|
||||
Connection:
|
||||
- keep-alive
|
||||
Content-Length:
|
||||
- '515'
|
||||
Content-Type:
|
||||
- application/json; charset=utf-8
|
||||
Date:
|
||||
- Thu, 06 Nov 2025 15:58:16 GMT
|
||||
cache-control:
|
||||
- no-store
|
||||
content-security-policy:
|
||||
- 'default-src ''self'' *.app.crewai.com app.crewai.com; script-src ''self''
|
||||
''unsafe-inline'' *.app.crewai.com app.crewai.com https://cdn.jsdelivr.net/npm/apexcharts
|
||||
https://www.gstatic.com https://run.pstmn.io https://apis.google.com https://apis.google.com/js/api.js
|
||||
https://accounts.google.com https://accounts.google.com/gsi/client https://cdnjs.cloudflare.com/ajax/libs/normalize/8.0.1/normalize.min.css.map
|
||||
https://*.google.com https://docs.google.com https://slides.google.com https://js.hs-scripts.com
|
||||
https://js.sentry-cdn.com https://browser.sentry-cdn.com https://www.googletagmanager.com
|
||||
https://js-na1.hs-scripts.com https://js.hubspot.com http://js-na1.hs-scripts.com
|
||||
https://bat.bing.com https://cdn.amplitude.com https://cdn.segment.com https://d1d3n03t5zntha.cloudfront.net/
|
||||
https://descriptusercontent.com https://edge.fullstory.com https://googleads.g.doubleclick.net
|
||||
https://js.hs-analytics.net https://js.hs-banner.com https://js.hsadspixel.net
|
||||
https://js.hscollectedforms.net https://js.usemessages.com https://snap.licdn.com
|
||||
https://static.cloudflareinsights.com https://static.reo.dev https://www.google-analytics.com
|
||||
https://share.descript.com/; style-src ''self'' ''unsafe-inline'' *.app.crewai.com
|
||||
app.crewai.com https://cdn.jsdelivr.net/npm/apexcharts; img-src ''self'' data:
|
||||
*.app.crewai.com app.crewai.com https://zeus.tools.crewai.com https://dashboard.tools.crewai.com
|
||||
https://cdn.jsdelivr.net https://forms.hsforms.com https://track.hubspot.com
|
||||
https://px.ads.linkedin.com https://px4.ads.linkedin.com https://www.google.com
|
||||
https://www.google.com.br; font-src ''self'' data: *.app.crewai.com app.crewai.com;
|
||||
connect-src ''self'' *.app.crewai.com app.crewai.com https://zeus.tools.crewai.com
|
||||
https://connect.useparagon.com/ https://zeus.useparagon.com/* https://*.useparagon.com/*
|
||||
https://run.pstmn.io https://connect.tools.crewai.com/ https://*.sentry.io
|
||||
https://www.google-analytics.com https://edge.fullstory.com https://rs.fullstory.com
|
||||
https://api.hubspot.com https://forms.hscollectedforms.net https://api.hubapi.com
|
||||
https://px.ads.linkedin.com https://px4.ads.linkedin.com https://google.com/pagead/form-data/16713662509
|
||||
https://google.com/ccm/form-data/16713662509 https://www.google.com/ccm/collect
|
||||
https://worker-actionkit.tools.crewai.com https://api.reo.dev; frame-src ''self''
|
||||
*.app.crewai.com app.crewai.com https://connect.useparagon.com/ https://zeus.tools.crewai.com
|
||||
https://zeus.useparagon.com/* https://connect.tools.crewai.com/ https://docs.google.com
|
||||
https://drive.google.com https://slides.google.com https://accounts.google.com
|
||||
https://*.google.com https://app.hubspot.com/ https://td.doubleclick.net https://www.googletagmanager.com/
|
||||
https://www.youtube.com https://share.descript.com'
|
||||
etag:
|
||||
- W/"8df0b730688b8bc094b74c66a6293578"
|
||||
expires:
|
||||
- '0'
|
||||
permissions-policy:
|
||||
- camera=(), microphone=(self), geolocation=()
|
||||
pragma:
|
||||
- no-cache
|
||||
referrer-policy:
|
||||
- strict-origin-when-cross-origin
|
||||
strict-transport-security:
|
||||
- max-age=63072000; includeSubDomains
|
||||
vary:
|
||||
- Accept
|
||||
x-content-type-options:
|
||||
- nosniff
|
||||
x-frame-options:
|
||||
- SAMEORIGIN
|
||||
x-permitted-cross-domain-policies:
|
||||
- none
|
||||
x-request-id:
|
||||
- 38352441-7508-4e1e-9bff-77d1689dffdf
|
||||
x-runtime:
|
||||
- '0.085540'
|
||||
x-xss-protection:
|
||||
- 1; mode=block
|
||||
status:
|
||||
code: 201
|
||||
message: Created
|
||||
- request:
|
||||
body: '{"messages":[{"role":"system","content":"Your goal is to rewrite the user
|
||||
query so that it is optimized for retrieval from a vector database. Consider
|
||||
how the query will be used to find relevant documents, and aim to make it more
|
||||
specific and context-aware. \n\n Do not include any other text than the rewritten
|
||||
query, especially any preamble or postamble and only add expected output format
|
||||
if its relevant to the rewritten query. \n\n Focus on the key words of the intended
|
||||
task and to retrieve the most relevant information. \n\n There will be some
|
||||
extra context provided that might need to be removed such as expected_output
|
||||
formats structured_outputs and other instructions."},{"role":"user","content":"The
|
||||
original query is: What is Vidit''s favorite color?\n\nThis is the expected
|
||||
criteria for your final answer: Vidit''s favorite color.\nyou MUST return the
|
||||
actual complete content as the final answer, not a summary.."}],"model":"gpt-4o-mini"}'
|
||||
headers:
|
||||
accept:
|
||||
- '*/*'
|
||||
- application/json
|
||||
accept-encoding:
|
||||
- gzip, deflate
|
||||
- gzip, deflate, zstd
|
||||
connection:
|
||||
- keep-alive
|
||||
content-length:
|
||||
- '1017'
|
||||
- '950'
|
||||
content-type:
|
||||
- application/json
|
||||
host:
|
||||
- openrouter.ai
|
||||
http-referer:
|
||||
- https://litellm.ai
|
||||
- api.openai.com
|
||||
user-agent:
|
||||
- litellm/1.68.0
|
||||
x-title:
|
||||
- liteLLM
|
||||
- OpenAI/Python 1.109.1
|
||||
x-stainless-arch:
|
||||
- arm64
|
||||
x-stainless-async:
|
||||
- 'false'
|
||||
x-stainless-lang:
|
||||
- python
|
||||
x-stainless-os:
|
||||
- MacOS
|
||||
x-stainless-package-version:
|
||||
- 1.109.1
|
||||
x-stainless-read-timeout:
|
||||
- '600'
|
||||
x-stainless-retry-count:
|
||||
- '0'
|
||||
x-stainless-runtime:
|
||||
- CPython
|
||||
x-stainless-runtime-version:
|
||||
- 3.12.9
|
||||
method: POST
|
||||
uri: https://openrouter.ai/api/v1/chat/completions
|
||||
uri: https://api.openai.com/v1/chat/completions
|
||||
response:
|
||||
body:
|
||||
string: !!binary |
|
||||
H4sIAAAAAAAAAwAAAP//4lKAAS4AAAAA//90kE1PIzEMhv8Kei97Sdnplwq5gTgAF8ShcFitRmnG
|
||||
nTFk4ihxq11V899Xs6gFJLja78djH8ANLFqKk+lqsZpP56vpYqJhublfP1eP65v1i79Lt9fdMwxS
|
||||
lj03lGHxkChe3cGgl4YCLCRRdPyzTTpZyKTnyDCQzQt5hYXvnJ576VMgZYkw8JmcUgP7XmvgO2FP
|
||||
BfbXAUHalGVTYOMuBIMtRy5dnckVibAoKgkG0Snvqf5my7GhP7CVQU+luJZgD8gSCBauFC7qoo40
|
||||
EpXiSPrEDeuPcrZ1e8msdOYlSIZBpu2uuHDEeWvi2L4NhuG3QflblPqRpaWcMv8P3Ka6ml/OLmaz
|
||||
6rKCwe6IkbL0SWuVV4pl/MNy5Di+6DRfGqioC+/Ci8p8NtcNqeNQxlTvfEfNSVwNX4R+1J/u+GAZ
|
||||
hn8AAAD//wMAIwJ79CICAAA=
|
||||
H4sIAAAAAAAAAwAAAP//jFJBbtswELzrFQQvvViBLMuy42sObYEWKIoiQFMEAkOu5G0oLkGu0xaB
|
||||
/15QciwlTYFceODsDGeG+5gJIdHInZB6r1j33uZX33+1H78y9tvVt+Lmpv68KrfXX96Xnz7cu61c
|
||||
JAbd/QTNT6wLTb23wEhuhHUAxZBUl5u6rKqqvqwHoCcDNtE6z3lFeY8O87Ioq7zY5MuTuN4Taohy
|
||||
J35kQgjxOJzJpzPwW+5EsXi66SFG1YHcnYeEkIFsupEqRoysHMvFBGpyDG6wfo0G+V0UrXqggAxC
|
||||
k6VwMZ8O0B6iSo7dwdoZoJwjVinx4PP2hBzPzix1PtBdfEGVLTqM+yaAiuSSi8jk5YAeMyFuhwYO
|
||||
z0JJH6j33DDdw/DccrMa9eRU/ISuTxgTKzsnbRevyDUGWKGNswqlVnoPZqJOfauDQZoB2Sz0v2Ze
|
||||
0x6Do+veIj8BWoNnMI0PYFA/DzyNBUhr+b+xc8mDYRkhPKCGhhFC+ggDrTrYcVlk/BMZ+qZF10Hw
|
||||
AceNaX2zrgvV1rBeX8rsmP0FAAD//wMA5SIzeT8DAAA=
|
||||
headers:
|
||||
Access-Control-Allow-Origin:
|
||||
- '*'
|
||||
CF-RAY:
|
||||
- 9402c9db99ec4722-BOM
|
||||
- 99a5ca96bb1443e7-EWR
|
||||
Connection:
|
||||
- keep-alive
|
||||
Content-Encoding:
|
||||
@@ -56,73 +167,122 @@ interactions:
|
||||
Content-Type:
|
||||
- application/json
|
||||
Date:
|
||||
- Thu, 15 May 2025 12:55:14 GMT
|
||||
- Thu, 06 Nov 2025 15:58:16 GMT
|
||||
Server:
|
||||
- cloudflare
|
||||
Set-Cookie:
|
||||
- __cf_bm=REDACTED;
|
||||
path=/; expires=Thu, 06-Nov-25 16:28:16 GMT; domain=.api.openai.com; HttpOnly;
|
||||
Secure; SameSite=None
|
||||
- _cfuvid=REDACTED;
|
||||
path=/; domain=.api.openai.com; HttpOnly; Secure; SameSite=None
|
||||
Strict-Transport-Security:
|
||||
- max-age=31536000; includeSubDomains; preload
|
||||
Transfer-Encoding:
|
||||
- chunked
|
||||
Vary:
|
||||
- Accept-Encoding
|
||||
x-clerk-auth-message:
|
||||
- Invalid JWT form. A JWT consists of three parts separated by dots. (reason=token-invalid,
|
||||
token-carrier=header)
|
||||
x-clerk-auth-reason:
|
||||
- token-invalid
|
||||
x-clerk-auth-status:
|
||||
- signed-out
|
||||
X-Content-Type-Options:
|
||||
- nosniff
|
||||
access-control-expose-headers:
|
||||
- X-Request-ID
|
||||
alt-svc:
|
||||
- h3=":443"; ma=86400
|
||||
cf-cache-status:
|
||||
- DYNAMIC
|
||||
openai-organization:
|
||||
- user-REDACTED
|
||||
openai-processing-ms:
|
||||
- '235'
|
||||
openai-project:
|
||||
- proj_REDACTED
|
||||
openai-version:
|
||||
- '2020-10-01'
|
||||
x-envoy-upstream-service-time:
|
||||
- '420'
|
||||
x-openai-proxy-wasm:
|
||||
- v0.1
|
||||
x-ratelimit-limit-requests:
|
||||
- '10000'
|
||||
x-ratelimit-limit-tokens:
|
||||
- '200000'
|
||||
x-ratelimit-remaining-requests:
|
||||
- '9999'
|
||||
x-ratelimit-remaining-tokens:
|
||||
- '199785'
|
||||
x-ratelimit-reset-requests:
|
||||
- 8.64s
|
||||
x-ratelimit-reset-tokens:
|
||||
- 64ms
|
||||
x-request-id:
|
||||
- req_9810e9721aa9463c930414ab5174ab61
|
||||
status:
|
||||
code: 200
|
||||
message: OK
|
||||
- request:
|
||||
body: '{"model": "openai/gpt-4o-mini", "messages": [{"role": "system", "content":
|
||||
"You are Information Agent. You have access to specific knowledge sources.\nYour
|
||||
personal goal is: Provide information based on knowledge sources\nTo give my
|
||||
best complete final answer to the task respond using the exact following format:\n\nThought:
|
||||
I now can give a great answer\nFinal Answer: Your final answer must be the great
|
||||
and the most complete as possible, it must be outcome described.\n\nI MUST use
|
||||
these formats, my job depends on it!"}, {"role": "user", "content": "\nCurrent
|
||||
Task: What is Vidit''s favorite color?\n\nThis is the expected criteria for
|
||||
your final answer: Vidit''s favorclearite color.\nyou MUST return the actual
|
||||
complete content as the final answer, not a summary.\n\nBegin! This is VERY
|
||||
important to you, use the tools available and give your best Final Answer, your
|
||||
job depends on it!\n\nThought:"}], "stream": false, "stop": ["\nObservation:"]}'
|
||||
body: '{"messages":[{"role":"system","content":"You are Information Agent. You
|
||||
have access to specific knowledge sources.\nYour personal goal is: Provide information
|
||||
based on knowledge sources\nTo give my best complete final answer to the task
|
||||
respond using the exact following format:\n\nThought: I now can give a great
|
||||
answer\nFinal Answer: Your final answer must be the great and the most complete
|
||||
as possible, it must be outcome described.\n\nI MUST use these formats, my job
|
||||
depends on it!"},{"role":"user","content":"\nCurrent Task: What is Vidit''s
|
||||
favorite color?\n\nThis is the expected criteria for your final answer: Vidit''s
|
||||
favorite color.\nyou MUST return the actual complete content as the final answer,
|
||||
not a summary.\n\nBegin! This is VERY important to you, use the tools available
|
||||
and give your best Final Answer, your job depends on it!\n\nThought:"}],"model":"gpt-4o-mini"}'
|
||||
headers:
|
||||
accept:
|
||||
- '*/*'
|
||||
- application/json
|
||||
accept-encoding:
|
||||
- gzip, deflate
|
||||
- gzip, deflate, zstd
|
||||
connection:
|
||||
- keep-alive
|
||||
content-length:
|
||||
- '951'
|
||||
- '884'
|
||||
content-type:
|
||||
- application/json
|
||||
cookie:
|
||||
- __cf_bm=REDACTED;
|
||||
_cfuvid=REDACTED
|
||||
host:
|
||||
- openrouter.ai
|
||||
http-referer:
|
||||
- https://litellm.ai
|
||||
- api.openai.com
|
||||
user-agent:
|
||||
- litellm/1.68.0
|
||||
x-title:
|
||||
- liteLLM
|
||||
- OpenAI/Python 1.109.1
|
||||
x-stainless-arch:
|
||||
- arm64
|
||||
x-stainless-async:
|
||||
- 'false'
|
||||
x-stainless-lang:
|
||||
- python
|
||||
x-stainless-os:
|
||||
- MacOS
|
||||
x-stainless-package-version:
|
||||
- 1.109.1
|
||||
x-stainless-read-timeout:
|
||||
- '600'
|
||||
x-stainless-retry-count:
|
||||
- '0'
|
||||
x-stainless-runtime:
|
||||
- CPython
|
||||
x-stainless-runtime-version:
|
||||
- 3.12.9
|
||||
method: POST
|
||||
uri: https://openrouter.ai/api/v1/chat/completions
|
||||
uri: https://api.openai.com/v1/chat/completions
|
||||
response:
|
||||
body:
|
||||
string: !!binary |
|
||||
H4sIAAAAAAAAAwAAAP//4lKAAS4AAAAA///iQjABAAAA//90kN1qGzEQRl9FfNdyul4nday73ARy
|
||||
VUpLE2jLIu+O15NoZ4QkOy1moa/R1+uTlE1wnEB7qU/zc84cwB0cepLZfHm+XMwXy/nF7II/3d7V
|
||||
H+tOPvsS3le3d+keFjHpnjtKcPgQSa5uYDFoRwEOGkk8v+tjmZ3rbGBhWOj6ntoCh3bry1mrQwxU
|
||||
WAUWbSJfqIM7rbVot8otZbivBwTtY9J1hpNdCBYbFs7bJpHPKnDIRSMsxBfeU/OfX5aOfsBVFgPl
|
||||
7HuCOyBpIDj4nDkXL2WiUSkkE+mNEX00rRfT856MN/0EarzkR0rGfJNrFh/M1dPbmS/ccfnz63c2
|
||||
G7/XxIVMq0GT4WzWYUdnsEi02WUfjiLPjCz9czCO3y3yz1xomCx6SjHxE8omNtViVV/WdbWqYLE7
|
||||
CsSkQyxN0QeSPF2wmgyOxz3lK4uixYdTcrmyb7ubjornkKexrW+31L0UV+M/pr6ufxF51TKOfwEA
|
||||
AP//AwBybekMaAIAAA==
|
||||
H4sIAAAAAAAAAwAAAP//jFPBahsxEL37KwZderGN7dqO41vaUghtT4FCacIiS7PrSbQaVZq1swT/
|
||||
e9HayTptCr0ING/e6M2b0dMAQJFVa1Bmq8XUwY0+/tiXXy4/2BDN/p7izdY4/vrp29Wvmxbv1TAz
|
||||
eHOPRp5ZY8N1cCjE/gibiFowV51eLGfz+Xx5ueqAmi26TKuCjOY8qsnTaDaZzUeTi9F0dWJvmQwm
|
||||
tYafAwCAp+7MOr3FR7WGyfA5UmNKukK1fkkCUJFdjiidEiXRXtSwBw17Qd9JvwbPezDaQ0U7BA1V
|
||||
lg3apz1GgFv/mbx2cNXd1/CdLMm7BKXecSRBMOw4AiXwLBCajSPjWrBsmhq9oAWOsCeLroUHz3s/
|
||||
husSWm5gq3cIKaChkgx0ih4lZ1sUTS6B3nAjxweHcA21bmGDoDcOQRhC5B3ZLLjmiJApHNFCxBTY
|
||||
Jxyf9xuxbJLOnvvGuTNAe8+i88w6p+9OyOHFW8dViLxJf1BVSZ7StoioE/vsYxIOqkMPA4C7bobN
|
||||
q7GoELkOUgg/YPfcdLk61lP96vTofHEChUW7Pj6bvh++Ua842Xa2Bcpos0XbU/uV0Y0lPgMGZ13/
|
||||
reat2sfOyVf/U74HjMEgaIsQ0ZJ53XGfFjH/rH+lvbjcCVYJ444MFkIY8yQslrpxx31XqU2CdVGS
|
||||
rzCGSMelL0OxWE50ucTF4lINDoPfAAAA//8DAPFGfbMCBAAA
|
||||
headers:
|
||||
Access-Control-Allow-Origin:
|
||||
- '*'
|
||||
CF-RAY:
|
||||
- 9402c9e1b94a4722-BOM
|
||||
- 99a5ca9c5ef543e7-EWR
|
||||
Connection:
|
||||
- keep-alive
|
||||
Content-Encoding:
|
||||
@@ -130,20 +290,47 @@ interactions:
|
||||
Content-Type:
|
||||
- application/json
|
||||
Date:
|
||||
- Thu, 15 May 2025 12:55:15 GMT
|
||||
- Thu, 06 Nov 2025 15:58:19 GMT
|
||||
Server:
|
||||
- cloudflare
|
||||
Strict-Transport-Security:
|
||||
- max-age=31536000; includeSubDomains; preload
|
||||
Transfer-Encoding:
|
||||
- chunked
|
||||
Vary:
|
||||
- Accept-Encoding
|
||||
x-clerk-auth-message:
|
||||
- Invalid JWT form. A JWT consists of three parts separated by dots. (reason=token-invalid,
|
||||
token-carrier=header)
|
||||
x-clerk-auth-reason:
|
||||
- token-invalid
|
||||
x-clerk-auth-status:
|
||||
- signed-out
|
||||
X-Content-Type-Options:
|
||||
- nosniff
|
||||
access-control-expose-headers:
|
||||
- X-Request-ID
|
||||
alt-svc:
|
||||
- h3=":443"; ma=86400
|
||||
cf-cache-status:
|
||||
- DYNAMIC
|
||||
openai-organization:
|
||||
- user-REDACTED
|
||||
openai-processing-ms:
|
||||
- '1326'
|
||||
openai-project:
|
||||
- proj_REDACTED
|
||||
openai-version:
|
||||
- '2020-10-01'
|
||||
x-envoy-upstream-service-time:
|
||||
- '1754'
|
||||
x-openai-proxy-wasm:
|
||||
- v0.1
|
||||
x-ratelimit-limit-requests:
|
||||
- '10000'
|
||||
x-ratelimit-limit-tokens:
|
||||
- '200000'
|
||||
x-ratelimit-remaining-requests:
|
||||
- '9998'
|
||||
x-ratelimit-remaining-tokens:
|
||||
- '199803'
|
||||
x-ratelimit-reset-requests:
|
||||
- 15.913s
|
||||
x-ratelimit-reset-tokens:
|
||||
- 59ms
|
||||
x-request-id:
|
||||
- req_f975e16b666e498b8bcfdfab525f71b3
|
||||
status:
|
||||
code: 200
|
||||
message: OK
|
||||
|
||||
@@ -0,0 +1,202 @@
|
||||
interactions:
|
||||
- request:
|
||||
body: '{"trace_id": "1703c4e0-d3be-411c-85e7-48018c2df384", "execution_type":
|
||||
"crew", "user_identifier": null, "execution_context": {"crew_fingerprint": null,
|
||||
"crew_name": "Unknown Crew", "flow_name": null, "crewai_version": "1.3.0", "privacy_level":
|
||||
"standard"}, "execution_metadata": {"expected_duration_estimate": 300, "agent_count":
|
||||
0, "task_count": 0, "flow_method_count": 0, "execution_started_at": "2025-11-07T01:58:22.260309+00:00"}}'
|
||||
headers:
|
||||
Accept:
|
||||
- '*/*'
|
||||
Accept-Encoding:
|
||||
- gzip, deflate, zstd
|
||||
Connection:
|
||||
- keep-alive
|
||||
Content-Length:
|
||||
- '434'
|
||||
Content-Type:
|
||||
- application/json
|
||||
User-Agent:
|
||||
- CrewAI-CLI/1.3.0
|
||||
X-Crewai-Version:
|
||||
- 1.3.0
|
||||
method: POST
|
||||
uri: https://app.crewai.com/crewai_plus/api/v1/tracing/batches
|
||||
response:
|
||||
body:
|
||||
string: '{"error":"bad_credentials","message":"Bad credentials"}'
|
||||
headers:
|
||||
Connection:
|
||||
- keep-alive
|
||||
Content-Length:
|
||||
- '55'
|
||||
Content-Type:
|
||||
- application/json; charset=utf-8
|
||||
Date:
|
||||
- Fri, 07 Nov 2025 01:58:22 GMT
|
||||
cache-control:
|
||||
- no-store
|
||||
content-security-policy:
|
||||
- 'default-src ''self'' *.app.crewai.com app.crewai.com; script-src ''self''
|
||||
''unsafe-inline'' *.app.crewai.com app.crewai.com https://cdn.jsdelivr.net/npm/apexcharts
|
||||
https://www.gstatic.com https://run.pstmn.io https://apis.google.com https://apis.google.com/js/api.js
|
||||
https://accounts.google.com https://accounts.google.com/gsi/client https://cdnjs.cloudflare.com/ajax/libs/normalize/8.0.1/normalize.min.css.map
|
||||
https://*.google.com https://docs.google.com https://slides.google.com https://js.hs-scripts.com
|
||||
https://js.sentry-cdn.com https://browser.sentry-cdn.com https://www.googletagmanager.com
|
||||
https://js-na1.hs-scripts.com https://js.hubspot.com http://js-na1.hs-scripts.com
|
||||
https://bat.bing.com https://cdn.amplitude.com https://cdn.segment.com https://d1d3n03t5zntha.cloudfront.net/
|
||||
https://descriptusercontent.com https://edge.fullstory.com https://googleads.g.doubleclick.net
|
||||
https://js.hs-analytics.net https://js.hs-banner.com https://js.hsadspixel.net
|
||||
https://js.hscollectedforms.net https://js.usemessages.com https://snap.licdn.com
|
||||
https://static.cloudflareinsights.com https://static.reo.dev https://www.google-analytics.com
|
||||
https://share.descript.com/; style-src ''self'' ''unsafe-inline'' *.app.crewai.com
|
||||
app.crewai.com https://cdn.jsdelivr.net/npm/apexcharts; img-src ''self'' data:
|
||||
*.app.crewai.com app.crewai.com https://zeus.tools.crewai.com https://dashboard.tools.crewai.com
|
||||
https://cdn.jsdelivr.net https://forms.hsforms.com https://track.hubspot.com
|
||||
https://px.ads.linkedin.com https://px4.ads.linkedin.com https://www.google.com
|
||||
https://www.google.com.br; font-src ''self'' data: *.app.crewai.com app.crewai.com;
|
||||
connect-src ''self'' *.app.crewai.com app.crewai.com https://zeus.tools.crewai.com
|
||||
https://connect.useparagon.com/ https://zeus.useparagon.com/* https://*.useparagon.com/*
|
||||
https://run.pstmn.io https://connect.tools.crewai.com/ https://*.sentry.io
|
||||
https://www.google-analytics.com https://edge.fullstory.com https://rs.fullstory.com
|
||||
https://api.hubspot.com https://forms.hscollectedforms.net https://api.hubapi.com
|
||||
https://px.ads.linkedin.com https://px4.ads.linkedin.com https://google.com/pagead/form-data/16713662509
|
||||
https://google.com/ccm/form-data/16713662509 https://www.google.com/ccm/collect
|
||||
https://worker-actionkit.tools.crewai.com https://api.reo.dev; frame-src ''self''
|
||||
*.app.crewai.com app.crewai.com https://connect.useparagon.com/ https://zeus.tools.crewai.com
|
||||
https://zeus.useparagon.com/* https://connect.tools.crewai.com/ https://docs.google.com
|
||||
https://drive.google.com https://slides.google.com https://accounts.google.com
|
||||
https://*.google.com https://app.hubspot.com/ https://td.doubleclick.net https://www.googletagmanager.com/
|
||||
https://www.youtube.com https://share.descript.com'
|
||||
expires:
|
||||
- '0'
|
||||
permissions-policy:
|
||||
- camera=(), microphone=(self), geolocation=()
|
||||
pragma:
|
||||
- no-cache
|
||||
referrer-policy:
|
||||
- strict-origin-when-cross-origin
|
||||
strict-transport-security:
|
||||
- max-age=63072000; includeSubDomains
|
||||
vary:
|
||||
- Accept
|
||||
x-content-type-options:
|
||||
- nosniff
|
||||
x-frame-options:
|
||||
- SAMEORIGIN
|
||||
x-permitted-cross-domain-policies:
|
||||
- none
|
||||
x-request-id:
|
||||
- 4124c4ce-02cf-4d08-9b0b-8983c2e9da6e
|
||||
x-runtime:
|
||||
- '0.073764'
|
||||
x-xss-protection:
|
||||
- 1; mode=block
|
||||
status:
|
||||
code: 401
|
||||
message: Unauthorized
|
||||
- request:
|
||||
body: '{"max_tokens":4096,"messages":[{"role":"user","content":"Say hello in one
|
||||
word"}],"model":"claude-3-5-haiku-20241022","stop_sequences":["\nObservation:","\nThought:"],"stream":false}'
|
||||
headers:
|
||||
accept:
|
||||
- application/json
|
||||
accept-encoding:
|
||||
- gzip, deflate, zstd
|
||||
anthropic-version:
|
||||
- '2023-06-01'
|
||||
connection:
|
||||
- keep-alive
|
||||
content-length:
|
||||
- '182'
|
||||
content-type:
|
||||
- application/json
|
||||
host:
|
||||
- api.anthropic.com
|
||||
user-agent:
|
||||
- Anthropic/Python 0.71.0
|
||||
x-stainless-arch:
|
||||
- arm64
|
||||
x-stainless-async:
|
||||
- 'false'
|
||||
x-stainless-lang:
|
||||
- python
|
||||
x-stainless-os:
|
||||
- MacOS
|
||||
x-stainless-package-version:
|
||||
- 0.71.0
|
||||
x-stainless-retry-count:
|
||||
- '0'
|
||||
x-stainless-runtime:
|
||||
- CPython
|
||||
x-stainless-runtime-version:
|
||||
- 3.12.9
|
||||
x-stainless-timeout:
|
||||
- NOT_GIVEN
|
||||
method: POST
|
||||
uri: https://api.anthropic.com/v1/messages
|
||||
response:
|
||||
body:
|
||||
string: !!binary |
|
||||
H4sIAAAAAAAAAwAAAP//dJBdS8QwEEX/y31Ope26u5J3dwWfBH0SCTEZtmHTpCYTXSn979LF4hc+
|
||||
DdxzZgbuiD5a8pAwXhdL1apaV512x1K1dXvZ1G0LAWch0eeDqpvN1f3q7bTPz91tc/ewLcfr/Xaz
|
||||
gwC/DzRblLM+EARS9HOgc3aZdWAImBiYAkM+jovPdJrJeUjckPfxAtOTQOY4qEQ6xwAJClZxSQGf
|
||||
INNLoWAIMhTvBcr5qRzhwlBYcTxSyJBNK2C06UiZRJpdDOqnUC88kbb/sWV3vk9DRz0l7dW6/+t/
|
||||
0ab7TSeBWPh7tBbIlF6dIcWOEiTmoqxOFtP0AQAA//8DAM5WvkqaAQAA
|
||||
headers:
|
||||
CF-RAY:
|
||||
- 99a939a5a931556e-EWR
|
||||
Connection:
|
||||
- keep-alive
|
||||
Content-Encoding:
|
||||
- gzip
|
||||
Content-Type:
|
||||
- application/json
|
||||
Date:
|
||||
- Fri, 07 Nov 2025 01:58:22 GMT
|
||||
Server:
|
||||
- cloudflare
|
||||
Transfer-Encoding:
|
||||
- chunked
|
||||
X-Robots-Tag:
|
||||
- none
|
||||
anthropic-organization-id:
|
||||
- SCRUBBED-ORG-ID
|
||||
anthropic-ratelimit-input-tokens-limit:
|
||||
- '400000'
|
||||
anthropic-ratelimit-input-tokens-remaining:
|
||||
- '400000'
|
||||
anthropic-ratelimit-input-tokens-reset:
|
||||
- '2025-11-07T01:58:22Z'
|
||||
anthropic-ratelimit-output-tokens-limit:
|
||||
- '80000'
|
||||
anthropic-ratelimit-output-tokens-remaining:
|
||||
- '80000'
|
||||
anthropic-ratelimit-output-tokens-reset:
|
||||
- '2025-11-07T01:58:22Z'
|
||||
anthropic-ratelimit-requests-limit:
|
||||
- '4000'
|
||||
anthropic-ratelimit-requests-remaining:
|
||||
- '3999'
|
||||
anthropic-ratelimit-requests-reset:
|
||||
- '2025-11-07T01:58:22Z'
|
||||
anthropic-ratelimit-tokens-limit:
|
||||
- '480000'
|
||||
anthropic-ratelimit-tokens-remaining:
|
||||
- '480000'
|
||||
anthropic-ratelimit-tokens-reset:
|
||||
- '2025-11-07T01:58:22Z'
|
||||
cf-cache-status:
|
||||
- DYNAMIC
|
||||
request-id:
|
||||
- req_011CUshbL7CEVoner91hUvxL
|
||||
retry-after:
|
||||
- '41'
|
||||
strict-transport-security:
|
||||
- max-age=31536000; includeSubDomains; preload
|
||||
x-envoy-upstream-service-time:
|
||||
- '390'
|
||||
status:
|
||||
code: 200
|
||||
message: OK
|
||||
version: 1
|
||||
File diff suppressed because it is too large
Load Diff
@@ -601,3 +601,81 @@ def test_file_path_validation():
|
||||
match="file_path/file_paths must be a Path, str, or a list of these types",
|
||||
):
|
||||
PDFKnowledgeSource()
|
||||
|
||||
|
||||
def test_hash_based_id_generation_without_doc_id(mock_vector_db):
|
||||
"""Test that documents without doc_id generate hash-based IDs. Duplicates are deduplicated before upsert."""
|
||||
import hashlib
|
||||
import json
|
||||
from crewai.rag.chromadb.utils import _prepare_documents_for_chromadb
|
||||
from crewai.rag.types import BaseRecord
|
||||
|
||||
documents: list[BaseRecord] = [
|
||||
{"content": "First document content", "metadata": {"source": "test1", "category": "research"}},
|
||||
{"content": "Second document content", "metadata": {"source": "test2", "category": "research"}},
|
||||
{"content": "Third document content"}, # No metadata
|
||||
]
|
||||
|
||||
result = _prepare_documents_for_chromadb(documents)
|
||||
|
||||
assert len(result.ids) == 3
|
||||
|
||||
# Unique documents should get 64-character hex hashes (no suffix)
|
||||
for doc_id in result.ids:
|
||||
assert len(doc_id) == 64, f"ID should be 64 characters: {doc_id}"
|
||||
assert all(c in "0123456789abcdef" for c in doc_id), f"ID should be hex: {doc_id}"
|
||||
|
||||
# Different documents should have different hashes
|
||||
assert result.ids[0] != result.ids[1] != result.ids[2]
|
||||
|
||||
# Verify hashes match expected values
|
||||
expected_hash_1 = hashlib.sha256(
|
||||
f"First document content|{json.dumps({'category': 'research', 'source': 'test1'}, sort_keys=True)}".encode()
|
||||
).hexdigest()
|
||||
assert result.ids[0] == expected_hash_1, "First document hash should match expected"
|
||||
|
||||
expected_hash_3 = hashlib.sha256("Third document content".encode()).hexdigest()
|
||||
assert result.ids[2] == expected_hash_3, "Third document hash should match expected"
|
||||
|
||||
# Test that duplicate documents are deduplicated (same ID, only one sent)
|
||||
duplicate_documents: list[BaseRecord] = [
|
||||
{"content": "Same content", "metadata": {"source": "test"}},
|
||||
{"content": "Same content", "metadata": {"source": "test"}},
|
||||
{"content": "Same content", "metadata": {"source": "test"}},
|
||||
]
|
||||
duplicate_result = _prepare_documents_for_chromadb(duplicate_documents)
|
||||
# Duplicates should be deduplicated - only one ID should remain
|
||||
assert len(duplicate_result.ids) == 1, "Duplicate documents should be deduplicated"
|
||||
assert len(duplicate_result.ids[0]) == 64, "Deduplicated ID should be clean hash"
|
||||
# Verify it's the expected hash
|
||||
expected_hash = hashlib.sha256(
|
||||
f"Same content|{json.dumps({'source': 'test'}, sort_keys=True)}".encode()
|
||||
).hexdigest()
|
||||
assert duplicate_result.ids[0] == expected_hash, "Deduplicated ID should match expected hash"
|
||||
|
||||
|
||||
def test_hash_based_id_generation_with_doc_id_in_metadata(mock_vector_db):
|
||||
"""Test that documents with doc_id in metadata use the doc_id directly, not hash-based."""
|
||||
from crewai.rag.chromadb.utils import _prepare_documents_for_chromadb
|
||||
from crewai.rag.types import BaseRecord
|
||||
|
||||
documents_with_doc_id: list[BaseRecord] = [
|
||||
{"content": "First document", "metadata": {"doc_id": "custom-id-1", "source": "test1"}},
|
||||
{"content": "Second document", "metadata": {"doc_id": "custom-id-2"}},
|
||||
]
|
||||
|
||||
documents_without_doc_id: list[BaseRecord] = [
|
||||
{"content": "First document", "metadata": {"source": "test1"}},
|
||||
{"content": "Second document"},
|
||||
]
|
||||
|
||||
result_with_doc_id = _prepare_documents_for_chromadb(documents_with_doc_id)
|
||||
result_without_doc_id = _prepare_documents_for_chromadb(documents_without_doc_id)
|
||||
|
||||
assert result_with_doc_id.ids == ["custom-id-1", "custom-id-2"]
|
||||
|
||||
assert len(result_without_doc_id.ids) == 2
|
||||
# Unique documents get 64-character hashes
|
||||
for doc_id in result_without_doc_id.ids:
|
||||
assert len(doc_id) == 64, "ID should be 64 characters"
|
||||
assert all(c in "0123456789abcdef" for c in doc_id), "ID should be hex"
|
||||
|
||||
@@ -664,3 +664,37 @@ def test_anthropic_token_usage_tracking():
|
||||
assert usage["input_tokens"] == 50
|
||||
assert usage["output_tokens"] == 25
|
||||
assert usage["total_tokens"] == 75
|
||||
|
||||
|
||||
def test_anthropic_stop_sequences_sync():
|
||||
"""Test that stop and stop_sequences attributes stay synchronized."""
|
||||
llm = LLM(model="anthropic/claude-3-5-sonnet-20241022")
|
||||
|
||||
# Test setting stop as a list
|
||||
llm.stop = ["\nObservation:", "\nThought:"]
|
||||
assert llm.stop_sequences == ["\nObservation:", "\nThought:"]
|
||||
assert llm.stop == ["\nObservation:", "\nThought:"]
|
||||
|
||||
# Test setting stop as a string
|
||||
llm.stop = "\nFinal Answer:"
|
||||
assert llm.stop_sequences == ["\nFinal Answer:"]
|
||||
assert llm.stop == ["\nFinal Answer:"]
|
||||
|
||||
# Test setting stop as None
|
||||
llm.stop = None
|
||||
assert llm.stop_sequences == []
|
||||
assert llm.stop == []
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization", "x-api-key"])
|
||||
def test_anthropic_stop_sequences_sent_to_api():
|
||||
"""Test that stop_sequences are properly sent to the Anthropic API."""
|
||||
llm = LLM(model="anthropic/claude-3-5-haiku-20241022")
|
||||
|
||||
llm.stop = ["\nObservation:", "\nThought:"]
|
||||
|
||||
result = llm.call("Say hello in one word")
|
||||
|
||||
assert result is not None
|
||||
assert isinstance(result, str)
|
||||
assert len(result) > 0
|
||||
|
||||
@@ -736,3 +736,56 @@ def test_bedrock_client_error_handling():
|
||||
with pytest.raises(RuntimeError) as exc_info:
|
||||
llm.call("Hello")
|
||||
assert "throttled" in str(exc_info.value).lower()
|
||||
|
||||
|
||||
def test_bedrock_stop_sequences_sync():
|
||||
"""Test that stop and stop_sequences attributes stay synchronized."""
|
||||
llm = LLM(model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0")
|
||||
|
||||
# Test setting stop as a list
|
||||
llm.stop = ["\nObservation:", "\nThought:"]
|
||||
assert list(llm.stop_sequences) == ["\nObservation:", "\nThought:"]
|
||||
assert llm.stop == ["\nObservation:", "\nThought:"]
|
||||
|
||||
# Test setting stop as a string
|
||||
llm.stop = "\nFinal Answer:"
|
||||
assert list(llm.stop_sequences) == ["\nFinal Answer:"]
|
||||
assert llm.stop == ["\nFinal Answer:"]
|
||||
|
||||
# Test setting stop as None
|
||||
llm.stop = None
|
||||
assert list(llm.stop_sequences) == []
|
||||
assert llm.stop == []
|
||||
|
||||
|
||||
def test_bedrock_stop_sequences_sent_to_api():
|
||||
"""Test that stop_sequences are properly sent to the Bedrock API."""
|
||||
llm = LLM(model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0")
|
||||
|
||||
# Set stop sequences via the stop attribute (simulating CrewAgentExecutor)
|
||||
llm.stop = ["\nObservation:", "\nThought:"]
|
||||
|
||||
# Patch the API call to capture parameters without making real call
|
||||
with patch.object(llm.client, 'converse') as mock_converse:
|
||||
mock_response = {
|
||||
'output': {
|
||||
'message': {
|
||||
'role': 'assistant',
|
||||
'content': [{'text': 'Hello'}]
|
||||
}
|
||||
},
|
||||
'usage': {
|
||||
'inputTokens': 10,
|
||||
'outputTokens': 5,
|
||||
'totalTokens': 15
|
||||
}
|
||||
}
|
||||
mock_converse.return_value = mock_response
|
||||
|
||||
llm.call("Say hello in one word")
|
||||
|
||||
# Verify stop_sequences were passed to the API in the inference config
|
||||
call_kwargs = mock_converse.call_args[1]
|
||||
assert "inferenceConfig" in call_kwargs
|
||||
assert "stopSequences" in call_kwargs["inferenceConfig"]
|
||||
assert call_kwargs["inferenceConfig"]["stopSequences"] == ["\nObservation:", "\nThought:"]
|
||||
|
||||
@@ -648,3 +648,55 @@ def test_gemini_token_usage_tracking():
|
||||
assert usage["candidates_token_count"] == 25
|
||||
assert usage["total_token_count"] == 75
|
||||
assert usage["total_tokens"] == 75
|
||||
|
||||
|
||||
def test_gemini_stop_sequences_sync():
|
||||
"""Test that stop and stop_sequences attributes stay synchronized."""
|
||||
llm = LLM(model="google/gemini-2.0-flash-001")
|
||||
|
||||
# Test setting stop as a list
|
||||
llm.stop = ["\nObservation:", "\nThought:"]
|
||||
assert llm.stop_sequences == ["\nObservation:", "\nThought:"]
|
||||
assert llm.stop == ["\nObservation:", "\nThought:"]
|
||||
|
||||
# Test setting stop as a string
|
||||
llm.stop = "\nFinal Answer:"
|
||||
assert llm.stop_sequences == ["\nFinal Answer:"]
|
||||
assert llm.stop == ["\nFinal Answer:"]
|
||||
|
||||
# Test setting stop as None
|
||||
llm.stop = None
|
||||
assert llm.stop_sequences == []
|
||||
assert llm.stop == []
|
||||
|
||||
|
||||
def test_gemini_stop_sequences_sent_to_api():
|
||||
"""Test that stop_sequences are properly sent to the Gemini API."""
|
||||
llm = LLM(model="google/gemini-2.0-flash-001")
|
||||
|
||||
# Set stop sequences via the stop attribute (simulating CrewAgentExecutor)
|
||||
llm.stop = ["\nObservation:", "\nThought:"]
|
||||
|
||||
# Patch the API call to capture parameters without making real call
|
||||
with patch.object(llm.client.models, 'generate_content') as mock_generate:
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = "Hello"
|
||||
mock_response.candidates = []
|
||||
mock_response.usage_metadata = MagicMock(
|
||||
prompt_token_count=10,
|
||||
candidates_token_count=5,
|
||||
total_token_count=15
|
||||
)
|
||||
mock_generate.return_value = mock_response
|
||||
|
||||
llm.call("Say hello in one word")
|
||||
|
||||
# Verify stop_sequences were passed to the API in the config
|
||||
call_kwargs = mock_generate.call_args[1]
|
||||
assert "config" in call_kwargs
|
||||
# The config object should have stop_sequences set
|
||||
config = call_kwargs["config"]
|
||||
# Check if the config has stop_sequences attribute
|
||||
assert hasattr(config, 'stop_sequences') or 'stop_sequences' in config.__dict__
|
||||
if hasattr(config, 'stop_sequences'):
|
||||
assert config.stop_sequences == ["\nObservation:", "\nThought:"]
|
||||
|
||||
@@ -6,7 +6,7 @@ import httpx
|
||||
import pytest
|
||||
|
||||
from crewai.llms.hooks.base import BaseInterceptor
|
||||
from crewai.llms.hooks.transport import AsyncHTTPransport, HTTPTransport
|
||||
from crewai.llms.hooks.transport import AsyncHTTPTransport, HTTPTransport
|
||||
|
||||
|
||||
class TrackingInterceptor(BaseInterceptor[httpx.Request, httpx.Response]):
|
||||
@@ -128,7 +128,7 @@ class TestAsyncHTTPTransport:
|
||||
def test_async_transport_instantiation(self) -> None:
|
||||
"""Test that async transport can be instantiated with interceptor."""
|
||||
interceptor = TrackingInterceptor()
|
||||
transport = AsyncHTTPransport(interceptor=interceptor)
|
||||
transport = AsyncHTTPTransport(interceptor=interceptor)
|
||||
|
||||
assert transport.interceptor is interceptor
|
||||
|
||||
@@ -136,13 +136,13 @@ class TestAsyncHTTPTransport:
|
||||
"""Test that async transport requires interceptor parameter."""
|
||||
# AsyncHTTPransport requires an interceptor parameter
|
||||
with pytest.raises(TypeError):
|
||||
AsyncHTTPransport()
|
||||
AsyncHTTPTransport()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_interceptor_called_on_request(self) -> None:
|
||||
"""Test that async interceptor hooks are called during request handling."""
|
||||
interceptor = TrackingInterceptor()
|
||||
transport = AsyncHTTPransport(interceptor=interceptor)
|
||||
transport = AsyncHTTPTransport(interceptor=interceptor)
|
||||
|
||||
# Create a mock parent transport that returns a response
|
||||
mock_response = httpx.Response(200, json={"success": True})
|
||||
@@ -217,7 +217,7 @@ class TestTransportIntegration:
|
||||
async def test_multiple_async_requests_same_interceptor(self) -> None:
|
||||
"""Test that multiple async requests through same interceptor are tracked."""
|
||||
interceptor = TrackingInterceptor()
|
||||
transport = AsyncHTTPransport(interceptor=interceptor)
|
||||
transport = AsyncHTTPTransport(interceptor=interceptor)
|
||||
|
||||
mock_response = httpx.Response(200)
|
||||
|
||||
|
||||
4
lib/crewai/tests/mcp/__init__.py
Normal file
4
lib/crewai/tests/mcp/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
"""Tests for MCP (Model Context Protocol) integration."""
|
||||
|
||||
|
||||
|
||||
200
lib/crewai/tests/mcp/test_mcp_config.py
Normal file
200
lib/crewai/tests/mcp/test_mcp_config.py
Normal file
@@ -0,0 +1,200 @@
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from crewai.agent.core import Agent
|
||||
from crewai.mcp.config import MCPServerHTTP, MCPServerSSE, MCPServerStdio
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_tool_definitions():
|
||||
"""Create mock MCP tool definitions (as returned by list_tools)."""
|
||||
return [
|
||||
{
|
||||
"name": "test_tool_1",
|
||||
"description": "Test tool 1 description",
|
||||
"inputSchema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string", "description": "Search query"}
|
||||
},
|
||||
"required": ["query"]
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "test_tool_2",
|
||||
"description": "Test tool 2 description",
|
||||
"inputSchema": {}
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
def test_agent_with_stdio_mcp_config(mock_tool_definitions):
|
||||
"""Test agent setup with MCPServerStdio configuration."""
|
||||
stdio_config = MCPServerStdio(
|
||||
command="python",
|
||||
args=["server.py"],
|
||||
env={"API_KEY": "test_key"},
|
||||
)
|
||||
|
||||
agent = Agent(
|
||||
role="Test Agent",
|
||||
goal="Test goal",
|
||||
backstory="Test backstory",
|
||||
mcps=[stdio_config],
|
||||
)
|
||||
|
||||
|
||||
with patch("crewai.agent.core.MCPClient") as mock_client_class:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.list_tools = AsyncMock(return_value=mock_tool_definitions)
|
||||
mock_client.connected = False # Will trigger connect
|
||||
mock_client.connect = AsyncMock()
|
||||
mock_client.disconnect = AsyncMock()
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
tools = agent.get_mcp_tools([stdio_config])
|
||||
|
||||
assert len(tools) == 2
|
||||
assert all(isinstance(tool, BaseTool) for tool in tools)
|
||||
|
||||
mock_client_class.assert_called_once()
|
||||
call_args = mock_client_class.call_args
|
||||
transport = call_args.kwargs["transport"]
|
||||
assert transport.command == "python"
|
||||
assert transport.args == ["server.py"]
|
||||
assert transport.env == {"API_KEY": "test_key"}
|
||||
|
||||
|
||||
def test_agent_with_http_mcp_config(mock_tool_definitions):
|
||||
"""Test agent setup with MCPServerHTTP configuration."""
|
||||
http_config = MCPServerHTTP(
|
||||
url="https://api.example.com/mcp",
|
||||
headers={"Authorization": "Bearer test_token"},
|
||||
streamable=True,
|
||||
)
|
||||
|
||||
agent = Agent(
|
||||
role="Test Agent",
|
||||
goal="Test goal",
|
||||
backstory="Test backstory",
|
||||
mcps=[http_config],
|
||||
)
|
||||
|
||||
with patch("crewai.agent.core.MCPClient") as mock_client_class:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.list_tools = AsyncMock(return_value=mock_tool_definitions)
|
||||
mock_client.connected = False # Will trigger connect
|
||||
mock_client.connect = AsyncMock()
|
||||
mock_client.disconnect = AsyncMock()
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
tools = agent.get_mcp_tools([http_config])
|
||||
|
||||
assert len(tools) == 2
|
||||
assert all(isinstance(tool, BaseTool) for tool in tools)
|
||||
|
||||
mock_client_class.assert_called_once()
|
||||
call_args = mock_client_class.call_args
|
||||
transport = call_args.kwargs["transport"]
|
||||
assert transport.url == "https://api.example.com/mcp"
|
||||
assert transport.headers == {"Authorization": "Bearer test_token"}
|
||||
assert transport.streamable is True
|
||||
|
||||
|
||||
def test_agent_with_sse_mcp_config(mock_tool_definitions):
|
||||
"""Test agent setup with MCPServerSSE configuration."""
|
||||
sse_config = MCPServerSSE(
|
||||
url="https://api.example.com/mcp/sse",
|
||||
headers={"Authorization": "Bearer test_token"},
|
||||
)
|
||||
|
||||
agent = Agent(
|
||||
role="Test Agent",
|
||||
goal="Test goal",
|
||||
backstory="Test backstory",
|
||||
mcps=[sse_config],
|
||||
)
|
||||
|
||||
with patch("crewai.agent.core.MCPClient") as mock_client_class:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.list_tools = AsyncMock(return_value=mock_tool_definitions)
|
||||
mock_client.connected = False
|
||||
mock_client.connect = AsyncMock()
|
||||
mock_client.disconnect = AsyncMock()
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
tools = agent.get_mcp_tools([sse_config])
|
||||
|
||||
assert len(tools) == 2
|
||||
assert all(isinstance(tool, BaseTool) for tool in tools)
|
||||
|
||||
mock_client_class.assert_called_once()
|
||||
call_args = mock_client_class.call_args
|
||||
transport = call_args.kwargs["transport"]
|
||||
assert transport.url == "https://api.example.com/mcp/sse"
|
||||
assert transport.headers == {"Authorization": "Bearer test_token"}
|
||||
|
||||
|
||||
def test_mcp_tool_execution_in_sync_context(mock_tool_definitions):
|
||||
"""Test MCPNativeTool execution in synchronous context (normal crew execution)."""
|
||||
http_config = MCPServerHTTP(url="https://api.example.com/mcp")
|
||||
|
||||
with patch("crewai.agent.core.MCPClient") as mock_client_class:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.list_tools = AsyncMock(return_value=mock_tool_definitions)
|
||||
mock_client.connected = False
|
||||
mock_client.connect = AsyncMock()
|
||||
mock_client.disconnect = AsyncMock()
|
||||
mock_client.call_tool = AsyncMock(return_value="test result")
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
agent = Agent(
|
||||
role="Test Agent",
|
||||
goal="Test goal",
|
||||
backstory="Test backstory",
|
||||
mcps=[http_config],
|
||||
)
|
||||
|
||||
tools = agent.get_mcp_tools([http_config])
|
||||
assert len(tools) == 2
|
||||
|
||||
|
||||
tool = tools[0]
|
||||
result = tool.run(query="test query")
|
||||
|
||||
assert result == "test result"
|
||||
mock_client.call_tool.assert_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mcp_tool_execution_in_async_context(mock_tool_definitions):
|
||||
"""Test MCPNativeTool execution in async context (e.g., from a Flow)."""
|
||||
http_config = MCPServerHTTP(url="https://api.example.com/mcp")
|
||||
|
||||
with patch("crewai.agent.core.MCPClient") as mock_client_class:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.list_tools = AsyncMock(return_value=mock_tool_definitions)
|
||||
mock_client.connected = False
|
||||
mock_client.connect = AsyncMock()
|
||||
mock_client.disconnect = AsyncMock()
|
||||
mock_client.call_tool = AsyncMock(return_value="test result")
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
agent = Agent(
|
||||
role="Test Agent",
|
||||
goal="Test goal",
|
||||
backstory="Test backstory",
|
||||
mcps=[http_config],
|
||||
)
|
||||
|
||||
tools = agent.get_mcp_tools([http_config])
|
||||
assert len(tools) == 2
|
||||
|
||||
|
||||
tool = tools[0]
|
||||
result = tool.run(query="test query")
|
||||
|
||||
assert result == "test result"
|
||||
mock_client.call_tool.assert_called()
|
||||
Reference in New Issue
Block a user