Compare commits

...

2 Commits

Author SHA1 Message Date
Lorenze Jay
1b7c5d1821 Merge branch 'main' into fix/cli-create-provider-flag 2025-04-01 10:23:40 -07:00
theCyberTech
fcaf0d264f fix(cli): ensure create_crew respects --provider flag 2025-03-31 08:21:58 +08:00
2 changed files with 196 additions and 70 deletions

View File

@@ -93,50 +93,66 @@ def create_crew(name, provider=None, skip_provider=False, parent_folder=None):
folder_path, folder_name, class_name = create_folder_structure(name, parent_folder)
env_vars = load_env_vars(folder_path)
if not skip_provider:
if not provider:
provider_models = get_provider_data()
if not provider_models:
return
existing_provider = None
for provider, env_keys in ENV_VARS.items():
if any(
"key_name" in details and details["key_name"] in env_vars
for details in env_keys
):
existing_provider = provider
break
if existing_provider:
if not click.confirm(
f"Found existing environment variable configuration for {existing_provider.capitalize()}. Do you want to override it?"
):
click.secho("Keeping existing provider configuration.", fg="yellow")
return
provider_models = get_provider_data()
if not provider_models:
click.secho("Could not retrieve provider data.", fg="red")
return
while True:
selected_provider = select_provider(provider_models)
if selected_provider is None: # User typed 'q'
click.secho("Exiting...", fg="yellow")
sys.exit(0)
if selected_provider: # Valid selection
break
click.secho(
"No provider selected. Please try again or press 'q' to exit.", fg="red"
)
selected_provider = None
if provider:
provider = provider.lower()
if provider in provider_models:
selected_provider = provider
click.secho(f"Using specified provider: {selected_provider.capitalize()}", fg="green")
else:
click.secho(f"Warning: Specified provider '{provider}' is not recognized. Please select one.", fg="yellow")
if not selected_provider:
existing_provider = None
for p, env_keys in ENV_VARS.items():
if any(
"key_name" in details and details["key_name"] in env_vars
for details in env_keys
):
existing_provider = p
break
if existing_provider:
if not click.confirm(
f"Found existing environment variable configuration for {existing_provider.capitalize()}. Do you want to override it?"
):
click.secho("Keeping existing provider configuration. Exiting provider setup.", fg="yellow")
copy_template_files(folder_path, name, class_name, parent_folder)
click.secho(f"Crew '{name}' created successfully!", fg="green")
click.secho(f"To run your crew, cd into '{folder_name}' and run 'crewai run'", fg="cyan")
return
else:
pass
while True:
selected_provider = select_provider(provider_models)
if selected_provider is None:
click.secho("Exiting...", fg="yellow")
sys.exit(0)
if selected_provider:
break
click.secho(
"No provider selected. Please try again or press 'q' to exit.", fg="red"
)
if not selected_provider:
click.secho("Provider selection failed. Exiting.", fg="red")
sys.exit(1)
# Check if the selected provider has predefined models
if selected_provider in MODELS and MODELS[selected_provider]:
while True:
selected_model = select_model(selected_provider, provider_models)
if selected_model is None: # User typed 'q'
if selected_model is None:
click.secho("Exiting...", fg="yellow")
sys.exit(0)
if selected_model: # Valid selection
if selected_model:
break
click.secho(
"No model selected. Please try again or press 'q' to exit.",
@@ -144,17 +160,14 @@ def create_crew(name, provider=None, skip_provider=False, parent_folder=None):
)
env_vars["MODEL"] = selected_model
# Check if the selected provider requires API keys
if selected_provider in ENV_VARS:
provider_env_vars = ENV_VARS[selected_provider]
for details in provider_env_vars:
if details.get("default", False):
# Automatically add default key-value pairs
for key, value in details.items():
if key not in ["prompt", "key_name", "default"]:
env_vars[key] = value
elif "key_name" in details:
# Prompt for non-default key-value pairs
prompt = details["prompt"]
key_name = details["key_name"]
api_key_value = click.prompt(prompt, default="", show_default=False)
@@ -167,41 +180,12 @@ def create_crew(name, provider=None, skip_provider=False, parent_folder=None):
click.secho("API keys and model saved to .env file", fg="green")
else:
click.secho(
"No API keys provided. Skipping .env file creation.", fg="yellow"
"No API keys provided or required by provider. Skipping .env file creation.", fg="yellow"
)
click.secho(f"Selected model: {env_vars.get('MODEL', 'N/A')}", fg="green")
package_dir = Path(__file__).parent
templates_dir = package_dir / "templates" / "crew"
copy_template_files(folder_path, name, class_name, parent_folder)
root_template_files = (
[".gitignore", "pyproject.toml", "README.md", "knowledge/user_preference.txt"]
if not parent_folder
else []
)
tools_template_files = ["tools/custom_tool.py", "tools/__init__.py"]
config_template_files = ["config/agents.yaml", "config/tasks.yaml"]
src_template_files = (
["__init__.py", "main.py", "crew.py"] if not parent_folder else ["crew.py"]
)
for file_name in root_template_files:
src_file = templates_dir / file_name
dst_file = folder_path / file_name
copy_template(src_file, dst_file, name, class_name, folder_name)
src_folder = folder_path / "src" / folder_name if not parent_folder else folder_path
for file_name in src_template_files:
src_file = templates_dir / file_name
dst_file = src_folder / file_name
copy_template(src_file, dst_file, name, class_name, folder_name)
if not parent_folder:
for file_name in tools_template_files + config_template_files:
src_file = templates_dir / file_name
dst_file = src_folder / file_name
copy_template(src_file, dst_file, name, class_name, folder_name)
click.secho(f"Crew {name} created successfully!", fg="green", bold=True)
click.secho(f"Crew '{name}' created successfully!", fg="green")
click.secho(f"To run your crew, cd into '{folder_name}' and run 'crewai run'", fg="cyan")

