Compare commits

...

19 Commits

Author SHA1 Message Date
Brandon Hancock
c9476769e1 Merge branch 'main' into fix-cli 2024-10-23 09:39:17 -04:00
Brandon Hancock (bhancock_ai)
d59ecb22e6 Fix spelling mistake 2024-10-23 09:09:56 -04:00
Rip&Tear
263544524d ruff updates 2024-10-23 17:20:05 +08:00
Rip&Tear
098a4312ab allow user to bypass api key entry + incorect number selected logic + ruff formatting 2024-10-23 17:17:05 +08:00
Rip&Tear
c724c0af70 Minor doc updates 2024-10-22 09:04:32 +08:00
Rip&Tear
f6f430b26a Added docs for new CLI provider + fixed missing API prompt 2024-10-18 10:23:34 +08:00
Brandon Hancock
a5f70d2307 fix unnecessary deps 2024-10-17 10:00:04 -04:00
Rip&Tear
b55fc40c83 Merge branch 'main' into feat/cli-model-selection-and-API-submission 2024-10-17 11:39:01 +08:00
Rip&Tear
d0ed4f5274 small comment cleanup 2024-10-17 11:25:37 +08:00
Rip&Tear
ee34399b71 refactor/Move functions into utils file, added new provider file and migrated fucntions thre, new constants file + general function refactor 2024-10-17 11:16:10 +08:00
Rip&Tear
39903f0c50 cleanup of comments 2024-10-13 18:14:09 +08:00
Rip&Tear
c4bf713113 refactored select_provider to have an ealry return 2024-10-13 18:13:24 +08:00
Rip&Tear
5d18c6312d refactered select_choice function for early return 2024-10-13 18:09:33 +08:00
Rip&Tear
1f9baf9b2c feat: implement crew creation CLI command
- refactor code to multiple functions
- Added ability for users to select provider and model when uing crewai create command and ave API key to .env
2024-10-13 00:04:05 +08:00
Rip&Tear
6fbc97b298 removed all unnecessary comments 2024-10-12 13:22:48 +08:00
Rip&Tear
08bacfa892 Merge branch 'feat/cli-model-selection-and-API-submission' of https://github.com/crewAIInc/crewAI into feat/cli-model-selection-and-API-submission 2024-10-12 13:06:16 +08:00
Rip&Tear
1ea8115d56 updated click prompt to remove default number 2024-10-12 13:05:55 +08:00
Brandon Hancock (bhancock_ai)
6b906f09cf Merge branch 'main' into feat/cli-model-selection-and-API-submission 2024-10-11 14:44:24 -04:00
Rip&Tear
6c29ebafea updated CLI to allow for submitting API keys 2024-10-11 23:33:49 +08:00
3 changed files with 206 additions and 84 deletions

View File

