Compare commits

...

22 Commits

Author SHA1 Message Date
Brandon Hancock
c9476769e1 Merge branch 'main' into fix-cli 2024-10-23 09:39:17 -04:00
Brandon Hancock (bhancock_ai)
d59ecb22e6 Fix spelling mistake 2024-10-23 09:09:56 -04:00
Rip&Tear
263544524d ruff updates 2024-10-23 17:20:05 +08:00
Rip&Tear
098a4312ab allow user to bypass api key entry + incorect number selected logic + ruff formatting 2024-10-23 17:17:05 +08:00
João Moura
b8a3c29745 preparing new verison 2024-10-23 05:34:34 -03:00
Brandon Hancock (bhancock_ai)
9cd4ff05c9 use copy to split testing and training on crews (#1491)
* use copy to split testing and training on crews

* make tests handle new copy functionality on train and test

* fix last test

* fix test
2024-10-22 21:31:44 -04:00
Lorenze Jay
4687779702 ensure original embedding config works (#1476)
* ensure original embedding config works

* some fixes

* raise error on unsupported provider

* WIP: brandons notes

* fixes

* rm prints

* fixed docs

* fixed run types

* updates to add more docs and correct imports with huggingface embedding server enabled

---------

Co-authored-by: Brandon Hancock <brandon@brandonhancock.io>
2024-10-22 12:30:30 -07:00
Rip&Tear
c724c0af70 Minor doc updates 2024-10-22 09:04:32 +08:00
Rip&Tear
f6f430b26a Added docs for new CLI provider + fixed missing API prompt 2024-10-18 10:23:34 +08:00
Brandon Hancock
a5f70d2307 fix unnecessary deps 2024-10-17 10:00:04 -04:00
Rip&Tear
b55fc40c83 Merge branch 'main' into feat/cli-model-selection-and-API-submission 2024-10-17 11:39:01 +08:00
Rip&Tear
d0ed4f5274 small comment cleanup 2024-10-17 11:25:37 +08:00
Rip&Tear
ee34399b71 refactor/Move functions into utils file, added new provider file and migrated fucntions thre, new constants file + general function refactor 2024-10-17 11:16:10 +08:00
Rip&Tear
39903f0c50 cleanup of comments 2024-10-13 18:14:09 +08:00
Rip&Tear
c4bf713113 refactored select_provider to have an ealry return 2024-10-13 18:13:24 +08:00
Rip&Tear
5d18c6312d refactered select_choice function for early return 2024-10-13 18:09:33 +08:00
Rip&Tear
1f9baf9b2c feat: implement crew creation CLI command
- refactor code to multiple functions
- Added ability for users to select provider and model when uing crewai create command and ave API key to .env
2024-10-13 00:04:05 +08:00
Rip&Tear
6fbc97b298 removed all unnecessary comments 2024-10-12 13:22:48 +08:00
Rip&Tear
08bacfa892 Merge branch 'feat/cli-model-selection-and-API-submission' of https://github.com/crewAIInc/crewAI into feat/cli-model-selection-and-API-submission 2024-10-12 13:06:16 +08:00
Rip&Tear
1ea8115d56 updated click prompt to remove default number 2024-10-12 13:05:55 +08:00
Brandon Hancock (bhancock_ai)
6b906f09cf Merge branch 'main' into feat/cli-model-selection-and-API-submission 2024-10-11 14:44:24 -04:00
Rip&Tear
6c29ebafea updated CLI to allow for submitting API keys 2024-10-11 23:33:49 +08:00
17 changed files with 426 additions and 177 deletions

View File

@@ -6,7 +6,7 @@ icon: terminal
# CrewAI CLI Documentation
The CrewAI CLI provides a set of commands to interact with CrewAI, allowing you to create, train, run, and manage crews and pipelines.
The CrewAI CLI provides a set of commands to interact with CrewAI, allowing you to create, train, run, and manage crews & flows.
## Installation
@@ -146,3 +146,34 @@ crewai run
Make sure to run these commands from the directory where your CrewAI project is set up.
Some commands may require additional configuration or setup within your project structure.
</Note>
### 9. API Keys
When running ```crewai create crew``` command, the CLI will first show you the top 5 most common LLM providers and ask you to select one.
Once you've selected an LLM provider, you will be prompted for API keys.
#### Initial API key providers
The CLI will initially prompt for API keys for the following services:
* OpenAI
* Groq
* Anthropic
* Google Gemini
When you select a provider, the CLI will prompt you to enter your API key.
#### Other Options
If you select option 6, you will be able to select from a list of LiteLLM supported providers.
When you select a provider, the CLI will prompt you to enter the Key name and the API key.
See the following link for each provider's key name:
* [LiteLLM Providers](https://docs.litellm.ai/docs/providers)

View File

@@ -105,9 +105,48 @@ my_crew = Crew(
process=Process.sequential,
memory=True,
verbose=True,
embedder=embedding_functions.OpenAIEmbeddingFunction(
api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small"
)
embedder={
"provider": "openai",
"config": {
"model": 'text-embedding-3-small'
}
}
)
```
Alternatively, you can directly pass the OpenAIEmbeddingFunction to the embedder parameter.
Example:
```python Code
from crewai import Crew, Agent, Task, Process
from chromadb.utils.embedding_functions.openai_embedding_function import OpenAIEmbeddingFunction
my_crew = Crew(
agents=[...],
tasks=[...],
process=Process.sequential,
memory=True,
verbose=True,
embedder=OpenAIEmbeddingFunction(api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small"),
)
```
### Using Ollama embeddings
```python Code
from crewai import Crew, Agent, Task, Process
my_crew = Crew(
agents=[...],
tasks=[...],
process=Process.sequential,
memory=True,
verbose=True,
embedder={
"provider": "ollama",
"config": {
"model": "mxbai-embed-large"
}
}
)
```
@@ -122,10 +161,13 @@ my_crew = Crew(
process=Process.sequential,
memory=True,
verbose=True,
embedder=embedding_functions.OpenAIEmbeddingFunction(
api_key=os.getenv("OPENAI_API_KEY"),
model_name="text-embedding-ada-002"
)
embedder={
"provider": "google",
"config": {
"api_key": "<YOUR_API_KEY>",
"model_name": "<model_name>"
}
}
)
```
@@ -181,10 +223,32 @@ my_crew = Crew(
process=Process.sequential,
memory=True,
verbose=True,
embedder=embedding_functions.CohereEmbeddingFunction(
api_key=YOUR_API_KEY,
model_name="<model_name>"
)
embedder={
"provider": "cohere",
"config": {
"api_key": "YOUR_API_KEY",
"model_name": "<model_name>"
}
}
)
```
### Using HuggingFace embeddings
```python Code
from crewai import Crew, Agent, Task, Process
my_crew = Crew(
agents=[...],
tasks=[...],
process=Process.sequential,
memory=True,
verbose=True,
embedder={
"provider": "huggingface",
"config": {
"api_url": "<api_url>",
}
}
)
```

View File

@@ -1,6 +1,6 @@
[project]
name = "crewai"
version = "0.74.2"
version = "0.75.1"
description = "Cutting-edge framework for orchestrating role-playing, autonomous AI agents. By fostering collaborative intelligence, CrewAI empowers agents to work together seamlessly, tackling complex tasks."
readme = "README.md"
requires-python = ">=3.10,<=3.13"

View File

@@ -14,5 +14,5 @@ warnings.filterwarnings(
category=UserWarning,
module="pydantic.main",
)
__version__ = "0.74.2"
__version__ = "0.75.1"
__all__ = ["Agent", "Crew", "Process", "Task", "Pipeline", "Router", "LLM", "Flow"]

View File

@@ -32,10 +32,11 @@ def crewai():
@crewai.command()
@click.argument("type", type=click.Choice(["crew", "pipeline", "flow"]))
@click.argument("name")
def create(type, name):
@click.option("--provider", type=str, help="The provider to use for the crew")
def create(type, name, provider):
"""Create a new crew, pipeline, or flow."""
if type == "crew":
create_crew(name)
create_crew(name, provider)
elif type == "pipeline":
create_pipeline(name)
elif type == "flow":

View File

@@ -1,8 +1,16 @@
import sys
from pathlib import Path
import click
from crewai.cli.utils import copy_template, load_env_vars, write_env_file
from crewai.cli.provider import get_provider_data, select_provider, PROVIDERS
from crewai.cli.constants import ENV_VARS
from crewai.cli.provider import (
PROVIDERS,
get_provider_data,
select_model,
select_provider,
)
from crewai.cli.utils import copy_template, load_env_vars, write_env_file
def create_folder_structure(name, parent_folder=None):
@@ -14,11 +22,19 @@ def create_folder_structure(name, parent_folder=None):
else:
folder_path = Path(folder_name)
click.secho(
f"Creating {'crew' if parent_folder else 'folder'} {folder_name}...",
fg="green",
bold=True,
)
if folder_path.exists():
if not click.confirm(
f"Folder {folder_name} already exists. Do you want to override it?"
):
click.secho("Operation cancelled.", fg="yellow")
sys.exit(0)
click.secho(f"Overriding folder {folder_name}...", fg="green", bold=True)
else:
click.secho(
f"Creating {'crew' if parent_folder else 'folder'} {folder_name}...",
fg="green",
bold=True,
)
if not folder_path.exists():
folder_path.mkdir(parents=True)
@@ -27,11 +43,6 @@ def create_folder_structure(name, parent_folder=None):
(folder_path / "src" / folder_name).mkdir(parents=True)
(folder_path / "src" / folder_name / "tools").mkdir(parents=True)
(folder_path / "src" / folder_name / "config").mkdir(parents=True)
else:
click.secho(
f"\tFolder {folder_name} already exists.",
fg="yellow",
)
return folder_path, folder_name, class_name
@@ -74,33 +85,73 @@ def create_crew(name, parent_folder=None):
folder_path, folder_name, class_name = create_folder_structure(name, parent_folder)
env_vars = load_env_vars(folder_path)
existing_provider = None
for provider, env_keys in ENV_VARS.items():
if any(key in env_vars for key in env_keys):
existing_provider = provider
break
if existing_provider:
if not click.confirm(
f"Found existing environment variable configuration for {existing_provider.capitalize()}. Do you want to override it?"
):
click.secho("Keeping existing provider configuration.", fg="yellow")
return
provider_models = get_provider_data()
if not provider_models:
return
selected_provider = select_provider(provider_models)
if not selected_provider:
return
provider = selected_provider
# selected_model = select_model(provider, provider_models)
# if not selected_model:
# return
# model = selected_model
if provider in PROVIDERS:
api_key_var = ENV_VARS[provider][0]
else:
api_key_var = click.prompt(
f"Enter the environment variable name for your {provider.capitalize()} API key",
type=str,
while True:
selected_provider = select_provider(provider_models)
if selected_provider is None: # User typed 'q'
click.secho("Exiting...", fg="yellow")
sys.exit(0)
if selected_provider: # Valid selection
break
click.secho(
"No provider selected. Please try again or press 'q' to exit.", fg="red"
)
env_vars = {api_key_var: "YOUR_API_KEY_HERE"}
write_env_file(folder_path, env_vars)
while True:
selected_model = select_model(selected_provider, provider_models)
if selected_model is None: # User typed 'q'
click.secho("Exiting...", fg="yellow")
sys.exit(0)
if selected_model: # Valid selection
break
click.secho(
"No model selected. Please try again or press 'q' to exit.", fg="red"
)
# env_vars['MODEL'] = model
# click.secho(f"Selected model: {model}", fg="green")
if selected_provider in PROVIDERS:
api_key_var = ENV_VARS[selected_provider][0]
else:
api_key_var = click.prompt(
f"Enter the environment variable name for your {selected_provider.capitalize()} API key",
type=str,
default="",
)
api_key_value = ""
click.echo(
f"Enter your {selected_provider.capitalize()} API key (press Enter to skip): ",
nl=False,
)
try:
api_key_value = input()
except (KeyboardInterrupt, EOFError):
api_key_value = ""
if api_key_value.strip():
env_vars = {api_key_var: api_key_value}
write_env_file(folder_path, env_vars)
click.secho("API key saved to .env file", fg="green")
else:
click.secho("No API key provided. Skipping .env file creation.", fg="yellow")
env_vars["MODEL"] = selected_model
click.secho(f"Selected model: {selected_model}", fg="green")
package_dir = Path(__file__).parent
templates_dir = package_dir / "templates" / "crew"

View File

@@ -1,67 +1,91 @@
import json
import time
import requests
from collections import defaultdict
from pathlib import Path
import click
from pathlib import Path
from crewai.cli.constants import PROVIDERS, MODELS, JSON_URL
import requests
from crewai.cli.constants import JSON_URL, MODELS, PROVIDERS
def select_choice(prompt_message, choices):
"""
Presents a list of choices to the user and prompts them to select one.
Args:
- prompt_message (str): The message to display to the user before presenting the choices.
- choices (list): A list of options to present to the user.
Returns:
- str: The selected choice from the list, or None if the operation is aborted or an invalid selection is made.
- str: The selected choice from the list, or None if the user chooses to quit.
"""
provider_models = get_provider_data()
if not provider_models:
return
click.secho(prompt_message, fg="cyan")
for idx, choice in enumerate(choices, start=1):
click.secho(f"{idx}. {choice}", fg="cyan")
try:
selected_index = click.prompt("Enter the number of your choice", type=int) - 1
except click.exceptions.Abort:
click.secho("Operation aborted by the user.", fg="red")
return None
if not (0 <= selected_index < len(choices)):
click.secho("Invalid selection.", fg="red")
return None
return choices[selected_index]
click.secho("q. Quit", fg="cyan")
while True:
choice = click.prompt(
"Enter the number of your choice or 'q' to quit", type=str
)
if choice.lower() == "q":
return None
try:
selected_index = int(choice) - 1
if 0 <= selected_index < len(choices):
return choices[selected_index]
except ValueError:
pass
click.secho(
"Invalid selection. Please select a number between 1 and 6 or 'q' to quit.",
fg="red",
)
def select_provider(provider_models):
"""
Presents a list of providers to the user and prompts them to select one.
Args:
- provider_models (dict): A dictionary of provider models.
Returns:
- str: The selected provider, or None if the operation is aborted or an invalid selection is made.
- str: The selected provider
- None: If user explicitly quits
"""
predefined_providers = [p.lower() for p in PROVIDERS]
all_providers = sorted(set(predefined_providers + list(provider_models.keys())))
provider = select_choice("Select a provider to set up:", predefined_providers + ['other'])
if not provider:
provider = select_choice(
"Select a provider to set up:", predefined_providers + ["other"]
)
if provider is None: # User typed 'q'
return None
provider = provider.lower()
if provider == 'other':
if provider == "other":
provider = select_choice("Select a provider from the full list:", all_providers)
if not provider:
if provider is None: # User typed 'q'
return None
return provider
return provider.lower() if provider else False
def select_model(provider, provider_models):
"""
Presents a list of models for a given provider to the user and prompts them to select one.
Args:
- provider (str): The provider for which to select a model.
- provider_models (dict): A dictionary of provider models.
Returns:
- str: The selected model, or None if the operation is aborted or an invalid selection is made.
"""
@@ -76,37 +100,49 @@ def select_model(provider, provider_models):
click.secho(f"No models available for provider '{provider}'.", fg="red")
return None
selected_model = select_choice(f"Select a model to use for {provider.capitalize()}:", available_models)
selected_model = select_choice(
f"Select a model to use for {provider.capitalize()}:", available_models
)
return selected_model
def load_provider_data(cache_file, cache_expiry):
"""
Loads provider data from a cache file if it exists and is not expired. If the cache is expired or corrupted, it fetches the data from the web.
Args:
- cache_file (Path): The path to the cache file.
- cache_expiry (int): The cache expiry time in seconds.
Returns:
- dict or None: The loaded provider data or None if the operation fails.
"""
current_time = time.time()
if cache_file.exists() and (current_time - cache_file.stat().st_mtime) < cache_expiry:
if (
cache_file.exists()
and (current_time - cache_file.stat().st_mtime) < cache_expiry
):
data = read_cache_file(cache_file)
if data:
return data
click.secho("Cache is corrupted. Fetching provider data from the web...", fg="yellow")
click.secho(
"Cache is corrupted. Fetching provider data from the web...", fg="yellow"
)
else:
click.secho("Cache expired or not found. Fetching provider data from the web...", fg="cyan")
click.secho(
"Cache expired or not found. Fetching provider data from the web...",
fg="cyan",
)
return fetch_provider_data(cache_file)
def read_cache_file(cache_file):
"""
Reads and returns the JSON content from a cache file. Returns None if the file contains invalid JSON.
Args:
- cache_file (Path): The path to the cache file.
Returns:
- dict or None: The JSON content of the cache file or None if the JSON is invalid.
"""
@@ -116,13 +152,14 @@ def read_cache_file(cache_file):
except json.JSONDecodeError:
return None
def fetch_provider_data(cache_file):
"""
Fetches provider data from a specified URL and caches it to a file.
Args:
- cache_file (Path): The path to the cache file.
Returns:
- dict or None: The fetched provider data or None if the operation fails.
"""
@@ -139,38 +176,42 @@ def fetch_provider_data(cache_file):
click.secho("Error parsing provider data. Invalid JSON format.", fg="red")
return None
def download_data(response):
"""
Downloads data from a given HTTP response and returns the JSON content.
Args:
- response (requests.Response): The HTTP response object.
Returns:
- dict: The JSON content of the response.
"""
total_size = int(response.headers.get('content-length', 0))
total_size = int(response.headers.get("content-length", 0))
block_size = 8192
data_chunks = []
with click.progressbar(length=total_size, label='Downloading', show_pos=True) as progress_bar:
with click.progressbar(
length=total_size, label="Downloading", show_pos=True
) as progress_bar:
for chunk in response.iter_content(block_size):
if chunk:
data_chunks.append(chunk)
progress_bar.update(len(chunk))
data_content = b''.join(data_chunks)
return json.loads(data_content.decode('utf-8'))
data_content = b"".join(data_chunks)
return json.loads(data_content.decode("utf-8"))
def get_provider_data():
"""
Retrieves provider data from a cache file, filters out models based on provider criteria, and returns a dictionary of providers mapped to their models.
Returns:
- dict or None: A dictionary of providers mapped to their models or None if the operation fails.
"""
cache_dir = Path.home() / '.crewai'
cache_dir = Path.home() / ".crewai"
cache_dir.mkdir(exist_ok=True)
cache_file = cache_dir / 'provider_cache.json'
cache_expiry = 24 * 3600
cache_file = cache_dir / "provider_cache.json"
cache_expiry = 24 * 3600
data = load_provider_data(cache_file, cache_expiry)
if not data:
@@ -179,8 +220,8 @@ def get_provider_data():
provider_models = defaultdict(list)
for model_name, properties in data.items():
provider = properties.get("litellm_provider", "").strip().lower()
if 'http' in provider or provider == 'other':
if "http" in provider or provider == "other":
continue
if provider:
provider_models[provider].append(model_name)
return provider_models
return provider_models

View File

@@ -5,7 +5,7 @@ description = "{{name}} using crewAI"
authors = [{ name = "Your Name", email = "you@example.com" }]
requires-python = ">=3.10,<=3.13"
dependencies = [
"crewai[tools]>=0.74.2,<1.0.0"
"crewai[tools]>=0.75.1,<1.0.0"
]
[project.scripts]

View File

@@ -5,7 +5,7 @@ description = "{{name}} using crewAI"
authors = [{ name = "Your Name", email = "you@example.com" }]
requires-python = ">=3.10,<=3.13"
dependencies = [
"crewai[tools]>=0.74.2,<1.0.0",
"crewai[tools]>=0.75.1,<1.0.0",
]
[project.scripts]

View File

@@ -6,7 +6,7 @@ authors = ["Your Name <you@example.com>"]
[tool.poetry.dependencies]
python = ">=3.10,<=3.13"
crewai = { extras = ["tools"], version = ">=0.74.2,<1.0.0" }
crewai = { extras = ["tools"], version = ">=0.75.1,<1.0.0" }
asyncio = "*"
[tool.poetry.scripts]

View File

@@ -5,7 +5,7 @@ description = "{{name}} using crewAI"
authors = ["Your Name <you@example.com>"]
requires-python = ">=3.10,<=3.13"
dependencies = [
"crewai[tools]>=0.74.2,<1.0.0"
"crewai[tools]>=0.75.1,<1.0.0"
]
[project.scripts]

View File

@@ -5,6 +5,6 @@ description = "Power up your crews with {{folder_name}}"
readme = "README.md"
requires-python = ">=3.10,<=3.13"
dependencies = [
"crewai[tools]>=0.74.2"
"crewai[tools]>=0.75.1"
]

View File

@@ -435,15 +435,16 @@ class Crew(BaseModel):
self, n_iterations: int, filename: str, inputs: Optional[Dict[str, Any]] = {}
) -> None:
"""Trains the crew for a given number of iterations."""
self._setup_for_training(filename)
train_crew = self.copy()
train_crew._setup_for_training(filename)
for n_iteration in range(n_iterations):
self._train_iteration = n_iteration
self.kickoff(inputs=inputs)
train_crew._train_iteration = n_iteration
train_crew.kickoff(inputs=inputs)
training_data = CrewTrainingHandler(TRAINING_DATA_FILE).load()
for agent in self.agents:
for agent in train_crew.agents:
result = TaskEvaluator(agent).evaluate_training_data(
training_data=training_data, agent_id=str(agent.id)
)
@@ -987,17 +988,19 @@ class Crew(BaseModel):
inputs: Optional[Dict[str, Any]] = None,
) -> None:
"""Test and evaluate the Crew with the given inputs for n iterations concurrently using concurrent.futures."""
self._test_execution_span = self._telemetry.test_execution_span(
self,
test_crew = self.copy()
self._test_execution_span = test_crew._telemetry.test_execution_span(
test_crew,
n_iterations,
inputs,
openai_model_name, # type: ignore[arg-type]
) # type: ignore[arg-type]
evaluator = CrewEvaluator(self, openai_model_name) # type: ignore[arg-type]
evaluator = CrewEvaluator(test_crew, openai_model_name) # type: ignore[arg-type]
for i in range(1, n_iterations + 1):
evaluator.set_iteration(i)
self.kickoff(inputs=inputs)
test_crew.kickoff(inputs=inputs)
evaluator.print_crew_evaluation_result()

View File

@@ -16,7 +16,7 @@ class EntityMemory(Memory):
if storage
else RAGStorage(
type="entities",
allow_reset=False,
allow_reset=True,
embedder_config=embedder_config,
crew=crew,
)

View File

@@ -8,6 +8,9 @@ from typing import Any, Dict, List, Optional
from crewai.memory.storage.base_rag_storage import BaseRAGStorage
from crewai.utilities.paths import db_storage_path
from chromadb.api import ClientAPI
from chromadb.api.types import validate_embedding_function
from chromadb import Documents, EmbeddingFunction, Embeddings
from typing import cast
@contextlib.contextmanager
@@ -41,16 +44,93 @@ class RAGStorage(BaseRAGStorage):
self.agents = agents
self.type = type
self.embedder_config = embedder_config or self._create_embedding_function()
self.allow_reset = allow_reset
self._initialize_app()
def _set_embedder_config(self):
import chromadb.utils.embedding_functions as embedding_functions
if self.embedder_config is None:
self.embedder_config = self._create_default_embedding_function()
if isinstance(self.embedder_config, dict):
provider = self.embedder_config.get("provider")
config = self.embedder_config.get("config", {})
model_name = config.get("model")
if provider == "openai":
self.embedder_config = embedding_functions.OpenAIEmbeddingFunction(
api_key=config.get("api_key") or os.getenv("OPENAI_API_KEY"),
model_name=model_name,
)
elif provider == "azure":
self.embedder_config = embedding_functions.OpenAIEmbeddingFunction(
api_key=config.get("api_key"),
api_base=config.get("api_base"),
api_type=config.get("api_type", "azure"),
api_version=config.get("api_version"),
model_name=model_name,
)
elif provider == "ollama":
from openai import OpenAI
class OllamaEmbeddingFunction(EmbeddingFunction):
def __call__(self, input: Documents) -> Embeddings:
client = OpenAI(
base_url="http://localhost:11434/v1",
api_key=config.get("api_key", "ollama"),
)
try:
response = client.embeddings.create(
input=input, model=model_name
)
embeddings = [item.embedding for item in response.data]
return cast(Embeddings, embeddings)
except Exception as e:
raise e
self.embedder_config = OllamaEmbeddingFunction()
elif provider == "vertexai":
self.embedder_config = (
embedding_functions.GoogleVertexEmbeddingFunction(
model_name=model_name,
api_key=config.get("api_key"),
)
)
elif provider == "google":
self.embedder_config = (
embedding_functions.GoogleGenerativeAiEmbeddingFunction(
model_name=model_name,
api_key=config.get("api_key"),
)
)
elif provider == "cohere":
self.embedder_config = embedding_functions.CohereEmbeddingFunction(
model_name=model_name,
api_key=config.get("api_key"),
)
elif provider == "huggingface":
self.embedder_config = embedding_functions.HuggingFaceEmbeddingServer(
url=config.get("api_url"),
)
else:
raise Exception(
f"Unsupported embedding provider: {provider}, supported providers: [openai, azure, ollama, vertexai, google, cohere, huggingface]"
)
else:
validate_embedding_function(self.embedder_config) # type: ignore # used for validating embedder_config if defined a embedding function/class
self.embedder_config = self.embedder_config
def _initialize_app(self):
import chromadb
from chromadb.config import Settings
self._set_embedder_config()
chroma_client = chromadb.PersistentClient(
path=f"{db_storage_path()}/{self.type}/{self.agents}"
path=f"{db_storage_path()}/{self.type}/{self.agents}",
settings=Settings(allow_reset=self.allow_reset),
)
self.app = chroma_client
try:
@@ -122,11 +202,15 @@ class RAGStorage(BaseRAGStorage):
if self.app:
self.app.reset()
except Exception as e:
raise Exception(
f"An error occurred while resetting the {self.type} memory: {e}"
)
if "attempt to write a readonly database" in str(e):
# Ignore this specific error
pass
else:
raise Exception(
f"An error occurred while resetting the {self.type} memory: {e}"
)
def _create_embedding_function(self):
def _create_default_embedding_function(self):
import chromadb.utils.embedding_functions as embedding_functions
return embedding_functions.OpenAIEmbeddingFunction(

View File

@@ -9,6 +9,7 @@ from unittest.mock import MagicMock, patch
import instructor
import pydantic_core
import pytest
from crewai.agent import Agent
from crewai.agents.cache import CacheHandler
from crewai.crew import Crew
@@ -497,6 +498,7 @@ def test_cache_hitting_between_agents():
@pytest.mark.vcr(filter_headers=["authorization"])
def test_api_calls_throttling(capsys):
from unittest.mock import patch
from crewai_tools import tool
@tool
@@ -779,11 +781,14 @@ def test_async_task_execution_call_count():
list_important_history.output = mock_task_output
write_article.output = mock_task_output
with patch.object(
Task, "execute_sync", return_value=mock_task_output
) as mock_execute_sync, patch.object(
Task, "execute_async", return_value=mock_future
) as mock_execute_async:
with (
patch.object(
Task, "execute_sync", return_value=mock_task_output
) as mock_execute_sync,
patch.object(
Task, "execute_async", return_value=mock_future
) as mock_execute_async,
):
crew.kickoff()
assert mock_execute_async.call_count == 2
@@ -1105,6 +1110,7 @@ def test_dont_set_agents_step_callback_if_already_set():
@pytest.mark.vcr(filter_headers=["authorization"])
def test_crew_function_calling_llm():
from unittest.mock import patch
from crewai_tools import tool
llm = "gpt-4o"
@@ -1448,52 +1454,6 @@ def test_crew_does_not_interpolate_without_inputs():
interpolate_task_inputs.assert_not_called()
# def test_crew_partial_inputs():
# agent = Agent(
# role="{topic} Researcher",
# goal="Express hot takes on {topic}.",
# backstory="You have a lot of experience with {topic}.",
# )
# task = Task(
# description="Give me an analysis around {topic}.",
# expected_output="{points} bullet points about {topic}.",
# )
# crew = Crew(agents=[agent], tasks=[task], inputs={"topic": "AI"})
# inputs = {"topic": "AI"}
# crew._interpolate_inputs(inputs=inputs) # Manual call for now
# assert crew.tasks[0].description == "Give me an analysis around AI."
# assert crew.tasks[0].expected_output == "{points} bullet points about AI."
# assert crew.agents[0].role == "AI Researcher"
# assert crew.agents[0].goal == "Express hot takes on AI."
# assert crew.agents[0].backstory == "You have a lot of experience with AI."
# def test_crew_invalid_inputs():
# agent = Agent(
# role="{topic} Researcher",
# goal="Express hot takes on {topic}.",
# backstory="You have a lot of experience with {topic}.",
# )
# task = Task(
# description="Give me an analysis around {topic}.",
# expected_output="{points} bullet points about {topic}.",
# )
# crew = Crew(agents=[agent], tasks=[task], inputs={"subject": "AI"})
# inputs = {"subject": "AI"}
# crew._interpolate_inputs(inputs=inputs) # Manual call for now
# assert crew.tasks[0].description == "Give me an analysis around {topic}."
# assert crew.tasks[0].expected_output == "{points} bullet points about {topic}."
# assert crew.agents[0].role == "{topic} Researcher"
# assert crew.agents[0].goal == "Express hot takes on {topic}."
# assert crew.agents[0].backstory == "You have a lot of experience with {topic}."
def test_task_callback_on_crew():
from unittest.mock import MagicMock, patch
@@ -1770,7 +1730,10 @@ def test_manager_agent_with_tools_raises_exception():
@patch("crewai.crew.Crew.kickoff")
@patch("crewai.crew.CrewTrainingHandler")
@patch("crewai.crew.TaskEvaluator")
def test_crew_train_success(task_evaluator, crew_training_handler, kickoff):
@patch("crewai.crew.Crew.copy")
def test_crew_train_success(
copy_mock, task_evaluator, crew_training_handler, kickoff_mock
):
task = Task(
description="Come up with a list of 5 interesting ideas to explore for an article, then write one amazing paragraph highlight for each idea that showcases how good an article about this topic could be. Return the list of ideas with their paragraph and your notes.",
expected_output="5 bullet points with a paragraph for each idea.",
@@ -1781,9 +1744,19 @@ def test_crew_train_success(task_evaluator, crew_training_handler, kickoff):
agents=[researcher, writer],
tasks=[task],
)
# Create a mock for the copied crew
copy_mock.return_value = crew
crew.train(
n_iterations=2, inputs={"topic": "AI"}, filename="trained_agents_data.pkl"
)
# Ensure kickoff is called on the copied crew
kickoff_mock.assert_has_calls(
[mock.call(inputs={"topic": "AI"}), mock.call(inputs={"topic": "AI"})]
)
task_evaluator.assert_has_calls(
[
mock.call(researcher),
@@ -1822,10 +1795,6 @@ def test_crew_train_success(task_evaluator, crew_training_handler, kickoff):
]
)
kickoff.assert_has_calls(
[mock.call(inputs={"topic": "AI"}), mock.call(inputs={"topic": "AI"})]
)
def test_crew_train_error():
task = Task(
@@ -1840,7 +1809,7 @@ def test_crew_train_error():
)
with pytest.raises(TypeError) as e:
crew.train()
crew.train() # type: ignore purposefully throwing err
assert "train() missing 1 required positional argument: 'n_iterations'" in str(
e
)
@@ -2536,8 +2505,9 @@ def test_conditional_should_execute():
@mock.patch("crewai.crew.CrewEvaluator")
@mock.patch("crewai.crew.Crew.copy")
@mock.patch("crewai.crew.Crew.kickoff")
def test_crew_testing_function(mock_kickoff, crew_evaluator):
def test_crew_testing_function(kickoff_mock, copy_mock, crew_evaluator):
task = Task(
description="Come up with a list of 5 interesting ideas to explore for an article, then write one amazing paragraph highlight for each idea that showcases how good an article about this topic could be. Return the list of ideas with their paragraph and your notes.",
expected_output="5 bullet points with a paragraph for each idea.",
@@ -2548,11 +2518,15 @@ def test_crew_testing_function(mock_kickoff, crew_evaluator):
agents=[researcher],
tasks=[task],
)
# Create a mock for the copied crew
copy_mock.return_value = crew
n_iterations = 2
crew.test(n_iterations, openai_model_name="gpt-4o-mini", inputs={"topic": "AI"})
assert len(mock_kickoff.mock_calls) == n_iterations
mock_kickoff.assert_has_calls(
# Ensure kickoff is called on the copied crew
kickoff_mock.assert_has_calls(
[mock.call(inputs={"topic": "AI"}), mock.call(inputs={"topic": "AI"})]
)

2
uv.lock generated
View File

@@ -627,7 +627,7 @@ wheels = [
[[package]]
name = "crewai"
version = "0.74.2"
version = "0.75.1"
source = { editable = "." }
dependencies = [
{ name = "appdirs" },