mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-10 16:48:30 +00:00
feat: Add model name validation and expand test coverage
- Add validation for Azure and Cerebras model names - Add validation handling in create_crew.py - Expand test coverage for model env var cases Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
@@ -67,6 +67,7 @@ ENV_VARS = {
|
||||
{
|
||||
"prompt": "Enter your Azure deployment name (must start with 'azure/')",
|
||||
"key_name": "MODEL", # Uppercase MODEL used for consistency across environment variables
|
||||
"validator": lambda x: x.startswith("azure/") or "Model name must start with 'azure/'"
|
||||
},
|
||||
{
|
||||
"prompt": "Enter your AZURE API key (press Enter to skip)",
|
||||
@@ -85,6 +86,7 @@ ENV_VARS = {
|
||||
{
|
||||
"prompt": "Enter your Cerebras model name (must start with 'cerebras/')",
|
||||
"key_name": "MODEL", # Uppercase MODEL used for consistency across environment variables
|
||||
"validator": lambda x: x.startswith("cerebras/") or "Model name must start with 'cerebras/'"
|
||||
},
|
||||
{
|
||||
"prompt": "Enter your Cerebras API version (press Enter to skip)",
|
||||
|
||||
@@ -157,10 +157,19 @@ def create_crew(name, provider=None, skip_provider=False, parent_folder=None):
|
||||
# Prompt for non-default key-value pairs
|
||||
prompt = details["prompt"]
|
||||
key_name = details["key_name"]
|
||||
api_key_value = click.prompt(prompt, default="", show_default=False)
|
||||
while True:
|
||||
api_key_value = click.prompt(prompt, default="", show_default=False)
|
||||
if not api_key_value.strip():
|
||||
break
|
||||
|
||||
if "validator" in details:
|
||||
validation_result = details["validator"](api_key_value)
|
||||
if isinstance(validation_result, str):
|
||||
click.secho(f"Invalid input: {validation_result}", fg="red")
|
||||
continue
|
||||
|
||||
if api_key_value.strip():
|
||||
env_vars[key_name] = api_key_value
|
||||
break
|
||||
|
||||
if env_vars:
|
||||
write_env_file(folder_path, env_vars)
|
||||
|
||||
@@ -21,14 +21,16 @@ from crewai.utilities import RPMController
|
||||
from crewai.utilities.events import Emitter
|
||||
|
||||
|
||||
def test_agent_azure_model_env_var():
|
||||
"""Test Azure MODEL environment variable handling with various cases."""
|
||||
def test_agent_model_env_var():
|
||||
"""Test MODEL environment variable handling with various cases."""
|
||||
# Store original environment variables
|
||||
original_model = os.environ.get("MODEL")
|
||||
|
||||
test_cases = [
|
||||
("azure/test-model", "azure/test-model"), # Valid case
|
||||
("azure/minimal", "azure/minimal"), # Another valid case
|
||||
("azure/test-model", "azure/test-model"), # Valid Azure case
|
||||
("azure/minimal", "azure/minimal"), # Another valid Azure case
|
||||
("cerebras/test-model", "cerebras/test-model"), # Valid Cerebras case
|
||||
("cerebras/minimal", "cerebras/minimal"), # Another valid Cerebras case
|
||||
]
|
||||
|
||||
for input_model, expected_model in test_cases:
|
||||
|
||||
Reference in New Issue
Block a user