mirror of
https://github.com/crewAIInc/crewAI.git
synced 2025-12-24 16:28:29 +00:00
Compare commits
4 Commits
gl/fix/cac
...
docs/train
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e08d210bb8 | ||
|
|
34100d290b | ||
|
|
46f8fa59c1 | ||
|
|
4867dced0e |
17
.github/workflows/linter.yml
vendored
17
.github/workflows/linter.yml
vendored
@@ -15,19 +15,8 @@ jobs:
|
||||
- name: Fetch Target Branch
|
||||
run: git fetch origin $TARGET_BRANCH --depth=1
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v6
|
||||
with:
|
||||
enable-cache: true
|
||||
cache-dependency-glob: |
|
||||
**/pyproject.toml
|
||||
**/uv.lock
|
||||
|
||||
- name: Set up Python
|
||||
run: uv python install 3.11
|
||||
|
||||
- name: Install dependencies
|
||||
run: uv sync --dev --no-install-project
|
||||
- name: Install Ruff
|
||||
run: pip install ruff
|
||||
|
||||
- name: Get Changed Python Files
|
||||
id: changed-files
|
||||
@@ -44,4 +33,4 @@ jobs:
|
||||
echo "${{ steps.changed-files.outputs.files }}" \
|
||||
| tr ' ' '\n' \
|
||||
| grep -v 'src/crewai/cli/templates/' \
|
||||
| xargs -I{} uv run ruff check "{}"
|
||||
| xargs -I{} ruff check "{}"
|
||||
|
||||
16
.github/workflows/security-checker.yml
vendored
16
.github/workflows/security-checker.yml
vendored
@@ -10,20 +10,14 @@ jobs:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v6
|
||||
with:
|
||||
enable-cache: true
|
||||
cache-dependency-glob: |
|
||||
**/pyproject.toml
|
||||
**/uv.lock
|
||||
|
||||
- name: Set up Python
|
||||
run: uv python install 3.11
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11.9"
|
||||
|
||||
- name: Install dependencies
|
||||
run: uv sync --dev --no-install-project
|
||||
run: pip install bandit
|
||||
|
||||
- name: Run Bandit
|
||||
run: uv run bandit -c pyproject.toml -r src/ -ll
|
||||
run: bandit -c pyproject.toml -r src/ -ll
|
||||
|
||||
|
||||
2
.github/workflows/tests.yml
vendored
2
.github/workflows/tests.yml
vendored
@@ -24,7 +24,7 @@ jobs:
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v6
|
||||
uses: astral-sh/setup-uv@v3
|
||||
with:
|
||||
enable-cache: true
|
||||
cache-dependency-glob: |
|
||||
|
||||
75
.github/workflows/type-checker.yml
vendored
75
.github/workflows/type-checker.yml
vendored
@@ -6,78 +6,21 @@ permissions:
|
||||
contents: write
|
||||
|
||||
jobs:
|
||||
type-checker-matrix:
|
||||
name: type-checker (${{ matrix.python-version }})
|
||||
type-checker:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: ["3.10", "3.11", "3.12", "3.13"]
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
fetch-depth: 0 # Fetch all history for proper diff
|
||||
python-version: "3.11.9"
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v6
|
||||
with:
|
||||
enable-cache: true
|
||||
cache-dependency-glob: |
|
||||
**/pyproject.toml
|
||||
**/uv.lock
|
||||
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
run: uv python install ${{ matrix.python-version }}
|
||||
|
||||
- name: Install dependencies
|
||||
run: uv sync --dev --no-install-project
|
||||
|
||||
- name: Get changed Python files
|
||||
id: changed-files
|
||||
- name: Install Requirements
|
||||
run: |
|
||||
# Get the list of changed Python files compared to the base branch
|
||||
echo "Fetching changed files..."
|
||||
git diff --name-only --diff-filter=ACMRT origin/${{ github.base_ref }}...HEAD -- '*.py' > changed_files.txt
|
||||
pip install mypy
|
||||
|
||||
# Filter for files in src/ directory only (excluding tests/)
|
||||
grep -E "^src/" changed_files.txt > filtered_changed_files.txt || true
|
||||
|
||||
# Check if there are any changed files
|
||||
if [ -s filtered_changed_files.txt ]; then
|
||||
echo "Changed Python files in src/:"
|
||||
cat filtered_changed_files.txt
|
||||
echo "has_changes=true" >> $GITHUB_OUTPUT
|
||||
# Convert newlines to spaces for mypy command
|
||||
echo "files=$(cat filtered_changed_files.txt | tr '\n' ' ')" >> $GITHUB_OUTPUT
|
||||
else
|
||||
echo "No Python files changed in src/"
|
||||
echo "has_changes=false" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
|
||||
- name: Run type checks on changed files
|
||||
if: steps.changed-files.outputs.has_changes == 'true'
|
||||
run: |
|
||||
echo "Running mypy on changed files with Python ${{ matrix.python-version }}..."
|
||||
uv run mypy ${{ steps.changed-files.outputs.files }}
|
||||
|
||||
- name: No files to check
|
||||
if: steps.changed-files.outputs.has_changes == 'false'
|
||||
run: echo "No Python files in src/ were modified - skipping type checks"
|
||||
|
||||
# Summary job to provide single status for branch protection
|
||||
type-checker:
|
||||
name: type-checker
|
||||
runs-on: ubuntu-latest
|
||||
needs: type-checker-matrix
|
||||
if: always()
|
||||
steps:
|
||||
- name: Check matrix results
|
||||
run: |
|
||||
if [ "${{ needs.type-checker-matrix.result }}" == "success" ] || [ "${{ needs.type-checker-matrix.result }}" == "skipped" ]; then
|
||||
echo "✅ All type checks passed"
|
||||
else
|
||||
echo "❌ Type checks failed"
|
||||
exit 1
|
||||
fi
|
||||
- name: Run type checks
|
||||
run: mypy src
|
||||
|
||||
@@ -1,16 +1,7 @@
|
||||
repos:
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.12.11
|
||||
rev: v0.8.2
|
||||
hooks:
|
||||
- id: ruff
|
||||
args: ["--config", "pyproject.toml"]
|
||||
args: ["--fix"]
|
||||
- id: ruff-format
|
||||
args: ["--config", "pyproject.toml"]
|
||||
|
||||
- repo: https://github.com/pre-commit/mirrors-mypy
|
||||
rev: v1.17.1
|
||||
hooks:
|
||||
- id: mypy
|
||||
args: ["--strict", "--exclude", "src/crewai/cli/templates"]
|
||||
files: ^src/
|
||||
exclude: ^tests/
|
||||
|
||||
4
.ruff.toml
Normal file
4
.ruff.toml
Normal file
@@ -0,0 +1,4 @@
|
||||
exclude = [
|
||||
"templates",
|
||||
"__init__.py",
|
||||
]
|
||||
12
README.md
12
README.md
@@ -418,10 +418,10 @@ Choose CrewAI to easily build powerful, adaptable, and production-ready AI autom
|
||||
|
||||
You can test different real life examples of AI crews in the [CrewAI-examples repo](https://github.com/crewAIInc/crewAI-examples?tab=readme-ov-file):
|
||||
|
||||
- [Landing Page Generator](https://github.com/crewAIInc/crewAI-examples/tree/main/crews/landing_page_generator)
|
||||
- [Landing Page Generator](https://github.com/crewAIInc/crewAI-examples/tree/main/landing_page_generator)
|
||||
- [Having Human input on the execution](https://docs.crewai.com/how-to/Human-Input-on-Execution)
|
||||
- [Trip Planner](https://github.com/crewAIInc/crewAI-examples/tree/main/crews/trip_planner)
|
||||
- [Stock Analysis](https://github.com/crewAIInc/crewAI-examples/tree/main/crews/stock_analysis)
|
||||
- [Trip Planner](https://github.com/crewAIInc/crewAI-examples/tree/main/trip_planner)
|
||||
- [Stock Analysis](https://github.com/crewAIInc/crewAI-examples/tree/main/stock_analysis)
|
||||
|
||||
### Quick Tutorial
|
||||
|
||||
@@ -429,19 +429,19 @@ You can test different real life examples of AI crews in the [CrewAI-examples re
|
||||
|
||||
### Write Job Descriptions
|
||||
|
||||
[Check out code for this example](https://github.com/crewAIInc/crewAI-examples/tree/main/crews/job-posting) or watch a video below:
|
||||
[Check out code for this example](https://github.com/crewAIInc/crewAI-examples/tree/main/job-posting) or watch a video below:
|
||||
|
||||
[](https://www.youtube.com/watch?v=u98wEMz-9to "Jobs postings")
|
||||
|
||||
### Trip Planner
|
||||
|
||||
[Check out code for this example](https://github.com/crewAIInc/crewAI-examples/tree/main/crews/trip_planner) or watch a video below:
|
||||
[Check out code for this example](https://github.com/crewAIInc/crewAI-examples/tree/main/trip_planner) or watch a video below:
|
||||
|
||||
[](https://www.youtube.com/watch?v=xis7rWp-hjs "Trip Planner")
|
||||
|
||||
### Stock Analysis
|
||||
|
||||
[Check out code for this example](https://github.com/crewAIInc/crewAI-examples/tree/main/crews/stock_analysis) or watch a video below:
|
||||
[Check out code for this example](https://github.com/crewAIInc/crewAI-examples/tree/main/stock_analysis) or watch a video below:
|
||||
|
||||
[](https://www.youtube.com/watch?v=e0Uj4yWdaAg "Stock Analysis")
|
||||
|
||||
|
||||
@@ -320,7 +320,6 @@
|
||||
"en/enterprise/guides/update-crew",
|
||||
"en/enterprise/guides/enable-crew-studio",
|
||||
"en/enterprise/guides/azure-openai-setup",
|
||||
"en/enterprise/guides/automation-triggers",
|
||||
"en/enterprise/guides/hubspot-trigger",
|
||||
"en/enterprise/guides/react-component-export",
|
||||
"en/enterprise/guides/salesforce-trigger",
|
||||
@@ -659,7 +658,6 @@
|
||||
"pt-BR/enterprise/guides/update-crew",
|
||||
"pt-BR/enterprise/guides/enable-crew-studio",
|
||||
"pt-BR/enterprise/guides/azure-openai-setup",
|
||||
"pt-BR/enterprise/guides/automation-triggers",
|
||||
"pt-BR/enterprise/guides/hubspot-trigger",
|
||||
"pt-BR/enterprise/guides/react-component-export",
|
||||
"pt-BR/enterprise/guides/salesforce-trigger",
|
||||
@@ -1009,7 +1007,6 @@
|
||||
"ko/enterprise/guides/update-crew",
|
||||
"ko/enterprise/guides/enable-crew-studio",
|
||||
"ko/enterprise/guides/azure-openai-setup",
|
||||
"ko/enterprise/guides/automation-triggers",
|
||||
"ko/enterprise/guides/hubspot-trigger",
|
||||
"ko/enterprise/guides/react-component-export",
|
||||
"ko/enterprise/guides/salesforce-trigger",
|
||||
|
||||
@@ -282,25 +282,7 @@ Watch this video tutorial for a step-by-step demonstration of deploying your cre
|
||||
allowfullscreen
|
||||
></iframe>
|
||||
|
||||
### 12. Login
|
||||
|
||||
Authenticate with CrewAI Enterprise using a secure device code flow (no email entry required).
|
||||
|
||||
```shell Terminal
|
||||
crewai login
|
||||
```
|
||||
|
||||
What happens:
|
||||
- A verification URL and short code are displayed in your terminal
|
||||
- Your browser opens to the verification URL
|
||||
- Enter/confirm the code to complete authentication
|
||||
|
||||
Notes:
|
||||
- The OAuth2 provider and domain are configured via `crewai config` (defaults use `login.crewai.com`)
|
||||
- After successful login, the CLI also attempts to authenticate to the Tool Repository automatically
|
||||
- If you reset your configuration, run `crewai login` again to re-authenticate
|
||||
|
||||
### 13. API Keys
|
||||
### 11. API Keys
|
||||
|
||||
When running ```crewai create crew``` command, the CLI will show you a list of available LLM providers to choose from, followed by model selection for your chosen provider.
|
||||
|
||||
@@ -328,7 +310,7 @@ See the following link for each provider's key name:
|
||||
|
||||
* [LiteLLM Providers](https://docs.litellm.ai/docs/providers)
|
||||
|
||||
### 14. Configuration Management
|
||||
### 12. Configuration Management
|
||||
|
||||
Manage CLI configuration settings for CrewAI.
|
||||
|
||||
@@ -403,10 +385,6 @@ Reset all configuration to defaults:
|
||||
crewai config reset
|
||||
```
|
||||
|
||||
<Tip>
|
||||
After resetting configuration, re-run `crewai login` to authenticate again.
|
||||
</Tip>
|
||||
|
||||
<Note>
|
||||
Configuration settings are stored in `~/.config/crewai/settings.json`. Some settings like organization name and UUID are read-only and managed through authentication and organization commands. Tool repository related settings are hidden and cannot be set directly by users.
|
||||
</Note>
|
||||
|
||||
@@ -44,12 +44,12 @@ To create a custom event listener, you need to:
|
||||
Here's a simple example of a custom event listener class:
|
||||
|
||||
```python
|
||||
from crewai.events import (
|
||||
from crewai.utilities.events import (
|
||||
CrewKickoffStartedEvent,
|
||||
CrewKickoffCompletedEvent,
|
||||
AgentExecutionCompletedEvent,
|
||||
)
|
||||
from crewai.events import BaseEventListener
|
||||
from crewai.utilities.events.base_event_listener import BaseEventListener
|
||||
|
||||
class MyCustomListener(BaseEventListener):
|
||||
def __init__(self):
|
||||
@@ -146,7 +146,7 @@ my_project/
|
||||
|
||||
```python
|
||||
# my_custom_listener.py
|
||||
from crewai.events import BaseEventListener
|
||||
from crewai.utilities.events.base_event_listener import BaseEventListener
|
||||
# ... import events ...
|
||||
|
||||
class MyCustomListener(BaseEventListener):
|
||||
@@ -279,7 +279,7 @@ Additional fields vary by event type. For example, `CrewKickoffCompletedEvent` i
|
||||
For temporary event handling (useful for testing or specific operations), you can use the `scoped_handlers` context manager:
|
||||
|
||||
```python
|
||||
from crewai.events import crewai_event_bus, CrewKickoffStartedEvent
|
||||
from crewai.utilities.events import crewai_event_bus, CrewKickoffStartedEvent
|
||||
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
@crewai_event_bus.on(CrewKickoffStartedEvent)
|
||||
|
||||
@@ -97,13 +97,7 @@ The state's unique ID and stored data can be useful for tracking flow executions
|
||||
|
||||
### @start()
|
||||
|
||||
The `@start()` decorator marks entry points for a Flow. You can:
|
||||
|
||||
- Declare multiple unconditional starts: `@start()`
|
||||
- Gate a start on a prior method or router label: `@start("method_or_label")`
|
||||
- Provide a callable condition to control when a start should fire
|
||||
|
||||
All satisfied `@start()` methods will execute (often in parallel) when the Flow begins or resumes.
|
||||
The `@start()` decorator is used to mark a method as the starting point of a Flow. When a Flow is started, all the methods decorated with `@start()` are executed in parallel. You can have multiple start methods in a Flow, and they will all be executed when the Flow is started.
|
||||
|
||||
### @listen()
|
||||
|
||||
|
||||
@@ -24,41 +24,6 @@ For file-based Knowledge Sources, make sure to place your files in a `knowledge`
|
||||
Also, use relative paths from the `knowledge` directory when creating the source.
|
||||
</Tip>
|
||||
|
||||
### Vector store (RAG) client configuration
|
||||
|
||||
CrewAI exposes a provider-neutral RAG client abstraction for vector stores. The default provider is ChromaDB, and Qdrant is supported as well. You can switch providers using configuration utilities.
|
||||
|
||||
Supported today:
|
||||
- ChromaDB (default)
|
||||
- Qdrant
|
||||
|
||||
```python Code
|
||||
from crewai.rag.config.utils import set_rag_config, get_rag_client, clear_rag_config
|
||||
|
||||
# ChromaDB (default)
|
||||
from crewai.rag.chromadb.config import ChromaDBConfig
|
||||
set_rag_config(ChromaDBConfig())
|
||||
chromadb_client = get_rag_client()
|
||||
|
||||
# Qdrant
|
||||
from crewai.rag.qdrant.config import QdrantConfig
|
||||
set_rag_config(QdrantConfig())
|
||||
qdrant_client = get_rag_client()
|
||||
|
||||
# Example operations (same API for any provider)
|
||||
client = qdrant_client # or chromadb_client
|
||||
client.create_collection(collection_name="docs")
|
||||
client.add_documents(
|
||||
collection_name="docs",
|
||||
documents=[{"id": "1", "content": "CrewAI enables collaborative AI agents."}],
|
||||
)
|
||||
results = client.search(collection_name="docs", query="collaborative agents", limit=3)
|
||||
|
||||
clear_rag_config() # optional reset
|
||||
```
|
||||
|
||||
This RAG client is separate from Knowledge’s built-in storage. Use it when you need direct vector-store control or custom retrieval pipelines.
|
||||
|
||||
### Basic String Knowledge Example
|
||||
|
||||
```python Code
|
||||
@@ -716,11 +681,11 @@ CrewAI emits events during the knowledge retrieval process that you can listen f
|
||||
#### Example: Monitoring Knowledge Retrieval
|
||||
|
||||
```python
|
||||
from crewai.events import (
|
||||
from crewai.utilities.events import (
|
||||
KnowledgeRetrievalStartedEvent,
|
||||
KnowledgeRetrievalCompletedEvent,
|
||||
BaseEventListener,
|
||||
)
|
||||
from crewai.utilities.events.base_event_listener import BaseEventListener
|
||||
|
||||
class KnowledgeMonitorListener(BaseEventListener):
|
||||
def setup_listeners(self, crewai_event_bus):
|
||||
|
||||
@@ -733,10 +733,10 @@ CrewAI supports streaming responses from LLMs, allowing your application to rece
|
||||
CrewAI emits events for each chunk received during streaming:
|
||||
|
||||
```python
|
||||
from crewai.events import (
|
||||
from crewai.utilities.events import (
|
||||
LLMStreamChunkEvent
|
||||
)
|
||||
from crewai.events import BaseEventListener
|
||||
from crewai.utilities.events.base_event_listener import BaseEventListener
|
||||
|
||||
class MyCustomListener(BaseEventListener):
|
||||
def setup_listeners(self, crewai_event_bus):
|
||||
@@ -758,8 +758,8 @@ CrewAI supports streaming responses from LLMs, allowing your application to rece
|
||||
|
||||
```python
|
||||
from crewai import LLM, Agent, Task, Crew
|
||||
from crewai.events import LLMStreamChunkEvent
|
||||
from crewai.events import BaseEventListener
|
||||
from crewai.utilities.events import LLMStreamChunkEvent
|
||||
from crewai.utilities.events.base_event_listener import BaseEventListener
|
||||
|
||||
class MyCustomListener(BaseEventListener):
|
||||
def setup_listeners(self, crewai_event_bus):
|
||||
|
||||
@@ -738,17 +738,6 @@ print(f"OpenAI: {openai_time:.2f}s")
|
||||
print(f"Ollama: {ollama_time:.2f}s")
|
||||
```
|
||||
|
||||
### Entity Memory batching behavior
|
||||
|
||||
Entity Memory supports batching when saving multiple entities at once. When you pass a list of `EntityMemoryItem`, the system:
|
||||
|
||||
- Emits a single MemorySaveStartedEvent with `entity_count`
|
||||
- Saves each entity internally, collecting any partial errors
|
||||
- Emits MemorySaveCompletedEvent with aggregate metadata (saved count, errors)
|
||||
- Raises a partial-save exception if some entities failed (includes counts)
|
||||
|
||||
This improves performance and observability when writing many entities in one operation.
|
||||
|
||||
## 2. External Memory
|
||||
External Memory provides a standalone memory system that operates independently from the crew's built-in memory. This is ideal for specialized memory providers or cross-application memory sharing.
|
||||
|
||||
@@ -1052,8 +1041,8 @@ CrewAI emits the following memory-related events:
|
||||
Track memory operation timing to optimize your application:
|
||||
|
||||
```python
|
||||
from crewai.events import (
|
||||
BaseEventListener,
|
||||
from crewai.utilities.events.base_event_listener import BaseEventListener
|
||||
from crewai.utilities.events import (
|
||||
MemoryQueryCompletedEvent,
|
||||
MemorySaveCompletedEvent
|
||||
)
|
||||
@@ -1087,8 +1076,8 @@ memory_monitor = MemoryPerformanceMonitor()
|
||||
Log memory operations for debugging and insights:
|
||||
|
||||
```python
|
||||
from crewai.events import (
|
||||
BaseEventListener,
|
||||
from crewai.utilities.events.base_event_listener import BaseEventListener
|
||||
from crewai.utilities.events import (
|
||||
MemorySaveStartedEvent,
|
||||
MemoryQueryStartedEvent,
|
||||
MemoryRetrievalCompletedEvent
|
||||
@@ -1128,8 +1117,8 @@ memory_logger = MemoryLogger()
|
||||
Capture and respond to memory errors:
|
||||
|
||||
```python
|
||||
from crewai.events import (
|
||||
BaseEventListener,
|
||||
from crewai.utilities.events.base_event_listener import BaseEventListener
|
||||
from crewai.utilities.events import (
|
||||
MemorySaveFailedEvent,
|
||||
MemoryQueryFailedEvent
|
||||
)
|
||||
@@ -1178,8 +1167,8 @@ error_tracker = MemoryErrorTracker(notify_email="admin@example.com")
|
||||
Memory events can be forwarded to analytics and monitoring platforms to track performance metrics, detect anomalies, and visualize memory usage patterns:
|
||||
|
||||
```python
|
||||
from crewai.events import (
|
||||
BaseEventListener,
|
||||
from crewai.utilities.events.base_event_listener import BaseEventListener
|
||||
from crewai.utilities.events import (
|
||||
MemoryQueryCompletedEvent,
|
||||
MemorySaveCompletedEvent
|
||||
)
|
||||
|
||||
@@ -59,12 +59,6 @@ crew = Crew(
|
||||
| **Output Pydantic** _(optional)_ | `output_pydantic` | `Optional[Type[BaseModel]]` | A Pydantic model for task output. |
|
||||
| **Callback** _(optional)_ | `callback` | `Optional[Any]` | Function/object to be executed after task completion. |
|
||||
| **Guardrail** _(optional)_ | `guardrail` | `Optional[Callable]` | Function to validate task output before proceeding to next task. |
|
||||
| **Guardrail Max Retries** _(optional)_ | `guardrail_max_retries` | `Optional[int]` | Maximum number of retries when guardrail validation fails. Defaults to 3. |
|
||||
|
||||
<Note type="warning" title="Deprecated: max_retries">
|
||||
The task attribute `max_retries` is deprecated and will be removed in v1.0.0.
|
||||
Use `guardrail_max_retries` instead to control retry attempts when a guardrail fails.
|
||||
</Note>
|
||||
|
||||
## Creating Tasks
|
||||
|
||||
@@ -437,7 +431,7 @@ When a guardrail returns `(False, error)`:
|
||||
2. The agent attempts to fix the issue
|
||||
3. The process repeats until:
|
||||
- The guardrail returns `(True, result)`
|
||||
- Maximum retries are reached (`guardrail_max_retries`)
|
||||
- Maximum retries are reached
|
||||
|
||||
Example with retry handling:
|
||||
```python Code
|
||||
@@ -458,7 +452,7 @@ task = Task(
|
||||
expected_output="A valid JSON object",
|
||||
agent=analyst,
|
||||
guardrail=validate_json_output,
|
||||
guardrail_max_retries=3 # Limit retry attempts
|
||||
max_retries=3 # Limit retry attempts
|
||||
)
|
||||
```
|
||||
|
||||
|
||||
@@ -59,7 +59,7 @@ Before using Authentication Integrations, ensure you have:
|
||||
3. Click **Connect** on your desired service from the Authentication Integrations section
|
||||
4. Complete the OAuth authentication flow
|
||||
5. Grant necessary permissions for your use case
|
||||
6. All set! Get your Enterprise Token from your [CrewAI Enterprise](https://app.crewai.com) in **Integration** tab
|
||||
6. Get your Enterprise Token from your [CrewAI Enterprise](https://app.crewai.com) account page - https://app.crewai.com/crewai_plus/settings/account
|
||||
|
||||
<Frame>
|
||||

|
||||
|
||||
@@ -141,16 +141,6 @@ Traces are invaluable for troubleshooting issues with your crews:
|
||||
</Step>
|
||||
</Steps>
|
||||
|
||||
## Performance and batching
|
||||
|
||||
CrewAI batches trace uploads to reduce overhead on high-volume runs:
|
||||
|
||||
- A TraceBatchManager buffers events and sends them in batches via the Plus API client
|
||||
- Reduces network chatter and improves reliability on flaky connections
|
||||
- Automatically enabled in the default trace listener; no configuration needed
|
||||
|
||||
This yields more stable tracing under load while preserving detailed task/agent telemetry.
|
||||
|
||||
<Card title="Need Help?" icon="headset" href="mailto:support@crewai.com">
|
||||
Contact our support team for assistance with trace analysis or any other CrewAI Enterprise features.
|
||||
</Card>
|
||||
@@ -1,178 +0,0 @@
|
||||
---
|
||||
title: "Automation Triggers"
|
||||
description: "Automatically execute your CrewAI workflows when specific events occur in connected integrations"
|
||||
icon: "bolt"
|
||||
---
|
||||
|
||||
Automation triggers enable you to automatically run your CrewAI deployments when specific events occur in your connected integrations, creating powerful event-driven workflows that respond to real-time changes in your business systems.
|
||||
|
||||
## Overview
|
||||
|
||||
With automation triggers, you can:
|
||||
|
||||
- **Respond to real-time events** - Automatically execute workflows when specific conditions are met
|
||||
- **Integrate with external systems** - Connect with platforms like Gmail, Outlook, OneDrive, JIRA, Slack, Stripe and more
|
||||
- **Scale your automation** - Handle high-volume events without manual intervention
|
||||
- **Maintain context** - Access trigger data within your crews and flows
|
||||
|
||||
## Managing Automation Triggers
|
||||
|
||||
### Viewing Available Triggers
|
||||
|
||||
To access and manage your automation triggers:
|
||||
|
||||
1. Navigate to your deployment in the CrewAI dashboard
|
||||
2. Click on the **Triggers** tab to view all available trigger integrations
|
||||
|
||||
<Frame>
|
||||
<img src="/images/enterprise/list-available-triggers.png" alt="List of available automation triggers" />
|
||||
</Frame>
|
||||
|
||||
This view shows all the trigger integrations available for your deployment, along with their current connection status.
|
||||
|
||||
### Enabling and Disabling Triggers
|
||||
|
||||
Each trigger can be easily enabled or disabled using the toggle switch:
|
||||
|
||||
<Frame>
|
||||
<img src="/images/enterprise/trigger-selected.png" alt="Enable or disable triggers with toggle" />
|
||||
</Frame>
|
||||
|
||||
- **Enabled (blue toggle)**: The trigger is active and will automatically execute your deployment when the specified events occur
|
||||
- **Disabled (gray toggle)**: The trigger is inactive and will not respond to events
|
||||
|
||||
Simply click the toggle to change the trigger state. Changes take effect immediately.
|
||||
|
||||
### Monitoring Trigger Executions
|
||||
|
||||
Track the performance and history of your triggered executions:
|
||||
|
||||
<Frame>
|
||||
<img src="/images/enterprise/list-executions.png" alt="List of executions triggered by automation" />
|
||||
</Frame>
|
||||
|
||||
## Building Automation
|
||||
|
||||
Before building your automation, it's helpful to understand the structure of trigger payloads that your crews and flows will receive.
|
||||
|
||||
### Payload Samples Repository
|
||||
|
||||
We maintain a comprehensive repository with sample payloads from various trigger sources to help you build and test your automations:
|
||||
|
||||
**🔗 [CrewAI Enterprise Trigger Payload Samples](https://github.com/crewAIInc/crewai-enterprise-trigger-payload-samples)**
|
||||
|
||||
This repository contains:
|
||||
|
||||
- **Real payload examples** from different trigger sources (Gmail, Google Drive, etc.)
|
||||
- **Payload structure documentation** showing the format and available fields
|
||||
|
||||
### Triggers with Crew
|
||||
|
||||
Your existing crew definitions work seamlessly with triggers, you just need to have a task to parse the received payload:
|
||||
|
||||
```python
|
||||
@CrewBase
|
||||
class MyAutomatedCrew:
|
||||
@agent
|
||||
def researcher(self) -> Agent:
|
||||
return Agent(
|
||||
config=self.agents_config['researcher'],
|
||||
)
|
||||
|
||||
@task
|
||||
def parse_trigger_payload(self) -> Task:
|
||||
return Task(
|
||||
config=self.tasks_config['parse_trigger_payload'],
|
||||
agent=self.researcher(),
|
||||
)
|
||||
|
||||
@task
|
||||
def analyze_trigger_content(self) -> Task:
|
||||
return Task(
|
||||
config=self.tasks_config['analyze_trigger_data'],
|
||||
agent=self.researcher(),
|
||||
)
|
||||
```
|
||||
|
||||
The crew will automatically receive and can access the trigger payload through the standard CrewAI context mechanisms.
|
||||
|
||||
<Note>
|
||||
Crew and Flow inputs can include `crewai_trigger_payload`. CrewAI automatically injects this payload:
|
||||
- Tasks: appended to the first task's description by default ("Trigger Payload: {crewai_trigger_payload}")
|
||||
- Control via `allow_crewai_trigger_context`: set `True` to always inject, `False` to never inject
|
||||
- Flows: any `@start()` method that accepts a `crewai_trigger_payload` parameter will receive it
|
||||
</Note>
|
||||
|
||||
### Integration with Flows
|
||||
|
||||
For flows, you have more control over how trigger data is handled:
|
||||
|
||||
#### Accessing Trigger Payload
|
||||
|
||||
All `@start()` methods in your flows will accept an additional parameter called `crewai_trigger_payload`:
|
||||
|
||||
```python
|
||||
from crewai.flow import Flow, start, listen
|
||||
|
||||
class MyAutomatedFlow(Flow):
|
||||
@start()
|
||||
def handle_trigger(self, crewai_trigger_payload: dict = None):
|
||||
"""
|
||||
This start method can receive trigger data
|
||||
"""
|
||||
if crewai_trigger_payload:
|
||||
# Process the trigger data
|
||||
trigger_id = crewai_trigger_payload.get('id')
|
||||
event_data = crewai_trigger_payload.get('payload', {})
|
||||
|
||||
# Store in flow state for use by other methods
|
||||
self.state.trigger_id = trigger_id
|
||||
self.state.trigger_type = event_data
|
||||
|
||||
return event_data
|
||||
|
||||
# Handle manual execution
|
||||
return None
|
||||
|
||||
@listen(handle_trigger)
|
||||
def process_data(self, trigger_data):
|
||||
"""
|
||||
Process the data from the trigger
|
||||
"""
|
||||
# ... process the trigger
|
||||
```
|
||||
|
||||
#### Triggering Crews from Flows
|
||||
|
||||
When kicking off a crew within a flow that was triggered, pass the trigger payload as it:
|
||||
|
||||
```python
|
||||
@start()
|
||||
def delegate_to_crew(self, crewai_trigger_payload: dict = None):
|
||||
"""
|
||||
Delegate processing to a specialized crew
|
||||
"""
|
||||
crew = MySpecializedCrew()
|
||||
|
||||
# Pass the trigger payload to the crew
|
||||
result = crew.crew().kickoff(
|
||||
inputs={
|
||||
'a_custom_parameter': "custom_value",
|
||||
'crewai_trigger_payload': crewai_trigger_payload
|
||||
},
|
||||
)
|
||||
|
||||
return result
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
**Trigger not firing:**
|
||||
- Verify the trigger is enabled
|
||||
- Check integration connection status
|
||||
|
||||
**Execution failures:**
|
||||
- Check the execution logs for error details
|
||||
- If you are developing, make sure the inputs include the `crewai_trigger_payload` parameter with the correct payload
|
||||
|
||||
Automation triggers transform your CrewAI deployments into responsive, event-driven systems that can seamlessly integrate with your existing business processes and tools.
|
||||
@@ -348,31 +348,6 @@ class SelectivePersistFlow(Flow):
|
||||
|
||||
## Advanced State Patterns
|
||||
|
||||
### Conditional starts and resumable execution
|
||||
|
||||
Flows support conditional `@start()` and resumable execution for HITL/cyclic scenarios:
|
||||
|
||||
```python
|
||||
from crewai.flow.flow import Flow, start, listen, and_, or_
|
||||
|
||||
class ResumableFlow(Flow):
|
||||
@start() # unconditional start
|
||||
def init(self):
|
||||
...
|
||||
|
||||
# Conditional start: run after "init" or external trigger name
|
||||
@start("init")
|
||||
def maybe_begin(self):
|
||||
...
|
||||
|
||||
@listen(and_(init, maybe_begin))
|
||||
def proceed(self):
|
||||
...
|
||||
```
|
||||
|
||||
- Conditional `@start()` accepts a method name, a router label, or a callable condition.
|
||||
- During resume, listeners continue from prior checkpoints; cycle/router branches honor resumption flags.
|
||||
|
||||
### State-Based Conditional Logic
|
||||
|
||||
You can use state to implement complex conditional logic in your flows:
|
||||
|
||||
@@ -30,12 +30,6 @@ Watch this video tutorial for a step-by-step demonstration of the installation p
|
||||
If you need to update Python, visit [python.org/downloads](https://python.org/downloads)
|
||||
</Note>
|
||||
|
||||
<Note>
|
||||
**OpenAI SDK Requirement**
|
||||
|
||||
CrewAI 0.175.0 requires `openai >= 1.13.3`. If you manage dependencies yourself, ensure your environment satisfies this constraint to avoid import/runtime issues.
|
||||
</Note>
|
||||
|
||||
CrewAI uses the `uv` as its dependency management and package handling tool. It simplifies project setup and execution, offering a seamless experience.
|
||||
|
||||
If you haven't installed `uv` yet, follow **step 1** to quickly get it set up on your system, else you can skip to **step 2**.
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
---
|
||||
title: Weaviate Vector Search
|
||||
description: The `WeaviateVectorSearchTool` is designed to search a Weaviate vector database for semantically similar documents using hybrid search.
|
||||
description: The `WeaviateVectorSearchTool` is designed to search a Weaviate vector database for semantically similar documents.
|
||||
icon: network-wired
|
||||
---
|
||||
|
||||
## Overview
|
||||
|
||||
|
||||
The `WeaviateVectorSearchTool` is specifically crafted for conducting semantic searches within documents stored in a Weaviate vector database. This tool allows you to find semantically similar documents to a given query, leveraging the power of vector and keyword search for more accurate and contextually relevant search results.
|
||||
The `WeaviateVectorSearchTool` is specifically crafted for conducting semantic searches within documents stored in a Weaviate vector database. This tool allows you to find semantically similar documents to a given query, leveraging the power of vector embeddings for more accurate and contextually relevant search results.
|
||||
|
||||
[Weaviate](https://weaviate.io/) is a vector database that stores and queries vector embeddings, enabling semantic search capabilities.
|
||||
|
||||
@@ -39,7 +39,6 @@ from crewai_tools import WeaviateVectorSearchTool
|
||||
tool = WeaviateVectorSearchTool(
|
||||
collection_name='example_collections',
|
||||
limit=3,
|
||||
alpha=0.75,
|
||||
weaviate_cluster_url="https://your-weaviate-cluster-url.com",
|
||||
weaviate_api_key="your-weaviate-api-key",
|
||||
)
|
||||
@@ -64,7 +63,6 @@ The `WeaviateVectorSearchTool` accepts the following parameters:
|
||||
- **weaviate_cluster_url**: Required. The URL of the Weaviate cluster.
|
||||
- **weaviate_api_key**: Required. The API key for the Weaviate cluster.
|
||||
- **limit**: Optional. The number of results to return. Default is `3`.
|
||||
- **alpha**: Optional. Controls the weighting between vector and keyword (BM25) search. alpha = 0 -> BM25 only, alpha = 1 -> vector search only. Default is `0.75`.
|
||||
- **vectorizer**: Optional. The vectorizer to use. If not provided, it will use `text2vec_openai` with the `nomic-embed-text` model.
|
||||
- **generative_model**: Optional. The generative model to use. If not provided, it will use OpenAI's `gpt-4o`.
|
||||
|
||||
@@ -80,7 +78,6 @@ from weaviate.classes.config import Configure
|
||||
tool = WeaviateVectorSearchTool(
|
||||
collection_name='example_collections',
|
||||
limit=3,
|
||||
alpha=0.75,
|
||||
vectorizer=Configure.Vectorizer.text2vec_openai(model="nomic-embed-text"),
|
||||
generative_model=Configure.Generative.openai(model="gpt-4o-mini"),
|
||||
weaviate_cluster_url="https://your-weaviate-cluster-url.com",
|
||||
@@ -131,7 +128,6 @@ with test_docs.batch.dynamic() as batch:
|
||||
tool = WeaviateVectorSearchTool(
|
||||
collection_name='example_collections',
|
||||
limit=3,
|
||||
alpha=0.75,
|
||||
weaviate_cluster_url="https://your-weaviate-cluster-url.com",
|
||||
weaviate_api_key="your-weaviate-api-key",
|
||||
)
|
||||
@@ -149,7 +145,6 @@ from crewai_tools import WeaviateVectorSearchTool
|
||||
weaviate_tool = WeaviateVectorSearchTool(
|
||||
collection_name='example_collections',
|
||||
limit=3,
|
||||
alpha=0.75,
|
||||
weaviate_cluster_url="https://your-weaviate-cluster-url.com",
|
||||
weaviate_api_key="your-weaviate-api-key",
|
||||
)
|
||||
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 72 KiB After Width: | Height: | Size: 54 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 142 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 330 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 133 KiB |
@@ -44,12 +44,12 @@ Prompt Tracing을 통해 다음과 같은 작업이 가능합니다:
|
||||
아래는 커스텀 이벤트 리스너 클래스의 간단한 예시입니다:
|
||||
|
||||
```python
|
||||
from crewai.events import (
|
||||
from crewai.utilities.events import (
|
||||
CrewKickoffStartedEvent,
|
||||
CrewKickoffCompletedEvent,
|
||||
AgentExecutionCompletedEvent,
|
||||
)
|
||||
from crewai.events import BaseEventListener
|
||||
from crewai.utilities.events.base_event_listener import BaseEventListener
|
||||
|
||||
class MyCustomListener(BaseEventListener):
|
||||
def __init__(self):
|
||||
@@ -146,7 +146,7 @@ my_project/
|
||||
|
||||
```python
|
||||
# my_custom_listener.py
|
||||
from crewai.events import BaseEventListener
|
||||
from crewai.utilities.events.base_event_listener import BaseEventListener
|
||||
# ... import events ...
|
||||
|
||||
class MyCustomListener(BaseEventListener):
|
||||
@@ -279,7 +279,7 @@ CrewAI는 여러분이 청취할 수 있는 다양한 이벤트를 제공합니
|
||||
임시 이벤트 처리가 필요한 경우(테스트 또는 특정 작업에 유용함), `scoped_handlers` 컨텍스트 관리자를 사용할 수 있습니다:
|
||||
|
||||
```python
|
||||
from crewai.events import crewai_event_bus, CrewKickoffStartedEvent
|
||||
from crewai.utilities.events import crewai_event_bus, CrewKickoffStartedEvent
|
||||
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
@crewai_event_bus.on(CrewKickoffStartedEvent)
|
||||
|
||||
@@ -683,11 +683,11 @@ CrewAI는 knowledge 검색 과정에서 이벤트를 발생시키며, 이벤트
|
||||
#### 예시: Knowledge Retrieval 모니터링
|
||||
|
||||
```python
|
||||
from crewai.events import (
|
||||
from crewai.utilities.events import (
|
||||
KnowledgeRetrievalStartedEvent,
|
||||
KnowledgeRetrievalCompletedEvent,
|
||||
BaseEventListener,
|
||||
)
|
||||
from crewai.utilities.events.base_event_listener import BaseEventListener
|
||||
|
||||
class KnowledgeMonitorListener(BaseEventListener):
|
||||
def setup_listeners(self, crewai_event_bus):
|
||||
|
||||
@@ -731,10 +731,10 @@ CrewAI는 LLM의 스트리밍 응답을 지원하여, 애플리케이션이 출
|
||||
CrewAI는 스트리밍 중 수신되는 각 청크에 대해 이벤트를 발생시킵니다:
|
||||
|
||||
```python
|
||||
from crewai.events import (
|
||||
from crewai.utilities.events import (
|
||||
LLMStreamChunkEvent
|
||||
)
|
||||
from crewai.events import BaseEventListener
|
||||
from crewai.utilities.events.base_event_listener import BaseEventListener
|
||||
|
||||
class MyCustomListener(BaseEventListener):
|
||||
def setup_listeners(self, crewai_event_bus):
|
||||
@@ -756,8 +756,8 @@ CrewAI는 LLM의 스트리밍 응답을 지원하여, 애플리케이션이 출
|
||||
|
||||
```python
|
||||
from crewai import LLM, Agent, Task, Crew
|
||||
from crewai.events import LLMStreamChunkEvent
|
||||
from crewai.events import BaseEventListener
|
||||
from crewai.utilities.events import LLMStreamChunkEvent
|
||||
from crewai.utilities.events.base_event_listener import BaseEventListener
|
||||
|
||||
class MyCustomListener(BaseEventListener):
|
||||
def setup_listeners(self, crewai_event_bus):
|
||||
|
||||
@@ -985,8 +985,8 @@ CrewAI는 다음과 같은 메모리 관련 이벤트를 발생시킵니다:
|
||||
애플리케이션을 최적화하기 위해 메모리 작업 타이밍을 추적하세요:
|
||||
|
||||
```python
|
||||
from crewai.events import (
|
||||
BaseEventListener,
|
||||
from crewai.utilities.events.base_event_listener import BaseEventListener
|
||||
from crewai.utilities.events import (
|
||||
MemoryQueryCompletedEvent,
|
||||
MemorySaveCompletedEvent
|
||||
)
|
||||
@@ -1020,8 +1020,8 @@ memory_monitor = MemoryPerformanceMonitor()
|
||||
디버깅 및 인사이트를 위해 메모리 작업을 로깅합니다:
|
||||
|
||||
```python
|
||||
from crewai.events import (
|
||||
BaseEventListener,
|
||||
from crewai.utilities.events.base_event_listener import BaseEventListener
|
||||
from crewai.utilities.events import (
|
||||
MemorySaveStartedEvent,
|
||||
MemoryQueryStartedEvent,
|
||||
MemoryRetrievalCompletedEvent
|
||||
@@ -1061,8 +1061,8 @@ memory_logger = MemoryLogger()
|
||||
메모리 오류를 캡처하고 대응합니다:
|
||||
|
||||
```python
|
||||
from crewai.events import (
|
||||
BaseEventListener,
|
||||
from crewai.utilities.events.base_event_listener import BaseEventListener
|
||||
from crewai.utilities.events import (
|
||||
MemorySaveFailedEvent,
|
||||
MemoryQueryFailedEvent
|
||||
)
|
||||
@@ -1111,8 +1111,8 @@ error_tracker = MemoryErrorTracker(notify_email="admin@example.com")
|
||||
메모리 이벤트는 분석 및 모니터링 플랫폼으로 전달되어 성능 지표를 추적하고, 이상 징후를 감지하며, 메모리 사용 패턴을 시각화할 수 있습니다:
|
||||
|
||||
```python
|
||||
from crewai.events import (
|
||||
BaseEventListener,
|
||||
from crewai.utilities.events.base_event_listener import BaseEventListener
|
||||
from crewai.utilities.events import (
|
||||
MemoryQueryCompletedEvent,
|
||||
MemorySaveCompletedEvent
|
||||
)
|
||||
|
||||
@@ -59,7 +59,6 @@ crew = Crew(
|
||||
| **Pydantic 출력** _(선택 사항)_ | `output_pydantic` | `Optional[Type[BaseModel]]` | 태스크 출력용 Pydantic 모델입니다. |
|
||||
| **콜백** _(선택 사항)_ | `callback` | `Optional[Any]` | 태스크 완료 후 실행할 함수/객체입니다. |
|
||||
| **가드레일** _(선택 사항)_ | `guardrail` | `Optional[Callable]` | 다음 태스크로 진행하기 전에 태스크 출력을 검증하는 함수입니다. |
|
||||
| **가드레일 최대 재시도** _(선택 사항)_ | `guardrail_max_retries` | `Optional[int]` | 가드레일 검증 실패 시 최대 재시도 횟수입니다. 기본값은 3입니다. |
|
||||
|
||||
## 작업 생성하기
|
||||
|
||||
@@ -449,7 +448,7 @@ task = Task(
|
||||
expected_output="A valid JSON object",
|
||||
agent=analyst,
|
||||
guardrail=validate_json_output,
|
||||
guardrail_max_retries=3 # 재시도 횟수 제한
|
||||
max_retries=3 # Limit retry attempts
|
||||
)
|
||||
```
|
||||
|
||||
@@ -900,4 +899,4 @@ except RuntimeError as e:
|
||||
작업(task)은 CrewAI 에이전트의 행동을 이끄는 원동력입니다.
|
||||
작업과 그 결과를 적절하게 정의함으로써, 에이전트가 독립적으로 또는 협업 단위로 효과적으로 작동할 수 있는 기반을 마련할 수 있습니다.
|
||||
작업에 적합한 도구를 장착하고, 실행 과정을 이해하며, 견고한 검증 절차를 따르는 것은 CrewAI의 잠재력을 극대화하는 데 필수적입니다.
|
||||
이를 통해 에이전트가 할당된 작업에 효과적으로 준비되고, 작업이 의도대로 수행될 수 있습니다.
|
||||
이를 통해 에이전트가 할당된 작업에 효과적으로 준비되고, 작업이 의도대로 수행될 수 있습니다.
|
||||
@@ -58,7 +58,7 @@ Authentication Integrations를 사용하기 전에 다음이 준비되어 있는
|
||||
3. Authentication Integrations 섹션에서 원하는 서비스의 **Connect** 버튼을 클릭합니다.
|
||||
4. OAuth 인증 과정을 완료합니다.
|
||||
5. 사용 사례에 필요한 권한을 부여합니다.
|
||||
6. 완료! [CrewAI Enterprise](https://app.crewai.com)의 **Integration** 탭에서 Enterprise Token을 받습니다.
|
||||
6. [CrewAI Enterprise](https://app.crewai.com) 계정 페이지 - https://app.crewai.com/crewai_plus/settings/account 에서 Enterprise Token을 받습니다.
|
||||
|
||||
<Frame>
|
||||

|
||||
@@ -176,4 +176,4 @@ crew를 배포하고 각 통합을 특정 사용자에게 범위 지정할 수
|
||||
|
||||
<Card title="도움이 필요하신가요?" icon="headset" href="mailto:support@crewai.com">
|
||||
통합 설정이나 문제 해결에 대한 지원이 필요하시면 저희 지원팀에 문의하세요.
|
||||
</Card>
|
||||
</Card>
|
||||
@@ -1,171 +0,0 @@
|
||||
---
|
||||
title: "자동화 트리거"
|
||||
description: "연결된 통합에서 특정 이벤트가 발생할 때 CrewAI 워크플로우를 자동으로 실행합니다"
|
||||
icon: "bolt"
|
||||
---
|
||||
|
||||
자동화 트리거를 사용하면 연결된 통합에서 특정 이벤트가 발생할 때 CrewAI 배포를 자동으로 실행할 수 있어, 비즈니스 시스템의 실시간 변화에 반응하는 강력한 이벤트 기반 워크플로우를 만들 수 있습니다.
|
||||
|
||||
## 개요
|
||||
|
||||
자동화 트리거를 사용하면 다음을 수행할 수 있습니다:
|
||||
|
||||
- **실시간 이벤트에 응답** - 특정 조건이 충족될 때 워크플로우를 자동으로 실행
|
||||
- **외부 시스템과 통합** - Gmail, Outlook, OneDrive, JIRA, Slack, Stripe 등의 플랫폼과 연결
|
||||
- **자동화 확장** - 수동 개입 없이 대용량 이벤트 처리
|
||||
- **컨텍스트 유지** - crew와 flow 내에서 트리거 데이터에 액세스
|
||||
|
||||
## 자동화 트리거 관리
|
||||
|
||||
### 사용 가능한 트리거 보기
|
||||
|
||||
자동화 트리거에 액세스하고 관리하려면:
|
||||
|
||||
1. CrewAI 대시보드에서 배포로 이동
|
||||
2. **트리거** 탭을 클릭하여 사용 가능한 모든 트리거 통합 보기
|
||||
|
||||
<Frame>
|
||||
<img src="/images/enterprise/list-available-triggers.png" alt="사용 가능한 자동화 트리거 목록" />
|
||||
</Frame>
|
||||
|
||||
이 보기는 배포에 사용 가능한 모든 트리거 통합과 현재 연결 상태를 보여줍니다.
|
||||
|
||||
### 트리거 활성화 및 비활성화
|
||||
|
||||
각 트리거는 토글 스위치를 사용하여 쉽게 활성화하거나 비활성화할 수 있습니다:
|
||||
|
||||
<Frame>
|
||||
<img src="/images/enterprise/trigger-selected.png" alt="토글로 트리거 활성화 또는 비활성화" />
|
||||
</Frame>
|
||||
|
||||
- **활성화됨 (파란색 토글)**: 트리거가 활성 상태이며 지정된 이벤트가 발생할 때 배포를 자동으로 실행합니다
|
||||
- **비활성화됨 (회색 토글)**: 트리거가 비활성 상태이며 이벤트에 응답하지 않습니다
|
||||
|
||||
토글을 클릭하기만 하면 트리거 상태를 변경할 수 있습니다. 변경 사항은 즉시 적용됩니다.
|
||||
|
||||
### 트리거 실행 모니터링
|
||||
|
||||
트리거된 실행의 성능과 기록을 추적합니다:
|
||||
|
||||
<Frame>
|
||||
<img src="/images/enterprise/list-executions.png" alt="자동화에 의해 트리거된 실행 목록" />
|
||||
</Frame>
|
||||
|
||||
## 자동화 구축
|
||||
|
||||
자동화를 구축하기 전에 crew와 flow가 받을 트리거 페이로드의 구조를 이해하는 것이 도움이 됩니다.
|
||||
|
||||
### 페이로드 샘플 저장소
|
||||
|
||||
자동화를 구축하고 테스트하는 데 도움이 되도록 다양한 트리거 소스의 샘플 페이로드가 포함된 포괄적인 저장소를 유지 관리하고 있습니다:
|
||||
|
||||
**🔗 [CrewAI Enterprise 트리거 페이로드 샘플](https://github.com/crewAIInc/crewai-enterprise-trigger-payload-samples)**
|
||||
|
||||
이 저장소에는 다음이 포함되어 있습니다:
|
||||
|
||||
- **실제 페이로드 예제** - 다양한 트리거 소스(Gmail, Google Drive 등)에서 가져온 예제
|
||||
- **페이로드 구조 문서** - 형식과 사용 가능한 필드를 보여주는 문서
|
||||
|
||||
### Crew와 트리거
|
||||
|
||||
기존 crew 정의는 트리거와 완벽하게 작동하며, 받은 페이로드를 분석하는 작업만 있으면 됩니다:
|
||||
|
||||
```python
|
||||
@CrewBase
|
||||
class MyAutomatedCrew:
|
||||
@agent
|
||||
def researcher(self) -> Agent:
|
||||
return Agent(
|
||||
config=self.agents_config['researcher'],
|
||||
)
|
||||
|
||||
@task
|
||||
def parse_trigger_payload(self) -> Task:
|
||||
return Task(
|
||||
config=self.tasks_config['parse_trigger_payload'],
|
||||
agent=self.researcher(),
|
||||
)
|
||||
|
||||
@task
|
||||
def analyze_trigger_content(self) -> Task:
|
||||
return Task(
|
||||
config=self.tasks_config['analyze_trigger_data'],
|
||||
agent=self.researcher(),
|
||||
)
|
||||
```
|
||||
|
||||
crew는 자동으로 트리거 페이로드를 받고 표준 CrewAI 컨텍스트 메커니즘을 통해 액세스할 수 있습니다.
|
||||
|
||||
### Flow와의 통합
|
||||
|
||||
flow의 경우 트리거 데이터 처리 방법을 더 세밀하게 제어할 수 있습니다:
|
||||
|
||||
#### 트리거 페이로드 액세스
|
||||
|
||||
flow의 모든 `@start()` 메서드는 `crewai_trigger_payload`라는 추가 매개변수를 허용합니다:
|
||||
|
||||
```python
|
||||
from crewai.flow import Flow, start, listen
|
||||
|
||||
class MyAutomatedFlow(Flow):
|
||||
@start()
|
||||
def handle_trigger(self, crewai_trigger_payload: dict = None):
|
||||
"""
|
||||
이 start 메서드는 트리거 데이터를 받을 수 있습니다
|
||||
"""
|
||||
if crewai_trigger_payload:
|
||||
# 트리거 데이터 처리
|
||||
trigger_id = crewai_trigger_payload.get('id')
|
||||
event_data = crewai_trigger_payload.get('payload', {})
|
||||
|
||||
# 다른 메서드에서 사용할 수 있도록 flow 상태에 저장
|
||||
self.state.trigger_id = trigger_id
|
||||
self.state.trigger_type = event_data
|
||||
|
||||
return event_data
|
||||
|
||||
# 수동 실행 처리
|
||||
return None
|
||||
|
||||
@listen(handle_trigger)
|
||||
def process_data(self, trigger_data):
|
||||
"""
|
||||
트리거 데이터 처리
|
||||
"""
|
||||
# ... 트리거 처리
|
||||
```
|
||||
|
||||
#### Flow에서 Crew 트리거하기
|
||||
|
||||
트리거된 flow 내에서 crew를 시작할 때 트리거 페이로드를 전달합니다:
|
||||
|
||||
```python
|
||||
@start()
|
||||
def delegate_to_crew(self, crewai_trigger_payload: dict = None):
|
||||
"""
|
||||
전문 crew에 처리 위임
|
||||
"""
|
||||
crew = MySpecializedCrew()
|
||||
|
||||
# crew에 트리거 페이로드 전달
|
||||
result = crew.crew().kickoff(
|
||||
inputs={
|
||||
'a_custom_parameter': "custom_value",
|
||||
'crewai_trigger_payload': crewai_trigger_payload
|
||||
},
|
||||
)
|
||||
|
||||
return result
|
||||
```
|
||||
|
||||
## 문제 해결
|
||||
|
||||
**트리거가 작동하지 않는 경우:**
|
||||
- 트리거가 활성화되어 있는지 확인
|
||||
- 통합 연결 상태 확인
|
||||
|
||||
**실행 실패:**
|
||||
- 오류 세부 정보는 실행 로그 확인
|
||||
- 개발 중인 경우 입력에 올바른 페이로드가 포함된 `crewai_trigger_payload` 매개변수가 포함되어 있는지 확인
|
||||
|
||||
자동화 트리거는 CrewAI 배포를 기존 비즈니스 프로세스 및 도구와 완벽하게 통합할 수 있는 반응형 이벤트 기반 시스템으로 변환합니다.
|
||||
@@ -44,12 +44,12 @@ Para criar um listener de evento personalizado, você precisa:
|
||||
Veja um exemplo simples de uma classe de listener de evento personalizado:
|
||||
|
||||
```python
|
||||
from crewai.events import (
|
||||
from crewai.utilities.events import (
|
||||
CrewKickoffStartedEvent,
|
||||
CrewKickoffCompletedEvent,
|
||||
AgentExecutionCompletedEvent,
|
||||
)
|
||||
from crewai.events import BaseEventListener
|
||||
from crewai.utilities.events.base_event_listener import BaseEventListener
|
||||
|
||||
class MeuListenerPersonalizado(BaseEventListener):
|
||||
def __init__(self):
|
||||
@@ -146,7 +146,7 @@ my_project/
|
||||
|
||||
```python
|
||||
# my_custom_listener.py
|
||||
from crewai.events import BaseEventListener
|
||||
from crewai.utilities.events.base_event_listener import BaseEventListener
|
||||
# ... importe events ...
|
||||
|
||||
class MyCustomListener(BaseEventListener):
|
||||
@@ -268,7 +268,7 @@ Campos adicionais variam pelo tipo de evento. Por exemplo, `CrewKickoffCompleted
|
||||
Para lidar temporariamente com eventos (útil para testes ou operações específicas), você pode usar o context manager `scoped_handlers`:
|
||||
|
||||
```python
|
||||
from crewai.events import crewai_event_bus, CrewKickoffStartedEvent
|
||||
from crewai.utilities.events import crewai_event_bus, CrewKickoffStartedEvent
|
||||
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
@crewai_event_bus.on(CrewKickoffStartedEvent)
|
||||
|
||||
@@ -681,11 +681,11 @@ O CrewAI emite eventos durante o processo de recuperação de knowledge que voc
|
||||
#### Exemplo: Monitorando Recuperação de Knowledge
|
||||
|
||||
```python
|
||||
from crewai.events import (
|
||||
from crewai.utilities.events import (
|
||||
KnowledgeRetrievalStartedEvent,
|
||||
KnowledgeRetrievalCompletedEvent,
|
||||
BaseEventListener,
|
||||
)
|
||||
from crewai.utilities.events.base_event_listener import BaseEventListener
|
||||
|
||||
class KnowledgeMonitorListener(BaseEventListener):
|
||||
def setup_listeners(self, crewai_event_bus):
|
||||
|
||||
@@ -708,10 +708,10 @@ O CrewAI suporta respostas em streaming de LLMs, permitindo que sua aplicação
|
||||
O CrewAI emite eventos para cada chunk recebido durante o streaming:
|
||||
|
||||
```python
|
||||
from crewai.events import (
|
||||
from crewai.utilities.events import (
|
||||
LLMStreamChunkEvent
|
||||
)
|
||||
from crewai.events import BaseEventListener
|
||||
from crewai.utilities.events.base_event_listener import BaseEventListener
|
||||
|
||||
class MyCustomListener(BaseEventListener):
|
||||
def setup_listeners(self, crewai_event_bus):
|
||||
|
||||
@@ -59,7 +59,6 @@ crew = Crew(
|
||||
| **Output Pydantic** _(opcional)_ | `output_pydantic` | `Optional[Type[BaseModel]]` | Um modelo Pydantic para a saída da tarefa. |
|
||||
| **Callback** _(opcional)_ | `callback` | `Optional[Any]` | Função/objeto a ser executado após a conclusão da tarefa. |
|
||||
| **Guardrail** _(opcional)_ | `guardrail` | `Optional[Callable]` | Função para validar a saída da tarefa antes de prosseguir para a próxima tarefa. |
|
||||
| **Max Tentativas Guardrail** _(opcional)_ | `guardrail_max_retries` | `Optional[int]` | Número máximo de tentativas quando a validação do guardrail falha. Padrão é 3. |
|
||||
|
||||
## Criando Tarefas
|
||||
|
||||
@@ -451,7 +450,7 @@ task = Task(
|
||||
expected_output="Um objeto JSON válido",
|
||||
agent=analyst,
|
||||
guardrail=validate_json_output,
|
||||
guardrail_max_retries=3 # Limite de tentativas
|
||||
max_retries=3 # Limite de tentativas
|
||||
)
|
||||
```
|
||||
|
||||
@@ -936,7 +935,7 @@ task = Task(
|
||||
description="Gerar dados",
|
||||
expected_output="Dados válidos",
|
||||
guardrail=validate_data,
|
||||
guardrail_max_retries=5 # Sobrescreve o limite padrão de tentativas
|
||||
max_retries=5 # Sobrescreve o limite padrão de tentativas
|
||||
)
|
||||
```
|
||||
|
||||
|
||||
@@ -58,7 +58,7 @@ Antes de usar as Integrações de Autenticação, certifique-se de que você pos
|
||||
3. Clique em **Conectar** no serviço desejado na seção Integrações de Autenticação
|
||||
4. Complete o fluxo de autenticação OAuth
|
||||
5. Conceda as permissões necessárias para seu caso de uso
|
||||
6. Pronto! Obtenha seu Token Enterprise do [CrewAI Enterprise](https://app.crewai.com) na aba **Integration**
|
||||
6. Obtenha seu Token Enterprise na sua página de conta do [CrewAI Enterprise](https://app.crewai.com) - https://app.crewai.com/crewai_plus/settings/account
|
||||
|
||||
<Frame>
|
||||

|
||||
@@ -176,4 +176,4 @@ Use o `user_bearer_token` para direcionar a integração a um usuário específi
|
||||
|
||||
<Card title="Precisa de ajuda?" icon="headset" href="mailto:support@crewai.com">
|
||||
Entre em contato com nosso time de suporte para assistência com a configuração de integrações ou solução de problemas.
|
||||
</Card>
|
||||
</Card>
|
||||
@@ -1,171 +0,0 @@
|
||||
---
|
||||
title: "Triggers de Automação"
|
||||
description: "Execute automaticamente seus workflows CrewAI quando eventos específicos ocorrem em integrações conectadas"
|
||||
icon: "bolt"
|
||||
---
|
||||
|
||||
Os triggers de automação permitem executar automaticamente suas implantações CrewAI quando eventos específicos ocorrem em suas integrações conectadas, criando workflows poderosos orientados por eventos que respondem a mudanças em tempo real em seus sistemas de negócio.
|
||||
|
||||
## Visão Geral
|
||||
|
||||
Com triggers de automação, você pode:
|
||||
|
||||
- **Responder a eventos em tempo real** - Execute workflows automaticamente quando condições específicas forem atendidas
|
||||
- **Integrar com sistemas externos** - Conecte com plataformas como Gmail, Outlook, OneDrive, JIRA, Slack, Stripe e muito mais
|
||||
- **Escalar sua automação** - Lide com eventos de alto volume sem intervenção manual
|
||||
- **Manter contexto** - Acesse dados do trigger dentro de suas crews e flows
|
||||
|
||||
## Gerenciando Triggers de Automação
|
||||
|
||||
### Visualizando Triggers Disponíveis
|
||||
|
||||
Para acessar e gerenciar seus triggers de automação:
|
||||
|
||||
1. Navegue até sua implantação no painel do CrewAI
|
||||
2. Clique na aba **Triggers** para visualizar todas as integrações de trigger disponíveis
|
||||
|
||||
<Frame>
|
||||
<img src="/images/enterprise/list-available-triggers.png" alt="Lista de triggers de automação disponíveis" />
|
||||
</Frame>
|
||||
|
||||
Esta visualização mostra todas as integrações de trigger disponíveis para sua implantação, junto com seus status de conexão atuais.
|
||||
|
||||
### Habilitando e Desabilitando Triggers
|
||||
|
||||
Cada trigger pode ser facilmente habilitado ou desabilitado usando o botão de alternância:
|
||||
|
||||
<Frame>
|
||||
<img src="/images/enterprise/trigger-selected.png" alt="Habilitar ou desabilitar triggers com alternância" />
|
||||
</Frame>
|
||||
|
||||
- **Habilitado (alternância azul)**: O trigger está ativo e executará automaticamente sua implantação quando os eventos especificados ocorrerem
|
||||
- **Desabilitado (alternância cinza)**: O trigger está inativo e não responderá a eventos
|
||||
|
||||
Simplesmente clique na alternância para mudar o estado do trigger. As alterações entram em vigor imediatamente.
|
||||
|
||||
### Monitorando Execuções de Trigger
|
||||
|
||||
Acompanhe o desempenho e histórico de suas execuções acionadas:
|
||||
|
||||
<Frame>
|
||||
<img src="/images/enterprise/list-executions.png" alt="Lista de execuções acionadas por automação" />
|
||||
</Frame>
|
||||
|
||||
## Construindo Automação
|
||||
|
||||
Antes de construir sua automação, é útil entender a estrutura dos payloads de trigger que suas crews e flows receberão.
|
||||
|
||||
### Repositório de Amostras de Payload
|
||||
|
||||
Mantemos um repositório abrangente com amostras de payload de várias fontes de trigger para ajudá-lo a construir e testar suas automações:
|
||||
|
||||
**🔗 [Amostras de Payload de Trigger CrewAI Enterprise](https://github.com/crewAIInc/crewai-enterprise-trigger-payload-samples)**
|
||||
|
||||
Este repositório contém:
|
||||
|
||||
- **Exemplos reais de payload** de diferentes fontes de trigger (Gmail, Google Drive, etc.)
|
||||
- **Documentação da estrutura de payload** mostrando o formato e campos disponíveis
|
||||
|
||||
### Triggers com Crew
|
||||
|
||||
Suas definições de crew existentes funcionam perfeitamente com triggers, você só precisa ter uma tarefa para analisar o payload recebido:
|
||||
|
||||
```python
|
||||
@CrewBase
|
||||
class MinhaCrewAutomatizada:
|
||||
@agent
|
||||
def pesquisador(self) -> Agent:
|
||||
return Agent(
|
||||
config=self.agents_config['pesquisador'],
|
||||
)
|
||||
|
||||
@task
|
||||
def analisar_payload_trigger(self) -> Task:
|
||||
return Task(
|
||||
config=self.tasks_config['analisar_payload_trigger'],
|
||||
agent=self.pesquisador(),
|
||||
)
|
||||
|
||||
@task
|
||||
def analisar_conteudo_trigger(self) -> Task:
|
||||
return Task(
|
||||
config=self.tasks_config['analisar_dados_trigger'],
|
||||
agent=self.pesquisador(),
|
||||
)
|
||||
```
|
||||
|
||||
A crew receberá automaticamente e pode acessar o payload do trigger através dos mecanismos de contexto padrão do CrewAI.
|
||||
|
||||
### Integração com Flows
|
||||
|
||||
Para flows, você tem mais controle sobre como os dados do trigger são tratados:
|
||||
|
||||
#### Acessando Payload do Trigger
|
||||
|
||||
Todos os métodos `@start()` em seus flows aceitarão um parâmetro adicional chamado `crewai_trigger_payload`:
|
||||
|
||||
```python
|
||||
from crewai.flow import Flow, start, listen
|
||||
|
||||
class MeuFlowAutomatizado(Flow):
|
||||
@start()
|
||||
def lidar_com_trigger(self, crewai_trigger_payload: dict = None):
|
||||
"""
|
||||
Este método start pode receber dados do trigger
|
||||
"""
|
||||
if crewai_trigger_payload:
|
||||
# Processa os dados do trigger
|
||||
trigger_id = crewai_trigger_payload.get('id')
|
||||
dados_evento = crewai_trigger_payload.get('payload', {})
|
||||
|
||||
# Armazena no estado do flow para uso por outros métodos
|
||||
self.state.trigger_id = trigger_id
|
||||
self.state.trigger_type = dados_evento
|
||||
|
||||
return dados_evento
|
||||
|
||||
# Lida com execução manual
|
||||
return None
|
||||
|
||||
@listen(lidar_com_trigger)
|
||||
def processar_dados(self, dados_trigger):
|
||||
"""
|
||||
Processa os dados do trigger
|
||||
"""
|
||||
# ... processa o trigger
|
||||
```
|
||||
|
||||
#### Acionando Crews a partir de Flows
|
||||
|
||||
Ao iniciar uma crew dentro de um flow que foi acionado, passe o payload do trigger conforme ele:
|
||||
|
||||
```python
|
||||
@start()
|
||||
def delegar_para_crew(self, crewai_trigger_payload: dict = None):
|
||||
"""
|
||||
Delega processamento para uma crew especializada
|
||||
"""
|
||||
crew = MinhaCrewEspecializada()
|
||||
|
||||
# Passa o payload do trigger para a crew
|
||||
resultado = crew.crew().kickoff(
|
||||
inputs={
|
||||
'parametro_personalizado': "valor_personalizado",
|
||||
'crewai_trigger_payload': crewai_trigger_payload
|
||||
},
|
||||
)
|
||||
|
||||
return resultado
|
||||
```
|
||||
|
||||
## Solução de Problemas
|
||||
|
||||
**Trigger não está sendo disparado:**
|
||||
- Verifique se o trigger está habilitado
|
||||
- Verifique o status de conexão da integração
|
||||
|
||||
**Falhas de execução:**
|
||||
- Verifique os logs de execução para detalhes do erro
|
||||
- Se você está desenvolvendo, certifique-se de que as entradas incluem o parâmetro `crewai_trigger_payload` com o payload correto
|
||||
|
||||
Os triggers de automação transformam suas implantações CrewAI em sistemas responsivos orientados por eventos que podem se integrar perfeitamente com seus processos de negócio e ferramentas existentes.
|
||||
@@ -48,7 +48,7 @@ Documentation = "https://docs.crewai.com"
|
||||
Repository = "https://github.com/crewAIInc/crewAI"
|
||||
|
||||
[project.optional-dependencies]
|
||||
tools = ["crewai-tools~=0.69.0"]
|
||||
tools = ["crewai-tools~=0.62.1"]
|
||||
embeddings = [
|
||||
"tiktoken~=0.8.0"
|
||||
]
|
||||
@@ -68,16 +68,12 @@ docling = [
|
||||
aisuite = [
|
||||
"aisuite>=0.1.10",
|
||||
]
|
||||
qdrant = [
|
||||
"qdrant-client[fastembed]>=1.14.3",
|
||||
]
|
||||
|
||||
[tool.uv]
|
||||
dev-dependencies = [
|
||||
"ruff>=0.12.11",
|
||||
"mypy>=1.17.1",
|
||||
"pre-commit>=4.3.0",
|
||||
"bandit>=1.8.6",
|
||||
"ruff>=0.8.2",
|
||||
"mypy>=1.10.0",
|
||||
"pre-commit>=3.6.0",
|
||||
"pillow>=10.2.0",
|
||||
"cairosvg>=2.7.1",
|
||||
"pytest>=8.0.0",
|
||||
@@ -89,40 +85,15 @@ dev-dependencies = [
|
||||
"pytest-timeout>=2.3.1",
|
||||
"pytest-xdist>=3.6.1",
|
||||
"pytest-split>=0.9.0",
|
||||
"types-requests==2.32.*",
|
||||
"types-pyyaml==6.0.*",
|
||||
"types-regex==2024.11.6.*",
|
||||
"types-appdirs==1.4.*",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
crewai = "crewai.cli.cli:crewai"
|
||||
|
||||
[tool.ruff]
|
||||
exclude = [
|
||||
"src/crewai/cli/templates",
|
||||
]
|
||||
fix = true
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = [
|
||||
"B006",
|
||||
"UP006",
|
||||
"UP007",
|
||||
"UP035",
|
||||
"UP037",
|
||||
"UP004",
|
||||
"UP008",
|
||||
"UP010",
|
||||
"UP018",
|
||||
"UP031",
|
||||
"UP032",
|
||||
"I001",
|
||||
"I002",
|
||||
]
|
||||
|
||||
[tool.mypy]
|
||||
exclude = ["src/crewai/cli/templates", "tests"]
|
||||
ignore_missing_imports = true
|
||||
disable_error_code = 'import-untyped'
|
||||
exclude = ["cli/templates"]
|
||||
|
||||
[tool.bandit]
|
||||
exclude_dirs = ["src/crewai/cli/templates"]
|
||||
|
||||
@@ -1,30 +1,4 @@
|
||||
import warnings
|
||||
from typing import Any
|
||||
|
||||
|
||||
def _suppress_pydantic_deprecation_warnings() -> None:
|
||||
"""Suppress Pydantic deprecation warnings using targeted monkey patch."""
|
||||
original_warn = warnings.warn
|
||||
|
||||
def filtered_warn(
|
||||
message: Any,
|
||||
category: type | None = None,
|
||||
stacklevel: int = 1,
|
||||
source: Any = None,
|
||||
) -> Any:
|
||||
if (
|
||||
category
|
||||
and hasattr(category, "__module__")
|
||||
and category.__module__ == "pydantic.warnings"
|
||||
):
|
||||
return None
|
||||
return original_warn(message, category, stacklevel + 1, source)
|
||||
|
||||
setattr(warnings, "warn", filtered_warn)
|
||||
|
||||
|
||||
_suppress_pydantic_deprecation_warnings()
|
||||
|
||||
import threading
|
||||
import urllib.request
|
||||
|
||||
@@ -41,10 +15,17 @@ from crewai.tasks.llm_guardrail import LLMGuardrail
|
||||
from crewai.tasks.task_output import TaskOutput
|
||||
from crewai.telemetry.telemetry import Telemetry
|
||||
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
message="Pydantic serializer warnings:",
|
||||
category=UserWarning,
|
||||
module="pydantic.main",
|
||||
)
|
||||
|
||||
_telemetry_submitted = False
|
||||
|
||||
|
||||
def _track_install() -> None:
|
||||
def _track_install():
|
||||
"""Track package installation/first-use via Scarf analytics."""
|
||||
global _telemetry_submitted
|
||||
|
||||
@@ -55,7 +36,7 @@ def _track_install() -> None:
|
||||
pixel_url = "https://api.scarf.sh/v2/packages/CrewAI/crewai/docs/00f2dad1-8334-4a39-934e-003b2e1146db"
|
||||
|
||||
req = urllib.request.Request(pixel_url)
|
||||
req.add_header("User-Agent", f"CrewAI-Python/{__version__}")
|
||||
req.add_header('User-Agent', f'CrewAI-Python/{__version__}')
|
||||
|
||||
with urllib.request.urlopen(req, timeout=2): # nosec B310
|
||||
_telemetry_submitted = True
|
||||
@@ -64,7 +45,7 @@ def _track_install() -> None:
|
||||
pass
|
||||
|
||||
|
||||
def _track_install_async() -> None:
|
||||
def _track_install_async():
|
||||
"""Track installation in background thread to avoid blocking imports."""
|
||||
if not Telemetry._is_telemetry_disabled():
|
||||
thread = threading.Thread(target=_track_install, daemon=True)
|
||||
@@ -73,7 +54,7 @@ def _track_install_async() -> None:
|
||||
|
||||
_track_install_async()
|
||||
|
||||
__version__ = "0.177.0"
|
||||
__version__ = "0.165.1"
|
||||
__all__ = [
|
||||
"Agent",
|
||||
"Crew",
|
||||
|
||||
@@ -1,50 +1,18 @@
|
||||
import shutil
|
||||
import subprocess
|
||||
import time
|
||||
from collections.abc import Callable, Sequence
|
||||
from typing import (
|
||||
Any,
|
||||
Literal,
|
||||
Optional,
|
||||
)
|
||||
from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple, Type, Union
|
||||
|
||||
from pydantic import (
|
||||
BeforeValidator,
|
||||
Field,
|
||||
InstanceOf,
|
||||
PrivateAttr,
|
||||
computed_field,
|
||||
field_validator,
|
||||
model_validator,
|
||||
)
|
||||
from typing_extensions import Self
|
||||
from pydantic import Field, InstanceOf, PrivateAttr, model_validator
|
||||
|
||||
from crewai.agents import CacheHandler
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
from crewai.agents.cache.cache_handler import CacheHandler
|
||||
from crewai.agents.crew_agent_executor import CrewAgentExecutor
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.agent_events import (
|
||||
AgentExecutionCompletedEvent,
|
||||
AgentExecutionErrorEvent,
|
||||
AgentExecutionStartedEvent,
|
||||
)
|
||||
from crewai.events.types.knowledge_events import (
|
||||
KnowledgeQueryCompletedEvent,
|
||||
KnowledgeQueryFailedEvent,
|
||||
KnowledgeQueryStartedEvent,
|
||||
KnowledgeRetrievalCompletedEvent,
|
||||
KnowledgeRetrievalStartedEvent,
|
||||
KnowledgeSearchQueryFailedEvent,
|
||||
)
|
||||
from crewai.events.types.memory_events import (
|
||||
MemoryRetrievalCompletedEvent,
|
||||
MemoryRetrievalStartedEvent,
|
||||
)
|
||||
from crewai.knowledge.knowledge import Knowledge
|
||||
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
|
||||
from crewai.knowledge.utils.knowledge_utils import extract_knowledge_context
|
||||
from crewai.lite_agent import LiteAgent, LiteAgentOutput
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.llm import BaseLLM
|
||||
from crewai.memory.contextual.contextual_memory import ContextualMemory
|
||||
from crewai.security import Fingerprint
|
||||
from crewai.task import Task
|
||||
@@ -59,7 +27,25 @@ from crewai.utilities.agent_utils import (
|
||||
)
|
||||
from crewai.utilities.constants import TRAINED_AGENTS_DATA_FILE, TRAINING_DATA_FILE
|
||||
from crewai.utilities.converter import generate_model_description
|
||||
from crewai.utilities.llm_utils import create_default_llm, create_llm
|
||||
from crewai.utilities.events.agent_events import (
|
||||
AgentExecutionCompletedEvent,
|
||||
AgentExecutionErrorEvent,
|
||||
AgentExecutionStartedEvent,
|
||||
)
|
||||
from crewai.utilities.events.crewai_event_bus import crewai_event_bus
|
||||
from crewai.utilities.events.memory_events import (
|
||||
MemoryRetrievalStartedEvent,
|
||||
MemoryRetrievalCompletedEvent,
|
||||
)
|
||||
from crewai.utilities.events.knowledge_events import (
|
||||
KnowledgeQueryCompletedEvent,
|
||||
KnowledgeQueryFailedEvent,
|
||||
KnowledgeQueryStartedEvent,
|
||||
KnowledgeRetrievalCompletedEvent,
|
||||
KnowledgeRetrievalStartedEvent,
|
||||
KnowledgeSearchQueryFailedEvent,
|
||||
)
|
||||
from crewai.utilities.llm_utils import create_llm
|
||||
from crewai.utilities.token_counter_callback import TokenCalcHandler
|
||||
from crewai.utilities.training_handler import CrewTrainingHandler
|
||||
|
||||
@@ -90,8 +76,6 @@ class Agent(BaseAgent):
|
||||
"""
|
||||
|
||||
_times_executed: int = PrivateAttr(default=0)
|
||||
_llm: BaseLLM = PrivateAttr()
|
||||
_function_calling_llm: BaseLLM | None = PrivateAttr(default=None)
|
||||
max_execution_time: Optional[int] = Field(
|
||||
default=None,
|
||||
description="Maximum execution time for an agent to execute a task",
|
||||
@@ -106,11 +90,10 @@ class Agent(BaseAgent):
|
||||
default=True,
|
||||
description="Use system prompt for the agent.",
|
||||
)
|
||||
llm: str | InstanceOf[BaseLLM] | None = Field(
|
||||
description="Language model that will run the agent.",
|
||||
default_factory=create_default_llm,
|
||||
llm: Union[str, InstanceOf[BaseLLM], Any] = Field(
|
||||
description="Language model that will run the agent.", default=None
|
||||
)
|
||||
function_calling_llm: str | InstanceOf[BaseLLM] | None = Field(
|
||||
function_calling_llm: Optional[Union[str, InstanceOf[BaseLLM], Any]] = Field(
|
||||
description="Language model that will run the agent.", default=None
|
||||
)
|
||||
system_template: Optional[str] = Field(
|
||||
@@ -157,7 +140,7 @@ class Agent(BaseAgent):
|
||||
default=None,
|
||||
description="Maximum number of reasoning attempts before executing the task. If None, will try until ready.",
|
||||
)
|
||||
embedder: Optional[dict[str, Any]] = Field(
|
||||
embedder: Optional[Dict[str, Any]] = Field(
|
||||
default=None,
|
||||
description="Embedder configuration for the agent.",
|
||||
)
|
||||
@@ -177,45 +160,29 @@ class Agent(BaseAgent):
|
||||
default=None,
|
||||
description="The Agent's role to be used from your repository.",
|
||||
)
|
||||
guardrail: Optional[Callable[[Any], tuple[bool, Any]] | str] = Field(
|
||||
guardrail: Optional[Union[Callable[[Any], Tuple[bool, Any]], str]] = Field(
|
||||
default=None,
|
||||
description="Function or string description of a guardrail to validate agent output",
|
||||
description="Function or string description of a guardrail to validate agent output"
|
||||
)
|
||||
guardrail_max_retries: int = Field(
|
||||
default=3, description="Maximum number of retries when guardrail fails"
|
||||
)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_from_repository(cls, v: Any) -> Any:
|
||||
def validate_from_repository(cls, v):
|
||||
if v is not None and (from_repository := v.get("from_repository")):
|
||||
return load_agent_from_repository(from_repository) | v
|
||||
return v
|
||||
|
||||
@field_validator("function_calling_llm", mode="after")
|
||||
@classmethod
|
||||
def validate_function_calling_llm(cls, v: Any) -> BaseLLM | None:
|
||||
if not v or isinstance(v, BaseLLM):
|
||||
return v
|
||||
return create_llm(v)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def post_init_setup(self) -> Self:
|
||||
def post_init_setup(self):
|
||||
self.agent_ops_agent_name = self.role
|
||||
|
||||
# Validate and set the private LLM attributes
|
||||
if isinstance(self.llm, BaseLLM):
|
||||
self._llm = self.llm
|
||||
elif self.llm is None:
|
||||
self._llm = create_default_llm()
|
||||
else:
|
||||
self._llm = create_llm(self.llm)
|
||||
|
||||
if self.function_calling_llm:
|
||||
if isinstance(self.function_calling_llm, BaseLLM):
|
||||
self._function_calling_llm = self.function_calling_llm
|
||||
else:
|
||||
self._function_calling_llm = create_llm(self.function_calling_llm)
|
||||
self.llm = create_llm(self.llm)
|
||||
if self.function_calling_llm and not isinstance(
|
||||
self.function_calling_llm, BaseLLM
|
||||
):
|
||||
self.function_calling_llm = create_llm(self.function_calling_llm)
|
||||
|
||||
if not self.agent_executor:
|
||||
self._setup_agent_executor()
|
||||
@@ -225,12 +192,12 @@ class Agent(BaseAgent):
|
||||
|
||||
return self
|
||||
|
||||
def _setup_agent_executor(self) -> None:
|
||||
def _setup_agent_executor(self):
|
||||
if not self.cache_handler:
|
||||
self.cache_handler = CacheHandler()
|
||||
self.set_cache_handler(self.cache_handler)
|
||||
|
||||
def set_knowledge(self, crew_embedder: Optional[dict[str, Any]] = None) -> None:
|
||||
def set_knowledge(self, crew_embedder: Optional[Dict[str, Any]] = None):
|
||||
try:
|
||||
if self.embedder is None and crew_embedder:
|
||||
self.embedder = crew_embedder
|
||||
@@ -267,8 +234,8 @@ class Agent(BaseAgent):
|
||||
self,
|
||||
task: Task,
|
||||
context: Optional[str] = None,
|
||||
tools: Optional[list[BaseTool]] = None,
|
||||
) -> Any:
|
||||
tools: Optional[List[BaseTool]] = None,
|
||||
) -> str:
|
||||
"""Execute a task with the agent.
|
||||
|
||||
Args:
|
||||
@@ -309,7 +276,7 @@ class Agent(BaseAgent):
|
||||
self._inject_date_to_task(task)
|
||||
|
||||
if self.tools_handler:
|
||||
self.tools_handler.last_used_tool = None
|
||||
self.tools_handler.last_used_tool = {} # type: ignore # Incompatible types in assignment (expression has type "dict[Never, Never]", variable has type "ToolCalling")
|
||||
|
||||
task_prompt = task.prompt()
|
||||
|
||||
@@ -342,20 +309,15 @@ class Agent(BaseAgent):
|
||||
event=MemoryRetrievalStartedEvent(
|
||||
task_id=str(task.id) if task else None,
|
||||
source_type="agent",
|
||||
from_agent=self,
|
||||
from_task=task,
|
||||
),
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
contextual_memory = ContextualMemory(
|
||||
self.crew._short_term_memory,
|
||||
self.crew._long_term_memory,
|
||||
self.crew._entity_memory,
|
||||
self.crew._external_memory,
|
||||
agent=self,
|
||||
task=task,
|
||||
)
|
||||
memory = contextual_memory.build_context_for_task(task, context)
|
||||
if memory.strip() != "":
|
||||
@@ -368,14 +330,13 @@ class Agent(BaseAgent):
|
||||
memory_content=memory,
|
||||
retrieval_time_ms=(time.time() - start_time) * 1000,
|
||||
source_type="agent",
|
||||
from_agent=self,
|
||||
from_task=task,
|
||||
),
|
||||
)
|
||||
knowledge_config = (
|
||||
self.knowledge_config.model_dump() if self.knowledge_config else {}
|
||||
)
|
||||
|
||||
|
||||
if self.knowledge or (self.crew and self.crew.knowledge):
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
@@ -439,7 +400,7 @@ class Agent(BaseAgent):
|
||||
)
|
||||
|
||||
tools = tools or self.tools or []
|
||||
self.create_agent_executor(task=task, tools=tools)
|
||||
self.create_agent_executor(tools=tools, task=task)
|
||||
|
||||
if self.crew and self.crew._train:
|
||||
task_prompt = self._training_handler(task_prompt=task_prompt)
|
||||
@@ -514,7 +475,7 @@ class Agent(BaseAgent):
|
||||
# If there was any tool in self.tools_results that had result_as_answer
|
||||
# set to True, return the results of the last tool that had
|
||||
# result_as_answer set to True
|
||||
for tool_result in self.tools_results:
|
||||
for tool_result in self.tools_results: # type: ignore # Item "None" of "list[Any] | None" has no attribute "__iter__" (not iterable)
|
||||
if tool_result.get("result_as_answer", False):
|
||||
result = tool_result["result"]
|
||||
crewai_event_bus.emit(
|
||||
@@ -523,7 +484,7 @@ class Agent(BaseAgent):
|
||||
)
|
||||
return result
|
||||
|
||||
def _execute_with_timeout(self, task_prompt: str, task: Task, timeout: int) -> Any:
|
||||
def _execute_with_timeout(self, task_prompt: str, task: Task, timeout: int) -> str:
|
||||
"""Execute a task with a timeout.
|
||||
|
||||
Args:
|
||||
@@ -556,7 +517,7 @@ class Agent(BaseAgent):
|
||||
future.cancel()
|
||||
raise RuntimeError(f"Task execution failed: {str(e)}")
|
||||
|
||||
def _execute_without_timeout(self, task_prompt: str, task: Task) -> Any:
|
||||
def _execute_without_timeout(self, task_prompt: str, task: Task) -> str:
|
||||
"""Execute a task without a timeout.
|
||||
|
||||
Args:
|
||||
@@ -566,9 +527,6 @@ class Agent(BaseAgent):
|
||||
Returns:
|
||||
The output of the agent.
|
||||
"""
|
||||
assert self.agent_executor is not None, (
|
||||
"Agent executor must be created before execution"
|
||||
)
|
||||
return self.agent_executor.invoke(
|
||||
{
|
||||
"input": task_prompt,
|
||||
@@ -579,15 +537,14 @@ class Agent(BaseAgent):
|
||||
)["output"]
|
||||
|
||||
def create_agent_executor(
|
||||
self, task: Task, tools: Optional[list[BaseTool]] = None
|
||||
self, tools: Optional[List[BaseTool]] = None, task=None
|
||||
) -> None:
|
||||
"""Create an agent executor for the agent.
|
||||
|
||||
Args:
|
||||
task: Task to execute.
|
||||
tools: Optional list of tools to use.
|
||||
Returns:
|
||||
An instance of the CrewAgentExecutor class.
|
||||
"""
|
||||
raw_tools: list[BaseTool] = tools or self.tools or []
|
||||
raw_tools: List[BaseTool] = tools or self.tools or []
|
||||
parsed_tools = parse_tools(raw_tools)
|
||||
|
||||
prompt = Prompts(
|
||||
@@ -608,7 +565,7 @@ class Agent(BaseAgent):
|
||||
)
|
||||
|
||||
self.agent_executor = CrewAgentExecutor(
|
||||
llm=self._llm,
|
||||
llm=self.llm,
|
||||
task=task,
|
||||
agent=self,
|
||||
crew=self.crew,
|
||||
@@ -621,15 +578,15 @@ class Agent(BaseAgent):
|
||||
tools_names=get_tool_names(parsed_tools),
|
||||
tools_description=render_text_description_and_args(parsed_tools),
|
||||
step_callback=self.step_callback,
|
||||
function_calling_llm=self._function_calling_llm,
|
||||
function_calling_llm=self.function_calling_llm,
|
||||
respect_context_window=self.respect_context_window,
|
||||
request_within_rpm_limit=(
|
||||
self._rpm_controller.check_or_wait if self._rpm_controller else None
|
||||
),
|
||||
litellm_callbacks=[TokenCalcHandler(self._token_process)],
|
||||
callbacks=[TokenCalcHandler(self._token_process)],
|
||||
)
|
||||
|
||||
def get_delegation_tools(self, agents: list[BaseAgent]) -> list[BaseTool]:
|
||||
def get_delegation_tools(self, agents: List[BaseAgent]):
|
||||
agent_tools = AgentTools(agents=agents)
|
||||
tools = agent_tools.tools()
|
||||
return tools
|
||||
@@ -639,7 +596,7 @@ class Agent(BaseAgent):
|
||||
|
||||
return [AddImageTool()]
|
||||
|
||||
def get_code_execution_tools(self) -> list[BaseTool]:
|
||||
def get_code_execution_tools(self):
|
||||
try:
|
||||
from crewai_tools import CodeInterpreterTool # type: ignore
|
||||
|
||||
@@ -650,11 +607,8 @@ class Agent(BaseAgent):
|
||||
self._logger.log(
|
||||
"info", "Coding tools not available. Install crewai_tools. "
|
||||
)
|
||||
return []
|
||||
|
||||
def get_output_converter(
|
||||
self, llm: BaseLLM, text: str, model: str, instructions: str
|
||||
) -> Converter:
|
||||
def get_output_converter(self, llm, text, model, instructions):
|
||||
return Converter(llm=llm, text=text, model=model, instructions=instructions)
|
||||
|
||||
def _training_handler(self, task_prompt: str) -> str:
|
||||
@@ -683,7 +637,7 @@ class Agent(BaseAgent):
|
||||
)
|
||||
return task_prompt
|
||||
|
||||
def _render_text_description(self, tools: list[Any]) -> str:
|
||||
def _render_text_description(self, tools: List[Any]) -> str:
|
||||
"""Render the tool name and description in plain text.
|
||||
|
||||
Output will be in the format of:
|
||||
@@ -702,7 +656,7 @@ class Agent(BaseAgent):
|
||||
|
||||
return description
|
||||
|
||||
def _inject_date_to_task(self, task: Task) -> None:
|
||||
def _inject_date_to_task(self, task):
|
||||
"""Inject the current date into the task description if inject_date is enabled."""
|
||||
if self.inject_date:
|
||||
from datetime import datetime
|
||||
@@ -752,7 +706,7 @@ class Agent(BaseAgent):
|
||||
f"Docker is not running. Please start Docker to use code execution with agent: {self.role}"
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
def __repr__(self):
|
||||
return f"Agent(role={self.role}, goal={self.goal}, backstory={self.backstory})"
|
||||
|
||||
@property
|
||||
@@ -765,7 +719,7 @@ class Agent(BaseAgent):
|
||||
"""
|
||||
return self.security_config.fingerprint
|
||||
|
||||
def set_fingerprint(self, fingerprint: Fingerprint) -> None:
|
||||
def set_fingerprint(self, fingerprint: Fingerprint):
|
||||
self.security_config.fingerprint = fingerprint
|
||||
|
||||
def _get_knowledge_search_query(self, task_prompt: str) -> str | None:
|
||||
@@ -781,8 +735,22 @@ class Agent(BaseAgent):
|
||||
task_prompt=task_prompt
|
||||
)
|
||||
rewriter_prompt = self.i18n.slice("knowledge_search_query_system_prompt")
|
||||
if not isinstance(self.llm, BaseLLM):
|
||||
self._logger.log(
|
||||
"warning",
|
||||
f"Knowledge search query failed: LLM for agent '{self.role}' is not an instance of BaseLLM",
|
||||
)
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=KnowledgeQueryFailedEvent(
|
||||
agent=self,
|
||||
error="LLM is not compatible with knowledge search queries",
|
||||
),
|
||||
)
|
||||
return None
|
||||
|
||||
try:
|
||||
rewritten_query = self._llm.call(
|
||||
rewritten_query = self.llm.call(
|
||||
[
|
||||
{
|
||||
"role": "system",
|
||||
@@ -811,8 +779,8 @@ class Agent(BaseAgent):
|
||||
|
||||
def kickoff(
|
||||
self,
|
||||
messages: str | list[dict[str, str]],
|
||||
response_format: Optional[type[Any]] = None,
|
||||
messages: Union[str, List[Dict[str, str]]],
|
||||
response_format: Optional[Type[Any]] = None,
|
||||
) -> LiteAgentOutput:
|
||||
"""
|
||||
Execute the agent with the given messages using a LiteAgent instance.
|
||||
@@ -834,7 +802,7 @@ class Agent(BaseAgent):
|
||||
role=self.role,
|
||||
goal=self.goal,
|
||||
backstory=self.backstory,
|
||||
llm=self._llm,
|
||||
llm=self.llm,
|
||||
tools=self.tools or [],
|
||||
max_iterations=self.max_iter,
|
||||
max_execution_time=self.max_execution_time,
|
||||
@@ -851,8 +819,8 @@ class Agent(BaseAgent):
|
||||
|
||||
async def kickoff_async(
|
||||
self,
|
||||
messages: str | list[dict[str, str]],
|
||||
response_format: Optional[type[Any]] = None,
|
||||
messages: Union[str, List[Dict[str, str]]],
|
||||
response_format: Optional[Type[Any]] = None,
|
||||
) -> LiteAgentOutput:
|
||||
"""
|
||||
Execute the agent asynchronously with the given messages using a LiteAgent instance.
|
||||
@@ -872,7 +840,7 @@ class Agent(BaseAgent):
|
||||
role=self.role,
|
||||
goal=self.goal,
|
||||
backstory=self.backstory,
|
||||
llm=self._llm,
|
||||
llm=self.llm,
|
||||
tools=self.tools or [],
|
||||
max_iterations=self.max_iter,
|
||||
max_execution_time=self.max_execution_time,
|
||||
|
||||
@@ -1,10 +1,5 @@
|
||||
from crewai.agents.parser import parse, AgentAction, AgentFinish, OutputParserException
|
||||
from crewai.agents.tools_handler import ToolsHandler
|
||||
from .cache.cache_handler import CacheHandler
|
||||
from .parser import CrewAgentParser
|
||||
from .tools_handler import ToolsHandler
|
||||
|
||||
__all__ = [
|
||||
"parse",
|
||||
"AgentAction",
|
||||
"AgentFinish",
|
||||
"OutputParserException",
|
||||
"ToolsHandler",
|
||||
]
|
||||
__all__ = ["CacheHandler", "CrewAgentParser", "ToolsHandler"]
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Optional
|
||||
from typing import Any, AsyncIterable, Dict, List, Optional
|
||||
|
||||
from pydantic import Field, PrivateAttr
|
||||
|
||||
@@ -10,22 +10,21 @@ from crewai.agents.agent_adapters.langgraph.structured_output_converter import (
|
||||
LangGraphConverterAdapter,
|
||||
)
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.agent_events import (
|
||||
AgentExecutionCompletedEvent,
|
||||
AgentExecutionErrorEvent,
|
||||
AgentExecutionStartedEvent,
|
||||
)
|
||||
from crewai.tools.agent_tools.agent_tools import AgentTools
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
from crewai.utilities import Logger
|
||||
from crewai.utilities.converter import Converter
|
||||
from crewai.utilities.events import crewai_event_bus
|
||||
from crewai.utilities.events.agent_events import (
|
||||
AgentExecutionCompletedEvent,
|
||||
AgentExecutionErrorEvent,
|
||||
AgentExecutionStartedEvent,
|
||||
)
|
||||
|
||||
try:
|
||||
from langgraph.checkpoint.memory import ( # type: ignore
|
||||
MemorySaver,
|
||||
)
|
||||
from langgraph.prebuilt import create_react_agent # type: ignore
|
||||
from langchain_core.messages import ToolMessage
|
||||
from langgraph.checkpoint.memory import MemorySaver
|
||||
from langgraph.prebuilt import create_react_agent
|
||||
|
||||
LANGGRAPH_AVAILABLE = True
|
||||
except ImportError:
|
||||
@@ -53,11 +52,11 @@ class LangGraphAgentAdapter(BaseAgentAdapter):
|
||||
role: str,
|
||||
goal: str,
|
||||
backstory: str,
|
||||
tools: Optional[list[BaseTool]] = None,
|
||||
tools: Optional[List[BaseTool]] = None,
|
||||
llm: Any = None,
|
||||
max_iterations: int = 10,
|
||||
agent_config: Optional[dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
agent_config: Optional[Dict[str, Any]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the LangGraph agent adapter."""
|
||||
if not LANGGRAPH_AVAILABLE:
|
||||
@@ -83,7 +82,7 @@ class LangGraphAgentAdapter(BaseAgentAdapter):
|
||||
try:
|
||||
self._memory = MemorySaver()
|
||||
|
||||
converted_tools: list[Any] = self._tool_adapter.tools()
|
||||
converted_tools: List[Any] = self._tool_adapter.tools()
|
||||
if self._agent_config:
|
||||
self._graph = create_react_agent(
|
||||
model=self.llm,
|
||||
@@ -113,7 +112,7 @@ class LangGraphAgentAdapter(BaseAgentAdapter):
|
||||
"""Build a system prompt for the LangGraph agent."""
|
||||
base_prompt = f"""
|
||||
You are {self.role}.
|
||||
|
||||
|
||||
Your goal is: {self.goal}
|
||||
|
||||
Your backstory: {self.backstory}
|
||||
@@ -126,10 +125,10 @@ class LangGraphAgentAdapter(BaseAgentAdapter):
|
||||
self,
|
||||
task: Any,
|
||||
context: Optional[str] = None,
|
||||
tools: Optional[list[BaseTool]] = None,
|
||||
tools: Optional[List[BaseTool]] = None,
|
||||
) -> str:
|
||||
"""Execute a task using the LangGraph workflow."""
|
||||
self.create_agent_executor(task, tools)
|
||||
self.create_agent_executor(tools)
|
||||
|
||||
self.configure_structured_output(task)
|
||||
|
||||
@@ -199,13 +198,11 @@ class LangGraphAgentAdapter(BaseAgentAdapter):
|
||||
)
|
||||
raise
|
||||
|
||||
def create_agent_executor(
|
||||
self, task: Any = None, tools: Optional[list[BaseTool]] = None
|
||||
) -> None:
|
||||
def create_agent_executor(self, tools: Optional[List[BaseTool]] = None) -> None:
|
||||
"""Configure the LangGraph agent for execution."""
|
||||
self.configure_tools(tools)
|
||||
|
||||
def configure_tools(self, tools: Optional[list[BaseTool]] = None) -> None:
|
||||
def configure_tools(self, tools: Optional[List[BaseTool]] = None) -> None:
|
||||
"""Configure tools for the LangGraph agent."""
|
||||
if tools:
|
||||
all_tools = list(self.tools or []) + list(tools or [])
|
||||
@@ -213,7 +210,7 @@ class LangGraphAgentAdapter(BaseAgentAdapter):
|
||||
available_tools = self._tool_adapter.tools()
|
||||
self._graph.tools = available_tools
|
||||
|
||||
def get_delegation_tools(self, agents: list[BaseAgent]) -> list[BaseTool]:
|
||||
def get_delegation_tools(self, agents: List[BaseAgent]) -> List[BaseTool]:
|
||||
"""Implement delegation tools support for LangGraph."""
|
||||
agent_tools = AgentTools(agents=agents)
|
||||
return agent_tools.tools()
|
||||
@@ -224,6 +221,6 @@ class LangGraphAgentAdapter(BaseAgentAdapter):
|
||||
"""Convert output format if needed."""
|
||||
return Converter(llm=llm, text=text, model=model, instructions=instructions)
|
||||
|
||||
def configure_structured_output(self, task: Any) -> None:
|
||||
def configure_structured_output(self, task) -> None:
|
||||
"""Configure the structured output for LangGraph."""
|
||||
self._converter_adapter.configure_structured_output(task)
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from crewai.agents.agent_adapters.base_converter_adapter import BaseConverterAdapter
|
||||
from crewai.utilities.converter import generate_model_description
|
||||
@@ -8,15 +7,14 @@ from crewai.utilities.converter import generate_model_description
|
||||
class LangGraphConverterAdapter(BaseConverterAdapter):
|
||||
"""Adapter for handling structured output conversion in LangGraph agents"""
|
||||
|
||||
def __init__(self, agent_adapter: Any) -> None:
|
||||
def __init__(self, agent_adapter):
|
||||
"""Initialize the converter adapter with a reference to the agent adapter"""
|
||||
super().__init__(agent_adapter) # type: ignore
|
||||
self.agent_adapter = agent_adapter
|
||||
self._output_format: str | None = None
|
||||
self._schema: str | None = None
|
||||
self._system_prompt_appendix: str | None = None
|
||||
self._output_format = None
|
||||
self._schema = None
|
||||
self._system_prompt_appendix = None
|
||||
|
||||
def configure_structured_output(self, task: Any) -> None:
|
||||
def configure_structured_output(self, task) -> None:
|
||||
"""Configure the structured output for LangGraph."""
|
||||
if not (task.output_json or task.output_pydantic):
|
||||
self._output_format = None
|
||||
@@ -43,7 +41,7 @@ Important: Your final answer MUST be provided in the following structured format
|
||||
|
||||
{self._schema}
|
||||
|
||||
DO NOT include any markdown code blocks, backticks, or other formatting around your response.
|
||||
DO NOT include any markdown code blocks, backticks, or other formatting around your response.
|
||||
The output should be raw JSON that exactly matches the specified schema.
|
||||
"""
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Optional
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from pydantic import Field, PrivateAttr
|
||||
|
||||
@@ -7,19 +7,19 @@ from crewai.agents.agent_adapters.openai_agents.structured_output_converter impo
|
||||
OpenAIConverterAdapter,
|
||||
)
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.agent_events import (
|
||||
from crewai.tools import BaseTool
|
||||
from crewai.tools.agent_tools.agent_tools import AgentTools
|
||||
from crewai.utilities import Logger
|
||||
from crewai.utilities.events import crewai_event_bus
|
||||
from crewai.utilities.events.agent_events import (
|
||||
AgentExecutionCompletedEvent,
|
||||
AgentExecutionErrorEvent,
|
||||
AgentExecutionStartedEvent,
|
||||
)
|
||||
from crewai.tools import BaseTool
|
||||
from crewai.tools.agent_tools.agent_tools import AgentTools
|
||||
from crewai.utilities import Logger
|
||||
|
||||
try:
|
||||
from agents import Agent as OpenAIAgent # type: ignore[import-not-found]
|
||||
from agents import Runner, enable_verbose_stdout_logging
|
||||
from agents import Agent as OpenAIAgent # type: ignore
|
||||
from agents import Runner, enable_verbose_stdout_logging # type: ignore
|
||||
|
||||
from .openai_agent_tool_adapter import OpenAIAgentToolAdapter
|
||||
|
||||
@@ -40,14 +40,13 @@ class OpenAIAgentAdapter(BaseAgentAdapter):
|
||||
step_callback: Any = Field(default=None)
|
||||
_tool_adapter: "OpenAIAgentToolAdapter" = PrivateAttr()
|
||||
_converter_adapter: OpenAIConverterAdapter = PrivateAttr()
|
||||
agent_executor: Any = Field(default=None)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str = "gpt-4o-mini",
|
||||
tools: Optional[list[BaseTool]] = None,
|
||||
agent_config: Optional[dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
tools: Optional[List[BaseTool]] = None,
|
||||
agent_config: Optional[dict] = None,
|
||||
**kwargs,
|
||||
):
|
||||
if not OPENAI_AVAILABLE:
|
||||
raise ImportError(
|
||||
@@ -73,7 +72,7 @@ class OpenAIAgentAdapter(BaseAgentAdapter):
|
||||
"""Build a system prompt for the OpenAI agent."""
|
||||
base_prompt = f"""
|
||||
You are {self.role}.
|
||||
|
||||
|
||||
Your goal is: {self.goal}
|
||||
|
||||
Your backstory: {self.backstory}
|
||||
@@ -86,11 +85,11 @@ class OpenAIAgentAdapter(BaseAgentAdapter):
|
||||
self,
|
||||
task: Any,
|
||||
context: Optional[str] = None,
|
||||
tools: Optional[list[BaseTool]] = None,
|
||||
) -> Any:
|
||||
tools: Optional[List[BaseTool]] = None,
|
||||
) -> str:
|
||||
"""Execute a task using the OpenAI Assistant"""
|
||||
self._converter_adapter.configure_structured_output(task)
|
||||
self.create_agent_executor(task, tools)
|
||||
self.create_agent_executor(tools)
|
||||
|
||||
if self.verbose:
|
||||
enable_verbose_stdout_logging()
|
||||
@@ -110,7 +109,6 @@ class OpenAIAgentAdapter(BaseAgentAdapter):
|
||||
task=task,
|
||||
),
|
||||
)
|
||||
assert hasattr(self, "agent_executor"), "agent_executor not initialized"
|
||||
result = self.agent_executor.run_sync(self._openai_agent, task_prompt)
|
||||
final_answer = self.handle_execution_result(result)
|
||||
crewai_event_bus.emit(
|
||||
@@ -133,9 +131,7 @@ class OpenAIAgentAdapter(BaseAgentAdapter):
|
||||
)
|
||||
raise
|
||||
|
||||
def create_agent_executor(
|
||||
self, task: Any = None, tools: Optional[list[BaseTool]] = None
|
||||
) -> None:
|
||||
def create_agent_executor(self, tools: Optional[List[BaseTool]] = None) -> None:
|
||||
"""
|
||||
Configure the OpenAI agent for execution.
|
||||
While OpenAI handles execution differently through Runner,
|
||||
@@ -156,24 +152,24 @@ class OpenAIAgentAdapter(BaseAgentAdapter):
|
||||
|
||||
self.agent_executor = Runner
|
||||
|
||||
def configure_tools(self, tools: Optional[list[BaseTool]] = None) -> None:
|
||||
def configure_tools(self, tools: Optional[List[BaseTool]] = None) -> None:
|
||||
"""Configure tools for the OpenAI Assistant"""
|
||||
if tools:
|
||||
self._tool_adapter.configure_tools(tools)
|
||||
if self._tool_adapter.converted_tools:
|
||||
self._openai_agent.tools = self._tool_adapter.converted_tools
|
||||
|
||||
def handle_execution_result(self, result: Any) -> Any:
|
||||
def handle_execution_result(self, result: Any) -> str:
|
||||
"""Process OpenAI Assistant execution result converting any structured output to a string"""
|
||||
return self._converter_adapter.post_process_result(result.final_output)
|
||||
|
||||
def get_delegation_tools(self, agents: list[BaseAgent]) -> list[BaseTool]:
|
||||
def get_delegation_tools(self, agents: List[BaseAgent]) -> List[BaseTool]:
|
||||
"""Implement delegation tools support"""
|
||||
agent_tools = AgentTools(agents=agents)
|
||||
tools = agent_tools.tools()
|
||||
return tools
|
||||
|
||||
def configure_structured_output(self, task: Any) -> None:
|
||||
def configure_structured_output(self, task) -> None:
|
||||
"""Configure the structured output for the specific agent implementation.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import json
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from crewai.agents.agent_adapters.base_converter_adapter import BaseConverterAdapter
|
||||
from crewai.utilities.converter import generate_model_description
|
||||
@@ -20,15 +19,14 @@ class OpenAIConverterAdapter(BaseConverterAdapter):
|
||||
_output_model: The Pydantic model for the output
|
||||
"""
|
||||
|
||||
def __init__(self, agent_adapter: Any) -> None:
|
||||
def __init__(self, agent_adapter):
|
||||
"""Initialize the converter adapter with a reference to the agent adapter"""
|
||||
super().__init__(agent_adapter) # type: ignore
|
||||
self.agent_adapter = agent_adapter
|
||||
self._output_format: str | None = None
|
||||
self._schema: str | None = None
|
||||
self._output_model: Any = None
|
||||
self._output_format = None
|
||||
self._schema = None
|
||||
self._output_model = None
|
||||
|
||||
def configure_structured_output(self, task: Any) -> None:
|
||||
def configure_structured_output(self, task) -> None:
|
||||
"""
|
||||
Configure the structured output for OpenAI agent based on task requirements.
|
||||
|
||||
@@ -77,7 +75,7 @@ class OpenAIConverterAdapter(BaseConverterAdapter):
|
||||
|
||||
return f"{base_prompt}\n\n{output_schema}"
|
||||
|
||||
def post_process_result(self, result: str) -> Any:
|
||||
def post_process_result(self, result: str) -> str:
|
||||
"""
|
||||
Post-process the result to ensure it matches the expected format.
|
||||
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable
|
||||
from copy import copy as shallow_copy
|
||||
from hashlib import md5
|
||||
from typing import Any, Optional, TypeVar
|
||||
from typing import Any, Callable, Dict, List, Optional, TypeVar
|
||||
|
||||
from pydantic import (
|
||||
UUID4,
|
||||
@@ -15,7 +14,6 @@ from pydantic import (
|
||||
model_validator,
|
||||
)
|
||||
from pydantic_core import PydanticCustomError
|
||||
from typing_extensions import Self
|
||||
|
||||
from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess
|
||||
from crewai.agents.cache.cache_handler import CacheHandler
|
||||
@@ -63,7 +61,7 @@ class BaseAgent(ABC, BaseModel):
|
||||
Methods:
|
||||
execute_task(task: Any, context: Optional[str] = None, tools: Optional[List[BaseTool]] = None) -> str:
|
||||
Abstract method to execute a task.
|
||||
create_agent_executor(task, tools=None) -> None:
|
||||
create_agent_executor(tools=None) -> None:
|
||||
Abstract method to create an agent executor.
|
||||
get_delegation_tools(agents: List["BaseAgent"]):
|
||||
Abstract method to set the agents task tools for handling delegation and question asking to other agents in crew.
|
||||
@@ -81,7 +79,7 @@ class BaseAgent(ABC, BaseModel):
|
||||
Set private attributes.
|
||||
"""
|
||||
|
||||
__hash__ = object.__hash__
|
||||
__hash__ = object.__hash__ # type: ignore
|
||||
_logger: Logger = PrivateAttr(default_factory=lambda: Logger(verbose=False))
|
||||
_rpm_controller: Optional[RPMController] = PrivateAttr(default=None)
|
||||
_request_within_rpm_limit: Any = PrivateAttr(default=None)
|
||||
@@ -93,7 +91,7 @@ class BaseAgent(ABC, BaseModel):
|
||||
role: str = Field(description="Role of the agent")
|
||||
goal: str = Field(description="Objective of the agent")
|
||||
backstory: str = Field(description="Backstory of the agent")
|
||||
config: Optional[dict[str, Any]] = Field(
|
||||
config: Optional[Dict[str, Any]] = Field(
|
||||
description="Configuration for the agent", default=None, exclude=True
|
||||
)
|
||||
cache: bool = Field(
|
||||
@@ -110,14 +108,14 @@ class BaseAgent(ABC, BaseModel):
|
||||
default=False,
|
||||
description="Enable agent to delegate and ask questions among each other.",
|
||||
)
|
||||
tools: Optional[list[BaseTool]] = Field(
|
||||
tools: Optional[List[BaseTool]] = Field(
|
||||
default_factory=list, description="Tools at agents' disposal"
|
||||
)
|
||||
max_iter: int = Field(
|
||||
default=25, description="Maximum iterations for an agent to execute a task"
|
||||
)
|
||||
agent_executor: Optional[Any] = Field(
|
||||
default=None, description="An instance of the agent executor class."
|
||||
agent_executor: InstanceOf = Field(
|
||||
default=None, description="An instance of the CrewAgentExecutor class."
|
||||
)
|
||||
llm: Any = Field(
|
||||
default=None, description="Language model that will run the agent."
|
||||
@@ -131,7 +129,7 @@ class BaseAgent(ABC, BaseModel):
|
||||
default_factory=ToolsHandler,
|
||||
description="An instance of the ToolsHandler class.",
|
||||
)
|
||||
tools_results: list[dict[str, Any]] = Field(
|
||||
tools_results: List[Dict[str, Any]] = Field(
|
||||
default=[], description="Results of the tools used by the agent."
|
||||
)
|
||||
max_tokens: Optional[int] = Field(
|
||||
@@ -140,7 +138,7 @@ class BaseAgent(ABC, BaseModel):
|
||||
knowledge: Optional[Knowledge] = Field(
|
||||
default=None, description="Knowledge for the agent."
|
||||
)
|
||||
knowledge_sources: Optional[list[BaseKnowledgeSource]] = Field(
|
||||
knowledge_sources: Optional[List[BaseKnowledgeSource]] = Field(
|
||||
default=None,
|
||||
description="Knowledge sources for the agent.",
|
||||
)
|
||||
@@ -152,7 +150,7 @@ class BaseAgent(ABC, BaseModel):
|
||||
default_factory=SecurityConfig,
|
||||
description="Security configuration for the agent, including fingerprinting.",
|
||||
)
|
||||
callbacks: list[Callable[..., Any]] = Field(
|
||||
callbacks: List[Callable] = Field(
|
||||
default=[], description="Callbacks to be used for the agent"
|
||||
)
|
||||
adapted_agent: bool = Field(
|
||||
@@ -165,12 +163,12 @@ class BaseAgent(ABC, BaseModel):
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def process_model_config(cls, values: Any) -> Any:
|
||||
def process_model_config(cls, values):
|
||||
return process_config(values, cls)
|
||||
|
||||
@field_validator("tools")
|
||||
@classmethod
|
||||
def validate_tools(cls, tools: list[Any]) -> list[BaseTool]:
|
||||
def validate_tools(cls, tools: List[Any]) -> List[BaseTool]:
|
||||
"""Validate and process the tools provided to the agent.
|
||||
|
||||
This method ensures that each tool is either an instance of BaseTool
|
||||
@@ -198,7 +196,7 @@ class BaseAgent(ABC, BaseModel):
|
||||
return processed_tools
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_and_set_attributes(self) -> Self:
|
||||
def validate_and_set_attributes(self):
|
||||
# Validate required fields
|
||||
for field in ["role", "goal", "backstory"]:
|
||||
if getattr(self, field) is None:
|
||||
@@ -230,7 +228,7 @@ class BaseAgent(ABC, BaseModel):
|
||||
)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def set_private_attrs(self) -> Self:
|
||||
def set_private_attrs(self):
|
||||
"""Set private attributes."""
|
||||
self._logger = Logger(verbose=self.verbose)
|
||||
if self.max_rpm and not self._rpm_controller:
|
||||
@@ -242,7 +240,7 @@ class BaseAgent(ABC, BaseModel):
|
||||
return self
|
||||
|
||||
@property
|
||||
def key(self) -> str:
|
||||
def key(self):
|
||||
source = [
|
||||
self._original_role or self.role,
|
||||
self._original_goal or self.goal,
|
||||
@@ -255,18 +253,16 @@ class BaseAgent(ABC, BaseModel):
|
||||
self,
|
||||
task: Any,
|
||||
context: Optional[str] = None,
|
||||
tools: Optional[list[BaseTool]] = None,
|
||||
tools: Optional[List[BaseTool]] = None,
|
||||
) -> str:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def create_agent_executor(
|
||||
self, task: Any, tools: Optional[list[BaseTool]] = None
|
||||
) -> None:
|
||||
def create_agent_executor(self, tools=None) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_delegation_tools(self, agents: list["BaseAgent"]) -> list[BaseTool]:
|
||||
def get_delegation_tools(self, agents: List["BaseAgent"]) -> List[BaseTool]:
|
||||
"""Set the task tools that init BaseAgenTools class."""
|
||||
pass
|
||||
|
||||
@@ -324,7 +320,7 @@ class BaseAgent(ABC, BaseModel):
|
||||
|
||||
return copied_agent
|
||||
|
||||
def interpolate_inputs(self, inputs: dict[str, Any]) -> None:
|
||||
def interpolate_inputs(self, inputs: Dict[str, Any]) -> None:
|
||||
"""Interpolate inputs into the agent description and backstory."""
|
||||
if self._original_role is None:
|
||||
self._original_role = self.role
|
||||
@@ -354,7 +350,7 @@ class BaseAgent(ABC, BaseModel):
|
||||
if self.cache:
|
||||
self.cache_handler = cache_handler
|
||||
self.tools_handler.cache = cache_handler
|
||||
# Executor will be created when a task is executed
|
||||
self.create_agent_executor()
|
||||
|
||||
def set_rpm_controller(self, rpm_controller: RPMController) -> None:
|
||||
"""Set the rpm controller for the agent.
|
||||
@@ -364,7 +360,7 @@ class BaseAgent(ABC, BaseModel):
|
||||
"""
|
||||
if not self._rpm_controller:
|
||||
self._rpm_controller = rpm_controller
|
||||
# Executor will be created when a task is executed
|
||||
self.create_agent_executor()
|
||||
|
||||
def set_knowledge(self, crew_embedder: Optional[dict[str, Any]] = None) -> None:
|
||||
def set_knowledge(self, crew_embedder: Optional[Dict[str, Any]] = None):
|
||||
pass
|
||||
|
||||
@@ -1,32 +1,31 @@
|
||||
import time
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Dict, List
|
||||
|
||||
from crewai.events.event_listener import event_listener
|
||||
from crewai.memory.entity.entity_memory_item import EntityMemoryItem
|
||||
from crewai.memory.long_term.long_term_memory_item import LongTermMemoryItem
|
||||
from crewai.utilities import I18N
|
||||
from crewai.utilities.converter import ConverterError
|
||||
from crewai.utilities.evaluators.task_evaluator import TaskEvaluator
|
||||
from crewai.utilities.printer import Printer
|
||||
from crewai.utilities.events.event_listener import event_listener
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
from crewai.agents.parser import AgentFinish
|
||||
from crewai.crew import Crew
|
||||
from crewai.task import Task
|
||||
|
||||
|
||||
class CrewAgentExecutorMixin:
|
||||
crew: "Crew | None"
|
||||
crew: "Crew"
|
||||
agent: "BaseAgent"
|
||||
task: "Task"
|
||||
iterations: int
|
||||
max_iter: int
|
||||
messages: list[dict[str, str]]
|
||||
messages: List[Dict[str, str]]
|
||||
_i18n: I18N
|
||||
_printer: Printer = Printer()
|
||||
|
||||
def _create_short_term_memory(self, output: "AgentFinish") -> None:
|
||||
def _create_short_term_memory(self, output) -> None:
|
||||
"""Create and save a short-term memory item if conditions are met."""
|
||||
if (
|
||||
self.crew
|
||||
@@ -36,8 +35,7 @@ class CrewAgentExecutorMixin:
|
||||
):
|
||||
try:
|
||||
if (
|
||||
self.crew
|
||||
and hasattr(self.crew, "_short_term_memory")
|
||||
hasattr(self.crew, "_short_term_memory")
|
||||
and self.crew._short_term_memory
|
||||
):
|
||||
self.crew._short_term_memory.save(
|
||||
@@ -45,12 +43,13 @@ class CrewAgentExecutorMixin:
|
||||
metadata={
|
||||
"observation": self.task.description,
|
||||
},
|
||||
agent=self.agent.role,
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Failed to add to short term memory: {e}")
|
||||
pass
|
||||
|
||||
def _create_external_memory(self, output: "AgentFinish") -> None:
|
||||
def _create_external_memory(self, output) -> None:
|
||||
"""Create and save a external-term memory item if conditions are met."""
|
||||
if (
|
||||
self.crew
|
||||
@@ -66,12 +65,13 @@ class CrewAgentExecutorMixin:
|
||||
"description": self.task.description,
|
||||
"messages": self.messages,
|
||||
},
|
||||
agent=self.agent.role,
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Failed to add to external memory: {e}")
|
||||
pass
|
||||
|
||||
def _create_long_term_memory(self, output: "AgentFinish") -> None:
|
||||
def _create_long_term_memory(self, output) -> None:
|
||||
"""Create and save long-term and entity memory items based on evaluation."""
|
||||
if (
|
||||
self.crew
|
||||
@@ -100,8 +100,8 @@ class CrewAgentExecutorMixin:
|
||||
)
|
||||
self.crew._long_term_memory.save(long_term_memory)
|
||||
|
||||
entity_memories = [
|
||||
EntityMemoryItem(
|
||||
for entity in evaluation.entities:
|
||||
entity_memory = EntityMemoryItem(
|
||||
name=entity.name,
|
||||
type=entity.type,
|
||||
description=entity.description,
|
||||
@@ -109,10 +109,7 @@ class CrewAgentExecutorMixin:
|
||||
[f"- {r}" for r in entity.relationships]
|
||||
),
|
||||
)
|
||||
for entity in evaluation.entities
|
||||
]
|
||||
if entity_memories:
|
||||
self.crew._entity_memory.save(entity_memories)
|
||||
self.crew._entity_memory.save(entity_memory)
|
||||
except AttributeError as e:
|
||||
print(f"Missing attributes for long term memory: {e}")
|
||||
pass
|
||||
@@ -161,9 +158,7 @@ class CrewAgentExecutorMixin:
|
||||
self._printer.print(content=prompt, color="bold_yellow")
|
||||
response = input()
|
||||
if response.strip() != "":
|
||||
self._printer.print(
|
||||
content="\nProcessing your feedback...", color="cyan"
|
||||
)
|
||||
self._printer.print(content="\nProcessing your feedback...", color="cyan")
|
||||
return response
|
||||
finally:
|
||||
event_listener.formatter.resume_live_updates()
|
||||
|
||||
6
src/crewai/agents/cache/__init__.py
vendored
6
src/crewai/agents/cache/__init__.py
vendored
@@ -1,5 +1,3 @@
|
||||
"""Internal caching utilities for agent tool execution.
|
||||
from .cache_handler import CacheHandler
|
||||
|
||||
This package provides caching mechanisms for storing and retrieving
|
||||
tool execution results to avoid redundant operations.
|
||||
"""
|
||||
__all__ = ["CacheHandler"]
|
||||
|
||||
49
src/crewai/agents/cache/cache_handler.py
vendored
49
src/crewai/agents/cache/cache_handler.py
vendored
@@ -1,50 +1,15 @@
|
||||
"""Cache handler for storing and retrieving tool execution results.
|
||||
|
||||
This module provides a caching mechanism for tool outputs in the CrewAI framework,
|
||||
allowing agents to reuse previous tool execution results when the same tool is
|
||||
called with identical arguments.
|
||||
|
||||
Classes:
|
||||
CacheHandler: Manages the caching of tool execution results using an in-memory
|
||||
dictionary with serialized tool arguments as keys.
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from pydantic import BaseModel, PrivateAttr
|
||||
|
||||
|
||||
class CacheHandler(BaseModel):
|
||||
"""Callback handler for tool usage.
|
||||
"""Callback handler for tool usage."""
|
||||
|
||||
_cache: Dict[str, Any] = PrivateAttr(default_factory=dict)
|
||||
|
||||
Notes:
|
||||
TODO: Make thread-safe, currently not thread-safe.
|
||||
"""
|
||||
def add(self, tool, input, output):
|
||||
self._cache[f"{tool}-{input}"] = output
|
||||
|
||||
_cache: dict[str, Any] = PrivateAttr(default_factory=dict)
|
||||
|
||||
def add(self, tool: str, input_data: dict[str, Any] | None, output: str) -> None:
|
||||
"""Add a tool execution result to the cache.
|
||||
|
||||
Args:
|
||||
tool: The name of the tool.
|
||||
input_data: The input arguments for the tool.
|
||||
output: The output from the tool execution.
|
||||
"""
|
||||
cache_key = json.dumps(input_data, sort_keys=True) if input_data else ""
|
||||
self._cache[f"{tool}-{cache_key}"] = output
|
||||
|
||||
def read(self, tool: str, input_data: dict[str, Any] | None) -> str | None:
|
||||
"""Read a tool execution result from the cache.
|
||||
|
||||
Args:
|
||||
tool: The name of the tool.
|
||||
input_data: The input arguments for the tool.
|
||||
|
||||
Returns:
|
||||
The cached output if found, None otherwise.
|
||||
"""
|
||||
cache_key = json.dumps(input_data, sort_keys=True) if input_data else ""
|
||||
return self._cache.get(f"{tool}-{cache_key}")
|
||||
def read(self, tool, input) -> Optional[str]:
|
||||
return self._cache.get(f"{tool}-{input}")
|
||||
|
||||
@@ -1,27 +0,0 @@
|
||||
"""Constants for agent-related modules."""
|
||||
|
||||
import re
|
||||
from typing import Final
|
||||
|
||||
# crewai.agents.parser constants
|
||||
|
||||
FINAL_ANSWER_ACTION: Final[str] = "Final Answer:"
|
||||
MISSING_ACTION_AFTER_THOUGHT_ERROR_MESSAGE: Final[str] = (
|
||||
"I did it wrong. Invalid Format: I missed the 'Action:' after 'Thought:'. I will do right next, and don't use a tool I have already used.\n"
|
||||
)
|
||||
MISSING_ACTION_INPUT_AFTER_ACTION_ERROR_MESSAGE: Final[str] = (
|
||||
"I did it wrong. Invalid Format: I missed the 'Action Input:' after 'Action:'. I will do right next, and don't use a tool I have already used.\n"
|
||||
)
|
||||
FINAL_ANSWER_AND_PARSABLE_ACTION_ERROR_MESSAGE: Final[str] = (
|
||||
"I did it wrong. Tried to both perform Action and give a Final Answer at the same time, I must do one or the other"
|
||||
)
|
||||
UNABLE_TO_REPAIR_JSON_RESULTS: Final[list[str]] = ['""', "{}"]
|
||||
ACTION_INPUT_REGEX: Final[re.Pattern[str]] = re.compile(
|
||||
r"Action\s*\d*\s*:\s*(.*?)\s*Action\s*\d*\s*Input\s*\d*\s*:\s*(.*)", re.DOTALL
|
||||
)
|
||||
ACTION_REGEX: Final[re.Pattern[str]] = re.compile(
|
||||
r"Action\s*\d*\s*:\s*(.*?)", re.DOTALL
|
||||
)
|
||||
ACTION_INPUT_ONLY_REGEX: Final[re.Pattern[str]] = re.compile(
|
||||
r"\s*Action\s*\d*\s*Input\s*\d*\s*:\s*(.*)", re.DOTALL
|
||||
)
|
||||
@@ -1,17 +1,4 @@
|
||||
"""Agent executor for crew AI agents.
|
||||
|
||||
Handles agent execution flow including LLM interactions, tool execution,
|
||||
and memory management.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.crew import Crew
|
||||
from crewai.task import Task
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
from crewai.agents.agent_builder.base_agent_executor_mixin import CrewAgentExecutorMixin
|
||||
@@ -21,12 +8,7 @@ from crewai.agents.parser import (
|
||||
OutputParserException,
|
||||
)
|
||||
from crewai.agents.tools_handler import ToolsHandler
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.logging_events import (
|
||||
AgentLogsExecutionEvent,
|
||||
AgentLogsStartedEvent,
|
||||
)
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.llm import BaseLLM
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
from crewai.tools.structured_tool import CrewStructuredTool
|
||||
from crewai.tools.tool_types import ToolResult
|
||||
@@ -44,73 +26,54 @@ from crewai.utilities.agent_utils import (
|
||||
is_context_length_exceeded,
|
||||
process_llm_response,
|
||||
)
|
||||
from crewai.utilities.constants import TRAINING_DATA_FILE
|
||||
from crewai.utilities.constants import MAX_LLM_RETRY, TRAINING_DATA_FILE
|
||||
from crewai.utilities.logger import Logger
|
||||
from crewai.utilities.tool_utils import execute_tool_and_check_finality
|
||||
from crewai.utilities.training_handler import CrewTrainingHandler
|
||||
from crewai.utilities.events.agent_events import (
|
||||
AgentLogsStartedEvent,
|
||||
AgentLogsExecutionEvent,
|
||||
)
|
||||
from crewai.utilities.events.crewai_event_bus import crewai_event_bus
|
||||
|
||||
|
||||
class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
"""Executor for crew agents.
|
||||
|
||||
Manages the execution lifecycle of an agent including prompt formatting,
|
||||
LLM interactions, tool execution, and feedback handling.
|
||||
"""
|
||||
_logger: Logger = Logger()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm: BaseLLM,
|
||||
task: Task,
|
||||
crew: Crew | None,
|
||||
llm: Any,
|
||||
task: Any,
|
||||
crew: Any,
|
||||
agent: BaseAgent,
|
||||
prompt: dict[str, str],
|
||||
max_iter: int,
|
||||
tools: list[CrewStructuredTool],
|
||||
tools: List[CrewStructuredTool],
|
||||
tools_names: str,
|
||||
stop_words: list[str],
|
||||
stop_words: List[str],
|
||||
tools_description: str,
|
||||
tools_handler: ToolsHandler,
|
||||
step_callback: Callable[[AgentAction | AgentFinish], None] | None = None,
|
||||
original_tools: list[BaseTool] | None = None,
|
||||
function_calling_llm: BaseLLM | None = None,
|
||||
step_callback: Any = None,
|
||||
original_tools: List[Any] = [],
|
||||
function_calling_llm: Any = None,
|
||||
respect_context_window: bool = False,
|
||||
request_within_rpm_limit: Callable[[], bool] | None = None,
|
||||
litellm_callbacks: list[Any] | None = None,
|
||||
) -> None:
|
||||
"""Initialize executor.
|
||||
|
||||
Args:
|
||||
llm: Language model instance.
|
||||
task: Task to execute.
|
||||
crew: Optional Crew instance.
|
||||
agent: Agent to execute.
|
||||
prompt: Prompt templates.
|
||||
max_iter: Maximum iterations.
|
||||
tools: Available tools.
|
||||
tools_names: Tool names string.
|
||||
stop_words: Stop word list.
|
||||
tools_description: Tool descriptions.
|
||||
tools_handler: Tool handler instance.
|
||||
step_callback: Optional step callback.
|
||||
original_tools: Original tool list.
|
||||
function_calling_llm: Optional function calling LLM.
|
||||
respect_context_window: Respect context limits.
|
||||
request_within_rpm_limit: RPM limit check function.
|
||||
litellm_callbacks: Optional litellm callbacks list.
|
||||
"""
|
||||
request_within_rpm_limit: Optional[Callable[[], bool]] = None,
|
||||
callbacks: List[Any] = [],
|
||||
):
|
||||
self._i18n: I18N = I18N()
|
||||
self.llm = llm
|
||||
self.llm: BaseLLM = llm
|
||||
self.task = task
|
||||
self.agent = agent
|
||||
self.crew: Crew | None = crew
|
||||
self.crew = crew
|
||||
self.prompt = prompt
|
||||
self.tools = tools
|
||||
self.tools_names = tools_names
|
||||
self.stop = stop_words
|
||||
self.max_iter = max_iter
|
||||
self.litellm_callbacks = litellm_callbacks or []
|
||||
self.callbacks = callbacks
|
||||
self._printer: Printer = Printer()
|
||||
self.tools_handler = tools_handler
|
||||
self.original_tools = original_tools or []
|
||||
self.original_tools = original_tools
|
||||
self.step_callback = step_callback
|
||||
self.use_stop_words = self.llm.supports_stop_words()
|
||||
self.tools_description = tools_description
|
||||
@@ -118,9 +81,12 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
self.respect_context_window = respect_context_window
|
||||
self.request_within_rpm_limit = request_within_rpm_limit
|
||||
self.ask_for_human_input = False
|
||||
self.messages: list[dict[str, str]] = []
|
||||
self.messages: List[Dict[str, str]] = []
|
||||
self.iterations = 0
|
||||
self.log_error_after = 3
|
||||
self.tool_name_to_tool_map: Dict[str, Union[CrewStructuredTool, BaseTool]] = {
|
||||
tool.name: tool for tool in self.tools
|
||||
}
|
||||
existing_stop = self.llm.stop or []
|
||||
self.llm.stop = list(
|
||||
set(
|
||||
@@ -130,19 +96,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
)
|
||||
)
|
||||
|
||||
def invoke(self, inputs: dict[str, str]) -> dict[str, str]:
|
||||
"""Execute the agent with given inputs.
|
||||
|
||||
Args:
|
||||
inputs: Input dictionary containing prompt variables.
|
||||
|
||||
Returns:
|
||||
Dictionary with agent output.
|
||||
|
||||
Raises:
|
||||
AssertionError: If agent fails to reach final answer.
|
||||
Exception: If unknown error occurs during execution.
|
||||
"""
|
||||
def invoke(self, inputs: Dict[str, str]) -> Dict[str, Any]:
|
||||
if "system" in self.prompt:
|
||||
system_prompt = self._format_prompt(self.prompt.get("system", ""), inputs)
|
||||
user_prompt = self._format_prompt(self.prompt.get("user", ""), inputs)
|
||||
@@ -168,6 +122,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
handle_unknown_error(self._printer, e)
|
||||
raise
|
||||
|
||||
|
||||
if self.ask_for_human_input:
|
||||
formatted_answer = self._handle_human_feedback(formatted_answer)
|
||||
|
||||
@@ -177,13 +132,9 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
return {"output": formatted_answer.output}
|
||||
|
||||
def _invoke_loop(self) -> AgentFinish:
|
||||
"""Execute agent loop until completion.
|
||||
|
||||
Returns:
|
||||
Final answer from the agent.
|
||||
|
||||
Raises:
|
||||
Exception: If litellm error or unknown error occurs.
|
||||
"""
|
||||
Main loop to invoke the agent's thought process until it reaches a conclusion
|
||||
or the maximum number of iterations is reached.
|
||||
"""
|
||||
formatted_answer = None
|
||||
while not isinstance(formatted_answer, AgentFinish):
|
||||
@@ -195,7 +146,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
i18n=self._i18n,
|
||||
messages=self.messages,
|
||||
llm=self.llm,
|
||||
callbacks=self.litellm_callbacks,
|
||||
callbacks=self.callbacks,
|
||||
)
|
||||
|
||||
enforce_rpm_limit(self.request_within_rpm_limit)
|
||||
@@ -203,17 +154,19 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
answer = get_llm_response(
|
||||
llm=self.llm,
|
||||
messages=self.messages,
|
||||
callbacks=self.litellm_callbacks,
|
||||
callbacks=self.callbacks,
|
||||
printer=self._printer,
|
||||
from_task=self.task,
|
||||
from_task=self.task
|
||||
)
|
||||
formatted_answer = process_llm_response(answer, self.use_stop_words)
|
||||
|
||||
if isinstance(formatted_answer, AgentAction):
|
||||
# Extract agent fingerprint if available
|
||||
fingerprint_context = {}
|
||||
if hasattr(self.agent, "security_config") and hasattr(
|
||||
self.agent.security_config, "fingerprint"
|
||||
if (
|
||||
self.agent
|
||||
and hasattr(self.agent, "security_config")
|
||||
and hasattr(self.agent.security_config, "fingerprint")
|
||||
):
|
||||
fingerprint_context = {
|
||||
"agent_fingerprint": str(
|
||||
@@ -226,8 +179,8 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
fingerprint_context=fingerprint_context,
|
||||
tools=self.tools,
|
||||
i18n=self._i18n,
|
||||
agent_key=self.agent.key,
|
||||
agent_role=self.agent.role,
|
||||
agent_key=self.agent.key if self.agent else None,
|
||||
agent_role=self.agent.role if self.agent else None,
|
||||
tools_handler=self.tools_handler,
|
||||
task=self.task,
|
||||
agent=self.agent,
|
||||
@@ -238,7 +191,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
)
|
||||
|
||||
self._invoke_step_callback(formatted_answer)
|
||||
self._append_message(formatted_answer.text)
|
||||
self._append_message(formatted_answer.text, role="assistant")
|
||||
|
||||
except OutputParserException as e:
|
||||
formatted_answer = handle_output_parser_exception(
|
||||
@@ -259,7 +212,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
printer=self._printer,
|
||||
messages=self.messages,
|
||||
llm=self.llm,
|
||||
callbacks=self.litellm_callbacks,
|
||||
callbacks=self.callbacks,
|
||||
i18n=self._i18n,
|
||||
)
|
||||
continue
|
||||
@@ -279,16 +232,8 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
|
||||
def _handle_agent_action(
|
||||
self, formatted_answer: AgentAction, tool_result: ToolResult
|
||||
) -> AgentAction | AgentFinish:
|
||||
"""Process agent action and tool execution.
|
||||
|
||||
Args:
|
||||
formatted_answer: Agent's action to execute.
|
||||
tool_result: Result from tool execution.
|
||||
|
||||
Returns:
|
||||
Updated action or final answer.
|
||||
"""
|
||||
) -> Union[AgentAction, AgentFinish]:
|
||||
"""Handle the AgentAction, execute tools, and process the results."""
|
||||
# Special case for add_image_tool
|
||||
add_image_tool = self._i18n.tools("add_image")
|
||||
if (
|
||||
@@ -307,65 +252,88 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
show_logs=self._show_logs,
|
||||
)
|
||||
|
||||
def _invoke_step_callback(
|
||||
self, formatted_answer: AgentAction | AgentFinish
|
||||
) -> None:
|
||||
"""Invoke step callback.
|
||||
|
||||
Args:
|
||||
formatted_answer: Current agent response.
|
||||
"""
|
||||
def _invoke_step_callback(self, formatted_answer) -> None:
|
||||
"""Invoke the step callback if it exists."""
|
||||
if self.step_callback:
|
||||
self.step_callback(formatted_answer)
|
||||
|
||||
def _append_message(self, text: str, role: str = "assistant") -> None:
|
||||
"""Add message to conversation history.
|
||||
|
||||
Args:
|
||||
text: Message content.
|
||||
role: Message role (default: assistant).
|
||||
"""
|
||||
"""Append a message to the message list with the given role."""
|
||||
self.messages.append(format_message_for_llm(text, role=role))
|
||||
|
||||
def _show_start_logs(self) -> None:
|
||||
"""Emit agent start event."""
|
||||
def _show_start_logs(self):
|
||||
"""Show logs for the start of agent execution."""
|
||||
if self.agent is None:
|
||||
raise ValueError("Agent cannot be None")
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self.agent,
|
||||
AgentLogsStartedEvent(
|
||||
agent_role=self.agent.role,
|
||||
task_description=self.task.description,
|
||||
task_description=(
|
||||
getattr(self.task, "description") if self.task else "Not Found"
|
||||
),
|
||||
verbose=self.agent.verbose
|
||||
or (self.crew.verbose if self.crew else False),
|
||||
or (hasattr(self, "crew") and getattr(self.crew, "verbose", False)),
|
||||
),
|
||||
)
|
||||
|
||||
def _show_logs(self, formatted_answer: AgentAction | AgentFinish) -> None:
|
||||
"""Emit agent execution event.
|
||||
def _show_logs(self, formatted_answer: Union[AgentAction, AgentFinish]):
|
||||
"""Show logs for the agent's execution."""
|
||||
if self.agent is None:
|
||||
raise ValueError("Agent cannot be None")
|
||||
|
||||
Args:
|
||||
formatted_answer: Agent's response to log.
|
||||
"""
|
||||
crewai_event_bus.emit(
|
||||
self.agent,
|
||||
AgentLogsExecutionEvent(
|
||||
agent_role=self.agent.role,
|
||||
formatted_answer=formatted_answer,
|
||||
verbose=self.agent.verbose
|
||||
or (self.crew.verbose if self.crew else False),
|
||||
or (hasattr(self, "crew") and getattr(self.crew, "verbose", False)),
|
||||
),
|
||||
)
|
||||
|
||||
def _handle_crew_training_output(
|
||||
self, result: AgentFinish, human_feedback: str | None = None
|
||||
) -> None:
|
||||
"""Save training data.
|
||||
def _summarize_messages(self) -> None:
|
||||
messages_groups = []
|
||||
for message in self.messages:
|
||||
content = message["content"]
|
||||
cut_size = self.llm.get_context_window_size()
|
||||
for i in range(0, len(content), cut_size):
|
||||
messages_groups.append({"content": content[i : i + cut_size]})
|
||||
|
||||
Args:
|
||||
result: Agent's final output.
|
||||
human_feedback: Optional feedback from human.
|
||||
"""
|
||||
agent_id = str(self.agent.id)
|
||||
train_iteration = getattr(self.crew, "_train_iteration", None)
|
||||
summarized_contents = []
|
||||
for group in messages_groups:
|
||||
summary = self.llm.call(
|
||||
[
|
||||
format_message_for_llm(
|
||||
self._i18n.slice("summarizer_system_message"), role="system"
|
||||
),
|
||||
format_message_for_llm(
|
||||
self._i18n.slice("summarize_instruction").format(
|
||||
group=group["content"]
|
||||
),
|
||||
),
|
||||
],
|
||||
callbacks=self.callbacks,
|
||||
)
|
||||
summarized_contents.append({"content": str(summary)})
|
||||
|
||||
merged_summary = " ".join(content["content"] for content in summarized_contents)
|
||||
|
||||
self.messages = [
|
||||
format_message_for_llm(
|
||||
self._i18n.slice("summary").format(merged_summary=merged_summary)
|
||||
)
|
||||
]
|
||||
|
||||
def _handle_crew_training_output(
|
||||
self, result: AgentFinish, human_feedback: Optional[str] = None
|
||||
) -> None:
|
||||
"""Handle the process of saving training data."""
|
||||
agent_id = str(self.agent.id) # type: ignore
|
||||
train_iteration = (
|
||||
getattr(self.crew, "_train_iteration", None) if self.crew else None
|
||||
)
|
||||
|
||||
if train_iteration is None or not isinstance(train_iteration, int):
|
||||
self._printer.print(
|
||||
@@ -404,30 +372,20 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
training_data[agent_id] = agent_training_data
|
||||
training_handler.save(training_data)
|
||||
|
||||
@staticmethod
|
||||
def _format_prompt(prompt: str, inputs: dict[str, str]) -> str:
|
||||
"""Format prompt with input values.
|
||||
|
||||
Args:
|
||||
prompt: Template string.
|
||||
inputs: Values to substitute.
|
||||
|
||||
Returns:
|
||||
Formatted prompt.
|
||||
"""
|
||||
def _format_prompt(self, prompt: str, inputs: Dict[str, str]) -> str:
|
||||
prompt = prompt.replace("{input}", inputs["input"])
|
||||
prompt = prompt.replace("{tool_names}", inputs["tool_names"])
|
||||
prompt = prompt.replace("{tools}", inputs["tools"])
|
||||
return prompt
|
||||
|
||||
def _handle_human_feedback(self, formatted_answer: AgentFinish) -> AgentFinish:
|
||||
"""Process human feedback.
|
||||
"""Handle human feedback with different flows for training vs regular use.
|
||||
|
||||
Args:
|
||||
formatted_answer: Initial agent result.
|
||||
formatted_answer: The initial AgentFinish result to get feedback on
|
||||
|
||||
Returns:
|
||||
Final answer after feedback.
|
||||
AgentFinish: The final answer after processing feedback
|
||||
"""
|
||||
human_feedback = self._ask_human_input(formatted_answer.output)
|
||||
|
||||
@@ -437,25 +395,13 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
return self._handle_regular_feedback(formatted_answer, human_feedback)
|
||||
|
||||
def _is_training_mode(self) -> bool:
|
||||
"""Check if training mode is active.
|
||||
|
||||
Returns:
|
||||
True if in training mode.
|
||||
"""
|
||||
"""Check if crew is in training mode."""
|
||||
return bool(self.crew and self.crew._train)
|
||||
|
||||
def _handle_training_feedback(
|
||||
self, initial_answer: AgentFinish, feedback: str
|
||||
) -> AgentFinish:
|
||||
"""Process training feedback.
|
||||
|
||||
Args:
|
||||
initial_answer: Initial agent output.
|
||||
feedback: Training feedback.
|
||||
|
||||
Returns:
|
||||
Improved answer.
|
||||
"""
|
||||
"""Process feedback for training scenarios with single iteration."""
|
||||
self._handle_crew_training_output(initial_answer, feedback)
|
||||
self.messages.append(
|
||||
format_message_for_llm(
|
||||
@@ -470,15 +416,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
def _handle_regular_feedback(
|
||||
self, current_answer: AgentFinish, initial_feedback: str
|
||||
) -> AgentFinish:
|
||||
"""Process regular feedback iteratively.
|
||||
|
||||
Args:
|
||||
current_answer: Current agent output.
|
||||
initial_feedback: Initial user feedback.
|
||||
|
||||
Returns:
|
||||
Final answer after iterations.
|
||||
"""
|
||||
"""Process feedback for regular use with potential multiple iterations."""
|
||||
feedback = initial_feedback
|
||||
answer = current_answer
|
||||
|
||||
@@ -493,17 +431,30 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
return answer
|
||||
|
||||
def _process_feedback_iteration(self, feedback: str) -> AgentFinish:
|
||||
"""Process single feedback iteration.
|
||||
|
||||
Args:
|
||||
feedback: User feedback.
|
||||
|
||||
Returns:
|
||||
Updated agent response.
|
||||
"""
|
||||
"""Process a single feedback iteration."""
|
||||
self.messages.append(
|
||||
format_message_for_llm(
|
||||
self._i18n.slice("feedback_instructions").format(feedback=feedback)
|
||||
)
|
||||
)
|
||||
return self._invoke_loop()
|
||||
|
||||
def _log_feedback_error(self, retry_count: int, error: Exception) -> None:
|
||||
"""Log feedback processing errors."""
|
||||
self._printer.print(
|
||||
content=(
|
||||
f"Error processing feedback: {error}. "
|
||||
f"Retrying... ({retry_count + 1}/{MAX_LLM_RETRY})"
|
||||
),
|
||||
color="red",
|
||||
)
|
||||
|
||||
def _log_max_retries_exceeded(self) -> None:
|
||||
"""Log when max retries for feedback processing are exceeded."""
|
||||
self._printer.print(
|
||||
content=(
|
||||
f"Failed to process feedback after {MAX_LLM_RETRY} attempts. "
|
||||
"Ending feedback loop."
|
||||
),
|
||||
color="red",
|
||||
)
|
||||
|
||||
@@ -1,67 +1,50 @@
|
||||
"""Agent output parsing module for ReAct-style LLM responses.
|
||||
import re
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
This module provides parsing functionality for agent outputs that follow
|
||||
the ReAct (Reasoning and Acting) format, converting them into structured
|
||||
AgentAction or AgentFinish objects.
|
||||
"""
|
||||
from json_repair import repair_json
|
||||
|
||||
from dataclasses import dataclass
|
||||
from crewai.utilities import I18N
|
||||
|
||||
from json_repair import repair_json # type: ignore[import-untyped]
|
||||
|
||||
from crewai.agents.constants import (
|
||||
ACTION_INPUT_ONLY_REGEX,
|
||||
ACTION_INPUT_REGEX,
|
||||
ACTION_REGEX,
|
||||
FINAL_ANSWER_ACTION,
|
||||
MISSING_ACTION_AFTER_THOUGHT_ERROR_MESSAGE,
|
||||
MISSING_ACTION_INPUT_AFTER_ACTION_ERROR_MESSAGE,
|
||||
UNABLE_TO_REPAIR_JSON_RESULTS,
|
||||
)
|
||||
from crewai.utilities.i18n import I18N
|
||||
|
||||
_I18N = I18N()
|
||||
FINAL_ANSWER_ACTION = "Final Answer:"
|
||||
MISSING_ACTION_AFTER_THOUGHT_ERROR_MESSAGE = "I did it wrong. Invalid Format: I missed the 'Action:' after 'Thought:'. I will do right next, and don't use a tool I have already used.\n"
|
||||
MISSING_ACTION_INPUT_AFTER_ACTION_ERROR_MESSAGE = "I did it wrong. Invalid Format: I missed the 'Action Input:' after 'Action:'. I will do right next, and don't use a tool I have already used.\n"
|
||||
FINAL_ANSWER_AND_PARSABLE_ACTION_ERROR_MESSAGE = "I did it wrong. Tried to both perform Action and give a Final Answer at the same time, I must do one or the other"
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentAction:
|
||||
"""Represents an action to be taken by an agent."""
|
||||
|
||||
thought: str
|
||||
tool: str
|
||||
tool_input: str
|
||||
text: str
|
||||
result: str | None = None
|
||||
result: str
|
||||
|
||||
def __init__(self, thought: str, tool: str, tool_input: str, text: str):
|
||||
self.thought = thought
|
||||
self.tool = tool
|
||||
self.tool_input = tool_input
|
||||
self.text = text
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentFinish:
|
||||
"""Represents the final answer from an agent."""
|
||||
|
||||
thought: str
|
||||
output: str
|
||||
text: str
|
||||
|
||||
def __init__(self, thought: str, output: str, text: str):
|
||||
self.thought = thought
|
||||
self.output = output
|
||||
self.text = text
|
||||
|
||||
|
||||
class OutputParserException(Exception):
|
||||
"""Exception raised when output parsing fails.
|
||||
error: str
|
||||
|
||||
Attributes:
|
||||
error: The error message.
|
||||
"""
|
||||
|
||||
def __init__(self, error: str) -> None:
|
||||
"""Initialize OutputParserException.
|
||||
|
||||
Args:
|
||||
error: The error message.
|
||||
"""
|
||||
def __init__(self, error: str):
|
||||
self.error = error
|
||||
super().__init__(error)
|
||||
|
||||
|
||||
def parse(text: str) -> AgentAction | AgentFinish:
|
||||
"""Parse agent output text into AgentAction or AgentFinish.
|
||||
class CrewAgentParser:
|
||||
"""Parses ReAct-style LLM calls that have a single tool input.
|
||||
|
||||
Expects output to be in one of two formats.
|
||||
|
||||
@@ -79,117 +62,108 @@ def parse(text: str) -> AgentAction | AgentFinish:
|
||||
|
||||
Thought: agent thought here
|
||||
Final Answer: The temperature is 100 degrees
|
||||
|
||||
Args:
|
||||
text: The agent output text to parse.
|
||||
|
||||
Returns:
|
||||
AgentAction or AgentFinish based on the content.
|
||||
|
||||
Raises:
|
||||
OutputParserException: If the text format is invalid.
|
||||
"""
|
||||
thought = _extract_thought(text)
|
||||
includes_answer = FINAL_ANSWER_ACTION in text
|
||||
action_match = ACTION_INPUT_REGEX.search(text)
|
||||
|
||||
if includes_answer:
|
||||
final_answer = text.split(FINAL_ANSWER_ACTION)[-1].strip()
|
||||
# Check whether the final answer ends with triple backticks.
|
||||
if final_answer.endswith("```"):
|
||||
# Count occurrences of triple backticks in the final answer.
|
||||
count = final_answer.count("```")
|
||||
# If count is odd then it's an unmatched trailing set; remove it.
|
||||
if count % 2 != 0:
|
||||
final_answer = final_answer[:-3].rstrip()
|
||||
return AgentFinish(thought=thought, output=final_answer, text=text)
|
||||
_i18n: I18N = I18N()
|
||||
agent: Any = None
|
||||
|
||||
elif action_match:
|
||||
action = action_match.group(1)
|
||||
clean_action = _clean_action(action)
|
||||
def __init__(self, agent: Optional[Any] = None):
|
||||
self.agent = agent
|
||||
|
||||
action_input = action_match.group(2).strip()
|
||||
@staticmethod
|
||||
def parse_text(text: str) -> Union[AgentAction, AgentFinish]:
|
||||
"""
|
||||
Static method to parse text into an AgentAction or AgentFinish without needing to instantiate the class.
|
||||
|
||||
tool_input = action_input.strip(" ").strip('"')
|
||||
safe_tool_input = _safe_repair_json(tool_input)
|
||||
Args:
|
||||
text: The text to parse.
|
||||
|
||||
return AgentAction(
|
||||
thought=thought, tool=clean_action, tool_input=safe_tool_input, text=text
|
||||
Returns:
|
||||
Either an AgentAction or AgentFinish based on the parsed content.
|
||||
"""
|
||||
parser = CrewAgentParser()
|
||||
return parser.parse(text)
|
||||
|
||||
def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
|
||||
thought = self._extract_thought(text)
|
||||
includes_answer = FINAL_ANSWER_ACTION in text
|
||||
regex = (
|
||||
r"Action\s*\d*\s*:[\s]*(.*?)[\s]*Action\s*\d*\s*Input\s*\d*\s*:[\s]*(.*)"
|
||||
)
|
||||
action_match = re.search(regex, text, re.DOTALL)
|
||||
if includes_answer:
|
||||
final_answer = text.split(FINAL_ANSWER_ACTION)[-1].strip()
|
||||
# Check whether the final answer ends with triple backticks.
|
||||
if final_answer.endswith("```"):
|
||||
# Count occurrences of triple backticks in the final answer.
|
||||
count = final_answer.count("```")
|
||||
# If count is odd then it's an unmatched trailing set; remove it.
|
||||
if count % 2 != 0:
|
||||
final_answer = final_answer[:-3].rstrip()
|
||||
return AgentFinish(thought, final_answer, text)
|
||||
|
||||
if not ACTION_REGEX.search(text):
|
||||
raise OutputParserException(
|
||||
f"{MISSING_ACTION_AFTER_THOUGHT_ERROR_MESSAGE}\n{_I18N.slice('final_answer_format')}",
|
||||
)
|
||||
elif not ACTION_INPUT_ONLY_REGEX.search(text):
|
||||
raise OutputParserException(
|
||||
MISSING_ACTION_INPUT_AFTER_ACTION_ERROR_MESSAGE,
|
||||
)
|
||||
else:
|
||||
err_format = _I18N.slice("format_without_tools")
|
||||
error = f"{err_format}"
|
||||
raise OutputParserException(
|
||||
error,
|
||||
)
|
||||
elif action_match:
|
||||
action = action_match.group(1)
|
||||
clean_action = self._clean_action(action)
|
||||
|
||||
action_input = action_match.group(2).strip()
|
||||
|
||||
def _extract_thought(text: str) -> str:
|
||||
"""Extract the thought portion from the text.
|
||||
tool_input = action_input.strip(" ").strip('"')
|
||||
safe_tool_input = self._safe_repair_json(tool_input)
|
||||
|
||||
Args:
|
||||
text: The full agent output text.
|
||||
return AgentAction(thought, clean_action, safe_tool_input, text)
|
||||
|
||||
Returns:
|
||||
The extracted thought string.
|
||||
"""
|
||||
thought_index = text.find("\nAction")
|
||||
if thought_index == -1:
|
||||
thought_index = text.find("\nFinal Answer")
|
||||
if thought_index == -1:
|
||||
return ""
|
||||
thought = text[:thought_index].strip()
|
||||
# Remove any triple backticks from the thought string
|
||||
thought = thought.replace("```", "").strip()
|
||||
return thought
|
||||
if not re.search(r"Action\s*\d*\s*:[\s]*(.*?)", text, re.DOTALL):
|
||||
raise OutputParserException(
|
||||
f"{MISSING_ACTION_AFTER_THOUGHT_ERROR_MESSAGE}\n{self._i18n.slice('final_answer_format')}",
|
||||
)
|
||||
elif not re.search(
|
||||
r"[\s]*Action\s*\d*\s*Input\s*\d*\s*:[\s]*(.*)", text, re.DOTALL
|
||||
):
|
||||
raise OutputParserException(
|
||||
MISSING_ACTION_INPUT_AFTER_ACTION_ERROR_MESSAGE,
|
||||
)
|
||||
else:
|
||||
format = self._i18n.slice("format_without_tools")
|
||||
error = f"{format}"
|
||||
raise OutputParserException(
|
||||
error,
|
||||
)
|
||||
|
||||
def _extract_thought(self, text: str) -> str:
|
||||
thought_index = text.find("\nAction")
|
||||
if thought_index == -1:
|
||||
thought_index = text.find("\nFinal Answer")
|
||||
if thought_index == -1:
|
||||
return ""
|
||||
thought = text[:thought_index].strip()
|
||||
# Remove any triple backticks from the thought string
|
||||
thought = thought.replace("```", "").strip()
|
||||
return thought
|
||||
|
||||
def _clean_action(text: str) -> str:
|
||||
"""Clean action string by removing non-essential formatting characters.
|
||||
def _clean_action(self, text: str) -> str:
|
||||
"""Clean action string by removing non-essential formatting characters."""
|
||||
return text.strip().strip("*").strip()
|
||||
|
||||
Args:
|
||||
text: The action text to clean.
|
||||
def _safe_repair_json(self, tool_input: str) -> str:
|
||||
UNABLE_TO_REPAIR_JSON_RESULTS = ['""', "{}"]
|
||||
|
||||
Returns:
|
||||
The cleaned action string.
|
||||
"""
|
||||
return text.strip().strip("*").strip()
|
||||
# Skip repair if the input starts and ends with square brackets
|
||||
# Explanation: The JSON parser has issues handling inputs that are enclosed in square brackets ('[]').
|
||||
# These are typically valid JSON arrays or strings that do not require repair. Attempting to repair such inputs
|
||||
# might lead to unintended alterations, such as wrapping the entire input in additional layers or modifying
|
||||
# the structure in a way that changes its meaning. By skipping the repair for inputs that start and end with
|
||||
# square brackets, we preserve the integrity of these valid JSON structures and avoid unnecessary modifications.
|
||||
if tool_input.startswith("[") and tool_input.endswith("]"):
|
||||
return tool_input
|
||||
|
||||
# Before repair, handle common LLM issues:
|
||||
# 1. Replace """ with " to avoid JSON parser errors
|
||||
|
||||
def _safe_repair_json(tool_input: str) -> str:
|
||||
"""Safely repair JSON input.
|
||||
tool_input = tool_input.replace('"""', '"')
|
||||
|
||||
Args:
|
||||
tool_input: The tool input string to repair.
|
||||
result = repair_json(tool_input)
|
||||
if result in UNABLE_TO_REPAIR_JSON_RESULTS:
|
||||
return tool_input
|
||||
|
||||
Returns:
|
||||
The repaired JSON string or original if repair fails.
|
||||
"""
|
||||
# Skip repair if the input starts and ends with square brackets
|
||||
# Explanation: The JSON parser has issues handling inputs that are enclosed in square brackets ('[]').
|
||||
# These are typically valid JSON arrays or strings that do not require repair. Attempting to repair such inputs
|
||||
# might lead to unintended alterations, such as wrapping the entire input in additional layers or modifying
|
||||
# the structure in a way that changes its meaning. By skipping the repair for inputs that start and end with
|
||||
# square brackets, we preserve the integrity of these valid JSON structures and avoid unnecessary modifications.
|
||||
if tool_input.startswith("[") and tool_input.endswith("]"):
|
||||
return tool_input
|
||||
|
||||
# Before repair, handle common LLM issues:
|
||||
# 1. Replace """ with " to avoid JSON parser errors
|
||||
|
||||
tool_input = tool_input.replace('"""', '"')
|
||||
|
||||
result = repair_json(tool_input)
|
||||
if result in UNABLE_TO_REPAIR_JSON_RESULTS:
|
||||
return tool_input
|
||||
|
||||
return str(result)
|
||||
return str(result)
|
||||
|
||||
@@ -1,44 +1,32 @@
|
||||
"""Tools handler for managing tool execution and caching."""
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from crewai.agents.cache.cache_handler import CacheHandler
|
||||
from crewai.tools.cache_tools.cache_tools import CacheTools
|
||||
from crewai.tools.tool_calling import InstructorToolCalling, ToolCalling
|
||||
from ..tools.cache_tools.cache_tools import CacheTools
|
||||
from ..tools.tool_calling import InstructorToolCalling, ToolCalling
|
||||
from .cache.cache_handler import CacheHandler
|
||||
|
||||
|
||||
class ToolsHandler:
|
||||
"""Callback handler for tool usage.
|
||||
"""Callback handler for tool usage."""
|
||||
|
||||
Attributes:
|
||||
last_used_tool: The most recently used tool calling instance.
|
||||
cache: Optional cache handler for storing tool outputs.
|
||||
"""
|
||||
last_used_tool: ToolCalling = {} # type: ignore # BUG?: Incompatible types in assignment (expression has type "Dict[...]", variable has type "ToolCalling")
|
||||
cache: Optional[CacheHandler]
|
||||
|
||||
def __init__(self, cache: CacheHandler | None = None) -> None:
|
||||
"""Initialize the callback handler.
|
||||
|
||||
Args:
|
||||
cache: Optional cache handler for storing tool outputs.
|
||||
"""
|
||||
self.cache: CacheHandler | None = cache
|
||||
self.last_used_tool: ToolCalling | InstructorToolCalling | None = None
|
||||
def __init__(self, cache: Optional[CacheHandler] = None):
|
||||
"""Initialize the callback handler."""
|
||||
self.cache = cache
|
||||
self.last_used_tool = {} # type: ignore # BUG?: same as above
|
||||
|
||||
def on_tool_use(
|
||||
self,
|
||||
calling: ToolCalling | InstructorToolCalling,
|
||||
calling: Union[ToolCalling, InstructorToolCalling],
|
||||
output: str,
|
||||
should_cache: bool = True,
|
||||
) -> None:
|
||||
"""Run when tool ends running.
|
||||
|
||||
Args:
|
||||
calling: The tool calling instance.
|
||||
output: The output from the tool execution.
|
||||
should_cache: Whether to cache the tool output.
|
||||
"""
|
||||
self.last_used_tool = calling
|
||||
) -> Any:
|
||||
"""Run when tool ends running."""
|
||||
self.last_used_tool = calling # type: ignore # BUG?: Incompatible types in assignment (expression has type "Union[ToolCalling, InstructorToolCalling]", variable has type "ToolCalling")
|
||||
if self.cache and should_cache and calling.tool_name != CacheTools().name:
|
||||
self.cache.add(
|
||||
tool=calling.tool_name,
|
||||
input_data=calling.arguments,
|
||||
input=calling.arguments,
|
||||
output=output,
|
||||
)
|
||||
|
||||
@@ -1 +1,6 @@
|
||||
ALGORITHMS = ["RS256"]
|
||||
|
||||
#TODO: The AUTH0 constants should be removed after WorkOS migration is completed
|
||||
AUTH0_DOMAIN = "crewai.us.auth0.com"
|
||||
AUTH0_CLIENT_ID = "DEVC5Fw6NlRoSzmDCcOhVq85EfLBjKa8"
|
||||
AUTH0_AUDIENCE = "https://crewai.us.auth0.com/api/v2/"
|
||||
|
||||
@@ -7,27 +7,24 @@ from rich.console import Console
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
from .utils import validate_jwt_token
|
||||
from crewai.cli.shared.token_manager import TokenManager
|
||||
from .utils import TokenManager, validate_jwt_token
|
||||
from urllib.parse import quote
|
||||
from crewai.cli.plus_api import PlusAPI
|
||||
from crewai.cli.config import Settings
|
||||
from crewai.cli.authentication.constants import (
|
||||
AUTH0_AUDIENCE,
|
||||
AUTH0_CLIENT_ID,
|
||||
AUTH0_DOMAIN,
|
||||
)
|
||||
|
||||
console = Console()
|
||||
|
||||
|
||||
class Oauth2Settings(BaseModel):
|
||||
provider: str = Field(
|
||||
description="OAuth2 provider used for authentication (e.g., workos, okta, auth0)."
|
||||
)
|
||||
client_id: str = Field(
|
||||
description="OAuth2 client ID issued by the provider, used during authentication requests."
|
||||
)
|
||||
domain: str = Field(
|
||||
description="OAuth2 provider's domain (e.g., your-org.auth0.com) used for issuing tokens."
|
||||
)
|
||||
audience: Optional[str] = Field(
|
||||
description="OAuth2 audience value, typically used to identify the target API or resource.",
|
||||
default=None,
|
||||
)
|
||||
provider: str = Field(description="OAuth2 provider used for authentication (e.g., workos, okta, auth0).")
|
||||
client_id: str = Field(description="OAuth2 client ID issued by the provider, used during authentication requests.")
|
||||
domain: str = Field(description="OAuth2 provider's domain (e.g., your-org.auth0.com) used for issuing tokens.")
|
||||
audience: Optional[str] = Field(description="OAuth2 audience value, typically used to identify the target API or resource.", default=None)
|
||||
|
||||
@classmethod
|
||||
def from_settings(cls):
|
||||
@@ -47,15 +44,11 @@ class ProviderFactory:
|
||||
settings = settings or Oauth2Settings.from_settings()
|
||||
|
||||
import importlib
|
||||
|
||||
module = importlib.import_module(
|
||||
f"crewai.cli.authentication.providers.{settings.provider.lower()}"
|
||||
)
|
||||
module = importlib.import_module(f"crewai.cli.authentication.providers.{settings.provider.lower()}")
|
||||
provider = getattr(module, f"{settings.provider.capitalize()}Provider")
|
||||
|
||||
return provider(settings)
|
||||
|
||||
|
||||
class AuthenticationCommand:
|
||||
def __init__(self):
|
||||
self.token_manager = TokenManager()
|
||||
@@ -65,12 +58,26 @@ class AuthenticationCommand:
|
||||
"""Sign up to CrewAI+"""
|
||||
console.print("Signing in to CrewAI Enterprise...\n", style="bold blue")
|
||||
|
||||
# TODO: WORKOS - Next line and conditional are temporary until migration to WorkOS is complete.
|
||||
user_provider = self._determine_user_provider()
|
||||
if user_provider == "auth0":
|
||||
settings = Oauth2Settings(
|
||||
provider="auth0",
|
||||
client_id=AUTH0_CLIENT_ID,
|
||||
domain=AUTH0_DOMAIN,
|
||||
audience=AUTH0_AUDIENCE
|
||||
)
|
||||
self.oauth2_provider = ProviderFactory.from_settings(settings)
|
||||
# End of temporary code.
|
||||
|
||||
device_code_data = self._get_device_code()
|
||||
self._display_auth_instructions(device_code_data)
|
||||
|
||||
return self._poll_for_token(device_code_data)
|
||||
|
||||
def _get_device_code(self) -> Dict[str, Any]:
|
||||
def _get_device_code(
|
||||
self
|
||||
) -> Dict[str, Any]:
|
||||
"""Get the device code to authenticate the user."""
|
||||
|
||||
device_code_payload = {
|
||||
@@ -79,9 +86,7 @@ class AuthenticationCommand:
|
||||
"audience": self.oauth2_provider.get_audience(),
|
||||
}
|
||||
response = requests.post(
|
||||
url=self.oauth2_provider.get_authorize_url(),
|
||||
data=device_code_payload,
|
||||
timeout=20,
|
||||
url=self.oauth2_provider.get_authorize_url(), data=device_code_payload, timeout=20
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
@@ -92,7 +97,9 @@ class AuthenticationCommand:
|
||||
console.print("2. Enter the following code: ", device_code_data["user_code"])
|
||||
webbrowser.open(device_code_data["verification_uri_complete"])
|
||||
|
||||
def _poll_for_token(self, device_code_data: Dict[str, Any]) -> None:
|
||||
def _poll_for_token(
|
||||
self, device_code_data: Dict[str, Any]
|
||||
) -> None:
|
||||
"""Polls the server for the token until it is received, or max attempts are reached."""
|
||||
|
||||
token_payload = {
|
||||
@@ -105,9 +112,7 @@ class AuthenticationCommand:
|
||||
|
||||
attempts = 0
|
||||
while True and attempts < 10:
|
||||
response = requests.post(
|
||||
self.oauth2_provider.get_token_url(), data=token_payload, timeout=30
|
||||
)
|
||||
response = requests.post(self.oauth2_provider.get_token_url(), data=token_payload, timeout=30)
|
||||
token_data = response.json()
|
||||
|
||||
if response.status_code == 200:
|
||||
@@ -187,3 +192,30 @@ class AuthenticationCommand:
|
||||
"\nRun [bold]crewai login[/bold] to try logging in again.\n",
|
||||
style="yellow",
|
||||
)
|
||||
|
||||
# TODO: WORKOS - This method is temporary until migration to WorkOS is complete.
|
||||
def _determine_user_provider(self) -> str:
|
||||
"""Determine which provider to use for authentication."""
|
||||
|
||||
console.print(
|
||||
"Enter your CrewAI Enterprise account email: ", style="bold blue", end=""
|
||||
)
|
||||
email = input()
|
||||
email_encoded = quote(email)
|
||||
|
||||
# It's not correct to call this method directly, but it's temporary until migration is complete.
|
||||
response = PlusAPI("")._make_request(
|
||||
"GET", f"/crewai_plus/api/v1/me/provider?email={email_encoded}"
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
if response.json().get("provider") == "auth0":
|
||||
return "auth0"
|
||||
else:
|
||||
return "workos"
|
||||
else:
|
||||
console.print(
|
||||
"Error: Failed to authenticate with crewai enterprise. Ensure that you are using the latest crewai version and please try again. If the problem persists, contact support@crewai.com.",
|
||||
style="red",
|
||||
)
|
||||
raise SystemExit
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from crewai.cli.shared.token_manager import TokenManager
|
||||
from .utils import TokenManager
|
||||
|
||||
|
||||
class AuthError(Exception):
|
||||
|
||||
@@ -1,5 +1,12 @@
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
import jwt
|
||||
from jwt import PyJWKClient
|
||||
from cryptography.fernet import Fernet
|
||||
|
||||
|
||||
def validate_jwt_token(
|
||||
@@ -60,3 +67,118 @@ def validate_jwt_token(
|
||||
raise Exception(f"JWKS or key processing error: {str(e)}")
|
||||
except jwt.InvalidTokenError as e:
|
||||
raise Exception(f"Invalid token: {str(e)}")
|
||||
|
||||
|
||||
class TokenManager:
|
||||
def __init__(self, file_path: str = "tokens.enc") -> None:
|
||||
"""
|
||||
Initialize the TokenManager class.
|
||||
|
||||
:param file_path: The file path to store the encrypted tokens. Default is "tokens.enc".
|
||||
"""
|
||||
self.file_path = file_path
|
||||
self.key = self._get_or_create_key()
|
||||
self.fernet = Fernet(self.key)
|
||||
|
||||
def _get_or_create_key(self) -> bytes:
|
||||
"""
|
||||
Get or create the encryption key.
|
||||
|
||||
:return: The encryption key.
|
||||
"""
|
||||
key_filename = "secret.key"
|
||||
key = self.read_secure_file(key_filename)
|
||||
|
||||
if key is not None:
|
||||
return key
|
||||
|
||||
new_key = Fernet.generate_key()
|
||||
self.save_secure_file(key_filename, new_key)
|
||||
return new_key
|
||||
|
||||
def save_tokens(self, access_token: str, expires_at: int) -> None:
|
||||
"""
|
||||
Save the access token and its expiration time.
|
||||
|
||||
:param access_token: The access token to save.
|
||||
:param expires_at: The UNIX timestamp of the expiration time.
|
||||
"""
|
||||
expiration_time = datetime.fromtimestamp(expires_at)
|
||||
data = {
|
||||
"access_token": access_token,
|
||||
"expiration": expiration_time.isoformat(),
|
||||
}
|
||||
encrypted_data = self.fernet.encrypt(json.dumps(data).encode())
|
||||
self.save_secure_file(self.file_path, encrypted_data)
|
||||
|
||||
def get_token(self) -> Optional[str]:
|
||||
"""
|
||||
Get the access token if it is valid and not expired.
|
||||
|
||||
:return: The access token if valid and not expired, otherwise None.
|
||||
"""
|
||||
encrypted_data = self.read_secure_file(self.file_path)
|
||||
|
||||
decrypted_data = self.fernet.decrypt(encrypted_data) # type: ignore
|
||||
data = json.loads(decrypted_data)
|
||||
|
||||
expiration = datetime.fromisoformat(data["expiration"])
|
||||
if expiration <= datetime.now():
|
||||
return None
|
||||
|
||||
return data["access_token"]
|
||||
|
||||
def get_secure_storage_path(self) -> Path:
|
||||
"""
|
||||
Get the secure storage path based on the operating system.
|
||||
|
||||
:return: The secure storage path.
|
||||
"""
|
||||
if sys.platform == "win32":
|
||||
# Windows: Use %LOCALAPPDATA%
|
||||
base_path = os.environ.get("LOCALAPPDATA")
|
||||
elif sys.platform == "darwin":
|
||||
# macOS: Use ~/Library/Application Support
|
||||
base_path = os.path.expanduser("~/Library/Application Support")
|
||||
else:
|
||||
# Linux and other Unix-like: Use ~/.local/share
|
||||
base_path = os.path.expanduser("~/.local/share")
|
||||
|
||||
app_name = "crewai/credentials"
|
||||
storage_path = Path(base_path) / app_name
|
||||
|
||||
storage_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
return storage_path
|
||||
|
||||
def save_secure_file(self, filename: str, content: bytes) -> None:
|
||||
"""
|
||||
Save the content to a secure file.
|
||||
|
||||
:param filename: The name of the file.
|
||||
:param content: The content to save.
|
||||
"""
|
||||
storage_path = self.get_secure_storage_path()
|
||||
file_path = storage_path / filename
|
||||
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(content)
|
||||
|
||||
# Set appropriate permissions (read/write for owner only)
|
||||
os.chmod(file_path, 0o600)
|
||||
|
||||
def read_secure_file(self, filename: str) -> Optional[bytes]:
|
||||
"""
|
||||
Read the content of a secure file.
|
||||
|
||||
:param filename: The name of the file.
|
||||
:return: The content of the file if it exists, otherwise None.
|
||||
"""
|
||||
storage_path = self.get_secure_storage_path()
|
||||
file_path = storage_path / filename
|
||||
|
||||
if not file_path.exists():
|
||||
return None
|
||||
|
||||
with open(file_path, "rb") as f:
|
||||
return f.read()
|
||||
|
||||
@@ -11,7 +11,6 @@ from crewai.cli.constants import (
|
||||
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_CLIENT_ID,
|
||||
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN,
|
||||
)
|
||||
from crewai.cli.shared.token_manager import TokenManager
|
||||
|
||||
DEFAULT_CONFIG_PATH = Path.home() / ".config" / "crewai" / "settings.json"
|
||||
|
||||
@@ -54,7 +53,6 @@ HIDDEN_SETTINGS_KEYS = [
|
||||
"tool_repository_password",
|
||||
]
|
||||
|
||||
|
||||
class Settings(BaseModel):
|
||||
enterprise_base_url: Optional[str] = Field(
|
||||
default=DEFAULT_CLI_SETTINGS["enterprise_base_url"],
|
||||
@@ -76,12 +74,12 @@ class Settings(BaseModel):
|
||||
|
||||
oauth2_provider: str = Field(
|
||||
description="OAuth2 provider used for authentication (e.g., workos, okta, auth0).",
|
||||
default=DEFAULT_CLI_SETTINGS["oauth2_provider"],
|
||||
default=DEFAULT_CLI_SETTINGS["oauth2_provider"]
|
||||
)
|
||||
|
||||
oauth2_audience: Optional[str] = Field(
|
||||
description="OAuth2 audience value, typically used to identify the target API or resource.",
|
||||
default=DEFAULT_CLI_SETTINGS["oauth2_audience"],
|
||||
default=DEFAULT_CLI_SETTINGS["oauth2_audience"]
|
||||
)
|
||||
|
||||
oauth2_client_id: str = Field(
|
||||
@@ -91,7 +89,7 @@ class Settings(BaseModel):
|
||||
|
||||
oauth2_domain: str = Field(
|
||||
description="OAuth2 provider's domain (e.g., your-org.auth0.com) used for issuing tokens.",
|
||||
default=DEFAULT_CLI_SETTINGS["oauth2_domain"],
|
||||
default=DEFAULT_CLI_SETTINGS["oauth2_domain"]
|
||||
)
|
||||
|
||||
def __init__(self, config_path: Path = DEFAULT_CONFIG_PATH, **data):
|
||||
@@ -118,7 +116,6 @@ class Settings(BaseModel):
|
||||
"""Reset all settings to default values"""
|
||||
self._reset_user_settings()
|
||||
self._reset_cli_settings()
|
||||
self._clear_auth_tokens()
|
||||
self.dump()
|
||||
|
||||
def dump(self) -> None:
|
||||
@@ -142,7 +139,3 @@ class Settings(BaseModel):
|
||||
"""Reset all CLI settings to default values"""
|
||||
for key in CLI_SETTINGS_KEYS:
|
||||
setattr(self, key, DEFAULT_CLI_SETTINGS.get(key))
|
||||
|
||||
def _clear_auth_tokens(self) -> None:
|
||||
"""Clear all authentication tokens"""
|
||||
TokenManager().clear_tokens()
|
||||
|
||||
@@ -117,19 +117,17 @@ class PlusAPI:
|
||||
def get_organizations(self) -> requests.Response:
|
||||
return self._make_request("GET", self.ORGANIZATIONS_RESOURCE)
|
||||
|
||||
def send_trace_batch(self, payload) -> requests.Response:
|
||||
return self._make_request("POST", self.TRACING_RESOURCE, json=payload)
|
||||
|
||||
def initialize_trace_batch(self, payload) -> requests.Response:
|
||||
return self._make_request(
|
||||
"POST",
|
||||
f"{self.TRACING_RESOURCE}/batches",
|
||||
json=payload,
|
||||
timeout=30,
|
||||
"POST", f"{self.TRACING_RESOURCE}/batches", json=payload
|
||||
)
|
||||
|
||||
def initialize_ephemeral_trace_batch(self, payload) -> requests.Response:
|
||||
return self._make_request(
|
||||
"POST",
|
||||
f"{self.EPHEMERAL_TRACING_RESOURCE}/batches",
|
||||
json=payload,
|
||||
"POST", f"{self.EPHEMERAL_TRACING_RESOURCE}/batches", json=payload
|
||||
)
|
||||
|
||||
def send_trace_events(self, trace_batch_id: str, payload) -> requests.Response:
|
||||
@@ -137,7 +135,6 @@ class PlusAPI:
|
||||
"POST",
|
||||
f"{self.TRACING_RESOURCE}/batches/{trace_batch_id}/events",
|
||||
json=payload,
|
||||
timeout=30,
|
||||
)
|
||||
|
||||
def send_ephemeral_trace_events(
|
||||
@@ -147,7 +144,6 @@ class PlusAPI:
|
||||
"POST",
|
||||
f"{self.EPHEMERAL_TRACING_RESOURCE}/batches/{trace_batch_id}/events",
|
||||
json=payload,
|
||||
timeout=30,
|
||||
)
|
||||
|
||||
def finalize_trace_batch(self, trace_batch_id: str, payload) -> requests.Response:
|
||||
@@ -155,7 +151,6 @@ class PlusAPI:
|
||||
"PATCH",
|
||||
f"{self.TRACING_RESOURCE}/batches/{trace_batch_id}/finalize",
|
||||
json=payload,
|
||||
timeout=30,
|
||||
)
|
||||
|
||||
def finalize_ephemeral_trace_batch(
|
||||
@@ -165,5 +160,4 @@ class PlusAPI:
|
||||
"PATCH",
|
||||
f"{self.EPHEMERAL_TRACING_RESOURCE}/batches/{trace_batch_id}/finalize",
|
||||
json=payload,
|
||||
timeout=30,
|
||||
)
|
||||
|
||||
@@ -10,9 +10,8 @@ console = Console()
|
||||
class SettingsCommand(BaseCommand):
|
||||
"""A class to handle CLI configuration commands."""
|
||||
|
||||
def __init__(self, settings_kwargs: dict[str, Any] | None = None):
|
||||
def __init__(self, settings_kwargs: dict[str, Any] = {}):
|
||||
super().__init__()
|
||||
settings_kwargs = settings_kwargs or {}
|
||||
self.settings = Settings(**settings_kwargs)
|
||||
|
||||
def list(self) -> None:
|
||||
|
||||
@@ -1,141 +0,0 @@
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from cryptography.fernet import Fernet
|
||||
|
||||
|
||||
class TokenManager:
|
||||
def __init__(self, file_path: str = "tokens.enc") -> None:
|
||||
"""
|
||||
Initialize the TokenManager class.
|
||||
|
||||
:param file_path: The file path to store the encrypted tokens. Default is "tokens.enc".
|
||||
"""
|
||||
self.file_path = file_path
|
||||
self.key = self._get_or_create_key()
|
||||
self.fernet = Fernet(self.key)
|
||||
|
||||
def _get_or_create_key(self) -> bytes:
|
||||
"""
|
||||
Get or create the encryption key.
|
||||
|
||||
:return: The encryption key.
|
||||
"""
|
||||
key_filename = "secret.key"
|
||||
key = self.read_secure_file(key_filename)
|
||||
|
||||
if key is not None:
|
||||
return key
|
||||
|
||||
new_key = Fernet.generate_key()
|
||||
self.save_secure_file(key_filename, new_key)
|
||||
return new_key
|
||||
|
||||
def save_tokens(self, access_token: str, expires_at: int) -> None:
|
||||
"""
|
||||
Save the access token and its expiration time.
|
||||
|
||||
:param access_token: The access token to save.
|
||||
:param expires_at: The UNIX timestamp of the expiration time.
|
||||
"""
|
||||
expiration_time = datetime.fromtimestamp(expires_at)
|
||||
data = {
|
||||
"access_token": access_token,
|
||||
"expiration": expiration_time.isoformat(),
|
||||
}
|
||||
encrypted_data = self.fernet.encrypt(json.dumps(data).encode())
|
||||
self.save_secure_file(self.file_path, encrypted_data)
|
||||
|
||||
def get_token(self) -> Optional[str]:
|
||||
"""
|
||||
Get the access token if it is valid and not expired.
|
||||
|
||||
:return: The access token if valid and not expired, otherwise None.
|
||||
"""
|
||||
encrypted_data = self.read_secure_file(self.file_path)
|
||||
if encrypted_data is None:
|
||||
return None
|
||||
|
||||
decrypted_data = self.fernet.decrypt(encrypted_data) # type: ignore
|
||||
data = json.loads(decrypted_data)
|
||||
|
||||
expiration = datetime.fromisoformat(data["expiration"])
|
||||
if expiration <= datetime.now():
|
||||
return None
|
||||
|
||||
return data["access_token"]
|
||||
|
||||
def clear_tokens(self) -> None:
|
||||
"""
|
||||
Clear the tokens.
|
||||
"""
|
||||
self.delete_secure_file(self.file_path)
|
||||
|
||||
def get_secure_storage_path(self) -> Path:
|
||||
"""
|
||||
Get the secure storage path based on the operating system.
|
||||
|
||||
:return: The secure storage path.
|
||||
"""
|
||||
if sys.platform == "win32":
|
||||
# Windows: Use %LOCALAPPDATA%
|
||||
base_path = os.environ.get("LOCALAPPDATA")
|
||||
elif sys.platform == "darwin":
|
||||
# macOS: Use ~/Library/Application Support
|
||||
base_path = os.path.expanduser("~/Library/Application Support")
|
||||
else:
|
||||
# Linux and other Unix-like: Use ~/.local/share
|
||||
base_path = os.path.expanduser("~/.local/share")
|
||||
|
||||
app_name = "crewai/credentials"
|
||||
storage_path = Path(base_path) / app_name
|
||||
|
||||
storage_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
return storage_path
|
||||
|
||||
def save_secure_file(self, filename: str, content: bytes) -> None:
|
||||
"""
|
||||
Save the content to a secure file.
|
||||
|
||||
:param filename: The name of the file.
|
||||
:param content: The content to save.
|
||||
"""
|
||||
storage_path = self.get_secure_storage_path()
|
||||
file_path = storage_path / filename
|
||||
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(content)
|
||||
|
||||
# Set appropriate permissions (read/write for owner only)
|
||||
os.chmod(file_path, 0o600)
|
||||
|
||||
def read_secure_file(self, filename: str) -> Optional[bytes]:
|
||||
"""
|
||||
Read the content of a secure file.
|
||||
|
||||
:param filename: The name of the file.
|
||||
:return: The content of the file if it exists, otherwise None.
|
||||
"""
|
||||
storage_path = self.get_secure_storage_path()
|
||||
file_path = storage_path / filename
|
||||
|
||||
if not file_path.exists():
|
||||
return None
|
||||
|
||||
with open(file_path, "rb") as f:
|
||||
return f.read()
|
||||
|
||||
def delete_secure_file(self, filename: str) -> None:
|
||||
"""
|
||||
Delete the secure file.
|
||||
|
||||
:param filename: The name of the file.
|
||||
"""
|
||||
storage_path = self.get_secure_storage_path()
|
||||
file_path = storage_path / filename
|
||||
if file_path.exists():
|
||||
file_path.unlink(missing_ok=True)
|
||||
@@ -5,7 +5,7 @@ description = "{{name}} using crewAI"
|
||||
authors = [{ name = "Your Name", email = "you@example.com" }]
|
||||
requires-python = ">=3.10,<3.14"
|
||||
dependencies = [
|
||||
"crewai[tools]>=0.177.0,<1.0.0"
|
||||
"crewai[tools]>=0.165.1,<1.0.0"
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
||||
@@ -5,7 +5,7 @@ description = "{{name}} using crewAI"
|
||||
authors = [{ name = "Your Name", email = "you@example.com" }]
|
||||
requires-python = ">=3.10,<3.14"
|
||||
dependencies = [
|
||||
"crewai[tools]>=0.177.0,<1.0.0",
|
||||
"crewai[tools]>=0.165.1,<1.0.0",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
||||
@@ -5,7 +5,7 @@ description = "Power up your crews with {{folder_name}}"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10,<3.14"
|
||||
dependencies = [
|
||||
"crewai[tools]>=0.177.0"
|
||||
"crewai[tools]>=0.165.1"
|
||||
]
|
||||
|
||||
[tool.crewai]
|
||||
|
||||
@@ -3,18 +3,26 @@ import json
|
||||
import re
|
||||
import uuid
|
||||
import warnings
|
||||
from collections.abc import Callable, Mapping, Set
|
||||
from concurrent.futures import Future
|
||||
from copy import copy as shallow_copy
|
||||
from hashlib import md5
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
from opentelemetry import baggage
|
||||
from opentelemetry.context import attach, detach
|
||||
|
||||
from crewai.utilities.crew.models import CrewContext
|
||||
|
||||
from pydantic import (
|
||||
UUID4,
|
||||
BaseModel,
|
||||
@@ -26,36 +34,15 @@ from pydantic import (
|
||||
model_validator,
|
||||
)
|
||||
from pydantic_core import PydanticCustomError
|
||||
from typing_extensions import Self
|
||||
|
||||
from crewai.agent import Agent
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
from crewai.agents.cache.cache_handler import CacheHandler
|
||||
from crewai.agents.cache import CacheHandler
|
||||
from crewai.crews.crew_output import CrewOutput
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.event_listener import EventListener
|
||||
from crewai.events.listeners.tracing.trace_listener import (
|
||||
TraceCollectionListener,
|
||||
)
|
||||
from crewai.events.listeners.tracing.utils import (
|
||||
is_tracing_enabled,
|
||||
)
|
||||
from crewai.events.types.crew_events import (
|
||||
CrewKickoffCompletedEvent,
|
||||
CrewKickoffFailedEvent,
|
||||
CrewKickoffStartedEvent,
|
||||
CrewTestCompletedEvent,
|
||||
CrewTestFailedEvent,
|
||||
CrewTestStartedEvent,
|
||||
CrewTrainCompletedEvent,
|
||||
CrewTrainFailedEvent,
|
||||
CrewTrainStartedEvent,
|
||||
)
|
||||
from crewai.flow.flow_trackable import FlowTrackable
|
||||
from crewai.knowledge.knowledge import Knowledge
|
||||
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
|
||||
from crewai.llm import LLM
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.llm import LLM, BaseLLM
|
||||
from crewai.memory.entity.entity_memory import EntityMemory
|
||||
from crewai.memory.external.external_memory import ExternalMemory
|
||||
from crewai.memory.long_term.long_term_memory import LongTermMemory
|
||||
@@ -70,9 +57,29 @@ from crewai.tools.base_tool import BaseTool, Tool
|
||||
from crewai.types.usage_metrics import UsageMetrics
|
||||
from crewai.utilities import I18N, FileHandler, Logger, RPMController
|
||||
from crewai.utilities.constants import NOT_SPECIFIED, TRAINING_DATA_FILE
|
||||
from crewai.utilities.crew.models import CrewContext
|
||||
from crewai.utilities.evaluators.crew_evaluator_handler import CrewEvaluator
|
||||
from crewai.utilities.evaluators.task_evaluator import TaskEvaluator
|
||||
from crewai.utilities.events.crew_events import (
|
||||
CrewKickoffCompletedEvent,
|
||||
CrewKickoffFailedEvent,
|
||||
CrewKickoffStartedEvent,
|
||||
CrewTestCompletedEvent,
|
||||
CrewTestFailedEvent,
|
||||
CrewTestStartedEvent,
|
||||
CrewTrainCompletedEvent,
|
||||
CrewTrainFailedEvent,
|
||||
CrewTrainStartedEvent,
|
||||
)
|
||||
from crewai.utilities.events.crewai_event_bus import crewai_event_bus
|
||||
from crewai.utilities.events.event_listener import EventListener
|
||||
from crewai.utilities.events.listeners.tracing.trace_listener import (
|
||||
TraceCollectionListener,
|
||||
)
|
||||
|
||||
|
||||
from crewai.utilities.events.listeners.tracing.utils import (
|
||||
is_tracing_enabled,
|
||||
)
|
||||
from crewai.utilities.formatter import (
|
||||
aggregate_raw_outputs_from_task_outputs,
|
||||
aggregate_raw_outputs_from_tasks,
|
||||
@@ -109,12 +116,9 @@ class Crew(FlowTrackable, BaseModel):
|
||||
planning: Plan the crew execution and add the plan to the crew.
|
||||
chat_llm: The language model used for orchestrating chat interactions with the crew.
|
||||
security_config: Security configuration for the crew, including fingerprinting.
|
||||
|
||||
Notes:
|
||||
TODO: Improve the embedder type from dict[str, Any] to a more specific TypedDict or dataclass.
|
||||
"""
|
||||
|
||||
__hash__ = object.__hash__
|
||||
__hash__ = object.__hash__ # type: ignore
|
||||
_execution_span: Any = PrivateAttr()
|
||||
_rpm_controller: RPMController = PrivateAttr()
|
||||
_logger: Logger = PrivateAttr()
|
||||
@@ -126,7 +130,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
_external_memory: Optional[InstanceOf[ExternalMemory]] = PrivateAttr()
|
||||
_train: Optional[bool] = PrivateAttr(default=False)
|
||||
_train_iteration: Optional[int] = PrivateAttr()
|
||||
_inputs: Optional[dict[str, Any]] = PrivateAttr(default=None)
|
||||
_inputs: Optional[Dict[str, Any]] = PrivateAttr(default=None)
|
||||
_logging_color: str = PrivateAttr(
|
||||
default="bold_purple",
|
||||
)
|
||||
@@ -136,8 +140,8 @@ class Crew(FlowTrackable, BaseModel):
|
||||
|
||||
name: Optional[str] = Field(default="crew")
|
||||
cache: bool = Field(default=True)
|
||||
tasks: list[Task] = Field(default_factory=list)
|
||||
agents: list[BaseAgent] = Field(default_factory=list)
|
||||
tasks: List[Task] = Field(default_factory=list)
|
||||
agents: List[BaseAgent] = Field(default_factory=list)
|
||||
process: Process = Field(default=Process.sequential)
|
||||
verbose: bool = Field(default=False)
|
||||
memory: bool = Field(
|
||||
@@ -160,7 +164,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
default=None,
|
||||
description="An Instance of the ExternalMemory to be used by the Crew",
|
||||
)
|
||||
embedder: Optional[dict[str, Any]] = Field(
|
||||
embedder: Optional[dict] = Field(
|
||||
default=None,
|
||||
description="Configuration for the embedder to be used for the crew.",
|
||||
)
|
||||
@@ -168,16 +172,16 @@ class Crew(FlowTrackable, BaseModel):
|
||||
default=None,
|
||||
description="Metrics for the LLM usage during all tasks execution.",
|
||||
)
|
||||
manager_llm: Optional[str | InstanceOf[BaseLLM] | Any] = Field(
|
||||
manager_llm: Optional[Union[str, InstanceOf[BaseLLM], Any]] = Field(
|
||||
description="Language model that will run the agent.", default=None
|
||||
)
|
||||
manager_agent: Optional[BaseAgent] = Field(
|
||||
description="Custom agent that will be used as manager.", default=None
|
||||
)
|
||||
function_calling_llm: Optional[str | InstanceOf[LLM] | Any] = Field(
|
||||
function_calling_llm: Optional[Union[str, InstanceOf[LLM], Any]] = Field(
|
||||
description="Language model that will run the agent.", default=None
|
||||
)
|
||||
config: Optional[Json[dict[str, Any]] | dict[str, Any]] = Field(default=None)
|
||||
config: Optional[Union[Json, Dict[str, Any]]] = Field(default=None)
|
||||
id: UUID4 = Field(default_factory=uuid.uuid4, frozen=True)
|
||||
share_crew: Optional[bool] = Field(default=False)
|
||||
step_callback: Optional[Any] = Field(
|
||||
@@ -188,13 +192,13 @@ class Crew(FlowTrackable, BaseModel):
|
||||
default=None,
|
||||
description="Callback to be executed after each task for all agents execution.",
|
||||
)
|
||||
before_kickoff_callbacks: list[
|
||||
Callable[[Optional[dict[str, Any]]], Optional[dict[str, Any]]]
|
||||
before_kickoff_callbacks: List[
|
||||
Callable[[Optional[Dict[str, Any]]], Optional[Dict[str, Any]]]
|
||||
] = Field(
|
||||
default_factory=list,
|
||||
description="List of callbacks to be executed before crew kickoff. It may be used to adjust inputs before the crew is executed.",
|
||||
)
|
||||
after_kickoff_callbacks: list[Callable[[CrewOutput], CrewOutput]] = Field(
|
||||
after_kickoff_callbacks: List[Callable[[CrewOutput], CrewOutput]] = Field(
|
||||
default_factory=list,
|
||||
description="List of callbacks to be executed after crew kickoff. It may be used to adjust the output of the crew.",
|
||||
)
|
||||
@@ -206,7 +210,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
default=None,
|
||||
description="Path to the prompt json file to be used for the crew.",
|
||||
)
|
||||
output_log_file: Optional[bool | str] = Field(
|
||||
output_log_file: Optional[Union[bool, str]] = Field(
|
||||
default=None,
|
||||
description="Path to the log file to be saved",
|
||||
)
|
||||
@@ -214,23 +218,23 @@ class Crew(FlowTrackable, BaseModel):
|
||||
default=False,
|
||||
description="Plan the crew execution and add the plan to the crew.",
|
||||
)
|
||||
planning_llm: Optional[str | InstanceOf[BaseLLM] | Any] = Field(
|
||||
planning_llm: Optional[Union[str, InstanceOf[BaseLLM], Any]] = Field(
|
||||
default=None,
|
||||
description="Language model that will run the AgentPlanner if planning is True.",
|
||||
)
|
||||
task_execution_output_json_files: Optional[list[str]] = Field(
|
||||
task_execution_output_json_files: Optional[List[str]] = Field(
|
||||
default=None,
|
||||
description="List of file paths for task execution JSON files.",
|
||||
)
|
||||
execution_logs: list[dict[str, Any]] = Field(
|
||||
execution_logs: List[Dict[str, Any]] = Field(
|
||||
default=[],
|
||||
description="List of execution logs for tasks",
|
||||
)
|
||||
knowledge_sources: Optional[list[BaseKnowledgeSource]] = Field(
|
||||
knowledge_sources: Optional[List[BaseKnowledgeSource]] = Field(
|
||||
default=None,
|
||||
description="Knowledge sources for the crew. Add knowledge sources to the knowledge object.",
|
||||
)
|
||||
chat_llm: Optional[str | InstanceOf[BaseLLM] | Any] = Field(
|
||||
chat_llm: Optional[Union[str, InstanceOf[BaseLLM], Any]] = Field(
|
||||
default=None,
|
||||
description="LLM used to handle chatting with the crew.",
|
||||
)
|
||||
@@ -263,8 +267,8 @@ class Crew(FlowTrackable, BaseModel):
|
||||
@field_validator("config", mode="before")
|
||||
@classmethod
|
||||
def check_config_type(
|
||||
cls, v: Json[dict[str, Any]] | dict[str, Any]
|
||||
) -> Json[dict[str, Any]] | dict[str, Any]:
|
||||
cls, v: Union[Json, Dict[str, Any]]
|
||||
) -> Union[Json, Dict[str, Any]]:
|
||||
"""Validates that the config is a valid type.
|
||||
Args:
|
||||
v: The config to be validated.
|
||||
@@ -273,10 +277,10 @@ class Crew(FlowTrackable, BaseModel):
|
||||
"""
|
||||
|
||||
# TODO: Improve typing
|
||||
return json.loads(v) if isinstance(v, str) else v
|
||||
return json.loads(v) if isinstance(v, Json) else v # type: ignore
|
||||
|
||||
@model_validator(mode="after")
|
||||
def set_private_attrs(self) -> Self:
|
||||
def set_private_attrs(self) -> "Crew":
|
||||
"""Set private attributes."""
|
||||
|
||||
self._cache_handler = CacheHandler()
|
||||
@@ -296,7 +300,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
|
||||
return self
|
||||
|
||||
def _initialize_default_memories(self) -> None:
|
||||
def _initialize_default_memories(self):
|
||||
self._long_term_memory = self._long_term_memory or LongTermMemory()
|
||||
self._short_term_memory = self._short_term_memory or ShortTermMemory(
|
||||
crew=self,
|
||||
@@ -307,7 +311,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def create_crew_memory(self) -> Self:
|
||||
def create_crew_memory(self) -> "Crew":
|
||||
"""Initialize private memory attributes."""
|
||||
self._external_memory = (
|
||||
# External memory doesn’t support a default value since it was designed to be managed entirely externally
|
||||
@@ -324,7 +328,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def create_crew_knowledge(self) -> Self:
|
||||
def create_crew_knowledge(self) -> "Crew":
|
||||
"""Create the knowledge for the crew."""
|
||||
if self.knowledge_sources:
|
||||
try:
|
||||
@@ -345,7 +349,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_manager_llm(self) -> Self:
|
||||
def check_manager_llm(self):
|
||||
"""Validates that the language model is set when using hierarchical process."""
|
||||
if self.process == Process.hierarchical:
|
||||
if not self.manager_llm and not self.manager_agent:
|
||||
@@ -367,7 +371,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_config(self) -> Self:
|
||||
def check_config(self):
|
||||
"""Validates that the crew is properly configured with agents and tasks."""
|
||||
if not self.config and not self.tasks and not self.agents:
|
||||
raise PydanticCustomError(
|
||||
@@ -388,20 +392,20 @@ class Crew(FlowTrackable, BaseModel):
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_tasks(self) -> Self:
|
||||
def validate_tasks(self):
|
||||
if self.process == Process.sequential:
|
||||
for task in self.tasks:
|
||||
if task.agent is None:
|
||||
raise PydanticCustomError(
|
||||
"missing_agent_in_task",
|
||||
f"Sequential process error: Agent is missing in the task with the following description: {task.description}",
|
||||
f"Sequential process error: Agent is missing in the task with the following description: {task.description}", # type: ignore # Argument of type "str" cannot be assigned to parameter "message_template" of type "LiteralString"
|
||||
{},
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_end_with_at_most_one_async_task(self) -> Self:
|
||||
def validate_end_with_at_most_one_async_task(self):
|
||||
"""Validates that the crew ends with at most one asynchronous task."""
|
||||
final_async_task_count = 0
|
||||
|
||||
@@ -422,7 +426,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_must_have_non_conditional_task(self) -> Self:
|
||||
def validate_must_have_non_conditional_task(self) -> "Crew":
|
||||
"""Ensure that a crew has at least one non-conditional task."""
|
||||
if not self.tasks:
|
||||
return self
|
||||
@@ -438,7 +442,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_first_task(self) -> Self:
|
||||
def validate_first_task(self) -> "Crew":
|
||||
"""Ensure the first task is not a ConditionalTask."""
|
||||
if self.tasks and isinstance(self.tasks[0], ConditionalTask):
|
||||
raise PydanticCustomError(
|
||||
@@ -449,21 +453,19 @@ class Crew(FlowTrackable, BaseModel):
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_async_tasks_not_async(self) -> Self:
|
||||
def validate_async_tasks_not_async(self) -> "Crew":
|
||||
"""Ensure that ConditionalTask is not async."""
|
||||
for task in self.tasks:
|
||||
if task.async_execution and isinstance(task, ConditionalTask):
|
||||
raise PydanticCustomError(
|
||||
"invalid_async_conditional_task",
|
||||
f"Conditional Task: {task.description} , cannot be executed asynchronously.",
|
||||
f"Conditional Task: {task.description} , cannot be executed asynchronously.", # type: ignore # Argument of type "str" cannot be assigned to parameter "message_template" of type "LiteralString"
|
||||
{},
|
||||
)
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_async_task_cannot_include_sequential_async_tasks_in_context(
|
||||
self,
|
||||
) -> Self:
|
||||
def validate_async_task_cannot_include_sequential_async_tasks_in_context(self):
|
||||
"""
|
||||
Validates that if a task is set to be executed asynchronously,
|
||||
it cannot include other asynchronous tasks in its context unless
|
||||
@@ -483,7 +485,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_context_no_future_tasks(self) -> Self:
|
||||
def validate_context_no_future_tasks(self):
|
||||
"""Validates that a task's context does not include future tasks."""
|
||||
task_indices = {id(task): i for i, task in enumerate(self.tasks)}
|
||||
|
||||
@@ -500,7 +502,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
|
||||
@property
|
||||
def key(self) -> str:
|
||||
source: list[str] = [agent.key for agent in self.agents] + [
|
||||
source: List[str] = [agent.key for agent in self.agents] + [
|
||||
task.key for task in self.tasks
|
||||
]
|
||||
return md5("|".join(source).encode(), usedforsecurity=False).hexdigest()
|
||||
@@ -515,7 +517,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
"""
|
||||
return self.security_config.fingerprint
|
||||
|
||||
def _setup_from_config(self) -> None:
|
||||
def _setup_from_config(self):
|
||||
assert self.config is not None, "Config should not be None."
|
||||
|
||||
"""Initializes agents and tasks from the provided config."""
|
||||
@@ -528,7 +530,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
self.agents = [Agent(**agent) for agent in self.config["agents"]]
|
||||
self.tasks = [self._create_task(task) for task in self.config["tasks"]]
|
||||
|
||||
def _create_task(self, task_config: dict[str, Any]) -> Task:
|
||||
def _create_task(self, task_config: Dict[str, Any]) -> Task:
|
||||
"""Creates a task instance from its configuration.
|
||||
|
||||
Args:
|
||||
@@ -557,10 +559,9 @@ class Crew(FlowTrackable, BaseModel):
|
||||
CrewTrainingHandler(filename).initialize_file()
|
||||
|
||||
def train(
|
||||
self, n_iterations: int, filename: str, inputs: Optional[dict[str, Any]] = None
|
||||
self, n_iterations: int, filename: str, inputs: Optional[Dict[str, Any]] = {}
|
||||
) -> None:
|
||||
"""Trains the crew for a given number of iterations."""
|
||||
inputs = inputs or {}
|
||||
try:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
@@ -609,7 +610,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
|
||||
def kickoff(
|
||||
self,
|
||||
inputs: Optional[dict[str, Any]] = None,
|
||||
inputs: Optional[Dict[str, Any]] = None,
|
||||
) -> CrewOutput:
|
||||
ctx = baggage.set_baggage(
|
||||
"crew_context", CrewContext(id=str(self.id), key=self.key)
|
||||
@@ -641,7 +642,8 @@ class Crew(FlowTrackable, BaseModel):
|
||||
|
||||
for agent in self.agents:
|
||||
agent.i18n = i18n
|
||||
agent.crew = self
|
||||
# type: ignore[attr-defined] # Argument 1 to "_interpolate_inputs" of "Crew" has incompatible type "dict[str, Any] | None"; expected "dict[str, Any]"
|
||||
agent.crew = self # type: ignore[attr-defined]
|
||||
agent.set_knowledge(crew_embedder=self.embedder)
|
||||
# TODO: Create an AgentFunctionCalling protocol for future refactoring
|
||||
if not agent.function_calling_llm: # type: ignore # "BaseAgent" has no attribute "function_calling_llm"
|
||||
@@ -650,7 +652,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
if not agent.step_callback: # type: ignore # "BaseAgent" has no attribute "step_callback"
|
||||
agent.step_callback = self.step_callback # type: ignore # "BaseAgent" has no attribute "step_callback"
|
||||
|
||||
# Agent executor will be created when tasks are executed
|
||||
agent.create_agent_executor()
|
||||
|
||||
if self.planning:
|
||||
self._handle_crew_planning()
|
||||
@@ -679,9 +681,9 @@ class Crew(FlowTrackable, BaseModel):
|
||||
finally:
|
||||
detach(token)
|
||||
|
||||
def kickoff_for_each(self, inputs: list[dict[str, Any]]) -> list[CrewOutput]:
|
||||
def kickoff_for_each(self, inputs: List[Dict[str, Any]]) -> List[CrewOutput]:
|
||||
"""Executes the Crew's workflow for each input in the list and aggregates results."""
|
||||
results: list[CrewOutput] = []
|
||||
results: List[CrewOutput] = []
|
||||
|
||||
# Initialize the parent crew's usage metrics
|
||||
total_usage_metrics = UsageMetrics()
|
||||
@@ -700,19 +702,14 @@ class Crew(FlowTrackable, BaseModel):
|
||||
self._task_output_handler.reset()
|
||||
return results
|
||||
|
||||
async def kickoff_async(
|
||||
self, inputs: Optional[dict[str, Any]] = None
|
||||
) -> CrewOutput:
|
||||
async def kickoff_async(self, inputs: Optional[Dict[str, Any]] = {}) -> CrewOutput:
|
||||
"""Asynchronous kickoff method to start the crew execution."""
|
||||
inputs = inputs or {}
|
||||
return await asyncio.to_thread(self.kickoff, inputs)
|
||||
|
||||
async def kickoff_for_each_async(
|
||||
self, inputs: list[dict[str, Any]]
|
||||
) -> list[CrewOutput]:
|
||||
async def kickoff_for_each_async(self, inputs: List[Dict]) -> List[CrewOutput]:
|
||||
crew_copies = [self.copy() for _ in inputs]
|
||||
|
||||
async def run_crew(crew: Self, input_data: dict[str, Any]) -> CrewOutput:
|
||||
async def run_crew(crew, input_data):
|
||||
return await crew.kickoff_async(inputs=input_data)
|
||||
|
||||
tasks = [
|
||||
@@ -731,7 +728,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
self._task_output_handler.reset()
|
||||
return results
|
||||
|
||||
def _handle_crew_planning(self) -> None:
|
||||
def _handle_crew_planning(self):
|
||||
"""Handles the Crew planning."""
|
||||
self._logger.log("info", "Planning the crew execution")
|
||||
result = CrewPlanner(
|
||||
@@ -747,7 +744,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
output: TaskOutput,
|
||||
task_index: int,
|
||||
was_replayed: bool = False,
|
||||
) -> None:
|
||||
):
|
||||
if self._inputs:
|
||||
inputs = self._inputs
|
||||
else:
|
||||
@@ -779,7 +776,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
self._create_manager_agent()
|
||||
return self._execute_tasks(self.tasks)
|
||||
|
||||
def _create_manager_agent(self) -> None:
|
||||
def _create_manager_agent(self):
|
||||
i18n = I18N(prompt_file=self.prompt_file)
|
||||
if self.manager_agent is not None:
|
||||
self.manager_agent.allow_delegation = True
|
||||
@@ -791,12 +788,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
manager.tools = []
|
||||
raise Exception("Manager agent should not have tools")
|
||||
else:
|
||||
if self.manager_llm is None:
|
||||
from crewai.utilities.llm_utils import create_default_llm
|
||||
|
||||
self.manager_llm = create_default_llm()
|
||||
else:
|
||||
self.manager_llm = create_llm(self.manager_llm)
|
||||
self.manager_llm = create_llm(self.manager_llm)
|
||||
manager = Agent(
|
||||
role=i18n.retrieve("hierarchical_manager_agent", "role"),
|
||||
goal=i18n.retrieve("hierarchical_manager_agent", "goal"),
|
||||
@@ -811,7 +803,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
|
||||
def _execute_tasks(
|
||||
self,
|
||||
tasks: list[Task],
|
||||
tasks: List[Task],
|
||||
start_index: Optional[int] = 0,
|
||||
was_replayed: bool = False,
|
||||
) -> CrewOutput:
|
||||
@@ -825,8 +817,8 @@ class Crew(FlowTrackable, BaseModel):
|
||||
CrewOutput: Final output of the crew
|
||||
"""
|
||||
|
||||
task_outputs: list[TaskOutput] = []
|
||||
futures: list[tuple[Task, Future[TaskOutput], int]] = []
|
||||
task_outputs: List[TaskOutput] = []
|
||||
futures: List[Tuple[Task, Future[TaskOutput], int]] = []
|
||||
last_sync_output: Optional[TaskOutput] = None
|
||||
|
||||
for task_index, task in enumerate(tasks):
|
||||
@@ -851,7 +843,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
tools_for_task = self._prepare_tools(
|
||||
agent_to_use,
|
||||
task,
|
||||
cast(list[Tool] | list[BaseTool], tools_for_task),
|
||||
cast(Union[List[Tool], List[BaseTool]], tools_for_task),
|
||||
)
|
||||
|
||||
self._log_task_start(task, agent_to_use.role)
|
||||
@@ -871,7 +863,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
future = task.execute_async(
|
||||
agent=agent_to_use,
|
||||
context=context,
|
||||
tools=tools_for_task,
|
||||
tools=cast(List[BaseTool], tools_for_task),
|
||||
)
|
||||
futures.append((task, future, task_index))
|
||||
else:
|
||||
@@ -883,7 +875,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
task_output = task.execute_sync(
|
||||
agent=agent_to_use,
|
||||
context=context,
|
||||
tools=tools_for_task,
|
||||
tools=cast(List[BaseTool], tools_for_task),
|
||||
)
|
||||
task_outputs.append(task_output)
|
||||
self._process_task_result(task, task_output)
|
||||
@@ -897,8 +889,8 @@ class Crew(FlowTrackable, BaseModel):
|
||||
def _handle_conditional_task(
|
||||
self,
|
||||
task: ConditionalTask,
|
||||
task_outputs: list[TaskOutput],
|
||||
futures: list[tuple[Task, Future[TaskOutput], int]],
|
||||
task_outputs: List[TaskOutput],
|
||||
futures: List[Tuple[Task, Future[TaskOutput], int]],
|
||||
task_index: int,
|
||||
was_replayed: bool,
|
||||
) -> Optional[TaskOutput]:
|
||||
@@ -921,8 +913,8 @@ class Crew(FlowTrackable, BaseModel):
|
||||
return None
|
||||
|
||||
def _prepare_tools(
|
||||
self, agent: BaseAgent, task: Task, tools: list[Tool] | list[BaseTool]
|
||||
) -> list[BaseTool]:
|
||||
self, agent: BaseAgent, task: Task, tools: Union[List[Tool], List[BaseTool]]
|
||||
) -> List[BaseTool]:
|
||||
# Add delegation tools if agent allows delegation
|
||||
if hasattr(agent, "allow_delegation") and getattr(
|
||||
agent, "allow_delegation", False
|
||||
@@ -952,7 +944,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
tools = self._add_multimodal_tools(agent, tools)
|
||||
|
||||
# Return a List[BaseTool] which is compatible with both Task.execute_sync and Task.execute_async
|
||||
return cast(list[BaseTool], tools)
|
||||
return cast(List[BaseTool], tools)
|
||||
|
||||
def _get_agent_to_use(self, task: Task) -> Optional[BaseAgent]:
|
||||
if self.process == Process.hierarchical:
|
||||
@@ -961,12 +953,12 @@ class Crew(FlowTrackable, BaseModel):
|
||||
|
||||
def _merge_tools(
|
||||
self,
|
||||
existing_tools: list[Tool] | list[BaseTool],
|
||||
new_tools: list[Tool] | list[BaseTool],
|
||||
) -> list[BaseTool]:
|
||||
existing_tools: Union[List[Tool], List[BaseTool]],
|
||||
new_tools: Union[List[Tool], List[BaseTool]],
|
||||
) -> List[BaseTool]:
|
||||
"""Merge new tools into existing tools list, avoiding duplicates by tool name."""
|
||||
if not new_tools:
|
||||
return cast(list[BaseTool], existing_tools)
|
||||
return cast(List[BaseTool], existing_tools)
|
||||
|
||||
# Create mapping of tool names to new tools
|
||||
new_tool_map = {tool.name: tool for tool in new_tools}
|
||||
@@ -977,41 +969,41 @@ class Crew(FlowTrackable, BaseModel):
|
||||
# Add all new tools
|
||||
tools.extend(new_tools)
|
||||
|
||||
return tools
|
||||
return cast(List[BaseTool], tools)
|
||||
|
||||
def _inject_delegation_tools(
|
||||
self,
|
||||
tools: list[Tool] | list[BaseTool],
|
||||
tools: Union[List[Tool], List[BaseTool]],
|
||||
task_agent: BaseAgent,
|
||||
agents: list[BaseAgent],
|
||||
) -> list[BaseTool]:
|
||||
agents: List[BaseAgent],
|
||||
) -> List[BaseTool]:
|
||||
if hasattr(task_agent, "get_delegation_tools"):
|
||||
delegation_tools = task_agent.get_delegation_tools(agents)
|
||||
# Cast delegation_tools to the expected type for _merge_tools
|
||||
return self._merge_tools(tools, delegation_tools)
|
||||
return cast(list[BaseTool], tools)
|
||||
return self._merge_tools(tools, cast(List[BaseTool], delegation_tools))
|
||||
return cast(List[BaseTool], tools)
|
||||
|
||||
def _add_multimodal_tools(
|
||||
self, agent: BaseAgent, tools: list[Tool] | list[BaseTool]
|
||||
) -> list[BaseTool]:
|
||||
self, agent: BaseAgent, tools: Union[List[Tool], List[BaseTool]]
|
||||
) -> List[BaseTool]:
|
||||
if hasattr(agent, "get_multimodal_tools"):
|
||||
multimodal_tools = agent.get_multimodal_tools()
|
||||
# Cast multimodal_tools to the expected type for _merge_tools
|
||||
return self._merge_tools(tools, cast(list[BaseTool], multimodal_tools))
|
||||
return cast(list[BaseTool], tools)
|
||||
return self._merge_tools(tools, cast(List[BaseTool], multimodal_tools))
|
||||
return cast(List[BaseTool], tools)
|
||||
|
||||
def _add_code_execution_tools(
|
||||
self, agent: BaseAgent, tools: list[Tool] | list[BaseTool]
|
||||
) -> list[BaseTool]:
|
||||
self, agent: BaseAgent, tools: Union[List[Tool], List[BaseTool]]
|
||||
) -> List[BaseTool]:
|
||||
if hasattr(agent, "get_code_execution_tools"):
|
||||
code_tools = agent.get_code_execution_tools()
|
||||
# Cast code_tools to the expected type for _merge_tools
|
||||
return self._merge_tools(tools, cast(list[BaseTool], code_tools))
|
||||
return cast(list[BaseTool], tools)
|
||||
return self._merge_tools(tools, cast(List[BaseTool], code_tools))
|
||||
return cast(List[BaseTool], tools)
|
||||
|
||||
def _add_delegation_tools(
|
||||
self, task: Task, tools: list[Tool] | list[BaseTool]
|
||||
) -> list[BaseTool]:
|
||||
self, task: Task, tools: Union[List[Tool], List[BaseTool]]
|
||||
) -> List[BaseTool]:
|
||||
agents_for_delegation = [agent for agent in self.agents if agent != task.agent]
|
||||
if len(self.agents) > 1 and len(agents_for_delegation) > 0 and task.agent:
|
||||
if not tools:
|
||||
@@ -1019,17 +1011,17 @@ class Crew(FlowTrackable, BaseModel):
|
||||
tools = self._inject_delegation_tools(
|
||||
tools, task.agent, agents_for_delegation
|
||||
)
|
||||
return cast(list[BaseTool], tools)
|
||||
return cast(List[BaseTool], tools)
|
||||
|
||||
def _log_task_start(self, task: Task, role: str = "None") -> None:
|
||||
def _log_task_start(self, task: Task, role: str = "None"):
|
||||
if self.output_log_file:
|
||||
self._file_handler.log(
|
||||
task_name=task.name, task=task.description, agent=role, status="started"
|
||||
)
|
||||
|
||||
def _update_manager_tools(
|
||||
self, task: Task, tools: list[Tool] | list[BaseTool]
|
||||
) -> list[BaseTool]:
|
||||
self, task: Task, tools: Union[List[Tool], List[BaseTool]]
|
||||
) -> List[BaseTool]:
|
||||
if self.manager_agent:
|
||||
if task.agent:
|
||||
tools = self._inject_delegation_tools(tools, task.agent, [task.agent])
|
||||
@@ -1037,9 +1029,9 @@ class Crew(FlowTrackable, BaseModel):
|
||||
tools = self._inject_delegation_tools(
|
||||
tools, self.manager_agent, self.agents
|
||||
)
|
||||
return cast(list[BaseTool], tools)
|
||||
return cast(List[BaseTool], tools)
|
||||
|
||||
def _get_context(self, task: Task, task_outputs: list[TaskOutput]) -> str:
|
||||
def _get_context(self, task: Task, task_outputs: List[TaskOutput]) -> str:
|
||||
if not task.context:
|
||||
return ""
|
||||
|
||||
@@ -1061,7 +1053,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
output=output.raw,
|
||||
)
|
||||
|
||||
def _create_crew_output(self, task_outputs: list[TaskOutput]) -> CrewOutput:
|
||||
def _create_crew_output(self, task_outputs: List[TaskOutput]) -> CrewOutput:
|
||||
if not task_outputs:
|
||||
raise ValueError("No task outputs available to create crew output.")
|
||||
|
||||
@@ -1092,10 +1084,10 @@ class Crew(FlowTrackable, BaseModel):
|
||||
|
||||
def _process_async_tasks(
|
||||
self,
|
||||
futures: list[tuple[Task, Future[TaskOutput], int]],
|
||||
futures: List[Tuple[Task, Future[TaskOutput], int]],
|
||||
was_replayed: bool = False,
|
||||
) -> list[TaskOutput]:
|
||||
task_outputs: list[TaskOutput] = []
|
||||
) -> List[TaskOutput]:
|
||||
task_outputs: List[TaskOutput] = []
|
||||
for future_task, future, task_index in futures:
|
||||
task_output = future.result()
|
||||
task_outputs.append(task_output)
|
||||
@@ -1106,7 +1098,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
return task_outputs
|
||||
|
||||
def _find_task_index(
|
||||
self, task_id: str, stored_outputs: list[Any]
|
||||
self, task_id: str, stored_outputs: List[Any]
|
||||
) -> Optional[int]:
|
||||
return next(
|
||||
(
|
||||
@@ -1118,7 +1110,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
)
|
||||
|
||||
def replay(
|
||||
self, task_id: str, inputs: Optional[dict[str, Any]] = None
|
||||
self, task_id: str, inputs: Optional[Dict[str, Any]] = None
|
||||
) -> CrewOutput:
|
||||
stored_outputs = self._task_output_handler.load()
|
||||
if not stored_outputs:
|
||||
@@ -1159,15 +1151,15 @@ class Crew(FlowTrackable, BaseModel):
|
||||
return result
|
||||
|
||||
def query_knowledge(
|
||||
self, query: list[str], results_limit: int = 3, score_threshold: float = 0.35
|
||||
) -> list[dict[str, Any]] | None:
|
||||
self, query: List[str], results_limit: int = 3, score_threshold: float = 0.35
|
||||
) -> Union[List[Dict[str, Any]], None]:
|
||||
if self.knowledge:
|
||||
return self.knowledge.query(
|
||||
query, results_limit=results_limit, score_threshold=score_threshold
|
||||
)
|
||||
return None
|
||||
|
||||
def fetch_inputs(self) -> set[str]:
|
||||
def fetch_inputs(self) -> Set[str]:
|
||||
"""
|
||||
Gathers placeholders (e.g., {something}) referenced in tasks or agents.
|
||||
Scans each task's 'description' + 'expected_output', and each agent's
|
||||
@@ -1176,7 +1168,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
Returns a set of all discovered placeholder names.
|
||||
"""
|
||||
placeholder_pattern = re.compile(r"\{(.+?)\}")
|
||||
required_inputs: set[str] = set()
|
||||
required_inputs: Set[str] = set()
|
||||
|
||||
# Scan tasks for inputs
|
||||
for task in self.tasks:
|
||||
@@ -1192,18 +1184,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
|
||||
return required_inputs
|
||||
|
||||
def copy(
|
||||
self,
|
||||
*,
|
||||
include: Optional[
|
||||
Set[int] | Set[str] | Mapping[int, Any] | Mapping[str, Any]
|
||||
] = None,
|
||||
exclude: Optional[
|
||||
Set[int] | Set[str] | Mapping[int, Any] | Mapping[str, Any]
|
||||
] = None,
|
||||
update: Optional[dict[str, Any]] = None,
|
||||
deep: bool = True,
|
||||
) -> "Crew":
|
||||
def copy(self):
|
||||
"""
|
||||
Creates a deep copy of the Crew instance.
|
||||
|
||||
@@ -1234,7 +1215,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
manager_agent = self.manager_agent.copy() if self.manager_agent else None
|
||||
manager_llm = shallow_copy(self.manager_llm) if self.manager_llm else None
|
||||
|
||||
task_mapping: dict[str, Task] = {}
|
||||
task_mapping = {}
|
||||
|
||||
cloned_tasks = []
|
||||
existing_knowledge_sources = shallow_copy(self.knowledge_sources)
|
||||
@@ -1289,10 +1270,16 @@ class Crew(FlowTrackable, BaseModel):
|
||||
if not task.callback:
|
||||
task.callback = self.task_callback
|
||||
|
||||
def _interpolate_inputs(self, inputs: dict[str, Any]) -> None:
|
||||
def _interpolate_inputs(self, inputs: Dict[str, Any]) -> None:
|
||||
"""Interpolates the inputs in the tasks and agents."""
|
||||
for task in self.tasks:
|
||||
task.interpolate_inputs_and_add_conversation_history(inputs)
|
||||
[
|
||||
task.interpolate_inputs_and_add_conversation_history(
|
||||
# type: ignore # "interpolate_inputs" of "Task" does not return a value (it only ever returns None)
|
||||
inputs
|
||||
)
|
||||
for task in self.tasks
|
||||
]
|
||||
# type: ignore # "interpolate_inputs" of "Agent" does not return a value (it only ever returns None)
|
||||
for agent in self.agents:
|
||||
agent.interpolate_inputs(inputs)
|
||||
|
||||
@@ -1316,8 +1303,8 @@ class Crew(FlowTrackable, BaseModel):
|
||||
def test(
|
||||
self,
|
||||
n_iterations: int,
|
||||
eval_llm: str | InstanceOf[BaseLLM],
|
||||
inputs: Optional[dict[str, Any]] = None,
|
||||
eval_llm: Union[str, InstanceOf[BaseLLM]],
|
||||
inputs: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
"""Test and evaluate the Crew with the given inputs for n iterations concurrently using concurrent.futures."""
|
||||
try:
|
||||
@@ -1358,7 +1345,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
)
|
||||
raise
|
||||
|
||||
def __repr__(self) -> str:
|
||||
def __repr__(self):
|
||||
return f"Crew(id={self.id}, process={self.process}, number_of_agents={len(self.agents)}, number_of_tasks={len(self.tasks)})"
|
||||
|
||||
def reset_memories(self, command_type: str) -> None:
|
||||
@@ -1410,9 +1397,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
if (system := config.get("system")) is not None:
|
||||
name = config.get("name")
|
||||
try:
|
||||
reset_fn: Callable[..., None] = cast(
|
||||
Callable[..., None], config.get("reset")
|
||||
)
|
||||
reset_fn: Callable = cast(Callable, config.get("reset"))
|
||||
reset_fn(system)
|
||||
self._logger.log(
|
||||
"info",
|
||||
@@ -1441,9 +1426,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
raise RuntimeError(f"{name} memory system is not initialized")
|
||||
|
||||
try:
|
||||
reset_fn: Callable[..., None] = cast(
|
||||
Callable[..., None], config.get("reset")
|
||||
)
|
||||
reset_fn: Callable = cast(Callable, config.get("reset"))
|
||||
reset_fn(system)
|
||||
self._logger.log(
|
||||
"info",
|
||||
@@ -1454,18 +1437,18 @@ class Crew(FlowTrackable, BaseModel):
|
||||
f"[Crew ({self.name if self.name else self.id})] Failed to reset {name} memory: {str(e)}"
|
||||
) from e
|
||||
|
||||
def _get_memory_systems(self) -> dict[str, dict[str, Any]]:
|
||||
def _get_memory_systems(self):
|
||||
"""Get all available memory systems with their configuration.
|
||||
|
||||
Returns:
|
||||
Dict containing all memory systems with their reset functions and display names.
|
||||
"""
|
||||
|
||||
def default_reset(memory: Any) -> None:
|
||||
memory.reset()
|
||||
def default_reset(memory):
|
||||
return memory.reset()
|
||||
|
||||
def knowledge_reset(memory: Any) -> None:
|
||||
self.reset_knowledge(memory)
|
||||
def knowledge_reset(memory):
|
||||
return self.reset_knowledge(memory)
|
||||
|
||||
# Get knowledge for agents
|
||||
agent_knowledges = [
|
||||
@@ -1519,12 +1502,12 @@ class Crew(FlowTrackable, BaseModel):
|
||||
},
|
||||
}
|
||||
|
||||
def reset_knowledge(self, knowledges: list[Knowledge]) -> None:
|
||||
def reset_knowledge(self, knowledges: List[Knowledge]) -> None:
|
||||
"""Reset crew and agent knowledge storage."""
|
||||
for ks in knowledges:
|
||||
ks.reset()
|
||||
|
||||
def _set_allow_crewai_trigger_context_for_first_task(self) -> None:
|
||||
def _set_allow_crewai_trigger_context_for_first_task(self):
|
||||
crewai_trigger_payload = self._inputs and self._inputs.get(
|
||||
"crewai_trigger_payload"
|
||||
)
|
||||
|
||||
@@ -1,56 +0,0 @@
|
||||
"""CrewAI events system for monitoring and extending agent behavior.
|
||||
|
||||
This module provides the event infrastructure that allows users to:
|
||||
- Monitor agent, task, and crew execution
|
||||
- Track memory operations and performance
|
||||
- Build custom logging and analytics
|
||||
- Extend CrewAI with custom event handlers
|
||||
"""
|
||||
|
||||
from crewai.events.base_event_listener import BaseEventListener
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
|
||||
from crewai.events.types.memory_events import (
|
||||
MemoryQueryCompletedEvent,
|
||||
MemorySaveCompletedEvent,
|
||||
MemorySaveStartedEvent,
|
||||
MemoryQueryStartedEvent,
|
||||
MemoryRetrievalCompletedEvent,
|
||||
MemorySaveFailedEvent,
|
||||
MemoryQueryFailedEvent,
|
||||
)
|
||||
|
||||
from crewai.events.types.knowledge_events import (
|
||||
KnowledgeRetrievalStartedEvent,
|
||||
KnowledgeRetrievalCompletedEvent,
|
||||
)
|
||||
|
||||
from crewai.events.types.crew_events import (
|
||||
CrewKickoffStartedEvent,
|
||||
CrewKickoffCompletedEvent,
|
||||
)
|
||||
from crewai.events.types.agent_events import (
|
||||
AgentExecutionCompletedEvent,
|
||||
)
|
||||
|
||||
from crewai.events.types.llm_events import (
|
||||
LLMStreamChunkEvent,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"BaseEventListener",
|
||||
"crewai_event_bus",
|
||||
"MemoryQueryCompletedEvent",
|
||||
"MemorySaveCompletedEvent",
|
||||
"MemorySaveStartedEvent",
|
||||
"MemoryQueryStartedEvent",
|
||||
"MemoryRetrievalCompletedEvent",
|
||||
"MemorySaveFailedEvent",
|
||||
"MemoryQueryFailedEvent",
|
||||
"KnowledgeRetrievalStartedEvent",
|
||||
"KnowledgeRetrievalCompletedEvent",
|
||||
"CrewKickoffStartedEvent",
|
||||
"CrewKickoffCompletedEvent",
|
||||
"AgentExecutionCompletedEvent",
|
||||
"LLMStreamChunkEvent",
|
||||
]
|
||||
@@ -1,15 +0,0 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from crewai.events.event_bus import CrewAIEventsBus, crewai_event_bus
|
||||
|
||||
|
||||
class BaseEventListener(ABC):
|
||||
verbose: bool = False
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.setup_listeners(crewai_event_bus)
|
||||
|
||||
@abstractmethod
|
||||
def setup_listeners(self, crewai_event_bus: CrewAIEventsBus) -> None:
|
||||
pass
|
||||
@@ -1,115 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import threading
|
||||
from collections.abc import Callable, Iterator
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, ParamSpec, TypeVar, cast
|
||||
|
||||
from blinker import Signal
|
||||
from typing_extensions import Self
|
||||
|
||||
from crewai.events.base_events import BaseEvent
|
||||
|
||||
EventT = TypeVar("EventT", bound=BaseEvent)
|
||||
P = ParamSpec("P")
|
||||
|
||||
|
||||
class CrewAIEventsBus:
|
||||
"""
|
||||
A singleton event bus that uses blinker signals for event handling.
|
||||
Allows both internal (Flow/Crew) and external event handling.
|
||||
"""
|
||||
|
||||
_instance = None
|
||||
_lock = threading.Lock()
|
||||
|
||||
def __new__(cls) -> Self:
|
||||
if cls._instance is None:
|
||||
with cls._lock:
|
||||
if cls._instance is None: # prevent race condition
|
||||
cls._instance = super().__new__(cls)
|
||||
cls._instance._initialize()
|
||||
return cls._instance
|
||||
|
||||
def _initialize(self) -> None:
|
||||
"""Initialize the event bus internal state"""
|
||||
self._signal = Signal("crewai_event_bus")
|
||||
self._handlers: dict[type[BaseEvent], list[Callable[[Any, Any], None]]] = {}
|
||||
|
||||
def on(
|
||||
self, event_type: type[EventT]
|
||||
) -> Callable[[Callable[[Any, EventT], None]], Callable[[Any, EventT], None]]:
|
||||
"""
|
||||
Decorator to register an event handler for a specific event type.
|
||||
|
||||
Usage:
|
||||
@crewai_event_bus.on(AgentExecutionCompletedEvent)
|
||||
def on_agent_execution_completed(
|
||||
source: Any, event: AgentExecutionCompletedEvent
|
||||
):
|
||||
print(f"👍 Agent '{event.agent}' completed task")
|
||||
print(f" Output: {event.output}")
|
||||
"""
|
||||
|
||||
def decorator(
|
||||
handler: Callable[[Any, EventT], None],
|
||||
) -> Callable[[Any, EventT], None]:
|
||||
if event_type not in self._handlers:
|
||||
self._handlers[event_type] = []
|
||||
self._handlers[event_type].append(cast(Callable[[Any, Any], None], handler))
|
||||
return handler
|
||||
|
||||
return decorator
|
||||
|
||||
def emit(self, source: Any, event: BaseEvent) -> None:
|
||||
"""
|
||||
Emit an event to all registered handlers
|
||||
|
||||
Args:
|
||||
source: The object emitting the event
|
||||
event: The event instance to emit
|
||||
"""
|
||||
for event_type, handlers in self._handlers.items():
|
||||
if isinstance(event, event_type):
|
||||
for handler in handlers:
|
||||
try:
|
||||
handler(source, event)
|
||||
except Exception as e:
|
||||
print(
|
||||
f"[EventBus Error] Handler '{handler.__name__}' failed for event '{event_type.__name__}': {e}"
|
||||
)
|
||||
|
||||
self._signal.send(source, event=event)
|
||||
|
||||
def register_handler(
|
||||
self, event_type: type[BaseEvent], handler: Callable[[Any, Any], None]
|
||||
) -> None:
|
||||
"""Register an event handler for a specific event type"""
|
||||
if event_type not in self._handlers:
|
||||
self._handlers[event_type] = []
|
||||
self._handlers[event_type].append(handler)
|
||||
|
||||
@contextmanager
|
||||
def scoped_handlers(self) -> Iterator[None]:
|
||||
"""
|
||||
Context manager for temporary event handling scope.
|
||||
Useful for testing or temporary event handling.
|
||||
|
||||
Usage:
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
@crewai_event_bus.on(CrewKickoffStarted)
|
||||
def temp_handler(source, event):
|
||||
print("Temporary handler")
|
||||
# Do stuff...
|
||||
# Handlers are cleared after the context
|
||||
"""
|
||||
previous_handlers = self._handlers.copy()
|
||||
self._handlers.clear()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self._handlers = previous_handlers
|
||||
|
||||
|
||||
# Global instance
|
||||
crewai_event_bus = CrewAIEventsBus()
|
||||
@@ -1,5 +0,0 @@
|
||||
"""Event listener implementations for CrewAI.
|
||||
|
||||
This module contains various event listener implementations
|
||||
for handling memory, tracing, and other event-driven functionality.
|
||||
"""
|
||||
@@ -1,5 +0,0 @@
|
||||
"""Event type definitions for CrewAI.
|
||||
|
||||
This module contains all event types used throughout the CrewAI system
|
||||
for monitoring and extending agent, crew, task, and tool execution.
|
||||
"""
|
||||
@@ -1,25 +0,0 @@
|
||||
"""Agent logging events that don't reference BaseAgent to avoid circular imports."""
|
||||
|
||||
from typing import Any, Optional
|
||||
|
||||
from crewai.events.base_events import BaseEvent
|
||||
|
||||
|
||||
class AgentLogsStartedEvent(BaseEvent):
|
||||
"""Event emitted when agent logs should be shown at start"""
|
||||
|
||||
agent_role: str
|
||||
task_description: Optional[str] = None
|
||||
verbose: bool = False
|
||||
type: str = "agent_logs_started"
|
||||
|
||||
|
||||
class AgentLogsExecutionEvent(BaseEvent):
|
||||
"""Event emitted when agent logs should be shown during execution"""
|
||||
|
||||
agent_role: str
|
||||
formatted_answer: Any
|
||||
verbose: bool = False
|
||||
type: str = "agent_logs_execution"
|
||||
|
||||
model_config = {"arbitrary_types_allowed": True}
|
||||
@@ -1,47 +0,0 @@
|
||||
from crewai.events.base_events import BaseEvent
|
||||
from typing import Any, Optional
|
||||
|
||||
|
||||
class ReasoningEvent(BaseEvent):
|
||||
"""Base event for reasoning events."""
|
||||
|
||||
type: str
|
||||
attempt: int = 1
|
||||
agent_role: str
|
||||
task_id: str
|
||||
task_name: Optional[str] = None
|
||||
from_task: Optional[Any] = None
|
||||
agent_id: Optional[str] = None
|
||||
from_agent: Optional[Any] = None
|
||||
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
self._set_task_params(data)
|
||||
self._set_agent_params(data)
|
||||
|
||||
|
||||
class AgentReasoningStartedEvent(ReasoningEvent):
|
||||
"""Event emitted when an agent starts reasoning about a task."""
|
||||
|
||||
type: str = "agent_reasoning_started"
|
||||
agent_role: str
|
||||
task_id: str
|
||||
|
||||
|
||||
class AgentReasoningCompletedEvent(ReasoningEvent):
|
||||
"""Event emitted when an agent finishes its reasoning process."""
|
||||
|
||||
type: str = "agent_reasoning_completed"
|
||||
agent_role: str
|
||||
task_id: str
|
||||
plan: str
|
||||
ready: bool
|
||||
|
||||
|
||||
class AgentReasoningFailedEvent(ReasoningEvent):
|
||||
"""Event emitted when the reasoning process fails."""
|
||||
|
||||
type: str = "agent_reasoning_failed"
|
||||
agent_role: str
|
||||
task_id: str
|
||||
error: str
|
||||
@@ -1,42 +1,28 @@
|
||||
import threading
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from crewai.experimental.evaluation.base_evaluator import (
|
||||
AgentEvaluationResult,
|
||||
AggregationStrategy,
|
||||
)
|
||||
from crewai.experimental.evaluation.base_evaluator import AgentEvaluationResult, AggregationStrategy
|
||||
from crewai.agent import Agent
|
||||
from crewai.task import Task
|
||||
from crewai.experimental.evaluation.evaluation_display import EvaluationDisplayFormatter
|
||||
from crewai.events.types.agent_events import (
|
||||
AgentEvaluationStartedEvent,
|
||||
AgentEvaluationCompletedEvent,
|
||||
AgentEvaluationFailedEvent,
|
||||
)
|
||||
from crewai.utilities.events.agent_events import AgentEvaluationStartedEvent, AgentEvaluationCompletedEvent, AgentEvaluationFailedEvent
|
||||
from crewai.experimental.evaluation import BaseEvaluator, create_evaluation_callbacks
|
||||
from collections.abc import Sequence
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.utils.console_formatter import ConsoleFormatter
|
||||
from crewai.events.types.task_events import TaskCompletedEvent
|
||||
from crewai.events.types.agent_events import LiteAgentExecutionCompletedEvent
|
||||
from crewai.experimental.evaluation.base_evaluator import (
|
||||
AgentAggregatedEvaluationResult,
|
||||
EvaluationScore,
|
||||
MetricCategory,
|
||||
)
|
||||
|
||||
from crewai.utilities.events.crewai_event_bus import crewai_event_bus
|
||||
from crewai.utilities.events.utils.console_formatter import ConsoleFormatter
|
||||
from crewai.utilities.events.task_events import TaskCompletedEvent
|
||||
from crewai.utilities.events.agent_events import LiteAgentExecutionCompletedEvent
|
||||
from crewai.experimental.evaluation.base_evaluator import AgentAggregatedEvaluationResult, EvaluationScore, MetricCategory
|
||||
|
||||
class ExecutionState:
|
||||
current_agent_id: Optional[str] = None
|
||||
current_task_id: Optional[str] = None
|
||||
|
||||
def __init__(self):
|
||||
self.traces = {}
|
||||
self.current_agent_id: str | None = None
|
||||
self.current_task_id: str | None = None
|
||||
self.iteration = 1
|
||||
self.iterations_results = {}
|
||||
self.agent_evaluators = {}
|
||||
|
||||
|
||||
class AgentEvaluator:
|
||||
def __init__(
|
||||
self,
|
||||
@@ -59,45 +45,27 @@ class AgentEvaluator:
|
||||
|
||||
@property
|
||||
def _execution_state(self) -> ExecutionState:
|
||||
if not hasattr(self._thread_local, "execution_state"):
|
||||
if not hasattr(self._thread_local, 'execution_state'):
|
||||
self._thread_local.execution_state = ExecutionState()
|
||||
return self._thread_local.execution_state
|
||||
|
||||
def _subscribe_to_events(self) -> None:
|
||||
from typing import cast
|
||||
|
||||
crewai_event_bus.register_handler(
|
||||
TaskCompletedEvent, cast(Any, self._handle_task_completed)
|
||||
)
|
||||
crewai_event_bus.register_handler(
|
||||
LiteAgentExecutionCompletedEvent,
|
||||
cast(Any, self._handle_lite_agent_completed),
|
||||
)
|
||||
crewai_event_bus.register_handler(TaskCompletedEvent, cast(Any, self._handle_task_completed))
|
||||
crewai_event_bus.register_handler(LiteAgentExecutionCompletedEvent, cast(Any, self._handle_lite_agent_completed))
|
||||
|
||||
def _handle_task_completed(self, source: Any, event: TaskCompletedEvent) -> None:
|
||||
assert event.task is not None
|
||||
agent = event.task.agent
|
||||
if (
|
||||
agent
|
||||
and str(getattr(agent, "id", "unknown"))
|
||||
in self._execution_state.agent_evaluators
|
||||
):
|
||||
self.emit_evaluation_started_event(
|
||||
agent_role=agent.role,
|
||||
agent_id=str(agent.id),
|
||||
task_id=str(event.task.id),
|
||||
)
|
||||
if agent and str(getattr(agent, 'id', 'unknown')) in self._execution_state.agent_evaluators:
|
||||
self.emit_evaluation_started_event(agent_role=agent.role, agent_id=str(agent.id), task_id=str(event.task.id))
|
||||
|
||||
state = ExecutionState()
|
||||
state.current_agent_id = str(agent.id)
|
||||
state.current_task_id = str(event.task.id)
|
||||
|
||||
assert (
|
||||
state.current_agent_id is not None and state.current_task_id is not None
|
||||
)
|
||||
trace = self.callback.get_trace(
|
||||
state.current_agent_id, state.current_task_id
|
||||
)
|
||||
assert state.current_agent_id is not None and state.current_task_id is not None
|
||||
trace = self.callback.get_trace(state.current_agent_id, state.current_task_id)
|
||||
|
||||
if not trace:
|
||||
return
|
||||
@@ -107,28 +75,19 @@ class AgentEvaluator:
|
||||
task=event.task,
|
||||
execution_trace=trace,
|
||||
final_output=event.output,
|
||||
state=state,
|
||||
state=state
|
||||
)
|
||||
|
||||
current_iteration = self._execution_state.iteration
|
||||
if current_iteration not in self._execution_state.iterations_results:
|
||||
self._execution_state.iterations_results[current_iteration] = {}
|
||||
|
||||
if (
|
||||
agent.role
|
||||
not in self._execution_state.iterations_results[current_iteration]
|
||||
):
|
||||
self._execution_state.iterations_results[current_iteration][
|
||||
agent.role
|
||||
] = []
|
||||
if agent.role not in self._execution_state.iterations_results[current_iteration]:
|
||||
self._execution_state.iterations_results[current_iteration][agent.role] = []
|
||||
|
||||
self._execution_state.iterations_results[current_iteration][
|
||||
agent.role
|
||||
].append(result)
|
||||
self._execution_state.iterations_results[current_iteration][agent.role].append(result)
|
||||
|
||||
def _handle_lite_agent_completed(
|
||||
self, source: object, event: LiteAgentExecutionCompletedEvent
|
||||
) -> None:
|
||||
def _handle_lite_agent_completed(self, source: object, event: LiteAgentExecutionCompletedEvent) -> None:
|
||||
agent_info = event.agent_info
|
||||
agent_id = str(agent_info["id"])
|
||||
|
||||
@@ -146,12 +105,8 @@ class AgentEvaluator:
|
||||
if not target_agent:
|
||||
return
|
||||
|
||||
assert (
|
||||
state.current_agent_id is not None and state.current_task_id is not None
|
||||
)
|
||||
trace = self.callback.get_trace(
|
||||
state.current_agent_id, state.current_task_id
|
||||
)
|
||||
assert state.current_agent_id is not None and state.current_task_id is not None
|
||||
trace = self.callback.get_trace(state.current_agent_id, state.current_task_id)
|
||||
|
||||
if not trace:
|
||||
return
|
||||
@@ -160,7 +115,7 @@ class AgentEvaluator:
|
||||
agent=target_agent,
|
||||
execution_trace=trace,
|
||||
final_output=event.output,
|
||||
state=state,
|
||||
state=state
|
||||
)
|
||||
|
||||
current_iteration = self._execution_state.iteration
|
||||
@@ -168,17 +123,10 @@ class AgentEvaluator:
|
||||
self._execution_state.iterations_results[current_iteration] = {}
|
||||
|
||||
agent_role = target_agent.role
|
||||
if (
|
||||
agent_role
|
||||
not in self._execution_state.iterations_results[current_iteration]
|
||||
):
|
||||
self._execution_state.iterations_results[current_iteration][
|
||||
agent_role
|
||||
] = []
|
||||
if agent_role not in self._execution_state.iterations_results[current_iteration]:
|
||||
self._execution_state.iterations_results[current_iteration][agent_role] = []
|
||||
|
||||
self._execution_state.iterations_results[current_iteration][
|
||||
agent_role
|
||||
].append(result)
|
||||
self._execution_state.iterations_results[current_iteration][agent_role].append(result)
|
||||
|
||||
def set_iteration(self, iteration: int) -> None:
|
||||
self._execution_state.iteration = iteration
|
||||
@@ -187,26 +135,14 @@ class AgentEvaluator:
|
||||
self._execution_state.iterations_results = {}
|
||||
|
||||
def get_evaluation_results(self) -> dict[str, list[AgentEvaluationResult]]:
|
||||
if (
|
||||
self._execution_state.iterations_results
|
||||
and self._execution_state.iteration
|
||||
in self._execution_state.iterations_results
|
||||
):
|
||||
return self._execution_state.iterations_results[
|
||||
self._execution_state.iteration
|
||||
]
|
||||
if self._execution_state.iterations_results and self._execution_state.iteration in self._execution_state.iterations_results:
|
||||
return self._execution_state.iterations_results[self._execution_state.iteration]
|
||||
return {}
|
||||
|
||||
def display_results_with_iterations(self) -> None:
|
||||
self.display_formatter.display_summary_results(
|
||||
self._execution_state.iterations_results
|
||||
)
|
||||
self.display_formatter.display_summary_results(self._execution_state.iterations_results)
|
||||
|
||||
def get_agent_evaluation(
|
||||
self,
|
||||
strategy: AggregationStrategy = AggregationStrategy.SIMPLE_AVERAGE,
|
||||
include_evaluation_feedback: bool = True,
|
||||
) -> dict[str, AgentAggregatedEvaluationResult]:
|
||||
def get_agent_evaluation(self, strategy: AggregationStrategy = AggregationStrategy.SIMPLE_AVERAGE, include_evaluation_feedback: bool = True) -> dict[str, AgentAggregatedEvaluationResult]:
|
||||
agent_results = {}
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
task_results = self.get_evaluation_results()
|
||||
@@ -220,16 +156,13 @@ class AgentEvaluator:
|
||||
agent_id=agent_id,
|
||||
agent_role=agent_role,
|
||||
results=results,
|
||||
strategy=strategy,
|
||||
strategy=strategy
|
||||
)
|
||||
|
||||
agent_results[agent_role] = aggregated_result
|
||||
|
||||
if (
|
||||
self._execution_state.iterations_results
|
||||
and self._execution_state.iteration
|
||||
== max(self._execution_state.iterations_results.keys(), default=0)
|
||||
):
|
||||
|
||||
if self._execution_state.iterations_results and self._execution_state.iteration == max(self._execution_state.iterations_results.keys(), default=0):
|
||||
self.display_results_with_iterations()
|
||||
|
||||
if include_evaluation_feedback:
|
||||
@@ -238,9 +171,7 @@ class AgentEvaluator:
|
||||
return agent_results
|
||||
|
||||
def display_evaluation_with_feedback(self) -> None:
|
||||
self.display_formatter.display_evaluation_with_feedback(
|
||||
self._execution_state.iterations_results
|
||||
)
|
||||
self.display_formatter.display_evaluation_with_feedback(self._execution_state.iterations_results)
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
@@ -252,91 +183,46 @@ class AgentEvaluator:
|
||||
) -> AgentEvaluationResult:
|
||||
result = AgentEvaluationResult(
|
||||
agent_id=state.current_agent_id or str(agent.id),
|
||||
task_id=state.current_task_id or (str(task.id) if task else "unknown_task"),
|
||||
task_id=state.current_task_id or (str(task.id) if task else "unknown_task")
|
||||
)
|
||||
|
||||
assert self.evaluators is not None
|
||||
task_id = str(task.id) if task else None
|
||||
for evaluator in self.evaluators:
|
||||
try:
|
||||
self.emit_evaluation_started_event(
|
||||
agent_role=agent.role, agent_id=str(agent.id), task_id=task_id
|
||||
)
|
||||
self.emit_evaluation_started_event(agent_role=agent.role, agent_id=str(agent.id), task_id=task_id)
|
||||
score = evaluator.evaluate(
|
||||
agent=agent,
|
||||
task=task,
|
||||
execution_trace=execution_trace,
|
||||
final_output=final_output,
|
||||
final_output=final_output
|
||||
)
|
||||
result.metrics[evaluator.metric_category] = score
|
||||
self.emit_evaluation_completed_event(
|
||||
agent_role=agent.role,
|
||||
agent_id=str(agent.id),
|
||||
task_id=task_id,
|
||||
metric_category=evaluator.metric_category,
|
||||
score=score,
|
||||
)
|
||||
self.emit_evaluation_completed_event(agent_role=agent.role, agent_id=str(agent.id), task_id=task_id, metric_category=evaluator.metric_category, score=score)
|
||||
except Exception as e:
|
||||
self.emit_evaluation_failed_event(
|
||||
agent_role=agent.role,
|
||||
agent_id=str(agent.id),
|
||||
task_id=task_id,
|
||||
error=str(e),
|
||||
)
|
||||
self.console_formatter.print(
|
||||
f"Error in {evaluator.metric_category.value} evaluator: {str(e)}"
|
||||
)
|
||||
self.emit_evaluation_failed_event(agent_role=agent.role, agent_id=str(agent.id), task_id=task_id, error=str(e))
|
||||
self.console_formatter.print(f"Error in {evaluator.metric_category.value} evaluator: {str(e)}")
|
||||
|
||||
return result
|
||||
|
||||
def emit_evaluation_started_event(
|
||||
self, agent_role: str, agent_id: str, task_id: str | None = None
|
||||
):
|
||||
def emit_evaluation_started_event(self, agent_role: str, agent_id: str, task_id: str | None = None):
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
AgentEvaluationStartedEvent(
|
||||
agent_role=agent_role,
|
||||
agent_id=agent_id,
|
||||
task_id=task_id,
|
||||
iteration=self._execution_state.iteration,
|
||||
),
|
||||
AgentEvaluationStartedEvent(agent_role=agent_role, agent_id=agent_id, task_id=task_id, iteration=self._execution_state.iteration)
|
||||
)
|
||||
|
||||
def emit_evaluation_completed_event(
|
||||
self,
|
||||
agent_role: str,
|
||||
agent_id: str,
|
||||
task_id: str | None = None,
|
||||
metric_category: MetricCategory | None = None,
|
||||
score: EvaluationScore | None = None,
|
||||
):
|
||||
def emit_evaluation_completed_event(self, agent_role: str, agent_id: str, task_id: str | None = None, metric_category: MetricCategory | None = None, score: EvaluationScore | None = None):
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
AgentEvaluationCompletedEvent(
|
||||
agent_role=agent_role,
|
||||
agent_id=agent_id,
|
||||
task_id=task_id,
|
||||
iteration=self._execution_state.iteration,
|
||||
metric_category=metric_category,
|
||||
score=score,
|
||||
),
|
||||
AgentEvaluationCompletedEvent(agent_role=agent_role, agent_id=agent_id, task_id=task_id, iteration=self._execution_state.iteration, metric_category=metric_category, score=score)
|
||||
)
|
||||
|
||||
def emit_evaluation_failed_event(
|
||||
self, agent_role: str, agent_id: str, error: str, task_id: str | None = None
|
||||
):
|
||||
def emit_evaluation_failed_event(self, agent_role: str, agent_id: str, error: str, task_id: str | None = None):
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
AgentEvaluationFailedEvent(
|
||||
agent_role=agent_role,
|
||||
agent_id=agent_id,
|
||||
task_id=task_id,
|
||||
iteration=self._execution_state.iteration,
|
||||
error=error,
|
||||
),
|
||||
AgentEvaluationFailedEvent(agent_role=agent_role, agent_id=agent_id, task_id=task_id, iteration=self._execution_state.iteration, error=error)
|
||||
)
|
||||
|
||||
|
||||
def create_default_evaluator(agents: list[Agent], llm: None = None):
|
||||
from crewai.experimental.evaluation import (
|
||||
GoalAlignmentEvaluator,
|
||||
@@ -344,7 +230,7 @@ def create_default_evaluator(agents: list[Agent], llm: None = None):
|
||||
ToolSelectionEvaluator,
|
||||
ParameterExtractionEvaluator,
|
||||
ToolInvocationEvaluator,
|
||||
ReasoningEfficiencyEvaluator,
|
||||
ReasoningEfficiencyEvaluator
|
||||
)
|
||||
|
||||
evaluators = [
|
||||
|
||||
@@ -1,15 +1,14 @@
|
||||
import abc
|
||||
import enum
|
||||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai.agent import Agent
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.task import Task
|
||||
from crewai.utilities.llm_utils import create_default_llm, create_llm
|
||||
|
||||
from crewai.llm import BaseLLM
|
||||
from crewai.utilities.llm_utils import create_llm
|
||||
|
||||
class MetricCategory(enum.Enum):
|
||||
GOAL_ALIGNMENT = "goal_alignment"
|
||||
@@ -19,8 +18,8 @@ class MetricCategory(enum.Enum):
|
||||
PARAMETER_EXTRACTION = "parameter_extraction"
|
||||
TOOL_INVOCATION = "tool_invocation"
|
||||
|
||||
def title(self) -> str:
|
||||
return self.value.replace("_", " ").title()
|
||||
def title(self):
|
||||
return self.value.replace('_', ' ').title()
|
||||
|
||||
|
||||
class EvaluationScore(BaseModel):
|
||||
@@ -28,13 +27,15 @@ class EvaluationScore(BaseModel):
|
||||
default=5.0,
|
||||
description="Numeric score from 0-10 where 0 is worst and 10 is best, None if not applicable",
|
||||
ge=0.0,
|
||||
le=10.0,
|
||||
le=10.0
|
||||
)
|
||||
feedback: str = Field(
|
||||
default="", description="Detailed feedback explaining the evaluation score"
|
||||
default="",
|
||||
description="Detailed feedback explaining the evaluation score"
|
||||
)
|
||||
raw_response: str | None = Field(
|
||||
default=None, description="Raw response from the evaluator (e.g., LLM)"
|
||||
default=None,
|
||||
description="Raw response from the evaluator (e.g., LLM)"
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
@@ -45,9 +46,7 @@ class EvaluationScore(BaseModel):
|
||||
|
||||
class BaseEvaluator(abc.ABC):
|
||||
def __init__(self, llm: BaseLLM | None = None):
|
||||
self.llm: BaseLLM | None = (
|
||||
create_llm(llm) if llm is not None else create_default_llm()
|
||||
)
|
||||
self.llm: BaseLLM | None = create_llm(llm)
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
@@ -58,7 +57,7 @@ class BaseEvaluator(abc.ABC):
|
||||
def evaluate(
|
||||
self,
|
||||
agent: Agent,
|
||||
execution_trace: dict[str, Any],
|
||||
execution_trace: Dict[str, Any],
|
||||
final_output: Any,
|
||||
task: Task | None = None,
|
||||
) -> EvaluationScore:
|
||||
@@ -68,8 +67,9 @@ class BaseEvaluator(abc.ABC):
|
||||
class AgentEvaluationResult(BaseModel):
|
||||
agent_id: str = Field(description="ID of the evaluated agent")
|
||||
task_id: str = Field(description="ID of the task that was executed")
|
||||
metrics: dict[MetricCategory, EvaluationScore] = Field(
|
||||
default_factory=dict, description="Evaluation scores for each metric category"
|
||||
metrics: Dict[MetricCategory, EvaluationScore] = Field(
|
||||
default_factory=dict,
|
||||
description="Evaluation scores for each metric category"
|
||||
)
|
||||
|
||||
|
||||
@@ -81,23 +81,33 @@ class AggregationStrategy(Enum):
|
||||
|
||||
|
||||
class AgentAggregatedEvaluationResult(BaseModel):
|
||||
agent_id: str = Field(default="", description="ID of the agent")
|
||||
agent_role: str = Field(default="", description="Role of the agent")
|
||||
agent_id: str = Field(
|
||||
default="",
|
||||
description="ID of the agent"
|
||||
)
|
||||
agent_role: str = Field(
|
||||
default="",
|
||||
description="Role of the agent"
|
||||
)
|
||||
task_count: int = Field(
|
||||
default=0, description="Number of tasks included in this aggregation"
|
||||
default=0,
|
||||
description="Number of tasks included in this aggregation"
|
||||
)
|
||||
aggregation_strategy: AggregationStrategy = Field(
|
||||
default=AggregationStrategy.SIMPLE_AVERAGE,
|
||||
description="Strategy used for aggregation",
|
||||
description="Strategy used for aggregation"
|
||||
)
|
||||
metrics: dict[MetricCategory, EvaluationScore] = Field(
|
||||
default_factory=dict, description="Aggregated metrics across all tasks"
|
||||
metrics: Dict[MetricCategory, EvaluationScore] = Field(
|
||||
default_factory=dict,
|
||||
description="Aggregated metrics across all tasks"
|
||||
)
|
||||
task_results: list[str] = Field(
|
||||
default_factory=list, description="IDs of tasks included in this aggregation"
|
||||
task_results: List[str] = Field(
|
||||
default_factory=list,
|
||||
description="IDs of tasks included in this aggregation"
|
||||
)
|
||||
overall_score: Optional[float] = Field(
|
||||
default=None, description="Overall score for this agent"
|
||||
default=None,
|
||||
description="Overall score for this agent"
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
@@ -109,7 +119,7 @@ class AgentAggregatedEvaluationResult(BaseModel):
|
||||
result += f"\n\n- {category.value.upper()}: {score.score}/10\n"
|
||||
|
||||
if score.feedback:
|
||||
detailed_feedback = "\n ".join(score.feedback.split("\n"))
|
||||
detailed_feedback = "\n ".join(score.feedback.split('\n'))
|
||||
result += f" {detailed_feedback}\n"
|
||||
|
||||
return result
|
||||
return result
|
||||
@@ -3,28 +3,18 @@ from typing import Dict, Any, List
|
||||
from rich.table import Table
|
||||
from rich.box import HEAVY_EDGE, ROUNDED
|
||||
from collections.abc import Sequence
|
||||
from crewai.experimental.evaluation.base_evaluator import (
|
||||
AgentAggregatedEvaluationResult,
|
||||
AggregationStrategy,
|
||||
AgentEvaluationResult,
|
||||
MetricCategory,
|
||||
)
|
||||
from crewai.experimental.evaluation.base_evaluator import AgentAggregatedEvaluationResult, AggregationStrategy, AgentEvaluationResult, MetricCategory
|
||||
from crewai.experimental.evaluation import EvaluationScore
|
||||
from crewai.events.utils.console_formatter import ConsoleFormatter
|
||||
from crewai.utilities.events.utils.console_formatter import ConsoleFormatter
|
||||
from crewai.utilities.llm_utils import create_llm
|
||||
|
||||
|
||||
class EvaluationDisplayFormatter:
|
||||
def __init__(self):
|
||||
self.console_formatter = ConsoleFormatter()
|
||||
|
||||
def display_evaluation_with_feedback(
|
||||
self, iterations_results: Dict[int, Dict[str, List[Any]]]
|
||||
):
|
||||
def display_evaluation_with_feedback(self, iterations_results: Dict[int, Dict[str, List[Any]]]):
|
||||
if not iterations_results:
|
||||
self.console_formatter.print(
|
||||
"[yellow]No evaluation results to display[/yellow]"
|
||||
)
|
||||
self.console_formatter.print("[yellow]No evaluation results to display[/yellow]")
|
||||
return
|
||||
|
||||
all_agent_roles: set[str] = set()
|
||||
@@ -32,9 +22,7 @@ class EvaluationDisplayFormatter:
|
||||
all_agent_roles.update(iter_results.keys())
|
||||
|
||||
for agent_role in sorted(all_agent_roles):
|
||||
self.console_formatter.print(
|
||||
f"\n[bold cyan]Agent: {agent_role}[/bold cyan]"
|
||||
)
|
||||
self.console_formatter.print(f"\n[bold cyan]Agent: {agent_role}[/bold cyan]")
|
||||
|
||||
for iter_num, results in sorted(iterations_results.items()):
|
||||
if agent_role not in results or not results[agent_role]:
|
||||
@@ -74,7 +62,9 @@ class EvaluationDisplayFormatter:
|
||||
|
||||
table.add_section()
|
||||
table.add_row(
|
||||
metric.title(), score_text, evaluation_score.feedback or ""
|
||||
metric.title(),
|
||||
score_text,
|
||||
evaluation_score.feedback or ""
|
||||
)
|
||||
|
||||
if aggregated_result.overall_score is not None:
|
||||
@@ -92,26 +82,19 @@ class EvaluationDisplayFormatter:
|
||||
table.add_row(
|
||||
"Overall Score",
|
||||
f"[{overall_color}]{overall_score:.1f}[/]",
|
||||
"Overall agent evaluation score",
|
||||
"Overall agent evaluation score"
|
||||
)
|
||||
|
||||
self.console_formatter.print(table)
|
||||
|
||||
def display_summary_results(
|
||||
self,
|
||||
iterations_results: Dict[int, Dict[str, List[AgentAggregatedEvaluationResult]]],
|
||||
):
|
||||
def display_summary_results(self, iterations_results: Dict[int, Dict[str, List[AgentAggregatedEvaluationResult]]]):
|
||||
if not iterations_results:
|
||||
self.console_formatter.print(
|
||||
"[yellow]No evaluation results to display[/yellow]"
|
||||
)
|
||||
self.console_formatter.print("[yellow]No evaluation results to display[/yellow]")
|
||||
return
|
||||
|
||||
self.console_formatter.print("\n")
|
||||
|
||||
table = Table(
|
||||
title="Agent Performance Scores \n (1-10 Higher is better)", box=HEAVY_EDGE
|
||||
)
|
||||
table = Table(title="Agent Performance Scores \n (1-10 Higher is better)", box=HEAVY_EDGE)
|
||||
|
||||
table.add_column("Agent/Metric", style="cyan")
|
||||
|
||||
@@ -140,14 +123,11 @@ class EvaluationDisplayFormatter:
|
||||
agent_id=agent_id,
|
||||
agent_role=agent_role,
|
||||
results=agent_results,
|
||||
strategy=AggregationStrategy.SIMPLE_AVERAGE,
|
||||
strategy=AggregationStrategy.SIMPLE_AVERAGE
|
||||
)
|
||||
|
||||
valid_scores = [
|
||||
score.score
|
||||
for score in aggregated_result.metrics.values()
|
||||
if score.score is not None
|
||||
]
|
||||
valid_scores = [score.score for score in aggregated_result.metrics.values()
|
||||
if score.score is not None]
|
||||
if valid_scores:
|
||||
avg_score = sum(valid_scores) / len(valid_scores)
|
||||
agent_scores_by_iteration[iter_num] = avg_score
|
||||
@@ -157,9 +137,7 @@ class EvaluationDisplayFormatter:
|
||||
if not agent_scores_by_iteration:
|
||||
continue
|
||||
|
||||
avg_across_iterations = sum(agent_scores_by_iteration.values()) / len(
|
||||
agent_scores_by_iteration
|
||||
)
|
||||
avg_across_iterations = sum(agent_scores_by_iteration.values()) / len(agent_scores_by_iteration)
|
||||
|
||||
row = [f"[bold]{agent_role}[/bold]"]
|
||||
|
||||
@@ -200,13 +178,9 @@ class EvaluationDisplayFormatter:
|
||||
row = [f" - {metric.title()}"]
|
||||
|
||||
for iter_num in sorted(iterations_results.keys()):
|
||||
if (
|
||||
iter_num in agent_metrics_by_iteration
|
||||
and metric in agent_metrics_by_iteration[iter_num]
|
||||
):
|
||||
metric_score = agent_metrics_by_iteration[iter_num][
|
||||
metric
|
||||
].score
|
||||
if (iter_num in agent_metrics_by_iteration and
|
||||
metric in agent_metrics_by_iteration[iter_num]):
|
||||
metric_score = agent_metrics_by_iteration[iter_num][metric].score
|
||||
if metric_score is not None:
|
||||
metric_scores.append(metric_score)
|
||||
if metric_score >= 8.0:
|
||||
@@ -251,9 +225,7 @@ class EvaluationDisplayFormatter:
|
||||
results: Sequence[AgentEvaluationResult],
|
||||
strategy: AggregationStrategy = AggregationStrategy.SIMPLE_AVERAGE,
|
||||
) -> AgentAggregatedEvaluationResult:
|
||||
metrics_by_category: dict[MetricCategory, list[EvaluationScore]] = defaultdict(
|
||||
list
|
||||
)
|
||||
metrics_by_category: dict[MetricCategory, list[EvaluationScore]] = defaultdict(list)
|
||||
|
||||
for result in results:
|
||||
for metric_name, evaluation_score in result.metrics.items():
|
||||
@@ -274,20 +246,19 @@ class EvaluationDisplayFormatter:
|
||||
metric=category.title(),
|
||||
feedbacks=feedbacks,
|
||||
scores=[s.score for s in scores],
|
||||
strategy=strategy,
|
||||
strategy=strategy
|
||||
)
|
||||
else:
|
||||
feedback_summary = feedbacks[0]
|
||||
|
||||
aggregated_metrics[category] = EvaluationScore(
|
||||
score=avg_score, feedback=feedback_summary
|
||||
score=avg_score,
|
||||
feedback=feedback_summary
|
||||
)
|
||||
|
||||
overall_score = None
|
||||
if aggregated_metrics:
|
||||
valid_scores = [
|
||||
m.score for m in aggregated_metrics.values() if m.score is not None
|
||||
]
|
||||
valid_scores = [m.score for m in aggregated_metrics.values() if m.score is not None]
|
||||
if valid_scores:
|
||||
overall_score = sum(valid_scores) / len(valid_scores)
|
||||
|
||||
@@ -297,7 +268,7 @@ class EvaluationDisplayFormatter:
|
||||
metrics=aggregated_metrics,
|
||||
overall_score=overall_score,
|
||||
task_count=len(results),
|
||||
aggregation_strategy=strategy,
|
||||
aggregation_strategy=strategy
|
||||
)
|
||||
|
||||
def _summarize_feedbacks(
|
||||
@@ -306,12 +277,10 @@ class EvaluationDisplayFormatter:
|
||||
metric: str,
|
||||
feedbacks: List[str],
|
||||
scores: List[float | None],
|
||||
strategy: AggregationStrategy,
|
||||
strategy: AggregationStrategy
|
||||
) -> str:
|
||||
if len(feedbacks) <= 2 and all(len(fb) < 200 for fb in feedbacks):
|
||||
return "\n\n".join(
|
||||
[f"Feedback {i+1}: {fb}" for i, fb in enumerate(feedbacks)]
|
||||
)
|
||||
return "\n\n".join([f"Feedback {i+1}: {fb}" for i, fb in enumerate(feedbacks)])
|
||||
|
||||
try:
|
||||
llm = create_llm()
|
||||
@@ -321,26 +290,20 @@ class EvaluationDisplayFormatter:
|
||||
if len(feedback) > 500:
|
||||
feedback = feedback[:500] + "..."
|
||||
score_text = f"{score:.1f}" if score is not None else "N/A"
|
||||
formatted_feedbacks.append(
|
||||
f"Feedback #{i+1} (Score: {score_text}):\n{feedback}"
|
||||
)
|
||||
formatted_feedbacks.append(f"Feedback #{i+1} (Score: {score_text}):\n{feedback}")
|
||||
|
||||
all_feedbacks = "\n\n" + "\n\n---\n\n".join(formatted_feedbacks)
|
||||
|
||||
strategy_guidance = ""
|
||||
if strategy == AggregationStrategy.BEST_PERFORMANCE:
|
||||
strategy_guidance = (
|
||||
"Focus on the highest-scoring aspects and strengths demonstrated."
|
||||
)
|
||||
strategy_guidance = "Focus on the highest-scoring aspects and strengths demonstrated."
|
||||
elif strategy == AggregationStrategy.WORST_PERFORMANCE:
|
||||
strategy_guidance = "Focus on areas that need improvement and common issues across tasks."
|
||||
else:
|
||||
strategy_guidance = "Provide a balanced analysis of strengths and weaknesses across all tasks."
|
||||
|
||||
prompt = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": f"""You are an expert evaluator creating a comprehensive summary of agent performance feedback.
|
||||
{"role": "system", "content": f"""You are an expert evaluator creating a comprehensive summary of agent performance feedback.
|
||||
Your job is to synthesize multiple feedback points about the same metric across different tasks.
|
||||
|
||||
Create a concise, insightful summary that captures the key patterns and themes from all feedback.
|
||||
@@ -352,18 +315,14 @@ class EvaluationDisplayFormatter:
|
||||
3. Highlighting patterns across tasks
|
||||
4. 150-250 words in length
|
||||
|
||||
The summary should be directly usable as final feedback for the agent's performance on this metric.""",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"""I need a synthesized summary of the following feedback for:
|
||||
The summary should be directly usable as final feedback for the agent's performance on this metric."""},
|
||||
{"role": "user", "content": f"""I need a synthesized summary of the following feedback for:
|
||||
|
||||
Agent Role: {agent_role}
|
||||
Metric: {metric.title()}
|
||||
|
||||
{all_feedbacks}
|
||||
""",
|
||||
},
|
||||
"""}
|
||||
]
|
||||
assert llm is not None
|
||||
response = llm.call(prompt)
|
||||
@@ -371,6 +330,4 @@ class EvaluationDisplayFormatter:
|
||||
return response
|
||||
|
||||
except Exception:
|
||||
return "Synthesized from multiple tasks: " + "\n\n".join(
|
||||
[f"- {fb[:500]}..." for fb in feedbacks]
|
||||
)
|
||||
return "Synthesized from multiple tasks: " + "\n\n".join([f"- {fb[:500]}..." for fb in feedbacks])
|
||||
|
||||
@@ -5,23 +5,25 @@ from collections.abc import Sequence
|
||||
|
||||
from crewai.agent import Agent
|
||||
from crewai.task import Task
|
||||
from crewai.events.base_event_listener import BaseEventListener
|
||||
from crewai.events.event_bus import CrewAIEventsBus
|
||||
from crewai.events.types.agent_events import (
|
||||
from crewai.utilities.events.base_event_listener import BaseEventListener
|
||||
from crewai.utilities.events.crewai_event_bus import CrewAIEventsBus
|
||||
from crewai.utilities.events.agent_events import (
|
||||
AgentExecutionStartedEvent,
|
||||
AgentExecutionCompletedEvent,
|
||||
LiteAgentExecutionStartedEvent,
|
||||
LiteAgentExecutionCompletedEvent,
|
||||
LiteAgentExecutionCompletedEvent
|
||||
)
|
||||
from crewai.events.types.tool_usage_events import (
|
||||
from crewai.utilities.events.tool_usage_events import (
|
||||
ToolUsageFinishedEvent,
|
||||
ToolUsageErrorEvent,
|
||||
ToolExecutionErrorEvent,
|
||||
ToolSelectionErrorEvent,
|
||||
ToolValidateInputErrorEvent,
|
||||
ToolValidateInputErrorEvent
|
||||
)
|
||||
from crewai.utilities.events.llm_events import (
|
||||
LLMCallStartedEvent,
|
||||
LLMCallCompletedEvent
|
||||
)
|
||||
from crewai.events.types.llm_events import LLMCallStartedEvent, LLMCallCompletedEvent
|
||||
|
||||
|
||||
class EvaluationTraceCallback(BaseEventListener):
|
||||
"""Event listener for collecting execution traces for evaluation.
|
||||
@@ -66,49 +68,27 @@ class EvaluationTraceCallback(BaseEventListener):
|
||||
|
||||
@event_bus.on(ToolUsageFinishedEvent)
|
||||
def on_tool_completed(source, event: ToolUsageFinishedEvent):
|
||||
self.on_tool_use(
|
||||
event.tool_name, event.tool_args, event.output, success=True
|
||||
)
|
||||
self.on_tool_use(event.tool_name, event.tool_args, event.output, success=True)
|
||||
|
||||
@event_bus.on(ToolUsageErrorEvent)
|
||||
def on_tool_usage_error(source, event: ToolUsageErrorEvent):
|
||||
self.on_tool_use(
|
||||
event.tool_name,
|
||||
event.tool_args,
|
||||
event.error,
|
||||
success=False,
|
||||
error_type="usage_error",
|
||||
)
|
||||
self.on_tool_use(event.tool_name, event.tool_args, event.error,
|
||||
success=False, error_type="usage_error")
|
||||
|
||||
@event_bus.on(ToolExecutionErrorEvent)
|
||||
def on_tool_execution_error(source, event: ToolExecutionErrorEvent):
|
||||
self.on_tool_use(
|
||||
event.tool_name,
|
||||
event.tool_args,
|
||||
event.error,
|
||||
success=False,
|
||||
error_type="execution_error",
|
||||
)
|
||||
self.on_tool_use(event.tool_name, event.tool_args, event.error,
|
||||
success=False, error_type="execution_error")
|
||||
|
||||
@event_bus.on(ToolSelectionErrorEvent)
|
||||
def on_tool_selection_error(source, event: ToolSelectionErrorEvent):
|
||||
self.on_tool_use(
|
||||
event.tool_name,
|
||||
event.tool_args,
|
||||
event.error,
|
||||
success=False,
|
||||
error_type="selection_error",
|
||||
)
|
||||
self.on_tool_use(event.tool_name, event.tool_args, event.error,
|
||||
success=False, error_type="selection_error")
|
||||
|
||||
@event_bus.on(ToolValidateInputErrorEvent)
|
||||
def on_tool_validate_input_error(source, event: ToolValidateInputErrorEvent):
|
||||
self.on_tool_use(
|
||||
event.tool_name,
|
||||
event.tool_args,
|
||||
event.error,
|
||||
success=False,
|
||||
error_type="validation_error",
|
||||
)
|
||||
self.on_tool_use(event.tool_name, event.tool_args, event.error,
|
||||
success=False, error_type="validation_error")
|
||||
|
||||
@event_bus.on(LLMCallStartedEvent)
|
||||
def on_llm_call_started(source, event: LLMCallStartedEvent):
|
||||
@@ -119,7 +99,7 @@ class EvaluationTraceCallback(BaseEventListener):
|
||||
self.on_llm_call_end(event.messages, event.response)
|
||||
|
||||
def on_lite_agent_start(self, agent_info: dict[str, Any]):
|
||||
self.current_agent_id = agent_info["id"]
|
||||
self.current_agent_id = agent_info['id']
|
||||
self.current_task_id = "lite_task"
|
||||
|
||||
trace_key = f"{self.current_agent_id}_{self.current_task_id}"
|
||||
@@ -130,7 +110,7 @@ class EvaluationTraceCallback(BaseEventListener):
|
||||
tool_uses=[],
|
||||
llm_calls=[],
|
||||
start_time=datetime.now(),
|
||||
final_output=None,
|
||||
final_output=None
|
||||
)
|
||||
|
||||
def _init_trace(self, trace_key: str, **kwargs: Any):
|
||||
@@ -148,7 +128,7 @@ class EvaluationTraceCallback(BaseEventListener):
|
||||
tool_uses=[],
|
||||
llm_calls=[],
|
||||
start_time=datetime.now(),
|
||||
final_output=None,
|
||||
final_output=None
|
||||
)
|
||||
|
||||
def on_agent_finish(self, agent: Agent, task: Task, output: Any):
|
||||
@@ -171,14 +151,8 @@ class EvaluationTraceCallback(BaseEventListener):
|
||||
|
||||
self._reset_current()
|
||||
|
||||
def on_tool_use(
|
||||
self,
|
||||
tool_name: str,
|
||||
tool_args: dict[str, Any] | str,
|
||||
result: Any,
|
||||
success: bool = True,
|
||||
error_type: str | None = None,
|
||||
):
|
||||
def on_tool_use(self, tool_name: str, tool_args: dict[str, Any] | str, result: Any,
|
||||
success: bool = True, error_type: str | None = None):
|
||||
if not self.current_agent_id or not self.current_task_id:
|
||||
return
|
||||
|
||||
@@ -189,7 +163,7 @@ class EvaluationTraceCallback(BaseEventListener):
|
||||
"args": tool_args,
|
||||
"result": result,
|
||||
"success": success,
|
||||
"timestamp": datetime.now(),
|
||||
"timestamp": datetime.now()
|
||||
}
|
||||
|
||||
# Add error information if applicable
|
||||
@@ -199,11 +173,7 @@ class EvaluationTraceCallback(BaseEventListener):
|
||||
|
||||
self.traces[trace_key]["tool_uses"].append(tool_use)
|
||||
|
||||
def on_llm_call_start(
|
||||
self,
|
||||
messages: str | Sequence[dict[str, Any]] | None,
|
||||
tools: Sequence[dict[str, Any]] | None = None,
|
||||
):
|
||||
def on_llm_call_start(self, messages: str | Sequence[dict[str, Any]] | None, tools: Sequence[dict[str, Any]] | None = None):
|
||||
if not self.current_agent_id or not self.current_task_id:
|
||||
return
|
||||
|
||||
@@ -216,12 +186,10 @@ class EvaluationTraceCallback(BaseEventListener):
|
||||
"tools": tools,
|
||||
"start_time": datetime.now(),
|
||||
"response": None,
|
||||
"end_time": None,
|
||||
"end_time": None
|
||||
}
|
||||
|
||||
def on_llm_call_end(
|
||||
self, messages: str | list[dict[str, Any]] | None, response: Any
|
||||
):
|
||||
def on_llm_call_end(self, messages: str | list[dict[str, Any]] | None, response: Any):
|
||||
if not self.current_agent_id or not self.current_task_id:
|
||||
return
|
||||
|
||||
@@ -245,7 +213,7 @@ class EvaluationTraceCallback(BaseEventListener):
|
||||
"response": response,
|
||||
"start_time": start_time,
|
||||
"end_time": current_time,
|
||||
"total_tokens": total_tokens,
|
||||
"total_tokens": total_tokens
|
||||
}
|
||||
|
||||
self.traces[trace_key]["llm_calls"].append(llm_call)
|
||||
@@ -259,7 +227,7 @@ class EvaluationTraceCallback(BaseEventListener):
|
||||
|
||||
|
||||
def create_evaluation_callbacks() -> EvaluationTraceCallback:
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.utilities.events.crewai_event_bus import crewai_event_bus
|
||||
|
||||
callback = EvaluationTraceCallback()
|
||||
callback.setup_listeners(crewai_event_bus)
|
||||
|
||||
@@ -25,8 +25,8 @@ from crewai.flow.flow_visualizer import plot_flow
|
||||
from crewai.flow.persistence.base import FlowPersistence
|
||||
from crewai.flow.types import FlowExecutionData
|
||||
from crewai.flow.utils import get_possible_return_constants
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.flow_events import (
|
||||
from crewai.utilities.events.crewai_event_bus import crewai_event_bus
|
||||
from crewai.utilities.events.flow_events import (
|
||||
FlowCreatedEvent,
|
||||
FlowFinishedEvent,
|
||||
FlowPlotEvent,
|
||||
@@ -35,10 +35,10 @@ from crewai.events.types.flow_events import (
|
||||
MethodExecutionFinishedEvent,
|
||||
MethodExecutionStartedEvent,
|
||||
)
|
||||
from crewai.events.listeners.tracing.trace_listener import (
|
||||
from crewai.utilities.events.listeners.tracing.trace_listener import (
|
||||
TraceCollectionListener,
|
||||
)
|
||||
from crewai.events.listeners.tracing.utils import (
|
||||
from crewai.utilities.events.listeners.tracing.utils import (
|
||||
is_tracing_enabled,
|
||||
)
|
||||
from crewai.utilities.printer import Printer
|
||||
@@ -474,7 +474,6 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
self._method_outputs: List[Any] = [] # List to store all method outputs
|
||||
self._completed_methods: Set[str] = set() # Track completed methods for reload
|
||||
self._persistence: Optional[FlowPersistence] = persistence
|
||||
self._is_execution_resuming: bool = False
|
||||
|
||||
# Initialize state with initial values
|
||||
self._state = self._create_initial_state()
|
||||
@@ -830,9 +829,6 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
# Clear completed methods and outputs for a fresh start
|
||||
self._completed_methods.clear()
|
||||
self._method_outputs.clear()
|
||||
else:
|
||||
# We're restoring from persistence, set the flag
|
||||
self._is_execution_resuming = True
|
||||
|
||||
if inputs:
|
||||
# Override the id in the state if it exists in inputs
|
||||
@@ -884,9 +880,6 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
]
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
# Clear the resumption flag after initial execution completes
|
||||
self._is_execution_resuming = False
|
||||
|
||||
final_output = self._method_outputs[-1] if self._method_outputs else None
|
||||
|
||||
crewai_event_bus.emit(
|
||||
@@ -923,23 +916,19 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
- Automatically injects crewai_trigger_payload if available in flow inputs
|
||||
"""
|
||||
if start_method_name in self._completed_methods:
|
||||
if self._is_execution_resuming:
|
||||
# During resumption, skip execution but continue listeners
|
||||
last_output = self._method_outputs[-1] if self._method_outputs else None
|
||||
await self._execute_listeners(start_method_name, last_output)
|
||||
return
|
||||
# For cyclic flows, clear from completed to allow re-execution
|
||||
self._completed_methods.discard(start_method_name)
|
||||
last_output = self._method_outputs[-1] if self._method_outputs else None
|
||||
await self._execute_listeners(start_method_name, last_output)
|
||||
return
|
||||
|
||||
method = self._methods[start_method_name]
|
||||
enhanced_method = self._inject_trigger_payload_for_start_method(method)
|
||||
|
||||
result = await self._execute_method(start_method_name, enhanced_method)
|
||||
result = await self._execute_method(
|
||||
start_method_name, enhanced_method
|
||||
)
|
||||
await self._execute_listeners(start_method_name, result)
|
||||
|
||||
def _inject_trigger_payload_for_start_method(
|
||||
self, original_method: Callable
|
||||
) -> Callable:
|
||||
def _inject_trigger_payload_for_start_method(self, original_method: Callable) -> Callable:
|
||||
def prepare_kwargs(*args, **kwargs):
|
||||
inputs = baggage.get_baggage("flow_inputs") or {}
|
||||
trigger_payload = inputs.get("crewai_trigger_payload")
|
||||
@@ -952,17 +941,15 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
elif trigger_payload is not None:
|
||||
self._log_flow_event(
|
||||
f"Trigger payload available but {original_method.__name__} doesn't accept crewai_trigger_payload parameter",
|
||||
color="yellow",
|
||||
color="yellow"
|
||||
)
|
||||
return args, kwargs
|
||||
|
||||
if asyncio.iscoroutinefunction(original_method):
|
||||
|
||||
async def enhanced_method(*args, **kwargs):
|
||||
args, kwargs = prepare_kwargs(*args, **kwargs)
|
||||
return await original_method(*args, **kwargs)
|
||||
else:
|
||||
|
||||
def enhanced_method(*args, **kwargs):
|
||||
args, kwargs = prepare_kwargs(*args, **kwargs)
|
||||
return original_method(*args, **kwargs)
|
||||
@@ -1063,15 +1050,11 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
for router_name in routers_triggered:
|
||||
await self._execute_single_listener(router_name, result)
|
||||
# After executing router, the router's result is the path
|
||||
router_result = (
|
||||
self._method_outputs[-1] if self._method_outputs else None
|
||||
)
|
||||
router_result = self._method_outputs[-1]
|
||||
if router_result: # Only add non-None results
|
||||
router_results.append(router_result)
|
||||
current_trigger = (
|
||||
str(router_result)
|
||||
if router_result is not None
|
||||
else "" # Update for next iteration of router chain
|
||||
router_result # Update for next iteration of router chain
|
||||
)
|
||||
|
||||
# Now execute normal listeners for all router results and the original trigger
|
||||
@@ -1089,24 +1072,6 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
]
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
if current_trigger in router_results:
|
||||
# Find start methods triggered by this router result
|
||||
for method_name in self._start_methods:
|
||||
# Check if this start method is triggered by the current trigger
|
||||
if method_name in self._listeners:
|
||||
condition_type, trigger_methods = self._listeners[
|
||||
method_name
|
||||
]
|
||||
if current_trigger in trigger_methods:
|
||||
# Only execute if this is a cycle (method was already completed)
|
||||
if method_name in self._completed_methods:
|
||||
# For router-triggered start methods in cycles, temporarily clear resumption flag
|
||||
# to allow cyclic execution
|
||||
was_resuming = self._is_execution_resuming
|
||||
self._is_execution_resuming = False
|
||||
await self._execute_start_method(method_name)
|
||||
self._is_execution_resuming = was_resuming
|
||||
|
||||
def _find_triggered_methods(
|
||||
self, trigger_method: str, router_only: bool
|
||||
) -> List[str]:
|
||||
@@ -1144,9 +1109,6 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
if router_only != is_router:
|
||||
continue
|
||||
|
||||
if not router_only and listener_name in self._start_methods:
|
||||
continue
|
||||
|
||||
if condition_type == "OR":
|
||||
# If the trigger_method matches any in methods, run this
|
||||
if trigger_method in methods:
|
||||
@@ -1196,13 +1158,10 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
Catches and logs any exceptions during execution, preventing
|
||||
individual listener failures from breaking the entire flow.
|
||||
"""
|
||||
if listener_name in self._completed_methods:
|
||||
if self._is_execution_resuming:
|
||||
# During resumption, skip execution but continue listeners
|
||||
await self._execute_listeners(listener_name, None)
|
||||
return
|
||||
# For cyclic flows, clear from completed to allow re-execution
|
||||
self._completed_methods.discard(listener_name)
|
||||
# TODO: greyson fix
|
||||
# if listener_name in self._completed_methods:
|
||||
# await self._execute_listeners(listener_name, None)
|
||||
# return
|
||||
|
||||
try:
|
||||
method = self._methods[listener_name]
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import os
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
@@ -18,20 +18,20 @@ class Knowledge(BaseModel):
|
||||
embedder: Optional[Dict[str, Any]] = None
|
||||
"""
|
||||
|
||||
sources: list[BaseKnowledgeSource] = Field(default_factory=list)
|
||||
sources: List[BaseKnowledgeSource] = Field(default_factory=list)
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
storage: Optional[KnowledgeStorage] = Field(default=None)
|
||||
embedder: Optional[dict[str, Any]] = None
|
||||
embedder: Optional[Dict[str, Any]] = None
|
||||
collection_name: Optional[str] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
collection_name: str,
|
||||
sources: list[BaseKnowledgeSource],
|
||||
embedder: Optional[dict[str, Any]] = None,
|
||||
sources: List[BaseKnowledgeSource],
|
||||
embedder: Optional[Dict[str, Any]] = None,
|
||||
storage: Optional[KnowledgeStorage] = None,
|
||||
**data: Any,
|
||||
) -> None:
|
||||
**data,
|
||||
):
|
||||
super().__init__(**data)
|
||||
if storage:
|
||||
self.storage = storage
|
||||
@@ -43,8 +43,8 @@ class Knowledge(BaseModel):
|
||||
self.storage.initialize_knowledge_storage()
|
||||
|
||||
def query(
|
||||
self, query: list[str], results_limit: int = 3, score_threshold: float = 0.35
|
||||
) -> list[dict[str, Any]]:
|
||||
self, query: List[str], results_limit: int = 3, score_threshold: float = 0.35
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Query across all knowledge sources to find the most relevant information.
|
||||
Returns the top_k most relevant chunks.
|
||||
@@ -62,7 +62,7 @@ class Knowledge(BaseModel):
|
||||
)
|
||||
return results
|
||||
|
||||
def add_sources(self) -> None:
|
||||
def add_sources(self):
|
||||
try:
|
||||
for source in self.sources:
|
||||
source.storage = self.storage
|
||||
|
||||
@@ -1,57 +1,42 @@
|
||||
import contextlib
|
||||
import hashlib
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import warnings
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import chromadb
|
||||
import chromadb.errors
|
||||
from chromadb import EmbeddingFunction
|
||||
from chromadb.api import ClientAPI
|
||||
from chromadb.api.types import OneOrMany
|
||||
from chromadb.config import Settings
|
||||
import warnings
|
||||
|
||||
from crewai.knowledge.storage.base_knowledge_storage import BaseKnowledgeStorage
|
||||
from crewai.rag.embeddings.configurator import EmbeddingConfigurator
|
||||
from crewai.utilities.chromadb import create_persistent_client, sanitize_collection_name
|
||||
from crewai.utilities.chromadb import sanitize_collection_name
|
||||
from crewai.utilities.constants import KNOWLEDGE_DIRECTORY
|
||||
from crewai.utilities.logger import Logger
|
||||
from crewai.utilities.logger_utils import suppress_logging
|
||||
from crewai.utilities.paths import db_storage_path
|
||||
from crewai.utilities.chromadb import create_persistent_client
|
||||
|
||||
|
||||
def _extract_chromadb_response_item(
|
||||
response_data: Any,
|
||||
index: int,
|
||||
expected_type: type[Any] | tuple[type[Any], ...],
|
||||
) -> Any | None:
|
||||
"""Extract an item from ChromaDB response data at the given index.
|
||||
|
||||
Args:
|
||||
response_data: The response data from ChromaDB query (e.g., documents, metadatas).
|
||||
index: The index of the item to extract.
|
||||
expected_type: The expected type(s) of the item.
|
||||
|
||||
Returns:
|
||||
The extracted item if it exists and matches the expected type, otherwise None.
|
||||
"""
|
||||
if response_data is None or not response_data:
|
||||
return None
|
||||
|
||||
# ChromaDB sometimes returns nested lists, handle both cases
|
||||
data_list = (
|
||||
response_data[0]
|
||||
if response_data and isinstance(response_data[0], list)
|
||||
else response_data
|
||||
)
|
||||
|
||||
if index < len(data_list):
|
||||
item = data_list[index]
|
||||
if isinstance(item, expected_type):
|
||||
return item
|
||||
return None
|
||||
@contextlib.contextmanager
|
||||
def suppress_logging(
|
||||
logger_name="chromadb.segment.impl.vector.local_persistent_hnsw",
|
||||
level=logging.ERROR,
|
||||
):
|
||||
logger = logging.getLogger(logger_name)
|
||||
original_level = logger.getEffectiveLevel()
|
||||
logger.setLevel(level)
|
||||
with (
|
||||
contextlib.redirect_stdout(io.StringIO()),
|
||||
contextlib.redirect_stderr(io.StringIO()),
|
||||
contextlib.suppress(UserWarning),
|
||||
):
|
||||
yield
|
||||
logger.setLevel(original_level)
|
||||
|
||||
|
||||
class KnowledgeStorage(BaseKnowledgeStorage):
|
||||
@@ -63,11 +48,10 @@ class KnowledgeStorage(BaseKnowledgeStorage):
|
||||
collection: Optional[chromadb.Collection] = None
|
||||
collection_name: Optional[str] = "knowledge"
|
||||
app: Optional[ClientAPI] = None
|
||||
embedder: Optional[EmbeddingFunction[Any]] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedder: Optional[dict[str, Any]] = None,
|
||||
embedder: Optional[Dict[str, Any]] = None,
|
||||
collection_name: Optional[str] = None,
|
||||
):
|
||||
self.collection_name = collection_name
|
||||
@@ -75,14 +59,12 @@ class KnowledgeStorage(BaseKnowledgeStorage):
|
||||
|
||||
def search(
|
||||
self,
|
||||
query: list[str],
|
||||
query: List[str],
|
||||
limit: int = 3,
|
||||
filter: Optional[dict[str, Any]] = None,
|
||||
filter: Optional[dict] = None,
|
||||
score_threshold: float = 0.35,
|
||||
) -> list[dict[str, Any]]:
|
||||
with suppress_logging(
|
||||
"chromadb.segment.impl.vector.local_persistent_hnsw", logging.ERROR
|
||||
):
|
||||
) -> List[Dict[str, Any]]:
|
||||
with suppress_logging():
|
||||
if self.collection:
|
||||
fetched = self.collection.query(
|
||||
query_texts=query,
|
||||
@@ -90,51 +72,20 @@ class KnowledgeStorage(BaseKnowledgeStorage):
|
||||
where=filter,
|
||||
)
|
||||
results = []
|
||||
if (
|
||||
fetched
|
||||
and "ids" in fetched
|
||||
and fetched["ids"]
|
||||
and len(fetched["ids"]) > 0
|
||||
):
|
||||
ids_list = (
|
||||
fetched["ids"][0]
|
||||
if isinstance(fetched["ids"][0], list)
|
||||
else fetched["ids"]
|
||||
)
|
||||
for i in range(len(ids_list)):
|
||||
# Handle metadatas
|
||||
meta_item = _extract_chromadb_response_item(
|
||||
fetched.get("metadatas"), i, dict
|
||||
)
|
||||
metadata: dict[str, Any] = meta_item if meta_item else {}
|
||||
|
||||
# Handle documents
|
||||
doc_item = _extract_chromadb_response_item(
|
||||
fetched.get("documents"), i, str
|
||||
)
|
||||
context = doc_item if doc_item else ""
|
||||
|
||||
# Handle distances
|
||||
dist_item = _extract_chromadb_response_item(
|
||||
fetched.get("distances"), i, (int, float)
|
||||
)
|
||||
score = dist_item if dist_item is not None else 1.0
|
||||
|
||||
result = {
|
||||
"id": ids_list[i],
|
||||
"metadata": metadata,
|
||||
"context": context,
|
||||
"score": score,
|
||||
}
|
||||
|
||||
# Check score threshold - distances are smaller when more similar
|
||||
if isinstance(score, (int, float)) and score <= score_threshold:
|
||||
results.append(result)
|
||||
for i in range(len(fetched["ids"][0])): # type: ignore
|
||||
result = {
|
||||
"id": fetched["ids"][0][i], # type: ignore
|
||||
"metadata": fetched["metadatas"][0][i], # type: ignore
|
||||
"context": fetched["documents"][0][i], # type: ignore
|
||||
"score": fetched["distances"][0][i], # type: ignore
|
||||
}
|
||||
if result["score"] >= score_threshold:
|
||||
results.append(result)
|
||||
return results
|
||||
else:
|
||||
raise Exception("Collection not initialized")
|
||||
|
||||
def initialize_knowledge_storage(self) -> None:
|
||||
def initialize_knowledge_storage(self):
|
||||
# Suppress deprecation warnings from chromadb, which are not relevant to us
|
||||
# TODO: Remove this once we upgrade chromadb to at least 1.0.8.
|
||||
warnings.filterwarnings(
|
||||
@@ -164,7 +115,7 @@ class KnowledgeStorage(BaseKnowledgeStorage):
|
||||
except Exception:
|
||||
raise Exception("Failed to create or get collection")
|
||||
|
||||
def reset(self) -> None:
|
||||
def reset(self):
|
||||
base_path = os.path.join(db_storage_path(), KNOWLEDGE_DIRECTORY)
|
||||
if not self.app:
|
||||
self.app = create_persistent_client(
|
||||
@@ -178,9 +129,9 @@ class KnowledgeStorage(BaseKnowledgeStorage):
|
||||
|
||||
def save(
|
||||
self,
|
||||
documents: list[str],
|
||||
metadata: Optional[dict[str, Any] | list[dict[str, Any]]] = None,
|
||||
) -> None:
|
||||
documents: List[str],
|
||||
metadata: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
|
||||
):
|
||||
if not self.collection:
|
||||
raise Exception("Collection not initialized")
|
||||
|
||||
@@ -212,7 +163,7 @@ class KnowledgeStorage(BaseKnowledgeStorage):
|
||||
|
||||
# If we have no metadata at all, set it to None
|
||||
final_metadata: Optional[OneOrMany[chromadb.Metadata]] = (
|
||||
None if all(m is None for m in filtered_metadata) else filtered_metadata # type: ignore[assignment]
|
||||
None if all(m is None for m in filtered_metadata) else filtered_metadata
|
||||
)
|
||||
|
||||
self.collection.upsert(
|
||||
@@ -235,7 +186,7 @@ class KnowledgeStorage(BaseKnowledgeStorage):
|
||||
Logger(verbose=True).log("error", f"Failed to upsert documents: {e}", "red")
|
||||
raise
|
||||
|
||||
def _create_default_embedding_function(self) -> Any:
|
||||
def _create_default_embedding_function(self):
|
||||
from chromadb.utils.embedding_functions.openai_embedding_function import (
|
||||
OpenAIEmbeddingFunction,
|
||||
)
|
||||
@@ -244,18 +195,15 @@ class KnowledgeStorage(BaseKnowledgeStorage):
|
||||
api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small"
|
||||
)
|
||||
|
||||
def _set_embedder_config(self, embedder: Optional[dict[str, Any]] = None) -> None:
|
||||
def _set_embedder_config(self, embedder: Optional[Dict[str, Any]] = None) -> None:
|
||||
"""Set the embedding configuration for the knowledge storage.
|
||||
|
||||
Args:
|
||||
embedder (Optional[Dict[str, Any]]): Configuration dictionary for the embedder.
|
||||
embedder_config (Optional[Dict[str, Any]]): Configuration dictionary for the embedder.
|
||||
If None or empty, defaults to the default embedding function.
|
||||
|
||||
Notes:
|
||||
- TODO: Improve typing for embedder configuration, remove type: ignore
|
||||
"""
|
||||
self.embedder = (
|
||||
EmbeddingConfigurator().configure_embedder(embedder) # type: ignore
|
||||
EmbeddingConfigurator().configure_embedder(embedder)
|
||||
if embedder
|
||||
else self._create_default_embedding_function()
|
||||
)
|
||||
|
||||
@@ -1,49 +1,50 @@
|
||||
import asyncio
|
||||
import inspect
|
||||
import uuid
|
||||
from collections.abc import Callable
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
cast,
|
||||
get_args,
|
||||
get_origin,
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
from typing import Self
|
||||
except ImportError:
|
||||
from typing_extensions import Self
|
||||
|
||||
from pydantic import (
|
||||
UUID4,
|
||||
BaseModel,
|
||||
Field,
|
||||
InstanceOf,
|
||||
PrivateAttr,
|
||||
field_validator,
|
||||
model_validator,
|
||||
field_validator,
|
||||
)
|
||||
from typing_extensions import Self
|
||||
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess
|
||||
from crewai.agents.cache.cache_handler import CacheHandler
|
||||
from crewai.agents.cache import CacheHandler
|
||||
from crewai.agents.parser import (
|
||||
AgentAction,
|
||||
AgentFinish,
|
||||
OutputParserException,
|
||||
)
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.agent_events import (
|
||||
LiteAgentExecutionCompletedEvent,
|
||||
LiteAgentExecutionErrorEvent,
|
||||
LiteAgentExecutionStartedEvent,
|
||||
)
|
||||
from crewai.events.types.logging_events import AgentLogsExecutionEvent
|
||||
from crewai.flow.flow_trackable import FlowTrackable
|
||||
from crewai.llm import LLM
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.tasks import TaskOutput
|
||||
from crewai.llm import LLM, BaseLLM
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
from crewai.tools.structured_tool import CrewStructuredTool
|
||||
from crewai.utilities import I18N
|
||||
from crewai.utilities.guardrail import process_guardrail
|
||||
from crewai.utilities.agent_utils import (
|
||||
enforce_rpm_limit,
|
||||
format_message_for_llm,
|
||||
@@ -61,8 +62,20 @@ from crewai.utilities.agent_utils import (
|
||||
render_text_description_and_args,
|
||||
)
|
||||
from crewai.utilities.converter import generate_model_description
|
||||
from crewai.utilities.guardrail import process_guardrail
|
||||
from crewai.utilities.llm_utils import create_default_llm, create_llm
|
||||
from crewai.utilities.events.agent_events import (
|
||||
AgentLogsExecutionEvent,
|
||||
LiteAgentExecutionCompletedEvent,
|
||||
LiteAgentExecutionErrorEvent,
|
||||
LiteAgentExecutionStartedEvent,
|
||||
)
|
||||
from crewai.utilities.events.crewai_event_bus import crewai_event_bus
|
||||
from crewai.utilities.events.llm_events import (
|
||||
LLMCallCompletedEvent,
|
||||
LLMCallFailedEvent,
|
||||
LLMCallStartedEvent,
|
||||
LLMCallType,
|
||||
)
|
||||
from crewai.utilities.llm_utils import create_llm
|
||||
from crewai.utilities.printer import Printer
|
||||
from crewai.utilities.token_counter_callback import TokenCalcHandler
|
||||
from crewai.utilities.tool_utils import execute_tool_and_check_finality
|
||||
@@ -78,11 +91,11 @@ class LiteAgentOutput(BaseModel):
|
||||
description="Pydantic output of the agent", default=None
|
||||
)
|
||||
agent_role: str = Field(description="Role of the agent that produced this output")
|
||||
usage_metrics: Optional[dict[str, Any]] = Field(
|
||||
usage_metrics: Optional[Dict[str, Any]] = Field(
|
||||
description="Token usage metrics for this execution", default=None
|
||||
)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert pydantic_output to a dictionary."""
|
||||
if self.pydantic:
|
||||
return self.pydantic.model_dump()
|
||||
@@ -122,10 +135,10 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
role: str = Field(description="Role of the agent")
|
||||
goal: str = Field(description="Goal of the agent")
|
||||
backstory: str = Field(description="Backstory of the agent")
|
||||
llm: Optional[str | InstanceOf[BaseLLM] | Any] = Field(
|
||||
llm: Optional[Union[str, InstanceOf[BaseLLM], Any]] = Field(
|
||||
default=None, description="Language model that will run the agent"
|
||||
)
|
||||
tools: list[BaseTool] = Field(
|
||||
tools: List[BaseTool] = Field(
|
||||
default_factory=list, description="Tools at agent's disposal"
|
||||
)
|
||||
|
||||
@@ -151,27 +164,29 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
i18n: I18N = Field(default=I18N(), description="Internationalization settings.")
|
||||
|
||||
# Output and Formatting Properties
|
||||
response_format: Optional[type[BaseModel]] = Field(
|
||||
response_format: Optional[Type[BaseModel]] = Field(
|
||||
default=None, description="Pydantic model for structured output"
|
||||
)
|
||||
verbose: bool = Field(
|
||||
default=False, description="Whether to print execution details"
|
||||
)
|
||||
callbacks: list[Callable[..., Any]] = Field(
|
||||
callbacks: List[Callable] = Field(
|
||||
default=[], description="Callbacks to be used for the agent"
|
||||
)
|
||||
|
||||
# Guardrail Properties
|
||||
guardrail: Optional[Callable[[LiteAgentOutput], tuple[bool, Any]] | str] = Field(
|
||||
default=None,
|
||||
description="Function or string description of a guardrail to validate agent output",
|
||||
guardrail: Optional[Union[Callable[[LiteAgentOutput], Tuple[bool, Any]], str]] = (
|
||||
Field(
|
||||
default=None,
|
||||
description="Function or string description of a guardrail to validate agent output",
|
||||
)
|
||||
)
|
||||
guardrail_max_retries: int = Field(
|
||||
default=3, description="Maximum number of retries when guardrail fails"
|
||||
)
|
||||
|
||||
# State and Results
|
||||
tools_results: list[dict[str, Any]] = Field(
|
||||
tools_results: List[Dict[str, Any]] = Field(
|
||||
default=[], description="Results of the tools used by the agent."
|
||||
)
|
||||
|
||||
@@ -180,25 +195,20 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
default=None, description="Reference to the agent that created this LiteAgent"
|
||||
)
|
||||
# Private Attributes
|
||||
_parsed_tools: list[CrewStructuredTool] = PrivateAttr(default_factory=list)
|
||||
_parsed_tools: List[CrewStructuredTool] = PrivateAttr(default_factory=list)
|
||||
_token_process: TokenProcess = PrivateAttr(default_factory=TokenProcess)
|
||||
_cache_handler: CacheHandler = PrivateAttr(default_factory=CacheHandler)
|
||||
_key: str = PrivateAttr(default_factory=lambda: str(uuid.uuid4()))
|
||||
_messages: list[dict[str, str]] = PrivateAttr(default_factory=list)
|
||||
_messages: List[Dict[str, str]] = PrivateAttr(default_factory=list)
|
||||
_iterations: int = PrivateAttr(default=0)
|
||||
_printer: Printer = PrivateAttr(default_factory=Printer)
|
||||
_guardrail: Optional[Callable[[LiteAgentOutput | TaskOutput], tuple[bool, Any]]] = (
|
||||
PrivateAttr(default=None)
|
||||
)
|
||||
_guardrail: Optional[Callable] = PrivateAttr(default=None)
|
||||
_guardrail_retry_count: int = PrivateAttr(default=0)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def setup_llm(self) -> Self:
|
||||
def setup_llm(self):
|
||||
"""Set up the LLM and other components after initialization."""
|
||||
if self.llm is None:
|
||||
self.llm = create_default_llm()
|
||||
else:
|
||||
self.llm = create_llm(self.llm)
|
||||
self.llm = create_llm(self.llm)
|
||||
if not isinstance(self.llm, BaseLLM):
|
||||
raise ValueError(
|
||||
f"Expected LLM instance of type BaseLLM, got {type(self.llm).__name__}"
|
||||
@@ -211,7 +221,7 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def parse_tools(self) -> Self:
|
||||
def parse_tools(self):
|
||||
"""Parse the tools and convert them to CrewStructuredTool instances."""
|
||||
self._parsed_tools = parse_tools(self.tools)
|
||||
|
||||
@@ -220,10 +230,7 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
@model_validator(mode="after")
|
||||
def ensure_guardrail_is_callable(self) -> Self:
|
||||
if callable(self.guardrail):
|
||||
self._guardrail = cast(
|
||||
Callable[[LiteAgentOutput | TaskOutput], tuple[bool, Any]],
|
||||
self.guardrail,
|
||||
)
|
||||
self._guardrail = self.guardrail
|
||||
elif isinstance(self.guardrail, str):
|
||||
from crewai.tasks.llm_guardrail import LLMGuardrail
|
||||
|
||||
@@ -232,18 +239,15 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
f"Guardrail requires LLM instance of type BaseLLM, got {type(self.llm).__name__}"
|
||||
)
|
||||
|
||||
self._guardrail = cast(
|
||||
Callable[[LiteAgentOutput | TaskOutput], tuple[bool, Any]],
|
||||
LLMGuardrail(description=self.guardrail, llm=self.llm),
|
||||
)
|
||||
self._guardrail = LLMGuardrail(description=self.guardrail, llm=self.llm)
|
||||
|
||||
return self
|
||||
|
||||
@field_validator("guardrail", mode="before")
|
||||
@classmethod
|
||||
def validate_guardrail_function(
|
||||
cls, v: Optional[Callable[[Any], tuple[bool, Any]] | str]
|
||||
) -> Optional[Callable[[Any], tuple[bool, Any]] | str]:
|
||||
cls, v: Optional[Union[Callable, str]]
|
||||
) -> Optional[Union[Callable, str]]:
|
||||
"""Validate that the guardrail function has the correct signature.
|
||||
|
||||
If v is a callable, validate that it has the correct signature.
|
||||
@@ -268,7 +272,7 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
|
||||
# Check return annotation if present
|
||||
if sig.return_annotation is not sig.empty:
|
||||
if sig.return_annotation == tuple[bool, Any]:
|
||||
if sig.return_annotation == Tuple[bool, Any]:
|
||||
return v
|
||||
|
||||
origin = get_origin(sig.return_annotation)
|
||||
@@ -291,7 +295,7 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
"""Return the original role for compatibility with tool interfaces."""
|
||||
return self.role
|
||||
|
||||
def kickoff(self, messages: str | list[dict[str, str]]) -> LiteAgentOutput:
|
||||
def kickoff(self, messages: Union[str, List[Dict[str, str]]]) -> LiteAgentOutput:
|
||||
"""
|
||||
Execute the agent with the given messages.
|
||||
|
||||
@@ -339,7 +343,7 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
)
|
||||
raise e
|
||||
|
||||
def _execute_core(self, agent_info: dict[str, Any]) -> LiteAgentOutput:
|
||||
def _execute_core(self, agent_info: Dict[str, Any]) -> LiteAgentOutput:
|
||||
# Emit event for agent execution start
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
@@ -429,7 +433,7 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
return output
|
||||
|
||||
async def kickoff_async(
|
||||
self, messages: str | list[dict[str, str]]
|
||||
self, messages: Union[str, List[Dict[str, str]]]
|
||||
) -> LiteAgentOutput:
|
||||
"""
|
||||
Execute the agent asynchronously with the given messages.
|
||||
@@ -476,8 +480,8 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
return base_prompt
|
||||
|
||||
def _format_messages(
|
||||
self, messages: str | list[dict[str, str]]
|
||||
) -> list[dict[str, str]]:
|
||||
self, messages: Union[str, List[Dict[str, str]]]
|
||||
) -> List[Dict[str, str]]:
|
||||
"""Format messages for the LLM."""
|
||||
if isinstance(messages, str):
|
||||
messages = [{"role": "user", "content": messages}]
|
||||
@@ -515,6 +519,19 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
|
||||
enforce_rpm_limit(self.request_within_rpm_limit)
|
||||
|
||||
llm = cast(LLM, self.llm)
|
||||
model = llm.model if hasattr(llm, "model") else "unknown"
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LLMCallStartedEvent(
|
||||
messages=self._messages,
|
||||
tools=None,
|
||||
callbacks=self._callbacks,
|
||||
from_agent=self,
|
||||
model=model,
|
||||
),
|
||||
)
|
||||
|
||||
try:
|
||||
answer = get_llm_response(
|
||||
llm=cast(LLM, self.llm),
|
||||
@@ -524,7 +541,23 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
from_agent=self,
|
||||
)
|
||||
|
||||
# Emit LLM call completed event
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LLMCallCompletedEvent(
|
||||
messages=self._messages,
|
||||
response=answer,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_agent=self,
|
||||
model=model,
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
# Emit LLM call failed event
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LLMCallFailedEvent(error=str(e), from_agent=self),
|
||||
)
|
||||
raise e
|
||||
|
||||
formatted_answer = process_llm_response(answer, self.use_stop_words)
|
||||
@@ -583,7 +616,7 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
self._show_logs(formatted_answer)
|
||||
return formatted_answer
|
||||
|
||||
def _show_logs(self, formatted_answer: AgentAction | AgentFinish) -> None:
|
||||
def _show_logs(self, formatted_answer: Union[AgentAction, AgentFinish]):
|
||||
"""Show logs for the agent's execution."""
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
|
||||
@@ -23,14 +23,14 @@ from dotenv import load_dotenv
|
||||
from litellm.types.utils import ChatCompletionDeltaToolCall
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai.events.types.llm_events import (
|
||||
from crewai.utilities.events.llm_events import (
|
||||
LLMCallCompletedEvent,
|
||||
LLMCallFailedEvent,
|
||||
LLMCallStartedEvent,
|
||||
LLMCallType,
|
||||
LLMStreamChunkEvent,
|
||||
)
|
||||
from crewai.events.types.tool_usage_events import (
|
||||
from crewai.utilities.events.tool_usage_events import (
|
||||
ToolUsageStartedEvent,
|
||||
ToolUsageFinishedEvent,
|
||||
ToolUsageErrorEvent,
|
||||
@@ -52,7 +52,7 @@ import io
|
||||
from typing import TextIO
|
||||
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.utilities.events import crewai_event_bus
|
||||
from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
||||
LLMContextLengthExceededException,
|
||||
)
|
||||
@@ -311,7 +311,7 @@ class LLM(BaseLLM):
|
||||
api_base: Optional[str] = None,
|
||||
api_version: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
callbacks: List[Any] | None = None,
|
||||
callbacks: List[Any] = [],
|
||||
reasoning_effort: Optional[Literal["none", "low", "medium", "high"]] = None,
|
||||
stream: bool = False,
|
||||
**kwargs,
|
||||
@@ -351,7 +351,7 @@ class LLM(BaseLLM):
|
||||
else:
|
||||
self.stop = stop
|
||||
|
||||
self.set_callbacks(callbacks or [])
|
||||
self.set_callbacks(callbacks)
|
||||
self.set_env_callbacks()
|
||||
|
||||
def _is_anthropic_model(self, model: str) -> bool:
|
||||
@@ -851,9 +851,7 @@ class LLM(BaseLLM):
|
||||
return tool_calls
|
||||
|
||||
# --- 7) Handle tool calls if present
|
||||
tool_result = self._handle_tool_call(
|
||||
tool_calls, available_functions, from_task, from_agent
|
||||
)
|
||||
tool_result = self._handle_tool_call(tool_calls, available_functions)
|
||||
if tool_result is not None:
|
||||
return tool_result
|
||||
# --- 8) If tool call handling didn't return a result, emit completion event and return text response
|
||||
@@ -870,8 +868,6 @@ class LLM(BaseLLM):
|
||||
self,
|
||||
tool_calls: List[Any],
|
||||
available_functions: Optional[Dict[str, Any]] = None,
|
||||
from_task: Optional[Any] = None,
|
||||
from_agent: Optional[Any] = None,
|
||||
) -> Optional[str]:
|
||||
"""Handle a tool call from the LLM.
|
||||
|
||||
@@ -906,8 +902,6 @@ class LLM(BaseLLM):
|
||||
event=ToolUsageStartedEvent(
|
||||
tool_name=function_name,
|
||||
tool_args=function_args,
|
||||
from_agent=from_agent,
|
||||
from_task=from_task,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -920,17 +914,12 @@ class LLM(BaseLLM):
|
||||
tool_args=function_args,
|
||||
started_at=started_at,
|
||||
finished_at=datetime.now(),
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
|
||||
# --- 3.3) Emit success event
|
||||
self._handle_emit_call_events(
|
||||
response=result,
|
||||
call_type=LLMCallType.TOOL_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
response=result, call_type=LLMCallType.TOOL_CALL
|
||||
)
|
||||
return result
|
||||
except Exception as e:
|
||||
@@ -950,8 +939,6 @@ class LLM(BaseLLM):
|
||||
tool_name=function_name,
|
||||
tool_args=function_args,
|
||||
error=f"Tool execution error: {str(e)}",
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
return None
|
||||
@@ -1152,11 +1139,7 @@ class LLM(BaseLLM):
|
||||
|
||||
# TODO: Remove this code after merging PR https://github.com/BerriAI/litellm/pull/10917
|
||||
# Ollama doesn't supports last message to be 'assistant'
|
||||
if (
|
||||
"ollama" in self.model.lower()
|
||||
and messages
|
||||
and messages[-1]["role"] == "assistant"
|
||||
):
|
||||
if "ollama" in self.model.lower() and messages and messages[-1]["role"] == "assistant":
|
||||
return messages + [{"role": "user", "content": ""}]
|
||||
|
||||
# Handle Anthropic models
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
from typing import Optional
|
||||
|
||||
from crewai.memory import (
|
||||
EntityMemory,
|
||||
@@ -9,10 +7,6 @@ from crewai.memory import (
|
||||
ShortTermMemory,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.agent import Agent
|
||||
from crewai.task import Task
|
||||
|
||||
|
||||
class ContextualMemory:
|
||||
def __init__(
|
||||
@@ -21,30 +15,13 @@ class ContextualMemory:
|
||||
ltm: LongTermMemory,
|
||||
em: EntityMemory,
|
||||
exm: ExternalMemory,
|
||||
agent: Optional[Agent] = None,
|
||||
task: Optional[Task] = None,
|
||||
):
|
||||
self.stm = stm
|
||||
self.ltm = ltm
|
||||
self.em = em
|
||||
self.exm = exm
|
||||
self.agent = agent
|
||||
self.task = task
|
||||
|
||||
if self.stm is not None:
|
||||
self.stm.agent = self.agent
|
||||
self.stm.task = self.task
|
||||
if self.ltm is not None:
|
||||
self.ltm.agent = self.agent
|
||||
self.ltm.task = self.task
|
||||
if self.em is not None:
|
||||
self.em.agent = self.agent
|
||||
self.em.task = self.task
|
||||
if self.exm is not None:
|
||||
self.exm.agent = self.agent
|
||||
self.exm.task = self.task
|
||||
|
||||
def build_context_for_task(self, task: Task, context: str) -> str:
|
||||
def build_context_for_task(self, task, context) -> str:
|
||||
"""
|
||||
Automatically builds a minimal, highly relevant set of contextual information
|
||||
for a given task.
|
||||
@@ -54,14 +31,14 @@ class ContextualMemory:
|
||||
if query == "":
|
||||
return ""
|
||||
|
||||
context_parts = []
|
||||
context_parts.append(self._fetch_ltm_context(task.description))
|
||||
context_parts.append(self._fetch_stm_context(query))
|
||||
context_parts.append(self._fetch_entity_context(query))
|
||||
context_parts.append(self._fetch_external_context(query))
|
||||
return "\n".join(filter(None, context_parts))
|
||||
context = []
|
||||
context.append(self._fetch_ltm_context(task.description))
|
||||
context.append(self._fetch_stm_context(query))
|
||||
context.append(self._fetch_entity_context(query))
|
||||
context.append(self._fetch_external_context(query))
|
||||
return "\n".join(filter(None, context))
|
||||
|
||||
def _fetch_stm_context(self, query: str) -> str:
|
||||
def _fetch_stm_context(self, query) -> str:
|
||||
"""
|
||||
Fetches recent relevant insights from STM related to the task's description and expected_output,
|
||||
formatted as bullet points.
|
||||
@@ -72,11 +49,14 @@ class ContextualMemory:
|
||||
|
||||
stm_results = self.stm.search(query)
|
||||
formatted_results = "\n".join(
|
||||
[f"- {result['context']}" for result in stm_results]
|
||||
[
|
||||
f"- {result['context']}"
|
||||
for result in stm_results
|
||||
]
|
||||
)
|
||||
return f"Recent Insights:\n{formatted_results}" if stm_results else ""
|
||||
|
||||
def _fetch_ltm_context(self, task: str) -> Optional[str]:
|
||||
def _fetch_ltm_context(self, task) -> Optional[str]:
|
||||
"""
|
||||
Fetches historical data or insights from LTM that are relevant to the task's description and expected_output,
|
||||
formatted as bullet points.
|
||||
@@ -85,23 +65,21 @@ class ContextualMemory:
|
||||
if self.ltm is None:
|
||||
return ""
|
||||
|
||||
ltm_results = self.ltm.search(task, limit=2)
|
||||
ltm_results = self.ltm.search(task, latest_n=2)
|
||||
if not ltm_results:
|
||||
return None
|
||||
|
||||
formatted_results = [
|
||||
suggestion
|
||||
for result in ltm_results
|
||||
for suggestion in result["metadata"]["suggestions"]
|
||||
for suggestion in result["metadata"]["suggestions"] # type: ignore # Invalid index type "str" for "str"; expected type "SupportsIndex | slice"
|
||||
]
|
||||
formatted_results = list(dict.fromkeys(formatted_results))
|
||||
formatted_results_str = "\n".join(
|
||||
[f"- {result}" for result in formatted_results]
|
||||
)
|
||||
formatted_results = "\n".join([f"- {result}" for result in formatted_results]) # type: ignore # Incompatible types in assignment (expression has type "str", variable has type "list[str]")
|
||||
|
||||
return f"Historical Data:\n{formatted_results_str}" if ltm_results else ""
|
||||
return f"Historical Data:\n{formatted_results}" if ltm_results else ""
|
||||
|
||||
def _fetch_entity_context(self, query: str) -> str:
|
||||
def _fetch_entity_context(self, query) -> str:
|
||||
"""
|
||||
Fetches relevant entity information from Entity Memory related to the task's description and expected_output,
|
||||
formatted as bullet points.
|
||||
@@ -111,7 +89,10 @@ class ContextualMemory:
|
||||
|
||||
em_results = self.em.search(query)
|
||||
formatted_results = "\n".join(
|
||||
[f"- {result['context']}" for result in em_results]
|
||||
[
|
||||
f"- {result['context']}"
|
||||
for result in em_results
|
||||
] # type: ignore # Invalid index type "str" for "str"; expected type "SupportsIndex | slice"
|
||||
)
|
||||
return f"Entities:\n{formatted_results}" if em_results else ""
|
||||
|
||||
|
||||
@@ -1,20 +1,20 @@
|
||||
from typing import Optional
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from pydantic import PrivateAttr
|
||||
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.memory_events import (
|
||||
MemoryQueryCompletedEvent,
|
||||
MemoryQueryFailedEvent,
|
||||
MemoryQueryStartedEvent,
|
||||
MemorySaveCompletedEvent,
|
||||
MemorySaveFailedEvent,
|
||||
MemorySaveStartedEvent,
|
||||
)
|
||||
from crewai.memory.entity.entity_memory_item import EntityMemoryItem
|
||||
from crewai.memory.memory import Memory
|
||||
from crewai.memory.storage.rag_storage import RAGStorage
|
||||
from crewai.utilities.events.crewai_event_bus import crewai_event_bus
|
||||
from crewai.utilities.events.memory_events import (
|
||||
MemoryQueryStartedEvent,
|
||||
MemoryQueryCompletedEvent,
|
||||
MemoryQueryFailedEvent,
|
||||
MemorySaveStartedEvent,
|
||||
MemorySaveCompletedEvent,
|
||||
MemorySaveFailedEvent,
|
||||
)
|
||||
|
||||
|
||||
class EntityMemory(Memory):
|
||||
@@ -24,15 +24,9 @@ class EntityMemory(Memory):
|
||||
Inherits from the Memory class.
|
||||
"""
|
||||
|
||||
_memory_provider: str | None = PrivateAttr()
|
||||
_memory_provider: Optional[str] = PrivateAttr()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
crew: Any = None,
|
||||
embedder_config: Any = None,
|
||||
storage: Any = None,
|
||||
path: Any = None,
|
||||
) -> None:
|
||||
def __init__(self, crew=None, embedder_config=None, storage=None, path=None):
|
||||
memory_provider = embedder_config.get("provider") if embedder_config else None
|
||||
if memory_provider == "mem0":
|
||||
try:
|
||||
@@ -41,7 +35,7 @@ class EntityMemory(Memory):
|
||||
raise ImportError(
|
||||
"Mem0 is not installed. Please install it with `pip install mem0ai`."
|
||||
)
|
||||
config = embedder_config.get("config") if embedder_config else None
|
||||
config = embedder_config.get("config")
|
||||
storage = Mem0Storage(type="short_term", crew=crew, config=config)
|
||||
else:
|
||||
storage = (
|
||||
@@ -59,99 +53,47 @@ class EntityMemory(Memory):
|
||||
super().__init__(storage=storage)
|
||||
self._memory_provider = memory_provider
|
||||
|
||||
def save(
|
||||
self,
|
||||
value: EntityMemoryItem | list[EntityMemoryItem],
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""Saves one or more entity items into the SQLite storage.
|
||||
|
||||
Args:
|
||||
value: Single EntityMemoryItem or list of EntityMemoryItems to save.
|
||||
metadata: Optional metadata dict (included for supertype compatibility but not used).
|
||||
|
||||
Notes:
|
||||
The metadata parameter is included to satisfy the supertype signature but is not
|
||||
used - entity metadata is extracted from the EntityMemoryItem objects themselves.
|
||||
"""
|
||||
|
||||
if not value:
|
||||
return
|
||||
|
||||
items = value if isinstance(value, list) else [value]
|
||||
is_batch = len(items) > 1
|
||||
|
||||
metadata = {"entity_count": len(items)} if is_batch else items[0].metadata
|
||||
def save(self, item: EntityMemoryItem) -> None: # type: ignore # BUG?: Signature of "save" incompatible with supertype "Memory"
|
||||
"""Saves an entity item into the SQLite storage."""
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemorySaveStartedEvent(
|
||||
metadata=metadata,
|
||||
metadata=item.metadata,
|
||||
source_type="entity_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
saved_count = 0
|
||||
errors = []
|
||||
|
||||
try:
|
||||
for item in items:
|
||||
try:
|
||||
if self._memory_provider == "mem0":
|
||||
data = f"""
|
||||
Remember details about the following entity:
|
||||
Name: {item.name}
|
||||
Type: {item.type}
|
||||
Entity Description: {item.description}
|
||||
"""
|
||||
else:
|
||||
data = f"{item.name}({item.type}): {item.description}"
|
||||
|
||||
super().save(data, item.metadata)
|
||||
saved_count += 1
|
||||
except Exception as e:
|
||||
errors.append(f"{item.name}: {str(e)}")
|
||||
|
||||
if is_batch:
|
||||
emit_value = f"Saved {saved_count} entities"
|
||||
metadata = {"entity_count": saved_count, "errors": errors}
|
||||
if self._memory_provider == "mem0":
|
||||
data = f"""
|
||||
Remember details about the following entity:
|
||||
Name: {item.name}
|
||||
Type: {item.type}
|
||||
Entity Description: {item.description}
|
||||
"""
|
||||
else:
|
||||
emit_value = f"{items[0].name}({items[0].type}): {items[0].description}"
|
||||
metadata = items[0].metadata
|
||||
data = f"{item.name}({item.type}): {item.description}"
|
||||
|
||||
super().save(data, item.metadata)
|
||||
|
||||
# Emit memory save completed event
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemorySaveCompletedEvent(
|
||||
value=emit_value,
|
||||
metadata=metadata,
|
||||
value=data,
|
||||
metadata=item.metadata,
|
||||
save_time_ms=(time.time() - start_time) * 1000,
|
||||
source_type="entity_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
|
||||
if errors:
|
||||
raise Exception(
|
||||
f"Partial save: {len(errors)} failed out of {len(items)}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
fail_metadata = (
|
||||
{"entity_count": len(items), "saved": saved_count}
|
||||
if is_batch
|
||||
else items[0].metadata
|
||||
)
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemorySaveFailedEvent(
|
||||
metadata=fail_metadata,
|
||||
metadata=item.metadata,
|
||||
error=str(e),
|
||||
source_type="entity_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
raise
|
||||
@@ -161,7 +103,7 @@ class EntityMemory(Memory):
|
||||
query: str,
|
||||
limit: int = 3,
|
||||
score_threshold: float = 0.35,
|
||||
) -> Any:
|
||||
):
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryQueryStartedEvent(
|
||||
@@ -169,8 +111,6 @@ class EntityMemory(Memory):
|
||||
limit=limit,
|
||||
score_threshold=score_threshold,
|
||||
source_type="entity_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -189,8 +129,6 @@ class EntityMemory(Memory):
|
||||
score_threshold=score_threshold,
|
||||
query_time_ms=(time.time() - start_time) * 1000,
|
||||
source_type="entity_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
26
src/crewai/memory/external/external_memory.py
vendored
26
src/crewai/memory/external/external_memory.py
vendored
@@ -4,8 +4,8 @@ import time
|
||||
from crewai.memory.external.external_memory_item import ExternalMemoryItem
|
||||
from crewai.memory.memory import Memory
|
||||
from crewai.memory.storage.interface import Storage
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.memory_events import (
|
||||
from crewai.utilities.events.crewai_event_bus import crewai_event_bus
|
||||
from crewai.utilities.events.memory_events import (
|
||||
MemoryQueryStartedEvent,
|
||||
MemoryQueryCompletedEvent,
|
||||
MemoryQueryFailedEvent,
|
||||
@@ -53,6 +53,7 @@ class ExternalMemory(Memory):
|
||||
self,
|
||||
value: Any,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
agent: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Saves a value into the external storage."""
|
||||
crewai_event_bus.emit(
|
||||
@@ -60,30 +61,24 @@ class ExternalMemory(Memory):
|
||||
event=MemorySaveStartedEvent(
|
||||
value=value,
|
||||
metadata=metadata,
|
||||
agent_role=agent,
|
||||
source_type="external_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
try:
|
||||
item = ExternalMemoryItem(
|
||||
value=value,
|
||||
metadata=metadata,
|
||||
agent=self.agent.role if self.agent else None,
|
||||
)
|
||||
super().save(value=item.value, metadata=item.metadata)
|
||||
item = ExternalMemoryItem(value=value, metadata=metadata, agent=agent)
|
||||
super().save(value=item.value, metadata=item.metadata, agent=item.agent)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemorySaveCompletedEvent(
|
||||
value=value,
|
||||
metadata=metadata,
|
||||
agent_role=agent,
|
||||
save_time_ms=(time.time() - start_time) * 1000,
|
||||
source_type="external_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
@@ -92,10 +87,9 @@ class ExternalMemory(Memory):
|
||||
event=MemorySaveFailedEvent(
|
||||
value=value,
|
||||
metadata=metadata,
|
||||
agent_role=agent,
|
||||
error=str(e),
|
||||
source_type="external_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
raise
|
||||
@@ -113,8 +107,6 @@ class ExternalMemory(Memory):
|
||||
limit=limit,
|
||||
score_threshold=score_threshold,
|
||||
source_type="external_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -133,8 +125,6 @@ class ExternalMemory(Memory):
|
||||
score_threshold=score_threshold,
|
||||
query_time_ms=(time.time() - start_time) * 1000,
|
||||
source_type="external_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -1,17 +1,17 @@
|
||||
from typing import Any, Dict, List
|
||||
import time
|
||||
from typing import Any, Optional
|
||||
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.memory_events import (
|
||||
MemoryQueryCompletedEvent,
|
||||
MemoryQueryFailedEvent,
|
||||
MemoryQueryStartedEvent,
|
||||
MemorySaveCompletedEvent,
|
||||
MemorySaveFailedEvent,
|
||||
MemorySaveStartedEvent,
|
||||
)
|
||||
from crewai.memory.long_term.long_term_memory_item import LongTermMemoryItem
|
||||
from crewai.memory.memory import Memory
|
||||
from crewai.utilities.events.crewai_event_bus import crewai_event_bus
|
||||
from crewai.utilities.events.memory_events import (
|
||||
MemoryQueryStartedEvent,
|
||||
MemoryQueryCompletedEvent,
|
||||
MemoryQueryFailedEvent,
|
||||
MemorySaveStartedEvent,
|
||||
MemorySaveCompletedEvent,
|
||||
MemorySaveFailedEvent,
|
||||
)
|
||||
from crewai.memory.storage.ltm_sqlite_storage import LTMSQLiteStorage
|
||||
|
||||
|
||||
@@ -24,120 +24,90 @@ class LongTermMemory(Memory):
|
||||
LongTermMemoryItem instances.
|
||||
"""
|
||||
|
||||
def __init__(self, storage: Any = None, path: Any = None) -> None:
|
||||
def __init__(self, storage=None, path=None):
|
||||
if not storage:
|
||||
storage = LTMSQLiteStorage(db_path=path) if path else LTMSQLiteStorage()
|
||||
super().__init__(storage=storage)
|
||||
|
||||
def save(
|
||||
self,
|
||||
value: Any,
|
||||
metadata: Optional[dict[str, Any]] = None,
|
||||
) -> None:
|
||||
# Handle both LongTermMemoryItem and regular save calls
|
||||
if isinstance(value, LongTermMemoryItem):
|
||||
item = value
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemorySaveStartedEvent(
|
||||
value=item.task,
|
||||
metadata=item.metadata,
|
||||
agent_role=item.agent,
|
||||
source_type="long_term_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
try:
|
||||
metadata = item.metadata.copy()
|
||||
metadata.update(
|
||||
{"agent": item.agent, "expected_output": item.expected_output}
|
||||
)
|
||||
self.storage.save(
|
||||
task_description=item.task,
|
||||
score=metadata["quality"],
|
||||
metadata=metadata,
|
||||
datetime=item.datetime,
|
||||
)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemorySaveCompletedEvent(
|
||||
value=item.task,
|
||||
metadata=metadata,
|
||||
agent_role=item.agent,
|
||||
save_time_ms=(time.time() - start_time) * 1000,
|
||||
source_type="long_term_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemorySaveFailedEvent(
|
||||
value=item.task,
|
||||
metadata=item.metadata,
|
||||
agent_role=item.agent,
|
||||
error=str(e),
|
||||
source_type="long_term_memory",
|
||||
),
|
||||
)
|
||||
raise
|
||||
else:
|
||||
# Regular save for compatibility with parent class
|
||||
metadata = metadata or {}
|
||||
self.storage.save(value, metadata)
|
||||
|
||||
def search(
|
||||
self,
|
||||
query: str,
|
||||
limit: int = 3,
|
||||
score_threshold: float = 0.35,
|
||||
) -> list[Any]:
|
||||
def save(self, item: LongTermMemoryItem) -> None: # type: ignore # BUG?: Signature of "save" incompatible with supertype "Memory"
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryQueryStartedEvent(
|
||||
query=query,
|
||||
limit=limit,
|
||||
event=MemorySaveStartedEvent(
|
||||
value=item.task,
|
||||
metadata=item.metadata,
|
||||
agent_role=item.agent,
|
||||
source_type="long_term_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
try:
|
||||
# The storage.load method uses different parameter names
|
||||
# but we'll call it with the aligned names
|
||||
results = self.storage.load(query, limit)
|
||||
metadata = item.metadata
|
||||
metadata.update({"agent": item.agent, "expected_output": item.expected_output})
|
||||
self.storage.save( # type: ignore # BUG?: Unexpected keyword argument "task_description","score","datetime" for "save" of "Storage"
|
||||
task_description=item.task,
|
||||
score=metadata["quality"],
|
||||
metadata=metadata,
|
||||
datetime=item.datetime,
|
||||
)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemorySaveCompletedEvent(
|
||||
value=item.task,
|
||||
metadata=item.metadata,
|
||||
agent_role=item.agent,
|
||||
save_time_ms=(time.time() - start_time) * 1000,
|
||||
source_type="long_term_memory",
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemorySaveFailedEvent(
|
||||
value=item.task,
|
||||
metadata=item.metadata,
|
||||
agent_role=item.agent,
|
||||
error=str(e),
|
||||
source_type="long_term_memory",
|
||||
),
|
||||
)
|
||||
raise
|
||||
|
||||
def search(self, task: str, latest_n: int = 3) -> List[Dict[str, Any]]: # type: ignore # signature of "search" incompatible with supertype "Memory"
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryQueryStartedEvent(
|
||||
query=task,
|
||||
limit=latest_n,
|
||||
source_type="long_term_memory",
|
||||
),
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
try:
|
||||
results = self.storage.load(task, latest_n) # type: ignore # BUG?: "Storage" has no attribute "load"
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryQueryCompletedEvent(
|
||||
query=query,
|
||||
query=task,
|
||||
results=results,
|
||||
limit=limit,
|
||||
limit=latest_n,
|
||||
query_time_ms=(time.time() - start_time) * 1000,
|
||||
source_type="long_term_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
|
||||
return results if results is not None else []
|
||||
return results
|
||||
except Exception as e:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryQueryFailedEvent(
|
||||
query=query,
|
||||
limit=limit,
|
||||
query=task,
|
||||
limit=latest_n,
|
||||
error=str(e),
|
||||
source_type="long_term_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
raise
|
||||
|
||||
@@ -1,11 +1,7 @@
|
||||
from typing import Any, Dict, List, Optional, TYPE_CHECKING
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.agent import Agent
|
||||
from crewai.task import Task
|
||||
|
||||
|
||||
class Memory(BaseModel):
|
||||
"""
|
||||
@@ -16,38 +12,19 @@ class Memory(BaseModel):
|
||||
crew: Optional[Any] = None
|
||||
|
||||
storage: Any
|
||||
_agent: Optional["Agent"] = None
|
||||
_task: Optional["Task"] = None
|
||||
|
||||
def __init__(self, storage: Any, **data: Any):
|
||||
super().__init__(storage=storage, **data)
|
||||
|
||||
@property
|
||||
def task(self) -> Optional["Task"]:
|
||||
"""Get the current task associated with this memory."""
|
||||
return self._task
|
||||
|
||||
@task.setter
|
||||
def task(self, task: Optional["Task"]) -> None:
|
||||
"""Set the current task associated with this memory."""
|
||||
self._task = task
|
||||
|
||||
@property
|
||||
def agent(self) -> Optional["Agent"]:
|
||||
"""Get the current agent associated with this memory."""
|
||||
return self._agent
|
||||
|
||||
@agent.setter
|
||||
def agent(self, agent: Optional["Agent"]) -> None:
|
||||
"""Set the current agent associated with this memory."""
|
||||
self._agent = agent
|
||||
|
||||
def save(
|
||||
self,
|
||||
value: Any,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
agent: Optional[str] = None,
|
||||
) -> None:
|
||||
metadata = metadata or {}
|
||||
if agent:
|
||||
metadata["agent"] = agent
|
||||
|
||||
self.storage.save(value, metadata)
|
||||
|
||||
|
||||
@@ -1,20 +1,20 @@
|
||||
from typing import Any, Dict, Optional
|
||||
import time
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import PrivateAttr
|
||||
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.memory_events import (
|
||||
MemoryQueryCompletedEvent,
|
||||
MemoryQueryFailedEvent,
|
||||
MemoryQueryStartedEvent,
|
||||
MemorySaveCompletedEvent,
|
||||
MemorySaveFailedEvent,
|
||||
MemorySaveStartedEvent,
|
||||
)
|
||||
from crewai.memory.memory import Memory
|
||||
from crewai.memory.short_term.short_term_memory_item import ShortTermMemoryItem
|
||||
from crewai.memory.storage.rag_storage import RAGStorage
|
||||
from crewai.utilities.events.crewai_event_bus import crewai_event_bus
|
||||
from crewai.utilities.events.memory_events import (
|
||||
MemoryQueryStartedEvent,
|
||||
MemoryQueryCompletedEvent,
|
||||
MemoryQueryFailedEvent,
|
||||
MemorySaveStartedEvent,
|
||||
MemorySaveCompletedEvent,
|
||||
MemorySaveFailedEvent,
|
||||
)
|
||||
|
||||
|
||||
class ShortTermMemory(Memory):
|
||||
@@ -28,13 +28,7 @@ class ShortTermMemory(Memory):
|
||||
|
||||
_memory_provider: Optional[str] = PrivateAttr()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
crew: Any = None,
|
||||
embedder_config: Any = None,
|
||||
storage: Any = None,
|
||||
path: Any = None,
|
||||
) -> None:
|
||||
def __init__(self, crew=None, embedder_config=None, storage=None, path=None):
|
||||
memory_provider = embedder_config.get("provider") if embedder_config else None
|
||||
if memory_provider == "mem0":
|
||||
try:
|
||||
@@ -43,7 +37,7 @@ class ShortTermMemory(Memory):
|
||||
raise ImportError(
|
||||
"Mem0 is not installed. Please install it with `pip install mem0ai`."
|
||||
)
|
||||
config = embedder_config.get("config") if embedder_config else None
|
||||
config = embedder_config.get("config")
|
||||
storage = Mem0Storage(type="short_term", crew=crew, config=config)
|
||||
else:
|
||||
storage = (
|
||||
@@ -62,43 +56,35 @@ class ShortTermMemory(Memory):
|
||||
def save(
|
||||
self,
|
||||
value: Any,
|
||||
metadata: Optional[dict[str, Any]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
agent: Optional[str] = None,
|
||||
) -> None:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemorySaveStartedEvent(
|
||||
value=value,
|
||||
metadata=metadata,
|
||||
agent_role=agent,
|
||||
source_type="short_term_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
try:
|
||||
item = ShortTermMemoryItem(
|
||||
data=value,
|
||||
metadata=metadata,
|
||||
agent=self.agent.role if self.agent else None,
|
||||
)
|
||||
item = ShortTermMemoryItem(data=value, metadata=metadata, agent=agent)
|
||||
if self._memory_provider == "mem0":
|
||||
item.data = (
|
||||
f"Remember the following insights from Agent run: {item.data}"
|
||||
)
|
||||
item.data = f"Remember the following insights from Agent run: {item.data}"
|
||||
|
||||
super().save(value=item.data, metadata=item.metadata)
|
||||
super().save(value=item.data, metadata=item.metadata, agent=item.agent)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemorySaveCompletedEvent(
|
||||
value=value,
|
||||
metadata=metadata,
|
||||
# agent_role=agent,
|
||||
agent_role=agent,
|
||||
save_time_ms=(time.time() - start_time) * 1000,
|
||||
source_type="short_term_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
@@ -107,10 +93,9 @@ class ShortTermMemory(Memory):
|
||||
event=MemorySaveFailedEvent(
|
||||
value=value,
|
||||
metadata=metadata,
|
||||
agent_role=agent,
|
||||
error=str(e),
|
||||
source_type="short_term_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
raise
|
||||
@@ -120,7 +105,7 @@ class ShortTermMemory(Memory):
|
||||
query: str,
|
||||
limit: int = 3,
|
||||
score_threshold: float = 0.35,
|
||||
) -> Any:
|
||||
):
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryQueryStartedEvent(
|
||||
@@ -128,8 +113,6 @@ class ShortTermMemory(Memory):
|
||||
limit=limit,
|
||||
score_threshold=score_threshold,
|
||||
source_type="short_term_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -137,7 +120,7 @@ class ShortTermMemory(Memory):
|
||||
try:
|
||||
results = self.storage.search(
|
||||
query=query, limit=limit, score_threshold=score_threshold
|
||||
)
|
||||
) # type: ignore # BUG? The reference is to the parent class, but the parent class does not have this parameters
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
@@ -148,8 +131,6 @@ class ShortTermMemory(Memory):
|
||||
score_threshold=score_threshold,
|
||||
query_time_ms=(time.time() - start_time) * 1000,
|
||||
source_type="short_term_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -18,7 +18,9 @@ class KickoffTaskOutputsSQLiteStorage:
|
||||
An updated SQLite storage class for kickoff task outputs storage.
|
||||
"""
|
||||
|
||||
def __init__(self, db_path: Optional[str] = None) -> None:
|
||||
def __init__(
|
||||
self, db_path: Optional[str] = None
|
||||
) -> None:
|
||||
if db_path is None:
|
||||
# Get the parent directory of the default db path and create our db file there
|
||||
db_path = str(Path(db_storage_path()) / "latest_kickoff_task_outputs.db")
|
||||
@@ -65,7 +67,7 @@ class KickoffTaskOutputsSQLiteStorage:
|
||||
output: Dict[str, Any],
|
||||
task_index: int,
|
||||
was_replayed: bool = False,
|
||||
inputs: Dict[str, Any] | None = None,
|
||||
inputs: Dict[str, Any] = {},
|
||||
) -> None:
|
||||
"""Add a new task output record to the database.
|
||||
|
||||
@@ -79,7 +81,6 @@ class KickoffTaskOutputsSQLiteStorage:
|
||||
Raises:
|
||||
DatabaseOperationError: If saving the task output fails due to SQLite errors.
|
||||
"""
|
||||
inputs = inputs or {}
|
||||
try:
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
conn.execute("BEGIN TRANSACTION")
|
||||
@@ -145,9 +146,7 @@ class KickoffTaskOutputsSQLiteStorage:
|
||||
conn.commit()
|
||||
|
||||
if cursor.rowcount == 0:
|
||||
logger.warning(
|
||||
f"No row found with task_index {task_index}. No update performed."
|
||||
)
|
||||
logger.warning(f"No row found with task_index {task_index}. No update performed.")
|
||||
except sqlite3.Error as e:
|
||||
error_msg = DatabaseError.format_error(DatabaseError.UPDATE_ERROR, e)
|
||||
logger.error(error_msg)
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
import os
|
||||
from typing import Any, Dict, List
|
||||
from collections import defaultdict
|
||||
from typing import Any
|
||||
|
||||
from mem0 import Memory, MemoryClient # type: ignore[import-not-found]
|
||||
from mem0 import Memory, MemoryClient
|
||||
from crewai.utilities.chromadb import sanitize_collection_name
|
||||
|
||||
from crewai.memory.storage.interface import Storage
|
||||
from crewai.utilities.chromadb import sanitize_collection_name
|
||||
|
||||
MAX_AGENT_ID_LENGTH_MEM0 = 255
|
||||
|
||||
@@ -14,10 +13,7 @@ class Mem0Storage(Storage):
|
||||
"""
|
||||
Extends Storage to handle embedding and searching across entities using Mem0.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, type: str, crew: Any = None, config: dict[str, Any] | None = None
|
||||
) -> None:
|
||||
def __init__(self, type, crew=None, config=None):
|
||||
super().__init__()
|
||||
|
||||
self._validate_type(type)
|
||||
@@ -28,21 +24,21 @@ class Mem0Storage(Storage):
|
||||
self._extract_config_values()
|
||||
self._initialize_memory()
|
||||
|
||||
def _validate_type(self, type: str) -> None:
|
||||
def _validate_type(self, type):
|
||||
supported_types = {"short_term", "long_term", "entities", "external"}
|
||||
if type not in supported_types:
|
||||
raise ValueError(
|
||||
f"Invalid type '{type}' for Mem0Storage. Must be one of: {', '.join(supported_types)}"
|
||||
)
|
||||
|
||||
def _extract_config_values(self) -> None:
|
||||
def _extract_config_values(self):
|
||||
self.mem0_run_id = self.config.get("run_id")
|
||||
self.includes = self.config.get("includes")
|
||||
self.excludes = self.config.get("excludes")
|
||||
self.custom_categories = self.config.get("custom_categories")
|
||||
self.infer = self.config.get("infer", True)
|
||||
|
||||
def _initialize_memory(self) -> None:
|
||||
def _initialize_memory(self):
|
||||
api_key = self.config.get("api_key") or os.getenv("MEM0_API_KEY")
|
||||
org_id = self.config.get("org_id")
|
||||
project_id = self.config.get("project_id")
|
||||
@@ -63,7 +59,7 @@ class Mem0Storage(Storage):
|
||||
else Memory()
|
||||
)
|
||||
|
||||
def _create_filter_for_search(self) -> dict[str, Any]:
|
||||
def _create_filter_for_search(self):
|
||||
"""
|
||||
Returns:
|
||||
dict: A filter dictionary containing AND conditions for querying data.
|
||||
@@ -90,21 +86,21 @@ class Mem0Storage(Storage):
|
||||
|
||||
return filter
|
||||
|
||||
def save(self, value: Any, metadata: dict[str, Any]) -> None:
|
||||
def save(self, value: Any, metadata: Dict[str, Any]) -> None:
|
||||
user_id = self.config.get("user_id", "")
|
||||
assistant_message = [{"role": "assistant", "content": value}]
|
||||
assistant_message = [{"role" : "assistant","content" : value}]
|
||||
|
||||
base_metadata = {
|
||||
"short_term": "short_term",
|
||||
"long_term": "long_term",
|
||||
"entities": "entity",
|
||||
"external": "external",
|
||||
"external": "external"
|
||||
}
|
||||
|
||||
# Shared base params
|
||||
params: dict[str, Any] = {
|
||||
"metadata": {"type": base_metadata[self.memory_type], **metadata},
|
||||
"infer": self.infer,
|
||||
"infer": self.infer
|
||||
}
|
||||
|
||||
# MemoryClient-specific overrides
|
||||
@@ -125,15 +121,13 @@ class Mem0Storage(Storage):
|
||||
|
||||
self.memory.add(assistant_message, **params)
|
||||
|
||||
def search(
|
||||
self, query: str, limit: int = 3, score_threshold: float = 0.35
|
||||
) -> list[Any]:
|
||||
def search(self,query: str,limit: int = 3,score_threshold: float = 0.35) -> List[Any]:
|
||||
params = {
|
||||
"query": query,
|
||||
"limit": limit,
|
||||
"version": "v2",
|
||||
"output_format": "v1.1",
|
||||
}
|
||||
"output_format": "v1.1"
|
||||
}
|
||||
|
||||
if user_id := self.config.get("user_id", ""):
|
||||
params["user_id"] = user_id
|
||||
@@ -154,10 +148,10 @@ class Mem0Storage(Storage):
|
||||
# automatically when the crew is created.
|
||||
|
||||
params["filters"] = self._create_filter_for_search()
|
||||
params["threshold"] = score_threshold
|
||||
params['threshold'] = score_threshold
|
||||
|
||||
if isinstance(self.memory, Memory):
|
||||
del params["metadata"], params["version"], params["output_format"]
|
||||
del params["metadata"], params["version"], params['output_format']
|
||||
if params.get("run_id"):
|
||||
del params["run_id"]
|
||||
|
||||
@@ -166,10 +160,10 @@ class Mem0Storage(Storage):
|
||||
# This makes it compatible for Contextual Memory to retrieve
|
||||
for result in results["results"]:
|
||||
result["context"] = result["memory"]
|
||||
|
||||
|
||||
return [r for r in results["results"]]
|
||||
|
||||
def reset(self) -> None:
|
||||
def reset(self):
|
||||
if self.memory:
|
||||
self.memory.reset()
|
||||
|
||||
@@ -186,6 +180,4 @@ class Mem0Storage(Storage):
|
||||
agents = self.crew.agents
|
||||
agents = [self._sanitize_role(agent.role) for agent in agents]
|
||||
agents = "_".join(agents)
|
||||
return sanitize_collection_name(
|
||||
name=agents, max_collection_length=MAX_AGENT_ID_LENGTH_MEM0
|
||||
)
|
||||
return sanitize_collection_name(name=agents, max_collection_length=MAX_AGENT_ID_LENGTH_MEM0)
|
||||
|
||||
@@ -1,73 +1,48 @@
|
||||
import contextlib
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import uuid
|
||||
import warnings
|
||||
from typing import Any
|
||||
|
||||
from chromadb import EmbeddingFunction
|
||||
from typing import Any, Dict, List, Optional
|
||||
from chromadb.api import ClientAPI
|
||||
|
||||
from crewai.rag.embeddings.configurator import EmbeddingConfigurator
|
||||
from crewai.rag.storage.base_rag_storage import BaseRAGStorage
|
||||
from crewai.rag.embeddings.configurator import EmbeddingConfigurator
|
||||
from crewai.utilities.chromadb import create_persistent_client
|
||||
from crewai.utilities.constants import MAX_FILE_NAME_LENGTH
|
||||
from crewai.utilities.logger_utils import suppress_logging
|
||||
from crewai.utilities.paths import db_storage_path
|
||||
import warnings
|
||||
|
||||
|
||||
def _extract_chromadb_response_item(
|
||||
response_data: Any,
|
||||
index: int,
|
||||
expected_type: type[Any] | tuple[type[Any], ...],
|
||||
) -> Any | None:
|
||||
"""Extract an item from ChromaDB response data at the given index.
|
||||
|
||||
Args:
|
||||
response_data: The response data from ChromaDB query (e.g., documents, metadatas).
|
||||
index: The index of the item to extract.
|
||||
expected_type: The expected type(s) of the item.
|
||||
|
||||
Returns:
|
||||
The extracted item if it exists and matches the expected type, otherwise None.
|
||||
"""
|
||||
if response_data is None or not response_data:
|
||||
return None
|
||||
|
||||
# ChromaDB sometimes returns nested lists, handle both cases
|
||||
data_list = (
|
||||
response_data[0]
|
||||
if response_data and isinstance(response_data[0], list)
|
||||
else response_data
|
||||
)
|
||||
|
||||
if index < len(data_list):
|
||||
item = data_list[index]
|
||||
if isinstance(item, expected_type):
|
||||
return item
|
||||
return None
|
||||
@contextlib.contextmanager
|
||||
def suppress_logging(
|
||||
logger_name="chromadb.segment.impl.vector.local_persistent_hnsw",
|
||||
level=logging.ERROR,
|
||||
):
|
||||
logger = logging.getLogger(logger_name)
|
||||
original_level = logger.getEffectiveLevel()
|
||||
logger.setLevel(level)
|
||||
with (
|
||||
contextlib.redirect_stdout(io.StringIO()),
|
||||
contextlib.redirect_stderr(io.StringIO()),
|
||||
contextlib.suppress(UserWarning),
|
||||
):
|
||||
yield
|
||||
logger.setLevel(original_level)
|
||||
|
||||
|
||||
class RAGStorage(BaseRAGStorage):
|
||||
"""
|
||||
Extends Storage to handle embeddings for memory entries, improving
|
||||
search efficiency.
|
||||
|
||||
Notes:
|
||||
- TODO: Add type hints to EmbeddingFunction in next typing PR.
|
||||
"""
|
||||
|
||||
app: ClientAPI | None = None
|
||||
embedder_config: EmbeddingFunction[Any] | None = None # type: ignore
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
type: str,
|
||||
allow_reset: bool = True,
|
||||
embedder_config: Any = None,
|
||||
crew: Any = None,
|
||||
path: str | None = None,
|
||||
) -> None:
|
||||
self, type, allow_reset=True, embedder_config=None, crew=None, path=None
|
||||
):
|
||||
super().__init__(type, allow_reset, embedder_config, crew)
|
||||
agents = crew.agents if crew else []
|
||||
agents = [self._sanitize_role(agent.role) for agent in agents]
|
||||
@@ -76,29 +51,16 @@ class RAGStorage(BaseRAGStorage):
|
||||
self.storage_file_name = self._build_storage_file_name(type, agents)
|
||||
|
||||
self.type = type
|
||||
self._original_embedder_config = (
|
||||
embedder_config # Store for later use in _set_embedder_config
|
||||
)
|
||||
|
||||
self.allow_reset = allow_reset
|
||||
self.path = path
|
||||
self._initialize_app()
|
||||
|
||||
def _set_embedder_config(self) -> None:
|
||||
"""Sets the embedder_config using EmbeddingConfigurator.
|
||||
def _set_embedder_config(self):
|
||||
configurator = EmbeddingConfigurator()
|
||||
self.embedder_config = configurator.configure_embedder(self.embedder_config)
|
||||
|
||||
Notes:
|
||||
- TODO: remove the type: ignore on next typing pr.
|
||||
"""
|
||||
configurator = EmbeddingConfigurator() # type: ignore
|
||||
# Pass the original embedder_config from __init__, not self.embedder_config
|
||||
if hasattr(self, "_original_embedder_config"):
|
||||
self.embedder_config = configurator.configure_embedder(
|
||||
self._original_embedder_config
|
||||
)
|
||||
else:
|
||||
self.embedder_config = configurator.configure_embedder()
|
||||
|
||||
def _initialize_app(self) -> None:
|
||||
def _initialize_app(self):
|
||||
from chromadb.config import Settings
|
||||
|
||||
# Suppress deprecation warnings from chromadb, which are not relevant to us
|
||||
@@ -127,8 +89,7 @@ class RAGStorage(BaseRAGStorage):
|
||||
"""
|
||||
return role.replace("\n", "").replace(" ", "_").replace("/", "_")
|
||||
|
||||
@staticmethod
|
||||
def _build_storage_file_name(type: str, file_name: str) -> str:
|
||||
def _build_storage_file_name(self, type: str, file_name: str) -> str:
|
||||
"""
|
||||
Ensures file name does not exceed max allowed by OS
|
||||
"""
|
||||
@@ -142,7 +103,7 @@ class RAGStorage(BaseRAGStorage):
|
||||
|
||||
return f"{base_path}/{file_name}"
|
||||
|
||||
def save(self, value: Any, metadata: dict[str, Any]) -> None:
|
||||
def save(self, value: Any, metadata: Dict[str, Any]) -> None:
|
||||
if not hasattr(self, "app") or not hasattr(self, "collection"):
|
||||
self._initialize_app()
|
||||
try:
|
||||
@@ -154,81 +115,37 @@ class RAGStorage(BaseRAGStorage):
|
||||
self,
|
||||
query: str,
|
||||
limit: int = 3,
|
||||
filter: dict[str, Any] | None = None,
|
||||
filter: Optional[dict] = None,
|
||||
score_threshold: float = 0.35,
|
||||
) -> list[Any]:
|
||||
) -> List[Any]:
|
||||
if not hasattr(self, "app"):
|
||||
self._initialize_app()
|
||||
|
||||
try:
|
||||
with suppress_logging(
|
||||
"chromadb.segment.impl.vector.local_persistent_hnsw", logging.ERROR
|
||||
):
|
||||
with suppress_logging():
|
||||
response = self.collection.query(query_texts=query, n_results=limit)
|
||||
|
||||
results = []
|
||||
if (
|
||||
response
|
||||
and "ids" in response
|
||||
and response["ids"]
|
||||
and len(response["ids"]) > 0
|
||||
):
|
||||
ids_list = (
|
||||
response["ids"][0]
|
||||
if isinstance(response["ids"][0], list)
|
||||
else response["ids"]
|
||||
)
|
||||
for i in range(len(ids_list)):
|
||||
# Handle metadatas
|
||||
meta_item = _extract_chromadb_response_item(
|
||||
response.get("metadatas"), i, dict
|
||||
)
|
||||
metadata: dict[str, Any] = meta_item if meta_item else {}
|
||||
|
||||
# Handle documents
|
||||
doc_item = _extract_chromadb_response_item(
|
||||
response.get("documents"), i, str
|
||||
)
|
||||
context = doc_item if doc_item else ""
|
||||
|
||||
# Handle distances
|
||||
dist_item = _extract_chromadb_response_item(
|
||||
response.get("distances"), i, (int, float)
|
||||
)
|
||||
score = dist_item if dist_item is not None else 1.0
|
||||
|
||||
result = {
|
||||
"id": ids_list[i],
|
||||
"metadata": metadata,
|
||||
"context": context,
|
||||
"score": score,
|
||||
}
|
||||
|
||||
# Check score threshold - distances are smaller when more similar
|
||||
if isinstance(score, (int, float)) and score <= score_threshold:
|
||||
results.append(result)
|
||||
for i in range(len(response["ids"][0])):
|
||||
result = {
|
||||
"id": response["ids"][0][i],
|
||||
"metadata": response["metadatas"][0][i],
|
||||
"context": response["documents"][0][i],
|
||||
"score": response["distances"][0][i],
|
||||
}
|
||||
if result["score"] >= score_threshold:
|
||||
results.append(result)
|
||||
|
||||
return results
|
||||
except Exception as e:
|
||||
logging.error(f"Error during {self.type} search: {str(e)}")
|
||||
return []
|
||||
|
||||
def _generate_embedding(
|
||||
self, text: str, metadata: dict[str, Any] | None = None
|
||||
) -> Any:
|
||||
"""Generates and stores the embedding for the given text and metadata.
|
||||
|
||||
Args:
|
||||
text: The text to generate an embedding for.
|
||||
metadata: Optional metadata associated with the text.
|
||||
|
||||
Notes:
|
||||
- Need to constrain the typing in the base class, this result isn't used
|
||||
"""
|
||||
def _generate_embedding(self, text: str, metadata: Dict[str, Any]) -> None: # type: ignore
|
||||
if not hasattr(self, "app") or not hasattr(self, "collection"):
|
||||
self._initialize_app()
|
||||
|
||||
return self.collection.add(
|
||||
self.collection.add(
|
||||
documents=[text],
|
||||
metadatas=[metadata or {}],
|
||||
ids=[str(uuid.uuid4())],
|
||||
@@ -250,8 +167,7 @@ class RAGStorage(BaseRAGStorage):
|
||||
f"An error occurred while resetting the {self.type} memory: {e}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _create_default_embedding_function() -> EmbeddingFunction[Any]:
|
||||
def _create_default_embedding_function(self):
|
||||
from chromadb.utils.embedding_functions.openai_embedding_function import (
|
||||
OpenAIEmbeddingFunction,
|
||||
)
|
||||
|
||||
@@ -1,58 +1 @@
|
||||
"""RAG (Retrieval-Augmented Generation) infrastructure for CrewAI."""
|
||||
|
||||
import sys
|
||||
import importlib
|
||||
from types import ModuleType
|
||||
from typing import Any
|
||||
|
||||
from crewai.rag.config.types import RagConfigType
|
||||
from crewai.rag.config.utils import set_rag_config
|
||||
|
||||
|
||||
_module_path = __path__
|
||||
_module_file = __file__
|
||||
|
||||
class _RagModule(ModuleType):
|
||||
"""Module wrapper to intercept attribute setting for config."""
|
||||
|
||||
__path__ = _module_path
|
||||
__file__ = _module_file
|
||||
|
||||
def __init__(self, module_name: str):
|
||||
"""Initialize the module wrapper.
|
||||
|
||||
Args:
|
||||
module_name: Name of the module.
|
||||
"""
|
||||
super().__init__(module_name)
|
||||
|
||||
def __setattr__(self, name: str, value: RagConfigType) -> None:
|
||||
"""Set module attributes.
|
||||
|
||||
Args:
|
||||
name: Attribute name.
|
||||
value: Attribute value.
|
||||
"""
|
||||
if name == "config":
|
||||
return set_rag_config(value)
|
||||
raise AttributeError(f"Setting attribute '{name}' is not allowed.")
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
"""Get module attributes.
|
||||
|
||||
Args:
|
||||
name: Attribute name.
|
||||
|
||||
Returns:
|
||||
The requested attribute.
|
||||
|
||||
Raises:
|
||||
AttributeError: If attribute doesn't exist.
|
||||
"""
|
||||
try:
|
||||
return importlib.import_module(f"{self.__name__}.{name}")
|
||||
except ImportError:
|
||||
raise AttributeError(f"module '{self.__name__}' has no attribute '{name}'")
|
||||
|
||||
|
||||
sys.modules[__name__] = _RagModule(__name__)
|
||||
"""RAG (Retrieval-Augmented Generation) infrastructure for CrewAI."""
|
||||
@@ -1,578 +0,0 @@
|
||||
"""ChromaDB client implementation."""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from chromadb.api.types import (
|
||||
Embeddable,
|
||||
EmbeddingFunction as ChromaEmbeddingFunction,
|
||||
QueryResult,
|
||||
)
|
||||
from typing_extensions import Unpack
|
||||
|
||||
from crewai.rag.chromadb.types import (
|
||||
ChromaDBClientType,
|
||||
ChromaDBCollectionCreateParams,
|
||||
ChromaDBCollectionSearchParams,
|
||||
)
|
||||
from crewai.rag.chromadb.utils import (
|
||||
_extract_search_params,
|
||||
_is_async_client,
|
||||
_is_sync_client,
|
||||
_prepare_documents_for_chromadb,
|
||||
_process_query_results,
|
||||
_sanitize_collection_name,
|
||||
)
|
||||
from crewai.utilities.logger_utils import suppress_logging
|
||||
from crewai.rag.core.base_client import (
|
||||
BaseClient,
|
||||
BaseCollectionParams,
|
||||
BaseCollectionAddParams,
|
||||
)
|
||||
from crewai.rag.types import SearchResult
|
||||
|
||||
|
||||
class ChromaDBClient(BaseClient):
|
||||
"""ChromaDB implementation of the BaseClient protocol.
|
||||
|
||||
Provides vector database operations for ChromaDB, supporting both
|
||||
synchronous and asynchronous clients.
|
||||
|
||||
Attributes:
|
||||
client: ChromaDB client instance (ClientAPI or AsyncClientAPI).
|
||||
embedding_function: Function to generate embeddings for documents.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client: ChromaDBClientType,
|
||||
embedding_function: ChromaEmbeddingFunction[Embeddable],
|
||||
) -> None:
|
||||
"""Initialize ChromaDBClient with client and embedding function.
|
||||
|
||||
Args:
|
||||
client: Pre-configured ChromaDB client instance.
|
||||
embedding_function: Embedding function for text to vector conversion.
|
||||
"""
|
||||
self.client = client
|
||||
self.embedding_function = embedding_function
|
||||
|
||||
def create_collection(
|
||||
self, **kwargs: Unpack[ChromaDBCollectionCreateParams]
|
||||
) -> None:
|
||||
"""Create a new collection in ChromaDB.
|
||||
|
||||
Uses the client's default embedding function if none provided.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: Name of the collection to create. Must be unique.
|
||||
configuration: Optional collection configuration specifying distance metrics,
|
||||
HNSW parameters, or other backend-specific settings.
|
||||
metadata: Optional metadata dictionary to attach to the collection.
|
||||
embedding_function: Optional custom embedding function. If not provided,
|
||||
uses the client's default embedding function.
|
||||
data_loader: Optional data loader for batch loading data into the collection.
|
||||
get_or_create: If True, returns existing collection if it already exists
|
||||
instead of raising an error. Defaults to False.
|
||||
|
||||
Raises:
|
||||
TypeError: If AsyncClientAPI is used instead of ClientAPI for sync operations.
|
||||
ValueError: If collection with the same name already exists and get_or_create
|
||||
is False.
|
||||
ConnectionError: If unable to connect to ChromaDB server.
|
||||
|
||||
Example:
|
||||
>>> client = ChromaDBClient()
|
||||
>>> client.create_collection(
|
||||
... collection_name="documents",
|
||||
... metadata={"description": "Product documentation"},
|
||||
... get_or_create=True
|
||||
... )
|
||||
"""
|
||||
if not _is_sync_client(self.client):
|
||||
raise TypeError(
|
||||
"Synchronous method create_collection() requires a ClientAPI. "
|
||||
"Use acreate_collection() for AsyncClientAPI."
|
||||
)
|
||||
|
||||
metadata = kwargs.get("metadata", {})
|
||||
if "hnsw:space" not in metadata:
|
||||
metadata["hnsw:space"] = "cosine"
|
||||
|
||||
self.client.create_collection(
|
||||
name=_sanitize_collection_name(kwargs["collection_name"]),
|
||||
configuration=kwargs.get("configuration"),
|
||||
metadata=metadata,
|
||||
embedding_function=kwargs.get(
|
||||
"embedding_function", self.embedding_function
|
||||
),
|
||||
data_loader=kwargs.get("data_loader"),
|
||||
get_or_create=kwargs.get("get_or_create", False),
|
||||
)
|
||||
|
||||
async def acreate_collection(
|
||||
self, **kwargs: Unpack[ChromaDBCollectionCreateParams]
|
||||
) -> None:
|
||||
"""Create a new collection in ChromaDB asynchronously.
|
||||
|
||||
Creates a new collection with the specified name and optional configuration.
|
||||
If an embedding function is not provided, uses the client's default embedding function.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: Name of the collection to create. Must be unique.
|
||||
configuration: Optional collection configuration specifying distance metrics,
|
||||
HNSW parameters, or other backend-specific settings.
|
||||
metadata: Optional metadata dictionary to attach to the collection.
|
||||
embedding_function: Optional custom embedding function. If not provided,
|
||||
uses the client's default embedding function.
|
||||
data_loader: Optional data loader for batch loading data into the collection.
|
||||
get_or_create: If True, returns existing collection if it already exists
|
||||
instead of raising an error. Defaults to False.
|
||||
|
||||
Raises:
|
||||
TypeError: If ClientAPI is used instead of AsyncClientAPI for async operations.
|
||||
ValueError: If collection with the same name already exists and get_or_create
|
||||
is False.
|
||||
ConnectionError: If unable to connect to ChromaDB server.
|
||||
|
||||
Example:
|
||||
>>> import asyncio
|
||||
>>> async def main():
|
||||
... client = ChromaDBClient()
|
||||
... await client.acreate_collection(
|
||||
... collection_name="documents",
|
||||
... metadata={"description": "Product documentation"},
|
||||
... get_or_create=True
|
||||
... )
|
||||
>>> asyncio.run(main())
|
||||
"""
|
||||
if not _is_async_client(self.client):
|
||||
raise TypeError(
|
||||
"Asynchronous method acreate_collection() requires an AsyncClientAPI. "
|
||||
"Use create_collection() for ClientAPI."
|
||||
)
|
||||
|
||||
metadata = kwargs.get("metadata", {})
|
||||
if "hnsw:space" not in metadata:
|
||||
metadata["hnsw:space"] = "cosine"
|
||||
|
||||
await self.client.create_collection(
|
||||
name=_sanitize_collection_name(kwargs["collection_name"]),
|
||||
configuration=kwargs.get("configuration"),
|
||||
metadata=metadata,
|
||||
embedding_function=kwargs.get(
|
||||
"embedding_function", self.embedding_function
|
||||
),
|
||||
data_loader=kwargs.get("data_loader"),
|
||||
get_or_create=kwargs.get("get_or_create", False),
|
||||
)
|
||||
|
||||
def get_or_create_collection(
|
||||
self, **kwargs: Unpack[ChromaDBCollectionCreateParams]
|
||||
) -> Any:
|
||||
"""Get an existing collection or create it if it doesn't exist.
|
||||
|
||||
Returns existing collection if found, otherwise creates a new one.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: Name of the collection to get or create.
|
||||
configuration: Optional collection configuration specifying distance metrics,
|
||||
HNSW parameters, or other backend-specific settings.
|
||||
metadata: Optional metadata dictionary to attach to the collection.
|
||||
embedding_function: Optional custom embedding function. If not provided,
|
||||
uses the client's default embedding function.
|
||||
data_loader: Optional data loader for batch loading data into the collection.
|
||||
|
||||
Returns:
|
||||
A ChromaDB Collection object.
|
||||
|
||||
Raises:
|
||||
TypeError: If AsyncClientAPI is used instead of ClientAPI for sync operations.
|
||||
ConnectionError: If unable to connect to ChromaDB server.
|
||||
|
||||
Example:
|
||||
>>> client = ChromaDBClient()
|
||||
>>> collection = client.get_or_create_collection(
|
||||
... collection_name="documents",
|
||||
... metadata={"description": "Product documentation"}
|
||||
... )
|
||||
"""
|
||||
if not _is_sync_client(self.client):
|
||||
raise TypeError(
|
||||
"Synchronous method get_or_create_collection() requires a ClientAPI. "
|
||||
"Use aget_or_create_collection() for AsyncClientAPI."
|
||||
)
|
||||
|
||||
metadata = kwargs.get("metadata", {})
|
||||
if "hnsw:space" not in metadata:
|
||||
metadata["hnsw:space"] = "cosine"
|
||||
|
||||
return self.client.get_or_create_collection(
|
||||
name=_sanitize_collection_name(kwargs["collection_name"]),
|
||||
configuration=kwargs.get("configuration"),
|
||||
metadata=metadata,
|
||||
embedding_function=kwargs.get(
|
||||
"embedding_function", self.embedding_function
|
||||
),
|
||||
data_loader=kwargs.get("data_loader"),
|
||||
)
|
||||
|
||||
async def aget_or_create_collection(
|
||||
self, **kwargs: Unpack[ChromaDBCollectionCreateParams]
|
||||
) -> Any:
|
||||
"""Get an existing collection or create it if it doesn't exist asynchronously.
|
||||
|
||||
Returns existing collection if found, otherwise creates a new one.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: Name of the collection to get or create.
|
||||
configuration: Optional collection configuration specifying distance metrics,
|
||||
HNSW parameters, or other backend-specific settings.
|
||||
metadata: Optional metadata dictionary to attach to the collection.
|
||||
embedding_function: Optional custom embedding function. If not provided,
|
||||
uses the client's default embedding function.
|
||||
data_loader: Optional data loader for batch loading data into the collection.
|
||||
|
||||
Returns:
|
||||
A ChromaDB AsyncCollection object.
|
||||
|
||||
Raises:
|
||||
TypeError: If ClientAPI is used instead of AsyncClientAPI for async operations.
|
||||
ConnectionError: If unable to connect to ChromaDB server.
|
||||
|
||||
Example:
|
||||
>>> import asyncio
|
||||
>>> async def main():
|
||||
... client = ChromaDBClient()
|
||||
... collection = await client.aget_or_create_collection(
|
||||
... collection_name="documents",
|
||||
... metadata={"description": "Product documentation"}
|
||||
... )
|
||||
>>> asyncio.run(main())
|
||||
"""
|
||||
if not _is_async_client(self.client):
|
||||
raise TypeError(
|
||||
"Asynchronous method aget_or_create_collection() requires an AsyncClientAPI. "
|
||||
"Use get_or_create_collection() for ClientAPI."
|
||||
)
|
||||
|
||||
metadata = kwargs.get("metadata", {})
|
||||
if "hnsw:space" not in metadata:
|
||||
metadata["hnsw:space"] = "cosine"
|
||||
|
||||
return await self.client.get_or_create_collection(
|
||||
name=_sanitize_collection_name(kwargs["collection_name"]),
|
||||
configuration=kwargs.get("configuration"),
|
||||
metadata=metadata,
|
||||
embedding_function=kwargs.get(
|
||||
"embedding_function", self.embedding_function
|
||||
),
|
||||
data_loader=kwargs.get("data_loader"),
|
||||
)
|
||||
|
||||
def add_documents(self, **kwargs: Unpack[BaseCollectionAddParams]) -> None:
|
||||
"""Add documents with their embeddings to a collection.
|
||||
|
||||
Performs an upsert operation - documents with existing IDs are updated.
|
||||
Generates embeddings automatically using the configured embedding function.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: The name of the collection to add documents to.
|
||||
documents: List of BaseRecord dicts containing:
|
||||
- content: The text content (required)
|
||||
- doc_id: Optional unique identifier (auto-generated if missing)
|
||||
- metadata: Optional metadata dictionary
|
||||
|
||||
Raises:
|
||||
TypeError: If AsyncClientAPI is used instead of ClientAPI for sync operations.
|
||||
ValueError: If collection doesn't exist or documents list is empty.
|
||||
ConnectionError: If unable to connect to ChromaDB server.
|
||||
"""
|
||||
if not _is_sync_client(self.client):
|
||||
raise TypeError(
|
||||
"Synchronous method add_documents() requires a ClientAPI. "
|
||||
"Use aadd_documents() for AsyncClientAPI."
|
||||
)
|
||||
|
||||
collection_name = kwargs["collection_name"]
|
||||
documents = kwargs["documents"]
|
||||
|
||||
if not documents:
|
||||
raise ValueError("Documents list cannot be empty")
|
||||
|
||||
collection = self.client.get_collection(
|
||||
name=_sanitize_collection_name(collection_name),
|
||||
embedding_function=self.embedding_function,
|
||||
)
|
||||
|
||||
prepared = _prepare_documents_for_chromadb(documents)
|
||||
collection.upsert(
|
||||
ids=prepared.ids,
|
||||
documents=prepared.texts,
|
||||
metadatas=prepared.metadatas,
|
||||
)
|
||||
|
||||
async def aadd_documents(self, **kwargs: Unpack[BaseCollectionAddParams]) -> None:
|
||||
"""Add documents with their embeddings to a collection asynchronously.
|
||||
|
||||
Performs an upsert operation - documents with existing IDs are updated.
|
||||
Generates embeddings automatically using the configured embedding function.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: The name of the collection to add documents to.
|
||||
documents: List of BaseRecord dicts containing:
|
||||
- content: The text content (required)
|
||||
- doc_id: Optional unique identifier (auto-generated if missing)
|
||||
- metadata: Optional metadata dictionary
|
||||
|
||||
Raises:
|
||||
TypeError: If ClientAPI is used instead of AsyncClientAPI for async operations.
|
||||
ValueError: If collection doesn't exist or documents list is empty.
|
||||
ConnectionError: If unable to connect to ChromaDB server.
|
||||
"""
|
||||
if not _is_async_client(self.client):
|
||||
raise TypeError(
|
||||
"Asynchronous method aadd_documents() requires an AsyncClientAPI. "
|
||||
"Use add_documents() for ClientAPI."
|
||||
)
|
||||
|
||||
collection_name = kwargs["collection_name"]
|
||||
documents = kwargs["documents"]
|
||||
|
||||
if not documents:
|
||||
raise ValueError("Documents list cannot be empty")
|
||||
|
||||
collection = await self.client.get_collection(
|
||||
name=_sanitize_collection_name(collection_name),
|
||||
embedding_function=self.embedding_function,
|
||||
)
|
||||
prepared = _prepare_documents_for_chromadb(documents)
|
||||
await collection.upsert(
|
||||
ids=prepared.ids,
|
||||
documents=prepared.texts,
|
||||
metadatas=prepared.metadatas,
|
||||
)
|
||||
|
||||
def search(
|
||||
self, **kwargs: Unpack[ChromaDBCollectionSearchParams]
|
||||
) -> list[SearchResult]:
|
||||
"""Search for similar documents using a query.
|
||||
|
||||
Performs semantic search to find documents similar to the query text.
|
||||
Uses the configured embedding function to generate query embeddings.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: Name of the collection to search in.
|
||||
query: The text query to search for.
|
||||
limit: Maximum number of results to return (default: 10).
|
||||
metadata_filter: Optional filter for metadata fields.
|
||||
score_threshold: Optional minimum similarity score (0-1) for results.
|
||||
where: Optional ChromaDB where clause for metadata filtering.
|
||||
where_document: Optional ChromaDB where clause for document content filtering.
|
||||
include: Optional list of fields to include in results.
|
||||
|
||||
Returns:
|
||||
List of SearchResult dicts containing id, content, metadata, and score.
|
||||
|
||||
Raises:
|
||||
TypeError: If AsyncClientAPI is used instead of ClientAPI for sync operations.
|
||||
ValueError: If collection doesn't exist.
|
||||
ConnectionError: If unable to connect to ChromaDB server.
|
||||
"""
|
||||
if not _is_sync_client(self.client):
|
||||
raise TypeError(
|
||||
"Synchronous method search() requires a ClientAPI. "
|
||||
"Use asearch() for AsyncClientAPI."
|
||||
)
|
||||
|
||||
params = _extract_search_params(kwargs)
|
||||
|
||||
collection = self.client.get_collection(
|
||||
name=_sanitize_collection_name(params.collection_name),
|
||||
embedding_function=self.embedding_function,
|
||||
)
|
||||
|
||||
where = params.where if params.where is not None else params.metadata_filter
|
||||
|
||||
with suppress_logging(
|
||||
"chromadb.segment.impl.vector.local_persistent_hnsw", logging.ERROR
|
||||
):
|
||||
results: QueryResult = collection.query(
|
||||
query_texts=[params.query],
|
||||
n_results=params.limit,
|
||||
where=where,
|
||||
where_document=params.where_document,
|
||||
include=params.include,
|
||||
)
|
||||
|
||||
return _process_query_results(
|
||||
collection=collection,
|
||||
results=results,
|
||||
params=params,
|
||||
)
|
||||
|
||||
async def asearch(
|
||||
self, **kwargs: Unpack[ChromaDBCollectionSearchParams]
|
||||
) -> list[SearchResult]:
|
||||
"""Search for similar documents using a query asynchronously.
|
||||
|
||||
Performs semantic search to find documents similar to the query text.
|
||||
Uses the configured embedding function to generate query embeddings.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: Name of the collection to search in.
|
||||
query: The text query to search for.
|
||||
limit: Maximum number of results to return (default: 10).
|
||||
metadata_filter: Optional filter for metadata fields.
|
||||
score_threshold: Optional minimum similarity score (0-1) for results.
|
||||
where: Optional ChromaDB where clause for metadata filtering.
|
||||
where_document: Optional ChromaDB where clause for document content filtering.
|
||||
include: Optional list of fields to include in results.
|
||||
|
||||
Returns:
|
||||
List of SearchResult dicts containing id, content, metadata, and score.
|
||||
|
||||
Raises:
|
||||
TypeError: If ClientAPI is used instead of AsyncClientAPI for async operations.
|
||||
ValueError: If collection doesn't exist.
|
||||
ConnectionError: If unable to connect to ChromaDB server.
|
||||
"""
|
||||
if not _is_async_client(self.client):
|
||||
raise TypeError(
|
||||
"Asynchronous method asearch() requires an AsyncClientAPI. "
|
||||
"Use search() for ClientAPI."
|
||||
)
|
||||
|
||||
params = _extract_search_params(kwargs)
|
||||
|
||||
collection = await self.client.get_collection(
|
||||
name=_sanitize_collection_name(params.collection_name),
|
||||
embedding_function=self.embedding_function,
|
||||
)
|
||||
|
||||
where = params.where if params.where is not None else params.metadata_filter
|
||||
|
||||
with suppress_logging(
|
||||
"chromadb.segment.impl.vector.local_persistent_hnsw", logging.ERROR
|
||||
):
|
||||
results: QueryResult = await collection.query(
|
||||
query_texts=[params.query],
|
||||
n_results=params.limit,
|
||||
where=where,
|
||||
where_document=params.where_document,
|
||||
include=params.include,
|
||||
)
|
||||
|
||||
return _process_query_results(
|
||||
collection=collection,
|
||||
results=results,
|
||||
params=params,
|
||||
)
|
||||
|
||||
def delete_collection(self, **kwargs: Unpack[BaseCollectionParams]) -> None:
|
||||
"""Delete a collection and all its data.
|
||||
|
||||
Permanently removes a collection and all documents, embeddings, and metadata it contains.
|
||||
This operation cannot be undone.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: Name of the collection to delete.
|
||||
|
||||
Raises:
|
||||
TypeError: If AsyncClientAPI is used instead of ClientAPI for sync operations.
|
||||
ValueError: If collection doesn't exist.
|
||||
ConnectionError: If unable to connect to ChromaDB server.
|
||||
|
||||
Example:
|
||||
>>> client = ChromaDBClient()
|
||||
>>> client.delete_collection(collection_name="old_documents")
|
||||
"""
|
||||
if not _is_sync_client(self.client):
|
||||
raise TypeError(
|
||||
"Synchronous method delete_collection() requires a ClientAPI. "
|
||||
"Use adelete_collection() for AsyncClientAPI."
|
||||
)
|
||||
|
||||
collection_name = kwargs["collection_name"]
|
||||
self.client.delete_collection(name=_sanitize_collection_name(collection_name))
|
||||
|
||||
async def adelete_collection(self, **kwargs: Unpack[BaseCollectionParams]) -> None:
|
||||
"""Delete a collection and all its data asynchronously.
|
||||
|
||||
Permanently removes a collection and all documents, embeddings, and metadata it contains.
|
||||
This operation cannot be undone.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: Name of the collection to delete.
|
||||
|
||||
Raises:
|
||||
TypeError: If ClientAPI is used instead of AsyncClientAPI for async operations.
|
||||
ValueError: If collection doesn't exist.
|
||||
ConnectionError: If unable to connect to ChromaDB server.
|
||||
|
||||
Example:
|
||||
>>> import asyncio
|
||||
>>> async def main():
|
||||
... client = ChromaDBClient()
|
||||
... await client.adelete_collection(collection_name="old_documents")
|
||||
>>> asyncio.run(main())
|
||||
"""
|
||||
if not _is_async_client(self.client):
|
||||
raise TypeError(
|
||||
"Asynchronous method adelete_collection() requires an AsyncClientAPI. "
|
||||
"Use delete_collection() for ClientAPI."
|
||||
)
|
||||
|
||||
collection_name = kwargs["collection_name"]
|
||||
await self.client.delete_collection(
|
||||
name=_sanitize_collection_name(collection_name)
|
||||
)
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset the vector database by deleting all collections and data.
|
||||
|
||||
Completely clears the ChromaDB instance, removing all collections,
|
||||
documents, embeddings, and metadata. This operation cannot be undone.
|
||||
Use with extreme caution in production environments.
|
||||
|
||||
Raises:
|
||||
TypeError: If AsyncClientAPI is used instead of ClientAPI for sync operations.
|
||||
ConnectionError: If unable to connect to ChromaDB server.
|
||||
|
||||
Example:
|
||||
>>> client = ChromaDBClient()
|
||||
>>> client.reset() # Removes ALL data from ChromaDB
|
||||
"""
|
||||
if not _is_sync_client(self.client):
|
||||
raise TypeError(
|
||||
"Synchronous method reset() requires a ClientAPI. "
|
||||
"Use areset() for AsyncClientAPI."
|
||||
)
|
||||
|
||||
self.client.reset()
|
||||
|
||||
async def areset(self) -> None:
|
||||
"""Reset the vector database by deleting all collections and data asynchronously.
|
||||
|
||||
Completely clears the ChromaDB instance, removing all collections,
|
||||
documents, embeddings, and metadata. This operation cannot be undone.
|
||||
Use with extreme caution in production environments.
|
||||
|
||||
Raises:
|
||||
TypeError: If ClientAPI is used instead of AsyncClientAPI for async operations.
|
||||
ConnectionError: If unable to connect to ChromaDB server.
|
||||
|
||||
Example:
|
||||
>>> import asyncio
|
||||
>>> async def main():
|
||||
... client = ChromaDBClient()
|
||||
... await client.areset() # Removes ALL data from ChromaDB
|
||||
>>> asyncio.run(main())
|
||||
"""
|
||||
if not _is_async_client(self.client):
|
||||
raise TypeError(
|
||||
"Asynchronous method areset() requires an AsyncClientAPI. "
|
||||
"Use reset() for ClientAPI."
|
||||
)
|
||||
|
||||
await self.client.reset()
|
||||
@@ -1,65 +0,0 @@
|
||||
"""ChromaDB configuration model."""
|
||||
|
||||
import warnings
|
||||
from dataclasses import field
|
||||
from typing import Literal, cast
|
||||
from pydantic.dataclasses import dataclass as pyd_dataclass
|
||||
from chromadb.config import Settings
|
||||
from chromadb.utils.embedding_functions import DefaultEmbeddingFunction
|
||||
|
||||
from crewai.rag.chromadb.types import ChromaEmbeddingFunctionWrapper
|
||||
from crewai.rag.config.base import BaseRagConfig
|
||||
from crewai.rag.chromadb.constants import (
|
||||
DEFAULT_TENANT,
|
||||
DEFAULT_DATABASE,
|
||||
DEFAULT_STORAGE_PATH,
|
||||
)
|
||||
|
||||
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
message=".*Mixing V1 models and V2 models.*",
|
||||
category=UserWarning,
|
||||
module="pydantic._internal._generate_schema",
|
||||
)
|
||||
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
message=r".*'model_fields'.*is deprecated.*",
|
||||
module=r"^chromadb(\.|$)",
|
||||
)
|
||||
|
||||
|
||||
def _default_settings() -> Settings:
|
||||
"""Create default ChromaDB settings.
|
||||
|
||||
Returns:
|
||||
Settings with persistent storage and reset enabled.
|
||||
"""
|
||||
return Settings(
|
||||
persist_directory=DEFAULT_STORAGE_PATH,
|
||||
allow_reset=True,
|
||||
is_persistent=True,
|
||||
)
|
||||
|
||||
|
||||
def _default_embedding_function() -> ChromaEmbeddingFunctionWrapper:
|
||||
"""Create default ChromaDB embedding function.
|
||||
|
||||
Returns:
|
||||
Default embedding function using all-MiniLM-L6-v2 via ONNX.
|
||||
"""
|
||||
return cast(ChromaEmbeddingFunctionWrapper, DefaultEmbeddingFunction())
|
||||
|
||||
|
||||
@pyd_dataclass(frozen=True)
|
||||
class ChromaDBConfig(BaseRagConfig):
|
||||
"""Configuration for ChromaDB client."""
|
||||
|
||||
provider: Literal["chromadb"] = field(default="chromadb", init=False)
|
||||
tenant: str = DEFAULT_TENANT
|
||||
database: str = DEFAULT_DATABASE
|
||||
settings: Settings = field(default_factory=_default_settings)
|
||||
embedding_function: ChromaEmbeddingFunctionWrapper = field(
|
||||
default_factory=_default_embedding_function
|
||||
)
|
||||
@@ -1,17 +0,0 @@
|
||||
"""Constants for ChromaDB configuration."""
|
||||
|
||||
import re
|
||||
from typing import Final
|
||||
|
||||
from crewai.utilities.paths import db_storage_path
|
||||
|
||||
DEFAULT_TENANT: Final[str] = "default_tenant"
|
||||
DEFAULT_DATABASE: Final[str] = "default_database"
|
||||
DEFAULT_STORAGE_PATH: Final[str] = db_storage_path()
|
||||
|
||||
MIN_COLLECTION_LENGTH: Final[int] = 3
|
||||
MAX_COLLECTION_LENGTH: Final[int] = 63
|
||||
DEFAULT_COLLECTION: Final[str] = "default_collection"
|
||||
|
||||
INVALID_CHARS_PATTERN: Final[re.Pattern[str]] = re.compile(r"[^a-zA-Z0-9_-]")
|
||||
IPV4_PATTERN: Final[re.Pattern[str]] = re.compile(r"^(\d{1,3}\.){3}\d{1,3}$")
|
||||
@@ -1,40 +0,0 @@
|
||||
"""Factory functions for creating ChromaDB clients."""
|
||||
|
||||
import os
|
||||
from hashlib import md5
|
||||
import portalocker
|
||||
from chromadb import PersistentClient
|
||||
|
||||
from crewai.rag.chromadb.config import ChromaDBConfig
|
||||
from crewai.rag.chromadb.client import ChromaDBClient
|
||||
|
||||
|
||||
def create_client(config: ChromaDBConfig) -> ChromaDBClient:
|
||||
"""Create a ChromaDBClient from configuration.
|
||||
|
||||
Args:
|
||||
config: ChromaDB configuration object.
|
||||
|
||||
Returns:
|
||||
Configured ChromaDBClient instance.
|
||||
|
||||
Notes:
|
||||
Need to update to use chromadb.Client to support more client types in the near future.
|
||||
"""
|
||||
|
||||
persist_dir = config.settings.persist_directory
|
||||
lock_id = md5(persist_dir.encode(), usedforsecurity=False).hexdigest()
|
||||
lockfile = os.path.join(persist_dir, f"chromadb-{lock_id}.lock")
|
||||
|
||||
with portalocker.Lock(lockfile):
|
||||
client = PersistentClient(
|
||||
path=persist_dir,
|
||||
settings=config.settings,
|
||||
tenant=config.tenant,
|
||||
database=config.database,
|
||||
)
|
||||
|
||||
return ChromaDBClient(
|
||||
client=client,
|
||||
embedding_function=config.embedding_function,
|
||||
)
|
||||
@@ -1,102 +0,0 @@
|
||||
"""Type definitions specific to ChromaDB implementation."""
|
||||
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, NamedTuple
|
||||
|
||||
from pydantic import GetCoreSchemaHandler
|
||||
from pydantic_core import CoreSchema, core_schema
|
||||
from chromadb.api import ClientAPI, AsyncClientAPI
|
||||
from chromadb.api.configuration import CollectionConfigurationInterface
|
||||
from chromadb.api.types import (
|
||||
CollectionMetadata,
|
||||
DataLoader,
|
||||
Embeddable,
|
||||
EmbeddingFunction as ChromaEmbeddingFunction,
|
||||
Include,
|
||||
Loadable,
|
||||
Where,
|
||||
WhereDocument,
|
||||
)
|
||||
|
||||
from crewai.rag.core.base_client import BaseCollectionParams, BaseCollectionSearchParams
|
||||
|
||||
ChromaDBClientType = ClientAPI | AsyncClientAPI
|
||||
|
||||
|
||||
class ChromaEmbeddingFunctionWrapper(ChromaEmbeddingFunction[Embeddable]):
|
||||
"""Base class for ChromaDB EmbeddingFunction to work with Pydantic validation."""
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(
|
||||
cls, _source_type: Any, _handler: GetCoreSchemaHandler
|
||||
) -> CoreSchema:
|
||||
"""Generate Pydantic core schema for ChromaDB EmbeddingFunction.
|
||||
|
||||
This allows Pydantic to handle ChromaDB's EmbeddingFunction type
|
||||
without requiring arbitrary_types_allowed=True.
|
||||
"""
|
||||
return core_schema.any_schema()
|
||||
|
||||
|
||||
class PreparedDocuments(NamedTuple):
|
||||
"""Prepared documents ready for ChromaDB insertion.
|
||||
|
||||
Attributes:
|
||||
ids: List of document IDs
|
||||
texts: List of document texts
|
||||
metadatas: List of document metadata mappings
|
||||
"""
|
||||
|
||||
ids: list[str]
|
||||
texts: list[str]
|
||||
metadatas: list[Mapping[str, str | int | float | bool]]
|
||||
|
||||
|
||||
class ExtractedSearchParams(NamedTuple):
|
||||
"""Extracted search parameters for ChromaDB queries.
|
||||
|
||||
Attributes:
|
||||
collection_name: Name of the collection to search
|
||||
query: Search query text
|
||||
limit: Maximum number of results
|
||||
metadata_filter: Optional metadata filter
|
||||
score_threshold: Optional minimum similarity score
|
||||
where: Optional ChromaDB where clause
|
||||
where_document: Optional ChromaDB document filter
|
||||
include: Fields to include in results
|
||||
"""
|
||||
|
||||
collection_name: str
|
||||
query: str
|
||||
limit: int
|
||||
metadata_filter: dict[str, Any] | None
|
||||
score_threshold: float | None
|
||||
where: Where | None
|
||||
where_document: WhereDocument | None
|
||||
include: Include
|
||||
|
||||
|
||||
class ChromaDBCollectionCreateParams(BaseCollectionParams, total=False):
|
||||
"""Parameters for creating a ChromaDB collection.
|
||||
|
||||
This class extends BaseCollectionParams to include any additional
|
||||
parameters specific to ChromaDB collection creation.
|
||||
"""
|
||||
|
||||
configuration: CollectionConfigurationInterface
|
||||
metadata: CollectionMetadata
|
||||
embedding_function: ChromaEmbeddingFunction[Embeddable]
|
||||
data_loader: DataLoader[Loadable]
|
||||
get_or_create: bool
|
||||
|
||||
|
||||
class ChromaDBCollectionSearchParams(BaseCollectionSearchParams, total=False):
|
||||
"""Parameters for searching a ChromaDB collection.
|
||||
|
||||
This class extends BaseCollectionSearchParams to include ChromaDB-specific
|
||||
search parameters like where clauses and include options.
|
||||
"""
|
||||
|
||||
where: Where
|
||||
where_document: WhereDocument
|
||||
include: Include
|
||||
@@ -1,280 +0,0 @@
|
||||
"""Utility functions for ChromaDB client implementation."""
|
||||
|
||||
import hashlib
|
||||
from collections.abc import Mapping
|
||||
from typing import Literal, TypeGuard, cast
|
||||
|
||||
from chromadb.api import AsyncClientAPI, ClientAPI
|
||||
from chromadb.api.types import (
|
||||
Include,
|
||||
IncludeEnum,
|
||||
QueryResult,
|
||||
)
|
||||
from chromadb.api.models.AsyncCollection import AsyncCollection
|
||||
from chromadb.api.models.Collection import Collection
|
||||
from crewai.rag.chromadb.constants import (
|
||||
DEFAULT_COLLECTION,
|
||||
INVALID_CHARS_PATTERN,
|
||||
IPV4_PATTERN,
|
||||
MAX_COLLECTION_LENGTH,
|
||||
MIN_COLLECTION_LENGTH,
|
||||
)
|
||||
from crewai.rag.chromadb.types import (
|
||||
ChromaDBClientType,
|
||||
ChromaDBCollectionSearchParams,
|
||||
ExtractedSearchParams,
|
||||
PreparedDocuments,
|
||||
)
|
||||
from crewai.rag.types import BaseRecord, SearchResult
|
||||
|
||||
|
||||
def _is_sync_client(client: ChromaDBClientType) -> TypeGuard[ClientAPI]:
|
||||
"""Type guard to check if the client is a synchronous ClientAPI.
|
||||
|
||||
Args:
|
||||
client: The client to check.
|
||||
|
||||
Returns:
|
||||
True if the client is a ClientAPI, False otherwise.
|
||||
"""
|
||||
return isinstance(client, ClientAPI)
|
||||
|
||||
|
||||
def _is_async_client(client: ChromaDBClientType) -> TypeGuard[AsyncClientAPI]:
|
||||
"""Type guard to check if the client is an asynchronous AsyncClientAPI.
|
||||
|
||||
Args:
|
||||
client: The client to check.
|
||||
|
||||
Returns:
|
||||
True if the client is an AsyncClientAPI, False otherwise.
|
||||
"""
|
||||
return isinstance(client, AsyncClientAPI)
|
||||
|
||||
|
||||
def _prepare_documents_for_chromadb(
|
||||
documents: list[BaseRecord],
|
||||
) -> PreparedDocuments:
|
||||
"""Prepare documents for ChromaDB by extracting IDs, texts, and metadata.
|
||||
|
||||
Args:
|
||||
documents: List of BaseRecord documents to prepare.
|
||||
|
||||
Returns:
|
||||
PreparedDocuments with ids, texts, and metadatas ready for ChromaDB.
|
||||
"""
|
||||
ids: list[str] = []
|
||||
texts: list[str] = []
|
||||
metadatas: list[Mapping[str, str | int | float | bool]] = []
|
||||
|
||||
for doc in documents:
|
||||
if "doc_id" in doc:
|
||||
ids.append(doc["doc_id"])
|
||||
else:
|
||||
content_hash = hashlib.sha256(doc["content"].encode()).hexdigest()[:16]
|
||||
ids.append(content_hash)
|
||||
|
||||
texts.append(doc["content"])
|
||||
metadata = doc.get("metadata")
|
||||
if metadata:
|
||||
if isinstance(metadata, list):
|
||||
metadatas.append(metadata[0] if metadata else {})
|
||||
else:
|
||||
metadatas.append(metadata)
|
||||
else:
|
||||
metadatas.append({})
|
||||
|
||||
return PreparedDocuments(ids, texts, metadatas)
|
||||
|
||||
|
||||
def _extract_search_params(
|
||||
kwargs: ChromaDBCollectionSearchParams,
|
||||
) -> ExtractedSearchParams:
|
||||
"""Extract search parameters from kwargs.
|
||||
|
||||
Args:
|
||||
kwargs: Keyword arguments containing search parameters.
|
||||
|
||||
Returns:
|
||||
ExtractedSearchParams with all extracted parameters.
|
||||
"""
|
||||
return ExtractedSearchParams(
|
||||
collection_name=kwargs["collection_name"],
|
||||
query=kwargs["query"],
|
||||
limit=kwargs.get("limit", 10),
|
||||
metadata_filter=kwargs.get("metadata_filter"),
|
||||
score_threshold=kwargs.get("score_threshold"),
|
||||
where=kwargs.get("where"),
|
||||
where_document=kwargs.get("where_document"),
|
||||
include=kwargs.get(
|
||||
"include",
|
||||
[IncludeEnum.metadatas, IncludeEnum.documents, IncludeEnum.distances],
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _convert_distance_to_score(
|
||||
distance: float,
|
||||
distance_metric: Literal["l2", "cosine", "ip"],
|
||||
) -> float:
|
||||
"""Convert ChromaDB distance to similarity score.
|
||||
|
||||
Notes:
|
||||
Assuming all embedding are unit-normalized for now, including custom embeddings.
|
||||
|
||||
Args:
|
||||
distance: The distance value from ChromaDB.
|
||||
distance_metric: The distance metric used ("l2", "cosine", or "ip").
|
||||
|
||||
Returns:
|
||||
Similarity score in range [0, 1] where 1 is most similar.
|
||||
"""
|
||||
if distance_metric == "cosine":
|
||||
score = 1.0 - 0.5 * distance
|
||||
return max(0.0, min(1.0, score))
|
||||
raise ValueError(f"Unsupported distance metric: {distance_metric}")
|
||||
|
||||
|
||||
def _convert_chromadb_results_to_search_results(
|
||||
results: QueryResult,
|
||||
include: Include,
|
||||
distance_metric: Literal["l2", "cosine", "ip"],
|
||||
score_threshold: float | None = None,
|
||||
) -> list[SearchResult]:
|
||||
"""Convert ChromaDB query results to SearchResult format.
|
||||
|
||||
Args:
|
||||
results: ChromaDB query results.
|
||||
include: List of fields that were included in the query.
|
||||
distance_metric: The distance metric used by the collection.
|
||||
score_threshold: Optional minimum similarity score (0-1) for results.
|
||||
|
||||
Returns:
|
||||
List of SearchResult dicts containing id, content, metadata, and score.
|
||||
"""
|
||||
search_results: list[SearchResult] = []
|
||||
|
||||
include_strings = [item.value for item in include]
|
||||
|
||||
ids = results["ids"][0] if results.get("ids") else []
|
||||
|
||||
documents_list = results.get("documents")
|
||||
documents = (
|
||||
documents_list[0] if documents_list and "documents" in include_strings else []
|
||||
)
|
||||
|
||||
metadatas_list = results.get("metadatas")
|
||||
metadatas = (
|
||||
metadatas_list[0] if metadatas_list and "metadatas" in include_strings else []
|
||||
)
|
||||
|
||||
distances_list = results.get("distances")
|
||||
distances = (
|
||||
distances_list[0] if distances_list and "distances" in include_strings else []
|
||||
)
|
||||
|
||||
for i, doc_id in enumerate(ids):
|
||||
if not distances or i >= len(distances):
|
||||
continue
|
||||
|
||||
distance = distances[i]
|
||||
score = _convert_distance_to_score(
|
||||
distance=distance, distance_metric=distance_metric
|
||||
)
|
||||
|
||||
if score_threshold and score < score_threshold:
|
||||
continue
|
||||
|
||||
result: SearchResult = {
|
||||
"id": doc_id,
|
||||
"content": documents[i] if documents and i < len(documents) else "",
|
||||
"metadata": dict(metadatas[i]) if metadatas and i < len(metadatas) else {},
|
||||
"score": score,
|
||||
}
|
||||
search_results.append(result)
|
||||
|
||||
return search_results
|
||||
|
||||
|
||||
def _process_query_results(
|
||||
collection: Collection | AsyncCollection,
|
||||
results: QueryResult,
|
||||
params: ExtractedSearchParams,
|
||||
) -> list[SearchResult]:
|
||||
"""Process ChromaDB query results and convert to SearchResult format.
|
||||
|
||||
Args:
|
||||
collection: The ChromaDB collection (sync or async) that was queried.
|
||||
results: Raw query results from ChromaDB.
|
||||
params: The search parameters used for the query.
|
||||
|
||||
Returns:
|
||||
List of SearchResult dicts containing id, content, metadata, and score.
|
||||
"""
|
||||
|
||||
distance_metric = cast(
|
||||
Literal["l2", "cosine", "ip"],
|
||||
collection.metadata.get("hnsw:space", "l2") if collection.metadata else "l2",
|
||||
)
|
||||
|
||||
return _convert_chromadb_results_to_search_results(
|
||||
results=results,
|
||||
include=params.include,
|
||||
distance_metric=distance_metric,
|
||||
score_threshold=params.score_threshold,
|
||||
)
|
||||
|
||||
|
||||
def _is_ipv4_pattern(name: str) -> bool:
|
||||
"""Check if a string matches an IPv4 address pattern.
|
||||
|
||||
Args:
|
||||
name: The string to check
|
||||
|
||||
Returns:
|
||||
True if the string matches an IPv4 pattern, False otherwise
|
||||
"""
|
||||
return bool(IPV4_PATTERN.match(name))
|
||||
|
||||
|
||||
def _sanitize_collection_name(
|
||||
name: str | None, max_collection_length: int = MAX_COLLECTION_LENGTH
|
||||
) -> str:
|
||||
"""Sanitize a collection name to meet ChromaDB requirements.
|
||||
|
||||
Requirements:
|
||||
1. 3-63 characters long
|
||||
2. Starts and ends with alphanumeric character
|
||||
3. Contains only alphanumeric characters, underscores, or hyphens
|
||||
4. No consecutive periods
|
||||
5. Not a valid IPv4 address
|
||||
|
||||
Args:
|
||||
name: The original collection name to sanitize
|
||||
max_collection_length: Maximum allowed length for the collection name
|
||||
|
||||
Returns:
|
||||
A sanitized collection name that meets ChromaDB requirements
|
||||
"""
|
||||
if not name:
|
||||
return DEFAULT_COLLECTION
|
||||
|
||||
if _is_ipv4_pattern(name):
|
||||
name = f"ip_{name}"
|
||||
|
||||
sanitized = INVALID_CHARS_PATTERN.sub("_", name)
|
||||
|
||||
if not sanitized[0].isalnum():
|
||||
sanitized = "a" + sanitized
|
||||
|
||||
if not sanitized[-1].isalnum():
|
||||
sanitized = sanitized[:-1] + "z"
|
||||
|
||||
if len(sanitized) < MIN_COLLECTION_LENGTH:
|
||||
sanitized = sanitized + "x" * (MIN_COLLECTION_LENGTH - len(sanitized))
|
||||
if len(sanitized) > max_collection_length:
|
||||
sanitized = sanitized[:max_collection_length]
|
||||
if not sanitized[-1].isalnum():
|
||||
sanitized = sanitized[:-1] + "z"
|
||||
|
||||
return sanitized
|
||||
@@ -1 +0,0 @@
|
||||
"""RAG client configuration management using ContextVars for thread-safe provider switching."""
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user