mirror of
https://github.com/crewAIInc/crewAI.git
synced 2025-12-16 20:38:29 +00:00
Compare commits
22 Commits
fix/memory
...
fix-cli
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c9476769e1 | ||
|
|
d59ecb22e6 | ||
|
|
263544524d | ||
|
|
098a4312ab | ||
|
|
b8a3c29745 | ||
|
|
9cd4ff05c9 | ||
|
|
4687779702 | ||
|
|
c724c0af70 | ||
|
|
f6f430b26a | ||
|
|
a5f70d2307 | ||
|
|
b55fc40c83 | ||
|
|
d0ed4f5274 | ||
|
|
ee34399b71 | ||
|
|
39903f0c50 | ||
|
|
c4bf713113 | ||
|
|
5d18c6312d | ||
|
|
1f9baf9b2c | ||
|
|
6fbc97b298 | ||
|
|
08bacfa892 | ||
|
|
1ea8115d56 | ||
|
|
6b906f09cf | ||
|
|
6c29ebafea |
@@ -6,7 +6,7 @@ icon: terminal
|
||||
|
||||
# CrewAI CLI Documentation
|
||||
|
||||
The CrewAI CLI provides a set of commands to interact with CrewAI, allowing you to create, train, run, and manage crews and pipelines.
|
||||
The CrewAI CLI provides a set of commands to interact with CrewAI, allowing you to create, train, run, and manage crews & flows.
|
||||
|
||||
## Installation
|
||||
|
||||
@@ -146,3 +146,34 @@ crewai run
|
||||
Make sure to run these commands from the directory where your CrewAI project is set up.
|
||||
Some commands may require additional configuration or setup within your project structure.
|
||||
</Note>
|
||||
|
||||
|
||||
### 9. API Keys
|
||||
|
||||
When running ```crewai create crew``` command, the CLI will first show you the top 5 most common LLM providers and ask you to select one.
|
||||
|
||||
Once you've selected an LLM provider, you will be prompted for API keys.
|
||||
|
||||
#### Initial API key providers
|
||||
|
||||
The CLI will initially prompt for API keys for the following services:
|
||||
|
||||
* OpenAI
|
||||
* Groq
|
||||
* Anthropic
|
||||
* Google Gemini
|
||||
|
||||
When you select a provider, the CLI will prompt you to enter your API key.
|
||||
|
||||
#### Other Options
|
||||
|
||||
If you select option 6, you will be able to select from a list of LiteLLM supported providers.
|
||||
|
||||
When you select a provider, the CLI will prompt you to enter the Key name and the API key.
|
||||
|
||||
See the following link for each provider's key name:
|
||||
|
||||
* [LiteLLM Providers](https://docs.litellm.ai/docs/providers)
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -105,9 +105,48 @@ my_crew = Crew(
|
||||
process=Process.sequential,
|
||||
memory=True,
|
||||
verbose=True,
|
||||
embedder=embedding_functions.OpenAIEmbeddingFunction(
|
||||
api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small"
|
||||
)
|
||||
embedder={
|
||||
"provider": "openai",
|
||||
"config": {
|
||||
"model": 'text-embedding-3-small'
|
||||
}
|
||||
}
|
||||
)
|
||||
```
|
||||
Alternatively, you can directly pass the OpenAIEmbeddingFunction to the embedder parameter.
|
||||
|
||||
Example:
|
||||
```python Code
|
||||
from crewai import Crew, Agent, Task, Process
|
||||
from chromadb.utils.embedding_functions.openai_embedding_function import OpenAIEmbeddingFunction
|
||||
|
||||
my_crew = Crew(
|
||||
agents=[...],
|
||||
tasks=[...],
|
||||
process=Process.sequential,
|
||||
memory=True,
|
||||
verbose=True,
|
||||
embedder=OpenAIEmbeddingFunction(api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small"),
|
||||
)
|
||||
```
|
||||
|
||||
### Using Ollama embeddings
|
||||
|
||||
```python Code
|
||||
from crewai import Crew, Agent, Task, Process
|
||||
|
||||
my_crew = Crew(
|
||||
agents=[...],
|
||||
tasks=[...],
|
||||
process=Process.sequential,
|
||||
memory=True,
|
||||
verbose=True,
|
||||
embedder={
|
||||
"provider": "ollama",
|
||||
"config": {
|
||||
"model": "mxbai-embed-large"
|
||||
}
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
@@ -122,10 +161,13 @@ my_crew = Crew(
|
||||
process=Process.sequential,
|
||||
memory=True,
|
||||
verbose=True,
|
||||
embedder=embedding_functions.OpenAIEmbeddingFunction(
|
||||
api_key=os.getenv("OPENAI_API_KEY"),
|
||||
model_name="text-embedding-ada-002"
|
||||
)
|
||||
embedder={
|
||||
"provider": "google",
|
||||
"config": {
|
||||
"api_key": "<YOUR_API_KEY>",
|
||||
"model_name": "<model_name>"
|
||||
}
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
@@ -181,10 +223,32 @@ my_crew = Crew(
|
||||
process=Process.sequential,
|
||||
memory=True,
|
||||
verbose=True,
|
||||
embedder=embedding_functions.CohereEmbeddingFunction(
|
||||
api_key=YOUR_API_KEY,
|
||||
model_name="<model_name>"
|
||||
)
|
||||
embedder={
|
||||
"provider": "cohere",
|
||||
"config": {
|
||||
"api_key": "YOUR_API_KEY",
|
||||
"model_name": "<model_name>"
|
||||
}
|
||||
}
|
||||
)
|
||||
```
|
||||
### Using HuggingFace embeddings
|
||||
|
||||
```python Code
|
||||
from crewai import Crew, Agent, Task, Process
|
||||
|
||||
my_crew = Crew(
|
||||
agents=[...],
|
||||
tasks=[...],
|
||||
process=Process.sequential,
|
||||
memory=True,
|
||||
verbose=True,
|
||||
embedder={
|
||||
"provider": "huggingface",
|
||||
"config": {
|
||||
"api_url": "<api_url>",
|
||||
}
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "crewai"
|
||||
version = "0.74.2"
|
||||
version = "0.75.1"
|
||||
description = "Cutting-edge framework for orchestrating role-playing, autonomous AI agents. By fostering collaborative intelligence, CrewAI empowers agents to work together seamlessly, tackling complex tasks."
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10,<=3.13"
|
||||
|
||||
@@ -14,5 +14,5 @@ warnings.filterwarnings(
|
||||
category=UserWarning,
|
||||
module="pydantic.main",
|
||||
)
|
||||
__version__ = "0.74.2"
|
||||
__version__ = "0.75.1"
|
||||
__all__ = ["Agent", "Crew", "Process", "Task", "Pipeline", "Router", "LLM", "Flow"]
|
||||
|
||||
@@ -32,10 +32,11 @@ def crewai():
|
||||
@crewai.command()
|
||||
@click.argument("type", type=click.Choice(["crew", "pipeline", "flow"]))
|
||||
@click.argument("name")
|
||||
def create(type, name):
|
||||
@click.option("--provider", type=str, help="The provider to use for the crew")
|
||||
def create(type, name, provider):
|
||||
"""Create a new crew, pipeline, or flow."""
|
||||
if type == "crew":
|
||||
create_crew(name)
|
||||
create_crew(name, provider)
|
||||
elif type == "pipeline":
|
||||
create_pipeline(name)
|
||||
elif type == "flow":
|
||||
|
||||
@@ -1,8 +1,16 @@
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import click
|
||||
from crewai.cli.utils import copy_template, load_env_vars, write_env_file
|
||||
from crewai.cli.provider import get_provider_data, select_provider, PROVIDERS
|
||||
|
||||
from crewai.cli.constants import ENV_VARS
|
||||
from crewai.cli.provider import (
|
||||
PROVIDERS,
|
||||
get_provider_data,
|
||||
select_model,
|
||||
select_provider,
|
||||
)
|
||||
from crewai.cli.utils import copy_template, load_env_vars, write_env_file
|
||||
|
||||
|
||||
def create_folder_structure(name, parent_folder=None):
|
||||
@@ -14,11 +22,19 @@ def create_folder_structure(name, parent_folder=None):
|
||||
else:
|
||||
folder_path = Path(folder_name)
|
||||
|
||||
click.secho(
|
||||
f"Creating {'crew' if parent_folder else 'folder'} {folder_name}...",
|
||||
fg="green",
|
||||
bold=True,
|
||||
)
|
||||
if folder_path.exists():
|
||||
if not click.confirm(
|
||||
f"Folder {folder_name} already exists. Do you want to override it?"
|
||||
):
|
||||
click.secho("Operation cancelled.", fg="yellow")
|
||||
sys.exit(0)
|
||||
click.secho(f"Overriding folder {folder_name}...", fg="green", bold=True)
|
||||
else:
|
||||
click.secho(
|
||||
f"Creating {'crew' if parent_folder else 'folder'} {folder_name}...",
|
||||
fg="green",
|
||||
bold=True,
|
||||
)
|
||||
|
||||
if not folder_path.exists():
|
||||
folder_path.mkdir(parents=True)
|
||||
@@ -27,11 +43,6 @@ def create_folder_structure(name, parent_folder=None):
|
||||
(folder_path / "src" / folder_name).mkdir(parents=True)
|
||||
(folder_path / "src" / folder_name / "tools").mkdir(parents=True)
|
||||
(folder_path / "src" / folder_name / "config").mkdir(parents=True)
|
||||
else:
|
||||
click.secho(
|
||||
f"\tFolder {folder_name} already exists.",
|
||||
fg="yellow",
|
||||
)
|
||||
|
||||
return folder_path, folder_name, class_name
|
||||
|
||||
@@ -74,33 +85,73 @@ def create_crew(name, parent_folder=None):
|
||||
folder_path, folder_name, class_name = create_folder_structure(name, parent_folder)
|
||||
env_vars = load_env_vars(folder_path)
|
||||
|
||||
existing_provider = None
|
||||
for provider, env_keys in ENV_VARS.items():
|
||||
if any(key in env_vars for key in env_keys):
|
||||
existing_provider = provider
|
||||
break
|
||||
|
||||
if existing_provider:
|
||||
if not click.confirm(
|
||||
f"Found existing environment variable configuration for {existing_provider.capitalize()}. Do you want to override it?"
|
||||
):
|
||||
click.secho("Keeping existing provider configuration.", fg="yellow")
|
||||
return
|
||||
|
||||
provider_models = get_provider_data()
|
||||
if not provider_models:
|
||||
return
|
||||
|
||||
selected_provider = select_provider(provider_models)
|
||||
if not selected_provider:
|
||||
return
|
||||
provider = selected_provider
|
||||
|
||||
# selected_model = select_model(provider, provider_models)
|
||||
# if not selected_model:
|
||||
# return
|
||||
# model = selected_model
|
||||
|
||||
if provider in PROVIDERS:
|
||||
api_key_var = ENV_VARS[provider][0]
|
||||
else:
|
||||
api_key_var = click.prompt(
|
||||
f"Enter the environment variable name for your {provider.capitalize()} API key",
|
||||
type=str,
|
||||
while True:
|
||||
selected_provider = select_provider(provider_models)
|
||||
if selected_provider is None: # User typed 'q'
|
||||
click.secho("Exiting...", fg="yellow")
|
||||
sys.exit(0)
|
||||
if selected_provider: # Valid selection
|
||||
break
|
||||
click.secho(
|
||||
"No provider selected. Please try again or press 'q' to exit.", fg="red"
|
||||
)
|
||||
|
||||
env_vars = {api_key_var: "YOUR_API_KEY_HERE"}
|
||||
write_env_file(folder_path, env_vars)
|
||||
while True:
|
||||
selected_model = select_model(selected_provider, provider_models)
|
||||
if selected_model is None: # User typed 'q'
|
||||
click.secho("Exiting...", fg="yellow")
|
||||
sys.exit(0)
|
||||
if selected_model: # Valid selection
|
||||
break
|
||||
click.secho(
|
||||
"No model selected. Please try again or press 'q' to exit.", fg="red"
|
||||
)
|
||||
|
||||
# env_vars['MODEL'] = model
|
||||
# click.secho(f"Selected model: {model}", fg="green")
|
||||
if selected_provider in PROVIDERS:
|
||||
api_key_var = ENV_VARS[selected_provider][0]
|
||||
else:
|
||||
api_key_var = click.prompt(
|
||||
f"Enter the environment variable name for your {selected_provider.capitalize()} API key",
|
||||
type=str,
|
||||
default="",
|
||||
)
|
||||
|
||||
api_key_value = ""
|
||||
click.echo(
|
||||
f"Enter your {selected_provider.capitalize()} API key (press Enter to skip): ",
|
||||
nl=False,
|
||||
)
|
||||
try:
|
||||
api_key_value = input()
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
api_key_value = ""
|
||||
|
||||
if api_key_value.strip():
|
||||
env_vars = {api_key_var: api_key_value}
|
||||
write_env_file(folder_path, env_vars)
|
||||
click.secho("API key saved to .env file", fg="green")
|
||||
else:
|
||||
click.secho("No API key provided. Skipping .env file creation.", fg="yellow")
|
||||
|
||||
env_vars["MODEL"] = selected_model
|
||||
click.secho(f"Selected model: {selected_model}", fg="green")
|
||||
|
||||
package_dir = Path(__file__).parent
|
||||
templates_dir = package_dir / "templates" / "crew"
|
||||
|
||||
@@ -1,67 +1,91 @@
|
||||
import json
|
||||
import time
|
||||
import requests
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
|
||||
import click
|
||||
from pathlib import Path
|
||||
from crewai.cli.constants import PROVIDERS, MODELS, JSON_URL
|
||||
import requests
|
||||
|
||||
from crewai.cli.constants import JSON_URL, MODELS, PROVIDERS
|
||||
|
||||
|
||||
def select_choice(prompt_message, choices):
|
||||
"""
|
||||
Presents a list of choices to the user and prompts them to select one.
|
||||
|
||||
|
||||
Args:
|
||||
- prompt_message (str): The message to display to the user before presenting the choices.
|
||||
- choices (list): A list of options to present to the user.
|
||||
|
||||
|
||||
Returns:
|
||||
- str: The selected choice from the list, or None if the operation is aborted or an invalid selection is made.
|
||||
- str: The selected choice from the list, or None if the user chooses to quit.
|
||||
"""
|
||||
|
||||
provider_models = get_provider_data()
|
||||
if not provider_models:
|
||||
return
|
||||
click.secho(prompt_message, fg="cyan")
|
||||
for idx, choice in enumerate(choices, start=1):
|
||||
click.secho(f"{idx}. {choice}", fg="cyan")
|
||||
try:
|
||||
selected_index = click.prompt("Enter the number of your choice", type=int) - 1
|
||||
except click.exceptions.Abort:
|
||||
click.secho("Operation aborted by the user.", fg="red")
|
||||
return None
|
||||
if not (0 <= selected_index < len(choices)):
|
||||
click.secho("Invalid selection.", fg="red")
|
||||
return None
|
||||
return choices[selected_index]
|
||||
click.secho("q. Quit", fg="cyan")
|
||||
|
||||
while True:
|
||||
choice = click.prompt(
|
||||
"Enter the number of your choice or 'q' to quit", type=str
|
||||
)
|
||||
|
||||
if choice.lower() == "q":
|
||||
return None
|
||||
|
||||
try:
|
||||
selected_index = int(choice) - 1
|
||||
if 0 <= selected_index < len(choices):
|
||||
return choices[selected_index]
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
click.secho(
|
||||
"Invalid selection. Please select a number between 1 and 6 or 'q' to quit.",
|
||||
fg="red",
|
||||
)
|
||||
|
||||
|
||||
def select_provider(provider_models):
|
||||
"""
|
||||
Presents a list of providers to the user and prompts them to select one.
|
||||
|
||||
|
||||
Args:
|
||||
- provider_models (dict): A dictionary of provider models.
|
||||
|
||||
|
||||
Returns:
|
||||
- str: The selected provider, or None if the operation is aborted or an invalid selection is made.
|
||||
- str: The selected provider
|
||||
- None: If user explicitly quits
|
||||
"""
|
||||
predefined_providers = [p.lower() for p in PROVIDERS]
|
||||
all_providers = sorted(set(predefined_providers + list(provider_models.keys())))
|
||||
|
||||
provider = select_choice("Select a provider to set up:", predefined_providers + ['other'])
|
||||
if not provider:
|
||||
provider = select_choice(
|
||||
"Select a provider to set up:", predefined_providers + ["other"]
|
||||
)
|
||||
if provider is None: # User typed 'q'
|
||||
return None
|
||||
provider = provider.lower()
|
||||
|
||||
if provider == 'other':
|
||||
if provider == "other":
|
||||
provider = select_choice("Select a provider from the full list:", all_providers)
|
||||
if not provider:
|
||||
if provider is None: # User typed 'q'
|
||||
return None
|
||||
return provider
|
||||
|
||||
return provider.lower() if provider else False
|
||||
|
||||
|
||||
def select_model(provider, provider_models):
|
||||
"""
|
||||
Presents a list of models for a given provider to the user and prompts them to select one.
|
||||
|
||||
|
||||
Args:
|
||||
- provider (str): The provider for which to select a model.
|
||||
- provider_models (dict): A dictionary of provider models.
|
||||
|
||||
|
||||
Returns:
|
||||
- str: The selected model, or None if the operation is aborted or an invalid selection is made.
|
||||
"""
|
||||
@@ -76,37 +100,49 @@ def select_model(provider, provider_models):
|
||||
click.secho(f"No models available for provider '{provider}'.", fg="red")
|
||||
return None
|
||||
|
||||
selected_model = select_choice(f"Select a model to use for {provider.capitalize()}:", available_models)
|
||||
selected_model = select_choice(
|
||||
f"Select a model to use for {provider.capitalize()}:", available_models
|
||||
)
|
||||
return selected_model
|
||||
|
||||
|
||||
def load_provider_data(cache_file, cache_expiry):
|
||||
"""
|
||||
Loads provider data from a cache file if it exists and is not expired. If the cache is expired or corrupted, it fetches the data from the web.
|
||||
|
||||
|
||||
Args:
|
||||
- cache_file (Path): The path to the cache file.
|
||||
- cache_expiry (int): The cache expiry time in seconds.
|
||||
|
||||
|
||||
Returns:
|
||||
- dict or None: The loaded provider data or None if the operation fails.
|
||||
"""
|
||||
current_time = time.time()
|
||||
if cache_file.exists() and (current_time - cache_file.stat().st_mtime) < cache_expiry:
|
||||
if (
|
||||
cache_file.exists()
|
||||
and (current_time - cache_file.stat().st_mtime) < cache_expiry
|
||||
):
|
||||
data = read_cache_file(cache_file)
|
||||
if data:
|
||||
return data
|
||||
click.secho("Cache is corrupted. Fetching provider data from the web...", fg="yellow")
|
||||
click.secho(
|
||||
"Cache is corrupted. Fetching provider data from the web...", fg="yellow"
|
||||
)
|
||||
else:
|
||||
click.secho("Cache expired or not found. Fetching provider data from the web...", fg="cyan")
|
||||
click.secho(
|
||||
"Cache expired or not found. Fetching provider data from the web...",
|
||||
fg="cyan",
|
||||
)
|
||||
return fetch_provider_data(cache_file)
|
||||
|
||||
|
||||
def read_cache_file(cache_file):
|
||||
"""
|
||||
Reads and returns the JSON content from a cache file. Returns None if the file contains invalid JSON.
|
||||
|
||||
|
||||
Args:
|
||||
- cache_file (Path): The path to the cache file.
|
||||
|
||||
|
||||
Returns:
|
||||
- dict or None: The JSON content of the cache file or None if the JSON is invalid.
|
||||
"""
|
||||
@@ -116,13 +152,14 @@ def read_cache_file(cache_file):
|
||||
except json.JSONDecodeError:
|
||||
return None
|
||||
|
||||
|
||||
def fetch_provider_data(cache_file):
|
||||
"""
|
||||
Fetches provider data from a specified URL and caches it to a file.
|
||||
|
||||
|
||||
Args:
|
||||
- cache_file (Path): The path to the cache file.
|
||||
|
||||
|
||||
Returns:
|
||||
- dict or None: The fetched provider data or None if the operation fails.
|
||||
"""
|
||||
@@ -139,38 +176,42 @@ def fetch_provider_data(cache_file):
|
||||
click.secho("Error parsing provider data. Invalid JSON format.", fg="red")
|
||||
return None
|
||||
|
||||
|
||||
def download_data(response):
|
||||
"""
|
||||
Downloads data from a given HTTP response and returns the JSON content.
|
||||
|
||||
|
||||
Args:
|
||||
- response (requests.Response): The HTTP response object.
|
||||
|
||||
|
||||
Returns:
|
||||
- dict: The JSON content of the response.
|
||||
"""
|
||||
total_size = int(response.headers.get('content-length', 0))
|
||||
total_size = int(response.headers.get("content-length", 0))
|
||||
block_size = 8192
|
||||
data_chunks = []
|
||||
with click.progressbar(length=total_size, label='Downloading', show_pos=True) as progress_bar:
|
||||
with click.progressbar(
|
||||
length=total_size, label="Downloading", show_pos=True
|
||||
) as progress_bar:
|
||||
for chunk in response.iter_content(block_size):
|
||||
if chunk:
|
||||
data_chunks.append(chunk)
|
||||
progress_bar.update(len(chunk))
|
||||
data_content = b''.join(data_chunks)
|
||||
return json.loads(data_content.decode('utf-8'))
|
||||
data_content = b"".join(data_chunks)
|
||||
return json.loads(data_content.decode("utf-8"))
|
||||
|
||||
|
||||
def get_provider_data():
|
||||
"""
|
||||
Retrieves provider data from a cache file, filters out models based on provider criteria, and returns a dictionary of providers mapped to their models.
|
||||
|
||||
|
||||
Returns:
|
||||
- dict or None: A dictionary of providers mapped to their models or None if the operation fails.
|
||||
"""
|
||||
cache_dir = Path.home() / '.crewai'
|
||||
cache_dir = Path.home() / ".crewai"
|
||||
cache_dir.mkdir(exist_ok=True)
|
||||
cache_file = cache_dir / 'provider_cache.json'
|
||||
cache_expiry = 24 * 3600
|
||||
cache_file = cache_dir / "provider_cache.json"
|
||||
cache_expiry = 24 * 3600
|
||||
|
||||
data = load_provider_data(cache_file, cache_expiry)
|
||||
if not data:
|
||||
@@ -179,8 +220,8 @@ def get_provider_data():
|
||||
provider_models = defaultdict(list)
|
||||
for model_name, properties in data.items():
|
||||
provider = properties.get("litellm_provider", "").strip().lower()
|
||||
if 'http' in provider or provider == 'other':
|
||||
if "http" in provider or provider == "other":
|
||||
continue
|
||||
if provider:
|
||||
provider_models[provider].append(model_name)
|
||||
return provider_models
|
||||
return provider_models
|
||||
|
||||
@@ -5,7 +5,7 @@ description = "{{name}} using crewAI"
|
||||
authors = [{ name = "Your Name", email = "you@example.com" }]
|
||||
requires-python = ">=3.10,<=3.13"
|
||||
dependencies = [
|
||||
"crewai[tools]>=0.74.2,<1.0.0"
|
||||
"crewai[tools]>=0.75.1,<1.0.0"
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
||||
@@ -5,7 +5,7 @@ description = "{{name}} using crewAI"
|
||||
authors = [{ name = "Your Name", email = "you@example.com" }]
|
||||
requires-python = ">=3.10,<=3.13"
|
||||
dependencies = [
|
||||
"crewai[tools]>=0.74.2,<1.0.0",
|
||||
"crewai[tools]>=0.75.1,<1.0.0",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
||||
@@ -6,7 +6,7 @@ authors = ["Your Name <you@example.com>"]
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = ">=3.10,<=3.13"
|
||||
crewai = { extras = ["tools"], version = ">=0.74.2,<1.0.0" }
|
||||
crewai = { extras = ["tools"], version = ">=0.75.1,<1.0.0" }
|
||||
asyncio = "*"
|
||||
|
||||
[tool.poetry.scripts]
|
||||
|
||||
@@ -5,7 +5,7 @@ description = "{{name}} using crewAI"
|
||||
authors = ["Your Name <you@example.com>"]
|
||||
requires-python = ">=3.10,<=3.13"
|
||||
dependencies = [
|
||||
"crewai[tools]>=0.74.2,<1.0.0"
|
||||
"crewai[tools]>=0.75.1,<1.0.0"
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
||||
@@ -5,6 +5,6 @@ description = "Power up your crews with {{folder_name}}"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10,<=3.13"
|
||||
dependencies = [
|
||||
"crewai[tools]>=0.74.2"
|
||||
"crewai[tools]>=0.75.1"
|
||||
]
|
||||
|
||||
|
||||
@@ -435,15 +435,16 @@ class Crew(BaseModel):
|
||||
self, n_iterations: int, filename: str, inputs: Optional[Dict[str, Any]] = {}
|
||||
) -> None:
|
||||
"""Trains the crew for a given number of iterations."""
|
||||
self._setup_for_training(filename)
|
||||
train_crew = self.copy()
|
||||
train_crew._setup_for_training(filename)
|
||||
|
||||
for n_iteration in range(n_iterations):
|
||||
self._train_iteration = n_iteration
|
||||
self.kickoff(inputs=inputs)
|
||||
train_crew._train_iteration = n_iteration
|
||||
train_crew.kickoff(inputs=inputs)
|
||||
|
||||
training_data = CrewTrainingHandler(TRAINING_DATA_FILE).load()
|
||||
|
||||
for agent in self.agents:
|
||||
for agent in train_crew.agents:
|
||||
result = TaskEvaluator(agent).evaluate_training_data(
|
||||
training_data=training_data, agent_id=str(agent.id)
|
||||
)
|
||||
@@ -987,17 +988,19 @@ class Crew(BaseModel):
|
||||
inputs: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
"""Test and evaluate the Crew with the given inputs for n iterations concurrently using concurrent.futures."""
|
||||
self._test_execution_span = self._telemetry.test_execution_span(
|
||||
self,
|
||||
test_crew = self.copy()
|
||||
|
||||
self._test_execution_span = test_crew._telemetry.test_execution_span(
|
||||
test_crew,
|
||||
n_iterations,
|
||||
inputs,
|
||||
openai_model_name, # type: ignore[arg-type]
|
||||
) # type: ignore[arg-type]
|
||||
evaluator = CrewEvaluator(self, openai_model_name) # type: ignore[arg-type]
|
||||
evaluator = CrewEvaluator(test_crew, openai_model_name) # type: ignore[arg-type]
|
||||
|
||||
for i in range(1, n_iterations + 1):
|
||||
evaluator.set_iteration(i)
|
||||
self.kickoff(inputs=inputs)
|
||||
test_crew.kickoff(inputs=inputs)
|
||||
|
||||
evaluator.print_crew_evaluation_result()
|
||||
|
||||
|
||||
@@ -16,7 +16,7 @@ class EntityMemory(Memory):
|
||||
if storage
|
||||
else RAGStorage(
|
||||
type="entities",
|
||||
allow_reset=False,
|
||||
allow_reset=True,
|
||||
embedder_config=embedder_config,
|
||||
crew=crew,
|
||||
)
|
||||
|
||||
@@ -8,6 +8,9 @@ from typing import Any, Dict, List, Optional
|
||||
from crewai.memory.storage.base_rag_storage import BaseRAGStorage
|
||||
from crewai.utilities.paths import db_storage_path
|
||||
from chromadb.api import ClientAPI
|
||||
from chromadb.api.types import validate_embedding_function
|
||||
from chromadb import Documents, EmbeddingFunction, Embeddings
|
||||
from typing import cast
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
@@ -41,16 +44,93 @@ class RAGStorage(BaseRAGStorage):
|
||||
self.agents = agents
|
||||
|
||||
self.type = type
|
||||
self.embedder_config = embedder_config or self._create_embedding_function()
|
||||
|
||||
self.allow_reset = allow_reset
|
||||
self._initialize_app()
|
||||
|
||||
def _set_embedder_config(self):
|
||||
import chromadb.utils.embedding_functions as embedding_functions
|
||||
|
||||
if self.embedder_config is None:
|
||||
self.embedder_config = self._create_default_embedding_function()
|
||||
|
||||
if isinstance(self.embedder_config, dict):
|
||||
provider = self.embedder_config.get("provider")
|
||||
config = self.embedder_config.get("config", {})
|
||||
model_name = config.get("model")
|
||||
if provider == "openai":
|
||||
self.embedder_config = embedding_functions.OpenAIEmbeddingFunction(
|
||||
api_key=config.get("api_key") or os.getenv("OPENAI_API_KEY"),
|
||||
model_name=model_name,
|
||||
)
|
||||
elif provider == "azure":
|
||||
self.embedder_config = embedding_functions.OpenAIEmbeddingFunction(
|
||||
api_key=config.get("api_key"),
|
||||
api_base=config.get("api_base"),
|
||||
api_type=config.get("api_type", "azure"),
|
||||
api_version=config.get("api_version"),
|
||||
model_name=model_name,
|
||||
)
|
||||
elif provider == "ollama":
|
||||
from openai import OpenAI
|
||||
|
||||
class OllamaEmbeddingFunction(EmbeddingFunction):
|
||||
def __call__(self, input: Documents) -> Embeddings:
|
||||
client = OpenAI(
|
||||
base_url="http://localhost:11434/v1",
|
||||
api_key=config.get("api_key", "ollama"),
|
||||
)
|
||||
try:
|
||||
response = client.embeddings.create(
|
||||
input=input, model=model_name
|
||||
)
|
||||
embeddings = [item.embedding for item in response.data]
|
||||
return cast(Embeddings, embeddings)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
self.embedder_config = OllamaEmbeddingFunction()
|
||||
elif provider == "vertexai":
|
||||
self.embedder_config = (
|
||||
embedding_functions.GoogleVertexEmbeddingFunction(
|
||||
model_name=model_name,
|
||||
api_key=config.get("api_key"),
|
||||
)
|
||||
)
|
||||
elif provider == "google":
|
||||
self.embedder_config = (
|
||||
embedding_functions.GoogleGenerativeAiEmbeddingFunction(
|
||||
model_name=model_name,
|
||||
api_key=config.get("api_key"),
|
||||
)
|
||||
)
|
||||
elif provider == "cohere":
|
||||
self.embedder_config = embedding_functions.CohereEmbeddingFunction(
|
||||
model_name=model_name,
|
||||
api_key=config.get("api_key"),
|
||||
)
|
||||
elif provider == "huggingface":
|
||||
self.embedder_config = embedding_functions.HuggingFaceEmbeddingServer(
|
||||
url=config.get("api_url"),
|
||||
)
|
||||
else:
|
||||
raise Exception(
|
||||
f"Unsupported embedding provider: {provider}, supported providers: [openai, azure, ollama, vertexai, google, cohere, huggingface]"
|
||||
)
|
||||
else:
|
||||
validate_embedding_function(self.embedder_config) # type: ignore # used for validating embedder_config if defined a embedding function/class
|
||||
self.embedder_config = self.embedder_config
|
||||
|
||||
def _initialize_app(self):
|
||||
import chromadb
|
||||
from chromadb.config import Settings
|
||||
|
||||
self._set_embedder_config()
|
||||
chroma_client = chromadb.PersistentClient(
|
||||
path=f"{db_storage_path()}/{self.type}/{self.agents}"
|
||||
path=f"{db_storage_path()}/{self.type}/{self.agents}",
|
||||
settings=Settings(allow_reset=self.allow_reset),
|
||||
)
|
||||
|
||||
self.app = chroma_client
|
||||
|
||||
try:
|
||||
@@ -122,11 +202,15 @@ class RAGStorage(BaseRAGStorage):
|
||||
if self.app:
|
||||
self.app.reset()
|
||||
except Exception as e:
|
||||
raise Exception(
|
||||
f"An error occurred while resetting the {self.type} memory: {e}"
|
||||
)
|
||||
if "attempt to write a readonly database" in str(e):
|
||||
# Ignore this specific error
|
||||
pass
|
||||
else:
|
||||
raise Exception(
|
||||
f"An error occurred while resetting the {self.type} memory: {e}"
|
||||
)
|
||||
|
||||
def _create_embedding_function(self):
|
||||
def _create_default_embedding_function(self):
|
||||
import chromadb.utils.embedding_functions as embedding_functions
|
||||
|
||||
return embedding_functions.OpenAIEmbeddingFunction(
|
||||
|
||||
@@ -9,6 +9,7 @@ from unittest.mock import MagicMock, patch
|
||||
import instructor
|
||||
import pydantic_core
|
||||
import pytest
|
||||
|
||||
from crewai.agent import Agent
|
||||
from crewai.agents.cache import CacheHandler
|
||||
from crewai.crew import Crew
|
||||
@@ -497,6 +498,7 @@ def test_cache_hitting_between_agents():
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_api_calls_throttling(capsys):
|
||||
from unittest.mock import patch
|
||||
|
||||
from crewai_tools import tool
|
||||
|
||||
@tool
|
||||
@@ -779,11 +781,14 @@ def test_async_task_execution_call_count():
|
||||
list_important_history.output = mock_task_output
|
||||
write_article.output = mock_task_output
|
||||
|
||||
with patch.object(
|
||||
Task, "execute_sync", return_value=mock_task_output
|
||||
) as mock_execute_sync, patch.object(
|
||||
Task, "execute_async", return_value=mock_future
|
||||
) as mock_execute_async:
|
||||
with (
|
||||
patch.object(
|
||||
Task, "execute_sync", return_value=mock_task_output
|
||||
) as mock_execute_sync,
|
||||
patch.object(
|
||||
Task, "execute_async", return_value=mock_future
|
||||
) as mock_execute_async,
|
||||
):
|
||||
crew.kickoff()
|
||||
|
||||
assert mock_execute_async.call_count == 2
|
||||
@@ -1105,6 +1110,7 @@ def test_dont_set_agents_step_callback_if_already_set():
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_crew_function_calling_llm():
|
||||
from unittest.mock import patch
|
||||
|
||||
from crewai_tools import tool
|
||||
|
||||
llm = "gpt-4o"
|
||||
@@ -1448,52 +1454,6 @@ def test_crew_does_not_interpolate_without_inputs():
|
||||
interpolate_task_inputs.assert_not_called()
|
||||
|
||||
|
||||
# def test_crew_partial_inputs():
|
||||
# agent = Agent(
|
||||
# role="{topic} Researcher",
|
||||
# goal="Express hot takes on {topic}.",
|
||||
# backstory="You have a lot of experience with {topic}.",
|
||||
# )
|
||||
|
||||
# task = Task(
|
||||
# description="Give me an analysis around {topic}.",
|
||||
# expected_output="{points} bullet points about {topic}.",
|
||||
# )
|
||||
|
||||
# crew = Crew(agents=[agent], tasks=[task], inputs={"topic": "AI"})
|
||||
# inputs = {"topic": "AI"}
|
||||
# crew._interpolate_inputs(inputs=inputs) # Manual call for now
|
||||
|
||||
# assert crew.tasks[0].description == "Give me an analysis around AI."
|
||||
# assert crew.tasks[0].expected_output == "{points} bullet points about AI."
|
||||
# assert crew.agents[0].role == "AI Researcher"
|
||||
# assert crew.agents[0].goal == "Express hot takes on AI."
|
||||
# assert crew.agents[0].backstory == "You have a lot of experience with AI."
|
||||
|
||||
|
||||
# def test_crew_invalid_inputs():
|
||||
# agent = Agent(
|
||||
# role="{topic} Researcher",
|
||||
# goal="Express hot takes on {topic}.",
|
||||
# backstory="You have a lot of experience with {topic}.",
|
||||
# )
|
||||
|
||||
# task = Task(
|
||||
# description="Give me an analysis around {topic}.",
|
||||
# expected_output="{points} bullet points about {topic}.",
|
||||
# )
|
||||
|
||||
# crew = Crew(agents=[agent], tasks=[task], inputs={"subject": "AI"})
|
||||
# inputs = {"subject": "AI"}
|
||||
# crew._interpolate_inputs(inputs=inputs) # Manual call for now
|
||||
|
||||
# assert crew.tasks[0].description == "Give me an analysis around {topic}."
|
||||
# assert crew.tasks[0].expected_output == "{points} bullet points about {topic}."
|
||||
# assert crew.agents[0].role == "{topic} Researcher"
|
||||
# assert crew.agents[0].goal == "Express hot takes on {topic}."
|
||||
# assert crew.agents[0].backstory == "You have a lot of experience with {topic}."
|
||||
|
||||
|
||||
def test_task_callback_on_crew():
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
@@ -1770,7 +1730,10 @@ def test_manager_agent_with_tools_raises_exception():
|
||||
@patch("crewai.crew.Crew.kickoff")
|
||||
@patch("crewai.crew.CrewTrainingHandler")
|
||||
@patch("crewai.crew.TaskEvaluator")
|
||||
def test_crew_train_success(task_evaluator, crew_training_handler, kickoff):
|
||||
@patch("crewai.crew.Crew.copy")
|
||||
def test_crew_train_success(
|
||||
copy_mock, task_evaluator, crew_training_handler, kickoff_mock
|
||||
):
|
||||
task = Task(
|
||||
description="Come up with a list of 5 interesting ideas to explore for an article, then write one amazing paragraph highlight for each idea that showcases how good an article about this topic could be. Return the list of ideas with their paragraph and your notes.",
|
||||
expected_output="5 bullet points with a paragraph for each idea.",
|
||||
@@ -1781,9 +1744,19 @@ def test_crew_train_success(task_evaluator, crew_training_handler, kickoff):
|
||||
agents=[researcher, writer],
|
||||
tasks=[task],
|
||||
)
|
||||
|
||||
# Create a mock for the copied crew
|
||||
copy_mock.return_value = crew
|
||||
|
||||
crew.train(
|
||||
n_iterations=2, inputs={"topic": "AI"}, filename="trained_agents_data.pkl"
|
||||
)
|
||||
|
||||
# Ensure kickoff is called on the copied crew
|
||||
kickoff_mock.assert_has_calls(
|
||||
[mock.call(inputs={"topic": "AI"}), mock.call(inputs={"topic": "AI"})]
|
||||
)
|
||||
|
||||
task_evaluator.assert_has_calls(
|
||||
[
|
||||
mock.call(researcher),
|
||||
@@ -1822,10 +1795,6 @@ def test_crew_train_success(task_evaluator, crew_training_handler, kickoff):
|
||||
]
|
||||
)
|
||||
|
||||
kickoff.assert_has_calls(
|
||||
[mock.call(inputs={"topic": "AI"}), mock.call(inputs={"topic": "AI"})]
|
||||
)
|
||||
|
||||
|
||||
def test_crew_train_error():
|
||||
task = Task(
|
||||
@@ -1840,7 +1809,7 @@ def test_crew_train_error():
|
||||
)
|
||||
|
||||
with pytest.raises(TypeError) as e:
|
||||
crew.train()
|
||||
crew.train() # type: ignore purposefully throwing err
|
||||
assert "train() missing 1 required positional argument: 'n_iterations'" in str(
|
||||
e
|
||||
)
|
||||
@@ -2536,8 +2505,9 @@ def test_conditional_should_execute():
|
||||
|
||||
|
||||
@mock.patch("crewai.crew.CrewEvaluator")
|
||||
@mock.patch("crewai.crew.Crew.copy")
|
||||
@mock.patch("crewai.crew.Crew.kickoff")
|
||||
def test_crew_testing_function(mock_kickoff, crew_evaluator):
|
||||
def test_crew_testing_function(kickoff_mock, copy_mock, crew_evaluator):
|
||||
task = Task(
|
||||
description="Come up with a list of 5 interesting ideas to explore for an article, then write one amazing paragraph highlight for each idea that showcases how good an article about this topic could be. Return the list of ideas with their paragraph and your notes.",
|
||||
expected_output="5 bullet points with a paragraph for each idea.",
|
||||
@@ -2548,11 +2518,15 @@ def test_crew_testing_function(mock_kickoff, crew_evaluator):
|
||||
agents=[researcher],
|
||||
tasks=[task],
|
||||
)
|
||||
|
||||
# Create a mock for the copied crew
|
||||
copy_mock.return_value = crew
|
||||
|
||||
n_iterations = 2
|
||||
crew.test(n_iterations, openai_model_name="gpt-4o-mini", inputs={"topic": "AI"})
|
||||
|
||||
assert len(mock_kickoff.mock_calls) == n_iterations
|
||||
mock_kickoff.assert_has_calls(
|
||||
# Ensure kickoff is called on the copied crew
|
||||
kickoff_mock.assert_has_calls(
|
||||
[mock.call(inputs={"topic": "AI"}), mock.call(inputs={"topic": "AI"})]
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user