Compare commits

..

2 Commits

Author SHA1 Message Date
Devin AI
eee7439610 Fix lint: Sort imports in test_llm.py
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-05-01 19:59:44 +00:00
Devin AI
a0a536e737 Fix issue #2738: Exclude stop parameter for o3 model
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-05-01 19:57:17 +00:00
257 changed files with 6407 additions and 16639 deletions

38
.github/security.md vendored
View File

@@ -1,27 +1,19 @@
## CrewAI Security Vulnerability Reporting Policy
CrewAI takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organization.
If you believe you have found a security vulnerability in any CrewAI product or service, please report it to us as described below.
CrewAI prioritizes the security of our software products, services, and GitHub repositories. To promptly address vulnerabilities, follow these steps for reporting security issues:
## Reporting a Vulnerability
Please do not report security vulnerabilities through public GitHub issues.
To report a vulnerability, please email us at security@crewai.com.
Please include the requested information listed below so that we can triage your report more quickly
### Reporting Process
Do **not** report vulnerabilities via public GitHub issues.
- Type of issue (e.g. SQL injection, cross-site scripting, etc.)
- Full paths of source file(s) related to the manifestation of the issue
- The location of the affected source code (tag/branch/commit or direct URL)
- Any special configuration required to reproduce the issue
- Step-by-step instructions to reproduce the issue (please include screenshots if needed)
- Proof-of-concept or exploit code (if possible)
- Impact of the issue, including how an attacker might exploit the issue
Email all vulnerability reports directly to:
**security@crewai.com**
Once we have received your report, we will respond to you at the email address you provide. If the issue is confirmed, we will release a patch as soon as possible depending on the complexity of the issue.
### Required Information
To help us quickly validate and remediate the issue, your report must include:
- **Vulnerability Type:** Clearly state the vulnerability type (e.g., SQL injection, XSS, privilege escalation).
- **Affected Source Code:** Provide full file paths and direct URLs (branch, tag, or commit).
- **Reproduction Steps:** Include detailed, step-by-step instructions. Screenshots are recommended.
- **Special Configuration:** Document any special settings or configurations required to reproduce.
- **Proof-of-Concept (PoC):** Provide exploit or PoC code (if available).
- **Impact Assessment:** Clearly explain the severity and potential exploitation scenarios.
### Our Response
- We will acknowledge receipt of your report promptly via your provided email.
- Confirmed vulnerabilities will receive priority remediation based on severity.
- Patches will be released as swiftly as possible following verification.
### Reward Notice
Currently, we do not offer a bug bounty program. Rewards, if issued, are discretionary.
At this time, we are not offering a bug bounty program. Any rewards will be at our discretion.

View File

@@ -5,29 +5,12 @@ on: [pull_request]
jobs:
lint:
runs-on: ubuntu-latest
env:
TARGET_BRANCH: ${{ github.event.pull_request.base.ref }}
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Fetch Target Branch
run: git fetch origin $TARGET_BRANCH --depth=1
- name: Install Ruff
run: pip install ruff
- name: Get Changed Python Files
id: changed-files
- name: Install Requirements
run: |
merge_base=$(git merge-base origin/"$TARGET_BRANCH" HEAD)
changed_files=$(git diff --name-only --diff-filter=ACMRTUB "$merge_base" | grep '\.py$' || true)
echo "files<<EOF" >> $GITHUB_OUTPUT
echo "$changed_files" >> $GITHUB_OUTPUT
echo "EOF" >> $GITHUB_OUTPUT
pip install ruff
- name: Run Ruff on Changed Files
if: ${{ steps.changed-files.outputs.files != '' }}
run: |
echo "${{ steps.changed-files.outputs.files }}" | tr " " "\n" | xargs -I{} ruff check "{}"
- name: Run Ruff Linter
run: ruff check

View File

@@ -4,48 +4,6 @@ exclude = [
]
[lint]
select = ["ALL"]
ignore = [
"D100", # Missing docstring in public module
"D101", # Missing docstring in public class
"D102", # Missing docstring in public method
"D103", # Missing docstring in public function
"D104", # Missing docstring in public package
"D105", # Missing docstring in magic method
"D106", # Missing docstring in public nested class
"D107", # Missing docstring in __init__
"D205", # 1 blank line required between summary line and description
"ANN001", # Missing type annotation for function argument
"ANN002", # Missing type annotation for *args
"ANN003", # Missing type annotation for **kwargs
"ANN201", # Missing return type annotation for public function
"ANN202", # Missing return type annotation for private function
"ANN204", # Missing return type annotation for special method
"ANN205", # Missing return type annotation for staticmethod
"ANN206", # Missing return type annotation for classmethod
"E501", # Line too long
"PT011", # pytest.raises() without match parameter
"PT012", # pytest.raises() block should contain a single simple statement
"SIM117", # Use a single `with` statement with multiple contexts
"PLR2004", # Magic value used in comparison
"B017", # Do not assert blind exception
select = [
"I", # isort rules
]
[lint.per-file-ignores]
"tests/*" = [
"S101", # Allow assert in tests
"SLF001", # Allow private member access in tests
"DTZ001", # Allow datetime without tzinfo in tests
"PTH107", # Allow os.remove instead of Path.unlink in tests
"PTH118", # Allow os.path.join() in tests
"PTH120", # Allow os.path.dirname() in tests
"PTH123", # Allow open() instead of Path.open() in tests
"PTH202", # Allow os.path.getsize in tests
"PT012", # Allow multiple statements in pytest.raises() block in tests
"SIM117", # Allow nested with statements in tests
"PLR2004", # Allow magic values in tests
"B017", # Allow asserting blind exceptions in tests
]
[lint.isort]
known-first-party = ["crewai"]

View File

@@ -504,7 +504,7 @@ This example demonstrates how to:
CrewAI supports using various LLMs through a variety of connection options. By default your agents will use the OpenAI API when querying the model. However, there are several other ways to allow your agents to connect to models. For example, you can configure your agents to use a local model via the Ollama tool.
Please refer to the [Connect CrewAI to LLMs](https://docs.crewai.com/how-to/LLM-Connections/) page for details on configuring your agents' connections to models.
Please refer to the [Connect CrewAI to LLMs](https://docs.crewai.com/how-to/LLM-Connections/) page for details on configuring you agents' connections to models.
## How CrewAI Compares

View File

@@ -27,7 +27,7 @@ A crew in crewAI represents a collaborative group of agents working together to
| **Step Callback** _(optional)_ | `step_callback` | A function that is called after each step of every agent. This can be used to log the agent's actions or to perform other operations; it won't override the agent-specific `step_callback`. |
| **Task Callback** _(optional)_ | `task_callback` | A function that is called after the completion of each task. Useful for monitoring or additional operations post-task execution. |
| **Share Crew** _(optional)_ | `share_crew` | Whether you want to share the complete crew information and execution with the crewAI team to make the library better, and allow us to train models. |
| **Output Log File** _(optional)_ | `output_log_file` | Set to True to save logs as logs.txt in the current directory or provide a file path. Logs will be in JSON format if the filename ends in .json, otherwise .txt. Defaults to `None`. |
| **Output Log File** _(optional)_ | `output_log_file` | Set to True to save logs as logs.txt in the current directory or provide a file path. Logs will be in JSON format if the filename ends in .json, otherwise .txt. Defautls to `None`. |
| **Manager Agent** _(optional)_ | `manager_agent` | `manager` sets a custom agent that will be used as a manager. |
| **Prompt File** _(optional)_ | `prompt_file` | Path to the prompt JSON file to be used for the crew. |
| **Planning** *(optional)* | `planning` | Adds planning ability to the Crew. When activated before each Crew iteration, all Crew data is sent to an AgentPlanner that will plan the tasks and this plan will be added to each task description. |
@@ -246,7 +246,7 @@ print(f"Token Usage: {crew_output.token_usage}")
You can see real time log of the crew execution, by setting `output_log_file` as a `True(Boolean)` or a `file_name(str)`. Supports logging of events as both `file_name.txt` and `file_name.json`.
In case of `True(Boolean)` will save as `logs.txt`.
In case of `output_log_file` is set as `False(Boolean)` or `None`, the logs will not be populated.
In case of `output_log_file` is set as `False(Booelan)` or `None`, the logs will not be populated.
```python Code
# Save crew logs

View File

@@ -397,53 +397,6 @@ result = crew.kickoff(inputs={"question": "What city does John live in and how o
John is 30 years old and lives in San Francisco.
```
</CodeGroup>
## Query Rewriting
CrewAI implements an intelligent query rewriting mechanism to optimize knowledge retrieval. When an agent needs to search through knowledge sources, the raw task prompt is automatically transformed into a more effective search query.
### How Query Rewriting Works
1. When an agent executes a task with knowledge sources available, the `_get_knowledge_search_query` method is triggered
2. The agent's LLM is used to transform the original task prompt into an optimized search query
3. This optimized query is then used to retrieve relevant information from knowledge sources
### Benefits of Query Rewriting
<CardGroup cols={2}>
<Card title="Improved Retrieval Accuracy" icon="bullseye-arrow">
By focusing on key concepts and removing irrelevant content, query rewriting helps retrieve more relevant information.
</Card>
<Card title="Context Awareness" icon="brain">
The rewritten queries are designed to be more specific and context-aware for vector database retrieval.
</Card>
</CardGroup>
### Implementation Details
Query rewriting happens transparently using a system prompt that instructs the LLM to:
- Focus on key words of the intended task
- Make the query more specific and context-aware
- Remove irrelevant content like output format instructions
- Generate only the rewritten query without preamble or postamble
<Tip>
This mechanism is fully automatic and requires no configuration from users. The agent's LLM is used to perform the query rewriting, so using a more capable LLM can improve the quality of rewritten queries.
</Tip>
### Example
```python
# Original task prompt
task_prompt = "Answer the following questions about the user's favorite movies: What movie did John watch last week? Format your answer in JSON."
# Behind the scenes, this might be rewritten as:
rewritten_query = "What movies did John watch last week?"
```
The rewritten query is more focused on the core information need and removes irrelevant instructions about output formatting.
## Clearing Knowledge
If you need to clear the knowledge stored in CrewAI, you can use the `crewai reset-memories` command with the `--knowledge` option.
@@ -700,11 +653,4 @@ recent_news = SpaceNewsKnowledgeSource(
- Configure appropriate embedding models
- Consider using local embedding providers for faster processing
</Accordion>
<Accordion title="One Time Knowledge">
- With the typical file structure provided by CrewAI, knowledge sources are embedded every time the kickoff is triggered.
- If the knowledge sources are large, this leads to inefficiency and increased latency, as the same data is embedded each time.
- To resolve this, directly initialize the knowledge parameter instead of the knowledge_sources parameter.
- Link to the issue to get complete idea [Github Issue](https://github.com/crewAIInc/crewAI/issues/2755)
</Accordion>
</AccordionGroup>

View File

@@ -27,19 +27,23 @@ Large Language Models (LLMs) are the core intelligence behind CrewAI agents. The
</Card>
</CardGroup>
## Setting up your LLM
## Setting Up Your LLM
There are different places in CrewAI code where you can specify the model to use. Once you specify the model you are using, you will need to provide the configuration (like an API key) for each of the model providers you use. See the [provider configuration examples](#provider-configuration-examples) section for your provider.
There are three ways to configure LLMs in CrewAI. Choose the method that best fits your workflow:
<Tabs>
<Tab title="1. Environment Variables">
The simplest way to get started. Set the model in your environment directly, through an `.env` file or in your app code. If you used `crewai create` to bootstrap your project, it will be set already.
The simplest way to get started. Set these variables in your environment:
```bash .env
MODEL=model-id # e.g. gpt-4o, gemini-2.0-flash, claude-3-sonnet-...
```bash
# Required: Your API key for authentication
OPENAI_API_KEY=<your-api-key>
# Be sure to set your API keys here too. See the Provider
# section below.
# Optional: Default model selection
OPENAI_MODEL_NAME=gpt-4o-mini # Default if not set
# Optional: Organization ID (if applicable)
OPENAI_ORGANIZATION_ID=<your-org-id>
```
<Warning>
@@ -49,13 +53,13 @@ There are different places in CrewAI code where you can specify the model to use
<Tab title="2. YAML Configuration">
Create a YAML file to define your agent configurations. This method is great for version control and team collaboration:
```yaml agents.yaml {6}
```yaml
researcher:
role: Research Specialist
goal: Conduct comprehensive research and analysis
backstory: A dedicated research professional with years of experience
verbose: true
llm: provider/model-id # e.g. openai/gpt-4o, google/gemini-2.0-flash, anthropic/claude...
llm: openai/gpt-4o-mini # your model here
# (see provider configuration examples below for more)
```
@@ -70,23 +74,23 @@ There are different places in CrewAI code where you can specify the model to use
<Tab title="3. Direct Code">
For maximum flexibility, configure LLMs directly in your Python code:
```python {4,8}
```python
from crewai import LLM
# Basic configuration
llm = LLM(model="model-id-here") # gpt-4o, gemini-2.0-flash, anthropic/claude...
llm = LLM(model="gpt-4")
# Advanced configuration with detailed parameters
llm = LLM(
model="model-id-here", # gpt-4o, gemini-2.0-flash, anthropic/claude...
model="gpt-4o-mini",
temperature=0.7, # Higher for more creative outputs
timeout=120, # Seconds to wait for response
max_tokens=4000, # Maximum length of response
top_p=0.9, # Nucleus sampling parameter
frequency_penalty=0.1 , # Reduce repetition
presence_penalty=0.1, # Encourage topic diversity
timeout=120, # Seconds to wait for response
max_tokens=4000, # Maximum length of response
top_p=0.9, # Nucleus sampling parameter
frequency_penalty=0.1, # Reduce repetition
presence_penalty=0.1, # Encourage topic diversity
response_format={"type": "json"}, # For structured outputs
seed=42 # For reproducible results
seed=42 # For reproducible results
)
```
@@ -106,6 +110,7 @@ There are different places in CrewAI code where you can specify the model to use
## Provider Configuration Examples
CrewAI supports a multitude of LLM providers, each offering unique features, authentication methods, and model capabilities.
In this section, you'll find detailed examples that help you select, configure, and optimize the LLM that best fits your project's needs.
@@ -169,55 +174,19 @@ In this section, you'll find detailed examples that help you select, configure,
```
</Accordion>
<Accordion title="Google (Gemini API)">
Set your API key in your `.env` file. If you need a key, or need to find an
existing key, check [AI Studio](https://aistudio.google.com/apikey).
<Accordion title="Google">
Set the following environment variables in your `.env` file:
```toml .env
```toml Code
# Option 1: Gemini accessed with an API key.
# https://ai.google.dev/gemini-api/docs/api-key
GEMINI_API_KEY=<your-api-key>
# Option 2: Vertex AI IAM credentials for Gemini, Anthropic, and Model Garden.
# https://cloud.google.com/vertex-ai/generative-ai/docs/overview
```
Example usage in your CrewAI project:
```python Code
from crewai import LLM
llm = LLM(
model="gemini/gemini-2.0-flash",
temperature=0.7,
)
```
### Gemini models
Google offers a range of powerful models optimized for different use cases.
| Model | Context Window | Best For |
|--------------------------------|----------------|-------------------------------------------------------------------|
| gemini-2.5-flash-preview-04-17 | 1M tokens | Adaptive thinking, cost efficiency |
| gemini-2.5-pro-preview-05-06 | 1M tokens | Enhanced thinking and reasoning, multimodal understanding, advanced coding, and more |
| gemini-2.0-flash | 1M tokens | Next generation features, speed, thinking, and realtime streaming |
| gemini-2.0-flash-lite | 1M tokens | Cost efficiency and low latency |
| gemini-1.5-flash | 1M tokens | Balanced multimodal model, good for most tasks |
| gemini-1.5-flash-8B | 1M tokens | Fastest, most cost-efficient, good for high-frequency tasks |
| gemini-1.5-pro | 2M tokens | Best performing, wide variety of reasoning tasks including logical reasoning, coding, and creative collaboration |
The full list of models is available in the [Gemini model docs](https://ai.google.dev/gemini-api/docs/models).
### Gemma
The Gemini API also allows you to use your API key to access [Gemma models](https://ai.google.dev/gemma/docs) hosted on Google infrastructure.
| Model | Context Window |
|----------------|----------------|
| gemma-3-1b-it | 32k tokens |
| gemma-3-4b-it | 32k tokens |
| gemma-3-12b-it | 32k tokens |
| gemma-3-27b-it | 128k tokens |
</Accordion>
<Accordion title="Google (Vertex AI)">
Get credentials from your Google Cloud Console and save it to a JSON file, then load it with the following code:
Get credentials from your Google Cloud Console and save it to a JSON file with the following code:
```python Code
import json
@@ -241,18 +210,14 @@ In this section, you'll find detailed examples that help you select, configure,
vertex_credentials=vertex_credentials_json
)
```
Google offers a range of powerful models optimized for different use cases:
| Model | Context Window | Best For |
|--------------------------------|----------------|-------------------------------------------------------------------|
| gemini-2.5-flash-preview-04-17 | 1M tokens | Adaptive thinking, cost efficiency |
| gemini-2.5-pro-preview-05-06 | 1M tokens | Enhanced thinking and reasoning, multimodal understanding, advanced coding, and more |
| gemini-2.0-flash | 1M tokens | Next generation features, speed, thinking, and realtime streaming |
| gemini-2.0-flash-lite | 1M tokens | Cost efficiency and low latency |
| gemini-1.5-flash | 1M tokens | Balanced multimodal model, good for most tasks |
| gemini-1.5-flash-8B | 1M tokens | Fastest, most cost-efficient, good for high-frequency tasks |
| gemini-1.5-pro | 2M tokens | Best performing, wide variety of reasoning tasks including logical reasoning, coding, and creative collaboration |
| Model | Context Window | Best For |
|-----------------------|----------------|------------------------------------------------------------------|
| gemini-2.0-flash-exp | 1M tokens | Higher quality at faster speed, multimodal model, good for most tasks |
| gemini-1.5-flash | 1M tokens | Balanced multimodal model, good for most tasks |
| gemini-1.5-flash-8B | 1M tokens | Fastest, most cost-efficient, good for high-frequency tasks |
| gemini-1.5-pro | 2M tokens | Best performing, wide variety of reasoning tasks including logical reasoning, coding, and creative collaboration |
</Accordion>
<Accordion title="Azure">
@@ -418,7 +383,7 @@ In this section, you'll find detailed examples that help you select, configure,
| microsoft/phi-3-medium-4k-instruct | 4,096 tokens | Lightweight, state-of-the-art open LLM with strong math and logical reasoning skills. |
| microsoft/phi-3-medium-128k-instruct | 128K tokens | Lightweight, state-of-the-art open LLM with strong math and logical reasoning skills. |
| microsoft/phi-3.5-mini-instruct | 128K tokens | Lightweight multilingual LLM powering AI applications in latency bound, memory/compute constrained environments |
| microsoft/phi-3.5-moe-instruct | 128K tokens | Advanced LLM based on Mixture of Experts architecture to deliver compute efficient content generation |
| microsoft/phi-3.5-moe-instruct | 128K tokens | Advanced LLM based on Mixture of Experts architecure to deliver compute efficient content generation |
| microsoft/kosmos-2 | 1,024 tokens | Groundbreaking multimodal model designed to understand and reason about visual elements in images. |
| microsoft/phi-3-vision-128k-instruct | 128k tokens | Cutting-edge open multimodal model exceling in high-quality reasoning from images. |
| microsoft/phi-3.5-vision-instruct | 128k tokens | Cutting-edge open multimodal model exceling in high-quality reasoning from images. |
@@ -442,19 +407,19 @@ In this section, you'll find detailed examples that help you select, configure,
</Accordion>
<Accordion title="Local NVIDIA NIM Deployed using WSL2">
NVIDIA NIM enables you to run powerful LLMs locally on your Windows machine using WSL2 (Windows Subsystem for Linux).
This approach allows you to leverage your NVIDIA GPU for private, secure, and cost-effective AI inference without relying on cloud services.
NVIDIA NIM enables you to run powerful LLMs locally on your Windows machine using WSL2 (Windows Subsystem for Linux).
This approach allows you to leverage your NVIDIA GPU for private, secure, and cost-effective AI inference without relying on cloud services.
Perfect for development, testing, or production scenarios where data privacy or offline capabilities are required.
Here is a step-by-step guide to setting up a local NVIDIA NIM model:
1. Follow installation instructions from [NVIDIA Website](https://docs.nvidia.com/nim/wsl2/latest/getting-started.html)
2. Install the local model. For Llama 3.1-8b follow [instructions](https://build.nvidia.com/meta/llama-3_1-8b-instruct/deploy)
3. Configure your crewai local models:
```python Code
from crewai.llm import LLM
@@ -476,7 +441,7 @@ In this section, you'll find detailed examples that help you select, configure,
config=self.agents_config['researcher'], # type: ignore[index]
llm=local_nvidia_nim_llm
)
# ...
```
</Accordion>
@@ -672,19 +637,19 @@ CrewAI supports streaming responses from LLMs, allowing your application to rece
When streaming is enabled, responses are delivered in chunks as they're generated, creating a more responsive user experience.
</Tab>
<Tab title="Event Handling">
CrewAI emits events for each chunk received during streaming:
```python
from crewai import LLM
from crewai.utilities.events import EventHandler, LLMStreamChunkEvent
class MyEventHandler(EventHandler):
def on_llm_stream_chunk(self, event: LLMStreamChunkEvent):
# Process each chunk as it arrives
print(f"Received chunk: {event.chunk}")
# Register the event handler
from crewai.utilities.events import crewai_event_bus
crewai_event_bus.register_handler(MyEventHandler())
@@ -820,7 +785,7 @@ Learn how to get the most out of your LLM configuration:
<Tip>
Use larger context models for extensive tasks
</Tip>
```python
# Large context model
llm = LLM(model="openai/gpt-4o") # 128K tokens

View File

@@ -35,8 +35,7 @@ Let's get started building your first crew!
Before starting, make sure you have:
1. Installed CrewAI following the [installation guide](/installation)
2. Set up your LLM API key in your environment, following the [LLM setup
guide](/concepts/llms#setting-up-your-llm)
2. Set up your OpenAI API key in your environment variables
3. Basic understanding of Python
## Step 1: Create a New CrewAI Project
@@ -93,8 +92,7 @@ For our research crew, we'll create two agents:
1. A **researcher** who excels at finding and organizing information
2. An **analyst** who can interpret research findings and create insightful reports
Let's modify the `agents.yaml` file to define these specialized agents. Be sure
to set `llm` to the provider you are using.
Let's modify the `agents.yaml` file to define these specialized agents:
```yaml
# src/research_crew/config/agents.yaml
@@ -109,7 +107,7 @@ researcher:
finding relevant information from various sources. You excel at
organizing information in a clear and structured manner, making
complex topics accessible to others.
llm: provider/model-id # e.g. openai/gpt-4o, google/gemini-2.0-flash, anthropic/claude...
llm: openai/gpt-4o-mini
analyst:
role: >
@@ -122,7 +120,7 @@ analyst:
and technical writing. You have a talent for identifying patterns
and extracting meaningful insights from research data, then
communicating those insights effectively through well-crafted reports.
llm: provider/model-id # e.g. openai/gpt-4o, google/gemini-2.0-flash, anthropic/claude...
llm: openai/gpt-4o-mini
```
Notice how each agent has a distinct role, goal, and backstory. These elements aren't just descriptive - they actively shape how the agent approaches its tasks. By crafting these carefully, you can create agents with specialized skills and perspectives that complement each other.
@@ -284,12 +282,12 @@ This script prepares the environment, specifies our research topic, and kicks of
Create a `.env` file in your project root with your API keys:
```sh
```
OPENAI_API_KEY=your_openai_api_key
SERPER_API_KEY=your_serper_api_key
# Add your provider's API key here too.
```
See the [LLM Setup guide](/concepts/llms#setting-up-your-llm) for details on configuring your provider of choice. You can get a Serper API key from [Serper.dev](https://serper.dev/).
You can get a Serper API key from [Serper.dev](https://serper.dev/).
## Step 8: Install Dependencies

View File

@@ -45,8 +45,7 @@ Let's dive in and build your first flow!
Before starting, make sure you have:
1. Installed CrewAI following the [installation guide](/installation)
2. Set up your LLM API key in your environment, following the [LLM setup
guide](/concepts/llms#setting-up-your-llm)
2. Set up your OpenAI API key in your environment variables
3. Basic understanding of Python
## Step 1: Create a New CrewAI Flow Project
@@ -108,8 +107,6 @@ Now, let's modify the generated files for the content writer crew. We'll set up
1. First, update the agents configuration file to define our content creation team:
Remember to set `llm` to the provider you are using.
```yaml
# src/guide_creator_flow/crews/content_crew/config/agents.yaml
content_writer:
@@ -122,7 +119,7 @@ content_writer:
You are a talented educational writer with expertise in creating clear, engaging
content. You have a gift for explaining complex concepts in accessible language
and organizing information in a way that helps readers build their understanding.
llm: provider/model-id # e.g. openai/gpt-4o, google/gemini-2.0-flash, anthropic/claude...
llm: openai/gpt-4o-mini
content_reviewer:
role: >
@@ -135,7 +132,7 @@ content_reviewer:
content. You have an eye for detail, clarity, and coherence. You excel at
improving content while maintaining the original author's voice and ensuring
consistent quality across multiple sections.
llm: provider/model-id # e.g. openai/gpt-4o, google/gemini-2.0-flash, anthropic/claude...
llm: openai/gpt-4o-mini
```
These agent definitions establish the specialized roles and perspectives that will shape how our AI agents approach content creation. Notice how each agent has a distinct purpose and expertise.
@@ -444,15 +441,10 @@ This is the power of flows - combining different types of processing (user inter
## Step 6: Set Up Your Environment Variables
Create a `.env` file in your project root with your API keys. See the [LLM setup
guide](/concepts/llms#setting-up-your-llm) for details on configuring a provider.
Create a `.env` file in your project root with your API keys:
```sh .env
```
OPENAI_API_KEY=your_openai_api_key
# or
GEMINI_API_KEY=your_gemini_api_key
# or
ANTHROPIC_API_KEY=your_anthropic_api_key
```
## Step 7: Install Dependencies
@@ -555,10 +547,7 @@ Let's break down the key components of flows to help you understand how to build
Flows allow you to make direct calls to language models when you need simple, structured responses:
```python
llm = LLM(
model="model-id-here", # gpt-4o, gemini-2.0-flash, anthropic/claude...
response_format=GuideOutline
)
llm = LLM(model="openai/gpt-4o-mini", response_format=GuideOutline)
response = llm.call(messages=messages)
```

View File

@@ -68,13 +68,7 @@ We'll create a CrewAI application where two agents collaborate to research and w
```python
from crewai import Agent, Crew, Process, Task
from crewai_tools import SerperDevTool
from openinference.instrumentation.crewai import CrewAIInstrumentor
from phoenix.otel import register
# setup monitoring for your crew
tracer_provider = register(
endpoint="http://localhost:6006/v1/traces")
CrewAIInstrumentor().instrument(skip_dep_check=True, tracer_provider=tracer_provider)
search_tool = SerperDevTool()
# Define your agents with roles and goals

View File

@@ -71,10 +71,6 @@ If you haven't installed `uv` yet, follow **step 1** to quickly get it set up on
```
</Warning>
<Warning>
If you encounter the `chroma-hnswlib==0.7.6` build error (`fatal error C1083: Cannot open include file: 'float.h'`) on Windows, install (Visual Studio Build Tools)[https://visualstudio.microsoft.com/downloads/] with *Desktop development with C++*.
</Warning>
- To verify that `crewai` is installed, run:
```shell
uv tool list

View File

@@ -180,9 +180,8 @@ Follow the steps below to get Crewing! 🚣‍♂️
</Step>
<Step title="Set your environment variables">
Before running your crew, make sure you have the following keys set as environment variables in your `.env` file:
- An [OpenAI API key](https://platform.openai.com/account/api-keys) (or other LLM API key): `OPENAI_API_KEY=sk-...`
- A [Serper.dev](https://serper.dev/) API key: `SERPER_API_KEY=YOUR_KEY_HERE`
- The configuration for your choice of model, such as an API key. See the
[LLM setup guide](/concepts/llms#setting-up-your-llm) to learn how to configure models from any provider.
</Step>
<Step title="Lock and install the dependencies">
- Lock the dependencies and install them by using the CLI command:
@@ -318,7 +317,7 @@ email_summarizer:
Summarize emails into a concise and clear summary
backstory: >
You will create a 5 bullet point summary of the report
llm: provider/model-id # Add your choice of model here
llm: openai/gpt-4o
```
<Tip>

View File

@@ -22,7 +22,7 @@ streamlining the process of finding specific information within large document c
Install the crewai_tools package by running the following command in your terminal:
```shell
uv pip install docx2txt 'crewai[tools]'
pip install 'crewai[tools]'
```
## Example
@@ -76,4 +76,4 @@ tool = DOCXSearchTool(
),
)
)
```
```

View File

@@ -8,10 +8,10 @@ icon: language
## Description
This tool is used to convert natural language to SQL queries. When passed to the agent it will generate queries and then use them to interact with the database.
This tool is used to convert natural language to SQL queries. When passsed to the agent it will generate queries and then use them to interact with the database.
This enables multiple workflows like having an Agent to access the database fetch information based on the goal and then use the information to generate a response, report or any other output.
Along with that provides the ability for the Agent to update the database based on its goal.
Along with that proivdes the ability for the Agent to update the database based on its goal.
**Attention**: Make sure that the Agent has access to a Read-Replica or that is okay for the Agent to run insert/update queries on the database.
@@ -81,4 +81,4 @@ The Tool provides endless possibilities on the logic of the Agent and how it can
```md
DB -> Agent -> ... -> Agent -> DB
```
```

View File

@@ -143,30 +143,12 @@ config = {
"config": {
"model": "text-embedding-ada-002"
}
},
"vectordb": {
"provider": "elasticsearch",
"config": {
"collection_name": "my-collection",
"cloud_id": "deployment-name:xxxx",
"api_key": "your-key",
"verify_certs": False
}
},
"chunker": {
"chunk_size": 400,
"chunk_overlap": 100,
"length_function": "len",
"min_chunk_size": 0
}
}
rag_tool = RagTool(config=config, summarize=True)
```
The internal RAG tool utilizes the Embedchain adapter, allowing you to pass any configuration options that are supported by Embedchain.
You can refer to the [Embedchain documentation](https://docs.embedchain.ai/components/introduction) for details.
Make sure to review the configuration options available in the .yaml file.
## Conclusion
The `RagTool` provides a powerful way to create and query knowledge bases from various data sources. By leveraging Retrieval-Augmented Generation, it enables agents to access and retrieve relevant information efficiently, enhancing their ability to provide accurate and contextually appropriate responses.

View File

@@ -1,6 +1,6 @@
[project]
name = "crewai"
version = "0.119.0"
version = "0.118.0"
description = "Cutting-edge framework for orchestrating role-playing, autonomous AI agents. By fostering collaborative intelligence, CrewAI empowers agents to work together seamlessly, tackling complex tasks."
readme = "README.md"
requires-python = ">=3.10,<3.13"
@@ -11,7 +11,7 @@ dependencies = [
# Core Dependencies
"pydantic>=2.4.2",
"openai>=1.13.3",
"litellm==1.68.0",
"litellm==1.67.1",
"instructor>=1.3.3",
# Text Processing
"pdfplumber>=0.11.4",
@@ -45,7 +45,7 @@ Documentation = "https://docs.crewai.com"
Repository = "https://github.com/crewAIInc/crewAI"
[project.optional-dependencies]
tools = ["crewai-tools~=0.44.0"]
tools = ["crewai-tools~=0.42.2"]
embeddings = [
"tiktoken~=0.7.0"
]
@@ -94,10 +94,8 @@ crewai = "crewai.cli.cli:crewai"
[tool.mypy]
ignore_missing_imports = true
disable_error_code = 'import-untyped,union-attr'
disable_error_code = 'import-untyped'
exclude = ["cli/templates"]
implicit_optional = true
strict_optional = false
[tool.bandit]
exclude_dirs = ["src/crewai/cli/templates"]

View File

@@ -17,7 +17,7 @@ warnings.filterwarnings(
category=UserWarning,
module="pydantic.main",
)
__version__ = "0.119.0"
__version__ = "0.118.0"
__all__ = [
"Agent",
"Crew",

View File

@@ -1,7 +1,6 @@
import shutil
import subprocess
from collections.abc import Sequence
from typing import Any, Literal
from typing import Any, Dict, List, Literal, Optional, Sequence, Type, Union
from pydantic import Field, InstanceOf, PrivateAttr, model_validator
@@ -32,14 +31,6 @@ from crewai.utilities.events.agent_events import (
AgentExecutionStartedEvent,
)
from crewai.utilities.events.crewai_event_bus import crewai_event_bus
from crewai.utilities.events.knowledge_events import (
KnowledgeQueryCompletedEvent,
KnowledgeQueryFailedEvent,
KnowledgeQueryStartedEvent,
KnowledgeRetrievalCompletedEvent,
KnowledgeRetrievalStartedEvent,
KnowledgeSearchQueryFailedEvent,
)
from crewai.utilities.llm_utils import create_llm
from crewai.utilities.token_counter_callback import TokenCalcHandler
from crewai.utilities.training_handler import CrewTrainingHandler
@@ -68,41 +59,40 @@ class Agent(BaseAgent):
step_callback: Callback to be executed after each step of the agent execution.
knowledge_sources: Knowledge sources for the agent.
embedder: Embedder configuration for the agent.
"""
_times_executed: int = PrivateAttr(default=0)
max_execution_time: int | None = Field(
max_execution_time: Optional[int] = Field(
default=None,
description="Maximum execution time for an agent to execute a task",
)
agent_ops_agent_name: str = None # type: ignore # Incompatible types in assignment (expression has type "None", variable has type "str")
agent_ops_agent_id: str = None # type: ignore # Incompatible types in assignment (expression has type "None", variable has type "str")
step_callback: Any | None = Field(
step_callback: Optional[Any] = Field(
default=None,
description="Callback to be executed after each step of the agent execution.",
)
use_system_prompt: bool | None = Field(
use_system_prompt: Optional[bool] = Field(
default=True,
description="Use system prompt for the agent.",
)
llm: str | InstanceOf[BaseLLM] | Any = Field(
description="Language model that will run the agent.", default=None,
llm: Union[str, InstanceOf[BaseLLM], Any] = Field(
description="Language model that will run the agent.", default=None
)
function_calling_llm: str | InstanceOf[BaseLLM] | Any | None = Field(
description="Language model that will run the agent.", default=None,
function_calling_llm: Optional[Union[str, InstanceOf[BaseLLM], Any]] = Field(
description="Language model that will run the agent.", default=None
)
system_template: str | None = Field(
default=None, description="System format for the agent.",
system_template: Optional[str] = Field(
default=None, description="System format for the agent."
)
prompt_template: str | None = Field(
default=None, description="Prompt format for the agent.",
prompt_template: Optional[str] = Field(
default=None, description="Prompt format for the agent."
)
response_template: str | None = Field(
default=None, description="Response format for the agent.",
response_template: Optional[str] = Field(
default=None, description="Response format for the agent."
)
allow_code_execution: bool | None = Field(
default=False, description="Enable code execution for the agent.",
allow_code_execution: Optional[bool] = Field(
default=False, description="Enable code execution for the agent."
)
respect_context_window: bool = Field(
default=True,
@@ -120,22 +110,18 @@ class Agent(BaseAgent):
default="safe",
description="Mode for code execution: 'safe' (using Docker) or 'unsafe' (direct execution).",
)
embedder: dict[str, Any] | None = Field(
embedder: Optional[Dict[str, Any]] = Field(
default=None,
description="Embedder configuration for the agent.",
)
agent_knowledge_context: str | None = Field(
agent_knowledge_context: Optional[str] = Field(
default=None,
description="Knowledge context for the agent.",
)
crew_knowledge_context: str | None = Field(
crew_knowledge_context: Optional[str] = Field(
default=None,
description="Knowledge context for the crew.",
)
knowledge_search_query: str | None = Field(
default=None,
description="Knowledge search query for the agent dynamically generated by the agent.",
)
@model_validator(mode="after")
def post_init_setup(self):
@@ -143,7 +129,7 @@ class Agent(BaseAgent):
self.llm = create_llm(self.llm)
if self.function_calling_llm and not isinstance(
self.function_calling_llm, BaseLLM,
self.function_calling_llm, BaseLLM
):
self.function_calling_llm = create_llm(self.function_calling_llm)
@@ -155,12 +141,12 @@ class Agent(BaseAgent):
return self
def _setup_agent_executor(self) -> None:
def _setup_agent_executor(self):
if not self.cache_handler:
self.cache_handler = CacheHandler()
self.set_cache_handler(self.cache_handler)
def set_knowledge(self, crew_embedder: dict[str, Any] | None = None) -> None:
def set_knowledge(self, crew_embedder: Optional[Dict[str, Any]] = None):
try:
if self.embedder is None and crew_embedder:
self.embedder = crew_embedder
@@ -176,8 +162,7 @@ class Agent(BaseAgent):
storage=self.knowledge_storage or None,
)
except (TypeError, ValueError) as e:
msg = f"Invalid Knowledge Configuration: {e!s}"
raise ValueError(msg)
raise ValueError(f"Invalid Knowledge Configuration: {str(e)}")
def _is_any_available_memory(self) -> bool:
"""Check if any memory is available."""
@@ -199,8 +184,8 @@ class Agent(BaseAgent):
def execute_task(
self,
task: Task,
context: str | None = None,
tools: list[BaseTool] | None = None,
context: Optional[str] = None,
tools: Optional[List[BaseTool]] = None
) -> str:
"""Execute a task with the agent.
@@ -216,7 +201,6 @@ class Agent(BaseAgent):
TimeoutError: If execution exceeds the maximum execution time.
ValueError: If the max execution time is not a positive integer.
RuntimeError: If the agent execution fails for other reasons.
"""
if self.tools_handler:
self.tools_handler.last_used_tool = {} # type: ignore # Incompatible types in assignment (expression has type "dict[Never, Never]", variable has type "ToolCalling")
@@ -232,18 +216,18 @@ class Agent(BaseAgent):
# schema = json.dumps(task.output_json, indent=2)
schema = generate_model_description(task.output_json)
task_prompt += "\n" + self.i18n.slice(
"formatted_task_instructions",
"formatted_task_instructions"
).format(output_format=schema)
elif task.output_pydantic:
schema = generate_model_description(task.output_pydantic)
task_prompt += "\n" + self.i18n.slice(
"formatted_task_instructions",
"formatted_task_instructions"
).format(output_format=schema)
if context:
task_prompt = self.i18n.slice("task_with_context").format(
task=task_prompt, context=context,
task=task_prompt, context=context
)
if self._is_any_available_memory():
@@ -261,65 +245,27 @@ class Agent(BaseAgent):
knowledge_config = (
self.knowledge_config.model_dump() if self.knowledge_config else {}
)
if self.knowledge:
crewai_event_bus.emit(
self,
event=KnowledgeRetrievalStartedEvent(
agent=self,
),
agent_knowledge_snippets = self.knowledge.query(
[task.prompt()], **knowledge_config
)
try:
self.knowledge_search_query = self._get_knowledge_search_query(
task_prompt,
if agent_knowledge_snippets:
self.agent_knowledge_context = extract_knowledge_context(
agent_knowledge_snippets
)
if self.knowledge_search_query:
agent_knowledge_snippets = self.knowledge.query(
[self.knowledge_search_query], **knowledge_config,
)
if agent_knowledge_snippets:
self.agent_knowledge_context = extract_knowledge_context(
agent_knowledge_snippets,
)
if self.agent_knowledge_context:
task_prompt += self.agent_knowledge_context
if self.crew:
knowledge_snippets = self.crew.query_knowledge(
[self.knowledge_search_query], **knowledge_config,
)
if knowledge_snippets:
self.crew_knowledge_context = extract_knowledge_context(
knowledge_snippets,
)
if self.crew_knowledge_context:
task_prompt += self.crew_knowledge_context
if self.agent_knowledge_context:
task_prompt += self.agent_knowledge_context
crewai_event_bus.emit(
self,
event=KnowledgeRetrievalCompletedEvent(
query=self.knowledge_search_query,
agent=self,
retrieved_knowledge=(
(self.agent_knowledge_context or "")
+ (
"\n"
if self.agent_knowledge_context
and self.crew_knowledge_context
else ""
)
+ (self.crew_knowledge_context or "")
),
),
)
except Exception as e:
crewai_event_bus.emit(
self,
event=KnowledgeSearchQueryFailedEvent(
query=self.knowledge_search_query or "",
agent=self,
error=str(e),
),
if self.crew:
knowledge_snippets = self.crew.query_knowledge(
[task.prompt()], **knowledge_config
)
if knowledge_snippets:
self.crew_knowledge_context = extract_knowledge_context(
knowledge_snippets
)
if self.crew_knowledge_context:
task_prompt += self.crew_knowledge_context
tools = tools or self.tools or []
self.create_agent_executor(tools=tools, task=task)
@@ -342,20 +288,12 @@ class Agent(BaseAgent):
# Determine execution method based on timeout setting
if self.max_execution_time is not None:
if (
not isinstance(self.max_execution_time, int)
or self.max_execution_time <= 0
):
msg = "Max Execution time must be a positive integer greater than zero"
raise ValueError(
msg,
)
result = self._execute_with_timeout(
task_prompt, task, self.max_execution_time,
)
if not isinstance(self.max_execution_time, int) or self.max_execution_time <= 0:
raise ValueError("Max Execution time must be a positive integer greater than zero")
result = self._execute_with_timeout(task_prompt, task, self.max_execution_time)
else:
result = self._execute_without_timeout(task_prompt, task)
except TimeoutError as e:
# Propagate TimeoutError without retry
crewai_event_bus.emit(
@@ -366,7 +304,7 @@ class Agent(BaseAgent):
error=str(e),
),
)
raise
raise e
except Exception as e:
if e.__class__.__module__.startswith("litellm"):
# Do not retry on litellm errors
@@ -378,7 +316,7 @@ class Agent(BaseAgent):
error=str(e),
),
)
raise
raise e
self._times_executed += 1
if self._times_executed > self.max_retry_limit:
crewai_event_bus.emit(
@@ -389,7 +327,7 @@ class Agent(BaseAgent):
error=str(e),
),
)
raise
raise e
result = self.execute_task(task, context, tools)
if self.max_rpm and self._rpm_controller:
@@ -407,52 +345,56 @@ class Agent(BaseAgent):
)
return result
def _execute_with_timeout(self, task_prompt: str, task: Task, timeout: int) -> str:
def _execute_with_timeout(
self,
task_prompt: str,
task: Task,
timeout: int
) -> str:
"""Execute a task with a timeout.
Args:
task_prompt: The prompt to send to the agent.
task: The task being executed.
timeout: Maximum execution time in seconds.
Returns:
The output of the agent.
Raises:
TimeoutError: If execution exceeds the timeout.
RuntimeError: If execution fails for other reasons.
"""
import concurrent.futures
with concurrent.futures.ThreadPoolExecutor() as executor:
future = executor.submit(
self._execute_without_timeout, task_prompt=task_prompt, task=task,
self._execute_without_timeout,
task_prompt=task_prompt,
task=task
)
try:
return future.result(timeout=timeout)
except concurrent.futures.TimeoutError:
future.cancel()
msg = f"Task '{task.description}' execution timed out after {timeout} seconds. Consider increasing max_execution_time or optimizing the task."
raise TimeoutError(
msg,
)
raise TimeoutError(f"Task '{task.description}' execution timed out after {timeout} seconds. Consider increasing max_execution_time or optimizing the task.")
except Exception as e:
future.cancel()
msg = f"Task execution failed: {e!s}"
raise RuntimeError(msg)
raise RuntimeError(f"Task execution failed: {str(e)}")
def _execute_without_timeout(self, task_prompt: str, task: Task) -> str:
def _execute_without_timeout(
self,
task_prompt: str,
task: Task
) -> str:
"""Execute a task without a timeout.
Args:
task_prompt: The prompt to send to the agent.
task: The task being executed.
Returns:
The output of the agent.
"""
return self.agent_executor.invoke(
{
@@ -460,19 +402,18 @@ class Agent(BaseAgent):
"tool_names": self.agent_executor.tools_names,
"tools": self.agent_executor.tools_description,
"ask_for_human_input": task.human_input,
},
}
)["output"]
def create_agent_executor(
self, tools: list[BaseTool] | None = None, task=None,
self, tools: Optional[List[BaseTool]] = None, task=None
) -> None:
"""Create an agent executor for the agent.
Returns:
An instance of the CrewAgentExecutor class.
"""
raw_tools: list[BaseTool] = tools or self.tools or []
raw_tools: List[BaseTool] = tools or self.tools or []
parsed_tools = parse_tools(raw_tools)
prompt = Prompts(
@@ -489,7 +430,7 @@ class Agent(BaseAgent):
if self.response_template:
stop_words.append(
self.response_template.split("{{ .Response }}")[1].strip(),
self.response_template.split("{{ .Response }}")[1].strip()
)
self.agent_executor = CrewAgentExecutor(
@@ -514,9 +455,10 @@ class Agent(BaseAgent):
callbacks=[TokenCalcHandler(self._token_process)],
)
def get_delegation_tools(self, agents: list[BaseAgent]):
def get_delegation_tools(self, agents: List[BaseAgent]):
agent_tools = AgentTools(agents=agents)
return agent_tools.tools()
tools = agent_tools.tools()
return tools
def get_multimodal_tools(self) -> Sequence[BaseTool]:
from crewai.tools.agent_tools.add_image_tool import AddImageTool
@@ -532,7 +474,7 @@ class Agent(BaseAgent):
return [CodeInterpreterTool(unsafe_mode=unsafe_mode)]
except ModuleNotFoundError:
self._logger.log(
"info", "Coding tools not available. Install crewai_tools. ",
"info", "Coding tools not available. Install crewai_tools. "
)
def get_output_converter(self, llm, text, model, instructions):
@@ -564,7 +506,7 @@ class Agent(BaseAgent):
)
return task_prompt
def _render_text_description(self, tools: list[Any]) -> str:
def _render_text_description(self, tools: List[Any]) -> str:
"""Render the tool name and description in plain text.
Output will be in the format of:
@@ -574,111 +516,57 @@ class Agent(BaseAgent):
search: This tool is used for search
calculator: This tool is used for math
"""
return "\n".join(
description = "\n".join(
[
f"Tool name: {tool.name}\nTool description:\n{tool.description}"
for tool in tools
],
]
)
return description
def _validate_docker_installation(self) -> None:
"""Check if Docker is installed and running."""
if not shutil.which("docker"):
msg = f"Docker is not installed. Please install Docker to use code execution with agent: {self.role}"
raise RuntimeError(
msg,
f"Docker is not installed. Please install Docker to use code execution with agent: {self.role}"
)
try:
subprocess.run(
["docker", "info"],
check=True,
capture_output=True,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
except subprocess.CalledProcessError:
msg = f"Docker is not running. Please start Docker to use code execution with agent: {self.role}"
raise RuntimeError(
msg,
f"Docker is not running. Please start Docker to use code execution with agent: {self.role}"
)
def __repr__(self) -> str:
def __repr__(self):
return f"Agent(role={self.role}, goal={self.goal}, backstory={self.backstory})"
@property
def fingerprint(self) -> Fingerprint:
"""Get the agent's fingerprint.
"""
Get the agent's fingerprint.
Returns:
Fingerprint: The agent's fingerprint
"""
return self.security_config.fingerprint
def set_fingerprint(self, fingerprint: Fingerprint) -> None:
def set_fingerprint(self, fingerprint: Fingerprint):
self.security_config.fingerprint = fingerprint
def _get_knowledge_search_query(self, task_prompt: str) -> str | None:
"""Generate a search query for the knowledge base based on the task description."""
crewai_event_bus.emit(
self,
event=KnowledgeQueryStartedEvent(
task_prompt=task_prompt,
agent=self,
),
)
query = self.i18n.slice("knowledge_search_query").format(
task_prompt=task_prompt,
)
rewriter_prompt = self.i18n.slice("knowledge_search_query_system_prompt")
if not isinstance(self.llm, BaseLLM):
self._logger.log(
"warning",
f"Knowledge search query failed: LLM for agent '{self.role}' is not an instance of BaseLLM",
)
crewai_event_bus.emit(
self,
event=KnowledgeQueryFailedEvent(
agent=self,
error="LLM is not compatible with knowledge search queries",
),
)
return None
try:
rewritten_query = self.llm.call(
[
{
"role": "system",
"content": rewriter_prompt,
},
{"role": "user", "content": query},
],
)
crewai_event_bus.emit(
self,
event=KnowledgeQueryCompletedEvent(
query=query,
agent=self,
),
)
return rewritten_query
except Exception as e:
crewai_event_bus.emit(
self,
event=KnowledgeQueryFailedEvent(
agent=self,
error=str(e),
),
)
return None
def kickoff(
self,
messages: str | list[dict[str, str]],
response_format: type[Any] | None = None,
messages: Union[str, List[Dict[str, str]]],
response_format: Optional[Type[Any]] = None,
) -> LiteAgentOutput:
"""Execute the agent with the given messages using a LiteAgent instance.
"""
Execute the agent with the given messages using a LiteAgent instance.
This method is useful when you want to use the Agent configuration but
with the simpler and more direct execution flow of LiteAgent.
@@ -691,7 +579,6 @@ class Agent(BaseAgent):
Returns:
LiteAgentOutput: The result of the agent execution.
"""
lite_agent = LiteAgent(
role=self.role,
@@ -712,10 +599,11 @@ class Agent(BaseAgent):
async def kickoff_async(
self,
messages: str | list[dict[str, str]],
response_format: type[Any] | None = None,
messages: Union[str, List[Dict[str, str]]],
response_format: Optional[Type[Any]] = None,
) -> LiteAgentOutput:
"""Execute the agent asynchronously with the given messages using a LiteAgent instance.
"""
Execute the agent asynchronously with the given messages using a LiteAgent instance.
This is the async version of the kickoff method.
@@ -727,7 +615,6 @@ class Agent(BaseAgent):
Returns:
LiteAgentOutput: The result of the agent execution.
"""
lite_agent = LiteAgent(
role=self.role,

View File

@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Any
from typing import Any, Dict, List, Optional
from pydantic import PrivateAttr
@@ -16,27 +16,27 @@ class BaseAgentAdapter(BaseAgent, ABC):
"""
adapted_structured_output: bool = False
_agent_config: dict[str, Any] | None = PrivateAttr(default=None)
_agent_config: Optional[Dict[str, Any]] = PrivateAttr(default=None)
model_config = {"arbitrary_types_allowed": True}
def __init__(self, agent_config: dict[str, Any] | None = None, **kwargs: Any) -> None:
def __init__(self, agent_config: Optional[Dict[str, Any]] = None, **kwargs: Any):
super().__init__(adapted_agent=True, **kwargs)
self._agent_config = agent_config
@abstractmethod
def configure_tools(self, tools: list[BaseTool] | None = None) -> None:
def configure_tools(self, tools: Optional[List[BaseTool]] = None) -> None:
"""Configure and adapt tools for the specific agent implementation.
Args:
tools: Optional list of BaseTool instances to be configured
"""
pass
def configure_structured_output(self, structured_output: Any) -> None:
"""Configure the structured output for the specific agent implementation.
Args:
structured_output: The structured output to be configured
"""
pass

View File

@@ -8,7 +8,7 @@ class BaseConverterAdapter(ABC):
converter adapters must implement for converting structured output.
"""
def __init__(self, agent_adapter) -> None:
def __init__(self, agent_adapter):
self.agent_adapter = agent_adapter
@abstractmethod
@@ -16,11 +16,14 @@ class BaseConverterAdapter(ABC):
"""Configure agents to return structured output.
Must support json and pydantic output.
"""
pass
@abstractmethod
def enhance_system_prompt(self, base_prompt: str) -> str:
"""Enhance the system prompt with structured output instructions."""
pass
@abstractmethod
def post_process_result(self, result: str) -> str:
"""Post-process the result to ensure it matches the expected format: string."""
pass

View File

@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Any
from typing import Any, List, Optional
from crewai.tools.base_tool import BaseTool
@@ -12,23 +12,23 @@ class BaseToolAdapter(ABC):
different frameworks and platforms.
"""
original_tools: list[BaseTool]
converted_tools: list[Any]
original_tools: List[BaseTool]
converted_tools: List[Any]
def __init__(self, tools: list[BaseTool] | None = None) -> None:
def __init__(self, tools: Optional[List[BaseTool]] = None):
self.original_tools = tools or []
self.converted_tools = []
@abstractmethod
def configure_tools(self, tools: list[BaseTool]) -> None:
def configure_tools(self, tools: List[BaseTool]) -> None:
"""Configure and convert tools for the specific implementation.
Args:
tools: List of BaseTool instances to be configured and converted
"""
pass
def tools(self) -> list[Any]:
def tools(self) -> List[Any]:
"""Return all converted tools."""
return self.converted_tools

View File

@@ -1,4 +1,4 @@
from typing import Any
from typing import Any, AsyncIterable, Dict, List, Optional
from pydantic import Field, PrivateAttr
@@ -52,17 +52,16 @@ class LangGraphAgentAdapter(BaseAgentAdapter):
role: str,
goal: str,
backstory: str,
tools: list[BaseTool] | None = None,
tools: Optional[List[BaseTool]] = None,
llm: Any = None,
max_iterations: int = 10,
agent_config: dict[str, Any] | None = None,
agent_config: Optional[Dict[str, Any]] = None,
**kwargs,
) -> None:
):
"""Initialize the LangGraph agent adapter."""
if not LANGGRAPH_AVAILABLE:
msg = "LangGraph Agent Dependencies are not installed. Please install it using `uv add langchain-core langgraph`"
raise ImportError(
msg,
"LangGraph Agent Dependencies are not installed. Please install it using `uv add langchain-core langgraph`"
)
super().__init__(
role=role,
@@ -83,7 +82,7 @@ class LangGraphAgentAdapter(BaseAgentAdapter):
try:
self._memory = MemorySaver()
converted_tools: list[Any] = self._tool_adapter.tools()
converted_tools: List[Any] = self._tool_adapter.tools()
if self._agent_config:
self._graph = create_react_agent(
model=self.llm,
@@ -102,18 +101,18 @@ class LangGraphAgentAdapter(BaseAgentAdapter):
except ImportError as e:
self._logger.log(
"error", f"Failed to import LangGraph dependencies: {e!s}",
"error", f"Failed to import LangGraph dependencies: {str(e)}"
)
raise
except Exception as e:
self._logger.log("error", f"Error setting up LangGraph agent: {e!s}")
self._logger.log("error", f"Error setting up LangGraph agent: {str(e)}")
raise
def _build_system_prompt(self) -> str:
"""Build a system prompt for the LangGraph agent."""
base_prompt = f"""
You are {self.role}.
Your goal is: {self.goal}
Your backstory: {self.backstory}
@@ -125,8 +124,8 @@ class LangGraphAgentAdapter(BaseAgentAdapter):
def execute_task(
self,
task: Any,
context: str | None = None,
tools: list[BaseTool] | None = None,
context: Optional[str] = None,
tools: Optional[List[BaseTool]] = None,
) -> str:
"""Execute a task using the LangGraph workflow."""
self.create_agent_executor(tools)
@@ -138,7 +137,7 @@ class LangGraphAgentAdapter(BaseAgentAdapter):
if context:
task_prompt = self.i18n.slice("task_with_context").format(
task=task_prompt, context=context,
task=task_prompt, context=context
)
crewai_event_bus.emit(
@@ -160,7 +159,7 @@ class LangGraphAgentAdapter(BaseAgentAdapter):
"messages": [
("system", self._build_system_prompt()),
("user", task_prompt),
],
]
},
config,
)
@@ -181,14 +180,14 @@ class LangGraphAgentAdapter(BaseAgentAdapter):
crewai_event_bus.emit(
self,
event=AgentExecutionCompletedEvent(
agent=self, task=task, output=final_answer,
agent=self, task=task, output=final_answer
),
)
return final_answer
except Exception as e:
self._logger.log("error", f"Error executing LangGraph task: {e!s}")
self._logger.log("error", f"Error executing LangGraph task: {str(e)}")
crewai_event_bus.emit(
self,
event=AgentExecutionErrorEvent(
@@ -199,11 +198,11 @@ class LangGraphAgentAdapter(BaseAgentAdapter):
)
raise
def create_agent_executor(self, tools: list[BaseTool] | None = None) -> None:
def create_agent_executor(self, tools: Optional[List[BaseTool]] = None) -> None:
"""Configure the LangGraph agent for execution."""
self.configure_tools(tools)
def configure_tools(self, tools: list[BaseTool] | None = None) -> None:
def configure_tools(self, tools: Optional[List[BaseTool]] = None) -> None:
"""Configure tools for the LangGraph agent."""
if tools:
all_tools = list(self.tools or []) + list(tools or [])
@@ -211,13 +210,13 @@ class LangGraphAgentAdapter(BaseAgentAdapter):
available_tools = self._tool_adapter.tools()
self._graph.tools = available_tools
def get_delegation_tools(self, agents: list[BaseAgent]) -> list[BaseTool]:
def get_delegation_tools(self, agents: List[BaseAgent]) -> List[BaseTool]:
"""Implement delegation tools support for LangGraph."""
agent_tools = AgentTools(agents=agents)
return agent_tools.tools()
def get_output_converter(
self, llm: Any, text: str, model: Any, instructions: str,
self, llm: Any, text: str, model: Any, instructions: str
) -> Any:
"""Convert output format if needed."""
return Converter(llm=llm, text=text, model=model, instructions=instructions)

View File

@@ -1,25 +1,29 @@
import inspect
from typing import Any
from typing import Any, List, Optional
from crewai.agents.agent_adapters.base_tool_adapter import BaseToolAdapter
from crewai.tools.base_tool import BaseTool
class LangGraphToolAdapter(BaseToolAdapter):
"""Adapts CrewAI tools to LangGraph agent tool compatible format."""
"""Adapts CrewAI tools to LangGraph agent tool compatible format"""
def __init__(self, tools: list[BaseTool] | None = None) -> None:
def __init__(self, tools: Optional[List[BaseTool]] = None):
self.original_tools = tools or []
self.converted_tools = []
def configure_tools(self, tools: list[BaseTool]) -> None:
"""Configure and convert CrewAI tools to LangGraph-compatible format.
def configure_tools(self, tools: List[BaseTool]) -> None:
"""
Configure and convert CrewAI tools to LangGraph-compatible format.
LangGraph expects tools in langchain_core.tools format.
"""
from langchain_core.tools import BaseTool, StructuredTool
converted_tools = []
all_tools = tools + self.original_tools if self.original_tools else tools
if self.original_tools:
all_tools = tools + self.original_tools
else:
all_tools = tools
for tool in all_tools:
if isinstance(tool, BaseTool):
converted_tools.append(tool)
@@ -53,5 +57,5 @@ class LangGraphToolAdapter(BaseToolAdapter):
self.converted_tools = converted_tools
def tools(self) -> list[Any]:
def tools(self) -> List[Any]:
return self.converted_tools or []

View File

@@ -5,10 +5,10 @@ from crewai.utilities.converter import generate_model_description
class LangGraphConverterAdapter(BaseConverterAdapter):
"""Adapter for handling structured output conversion in LangGraph agents."""
"""Adapter for handling structured output conversion in LangGraph agents"""
def __init__(self, agent_adapter) -> None:
"""Initialize the converter adapter with a reference to the agent adapter."""
def __init__(self, agent_adapter):
"""Initialize the converter adapter with a reference to the agent adapter"""
self.agent_adapter = agent_adapter
self._output_format = None
self._schema = None
@@ -32,7 +32,7 @@ class LangGraphConverterAdapter(BaseConverterAdapter):
self._system_prompt_appendix = self._generate_system_prompt_appendix()
def _generate_system_prompt_appendix(self) -> str:
"""Generate an appendix for the system prompt to enforce structured output."""
"""Generate an appendix for the system prompt to enforce structured output"""
if not self._output_format or not self._schema:
return ""
@@ -41,19 +41,19 @@ Important: Your final answer MUST be provided in the following structured format
{self._schema}
DO NOT include any markdown code blocks, backticks, or other formatting around your response.
DO NOT include any markdown code blocks, backticks, or other formatting around your response.
The output should be raw JSON that exactly matches the specified schema.
"""
def enhance_system_prompt(self, original_prompt: str) -> str:
"""Add structured output instructions to the system prompt if needed."""
"""Add structured output instructions to the system prompt if needed"""
if not self._system_prompt_appendix:
return original_prompt
return f"{original_prompt}\n{self._system_prompt_appendix}"
def post_process_result(self, result: str) -> str:
"""Post-process the result to ensure it matches the expected format."""
"""Post-process the result to ensure it matches the expected format"""
if not self._output_format:
return result

View File

@@ -1,4 +1,4 @@
from typing import Any
from typing import Any, List, Optional
from pydantic import Field, PrivateAttr
@@ -29,13 +29,13 @@ except ImportError:
class OpenAIAgentAdapter(BaseAgentAdapter):
"""Adapter for OpenAI Assistants."""
"""Adapter for OpenAI Assistants"""
model_config = {"arbitrary_types_allowed": True}
_openai_agent: "OpenAIAgent" = PrivateAttr()
_logger: Logger = PrivateAttr(default_factory=lambda: Logger())
_active_thread: str | None = PrivateAttr(default=None)
_active_thread: Optional[str] = PrivateAttr(default=None)
function_calling_llm: Any = Field(default=None)
step_callback: Any = Field(default=None)
_tool_adapter: "OpenAIAgentToolAdapter" = PrivateAttr()
@@ -44,35 +44,35 @@ class OpenAIAgentAdapter(BaseAgentAdapter):
def __init__(
self,
model: str = "gpt-4o-mini",
tools: list[BaseTool] | None = None,
agent_config: dict | None = None,
tools: Optional[List[BaseTool]] = None,
agent_config: Optional[dict] = None,
**kwargs,
) -> None:
):
if not OPENAI_AVAILABLE:
msg = "OpenAI Agent Dependencies are not installed. Please install it using `uv add openai-agents`"
raise ImportError(
msg,
"OpenAI Agent Dependencies are not installed. Please install it using `uv add openai-agents`"
)
role = kwargs.pop("role", None)
goal = kwargs.pop("goal", None)
backstory = kwargs.pop("backstory", None)
super().__init__(
role=role,
goal=goal,
backstory=backstory,
tools=tools,
agent_config=agent_config,
**kwargs,
)
self._tool_adapter = OpenAIAgentToolAdapter(tools=tools)
self.llm = model
self._converter_adapter = OpenAIConverterAdapter(self)
else:
role = kwargs.pop("role", None)
goal = kwargs.pop("goal", None)
backstory = kwargs.pop("backstory", None)
super().__init__(
role=role,
goal=goal,
backstory=backstory,
tools=tools,
agent_config=agent_config,
**kwargs,
)
self._tool_adapter = OpenAIAgentToolAdapter(tools=tools)
self.llm = model
self._converter_adapter = OpenAIConverterAdapter(self)
def _build_system_prompt(self) -> str:
"""Build a system prompt for the OpenAI agent."""
base_prompt = f"""
You are {self.role}.
Your goal is: {self.goal}
Your backstory: {self.backstory}
@@ -84,10 +84,10 @@ class OpenAIAgentAdapter(BaseAgentAdapter):
def execute_task(
self,
task: Any,
context: str | None = None,
tools: list[BaseTool] | None = None,
context: Optional[str] = None,
tools: Optional[List[BaseTool]] = None,
) -> str:
"""Execute a task using the OpenAI Assistant."""
"""Execute a task using the OpenAI Assistant"""
self._converter_adapter.configure_structured_output(task)
self.create_agent_executor(tools)
@@ -98,7 +98,7 @@ class OpenAIAgentAdapter(BaseAgentAdapter):
task_prompt = task.prompt()
if context:
task_prompt = self.i18n.slice("task_with_context").format(
task=task_prompt, context=context,
task=task_prompt, context=context
)
crewai_event_bus.emit(
self,
@@ -114,13 +114,13 @@ class OpenAIAgentAdapter(BaseAgentAdapter):
crewai_event_bus.emit(
self,
event=AgentExecutionCompletedEvent(
agent=self, task=task, output=final_answer,
agent=self, task=task, output=final_answer
),
)
return final_answer
except Exception as e:
self._logger.log("error", f"Error executing OpenAI task: {e!s}")
self._logger.log("error", f"Error executing OpenAI task: {str(e)}")
crewai_event_bus.emit(
self,
event=AgentExecutionErrorEvent(
@@ -131,8 +131,9 @@ class OpenAIAgentAdapter(BaseAgentAdapter):
)
raise
def create_agent_executor(self, tools: list[BaseTool] | None = None) -> None:
"""Configure the OpenAI agent for execution.
def create_agent_executor(self, tools: Optional[List[BaseTool]] = None) -> None:
"""
Configure the OpenAI agent for execution.
While OpenAI handles execution differently through Runner,
we can use this method to set up tools and configurations.
"""
@@ -151,27 +152,27 @@ class OpenAIAgentAdapter(BaseAgentAdapter):
self.agent_executor = Runner
def configure_tools(self, tools: list[BaseTool] | None = None) -> None:
"""Configure tools for the OpenAI Assistant."""
def configure_tools(self, tools: Optional[List[BaseTool]] = None) -> None:
"""Configure tools for the OpenAI Assistant"""
if tools:
self._tool_adapter.configure_tools(tools)
if self._tool_adapter.converted_tools:
self._openai_agent.tools = self._tool_adapter.converted_tools
def handle_execution_result(self, result: Any) -> str:
"""Process OpenAI Assistant execution result converting any structured output to a string."""
"""Process OpenAI Assistant execution result converting any structured output to a string"""
return self._converter_adapter.post_process_result(result.final_output)
def get_delegation_tools(self, agents: list[BaseAgent]) -> list[BaseTool]:
"""Implement delegation tools support."""
def get_delegation_tools(self, agents: List[BaseAgent]) -> List[BaseTool]:
"""Implement delegation tools support"""
agent_tools = AgentTools(agents=agents)
return agent_tools.tools()
tools = agent_tools.tools()
return tools
def configure_structured_output(self, task) -> None:
"""Configure the structured output for the specific agent implementation.
Args:
structured_output: The structured output to be configured
"""
self._converter_adapter.configure_structured_output(task)

View File

@@ -1,5 +1,5 @@
import inspect
from typing import Any
from typing import Any, List, Optional
from agents import FunctionTool, Tool
@@ -8,36 +8,42 @@ from crewai.tools import BaseTool
class OpenAIAgentToolAdapter(BaseToolAdapter):
"""Adapter for OpenAI Assistant tools."""
"""Adapter for OpenAI Assistant tools"""
def __init__(self, tools: list[BaseTool] | None = None) -> None:
def __init__(self, tools: Optional[List[BaseTool]] = None):
self.original_tools = tools or []
def configure_tools(self, tools: list[BaseTool]) -> None:
"""Configure tools for the OpenAI Assistant."""
all_tools = tools + self.original_tools if self.original_tools else tools
def configure_tools(self, tools: List[BaseTool]) -> None:
"""Configure tools for the OpenAI Assistant"""
if self.original_tools:
all_tools = tools + self.original_tools
else:
all_tools = tools
if all_tools:
self.converted_tools = self._convert_tools_to_openai_format(all_tools)
def _convert_tools_to_openai_format(
self, tools: list[BaseTool] | None,
) -> list[Tool]:
"""Convert CrewAI tools to OpenAI Assistant tool format."""
self, tools: Optional[List[BaseTool]]
) -> List[Tool]:
"""Convert CrewAI tools to OpenAI Assistant tool format"""
if not tools:
return []
def sanitize_tool_name(name: str) -> str:
"""Convert tool name to match OpenAI's required pattern."""
"""Convert tool name to match OpenAI's required pattern"""
import re
return re.sub(r"[^a-zA-Z0-9_-]", "_", name).lower()
sanitized = re.sub(r"[^a-zA-Z0-9_-]", "_", name).lower()
return sanitized
def create_tool_wrapper(tool: BaseTool):
"""Create a wrapper function that handles the OpenAI function tool interface."""
"""Create a wrapper function that handles the OpenAI function tool interface"""
async def wrapper(context_wrapper: Any, arguments: Any) -> Any:
# Get the parameter name from the schema
param_name = next(iter(tool.args_schema.model_json_schema()["properties"].keys()))
param_name = list(
tool.args_schema.model_json_schema()["properties"].keys()
)[0]
# Handle different argument types
if isinstance(arguments, dict):

View File

@@ -7,7 +7,8 @@ from crewai.utilities.i18n import I18N
class OpenAIConverterAdapter(BaseConverterAdapter):
"""Adapter for handling structured output conversion in OpenAI agents.
"""
Adapter for handling structured output conversion in OpenAI agents.
This adapter enhances the OpenAI agent to handle structured output formats
and post-processes the results when needed.
@@ -16,22 +17,21 @@ class OpenAIConverterAdapter(BaseConverterAdapter):
_output_format: The expected output format (json, pydantic, or None)
_schema: The schema description for the expected output
_output_model: The Pydantic model for the output
"""
def __init__(self, agent_adapter) -> None:
"""Initialize the converter adapter with a reference to the agent adapter."""
def __init__(self, agent_adapter):
"""Initialize the converter adapter with a reference to the agent adapter"""
self.agent_adapter = agent_adapter
self._output_format = None
self._schema = None
self._output_model = None
def configure_structured_output(self, task) -> None:
"""Configure the structured output for OpenAI agent based on task requirements.
"""
Configure the structured output for OpenAI agent based on task requirements.
Args:
task: The task containing output format requirements
"""
# Reset configuration
self._output_format = None
@@ -55,14 +55,14 @@ class OpenAIConverterAdapter(BaseConverterAdapter):
self._output_model = task.output_pydantic
def enhance_system_prompt(self, base_prompt: str) -> str:
"""Enhance the base system prompt with structured output requirements if needed.
"""
Enhance the base system prompt with structured output requirements if needed.
Args:
base_prompt: The original system prompt
Returns:
Enhanced system prompt with output format instructions if needed
"""
if not self._output_format:
return base_prompt
@@ -76,7 +76,8 @@ class OpenAIConverterAdapter(BaseConverterAdapter):
return f"{base_prompt}\n\n{output_schema}"
def post_process_result(self, result: str) -> str:
"""Post-process the result to ensure it matches the expected format.
"""
Post-process the result to ensure it matches the expected format.
This method attempts to extract valid JSON from the result if necessary.
@@ -85,7 +86,6 @@ class OpenAIConverterAdapter(BaseConverterAdapter):
Returns:
Processed result conforming to the expected output format
"""
if not self._output_format:
return result

View File

@@ -1,9 +1,8 @@
import uuid
from abc import ABC, abstractmethod
from collections.abc import Callable
from copy import copy as shallow_copy
from hashlib import md5
from typing import Any, TypeVar
from typing import Any, Callable, Dict, List, Optional, TypeVar
from pydantic import (
UUID4,
@@ -15,7 +14,6 @@ from pydantic import (
model_validator,
)
from pydantic_core import PydanticCustomError
from typing_extensions import Self
from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess
from crewai.agents.cache.cache_handler import CacheHandler
@@ -27,6 +25,7 @@ from crewai.security.security_config import SecurityConfig
from crewai.tools.base_tool import BaseTool, Tool
from crewai.utilities import I18N, Logger, RPMController
from crewai.utilities.config import process_config
from crewai.utilities.converter import Converter
from crewai.utilities.string_utils import interpolate_only
T = TypeVar("T", bound="BaseAgent")
@@ -78,31 +77,30 @@ class BaseAgent(ABC, BaseModel):
Set the rpm controller for the agent.
set_private_attrs() -> "BaseAgent":
Set private attributes.
"""
__hash__ = object.__hash__ # type: ignore
_logger: Logger = PrivateAttr(default_factory=lambda: Logger(verbose=False))
_rpm_controller: RPMController | None = PrivateAttr(default=None)
_rpm_controller: Optional[RPMController] = PrivateAttr(default=None)
_request_within_rpm_limit: Any = PrivateAttr(default=None)
_original_role: str | None = PrivateAttr(default=None)
_original_goal: str | None = PrivateAttr(default=None)
_original_backstory: str | None = PrivateAttr(default=None)
_original_role: Optional[str] = PrivateAttr(default=None)
_original_goal: Optional[str] = PrivateAttr(default=None)
_original_backstory: Optional[str] = PrivateAttr(default=None)
_token_process: TokenProcess = PrivateAttr(default_factory=TokenProcess)
id: UUID4 = Field(default_factory=uuid.uuid4, frozen=True)
role: str = Field(description="Role of the agent")
goal: str = Field(description="Objective of the agent")
backstory: str = Field(description="Backstory of the agent")
config: dict[str, Any] | None = Field(
description="Configuration for the agent", default=None, exclude=True,
config: Optional[Dict[str, Any]] = Field(
description="Configuration for the agent", default=None, exclude=True
)
cache: bool = Field(
default=True, description="Whether the agent should use a cache for tool usage.",
default=True, description="Whether the agent should use a cache for tool usage."
)
verbose: bool = Field(
default=False, description="Verbose mode for the Agent Execution",
default=False, description="Verbose mode for the Agent Execution"
)
max_rpm: int | None = Field(
max_rpm: Optional[int] = Field(
default=None,
description="Maximum number of requests per minute for the agent execution to be respected.",
)
@@ -110,41 +108,41 @@ class BaseAgent(ABC, BaseModel):
default=False,
description="Enable agent to delegate and ask questions among each other.",
)
tools: list[BaseTool] | None = Field(
default_factory=list, description="Tools at agents' disposal",
tools: Optional[List[BaseTool]] = Field(
default_factory=list, description="Tools at agents' disposal"
)
max_iter: int = Field(
default=25, description="Maximum iterations for an agent to execute a task",
default=25, description="Maximum iterations for an agent to execute a task"
)
agent_executor: InstanceOf = Field(
default=None, description="An instance of the CrewAgentExecutor class.",
default=None, description="An instance of the CrewAgentExecutor class."
)
llm: Any = Field(
default=None, description="Language model that will run the agent.",
default=None, description="Language model that will run the agent."
)
crew: Any = Field(default=None, description="Crew to which the agent belongs.")
i18n: I18N = Field(default=I18N(), description="Internationalization settings.")
cache_handler: InstanceOf[CacheHandler] | None = Field(
default=None, description="An instance of the CacheHandler class.",
cache_handler: Optional[InstanceOf[CacheHandler]] = Field(
default=None, description="An instance of the CacheHandler class."
)
tools_handler: InstanceOf[ToolsHandler] = Field(
default_factory=ToolsHandler,
description="An instance of the ToolsHandler class.",
)
tools_results: list[dict[str, Any]] = Field(
default=[], description="Results of the tools used by the agent.",
tools_results: List[Dict[str, Any]] = Field(
default=[], description="Results of the tools used by the agent."
)
max_tokens: int | None = Field(
default=None, description="Maximum number of tokens for the agent's execution.",
max_tokens: Optional[int] = Field(
default=None, description="Maximum number of tokens for the agent's execution."
)
knowledge: Knowledge | None = Field(
default=None, description="Knowledge for the agent.",
knowledge: Optional[Knowledge] = Field(
default=None, description="Knowledge for the agent."
)
knowledge_sources: list[BaseKnowledgeSource] | None = Field(
knowledge_sources: Optional[List[BaseKnowledgeSource]] = Field(
default=None,
description="Knowledge sources for the agent.",
)
knowledge_storage: Any | None = Field(
knowledge_storage: Optional[Any] = Field(
default=None,
description="Custom knowledge storage for the agent.",
)
@@ -152,13 +150,13 @@ class BaseAgent(ABC, BaseModel):
default_factory=SecurityConfig,
description="Security configuration for the agent, including fingerprinting.",
)
callbacks: list[Callable] = Field(
default=[], description="Callbacks to be used for the agent",
callbacks: List[Callable] = Field(
default=[], description="Callbacks to be used for the agent"
)
adapted_agent: bool = Field(
default=False, description="Whether the agent is adapted",
default=False, description="Whether the agent is adapted"
)
knowledge_config: KnowledgeConfig | None = Field(
knowledge_config: Optional[KnowledgeConfig] = Field(
default=None,
description="Knowledge configuration for the agent such as limits and threshold",
)
@@ -170,7 +168,7 @@ class BaseAgent(ABC, BaseModel):
@field_validator("tools")
@classmethod
def validate_tools(cls, tools: list[Any]) -> list[BaseTool]:
def validate_tools(cls, tools: List[Any]) -> List[BaseTool]:
"""Validate and process the tools provided to the agent.
This method ensures that each tool is either an instance of BaseTool
@@ -190,14 +188,11 @@ class BaseAgent(ABC, BaseModel):
# Tool has the required attributes, create a Tool instance
processed_tools.append(Tool.from_langchain(tool))
else:
msg = (
raise ValueError(
f"Invalid tool type: {type(tool)}. "
"Tool must be an instance of BaseTool or "
"an object with 'name', 'func', and 'description' attributes."
)
raise ValueError(
msg,
)
return processed_tools
@model_validator(mode="after")
@@ -205,16 +200,15 @@ class BaseAgent(ABC, BaseModel):
# Validate required fields
for field in ["role", "goal", "backstory"]:
if getattr(self, field) is None:
msg = f"{field} must be provided either directly or through config"
raise ValueError(
msg,
f"{field} must be provided either directly or through config"
)
# Set private attributes
self._logger = Logger(verbose=self.verbose)
if self.max_rpm and not self._rpm_controller:
self._rpm_controller = RPMController(
max_rpm=self.max_rpm, logger=self._logger,
max_rpm=self.max_rpm, logger=self._logger
)
if not self._token_process:
self._token_process = TokenProcess()
@@ -227,11 +221,10 @@ class BaseAgent(ABC, BaseModel):
@field_validator("id", mode="before")
@classmethod
def _deny_user_set_id(cls, v: UUID4 | None) -> None:
def _deny_user_set_id(cls, v: Optional[UUID4]) -> None:
if v:
msg = "may_not_set_field"
raise PydanticCustomError(
msg, "This field is not to be set by the user.", {},
"may_not_set_field", "This field is not to be set by the user.", {}
)
@model_validator(mode="after")
@@ -240,7 +233,7 @@ class BaseAgent(ABC, BaseModel):
self._logger = Logger(verbose=self.verbose)
if self.max_rpm and not self._rpm_controller:
self._rpm_controller = RPMController(
max_rpm=self.max_rpm, logger=self._logger,
max_rpm=self.max_rpm, logger=self._logger
)
if not self._token_process:
self._token_process = TokenProcess()
@@ -259,8 +252,8 @@ class BaseAgent(ABC, BaseModel):
def execute_task(
self,
task: Any,
context: str | None = None,
tools: list[BaseTool] | None = None,
context: Optional[str] = None,
tools: Optional[List[BaseTool]] = None,
) -> str:
pass
@@ -269,10 +262,11 @@ class BaseAgent(ABC, BaseModel):
pass
@abstractmethod
def get_delegation_tools(self, agents: list["BaseAgent"]) -> list[BaseTool]:
def get_delegation_tools(self, agents: List["BaseAgent"]) -> List[BaseTool]:
"""Set the task tools that init BaseAgenTools class."""
pass
def copy(self) -> Self: # type: ignore # Signature of "copy" incompatible with supertype "BaseModel"
def copy(self: T) -> T: # type: ignore # Signature of "copy" incompatible with supertype "BaseModel"
"""Create a deep copy of the Agent."""
exclude = {
"id",
@@ -315,7 +309,7 @@ class BaseAgent(ABC, BaseModel):
copied_data = self.model_dump(exclude=exclude)
copied_data = {k: v for k, v in copied_data.items() if v is not None}
return type(self)(
copied_agent = type(self)(
**copied_data,
llm=existing_llm,
tools=self.tools,
@@ -324,8 +318,9 @@ class BaseAgent(ABC, BaseModel):
knowledge_storage=copied_knowledge_storage,
)
return copied_agent
def interpolate_inputs(self, inputs: dict[str, Any]) -> None:
def interpolate_inputs(self, inputs: Dict[str, Any]) -> None:
"""Interpolate inputs into the agent description and backstory."""
if self._original_role is None:
self._original_role = self.role
@@ -336,13 +331,13 @@ class BaseAgent(ABC, BaseModel):
if inputs:
self.role = interpolate_only(
input_string=self._original_role, inputs=inputs,
input_string=self._original_role, inputs=inputs
)
self.goal = interpolate_only(
input_string=self._original_goal, inputs=inputs,
input_string=self._original_goal, inputs=inputs
)
self.backstory = interpolate_only(
input_string=self._original_backstory, inputs=inputs,
input_string=self._original_backstory, inputs=inputs
)
def set_cache_handler(self, cache_handler: CacheHandler) -> None:
@@ -350,7 +345,6 @@ class BaseAgent(ABC, BaseModel):
Args:
cache_handler: An instance of the CacheHandler class.
"""
self.tools_handler = ToolsHandler()
if self.cache:
@@ -363,11 +357,10 @@ class BaseAgent(ABC, BaseModel):
Args:
rpm_controller: An instance of the RPMController class.
"""
if not self._rpm_controller:
self._rpm_controller = rpm_controller
self.create_agent_executor()
def set_knowledge(self, crew_embedder: dict[str, Any] | None = None) -> None:
def set_knowledge(self, crew_embedder: Optional[Dict[str, Any]] = None):
pass

View File

@@ -1,4 +1,3 @@
import contextlib
import time
from typing import TYPE_CHECKING
@@ -44,7 +43,8 @@ class CrewAgentExecutorMixin:
},
agent=self.agent.role,
)
except Exception:
except Exception as e:
print(f"Failed to add to short term memory: {e}")
pass
def _create_external_memory(self, output) -> None:
@@ -56,7 +56,7 @@ class CrewAgentExecutorMixin:
and hasattr(self.crew, "_external_memory")
and self.crew._external_memory
):
with contextlib.suppress(Exception):
try:
self.crew._external_memory.save(
value=output.text,
metadata={
@@ -64,6 +64,9 @@ class CrewAgentExecutorMixin:
},
agent=self.agent.role,
)
except Exception as e:
print(f"Failed to add to external memory: {e}")
pass
def _create_long_term_memory(self, output) -> None:
"""Create and save long-term and entity memory items based on evaluation."""
@@ -100,13 +103,15 @@ class CrewAgentExecutorMixin:
type=entity.type,
description=entity.description,
relationships="\n".join(
[f"- {r}" for r in entity.relationships],
[f"- {r}" for r in entity.relationships]
),
)
self.crew._entity_memory.save(entity_memory)
except AttributeError:
except AttributeError as e:
print(f"Missing attributes for long term memory: {e}")
pass
except Exception:
except Exception as e:
print(f"Failed to add to long term memory: {e}")
pass
elif (
self.crew
@@ -121,7 +126,7 @@ class CrewAgentExecutorMixin:
def _ask_human_input(self, final_answer: str) -> str:
"""Prompt human input with mode-appropriate messaging."""
self._printer.print(
content=f"\033[1m\033[95m ## Final Result:\033[00m \033[92m{final_answer}\033[00m",
content=f"\033[1m\033[95m ## Final Result:\033[00m \033[92m{final_answer}\033[00m"
)
# Training mode prompt (single iteration)

View File

@@ -1,11 +1,12 @@
from abc import ABC, abstractmethod
from typing import Any
from typing import Any, Optional
from pydantic import BaseModel, Field
class OutputConverter(BaseModel, ABC):
"""Abstract base class for converting task results into structured formats.
"""
Abstract base class for converting task results into structured formats.
This class provides a framework for converting unstructured text into
either Pydantic models or JSON, tailored for specific agent requirements.
@@ -18,7 +19,6 @@ class OutputConverter(BaseModel, ABC):
model (Any): The target model for structuring the output.
instructions (str): Specific instructions for the conversion process.
max_attempts (int): Maximum number of conversion attempts (default: 3).
"""
text: str = Field(description="Text to be converted.")
@@ -33,7 +33,9 @@ class OutputConverter(BaseModel, ABC):
@abstractmethod
def to_pydantic(self, current_attempt=1) -> BaseModel:
"""Convert text to pydantic."""
pass
@abstractmethod
def to_json(self, current_attempt=1) -> dict:
"""Convert text to json."""
pass

View File

@@ -1,4 +1,4 @@
from typing import Any
from typing import Any, Dict, Optional
from pydantic import BaseModel, PrivateAttr
@@ -6,10 +6,10 @@ from pydantic import BaseModel, PrivateAttr
class CacheHandler(BaseModel):
"""Callback handler for tool usage."""
_cache: dict[str, Any] = PrivateAttr(default_factory=dict)
_cache: Dict[str, Any] = PrivateAttr(default_factory=dict)
def add(self, tool, input, output) -> None:
def add(self, tool, input, output):
self._cache[f"{tool}-{input}"] = output
def read(self, tool, input) -> str | None:
def read(self, tool, input) -> Optional[str]:
return self._cache.get(f"{tool}-{input}")

View File

@@ -1,5 +1,6 @@
from collections.abc import Callable
from typing import TYPE_CHECKING, Any
import json
import re
from typing import Any, Callable, Dict, List, Optional, Union
from crewai.agents.agent_builder.base_agent import BaseAgent
from crewai.agents.agent_builder.base_agent_executor_mixin import CrewAgentExecutorMixin
@@ -9,6 +10,8 @@ from crewai.agents.parser import (
OutputParserException,
)
from crewai.agents.tools_handler import ToolsHandler
from crewai.llm import BaseLLM
from crewai.tools.base_tool import BaseTool
from crewai.tools.structured_tool import CrewStructuredTool
from crewai.tools.tool_types import ToolResult
from crewai.utilities import I18N, Printer
@@ -31,10 +34,6 @@ from crewai.utilities.logger import Logger
from crewai.utilities.tool_utils import execute_tool_and_check_finality
from crewai.utilities.training_handler import CrewTrainingHandler
if TYPE_CHECKING:
from crewai.llm import BaseLLM
from crewai.tools.base_tool import BaseTool
class CrewAgentExecutor(CrewAgentExecutorMixin):
_logger: Logger = Logger()
@@ -47,22 +46,18 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
agent: BaseAgent,
prompt: dict[str, str],
max_iter: int,
tools: list[CrewStructuredTool],
tools: List[CrewStructuredTool],
tools_names: str,
stop_words: list[str],
stop_words: List[str],
tools_description: str,
tools_handler: ToolsHandler,
step_callback: Any = None,
original_tools: list[Any] | None = None,
original_tools: List[Any] = [],
function_calling_llm: Any = None,
respect_context_window: bool = False,
request_within_rpm_limit: Callable[[], bool] | None = None,
callbacks: list[Any] | None = None,
) -> None:
if callbacks is None:
callbacks = []
if original_tools is None:
original_tools = []
request_within_rpm_limit: Optional[Callable[[], bool]] = None,
callbacks: List[Any] = [],
):
self._i18n: I18N = I18N()
self.llm: BaseLLM = llm
self.task = task
@@ -84,10 +79,10 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
self.respect_context_window = respect_context_window
self.request_within_rpm_limit = request_within_rpm_limit
self.ask_for_human_input = False
self.messages: list[dict[str, str]] = []
self.messages: List[Dict[str, str]] = []
self.iterations = 0
self.log_error_after = 3
self.tool_name_to_tool_map: dict[str, CrewStructuredTool | BaseTool] = {
self.tool_name_to_tool_map: Dict[str, Union[CrewStructuredTool, BaseTool]] = {
tool.name: tool for tool in self.tools
}
existing_stop = self.llm.stop or []
@@ -95,11 +90,11 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
set(
existing_stop + self.stop
if isinstance(existing_stop, list)
else self.stop,
),
else self.stop
)
)
def invoke(self, inputs: dict[str, str]) -> dict[str, Any]:
def invoke(self, inputs: Dict[str, str]) -> Dict[str, Any]:
if "system" in self.prompt:
system_prompt = self._format_prompt(self.prompt.get("system", ""), inputs)
user_prompt = self._format_prompt(self.prompt.get("user", ""), inputs)
@@ -125,8 +120,9 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
handle_unknown_error(self._printer, e)
if e.__class__.__module__.startswith("litellm"):
# Do not retry on litellm errors
raise
raise
raise e
else:
raise e
if self.ask_for_human_input:
formatted_answer = self._handle_human_feedback(formatted_answer)
@@ -137,7 +133,8 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
return {"output": formatted_answer.output}
def _invoke_loop(self) -> AgentFinish:
"""Main loop to invoke the agent's thought process until it reaches a conclusion
"""
Main loop to invoke the agent's thought process until it reaches a conclusion
or the maximum number of iterations is reached.
"""
formatted_answer = None
@@ -173,8 +170,8 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
):
fingerprint_context = {
"agent_fingerprint": str(
self.agent.security_config.fingerprint,
),
self.agent.security_config.fingerprint
)
}
tool_result = execute_tool_and_check_finality(
@@ -190,7 +187,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
function_calling_llm=self.function_calling_llm,
)
formatted_answer = self._handle_agent_action(
formatted_answer, tool_result,
formatted_answer, tool_result
)
self._invoke_step_callback(formatted_answer)
@@ -208,7 +205,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
except Exception as e:
if e.__class__.__module__.startswith("litellm"):
# Do not retry on litellm errors
raise
raise e
if is_context_length_exceeded(e):
handle_context_length(
respect_context_window=self.respect_context_window,
@@ -219,8 +216,9 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
i18n=self._i18n,
)
continue
handle_unknown_error(self._printer, e)
raise
else:
handle_unknown_error(self._printer, e)
raise e
finally:
self.iterations += 1
@@ -233,8 +231,8 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
return formatted_answer
def _handle_agent_action(
self, formatted_answer: AgentAction, tool_result: ToolResult,
) -> AgentAction | AgentFinish:
self, formatted_answer: AgentAction, tool_result: ToolResult
) -> Union[AgentAction, AgentFinish]:
"""Handle the AgentAction, execute tools, and process the results."""
# Special case for add_image_tool
add_image_tool = self._i18n.tools("add_image")
@@ -263,26 +261,24 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
"""Append a message to the message list with the given role."""
self.messages.append(format_message_for_llm(text, role=role))
def _show_start_logs(self) -> None:
def _show_start_logs(self):
"""Show logs for the start of agent execution."""
if self.agent is None:
msg = "Agent cannot be None"
raise ValueError(msg)
raise ValueError("Agent cannot be None")
show_agent_logs(
printer=self._printer,
agent_role=self.agent.role,
task_description=(
self.task.description if self.task else "Not Found"
getattr(self.task, "description") if self.task else "Not Found"
),
verbose=self.agent.verbose
or (hasattr(self, "crew") and getattr(self.crew, "verbose", False)),
)
def _show_logs(self, formatted_answer: AgentAction | AgentFinish) -> None:
def _show_logs(self, formatted_answer: Union[AgentAction, AgentFinish]):
"""Show logs for the agent's execution."""
if self.agent is None:
msg = "Agent cannot be None"
raise ValueError(msg)
raise ValueError("Agent cannot be None")
show_agent_logs(
printer=self._printer,
agent_role=self.agent.role,
@@ -304,11 +300,11 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
summary = self.llm.call(
[
format_message_for_llm(
self._i18n.slice("summarizer_system_message"), role="system",
self._i18n.slice("summarizer_system_message"), role="system"
),
format_message_for_llm(
self._i18n.slice("summarize_instruction").format(
group=group["content"],
group=group["content"]
),
),
],
@@ -320,12 +316,12 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
self.messages = [
format_message_for_llm(
self._i18n.slice("summary").format(merged_summary=merged_summary),
),
self._i18n.slice("summary").format(merged_summary=merged_summary)
)
]
def _handle_crew_training_output(
self, result: AgentFinish, human_feedback: str | None = None,
self, result: AgentFinish, human_feedback: Optional[str] = None
) -> None:
"""Handle the process of saving training data."""
agent_id = str(self.agent.id) # type: ignore
@@ -352,27 +348,29 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
"initial_output": result.output,
"human_feedback": human_feedback,
}
# Save improved output
elif train_iteration in agent_training_data:
agent_training_data[train_iteration]["improved_output"] = result.output
else:
self._printer.print(
content=(
f"No existing training data for agent {agent_id} and iteration "
f"{train_iteration}. Cannot save improved output."
),
color="red",
)
return
# Save improved output
if train_iteration in agent_training_data:
agent_training_data[train_iteration]["improved_output"] = result.output
else:
self._printer.print(
content=(
f"No existing training data for agent {agent_id} and iteration "
f"{train_iteration}. Cannot save improved output."
),
color="red",
)
return
# Update the training data and save
training_data[agent_id] = agent_training_data
training_handler.save(training_data)
def _format_prompt(self, prompt: str, inputs: dict[str, str]) -> str:
def _format_prompt(self, prompt: str, inputs: Dict[str, str]) -> str:
prompt = prompt.replace("{input}", inputs["input"])
prompt = prompt.replace("{tool_names}", inputs["tool_names"])
return prompt.replace("{tools}", inputs["tools"])
prompt = prompt.replace("{tools}", inputs["tools"])
return prompt
def _handle_human_feedback(self, formatted_answer: AgentFinish) -> AgentFinish:
"""Handle human feedback with different flows for training vs regular use.
@@ -382,7 +380,6 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
Returns:
AgentFinish: The final answer after processing feedback
"""
human_feedback = self._ask_human_input(formatted_answer.output)
@@ -396,14 +393,14 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
return bool(self.crew and self.crew._train)
def _handle_training_feedback(
self, initial_answer: AgentFinish, feedback: str,
self, initial_answer: AgentFinish, feedback: str
) -> AgentFinish:
"""Process feedback for training scenarios with single iteration."""
self._handle_crew_training_output(initial_answer, feedback)
self.messages.append(
format_message_for_llm(
self._i18n.slice("feedback_instructions").format(feedback=feedback),
),
self._i18n.slice("feedback_instructions").format(feedback=feedback)
)
)
improved_answer = self._invoke_loop()
self._handle_crew_training_output(improved_answer)
@@ -411,7 +408,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
return improved_answer
def _handle_regular_feedback(
self, current_answer: AgentFinish, initial_feedback: str,
self, current_answer: AgentFinish, initial_feedback: str
) -> AgentFinish:
"""Process feedback for regular use with potential multiple iterations."""
feedback = initial_feedback
@@ -431,8 +428,8 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
"""Process a single feedback iteration."""
self.messages.append(
format_message_for_llm(
self._i18n.slice("feedback_instructions").format(feedback=feedback),
),
self._i18n.slice("feedback_instructions").format(feedback=feedback)
)
)
return self._invoke_loop()

View File

@@ -1,5 +1,5 @@
import re
from typing import Any
from typing import Any, Optional, Union
from json_repair import repair_json
@@ -18,7 +18,7 @@ class AgentAction:
text: str
result: str
def __init__(self, thought: str, tool: str, tool_input: str, text: str) -> None:
def __init__(self, thought: str, tool: str, tool_input: str, text: str):
self.thought = thought
self.tool = tool
self.tool_input = tool_input
@@ -30,7 +30,7 @@ class AgentFinish:
output: str
text: str
def __init__(self, thought: str, output: str, text: str) -> None:
def __init__(self, thought: str, output: str, text: str):
self.thought = thought
self.output = output
self.text = text
@@ -39,7 +39,7 @@ class AgentFinish:
class OutputParserException(Exception):
error: str
def __init__(self, error: str) -> None:
def __init__(self, error: str):
self.error = error
@@ -67,24 +67,24 @@ class CrewAgentParser:
_i18n: I18N = I18N()
agent: Any = None
def __init__(self, agent: Any | None = None) -> None:
def __init__(self, agent: Optional[Any] = None):
self.agent = agent
@staticmethod
def parse_text(text: str) -> AgentAction | AgentFinish:
"""Static method to parse text into an AgentAction or AgentFinish without needing to instantiate the class.
def parse_text(text: str) -> Union[AgentAction, AgentFinish]:
"""
Static method to parse text into an AgentAction or AgentFinish without needing to instantiate the class.
Args:
text: The text to parse.
Returns:
Either an AgentAction or AgentFinish based on the parsed content.
"""
parser = CrewAgentParser()
return parser.parse(text)
def parse(self, text: str) -> AgentAction | AgentFinish:
def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
thought = self._extract_thought(text)
includes_answer = FINAL_ANSWER_ACTION in text
regex = (
@@ -102,7 +102,7 @@ class CrewAgentParser:
final_answer = final_answer[:-3].rstrip()
return AgentFinish(thought, final_answer, text)
if action_match:
elif action_match:
action = action_match.group(1)
clean_action = self._clean_action(action)
@@ -114,21 +114,21 @@ class CrewAgentParser:
return AgentAction(thought, clean_action, safe_tool_input, text)
if not re.search(r"Action\s*\d*\s*:[\s]*(.*?)", text, re.DOTALL):
msg = f"{MISSING_ACTION_AFTER_THOUGHT_ERROR_MESSAGE}\n{self._i18n.slice('final_answer_format')}"
raise OutputParserException(
msg,
f"{MISSING_ACTION_AFTER_THOUGHT_ERROR_MESSAGE}\n{self._i18n.slice('final_answer_format')}",
)
if not re.search(
r"[\s]*Action\s*\d*\s*Input\s*\d*\s*:[\s]*(.*)", text, re.DOTALL,
elif not re.search(
r"[\s]*Action\s*\d*\s*Input\s*\d*\s*:[\s]*(.*)", text, re.DOTALL
):
raise OutputParserException(
MISSING_ACTION_INPUT_AFTER_ACTION_ERROR_MESSAGE,
)
format = self._i18n.slice("format_without_tools")
error = f"{format}"
raise OutputParserException(
error,
)
else:
format = self._i18n.slice("format_without_tools")
error = f"{format}"
raise OutputParserException(
error,
)
def _extract_thought(self, text: str) -> str:
thought_index = text.find("\nAction")
@@ -138,7 +138,8 @@ class CrewAgentParser:
return ""
thought = text[:thought_index].strip()
# Remove any triple backticks from the thought string
return thought.replace("```", "").strip()
thought = thought.replace("```", "").strip()
return thought
def _clean_action(self, text: str) -> str:
"""Clean action string by removing non-essential formatting characters."""

View File

@@ -1,8 +1,7 @@
from typing import Any
from crewai.tools.cache_tools.cache_tools import CacheTools
from crewai.tools.tool_calling import InstructorToolCalling, ToolCalling
from typing import Any, Optional, Union
from ..tools.cache_tools.cache_tools import CacheTools
from ..tools.tool_calling import InstructorToolCalling, ToolCalling
from .cache.cache_handler import CacheHandler
@@ -10,16 +9,16 @@ class ToolsHandler:
"""Callback handler for tool usage."""
last_used_tool: ToolCalling = {} # type: ignore # BUG?: Incompatible types in assignment (expression has type "Dict[...]", variable has type "ToolCalling")
cache: CacheHandler | None
cache: Optional[CacheHandler]
def __init__(self, cache: CacheHandler | None = None) -> None:
def __init__(self, cache: Optional[CacheHandler] = None):
"""Initialize the callback handler."""
self.cache = cache
self.last_used_tool = {} # type: ignore # BUG?: same as above
def on_tool_use(
self,
calling: ToolCalling | InstructorToolCalling,
calling: Union[ToolCalling, InstructorToolCalling],
output: str,
should_cache: bool = True,
) -> Any:

View File

@@ -9,9 +9,9 @@ def add_crew_to_flow(crew_name: str) -> None:
"""Add a new crew to the current flow."""
# Check if pyproject.toml exists in the current directory
if not Path("pyproject.toml").exists():
msg = "This command must be run from the root of a flow project."
print("This command must be run from the root of a flow project.")
raise click.ClickException(
msg,
"This command must be run from the root of a flow project."
)
# Determine the flow folder based on the current directory
@@ -19,8 +19,8 @@ def add_crew_to_flow(crew_name: str) -> None:
crews_folder = flow_folder / "src" / flow_folder.name / "crews"
if not crews_folder.exists():
msg = "Crews folder does not exist in the current flow."
raise click.ClickException(msg)
print("Crews folder does not exist in the current flow.")
raise click.ClickException("Crews folder does not exist in the current flow.")
# Create the crew within the flow's crews directory
create_embedded_crew(crew_name, parent_folder=crews_folder)
@@ -39,7 +39,7 @@ def create_embedded_crew(crew_name: str, parent_folder: Path) -> None:
if crew_folder.exists():
if not click.confirm(
f"Crew {folder_name} already exists. Do you want to override it?",
f"Crew {folder_name} already exists. Do you want to override it?"
):
click.secho("Operation cancelled.", fg="yellow")
return
@@ -66,5 +66,5 @@ def create_embedded_crew(crew_name: str, parent_folder: Path) -> None:
copy_template(src_file, dst_file, crew_name, class_name, folder_name)
click.secho(
f"Crew {crew_name} added to the flow successfully!", fg="green", bold=True,
f"Crew {crew_name} added to the flow successfully!", fg="green", bold=True
)

View File

@@ -1,6 +1,6 @@
import time
import webbrowser
from typing import Any
from typing import Any, Dict
import requests
from rich.console import Console
@@ -17,37 +17,38 @@ class AuthenticationCommand:
DEVICE_CODE_URL = f"https://{AUTH0_DOMAIN}/oauth/device/code"
TOKEN_URL = f"https://{AUTH0_DOMAIN}/oauth/token"
def __init__(self) -> None:
def __init__(self):
self.token_manager = TokenManager()
def signup(self) -> None:
"""Sign up to CrewAI+."""
"""Sign up to CrewAI+"""
console.print("Signing Up to CrewAI+ \n", style="bold blue")
device_code_data = self._get_device_code()
self._display_auth_instructions(device_code_data)
return self._poll_for_token(device_code_data)
def _get_device_code(self) -> dict[str, Any]:
def _get_device_code(self) -> Dict[str, Any]:
"""Get the device code to authenticate the user."""
device_code_payload = {
"client_id": AUTH0_CLIENT_ID,
"scope": "openid",
"audience": AUTH0_AUDIENCE,
}
response = requests.post(
url=self.DEVICE_CODE_URL, data=device_code_payload, timeout=20,
url=self.DEVICE_CODE_URL, data=device_code_payload, timeout=20
)
response.raise_for_status()
return response.json()
def _display_auth_instructions(self, device_code_data: dict[str, str]) -> None:
def _display_auth_instructions(self, device_code_data: Dict[str, str]) -> None:
"""Display the authentication instructions to the user."""
console.print("1. Navigate to: ", device_code_data["verification_uri_complete"])
console.print("2. Enter the following code: ", device_code_data["user_code"])
webbrowser.open(device_code_data["verification_uri_complete"])
def _poll_for_token(self, device_code_data: dict[str, Any]) -> None:
def _poll_for_token(self, device_code_data: Dict[str, Any]) -> None:
"""Poll the server for the token."""
token_payload = {
"grant_type": "urn:ietf:params:oauth:grant-type:device_code",
@@ -80,7 +81,7 @@ class AuthenticationCommand:
)
console.print(
"\n[bold green]Welcome to CrewAI Enterprise![/bold green]\n",
"\n[bold green]Welcome to CrewAI Enterprise![/bold green]\n"
)
return
@@ -91,5 +92,5 @@ class AuthenticationCommand:
attempts += 1
console.print(
"Timeout: Failed to get the token. Please try again.", style="bold red",
"Timeout: Failed to get the token. Please try again.", style="bold red"
)

View File

@@ -5,5 +5,5 @@ def get_auth_token() -> str:
"""Get the authentication token."""
access_token = TokenManager().get_token()
if not access_token:
raise Exception
raise Exception()
return access_token

View File

@@ -3,6 +3,7 @@ import os
import sys
from datetime import datetime, timedelta
from pathlib import Path
from typing import Optional
from auth0.authentication.token_verifier import (
AsymmetricSignatureVerifier,
@@ -14,7 +15,8 @@ from .constants import AUTH0_CLIENT_ID, AUTH0_DOMAIN
def validate_token(id_token: str) -> None:
"""Verify the token and its precedence.
"""
Verify the token and its precedence
:param id_token:
"""
@@ -22,14 +24,15 @@ def validate_token(id_token: str) -> None:
issuer = f"https://{AUTH0_DOMAIN}/"
signature_verifier = AsymmetricSignatureVerifier(jwks_url)
token_verifier = TokenVerifier(
signature_verifier=signature_verifier, issuer=issuer, audience=AUTH0_CLIENT_ID,
signature_verifier=signature_verifier, issuer=issuer, audience=AUTH0_CLIENT_ID
)
token_verifier.verify(id_token)
class TokenManager:
def __init__(self, file_path: str = "tokens.enc") -> None:
"""Initialize the TokenManager class.
"""
Initialize the TokenManager class.
:param file_path: The file path to store the encrypted tokens. Default is "tokens.enc".
"""
@@ -38,7 +41,8 @@ class TokenManager:
self.fernet = Fernet(self.key)
def _get_or_create_key(self) -> bytes:
"""Get or create the encryption key.
"""
Get or create the encryption key.
:return: The encryption key.
"""
@@ -53,7 +57,8 @@ class TokenManager:
return new_key
def save_tokens(self, access_token: str, expires_in: int) -> None:
"""Save the access token and its expiration time.
"""
Save the access token and its expiration time.
:param access_token: The access token to save.
:param expires_in: The expiration time of the access token in seconds.
@@ -66,8 +71,9 @@ class TokenManager:
encrypted_data = self.fernet.encrypt(json.dumps(data).encode())
self.save_secure_file(self.file_path, encrypted_data)
def get_token(self) -> str | None:
"""Get the access token if it is valid and not expired.
def get_token(self) -> Optional[str]:
"""
Get the access token if it is valid and not expired.
:return: The access token if valid and not expired, otherwise None.
"""
@@ -83,7 +89,8 @@ class TokenManager:
return data["access_token"]
def get_secure_storage_path(self) -> Path:
"""Get the secure storage path based on the operating system.
"""
Get the secure storage path based on the operating system.
:return: The secure storage path.
"""
@@ -105,7 +112,8 @@ class TokenManager:
return storage_path
def save_secure_file(self, filename: str, content: bytes) -> None:
"""Save the content to a secure file.
"""
Save the content to a secure file.
:param filename: The name of the file.
:param content: The content to save.
@@ -119,8 +127,9 @@ class TokenManager:
# Set appropriate permissions (read/write for owner only)
os.chmod(file_path, 0o600)
def read_secure_file(self, filename: str) -> bytes | None:
"""Read the content of a secure file.
def read_secure_file(self, filename: str) -> Optional[bytes]:
"""
Read the content of a secure file.
:param filename: The name of the file.
:return: The content of the file if it exists, otherwise None.

View File

@@ -1,4 +1,6 @@
import os
from importlib.metadata import version as get_version
from typing import Optional, Tuple
import click
@@ -26,7 +28,7 @@ from .update_crew import update_crew
@click.group()
@click.version_option(get_version("crewai"))
def crewai() -> None:
def crewai():
"""Top-level command group for crewai."""
@@ -35,7 +37,7 @@ def crewai() -> None:
@click.argument("name")
@click.option("--provider", type=str, help="The provider to use for the crew")
@click.option("--skip_provider", is_flag=True, help="Skip provider validation")
def create(type, name, provider, skip_provider=False) -> None:
def create(type, name, provider, skip_provider=False):
"""Create a new crew, or flow."""
if type == "crew":
create_crew(name, provider, skip_provider)
@@ -47,9 +49,9 @@ def create(type, name, provider, skip_provider=False) -> None:
@crewai.command()
@click.option(
"--tools", is_flag=True, help="Show the installed version of crewai tools",
"--tools", is_flag=True, help="Show the installed version of crewai tools"
)
def version(tools) -> None:
def version(tools):
"""Show the installed version of crewai."""
try:
crewai_version = get_version("crewai")
@@ -80,7 +82,7 @@ def version(tools) -> None:
default="trained_agents_data.pkl",
help="Path to a custom file for training",
)
def train(n_iterations: int, filename: str) -> None:
def train(n_iterations: int, filename: str):
"""Train the crew."""
click.echo(f"Training the Crew for {n_iterations} iterations")
train_crew(n_iterations, filename)
@@ -94,11 +96,11 @@ def train(n_iterations: int, filename: str) -> None:
help="Replay the crew from this task ID, including all subsequent tasks.",
)
def replay(task_id: str) -> None:
"""Replay the crew execution from a specific task.
"""
Replay the crew execution from a specific task.
Args:
task_id (str): The ID of the task to replay from.
"""
try:
click.echo(f"Replaying the crew from task {task_id}")
@@ -109,14 +111,16 @@ def replay(task_id: str) -> None:
@crewai.command()
def log_tasks_outputs() -> None:
"""Retrieve your latest crew.kickoff() task outputs."""
"""
Retrieve your latest crew.kickoff() task outputs.
"""
try:
storage = KickoffTaskOutputsSQLiteStorage()
tasks = storage.load()
if not tasks:
click.echo(
"No task outputs found. Only crew kickoff task outputs are logged.",
"No task outputs found. Only crew kickoff task outputs are logged."
)
return
@@ -149,11 +153,13 @@ def reset_memories(
kickoff_outputs: bool,
all: bool,
) -> None:
"""Reset the crew memories (long, short, entity, latest_crew_kickoff_ouputs). This will delete all the data saved."""
"""
Reset the crew memories (long, short, entity, latest_crew_kickoff_ouputs). This will delete all the data saved.
"""
try:
if not all and not (long or short or entities or knowledge or kickoff_outputs):
click.echo(
"Please specify at least one memory type to reset using the appropriate flags.",
"Please specify at least one memory type to reset using the appropriate flags."
)
return
reset_memories_command(long, short, entities, knowledge, kickoff_outputs, all)
@@ -176,69 +182,71 @@ def reset_memories(
default="gpt-4o-mini",
help="LLM Model to run the tests on the Crew. For now only accepting only OpenAI models.",
)
def test(n_iterations: int, model: str) -> None:
def test(n_iterations: int, model: str):
"""Test the crew and evaluate the results."""
click.echo(f"Testing the crew for {n_iterations} iterations with model {model}")
evaluate_crew(n_iterations, model)
@crewai.command(
context_settings={
"ignore_unknown_options": True,
"allow_extra_args": True,
},
context_settings=dict(
ignore_unknown_options=True,
allow_extra_args=True,
)
)
@click.pass_context
def install(context) -> None:
def install(context):
"""Install the Crew."""
install_crew(context.args)
@crewai.command()
def run() -> None:
def run():
"""Run the Crew."""
run_crew()
@crewai.command()
def update() -> None:
def update():
"""Update the pyproject.toml of the Crew project to use uv."""
update_crew()
@crewai.command()
def signup() -> None:
def signup():
"""Sign Up/Login to CrewAI+."""
AuthenticationCommand().signup()
@crewai.command()
def login() -> None:
def login():
"""Sign Up/Login to CrewAI+."""
AuthenticationCommand().signup()
# DEPLOY CREWAI+ COMMANDS
@crewai.group()
def deploy() -> None:
def deploy():
"""Deploy the Crew CLI group."""
pass
@crewai.group()
def tool() -> None:
def tool():
"""Tool Repository related commands."""
pass
@deploy.command(name="create")
@click.option("-y", "--yes", is_flag=True, help="Skip the confirmation prompt")
def deploy_create(yes: bool) -> None:
def deploy_create(yes: bool):
"""Create a Crew deployment."""
deploy_cmd = DeployCommand()
deploy_cmd.create_crew(yes)
@deploy.command(name="list")
def deploy_list() -> None:
def deploy_list():
"""List all deployments."""
deploy_cmd = DeployCommand()
deploy_cmd.list_crews()
@@ -246,7 +254,7 @@ def deploy_list() -> None:
@deploy.command(name="push")
@click.option("-u", "--uuid", type=str, help="Crew UUID parameter")
def deploy_push(uuid: str | None) -> None:
def deploy_push(uuid: Optional[str]):
"""Deploy the Crew."""
deploy_cmd = DeployCommand()
deploy_cmd.deploy(uuid=uuid)
@@ -254,7 +262,7 @@ def deploy_push(uuid: str | None) -> None:
@deploy.command(name="status")
@click.option("-u", "--uuid", type=str, help="Crew UUID parameter")
def deply_status(uuid: str | None) -> None:
def deply_status(uuid: Optional[str]):
"""Get the status of a deployment."""
deploy_cmd = DeployCommand()
deploy_cmd.get_crew_status(uuid=uuid)
@@ -262,7 +270,7 @@ def deply_status(uuid: str | None) -> None:
@deploy.command(name="logs")
@click.option("-u", "--uuid", type=str, help="Crew UUID parameter")
def deploy_logs(uuid: str | None) -> None:
def deploy_logs(uuid: Optional[str]):
"""Get the logs of a deployment."""
deploy_cmd = DeployCommand()
deploy_cmd.get_crew_logs(uuid=uuid)
@@ -270,7 +278,7 @@ def deploy_logs(uuid: str | None) -> None:
@deploy.command(name="remove")
@click.option("-u", "--uuid", type=str, help="Crew UUID parameter")
def deploy_remove(uuid: str | None) -> None:
def deploy_remove(uuid: Optional[str]):
"""Remove a deployment."""
deploy_cmd = DeployCommand()
deploy_cmd.remove_crew(uuid=uuid)
@@ -278,14 +286,14 @@ def deploy_remove(uuid: str | None) -> None:
@tool.command(name="create")
@click.argument("handle")
def tool_create(handle: str) -> None:
def tool_create(handle: str):
tool_cmd = ToolCommand()
tool_cmd.create(handle)
@tool.command(name="install")
@click.argument("handle")
def tool_install(handle: str) -> None:
def tool_install(handle: str):
tool_cmd = ToolCommand()
tool_cmd.login()
tool_cmd.install(handle)
@@ -301,26 +309,27 @@ def tool_install(handle: str) -> None:
)
@click.option("--public", "is_public", flag_value=True, default=False)
@click.option("--private", "is_public", flag_value=False)
def tool_publish(is_public: bool, force: bool) -> None:
def tool_publish(is_public: bool, force: bool):
tool_cmd = ToolCommand()
tool_cmd.login()
tool_cmd.publish(is_public, force)
@crewai.group()
def flow() -> None:
def flow():
"""Flow related commands."""
pass
@flow.command(name="kickoff")
def flow_run() -> None:
def flow_run():
"""Kickoff the Flow."""
click.echo("Running the Flow")
kickoff_flow()
@flow.command(name="plot")
def flow_plot() -> None:
def flow_plot():
"""Plot the Flow."""
click.echo("Plotting the Flow")
plot_flow()
@@ -328,19 +337,20 @@ def flow_plot() -> None:
@flow.command(name="add-crew")
@click.argument("crew_name")
def flow_add_crew(crew_name) -> None:
def flow_add_crew(crew_name):
"""Add a crew to an existing flow."""
click.echo(f"Adding crew {crew_name} to the flow")
add_crew_to_flow(crew_name)
@crewai.command()
def chat() -> None:
"""Start a conversation with the Crew, collecting user-supplied inputs,
def chat():
"""
Start a conversation with the Crew, collecting user-supplied inputs,
and using the Chat LLM to generate responses.
"""
click.secho(
"\nStarting a conversation with the Crew\nType 'exit' or Ctrl+C to quit.\n",
"\nStarting a conversation with the Crew\n" "Type 'exit' or Ctrl+C to quit.\n",
)
run_chat()

View File

@@ -10,13 +10,13 @@ console = Console()
class BaseCommand:
def __init__(self) -> None:
def __init__(self):
self._telemetry = Telemetry()
self._telemetry.set_tracer()
class PlusAPIMixin:
def __init__(self, telemetry) -> None:
def __init__(self, telemetry):
try:
telemetry.set_tracer()
self.plus_api_client = PlusAPI(api_key=get_auth_token())
@@ -30,11 +30,11 @@ class PlusAPIMixin:
raise SystemExit
def _validate_response(self, response: requests.Response) -> None:
"""Handle and display error messages from API responses.
"""
Handle and display error messages from API responses.
Args:
response (requests.Response): The response from the Plus API
"""
try:
json_response = response.json()
@@ -55,13 +55,13 @@ class PlusAPIMixin:
for field, messages in json_response.items():
for message in messages:
console.print(
f"* [bold red]{field.capitalize()}[/bold red] {message}",
f"* [bold red]{field.capitalize()}[/bold red] {message}"
)
raise SystemExit
if not response.ok:
console.print(
"Request to Enterprise API failed. Details:", style="bold red",
"Request to Enterprise API failed. Details:", style="bold red"
)
details = (
json_response.get("error")

View File

@@ -1,5 +1,6 @@
import json
from pathlib import Path
from typing import Optional
from pydantic import BaseModel, Field
@@ -7,16 +8,16 @@ DEFAULT_CONFIG_PATH = Path.home() / ".config" / "crewai" / "settings.json"
class Settings(BaseModel):
tool_repository_username: str | None = Field(
None, description="Username for interacting with the Tool Repository",
tool_repository_username: Optional[str] = Field(
None, description="Username for interacting with the Tool Repository"
)
tool_repository_password: str | None = Field(
None, description="Password for interacting with the Tool Repository",
tool_repository_password: Optional[str] = Field(
None, description="Password for interacting with the Tool Repository"
)
config_path: Path = Field(default=DEFAULT_CONFIG_PATH, exclude=True)
def __init__(self, config_path: Path = DEFAULT_CONFIG_PATH, **data) -> None:
"""Load Settings from config path."""
def __init__(self, config_path: Path = DEFAULT_CONFIG_PATH, **data):
"""Load Settings from config path"""
config_path.parent.mkdir(parents=True, exist_ok=True)
file_data = {}
@@ -31,7 +32,7 @@ class Settings(BaseModel):
super().__init__(config_path=config_path, **merged_data)
def dump(self) -> None:
"""Save current settings to settings.json."""
"""Save current settings to settings.json"""
if self.config_path.is_file():
with self.config_path.open("r") as f:
existing_data = json.load(f)

View File

@@ -3,31 +3,31 @@ ENV_VARS = {
{
"prompt": "Enter your OPENAI API key (press Enter to skip)",
"key_name": "OPENAI_API_KEY",
},
}
],
"anthropic": [
{
"prompt": "Enter your ANTHROPIC API key (press Enter to skip)",
"key_name": "ANTHROPIC_API_KEY",
},
}
],
"gemini": [
{
"prompt": "Enter your GEMINI API key from https://ai.dev/apikey (press Enter to skip)",
"prompt": "Enter your GEMINI API key (press Enter to skip)",
"key_name": "GEMINI_API_KEY",
},
}
],
"nvidia_nim": [
{
"prompt": "Enter your NVIDIA API key (press Enter to skip)",
"key_name": "NVIDIA_NIM_API_KEY",
},
}
],
"groq": [
{
"prompt": "Enter your GROQ API key (press Enter to skip)",
"key_name": "GROQ_API_KEY",
},
}
],
"watson": [
{
@@ -47,7 +47,7 @@ ENV_VARS = {
{
"default": True,
"API_BASE": "http://localhost:11434",
},
}
],
"bedrock": [
{
@@ -101,7 +101,7 @@ ENV_VARS = {
{
"prompt": "Enter your SambaNovaCloud API key (press Enter to skip)",
"key_name": "SAMBANOVA_API_KEY",
},
}
],
}

View File

@@ -24,7 +24,7 @@ def create_folder_structure(name, parent_folder=None):
if folder_path.exists():
if not click.confirm(
f"Folder {folder_name} already exists. Do you want to override it?",
f"Folder {folder_name} already exists. Do you want to override it?"
):
click.secho("Operation cancelled.", fg="yellow")
sys.exit(0)
@@ -48,7 +48,7 @@ def create_folder_structure(name, parent_folder=None):
return folder_path, folder_name, class_name
def copy_template_files(folder_path, name, class_name, parent_folder) -> None:
def copy_template_files(folder_path, name, class_name, parent_folder):
package_dir = Path(__file__).parent
templates_dir = package_dir / "templates" / "crew"
@@ -89,7 +89,7 @@ def copy_template_files(folder_path, name, class_name, parent_folder) -> None:
copy_template(src_file, dst_file, name, class_name, folder_path.name)
def create_crew(name, provider=None, skip_provider=False, parent_folder=None) -> None:
def create_crew(name, provider=None, skip_provider=False, parent_folder=None):
folder_path, folder_name, class_name = create_folder_structure(name, parent_folder)
env_vars = load_env_vars(folder_path)
if not skip_provider:
@@ -109,7 +109,7 @@ def create_crew(name, provider=None, skip_provider=False, parent_folder=None) ->
if existing_provider:
if not click.confirm(
f"Found existing environment variable configuration for {existing_provider.capitalize()}. Do you want to override it?",
f"Found existing environment variable configuration for {existing_provider.capitalize()}. Do you want to override it?"
):
click.secho("Keeping existing provider configuration.", fg="yellow")
return
@@ -126,11 +126,11 @@ def create_crew(name, provider=None, skip_provider=False, parent_folder=None) ->
if selected_provider: # Valid selection
break
click.secho(
"No provider selected. Please try again or press 'q' to exit.", fg="red",
"No provider selected. Please try again or press 'q' to exit.", fg="red"
)
# Check if the selected provider has predefined models
if MODELS.get(selected_provider):
if selected_provider in MODELS and MODELS[selected_provider]:
while True:
selected_model = select_model(selected_provider, provider_models)
if selected_model is None: # User typed 'q'
@@ -167,7 +167,7 @@ def create_crew(name, provider=None, skip_provider=False, parent_folder=None) ->
click.secho("API keys and model saved to .env file", fg="green")
else:
click.secho(
"No API keys provided. Skipping .env file creation.", fg="yellow",
"No API keys provided. Skipping .env file creation.", fg="yellow"
)
click.secho(f"Selected model: {env_vars.get('MODEL', 'N/A')}", fg="green")

View File

@@ -5,7 +5,7 @@ import click
from crewai.telemetry import Telemetry
def create_flow(name) -> None:
def create_flow(name):
"""Create a new flow."""
folder_name = name.replace(" ", "_").replace("-", "_").lower()
class_name = name.replace("_", " ").replace("-", " ").title().replace(" ", "")
@@ -43,12 +43,12 @@ def create_flow(name) -> None:
"poem_crew",
]
def process_file(src_file, dst_file) -> None:
def process_file(src_file, dst_file):
if src_file.suffix in [".pyc", ".pyo", ".pyd"]:
return
try:
with open(src_file, encoding="utf-8") as file:
with open(src_file, "r", encoding="utf-8") as file:
content = file.read()
except Exception as e:
click.secho(f"Error processing file {src_file}: {e}", fg="red")

View File

@@ -5,7 +5,7 @@ import sys
import threading
import time
from pathlib import Path
from typing import Any
from typing import Any, Dict, List, Optional, Set, Tuple
import click
import tomli
@@ -22,9 +22,10 @@ MIN_REQUIRED_VERSION = "0.98.0"
def check_conversational_crews_version(
crewai_version: str, pyproject_data: dict,
crewai_version: str, pyproject_data: dict
) -> bool:
"""Check if the installed crewAI version supports conversational crews.
"""
Check if the installed crewAI version supports conversational crews.
Args:
crewai_version: The current version of crewAI.
@@ -32,7 +33,6 @@ def check_conversational_crews_version(
Returns:
bool: True if version check passes, False otherwise.
"""
try:
if version.parse(crewai_version) < version.parse(MIN_REQUIRED_VERSION):
@@ -48,8 +48,9 @@ def check_conversational_crews_version(
return True
def run_chat() -> None:
"""Runs an interactive chat loop using the Crew's chat LLM with function calling.
def run_chat():
"""
Runs an interactive chat loop using the Crew's chat LLM with function calling.
Incorporates crew_name, crew_description, and input fields to build a tool schema.
Exits if crew_name or crew_description are missing.
"""
@@ -83,7 +84,7 @@ def run_chat() -> None:
# Call the LLM to generate the introductory message
introductory_message = chat_llm.call(
messages=[{"role": "system", "content": system_message}],
messages=[{"role": "system", "content": system_message}]
)
finally:
# Stop loading indicator
@@ -107,13 +108,15 @@ def run_chat() -> None:
chat_loop(chat_llm, messages, crew_tool_schema, available_functions)
def show_loading(event: threading.Event) -> None:
def show_loading(event: threading.Event):
"""Display animated loading dots while processing."""
while not event.is_set():
print(".", end="", flush=True)
time.sleep(1)
print()
def initialize_chat_llm(crew: Crew) -> LLM | BaseLLM | None:
def initialize_chat_llm(crew: Crew) -> Optional[LLM | BaseLLM]:
"""Initializes the chat LLM and handles exceptions."""
try:
return create_llm(crew.chat_llm)
@@ -154,7 +157,7 @@ def build_system_message(crew_chat_inputs: ChatInputs) -> str:
)
def create_tool_function(crew: Crew, messages: list[dict[str, str]]) -> Any:
def create_tool_function(crew: Crew, messages: List[Dict[str, str]]) -> Any:
"""Creates a wrapper function for running the crew tool with messages."""
def run_crew_tool_with_messages(**kwargs):
@@ -163,7 +166,7 @@ def create_tool_function(crew: Crew, messages: list[dict[str, str]]) -> Any:
return run_crew_tool_with_messages
def flush_input() -> None:
def flush_input():
"""Flush any pending input from the user."""
if platform.system() == "Windows":
# Windows platform
@@ -178,7 +181,7 @@ def flush_input() -> None:
termios.tcflush(sys.stdin, termios.TCIFLUSH)
def chat_loop(chat_llm, messages, crew_tool_schema, available_functions) -> None:
def chat_loop(chat_llm, messages, crew_tool_schema, available_functions):
"""Main chat loop for interacting with the user."""
while True:
try:
@@ -187,7 +190,7 @@ def chat_loop(chat_llm, messages, crew_tool_schema, available_functions) -> None
user_input = get_user_input()
handle_user_input(
user_input, chat_llm, messages, crew_tool_schema, available_functions,
user_input, chat_llm, messages, crew_tool_schema, available_functions
)
except KeyboardInterrupt:
@@ -218,9 +221,9 @@ def get_user_input() -> str:
def handle_user_input(
user_input: str,
chat_llm: LLM,
messages: list[dict[str, str]],
crew_tool_schema: dict[str, Any],
available_functions: dict[str, Any],
messages: List[Dict[str, str]],
crew_tool_schema: Dict[str, Any],
available_functions: Dict[str, Any],
) -> None:
if user_input.strip().lower() == "exit":
click.echo("Exiting chat. Goodbye!")
@@ -248,7 +251,8 @@ def handle_user_input(
def generate_crew_tool_schema(crew_inputs: ChatInputs) -> dict:
"""Dynamically build a Littellm 'function' schema for the given crew.
"""
Dynamically build a Littellm 'function' schema for the given crew.
crew_name: The name of the crew (used for the function 'name').
crew_inputs: A ChatInputs object containing crew_description
@@ -277,8 +281,9 @@ def generate_crew_tool_schema(crew_inputs: ChatInputs) -> dict:
}
def run_crew_tool(crew: Crew, messages: list[dict[str, str]], **kwargs):
"""Runs the crew using crew.kickoff(inputs=kwargs) and returns the output.
def run_crew_tool(crew: Crew, messages: List[Dict[str, str]], **kwargs):
"""
Runs the crew using crew.kickoff(inputs=kwargs) and returns the output.
Args:
crew (Crew): The crew instance to run.
@@ -290,7 +295,6 @@ def run_crew_tool(crew: Crew, messages: list[dict[str, str]], **kwargs):
Raises:
SystemExit: Exits the chat if an error occurs during crew execution.
"""
try:
# Serialize 'messages' to JSON string before adding to kwargs
@@ -300,8 +304,9 @@ def run_crew_tool(crew: Crew, messages: list[dict[str, str]], **kwargs):
crew_output = crew.kickoff(inputs=kwargs)
# Convert CrewOutput to a string to send back to the user
return str(crew_output)
result = str(crew_output)
return result
except Exception as e:
# Exit the chat and show the error message
click.secho("An error occurred while running the crew:", fg="red")
@@ -309,12 +314,12 @@ def run_crew_tool(crew: Crew, messages: list[dict[str, str]], **kwargs):
sys.exit(1)
def load_crew_and_name() -> tuple[Crew, str]:
"""Loads the crew by importing the crew class from the user's project.
def load_crew_and_name() -> Tuple[Crew, str]:
"""
Loads the crew by importing the crew class from the user's project.
Returns:
Tuple[Crew, str]: A tuple containing the Crew instance and the name of the crew.
"""
# Get the current working directory
cwd = Path.cwd()
@@ -322,8 +327,7 @@ def load_crew_and_name() -> tuple[Crew, str]:
# Path to the pyproject.toml file
pyproject_path = cwd / "pyproject.toml"
if not pyproject_path.exists():
msg = "pyproject.toml not found in the current directory."
raise FileNotFoundError(msg)
raise FileNotFoundError("pyproject.toml not found in the current directory.")
# Load the pyproject.toml file using 'tomli'
with pyproject_path.open("rb") as f:
@@ -347,16 +351,14 @@ def load_crew_and_name() -> tuple[Crew, str]:
try:
crew_module = __import__(crew_module_name, fromlist=[crew_class_name])
except ImportError as e:
msg = f"Failed to import crew module {crew_module_name}: {e}"
raise ImportError(msg)
raise ImportError(f"Failed to import crew module {crew_module_name}: {e}")
# Get the crew class from the module
try:
crew_class = getattr(crew_module, crew_class_name)
except AttributeError:
msg = f"Crew class {crew_class_name} not found in module {crew_module_name}"
raise AttributeError(
msg,
f"Crew class {crew_class_name} not found in module {crew_module_name}"
)
# Instantiate the crew
@@ -365,7 +367,8 @@ def load_crew_and_name() -> tuple[Crew, str]:
def generate_crew_chat_inputs(crew: Crew, crew_name: str, chat_llm) -> ChatInputs:
"""Generates the ChatInputs required for the crew by analyzing the tasks and agents.
"""
Generates the ChatInputs required for the crew by analyzing the tasks and agents.
Args:
crew (Crew): The crew object containing tasks and agents.
@@ -374,7 +377,6 @@ def generate_crew_chat_inputs(crew: Crew, crew_name: str, chat_llm) -> ChatInput
Returns:
ChatInputs: An object containing the crew's name, description, and input fields.
"""
# Extract placeholders from tasks and agents
required_inputs = fetch_required_inputs(crew)
@@ -389,22 +391,22 @@ def generate_crew_chat_inputs(crew: Crew, crew_name: str, chat_llm) -> ChatInput
crew_description = generate_crew_description_with_ai(crew, chat_llm)
return ChatInputs(
crew_name=crew_name, crew_description=crew_description, inputs=input_fields,
crew_name=crew_name, crew_description=crew_description, inputs=input_fields
)
def fetch_required_inputs(crew: Crew) -> set[str]:
"""Extracts placeholders from the crew's tasks and agents.
def fetch_required_inputs(crew: Crew) -> Set[str]:
"""
Extracts placeholders from the crew's tasks and agents.
Args:
crew (Crew): The crew object.
Returns:
Set[str]: A set of placeholder names.
"""
placeholder_pattern = re.compile(r"\{(.+?)\}")
required_inputs: set[str] = set()
required_inputs: Set[str] = set()
# Scan tasks
for task in crew.tasks:
@@ -420,7 +422,8 @@ def fetch_required_inputs(crew: Crew) -> set[str]:
def generate_input_description_with_ai(input_name: str, crew: Crew, chat_llm) -> str:
"""Generates an input description using AI based on the context of the crew.
"""
Generates an input description using AI based on the context of the crew.
Args:
input_name (str): The name of the input placeholder.
@@ -429,7 +432,6 @@ def generate_input_description_with_ai(input_name: str, crew: Crew, chat_llm) ->
Returns:
str: A concise description of the input.
"""
# Gather context from tasks and agents where the input is used
context_texts = []
@@ -442,10 +444,10 @@ def generate_input_description_with_ai(input_name: str, crew: Crew, chat_llm) ->
):
# Replace placeholders with input names
task_description = placeholder_pattern.sub(
lambda m: m.group(1), task.description or "",
lambda m: m.group(1), task.description or ""
)
expected_output = placeholder_pattern.sub(
lambda m: m.group(1), task.expected_output or "",
lambda m: m.group(1), task.expected_output or ""
)
context_texts.append(f"Task Description: {task_description}")
context_texts.append(f"Expected Output: {expected_output}")
@@ -459,7 +461,7 @@ def generate_input_description_with_ai(input_name: str, crew: Crew, chat_llm) ->
agent_role = placeholder_pattern.sub(lambda m: m.group(1), agent.role or "")
agent_goal = placeholder_pattern.sub(lambda m: m.group(1), agent.goal or "")
agent_backstory = placeholder_pattern.sub(
lambda m: m.group(1), agent.backstory or "",
lambda m: m.group(1), agent.backstory or ""
)
context_texts.append(f"Agent Role: {agent_role}")
context_texts.append(f"Agent Goal: {agent_goal}")
@@ -468,8 +470,7 @@ def generate_input_description_with_ai(input_name: str, crew: Crew, chat_llm) ->
context = "\n".join(context_texts)
if not context:
# If no context is found for the input, raise an exception as per instruction
msg = f"No context found for input '{input_name}'."
raise ValueError(msg)
raise ValueError(f"No context found for input '{input_name}'.")
prompt = (
f"Based on the following context, write a concise description (15 words or less) of the input '{input_name}'.\n"
@@ -478,12 +479,14 @@ def generate_input_description_with_ai(input_name: str, crew: Crew, chat_llm) ->
f"{context}"
)
response = chat_llm.call(messages=[{"role": "user", "content": prompt}])
return response.strip()
description = response.strip()
return description
def generate_crew_description_with_ai(crew: Crew, chat_llm) -> str:
"""Generates a brief description of the crew using AI.
"""
Generates a brief description of the crew using AI.
Args:
crew (Crew): The crew object.
@@ -491,7 +494,6 @@ def generate_crew_description_with_ai(crew: Crew, chat_llm) -> str:
Returns:
str: A concise description of the crew's purpose (15 words or less).
"""
# Gather context from tasks and agents
context_texts = []
@@ -500,10 +502,10 @@ def generate_crew_description_with_ai(crew: Crew, chat_llm) -> str:
for task in crew.tasks:
# Replace placeholders with input names
task_description = placeholder_pattern.sub(
lambda m: m.group(1), task.description or "",
lambda m: m.group(1), task.description or ""
)
expected_output = placeholder_pattern.sub(
lambda m: m.group(1), task.expected_output or "",
lambda m: m.group(1), task.expected_output or ""
)
context_texts.append(f"Task Description: {task_description}")
context_texts.append(f"Expected Output: {expected_output}")
@@ -512,7 +514,7 @@ def generate_crew_description_with_ai(crew: Crew, chat_llm) -> str:
agent_role = placeholder_pattern.sub(lambda m: m.group(1), agent.role or "")
agent_goal = placeholder_pattern.sub(lambda m: m.group(1), agent.goal or "")
agent_backstory = placeholder_pattern.sub(
lambda m: m.group(1), agent.backstory or "",
lambda m: m.group(1), agent.backstory or ""
)
context_texts.append(f"Agent Role: {agent_role}")
context_texts.append(f"Agent Goal: {agent_goal}")
@@ -520,8 +522,7 @@ def generate_crew_description_with_ai(crew: Crew, chat_llm) -> str:
context = "\n".join(context_texts)
if not context:
msg = "No context found for generating crew description."
raise ValueError(msg)
raise ValueError("No context found for generating crew description.")
prompt = (
"Based on the following context, write a concise, action-oriented description (15 words or less) of the crew's purpose.\n"
@@ -530,5 +531,6 @@ def generate_crew_description_with_ai(crew: Crew, chat_llm) -> str:
f"{context}"
)
response = chat_llm.call(messages=[{"role": "user", "content": prompt}])
return response.strip()
crew_description = response.strip()
return crew_description

View File

@@ -1,4 +1,4 @@
from typing import Any
from typing import Any, Dict, List, Optional
from rich.console import Console
@@ -10,27 +10,34 @@ console = Console()
class DeployCommand(BaseCommand, PlusAPIMixin):
"""A class to handle deployment-related operations for CrewAI projects."""
"""
A class to handle deployment-related operations for CrewAI projects.
"""
def __init__(self):
"""
Initialize the DeployCommand with project name and API client.
"""
def __init__(self) -> None:
"""Initialize the DeployCommand with project name and API client."""
BaseCommand.__init__(self)
PlusAPIMixin.__init__(self, telemetry=self._telemetry)
self.project_name = get_project_name(require=True)
def _standard_no_param_error_message(self) -> None:
"""Display a standard error message when no UUID or project name is available."""
"""
Display a standard error message when no UUID or project name is available.
"""
console.print(
"No UUID provided, project pyproject.toml not found or with error.",
style="bold red",
)
def _display_deployment_info(self, json_response: dict[str, Any]) -> None:
"""Display deployment information.
def _display_deployment_info(self, json_response: Dict[str, Any]) -> None:
"""
Display deployment information.
Args:
json_response (Dict[str, Any]): The deployment information to display.
"""
console.print("Deploying the crew...\n", style="bold blue")
for key, value in json_response.items():
@@ -40,24 +47,24 @@ class DeployCommand(BaseCommand, PlusAPIMixin):
console.print(" or")
console.print(f"crewai deploy status --uuid \"{json_response['uuid']}\"")
def _display_logs(self, log_messages: list[dict[str, Any]]) -> None:
"""Display log messages.
def _display_logs(self, log_messages: List[Dict[str, Any]]) -> None:
"""
Display log messages.
Args:
log_messages (List[Dict[str, Any]]): The log messages to display.
"""
for log_message in log_messages:
console.print(
f"{log_message['timestamp']} - {log_message['level']}: {log_message['message']}",
f"{log_message['timestamp']} - {log_message['level']}: {log_message['message']}"
)
def deploy(self, uuid: str | None = None) -> None:
"""Deploy a crew using either UUID or project name.
def deploy(self, uuid: Optional[str] = None) -> None:
"""
Deploy a crew using either UUID or project name.
Args:
uuid (Optional[str]): The UUID of the crew to deploy.
"""
self._start_deployment_span = self._telemetry.start_deployment_span(uuid)
console.print("Starting deployment...", style="bold blue")
@@ -73,7 +80,9 @@ class DeployCommand(BaseCommand, PlusAPIMixin):
self._display_deployment_info(response.json())
def create_crew(self, confirm: bool = False) -> None:
"""Create a new crew deployment."""
"""
Create a new crew deployment.
"""
self._create_crew_deployment_span = (
self._telemetry.create_crew_deployment_span()
)
@@ -101,28 +110,29 @@ class DeployCommand(BaseCommand, PlusAPIMixin):
self._display_creation_success(response.json())
def _confirm_input(
self, env_vars: dict[str, str], remote_repo_url: str, confirm: bool,
self, env_vars: Dict[str, str], remote_repo_url: str, confirm: bool
) -> None:
"""Confirm input parameters with the user.
"""
Confirm input parameters with the user.
Args:
env_vars (Dict[str, str]): Environment variables.
remote_repo_url (str): Remote repository URL.
confirm (bool): Whether to confirm input.
"""
if not confirm:
input(f"Press Enter to continue with the following Env vars: {env_vars}")
input(
f"Press Enter to continue with the following remote repository: {remote_repo_url}\n",
f"Press Enter to continue with the following remote repository: {remote_repo_url}\n"
)
def _create_payload(
self,
env_vars: dict[str, str],
env_vars: Dict[str, str],
remote_repo_url: str,
) -> dict[str, Any]:
"""Create the payload for crew creation.
) -> Dict[str, Any]:
"""
Create the payload for crew creation.
Args:
remote_repo_url (str): Remote repository URL.
@@ -130,26 +140,25 @@ class DeployCommand(BaseCommand, PlusAPIMixin):
Returns:
Dict[str, Any]: The payload for crew creation.
"""
return {
"deploy": {
"name": self.project_name,
"repo_clone_url": remote_repo_url,
"env": env_vars,
},
}
}
def _display_creation_success(self, json_response: dict[str, Any]) -> None:
"""Display success message after crew creation.
def _display_creation_success(self, json_response: Dict[str, Any]) -> None:
"""
Display success message after crew creation.
Args:
json_response (Dict[str, Any]): The response containing crew information.
"""
console.print("Deployment created successfully!\n", style="bold green")
console.print(
f"Name: {self.project_name} ({json_response['uuid']})", style="bold green",
f"Name: {self.project_name} ({json_response['uuid']})", style="bold green"
)
console.print(f"Status: {json_response['status']}", style="bold green")
console.print("\nTo (re)deploy the crew, run:")
@@ -158,7 +167,9 @@ class DeployCommand(BaseCommand, PlusAPIMixin):
console.print(f"crewai deploy push --uuid {json_response['uuid']}")
def list_crews(self) -> None:
"""List all available crews."""
"""
List all available crews.
"""
console.print("Listing all Crews\n", style="bold blue")
response = self.plus_api_client.list_crews()
@@ -168,29 +179,31 @@ class DeployCommand(BaseCommand, PlusAPIMixin):
else:
self._display_no_crews_message()
def _display_crews(self, crews_data: list[dict[str, Any]]) -> None:
"""Display the list of crews.
def _display_crews(self, crews_data: List[Dict[str, Any]]) -> None:
"""
Display the list of crews.
Args:
crews_data (List[Dict[str, Any]]): List of crew data to display.
"""
for crew_data in crews_data:
console.print(
f"- {crew_data['name']} ({crew_data['uuid']}) [blue]{crew_data['status']}[/blue]",
f"- {crew_data['name']} ({crew_data['uuid']}) [blue]{crew_data['status']}[/blue]"
)
def _display_no_crews_message(self) -> None:
"""Display a message when no crews are available."""
"""
Display a message when no crews are available.
"""
console.print("You don't have any Crews yet. Let's create one!", style="yellow")
console.print(" crewai create crew <crew_name>", style="green")
def get_crew_status(self, uuid: str | None = None) -> None:
"""Get the status of a crew.
def get_crew_status(self, uuid: Optional[str] = None) -> None:
"""
Get the status of a crew.
Args:
uuid (Optional[str]): The UUID of the crew to check.
"""
console.print("Fetching deployment status...", style="bold blue")
if uuid:
@@ -204,23 +217,23 @@ class DeployCommand(BaseCommand, PlusAPIMixin):
self._validate_response(response)
self._display_crew_status(response.json())
def _display_crew_status(self, status_data: dict[str, str]) -> None:
"""Display the status of a crew.
def _display_crew_status(self, status_data: Dict[str, str]) -> None:
"""
Display the status of a crew.
Args:
status_data (Dict[str, str]): The status data to display.
"""
console.print(f"Name:\t {status_data['name']}")
console.print(f"Status:\t {status_data['status']}")
def get_crew_logs(self, uuid: str | None, log_type: str = "deployment") -> None:
"""Get logs for a crew.
def get_crew_logs(self, uuid: Optional[str], log_type: str = "deployment") -> None:
"""
Get logs for a crew.
Args:
uuid (Optional[str]): The UUID of the crew to get logs for.
log_type (str): The type of logs to retrieve (default: "deployment").
"""
self._get_crew_logs_span = self._telemetry.get_crew_logs_span(uuid, log_type)
console.print(f"Fetching {log_type} logs...", style="bold blue")
@@ -236,12 +249,12 @@ class DeployCommand(BaseCommand, PlusAPIMixin):
self._validate_response(response)
self._display_logs(response.json())
def remove_crew(self, uuid: str | None) -> None:
"""Remove a crew deployment.
def remove_crew(self, uuid: Optional[str]) -> None:
"""
Remove a crew deployment.
Args:
uuid (Optional[str]): The UUID of the crew to remove.
"""
self._remove_crew_span = self._telemetry.remove_crew_span(uuid)
console.print("Removing deployment...", style="bold blue")
@@ -256,9 +269,9 @@ class DeployCommand(BaseCommand, PlusAPIMixin):
if response.status_code == 204:
console.print(
f"Crew '{self.project_name}' removed successfully.", style="green",
f"Crew '{self.project_name}' removed successfully.", style="green"
)
else:
console.print(
f"Failed to remove crew '{self.project_name}'", style="bold red",
f"Failed to remove crew '{self.project_name}'", style="bold red"
)

View File

@@ -4,19 +4,18 @@ import click
def evaluate_crew(n_iterations: int, model: str) -> None:
"""Test and Evaluate the crew by running a command in the UV environment.
"""
Test and Evaluate the crew by running a command in the UV environment.
Args:
n_iterations (int): The number of iterations to test the crew.
model (str): The model to test the crew with.
"""
command = ["uv", "run", "test", str(n_iterations), model]
try:
if n_iterations <= 0:
msg = "The number of iterations must be a positive integer."
raise ValueError(msg)
raise ValueError("The number of iterations must be a positive integer.")
result = subprocess.run(command, capture_output=False, text=True, check=True)

View File

@@ -1,18 +1,16 @@
import subprocess
from functools import cache
from functools import lru_cache
class Repository:
def __init__(self, path=".") -> None:
def __init__(self, path="."):
self.path = path
if not self.is_git_installed():
msg = "Git is not installed or not found in your PATH."
raise ValueError(msg)
raise ValueError("Git is not installed or not found in your PATH.")
if not self.is_git_repo():
msg = f"{self.path} is not a Git repository."
raise ValueError(msg)
raise ValueError(f"{self.path} is not a Git repository.")
self.fetch()
@@ -20,7 +18,7 @@ class Repository:
"""Check if Git is installed and available in the system."""
try:
subprocess.run(
["git", "--version"], capture_output=True, check=True, text=True,
["git", "--version"], capture_output=True, check=True, text=True
)
return True
except (subprocess.CalledProcessError, FileNotFoundError):
@@ -38,7 +36,7 @@ class Repository:
encoding="utf-8",
).strip()
@cache
@lru_cache(maxsize=None)
def is_git_repo(self) -> bool:
"""Check if the current directory is a git repository."""
try:
@@ -64,7 +62,10 @@ class Repository:
def is_synced(self) -> bool:
"""Return True if the Git repository is fully synced with the remote, False otherwise."""
return not (self.has_uncommitted_changes() or self.is_ahead_or_behind())
if self.has_uncommitted_changes() or self.is_ahead_or_behind():
return False
else:
return True
def origin_url(self) -> str | None:
"""Get the Git repository's remote URL."""

View File

@@ -4,13 +4,15 @@ import click
# Be mindful about changing this.
# on some environments we don't use this command but instead uv sync directly
# on some enviorments we don't use this command but instead uv sync directly
# so if you expect this to support more things you will need to replicate it there
# ask @joaomdmoura if you are unsure
def install_crew(proxy_options: list[str]) -> None:
"""Install the crew by running the UV command to lock and install."""
"""
Install the crew by running the UV command to lock and install.
"""
try:
command = ["uv", "sync", *proxy_options]
command = ["uv", "sync"] + proxy_options
subprocess.run(command, check=True, capture_output=False, text=True)
except subprocess.CalledProcessError as e:

View File

@@ -4,7 +4,9 @@ import click
def kickoff_flow() -> None:
"""Kickoff the flow by running a command in the UV environment."""
"""
Kickoff the flow by running a command in the UV environment.
"""
command = ["uv", "run", "kickoff"]
try:

View File

@@ -4,7 +4,9 @@ import click
def plot_flow() -> None:
"""Plot the flow by running a command in the UV environment."""
"""
Plot the flow by running a command in the UV environment.
"""
command = ["uv", "run", "plot"]
try:

View File

@@ -1,4 +1,5 @@
from os import getenv
from typing import Optional
from urllib.parse import urljoin
import requests
@@ -7,7 +8,9 @@ from crewai.cli.version import get_crewai_version
class PlusAPI:
"""This class exposes methods for working with the CrewAI+ API."""
"""
This class exposes methods for working with the CrewAI+ API.
"""
TOOLS_RESOURCE = "/crewai_plus/api/v1/tools"
CREWS_RESOURCE = "/crewai_plus/api/v1/crews"
@@ -39,7 +42,7 @@ class PlusAPI:
handle: str,
is_public: bool,
version: str,
description: str | None,
description: Optional[str],
encoded_file: str,
):
params = {
@@ -53,7 +56,7 @@ class PlusAPI:
def deploy_by_name(self, project_name: str) -> requests.Response:
return self._make_request(
"POST", f"{self.CREWS_RESOURCE}/by-name/{project_name}/deploy",
"POST", f"{self.CREWS_RESOURCE}/by-name/{project_name}/deploy"
)
def deploy_by_uuid(self, uuid: str) -> requests.Response:
@@ -61,29 +64,29 @@ class PlusAPI:
def crew_status_by_name(self, project_name: str) -> requests.Response:
return self._make_request(
"GET", f"{self.CREWS_RESOURCE}/by-name/{project_name}/status",
"GET", f"{self.CREWS_RESOURCE}/by-name/{project_name}/status"
)
def crew_status_by_uuid(self, uuid: str) -> requests.Response:
return self._make_request("GET", f"{self.CREWS_RESOURCE}/{uuid}/status")
def crew_by_name(
self, project_name: str, log_type: str = "deployment",
self, project_name: str, log_type: str = "deployment"
) -> requests.Response:
return self._make_request(
"GET", f"{self.CREWS_RESOURCE}/by-name/{project_name}/logs/{log_type}",
"GET", f"{self.CREWS_RESOURCE}/by-name/{project_name}/logs/{log_type}"
)
def crew_by_uuid(
self, uuid: str, log_type: str = "deployment",
self, uuid: str, log_type: str = "deployment"
) -> requests.Response:
return self._make_request(
"GET", f"{self.CREWS_RESOURCE}/{uuid}/logs/{log_type}",
"GET", f"{self.CREWS_RESOURCE}/{uuid}/logs/{log_type}"
)
def delete_crew_by_name(self, project_name: str) -> requests.Response:
return self._make_request(
"DELETE", f"{self.CREWS_RESOURCE}/by-name/{project_name}",
"DELETE", f"{self.CREWS_RESOURCE}/by-name/{project_name}"
)
def delete_crew_by_uuid(self, uuid: str) -> requests.Response:

View File

@@ -10,7 +10,8 @@ from crewai.cli.constants import JSON_URL, MODELS, PROVIDERS
def select_choice(prompt_message, choices):
"""Presents a list of choices to the user and prompts them to select one.
"""
Presents a list of choices to the user and prompts them to select one.
Args:
- prompt_message (str): The message to display to the user before presenting the choices.
@@ -18,11 +19,11 @@ def select_choice(prompt_message, choices):
Returns:
- str: The selected choice from the list, or None if the user chooses to quit.
"""
provider_models = get_provider_data()
if not provider_models:
return None
return
click.secho(prompt_message, fg="cyan")
for idx, choice in enumerate(choices, start=1):
click.secho(f"{idx}. {choice}", fg="cyan")
@@ -30,7 +31,7 @@ def select_choice(prompt_message, choices):
while True:
choice = click.prompt(
"Enter the number of your choice or 'q' to quit", type=str,
"Enter the number of your choice or 'q' to quit", type=str
)
if choice.lower() == "q":
@@ -50,7 +51,8 @@ def select_choice(prompt_message, choices):
def select_provider(provider_models):
"""Presents a list of providers to the user and prompts them to select one.
"""
Presents a list of providers to the user and prompts them to select one.
Args:
- provider_models (dict): A dictionary of provider models.
@@ -58,13 +60,12 @@ def select_provider(provider_models):
Returns:
- str: The selected provider
- None: If user explicitly quits
"""
predefined_providers = [p.lower() for p in PROVIDERS]
all_providers = sorted(set(predefined_providers + list(provider_models.keys())))
provider = select_choice(
"Select a provider to set up:", [*predefined_providers, "other"],
"Select a provider to set up:", predefined_providers + ["other"]
)
if provider is None: # User typed 'q'
return None
@@ -78,7 +79,8 @@ def select_provider(provider_models):
def select_model(provider, provider_models):
"""Presents a list of models for a given provider to the user and prompts them to select one.
"""
Presents a list of models for a given provider to the user and prompts them to select one.
Args:
- provider (str): The provider for which to select a model.
@@ -86,7 +88,6 @@ def select_model(provider, provider_models):
Returns:
- str: The selected model, or None if the operation is aborted or an invalid selection is made.
"""
predefined_providers = [p.lower() for p in PROVIDERS]
@@ -99,13 +100,15 @@ def select_model(provider, provider_models):
click.secho(f"No models available for provider '{provider}'.", fg="red")
return None
return select_choice(
f"Select a model to use for {provider.capitalize()}:", available_models,
selected_model = select_choice(
f"Select a model to use for {provider.capitalize()}:", available_models
)
return selected_model
def load_provider_data(cache_file, cache_expiry):
"""Loads provider data from a cache file if it exists and is not expired. If the cache is expired or corrupted, it fetches the data from the web.
"""
Loads provider data from a cache file if it exists and is not expired. If the cache is expired or corrupted, it fetches the data from the web.
Args:
- cache_file (Path): The path to the cache file.
@@ -113,7 +116,6 @@ def load_provider_data(cache_file, cache_expiry):
Returns:
- dict or None: The loaded provider data or None if the operation fails.
"""
current_time = time.time()
if (
@@ -124,7 +126,7 @@ def load_provider_data(cache_file, cache_expiry):
if data:
return data
click.secho(
"Cache is corrupted. Fetching provider data from the web...", fg="yellow",
"Cache is corrupted. Fetching provider data from the web...", fg="yellow"
)
else:
click.secho(
@@ -135,31 +137,31 @@ def load_provider_data(cache_file, cache_expiry):
def read_cache_file(cache_file):
"""Reads and returns the JSON content from a cache file. Returns None if the file contains invalid JSON.
"""
Reads and returns the JSON content from a cache file. Returns None if the file contains invalid JSON.
Args:
- cache_file (Path): The path to the cache file.
Returns:
- dict or None: The JSON content of the cache file or None if the JSON is invalid.
"""
try:
with open(cache_file) as f:
with open(cache_file, "r") as f:
return json.load(f)
except json.JSONDecodeError:
return None
def fetch_provider_data(cache_file):
"""Fetches provider data from a specified URL and caches it to a file.
"""
Fetches provider data from a specified URL and caches it to a file.
Args:
- cache_file (Path): The path to the cache file.
Returns:
- dict or None: The fetched provider data or None if the operation fails.
"""
try:
response = requests.get(JSON_URL, stream=True, timeout=60)
@@ -176,20 +178,20 @@ def fetch_provider_data(cache_file):
def download_data(response):
"""Downloads data from a given HTTP response and returns the JSON content.
"""
Downloads data from a given HTTP response and returns the JSON content.
Args:
- response (requests.Response): The HTTP response object.
Returns:
- dict: The JSON content of the response.
"""
total_size = int(response.headers.get("content-length", 0))
block_size = 8192
data_chunks = []
with click.progressbar(
length=total_size, label="Downloading", show_pos=True,
length=total_size, label="Downloading", show_pos=True
) as progress_bar:
for chunk in response.iter_content(block_size):
if chunk:
@@ -200,11 +202,11 @@ def download_data(response):
def get_provider_data():
"""Retrieves provider data from a cache file, filters out models based on provider criteria, and returns a dictionary of providers mapped to their models.
"""
Retrieves provider data from a cache file, filters out models based on provider criteria, and returns a dictionary of providers mapped to their models.
Returns:
- dict or None: A dictionary of providers mapped to their models or None if the operation fails.
"""
cache_dir = Path.home() / ".crewai"
cache_dir.mkdir(exist_ok=True)

View File

@@ -4,11 +4,11 @@ import click
def replay_task_command(task_id: str) -> None:
"""Replay the crew execution from a specific task.
"""
Replay the crew execution from a specific task.
Args:
task_id (str): The ID of the task to replay from.
"""
command = ["uv", "run", "replay", task_id]

View File

@@ -2,7 +2,7 @@ import subprocess
import click
from crewai.cli.utils import get_crews
from crewai.cli.utils import get_crew
def reset_memories_command(
@@ -13,7 +13,8 @@ def reset_memories_command(
kickoff_outputs,
all,
) -> None:
"""Reset the crew memories.
"""
Reset the crew memories.
Args:
long (bool): Whether to reset the long-term memory.
@@ -22,51 +23,38 @@ def reset_memories_command(
kickoff_outputs (bool): Whether to reset the latest kickoff task outputs.
all (bool): Whether to reset all memories.
knowledge (bool): Whether to reset the knowledge.
"""
try:
if not any([long, short, entity, kickoff_outputs, knowledge, all]):
crew = get_crew()
if not crew:
raise ValueError("No crew found.")
if all:
crew.reset_memories(command_type="all")
click.echo("All memories have been reset.")
return
if not any([long, short, entity, kickoff_outputs, knowledge]):
click.echo(
"No memory type specified. Please specify at least one type to reset.",
"No memory type specified. Please specify at least one type to reset."
)
return
crews = get_crews()
if not crews:
msg = "No crew found."
raise ValueError(msg)
for crew in crews:
if all:
crew.reset_memories(command_type="all")
click.echo(
f"[Crew ({crew.name if crew.name else crew.id})] Reset memories command has been completed.",
)
continue
if long:
crew.reset_memories(command_type="long")
click.echo(
f"[Crew ({crew.name if crew.name else crew.id})] Long term memory has been reset.",
)
if short:
crew.reset_memories(command_type="short")
click.echo(
f"[Crew ({crew.name if crew.name else crew.id})] Short term memory has been reset.",
)
if entity:
crew.reset_memories(command_type="entity")
click.echo(
f"[Crew ({crew.name if crew.name else crew.id})] Entity memory has been reset.",
)
if kickoff_outputs:
crew.reset_memories(command_type="kickoff_outputs")
click.echo(
f"[Crew ({crew.name if crew.name else crew.id})] Latest Kickoff outputs stored has been reset.",
)
if knowledge:
crew.reset_memories(command_type="knowledge")
click.echo(
f"[Crew ({crew.name if crew.name else crew.id})] Knowledge has been reset.",
)
if long:
crew.reset_memories(command_type="long")
click.echo("Long term memory has been reset.")
if short:
crew.reset_memories(command_type="short")
click.echo("Short term memory has been reset.")
if entity:
crew.reset_memories(command_type="entity")
click.echo("Entity memory has been reset.")
if kickoff_outputs:
crew.reset_memories(command_type="kickoff_outputs")
click.echo("Latest Kickoff outputs stored has been reset.")
if knowledge:
crew.reset_memories(command_type="knowledge")
click.echo("Knowledge has been reset.")
except subprocess.CalledProcessError as e:
click.echo(f"An error occurred while resetting the memories: {e}", err=True)

View File

@@ -1,5 +1,6 @@
import subprocess
from enum import Enum
from typing import List, Optional
import click
from packaging import version
@@ -14,7 +15,8 @@ class CrewType(Enum):
def run_crew() -> None:
"""Run the crew or flow by running a command in the UV environment.
"""
Run the crew or flow by running a command in the UV environment.
Starting from version 0.103.0, this command can be used to run both
standard crews and flows. For flows, it detects the type from pyproject.toml
@@ -46,11 +48,11 @@ def run_crew() -> None:
def execute_command(crew_type: CrewType) -> None:
"""Execute the appropriate command based on crew type.
"""
Execute the appropriate command based on crew type.
Args:
crew_type: The type of crew to run
"""
command = ["uv", "run", "kickoff" if crew_type == CrewType.FLOW else "run_crew"]
@@ -65,12 +67,12 @@ def execute_command(crew_type: CrewType) -> None:
def handle_error(error: subprocess.CalledProcessError, crew_type: CrewType) -> None:
"""Handle subprocess errors with appropriate messaging.
"""
Handle subprocess errors with appropriate messaging.
Args:
error: The subprocess error that occurred
crew_type: The type of crew that was being run
"""
entity_type = "flow" if crew_type == CrewType.FLOW else "crew"
click.echo(f"An error occurred while running the {entity_type}: {error}", err=True)

View File

@@ -5,7 +5,7 @@ description = "{{name}} using crewAI"
authors = [{ name = "Your Name", email = "you@example.com" }]
requires-python = ">=3.10,<3.13"
dependencies = [
"crewai[tools]>=0.119.0,<1.0.0"
"crewai[tools]>=0.118.0,<1.0.0"
]
[project.scripts]

View File

@@ -5,7 +5,7 @@ description = "{{name}} using crewAI"
authors = [{ name = "Your Name", email = "you@example.com" }]
requires-python = ">=3.10,<3.13"
dependencies = [
"crewai[tools]>=0.119.0,<1.0.0",
"crewai[tools]>=0.118.0,<1.0.0",
]
[project.scripts]

View File

@@ -5,7 +5,7 @@ description = "Power up your crews with {{folder_name}}"
readme = "README.md"
requires-python = ">=3.10,<3.13"
dependencies = [
"crewai[tools]>=0.119.0"
"crewai[tools]>=0.118.0"
]
[tool.crewai]

View File

@@ -22,13 +22,15 @@ console = Console()
class ToolCommand(BaseCommand, PlusAPIMixin):
"""A class to handle tool repository related operations for CrewAI projects."""
"""
A class to handle tool repository related operations for CrewAI projects.
"""
def __init__(self) -> None:
def __init__(self):
BaseCommand.__init__(self)
PlusAPIMixin.__init__(self, telemetry=self._telemetry)
def create(self, handle: str) -> None:
def create(self, handle: str):
self._ensure_not_in_project()
folder_name = handle.replace(" ", "_").replace("-", "_").lower()
@@ -38,7 +40,8 @@ class ToolCommand(BaseCommand, PlusAPIMixin):
if project_root.exists():
click.secho(f"Folder {folder_name} already exists.", fg="red")
raise SystemExit
os.makedirs(project_root)
else:
os.makedirs(project_root)
click.secho(f"Creating custom tool {folder_name}...", fg="green", bold=True)
@@ -53,12 +56,12 @@ class ToolCommand(BaseCommand, PlusAPIMixin):
self.login()
subprocess.run(["git", "init"], check=True)
console.print(
f"[green]Created custom tool [bold]{folder_name}[/bold]. Run [bold]cd {project_root}[/bold] to start working.[/green]",
f"[green]Created custom tool [bold]{folder_name}[/bold]. Run [bold]cd {project_root}[/bold] to start working.[/green]"
)
finally:
os.chdir(old_directory)
def publish(self, is_public: bool, force: bool = False) -> None:
def publish(self, is_public: bool, force: bool = False):
if not git.Repository().is_synced() and not force:
console.print(
"[bold red]Failed to publish tool.[/bold red]\n"
@@ -66,9 +69,9 @@ class ToolCommand(BaseCommand, PlusAPIMixin):
"* [bold]Commit[/bold] your changes.\n"
"* [bold]Push[/bold] to sync with the remote.\n"
"* [bold]Pull[/bold] the latest changes from the remote.\n"
"\nOnce your repository is up-to-date, retry publishing the tool.",
"\nOnce your repository is up-to-date, retry publishing the tool."
)
raise SystemExit
raise SystemExit()
project_name = get_project_name(require=True)
assert isinstance(project_name, str)
@@ -87,7 +90,7 @@ class ToolCommand(BaseCommand, PlusAPIMixin):
)
tarball_filename = next(
(f for f in os.listdir(temp_build_dir) if f.endswith(".tar.gz")), None,
(f for f in os.listdir(temp_build_dir) if f.endswith(".tar.gz")), None
)
if not tarball_filename:
console.print(
@@ -120,7 +123,7 @@ class ToolCommand(BaseCommand, PlusAPIMixin):
style="bold green",
)
def install(self, handle: str) -> None:
def install(self, handle: str):
get_response = self.plus_api_client.get_tool(handle)
if get_response.status_code == 404:
@@ -129,9 +132,9 @@ class ToolCommand(BaseCommand, PlusAPIMixin):
style="bold red",
)
raise SystemExit
if get_response.status_code != 200:
elif get_response.status_code != 200:
console.print(
"Failed to get tool details. Please try again later.", style="bold red",
"Failed to get tool details. Please try again later.", style="bold red"
)
raise SystemExit
@@ -139,7 +142,7 @@ class ToolCommand(BaseCommand, PlusAPIMixin):
console.print(f"Successfully installed {handle}", style="bold green")
def login(self) -> None:
def login(self):
login_response = self.plus_api_client.login_to_tool_repository()
if login_response.status_code != 200:
@@ -161,10 +164,10 @@ class ToolCommand(BaseCommand, PlusAPIMixin):
settings.dump()
console.print(
"Successfully authenticated to the tool repository.", style="bold green",
"Successfully authenticated to the tool repository.", style="bold green"
)
def _add_package(self, tool_details) -> None:
def _add_package(self, tool_details):
tool_handle = tool_details["handle"]
repository_handle = tool_details["repository"]["handle"]
repository_url = tool_details["repository"]["url"]
@@ -189,16 +192,16 @@ class ToolCommand(BaseCommand, PlusAPIMixin):
click.echo(add_package_result.stderr, err=True)
raise SystemExit
def _ensure_not_in_project(self) -> None:
def _ensure_not_in_project(self):
if os.path.isfile("./pyproject.toml"):
console.print(
"[bold red]Oops! It looks like you're inside a project.[/bold red]",
"[bold red]Oops! It looks like you're inside a project.[/bold red]"
)
console.print(
"You can't create a new tool while inside an existing project.",
"You can't create a new tool while inside an existing project."
)
console.print(
"[bold yellow]Tip:[/bold yellow] Navigate to a different directory and try again.",
"[bold yellow]Tip:[/bold yellow] Navigate to a different directory and try again."
)
raise SystemExit
@@ -208,10 +211,10 @@ class ToolCommand(BaseCommand, PlusAPIMixin):
env = os.environ.copy()
env[f"UV_INDEX_{repository_handle}_USERNAME"] = str(
settings.tool_repository_username or "",
settings.tool_repository_username or ""
)
env[f"UV_INDEX_{repository_handle}_PASSWORD"] = str(
settings.tool_repository_password or "",
settings.tool_repository_password or ""
)
return env

View File

@@ -4,22 +4,20 @@ import click
def train_crew(n_iterations: int, filename: str) -> None:
"""Train the crew by running a command in the UV environment.
"""
Train the crew by running a command in the UV environment.
Args:
n_iterations (int): The number of iterations to train the crew.
"""
command = ["uv", "run", "train", str(n_iterations), filename]
try:
if n_iterations <= 0:
msg = "The number of iterations must be a positive integer."
raise ValueError(msg)
raise ValueError("The number of iterations must be a positive integer.")
if not filename.endswith(".pkl"):
msg = "The filename must not end with .pkl"
raise ValueError(msg)
raise ValueError("The filename must not end with .pkl")
result = subprocess.run(command, capture_output=False, text=True, check=True)

View File

@@ -11,8 +11,9 @@ def update_crew() -> None:
migrate_pyproject("pyproject.toml", "pyproject.toml")
def migrate_pyproject(input_file, output_file) -> None:
"""Migrate the pyproject.toml to the new format.
def migrate_pyproject(input_file, output_file):
"""
Migrate the pyproject.toml to the new format.
This function is used to migrate the pyproject.toml to the new format.
And it will be used to migrate the pyproject.toml to the new format when uv is used.
@@ -80,7 +81,7 @@ def migrate_pyproject(input_file, output_file) -> None:
# Extract the module name from any existing script
existing_scripts = new_pyproject["project"]["scripts"]
module_name = next(
value.split(".")[0] for value in existing_scripts.values() if "." in value
(value.split(".")[0] for value in existing_scripts.values() if "." in value)
)
new_pyproject["project"]["scripts"]["run_crew"] = f"{module_name}.main:run"
@@ -92,19 +93,22 @@ def migrate_pyproject(input_file, output_file) -> None:
# Backup the old pyproject.toml
backup_file = "pyproject-old.toml"
shutil.copy2(input_file, backup_file)
print(f"Original pyproject.toml backed up as {backup_file}")
# Rename the poetry.lock file
lock_file = "poetry.lock"
lock_backup = "poetry-old.lock"
if os.path.exists(lock_file):
os.rename(lock_file, lock_backup)
print(f"Original poetry.lock renamed to {lock_backup}")
else:
pass
print("No poetry.lock file found to rename.")
# Write the new pyproject.toml
with open(output_file, "wb") as f:
tomli_w.dump(new_pyproject, f)
print(f"Migration complete. New pyproject.toml written to {output_file}")
def parse_version(version: str) -> str:

View File

@@ -2,8 +2,7 @@ import os
import shutil
import sys
from functools import reduce
from inspect import isfunction, ismethod
from typing import Any, get_type_hints
from typing import Any, Dict, List
import click
import tomli
@@ -11,7 +10,6 @@ from rich.console import Console
from crewai.cli.constants import ENV_VARS
from crewai.crew import Crew
from crewai.flow import Flow
if sys.version_info >= (3, 11):
import tomllib
@@ -19,9 +17,9 @@ if sys.version_info >= (3, 11):
console = Console()
def copy_template(src, dst, name, class_name, folder_name) -> None:
def copy_template(src, dst, name, class_name, folder_name):
"""Copy a file from src to dst."""
with open(src) as file:
with open(src, "r") as file:
content = file.read()
# Interpolate the content
@@ -39,7 +37,8 @@ def copy_template(src, dst, name, class_name, folder_name) -> None:
def read_toml(file_path: str = "pyproject.toml"):
"""Read the content of a TOML file and return it as a dictionary."""
with open(file_path, "rb") as f:
return tomli.load(f)
toml_dict = tomli.load(f)
return toml_dict
def parse_toml(content):
@@ -49,56 +48,59 @@ def parse_toml(content):
def get_project_name(
pyproject_path: str = "pyproject.toml", require: bool = False,
pyproject_path: str = "pyproject.toml", require: bool = False
) -> str | None:
"""Get the project name from the pyproject.toml file."""
return _get_project_attribute(pyproject_path, ["project", "name"], require=require)
def get_project_version(
pyproject_path: str = "pyproject.toml", require: bool = False,
pyproject_path: str = "pyproject.toml", require: bool = False
) -> str | None:
"""Get the project version from the pyproject.toml file."""
return _get_project_attribute(
pyproject_path, ["project", "version"], require=require,
pyproject_path, ["project", "version"], require=require
)
def get_project_description(
pyproject_path: str = "pyproject.toml", require: bool = False,
pyproject_path: str = "pyproject.toml", require: bool = False
) -> str | None:
"""Get the project description from the pyproject.toml file."""
return _get_project_attribute(
pyproject_path, ["project", "description"], require=require,
pyproject_path, ["project", "description"], require=require
)
def _get_project_attribute(
pyproject_path: str, keys: list[str], require: bool,
pyproject_path: str, keys: List[str], require: bool
) -> Any | None:
"""Get an attribute from the pyproject.toml file."""
attribute = None
try:
with open(pyproject_path) as f:
with open(pyproject_path, "r") as f:
pyproject_content = parse_toml(f.read())
dependencies = (
_get_nested_value(pyproject_content, ["project", "dependencies"]) or []
)
if not any(True for dep in dependencies if "crewai" in dep):
msg = "crewai is not in the dependencies."
raise Exception(msg)
raise Exception("crewai is not in the dependencies.")
attribute = _get_nested_value(pyproject_content, keys)
except FileNotFoundError:
pass
print(f"Error: {pyproject_path} not found.")
except KeyError:
pass
except tomllib.TOMLDecodeError if sys.version_info >= (3, 11) else Exception: # type: ignore
pass
except Exception:
pass
print(f"Error: {pyproject_path} is not a valid pyproject.toml file.")
except tomllib.TOMLDecodeError if sys.version_info >= (3, 11) else Exception as e: # type: ignore
print(
f"Error: {pyproject_path} is not a valid TOML file."
if sys.version_info >= (3, 11)
else f"Error reading the pyproject.toml file: {e}"
)
except Exception as e:
print(f"Error reading the pyproject.toml file: {e}")
if require and not attribute:
console.print(
@@ -110,7 +112,7 @@ def _get_project_attribute(
return attribute
def _get_nested_value(data: dict[str, Any], keys: list[str]) -> Any:
def _get_nested_value(data: Dict[str, Any], keys: List[str]) -> Any:
return reduce(dict.__getitem__, keys, data)
@@ -118,7 +120,7 @@ def fetch_and_json_env_file(env_file_path: str = ".env") -> dict:
"""Fetch the environment variables from a .env file and return them as a dictionary."""
try:
# Read the .env file
with open(env_file_path) as f:
with open(env_file_path, "r") as f:
env_content = f.read()
# Parse the .env file content to a dictionary
@@ -131,14 +133,14 @@ def fetch_and_json_env_file(env_file_path: str = ".env") -> dict:
return env_dict
except FileNotFoundError:
pass
except Exception:
pass
print(f"Error: {env_file_path} not found.")
except Exception as e:
print(f"Error reading the .env file: {e}")
return {}
def tree_copy(source, destination) -> None:
def tree_copy(source, destination):
"""Copies the entire directory structure from the source to the destination."""
for item in os.listdir(source):
source_item = os.path.join(source, item)
@@ -149,7 +151,7 @@ def tree_copy(source, destination) -> None:
shutil.copy2(source_item, destination_item)
def tree_find_and_replace(directory, find, replace) -> None:
def tree_find_and_replace(directory, find, replace):
"""Recursively searches through a directory, replacing a target string in
both file contents and filenames with a specified replacement string.
"""
@@ -157,7 +159,7 @@ def tree_find_and_replace(directory, find, replace) -> None:
for filename in files:
filepath = os.path.join(path, filename)
with open(filepath) as file:
with open(filepath, "r") as file:
contents = file.read()
with open(filepath, "w") as file:
file.write(contents.replace(find, replace))
@@ -176,19 +178,19 @@ def tree_find_and_replace(directory, find, replace) -> None:
def load_env_vars(folder_path):
"""Loads environment variables from a .env file in the specified folder path.
"""
Loads environment variables from a .env file in the specified folder path.
Args:
- folder_path (Path): The path to the folder containing the .env file.
Returns:
- dict: A dictionary of environment variables.
"""
env_file_path = folder_path / ".env"
env_vars = {}
if env_file_path.exists():
with open(env_file_path) as file:
with open(env_file_path, "r") as file:
for line in file:
key, _, value = line.strip().partition("=")
if key and value:
@@ -197,7 +199,8 @@ def load_env_vars(folder_path):
def update_env_vars(env_vars, provider, model):
"""Updates environment variables with the API key for the selected provider and model.
"""
Updates environment variables with the API key for the selected provider and model.
Args:
- env_vars (dict): Environment variables dictionary.
@@ -206,7 +209,6 @@ def update_env_vars(env_vars, provider, model):
Returns:
- None
"""
api_key_var = ENV_VARS.get(
provider,
@@ -214,14 +216,14 @@ def update_env_vars(env_vars, provider, model):
click.prompt(
f"Enter the environment variable name for your {provider.capitalize()} API key",
type=str,
),
)
],
)[0]
if api_key_var not in env_vars:
try:
env_vars[api_key_var] = click.prompt(
f"Enter your {provider.capitalize()} API key", type=str, hide_input=True,
f"Enter your {provider.capitalize()} API key", type=str, hide_input=True
)
except click.exceptions.Abort:
click.secho("Operation aborted by the user.", fg="red")
@@ -234,13 +236,13 @@ def update_env_vars(env_vars, provider, model):
return env_vars
def write_env_file(folder_path, env_vars) -> None:
"""Writes environment variables to a .env file in the specified folder.
def write_env_file(folder_path, env_vars):
"""
Writes environment variables to a .env file in the specified folder.
Args:
- folder_path (Path): The path to the folder where the .env file will be written.
- env_vars (dict): A dictionary of environment variables to write.
"""
env_file_path = folder_path / ".env"
with open(env_file_path, "w") as file:
@@ -248,18 +250,18 @@ def write_env_file(folder_path, env_vars) -> None:
file.write(f"{key}={value}\n")
def get_crews(crew_path: str = "crew.py", require: bool = False) -> list[Crew]:
"""Get the crew instances from the a file."""
crew_instances = []
def get_crew(crew_path: str = "crew.py", require: bool = False) -> Crew | None:
"""Get the crew instance from the crew.py file."""
try:
import importlib.util
import os
for root, _, files in os.walk("."):
if crew_path in files:
crew_os_path = os.path.join(root, crew_path)
try:
spec = importlib.util.spec_from_file_location(
"crew_module", crew_os_path,
"crew_module", crew_os_path
)
if not spec or not spec.loader:
continue
@@ -269,20 +271,26 @@ def get_crews(crew_path: str = "crew.py", require: bool = False) -> list[Crew]:
spec.loader.exec_module(module)
for attr_name in dir(module):
module_attr = getattr(module, attr_name)
attr = getattr(module, attr_name)
try:
crew_instances.extend(fetch_crews(module_attr))
except Exception:
if callable(attr) and hasattr(attr, "crew"):
crew_instance = attr().crew()
return crew_instance
except Exception as e:
print(f"Error processing attribute {attr_name}: {e}")
continue
except Exception:
except Exception as exec_error:
print(f"Error executing module: {exec_error}")
import traceback
print(f"Traceback: {traceback.format_exc()}")
except (ImportError, AttributeError) as e:
if require:
console.print(
f"Error importing crew from {crew_path}: {e!s}",
f"Error importing crew from {crew_path}: {str(e)}",
style="bold red",
)
continue
@@ -292,42 +300,12 @@ def get_crews(crew_path: str = "crew.py", require: bool = False) -> list[Crew]:
if require:
console.print("No valid Crew instance found in crew.py", style="bold red")
raise SystemExit
return None
except Exception as e:
if require:
console.print(
f"Unexpected error while loading crew: {e!s}", style="bold red",
f"Unexpected error while loading crew: {str(e)}", style="bold red"
)
raise SystemExit
return crew_instances
def get_crew_instance(module_attr) -> Crew | None:
if (
callable(module_attr)
and hasattr(module_attr, "is_crew_class")
and module_attr.is_crew_class
):
return module_attr().crew()
if (ismethod(module_attr) or isfunction(module_attr)) and get_type_hints(
module_attr,
).get("return") is Crew:
return module_attr()
if isinstance(module_attr, Crew):
return module_attr
return None
def fetch_crews(module_attr) -> list[Crew]:
crew_instances: list[Crew] = []
if crew_instance := get_crew_instance(module_attr):
crew_instances.append(crew_instance)
if isinstance(module_attr, type) and issubclass(module_attr, Flow):
instance = module_attr()
for attr_name in dir(instance):
attr = getattr(instance, attr_name)
if crew_instance := get_crew_instance(attr):
crew_instances.append(crew_instance)
return crew_instances
return None

View File

@@ -2,5 +2,5 @@ import importlib.metadata
def get_crewai_version() -> str:
"""Get the version number of CrewAI running the CLI."""
"""Get the version number of CrewAI running the CLI"""
return importlib.metadata.version("crewai")

File diff suppressed because it is too large Load Diff

View File

@@ -1,5 +1,5 @@
import json
from typing import Any
from typing import Any, Dict, Optional
from pydantic import BaseModel, Field
@@ -12,28 +12,27 @@ class CrewOutput(BaseModel):
"""Class that represents the result of a crew."""
raw: str = Field(description="Raw output of crew", default="")
pydantic: BaseModel | None = Field(
description="Pydantic output of Crew", default=None,
pydantic: Optional[BaseModel] = Field(
description="Pydantic output of Crew", default=None
)
json_dict: dict[str, Any] | None = Field(
description="JSON dict output of Crew", default=None,
json_dict: Optional[Dict[str, Any]] = Field(
description="JSON dict output of Crew", default=None
)
tasks_output: list[TaskOutput] = Field(
description="Output of each task", default=[],
description="Output of each task", default=[]
)
token_usage: UsageMetrics = Field(description="Processed token summary", default={})
@property
def json(self) -> str | None:
def json(self) -> Optional[str]:
if self.tasks_output[-1].output_format != OutputFormat.JSON:
msg = "No JSON output found in the final task. Please make sure to set the output_json property in the final task in your crew."
raise ValueError(
msg,
"No JSON output found in the final task. Please make sure to set the output_json property in the final task in your crew."
)
return json.dumps(self.json_dict)
def to_dict(self) -> dict[str, Any]:
def to_dict(self) -> Dict[str, Any]:
"""Convert json_output and pydantic_output to a dictionary."""
output_dict = {}
if self.json_dict:
@@ -45,12 +44,12 @@ class CrewOutput(BaseModel):
def __getitem__(self, key):
if self.pydantic and hasattr(self.pydantic, key):
return getattr(self.pydantic, key)
if self.json_dict and key in self.json_dict:
elif self.json_dict and key in self.json_dict:
return self.json_dict[key]
msg = f"Key '{key}' not found in CrewOutput."
raise KeyError(msg)
else:
raise KeyError(f"Key '{key}' not found in CrewOutput.")
def __str__(self) -> str:
def __str__(self):
if self.pydantic:
return str(self.pydantic)
if self.json_dict:

View File

@@ -2,11 +2,17 @@ import asyncio
import copy
import inspect
import logging
from collections.abc import Callable
from typing import (
Any,
Callable,
Dict,
Generic,
List,
Optional,
Set,
Type,
TypeVar,
Union,
cast,
)
from uuid import uuid4
@@ -42,14 +48,14 @@ class FlowState(BaseModel):
# Type variables with explicit bounds
T = TypeVar(
"T", bound=dict[str, Any] | BaseModel,
"T", bound=Union[Dict[str, Any], BaseModel]
) # Generic flow state type parameter
StateT = TypeVar(
"StateT", bound=dict[str, Any] | BaseModel,
"StateT", bound=Union[Dict[str, Any], BaseModel]
) # State validation type parameter
def ensure_state_type(state: Any, expected_type: type[StateT]) -> StateT:
def ensure_state_type(state: Any, expected_type: Type[StateT]) -> StateT:
"""Ensure state matches expected type with proper validation.
Args:
@@ -62,7 +68,6 @@ def ensure_state_type(state: Any, expected_type: type[StateT]) -> StateT:
Raises:
TypeError: If state doesn't match expected type
ValueError: If state validation fails
"""
"""Ensure state matches expected type with proper validation.
@@ -79,22 +84,20 @@ def ensure_state_type(state: Any, expected_type: type[StateT]) -> StateT:
"""
if expected_type is dict:
if not isinstance(state, dict):
msg = f"Expected dict, got {type(state).__name__}"
raise TypeError(msg)
return cast("StateT", state)
raise TypeError(f"Expected dict, got {type(state).__name__}")
return cast(StateT, state)
if isinstance(expected_type, type) and issubclass(expected_type, BaseModel):
if not isinstance(state, expected_type):
msg = f"Expected {expected_type.__name__}, got {type(state).__name__}"
raise TypeError(
msg,
f"Expected {expected_type.__name__}, got {type(state).__name__}"
)
return cast("StateT", state)
msg = f"Invalid expected_type: {expected_type}"
raise TypeError(msg)
return cast(StateT, state)
raise TypeError(f"Invalid expected_type: {expected_type}")
def start(condition: str | dict | Callable | None = None) -> Callable:
"""Marks a method as a flow's starting point.
def start(condition: Optional[Union[str, dict, Callable]] = None) -> Callable:
"""
Marks a method as a flow's starting point.
This decorator designates a method as an entry point for the flow execution.
It can optionally specify conditions that trigger the start based on other
@@ -132,7 +135,6 @@ def start(condition: str | dict | Callable | None = None) -> Callable:
>>> @start(and_("method1", "method2")) # Start after multiple methods
>>> def complex_start(self):
... pass
"""
def decorator(func):
@@ -152,17 +154,17 @@ def start(condition: str | dict | Callable | None = None) -> Callable:
func.__trigger_methods__ = [condition.__name__]
func.__condition_type__ = "OR"
else:
msg = "Condition must be a method, string, or a result of or_() or and_()"
raise ValueError(
msg,
"Condition must be a method, string, or a result of or_() or and_()"
)
return func
return decorator
def listen(condition: str | dict | Callable) -> Callable:
"""Creates a listener that executes when specified conditions are met.
def listen(condition: Union[str, dict, Callable]) -> Callable:
"""
Creates a listener that executes when specified conditions are met.
This decorator sets up a method to execute in response to other method
executions in the flow. It supports both simple and complex triggering
@@ -195,7 +197,6 @@ def listen(condition: str | dict | Callable) -> Callable:
>>> @listen(or_("success", "failure")) # Listen to multiple methods
>>> def handle_completion(self):
... pass
"""
def decorator(func):
@@ -213,17 +214,17 @@ def listen(condition: str | dict | Callable) -> Callable:
func.__trigger_methods__ = [condition.__name__]
func.__condition_type__ = "OR"
else:
msg = "Condition must be a method, string, or a result of or_() or and_()"
raise ValueError(
msg,
"Condition must be a method, string, or a result of or_() or and_()"
)
return func
return decorator
def router(condition: str | dict | Callable) -> Callable:
"""Creates a routing method that directs flow execution based on conditions.
def router(condition: Union[str, dict, Callable]) -> Callable:
"""
Creates a routing method that directs flow execution based on conditions.
This decorator marks a method as a router, which can dynamically determine
the next steps in the flow based on its return value. Routers are triggered
@@ -261,7 +262,6 @@ def router(condition: str | dict | Callable) -> Callable:
... if all([self.state.valid, self.state.processed]):
... return CONTINUE
... return STOP
"""
def decorator(func):
@@ -280,17 +280,17 @@ def router(condition: str | dict | Callable) -> Callable:
func.__trigger_methods__ = [condition.__name__]
func.__condition_type__ = "OR"
else:
msg = "Condition must be a method, string, or a result of or_() or and_()"
raise ValueError(
msg,
"Condition must be a method, string, or a result of or_() or and_()"
)
return func
return decorator
def or_(*conditions: str | dict | Callable) -> dict:
"""Combines multiple conditions with OR logic for flow control.
def or_(*conditions: Union[str, dict, Callable]) -> dict:
"""
Combines multiple conditions with OR logic for flow control.
Creates a condition that is satisfied when any of the specified conditions
are met. This is used with @start, @listen, or @router decorators to create
@@ -320,7 +320,6 @@ def or_(*conditions: str | dict | Callable) -> dict:
>>> @listen(or_("success", "timeout"))
>>> def handle_completion(self):
... pass
"""
methods = []
for condition in conditions:
@@ -331,13 +330,13 @@ def or_(*conditions: str | dict | Callable) -> dict:
elif callable(condition):
methods.append(getattr(condition, "__name__", repr(condition)))
else:
msg = "Invalid condition in or_()"
raise ValueError(msg)
raise ValueError("Invalid condition in or_()")
return {"type": "OR", "methods": methods}
def and_(*conditions: str | dict | Callable) -> dict:
"""Combines multiple conditions with AND logic for flow control.
def and_(*conditions: Union[str, dict, Callable]) -> dict:
"""
Combines multiple conditions with AND logic for flow control.
Creates a condition that is satisfied only when all specified conditions
are met. This is used with @start, @listen, or @router decorators to create
@@ -367,7 +366,6 @@ def and_(*conditions: str | dict | Callable) -> dict:
>>> @listen(and_("validated", "processed"))
>>> def handle_complete_data(self):
... pass
"""
methods = []
for condition in conditions:
@@ -378,8 +376,7 @@ def and_(*conditions: str | dict | Callable) -> dict:
elif callable(condition):
methods.append(getattr(condition, "__name__", repr(condition)))
else:
msg = "Invalid condition in and_()"
raise ValueError(msg)
raise ValueError("Invalid condition in and_()")
return {"type": "AND", "methods": methods}
@@ -419,10 +416,10 @@ class FlowMeta(type):
if possible_returns:
router_paths[attr_name] = possible_returns
cls._start_methods = start_methods
cls._listeners = listeners
cls._routers = routers
cls._router_paths = router_paths
setattr(cls, "_start_methods", start_methods)
setattr(cls, "_listeners", listeners)
setattr(cls, "_routers", routers)
setattr(cls, "_router_paths", router_paths)
return cls
@@ -430,18 +427,17 @@ class FlowMeta(type):
class Flow(Generic[T], metaclass=FlowMeta):
"""Base class for all flows.
Type parameter T must be either Dict[str, Any] or a subclass of BaseModel.
"""
Type parameter T must be either Dict[str, Any] or a subclass of BaseModel."""
_printer = Printer()
_start_methods: list[str] = []
_listeners: dict[str, tuple[str, list[str]]] = {}
_routers: set[str] = set()
_router_paths: dict[str, list[str]] = {}
initial_state: type[T] | T | None = None
_start_methods: List[str] = []
_listeners: Dict[str, tuple[str, List[str]]] = {}
_routers: Set[str] = set()
_router_paths: Dict[str, List[str]] = {}
initial_state: Union[Type[T], T, None] = None
def __class_getitem__(cls: type["Flow"], item: type[T]) -> type["Flow"]:
def __class_getitem__(cls: Type["Flow"], item: Type[T]) -> Type["Flow"]:
class _FlowGeneric(cls): # type: ignore
_initial_state_T = item # type: ignore
@@ -450,7 +446,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
def __init__(
self,
persistence: FlowPersistence | None = None,
persistence: Optional[FlowPersistence] = None,
**kwargs: Any,
) -> None:
"""Initialize a new Flow instance.
@@ -458,14 +454,13 @@ class Flow(Generic[T], metaclass=FlowMeta):
Args:
persistence: Optional persistence backend for storing flow states
**kwargs: Additional state values to initialize or override
"""
# Initialize basic instance attributes
self._methods: dict[str, Callable] = {}
self._method_execution_counts: dict[str, int] = {}
self._pending_and_listeners: dict[str, set[str]] = {}
self._method_outputs: list[Any] = [] # List to store all method outputs
self._persistence: FlowPersistence | None = persistence
self._methods: Dict[str, Callable] = {}
self._method_execution_counts: Dict[str, int] = {}
self._pending_and_listeners: Dict[str, Set[str]] = {}
self._method_outputs: List[Any] = [] # List to store all method outputs
self._persistence: Optional[FlowPersistence] = persistence
# Initialize state with initial values
self._state = self._create_initial_state()
@@ -507,61 +502,58 @@ class Flow(Generic[T], metaclass=FlowMeta):
Raises:
ValueError: If structured state model lacks 'id' field
TypeError: If state is neither BaseModel nor dictionary
"""
# Handle case where initial_state is None but we have a type parameter
if self.initial_state is None and hasattr(self, "_initial_state_T"):
state_type = self._initial_state_T
state_type = getattr(self, "_initial_state_T")
if isinstance(state_type, type):
if issubclass(state_type, FlowState):
# Create instance without id, then set it
instance = state_type()
if not hasattr(instance, "id"):
instance.id = str(uuid4())
return cast("T", instance)
if issubclass(state_type, BaseModel):
setattr(instance, "id", str(uuid4()))
return cast(T, instance)
elif issubclass(state_type, BaseModel):
# Create a new type that includes the ID field
class StateWithId(state_type, FlowState): # type: ignore
pass
instance = StateWithId()
if not hasattr(instance, "id"):
instance.id = str(uuid4())
return cast("T", instance)
if state_type is dict:
return cast("T", {"id": str(uuid4())})
setattr(instance, "id", str(uuid4()))
return cast(T, instance)
elif state_type is dict:
return cast(T, {"id": str(uuid4())})
# Handle case where no initial state is provided
if self.initial_state is None:
return cast("T", {"id": str(uuid4())})
return cast(T, {"id": str(uuid4())})
# Handle case where initial_state is a type (class)
if isinstance(self.initial_state, type):
if issubclass(self.initial_state, FlowState):
return cast("T", self.initial_state()) # Uses model defaults
if issubclass(self.initial_state, BaseModel):
return cast(T, self.initial_state()) # Uses model defaults
elif issubclass(self.initial_state, BaseModel):
# Validate that the model has an id field
model_fields = getattr(self.initial_state, "model_fields", None)
if not model_fields or "id" not in model_fields:
msg = "Flow state model must have an 'id' field"
raise ValueError(msg)
return cast("T", self.initial_state()) # Uses model defaults
if self.initial_state is dict:
return cast("T", {"id": str(uuid4())})
raise ValueError("Flow state model must have an 'id' field")
return cast(T, self.initial_state()) # Uses model defaults
elif self.initial_state is dict:
return cast(T, {"id": str(uuid4())})
# Handle dictionary instance case
if isinstance(self.initial_state, dict):
new_state = dict(self.initial_state) # Copy to avoid mutations
if "id" not in new_state:
new_state["id"] = str(uuid4())
return cast("T", new_state)
return cast(T, new_state)
# Handle BaseModel instance case
if isinstance(self.initial_state, BaseModel):
model = cast("BaseModel", self.initial_state)
model = cast(BaseModel, self.initial_state)
if not hasattr(model, "id"):
msg = "Flow state model must have an 'id' field"
raise ValueError(msg)
raise ValueError("Flow state model must have an 'id' field")
# Create new instance with same values to avoid mutations
if hasattr(model, "model_dump"):
@@ -578,10 +570,9 @@ class Flow(Generic[T], metaclass=FlowMeta):
# Create new instance of the same class
model_class = type(model)
return cast("T", model_class(**state_dict))
msg = f"Initial state must be dict or BaseModel, got {type(self.initial_state)}"
return cast(T, model_class(**state_dict))
raise TypeError(
msg,
f"Initial state must be dict or BaseModel, got {type(self.initial_state)}"
)
def _copy_state(self) -> T:
@@ -592,7 +583,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
return self._state
@property
def method_outputs(self) -> list[Any]:
def method_outputs(self) -> List[Any]:
"""Returns the list of all outputs from executed methods."""
return self._method_outputs
@@ -616,7 +607,6 @@ class Flow(Generic[T], metaclass=FlowMeta):
flow = MyFlow()
print(f"Current flow ID: {flow.flow_id}") # Safely get flow ID
```
"""
try:
if not hasattr(self, "_state"):
@@ -624,13 +614,13 @@ class Flow(Generic[T], metaclass=FlowMeta):
if isinstance(self._state, dict):
return str(self._state.get("id", ""))
if isinstance(self._state, BaseModel):
elif isinstance(self._state, BaseModel):
return str(getattr(self._state, "id", ""))
return ""
except (AttributeError, TypeError):
return "" # Safely handle any unexpected attribute access issues
def _initialize_state(self, inputs: dict[str, Any]) -> None:
def _initialize_state(self, inputs: Dict[str, Any]) -> None:
"""Initialize or update flow state with new inputs.
Args:
@@ -639,7 +629,6 @@ class Flow(Generic[T], metaclass=FlowMeta):
Raises:
ValueError: If validation fails for structured state
TypeError: If state is neither BaseModel nor dictionary
"""
if isinstance(self._state, dict):
# For dict states, preserve existing fields unless overridden
@@ -655,7 +644,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
elif isinstance(self._state, BaseModel):
# For BaseModel states, preserve existing fields unless overridden
try:
model = cast("BaseModel", self._state)
model = cast(BaseModel, self._state)
# Get current state as dict
if hasattr(model, "model_dump"):
current_state = model.model_dump()
@@ -673,21 +662,19 @@ class Flow(Generic[T], metaclass=FlowMeta):
model_class = type(model)
if hasattr(model_class, "model_validate"):
# Pydantic v2
self._state = cast("T", model_class.model_validate(new_state))
self._state = cast(T, model_class.model_validate(new_state))
elif hasattr(model_class, "parse_obj"):
# Pydantic v1
self._state = cast("T", model_class.parse_obj(new_state))
self._state = cast(T, model_class.parse_obj(new_state))
else:
# Fallback for other BaseModel implementations
self._state = cast("T", model_class(**new_state))
self._state = cast(T, model_class(**new_state))
except ValidationError as e:
msg = f"Invalid inputs for structured state: {e}"
raise ValueError(msg) from e
raise ValueError(f"Invalid inputs for structured state: {e}") from e
else:
msg = "State must be a BaseModel instance or a dictionary."
raise TypeError(msg)
raise TypeError("State must be a BaseModel instance or a dictionary.")
def _restore_state(self, stored_state: dict[str, Any]) -> None:
def _restore_state(self, stored_state: Dict[str, Any]) -> None:
"""Restore flow state from persistence.
Args:
@@ -696,13 +683,11 @@ class Flow(Generic[T], metaclass=FlowMeta):
Raises:
ValueError: If validation fails for structured state
TypeError: If state is neither BaseModel nor dictionary
"""
# When restoring from persistence, use the stored ID
stored_id = stored_state.get("id")
if not stored_id:
msg = "Stored state must have an 'id' field"
raise ValueError(msg)
raise ValueError("Stored state must have an 'id' field")
if isinstance(self._state, dict):
# For dict states, update all fields from stored state
@@ -710,22 +695,22 @@ class Flow(Generic[T], metaclass=FlowMeta):
self._state.update(stored_state)
elif isinstance(self._state, BaseModel):
# For BaseModel states, create new instance with stored values
model = cast("BaseModel", self._state)
model = cast(BaseModel, self._state)
if hasattr(model, "model_validate"):
# Pydantic v2
self._state = cast("T", type(model).model_validate(stored_state))
self._state = cast(T, type(model).model_validate(stored_state))
elif hasattr(model, "parse_obj"):
# Pydantic v1
self._state = cast("T", type(model).parse_obj(stored_state))
self._state = cast(T, type(model).parse_obj(stored_state))
else:
# Fallback for other BaseModel implementations
self._state = cast("T", type(model)(**stored_state))
self._state = cast(T, type(model)(**stored_state))
else:
msg = f"State must be dict or BaseModel, got {type(self._state)}"
raise TypeError(msg)
raise TypeError(f"State must be dict or BaseModel, got {type(self._state)}")
def kickoff(self, inputs: dict[str, Any] | None = None) -> Any:
"""Start the flow execution in a synchronous context.
def kickoff(self, inputs: Optional[Dict[str, Any]] = None) -> Any:
"""
Start the flow execution in a synchronous context.
This method wraps kickoff_async so that all state initialization and event
emission is handled in the asynchronous method.
@@ -736,8 +721,9 @@ class Flow(Generic[T], metaclass=FlowMeta):
return asyncio.run(run_flow())
async def kickoff_async(self, inputs: dict[str, Any] | None = None) -> Any:
"""Start the flow execution asynchronously.
async def kickoff_async(self, inputs: Optional[Dict[str, Any]] = None) -> Any:
"""
Start the flow execution asynchronously.
This method performs state restoration (if an 'id' is provided and persistence is available)
and updates the flow state with any additional inputs. It then emits the FlowStartedEvent,
@@ -749,7 +735,6 @@ class Flow(Generic[T], metaclass=FlowMeta):
Returns:
The final output from the flow, which is the result of the last executed method.
"""
if inputs:
# Override the id in the state if it exists in inputs
@@ -757,7 +742,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
if isinstance(self._state, dict):
self._state["id"] = inputs["id"]
elif isinstance(self._state, BaseModel):
self._state.id = inputs["id"]
setattr(self._state, "id", inputs["id"])
# If persistence is enabled, attempt to restore the stored state using the provided id.
if "id" in inputs and self._persistence is not None:
@@ -771,7 +756,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
self._restore_state(stored_state)
else:
self._log_flow_event(
f"No flow state found for UUID: {restore_uuid}", color="red",
f"No flow state found for UUID: {restore_uuid}", color="red"
)
# Update state with any additional inputs (ignoring the 'id' key)
@@ -789,7 +774,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
),
)
self._log_flow_event(
f"Flow started with ID: {self.flow_id}", color="bold_magenta",
f"Flow started with ID: {self.flow_id}", color="bold_magenta"
)
if inputs is not None and "id" not in inputs:
@@ -815,7 +800,8 @@ class Flow(Generic[T], metaclass=FlowMeta):
return final_output
async def _execute_start_method(self, start_method_name: str) -> None:
"""Executes a flow's start method and its triggered listeners.
"""
Executes a flow's start method and its triggered listeners.
This internal method handles the execution of methods marked with @start
decorator and manages the subsequent chain of listener executions.
@@ -830,15 +816,14 @@ class Flow(Generic[T], metaclass=FlowMeta):
- Executes the start method and captures its result
- Triggers execution of any listeners waiting on this start method
- Part of the flow's initialization sequence
"""
result = await self._execute_method(
start_method_name, self._methods[start_method_name],
start_method_name, self._methods[start_method_name]
)
await self._execute_listeners(start_method_name, result)
async def _execute_method(
self, method_name: str, method: Callable, *args: Any, **kwargs: Any,
self, method_name: str, method: Callable, *args: Any, **kwargs: Any
) -> Any:
try:
dumped_params = {f"_{i}": arg for i, arg in enumerate(args)} | (
@@ -888,10 +873,11 @@ class Flow(Generic[T], metaclass=FlowMeta):
error=e,
),
)
raise
raise e
async def _execute_listeners(self, trigger_method: str, result: Any) -> None:
"""Executes all listeners and routers triggered by a method completion.
"""
Executes all listeners and routers triggered by a method completion.
This internal method manages the execution flow by:
1. First executing all triggered routers sequentially
@@ -911,7 +897,6 @@ class Flow(Generic[T], metaclass=FlowMeta):
- Each router's result becomes a new trigger_method
- Normal listeners are executed in parallel for efficiency
- Listeners can receive the trigger method's result as a parameter
"""
# First, handle routers repeatedly until no router triggers anymore
router_results = []
@@ -919,7 +904,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
while True:
routers_triggered = self._find_triggered_methods(
current_trigger, router_only=True,
current_trigger, router_only=True
)
if not routers_triggered:
break
@@ -935,12 +920,12 @@ class Flow(Generic[T], metaclass=FlowMeta):
)
# Now execute normal listeners for all router results and the original trigger
all_triggers = [trigger_method, *router_results]
all_triggers = [trigger_method] + router_results
for current_trigger in all_triggers:
if current_trigger: # Skip None results
listeners_triggered = self._find_triggered_methods(
current_trigger, router_only=False,
current_trigger, router_only=False
)
if listeners_triggered:
tasks = [
@@ -950,9 +935,10 @@ class Flow(Generic[T], metaclass=FlowMeta):
await asyncio.gather(*tasks)
def _find_triggered_methods(
self, trigger_method: str, router_only: bool,
) -> list[str]:
"""Finds all methods that should be triggered based on conditions.
self, trigger_method: str, router_only: bool
) -> List[str]:
"""
Finds all methods that should be triggered based on conditions.
This internal method evaluates both OR and AND conditions to determine
which methods should be executed next in the flow.
@@ -977,7 +963,6 @@ class Flow(Generic[T], metaclass=FlowMeta):
* AND: Triggers only when all conditions are met
- Maintains state for AND conditions using _pending_and_listeners
- Separates router and normal listener evaluation
"""
triggered = []
for listener_name, (condition_type, methods) in self._listeners.items():
@@ -1007,7 +992,8 @@ class Flow(Generic[T], metaclass=FlowMeta):
return triggered
async def _execute_single_listener(self, listener_name: str, result: Any) -> None:
"""Executes a single listener method with proper event handling.
"""
Executes a single listener method with proper event handling.
This internal method manages the execution of an individual listener,
including parameter inspection, event emission, and error handling.
@@ -1032,7 +1018,6 @@ class Flow(Generic[T], metaclass=FlowMeta):
-------------
Catches and logs any exceptions during execution, preventing
individual listener failures from breaking the entire flow.
"""
try:
method = self._methods[listener_name]
@@ -1043,7 +1028,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
if method_params:
listener_result = await self._execute_method(
listener_name, method, result,
listener_name, method, result
)
else:
listener_result = await self._execute_method(listener_name, method)
@@ -1051,14 +1036,17 @@ class Flow(Generic[T], metaclass=FlowMeta):
# Execute listeners (and possibly routers) of this listener
await self._execute_listeners(listener_name, listener_result)
except Exception:
except Exception as e:
print(
f"[Flow._execute_single_listener] Error in method {listener_name}: {e}"
)
import traceback
traceback.print_exc()
raise
def _log_flow_event(
self, message: str, color: str = "yellow", level: str = "info",
self, message: str, color: str = "yellow", level: str = "info"
) -> None:
"""Centralized logging method for flow events.
@@ -1076,7 +1064,6 @@ class Flow(Generic[T], metaclass=FlowMeta):
Note:
This method uses the Printer utility for colored console output
and the standard logging module for log level support.
"""
self._printer.print(message, color=color)
if level == "info":

View File

@@ -1,43 +0,0 @@
import inspect
from pydantic import BaseModel, Field, InstanceOf, model_validator
from crewai.flow import Flow
class FlowTrackable(BaseModel):
"""Mixin that tracks the Flow instance that instantiated the object, e.g. a
Flow instance that created a Crew or Agent.
Automatically finds and stores a reference to the parent Flow instance by
inspecting the call stack.
"""
parent_flow: InstanceOf[Flow] | None = Field(
default=None,
description="The parent flow of the instance, if it was created inside a flow.",
)
@model_validator(mode="after")
def _set_parent_flow(self, max_depth: int = 5) -> "FlowTrackable":
frame = inspect.currentframe()
try:
if frame is None:
return self
frame = frame.f_back
for _ in range(max_depth):
if frame is None:
break
candidate = frame.f_locals.get("self")
if isinstance(candidate, Flow):
self.parent_flow = candidate
break
frame = frame.f_back
finally:
del frame
return self

View File

@@ -1,13 +1,14 @@
# flow_visualizer.py
import os
from pathlib import Path
from pyvis.network import Network
from crewai.flow.config import COLORS, NODE_STYLES
from crewai.flow.html_template_handler import HTMLTemplateHandler
from crewai.flow.legend_generator import generate_legend_items_html, get_legend_items
from crewai.flow.path_utils import safe_path_join
from crewai.flow.path_utils import safe_path_join, validate_path_exists
from crewai.flow.utils import calculate_node_levels
from crewai.flow.visualization_utils import (
add_edges,
@@ -19,8 +20,9 @@ from crewai.flow.visualization_utils import (
class FlowPlot:
"""Handles the creation and rendering of flow visualization diagrams."""
def __init__(self, flow) -> None:
"""Initialize FlowPlot with a flow object.
def __init__(self, flow):
"""
Initialize FlowPlot with a flow object.
Parameters
----------
@@ -31,24 +33,21 @@ class FlowPlot:
------
ValueError
If flow object is invalid or missing required attributes.
"""
if not hasattr(flow, "_methods"):
msg = "Invalid flow object: missing '_methods' attribute"
raise ValueError(msg)
if not hasattr(flow, "_listeners"):
msg = "Invalid flow object: missing '_listeners' attribute"
raise ValueError(msg)
if not hasattr(flow, "_start_methods"):
msg = "Invalid flow object: missing '_start_methods' attribute"
raise ValueError(msg)
if not hasattr(flow, '_methods'):
raise ValueError("Invalid flow object: missing '_methods' attribute")
if not hasattr(flow, '_listeners'):
raise ValueError("Invalid flow object: missing '_listeners' attribute")
if not hasattr(flow, '_start_methods'):
raise ValueError("Invalid flow object: missing '_start_methods' attribute")
self.flow = flow
self.colors = COLORS
self.node_styles = NODE_STYLES
def plot(self, filename) -> None:
"""Generate and save an HTML visualization of the flow.
def plot(self, filename):
"""
Generate and save an HTML visualization of the flow.
Parameters
----------
@@ -63,12 +62,10 @@ class FlowPlot:
If file operations fail or visualization cannot be generated.
RuntimeError
If network visualization generation fails.
"""
if not filename or not isinstance(filename, str):
msg = "Filename must be a non-empty string"
raise ValueError(msg)
raise ValueError("Filename must be a non-empty string")
try:
# Initialize network
net = Network(
@@ -92,63 +89,58 @@ class FlowPlot:
"enabled": false
}
}
""",
"""
)
# Calculate levels for nodes
try:
node_levels = calculate_node_levels(self.flow)
except Exception as e:
msg = f"Failed to calculate node levels: {e!s}"
raise ValueError(msg)
raise ValueError(f"Failed to calculate node levels: {str(e)}")
# Compute positions
try:
node_positions = compute_positions(self.flow, node_levels)
except Exception as e:
msg = f"Failed to compute node positions: {e!s}"
raise ValueError(msg)
raise ValueError(f"Failed to compute node positions: {str(e)}")
# Add nodes to the network
try:
add_nodes_to_network(net, self.flow, node_positions, self.node_styles)
except Exception as e:
msg = f"Failed to add nodes to network: {e!s}"
raise RuntimeError(msg)
raise RuntimeError(f"Failed to add nodes to network: {str(e)}")
# Add edges to the network
try:
add_edges(net, self.flow, node_positions, self.colors)
except Exception as e:
msg = f"Failed to add edges to network: {e!s}"
raise RuntimeError(msg)
raise RuntimeError(f"Failed to add edges to network: {str(e)}")
# Generate HTML
try:
network_html = net.generate_html()
final_html_content = self._generate_final_html(network_html)
except Exception as e:
msg = f"Failed to generate network visualization: {e!s}"
raise RuntimeError(msg)
raise RuntimeError(f"Failed to generate network visualization: {str(e)}")
# Save the final HTML content to the file
try:
with open(f"{filename}.html", "w", encoding="utf-8") as f:
f.write(final_html_content)
except OSError as e:
msg = f"Failed to save flow visualization to {filename}.html: {e!s}"
raise OSError(msg)
print(f"Plot saved as {filename}.html")
except IOError as e:
raise IOError(f"Failed to save flow visualization to {filename}.html: {str(e)}")
except (OSError, ValueError, RuntimeError):
raise
except (ValueError, RuntimeError, IOError) as e:
raise e
except Exception as e:
msg = f"Unexpected error during flow visualization: {e!s}"
raise RuntimeError(msg)
raise RuntimeError(f"Unexpected error during flow visualization: {str(e)}")
finally:
self._cleanup_pyvis_lib()
def _generate_final_html(self, network_html):
"""Generate the final HTML content with network visualization and legend.
"""
Generate the final HTML content with network visualization and legend.
Parameters
----------
@@ -166,11 +158,9 @@ class FlowPlot:
If template or logo files cannot be accessed.
ValueError
If network_html is invalid.
"""
if not network_html:
msg = "Invalid network HTML content"
raise ValueError(msg)
raise ValueError("Invalid network HTML content")
try:
# Extract just the body content from the generated HTML
@@ -179,11 +169,9 @@ class FlowPlot:
logo_path = safe_path_join("assets", "crewai_logo.svg", root=current_dir)
if not os.path.exists(template_path):
msg = f"Template file not found: {template_path}"
raise OSError(msg)
raise IOError(f"Template file not found: {template_path}")
if not os.path.exists(logo_path):
msg = f"Logo file not found: {logo_path}"
raise OSError(msg)
raise IOError(f"Logo file not found: {logo_path}")
html_handler = HTMLTemplateHandler(template_path, logo_path)
network_body = html_handler.extract_body_content(network_html)
@@ -191,15 +179,16 @@ class FlowPlot:
# Generate the legend items HTML
legend_items = get_legend_items(self.colors)
legend_items_html = generate_legend_items_html(legend_items)
return html_handler.generate_final_html(
network_body, legend_items_html,
final_html_content = html_handler.generate_final_html(
network_body, legend_items_html
)
return final_html_content
except Exception as e:
msg = f"Failed to generate visualization HTML: {e!s}"
raise OSError(msg)
raise IOError(f"Failed to generate visualization HTML: {str(e)}")
def _cleanup_pyvis_lib(self) -> None:
"""Clean up the generated lib folder from pyvis.
def _cleanup_pyvis_lib(self):
"""
Clean up the generated lib folder from pyvis.
This method safely removes the temporary lib directory created by pyvis
during network visualization generation.
@@ -209,14 +198,15 @@ class FlowPlot:
if os.path.exists(lib_folder) and os.path.isdir(lib_folder):
import shutil
shutil.rmtree(lib_folder)
except ValueError:
pass
except Exception:
pass
except ValueError as e:
print(f"Error validating lib folder path: {e}")
except Exception as e:
print(f"Error cleaning up lib folder: {e}")
def plot_flow(flow, filename="flow_plot") -> None:
"""Convenience function to create and save a flow visualization.
def plot_flow(flow, filename="flow_plot"):
"""
Convenience function to create and save a flow visualization.
Parameters
----------
@@ -231,7 +221,6 @@ def plot_flow(flow, filename="flow_plot") -> None:
If flow object or filename is invalid.
IOError
If file operations fail.
"""
visualizer = FlowPlot(flow)
visualizer.plot(filename)

View File

@@ -1,14 +1,16 @@
import base64
import re
from pathlib import Path
from crewai.flow.path_utils import validate_path_exists
from crewai.flow.path_utils import safe_path_join, validate_path_exists
class HTMLTemplateHandler:
"""Handles HTML template processing and generation for flow visualization diagrams."""
def __init__(self, template_path, logo_path) -> None:
"""Initialize HTMLTemplateHandler with validated template and logo paths.
def __init__(self, template_path, logo_path):
"""
Initialize HTMLTemplateHandler with validated template and logo paths.
Parameters
----------
@@ -21,18 +23,16 @@ class HTMLTemplateHandler:
------
ValueError
If template or logo paths are invalid or files don't exist.
"""
try:
self.template_path = validate_path_exists(template_path, "file")
self.logo_path = validate_path_exists(logo_path, "file")
except ValueError as e:
msg = f"Invalid template or logo path: {e}"
raise ValueError(msg)
raise ValueError(f"Invalid template or logo path: {e}")
def read_template(self):
"""Read and return the HTML template file contents."""
with open(self.template_path, encoding="utf-8") as f:
with open(self.template_path, "r", encoding="utf-8") as f:
return f.read()
def encode_logo(self):
@@ -81,12 +81,13 @@ class HTMLTemplateHandler:
final_html_content = html_template.replace("{{ title }}", title)
final_html_content = final_html_content.replace(
"{{ network_content }}", network_body,
"{{ network_content }}", network_body
)
final_html_content = final_html_content.replace(
"{{ logo_svg_base64 }}", logo_svg_base64,
"{{ logo_svg_base64 }}", logo_svg_base64
)
return final_html_content.replace(
"<!-- LEGEND_ITEMS_PLACEHOLDER -->", legend_items_html,
final_html_content = final_html_content.replace(
"<!-- LEGEND_ITEMS_PLACEHOLDER -->", legend_items_html
)
return final_html_content

View File

@@ -1,14 +1,18 @@
"""Path utilities for secure file operations in CrewAI flow module.
"""
Path utilities for secure file operations in CrewAI flow module.
This module provides utilities for secure path handling to prevent directory
traversal attacks and ensure paths remain within allowed boundaries.
"""
import os
from pathlib import Path
from typing import List, Union
def safe_path_join(*parts: str, root: str | Path | None = None) -> str:
"""Safely join path components and ensure the result is within allowed boundaries.
def safe_path_join(*parts: str, root: Union[str, Path, None] = None) -> str:
"""
Safely join path components and ensure the result is within allowed boundaries.
Parameters
----------
@@ -27,43 +31,39 @@ def safe_path_join(*parts: str, root: str | Path | None = None) -> str:
ValueError
If the resulting path would be outside the root directory
or if any path component is invalid.
"""
if not parts:
msg = "No path components provided"
raise ValueError(msg)
raise ValueError("No path components provided")
try:
# Convert all parts to strings and clean them
clean_parts = [str(part).strip() for part in parts if part]
if not clean_parts:
msg = "No valid path components provided"
raise ValueError(msg)
raise ValueError("No valid path components provided")
# Establish root directory
root_path = Path(root).resolve() if root else Path.cwd()
# Join and resolve the full path
full_path = Path(root_path, *clean_parts).resolve()
# Check if the resolved path is within root
if not str(full_path).startswith(str(root_path)):
msg = f"Invalid path: Potential directory traversal. Path must be within {root_path}"
raise ValueError(
msg,
f"Invalid path: Potential directory traversal. Path must be within {root_path}"
)
return str(full_path)
except Exception as e:
if isinstance(e, ValueError):
raise
msg = f"Invalid path components: {e!s}"
raise ValueError(msg)
raise ValueError(f"Invalid path components: {str(e)}")
def validate_path_exists(path: str | Path, file_type: str = "file") -> str:
"""Validate that a path exists and is of the expected type.
def validate_path_exists(path: Union[str, Path], file_type: str = "file") -> str:
"""
Validate that a path exists and is of the expected type.
Parameters
----------
@@ -81,33 +81,29 @@ def validate_path_exists(path: str | Path, file_type: str = "file") -> str:
------
ValueError
If path doesn't exist or is not of expected type.
"""
try:
path_obj = Path(path).resolve()
if not path_obj.exists():
msg = f"Path does not exist: {path}"
raise ValueError(msg)
raise ValueError(f"Path does not exist: {path}")
if file_type == "file" and not path_obj.is_file():
msg = f"Path is not a file: {path}"
raise ValueError(msg)
if file_type == "directory" and not path_obj.is_dir():
msg = f"Path is not a directory: {path}"
raise ValueError(msg)
raise ValueError(f"Path is not a file: {path}")
elif file_type == "directory" and not path_obj.is_dir():
raise ValueError(f"Path is not a directory: {path}")
return str(path_obj)
except Exception as e:
if isinstance(e, ValueError):
raise
msg = f"Invalid path: {e!s}"
raise ValueError(msg)
raise ValueError(f"Invalid path: {str(e)}")
def list_files(directory: str | Path, pattern: str = "*") -> list[str]:
"""Safely list files in a directory matching a pattern.
def list_files(directory: Union[str, Path], pattern: str = "*") -> List[str]:
"""
Safely list files in a directory matching a pattern.
Parameters
----------
@@ -125,18 +121,15 @@ def list_files(directory: str | Path, pattern: str = "*") -> list[str]:
------
ValueError
If directory is invalid or inaccessible.
"""
try:
dir_path = Path(directory).resolve()
if not dir_path.is_dir():
msg = f"Not a directory: {directory}"
raise ValueError(msg)
raise ValueError(f"Not a directory: {directory}")
return [str(p) for p in dir_path.glob(pattern) if p.is_file()]
except Exception as e:
if isinstance(e, ValueError):
raise
msg = f"Error listing files: {e!s}"
raise ValueError(msg)
raise ValueError(f"Error listing files: {str(e)}")

View File

@@ -1,52 +1,53 @@
"""Base class for flow state persistence."""
import abc
from typing import Any
from typing import Any, Dict, Optional, Union
from pydantic import BaseModel
class FlowPersistence(abc.ABC):
"""Abstract base class for flow state persistence.
This class defines the interface that all persistence implementations must follow.
It supports both structured (Pydantic BaseModel) and unstructured (dict) states.
"""
@abc.abstractmethod
def init_db(self) -> None:
"""Initialize the persistence backend.
This method should handle any necessary setup, such as:
- Creating tables
- Establishing connections
- Setting up indexes
"""
pass
@abc.abstractmethod
def save_state(
self,
flow_uuid: str,
method_name: str,
state_data: dict[str, Any] | BaseModel,
state_data: Union[Dict[str, Any], BaseModel]
) -> None:
"""Persist the flow state after method completion.
Args:
flow_uuid: Unique identifier for the flow instance
method_name: Name of the method that just completed
state_data: Current state data (either dict or Pydantic model)
"""
pass
@abc.abstractmethod
def load_state(self, flow_uuid: str) -> dict[str, Any] | None:
def load_state(self, flow_uuid: str) -> Optional[Dict[str, Any]]:
"""Load the most recent state for a given flow UUID.
Args:
flow_uuid: Unique identifier for the flow instance
Returns:
The most recent state as a dictionary, or None if no state exists
"""
pass

View File

@@ -1,4 +1,5 @@
"""Decorators for flow state persistence.
"""
Decorators for flow state persistence.
Example:
```python
@@ -18,16 +19,18 @@ Example:
# Asynchronous method implementation
await some_async_operation()
```
"""
import asyncio
import functools
import logging
from collections.abc import Callable
from typing import (
Any,
Callable,
Optional,
Type,
TypeVar,
Union,
cast,
)
@@ -45,7 +48,7 @@ LOG_MESSAGES = {
"save_state": "Saving flow state to memory for ID: {}",
"save_error": "Failed to persist state for method {}: {}",
"state_missing": "Flow instance has no state",
"id_missing": "Flow state must have an 'id' field for persistence",
"id_missing": "Flow state must have an 'id' field for persistence"
}
@@ -71,23 +74,20 @@ class PersistenceDecorator:
ValueError: If flow has no state or state lacks an ID
RuntimeError: If state persistence fails
AttributeError: If flow instance lacks required state attributes
"""
try:
state = getattr(flow_instance, "state", None)
state = getattr(flow_instance, 'state', None)
if state is None:
msg = "Flow instance has no state"
raise ValueError(msg)
raise ValueError("Flow instance has no state")
flow_uuid: str | None = None
flow_uuid: Optional[str] = None
if isinstance(state, dict):
flow_uuid = state.get("id")
flow_uuid = state.get('id')
elif isinstance(state, BaseModel):
flow_uuid = getattr(state, "id", None)
flow_uuid = getattr(state, 'id', None)
if not flow_uuid:
msg = "Flow state must have an 'id' field for persistence"
raise ValueError(msg)
raise ValueError("Flow state must have an 'id' field for persistence")
# Log state saving only if verbose is True
if verbose:
@@ -103,22 +103,21 @@ class PersistenceDecorator:
except Exception as e:
error_msg = LOG_MESSAGES["save_error"].format(method_name, str(e))
cls._printer.print(error_msg, color="red")
logger.exception(error_msg)
msg = f"State persistence failed: {e!s}"
raise RuntimeError(msg) from e
logger.error(error_msg)
raise RuntimeError(f"State persistence failed: {str(e)}") from e
except AttributeError:
error_msg = LOG_MESSAGES["state_missing"]
cls._printer.print(error_msg, color="red")
logger.exception(error_msg)
logger.error(error_msg)
raise ValueError(error_msg)
except (TypeError, ValueError) as e:
error_msg = LOG_MESSAGES["id_missing"]
cls._printer.print(error_msg, color="red")
logger.exception(error_msg)
logger.error(error_msg)
raise ValueError(error_msg) from e
def persist(persistence: FlowPersistence | None = None, verbose: bool = False):
def persist(persistence: Optional[FlowPersistence] = None, verbose: bool = False):
"""Decorator to persist flow state.
This decorator can be applied at either the class level or method level.
@@ -144,23 +143,22 @@ def persist(persistence: FlowPersistence | None = None, verbose: bool = False):
@start()
def begin(self):
pass
"""
def decorator(target: type | Callable[..., T]) -> type | Callable[..., T]:
def decorator(target: Union[Type, Callable[..., T]]) -> Union[Type, Callable[..., T]]:
"""Decorator that handles both class and method decoration."""
actual_persistence = persistence or SQLiteFlowPersistence()
if isinstance(target, type):
# Class decoration
original_init = target.__init__
original_init = getattr(target, "__init__")
@functools.wraps(original_init)
def new_init(self: Any, *args: Any, **kwargs: Any) -> None:
if "persistence" not in kwargs:
kwargs["persistence"] = actual_persistence
if 'persistence' not in kwargs:
kwargs['persistence'] = actual_persistence
original_init(self, *args, **kwargs)
target.__init__ = new_init
setattr(target, "__init__", new_init)
# Store original methods to preserve their decorators
original_methods = {}
@@ -193,7 +191,7 @@ def persist(persistence: FlowPersistence | None = None, verbose: bool = False):
for attr in ["__is_start_method__", "__trigger_methods__", "__condition_type__", "__is_router__"]:
if hasattr(method, attr):
setattr(wrapped, attr, getattr(method, attr))
wrapped.__is_flow_method__ = True
setattr(wrapped, "__is_flow_method__", True)
# Update the class with the wrapped method
setattr(target, name, wrapped)
@@ -213,42 +211,44 @@ def persist(persistence: FlowPersistence | None = None, verbose: bool = False):
for attr in ["__is_start_method__", "__trigger_methods__", "__condition_type__", "__is_router__"]:
if hasattr(method, attr):
setattr(wrapped, attr, getattr(method, attr))
wrapped.__is_flow_method__ = True
setattr(wrapped, "__is_flow_method__", True)
# Update the class with the wrapped method
setattr(target, name, wrapped)
return target
# Method decoration
method = target
method.__is_flow_method__ = True
else:
# Method decoration
method = target
setattr(method, "__is_flow_method__", True)
if asyncio.iscoroutinefunction(method):
@functools.wraps(method)
async def method_async_wrapper(flow_instance: Any, *args: Any, **kwargs: Any) -> T:
method_coro = method(flow_instance, *args, **kwargs)
if asyncio.iscoroutine(method_coro):
result = await method_coro
else:
result = method_coro
PersistenceDecorator.persist_state(flow_instance, method.__name__, actual_persistence, verbose)
return result
if asyncio.iscoroutinefunction(method):
@functools.wraps(method)
async def method_async_wrapper(flow_instance: Any, *args: Any, **kwargs: Any) -> T:
method_coro = method(flow_instance, *args, **kwargs)
if asyncio.iscoroutine(method_coro):
result = await method_coro
else:
result = method_coro
PersistenceDecorator.persist_state(flow_instance, method.__name__, actual_persistence, verbose)
return result
for attr in ["__is_start_method__", "__trigger_methods__", "__condition_type__", "__is_router__"]:
if hasattr(method, attr):
setattr(method_async_wrapper, attr, getattr(method, attr))
method_async_wrapper.__is_flow_method__ = True
return cast("Callable[..., T]", method_async_wrapper)
@functools.wraps(method)
def method_sync_wrapper(flow_instance: Any, *args: Any, **kwargs: Any) -> T:
result = method(flow_instance, *args, **kwargs)
PersistenceDecorator.persist_state(flow_instance, method.__name__, actual_persistence, verbose)
return result
for attr in ["__is_start_method__", "__trigger_methods__", "__condition_type__", "__is_router__"]:
if hasattr(method, attr):
setattr(method_async_wrapper, attr, getattr(method, attr))
setattr(method_async_wrapper, "__is_flow_method__", True)
return cast(Callable[..., T], method_async_wrapper)
else:
@functools.wraps(method)
def method_sync_wrapper(flow_instance: Any, *args: Any, **kwargs: Any) -> T:
result = method(flow_instance, *args, **kwargs)
PersistenceDecorator.persist_state(flow_instance, method.__name__, actual_persistence, verbose)
return result
for attr in ["__is_start_method__", "__trigger_methods__", "__condition_type__", "__is_router__"]:
if hasattr(method, attr):
setattr(method_sync_wrapper, attr, getattr(method, attr))
method_sync_wrapper.__is_flow_method__ = True
return cast("Callable[..., T]", method_sync_wrapper)
for attr in ["__is_start_method__", "__trigger_methods__", "__condition_type__", "__is_router__"]:
if hasattr(method, attr):
setattr(method_sync_wrapper, attr, getattr(method, attr))
setattr(method_sync_wrapper, "__is_flow_method__", True)
return cast(Callable[..., T], method_sync_wrapper)
return decorator

View File

@@ -1,10 +1,12 @@
"""SQLite-based implementation of flow state persistence."""
"""
SQLite-based implementation of flow state persistence.
"""
import json
import sqlite3
from datetime import datetime, timezone
from pathlib import Path
from typing import Any
from typing import Any, Dict, Optional, Union
from pydantic import BaseModel
@@ -21,7 +23,7 @@ class SQLiteFlowPersistence(FlowPersistence):
db_path: str
def __init__(self, db_path: str | None = None) -> None:
def __init__(self, db_path: Optional[str] = None):
"""Initialize SQLite persistence.
Args:
@@ -30,7 +32,6 @@ class SQLiteFlowPersistence(FlowPersistence):
Raises:
ValueError: If db_path is invalid
"""
from crewai.utilities.paths import db_storage_path
@@ -38,8 +39,7 @@ class SQLiteFlowPersistence(FlowPersistence):
path = db_path or str(Path(db_storage_path()) / "flow_states.db")
if not path:
msg = "Database path must be provided"
raise ValueError(msg)
raise ValueError("Database path must be provided")
self.db_path = path # Now mypy knows this is str
self.init_db()
@@ -56,21 +56,21 @@ class SQLiteFlowPersistence(FlowPersistence):
timestamp DATETIME NOT NULL,
state_json TEXT NOT NULL
)
""",
"""
)
# Add index for faster UUID lookups
conn.execute(
"""
CREATE INDEX IF NOT EXISTS idx_flow_states_uuid
ON flow_states(flow_uuid)
""",
"""
)
def save_state(
self,
flow_uuid: str,
method_name: str,
state_data: dict[str, Any] | BaseModel,
state_data: Union[Dict[str, Any], BaseModel],
) -> None:
"""Save the current flow state to SQLite.
@@ -78,7 +78,6 @@ class SQLiteFlowPersistence(FlowPersistence):
flow_uuid: Unique identifier for the flow instance
method_name: Name of the method that just completed
state_data: Current state data (either dict or Pydantic model)
"""
# Convert state_data to dict, handling both Pydantic and dict cases
if isinstance(state_data, BaseModel):
@@ -86,9 +85,8 @@ class SQLiteFlowPersistence(FlowPersistence):
elif isinstance(state_data, dict):
state_dict = state_data
else:
msg = f"state_data must be either a Pydantic BaseModel or dict, got {type(state_data)}"
raise ValueError(
msg,
f"state_data must be either a Pydantic BaseModel or dict, got {type(state_data)}"
)
with sqlite3.connect(self.db_path) as conn:
@@ -109,7 +107,7 @@ class SQLiteFlowPersistence(FlowPersistence):
),
)
def load_state(self, flow_uuid: str) -> dict[str, Any] | None:
def load_state(self, flow_uuid: str) -> Optional[Dict[str, Any]]:
"""Load the most recent state for a given flow UUID.
Args:
@@ -117,7 +115,6 @@ class SQLiteFlowPersistence(FlowPersistence):
Returns:
The most recent state as a dictionary, or None if no state exists
"""
with sqlite3.connect(self.db_path) as conn:
cursor = conn.execute(

View File

@@ -1,32 +1,33 @@
"""Utility functions for flow visualization and dependency analysis.
"""
Utility functions for flow visualization and dependency analysis.
This module provides core functionality for analyzing and manipulating flow structures,
including node level calculation, ancestor tracking, and return value analysis.
Functions in this module are primarily used by the visualization system to create
accurate and informative flow diagrams.
Example:
Example
-------
>>> flow = Flow()
>>> node_levels = calculate_node_levels(flow)
>>> ancestors = build_ancestor_dict(flow)
"""
import ast
import inspect
import textwrap
from collections import defaultdict, deque
from typing import Any
from typing import Any, Deque, Dict, List, Optional, Set, Union
def get_possible_return_constants(function: Any) -> list[str] | None:
def get_possible_return_constants(function: Any) -> Optional[List[str]]:
try:
source = inspect.getsource(function)
except OSError:
# Can't get source code
return None
except Exception:
except Exception as e:
print(f"Error retrieving source code for function {function.__name__}: {e}")
return None
try:
@@ -34,18 +35,24 @@ def get_possible_return_constants(function: Any) -> list[str] | None:
source = textwrap.dedent(source)
# Parse the source code into an AST
code_ast = ast.parse(source)
except IndentationError:
except IndentationError as e:
print(f"IndentationError while parsing source code of {function.__name__}: {e}")
print(f"Source code:\n{source}")
return None
except SyntaxError:
except SyntaxError as e:
print(f"SyntaxError while parsing source code of {function.__name__}: {e}")
print(f"Source code:\n{source}")
return None
except Exception:
except Exception as e:
print(f"Unexpected error while parsing source code of {function.__name__}: {e}")
print(f"Source code:\n{source}")
return None
return_values = set()
dict_definitions = {}
class DictionaryAssignmentVisitor(ast.NodeVisitor):
def visit_Assign(self, node) -> None:
def visit_Assign(self, node):
# Check if this assignment is assigning a dictionary literal to a variable
if isinstance(node.value, ast.Dict) and len(node.targets) == 1:
target = node.targets[0]
@@ -62,10 +69,10 @@ def get_possible_return_constants(function: Any) -> list[str] | None:
self.generic_visit(node)
class ReturnVisitor(ast.NodeVisitor):
def visit_Return(self, node) -> None:
def visit_Return(self, node):
# Direct string return
if isinstance(node.value, ast.Constant) and isinstance(
node.value.value, str,
node.value.value, str
):
return_values.add(node.value.value)
# Dictionary-based return, like return paths[result]
@@ -87,8 +94,9 @@ def get_possible_return_constants(function: Any) -> list[str] | None:
return list(return_values) if return_values else None
def calculate_node_levels(flow: Any) -> dict[str, int]:
"""Calculate the hierarchical level of each node in the flow.
def calculate_node_levels(flow: Any) -> Dict[str, int]:
"""
Calculate the hierarchical level of each node in the flow.
Performs a breadth-first traversal of the flow graph to assign levels
to nodes, starting with start methods at level 0.
@@ -109,12 +117,11 @@ def calculate_node_levels(flow: Any) -> dict[str, int]:
- Each subsequent connected node is assigned level = parent_level + 1
- Handles both OR and AND conditions for listeners
- Processes router paths separately
"""
levels: dict[str, int] = {}
queue: deque[str] = deque()
visited: set[str] = set()
pending_and_listeners: dict[str, set[str]] = {}
levels: Dict[str, int] = {}
queue: Deque[str] = deque()
visited: Set[str] = set()
pending_and_listeners: Dict[str, Set[str]] = {}
# Make all start methods at level 0
for method_name, method in flow._methods.items():
@@ -165,8 +172,9 @@ def calculate_node_levels(flow: Any) -> dict[str, int]:
return levels
def count_outgoing_edges(flow: Any) -> dict[str, int]:
"""Count the number of outgoing edges for each method in the flow.
def count_outgoing_edges(flow: Any) -> Dict[str, int]:
"""
Count the number of outgoing edges for each method in the flow.
Parameters
----------
@@ -177,7 +185,6 @@ def count_outgoing_edges(flow: Any) -> dict[str, int]:
-------
Dict[str, int]
Dictionary mapping method names to their outgoing edge count.
"""
counts = {}
for method_name in flow._methods:
@@ -190,8 +197,9 @@ def count_outgoing_edges(flow: Any) -> dict[str, int]:
return counts
def build_ancestor_dict(flow: Any) -> dict[str, set[str]]:
"""Build a dictionary mapping each node to its ancestor nodes.
def build_ancestor_dict(flow: Any) -> Dict[str, Set[str]]:
"""
Build a dictionary mapping each node to its ancestor nodes.
Parameters
----------
@@ -202,10 +210,9 @@ def build_ancestor_dict(flow: Any) -> dict[str, set[str]]:
-------
Dict[str, Set[str]]
Dictionary mapping each node to a set of its ancestor nodes.
"""
ancestors: dict[str, set[str]] = {node: set() for node in flow._methods}
visited: set[str] = set()
ancestors: Dict[str, Set[str]] = {node: set() for node in flow._methods}
visited: Set[str] = set()
for node in flow._methods:
if node not in visited:
dfs_ancestors(node, ancestors, visited, flow)
@@ -213,9 +220,10 @@ def build_ancestor_dict(flow: Any) -> dict[str, set[str]]:
def dfs_ancestors(
node: str, ancestors: dict[str, set[str]], visited: set[str], flow: Any,
node: str, ancestors: Dict[str, Set[str]], visited: Set[str], flow: Any
) -> None:
"""Perform depth-first search to build ancestor relationships.
"""
Perform depth-first search to build ancestor relationships.
Parameters
----------
@@ -232,7 +240,6 @@ def dfs_ancestors(
-----
This function modifies the ancestors dictionary in-place to build
the complete ancestor graph.
"""
if node in visited:
return
@@ -258,9 +265,10 @@ def dfs_ancestors(
def is_ancestor(
node: str, ancestor_candidate: str, ancestors: dict[str, set[str]],
node: str, ancestor_candidate: str, ancestors: Dict[str, Set[str]]
) -> bool:
"""Check if one node is an ancestor of another.
"""
Check if one node is an ancestor of another.
Parameters
----------
@@ -275,13 +283,13 @@ def is_ancestor(
-------
bool
True if ancestor_candidate is an ancestor of node, False otherwise.
"""
return ancestor_candidate in ancestors.get(node, set())
def build_parent_children_dict(flow: Any) -> dict[str, list[str]]:
"""Build a dictionary mapping parent nodes to their children.
def build_parent_children_dict(flow: Any) -> Dict[str, List[str]]:
"""
Build a dictionary mapping parent nodes to their children.
Parameters
----------
@@ -298,9 +306,8 @@ def build_parent_children_dict(flow: Any) -> dict[str, list[str]]:
- Maps listeners to their trigger methods
- Maps router methods to their paths and listeners
- Children lists are sorted for consistent ordering
"""
parent_children: dict[str, list[str]] = {}
parent_children: Dict[str, List[str]] = {}
# Map listeners to their trigger methods
for listener_name, (_, trigger_methods) in flow._listeners.items():
@@ -325,9 +332,10 @@ def build_parent_children_dict(flow: Any) -> dict[str, list[str]]:
def get_child_index(
parent: str, child: str, parent_children: dict[str, list[str]],
parent: str, child: str, parent_children: Dict[str, List[str]]
) -> int:
"""Get the index of a child node in its parent's sorted children list.
"""
Get the index of a child node in its parent's sorted children list.
Parameters
----------
@@ -342,25 +350,27 @@ def get_child_index(
-------
int
Zero-based index of the child in its parent's sorted children list.
"""
children = parent_children.get(parent, [])
children.sort()
return children.index(child)
def process_router_paths(flow, current, current_level, levels, queue) -> None:
"""Handle the router connections for the current node."""
def process_router_paths(flow, current, current_level, levels, queue):
"""
Handle the router connections for the current node.
"""
if current in flow._routers:
paths = flow._router_paths.get(current, [])
for path in paths:
for listener_name, (
_condition_type,
condition_type,
trigger_methods,
) in flow._listeners.items():
if path in trigger_methods and (
listener_name not in levels
or levels[listener_name] > current_level + 1
):
levels[listener_name] = current_level + 1
queue.append(listener_name)
if path in trigger_methods:
if (
listener_name not in levels
or levels[listener_name] > current_level + 1
):
levels[listener_name] = current_level + 1
queue.append(listener_name)

View File

@@ -1,23 +1,23 @@
"""Utilities for creating visual representations of flow structures.
"""
Utilities for creating visual representations of flow structures.
This module provides functions for generating network visualizations of flows,
including node placement, edge creation, and visual styling. It handles the
conversion of flow structures into visual network graphs with appropriate
styling and layout.
Example:
Example
-------
>>> flow = Flow()
>>> net = Network(directed=True)
>>> node_positions = compute_positions(flow, node_levels)
>>> add_nodes_to_network(net, flow, node_positions, node_styles)
>>> add_edges(net, flow, node_positions, colors)
"""
import ast
import inspect
from typing import Any
from typing import Any, Dict, List, Optional, Tuple, Union
from .utils import (
build_ancestor_dict,
@@ -28,7 +28,8 @@ from .utils import (
def method_calls_crew(method: Any) -> bool:
"""Check if the method contains a call to `.crew()`.
"""
Check if the method contains a call to `.crew()`.
Parameters
----------
@@ -44,22 +45,21 @@ def method_calls_crew(method: Any) -> bool:
-----
Uses AST analysis to detect method calls, specifically looking for
attribute access of 'crew'.
"""
try:
source = inspect.getsource(method)
source = inspect.cleandoc(source)
tree = ast.parse(source)
except Exception:
except Exception as e:
print(f"Could not parse method {method.__name__}: {e}")
return False
class CrewCallVisitor(ast.NodeVisitor):
"""AST visitor to detect .crew() method calls."""
def __init__(self) -> None:
def __init__(self):
self.found = False
def visit_Call(self, node) -> None:
def visit_Call(self, node):
if isinstance(node.func, ast.Attribute):
if node.func.attr == "crew":
self.found = True
@@ -73,10 +73,11 @@ def method_calls_crew(method: Any) -> bool:
def add_nodes_to_network(
net: Any,
flow: Any,
node_positions: dict[str, tuple[float, float]],
node_styles: dict[str, dict[str, Any]],
node_positions: Dict[str, Tuple[float, float]],
node_styles: Dict[str, Dict[str, Any]]
) -> None:
"""Add nodes to the network visualization with appropriate styling.
"""
Add nodes to the network visualization with appropriate styling.
Parameters
----------
@@ -96,7 +97,6 @@ def add_nodes_to_network(
- Router methods
- Crew methods
- Regular methods
"""
def human_friendly_label(method_name):
return method_name.replace("_", " ").title()
@@ -123,7 +123,7 @@ def add_nodes_to_network(
"multi": "html",
"color": node_style.get("font", {}).get("color", "#FFFFFF"),
},
},
}
)
net.add_node(
@@ -138,11 +138,12 @@ def add_nodes_to_network(
def compute_positions(
flow: Any,
node_levels: dict[str, int],
node_levels: Dict[str, int],
y_spacing: float = 150,
x_spacing: float = 150,
) -> dict[str, tuple[float, float]]:
"""Compute the (x, y) positions for each node in the flow graph.
x_spacing: float = 150
) -> Dict[str, Tuple[float, float]]:
"""
Compute the (x, y) positions for each node in the flow graph.
Parameters
----------
@@ -159,10 +160,9 @@ def compute_positions(
-------
Dict[str, Tuple[float, float]]
Dictionary mapping node names to their (x, y) coordinates.
"""
level_nodes: dict[int, list[str]] = {}
node_positions: dict[str, tuple[float, float]] = {}
level_nodes: Dict[int, List[str]] = {}
node_positions: Dict[str, Tuple[float, float]] = {}
for method_name, level in node_levels.items():
level_nodes.setdefault(level, []).append(method_name)
@@ -180,10 +180,10 @@ def compute_positions(
def add_edges(
net: Any,
flow: Any,
node_positions: dict[str, tuple[float, float]],
colors: dict[str, str],
node_positions: Dict[str, Tuple[float, float]],
colors: Dict[str, str]
) -> None:
edge_smooth: dict[str, str | float] = {"type": "continuous"} # Default value
edge_smooth: Dict[str, Union[str, float]] = {"type": "continuous"} # Default value
"""
Add edges to the network visualization with appropriate styling.
@@ -245,7 +245,7 @@ def add_edges(
"color": edge_color,
"width": 2,
"arrows": "to",
"dashes": bool(is_router_edge or is_and_condition),
"dashes": True if is_router_edge or is_and_condition else False,
"smooth": edge_smooth,
}
@@ -261,7 +261,9 @@ def add_edges(
# If it's a known router edge and the method is known, don't warn.
# This means the path is legitimate, just not reflected as nodes here.
if not (is_router_edge and method_known):
pass
print(
f"Warning: No node found for '{trigger}' or '{method_name}'. Skipping edge."
)
# Edges for router return paths
for router_method_name, paths in flow._router_paths.items():
@@ -276,7 +278,7 @@ def add_edges(
and listener_name in node_positions
):
is_cycle_edge = is_ancestor(
router_method_name, listener_name, ancestors,
router_method_name, listener_name, ancestors
)
parent_has_multiple_children = (
len(parent_children.get(router_method_name, [])) > 1
@@ -291,7 +293,7 @@ def add_edges(
dx = target_pos[0] - source_pos[0]
smooth_type = "curvedCCW" if dx <= 0 else "curvedCW"
index = get_child_index(
router_method_name, listener_name, parent_children,
router_method_name, listener_name, parent_children
)
edge_smooth = {
"type": smooth_type,
@@ -314,4 +316,6 @@ def add_edges(
# Same check here: known router edge and known method?
method_known = listener_name in flow._methods
if not method_known:
pass
print(
f"Warning: No node found for '{router_method_name}' or '{listener_name}'. Skipping edge."
)

View File

@@ -1,48 +1,55 @@
from abc import ABC, abstractmethod
from typing import List
import numpy as np
class BaseEmbedder(ABC):
"""Abstract base class for text embedding models."""
"""
Abstract base class for text embedding models
"""
@abstractmethod
def embed_chunks(self, chunks: list[str]) -> np.ndarray:
"""Generate embeddings for a list of text chunks.
def embed_chunks(self, chunks: List[str]) -> np.ndarray:
"""
Generate embeddings for a list of text chunks
Args:
chunks: List of text chunks to embed
Returns:
Array of embeddings
"""
pass
@abstractmethod
def embed_texts(self, texts: list[str]) -> np.ndarray:
"""Generate embeddings for a list of texts.
def embed_texts(self, texts: List[str]) -> np.ndarray:
"""
Generate embeddings for a list of texts
Args:
texts: List of texts to embed
Returns:
Array of embeddings
"""
pass
@abstractmethod
def embed_text(self, text: str) -> np.ndarray:
"""Generate embedding for a single text.
"""
Generate embedding for a single text
Args:
text: Text to embed
Returns:
Embedding array
"""
pass
@property
@abstractmethod
def dimension(self) -> int:
"""Get the dimension of the embeddings."""
"""Get the dimension of the embeddings"""
pass

View File

@@ -1,4 +1,5 @@
from pathlib import Path
from typing import List, Optional, Union
import numpy as np
@@ -18,74 +19,75 @@ except ImportError:
class FastEmbed(BaseEmbedder):
"""A wrapper class for text embedding models using FastEmbed."""
"""
A wrapper class for text embedding models using FastEmbed
"""
def __init__(
self,
model_name: str = "BAAI/bge-small-en-v1.5",
cache_dir: str | Path | None = None,
) -> None:
"""Initialize the embedding model.
cache_dir: Optional[Union[str, Path]] = None,
):
"""
Initialize the embedding model
Args:
model_name: Name of the model to use
cache_dir: Directory to cache the model
gpu: Whether to use GPU acceleration
"""
if not FASTEMBED_AVAILABLE:
msg = (
raise ImportError(
"FastEmbed is not installed. Please install it with: "
"uv pip install fastembed or uv pip install fastembed-gpu for GPU support"
)
raise ImportError(
msg,
)
self.model = TextEmbedding(
model_name=model_name,
cache_dir=str(cache_dir) if cache_dir else None,
)
def embed_chunks(self, chunks: list[str]) -> list[np.ndarray]:
"""Generate embeddings for a list of text chunks.
def embed_chunks(self, chunks: List[str]) -> List[np.ndarray]:
"""
Generate embeddings for a list of text chunks
Args:
chunks: List of text chunks to embed
Returns:
List of embeddings
"""
return list(self.model.embed(chunks))
embeddings = list(self.model.embed(chunks))
return embeddings
def embed_texts(self, texts: list[str]) -> list[np.ndarray]:
"""Generate embeddings for a list of texts.
def embed_texts(self, texts: List[str]) -> List[np.ndarray]:
"""
Generate embeddings for a list of texts
Args:
texts: List of texts to embed
Returns:
List of embeddings
"""
return list(self.model.embed(texts))
embeddings = list(self.model.embed(texts))
return embeddings
def embed_text(self, text: str) -> np.ndarray:
"""Generate embedding for a single text.
"""
Generate embedding for a single text
Args:
text: Text to embed
Returns:
Embedding array
"""
return self.embed_texts([text])[0]
@property
def dimension(self) -> int:
"""Get the dimension of the embeddings."""
"""Get the dimension of the embeddings"""
# Generate a test embedding to get dimensions
test_embed = self.embed_text("test")
return len(test_embed)

View File

@@ -1,5 +1,5 @@
import os
from typing import Any
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, ConfigDict, Field
@@ -10,70 +10,69 @@ os.environ["TOKENIZERS_PARALLELISM"] = "false" # removes logging from fastembed
class Knowledge(BaseModel):
"""Knowledge is a collection of sources and setup for the vector store to save and query relevant context.
"""
Knowledge is a collection of sources and setup for the vector store to save and query relevant context.
Args:
sources: List[BaseKnowledgeSource] = Field(default_factory=list)
storage: Optional[KnowledgeStorage] = Field(default=None)
embedder: Optional[Dict[str, Any]] = None.
embedder: Optional[Dict[str, Any]] = None
"""
sources: list[BaseKnowledgeSource] = Field(default_factory=list)
sources: List[BaseKnowledgeSource] = Field(default_factory=list)
model_config = ConfigDict(arbitrary_types_allowed=True)
storage: KnowledgeStorage | None = Field(default=None)
embedder: dict[str, Any] | None = None
collection_name: str | None = None
storage: Optional[KnowledgeStorage] = Field(default=None)
embedder: Optional[Dict[str, Any]] = None
collection_name: Optional[str] = None
def __init__(
self,
collection_name: str,
sources: list[BaseKnowledgeSource],
embedder: dict[str, Any] | None = None,
storage: KnowledgeStorage | None = None,
sources: List[BaseKnowledgeSource],
embedder: Optional[Dict[str, Any]] = None,
storage: Optional[KnowledgeStorage] = None,
**data,
) -> None:
):
super().__init__(**data)
if storage:
self.storage = storage
else:
self.storage = KnowledgeStorage(
embedder=embedder, collection_name=collection_name,
embedder=embedder, collection_name=collection_name
)
self.sources = sources
self.storage.initialize_knowledge_storage()
self._add_sources()
def query(
self, query: list[str], results_limit: int = 3, score_threshold: float = 0.35,
) -> list[dict[str, Any]]:
"""Query across all knowledge sources to find the most relevant information.
self, query: List[str], results_limit: int = 3, score_threshold: float = 0.35
) -> List[Dict[str, Any]]:
"""
Query across all knowledge sources to find the most relevant information.
Returns the top_k most relevant chunks.
Raises:
ValueError: If storage is not initialized.
"""
if self.storage is None:
msg = "Storage is not initialized."
raise ValueError(msg)
raise ValueError("Storage is not initialized.")
return self.storage.search(
results = self.storage.search(
query,
limit=results_limit,
score_threshold=score_threshold,
)
return results
def add_sources(self) -> None:
def _add_sources(self):
try:
for source in self.sources:
source.storage = self.storage
source.add()
except Exception:
raise
except Exception as e:
raise e
def reset(self) -> None:
if self.storage:
self.storage.reset()
else:
msg = "Storage is not initialized."
raise ValueError(msg)
raise ValueError("Storage is not initialized.")

View File

@@ -7,7 +7,6 @@ class KnowledgeConfig(BaseModel):
Args:
results_limit (int): The number of relevant documents to return.
score_threshold (float): The minimum score for a document to be considered relevant.
"""
results_limit: int = Field(default=3, description="The number of results to return")

View File

@@ -1,5 +1,6 @@
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Dict, List, Optional, Union
from pydantic import Field, field_validator
@@ -13,43 +14,43 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
"""Base class for knowledge sources that load content from files."""
_logger: Logger = Logger(verbose=True)
file_path: Path | list[Path] | str | list[str] | None = Field(
file_path: Optional[Union[Path, List[Path], str, List[str]]] = Field(
default=None,
description="[Deprecated] The path to the file. Use file_paths instead.",
)
file_paths: Path | list[Path] | str | list[str] | None = Field(
default_factory=list, description="The path to the file",
file_paths: Optional[Union[Path, List[Path], str, List[str]]] = Field(
default_factory=list, description="The path to the file"
)
content: dict[Path, str] = Field(init=False, default_factory=dict)
storage: KnowledgeStorage | None = Field(default=None)
safe_file_paths: list[Path] = Field(default_factory=list)
content: Dict[Path, str] = Field(init=False, default_factory=dict)
storage: Optional[KnowledgeStorage] = Field(default=None)
safe_file_paths: List[Path] = Field(default_factory=list)
@field_validator("file_path", "file_paths", mode="before")
def validate_file_path(self, v, info):
def validate_file_path(cls, v, info):
"""Validate that at least one of file_path or file_paths is provided."""
# Single check if both are None, O(1) instead of nested conditions
if (
v is None
and info.data.get(
"file_path" if info.field_name == "file_paths" else "file_paths",
"file_path" if info.field_name == "file_paths" else "file_paths"
)
is None
):
msg = "Either file_path or file_paths must be provided"
raise ValueError(msg)
raise ValueError("Either file_path or file_paths must be provided")
return v
def model_post_init(self, _) -> None:
def model_post_init(self, _):
"""Post-initialization method to load content."""
self.safe_file_paths = self._process_file_paths()
self.validate_content()
self.content = self.load_content()
@abstractmethod
def load_content(self) -> dict[Path, str]:
def load_content(self) -> Dict[Path, str]:
"""Load and preprocess file content. Should be overridden by subclasses. Assume that the file path is relative to the project root in the knowledge directory."""
pass
def validate_content(self) -> None:
def validate_content(self):
"""Validate the paths."""
for path in self.safe_file_paths:
if not path.exists():
@@ -58,8 +59,7 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
f"File not found: {path}. Try adding sources to the knowledge directory. If it's inside the knowledge directory, use the relative path.",
color="red",
)
msg = f"File not found: {path}"
raise FileNotFoundError(msg)
raise FileNotFoundError(f"File not found: {path}")
if not path.is_file():
self._logger.log(
"error",
@@ -67,20 +67,20 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
color="red",
)
def _save_documents(self) -> None:
def _save_documents(self):
"""Save the documents to the storage."""
if self.storage:
self.storage.save(self.chunks)
else:
msg = "No storage found to save documents."
raise ValueError(msg)
raise ValueError("No storage found to save documents.")
def convert_to_path(self, path: Path | str) -> Path:
def convert_to_path(self, path: Union[Path, str]) -> Path:
"""Convert a path to a Path object."""
return Path(KNOWLEDGE_DIRECTORY + "/" + path) if isinstance(path, str) else path
def _process_file_paths(self) -> list[Path]:
def _process_file_paths(self) -> List[Path]:
"""Convert file_path to a list of Path objects."""
if hasattr(self, "file_path") and self.file_path is not None:
self._logger.log(
"warning",
@@ -90,11 +90,10 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
self.file_paths = self.file_path
if self.file_paths is None:
msg = "Your source must be provided with a file_paths: []"
raise ValueError(msg)
raise ValueError("Your source must be provided with a file_paths: []")
# Convert single path to list
path_list: list[Path | str] = (
path_list: List[Union[Path, str]] = (
[self.file_paths]
if isinstance(self.file_paths, (str, Path))
else list(self.file_paths)
@@ -103,9 +102,8 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
)
if not path_list:
msg = "file_path/file_paths must be a Path, str, or a list of these types"
raise ValueError(
msg,
"file_path/file_paths must be a Path, str, or a list of these types"
)
return [self.convert_to_path(path) for path in path_list]

View File

@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Any
from typing import Any, Dict, List, Optional
import numpy as np
from pydantic import BaseModel, ConfigDict, Field
@@ -12,39 +12,41 @@ class BaseKnowledgeSource(BaseModel, ABC):
chunk_size: int = 4000
chunk_overlap: int = 200
chunks: list[str] = Field(default_factory=list)
chunk_embeddings: list[np.ndarray] = Field(default_factory=list)
chunks: List[str] = Field(default_factory=list)
chunk_embeddings: List[np.ndarray] = Field(default_factory=list)
model_config = ConfigDict(arbitrary_types_allowed=True)
storage: KnowledgeStorage | None = Field(default=None)
metadata: dict[str, Any] = Field(default_factory=dict) # Currently unused
collection_name: str | None = Field(default=None)
storage: Optional[KnowledgeStorage] = Field(default=None)
metadata: Dict[str, Any] = Field(default_factory=dict) # Currently unused
collection_name: Optional[str] = Field(default=None)
@abstractmethod
def validate_content(self) -> Any:
"""Load and preprocess content from the source."""
pass
@abstractmethod
def add(self) -> None:
"""Process content, chunk it, compute embeddings, and save them."""
pass
def get_embeddings(self) -> list[np.ndarray]:
def get_embeddings(self) -> List[np.ndarray]:
"""Return the list of embeddings for the chunks."""
return self.chunk_embeddings
def _chunk_text(self, text: str) -> list[str]:
def _chunk_text(self, text: str) -> List[str]:
"""Utility method to split text into chunks."""
return [
text[i : i + self.chunk_size]
for i in range(0, len(text), self.chunk_size - self.chunk_overlap)
]
def _save_documents(self) -> None:
"""Save the documents to the storage.
def _save_documents(self):
"""
Save the documents to the storage.
This method should be called after the chunks and embeddings are generated.
"""
if self.storage:
self.storage.save(self.chunks)
else:
msg = "No storage found to save documents."
raise ValueError(msg)
raise ValueError("No storage found to save documents.")

View File

@@ -1,6 +1,5 @@
from collections.abc import Iterator
from pathlib import Path
from typing import TYPE_CHECKING
from typing import Iterator, List, Optional, Union
from urllib.parse import urlparse
try:
@@ -8,6 +7,7 @@ try:
from docling.document_converter import DocumentConverter
from docling.exceptions import ConversionError
from docling_core.transforms.chunker.hierarchical_chunker import HierarchicalChunker
from docling_core.types.doc.document import DoclingDocument
DOCLING_AVAILABLE = True
except ImportError:
@@ -19,33 +19,27 @@ from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
from crewai.utilities.constants import KNOWLEDGE_DIRECTORY
from crewai.utilities.logger import Logger
if TYPE_CHECKING:
from docling_core.types.doc.document import DoclingDocument
class CrewDoclingSource(BaseKnowledgeSource):
"""Default Source class for converting documents to markdown or json
This will auto support PDF, DOCX, and TXT, XLSX, Images, and HTML files without any additional dependencies and follows the docling package as the source of truth.
"""
def __init__(self, *args, **kwargs) -> None:
def __init__(self, *args, **kwargs):
if not DOCLING_AVAILABLE:
msg = (
raise ImportError(
"The docling package is required to use CrewDoclingSource. "
"Please install it using: uv add docling"
)
raise ImportError(
msg,
)
super().__init__(*args, **kwargs)
_logger: Logger = Logger(verbose=True)
file_path: list[Path | str] | None = Field(default=None)
file_paths: list[Path | str] = Field(default_factory=list)
chunks: list[str] = Field(default_factory=list)
safe_file_paths: list[Path | str] = Field(default_factory=list)
content: list["DoclingDocument"] = Field(default_factory=list)
file_path: Optional[List[Union[Path, str]]] = Field(default=None)
file_paths: List[Union[Path, str]] = Field(default_factory=list)
chunks: List[str] = Field(default_factory=list)
safe_file_paths: List[Union[Path, str]] = Field(default_factory=list)
content: List["DoclingDocument"] = Field(default_factory=list)
document_converter: "DocumentConverter" = Field(
default_factory=lambda: DocumentConverter(
allowed_formats=[
@@ -57,8 +51,8 @@ class CrewDoclingSource(BaseKnowledgeSource):
InputFormat.IMAGE,
InputFormat.XLSX,
InputFormat.PPTX,
],
),
]
)
)
def model_post_init(self, _) -> None:
@@ -72,7 +66,7 @@ class CrewDoclingSource(BaseKnowledgeSource):
self.safe_file_paths = self.validate_content()
self.content = self._load_content()
def _load_content(self) -> list["DoclingDocument"]:
def _load_content(self) -> List["DoclingDocument"]:
try:
return self._convert_source_to_docling_documents()
except ConversionError as e:
@@ -81,10 +75,10 @@ class CrewDoclingSource(BaseKnowledgeSource):
f"Error loading content: {e}. Supported formats: {self.document_converter.allowed_formats}",
"red",
)
raise
raise e
except Exception as e:
self._logger.log("error", f"Error loading content: {e}")
raise
raise e
def add(self) -> None:
if self.content is None:
@@ -94,7 +88,7 @@ class CrewDoclingSource(BaseKnowledgeSource):
self.chunks.extend(list(new_chunks_iterable))
self._save_documents()
def _convert_source_to_docling_documents(self) -> list["DoclingDocument"]:
def _convert_source_to_docling_documents(self) -> List["DoclingDocument"]:
conv_results_iter = self.document_converter.convert_all(self.safe_file_paths)
return [result.document for result in conv_results_iter]
@@ -103,8 +97,8 @@ class CrewDoclingSource(BaseKnowledgeSource):
for chunk in chunker.chunk(doc):
yield chunk.text
def validate_content(self) -> list[Path | str]:
processed_paths: list[Path | str] = []
def validate_content(self) -> List[Union[Path, str]]:
processed_paths: List[Union[Path, str]] = []
for path in self.file_paths:
if isinstance(path, str):
if path.startswith(("http://", "https://")):
@@ -112,18 +106,15 @@ class CrewDoclingSource(BaseKnowledgeSource):
if self._validate_url(path):
processed_paths.append(path)
else:
msg = f"Invalid URL format: {path}"
raise ValueError(msg)
raise ValueError(f"Invalid URL format: {path}")
except Exception as e:
msg = f"Invalid URL: {path}. Error: {e!s}"
raise ValueError(msg)
raise ValueError(f"Invalid URL: {path}. Error: {str(e)}")
else:
local_path = Path(KNOWLEDGE_DIRECTORY + "/" + path)
if local_path.exists():
processed_paths.append(local_path)
else:
msg = f"File not found: {local_path}"
raise FileNotFoundError(msg)
raise FileNotFoundError(f"File not found: {local_path}")
else:
# this is an instance of Path
processed_paths.append(path)
@@ -137,7 +128,7 @@ class CrewDoclingSource(BaseKnowledgeSource):
result.scheme in ("http", "https"),
result.netloc,
len(result.netloc.split(".")) >= 2, # Ensure domain has TLD
],
]
)
except Exception:
return False

View File

@@ -1,5 +1,6 @@
import csv
from pathlib import Path
from typing import Dict, List
from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledgeSource
@@ -7,11 +8,11 @@ from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledge
class CSVKnowledgeSource(BaseFileKnowledgeSource):
"""A knowledge source that stores and queries CSV file content using embeddings."""
def load_content(self) -> dict[Path, str]:
def load_content(self) -> Dict[Path, str]:
"""Load and preprocess CSV file content."""
content_dict = {}
for file_path in self.safe_file_paths:
with open(file_path, encoding="utf-8") as csvfile:
with open(file_path, "r", encoding="utf-8") as csvfile:
reader = csv.reader(csvfile)
content = ""
for row in reader:
@@ -20,7 +21,8 @@ class CSVKnowledgeSource(BaseFileKnowledgeSource):
return content_dict
def add(self) -> None:
"""Add CSV file content to the knowledge source, chunk it, compute embeddings,
"""
Add CSV file content to the knowledge source, chunk it, compute embeddings,
and save the embeddings.
"""
content_str = (
@@ -30,7 +32,7 @@ class CSVKnowledgeSource(BaseFileKnowledgeSource):
self.chunks.extend(new_chunks)
self._save_documents()
def _chunk_text(self, text: str) -> list[str]:
def _chunk_text(self, text: str) -> List[str]:
"""Utility method to split text into chunks."""
return [
text[i : i + self.chunk_size]

View File

@@ -1,4 +1,6 @@
from pathlib import Path
from typing import Dict, Iterator, List, Optional, Union
from urllib.parse import urlparse
from pydantic import Field, field_validator
@@ -14,34 +16,34 @@ class ExcelKnowledgeSource(BaseKnowledgeSource):
_logger: Logger = Logger(verbose=True)
file_path: Path | list[Path] | str | list[str] | None = Field(
file_path: Optional[Union[Path, List[Path], str, List[str]]] = Field(
default=None,
description="[Deprecated] The path to the file. Use file_paths instead.",
)
file_paths: Path | list[Path] | str | list[str] | None = Field(
default_factory=list, description="The path to the file",
file_paths: Optional[Union[Path, List[Path], str, List[str]]] = Field(
default_factory=list, description="The path to the file"
)
chunks: list[str] = Field(default_factory=list)
content: dict[Path, dict[str, str]] = Field(default_factory=dict)
safe_file_paths: list[Path] = Field(default_factory=list)
chunks: List[str] = Field(default_factory=list)
content: Dict[Path, Dict[str, str]] = Field(default_factory=dict)
safe_file_paths: List[Path] = Field(default_factory=list)
@field_validator("file_path", "file_paths", mode="before")
def validate_file_path(self, v, info):
def validate_file_path(cls, v, info):
"""Validate that at least one of file_path or file_paths is provided."""
# Single check if both are None, O(1) instead of nested conditions
if (
v is None
and info.data.get(
"file_path" if info.field_name == "file_paths" else "file_paths",
"file_path" if info.field_name == "file_paths" else "file_paths"
)
is None
):
msg = "Either file_path or file_paths must be provided"
raise ValueError(msg)
raise ValueError("Either file_path or file_paths must be provided")
return v
def _process_file_paths(self) -> list[Path]:
def _process_file_paths(self) -> List[Path]:
"""Convert file_path to a list of Path objects."""
if hasattr(self, "file_path") and self.file_path is not None:
self._logger.log(
"warning",
@@ -51,11 +53,10 @@ class ExcelKnowledgeSource(BaseKnowledgeSource):
self.file_paths = self.file_path
if self.file_paths is None:
msg = "Your source must be provided with a file_paths: []"
raise ValueError(msg)
raise ValueError("Your source must be provided with a file_paths: []")
# Convert single path to list
path_list: list[Path | str] = (
path_list: List[Union[Path, str]] = (
[self.file_paths]
if isinstance(self.file_paths, (str, Path))
else list(self.file_paths)
@@ -64,14 +65,13 @@ class ExcelKnowledgeSource(BaseKnowledgeSource):
)
if not path_list:
msg = "file_path/file_paths must be a Path, str, or a list of these types"
raise ValueError(
msg,
"file_path/file_paths must be a Path, str, or a list of these types"
)
return [self.convert_to_path(path) for path in path_list]
def validate_content(self) -> None:
def validate_content(self):
"""Validate the paths."""
for path in self.safe_file_paths:
if not path.exists():
@@ -80,8 +80,7 @@ class ExcelKnowledgeSource(BaseKnowledgeSource):
f"File not found: {path}. Try adding sources to the knowledge directory. If it's inside the knowledge directory, use the relative path.",
color="red",
)
msg = f"File not found: {path}"
raise FileNotFoundError(msg)
raise FileNotFoundError(f"File not found: {path}")
if not path.is_file():
self._logger.log(
"error",
@@ -101,7 +100,7 @@ class ExcelKnowledgeSource(BaseKnowledgeSource):
self.validate_content()
self.content = self._load_content()
def _load_content(self) -> dict[Path, dict[str, str]]:
def _load_content(self) -> Dict[Path, Dict[str, str]]:
"""Load and preprocess Excel file content from multiple sheets.
Each sheet's content is converted to CSV format and stored.
@@ -112,7 +111,6 @@ class ExcelKnowledgeSource(BaseKnowledgeSource):
Raises:
ImportError: If required dependencies are missing.
FileNotFoundError: If the specified Excel file cannot be opened.
"""
pd = self._import_dependencies()
content_dict = {}
@@ -121,14 +119,14 @@ class ExcelKnowledgeSource(BaseKnowledgeSource):
with pd.ExcelFile(file_path) as xl:
sheet_dict = {
str(sheet_name): str(
pd.read_excel(xl, sheet_name).to_csv(index=False),
pd.read_excel(xl, sheet_name).to_csv(index=False)
)
for sheet_name in xl.sheet_names
}
content_dict[file_path] = sheet_dict
return content_dict
def convert_to_path(self, path: Path | str) -> Path:
def convert_to_path(self, path: Union[Path, str]) -> Path:
"""Convert a path to a Path object."""
return Path(KNOWLEDGE_DIRECTORY + "/" + path) if isinstance(path, str) else path
@@ -140,13 +138,13 @@ class ExcelKnowledgeSource(BaseKnowledgeSource):
return pd
except ImportError as e:
missing_package = str(e).split()[-1]
msg = f"{missing_package} is not installed. Please install it with: pip install {missing_package}"
raise ImportError(
msg,
f"{missing_package} is not installed. Please install it with: pip install {missing_package}"
)
def add(self) -> None:
"""Add Excel file content to the knowledge source, chunk it, compute embeddings,
"""
Add Excel file content to the knowledge source, chunk it, compute embeddings,
and save the embeddings.
"""
# Convert dictionary values to a single string if content is a dictionary
@@ -163,7 +161,7 @@ class ExcelKnowledgeSource(BaseKnowledgeSource):
self.chunks.extend(new_chunks)
self._save_documents()
def _chunk_text(self, text: str) -> list[str]:
def _chunk_text(self, text: str) -> List[str]:
"""Utility method to split text into chunks."""
return [
text[i : i + self.chunk_size]

View File

@@ -1,6 +1,6 @@
import json
from pathlib import Path
from typing import Any
from typing import Any, Dict, List
from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledgeSource
@@ -8,12 +8,12 @@ from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledge
class JSONKnowledgeSource(BaseFileKnowledgeSource):
"""A knowledge source that stores and queries JSON file content using embeddings."""
def load_content(self) -> dict[Path, str]:
def load_content(self) -> Dict[Path, str]:
"""Load and preprocess JSON file content."""
content: dict[Path, str] = {}
content: Dict[Path, str] = {}
for path in self.safe_file_paths:
path = self.convert_to_path(path)
with open(path, encoding="utf-8") as json_file:
with open(path, "r", encoding="utf-8") as json_file:
data = json.load(json_file)
content[path] = self._json_to_text(data)
return content
@@ -29,11 +29,12 @@ class JSONKnowledgeSource(BaseFileKnowledgeSource):
for item in data:
text += f"{indent}- {self._json_to_text(item, level + 1)}\n"
else:
text += f"{data!s}"
text += f"{str(data)}"
return text
def add(self) -> None:
"""Add JSON file content to the knowledge source, chunk it, compute embeddings,
"""
Add JSON file content to the knowledge source, chunk it, compute embeddings,
and save the embeddings.
"""
content_str = (
@@ -43,7 +44,7 @@ class JSONKnowledgeSource(BaseFileKnowledgeSource):
self.chunks.extend(new_chunks)
self._save_documents()
def _chunk_text(self, text: str) -> list[str]:
def _chunk_text(self, text: str) -> List[str]:
"""Utility method to split text into chunks."""
return [
text[i : i + self.chunk_size]

View File

@@ -1,4 +1,5 @@
from pathlib import Path
from typing import Dict, List
from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledgeSource
@@ -6,7 +7,7 @@ from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledge
class PDFKnowledgeSource(BaseFileKnowledgeSource):
"""A knowledge source that stores and queries PDF file content using embeddings."""
def load_content(self) -> dict[Path, str]:
def load_content(self) -> Dict[Path, str]:
"""Load and preprocess PDF file content."""
pdfplumber = self._import_pdfplumber()
@@ -30,21 +31,21 @@ class PDFKnowledgeSource(BaseFileKnowledgeSource):
return pdfplumber
except ImportError:
msg = "pdfplumber is not installed. Please install it with: pip install pdfplumber"
raise ImportError(
msg,
"pdfplumber is not installed. Please install it with: pip install pdfplumber"
)
def add(self) -> None:
"""Add PDF file content to the knowledge source, chunk it, compute embeddings,
"""
Add PDF file content to the knowledge source, chunk it, compute embeddings,
and save the embeddings.
"""
for text in self.content.values():
for _, text in self.content.items():
new_chunks = self._chunk_text(text)
self.chunks.extend(new_chunks)
self._save_documents()
def _chunk_text(self, text: str) -> list[str]:
def _chunk_text(self, text: str) -> List[str]:
"""Utility method to split text into chunks."""
return [
text[i : i + self.chunk_size]

View File

@@ -1,3 +1,4 @@
from typing import List, Optional
from pydantic import Field
@@ -8,17 +9,16 @@ class StringKnowledgeSource(BaseKnowledgeSource):
"""A knowledge source that stores and queries plain text content using embeddings."""
content: str = Field(...)
collection_name: str | None = Field(default=None)
collection_name: Optional[str] = Field(default=None)
def model_post_init(self, _) -> None:
def model_post_init(self, _):
"""Post-initialization method to validate content."""
self.validate_content()
def validate_content(self) -> None:
def validate_content(self):
"""Validate string content."""
if not isinstance(self.content, str):
msg = "StringKnowledgeSource only accepts string content"
raise ValueError(msg)
raise ValueError("StringKnowledgeSource only accepts string content")
def add(self) -> None:
"""Add string content to the knowledge source, chunk it, compute embeddings, and save them."""
@@ -26,7 +26,7 @@ class StringKnowledgeSource(BaseKnowledgeSource):
self.chunks.extend(new_chunks)
self._save_documents()
def _chunk_text(self, text: str) -> list[str]:
def _chunk_text(self, text: str) -> List[str]:
"""Utility method to split text into chunks."""
return [
text[i : i + self.chunk_size]

View File

@@ -1,4 +1,5 @@
from pathlib import Path
from typing import Dict, List
from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledgeSource
@@ -6,25 +7,26 @@ from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledge
class TextFileKnowledgeSource(BaseFileKnowledgeSource):
"""A knowledge source that stores and queries text file content using embeddings."""
def load_content(self) -> dict[Path, str]:
def load_content(self) -> Dict[Path, str]:
"""Load and preprocess text file content."""
content = {}
for path in self.safe_file_paths:
path = self.convert_to_path(path)
with open(path, encoding="utf-8") as f:
with open(path, "r", encoding="utf-8") as f:
content[path] = f.read()
return content
def add(self) -> None:
"""Add text file content to the knowledge source, chunk it, compute embeddings,
"""
Add text file content to the knowledge source, chunk it, compute embeddings,
and save the embeddings.
"""
for text in self.content.values():
for _, text in self.content.items():
new_chunks = self._chunk_text(text)
self.chunks.extend(new_chunks)
self._save_documents()
def _chunk_text(self, text: str) -> list[str]:
def _chunk_text(self, text: str) -> List[str]:
"""Utility method to split text into chunks."""
return [
text[i : i + self.chunk_size]

View File

@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Any
from typing import Any, Dict, List, Optional
class BaseKnowledgeStorage(ABC):
@@ -8,19 +8,22 @@ class BaseKnowledgeStorage(ABC):
@abstractmethod
def search(
self,
query: list[str],
query: List[str],
limit: int = 3,
filter: dict | None = None,
filter: Optional[dict] = None,
score_threshold: float = 0.35,
) -> list[dict[str, Any]]:
) -> List[Dict[str, Any]]:
"""Search for documents in the knowledge base."""
pass
@abstractmethod
def save(
self, documents: list[str], metadata: dict[str, Any] | list[dict[str, Any]],
self, documents: List[str], metadata: Dict[str, Any] | List[Dict[str, Any]]
) -> None:
"""Save documents to the knowledge base."""
pass
@abstractmethod
def reset(self) -> None:
"""Reset the knowledge base."""
pass

View File

@@ -4,11 +4,12 @@ import io
import logging
import os
import shutil
from typing import TYPE_CHECKING, Any
from typing import Any, Dict, List, Optional, Union
import chromadb
import chromadb.errors
from chromadb.api import ClientAPI
from chromadb.api.types import OneOrMany
from chromadb.config import Settings
from crewai.knowledge.storage.base_knowledge_storage import BaseKnowledgeStorage
@@ -18,9 +19,6 @@ from crewai.utilities.constants import KNOWLEDGE_DIRECTORY
from crewai.utilities.logger import Logger
from crewai.utilities.paths import db_storage_path
if TYPE_CHECKING:
from chromadb.api.types import OneOrMany
@contextlib.contextmanager
def suppress_logging(
@@ -40,29 +38,30 @@ def suppress_logging(
class KnowledgeStorage(BaseKnowledgeStorage):
"""Extends Storage to handle embeddings for memory entries, improving
"""
Extends Storage to handle embeddings for memory entries, improving
search efficiency.
"""
collection: chromadb.Collection | None = None
collection_name: str | None = "knowledge"
app: ClientAPI | None = None
collection: Optional[chromadb.Collection] = None
collection_name: Optional[str] = "knowledge"
app: Optional[ClientAPI] = None
def __init__(
self,
embedder: dict[str, Any] | None = None,
collection_name: str | None = None,
) -> None:
embedder: Optional[Dict[str, Any]] = None,
collection_name: Optional[str] = None,
):
self.collection_name = collection_name
self._set_embedder_config(embedder)
def search(
self,
query: list[str],
query: List[str],
limit: int = 3,
filter: dict | None = None,
filter: Optional[dict] = None,
score_threshold: float = 0.35,
) -> list[dict[str, Any]]:
) -> List[Dict[str, Any]]:
with suppress_logging():
if self.collection:
fetched = self.collection.query(
@@ -81,10 +80,10 @@ class KnowledgeStorage(BaseKnowledgeStorage):
if result["score"] >= score_threshold:
results.append(result)
return results
msg = "Collection not initialized"
raise Exception(msg)
else:
raise Exception("Collection not initialized")
def initialize_knowledge_storage(self) -> None:
def initialize_knowledge_storage(self):
base_path = os.path.join(db_storage_path(), "knowledge")
chroma_client = chromadb.PersistentClient(
path=base_path,
@@ -105,13 +104,11 @@ class KnowledgeStorage(BaseKnowledgeStorage):
embedding_function=self.embedder,
)
else:
msg = "Vector Database Client not initialized"
raise Exception(msg)
raise Exception("Vector Database Client not initialized")
except Exception:
msg = "Failed to create or get collection"
raise Exception(msg)
raise Exception("Failed to create or get collection")
def reset(self) -> None:
def reset(self):
base_path = os.path.join(db_storage_path(), KNOWLEDGE_DIRECTORY)
if not self.app:
self.app = chromadb.PersistentClient(
@@ -126,12 +123,11 @@ class KnowledgeStorage(BaseKnowledgeStorage):
def save(
self,
documents: list[str],
metadata: dict[str, Any] | list[dict[str, Any]] | None = None,
) -> None:
documents: List[str],
metadata: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
):
if not self.collection:
msg = "Collection not initialized"
raise Exception(msg)
raise Exception("Collection not initialized")
try:
# Create a dictionary to store unique documents
@@ -160,7 +156,7 @@ class KnowledgeStorage(BaseKnowledgeStorage):
filtered_ids.append(doc_id)
# If we have no metadata at all, set it to None
final_metadata: OneOrMany[chromadb.Metadata] | None = (
final_metadata: Optional[OneOrMany[chromadb.Metadata]] = (
None if all(m is None for m in filtered_metadata) else filtered_metadata
)
@@ -175,13 +171,10 @@ class KnowledgeStorage(BaseKnowledgeStorage):
"Embedding dimension mismatch. This usually happens when mixing different embedding models. Try resetting the collection using `crewai reset-memories -a`",
"red",
)
msg = (
raise ValueError(
"Embedding dimension mismatch. Make sure you're using the same embedding model "
"across all operations with this collection."
"Try resetting the collection using `crewai reset-memories -a`"
)
raise ValueError(
msg,
) from e
except Exception as e:
Logger(verbose=True).log("error", f"Failed to upsert documents: {e}", "red")
@@ -193,16 +186,15 @@ class KnowledgeStorage(BaseKnowledgeStorage):
)
return OpenAIEmbeddingFunction(
api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small",
api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small"
)
def _set_embedder_config(self, embedder: dict[str, Any] | None = None) -> None:
def _set_embedder_config(self, embedder: Optional[Dict[str, Any]] = None) -> None:
"""Set the embedding configuration for the knowledge storage.
Args:
embedder_config (Optional[Dict[str, Any]]): Configuration dictionary for the embedder.
If None or empty, defaults to the default embedding function.
"""
self.embedder = (
EmbeddingConfigurator().configure_embedder(embedder)

View File

@@ -1,7 +1,7 @@
from typing import Any
from typing import Any, Dict, List
def extract_knowledge_context(knowledge_snippets: list[dict[str, Any]]) -> str:
def extract_knowledge_context(knowledge_snippets: List[Dict[str, Any]]) -> str:
"""Extract knowledge from the task prompt."""
valid_snippets = [
result["context"]

View File

@@ -1,7 +1,7 @@
import asyncio
import uuid
from collections.abc import Callable
from typing import Any, cast
from datetime import datetime
from typing import Any, Callable, Dict, List, Optional, Type, Union, cast
from pydantic import BaseModel, Field, InstanceOf, PrivateAttr, model_validator
@@ -13,7 +13,6 @@ from crewai.agents.parser import (
AgentFinish,
OutputParserException,
)
from crewai.flow.flow_trackable import FlowTrackable
from crewai.llm import LLM
from crewai.tools.base_tool import BaseTool
from crewai.tools.structured_tool import CrewStructuredTool
@@ -35,7 +34,7 @@ from crewai.utilities.agent_utils import (
render_text_description_and_args,
show_agent_logs,
)
from crewai.utilities.converter import generate_model_description
from crewai.utilities.converter import convert_to_model, generate_model_description
from crewai.utilities.events.agent_events import (
LiteAgentExecutionCompletedEvent,
LiteAgentExecutionErrorEvent,
@@ -60,15 +59,15 @@ class LiteAgentOutput(BaseModel):
model_config = {"arbitrary_types_allowed": True}
raw: str = Field(description="Raw output of the agent", default="")
pydantic: BaseModel | None = Field(
description="Pydantic output of the agent", default=None,
pydantic: Optional[BaseModel] = Field(
description="Pydantic output of the agent", default=None
)
agent_role: str = Field(description="Role of the agent that produced this output")
usage_metrics: dict[str, Any] | None = Field(
description="Token usage metrics for this execution", default=None,
usage_metrics: Optional[Dict[str, Any]] = Field(
description="Token usage metrics for this execution", default=None
)
def to_dict(self) -> dict[str, Any]:
def to_dict(self) -> Dict[str, Any]:
"""Convert pydantic_output to a dictionary."""
if self.pydantic:
return self.pydantic.model_dump()
@@ -81,8 +80,9 @@ class LiteAgentOutput(BaseModel):
return self.raw
class LiteAgent(FlowTrackable, BaseModel):
"""A lightweight agent that can process messages and use tools.
class LiteAgent(BaseModel):
"""
A lightweight agent that can process messages and use tools.
This agent is simpler than the full Agent class, focusing on direct execution
rather than task delegation. It's designed to be used for simple interactions
@@ -98,7 +98,6 @@ class LiteAgent(FlowTrackable, BaseModel):
max_iterations: Maximum number of iterations for tool usage.
max_execution_time: Maximum execution time in seconds.
response_format: Optional Pydantic model for structured output.
"""
model_config = {"arbitrary_types_allowed": True}
@@ -107,19 +106,19 @@ class LiteAgent(FlowTrackable, BaseModel):
role: str = Field(description="Role of the agent")
goal: str = Field(description="Goal of the agent")
backstory: str = Field(description="Backstory of the agent")
llm: str | InstanceOf[LLM] | Any | None = Field(
default=None, description="Language model that will run the agent",
llm: Optional[Union[str, InstanceOf[LLM], Any]] = Field(
default=None, description="Language model that will run the agent"
)
tools: list[BaseTool] = Field(
default_factory=list, description="Tools at agent's disposal",
tools: List[BaseTool] = Field(
default_factory=list, description="Tools at agent's disposal"
)
# Execution Control Properties
max_iterations: int = Field(
default=15, description="Maximum number of iterations for tool usage",
default=15, description="Maximum number of iterations for tool usage"
)
max_execution_time: int | None = Field(
default=None, description="Maximum execution time in seconds",
max_execution_time: Optional[int] = Field(
default=None, description="Maximum execution time in seconds"
)
respect_context_window: bool = Field(
default=True,
@@ -129,48 +128,47 @@ class LiteAgent(FlowTrackable, BaseModel):
default=True,
description="Whether to use stop words to prevent the LLM from using tools",
)
request_within_rpm_limit: Callable[[], bool] | None = Field(
request_within_rpm_limit: Optional[Callable[[], bool]] = Field(
default=None,
description="Callback to check if the request is within the RPM limit",
)
i18n: I18N = Field(default=I18N(), description="Internationalization settings.")
# Output and Formatting Properties
response_format: type[BaseModel] | None = Field(
default=None, description="Pydantic model for structured output",
response_format: Optional[Type[BaseModel]] = Field(
default=None, description="Pydantic model for structured output"
)
verbose: bool = Field(
default=False, description="Whether to print execution details",
default=False, description="Whether to print execution details"
)
callbacks: list[Callable] = Field(
default=[], description="Callbacks to be used for the agent",
callbacks: List[Callable] = Field(
default=[], description="Callbacks to be used for the agent"
)
# State and Results
tools_results: list[dict[str, Any]] = Field(
default=[], description="Results of the tools used by the agent.",
tools_results: List[Dict[str, Any]] = Field(
default=[], description="Results of the tools used by the agent."
)
# Reference of Agent
original_agent: BaseAgent | None = Field(
default=None, description="Reference to the agent that created this LiteAgent",
original_agent: Optional[BaseAgent] = Field(
default=None, description="Reference to the agent that created this LiteAgent"
)
# Private Attributes
_parsed_tools: list[CrewStructuredTool] = PrivateAttr(default_factory=list)
_parsed_tools: List[CrewStructuredTool] = PrivateAttr(default_factory=list)
_token_process: TokenProcess = PrivateAttr(default_factory=TokenProcess)
_cache_handler: CacheHandler = PrivateAttr(default_factory=CacheHandler)
_key: str = PrivateAttr(default_factory=lambda: str(uuid.uuid4()))
_messages: list[dict[str, str]] = PrivateAttr(default_factory=list)
_messages: List[Dict[str, str]] = PrivateAttr(default_factory=list)
_iterations: int = PrivateAttr(default=0)
_printer: Printer = PrivateAttr(default_factory=Printer)
@model_validator(mode="after")
def setup_llm(self):
"""Set up the LLM and other components after initialization."""
self.llm = create_llm(self.llm)
if not isinstance(self.llm, LLM):
msg = "Unable to create LLM instance"
raise ValueError(msg)
raise ValueError("Unable to create LLM instance")
# Initialize callbacks
token_callback = TokenCalcHandler(token_cost_process=self._token_process)
@@ -195,8 +193,9 @@ class LiteAgent(FlowTrackable, BaseModel):
"""Return the original role for compatibility with tool interfaces."""
return self.role
def kickoff(self, messages: str | list[dict[str, str]]) -> LiteAgentOutput:
"""Execute the agent with the given messages.
def kickoff(self, messages: Union[str, List[Dict[str, str]]]) -> LiteAgentOutput:
"""
Execute the agent with the given messages.
Args:
messages: Either a string query or a list of message dictionaries.
@@ -205,7 +204,6 @@ class LiteAgent(FlowTrackable, BaseModel):
Returns:
LiteAgentOutput: The result of the agent execution.
"""
# Create agent info for event emission
agent_info = {
@@ -236,18 +234,18 @@ class LiteAgent(FlowTrackable, BaseModel):
# Execute the agent using invoke loop
agent_finish = self._invoke_loop()
formatted_result: BaseModel | None = None
formatted_result: Optional[BaseModel] = None
if self.response_format:
try:
# Cast to BaseModel to ensure type safety
result = self.response_format.model_validate_json(
agent_finish.output,
agent_finish.output
)
if isinstance(result, BaseModel):
formatted_result = result
except Exception as e:
self._printer.print(
content=f"Failed to parse output into response format: {e!s}",
content=f"Failed to parse output into response format: {str(e)}",
color="yellow",
)
@@ -287,12 +285,13 @@ class LiteAgent(FlowTrackable, BaseModel):
error=str(e),
),
)
raise
raise e
async def kickoff_async(
self, messages: str | list[dict[str, str]],
self, messages: Union[str, List[Dict[str, str]]]
) -> LiteAgentOutput:
"""Execute the agent asynchronously with the given messages.
"""
Execute the agent asynchronously with the given messages.
Args:
messages: Either a string query or a list of message dictionaries.
@@ -301,7 +300,6 @@ class LiteAgent(FlowTrackable, BaseModel):
Returns:
LiteAgentOutput: The result of the agent execution.
"""
return await asyncio.to_thread(self.kickoff, messages)
@@ -320,7 +318,7 @@ class LiteAgent(FlowTrackable, BaseModel):
else:
# Use the prompt template for agents without tools
base_prompt = self.i18n.slice(
"lite_agent_system_prompt_without_tools",
"lite_agent_system_prompt_without_tools"
).format(
role=self.role,
backstory=self.backstory,
@@ -331,14 +329,14 @@ class LiteAgent(FlowTrackable, BaseModel):
if self.response_format:
schema = generate_model_description(self.response_format)
base_prompt += self.i18n.slice("lite_agent_response_format").format(
response_format=schema,
response_format=schema
)
return base_prompt
def _format_messages(
self, messages: str | list[dict[str, str]],
) -> list[dict[str, str]]:
self, messages: Union[str, List[Dict[str, str]]]
) -> List[Dict[str, str]]:
"""Format messages for the LLM."""
if isinstance(messages, str):
messages = [{"role": "user", "content": messages}]
@@ -354,11 +352,11 @@ class LiteAgent(FlowTrackable, BaseModel):
return formatted_messages
def _invoke_loop(self) -> AgentFinish:
"""Run the agent's thought process until it reaches a conclusion or max iterations.
"""
Run the agent's thought process until it reaches a conclusion or max iterations.
Returns:
AgentFinish: The final result of the agent execution.
"""
# Execute the agent loop
formatted_answer = None
@@ -370,7 +368,7 @@ class LiteAgent(FlowTrackable, BaseModel):
printer=self._printer,
i18n=self.i18n,
messages=self._messages,
llm=cast("LLM", self.llm),
llm=cast(LLM, self.llm),
callbacks=self._callbacks,
)
@@ -388,7 +386,7 @@ class LiteAgent(FlowTrackable, BaseModel):
try:
answer = get_llm_response(
llm=cast("LLM", self.llm),
llm=cast(LLM, self.llm),
messages=self._messages,
callbacks=self._callbacks,
printer=self._printer,
@@ -408,7 +406,7 @@ class LiteAgent(FlowTrackable, BaseModel):
self,
event=LLMCallFailedEvent(error=str(e)),
)
raise
raise e
formatted_answer = process_llm_response(answer, self.use_stop_words)
@@ -422,8 +420,8 @@ class LiteAgent(FlowTrackable, BaseModel):
agent_role=self.role,
agent=self.original_agent,
)
except Exception:
raise
except Exception as e:
raise e
formatted_answer = handle_agent_action_core(
formatted_answer=formatted_answer,
@@ -444,19 +442,20 @@ class LiteAgent(FlowTrackable, BaseModel):
except Exception as e:
if e.__class__.__module__.startswith("litellm"):
# Do not retry on litellm errors
raise
raise e
if is_context_length_exceeded(e):
handle_context_length(
respect_context_window=self.respect_context_window,
printer=self._printer,
messages=self._messages,
llm=cast("LLM", self.llm),
llm=cast(LLM, self.llm),
callbacks=self._callbacks,
i18n=self.i18n,
)
continue
handle_unknown_error(self._printer, e)
raise
else:
handle_unknown_error(self._printer, e)
raise e
finally:
self._iterations += 1
@@ -465,7 +464,7 @@ class LiteAgent(FlowTrackable, BaseModel):
self._show_logs(formatted_answer)
return formatted_answer
def _show_logs(self, formatted_answer: AgentAction | AgentFinish) -> None:
def _show_logs(self, formatted_answer: Union[AgentAction, AgentFinish]):
"""Show logs for the agent's execution."""
show_agent_logs(
printer=self._printer,

View File

@@ -6,10 +6,17 @@ import threading
import warnings
from collections import defaultdict
from contextlib import contextmanager
from types import SimpleNamespace
from typing import (
Any,
DefaultDict,
Dict,
List,
Literal,
Optional,
Type,
TypedDict,
Union,
cast,
)
@@ -24,6 +31,7 @@ from crewai.utilities.events.llm_events import (
LLMCallType,
LLMStreamChunkEvent,
)
from crewai.utilities.events.tool_usage_events import ToolExecutionErrorEvent
with warnings.catch_warnings():
warnings.simplefilter("ignore", UserWarning)
@@ -47,7 +55,7 @@ load_dotenv()
class FilteredStream:
def __init__(self, original_stream) -> None:
def __init__(self, original_stream):
self._original_stream = original_stream
self._lock = threading.Lock()
@@ -202,7 +210,7 @@ def suppress_warnings():
with warnings.catch_warnings():
warnings.filterwarnings("ignore")
warnings.filterwarnings(
"ignore", message="open_text is deprecated*", category=DeprecationWarning,
"ignore", message="open_text is deprecated*", category=DeprecationWarning
)
# Redirect stdout and stderr
@@ -218,14 +226,14 @@ def suppress_warnings():
class Delta(TypedDict):
content: str | None
role: str | None
content: Optional[str]
role: Optional[str]
class StreamingChoices(TypedDict):
delta: Delta
index: int
finish_reason: str | None
finish_reason: Optional[str]
class FunctionArgs(BaseModel):
@@ -241,31 +249,29 @@ class LLM(BaseLLM):
def __init__(
self,
model: str,
timeout: float | None = None,
temperature: float | None = None,
top_p: float | None = None,
n: int | None = None,
stop: str | list[str] | None = None,
max_completion_tokens: int | None = None,
max_tokens: int | None = None,
presence_penalty: float | None = None,
frequency_penalty: float | None = None,
logit_bias: dict[int, float] | None = None,
response_format: type[BaseModel] | None = None,
seed: int | None = None,
logprobs: int | None = None,
top_logprobs: int | None = None,
base_url: str | None = None,
api_base: str | None = None,
api_version: str | None = None,
api_key: str | None = None,
callbacks: list[Any] | None = None,
reasoning_effort: Literal["none", "low", "medium", "high"] | None = None,
timeout: Optional[Union[float, int]] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
n: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
max_completion_tokens: Optional[int] = None,
max_tokens: Optional[int] = None,
presence_penalty: Optional[float] = None,
frequency_penalty: Optional[float] = None,
logit_bias: Optional[Dict[int, float]] = None,
response_format: Optional[Type[BaseModel]] = None,
seed: Optional[int] = None,
logprobs: Optional[int] = None,
top_logprobs: Optional[int] = None,
base_url: Optional[str] = None,
api_base: Optional[str] = None,
api_version: Optional[str] = None,
api_key: Optional[str] = None,
callbacks: List[Any] = [],
reasoning_effort: Optional[Literal["none", "low", "medium", "high"]] = None,
stream: bool = False,
**kwargs,
) -> None:
if callbacks is None:
callbacks = []
):
self.model = model
self.timeout = timeout
self.temperature = temperature
@@ -295,7 +301,7 @@ class LLM(BaseLLM):
# Normalize self.stop to always be a List[str]
if stop is None:
self.stop: list[str] = []
self.stop: List[str] = []
elif isinstance(stop, str):
self.stop = [stop]
else:
@@ -312,16 +318,15 @@ class LLM(BaseLLM):
Returns:
bool: True if the model is from Anthropic, False otherwise.
"""
ANTHROPIC_PREFIXES = ("anthropic/", "claude-", "claude/")
return any(prefix in model.lower() for prefix in ANTHROPIC_PREFIXES)
def _prepare_completion_params(
self,
messages: str | list[dict[str, str]],
tools: list[dict] | None = None,
) -> dict[str, Any]:
messages: Union[str, List[Dict[str, str]]],
tools: Optional[List[dict]] = None,
) -> Dict[str, Any]:
"""Prepare parameters for the completion call.
Args:
@@ -332,7 +337,6 @@ class LLM(BaseLLM):
Returns:
Dict[str, Any]: Parameters for the completion call
"""
# --- 1) Format messages according to provider requirements
if isinstance(messages, str):
@@ -347,7 +351,6 @@ class LLM(BaseLLM):
"temperature": self.temperature,
"top_p": self.top_p,
"n": self.n,
"stop": self.stop,
"max_tokens": self.max_tokens or self.max_completion_tokens,
"presence_penalty": self.presence_penalty,
"frequency_penalty": self.frequency_penalty,
@@ -365,15 +368,18 @@ class LLM(BaseLLM):
"reasoning_effort": self.reasoning_effort,
**self.additional_params,
}
if self.stop and self.supports_stop_words():
params["stop"] = self.stop
# Remove None values from params
return {k: v for k, v in params.items() if v is not None}
def _handle_streaming_response(
self,
params: dict[str, Any],
callbacks: list[Any] | None = None,
available_functions: dict[str, Any] | None = None,
params: Dict[str, Any],
callbacks: Optional[List[Any]] = None,
available_functions: Optional[Dict[str, Any]] = None,
) -> str:
"""Handle a streaming response from the LLM.
@@ -387,7 +393,6 @@ class LLM(BaseLLM):
Raises:
Exception: If no content is received from the streaming response
"""
# --- 1) Initialize response tracking
full_response = ""
@@ -396,8 +401,8 @@ class LLM(BaseLLM):
usage_info = None
tool_calls = None
accumulated_tool_args: defaultdict[int, AccumulatedToolArgs] = defaultdict(
AccumulatedToolArgs,
accumulated_tool_args: DefaultDict[int, AccumulatedToolArgs] = defaultdict(
AccumulatedToolArgs
)
# --- 2) Make sure stream is set to True and include usage metrics
@@ -421,16 +426,16 @@ class LLM(BaseLLM):
choices = chunk["choices"]
elif hasattr(chunk, "choices"):
# Check if choices is not a type but an actual attribute with value
if not isinstance(chunk.choices, type):
choices = chunk.choices
if not isinstance(getattr(chunk, "choices"), type):
choices = getattr(chunk, "choices")
# Try to extract usage information if available
if isinstance(chunk, dict) and "usage" in chunk:
usage_info = chunk["usage"]
elif hasattr(chunk, "usage"):
# Check if usage is not a type but an actual attribute with value
if not isinstance(chunk.usage, type):
usage_info = chunk.usage
if not isinstance(getattr(chunk, "usage"), type):
usage_info = getattr(chunk, "usage")
if choices and len(choices) > 0:
choice = choices[0]
@@ -440,7 +445,7 @@ class LLM(BaseLLM):
if isinstance(choice, dict) and "delta" in choice:
delta = choice["delta"]
elif hasattr(choice, "delta"):
delta = choice.delta
delta = getattr(choice, "delta")
# Extract content from delta
if delta:
@@ -450,7 +455,7 @@ class LLM(BaseLLM):
chunk_content = delta["content"]
# Handle object format
elif hasattr(delta, "content"):
chunk_content = delta.content
chunk_content = getattr(delta, "content")
# Handle case where content might be None or empty
if chunk_content is None and isinstance(delta, dict):
@@ -488,21 +493,21 @@ class LLM(BaseLLM):
# --- 4) Fallback to non-streaming if no content received
if not full_response.strip() and chunk_count == 0:
logging.warning(
"No chunks received in streaming response, falling back to non-streaming",
"No chunks received in streaming response, falling back to non-streaming"
)
non_streaming_params = params.copy()
non_streaming_params["stream"] = False
non_streaming_params.pop(
"stream_options", None,
"stream_options", None
) # Remove stream_options for non-streaming call
return self._handle_non_streaming_response(
non_streaming_params, callbacks, available_functions,
non_streaming_params, callbacks, available_functions
)
# --- 5) Handle empty response with chunks
if not full_response.strip() and chunk_count > 0:
logging.warning(
f"Received {chunk_count} chunks but no content was extracted",
f"Received {chunk_count} chunks but no content was extracted"
)
if last_chunk is not None:
try:
@@ -511,8 +516,8 @@ class LLM(BaseLLM):
if isinstance(last_chunk, dict) and "choices" in last_chunk:
choices = last_chunk["choices"]
elif hasattr(last_chunk, "choices"):
if not isinstance(last_chunk.choices, type):
choices = last_chunk.choices
if not isinstance(getattr(last_chunk, "choices"), type):
choices = getattr(last_chunk, "choices")
if choices and len(choices) > 0:
choice = choices[0]
@@ -522,31 +527,30 @@ class LLM(BaseLLM):
if isinstance(choice, dict) and "message" in choice:
message = choice["message"]
elif hasattr(choice, "message"):
message = choice.message
message = getattr(choice, "message")
if message:
content = None
if isinstance(message, dict) and "content" in message:
content = message["content"]
elif hasattr(message, "content"):
content = message.content
content = getattr(message, "content")
if content:
full_response = content
logging.info(
f"Extracted content from last chunk message: {full_response}",
f"Extracted content from last chunk message: {full_response}"
)
except Exception as e:
logging.debug(f"Error extracting content from last chunk: {e}")
logging.debug(
f"Last chunk format: {type(last_chunk)}, content: {last_chunk}",
f"Last chunk format: {type(last_chunk)}, content: {last_chunk}"
)
# --- 6) If still empty, raise an error instead of using a default response
if not full_response.strip() and len(accumulated_tool_args) == 0:
msg = "No content received from streaming response. Received empty chunks or failed to extract content."
raise Exception(
msg,
"No content received from streaming response. Received empty chunks or failed to extract content."
)
# --- 7) Check for tool calls in the final response
@@ -557,8 +561,8 @@ class LLM(BaseLLM):
if isinstance(last_chunk, dict) and "choices" in last_chunk:
choices = last_chunk["choices"]
elif hasattr(last_chunk, "choices"):
if not isinstance(last_chunk.choices, type):
choices = last_chunk.choices
if not isinstance(getattr(last_chunk, "choices"), type):
choices = getattr(last_chunk, "choices")
if choices and len(choices) > 0:
choice = choices[0]
@@ -567,13 +571,13 @@ class LLM(BaseLLM):
if isinstance(choice, dict) and "message" in choice:
message = choice["message"]
elif hasattr(choice, "message"):
message = choice.message
message = getattr(choice, "message")
if message:
if isinstance(message, dict) and "tool_calls" in message:
tool_calls = message["tool_calls"]
elif hasattr(message, "tool_calls"):
tool_calls = message.tool_calls
tool_calls = getattr(message, "tool_calls")
except Exception as e:
logging.debug(f"Error checking for tool calls: {e}")
# --- 8) If no tool calls or no available functions, return the text response directly
@@ -603,9 +607,9 @@ class LLM(BaseLLM):
# decide whether to summarize the content or abort based on the respect_context_window flag.
raise LLMContextLengthExceededException(str(e))
except Exception as e:
logging.exception(f"Error in streaming response: {e!s}")
logging.error(f"Error in streaming response: {str(e)}")
if full_response.strip():
logging.warning(f"Returning partial response despite error: {e!s}")
logging.warning(f"Returning partial response despite error: {str(e)}")
self._handle_emit_call_events(full_response, LLMCallType.LLM_CALL)
return full_response
@@ -615,14 +619,13 @@ class LLM(BaseLLM):
self,
event=LLMCallFailedEvent(error=str(e)),
)
msg = f"Failed to get streaming response: {e!s}"
raise Exception(msg)
raise Exception(f"Failed to get streaming response: {str(e)}")
def _handle_streaming_tool_calls(
self,
tool_calls: list[ChatCompletionDeltaToolCall],
accumulated_tool_args: defaultdict[int, AccumulatedToolArgs],
available_functions: dict[str, Any] | None = None,
tool_calls: List[ChatCompletionDeltaToolCall],
accumulated_tool_args: DefaultDict[int, AccumulatedToolArgs],
available_functions: Optional[Dict[str, Any]] = None,
) -> None | str:
for tool_call in tool_calls:
current_tool_accumulator = accumulated_tool_args[tool_call.index]
@@ -661,9 +664,9 @@ class LLM(BaseLLM):
def _handle_streaming_callbacks(
self,
callbacks: list[Any] | None,
usage_info: dict[str, Any] | None,
last_chunk: Any | None,
callbacks: Optional[List[Any]],
usage_info: Optional[Dict[str, Any]],
last_chunk: Optional[Any],
) -> None:
"""Handle callbacks with usage info for streaming responses.
@@ -671,7 +674,6 @@ class LLM(BaseLLM):
callbacks: Optional list of callback functions
usage_info: Usage information collected during streaming
last_chunk: The last chunk received from the streaming response
"""
if callbacks and len(callbacks) > 0:
for callback in callbacks:
@@ -688,9 +690,9 @@ class LLM(BaseLLM):
usage_info = last_chunk["usage"]
elif hasattr(last_chunk, "usage"):
if not isinstance(
last_chunk.usage, type,
getattr(last_chunk, "usage"), type
):
usage_info = last_chunk.usage
usage_info = getattr(last_chunk, "usage")
except Exception as e:
logging.debug(f"Error extracting usage info: {e}")
@@ -704,9 +706,9 @@ class LLM(BaseLLM):
def _handle_non_streaming_response(
self,
params: dict[str, Any],
callbacks: list[Any] | None = None,
available_functions: dict[str, Any] | None = None,
params: Dict[str, Any],
callbacks: Optional[List[Any]] = None,
available_functions: Optional[Dict[str, Any]] = None,
) -> str:
"""Handle a non-streaming response from the LLM.
@@ -717,7 +719,6 @@ class LLM(BaseLLM):
Returns:
str: The response text
"""
# --- 1) Make the completion call
try:
@@ -732,7 +733,7 @@ class LLM(BaseLLM):
raise LLMContextLengthExceededException(str(e))
# --- 2) Extract response message and content
response_message = cast("Choices", cast("ModelResponse", response).choices)[
response_message = cast(Choices, cast(ModelResponse, response).choices)[
0
].message
text_response = response_message.content or ""
@@ -769,9 +770,9 @@ class LLM(BaseLLM):
def _handle_tool_call(
self,
tool_calls: list[Any],
available_functions: dict[str, Any] | None = None,
) -> str | None:
tool_calls: List[Any],
available_functions: Optional[Dict[str, Any]] = None,
) -> Optional[str]:
"""Handle a tool call from the LLM.
Args:
@@ -780,7 +781,6 @@ class LLM(BaseLLM):
Returns:
Optional[str]: The result of the tool call, or None if no tool call was made
"""
# --- 1) Validate tool calls and available functions
if not tool_calls or not available_functions:
@@ -807,23 +807,23 @@ class LLM(BaseLLM):
except Exception as e:
# --- 3.4) Handle execution errors
fn = available_functions.get(
function_name, lambda: None,
function_name, lambda: None
) # Ensure fn is always a callable
logging.exception(f"Error executing function '{function_name}': {e}")
logging.error(f"Error executing function '{function_name}': {e}")
assert hasattr(crewai_event_bus, "emit")
crewai_event_bus.emit(
self,
event=LLMCallFailedEvent(error=f"Tool execution error: {e!s}"),
event=LLMCallFailedEvent(error=f"Tool execution error: {str(e)}"),
)
return None
def call(
self,
messages: str | list[dict[str, str]],
tools: list[dict] | None = None,
callbacks: list[Any] | None = None,
available_functions: dict[str, Any] | None = None,
) -> str | Any:
messages: Union[str, List[Dict[str, str]]],
tools: Optional[List[dict]] = None,
callbacks: Optional[List[Any]] = None,
available_functions: Optional[Dict[str, Any]] = None,
) -> Union[str, Any]:
"""High-level LLM call method.
Args:
@@ -846,7 +846,6 @@ class LLM(BaseLLM):
TypeError: If messages format is invalid
ValueError: If response format is not supported
LLMContextLengthExceededException: If input exceeds model's context limit
"""
# --- 1) Emit call started event
assert hasattr(crewai_event_bus, "emit")
@@ -885,11 +884,12 @@ class LLM(BaseLLM):
# --- 7) Make the completion call and handle response
if self.stream:
return self._handle_streaming_response(
params, callbacks, available_functions,
params, callbacks, available_functions
)
else:
return self._handle_non_streaming_response(
params, callbacks, available_functions
)
return self._handle_non_streaming_response(
params, callbacks, available_functions,
)
except LLMContextLengthExceededException:
# Re-raise LLMContextLengthExceededException as it should be handled
@@ -902,16 +902,15 @@ class LLM(BaseLLM):
self,
event=LLMCallFailedEvent(error=str(e)),
)
logging.exception(f"LiteLLM call failed: {e!s}")
logging.error(f"LiteLLM call failed: {str(e)}")
raise
def _handle_emit_call_events(self, response: Any, call_type: LLMCallType) -> None:
def _handle_emit_call_events(self, response: Any, call_type: LLMCallType):
"""Handle the events for the LLM call.
Args:
response (str): The response from the LLM call.
call_type (str): The type of call, either "tool_call" or "llm_call".
"""
assert hasattr(crewai_event_bus, "emit")
crewai_event_bus.emit(
@@ -920,8 +919,8 @@ class LLM(BaseLLM):
)
def _format_messages_for_provider(
self, messages: list[dict[str, str]],
) -> list[dict[str, str]]:
self, messages: List[Dict[str, str]]
) -> List[Dict[str, str]]:
"""Format messages according to provider requirements.
Args:
@@ -934,18 +933,15 @@ class LLM(BaseLLM):
Raises:
TypeError: If messages is None or contains invalid message format.
"""
if messages is None:
msg = "Messages cannot be None"
raise TypeError(msg)
raise TypeError("Messages cannot be None")
# Validate message format first
for msg in messages:
if not isinstance(msg, dict) or "role" not in msg or "content" not in msg:
msg = "Invalid message format. Each message must be a dict with 'role' and 'content' keys"
raise TypeError(
msg,
"Invalid message format. Each message must be a dict with 'role' and 'content' keys"
)
# Handle O1 models specially
@@ -955,7 +951,7 @@ class LLM(BaseLLM):
# Convert system messages to assistant messages
if msg["role"] == "system":
formatted_messages.append(
{"role": "assistant", "content": msg["content"]},
{"role": "assistant", "content": msg["content"]}
)
else:
formatted_messages.append(msg)
@@ -983,8 +979,9 @@ class LLM(BaseLLM):
return messages
def _get_custom_llm_provider(self) -> str | None:
"""Derives the custom_llm_provider from the model string.
def _get_custom_llm_provider(self) -> Optional[str]:
"""
Derives the custom_llm_provider from the model string.
- For example, if the model is "openrouter/deepseek/deepseek-chat", returns "openrouter".
- If the model is "gemini/gemini-1.5-pro", returns "gemini".
- If there is no '/', defaults to "openai".
@@ -994,7 +991,8 @@ class LLM(BaseLLM):
return None
def _validate_call_params(self) -> None:
"""Validate parameters before making a call. Currently this only checks if
"""
Validate parameters before making a call. Currently this only checks if
a response_format is provided and whether the model supports it.
The custom_llm_provider is dynamically determined from the model:
- E.g., "openrouter/deepseek/deepseek-chat" yields "openrouter"
@@ -1006,22 +1004,19 @@ class LLM(BaseLLM):
model=self.model,
custom_llm_provider=provider,
):
msg = (
raise ValueError(
f"The model {self.model} does not support response_format for provider '{provider}'. "
"Please remove response_format or use a supported model."
)
raise ValueError(
msg,
)
def supports_function_calling(self) -> bool:
try:
provider = self._get_custom_llm_provider()
return litellm.utils.supports_function_calling(
self.model, custom_llm_provider=provider,
self.model, custom_llm_provider=provider
)
except Exception as e:
logging.exception(f"Failed to check function calling support: {e!s}")
logging.error(f"Failed to check function calling support: {str(e)}")
return False
def supports_stop_words(self) -> bool:
@@ -1029,16 +1024,16 @@ class LLM(BaseLLM):
params = get_supported_openai_params(model=self.model)
return params is not None and "stop" in params
except Exception as e:
logging.exception(f"Failed to get supported params: {e!s}")
logging.error(f"Failed to get supported params: {str(e)}")
return False
def get_context_window_size(self) -> int:
"""Returns the context window size, using 75% of the maximum to avoid
"""
Returns the context window size, using 75% of the maximum to avoid
cutting off messages mid-thread.
Raises:
ValueError: If a model's context window size is outside valid bounds (1024-2097152)
"""
if self.context_window_size != 0:
return self.context_window_size
@@ -1049,21 +1044,21 @@ class LLM(BaseLLM):
# Validate all context window sizes
for key, value in LLM_CONTEXT_WINDOW_SIZES.items():
if value < MIN_CONTEXT or value > MAX_CONTEXT:
msg = f"Context window for {key} must be between {MIN_CONTEXT} and {MAX_CONTEXT}"
raise ValueError(
msg,
f"Context window for {key} must be between {MIN_CONTEXT} and {MAX_CONTEXT}"
)
self.context_window_size = int(
DEFAULT_CONTEXT_WINDOW_SIZE * CONTEXT_WINDOW_USAGE_RATIO,
DEFAULT_CONTEXT_WINDOW_SIZE * CONTEXT_WINDOW_USAGE_RATIO
)
for key, value in LLM_CONTEXT_WINDOW_SIZES.items():
if self.model.startswith(key):
self.context_window_size = int(value * CONTEXT_WINDOW_USAGE_RATIO)
return self.context_window_size
def set_callbacks(self, callbacks: list[Any]) -> None:
"""Attempt to keep a single set of callbacks in litellm by removing old
def set_callbacks(self, callbacks: List[Any]):
"""
Attempt to keep a single set of callbacks in litellm by removing old
duplicates and adding new ones.
"""
with suppress_warnings():
@@ -1078,8 +1073,9 @@ class LLM(BaseLLM):
litellm.callbacks = callbacks
def set_env_callbacks(self) -> None:
"""Sets the success and failure callbacks for the LiteLLM library from environment variables.
def set_env_callbacks(self):
"""
Sets the success and failure callbacks for the LiteLLM library from environment variables.
This method reads the `LITELLM_SUCCESS_CALLBACKS` and `LITELLM_FAILURE_CALLBACKS`
environment variables, which should contain comma-separated lists of callback names.
@@ -1095,7 +1091,6 @@ class LLM(BaseLLM):
This will set `litellm.success_callback` to ["langfuse", "langsmith"] and
`litellm.failure_callback` to ["langfuse"].
"""
with suppress_warnings():
success_callbacks_str = os.environ.get("LITELLM_SUCCESS_CALLBACKS", "")

View File

@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Any
from typing import Any, Callable, Dict, List, Optional, Union
class BaseLLM(ABC):
@@ -17,18 +17,17 @@ class BaseLLM(ABC):
Attributes:
stop (list): A list of stop sequences that the LLM should use to stop generation.
This is used by the CrewAgentExecutor and other components.
"""
model: str
temperature: float | None = None
stop: list[str] | None = None
temperature: Optional[float] = None
stop: Optional[List[str]] = None
def __init__(
self,
model: str,
temperature: float | None = None,
) -> None:
temperature: Optional[float] = None,
):
"""Initialize the BaseLLM with default attributes.
This constructor sets default values for attributes that are expected
@@ -44,11 +43,11 @@ class BaseLLM(ABC):
@abstractmethod
def call(
self,
messages: str | list[dict[str, str]],
tools: list[dict] | None = None,
callbacks: list[Any] | None = None,
available_functions: dict[str, Any] | None = None,
) -> str | Any:
messages: Union[str, List[Dict[str, str]]],
tools: Optional[List[dict]] = None,
callbacks: Optional[List[Any]] = None,
available_functions: Optional[Dict[str, Any]] = None,
) -> Union[str, Any]:
"""Call the LLM with the given messages.
Args:
@@ -71,15 +70,14 @@ class BaseLLM(ABC):
ValueError: If the messages format is invalid.
TimeoutError: If the LLM request times out.
RuntimeError: If the LLM request fails for other reasons.
"""
pass
def supports_stop_words(self) -> bool:
"""Check if the LLM supports stop words.
Returns:
bool: True if the LLM supports stop words, False otherwise.
"""
return True # Default implementation assumes support for stop words
@@ -88,7 +86,6 @@ class BaseLLM(ABC):
Returns:
int: The number of tokens/characters the model can handle.
"""
# Default implementation - subclasses should override with model-specific values
return 4096

View File

@@ -1,4 +1,4 @@
from typing import Any
from typing import Any, Dict, List, Optional, Union
import aisuite as ai
@@ -6,17 +6,17 @@ from crewai.llms.base_llm import BaseLLM
class AISuiteLLM(BaseLLM):
def __init__(self, model: str, temperature: float | None = None, **kwargs) -> None:
def __init__(self, model: str, temperature: Optional[float] = None, **kwargs):
super().__init__(model, temperature, **kwargs)
self.client = ai.Client()
def call(
self,
messages: str | list[dict[str, str]],
tools: list[dict] | None = None,
callbacks: list[Any] | None = None,
available_functions: dict[str, Any] | None = None,
) -> str | Any:
messages: Union[str, List[Dict[str, str]]],
tools: Optional[List[dict]] = None,
callbacks: Optional[List[Any]] = None,
available_functions: Optional[Dict[str, Any]] = None,
) -> Union[str, Any]:
completion_params = self._prepare_completion_params(messages, tools)
response = self.client.chat.completions.create(**completion_params)
@@ -24,9 +24,9 @@ class AISuiteLLM(BaseLLM):
def _prepare_completion_params(
self,
messages: str | list[dict[str, str]],
tools: list[dict] | None = None,
) -> dict[str, Any]:
messages: Union[str, List[Dict[str, str]]],
tools: Optional[List[dict]] = None,
) -> Dict[str, Any]:
return {
"model": self.model,
"messages": messages,

View File

@@ -1,4 +1,4 @@
from typing import Any
from typing import Any, Dict, Optional
from crewai.memory import (
EntityMemory,
@@ -12,13 +12,13 @@ from crewai.memory import (
class ContextualMemory:
def __init__(
self,
memory_config: dict[str, Any] | None,
memory_config: Optional[Dict[str, Any]],
stm: ShortTermMemory,
ltm: LongTermMemory,
em: EntityMemory,
um: UserMemory,
exm: ExternalMemory,
) -> None:
):
if memory_config is not None:
self.memory_provider = memory_config.get("provider")
else:
@@ -30,7 +30,8 @@ class ContextualMemory:
self.exm = exm
def build_context_for_task(self, task, context) -> str:
"""Automatically builds a minimal, highly relevant set of contextual information
"""
Automatically builds a minimal, highly relevant set of contextual information
for a given task.
"""
query = f"{task.description} {context}".strip()
@@ -48,9 +49,11 @@ class ContextualMemory:
return "\n".join(filter(None, context))
def _fetch_stm_context(self, query) -> str:
"""Fetches recent relevant insights from STM related to the task's description and expected_output,
"""
Fetches recent relevant insights from STM related to the task's description and expected_output,
formatted as bullet points.
"""
if self.stm is None:
return ""
@@ -59,14 +62,16 @@ class ContextualMemory:
[
f"- {result['memory'] if self.memory_provider == 'mem0' else result['context']}"
for result in stm_results
],
]
)
return f"Recent Insights:\n{formatted_results}" if stm_results else ""
def _fetch_ltm_context(self, task) -> str | None:
"""Fetches historical data or insights from LTM that are relevant to the task's description and expected_output,
def _fetch_ltm_context(self, task) -> Optional[str]:
"""
Fetches historical data or insights from LTM that are relevant to the task's description and expected_output,
formatted as bullet points.
"""
if self.ltm is None:
return ""
@@ -85,7 +90,8 @@ class ContextualMemory:
return f"Historical Data:\n{formatted_results}" if ltm_results else ""
def _fetch_entity_context(self, query) -> str:
"""Fetches relevant entity information from Entity Memory related to the task's description and expected_output,
"""
Fetches relevant entity information from Entity Memory related to the task's description and expected_output,
formatted as bullet points.
"""
if self.em is None:
@@ -96,20 +102,19 @@ class ContextualMemory:
[
f"- {result['memory'] if self.memory_provider == 'mem0' else result['context']}"
for result in em_results
], # type: ignore # Invalid index type "str" for "str"; expected type "SupportsIndex | slice"
] # type: ignore # Invalid index type "str" for "str"; expected type "SupportsIndex | slice"
)
return f"Entities:\n{formatted_results}" if em_results else ""
def _fetch_user_context(self, query: str) -> str:
"""Fetches and formats relevant user information from User Memory.
"""
Fetches and formats relevant user information from User Memory.
Args:
query (str): The search query to find relevant user memories.
Returns:
str: Formatted user memories as bullet points, or an empty string if none found.
"""
if self.um is None:
return ""
@@ -123,14 +128,12 @@ class ContextualMemory:
return f"User memories/preferences:\n{formatted_memories}"
def _fetch_external_context(self, query: str) -> str:
"""Fetches and formats relevant information from External Memory.
"""
Fetches and formats relevant information from External Memory.
Args:
query (str): The search query to find relevant information.
Returns:
str: Formatted information as bullet points, or an empty string if none found.
"""
if self.exm is None:
return ""

View File

@@ -1,3 +1,4 @@
from typing import Optional
from pydantic import PrivateAttr
@@ -7,14 +8,15 @@ from crewai.memory.storage.rag_storage import RAGStorage
class EntityMemory(Memory):
"""EntityMemory class for managing structured information about entities
"""
EntityMemory class for managing structured information about entities
and their relationships using SQLite storage.
Inherits from the Memory class.
"""
_memory_provider: str | None = PrivateAttr()
_memory_provider: Optional[str] = PrivateAttr()
def __init__(self, crew=None, embedder_config=None, storage=None, path=None) -> None:
def __init__(self, crew=None, embedder_config=None, storage=None, path=None):
if crew and hasattr(crew, "memory_config") and crew.memory_config is not None:
memory_provider = crew.memory_config.get("provider")
else:
@@ -24,9 +26,8 @@ class EntityMemory(Memory):
try:
from crewai.memory.storage.mem0_storage import Mem0Storage
except ImportError:
msg = "Mem0 is not installed. Please install it with `pip install mem0ai`."
raise ImportError(
msg,
"Mem0 is not installed. Please install it with `pip install mem0ai`."
)
storage = Mem0Storage(type="entities", crew=crew)
else:
@@ -62,5 +63,4 @@ class EntityMemory(Memory):
try:
self.storage.reset()
except Exception as e:
msg = f"An error occurred while resetting the entity memory: {e}"
raise Exception(msg)
raise Exception(f"An error occurred while resetting the entity memory: {e}")

View File

@@ -5,7 +5,7 @@ class EntityMemoryItem:
type: str,
description: str,
relationships: str,
) -> None:
):
self.name = name
self.type = type
self.description = description

View File

@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Dict, Optional
from crewai.memory.external.external_memory_item import ExternalMemoryItem
from crewai.memory.memory import Memory
@@ -9,44 +9,41 @@ if TYPE_CHECKING:
class ExternalMemory(Memory):
def __init__(self, storage: Storage | None = None, **data: Any) -> None:
def __init__(self, storage: Optional[Storage] = None, **data: Any):
super().__init__(storage=storage, **data)
@staticmethod
def _configure_mem0(crew: Any, config: dict[str, Any]) -> "Mem0Storage":
def _configure_mem0(crew: Any, config: Dict[str, Any]) -> "Mem0Storage":
from crewai.memory.storage.mem0_storage import Mem0Storage
return Mem0Storage(type="external", crew=crew, config=config)
@staticmethod
def external_supported_storages() -> dict[str, Any]:
def external_supported_storages() -> Dict[str, Any]:
return {
"mem0": ExternalMemory._configure_mem0,
}
@staticmethod
def create_storage(crew: Any, embedder_config: dict[str, Any] | None) -> Storage:
def create_storage(crew: Any, embedder_config: Optional[Dict[str, Any]]) -> Storage:
if not embedder_config:
msg = "embedder_config is required"
raise ValueError(msg)
raise ValueError("embedder_config is required")
if "provider" not in embedder_config:
msg = "embedder_config must include a 'provider' key"
raise ValueError(msg)
raise ValueError("embedder_config must include a 'provider' key")
provider = embedder_config["provider"]
supported_storages = ExternalMemory.external_supported_storages()
if provider not in supported_storages:
msg = f"Provider {provider} not supported"
raise ValueError(msg)
raise ValueError(f"Provider {provider} not supported")
return supported_storages[provider](crew, embedder_config.get("config", {}))
def save(
self,
value: Any,
metadata: dict[str, Any] | None = None,
agent: str | None = None,
metadata: Optional[Dict[str, Any]] = None,
agent: Optional[str] = None,
) -> None:
"""Saves a value into the external storage."""
item = ExternalMemoryItem(value=value, metadata=metadata, agent=agent)

Some files were not shown because too many files have changed in this diff Show More