diff --git a/src/crewai/cli/constants.py b/src/crewai/cli/constants.py index b2f2307d1..654582b05 100644 --- a/src/crewai/cli/constants.py +++ b/src/crewai/cli/constants.py @@ -67,6 +67,7 @@ ENV_VARS = { { "prompt": "Enter your Azure deployment name (must start with 'azure/')", "key_name": "MODEL", # Uppercase MODEL used for consistency across environment variables + "validator": lambda x: x.startswith("azure/") or "Model name must start with 'azure/'" }, { "prompt": "Enter your AZURE API key (press Enter to skip)", @@ -85,6 +86,7 @@ ENV_VARS = { { "prompt": "Enter your Cerebras model name (must start with 'cerebras/')", "key_name": "MODEL", # Uppercase MODEL used for consistency across environment variables + "validator": lambda x: x.startswith("cerebras/") or "Model name must start with 'cerebras/'" }, { "prompt": "Enter your Cerebras API version (press Enter to skip)", diff --git a/src/crewai/cli/create_crew.py b/src/crewai/cli/create_crew.py index c658b0de1..036c3a77e 100644 --- a/src/crewai/cli/create_crew.py +++ b/src/crewai/cli/create_crew.py @@ -157,10 +157,19 @@ def create_crew(name, provider=None, skip_provider=False, parent_folder=None): # Prompt for non-default key-value pairs prompt = details["prompt"] key_name = details["key_name"] - api_key_value = click.prompt(prompt, default="", show_default=False) + while True: + api_key_value = click.prompt(prompt, default="", show_default=False) + if not api_key_value.strip(): + break + + if "validator" in details: + validation_result = details["validator"](api_key_value) + if isinstance(validation_result, str): + click.secho(f"Invalid input: {validation_result}", fg="red") + continue - if api_key_value.strip(): env_vars[key_name] = api_key_value + break if env_vars: write_env_file(folder_path, env_vars) diff --git a/tests/agent_test.py b/tests/agent_test.py index 6f9ec3574..5e5f7ba24 100644 --- a/tests/agent_test.py +++ b/tests/agent_test.py @@ -21,14 +21,16 @@ from crewai.utilities import RPMController from crewai.utilities.events import Emitter -def test_agent_azure_model_env_var(): - """Test Azure MODEL environment variable handling with various cases.""" +def test_agent_model_env_var(): + """Test MODEL environment variable handling with various cases.""" # Store original environment variables original_model = os.environ.get("MODEL") test_cases = [ - ("azure/test-model", "azure/test-model"), # Valid case - ("azure/minimal", "azure/minimal"), # Another valid case + ("azure/test-model", "azure/test-model"), # Valid Azure case + ("azure/minimal", "azure/minimal"), # Another valid Azure case + ("cerebras/test-model", "cerebras/test-model"), # Valid Cerebras case + ("cerebras/minimal", "cerebras/minimal"), # Another valid Cerebras case ] for input_model, expected_model in test_cases: