fix: reject reserved script names for crew folders

This commit is contained in:
Greyson LaLonde
2026-02-03 09:16:55 -05:00
committed by GitHub
parent 6a8483fcb6
commit c1d2801be2
4 changed files with 118 additions and 55 deletions

View File

@@ -1,10 +1,13 @@
from typing import Any
DEFAULT_CREWAI_ENTERPRISE_URL = "https://app.crewai.com" DEFAULT_CREWAI_ENTERPRISE_URL = "https://app.crewai.com"
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_PROVIDER = "workos" CREWAI_ENTERPRISE_DEFAULT_OAUTH2_PROVIDER = "workos"
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_AUDIENCE = "client_01JNJQWBJ4SPFN3SWJM5T7BDG8" CREWAI_ENTERPRISE_DEFAULT_OAUTH2_AUDIENCE = "client_01JNJQWBJ4SPFN3SWJM5T7BDG8"
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_CLIENT_ID = "client_01JYT06R59SP0NXYGD994NFXXX" CREWAI_ENTERPRISE_DEFAULT_OAUTH2_CLIENT_ID = "client_01JYT06R59SP0NXYGD994NFXXX"
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN = "login.crewai.com" CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN = "login.crewai.com"
ENV_VARS = { ENV_VARS: dict[str, list[dict[str, Any]]] = {
"openai": [ "openai": [
{ {
"prompt": "Enter your OPENAI API key (press Enter to skip)", "prompt": "Enter your OPENAI API key (press Enter to skip)",
@@ -112,7 +115,7 @@ ENV_VARS = {
} }
PROVIDERS = [ PROVIDERS: list[str] = [
"openai", "openai",
"anthropic", "anthropic",
"gemini", "gemini",
@@ -127,7 +130,7 @@ PROVIDERS = [
"sambanova", "sambanova",
] ]
MODELS = { MODELS: dict[str, list[str]] = {
"openai": [ "openai": [
"gpt-4", "gpt-4",
"gpt-4.1", "gpt-4.1",

View File

@@ -3,6 +3,7 @@ import shutil
import sys import sys
import click import click
import tomli
from crewai.cli.constants import ENV_VARS, MODELS from crewai.cli.constants import ENV_VARS, MODELS
from crewai.cli.provider import ( from crewai.cli.provider import (
@@ -13,7 +14,31 @@ from crewai.cli.provider import (
from crewai.cli.utils import copy_template, load_env_vars, write_env_file from crewai.cli.utils import copy_template, load_env_vars, write_env_file
def create_folder_structure(name, parent_folder=None): def get_reserved_script_names() -> set[str]:
"""Get reserved script names from pyproject.toml template.
Returns:
Set of reserved script names that would conflict with crew folder names.
"""
package_dir = Path(__file__).parent
template_path = package_dir / "templates" / "crew" / "pyproject.toml"
with open(template_path, "r") as f:
template_content = f.read()
template_content = template_content.replace("{{folder_name}}", "_placeholder_")
template_content = template_content.replace("{{name}}", "placeholder")
template_content = template_content.replace("{{crew_name}}", "Placeholder")
template_data = tomli.loads(template_content)
script_names = set(template_data.get("project", {}).get("scripts", {}).keys())
script_names.discard("_placeholder_")
return script_names
def create_folder_structure(
name: str, parent_folder: str | None = None
) -> tuple[Path, str, str]:
import keyword import keyword
import re import re
@@ -51,6 +76,14 @@ def create_folder_structure(name, parent_folder=None):
f"Project name '{name}' would generate invalid Python module name '{folder_name}'" f"Project name '{name}' would generate invalid Python module name '{folder_name}'"
) )
reserved_names = get_reserved_script_names()
if folder_name in reserved_names:
raise ValueError(
f"Project name '{name}' would generate folder name '{folder_name}' which is reserved. "
f"Reserved names are: {', '.join(sorted(reserved_names))}. "
"Please choose a different name."
)
class_name = name.replace("_", " ").replace("-", " ").title().replace(" ", "") class_name = name.replace("_", " ").replace("-", " ").title().replace(" ", "")
class_name = re.sub(r"[^a-zA-Z0-9_]", "", class_name) class_name = re.sub(r"[^a-zA-Z0-9_]", "", class_name)
@@ -114,7 +147,9 @@ def create_folder_structure(name, parent_folder=None):
return folder_path, folder_name, class_name return folder_path, folder_name, class_name
def copy_template_files(folder_path, name, class_name, parent_folder): def copy_template_files(
folder_path: Path, name: str, class_name: str, parent_folder: str | None
) -> None:
package_dir = Path(__file__).parent package_dir = Path(__file__).parent
templates_dir = package_dir / "templates" / "crew" templates_dir = package_dir / "templates" / "crew"
@@ -155,7 +190,12 @@ def copy_template_files(folder_path, name, class_name, parent_folder):
copy_template(src_file, dst_file, name, class_name, folder_path.name) copy_template(src_file, dst_file, name, class_name, folder_path.name)
def create_crew(name, provider=None, skip_provider=False, parent_folder=None): def create_crew(
name: str,
provider: str | None = None,
skip_provider: bool = False,
parent_folder: str | None = None,
) -> None:
folder_path, folder_name, class_name = create_folder_structure(name, parent_folder) folder_path, folder_name, class_name = create_folder_structure(name, parent_folder)
env_vars = load_env_vars(folder_path) env_vars = load_env_vars(folder_path)
if not skip_provider: if not skip_provider:
@@ -189,7 +229,9 @@ def create_crew(name, provider=None, skip_provider=False, parent_folder=None):
if selected_provider is None: # User typed 'q' if selected_provider is None: # User typed 'q'
click.secho("Exiting...", fg="yellow") click.secho("Exiting...", fg="yellow")
sys.exit(0) sys.exit(0)
if selected_provider: # Valid selection if selected_provider and isinstance(
selected_provider, str
): # Valid selection
break break
click.secho( click.secho(
"No provider selected. Please try again or press 'q' to exit.", fg="red" "No provider selected. Please try again or press 'q' to exit.", fg="red"

View File

@@ -1,8 +1,10 @@
from collections import defaultdict from collections import defaultdict
from collections.abc import Sequence
import json import json
import os import os
from pathlib import Path from pathlib import Path
import time import time
from typing import Any
import certifi import certifi
import click import click
@@ -11,16 +13,15 @@ import requests
from crewai.cli.constants import JSON_URL, MODELS, PROVIDERS from crewai.cli.constants import JSON_URL, MODELS, PROVIDERS
def select_choice(prompt_message, choices): def select_choice(prompt_message: str, choices: Sequence[str]) -> str | None:
""" """Presents a list of choices to the user and prompts them to select one.
Presents a list of choices to the user and prompts them to select one.
Args: Args:
- prompt_message (str): The message to display to the user before presenting the choices. prompt_message: The message to display to the user before presenting the choices.
- choices (list): A list of options to present to the user. choices: A list of options to present to the user.
Returns: Returns:
- str: The selected choice from the list, or None if the user chooses to quit. The selected choice from the list, or None if the user chooses to quit.
""" """
provider_models = get_provider_data() provider_models = get_provider_data()
@@ -52,16 +53,14 @@ def select_choice(prompt_message, choices):
) )
def select_provider(provider_models): def select_provider(provider_models: dict[str, list[str]]) -> str | None | bool:
""" """Presents a list of providers to the user and prompts them to select one.
Presents a list of providers to the user and prompts them to select one.
Args: Args:
- provider_models (dict): A dictionary of provider models. provider_models: A dictionary of provider models.
Returns: Returns:
- str: The selected provider The selected provider, None if user explicitly quits, or False if no selection.
- None: If user explicitly quits
""" """
predefined_providers = [p.lower() for p in PROVIDERS] predefined_providers = [p.lower() for p in PROVIDERS]
all_providers = sorted(set(predefined_providers + list(provider_models.keys()))) all_providers = sorted(set(predefined_providers + list(provider_models.keys())))
@@ -80,16 +79,15 @@ def select_provider(provider_models):
return provider.lower() if provider else False return provider.lower() if provider else False
def select_model(provider, provider_models): def select_model(provider: str, provider_models: dict[str, list[str]]) -> str | None:
""" """Presents a list of models for a given provider to the user and prompts them to select one.
Presents a list of models for a given provider to the user and prompts them to select one.
Args: Args:
- provider (str): The provider for which to select a model. provider: The provider for which to select a model.
- provider_models (dict): A dictionary of provider models. provider_models: A dictionary of provider models.
Returns: Returns:
- str: The selected model, or None if the operation is aborted or an invalid selection is made. The selected model, or None if the operation is aborted or an invalid selection is made.
""" """
predefined_providers = [p.lower() for p in PROVIDERS] predefined_providers = [p.lower() for p in PROVIDERS]
@@ -107,16 +105,17 @@ def select_model(provider, provider_models):
) )
def load_provider_data(cache_file, cache_expiry): def load_provider_data(cache_file: Path, cache_expiry: int) -> dict[str, Any] | None:
""" """Loads provider data from a cache file if it exists and is not expired.
Loads provider data from a cache file if it exists and is not expired. If the cache is expired or corrupted, it fetches the data from the web.
If the cache is expired or corrupted, it fetches the data from the web.
Args: Args:
- cache_file (Path): The path to the cache file. cache_file: The path to the cache file.
- cache_expiry (int): The cache expiry time in seconds. cache_expiry: The cache expiry time in seconds.
Returns: Returns:
- dict or None: The loaded provider data or None if the operation fails. The loaded provider data or None if the operation fails.
""" """
current_time = time.time() current_time = time.time()
if ( if (
@@ -137,32 +136,31 @@ def load_provider_data(cache_file, cache_expiry):
return fetch_provider_data(cache_file) return fetch_provider_data(cache_file)
def read_cache_file(cache_file): def read_cache_file(cache_file: Path) -> dict[str, Any] | None:
""" """Reads and returns the JSON content from a cache file.
Reads and returns the JSON content from a cache file. Returns None if the file contains invalid JSON.
Args: Args:
- cache_file (Path): The path to the cache file. cache_file: The path to the cache file.
Returns: Returns:
- dict or None: The JSON content of the cache file or None if the JSON is invalid. The JSON content of the cache file or None if the JSON is invalid.
""" """
try: try:
with open(cache_file, "r") as f: with open(cache_file, "r") as f:
return json.load(f) data: dict[str, Any] = json.load(f)
return data
except json.JSONDecodeError: except json.JSONDecodeError:
return None return None
def fetch_provider_data(cache_file): def fetch_provider_data(cache_file: Path) -> dict[str, Any] | None:
""" """Fetches provider data from a specified URL and caches it to a file.
Fetches provider data from a specified URL and caches it to a file.
Args: Args:
- cache_file (Path): The path to the cache file. cache_file: The path to the cache file.
Returns: Returns:
- dict or None: The fetched provider data or None if the operation fails. The fetched provider data or None if the operation fails.
""" """
ssl_config = os.environ["SSL_CERT_FILE"] = certifi.where() ssl_config = os.environ["SSL_CERT_FILE"] = certifi.where()
@@ -180,36 +178,39 @@ def fetch_provider_data(cache_file):
return None return None
def download_data(response): def download_data(response: requests.Response) -> dict[str, Any]:
""" """Downloads data from a given HTTP response and returns the JSON content.
Downloads data from a given HTTP response and returns the JSON content.
Args: Args:
- response (requests.Response): The HTTP response object. response: The HTTP response object.
Returns: Returns:
- dict: The JSON content of the response. The JSON content of the response.
""" """
total_size = int(response.headers.get("content-length", 0)) total_size = int(response.headers.get("content-length", 0))
block_size = 8192 block_size = 8192
data_chunks = [] data_chunks: list[bytes] = []
bar: Any
with click.progressbar( with click.progressbar(
length=total_size, label="Downloading", show_pos=True length=total_size, label="Downloading", show_pos=True
) as progress_bar: ) as bar:
for chunk in response.iter_content(block_size): for chunk in response.iter_content(block_size):
if chunk: if chunk:
data_chunks.append(chunk) data_chunks.append(chunk)
progress_bar.update(len(chunk)) bar.update(len(chunk))
data_content = b"".join(data_chunks) data_content = b"".join(data_chunks)
return json.loads(data_content.decode("utf-8")) result: dict[str, Any] = json.loads(data_content.decode("utf-8"))
return result
def get_provider_data(): def get_provider_data() -> dict[str, list[str]] | None:
""" """Retrieves provider data from a cache file.
Retrieves provider data from a cache file, filters out models based on provider criteria, and returns a dictionary of providers mapped to their models.
Filters out models based on provider criteria, and returns a dictionary of providers
mapped to their models.
Returns: Returns:
- dict or None: A dictionary of providers mapped to their models or None if the operation fails. A dictionary of providers mapped to their models or None if the operation fails.
""" """
cache_dir = Path.home() / ".crewai" cache_dir = Path.home() / ".crewai"
cache_dir.mkdir(exist_ok=True) cache_dir.mkdir(exist_ok=True)

View File

@@ -296,6 +296,23 @@ def test_create_folder_structure_folder_name_validation():
shutil.rmtree(folder_path) shutil.rmtree(folder_path)
def test_create_folder_structure_rejects_reserved_names():
"""Test that reserved script names are rejected to prevent pyproject.toml conflicts."""
with tempfile.TemporaryDirectory() as temp_dir:
reserved_names = ["test", "train", "replay", "run_crew", "run_with_trigger"]
for reserved_name in reserved_names:
with pytest.raises(ValueError, match="which is reserved"):
create_folder_structure(reserved_name, parent_folder=temp_dir)
with pytest.raises(ValueError, match="which is reserved"):
create_folder_structure(f"{reserved_name}/", parent_folder=temp_dir)
capitalized = reserved_name.capitalize()
with pytest.raises(ValueError, match="which is reserved"):
create_folder_structure(capitalized, parent_folder=temp_dir)
@mock.patch("crewai.cli.create_crew.create_folder_structure") @mock.patch("crewai.cli.create_crew.create_folder_structure")
@mock.patch("crewai.cli.create_crew.copy_template") @mock.patch("crewai.cli.create_crew.copy_template")
@mock.patch("crewai.cli.create_crew.load_env_vars") @mock.patch("crewai.cli.create_crew.load_env_vars")