View File

@@ -0,0 +1,142 @@
import pytest
from click.testing import CliRunner
from unittest.mock import patch, MagicMock
from pathlib import Path
import sys
# Ensure the src directory is in the Python path for imports
sys.path.insert(0, str(Path(__file__).parent.parent.parent / 'src'))
from crewai.cli.cli import crewai
from crewai.cli import create_crew
from crewai.cli.constants import MODELS, ENV_VARS
# Mock provider data for testing
MOCK_PROVIDER_DATA = {
'openai': {'models': ['gpt-4', 'gpt-3.5-turbo']},
'google': {'models': ['gemini-pro']},
'anthropic': {'models': ['claude-3-opus']}
}
MOCK_VALID_PROVIDERS = list(MOCK_PROVIDER_DATA.keys())
@pytest.fixture
def runner():
return CliRunner()
@pytest.fixture(autouse=True)
def isolate_fs(monkeypatch):
# Prevent tests from interacting with the actual filesystem or real env vars
monkeypatch.setattr(Path, 'mkdir', lambda *args, **kwargs: None)
monkeypatch.setattr(Path, 'exists', lambda *args: False) # Assume folders don't exist initially
monkeypatch.setattr(create_crew, 'load_env_vars', lambda *args: {}) # Start with empty env vars
monkeypatch.setattr(create_crew, 'write_env_file', lambda *args, **kwargs: None)
monkeypatch.setattr(create_crew, 'copy_template_files', lambda *args, **kwargs: None)
@patch('crewai.cli.create_crew.get_provider_data', return_value=MOCK_PROVIDER_DATA)
@patch('crewai.cli.create_crew.select_provider')
@patch('crewai.cli.create_crew.select_model')
@patch('click.prompt')
@patch('click.confirm', return_value=True) # Default to confirming prompts
def test_create_crew_with_valid_provider(mock_confirm, mock_prompt, mock_select_model, mock_select_provider, mock_get_data, runner):
"""Test `crewai create crew <name> --provider <valid_provider>`"""
result = runner.invoke(crewai, ['create', 'crew', 'testcrew', '--provider', 'openai'])
print(f"CLI Output:\n{result.output}") # Debug output
assert result.exit_code == 0, f"CLI exited with code {result.exit_code}\nOutput: {result.output}"
assert "Using specified provider: Openai" in result.output
mock_select_provider.assert_not_called() # Should not ask interactively
# Depending on whether openai needs models/keys, check select_model/prompt calls
assert "Crew 'testcrew' created successfully!" in result.output
@patch('crewai.cli.create_crew.get_provider_data', return_value=MOCK_PROVIDER_DATA)
@patch('crewai.cli.create_crew.select_provider', return_value='google') # Simulate user selecting google
@patch('crewai.cli.create_crew.select_model', return_value='gemini-pro')
@patch('click.prompt')
@patch('click.confirm', return_value=True)
def test_create_crew_with_invalid_provider(mock_confirm, mock_prompt, mock_select_model, mock_select_provider, mock_get_data, runner):
"""Test `crewai create crew <name> --provider <invalid_provider>`"""
result = runner.invoke(crewai, ['create', 'crew', 'testcrew', '--provider', 'invalidprovider'])
print(f"CLI Output:\n{result.output}") # Debug output
assert result.exit_code == 0, f"CLI exited with code {result.exit_code}\nOutput: {result.output}"
assert "Warning: Specified provider 'invalidprovider' is not recognized." in result.output
mock_select_provider.assert_called_once() # Should ask interactively
# Check if subsequent steps for the selected provider (google) ran
mock_select_model.assert_called_once()
assert "Crew 'testcrew' created successfully!" in result.output
@patch('crewai.cli.create_crew.get_provider_data', return_value=MOCK_PROVIDER_DATA)
@patch('crewai.cli.create_crew.select_provider', return_value='anthropic') # Simulate user selecting anthropic
@patch('crewai.cli.create_crew.select_model', return_value='claude-3-opus')
@patch('click.prompt', return_value='sk-abc') # Simulate API key entry
@patch('click.confirm', return_value=True)
def test_create_crew_no_provider(mock_confirm, mock_prompt, mock_select_model, mock_select_provider, mock_get_data, runner):
"""Test `crewai create crew <name>`"""
result = runner.invoke(crewai, ['create', 'crew', 'testcrew'])
print(f"CLI Output:\n{result.output}") # Debug output
assert result.exit_code == 0, f"CLI exited with code {result.exit_code}\nOutput: {result.output}"
assert "Using specified provider:" not in result.output # Should not mention specified provider
mock_select_provider.assert_called_once() # Should ask interactively
mock_select_model.assert_called_once()
# Check if prompt for API key was called (assuming anthropic needs one)
if 'anthropic' in ENV_VARS and any('key_name' in d for d in ENV_VARS['anthropic']):
mock_prompt.assert_called()
assert "Crew 'testcrew' created successfully!" in result.output
@patch('crewai.cli.create_crew.get_provider_data')
@patch('crewai.cli.create_crew.select_provider')
@patch('crewai.cli.create_crew.select_model')
@patch('click.prompt')
@patch('click.confirm')
def test_create_crew_skip_provider(mock_confirm, mock_prompt, mock_select_model, mock_select_provider, mock_get_data, runner):
"""Test `crewai create crew <name> --skip_provider`"""
result = runner.invoke(crewai, ['create', 'crew', 'testcrew', '--skip_provider'])
print(f"CLI Output:\n{result.output}") # Debug output
assert result.exit_code == 0, f"CLI exited with code {result.exit_code}\nOutput: {result.output}"
mock_get_data.assert_not_called()
mock_select_provider.assert_not_called()
mock_select_model.assert_not_called()
mock_prompt.assert_not_called()
mock_confirm.assert_not_called()
assert "Crew 'testcrew' created successfully!" in result.output
@patch('crewai.cli.create_crew.load_env_vars', return_value={'OPENAI_API_KEY': 'existing_key'}) # Simulate existing env
@patch('crewai.cli.create_crew.get_provider_data', return_value=MOCK_PROVIDER_DATA)
@patch('crewai.cli.create_crew.select_provider', return_value='google') # Simulate selecting new provider
@patch('crewai.cli.create_crew.select_model', return_value='gemini-pro')
@patch('click.prompt')
@patch('click.confirm', return_value=True) # User confirms override
def test_create_crew_existing_override(mock_confirm, mock_prompt, mock_select_model, mock_select_provider, mock_get_data, mock_load_env, runner):
"""Test `crewai create crew <name>` with existing config and user overrides."""
result = runner.invoke(crewai, ['create', 'crew', 'testcrew'])
print(f"CLI Output:\n{result.output}") # Debug output
assert result.exit_code == 0, f"CLI exited with code {result.exit_code}\nOutput: {result.output}"
mock_confirm.assert_called_once_with(
'Found existing environment variable configuration for Openai. Do you want to override it?'
)
mock_select_provider.assert_called_once() # Should ask for new provider after confirming override
assert "Crew 'testcrew' created successfully!" in result.output
@patch('crewai.cli.create_crew.load_env_vars', return_value={'OPENAI_API_KEY': 'existing_key'}) # Simulate existing env
@patch('crewai.cli.create_crew.get_provider_data', return_value=MOCK_PROVIDER_DATA)
@patch('crewai.cli.create_crew.select_provider')
@patch('crewai.cli.create_crew.select_model')
@patch('click.prompt')
@patch('click.confirm', return_value=False) # User denies override
def test_create_crew_existing_keep(mock_confirm, mock_prompt, mock_select_model, mock_select_provider, mock_get_data, mock_load_env, runner):
"""Test `crewai create crew <name>` with existing config and user keeps it."""
result = runner.invoke(crewai, ['create', 'crew', 'testcrew'])
print(f"CLI Output:\n{result.output}") # Debug output
assert result.exit_code == 0, f"CLI exited with code {result.exit_code}\nOutput: {result.output}"
mock_confirm.assert_called_once_with(
'Found existing environment variable configuration for Openai. Do you want to override it?'
)
assert "Keeping existing provider configuration. Exiting provider setup." in result.output
mock_select_provider.assert_not_called() # Should NOT ask for new provider
assert "Crew 'testcrew' created successfully!" in result.output