diff --git a/src/crewai/cli/utils.py b/src/crewai/cli/utils.py index 9780b52fa..bcf6030ca 100644 --- a/src/crewai/cli/utils.py +++ b/src/crewai/cli/utils.py @@ -94,17 +94,18 @@ def _get_project_attribute( attribute = _get_nested_value(pyproject_content, keys) except FileNotFoundError: - print(f"Error: {pyproject_path} not found.") + console.print(f"Error: {pyproject_path} not found.", style="bold red") except KeyError: - print(f"Error: {pyproject_path} is not a valid pyproject.toml file.") + console.print(f"Error: {pyproject_path} is not a valid pyproject.toml file.", style="bold red") except tomllib.TOMLDecodeError if sys.version_info >= (3, 11) else Exception as e: # type: ignore - print( + console.print( f"Error: {pyproject_path} is not a valid TOML file." if sys.version_info >= (3, 11) - else f"Error reading the pyproject.toml file: {e}" + else f"Error reading the pyproject.toml file: {e}", + style="bold red", ) except Exception as e: - print(f"Error reading the pyproject.toml file: {e}") + console.print(f"Error reading the pyproject.toml file: {e}", style="bold red") if require and not attribute: console.print( @@ -137,9 +138,9 @@ def fetch_and_json_env_file(env_file_path: str = ".env") -> dict: return env_dict except FileNotFoundError: - print(f"Error: {env_file_path} not found.") + console.print(f"Error: {env_file_path} not found.", style="bold red") except Exception as e: - print(f"Error reading the .env file: {e}") + console.print(f"Error reading the .env file: {e}", style="bold red") return {} @@ -255,50 +256,69 @@ def write_env_file(folder_path, env_vars): def get_crews(crew_path: str = "crew.py", require: bool = False) -> list[Crew]: - """Get the crew instances from the a file.""" + """Get the crew instances from a file.""" crew_instances = [] try: import importlib.util - for root, _, files in os.walk("."): - if crew_path in files: - crew_os_path = os.path.join(root, crew_path) - try: - spec = importlib.util.spec_from_file_location( - "crew_module", crew_os_path - ) - if not spec or not spec.loader: - continue - module = importlib.util.module_from_spec(spec) + # Add the current directory to sys.path to ensure imports resolve correctly + current_dir = os.getcwd() + if current_dir not in sys.path: + sys.path.insert(0, current_dir) + + # If we're not in src directory but there's a src directory, add it to path + src_dir = os.path.join(current_dir, "src") + if os.path.isdir(src_dir) and src_dir not in sys.path: + sys.path.insert(0, src_dir) + + # Search in both current directory and src directory if it exists + search_paths = [".", "src"] if os.path.isdir("src") else ["."] + + for search_path in search_paths: + for root, _, files in os.walk(search_path): + if crew_path in files and "cli/templates" not in root: + crew_os_path = os.path.join(root, crew_path) try: - sys.modules[spec.name] = module - spec.loader.exec_module(module) - - for attr_name in dir(module): - module_attr = getattr(module, attr_name) - - try: - crew_instances.extend(fetch_crews(module_attr)) - except Exception as e: - print(f"Error processing attribute {attr_name}: {e}") - continue - - except Exception as exec_error: - print(f"Error executing module: {exec_error}") - import traceback - - print(f"Traceback: {traceback.format_exc()}") - except (ImportError, AttributeError) as e: - if require: - console.print( - f"Error importing crew from {crew_path}: {str(e)}", - style="bold red", + spec = importlib.util.spec_from_file_location( + "crew_module", crew_os_path ) + if not spec or not spec.loader: + continue + + module = importlib.util.module_from_spec(spec) + sys.modules[spec.name] = module + + try: + spec.loader.exec_module(module) + + for attr_name in dir(module): + module_attr = getattr(module, attr_name) + try: + crew_instances.extend(fetch_crews(module_attr)) + except Exception as e: + console.print(f"Error processing attribute {attr_name}: {e}", style="bold red") + continue + + # If we found crew instances, break out of the loop + if crew_instances: + break + + except Exception as exec_error: + console.print(f"Error executing module: {exec_error}", style="bold red") + + except (ImportError, AttributeError) as e: + if require: + console.print( + f"Error importing crew from {crew_path}: {str(e)}", + style="bold red", + ) continue + # If we found crew instances in this search path, break out of the search paths loop + if crew_instances: break - if require: + if require and not crew_instances: console.print("No valid Crew instance found in crew.py", style="bold red") raise SystemExit @@ -318,11 +338,15 @@ def get_crew_instance(module_attr) -> Crew | None: and module_attr.is_crew_class ): return module_attr().crew() - if (ismethod(module_attr) or isfunction(module_attr)) and get_type_hints( - module_attr - ).get("return") is Crew: - return module_attr() - elif isinstance(module_attr, Crew): + try: + if (ismethod(module_attr) or isfunction(module_attr)) and get_type_hints( + module_attr + ).get("return") is Crew: + return module_attr() + except Exception: + return None + + if isinstance(module_attr, Crew): return module_attr else: return None @@ -402,7 +426,8 @@ def _load_tools_from_init(init_file: Path) -> list[dict[str, Any]]: if not hasattr(module, "__all__"): console.print( - f"[bold yellow]Warning: No __all__ defined in {init_file}[/bold yellow]" + f"Warning: No __all__ defined in {init_file}", + style="bold yellow", ) raise SystemExit(1) diff --git a/tests/cli/test_utils.py b/tests/cli/test_utils.py index 115bb67eb..517a1c236 100644 --- a/tests/cli/test_utils.py +++ b/tests/cli/test_utils.py @@ -261,3 +261,104 @@ __all__ = ['MyTool'] captured = capsys.readouterr() assert "was never closed" in captured.out + + +@pytest.fixture +def mock_crew(): + from crewai.crew import Crew + + class MockCrew(Crew): + def __init__(self): + pass + + return MockCrew() + + +@pytest.fixture +def temp_crew_project(): + with tempfile.TemporaryDirectory() as temp_dir: + old_cwd = os.getcwd() + os.chdir(temp_dir) + + crew_content = """ + from crewai.crew import Crew + from crewai.agent import Agent + + def create_crew() -> Crew: + agent = Agent(role="test", goal="test", backstory="test") + return Crew(agents=[agent], tasks=[]) + + # Direct crew instance + direct_crew = Crew(agents=[], tasks=[]) + """ + + with open("crew.py", "w") as f: + f.write(crew_content) + + os.makedirs("src", exist_ok=True) + with open(os.path.join("src", "crew.py"), "w") as f: + f.write(crew_content) + + # Create a src/templates directory that should be ignored + os.makedirs(os.path.join("src", "templates"), exist_ok=True) + with open(os.path.join("src", "templates", "crew.py"), "w") as f: + f.write("# This should be ignored") + + yield temp_dir + + os.chdir(old_cwd) + + +def test_get_crews_finds_valid_crews(temp_crew_project, monkeypatch, mock_crew): + def mock_fetch_crews(module_attr): + return [mock_crew] + + monkeypatch.setattr(utils, "fetch_crews", mock_fetch_crews) + + crews = utils.get_crews() + + assert len(crews) > 0 + assert mock_crew in crews + + +def test_get_crews_with_nonexistent_file(temp_crew_project): + crews = utils.get_crews(crew_path="nonexistent.py", require=False) + assert len(crews) == 0 + + +def test_get_crews_with_required_nonexistent_file(temp_crew_project, capsys): + with pytest.raises(SystemExit): + utils.get_crews(crew_path="nonexistent.py", require=True) + + captured = capsys.readouterr() + assert "No valid Crew instance found" in captured.out + + +def test_get_crews_with_invalid_module(temp_crew_project, capsys): + with open("crew.py", "w") as f: + f.write("import nonexistent_module\n") + + crews = utils.get_crews(crew_path="crew.py", require=False) + assert len(crews) == 0 + + with pytest.raises(SystemExit): + utils.get_crews(crew_path="crew.py", require=True) + + captured = capsys.readouterr() + assert "Error" in captured.out + + +def test_get_crews_ignores_template_directories(temp_crew_project, monkeypatch, mock_crew): + template_crew_detected = False + + def mock_fetch_crews(module_attr): + nonlocal template_crew_detected + if hasattr(module_attr, "__file__") and "templates" in module_attr.__file__: + template_crew_detected = True + return [mock_crew] + + monkeypatch.setattr(utils, "fetch_crews", mock_fetch_crews) + + utils.get_crews() + + assert not template_crew_detected