From 775fea180b287bdc7854ee2a12a667d22c6cfd89 Mon Sep 17 00:00:00 2001 From: Brandon Hancock Date: Wed, 30 Oct 2024 11:30:45 -0400 Subject: [PATCH] getting cli and .env to work together for different models --- src/crewai/agent.py | 5 ++ src/crewai/cli/constants.py | 91 +++++++++++++++++++++++---- src/crewai/cli/create_crew.py | 40 +++++------- src/crewai/cli/templates/crew/crew.py | 2 +- src/crewai/cli/templates/crew/main.py | 4 ++ 5 files changed, 105 insertions(+), 37 deletions(-) diff --git a/src/crewai/agent.py b/src/crewai/agent.py index 937710f59..b487c3c5a 100644 --- a/src/crewai/agent.py +++ b/src/crewai/agent.py @@ -122,6 +122,9 @@ class Agent(BaseAgent): def post_init_setup(self): self.agent_ops_agent_name = self.role + print("IN POST INIT SETUP") + print("self.llm:", self.llm) + # Handle different cases for self.llm if isinstance(self.llm, str): # If it's a string, create an LLM instance @@ -130,6 +133,7 @@ class Agent(BaseAgent): # If it's already an LLM instance, keep it as is pass elif self.llm is None: + print("No LLM provided") # If it's None, use environment variables or default model_name = os.environ.get("OPENAI_MODEL_NAME", "gpt-4o-mini") llm_params = {"model": model_name} @@ -146,6 +150,7 @@ class Agent(BaseAgent): self.llm = LLM(**llm_params) else: + print("IN ELSE") # For any other type, attempt to extract relevant attributes llm_params = { "model": getattr(self.llm, "model_name", None) diff --git a/src/crewai/cli/constants.py b/src/crewai/cli/constants.py index 9a0b36c39..94932c0c7 100644 --- a/src/crewai/cli/constants.py +++ b/src/crewai/cli/constants.py @@ -1,19 +1,86 @@ ENV_VARS = { - 'openai': ['OPENAI_API_KEY'], - 'anthropic': ['ANTHROPIC_API_KEY'], - 'gemini': ['GEMINI_API_KEY'], - 'groq': ['GROQ_API_KEY'], - 'ollama': ['FAKE_KEY'], + "openai": [ + { + "prompt": "Enter your OPENAI API key (press Enter to skip)", + "key_name": "OPENAI_API_KEY", + } + ], + "anthropic": [ + { + "prompt": "Enter your ANTHROPIC API key (press Enter to skip)", + "key_name": "ANTHROPIC_API_KEY", + } + ], + "gemini": [ + { + "prompt": "Enter your GEMINI API key (press Enter to skip)", + "key_name": "GEMINI_API_KEY", + } + ], + "groq": [ + { + "prompt": "Enter your GROQ API key (press Enter to skip)", + "key_name": "GROQ_API_KEY", + } + ], + "watson": [ + { + "prompt": "Enter your WATSONX URL (press Enter to skip)", + "key_name": "WATSONX_URL", + }, + { + "prompt": "Enter your WATSONX API key (press Enter to skip)", + "key_name": "WATSONX_APIKEY", + }, + { + "prompt": "Enter your WATSONX token (press Enter to skip)", + "key_name": "WATSONX_TOKEN", + }, + ], } -PROVIDERS = ['openai', 'anthropic', 'gemini', 'groq', 'ollama'] + +PROVIDERS = ["openai", "anthropic", "gemini", "groq", "ollama", "watson"] MODELS = { - 'openai': ['gpt-4', 'gpt-4o', 'gpt-4o-mini', 'o1-mini', 'o1-preview'], - 'anthropic': ['claude-3-5-sonnet-20240620', 'claude-3-sonnet-20240229', 'claude-3-opus-20240229', 'claude-3-haiku-20240307'], - 'gemini': ['gemini-1.5-flash', 'gemini-1.5-pro', 'gemini-gemma-2-9b-it', 'gemini-gemma-2-27b-it'], - 'groq': ['llama-3.1-8b-instant', 'llama-3.1-70b-versatile', 'llama-3.1-405b-reasoning', 'gemma2-9b-it', 'gemma-7b-it'], - 'ollama': ['llama3.1', 'mixtral'], + "openai": ["gpt-4", "gpt-4o", "gpt-4o-mini", "o1-mini", "o1-preview"], + "anthropic": [ + "claude-3-5-sonnet-20240620", + "claude-3-sonnet-20240229", + "claude-3-opus-20240229", + "claude-3-haiku-20240307", + ], + "gemini": [ + "gemini-1.5-flash", + "gemini-1.5-pro", + "gemini-gemma-2-9b-it", + "gemini-gemma-2-27b-it", + ], + "groq": [ + "llama-3.1-8b-instant", + "llama-3.1-70b-versatile", + "llama-3.1-405b-reasoning", + "gemma2-9b-it", + "gemma-7b-it", + ], + "ollama": ["ollama/llama3.1", "ollama/mixtral"], + "watson": [ + "watsonx/google/flan-t5-xxl", + "watsonx/google/flan-ul2", + "watsonx/bigscience/mt0-xxl", + "watsonx/eleutherai/gpt-neox-20b", + "watsonx/ibm/mpt-7b-instruct2", + "watsonx/bigcode/starcoder", + "watsonx/meta-llama/llama-2-70b-chat", + "watsonx/meta-llama/llama-2-13b-chat", + "watsonx/ibm/granite-13b-instruct-v1", + "watsonx/ibm/granite-13b-chat-v1", + "watsonx/google/flan-t5-xl", + "watsonx/ibm/granite-13b-chat-v2", + "watsonx/ibm/granite-13b-instruct-v2", + "watsonx/elyza/elyza-japanese-llama-2-7b-instruct", + "watsonx/ibm-mistralai/mixtral-8x7b-instruct-v01-q", + ], } -JSON_URL = "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json" \ No newline at end of file +JSON_URL = "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json" diff --git a/src/crewai/cli/create_crew.py b/src/crewai/cli/create_crew.py index 5767b82a1..bbb34c74d 100644 --- a/src/crewai/cli/create_crew.py +++ b/src/crewai/cli/create_crew.py @@ -5,7 +5,6 @@ import click from crewai.cli.constants import ENV_VARS from crewai.cli.provider import ( - PROVIDERS, get_provider_data, select_model, select_provider, @@ -92,7 +91,7 @@ def create_crew(name, provider=None, skip_provider=False, parent_folder=None): existing_provider = None for provider, env_keys in ENV_VARS.items(): - if any(key in env_vars for key in env_keys): + if any(details["key_name"] in env_vars for details in env_keys): existing_provider = provider break @@ -129,35 +128,28 @@ def create_crew(name, provider=None, skip_provider=False, parent_folder=None): "No model selected. Please try again or press 'q' to exit.", fg="red" ) - if selected_provider in PROVIDERS: - api_key_var = ENV_VARS[selected_provider][0] - else: - api_key_var = click.prompt( - f"Enter the environment variable name for your {selected_provider.capitalize()} API key", - type=str, - default="", - ) + # Check if the selected provider requires API keys + if selected_provider in ENV_VARS: + provider_env_vars = ENV_VARS[selected_provider] + for details in provider_env_vars: + prompt = details["prompt"] + key_name = details["key_name"] + api_key_value = click.prompt(prompt, default="", show_default=False) - api_key_value = "" - click.echo( - f"Enter your {selected_provider.capitalize()} API key (press Enter to skip): ", - nl=False, - ) - try: - api_key_value = input() - except (KeyboardInterrupt, EOFError): - api_key_value = "" + if api_key_value.strip(): + env_vars[key_name] = api_key_value - if api_key_value.strip(): - env_vars = {api_key_var: api_key_value} + # Save the selected model to env_vars + env_vars["MODEL"] = selected_model + + if env_vars: write_env_file(folder_path, env_vars) - click.secho("API key saved to .env file", fg="green") + click.secho("API keys and model saved to .env file", fg="green") else: click.secho( - "No API key provided. Skipping .env file creation.", fg="yellow" + "No API keys provided. Skipping .env file creation.", fg="yellow" ) - env_vars["MODEL"] = selected_model click.secho(f"Selected model: {selected_model}", fg="green") package_dir = Path(__file__).parent diff --git a/src/crewai/cli/templates/crew/crew.py b/src/crewai/cli/templates/crew/crew.py index f950d13d4..392e29edd 100644 --- a/src/crewai/cli/templates/crew/crew.py +++ b/src/crewai/cli/templates/crew/crew.py @@ -48,4 +48,4 @@ class {{crew_name}}Crew(): process=Process.sequential, verbose=True, # process=Process.hierarchical, # In case you wanna use that instead https://docs.crewai.com/how-to/Hierarchical/ - ) \ No newline at end of file + ) diff --git a/src/crewai/cli/templates/crew/main.py b/src/crewai/cli/templates/crew/main.py index 88edfcbff..d441fa0fa 100644 --- a/src/crewai/cli/templates/crew/main.py +++ b/src/crewai/cli/templates/crew/main.py @@ -1,7 +1,11 @@ #!/usr/bin/env python import sys +import warnings + from {{folder_name}}.crew import {{crew_name}}Crew +warnings.filterwarnings("ignore", category=SyntaxWarning, module="pysbd") + # This main file is intended to be a way for you to run your # crew locally, so refrain from adding unnecessary logic into this file. # Replace with inputs you want to test with, it will automatically