mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-01 04:08:30 +00:00
Compare commits
67 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
14a8926214 | ||
|
|
09a390d50f | ||
|
|
d7ac90c9e2 | ||
|
|
b680065c45 | ||
|
|
436f1b4639 | ||
|
|
fa53a995e4 | ||
|
|
c35a84de82 | ||
|
|
6b9a797000 | ||
|
|
6515c7faeb | ||
|
|
3b32793e78 | ||
|
|
02d7ce7621 | ||
|
|
06a45b29db | ||
|
|
7351e4b0ef | ||
|
|
d9b68ddd85 | ||
|
|
2d5ad7a187 | ||
|
|
a9aff87db3 | ||
|
|
c9ff264e8e | ||
|
|
0816b810c7 | ||
|
|
96d142e353 | ||
|
|
53b239c6df | ||
|
|
cec4e4c2e9 | ||
|
|
6c5ac13242 | ||
|
|
52b2f07c9f | ||
|
|
da331ce422 | ||
|
|
0648e88f22 | ||
|
|
abe1f40bc2 | ||
|
|
06f7d224c0 | ||
|
|
faddcd0de7 | ||
|
|
2f4fdf9a90 | ||
|
|
28a8a7e6fa | ||
|
|
51e8fb1f90 | ||
|
|
f094df6015 | ||
|
|
458f56fb33 | ||
|
|
11f6b34aa3 | ||
|
|
47b6baee01 | ||
|
|
f9992d8d7a | ||
|
|
79d4e42e62 | ||
|
|
8b9186311f | ||
|
|
29a0ac483f | ||
|
|
38bc5a9dc4 | ||
|
|
0b305dabc9 | ||
|
|
ebeed0b752 | ||
|
|
2a0018a99b | ||
|
|
5865d39137 | ||
|
|
e529ebff2b | ||
|
|
126b91eab3 | ||
|
|
428810bd6f | ||
|
|
610bc4b3f5 | ||
|
|
e73c5887d9 | ||
|
|
c5ac5fa78a | ||
|
|
5456c80556 | ||
|
|
df754dbcc8 | ||
|
|
e8356b777c | ||
|
|
ade425a543 | ||
|
|
d7f6f07a5d | ||
|
|
9e1dae0746 | ||
|
|
b5161c320d | ||
|
|
c793c829ea | ||
|
|
0fe9352149 | ||
|
|
548170e989 | ||
|
|
417a4e3d91 | ||
|
|
68dce92003 | ||
|
|
289b90f00a | ||
|
|
c591c1ac87 | ||
|
|
86f0dfc2d7 | ||
|
|
74b5c88834 | ||
|
|
13e5ec711d |
4
.github/workflows/publish.yml
vendored
4
.github/workflows/publish.yml
vendored
@@ -7,6 +7,7 @@ on:
|
||||
|
||||
jobs:
|
||||
build:
|
||||
if: github.event.release.prerelease == true
|
||||
name: Build packages
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
@@ -24,7 +25,7 @@ jobs:
|
||||
|
||||
- name: Build packages
|
||||
run: |
|
||||
uv build --all-packages
|
||||
uv build --prerelease="allow" --all-packages
|
||||
rm dist/.gitignore
|
||||
|
||||
- name: Upload artifacts
|
||||
@@ -34,6 +35,7 @@ jobs:
|
||||
path: dist/
|
||||
|
||||
publish:
|
||||
if: github.event.release.prerelease == true
|
||||
name: Publish to PyPI
|
||||
needs: build
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
@@ -3,24 +3,23 @@ repos:
|
||||
hooks:
|
||||
- id: ruff
|
||||
name: ruff
|
||||
entry: bash -c 'source .venv/bin/activate && uv run ruff check --config pyproject.toml "$@"' --
|
||||
entry: uv run ruff check
|
||||
args: ["--config", "pyproject.toml", "."]
|
||||
language: system
|
||||
pass_filenames: true
|
||||
pass_filenames: false
|
||||
types: [python]
|
||||
- id: ruff-format
|
||||
name: ruff-format
|
||||
entry: bash -c 'source .venv/bin/activate && uv run ruff format --config pyproject.toml "$@"' --
|
||||
entry: uv run ruff format
|
||||
args: ["--config", "pyproject.toml", "."]
|
||||
language: system
|
||||
pass_filenames: true
|
||||
pass_filenames: false
|
||||
types: [python]
|
||||
- id: mypy
|
||||
name: mypy
|
||||
entry: bash -c 'source .venv/bin/activate && uv run mypy --config-file pyproject.toml "$@"' --
|
||||
entry: uv run mypy
|
||||
args: ["--config-file", "pyproject.toml", "."]
|
||||
language: system
|
||||
pass_filenames: true
|
||||
pass_filenames: false
|
||||
types: [python]
|
||||
- repo: https://github.com/astral-sh/uv-pre-commit
|
||||
rev: 0.9.3
|
||||
hooks:
|
||||
- id: uv-lock
|
||||
|
||||
|
||||
@@ -134,7 +134,6 @@
|
||||
"group": "MCP Integration",
|
||||
"pages": [
|
||||
"en/mcp/overview",
|
||||
"en/mcp/dsl-integration",
|
||||
"en/mcp/stdio",
|
||||
"en/mcp/sse",
|
||||
"en/mcp/streamable-http",
|
||||
@@ -571,7 +570,6 @@
|
||||
"group": "Integração MCP",
|
||||
"pages": [
|
||||
"pt-BR/mcp/overview",
|
||||
"pt-BR/mcp/dsl-integration",
|
||||
"pt-BR/mcp/stdio",
|
||||
"pt-BR/mcp/sse",
|
||||
"pt-BR/mcp/streamable-http",
|
||||
@@ -991,7 +989,6 @@
|
||||
"group": "MCP 통합",
|
||||
"pages": [
|
||||
"ko/mcp/overview",
|
||||
"ko/mcp/dsl-integration",
|
||||
"ko/mcp/stdio",
|
||||
"ko/mcp/sse",
|
||||
"ko/mcp/streamable-http",
|
||||
|
||||
@@ -1,344 +0,0 @@
|
||||
---
|
||||
title: MCP DSL Integration
|
||||
description: Learn how to use CrewAI's simple DSL syntax to integrate MCP servers directly with your agents using the mcps field.
|
||||
icon: code
|
||||
mode: "wide"
|
||||
---
|
||||
|
||||
## Overview
|
||||
|
||||
CrewAI's MCP DSL (Domain Specific Language) integration provides the **simplest way** to connect your agents to MCP (Model Context Protocol) servers. Just add an `mcps` field to your agent and CrewAI handles all the complexity automatically.
|
||||
|
||||
<Info>
|
||||
This is the **recommended approach** for most MCP use cases. For advanced scenarios requiring manual connection management, see [MCPServerAdapter](/en/mcp/overview#advanced-mcpserveradapter).
|
||||
</Info>
|
||||
|
||||
## Basic Usage
|
||||
|
||||
Add MCP servers to your agent using the `mcps` field:
|
||||
|
||||
```python
|
||||
from crewai import Agent
|
||||
|
||||
agent = Agent(
|
||||
role="Research Assistant",
|
||||
goal="Help with research and analysis tasks",
|
||||
backstory="Expert assistant with access to advanced research tools",
|
||||
mcps=[
|
||||
"https://mcp.exa.ai/mcp?api_key=your_key&profile=research"
|
||||
]
|
||||
)
|
||||
|
||||
# MCP tools are now automatically available!
|
||||
# No need for manual connection management or tool configuration
|
||||
```
|
||||
|
||||
## Supported Reference Formats
|
||||
|
||||
### External MCP Remote Servers
|
||||
|
||||
```python
|
||||
# Basic HTTPS server
|
||||
"https://api.example.com/mcp"
|
||||
|
||||
# Server with authentication
|
||||
"https://mcp.exa.ai/mcp?api_key=your_key&profile=your_profile"
|
||||
|
||||
# Server with custom path
|
||||
"https://services.company.com/api/v1/mcp"
|
||||
```
|
||||
|
||||
### Specific Tool Selection
|
||||
|
||||
Use the `#` syntax to select specific tools from a server:
|
||||
|
||||
```python
|
||||
# Get only the forecast tool from weather server
|
||||
"https://weather.api.com/mcp#get_forecast"
|
||||
|
||||
# Get only the search tool from Exa
|
||||
"https://mcp.exa.ai/mcp?api_key=your_key#web_search_exa"
|
||||
```
|
||||
|
||||
### CrewAI AMP Marketplace
|
||||
|
||||
Access tools from the CrewAI AMP marketplace:
|
||||
|
||||
```python
|
||||
# Full service with all tools
|
||||
"crewai-amp:financial-data"
|
||||
|
||||
# Specific tool from AMP service
|
||||
"crewai-amp:research-tools#pubmed_search"
|
||||
|
||||
# Multiple AMP services
|
||||
mcps=[
|
||||
"crewai-amp:weather-insights",
|
||||
"crewai-amp:market-analysis",
|
||||
"crewai-amp:social-media-monitoring"
|
||||
]
|
||||
```
|
||||
|
||||
## Complete Example
|
||||
|
||||
Here's a complete example using multiple MCP servers:
|
||||
|
||||
```python
|
||||
from crewai import Agent, Task, Crew, Process
|
||||
|
||||
# Create agent with multiple MCP sources
|
||||
multi_source_agent = Agent(
|
||||
role="Multi-Source Research Analyst",
|
||||
goal="Conduct comprehensive research using multiple data sources",
|
||||
backstory="""Expert researcher with access to web search, weather data,
|
||||
financial information, and academic research tools""",
|
||||
mcps=[
|
||||
# External MCP servers
|
||||
"https://mcp.exa.ai/mcp?api_key=your_exa_key&profile=research",
|
||||
"https://weather.api.com/mcp#get_current_conditions",
|
||||
|
||||
# CrewAI AMP marketplace
|
||||
"crewai-amp:financial-insights",
|
||||
"crewai-amp:academic-research#pubmed_search",
|
||||
"crewai-amp:market-intelligence#competitor_analysis"
|
||||
]
|
||||
)
|
||||
|
||||
# Create comprehensive research task
|
||||
research_task = Task(
|
||||
description="""Research the impact of AI agents on business productivity.
|
||||
Include current weather impacts on remote work, financial market trends,
|
||||
and recent academic publications on AI agent frameworks.""",
|
||||
expected_output="""Comprehensive report covering:
|
||||
1. AI agent business impact analysis
|
||||
2. Weather considerations for remote work
|
||||
3. Financial market trends related to AI
|
||||
4. Academic research citations and insights
|
||||
5. Competitive landscape analysis""",
|
||||
agent=multi_source_agent
|
||||
)
|
||||
|
||||
# Create and execute crew
|
||||
research_crew = Crew(
|
||||
agents=[multi_source_agent],
|
||||
tasks=[research_task],
|
||||
process=Process.sequential,
|
||||
verbose=True
|
||||
)
|
||||
|
||||
result = research_crew.kickoff()
|
||||
print(f"Research completed with {len(multi_source_agent.mcps)} MCP data sources")
|
||||
```
|
||||
|
||||
## Tool Naming and Organization
|
||||
|
||||
CrewAI automatically handles tool naming to prevent conflicts:
|
||||
|
||||
```python
|
||||
# Original MCP server has tools: "search", "analyze"
|
||||
# CrewAI creates tools: "mcp_exa_ai_search", "mcp_exa_ai_analyze"
|
||||
|
||||
agent = Agent(
|
||||
role="Tool Organization Demo",
|
||||
goal="Show how tool naming works",
|
||||
backstory="Demonstrates automatic tool organization",
|
||||
mcps=[
|
||||
"https://mcp.exa.ai/mcp?api_key=key", # Tools: mcp_exa_ai_*
|
||||
"https://weather.service.com/mcp", # Tools: weather_service_com_*
|
||||
"crewai-amp:financial-data" # Tools: financial_data_*
|
||||
]
|
||||
)
|
||||
|
||||
# Each server's tools get unique prefixes based on the server name
|
||||
# This prevents naming conflicts between different MCP servers
|
||||
```
|
||||
|
||||
## Error Handling and Resilience
|
||||
|
||||
The MCP DSL is designed to be robust and user-friendly:
|
||||
|
||||
### Graceful Server Failures
|
||||
|
||||
```python
|
||||
agent = Agent(
|
||||
role="Resilient Researcher",
|
||||
goal="Research despite server issues",
|
||||
backstory="Experienced researcher who adapts to available tools",
|
||||
mcps=[
|
||||
"https://primary-server.com/mcp", # Primary data source
|
||||
"https://backup-server.com/mcp", # Backup if primary fails
|
||||
"https://unreachable-server.com/mcp", # Will be skipped with warning
|
||||
"crewai-amp:reliable-service" # Reliable AMP service
|
||||
]
|
||||
)
|
||||
|
||||
# Agent will:
|
||||
# 1. Successfully connect to working servers
|
||||
# 2. Log warnings for failing servers
|
||||
# 3. Continue with available tools
|
||||
# 4. Not crash or hang on server failures
|
||||
```
|
||||
|
||||
### Timeout Protection
|
||||
|
||||
All MCP operations have built-in timeouts:
|
||||
|
||||
- **Connection timeout**: 10 seconds
|
||||
- **Tool execution timeout**: 30 seconds
|
||||
- **Discovery timeout**: 15 seconds
|
||||
|
||||
```python
|
||||
# These servers will timeout gracefully if unresponsive
|
||||
mcps=[
|
||||
"https://slow-server.com/mcp", # Will timeout after 10s if unresponsive
|
||||
"https://overloaded-api.com/mcp" # Will timeout if discovery takes > 15s
|
||||
]
|
||||
```
|
||||
|
||||
## Performance Features
|
||||
|
||||
### Automatic Caching
|
||||
|
||||
Tool schemas are cached for 5 minutes to improve performance:
|
||||
|
||||
```python
|
||||
# First agent creation - discovers tools from server
|
||||
agent1 = Agent(role="First", goal="Test", backstory="Test",
|
||||
mcps=["https://api.example.com/mcp"])
|
||||
|
||||
# Second agent creation (within 5 minutes) - uses cached tool schemas
|
||||
agent2 = Agent(role="Second", goal="Test", backstory="Test",
|
||||
mcps=["https://api.example.com/mcp"]) # Much faster!
|
||||
```
|
||||
|
||||
### On-Demand Connections
|
||||
|
||||
Tool connections are established only when tools are actually used:
|
||||
|
||||
```python
|
||||
# Agent creation is fast - no MCP connections made yet
|
||||
agent = Agent(
|
||||
role="On-Demand Agent",
|
||||
goal="Use tools efficiently",
|
||||
backstory="Efficient agent that connects only when needed",
|
||||
mcps=["https://api.example.com/mcp"]
|
||||
)
|
||||
|
||||
# MCP connection is made only when a tool is actually executed
|
||||
# This minimizes connection overhead and improves startup performance
|
||||
```
|
||||
|
||||
## Integration with Existing Features
|
||||
|
||||
MCP tools work seamlessly with other CrewAI features:
|
||||
|
||||
```python
|
||||
from crewai.tools import BaseTool
|
||||
|
||||
class CustomTool(BaseTool):
|
||||
name: str = "custom_analysis"
|
||||
description: str = "Custom analysis tool"
|
||||
|
||||
def _run(self, **kwargs):
|
||||
return "Custom analysis result"
|
||||
|
||||
agent = Agent(
|
||||
role="Full-Featured Agent",
|
||||
goal="Use all available tool types",
|
||||
backstory="Agent with comprehensive tool access",
|
||||
|
||||
# All tool types work together
|
||||
tools=[CustomTool()], # Custom tools
|
||||
apps=["gmail", "slack"], # Platform integrations
|
||||
mcps=[ # MCP servers
|
||||
"https://mcp.exa.ai/mcp?api_key=key",
|
||||
"crewai-amp:research-tools"
|
||||
],
|
||||
|
||||
verbose=True,
|
||||
max_iter=15
|
||||
)
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
### 1. Use Specific Tools When Possible
|
||||
|
||||
```python
|
||||
# Good - only get the tools you need
|
||||
mcps=["https://weather.api.com/mcp#get_forecast"]
|
||||
|
||||
# Less efficient - gets all tools from server
|
||||
mcps=["https://weather.api.com/mcp"]
|
||||
```
|
||||
|
||||
### 2. Handle Authentication Securely
|
||||
|
||||
```python
|
||||
import os
|
||||
|
||||
# Store API keys in environment variables
|
||||
exa_key = os.getenv("EXA_API_KEY")
|
||||
exa_profile = os.getenv("EXA_PROFILE")
|
||||
|
||||
agent = Agent(
|
||||
role="Secure Agent",
|
||||
goal="Use MCP tools securely",
|
||||
backstory="Security-conscious agent",
|
||||
mcps=[f"https://mcp.exa.ai/mcp?api_key={exa_key}&profile={exa_profile}"]
|
||||
)
|
||||
```
|
||||
|
||||
### 3. Plan for Server Failures
|
||||
|
||||
```python
|
||||
# Always include backup options
|
||||
mcps=[
|
||||
"https://primary-api.com/mcp", # Primary choice
|
||||
"https://backup-api.com/mcp", # Backup option
|
||||
"crewai-amp:reliable-service" # AMP fallback
|
||||
]
|
||||
```
|
||||
|
||||
### 4. Use Descriptive Agent Roles
|
||||
|
||||
```python
|
||||
agent = Agent(
|
||||
role="Weather-Enhanced Market Analyst",
|
||||
goal="Analyze markets considering weather impacts",
|
||||
backstory="Financial analyst with access to weather data for agricultural market insights",
|
||||
mcps=[
|
||||
"https://weather.service.com/mcp#get_forecast",
|
||||
"crewai-amp:financial-data#stock_analysis"
|
||||
]
|
||||
)
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Common Issues
|
||||
|
||||
**No tools discovered:**
|
||||
```python
|
||||
# Check your MCP server URL and authentication
|
||||
# Verify the server is running and accessible
|
||||
mcps=["https://mcp.example.com/mcp?api_key=valid_key"]
|
||||
```
|
||||
|
||||
**Connection timeouts:**
|
||||
```python
|
||||
# Server may be slow or overloaded
|
||||
# CrewAI will log warnings and continue with other servers
|
||||
# Check server status or try backup servers
|
||||
```
|
||||
|
||||
**Authentication failures:**
|
||||
```python
|
||||
# Verify API keys and credentials
|
||||
# Check server documentation for required parameters
|
||||
# Ensure query parameters are properly URL encoded
|
||||
```
|
||||
|
||||
## Advanced: MCPServerAdapter
|
||||
|
||||
For complex scenarios requiring manual connection management, use the `MCPServerAdapter` class from `crewai-tools`. Using a Python context manager (`with` statement) is the recommended approach as it automatically handles starting and stopping the connection to the MCP server.
|
||||
@@ -8,39 +8,14 @@ mode: "wide"
|
||||
## Overview
|
||||
|
||||
The [Model Context Protocol](https://modelcontextprotocol.io/introduction) (MCP) provides a standardized way for AI agents to provide context to LLMs by communicating with external services, known as MCP Servers.
|
||||
|
||||
CrewAI offers **two approaches** for MCP integration:
|
||||
|
||||
### Simple DSL Integration** (Recommended)
|
||||
|
||||
Use the `mcps` field directly on agents for seamless MCP tool integration:
|
||||
|
||||
```python
|
||||
from crewai import Agent
|
||||
|
||||
agent = Agent(
|
||||
role="Research Analyst",
|
||||
goal="Research and analyze information",
|
||||
backstory="Expert researcher with access to external tools",
|
||||
mcps=[
|
||||
"https://mcp.exa.ai/mcp?api_key=your_key", # External MCP server
|
||||
"https://api.weather.com/mcp#get_forecast", # Specific tool from server
|
||||
"crewai-amp:financial-data", # CrewAI AMP marketplace
|
||||
"crewai-amp:research-tools#pubmed_search" # Specific AMP tool
|
||||
]
|
||||
)
|
||||
# MCP tools are now automatically available to your agent!
|
||||
```
|
||||
|
||||
### 🔧 **Advanced: MCPServerAdapter** (For Complex Scenarios)
|
||||
|
||||
For advanced use cases requiring manual connection management, the `crewai-tools` library provides the `MCPServerAdapter` class.
|
||||
The `crewai-tools` library extends CrewAI's capabilities by allowing you to seamlessly integrate tools from these MCP servers into your agents.
|
||||
This gives your crews access to a vast ecosystem of functionalities.
|
||||
|
||||
We currently support the following transport mechanisms:
|
||||
|
||||
- **Stdio**: for local servers (communication via standard input/output between processes on the same machine)
|
||||
- **Server-Sent Events (SSE)**: for remote servers (unidirectional, real-time data streaming from server to client over HTTP)
|
||||
- **Streamable HTTPS**: for remote servers (flexible, potentially bi-directional communication over HTTPS, often utilizing SSE for server-to-client streams)
|
||||
- **Streamable HTTP**: for remote servers (flexible, potentially bi-directional communication over HTTP, often utilizing SSE for server-to-client streams)
|
||||
|
||||
## Video Tutorial
|
||||
Watch this video tutorial for a comprehensive guide on MCP integration with CrewAI:
|
||||
@@ -56,125 +31,17 @@ Watch this video tutorial for a comprehensive guide on MCP integration with Crew
|
||||
|
||||
## Installation
|
||||
|
||||
CrewAI MCP integration requires the `mcp` library:
|
||||
Before you start using MCP with `crewai-tools`, you need to install the `mcp` extra `crewai-tools` dependency with the following command:
|
||||
|
||||
```shell
|
||||
# For Simple DSL Integration (Recommended)
|
||||
uv add mcp
|
||||
|
||||
# For Advanced MCPServerAdapter usage
|
||||
uv pip install 'crewai-tools[mcp]'
|
||||
```
|
||||
|
||||
## Quick Start: Simple DSL Integration
|
||||
## Key Concepts & Getting Started
|
||||
|
||||
The easiest way to integrate MCP servers is using the `mcps` field on your agents:
|
||||
The `MCPServerAdapter` class from `crewai-tools` is the primary way to connect to an MCP server and make its tools available to your CrewAI agents. It supports different transport mechanisms and simplifies connection management.
|
||||
|
||||
```python
|
||||
from crewai import Agent, Task, Crew
|
||||
|
||||
# Create agent with MCP tools
|
||||
research_agent = Agent(
|
||||
role="Research Analyst",
|
||||
goal="Find and analyze information using advanced search tools",
|
||||
backstory="Expert researcher with access to multiple data sources",
|
||||
mcps=[
|
||||
"https://mcp.exa.ai/mcp?api_key=your_key&profile=your_profile",
|
||||
"crewai-amp:weather-service#current_conditions"
|
||||
]
|
||||
)
|
||||
|
||||
# Create task
|
||||
research_task = Task(
|
||||
description="Research the latest developments in AI agent frameworks",
|
||||
expected_output="Comprehensive research report with citations",
|
||||
agent=research_agent
|
||||
)
|
||||
|
||||
# Create and run crew
|
||||
crew = Crew(agents=[research_agent], tasks=[research_task])
|
||||
result = crew.kickoff()
|
||||
```
|
||||
|
||||
That's it! The MCP tools are automatically discovered and available to your agent.
|
||||
|
||||
## MCP Reference Formats
|
||||
|
||||
The `mcps` field supports various reference formats for maximum flexibility:
|
||||
|
||||
### External MCP Servers
|
||||
|
||||
```python
|
||||
mcps=[
|
||||
# Full server - get all available tools
|
||||
"https://mcp.example.com/api",
|
||||
|
||||
# Specific tool from server using # syntax
|
||||
"https://api.weather.com/mcp#get_current_weather",
|
||||
|
||||
# Server with authentication parameters
|
||||
"https://mcp.exa.ai/mcp?api_key=your_key&profile=your_profile"
|
||||
]
|
||||
```
|
||||
|
||||
### CrewAI AMP Marketplace
|
||||
|
||||
```python
|
||||
mcps=[
|
||||
# Full AMP MCP service - get all available tools
|
||||
"crewai-amp:financial-data",
|
||||
|
||||
# Specific tool from AMP service using # syntax
|
||||
"crewai-amp:research-tools#pubmed_search",
|
||||
|
||||
# Multiple AMP services
|
||||
"crewai-amp:weather-service",
|
||||
"crewai-amp:market-analysis"
|
||||
]
|
||||
```
|
||||
|
||||
### Mixed References
|
||||
|
||||
```python
|
||||
mcps=[
|
||||
"https://external-api.com/mcp", # External server
|
||||
"https://weather.service.com/mcp#forecast", # Specific external tool
|
||||
"crewai-amp:financial-insights", # AMP service
|
||||
"crewai-amp:data-analysis#sentiment_tool" # Specific AMP tool
|
||||
]
|
||||
```
|
||||
|
||||
## Key Features
|
||||
|
||||
- 🔄 **Automatic Tool Discovery**: Tools are automatically discovered and integrated
|
||||
- 🏷️ **Name Collision Prevention**: Server names are prefixed to tool names
|
||||
- ⚡ **Performance Optimized**: On-demand connections with schema caching
|
||||
- 🛡️ **Error Resilience**: Graceful handling of unavailable servers
|
||||
- ⏱️ **Timeout Protection**: Built-in timeouts prevent hanging connections
|
||||
- 📊 **Transparent Integration**: Works seamlessly with existing CrewAI features
|
||||
|
||||
## Error Handling
|
||||
|
||||
The MCP DSL integration is designed to be resilient:
|
||||
|
||||
```python
|
||||
agent = Agent(
|
||||
role="Resilient Agent",
|
||||
goal="Continue working despite server issues",
|
||||
backstory="Agent that handles failures gracefully",
|
||||
mcps=[
|
||||
"https://reliable-server.com/mcp", # Will work
|
||||
"https://unreachable-server.com/mcp", # Will be skipped gracefully
|
||||
"https://slow-server.com/mcp", # Will timeout gracefully
|
||||
"crewai-amp:working-service" # Will work
|
||||
]
|
||||
)
|
||||
# Agent will use tools from working servers and log warnings for failing ones
|
||||
```
|
||||
|
||||
## Advanced: MCPServerAdapter
|
||||
|
||||
For complex scenarios requiring manual connection management, use the `MCPServerAdapter` class from `crewai-tools`. Using a Python context manager (`with` statement) is the recommended approach as it automatically handles starting and stopping the connection to the MCP server.
|
||||
Using a Python context manager (`with` statement) is the **recommended approach** for `MCPServerAdapter`. It automatically handles starting and stopping the connection to the MCP server.
|
||||
|
||||
## Connection Configuration
|
||||
|
||||
@@ -374,19 +241,11 @@ class CrewWithCustomTimeout:
|
||||
## Explore MCP Integrations
|
||||
|
||||
<CardGroup cols={2}>
|
||||
<Card
|
||||
title="Simple DSL Integration"
|
||||
icon="code"
|
||||
href="/en/mcp/dsl-integration"
|
||||
color="#3B82F6"
|
||||
>
|
||||
**Recommended**: Use the simple `mcps=[]` field syntax for effortless MCP integration.
|
||||
</Card>
|
||||
<Card
|
||||
title="Stdio Transport"
|
||||
icon="server"
|
||||
href="/en/mcp/stdio"
|
||||
color="#10B981"
|
||||
color="#3B82F6"
|
||||
>
|
||||
Connect to local MCP servers via standard input/output. Ideal for scripts and local executables.
|
||||
</Card>
|
||||
@@ -394,7 +253,7 @@ class CrewWithCustomTimeout:
|
||||
title="SSE Transport"
|
||||
icon="wifi"
|
||||
href="/en/mcp/sse"
|
||||
color="#F59E0B"
|
||||
color="#10B981"
|
||||
>
|
||||
Integrate with remote MCP servers using Server-Sent Events for real-time data streaming.
|
||||
</Card>
|
||||
@@ -402,7 +261,7 @@ class CrewWithCustomTimeout:
|
||||
title="Streamable HTTP Transport"
|
||||
icon="globe"
|
||||
href="/en/mcp/streamable-http"
|
||||
color="#8B5CF6"
|
||||
color="#F59E0B"
|
||||
>
|
||||
Utilize flexible Streamable HTTP for robust communication with remote MCP servers.
|
||||
</Card>
|
||||
@@ -410,7 +269,7 @@ class CrewWithCustomTimeout:
|
||||
title="Connecting to Multiple Servers"
|
||||
icon="layer-group"
|
||||
href="/en/mcp/multiple-servers"
|
||||
color="#EF4444"
|
||||
color="#8B5CF6"
|
||||
>
|
||||
Aggregate tools from several MCP servers simultaneously using a single adapter.
|
||||
</Card>
|
||||
@@ -418,7 +277,7 @@ class CrewWithCustomTimeout:
|
||||
title="Security Considerations"
|
||||
icon="lock"
|
||||
href="/en/mcp/security"
|
||||
color="#DC2626"
|
||||
color="#EF4444"
|
||||
>
|
||||
Review important security best practices for MCP integration to keep your agents safe.
|
||||
</Card>
|
||||
|
||||
@@ -1,232 +0,0 @@
|
||||
---
|
||||
title: MCP DSL 통합
|
||||
description: CrewAI의 간단한 DSL 구문을 사용하여 mcps 필드로 MCP 서버를 에이전트와 직접 통합하는 방법을 알아보세요.
|
||||
icon: code
|
||||
mode: "wide"
|
||||
---
|
||||
|
||||
## 개요
|
||||
|
||||
CrewAI의 MCP DSL(Domain Specific Language) 통합은 에이전트를 MCP(Model Context Protocol) 서버에 연결하는 **가장 간단한 방법**을 제공합니다. 에이전트에 `mcps` 필드만 추가하면 CrewAI가 모든 복잡성을 자동으로 처리합니다.
|
||||
|
||||
<Info>
|
||||
이는 대부분의 MCP 사용 사례에 **권장되는 접근 방식**입니다. 수동 연결 관리가 필요한 고급 시나리오의 경우 [MCPServerAdapter](/ko/mcp/overview#advanced-mcpserveradapter)를 참조하세요.
|
||||
</Info>
|
||||
|
||||
## 기본 사용법
|
||||
|
||||
`mcps` 필드를 사용하여 에이전트에 MCP 서버를 추가하세요:
|
||||
|
||||
```python
|
||||
from crewai import Agent
|
||||
|
||||
agent = Agent(
|
||||
role="연구 보조원",
|
||||
goal="연구 및 분석 업무 지원",
|
||||
backstory="고급 연구 도구에 접근할 수 있는 전문가 보조원",
|
||||
mcps=[
|
||||
"https://mcp.exa.ai/mcp?api_key=your_key&profile=research"
|
||||
]
|
||||
)
|
||||
|
||||
# MCP 도구들이 이제 자동으로 사용 가능합니다!
|
||||
# 수동 연결 관리나 도구 구성이 필요 없습니다
|
||||
```
|
||||
|
||||
## 지원되는 참조 형식
|
||||
|
||||
### 외부 MCP 원격 서버
|
||||
|
||||
```python
|
||||
# 기본 HTTPS 서버
|
||||
"https://api.example.com/mcp"
|
||||
|
||||
# 인증이 포함된 서버
|
||||
"https://mcp.exa.ai/mcp?api_key=your_key&profile=your_profile"
|
||||
|
||||
# 사용자 정의 경로가 있는 서버
|
||||
"https://services.company.com/api/v1/mcp"
|
||||
```
|
||||
|
||||
### 특정 도구 선택
|
||||
|
||||
`#` 구문을 사용하여 서버에서 특정 도구를 선택하세요:
|
||||
|
||||
```python
|
||||
# 날씨 서버에서 예보 도구만 가져오기
|
||||
"https://weather.api.com/mcp#get_forecast"
|
||||
|
||||
# Exa에서 검색 도구만 가져오기
|
||||
"https://mcp.exa.ai/mcp?api_key=your_key#web_search_exa"
|
||||
```
|
||||
|
||||
### CrewAI AMP 마켓플레이스
|
||||
|
||||
CrewAI AMP 마켓플레이스의 도구에 액세스하세요:
|
||||
|
||||
```python
|
||||
# 모든 도구가 포함된 전체 서비스
|
||||
"crewai-amp:financial-data"
|
||||
|
||||
# AMP 서비스의 특정 도구
|
||||
"crewai-amp:research-tools#pubmed_search"
|
||||
|
||||
# 다중 AMP 서비스
|
||||
mcps=[
|
||||
"crewai-amp:weather-insights",
|
||||
"crewai-amp:market-analysis",
|
||||
"crewai-amp:social-media-monitoring"
|
||||
]
|
||||
```
|
||||
|
||||
## 완전한 예제
|
||||
|
||||
다음은 여러 MCP 서버를 사용하는 완전한 예제입니다:
|
||||
|
||||
```python
|
||||
from crewai import Agent, Task, Crew, Process
|
||||
|
||||
# 다중 MCP 소스를 가진 에이전트 생성
|
||||
multi_source_agent = Agent(
|
||||
role="다중 소스 연구 분석가",
|
||||
goal="다중 데이터 소스를 사용한 종합적인 연구 수행",
|
||||
backstory="""웹 검색, 날씨 데이터, 금융 정보,
|
||||
학술 연구 도구에 접근할 수 있는 전문가 연구원""",
|
||||
mcps=[
|
||||
# 외부 MCP 서버
|
||||
"https://mcp.exa.ai/mcp?api_key=your_exa_key&profile=research",
|
||||
"https://weather.api.com/mcp#get_current_conditions",
|
||||
|
||||
# CrewAI AMP 마켓플레이스
|
||||
"crewai-amp:financial-insights",
|
||||
"crewai-amp:academic-research#pubmed_search",
|
||||
"crewai-amp:market-intelligence#competitor_analysis"
|
||||
]
|
||||
)
|
||||
|
||||
# 종합적인 연구 작업 생성
|
||||
research_task = Task(
|
||||
description="""AI 에이전트가 비즈니스 생산성에 미치는 영향을 연구하세요.
|
||||
원격 근무에 대한 현재 날씨 영향, 금융 시장 트렌드,
|
||||
AI 에이전트 프레임워크에 대한 최근 학술 발표를 포함하세요.""",
|
||||
expected_output="""다음을 다루는 종합 보고서:
|
||||
1. AI 에이전트 비즈니스 영향 분석
|
||||
2. 원격 근무를 위한 날씨 고려사항
|
||||
3. AI 관련 금융 시장 트렌드
|
||||
4. 학술 연구 인용 및 통찰
|
||||
5. 경쟁 환경 분석""",
|
||||
agent=multi_source_agent
|
||||
)
|
||||
|
||||
# crew 생성 및 실행
|
||||
research_crew = Crew(
|
||||
agents=[multi_source_agent],
|
||||
tasks=[research_task],
|
||||
process=Process.sequential,
|
||||
verbose=True
|
||||
)
|
||||
|
||||
result = research_crew.kickoff()
|
||||
print(f"{len(multi_source_agent.mcps)}개의 MCP 데이터 소스로 연구 완료")
|
||||
```
|
||||
|
||||
## 주요 기능
|
||||
|
||||
- 🔄 **자동 도구 발견**: 도구들이 자동으로 발견되고 통합됩니다
|
||||
- 🏷️ **이름 충돌 방지**: 서버 이름이 도구 이름에 접두사로 붙습니다
|
||||
- ⚡ **성능 최적화**: 스키마 캐싱과 온디맨드 연결
|
||||
- 🛡️ **오류 복원력**: 사용할 수 없는 서버의 우아한 처리
|
||||
- ⏱️ **타임아웃 보호**: 내장 타임아웃으로 연결 중단 방지
|
||||
- 📊 **투명한 통합**: 기존 CrewAI 기능과 완벽한 연동
|
||||
|
||||
## 오류 처리
|
||||
|
||||
MCP DSL 통합은 복원력 있게 설계되었습니다:
|
||||
|
||||
```python
|
||||
agent = Agent(
|
||||
role="복원력 있는 에이전트",
|
||||
goal="서버 문제에도 불구하고 작업 계속",
|
||||
backstory="장애를 우아하게 처리하는 에이전트",
|
||||
mcps=[
|
||||
"https://reliable-server.com/mcp", # 작동할 것
|
||||
"https://unreachable-server.com/mcp", # 우아하게 건너뛸 것
|
||||
"https://slow-server.com/mcp", # 우아하게 타임아웃될 것
|
||||
"crewai-amp:working-service" # 작동할 것
|
||||
]
|
||||
)
|
||||
# 에이전트는 작동하는 서버의 도구를 사용하고 실패한 서버에 대한 경고를 로그에 남깁니다
|
||||
```
|
||||
|
||||
## 성능 기능
|
||||
|
||||
### 자동 캐싱
|
||||
|
||||
도구 스키마는 성능 향상을 위해 5분간 캐시됩니다:
|
||||
|
||||
```python
|
||||
# 첫 번째 에이전트 생성 - 서버에서 도구 발견
|
||||
agent1 = Agent(role="첫 번째", goal="테스트", backstory="테스트",
|
||||
mcps=["https://api.example.com/mcp"])
|
||||
|
||||
# 두 번째 에이전트 생성 (5분 이내) - 캐시된 도구 스키마 사용
|
||||
agent2 = Agent(role="두 번째", goal="테스트", backstory="테스트",
|
||||
mcps=["https://api.example.com/mcp"]) # 훨씬 빠릅니다!
|
||||
```
|
||||
|
||||
### 온디맨드 연결
|
||||
|
||||
도구 연결은 실제로 사용될 때만 설정됩니다:
|
||||
|
||||
```python
|
||||
# 에이전트 생성은 빠름 - 아직 MCP 연결을 만들지 않음
|
||||
agent = Agent(
|
||||
role="온디맨드 에이전트",
|
||||
goal="도구를 효율적으로 사용",
|
||||
backstory="필요할 때만 연결하는 효율적인 에이전트",
|
||||
mcps=["https://api.example.com/mcp"]
|
||||
)
|
||||
|
||||
# MCP 연결은 도구가 실제로 실행될 때만 만들어집니다
|
||||
# 이는 연결 오버헤드를 최소화하고 시작 성능을 개선합니다
|
||||
```
|
||||
|
||||
## 모범 사례
|
||||
|
||||
### 1. 가능하면 특정 도구 사용
|
||||
|
||||
```python
|
||||
# 좋음 - 필요한 도구만 가져오기
|
||||
mcps=["https://weather.api.com/mcp#get_forecast"]
|
||||
|
||||
# 덜 효율적 - 서버의 모든 도구 가져오기
|
||||
mcps=["https://weather.api.com/mcp"]
|
||||
```
|
||||
|
||||
### 2. 인증을 안전하게 처리
|
||||
|
||||
```python
|
||||
import os
|
||||
|
||||
# 환경 변수에 API 키 저장
|
||||
exa_key = os.getenv("EXA_API_KEY")
|
||||
exa_profile = os.getenv("EXA_PROFILE")
|
||||
|
||||
agent = Agent(
|
||||
role="안전한 에이전트",
|
||||
goal="MCP 도구를 안전하게 사용",
|
||||
backstory="보안을 고려하는 에이전트",
|
||||
mcps=[f"https://mcp.exa.ai/mcp?api_key={exa_key}&profile={exa_profile}"]
|
||||
)
|
||||
```
|
||||
|
||||
### 3. 서버 장애 계획
|
||||
|
||||
```python
|
||||
# 항상 백업 옵션 포함
|
||||
mcps=[
|
||||
"https://primary-api.com/mcp", # 주요 선택
|
||||
"https://backup-api.com/mcp", # 백업 옵션
|
||||
"crewai-amp:reliable-service" # AMP 폴백
|
||||
]
|
||||
```
|
||||
@@ -8,37 +8,12 @@ mode: "wide"
|
||||
## 개요
|
||||
|
||||
[Model Context Protocol](https://modelcontextprotocol.io/introduction) (MCP)는 AI 에이전트가 MCP 서버로 알려진 외부 서비스와 통신함으로써 LLM에 컨텍스트를 제공할 수 있도록 표준화된 방식을 제공합니다.
|
||||
|
||||
CrewAI는 MCP 통합을 위한 **두 가지 접근 방식**을 제공합니다:
|
||||
|
||||
### 🚀 **새로운 기능: 간단한 DSL 통합** (권장)
|
||||
|
||||
에이전트에 `mcps` 필드를 직접 사용하여 완벽한 MCP 도구 통합을 구현하세요:
|
||||
|
||||
```python
|
||||
from crewai import Agent
|
||||
|
||||
agent = Agent(
|
||||
role="연구 분석가",
|
||||
goal="정보를 연구하고 분석",
|
||||
backstory="외부 도구에 접근할 수 있는 전문가 연구원",
|
||||
mcps=[
|
||||
"https://mcp.exa.ai/mcp?api_key=your_key", # 외부 MCP 서버
|
||||
"https://api.weather.com/mcp#get_forecast", # 서버의 특정 도구
|
||||
"crewai-amp:financial-data", # CrewAI AMP 마켓플레이스
|
||||
"crewai-amp:research-tools#pubmed_search" # 특정 AMP 도구
|
||||
]
|
||||
)
|
||||
# MCP 도구들이 이제 자동으로 에이전트에서 사용 가능합니다!
|
||||
```
|
||||
|
||||
### 🔧 **고급: MCPServerAdapter** (복잡한 시나리오용)
|
||||
|
||||
수동 연결 관리가 필요한 고급 사용 사례의 경우 `crewai-tools` 라이브러리는 `MCPServerAdapter` 클래스를 제공합니다.
|
||||
`crewai-tools` 라이브러리는 CrewAI의 기능을 확장하여, 이러한 MCP 서버에서 제공하는 툴을 에이전트에 원활하게 통합할 수 있도록 해줍니다.
|
||||
이를 통해 여러분의 crew는 방대한 기능 에코시스템에 접근할 수 있습니다.
|
||||
|
||||
현재 다음과 같은 전송 메커니즘을 지원합니다:
|
||||
|
||||
- **HTTPS**: 원격 서버용 (HTTPS를 통한 보안 통신)
|
||||
- **Stdio**: 로컬 서버용 (동일 머신 내 프로세스 간 표준 입력/출력을 통한 통신)
|
||||
- **Server-Sent Events (SSE)**: 원격 서버용 (서버에서 클라이언트로의 일방향, 실시간 데이터 스트리밍, HTTP 기반)
|
||||
- **Streamable HTTP**: 원격 서버용 (유연하며 잠재적으로 양방향 통신이 가능, 주로 SSE를 활용한 서버-클라이언트 스트림 제공, HTTP 기반)
|
||||
|
||||
|
||||
@@ -1,232 +0,0 @@
|
||||
---
|
||||
title: Integração DSL MCP
|
||||
description: Aprenda a usar a sintaxe DSL simples do CrewAI para integrar servidores MCP diretamente com seus agentes usando o campo mcps.
|
||||
icon: code
|
||||
mode: "wide"
|
||||
---
|
||||
|
||||
## Visão Geral
|
||||
|
||||
A integração DSL (Domain Specific Language) MCP do CrewAI oferece a **forma mais simples** de conectar seus agentes aos servidores MCP (Model Context Protocol). Basta adicionar um campo `mcps` ao seu agente e o CrewAI cuida de toda a complexidade automaticamente.
|
||||
|
||||
<Info>
|
||||
Esta é a **abordagem recomendada** para a maioria dos casos de uso de MCP. Para cenários avançados que requerem gerenciamento manual de conexão, veja [MCPServerAdapter](/pt-BR/mcp/overview#advanced-mcpserveradapter).
|
||||
</Info>
|
||||
|
||||
## Uso Básico
|
||||
|
||||
Adicione servidores MCP ao seu agente usando o campo `mcps`:
|
||||
|
||||
```python
|
||||
from crewai import Agent
|
||||
|
||||
agent = Agent(
|
||||
role="Assistente de Pesquisa",
|
||||
goal="Ajudar com tarefas de pesquisa e análise",
|
||||
backstory="Assistente especialista com acesso a ferramentas avançadas de pesquisa",
|
||||
mcps=[
|
||||
"https://mcp.exa.ai/mcp?api_key=sua_chave&profile=pesquisa"
|
||||
]
|
||||
)
|
||||
|
||||
# As ferramentas MCP agora estão automaticamente disponíveis!
|
||||
# Não é necessário gerenciamento manual de conexão ou configuração de ferramentas
|
||||
```
|
||||
|
||||
## Formatos de Referência Suportados
|
||||
|
||||
### Servidores MCP Remotos Externos
|
||||
|
||||
```python
|
||||
# Servidor HTTPS básico
|
||||
"https://api.example.com/mcp"
|
||||
|
||||
# Servidor com autenticação
|
||||
"https://mcp.exa.ai/mcp?api_key=sua_chave&profile=seu_perfil"
|
||||
|
||||
# Servidor com caminho personalizado
|
||||
"https://services.company.com/api/v1/mcp"
|
||||
```
|
||||
|
||||
### Seleção de Ferramentas Específicas
|
||||
|
||||
Use a sintaxe `#` para selecionar ferramentas específicas de um servidor:
|
||||
|
||||
```python
|
||||
# Obter apenas a ferramenta de previsão do servidor meteorológico
|
||||
"https://weather.api.com/mcp#get_forecast"
|
||||
|
||||
# Obter apenas a ferramenta de busca do Exa
|
||||
"https://mcp.exa.ai/mcp?api_key=sua_chave#web_search_exa"
|
||||
```
|
||||
|
||||
### Marketplace CrewAI AMP
|
||||
|
||||
Acesse ferramentas do marketplace CrewAI AMP:
|
||||
|
||||
```python
|
||||
# Serviço completo com todas as ferramentas
|
||||
"crewai-amp:financial-data"
|
||||
|
||||
# Ferramenta específica do serviço AMP
|
||||
"crewai-amp:research-tools#pubmed_search"
|
||||
|
||||
# Múltiplos serviços AMP
|
||||
mcps=[
|
||||
"crewai-amp:weather-insights",
|
||||
"crewai-amp:market-analysis",
|
||||
"crewai-amp:social-media-monitoring"
|
||||
]
|
||||
```
|
||||
|
||||
## Exemplo Completo
|
||||
|
||||
Aqui está um exemplo completo usando múltiplos servidores MCP:
|
||||
|
||||
```python
|
||||
from crewai import Agent, Task, Crew, Process
|
||||
|
||||
# Criar agente com múltiplas fontes MCP
|
||||
agente_multi_fonte = Agent(
|
||||
role="Analista de Pesquisa Multi-Fonte",
|
||||
goal="Conduzir pesquisa abrangente usando múltiplas fontes de dados",
|
||||
backstory="""Pesquisador especialista com acesso a busca web, dados meteorológicos,
|
||||
informações financeiras e ferramentas de pesquisa acadêmica""",
|
||||
mcps=[
|
||||
# Servidores MCP externos
|
||||
"https://mcp.exa.ai/mcp?api_key=sua_chave_exa&profile=pesquisa",
|
||||
"https://weather.api.com/mcp#get_current_conditions",
|
||||
|
||||
# Marketplace CrewAI AMP
|
||||
"crewai-amp:financial-insights",
|
||||
"crewai-amp:academic-research#pubmed_search",
|
||||
"crewai-amp:market-intelligence#competitor_analysis"
|
||||
]
|
||||
)
|
||||
|
||||
# Criar tarefa de pesquisa abrangente
|
||||
tarefa_pesquisa = Task(
|
||||
description="""Pesquisar o impacto dos agentes de IA na produtividade empresarial.
|
||||
Incluir impactos climáticos atuais no trabalho remoto, tendências do mercado financeiro,
|
||||
e publicações acadêmicas recentes sobre frameworks de agentes de IA.""",
|
||||
expected_output="""Relatório abrangente cobrindo:
|
||||
1. Análise do impacto dos agentes de IA nos negócios
|
||||
2. Considerações climáticas para trabalho remoto
|
||||
3. Tendências do mercado financeiro relacionadas à IA
|
||||
4. Citações e insights de pesquisa acadêmica
|
||||
5. Análise do cenário competitivo""",
|
||||
agent=agente_multi_fonte
|
||||
)
|
||||
|
||||
# Criar e executar crew
|
||||
crew_pesquisa = Crew(
|
||||
agents=[agente_multi_fonte],
|
||||
tasks=[tarefa_pesquisa],
|
||||
process=Process.sequential,
|
||||
verbose=True
|
||||
)
|
||||
|
||||
resultado = crew_pesquisa.kickoff()
|
||||
print(f"Pesquisa concluída com {len(agente_multi_fonte.mcps)} fontes de dados MCP")
|
||||
```
|
||||
|
||||
## Recursos Principais
|
||||
|
||||
- 🔄 **Descoberta Automática de Ferramentas**: Ferramentas são descobertas e integradas automaticamente
|
||||
- 🏷️ **Prevenção de Colisão de Nomes**: Nomes de servidor são prefixados aos nomes das ferramentas
|
||||
- ⚡ **Otimizado para Performance**: Conexões sob demanda com cache de esquemas
|
||||
- 🛡️ **Resiliência a Erros**: Tratamento gracioso de servidores indisponíveis
|
||||
- ⏱️ **Proteção por Timeout**: Timeouts integrados previnem conexões travadas
|
||||
- 📊 **Integração Transparente**: Funciona perfeitamente com recursos existentes do CrewAI
|
||||
|
||||
## Tratamento de Erros
|
||||
|
||||
A integração DSL MCP é projetada para ser resiliente:
|
||||
|
||||
```python
|
||||
agente = Agent(
|
||||
role="Agente Resiliente",
|
||||
goal="Continuar trabalhando apesar de problemas no servidor",
|
||||
backstory="Agente que lida graciosamente com falhas",
|
||||
mcps=[
|
||||
"https://servidor-confiavel.com/mcp", # Vai funcionar
|
||||
"https://servidor-inalcancavel.com/mcp", # Será ignorado graciosamente
|
||||
"https://servidor-lento.com/mcp", # Timeout gracioso
|
||||
"crewai-amp:servico-funcionando" # Vai funcionar
|
||||
]
|
||||
)
|
||||
# O agente usará ferramentas de servidores funcionais e registrará avisos para os que falharem
|
||||
```
|
||||
|
||||
## Recursos de Performance
|
||||
|
||||
### Cache Automático
|
||||
|
||||
Esquemas de ferramentas são cacheados por 5 minutos para melhorar a performance:
|
||||
|
||||
```python
|
||||
# Primeira criação de agente - descobre ferramentas do servidor
|
||||
agente1 = Agent(role="Primeiro", goal="Teste", backstory="Teste",
|
||||
mcps=["https://api.example.com/mcp"])
|
||||
|
||||
# Segunda criação de agente (dentro de 5 minutos) - usa esquemas cacheados
|
||||
agente2 = Agent(role="Segundo", goal="Teste", backstory="Teste",
|
||||
mcps=["https://api.example.com/mcp"]) # Muito mais rápido!
|
||||
```
|
||||
|
||||
### Conexões Sob Demanda
|
||||
|
||||
Conexões de ferramentas são estabelecidas apenas quando as ferramentas são realmente usadas:
|
||||
|
||||
```python
|
||||
# Criação do agente é rápida - nenhuma conexão MCP feita ainda
|
||||
agente = Agent(
|
||||
role="Agente Sob Demanda",
|
||||
goal="Usar ferramentas eficientemente",
|
||||
backstory="Agente eficiente que conecta apenas quando necessário",
|
||||
mcps=["https://api.example.com/mcp"]
|
||||
)
|
||||
|
||||
# Conexão MCP é feita apenas quando uma ferramenta é realmente executada
|
||||
# Isso minimiza o overhead de conexão e melhora a performance de inicialização
|
||||
```
|
||||
|
||||
## Melhores Práticas
|
||||
|
||||
### 1. Use Ferramentas Específicas Quando Possível
|
||||
|
||||
```python
|
||||
# Bom - obter apenas as ferramentas necessárias
|
||||
mcps=["https://weather.api.com/mcp#get_forecast"]
|
||||
|
||||
# Menos eficiente - obter todas as ferramentas do servidor
|
||||
mcps=["https://weather.api.com/mcp"]
|
||||
```
|
||||
|
||||
### 2. Lidar com Autenticação de Forma Segura
|
||||
|
||||
```python
|
||||
import os
|
||||
|
||||
# Armazenar chaves API em variáveis de ambiente
|
||||
exa_key = os.getenv("EXA_API_KEY")
|
||||
exa_profile = os.getenv("EXA_PROFILE")
|
||||
|
||||
agente = Agent(
|
||||
role="Agente Seguro",
|
||||
goal="Usar ferramentas MCP com segurança",
|
||||
backstory="Agente consciente da segurança",
|
||||
mcps=[f"https://mcp.exa.ai/mcp?api_key={exa_key}&profile={exa_profile}"]
|
||||
)
|
||||
```
|
||||
|
||||
### 3. Planejar para Falhas de Servidor
|
||||
|
||||
```python
|
||||
# Sempre incluir opções de backup
|
||||
mcps=[
|
||||
"https://api-principal.com/mcp", # Escolha principal
|
||||
"https://api-backup.com/mcp", # Opção de backup
|
||||
"crewai-amp:servico-confiavel" # Fallback AMP
|
||||
]
|
||||
```
|
||||
@@ -8,37 +8,12 @@ mode: "wide"
|
||||
## Visão Geral
|
||||
|
||||
O [Model Context Protocol](https://modelcontextprotocol.io/introduction) (MCP) fornece uma maneira padronizada para agentes de IA fornecerem contexto para LLMs comunicando-se com serviços externos, conhecidos como Servidores MCP.
|
||||
|
||||
O CrewAI oferece **duas abordagens** para integração MCP:
|
||||
|
||||
### 🚀 **Novo: Integração DSL Simples** (Recomendado)
|
||||
|
||||
Use o campo `mcps` diretamente nos agentes para integração perfeita de ferramentas MCP:
|
||||
|
||||
```python
|
||||
from crewai import Agent
|
||||
|
||||
agent = Agent(
|
||||
role="Analista de Pesquisa",
|
||||
goal="Pesquisar e analisar informações",
|
||||
backstory="Pesquisador especialista com acesso a ferramentas externas",
|
||||
mcps=[
|
||||
"https://mcp.exa.ai/mcp?api_key=sua_chave", # Servidor MCP externo
|
||||
"https://api.weather.com/mcp#get_forecast", # Ferramenta específica do servidor
|
||||
"crewai-amp:financial-data", # Marketplace CrewAI AMP
|
||||
"crewai-amp:research-tools#pubmed_search" # Ferramenta AMP específica
|
||||
]
|
||||
)
|
||||
# Ferramentas MCP agora estão automaticamente disponíveis para seu agente!
|
||||
```
|
||||
|
||||
### 🔧 **Avançado: MCPServerAdapter** (Para Cenários Complexos)
|
||||
|
||||
Para casos de uso avançados que requerem gerenciamento manual de conexão, a biblioteca `crewai-tools` fornece a classe `MCPServerAdapter`.
|
||||
A biblioteca `crewai-tools` expande as capacidades do CrewAI permitindo que você integre facilmente ferramentas desses servidores MCP em seus agentes.
|
||||
Isso oferece às suas crews acesso a um vasto ecossistema de funcionalidades.
|
||||
|
||||
Atualmente, suportamos os seguintes mecanismos de transporte:
|
||||
|
||||
- **HTTPS**: para servidores remotos (comunicação segura via HTTPS)
|
||||
- **Stdio**: para servidores locais (comunicação via entrada/saída padrão entre processos na mesma máquina)
|
||||
- **Server-Sent Events (SSE)**: para servidores remotos (transmissão de dados unidirecional em tempo real do servidor para o cliente via HTTP)
|
||||
- **Streamable HTTP**: para servidores remotos (comunicação flexível e potencialmente bidirecional via HTTP, geralmente utilizando SSE para streams do servidor para o cliente)
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@ dependencies = [
|
||||
"pytube>=15.0.0",
|
||||
"requests>=2.32.5",
|
||||
"docker>=7.1.0",
|
||||
"crewai==1.0.0",
|
||||
"crewai==1.0.0b3",
|
||||
"lancedb>=0.5.4",
|
||||
"tiktoken>=0.8.0",
|
||||
"beautifulsoup4>=4.13.4",
|
||||
|
||||
@@ -43,6 +43,9 @@ from crewai_tools.tools.contextualai_rerank_tool.contextual_rerank_tool import (
|
||||
from crewai_tools.tools.couchbase_tool.couchbase_tool import (
|
||||
CouchbaseFTSVectorSearchTool,
|
||||
)
|
||||
from crewai_tools.tools.crewai_enterprise_tools.crewai_enterprise_tools import (
|
||||
CrewaiEnterpriseTools,
|
||||
)
|
||||
from crewai_tools.tools.crewai_platform_tools.crewai_platform_tools import (
|
||||
CrewaiPlatformTools,
|
||||
)
|
||||
@@ -211,6 +214,7 @@ __all__ = [
|
||||
"ContextualAIQueryTool",
|
||||
"ContextualAIRerankTool",
|
||||
"CouchbaseFTSVectorSearchTool",
|
||||
"CrewaiEnterpriseTools",
|
||||
"CrewaiPlatformTools",
|
||||
"DOCXSearchTool",
|
||||
"DallETool",
|
||||
@@ -287,4 +291,4 @@ __all__ = [
|
||||
"ZapierActionTools",
|
||||
]
|
||||
|
||||
__version__ = "1.0.0"
|
||||
__version__ = "1.0.0b3"
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from typing import Any, Literal, Optional, Union, _SpecialForm, cast, get_origin
|
||||
from typing import Any, Literal, Optional, Union, cast, get_origin
|
||||
import warnings
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
@@ -41,7 +41,7 @@ class EnterpriseActionTool(BaseTool):
|
||||
action_schema: dict[str, Any],
|
||||
enterprise_api_base_url: str | None = None,
|
||||
):
|
||||
self._model_registry = {} # type: ignore[var-annotated]
|
||||
self._model_registry = {}
|
||||
self._base_name = self._sanitize_name(name)
|
||||
|
||||
schema_props, required = self._extract_schema_info(action_schema)
|
||||
@@ -67,7 +67,7 @@ class EnterpriseActionTool(BaseTool):
|
||||
# Create the model
|
||||
if field_definitions:
|
||||
try:
|
||||
args_schema = create_model( # type: ignore[call-overload]
|
||||
args_schema = create_model(
|
||||
f"{self._base_name}Schema", **field_definitions
|
||||
)
|
||||
except Exception:
|
||||
@@ -110,9 +110,7 @@ class EnterpriseActionTool(BaseTool):
|
||||
)
|
||||
return schema_props, required
|
||||
|
||||
def _process_schema_type(
|
||||
self, schema: dict[str, Any], type_name: str
|
||||
) -> type[Any] | _SpecialForm:
|
||||
def _process_schema_type(self, schema: dict[str, Any], type_name: str) -> type[Any]:
|
||||
"""Process a JSON schema and return appropriate Python type."""
|
||||
if "anyOf" in schema:
|
||||
any_of_types = schema["anyOf"]
|
||||
@@ -141,7 +139,7 @@ class EnterpriseActionTool(BaseTool):
|
||||
if json_type == "array":
|
||||
items_schema = schema.get("items", {"type": "string"})
|
||||
item_type = self._process_schema_type(items_schema, f"{type_name}Item")
|
||||
return list[item_type] # type: ignore[valid-type]
|
||||
return list[item_type]
|
||||
|
||||
if json_type == "object":
|
||||
return self._create_nested_model(schema, type_name)
|
||||
@@ -176,20 +174,18 @@ class EnterpriseActionTool(BaseTool):
|
||||
prop_type = str
|
||||
|
||||
field_definitions[prop_name] = self._create_field_definition(
|
||||
prop_type,
|
||||
is_required,
|
||||
prop_desc, # type: ignore[arg-type]
|
||||
prop_type, is_required, prop_desc
|
||||
)
|
||||
|
||||
try:
|
||||
nested_model = create_model(full_model_name, **field_definitions) # type: ignore[call-overload]
|
||||
nested_model = create_model(full_model_name, **field_definitions)
|
||||
self._model_registry[full_model_name] = nested_model
|
||||
return nested_model
|
||||
except Exception:
|
||||
return dict
|
||||
|
||||
def _create_field_definition(
|
||||
self, field_type: type[Any] | _SpecialForm, is_required: bool, description: str
|
||||
self, field_type: type[Any], is_required: bool, description: str
|
||||
) -> tuple:
|
||||
"""Create Pydantic field definition based on type and requirement."""
|
||||
if is_required:
|
||||
@@ -280,7 +276,7 @@ class EnterpriseActionKitToolAdapter:
|
||||
):
|
||||
"""Initialize the adapter with an enterprise action token."""
|
||||
self._set_enterprise_action_token(enterprise_action_token)
|
||||
self._actions_schema = {} # type: ignore[var-annotated]
|
||||
self._actions_schema = {}
|
||||
self._tools = None
|
||||
self.enterprise_api_base_url = (
|
||||
enterprise_api_base_url or get_enterprise_api_base_url()
|
||||
|
||||
@@ -2,11 +2,8 @@ from collections.abc import Callable
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from lancedb import ( # type: ignore[import-untyped]
|
||||
DBConnection as LanceDBConnection,
|
||||
connect as lancedb_connect,
|
||||
)
|
||||
from lancedb.table import Table as LanceDBTable # type: ignore[import-untyped]
|
||||
from lancedb import DBConnection as LanceDBConnection, connect as lancedb_connect
|
||||
from lancedb.table import Table as LanceDBTable
|
||||
from openai import Client as OpenAIClient
|
||||
from pydantic import Field, PrivateAttr
|
||||
|
||||
@@ -40,7 +37,7 @@ class LanceDBAdapter(Adapter):
|
||||
|
||||
super().model_post_init(__context)
|
||||
|
||||
def query(self, question: str) -> str: # type: ignore[override]
|
||||
def query(self, question: str) -> str:
|
||||
query = self.embedding_function([question])[0]
|
||||
results = (
|
||||
self._table.search(query, vector_column_name=self.vector_column_name)
|
||||
|
||||
@@ -27,7 +27,7 @@ class RAGAdapter(Adapter):
|
||||
embedding_config=embedding_config,
|
||||
)
|
||||
|
||||
def query(self, question: str) -> str: # type: ignore[override]
|
||||
def query(self, question: str) -> str:
|
||||
return self._adapter.query(question)
|
||||
|
||||
def add(
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import Generic, TypeVar
|
||||
|
||||
@@ -31,7 +29,7 @@ class ToolCollection(list, Generic[T]):
|
||||
def _build_name_cache(self) -> None:
|
||||
self._name_cache = {tool.name.lower(): tool for tool in self}
|
||||
|
||||
def __getitem__(self, key: int | str) -> T: # type: ignore[override]
|
||||
def __getitem__(self, key: int | str) -> T:
|
||||
if isinstance(key, str):
|
||||
return self._name_cache[key.lower()]
|
||||
return super().__getitem__(key)
|
||||
@@ -40,11 +38,11 @@ class ToolCollection(list, Generic[T]):
|
||||
super().append(tool)
|
||||
self._name_cache[tool.name.lower()] = tool
|
||||
|
||||
def extend(self, tools: list[T]) -> None: # type: ignore[override]
|
||||
def extend(self, tools: list[T]) -> None:
|
||||
super().extend(tools)
|
||||
self._build_name_cache()
|
||||
|
||||
def insert(self, index: int, tool: T) -> None: # type: ignore[override]
|
||||
def insert(self, index: int, tool: T) -> None:
|
||||
super().insert(index, tool)
|
||||
self._name_cache[tool.name.lower()] = tool
|
||||
|
||||
@@ -53,13 +51,13 @@ class ToolCollection(list, Generic[T]):
|
||||
if tool.name.lower() in self._name_cache:
|
||||
del self._name_cache[tool.name.lower()]
|
||||
|
||||
def pop(self, index: int = -1) -> T: # type: ignore[override]
|
||||
def pop(self, index: int = -1) -> T:
|
||||
tool = super().pop(index)
|
||||
if tool.name.lower() in self._name_cache:
|
||||
del self._name_cache[tool.name.lower()]
|
||||
return tool
|
||||
|
||||
def filter_by_names(self, names: list[str] | None = None) -> ToolCollection[T]:
|
||||
def filter_by_names(self, names: list[str] | None = None) -> "ToolCollection[T]":
|
||||
if names is None:
|
||||
return self
|
||||
|
||||
@@ -71,7 +69,7 @@ class ToolCollection(list, Generic[T]):
|
||||
]
|
||||
)
|
||||
|
||||
def filter_where(self, func: Callable[[T], bool]) -> ToolCollection[T]:
|
||||
def filter_where(self, func: Callable[[T], bool]) -> "ToolCollection[T]":
|
||||
return ToolCollection([tool for tool in self if func(tool)])
|
||||
|
||||
def clear(self) -> None:
|
||||
|
||||
@@ -1,15 +1,12 @@
|
||||
import logging
|
||||
import os
|
||||
from typing import Final, Literal
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
from pydantic import Field, create_model
|
||||
import requests
|
||||
|
||||
|
||||
ACTIONS_URL: Final[Literal["https://actions.zapier.com/api/v2/ai-actions"]] = (
|
||||
"https://actions.zapier.com/api/v2/ai-actions"
|
||||
)
|
||||
ACTIONS_URL = "https://actions.zapier.com/api/v2/ai-actions"
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -58,6 +55,8 @@ class ZapierActionTool(BaseTool):
|
||||
class ZapierActionsAdapter:
|
||||
"""Adapter for Zapier Actions."""
|
||||
|
||||
api_key: str
|
||||
|
||||
def __init__(self, api_key: str | None = None):
|
||||
self.api_key = api_key or os.getenv("ZAPIER_API_KEY")
|
||||
if not self.api_key:
|
||||
@@ -78,7 +77,7 @@ class ZapierActionsAdapter:
|
||||
|
||||
return response.json()
|
||||
|
||||
def tools(self) -> list[ZapierActionTool]:
|
||||
def tools(self) -> list[BaseTool]:
|
||||
"""Convert Zapier actions to BaseTool instances."""
|
||||
actions_response = self.get_zapier_actions()
|
||||
tools = []
|
||||
@@ -92,12 +91,12 @@ class ZapierActionsAdapter:
|
||||
)
|
||||
|
||||
params = action.get("params", {})
|
||||
args_fields = {
|
||||
"instructions": (
|
||||
str,
|
||||
Field(description="Instructions for how to execute this action"),
|
||||
)
|
||||
}
|
||||
args_fields = {}
|
||||
|
||||
args_fields["instructions"] = (
|
||||
str,
|
||||
Field(description="Instructions for how to execute this action"),
|
||||
)
|
||||
|
||||
for param_name, param_info in params.items():
|
||||
field_type = (
|
||||
@@ -113,7 +112,7 @@ class ZapierActionsAdapter:
|
||||
Field(description=field_description),
|
||||
)
|
||||
|
||||
args_schema = create_model(f"{tool_name.title()}Schema", **args_fields) # type: ignore[call-overload]
|
||||
args_schema = create_model(f"{tool_name.title()}Schema", **args_fields)
|
||||
|
||||
tool = ZapierActionTool(
|
||||
name=tool_name,
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
from crewai_tools.aws.bedrock import (
|
||||
from .bedrock import (
|
||||
BedrockInvokeAgentTool,
|
||||
BedrockKBRetrieverTool,
|
||||
create_browser_toolkit,
|
||||
create_code_interpreter_toolkit,
|
||||
)
|
||||
from crewai_tools.aws.s3 import S3ReaderTool, S3WriterTool
|
||||
from .s3 import S3ReaderTool, S3WriterTool
|
||||
|
||||
|
||||
__all__ = [
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
from crewai_tools.aws.bedrock.agents.invoke_agent_tool import BedrockInvokeAgentTool
|
||||
from crewai_tools.aws.bedrock.browser import create_browser_toolkit
|
||||
from crewai_tools.aws.bedrock.code_interpreter import create_code_interpreter_toolkit
|
||||
from crewai_tools.aws.bedrock.knowledge_base.retriever_tool import (
|
||||
BedrockKBRetrieverTool,
|
||||
)
|
||||
from .agents.invoke_agent_tool import BedrockInvokeAgentTool
|
||||
from .browser import create_browser_toolkit
|
||||
from .code_interpreter import create_code_interpreter_toolkit
|
||||
from .knowledge_base.retriever_tool import BedrockKBRetrieverTool
|
||||
|
||||
|
||||
__all__ = [
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from crewai_tools.aws.bedrock.agents.invoke_agent_tool import BedrockInvokeAgentTool
|
||||
from .invoke_agent_tool import BedrockInvokeAgentTool
|
||||
|
||||
|
||||
__all__ = ["BedrockInvokeAgentTool"]
|
||||
|
||||
@@ -7,10 +7,7 @@ from crewai.tools import BaseTool
|
||||
from dotenv import load_dotenv
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai_tools.aws.bedrock.exceptions import (
|
||||
BedrockAgentError,
|
||||
BedrockValidationError,
|
||||
)
|
||||
from ..exceptions import BedrockAgentError, BedrockValidationError
|
||||
|
||||
|
||||
# Load environment variables from .env file
|
||||
@@ -27,9 +24,9 @@ class BedrockInvokeAgentTool(BaseTool):
|
||||
name: str = "Bedrock Agent Invoke Tool"
|
||||
description: str = "An agent responsible for policy analysis."
|
||||
args_schema: type[BaseModel] = BedrockInvokeAgentToolInput
|
||||
agent_id: str | None = None
|
||||
agent_alias_id: str | None = None
|
||||
session_id: str | None = None
|
||||
agent_id: str = None
|
||||
agent_alias_id: str = None
|
||||
session_id: str = None
|
||||
enable_trace: bool = False
|
||||
end_session: bool = False
|
||||
package_dependencies: list[str] = Field(default_factory=lambda: ["boto3"])
|
||||
|
||||
@@ -1,7 +1,4 @@
|
||||
from crewai_tools.aws.bedrock.browser.browser_toolkit import (
|
||||
BrowserToolkit,
|
||||
create_browser_toolkit,
|
||||
)
|
||||
from .browser_toolkit import BrowserToolkit, create_browser_toolkit
|
||||
|
||||
|
||||
__all__ = ["BrowserToolkit", "create_browser_toolkit"]
|
||||
|
||||
@@ -9,10 +9,8 @@ from urllib.parse import urlparse
|
||||
from crewai.tools import BaseTool
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai_tools.aws.bedrock.browser.browser_session_manager import (
|
||||
BrowserSessionManager,
|
||||
)
|
||||
from crewai_tools.aws.bedrock.browser.utils import aget_current_page, get_current_page
|
||||
from .browser_session_manager import BrowserSessionManager
|
||||
from .utils import aget_current_page, get_current_page
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -82,9 +80,9 @@ class CurrentWebPageToolInput(BaseModel):
|
||||
class BrowserBaseTool(BaseTool):
|
||||
"""Base class for browser tools."""
|
||||
|
||||
def __init__(self, session_manager: BrowserSessionManager): # type: ignore[call-arg]
|
||||
def __init__(self, session_manager: BrowserSessionManager):
|
||||
"""Initialize with a session manager."""
|
||||
super().__init__() # type: ignore[call-arg]
|
||||
super().__init__()
|
||||
self._session_manager = session_manager
|
||||
|
||||
if self._is_in_asyncio_loop() and hasattr(self, "_arun"):
|
||||
@@ -93,7 +91,7 @@ class BrowserBaseTool(BaseTool):
|
||||
# Override _run to use _arun when in an asyncio loop
|
||||
def patched_run(*args, **kwargs):
|
||||
try:
|
||||
import nest_asyncio # type: ignore[import-untyped]
|
||||
import nest_asyncio
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
nest_asyncio.apply(loop)
|
||||
@@ -103,7 +101,7 @@ class BrowserBaseTool(BaseTool):
|
||||
except Exception as e:
|
||||
return f"Error in patched _run: {e!s}"
|
||||
|
||||
self._run = patched_run # type: ignore[method-assign]
|
||||
self._run = patched_run
|
||||
|
||||
async def get_async_page(self, thread_id: str) -> Any:
|
||||
"""Get or create a page for the specified thread."""
|
||||
@@ -358,7 +356,7 @@ class ExtractHyperlinksTool(BrowserBaseTool):
|
||||
for link in soup.find_all("a", href=True):
|
||||
text = link.get_text().strip()
|
||||
href = link["href"]
|
||||
if href.startswith(("http", "https")): # type: ignore[union-attr]
|
||||
if href.startswith(("http", "https")):
|
||||
links.append({"text": text, "url": href})
|
||||
|
||||
if not links:
|
||||
@@ -390,7 +388,7 @@ class ExtractHyperlinksTool(BrowserBaseTool):
|
||||
for link in soup.find_all("a", href=True):
|
||||
text = link.get_text().strip()
|
||||
href = link["href"]
|
||||
if href.startswith(("http", "https")): # type: ignore[union-attr]
|
||||
if href.startswith(("http", "https")):
|
||||
links.append({"text": text, "url": href})
|
||||
|
||||
if not links:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from crewai_tools.aws.bedrock.code_interpreter.code_interpreter_toolkit import (
|
||||
from .code_interpreter_toolkit import (
|
||||
CodeInterpreterToolkit,
|
||||
create_code_interpreter_toolkit,
|
||||
)
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
from crewai_tools.aws.bedrock.knowledge_base.retriever_tool import (
|
||||
BedrockKBRetrieverTool,
|
||||
)
|
||||
from .retriever_tool import BedrockKBRetrieverTool
|
||||
|
||||
|
||||
__all__ = ["BedrockKBRetrieverTool"]
|
||||
|
||||
@@ -6,10 +6,7 @@ from crewai.tools import BaseTool
|
||||
from dotenv import load_dotenv
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai_tools.aws.bedrock.exceptions import (
|
||||
BedrockKnowledgeBaseError,
|
||||
BedrockValidationError,
|
||||
)
|
||||
from ..exceptions import BedrockKnowledgeBaseError, BedrockValidationError
|
||||
|
||||
|
||||
# Load environment variables from .env file
|
||||
@@ -30,7 +27,7 @@ class BedrockKBRetrieverTool(BaseTool):
|
||||
"Retrieves information from an Amazon Bedrock Knowledge Base given a query"
|
||||
)
|
||||
args_schema: type[BaseModel] = BedrockKBRetrieverToolInput
|
||||
knowledge_base_id: str = None # type: ignore[assignment]
|
||||
knowledge_base_id: str = None
|
||||
number_of_results: int | None = 5
|
||||
retrieval_configuration: dict[str, Any] | None = None
|
||||
guardrail_configuration: dict[str, Any] | None = None
|
||||
@@ -58,7 +55,7 @@ class BedrockKBRetrieverTool(BaseTool):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
# Get knowledge_base_id from environment variable if not provided
|
||||
self.knowledge_base_id = knowledge_base_id or os.getenv("BEDROCK_KB_ID") # type: ignore[assignment]
|
||||
self.knowledge_base_id = knowledge_base_id or os.getenv("BEDROCK_KB_ID")
|
||||
self.number_of_results = number_of_results
|
||||
self.guardrail_configuration = guardrail_configuration
|
||||
self.next_token = next_token
|
||||
@@ -242,7 +239,7 @@ class BedrockKBRetrieverTool(BaseTool):
|
||||
if results:
|
||||
response_object["results"] = results
|
||||
else:
|
||||
response_object["message"] = "No results found for the given query." # type: ignore[assignment]
|
||||
response_object["message"] = "No results found for the given query."
|
||||
|
||||
if "nextToken" in response:
|
||||
response_object["nextToken"] = response["nextToken"]
|
||||
|
||||
@@ -1,2 +1,2 @@
|
||||
from crewai_tools.aws.s3.reader_tool import S3ReaderTool as S3ReaderTool
|
||||
from crewai_tools.aws.s3.writer_tool import S3WriterTool as S3WriterTool
|
||||
from .reader_tool import S3ReaderTool as S3ReaderTool
|
||||
from .writer_tool import S3WriterTool as S3WriterTool
|
||||
|
||||
@@ -17,15 +17,14 @@ class LoaderResult(BaseModel):
|
||||
|
||||
|
||||
class BaseLoader(ABC):
|
||||
def __init__(self, config: dict[str, Any] | None = None) -> None:
|
||||
def __init__(self, config: dict[str, Any] | None = None):
|
||||
self.config = config or {}
|
||||
|
||||
@abstractmethod
|
||||
def load(self, content: SourceContent, **kwargs) -> LoaderResult: ...
|
||||
|
||||
@staticmethod
|
||||
def generate_doc_id(
|
||||
source_ref: str | None = None, content: str | None = None
|
||||
self, source_ref: str | None = None, content: str | None = None
|
||||
) -> str:
|
||||
"""Generate a unique document id based on the source reference and content.
|
||||
If the source reference is not provided, the content is used as the source reference.
|
||||
|
||||
@@ -10,7 +10,7 @@ class RecursiveCharacterTextSplitter:
|
||||
chunk_overlap: int = 200,
|
||||
separators: list[str] | None = None,
|
||||
keep_separator: bool = True,
|
||||
) -> None:
|
||||
):
|
||||
"""Initialize the RecursiveCharacterTextSplitter.
|
||||
|
||||
Args:
|
||||
@@ -36,14 +36,6 @@ class RecursiveCharacterTextSplitter:
|
||||
]
|
||||
|
||||
def split_text(self, text: str) -> list[str]:
|
||||
"""Split the input text into chunks.
|
||||
|
||||
Args:
|
||||
text: The text to split.
|
||||
|
||||
Returns:
|
||||
A list of text chunks.
|
||||
"""
|
||||
return self._split_text(text, self._separators)
|
||||
|
||||
def _split_text(self, text: str, separators: list[str]) -> list[str]:
|
||||
@@ -107,8 +99,8 @@ class RecursiveCharacterTextSplitter:
|
||||
|
||||
def _merge_splits(self, splits: list[str], separator: str) -> list[str]:
|
||||
"""Merge splits into chunks with proper overlap."""
|
||||
docs: list[str] = []
|
||||
current_doc: list[str] = []
|
||||
docs = []
|
||||
current_doc = []
|
||||
total = 0
|
||||
|
||||
for split in splits:
|
||||
@@ -160,7 +152,7 @@ class BaseChunker:
|
||||
chunk_overlap: int = 200,
|
||||
separators: list[str] | None = None,
|
||||
keep_separator: bool = True,
|
||||
) -> None:
|
||||
):
|
||||
"""Initialize the Chunker.
|
||||
|
||||
Args:
|
||||
@@ -177,14 +169,6 @@ class BaseChunker:
|
||||
)
|
||||
|
||||
def chunk(self, text: str) -> list[str]:
|
||||
"""Chunk the input text into smaller pieces.
|
||||
|
||||
Args:
|
||||
text: The text to chunk.
|
||||
|
||||
Returns:
|
||||
A list of text chunks.
|
||||
"""
|
||||
if not text or not text.strip():
|
||||
return []
|
||||
|
||||
|
||||
@@ -218,9 +218,8 @@ class RAG(Adapter):
|
||||
logger.error(f"Failed to get collection info: {e}")
|
||||
return {"error": str(e)}
|
||||
|
||||
@staticmethod
|
||||
def _get_data_type(
|
||||
content: SourceContent, data_type: str | DataType | None = None
|
||||
self, content: SourceContent, data_type: str | DataType | None = None
|
||||
) -> DataType:
|
||||
try:
|
||||
if isinstance(data_type, str):
|
||||
|
||||
@@ -116,7 +116,7 @@ class DataTypes:
|
||||
if isinstance(content, str):
|
||||
try:
|
||||
url = urlparse(content)
|
||||
is_url = bool(url.scheme and url.netloc) or url.scheme == "file"
|
||||
is_url = (url.scheme and url.netloc) or url.scheme == "file"
|
||||
except Exception: # noqa: S110
|
||||
pass
|
||||
|
||||
|
||||
@@ -3,8 +3,6 @@ Enhanced embedding service that leverages CrewAI's existing embedding providers.
|
||||
This replaces the litellm-based EmbeddingService with a more flexible architecture.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Any
|
||||
@@ -84,8 +82,7 @@ class EmbeddingService:
|
||||
self._embedding_function = None
|
||||
self._initialize_embedding_function()
|
||||
|
||||
@staticmethod
|
||||
def _get_default_api_key(provider: str) -> str | None:
|
||||
def _get_default_api_key(self, provider: str) -> str | None:
|
||||
"""Get default API key from environment variables."""
|
||||
env_key_map = {
|
||||
"azure": "AZURE_OPENAI_API_KEY",
|
||||
@@ -386,14 +383,14 @@ class EmbeddingService:
|
||||
model: str = "text-embedding-3-small",
|
||||
api_key: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> EmbeddingService:
|
||||
) -> "EmbeddingService":
|
||||
"""Create an OpenAI embedding service."""
|
||||
return cls(provider="openai", model=model, api_key=api_key, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def create_voyage_service(
|
||||
cls, model: str = "voyage-2", api_key: str | None = None, **kwargs: Any
|
||||
) -> EmbeddingService:
|
||||
) -> "EmbeddingService":
|
||||
"""Create a Voyage AI embedding service."""
|
||||
return cls(provider="voyageai", model=model, api_key=api_key, **kwargs)
|
||||
|
||||
@@ -403,7 +400,7 @@ class EmbeddingService:
|
||||
model: str = "embed-english-v3.0",
|
||||
api_key: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> EmbeddingService:
|
||||
) -> "EmbeddingService":
|
||||
"""Create a Cohere embedding service."""
|
||||
return cls(provider="cohere", model=model, api_key=api_key, **kwargs)
|
||||
|
||||
@@ -413,7 +410,7 @@ class EmbeddingService:
|
||||
model: str = "models/embedding-001",
|
||||
api_key: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> EmbeddingService:
|
||||
) -> "EmbeddingService":
|
||||
"""Create a Google Gemini embedding service."""
|
||||
return cls(
|
||||
provider="google-generativeai", model=model, api_key=api_key, **kwargs
|
||||
@@ -425,7 +422,7 @@ class EmbeddingService:
|
||||
model: str = "text-embedding-ada-002",
|
||||
api_key: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> EmbeddingService:
|
||||
) -> "EmbeddingService":
|
||||
"""Create an Azure OpenAI embedding service."""
|
||||
return cls(provider="azure", model=model, api_key=api_key, **kwargs)
|
||||
|
||||
@@ -435,7 +432,7 @@ class EmbeddingService:
|
||||
model: str = "amazon.titan-embed-text-v1",
|
||||
api_key: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> EmbeddingService:
|
||||
) -> "EmbeddingService":
|
||||
"""Create an Amazon Bedrock embedding service."""
|
||||
return cls(provider="amazon-bedrock", model=model, api_key=api_key, **kwargs)
|
||||
|
||||
@@ -445,7 +442,7 @@ class EmbeddingService:
|
||||
model: str = "sentence-transformers/all-MiniLM-L6-v2",
|
||||
api_key: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> EmbeddingService:
|
||||
) -> "EmbeddingService":
|
||||
"""Create a Hugging Face embedding service."""
|
||||
return cls(provider="huggingface", model=model, api_key=api_key, **kwargs)
|
||||
|
||||
@@ -454,7 +451,7 @@ class EmbeddingService:
|
||||
cls,
|
||||
model: str = "all-MiniLM-L6-v2",
|
||||
**kwargs: Any,
|
||||
) -> EmbeddingService:
|
||||
) -> "EmbeddingService":
|
||||
"""Create a Sentence Transformers embedding service (local)."""
|
||||
return cls(provider="sentence-transformer", model=model, **kwargs)
|
||||
|
||||
@@ -463,7 +460,7 @@ class EmbeddingService:
|
||||
cls,
|
||||
model: str = "nomic-embed-text",
|
||||
**kwargs: Any,
|
||||
) -> EmbeddingService:
|
||||
) -> "EmbeddingService":
|
||||
"""Create an Ollama embedding service (local)."""
|
||||
return cls(provider="ollama", model=model, **kwargs)
|
||||
|
||||
@@ -473,7 +470,7 @@ class EmbeddingService:
|
||||
model: str = "jina-embeddings-v2-base-en",
|
||||
api_key: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> EmbeddingService:
|
||||
) -> "EmbeddingService":
|
||||
"""Create a Jina AI embedding service."""
|
||||
return cls(provider="jina", model=model, api_key=api_key, **kwargs)
|
||||
|
||||
@@ -482,7 +479,7 @@ class EmbeddingService:
|
||||
cls,
|
||||
model: str = "hkunlp/instructor-large",
|
||||
**kwargs: Any,
|
||||
) -> EmbeddingService:
|
||||
) -> "EmbeddingService":
|
||||
"""Create an Instructor embedding service."""
|
||||
return cls(provider="instructor", model=model, **kwargs)
|
||||
|
||||
@@ -492,7 +489,7 @@ class EmbeddingService:
|
||||
model: str = "ibm/slate-125m-english-rtrvr",
|
||||
api_key: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> EmbeddingService:
|
||||
) -> "EmbeddingService":
|
||||
"""Create a Watson X embedding service."""
|
||||
return cls(provider="watsonx", model=model, api_key=api_key, **kwargs)
|
||||
|
||||
@@ -501,7 +498,7 @@ class EmbeddingService:
|
||||
cls,
|
||||
embedding_callable: Any,
|
||||
**kwargs: Any,
|
||||
) -> EmbeddingService:
|
||||
) -> "EmbeddingService":
|
||||
"""Create a custom embedding service with your own embedding function."""
|
||||
return cls(
|
||||
provider="custom",
|
||||
|
||||
@@ -2,30 +2,41 @@ import csv
|
||||
from io import StringIO
|
||||
|
||||
from crewai_tools.rag.base_loader import BaseLoader, LoaderResult
|
||||
from crewai_tools.rag.loaders.utils import load_from_url
|
||||
from crewai_tools.rag.source_content import SourceContent
|
||||
|
||||
|
||||
class CSVLoader(BaseLoader):
|
||||
def load(self, source_content: SourceContent, **kwargs) -> LoaderResult: # type: ignore[override]
|
||||
def load(self, source_content: SourceContent, **kwargs) -> LoaderResult:
|
||||
source_ref = source_content.source_ref
|
||||
|
||||
content_str = source_content.source
|
||||
if source_content.is_url():
|
||||
content_str = load_from_url(
|
||||
content_str,
|
||||
kwargs,
|
||||
accept_header="text/csv, application/csv, text/plain",
|
||||
loader_name="CSVLoader",
|
||||
)
|
||||
content_str = self._load_from_url(content_str, kwargs)
|
||||
elif source_content.path_exists():
|
||||
content_str = self._load_from_file(content_str)
|
||||
|
||||
return self._parse_csv(content_str, source_ref)
|
||||
|
||||
@staticmethod
|
||||
def _load_from_file(path: str) -> str:
|
||||
with open(path, encoding="utf-8") as file:
|
||||
def _load_from_url(self, url: str, kwargs: dict) -> str:
|
||||
import requests
|
||||
|
||||
headers = kwargs.get(
|
||||
"headers",
|
||||
{
|
||||
"Accept": "text/csv, application/csv, text/plain",
|
||||
"User-Agent": "Mozilla/5.0 (compatible; crewai-tools CSVLoader)",
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
response = requests.get(url, headers=headers, timeout=30)
|
||||
response.raise_for_status()
|
||||
return response.text
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error fetching CSV from URL {url}: {e!s}") from e
|
||||
|
||||
def _load_from_file(self, path: str) -> str:
|
||||
with open(path, "r", encoding="utf-8") as file:
|
||||
return file.read()
|
||||
|
||||
def _parse_csv(self, content: str, source_ref: str) -> LoaderResult:
|
||||
|
||||
@@ -6,11 +6,11 @@ from crewai_tools.rag.source_content import SourceContent
|
||||
|
||||
|
||||
class DirectoryLoader(BaseLoader):
|
||||
def load(self, source_content: SourceContent, **kwargs) -> LoaderResult: # type: ignore[override]
|
||||
def load(self, source_content: SourceContent, **kwargs) -> LoaderResult:
|
||||
"""Load and process all files from a directory recursively.
|
||||
|
||||
Args:
|
||||
source_content: Directory path or URL to a directory listing
|
||||
source: Directory path or URL to a directory listing
|
||||
**kwargs: Additional options:
|
||||
- recursive: bool (default True) - Whether to search recursively
|
||||
- include_extensions: list - Only include files with these extensions
|
||||
@@ -33,16 +33,16 @@ class DirectoryLoader(BaseLoader):
|
||||
return self._process_directory(source_ref, kwargs)
|
||||
|
||||
def _process_directory(self, dir_path: str, kwargs: dict) -> LoaderResult:
|
||||
recursive: bool = kwargs.get("recursive", True)
|
||||
include_extensions: list[str] | None = kwargs.get("include_extensions", None)
|
||||
exclude_extensions: list[str] | None = kwargs.get("exclude_extensions", None)
|
||||
max_files: int | None = kwargs.get("max_files", None)
|
||||
recursive = kwargs.get("recursive", True)
|
||||
include_extensions = kwargs.get("include_extensions", None)
|
||||
exclude_extensions = kwargs.get("exclude_extensions", None)
|
||||
max_files = kwargs.get("max_files", None)
|
||||
|
||||
files = self._find_files(
|
||||
dir_path, recursive, include_extensions, exclude_extensions
|
||||
)
|
||||
|
||||
if max_files is not None and len(files) > max_files:
|
||||
if max_files and len(files) > max_files:
|
||||
files = files[:max_files]
|
||||
|
||||
all_contents = []
|
||||
@@ -115,8 +115,8 @@ class DirectoryLoader(BaseLoader):
|
||||
|
||||
return sorted(files)
|
||||
|
||||
@staticmethod
|
||||
def _should_include_file(
|
||||
self,
|
||||
filename: str,
|
||||
include_ext: list[str] | None = None,
|
||||
exclude_ext: list[str] | None = None,
|
||||
@@ -141,8 +141,7 @@ class DirectoryLoader(BaseLoader):
|
||||
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def _process_single_file(file_path: str) -> LoaderResult:
|
||||
def _process_single_file(self, file_path: str) -> LoaderResult:
|
||||
from crewai_tools.rag.data_types import DataTypes
|
||||
|
||||
data_type = DataTypes.from_content(Path(file_path))
|
||||
|
||||
@@ -12,7 +12,7 @@ from crewai_tools.rag.source_content import SourceContent
|
||||
class DocsSiteLoader(BaseLoader):
|
||||
"""Loader for documentation websites."""
|
||||
|
||||
def load(self, source: SourceContent, **kwargs) -> LoaderResult: # type: ignore[override]
|
||||
def load(self, source: SourceContent, **kwargs) -> LoaderResult:
|
||||
"""Load content from a documentation site.
|
||||
|
||||
Args:
|
||||
@@ -40,6 +40,7 @@ class DocsSiteLoader(BaseLoader):
|
||||
title = soup.find("title")
|
||||
title_text = title.get_text(strip=True) if title else "Documentation"
|
||||
|
||||
main_content = None
|
||||
for selector in [
|
||||
"main",
|
||||
"article",
|
||||
@@ -81,10 +82,8 @@ class DocsSiteLoader(BaseLoader):
|
||||
if nav:
|
||||
links = nav.find_all("a", href=True)
|
||||
for link in links[:20]:
|
||||
href = link.get("href", "")
|
||||
if isinstance(href, str) and not href.startswith(
|
||||
("http://", "https://", "mailto:", "#")
|
||||
):
|
||||
href = link["href"]
|
||||
if not href.startswith(("http://", "https://", "mailto:", "#")):
|
||||
full_url = urljoin(docs_url, href)
|
||||
nav_links.append(f"- {link.get_text(strip=True)}: {full_url}")
|
||||
|
||||
|
||||
@@ -1,15 +1,12 @@
|
||||
import os
|
||||
import tempfile
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
|
||||
from crewai_tools.rag.base_loader import BaseLoader, LoaderResult
|
||||
from crewai_tools.rag.source_content import SourceContent
|
||||
|
||||
|
||||
class DOCXLoader(BaseLoader):
|
||||
def load(self, source_content: SourceContent, **kwargs) -> LoaderResult: # type: ignore[override]
|
||||
def load(self, source_content: SourceContent, **kwargs) -> LoaderResult:
|
||||
try:
|
||||
from docx import Document as DocxDocument
|
||||
except ImportError as e:
|
||||
@@ -32,8 +29,9 @@ class DOCXLoader(BaseLoader):
|
||||
f"Source must be a valid file path or URL, got: {source_content.source}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _download_from_url(url: str, kwargs: dict) -> str:
|
||||
def _download_from_url(self, url: str, kwargs: dict) -> str:
|
||||
import requests
|
||||
|
||||
headers = kwargs.get(
|
||||
"headers",
|
||||
{
|
||||
@@ -51,13 +49,13 @@ class DOCXLoader(BaseLoader):
|
||||
temp_file.write(response.content)
|
||||
return temp_file.name
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error fetching content from URL {url}: {e!s}") from e
|
||||
raise ValueError(f"Error fetching DOCX from URL {url}: {e!s}") from e
|
||||
|
||||
def _load_from_file(
|
||||
self,
|
||||
file_path: str,
|
||||
source_ref: str,
|
||||
DocxDocument: Any, # noqa: N803
|
||||
DocxDocument, # noqa: N803
|
||||
) -> LoaderResult:
|
||||
try:
|
||||
doc = DocxDocument(file_path)
|
||||
|
||||
@@ -9,7 +9,7 @@ from crewai_tools.rag.source_content import SourceContent
|
||||
class GithubLoader(BaseLoader):
|
||||
"""Loader for GitHub repository content."""
|
||||
|
||||
def load(self, source: SourceContent, **kwargs) -> LoaderResult: # type: ignore[override]
|
||||
def load(self, source: SourceContent, **kwargs) -> LoaderResult:
|
||||
"""Load content from a GitHub repository.
|
||||
|
||||
Args:
|
||||
@@ -54,7 +54,9 @@ class GithubLoader(BaseLoader):
|
||||
try:
|
||||
readme = repo.get_readme()
|
||||
all_content.append("README:")
|
||||
all_content.append(readme.decoded_content.decode(errors="ignore"))
|
||||
all_content.append(
|
||||
readme.decoded_content.decode("utf-8", errors="ignore")
|
||||
)
|
||||
all_content.append("")
|
||||
except GithubException:
|
||||
pass
|
||||
|
||||
@@ -1,30 +1,52 @@
|
||||
import json
|
||||
|
||||
from crewai_tools.rag.base_loader import BaseLoader, LoaderResult
|
||||
from crewai_tools.rag.loaders.utils import load_from_url
|
||||
from crewai_tools.rag.source_content import SourceContent
|
||||
|
||||
|
||||
class JSONLoader(BaseLoader):
|
||||
def load(self, source_content: SourceContent, **kwargs) -> LoaderResult: # type: ignore[override]
|
||||
def load(self, source_content: SourceContent, **kwargs) -> LoaderResult:
|
||||
source_ref = source_content.source_ref
|
||||
content = source_content.source
|
||||
|
||||
if source_content.is_url():
|
||||
content = load_from_url(
|
||||
source_ref,
|
||||
kwargs,
|
||||
accept_header="application/json",
|
||||
loader_name="JSONLoader",
|
||||
)
|
||||
content = self._load_from_url(source_ref, kwargs)
|
||||
elif source_content.path_exists():
|
||||
content = self._load_from_file(source_ref)
|
||||
|
||||
return self._parse_json(content, source_ref)
|
||||
|
||||
@staticmethod
|
||||
def _load_from_file(path: str) -> str:
|
||||
with open(path, encoding="utf-8") as file:
|
||||
def _load_from_url(self, url: str, kwargs: dict) -> str:
|
||||
import requests
|
||||
|
||||
headers = kwargs.get(
|
||||
"headers",
|
||||
{
|
||||
"Accept": "application/json",
|
||||
"User-Agent": "Mozilla/5.0 (compatible; crewai-tools JSONLoader)",
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
response = requests.get(url, headers=headers, timeout=30)
|
||||
response.raise_for_status()
|
||||
return (
|
||||
response.text
|
||||
if not self._is_json_response(response)
|
||||
else json.dumps(response.json(), indent=2)
|
||||
)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error fetching JSON from URL {url}: {e!s}") from e
|
||||
|
||||
def _is_json_response(self, response) -> bool:
|
||||
try:
|
||||
response.json()
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
def _load_from_file(self, path: str) -> str:
|
||||
with open(path, "r", encoding="utf-8") as file:
|
||||
return file.read()
|
||||
|
||||
def _parse_json(self, content: str, source_ref: str) -> LoaderResult:
|
||||
|
||||
@@ -1,55 +1,61 @@
|
||||
import re
|
||||
from typing import Final
|
||||
|
||||
from crewai_tools.rag.base_loader import BaseLoader, LoaderResult
|
||||
from crewai_tools.rag.loaders.utils import load_from_url
|
||||
from crewai_tools.rag.source_content import SourceContent
|
||||
|
||||
|
||||
_IMPORT_PATTERN: Final[re.Pattern[str]] = re.compile(r"^import\s+.*?\n", re.MULTILINE)
|
||||
_EXPORT_PATTERN: Final[re.Pattern[str]] = re.compile(
|
||||
r"^export\s+.*?(?:\n|$)", re.MULTILINE
|
||||
)
|
||||
_JSX_TAG_PATTERN: Final[re.Pattern[str]] = re.compile(r"<[^>]+>")
|
||||
_EXTRA_NEWLINES_PATTERN: Final[re.Pattern[str]] = re.compile(r"\n\s*\n\s*\n")
|
||||
|
||||
|
||||
class MDXLoader(BaseLoader):
|
||||
def load(self, source_content: SourceContent, **kwargs) -> LoaderResult: # type: ignore[override]
|
||||
def load(self, source_content: SourceContent, **kwargs) -> LoaderResult:
|
||||
source_ref = source_content.source_ref
|
||||
content = source_content.source
|
||||
|
||||
if source_content.is_url():
|
||||
content = load_from_url(
|
||||
source_ref,
|
||||
kwargs,
|
||||
accept_header="text/markdown, text/x-markdown, text/plain",
|
||||
loader_name="MDXLoader",
|
||||
)
|
||||
content = self._load_from_url(source_ref, kwargs)
|
||||
elif source_content.path_exists():
|
||||
content = self._load_from_file(source_ref)
|
||||
|
||||
return self._parse_mdx(content, source_ref)
|
||||
|
||||
@staticmethod
|
||||
def _load_from_file(path: str) -> str:
|
||||
with open(path, encoding="utf-8") as file:
|
||||
def _load_from_url(self, url: str, kwargs: dict) -> str:
|
||||
import requests
|
||||
|
||||
headers = kwargs.get(
|
||||
"headers",
|
||||
{
|
||||
"Accept": "text/markdown, text/x-markdown, text/plain",
|
||||
"User-Agent": "Mozilla/5.0 (compatible; crewai-tools MDXLoader)",
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
response = requests.get(url, headers=headers, timeout=30)
|
||||
response.raise_for_status()
|
||||
return response.text
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error fetching MDX from URL {url}: {e!s}") from e
|
||||
|
||||
def _load_from_file(self, path: str) -> str:
|
||||
with open(path, "r", encoding="utf-8") as file:
|
||||
return file.read()
|
||||
|
||||
def _parse_mdx(self, content: str, source_ref: str) -> LoaderResult:
|
||||
cleaned_content = content
|
||||
|
||||
# Remove import statements
|
||||
cleaned_content = _IMPORT_PATTERN.sub("", cleaned_content)
|
||||
cleaned_content = re.sub(
|
||||
r"^import\s+.*?\n", "", cleaned_content, flags=re.MULTILINE
|
||||
)
|
||||
|
||||
# Remove export statements
|
||||
cleaned_content = _EXPORT_PATTERN.sub("", cleaned_content)
|
||||
cleaned_content = re.sub(
|
||||
r"^export\s+.*?(?:\n|$)", "", cleaned_content, flags=re.MULTILINE
|
||||
)
|
||||
|
||||
# Remove JSX tags (simple approach)
|
||||
cleaned_content = _JSX_TAG_PATTERN.sub("", cleaned_content)
|
||||
cleaned_content = re.sub(r"<[^>]+>", "", cleaned_content)
|
||||
|
||||
# Clean up extra whitespace
|
||||
cleaned_content = _EXTRA_NEWLINES_PATTERN.sub("\n\n", cleaned_content)
|
||||
cleaned_content = re.sub(r"\n\s*\n\s*\n", "\n\n", cleaned_content)
|
||||
cleaned_content = cleaned_content.strip()
|
||||
|
||||
metadata = {"format": "mdx"}
|
||||
|
||||
@@ -1,10 +1,8 @@
|
||||
"""MySQL database loader."""
|
||||
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from pymysql import Error, connect
|
||||
from pymysql.cursors import DictCursor
|
||||
import pymysql
|
||||
|
||||
from crewai_tools.rag.base_loader import BaseLoader, LoaderResult
|
||||
from crewai_tools.rag.source_content import SourceContent
|
||||
@@ -13,7 +11,7 @@ from crewai_tools.rag.source_content import SourceContent
|
||||
class MySQLLoader(BaseLoader):
|
||||
"""Loader for MySQL database content."""
|
||||
|
||||
def load(self, source: SourceContent, **kwargs: Any) -> LoaderResult: # type: ignore[override]
|
||||
def load(self, source: SourceContent, **kwargs) -> LoaderResult:
|
||||
"""Load content from a MySQL database table.
|
||||
|
||||
Args:
|
||||
@@ -42,14 +40,14 @@ class MySQLLoader(BaseLoader):
|
||||
"password": parsed.password,
|
||||
"database": parsed.path.lstrip("/") if parsed.path else None,
|
||||
"charset": "utf8mb4",
|
||||
"cursorclass": DictCursor,
|
||||
"cursorclass": pymysql.cursors.DictCursor,
|
||||
}
|
||||
|
||||
if not connection_params["database"]:
|
||||
raise ValueError("Database name is required in the URI")
|
||||
|
||||
try:
|
||||
connection = connect(**connection_params)
|
||||
connection = pymysql.connect(**connection_params)
|
||||
try:
|
||||
with connection.cursor() as cursor:
|
||||
cursor.execute(query)
|
||||
@@ -96,7 +94,7 @@ class MySQLLoader(BaseLoader):
|
||||
)
|
||||
finally:
|
||||
connection.close()
|
||||
except Error as e:
|
||||
except pymysql.Error as e:
|
||||
raise ValueError(f"MySQL database error: {e}") from e
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to load data from MySQL: {e}") from e
|
||||
|
||||
@@ -11,7 +11,7 @@ from crewai_tools.rag.source_content import SourceContent
|
||||
class PDFLoader(BaseLoader):
|
||||
"""Loader for PDF files."""
|
||||
|
||||
def load(self, source: SourceContent, **kwargs) -> LoaderResult: # type: ignore[override]
|
||||
def load(self, source: SourceContent, **kwargs) -> LoaderResult:
|
||||
"""Load and extract text from a PDF file.
|
||||
|
||||
Args:
|
||||
@@ -28,7 +28,7 @@ class PDFLoader(BaseLoader):
|
||||
import pypdf
|
||||
except ImportError:
|
||||
try:
|
||||
import PyPDF2 as pypdf # type: ignore[import-not-found,no-redef] # noqa: N813
|
||||
import PyPDF2 as pypdf # noqa: N813
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"PDF support requires pypdf or PyPDF2. Install with: uv add pypdf"
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from psycopg2 import Error, connect
|
||||
import psycopg2
|
||||
from psycopg2.extras import RealDictCursor
|
||||
|
||||
from crewai_tools.rag.base_loader import BaseLoader, LoaderResult
|
||||
@@ -12,7 +12,7 @@ from crewai_tools.rag.source_content import SourceContent
|
||||
class PostgresLoader(BaseLoader):
|
||||
"""Loader for PostgreSQL database content."""
|
||||
|
||||
def load(self, source: SourceContent, **kwargs) -> LoaderResult: # type: ignore[override]
|
||||
def load(self, source: SourceContent, **kwargs) -> LoaderResult:
|
||||
"""Load content from a PostgreSQL database table.
|
||||
|
||||
Args:
|
||||
@@ -47,7 +47,7 @@ class PostgresLoader(BaseLoader):
|
||||
raise ValueError("Database name is required in the URI")
|
||||
|
||||
try:
|
||||
connection = connect(**connection_params)
|
||||
connection = psycopg2.connect(**connection_params)
|
||||
try:
|
||||
with connection.cursor() as cursor:
|
||||
cursor.execute(query)
|
||||
@@ -94,7 +94,7 @@ class PostgresLoader(BaseLoader):
|
||||
)
|
||||
finally:
|
||||
connection.close()
|
||||
except Error as e:
|
||||
except psycopg2.Error as e:
|
||||
raise ValueError(f"PostgreSQL database error: {e}") from e
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to load data from PostgreSQL: {e}") from e
|
||||
|
||||
@@ -3,14 +3,14 @@ from crewai_tools.rag.source_content import SourceContent
|
||||
|
||||
|
||||
class TextFileLoader(BaseLoader):
|
||||
def load(self, source_content: SourceContent, **kwargs) -> LoaderResult: # type: ignore[override]
|
||||
def load(self, source_content: SourceContent, **kwargs) -> LoaderResult:
|
||||
source_ref = source_content.source_ref
|
||||
if not source_content.path_exists():
|
||||
raise FileNotFoundError(
|
||||
f"The following file does not exist: {source_content.source}"
|
||||
)
|
||||
|
||||
with open(source_content.source, encoding="utf-8") as file:
|
||||
with open(source_content.source, "r", encoding="utf-8") as file:
|
||||
content = file.read()
|
||||
|
||||
return LoaderResult(
|
||||
@@ -21,7 +21,7 @@ class TextFileLoader(BaseLoader):
|
||||
|
||||
|
||||
class TextLoader(BaseLoader):
|
||||
def load(self, source_content: SourceContent, **kwargs) -> LoaderResult: # type: ignore[override]
|
||||
def load(self, source_content: SourceContent, **kwargs) -> LoaderResult:
|
||||
return LoaderResult(
|
||||
content=source_content.source,
|
||||
source=source_content.source_ref,
|
||||
|
||||
@@ -1,36 +0,0 @@
|
||||
"""Utility functions for RAG loaders."""
|
||||
|
||||
|
||||
def load_from_url(
|
||||
url: str, kwargs: dict, accept_header: str = "*/*", loader_name: str = "Loader"
|
||||
) -> str:
|
||||
"""Load content from a URL.
|
||||
|
||||
Args:
|
||||
url: The URL to fetch content from
|
||||
kwargs: Additional keyword arguments (can include 'headers' override)
|
||||
accept_header: The Accept header value for the request
|
||||
loader_name: The name of the loader for the User-Agent header
|
||||
|
||||
Returns:
|
||||
The text content from the URL
|
||||
|
||||
Raises:
|
||||
ValueError: If there's an error fetching the URL
|
||||
"""
|
||||
import requests
|
||||
|
||||
headers = kwargs.get(
|
||||
"headers",
|
||||
{
|
||||
"Accept": accept_header,
|
||||
"User-Agent": f"Mozilla/5.0 (compatible; crewai-tools {loader_name})",
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
response = requests.get(url, headers=headers, timeout=30)
|
||||
response.raise_for_status()
|
||||
return response.text
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error fetching content from URL {url}: {e!s}") from e
|
||||
@@ -1,5 +1,4 @@
|
||||
import re
|
||||
from typing import Final
|
||||
|
||||
from bs4 import BeautifulSoup
|
||||
import requests
|
||||
@@ -8,12 +7,8 @@ from crewai_tools.rag.base_loader import BaseLoader, LoaderResult
|
||||
from crewai_tools.rag.source_content import SourceContent
|
||||
|
||||
|
||||
_SPACES_PATTERN: Final[re.Pattern[str]] = re.compile(r"[ \t]+")
|
||||
_NEWLINE_PATTERN: Final[re.Pattern[str]] = re.compile(r"\s+\n\s+")
|
||||
|
||||
|
||||
class WebPageLoader(BaseLoader):
|
||||
def load(self, source_content: SourceContent, **kwargs) -> LoaderResult: # type: ignore[override]
|
||||
def load(self, source_content: SourceContent, **kwargs) -> LoaderResult:
|
||||
url = source_content.source
|
||||
headers = kwargs.get(
|
||||
"headers",
|
||||
@@ -34,8 +29,8 @@ class WebPageLoader(BaseLoader):
|
||||
script.decompose()
|
||||
|
||||
text = soup.get_text(" ")
|
||||
text = _SPACES_PATTERN.sub(" ", text)
|
||||
text = _NEWLINE_PATTERN.sub("\n", text)
|
||||
text = re.sub("[ \t]+", " ", text)
|
||||
text = re.sub("\\s+\n\\s+", "\n", text)
|
||||
text = text.strip()
|
||||
|
||||
title = (
|
||||
|
||||
@@ -1,48 +1,49 @@
|
||||
from typing import Any
|
||||
from xml.etree.ElementTree import ParseError, fromstring, parse
|
||||
import xml.etree.ElementTree as ET
|
||||
|
||||
from crewai_tools.rag.base_loader import BaseLoader, LoaderResult
|
||||
from crewai_tools.rag.loaders.utils import load_from_url
|
||||
from crewai_tools.rag.source_content import SourceContent
|
||||
|
||||
|
||||
class XMLLoader(BaseLoader):
|
||||
def load(self, source_content: SourceContent, **kwargs: Any) -> LoaderResult: # type: ignore[override]
|
||||
"""Load and parse XML content from various sources.
|
||||
|
||||
Args:
|
||||
source_content: SourceContent: The source content to load.
|
||||
**kwargs: Additional keyword arguments for loading from URL.
|
||||
|
||||
Returns:
|
||||
LoaderResult: The result of loading and parsing the XML content.
|
||||
"""
|
||||
def load(self, source_content: SourceContent, **kwargs) -> LoaderResult:
|
||||
source_ref = source_content.source_ref
|
||||
content = source_content.source
|
||||
|
||||
if source_content.is_url():
|
||||
content = load_from_url(
|
||||
source_ref,
|
||||
kwargs,
|
||||
accept_header="application/xml, text/xml, text/plain",
|
||||
loader_name="XMLLoader",
|
||||
)
|
||||
content = self._load_from_url(source_ref, kwargs)
|
||||
elif source_content.path_exists():
|
||||
content = self._load_from_file(source_ref)
|
||||
|
||||
return self._parse_xml(content, source_ref)
|
||||
|
||||
@staticmethod
|
||||
def _load_from_file(path: str) -> str:
|
||||
with open(path, encoding="utf-8") as file:
|
||||
def _load_from_url(self, url: str, kwargs: dict) -> str:
|
||||
import requests
|
||||
|
||||
headers = kwargs.get(
|
||||
"headers",
|
||||
{
|
||||
"Accept": "application/xml, text/xml, text/plain",
|
||||
"User-Agent": "Mozilla/5.0 (compatible; crewai-tools XMLLoader)",
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
response = requests.get(url, headers=headers, timeout=30)
|
||||
response.raise_for_status()
|
||||
return response.text
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error fetching XML from URL {url}: {e!s}") from e
|
||||
|
||||
def _load_from_file(self, path: str) -> str:
|
||||
with open(path, "r", encoding="utf-8") as file:
|
||||
return file.read()
|
||||
|
||||
def _parse_xml(self, content: str, source_ref: str) -> LoaderResult:
|
||||
try:
|
||||
if content.strip().startswith("<"):
|
||||
root = fromstring(content) # noqa: S314
|
||||
root = ET.fromstring(content) # noqa: S314
|
||||
else:
|
||||
root = parse(source_ref).getroot() # noqa: S314
|
||||
root = ET.parse(source_ref).getroot() # noqa: S314
|
||||
|
||||
text_parts = []
|
||||
for text_content in root.itertext():
|
||||
@@ -51,7 +52,7 @@ class XMLLoader(BaseLoader):
|
||||
|
||||
text = "\n".join(text_parts)
|
||||
metadata = {"format": "xml", "root_tag": root.tag}
|
||||
except ParseError as e:
|
||||
except ET.ParseError as e:
|
||||
text = content
|
||||
metadata = {"format": "xml", "parse_error": str(e)}
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ from crewai_tools.rag.source_content import SourceContent
|
||||
class YoutubeChannelLoader(BaseLoader):
|
||||
"""Loader for YouTube channels."""
|
||||
|
||||
def load(self, source: SourceContent, **kwargs) -> LoaderResult: # type: ignore[override]
|
||||
def load(self, source: SourceContent, **kwargs) -> LoaderResult:
|
||||
"""Load and extract content from a YouTube channel.
|
||||
|
||||
Args:
|
||||
@@ -24,7 +24,7 @@ class YoutubeChannelLoader(BaseLoader):
|
||||
ValueError: If the URL is not a valid YouTube channel URL
|
||||
"""
|
||||
try:
|
||||
from pytube import Channel # type: ignore[import-untyped]
|
||||
from pytube import Channel
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"YouTube channel support requires pytube. Install with: uv add pytube"
|
||||
@@ -89,6 +89,7 @@ class YoutubeChannelLoader(BaseLoader):
|
||||
try:
|
||||
api = YouTubeTranscriptApi()
|
||||
transcript_list = api.list(video_id)
|
||||
transcript = None
|
||||
|
||||
try:
|
||||
transcript = transcript_list.find_transcript(["en"])
|
||||
@@ -100,7 +101,7 @@ class YoutubeChannelLoader(BaseLoader):
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
transcript = next(iter(transcript_list))
|
||||
transcript = next(iter(transcript_list), None)
|
||||
|
||||
if transcript:
|
||||
transcript_data = transcript.fetch()
|
||||
@@ -147,8 +148,7 @@ class YoutubeChannelLoader(BaseLoader):
|
||||
doc_id=self.generate_doc_id(source_ref=channel_url, content=content),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _extract_video_id(url: str) -> str | None:
|
||||
def _extract_video_id(self, url: str) -> str | None:
|
||||
"""Extract video ID from YouTube URL."""
|
||||
patterns = [
|
||||
r"(?:youtube\.com\/watch\?v=|youtu\.be\/|youtube\.com\/embed\/|youtube\.com\/v\/)([^&\n?#]+)",
|
||||
|
||||
@@ -11,7 +11,7 @@ from crewai_tools.rag.source_content import SourceContent
|
||||
class YoutubeVideoLoader(BaseLoader):
|
||||
"""Loader for YouTube videos."""
|
||||
|
||||
def load(self, source: SourceContent, **kwargs) -> LoaderResult: # type: ignore[override]
|
||||
def load(self, source: SourceContent, **kwargs) -> LoaderResult:
|
||||
"""Load and extract transcript from a YouTube video.
|
||||
|
||||
Args:
|
||||
@@ -48,6 +48,7 @@ class YoutubeVideoLoader(BaseLoader):
|
||||
api = YouTubeTranscriptApi()
|
||||
transcript_list = api.list(video_id)
|
||||
|
||||
transcript = None
|
||||
try:
|
||||
transcript = transcript_list.find_transcript(["en"])
|
||||
except Exception:
|
||||
@@ -71,7 +72,7 @@ class YoutubeVideoLoader(BaseLoader):
|
||||
content = " ".join(text_content)
|
||||
|
||||
try:
|
||||
from pytube import YouTube # type: ignore[import-untyped]
|
||||
from pytube import YouTube
|
||||
|
||||
yt = YouTube(video_url)
|
||||
metadata["title"] = yt.title
|
||||
@@ -102,8 +103,7 @@ class YoutubeVideoLoader(BaseLoader):
|
||||
doc_id=self.generate_doc_id(source_ref=video_url, content=content),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _extract_video_id(url: str) -> str | None:
|
||||
def _extract_video_id(self, url: str) -> str | None:
|
||||
"""Extract video ID from various YouTube URL formats."""
|
||||
patterns = [
|
||||
r"(?:youtube\.com\/watch\?v=|youtu\.be\/|youtube\.com\/embed\/|youtube\.com\/v\/)([^&\n?#]+)",
|
||||
|
||||
@@ -3,15 +3,7 @@ from typing import Any
|
||||
|
||||
|
||||
def compute_sha256(content: str) -> str:
|
||||
"""Compute the SHA-256 hash of the given content.
|
||||
|
||||
Args:
|
||||
content: The content to hash.
|
||||
|
||||
Returns:
|
||||
The SHA-256 hash of the content as a hexadecimal string.
|
||||
"""
|
||||
return hashlib.sha256(content.encode()).hexdigest()
|
||||
return hashlib.sha256(content.encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
def sanitize_metadata_for_chromadb(metadata: dict[str, Any]) -> dict[str, Any]:
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import cached_property
|
||||
import os
|
||||
from pathlib import Path
|
||||
@@ -30,7 +28,7 @@ class SourceContent:
|
||||
return os.path.exists(self.source)
|
||||
|
||||
@cached_property
|
||||
def data_type(self) -> DataType:
|
||||
def data_type(self) -> "DataType":
|
||||
from crewai_tools.rag.data_types import DataTypes
|
||||
|
||||
return DataTypes.from_content(self.source)
|
||||
|
||||
@@ -32,6 +32,9 @@ from crewai_tools.tools.contextualai_rerank_tool.contextual_rerank_tool import (
|
||||
from crewai_tools.tools.couchbase_tool.couchbase_tool import (
|
||||
CouchbaseFTSVectorSearchTool,
|
||||
)
|
||||
from crewai_tools.tools.crewai_enterprise_tools.crewai_enterprise_tools import (
|
||||
CrewaiEnterpriseTools,
|
||||
)
|
||||
from crewai_tools.tools.crewai_platform_tools.crewai_platform_tools import (
|
||||
CrewaiPlatformTools,
|
||||
)
|
||||
@@ -196,6 +199,7 @@ __all__ = [
|
||||
"ContextualAIQueryTool",
|
||||
"ContextualAIRerankTool",
|
||||
"CouchbaseFTSVectorSearchTool",
|
||||
"CrewaiEnterpriseTools",
|
||||
"CrewaiPlatformTools",
|
||||
"DOCXSearchTool",
|
||||
"DallETool",
|
||||
|
||||
@@ -4,7 +4,6 @@ from typing import Any
|
||||
|
||||
from crewai.tools import BaseTool, EnvVar
|
||||
from openai import OpenAI
|
||||
from openai.types.chat import ChatCompletion
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
@@ -31,7 +30,7 @@ class AIMindTool(BaseTool):
|
||||
)
|
||||
args_schema: type[BaseModel] = AIMindToolInputSchema
|
||||
api_key: str | None = None
|
||||
datasources: list[dict[str, Any]] = Field(default_factory=list)
|
||||
datasources: list[dict[str, Any]] | None = None
|
||||
mind_name: str | None = None
|
||||
package_dependencies: list[str] = Field(default_factory=lambda: ["minds-sdk"])
|
||||
env_vars: list[EnvVar] = Field(
|
||||
@@ -88,15 +87,10 @@ class AIMindTool(BaseTool):
|
||||
base_url=AIMindToolConstants.MINDS_API_BASE_URL, api_key=self.api_key
|
||||
)
|
||||
|
||||
if self.mind_name is None:
|
||||
raise ValueError("Mind name is not set.")
|
||||
|
||||
completion = openai_client.chat.completions.create(
|
||||
model=self.mind_name,
|
||||
messages=[{"role": "user", "content": query}],
|
||||
stream=False,
|
||||
)
|
||||
if not isinstance(completion, ChatCompletion):
|
||||
raise ValueError("Invalid response from AI-Mind")
|
||||
|
||||
return completion.choices[0].message.content
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
@@ -51,7 +49,7 @@ class ApifyActorsTool(BaseTool):
|
||||
print(f"URL: {result['metadata']['url']}")
|
||||
print(f"Content: {result.get('markdown', 'N/A')[:100]}...")
|
||||
"""
|
||||
actor_tool: _ApifyActorsTool = Field(description="Apify Actor Tool")
|
||||
actor_tool: "_ApifyActorsTool" = Field(description="Apify Actor Tool")
|
||||
package_dependencies: list[str] = Field(default_factory=lambda: ["langchain-apify"])
|
||||
|
||||
def __init__(self, actor_name: str, *args: Any, **kwargs: Any) -> None:
|
||||
|
||||
@@ -36,9 +36,14 @@ class ArxivPaperTool(BaseTool):
|
||||
model_config = ConfigDict(extra="allow")
|
||||
package_dependencies: list[str] = Field(default_factory=lambda: ["pydantic"])
|
||||
env_vars: list[EnvVar] = Field(default_factory=list)
|
||||
download_pdfs: bool = False
|
||||
save_dir: str = "./arxiv_pdfs"
|
||||
use_title_as_filename: bool = False
|
||||
|
||||
def __init__(
|
||||
self, download_pdfs=False, save_dir="./arxiv_pdfs", use_title_as_filename=False
|
||||
):
|
||||
super().__init__()
|
||||
self.download_pdfs = download_pdfs
|
||||
self.save_dir = save_dir
|
||||
self.use_title_as_filename = use_title_as_filename
|
||||
|
||||
def _run(self, search_query: str, max_results: int = 5) -> str:
|
||||
try:
|
||||
@@ -65,7 +70,7 @@ class ArxivPaperTool(BaseTool):
|
||||
filename = f"{filename_base[:500]}.pdf"
|
||||
save_path = Path(save_dir) / filename
|
||||
|
||||
self.download_pdf(paper["pdf_url"], save_path) # type: ignore[arg-type]
|
||||
self.download_pdf(paper["pdf_url"], save_path)
|
||||
time.sleep(self.SLEEP_DURATION)
|
||||
|
||||
results = [self._format_paper_result(p) for p in papers]
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from datetime import datetime
|
||||
import datetime
|
||||
import os
|
||||
import time
|
||||
from typing import Any, ClassVar
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
from crewai_tools.tools.brightdata_tool.brightdata_dataset import BrightDataDatasetTool
|
||||
from crewai_tools.tools.brightdata_tool.brightdata_serp import BrightDataSearchTool
|
||||
from crewai_tools.tools.brightdata_tool.brightdata_unlocker import (
|
||||
BrightDataWebUnlockerTool,
|
||||
)
|
||||
from .brightdata_dataset import BrightDataDatasetTool
|
||||
from .brightdata_serp import BrightDataSearchTool
|
||||
from .brightdata_unlocker import BrightDataWebUnlockerTool
|
||||
|
||||
|
||||
__all__ = ["BrightDataDatasetTool", "BrightDataSearchTool", "BrightDataWebUnlockerTool"]
|
||||
|
||||
@@ -72,6 +72,6 @@ class BrowserbaseLoadTool(BaseTool):
|
||||
self.proxy = proxy
|
||||
|
||||
def _run(self, url: str):
|
||||
return self.browserbase.load_url( # type: ignore[union-attr]
|
||||
return self.browserbase.load_url(
|
||||
url, self.text_content, self.session_id, self.proxy
|
||||
)
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai_tools.rag.data_types import DataType
|
||||
from crewai_tools.tools.rag.rag_tool import RagTool
|
||||
|
||||
from ..rag.rag_tool import RagTool
|
||||
|
||||
|
||||
class FixedCodeDocsSearchToolSchema(BaseModel):
|
||||
@@ -37,7 +38,7 @@ class CodeDocsSearchTool(RagTool):
|
||||
def add(self, docs_url: str) -> None:
|
||||
super().add(docs_url, data_type=DataType.DOCS_SITE)
|
||||
|
||||
def _run( # type: ignore[override]
|
||||
def _run(
|
||||
self,
|
||||
search_query: str,
|
||||
docs_url: str | None = None,
|
||||
|
||||
@@ -7,30 +7,18 @@ potentially unsafe operations and importing restricted modules.
|
||||
|
||||
import importlib.util
|
||||
import os
|
||||
import subprocess
|
||||
from types import ModuleType
|
||||
from typing import Any, ClassVar, TypedDict
|
||||
from typing import Any, ClassVar
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
from docker import ( # type: ignore[import-untyped]
|
||||
DockerClient,
|
||||
from_env as docker_from_env,
|
||||
)
|
||||
from docker.errors import ImageNotFound, NotFound # type: ignore[import-untyped]
|
||||
from docker.models.containers import Container # type: ignore[import-untyped]
|
||||
from docker import DockerClient, from_env as docker_from_env
|
||||
from docker.errors import ImageNotFound, NotFound
|
||||
from docker.models.containers import Container
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Unpack
|
||||
|
||||
from crewai_tools.printer import Printer
|
||||
|
||||
|
||||
class RunKwargs(TypedDict, total=False):
|
||||
"""Keyword arguments for the _run method."""
|
||||
|
||||
code: str
|
||||
libraries_used: list[str]
|
||||
|
||||
|
||||
class CodeInterpreterSchema(BaseModel):
|
||||
"""Schema for defining inputs to the CodeInterpreterTool.
|
||||
|
||||
@@ -127,14 +115,14 @@ class SandboxPython:
|
||||
return safe_builtins
|
||||
|
||||
@staticmethod
|
||||
def exec(code: str, locals_: dict[str, Any]) -> None:
|
||||
def exec(code: str, locals: dict[str, Any]) -> None:
|
||||
"""Executes Python code in a restricted environment.
|
||||
|
||||
Args:
|
||||
code: The Python code to execute as a string.
|
||||
locals_: A dictionary that will be used for local variable storage.
|
||||
locals: A dictionary that will be used for local variable storage.
|
||||
"""
|
||||
exec(code, {"__builtins__": SandboxPython.safe_builtins()}, locals_) # noqa: S102
|
||||
exec(code, {"__builtins__": SandboxPython.safe_builtins()}, locals) # noqa: S102
|
||||
|
||||
|
||||
class CodeInterpreterTool(BaseTool):
|
||||
@@ -160,13 +148,8 @@ class CodeInterpreterTool(BaseTool):
|
||||
|
||||
Returns:
|
||||
The directory path where the package is installed.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the package cannot be found.
|
||||
"""
|
||||
spec = importlib.util.find_spec("crewai_tools")
|
||||
if spec is None or spec.origin is None:
|
||||
raise RuntimeError("Cannot find crewai_tools package installation path")
|
||||
return os.path.dirname(spec.origin)
|
||||
|
||||
def _verify_docker_image(self) -> None:
|
||||
@@ -206,7 +189,7 @@ class CodeInterpreterTool(BaseTool):
|
||||
rm=True,
|
||||
)
|
||||
|
||||
def _run(self, **kwargs: Unpack[RunKwargs]) -> str:
|
||||
def _run(self, **kwargs) -> str:
|
||||
"""Runs the code interpreter tool with the provided arguments.
|
||||
|
||||
Args:
|
||||
@@ -215,18 +198,14 @@ class CodeInterpreterTool(BaseTool):
|
||||
Returns:
|
||||
The output of the executed code as a string.
|
||||
"""
|
||||
code: str | None = kwargs.get("code", self.code)
|
||||
libraries_used: list[str] = kwargs.get("libraries_used", [])
|
||||
|
||||
if not code:
|
||||
return "No code provided to execute."
|
||||
code = kwargs.get("code", self.code)
|
||||
libraries_used = kwargs.get("libraries_used", [])
|
||||
|
||||
if self.unsafe_mode:
|
||||
return self.run_code_unsafe(code, libraries_used)
|
||||
return self.run_code_safety(code, libraries_used)
|
||||
|
||||
@staticmethod
|
||||
def _install_libraries(container: Container, libraries: list[str]) -> None:
|
||||
def _install_libraries(self, container: Container, libraries: list[str]) -> None:
|
||||
"""Installs required Python libraries in the Docker container.
|
||||
|
||||
Args:
|
||||
@@ -266,8 +245,7 @@ class CodeInterpreterTool(BaseTool):
|
||||
volumes={current_path: {"bind": "/workspace", "mode": "rw"}}, # type: ignore
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _check_docker_available() -> bool:
|
||||
def _check_docker_available(self) -> bool:
|
||||
"""Checks if Docker is available and running on the system.
|
||||
|
||||
Attempts to run the 'docker info' command to verify Docker availability.
|
||||
@@ -276,6 +254,7 @@ class CodeInterpreterTool(BaseTool):
|
||||
Returns:
|
||||
True if Docker is available and running, False otherwise.
|
||||
"""
|
||||
import subprocess
|
||||
|
||||
try:
|
||||
subprocess.run(
|
||||
@@ -340,8 +319,7 @@ class CodeInterpreterTool(BaseTool):
|
||||
return f"Something went wrong while running the code: \n{exec_result.output.decode('utf-8')}"
|
||||
return exec_result.output.decode("utf-8")
|
||||
|
||||
@staticmethod
|
||||
def run_code_in_restricted_sandbox(code: str) -> str:
|
||||
def run_code_in_restricted_sandbox(self, code: str) -> str:
|
||||
"""Runs Python code in a restricted sandbox environment.
|
||||
|
||||
Executes the code with restricted access to potentially dangerous modules and
|
||||
@@ -355,15 +333,14 @@ class CodeInterpreterTool(BaseTool):
|
||||
or an error message if execution failed.
|
||||
"""
|
||||
Printer.print("Running code in restricted sandbox", color="yellow")
|
||||
exec_locals: dict[str, Any] = {}
|
||||
exec_locals = {}
|
||||
try:
|
||||
SandboxPython.exec(code=code, locals_=exec_locals)
|
||||
SandboxPython.exec(code=code, locals=exec_locals)
|
||||
return exec_locals.get("result", "No result variable found.")
|
||||
except Exception as e:
|
||||
return f"An error occurred: {e!s}"
|
||||
|
||||
@staticmethod
|
||||
def run_code_unsafe(code: str, libraries_used: list[str]) -> str:
|
||||
def run_code_unsafe(self, code: str, libraries_used: list[str]) -> str:
|
||||
"""Runs code directly on the host machine without any safety restrictions.
|
||||
|
||||
WARNING: This mode is unsafe and should only be used in trusted environments
|
||||
@@ -384,7 +361,7 @@ class CodeInterpreterTool(BaseTool):
|
||||
|
||||
# Execute the code
|
||||
try:
|
||||
exec_locals: dict[str, Any] = {}
|
||||
exec_locals = {}
|
||||
exec(code, {}, exec_locals) # noqa: S102
|
||||
return exec_locals.get("result", "No result variable found.")
|
||||
except Exception as e:
|
||||
|
||||
@@ -124,5 +124,5 @@ class ComposioTool(BaseTool):
|
||||
|
||||
return [
|
||||
cls.from_action(action=action, **kwargs)
|
||||
for action in toolset.find_actions_by_tags(*apps, tags=tags) # type: ignore[arg-type]
|
||||
for action in toolset.find_actions_by_tags(*apps, tags=tags)
|
||||
]
|
||||
|
||||
@@ -82,7 +82,7 @@ class ContextualAIQueryTool(BaseTool):
|
||||
if loop and loop.is_running():
|
||||
# Already inside an event loop
|
||||
try:
|
||||
import nest_asyncio # type: ignore[import-untyped]
|
||||
import nest_asyncio
|
||||
|
||||
nest_asyncio.apply(loop)
|
||||
loop.run_until_complete(
|
||||
|
||||
@@ -4,13 +4,10 @@ from typing import Any
|
||||
|
||||
|
||||
try:
|
||||
from couchbase.cluster import Cluster # type: ignore[import-untyped]
|
||||
from couchbase.options import SearchOptions # type: ignore[import-untyped]
|
||||
import couchbase.search as search # type: ignore[import-untyped]
|
||||
from couchbase.vector_search import ( # type: ignore[import-untyped]
|
||||
VectorQuery,
|
||||
VectorSearch,
|
||||
)
|
||||
from couchbase.cluster import Cluster
|
||||
from couchbase.options import SearchOptions
|
||||
import couchbase.search as search
|
||||
from couchbase.vector_search import VectorQuery, VectorSearch
|
||||
|
||||
COUCHBASE_AVAILABLE = True
|
||||
except ImportError:
|
||||
@@ -41,31 +38,24 @@ class CouchbaseFTSVectorSearchTool(BaseTool):
|
||||
name: str = "CouchbaseFTSVectorSearchTool"
|
||||
description: str = "A tool to search the Couchbase database for relevant information on internal documents."
|
||||
args_schema: type[BaseModel] = CouchbaseToolSchema
|
||||
cluster: SkipValidation[Cluster] = Field(
|
||||
description="An instance of the Couchbase Cluster connected to the desired Couchbase server.",
|
||||
)
|
||||
collection_name: str = Field(
|
||||
description="The name of the Couchbase collection to search",
|
||||
)
|
||||
scope_name: str = Field(
|
||||
description="The name of the Couchbase scope containing the collection to search.",
|
||||
)
|
||||
bucket_name: str = Field(
|
||||
description="The name of the Couchbase bucket to search",
|
||||
)
|
||||
index_name: str = Field(
|
||||
description="The name of the Couchbase index to search",
|
||||
)
|
||||
cluster: SkipValidation[Cluster | None] = None
|
||||
collection_name: str | None = (None,)
|
||||
scope_name: str | None = (None,)
|
||||
bucket_name: str | None = (None,)
|
||||
index_name: str | None = (None,)
|
||||
embedding_key: str | None = Field(
|
||||
default="embedding",
|
||||
description="Name of the field in the search index that stores the vector",
|
||||
)
|
||||
scoped_index: bool = Field(
|
||||
default=True,
|
||||
description="Specify whether the index is scoped. Is True by default.",
|
||||
scoped_index: bool | None = (
|
||||
Field(
|
||||
default=True,
|
||||
description="Specify whether the index is scoped. Is True by default.",
|
||||
),
|
||||
)
|
||||
limit: int | None = Field(default=3)
|
||||
embedding_function: SkipValidation[Callable[[str], list[float]]] = Field(
|
||||
default=None,
|
||||
description="A function that takes a string and returns a list of floats. This is used to embed the query before searching the database.",
|
||||
)
|
||||
|
||||
@@ -122,9 +112,6 @@ class CouchbaseFTSVectorSearchTool(BaseTool):
|
||||
" Please create the index before searching."
|
||||
)
|
||||
else:
|
||||
if not self.cluster:
|
||||
raise ValueError("Cluster instance must be provided")
|
||||
|
||||
all_indexes = [
|
||||
index.name for index in self.cluster.search_indexes().get_all_indexes()
|
||||
]
|
||||
@@ -153,6 +140,24 @@ class CouchbaseFTSVectorSearchTool(BaseTool):
|
||||
super().__init__(**kwargs)
|
||||
if COUCHBASE_AVAILABLE:
|
||||
try:
|
||||
if not self.cluster:
|
||||
raise ValueError("Cluster instance must be provided")
|
||||
|
||||
if not self.bucket_name:
|
||||
raise ValueError("Bucket name must be provided")
|
||||
|
||||
if not self.scope_name:
|
||||
raise ValueError("Scope name must be provided")
|
||||
|
||||
if not self.collection_name:
|
||||
raise ValueError("Collection name must be provided")
|
||||
|
||||
if not self.index_name:
|
||||
raise ValueError("Index name must be provided")
|
||||
|
||||
if not self.embedding_function:
|
||||
raise ValueError("Embedding function must be provided")
|
||||
|
||||
self._bucket = self.cluster.bucket(self.bucket_name)
|
||||
self._scope = self._bucket.scope(self.scope_name)
|
||||
self._collection = self._scope.collection(self.collection_name)
|
||||
|
||||
@@ -0,0 +1,88 @@
|
||||
"""Crewai Enterprise Tools."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
|
||||
from crewai_tools.adapters.enterprise_adapter import EnterpriseActionKitToolAdapter
|
||||
from crewai_tools.adapters.tool_collection import ToolCollection
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def CrewaiEnterpriseTools( # noqa: N802
|
||||
enterprise_token: str | None = None,
|
||||
actions_list: list[str] | None = None,
|
||||
enterprise_action_kit_project_id: str | None = None,
|
||||
enterprise_action_kit_project_url: str | None = None,
|
||||
) -> ToolCollection[BaseTool]:
|
||||
"""Factory function that returns crewai enterprise tools.
|
||||
|
||||
Args:
|
||||
enterprise_token: The token for accessing enterprise actions.
|
||||
If not provided, will try to use CREWAI_ENTERPRISE_TOOLS_TOKEN env var.
|
||||
actions_list: Optional list of specific tool names to include.
|
||||
If provided, only tools with these names will be returned.
|
||||
enterprise_action_kit_project_id: Optional ID of the Enterprise Action Kit project.
|
||||
enterprise_action_kit_project_url: Optional URL of the Enterprise Action Kit project.
|
||||
|
||||
Returns:
|
||||
A ToolCollection of BaseTool instances for enterprise actions
|
||||
"""
|
||||
import warnings
|
||||
|
||||
warnings.warn(
|
||||
"CrewaiEnterpriseTools will be removed in v1.0.0. Considering use `Agent(apps=[...])` instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
if enterprise_token is None or enterprise_token == "":
|
||||
enterprise_token = os.environ.get("CREWAI_ENTERPRISE_TOOLS_TOKEN")
|
||||
if not enterprise_token:
|
||||
logger.warning("No enterprise token provided")
|
||||
|
||||
adapter_kwargs = {"enterprise_action_token": enterprise_token}
|
||||
|
||||
if enterprise_action_kit_project_id is not None:
|
||||
adapter_kwargs["enterprise_action_kit_project_id"] = (
|
||||
enterprise_action_kit_project_id
|
||||
)
|
||||
if enterprise_action_kit_project_url is not None:
|
||||
adapter_kwargs["enterprise_action_kit_project_url"] = (
|
||||
enterprise_action_kit_project_url
|
||||
)
|
||||
|
||||
adapter = EnterpriseActionKitToolAdapter(**adapter_kwargs)
|
||||
all_tools = adapter.tools()
|
||||
parsed_actions_list = _parse_actions_list(actions_list)
|
||||
|
||||
# Filter tools based on the provided list
|
||||
return ToolCollection(all_tools).filter_by_names(parsed_actions_list)
|
||||
|
||||
|
||||
# ENTERPRISE INJECTION ONLY
|
||||
def _parse_actions_list(actions_list: list[str] | None) -> list[str] | None:
|
||||
"""Parse a string representation of a list of tool names to a list of tool names.
|
||||
|
||||
Args:
|
||||
actions_list: A string representation of a list of tool names.
|
||||
|
||||
Returns:
|
||||
A list of tool names.
|
||||
"""
|
||||
if actions_list is not None:
|
||||
return actions_list
|
||||
|
||||
actions_list_from_env = os.environ.get("CREWAI_ENTERPRISE_TOOLS_ACTIONS_LIST")
|
||||
if actions_list_from_env is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
return json.loads(actions_list_from_env)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Failed to parse actions_list as JSON: {actions_list_from_env}")
|
||||
return None
|
||||
@@ -18,7 +18,7 @@ class CrewaiPlatformToolBuilder:
|
||||
apps: list[str],
|
||||
):
|
||||
self._apps = apps
|
||||
self._actions_schema = {} # type: ignore[var-annotated]
|
||||
self._actions_schema = {}
|
||||
self._tools = None
|
||||
|
||||
def tools(self) -> list[BaseTool]:
|
||||
|
||||
@@ -24,4 +24,4 @@ def CrewaiPlatformTools( # noqa: N802
|
||||
"""
|
||||
builder = CrewaiPlatformToolBuilder(apps=apps)
|
||||
|
||||
return builder.tools() # type: ignore
|
||||
return builder.tools()
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai_tools.rag.data_types import DataType
|
||||
from crewai_tools.tools.rag.rag_tool import RagTool
|
||||
|
||||
from ..rag.rag_tool import RagTool
|
||||
|
||||
|
||||
class FixedCSVSearchToolSchema(BaseModel):
|
||||
@@ -37,7 +38,7 @@ class CSVSearchTool(RagTool):
|
||||
def add(self, csv: str) -> None:
|
||||
super().add(csv, data_type=DataType.CSV)
|
||||
|
||||
def _run( # type: ignore[override]
|
||||
def _run(
|
||||
self,
|
||||
search_query: str,
|
||||
csv: str | None = None,
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
import json
|
||||
from typing import Literal
|
||||
|
||||
from crewai.tools import BaseTool, EnvVar
|
||||
from openai import Omit, OpenAI
|
||||
from openai import OpenAI
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
@@ -20,22 +19,8 @@ class DallETool(BaseTool):
|
||||
args_schema: type[BaseModel] = ImagePromptSchema
|
||||
|
||||
model: str = "dall-e-3"
|
||||
size: (
|
||||
Literal[
|
||||
"auto",
|
||||
"1024x1024",
|
||||
"1536x1024",
|
||||
"1024x1536",
|
||||
"256x256",
|
||||
"512x512",
|
||||
"1792x1024",
|
||||
"1024x1792",
|
||||
]
|
||||
| None
|
||||
) = "1024x1024"
|
||||
quality: (
|
||||
Literal["standard", "hd", "low", "medium", "high", "auto"] | None | Omit
|
||||
) = "standard"
|
||||
size: str = "1024x1024"
|
||||
quality: str = "standard"
|
||||
n: int = 1
|
||||
|
||||
env_vars: list[EnvVar] = Field(
|
||||
@@ -64,9 +49,6 @@ class DallETool(BaseTool):
|
||||
n=self.n,
|
||||
)
|
||||
|
||||
if not response or not response.data:
|
||||
return "Failed to generate image."
|
||||
|
||||
return json.dumps(
|
||||
{
|
||||
"image_url": response.data[0].url,
|
||||
|
||||
@@ -1,8 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Any, TypeGuard, TypedDict
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
@@ -12,28 +9,6 @@ if TYPE_CHECKING:
|
||||
from databricks.sdk import WorkspaceClient
|
||||
|
||||
|
||||
class ExecutionContext(TypedDict, total=False):
|
||||
catalog: str
|
||||
schema: str
|
||||
|
||||
|
||||
def _has_data_array(result: Any) -> TypeGuard[Any]:
|
||||
"""Type guard to check if result has data_array attribute.
|
||||
|
||||
Args:
|
||||
result: The result object to check.
|
||||
|
||||
Returns:
|
||||
True if result.result.data_array exists and is not None.
|
||||
"""
|
||||
return (
|
||||
hasattr(result, "result")
|
||||
and result.result is not None
|
||||
and hasattr(result.result, "data_array")
|
||||
and result.result.data_array is not None
|
||||
)
|
||||
|
||||
|
||||
class DatabricksQueryToolSchema(BaseModel):
|
||||
"""Input schema for DatabricksQueryTool."""
|
||||
|
||||
@@ -57,7 +32,7 @@ class DatabricksQueryToolSchema(BaseModel):
|
||||
)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_input(self) -> DatabricksQueryToolSchema:
|
||||
def validate_input(self) -> "DatabricksQueryToolSchema":
|
||||
"""Validate the input parameters."""
|
||||
# Ensure the query is not empty
|
||||
if not self.query or not self.query.strip():
|
||||
@@ -97,7 +72,7 @@ class DatabricksQueryTool(BaseTool):
|
||||
default_schema: str | None = None
|
||||
default_warehouse_id: str | None = None
|
||||
|
||||
_workspace_client: WorkspaceClient | None = None
|
||||
_workspace_client: Optional["WorkspaceClient"] = None
|
||||
package_dependencies: list[str] = Field(default_factory=lambda: ["databricks-sdk"])
|
||||
|
||||
def __init__(
|
||||
@@ -135,7 +110,7 @@ class DatabricksQueryTool(BaseTool):
|
||||
)
|
||||
|
||||
@property
|
||||
def workspace_client(self) -> WorkspaceClient:
|
||||
def workspace_client(self) -> "WorkspaceClient":
|
||||
"""Get or create a Databricks WorkspaceClient instance."""
|
||||
if self._workspace_client is None:
|
||||
try:
|
||||
@@ -234,12 +209,8 @@ class DatabricksQueryTool(BaseTool):
|
||||
db_schema = validated_input.db_schema
|
||||
warehouse_id = validated_input.warehouse_id
|
||||
|
||||
if warehouse_id is None:
|
||||
return "SQL warehouse ID must be provided either as a parameter or as a default."
|
||||
|
||||
# Setup SQL context with catalog/schema if provided
|
||||
|
||||
context: ExecutionContext = {}
|
||||
context = {}
|
||||
if catalog:
|
||||
context["catalog"] = catalog
|
||||
if db_schema:
|
||||
@@ -260,6 +231,7 @@ class DatabricksQueryTool(BaseTool):
|
||||
return f"Error starting query execution: {execute_error!s}"
|
||||
|
||||
# Poll for results with better error handling
|
||||
import time
|
||||
|
||||
result = None
|
||||
timeout = 300 # 5 minutes timeout
|
||||
@@ -267,9 +239,6 @@ class DatabricksQueryTool(BaseTool):
|
||||
poll_count = 0
|
||||
previous_state = None # Track previous state to detect changes
|
||||
|
||||
if statement_id is None:
|
||||
return "Failed to retrieve statement ID after execution."
|
||||
|
||||
while time.time() - start_time < timeout:
|
||||
poll_count += 1
|
||||
try:
|
||||
@@ -279,7 +248,7 @@ class DatabricksQueryTool(BaseTool):
|
||||
# Check if finished - be very explicit about state checking
|
||||
if hasattr(result, "status") and hasattr(result.status, "state"):
|
||||
state_value = str(
|
||||
result.status.state # type: ignore[union-attr]
|
||||
result.status.state
|
||||
) # Convert to string to handle both string and enum
|
||||
|
||||
# Track state changes for debugging
|
||||
@@ -296,16 +265,16 @@ class DatabricksQueryTool(BaseTool):
|
||||
# First try direct access to error.message
|
||||
if (
|
||||
hasattr(result.status, "error")
|
||||
and result.status.error # type: ignore[union-attr]
|
||||
and result.status.error
|
||||
):
|
||||
if hasattr(result.status.error, "message"): # type: ignore[union-attr]
|
||||
error_info = result.status.error.message # type: ignore[union-attr,assignment]
|
||||
if hasattr(result.status.error, "message"):
|
||||
error_info = result.status.error.message
|
||||
# Some APIs may have a different structure
|
||||
elif hasattr(result.status.error, "error_message"): # type: ignore[union-attr]
|
||||
error_info = result.status.error.error_message # type: ignore[union-attr]
|
||||
elif hasattr(result.status.error, "error_message"):
|
||||
error_info = result.status.error.error_message
|
||||
# Last resort, try to convert the whole error object to string
|
||||
else:
|
||||
error_info = str(result.status.error) # type: ignore[union-attr]
|
||||
error_info = str(result.status.error)
|
||||
except Exception as err_extract_error:
|
||||
# If all else fails, try to get any info we can
|
||||
error_info = (
|
||||
@@ -333,7 +302,7 @@ class DatabricksQueryTool(BaseTool):
|
||||
return "Query completed but returned an invalid result structure"
|
||||
|
||||
# Convert state to string for comparison
|
||||
state_value = str(result.status.state) # type: ignore[union-attr]
|
||||
state_value = str(result.status.state)
|
||||
if not any(
|
||||
state in state_value for state in ["SUCCEEDED", "FAILED", "CANCELED"]
|
||||
):
|
||||
@@ -354,7 +323,7 @@ class DatabricksQueryTool(BaseTool):
|
||||
if has_schema and has_result:
|
||||
try:
|
||||
# Get schema for column names
|
||||
columns = [col.name for col in result.manifest.schema.columns] # type: ignore[union-attr]
|
||||
columns = [col.name for col in result.manifest.schema.columns]
|
||||
|
||||
# Debug info for schema
|
||||
|
||||
@@ -362,7 +331,7 @@ class DatabricksQueryTool(BaseTool):
|
||||
all_columns = set(columns)
|
||||
|
||||
# Dump the raw structure of result data to help troubleshoot
|
||||
if _has_data_array(result):
|
||||
if hasattr(result.result, "data_array"):
|
||||
# Add defensive check for None data_array
|
||||
if result.result.data_array is None:
|
||||
# Return empty result handling rather than trying to process null data
|
||||
@@ -374,7 +343,8 @@ class DatabricksQueryTool(BaseTool):
|
||||
|
||||
# Only try to analyze sample if data_array exists and has content
|
||||
if (
|
||||
_has_data_array(result)
|
||||
hasattr(result.result, "data_array")
|
||||
and result.result.data_array
|
||||
and len(result.result.data_array) > 0
|
||||
and len(result.result.data_array[0]) > 0
|
||||
):
|
||||
@@ -415,17 +385,17 @@ class DatabricksQueryTool(BaseTool):
|
||||
rows_with_single_item = 0
|
||||
if (
|
||||
hasattr(result.result, "data_array")
|
||||
and result.result.data_array # type: ignore[union-attr]
|
||||
and len(result.result.data_array) > 0 # type: ignore[union-attr]
|
||||
and result.result.data_array
|
||||
and len(result.result.data_array) > 0
|
||||
):
|
||||
sample_size_for_rows = (
|
||||
min(sample_size, len(result.result.data_array[0])) # type: ignore[union-attr]
|
||||
min(sample_size, len(result.result.data_array[0]))
|
||||
if "sample_size" in locals()
|
||||
else min(20, len(result.result.data_array[0])) # type: ignore[union-attr]
|
||||
else min(20, len(result.result.data_array[0]))
|
||||
)
|
||||
rows_with_single_item = sum(
|
||||
1 # type: ignore[misc]
|
||||
for row in result.result.data_array[0][ # type: ignore[union-attr]
|
||||
1
|
||||
for row in result.result.data_array[0][
|
||||
:sample_size_for_rows
|
||||
]
|
||||
if isinstance(row, list) and len(row) == 1
|
||||
@@ -454,13 +424,13 @@ class DatabricksQueryTool(BaseTool):
|
||||
# We're dealing with data where the rows may be incorrectly structured
|
||||
|
||||
# Collect all values into a flat list
|
||||
all_values: list[Any] = []
|
||||
all_values = []
|
||||
if (
|
||||
hasattr(result.result, "data_array")
|
||||
and result.result.data_array # type: ignore[union-attr]
|
||||
and result.result.data_array
|
||||
):
|
||||
# Flatten all values into a single list
|
||||
for chunk in result.result.data_array: # type: ignore[union-attr]
|
||||
for chunk in result.result.data_array:
|
||||
for item in chunk:
|
||||
if isinstance(item, (list, tuple)):
|
||||
all_values.extend(item)
|
||||
@@ -659,7 +629,7 @@ class DatabricksQueryTool(BaseTool):
|
||||
# Fix titles that might still have issues
|
||||
if (
|
||||
isinstance(row.get("Title"), str)
|
||||
and len(row.get("Title")) <= 1 # type: ignore[arg-type]
|
||||
and len(row.get("Title")) <= 1
|
||||
):
|
||||
# This is likely still a fragmented title - mark as potentially incomplete
|
||||
row["Title"] = f"[INCOMPLETE] {row.get('Title')}"
|
||||
@@ -675,11 +645,11 @@ class DatabricksQueryTool(BaseTool):
|
||||
# Check different result structures
|
||||
if (
|
||||
hasattr(result.result, "data_array")
|
||||
and result.result.data_array # type: ignore[union-attr]
|
||||
and result.result.data_array
|
||||
):
|
||||
# Check if data appears to be malformed within chunks
|
||||
for _chunk_idx, chunk in enumerate(
|
||||
result.result.data_array # type: ignore[union-attr]
|
||||
result.result.data_array
|
||||
):
|
||||
# Check if chunk might actually contain individual columns of a single row
|
||||
# This is another way data might be malformed - check the first few values
|
||||
@@ -786,10 +756,10 @@ class DatabricksQueryTool(BaseTool):
|
||||
|
||||
chunk_results.append(row_dict)
|
||||
|
||||
elif hasattr(result.result, "data") and result.result.data: # type: ignore[union-attr]
|
||||
elif hasattr(result.result, "data") and result.result.data:
|
||||
# Alternative data structure
|
||||
|
||||
for _row_idx, row in enumerate(result.result.data): # type: ignore[union-attr]
|
||||
for _row_idx, row in enumerate(result.result.data):
|
||||
# Debug info
|
||||
|
||||
# Safely create dictionary matching column names to values
|
||||
@@ -833,12 +803,12 @@ class DatabricksQueryTool(BaseTool):
|
||||
|
||||
# If we have no results but the query succeeded (e.g., for DDL statements)
|
||||
if not chunk_results and hasattr(result, "status"):
|
||||
state_value = str(result.status.state) # type: ignore[union-attr]
|
||||
state_value = str(result.status.state)
|
||||
if "SUCCEEDED" in state_value:
|
||||
return "Query executed successfully (no results to display)"
|
||||
|
||||
# Format and return results
|
||||
return self._format_results(chunk_results) # type: ignore[arg-type]
|
||||
return self._format_results(chunk_results)
|
||||
|
||||
except Exception as e:
|
||||
# Include more details in the error message to help with debugging
|
||||
|
||||
@@ -35,10 +35,7 @@ class DirectoryReadTool(BaseTool):
|
||||
self,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
directory: str | None = kwargs.get("directory", self.directory)
|
||||
if directory is None:
|
||||
raise ValueError("Directory must be provided.")
|
||||
|
||||
directory = kwargs.get("directory", self.directory)
|
||||
if directory[-1] == "/":
|
||||
directory = directory[:-1]
|
||||
files_list = [
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai_tools.rag.data_types import DataType
|
||||
from crewai_tools.tools.rag.rag_tool import RagTool
|
||||
|
||||
from ..rag.rag_tool import RagTool
|
||||
|
||||
|
||||
class FixedDirectorySearchToolSchema(BaseModel):
|
||||
@@ -37,7 +38,7 @@ class DirectorySearchTool(RagTool):
|
||||
def add(self, directory: str) -> None:
|
||||
super().add(directory, data_type=DataType.DIRECTORY)
|
||||
|
||||
def _run( # type: ignore[override]
|
||||
def _run(
|
||||
self,
|
||||
search_query: str,
|
||||
directory: str | None = None,
|
||||
|
||||
@@ -3,7 +3,8 @@ from typing import Any
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai_tools.rag.data_types import DataType
|
||||
from crewai_tools.tools.rag.rag_tool import RagTool
|
||||
|
||||
from ..rag.rag_tool import RagTool
|
||||
|
||||
|
||||
class FixedDOCXSearchToolSchema(BaseModel):
|
||||
@@ -45,7 +46,7 @@ class DOCXSearchTool(RagTool):
|
||||
def add(self, docx: str) -> None:
|
||||
super().add(docx, data_type=DataType.DOCX)
|
||||
|
||||
def _run( # type: ignore[override]
|
||||
def _run(
|
||||
self,
|
||||
search_query: str,
|
||||
docx: str | None = None,
|
||||
|
||||
@@ -1,21 +1,17 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from builtins import type as type_
|
||||
import os
|
||||
from typing import Any, TypedDict
|
||||
from typing import Any, Optional
|
||||
|
||||
from crewai.tools import BaseTool, EnvVar
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from typing_extensions import Required
|
||||
|
||||
|
||||
class SearchParams(TypedDict, total=False):
|
||||
"""Parameters for Exa search API."""
|
||||
try:
|
||||
from exa_py import Exa
|
||||
|
||||
type: Required[str | None]
|
||||
start_published_date: str
|
||||
end_published_date: str
|
||||
include_domains: list[str]
|
||||
EXA_INSTALLED = True
|
||||
except ImportError:
|
||||
Exa = Any
|
||||
EXA_INSTALLED = False
|
||||
|
||||
|
||||
class EXABaseToolSchema(BaseModel):
|
||||
@@ -35,8 +31,8 @@ class EXASearchTool(BaseTool):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
name: str = "EXASearchTool"
|
||||
description: str = "Search the internet using Exa"
|
||||
args_schema: type_[BaseModel] = EXABaseToolSchema
|
||||
client: Any | None = None
|
||||
args_schema: type[BaseModel] = EXABaseToolSchema
|
||||
client: Optional["Exa"] = None
|
||||
content: bool | None = False
|
||||
summary: bool | None = False
|
||||
type: str | None = "auto"
|
||||
@@ -76,9 +72,7 @@ class EXASearchTool(BaseTool):
|
||||
super().__init__(
|
||||
**kwargs,
|
||||
)
|
||||
try:
|
||||
from exa_py import Exa
|
||||
except ImportError as e:
|
||||
if not EXA_INSTALLED:
|
||||
import click
|
||||
|
||||
if click.confirm(
|
||||
@@ -88,16 +82,11 @@ class EXASearchTool(BaseTool):
|
||||
|
||||
subprocess.run(["uv", "add", "exa_py"], check=True) # noqa: S607
|
||||
|
||||
# Re-import after installation
|
||||
from exa_py import Exa
|
||||
else:
|
||||
raise ImportError(
|
||||
"You are missing the 'exa_py' package. Would you like to install it?"
|
||||
) from e
|
||||
|
||||
client_kwargs: dict[str, str] = {}
|
||||
if self.api_key:
|
||||
client_kwargs["api_key"] = self.api_key
|
||||
)
|
||||
client_kwargs = {"api_key": self.api_key}
|
||||
if self.base_url:
|
||||
client_kwargs["base_url"] = self.base_url
|
||||
self.client = Exa(**client_kwargs)
|
||||
@@ -115,7 +104,7 @@ class EXASearchTool(BaseTool):
|
||||
if self.client is None:
|
||||
raise ValueError("Client not initialized")
|
||||
|
||||
search_params: SearchParams = {
|
||||
search_params = {
|
||||
"type": self.type,
|
||||
}
|
||||
|
||||
|
||||
@@ -72,9 +72,9 @@ class FileCompressorTool(BaseTool):
|
||||
"tar.xz": self._compress_tar,
|
||||
}
|
||||
if format == "zip":
|
||||
format_compression[format](input_path, output_path) # type: ignore[operator]
|
||||
format_compression[format](input_path, output_path)
|
||||
else:
|
||||
format_compression[format](input_path, output_path, format) # type: ignore[operator]
|
||||
format_compression[format](input_path, output_path, format)
|
||||
|
||||
return f"Successfully compressed '{input_path}' into '{output_path}'"
|
||||
except FileNotFoundError:
|
||||
@@ -84,8 +84,7 @@ class FileCompressorTool(BaseTool):
|
||||
except Exception as e:
|
||||
return f"An unexpected error occurred during compression: {e!s}"
|
||||
|
||||
@staticmethod
|
||||
def _generate_output_path(input_path: str, format: str) -> str:
|
||||
def _generate_output_path(self, input_path: str, format: str) -> str:
|
||||
"""Generates output path based on input path and format."""
|
||||
if os.path.isfile(input_path):
|
||||
base_name = os.path.splitext(os.path.basename(input_path))[
|
||||
@@ -95,8 +94,7 @@ class FileCompressorTool(BaseTool):
|
||||
base_name = os.path.basename(os.path.normpath(input_path)) # Directory name
|
||||
return os.path.join(os.getcwd(), f"{base_name}.{format}")
|
||||
|
||||
@staticmethod
|
||||
def _prepare_output(output_path: str, overwrite: bool) -> bool:
|
||||
def _prepare_output(self, output_path: str, overwrite: bool) -> bool:
|
||||
"""Ensures output path is ready for writing."""
|
||||
output_dir = os.path.dirname(output_path)
|
||||
if output_dir and not os.path.exists(output_dir):
|
||||
@@ -105,8 +103,7 @@ class FileCompressorTool(BaseTool):
|
||||
return False
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def _compress_zip(input_path: str, output_path: str):
|
||||
def _compress_zip(self, input_path: str, output_path: str):
|
||||
"""Compresses input into a zip archive."""
|
||||
with zipfile.ZipFile(output_path, "w", zipfile.ZIP_DEFLATED) as zipf:
|
||||
if os.path.isfile(input_path):
|
||||
@@ -118,8 +115,7 @@ class FileCompressorTool(BaseTool):
|
||||
arcname = os.path.relpath(full_path, start=input_path)
|
||||
zipf.write(full_path, arcname)
|
||||
|
||||
@staticmethod
|
||||
def _compress_tar(input_path: str, output_path: str, format: str):
|
||||
def _compress_tar(self, input_path: str, output_path: str, format: str):
|
||||
"""Compresses input into a tar archive with the given format."""
|
||||
format_mode = {
|
||||
"tar": "w",
|
||||
@@ -133,6 +129,6 @@ class FileCompressorTool(BaseTool):
|
||||
|
||||
mode = format_mode[format]
|
||||
|
||||
with tarfile.open(output_path, mode) as tarf: # type: ignore[call-overload]
|
||||
with tarfile.open(output_path, mode) as tarf:
|
||||
arcname = os.path.basename(input_path)
|
||||
tarf.add(input_path, arcname=arcname)
|
||||
|
||||
@@ -1,16 +1,14 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from crewai.tools import BaseTool, EnvVar
|
||||
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from firecrawl import FirecrawlApp # type: ignore[import-untyped]
|
||||
from firecrawl import FirecrawlApp
|
||||
|
||||
try:
|
||||
from firecrawl import FirecrawlApp # type: ignore[import-untyped]
|
||||
from firecrawl import FirecrawlApp
|
||||
|
||||
FIRECRAWL_AVAILABLE = True
|
||||
except ImportError:
|
||||
@@ -61,7 +59,7 @@ class FirecrawlCrawlWebsiteTool(BaseTool):
|
||||
},
|
||||
}
|
||||
)
|
||||
_firecrawl: FirecrawlApp | None = PrivateAttr(None)
|
||||
_firecrawl: Optional["FirecrawlApp"] = PrivateAttr(None)
|
||||
package_dependencies: list[str] = Field(default_factory=lambda: ["firecrawl-py"])
|
||||
env_vars: list[EnvVar] = Field(
|
||||
default_factory=lambda: [
|
||||
@@ -116,7 +114,7 @@ try:
|
||||
# Only rebuild if the class hasn't been initialized yet
|
||||
if not hasattr(FirecrawlCrawlWebsiteTool, "_model_rebuilt"):
|
||||
FirecrawlCrawlWebsiteTool.model_rebuild()
|
||||
FirecrawlCrawlWebsiteTool._model_rebuilt = True # type: ignore[attr-defined]
|
||||
FirecrawlCrawlWebsiteTool._model_rebuilt = True
|
||||
except ImportError:
|
||||
"""
|
||||
When this tool is not used, then exception can be ignored.
|
||||
|
||||
@@ -1,16 +1,14 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from crewai.tools import BaseTool, EnvVar
|
||||
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from firecrawl import FirecrawlApp # type: ignore[import-untyped]
|
||||
from firecrawl import FirecrawlApp
|
||||
|
||||
try:
|
||||
from firecrawl import FirecrawlApp # type: ignore[import-untyped]
|
||||
from firecrawl import FirecrawlApp
|
||||
|
||||
FIRECRAWL_AVAILABLE = True
|
||||
except ImportError:
|
||||
@@ -56,7 +54,7 @@ class FirecrawlScrapeWebsiteTool(BaseTool):
|
||||
}
|
||||
)
|
||||
|
||||
_firecrawl: FirecrawlApp | None = PrivateAttr(None)
|
||||
_firecrawl: Optional["FirecrawlApp"] = PrivateAttr(None)
|
||||
package_dependencies: list[str] = Field(default_factory=lambda: ["firecrawl-py"])
|
||||
env_vars: list[EnvVar] = Field(
|
||||
default_factory=lambda: [
|
||||
@@ -104,7 +102,7 @@ try:
|
||||
# Must rebuild model after class is defined
|
||||
if not hasattr(FirecrawlScrapeWebsiteTool, "_model_rebuilt"):
|
||||
FirecrawlScrapeWebsiteTool.model_rebuild()
|
||||
FirecrawlScrapeWebsiteTool._model_rebuilt = True # type: ignore[attr-defined]
|
||||
FirecrawlScrapeWebsiteTool._model_rebuilt = True
|
||||
except ImportError:
|
||||
"""
|
||||
When this tool is not used, then exception can be ignored.
|
||||
|
||||
@@ -1,17 +1,15 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from crewai.tools import BaseTool, EnvVar
|
||||
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from firecrawl import FirecrawlApp # type: ignore[import-untyped]
|
||||
from firecrawl import FirecrawlApp
|
||||
|
||||
|
||||
try:
|
||||
from firecrawl import FirecrawlApp # type: ignore[import-untyped]
|
||||
from firecrawl import FirecrawlApp
|
||||
|
||||
FIRECRAWL_AVAILABLE = True
|
||||
except ImportError:
|
||||
@@ -55,7 +53,7 @@ class FirecrawlSearchTool(BaseTool):
|
||||
"timeout": 60000,
|
||||
}
|
||||
)
|
||||
_firecrawl: FirecrawlApp | None = PrivateAttr(None)
|
||||
_firecrawl: Optional["FirecrawlApp"] = PrivateAttr(None)
|
||||
package_dependencies: list[str] = Field(default_factory=lambda: ["firecrawl-py"])
|
||||
env_vars: list[EnvVar] = Field(
|
||||
default_factory=lambda: [
|
||||
@@ -116,7 +114,7 @@ try:
|
||||
# Only rebuild if the class hasn't been initialized yet
|
||||
if not hasattr(FirecrawlSearchTool, "_model_rebuilt"):
|
||||
FirecrawlSearchTool.model_rebuild()
|
||||
FirecrawlSearchTool._model_rebuilt = True # type: ignore[attr-defined]
|
||||
FirecrawlSearchTool._model_rebuilt = True
|
||||
except ImportError:
|
||||
"""
|
||||
When this tool is not used, then exception can be ignored.
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai_tools.rag.data_types import DataType
|
||||
from crewai_tools.tools.rag.rag_tool import RagTool
|
||||
|
||||
from ..rag.rag_tool import RagTool
|
||||
|
||||
|
||||
class FixedGithubSearchToolSchema(BaseModel):
|
||||
@@ -60,7 +61,7 @@ class GithubSearchTool(RagTool):
|
||||
metadata={"content_types": content_types, "gh_token": self.gh_token},
|
||||
)
|
||||
|
||||
def _run( # type: ignore[override]
|
||||
def _run(
|
||||
self,
|
||||
search_query: str,
|
||||
github_repo: str | None = None,
|
||||
|
||||
@@ -51,7 +51,7 @@ class HyperbrowserLoadTool(BaseTool):
|
||||
)
|
||||
|
||||
try:
|
||||
from hyperbrowser import Hyperbrowser # type: ignore[import-untyped]
|
||||
from hyperbrowser import Hyperbrowser
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"`hyperbrowser` package not found, please run `pip install hyperbrowser`"
|
||||
@@ -64,16 +64,11 @@ class HyperbrowserLoadTool(BaseTool):
|
||||
|
||||
self.hyperbrowser = Hyperbrowser(api_key=self.api_key)
|
||||
|
||||
@staticmethod
|
||||
def _prepare_params(params: dict) -> dict:
|
||||
def _prepare_params(self, params: dict) -> dict:
|
||||
"""Prepare session and scrape options parameters."""
|
||||
try:
|
||||
from hyperbrowser.models.scrape import ( # type: ignore[import-untyped]
|
||||
ScrapeOptions,
|
||||
)
|
||||
from hyperbrowser.models.session import ( # type: ignore[import-untyped]
|
||||
CreateSessionParams,
|
||||
)
|
||||
from hyperbrowser.models.scrape import ScrapeOptions
|
||||
from hyperbrowser.models.session import CreateSessionParams
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"`hyperbrowser` package not found, please run `pip install hyperbrowser`"
|
||||
@@ -107,12 +102,8 @@ class HyperbrowserLoadTool(BaseTool):
|
||||
if params is None:
|
||||
params = {}
|
||||
try:
|
||||
from hyperbrowser.models.crawl import ( # type: ignore[import-untyped]
|
||||
StartCrawlJobParams,
|
||||
)
|
||||
from hyperbrowser.models.scrape import ( # type: ignore[import-untyped]
|
||||
StartScrapeJobParams,
|
||||
)
|
||||
from hyperbrowser.models.crawl import StartCrawlJobParams
|
||||
from hyperbrowser.models.scrape import StartScrapeJobParams
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"`hyperbrowser` package not found, please run `pip install hyperbrowser`"
|
||||
@@ -122,10 +113,10 @@ class HyperbrowserLoadTool(BaseTool):
|
||||
|
||||
if operation == "scrape":
|
||||
scrape_params = StartScrapeJobParams(url=url, **params)
|
||||
scrape_resp = self.hyperbrowser.scrape.start_and_wait(scrape_params) # type: ignore[union-attr]
|
||||
scrape_resp = self.hyperbrowser.scrape.start_and_wait(scrape_params)
|
||||
return self._extract_content(scrape_resp.data)
|
||||
crawl_params = StartCrawlJobParams(url=url, **params)
|
||||
crawl_resp = self.hyperbrowser.crawl.start_and_wait(crawl_params) # type: ignore[union-attr]
|
||||
crawl_resp = self.hyperbrowser.crawl.start_and_wait(crawl_params)
|
||||
content = ""
|
||||
if crawl_resp.data:
|
||||
for page in crawl_resp.data:
|
||||
|
||||
@@ -102,7 +102,7 @@ class InvokeCrewAIAutomationTool(BaseTool):
|
||||
fields[field_name] = (str, field_def)
|
||||
|
||||
# Create dynamic model
|
||||
args_schema = create_model("DynamicInvokeCrewAIAutomationInput", **fields) # type: ignore[call-overload]
|
||||
args_schema = create_model("DynamicInvokeCrewAIAutomationInput", **fields)
|
||||
else:
|
||||
args_schema = InvokeCrewAIAutomationInput
|
||||
|
||||
@@ -162,11 +162,12 @@ class InvokeCrewAIAutomationTool(BaseTool):
|
||||
|
||||
# Start the crew
|
||||
response = self._kickoff_crew(inputs=kwargs)
|
||||
kickoff_id: str | None = response.get("kickoff_id")
|
||||
|
||||
if kickoff_id is None:
|
||||
if response.get("kickoff_id") is None:
|
||||
return f"Error: Failed to kickoff crew. Response: {response}"
|
||||
|
||||
kickoff_id = response.get("kickoff_id")
|
||||
|
||||
# Poll for completion
|
||||
for i in range(self.max_polling_time):
|
||||
try:
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai_tools.tools.rag.rag_tool import RagTool
|
||||
from ..rag.rag_tool import RagTool
|
||||
|
||||
|
||||
class FixedJSONSearchToolSchema(BaseModel):
|
||||
@@ -35,7 +35,7 @@ class JSONSearchTool(RagTool):
|
||||
self.args_schema = FixedJSONSearchToolSchema
|
||||
self._generate_description()
|
||||
|
||||
def _run( # type: ignore[override]
|
||||
def _run(
|
||||
self,
|
||||
search_query: str,
|
||||
json_path: str | None = None,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import os
|
||||
from typing import Any, Literal
|
||||
from typing import Any
|
||||
|
||||
from crewai.tools import BaseTool, EnvVar
|
||||
|
||||
@@ -10,7 +10,7 @@ try:
|
||||
LINKUP_AVAILABLE = True
|
||||
except ImportError:
|
||||
LINKUP_AVAILABLE = False
|
||||
LinkupClient = Any # type: ignore[misc,assignment] # type placeholder when package is not available
|
||||
LinkupClient = Any # type placeholder when package is not available
|
||||
|
||||
from pydantic import Field, PrivateAttr
|
||||
|
||||
@@ -32,7 +32,7 @@ class LinkupSearchTool(BaseTool):
|
||||
|
||||
def __init__(self, api_key: str | None = None) -> None:
|
||||
"""Initialize the tool with an API key."""
|
||||
super().__init__() # type: ignore[call-arg]
|
||||
super().__init__()
|
||||
try:
|
||||
from linkup import LinkupClient
|
||||
except ImportError:
|
||||
@@ -54,12 +54,7 @@ class LinkupSearchTool(BaseTool):
|
||||
self._client = LinkupClient(api_key=api_key or os.getenv("LINKUP_API_KEY"))
|
||||
|
||||
def _run(
|
||||
self,
|
||||
query: str,
|
||||
depth: Literal["standard", "deep"] = "standard",
|
||||
output_type: Literal[
|
||||
"searchResults", "sourcedAnswer", "structured"
|
||||
] = "searchResults",
|
||||
self, query: str, depth: str = "standard", output_type: str = "searchResults"
|
||||
) -> dict:
|
||||
"""Executes a search using the Linkup API.
|
||||
|
||||
|
||||
@@ -17,9 +17,7 @@ class LlamaIndexTool(BaseTool):
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run tool."""
|
||||
from llama_index.core.tools import ( # type: ignore[import-not-found]
|
||||
BaseTool as LlamaBaseTool,
|
||||
)
|
||||
from llama_index.core.tools import BaseTool as LlamaBaseTool
|
||||
|
||||
tool = cast(LlamaBaseTool, self.llama_index_tool)
|
||||
|
||||
@@ -30,9 +28,7 @@ class LlamaIndexTool(BaseTool):
|
||||
|
||||
@classmethod
|
||||
def from_tool(cls, tool: Any, **kwargs: Any) -> LlamaIndexTool:
|
||||
from llama_index.core.tools import ( # type: ignore[import-not-found]
|
||||
BaseTool as LlamaBaseTool,
|
||||
)
|
||||
from llama_index.core.tools import BaseTool as LlamaBaseTool
|
||||
|
||||
if not isinstance(tool, LlamaBaseTool):
|
||||
raise ValueError(f"Expected a LlamaBaseTool, got {type(tool)}")
|
||||
@@ -61,12 +57,8 @@ class LlamaIndexTool(BaseTool):
|
||||
return_direct: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> LlamaIndexTool:
|
||||
from llama_index.core.query_engine import ( # type: ignore[import-not-found]
|
||||
BaseQueryEngine,
|
||||
)
|
||||
from llama_index.core.tools import ( # type: ignore[import-not-found]
|
||||
QueryEngineTool,
|
||||
)
|
||||
from llama_index.core.query_engine import BaseQueryEngine
|
||||
from llama_index.core.tools import QueryEngineTool
|
||||
|
||||
if not isinstance(query_engine, BaseQueryEngine):
|
||||
raise ValueError(f"Expected a BaseQueryEngine, got {type(query_engine)}")
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai_tools.rag.data_types import DataType
|
||||
from crewai_tools.tools.rag.rag_tool import RagTool
|
||||
|
||||
from ..rag.rag_tool import RagTool
|
||||
|
||||
|
||||
class FixedMDXSearchToolSchema(BaseModel):
|
||||
@@ -37,7 +38,7 @@ class MDXSearchTool(RagTool):
|
||||
def add(self, mdx: str) -> None:
|
||||
super().add(mdx, data_type=DataType.MDX)
|
||||
|
||||
def _run( # type: ignore[override]
|
||||
def _run(
|
||||
self,
|
||||
search_query: str,
|
||||
mdx: str | None = None,
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from crewai_tools.tools.mongodb_vector_search_tool.vector_search import (
|
||||
from .vector_search import (
|
||||
MongoDBToolSchema,
|
||||
MongoDBVectorSearchConfig,
|
||||
MongoDBVectorSearchTool,
|
||||
|
||||
@@ -197,6 +197,7 @@ class MongoDBVectorSearchTool(BaseTool):
|
||||
|
||||
_metadatas = metadatas or [{} for _ in texts]
|
||||
ids = [str(ObjectId()) for _ in range(len(list(texts)))]
|
||||
metadatas_batch = _metadatas
|
||||
|
||||
result_ids = []
|
||||
texts_batch = []
|
||||
@@ -284,7 +285,7 @@ class MongoDBVectorSearchTool(BaseTool):
|
||||
"index": self.vector_index_name,
|
||||
"path": self.embedding_key,
|
||||
"queryVector": query_vector,
|
||||
"numCandidates": limit * oversampling_factor, # type: ignore[operator]
|
||||
"numCandidates": limit * oversampling_factor,
|
||||
"limit": limit,
|
||||
}
|
||||
if pre_filter:
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import os
|
||||
|
||||
from crewai import Agent, Crew, Task
|
||||
from multion_tool import MultiOnTool # type: ignore[import-not-found]
|
||||
from multion_tool import MultiOnTool
|
||||
|
||||
|
||||
os.environ["OPENAI_API_KEY"] = "Your Key"
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
"""Multion tool spec."""
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
from typing import Any
|
||||
|
||||
from crewai.tools import BaseTool, EnvVar
|
||||
@@ -31,6 +30,8 @@ class MultiOnTool(BaseTool):
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str | None = None,
|
||||
local: bool = False,
|
||||
max_steps: int = 3,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
@@ -42,6 +43,8 @@ class MultiOnTool(BaseTool):
|
||||
if click.confirm(
|
||||
"You are missing the 'multion' package. Would you like to install it?"
|
||||
):
|
||||
import subprocess
|
||||
|
||||
subprocess.run(["uv", "add", "multion"], check=True) # noqa: S607
|
||||
from multion.client import MultiOn
|
||||
else:
|
||||
@@ -49,7 +52,9 @@ class MultiOnTool(BaseTool):
|
||||
"`multion` package not found, please run `uv add multion`"
|
||||
) from None
|
||||
self.session_id = None
|
||||
self.local = local
|
||||
self.multion = MultiOn(api_key=api_key or os.getenv("MULTION_API_KEY"))
|
||||
self.max_steps = max_steps
|
||||
|
||||
def _run(
|
||||
self,
|
||||
@@ -65,9 +70,6 @@ class MultiOnTool(BaseTool):
|
||||
*args (Any): Additional arguments to pass to the Multion client
|
||||
**kwargs (Any): Additional keyword arguments to pass to the Multion client
|
||||
"""
|
||||
if self.multion is None:
|
||||
raise ValueError("Multion client is not initialized.")
|
||||
|
||||
browse = self.multion.browse(
|
||||
cmd=cmd,
|
||||
session_id=self.session_id,
|
||||
|
||||
@@ -3,7 +3,8 @@ from typing import Any
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai_tools.rag.data_types import DataType
|
||||
from crewai_tools.tools.rag.rag_tool import RagTool
|
||||
|
||||
from ..rag.rag_tool import RagTool
|
||||
|
||||
|
||||
class MySQLSearchToolSchema(BaseModel):
|
||||
@@ -34,7 +35,7 @@ class MySQLSearchTool(RagTool):
|
||||
) -> None:
|
||||
super().add(f"SELECT * FROM {table_name};", **kwargs) # noqa: S608
|
||||
|
||||
def _run( # type: ignore[override]
|
||||
def _run(
|
||||
self,
|
||||
search_query: str,
|
||||
similarity_threshold: float | None = None,
|
||||
|
||||
@@ -82,7 +82,7 @@ class NL2SQLTool(BaseTool):
|
||||
result = session.execute(text(sql_query))
|
||||
session.commit()
|
||||
|
||||
if result.returns_rows: # type: ignore[attr-defined]
|
||||
if result.returns_rows:
|
||||
columns = result.keys()
|
||||
return [
|
||||
dict(zip(columns, row, strict=False)) for row in result.fetchall()
|
||||
|
||||
@@ -5,10 +5,9 @@ This tool provides functionality for extracting text from images using supported
|
||||
|
||||
import base64
|
||||
|
||||
from crewai.llm import LLM
|
||||
from crewai import LLM
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
from crewai.utilities.types import LLMMessage
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, PrivateAttr
|
||||
|
||||
|
||||
class OCRToolSchema(BaseModel):
|
||||
@@ -20,7 +19,7 @@ class OCRToolSchema(BaseModel):
|
||||
For remote images, provide the complete URL starting with 'http' or 'https'.
|
||||
"""
|
||||
|
||||
image_path_url: str = Field(description="The image path or URL.")
|
||||
image_path_url: str = "The image path or URL."
|
||||
|
||||
|
||||
class OCRTool(BaseTool):
|
||||
@@ -40,9 +39,29 @@ class OCRTool(BaseTool):
|
||||
|
||||
name: str = "Optical Character Recognition Tool"
|
||||
description: str = "This tool uses an LLM's API to extract text from an image file."
|
||||
llm: LLM = Field(default_factory=lambda: LLM(model="gpt-4o", temperature=0.7))
|
||||
_llm: LLM | None = PrivateAttr(default=None)
|
||||
|
||||
args_schema: type[BaseModel] = OCRToolSchema
|
||||
|
||||
def __init__(self, llm: LLM = None, **kwargs):
|
||||
"""Initialize the OCR tool.
|
||||
|
||||
Args:
|
||||
llm (LLM, optional): Language model instance to use for API calls.
|
||||
If not provided, a default LLM with gpt-4o model will be used.
|
||||
**kwargs: Additional arguments passed to the parent class.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
|
||||
if llm is None:
|
||||
# Use the default LLM
|
||||
llm = LLM(
|
||||
model="gpt-4o",
|
||||
temperature=0.7,
|
||||
)
|
||||
|
||||
self._llm = llm
|
||||
|
||||
def _run(self, **kwargs) -> str:
|
||||
"""Execute the OCR operation on the provided image.
|
||||
|
||||
@@ -69,7 +88,7 @@ class OCRTool(BaseTool):
|
||||
base64_image = self._encode_image(image_path_url)
|
||||
image_data = f"data:image/jpeg;base64,{base64_image}"
|
||||
|
||||
messages: list[LLMMessage] = [
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are an expert OCR specialist. Extract complete text from the provided image. Provide the result as a raw text.",
|
||||
@@ -85,10 +104,9 @@ class OCRTool(BaseTool):
|
||||
},
|
||||
]
|
||||
|
||||
return self.llm.call(messages=messages)
|
||||
return self._llm.call(messages=messages)
|
||||
|
||||
@staticmethod
|
||||
def _encode_image(image_path: str):
|
||||
def _encode_image(self, image_path: str):
|
||||
"""Encode an image file to base64 format.
|
||||
|
||||
Args:
|
||||
@@ -98,4 +116,4 @@ class OCRTool(BaseTool):
|
||||
str: Base64-encoded image data as a UTF-8 string.
|
||||
"""
|
||||
with open(image_path, "rb") as image_file:
|
||||
return base64.b64encode(image_file.read()).decode()
|
||||
return base64.b64encode(image_file.read()).decode("utf-8")
|
||||
|
||||
@@ -9,10 +9,8 @@ from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
try:
|
||||
from oxylabs import RealtimeClient # type: ignore[import-untyped]
|
||||
from oxylabs.sources.response import ( # type: ignore[import-untyped]
|
||||
Response as OxylabsResponse,
|
||||
)
|
||||
from oxylabs import RealtimeClient
|
||||
from oxylabs.sources.response import Response as OxylabsResponse
|
||||
|
||||
OXYLABS_AVAILABLE = True
|
||||
except ImportError:
|
||||
|
||||
@@ -9,10 +9,8 @@ from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
try:
|
||||
from oxylabs import RealtimeClient # type: ignore[import-untyped]
|
||||
from oxylabs.sources.response import ( # type: ignore[import-untyped]
|
||||
Response as OxylabsResponse,
|
||||
)
|
||||
from oxylabs import RealtimeClient
|
||||
from oxylabs.sources.response import Response as OxylabsResponse
|
||||
|
||||
OXYLABS_AVAILABLE = True
|
||||
except ImportError:
|
||||
|
||||
@@ -9,10 +9,8 @@ from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
try:
|
||||
from oxylabs import RealtimeClient # type: ignore[import-untyped]
|
||||
from oxylabs.sources.response import ( # type: ignore[import-untyped]
|
||||
Response as OxylabsResponse,
|
||||
)
|
||||
from oxylabs import RealtimeClient
|
||||
from oxylabs.sources.response import Response as OxylabsResponse
|
||||
|
||||
OXYLABS_AVAILABLE = True
|
||||
except ImportError:
|
||||
|
||||
@@ -9,10 +9,8 @@ from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
try:
|
||||
from oxylabs import RealtimeClient # type: ignore[import-untyped]
|
||||
from oxylabs.sources.response import ( # type: ignore[import-untyped]
|
||||
Response as OxylabsResponse,
|
||||
)
|
||||
from oxylabs import RealtimeClient
|
||||
from oxylabs.sources.response import Response as OxylabsResponse
|
||||
|
||||
OXYLABS_AVAILABLE = True
|
||||
except ImportError:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from crewai_tools.tools.parallel_tools.parallel_search_tool import ParallelSearchTool
|
||||
from .parallel_search_tool import ParallelSearchTool
|
||||
|
||||
|
||||
__all__ = [
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
from crewai_tools.tools.patronus_eval_tool.patronus_eval_tool import (
|
||||
PatronusEvalTool as PatronusEvalTool,
|
||||
)
|
||||
from crewai_tools.tools.patronus_eval_tool.patronus_local_evaluator_tool import (
|
||||
from .patronus_eval_tool import PatronusEvalTool as PatronusEvalTool
|
||||
from .patronus_local_evaluator_tool import (
|
||||
PatronusLocalEvaluatorTool as PatronusLocalEvaluatorTool,
|
||||
)
|
||||
from crewai_tools.tools.patronus_eval_tool.patronus_predefined_criteria_eval_tool import (
|
||||
from .patronus_predefined_criteria_eval_tool import (
|
||||
PatronusPredefinedCriteriaEvalTool as PatronusPredefinedCriteriaEvalTool,
|
||||
)
|
||||
|
||||
@@ -1,14 +1,12 @@
|
||||
import random
|
||||
|
||||
from crewai import Agent, Crew, Task
|
||||
from patronus import ( # type: ignore[import-not-found,import-untyped]
|
||||
Client,
|
||||
EvaluationResult,
|
||||
)
|
||||
from patronus_local_evaluator_tool import ( # type: ignore[import-not-found,import-untyped]
|
||||
from patronus import Client, EvaluationResult # type: ignore[import-not-found]
|
||||
from patronus_local_evaluator_tool import ( # type: ignore[import-not-found]
|
||||
PatronusLocalEvaluatorTool,
|
||||
)
|
||||
|
||||
from crewai import Agent, Crew, Task
|
||||
|
||||
|
||||
# Test the PatronusLocalEvaluatorTool where agent uses the local evaluator
|
||||
client = Client()
|
||||
|
||||
@@ -29,6 +29,7 @@ class PatronusEvalTool(BaseTool):
|
||||
temp_evaluators, temp_criteria = self._init_run()
|
||||
self.evaluators = temp_evaluators
|
||||
self.criteria = temp_criteria
|
||||
self.description = self._generate_description()
|
||||
warnings.warn(
|
||||
"You are allowing the agent to select the best evaluator and criteria when you use the `PatronusEvalTool`. If this is not intended then please use `PatronusPredefinedCriteriaEvalTool` instead.",
|
||||
stacklevel=2,
|
||||
@@ -99,9 +100,9 @@ class PatronusEvalTool(BaseTool):
|
||||
|
||||
return evaluators, criteria
|
||||
|
||||
def _generate_description(self) -> None:
|
||||
def _generate_description(self) -> str:
|
||||
criteria = "\n".join([json.dumps(i) for i in self.criteria])
|
||||
self.description = f"""This tool calls the Patronus Evaluation API that takes the following arguments:
|
||||
return f"""This tool calls the Patronus Evaluation API that takes the following arguments:
|
||||
1. evaluated_model_input: str: The agent's task description in simple text
|
||||
2. evaluated_model_output: str: The agent's output of the task
|
||||
3. evaluated_model_retrieved_context: str: The agent's context
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
@@ -7,7 +5,7 @@ from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from patronus import Client, EvaluationResult # type: ignore[import-untyped]
|
||||
from patronus import Client, EvaluationResult
|
||||
|
||||
try:
|
||||
import patronus # noqa: F401
|
||||
@@ -37,7 +35,7 @@ class PatronusLocalEvaluatorTool(BaseTool):
|
||||
name: str = "Patronus Local Evaluator Tool"
|
||||
description: str = "This tool is used to evaluate the model input and output using custom function evaluators."
|
||||
args_schema: type[BaseModel] = FixedLocalEvaluatorToolSchema
|
||||
client: Client = None
|
||||
client: "Client" = None
|
||||
evaluator: str
|
||||
evaluated_model_gold_answer: str
|
||||
|
||||
@@ -46,7 +44,7 @@ class PatronusLocalEvaluatorTool(BaseTool):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
patronus_client: Client = None,
|
||||
patronus_client: "Client" = None,
|
||||
evaluator: str = "",
|
||||
evaluated_model_gold_answer: str = "",
|
||||
**kwargs: Any,
|
||||
@@ -56,7 +54,7 @@ class PatronusLocalEvaluatorTool(BaseTool):
|
||||
self.evaluated_model_gold_answer = evaluated_model_gold_answer
|
||||
self._initialize_patronus(patronus_client)
|
||||
|
||||
def _initialize_patronus(self, patronus_client: Client) -> None:
|
||||
def _initialize_patronus(self, patronus_client: "Client") -> None:
|
||||
try:
|
||||
if PYPATRONUS_AVAILABLE:
|
||||
self.client = patronus_client
|
||||
@@ -109,6 +107,6 @@ try:
|
||||
# Only rebuild if the class hasn't been initialized yet
|
||||
if not hasattr(PatronusLocalEvaluatorTool, "_model_rebuilt"):
|
||||
PatronusLocalEvaluatorTool.model_rebuild()
|
||||
PatronusLocalEvaluatorTool._model_rebuilt = True # type: ignore[attr-defined]
|
||||
PatronusLocalEvaluatorTool._model_rebuilt = True
|
||||
except Exception: # noqa: S110
|
||||
pass
|
||||
|
||||
@@ -67,22 +67,22 @@ class PatronusPredefinedCriteriaEvalTool(BaseTool):
|
||||
"evaluated_model_input": (
|
||||
evaluated_model_input
|
||||
if isinstance(evaluated_model_input, str)
|
||||
else evaluated_model_input.get("description") # type: ignore[union-attr]
|
||||
else evaluated_model_input.get("description")
|
||||
),
|
||||
"evaluated_model_output": (
|
||||
evaluated_model_output
|
||||
if isinstance(evaluated_model_output, str)
|
||||
else evaluated_model_output.get("description") # type: ignore[union-attr]
|
||||
else evaluated_model_output.get("description")
|
||||
),
|
||||
"evaluated_model_retrieved_context": (
|
||||
evaluated_model_retrieved_context
|
||||
if isinstance(evaluated_model_retrieved_context, str)
|
||||
else evaluated_model_retrieved_context.get("description") # type: ignore[union-attr]
|
||||
else evaluated_model_retrieved_context.get("description")
|
||||
),
|
||||
"evaluated_model_gold_answer": (
|
||||
evaluated_model_gold_answer
|
||||
if isinstance(evaluated_model_gold_answer, str)
|
||||
else evaluated_model_gold_answer.get("description") # type: ignore[union-attr]
|
||||
else evaluated_model_gold_answer.get("description")
|
||||
),
|
||||
"evaluators": (
|
||||
evaluators
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user