feat: implement crew creation CLI command

- refactor code to multiple functions
- Added ability for users to select provider and model when uing crewai create command and ave API key to .env
This commit is contained in:
Rip&Tear
2024-10-13 00:04:05 +08:00
parent 6fbc97b298
commit 1f9baf9b2c

View File

@@ -1,14 +1,14 @@
from collections import defaultdict
from pathlib import Path
import click
import requests
from collections import defaultdict
import json
import requests
import time
from urllib.parse import urlparse # Added import
from crewai.cli.utils import copy_template
PROVIDERS = ['openai', 'anthropic', 'gemini', 'groq', 'ollama']
JSON_URL = "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json"
PROVIDERS = ['openai', 'anthropic', 'gemini', 'groq', 'ollama']
ENV_VARS = {
'openai': ['OPENAI_API_KEY'],
@@ -19,17 +19,109 @@ ENV_VARS = {
}
MODELS = {
'openai': ['gpt-4', 'gpt-4o', 'gpt-4o-mini','o1-mini', 'o1-preview'],
'openai': ['gpt-4', 'gpt-4o', 'gpt-4o-mini', 'o1-mini', 'o1-preview'],
'anthropic': ['claude-3-5-sonnet-20240620', 'claude-3-sonnet-20240229', 'claude-3-opus-20240229', 'claude-3-haiku-20240307'],
'gemini': ['gemini-1.5-flash', 'gemini-1.5-pro', 'gemini-gemma-2-9b-it', 'gemini-gemma-2-27b-it'],
'groq': ['llama-3.1-8b-instant', 'llama-3.1-70b-versatile', 'llama-3.1-405b-reasoning', 'gemma2-9b-it', 'gemma-7b-it'],
'ollama': ['llama3.1', 'mixtral'],
}
def create_crew(name, parent_folder=None):
"""Create a new crew."""
provider = None
model = None
def load_provider_data(cache_file, cache_expiry):
current_time = time.time()
if cache_file.exists() and (current_time - cache_file.stat().st_mtime) < cache_expiry:
try:
with open(cache_file, "r") as f:
return json.load(f)
except json.JSONDecodeError:
click.secho("Cache is corrupted. Fetching provider data from the web...", fg="yellow")
else:
click.secho("Cache expired or not found. Fetching provider data from the web...", fg="cyan")
try:
response = requests.get(JSON_URL, stream=True, timeout=10)
response.raise_for_status()
total_size = int(response.headers.get('content-length', 0))
block_size = 8192
data_chunks = []
with click.progressbar(length=total_size, label='Downloading', show_pos=True) as progress_bar:
for chunk in response.iter_content(block_size):
if chunk:
data_chunks.append(chunk)
progress_bar.update(len(chunk))
data_content = b''.join(data_chunks)
data = json.loads(data_content.decode('utf-8'))
with open(cache_file, "w") as f:
json.dump(data, f)
return data
except requests.RequestException as e:
click.secho(f"Error fetching provider data: {e}", fg="red")
return None
except json.JSONDecodeError:
click.secho("Error parsing provider data. Invalid JSON format.", fg="red")
return None
def select_choice(prompt_message, choices):
click.secho(prompt_message, fg="cyan")
for idx, choice in enumerate(choices, start=1):
click.secho(f"{idx}. {choice}", fg="cyan")
try:
selected_index = click.prompt("Enter the number of your choice", type=int) - 1
except click.exceptions.Abort:
click.secho("Operation aborted by the user.", fg="red")
return None
if 0 <= selected_index < len(choices):
return choices[selected_index]
else:
click.secho("Invalid selection.", fg="red")
return None
def select_provider(provider, all_providers, PROVIDERS):
if provider and provider.lower() not in all_providers and provider.lower() != 'other':
click.secho(f"Invalid provider: {provider}", fg="red")
return None
if not provider:
options = PROVIDERS + ['other']
selected_provider = select_choice("Select a provider to set up:", options)
if not selected_provider:
return None
if selected_provider.lower() == 'other':
if not all_providers:
click.secho("No additional providers available.", fg="yellow")
return None
selected_provider = select_choice("Select a provider from the full list:", all_providers)
if not selected_provider:
return None
else:
selected_provider = provider.lower()
if selected_provider == 'other':
if not all_providers:
click.secho("No additional providers available.", fg="yellow")
return None
selected_provider = select_choice("Select a provider from the full list:", all_providers)
if not selected_provider:
return None
return selected_provider.lower()
def select_model(provider, predefined_providers, MODELS, provider_models):
provider = provider.lower()
if provider in predefined_providers:
available_models = MODELS.get(provider, [])
else:
available_models = provider_models.get(provider, [])
if not available_models:
click.secho(f"No models available for provider '{provider}'.", fg="red")
click.secho(f"Available providers: {list(provider_models.keys())}", fg="yellow")
return None
selected_model = select_choice(f"Select a model to use for {provider.capitalize()}:", available_models)
if not selected_model:
return None
return selected_model
def create_folder_structure(name, parent_folder=None):
folder_name = name.replace(" ", "_").replace("-", "_").lower()
class_name = name.replace("_", " ").replace("-", " ").title().replace(" ", "")
@@ -53,9 +145,14 @@ def create_crew(name, parent_folder=None):
(folder_path / "src" / folder_name / "config").mkdir(parents=True)
else:
click.secho(
f"\tFolder {folder_name} already exists. Updating .env file...",
f"\tFolder {folder_name} already exists.",
fg="yellow",
)
return folder_path, folder_name, class_name
def create_crew(name, parent_folder=None):
folder_path, folder_name, class_name = create_folder_structure(name, parent_folder)
env_file_path = folder_path / ".env"
@@ -67,203 +164,53 @@ def create_crew(name, parent_folder=None):
if len(key_value) == 2:
env_vars[key_value[0]] = key_value[1]
cache_dir = Path.home() / '.crewai'
cache_dir.mkdir(exist_ok=True)
cache_file = cache_dir / 'provider_cache.json'
cache_expiry = 24 * 3600
cache_expiry = 24 * 3600
json_url = "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json"
current_time = time.time()
data = {}
"""
Attempts to load provider data from the cache. If the cache is valid (i.e., not expired), it loads the data from the cache.
If the cache is invalid or corrupted, it fetches the provider data from the web.
"""
if cache_file.exists() and (current_time - cache_file.stat().st_mtime) < cache_expiry:
click.secho("Loading provider data from cache...", fg="cyan")
try:
with open(cache_file, "r") as f:
data = json.load(f)
click.secho("Provider data loaded from cache successfully.", fg="green")
except json.JSONDecodeError:
click.secho("Cache is corrupted. Fetching provider data from the web...", fg="yellow")
data = {}
else:
click.secho("Cache expired or not found. Fetching provider data from the web...", fg="cyan")
data = {}
if not data:
try:
with requests.get(json_url, stream=True, timeout=10) as response:
response.raise_for_status()
total_size = response.headers.get('content-length')
total_size = int(total_size) if total_size else None
block_size = 8192
data_chunks = []
with click.progressbar(length=total_size, label='Downloading', show_pos=True) as progress_bar:
for chunk in response.iter_content(block_size):
if chunk:
data_chunks.append(chunk)
progress_bar.update(len(chunk))
data_content = b''.join(data_chunks)
data = json.loads(data_content.decode('utf-8'))
with open(cache_file, "w") as f:
json.dump(data, f)
click.secho("Provider data fetched and cached successfully.", fg="green")
except requests.RequestException as e:
click.secho(f"Error fetching provider data: {e}", fg="red")
return
except json.JSONDecodeError:
click.secho("Error parsing provider data. Invalid JSON format.", fg="red")
return
provider_models = defaultdict(list)
for model_name, properties in data.items():
provider_full = properties.get("litellm_provider")
if provider_full:
provider_key = provider_full.strip().lower()
if 'http' in provider_key:
click.secho(f"Skipping invalid provider entry: '{provider_full}'", fg="yellow")
continue
if provider_key and provider_key != 'other':
provider_models[provider_key].append(model_name)
predefined_providers = [p.lower() for p in PROVIDERS]
all_providers = set(predefined_providers)
all_providers.update(provider_models.keys())
all_providers = sorted(all_providers)
if provider:
provider_lower = provider.lower()
if provider_lower == 'other':
all_providers = sorted(provider_models.keys())
if not all_providers:
click.secho("No additional providers available.", fg="yellow")
return
click.secho("Select a provider from the full list:", fg="cyan")
for index, provider_name in enumerate(all_providers, start=1):
click.secho(f"{index}. {provider_name}", fg="cyan")
while True:
try:
selected_index = click.prompt(
"Enter the number of your choice", type=int
) - 1
if 0 <= selected_index < len(all_providers):
provider = all_providers[selected_index]
break
else:
click.secho("Invalid selection. Please try again.", fg="red")
except click.exceptions.Abort:
click.secho("Operation aborted by the user.", fg="red")
return
else:
if provider_lower not in provider_models and provider_lower not in [p.lower() for p in PROVIDERS]:
click.secho(f"Invalid provider: {provider}", fg="red")
return
else:
click.secho("Select a provider to set up:", fg="cyan")
for index, provider_name in enumerate(PROVIDERS + ['other'], start=1):
click.secho(f"{index}. {provider_name}", fg="cyan")
while True:
try:
selected_index = click.prompt(
"Enter the number of your choice", type=int
) - 1
if 0 <= selected_index < len(PROVIDERS) + 1:
selected_provider = (PROVIDERS + ['other'])[selected_index]
if selected_provider.lower() == 'other':
if not all_providers:
click.secho("No additional providers available.", fg="yellow")
return
click.secho("Select a provider from the full list:", fg="cyan")
for idx, provider_name in enumerate(all_providers, start=1):
display_name = provider_name.capitalize()
click.secho(f"{idx}. {display_name}", fg="cyan")
while True:
try:
selected_sub_index = click.prompt(
"Enter the number of your choice", type=int
) - 1
if 0 <= selected_sub_index < len(all_providers):
provider = all_providers[selected_sub_index]
break
else:
click.secho("Invalid selection. Please try again.", fg="red")
except click.exceptions.Abort:
click.secho("Operation aborted by the user.", fg="red")
return
break
else:
provider = selected_provider.lower()
break
else:
click.secho("Invalid selection. Please try again.", fg="red")
except click.exceptions.Abort:
click.secho("Operation aborted by the user.", fg="red")
return
provider = provider.strip().lower()
if provider in predefined_providers:
available_models = MODELS.get(provider, [])
else:
available_models = provider_models.get(provider, [])
if not available_models:
click.secho(f"No models available for provider '{provider}'.", fg="red")
click.secho(f"Available providers: {list(provider_models.keys())}", fg="yellow")
provider_models = get_provider_data(cache_file, cache_expiry)
if not provider_models:
return
if model:
if model not in available_models:
click.secho(f"Invalid model '{model}' for provider '{provider}'.", fg="red")
return
else:
click.secho(f"Select a model to use for {provider.capitalize()}:", fg="cyan")
for idx, model_name in enumerate(available_models, start=1):
click.secho(f"{idx}. {model_name}", fg="cyan")
while True:
try:
selected_model_index = click.prompt(
"Enter the number of your choice", type=int
) - 1
if 0 <= selected_model_index < len(available_models):
model = available_models[selected_model_index]
break
else:
click.secho("Invalid selection. Please try again.", fg="red")
except click.exceptions.Abort:
click.secho("Operation aborted by the user.", fg="red")
return
predefined_providers = [p.lower() for p in PROVIDERS]
provider_models = {k.lower(): v for k, v in provider_models.items()}
if provider.lower() in ENV_VARS:
api_key_var = ENV_VARS[provider.lower()][0]
all_providers = set(predefined_providers)
all_providers.update(provider_models.keys())
all_providers = sorted(all_providers)
selected_provider = select_provider(None, all_providers, predefined_providers)
if not selected_provider:
return
provider = selected_provider.lower()
selected_model = select_model(provider, predefined_providers, MODELS, provider_models)
if not selected_model:
return
model = selected_model
if provider in predefined_providers:
api_key_var = ENV_VARS[provider][0]
else:
api_key_var = f"{provider.upper()}_API_KEY"
api_key_var = click.prompt(
f"Enter the environment variable name for your {provider.capitalize()} API key",
type=str
)
if api_key_var not in env_vars:
env_vars[api_key_var] = click.prompt(
f"Enter your {provider} API key", type=str, hide_input=True
)
try:
env_vars[api_key_var] = click.prompt(
f"Enter your {provider.capitalize()} API key", type=str, hide_input=True
)
except click.exceptions.Abort:
click.secho("Operation aborted by the user.", fg="red")
return
else:
click.secho(f"API key already exists for {provider}.", fg="yellow")
click.secho(f"API key already exists for {provider.capitalize()}.", fg="yellow")
if model:
env_vars['MODEL'] = model
click.secho(f"Selected model: {model}", fg="green")
env_vars['MODEL'] = model
click.secho(f"Selected model: {model}", fg="green")
with open(env_file_path, "w") as file:
for key, value in env_vars.items():
@@ -300,3 +247,19 @@ def create_crew(name, parent_folder=None):
copy_template(src_file, dst_file, name, class_name, folder_name)
click.secho(f"Crew {name} created successfully!", fg="green", bold=True)
def get_provider_data(cache_file, cache_expiry):
data = load_provider_data(cache_file, cache_expiry)
if not data:
return None
provider_models = defaultdict(list)
for model_name, properties in data.items():
provider_full = properties.get("litellm_provider")
if provider_full:
provider_key = provider_full.strip().lower()
if 'http' in provider_key:
continue
if provider_key and provider_key != 'other':
provider_models[provider_key].append(model_name)
return provider_models