mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-11 00:58:30 +00:00
removed all unnecessary comments
This commit is contained in:
@@ -8,10 +8,8 @@ from urllib.parse import urlparse # Added import
|
|||||||
|
|
||||||
from crewai.cli.utils import copy_template
|
from crewai.cli.utils import copy_template
|
||||||
|
|
||||||
# Constants for predefined providers
|
|
||||||
PROVIDERS = ['openai', 'anthropic', 'gemini', 'groq', 'ollama']
|
PROVIDERS = ['openai', 'anthropic', 'gemini', 'groq', 'ollama']
|
||||||
|
|
||||||
# Each provider has their own environment variables
|
|
||||||
ENV_VARS = {
|
ENV_VARS = {
|
||||||
'openai': ['OPENAI_API_KEY'],
|
'openai': ['OPENAI_API_KEY'],
|
||||||
'anthropic': ['ANTHROPIC_API_KEY'],
|
'anthropic': ['ANTHROPIC_API_KEY'],
|
||||||
@@ -20,7 +18,6 @@ ENV_VARS = {
|
|||||||
'ollama': ['FAKE_KEY'],
|
'ollama': ['FAKE_KEY'],
|
||||||
}
|
}
|
||||||
|
|
||||||
# Each provider has their own models
|
|
||||||
MODELS = {
|
MODELS = {
|
||||||
'openai': ['gpt-4', 'gpt-4o', 'gpt-4o-mini','o1-mini', 'o1-preview'],
|
'openai': ['gpt-4', 'gpt-4o', 'gpt-4o-mini','o1-mini', 'o1-preview'],
|
||||||
'anthropic': ['claude-3-5-sonnet-20240620', 'claude-3-sonnet-20240229', 'claude-3-opus-20240229', 'claude-3-haiku-20240307'],
|
'anthropic': ['claude-3-5-sonnet-20240620', 'claude-3-sonnet-20240229', 'claude-3-opus-20240229', 'claude-3-haiku-20240307'],
|
||||||
@@ -47,7 +44,6 @@ def create_crew(name, parent_folder=None):
|
|||||||
bold=True,
|
bold=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create necessary directories
|
|
||||||
if not folder_path.exists():
|
if not folder_path.exists():
|
||||||
folder_path.mkdir(parents=True)
|
folder_path.mkdir(parents=True)
|
||||||
(folder_path / "tests").mkdir(exist_ok=True)
|
(folder_path / "tests").mkdir(exist_ok=True)
|
||||||
@@ -61,10 +57,8 @@ def create_crew(name, parent_folder=None):
|
|||||||
fg="yellow",
|
fg="yellow",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Path to the .env file
|
|
||||||
env_file_path = folder_path / ".env"
|
env_file_path = folder_path / ".env"
|
||||||
|
|
||||||
# Initialize env_vars
|
|
||||||
env_vars = {}
|
env_vars = {}
|
||||||
if env_file_path.exists():
|
if env_file_path.exists():
|
||||||
with open(env_file_path, "r") as file:
|
with open(env_file_path, "r") as file:
|
||||||
@@ -73,13 +67,12 @@ def create_crew(name, parent_folder=None):
|
|||||||
if len(key_value) == 2:
|
if len(key_value) == 2:
|
||||||
env_vars[key_value[0]] = key_value[1]
|
env_vars[key_value[0]] = key_value[1]
|
||||||
|
|
||||||
# Caching setup
|
|
||||||
cache_dir = Path.home() / '.crewai'
|
cache_dir = Path.home() / '.crewai'
|
||||||
cache_dir.mkdir(exist_ok=True)
|
cache_dir.mkdir(exist_ok=True)
|
||||||
cache_file = cache_dir / 'provider_cache.json'
|
cache_file = cache_dir / 'provider_cache.json'
|
||||||
cache_expiry = 24 * 3600
|
cache_expiry = 24 * 3600
|
||||||
|
|
||||||
# Load API providers and models from JSON with caching
|
|
||||||
json_url = "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json"
|
json_url = "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json"
|
||||||
current_time = time.time()
|
current_time = time.time()
|
||||||
data = {}
|
data = {}
|
||||||
@@ -107,10 +100,9 @@ def create_crew(name, parent_folder=None):
|
|||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
total_size = response.headers.get('content-length')
|
total_size = response.headers.get('content-length')
|
||||||
total_size = int(total_size) if total_size else None
|
total_size = int(total_size) if total_size else None
|
||||||
block_size = 8192 # Increased block size for faster download
|
block_size = 8192
|
||||||
data_chunks = []
|
data_chunks = []
|
||||||
|
|
||||||
# Removed 'dynamic_ncols=True' as it is not a valid argument for Click's progressbar
|
|
||||||
with click.progressbar(length=total_size, label='Downloading', show_pos=True) as progress_bar:
|
with click.progressbar(length=total_size, label='Downloading', show_pos=True) as progress_bar:
|
||||||
for chunk in response.iter_content(block_size):
|
for chunk in response.iter_content(block_size):
|
||||||
if chunk:
|
if chunk:
|
||||||
@@ -120,7 +112,6 @@ def create_crew(name, parent_folder=None):
|
|||||||
data_content = b''.join(data_chunks)
|
data_content = b''.join(data_chunks)
|
||||||
data = json.loads(data_content.decode('utf-8'))
|
data = json.loads(data_content.decode('utf-8'))
|
||||||
|
|
||||||
# Save fetched data to cache
|
|
||||||
with open(cache_file, "w") as f:
|
with open(cache_file, "w") as f:
|
||||||
json.dump(data, f)
|
json.dump(data, f)
|
||||||
click.secho("Provider data fetched and cached successfully.", fg="green")
|
click.secho("Provider data fetched and cached successfully.", fg="green")
|
||||||
@@ -131,34 +122,28 @@ def create_crew(name, parent_folder=None):
|
|||||||
click.secho("Error parsing provider data. Invalid JSON format.", fg="red")
|
click.secho("Error parsing provider data. Invalid JSON format.", fg="red")
|
||||||
return
|
return
|
||||||
|
|
||||||
# Extract unique providers based on 'litellm_provider' and map models
|
|
||||||
provider_models = defaultdict(list)
|
provider_models = defaultdict(list)
|
||||||
for model_name, properties in data.items():
|
for model_name, properties in data.items():
|
||||||
provider_full = properties.get("litellm_provider")
|
provider_full = properties.get("litellm_provider")
|
||||||
if provider_full:
|
if provider_full:
|
||||||
provider_key = provider_full.strip().lower() # Ensure consistent casing and strip whitespace
|
provider_key = provider_full.strip().lower()
|
||||||
|
|
||||||
# Skip invalid provider entries
|
|
||||||
if 'http' in provider_key:
|
if 'http' in provider_key:
|
||||||
click.secho(f"Skipping invalid provider entry: '{provider_full}'", fg="yellow")
|
click.secho(f"Skipping invalid provider entry: '{provider_full}'", fg="yellow")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if provider_key and provider_key != 'other': # Exclude 'other' and empty strings
|
if provider_key and provider_key != 'other':
|
||||||
provider_models[provider_key].append(model_name)
|
provider_models[provider_key].append(model_name)
|
||||||
|
|
||||||
# Merge predefined PROVIDERS with providers from JSON, ensuring consistent casing
|
|
||||||
predefined_providers = [p.lower() for p in PROVIDERS]
|
predefined_providers = [p.lower() for p in PROVIDERS]
|
||||||
all_providers = set(predefined_providers)
|
all_providers = set(predefined_providers)
|
||||||
all_providers.update(provider_models.keys())
|
all_providers.update(provider_models.keys())
|
||||||
|
|
||||||
# Convert to a sorted list for consistent display
|
|
||||||
all_providers = sorted(all_providers)
|
all_providers = sorted(all_providers)
|
||||||
|
|
||||||
# Adjust provider selection logic to handle 'other' by displaying all providers from JSON data
|
|
||||||
if provider:
|
if provider:
|
||||||
provider_lower = provider.lower()
|
provider_lower = provider.lower()
|
||||||
if provider_lower == 'other':
|
if provider_lower == 'other':
|
||||||
# Load all providers from JSON data
|
|
||||||
all_providers = sorted(provider_models.keys())
|
all_providers = sorted(provider_models.keys())
|
||||||
if not all_providers:
|
if not all_providers:
|
||||||
click.secho("No additional providers available.", fg="yellow")
|
click.secho("No additional providers available.", fg="yellow")
|
||||||
@@ -173,7 +158,7 @@ def create_crew(name, parent_folder=None):
|
|||||||
"Enter the number of your choice", type=int
|
"Enter the number of your choice", type=int
|
||||||
) - 1
|
) - 1
|
||||||
if 0 <= selected_index < len(all_providers):
|
if 0 <= selected_index < len(all_providers):
|
||||||
provider = all_providers[selected_index] # Update provider to the selected one
|
provider = all_providers[selected_index]
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
click.secho("Invalid selection. Please try again.", fg="red")
|
click.secho("Invalid selection. Please try again.", fg="red")
|
||||||
@@ -181,12 +166,11 @@ def create_crew(name, parent_folder=None):
|
|||||||
click.secho("Operation aborted by the user.", fg="red")
|
click.secho("Operation aborted by the user.", fg="red")
|
||||||
return
|
return
|
||||||
else:
|
else:
|
||||||
# Validate provider
|
|
||||||
if provider_lower not in provider_models and provider_lower not in [p.lower() for p in PROVIDERS]:
|
if provider_lower not in provider_models and provider_lower not in [p.lower() for p in PROVIDERS]:
|
||||||
click.secho(f"Invalid provider: {provider}", fg="red")
|
click.secho(f"Invalid provider: {provider}", fg="red")
|
||||||
return
|
return
|
||||||
else:
|
else:
|
||||||
# Prompt for provider from predefined PROVIDERS and 'other'
|
|
||||||
click.secho("Select a provider to set up:", fg="cyan")
|
click.secho("Select a provider to set up:", fg="cyan")
|
||||||
for index, provider_name in enumerate(PROVIDERS + ['other'], start=1):
|
for index, provider_name in enumerate(PROVIDERS + ['other'], start=1):
|
||||||
click.secho(f"{index}. {provider_name}", fg="cyan")
|
click.secho(f"{index}. {provider_name}", fg="cyan")
|
||||||
@@ -199,13 +183,12 @@ def create_crew(name, parent_folder=None):
|
|||||||
if 0 <= selected_index < len(PROVIDERS) + 1:
|
if 0 <= selected_index < len(PROVIDERS) + 1:
|
||||||
selected_provider = (PROVIDERS + ['other'])[selected_index]
|
selected_provider = (PROVIDERS + ['other'])[selected_index]
|
||||||
if selected_provider.lower() == 'other':
|
if selected_provider.lower() == 'other':
|
||||||
# Display all providers from JSON data
|
|
||||||
if not all_providers:
|
if not all_providers:
|
||||||
click.secho("No additional providers available.", fg="yellow")
|
click.secho("No additional providers available.", fg="yellow")
|
||||||
return
|
return
|
||||||
click.secho("Select a provider from the full list:", fg="cyan")
|
click.secho("Select a provider from the full list:", fg="cyan")
|
||||||
for idx, provider_name in enumerate(all_providers, start=1):
|
for idx, provider_name in enumerate(all_providers, start=1):
|
||||||
display_name = provider_name.capitalize() # Format for display
|
display_name = provider_name.capitalize()
|
||||||
click.secho(f"{idx}. {display_name}", fg="cyan")
|
click.secho(f"{idx}. {display_name}", fg="cyan")
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
@@ -215,33 +198,29 @@ def create_crew(name, parent_folder=None):
|
|||||||
) - 1
|
) - 1
|
||||||
if 0 <= selected_sub_index < len(all_providers):
|
if 0 <= selected_sub_index < len(all_providers):
|
||||||
provider = all_providers[selected_sub_index]
|
provider = all_providers[selected_sub_index]
|
||||||
break # Break from inner loop
|
break
|
||||||
else:
|
else:
|
||||||
click.secho("Invalid selection. Please try again.", fg="red")
|
click.secho("Invalid selection. Please try again.", fg="red")
|
||||||
except click.exceptions.Abort:
|
except click.exceptions.Abort:
|
||||||
click.secho("Operation aborted by the user.", fg="red")
|
click.secho("Operation aborted by the user.", fg="red")
|
||||||
return
|
return
|
||||||
# **Add this break to exit the outer loop**
|
break
|
||||||
break # Break from outer loop after selecting provider
|
|
||||||
else:
|
else:
|
||||||
provider = selected_provider.lower() # Ensure consistent casing
|
provider = selected_provider.lower()
|
||||||
break # Break from outer loop
|
break
|
||||||
else:
|
else:
|
||||||
click.secho("Invalid selection. Please try again.", fg="red")
|
click.secho("Invalid selection. Please try again.", fg="red")
|
||||||
except click.exceptions.Abort:
|
except click.exceptions.Abort:
|
||||||
click.secho("Operation aborted by the user.", fg="red")
|
click.secho("Operation aborted by the user.", fg="red")
|
||||||
return
|
return
|
||||||
|
|
||||||
# Ensure that 'provider' is in lowercase without extra whitespace
|
|
||||||
provider = provider.strip().lower()
|
provider = provider.strip().lower()
|
||||||
|
|
||||||
# Handle model selection based on provider
|
|
||||||
if provider in predefined_providers:
|
if provider in predefined_providers:
|
||||||
available_models = MODELS.get(provider, [])
|
available_models = MODELS.get(provider, [])
|
||||||
else:
|
else:
|
||||||
available_models = provider_models.get(provider, [])
|
available_models = provider_models.get(provider, [])
|
||||||
|
|
||||||
# Add a debug message if no models are found
|
|
||||||
if not available_models:
|
if not available_models:
|
||||||
click.secho(f"No models available for provider '{provider}'.", fg="red")
|
click.secho(f"No models available for provider '{provider}'.", fg="red")
|
||||||
click.secho(f"Available providers: {list(provider_models.keys())}", fg="yellow")
|
click.secho(f"Available providers: {list(provider_models.keys())}", fg="yellow")
|
||||||
@@ -283,10 +262,9 @@ def create_crew(name, parent_folder=None):
|
|||||||
click.secho(f"API key already exists for {provider}.", fg="yellow")
|
click.secho(f"API key already exists for {provider}.", fg="yellow")
|
||||||
|
|
||||||
if model:
|
if model:
|
||||||
env_vars['MODEL'] = model # Use 'MODEL' as the key name
|
env_vars['MODEL'] = model
|
||||||
click.secho(f"Selected model: {model}", fg="green")
|
click.secho(f"Selected model: {model}", fg="green")
|
||||||
|
|
||||||
# Write the environment variables to .env file
|
|
||||||
with open(env_file_path, "w") as file:
|
with open(env_file_path, "w") as file:
|
||||||
for key, value in env_vars.items():
|
for key, value in env_vars.items():
|
||||||
file.write(f"{key}={value}\n")
|
file.write(f"{key}={value}\n")
|
||||||
@@ -294,7 +272,6 @@ def create_crew(name, parent_folder=None):
|
|||||||
package_dir = Path(__file__).parent
|
package_dir = Path(__file__).parent
|
||||||
templates_dir = package_dir / "templates" / "crew"
|
templates_dir = package_dir / "templates" / "crew"
|
||||||
|
|
||||||
# List of template files to copy
|
|
||||||
root_template_files = (
|
root_template_files = (
|
||||||
[".gitignore", "pyproject.toml", "README.md"] if not parent_folder else []
|
[".gitignore", "pyproject.toml", "README.md"] if not parent_folder else []
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user