diff --git a/src/crewai/cli/constants.py b/src/crewai/cli/constants.py new file mode 100644 index 000000000..9a0b36c39 --- /dev/null +++ b/src/crewai/cli/constants.py @@ -0,0 +1,19 @@ +ENV_VARS = { + 'openai': ['OPENAI_API_KEY'], + 'anthropic': ['ANTHROPIC_API_KEY'], + 'gemini': ['GEMINI_API_KEY'], + 'groq': ['GROQ_API_KEY'], + 'ollama': ['FAKE_KEY'], +} + +PROVIDERS = ['openai', 'anthropic', 'gemini', 'groq', 'ollama'] + +MODELS = { + 'openai': ['gpt-4', 'gpt-4o', 'gpt-4o-mini', 'o1-mini', 'o1-preview'], + 'anthropic': ['claude-3-5-sonnet-20240620', 'claude-3-sonnet-20240229', 'claude-3-opus-20240229', 'claude-3-haiku-20240307'], + 'gemini': ['gemini-1.5-flash', 'gemini-1.5-pro', 'gemini-gemma-2-9b-it', 'gemini-gemma-2-27b-it'], + 'groq': ['llama-3.1-8b-instant', 'llama-3.1-70b-versatile', 'llama-3.1-405b-reasoning', 'gemma2-9b-it', 'gemma-7b-it'], + 'ollama': ['llama3.1', 'mixtral'], +} + +JSON_URL = "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json" \ No newline at end of file diff --git a/src/crewai/cli/create_crew.py b/src/crewai/cli/create_crew.py index 9c559b798..23088cb58 100644 --- a/src/crewai/cli/create_crew.py +++ b/src/crewai/cli/create_crew.py @@ -1,120 +1,8 @@ -from collections import defaultdict from pathlib import Path import click -import json -import requests -import time -from crewai.cli.utils import copy_template - -JSON_URL = "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json" - -PROVIDERS = ['openai', 'anthropic', 'gemini', 'groq', 'ollama'] - -ENV_VARS = { - 'openai': ['OPENAI_API_KEY'], - 'anthropic': ['ANTHROPIC_API_KEY'], - 'gemini': ['GEMINI_API_KEY'], - 'groq': ['GROQ_API_KEY'], - 'ollama': ['FAKE_KEY'], -} - -MODELS = { - 'openai': ['gpt-4', 'gpt-4o', 'gpt-4o-mini', 'o1-mini', 'o1-preview'], - 'anthropic': ['claude-3-5-sonnet-20240620', 'claude-3-sonnet-20240229', 'claude-3-opus-20240229', 'claude-3-haiku-20240307'], - 'gemini': ['gemini-1.5-flash', 'gemini-1.5-pro', 'gemini-gemma-2-9b-it', 'gemini-gemma-2-27b-it'], - 'groq': ['llama-3.1-8b-instant', 'llama-3.1-70b-versatile', 'llama-3.1-405b-reasoning', 'gemma2-9b-it', 'gemma-7b-it'], - 'ollama': ['llama3.1', 'mixtral'], -} - -def load_provider_data(cache_file, cache_expiry): - current_time = time.time() - if cache_file.exists() and (current_time - cache_file.stat().st_mtime) < cache_expiry: - try: - with open(cache_file, "r") as f: - return json.load(f) - except json.JSONDecodeError: - click.secho("Cache is corrupted. Fetching provider data from the web...", fg="yellow") - else: - click.secho("Cache expired or not found. Fetching provider data from the web...", fg="cyan") - - try: - response = requests.get(JSON_URL, stream=True, timeout=10) - response.raise_for_status() - total_size = int(response.headers.get('content-length', 0)) - block_size = 8192 - data_chunks = [] - with click.progressbar(length=total_size, label='Downloading', show_pos=True) as progress_bar: - for chunk in response.iter_content(block_size): - if chunk: - data_chunks.append(chunk) - progress_bar.update(len(chunk)) - data_content = b''.join(data_chunks) - data = json.loads(data_content.decode('utf-8')) - with open(cache_file, "w") as f: - json.dump(data, f) - return data - except requests.RequestException as e: - click.secho(f"Error fetching provider data: {e}", fg="red") - return None - except json.JSONDecodeError: - click.secho("Error parsing provider data. Invalid JSON format.", fg="red") - return None - -def select_choice(prompt_message, choices): - click.secho(prompt_message, fg="cyan") - for idx, choice in enumerate(choices, start=1): - click.secho(f"{idx}. {choice}", fg="cyan") - try: - selected_index = click.prompt("Enter the number of your choice", type=int) - 1 - except click.exceptions.Abort: - click.secho("Operation aborted by the user.", fg="red") - return None - - if not (0 <= selected_index < len(choices)): - click.secho("Invalid selection.", fg="red") - return None - - return choices[selected_index] - -def select_provider(provider, all_providers, predefined_providers): - provider = provider.lower() if provider else None - - if provider and provider not in all_providers and provider != 'other': - click.secho(f"Invalid provider: {provider}", fg="red") - return None - - if not provider: - options = predefined_providers + ['other'] - provider = select_choice("Select a provider to set up:", options) - if not provider: - return None - - if provider == 'other': - if not all_providers: - click.secho("No additional providers available.", fg="yellow") - return None - provider = select_choice("Select a provider from the full list:", all_providers) - if not provider: - return None - - return provider.lower() - -def select_model(provider, predefined_providers, MODELS, provider_models): - provider = provider.lower() - if provider in predefined_providers: - available_models = MODELS.get(provider, []) - else: - available_models = provider_models.get(provider, []) - - if not available_models: - click.secho(f"No models available for provider '{provider}'.", fg="red") - click.secho(f"Available providers: {list(provider_models.keys())}", fg="yellow") - return None - - selected_model = select_choice(f"Select a model to use for {provider.capitalize()}:", available_models) - if not selected_model: - return None - return selected_model +from crewai.cli.utils import copy_template,load_env_vars, write_env_file +from crewai.cli.provider import get_provider_data, select_provider, select_model, PROVIDERS +from crewai.cli.constants import ENV_VARS def create_folder_structure(name, parent_folder=None): folder_name = name.replace(" ", "_").replace("-", "_").lower() @@ -146,46 +34,59 @@ def create_folder_structure(name, parent_folder=None): return folder_path, folder_name, class_name + + +def copy_template_files(folder_path, name, class_name, parent_folder): + package_dir = Path(__file__).parent + templates_dir = package_dir / "templates" / "crew" + + root_template_files = ( + [".gitignore", "pyproject.toml", "README.md"] if not parent_folder else [] + ) + tools_template_files = ["tools/custom_tool.py", "tools/__init__.py"] + config_template_files = ["config/agents.yaml", "config/tasks.yaml"] + src_template_files = ( + ["__init__.py", "main.py", "crew.py"] if not parent_folder else ["crew.py"] + ) + + for file_name in root_template_files: + src_file = templates_dir / file_name + dst_file = folder_path / file_name + copy_template(src_file, dst_file, name, class_name, folder_path.name) + + src_folder = folder_path / "src" / folder_path.name if not parent_folder else folder_path + + for file_name in src_template_files: + src_file = templates_dir / file_name + dst_file = src_folder / file_name + copy_template(src_file, dst_file, name, class_name, folder_path.name) + + if not parent_folder: + for file_name in tools_template_files + config_template_files: + src_file = templates_dir / file_name + dst_file = src_folder / file_name + copy_template(src_file, dst_file, name, class_name, folder_path.name) + + def create_crew(name, parent_folder=None): folder_path, folder_name, class_name = create_folder_structure(name, parent_folder) - - env_file_path = folder_path / ".env" + env_vars = load_env_vars(folder_path) - env_vars = {} - if env_file_path.exists(): - with open(env_file_path, "r") as file: - for line in file: - key_value = line.strip().split('=', 1) - if len(key_value) == 2: - env_vars[key_value[0]] = key_value[1] - - cache_dir = Path.home() / '.crewai' - cache_dir.mkdir(exist_ok=True) - cache_file = cache_dir / 'provider_cache.json' - cache_expiry = 24 * 3600 - - provider_models = get_provider_data(cache_file, cache_expiry) + provider_models = get_provider_data() if not provider_models: return - predefined_providers = [p.lower() for p in PROVIDERS] - provider_models = {k.lower(): v for k, v in provider_models.items()} - - all_providers = set(predefined_providers) - all_providers.update(provider_models.keys()) - all_providers = sorted(all_providers) - - selected_provider = select_provider(None, all_providers, predefined_providers) + selected_provider = select_provider(provider_models) if not selected_provider: return - provider = selected_provider.lower() + provider = selected_provider - selected_model = select_model(provider, predefined_providers, MODELS, provider_models) + selected_model = select_model(provider, provider_models) if not selected_model: return model = selected_model - if provider in predefined_providers: + if provider in PROVIDERS: api_key_var = ENV_VARS[provider][0] else: api_key_var = click.prompt( @@ -193,24 +94,12 @@ def create_crew(name, parent_folder=None): type=str ) - if api_key_var not in env_vars: - try: - env_vars[api_key_var] = click.prompt( - f"Enter your {provider.capitalize()} API key", type=str, hide_input=True - ) - except click.exceptions.Abort: - click.secho("Operation aborted by the user.", fg="red") - return - else: - click.secho(f"API key already exists for {provider.capitalize()}.", fg="yellow") + env_vars = {api_key_var: "YOUR_API_KEY_HERE"} + write_env_file(folder_path, env_vars) env_vars['MODEL'] = model click.secho(f"Selected model: {model}", fg="green") - with open(env_file_path, "w") as file: - for key, value in env_vars.items(): - file.write(f"{key}={value}\n") - package_dir = Path(__file__).parent templates_dir = package_dir / "templates" / "crew" @@ -242,19 +131,3 @@ def create_crew(name, parent_folder=None): copy_template(src_file, dst_file, name, class_name, folder_name) click.secho(f"Crew {name} created successfully!", fg="green", bold=True) - -def get_provider_data(cache_file, cache_expiry): - data = load_provider_data(cache_file, cache_expiry) - if not data: - return None - - provider_models = defaultdict(list) - for model_name, properties in data.items(): - provider_full = properties.get("litellm_provider") - if provider_full: - provider_key = provider_full.strip().lower() - if 'http' in provider_key: - continue - if provider_key and provider_key != 'other': - provider_models[provider_key].append(model_name) - return provider_models \ No newline at end of file diff --git a/src/crewai/cli/provider.py b/src/crewai/cli/provider.py new file mode 100644 index 000000000..6164de82b --- /dev/null +++ b/src/crewai/cli/provider.py @@ -0,0 +1,186 @@ +import json +import time +import requests +from collections import defaultdict +import click +from pathlib import Path # Add this import +from crewai.cli.constants import PROVIDERS, MODELS, JSON_URL + +def select_choice(prompt_message, choices): + """ + Presents a list of choices to the user and prompts them to select one. + + Args: + - prompt_message (str): The message to display to the user before presenting the choices. + - choices (list): A list of options to present to the user. + + Returns: + - str: The selected choice from the list, or None if the operation is aborted or an invalid selection is made. + """ + click.secho(prompt_message, fg="cyan") + for idx, choice in enumerate(choices, start=1): + click.secho(f"{idx}. {choice}", fg="cyan") + try: + selected_index = click.prompt("Enter the number of your choice", type=int) - 1 + except click.exceptions.Abort: + click.secho("Operation aborted by the user.", fg="red") + return None + if not (0 <= selected_index < len(choices)): + click.secho("Invalid selection.", fg="red") + return None + return choices[selected_index] + +def select_provider(provider_models): + """ + Presents a list of providers to the user and prompts them to select one. + + Args: + - provider_models (dict): A dictionary of provider models. + + Returns: + - str: The selected provider, or None if the operation is aborted or an invalid selection is made. + """ + predefined_providers = [p.lower() for p in PROVIDERS] + all_providers = sorted(set(predefined_providers + list(provider_models.keys()))) + + provider = select_choice("Select a provider to set up:", predefined_providers + ['other']) + if not provider: + return None + provider = provider.lower() + + if provider == 'other': + provider = select_choice("Select a provider from the full list:", all_providers) + if not provider: + return None + return provider + +def select_model(provider, provider_models): + """ + Presents a list of models for a given provider to the user and prompts them to select one. + + Args: + - provider (str): The provider for which to select a model. + - provider_models (dict): A dictionary of provider models. + + Returns: + - str: The selected model, or None if the operation is aborted or an invalid selection is made. + """ + predefined_providers = [p.lower() for p in PROVIDERS] + + if provider in predefined_providers: + available_models = MODELS.get(provider, []) + else: + available_models = provider_models.get(provider, []) + + if not available_models: + click.secho(f"No models available for provider '{provider}'.", fg="red") + return None + + selected_model = select_choice(f"Select a model to use for {provider.capitalize()}:", available_models) + return selected_model + +def load_provider_data(cache_file, cache_expiry): + """ + Loads provider data from a cache file if it exists and is not expired. If the cache is expired or corrupted, it fetches the data from the web. + + Args: + - cache_file (Path): The path to the cache file. + - cache_expiry (int): The cache expiry time in seconds. + + Returns: + - dict or None: The loaded provider data or None if the operation fails. + """ + current_time = time.time() + if cache_file.exists() and (current_time - cache_file.stat().st_mtime) < cache_expiry: + data = read_cache_file(cache_file) + if data: + return data + click.secho("Cache is corrupted. Fetching provider data from the web...", fg="yellow") + else: + click.secho("Cache expired or not found. Fetching provider data from the web...", fg="cyan") + return fetch_provider_data(cache_file) + +def read_cache_file(cache_file): + """ + Reads and returns the JSON content from a cache file. Returns None if the file contains invalid JSON. + + Args: + - cache_file (Path): The path to the cache file. + + Returns: + - dict or None: The JSON content of the cache file or None if the JSON is invalid. + """ + try: + with open(cache_file, "r") as f: + return json.load(f) + except json.JSONDecodeError: + return None + +def fetch_provider_data(cache_file): + """ + Fetches provider data from a specified URL and caches it to a file. + + Args: + - cache_file (Path): The path to the cache file. + + Returns: + - dict or None: The fetched provider data or None if the operation fails. + """ + try: + response = requests.get(JSON_URL, stream=True, timeout=10) + response.raise_for_status() + data = download_data(response) + with open(cache_file, "w") as f: + json.dump(data, f) + return data + except requests.RequestException as e: + click.secho(f"Error fetching provider data: {e}", fg="red") + except json.JSONDecodeError: + click.secho("Error parsing provider data. Invalid JSON format.", fg="red") + return None + +def download_data(response): + """ + Downloads data from a given HTTP response and returns the JSON content. + + Args: + - response (requests.Response): The HTTP response object. + + Returns: + - dict: The JSON content of the response. + """ + total_size = int(response.headers.get('content-length', 0)) + block_size = 8192 + data_chunks = [] + with click.progressbar(length=total_size, label='Downloading', show_pos=True) as progress_bar: + for chunk in response.iter_content(block_size): + if chunk: + data_chunks.append(chunk) + progress_bar.update(len(chunk)) + data_content = b''.join(data_chunks) + return json.loads(data_content.decode('utf-8')) + +def get_provider_data(): + """ + Retrieves provider data from a cache file, filters out models based on provider criteria, and returns a dictionary of providers mapped to their models. + + Returns: + - dict or None: A dictionary of providers mapped to their models or None if the operation fails. + """ + cache_dir = Path.home() / '.crewai' + cache_dir.mkdir(exist_ok=True) + cache_file = cache_dir / 'provider_cache.json' + cache_expiry = 24 * 3600 # Cache expiry time in seconds + + data = load_provider_data(cache_file, cache_expiry) + if not data: + return None + + provider_models = defaultdict(list) + for model_name, properties in data.items(): + provider = properties.get("litellm_provider", "").strip().lower() + if 'http' in provider or provider == 'other': + continue + if provider: + provider_models[provider].append(model_name) + return provider_models \ No newline at end of file diff --git a/src/crewai/cli/utils.py b/src/crewai/cli/utils.py index ba553f034..5e87f5ec4 100644 --- a/src/crewai/cli/utils.py +++ b/src/crewai/cli/utils.py @@ -5,6 +5,7 @@ import sys import importlib.metadata from crewai.cli.authentication.utils import TokenManager +from crewai.cli.constants import ENV_VARS from functools import reduce from rich.console import Console from typing import Any, Dict, List @@ -203,3 +204,69 @@ def tree_find_and_replace(directory, find, replace): new_dirpath = os.path.join(path, new_dirname) old_dirpath = os.path.join(path, dirname) os.rename(old_dirpath, new_dirpath) + + +def load_env_vars(folder_path): + """ + Loads environment variables from a .env file in the specified folder path. + + Args: + - folder_path (Path): The path to the folder containing the .env file. + + Returns: + - dict: A dictionary of environment variables. + """ + env_file_path = folder_path / ".env" + env_vars = {} + if env_file_path.exists(): + with open(env_file_path, "r") as file: + for line in file: + key, _, value = line.strip().partition('=') + if key and value: + env_vars[key] = value + return env_vars + +def update_env_vars(env_vars, provider, model): + """ + Updates environment variables with the API key for the selected provider and model. + + Args: + - env_vars (dict): Environment variables dictionary. + - provider (str): Selected provider. + - model (str): Selected model. + + Returns: + - None + """ + api_key_var = ENV_VARS.get(provider, [click.prompt( + f"Enter the environment variable name for your {provider.capitalize()} API key", + type=str + )])[0] + + if api_key_var not in env_vars: + try: + env_vars[api_key_var] = click.prompt( + f"Enter your {provider.capitalize()} API key", type=str, hide_input=True + ) + except click.exceptions.Abort: + click.secho("Operation aborted by the user.", fg="red") + return None + else: + click.secho(f"API key already exists for {provider.capitalize()}.", fg="yellow") + + env_vars['MODEL'] = model + click.secho(f"Selected model: {model}", fg="green") + return env_vars + +def write_env_file(folder_path, env_vars): + """ + Writes environment variables to a .env file in the specified folder. + + Args: + - folder_path (Path): The path to the folder where the .env file will be written. + - env_vars (dict): A dictionary of environment variables to write. + """ + env_file_path = folder_path / ".env" + with open(env_file_path, "w") as file: + for key, value in env_vars.items(): + file.write(f"{key}={value}\n")