feat: Add model name validation and expand test coverage

- Add validation for Azure and Cerebras model names
- Add validation handling in create_crew.py
- Expand test coverage for model env var cases

Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
Devin AI
2025-02-15 16:12:22 +00:00
parent e1ed85d7bd
commit 583e6584eb
3 changed files with 19 additions and 6 deletions

View File

@@ -67,6 +67,7 @@ ENV_VARS = {
{
"prompt": "Enter your Azure deployment name (must start with 'azure/')",
"key_name": "MODEL", # Uppercase MODEL used for consistency across environment variables
"validator": lambda x: x.startswith("azure/") or "Model name must start with 'azure/'"
},
{
"prompt": "Enter your AZURE API key (press Enter to skip)",
@@ -85,6 +86,7 @@ ENV_VARS = {
{
"prompt": "Enter your Cerebras model name (must start with 'cerebras/')",
"key_name": "MODEL", # Uppercase MODEL used for consistency across environment variables
"validator": lambda x: x.startswith("cerebras/") or "Model name must start with 'cerebras/'"
},
{
"prompt": "Enter your Cerebras API version (press Enter to skip)",

View File

@@ -157,10 +157,19 @@ def create_crew(name, provider=None, skip_provider=False, parent_folder=None):
# Prompt for non-default key-value pairs
prompt = details["prompt"]
key_name = details["key_name"]
api_key_value = click.prompt(prompt, default="", show_default=False)
while True:
api_key_value = click.prompt(prompt, default="", show_default=False)
if not api_key_value.strip():
break
if "validator" in details:
validation_result = details["validator"](api_key_value)
if isinstance(validation_result, str):
click.secho(f"Invalid input: {validation_result}", fg="red")
continue
if api_key_value.strip():
env_vars[key_name] = api_key_value
break
if env_vars:
write_env_file(folder_path, env_vars)

View File

@@ -21,14 +21,16 @@ from crewai.utilities import RPMController
from crewai.utilities.events import Emitter
def test_agent_azure_model_env_var():
"""Test Azure MODEL environment variable handling with various cases."""
def test_agent_model_env_var():
"""Test MODEL environment variable handling with various cases."""
# Store original environment variables
original_model = os.environ.get("MODEL")
test_cases = [
("azure/test-model", "azure/test-model"), # Valid case
("azure/minimal", "azure/minimal"), # Another valid case
("azure/test-model", "azure/test-model"), # Valid Azure case
("azure/minimal", "azure/minimal"), # Another valid Azure case
("cerebras/test-model", "cerebras/test-model"), # Valid Cerebras case
("cerebras/minimal", "cerebras/minimal"), # Another valid Cerebras case
]
for input_model, expected_model in test_cases: