From 1f9baf9b2c7bad1c5ffa96348a6ff0ab4e71219a Mon Sep 17 00:00:00 2001 From: Rip&Tear <84775494+theCyberTech@users.noreply.github.com> Date: Sun, 13 Oct 2024 00:04:05 +0800 Subject: [PATCH] feat: implement crew creation CLI command - refactor code to multiple functions - Added ability for users to select provider and model when uing crewai create command and ave API key to .env --- src/crewai/cli/create_crew.py | 355 +++++++++++++++------------------- 1 file changed, 159 insertions(+), 196 deletions(-) diff --git a/src/crewai/cli/create_crew.py b/src/crewai/cli/create_crew.py index a7818a73f..d1a362000 100644 --- a/src/crewai/cli/create_crew.py +++ b/src/crewai/cli/create_crew.py @@ -1,14 +1,14 @@ +from collections import defaultdict from pathlib import Path import click -import requests -from collections import defaultdict import json +import requests import time -from urllib.parse import urlparse # Added import - from crewai.cli.utils import copy_template -PROVIDERS = ['openai', 'anthropic', 'gemini', 'groq', 'ollama'] +JSON_URL = "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json" + +PROVIDERS = ['openai', 'anthropic', 'gemini', 'groq', 'ollama'] ENV_VARS = { 'openai': ['OPENAI_API_KEY'], @@ -19,17 +19,109 @@ ENV_VARS = { } MODELS = { - 'openai': ['gpt-4', 'gpt-4o', 'gpt-4o-mini','o1-mini', 'o1-preview'], + 'openai': ['gpt-4', 'gpt-4o', 'gpt-4o-mini', 'o1-mini', 'o1-preview'], 'anthropic': ['claude-3-5-sonnet-20240620', 'claude-3-sonnet-20240229', 'claude-3-opus-20240229', 'claude-3-haiku-20240307'], 'gemini': ['gemini-1.5-flash', 'gemini-1.5-pro', 'gemini-gemma-2-9b-it', 'gemini-gemma-2-27b-it'], 'groq': ['llama-3.1-8b-instant', 'llama-3.1-70b-versatile', 'llama-3.1-405b-reasoning', 'gemma2-9b-it', 'gemma-7b-it'], 'ollama': ['llama3.1', 'mixtral'], } -def create_crew(name, parent_folder=None): - """Create a new crew.""" - provider = None - model = None +def load_provider_data(cache_file, cache_expiry): + current_time = time.time() + if cache_file.exists() and (current_time - cache_file.stat().st_mtime) < cache_expiry: + try: + with open(cache_file, "r") as f: + return json.load(f) + except json.JSONDecodeError: + click.secho("Cache is corrupted. Fetching provider data from the web...", fg="yellow") + else: + click.secho("Cache expired or not found. Fetching provider data from the web...", fg="cyan") + + try: + response = requests.get(JSON_URL, stream=True, timeout=10) + response.raise_for_status() + total_size = int(response.headers.get('content-length', 0)) + block_size = 8192 + data_chunks = [] + with click.progressbar(length=total_size, label='Downloading', show_pos=True) as progress_bar: + for chunk in response.iter_content(block_size): + if chunk: + data_chunks.append(chunk) + progress_bar.update(len(chunk)) + data_content = b''.join(data_chunks) + data = json.loads(data_content.decode('utf-8')) + with open(cache_file, "w") as f: + json.dump(data, f) + return data + except requests.RequestException as e: + click.secho(f"Error fetching provider data: {e}", fg="red") + return None + except json.JSONDecodeError: + click.secho("Error parsing provider data. Invalid JSON format.", fg="red") + return None + +def select_choice(prompt_message, choices): + click.secho(prompt_message, fg="cyan") + for idx, choice in enumerate(choices, start=1): + click.secho(f"{idx}. {choice}", fg="cyan") + try: + selected_index = click.prompt("Enter the number of your choice", type=int) - 1 + except click.exceptions.Abort: + click.secho("Operation aborted by the user.", fg="red") + return None + if 0 <= selected_index < len(choices): + return choices[selected_index] + else: + click.secho("Invalid selection.", fg="red") + return None + +def select_provider(provider, all_providers, PROVIDERS): + if provider and provider.lower() not in all_providers and provider.lower() != 'other': + click.secho(f"Invalid provider: {provider}", fg="red") + return None + + if not provider: + options = PROVIDERS + ['other'] + selected_provider = select_choice("Select a provider to set up:", options) + if not selected_provider: + return None + if selected_provider.lower() == 'other': + if not all_providers: + click.secho("No additional providers available.", fg="yellow") + return None + selected_provider = select_choice("Select a provider from the full list:", all_providers) + if not selected_provider: + return None + else: + selected_provider = provider.lower() + if selected_provider == 'other': + if not all_providers: + click.secho("No additional providers available.", fg="yellow") + return None + selected_provider = select_choice("Select a provider from the full list:", all_providers) + if not selected_provider: + return None + + return selected_provider.lower() + +def select_model(provider, predefined_providers, MODELS, provider_models): + provider = provider.lower() + if provider in predefined_providers: + available_models = MODELS.get(provider, []) + else: + available_models = provider_models.get(provider, []) + + if not available_models: + click.secho(f"No models available for provider '{provider}'.", fg="red") + click.secho(f"Available providers: {list(provider_models.keys())}", fg="yellow") + return None + + selected_model = select_choice(f"Select a model to use for {provider.capitalize()}:", available_models) + if not selected_model: + return None + return selected_model + +def create_folder_structure(name, parent_folder=None): folder_name = name.replace(" ", "_").replace("-", "_").lower() class_name = name.replace("_", " ").replace("-", " ").title().replace(" ", "") @@ -53,9 +145,14 @@ def create_crew(name, parent_folder=None): (folder_path / "src" / folder_name / "config").mkdir(parents=True) else: click.secho( - f"\tFolder {folder_name} already exists. Updating .env file...", + f"\tFolder {folder_name} already exists.", fg="yellow", ) + + return folder_path, folder_name, class_name + +def create_crew(name, parent_folder=None): + folder_path, folder_name, class_name = create_folder_structure(name, parent_folder) env_file_path = folder_path / ".env" @@ -67,203 +164,53 @@ def create_crew(name, parent_folder=None): if len(key_value) == 2: env_vars[key_value[0]] = key_value[1] - cache_dir = Path.home() / '.crewai' cache_dir.mkdir(exist_ok=True) cache_file = cache_dir / 'provider_cache.json' - cache_expiry = 24 * 3600 + cache_expiry = 24 * 3600 - json_url = "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json" - current_time = time.time() - data = {} - - """ - Attempts to load provider data from the cache. If the cache is valid (i.e., not expired), it loads the data from the cache. - If the cache is invalid or corrupted, it fetches the provider data from the web. - """ - if cache_file.exists() and (current_time - cache_file.stat().st_mtime) < cache_expiry: - click.secho("Loading provider data from cache...", fg="cyan") - try: - with open(cache_file, "r") as f: - data = json.load(f) - click.secho("Provider data loaded from cache successfully.", fg="green") - except json.JSONDecodeError: - click.secho("Cache is corrupted. Fetching provider data from the web...", fg="yellow") - data = {} - else: - click.secho("Cache expired or not found. Fetching provider data from the web...", fg="cyan") - data = {} - - if not data: - try: - with requests.get(json_url, stream=True, timeout=10) as response: - response.raise_for_status() - total_size = response.headers.get('content-length') - total_size = int(total_size) if total_size else None - block_size = 8192 - data_chunks = [] - - with click.progressbar(length=total_size, label='Downloading', show_pos=True) as progress_bar: - for chunk in response.iter_content(block_size): - if chunk: - data_chunks.append(chunk) - progress_bar.update(len(chunk)) - - data_content = b''.join(data_chunks) - data = json.loads(data_content.decode('utf-8')) - - with open(cache_file, "w") as f: - json.dump(data, f) - click.secho("Provider data fetched and cached successfully.", fg="green") - except requests.RequestException as e: - click.secho(f"Error fetching provider data: {e}", fg="red") - return - except json.JSONDecodeError: - click.secho("Error parsing provider data. Invalid JSON format.", fg="red") - return - - provider_models = defaultdict(list) - for model_name, properties in data.items(): - provider_full = properties.get("litellm_provider") - if provider_full: - provider_key = provider_full.strip().lower() - - if 'http' in provider_key: - click.secho(f"Skipping invalid provider entry: '{provider_full}'", fg="yellow") - continue - - if provider_key and provider_key != 'other': - provider_models[provider_key].append(model_name) - - predefined_providers = [p.lower() for p in PROVIDERS] - all_providers = set(predefined_providers) - all_providers.update(provider_models.keys()) - - all_providers = sorted(all_providers) - - if provider: - provider_lower = provider.lower() - if provider_lower == 'other': - all_providers = sorted(provider_models.keys()) - if not all_providers: - click.secho("No additional providers available.", fg="yellow") - return - click.secho("Select a provider from the full list:", fg="cyan") - for index, provider_name in enumerate(all_providers, start=1): - click.secho(f"{index}. {provider_name}", fg="cyan") - - while True: - try: - selected_index = click.prompt( - "Enter the number of your choice", type=int - ) - 1 - if 0 <= selected_index < len(all_providers): - provider = all_providers[selected_index] - break - else: - click.secho("Invalid selection. Please try again.", fg="red") - except click.exceptions.Abort: - click.secho("Operation aborted by the user.", fg="red") - return - else: - - if provider_lower not in provider_models and provider_lower not in [p.lower() for p in PROVIDERS]: - click.secho(f"Invalid provider: {provider}", fg="red") - return - else: - click.secho("Select a provider to set up:", fg="cyan") - for index, provider_name in enumerate(PROVIDERS + ['other'], start=1): - click.secho(f"{index}. {provider_name}", fg="cyan") - - while True: - try: - selected_index = click.prompt( - "Enter the number of your choice", type=int - ) - 1 - if 0 <= selected_index < len(PROVIDERS) + 1: - selected_provider = (PROVIDERS + ['other'])[selected_index] - if selected_provider.lower() == 'other': - if not all_providers: - click.secho("No additional providers available.", fg="yellow") - return - click.secho("Select a provider from the full list:", fg="cyan") - for idx, provider_name in enumerate(all_providers, start=1): - display_name = provider_name.capitalize() - click.secho(f"{idx}. {display_name}", fg="cyan") - - while True: - try: - selected_sub_index = click.prompt( - "Enter the number of your choice", type=int - ) - 1 - if 0 <= selected_sub_index < len(all_providers): - provider = all_providers[selected_sub_index] - break - else: - click.secho("Invalid selection. Please try again.", fg="red") - except click.exceptions.Abort: - click.secho("Operation aborted by the user.", fg="red") - return - break - else: - provider = selected_provider.lower() - break - else: - click.secho("Invalid selection. Please try again.", fg="red") - except click.exceptions.Abort: - click.secho("Operation aborted by the user.", fg="red") - return - - provider = provider.strip().lower() - - if provider in predefined_providers: - available_models = MODELS.get(provider, []) - else: - available_models = provider_models.get(provider, []) - - if not available_models: - click.secho(f"No models available for provider '{provider}'.", fg="red") - click.secho(f"Available providers: {list(provider_models.keys())}", fg="yellow") + provider_models = get_provider_data(cache_file, cache_expiry) + if not provider_models: return - if model: - if model not in available_models: - click.secho(f"Invalid model '{model}' for provider '{provider}'.", fg="red") - return - else: - click.secho(f"Select a model to use for {provider.capitalize()}:", fg="cyan") - for idx, model_name in enumerate(available_models, start=1): - click.secho(f"{idx}. {model_name}", fg="cyan") - - while True: - try: - selected_model_index = click.prompt( - "Enter the number of your choice", type=int - ) - 1 - if 0 <= selected_model_index < len(available_models): - model = available_models[selected_model_index] - break - else: - click.secho("Invalid selection. Please try again.", fg="red") - except click.exceptions.Abort: - click.secho("Operation aborted by the user.", fg="red") - return + predefined_providers = [p.lower() for p in PROVIDERS] + provider_models = {k.lower(): v for k, v in provider_models.items()} - if provider.lower() in ENV_VARS: - api_key_var = ENV_VARS[provider.lower()][0] + all_providers = set(predefined_providers) + all_providers.update(provider_models.keys()) + all_providers = sorted(all_providers) + + selected_provider = select_provider(None, all_providers, predefined_providers) + if not selected_provider: + return + provider = selected_provider.lower() + + selected_model = select_model(provider, predefined_providers, MODELS, provider_models) + if not selected_model: + return + model = selected_model + + if provider in predefined_providers: + api_key_var = ENV_VARS[provider][0] else: - api_key_var = f"{provider.upper()}_API_KEY" + api_key_var = click.prompt( + f"Enter the environment variable name for your {provider.capitalize()} API key", + type=str + ) if api_key_var not in env_vars: - env_vars[api_key_var] = click.prompt( - f"Enter your {provider} API key", type=str, hide_input=True - ) + try: + env_vars[api_key_var] = click.prompt( + f"Enter your {provider.capitalize()} API key", type=str, hide_input=True + ) + except click.exceptions.Abort: + click.secho("Operation aborted by the user.", fg="red") + return else: - click.secho(f"API key already exists for {provider}.", fg="yellow") + click.secho(f"API key already exists for {provider.capitalize()}.", fg="yellow") - if model: - env_vars['MODEL'] = model - click.secho(f"Selected model: {model}", fg="green") + env_vars['MODEL'] = model + click.secho(f"Selected model: {model}", fg="green") with open(env_file_path, "w") as file: for key, value in env_vars.items(): @@ -300,3 +247,19 @@ def create_crew(name, parent_folder=None): copy_template(src_file, dst_file, name, class_name, folder_name) click.secho(f"Crew {name} created successfully!", fg="green", bold=True) + +def get_provider_data(cache_file, cache_expiry): + data = load_provider_data(cache_file, cache_expiry) + if not data: + return None + + provider_models = defaultdict(list) + for model_name, properties in data.items(): + provider_full = properties.get("litellm_provider") + if provider_full: + provider_key = provider_full.strip().lower() + if 'http' in provider_key: + continue + if provider_key and provider_key != 'other': + provider_models[provider_key].append(model_name) + return provider_models \ No newline at end of file