allow user to bypass api key entry + incorect number selected logic + ruff formatting

This commit is contained in:
Rip&Tear
2024-10-23 17:17:05 +08:00
parent c724c0af70
commit 098a4312ab
2 changed files with 164 additions and 84 deletions

View File

@@ -1,8 +1,15 @@
from pathlib import Path from pathlib import Path
import click import click
from crewai.cli.utils import copy_template,load_env_vars, write_env_file from crewai.cli.utils import copy_template, load_env_vars, write_env_file
from crewai.cli.provider import get_provider_data, select_provider, select_model, PROVIDERS from crewai.cli.provider import (
get_provider_data,
select_provider,
select_model,
PROVIDERS,
)
from crewai.cli.constants import ENV_VARS from crewai.cli.constants import ENV_VARS
import sys
def create_folder_structure(name, parent_folder=None): def create_folder_structure(name, parent_folder=None):
folder_name = name.replace(" ", "_").replace("-", "_").lower() folder_name = name.replace(" ", "_").replace("-", "_").lower()
@@ -13,11 +20,19 @@ def create_folder_structure(name, parent_folder=None):
else: else:
folder_path = Path(folder_name) folder_path = Path(folder_name)
click.secho( if folder_path.exists():
f"Creating {'crew' if parent_folder else 'folder'} {folder_name}...", if not click.confirm(
fg="green", f"Folder {folder_name} already exists. Do you want to override it?"
bold=True, ):
) click.secho("Operation cancelled.", fg="yellow")
sys.exit(0)
click.secho(f"Overriding folder {folder_name}...", fg="green", bold=True)
else:
click.secho(
f"Creating {'crew' if parent_folder else 'folder'} {folder_name}...",
fg="green",
bold=True,
)
if not folder_path.exists(): if not folder_path.exists():
folder_path.mkdir(parents=True) folder_path.mkdir(parents=True)
@@ -26,16 +41,10 @@ def create_folder_structure(name, parent_folder=None):
(folder_path / "src" / folder_name).mkdir(parents=True) (folder_path / "src" / folder_name).mkdir(parents=True)
(folder_path / "src" / folder_name / "tools").mkdir(parents=True) (folder_path / "src" / folder_name / "tools").mkdir(parents=True)
(folder_path / "src" / folder_name / "config").mkdir(parents=True) (folder_path / "src" / folder_name / "config").mkdir(parents=True)
else:
click.secho(
f"\tFolder {folder_name} already exists.",
fg="yellow",
)
return folder_path, folder_name, class_name return folder_path, folder_name, class_name
def copy_template_files(folder_path, name, class_name, parent_folder): def copy_template_files(folder_path, name, class_name, parent_folder):
package_dir = Path(__file__).parent package_dir = Path(__file__).parent
templates_dir = package_dir / "templates" / "crew" templates_dir = package_dir / "templates" / "crew"
@@ -54,7 +63,9 @@ def copy_template_files(folder_path, name, class_name, parent_folder):
dst_file = folder_path / file_name dst_file = folder_path / file_name
copy_template(src_file, dst_file, name, class_name, folder_path.name) copy_template(src_file, dst_file, name, class_name, folder_path.name)
src_folder = folder_path / "src" / folder_path.name if not parent_folder else folder_path src_folder = (
folder_path / "src" / folder_path.name if not parent_folder else folder_path
)
for file_name in src_template_files: for file_name in src_template_files:
src_file = templates_dir / file_name src_file = templates_dir / file_name
@@ -72,39 +83,73 @@ def create_crew(name, parent_folder=None):
folder_path, folder_name, class_name = create_folder_structure(name, parent_folder) folder_path, folder_name, class_name = create_folder_structure(name, parent_folder)
env_vars = load_env_vars(folder_path) env_vars = load_env_vars(folder_path)
existing_provider = None
for provider, env_keys in ENV_VARS.items():
if any(key in env_vars for key in env_keys):
existing_provider = provider
break
if existing_provider:
if not click.confirm(
f"Found existing environment variable configuration for {existing_provider.capitalize()}. Do you want to override it?"
):
click.secho("Keeping existing provider configuration.", fg="yellow")
return
provider_models = get_provider_data() provider_models = get_provider_data()
if not provider_models: if not provider_models:
return return
selected_provider = select_provider(provider_models) while True:
if not selected_provider: selected_provider = select_provider(provider_models)
return if selected_provider is None: # User typed 'q'
provider = selected_provider click.secho("Exiting...", fg="yellow")
sys.exit(0)
selected_model = select_model(provider, provider_models) if selected_provider: # Valid selection
if not selected_model: break
return click.secho(
model = selected_model "No provider selected. Please try again or press 'q' to exit.", fg="red"
if provider in PROVIDERS:
api_key_var = ENV_VARS[provider][0]
else:
api_key_var = click.prompt(
f"Enter the environment variable name for your {provider.capitalize()} API key",
type=str
) )
api_key_value = click.prompt( while True:
f"Enter your {provider.capitalize()} API key", selected_model = select_model(selected_provider, provider_models)
type=str, if selected_model is None: # User typed 'q'
hide_input=True click.secho("Exiting...", fg="yellow")
sys.exit(0)
if selected_model: # Valid selection
break
click.secho(
"No model selected. Please try again or press 'q' to exit.", fg="red"
)
if selected_provider in PROVIDERS:
api_key_var = ENV_VARS[selected_provider][0]
else:
api_key_var = click.prompt(
f"Enter the environment variable name for your {selected_provider.capitalize()} API key",
type=str,
default="",
)
api_key_value = ""
click.echo(
f"Enter your {selected_provider.capitalize()} API key (press Enter to skip): ",
nl=False,
) )
try:
api_key_value = input()
except (KeyboardInterrupt, EOFError):
api_key_value = ""
env_vars = {api_key_var: api_key_value} if api_key_value.strip():
write_env_file(folder_path, env_vars) env_vars = {api_key_var: api_key_value}
write_env_file(folder_path, env_vars)
click.secho("API key saved to .env file", fg="green")
else:
click.secho("No API key provided. Skipping .env file creation.", fg="yellow")
env_vars['MODEL'] = model env_vars["MODEL"] = selected_model
click.secho(f"Selected model: {model}", fg="green") click.secho(f"Selected model: {selected_model}", fg="green")
package_dir = Path(__file__).parent package_dir = Path(__file__).parent
templates_dir = package_dir / "templates" / "crew" templates_dir = package_dir / "templates" / "crew"