@@ -6,7 +6,7 @@ icon: terminal
# CrewAI CLI Documentation # CrewAI CLI Documentation
The CrewAI CLI provides a set of commands to interact with CrewAI, allowing you to create, train, run, and manage crews and pipelines. The CrewAI CLI provides a set of commands to interact with CrewAI, allowing you to create, train, run, and manage crews & flows.
## Installation ## Installation
@@ -146,3 +146,34 @@ crewai run
Make sure to run these commands from the directory where your CrewAI project is set up. Make sure to run these commands from the directory where your CrewAI project is set up.
Some commands may require additional configuration or setup within your project structure. Some commands may require additional configuration or setup within your project structure.
</Note> </Note>
### 9. API Keys
When running ```crewai create crew``` command, the CLI will first show you the top 5 most common LLM providers and ask you to select one.
Once you've selected an LLM provider, you will be prompted for API keys.
#### Initial API key providers
The CLI will initially prompt for API keys for the following services:
* OpenAI
* Groq
* Anthropic
* Google Gemini
When you select a provider, the CLI will prompt you to enter your API key.
#### Other Options
If you select option 6, you will be able to select from a list of LiteLLM supported providers.
When you select a provider, the CLI will prompt you to enter the Key name and the API key.
See the following link for each provider's key name:
* [LiteLLM Providers](https://docs.litellm.ai/docs/providers)

View File

@@ -1,8 +1,16 @@
import sys
from pathlib import Path from pathlib import Path
import click import click
from crewai.cli.utils import copy_template, load_env_vars, write_env_file
from crewai.cli.provider import get_provider_data, select_provider, PROVIDERS
from crewai.cli.constants import ENV_VARS from crewai.cli.constants import ENV_VARS
from crewai.cli.provider import (
PROVIDERS,
get_provider_data,
select_model,
select_provider,
)
from crewai.cli.utils import copy_template, load_env_vars, write_env_file
def create_folder_structure(name, parent_folder=None): def create_folder_structure(name, parent_folder=None):
@@ -14,11 +22,19 @@ def create_folder_structure(name, parent_folder=None):
else: else:
folder_path = Path(folder_name) folder_path = Path(folder_name)
click.secho( if folder_path.exists():
f"Creating {'crew' if parent_folder else 'folder'} {folder_name}...", if not click.confirm(
fg="green", f"Folder {folder_name} already exists. Do you want to override it?"
bold=True, ):
) click.secho("Operation cancelled.", fg="yellow")
sys.exit(0)
click.secho(f"Overriding folder {folder_name}...", fg="green", bold=True)
else:
click.secho(
f"Creating {'crew' if parent_folder else 'folder'} {folder_name}...",
fg="green",
bold=True,
)
if not folder_path.exists(): if not folder_path.exists():
folder_path.mkdir(parents=True) folder_path.mkdir(parents=True)
@@ -27,11 +43,6 @@ def create_folder_structure(name, parent_folder=None):
(folder_path / "src" / folder_name).mkdir(parents=True) (folder_path / "src" / folder_name).mkdir(parents=True)
(folder_path / "src" / folder_name / "tools").mkdir(parents=True) (folder_path / "src" / folder_name / "tools").mkdir(parents=True)
(folder_path / "src" / folder_name / "config").mkdir(parents=True) (folder_path / "src" / folder_name / "config").mkdir(parents=True)
else:
click.secho(
f"\tFolder {folder_name} already exists.",
fg="yellow",
)
return folder_path, folder_name, class_name return folder_path, folder_name, class_name
@@ -70,38 +81,77 @@ def copy_template_files(folder_path, name, class_name, parent_folder):
copy_template(src_file, dst_file, name, class_name, folder_path.name) copy_template(src_file, dst_file, name, class_name, folder_path.name)
def create_crew(name, provider=None, parent_folder=None): def create_crew(name, parent_folder=None):
folder_path, folder_name, class_name = create_folder_structure(name, parent_folder) folder_path, folder_name, class_name = create_folder_structure(name, parent_folder)
env_vars = load_env_vars(folder_path) env_vars = load_env_vars(folder_path)
if not provider: existing_provider = None
provider_models = get_provider_data() for provider, env_keys in ENV_VARS.items():
if not provider_models: if any(key in env_vars for key in env_keys):
existing_provider = provider
break
if existing_provider:
if not click.confirm(
f"Found existing environment variable configuration for {existing_provider.capitalize()}. Do you want to override it?"
):
click.secho("Keeping existing provider configuration.", fg="yellow")
return return
provider_models = get_provider_data()
if not provider_models:
return
while True:
selected_provider = select_provider(provider_models) selected_provider = select_provider(provider_models)
if not selected_provider: if selected_provider is None: # User typed 'q'
return click.secho("Exiting...", fg="yellow")
provider = selected_provider sys.exit(0)
if selected_provider: # Valid selection
# selected_model = select_model(provider, provider_models) break
# if not selected_model: click.secho(
# return "No provider selected. Please try again or press 'q' to exit.", fg="red"
# model = selected_model
if provider in PROVIDERS:
api_key_var = ENV_VARS[provider][0]
else:
api_key_var = click.prompt(
f"Enter the environment variable name for your {provider.capitalize()} API key",
type=str,
) )
env_vars = {api_key_var: "YOUR_API_KEY_HERE"} while True:
write_env_file(folder_path, env_vars) selected_model = select_model(selected_provider, provider_models)
if selected_model is None: # User typed 'q'
click.secho("Exiting...", fg="yellow")
sys.exit(0)
if selected_model: # Valid selection
break
click.secho(
"No model selected. Please try again or press 'q' to exit.", fg="red"
)
# env_vars['MODEL'] = model if selected_provider in PROVIDERS:
# click.secho(f"Selected model: {model}", fg="green") api_key_var = ENV_VARS[selected_provider][0]
else:
api_key_var = click.prompt(
f"Enter the environment variable name for your {selected_provider.capitalize()} API key",
type=str,
default="",
)
api_key_value = ""
click.echo(
f"Enter your {selected_provider.capitalize()} API key (press Enter to skip): ",
nl=False,
)
try:
api_key_value = input()
except (KeyboardInterrupt, EOFError):
api_key_value = ""
if api_key_value.strip():
env_vars = {api_key_var: api_key_value}
write_env_file(folder_path, env_vars)
click.secho("API key saved to .env file", fg="green")
else:
click.secho("No API key provided. Skipping .env file creation.", fg="yellow")
env_vars["MODEL"] = selected_model
click.secho(f"Selected model: {selected_model}", fg="green")
package_dir = Path(__file__).parent package_dir = Path(__file__).parent
templates_dir = package_dir / "templates" / "crew" templates_dir = package_dir / "templates" / "crew"

