From ed3edc5c43a6df57c16450b56eb1f4aa70f42654 Mon Sep 17 00:00:00 2001 From: Rip&Tear <84775494+theCyberTech@users.noreply.github.com> Date: Wed, 23 Oct 2024 21:41:14 +0800 Subject: [PATCH] fix/fixed missing API prompt + CLI docs update (#1464) * updated CLI to allow for submitting API keys * updated click prompt to remove default number * removed all unnecessary comments * feat: implement crew creation CLI command - refactor code to multiple functions - Added ability for users to select provider and model when uing crewai create command and ave API key to .env * refactered select_choice function for early return * refactored select_provider to have an ealry return * cleanup of comments * refactor/Move functions into utils file, added new provider file and migrated fucntions thre, new constants file + general function refactor * small comment cleanup * fix unnecessary deps * Added docs for new CLI provider + fixed missing API prompt * Minor doc updates * allow user to bypass api key entry + incorect number selected logic + ruff formatting * ruff updates * Fix spelling mistake --------- Co-authored-by: Brandon Hancock (bhancock_ai) <109994880+bhancockio@users.noreply.github.com> Co-authored-by: Brandon Hancock --- docs/concepts/cli.mdx | 33 +++++++- src/crewai/cli/create_crew.py | 120 ++++++++++++++++++++--------- src/crewai/cli/provider.py | 137 ++++++++++++++++++++++------------ 3 files changed, 206 insertions(+), 84 deletions(-) diff --git a/docs/concepts/cli.mdx b/docs/concepts/cli.mdx index 8297ee6aa..2afc6b56c 100644 --- a/docs/concepts/cli.mdx +++ b/docs/concepts/cli.mdx @@ -6,7 +6,7 @@ icon: terminal # CrewAI CLI Documentation -The CrewAI CLI provides a set of commands to interact with CrewAI, allowing you to create, train, run, and manage crews and pipelines. +The CrewAI CLI provides a set of commands to interact with CrewAI, allowing you to create, train, run, and manage crews & flows. ## Installation @@ -146,3 +146,34 @@ crewai run Make sure to run these commands from the directory where your CrewAI project is set up. Some commands may require additional configuration or setup within your project structure. + + +### 9. API Keys + +When running ```crewai create crew``` command, the CLI will first show you the top 5 most common LLM providers and ask you to select one. + +Once you've selected an LLM provider, you will be prompted for API keys. + +#### Initial API key providers + +The CLI will initially prompt for API keys for the following services: + +* OpenAI +* Groq +* Anthropic +* Google Gemini + +When you select a provider, the CLI will prompt you to enter your API key. + +#### Other Options + +If you select option 6, you will be able to select from a list of LiteLLM supported providers. + +When you select a provider, the CLI will prompt you to enter the Key name and the API key. + +See the following link for each provider's key name: + +* [LiteLLM Providers](https://docs.litellm.ai/docs/providers) + + + diff --git a/src/crewai/cli/create_crew.py b/src/crewai/cli/create_crew.py index f3a50f5f4..f336c3f52 100644 --- a/src/crewai/cli/create_crew.py +++ b/src/crewai/cli/create_crew.py @@ -1,8 +1,16 @@ +import sys from pathlib import Path + import click -from crewai.cli.utils import copy_template, load_env_vars, write_env_file -from crewai.cli.provider import get_provider_data, select_provider, PROVIDERS + from crewai.cli.constants import ENV_VARS +from crewai.cli.provider import ( + PROVIDERS, + get_provider_data, + select_model, + select_provider, +) +from crewai.cli.utils import copy_template, load_env_vars, write_env_file def create_folder_structure(name, parent_folder=None): @@ -14,11 +22,19 @@ def create_folder_structure(name, parent_folder=None): else: folder_path = Path(folder_name) - click.secho( - f"Creating {'crew' if parent_folder else 'folder'} {folder_name}...", - fg="green", - bold=True, - ) + if folder_path.exists(): + if not click.confirm( + f"Folder {folder_name} already exists. Do you want to override it?" + ): + click.secho("Operation cancelled.", fg="yellow") + sys.exit(0) + click.secho(f"Overriding folder {folder_name}...", fg="green", bold=True) + else: + click.secho( + f"Creating {'crew' if parent_folder else 'folder'} {folder_name}...", + fg="green", + bold=True, + ) if not folder_path.exists(): folder_path.mkdir(parents=True) @@ -27,11 +43,6 @@ def create_folder_structure(name, parent_folder=None): (folder_path / "src" / folder_name).mkdir(parents=True) (folder_path / "src" / folder_name / "tools").mkdir(parents=True) (folder_path / "src" / folder_name / "config").mkdir(parents=True) - else: - click.secho( - f"\tFolder {folder_name} already exists.", - fg="yellow", - ) return folder_path, folder_name, class_name @@ -70,38 +81,77 @@ def copy_template_files(folder_path, name, class_name, parent_folder): copy_template(src_file, dst_file, name, class_name, folder_path.name) -def create_crew(name, provider=None, parent_folder=None): +def create_crew(name, parent_folder=None): folder_path, folder_name, class_name = create_folder_structure(name, parent_folder) env_vars = load_env_vars(folder_path) - if not provider: - provider_models = get_provider_data() - if not provider_models: + existing_provider = None + for provider, env_keys in ENV_VARS.items(): + if any(key in env_vars for key in env_keys): + existing_provider = provider + break + + if existing_provider: + if not click.confirm( + f"Found existing environment variable configuration for {existing_provider.capitalize()}. Do you want to override it?" + ): + click.secho("Keeping existing provider configuration.", fg="yellow") return + provider_models = get_provider_data() + if not provider_models: + return + + while True: selected_provider = select_provider(provider_models) - if not selected_provider: - return - provider = selected_provider - - # selected_model = select_model(provider, provider_models) - # if not selected_model: - # return - # model = selected_model - - if provider in PROVIDERS: - api_key_var = ENV_VARS[provider][0] - else: - api_key_var = click.prompt( - f"Enter the environment variable name for your {provider.capitalize()} API key", - type=str, + if selected_provider is None: # User typed 'q' + click.secho("Exiting...", fg="yellow") + sys.exit(0) + if selected_provider: # Valid selection + break + click.secho( + "No provider selected. Please try again or press 'q' to exit.", fg="red" ) - env_vars = {api_key_var: "YOUR_API_KEY_HERE"} - write_env_file(folder_path, env_vars) + while True: + selected_model = select_model(selected_provider, provider_models) + if selected_model is None: # User typed 'q' + click.secho("Exiting...", fg="yellow") + sys.exit(0) + if selected_model: # Valid selection + break + click.secho( + "No model selected. Please try again or press 'q' to exit.", fg="red" + ) - # env_vars['MODEL'] = model - # click.secho(f"Selected model: {model}", fg="green") + if selected_provider in PROVIDERS: + api_key_var = ENV_VARS[selected_provider][0] + else: + api_key_var = click.prompt( + f"Enter the environment variable name for your {selected_provider.capitalize()} API key", + type=str, + default="", + ) + + api_key_value = "" + click.echo( + f"Enter your {selected_provider.capitalize()} API key (press Enter to skip): ", + nl=False, + ) + try: + api_key_value = input() + except (KeyboardInterrupt, EOFError): + api_key_value = "" + + if api_key_value.strip(): + env_vars = {api_key_var: api_key_value} + write_env_file(folder_path, env_vars) + click.secho("API key saved to .env file", fg="green") + else: + click.secho("No API key provided. Skipping .env file creation.", fg="yellow") + + env_vars["MODEL"] = selected_model + click.secho(f"Selected model: {selected_model}", fg="green") package_dir = Path(__file__).parent templates_dir = package_dir / "templates" / "crew" diff --git a/src/crewai/cli/provider.py b/src/crewai/cli/provider.py index f829ca9fd..4bfeb9324 100644 --- a/src/crewai/cli/provider.py +++ b/src/crewai/cli/provider.py @@ -1,67 +1,91 @@ import json import time -import requests from collections import defaultdict +from pathlib import Path + import click -from pathlib import Path -from crewai.cli.constants import PROVIDERS, MODELS, JSON_URL +import requests + +from crewai.cli.constants import JSON_URL, MODELS, PROVIDERS + def select_choice(prompt_message, choices): """ Presents a list of choices to the user and prompts them to select one. - + Args: - prompt_message (str): The message to display to the user before presenting the choices. - choices (list): A list of options to present to the user. - + Returns: - - str: The selected choice from the list, or None if the operation is aborted or an invalid selection is made. + - str: The selected choice from the list, or None if the user chooses to quit. """ + + provider_models = get_provider_data() + if not provider_models: + return click.secho(prompt_message, fg="cyan") for idx, choice in enumerate(choices, start=1): click.secho(f"{idx}. {choice}", fg="cyan") - try: - selected_index = click.prompt("Enter the number of your choice", type=int) - 1 - except click.exceptions.Abort: - click.secho("Operation aborted by the user.", fg="red") - return None - if not (0 <= selected_index < len(choices)): - click.secho("Invalid selection.", fg="red") - return None - return choices[selected_index] + click.secho("q. Quit", fg="cyan") + + while True: + choice = click.prompt( + "Enter the number of your choice or 'q' to quit", type=str + ) + + if choice.lower() == "q": + return None + + try: + selected_index = int(choice) - 1 + if 0 <= selected_index < len(choices): + return choices[selected_index] + except ValueError: + pass + + click.secho( + "Invalid selection. Please select a number between 1 and 6 or 'q' to quit.", + fg="red", + ) + def select_provider(provider_models): """ Presents a list of providers to the user and prompts them to select one. - + Args: - provider_models (dict): A dictionary of provider models. - + Returns: - - str: The selected provider, or None if the operation is aborted or an invalid selection is made. + - str: The selected provider + - None: If user explicitly quits """ predefined_providers = [p.lower() for p in PROVIDERS] all_providers = sorted(set(predefined_providers + list(provider_models.keys()))) - provider = select_choice("Select a provider to set up:", predefined_providers + ['other']) - if not provider: + provider = select_choice( + "Select a provider to set up:", predefined_providers + ["other"] + ) + if provider is None: # User typed 'q' return None - provider = provider.lower() - if provider == 'other': + if provider == "other": provider = select_choice("Select a provider from the full list:", all_providers) - if not provider: + if provider is None: # User typed 'q' return None - return provider + + return provider.lower() if provider else False + def select_model(provider, provider_models): """ Presents a list of models for a given provider to the user and prompts them to select one. - + Args: - provider (str): The provider for which to select a model. - provider_models (dict): A dictionary of provider models. - + Returns: - str: The selected model, or None if the operation is aborted or an invalid selection is made. """ @@ -76,37 +100,49 @@ def select_model(provider, provider_models): click.secho(f"No models available for provider '{provider}'.", fg="red") return None - selected_model = select_choice(f"Select a model to use for {provider.capitalize()}:", available_models) + selected_model = select_choice( + f"Select a model to use for {provider.capitalize()}:", available_models + ) return selected_model + def load_provider_data(cache_file, cache_expiry): """ Loads provider data from a cache file if it exists and is not expired. If the cache is expired or corrupted, it fetches the data from the web. - + Args: - cache_file (Path): The path to the cache file. - cache_expiry (int): The cache expiry time in seconds. - + Returns: - dict or None: The loaded provider data or None if the operation fails. """ current_time = time.time() - if cache_file.exists() and (current_time - cache_file.stat().st_mtime) < cache_expiry: + if ( + cache_file.exists() + and (current_time - cache_file.stat().st_mtime) < cache_expiry + ): data = read_cache_file(cache_file) if data: return data - click.secho("Cache is corrupted. Fetching provider data from the web...", fg="yellow") + click.secho( + "Cache is corrupted. Fetching provider data from the web...", fg="yellow" + ) else: - click.secho("Cache expired or not found. Fetching provider data from the web...", fg="cyan") + click.secho( + "Cache expired or not found. Fetching provider data from the web...", + fg="cyan", + ) return fetch_provider_data(cache_file) + def read_cache_file(cache_file): """ Reads and returns the JSON content from a cache file. Returns None if the file contains invalid JSON. - + Args: - cache_file (Path): The path to the cache file. - + Returns: - dict or None: The JSON content of the cache file or None if the JSON is invalid. """ @@ -116,13 +152,14 @@ def read_cache_file(cache_file): except json.JSONDecodeError: return None + def fetch_provider_data(cache_file): """ Fetches provider data from a specified URL and caches it to a file. - + Args: - cache_file (Path): The path to the cache file. - + Returns: - dict or None: The fetched provider data or None if the operation fails. """ @@ -139,38 +176,42 @@ def fetch_provider_data(cache_file): click.secho("Error parsing provider data. Invalid JSON format.", fg="red") return None + def download_data(response): """ Downloads data from a given HTTP response and returns the JSON content. - + Args: - response (requests.Response): The HTTP response object. - + Returns: - dict: The JSON content of the response. """ - total_size = int(response.headers.get('content-length', 0)) + total_size = int(response.headers.get("content-length", 0)) block_size = 8192 data_chunks = [] - with click.progressbar(length=total_size, label='Downloading', show_pos=True) as progress_bar: + with click.progressbar( + length=total_size, label="Downloading", show_pos=True + ) as progress_bar: for chunk in response.iter_content(block_size): if chunk: data_chunks.append(chunk) progress_bar.update(len(chunk)) - data_content = b''.join(data_chunks) - return json.loads(data_content.decode('utf-8')) + data_content = b"".join(data_chunks) + return json.loads(data_content.decode("utf-8")) + def get_provider_data(): """ Retrieves provider data from a cache file, filters out models based on provider criteria, and returns a dictionary of providers mapped to their models. - + Returns: - dict or None: A dictionary of providers mapped to their models or None if the operation fails. """ - cache_dir = Path.home() / '.crewai' + cache_dir = Path.home() / ".crewai" cache_dir.mkdir(exist_ok=True) - cache_file = cache_dir / 'provider_cache.json' - cache_expiry = 24 * 3600 + cache_file = cache_dir / "provider_cache.json" + cache_expiry = 24 * 3600 data = load_provider_data(cache_file, cache_expiry) if not data: @@ -179,8 +220,8 @@ def get_provider_data(): provider_models = defaultdict(list) for model_name, properties in data.items(): provider = properties.get("litellm_provider", "").strip().lower() - if 'http' in provider or provider == 'other': + if "http" in provider or provider == "other": continue if provider: provider_models[provider].append(model_name) - return provider_models \ No newline at end of file + return provider_models