mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-06 22:58:30 +00:00
Compare commits
24 Commits
0.165.1
...
devin/1756
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b42d2e8cf0 | ||
|
|
36de68ecd4 | ||
|
|
109de91d08 | ||
|
|
92b70e652d | ||
|
|
fc3f2c49d2 | ||
|
|
88d2968fd5 | ||
|
|
7addda9398 | ||
|
|
4b4a119a9f | ||
|
|
869bb115c8 | ||
|
|
7ac482c7c9 | ||
|
|
2e4bd3f49d | ||
|
|
c02997d956 | ||
|
|
f96b779df5 | ||
|
|
842bed4e9c | ||
|
|
1217935b31 | ||
|
|
641c156c17 | ||
|
|
7fdf9f9290 | ||
|
|
c0d2bf4c12 | ||
|
|
ed187b495b | ||
|
|
2773996b49 | ||
|
|
95923b78c6 | ||
|
|
7065ad4336 | ||
|
|
d6254918fd | ||
|
|
95e3d6db7a |
@@ -320,6 +320,7 @@
|
||||
"en/enterprise/guides/update-crew",
|
||||
"en/enterprise/guides/enable-crew-studio",
|
||||
"en/enterprise/guides/azure-openai-setup",
|
||||
"en/enterprise/guides/automation-triggers",
|
||||
"en/enterprise/guides/hubspot-trigger",
|
||||
"en/enterprise/guides/react-component-export",
|
||||
"en/enterprise/guides/salesforce-trigger",
|
||||
@@ -341,11 +342,12 @@
|
||||
"groups": [
|
||||
{
|
||||
"group": "Getting Started",
|
||||
"pages": ["en/api-reference/introduction"]
|
||||
},
|
||||
{
|
||||
"group": "Endpoints",
|
||||
"openapi": "https://raw.githubusercontent.com/crewAIInc/crewAI/main/docs/enterprise-api.en.yaml"
|
||||
"pages": [
|
||||
"en/api-reference/introduction",
|
||||
"en/api-reference/inputs",
|
||||
"en/api-reference/kickoff",
|
||||
"en/api-reference/status"
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
@@ -657,6 +659,7 @@
|
||||
"pt-BR/enterprise/guides/update-crew",
|
||||
"pt-BR/enterprise/guides/enable-crew-studio",
|
||||
"pt-BR/enterprise/guides/azure-openai-setup",
|
||||
"pt-BR/enterprise/guides/automation-triggers",
|
||||
"pt-BR/enterprise/guides/hubspot-trigger",
|
||||
"pt-BR/enterprise/guides/react-component-export",
|
||||
"pt-BR/enterprise/guides/salesforce-trigger",
|
||||
@@ -680,11 +683,12 @@
|
||||
"groups": [
|
||||
{
|
||||
"group": "Começando",
|
||||
"pages": ["pt-BR/api-reference/introduction"]
|
||||
},
|
||||
{
|
||||
"group": "Endpoints",
|
||||
"openapi": "https://raw.githubusercontent.com/crewAIInc/crewAI/main/docs/enterprise-api.pt-BR.yaml"
|
||||
"pages": [
|
||||
"pt-BR/api-reference/introduction",
|
||||
"pt-BR/api-reference/inputs",
|
||||
"pt-BR/api-reference/kickoff",
|
||||
"pt-BR/api-reference/status"
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
@@ -1005,6 +1009,7 @@
|
||||
"ko/enterprise/guides/update-crew",
|
||||
"ko/enterprise/guides/enable-crew-studio",
|
||||
"ko/enterprise/guides/azure-openai-setup",
|
||||
"ko/enterprise/guides/automation-triggers",
|
||||
"ko/enterprise/guides/hubspot-trigger",
|
||||
"ko/enterprise/guides/react-component-export",
|
||||
"ko/enterprise/guides/salesforce-trigger",
|
||||
@@ -1026,11 +1031,12 @@
|
||||
"groups": [
|
||||
{
|
||||
"group": "시작 안내",
|
||||
"pages": ["ko/api-reference/introduction"]
|
||||
},
|
||||
{
|
||||
"group": "Endpoints",
|
||||
"openapi": "https://raw.githubusercontent.com/crewAIInc/crewAI/main/docs/enterprise-api.ko.yaml"
|
||||
"pages": [
|
||||
"ko/api-reference/introduction",
|
||||
"ko/api-reference/inputs",
|
||||
"ko/api-reference/kickoff",
|
||||
"ko/api-reference/status"
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
@@ -1081,6 +1087,10 @@
|
||||
"indexing": "all"
|
||||
},
|
||||
"redirects": [
|
||||
{
|
||||
"source": "/api-reference",
|
||||
"destination": "/en/api-reference/introduction"
|
||||
},
|
||||
{
|
||||
"source": "/introduction",
|
||||
"destination": "/en/introduction"
|
||||
@@ -1133,6 +1143,18 @@
|
||||
"source": "/api-reference/:path*",
|
||||
"destination": "/en/api-reference/:path*"
|
||||
},
|
||||
{
|
||||
"source": "/en/api-reference",
|
||||
"destination": "/en/api-reference/introduction"
|
||||
},
|
||||
{
|
||||
"source": "/pt-BR/api-reference",
|
||||
"destination": "/pt-BR/api-reference/introduction"
|
||||
},
|
||||
{
|
||||
"source": "/ko/api-reference",
|
||||
"destination": "/ko/api-reference/introduction"
|
||||
},
|
||||
{
|
||||
"source": "/examples/:path*",
|
||||
"destination": "/en/examples/:path*"
|
||||
|
||||
7
docs/en/api-reference/inputs.mdx
Normal file
7
docs/en/api-reference/inputs.mdx
Normal file
@@ -0,0 +1,7 @@
|
||||
---
|
||||
title: "GET /inputs"
|
||||
description: "Get required inputs for your crew"
|
||||
openapi: "/enterprise-api.en.yaml GET /inputs"
|
||||
---
|
||||
|
||||
|
||||
7
docs/en/api-reference/kickoff.mdx
Normal file
7
docs/en/api-reference/kickoff.mdx
Normal file
@@ -0,0 +1,7 @@
|
||||
---
|
||||
title: "POST /kickoff"
|
||||
description: "Start a crew execution"
|
||||
openapi: "/enterprise-api.en.yaml POST /kickoff"
|
||||
---
|
||||
|
||||
|
||||
7
docs/en/api-reference/status.mdx
Normal file
7
docs/en/api-reference/status.mdx
Normal file
@@ -0,0 +1,7 @@
|
||||
---
|
||||
title: "GET /status/{kickoff_id}"
|
||||
description: "Get execution status"
|
||||
openapi: "/enterprise-api.en.yaml GET /status/{kickoff_id}"
|
||||
---
|
||||
|
||||
|
||||
@@ -59,6 +59,7 @@ crew = Crew(
|
||||
| **Output Pydantic** _(optional)_ | `output_pydantic` | `Optional[Type[BaseModel]]` | A Pydantic model for task output. |
|
||||
| **Callback** _(optional)_ | `callback` | `Optional[Any]` | Function/object to be executed after task completion. |
|
||||
| **Guardrail** _(optional)_ | `guardrail` | `Optional[Callable]` | Function to validate task output before proceeding to next task. |
|
||||
| **Guardrail Max Retries** _(optional)_ | `guardrail_max_retries` | `Optional[int]` | Maximum number of retries when guardrail validation fails. Defaults to 3. |
|
||||
|
||||
## Creating Tasks
|
||||
|
||||
@@ -452,7 +453,7 @@ task = Task(
|
||||
expected_output="A valid JSON object",
|
||||
agent=analyst,
|
||||
guardrail=validate_json_output,
|
||||
max_retries=3 # Limit retry attempts
|
||||
guardrail_max_retries=3 # Limit retry attempts
|
||||
)
|
||||
```
|
||||
|
||||
|
||||
@@ -21,13 +21,17 @@ To use the training feature, follow these steps:
|
||||
3. Run the following command:
|
||||
|
||||
```shell
|
||||
crewai train -n <n_iterations> <filename> (optional)
|
||||
crewai train -n <n_iterations> -f <filename.pkl>
|
||||
```
|
||||
<Tip>
|
||||
Replace `<n_iterations>` with the desired number of training iterations and `<filename>` with the appropriate filename ending with `.pkl`.
|
||||
</Tip>
|
||||
|
||||
### Training Your Crew Programmatically
|
||||
<Note>
|
||||
If you omit `-f`, the output defaults to `trained_agents_data.pkl` in the current working directory. You can pass an absolute path to control where the file is written.
|
||||
</Note>
|
||||
|
||||
### Training your Crew programmatically
|
||||
|
||||
To train your crew programmatically, use the following steps:
|
||||
|
||||
@@ -51,19 +55,65 @@ except Exception as e:
|
||||
raise Exception(f"An error occurred while training the crew: {e}")
|
||||
```
|
||||
|
||||
### Key Points to Note
|
||||
## How trained data is used by agents
|
||||
|
||||
- **Positive Integer Requirement:** Ensure that the number of iterations (`n_iterations`) is a positive integer. The code will raise a `ValueError` if this condition is not met.
|
||||
- **Filename Requirement:** Ensure that the filename ends with `.pkl`. The code will raise a `ValueError` if this condition is not met.
|
||||
- **Error Handling:** The code handles subprocess errors and unexpected exceptions, providing error messages to the user.
|
||||
CrewAI uses the training artifacts in two ways: during training to incorporate your human feedback, and after training to guide agents with consolidated suggestions.
|
||||
|
||||
It is important to note that the training process may take some time, depending on the complexity of your agents and will also require your feedback on each iteration.
|
||||
### Training data flow
|
||||
|
||||
Once the training is complete, your agents will be equipped with enhanced capabilities and knowledge, ready to tackle complex tasks and provide more consistent and valuable insights.
|
||||
```mermaid
|
||||
flowchart TD
|
||||
A["Start training<br/>CLI: crewai train -n -f<br/>or Python: crew.train(...)"] --> B["Setup training mode<br/>- task.human_input = true<br/>- disable delegation<br/>- init training_data.pkl + trained file"]
|
||||
|
||||
Remember to regularly update and retrain your agents to ensure they stay up-to-date with the latest information and advancements in the field.
|
||||
subgraph "Iterations"
|
||||
direction LR
|
||||
C["Iteration i<br/>initial_output"] --> D["User human_feedback"]
|
||||
D --> E["improved_output"]
|
||||
E --> F["Append to training_data.pkl<br/>by agent_id and iteration"]
|
||||
end
|
||||
|
||||
Happy training with CrewAI! 🚀
|
||||
B --> C
|
||||
F --> G{"More iterations?"}
|
||||
G -- "Yes" --> C
|
||||
G -- "No" --> H["Evaluate per agent<br/>aggregate iterations"]
|
||||
|
||||
H --> I["Consolidate<br/>suggestions[] + quality + final_summary"]
|
||||
I --> J["Save by agent role to trained file<br/>(default: trained_agents_data.pkl)"]
|
||||
|
||||
J --> K["Normal (non-training) runs"]
|
||||
K --> L["Auto-load suggestions<br/>from trained_agents_data.pkl"]
|
||||
L --> M["Append to prompt<br/>for consistent improvements"]
|
||||
```
|
||||
|
||||
### During training runs
|
||||
|
||||
- On each iteration, the system records for every agent:
|
||||
- `initial_output`: the agent’s first answer
|
||||
- `human_feedback`: your inline feedback when prompted
|
||||
- `improved_output`: the agent’s follow-up answer after feedback
|
||||
- This data is stored in a working file named `training_data.pkl` keyed by the agent’s internal ID and iteration.
|
||||
- While training is active, the agent automatically appends your prior human feedback to its prompt to enforce those instructions on subsequent attempts within the training session.
|
||||
Training is interactive: tasks set `human_input = true`, so running in a non-interactive environment will block on user input.
|
||||
|
||||
### After training completes
|
||||
|
||||
- When `train(...)` finishes, CrewAI evaluates the collected training data per agent and produces a consolidated result containing:
|
||||
- `suggestions`: clear, actionable instructions distilled from your feedback and the difference between initial/improved outputs
|
||||
- `quality`: a 0–10 score capturing improvement
|
||||
- `final_summary`: a step-by-step set of action items for future tasks
|
||||
- These consolidated results are saved to the filename you pass to `train(...)` (default via CLI is `trained_agents_data.pkl`). Entries are keyed by the agent’s `role` so they can be applied across sessions.
|
||||
- During normal (non-training) execution, each agent automatically loads its consolidated `suggestions` and appends them to the task prompt as mandatory instructions. This gives you consistent improvements without changing your agent definitions.
|
||||
|
||||
### File summary
|
||||
|
||||
- `training_data.pkl` (ephemeral, per-session):
|
||||
- Structure: `agent_id -> { iteration_number: { initial_output, human_feedback, improved_output } }`
|
||||
- Purpose: capture raw data and human feedback during training
|
||||
- Location: saved in the current working directory (CWD)
|
||||
- `trained_agents_data.pkl` (or your custom filename):
|
||||
- Structure: `agent_role -> { suggestions: string[], quality: number, final_summary: string }`
|
||||
- Purpose: persist consolidated guidance for future runs
|
||||
- Location: written to the CWD by default; use `-f` to set a custom (including absolute) path
|
||||
|
||||
## Small Language Model Considerations
|
||||
|
||||
@@ -129,3 +179,18 @@ Happy training with CrewAI! 🚀
|
||||
</Warning>
|
||||
</Tab>
|
||||
</Tabs>
|
||||
|
||||
### Key Points to Note
|
||||
|
||||
- **Positive Integer Requirement:** Ensure that the number of iterations (`n_iterations`) is a positive integer. The code will raise a `ValueError` if this condition is not met.
|
||||
- **Filename Requirement:** Ensure that the filename ends with `.pkl`. The code will raise a `ValueError` if this condition is not met.
|
||||
- **Error Handling:** The code handles subprocess errors and unexpected exceptions, providing error messages to the user.
|
||||
- Trained guidance is applied at prompt time; it does not modify your Python/YAML agent configuration.
|
||||
- Agents automatically load trained suggestions from a file named `trained_agents_data.pkl` located in the current working directory. If you trained to a different filename, either rename it to `trained_agents_data.pkl` before running, or adjust the loader in code.
|
||||
- You can change the output filename when calling `crewai train` with `-f/--filename`. Absolute paths are supported if you want to save outside the CWD.
|
||||
|
||||
It is important to note that the training process may take some time, depending on the complexity of your agents and will also require your feedback on each iteration.
|
||||
|
||||
Once the training is complete, your agents will be equipped with enhanced capabilities and knowledge, ready to tackle complex tasks and provide more consistent and valuable insights.
|
||||
|
||||
Remember to regularly update and retrain your agents to ensure they stay up-to-date with the latest information and advancements in the field.
|
||||
|
||||
171
docs/en/enterprise/guides/automation-triggers.mdx
Normal file
171
docs/en/enterprise/guides/automation-triggers.mdx
Normal file
@@ -0,0 +1,171 @@
|
||||
---
|
||||
title: "Automation Triggers"
|
||||
description: "Automatically execute your CrewAI workflows when specific events occur in connected integrations"
|
||||
icon: "bolt"
|
||||
---
|
||||
|
||||
Automation triggers enable you to automatically run your CrewAI deployments when specific events occur in your connected integrations, creating powerful event-driven workflows that respond to real-time changes in your business systems.
|
||||
|
||||
## Overview
|
||||
|
||||
With automation triggers, you can:
|
||||
|
||||
- **Respond to real-time events** - Automatically execute workflows when specific conditions are met
|
||||
- **Integrate with external systems** - Connect with platforms like Gmail, Outlook, OneDrive, JIRA, Slack, Stripe and more
|
||||
- **Scale your automation** - Handle high-volume events without manual intervention
|
||||
- **Maintain context** - Access trigger data within your crews and flows
|
||||
|
||||
## Managing Automation Triggers
|
||||
|
||||
### Viewing Available Triggers
|
||||
|
||||
To access and manage your automation triggers:
|
||||
|
||||
1. Navigate to your deployment in the CrewAI dashboard
|
||||
2. Click on the **Triggers** tab to view all available trigger integrations
|
||||
|
||||
<Frame>
|
||||
<img src="/images/enterprise/list-available-triggers.png" alt="List of available automation triggers" />
|
||||
</Frame>
|
||||
|
||||
This view shows all the trigger integrations available for your deployment, along with their current connection status.
|
||||
|
||||
### Enabling and Disabling Triggers
|
||||
|
||||
Each trigger can be easily enabled or disabled using the toggle switch:
|
||||
|
||||
<Frame>
|
||||
<img src="/images/enterprise/trigger-selected.png" alt="Enable or disable triggers with toggle" />
|
||||
</Frame>
|
||||
|
||||
- **Enabled (blue toggle)**: The trigger is active and will automatically execute your deployment when the specified events occur
|
||||
- **Disabled (gray toggle)**: The trigger is inactive and will not respond to events
|
||||
|
||||
Simply click the toggle to change the trigger state. Changes take effect immediately.
|
||||
|
||||
### Monitoring Trigger Executions
|
||||
|
||||
Track the performance and history of your triggered executions:
|
||||
|
||||
<Frame>
|
||||
<img src="/images/enterprise/list-executions.png" alt="List of executions triggered by automation" />
|
||||
</Frame>
|
||||
|
||||
## Building Automation
|
||||
|
||||
Before building your automation, it's helpful to understand the structure of trigger payloads that your crews and flows will receive.
|
||||
|
||||
### Payload Samples Repository
|
||||
|
||||
We maintain a comprehensive repository with sample payloads from various trigger sources to help you build and test your automations:
|
||||
|
||||
**🔗 [CrewAI Enterprise Trigger Payload Samples](https://github.com/crewAIInc/crewai-enterprise-trigger-payload-samples)**
|
||||
|
||||
This repository contains:
|
||||
|
||||
- **Real payload examples** from different trigger sources (Gmail, Google Drive, etc.)
|
||||
- **Payload structure documentation** showing the format and available fields
|
||||
|
||||
### Triggers with Crew
|
||||
|
||||
Your existing crew definitions work seamlessly with triggers, you just need to have a task to parse the received payload:
|
||||
|
||||
```python
|
||||
@CrewBase
|
||||
class MyAutomatedCrew:
|
||||
@agent
|
||||
def researcher(self) -> Agent:
|
||||
return Agent(
|
||||
config=self.agents_config['researcher'],
|
||||
)
|
||||
|
||||
@task
|
||||
def parse_trigger_payload(self) -> Task:
|
||||
return Task(
|
||||
config=self.tasks_config['parse_trigger_payload'],
|
||||
agent=self.researcher(),
|
||||
)
|
||||
|
||||
@task
|
||||
def analyze_trigger_content(self) -> Task:
|
||||
return Task(
|
||||
config=self.tasks_config['analyze_trigger_data'],
|
||||
agent=self.researcher(),
|
||||
)
|
||||
```
|
||||
|
||||
The crew will automatically receive and can access the trigger payload through the standard CrewAI context mechanisms.
|
||||
|
||||
### Integration with Flows
|
||||
|
||||
For flows, you have more control over how trigger data is handled:
|
||||
|
||||
#### Accessing Trigger Payload
|
||||
|
||||
All `@start()` methods in your flows will accept an additional parameter called `crewai_trigger_payload`:
|
||||
|
||||
```python
|
||||
from crewai.flow import Flow, start, listen
|
||||
|
||||
class MyAutomatedFlow(Flow):
|
||||
@start()
|
||||
def handle_trigger(self, crewai_trigger_payload: dict = None):
|
||||
"""
|
||||
This start method can receive trigger data
|
||||
"""
|
||||
if crewai_trigger_payload:
|
||||
# Process the trigger data
|
||||
trigger_id = crewai_trigger_payload.get('id')
|
||||
event_data = crewai_trigger_payload.get('payload', {})
|
||||
|
||||
# Store in flow state for use by other methods
|
||||
self.state.trigger_id = trigger_id
|
||||
self.state.trigger_type = event_data
|
||||
|
||||
return event_data
|
||||
|
||||
# Handle manual execution
|
||||
return None
|
||||
|
||||
@listen(handle_trigger)
|
||||
def process_data(self, trigger_data):
|
||||
"""
|
||||
Process the data from the trigger
|
||||
"""
|
||||
# ... process the trigger
|
||||
```
|
||||
|
||||
#### Triggering Crews from Flows
|
||||
|
||||
When kicking off a crew within a flow that was triggered, pass the trigger payload as it:
|
||||
|
||||
```python
|
||||
@start()
|
||||
def delegate_to_crew(self, crewai_trigger_payload: dict = None):
|
||||
"""
|
||||
Delegate processing to a specialized crew
|
||||
"""
|
||||
crew = MySpecializedCrew()
|
||||
|
||||
# Pass the trigger payload to the crew
|
||||
result = crew.crew().kickoff(
|
||||
inputs={
|
||||
'a_custom_parameter': "custom_value",
|
||||
'crewai_trigger_payload': crewai_trigger_payload
|
||||
},
|
||||
)
|
||||
|
||||
return result
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
**Trigger not firing:**
|
||||
- Verify the trigger is enabled
|
||||
- Check integration connection status
|
||||
|
||||
**Execution failures:**
|
||||
- Check the execution logs for error details
|
||||
- If you are developing, make sure the inputs include the `crewai_trigger_payload` parameter with the correct payload
|
||||
|
||||
Automation triggers transform your CrewAI deployments into responsive, event-driven systems that can seamlessly integrate with your existing business processes and tools.
|
||||
@@ -1,13 +1,13 @@
|
||||
---
|
||||
title: Weaviate Vector Search
|
||||
description: The `WeaviateVectorSearchTool` is designed to search a Weaviate vector database for semantically similar documents.
|
||||
description: The `WeaviateVectorSearchTool` is designed to search a Weaviate vector database for semantically similar documents using hybrid search.
|
||||
icon: network-wired
|
||||
---
|
||||
|
||||
## Overview
|
||||
|
||||
|
||||
The `WeaviateVectorSearchTool` is specifically crafted for conducting semantic searches within documents stored in a Weaviate vector database. This tool allows you to find semantically similar documents to a given query, leveraging the power of vector embeddings for more accurate and contextually relevant search results.
|
||||
The `WeaviateVectorSearchTool` is specifically crafted for conducting semantic searches within documents stored in a Weaviate vector database. This tool allows you to find semantically similar documents to a given query, leveraging the power of vector and keyword search for more accurate and contextually relevant search results.
|
||||
|
||||
[Weaviate](https://weaviate.io/) is a vector database that stores and queries vector embeddings, enabling semantic search capabilities.
|
||||
|
||||
@@ -39,6 +39,7 @@ from crewai_tools import WeaviateVectorSearchTool
|
||||
tool = WeaviateVectorSearchTool(
|
||||
collection_name='example_collections',
|
||||
limit=3,
|
||||
alpha=0.75,
|
||||
weaviate_cluster_url="https://your-weaviate-cluster-url.com",
|
||||
weaviate_api_key="your-weaviate-api-key",
|
||||
)
|
||||
@@ -63,6 +64,7 @@ The `WeaviateVectorSearchTool` accepts the following parameters:
|
||||
- **weaviate_cluster_url**: Required. The URL of the Weaviate cluster.
|
||||
- **weaviate_api_key**: Required. The API key for the Weaviate cluster.
|
||||
- **limit**: Optional. The number of results to return. Default is `3`.
|
||||
- **alpha**: Optional. Controls the weighting between vector and keyword (BM25) search. alpha = 0 -> BM25 only, alpha = 1 -> vector search only. Default is `0.75`.
|
||||
- **vectorizer**: Optional. The vectorizer to use. If not provided, it will use `text2vec_openai` with the `nomic-embed-text` model.
|
||||
- **generative_model**: Optional. The generative model to use. If not provided, it will use OpenAI's `gpt-4o`.
|
||||
|
||||
@@ -78,6 +80,7 @@ from weaviate.classes.config import Configure
|
||||
tool = WeaviateVectorSearchTool(
|
||||
collection_name='example_collections',
|
||||
limit=3,
|
||||
alpha=0.75,
|
||||
vectorizer=Configure.Vectorizer.text2vec_openai(model="nomic-embed-text"),
|
||||
generative_model=Configure.Generative.openai(model="gpt-4o-mini"),
|
||||
weaviate_cluster_url="https://your-weaviate-cluster-url.com",
|
||||
@@ -128,6 +131,7 @@ with test_docs.batch.dynamic() as batch:
|
||||
tool = WeaviateVectorSearchTool(
|
||||
collection_name='example_collections',
|
||||
limit=3,
|
||||
alpha=0.75,
|
||||
weaviate_cluster_url="https://your-weaviate-cluster-url.com",
|
||||
weaviate_api_key="your-weaviate-api-key",
|
||||
)
|
||||
@@ -145,6 +149,7 @@ from crewai_tools import WeaviateVectorSearchTool
|
||||
weaviate_tool = WeaviateVectorSearchTool(
|
||||
collection_name='example_collections',
|
||||
limit=3,
|
||||
alpha=0.75,
|
||||
weaviate_cluster_url="https://your-weaviate-cluster-url.com",
|
||||
weaviate_api_key="your-weaviate-api-key",
|
||||
)
|
||||
|
||||
@@ -117,4 +117,19 @@ agent = Agent(
|
||||
)
|
||||
```
|
||||
|
||||
## **Max Usage Count**
|
||||
|
||||
You can set a maximum usage count for a tool to prevent it from being used more than a certain number of times.
|
||||
By default, the max usage count is unlimited.
|
||||
|
||||
|
||||
|
||||
```python
|
||||
from crewai_tools import FileReadTool
|
||||
|
||||
tool = FileReadTool(max_usage_count=5, ...)
|
||||
```
|
||||
|
||||
|
||||
|
||||
Ready to explore? Pick a category above to discover tools that fit your use case!
|
||||
|
||||
BIN
docs/images/enterprise/list-available-triggers.png
Normal file
BIN
docs/images/enterprise/list-available-triggers.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 142 KiB |
BIN
docs/images/enterprise/list-executions.png
Normal file
BIN
docs/images/enterprise/list-executions.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 330 KiB |
BIN
docs/images/enterprise/trigger-selected.png
Normal file
BIN
docs/images/enterprise/trigger-selected.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 133 KiB |
7
docs/ko/api-reference/inputs.mdx
Normal file
7
docs/ko/api-reference/inputs.mdx
Normal file
@@ -0,0 +1,7 @@
|
||||
---
|
||||
title: "GET /inputs"
|
||||
description: "크루가 필요로 하는 입력 확인"
|
||||
openapi: "/enterprise-api.ko.yaml GET /inputs"
|
||||
---
|
||||
|
||||
|
||||
7
docs/ko/api-reference/kickoff.mdx
Normal file
7
docs/ko/api-reference/kickoff.mdx
Normal file
@@ -0,0 +1,7 @@
|
||||
---
|
||||
title: "POST /kickoff"
|
||||
description: "크루 실행 시작"
|
||||
openapi: "/enterprise-api.ko.yaml POST /kickoff"
|
||||
---
|
||||
|
||||
|
||||
7
docs/ko/api-reference/status.mdx
Normal file
7
docs/ko/api-reference/status.mdx
Normal file
@@ -0,0 +1,7 @@
|
||||
---
|
||||
title: "GET /status/{kickoff_id}"
|
||||
description: "실행 상태 조회"
|
||||
openapi: "/enterprise-api.ko.yaml GET /status/{kickoff_id}"
|
||||
---
|
||||
|
||||
|
||||
@@ -59,6 +59,7 @@ crew = Crew(
|
||||
| **Pydantic 출력** _(선택 사항)_ | `output_pydantic` | `Optional[Type[BaseModel]]` | 태스크 출력용 Pydantic 모델입니다. |
|
||||
| **콜백** _(선택 사항)_ | `callback` | `Optional[Any]` | 태스크 완료 후 실행할 함수/객체입니다. |
|
||||
| **가드레일** _(선택 사항)_ | `guardrail` | `Optional[Callable]` | 다음 태스크로 진행하기 전에 태스크 출력을 검증하는 함수입니다. |
|
||||
| **가드레일 최대 재시도** _(선택 사항)_ | `guardrail_max_retries` | `Optional[int]` | 가드레일 검증 실패 시 최대 재시도 횟수입니다. 기본값은 3입니다. |
|
||||
|
||||
## 작업 생성하기
|
||||
|
||||
@@ -448,7 +449,7 @@ task = Task(
|
||||
expected_output="A valid JSON object",
|
||||
agent=analyst,
|
||||
guardrail=validate_json_output,
|
||||
max_retries=3 # Limit retry attempts
|
||||
guardrail_max_retries=3 # 재시도 횟수 제한
|
||||
)
|
||||
```
|
||||
|
||||
@@ -899,4 +900,4 @@ except RuntimeError as e:
|
||||
작업(task)은 CrewAI 에이전트의 행동을 이끄는 원동력입니다.
|
||||
작업과 그 결과를 적절하게 정의함으로써, 에이전트가 독립적으로 또는 협업 단위로 효과적으로 작동할 수 있는 기반을 마련할 수 있습니다.
|
||||
작업에 적합한 도구를 장착하고, 실행 과정을 이해하며, 견고한 검증 절차를 따르는 것은 CrewAI의 잠재력을 극대화하는 데 필수적입니다.
|
||||
이를 통해 에이전트가 할당된 작업에 효과적으로 준비되고, 작업이 의도대로 수행될 수 있습니다.
|
||||
이를 통해 에이전트가 할당된 작업에 효과적으로 준비되고, 작업이 의도대로 수행될 수 있습니다.
|
||||
|
||||
171
docs/ko/enterprise/guides/automation-triggers.mdx
Normal file
171
docs/ko/enterprise/guides/automation-triggers.mdx
Normal file
@@ -0,0 +1,171 @@
|
||||
---
|
||||
title: "자동화 트리거"
|
||||
description: "연결된 통합에서 특정 이벤트가 발생할 때 CrewAI 워크플로우를 자동으로 실행합니다"
|
||||
icon: "bolt"
|
||||
---
|
||||
|
||||
자동화 트리거를 사용하면 연결된 통합에서 특정 이벤트가 발생할 때 CrewAI 배포를 자동으로 실행할 수 있어, 비즈니스 시스템의 실시간 변화에 반응하는 강력한 이벤트 기반 워크플로우를 만들 수 있습니다.
|
||||
|
||||
## 개요
|
||||
|
||||
자동화 트리거를 사용하면 다음을 수행할 수 있습니다:
|
||||
|
||||
- **실시간 이벤트에 응답** - 특정 조건이 충족될 때 워크플로우를 자동으로 실행
|
||||
- **외부 시스템과 통합** - Gmail, Outlook, OneDrive, JIRA, Slack, Stripe 등의 플랫폼과 연결
|
||||
- **자동화 확장** - 수동 개입 없이 대용량 이벤트 처리
|
||||
- **컨텍스트 유지** - crew와 flow 내에서 트리거 데이터에 액세스
|
||||
|
||||
## 자동화 트리거 관리
|
||||
|
||||
### 사용 가능한 트리거 보기
|
||||
|
||||
자동화 트리거에 액세스하고 관리하려면:
|
||||
|
||||
1. CrewAI 대시보드에서 배포로 이동
|
||||
2. **트리거** 탭을 클릭하여 사용 가능한 모든 트리거 통합 보기
|
||||
|
||||
<Frame>
|
||||
<img src="/images/enterprise/list-available-triggers.png" alt="사용 가능한 자동화 트리거 목록" />
|
||||
</Frame>
|
||||
|
||||
이 보기는 배포에 사용 가능한 모든 트리거 통합과 현재 연결 상태를 보여줍니다.
|
||||
|
||||
### 트리거 활성화 및 비활성화
|
||||
|
||||
각 트리거는 토글 스위치를 사용하여 쉽게 활성화하거나 비활성화할 수 있습니다:
|
||||
|
||||
<Frame>
|
||||
<img src="/images/enterprise/trigger-selected.png" alt="토글로 트리거 활성화 또는 비활성화" />
|
||||
</Frame>
|
||||
|
||||
- **활성화됨 (파란색 토글)**: 트리거가 활성 상태이며 지정된 이벤트가 발생할 때 배포를 자동으로 실행합니다
|
||||
- **비활성화됨 (회색 토글)**: 트리거가 비활성 상태이며 이벤트에 응답하지 않습니다
|
||||
|
||||
토글을 클릭하기만 하면 트리거 상태를 변경할 수 있습니다. 변경 사항은 즉시 적용됩니다.
|
||||
|
||||
### 트리거 실행 모니터링
|
||||
|
||||
트리거된 실행의 성능과 기록을 추적합니다:
|
||||
|
||||
<Frame>
|
||||
<img src="/images/enterprise/list-executions.png" alt="자동화에 의해 트리거된 실행 목록" />
|
||||
</Frame>
|
||||
|
||||
## 자동화 구축
|
||||
|
||||
자동화를 구축하기 전에 crew와 flow가 받을 트리거 페이로드의 구조를 이해하는 것이 도움이 됩니다.
|
||||
|
||||
### 페이로드 샘플 저장소
|
||||
|
||||
자동화를 구축하고 테스트하는 데 도움이 되도록 다양한 트리거 소스의 샘플 페이로드가 포함된 포괄적인 저장소를 유지 관리하고 있습니다:
|
||||
|
||||
**🔗 [CrewAI Enterprise 트리거 페이로드 샘플](https://github.com/crewAIInc/crewai-enterprise-trigger-payload-samples)**
|
||||
|
||||
이 저장소에는 다음이 포함되어 있습니다:
|
||||
|
||||
- **실제 페이로드 예제** - 다양한 트리거 소스(Gmail, Google Drive 등)에서 가져온 예제
|
||||
- **페이로드 구조 문서** - 형식과 사용 가능한 필드를 보여주는 문서
|
||||
|
||||
### Crew와 트리거
|
||||
|
||||
기존 crew 정의는 트리거와 완벽하게 작동하며, 받은 페이로드를 분석하는 작업만 있으면 됩니다:
|
||||
|
||||
```python
|
||||
@CrewBase
|
||||
class MyAutomatedCrew:
|
||||
@agent
|
||||
def researcher(self) -> Agent:
|
||||
return Agent(
|
||||
config=self.agents_config['researcher'],
|
||||
)
|
||||
|
||||
@task
|
||||
def parse_trigger_payload(self) -> Task:
|
||||
return Task(
|
||||
config=self.tasks_config['parse_trigger_payload'],
|
||||
agent=self.researcher(),
|
||||
)
|
||||
|
||||
@task
|
||||
def analyze_trigger_content(self) -> Task:
|
||||
return Task(
|
||||
config=self.tasks_config['analyze_trigger_data'],
|
||||
agent=self.researcher(),
|
||||
)
|
||||
```
|
||||
|
||||
crew는 자동으로 트리거 페이로드를 받고 표준 CrewAI 컨텍스트 메커니즘을 통해 액세스할 수 있습니다.
|
||||
|
||||
### Flow와의 통합
|
||||
|
||||
flow의 경우 트리거 데이터 처리 방법을 더 세밀하게 제어할 수 있습니다:
|
||||
|
||||
#### 트리거 페이로드 액세스
|
||||
|
||||
flow의 모든 `@start()` 메서드는 `crewai_trigger_payload`라는 추가 매개변수를 허용합니다:
|
||||
|
||||
```python
|
||||
from crewai.flow import Flow, start, listen
|
||||
|
||||
class MyAutomatedFlow(Flow):
|
||||
@start()
|
||||
def handle_trigger(self, crewai_trigger_payload: dict = None):
|
||||
"""
|
||||
이 start 메서드는 트리거 데이터를 받을 수 있습니다
|
||||
"""
|
||||
if crewai_trigger_payload:
|
||||
# 트리거 데이터 처리
|
||||
trigger_id = crewai_trigger_payload.get('id')
|
||||
event_data = crewai_trigger_payload.get('payload', {})
|
||||
|
||||
# 다른 메서드에서 사용할 수 있도록 flow 상태에 저장
|
||||
self.state.trigger_id = trigger_id
|
||||
self.state.trigger_type = event_data
|
||||
|
||||
return event_data
|
||||
|
||||
# 수동 실행 처리
|
||||
return None
|
||||
|
||||
@listen(handle_trigger)
|
||||
def process_data(self, trigger_data):
|
||||
"""
|
||||
트리거 데이터 처리
|
||||
"""
|
||||
# ... 트리거 처리
|
||||
```
|
||||
|
||||
#### Flow에서 Crew 트리거하기
|
||||
|
||||
트리거된 flow 내에서 crew를 시작할 때 트리거 페이로드를 전달합니다:
|
||||
|
||||
```python
|
||||
@start()
|
||||
def delegate_to_crew(self, crewai_trigger_payload: dict = None):
|
||||
"""
|
||||
전문 crew에 처리 위임
|
||||
"""
|
||||
crew = MySpecializedCrew()
|
||||
|
||||
# crew에 트리거 페이로드 전달
|
||||
result = crew.crew().kickoff(
|
||||
inputs={
|
||||
'a_custom_parameter': "custom_value",
|
||||
'crewai_trigger_payload': crewai_trigger_payload
|
||||
},
|
||||
)
|
||||
|
||||
return result
|
||||
```
|
||||
|
||||
## 문제 해결
|
||||
|
||||
**트리거가 작동하지 않는 경우:**
|
||||
- 트리거가 활성화되어 있는지 확인
|
||||
- 통합 연결 상태 확인
|
||||
|
||||
**실행 실패:**
|
||||
- 오류 세부 정보는 실행 로그 확인
|
||||
- 개발 중인 경우 입력에 올바른 페이로드가 포함된 `crewai_trigger_payload` 매개변수가 포함되어 있는지 확인
|
||||
|
||||
자동화 트리거는 CrewAI 배포를 기존 비즈니스 프로세스 및 도구와 완벽하게 통합할 수 있는 반응형 이벤트 기반 시스템으로 변환합니다.
|
||||
7
docs/pt-BR/api-reference/inputs.mdx
Normal file
7
docs/pt-BR/api-reference/inputs.mdx
Normal file
@@ -0,0 +1,7 @@
|
||||
---
|
||||
title: "GET /inputs"
|
||||
description: "Obter entradas necessárias para sua crew"
|
||||
openapi: "/enterprise-api.pt-BR.yaml GET /inputs"
|
||||
---
|
||||
|
||||
|
||||
7
docs/pt-BR/api-reference/kickoff.mdx
Normal file
7
docs/pt-BR/api-reference/kickoff.mdx
Normal file
@@ -0,0 +1,7 @@
|
||||
---
|
||||
title: "POST /kickoff"
|
||||
description: "Iniciar a execução da crew"
|
||||
openapi: "/enterprise-api.pt-BR.yaml POST /kickoff"
|
||||
---
|
||||
|
||||
|
||||
7
docs/pt-BR/api-reference/status.mdx
Normal file
7
docs/pt-BR/api-reference/status.mdx
Normal file
@@ -0,0 +1,7 @@
|
||||
---
|
||||
title: "GET /status/{kickoff_id}"
|
||||
description: "Obter o status da execução"
|
||||
openapi: "/enterprise-api.pt-BR.yaml GET /status/{kickoff_id}"
|
||||
---
|
||||
|
||||
|
||||
@@ -59,6 +59,7 @@ crew = Crew(
|
||||
| **Output Pydantic** _(opcional)_ | `output_pydantic` | `Optional[Type[BaseModel]]` | Um modelo Pydantic para a saída da tarefa. |
|
||||
| **Callback** _(opcional)_ | `callback` | `Optional[Any]` | Função/objeto a ser executado após a conclusão da tarefa. |
|
||||
| **Guardrail** _(opcional)_ | `guardrail` | `Optional[Callable]` | Função para validar a saída da tarefa antes de prosseguir para a próxima tarefa. |
|
||||
| **Max Tentativas Guardrail** _(opcional)_ | `guardrail_max_retries` | `Optional[int]` | Número máximo de tentativas quando a validação do guardrail falha. Padrão é 3. |
|
||||
|
||||
## Criando Tarefas
|
||||
|
||||
@@ -450,7 +451,7 @@ task = Task(
|
||||
expected_output="Um objeto JSON válido",
|
||||
agent=analyst,
|
||||
guardrail=validate_json_output,
|
||||
max_retries=3 # Limite de tentativas
|
||||
guardrail_max_retries=3 # Limite de tentativas
|
||||
)
|
||||
```
|
||||
|
||||
@@ -935,7 +936,7 @@ task = Task(
|
||||
description="Gerar dados",
|
||||
expected_output="Dados válidos",
|
||||
guardrail=validate_data,
|
||||
max_retries=5 # Sobrescreve o limite padrão de tentativas
|
||||
guardrail_max_retries=5 # Sobrescreve o limite padrão de tentativas
|
||||
)
|
||||
```
|
||||
|
||||
|
||||
171
docs/pt-BR/enterprise/guides/automation-triggers.mdx
Normal file
171
docs/pt-BR/enterprise/guides/automation-triggers.mdx
Normal file
@@ -0,0 +1,171 @@
|
||||
---
|
||||
title: "Triggers de Automação"
|
||||
description: "Execute automaticamente seus workflows CrewAI quando eventos específicos ocorrem em integrações conectadas"
|
||||
icon: "bolt"
|
||||
---
|
||||
|
||||
Os triggers de automação permitem executar automaticamente suas implantações CrewAI quando eventos específicos ocorrem em suas integrações conectadas, criando workflows poderosos orientados por eventos que respondem a mudanças em tempo real em seus sistemas de negócio.
|
||||
|
||||
## Visão Geral
|
||||
|
||||
Com triggers de automação, você pode:
|
||||
|
||||
- **Responder a eventos em tempo real** - Execute workflows automaticamente quando condições específicas forem atendidas
|
||||
- **Integrar com sistemas externos** - Conecte com plataformas como Gmail, Outlook, OneDrive, JIRA, Slack, Stripe e muito mais
|
||||
- **Escalar sua automação** - Lide com eventos de alto volume sem intervenção manual
|
||||
- **Manter contexto** - Acesse dados do trigger dentro de suas crews e flows
|
||||
|
||||
## Gerenciando Triggers de Automação
|
||||
|
||||
### Visualizando Triggers Disponíveis
|
||||
|
||||
Para acessar e gerenciar seus triggers de automação:
|
||||
|
||||
1. Navegue até sua implantação no painel do CrewAI
|
||||
2. Clique na aba **Triggers** para visualizar todas as integrações de trigger disponíveis
|
||||
|
||||
<Frame>
|
||||
<img src="/images/enterprise/list-available-triggers.png" alt="Lista de triggers de automação disponíveis" />
|
||||
</Frame>
|
||||
|
||||
Esta visualização mostra todas as integrações de trigger disponíveis para sua implantação, junto com seus status de conexão atuais.
|
||||
|
||||
### Habilitando e Desabilitando Triggers
|
||||
|
||||
Cada trigger pode ser facilmente habilitado ou desabilitado usando o botão de alternância:
|
||||
|
||||
<Frame>
|
||||
<img src="/images/enterprise/trigger-selected.png" alt="Habilitar ou desabilitar triggers com alternância" />
|
||||
</Frame>
|
||||
|
||||
- **Habilitado (alternância azul)**: O trigger está ativo e executará automaticamente sua implantação quando os eventos especificados ocorrerem
|
||||
- **Desabilitado (alternância cinza)**: O trigger está inativo e não responderá a eventos
|
||||
|
||||
Simplesmente clique na alternância para mudar o estado do trigger. As alterações entram em vigor imediatamente.
|
||||
|
||||
### Monitorando Execuções de Trigger
|
||||
|
||||
Acompanhe o desempenho e histórico de suas execuções acionadas:
|
||||
|
||||
<Frame>
|
||||
<img src="/images/enterprise/list-executions.png" alt="Lista de execuções acionadas por automação" />
|
||||
</Frame>
|
||||
|
||||
## Construindo Automação
|
||||
|
||||
Antes de construir sua automação, é útil entender a estrutura dos payloads de trigger que suas crews e flows receberão.
|
||||
|
||||
### Repositório de Amostras de Payload
|
||||
|
||||
Mantemos um repositório abrangente com amostras de payload de várias fontes de trigger para ajudá-lo a construir e testar suas automações:
|
||||
|
||||
**🔗 [Amostras de Payload de Trigger CrewAI Enterprise](https://github.com/crewAIInc/crewai-enterprise-trigger-payload-samples)**
|
||||
|
||||
Este repositório contém:
|
||||
|
||||
- **Exemplos reais de payload** de diferentes fontes de trigger (Gmail, Google Drive, etc.)
|
||||
- **Documentação da estrutura de payload** mostrando o formato e campos disponíveis
|
||||
|
||||
### Triggers com Crew
|
||||
|
||||
Suas definições de crew existentes funcionam perfeitamente com triggers, você só precisa ter uma tarefa para analisar o payload recebido:
|
||||
|
||||
```python
|
||||
@CrewBase
|
||||
class MinhaCrewAutomatizada:
|
||||
@agent
|
||||
def pesquisador(self) -> Agent:
|
||||
return Agent(
|
||||
config=self.agents_config['pesquisador'],
|
||||
)
|
||||
|
||||
@task
|
||||
def analisar_payload_trigger(self) -> Task:
|
||||
return Task(
|
||||
config=self.tasks_config['analisar_payload_trigger'],
|
||||
agent=self.pesquisador(),
|
||||
)
|
||||
|
||||
@task
|
||||
def analisar_conteudo_trigger(self) -> Task:
|
||||
return Task(
|
||||
config=self.tasks_config['analisar_dados_trigger'],
|
||||
agent=self.pesquisador(),
|
||||
)
|
||||
```
|
||||
|
||||
A crew receberá automaticamente e pode acessar o payload do trigger através dos mecanismos de contexto padrão do CrewAI.
|
||||
|
||||
### Integração com Flows
|
||||
|
||||
Para flows, você tem mais controle sobre como os dados do trigger são tratados:
|
||||
|
||||
#### Acessando Payload do Trigger
|
||||
|
||||
Todos os métodos `@start()` em seus flows aceitarão um parâmetro adicional chamado `crewai_trigger_payload`:
|
||||
|
||||
```python
|
||||
from crewai.flow import Flow, start, listen
|
||||
|
||||
class MeuFlowAutomatizado(Flow):
|
||||
@start()
|
||||
def lidar_com_trigger(self, crewai_trigger_payload: dict = None):
|
||||
"""
|
||||
Este método start pode receber dados do trigger
|
||||
"""
|
||||
if crewai_trigger_payload:
|
||||
# Processa os dados do trigger
|
||||
trigger_id = crewai_trigger_payload.get('id')
|
||||
dados_evento = crewai_trigger_payload.get('payload', {})
|
||||
|
||||
# Armazena no estado do flow para uso por outros métodos
|
||||
self.state.trigger_id = trigger_id
|
||||
self.state.trigger_type = dados_evento
|
||||
|
||||
return dados_evento
|
||||
|
||||
# Lida com execução manual
|
||||
return None
|
||||
|
||||
@listen(lidar_com_trigger)
|
||||
def processar_dados(self, dados_trigger):
|
||||
"""
|
||||
Processa os dados do trigger
|
||||
"""
|
||||
# ... processa o trigger
|
||||
```
|
||||
|
||||
#### Acionando Crews a partir de Flows
|
||||
|
||||
Ao iniciar uma crew dentro de um flow que foi acionado, passe o payload do trigger conforme ele:
|
||||
|
||||
```python
|
||||
@start()
|
||||
def delegar_para_crew(self, crewai_trigger_payload: dict = None):
|
||||
"""
|
||||
Delega processamento para uma crew especializada
|
||||
"""
|
||||
crew = MinhaCrewEspecializada()
|
||||
|
||||
# Passa o payload do trigger para a crew
|
||||
resultado = crew.crew().kickoff(
|
||||
inputs={
|
||||
'parametro_personalizado': "valor_personalizado",
|
||||
'crewai_trigger_payload': crewai_trigger_payload
|
||||
},
|
||||
)
|
||||
|
||||
return resultado
|
||||
```
|
||||
|
||||
## Solução de Problemas
|
||||
|
||||
**Trigger não está sendo disparado:**
|
||||
- Verifique se o trigger está habilitado
|
||||
- Verifique o status de conexão da integração
|
||||
|
||||
**Falhas de execução:**
|
||||
- Verifique os logs de execução para detalhes do erro
|
||||
- Se você está desenvolvendo, certifique-se de que as entradas incluem o parâmetro `crewai_trigger_payload` com o payload correto
|
||||
|
||||
Os triggers de automação transformam suas implantações CrewAI em sistemas responsivos orientados por eventos que podem se integrar perfeitamente com seus processos de negócio e ferramentas existentes.
|
||||
@@ -10,16 +10,16 @@ authors = [
|
||||
dependencies = [
|
||||
# Core Dependencies
|
||||
"pydantic>=2.4.2",
|
||||
"openai<1.100.0", # TODO: Temporarily pin due to https://github.com/openai/openai-python/issues/2564. Can revert once the issue is fixed.
|
||||
"openai>=1.13.3",
|
||||
"litellm==1.74.9",
|
||||
"instructor>=1.3.3",
|
||||
# Text Processing
|
||||
"pdfplumber>=0.11.4",
|
||||
"regex>=2024.9.11",
|
||||
# Telemetry and Monitoring
|
||||
"opentelemetry-api>=1.30.0",
|
||||
"opentelemetry-sdk>=1.30.0",
|
||||
"opentelemetry-exporter-otlp-proto-http>=1.30.0",
|
||||
"opentelemetry-api>=1.27.0,<1.28.0",
|
||||
"opentelemetry-sdk>=1.27.0,<1.28.0",
|
||||
"opentelemetry-exporter-otlp-proto-http>=1.27.0,<1.28.0",
|
||||
# Data Handling
|
||||
"chromadb>=0.5.23",
|
||||
"tokenizers>=0.20.3",
|
||||
@@ -68,6 +68,9 @@ docling = [
|
||||
aisuite = [
|
||||
"aisuite>=0.1.10",
|
||||
]
|
||||
qdrant = [
|
||||
"qdrant-client[fastembed]>=1.14.3",
|
||||
]
|
||||
|
||||
[tool.uv]
|
||||
dev-dependencies = [
|
||||
|
||||
@@ -1,7 +1,18 @@
|
||||
import shutil
|
||||
import subprocess
|
||||
import time
|
||||
from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple, Type, Union
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
|
||||
from pydantic import Field, InstanceOf, PrivateAttr, model_validator
|
||||
|
||||
@@ -162,7 +173,7 @@ class Agent(BaseAgent):
|
||||
)
|
||||
guardrail: Optional[Union[Callable[[Any], Tuple[bool, Any]], str]] = Field(
|
||||
default=None,
|
||||
description="Function or string description of a guardrail to validate agent output"
|
||||
description="Function or string description of a guardrail to validate agent output",
|
||||
)
|
||||
guardrail_max_retries: int = Field(
|
||||
default=3, description="Maximum number of retries when guardrail fails"
|
||||
@@ -276,7 +287,7 @@ class Agent(BaseAgent):
|
||||
self._inject_date_to_task(task)
|
||||
|
||||
if self.tools_handler:
|
||||
self.tools_handler.last_used_tool = {} # type: ignore # Incompatible types in assignment (expression has type "dict[Never, Never]", variable has type "ToolCalling")
|
||||
self.tools_handler.last_used_tool = None
|
||||
|
||||
task_prompt = task.prompt()
|
||||
|
||||
@@ -309,15 +320,20 @@ class Agent(BaseAgent):
|
||||
event=MemoryRetrievalStartedEvent(
|
||||
task_id=str(task.id) if task else None,
|
||||
source_type="agent",
|
||||
from_agent=self,
|
||||
from_task=task,
|
||||
),
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
contextual_memory = ContextualMemory(
|
||||
self.crew._short_term_memory,
|
||||
self.crew._long_term_memory,
|
||||
self.crew._entity_memory,
|
||||
self.crew._external_memory,
|
||||
agent=self,
|
||||
task=task,
|
||||
)
|
||||
memory = contextual_memory.build_context_for_task(task, context)
|
||||
if memory.strip() != "":
|
||||
@@ -330,13 +346,14 @@ class Agent(BaseAgent):
|
||||
memory_content=memory,
|
||||
retrieval_time_ms=(time.time() - start_time) * 1000,
|
||||
source_type="agent",
|
||||
from_agent=self,
|
||||
from_task=task,
|
||||
),
|
||||
)
|
||||
knowledge_config = (
|
||||
self.knowledge_config.model_dump() if self.knowledge_config else {}
|
||||
)
|
||||
|
||||
|
||||
if self.knowledge or (self.crew and self.crew.knowledge):
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
|
||||
@@ -43,7 +43,6 @@ class CrewAgentExecutorMixin:
|
||||
metadata={
|
||||
"observation": self.task.description,
|
||||
},
|
||||
agent=self.agent.role,
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Failed to add to short term memory: {e}")
|
||||
@@ -65,7 +64,6 @@ class CrewAgentExecutorMixin:
|
||||
"description": self.task.description,
|
||||
"messages": self.messages,
|
||||
},
|
||||
agent=self.agent.role,
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Failed to add to external memory: {e}")
|
||||
@@ -100,8 +98,8 @@ class CrewAgentExecutorMixin:
|
||||
)
|
||||
self.crew._long_term_memory.save(long_term_memory)
|
||||
|
||||
for entity in evaluation.entities:
|
||||
entity_memory = EntityMemoryItem(
|
||||
entity_memories = [
|
||||
EntityMemoryItem(
|
||||
name=entity.name,
|
||||
type=entity.type,
|
||||
description=entity.description,
|
||||
@@ -109,7 +107,10 @@ class CrewAgentExecutorMixin:
|
||||
[f"- {r}" for r in entity.relationships]
|
||||
),
|
||||
)
|
||||
self.crew._entity_memory.save(entity_memory)
|
||||
for entity in evaluation.entities
|
||||
]
|
||||
if entity_memories:
|
||||
self.crew._entity_memory.save(entity_memories)
|
||||
except AttributeError as e:
|
||||
print(f"Missing attributes for long term memory: {e}")
|
||||
pass
|
||||
@@ -158,7 +159,9 @@ class CrewAgentExecutorMixin:
|
||||
self._printer.print(content=prompt, color="bold_yellow")
|
||||
response = input()
|
||||
if response.strip() != "":
|
||||
self._printer.print(content="\nProcessing your feedback...", color="cyan")
|
||||
self._printer.print(
|
||||
content="\nProcessing your feedback...", color="cyan"
|
||||
)
|
||||
return response
|
||||
finally:
|
||||
event_listener.formatter.resume_live_updates()
|
||||
|
||||
@@ -8,13 +8,13 @@ from .cache.cache_handler import CacheHandler
|
||||
class ToolsHandler:
|
||||
"""Callback handler for tool usage."""
|
||||
|
||||
last_used_tool: ToolCalling = {} # type: ignore # BUG?: Incompatible types in assignment (expression has type "Dict[...]", variable has type "ToolCalling")
|
||||
last_used_tool: Optional[ToolCalling] = None
|
||||
cache: Optional[CacheHandler]
|
||||
|
||||
def __init__(self, cache: Optional[CacheHandler] = None):
|
||||
"""Initialize the callback handler."""
|
||||
self.cache = cache
|
||||
self.last_used_tool = {} # type: ignore # BUG?: same as above
|
||||
self.last_used_tool = None
|
||||
|
||||
def on_tool_use(
|
||||
self,
|
||||
|
||||
@@ -1,6 +1 @@
|
||||
ALGORITHMS = ["RS256"]
|
||||
|
||||
#TODO: The AUTH0 constants should be removed after WorkOS migration is completed
|
||||
AUTH0_DOMAIN = "crewai.us.auth0.com"
|
||||
AUTH0_CLIENT_ID = "DEVC5Fw6NlRoSzmDCcOhVq85EfLBjKa8"
|
||||
AUTH0_AUDIENCE = "https://crewai.us.auth0.com/api/v2/"
|
||||
|
||||
@@ -7,24 +7,27 @@ from rich.console import Console
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
from .utils import TokenManager, validate_jwt_token
|
||||
from urllib.parse import quote
|
||||
from crewai.cli.plus_api import PlusAPI
|
||||
from .utils import validate_jwt_token
|
||||
from crewai.cli.shared.token_manager import TokenManager
|
||||
from crewai.cli.config import Settings
|
||||
from crewai.cli.authentication.constants import (
|
||||
AUTH0_AUDIENCE,
|
||||
AUTH0_CLIENT_ID,
|
||||
AUTH0_DOMAIN,
|
||||
)
|
||||
|
||||
console = Console()
|
||||
|
||||
|
||||
class Oauth2Settings(BaseModel):
|
||||
provider: str = Field(description="OAuth2 provider used for authentication (e.g., workos, okta, auth0).")
|
||||
client_id: str = Field(description="OAuth2 client ID issued by the provider, used during authentication requests.")
|
||||
domain: str = Field(description="OAuth2 provider's domain (e.g., your-org.auth0.com) used for issuing tokens.")
|
||||
audience: Optional[str] = Field(description="OAuth2 audience value, typically used to identify the target API or resource.", default=None)
|
||||
provider: str = Field(
|
||||
description="OAuth2 provider used for authentication (e.g., workos, okta, auth0)."
|
||||
)
|
||||
client_id: str = Field(
|
||||
description="OAuth2 client ID issued by the provider, used during authentication requests."
|
||||
)
|
||||
domain: str = Field(
|
||||
description="OAuth2 provider's domain (e.g., your-org.auth0.com) used for issuing tokens."
|
||||
)
|
||||
audience: Optional[str] = Field(
|
||||
description="OAuth2 audience value, typically used to identify the target API or resource.",
|
||||
default=None,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_settings(cls):
|
||||
@@ -44,11 +47,15 @@ class ProviderFactory:
|
||||
settings = settings or Oauth2Settings.from_settings()
|
||||
|
||||
import importlib
|
||||
module = importlib.import_module(f"crewai.cli.authentication.providers.{settings.provider.lower()}")
|
||||
|
||||
module = importlib.import_module(
|
||||
f"crewai.cli.authentication.providers.{settings.provider.lower()}"
|
||||
)
|
||||
provider = getattr(module, f"{settings.provider.capitalize()}Provider")
|
||||
|
||||
return provider(settings)
|
||||
|
||||
|
||||
class AuthenticationCommand:
|
||||
def __init__(self):
|
||||
self.token_manager = TokenManager()
|
||||
@@ -58,26 +65,12 @@ class AuthenticationCommand:
|
||||
"""Sign up to CrewAI+"""
|
||||
console.print("Signing in to CrewAI Enterprise...\n", style="bold blue")
|
||||
|
||||
# TODO: WORKOS - Next line and conditional are temporary until migration to WorkOS is complete.
|
||||
user_provider = self._determine_user_provider()
|
||||
if user_provider == "auth0":
|
||||
settings = Oauth2Settings(
|
||||
provider="auth0",
|
||||
client_id=AUTH0_CLIENT_ID,
|
||||
domain=AUTH0_DOMAIN,
|
||||
audience=AUTH0_AUDIENCE
|
||||
)
|
||||
self.oauth2_provider = ProviderFactory.from_settings(settings)
|
||||
# End of temporary code.
|
||||
|
||||
device_code_data = self._get_device_code()
|
||||
self._display_auth_instructions(device_code_data)
|
||||
|
||||
return self._poll_for_token(device_code_data)
|
||||
|
||||
def _get_device_code(
|
||||
self
|
||||
) -> Dict[str, Any]:
|
||||
def _get_device_code(self) -> Dict[str, Any]:
|
||||
"""Get the device code to authenticate the user."""
|
||||
|
||||
device_code_payload = {
|
||||
@@ -86,7 +79,9 @@ class AuthenticationCommand:
|
||||
"audience": self.oauth2_provider.get_audience(),
|
||||
}
|
||||
response = requests.post(
|
||||
url=self.oauth2_provider.get_authorize_url(), data=device_code_payload, timeout=20
|
||||
url=self.oauth2_provider.get_authorize_url(),
|
||||
data=device_code_payload,
|
||||
timeout=20,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
@@ -97,9 +92,7 @@ class AuthenticationCommand:
|
||||
console.print("2. Enter the following code: ", device_code_data["user_code"])
|
||||
webbrowser.open(device_code_data["verification_uri_complete"])
|
||||
|
||||
def _poll_for_token(
|
||||
self, device_code_data: Dict[str, Any]
|
||||
) -> None:
|
||||
def _poll_for_token(self, device_code_data: Dict[str, Any]) -> None:
|
||||
"""Polls the server for the token until it is received, or max attempts are reached."""
|
||||
|
||||
token_payload = {
|
||||
@@ -112,7 +105,9 @@ class AuthenticationCommand:
|
||||
|
||||
attempts = 0
|
||||
while True and attempts < 10:
|
||||
response = requests.post(self.oauth2_provider.get_token_url(), data=token_payload, timeout=30)
|
||||
response = requests.post(
|
||||
self.oauth2_provider.get_token_url(), data=token_payload, timeout=30
|
||||
)
|
||||
token_data = response.json()
|
||||
|
||||
if response.status_code == 200:
|
||||
@@ -192,30 +187,3 @@ class AuthenticationCommand:
|
||||
"\nRun [bold]crewai login[/bold] to try logging in again.\n",
|
||||
style="yellow",
|
||||
)
|
||||
|
||||
# TODO: WORKOS - This method is temporary until migration to WorkOS is complete.
|
||||
def _determine_user_provider(self) -> str:
|
||||
"""Determine which provider to use for authentication."""
|
||||
|
||||
console.print(
|
||||
"Enter your CrewAI Enterprise account email: ", style="bold blue", end=""
|
||||
)
|
||||
email = input()
|
||||
email_encoded = quote(email)
|
||||
|
||||
# It's not correct to call this method directly, but it's temporary until migration is complete.
|
||||
response = PlusAPI("")._make_request(
|
||||
"GET", f"/crewai_plus/api/v1/me/provider?email={email_encoded}"
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
if response.json().get("provider") == "auth0":
|
||||
return "auth0"
|
||||
else:
|
||||
return "workos"
|
||||
else:
|
||||
console.print(
|
||||
"Error: Failed to authenticate with crewai enterprise. Ensure that you are using the latest crewai version and please try again. If the problem persists, contact support@crewai.com.",
|
||||
style="red",
|
||||
)
|
||||
raise SystemExit
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from .utils import TokenManager
|
||||
from crewai.cli.shared.token_manager import TokenManager
|
||||
|
||||
|
||||
class AuthError(Exception):
|
||||
|
||||
@@ -1,12 +1,5 @@
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
import jwt
|
||||
from jwt import PyJWKClient
|
||||
from cryptography.fernet import Fernet
|
||||
|
||||
|
||||
def validate_jwt_token(
|
||||
@@ -67,118 +60,3 @@ def validate_jwt_token(
|
||||
raise Exception(f"JWKS or key processing error: {str(e)}")
|
||||
except jwt.InvalidTokenError as e:
|
||||
raise Exception(f"Invalid token: {str(e)}")
|
||||
|
||||
|
||||
class TokenManager:
|
||||
def __init__(self, file_path: str = "tokens.enc") -> None:
|
||||
"""
|
||||
Initialize the TokenManager class.
|
||||
|
||||
:param file_path: The file path to store the encrypted tokens. Default is "tokens.enc".
|
||||
"""
|
||||
self.file_path = file_path
|
||||
self.key = self._get_or_create_key()
|
||||
self.fernet = Fernet(self.key)
|
||||
|
||||
def _get_or_create_key(self) -> bytes:
|
||||
"""
|
||||
Get or create the encryption key.
|
||||
|
||||
:return: The encryption key.
|
||||
"""
|
||||
key_filename = "secret.key"
|
||||
key = self.read_secure_file(key_filename)
|
||||
|
||||
if key is not None:
|
||||
return key
|
||||
|
||||
new_key = Fernet.generate_key()
|
||||
self.save_secure_file(key_filename, new_key)
|
||||
return new_key
|
||||
|
||||
def save_tokens(self, access_token: str, expires_at: int) -> None:
|
||||
"""
|
||||
Save the access token and its expiration time.
|
||||
|
||||
:param access_token: The access token to save.
|
||||
:param expires_at: The UNIX timestamp of the expiration time.
|
||||
"""
|
||||
expiration_time = datetime.fromtimestamp(expires_at)
|
||||
data = {
|
||||
"access_token": access_token,
|
||||
"expiration": expiration_time.isoformat(),
|
||||
}
|
||||
encrypted_data = self.fernet.encrypt(json.dumps(data).encode())
|
||||
self.save_secure_file(self.file_path, encrypted_data)
|
||||
|
||||
def get_token(self) -> Optional[str]:
|
||||
"""
|
||||
Get the access token if it is valid and not expired.
|
||||
|
||||
:return: The access token if valid and not expired, otherwise None.
|
||||
"""
|
||||
encrypted_data = self.read_secure_file(self.file_path)
|
||||
|
||||
decrypted_data = self.fernet.decrypt(encrypted_data) # type: ignore
|
||||
data = json.loads(decrypted_data)
|
||||
|
||||
expiration = datetime.fromisoformat(data["expiration"])
|
||||
if expiration <= datetime.now():
|
||||
return None
|
||||
|
||||
return data["access_token"]
|
||||
|
||||
def get_secure_storage_path(self) -> Path:
|
||||
"""
|
||||
Get the secure storage path based on the operating system.
|
||||
|
||||
:return: The secure storage path.
|
||||
"""
|
||||
if sys.platform == "win32":
|
||||
# Windows: Use %LOCALAPPDATA%
|
||||
base_path = os.environ.get("LOCALAPPDATA")
|
||||
elif sys.platform == "darwin":
|
||||
# macOS: Use ~/Library/Application Support
|
||||
base_path = os.path.expanduser("~/Library/Application Support")
|
||||
else:
|
||||
# Linux and other Unix-like: Use ~/.local/share
|
||||
base_path = os.path.expanduser("~/.local/share")
|
||||
|
||||
app_name = "crewai/credentials"
|
||||
storage_path = Path(base_path) / app_name
|
||||
|
||||
storage_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
return storage_path
|
||||
|
||||
def save_secure_file(self, filename: str, content: bytes) -> None:
|
||||
"""
|
||||
Save the content to a secure file.
|
||||
|
||||
:param filename: The name of the file.
|
||||
:param content: The content to save.
|
||||
"""
|
||||
storage_path = self.get_secure_storage_path()
|
||||
file_path = storage_path / filename
|
||||
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(content)
|
||||
|
||||
# Set appropriate permissions (read/write for owner only)
|
||||
os.chmod(file_path, 0o600)
|
||||
|
||||
def read_secure_file(self, filename: str) -> Optional[bytes]:
|
||||
"""
|
||||
Read the content of a secure file.
|
||||
|
||||
:param filename: The name of the file.
|
||||
:return: The content of the file if it exists, otherwise None.
|
||||
"""
|
||||
storage_path = self.get_secure_storage_path()
|
||||
file_path = storage_path / filename
|
||||
|
||||
if not file_path.exists():
|
||||
return None
|
||||
|
||||
with open(file_path, "rb") as f:
|
||||
return f.read()
|
||||
|
||||
@@ -11,6 +11,7 @@ from crewai.cli.constants import (
|
||||
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_CLIENT_ID,
|
||||
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN,
|
||||
)
|
||||
from crewai.cli.shared.token_manager import TokenManager
|
||||
|
||||
DEFAULT_CONFIG_PATH = Path.home() / ".config" / "crewai" / "settings.json"
|
||||
|
||||
@@ -53,6 +54,7 @@ HIDDEN_SETTINGS_KEYS = [
|
||||
"tool_repository_password",
|
||||
]
|
||||
|
||||
|
||||
class Settings(BaseModel):
|
||||
enterprise_base_url: Optional[str] = Field(
|
||||
default=DEFAULT_CLI_SETTINGS["enterprise_base_url"],
|
||||
@@ -74,12 +76,12 @@ class Settings(BaseModel):
|
||||
|
||||
oauth2_provider: str = Field(
|
||||
description="OAuth2 provider used for authentication (e.g., workos, okta, auth0).",
|
||||
default=DEFAULT_CLI_SETTINGS["oauth2_provider"]
|
||||
default=DEFAULT_CLI_SETTINGS["oauth2_provider"],
|
||||
)
|
||||
|
||||
oauth2_audience: Optional[str] = Field(
|
||||
description="OAuth2 audience value, typically used to identify the target API or resource.",
|
||||
default=DEFAULT_CLI_SETTINGS["oauth2_audience"]
|
||||
default=DEFAULT_CLI_SETTINGS["oauth2_audience"],
|
||||
)
|
||||
|
||||
oauth2_client_id: str = Field(
|
||||
@@ -89,7 +91,7 @@ class Settings(BaseModel):
|
||||
|
||||
oauth2_domain: str = Field(
|
||||
description="OAuth2 provider's domain (e.g., your-org.auth0.com) used for issuing tokens.",
|
||||
default=DEFAULT_CLI_SETTINGS["oauth2_domain"]
|
||||
default=DEFAULT_CLI_SETTINGS["oauth2_domain"],
|
||||
)
|
||||
|
||||
def __init__(self, config_path: Path = DEFAULT_CONFIG_PATH, **data):
|
||||
@@ -116,6 +118,7 @@ class Settings(BaseModel):
|
||||
"""Reset all settings to default values"""
|
||||
self._reset_user_settings()
|
||||
self._reset_cli_settings()
|
||||
self._clear_auth_tokens()
|
||||
self.dump()
|
||||
|
||||
def dump(self) -> None:
|
||||
@@ -139,3 +142,7 @@ class Settings(BaseModel):
|
||||
"""Reset all CLI settings to default values"""
|
||||
for key in CLI_SETTINGS_KEYS:
|
||||
setattr(self, key, DEFAULT_CLI_SETTINGS.get(key))
|
||||
|
||||
def _clear_auth_tokens(self) -> None:
|
||||
"""Clear all authentication tokens"""
|
||||
TokenManager().clear_tokens()
|
||||
|
||||
@@ -117,9 +117,6 @@ class PlusAPI:
|
||||
def get_organizations(self) -> requests.Response:
|
||||
return self._make_request("GET", self.ORGANIZATIONS_RESOURCE)
|
||||
|
||||
def send_trace_batch(self, payload) -> requests.Response:
|
||||
return self._make_request("POST", self.TRACING_RESOURCE, json=payload)
|
||||
|
||||
def initialize_trace_batch(self, payload) -> requests.Response:
|
||||
return self._make_request(
|
||||
"POST", f"{self.TRACING_RESOURCE}/batches", json=payload
|
||||
@@ -135,6 +132,7 @@ class PlusAPI:
|
||||
"POST",
|
||||
f"{self.TRACING_RESOURCE}/batches/{trace_batch_id}/events",
|
||||
json=payload,
|
||||
timeout=30,
|
||||
)
|
||||
|
||||
def send_ephemeral_trace_events(
|
||||
@@ -144,6 +142,7 @@ class PlusAPI:
|
||||
"POST",
|
||||
f"{self.EPHEMERAL_TRACING_RESOURCE}/batches/{trace_batch_id}/events",
|
||||
json=payload,
|
||||
timeout=30,
|
||||
)
|
||||
|
||||
def finalize_trace_batch(self, trace_batch_id: str, payload) -> requests.Response:
|
||||
|
||||
0
src/crewai/cli/shared/__init__.py
Normal file
0
src/crewai/cli/shared/__init__.py
Normal file
139
src/crewai/cli/shared/token_manager.py
Normal file
139
src/crewai/cli/shared/token_manager.py
Normal file
@@ -0,0 +1,139 @@
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from cryptography.fernet import Fernet
|
||||
|
||||
|
||||
class TokenManager:
|
||||
def __init__(self, file_path: str = "tokens.enc") -> None:
|
||||
"""
|
||||
Initialize the TokenManager class.
|
||||
|
||||
:param file_path: The file path to store the encrypted tokens. Default is "tokens.enc".
|
||||
"""
|
||||
self.file_path = file_path
|
||||
self.key = self._get_or_create_key()
|
||||
self.fernet = Fernet(self.key)
|
||||
|
||||
def _get_or_create_key(self) -> bytes:
|
||||
"""
|
||||
Get or create the encryption key.
|
||||
|
||||
:return: The encryption key.
|
||||
"""
|
||||
key_filename = "secret.key"
|
||||
key = self.read_secure_file(key_filename)
|
||||
|
||||
if key is not None:
|
||||
return key
|
||||
|
||||
new_key = Fernet.generate_key()
|
||||
self.save_secure_file(key_filename, new_key)
|
||||
return new_key
|
||||
|
||||
def save_tokens(self, access_token: str, expires_at: int) -> None:
|
||||
"""
|
||||
Save the access token and its expiration time.
|
||||
|
||||
:param access_token: The access token to save.
|
||||
:param expires_at: The UNIX timestamp of the expiration time.
|
||||
"""
|
||||
expiration_time = datetime.fromtimestamp(expires_at)
|
||||
data = {
|
||||
"access_token": access_token,
|
||||
"expiration": expiration_time.isoformat(),
|
||||
}
|
||||
encrypted_data = self.fernet.encrypt(json.dumps(data).encode())
|
||||
self.save_secure_file(self.file_path, encrypted_data)
|
||||
|
||||
def get_token(self) -> Optional[str]:
|
||||
"""
|
||||
Get the access token if it is valid and not expired.
|
||||
|
||||
:return: The access token if valid and not expired, otherwise None.
|
||||
"""
|
||||
encrypted_data = self.read_secure_file(self.file_path)
|
||||
|
||||
decrypted_data = self.fernet.decrypt(encrypted_data) # type: ignore
|
||||
data = json.loads(decrypted_data)
|
||||
|
||||
expiration = datetime.fromisoformat(data["expiration"])
|
||||
if expiration <= datetime.now():
|
||||
return None
|
||||
|
||||
return data["access_token"]
|
||||
|
||||
def clear_tokens(self) -> None:
|
||||
"""
|
||||
Clear the tokens.
|
||||
"""
|
||||
self.delete_secure_file(self.file_path)
|
||||
|
||||
def get_secure_storage_path(self) -> Path:
|
||||
"""
|
||||
Get the secure storage path based on the operating system.
|
||||
|
||||
:return: The secure storage path.
|
||||
"""
|
||||
if sys.platform == "win32":
|
||||
# Windows: Use %LOCALAPPDATA%
|
||||
base_path = os.environ.get("LOCALAPPDATA")
|
||||
elif sys.platform == "darwin":
|
||||
# macOS: Use ~/Library/Application Support
|
||||
base_path = os.path.expanduser("~/Library/Application Support")
|
||||
else:
|
||||
# Linux and other Unix-like: Use ~/.local/share
|
||||
base_path = os.path.expanduser("~/.local/share")
|
||||
|
||||
app_name = "crewai/credentials"
|
||||
storage_path = Path(base_path) / app_name
|
||||
|
||||
storage_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
return storage_path
|
||||
|
||||
def save_secure_file(self, filename: str, content: bytes) -> None:
|
||||
"""
|
||||
Save the content to a secure file.
|
||||
|
||||
:param filename: The name of the file.
|
||||
:param content: The content to save.
|
||||
"""
|
||||
storage_path = self.get_secure_storage_path()
|
||||
file_path = storage_path / filename
|
||||
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(content)
|
||||
|
||||
# Set appropriate permissions (read/write for owner only)
|
||||
os.chmod(file_path, 0o600)
|
||||
|
||||
def read_secure_file(self, filename: str) -> Optional[bytes]:
|
||||
"""
|
||||
Read the content of a secure file.
|
||||
|
||||
:param filename: The name of the file.
|
||||
:return: The content of the file if it exists, otherwise None.
|
||||
"""
|
||||
storage_path = self.get_secure_storage_path()
|
||||
file_path = storage_path / filename
|
||||
|
||||
if not file_path.exists():
|
||||
return None
|
||||
|
||||
with open(file_path, "rb") as f:
|
||||
return f.read()
|
||||
|
||||
def delete_secure_file(self, filename: str) -> None:
|
||||
"""
|
||||
Delete the secure file.
|
||||
|
||||
:param filename: The name of the file.
|
||||
"""
|
||||
storage_path = self.get_secure_storage_path()
|
||||
file_path = storage_path / filename
|
||||
if file_path.exists():
|
||||
file_path.unlink(missing_ok=True)
|
||||
@@ -44,8 +44,9 @@ def migrate_pyproject(input_file, output_file):
|
||||
]
|
||||
new_pyproject["project"]["requires-python"] = poetry_data.get("python")
|
||||
else:
|
||||
# If it's already in the new format, just copy the project section
|
||||
# If it's already in the new format, just copy the project and tool sections
|
||||
new_pyproject["project"] = pyproject_data.get("project", {})
|
||||
new_pyproject["tool"] = pyproject_data.get("tool", {})
|
||||
|
||||
# Migrate or copy dependencies
|
||||
if "dependencies" in new_pyproject["project"]:
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import threading
|
||||
from typing import Any
|
||||
from typing import Any, Optional
|
||||
|
||||
from crewai.experimental.evaluation.base_evaluator import AgentEvaluationResult, AggregationStrategy
|
||||
from crewai.agent import Agent
|
||||
@@ -15,10 +15,11 @@ from crewai.utilities.events.agent_events import LiteAgentExecutionCompletedEven
|
||||
from crewai.experimental.evaluation.base_evaluator import AgentAggregatedEvaluationResult, EvaluationScore, MetricCategory
|
||||
|
||||
class ExecutionState:
|
||||
current_agent_id: Optional[str] = None
|
||||
current_task_id: Optional[str] = None
|
||||
|
||||
def __init__(self):
|
||||
self.traces = {}
|
||||
self.current_agent_id: str | None = None
|
||||
self.current_task_id: str | None = None
|
||||
self.iteration = 1
|
||||
self.iterations_results = {}
|
||||
self.agent_evaluators = {}
|
||||
|
||||
@@ -474,6 +474,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
self._method_outputs: List[Any] = [] # List to store all method outputs
|
||||
self._completed_methods: Set[str] = set() # Track completed methods for reload
|
||||
self._persistence: Optional[FlowPersistence] = persistence
|
||||
self._is_execution_resuming: bool = False
|
||||
|
||||
# Initialize state with initial values
|
||||
self._state = self._create_initial_state()
|
||||
@@ -829,6 +830,9 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
# Clear completed methods and outputs for a fresh start
|
||||
self._completed_methods.clear()
|
||||
self._method_outputs.clear()
|
||||
else:
|
||||
# We're restoring from persistence, set the flag
|
||||
self._is_execution_resuming = True
|
||||
|
||||
if inputs:
|
||||
# Override the id in the state if it exists in inputs
|
||||
@@ -880,6 +884,9 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
]
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
# Clear the resumption flag after initial execution completes
|
||||
self._is_execution_resuming = False
|
||||
|
||||
final_output = self._method_outputs[-1] if self._method_outputs else None
|
||||
|
||||
crewai_event_bus.emit(
|
||||
@@ -913,17 +920,56 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
- Triggers execution of any listeners waiting on this start method
|
||||
- Part of the flow's initialization sequence
|
||||
- Skips execution if method was already completed (e.g., after reload)
|
||||
- Automatically injects crewai_trigger_payload if available in flow inputs
|
||||
"""
|
||||
if start_method_name in self._completed_methods:
|
||||
last_output = self._method_outputs[-1] if self._method_outputs else None
|
||||
await self._execute_listeners(start_method_name, last_output)
|
||||
return
|
||||
if self._is_execution_resuming:
|
||||
# During resumption, skip execution but continue listeners
|
||||
last_output = self._method_outputs[-1] if self._method_outputs else None
|
||||
await self._execute_listeners(start_method_name, last_output)
|
||||
return
|
||||
# For cyclic flows, clear from completed to allow re-execution
|
||||
self._completed_methods.discard(start_method_name)
|
||||
|
||||
method = self._methods[start_method_name]
|
||||
enhanced_method = self._inject_trigger_payload_for_start_method(method)
|
||||
|
||||
result = await self._execute_method(
|
||||
start_method_name, self._methods[start_method_name]
|
||||
start_method_name, enhanced_method
|
||||
)
|
||||
await self._execute_listeners(start_method_name, result)
|
||||
|
||||
def _inject_trigger_payload_for_start_method(self, original_method: Callable) -> Callable:
|
||||
def prepare_kwargs(*args, **kwargs):
|
||||
inputs = baggage.get_baggage("flow_inputs") or {}
|
||||
trigger_payload = inputs.get("crewai_trigger_payload")
|
||||
|
||||
sig = inspect.signature(original_method)
|
||||
accepts_trigger_payload = "crewai_trigger_payload" in sig.parameters
|
||||
|
||||
if trigger_payload is not None and accepts_trigger_payload:
|
||||
kwargs["crewai_trigger_payload"] = trigger_payload
|
||||
elif trigger_payload is not None:
|
||||
self._log_flow_event(
|
||||
f"Trigger payload available but {original_method.__name__} doesn't accept crewai_trigger_payload parameter",
|
||||
color="yellow"
|
||||
)
|
||||
return args, kwargs
|
||||
|
||||
if asyncio.iscoroutinefunction(original_method):
|
||||
async def enhanced_method(*args, **kwargs):
|
||||
args, kwargs = prepare_kwargs(*args, **kwargs)
|
||||
return await original_method(*args, **kwargs)
|
||||
else:
|
||||
def enhanced_method(*args, **kwargs):
|
||||
args, kwargs = prepare_kwargs(*args, **kwargs)
|
||||
return original_method(*args, **kwargs)
|
||||
|
||||
enhanced_method.__name__ = original_method.__name__
|
||||
enhanced_method.__doc__ = original_method.__doc__
|
||||
|
||||
return enhanced_method
|
||||
|
||||
async def _execute_method(
|
||||
self, method_name: str, method: Callable, *args: Any, **kwargs: Any
|
||||
) -> Any:
|
||||
@@ -1015,11 +1061,15 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
for router_name in routers_triggered:
|
||||
await self._execute_single_listener(router_name, result)
|
||||
# After executing router, the router's result is the path
|
||||
router_result = self._method_outputs[-1]
|
||||
router_result = (
|
||||
self._method_outputs[-1] if self._method_outputs else None
|
||||
)
|
||||
if router_result: # Only add non-None results
|
||||
router_results.append(router_result)
|
||||
current_trigger = (
|
||||
router_result # Update for next iteration of router chain
|
||||
str(router_result)
|
||||
if router_result is not None
|
||||
else "" # Update for next iteration of router chain
|
||||
)
|
||||
|
||||
# Now execute normal listeners for all router results and the original trigger
|
||||
@@ -1037,6 +1087,24 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
]
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
if current_trigger in router_results:
|
||||
# Find start methods triggered by this router result
|
||||
for method_name in self._start_methods:
|
||||
# Check if this start method is triggered by the current trigger
|
||||
if method_name in self._listeners:
|
||||
condition_type, trigger_methods = self._listeners[
|
||||
method_name
|
||||
]
|
||||
if current_trigger in trigger_methods:
|
||||
# Only execute if this is a cycle (method was already completed)
|
||||
if method_name in self._completed_methods:
|
||||
# For router-triggered start methods in cycles, temporarily clear resumption flag
|
||||
# to allow cyclic execution
|
||||
was_resuming = self._is_execution_resuming
|
||||
self._is_execution_resuming = False
|
||||
await self._execute_start_method(method_name)
|
||||
self._is_execution_resuming = was_resuming
|
||||
|
||||
def _find_triggered_methods(
|
||||
self, trigger_method: str, router_only: bool
|
||||
) -> List[str]:
|
||||
@@ -1074,6 +1142,9 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
if router_only != is_router:
|
||||
continue
|
||||
|
||||
if not router_only and listener_name in self._start_methods:
|
||||
continue
|
||||
|
||||
if condition_type == "OR":
|
||||
# If the trigger_method matches any in methods, run this
|
||||
if trigger_method in methods:
|
||||
@@ -1123,10 +1194,13 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
Catches and logs any exceptions during execution, preventing
|
||||
individual listener failures from breaking the entire flow.
|
||||
"""
|
||||
# TODO: greyson fix
|
||||
# if listener_name in self._completed_methods:
|
||||
# await self._execute_listeners(listener_name, None)
|
||||
# return
|
||||
if listener_name in self._completed_methods:
|
||||
if self._is_execution_resuming:
|
||||
# During resumption, skip execution but continue listeners
|
||||
await self._execute_listeners(listener_name, None)
|
||||
return
|
||||
# For cyclic flows, clear from completed to allow re-execution
|
||||
self._completed_methods.discard(listener_name)
|
||||
|
||||
try:
|
||||
method = self._methods[listener_name]
|
||||
|
||||
@@ -69,12 +69,7 @@ from crewai.utilities.events.agent_events import (
|
||||
LiteAgentExecutionStartedEvent,
|
||||
)
|
||||
from crewai.utilities.events.crewai_event_bus import crewai_event_bus
|
||||
from crewai.utilities.events.llm_events import (
|
||||
LLMCallCompletedEvent,
|
||||
LLMCallFailedEvent,
|
||||
LLMCallStartedEvent,
|
||||
LLMCallType,
|
||||
)
|
||||
|
||||
from crewai.utilities.llm_utils import create_llm
|
||||
from crewai.utilities.printer import Printer
|
||||
from crewai.utilities.token_counter_callback import TokenCalcHandler
|
||||
@@ -519,19 +514,6 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
|
||||
enforce_rpm_limit(self.request_within_rpm_limit)
|
||||
|
||||
llm = cast(LLM, self.llm)
|
||||
model = llm.model if hasattr(llm, "model") else "unknown"
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LLMCallStartedEvent(
|
||||
messages=self._messages,
|
||||
tools=None,
|
||||
callbacks=self._callbacks,
|
||||
from_agent=self,
|
||||
model=model,
|
||||
),
|
||||
)
|
||||
|
||||
try:
|
||||
answer = get_llm_response(
|
||||
llm=cast(LLM, self.llm),
|
||||
@@ -541,23 +523,7 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
from_agent=self,
|
||||
)
|
||||
|
||||
# Emit LLM call completed event
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LLMCallCompletedEvent(
|
||||
messages=self._messages,
|
||||
response=answer,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_agent=self,
|
||||
model=model,
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
# Emit LLM call failed event
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LLMCallFailedEvent(error=str(e), from_agent=self),
|
||||
)
|
||||
raise e
|
||||
|
||||
formatted_answer = process_llm_response(answer, self.use_stop_words)
|
||||
|
||||
@@ -851,7 +851,9 @@ class LLM(BaseLLM):
|
||||
return tool_calls
|
||||
|
||||
# --- 7) Handle tool calls if present
|
||||
tool_result = self._handle_tool_call(tool_calls, available_functions)
|
||||
tool_result = self._handle_tool_call(
|
||||
tool_calls, available_functions, from_task, from_agent
|
||||
)
|
||||
if tool_result is not None:
|
||||
return tool_result
|
||||
# --- 8) If tool call handling didn't return a result, emit completion event and return text response
|
||||
@@ -868,6 +870,8 @@ class LLM(BaseLLM):
|
||||
self,
|
||||
tool_calls: List[Any],
|
||||
available_functions: Optional[Dict[str, Any]] = None,
|
||||
from_task: Optional[Any] = None,
|
||||
from_agent: Optional[Any] = None,
|
||||
) -> Optional[str]:
|
||||
"""Handle a tool call from the LLM.
|
||||
|
||||
@@ -902,6 +906,8 @@ class LLM(BaseLLM):
|
||||
event=ToolUsageStartedEvent(
|
||||
tool_name=function_name,
|
||||
tool_args=function_args,
|
||||
from_agent=from_agent,
|
||||
from_task=from_task,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -914,12 +920,17 @@ class LLM(BaseLLM):
|
||||
tool_args=function_args,
|
||||
started_at=started_at,
|
||||
finished_at=datetime.now(),
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
|
||||
# --- 3.3) Emit success event
|
||||
self._handle_emit_call_events(
|
||||
response=result, call_type=LLMCallType.TOOL_CALL
|
||||
response=result,
|
||||
call_type=LLMCallType.TOOL_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
return result
|
||||
except Exception as e:
|
||||
@@ -1139,7 +1150,11 @@ class LLM(BaseLLM):
|
||||
|
||||
# TODO: Remove this code after merging PR https://github.com/BerriAI/litellm/pull/10917
|
||||
# Ollama doesn't supports last message to be 'assistant'
|
||||
if "ollama" in self.model.lower() and messages and messages[-1]["role"] == "assistant":
|
||||
if (
|
||||
"ollama" in self.model.lower()
|
||||
and messages
|
||||
and messages[-1]["role"] == "assistant"
|
||||
):
|
||||
return messages + [{"role": "user", "content": ""}]
|
||||
|
||||
# Handle Anthropic models
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Optional
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
|
||||
from crewai.memory import (
|
||||
EntityMemory,
|
||||
@@ -7,6 +7,10 @@ from crewai.memory import (
|
||||
ShortTermMemory,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.agent import Agent
|
||||
from crewai.task import Task
|
||||
|
||||
|
||||
class ContextualMemory:
|
||||
def __init__(
|
||||
@@ -15,11 +19,28 @@ class ContextualMemory:
|
||||
ltm: LongTermMemory,
|
||||
em: EntityMemory,
|
||||
exm: ExternalMemory,
|
||||
agent: Optional["Agent"] = None,
|
||||
task: Optional["Task"] = None,
|
||||
):
|
||||
self.stm = stm
|
||||
self.ltm = ltm
|
||||
self.em = em
|
||||
self.exm = exm
|
||||
self.agent = agent
|
||||
self.task = task
|
||||
|
||||
if self.stm is not None:
|
||||
self.stm.agent = self.agent
|
||||
self.stm.task = self.task
|
||||
if self.ltm is not None:
|
||||
self.ltm.agent = self.agent
|
||||
self.ltm.task = self.task
|
||||
if self.em is not None:
|
||||
self.em.agent = self.agent
|
||||
self.em.task = self.task
|
||||
if self.exm is not None:
|
||||
self.exm.agent = self.agent
|
||||
self.exm.task = self.task
|
||||
|
||||
def build_context_for_task(self, task, context) -> str:
|
||||
"""
|
||||
@@ -49,10 +70,7 @@ class ContextualMemory:
|
||||
|
||||
stm_results = self.stm.search(query)
|
||||
formatted_results = "\n".join(
|
||||
[
|
||||
f"- {result['context']}"
|
||||
for result in stm_results
|
||||
]
|
||||
[f"- {result['context']}" for result in stm_results]
|
||||
)
|
||||
return f"Recent Insights:\n{formatted_results}" if stm_results else ""
|
||||
|
||||
@@ -89,10 +107,7 @@ class ContextualMemory:
|
||||
|
||||
em_results = self.em.search(query)
|
||||
formatted_results = "\n".join(
|
||||
[
|
||||
f"- {result['context']}"
|
||||
for result in em_results
|
||||
] # type: ignore # Invalid index type "str" for "str"; expected type "SupportsIndex | slice"
|
||||
[f"- {result['context']}" for result in em_results] # type: ignore # Invalid index type "str" for "str"; expected type "SupportsIndex | slice"
|
||||
)
|
||||
return f"Entities:\n{formatted_results}" if em_results else ""
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Optional
|
||||
from typing import Any
|
||||
import time
|
||||
|
||||
from pydantic import PrivateAttr
|
||||
@@ -24,7 +24,7 @@ class EntityMemory(Memory):
|
||||
Inherits from the Memory class.
|
||||
"""
|
||||
|
||||
_memory_provider: Optional[str] = PrivateAttr()
|
||||
_memory_provider: str | None = PrivateAttr()
|
||||
|
||||
def __init__(self, crew=None, embedder_config=None, storage=None, path=None):
|
||||
memory_provider = embedder_config.get("provider") if embedder_config else None
|
||||
@@ -35,7 +35,7 @@ class EntityMemory(Memory):
|
||||
raise ImportError(
|
||||
"Mem0 is not installed. Please install it with `pip install mem0ai`."
|
||||
)
|
||||
config = embedder_config.get("config")
|
||||
config = embedder_config.get("config") if embedder_config else None
|
||||
storage = Mem0Storage(type="short_term", crew=crew, config=config)
|
||||
else:
|
||||
storage = (
|
||||
@@ -53,47 +53,99 @@ class EntityMemory(Memory):
|
||||
super().__init__(storage=storage)
|
||||
self._memory_provider = memory_provider
|
||||
|
||||
def save(self, item: EntityMemoryItem) -> None: # type: ignore # BUG?: Signature of "save" incompatible with supertype "Memory"
|
||||
"""Saves an entity item into the SQLite storage."""
|
||||
def save(
|
||||
self,
|
||||
value: EntityMemoryItem | list[EntityMemoryItem],
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""Saves one or more entity items into the SQLite storage.
|
||||
|
||||
Args:
|
||||
value: Single EntityMemoryItem or list of EntityMemoryItems to save.
|
||||
metadata: Optional metadata dict (included for supertype compatibility but not used).
|
||||
|
||||
Notes:
|
||||
The metadata parameter is included to satisfy the supertype signature but is not
|
||||
used - entity metadata is extracted from the EntityMemoryItem objects themselves.
|
||||
"""
|
||||
|
||||
if not value:
|
||||
return
|
||||
|
||||
items = value if isinstance(value, list) else [value]
|
||||
is_batch = len(items) > 1
|
||||
|
||||
metadata = {"entity_count": len(items)} if is_batch else items[0].metadata
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemorySaveStartedEvent(
|
||||
metadata=item.metadata,
|
||||
metadata=metadata,
|
||||
source_type="entity_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
saved_count = 0
|
||||
errors = []
|
||||
|
||||
try:
|
||||
if self._memory_provider == "mem0":
|
||||
data = f"""
|
||||
Remember details about the following entity:
|
||||
Name: {item.name}
|
||||
Type: {item.type}
|
||||
Entity Description: {item.description}
|
||||
"""
|
||||
for item in items:
|
||||
try:
|
||||
if self._memory_provider == "mem0":
|
||||
data = f"""
|
||||
Remember details about the following entity:
|
||||
Name: {item.name}
|
||||
Type: {item.type}
|
||||
Entity Description: {item.description}
|
||||
"""
|
||||
else:
|
||||
data = f"{item.name}({item.type}): {item.description}"
|
||||
|
||||
super().save(data, item.metadata)
|
||||
saved_count += 1
|
||||
except Exception as e:
|
||||
errors.append(f"{item.name}: {str(e)}")
|
||||
|
||||
if is_batch:
|
||||
emit_value = f"Saved {saved_count} entities"
|
||||
metadata = {"entity_count": saved_count, "errors": errors}
|
||||
else:
|
||||
data = f"{item.name}({item.type}): {item.description}"
|
||||
emit_value = f"{items[0].name}({items[0].type}): {items[0].description}"
|
||||
metadata = items[0].metadata
|
||||
|
||||
super().save(data, item.metadata)
|
||||
|
||||
# Emit memory save completed event
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemorySaveCompletedEvent(
|
||||
value=data,
|
||||
metadata=item.metadata,
|
||||
value=emit_value,
|
||||
metadata=metadata,
|
||||
save_time_ms=(time.time() - start_time) * 1000,
|
||||
source_type="entity_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
|
||||
if errors:
|
||||
raise Exception(
|
||||
f"Partial save: {len(errors)} failed out of {len(items)}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
fail_metadata = (
|
||||
{"entity_count": len(items), "saved": saved_count}
|
||||
if is_batch
|
||||
else items[0].metadata
|
||||
)
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemorySaveFailedEvent(
|
||||
metadata=item.metadata,
|
||||
metadata=fail_metadata,
|
||||
error=str(e),
|
||||
source_type="entity_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
raise
|
||||
@@ -111,6 +163,8 @@ class EntityMemory(Memory):
|
||||
limit=limit,
|
||||
score_threshold=score_threshold,
|
||||
source_type="entity_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -129,6 +183,8 @@ class EntityMemory(Memory):
|
||||
score_threshold=score_threshold,
|
||||
query_time_ms=(time.time() - start_time) * 1000,
|
||||
source_type="entity_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
22
src/crewai/memory/external/external_memory.py
vendored
22
src/crewai/memory/external/external_memory.py
vendored
@@ -53,7 +53,6 @@ class ExternalMemory(Memory):
|
||||
self,
|
||||
value: Any,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
agent: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Saves a value into the external storage."""
|
||||
crewai_event_bus.emit(
|
||||
@@ -61,24 +60,30 @@ class ExternalMemory(Memory):
|
||||
event=MemorySaveStartedEvent(
|
||||
value=value,
|
||||
metadata=metadata,
|
||||
agent_role=agent,
|
||||
source_type="external_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
try:
|
||||
item = ExternalMemoryItem(value=value, metadata=metadata, agent=agent)
|
||||
super().save(value=item.value, metadata=item.metadata, agent=item.agent)
|
||||
item = ExternalMemoryItem(
|
||||
value=value,
|
||||
metadata=metadata,
|
||||
agent=self.agent.role if self.agent else None,
|
||||
)
|
||||
super().save(value=item.value, metadata=item.metadata)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemorySaveCompletedEvent(
|
||||
value=value,
|
||||
metadata=metadata,
|
||||
agent_role=agent,
|
||||
save_time_ms=(time.time() - start_time) * 1000,
|
||||
source_type="external_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
@@ -87,9 +92,10 @@ class ExternalMemory(Memory):
|
||||
event=MemorySaveFailedEvent(
|
||||
value=value,
|
||||
metadata=metadata,
|
||||
agent_role=agent,
|
||||
error=str(e),
|
||||
source_type="external_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
raise
|
||||
@@ -107,6 +113,8 @@ class ExternalMemory(Memory):
|
||||
limit=limit,
|
||||
score_threshold=score_threshold,
|
||||
source_type="external_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -125,6 +133,8 @@ class ExternalMemory(Memory):
|
||||
score_threshold=score_threshold,
|
||||
query_time_ms=(time.time() - start_time) * 1000,
|
||||
source_type="external_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -37,13 +37,17 @@ class LongTermMemory(Memory):
|
||||
metadata=item.metadata,
|
||||
agent_role=item.agent,
|
||||
source_type="long_term_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
try:
|
||||
metadata = item.metadata
|
||||
metadata.update({"agent": item.agent, "expected_output": item.expected_output})
|
||||
metadata.update(
|
||||
{"agent": item.agent, "expected_output": item.expected_output}
|
||||
)
|
||||
self.storage.save( # type: ignore # BUG?: Unexpected keyword argument "task_description","score","datetime" for "save" of "Storage"
|
||||
task_description=item.task,
|
||||
score=metadata["quality"],
|
||||
@@ -59,6 +63,8 @@ class LongTermMemory(Memory):
|
||||
agent_role=item.agent,
|
||||
save_time_ms=(time.time() - start_time) * 1000,
|
||||
source_type="long_term_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
@@ -74,13 +80,19 @@ class LongTermMemory(Memory):
|
||||
)
|
||||
raise
|
||||
|
||||
def search(self, task: str, latest_n: int = 3) -> List[Dict[str, Any]]: # type: ignore # signature of "search" incompatible with supertype "Memory"
|
||||
def search( # type: ignore # signature of "search" incompatible with supertype "Memory"
|
||||
self,
|
||||
task: str,
|
||||
latest_n: int = 3,
|
||||
) -> List[Dict[str, Any]]: # type: ignore # signature of "search" incompatible with supertype "Memory"
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryQueryStartedEvent(
|
||||
query=task,
|
||||
limit=latest_n,
|
||||
source_type="long_term_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -96,6 +108,8 @@ class LongTermMemory(Memory):
|
||||
limit=latest_n,
|
||||
query_time_ms=(time.time() - start_time) * 1000,
|
||||
source_type="long_term_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -1,7 +1,11 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, List, Optional, TYPE_CHECKING
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.agent import Agent
|
||||
from crewai.task import Task
|
||||
|
||||
|
||||
class Memory(BaseModel):
|
||||
"""
|
||||
@@ -12,19 +16,38 @@ class Memory(BaseModel):
|
||||
crew: Optional[Any] = None
|
||||
|
||||
storage: Any
|
||||
_agent: Optional["Agent"] = None
|
||||
_task: Optional["Task"] = None
|
||||
|
||||
def __init__(self, storage: Any, **data: Any):
|
||||
super().__init__(storage=storage, **data)
|
||||
|
||||
@property
|
||||
def task(self) -> Optional["Task"]:
|
||||
"""Get the current task associated with this memory."""
|
||||
return self._task
|
||||
|
||||
@task.setter
|
||||
def task(self, task: Optional["Task"]) -> None:
|
||||
"""Set the current task associated with this memory."""
|
||||
self._task = task
|
||||
|
||||
@property
|
||||
def agent(self) -> Optional["Agent"]:
|
||||
"""Get the current agent associated with this memory."""
|
||||
return self._agent
|
||||
|
||||
@agent.setter
|
||||
def agent(self, agent: Optional["Agent"]) -> None:
|
||||
"""Set the current agent associated with this memory."""
|
||||
self._agent = agent
|
||||
|
||||
def save(
|
||||
self,
|
||||
value: Any,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
agent: Optional[str] = None,
|
||||
) -> None:
|
||||
metadata = metadata or {}
|
||||
if agent:
|
||||
metadata["agent"] = agent
|
||||
|
||||
self.storage.save(value, metadata)
|
||||
|
||||
|
||||
@@ -37,7 +37,7 @@ class ShortTermMemory(Memory):
|
||||
raise ImportError(
|
||||
"Mem0 is not installed. Please install it with `pip install mem0ai`."
|
||||
)
|
||||
config = embedder_config.get("config")
|
||||
config = embedder_config.get("config") if embedder_config else None
|
||||
storage = Mem0Storage(type="short_term", crew=crew, config=config)
|
||||
else:
|
||||
storage = (
|
||||
@@ -57,34 +57,42 @@ class ShortTermMemory(Memory):
|
||||
self,
|
||||
value: Any,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
agent: Optional[str] = None,
|
||||
) -> None:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemorySaveStartedEvent(
|
||||
value=value,
|
||||
metadata=metadata,
|
||||
agent_role=agent,
|
||||
source_type="short_term_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
try:
|
||||
item = ShortTermMemoryItem(data=value, metadata=metadata, agent=agent)
|
||||
item = ShortTermMemoryItem(
|
||||
data=value,
|
||||
metadata=metadata,
|
||||
agent=self.agent.role if self.agent else None,
|
||||
)
|
||||
if self._memory_provider == "mem0":
|
||||
item.data = f"Remember the following insights from Agent run: {item.data}"
|
||||
item.data = (
|
||||
f"Remember the following insights from Agent run: {item.data}"
|
||||
)
|
||||
|
||||
super().save(value=item.data, metadata=item.metadata, agent=item.agent)
|
||||
super().save(value=item.data, metadata=item.metadata)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemorySaveCompletedEvent(
|
||||
value=value,
|
||||
metadata=metadata,
|
||||
agent_role=agent,
|
||||
# agent_role=agent,
|
||||
save_time_ms=(time.time() - start_time) * 1000,
|
||||
source_type="short_term_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
@@ -93,9 +101,10 @@ class ShortTermMemory(Memory):
|
||||
event=MemorySaveFailedEvent(
|
||||
value=value,
|
||||
metadata=metadata,
|
||||
agent_role=agent,
|
||||
error=str(e),
|
||||
source_type="short_term_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
raise
|
||||
@@ -113,6 +122,8 @@ class ShortTermMemory(Memory):
|
||||
limit=limit,
|
||||
score_threshold=score_threshold,
|
||||
source_type="short_term_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -131,6 +142,8 @@ class ShortTermMemory(Memory):
|
||||
score_threshold=score_threshold,
|
||||
query_time_ms=(time.time() - start_time) * 1000,
|
||||
source_type="short_term_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -1 +1,58 @@
|
||||
"""RAG (Retrieval-Augmented Generation) infrastructure for CrewAI."""
|
||||
"""RAG (Retrieval-Augmented Generation) infrastructure for CrewAI."""
|
||||
|
||||
import sys
|
||||
import importlib
|
||||
from types import ModuleType
|
||||
from typing import Any
|
||||
|
||||
from crewai.rag.config.types import RagConfigType
|
||||
from crewai.rag.config.utils import set_rag_config
|
||||
|
||||
|
||||
_module_path = __path__
|
||||
_module_file = __file__
|
||||
|
||||
class _RagModule(ModuleType):
|
||||
"""Module wrapper to intercept attribute setting for config."""
|
||||
|
||||
__path__ = _module_path
|
||||
__file__ = _module_file
|
||||
|
||||
def __init__(self, module_name: str):
|
||||
"""Initialize the module wrapper.
|
||||
|
||||
Args:
|
||||
module_name: Name of the module.
|
||||
"""
|
||||
super().__init__(module_name)
|
||||
|
||||
def __setattr__(self, name: str, value: RagConfigType) -> None:
|
||||
"""Set module attributes.
|
||||
|
||||
Args:
|
||||
name: Attribute name.
|
||||
value: Attribute value.
|
||||
"""
|
||||
if name == "config":
|
||||
return set_rag_config(value)
|
||||
raise AttributeError(f"Setting attribute '{name}' is not allowed.")
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
"""Get module attributes.
|
||||
|
||||
Args:
|
||||
name: Attribute name.
|
||||
|
||||
Returns:
|
||||
The requested attribute.
|
||||
|
||||
Raises:
|
||||
AttributeError: If attribute doesn't exist.
|
||||
"""
|
||||
try:
|
||||
return importlib.import_module(f"{self.__name__}.{name}")
|
||||
except ImportError:
|
||||
raise AttributeError(f"module '{self.__name__}' has no attribute '{name}'")
|
||||
|
||||
|
||||
sys.modules[__name__] = _RagModule(__name__)
|
||||
|
||||
0
src/crewai/rag/chromadb/__init__.py
Normal file
0
src/crewai/rag/chromadb/__init__.py
Normal file
567
src/crewai/rag/chromadb/client.py
Normal file
567
src/crewai/rag/chromadb/client.py
Normal file
@@ -0,0 +1,567 @@
|
||||
"""ChromaDB client implementation."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from chromadb.api.types import (
|
||||
Embeddable,
|
||||
EmbeddingFunction as ChromaEmbeddingFunction,
|
||||
QueryResult,
|
||||
)
|
||||
from typing_extensions import Unpack
|
||||
|
||||
from crewai.rag.chromadb.types import (
|
||||
ChromaDBClientType,
|
||||
ChromaDBCollectionCreateParams,
|
||||
ChromaDBCollectionSearchParams,
|
||||
)
|
||||
from crewai.rag.chromadb.utils import (
|
||||
_extract_search_params,
|
||||
_is_async_client,
|
||||
_is_sync_client,
|
||||
_prepare_documents_for_chromadb,
|
||||
_process_query_results,
|
||||
)
|
||||
from crewai.rag.core.base_client import (
|
||||
BaseClient,
|
||||
BaseCollectionParams,
|
||||
BaseCollectionAddParams,
|
||||
)
|
||||
from crewai.rag.types import SearchResult
|
||||
|
||||
|
||||
class ChromaDBClient(BaseClient):
|
||||
"""ChromaDB implementation of the BaseClient protocol.
|
||||
|
||||
Provides vector database operations for ChromaDB, supporting both
|
||||
synchronous and asynchronous clients.
|
||||
|
||||
Attributes:
|
||||
client: ChromaDB client instance (ClientAPI or AsyncClientAPI).
|
||||
embedding_function: Function to generate embeddings for documents.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client: ChromaDBClientType,
|
||||
embedding_function: ChromaEmbeddingFunction[Embeddable],
|
||||
) -> None:
|
||||
"""Initialize ChromaDBClient with client and embedding function.
|
||||
|
||||
Args:
|
||||
client: Pre-configured ChromaDB client instance.
|
||||
embedding_function: Embedding function for text to vector conversion.
|
||||
"""
|
||||
self.client = client
|
||||
self.embedding_function = embedding_function
|
||||
|
||||
def create_collection(
|
||||
self, **kwargs: Unpack[ChromaDBCollectionCreateParams]
|
||||
) -> None:
|
||||
"""Create a new collection in ChromaDB.
|
||||
|
||||
Uses the client's default embedding function if none provided.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: Name of the collection to create. Must be unique.
|
||||
configuration: Optional collection configuration specifying distance metrics,
|
||||
HNSW parameters, or other backend-specific settings.
|
||||
metadata: Optional metadata dictionary to attach to the collection.
|
||||
embedding_function: Optional custom embedding function. If not provided,
|
||||
uses the client's default embedding function.
|
||||
data_loader: Optional data loader for batch loading data into the collection.
|
||||
get_or_create: If True, returns existing collection if it already exists
|
||||
instead of raising an error. Defaults to False.
|
||||
|
||||
Raises:
|
||||
TypeError: If AsyncClientAPI is used instead of ClientAPI for sync operations.
|
||||
ValueError: If collection with the same name already exists and get_or_create
|
||||
is False.
|
||||
ConnectionError: If unable to connect to ChromaDB server.
|
||||
|
||||
Example:
|
||||
>>> client = ChromaDBClient()
|
||||
>>> client.create_collection(
|
||||
... collection_name="documents",
|
||||
... metadata={"description": "Product documentation"},
|
||||
... get_or_create=True
|
||||
... )
|
||||
"""
|
||||
if not _is_sync_client(self.client):
|
||||
raise TypeError(
|
||||
"Synchronous method create_collection() requires a ClientAPI. "
|
||||
"Use acreate_collection() for AsyncClientAPI."
|
||||
)
|
||||
|
||||
metadata = kwargs.get("metadata", {})
|
||||
if "hnsw:space" not in metadata:
|
||||
metadata["hnsw:space"] = "cosine"
|
||||
|
||||
self.client.create_collection(
|
||||
name=kwargs["collection_name"],
|
||||
configuration=kwargs.get("configuration"),
|
||||
metadata=metadata,
|
||||
embedding_function=kwargs.get(
|
||||
"embedding_function", self.embedding_function
|
||||
),
|
||||
data_loader=kwargs.get("data_loader"),
|
||||
get_or_create=kwargs.get("get_or_create", False),
|
||||
)
|
||||
|
||||
async def acreate_collection(
|
||||
self, **kwargs: Unpack[ChromaDBCollectionCreateParams]
|
||||
) -> None:
|
||||
"""Create a new collection in ChromaDB asynchronously.
|
||||
|
||||
Creates a new collection with the specified name and optional configuration.
|
||||
If an embedding function is not provided, uses the client's default embedding function.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: Name of the collection to create. Must be unique.
|
||||
configuration: Optional collection configuration specifying distance metrics,
|
||||
HNSW parameters, or other backend-specific settings.
|
||||
metadata: Optional metadata dictionary to attach to the collection.
|
||||
embedding_function: Optional custom embedding function. If not provided,
|
||||
uses the client's default embedding function.
|
||||
data_loader: Optional data loader for batch loading data into the collection.
|
||||
get_or_create: If True, returns existing collection if it already exists
|
||||
instead of raising an error. Defaults to False.
|
||||
|
||||
Raises:
|
||||
TypeError: If ClientAPI is used instead of AsyncClientAPI for async operations.
|
||||
ValueError: If collection with the same name already exists and get_or_create
|
||||
is False.
|
||||
ConnectionError: If unable to connect to ChromaDB server.
|
||||
|
||||
Example:
|
||||
>>> import asyncio
|
||||
>>> async def main():
|
||||
... client = ChromaDBClient()
|
||||
... await client.acreate_collection(
|
||||
... collection_name="documents",
|
||||
... metadata={"description": "Product documentation"},
|
||||
... get_or_create=True
|
||||
... )
|
||||
>>> asyncio.run(main())
|
||||
"""
|
||||
if not _is_async_client(self.client):
|
||||
raise TypeError(
|
||||
"Asynchronous method acreate_collection() requires an AsyncClientAPI. "
|
||||
"Use create_collection() for ClientAPI."
|
||||
)
|
||||
|
||||
metadata = kwargs.get("metadata", {})
|
||||
if "hnsw:space" not in metadata:
|
||||
metadata["hnsw:space"] = "cosine"
|
||||
|
||||
await self.client.create_collection(
|
||||
name=kwargs["collection_name"],
|
||||
configuration=kwargs.get("configuration"),
|
||||
metadata=metadata,
|
||||
embedding_function=kwargs.get(
|
||||
"embedding_function", self.embedding_function
|
||||
),
|
||||
data_loader=kwargs.get("data_loader"),
|
||||
get_or_create=kwargs.get("get_or_create", False),
|
||||
)
|
||||
|
||||
def get_or_create_collection(
|
||||
self, **kwargs: Unpack[ChromaDBCollectionCreateParams]
|
||||
) -> Any:
|
||||
"""Get an existing collection or create it if it doesn't exist.
|
||||
|
||||
Returns existing collection if found, otherwise creates a new one.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: Name of the collection to get or create.
|
||||
configuration: Optional collection configuration specifying distance metrics,
|
||||
HNSW parameters, or other backend-specific settings.
|
||||
metadata: Optional metadata dictionary to attach to the collection.
|
||||
embedding_function: Optional custom embedding function. If not provided,
|
||||
uses the client's default embedding function.
|
||||
data_loader: Optional data loader for batch loading data into the collection.
|
||||
|
||||
Returns:
|
||||
A ChromaDB Collection object.
|
||||
|
||||
Raises:
|
||||
TypeError: If AsyncClientAPI is used instead of ClientAPI for sync operations.
|
||||
ConnectionError: If unable to connect to ChromaDB server.
|
||||
|
||||
Example:
|
||||
>>> client = ChromaDBClient()
|
||||
>>> collection = client.get_or_create_collection(
|
||||
... collection_name="documents",
|
||||
... metadata={"description": "Product documentation"}
|
||||
... )
|
||||
"""
|
||||
if not _is_sync_client(self.client):
|
||||
raise TypeError(
|
||||
"Synchronous method get_or_create_collection() requires a ClientAPI. "
|
||||
"Use aget_or_create_collection() for AsyncClientAPI."
|
||||
)
|
||||
|
||||
metadata = kwargs.get("metadata", {})
|
||||
if "hnsw:space" not in metadata:
|
||||
metadata["hnsw:space"] = "cosine"
|
||||
|
||||
return self.client.get_or_create_collection(
|
||||
name=kwargs["collection_name"],
|
||||
configuration=kwargs.get("configuration"),
|
||||
metadata=metadata,
|
||||
embedding_function=kwargs.get(
|
||||
"embedding_function", self.embedding_function
|
||||
),
|
||||
data_loader=kwargs.get("data_loader"),
|
||||
)
|
||||
|
||||
async def aget_or_create_collection(
|
||||
self, **kwargs: Unpack[ChromaDBCollectionCreateParams]
|
||||
) -> Any:
|
||||
"""Get an existing collection or create it if it doesn't exist asynchronously.
|
||||
|
||||
Returns existing collection if found, otherwise creates a new one.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: Name of the collection to get or create.
|
||||
configuration: Optional collection configuration specifying distance metrics,
|
||||
HNSW parameters, or other backend-specific settings.
|
||||
metadata: Optional metadata dictionary to attach to the collection.
|
||||
embedding_function: Optional custom embedding function. If not provided,
|
||||
uses the client's default embedding function.
|
||||
data_loader: Optional data loader for batch loading data into the collection.
|
||||
|
||||
Returns:
|
||||
A ChromaDB AsyncCollection object.
|
||||
|
||||
Raises:
|
||||
TypeError: If ClientAPI is used instead of AsyncClientAPI for async operations.
|
||||
ConnectionError: If unable to connect to ChromaDB server.
|
||||
|
||||
Example:
|
||||
>>> import asyncio
|
||||
>>> async def main():
|
||||
... client = ChromaDBClient()
|
||||
... collection = await client.aget_or_create_collection(
|
||||
... collection_name="documents",
|
||||
... metadata={"description": "Product documentation"}
|
||||
... )
|
||||
>>> asyncio.run(main())
|
||||
"""
|
||||
if not _is_async_client(self.client):
|
||||
raise TypeError(
|
||||
"Asynchronous method aget_or_create_collection() requires an AsyncClientAPI. "
|
||||
"Use get_or_create_collection() for ClientAPI."
|
||||
)
|
||||
|
||||
metadata = kwargs.get("metadata", {})
|
||||
if "hnsw:space" not in metadata:
|
||||
metadata["hnsw:space"] = "cosine"
|
||||
|
||||
return await self.client.get_or_create_collection(
|
||||
name=kwargs["collection_name"],
|
||||
configuration=kwargs.get("configuration"),
|
||||
metadata=metadata,
|
||||
embedding_function=kwargs.get(
|
||||
"embedding_function", self.embedding_function
|
||||
),
|
||||
data_loader=kwargs.get("data_loader"),
|
||||
)
|
||||
|
||||
def add_documents(self, **kwargs: Unpack[BaseCollectionAddParams]) -> None:
|
||||
"""Add documents with their embeddings to a collection.
|
||||
|
||||
Performs an upsert operation - documents with existing IDs are updated.
|
||||
Generates embeddings automatically using the configured embedding function.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: The name of the collection to add documents to.
|
||||
documents: List of BaseRecord dicts containing:
|
||||
- content: The text content (required)
|
||||
- doc_id: Optional unique identifier (auto-generated if missing)
|
||||
- metadata: Optional metadata dictionary
|
||||
|
||||
Raises:
|
||||
TypeError: If AsyncClientAPI is used instead of ClientAPI for sync operations.
|
||||
ValueError: If collection doesn't exist or documents list is empty.
|
||||
ConnectionError: If unable to connect to ChromaDB server.
|
||||
"""
|
||||
if not _is_sync_client(self.client):
|
||||
raise TypeError(
|
||||
"Synchronous method add_documents() requires a ClientAPI. "
|
||||
"Use aadd_documents() for AsyncClientAPI."
|
||||
)
|
||||
|
||||
collection_name = kwargs["collection_name"]
|
||||
documents = kwargs["documents"]
|
||||
|
||||
if not documents:
|
||||
raise ValueError("Documents list cannot be empty")
|
||||
|
||||
collection = self.client.get_collection(
|
||||
name=collection_name,
|
||||
embedding_function=self.embedding_function,
|
||||
)
|
||||
|
||||
prepared = _prepare_documents_for_chromadb(documents)
|
||||
collection.add(
|
||||
ids=prepared.ids,
|
||||
documents=prepared.texts,
|
||||
metadatas=prepared.metadatas,
|
||||
)
|
||||
|
||||
async def aadd_documents(self, **kwargs: Unpack[BaseCollectionAddParams]) -> None:
|
||||
"""Add documents with their embeddings to a collection asynchronously.
|
||||
|
||||
Performs an upsert operation - documents with existing IDs are updated.
|
||||
Generates embeddings automatically using the configured embedding function.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: The name of the collection to add documents to.
|
||||
documents: List of BaseRecord dicts containing:
|
||||
- content: The text content (required)
|
||||
- doc_id: Optional unique identifier (auto-generated if missing)
|
||||
- metadata: Optional metadata dictionary
|
||||
|
||||
Raises:
|
||||
TypeError: If ClientAPI is used instead of AsyncClientAPI for async operations.
|
||||
ValueError: If collection doesn't exist or documents list is empty.
|
||||
ConnectionError: If unable to connect to ChromaDB server.
|
||||
"""
|
||||
if not _is_async_client(self.client):
|
||||
raise TypeError(
|
||||
"Asynchronous method aadd_documents() requires an AsyncClientAPI. "
|
||||
"Use add_documents() for ClientAPI."
|
||||
)
|
||||
|
||||
collection_name = kwargs["collection_name"]
|
||||
documents = kwargs["documents"]
|
||||
|
||||
if not documents:
|
||||
raise ValueError("Documents list cannot be empty")
|
||||
|
||||
collection = await self.client.get_collection(
|
||||
name=collection_name,
|
||||
embedding_function=self.embedding_function,
|
||||
)
|
||||
prepared = _prepare_documents_for_chromadb(documents)
|
||||
await collection.add(
|
||||
ids=prepared.ids,
|
||||
documents=prepared.texts,
|
||||
metadatas=prepared.metadatas,
|
||||
)
|
||||
|
||||
def search(
|
||||
self, **kwargs: Unpack[ChromaDBCollectionSearchParams]
|
||||
) -> list[SearchResult]:
|
||||
"""Search for similar documents using a query.
|
||||
|
||||
Performs semantic search to find documents similar to the query text.
|
||||
Uses the configured embedding function to generate query embeddings.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: Name of the collection to search in.
|
||||
query: The text query to search for.
|
||||
limit: Maximum number of results to return (default: 10).
|
||||
metadata_filter: Optional filter for metadata fields.
|
||||
score_threshold: Optional minimum similarity score (0-1) for results.
|
||||
where: Optional ChromaDB where clause for metadata filtering.
|
||||
where_document: Optional ChromaDB where clause for document content filtering.
|
||||
include: Optional list of fields to include in results.
|
||||
|
||||
Returns:
|
||||
List of SearchResult dicts containing id, content, metadata, and score.
|
||||
|
||||
Raises:
|
||||
TypeError: If AsyncClientAPI is used instead of ClientAPI for sync operations.
|
||||
ValueError: If collection doesn't exist.
|
||||
ConnectionError: If unable to connect to ChromaDB server.
|
||||
"""
|
||||
if not _is_sync_client(self.client):
|
||||
raise TypeError(
|
||||
"Synchronous method search() requires a ClientAPI. "
|
||||
"Use asearch() for AsyncClientAPI."
|
||||
)
|
||||
|
||||
params = _extract_search_params(kwargs)
|
||||
|
||||
collection = self.client.get_collection(
|
||||
name=params.collection_name,
|
||||
embedding_function=self.embedding_function,
|
||||
)
|
||||
|
||||
where = params.where if params.where is not None else params.metadata_filter
|
||||
|
||||
results: QueryResult = collection.query(
|
||||
query_texts=[params.query],
|
||||
n_results=params.limit,
|
||||
where=where,
|
||||
where_document=params.where_document,
|
||||
include=params.include,
|
||||
)
|
||||
|
||||
return _process_query_results(
|
||||
collection=collection,
|
||||
results=results,
|
||||
params=params,
|
||||
)
|
||||
|
||||
async def asearch(
|
||||
self, **kwargs: Unpack[ChromaDBCollectionSearchParams]
|
||||
) -> list[SearchResult]:
|
||||
"""Search for similar documents using a query asynchronously.
|
||||
|
||||
Performs semantic search to find documents similar to the query text.
|
||||
Uses the configured embedding function to generate query embeddings.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: Name of the collection to search in.
|
||||
query: The text query to search for.
|
||||
limit: Maximum number of results to return (default: 10).
|
||||
metadata_filter: Optional filter for metadata fields.
|
||||
score_threshold: Optional minimum similarity score (0-1) for results.
|
||||
where: Optional ChromaDB where clause for metadata filtering.
|
||||
where_document: Optional ChromaDB where clause for document content filtering.
|
||||
include: Optional list of fields to include in results.
|
||||
|
||||
Returns:
|
||||
List of SearchResult dicts containing id, content, metadata, and score.
|
||||
|
||||
Raises:
|
||||
TypeError: If ClientAPI is used instead of AsyncClientAPI for async operations.
|
||||
ValueError: If collection doesn't exist.
|
||||
ConnectionError: If unable to connect to ChromaDB server.
|
||||
"""
|
||||
if not _is_async_client(self.client):
|
||||
raise TypeError(
|
||||
"Asynchronous method asearch() requires an AsyncClientAPI. "
|
||||
"Use search() for ClientAPI."
|
||||
)
|
||||
|
||||
params = _extract_search_params(kwargs)
|
||||
|
||||
collection = await self.client.get_collection(
|
||||
name=params.collection_name,
|
||||
embedding_function=self.embedding_function,
|
||||
)
|
||||
|
||||
where = params.where if params.where is not None else params.metadata_filter
|
||||
|
||||
results: QueryResult = await collection.query(
|
||||
query_texts=[params.query],
|
||||
n_results=params.limit,
|
||||
where=where,
|
||||
where_document=params.where_document,
|
||||
include=params.include,
|
||||
)
|
||||
|
||||
return _process_query_results(
|
||||
collection=collection,
|
||||
results=results,
|
||||
params=params,
|
||||
)
|
||||
|
||||
def delete_collection(self, **kwargs: Unpack[BaseCollectionParams]) -> None:
|
||||
"""Delete a collection and all its data.
|
||||
|
||||
Permanently removes a collection and all documents, embeddings, and metadata it contains.
|
||||
This operation cannot be undone.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: Name of the collection to delete.
|
||||
|
||||
Raises:
|
||||
TypeError: If AsyncClientAPI is used instead of ClientAPI for sync operations.
|
||||
ValueError: If collection doesn't exist.
|
||||
ConnectionError: If unable to connect to ChromaDB server.
|
||||
|
||||
Example:
|
||||
>>> client = ChromaDBClient()
|
||||
>>> client.delete_collection(collection_name="old_documents")
|
||||
"""
|
||||
if not _is_sync_client(self.client):
|
||||
raise TypeError(
|
||||
"Synchronous method delete_collection() requires a ClientAPI. "
|
||||
"Use adelete_collection() for AsyncClientAPI."
|
||||
)
|
||||
|
||||
collection_name = kwargs["collection_name"]
|
||||
self.client.delete_collection(name=collection_name)
|
||||
|
||||
async def adelete_collection(self, **kwargs: Unpack[BaseCollectionParams]) -> None:
|
||||
"""Delete a collection and all its data asynchronously.
|
||||
|
||||
Permanently removes a collection and all documents, embeddings, and metadata it contains.
|
||||
This operation cannot be undone.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: Name of the collection to delete.
|
||||
|
||||
Raises:
|
||||
TypeError: If ClientAPI is used instead of AsyncClientAPI for async operations.
|
||||
ValueError: If collection doesn't exist.
|
||||
ConnectionError: If unable to connect to ChromaDB server.
|
||||
|
||||
Example:
|
||||
>>> import asyncio
|
||||
>>> async def main():
|
||||
... client = ChromaDBClient()
|
||||
... await client.adelete_collection(collection_name="old_documents")
|
||||
>>> asyncio.run(main())
|
||||
"""
|
||||
if not _is_async_client(self.client):
|
||||
raise TypeError(
|
||||
"Asynchronous method adelete_collection() requires an AsyncClientAPI. "
|
||||
"Use delete_collection() for ClientAPI."
|
||||
)
|
||||
|
||||
collection_name = kwargs["collection_name"]
|
||||
await self.client.delete_collection(name=collection_name)
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset the vector database by deleting all collections and data.
|
||||
|
||||
Completely clears the ChromaDB instance, removing all collections,
|
||||
documents, embeddings, and metadata. This operation cannot be undone.
|
||||
Use with extreme caution in production environments.
|
||||
|
||||
Raises:
|
||||
TypeError: If AsyncClientAPI is used instead of ClientAPI for sync operations.
|
||||
ConnectionError: If unable to connect to ChromaDB server.
|
||||
|
||||
Example:
|
||||
>>> client = ChromaDBClient()
|
||||
>>> client.reset() # Removes ALL data from ChromaDB
|
||||
"""
|
||||
if not _is_sync_client(self.client):
|
||||
raise TypeError(
|
||||
"Synchronous method reset() requires a ClientAPI. "
|
||||
"Use areset() for AsyncClientAPI."
|
||||
)
|
||||
|
||||
self.client.reset()
|
||||
|
||||
async def areset(self) -> None:
|
||||
"""Reset the vector database by deleting all collections and data asynchronously.
|
||||
|
||||
Completely clears the ChromaDB instance, removing all collections,
|
||||
documents, embeddings, and metadata. This operation cannot be undone.
|
||||
Use with extreme caution in production environments.
|
||||
|
||||
Raises:
|
||||
TypeError: If ClientAPI is used instead of AsyncClientAPI for async operations.
|
||||
ConnectionError: If unable to connect to ChromaDB server.
|
||||
|
||||
Example:
|
||||
>>> import asyncio
|
||||
>>> async def main():
|
||||
... client = ChromaDBClient()
|
||||
... await client.areset() # Removes ALL data from ChromaDB
|
||||
>>> asyncio.run(main())
|
||||
"""
|
||||
if not _is_async_client(self.client):
|
||||
raise TypeError(
|
||||
"Asynchronous method areset() requires an AsyncClientAPI. "
|
||||
"Use reset() for ClientAPI."
|
||||
)
|
||||
|
||||
await self.client.reset()
|
||||
59
src/crewai/rag/chromadb/config.py
Normal file
59
src/crewai/rag/chromadb/config.py
Normal file
@@ -0,0 +1,59 @@
|
||||
"""ChromaDB configuration model."""
|
||||
|
||||
import warnings
|
||||
from dataclasses import field
|
||||
from typing import Literal, cast
|
||||
from pydantic.dataclasses import dataclass as pyd_dataclass
|
||||
from chromadb.config import Settings
|
||||
from chromadb.utils.embedding_functions import DefaultEmbeddingFunction
|
||||
|
||||
from crewai.rag.chromadb.types import ChromaEmbeddingFunctionWrapper
|
||||
from crewai.rag.config.base import BaseRagConfig
|
||||
from crewai.rag.chromadb.constants import (
|
||||
DEFAULT_TENANT,
|
||||
DEFAULT_DATABASE,
|
||||
DEFAULT_STORAGE_PATH,
|
||||
)
|
||||
|
||||
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
message=".*Mixing V1 models and V2 models.*",
|
||||
category=UserWarning,
|
||||
module="pydantic._internal._generate_schema",
|
||||
)
|
||||
|
||||
|
||||
def _default_settings() -> Settings:
|
||||
"""Create default ChromaDB settings.
|
||||
|
||||
Returns:
|
||||
Settings with persistent storage and reset enabled.
|
||||
"""
|
||||
return Settings(
|
||||
persist_directory=DEFAULT_STORAGE_PATH,
|
||||
allow_reset=True,
|
||||
is_persistent=True,
|
||||
)
|
||||
|
||||
|
||||
def _default_embedding_function() -> ChromaEmbeddingFunctionWrapper:
|
||||
"""Create default ChromaDB embedding function.
|
||||
|
||||
Returns:
|
||||
Default embedding function using all-MiniLM-L6-v2 via ONNX.
|
||||
"""
|
||||
return cast(ChromaEmbeddingFunctionWrapper, DefaultEmbeddingFunction())
|
||||
|
||||
|
||||
@pyd_dataclass(frozen=True)
|
||||
class ChromaDBConfig(BaseRagConfig):
|
||||
"""Configuration for ChromaDB client."""
|
||||
|
||||
provider: Literal["chromadb"] = field(default="chromadb", init=False)
|
||||
tenant: str = DEFAULT_TENANT
|
||||
database: str = DEFAULT_DATABASE
|
||||
settings: Settings = field(default_factory=_default_settings)
|
||||
embedding_function: ChromaEmbeddingFunctionWrapper = field(
|
||||
default_factory=_default_embedding_function
|
||||
)
|
||||
10
src/crewai/rag/chromadb/constants.py
Normal file
10
src/crewai/rag/chromadb/constants.py
Normal file
@@ -0,0 +1,10 @@
|
||||
"""Constants for ChromaDB configuration."""
|
||||
|
||||
import os
|
||||
from typing import Final
|
||||
|
||||
from crewai.utilities.paths import db_storage_path
|
||||
|
||||
DEFAULT_TENANT: Final[str] = "default_tenant"
|
||||
DEFAULT_DATABASE: Final[str] = "default_database"
|
||||
DEFAULT_STORAGE_PATH: Final[str] = os.path.join(db_storage_path(), "chromadb")
|
||||
24
src/crewai/rag/chromadb/factory.py
Normal file
24
src/crewai/rag/chromadb/factory.py
Normal file
@@ -0,0 +1,24 @@
|
||||
"""Factory functions for creating ChromaDB clients."""
|
||||
|
||||
from chromadb import Client
|
||||
|
||||
from crewai.rag.chromadb.config import ChromaDBConfig
|
||||
from crewai.rag.chromadb.client import ChromaDBClient
|
||||
|
||||
|
||||
def create_client(config: ChromaDBConfig) -> ChromaDBClient:
|
||||
"""Create a ChromaDBClient from configuration.
|
||||
|
||||
Args:
|
||||
config: ChromaDB configuration object.
|
||||
|
||||
Returns:
|
||||
Configured ChromaDBClient instance.
|
||||
"""
|
||||
|
||||
return ChromaDBClient(
|
||||
client=Client(
|
||||
settings=config.settings, tenant=config.tenant, database=config.database
|
||||
),
|
||||
embedding_function=config.embedding_function,
|
||||
)
|
||||
102
src/crewai/rag/chromadb/types.py
Normal file
102
src/crewai/rag/chromadb/types.py
Normal file
@@ -0,0 +1,102 @@
|
||||
"""Type definitions specific to ChromaDB implementation."""
|
||||
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, NamedTuple
|
||||
|
||||
from pydantic import GetCoreSchemaHandler
|
||||
from pydantic_core import CoreSchema, core_schema
|
||||
from chromadb.api import ClientAPI, AsyncClientAPI
|
||||
from chromadb.api.configuration import CollectionConfigurationInterface
|
||||
from chromadb.api.types import (
|
||||
CollectionMetadata,
|
||||
DataLoader,
|
||||
Embeddable,
|
||||
EmbeddingFunction as ChromaEmbeddingFunction,
|
||||
Include,
|
||||
Loadable,
|
||||
Where,
|
||||
WhereDocument,
|
||||
)
|
||||
|
||||
from crewai.rag.core.base_client import BaseCollectionParams, BaseCollectionSearchParams
|
||||
|
||||
ChromaDBClientType = ClientAPI | AsyncClientAPI
|
||||
|
||||
|
||||
class ChromaEmbeddingFunctionWrapper(ChromaEmbeddingFunction[Embeddable]):
|
||||
"""Base class for ChromaDB EmbeddingFunction to work with Pydantic validation."""
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(
|
||||
cls, _source_type: Any, _handler: GetCoreSchemaHandler
|
||||
) -> CoreSchema:
|
||||
"""Generate Pydantic core schema for ChromaDB EmbeddingFunction.
|
||||
|
||||
This allows Pydantic to handle ChromaDB's EmbeddingFunction type
|
||||
without requiring arbitrary_types_allowed=True.
|
||||
"""
|
||||
return core_schema.any_schema()
|
||||
|
||||
|
||||
class PreparedDocuments(NamedTuple):
|
||||
"""Prepared documents ready for ChromaDB insertion.
|
||||
|
||||
Attributes:
|
||||
ids: List of document IDs
|
||||
texts: List of document texts
|
||||
metadatas: List of document metadata mappings
|
||||
"""
|
||||
|
||||
ids: list[str]
|
||||
texts: list[str]
|
||||
metadatas: list[Mapping[str, str | int | float | bool]]
|
||||
|
||||
|
||||
class ExtractedSearchParams(NamedTuple):
|
||||
"""Extracted search parameters for ChromaDB queries.
|
||||
|
||||
Attributes:
|
||||
collection_name: Name of the collection to search
|
||||
query: Search query text
|
||||
limit: Maximum number of results
|
||||
metadata_filter: Optional metadata filter
|
||||
score_threshold: Optional minimum similarity score
|
||||
where: Optional ChromaDB where clause
|
||||
where_document: Optional ChromaDB document filter
|
||||
include: Fields to include in results
|
||||
"""
|
||||
|
||||
collection_name: str
|
||||
query: str
|
||||
limit: int
|
||||
metadata_filter: dict[str, Any] | None
|
||||
score_threshold: float | None
|
||||
where: Where | None
|
||||
where_document: WhereDocument | None
|
||||
include: Include
|
||||
|
||||
|
||||
class ChromaDBCollectionCreateParams(BaseCollectionParams, total=False):
|
||||
"""Parameters for creating a ChromaDB collection.
|
||||
|
||||
This class extends BaseCollectionParams to include any additional
|
||||
parameters specific to ChromaDB collection creation.
|
||||
"""
|
||||
|
||||
configuration: CollectionConfigurationInterface
|
||||
metadata: CollectionMetadata
|
||||
embedding_function: ChromaEmbeddingFunction[Embeddable]
|
||||
data_loader: DataLoader[Loadable]
|
||||
get_or_create: bool
|
||||
|
||||
|
||||
class ChromaDBCollectionSearchParams(BaseCollectionSearchParams, total=False):
|
||||
"""Parameters for searching a ChromaDB collection.
|
||||
|
||||
This class extends BaseCollectionSearchParams to include ChromaDB-specific
|
||||
search parameters like where clauses and include options.
|
||||
"""
|
||||
|
||||
where: Where
|
||||
where_document: WhereDocument
|
||||
include: Include
|
||||
218
src/crewai/rag/chromadb/utils.py
Normal file
218
src/crewai/rag/chromadb/utils.py
Normal file
@@ -0,0 +1,218 @@
|
||||
"""Utility functions for ChromaDB client implementation."""
|
||||
|
||||
import hashlib
|
||||
from collections.abc import Mapping
|
||||
from typing import Literal, TypeGuard, cast
|
||||
|
||||
from chromadb.api import AsyncClientAPI, ClientAPI
|
||||
from chromadb.api.types import (
|
||||
Include,
|
||||
IncludeEnum,
|
||||
QueryResult,
|
||||
)
|
||||
from chromadb.api.models.AsyncCollection import AsyncCollection
|
||||
from chromadb.api.models.Collection import Collection
|
||||
from crewai.rag.chromadb.types import (
|
||||
ChromaDBClientType,
|
||||
ChromaDBCollectionSearchParams,
|
||||
ExtractedSearchParams,
|
||||
PreparedDocuments,
|
||||
)
|
||||
from crewai.rag.types import BaseRecord, SearchResult
|
||||
|
||||
|
||||
def _is_sync_client(client: ChromaDBClientType) -> TypeGuard[ClientAPI]:
|
||||
"""Type guard to check if the client is a synchronous ClientAPI.
|
||||
|
||||
Args:
|
||||
client: The client to check.
|
||||
|
||||
Returns:
|
||||
True if the client is a ClientAPI, False otherwise.
|
||||
"""
|
||||
return isinstance(client, ClientAPI)
|
||||
|
||||
|
||||
def _is_async_client(client: ChromaDBClientType) -> TypeGuard[AsyncClientAPI]:
|
||||
"""Type guard to check if the client is an asynchronous AsyncClientAPI.
|
||||
|
||||
Args:
|
||||
client: The client to check.
|
||||
|
||||
Returns:
|
||||
True if the client is an AsyncClientAPI, False otherwise.
|
||||
"""
|
||||
return isinstance(client, AsyncClientAPI)
|
||||
|
||||
|
||||
def _prepare_documents_for_chromadb(
|
||||
documents: list[BaseRecord],
|
||||
) -> PreparedDocuments:
|
||||
"""Prepare documents for ChromaDB by extracting IDs, texts, and metadata.
|
||||
|
||||
Args:
|
||||
documents: List of BaseRecord documents to prepare.
|
||||
|
||||
Returns:
|
||||
PreparedDocuments with ids, texts, and metadatas ready for ChromaDB.
|
||||
"""
|
||||
ids: list[str] = []
|
||||
texts: list[str] = []
|
||||
metadatas: list[Mapping[str, str | int | float | bool]] = []
|
||||
|
||||
for doc in documents:
|
||||
if "doc_id" in doc:
|
||||
ids.append(doc["doc_id"])
|
||||
else:
|
||||
content_hash = hashlib.sha256(doc["content"].encode()).hexdigest()[:16]
|
||||
ids.append(content_hash)
|
||||
|
||||
texts.append(doc["content"])
|
||||
metadata = doc.get("metadata")
|
||||
if metadata:
|
||||
if isinstance(metadata, list):
|
||||
metadatas.append(metadata[0] if metadata else {})
|
||||
else:
|
||||
metadatas.append(metadata)
|
||||
else:
|
||||
metadatas.append({})
|
||||
|
||||
return PreparedDocuments(ids, texts, metadatas)
|
||||
|
||||
|
||||
def _extract_search_params(
|
||||
kwargs: ChromaDBCollectionSearchParams,
|
||||
) -> ExtractedSearchParams:
|
||||
"""Extract search parameters from kwargs.
|
||||
|
||||
Args:
|
||||
kwargs: Keyword arguments containing search parameters.
|
||||
|
||||
Returns:
|
||||
ExtractedSearchParams with all extracted parameters.
|
||||
"""
|
||||
return ExtractedSearchParams(
|
||||
collection_name=kwargs["collection_name"],
|
||||
query=kwargs["query"],
|
||||
limit=kwargs.get("limit", 10),
|
||||
metadata_filter=kwargs.get("metadata_filter"),
|
||||
score_threshold=kwargs.get("score_threshold"),
|
||||
where=kwargs.get("where"),
|
||||
where_document=kwargs.get("where_document"),
|
||||
include=kwargs.get(
|
||||
"include",
|
||||
[IncludeEnum.metadatas, IncludeEnum.documents, IncludeEnum.distances],
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _convert_distance_to_score(
|
||||
distance: float,
|
||||
distance_metric: Literal["l2", "cosine", "ip"],
|
||||
) -> float:
|
||||
"""Convert ChromaDB distance to similarity score.
|
||||
|
||||
Notes:
|
||||
Assuming all embedding are unit-normalized for now, including custom embeddings.
|
||||
|
||||
Args:
|
||||
distance: The distance value from ChromaDB.
|
||||
distance_metric: The distance metric used ("l2", "cosine", or "ip").
|
||||
|
||||
Returns:
|
||||
Similarity score in range [0, 1] where 1 is most similar.
|
||||
"""
|
||||
if distance_metric == "cosine":
|
||||
score = 1.0 - 0.5 * distance
|
||||
return max(0.0, min(1.0, score))
|
||||
raise ValueError(f"Unsupported distance metric: {distance_metric}")
|
||||
|
||||
|
||||
def _convert_chromadb_results_to_search_results(
|
||||
results: QueryResult,
|
||||
include: Include,
|
||||
distance_metric: Literal["l2", "cosine", "ip"],
|
||||
score_threshold: float | None = None,
|
||||
) -> list[SearchResult]:
|
||||
"""Convert ChromaDB query results to SearchResult format.
|
||||
|
||||
Args:
|
||||
results: ChromaDB query results.
|
||||
include: List of fields that were included in the query.
|
||||
distance_metric: The distance metric used by the collection.
|
||||
score_threshold: Optional minimum similarity score (0-1) for results.
|
||||
|
||||
Returns:
|
||||
List of SearchResult dicts containing id, content, metadata, and score.
|
||||
"""
|
||||
search_results: list[SearchResult] = []
|
||||
|
||||
include_strings = [item.value for item in include]
|
||||
|
||||
ids = results["ids"][0] if results.get("ids") else []
|
||||
|
||||
documents_list = results.get("documents")
|
||||
documents = (
|
||||
documents_list[0] if documents_list and "documents" in include_strings else []
|
||||
)
|
||||
|
||||
metadatas_list = results.get("metadatas")
|
||||
metadatas = (
|
||||
metadatas_list[0] if metadatas_list and "metadatas" in include_strings else []
|
||||
)
|
||||
|
||||
distances_list = results.get("distances")
|
||||
distances = (
|
||||
distances_list[0] if distances_list and "distances" in include_strings else []
|
||||
)
|
||||
|
||||
for i, doc_id in enumerate(ids):
|
||||
if not distances or i >= len(distances):
|
||||
continue
|
||||
|
||||
distance = distances[i]
|
||||
score = _convert_distance_to_score(
|
||||
distance=distance, distance_metric=distance_metric
|
||||
)
|
||||
|
||||
if score_threshold and score < score_threshold:
|
||||
continue
|
||||
|
||||
result: SearchResult = {
|
||||
"id": doc_id,
|
||||
"content": documents[i] if documents and i < len(documents) else "",
|
||||
"metadata": dict(metadatas[i]) if metadatas and i < len(metadatas) else {},
|
||||
"score": score,
|
||||
}
|
||||
search_results.append(result)
|
||||
|
||||
return search_results
|
||||
|
||||
|
||||
def _process_query_results(
|
||||
collection: Collection | AsyncCollection,
|
||||
results: QueryResult,
|
||||
params: ExtractedSearchParams,
|
||||
) -> list[SearchResult]:
|
||||
"""Process ChromaDB query results and convert to SearchResult format.
|
||||
|
||||
Args:
|
||||
collection: The ChromaDB collection (sync or async) that was queried.
|
||||
results: Raw query results from ChromaDB.
|
||||
params: The search parameters used for the query.
|
||||
|
||||
Returns:
|
||||
List of SearchResult dicts containing id, content, metadata, and score.
|
||||
"""
|
||||
|
||||
distance_metric = cast(
|
||||
Literal["l2", "cosine", "ip"],
|
||||
collection.metadata.get("hnsw:space", "l2") if collection.metadata else "l2",
|
||||
)
|
||||
|
||||
return _convert_chromadb_results_to_search_results(
|
||||
results=results,
|
||||
include=params.include,
|
||||
distance_metric=distance_metric,
|
||||
score_threshold=params.score_threshold,
|
||||
)
|
||||
1
src/crewai/rag/config/__init__.py
Normal file
1
src/crewai/rag/config/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""RAG client configuration management using ContextVars for thread-safe provider switching."""
|
||||
16
src/crewai/rag/config/base.py
Normal file
16
src/crewai/rag/config/base.py
Normal file
@@ -0,0 +1,16 @@
|
||||
"""Base configuration class for RAG providers."""
|
||||
|
||||
from dataclasses import field
|
||||
from typing import Any
|
||||
|
||||
from pydantic.dataclasses import dataclass as pyd_dataclass
|
||||
|
||||
from crewai.rag.config.optional_imports.types import SupportedProvider
|
||||
|
||||
|
||||
@pyd_dataclass(frozen=True)
|
||||
class BaseRagConfig:
|
||||
"""Base class for RAG configuration with Pydantic serialization support."""
|
||||
|
||||
provider: SupportedProvider = field(init=False)
|
||||
embedding_function: Any | None = field(default=None)
|
||||
8
src/crewai/rag/config/constants.py
Normal file
8
src/crewai/rag/config/constants.py
Normal file
@@ -0,0 +1,8 @@
|
||||
"""Constants for RAG configuration."""
|
||||
|
||||
from typing import Final
|
||||
|
||||
DISCRIMINATOR: Final[str] = "provider"
|
||||
|
||||
DEFAULT_RAG_CONFIG_PATH: Final[str] = "crewai.rag.chromadb.config"
|
||||
DEFAULT_RAG_CONFIG_CLASS: Final[str] = "ChromaDBConfig"
|
||||
1
src/crewai/rag/config/optional_imports/__init__.py
Normal file
1
src/crewai/rag/config/optional_imports/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Optional imports for RAG configuration providers."""
|
||||
26
src/crewai/rag/config/optional_imports/base.py
Normal file
26
src/crewai/rag/config/optional_imports/base.py
Normal file
@@ -0,0 +1,26 @@
|
||||
"""Base classes for missing provider configurations."""
|
||||
|
||||
from typing import Literal
|
||||
from dataclasses import field
|
||||
|
||||
from pydantic import ConfigDict
|
||||
from pydantic.dataclasses import dataclass as pyd_dataclass
|
||||
|
||||
|
||||
@pyd_dataclass(config=ConfigDict(extra="forbid"))
|
||||
class _MissingProvider:
|
||||
"""Base class for missing provider configurations.
|
||||
|
||||
Raises RuntimeError when instantiated to indicate missing dependencies.
|
||||
"""
|
||||
|
||||
provider: Literal["chromadb", "qdrant", "__missing__"] = field(
|
||||
default="__missing__"
|
||||
)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Raises error indicating the provider is not installed."""
|
||||
raise RuntimeError(
|
||||
f"provider '{self.provider}' requested but not installed. "
|
||||
f"Install the extra: `uv add crewai'[{self.provider}]'`."
|
||||
)
|
||||
27
src/crewai/rag/config/optional_imports/protocols.py
Normal file
27
src/crewai/rag/config/optional_imports/protocols.py
Normal file
@@ -0,0 +1,27 @@
|
||||
"""Protocol definitions for RAG factory modules."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Protocol, TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.rag.chromadb.client import ChromaDBClient
|
||||
from crewai.rag.chromadb.config import ChromaDBConfig
|
||||
from crewai.rag.qdrant.client import QdrantClient
|
||||
from crewai.rag.qdrant.config import QdrantConfig
|
||||
|
||||
|
||||
class ChromaFactoryModule(Protocol):
|
||||
"""Protocol for ChromaDB factory module."""
|
||||
|
||||
def create_client(self, config: ChromaDBConfig) -> ChromaDBClient:
|
||||
"""Creates a ChromaDB client from configuration."""
|
||||
...
|
||||
|
||||
|
||||
class QdrantFactoryModule(Protocol):
|
||||
"""Protocol for Qdrant factory module."""
|
||||
|
||||
def create_client(self, config: QdrantConfig) -> QdrantClient:
|
||||
"""Creates a Qdrant client from configuration."""
|
||||
...
|
||||
22
src/crewai/rag/config/optional_imports/providers.py
Normal file
22
src/crewai/rag/config/optional_imports/providers.py
Normal file
@@ -0,0 +1,22 @@
|
||||
"""Provider-specific missing configuration classes."""
|
||||
|
||||
from typing import Literal
|
||||
from dataclasses import field
|
||||
from pydantic import ConfigDict
|
||||
from pydantic.dataclasses import dataclass as pyd_dataclass
|
||||
|
||||
from crewai.rag.config.optional_imports.base import _MissingProvider
|
||||
|
||||
|
||||
@pyd_dataclass(config=ConfigDict(extra="forbid"))
|
||||
class MissingChromaDBConfig(_MissingProvider):
|
||||
"""Placeholder for missing ChromaDB configuration."""
|
||||
|
||||
provider: Literal["chromadb"] = field(default="chromadb")
|
||||
|
||||
|
||||
@pyd_dataclass(config=ConfigDict(extra="forbid"))
|
||||
class MissingQdrantConfig(_MissingProvider):
|
||||
"""Placeholder for missing Qdrant configuration."""
|
||||
|
||||
provider: Literal["qdrant"] = field(default="qdrant")
|
||||
8
src/crewai/rag/config/optional_imports/types.py
Normal file
8
src/crewai/rag/config/optional_imports/types.py
Normal file
@@ -0,0 +1,8 @@
|
||||
"""Type definitions for optional imports."""
|
||||
|
||||
from typing import Annotated, Literal
|
||||
|
||||
SupportedProvider = Annotated[
|
||||
Literal["chromadb", "qdrant"],
|
||||
"Supported RAG provider types, add providers here as they become available",
|
||||
]
|
||||
34
src/crewai/rag/config/types.py
Normal file
34
src/crewai/rag/config/types.py
Normal file
@@ -0,0 +1,34 @@
|
||||
"""Type definitions for RAG configuration."""
|
||||
|
||||
from typing import Annotated, TypeAlias, TYPE_CHECKING
|
||||
from pydantic import Field
|
||||
|
||||
from crewai.rag.config.constants import DISCRIMINATOR
|
||||
|
||||
# Linter freaks out on conditional imports, assigning in the type checking fixes it
|
||||
if TYPE_CHECKING:
|
||||
from crewai.rag.chromadb.config import ChromaDBConfig as ChromaDBConfig_
|
||||
|
||||
ChromaDBConfig = ChromaDBConfig_
|
||||
from crewai.rag.qdrant.config import QdrantConfig as QdrantConfig_
|
||||
|
||||
QdrantConfig = QdrantConfig_
|
||||
else:
|
||||
try:
|
||||
from crewai.rag.chromadb.config import ChromaDBConfig
|
||||
except ImportError:
|
||||
from crewai.rag.config.optional_imports.providers import (
|
||||
MissingChromaDBConfig as ChromaDBConfig,
|
||||
)
|
||||
|
||||
try:
|
||||
from crewai.rag.qdrant.config import QdrantConfig
|
||||
except ImportError:
|
||||
from crewai.rag.config.optional_imports.providers import (
|
||||
MissingQdrantConfig as QdrantConfig,
|
||||
)
|
||||
|
||||
SupportedProviderConfig: TypeAlias = ChromaDBConfig | QdrantConfig
|
||||
RagConfigType: TypeAlias = Annotated[
|
||||
SupportedProviderConfig, Field(discriminator=DISCRIMINATOR)
|
||||
]
|
||||
86
src/crewai/rag/config/utils.py
Normal file
86
src/crewai/rag/config/utils.py
Normal file
@@ -0,0 +1,86 @@
|
||||
"""RAG client configuration utilities."""
|
||||
|
||||
from contextvars import ContextVar
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai.utilities.import_utils import require
|
||||
from crewai.rag.core.base_client import BaseClient
|
||||
from crewai.rag.config.types import RagConfigType
|
||||
from crewai.rag.config.constants import (
|
||||
DEFAULT_RAG_CONFIG_PATH,
|
||||
DEFAULT_RAG_CONFIG_CLASS,
|
||||
)
|
||||
from crewai.rag.factory import create_client
|
||||
|
||||
|
||||
class RagContext(BaseModel):
|
||||
"""Context holding RAG configuration and client instance."""
|
||||
|
||||
config: RagConfigType = Field(..., description="RAG provider configuration")
|
||||
client: BaseClient | None = Field(
|
||||
default=None, description="Instantiated RAG client"
|
||||
)
|
||||
|
||||
|
||||
_rag_context: ContextVar[RagContext | None] = ContextVar("_rag_context", default=None)
|
||||
|
||||
|
||||
def set_rag_config(config: RagConfigType) -> None:
|
||||
"""Set global RAG client configuration and instantiate the client.
|
||||
|
||||
Args:
|
||||
config: The RAG client configuration (ChromaDBConfig).
|
||||
"""
|
||||
client = create_client(config)
|
||||
context = RagContext(config=config, client=client)
|
||||
_rag_context.set(context)
|
||||
|
||||
|
||||
def get_rag_config() -> RagConfigType:
|
||||
"""Get current RAG configuration.
|
||||
|
||||
Returns:
|
||||
The current RAG configuration object.
|
||||
"""
|
||||
context = _rag_context.get()
|
||||
if context is None:
|
||||
module = require(DEFAULT_RAG_CONFIG_PATH, purpose="RAG configuration")
|
||||
config_class = getattr(module, DEFAULT_RAG_CONFIG_CLASS)
|
||||
default_config = config_class()
|
||||
set_rag_config(default_config)
|
||||
context = _rag_context.get()
|
||||
|
||||
if context is None or context.config is None:
|
||||
raise ValueError(
|
||||
"RAG configuration is not set. Please set the RAG config first."
|
||||
)
|
||||
|
||||
return context.config
|
||||
|
||||
|
||||
def get_rag_client() -> BaseClient:
|
||||
"""Get the current RAG client instance.
|
||||
|
||||
Returns:
|
||||
The current RAG client, creating one if needed.
|
||||
"""
|
||||
context = _rag_context.get()
|
||||
if context is None:
|
||||
get_rag_config()
|
||||
context = _rag_context.get()
|
||||
|
||||
if context and context.client is None:
|
||||
context.client = create_client(context.config)
|
||||
|
||||
if context is None or context.client is None:
|
||||
raise ValueError(
|
||||
"RAG client is not configured. Please set the RAG config first."
|
||||
)
|
||||
|
||||
return context.client
|
||||
|
||||
|
||||
def clear_rag_config() -> None:
|
||||
"""Clear the current RAG configuration and client, reverting to defaults."""
|
||||
_rag_context.set(None)
|
||||
1
src/crewai/rag/core/__init__.py
Normal file
1
src/crewai/rag/core/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Core abstract base classes and protocols for RAG systems."""
|
||||
446
src/crewai/rag/core/base_client.py
Normal file
446
src/crewai/rag/core/base_client.py
Normal file
@@ -0,0 +1,446 @@
|
||||
"""Protocol for vector database client implementations."""
|
||||
|
||||
from abc import abstractmethod
|
||||
from typing import Any, Protocol, runtime_checkable, Annotated
|
||||
from typing_extensions import Unpack, Required, TypedDict
|
||||
from pydantic import GetCoreSchemaHandler
|
||||
from pydantic_core import CoreSchema, core_schema
|
||||
|
||||
|
||||
from crewai.rag.types import (
|
||||
EmbeddingFunction,
|
||||
BaseRecord,
|
||||
SearchResult,
|
||||
)
|
||||
|
||||
|
||||
class BaseCollectionParams(TypedDict):
|
||||
"""Base parameters for collection operations.
|
||||
|
||||
Attributes:
|
||||
collection_name: The name of the collection/index to operate on.
|
||||
"""
|
||||
|
||||
collection_name: Required[
|
||||
Annotated[
|
||||
str,
|
||||
"Name of the collection/index. Implementations may have specific constraints (e.g., character limits, allowed characters, case sensitivity).",
|
||||
]
|
||||
]
|
||||
|
||||
|
||||
class BaseCollectionAddParams(BaseCollectionParams):
|
||||
"""Parameters for adding documents to a collection.
|
||||
|
||||
Extends BaseCollectionParams with document-specific fields.
|
||||
|
||||
Attributes:
|
||||
collection_name: The name of the collection to add documents to.
|
||||
documents: List of BaseRecord dictionaries containing document data.
|
||||
"""
|
||||
|
||||
documents: list[BaseRecord]
|
||||
|
||||
|
||||
class BaseCollectionSearchParams(BaseCollectionParams, total=False):
|
||||
"""Parameters for searching within a collection.
|
||||
|
||||
Extends BaseCollectionParams with search-specific optional fields.
|
||||
All fields except collection_name and query are optional.
|
||||
|
||||
Attributes:
|
||||
query: The text query to search for (required).
|
||||
limit: Maximum number of results to return.
|
||||
metadata_filter: Filter results by metadata fields.
|
||||
score_threshold: Minimum similarity score for results (0-1).
|
||||
"""
|
||||
|
||||
query: Required[str]
|
||||
limit: int
|
||||
metadata_filter: dict[str, Any]
|
||||
score_threshold: float
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class BaseClient(Protocol):
|
||||
"""Protocol for vector store client implementations.
|
||||
|
||||
This protocol defines the interface that all vector store client implementations
|
||||
must follow. It provides a consistent API for storing and retrieving
|
||||
documents with their vector embeddings across different vector database
|
||||
backends (e.g., Qdrant, ChromaDB, Weaviate). Implementing classes should
|
||||
handle connection management, data persistence, and vector similarity
|
||||
search operations specific to their backend.
|
||||
|
||||
Implementation Guidelines:
|
||||
Implementations should accept BaseClientParams in their constructor to allow
|
||||
passing pre-configured client instances:
|
||||
|
||||
class MyVectorClient:
|
||||
def __init__(self, client: Any | None = None, **kwargs):
|
||||
if client:
|
||||
self.client = client
|
||||
else:
|
||||
self.client = self._create_default_client(**kwargs)
|
||||
|
||||
Notes:
|
||||
This protocol replaces the former BaseRAGStorage abstraction,
|
||||
providing a cleaner interface for vector store operations.
|
||||
|
||||
Attributes:
|
||||
embedding_function: Callable that takes a list of text strings
|
||||
and returns a list of embedding vectors. Implementations
|
||||
should always provide a default embedding function.
|
||||
client: The underlying vector database client instance. This could be
|
||||
passed via BaseClientParams during initialization or created internally.
|
||||
"""
|
||||
|
||||
client: Any
|
||||
embedding_function: EmbeddingFunction
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(
|
||||
cls, _source_type: Any, _handler: GetCoreSchemaHandler
|
||||
) -> CoreSchema:
|
||||
"""Generate Pydantic core schema for BaseClient Protocol.
|
||||
|
||||
This allows the Protocol to be used in Pydantic models without
|
||||
requiring arbitrary_types_allowed=True.
|
||||
"""
|
||||
return core_schema.any_schema()
|
||||
|
||||
@abstractmethod
|
||||
def create_collection(self, **kwargs: Unpack[BaseCollectionParams]) -> None:
|
||||
"""Create a new collection/index in the vector database.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: The name of the collection to create. Must be unique within
|
||||
the vector database instance.
|
||||
|
||||
Raises:
|
||||
ValueError: If collection name already exists.
|
||||
ConnectionError: If unable to connect to the vector database backend.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def acreate_collection(self, **kwargs: Unpack[BaseCollectionParams]) -> None:
|
||||
"""Create a new collection/index in the vector database asynchronously.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: The name of the collection to create. Must be unique within
|
||||
the vector database instance.
|
||||
|
||||
Raises:
|
||||
ValueError: If collection name already exists.
|
||||
ConnectionError: If unable to connect to the vector database backend.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def get_or_create_collection(self, **kwargs: Unpack[BaseCollectionParams]) -> Any:
|
||||
"""Get an existing collection or create it if it doesn't exist.
|
||||
|
||||
This method provides a convenient way to ensure a collection exists
|
||||
without having to check for its existence first.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: The name of the collection to get or create.
|
||||
|
||||
Returns:
|
||||
A collection object whose type depends on the backend implementation.
|
||||
This could be a collection reference, ID, or client object.
|
||||
|
||||
Raises:
|
||||
ValueError: If unable to create the collection.
|
||||
ConnectionError: If unable to connect to the vector database backend.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def aget_or_create_collection(
|
||||
self, **kwargs: Unpack[BaseCollectionParams]
|
||||
) -> Any:
|
||||
"""Get an existing collection or create it if it doesn't exist asynchronously.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: The name of the collection to get or create.
|
||||
|
||||
Returns:
|
||||
A collection object whose type depends on the backend implementation.
|
||||
|
||||
Raises:
|
||||
ValueError: If unable to create the collection.
|
||||
ConnectionError: If unable to connect to the vector database backend.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def add_documents(self, **kwargs: Unpack[BaseCollectionAddParams]) -> None:
|
||||
"""Add documents with their embeddings to a collection.
|
||||
|
||||
This method performs an upsert operation - if a document with the same ID
|
||||
already exists, it will be updated with the new content and metadata.
|
||||
|
||||
Implementations should handle embedding generation internally based on
|
||||
the configured embedding function.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: The name of the collection to add documents to.
|
||||
documents: List of BaseRecord dicts containing:
|
||||
- content: The text content (required)
|
||||
- doc_id: Optional unique identifier (auto-generated from content hash if missing)
|
||||
- metadata: Optional metadata dictionary
|
||||
Embeddings will be generated automatically.
|
||||
|
||||
Raises:
|
||||
ValueError: If collection doesn't exist or documents list is empty.
|
||||
TypeError: If documents are not BaseRecord dict instances.
|
||||
ConnectionError: If unable to connect to the vector database backend.
|
||||
|
||||
Example:
|
||||
>>> from crewai.rag.chromadb.client import ChromaDBClient
|
||||
>>> from crewai.rag.types import BaseRecord
|
||||
>>> client = ChromaDBClient()
|
||||
>>>
|
||||
>>> records: list[BaseRecord] = [
|
||||
... {
|
||||
... "content": "Machine learning basics",
|
||||
... "metadata": {"source": "file3", "topic": "ML"}
|
||||
... },
|
||||
... {
|
||||
... "doc_id": "custom_id",
|
||||
... "content": "Deep learning fundamentals",
|
||||
... "metadata": {"source": "file4", "topic": "DL"}
|
||||
... }
|
||||
... ]
|
||||
>>> client.add_documents(collection_name="my_docs", documents=records)
|
||||
>>>
|
||||
>>> records_with_id: list[BaseRecord] = [
|
||||
... {
|
||||
... "doc_id": "nlp_001",
|
||||
... "content": "Advanced NLP techniques",
|
||||
... "metadata": {"source": "file5", "topic": "NLP"}
|
||||
... }
|
||||
... ]
|
||||
>>> client.add_documents(collection_name="my_docs", documents=records_with_id)
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def aadd_documents(self, **kwargs: Unpack[BaseCollectionAddParams]) -> None:
|
||||
"""Add documents with their embeddings to a collection asynchronously.
|
||||
|
||||
Implementations should handle embedding generation internally based on
|
||||
the configured embedding function.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: The name of the collection to add documents to.
|
||||
documents: List of BaseRecord dicts containing:
|
||||
- content: The text content (required)
|
||||
- doc_id: Optional unique identifier (auto-generated from content hash if missing)
|
||||
- metadata: Optional metadata dictionary
|
||||
Embeddings will be generated automatically.
|
||||
|
||||
Raises:
|
||||
ValueError: If collection doesn't exist or documents list is empty.
|
||||
TypeError: If documents are not BaseRecord dict instances.
|
||||
ConnectionError: If unable to connect to the vector database backend.
|
||||
|
||||
Example:
|
||||
>>> import asyncio
|
||||
>>> from crewai.rag.chromadb.client import ChromaDBClient
|
||||
>>> from crewai.rag.types import BaseRecord
|
||||
>>>
|
||||
>>> async def add_documents():
|
||||
... client = ChromaDBClient()
|
||||
...
|
||||
... records: list[BaseRecord] = [
|
||||
... {
|
||||
... "doc_id": "doc2",
|
||||
... "content": "Async operations in Python",
|
||||
... "metadata": {"source": "file2", "topic": "async"}
|
||||
... }
|
||||
... ]
|
||||
... await client.aadd_documents(collection_name="my_docs", documents=records)
|
||||
...
|
||||
>>> asyncio.run(add_documents())
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def search(
|
||||
self, **kwargs: Unpack[BaseCollectionSearchParams]
|
||||
) -> list[SearchResult]:
|
||||
"""Search for similar documents using a query.
|
||||
|
||||
Performs a vector similarity search to find the most similar documents
|
||||
to the provided query.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: The name of the collection to search in.
|
||||
query: The text query to search for. The implementation handles
|
||||
embedding generation internally.
|
||||
limit: Maximum number of results to return. Defaults to 10.
|
||||
metadata_filter: Optional metadata filter to apply to the search. The exact
|
||||
format depends on the backend, but typically supports equality
|
||||
and range queries on metadata fields.
|
||||
score_threshold: Optional minimum similarity score threshold. Only
|
||||
results with scores >= this threshold will be returned. The
|
||||
score interpretation depends on the distance metric used.
|
||||
|
||||
Returns:
|
||||
A list of SearchResult dictionaries ordered by similarity score in
|
||||
descending order. Each result contains:
|
||||
- id: Document ID
|
||||
- content: Document text content
|
||||
- metadata: Document metadata
|
||||
- score: Similarity score (0-1, higher is better)
|
||||
|
||||
Raises:
|
||||
ValueError: If collection doesn't exist.
|
||||
ConnectionError: If unable to connect to the vector database backend.
|
||||
|
||||
Example:
|
||||
>>> from crewai.rag.chromadb.client import ChromaDBClient
|
||||
>>> client = ChromaDBClient()
|
||||
>>>
|
||||
>>> results = client.search(
|
||||
... collection_name="my_docs",
|
||||
... query="What is machine learning?",
|
||||
... limit=5,
|
||||
... metadata_filter={"source": "file1"},
|
||||
... score_threshold=0.7
|
||||
... )
|
||||
>>> for result in results:
|
||||
... print(f"{result['id']}: {result['score']:.2f}")
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def asearch(
|
||||
self, **kwargs: Unpack[BaseCollectionSearchParams]
|
||||
) -> list[SearchResult]:
|
||||
"""Search for similar documents using a query asynchronously.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: The name of the collection to search in.
|
||||
query: The text query to search for. The implementation handles
|
||||
embedding generation internally.
|
||||
limit: Maximum number of results to return. Defaults to 10.
|
||||
metadata_filter: Optional metadata filter to apply to the search.
|
||||
score_threshold: Optional minimum similarity score threshold.
|
||||
|
||||
Returns:
|
||||
A list of SearchResult dictionaries ordered by similarity score.
|
||||
|
||||
Raises:
|
||||
ValueError: If collection doesn't exist.
|
||||
ConnectionError: If unable to connect to the vector database backend.
|
||||
|
||||
Example:
|
||||
>>> import asyncio
|
||||
>>> from crewai.rag.chromadb.client import ChromaDBClient
|
||||
>>>
|
||||
>>> async def search_documents():
|
||||
... client = ChromaDBClient()
|
||||
... results = await client.asearch(
|
||||
... collection_name="my_docs",
|
||||
... query="Python programming best practices",
|
||||
... limit=5,
|
||||
... metadata_filter={"source": "file1"},
|
||||
... score_threshold=0.7
|
||||
... )
|
||||
... for result in results:
|
||||
... print(f"{result['id']}: {result['score']:.2f}")
|
||||
...
|
||||
>>> asyncio.run(search_documents())
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def delete_collection(self, **kwargs: Unpack[BaseCollectionParams]) -> None:
|
||||
"""Delete a collection and all its data.
|
||||
|
||||
This operation is irreversible and will permanently remove all documents,
|
||||
embeddings, and metadata associated with the collection.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: The name of the collection to delete.
|
||||
|
||||
Raises:
|
||||
ValueError: If the collection doesn't exist.
|
||||
ConnectionError: If unable to connect to the vector database backend.
|
||||
|
||||
Example:
|
||||
>>> from crewai.rag.chromadb.client import ChromaDBClient
|
||||
>>> client = ChromaDBClient()
|
||||
>>> client.delete_collection(collection_name="old_docs")
|
||||
>>> print("Collection 'old_docs' deleted successfully")
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def adelete_collection(self, **kwargs: Unpack[BaseCollectionParams]) -> None:
|
||||
"""Delete a collection and all its data asynchronously.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: The name of the collection to delete.
|
||||
|
||||
Raises:
|
||||
ValueError: If the collection doesn't exist.
|
||||
ConnectionError: If unable to connect to the vector database backend.
|
||||
|
||||
Example:
|
||||
>>> import asyncio
|
||||
>>> from crewai.rag.chromadb.client import ChromaDBClient
|
||||
>>>
|
||||
>>> async def delete_old_collection():
|
||||
... client = ChromaDBClient()
|
||||
... await client.adelete_collection(collection_name="old_docs")
|
||||
... print("Collection 'old_docs' deleted successfully")
|
||||
...
|
||||
>>> asyncio.run(delete_old_collection())
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def reset(self) -> None:
|
||||
"""Reset the vector database by deleting all collections and data.
|
||||
|
||||
This method provides a way to completely clear the vector database,
|
||||
removing all collections and their contents. Use with caution as
|
||||
this operation is irreversible.
|
||||
|
||||
Raises:
|
||||
ConnectionError: If unable to connect to the vector database backend.
|
||||
PermissionError: If the operation is not allowed by the backend.
|
||||
|
||||
Example:
|
||||
>>> from crewai.rag.chromadb.client import ChromaDBClient
|
||||
>>> client = ChromaDBClient()
|
||||
>>> client.reset()
|
||||
>>> print("Vector database completely reset - all data deleted")
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def areset(self) -> None:
|
||||
"""Reset the vector database by deleting all collections and data asynchronously.
|
||||
|
||||
Raises:
|
||||
ConnectionError: If unable to connect to the vector database backend.
|
||||
PermissionError: If the operation is not allowed by the backend.
|
||||
|
||||
Example:
|
||||
>>> import asyncio
|
||||
>>> from crewai.rag.chromadb.client import ChromaDBClient
|
||||
>>>
|
||||
>>> async def reset_database():
|
||||
... client = ChromaDBClient()
|
||||
... await client.areset()
|
||||
... print("Vector database completely reset - all data deleted")
|
||||
...
|
||||
>>> asyncio.run(reset_database())
|
||||
"""
|
||||
...
|
||||
26
src/crewai/rag/core/exceptions.py
Normal file
26
src/crewai/rag/core/exceptions.py
Normal file
@@ -0,0 +1,26 @@
|
||||
"""Core exceptions for RAG module."""
|
||||
|
||||
|
||||
class ClientMethodMismatchError(TypeError):
|
||||
"""Raised when a method is called with the wrong client type.
|
||||
|
||||
Typically used when a sync method is called with an async client,
|
||||
or vice versa.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, method_name: str, expected_client: str, alt_method: str, alt_client: str
|
||||
) -> None:
|
||||
"""Create a ClientMethodMismatchError.
|
||||
|
||||
Args:
|
||||
method_name: Method that was called incorrectly.
|
||||
expected_client: Required client type.
|
||||
alt_method: Suggested alternative method.
|
||||
alt_client: Client type for the alternative method.
|
||||
"""
|
||||
message = (
|
||||
f"Method {method_name}() requires a {expected_client}. "
|
||||
f"Use {alt_method}() for {alt_client}."
|
||||
)
|
||||
super().__init__(message)
|
||||
148
src/crewai/rag/embeddings/factory.py
Normal file
148
src/crewai/rag/embeddings/factory.py
Normal file
@@ -0,0 +1,148 @@
|
||||
"""Minimal embedding function factory for CrewAI."""
|
||||
|
||||
import os
|
||||
|
||||
from chromadb import EmbeddingFunction
|
||||
from chromadb.utils.embedding_functions.amazon_bedrock_embedding_function import (
|
||||
AmazonBedrockEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.cohere_embedding_function import (
|
||||
CohereEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.google_embedding_function import (
|
||||
GooglePalmEmbeddingFunction,
|
||||
GoogleGenerativeAiEmbeddingFunction,
|
||||
GoogleVertexEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.huggingface_embedding_function import (
|
||||
HuggingFaceEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.instructor_embedding_function import (
|
||||
InstructorEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.jina_embedding_function import (
|
||||
JinaEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.ollama_embedding_function import (
|
||||
OllamaEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.onnx_mini_lm_l6_v2 import ONNXMiniLM_L6_V2
|
||||
from chromadb.utils.embedding_functions.open_clip_embedding_function import (
|
||||
OpenCLIPEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.openai_embedding_function import (
|
||||
OpenAIEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.roboflow_embedding_function import (
|
||||
RoboflowEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.sentence_transformer_embedding_function import (
|
||||
SentenceTransformerEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.text2vec_embedding_function import (
|
||||
Text2VecEmbeddingFunction,
|
||||
)
|
||||
|
||||
from crewai.rag.embeddings.types import EmbeddingOptions
|
||||
|
||||
|
||||
def get_embedding_function(
|
||||
config: EmbeddingOptions | dict | None = None,
|
||||
) -> EmbeddingFunction:
|
||||
"""Get embedding function - delegates to ChromaDB.
|
||||
|
||||
Args:
|
||||
config: Optional configuration - either an EmbeddingOptions object or a dict with:
|
||||
- provider: The embedding provider to use (default: "openai")
|
||||
- Any other provider-specific parameters
|
||||
|
||||
Returns:
|
||||
EmbeddingFunction instance ready for use with ChromaDB
|
||||
|
||||
Supported providers:
|
||||
- openai: OpenAI embeddings (default)
|
||||
- cohere: Cohere embeddings
|
||||
- ollama: Ollama local embeddings
|
||||
- huggingface: HuggingFace embeddings
|
||||
- sentence-transformer: Local sentence transformers
|
||||
- instructor: Instructor embeddings for specialized tasks
|
||||
- google-palm: Google PaLM embeddings
|
||||
- google-generativeai: Google Generative AI embeddings
|
||||
- google-vertex: Google Vertex AI embeddings
|
||||
- amazon-bedrock: AWS Bedrock embeddings
|
||||
- jina: Jina AI embeddings
|
||||
- roboflow: Roboflow embeddings for vision tasks
|
||||
- openclip: OpenCLIP embeddings for multimodal tasks
|
||||
- text2vec: Text2Vec embeddings
|
||||
- onnx: ONNX MiniLM-L6-v2 (no API key needed, included with ChromaDB)
|
||||
|
||||
Examples:
|
||||
# Use default OpenAI with retry logic
|
||||
>>> embedder = get_embedding_function()
|
||||
|
||||
# Use Cohere with dict
|
||||
>>> embedder = get_embedding_function({
|
||||
... "provider": "cohere",
|
||||
... "api_key": "your-key",
|
||||
... "model_name": "embed-english-v3.0"
|
||||
... })
|
||||
|
||||
# Use with EmbeddingOptions
|
||||
>>> embedder = get_embedding_function(
|
||||
... EmbeddingOptions(provider="sentence-transformer", model_name="all-MiniLM-L6-v2")
|
||||
... )
|
||||
|
||||
# Use local sentence transformers (no API key needed)
|
||||
>>> embedder = get_embedding_function({
|
||||
... "provider": "sentence-transformer",
|
||||
... "model_name": "all-MiniLM-L6-v2"
|
||||
... })
|
||||
|
||||
# Use Ollama for local embeddings
|
||||
>>> embedder = get_embedding_function({
|
||||
... "provider": "ollama",
|
||||
... "model_name": "nomic-embed-text"
|
||||
... })
|
||||
|
||||
# Use ONNX (no API key needed)
|
||||
>>> embedder = get_embedding_function({
|
||||
... "provider": "onnx"
|
||||
... })
|
||||
"""
|
||||
if config is None:
|
||||
return OpenAIEmbeddingFunction(
|
||||
api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small"
|
||||
)
|
||||
|
||||
# Handle EmbeddingOptions object
|
||||
if isinstance(config, EmbeddingOptions):
|
||||
config_dict = config.model_dump(exclude_none=True)
|
||||
else:
|
||||
config_dict = config.copy()
|
||||
|
||||
provider = config_dict.pop("provider", "openai")
|
||||
|
||||
embedding_functions = {
|
||||
"openai": OpenAIEmbeddingFunction,
|
||||
"cohere": CohereEmbeddingFunction,
|
||||
"ollama": OllamaEmbeddingFunction,
|
||||
"huggingface": HuggingFaceEmbeddingFunction,
|
||||
"sentence-transformer": SentenceTransformerEmbeddingFunction,
|
||||
"instructor": InstructorEmbeddingFunction,
|
||||
"google-palm": GooglePalmEmbeddingFunction,
|
||||
"google-generativeai": GoogleGenerativeAiEmbeddingFunction,
|
||||
"google-vertex": GoogleVertexEmbeddingFunction,
|
||||
"amazon-bedrock": AmazonBedrockEmbeddingFunction,
|
||||
"jina": JinaEmbeddingFunction,
|
||||
"roboflow": RoboflowEmbeddingFunction,
|
||||
"openclip": OpenCLIPEmbeddingFunction,
|
||||
"text2vec": Text2VecEmbeddingFunction,
|
||||
"onnx": ONNXMiniLM_L6_V2,
|
||||
}
|
||||
|
||||
if provider not in embedding_functions:
|
||||
raise ValueError(
|
||||
f"Unsupported provider: {provider}. "
|
||||
f"Available providers: {list(embedding_functions.keys())}"
|
||||
)
|
||||
return embedding_functions[provider](**config_dict)
|
||||
62
src/crewai/rag/embeddings/types.py
Normal file
62
src/crewai/rag/embeddings/types.py
Normal file
@@ -0,0 +1,62 @@
|
||||
"""Type definitions for the embeddings module."""
|
||||
|
||||
from typing import Literal
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
|
||||
from crewai.rag.types import EmbeddingFunction
|
||||
|
||||
|
||||
EmbeddingProvider = Literal[
|
||||
"openai",
|
||||
"cohere",
|
||||
"ollama",
|
||||
"huggingface",
|
||||
"sentence-transformer",
|
||||
"instructor",
|
||||
"google-palm",
|
||||
"google-generativeai",
|
||||
"google-vertex",
|
||||
"amazon-bedrock",
|
||||
"jina",
|
||||
"roboflow",
|
||||
"openclip",
|
||||
"text2vec",
|
||||
"onnx",
|
||||
]
|
||||
"""Supported embedding providers.
|
||||
|
||||
These correspond to the embedding functions available in ChromaDB's
|
||||
embedding_functions module. Each provider has specific requirements
|
||||
and configuration options.
|
||||
"""
|
||||
|
||||
|
||||
class EmbeddingOptions(BaseModel):
|
||||
"""Configuration options for embedding providers.
|
||||
|
||||
Generic attributes that can be passed to get_embedding_function
|
||||
to configure various embedding providers.
|
||||
"""
|
||||
|
||||
provider: EmbeddingProvider = Field(
|
||||
..., description="Embedding provider name (e.g., 'openai', 'cohere', 'onnx')"
|
||||
)
|
||||
model_name: str | None = Field(
|
||||
default=None, description="Model name for the embedding provider"
|
||||
)
|
||||
api_key: SecretStr | None = Field(
|
||||
default=None, description="API key for the embedding provider"
|
||||
)
|
||||
|
||||
|
||||
class EmbeddingConfig(BaseModel):
|
||||
"""Configuration wrapper for embedding functions.
|
||||
|
||||
Accepts either a pre-configured EmbeddingFunction or EmbeddingOptions
|
||||
to create one. This provides flexibility in how embeddings are configured.
|
||||
|
||||
Attributes:
|
||||
function: Either a callable EmbeddingFunction or EmbeddingOptions to create one
|
||||
"""
|
||||
|
||||
function: EmbeddingFunction | EmbeddingOptions
|
||||
45
src/crewai/rag/factory.py
Normal file
45
src/crewai/rag/factory.py
Normal file
@@ -0,0 +1,45 @@
|
||||
"""Factory functions for creating RAG clients from configuration."""
|
||||
|
||||
from typing import cast
|
||||
|
||||
from crewai.rag.config.optional_imports.protocols import (
|
||||
ChromaFactoryModule,
|
||||
QdrantFactoryModule,
|
||||
)
|
||||
from crewai.rag.core.base_client import BaseClient
|
||||
from crewai.rag.config.types import RagConfigType
|
||||
from crewai.utilities.import_utils import require
|
||||
|
||||
|
||||
def create_client(config: RagConfigType) -> BaseClient:
|
||||
"""Create a client from configuration using the appropriate factory.
|
||||
|
||||
Args:
|
||||
config: The RAG client configuration.
|
||||
|
||||
Returns:
|
||||
The created client instance.
|
||||
|
||||
Raises:
|
||||
ValueError: If the configuration provider is not supported.
|
||||
"""
|
||||
|
||||
if config.provider == "chromadb":
|
||||
chromadb_mod = cast(
|
||||
ChromaFactoryModule,
|
||||
require(
|
||||
"crewai.rag.chromadb.factory",
|
||||
purpose="The 'chromadb' provider",
|
||||
),
|
||||
)
|
||||
return chromadb_mod.create_client(config)
|
||||
|
||||
if config.provider == "qdrant":
|
||||
qdrant_mod = cast(
|
||||
QdrantFactoryModule,
|
||||
require(
|
||||
"crewai.rag.qdrant.factory",
|
||||
purpose="The 'qdrant' provider",
|
||||
),
|
||||
)
|
||||
return qdrant_mod.create_client(config)
|
||||
1
src/crewai/rag/qdrant/__init__.py
Normal file
1
src/crewai/rag/qdrant/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Qdrant vector database client implementation."""
|
||||
501
src/crewai/rag/qdrant/client.py
Normal file
501
src/crewai/rag/qdrant/client.py
Normal file
@@ -0,0 +1,501 @@
|
||||
"""Qdrant client implementation."""
|
||||
|
||||
from typing import Any, cast
|
||||
|
||||
from typing_extensions import Unpack
|
||||
|
||||
from crewai.rag.core.base_client import (
|
||||
BaseClient,
|
||||
BaseCollectionParams,
|
||||
BaseCollectionAddParams,
|
||||
BaseCollectionSearchParams,
|
||||
)
|
||||
from crewai.rag.core.exceptions import ClientMethodMismatchError
|
||||
from crewai.rag.qdrant.types import (
|
||||
AsyncEmbeddingFunction,
|
||||
EmbeddingFunction,
|
||||
QdrantClientType,
|
||||
QdrantCollectionCreateParams,
|
||||
)
|
||||
from crewai.rag.qdrant.utils import (
|
||||
_is_async_client,
|
||||
_is_async_embedding_function,
|
||||
_is_sync_client,
|
||||
_create_point_from_document,
|
||||
_get_collection_params,
|
||||
_prepare_search_params,
|
||||
_process_search_results,
|
||||
)
|
||||
from crewai.rag.types import SearchResult
|
||||
|
||||
|
||||
class QdrantClient(BaseClient):
|
||||
"""Qdrant implementation of the BaseClient protocol.
|
||||
|
||||
Provides vector database operations for Qdrant, supporting both
|
||||
synchronous and asynchronous clients.
|
||||
|
||||
Attributes:
|
||||
client: Qdrant client instance (QdrantClient or AsyncQdrantClient).
|
||||
embedding_function: Function to generate embeddings for documents.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client: QdrantClientType,
|
||||
embedding_function: EmbeddingFunction | AsyncEmbeddingFunction,
|
||||
) -> None:
|
||||
"""Initialize QdrantClient with client and embedding function.
|
||||
|
||||
Args:
|
||||
client: Pre-configured Qdrant client instance.
|
||||
embedding_function: Embedding function for text to vector conversion.
|
||||
"""
|
||||
self.client = client
|
||||
self.embedding_function = embedding_function
|
||||
|
||||
def create_collection(self, **kwargs: Unpack[QdrantCollectionCreateParams]) -> None:
|
||||
"""Create a new collection in Qdrant.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: Name of the collection to create. Must be unique.
|
||||
vectors_config: Optional vector configuration. Defaults to 1536 dimensions with cosine distance.
|
||||
sparse_vectors_config: Optional sparse vector configuration.
|
||||
shard_number: Optional number of shards.
|
||||
replication_factor: Optional replication factor.
|
||||
write_consistency_factor: Optional write consistency factor.
|
||||
on_disk_payload: Optional flag to store payload on disk.
|
||||
hnsw_config: Optional HNSW index configuration.
|
||||
optimizers_config: Optional optimizer configuration.
|
||||
wal_config: Optional write-ahead log configuration.
|
||||
quantization_config: Optional quantization configuration.
|
||||
init_from: Optional collection to initialize from.
|
||||
timeout: Optional timeout for the operation.
|
||||
|
||||
Raises:
|
||||
ValueError: If collection with the same name already exists.
|
||||
ConnectionError: If unable to connect to Qdrant server.
|
||||
"""
|
||||
if not _is_sync_client(self.client):
|
||||
raise ClientMethodMismatchError(
|
||||
method_name="create_collection",
|
||||
expected_client="QdrantClient",
|
||||
alt_method="acreate_collection",
|
||||
alt_client="AsyncQdrantClient",
|
||||
)
|
||||
|
||||
collection_name = kwargs["collection_name"]
|
||||
|
||||
if self.client.collection_exists(collection_name):
|
||||
raise ValueError(f"Collection '{collection_name}' already exists")
|
||||
|
||||
params = _get_collection_params(kwargs)
|
||||
self.client.create_collection(**params)
|
||||
|
||||
async def acreate_collection(
|
||||
self, **kwargs: Unpack[QdrantCollectionCreateParams]
|
||||
) -> None:
|
||||
"""Create a new collection in Qdrant asynchronously.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: Name of the collection to create. Must be unique.
|
||||
vectors_config: Optional vector configuration. Defaults to 1536 dimensions with cosine distance.
|
||||
sparse_vectors_config: Optional sparse vector configuration.
|
||||
shard_number: Optional number of shards.
|
||||
replication_factor: Optional replication factor.
|
||||
write_consistency_factor: Optional write consistency factor.
|
||||
on_disk_payload: Optional flag to store payload on disk.
|
||||
hnsw_config: Optional HNSW index configuration.
|
||||
optimizers_config: Optional optimizer configuration.
|
||||
wal_config: Optional write-ahead log configuration.
|
||||
quantization_config: Optional quantization configuration.
|
||||
init_from: Optional collection to initialize from.
|
||||
timeout: Optional timeout for the operation.
|
||||
|
||||
Raises:
|
||||
ValueError: If collection with the same name already exists.
|
||||
ConnectionError: If unable to connect to Qdrant server.
|
||||
"""
|
||||
if not _is_async_client(self.client):
|
||||
raise ClientMethodMismatchError(
|
||||
method_name="acreate_collection",
|
||||
expected_client="AsyncQdrantClient",
|
||||
alt_method="create_collection",
|
||||
alt_client="QdrantClient",
|
||||
)
|
||||
|
||||
collection_name = kwargs["collection_name"]
|
||||
|
||||
if await self.client.collection_exists(collection_name):
|
||||
raise ValueError(f"Collection '{collection_name}' already exists")
|
||||
|
||||
params = _get_collection_params(kwargs)
|
||||
await self.client.create_collection(**params)
|
||||
|
||||
def get_or_create_collection(
|
||||
self, **kwargs: Unpack[QdrantCollectionCreateParams]
|
||||
) -> Any:
|
||||
"""Get an existing collection or create it if it doesn't exist.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: Name of the collection to get or create.
|
||||
vectors_config: Optional vector configuration. Defaults to 1536 dimensions with cosine distance.
|
||||
sparse_vectors_config: Optional sparse vector configuration.
|
||||
shard_number: Optional number of shards.
|
||||
replication_factor: Optional replication factor.
|
||||
write_consistency_factor: Optional write consistency factor.
|
||||
on_disk_payload: Optional flag to store payload on disk.
|
||||
hnsw_config: Optional HNSW index configuration.
|
||||
optimizers_config: Optional optimizer configuration.
|
||||
wal_config: Optional write-ahead log configuration.
|
||||
quantization_config: Optional quantization configuration.
|
||||
init_from: Optional collection to initialize from.
|
||||
timeout: Optional timeout for the operation.
|
||||
|
||||
Returns:
|
||||
Collection info dict with name and other metadata.
|
||||
|
||||
Raises:
|
||||
ConnectionError: If unable to connect to Qdrant server.
|
||||
"""
|
||||
if not _is_sync_client(self.client):
|
||||
raise ClientMethodMismatchError(
|
||||
method_name="get_or_create_collection",
|
||||
expected_client="QdrantClient",
|
||||
alt_method="aget_or_create_collection",
|
||||
alt_client="AsyncQdrantClient",
|
||||
)
|
||||
|
||||
collection_name = kwargs["collection_name"]
|
||||
|
||||
if self.client.collection_exists(collection_name):
|
||||
return self.client.get_collection(collection_name)
|
||||
|
||||
params = _get_collection_params(kwargs)
|
||||
self.client.create_collection(**params)
|
||||
|
||||
return self.client.get_collection(collection_name)
|
||||
|
||||
async def aget_or_create_collection(
|
||||
self, **kwargs: Unpack[QdrantCollectionCreateParams]
|
||||
) -> Any:
|
||||
"""Get an existing collection or create it if it doesn't exist asynchronously.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: Name of the collection to get or create.
|
||||
vectors_config: Optional vector configuration. Defaults to 1536 dimensions with cosine distance.
|
||||
sparse_vectors_config: Optional sparse vector configuration.
|
||||
shard_number: Optional number of shards.
|
||||
replication_factor: Optional replication factor.
|
||||
write_consistency_factor: Optional write consistency factor.
|
||||
on_disk_payload: Optional flag to store payload on disk.
|
||||
hnsw_config: Optional HNSW index configuration.
|
||||
optimizers_config: Optional optimizer configuration.
|
||||
wal_config: Optional write-ahead log configuration.
|
||||
quantization_config: Optional quantization configuration.
|
||||
init_from: Optional collection to initialize from.
|
||||
timeout: Optional timeout for the operation.
|
||||
|
||||
Returns:
|
||||
Collection info dict with name and other metadata.
|
||||
|
||||
Raises:
|
||||
ConnectionError: If unable to connect to Qdrant server.
|
||||
"""
|
||||
if not _is_async_client(self.client):
|
||||
raise ClientMethodMismatchError(
|
||||
method_name="aget_or_create_collection",
|
||||
expected_client="AsyncQdrantClient",
|
||||
alt_method="get_or_create_collection",
|
||||
alt_client="QdrantClient",
|
||||
)
|
||||
|
||||
collection_name = kwargs["collection_name"]
|
||||
|
||||
if await self.client.collection_exists(collection_name):
|
||||
return await self.client.get_collection(collection_name)
|
||||
|
||||
params = _get_collection_params(kwargs)
|
||||
await self.client.create_collection(**params)
|
||||
|
||||
return await self.client.get_collection(collection_name)
|
||||
|
||||
def add_documents(self, **kwargs: Unpack[BaseCollectionAddParams]) -> None:
|
||||
"""Add documents with their embeddings to a collection.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: The name of the collection to add documents to.
|
||||
documents: List of BaseRecord dicts containing document data.
|
||||
|
||||
Raises:
|
||||
ValueError: If collection doesn't exist or documents list is empty.
|
||||
ConnectionError: If unable to connect to Qdrant server.
|
||||
"""
|
||||
if not _is_sync_client(self.client):
|
||||
raise ClientMethodMismatchError(
|
||||
method_name="add_documents",
|
||||
expected_client="QdrantClient",
|
||||
alt_method="aadd_documents",
|
||||
alt_client="AsyncQdrantClient",
|
||||
)
|
||||
|
||||
collection_name = kwargs["collection_name"]
|
||||
documents = kwargs["documents"]
|
||||
|
||||
if not documents:
|
||||
raise ValueError("Documents list cannot be empty")
|
||||
|
||||
if not self.client.collection_exists(collection_name):
|
||||
raise ValueError(f"Collection '{collection_name}' does not exist")
|
||||
|
||||
points = []
|
||||
for doc in documents:
|
||||
if _is_async_embedding_function(self.embedding_function):
|
||||
raise TypeError(
|
||||
"Async embedding function cannot be used with sync add_documents. "
|
||||
"Use aadd_documents instead."
|
||||
)
|
||||
sync_fn = cast(EmbeddingFunction, self.embedding_function)
|
||||
embedding = sync_fn(doc["content"])
|
||||
point = _create_point_from_document(doc, embedding)
|
||||
points.append(point)
|
||||
|
||||
self.client.upsert(collection_name=collection_name, points=points)
|
||||
|
||||
async def aadd_documents(self, **kwargs: Unpack[BaseCollectionAddParams]) -> None:
|
||||
"""Add documents with their embeddings to a collection asynchronously.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: The name of the collection to add documents to.
|
||||
documents: List of BaseRecord dicts containing document data.
|
||||
|
||||
Raises:
|
||||
ValueError: If collection doesn't exist or documents list is empty.
|
||||
ConnectionError: If unable to connect to Qdrant server.
|
||||
"""
|
||||
if not _is_async_client(self.client):
|
||||
raise ClientMethodMismatchError(
|
||||
method_name="aadd_documents",
|
||||
expected_client="AsyncQdrantClient",
|
||||
alt_method="add_documents",
|
||||
alt_client="QdrantClient",
|
||||
)
|
||||
|
||||
collection_name = kwargs["collection_name"]
|
||||
documents = kwargs["documents"]
|
||||
|
||||
if not documents:
|
||||
raise ValueError("Documents list cannot be empty")
|
||||
|
||||
if not await self.client.collection_exists(collection_name):
|
||||
raise ValueError(f"Collection '{collection_name}' does not exist")
|
||||
|
||||
points = []
|
||||
for doc in documents:
|
||||
if _is_async_embedding_function(self.embedding_function):
|
||||
async_fn = cast(AsyncEmbeddingFunction, self.embedding_function)
|
||||
embedding = await async_fn(doc["content"])
|
||||
else:
|
||||
sync_fn = cast(EmbeddingFunction, self.embedding_function)
|
||||
embedding = sync_fn(doc["content"])
|
||||
point = _create_point_from_document(doc, embedding)
|
||||
points.append(point)
|
||||
|
||||
await self.client.upsert(collection_name=collection_name, points=points)
|
||||
|
||||
def search(
|
||||
self, **kwargs: Unpack[BaseCollectionSearchParams]
|
||||
) -> list[SearchResult]:
|
||||
"""Search for similar documents using a query.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: Name of the collection to search in.
|
||||
query: The text query to search for.
|
||||
limit: Maximum number of results to return (default: 10).
|
||||
metadata_filter: Optional filter for metadata fields.
|
||||
score_threshold: Optional minimum similarity score (0-1) for results.
|
||||
|
||||
Returns:
|
||||
List of SearchResult dicts containing id, content, metadata, and score.
|
||||
|
||||
Raises:
|
||||
ValueError: If collection doesn't exist.
|
||||
ConnectionError: If unable to connect to Qdrant server.
|
||||
"""
|
||||
if not _is_sync_client(self.client):
|
||||
raise ClientMethodMismatchError(
|
||||
method_name="search",
|
||||
expected_client="QdrantClient",
|
||||
alt_method="asearch",
|
||||
alt_client="AsyncQdrantClient",
|
||||
)
|
||||
|
||||
collection_name = kwargs["collection_name"]
|
||||
query = kwargs["query"]
|
||||
limit = kwargs.get("limit", 10)
|
||||
metadata_filter = kwargs.get("metadata_filter")
|
||||
score_threshold = kwargs.get("score_threshold")
|
||||
|
||||
if not self.client.collection_exists(collection_name):
|
||||
raise ValueError(f"Collection '{collection_name}' does not exist")
|
||||
|
||||
if _is_async_embedding_function(self.embedding_function):
|
||||
raise TypeError(
|
||||
"Async embedding function cannot be used with sync search. "
|
||||
"Use asearch instead."
|
||||
)
|
||||
sync_fn = cast(EmbeddingFunction, self.embedding_function)
|
||||
query_embedding = sync_fn(query)
|
||||
|
||||
search_kwargs = _prepare_search_params(
|
||||
collection_name=collection_name,
|
||||
query_embedding=query_embedding,
|
||||
limit=limit,
|
||||
score_threshold=score_threshold,
|
||||
metadata_filter=metadata_filter,
|
||||
)
|
||||
|
||||
response = self.client.query_points(**search_kwargs)
|
||||
return _process_search_results(response)
|
||||
|
||||
async def asearch(
|
||||
self, **kwargs: Unpack[BaseCollectionSearchParams]
|
||||
) -> list[SearchResult]:
|
||||
"""Search for similar documents using a query asynchronously.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: Name of the collection to search in.
|
||||
query: The text query to search for.
|
||||
limit: Maximum number of results to return (default: 10).
|
||||
metadata_filter: Optional filter for metadata fields.
|
||||
score_threshold: Optional minimum similarity score (0-1) for results.
|
||||
|
||||
Returns:
|
||||
List of SearchResult dicts containing id, content, metadata, and score.
|
||||
|
||||
Raises:
|
||||
ValueError: If collection doesn't exist.
|
||||
ConnectionError: If unable to connect to Qdrant server.
|
||||
"""
|
||||
if not _is_async_client(self.client):
|
||||
raise ClientMethodMismatchError(
|
||||
method_name="asearch",
|
||||
expected_client="AsyncQdrantClient",
|
||||
alt_method="search",
|
||||
alt_client="QdrantClient",
|
||||
)
|
||||
|
||||
collection_name = kwargs["collection_name"]
|
||||
query = kwargs["query"]
|
||||
limit = kwargs.get("limit", 10)
|
||||
metadata_filter = kwargs.get("metadata_filter")
|
||||
score_threshold = kwargs.get("score_threshold")
|
||||
|
||||
if not await self.client.collection_exists(collection_name):
|
||||
raise ValueError(f"Collection '{collection_name}' does not exist")
|
||||
|
||||
if _is_async_embedding_function(self.embedding_function):
|
||||
async_fn = cast(AsyncEmbeddingFunction, self.embedding_function)
|
||||
query_embedding = await async_fn(query)
|
||||
else:
|
||||
sync_fn = cast(EmbeddingFunction, self.embedding_function)
|
||||
query_embedding = sync_fn(query)
|
||||
|
||||
search_kwargs = _prepare_search_params(
|
||||
collection_name=collection_name,
|
||||
query_embedding=query_embedding,
|
||||
limit=limit,
|
||||
score_threshold=score_threshold,
|
||||
metadata_filter=metadata_filter,
|
||||
)
|
||||
|
||||
response = await self.client.query_points(**search_kwargs)
|
||||
return _process_search_results(response)
|
||||
|
||||
def delete_collection(self, **kwargs: Unpack[BaseCollectionParams]) -> None:
|
||||
"""Delete a collection and all its data.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: Name of the collection to delete.
|
||||
|
||||
Raises:
|
||||
ValueError: If collection doesn't exist.
|
||||
ConnectionError: If unable to connect to Qdrant server.
|
||||
"""
|
||||
if not _is_sync_client(self.client):
|
||||
raise ClientMethodMismatchError(
|
||||
method_name="delete_collection",
|
||||
expected_client="QdrantClient",
|
||||
alt_method="adelete_collection",
|
||||
alt_client="AsyncQdrantClient",
|
||||
)
|
||||
|
||||
collection_name = kwargs["collection_name"]
|
||||
|
||||
if not self.client.collection_exists(collection_name):
|
||||
raise ValueError(f"Collection '{collection_name}' does not exist")
|
||||
|
||||
self.client.delete_collection(collection_name=collection_name)
|
||||
|
||||
async def adelete_collection(self, **kwargs: Unpack[BaseCollectionParams]) -> None:
|
||||
"""Delete a collection and all its data asynchronously.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: Name of the collection to delete.
|
||||
|
||||
Raises:
|
||||
ValueError: If collection doesn't exist.
|
||||
ConnectionError: If unable to connect to Qdrant server.
|
||||
"""
|
||||
if not _is_async_client(self.client):
|
||||
raise ClientMethodMismatchError(
|
||||
method_name="adelete_collection",
|
||||
expected_client="AsyncQdrantClient",
|
||||
alt_method="delete_collection",
|
||||
alt_client="QdrantClient",
|
||||
)
|
||||
|
||||
collection_name = kwargs["collection_name"]
|
||||
|
||||
if not await self.client.collection_exists(collection_name):
|
||||
raise ValueError(f"Collection '{collection_name}' does not exist")
|
||||
|
||||
await self.client.delete_collection(collection_name=collection_name)
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset the vector database by deleting all collections and data.
|
||||
|
||||
Raises:
|
||||
ConnectionError: If unable to connect to Qdrant server.
|
||||
"""
|
||||
if not _is_sync_client(self.client):
|
||||
raise ClientMethodMismatchError(
|
||||
method_name="reset",
|
||||
expected_client="QdrantClient",
|
||||
alt_method="areset",
|
||||
alt_client="AsyncQdrantClient",
|
||||
)
|
||||
|
||||
collections_response = self.client.get_collections()
|
||||
|
||||
for collection in collections_response.collections:
|
||||
self.client.delete_collection(collection_name=collection.name)
|
||||
|
||||
async def areset(self) -> None:
|
||||
"""Reset the vector database by deleting all collections and data asynchronously.
|
||||
|
||||
Raises:
|
||||
ConnectionError: If unable to connect to Qdrant server.
|
||||
"""
|
||||
if not _is_async_client(self.client):
|
||||
raise ClientMethodMismatchError(
|
||||
method_name="areset",
|
||||
expected_client="AsyncQdrantClient",
|
||||
alt_method="reset",
|
||||
alt_client="QdrantClient",
|
||||
)
|
||||
|
||||
collections_response = await self.client.get_collections()
|
||||
|
||||
for collection in collections_response.collections:
|
||||
await self.client.delete_collection(collection_name=collection.name)
|
||||
54
src/crewai/rag/qdrant/config.py
Normal file
54
src/crewai/rag/qdrant/config.py
Normal file
@@ -0,0 +1,54 @@
|
||||
"""Qdrant configuration model."""
|
||||
|
||||
from dataclasses import field
|
||||
from typing import Literal, cast
|
||||
from pydantic.dataclasses import dataclass as pyd_dataclass
|
||||
|
||||
from crewai.rag.config.base import BaseRagConfig
|
||||
from crewai.rag.qdrant.types import QdrantClientParams, QdrantEmbeddingFunctionWrapper
|
||||
from crewai.rag.qdrant.constants import DEFAULT_EMBEDDING_MODEL, DEFAULT_STORAGE_PATH
|
||||
|
||||
|
||||
def _default_options() -> QdrantClientParams:
|
||||
"""Create default Qdrant client options.
|
||||
|
||||
Returns:
|
||||
Default options with file-based storage.
|
||||
"""
|
||||
return QdrantClientParams(path=DEFAULT_STORAGE_PATH)
|
||||
|
||||
|
||||
def _default_embedding_function() -> QdrantEmbeddingFunctionWrapper:
|
||||
"""Create default Qdrant embedding function.
|
||||
|
||||
Returns:
|
||||
Default embedding function using fastembed with all-MiniLM-L6-v2.
|
||||
"""
|
||||
from fastembed import TextEmbedding
|
||||
|
||||
model = TextEmbedding(model_name=DEFAULT_EMBEDDING_MODEL)
|
||||
|
||||
def embed_fn(text: str) -> list[float]:
|
||||
"""Embed a single text string.
|
||||
|
||||
Args:
|
||||
text: Text to embed.
|
||||
|
||||
Returns:
|
||||
Embedding vector as list of floats.
|
||||
"""
|
||||
embeddings = list(model.embed([text]))
|
||||
return embeddings[0].tolist() if embeddings else []
|
||||
|
||||
return cast(QdrantEmbeddingFunctionWrapper, embed_fn)
|
||||
|
||||
|
||||
@pyd_dataclass(frozen=True)
|
||||
class QdrantConfig(BaseRagConfig):
|
||||
"""Configuration for Qdrant client."""
|
||||
|
||||
provider: Literal["qdrant"] = field(default="qdrant", init=False)
|
||||
options: QdrantClientParams = field(default_factory=_default_options)
|
||||
embedding_function: QdrantEmbeddingFunctionWrapper = field(
|
||||
default_factory=_default_embedding_function
|
||||
)
|
||||
12
src/crewai/rag/qdrant/constants.py
Normal file
12
src/crewai/rag/qdrant/constants.py
Normal file
@@ -0,0 +1,12 @@
|
||||
"""Constants for Qdrant implementation."""
|
||||
|
||||
import os
|
||||
from typing import Final
|
||||
|
||||
from qdrant_client.models import Distance, VectorParams
|
||||
|
||||
from crewai.utilities.paths import db_storage_path
|
||||
|
||||
DEFAULT_VECTOR_PARAMS: Final = VectorParams(size=384, distance=Distance.COSINE)
|
||||
DEFAULT_EMBEDDING_MODEL: Final[str] = "sentence-transformers/all-MiniLM-L6-v2"
|
||||
DEFAULT_STORAGE_PATH: Final[str] = os.path.join(db_storage_path(), "qdrant")
|
||||
21
src/crewai/rag/qdrant/factory.py
Normal file
21
src/crewai/rag/qdrant/factory.py
Normal file
@@ -0,0 +1,21 @@
|
||||
"""Factory functions for creating Qdrant clients from configuration."""
|
||||
|
||||
from qdrant_client import QdrantClient as SyncQdrantClientBase
|
||||
from crewai.rag.qdrant.client import QdrantClient
|
||||
from crewai.rag.qdrant.config import QdrantConfig
|
||||
|
||||
|
||||
def create_client(config: QdrantConfig) -> QdrantClient:
|
||||
"""Create a Qdrant client from configuration.
|
||||
|
||||
Args:
|
||||
config: The Qdrant configuration.
|
||||
|
||||
Returns:
|
||||
A configured QdrantClient instance.
|
||||
"""
|
||||
|
||||
qdrant_client = SyncQdrantClientBase(**config.options)
|
||||
return QdrantClient(
|
||||
client=qdrant_client, embedding_function=config.embedding_function
|
||||
)
|
||||
155
src/crewai/rag/qdrant/types.py
Normal file
155
src/crewai/rag/qdrant/types.py
Normal file
@@ -0,0 +1,155 @@
|
||||
"""Type definitions specific to Qdrant implementation."""
|
||||
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Annotated, Any, Protocol, TypeAlias
|
||||
from typing_extensions import NotRequired, TypedDict
|
||||
|
||||
import numpy as np
|
||||
from pydantic import GetCoreSchemaHandler
|
||||
from pydantic_core import CoreSchema, core_schema
|
||||
from qdrant_client import AsyncQdrantClient, QdrantClient as SyncQdrantClient
|
||||
from qdrant_client.models import (
|
||||
FieldCondition,
|
||||
Filter,
|
||||
HasIdCondition,
|
||||
HasVectorCondition,
|
||||
HnswConfigDiff,
|
||||
InitFrom,
|
||||
IsEmptyCondition,
|
||||
IsNullCondition,
|
||||
NestedCondition,
|
||||
OptimizersConfigDiff,
|
||||
QuantizationConfig,
|
||||
ShardingMethod,
|
||||
SparseVectorsConfig,
|
||||
VectorsConfig,
|
||||
WalConfigDiff,
|
||||
)
|
||||
|
||||
from crewai.rag.core.base_client import BaseCollectionParams
|
||||
|
||||
QdrantClientType = SyncQdrantClient | AsyncQdrantClient
|
||||
|
||||
QueryEmbedding: TypeAlias = list[float] | np.ndarray[Any, np.dtype[np.floating[Any]]]
|
||||
|
||||
BasicConditions = FieldCondition | IsEmptyCondition | IsNullCondition
|
||||
StructuralConditions = HasIdCondition | HasVectorCondition | NestedCondition
|
||||
FilterCondition = BasicConditions | StructuralConditions | Filter
|
||||
|
||||
MetadataFilterValue = bool | int | str
|
||||
MetadataFilter = dict[str, MetadataFilterValue]
|
||||
|
||||
|
||||
class EmbeddingFunction(Protocol):
|
||||
"""Protocol for embedding functions that convert text to vectors."""
|
||||
|
||||
def __call__(self, text: str) -> QueryEmbedding:
|
||||
"""Convert text to embedding vector.
|
||||
|
||||
Args:
|
||||
text: Input text to embed.
|
||||
|
||||
Returns:
|
||||
Embedding vector as list of floats or numpy array.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class QdrantEmbeddingFunctionWrapper(EmbeddingFunction):
|
||||
"""Base class for Qdrant EmbeddingFunction to work with Pydantic validation."""
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(
|
||||
cls, _source_type: Any, _handler: GetCoreSchemaHandler
|
||||
) -> CoreSchema:
|
||||
"""Generate Pydantic core schema for Qdrant EmbeddingFunction.
|
||||
|
||||
This allows Pydantic to handle Qdrant's EmbeddingFunction type
|
||||
without requiring arbitrary_types_allowed=True.
|
||||
"""
|
||||
return core_schema.any_schema()
|
||||
|
||||
|
||||
class AsyncEmbeddingFunction(Protocol):
|
||||
"""Protocol for async embedding functions that convert text to vectors."""
|
||||
|
||||
async def __call__(self, text: str) -> QueryEmbedding:
|
||||
"""Convert text to embedding vector asynchronously.
|
||||
|
||||
Args:
|
||||
text: Input text to embed.
|
||||
|
||||
Returns:
|
||||
Embedding vector as list of floats or numpy array.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class QdrantClientParams(TypedDict, total=False):
|
||||
"""Parameters for QdrantClient initialization.
|
||||
|
||||
Notes:
|
||||
Need to implement in factory or remove.
|
||||
"""
|
||||
|
||||
location: str | None
|
||||
url: str | None
|
||||
port: int
|
||||
grpc_port: int
|
||||
prefer_grpc: bool
|
||||
https: bool | None
|
||||
api_key: str | None
|
||||
prefix: str | None
|
||||
timeout: int | None
|
||||
host: str | None
|
||||
path: str | None
|
||||
force_disable_check_same_thread: bool
|
||||
grpc_options: dict[str, Any] | None
|
||||
auth_token_provider: Callable[[], str] | Callable[[], Awaitable[str]] | None
|
||||
cloud_inference: bool
|
||||
local_inference_batch_size: int | None
|
||||
check_compatibility: bool
|
||||
|
||||
|
||||
class CommonCreateFields(TypedDict, total=False):
|
||||
"""Fields shared between high-level and direct create_collection params."""
|
||||
|
||||
vectors_config: VectorsConfig
|
||||
sparse_vectors_config: SparseVectorsConfig
|
||||
shard_number: Annotated[int, "Number of shards (default: 1)"]
|
||||
sharding_method: ShardingMethod
|
||||
replication_factor: Annotated[int, "Number of replicas per shard (default: 1)"]
|
||||
write_consistency_factor: Annotated[int, "Await N replicas on write (default: 1)"]
|
||||
on_disk_payload: Annotated[bool, "Store payload on disk instead of RAM"]
|
||||
hnsw_config: HnswConfigDiff
|
||||
optimizers_config: OptimizersConfigDiff
|
||||
wal_config: WalConfigDiff
|
||||
quantization_config: QuantizationConfig
|
||||
init_from: InitFrom | str
|
||||
timeout: Annotated[int, "Operation timeout in seconds"]
|
||||
|
||||
|
||||
class QdrantCollectionCreateParams(
|
||||
BaseCollectionParams, CommonCreateFields, total=False
|
||||
):
|
||||
"""High-level parameters for creating a Qdrant collection."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class CreateCollectionParams(CommonCreateFields, total=False):
|
||||
"""Parameters for qdrant_client.create_collection."""
|
||||
|
||||
collection_name: str
|
||||
|
||||
|
||||
class PreparedSearchParams(TypedDict):
|
||||
"""Type definition for prepared Qdrant search parameters."""
|
||||
|
||||
collection_name: str
|
||||
query: list[float]
|
||||
limit: Annotated[int, "Max results to return"]
|
||||
with_payload: Annotated[bool, "Include payload in results"]
|
||||
with_vectors: Annotated[bool, "Include vectors in results"]
|
||||
score_threshold: NotRequired[Annotated[float, "Min similarity score (0-1)"]]
|
||||
query_filter: NotRequired[Filter]
|
||||
228
src/crewai/rag/qdrant/utils.py
Normal file
228
src/crewai/rag/qdrant/utils.py
Normal file
@@ -0,0 +1,228 @@
|
||||
"""Utility functions for Qdrant operations."""
|
||||
|
||||
import asyncio
|
||||
from typing import TypeGuard
|
||||
from uuid import uuid4
|
||||
|
||||
from qdrant_client import AsyncQdrantClient, QdrantClient as SyncQdrantClient
|
||||
from qdrant_client.models import (
|
||||
FieldCondition,
|
||||
Filter,
|
||||
MatchValue,
|
||||
PointStruct,
|
||||
QueryResponse,
|
||||
)
|
||||
|
||||
from crewai.rag.qdrant.constants import DEFAULT_VECTOR_PARAMS
|
||||
from crewai.rag.qdrant.types import (
|
||||
AsyncEmbeddingFunction,
|
||||
CreateCollectionParams,
|
||||
EmbeddingFunction,
|
||||
FilterCondition,
|
||||
MetadataFilter,
|
||||
PreparedSearchParams,
|
||||
QdrantClientType,
|
||||
QdrantCollectionCreateParams,
|
||||
QueryEmbedding,
|
||||
)
|
||||
from crewai.rag.types import SearchResult, BaseRecord
|
||||
|
||||
|
||||
def _ensure_list_embedding(embedding: QueryEmbedding) -> list[float]:
|
||||
"""Convert embedding to list[float] format if needed.
|
||||
|
||||
Args:
|
||||
embedding: Embedding vector as list or numpy array.
|
||||
|
||||
Returns:
|
||||
Embedding as list[float].
|
||||
"""
|
||||
if not isinstance(embedding, list):
|
||||
return embedding.tolist()
|
||||
return embedding
|
||||
|
||||
|
||||
def _is_sync_client(client: QdrantClientType) -> TypeGuard[SyncQdrantClient]:
|
||||
"""Type guard to check if the client is a synchronous QdrantClient.
|
||||
|
||||
Args:
|
||||
client: The client to check.
|
||||
|
||||
Returns:
|
||||
True if the client is a QdrantClient, False otherwise.
|
||||
"""
|
||||
return isinstance(client, SyncQdrantClient)
|
||||
|
||||
|
||||
def _is_async_client(client: QdrantClientType) -> TypeGuard[AsyncQdrantClient]:
|
||||
"""Type guard to check if the client is an asynchronous AsyncQdrantClient.
|
||||
|
||||
Args:
|
||||
client: The client to check.
|
||||
|
||||
Returns:
|
||||
True if the client is an AsyncQdrantClient, False otherwise.
|
||||
"""
|
||||
return isinstance(client, AsyncQdrantClient)
|
||||
|
||||
|
||||
def _is_async_embedding_function(
|
||||
func: EmbeddingFunction | AsyncEmbeddingFunction,
|
||||
) -> TypeGuard[AsyncEmbeddingFunction]:
|
||||
"""Type guard to check if the embedding function is async.
|
||||
|
||||
Args:
|
||||
func: The embedding function to check.
|
||||
|
||||
Returns:
|
||||
True if the function is async, False otherwise.
|
||||
"""
|
||||
return asyncio.iscoroutinefunction(func)
|
||||
|
||||
|
||||
def _get_collection_params(
|
||||
kwargs: QdrantCollectionCreateParams,
|
||||
) -> CreateCollectionParams:
|
||||
"""Extract collection creation parameters from kwargs."""
|
||||
params: CreateCollectionParams = {
|
||||
"collection_name": kwargs["collection_name"],
|
||||
"vectors_config": kwargs.get("vectors_config", DEFAULT_VECTOR_PARAMS),
|
||||
}
|
||||
|
||||
if "sparse_vectors_config" in kwargs:
|
||||
params["sparse_vectors_config"] = kwargs["sparse_vectors_config"]
|
||||
if "shard_number" in kwargs:
|
||||
params["shard_number"] = kwargs["shard_number"]
|
||||
if "sharding_method" in kwargs:
|
||||
params["sharding_method"] = kwargs["sharding_method"]
|
||||
if "replication_factor" in kwargs:
|
||||
params["replication_factor"] = kwargs["replication_factor"]
|
||||
if "write_consistency_factor" in kwargs:
|
||||
params["write_consistency_factor"] = kwargs["write_consistency_factor"]
|
||||
if "on_disk_payload" in kwargs:
|
||||
params["on_disk_payload"] = kwargs["on_disk_payload"]
|
||||
if "hnsw_config" in kwargs:
|
||||
params["hnsw_config"] = kwargs["hnsw_config"]
|
||||
if "optimizers_config" in kwargs:
|
||||
params["optimizers_config"] = kwargs["optimizers_config"]
|
||||
if "wal_config" in kwargs:
|
||||
params["wal_config"] = kwargs["wal_config"]
|
||||
if "quantization_config" in kwargs:
|
||||
params["quantization_config"] = kwargs["quantization_config"]
|
||||
if "init_from" in kwargs:
|
||||
params["init_from"] = kwargs["init_from"]
|
||||
if "timeout" in kwargs:
|
||||
params["timeout"] = kwargs["timeout"]
|
||||
|
||||
return params
|
||||
|
||||
|
||||
def _prepare_search_params(
|
||||
collection_name: str,
|
||||
query_embedding: QueryEmbedding,
|
||||
limit: int,
|
||||
score_threshold: float | None,
|
||||
metadata_filter: MetadataFilter | None,
|
||||
) -> PreparedSearchParams:
|
||||
"""Prepare search parameters for Qdrant query_points.
|
||||
|
||||
Args:
|
||||
collection_name: Name of the collection to search.
|
||||
query_embedding: Embedding vector for the query.
|
||||
limit: Maximum number of results.
|
||||
score_threshold: Optional minimum similarity score.
|
||||
metadata_filter: Optional metadata filters.
|
||||
|
||||
Returns:
|
||||
Dictionary of parameters for query_points method.
|
||||
"""
|
||||
query_vector = _ensure_list_embedding(query_embedding)
|
||||
|
||||
search_kwargs: PreparedSearchParams = {
|
||||
"collection_name": collection_name,
|
||||
"query": query_vector,
|
||||
"limit": limit,
|
||||
"with_payload": True,
|
||||
"with_vectors": False,
|
||||
}
|
||||
|
||||
if score_threshold is not None:
|
||||
search_kwargs["score_threshold"] = score_threshold
|
||||
|
||||
if metadata_filter:
|
||||
filter_conditions: list[FilterCondition] = []
|
||||
for key, value in metadata_filter.items():
|
||||
filter_conditions.append(
|
||||
FieldCondition(key=key, match=MatchValue(value=value))
|
||||
)
|
||||
|
||||
search_kwargs["query_filter"] = Filter(must=filter_conditions)
|
||||
|
||||
return search_kwargs
|
||||
|
||||
|
||||
def _normalize_qdrant_score(score: float) -> float:
|
||||
"""Normalize Qdrant cosine similarity score to [0, 1] range.
|
||||
|
||||
Converts from Qdrant's [-1, 1] cosine similarity range to [0, 1] range for standardization across clients.
|
||||
|
||||
Args:
|
||||
score: Raw cosine similarity score from Qdrant [-1, 1].
|
||||
|
||||
Returns:
|
||||
Normalized score in [0, 1] range where 1 is most similar.
|
||||
"""
|
||||
normalized = (score + 1.0) / 2.0
|
||||
return max(0.0, min(1.0, normalized))
|
||||
|
||||
|
||||
def _process_search_results(response: QueryResponse) -> list[SearchResult]:
|
||||
"""Process Qdrant search response into SearchResult format.
|
||||
|
||||
Args:
|
||||
response: Response from Qdrant query_points method.
|
||||
|
||||
Returns:
|
||||
List of SearchResult dictionaries.
|
||||
"""
|
||||
results: list[SearchResult] = []
|
||||
for point in response.points:
|
||||
payload = point.payload or {}
|
||||
score = _normalize_qdrant_score(score=point.score)
|
||||
result: SearchResult = {
|
||||
"id": str(point.id),
|
||||
"content": payload.get("content", ""),
|
||||
"metadata": {k: v for k, v in payload.items() if k != "content"},
|
||||
"score": score,
|
||||
}
|
||||
results.append(result)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def _create_point_from_document(
|
||||
doc: BaseRecord, embedding: QueryEmbedding
|
||||
) -> PointStruct:
|
||||
"""Create a PointStruct from a document and its embedding.
|
||||
|
||||
Args:
|
||||
doc: Document dictionary containing content, metadata, and optional doc_id.
|
||||
embedding: The embedding vector for the document content.
|
||||
|
||||
Returns:
|
||||
PointStruct ready to be upserted to Qdrant.
|
||||
"""
|
||||
doc_id = doc.get("doc_id", str(uuid4()))
|
||||
vector = _ensure_list_embedding(embedding)
|
||||
|
||||
metadata = doc.get("metadata", {})
|
||||
if isinstance(metadata, list):
|
||||
metadata = metadata[0] if metadata else {}
|
||||
elif not isinstance(metadata, dict):
|
||||
metadata = dict(metadata) if metadata else {}
|
||||
|
||||
return PointStruct(
|
||||
id=doc_id,
|
||||
vector=vector,
|
||||
payload={"content": doc["content"], **metadata},
|
||||
)
|
||||
50
src/crewai/rag/types.py
Normal file
50
src/crewai/rag/types.py
Normal file
@@ -0,0 +1,50 @@
|
||||
"""Type definitions for RAG (Retrieval-Augmented Generation) systems."""
|
||||
|
||||
from collections.abc import Callable, Mapping
|
||||
from typing import TypeAlias, Any
|
||||
|
||||
from typing_extensions import Required, TypedDict
|
||||
|
||||
|
||||
class BaseRecord(TypedDict, total=False):
|
||||
"""A typed dictionary representing a document record.
|
||||
|
||||
Attributes:
|
||||
doc_id: Optional unique identifier for the document. If not provided,
|
||||
a content-based ID will be generated using SHA256 hash.
|
||||
content: The text content of the document (required)
|
||||
metadata: Optional metadata associated with the document
|
||||
"""
|
||||
|
||||
doc_id: str
|
||||
content: Required[str]
|
||||
metadata: (
|
||||
Mapping[str, str | int | float | bool]
|
||||
| list[Mapping[str, str | int | float | bool]]
|
||||
)
|
||||
|
||||
|
||||
DenseVector: TypeAlias = list[float]
|
||||
IntVector: TypeAlias = list[int]
|
||||
|
||||
EmbeddingFunction: TypeAlias = Callable[..., Any]
|
||||
|
||||
|
||||
class SearchResult(TypedDict):
|
||||
"""Standard search result format for vector store queries.
|
||||
|
||||
This provides a consistent interface for search results across different
|
||||
vector store implementations. Each implementation should convert their
|
||||
native result format to this standard format.
|
||||
|
||||
Attributes:
|
||||
id: Unique identifier of the document
|
||||
content: The text content of the document
|
||||
metadata: Optional metadata associated with the document
|
||||
score: Similarity score (higher is better, typically between 0 and 1)
|
||||
"""
|
||||
|
||||
id: str
|
||||
content: str
|
||||
metadata: dict[str, Any]
|
||||
score: float
|
||||
@@ -4,6 +4,7 @@ import json
|
||||
import logging
|
||||
import threading
|
||||
import uuid
|
||||
import warnings
|
||||
from concurrent.futures import Future
|
||||
from copy import copy
|
||||
from hashlib import md5
|
||||
@@ -157,8 +158,13 @@ class Task(BaseModel):
|
||||
default=None,
|
||||
description="Function or string description of a guardrail to validate task output before proceeding to next task",
|
||||
)
|
||||
max_retries: int = Field(
|
||||
default=3, description="Maximum number of retries when guardrail fails"
|
||||
max_retries: Optional[int] = Field(
|
||||
default=None,
|
||||
description="[DEPRECATED] Maximum number of retries when guardrail fails. Use guardrail_max_retries instead. Will be removed in v1.0.0"
|
||||
)
|
||||
guardrail_max_retries: int = Field(
|
||||
default=3,
|
||||
description="Maximum number of retries when guardrail fails"
|
||||
)
|
||||
retry_count: int = Field(default=0, description="Current number of retries")
|
||||
start_time: Optional[datetime.datetime] = Field(
|
||||
@@ -354,6 +360,18 @@ class Task(BaseModel):
|
||||
)
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def handle_max_retries_deprecation(self):
|
||||
if self.max_retries is not None:
|
||||
warnings.warn(
|
||||
"The 'max_retries' parameter is deprecated and will be removed in CrewAI v1.0.0. "
|
||||
"Please use 'guardrail_max_retries' instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2
|
||||
)
|
||||
self.guardrail_max_retries = self.max_retries
|
||||
return self
|
||||
|
||||
def execute_sync(
|
||||
self,
|
||||
agent: Optional[BaseAgent] = None,
|
||||
@@ -433,7 +451,7 @@ class Task(BaseModel):
|
||||
|
||||
pydantic_output, json_output = self._export_output(result)
|
||||
task_output = TaskOutput(
|
||||
name=self.name,
|
||||
name=self.name or self.description,
|
||||
description=self.description,
|
||||
expected_output=self.expected_output,
|
||||
raw=result,
|
||||
@@ -450,9 +468,9 @@ class Task(BaseModel):
|
||||
retry_count=self.retry_count,
|
||||
)
|
||||
if not guardrail_result.success:
|
||||
if self.retry_count >= self.max_retries:
|
||||
if self.retry_count >= self.guardrail_max_retries:
|
||||
raise Exception(
|
||||
f"Task failed guardrail validation after {self.max_retries} retries. "
|
||||
f"Task failed guardrail validation after {self.guardrail_max_retries} retries. "
|
||||
f"Last error: {guardrail_result.error}"
|
||||
)
|
||||
|
||||
@@ -561,8 +579,8 @@ class Task(BaseModel):
|
||||
should_inject = self.allow_crewai_trigger_context
|
||||
|
||||
if should_inject and self.agent:
|
||||
crew = getattr(self.agent, 'crew', None)
|
||||
if crew and hasattr(crew, '_inputs') and crew._inputs:
|
||||
crew = getattr(self.agent, "crew", None)
|
||||
if crew and hasattr(crew, "_inputs") and crew._inputs:
|
||||
trigger_payload = crew._inputs.get("crewai_trigger_payload")
|
||||
if trigger_payload is not None:
|
||||
description += f"\n\nTrigger Payload: {trigger_payload}"
|
||||
@@ -780,7 +798,9 @@ Follow these guidelines:
|
||||
if self.create_directory and not directory.exists():
|
||||
directory.mkdir(parents=True, exist_ok=True)
|
||||
elif not self.create_directory and not directory.exists():
|
||||
raise RuntimeError(f"Directory {directory} does not exist and create_directory is False")
|
||||
raise RuntimeError(
|
||||
f"Directory {directory} does not exist and create_directory is False"
|
||||
)
|
||||
|
||||
with resolved_path.open("w", encoding="utf-8") as file:
|
||||
if isinstance(result, dict):
|
||||
|
||||
@@ -14,12 +14,14 @@ from pydantic import BaseModel as PydanticBaseModel
|
||||
|
||||
from crewai.tools.structured_tool import CrewStructuredTool
|
||||
|
||||
|
||||
class EnvVar(BaseModel):
|
||||
name: str
|
||||
description: str
|
||||
required: bool = True
|
||||
default: Optional[str] = None
|
||||
|
||||
|
||||
class BaseTool(BaseModel, ABC):
|
||||
class _ArgsSchemaPlaceholder(PydanticBaseModel):
|
||||
pass
|
||||
@@ -108,7 +110,7 @@ class BaseTool(BaseModel, ABC):
|
||||
def to_structured_tool(self) -> CrewStructuredTool:
|
||||
"""Convert this tool to a CrewStructuredTool instance."""
|
||||
self._set_args_schema()
|
||||
return CrewStructuredTool(
|
||||
structured_tool = CrewStructuredTool(
|
||||
name=self.name,
|
||||
description=self.description,
|
||||
args_schema=self.args_schema,
|
||||
@@ -117,6 +119,8 @@ class BaseTool(BaseModel, ABC):
|
||||
max_usage_count=self.max_usage_count,
|
||||
current_usage_count=self.current_usage_count,
|
||||
)
|
||||
structured_tool._original_tool = self
|
||||
return structured_tool
|
||||
|
||||
@classmethod
|
||||
def from_langchain(cls, tool: Any) -> "BaseTool":
|
||||
@@ -276,7 +280,9 @@ def to_langchain(
|
||||
return [t.to_structured_tool() if isinstance(t, BaseTool) else t for t in tools]
|
||||
|
||||
|
||||
def tool(*args, result_as_answer: bool = False, max_usage_count: int | None = None) -> Callable:
|
||||
def tool(
|
||||
*args, result_as_answer: bool = False, max_usage_count: int | None = None
|
||||
) -> Callable:
|
||||
"""
|
||||
Decorator to create a tool from a function.
|
||||
|
||||
|
||||
@@ -10,6 +10,17 @@ from pydantic import BaseModel, Field, create_model
|
||||
|
||||
from crewai.utilities.logger import Logger
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
|
||||
|
||||
class ToolUsageLimitExceeded(Exception):
|
||||
"""Exception raised when a tool has reached its maximum usage limit."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class CrewStructuredTool:
|
||||
"""A structured tool that can operate on any number of inputs.
|
||||
@@ -18,6 +29,8 @@ class CrewStructuredTool:
|
||||
that integrates better with CrewAI's ecosystem.
|
||||
"""
|
||||
|
||||
_original_tool: BaseTool | None = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
@@ -47,6 +60,7 @@ class CrewStructuredTool:
|
||||
self.result_as_answer = result_as_answer
|
||||
self.max_usage_count = max_usage_count
|
||||
self.current_usage_count = current_usage_count
|
||||
self._original_tool = None
|
||||
|
||||
# Validate the function signature matches the schema
|
||||
self._validate_function_signature()
|
||||
@@ -219,16 +233,26 @@ class CrewStructuredTool:
|
||||
"""
|
||||
parsed_args = self._parse_args(input)
|
||||
|
||||
if inspect.iscoroutinefunction(self.func):
|
||||
return await self.func(**parsed_args, **kwargs)
|
||||
else:
|
||||
# Run sync functions in a thread pool
|
||||
import asyncio
|
||||
|
||||
return await asyncio.get_event_loop().run_in_executor(
|
||||
None, lambda: self.func(**parsed_args, **kwargs)
|
||||
if self.has_reached_max_usage_count():
|
||||
raise ToolUsageLimitExceeded(
|
||||
f"Tool '{self.name}' has reached its maximum usage limit of {self.max_usage_count}. You should not use the {self.name} tool again."
|
||||
)
|
||||
|
||||
self._increment_usage_count()
|
||||
|
||||
try:
|
||||
if inspect.iscoroutinefunction(self.func):
|
||||
return await self.func(**parsed_args, **kwargs)
|
||||
else:
|
||||
# Run sync functions in a thread pool
|
||||
import asyncio
|
||||
|
||||
return await asyncio.get_event_loop().run_in_executor(
|
||||
None, lambda: self.func(**parsed_args, **kwargs)
|
||||
)
|
||||
except Exception:
|
||||
raise
|
||||
|
||||
def _run(self, *args, **kwargs) -> Any:
|
||||
"""Legacy method for compatibility."""
|
||||
# Convert args/kwargs to our expected format
|
||||
@@ -242,10 +266,22 @@ class CrewStructuredTool:
|
||||
"""Main method for tool execution."""
|
||||
parsed_args = self._parse_args(input)
|
||||
|
||||
if self.has_reached_max_usage_count():
|
||||
raise ToolUsageLimitExceeded(
|
||||
f"Tool '{self.name}' has reached its maximum usage limit of {self.max_usage_count}. You should not use the {self.name} tool again."
|
||||
)
|
||||
|
||||
self._increment_usage_count()
|
||||
|
||||
if inspect.iscoroutinefunction(self.func):
|
||||
result = asyncio.run(self.func(**parsed_args, **kwargs))
|
||||
return result
|
||||
|
||||
try:
|
||||
result = self.func(**parsed_args, **kwargs)
|
||||
except Exception:
|
||||
raise
|
||||
|
||||
result = self.func(**parsed_args, **kwargs)
|
||||
|
||||
if asyncio.iscoroutine(result):
|
||||
@@ -253,6 +289,19 @@ class CrewStructuredTool:
|
||||
|
||||
return result
|
||||
|
||||
def has_reached_max_usage_count(self) -> bool:
|
||||
"""Check if the tool has reached its maximum usage count."""
|
||||
return (
|
||||
self.max_usage_count is not None
|
||||
and self.current_usage_count >= self.max_usage_count
|
||||
)
|
||||
|
||||
def _increment_usage_count(self) -> None:
|
||||
"""Increment the usage count."""
|
||||
self.current_usage_count += 1
|
||||
if self._original_tool is not None:
|
||||
self._original_tool.current_usage_count = self.current_usage_count
|
||||
|
||||
@property
|
||||
def args(self) -> dict:
|
||||
"""Get the tool's input arguments schema."""
|
||||
|
||||
@@ -178,9 +178,11 @@ class ToolUsage:
|
||||
|
||||
if self.agent.fingerprint:
|
||||
event_data.update(self.agent.fingerprint)
|
||||
if self.task:
|
||||
event_data["task_name"] = self.task.name or self.task.description
|
||||
event_data["task_id"] = str(self.task.id)
|
||||
crewai_event_bus.emit(self, ToolUsageStartedEvent(**event_data))
|
||||
|
||||
crewai_event_bus.emit(self,ToolUsageStartedEvent(**event_data))
|
||||
|
||||
started_at = time.time()
|
||||
from_cache = False
|
||||
result = None # type: ignore
|
||||
@@ -311,12 +313,15 @@ class ToolUsage:
|
||||
if self.agent and hasattr(self.agent, "tools_results"):
|
||||
self.agent.tools_results.append(data)
|
||||
|
||||
if available_tool and hasattr(available_tool, 'current_usage_count'):
|
||||
if available_tool and hasattr(available_tool, "current_usage_count"):
|
||||
available_tool.current_usage_count += 1
|
||||
if hasattr(available_tool, 'max_usage_count') and available_tool.max_usage_count is not None:
|
||||
if (
|
||||
hasattr(available_tool, "max_usage_count")
|
||||
and available_tool.max_usage_count is not None
|
||||
):
|
||||
self._printer.print(
|
||||
content=f"Tool '{available_tool.name}' usage: {available_tool.current_usage_count}/{available_tool.max_usage_count}",
|
||||
color="blue"
|
||||
color="blue",
|
||||
)
|
||||
|
||||
return result
|
||||
@@ -350,20 +355,20 @@ class ToolUsage:
|
||||
calling.arguments == last_tool_usage.arguments
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
def _check_usage_limit(self, tool: Any, tool_name: str) -> str | None:
|
||||
"""Check if tool has reached its usage limit.
|
||||
|
||||
|
||||
Args:
|
||||
tool: The tool to check
|
||||
tool_name: The name of the tool (used for error message)
|
||||
|
||||
|
||||
Returns:
|
||||
Error message if limit reached, None otherwise
|
||||
"""
|
||||
if (
|
||||
hasattr(tool, 'max_usage_count')
|
||||
and tool.max_usage_count is not None
|
||||
hasattr(tool, "max_usage_count")
|
||||
and tool.max_usage_count is not None
|
||||
and tool.current_usage_count >= tool.max_usage_count
|
||||
):
|
||||
return f"Tool '{tool_name}' has reached its usage limit of {tool.max_usage_count} times and cannot be used anymore."
|
||||
@@ -605,6 +610,9 @@ class ToolUsage:
|
||||
"output": result,
|
||||
}
|
||||
)
|
||||
if self.task:
|
||||
event_data["task_id"] = str(self.task.id)
|
||||
event_data["task_name"] = self.task.name or self.task.description
|
||||
crewai_event_bus.emit(self, ToolUsageFinishedEvent(**event_data))
|
||||
|
||||
def _prepare_event_data(
|
||||
|
||||
@@ -11,7 +11,9 @@ class BaseEvent(BaseModel):
|
||||
timestamp: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
type: str
|
||||
source_fingerprint: Optional[str] = None # UUID string of the source entity
|
||||
source_type: Optional[str] = None # "agent", "task", "crew", "memory", "entity_memory", "short_term_memory", "long_term_memory", "external_memory"
|
||||
source_type: Optional[str] = (
|
||||
None # "agent", "task", "crew", "memory", "entity_memory", "short_term_memory", "long_term_memory", "external_memory"
|
||||
)
|
||||
fingerprint_metadata: Optional[Dict[str, Any]] = None # Any relevant metadata
|
||||
|
||||
def to_json(self, exclude: set[str] | None = None):
|
||||
@@ -25,3 +27,20 @@ class BaseEvent(BaseModel):
|
||||
dict: A JSON-serializable dictionary.
|
||||
"""
|
||||
return to_serializable(self, exclude=exclude)
|
||||
|
||||
def _set_task_params(self, data: Dict[str, Any]):
|
||||
if "from_task" in data and (task := data["from_task"]):
|
||||
self.task_id = task.id
|
||||
self.task_name = task.name or task.description
|
||||
self.from_task = None
|
||||
|
||||
def _set_agent_params(self, data: Dict[str, Any]):
|
||||
task = data.get("from_task", None)
|
||||
agent = task.agent if task else data.get("from_agent", None)
|
||||
|
||||
if not agent:
|
||||
return
|
||||
|
||||
self.agent_id = agent.id
|
||||
self.agent_role = agent.role
|
||||
self.from_agent = None
|
||||
|
||||
@@ -161,8 +161,10 @@ class EventListener(BaseEventListener):
|
||||
def on_task_started(source, event: TaskStartedEvent):
|
||||
span = self._telemetry.task_started(crew=source.agent.crew, task=source)
|
||||
self.execution_spans[source] = span
|
||||
# Pass both task ID and task name (if set)
|
||||
task_name = source.name if hasattr(source, 'name') and source.name else None
|
||||
self.formatter.create_task_branch(
|
||||
self.formatter.current_crew_tree, source.id
|
||||
self.formatter.current_crew_tree, source.id, task_name
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(TaskCompletedEvent)
|
||||
@@ -173,11 +175,14 @@ class EventListener(BaseEventListener):
|
||||
self._telemetry.task_ended(span, source, source.agent.crew)
|
||||
self.execution_spans[source] = None
|
||||
|
||||
# Pass task name if it exists
|
||||
task_name = source.name if hasattr(source, 'name') and source.name else None
|
||||
self.formatter.update_task_status(
|
||||
self.formatter.current_crew_tree,
|
||||
source.id,
|
||||
source.agent.role,
|
||||
"completed",
|
||||
task_name
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(TaskFailedEvent)
|
||||
@@ -188,11 +193,14 @@ class EventListener(BaseEventListener):
|
||||
self._telemetry.task_ended(span, source, source.agent.crew)
|
||||
self.execution_spans[source] = None
|
||||
|
||||
# Pass task name if it exists
|
||||
task_name = source.name if hasattr(source, 'name') and source.name else None
|
||||
self.formatter.update_task_status(
|
||||
self.formatter.current_crew_tree,
|
||||
source.id,
|
||||
source.agent.role,
|
||||
"failed",
|
||||
task_name
|
||||
)
|
||||
|
||||
# ----------- AGENT EVENTS -----------
|
||||
|
||||
@@ -41,6 +41,12 @@ class TraceBatchManager:
|
||||
"""Single responsibility: Manage batches and event buffering"""
|
||||
|
||||
is_current_batch_ephemeral: bool = False
|
||||
trace_batch_id: Optional[str] = None
|
||||
current_batch: Optional[TraceBatch] = None
|
||||
event_buffer: List[TraceEvent] = []
|
||||
execution_start_times: Dict[str, datetime] = {}
|
||||
batch_owner_type: Optional[str] = None
|
||||
batch_owner_id: Optional[str] = None
|
||||
|
||||
def __init__(self):
|
||||
try:
|
||||
@@ -48,11 +54,6 @@ class TraceBatchManager:
|
||||
except AuthError:
|
||||
self.plus_api = PlusAPI(api_key="")
|
||||
|
||||
self.trace_batch_id: Optional[str] = None # Backend ID
|
||||
self.current_batch: Optional[TraceBatch] = None
|
||||
self.event_buffer: List[TraceEvent] = []
|
||||
self.execution_start_times: Dict[str, datetime] = {}
|
||||
|
||||
def initialize_batch(
|
||||
self,
|
||||
user_context: Dict[str, str],
|
||||
@@ -178,14 +179,16 @@ class TraceBatchManager:
|
||||
if not self.current_batch:
|
||||
return None
|
||||
|
||||
self.current_batch.events = self.event_buffer.copy()
|
||||
if self.event_buffer:
|
||||
self._send_events_to_backend()
|
||||
self._finalize_backend_batch()
|
||||
|
||||
self.current_batch.events = self.event_buffer.copy()
|
||||
|
||||
finalized_batch = self.current_batch
|
||||
|
||||
self.batch_owner_type = None
|
||||
self.batch_owner_id = None
|
||||
|
||||
self.current_batch = None
|
||||
self.event_buffer.clear()
|
||||
self.trace_batch_id = None
|
||||
|
||||
@@ -75,10 +75,18 @@ class TraceCollectionListener(BaseEventListener):
|
||||
Trace collection listener that orchestrates trace collection
|
||||
"""
|
||||
|
||||
complex_events = ["task_started", "llm_call_started", "llm_call_completed"]
|
||||
complex_events = [
|
||||
"task_started",
|
||||
"task_completed",
|
||||
"llm_call_started",
|
||||
"llm_call_completed",
|
||||
"agent_execution_started",
|
||||
"agent_execution_completed",
|
||||
]
|
||||
|
||||
_instance = None
|
||||
_initialized = False
|
||||
_listeners_setup = False
|
||||
|
||||
def __new__(cls, batch_manager=None):
|
||||
if cls._instance is None:
|
||||
@@ -116,10 +124,15 @@ class TraceCollectionListener(BaseEventListener):
|
||||
def setup_listeners(self, crewai_event_bus):
|
||||
"""Setup event listeners - delegates to specific handlers"""
|
||||
|
||||
if self._listeners_setup:
|
||||
return
|
||||
|
||||
self._register_flow_event_handlers(crewai_event_bus)
|
||||
self._register_context_event_handlers(crewai_event_bus)
|
||||
self._register_action_event_handlers(crewai_event_bus)
|
||||
|
||||
self._listeners_setup = True
|
||||
|
||||
def _register_flow_event_handlers(self, event_bus):
|
||||
"""Register handlers for flow events"""
|
||||
|
||||
@@ -148,7 +161,8 @@ class TraceCollectionListener(BaseEventListener):
|
||||
@event_bus.on(FlowFinishedEvent)
|
||||
def on_flow_finished(source, event):
|
||||
self._handle_trace_event("flow_finished", source, event)
|
||||
self.batch_manager.finalize_batch()
|
||||
if self.batch_manager.batch_owner_type == "flow":
|
||||
self.batch_manager.finalize_batch()
|
||||
|
||||
@event_bus.on(FlowPlotEvent)
|
||||
def on_flow_plot(source, event):
|
||||
@@ -166,7 +180,8 @@ class TraceCollectionListener(BaseEventListener):
|
||||
@event_bus.on(CrewKickoffCompletedEvent)
|
||||
def on_crew_completed(source, event):
|
||||
self._handle_trace_event("crew_kickoff_completed", source, event)
|
||||
self.batch_manager.finalize_batch()
|
||||
if self.batch_manager.batch_owner_type == "crew":
|
||||
self.batch_manager.finalize_batch()
|
||||
|
||||
@event_bus.on(CrewKickoffFailedEvent)
|
||||
def on_crew_failed(source, event):
|
||||
@@ -218,7 +233,7 @@ class TraceCollectionListener(BaseEventListener):
|
||||
self._handle_trace_event("llm_guardrail_completed", source, event)
|
||||
|
||||
def _register_action_event_handlers(self, event_bus):
|
||||
"""Register handlers for action events (LLM calls, tool usage, memory)"""
|
||||
"""Register handlers for action events (LLM calls, tool usage)"""
|
||||
|
||||
@event_bus.on(LLMCallStartedEvent)
|
||||
def on_llm_call_started(source, event):
|
||||
@@ -289,6 +304,9 @@ class TraceCollectionListener(BaseEventListener):
|
||||
"crewai_version": get_crewai_version(),
|
||||
}
|
||||
|
||||
self.batch_manager.batch_owner_type = "crew"
|
||||
self.batch_manager.batch_owner_id = getattr(source, "id", str(uuid.uuid4()))
|
||||
|
||||
self._initialize_batch(user_context, execution_metadata)
|
||||
|
||||
def _initialize_flow_batch(self, source: Any, event: Any):
|
||||
@@ -301,6 +319,9 @@ class TraceCollectionListener(BaseEventListener):
|
||||
"execution_type": "flow",
|
||||
}
|
||||
|
||||
self.batch_manager.batch_owner_type = "flow"
|
||||
self.batch_manager.batch_owner_id = getattr(source, "id", str(uuid.uuid4()))
|
||||
|
||||
self._initialize_batch(user_context, execution_metadata)
|
||||
|
||||
def _initialize_batch(
|
||||
@@ -358,12 +379,44 @@ class TraceCollectionListener(BaseEventListener):
|
||||
return {
|
||||
"task_description": event.task.description,
|
||||
"expected_output": event.task.expected_output,
|
||||
"task_name": event.task.name,
|
||||
"task_name": event.task.name or event.task.description,
|
||||
"context": event.context,
|
||||
"agent": source.agent.role,
|
||||
"agent_role": source.agent.role,
|
||||
"task_id": str(event.task.id),
|
||||
}
|
||||
elif event_type == "task_completed":
|
||||
return {
|
||||
"task_description": event.task.description if event.task else None,
|
||||
"task_name": event.task.name or event.task.description
|
||||
if event.task
|
||||
else None,
|
||||
"task_id": str(event.task.id) if event.task else None,
|
||||
"output_raw": event.output.raw if event.output else None,
|
||||
"output_format": str(event.output.output_format)
|
||||
if event.output
|
||||
else None,
|
||||
"agent_role": event.output.agent if event.output else None,
|
||||
}
|
||||
elif event_type == "agent_execution_started":
|
||||
return {
|
||||
"agent_role": event.agent.role,
|
||||
"agent_goal": event.agent.goal,
|
||||
"agent_backstory": event.agent.backstory,
|
||||
}
|
||||
elif event_type == "agent_execution_completed":
|
||||
return {
|
||||
"agent_role": event.agent.role,
|
||||
"agent_goal": event.agent.goal,
|
||||
"agent_backstory": event.agent.backstory,
|
||||
}
|
||||
elif event_type == "llm_call_started":
|
||||
return self._safe_serialize_to_dict(event)
|
||||
event_data = self._safe_serialize_to_dict(event)
|
||||
event_data["task_name"] = (
|
||||
event.task_name or event.task_description
|
||||
if hasattr(event, "task_name") and event.task_name
|
||||
else None
|
||||
)
|
||||
return event_data
|
||||
elif event_type == "llm_call_completed":
|
||||
return self._safe_serialize_to_dict(event)
|
||||
else:
|
||||
|
||||
@@ -13,26 +13,14 @@ class LLMEventBase(BaseEvent):
|
||||
agent_id: Optional[str] = None
|
||||
agent_role: Optional[str] = None
|
||||
|
||||
from_task: Optional[Any] = None
|
||||
from_agent: Optional[Any] = None
|
||||
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
self._set_agent_params(data)
|
||||
self._set_task_params(data)
|
||||
|
||||
def _set_agent_params(self, data: Dict[str, Any]):
|
||||
task = data.get("from_task", None)
|
||||
agent = task.agent if task else data.get("from_agent", None)
|
||||
|
||||
if not agent:
|
||||
return
|
||||
|
||||
self.agent_id = agent.id
|
||||
self.agent_role = agent.role
|
||||
|
||||
def _set_task_params(self, data: Dict[str, Any]):
|
||||
if "from_task" in data and (task := data["from_task"]):
|
||||
self.task_id = task.id
|
||||
self.task_name = task.name
|
||||
|
||||
|
||||
class LLMCallType(Enum):
|
||||
"""Type of LLM call being made"""
|
||||
|
||||
@@ -3,7 +3,24 @@ from typing import Any, Dict, Optional
|
||||
from crewai.utilities.events.base_events import BaseEvent
|
||||
|
||||
|
||||
class MemoryQueryStartedEvent(BaseEvent):
|
||||
class MemoryBaseEvent(BaseEvent):
|
||||
"""Base event for memory operations"""
|
||||
|
||||
type: str
|
||||
task_id: Optional[str] = None
|
||||
task_name: Optional[str] = None
|
||||
from_task: Optional[Any] = None
|
||||
from_agent: Optional[Any] = None
|
||||
agent_role: Optional[str] = None
|
||||
agent_id: Optional[str] = None
|
||||
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
self._set_agent_params(data)
|
||||
self._set_task_params(data)
|
||||
|
||||
|
||||
class MemoryQueryStartedEvent(MemoryBaseEvent):
|
||||
"""Event emitted when a memory query is started"""
|
||||
|
||||
type: str = "memory_query_started"
|
||||
@@ -12,7 +29,7 @@ class MemoryQueryStartedEvent(BaseEvent):
|
||||
score_threshold: Optional[float] = None
|
||||
|
||||
|
||||
class MemoryQueryCompletedEvent(BaseEvent):
|
||||
class MemoryQueryCompletedEvent(MemoryBaseEvent):
|
||||
"""Event emitted when a memory query is completed successfully"""
|
||||
|
||||
type: str = "memory_query_completed"
|
||||
@@ -23,7 +40,7 @@ class MemoryQueryCompletedEvent(BaseEvent):
|
||||
query_time_ms: float
|
||||
|
||||
|
||||
class MemoryQueryFailedEvent(BaseEvent):
|
||||
class MemoryQueryFailedEvent(MemoryBaseEvent):
|
||||
"""Event emitted when a memory query fails"""
|
||||
|
||||
type: str = "memory_query_failed"
|
||||
@@ -33,7 +50,7 @@ class MemoryQueryFailedEvent(BaseEvent):
|
||||
error: str
|
||||
|
||||
|
||||
class MemorySaveStartedEvent(BaseEvent):
|
||||
class MemorySaveStartedEvent(MemoryBaseEvent):
|
||||
"""Event emitted when a memory save operation is started"""
|
||||
|
||||
type: str = "memory_save_started"
|
||||
@@ -42,7 +59,7 @@ class MemorySaveStartedEvent(BaseEvent):
|
||||
agent_role: Optional[str] = None
|
||||
|
||||
|
||||
class MemorySaveCompletedEvent(BaseEvent):
|
||||
class MemorySaveCompletedEvent(MemoryBaseEvent):
|
||||
"""Event emitted when a memory save operation is completed successfully"""
|
||||
|
||||
type: str = "memory_save_completed"
|
||||
@@ -52,7 +69,7 @@ class MemorySaveCompletedEvent(BaseEvent):
|
||||
save_time_ms: float
|
||||
|
||||
|
||||
class MemorySaveFailedEvent(BaseEvent):
|
||||
class MemorySaveFailedEvent(MemoryBaseEvent):
|
||||
"""Event emitted when a memory save operation fails"""
|
||||
|
||||
type: str = "memory_save_failed"
|
||||
@@ -62,14 +79,14 @@ class MemorySaveFailedEvent(BaseEvent):
|
||||
error: str
|
||||
|
||||
|
||||
class MemoryRetrievalStartedEvent(BaseEvent):
|
||||
class MemoryRetrievalStartedEvent(MemoryBaseEvent):
|
||||
"""Event emitted when memory retrieval for a task prompt starts"""
|
||||
|
||||
type: str = "memory_retrieval_started"
|
||||
task_id: Optional[str] = None
|
||||
|
||||
|
||||
class MemoryRetrievalCompletedEvent(BaseEvent):
|
||||
class MemoryRetrievalCompletedEvent(MemoryBaseEvent):
|
||||
"""Event emitted when memory retrieval for a task prompt completes successfully"""
|
||||
|
||||
type: str = "memory_retrieval_completed"
|
||||
|
||||
@@ -1,16 +1,34 @@
|
||||
from crewai.utilities.events.base_events import BaseEvent
|
||||
from typing import Any, Optional
|
||||
|
||||
|
||||
class AgentReasoningStartedEvent(BaseEvent):
|
||||
class ReasoningEvent(BaseEvent):
|
||||
"""Base event for reasoning events."""
|
||||
|
||||
type: str
|
||||
attempt: int = 1
|
||||
agent_role: str
|
||||
task_id: str
|
||||
task_name: Optional[str] = None
|
||||
from_task: Optional[Any] = None
|
||||
agent_id: Optional[str] = None
|
||||
from_agent: Optional[Any] = None
|
||||
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
self._set_task_params(data)
|
||||
self._set_agent_params(data)
|
||||
|
||||
|
||||
class AgentReasoningStartedEvent(ReasoningEvent):
|
||||
"""Event emitted when an agent starts reasoning about a task."""
|
||||
|
||||
type: str = "agent_reasoning_started"
|
||||
agent_role: str
|
||||
task_id: str
|
||||
attempt: int = 1 # The current reasoning/refinement attempt
|
||||
|
||||
|
||||
class AgentReasoningCompletedEvent(BaseEvent):
|
||||
class AgentReasoningCompletedEvent(ReasoningEvent):
|
||||
"""Event emitted when an agent finishes its reasoning process."""
|
||||
|
||||
type: str = "agent_reasoning_completed"
|
||||
@@ -18,14 +36,12 @@ class AgentReasoningCompletedEvent(BaseEvent):
|
||||
task_id: str
|
||||
plan: str
|
||||
ready: bool
|
||||
attempt: int = 1
|
||||
|
||||
|
||||
class AgentReasoningFailedEvent(BaseEvent):
|
||||
class AgentReasoningFailedEvent(ReasoningEvent):
|
||||
"""Event emitted when the reasoning process fails."""
|
||||
|
||||
type: str = "agent_reasoning_failed"
|
||||
agent_role: str
|
||||
task_id: str
|
||||
error: str
|
||||
attempt: int = 1
|
||||
@@ -9,17 +9,24 @@ class ToolUsageEvent(BaseEvent):
|
||||
|
||||
agent_key: Optional[str] = None
|
||||
agent_role: Optional[str] = None
|
||||
agent_id: Optional[str] = None
|
||||
tool_name: str
|
||||
tool_args: Dict[str, Any] | str
|
||||
tool_class: Optional[str] = None
|
||||
run_attempts: int | None = None
|
||||
delegations: int | None = None
|
||||
agent: Optional[Any] = None
|
||||
task_name: Optional[str] = None
|
||||
task_id: Optional[str] = None
|
||||
from_task: Optional[Any] = None
|
||||
from_agent: Optional[Any] = None
|
||||
|
||||
model_config = {"arbitrary_types_allowed": True}
|
||||
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
self._set_agent_params(data)
|
||||
self._set_task_params(data)
|
||||
# Set fingerprint data from the agent
|
||||
if self.agent and hasattr(self.agent, "fingerprint") and self.agent.fingerprint:
|
||||
self.source_fingerprint = self.agent.fingerprint.uuid_str
|
||||
|
||||
@@ -220,14 +220,22 @@ class ConsoleFormatter:
|
||||
return tree
|
||||
|
||||
def create_task_branch(
|
||||
self, crew_tree: Optional[Tree], task_id: str
|
||||
self, crew_tree: Optional[Tree], task_id: str, task_name: Optional[str] = None
|
||||
) -> Optional[Tree]:
|
||||
"""Create and initialize a task branch."""
|
||||
if not self.verbose:
|
||||
return None
|
||||
|
||||
task_content = Text()
|
||||
task_content.append(f"📋 Task: {task_id}", style="yellow bold")
|
||||
|
||||
# Display task name if available, otherwise just the ID
|
||||
if task_name:
|
||||
task_content.append("📋 Task: ", style="yellow bold")
|
||||
task_content.append(f"{task_name}", style="yellow bold")
|
||||
task_content.append(f" (ID: {task_id})", style="yellow dim")
|
||||
else:
|
||||
task_content.append(f"📋 Task: {task_id}", style="yellow bold")
|
||||
|
||||
task_content.append("\nStatus: ", style="white")
|
||||
task_content.append("Executing Task...", style="yellow dim")
|
||||
|
||||
@@ -251,6 +259,7 @@ class ConsoleFormatter:
|
||||
task_id: str,
|
||||
agent_role: str,
|
||||
status: str = "completed",
|
||||
task_name: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Update task status in the tree."""
|
||||
if not self.verbose or crew_tree is None:
|
||||
@@ -270,8 +279,13 @@ class ConsoleFormatter:
|
||||
if str(task_id) in str(branch.label):
|
||||
# Build label without introducing stray blank lines
|
||||
task_content = Text()
|
||||
# First line: Task ID
|
||||
task_content.append(f"📋 Task: {task_id}", style=f"{style} bold")
|
||||
# First line: Task ID/name
|
||||
if task_name:
|
||||
task_content.append("📋 Task: ", style=f"{style} bold")
|
||||
task_content.append(f"{task_name}", style=f"{style} bold")
|
||||
task_content.append(f" (ID: {task_id})", style=f"{style} dim")
|
||||
else:
|
||||
task_content.append(f"📋 Task: {task_id}", style=f"{style} bold")
|
||||
|
||||
# Second line: Assigned to
|
||||
task_content.append("\nAssigned to: ", style="white")
|
||||
@@ -285,8 +299,9 @@ class ConsoleFormatter:
|
||||
break
|
||||
|
||||
# Show status panel
|
||||
display_name = task_name if task_name else str(task_id)
|
||||
content = self.create_status_content(
|
||||
f"Task {status.title()}", str(task_id), style, Agent=agent_role
|
||||
f"Task {status.title()}", display_name, style, Agent=agent_role
|
||||
)
|
||||
self.print_panel(content, panel_title, style)
|
||||
|
||||
|
||||
32
src/crewai/utilities/import_utils.py
Normal file
32
src/crewai/utilities/import_utils.py
Normal file
@@ -0,0 +1,32 @@
|
||||
"""Import utilities for optional dependencies."""
|
||||
|
||||
import importlib
|
||||
from types import ModuleType
|
||||
|
||||
|
||||
class OptionalDependencyError(ImportError):
|
||||
"""Exception raised when an optional dependency is not installed."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def require(name: str, *, purpose: str) -> ModuleType:
|
||||
"""Import a module, raising a helpful error if it's not installed.
|
||||
|
||||
Args:
|
||||
name: The module name to import.
|
||||
purpose: Description of what requires this dependency.
|
||||
|
||||
Returns:
|
||||
The imported module.
|
||||
|
||||
Raises:
|
||||
OptionalDependencyError: If the module is not installed.
|
||||
"""
|
||||
try:
|
||||
return importlib.import_module(name)
|
||||
except ImportError as exc:
|
||||
raise OptionalDependencyError(
|
||||
f"{purpose} requires the optional dependency '{name}'.\n"
|
||||
f"Install it with: uv add {name}"
|
||||
) from exc
|
||||
@@ -18,17 +18,20 @@ from crewai.utilities.events.reasoning_events import (
|
||||
|
||||
class ReasoningPlan(BaseModel):
|
||||
"""Model representing a reasoning plan for a task."""
|
||||
|
||||
plan: str = Field(description="The detailed reasoning plan for the task.")
|
||||
ready: bool = Field(description="Whether the agent is ready to execute the task.")
|
||||
|
||||
|
||||
class AgentReasoningOutput(BaseModel):
|
||||
"""Model representing the output of the agent reasoning process."""
|
||||
|
||||
plan: ReasoningPlan = Field(description="The reasoning plan for the task.")
|
||||
|
||||
|
||||
class ReasoningFunction(BaseModel):
|
||||
"""Model for function calling with reasoning."""
|
||||
|
||||
plan: str = Field(description="The detailed reasoning plan for the task.")
|
||||
ready: bool = Field(description="Whether the agent is ready to execute the task.")
|
||||
|
||||
@@ -38,6 +41,7 @@ class AgentReasoning:
|
||||
Handles the agent reasoning process, enabling an agent to reflect and create a plan
|
||||
before executing a task.
|
||||
"""
|
||||
|
||||
def __init__(self, task: Task, agent: Agent):
|
||||
if not task or not agent:
|
||||
raise ValueError("Both task and agent must be provided.")
|
||||
@@ -63,6 +67,7 @@ class AgentReasoning:
|
||||
agent_role=self.agent.role,
|
||||
task_id=str(self.task.id),
|
||||
attempt=1,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
except Exception:
|
||||
@@ -82,6 +87,7 @@ class AgentReasoning:
|
||||
plan=output.plan.plan,
|
||||
ready=output.plan.ready,
|
||||
attempt=1,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
except Exception:
|
||||
@@ -98,6 +104,7 @@ class AgentReasoning:
|
||||
task_id=str(self.task.id),
|
||||
error=str(e),
|
||||
attempt=1,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
except Exception:
|
||||
@@ -135,14 +142,16 @@ class AgentReasoning:
|
||||
system_prompt = self.i18n.retrieve("reasoning", "initial_plan").format(
|
||||
role=self.agent.role,
|
||||
goal=self.agent.goal,
|
||||
backstory=self.__get_agent_backstory()
|
||||
backstory=self.__get_agent_backstory(),
|
||||
)
|
||||
|
||||
response = self.llm.call(
|
||||
[
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": reasoning_prompt}
|
||||
]
|
||||
{"role": "user", "content": reasoning_prompt},
|
||||
],
|
||||
from_task=self.task,
|
||||
from_agent=self.agent,
|
||||
)
|
||||
|
||||
return self.__parse_reasoning_response(str(response))
|
||||
@@ -170,6 +179,7 @@ class AgentReasoning:
|
||||
agent_role=self.agent.role,
|
||||
task_id=str(self.task.id),
|
||||
attempt=attempt + 1,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
except Exception:
|
||||
@@ -183,14 +193,16 @@ class AgentReasoning:
|
||||
system_prompt = self.i18n.retrieve("reasoning", "refine_plan").format(
|
||||
role=self.agent.role,
|
||||
goal=self.agent.goal,
|
||||
backstory=self.__get_agent_backstory()
|
||||
backstory=self.__get_agent_backstory(),
|
||||
)
|
||||
|
||||
response = self.llm.call(
|
||||
[
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": refine_prompt}
|
||||
]
|
||||
{"role": "user", "content": refine_prompt},
|
||||
],
|
||||
from_task=self.task,
|
||||
from_agent=self.agent,
|
||||
)
|
||||
plan, ready = self.__parse_reasoning_response(str(response))
|
||||
|
||||
@@ -227,23 +239,23 @@ class AgentReasoning:
|
||||
"properties": {
|
||||
"plan": {
|
||||
"type": "string",
|
||||
"description": "The detailed reasoning plan for the task."
|
||||
"description": "The detailed reasoning plan for the task.",
|
||||
},
|
||||
"ready": {
|
||||
"type": "boolean",
|
||||
"description": "Whether the agent is ready to execute the task."
|
||||
}
|
||||
"description": "Whether the agent is ready to execute the task.",
|
||||
},
|
||||
},
|
||||
"required": ["plan", "ready"]
|
||||
}
|
||||
}
|
||||
"required": ["plan", "ready"],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
try:
|
||||
system_prompt = self.i18n.retrieve("reasoning", prompt_type).format(
|
||||
role=self.agent.role,
|
||||
goal=self.agent.goal,
|
||||
backstory=self.__get_agent_backstory()
|
||||
backstory=self.__get_agent_backstory(),
|
||||
)
|
||||
|
||||
# Prepare a simple callable that just returns the tool arguments as JSON
|
||||
@@ -254,10 +266,12 @@ class AgentReasoning:
|
||||
response = self.llm.call(
|
||||
[
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": prompt}
|
||||
{"role": "user", "content": prompt},
|
||||
],
|
||||
tools=[function_schema],
|
||||
available_functions={"create_reasoning_plan": _create_reasoning_plan},
|
||||
from_task=self.task,
|
||||
from_agent=self.agent,
|
||||
)
|
||||
|
||||
self.logger.debug(f"Function calling response: {response[:100]}...")
|
||||
@@ -270,30 +284,43 @@ class AgentReasoning:
|
||||
pass
|
||||
|
||||
response_str = str(response)
|
||||
return response_str, "READY: I am ready to execute the task." in response_str
|
||||
return (
|
||||
response_str,
|
||||
"READY: I am ready to execute the task." in response_str,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Error during function calling: {str(e)}. Falling back to text parsing.")
|
||||
self.logger.warning(
|
||||
f"Error during function calling: {str(e)}. Falling back to text parsing."
|
||||
)
|
||||
|
||||
try:
|
||||
system_prompt = self.i18n.retrieve("reasoning", prompt_type).format(
|
||||
role=self.agent.role,
|
||||
goal=self.agent.goal,
|
||||
backstory=self.__get_agent_backstory()
|
||||
backstory=self.__get_agent_backstory(),
|
||||
)
|
||||
|
||||
fallback_response = self.llm.call(
|
||||
[
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": prompt}
|
||||
]
|
||||
{"role": "user", "content": prompt},
|
||||
],
|
||||
from_task=self.task,
|
||||
from_agent=self.agent,
|
||||
)
|
||||
|
||||
fallback_str = str(fallback_response)
|
||||
return fallback_str, "READY: I am ready to execute the task." in fallback_str
|
||||
return (
|
||||
fallback_str,
|
||||
"READY: I am ready to execute the task." in fallback_str,
|
||||
)
|
||||
except Exception as inner_e:
|
||||
self.logger.error(f"Error during fallback text parsing: {str(inner_e)}")
|
||||
return "Failed to generate a plan due to an error.", True # Default to ready to avoid getting stuck
|
||||
return (
|
||||
"Failed to generate a plan due to an error.",
|
||||
True,
|
||||
) # Default to ready to avoid getting stuck
|
||||
|
||||
def __get_agent_backstory(self) -> str:
|
||||
"""
|
||||
@@ -319,7 +346,7 @@ class AgentReasoning:
|
||||
backstory=self.__get_agent_backstory(),
|
||||
description=self.task.description,
|
||||
expected_output=self.task.expected_output,
|
||||
tools=available_tools
|
||||
tools=available_tools,
|
||||
)
|
||||
|
||||
def __format_available_tools(self) -> str:
|
||||
@@ -330,7 +357,7 @@ class AgentReasoning:
|
||||
str: Comma-separated list of tool names.
|
||||
"""
|
||||
try:
|
||||
return ', '.join([tool.name for tool in (self.task.tools or [])])
|
||||
return ", ".join([tool.name for tool in (self.task.tools or [])])
|
||||
except (AttributeError, TypeError):
|
||||
return "No tools available"
|
||||
|
||||
@@ -348,7 +375,7 @@ class AgentReasoning:
|
||||
role=self.agent.role,
|
||||
goal=self.agent.goal,
|
||||
backstory=self.__get_agent_backstory(),
|
||||
current_plan=current_plan
|
||||
current_plan=current_plan,
|
||||
)
|
||||
|
||||
def __parse_reasoning_response(self, response: str) -> Tuple[str, bool]:
|
||||
|
||||
@@ -23,6 +23,7 @@ from crewai.utilities.events import crewai_event_bus
|
||||
from crewai.utilities.events.tool_usage_events import ToolUsageFinishedEvent
|
||||
from crewai.process import Process
|
||||
|
||||
|
||||
def test_agent_llm_creation_with_env_vars():
|
||||
# Store original environment variables
|
||||
original_api_key = os.environ.get("OPENAI_API_KEY")
|
||||
@@ -235,7 +236,7 @@ def test_logging_tool_usage():
|
||||
)
|
||||
|
||||
assert agent.llm.model == "gpt-4o-mini"
|
||||
assert agent.tools_handler.last_used_tool == {}
|
||||
assert agent.tools_handler.last_used_tool is None
|
||||
task = Task(
|
||||
description="What is 3 times 4?",
|
||||
agent=agent,
|
||||
@@ -593,42 +594,17 @@ def test_agent_repeated_tool_usage_check_even_with_disabled_cache(capsys):
|
||||
)
|
||||
|
||||
captured = capsys.readouterr()
|
||||
output = (
|
||||
captured.out.replace("\n", " ")
|
||||
.replace(" ", " ")
|
||||
.strip()
|
||||
.replace("╭", "")
|
||||
.replace("╮", "")
|
||||
.replace("╯", "")
|
||||
.replace("╰", "")
|
||||
.replace("│", "")
|
||||
.replace("─", "")
|
||||
.replace("[", "")
|
||||
.replace("]", "")
|
||||
.replace("bold", "")
|
||||
.replace("blue", "")
|
||||
.replace("yellow", "")
|
||||
.replace("green", "")
|
||||
.replace("red", "")
|
||||
.replace("dim", "")
|
||||
.replace("🤖", "")
|
||||
.replace("🔧", "")
|
||||
.replace("✅", "")
|
||||
.replace("\x1b[93m", "")
|
||||
.replace("\x1b[00m", "")
|
||||
.replace("\\", "")
|
||||
.replace('"', "")
|
||||
.replace("'", "")
|
||||
)
|
||||
|
||||
# Look for the message in the normalized output, handling the apostrophe difference
|
||||
expected_message = (
|
||||
"I tried reusing the same input, I must stop using this action input"
|
||||
)
|
||||
# More flexible check, look for either the repeated usage message or verification that max iterations was reached
|
||||
output_lower = captured.out.lower()
|
||||
|
||||
has_repeated_usage_message = "tried reusing the same input" in output_lower
|
||||
has_max_iterations = "maximum iterations reached" in output_lower
|
||||
has_final_answer = "final answer" in output_lower or "42" in captured.out
|
||||
|
||||
assert (
|
||||
expected_message in output
|
||||
), f"Expected message not found in output. Output was: {output}"
|
||||
has_repeated_usage_message or (has_max_iterations and has_final_answer)
|
||||
), f"Expected repeated tool usage handling or proper max iteration handling. Output was: {captured.out[:500]}..."
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@@ -783,10 +759,10 @@ def test_agent_without_max_rpm_respects_crew_rpm(capsys):
|
||||
|
||||
with patch.object(RPMController, "_wait_for_next_minute") as moveon:
|
||||
moveon.return_value = True
|
||||
crew.kickoff()
|
||||
captured = capsys.readouterr()
|
||||
assert "get_final_answer" in captured.out
|
||||
assert "Max RPM reached, waiting for next minute to start." in captured.out
|
||||
result = crew.kickoff()
|
||||
# Verify the crew executed and RPM limit was triggered
|
||||
assert result is not None
|
||||
assert moveon.called
|
||||
moveon.assert_called_once()
|
||||
|
||||
|
||||
@@ -1213,17 +1189,13 @@ Thought:<|eot_id|>
|
||||
def test_task_allow_crewai_trigger_context():
|
||||
from crewai import Crew
|
||||
|
||||
agent = Agent(
|
||||
role="test role",
|
||||
goal="test goal",
|
||||
backstory="test backstory"
|
||||
)
|
||||
agent = Agent(role="test role", goal="test goal", backstory="test backstory")
|
||||
|
||||
task = Task(
|
||||
description="Analyze the data",
|
||||
expected_output="Analysis report",
|
||||
agent=agent,
|
||||
allow_crewai_trigger_context=True
|
||||
allow_crewai_trigger_context=True,
|
||||
)
|
||||
crew = Crew(agents=[agent], tasks=[task])
|
||||
crew.kickoff({"crewai_trigger_payload": "Important context data"})
|
||||
@@ -1238,17 +1210,13 @@ def test_task_allow_crewai_trigger_context():
|
||||
def test_task_without_allow_crewai_trigger_context():
|
||||
from crewai import Crew
|
||||
|
||||
agent = Agent(
|
||||
role="test role",
|
||||
goal="test goal",
|
||||
backstory="test backstory"
|
||||
)
|
||||
agent = Agent(role="test role", goal="test goal", backstory="test backstory")
|
||||
|
||||
task = Task(
|
||||
description="Analyze the data",
|
||||
expected_output="Analysis report",
|
||||
agent=agent,
|
||||
allow_crewai_trigger_context=False
|
||||
allow_crewai_trigger_context=False,
|
||||
)
|
||||
|
||||
crew = Crew(agents=[agent], tasks=[task])
|
||||
@@ -1265,23 +1233,18 @@ def test_task_without_allow_crewai_trigger_context():
|
||||
def test_task_allow_crewai_trigger_context_no_payload():
|
||||
from crewai import Crew
|
||||
|
||||
agent = Agent(
|
||||
role="test role",
|
||||
goal="test goal",
|
||||
backstory="test backstory"
|
||||
)
|
||||
agent = Agent(role="test role", goal="test goal", backstory="test backstory")
|
||||
|
||||
task = Task(
|
||||
description="Analyze the data",
|
||||
expected_output="Analysis report",
|
||||
agent=agent,
|
||||
allow_crewai_trigger_context=True
|
||||
allow_crewai_trigger_context=True,
|
||||
)
|
||||
|
||||
crew = Crew(agents=[agent], tasks=[task])
|
||||
crew.kickoff({"other_input": "other data"})
|
||||
|
||||
|
||||
prompt = task.prompt()
|
||||
|
||||
assert "Analyze the data" in prompt
|
||||
@@ -1293,7 +1256,9 @@ def test_do_not_allow_crewai_trigger_context_for_first_task_hierarchical():
|
||||
from crewai import Crew
|
||||
|
||||
agent1 = Agent(role="First Agent", goal="First goal", backstory="First backstory")
|
||||
agent2 = Agent(role="Second Agent", goal="Second goal", backstory="Second backstory")
|
||||
agent2 = Agent(
|
||||
role="Second Agent", goal="Second goal", backstory="Second backstory"
|
||||
)
|
||||
|
||||
first_task = Task(
|
||||
description="Process initial data",
|
||||
@@ -1301,12 +1266,11 @@ def test_do_not_allow_crewai_trigger_context_for_first_task_hierarchical():
|
||||
agent=agent1,
|
||||
)
|
||||
|
||||
|
||||
crew = Crew(
|
||||
agents=[agent1, agent2],
|
||||
tasks=[first_task],
|
||||
process=Process.hierarchical,
|
||||
manager_llm="gpt-4o"
|
||||
manager_llm="gpt-4o",
|
||||
)
|
||||
|
||||
crew.kickoff({"crewai_trigger_payload": "Initial context data"})
|
||||
@@ -1321,7 +1285,9 @@ def test_first_task_auto_inject_trigger():
|
||||
from crewai import Crew
|
||||
|
||||
agent1 = Agent(role="First Agent", goal="First goal", backstory="First backstory")
|
||||
agent2 = Agent(role="Second Agent", goal="Second goal", backstory="Second backstory")
|
||||
agent2 = Agent(
|
||||
role="Second Agent", goal="Second goal", backstory="Second backstory"
|
||||
)
|
||||
|
||||
first_task = Task(
|
||||
description="Process initial data",
|
||||
@@ -1335,10 +1301,7 @@ def test_first_task_auto_inject_trigger():
|
||||
agent=agent2,
|
||||
)
|
||||
|
||||
crew = Crew(
|
||||
agents=[agent1, agent2],
|
||||
tasks=[first_task, second_task]
|
||||
)
|
||||
crew = Crew(agents=[agent1, agent2], tasks=[first_task, second_task])
|
||||
crew.kickoff({"crewai_trigger_payload": "Initial context data"})
|
||||
|
||||
first_prompt = first_task.prompt()
|
||||
@@ -1349,31 +1312,31 @@ def test_first_task_auto_inject_trigger():
|
||||
assert "Process secondary data" in second_prompt
|
||||
assert "Trigger Payload:" not in second_prompt
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_ensure_first_task_allow_crewai_trigger_context_is_false_does_not_inject():
|
||||
from crewai import Crew
|
||||
|
||||
agent1 = Agent(role="First Agent", goal="First goal", backstory="First backstory")
|
||||
agent2 = Agent(role="Second Agent", goal="Second goal", backstory="Second backstory")
|
||||
agent2 = Agent(
|
||||
role="Second Agent", goal="Second goal", backstory="Second backstory"
|
||||
)
|
||||
|
||||
first_task = Task(
|
||||
description="Process initial data",
|
||||
expected_output="Initial analysis",
|
||||
agent=agent1,
|
||||
allow_crewai_trigger_context=False
|
||||
allow_crewai_trigger_context=False,
|
||||
)
|
||||
|
||||
second_task = Task(
|
||||
description="Process secondary data",
|
||||
expected_output="Secondary analysis",
|
||||
agent=agent2,
|
||||
allow_crewai_trigger_context=True
|
||||
allow_crewai_trigger_context=True,
|
||||
)
|
||||
|
||||
crew = Crew(
|
||||
agents=[agent1, agent2],
|
||||
tasks=[first_task, second_task]
|
||||
)
|
||||
crew = Crew(agents=[agent1, agent2], tasks=[first_task, second_task])
|
||||
crew.kickoff({"crewai_trigger_payload": "Context data"})
|
||||
|
||||
first_prompt = first_task.prompt()
|
||||
@@ -1383,7 +1346,6 @@ def test_ensure_first_task_allow_crewai_trigger_context_is_false_does_not_inject
|
||||
assert "Trigger Payload: Context data" in second_prompt
|
||||
|
||||
|
||||
|
||||
@patch("crewai.agent.CrewTrainingHandler")
|
||||
def test_agent_training_handler(crew_training_handler):
|
||||
task_prompt = "What is 1 + 1?"
|
||||
@@ -2347,12 +2309,13 @@ def mock_get_auth_token():
|
||||
|
||||
@patch("crewai.cli.plus_api.PlusAPI.get_agent")
|
||||
def test_agent_from_repository(mock_get_agent, mock_get_auth_token):
|
||||
from crewai_tools import (
|
||||
SerperDevTool,
|
||||
XMLSearchTool,
|
||||
CSVSearchTool,
|
||||
EnterpriseActionTool,
|
||||
)
|
||||
# Mock embedchain initialization to prevent race conditions in parallel CI execution
|
||||
with patch("embedchain.client.Client.setup"):
|
||||
from crewai_tools import (
|
||||
SerperDevTool,
|
||||
FileReadTool,
|
||||
EnterpriseActionTool,
|
||||
)
|
||||
|
||||
mock_get_response = MagicMock()
|
||||
mock_get_response.status_code = 200
|
||||
@@ -2368,10 +2331,9 @@ def test_agent_from_repository(mock_get_agent, mock_get_auth_token):
|
||||
},
|
||||
{
|
||||
"module": "crewai_tools",
|
||||
"name": "XMLSearchTool",
|
||||
"init_params": {"summarize": "true"},
|
||||
"name": "FileReadTool",
|
||||
"init_params": {"file_path": "test.txt"},
|
||||
},
|
||||
{"module": "crewai_tools", "name": "CSVSearchTool", "init_params": {}},
|
||||
# using a tools that returns a list of BaseTools
|
||||
{
|
||||
"module": "crewai_tools",
|
||||
@@ -2396,23 +2358,22 @@ def test_agent_from_repository(mock_get_agent, mock_get_auth_token):
|
||||
assert agent.role == "test role"
|
||||
assert agent.goal == "test goal"
|
||||
assert agent.backstory == "test backstory"
|
||||
assert len(agent.tools) == 4
|
||||
assert len(agent.tools) == 3
|
||||
|
||||
assert isinstance(agent.tools[0], SerperDevTool)
|
||||
assert agent.tools[0].n_results == 30
|
||||
assert isinstance(agent.tools[1], XMLSearchTool)
|
||||
assert agent.tools[1].summarize
|
||||
assert isinstance(agent.tools[1], FileReadTool)
|
||||
assert agent.tools[1].file_path == "test.txt"
|
||||
|
||||
assert isinstance(agent.tools[2], CSVSearchTool)
|
||||
assert not agent.tools[2].summarize
|
||||
|
||||
assert isinstance(agent.tools[3], EnterpriseActionTool)
|
||||
assert agent.tools[3].name == "test_name"
|
||||
assert isinstance(agent.tools[2], EnterpriseActionTool)
|
||||
assert agent.tools[2].name == "test_name"
|
||||
|
||||
|
||||
@patch("crewai.cli.plus_api.PlusAPI.get_agent")
|
||||
def test_agent_from_repository_override_attributes(mock_get_agent, mock_get_auth_token):
|
||||
from crewai_tools import SerperDevTool
|
||||
# Mock embedchain initialization to prevent race conditions in parallel CI execution
|
||||
with patch("embedchain.client.Client.setup"):
|
||||
from crewai_tools import SerperDevTool
|
||||
|
||||
mock_get_response = MagicMock()
|
||||
mock_get_response.status_code = 200
|
||||
@@ -7,37 +7,37 @@ from crewai.task import Task
|
||||
|
||||
def test_agent_inject_date():
|
||||
"""Test that the inject_date flag injects the current date into the task.
|
||||
|
||||
|
||||
Tests that when inject_date=True, the current date is added to the task description.
|
||||
"""
|
||||
with patch('datetime.datetime') as mock_datetime:
|
||||
with patch("datetime.datetime") as mock_datetime:
|
||||
mock_datetime.now.return_value = datetime(2025, 1, 1)
|
||||
|
||||
|
||||
agent = Agent(
|
||||
role="test_agent",
|
||||
goal="test_goal",
|
||||
backstory="test_backstory",
|
||||
inject_date=True,
|
||||
)
|
||||
|
||||
|
||||
task = Task(
|
||||
description="Test task",
|
||||
expected_output="Test output",
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
|
||||
# Store original description
|
||||
original_description = task.description
|
||||
|
||||
|
||||
agent._inject_date_to_task(task)
|
||||
|
||||
|
||||
assert "Current Date: 2025-01-01" in task.description
|
||||
assert task.description != original_description
|
||||
|
||||
|
||||
def test_agent_without_inject_date():
|
||||
"""Test that without inject_date flag, no date is injected.
|
||||
|
||||
|
||||
Tests that when inject_date=False (default), no date is added to the task description.
|
||||
"""
|
||||
agent = Agent(
|
||||
@@ -46,28 +46,28 @@ def test_agent_without_inject_date():
|
||||
backstory="test_backstory",
|
||||
# inject_date is False by default
|
||||
)
|
||||
|
||||
|
||||
task = Task(
|
||||
description="Test task",
|
||||
expected_output="Test output",
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
|
||||
original_description = task.description
|
||||
|
||||
|
||||
agent._inject_date_to_task(task)
|
||||
|
||||
|
||||
assert task.description == original_description
|
||||
|
||||
|
||||
def test_agent_inject_date_custom_format():
|
||||
"""Test that the inject_date flag with custom date_format works correctly.
|
||||
|
||||
|
||||
Tests that when inject_date=True with a custom date_format, the date is formatted correctly.
|
||||
"""
|
||||
with patch('datetime.datetime') as mock_datetime:
|
||||
with patch("datetime.datetime") as mock_datetime:
|
||||
mock_datetime.now.return_value = datetime(2025, 1, 1)
|
||||
|
||||
|
||||
agent = Agent(
|
||||
role="test_agent",
|
||||
goal="test_goal",
|
||||
@@ -75,25 +75,25 @@ def test_agent_inject_date_custom_format():
|
||||
inject_date=True,
|
||||
date_format="%d/%m/%Y",
|
||||
)
|
||||
|
||||
|
||||
task = Task(
|
||||
description="Test task",
|
||||
expected_output="Test output",
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
|
||||
# Store original description
|
||||
original_description = task.description
|
||||
|
||||
|
||||
agent._inject_date_to_task(task)
|
||||
|
||||
|
||||
assert "Current Date: 01/01/2025" in task.description
|
||||
assert task.description != original_description
|
||||
|
||||
|
||||
def test_agent_inject_date_invalid_format():
|
||||
"""Test error handling with invalid date format.
|
||||
|
||||
|
||||
Tests that when an invalid date_format is provided, the task description remains unchanged.
|
||||
"""
|
||||
agent = Agent(
|
||||
@@ -103,15 +103,15 @@ def test_agent_inject_date_invalid_format():
|
||||
inject_date=True,
|
||||
date_format="invalid",
|
||||
)
|
||||
|
||||
|
||||
task = Task(
|
||||
description="Test task",
|
||||
expected_output="Test output",
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
|
||||
original_description = task.description
|
||||
|
||||
|
||||
agent._inject_date_to_task(task)
|
||||
|
||||
|
||||
assert task.description == original_description
|
||||
@@ -15,37 +15,37 @@ def mock_llm_responses():
|
||||
"ready": "I'll solve this simple math problem.\n\nREADY: I am ready to execute the task.\n\n",
|
||||
"not_ready": "I need to think about derivatives.\n\nNOT READY: I need to refine my plan because I'm not sure about the derivative rules.",
|
||||
"ready_after_refine": "I'll use the power rule for derivatives where d/dx(x^n) = n*x^(n-1).\n\nREADY: I am ready to execute the task.",
|
||||
"execution": "4"
|
||||
"execution": "4",
|
||||
}
|
||||
|
||||
|
||||
def test_agent_with_reasoning(mock_llm_responses):
|
||||
"""Test agent with reasoning."""
|
||||
llm = LLM("gpt-3.5-turbo")
|
||||
|
||||
|
||||
agent = Agent(
|
||||
role="Test Agent",
|
||||
goal="To test the reasoning feature",
|
||||
backstory="I am a test agent created to verify the reasoning feature works correctly.",
|
||||
llm=llm,
|
||||
reasoning=True,
|
||||
verbose=True
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
|
||||
task = Task(
|
||||
description="Simple math task: What's 2+2?",
|
||||
expected_output="The answer should be a number.",
|
||||
agent=agent
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
|
||||
agent.llm.call = lambda messages, *args, **kwargs: (
|
||||
mock_llm_responses["ready"]
|
||||
if any("create a detailed plan" in msg.get("content", "") for msg in messages)
|
||||
else mock_llm_responses["execution"]
|
||||
)
|
||||
|
||||
|
||||
result = agent.execute_task(task)
|
||||
|
||||
|
||||
assert result == mock_llm_responses["execution"]
|
||||
assert "Reasoning Plan:" in task.description
|
||||
|
||||
@@ -53,7 +53,7 @@ def test_agent_with_reasoning(mock_llm_responses):
|
||||
def test_agent_with_reasoning_not_ready_initially(mock_llm_responses):
|
||||
"""Test agent with reasoning that requires refinement."""
|
||||
llm = LLM("gpt-3.5-turbo")
|
||||
|
||||
|
||||
agent = Agent(
|
||||
role="Test Agent",
|
||||
goal="To test the reasoning feature",
|
||||
@@ -61,19 +61,21 @@ def test_agent_with_reasoning_not_ready_initially(mock_llm_responses):
|
||||
llm=llm,
|
||||
reasoning=True,
|
||||
max_reasoning_attempts=2,
|
||||
verbose=True
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
|
||||
task = Task(
|
||||
description="Complex math task: What's the derivative of x²?",
|
||||
expected_output="The answer should be a mathematical expression.",
|
||||
agent=agent
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
|
||||
call_count = [0]
|
||||
|
||||
|
||||
def mock_llm_call(messages, *args, **kwargs):
|
||||
if any("create a detailed plan" in msg.get("content", "") for msg in messages) or any("refine your plan" in msg.get("content", "") for msg in messages):
|
||||
if any(
|
||||
"create a detailed plan" in msg.get("content", "") for msg in messages
|
||||
) or any("refine your plan" in msg.get("content", "") for msg in messages):
|
||||
call_count[0] += 1
|
||||
if call_count[0] == 1:
|
||||
return mock_llm_responses["not_ready"]
|
||||
@@ -81,11 +83,11 @@ def test_agent_with_reasoning_not_ready_initially(mock_llm_responses):
|
||||
return mock_llm_responses["ready_after_refine"]
|
||||
else:
|
||||
return "2x"
|
||||
|
||||
|
||||
agent.llm.call = mock_llm_call
|
||||
|
||||
|
||||
result = agent.execute_task(task)
|
||||
|
||||
|
||||
assert result == "2x"
|
||||
assert call_count[0] == 2 # Should have made 2 reasoning calls
|
||||
assert "Reasoning Plan:" in task.description
|
||||
@@ -94,7 +96,7 @@ def test_agent_with_reasoning_not_ready_initially(mock_llm_responses):
|
||||
def test_agent_with_reasoning_max_attempts_reached():
|
||||
"""Test agent with reasoning that reaches max attempts without being ready."""
|
||||
llm = LLM("gpt-3.5-turbo")
|
||||
|
||||
|
||||
agent = Agent(
|
||||
role="Test Agent",
|
||||
goal="To test the reasoning feature",
|
||||
@@ -102,52 +104,53 @@ def test_agent_with_reasoning_max_attempts_reached():
|
||||
llm=llm,
|
||||
reasoning=True,
|
||||
max_reasoning_attempts=2,
|
||||
verbose=True
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
|
||||
task = Task(
|
||||
description="Complex math task: Solve the Riemann hypothesis.",
|
||||
expected_output="A proof or disproof of the hypothesis.",
|
||||
agent=agent
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
|
||||
call_count = [0]
|
||||
|
||||
|
||||
def mock_llm_call(messages, *args, **kwargs):
|
||||
if any("create a detailed plan" in msg.get("content", "") for msg in messages) or any("refine your plan" in msg.get("content", "") for msg in messages):
|
||||
if any(
|
||||
"create a detailed plan" in msg.get("content", "") for msg in messages
|
||||
) or any("refine your plan" in msg.get("content", "") for msg in messages):
|
||||
call_count[0] += 1
|
||||
return f"Attempt {call_count[0]}: I need more time to think.\n\nNOT READY: I need to refine my plan further."
|
||||
else:
|
||||
return "This is an unsolved problem in mathematics."
|
||||
|
||||
|
||||
agent.llm.call = mock_llm_call
|
||||
|
||||
|
||||
result = agent.execute_task(task)
|
||||
|
||||
|
||||
assert result == "This is an unsolved problem in mathematics."
|
||||
assert call_count[0] == 2 # Should have made exactly 2 reasoning calls (max_attempts)
|
||||
assert (
|
||||
call_count[0] == 2
|
||||
) # Should have made exactly 2 reasoning calls (max_attempts)
|
||||
assert "Reasoning Plan:" in task.description
|
||||
|
||||
|
||||
def test_agent_reasoning_input_validation():
|
||||
"""Test input validation in AgentReasoning."""
|
||||
llm = LLM("gpt-3.5-turbo")
|
||||
|
||||
|
||||
agent = Agent(
|
||||
role="Test Agent",
|
||||
goal="To test the reasoning feature",
|
||||
backstory="I am a test agent created to verify the reasoning feature works correctly.",
|
||||
llm=llm,
|
||||
reasoning=True
|
||||
reasoning=True,
|
||||
)
|
||||
|
||||
|
||||
with pytest.raises(ValueError, match="Both task and agent must be provided"):
|
||||
AgentReasoning(task=None, agent=agent)
|
||||
|
||||
task = Task(
|
||||
description="Simple task",
|
||||
expected_output="Simple output"
|
||||
)
|
||||
|
||||
task = Task(description="Simple task", expected_output="Simple output")
|
||||
with pytest.raises(ValueError, match="Both task and agent must be provided"):
|
||||
AgentReasoning(task=task, agent=None)
|
||||
|
||||
@@ -155,33 +158,33 @@ def test_agent_reasoning_input_validation():
|
||||
def test_agent_reasoning_error_handling():
|
||||
"""Test error handling during the reasoning process."""
|
||||
llm = LLM("gpt-3.5-turbo")
|
||||
|
||||
|
||||
agent = Agent(
|
||||
role="Test Agent",
|
||||
goal="To test the reasoning feature",
|
||||
backstory="I am a test agent created to verify the reasoning feature works correctly.",
|
||||
llm=llm,
|
||||
reasoning=True
|
||||
reasoning=True,
|
||||
)
|
||||
|
||||
|
||||
task = Task(
|
||||
description="Task that will cause an error",
|
||||
expected_output="Output that will never be generated",
|
||||
agent=agent
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
|
||||
call_count = [0]
|
||||
|
||||
|
||||
def mock_llm_call_error(*args, **kwargs):
|
||||
call_count[0] += 1
|
||||
if call_count[0] <= 2: # First calls are for reasoning
|
||||
raise Exception("LLM error during reasoning")
|
||||
return "Fallback execution result" # Return a value for task execution
|
||||
|
||||
|
||||
agent.llm.call = mock_llm_call_error
|
||||
|
||||
|
||||
result = agent.execute_task(task)
|
||||
|
||||
|
||||
assert result == "Fallback execution result"
|
||||
assert call_count[0] > 2 # Ensure we called the mock multiple times
|
||||
|
||||
@@ -189,37 +192,36 @@ def test_agent_reasoning_error_handling():
|
||||
def test_agent_with_function_calling():
|
||||
"""Test agent with reasoning using function calling."""
|
||||
llm = LLM("gpt-3.5-turbo")
|
||||
|
||||
|
||||
agent = Agent(
|
||||
role="Test Agent",
|
||||
goal="To test the reasoning feature",
|
||||
backstory="I am a test agent created to verify the reasoning feature works correctly.",
|
||||
llm=llm,
|
||||
reasoning=True,
|
||||
verbose=True
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
|
||||
task = Task(
|
||||
description="Simple math task: What's 2+2?",
|
||||
expected_output="The answer should be a number.",
|
||||
agent=agent
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
|
||||
agent.llm.supports_function_calling = lambda: True
|
||||
|
||||
|
||||
def mock_function_call(messages, *args, **kwargs):
|
||||
if "tools" in kwargs:
|
||||
return json.dumps({
|
||||
"plan": "I'll solve this simple math problem: 2+2=4.",
|
||||
"ready": True
|
||||
})
|
||||
return json.dumps(
|
||||
{"plan": "I'll solve this simple math problem: 2+2=4.", "ready": True}
|
||||
)
|
||||
else:
|
||||
return "4"
|
||||
|
||||
|
||||
agent.llm.call = mock_function_call
|
||||
|
||||
|
||||
result = agent.execute_task(task)
|
||||
|
||||
|
||||
assert result == "4"
|
||||
assert "Reasoning Plan:" in task.description
|
||||
assert "I'll solve this simple math problem: 2+2=4." in task.description
|
||||
@@ -228,34 +230,34 @@ def test_agent_with_function_calling():
|
||||
def test_agent_with_function_calling_fallback():
|
||||
"""Test agent with reasoning using function calling that falls back to text parsing."""
|
||||
llm = LLM("gpt-3.5-turbo")
|
||||
|
||||
|
||||
agent = Agent(
|
||||
role="Test Agent",
|
||||
goal="To test the reasoning feature",
|
||||
backstory="I am a test agent created to verify the reasoning feature works correctly.",
|
||||
llm=llm,
|
||||
reasoning=True,
|
||||
verbose=True
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
|
||||
task = Task(
|
||||
description="Simple math task: What's 2+2?",
|
||||
expected_output="The answer should be a number.",
|
||||
agent=agent
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
|
||||
agent.llm.supports_function_calling = lambda: True
|
||||
|
||||
|
||||
def mock_function_call(messages, *args, **kwargs):
|
||||
if "tools" in kwargs:
|
||||
return "Invalid JSON that will trigger fallback. READY: I am ready to execute the task."
|
||||
else:
|
||||
return "4"
|
||||
|
||||
|
||||
agent.llm.call = mock_function_call
|
||||
|
||||
|
||||
result = agent.execute_task(task)
|
||||
|
||||
|
||||
assert result == "4"
|
||||
assert "Reasoning Plan:" in task.description
|
||||
assert "Invalid JSON that will trigger fallback" in task.description
|
||||
@@ -318,11 +318,17 @@ def test_sets_parent_flow_when_inside_flow():
|
||||
flow.kickoff()
|
||||
assert captured_agent.parent_flow is flow
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_guardrail_is_called_using_string():
|
||||
guardrail_events = defaultdict(list)
|
||||
from crewai.utilities.events import LLMGuardrailCompletedEvent, LLMGuardrailStartedEvent
|
||||
from crewai.utilities.events import (
|
||||
LLMGuardrailCompletedEvent,
|
||||
LLMGuardrailStartedEvent,
|
||||
)
|
||||
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
|
||||
@crewai_event_bus.on(LLMGuardrailStartedEvent)
|
||||
def capture_guardrail_started(source, event):
|
||||
guardrail_events["started"].append(event)
|
||||
@@ -340,17 +346,26 @@ def test_guardrail_is_called_using_string():
|
||||
|
||||
result = agent.kickoff(messages="Top 10 best players in the world?")
|
||||
|
||||
assert len(guardrail_events['started']) == 2
|
||||
assert len(guardrail_events['completed']) == 2
|
||||
assert not guardrail_events['completed'][0].success
|
||||
assert guardrail_events['completed'][1].success
|
||||
assert "Here are the top 10 best soccer players in the world, focusing exclusively on Brazilian players" in result.raw
|
||||
assert len(guardrail_events["started"]) == 2
|
||||
assert len(guardrail_events["completed"]) == 2
|
||||
assert not guardrail_events["completed"][0].success
|
||||
assert guardrail_events["completed"][1].success
|
||||
assert (
|
||||
"Here are the top 10 best soccer players in the world, focusing exclusively on Brazilian players"
|
||||
in result.raw
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_guardrail_is_called_using_callable():
|
||||
guardrail_events = defaultdict(list)
|
||||
from crewai.utilities.events import LLMGuardrailCompletedEvent, LLMGuardrailStartedEvent
|
||||
from crewai.utilities.events import (
|
||||
LLMGuardrailCompletedEvent,
|
||||
LLMGuardrailStartedEvent,
|
||||
)
|
||||
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
|
||||
@crewai_event_bus.on(LLMGuardrailStartedEvent)
|
||||
def capture_guardrail_started(source, event):
|
||||
guardrail_events["started"].append(event)
|
||||
@@ -368,16 +383,22 @@ def test_guardrail_is_called_using_callable():
|
||||
|
||||
result = agent.kickoff(messages="Top 1 best players in the world?")
|
||||
|
||||
assert len(guardrail_events['started']) == 1
|
||||
assert len(guardrail_events['completed']) == 1
|
||||
assert guardrail_events['completed'][0].success
|
||||
assert len(guardrail_events["started"]) == 1
|
||||
assert len(guardrail_events["completed"]) == 1
|
||||
assert guardrail_events["completed"][0].success
|
||||
assert "Pelé - Santos, 1958" in result.raw
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_guardrail_reached_attempt_limit():
|
||||
guardrail_events = defaultdict(list)
|
||||
from crewai.utilities.events import LLMGuardrailCompletedEvent, LLMGuardrailStartedEvent
|
||||
from crewai.utilities.events import (
|
||||
LLMGuardrailCompletedEvent,
|
||||
LLMGuardrailStartedEvent,
|
||||
)
|
||||
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
|
||||
@crewai_event_bus.on(LLMGuardrailStartedEvent)
|
||||
def capture_guardrail_started(source, event):
|
||||
guardrail_events["started"].append(event)
|
||||
@@ -390,18 +411,23 @@ def test_guardrail_reached_attempt_limit():
|
||||
role="Sports Analyst",
|
||||
goal="Gather information about the best soccer players",
|
||||
backstory="""You are an expert at gathering and organizing information. You carefully collect details and present them in a structured way.""",
|
||||
guardrail=lambda output: (False, "You are not allowed to include Brazilian players"),
|
||||
guardrail=lambda output: (
|
||||
False,
|
||||
"You are not allowed to include Brazilian players",
|
||||
),
|
||||
guardrail_max_retries=2,
|
||||
)
|
||||
|
||||
with pytest.raises(Exception, match="Agent's guardrail failed validation after 2 retries"):
|
||||
with pytest.raises(
|
||||
Exception, match="Agent's guardrail failed validation after 2 retries"
|
||||
):
|
||||
agent.kickoff(messages="Top 10 best players in the world?")
|
||||
|
||||
assert len(guardrail_events['started']) == 3 # 2 retries + 1 initial call
|
||||
assert len(guardrail_events['completed']) == 3 # 2 retries + 1 initial call
|
||||
assert not guardrail_events['completed'][0].success
|
||||
assert not guardrail_events['completed'][1].success
|
||||
assert not guardrail_events['completed'][2].success
|
||||
assert len(guardrail_events["started"]) == 3 # 2 retries + 1 initial call
|
||||
assert len(guardrail_events["completed"]) == 3 # 2 retries + 1 initial call
|
||||
assert not guardrail_events["completed"][0].success
|
||||
assert not guardrail_events["completed"][1].success
|
||||
assert not guardrail_events["completed"][2].success
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@@ -414,22 +440,35 @@ def test_agent_output_when_guardrail_returns_base_model():
|
||||
role="Sports Analyst",
|
||||
goal="Gather information about the best soccer players",
|
||||
backstory="""You are an expert at gathering and organizing information. You carefully collect details and present them in a structured way.""",
|
||||
guardrail=lambda output: (True, Player(name="Lionel Messi", country="Argentina")),
|
||||
guardrail=lambda output: (
|
||||
True,
|
||||
Player(name="Lionel Messi", country="Argentina"),
|
||||
),
|
||||
)
|
||||
|
||||
result = agent.kickoff(messages="Top 10 best players in the world?")
|
||||
|
||||
assert result.pydantic == Player(name="Lionel Messi", country="Argentina")
|
||||
|
||||
|
||||
def test_lite_agent_with_custom_llm_and_guardrails():
|
||||
"""Test that CustomLLM (inheriting from BaseLLM) works with guardrails."""
|
||||
|
||||
class CustomLLM(BaseLLM):
|
||||
def __init__(self, response: str = "Custom response"):
|
||||
super().__init__(model="custom-model")
|
||||
self.response = response
|
||||
self.call_count = 0
|
||||
|
||||
def call(self, messages, tools=None, callbacks=None, available_functions=None, from_task=None, from_agent=None) -> str:
|
||||
def call(
|
||||
self,
|
||||
messages,
|
||||
tools=None,
|
||||
callbacks=None,
|
||||
available_functions=None,
|
||||
from_task=None,
|
||||
from_agent=None,
|
||||
) -> str:
|
||||
self.call_count += 1
|
||||
|
||||
if "valid" in str(messages) and "feedback" in str(messages):
|
||||
@@ -456,7 +495,7 @@ def test_lite_agent_with_custom_llm_and_guardrails():
|
||||
goal="Analyze soccer players",
|
||||
backstory="You analyze soccer players and their performance.",
|
||||
llm=custom_llm,
|
||||
guardrail="Only include Brazilian players"
|
||||
guardrail="Only include Brazilian players",
|
||||
)
|
||||
|
||||
result = agent.kickoff("Tell me about the best soccer players")
|
||||
@@ -474,7 +513,7 @@ def test_lite_agent_with_custom_llm_and_guardrails():
|
||||
goal="Test goal",
|
||||
backstory="Test backstory",
|
||||
llm=custom_llm2,
|
||||
guardrail=test_guardrail
|
||||
guardrail=test_guardrail,
|
||||
)
|
||||
|
||||
result2 = agent2.kickoff("Test message")
|
||||
@@ -484,12 +523,12 @@ def test_lite_agent_with_custom_llm_and_guardrails():
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_lite_agent_with_invalid_llm():
|
||||
"""Test that LiteAgent raises proper error when create_llm returns None."""
|
||||
with patch('crewai.lite_agent.create_llm', return_value=None):
|
||||
with patch("crewai.lite_agent.create_llm", return_value=None):
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
LiteAgent(
|
||||
role="Test Agent",
|
||||
goal="Test goal",
|
||||
goal="Test goal",
|
||||
backstory="Test backstory",
|
||||
llm="invalid-model"
|
||||
llm="invalid-model",
|
||||
)
|
||||
assert "Expected LLM instance of type BaseLLM" in str(exc_info.value)
|
||||
assert "Expected LLM instance of type BaseLLM" in str(exc_info.value)
|
||||
@@ -508,4 +508,162 @@ interactions:
|
||||
- req_80eed127ea0361c637657470cf9b647e
|
||||
http_version: HTTP/1.1
|
||||
status_code: 200
|
||||
- request:
|
||||
body: '{"messages": [{"role": "system", "content": "You are Crew Manager. You
|
||||
are a seasoned manager with a knack for getting the best out of your team.\nYou
|
||||
are also known for your ability to delegate work to the right people, and to
|
||||
ask the right questions to get the best out of your team.\nEven though you don''t
|
||||
perform tasks by yourself, you have a lot of experience in the field, which
|
||||
allows you to properly evaluate the work of your team members.\nYour personal
|
||||
goal is: Manage the team to complete the task in the best way possible.\nYou
|
||||
ONLY have access to the following tools, and should NEVER make up tools that
|
||||
are not listed here:\n\nTool Name: Delegate work to coworker\nTool Arguments:
|
||||
{''task'': {''description'': ''The task to delegate'', ''type'': ''str''}, ''context'':
|
||||
{''description'': ''The context for the task'', ''type'': ''str''}, ''coworker'':
|
||||
{''description'': ''The role/name of the coworker to delegate to'', ''type'':
|
||||
''str''}}\nTool Description: Delegate a specific task to one of the following
|
||||
coworkers: Researcher\nThe input to this tool should be the coworker, the task
|
||||
you want them to do, and ALL necessary context to execute the task, they know
|
||||
nothing about the task, so share absolutely everything you know, don''t reference
|
||||
things but instead explain them.\nTool Name: Ask question to coworker\nTool
|
||||
Arguments: {''question'': {''description'': ''The question to ask'', ''type'':
|
||||
''str''}, ''context'': {''description'': ''The context for the question'', ''type'':
|
||||
''str''}, ''coworker'': {''description'': ''The role/name of the coworker to
|
||||
ask'', ''type'': ''str''}}\nTool Description: Ask a specific question to one
|
||||
of the following coworkers: Researcher\nThe input to this tool should be the
|
||||
coworker, the question you have for them, and ALL necessary context to ask the
|
||||
question properly, they know nothing about the question, so share absolutely
|
||||
everything you know, don''t reference things but instead explain them.\n\nIMPORTANT:
|
||||
Use the following format in your response:\n\n```\nThought: you should always
|
||||
think about what to do\nAction: the action to take, only one name of [Delegate
|
||||
work to coworker, Ask question to coworker], just the name, exactly as it''s
|
||||
written.\nAction Input: the input to the action, just a simple JSON object,
|
||||
enclosed in curly braces, using \" to wrap keys and values.\nObservation: the
|
||||
result of the action\n```\n\nOnce all necessary information is gathered, return
|
||||
the following format:\n\n```\nThought: I now know the final answer\nFinal Answer:
|
||||
the final answer to the original input question\n```"}, {"role": "user", "content":
|
||||
"\nCurrent Task: Ask the researched to say hi!\n\nThis is the expected criteria
|
||||
for your final answer: Howdy!\nyou MUST return the actual complete content as
|
||||
the final answer, not a summary.\n\nBegin! This is VERY important to you, use
|
||||
the tools available and give your best Final Answer, your job depends on it!\n\nThought:"},
|
||||
{"role": "assistant", "content": "Thought: To complete the task, I need to ask
|
||||
the researcher to say \"Howdy!\" I will use the \"Ask question to coworker\"
|
||||
tool to instruct the researcher accordingly.\n\nAction: Ask question to coworker\nAction
|
||||
Input: {\"question\": \"Can you please say hi?\", \"context\": \"The expected
|
||||
greeting is: Howdy!\", \"coworker\": \"Researcher\"}\nObservation: Howdy!"}],
|
||||
"model": "gpt-4o", "stop": ["\nObservation:"]}'
|
||||
headers:
|
||||
accept:
|
||||
- application/json
|
||||
accept-encoding:
|
||||
- gzip, deflate, zstd
|
||||
connection:
|
||||
- keep-alive
|
||||
content-length:
|
||||
- '3318'
|
||||
content-type:
|
||||
- application/json
|
||||
cookie:
|
||||
- _cfuvid=g371YzJ.yOdjD9dcZxJ8VI4huWlRJL2j8lbKDhE0qV8-1743463280779-0.0.1.1-604800000
|
||||
host:
|
||||
- api.openai.com
|
||||
user-agent:
|
||||
- OpenAI/Python 1.93.0
|
||||
x-stainless-arch:
|
||||
- arm64
|
||||
x-stainless-async:
|
||||
- 'false'
|
||||
x-stainless-lang:
|
||||
- python
|
||||
x-stainless-os:
|
||||
- MacOS
|
||||
x-stainless-package-version:
|
||||
- 1.93.0
|
||||
x-stainless-raw-response:
|
||||
- 'true'
|
||||
x-stainless-read-timeout:
|
||||
- '600.0'
|
||||
x-stainless-retry-count:
|
||||
- '0'
|
||||
x-stainless-runtime:
|
||||
- CPython
|
||||
x-stainless-runtime-version:
|
||||
- 3.12.9
|
||||
method: POST
|
||||
uri: https://api.openai.com/v1/chat/completions
|
||||
response:
|
||||
body:
|
||||
string: !!binary |
|
||||
H4sIAAAAAAAAAwAAAP//jFJdj9MwEHzPr1j83KL0Ky15A6SKQ7oHdCAh4BS5ziYxdWxjbyjl1P9+
|
||||
stNrcnAn3Ysle3bGM7t7lwAwWbIcmGg4idaq6fvN7nrb/LrObm72/su3d8dP9uux83+3/mr5kU0C
|
||||
w+x+oqAH1mthWquQpNE9LBxywqA6W6+yWZbNs0UEWlOiCrTa0nRppvN0vpymm2manYmNkQI9y+F7
|
||||
AgBwF89gUZf4h+WQTh5eWvSe18jySxEAc0aFF8a9l564JjYZQGE0oY6uPzemqxvK4Qq0OcA+HNQg
|
||||
VFJzBVz7AzqAH3ob72/jPYcP5lAeX40lHVad5yGR7pQaAVxrQzx0JIa5PSOni31lauvMzv9DZZXU
|
||||
0jeFQ+6NDlY9GcsiekoAbmObukfJmXWmtVSQ2WP8bp3Oez02DGZAZ6szSIa4GrFm68kTekWJxKXy
|
||||
o0YzwUWD5UAdpsK7UpoRkIxS/+/mKe0+udT1S+QHQAi0hGVhHZZSPE48lDkMe/tc2aXL0TDz6H5L
|
||||
gQVJdGESJVa8U/1KMX/0hG1RSV2js072e1XZYpO+WWWrxULsWHJK7gEAAP//AwAfm3tPYAMAAA==
|
||||
headers:
|
||||
CF-RAY:
|
||||
- 974f08878cb7239e-SJC
|
||||
Connection:
|
||||
- keep-alive
|
||||
Content-Encoding:
|
||||
- gzip
|
||||
Content-Type:
|
||||
- application/json
|
||||
Date:
|
||||
- Mon, 25 Aug 2025 23:57:43 GMT
|
||||
Server:
|
||||
- cloudflare
|
||||
Set-Cookie:
|
||||
- __cf_bm=veVLwO9uj9FvIBC_v55vfVlSdU8NHY.wdBapmpmtHn4-1756166263-1.0.1.1-LJ9bNXe6v06jg_mjVZp8vfvLT.9Hf8xUHOf2FempuntDnL5ogQXRuJvIJipz1trGr96_3WUCWlsexhQlkdtveEH6NbFMrm__Y61khA_IyPM;
|
||||
path=/; expires=Tue, 26-Aug-25 00:27:43 GMT; domain=.api.openai.com; HttpOnly;
|
||||
Secure; SameSite=None
|
||||
- _cfuvid=miqmxIsNDZtXvrtcyhpeSI3TjT_zcUas9Enn6gGtIsI-1756166263603-0.0.1.1-604800000;
|
||||
path=/; domain=.api.openai.com; HttpOnly; Secure; SameSite=None
|
||||
Strict-Transport-Security:
|
||||
- max-age=31536000; includeSubDomains; preload
|
||||
Transfer-Encoding:
|
||||
- chunked
|
||||
X-Content-Type-Options:
|
||||
- nosniff
|
||||
access-control-expose-headers:
|
||||
- X-Request-ID
|
||||
alt-svc:
|
||||
- h3=":443"; ma=86400
|
||||
cf-cache-status:
|
||||
- DYNAMIC
|
||||
openai-organization:
|
||||
- crewai-iuxna1
|
||||
openai-processing-ms:
|
||||
- '472'
|
||||
openai-project:
|
||||
- proj_xitITlrFeen7zjNSzML82h9x
|
||||
openai-version:
|
||||
- '2020-10-01'
|
||||
x-envoy-upstream-service-time:
|
||||
- '550'
|
||||
x-ratelimit-limit-project-requests:
|
||||
- '10000'
|
||||
x-ratelimit-limit-requests:
|
||||
- '10000'
|
||||
x-ratelimit-limit-tokens:
|
||||
- '30000000'
|
||||
x-ratelimit-remaining-project-requests:
|
||||
- '9999'
|
||||
x-ratelimit-remaining-requests:
|
||||
- '9999'
|
||||
x-ratelimit-remaining-tokens:
|
||||
- '29999222'
|
||||
x-ratelimit-reset-project-requests:
|
||||
- 6ms
|
||||
x-ratelimit-reset-requests:
|
||||
- 6ms
|
||||
x-ratelimit-reset-tokens:
|
||||
- 1ms
|
||||
x-request-id:
|
||||
- req_4f627ee303b9429db117e9699b00656a
|
||||
status:
|
||||
code: 200
|
||||
message: OK
|
||||
version: 1
|
||||
|
||||
@@ -498,4 +498,204 @@ interactions:
|
||||
status:
|
||||
code: 200
|
||||
message: OK
|
||||
- request:
|
||||
body: '{"messages": [{"role": "system", "content": "You are Crew Manager. You
|
||||
are a seasoned manager with a knack for getting the best out of your team.\nYou
|
||||
are also known for your ability to delegate work to the right people, and to
|
||||
ask the right questions to get the best out of your team.\nEven though you don''t
|
||||
perform tasks by yourself, you have a lot of experience in the field, which
|
||||
allows you to properly evaluate the work of your team members.\nYour personal
|
||||
goal is: Manage the team to complete the task in the best way possible.\nYou
|
||||
ONLY have access to the following tools, and should NEVER make up tools that
|
||||
are not listed here:\n\nTool Name: Delegate work to coworker\nTool Arguments:
|
||||
{''task'': {''description'': ''The task to delegate'', ''type'': ''str''}, ''context'':
|
||||
{''description'': ''The context for the task'', ''type'': ''str''}, ''coworker'':
|
||||
{''description'': ''The role/name of the coworker to delegate to'', ''type'':
|
||||
''str''}}\nTool Description: Delegate a specific task to one of the following
|
||||
coworkers: First Agent\nThe input to this tool should be the coworker, the task
|
||||
you want them to do, and ALL necessary context to execute the task, they know
|
||||
nothing about the task, so share absolutely everything you know, don''t reference
|
||||
things but instead explain them.\nTool Name: Ask question to coworker\nTool
|
||||
Arguments: {''question'': {''description'': ''The question to ask'', ''type'':
|
||||
''str''}, ''context'': {''description'': ''The context for the question'', ''type'':
|
||||
''str''}, ''coworker'': {''description'': ''The role/name of the coworker to
|
||||
ask'', ''type'': ''str''}}\nTool Description: Ask a specific question to one
|
||||
of the following coworkers: First Agent\nThe input to this tool should be the
|
||||
coworker, the question you have for them, and ALL necessary context to ask the
|
||||
question properly, they know nothing about the question, so share absolutely
|
||||
everything you know, don''t reference things but instead explain them.\n\nIMPORTANT:
|
||||
Use the following format in your response:\n\n```\nThought: you should always
|
||||
think about what to do\nAction: the action to take, only one name of [Delegate
|
||||
work to coworker, Ask question to coworker], just the name, exactly as it''s
|
||||
written.\nAction Input: the input to the action, just a simple JSON object,
|
||||
enclosed in curly braces, using \" to wrap keys and values.\nObservation: the
|
||||
result of the action\n```\n\nOnce all necessary information is gathered, return
|
||||
the following format:\n\n```\nThought: I now know the final answer\nFinal Answer:
|
||||
the final answer to the original input question\n```"}, {"role": "user", "content":
|
||||
"\nCurrent Task: Process initial data\n\nThis is the expected criteria for your
|
||||
final answer: Initial analysis\nyou MUST return the actual complete content
|
||||
as the final answer, not a summary.\n\nBegin! This is VERY important to you,
|
||||
use the tools available and give your best Final Answer, your job depends on
|
||||
it!\n\nThought:"}, {"role": "assistant", "content": "Thought: I need to delegate
|
||||
the task of processing the initial data to the First Agent to ensure we have
|
||||
a thorough and accurate analysis. I will provide them with all the necessary
|
||||
details to complete this task effectively.\n\nAction: Delegate work to coworker\nAction
|
||||
Input: {\"task\": \"Process initial data\", \"context\": \"The task involves
|
||||
analyzing the initial data set we have received. This includes cleaning the
|
||||
data, categorizing it for analysis, identifying any trends or patterns, and
|
||||
summarizing the findings. The goal is to have a clear understanding of what
|
||||
the data indicates and any initial insights that can be drawn from it.\", \"coworker\":
|
||||
\"First Agent\"}\nObservation: To process the initial data set effectively,
|
||||
we will follow several key steps to ensure we clean, categorize, analyze, and
|
||||
summarize our findings comprehensively.\n\n1. **Data Cleaning**: \n - First,
|
||||
we will review the dataset for missing, inconsistent, or erroneous entries.
|
||||
\n - We will handle missing values by either imputing them based on surrounding
|
||||
data or removing records where necessary.\n - Additionally, we will standardize
|
||||
categorical variables to ensure consistency (e.g., ensuring all location names
|
||||
are spelled the same way).\n\n2. **Data Categorization**: \n - Next, we will
|
||||
categorize the data into relevant segments that will aid our analysis. \n -
|
||||
This involves grouping data points based on common characteristics, such as
|
||||
demographics, time periods, or any key performance indicators (KPIs) we are
|
||||
focusing on.\n\n3. **Trend and Pattern Identification**: \n - With the cleaned
|
||||
and categorized data, we will perform a detailed analysis to identify trends
|
||||
and patterns.\n - This will involve using statistical tools and visualizations
|
||||
to uncover relationships within the data. We will look at time series analysis,
|
||||
correlation coefficients, and significant outliers that may require further
|
||||
investigation.\n\n4. **Summarizing Findings**: \n - Finally, we will compile
|
||||
a summary of our findings, including both qualitative insights and quantitative
|
||||
metrics.\n - This summary should encapsulate the key trends identified, any
|
||||
notable patterns, and implications of these findings.\n - We will also document
|
||||
any limitations of the data and suggest areas for further research if necessary.\n\nBy
|
||||
completing these steps, we will not only have a clear understanding of what
|
||||
the data indicates but also provide actionable insights that can guide our next
|
||||
steps. This comprehensive analysis will serve as a solid foundation for any
|
||||
additional exploration or decision-making initiatives related to our project."}],
|
||||
"model": "gpt-4o", "stop": ["\nObservation:"]}'
|
||||
headers:
|
||||
accept:
|
||||
- application/json
|
||||
accept-encoding:
|
||||
- gzip, deflate, zstd
|
||||
connection:
|
||||
- keep-alive
|
||||
content-length:
|
||||
- '5593'
|
||||
content-type:
|
||||
- application/json
|
||||
cookie:
|
||||
- _cfuvid=YeODa6MF5ug3OZUV6ob1dSrBKCM8BXbKkS77TIihYoE-1755550362828-0.0.1.1-604800000
|
||||
host:
|
||||
- api.openai.com
|
||||
user-agent:
|
||||
- OpenAI/Python 1.93.0
|
||||
x-stainless-arch:
|
||||
- arm64
|
||||
x-stainless-async:
|
||||
- 'false'
|
||||
x-stainless-lang:
|
||||
- python
|
||||
x-stainless-os:
|
||||
- MacOS
|
||||
x-stainless-package-version:
|
||||
- 1.93.0
|
||||
x-stainless-raw-response:
|
||||
- 'true'
|
||||
x-stainless-read-timeout:
|
||||
- '600.0'
|
||||
x-stainless-retry-count:
|
||||
- '0'
|
||||
x-stainless-runtime:
|
||||
- CPython
|
||||
x-stainless-runtime-version:
|
||||
- 3.12.9
|
||||
method: POST
|
||||
uri: https://api.openai.com/v1/chat/completions
|
||||
response:
|
||||
body:
|
||||
string: !!binary |
|
||||
H4sIAAAAAAAAAwAAAP//jFXbbtxGDH33VxB6KRCsDV8T129pmgBBH5qiblO0DuzxDCWxHnFUklpn
|
||||
HeTfC472ljYF+rLAisPLOTwkPx0ANJSaK2hiHywOYz58dTn2H18/9T+d/KT8+Otr/uXFu6ffLs/k
|
||||
+rvfY7Nwj3L/J0bbeB3FMowZjQrP5igYDD3qyYuL56enJ+enF9UwlITZ3brRDs/L4enx6fnh8eXh
|
||||
8fO1Y18oojZX8McBAMCn+uslcsKPzRUcLzZfBlQNHTZX20cAjZTsX5qgSmqBrVnsjLGwIdeq7+7u
|
||||
bvi6L1PX2xW8BS6P8OA/1iO0xCFDYH1EueE39d/L+u8KrguMUiKq1qfEZBQypGABFA2wbTEaLTGv
|
||||
FvCI8Eg5w8QJxcIDguISJWSIQkYxZFDDUcEKIOskCAGcS8EeWWmJEDjklZIewQ3f8MkRPHv2ved6
|
||||
lTEwcffs2RW8X6dRC2JwvwJKyEbtiriDwAliEfGquIOBVIm7BRDHwk4Ssi2gCKBIYSyTArIJocIj
|
||||
WU9ccTo+RTuC654UiJclL1Eh9qV4OECyHgVoGCcLrgMY0PqSFNoim6SwDHlC9WSCQ1n6Jy+jagdB
|
||||
MBZJegRvJvFoQxGsHA6TmoPjFCTRE0IMhl2RSuAyCIX7jHOmiaktMpCtjpyv0x1fa5enWp2z9iNH
|
||||
hOg0YlpsQc5M3iMU6QLTEyYgtgIjihEjGyh2A7J5/4NBGyJlsmD7raoZOynT6BA3Ee+DYoLCoH0Q
|
||||
TBD7ICEaCqlRVMj0gJBwKJ2EsfcvRcBoQE9OJekCkPvA0YN6vYIZl8FhlBaIlbreFJKER67ozxz9
|
||||
tSCnqoJ3wQyF4e2sDopbKt6T9VtSHbIzsYAyCbQlTrqWV0+tuVT35fWAKzBPoTXHOOdQWFLwllnF
|
||||
FjJYKXl+siSdQl43Qo/gB1xtxVLzEMc8JZyhK1YpbrjdyTnPMlObEqEu4LGnjKDUcYXGBmWyTCi6
|
||||
bQAXw1Rl0k7mw6ZRJiOetXLubP08DUMQenJob4gTcafO0HWPEEstzE0+tbspSBLaOlwBtPqvZm3U
|
||||
AhgV/nLATsYSd30KuXCnlNDNbBv7gCYUdT1ptfRYlii140Pweax0L7ZcL2bmi0/ydhdlGmiexLVd
|
||||
p65DNWjn4QJBxSCxh0R1N3gvnIaXqUepCiueUnG9oropSGBDbwZYX8R35/6yKuwy3MzRN+rLIK9F
|
||||
pot5B47EPE+9T6kLDSO56+EQHjbbat2cUYrfGBj7oOi13d3d7W9zwXbS4MeEp5z3DIG5rKH7Hfmw
|
||||
tnzeXo5culHKvf7DtWmJSftbwaCF/UqolbGp1s8HAB/qhZq+ODrNKGUY7dbKA9Z0Jycn53PAZncU
|
||||
d+bTi43VioW853f2/HLxlZC3CS1Q1r0z18QQe0w7391NDFOismc42AP+73q+FnsGT9z9n/A7Q4w4
|
||||
GqbbUTBR/BLz7pmgd/S/nm2JrgU3irKkiLdGKN6MhG2Y8nzQG12p4XDbEncoo9B81dvx9vL424vn
|
||||
F2dn8b45+HzwNwAAAP//AwDhfkSS3ggAAA==
|
||||
headers:
|
||||
CF-RAY:
|
||||
- 97544b3fd9c66894-SJC
|
||||
Connection:
|
||||
- keep-alive
|
||||
Content-Encoding:
|
||||
- gzip
|
||||
Content-Type:
|
||||
- application/json
|
||||
Date:
|
||||
- Tue, 26 Aug 2025 15:17:10 GMT
|
||||
Server:
|
||||
- cloudflare
|
||||
Set-Cookie:
|
||||
- __cf_bm=AK6x7s00CdjvAhZqoKc.oyU2huXbBJAB_qi1o9cIHkk-1756221430-1.0.1.1-s9cWi1kLPHCBoqRe8BhCYWgaKEG.LQvm0b0NNJkJrpuMMIAUz9sSqijPatK.t2wknR3Qo65.PTew2trnDH5_.mL1l4JewiW1VndksvCWngY;
|
||||
path=/; expires=Tue, 26-Aug-25 15:47:10 GMT; domain=.api.openai.com; HttpOnly;
|
||||
Secure; SameSite=None
|
||||
- _cfuvid=3NkIk1Ua5GwknkJHax_bb1dBUHU9Yobu11sjZ9yu7Rg-1756221430892-0.0.1.1-604800000;
|
||||
path=/; domain=.api.openai.com; HttpOnly; Secure; SameSite=None
|
||||
Strict-Transport-Security:
|
||||
- max-age=31536000; includeSubDomains; preload
|
||||
Transfer-Encoding:
|
||||
- chunked
|
||||
X-Content-Type-Options:
|
||||
- nosniff
|
||||
access-control-expose-headers:
|
||||
- X-Request-ID
|
||||
alt-svc:
|
||||
- h3=":443"; ma=86400
|
||||
cf-cache-status:
|
||||
- DYNAMIC
|
||||
openai-organization:
|
||||
- crewai-iuxna1
|
||||
openai-processing-ms:
|
||||
- '5563'
|
||||
openai-project:
|
||||
- proj_xitITlrFeen7zjNSzML82h9x
|
||||
openai-version:
|
||||
- '2020-10-01'
|
||||
x-envoy-upstream-service-time:
|
||||
- '5651'
|
||||
x-ratelimit-limit-project-requests:
|
||||
- '10000'
|
||||
x-ratelimit-limit-requests:
|
||||
- '10000'
|
||||
x-ratelimit-limit-tokens:
|
||||
- '30000000'
|
||||
x-ratelimit-remaining-project-requests:
|
||||
- '9999'
|
||||
x-ratelimit-remaining-requests:
|
||||
- '9999'
|
||||
x-ratelimit-remaining-tokens:
|
||||
- '29998658'
|
||||
x-ratelimit-reset-project-requests:
|
||||
- 6ms
|
||||
x-ratelimit-reset-requests:
|
||||
- 6ms
|
||||
x-ratelimit-reset-tokens:
|
||||
- 2ms
|
||||
x-request-id:
|
||||
- req_8ee5ddbc01374cf487da8763d7dee507
|
||||
status:
|
||||
code: 200
|
||||
message: OK
|
||||
version: 1
|
||||
|
||||
@@ -1158,4 +1158,201 @@ interactions:
|
||||
- req_04364e8bad0fc80241f59550b9320360
|
||||
http_version: HTTP/1.1
|
||||
status_code: 200
|
||||
- request:
|
||||
body: '{"messages": [{"role": "system", "content": "You are Crew Manager. You
|
||||
are a seasoned manager with a knack for getting the best out of your team.\nYou
|
||||
are also known for your ability to delegate work to the right people, and to
|
||||
ask the right questions to get the best out of your team.\nEven though you don''t
|
||||
perform tasks by yourself, you have a lot of experience in the field, which
|
||||
allows you to properly evaluate the work of your team members.\nYour personal
|
||||
goal is: Manage the team to complete the task in the best way possible.\nYou
|
||||
ONLY have access to the following tools, and should NEVER make up tools that
|
||||
are not listed here:\n\nTool Name: Delegate work to coworker\nTool Arguments:
|
||||
{''task'': {''description'': ''The task to delegate'', ''type'': ''str''}, ''context'':
|
||||
{''description'': ''The context for the task'', ''type'': ''str''}, ''coworker'':
|
||||
{''description'': ''The role/name of the coworker to delegate to'', ''type'':
|
||||
''str''}}\nTool Description: Delegate a specific task to one of the following
|
||||
coworkers: Senior Writer, Researcher, CEO\nThe input to this tool should be
|
||||
the coworker, the task you want them to do, and ALL necessary context to execute
|
||||
the task, they know nothing about the task, so share absolutely everything you
|
||||
know, don''t reference things but instead explain them.\nTool Name: Ask question
|
||||
to coworker\nTool Arguments: {''question'': {''description'': ''The question
|
||||
to ask'', ''type'': ''str''}, ''context'': {''description'': ''The context for
|
||||
the question'', ''type'': ''str''}, ''coworker'': {''description'': ''The role/name
|
||||
of the coworker to ask'', ''type'': ''str''}}\nTool Description: Ask a specific
|
||||
question to one of the following coworkers: Senior Writer, Researcher, CEO\nThe
|
||||
input to this tool should be the coworker, the question you have for them, and
|
||||
ALL necessary context to ask the question properly, they know nothing about
|
||||
the question, so share absolutely everything you know, don''t reference things
|
||||
but instead explain them.\n\nIMPORTANT: Use the following format in your response:\n\n```\nThought:
|
||||
you should always think about what to do\nAction: the action to take, only one
|
||||
name of [Delegate work to coworker, Ask question to coworker], just the name,
|
||||
exactly as it''s written.\nAction Input: the input to the action, just a simple
|
||||
JSON object, enclosed in curly braces, using \" to wrap keys and values.\nObservation:
|
||||
the result of the action\n```\n\nOnce all necessary information is gathered,
|
||||
return the following format:\n\n```\nThought: I now know the final answer\nFinal
|
||||
Answer: the final answer to the original input question\n```"}, {"role": "user",
|
||||
"content": "\nCurrent Task: Write one amazing paragraph about AI.\n\nThis is
|
||||
the expected criteria for your final answer: A single paragraph with 4 sentences.\nyou
|
||||
MUST return the actual complete content as the final answer, not a summary.\n\nThis
|
||||
is the context you''re working with:\n\"Artificial Intelligence (AI) stands
|
||||
at the forefront of technological advancement, utilizing machine learning, neural
|
||||
networks, and natural language processing to replicate human cognitive functions.
|
||||
The transformative power of AI spans across industries, driving revolutionary
|
||||
changes from predictive diagnostics in healthcare to algorithmic trading in
|
||||
finance. By rapidly analyzing large datasets, AI reveals hidden patterns, enhances
|
||||
efficiency, and sparks unprecedented innovation. As we embrace this technology,
|
||||
it is crucial to prioritize ethical considerations and responsible development
|
||||
to fully leverage its potential while addressing associated risks.\"\n\nBegin!
|
||||
This is VERY important to you, use the tools available and give your best Final
|
||||
Answer, your job depends on it!\n\nThought:"}, {"role": "assistant", "content":
|
||||
"Thought: To craft an amazing paragraph about AI as described in the context,
|
||||
I should delegate this task to the Senior Writer and include all the detailed
|
||||
context provided.\n\nAction: Delegate work to coworker\nAction Input: {\"task\":
|
||||
\"Write one amazing paragraph about AI\", \"context\": \"Artificial Intelligence
|
||||
(AI) is a transformative technology that leverages machine learning, neural
|
||||
networks, and natural language processing to perform tasks that traditionally
|
||||
require human intelligence, such as visual perception, speech recognition, and
|
||||
decision-making. Its potential to revolutionize industries\u2014from healthcare
|
||||
with predictive diagnostics to finance with algorithmic trading\u2014underscores
|
||||
its profound impact on our daily lives and the global economy. With the ability
|
||||
to analyze vast amounts of data quickly and accurately, AI can unveil hidden
|
||||
patterns, drive efficiencies, and foster innovation at an unprecedented scale.
|
||||
As we move forward, the ethical and responsible development of AI will be paramount
|
||||
to harness its benefits while mitigating risks.\", \"coworker\": \"Senior Writer\"}\n\nObservation:
|
||||
The Senior Writer has received the task and context.\nObservation: Artificial
|
||||
Intelligence (AI) is a transformative technology that leverages machine learning,
|
||||
neural networks, and natural language processing to perform tasks that traditionally
|
||||
require human intelligence, such as visual perception, speech recognition, and
|
||||
decision-making. Its potential to revolutionize industries\u2014from healthcare
|
||||
with predictive diagnostics to finance with algorithmic trading\u2014underscores
|
||||
its profound impact on our daily lives and the global economy. With the ability
|
||||
to analyze vast amounts of data quickly and accurately, AI can unveil hidden
|
||||
patterns, drive efficiencies, and foster innovation at an unprecedented scale.
|
||||
As we move forward, the ethical and responsible development of AI will be paramount
|
||||
to harness its benefits while mitigating risks."}], "model": "gpt-4o", "stop":
|
||||
["\nObservation:"]}'
|
||||
headers:
|
||||
accept:
|
||||
- application/json
|
||||
accept-encoding:
|
||||
- gzip, deflate, zstd
|
||||
connection:
|
||||
- keep-alive
|
||||
content-length:
|
||||
- '5729'
|
||||
content-type:
|
||||
- application/json
|
||||
cookie:
|
||||
- _cfuvid=Q23zZGhbuNaTNh.RPoM_1O4jWXLFM.KtSgSytn2NO.Q-1744492727869-0.0.1.1-604800000
|
||||
host:
|
||||
- api.openai.com
|
||||
user-agent:
|
||||
- OpenAI/Python 1.93.0
|
||||
x-stainless-arch:
|
||||
- arm64
|
||||
x-stainless-async:
|
||||
- 'false'
|
||||
x-stainless-lang:
|
||||
- python
|
||||
x-stainless-os:
|
||||
- MacOS
|
||||
x-stainless-package-version:
|
||||
- 1.93.0
|
||||
x-stainless-raw-response:
|
||||
- 'true'
|
||||
x-stainless-read-timeout:
|
||||
- '600.0'
|
||||
x-stainless-retry-count:
|
||||
- '0'
|
||||
x-stainless-runtime:
|
||||
- CPython
|
||||
x-stainless-runtime-version:
|
||||
- 3.12.9
|
||||
method: POST
|
||||
uri: https://api.openai.com/v1/chat/completions
|
||||
response:
|
||||
body:
|
||||
string: !!binary |
|
||||
H4sIAAAAAAAAAwAAAP//jFTbbhtHDH33VxD71AKyIcu27PhNCNBCBYICRdACqQOZmuHuspolN0Ou
|
||||
VCUwkI/IF+ZLiln5orYp0JfB7pzhmXPIIT+dAFQcq1uoQoseuj6dvr5Zv1m+e/P6yvI+XL+bCl//
|
||||
+uN08cvPP81u3laTEqHrPyj4U9RZ0K5P5KxygEMmdCqs59dX8/P5fDafj0CnkVIJa3o/vdTT2XR2
|
||||
eTq9OZ3OHwNb5UBW3cLvJwAAn8a1SJRIf1a3MJ087XRkhg1Vt8+HAKqsqexUaMbmKF5NXsCg4iSj
|
||||
6vv7+zt52+rQtH4LSxDdwaYs3hLULJgAxXaUz+7kh/F3Mf7ewiI71xwYEyzFKSVuSALBd4vl98AG
|
||||
CJ5RrNbcofOWwCm0okmbPXiLDom2lLEhgw5Dy0KQCLOwNBMQGjImEPKd5o1NACWCoI+7CaUZsCHo
|
||||
swYyY2nAFXrK5TJwtI0drvCMkUs1MKU9ZPowcCZohw4F+Ej0BGwILaDBlm3AVLgC9SVyAtYThRYy
|
||||
BW2ED3tFTqTAxiqnHW5YmjNYukGvJbMlKa6QaatpKBH8kYAlDuaZyb5+/lJn7aAlTN4GzAQ79hb6
|
||||
TJHDmKzI2Iiac7DCVCpRkjsew9RoZm87DgeH0nz9/GWQSNmCZjLgoiRrrYNE4K7H4KACOmSIyGkP
|
||||
ibdko4tS5ybpGhNQUNFufwa/lVsKgGtO7PuiAAXT/iPBFs0BOx3EDbSGiI7wYeCwSfuREEMYMjql
|
||||
/QQWSwgoMMiWOEHLMZJAj+6UxSYQc7FKdXlGJIHpsdC1mlMGFtEtlvQBOow8faZAkcQpggVMdAYL
|
||||
gx1Bp1uCWvMOc5yM2slbDuPzjZDJehXjdSKItKWkfUfiRf9iCTtOCdYEPeaDseK3xSxkh1SuSagu
|
||||
H7uWE0HHzg16eXaZbWNnd3J/f3/cYJnqwbD0twwpHQEooj5aGlv7/SPy8NzMSZs+69r+EVrVLGzt
|
||||
KhOaSmlcc+2rEX04AXg/Do3hb3Og6rN2va9cNzRedz69fnUgrF7m1BF8efGIujqmI2A2m02+QbmK
|
||||
5MjJjiZPFTC0FF9iX8YUDpH1CDg5Mv5vPd/iPphnaf4P/QsQSidTXD1117Hnl2OZyiD/r2PPiR4F
|
||||
V0Z5y4FWzpRLMSLVOKTDjK1sb07dqmZpKPeZD4O27lc301dX86uLi7CuTh5O/gIAAP//AwAI5M9k
|
||||
cQYAAA==
|
||||
headers:
|
||||
CF-RAY:
|
||||
- 974f089aacfe239e-SJC
|
||||
Connection:
|
||||
- keep-alive
|
||||
Content-Encoding:
|
||||
- gzip
|
||||
Content-Type:
|
||||
- application/json
|
||||
Date:
|
||||
- Mon, 25 Aug 2025 23:57:47 GMT
|
||||
Server:
|
||||
- cloudflare
|
||||
Set-Cookie:
|
||||
- __cf_bm=x3Z22TL7MiNrhhoqoBAkoerRf2slLEFNFLerhsr8uWg-1756166267-1.0.1.1-YsxFr.4IpdTvhEkob8OJGHggHZj5mRzRdK0Ta9PllabKVSUEp9sdbUxQpYatb5wW12dk9nglKUcebewHsL7RgYB0Y0BPnFRoyX0r5kn14HI;
|
||||
path=/; expires=Tue, 26-Aug-25 00:27:47 GMT; domain=.api.openai.com; HttpOnly;
|
||||
Secure; SameSite=None
|
||||
- _cfuvid=H4uvKyyXorF4TlJkoLerJYI2PABtQ6M4T.XxqO7FHCQ-1756166267474-0.0.1.1-604800000;
|
||||
path=/; domain=.api.openai.com; HttpOnly; Secure; SameSite=None
|
||||
Strict-Transport-Security:
|
||||
- max-age=31536000; includeSubDomains; preload
|
||||
Transfer-Encoding:
|
||||
- chunked
|
||||
X-Content-Type-Options:
|
||||
- nosniff
|
||||
access-control-expose-headers:
|
||||
- X-Request-ID
|
||||
alt-svc:
|
||||
- h3=":443"; ma=86400
|
||||
cf-cache-status:
|
||||
- DYNAMIC
|
||||
openai-organization:
|
||||
- crewai-iuxna1
|
||||
openai-processing-ms:
|
||||
- '1314'
|
||||
openai-project:
|
||||
- proj_xitITlrFeen7zjNSzML82h9x
|
||||
openai-version:
|
||||
- '2020-10-01'
|
||||
x-envoy-upstream-service-time:
|
||||
- '1380'
|
||||
x-ratelimit-limit-project-requests:
|
||||
- '10000'
|
||||
x-ratelimit-limit-requests:
|
||||
- '10000'
|
||||
x-ratelimit-limit-tokens:
|
||||
- '30000000'
|
||||
x-ratelimit-remaining-project-requests:
|
||||
- '9999'
|
||||
x-ratelimit-remaining-requests:
|
||||
- '9999'
|
||||
x-ratelimit-remaining-tokens:
|
||||
- '29998623'
|
||||
x-ratelimit-reset-project-requests:
|
||||
- 6ms
|
||||
x-ratelimit-reset-requests:
|
||||
- 6ms
|
||||
x-ratelimit-reset-tokens:
|
||||
- 2ms
|
||||
x-request-id:
|
||||
- req_47e6588839144756828145be5a5a3b63
|
||||
status:
|
||||
code: 200
|
||||
message: OK
|
||||
version: 1
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user