From 098a4312ab669a440bf75ed8980a65ebbe1374eb Mon Sep 17 00:00:00 2001 From: Rip&Tear <84775494+theCyberTech@users.noreply.github.com> Date: Wed, 23 Oct 2024 17:17:05 +0800 Subject: [PATCH] allow user to bypass api key entry + incorect number selected logic + ruff formatting --- src/crewai/cli/create_crew.py | 121 ++++++++++++++++++++++---------- src/crewai/cli/provider.py | 127 ++++++++++++++++++++++------------ 2 files changed, 164 insertions(+), 84 deletions(-) diff --git a/src/crewai/cli/create_crew.py b/src/crewai/cli/create_crew.py index 5803f0ac1..1e506c3b7 100644 --- a/src/crewai/cli/create_crew.py +++ b/src/crewai/cli/create_crew.py @@ -1,8 +1,15 @@ from pathlib import Path import click -from crewai.cli.utils import copy_template,load_env_vars, write_env_file -from crewai.cli.provider import get_provider_data, select_provider, select_model, PROVIDERS +from crewai.cli.utils import copy_template, load_env_vars, write_env_file +from crewai.cli.provider import ( + get_provider_data, + select_provider, + select_model, + PROVIDERS, +) from crewai.cli.constants import ENV_VARS +import sys + def create_folder_structure(name, parent_folder=None): folder_name = name.replace(" ", "_").replace("-", "_").lower() @@ -13,11 +20,19 @@ def create_folder_structure(name, parent_folder=None): else: folder_path = Path(folder_name) - click.secho( - f"Creating {'crew' if parent_folder else 'folder'} {folder_name}...", - fg="green", - bold=True, - ) + if folder_path.exists(): + if not click.confirm( + f"Folder {folder_name} already exists. Do you want to override it?" + ): + click.secho("Operation cancelled.", fg="yellow") + sys.exit(0) + click.secho(f"Overriding folder {folder_name}...", fg="green", bold=True) + else: + click.secho( + f"Creating {'crew' if parent_folder else 'folder'} {folder_name}...", + fg="green", + bold=True, + ) if not folder_path.exists(): folder_path.mkdir(parents=True) @@ -26,16 +41,10 @@ def create_folder_structure(name, parent_folder=None): (folder_path / "src" / folder_name).mkdir(parents=True) (folder_path / "src" / folder_name / "tools").mkdir(parents=True) (folder_path / "src" / folder_name / "config").mkdir(parents=True) - else: - click.secho( - f"\tFolder {folder_name} already exists.", - fg="yellow", - ) return folder_path, folder_name, class_name - def copy_template_files(folder_path, name, class_name, parent_folder): package_dir = Path(__file__).parent templates_dir = package_dir / "templates" / "crew" @@ -54,7 +63,9 @@ def copy_template_files(folder_path, name, class_name, parent_folder): dst_file = folder_path / file_name copy_template(src_file, dst_file, name, class_name, folder_path.name) - src_folder = folder_path / "src" / folder_path.name if not parent_folder else folder_path + src_folder = ( + folder_path / "src" / folder_path.name if not parent_folder else folder_path + ) for file_name in src_template_files: src_file = templates_dir / file_name @@ -72,39 +83,73 @@ def create_crew(name, parent_folder=None): folder_path, folder_name, class_name = create_folder_structure(name, parent_folder) env_vars = load_env_vars(folder_path) + existing_provider = None + for provider, env_keys in ENV_VARS.items(): + if any(key in env_vars for key in env_keys): + existing_provider = provider + break + + if existing_provider: + if not click.confirm( + f"Found existing environment variable configuration for {existing_provider.capitalize()}. Do you want to override it?" + ): + click.secho("Keeping existing provider configuration.", fg="yellow") + return + provider_models = get_provider_data() if not provider_models: return - selected_provider = select_provider(provider_models) - if not selected_provider: - return - provider = selected_provider - - selected_model = select_model(provider, provider_models) - if not selected_model: - return - model = selected_model - - if provider in PROVIDERS: - api_key_var = ENV_VARS[provider][0] - else: - api_key_var = click.prompt( - f"Enter the environment variable name for your {provider.capitalize()} API key", - type=str + while True: + selected_provider = select_provider(provider_models) + if selected_provider is None: # User typed 'q' + click.secho("Exiting...", fg="yellow") + sys.exit(0) + if selected_provider: # Valid selection + break + click.secho( + "No provider selected. Please try again or press 'q' to exit.", fg="red" ) - api_key_value = click.prompt( - f"Enter your {provider.capitalize()} API key", - type=str, - hide_input=True + while True: + selected_model = select_model(selected_provider, provider_models) + if selected_model is None: # User typed 'q' + click.secho("Exiting...", fg="yellow") + sys.exit(0) + if selected_model: # Valid selection + break + click.secho( + "No model selected. Please try again or press 'q' to exit.", fg="red" + ) + + if selected_provider in PROVIDERS: + api_key_var = ENV_VARS[selected_provider][0] + else: + api_key_var = click.prompt( + f"Enter the environment variable name for your {selected_provider.capitalize()} API key", + type=str, + default="", + ) + + api_key_value = "" + click.echo( + f"Enter your {selected_provider.capitalize()} API key (press Enter to skip): ", + nl=False, ) + try: + api_key_value = input() + except (KeyboardInterrupt, EOFError): + api_key_value = "" - env_vars = {api_key_var: api_key_value} - write_env_file(folder_path, env_vars) + if api_key_value.strip(): + env_vars = {api_key_var: api_key_value} + write_env_file(folder_path, env_vars) + click.secho("API key saved to .env file", fg="green") + else: + click.secho("No API key provided. Skipping .env file creation.", fg="yellow") - env_vars['MODEL'] = model - click.secho(f"Selected model: {model}", fg="green") + env_vars["MODEL"] = selected_model + click.secho(f"Selected model: {selected_model}", fg="green") package_dir = Path(__file__).parent templates_dir = package_dir / "templates" / "crew" diff --git a/src/crewai/cli/provider.py b/src/crewai/cli/provider.py index f829ca9fd..8d68398ab 100644 --- a/src/crewai/cli/provider.py +++ b/src/crewai/cli/provider.py @@ -3,65 +3,83 @@ import time import requests from collections import defaultdict import click -from pathlib import Path +from pathlib import Path from crewai.cli.constants import PROVIDERS, MODELS, JSON_URL + def select_choice(prompt_message, choices): """ Presents a list of choices to the user and prompts them to select one. - + Args: - prompt_message (str): The message to display to the user before presenting the choices. - choices (list): A list of options to present to the user. - + Returns: - - str: The selected choice from the list, or None if the operation is aborted or an invalid selection is made. + - str: The selected choice from the list, or None if the user chooses to quit. """ click.secho(prompt_message, fg="cyan") for idx, choice in enumerate(choices, start=1): click.secho(f"{idx}. {choice}", fg="cyan") - try: - selected_index = click.prompt("Enter the number of your choice", type=int) - 1 - except click.exceptions.Abort: - click.secho("Operation aborted by the user.", fg="red") - return None - if not (0 <= selected_index < len(choices)): - click.secho("Invalid selection.", fg="red") - return None - return choices[selected_index] + click.secho("q. Quit", fg="cyan") + + while True: + choice = click.prompt( + "Enter the number of your choice or 'q' to quit", type=str + ) + + if choice.lower() == "q": + return None + + try: + selected_index = int(choice) - 1 + if 0 <= selected_index < len(choices): + return choices[selected_index] + except ValueError: + pass + + click.secho( + "Invalid selection. Please select a number between 1 and 6 or 'q' to quit.", + fg="red", + ) + def select_provider(provider_models): """ Presents a list of providers to the user and prompts them to select one. - + Args: - provider_models (dict): A dictionary of provider models. - + Returns: - - str: The selected provider, or None if the operation is aborted or an invalid selection is made. + - str: The selected provider + - None: If user explicitly quits """ predefined_providers = [p.lower() for p in PROVIDERS] all_providers = sorted(set(predefined_providers + list(provider_models.keys()))) - provider = select_choice("Select a provider to set up:", predefined_providers + ['other']) - if not provider: + provider = select_choice( + "Select a provider to set up:", predefined_providers + ["other"] + ) + if provider is None: # User typed 'q' return None - provider = provider.lower() - if provider == 'other': + if provider == "other": provider = select_choice("Select a provider from the full list:", all_providers) - if not provider: + if provider is None: # User typed 'q' return None - return provider + + return provider.lower() if provider else False + def select_model(provider, provider_models): """ Presents a list of models for a given provider to the user and prompts them to select one. - + Args: - provider (str): The provider for which to select a model. - provider_models (dict): A dictionary of provider models. - + Returns: - str: The selected model, or None if the operation is aborted or an invalid selection is made. """ @@ -76,37 +94,49 @@ def select_model(provider, provider_models): click.secho(f"No models available for provider '{provider}'.", fg="red") return None - selected_model = select_choice(f"Select a model to use for {provider.capitalize()}:", available_models) + selected_model = select_choice( + f"Select a model to use for {provider.capitalize()}:", available_models + ) return selected_model + def load_provider_data(cache_file, cache_expiry): """ Loads provider data from a cache file if it exists and is not expired. If the cache is expired or corrupted, it fetches the data from the web. - + Args: - cache_file (Path): The path to the cache file. - cache_expiry (int): The cache expiry time in seconds. - + Returns: - dict or None: The loaded provider data or None if the operation fails. """ current_time = time.time() - if cache_file.exists() and (current_time - cache_file.stat().st_mtime) < cache_expiry: + if ( + cache_file.exists() + and (current_time - cache_file.stat().st_mtime) < cache_expiry + ): data = read_cache_file(cache_file) if data: return data - click.secho("Cache is corrupted. Fetching provider data from the web...", fg="yellow") + click.secho( + "Cache is corrupted. Fetching provider data from the web...", fg="yellow" + ) else: - click.secho("Cache expired or not found. Fetching provider data from the web...", fg="cyan") + click.secho( + "Cache expired or not found. Fetching provider data from the web...", + fg="cyan", + ) return fetch_provider_data(cache_file) + def read_cache_file(cache_file): """ Reads and returns the JSON content from a cache file. Returns None if the file contains invalid JSON. - + Args: - cache_file (Path): The path to the cache file. - + Returns: - dict or None: The JSON content of the cache file or None if the JSON is invalid. """ @@ -116,13 +146,14 @@ def read_cache_file(cache_file): except json.JSONDecodeError: return None + def fetch_provider_data(cache_file): """ Fetches provider data from a specified URL and caches it to a file. - + Args: - cache_file (Path): The path to the cache file. - + Returns: - dict or None: The fetched provider data or None if the operation fails. """ @@ -139,38 +170,42 @@ def fetch_provider_data(cache_file): click.secho("Error parsing provider data. Invalid JSON format.", fg="red") return None + def download_data(response): """ Downloads data from a given HTTP response and returns the JSON content. - + Args: - response (requests.Response): The HTTP response object. - + Returns: - dict: The JSON content of the response. """ - total_size = int(response.headers.get('content-length', 0)) + total_size = int(response.headers.get("content-length", 0)) block_size = 8192 data_chunks = [] - with click.progressbar(length=total_size, label='Downloading', show_pos=True) as progress_bar: + with click.progressbar( + length=total_size, label="Downloading", show_pos=True + ) as progress_bar: for chunk in response.iter_content(block_size): if chunk: data_chunks.append(chunk) progress_bar.update(len(chunk)) - data_content = b''.join(data_chunks) - return json.loads(data_content.decode('utf-8')) + data_content = b"".join(data_chunks) + return json.loads(data_content.decode("utf-8")) + def get_provider_data(): """ Retrieves provider data from a cache file, filters out models based on provider criteria, and returns a dictionary of providers mapped to their models. - + Returns: - dict or None: A dictionary of providers mapped to their models or None if the operation fails. """ - cache_dir = Path.home() / '.crewai' + cache_dir = Path.home() / ".crewai" cache_dir.mkdir(exist_ok=True) - cache_file = cache_dir / 'provider_cache.json' - cache_expiry = 24 * 3600 + cache_file = cache_dir / "provider_cache.json" + cache_expiry = 24 * 3600 data = load_provider_data(cache_file, cache_expiry) if not data: @@ -179,8 +214,8 @@ def get_provider_data(): provider_models = defaultdict(list) for model_name, properties in data.items(): provider = properties.get("litellm_provider", "").strip().lower() - if 'http' in provider or provider == 'other': + if "http" in provider or provider == "other": continue if provider: provider_models[provider].append(model_name) - return provider_models \ No newline at end of file + return provider_models