Files
crewAI/src/crewai/cli/create_crew.py
2024-10-12 13:22:48 +08:00

303 lines
12 KiB
Python

from pathlib import Path
import click
import requests
from collections import defaultdict
import json
import time
from urllib.parse import urlparse # Added import
from crewai.cli.utils import copy_template
PROVIDERS = ['openai', 'anthropic', 'gemini', 'groq', 'ollama']
ENV_VARS = {
'openai': ['OPENAI_API_KEY'],
'anthropic': ['ANTHROPIC_API_KEY'],
'gemini': ['GEMINI_API_KEY'],
'groq': ['GROQ_API_KEY'],
'ollama': ['FAKE_KEY'],
}
MODELS = {
'openai': ['gpt-4', 'gpt-4o', 'gpt-4o-mini','o1-mini', 'o1-preview'],
'anthropic': ['claude-3-5-sonnet-20240620', 'claude-3-sonnet-20240229', 'claude-3-opus-20240229', 'claude-3-haiku-20240307'],
'gemini': ['gemini-1.5-flash', 'gemini-1.5-pro', 'gemini-gemma-2-9b-it', 'gemini-gemma-2-27b-it'],
'groq': ['llama-3.1-8b-instant', 'llama-3.1-70b-versatile', 'llama-3.1-405b-reasoning', 'gemma2-9b-it', 'gemma-7b-it'],
'ollama': ['llama3.1', 'mixtral'],
}
def create_crew(name, parent_folder=None):
"""Create a new crew."""
provider = None
model = None
folder_name = name.replace(" ", "_").replace("-", "_").lower()
class_name = name.replace("_", " ").replace("-", " ").title().replace(" ", "")
if parent_folder:
folder_path = Path(parent_folder) / folder_name
else:
folder_path = Path(folder_name)
click.secho(
f"Creating {'crew' if parent_folder else 'folder'} {folder_name}...",
fg="green",
bold=True,
)
if not folder_path.exists():
folder_path.mkdir(parents=True)
(folder_path / "tests").mkdir(exist_ok=True)
if not parent_folder:
(folder_path / "src" / folder_name).mkdir(parents=True)
(folder_path / "src" / folder_name / "tools").mkdir(parents=True)
(folder_path / "src" / folder_name / "config").mkdir(parents=True)
else:
click.secho(
f"\tFolder {folder_name} already exists. Updating .env file...",
fg="yellow",
)
env_file_path = folder_path / ".env"
env_vars = {}
if env_file_path.exists():
with open(env_file_path, "r") as file:
for line in file:
key_value = line.strip().split('=', 1)
if len(key_value) == 2:
env_vars[key_value[0]] = key_value[1]
cache_dir = Path.home() / '.crewai'
cache_dir.mkdir(exist_ok=True)
cache_file = cache_dir / 'provider_cache.json'
cache_expiry = 24 * 3600
json_url = "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json"
current_time = time.time()
data = {}
"""
Attempts to load provider data from the cache. If the cache is valid (i.e., not expired), it loads the data from the cache.
If the cache is invalid or corrupted, it fetches the provider data from the web.
"""
if cache_file.exists() and (current_time - cache_file.stat().st_mtime) < cache_expiry:
click.secho("Loading provider data from cache...", fg="cyan")
try:
with open(cache_file, "r") as f:
data = json.load(f)
click.secho("Provider data loaded from cache successfully.", fg="green")
except json.JSONDecodeError:
click.secho("Cache is corrupted. Fetching provider data from the web...", fg="yellow")
data = {}
else:
click.secho("Cache expired or not found. Fetching provider data from the web...", fg="cyan")
data = {}
if not data:
try:
with requests.get(json_url, stream=True, timeout=10) as response:
response.raise_for_status()
total_size = response.headers.get('content-length')
total_size = int(total_size) if total_size else None
block_size = 8192
data_chunks = []
with click.progressbar(length=total_size, label='Downloading', show_pos=True) as progress_bar:
for chunk in response.iter_content(block_size):
if chunk:
data_chunks.append(chunk)
progress_bar.update(len(chunk))
data_content = b''.join(data_chunks)
data = json.loads(data_content.decode('utf-8'))
with open(cache_file, "w") as f:
json.dump(data, f)
click.secho("Provider data fetched and cached successfully.", fg="green")
except requests.RequestException as e:
click.secho(f"Error fetching provider data: {e}", fg="red")
return
except json.JSONDecodeError:
click.secho("Error parsing provider data. Invalid JSON format.", fg="red")
return
provider_models = defaultdict(list)
for model_name, properties in data.items():
provider_full = properties.get("litellm_provider")
if provider_full:
provider_key = provider_full.strip().lower()
if 'http' in provider_key:
click.secho(f"Skipping invalid provider entry: '{provider_full}'", fg="yellow")
continue
if provider_key and provider_key != 'other':
provider_models[provider_key].append(model_name)
predefined_providers = [p.lower() for p in PROVIDERS]
all_providers = set(predefined_providers)
all_providers.update(provider_models.keys())
all_providers = sorted(all_providers)
if provider:
provider_lower = provider.lower()
if provider_lower == 'other':
all_providers = sorted(provider_models.keys())
if not all_providers:
click.secho("No additional providers available.", fg="yellow")
return
click.secho("Select a provider from the full list:", fg="cyan")
for index, provider_name in enumerate(all_providers, start=1):
click.secho(f"{index}. {provider_name}", fg="cyan")
while True:
try:
selected_index = click.prompt(
"Enter the number of your choice", type=int
) - 1
if 0 <= selected_index < len(all_providers):
provider = all_providers[selected_index]
break
else:
click.secho("Invalid selection. Please try again.", fg="red")
except click.exceptions.Abort:
click.secho("Operation aborted by the user.", fg="red")
return
else:
if provider_lower not in provider_models and provider_lower not in [p.lower() for p in PROVIDERS]:
click.secho(f"Invalid provider: {provider}", fg="red")
return
else:
click.secho("Select a provider to set up:", fg="cyan")
for index, provider_name in enumerate(PROVIDERS + ['other'], start=1):
click.secho(f"{index}. {provider_name}", fg="cyan")
while True:
try:
selected_index = click.prompt(
"Enter the number of your choice", type=int
) - 1
if 0 <= selected_index < len(PROVIDERS) + 1:
selected_provider = (PROVIDERS + ['other'])[selected_index]
if selected_provider.lower() == 'other':
if not all_providers:
click.secho("No additional providers available.", fg="yellow")
return
click.secho("Select a provider from the full list:", fg="cyan")
for idx, provider_name in enumerate(all_providers, start=1):
display_name = provider_name.capitalize()
click.secho(f"{idx}. {display_name}", fg="cyan")
while True:
try:
selected_sub_index = click.prompt(
"Enter the number of your choice", type=int
) - 1
if 0 <= selected_sub_index < len(all_providers):
provider = all_providers[selected_sub_index]
break
else:
click.secho("Invalid selection. Please try again.", fg="red")
except click.exceptions.Abort:
click.secho("Operation aborted by the user.", fg="red")
return
break
else:
provider = selected_provider.lower()
break
else:
click.secho("Invalid selection. Please try again.", fg="red")
except click.exceptions.Abort:
click.secho("Operation aborted by the user.", fg="red")
return
provider = provider.strip().lower()
if provider in predefined_providers:
available_models = MODELS.get(provider, [])
else:
available_models = provider_models.get(provider, [])
if not available_models:
click.secho(f"No models available for provider '{provider}'.", fg="red")
click.secho(f"Available providers: {list(provider_models.keys())}", fg="yellow")
return
if model:
if model not in available_models:
click.secho(f"Invalid model '{model}' for provider '{provider}'.", fg="red")
return
else:
click.secho(f"Select a model to use for {provider.capitalize()}:", fg="cyan")
for idx, model_name in enumerate(available_models, start=1):
click.secho(f"{idx}. {model_name}", fg="cyan")
while True:
try:
selected_model_index = click.prompt(
"Enter the number of your choice", type=int
) - 1
if 0 <= selected_model_index < len(available_models):
model = available_models[selected_model_index]
break
else:
click.secho("Invalid selection. Please try again.", fg="red")
except click.exceptions.Abort:
click.secho("Operation aborted by the user.", fg="red")
return
if provider.lower() in ENV_VARS:
api_key_var = ENV_VARS[provider.lower()][0]
else:
api_key_var = f"{provider.upper()}_API_KEY"
if api_key_var not in env_vars:
env_vars[api_key_var] = click.prompt(
f"Enter your {provider} API key", type=str, hide_input=True
)
else:
click.secho(f"API key already exists for {provider}.", fg="yellow")
if model:
env_vars['MODEL'] = model
click.secho(f"Selected model: {model}", fg="green")
with open(env_file_path, "w") as file:
for key, value in env_vars.items():
file.write(f"{key}={value}\n")
package_dir = Path(__file__).parent
templates_dir = package_dir / "templates" / "crew"
root_template_files = (
[".gitignore", "pyproject.toml", "README.md"] if not parent_folder else []
)
tools_template_files = ["tools/custom_tool.py", "tools/__init__.py"]
config_template_files = ["config/agents.yaml", "config/tasks.yaml"]
src_template_files = (
["__init__.py", "main.py", "crew.py"] if not parent_folder else ["crew.py"]
)
for file_name in root_template_files:
src_file = templates_dir / file_name
dst_file = folder_path / file_name
copy_template(src_file, dst_file, name, class_name, folder_name)
src_folder = folder_path / "src" / folder_name if not parent_folder else folder_path
for file_name in src_template_files:
src_file = templates_dir / file_name
dst_file = src_folder / file_name
copy_template(src_file, dst_file, name, class_name, folder_name)
if not parent_folder:
for file_name in tools_template_files + config_template_files:
src_file = templates_dir / file_name
dst_file = src_folder / file_name
copy_template(src_file, dst_file, name, class_name, folder_name)
click.secho(f"Crew {name} created successfully!", fg="green", bold=True)