mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-10 00:28:31 +00:00
allow user to bypass api key entry + incorect number selected logic + ruff formatting
This commit is contained in:
@@ -1,8 +1,15 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import click
|
import click
|
||||||
from crewai.cli.utils import copy_template,load_env_vars, write_env_file
|
from crewai.cli.utils import copy_template, load_env_vars, write_env_file
|
||||||
from crewai.cli.provider import get_provider_data, select_provider, select_model, PROVIDERS
|
from crewai.cli.provider import (
|
||||||
|
get_provider_data,
|
||||||
|
select_provider,
|
||||||
|
select_model,
|
||||||
|
PROVIDERS,
|
||||||
|
)
|
||||||
from crewai.cli.constants import ENV_VARS
|
from crewai.cli.constants import ENV_VARS
|
||||||
|
import sys
|
||||||
|
|
||||||
|
|
||||||
def create_folder_structure(name, parent_folder=None):
|
def create_folder_structure(name, parent_folder=None):
|
||||||
folder_name = name.replace(" ", "_").replace("-", "_").lower()
|
folder_name = name.replace(" ", "_").replace("-", "_").lower()
|
||||||
@@ -13,11 +20,19 @@ def create_folder_structure(name, parent_folder=None):
|
|||||||
else:
|
else:
|
||||||
folder_path = Path(folder_name)
|
folder_path = Path(folder_name)
|
||||||
|
|
||||||
click.secho(
|
if folder_path.exists():
|
||||||
f"Creating {'crew' if parent_folder else 'folder'} {folder_name}...",
|
if not click.confirm(
|
||||||
fg="green",
|
f"Folder {folder_name} already exists. Do you want to override it?"
|
||||||
bold=True,
|
):
|
||||||
)
|
click.secho("Operation cancelled.", fg="yellow")
|
||||||
|
sys.exit(0)
|
||||||
|
click.secho(f"Overriding folder {folder_name}...", fg="green", bold=True)
|
||||||
|
else:
|
||||||
|
click.secho(
|
||||||
|
f"Creating {'crew' if parent_folder else 'folder'} {folder_name}...",
|
||||||
|
fg="green",
|
||||||
|
bold=True,
|
||||||
|
)
|
||||||
|
|
||||||
if not folder_path.exists():
|
if not folder_path.exists():
|
||||||
folder_path.mkdir(parents=True)
|
folder_path.mkdir(parents=True)
|
||||||
@@ -26,16 +41,10 @@ def create_folder_structure(name, parent_folder=None):
|
|||||||
(folder_path / "src" / folder_name).mkdir(parents=True)
|
(folder_path / "src" / folder_name).mkdir(parents=True)
|
||||||
(folder_path / "src" / folder_name / "tools").mkdir(parents=True)
|
(folder_path / "src" / folder_name / "tools").mkdir(parents=True)
|
||||||
(folder_path / "src" / folder_name / "config").mkdir(parents=True)
|
(folder_path / "src" / folder_name / "config").mkdir(parents=True)
|
||||||
else:
|
|
||||||
click.secho(
|
|
||||||
f"\tFolder {folder_name} already exists.",
|
|
||||||
fg="yellow",
|
|
||||||
)
|
|
||||||
|
|
||||||
return folder_path, folder_name, class_name
|
return folder_path, folder_name, class_name
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def copy_template_files(folder_path, name, class_name, parent_folder):
|
def copy_template_files(folder_path, name, class_name, parent_folder):
|
||||||
package_dir = Path(__file__).parent
|
package_dir = Path(__file__).parent
|
||||||
templates_dir = package_dir / "templates" / "crew"
|
templates_dir = package_dir / "templates" / "crew"
|
||||||
@@ -54,7 +63,9 @@ def copy_template_files(folder_path, name, class_name, parent_folder):
|
|||||||
dst_file = folder_path / file_name
|
dst_file = folder_path / file_name
|
||||||
copy_template(src_file, dst_file, name, class_name, folder_path.name)
|
copy_template(src_file, dst_file, name, class_name, folder_path.name)
|
||||||
|
|
||||||
src_folder = folder_path / "src" / folder_path.name if not parent_folder else folder_path
|
src_folder = (
|
||||||
|
folder_path / "src" / folder_path.name if not parent_folder else folder_path
|
||||||
|
)
|
||||||
|
|
||||||
for file_name in src_template_files:
|
for file_name in src_template_files:
|
||||||
src_file = templates_dir / file_name
|
src_file = templates_dir / file_name
|
||||||
@@ -72,39 +83,73 @@ def create_crew(name, parent_folder=None):
|
|||||||
folder_path, folder_name, class_name = create_folder_structure(name, parent_folder)
|
folder_path, folder_name, class_name = create_folder_structure(name, parent_folder)
|
||||||
env_vars = load_env_vars(folder_path)
|
env_vars = load_env_vars(folder_path)
|
||||||
|
|
||||||
|
existing_provider = None
|
||||||
|
for provider, env_keys in ENV_VARS.items():
|
||||||
|
if any(key in env_vars for key in env_keys):
|
||||||
|
existing_provider = provider
|
||||||
|
break
|
||||||
|
|
||||||
|
if existing_provider:
|
||||||
|
if not click.confirm(
|
||||||
|
f"Found existing environment variable configuration for {existing_provider.capitalize()}. Do you want to override it?"
|
||||||
|
):
|
||||||
|
click.secho("Keeping existing provider configuration.", fg="yellow")
|
||||||
|
return
|
||||||
|
|
||||||
provider_models = get_provider_data()
|
provider_models = get_provider_data()
|
||||||
if not provider_models:
|
if not provider_models:
|
||||||
return
|
return
|
||||||
|
|
||||||
selected_provider = select_provider(provider_models)
|
while True:
|
||||||
if not selected_provider:
|
selected_provider = select_provider(provider_models)
|
||||||
return
|
if selected_provider is None: # User typed 'q'
|
||||||
provider = selected_provider
|
click.secho("Exiting...", fg="yellow")
|
||||||
|
sys.exit(0)
|
||||||
selected_model = select_model(provider, provider_models)
|
if selected_provider: # Valid selection
|
||||||
if not selected_model:
|
break
|
||||||
return
|
click.secho(
|
||||||
model = selected_model
|
"No provider selected. Please try again or press 'q' to exit.", fg="red"
|
||||||
|
|
||||||
if provider in PROVIDERS:
|
|
||||||
api_key_var = ENV_VARS[provider][0]
|
|
||||||
else:
|
|
||||||
api_key_var = click.prompt(
|
|
||||||
f"Enter the environment variable name for your {provider.capitalize()} API key",
|
|
||||||
type=str
|
|
||||||
)
|
)
|
||||||
|
|
||||||
api_key_value = click.prompt(
|
while True:
|
||||||
f"Enter your {provider.capitalize()} API key",
|
selected_model = select_model(selected_provider, provider_models)
|
||||||
type=str,
|
if selected_model is None: # User typed 'q'
|
||||||
hide_input=True
|
click.secho("Exiting...", fg="yellow")
|
||||||
|
sys.exit(0)
|
||||||
|
if selected_model: # Valid selection
|
||||||
|
break
|
||||||
|
click.secho(
|
||||||
|
"No model selected. Please try again or press 'q' to exit.", fg="red"
|
||||||
|
)
|
||||||
|
|
||||||
|
if selected_provider in PROVIDERS:
|
||||||
|
api_key_var = ENV_VARS[selected_provider][0]
|
||||||
|
else:
|
||||||
|
api_key_var = click.prompt(
|
||||||
|
f"Enter the environment variable name for your {selected_provider.capitalize()} API key",
|
||||||
|
type=str,
|
||||||
|
default="",
|
||||||
|
)
|
||||||
|
|
||||||
|
api_key_value = ""
|
||||||
|
click.echo(
|
||||||
|
f"Enter your {selected_provider.capitalize()} API key (press Enter to skip): ",
|
||||||
|
nl=False,
|
||||||
)
|
)
|
||||||
|
try:
|
||||||
|
api_key_value = input()
|
||||||
|
except (KeyboardInterrupt, EOFError):
|
||||||
|
api_key_value = ""
|
||||||
|
|
||||||
env_vars = {api_key_var: api_key_value}
|
if api_key_value.strip():
|
||||||
write_env_file(folder_path, env_vars)
|
env_vars = {api_key_var: api_key_value}
|
||||||
|
write_env_file(folder_path, env_vars)
|
||||||
|
click.secho("API key saved to .env file", fg="green")
|
||||||
|
else:
|
||||||
|
click.secho("No API key provided. Skipping .env file creation.", fg="yellow")
|
||||||
|
|
||||||
env_vars['MODEL'] = model
|
env_vars["MODEL"] = selected_model
|
||||||
click.secho(f"Selected model: {model}", fg="green")
|
click.secho(f"Selected model: {selected_model}", fg="green")
|
||||||
|
|
||||||
package_dir = Path(__file__).parent
|
package_dir = Path(__file__).parent
|
||||||
templates_dir = package_dir / "templates" / "crew"
|
templates_dir = package_dir / "templates" / "crew"
|
||||||
|
|||||||
@@ -3,65 +3,83 @@ import time
|
|||||||
import requests
|
import requests
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
import click
|
import click
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from crewai.cli.constants import PROVIDERS, MODELS, JSON_URL
|
from crewai.cli.constants import PROVIDERS, MODELS, JSON_URL
|
||||||
|
|
||||||
|
|
||||||
def select_choice(prompt_message, choices):
|
def select_choice(prompt_message, choices):
|
||||||
"""
|
"""
|
||||||
Presents a list of choices to the user and prompts them to select one.
|
Presents a list of choices to the user and prompts them to select one.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
- prompt_message (str): The message to display to the user before presenting the choices.
|
- prompt_message (str): The message to display to the user before presenting the choices.
|
||||||
- choices (list): A list of options to present to the user.
|
- choices (list): A list of options to present to the user.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
- str: The selected choice from the list, or None if the operation is aborted or an invalid selection is made.
|
- str: The selected choice from the list, or None if the user chooses to quit.
|
||||||
"""
|
"""
|
||||||
click.secho(prompt_message, fg="cyan")
|
click.secho(prompt_message, fg="cyan")
|
||||||
for idx, choice in enumerate(choices, start=1):
|
for idx, choice in enumerate(choices, start=1):
|
||||||
click.secho(f"{idx}. {choice}", fg="cyan")
|
click.secho(f"{idx}. {choice}", fg="cyan")
|
||||||
try:
|
click.secho("q. Quit", fg="cyan")
|
||||||
selected_index = click.prompt("Enter the number of your choice", type=int) - 1
|
|
||||||
except click.exceptions.Abort:
|
while True:
|
||||||
click.secho("Operation aborted by the user.", fg="red")
|
choice = click.prompt(
|
||||||
return None
|
"Enter the number of your choice or 'q' to quit", type=str
|
||||||
if not (0 <= selected_index < len(choices)):
|
)
|
||||||
click.secho("Invalid selection.", fg="red")
|
|
||||||
return None
|
if choice.lower() == "q":
|
||||||
return choices[selected_index]
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
selected_index = int(choice) - 1
|
||||||
|
if 0 <= selected_index < len(choices):
|
||||||
|
return choices[selected_index]
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
click.secho(
|
||||||
|
"Invalid selection. Please select a number between 1 and 6 or 'q' to quit.",
|
||||||
|
fg="red",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def select_provider(provider_models):
|
def select_provider(provider_models):
|
||||||
"""
|
"""
|
||||||
Presents a list of providers to the user and prompts them to select one.
|
Presents a list of providers to the user and prompts them to select one.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
- provider_models (dict): A dictionary of provider models.
|
- provider_models (dict): A dictionary of provider models.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
- str: The selected provider, or None if the operation is aborted or an invalid selection is made.
|
- str: The selected provider
|
||||||
|
- None: If user explicitly quits
|
||||||
"""
|
"""
|
||||||
predefined_providers = [p.lower() for p in PROVIDERS]
|
predefined_providers = [p.lower() for p in PROVIDERS]
|
||||||
all_providers = sorted(set(predefined_providers + list(provider_models.keys())))
|
all_providers = sorted(set(predefined_providers + list(provider_models.keys())))
|
||||||
|
|
||||||
provider = select_choice("Select a provider to set up:", predefined_providers + ['other'])
|
provider = select_choice(
|
||||||
if not provider:
|
"Select a provider to set up:", predefined_providers + ["other"]
|
||||||
|
)
|
||||||
|
if provider is None: # User typed 'q'
|
||||||
return None
|
return None
|
||||||
provider = provider.lower()
|
|
||||||
|
|
||||||
if provider == 'other':
|
if provider == "other":
|
||||||
provider = select_choice("Select a provider from the full list:", all_providers)
|
provider = select_choice("Select a provider from the full list:", all_providers)
|
||||||
if not provider:
|
if provider is None: # User typed 'q'
|
||||||
return None
|
return None
|
||||||
return provider
|
|
||||||
|
return provider.lower() if provider else False
|
||||||
|
|
||||||
|
|
||||||
def select_model(provider, provider_models):
|
def select_model(provider, provider_models):
|
||||||
"""
|
"""
|
||||||
Presents a list of models for a given provider to the user and prompts them to select one.
|
Presents a list of models for a given provider to the user and prompts them to select one.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
- provider (str): The provider for which to select a model.
|
- provider (str): The provider for which to select a model.
|
||||||
- provider_models (dict): A dictionary of provider models.
|
- provider_models (dict): A dictionary of provider models.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
- str: The selected model, or None if the operation is aborted or an invalid selection is made.
|
- str: The selected model, or None if the operation is aborted or an invalid selection is made.
|
||||||
"""
|
"""
|
||||||
@@ -76,37 +94,49 @@ def select_model(provider, provider_models):
|
|||||||
click.secho(f"No models available for provider '{provider}'.", fg="red")
|
click.secho(f"No models available for provider '{provider}'.", fg="red")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
selected_model = select_choice(f"Select a model to use for {provider.capitalize()}:", available_models)
|
selected_model = select_choice(
|
||||||
|
f"Select a model to use for {provider.capitalize()}:", available_models
|
||||||
|
)
|
||||||
return selected_model
|
return selected_model
|
||||||
|
|
||||||
|
|
||||||
def load_provider_data(cache_file, cache_expiry):
|
def load_provider_data(cache_file, cache_expiry):
|
||||||
"""
|
"""
|
||||||
Loads provider data from a cache file if it exists and is not expired. If the cache is expired or corrupted, it fetches the data from the web.
|
Loads provider data from a cache file if it exists and is not expired. If the cache is expired or corrupted, it fetches the data from the web.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
- cache_file (Path): The path to the cache file.
|
- cache_file (Path): The path to the cache file.
|
||||||
- cache_expiry (int): The cache expiry time in seconds.
|
- cache_expiry (int): The cache expiry time in seconds.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
- dict or None: The loaded provider data or None if the operation fails.
|
- dict or None: The loaded provider data or None if the operation fails.
|
||||||
"""
|
"""
|
||||||
current_time = time.time()
|
current_time = time.time()
|
||||||
if cache_file.exists() and (current_time - cache_file.stat().st_mtime) < cache_expiry:
|
if (
|
||||||
|
cache_file.exists()
|
||||||
|
and (current_time - cache_file.stat().st_mtime) < cache_expiry
|
||||||
|
):
|
||||||
data = read_cache_file(cache_file)
|
data = read_cache_file(cache_file)
|
||||||
if data:
|
if data:
|
||||||
return data
|
return data
|
||||||
click.secho("Cache is corrupted. Fetching provider data from the web...", fg="yellow")
|
click.secho(
|
||||||
|
"Cache is corrupted. Fetching provider data from the web...", fg="yellow"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
click.secho("Cache expired or not found. Fetching provider data from the web...", fg="cyan")
|
click.secho(
|
||||||
|
"Cache expired or not found. Fetching provider data from the web...",
|
||||||
|
fg="cyan",
|
||||||
|
)
|
||||||
return fetch_provider_data(cache_file)
|
return fetch_provider_data(cache_file)
|
||||||
|
|
||||||
|
|
||||||
def read_cache_file(cache_file):
|
def read_cache_file(cache_file):
|
||||||
"""
|
"""
|
||||||
Reads and returns the JSON content from a cache file. Returns None if the file contains invalid JSON.
|
Reads and returns the JSON content from a cache file. Returns None if the file contains invalid JSON.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
- cache_file (Path): The path to the cache file.
|
- cache_file (Path): The path to the cache file.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
- dict or None: The JSON content of the cache file or None if the JSON is invalid.
|
- dict or None: The JSON content of the cache file or None if the JSON is invalid.
|
||||||
"""
|
"""
|
||||||
@@ -116,13 +146,14 @@ def read_cache_file(cache_file):
|
|||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def fetch_provider_data(cache_file):
|
def fetch_provider_data(cache_file):
|
||||||
"""
|
"""
|
||||||
Fetches provider data from a specified URL and caches it to a file.
|
Fetches provider data from a specified URL and caches it to a file.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
- cache_file (Path): The path to the cache file.
|
- cache_file (Path): The path to the cache file.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
- dict or None: The fetched provider data or None if the operation fails.
|
- dict or None: The fetched provider data or None if the operation fails.
|
||||||
"""
|
"""
|
||||||
@@ -139,38 +170,42 @@ def fetch_provider_data(cache_file):
|
|||||||
click.secho("Error parsing provider data. Invalid JSON format.", fg="red")
|
click.secho("Error parsing provider data. Invalid JSON format.", fg="red")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def download_data(response):
|
def download_data(response):
|
||||||
"""
|
"""
|
||||||
Downloads data from a given HTTP response and returns the JSON content.
|
Downloads data from a given HTTP response and returns the JSON content.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
- response (requests.Response): The HTTP response object.
|
- response (requests.Response): The HTTP response object.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
- dict: The JSON content of the response.
|
- dict: The JSON content of the response.
|
||||||
"""
|
"""
|
||||||
total_size = int(response.headers.get('content-length', 0))
|
total_size = int(response.headers.get("content-length", 0))
|
||||||
block_size = 8192
|
block_size = 8192
|
||||||
data_chunks = []
|
data_chunks = []
|
||||||
with click.progressbar(length=total_size, label='Downloading', show_pos=True) as progress_bar:
|
with click.progressbar(
|
||||||
|
length=total_size, label="Downloading", show_pos=True
|
||||||
|
) as progress_bar:
|
||||||
for chunk in response.iter_content(block_size):
|
for chunk in response.iter_content(block_size):
|
||||||
if chunk:
|
if chunk:
|
||||||
data_chunks.append(chunk)
|
data_chunks.append(chunk)
|
||||||
progress_bar.update(len(chunk))
|
progress_bar.update(len(chunk))
|
||||||
data_content = b''.join(data_chunks)
|
data_content = b"".join(data_chunks)
|
||||||
return json.loads(data_content.decode('utf-8'))
|
return json.loads(data_content.decode("utf-8"))
|
||||||
|
|
||||||
|
|
||||||
def get_provider_data():
|
def get_provider_data():
|
||||||
"""
|
"""
|
||||||
Retrieves provider data from a cache file, filters out models based on provider criteria, and returns a dictionary of providers mapped to their models.
|
Retrieves provider data from a cache file, filters out models based on provider criteria, and returns a dictionary of providers mapped to their models.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
- dict or None: A dictionary of providers mapped to their models or None if the operation fails.
|
- dict or None: A dictionary of providers mapped to their models or None if the operation fails.
|
||||||
"""
|
"""
|
||||||
cache_dir = Path.home() / '.crewai'
|
cache_dir = Path.home() / ".crewai"
|
||||||
cache_dir.mkdir(exist_ok=True)
|
cache_dir.mkdir(exist_ok=True)
|
||||||
cache_file = cache_dir / 'provider_cache.json'
|
cache_file = cache_dir / "provider_cache.json"
|
||||||
cache_expiry = 24 * 3600
|
cache_expiry = 24 * 3600
|
||||||
|
|
||||||
data = load_provider_data(cache_file, cache_expiry)
|
data = load_provider_data(cache_file, cache_expiry)
|
||||||
if not data:
|
if not data:
|
||||||
@@ -179,8 +214,8 @@ def get_provider_data():
|
|||||||
provider_models = defaultdict(list)
|
provider_models = defaultdict(list)
|
||||||
for model_name, properties in data.items():
|
for model_name, properties in data.items():
|
||||||
provider = properties.get("litellm_provider", "").strip().lower()
|
provider = properties.get("litellm_provider", "").strip().lower()
|
||||||
if 'http' in provider or provider == 'other':
|
if "http" in provider or provider == "other":
|
||||||
continue
|
continue
|
||||||
if provider:
|
if provider:
|
||||||
provider_models[provider].append(model_name)
|
provider_models[provider].append(model_name)
|
||||||
return provider_models
|
return provider_models
|
||||||
|
|||||||
Reference in New Issue
Block a user