mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-05-06 01:32:36 +00:00
fix: reject reserved script names for crew folders
This commit is contained in:
@@ -1,10 +1,13 @@
|
|||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_CREWAI_ENTERPRISE_URL = "https://app.crewai.com"
|
DEFAULT_CREWAI_ENTERPRISE_URL = "https://app.crewai.com"
|
||||||
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_PROVIDER = "workos"
|
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_PROVIDER = "workos"
|
||||||
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_AUDIENCE = "client_01JNJQWBJ4SPFN3SWJM5T7BDG8"
|
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_AUDIENCE = "client_01JNJQWBJ4SPFN3SWJM5T7BDG8"
|
||||||
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_CLIENT_ID = "client_01JYT06R59SP0NXYGD994NFXXX"
|
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_CLIENT_ID = "client_01JYT06R59SP0NXYGD994NFXXX"
|
||||||
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN = "login.crewai.com"
|
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN = "login.crewai.com"
|
||||||
|
|
||||||
ENV_VARS = {
|
ENV_VARS: dict[str, list[dict[str, Any]]] = {
|
||||||
"openai": [
|
"openai": [
|
||||||
{
|
{
|
||||||
"prompt": "Enter your OPENAI API key (press Enter to skip)",
|
"prompt": "Enter your OPENAI API key (press Enter to skip)",
|
||||||
@@ -112,7 +115,7 @@ ENV_VARS = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
PROVIDERS = [
|
PROVIDERS: list[str] = [
|
||||||
"openai",
|
"openai",
|
||||||
"anthropic",
|
"anthropic",
|
||||||
"gemini",
|
"gemini",
|
||||||
@@ -127,7 +130,7 @@ PROVIDERS = [
|
|||||||
"sambanova",
|
"sambanova",
|
||||||
]
|
]
|
||||||
|
|
||||||
MODELS = {
|
MODELS: dict[str, list[str]] = {
|
||||||
"openai": [
|
"openai": [
|
||||||
"gpt-4",
|
"gpt-4",
|
||||||
"gpt-4.1",
|
"gpt-4.1",
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ import shutil
|
|||||||
import sys
|
import sys
|
||||||
|
|
||||||
import click
|
import click
|
||||||
|
import tomli
|
||||||
|
|
||||||
from crewai.cli.constants import ENV_VARS, MODELS
|
from crewai.cli.constants import ENV_VARS, MODELS
|
||||||
from crewai.cli.provider import (
|
from crewai.cli.provider import (
|
||||||
@@ -13,7 +14,31 @@ from crewai.cli.provider import (
|
|||||||
from crewai.cli.utils import copy_template, load_env_vars, write_env_file
|
from crewai.cli.utils import copy_template, load_env_vars, write_env_file
|
||||||
|
|
||||||
|
|
||||||
def create_folder_structure(name, parent_folder=None):
|
def get_reserved_script_names() -> set[str]:
|
||||||
|
"""Get reserved script names from pyproject.toml template.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Set of reserved script names that would conflict with crew folder names.
|
||||||
|
"""
|
||||||
|
package_dir = Path(__file__).parent
|
||||||
|
template_path = package_dir / "templates" / "crew" / "pyproject.toml"
|
||||||
|
|
||||||
|
with open(template_path, "r") as f:
|
||||||
|
template_content = f.read()
|
||||||
|
|
||||||
|
template_content = template_content.replace("{{folder_name}}", "_placeholder_")
|
||||||
|
template_content = template_content.replace("{{name}}", "placeholder")
|
||||||
|
template_content = template_content.replace("{{crew_name}}", "Placeholder")
|
||||||
|
|
||||||
|
template_data = tomli.loads(template_content)
|
||||||
|
script_names = set(template_data.get("project", {}).get("scripts", {}).keys())
|
||||||
|
script_names.discard("_placeholder_")
|
||||||
|
return script_names
|
||||||
|
|
||||||
|
|
||||||
|
def create_folder_structure(
|
||||||
|
name: str, parent_folder: str | None = None
|
||||||
|
) -> tuple[Path, str, str]:
|
||||||
import keyword
|
import keyword
|
||||||
import re
|
import re
|
||||||
|
|
||||||
@@ -51,6 +76,14 @@ def create_folder_structure(name, parent_folder=None):
|
|||||||
f"Project name '{name}' would generate invalid Python module name '{folder_name}'"
|
f"Project name '{name}' would generate invalid Python module name '{folder_name}'"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
reserved_names = get_reserved_script_names()
|
||||||
|
if folder_name in reserved_names:
|
||||||
|
raise ValueError(
|
||||||
|
f"Project name '{name}' would generate folder name '{folder_name}' which is reserved. "
|
||||||
|
f"Reserved names are: {', '.join(sorted(reserved_names))}. "
|
||||||
|
"Please choose a different name."
|
||||||
|
)
|
||||||
|
|
||||||
class_name = name.replace("_", " ").replace("-", " ").title().replace(" ", "")
|
class_name = name.replace("_", " ").replace("-", " ").title().replace(" ", "")
|
||||||
|
|
||||||
class_name = re.sub(r"[^a-zA-Z0-9_]", "", class_name)
|
class_name = re.sub(r"[^a-zA-Z0-9_]", "", class_name)
|
||||||
@@ -114,7 +147,9 @@ def create_folder_structure(name, parent_folder=None):
|
|||||||
return folder_path, folder_name, class_name
|
return folder_path, folder_name, class_name
|
||||||
|
|
||||||
|
|
||||||
def copy_template_files(folder_path, name, class_name, parent_folder):
|
def copy_template_files(
|
||||||
|
folder_path: Path, name: str, class_name: str, parent_folder: str | None
|
||||||
|
) -> None:
|
||||||
package_dir = Path(__file__).parent
|
package_dir = Path(__file__).parent
|
||||||
templates_dir = package_dir / "templates" / "crew"
|
templates_dir = package_dir / "templates" / "crew"
|
||||||
|
|
||||||
@@ -155,7 +190,12 @@ def copy_template_files(folder_path, name, class_name, parent_folder):
|
|||||||
copy_template(src_file, dst_file, name, class_name, folder_path.name)
|
copy_template(src_file, dst_file, name, class_name, folder_path.name)
|
||||||
|
|
||||||
|
|
||||||
def create_crew(name, provider=None, skip_provider=False, parent_folder=None):
|
def create_crew(
|
||||||
|
name: str,
|
||||||
|
provider: str | None = None,
|
||||||
|
skip_provider: bool = False,
|
||||||
|
parent_folder: str | None = None,
|
||||||
|
) -> None:
|
||||||
folder_path, folder_name, class_name = create_folder_structure(name, parent_folder)
|
folder_path, folder_name, class_name = create_folder_structure(name, parent_folder)
|
||||||
env_vars = load_env_vars(folder_path)
|
env_vars = load_env_vars(folder_path)
|
||||||
if not skip_provider:
|
if not skip_provider:
|
||||||
@@ -189,7 +229,9 @@ def create_crew(name, provider=None, skip_provider=False, parent_folder=None):
|
|||||||
if selected_provider is None: # User typed 'q'
|
if selected_provider is None: # User typed 'q'
|
||||||
click.secho("Exiting...", fg="yellow")
|
click.secho("Exiting...", fg="yellow")
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
if selected_provider: # Valid selection
|
if selected_provider and isinstance(
|
||||||
|
selected_provider, str
|
||||||
|
): # Valid selection
|
||||||
break
|
break
|
||||||
click.secho(
|
click.secho(
|
||||||
"No provider selected. Please try again or press 'q' to exit.", fg="red"
|
"No provider selected. Please try again or press 'q' to exit.", fg="red"
|
||||||
|
|||||||
@@ -1,8 +1,10 @@
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
from collections.abc import Sequence
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import time
|
import time
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
import certifi
|
import certifi
|
||||||
import click
|
import click
|
||||||
@@ -11,16 +13,15 @@ import requests
|
|||||||
from crewai.cli.constants import JSON_URL, MODELS, PROVIDERS
|
from crewai.cli.constants import JSON_URL, MODELS, PROVIDERS
|
||||||
|
|
||||||
|
|
||||||
def select_choice(prompt_message, choices):
|
def select_choice(prompt_message: str, choices: Sequence[str]) -> str | None:
|
||||||
"""
|
"""Presents a list of choices to the user and prompts them to select one.
|
||||||
Presents a list of choices to the user and prompts them to select one.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
- prompt_message (str): The message to display to the user before presenting the choices.
|
prompt_message: The message to display to the user before presenting the choices.
|
||||||
- choices (list): A list of options to present to the user.
|
choices: A list of options to present to the user.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
- str: The selected choice from the list, or None if the user chooses to quit.
|
The selected choice from the list, or None if the user chooses to quit.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
provider_models = get_provider_data()
|
provider_models = get_provider_data()
|
||||||
@@ -52,16 +53,14 @@ def select_choice(prompt_message, choices):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def select_provider(provider_models):
|
def select_provider(provider_models: dict[str, list[str]]) -> str | None | bool:
|
||||||
"""
|
"""Presents a list of providers to the user and prompts them to select one.
|
||||||
Presents a list of providers to the user and prompts them to select one.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
- provider_models (dict): A dictionary of provider models.
|
provider_models: A dictionary of provider models.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
- str: The selected provider
|
The selected provider, None if user explicitly quits, or False if no selection.
|
||||||
- None: If user explicitly quits
|
|
||||||
"""
|
"""
|
||||||
predefined_providers = [p.lower() for p in PROVIDERS]
|
predefined_providers = [p.lower() for p in PROVIDERS]
|
||||||
all_providers = sorted(set(predefined_providers + list(provider_models.keys())))
|
all_providers = sorted(set(predefined_providers + list(provider_models.keys())))
|
||||||
@@ -80,16 +79,15 @@ def select_provider(provider_models):
|
|||||||
return provider.lower() if provider else False
|
return provider.lower() if provider else False
|
||||||
|
|
||||||
|
|
||||||
def select_model(provider, provider_models):
|
def select_model(provider: str, provider_models: dict[str, list[str]]) -> str | None:
|
||||||
"""
|
"""Presents a list of models for a given provider to the user and prompts them to select one.
|
||||||
Presents a list of models for a given provider to the user and prompts them to select one.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
- provider (str): The provider for which to select a model.
|
provider: The provider for which to select a model.
|
||||||
- provider_models (dict): A dictionary of provider models.
|
provider_models: A dictionary of provider models.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
- str: The selected model, or None if the operation is aborted or an invalid selection is made.
|
The selected model, or None if the operation is aborted or an invalid selection is made.
|
||||||
"""
|
"""
|
||||||
predefined_providers = [p.lower() for p in PROVIDERS]
|
predefined_providers = [p.lower() for p in PROVIDERS]
|
||||||
|
|
||||||
@@ -107,16 +105,17 @@ def select_model(provider, provider_models):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def load_provider_data(cache_file, cache_expiry):
|
def load_provider_data(cache_file: Path, cache_expiry: int) -> dict[str, Any] | None:
|
||||||
"""
|
"""Loads provider data from a cache file if it exists and is not expired.
|
||||||
Loads provider data from a cache file if it exists and is not expired. If the cache is expired or corrupted, it fetches the data from the web.
|
|
||||||
|
If the cache is expired or corrupted, it fetches the data from the web.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
- cache_file (Path): The path to the cache file.
|
cache_file: The path to the cache file.
|
||||||
- cache_expiry (int): The cache expiry time in seconds.
|
cache_expiry: The cache expiry time in seconds.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
- dict or None: The loaded provider data or None if the operation fails.
|
The loaded provider data or None if the operation fails.
|
||||||
"""
|
"""
|
||||||
current_time = time.time()
|
current_time = time.time()
|
||||||
if (
|
if (
|
||||||
@@ -137,32 +136,31 @@ def load_provider_data(cache_file, cache_expiry):
|
|||||||
return fetch_provider_data(cache_file)
|
return fetch_provider_data(cache_file)
|
||||||
|
|
||||||
|
|
||||||
def read_cache_file(cache_file):
|
def read_cache_file(cache_file: Path) -> dict[str, Any] | None:
|
||||||
"""
|
"""Reads and returns the JSON content from a cache file.
|
||||||
Reads and returns the JSON content from a cache file. Returns None if the file contains invalid JSON.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
- cache_file (Path): The path to the cache file.
|
cache_file: The path to the cache file.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
- dict or None: The JSON content of the cache file or None if the JSON is invalid.
|
The JSON content of the cache file or None if the JSON is invalid.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
with open(cache_file, "r") as f:
|
with open(cache_file, "r") as f:
|
||||||
return json.load(f)
|
data: dict[str, Any] = json.load(f)
|
||||||
|
return data
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def fetch_provider_data(cache_file):
|
def fetch_provider_data(cache_file: Path) -> dict[str, Any] | None:
|
||||||
"""
|
"""Fetches provider data from a specified URL and caches it to a file.
|
||||||
Fetches provider data from a specified URL and caches it to a file.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
- cache_file (Path): The path to the cache file.
|
cache_file: The path to the cache file.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
- dict or None: The fetched provider data or None if the operation fails.
|
The fetched provider data or None if the operation fails.
|
||||||
"""
|
"""
|
||||||
ssl_config = os.environ["SSL_CERT_FILE"] = certifi.where()
|
ssl_config = os.environ["SSL_CERT_FILE"] = certifi.where()
|
||||||
|
|
||||||
@@ -180,36 +178,39 @@ def fetch_provider_data(cache_file):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def download_data(response):
|
def download_data(response: requests.Response) -> dict[str, Any]:
|
||||||
"""
|
"""Downloads data from a given HTTP response and returns the JSON content.
|
||||||
Downloads data from a given HTTP response and returns the JSON content.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
- response (requests.Response): The HTTP response object.
|
response: The HTTP response object.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
- dict: The JSON content of the response.
|
The JSON content of the response.
|
||||||
"""
|
"""
|
||||||
total_size = int(response.headers.get("content-length", 0))
|
total_size = int(response.headers.get("content-length", 0))
|
||||||
block_size = 8192
|
block_size = 8192
|
||||||
data_chunks = []
|
data_chunks: list[bytes] = []
|
||||||
|
bar: Any
|
||||||
with click.progressbar(
|
with click.progressbar(
|
||||||
length=total_size, label="Downloading", show_pos=True
|
length=total_size, label="Downloading", show_pos=True
|
||||||
) as progress_bar:
|
) as bar:
|
||||||
for chunk in response.iter_content(block_size):
|
for chunk in response.iter_content(block_size):
|
||||||
if chunk:
|
if chunk:
|
||||||
data_chunks.append(chunk)
|
data_chunks.append(chunk)
|
||||||
progress_bar.update(len(chunk))
|
bar.update(len(chunk))
|
||||||
data_content = b"".join(data_chunks)
|
data_content = b"".join(data_chunks)
|
||||||
return json.loads(data_content.decode("utf-8"))
|
result: dict[str, Any] = json.loads(data_content.decode("utf-8"))
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
def get_provider_data():
|
def get_provider_data() -> dict[str, list[str]] | None:
|
||||||
"""
|
"""Retrieves provider data from a cache file.
|
||||||
Retrieves provider data from a cache file, filters out models based on provider criteria, and returns a dictionary of providers mapped to their models.
|
|
||||||
|
Filters out models based on provider criteria, and returns a dictionary of providers
|
||||||
|
mapped to their models.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
- dict or None: A dictionary of providers mapped to their models or None if the operation fails.
|
A dictionary of providers mapped to their models or None if the operation fails.
|
||||||
"""
|
"""
|
||||||
cache_dir = Path.home() / ".crewai"
|
cache_dir = Path.home() / ".crewai"
|
||||||
cache_dir.mkdir(exist_ok=True)
|
cache_dir.mkdir(exist_ok=True)
|
||||||
|
|||||||
@@ -296,6 +296,23 @@ def test_create_folder_structure_folder_name_validation():
|
|||||||
shutil.rmtree(folder_path)
|
shutil.rmtree(folder_path)
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_folder_structure_rejects_reserved_names():
|
||||||
|
"""Test that reserved script names are rejected to prevent pyproject.toml conflicts."""
|
||||||
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
|
reserved_names = ["test", "train", "replay", "run_crew", "run_with_trigger"]
|
||||||
|
|
||||||
|
for reserved_name in reserved_names:
|
||||||
|
with pytest.raises(ValueError, match="which is reserved"):
|
||||||
|
create_folder_structure(reserved_name, parent_folder=temp_dir)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="which is reserved"):
|
||||||
|
create_folder_structure(f"{reserved_name}/", parent_folder=temp_dir)
|
||||||
|
|
||||||
|
capitalized = reserved_name.capitalize()
|
||||||
|
with pytest.raises(ValueError, match="which is reserved"):
|
||||||
|
create_folder_structure(capitalized, parent_folder=temp_dir)
|
||||||
|
|
||||||
|
|
||||||
@mock.patch("crewai.cli.create_crew.create_folder_structure")
|
@mock.patch("crewai.cli.create_crew.create_folder_structure")
|
||||||
@mock.patch("crewai.cli.create_crew.copy_template")
|
@mock.patch("crewai.cli.create_crew.copy_template")
|
||||||
@mock.patch("crewai.cli.create_crew.load_env_vars")
|
@mock.patch("crewai.cli.create_crew.load_env_vars")
|
||||||
|
|||||||
Reference in New Issue
Block a user