diff --git a/src/crewai/agent.py b/src/crewai/agent.py index b487c3c5a..925ad2e05 100644 --- a/src/crewai/agent.py +++ b/src/crewai/agent.py @@ -8,6 +8,7 @@ from pydantic import Field, InstanceOf, PrivateAttr, model_validator from crewai.agents import CacheHandler from crewai.agents.agent_builder.base_agent import BaseAgent from crewai.agents.crew_agent_executor import CrewAgentExecutor +from crewai.cli.constants import ENV_VARS from crewai.llm import LLM from crewai.memory.contextual.contextual_memory import ContextualMemory from crewai.tools.agent_tools import AgentTools @@ -134,8 +135,12 @@ class Agent(BaseAgent): pass elif self.llm is None: print("No LLM provided") - # If it's None, use environment variables or default - model_name = os.environ.get("OPENAI_MODEL_NAME", "gpt-4o-mini") + # Determine the model name from environment variables or use default + model_name = ( + os.environ.get("OPENAI_MODEL_NAME") + or os.environ.get("MODEL") + or "gpt-4o-mini" + ) llm_params = {"model": model_name} api_base = os.environ.get("OPENAI_API_BASE") or os.environ.get( @@ -144,10 +149,29 @@ class Agent(BaseAgent): if api_base: llm_params["base_url"] = api_base - api_key = os.environ.get("OPENAI_API_KEY") - if api_key: - llm_params["api_key"] = api_key + # Iterate over all environment variables to find matching API keys or use defaults + for provider, env_vars in ENV_VARS.items(): + for env_var in env_vars: + # Check if the environment variable is set + if "key_name" in env_var: + env_value = os.environ.get(env_var["key_name"]) + if env_value: + # Map key names containing "API_KEY" to "api_key" + key_name = ( + "api_key" + if "API_KEY" in env_var["key_name"] + else env_var["key_name"] + ) + llm_params[key_name] = env_value + # Check for default values if the environment variable is not set + elif env_var.get("default", False): + for key, value in env_var.items(): + if key not in ["prompt", "key_name", "default"]: + # Only add default if the key is already set in os.environ + if key in os.environ: + llm_params[key] = value + print("LLM PARAMS:", llm_params) self.llm = LLM(**llm_params) else: print("IN ELSE") diff --git a/src/crewai/cli/constants.py b/src/crewai/cli/constants.py index 94932c0c7..39b024792 100644 --- a/src/crewai/cli/constants.py +++ b/src/crewai/cli/constants.py @@ -37,10 +37,57 @@ ENV_VARS = { "key_name": "WATSONX_TOKEN", }, ], + "ollama": [ + { + "default": True, + "API_BASE": "http://localhost:11434", + } + ], + "bedrock": [ + { + "prompt": "Enter your AWS Access Key ID (press Enter to skip)", + "key_name": "AWS_ACCESS_KEY_ID", + }, + { + "prompt": "Enter your AWS Secret Access Key (press Enter to skip)", + "key_name": "AWS_SECRET_ACCESS_KEY", + }, + { + "prompt": "Enter your AWS Region Name (press Enter to skip)", + "key_name": "AWS_REGION_NAME", + }, + ], + "azure": [ + { + "prompt": "Enter your Azure deployment name (must start with 'azure/')", + "key_name": "model", + }, + { + "prompt": "Enter your AZURE API key (press Enter to skip)", + "key_name": "AZURE_API_KEY", + }, + { + "prompt": "Enter your AZURE API base URL (press Enter to skip)", + "key_name": "AZURE_API_BASE", + }, + { + "prompt": "Enter your AZURE API version (press Enter to skip)", + "key_name": "AZURE_API_VERSION", + }, + ], } -PROVIDERS = ["openai", "anthropic", "gemini", "groq", "ollama", "watson"] +PROVIDERS = [ + "openai", + "anthropic", + "gemini", + "groq", + "ollama", + "watson", + "bedrock", + "azure", +] MODELS = { "openai": ["gpt-4", "gpt-4o", "gpt-4o-mini", "o1-mini", "o1-preview"], @@ -51,17 +98,17 @@ MODELS = { "claude-3-haiku-20240307", ], "gemini": [ - "gemini-1.5-flash", - "gemini-1.5-pro", - "gemini-gemma-2-9b-it", - "gemini-gemma-2-27b-it", + "gemini/gemini-1.5-flash", + "gemini/gemini-1.5-pro", + "gemini/gemini-gemma-2-9b-it", + "gemini/gemini-gemma-2-27b-it", ], "groq": [ - "llama-3.1-8b-instant", - "llama-3.1-70b-versatile", - "llama-3.1-405b-reasoning", - "gemma2-9b-it", - "gemma-7b-it", + "groq/llama-3.1-8b-instant", + "groq/llama-3.1-70b-versatile", + "groq/llama-3.1-405b-reasoning", + "groq/gemma2-9b-it", + "groq/gemma-7b-it", ], "ollama": ["ollama/llama3.1", "ollama/mixtral"], "watson": [ @@ -81,6 +128,30 @@ MODELS = { "watsonx/elyza/elyza-japanese-llama-2-7b-instruct", "watsonx/ibm-mistralai/mixtral-8x7b-instruct-v01-q", ], + "bedrock": [ + "bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0", + "bedrock/anthropic.claude-3-sonnet-20240229-v1:0", + "bedrock/anthropic.claude-3-haiku-20240307-v1:0", + "bedrock/anthropic.claude-3-opus-20240229-v1:0", + "bedrock/anthropic.claude-v2:1", + "bedrock/anthropic.claude-v2", + "bedrock/anthropic.claude-instant-v1", + "bedrock/meta.llama3-1-405b-instruct-v1:0", + "bedrock/meta.llama3-1-70b-instruct-v1:0", + "bedrock/meta.llama3-1-8b-instruct-v1:0", + "bedrock/meta.llama3-70b-instruct-v1:0", + "bedrock/meta.llama3-8b-instruct-v1:0", + "bedrock/amazon.titan-text-lite-v1", + "bedrock/amazon.titan-text-express-v1", + "bedrock/cohere.command-text-v14", + "bedrock/ai21.j2-mid-v1", + "bedrock/ai21.j2-ultra-v1", + "bedrock/ai21.jamba-instruct-v1:0", + "bedrock/meta.llama2-13b-chat-v1", + "bedrock/meta.llama2-70b-chat-v1", + "bedrock/mistral.mistral-7b-instruct-v0:2", + "bedrock/mistral.mixtral-8x7b-instruct-v0:1", + ], } JSON_URL = "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json" diff --git a/src/crewai/cli/create_crew.py b/src/crewai/cli/create_crew.py index bbb34c74d..06440d74e 100644 --- a/src/crewai/cli/create_crew.py +++ b/src/crewai/cli/create_crew.py @@ -1,9 +1,10 @@ +import shutil import sys from pathlib import Path import click -from crewai.cli.constants import ENV_VARS +from crewai.cli.constants import ENV_VARS, MODELS from crewai.cli.provider import ( get_provider_data, select_model, @@ -28,20 +29,20 @@ def create_folder_structure(name, parent_folder=None): click.secho("Operation cancelled.", fg="yellow") sys.exit(0) click.secho(f"Overriding folder {folder_name}...", fg="green", bold=True) - else: - click.secho( - f"Creating {'crew' if parent_folder else 'folder'} {folder_name}...", - fg="green", - bold=True, - ) + shutil.rmtree(folder_path) # Delete the existing folder and its contents - if not folder_path.exists(): - folder_path.mkdir(parents=True) - (folder_path / "tests").mkdir(exist_ok=True) - if not parent_folder: - (folder_path / "src" / folder_name).mkdir(parents=True) - (folder_path / "src" / folder_name / "tools").mkdir(parents=True) - (folder_path / "src" / folder_name / "config").mkdir(parents=True) + click.secho( + f"Creating {'crew' if parent_folder else 'folder'} {folder_name}...", + fg="green", + bold=True, + ) + + folder_path.mkdir(parents=True) + (folder_path / "tests").mkdir(exist_ok=True) + if not parent_folder: + (folder_path / "src" / folder_name).mkdir(parents=True) + (folder_path / "src" / folder_name / "tools").mkdir(parents=True) + (folder_path / "src" / folder_name / "config").mkdir(parents=True) return folder_path, folder_name, class_name @@ -91,7 +92,10 @@ def create_crew(name, provider=None, skip_provider=False, parent_folder=None): existing_provider = None for provider, env_keys in ENV_VARS.items(): - if any(details["key_name"] in env_vars for details in env_keys): + if any( + "key_name" in details and details["key_name"] in env_vars + for details in env_keys + ): existing_provider = provider break @@ -117,30 +121,38 @@ def create_crew(name, provider=None, skip_provider=False, parent_folder=None): "No provider selected. Please try again or press 'q' to exit.", fg="red" ) - while True: - selected_model = select_model(selected_provider, provider_models) - if selected_model is None: # User typed 'q' - click.secho("Exiting...", fg="yellow") - sys.exit(0) - if selected_model: # Valid selection - break - click.secho( - "No model selected. Please try again or press 'q' to exit.", fg="red" - ) + # Check if the selected provider has predefined models + if selected_provider in MODELS and MODELS[selected_provider]: + while True: + selected_model = select_model(selected_provider, provider_models) + if selected_model is None: # User typed 'q' + click.secho("Exiting...", fg="yellow") + sys.exit(0) + if selected_model: # Valid selection + break + click.secho( + "No model selected. Please try again or press 'q' to exit.", + fg="red", + ) + env_vars["MODEL"] = selected_model # Check if the selected provider requires API keys if selected_provider in ENV_VARS: provider_env_vars = ENV_VARS[selected_provider] for details in provider_env_vars: - prompt = details["prompt"] - key_name = details["key_name"] - api_key_value = click.prompt(prompt, default="", show_default=False) + if details.get("default", False): + # Automatically add default key-value pairs + for key, value in details.items(): + if key not in ["prompt", "key_name", "default"]: + env_vars[key] = value + elif "key_name" in details: + # Prompt for non-default key-value pairs + prompt = details["prompt"] + key_name = details["key_name"] + api_key_value = click.prompt(prompt, default="", show_default=False) - if api_key_value.strip(): - env_vars[key_name] = api_key_value - - # Save the selected model to env_vars - env_vars["MODEL"] = selected_model + if api_key_value.strip(): + env_vars[key_name] = api_key_value if env_vars: write_env_file(folder_path, env_vars) @@ -150,7 +162,7 @@ def create_crew(name, provider=None, skip_provider=False, parent_folder=None): "No API keys provided. Skipping .env file creation.", fg="yellow" ) - click.secho(f"Selected model: {selected_model}", fg="green") + click.secho(f"Selected model: {env_vars.get('MODEL', 'N/A')}", fg="green") package_dir = Path(__file__).parent templates_dir = package_dir / "templates" / "crew"