View File

@@ -3,65 +3,83 @@ import time
import requests import requests
from collections import defaultdict from collections import defaultdict
import click import click
from pathlib import Path from pathlib import Path
from crewai.cli.constants import PROVIDERS, MODELS, JSON_URL from crewai.cli.constants import PROVIDERS, MODELS, JSON_URL
def select_choice(prompt_message, choices): def select_choice(prompt_message, choices):
""" """
Presents a list of choices to the user and prompts them to select one. Presents a list of choices to the user and prompts them to select one.
Args: Args:
- prompt_message (str): The message to display to the user before presenting the choices. - prompt_message (str): The message to display to the user before presenting the choices.
- choices (list): A list of options to present to the user. - choices (list): A list of options to present to the user.
Returns: Returns:
- str: The selected choice from the list, or None if the operation is aborted or an invalid selection is made. - str: The selected choice from the list, or None if the user chooses to quit.
""" """
click.secho(prompt_message, fg="cyan") click.secho(prompt_message, fg="cyan")
for idx, choice in enumerate(choices, start=1): for idx, choice in enumerate(choices, start=1):
click.secho(f"{idx}. {choice}", fg="cyan") click.secho(f"{idx}. {choice}", fg="cyan")
try: click.secho("q. Quit", fg="cyan")
selected_index = click.prompt("Enter the number of your choice", type=int) - 1
except click.exceptions.Abort: while True:
click.secho("Operation aborted by the user.", fg="red") choice = click.prompt(
return None "Enter the number of your choice or 'q' to quit", type=str
if not (0 <= selected_index < len(choices)): )
click.secho("Invalid selection.", fg="red")
return None if choice.lower() == "q":
return choices[selected_index] return None
try:
selected_index = int(choice) - 1
if 0 <= selected_index < len(choices):
return choices[selected_index]
except ValueError:
pass
click.secho(
"Invalid selection. Please select a number between 1 and 6 or 'q' to quit.",
fg="red",
)
def select_provider(provider_models): def select_provider(provider_models):
""" """
Presents a list of providers to the user and prompts them to select one. Presents a list of providers to the user and prompts them to select one.
Args: Args:
- provider_models (dict): A dictionary of provider models. - provider_models (dict): A dictionary of provider models.
Returns: Returns:
- str: The selected provider, or None if the operation is aborted or an invalid selection is made. - str: The selected provider
- None: If user explicitly quits
""" """
predefined_providers = [p.lower() for p in PROVIDERS] predefined_providers = [p.lower() for p in PROVIDERS]
all_providers = sorted(set(predefined_providers + list(provider_models.keys()))) all_providers = sorted(set(predefined_providers + list(provider_models.keys())))
provider = select_choice("Select a provider to set up:", predefined_providers + ['other']) provider = select_choice(
if not provider: "Select a provider to set up:", predefined_providers + ["other"]
)
if provider is None: # User typed 'q'
return None return None
provider = provider.lower()
if provider == 'other': if provider == "other":
provider = select_choice("Select a provider from the full list:", all_providers) provider = select_choice("Select a provider from the full list:", all_providers)
if not provider: if provider is None: # User typed 'q'
return None return None
return provider
return provider.lower() if provider else False
def select_model(provider, provider_models): def select_model(provider, provider_models):
""" """
Presents a list of models for a given provider to the user and prompts them to select one. Presents a list of models for a given provider to the user and prompts them to select one.
Args: Args:
- provider (str): The provider for which to select a model. - provider (str): The provider for which to select a model.
- provider_models (dict): A dictionary of provider models. - provider_models (dict): A dictionary of provider models.
Returns: Returns:
- str: The selected model, or None if the operation is aborted or an invalid selection is made. - str: The selected model, or None if the operation is aborted or an invalid selection is made.
""" """
@@ -76,37 +94,49 @@ def select_model(provider, provider_models):
click.secho(f"No models available for provider '{provider}'.", fg="red") click.secho(f"No models available for provider '{provider}'.", fg="red")
return None return None
selected_model = select_choice(f"Select a model to use for {provider.capitalize()}:", available_models) selected_model = select_choice(
f"Select a model to use for {provider.capitalize()}:", available_models
)
return selected_model return selected_model
def load_provider_data(cache_file, cache_expiry): def load_provider_data(cache_file, cache_expiry):
""" """
Loads provider data from a cache file if it exists and is not expired. If the cache is expired or corrupted, it fetches the data from the web. Loads provider data from a cache file if it exists and is not expired. If the cache is expired or corrupted, it fetches the data from the web.
Args: Args:
- cache_file (Path): The path to the cache file. - cache_file (Path): The path to the cache file.
- cache_expiry (int): The cache expiry time in seconds. - cache_expiry (int): The cache expiry time in seconds.
Returns: Returns:
- dict or None: The loaded provider data or None if the operation fails. - dict or None: The loaded provider data or None if the operation fails.
""" """
current_time = time.time() current_time = time.time()
if cache_file.exists() and (current_time - cache_file.stat().st_mtime) < cache_expiry: if (
cache_file.exists()
and (current_time - cache_file.stat().st_mtime) < cache_expiry
):
data = read_cache_file(cache_file) data = read_cache_file(cache_file)
if data: if data:
return data return data
click.secho("Cache is corrupted. Fetching provider data from the web...", fg="yellow") click.secho(
"Cache is corrupted. Fetching provider data from the web...", fg="yellow"
)
else: else:
click.secho("Cache expired or not found. Fetching provider data from the web...", fg="cyan") click.secho(
"Cache expired or not found. Fetching provider data from the web...",
fg="cyan",
)
return fetch_provider_data(cache_file) return fetch_provider_data(cache_file)
def read_cache_file(cache_file): def read_cache_file(cache_file):
""" """
Reads and returns the JSON content from a cache file. Returns None if the file contains invalid JSON. Reads and returns the JSON content from a cache file. Returns None if the file contains invalid JSON.
Args: Args:
- cache_file (Path): The path to the cache file. - cache_file (Path): The path to the cache file.
Returns: Returns:
- dict or None: The JSON content of the cache file or None if the JSON is invalid. - dict or None: The JSON content of the cache file or None if the JSON is invalid.
""" """
@@ -116,13 +146,14 @@ def read_cache_file(cache_file):
except json.JSONDecodeError: except json.JSONDecodeError:
return None return None
def fetch_provider_data(cache_file): def fetch_provider_data(cache_file):
""" """
Fetches provider data from a specified URL and caches it to a file. Fetches provider data from a specified URL and caches it to a file.
Args: Args:
- cache_file (Path): The path to the cache file. - cache_file (Path): The path to the cache file.
Returns: Returns:
- dict or None: The fetched provider data or None if the operation fails. - dict or None: The fetched provider data or None if the operation fails.
""" """
@@ -139,38 +170,42 @@ def fetch_provider_data(cache_file):
click.secho("Error parsing provider data. Invalid JSON format.", fg="red") click.secho("Error parsing provider data. Invalid JSON format.", fg="red")
return None return None
def download_data(response): def download_data(response):
""" """
Downloads data from a given HTTP response and returns the JSON content. Downloads data from a given HTTP response and returns the JSON content.
Args: Args:
- response (requests.Response): The HTTP response object. - response (requests.Response): The HTTP response object.
Returns: Returns:
- dict: The JSON content of the response. - dict: The JSON content of the response.
""" """
total_size = int(response.headers.get('content-length', 0)) total_size = int(response.headers.get("content-length", 0))
block_size = 8192 block_size = 8192
data_chunks = [] data_chunks = []
with click.progressbar(length=total_size, label='Downloading', show_pos=True) as progress_bar: with click.progressbar(
length=total_size, label="Downloading", show_pos=True
) as progress_bar:
for chunk in response.iter_content(block_size): for chunk in response.iter_content(block_size):
if chunk: if chunk:
data_chunks.append(chunk) data_chunks.append(chunk)
progress_bar.update(len(chunk)) progress_bar.update(len(chunk))
data_content = b''.join(data_chunks) data_content = b"".join(data_chunks)
return json.loads(data_content.decode('utf-8')) return json.loads(data_content.decode("utf-8"))
def get_provider_data(): def get_provider_data():
""" """
Retrieves provider data from a cache file, filters out models based on provider criteria, and returns a dictionary of providers mapped to their models. Retrieves provider data from a cache file, filters out models based on provider criteria, and returns a dictionary of providers mapped to their models.
Returns: Returns:
- dict or None: A dictionary of providers mapped to their models or None if the operation fails. - dict or None: A dictionary of providers mapped to their models or None if the operation fails.
""" """
cache_dir = Path.home() / '.crewai' cache_dir = Path.home() / ".crewai"
cache_dir.mkdir(exist_ok=True) cache_dir.mkdir(exist_ok=True)
cache_file = cache_dir / 'provider_cache.json' cache_file = cache_dir / "provider_cache.json"
cache_expiry = 24 * 3600 cache_expiry = 24 * 3600
data = load_provider_data(cache_file, cache_expiry) data = load_provider_data(cache_file, cache_expiry)
if not data: if not data:
@@ -179,8 +214,8 @@ def get_provider_data():
provider_models = defaultdict(list) provider_models = defaultdict(list)
for model_name, properties in data.items(): for model_name, properties in data.items():
provider = properties.get("litellm_provider", "").strip().lower() provider = properties.get("litellm_provider", "").strip().lower()
if 'http' in provider or provider == 'other': if "http" in provider or provider == "other":
continue continue
if provider: if provider:
provider_models[provider].append(model_name) provider_models[provider].append(model_name)
return provider_models return provider_models