mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-11 00:58:30 +00:00
feat: implement crew creation CLI command
- refactor code to multiple functions - Added ability for users to select provider and model when uing crewai create command and ave API key to .env
This commit is contained in:
@@ -1,13 +1,13 @@
|
|||||||
|
from collections import defaultdict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import click
|
import click
|
||||||
import requests
|
|
||||||
from collections import defaultdict
|
|
||||||
import json
|
import json
|
||||||
|
import requests
|
||||||
import time
|
import time
|
||||||
from urllib.parse import urlparse # Added import
|
|
||||||
|
|
||||||
from crewai.cli.utils import copy_template
|
from crewai.cli.utils import copy_template
|
||||||
|
|
||||||
|
JSON_URL = "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json"
|
||||||
|
|
||||||
PROVIDERS = ['openai', 'anthropic', 'gemini', 'groq', 'ollama']
|
PROVIDERS = ['openai', 'anthropic', 'gemini', 'groq', 'ollama']
|
||||||
|
|
||||||
ENV_VARS = {
|
ENV_VARS = {
|
||||||
@@ -19,17 +19,109 @@ ENV_VARS = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
MODELS = {
|
MODELS = {
|
||||||
'openai': ['gpt-4', 'gpt-4o', 'gpt-4o-mini','o1-mini', 'o1-preview'],
|
'openai': ['gpt-4', 'gpt-4o', 'gpt-4o-mini', 'o1-mini', 'o1-preview'],
|
||||||
'anthropic': ['claude-3-5-sonnet-20240620', 'claude-3-sonnet-20240229', 'claude-3-opus-20240229', 'claude-3-haiku-20240307'],
|
'anthropic': ['claude-3-5-sonnet-20240620', 'claude-3-sonnet-20240229', 'claude-3-opus-20240229', 'claude-3-haiku-20240307'],
|
||||||
'gemini': ['gemini-1.5-flash', 'gemini-1.5-pro', 'gemini-gemma-2-9b-it', 'gemini-gemma-2-27b-it'],
|
'gemini': ['gemini-1.5-flash', 'gemini-1.5-pro', 'gemini-gemma-2-9b-it', 'gemini-gemma-2-27b-it'],
|
||||||
'groq': ['llama-3.1-8b-instant', 'llama-3.1-70b-versatile', 'llama-3.1-405b-reasoning', 'gemma2-9b-it', 'gemma-7b-it'],
|
'groq': ['llama-3.1-8b-instant', 'llama-3.1-70b-versatile', 'llama-3.1-405b-reasoning', 'gemma2-9b-it', 'gemma-7b-it'],
|
||||||
'ollama': ['llama3.1', 'mixtral'],
|
'ollama': ['llama3.1', 'mixtral'],
|
||||||
}
|
}
|
||||||
|
|
||||||
def create_crew(name, parent_folder=None):
|
def load_provider_data(cache_file, cache_expiry):
|
||||||
"""Create a new crew."""
|
current_time = time.time()
|
||||||
provider = None
|
if cache_file.exists() and (current_time - cache_file.stat().st_mtime) < cache_expiry:
|
||||||
model = None
|
try:
|
||||||
|
with open(cache_file, "r") as f:
|
||||||
|
return json.load(f)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
click.secho("Cache is corrupted. Fetching provider data from the web...", fg="yellow")
|
||||||
|
else:
|
||||||
|
click.secho("Cache expired or not found. Fetching provider data from the web...", fg="cyan")
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = requests.get(JSON_URL, stream=True, timeout=10)
|
||||||
|
response.raise_for_status()
|
||||||
|
total_size = int(response.headers.get('content-length', 0))
|
||||||
|
block_size = 8192
|
||||||
|
data_chunks = []
|
||||||
|
with click.progressbar(length=total_size, label='Downloading', show_pos=True) as progress_bar:
|
||||||
|
for chunk in response.iter_content(block_size):
|
||||||
|
if chunk:
|
||||||
|
data_chunks.append(chunk)
|
||||||
|
progress_bar.update(len(chunk))
|
||||||
|
data_content = b''.join(data_chunks)
|
||||||
|
data = json.loads(data_content.decode('utf-8'))
|
||||||
|
with open(cache_file, "w") as f:
|
||||||
|
json.dump(data, f)
|
||||||
|
return data
|
||||||
|
except requests.RequestException as e:
|
||||||
|
click.secho(f"Error fetching provider data: {e}", fg="red")
|
||||||
|
return None
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
click.secho("Error parsing provider data. Invalid JSON format.", fg="red")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def select_choice(prompt_message, choices):
|
||||||
|
click.secho(prompt_message, fg="cyan")
|
||||||
|
for idx, choice in enumerate(choices, start=1):
|
||||||
|
click.secho(f"{idx}. {choice}", fg="cyan")
|
||||||
|
try:
|
||||||
|
selected_index = click.prompt("Enter the number of your choice", type=int) - 1
|
||||||
|
except click.exceptions.Abort:
|
||||||
|
click.secho("Operation aborted by the user.", fg="red")
|
||||||
|
return None
|
||||||
|
if 0 <= selected_index < len(choices):
|
||||||
|
return choices[selected_index]
|
||||||
|
else:
|
||||||
|
click.secho("Invalid selection.", fg="red")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def select_provider(provider, all_providers, PROVIDERS):
|
||||||
|
if provider and provider.lower() not in all_providers and provider.lower() != 'other':
|
||||||
|
click.secho(f"Invalid provider: {provider}", fg="red")
|
||||||
|
return None
|
||||||
|
|
||||||
|
if not provider:
|
||||||
|
options = PROVIDERS + ['other']
|
||||||
|
selected_provider = select_choice("Select a provider to set up:", options)
|
||||||
|
if not selected_provider:
|
||||||
|
return None
|
||||||
|
if selected_provider.lower() == 'other':
|
||||||
|
if not all_providers:
|
||||||
|
click.secho("No additional providers available.", fg="yellow")
|
||||||
|
return None
|
||||||
|
selected_provider = select_choice("Select a provider from the full list:", all_providers)
|
||||||
|
if not selected_provider:
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
selected_provider = provider.lower()
|
||||||
|
if selected_provider == 'other':
|
||||||
|
if not all_providers:
|
||||||
|
click.secho("No additional providers available.", fg="yellow")
|
||||||
|
return None
|
||||||
|
selected_provider = select_choice("Select a provider from the full list:", all_providers)
|
||||||
|
if not selected_provider:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return selected_provider.lower()
|
||||||
|
|
||||||
|
def select_model(provider, predefined_providers, MODELS, provider_models):
|
||||||
|
provider = provider.lower()
|
||||||
|
if provider in predefined_providers:
|
||||||
|
available_models = MODELS.get(provider, [])
|
||||||
|
else:
|
||||||
|
available_models = provider_models.get(provider, [])
|
||||||
|
|
||||||
|
if not available_models:
|
||||||
|
click.secho(f"No models available for provider '{provider}'.", fg="red")
|
||||||
|
click.secho(f"Available providers: {list(provider_models.keys())}", fg="yellow")
|
||||||
|
return None
|
||||||
|
|
||||||
|
selected_model = select_choice(f"Select a model to use for {provider.capitalize()}:", available_models)
|
||||||
|
if not selected_model:
|
||||||
|
return None
|
||||||
|
return selected_model
|
||||||
|
|
||||||
|
def create_folder_structure(name, parent_folder=None):
|
||||||
folder_name = name.replace(" ", "_").replace("-", "_").lower()
|
folder_name = name.replace(" ", "_").replace("-", "_").lower()
|
||||||
class_name = name.replace("_", " ").replace("-", " ").title().replace(" ", "")
|
class_name = name.replace("_", " ").replace("-", " ").title().replace(" ", "")
|
||||||
|
|
||||||
@@ -53,10 +145,15 @@ def create_crew(name, parent_folder=None):
|
|||||||
(folder_path / "src" / folder_name / "config").mkdir(parents=True)
|
(folder_path / "src" / folder_name / "config").mkdir(parents=True)
|
||||||
else:
|
else:
|
||||||
click.secho(
|
click.secho(
|
||||||
f"\tFolder {folder_name} already exists. Updating .env file...",
|
f"\tFolder {folder_name} already exists.",
|
||||||
fg="yellow",
|
fg="yellow",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
return folder_path, folder_name, class_name
|
||||||
|
|
||||||
|
def create_crew(name, parent_folder=None):
|
||||||
|
folder_path, folder_name, class_name = create_folder_structure(name, parent_folder)
|
||||||
|
|
||||||
env_file_path = folder_path / ".env"
|
env_file_path = folder_path / ".env"
|
||||||
|
|
||||||
env_vars = {}
|
env_vars = {}
|
||||||
@@ -67,201 +164,51 @@ def create_crew(name, parent_folder=None):
|
|||||||
if len(key_value) == 2:
|
if len(key_value) == 2:
|
||||||
env_vars[key_value[0]] = key_value[1]
|
env_vars[key_value[0]] = key_value[1]
|
||||||
|
|
||||||
|
|
||||||
cache_dir = Path.home() / '.crewai'
|
cache_dir = Path.home() / '.crewai'
|
||||||
cache_dir.mkdir(exist_ok=True)
|
cache_dir.mkdir(exist_ok=True)
|
||||||
cache_file = cache_dir / 'provider_cache.json'
|
cache_file = cache_dir / 'provider_cache.json'
|
||||||
cache_expiry = 24 * 3600
|
cache_expiry = 24 * 3600
|
||||||
|
|
||||||
json_url = "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json"
|
provider_models = get_provider_data(cache_file, cache_expiry)
|
||||||
current_time = time.time()
|
if not provider_models:
|
||||||
data = {}
|
|
||||||
|
|
||||||
"""
|
|
||||||
Attempts to load provider data from the cache. If the cache is valid (i.e., not expired), it loads the data from the cache.
|
|
||||||
If the cache is invalid or corrupted, it fetches the provider data from the web.
|
|
||||||
"""
|
|
||||||
if cache_file.exists() and (current_time - cache_file.stat().st_mtime) < cache_expiry:
|
|
||||||
click.secho("Loading provider data from cache...", fg="cyan")
|
|
||||||
try:
|
|
||||||
with open(cache_file, "r") as f:
|
|
||||||
data = json.load(f)
|
|
||||||
click.secho("Provider data loaded from cache successfully.", fg="green")
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
click.secho("Cache is corrupted. Fetching provider data from the web...", fg="yellow")
|
|
||||||
data = {}
|
|
||||||
else:
|
|
||||||
click.secho("Cache expired or not found. Fetching provider data from the web...", fg="cyan")
|
|
||||||
data = {}
|
|
||||||
|
|
||||||
if not data:
|
|
||||||
try:
|
|
||||||
with requests.get(json_url, stream=True, timeout=10) as response:
|
|
||||||
response.raise_for_status()
|
|
||||||
total_size = response.headers.get('content-length')
|
|
||||||
total_size = int(total_size) if total_size else None
|
|
||||||
block_size = 8192
|
|
||||||
data_chunks = []
|
|
||||||
|
|
||||||
with click.progressbar(length=total_size, label='Downloading', show_pos=True) as progress_bar:
|
|
||||||
for chunk in response.iter_content(block_size):
|
|
||||||
if chunk:
|
|
||||||
data_chunks.append(chunk)
|
|
||||||
progress_bar.update(len(chunk))
|
|
||||||
|
|
||||||
data_content = b''.join(data_chunks)
|
|
||||||
data = json.loads(data_content.decode('utf-8'))
|
|
||||||
|
|
||||||
with open(cache_file, "w") as f:
|
|
||||||
json.dump(data, f)
|
|
||||||
click.secho("Provider data fetched and cached successfully.", fg="green")
|
|
||||||
except requests.RequestException as e:
|
|
||||||
click.secho(f"Error fetching provider data: {e}", fg="red")
|
|
||||||
return
|
return
|
||||||
except json.JSONDecodeError:
|
|
||||||
click.secho("Error parsing provider data. Invalid JSON format.", fg="red")
|
|
||||||
return
|
|
||||||
|
|
||||||
provider_models = defaultdict(list)
|
|
||||||
for model_name, properties in data.items():
|
|
||||||
provider_full = properties.get("litellm_provider")
|
|
||||||
if provider_full:
|
|
||||||
provider_key = provider_full.strip().lower()
|
|
||||||
|
|
||||||
if 'http' in provider_key:
|
|
||||||
click.secho(f"Skipping invalid provider entry: '{provider_full}'", fg="yellow")
|
|
||||||
continue
|
|
||||||
|
|
||||||
if provider_key and provider_key != 'other':
|
|
||||||
provider_models[provider_key].append(model_name)
|
|
||||||
|
|
||||||
predefined_providers = [p.lower() for p in PROVIDERS]
|
predefined_providers = [p.lower() for p in PROVIDERS]
|
||||||
|
provider_models = {k.lower(): v for k, v in provider_models.items()}
|
||||||
|
|
||||||
all_providers = set(predefined_providers)
|
all_providers = set(predefined_providers)
|
||||||
all_providers.update(provider_models.keys())
|
all_providers.update(provider_models.keys())
|
||||||
|
|
||||||
all_providers = sorted(all_providers)
|
all_providers = sorted(all_providers)
|
||||||
|
|
||||||
if provider:
|
selected_provider = select_provider(None, all_providers, predefined_providers)
|
||||||
provider_lower = provider.lower()
|
if not selected_provider:
|
||||||
if provider_lower == 'other':
|
|
||||||
all_providers = sorted(provider_models.keys())
|
|
||||||
if not all_providers:
|
|
||||||
click.secho("No additional providers available.", fg="yellow")
|
|
||||||
return
|
return
|
||||||
click.secho("Select a provider from the full list:", fg="cyan")
|
|
||||||
for index, provider_name in enumerate(all_providers, start=1):
|
|
||||||
click.secho(f"{index}. {provider_name}", fg="cyan")
|
|
||||||
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
selected_index = click.prompt(
|
|
||||||
"Enter the number of your choice", type=int
|
|
||||||
) - 1
|
|
||||||
if 0 <= selected_index < len(all_providers):
|
|
||||||
provider = all_providers[selected_index]
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
click.secho("Invalid selection. Please try again.", fg="red")
|
|
||||||
except click.exceptions.Abort:
|
|
||||||
click.secho("Operation aborted by the user.", fg="red")
|
|
||||||
return
|
|
||||||
else:
|
|
||||||
|
|
||||||
if provider_lower not in provider_models and provider_lower not in [p.lower() for p in PROVIDERS]:
|
|
||||||
click.secho(f"Invalid provider: {provider}", fg="red")
|
|
||||||
return
|
|
||||||
else:
|
|
||||||
click.secho("Select a provider to set up:", fg="cyan")
|
|
||||||
for index, provider_name in enumerate(PROVIDERS + ['other'], start=1):
|
|
||||||
click.secho(f"{index}. {provider_name}", fg="cyan")
|
|
||||||
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
selected_index = click.prompt(
|
|
||||||
"Enter the number of your choice", type=int
|
|
||||||
) - 1
|
|
||||||
if 0 <= selected_index < len(PROVIDERS) + 1:
|
|
||||||
selected_provider = (PROVIDERS + ['other'])[selected_index]
|
|
||||||
if selected_provider.lower() == 'other':
|
|
||||||
if not all_providers:
|
|
||||||
click.secho("No additional providers available.", fg="yellow")
|
|
||||||
return
|
|
||||||
click.secho("Select a provider from the full list:", fg="cyan")
|
|
||||||
for idx, provider_name in enumerate(all_providers, start=1):
|
|
||||||
display_name = provider_name.capitalize()
|
|
||||||
click.secho(f"{idx}. {display_name}", fg="cyan")
|
|
||||||
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
selected_sub_index = click.prompt(
|
|
||||||
"Enter the number of your choice", type=int
|
|
||||||
) - 1
|
|
||||||
if 0 <= selected_sub_index < len(all_providers):
|
|
||||||
provider = all_providers[selected_sub_index]
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
click.secho("Invalid selection. Please try again.", fg="red")
|
|
||||||
except click.exceptions.Abort:
|
|
||||||
click.secho("Operation aborted by the user.", fg="red")
|
|
||||||
return
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
provider = selected_provider.lower()
|
provider = selected_provider.lower()
|
||||||
break
|
|
||||||
else:
|
|
||||||
click.secho("Invalid selection. Please try again.", fg="red")
|
|
||||||
except click.exceptions.Abort:
|
|
||||||
click.secho("Operation aborted by the user.", fg="red")
|
|
||||||
return
|
|
||||||
|
|
||||||
provider = provider.strip().lower()
|
selected_model = select_model(provider, predefined_providers, MODELS, provider_models)
|
||||||
|
if not selected_model:
|
||||||
|
return
|
||||||
|
model = selected_model
|
||||||
|
|
||||||
if provider in predefined_providers:
|
if provider in predefined_providers:
|
||||||
available_models = MODELS.get(provider, [])
|
api_key_var = ENV_VARS[provider][0]
|
||||||
else:
|
else:
|
||||||
available_models = provider_models.get(provider, [])
|
api_key_var = click.prompt(
|
||||||
|
f"Enter the environment variable name for your {provider.capitalize()} API key",
|
||||||
|
type=str
|
||||||
|
)
|
||||||
|
|
||||||
if not available_models:
|
if api_key_var not in env_vars:
|
||||||
click.secho(f"No models available for provider '{provider}'.", fg="red")
|
|
||||||
click.secho(f"Available providers: {list(provider_models.keys())}", fg="yellow")
|
|
||||||
return
|
|
||||||
|
|
||||||
if model:
|
|
||||||
if model not in available_models:
|
|
||||||
click.secho(f"Invalid model '{model}' for provider '{provider}'.", fg="red")
|
|
||||||
return
|
|
||||||
else:
|
|
||||||
click.secho(f"Select a model to use for {provider.capitalize()}:", fg="cyan")
|
|
||||||
for idx, model_name in enumerate(available_models, start=1):
|
|
||||||
click.secho(f"{idx}. {model_name}", fg="cyan")
|
|
||||||
|
|
||||||
while True:
|
|
||||||
try:
|
try:
|
||||||
selected_model_index = click.prompt(
|
env_vars[api_key_var] = click.prompt(
|
||||||
"Enter the number of your choice", type=int
|
f"Enter your {provider.capitalize()} API key", type=str, hide_input=True
|
||||||
) - 1
|
)
|
||||||
if 0 <= selected_model_index < len(available_models):
|
|
||||||
model = available_models[selected_model_index]
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
click.secho("Invalid selection. Please try again.", fg="red")
|
|
||||||
except click.exceptions.Abort:
|
except click.exceptions.Abort:
|
||||||
click.secho("Operation aborted by the user.", fg="red")
|
click.secho("Operation aborted by the user.", fg="red")
|
||||||
return
|
return
|
||||||
|
|
||||||
if provider.lower() in ENV_VARS:
|
|
||||||
api_key_var = ENV_VARS[provider.lower()][0]
|
|
||||||
else:
|
else:
|
||||||
api_key_var = f"{provider.upper()}_API_KEY"
|
click.secho(f"API key already exists for {provider.capitalize()}.", fg="yellow")
|
||||||
|
|
||||||
if api_key_var not in env_vars:
|
|
||||||
env_vars[api_key_var] = click.prompt(
|
|
||||||
f"Enter your {provider} API key", type=str, hide_input=True
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
click.secho(f"API key already exists for {provider}.", fg="yellow")
|
|
||||||
|
|
||||||
if model:
|
|
||||||
env_vars['MODEL'] = model
|
env_vars['MODEL'] = model
|
||||||
click.secho(f"Selected model: {model}", fg="green")
|
click.secho(f"Selected model: {model}", fg="green")
|
||||||
|
|
||||||
@@ -300,3 +247,19 @@ def create_crew(name, parent_folder=None):
|
|||||||
copy_template(src_file, dst_file, name, class_name, folder_name)
|
copy_template(src_file, dst_file, name, class_name, folder_name)
|
||||||
|
|
||||||
click.secho(f"Crew {name} created successfully!", fg="green", bold=True)
|
click.secho(f"Crew {name} created successfully!", fg="green", bold=True)
|
||||||
|
|
||||||
|
def get_provider_data(cache_file, cache_expiry):
|
||||||
|
data = load_provider_data(cache_file, cache_expiry)
|
||||||
|
if not data:
|
||||||
|
return None
|
||||||
|
|
||||||
|
provider_models = defaultdict(list)
|
||||||
|
for model_name, properties in data.items():
|
||||||
|
provider_full = properties.get("litellm_provider")
|
||||||
|
if provider_full:
|
||||||
|
provider_key = provider_full.strip().lower()
|
||||||
|
if 'http' in provider_key:
|
||||||
|
continue
|
||||||
|
if provider_key and provider_key != 'other':
|
||||||
|
provider_models[provider_key].append(model_name)
|
||||||
|
return provider_models
|
||||||
Reference in New Issue
Block a user