View File

@@ -1,67 +1,91 @@
import json import json
import time import time
import requests
from collections import defaultdict from collections import defaultdict
from pathlib import Path
import click import click
from pathlib import Path import requests
from crewai.cli.constants import PROVIDERS, MODELS, JSON_URL
from crewai.cli.constants import JSON_URL, MODELS, PROVIDERS
def select_choice(prompt_message, choices): def select_choice(prompt_message, choices):
""" """
Presents a list of choices to the user and prompts them to select one. Presents a list of choices to the user and prompts them to select one.
Args: Args:
- prompt_message (str): The message to display to the user before presenting the choices. - prompt_message (str): The message to display to the user before presenting the choices.
- choices (list): A list of options to present to the user. - choices (list): A list of options to present to the user.
Returns: Returns:
- str: The selected choice from the list, or None if the operation is aborted or an invalid selection is made. - str: The selected choice from the list, or None if the user chooses to quit.
""" """
provider_models = get_provider_data()
if not provider_models:
return
click.secho(prompt_message, fg="cyan") click.secho(prompt_message, fg="cyan")
for idx, choice in enumerate(choices, start=1): for idx, choice in enumerate(choices, start=1):
click.secho(f"{idx}. {choice}", fg="cyan") click.secho(f"{idx}. {choice}", fg="cyan")
try: click.secho("q. Quit", fg="cyan")
selected_index = click.prompt("Enter the number of your choice", type=int) - 1
except click.exceptions.Abort: while True:
click.secho("Operation aborted by the user.", fg="red") choice = click.prompt(
return None "Enter the number of your choice or 'q' to quit", type=str
if not (0 <= selected_index < len(choices)): )
click.secho("Invalid selection.", fg="red")
return None if choice.lower() == "q":
return choices[selected_index] return None
try:
selected_index = int(choice) - 1
if 0 <= selected_index < len(choices):
return choices[selected_index]
except ValueError:
pass
click.secho(
"Invalid selection. Please select a number between 1 and 6 or 'q' to quit.",
fg="red",
)
def select_provider(provider_models): def select_provider(provider_models):
""" """
Presents a list of providers to the user and prompts them to select one. Presents a list of providers to the user and prompts them to select one.
Args: Args:
- provider_models (dict): A dictionary of provider models. - provider_models (dict): A dictionary of provider models.
Returns: Returns:
- str: The selected provider, or None if the operation is aborted or an invalid selection is made. - str: The selected provider
- None: If user explicitly quits
""" """
predefined_providers = [p.lower() for p in PROVIDERS] predefined_providers = [p.lower() for p in PROVIDERS]
all_providers = sorted(set(predefined_providers + list(provider_models.keys()))) all_providers = sorted(set(predefined_providers + list(provider_models.keys())))
provider = select_choice("Select a provider to set up:", predefined_providers + ['other']) provider = select_choice(
if not provider: "Select a provider to set up:", predefined_providers + ["other"]
)
if provider is None: # User typed 'q'
return None return None
provider = provider.lower()
if provider == 'other': if provider == "other":
provider = select_choice("Select a provider from the full list:", all_providers) provider = select_choice("Select a provider from the full list:", all_providers)
if not provider: if provider is None: # User typed 'q'
return None return None
return provider
return provider.lower() if provider else False
def select_model(provider, provider_models): def select_model(provider, provider_models):
""" """
Presents a list of models for a given provider to the user and prompts them to select one. Presents a list of models for a given provider to the user and prompts them to select one.
Args: Args:
- provider (str): The provider for which to select a model. - provider (str): The provider for which to select a model.
- provider_models (dict): A dictionary of provider models. - provider_models (dict): A dictionary of provider models.
Returns: Returns:
- str: The selected model, or None if the operation is aborted or an invalid selection is made. - str: The selected model, or None if the operation is aborted or an invalid selection is made.
""" """
@@ -76,37 +100,49 @@ def select_model(provider, provider_models):
click.secho(f"No models available for provider '{provider}'.", fg="red") click.secho(f"No models available for provider '{provider}'.", fg="red")
return None return None
selected_model = select_choice(f"Select a model to use for {provider.capitalize()}:", available_models) selected_model = select_choice(
f"Select a model to use for {provider.capitalize()}:", available_models
)
return selected_model return selected_model
def load_provider_data(cache_file, cache_expiry): def load_provider_data(cache_file, cache_expiry):
""" """
Loads provider data from a cache file if it exists and is not expired. If the cache is expired or corrupted, it fetches the data from the web. Loads provider data from a cache file if it exists and is not expired. If the cache is expired or corrupted, it fetches the data from the web.
Args: Args:
- cache_file (Path): The path to the cache file. - cache_file (Path): The path to the cache file.
- cache_expiry (int): The cache expiry time in seconds. - cache_expiry (int): The cache expiry time in seconds.
Returns: Returns:
- dict or None: The loaded provider data or None if the operation fails. - dict or None: The loaded provider data or None if the operation fails.
""" """
current_time = time.time() current_time = time.time()
if cache_file.exists() and (current_time - cache_file.stat().st_mtime) < cache_expiry: if (
cache_file.exists()
and (current_time - cache_file.stat().st_mtime) < cache_expiry
):
data = read_cache_file(cache_file) data = read_cache_file(cache_file)
if data: if data:
return data return data
click.secho("Cache is corrupted. Fetching provider data from the web...", fg="yellow") click.secho(
"Cache is corrupted. Fetching provider data from the web...", fg="yellow"
)
else: else:
click.secho("Cache expired or not found. Fetching provider data from the web...", fg="cyan") click.secho(
"Cache expired or not found. Fetching provider data from the web...",
fg="cyan",
)
return fetch_provider_data(cache_file) return fetch_provider_data(cache_file)
def read_cache_file(cache_file): def read_cache_file(cache_file):
""" """
Reads and returns the JSON content from a cache file. Returns None if the file contains invalid JSON. Reads and returns the JSON content from a cache file. Returns None if the file contains invalid JSON.
Args: Args:
- cache_file (Path): The path to the cache file. - cache_file (Path): The path to the cache file.
Returns: Returns:
- dict or None: The JSON content of the cache file or None if the JSON is invalid. - dict or None: The JSON content of the cache file or None if the JSON is invalid.
""" """
@@ -116,13 +152,14 @@ def read_cache_file(cache_file):
except json.JSONDecodeError: except json.JSONDecodeError:
return None return None
def fetch_provider_data(cache_file): def fetch_provider_data(cache_file):
""" """
Fetches provider data from a specified URL and caches it to a file. Fetches provider data from a specified URL and caches it to a file.
Args: Args:
- cache_file (Path): The path to the cache file. - cache_file (Path): The path to the cache file.
Returns: Returns:
- dict or None: The fetched provider data or None if the operation fails. - dict or None: The fetched provider data or None if the operation fails.
""" """
@@ -139,38 +176,42 @@ def fetch_provider_data(cache_file):
click.secho("Error parsing provider data. Invalid JSON format.", fg="red") click.secho("Error parsing provider data. Invalid JSON format.", fg="red")
return None return None
def download_data(response): def download_data(response):
""" """
Downloads data from a given HTTP response and returns the JSON content. Downloads data from a given HTTP response and returns the JSON content.
Args: Args:
- response (requests.Response): The HTTP response object. - response (requests.Response): The HTTP response object.
Returns: Returns:
- dict: The JSON content of the response. - dict: The JSON content of the response.
""" """
total_size = int(response.headers.get('content-length', 0)) total_size = int(response.headers.get("content-length", 0))
block_size = 8192 block_size = 8192
data_chunks = [] data_chunks = []
with click.progressbar(length=total_size, label='Downloading', show_pos=True) as progress_bar: with click.progressbar(
length=total_size, label="Downloading", show_pos=True
) as progress_bar:
for chunk in response.iter_content(block_size): for chunk in response.iter_content(block_size):
if chunk: if chunk:
data_chunks.append(chunk) data_chunks.append(chunk)
progress_bar.update(len(chunk)) progress_bar.update(len(chunk))
data_content = b''.join(data_chunks) data_content = b"".join(data_chunks)
return json.loads(data_content.decode('utf-8')) return json.loads(data_content.decode("utf-8"))
def get_provider_data(): def get_provider_data():
""" """
Retrieves provider data from a cache file, filters out models based on provider criteria, and returns a dictionary of providers mapped to their models. Retrieves provider data from a cache file, filters out models based on provider criteria, and returns a dictionary of providers mapped to their models.
Returns: Returns:
- dict or None: A dictionary of providers mapped to their models or None if the operation fails. - dict or None: A dictionary of providers mapped to their models or None if the operation fails.
""" """
cache_dir = Path.home() / '.crewai' cache_dir = Path.home() / ".crewai"
cache_dir.mkdir(exist_ok=True) cache_dir.mkdir(exist_ok=True)
cache_file = cache_dir / 'provider_cache.json' cache_file = cache_dir / "provider_cache.json"
cache_expiry = 24 * 3600 cache_expiry = 24 * 3600
data = load_provider_data(cache_file, cache_expiry) data = load_provider_data(cache_file, cache_expiry)
if not data: if not data:
@@ -179,8 +220,8 @@ def get_provider_data():
provider_models = defaultdict(list) provider_models = defaultdict(list)
for model_name, properties in data.items(): for model_name, properties in data.items():
provider = properties.get("litellm_provider", "").strip().lower() provider = properties.get("litellm_provider", "").strip().lower()
if 'http' in provider or provider == 'other': if "http" in provider or provider == "other":
continue continue
if provider: if provider:
provider_models[provider].append(model_name) provider_models[provider].append(model_name)
return provider_models return provider_models