mirror of
https://github.com/crewAIInc/crewAI.git
synced 2025-12-16 12:28:30 +00:00
Compare commits
13 Commits
devin/1761
...
feat/cli-m
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a5f70d2307 | ||
|
|
b55fc40c83 | ||
|
|
d0ed4f5274 | ||
|
|
ee34399b71 | ||
|
|
39903f0c50 | ||
|
|
c4bf713113 | ||
|
|
5d18c6312d | ||
|
|
1f9baf9b2c | ||
|
|
6fbc97b298 | ||
|
|
08bacfa892 | ||
|
|
1ea8115d56 | ||
|
|
6b906f09cf | ||
|
|
6c29ebafea |
19
src/crewai/cli/constants.py
Normal file
19
src/crewai/cli/constants.py
Normal file
@@ -0,0 +1,19 @@
|
||||
ENV_VARS = {
|
||||
'openai': ['OPENAI_API_KEY'],
|
||||
'anthropic': ['ANTHROPIC_API_KEY'],
|
||||
'gemini': ['GEMINI_API_KEY'],
|
||||
'groq': ['GROQ_API_KEY'],
|
||||
'ollama': ['FAKE_KEY'],
|
||||
}
|
||||
|
||||
PROVIDERS = ['openai', 'anthropic', 'gemini', 'groq', 'ollama']
|
||||
|
||||
MODELS = {
|
||||
'openai': ['gpt-4', 'gpt-4o', 'gpt-4o-mini', 'o1-mini', 'o1-preview'],
|
||||
'anthropic': ['claude-3-5-sonnet-20240620', 'claude-3-sonnet-20240229', 'claude-3-opus-20240229', 'claude-3-haiku-20240307'],
|
||||
'gemini': ['gemini-1.5-flash', 'gemini-1.5-pro', 'gemini-gemma-2-9b-it', 'gemini-gemma-2-27b-it'],
|
||||
'groq': ['llama-3.1-8b-instant', 'llama-3.1-70b-versatile', 'llama-3.1-405b-reasoning', 'gemma2-9b-it', 'gemma-7b-it'],
|
||||
'ollama': ['llama3.1', 'mixtral'],
|
||||
}
|
||||
|
||||
JSON_URL = "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json"
|
||||
@@ -1,12 +1,10 @@
|
||||
from pathlib import Path
|
||||
|
||||
import click
|
||||
from crewai.cli.utils import copy_template,load_env_vars, write_env_file
|
||||
from crewai.cli.provider import get_provider_data, select_provider, select_model, PROVIDERS
|
||||
from crewai.cli.constants import ENV_VARS
|
||||
|
||||
from crewai.cli.utils import copy_template
|
||||
|
||||
|
||||
def create_crew(name, parent_folder=None):
|
||||
"""Create a new crew."""
|
||||
def create_folder_structure(name, parent_folder=None):
|
||||
folder_name = name.replace(" ", "_").replace("-", "_").lower()
|
||||
class_name = name.replace("_", " ").replace("-", " ").title().replace(" ", "")
|
||||
|
||||
@@ -28,19 +26,83 @@ def create_crew(name, parent_folder=None):
|
||||
(folder_path / "src" / folder_name).mkdir(parents=True)
|
||||
(folder_path / "src" / folder_name / "tools").mkdir(parents=True)
|
||||
(folder_path / "src" / folder_name / "config").mkdir(parents=True)
|
||||
with open(folder_path / ".env", "w") as file:
|
||||
file.write("OPENAI_API_KEY=YOUR_API_KEY")
|
||||
else:
|
||||
click.secho(
|
||||
f"\tFolder {folder_name} already exists. Please choose a different name.",
|
||||
fg="red",
|
||||
f"\tFolder {folder_name} already exists.",
|
||||
fg="yellow",
|
||||
)
|
||||
|
||||
return folder_path, folder_name, class_name
|
||||
|
||||
|
||||
|
||||
def copy_template_files(folder_path, name, class_name, parent_folder):
|
||||
package_dir = Path(__file__).parent
|
||||
templates_dir = package_dir / "templates" / "crew"
|
||||
|
||||
root_template_files = (
|
||||
[".gitignore", "pyproject.toml", "README.md"] if not parent_folder else []
|
||||
)
|
||||
tools_template_files = ["tools/custom_tool.py", "tools/__init__.py"]
|
||||
config_template_files = ["config/agents.yaml", "config/tasks.yaml"]
|
||||
src_template_files = (
|
||||
["__init__.py", "main.py", "crew.py"] if not parent_folder else ["crew.py"]
|
||||
)
|
||||
|
||||
for file_name in root_template_files:
|
||||
src_file = templates_dir / file_name
|
||||
dst_file = folder_path / file_name
|
||||
copy_template(src_file, dst_file, name, class_name, folder_path.name)
|
||||
|
||||
src_folder = folder_path / "src" / folder_path.name if not parent_folder else folder_path
|
||||
|
||||
for file_name in src_template_files:
|
||||
src_file = templates_dir / file_name
|
||||
dst_file = src_folder / file_name
|
||||
copy_template(src_file, dst_file, name, class_name, folder_path.name)
|
||||
|
||||
if not parent_folder:
|
||||
for file_name in tools_template_files + config_template_files:
|
||||
src_file = templates_dir / file_name
|
||||
dst_file = src_folder / file_name
|
||||
copy_template(src_file, dst_file, name, class_name, folder_path.name)
|
||||
|
||||
|
||||
def create_crew(name, parent_folder=None):
|
||||
folder_path, folder_name, class_name = create_folder_structure(name, parent_folder)
|
||||
env_vars = load_env_vars(folder_path)
|
||||
|
||||
provider_models = get_provider_data()
|
||||
if not provider_models:
|
||||
return
|
||||
|
||||
selected_provider = select_provider(provider_models)
|
||||
if not selected_provider:
|
||||
return
|
||||
provider = selected_provider
|
||||
|
||||
selected_model = select_model(provider, provider_models)
|
||||
if not selected_model:
|
||||
return
|
||||
model = selected_model
|
||||
|
||||
if provider in PROVIDERS:
|
||||
api_key_var = ENV_VARS[provider][0]
|
||||
else:
|
||||
api_key_var = click.prompt(
|
||||
f"Enter the environment variable name for your {provider.capitalize()} API key",
|
||||
type=str
|
||||
)
|
||||
|
||||
env_vars = {api_key_var: "YOUR_API_KEY_HERE"}
|
||||
write_env_file(folder_path, env_vars)
|
||||
|
||||
env_vars['MODEL'] = model
|
||||
click.secho(f"Selected model: {model}", fg="green")
|
||||
|
||||
package_dir = Path(__file__).parent
|
||||
templates_dir = package_dir / "templates" / "crew"
|
||||
|
||||
# List of template files to copy
|
||||
root_template_files = (
|
||||
[".gitignore", "pyproject.toml", "README.md"] if not parent_folder else []
|
||||
)
|
||||
|
||||
186
src/crewai/cli/provider.py
Normal file
186
src/crewai/cli/provider.py
Normal file
@@ -0,0 +1,186 @@
|
||||
import json
|
||||
import time
|
||||
import requests
|
||||
from collections import defaultdict
|
||||
import click
|
||||
from pathlib import Path
|
||||
from crewai.cli.constants import PROVIDERS, MODELS, JSON_URL
|
||||
|
||||
def select_choice(prompt_message, choices):
|
||||
"""
|
||||
Presents a list of choices to the user and prompts them to select one.
|
||||
|
||||
Args:
|
||||
- prompt_message (str): The message to display to the user before presenting the choices.
|
||||
- choices (list): A list of options to present to the user.
|
||||
|
||||
Returns:
|
||||
- str: The selected choice from the list, or None if the operation is aborted or an invalid selection is made.
|
||||
"""
|
||||
click.secho(prompt_message, fg="cyan")
|
||||
for idx, choice in enumerate(choices, start=1):
|
||||
click.secho(f"{idx}. {choice}", fg="cyan")
|
||||
try:
|
||||
selected_index = click.prompt("Enter the number of your choice", type=int) - 1
|
||||
except click.exceptions.Abort:
|
||||
click.secho("Operation aborted by the user.", fg="red")
|
||||
return None
|
||||
if not (0 <= selected_index < len(choices)):
|
||||
click.secho("Invalid selection.", fg="red")
|
||||
return None
|
||||
return choices[selected_index]
|
||||
|
||||
def select_provider(provider_models):
|
||||
"""
|
||||
Presents a list of providers to the user and prompts them to select one.
|
||||
|
||||
Args:
|
||||
- provider_models (dict): A dictionary of provider models.
|
||||
|
||||
Returns:
|
||||
- str: The selected provider, or None if the operation is aborted or an invalid selection is made.
|
||||
"""
|
||||
predefined_providers = [p.lower() for p in PROVIDERS]
|
||||
all_providers = sorted(set(predefined_providers + list(provider_models.keys())))
|
||||
|
||||
provider = select_choice("Select a provider to set up:", predefined_providers + ['other'])
|
||||
if not provider:
|
||||
return None
|
||||
provider = provider.lower()
|
||||
|
||||
if provider == 'other':
|
||||
provider = select_choice("Select a provider from the full list:", all_providers)
|
||||
if not provider:
|
||||
return None
|
||||
return provider
|
||||
|
||||
def select_model(provider, provider_models):
|
||||
"""
|
||||
Presents a list of models for a given provider to the user and prompts them to select one.
|
||||
|
||||
Args:
|
||||
- provider (str): The provider for which to select a model.
|
||||
- provider_models (dict): A dictionary of provider models.
|
||||
|
||||
Returns:
|
||||
- str: The selected model, or None if the operation is aborted or an invalid selection is made.
|
||||
"""
|
||||
predefined_providers = [p.lower() for p in PROVIDERS]
|
||||
|
||||
if provider in predefined_providers:
|
||||
available_models = MODELS.get(provider, [])
|
||||
else:
|
||||
available_models = provider_models.get(provider, [])
|
||||
|
||||
if not available_models:
|
||||
click.secho(f"No models available for provider '{provider}'.", fg="red")
|
||||
return None
|
||||
|
||||
selected_model = select_choice(f"Select a model to use for {provider.capitalize()}:", available_models)
|
||||
return selected_model
|
||||
|
||||
def load_provider_data(cache_file, cache_expiry):
|
||||
"""
|
||||
Loads provider data from a cache file if it exists and is not expired. If the cache is expired or corrupted, it fetches the data from the web.
|
||||
|
||||
Args:
|
||||
- cache_file (Path): The path to the cache file.
|
||||
- cache_expiry (int): The cache expiry time in seconds.
|
||||
|
||||
Returns:
|
||||
- dict or None: The loaded provider data or None if the operation fails.
|
||||
"""
|
||||
current_time = time.time()
|
||||
if cache_file.exists() and (current_time - cache_file.stat().st_mtime) < cache_expiry:
|
||||
data = read_cache_file(cache_file)
|
||||
if data:
|
||||
return data
|
||||
click.secho("Cache is corrupted. Fetching provider data from the web...", fg="yellow")
|
||||
else:
|
||||
click.secho("Cache expired or not found. Fetching provider data from the web...", fg="cyan")
|
||||
return fetch_provider_data(cache_file)
|
||||
|
||||
def read_cache_file(cache_file):
|
||||
"""
|
||||
Reads and returns the JSON content from a cache file. Returns None if the file contains invalid JSON.
|
||||
|
||||
Args:
|
||||
- cache_file (Path): The path to the cache file.
|
||||
|
||||
Returns:
|
||||
- dict or None: The JSON content of the cache file or None if the JSON is invalid.
|
||||
"""
|
||||
try:
|
||||
with open(cache_file, "r") as f:
|
||||
return json.load(f)
|
||||
except json.JSONDecodeError:
|
||||
return None
|
||||
|
||||
def fetch_provider_data(cache_file):
|
||||
"""
|
||||
Fetches provider data from a specified URL and caches it to a file.
|
||||
|
||||
Args:
|
||||
- cache_file (Path): The path to the cache file.
|
||||
|
||||
Returns:
|
||||
- dict or None: The fetched provider data or None if the operation fails.
|
||||
"""
|
||||
try:
|
||||
response = requests.get(JSON_URL, stream=True, timeout=10)
|
||||
response.raise_for_status()
|
||||
data = download_data(response)
|
||||
with open(cache_file, "w") as f:
|
||||
json.dump(data, f)
|
||||
return data
|
||||
except requests.RequestException as e:
|
||||
click.secho(f"Error fetching provider data: {e}", fg="red")
|
||||
except json.JSONDecodeError:
|
||||
click.secho("Error parsing provider data. Invalid JSON format.", fg="red")
|
||||
return None
|
||||
|
||||
def download_data(response):
|
||||
"""
|
||||
Downloads data from a given HTTP response and returns the JSON content.
|
||||
|
||||
Args:
|
||||
- response (requests.Response): The HTTP response object.
|
||||
|
||||
Returns:
|
||||
- dict: The JSON content of the response.
|
||||
"""
|
||||
total_size = int(response.headers.get('content-length', 0))
|
||||
block_size = 8192
|
||||
data_chunks = []
|
||||
with click.progressbar(length=total_size, label='Downloading', show_pos=True) as progress_bar:
|
||||
for chunk in response.iter_content(block_size):
|
||||
if chunk:
|
||||
data_chunks.append(chunk)
|
||||
progress_bar.update(len(chunk))
|
||||
data_content = b''.join(data_chunks)
|
||||
return json.loads(data_content.decode('utf-8'))
|
||||
|
||||
def get_provider_data():
|
||||
"""
|
||||
Retrieves provider data from a cache file, filters out models based on provider criteria, and returns a dictionary of providers mapped to their models.
|
||||
|
||||
Returns:
|
||||
- dict or None: A dictionary of providers mapped to their models or None if the operation fails.
|
||||
"""
|
||||
cache_dir = Path.home() / '.crewai'
|
||||
cache_dir.mkdir(exist_ok=True)
|
||||
cache_file = cache_dir / 'provider_cache.json'
|
||||
cache_expiry = 24 * 3600
|
||||
|
||||
data = load_provider_data(cache_file, cache_expiry)
|
||||
if not data:
|
||||
return None
|
||||
|
||||
provider_models = defaultdict(list)
|
||||
for model_name, properties in data.items():
|
||||
provider = properties.get("litellm_provider", "").strip().lower()
|
||||
if 'http' in provider or provider == 'other':
|
||||
continue
|
||||
if provider:
|
||||
provider_models[provider].append(model_name)
|
||||
return provider_models
|
||||
@@ -9,6 +9,7 @@ import click
|
||||
from rich.console import Console
|
||||
|
||||
from crewai.cli.authentication.utils import TokenManager
|
||||
from crewai.cli.constants import ENV_VARS
|
||||
|
||||
if sys.version_info >= (3, 11):
|
||||
import tomllib
|
||||
@@ -200,3 +201,76 @@ def tree_find_and_replace(directory, find, replace):
|
||||
new_dirpath = os.path.join(path, new_dirname)
|
||||
old_dirpath = os.path.join(path, dirname)
|
||||
os.rename(old_dirpath, new_dirpath)
|
||||
|
||||
|
||||
def load_env_vars(folder_path):
|
||||
"""
|
||||
Loads environment variables from a .env file in the specified folder path.
|
||||
|
||||
Args:
|
||||
- folder_path (Path): The path to the folder containing the .env file.
|
||||
|
||||
Returns:
|
||||
- dict: A dictionary of environment variables.
|
||||
"""
|
||||
env_file_path = folder_path / ".env"
|
||||
env_vars = {}
|
||||
if env_file_path.exists():
|
||||
with open(env_file_path, "r") as file:
|
||||
for line in file:
|
||||
key, _, value = line.strip().partition("=")
|
||||
if key and value:
|
||||
env_vars[key] = value
|
||||
return env_vars
|
||||
|
||||
|
||||
def update_env_vars(env_vars, provider, model):
|
||||
"""
|
||||
Updates environment variables with the API key for the selected provider and model.
|
||||
|
||||
Args:
|
||||
- env_vars (dict): Environment variables dictionary.
|
||||
- provider (str): Selected provider.
|
||||
- model (str): Selected model.
|
||||
|
||||
Returns:
|
||||
- None
|
||||
"""
|
||||
api_key_var = ENV_VARS.get(
|
||||
provider,
|
||||
[
|
||||
click.prompt(
|
||||
f"Enter the environment variable name for your {provider.capitalize()} API key",
|
||||
type=str,
|
||||
)
|
||||
],
|
||||
)[0]
|
||||
|
||||
if api_key_var not in env_vars:
|
||||
try:
|
||||
env_vars[api_key_var] = click.prompt(
|
||||
f"Enter your {provider.capitalize()} API key", type=str, hide_input=True
|
||||
)
|
||||
except click.exceptions.Abort:
|
||||
click.secho("Operation aborted by the user.", fg="red")
|
||||
return None
|
||||
else:
|
||||
click.secho(f"API key already exists for {provider.capitalize()}.", fg="yellow")
|
||||
|
||||
env_vars["MODEL"] = model
|
||||
click.secho(f"Selected model: {model}", fg="green")
|
||||
return env_vars
|
||||
|
||||
|
||||
def write_env_file(folder_path, env_vars):
|
||||
"""
|
||||
Writes environment variables to a .env file in the specified folder.
|
||||
|
||||
Args:
|
||||
- folder_path (Path): The path to the folder where the .env file will be written.
|
||||
- env_vars (dict): A dictionary of environment variables to write.
|
||||
"""
|
||||
env_file_path = folder_path / ".env"
|
||||
with open(env_file_path, "w") as file:
|
||||
for key, value in env_vars.items():
|
||||
file.write(f"{key}={value}\n")
|
||||
|
||||
Reference in New Issue
Block a user