Compare commits

...

9 Commits

Author SHA1 Message Date
Greyson Lalonde
7afca5daab refactor: remove cli/ from crewai package and relocate to proper modules
Move framework infrastructure out of crewai/cli/ to dedicated modules:
- cli/authentication/ → crewai/auth/
- cli/config.py → crewai/settings.py
- cli/constants.py → crewai/constants.py
- cli/plus_api.py → crewai/plus_api.py
- cli/version.py → crewai/version.py
- cli/crew_chat.py → crewai/utilities/crew_chat.py
- cli/reset_memories_command.py → crewai/utilities/reset_memories.py
- cli/utils.py (framework parts) → crewai/utilities/project_utils.py

Delete CLI-only duplicates (command.py, git.py, provider.py) already
present in crewai_cli. Replace _login_to_tool_repository with a
_post_login() hook in AuthenticationCommand. Update all imports and
mock.patch paths across both packages and tests.
2026-03-15 19:39:55 -04:00
Greyson LaLonde
cf1636c300 fix(ci): exclude crewai_cli templates from ruff linting
Ruff fails when checking .py files in the templates directory because
it discovers the nearby pyproject.toml which contains {{folder_name}}
placeholders that are invalid TOML. Add the new template path to the
CI grep filter, matching the existing exclusion for the original path.
2026-03-14 22:38:48 -04:00
Greyson LaLonde
dfea5fb650 refactor: remove CLI shim from crewai package
The backward-compat shim is unnecessary — nothing imports from
crewai.cli.cli and the entry point lives in crewai-cli now.
2026-03-14 22:24:34 -04:00
Greyson LaLonde
8fd7a73423 fix(deploy): add pre-flight validation before deployment
Validate that pyproject.toml, a lockfile (uv.lock or poetry.lock),
and the expected src/<project>/crew.py or config directory exist
locally before making any API calls. This surfaces clear, actionable
errors on the CLI instead of cryptic server-side deployment failures.
2026-03-14 22:21:02 -04:00
Greyson LaLonde
b7bd7aea50 Merge branch 'main' into gl/chore/refactor-cli
# Conflicts:
#	lib/crewai/src/crewai/cli/cli.py
2026-03-14 22:17:02 -04:00
Greyson LaLonde
96fc584ab8 refactor: remove CLI from crewai package and add backward-compat shim
Remove all CLI modules and tests that have been moved to the
crewai-cli package. Replace cli.py with a thin shim that re-exports
from crewai_cli when available, or shows an install hint otherwise.

Update crewai pyproject.toml to add a [cli] extra pointing to
crewai-cli and comment out the old entry point. Add py.typed marker
to crewai_cli for mypy compatibility.
2026-03-14 22:12:38 -04:00
Greyson LaLonde
3732de7b88 test: add CLI tests to crewai-cli package
Move and adapt all CLI tests from lib/crewai/tests/cli/ to
lib/cli/tests/, updating import paths from crewai.cli.* to
crewai_cli.* and adjusting mock targets accordingly.
2026-03-14 22:09:38 -04:00
Greyson LaLonde
4f9a8f4112 refactor: move CLI source modules to crewai-cli package
Copy all CLI source modules from lib/crewai/src/crewai/cli/ to the
new lib/cli/src/crewai_cli/ package, updating internal imports from
crewai.cli.* to crewai_cli.* throughout.

Includes: authentication, deploy, enterprise, organization, settings,
tools, triggers, templates, and all top-level CLI command modules.

Also excludes lib/cli/ from pre-commit mypy checks to match existing
behavior (original CLI code has the same type gaps).
2026-03-14 22:08:48 -04:00
Greyson LaLonde
c0689aa6dc chore: scaffold crewai-cli package and update workspace config
Add the new lib/cli package skeleton with pyproject.toml, README,
and __init__.py. Register it as a uv workspace member and update
root linting, mypy, bandit, and pytest config to include the new
package paths.
2026-03-14 22:04:37 -04:00
167 changed files with 5668 additions and 977 deletions

View File

@@ -55,6 +55,7 @@ jobs:
echo "${{ steps.changed-files.outputs.files }}" \
| tr ' ' '\n' \
| grep -v 'src/crewai/cli/templates/' \
| grep -v 'src/crewai_cli/templates/' \
| grep -v '/tests/' \
| xargs -I{} uv run ruff check "{}"

View File

@@ -19,7 +19,7 @@ repos:
language: system
pass_filenames: true
types: [python]
exclude: ^(lib/crewai/src/crewai/cli/templates/|lib/crewai/tests/|lib/crewai-tools/tests/|lib/crewai-files/tests/)
exclude: ^(lib/crewai/src/crewai/cli/templates/|lib/cli/|lib/crewai/tests/|lib/crewai-tools/tests/|lib/crewai-files/tests/)
- repo: https://github.com/astral-sh/uv-pre-commit
rev: 0.9.3
hooks:

View File

@@ -226,7 +226,7 @@ def vcr_cassette_dir(request: Any) -> str:
for parent in test_file.parents:
if (
parent.name in ("crewai", "crewai-tools", "crewai-files")
parent.name in ("crewai", "crewai-tools", "crewai-files", "cli")
and parent.parent.name == "lib"
):
package_root = parent

15
lib/cli/README.md Normal file
View File

@@ -0,0 +1,15 @@
# crewai-cli
CLI for CrewAI - scaffold, run, deploy and manage AI agent crews without installing the full framework.
## Installation
```bash
pip install crewai-cli
```
Or install alongside the full framework:
```bash
pip install crewai[cli]
```

39
lib/cli/pyproject.toml Normal file
View File

@@ -0,0 +1,39 @@
[project]
name = "crewai-cli"
version = "1.10.0"
description = "CLI for CrewAI - scaffold, run, deploy and manage AI agent crews without installing the full framework."
readme = "README.md"
authors = [
{ name = "Joao Moura", email = "joao@crewai.com" }
]
requires-python = ">=3.10, <3.14"
dependencies = [
"click~=8.1.7",
"pydantic~=2.11.9",
"pydantic-settings~=2.10.1",
"appdirs~=1.4.4",
"httpx~=0.28.1",
"pyjwt>=2.9.0,<3",
"rich>=13.7.1",
"tomli~=2.0.2",
"tomli-w~=1.1.0",
"packaging>=23.0",
"python-dotenv~=1.1.1",
"uv~=0.9.13",
"portalocker~=2.7.0",
]
[project.urls]
Homepage = "https://crewai.com"
Documentation = "https://docs.crewai.com"
Repository = "https://github.com/crewAIInc/crewAI"
[project.scripts]
crewai = "crewai_cli.cli:crewai"
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
[tool.hatch.build.targets.wheel]
packages = ["src/crewai_cli"]

View File

@@ -0,0 +1 @@
__version__ = "1.10.0"

View File

@@ -2,19 +2,15 @@ from pathlib import Path
import click
from crewai.cli.utils import copy_template
from crewai.utilities.printer import Printer
_printer = Printer()
from crewai_cli.utils import copy_template
def add_crew_to_flow(crew_name: str) -> None:
"""Add a new crew to the current flow."""
# Check if pyproject.toml exists in the current directory
if not Path("pyproject.toml").exists():
_printer.print(
"This command must be run from the root of a flow project.", color="red"
click.secho(
"This command must be run from the root of a flow project.", fg="red"
)
raise click.ClickException(
"This command must be run from the root of a flow project."
@@ -25,7 +21,7 @@ def add_crew_to_flow(crew_name: str) -> None:
crews_folder = flow_folder / "src" / flow_folder.name / "crews"
if not crews_folder.exists():
_printer.print("Crews folder does not exist in the current flow.", color="red")
click.secho("Crews folder does not exist in the current flow.", fg="red")
raise click.ClickException("Crews folder does not exist in the current flow.")
# Create the crew within the flow's crews directory

View File

@@ -0,0 +1,4 @@
from crewai_cli.authentication.main import AuthenticationCommand
__all__ = ["AuthenticationCommand"]

View File

@@ -6,9 +6,9 @@ import httpx
from pydantic import BaseModel, Field
from rich.console import Console
from crewai.cli.authentication.utils import validate_jwt_token
from crewai.cli.config import Settings
from crewai.cli.shared.token_manager import TokenManager
from crewai_cli.authentication.utils import validate_jwt_token
from crewai_cli.config import Settings
from crewai_cli.shared.token_manager import TokenManager
console = Console()
@@ -51,7 +51,7 @@ class Oauth2Settings(BaseModel):
if TYPE_CHECKING:
from crewai.cli.authentication.providers.base_provider import BaseProvider
from crewai_cli.authentication.providers.base_provider import BaseProvider
class ProviderFactory:
@@ -65,7 +65,7 @@ class ProviderFactory:
import importlib
module = importlib.import_module(
f"crewai.cli.authentication.providers.{settings.provider.lower()}"
f"crewai_cli.authentication.providers.{settings.provider.lower()}"
)
# Converts from snake_case to CamelCase to obtain the provider class name.
provider = getattr(
@@ -180,7 +180,7 @@ class AuthenticationCommand:
def _login_to_tool_repository(self) -> None:
"""Login to the tool repository."""
from crewai.cli.tools.main import ToolCommand
from crewai_cli.tools.main import ToolCommand
try:
console.print(

View File

@@ -1,4 +1,4 @@
from crewai.cli.authentication.providers.base_provider import BaseProvider
from crewai_cli.authentication.providers.base_provider import BaseProvider
class Auth0Provider(BaseProvider):

View File

@@ -1,6 +1,6 @@
from abc import ABC, abstractmethod
from crewai.cli.authentication.main import Oauth2Settings
from crewai_cli.authentication.main import Oauth2Settings
class BaseProvider(ABC):

View File

@@ -1,6 +1,6 @@
from typing import cast
from crewai.cli.authentication.providers.base_provider import BaseProvider
from crewai_cli.authentication.providers.base_provider import BaseProvider
class EntraIdProvider(BaseProvider):

View File

@@ -1,4 +1,4 @@
from crewai.cli.authentication.providers.base_provider import BaseProvider
from crewai_cli.authentication.providers.base_provider import BaseProvider
class KeycloakProvider(BaseProvider):

View File

@@ -1,4 +1,4 @@
from crewai.cli.authentication.providers.base_provider import BaseProvider
from crewai_cli.authentication.providers.base_provider import BaseProvider
class OktaProvider(BaseProvider):

View File

@@ -1,4 +1,4 @@
from crewai.cli.authentication.providers.base_provider import BaseProvider
from crewai_cli.authentication.providers.base_provider import BaseProvider
class WorkosProvider(BaseProvider):

View File

@@ -1,4 +1,4 @@
from crewai.cli.shared.token_manager import TokenManager
from crewai_cli.shared.token_manager import TokenManager
class AuthError(Exception):

View File

@@ -1,3 +1,5 @@
from __future__ import annotations
from importlib.metadata import version as get_version
import os
import subprocess
@@ -5,44 +7,58 @@ from typing import Any
import click
from crewai.cli.add_crew_to_flow import add_crew_to_flow
from crewai.cli.authentication.main import AuthenticationCommand
from crewai.cli.config import Settings
from crewai.cli.create_crew import create_crew
from crewai.cli.create_flow import create_flow
from crewai.cli.crew_chat import run_chat
from crewai.cli.deploy.main import DeployCommand
from crewai.cli.enterprise.main import EnterpriseConfigureCommand
from crewai.cli.evaluate_crew import evaluate_crew
from crewai.cli.install_crew import install_crew
from crewai.cli.kickoff_flow import kickoff_flow
from crewai.cli.organization.main import OrganizationCommand
from crewai.cli.plot_flow import plot_flow
from crewai.cli.replay_from_task import replay_task_command
from crewai.cli.reset_memories_command import reset_memories_command
from crewai.cli.run_crew import run_crew
from crewai.cli.settings.main import SettingsCommand
from crewai.cli.tools.main import ToolCommand
from crewai.cli.train_crew import train_crew
from crewai.cli.triggers.main import TriggersCommand
from crewai.cli.update_crew import update_crew
from crewai.cli.utils import build_env_with_tool_repository_credentials, read_toml
from crewai.memory.storage.kickoff_task_outputs_storage import (
KickoffTaskOutputsSQLiteStorage,
from crewai_cli.add_crew_to_flow import add_crew_to_flow
from crewai_cli.authentication.main import AuthenticationCommand
from crewai_cli.config import Settings
from crewai_cli.create_crew import create_crew
from crewai_cli.create_flow import create_flow
from crewai_cli.crew_chat import run_chat
from crewai_cli.deploy.main import DeployCommand
from crewai_cli.enterprise.main import EnterpriseConfigureCommand
from crewai_cli.evaluate_crew import evaluate_crew
from crewai_cli.install_crew import install_crew
from crewai_cli.kickoff_flow import kickoff_flow
from crewai_cli.organization.main import OrganizationCommand
from crewai_cli.plot_flow import plot_flow
from crewai_cli.replay_from_task import replay_task_command
from crewai_cli.reset_memories_command import reset_memories_command
from crewai_cli.run_crew import run_crew
from crewai_cli.settings.main import SettingsCommand
from crewai_cli.task_outputs import load_task_outputs
from crewai_cli.tools.main import ToolCommand
from crewai_cli.train_crew import train_crew
from crewai_cli.triggers.main import TriggersCommand
from crewai_cli.update_crew import update_crew
from crewai_cli.user_data import (
_load_user_data,
_save_user_data,
is_tracing_enabled,
)
from crewai_cli.utils import build_env_with_tool_repository_credentials, read_toml
def _get_cli_version() -> str:
"""Return the best available version string for the CLI."""
# Prefer crewai version if installed (keeps existing UX)
try:
return get_version("crewai")
except Exception: # noqa: S110
pass
try:
return get_version("crewai-cli")
except Exception:
return "unknown"
@click.group()
@click.version_option(get_version("crewai"))
@click.version_option(_get_cli_version())
def crewai():
"""Top-level command group for crewai."""
@crewai.command(
name="uv",
context_settings=dict(
ignore_unknown_options=True,
),
context_settings={"ignore_unknown_options": True},
)
@click.argument("uv_args", nargs=-1, type=click.UNPROCESSED)
def uv(uv_args):
@@ -107,7 +123,7 @@ def version(tools):
if tools:
try:
tools_version = get_version("crewai")
tools_version = get_version("crewai-tools")
click.echo(f"crewai tools version: {tools_version}")
except Exception:
click.echo("crewai tools not installed")
@@ -142,12 +158,7 @@ def train(n_iterations: int, filename: str):
help="Replay the crew from this task ID, including all subsequent tasks.",
)
def replay(task_id: str) -> None:
"""
Replay the crew execution from a specific task.
Args:
task_id (str): The ID of the task to replay from.
"""
"""Replay the crew execution from a specific task."""
try:
click.echo(f"Replaying the crew from task {task_id}")
replay_task_command(task_id)
@@ -157,12 +168,9 @@ def replay(task_id: str) -> None:
@crewai.command()
def log_tasks_outputs() -> None:
"""
Retrieve your latest crew.kickoff() task outputs.
"""
"""Retrieve your latest crew.kickoff() task outputs."""
try:
storage = KickoffTaskOutputsSQLiteStorage()
tasks = storage.load()
tasks = load_task_outputs()
if not tasks:
click.echo(
@@ -220,11 +228,8 @@ def reset_memories(
agent_knowledge: bool,
all: bool,
) -> None:
"""
Reset the crew memories (memory, knowledge, agent_knowledge, kickoff_outputs). This will delete all the data saved.
"""
"""Reset the crew memories (memory, knowledge, agent_knowledge, kickoff_outputs). This will delete all the data saved."""
try:
# Treat legacy flags as --memory with a deprecation warning
if long or short or entities:
legacy_used = [
f
@@ -291,7 +296,7 @@ def memory(
) -> None:
"""Open the Memory TUI to browse scopes and recall memories."""
try:
from crewai.cli.memory_tui import MemoryTUI
from crewai_cli.memory_tui import MemoryTUI
except ImportError as exc:
click.echo(
"Textual is required for the memory TUI but could not be imported. "
@@ -341,10 +346,10 @@ def test(n_iterations: int, model: str):
@crewai.command(
context_settings=dict(
ignore_unknown_options=True,
allow_extra_args=True,
)
context_settings={
"ignore_unknown_options": True,
"allow_extra_args": True,
}
)
@click.pass_context
def install(context):
@@ -509,14 +514,12 @@ def triggers_run(trigger_path: str):
@crewai.command()
def chat():
"""
Start a conversation with the Crew, collecting user-supplied inputs,
"""Start a conversation with the Crew, collecting user-supplied inputs,
and using the Chat LLM to generate responses.
"""
click.secho(
"\nStarting a conversation with the Crew\nType 'exit' or Ctrl+C to quit.\n",
)
run_chat()
@@ -627,7 +630,7 @@ def env_view():
table.add_row(
"CREWAI_TRACING_ENABLED",
"[dim]Not set[/dim]",
"[dim][/dim]",
"[dim]---[/dim]",
)
# Check other related env vars
@@ -646,7 +649,7 @@ def env_view():
# Check if .env file exists
table.add_row(
".env file",
"Found" if env_file_exists else "Not found",
"Found" if env_file_exists else "Not found",
str(env_file.resolve()) if env_file_exists else "N/A",
)
@@ -662,11 +665,11 @@ def env_view():
# Show helpful message
if env_file_exists:
console.print(
"\n[dim]💡 Tip: To enable tracing via .env, add: CREWAI_TRACING_ENABLED=true[/dim]"
"\n[dim]Tip: To enable tracing via .env, add: CREWAI_TRACING_ENABLED=true[/dim]"
)
else:
console.print(
"\n[dim]💡 Tip: Create a .env file in your project root and add: CREWAI_TRACING_ENABLED=true[/dim]"
"\n[dim]Tip: Create a .env file in your project root and add: CREWAI_TRACING_ENABLED=true[/dim]"
)
console.print()
@@ -682,14 +685,16 @@ def traces_enable():
from rich.console import Console
from rich.panel import Panel
from crewai.events.listeners.tracing.utils import update_user_data
console = Console()
update_user_data({"trace_consent": True, "first_execution_done": True})
# Update user data to enable traces
user_data = _load_user_data()
user_data["trace_consent"] = True
user_data["first_execution_done"] = True
_save_user_data(user_data)
panel = Panel(
"Trace collection has been enabled!\n\n"
"Trace collection has been enabled!\n\n"
"Your crew/flow executions will now send traces to CrewAI+.\n"
"Use 'crewai traces disable' to turn off trace collection.",
title="Traces Enabled",
@@ -705,14 +710,16 @@ def traces_disable():
from rich.console import Console
from rich.panel import Panel
from crewai.events.listeners.tracing.utils import update_user_data
console = Console()
update_user_data({"trace_consent": False, "first_execution_done": True})
# Update user data to disable traces
user_data = _load_user_data()
user_data["trace_consent"] = False
user_data["first_execution_done"] = True
_save_user_data(user_data)
panel = Panel(
"Trace collection has been disabled!\n\n"
"Trace collection has been disabled!\n\n"
"Your crew/flow executions will no longer send traces.\n"
"Use 'crewai traces enable' to turn trace collection back on.",
title="Traces Disabled",
@@ -731,11 +738,6 @@ def traces_status():
from rich.panel import Panel
from rich.table import Table
from crewai.events.listeners.tracing.utils import (
_load_user_data,
is_tracing_enabled,
)
console = Console()
user_data = _load_user_data()
@@ -750,19 +752,19 @@ def traces_status():
# Check user consent
trace_consent = user_data.get("trace_consent")
if trace_consent is True:
consent_status = "Enabled (user consented)"
consent_status = "Enabled (user consented)"
elif trace_consent is False:
consent_status = "Disabled (user declined)"
consent_status = "Disabled (user declined)"
else:
consent_status = "Not set (first-time user)"
consent_status = "Not set (first-time user)"
table.add_row("User Consent", consent_status)
# Check overall status
if is_tracing_enabled():
overall_status = "ENABLED"
overall_status = "ENABLED"
border_style = "green"
else:
overall_status = "DISABLED"
overall_status = "DISABLED"
border_style = "red"
table.add_row("Overall Status", overall_status)

View File

@@ -1,11 +1,12 @@
from __future__ import annotations
import json
import httpx
from rich.console import Console
from crewai.cli.authentication.token import get_auth_token
from crewai.cli.plus_api import PlusAPI
from crewai.telemetry.telemetry import Telemetry
from crewai_cli.authentication.token import get_auth_token
from crewai_cli.plus_api import PlusAPI
console = Console()
@@ -13,17 +14,14 @@ console = Console()
class BaseCommand:
def __init__(self) -> None:
self._telemetry = Telemetry()
self._telemetry.set_tracer()
pass
class PlusAPIMixin:
def __init__(self, telemetry: Telemetry) -> None:
def __init__(self) -> None:
try:
telemetry.set_tracer()
self.plus_api_client = PlusAPI(api_key=get_auth_token())
except Exception:
telemetry.deploy_signup_error_span()
console.print(
"Please sign up/login to CrewAI+ before using the CLI.",
style="bold red",
@@ -32,12 +30,6 @@ class PlusAPIMixin:
raise SystemExit from None
def _validate_response(self, response: httpx.Response) -> None:
"""
Handle and display error messages from API responses.
Args:
response (httpx.Response): The response from the Plus API
"""
try:
json_response = response.json()
except (json.JSONDecodeError, ValueError):

View File

@@ -6,14 +6,14 @@ from typing import Any
from pydantic import BaseModel, Field
from crewai.cli.constants import (
from crewai_cli.constants import (
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_AUDIENCE,
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_CLIENT_ID,
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN,
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_PROVIDER,
DEFAULT_CREWAI_ENTERPRISE_URL,
)
from crewai.cli.shared.token_manager import TokenManager
from crewai_cli.shared.token_manager import TokenManager
logger = getLogger(__name__)

View File

@@ -5,13 +5,13 @@ import sys
import click
import tomli
from crewai.cli.constants import ENV_VARS, MODELS
from crewai.cli.provider import (
from crewai_cli.constants import ENV_VARS, MODELS
from crewai_cli.provider import (
get_provider_data,
select_model,
select_provider,
)
from crewai.cli.utils import copy_template, load_env_vars, write_env_file
from crewai_cli.utils import copy_template, load_env_vars, write_env_file
def get_reserved_script_names() -> set[str]:

View File

@@ -3,8 +3,6 @@ import shutil
import click
from crewai.telemetry import Telemetry
def create_flow(name):
"""Create a new flow."""
@@ -18,10 +16,6 @@ def create_flow(name):
click.secho(f"Error: Folder {folder_name} already exists.", fg="red")
return
# Initialize telemetry
telemetry = Telemetry()
telemetry.flow_creation_span(class_name)
# Create directory structure
(project_root / "src" / folder_name).mkdir(parents=True)
(project_root / "src" / folder_name / "crews").mkdir(parents=True)

View File

@@ -0,0 +1,23 @@
"""Wrapper for the crew chat command.
Delegates to ``crewai.utilities.crew_chat.run_chat`` when the full crewai
package is installed, otherwise prints a helpful error message.
"""
from __future__ import annotations
import click
def run_chat() -> None:
try:
from crewai.utilities.crew_chat import run_chat as _run_chat
except ImportError:
click.secho(
"The 'chat' command requires the full crewai package.\n"
"Install it with: pip install crewai",
fg="red",
)
raise SystemExit(1) from None
_run_chat()

View File

@@ -1,10 +1,11 @@
from pathlib import Path
from typing import Any
from rich.console import Console
from crewai.cli import git
from crewai.cli.command import BaseCommand, PlusAPIMixin
from crewai.cli.utils import fetch_and_json_env_file, get_project_name
from crewai_cli import git
from crewai_cli.command import BaseCommand, PlusAPIMixin
from crewai_cli.utils import fetch_and_json_env_file, get_project_name
console = Console()
@@ -21,8 +22,43 @@ class DeployCommand(BaseCommand, PlusAPIMixin):
"""
BaseCommand.__init__(self)
PlusAPIMixin.__init__(self, telemetry=self._telemetry)
PlusAPIMixin.__init__(self)
self.project_name = get_project_name(require=True)
self._validate_project_structure()
def _validate_project_structure(self) -> None:
"""Validate that the local project has the files required for deployment."""
errors: list[str] = []
if not Path("pyproject.toml").exists():
errors.append("Cannot find pyproject.toml in the current directory.")
has_lockfile = Path("uv.lock").exists() or Path("poetry.lock").exists()
if not has_lockfile:
errors.append(
"No uv.lock or poetry.lock found. "
"Run 'uv lock' or 'poetry lock' to generate one."
)
src_dir = Path("src") / (self.project_name or "")
crew_py = src_dir / "crew.py"
config_dir = src_dir / "config"
if not crew_py.exists() and not config_dir.exists():
errors.append(
f"Cannot find src/{self.project_name}/crew.py or "
f"src/{self.project_name}/config. "
"Ensure you are running this command from the project root."
)
if errors:
console.print(
"\n[bold red]Pre-flight check failed:[/bold red] "
"Your project is missing required files for deployment.\n"
)
for error in errors:
console.print(f"{error}", style="red")
console.print()
raise SystemExit(1)
def _standard_no_param_error_message(self) -> None:
"""
@@ -67,7 +103,6 @@ class DeployCommand(BaseCommand, PlusAPIMixin):
Args:
uuid (Optional[str]): The UUID of the crew to deploy.
"""
self._start_deployment_span = self._telemetry.start_deployment_span(uuid)
console.print("Starting deployment...", style="bold blue")
if uuid:
response = self.plus_api_client.deploy_by_uuid(uuid)
@@ -84,9 +119,6 @@ class DeployCommand(BaseCommand, PlusAPIMixin):
"""
Create a new crew deployment.
"""
self._create_crew_deployment_span = (
self._telemetry.create_crew_deployment_span()
)
console.print("Creating deployment...", style="bold blue")
env_vars = fetch_and_json_env_file()
@@ -236,7 +268,6 @@ class DeployCommand(BaseCommand, PlusAPIMixin):
uuid (Optional[str]): The UUID of the crew to get logs for.
log_type (str): The type of logs to retrieve (default: "deployment").
"""
self._get_crew_logs_span = self._telemetry.get_crew_logs_span(uuid, log_type)
console.print(f"Fetching {log_type} logs...", style="bold blue")
if uuid:
@@ -257,7 +288,6 @@ class DeployCommand(BaseCommand, PlusAPIMixin):
Args:
uuid (Optional[str]): The UUID of the crew to remove.
"""
self._remove_crew_span = self._telemetry.remove_crew_span(uuid)
console.print("Removing deployment...", style="bold blue")
if uuid:

View File

@@ -4,10 +4,10 @@ from typing import Any, cast
import httpx
from rich.console import Console
from crewai.cli.authentication.main import Oauth2Settings, ProviderFactory
from crewai.cli.command import BaseCommand
from crewai.cli.settings.main import SettingsCommand
from crewai.cli.version import get_crewai_version
from crewai_cli.authentication.main import Oauth2Settings, ProviderFactory
from crewai_cli.command import BaseCommand
from crewai_cli.settings.main import SettingsCommand
from crewai_cli.version import get_crewai_version
console = Console()

View File

@@ -2,8 +2,8 @@ from httpx import HTTPStatusError
from rich.console import Console
from rich.table import Table
from crewai.cli.command import BaseCommand, PlusAPIMixin
from crewai.cli.config import Settings
from crewai_cli.command import BaseCommand, PlusAPIMixin
from crewai_cli.config import Settings
console = Console()
@@ -12,7 +12,7 @@ console = Console()
class OrganizationCommand(BaseCommand, PlusAPIMixin):
def __init__(self) -> None:
BaseCommand.__init__(self)
PlusAPIMixin.__init__(self, telemetry=self._telemetry)
PlusAPIMixin.__init__(self)
def list(self) -> None:
try:

View File

@@ -0,0 +1,210 @@
import os
from typing import Any
from urllib.parse import urljoin
import httpx
from crewai_cli.config import Settings
from crewai_cli.constants import DEFAULT_CREWAI_ENTERPRISE_URL
from crewai_cli.version import get_crewai_version
class PlusAPI:
"""
This class exposes methods for working with the CrewAI+ API.
"""
TOOLS_RESOURCE = "/crewai_plus/api/v1/tools"
ORGANIZATIONS_RESOURCE = "/crewai_plus/api/v1/me/organizations"
CREWS_RESOURCE = "/crewai_plus/api/v1/crews"
AGENTS_RESOURCE = "/crewai_plus/api/v1/agents"
TRACING_RESOURCE = "/crewai_plus/api/v1/tracing"
EPHEMERAL_TRACING_RESOURCE = "/crewai_plus/api/v1/tracing/ephemeral"
INTEGRATIONS_RESOURCE = "/crewai_plus/api/v1/integrations"
def __init__(self, api_key: str) -> None:
self.api_key = api_key
self.headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
"User-Agent": f"CrewAI-CLI/{get_crewai_version()}",
"X-Crewai-Version": get_crewai_version(),
}
settings = Settings()
if settings.org_uuid:
self.headers["X-Crewai-Organization-Id"] = settings.org_uuid
self.base_url = (
os.getenv("CREWAI_PLUS_URL")
or str(settings.enterprise_base_url)
or DEFAULT_CREWAI_ENTERPRISE_URL
)
def _make_request(
self, method: str, endpoint: str, **kwargs: Any
) -> httpx.Response:
url = urljoin(self.base_url, endpoint)
verify = kwargs.pop("verify", True)
with httpx.Client(trust_env=False, verify=verify) as client:
return client.request(method, url, headers=self.headers, **kwargs)
def login_to_tool_repository(self) -> httpx.Response:
return self._make_request("POST", f"{self.TOOLS_RESOURCE}/login")
def get_tool(self, handle: str) -> httpx.Response:
return self._make_request("GET", f"{self.TOOLS_RESOURCE}/{handle}")
async def get_agent(self, handle: str) -> httpx.Response:
url = urljoin(self.base_url, f"{self.AGENTS_RESOURCE}/{handle}")
async with httpx.AsyncClient() as client:
return await client.get(url, headers=self.headers)
def publish_tool(
self,
handle: str,
is_public: bool,
version: str,
description: str | None,
encoded_file: str,
available_exports: list[dict[str, Any]] | None = None,
) -> httpx.Response:
params = {
"handle": handle,
"public": is_public,
"version": version,
"file": encoded_file,
"description": description,
"available_exports": available_exports,
}
return self._make_request("POST", f"{self.TOOLS_RESOURCE}", json=params)
def deploy_by_name(self, project_name: str) -> httpx.Response:
return self._make_request(
"POST", f"{self.CREWS_RESOURCE}/by-name/{project_name}/deploy"
)
def deploy_by_uuid(self, uuid: str) -> httpx.Response:
return self._make_request("POST", f"{self.CREWS_RESOURCE}/{uuid}/deploy")
def crew_status_by_name(self, project_name: str) -> httpx.Response:
return self._make_request(
"GET", f"{self.CREWS_RESOURCE}/by-name/{project_name}/status"
)
def crew_status_by_uuid(self, uuid: str) -> httpx.Response:
return self._make_request("GET", f"{self.CREWS_RESOURCE}/{uuid}/status")
def crew_by_name(
self, project_name: str, log_type: str = "deployment"
) -> httpx.Response:
return self._make_request(
"GET", f"{self.CREWS_RESOURCE}/by-name/{project_name}/logs/{log_type}"
)
def crew_by_uuid(self, uuid: str, log_type: str = "deployment") -> httpx.Response:
return self._make_request(
"GET", f"{self.CREWS_RESOURCE}/{uuid}/logs/{log_type}"
)
def delete_crew_by_name(self, project_name: str) -> httpx.Response:
return self._make_request(
"DELETE", f"{self.CREWS_RESOURCE}/by-name/{project_name}"
)
def delete_crew_by_uuid(self, uuid: str) -> httpx.Response:
return self._make_request("DELETE", f"{self.CREWS_RESOURCE}/{uuid}")
def list_crews(self) -> httpx.Response:
return self._make_request("GET", self.CREWS_RESOURCE)
def create_crew(self, payload: dict[str, Any]) -> httpx.Response:
return self._make_request("POST", self.CREWS_RESOURCE, json=payload)
def get_organizations(self) -> httpx.Response:
return self._make_request("GET", self.ORGANIZATIONS_RESOURCE)
def initialize_trace_batch(self, payload: dict[str, Any]) -> httpx.Response:
return self._make_request(
"POST",
f"{self.TRACING_RESOURCE}/batches",
json=payload,
timeout=30,
)
def initialize_ephemeral_trace_batch(
self, payload: dict[str, Any]
) -> httpx.Response:
return self._make_request(
"POST",
f"{self.EPHEMERAL_TRACING_RESOURCE}/batches",
json=payload,
)
def send_trace_events(
self, trace_batch_id: str, payload: dict[str, Any]
) -> httpx.Response:
return self._make_request(
"POST",
f"{self.TRACING_RESOURCE}/batches/{trace_batch_id}/events",
json=payload,
timeout=30,
)
def send_ephemeral_trace_events(
self, trace_batch_id: str, payload: dict[str, Any]
) -> httpx.Response:
return self._make_request(
"POST",
f"{self.EPHEMERAL_TRACING_RESOURCE}/batches/{trace_batch_id}/events",
json=payload,
timeout=30,
)
def finalize_trace_batch(
self, trace_batch_id: str, payload: dict[str, Any]
) -> httpx.Response:
return self._make_request(
"PATCH",
f"{self.TRACING_RESOURCE}/batches/{trace_batch_id}/finalize",
json=payload,
timeout=30,
)
def finalize_ephemeral_trace_batch(
self, trace_batch_id: str, payload: dict[str, Any]
) -> httpx.Response:
return self._make_request(
"PATCH",
f"{self.EPHEMERAL_TRACING_RESOURCE}/batches/{trace_batch_id}/finalize",
json=payload,
timeout=30,
)
def mark_trace_batch_as_failed(
self, trace_batch_id: str, error_message: str
) -> httpx.Response:
return self._make_request(
"PATCH",
f"{self.TRACING_RESOURCE}/batches/{trace_batch_id}",
json={"status": "failed", "failure_reason": error_message},
timeout=30,
)
def get_mcp_configs(self, slugs: list[str]) -> httpx.Response:
"""Get MCP server configurations for the given slugs."""
return self._make_request(
"GET",
f"{self.INTEGRATIONS_RESOURCE}/mcp_configs",
params={"slugs": ",".join(slugs)},
timeout=30,
)
def get_triggers(self) -> httpx.Response:
"""Get all available triggers from integrations."""
return self._make_request("GET", f"{self.INTEGRATIONS_RESOURCE}/apps")
def get_trigger_payload(self, app_slug: str, trigger_slug: str) -> httpx.Response:
"""Get sample payload for a specific trigger."""
return self._make_request(
"GET", f"{self.INTEGRATIONS_RESOURCE}/{app_slug}/{trigger_slug}/payload"
)

View File

@@ -10,7 +10,7 @@ import certifi
import click
import httpx
from crewai.cli.constants import JSON_URL, MODELS, PROVIDERS
from crewai_cli.constants import JSON_URL, MODELS, PROVIDERS
def select_choice(prompt_message: str, choices: Sequence[str]) -> str | None:

View File

@@ -0,0 +1,31 @@
"""Wrapper for the reset-memories command.
Delegates to ``crewai.utilities.reset_memories`` when the full crewai
package is installed, otherwise prints a helpful error message.
"""
from __future__ import annotations
import click
def reset_memories_command(
memory: bool,
knowledge: bool,
agent_knowledge: bool,
kickoff_outputs: bool,
all: bool,
) -> None:
try:
from crewai.utilities.reset_memories import (
reset_memories_command as _reset,
)
except ImportError:
click.secho(
"The 'reset-memories' command requires the full crewai package.\n"
"Install it with: pip install crewai",
fg="red",
)
raise SystemExit(1) from None
_reset(memory, knowledge, agent_knowledge, kickoff_outputs, all)

View File

@@ -5,8 +5,8 @@ import subprocess
import click
from packaging import version
from crewai.cli.utils import build_env_with_tool_repository_credentials, read_toml
from crewai.cli.version import get_crewai_version
from crewai_cli.utils import build_env_with_tool_repository_credentials, read_toml
from crewai_cli.version import get_crewai_version
class CrewType(Enum):

View File

@@ -5,9 +5,9 @@ from typing import Any
from rich.console import Console
from rich.table import Table
from crewai.cli.command import BaseCommand
from crewai.cli.config import HIDDEN_SETTINGS_KEYS, READONLY_SETTINGS_KEYS, Settings
from crewai.events.listeners.tracing.utils import _load_user_data
from crewai_cli.command import BaseCommand
from crewai_cli.config import HIDDEN_SETTINGS_KEYS, READONLY_SETTINGS_KEYS, Settings
from crewai_cli.user_data import _load_user_data
console = Console()

View File

@@ -0,0 +1,54 @@
"""Lightweight SQLite reader for kickoff task outputs.
Only used by the ``crewai log-tasks-outputs`` CLI command. Depends solely on
the standard library + *appdirs* so crewai-cli can read stored outputs without
importing the full crewai framework.
"""
from __future__ import annotations
import json
import logging
from pathlib import Path
import sqlite3
from typing import Any
from crewai_cli.user_data import _db_storage_path
logger = logging.getLogger(__name__)
def load_task_outputs(db_path: str | None = None) -> list[dict[str, Any]]:
"""Return all rows from the kickoff task outputs database."""
if db_path is None:
db_path = str(Path(_db_storage_path()) / "latest_kickoff_task_outputs.db")
if not Path(db_path).exists():
return []
try:
with sqlite3.connect(db_path) as conn:
cursor = conn.cursor()
cursor.execute("""
SELECT *
FROM latest_kickoff_task_outputs
ORDER BY task_index
""")
rows = cursor.fetchall()
results: list[dict[str, Any]] = [
{
"task_id": row[0],
"expected_output": row[1],
"output": json.loads(row[2]),
"task_index": row[3],
"inputs": json.loads(row[4]),
"was_replayed": row[5],
"timestamp": row[6],
}
for row in rows
]
return results
except sqlite3.Error as e:
logger.error("Failed to load task outputs: %s", e)
return []

View File

@@ -8,13 +8,14 @@ import tempfile
from typing import Any
import click
from crewai.events.listeners.tracing.utils import get_user_id
from rich.console import Console
from crewai.cli import git
from crewai.cli.command import BaseCommand, PlusAPIMixin
from crewai.cli.config import Settings
from crewai.cli.constants import DEFAULT_CREWAI_ENTERPRISE_URL
from crewai.cli.utils import (
from crewai_cli import git
from crewai_cli.command import BaseCommand, PlusAPIMixin
from crewai_cli.config import Settings
from crewai_cli.constants import DEFAULT_CREWAI_ENTERPRISE_URL
from crewai_cli.utils import (
build_env_with_tool_repository_credentials,
extract_available_exports,
get_project_description,
@@ -23,7 +24,6 @@ from crewai.cli.utils import (
tree_copy,
tree_find_and_replace,
)
from crewai.events.listeners.tracing.utils import get_user_id
console = Console()
@@ -36,7 +36,7 @@ class ToolCommand(BaseCommand, PlusAPIMixin):
def __init__(self) -> None:
BaseCommand.__init__(self)
PlusAPIMixin.__init__(self, telemetry=self._telemetry)
PlusAPIMixin.__init__(self)
def create(self, handle: str) -> None:
self._ensure_not_in_project()

View File

@@ -1,6 +1,6 @@
"""Triggers command module for CrewAI CLI."""
from crewai.cli.triggers.main import TriggersCommand
from crewai_cli.triggers.main import TriggersCommand
__all__ = ["TriggersCommand"]

View File

@@ -5,7 +5,7 @@ from typing import Any
from rich.console import Console
from rich.table import Table
from crewai.cli.command import BaseCommand, PlusAPIMixin
from crewai_cli.command import BaseCommand, PlusAPIMixin
console = Console()
@@ -18,7 +18,7 @@ class TriggersCommand(BaseCommand, PlusAPIMixin):
def __init__(self):
BaseCommand.__init__(self)
PlusAPIMixin.__init__(self, telemetry=self._telemetry)
PlusAPIMixin.__init__(self)
def list_triggers(self) -> None:
"""List all available triggers from integrations."""

View File

@@ -3,7 +3,7 @@ import shutil
import tomli_w
from crewai.cli.utils import read_toml
from crewai_cli.utils import read_toml
def update_crew() -> None:

View File

@@ -0,0 +1,66 @@
"""Standalone user-data helpers for the CLI package.
These mirror the functions in ``crewai.events.listeners.tracing.utils`` but
depend only on the standard library + *appdirs* so that crewai-cli can work
without importing the full crewai framework.
"""
from __future__ import annotations
import json
import logging
import os
from pathlib import Path
from typing import Any, cast
import appdirs
logger = logging.getLogger(__name__)
def _get_project_directory_name() -> str:
return os.environ.get("CREWAI_STORAGE_DIR", Path.cwd().name)
def _db_storage_path() -> str:
app_name = _get_project_directory_name()
app_author = "CrewAI"
data_dir = Path(appdirs.user_data_dir(app_name, app_author))
data_dir.mkdir(parents=True, exist_ok=True)
return str(data_dir)
def _user_data_file() -> Path:
base = Path(_db_storage_path())
base.mkdir(parents=True, exist_ok=True)
return base / ".crewai_user.json"
def _load_user_data() -> dict[str, Any]:
p = _user_data_file()
if p.exists():
try:
return cast(dict[str, Any], json.loads(p.read_text()))
except (json.JSONDecodeError, OSError, PermissionError) as e:
logger.warning("Failed to load user data: %s", e)
return {}
def _save_user_data(data: dict[str, Any]) -> None:
try:
p = _user_data_file()
p.write_text(json.dumps(data, indent=2))
except (OSError, PermissionError) as e:
logger.warning("Failed to save user data: %s", e)
def is_tracing_enabled() -> bool:
"""Check if tracing is enabled (mirrors crewai core logic)."""
data = _load_user_data()
if (
data.get("first_execution_done", False)
and data.get("trace_consent", False) is False
):
return False
return os.getenv("CREWAI_TRACING_ENABLED", "false").lower() == "true"

View File

@@ -0,0 +1,369 @@
from __future__ import annotations
from functools import reduce
from inspect import getmro, isclass
import os
from pathlib import Path
import shutil
import sys
from typing import Any, cast
import click
from rich.console import Console
import tomli
from crewai_cli.config import Settings
from crewai_cli.constants import ENV_VARS
if sys.version_info >= (3, 11):
import tomllib
console = Console()
def copy_template(
src: Path, dst: Path, name: str, class_name: str, folder_name: str
) -> None:
"""Copy a file from src to dst."""
with open(src, "r") as file:
content = file.read()
content = content.replace("{{name}}", name)
content = content.replace("{{crew_name}}", class_name)
content = content.replace("{{folder_name}}", folder_name)
with open(dst, "w") as file:
file.write(content)
click.secho(f" - Created {dst}", fg="green")
def read_toml(file_path: str = "pyproject.toml") -> dict[str, Any]:
"""Read the content of a TOML file and return it as a dictionary."""
with open(file_path, "rb") as f:
return tomli.load(f)
def parse_toml(content: str) -> dict[str, Any]:
if sys.version_info >= (3, 11):
return tomllib.loads(content)
return tomli.loads(content)
def get_project_name(
pyproject_path: str = "pyproject.toml", require: bool = False
) -> str | None:
"""Get the project name from the pyproject.toml file."""
return _get_project_attribute(pyproject_path, ["project", "name"], require=require)
def get_project_version(
pyproject_path: str = "pyproject.toml", require: bool = False
) -> str | None:
"""Get the project version from the pyproject.toml file."""
return _get_project_attribute(
pyproject_path, ["project", "version"], require=require
)
def get_project_description(
pyproject_path: str = "pyproject.toml", require: bool = False
) -> str | None:
"""Get the project description from the pyproject.toml file."""
return _get_project_attribute(
pyproject_path, ["project", "description"], require=require
)
def _get_project_attribute(
pyproject_path: str, keys: list[str], require: bool
) -> Any | None:
"""Get an attribute from the pyproject.toml file."""
attribute = None
try:
with open(pyproject_path, "r") as f:
pyproject_content = parse_toml(f.read())
dependencies = (
_get_nested_value(pyproject_content, ["project", "dependencies"]) or []
)
if not any(True for dep in dependencies if "crewai" in dep):
raise Exception("crewai is not in the dependencies.")
attribute = _get_nested_value(pyproject_content, keys)
except FileNotFoundError:
console.print(f"Error: {pyproject_path} not found.", style="bold red")
except KeyError:
console.print(
f"Error: {pyproject_path} is not a valid pyproject.toml file.",
style="bold red",
)
except Exception as e:
if sys.version_info >= (3, 11) and isinstance(e, tomllib.TOMLDecodeError):
console.print(
f"Error: {pyproject_path} is not a valid TOML file.", style="bold red"
)
else:
console.print(
f"Error reading the pyproject.toml file: {e}", style="bold red"
)
if require and not attribute:
console.print(
f"Unable to read '{'.'.join(keys)}' in the pyproject.toml file. Please verify that the file exists and contains the specified attribute.",
style="bold red",
)
raise SystemExit
return attribute
def _get_nested_value(data: dict[str, Any], keys: list[str]) -> Any:
return reduce(dict.__getitem__, keys, data)
def fetch_and_json_env_file(env_file_path: str = ".env") -> dict[str, Any]:
"""Fetch the environment variables from a .env file and return them as a dictionary."""
try:
with open(env_file_path, "r") as f:
env_content = f.read()
env_dict = {}
for line in env_content.splitlines():
if line.strip() and not line.strip().startswith("#"):
key, value = line.split("=", 1)
env_dict[key.strip()] = value.strip()
return env_dict
except FileNotFoundError:
console.print(f"Error: {env_file_path} not found.", style="bold red")
except Exception as e:
console.print(f"Error reading the .env file: {e}", style="bold red")
return {}
def tree_copy(source: Path, destination: Path) -> None:
"""Copies the entire directory structure from the source to the destination."""
for item in os.listdir(source):
source_item = os.path.join(source, item)
destination_item = os.path.join(destination, item)
if os.path.isdir(source_item):
shutil.copytree(source_item, destination_item)
else:
shutil.copy2(source_item, destination_item)
def tree_find_and_replace(directory: Path, find: str, replace: str) -> None:
"""Recursively searches through a directory, replacing a target string in
both file contents and filenames with a specified replacement string.
"""
for path, dirs, files in os.walk(os.path.abspath(directory), topdown=False):
for filename in files:
filepath = os.path.join(path, filename)
with open(filepath, "r", encoding="utf-8", errors="ignore") as file:
contents = file.read()
with open(filepath, "w") as file:
file.write(contents.replace(find, replace))
if find in filename:
new_filename = filename.replace(find, replace)
new_filepath = os.path.join(path, new_filename)
os.rename(filepath, new_filepath)
for dirname in dirs:
if find in dirname:
new_dirname = dirname.replace(find, replace)
new_dirpath = os.path.join(path, new_dirname)
old_dirpath = os.path.join(path, dirname)
os.rename(old_dirpath, new_dirpath)
def load_env_vars(folder_path: Path) -> dict[str, Any]:
"""Loads environment variables from a .env file in the specified folder path."""
env_file_path = folder_path / ".env"
env_vars = {}
if env_file_path.exists():
with open(env_file_path, "r") as file:
for line in file:
key, _, value = line.strip().partition("=")
if key and value:
env_vars[key] = value
return env_vars
def update_env_vars(
env_vars: dict[str, Any], provider: str, model: str
) -> dict[str, Any] | None:
"""Updates environment variables with the API key for the selected provider and model."""
provider_config = cast(
list[str],
ENV_VARS.get(
provider,
[
click.prompt(
f"Enter the environment variable name for your {provider.capitalize()} API key",
type=str,
)
],
),
)
api_key_var = provider_config[0]
if api_key_var not in env_vars:
try:
env_vars[api_key_var] = click.prompt(
f"Enter your {provider.capitalize()} API key", type=str, hide_input=True
)
except click.exceptions.Abort:
click.secho("Operation aborted by the user.", fg="red")
return None
else:
click.secho(f"API key already exists for {provider.capitalize()}.", fg="yellow")
env_vars["MODEL"] = model
click.secho(f"Selected model: {model}", fg="green")
return env_vars
def write_env_file(folder_path: Path, env_vars: dict[str, Any]) -> None:
"""Writes environment variables to a .env file in the specified folder."""
env_file_path = folder_path / ".env"
with open(env_file_path, "w") as file:
for key, value in env_vars.items():
file.write(f"{key.upper()}={value}\n")
def is_valid_tool(obj: Any) -> bool:
"""Check if an object is a valid tool class.
Works without importing crewai by checking MRO class names.
Falls back to crewai's ``is_valid_tool`` when available.
"""
try:
from crewai.utilities.project_utils import is_valid_tool as _core_is_valid_tool
return _core_is_valid_tool(obj)
except ImportError:
pass
if isclass(obj):
try:
return any(base.__name__ == "BaseTool" for base in getmro(obj))
except (TypeError, AttributeError):
return False
return False
def extract_available_exports(dir_path: str = "src") -> list[dict[str, Any]]:
"""Extract available tool classes from the project's __init__.py files."""
try:
init_files = Path(dir_path).glob("**/__init__.py")
available_exports: list[dict[str, Any]] = []
for init_file in init_files:
tools = _load_tools_from_init(init_file)
available_exports.extend(tools)
if not available_exports:
_print_no_tools_warning()
raise SystemExit(1)
return available_exports
except SystemExit:
raise
except Exception as e:
console.print(f"[red]Error: Could not extract tool classes: {e!s}[/red]")
console.print(
"Please ensure your project contains valid tools (classes inheriting from BaseTool or functions with @tool decorator)."
)
raise SystemExit(1) from e
def _load_tools_from_init(init_file: Path) -> list[dict[str, Any]]:
"""Load and validate tools from a given __init__.py file."""
import importlib.util as _importlib_util
spec = _importlib_util.spec_from_file_location("temp_module", init_file)
if not spec or not spec.loader:
return []
module = _importlib_util.module_from_spec(spec)
sys.modules["temp_module"] = module
try:
spec.loader.exec_module(module)
if not hasattr(module, "__all__"):
console.print(
f"Warning: No __all__ defined in {init_file}",
style="bold yellow",
)
raise SystemExit(1)
return [
{"name": name}
for name in module.__all__
if hasattr(module, name) and is_valid_tool(getattr(module, name))
]
except SystemExit:
raise
except Exception as e:
console.print(f"[red]Warning: Could not load {init_file}: {e!s}[/red]")
raise SystemExit(1) from e
finally:
sys.modules.pop("temp_module", None)
def _print_no_tools_warning() -> None:
"""Display warning and usage instructions if no tools were found."""
console.print(
"\n[bold yellow]Warning: No valid tools were exposed in your __init__.py file![/bold yellow]"
)
console.print(
"Your __init__.py file must contain all classes that inherit from [bold]BaseTool[/bold] "
"or functions decorated with [bold]@tool[/bold]."
)
console.print(
"\nExample:\n[dim]# In your __init__.py file[/dim]\n"
"[green]__all__ = ['YourTool', 'your_tool_function'][/green]\n\n"
"[dim]# In your tool.py file[/dim]\n"
"[green]from crewai.tools import BaseTool, tool\n\n"
"# Tool class example\n"
"class YourTool(BaseTool):\n"
' name = "your_tool"\n'
' description = "Your tool description"\n'
" # ... rest of implementation\n\n"
"# Decorated function example\n"
"@tool\n"
"def your_tool_function(text: str) -> str:\n"
' """Your tool description"""\n'
" # ... implementation\n"
" return result\n"
)
def build_env_with_tool_repository_credentials(
repository_handle: str,
) -> dict[str, Any]:
repository_handle = repository_handle.upper().replace("-", "_")
settings = Settings()
env = os.environ.copy()
env[f"UV_INDEX_{repository_handle}_USERNAME"] = str(
settings.tool_repository_username or ""
)
env[f"UV_INDEX_{repository_handle}_PASSWORD"] = str(
settings.tool_repository_password or ""
)
return env

View File

@@ -0,0 +1,91 @@
import pytest
from crewai_cli.authentication.main import Oauth2Settings
from crewai_cli.authentication.providers.auth0 import Auth0Provider
class TestAuth0Provider:
@pytest.fixture(autouse=True)
def setup_method(self):
self.valid_settings = Oauth2Settings(
provider="auth0",
domain="test-domain.auth0.com",
client_id="test-client-id",
audience="test-audience"
)
self.provider = Auth0Provider(self.valid_settings)
def test_initialization_with_valid_settings(self):
provider = Auth0Provider(self.valid_settings)
assert provider.settings == self.valid_settings
assert provider.settings.provider == "auth0"
assert provider.settings.domain == "test-domain.auth0.com"
assert provider.settings.client_id == "test-client-id"
assert provider.settings.audience == "test-audience"
def test_get_authorize_url(self):
expected_url = "https://test-domain.auth0.com/oauth/device/code"
assert self.provider.get_authorize_url() == expected_url
def test_get_authorize_url_with_different_domain(self):
settings = Oauth2Settings(
provider="auth0",
domain="my-company.auth0.com",
client_id="test-client",
audience="test-audience"
)
provider = Auth0Provider(settings)
expected_url = "https://my-company.auth0.com/oauth/device/code"
assert provider.get_authorize_url() == expected_url
def test_get_token_url(self):
expected_url = "https://test-domain.auth0.com/oauth/token"
assert self.provider.get_token_url() == expected_url
def test_get_token_url_with_different_domain(self):
settings = Oauth2Settings(
provider="auth0",
domain="another-domain.auth0.com",
client_id="test-client",
audience="test-audience"
)
provider = Auth0Provider(settings)
expected_url = "https://another-domain.auth0.com/oauth/token"
assert provider.get_token_url() == expected_url
def test_get_jwks_url(self):
expected_url = "https://test-domain.auth0.com/.well-known/jwks.json"
assert self.provider.get_jwks_url() == expected_url
def test_get_jwks_url_with_different_domain(self):
settings = Oauth2Settings(
provider="auth0",
domain="dev.auth0.com",
client_id="test-client",
audience="test-audience"
)
provider = Auth0Provider(settings)
expected_url = "https://dev.auth0.com/.well-known/jwks.json"
assert provider.get_jwks_url() == expected_url
def test_get_issuer(self):
expected_issuer = "https://test-domain.auth0.com/"
assert self.provider.get_issuer() == expected_issuer
def test_get_issuer_with_different_domain(self):
settings = Oauth2Settings(
provider="auth0",
domain="prod.auth0.com",
client_id="test-client",
audience="test-audience"
)
provider = Auth0Provider(settings)
expected_issuer = "https://prod.auth0.com/"
assert provider.get_issuer() == expected_issuer
def test_get_audience(self):
assert self.provider.get_audience() == "test-audience"
def test_get_client_id(self):
assert self.provider.get_client_id() == "test-client-id"

View File

@@ -0,0 +1,141 @@
import pytest
from crewai_cli.authentication.main import Oauth2Settings
from crewai_cli.authentication.providers.entra_id import EntraIdProvider
class TestEntraIdProvider:
@pytest.fixture(autouse=True)
def setup_method(self):
self.valid_settings = Oauth2Settings(
provider="entra_id",
domain="tenant-id-abcdef123456",
client_id="test-client-id",
audience="test-audience",
extra={
"scope": "openid profile email api://crewai-cli-dev/read"
}
)
self.provider = EntraIdProvider(self.valid_settings)
def test_initialization_with_valid_settings(self):
provider = EntraIdProvider(self.valid_settings)
assert provider.settings == self.valid_settings
assert provider.settings.provider == "entra_id"
assert provider.settings.domain == "tenant-id-abcdef123456"
assert provider.settings.client_id == "test-client-id"
assert provider.settings.audience == "test-audience"
def test_get_authorize_url(self):
expected_url = "https://login.microsoftonline.com/tenant-id-abcdef123456/oauth2/v2.0/devicecode"
assert self.provider.get_authorize_url() == expected_url
def test_get_authorize_url_with_different_domain(self):
# For EntraID, the domain is the tenant ID.
settings = Oauth2Settings(
provider="entra_id",
domain="my-company.entra.id",
client_id="test-client",
audience="test-audience",
)
provider = EntraIdProvider(settings)
expected_url = "https://login.microsoftonline.com/my-company.entra.id/oauth2/v2.0/devicecode"
assert provider.get_authorize_url() == expected_url
def test_get_token_url(self):
expected_url = "https://login.microsoftonline.com/tenant-id-abcdef123456/oauth2/v2.0/token"
assert self.provider.get_token_url() == expected_url
def test_get_token_url_with_different_domain(self):
# For EntraID, the domain is the tenant ID.
settings = Oauth2Settings(
provider="entra_id",
domain="another-domain.entra.id",
client_id="test-client",
audience="test-audience",
)
provider = EntraIdProvider(settings)
expected_url = "https://login.microsoftonline.com/another-domain.entra.id/oauth2/v2.0/token"
assert provider.get_token_url() == expected_url
def test_get_jwks_url(self):
expected_url = "https://login.microsoftonline.com/tenant-id-abcdef123456/discovery/v2.0/keys"
assert self.provider.get_jwks_url() == expected_url
def test_get_jwks_url_with_different_domain(self):
# For EntraID, the domain is the tenant ID.
settings = Oauth2Settings(
provider="entra_id",
domain="dev.entra.id",
client_id="test-client",
audience="test-audience",
)
provider = EntraIdProvider(settings)
expected_url = "https://login.microsoftonline.com/dev.entra.id/discovery/v2.0/keys"
assert provider.get_jwks_url() == expected_url
def test_get_issuer(self):
expected_issuer = "https://login.microsoftonline.com/tenant-id-abcdef123456/v2.0"
assert self.provider.get_issuer() == expected_issuer
def test_get_issuer_with_different_domain(self):
# For EntraID, the domain is the tenant ID.
settings = Oauth2Settings(
provider="entra_id",
domain="other-tenant-id-xpto",
client_id="test-client",
audience="test-audience",
)
provider = EntraIdProvider(settings)
expected_issuer = "https://login.microsoftonline.com/other-tenant-id-xpto/v2.0"
assert provider.get_issuer() == expected_issuer
def test_get_audience(self):
assert self.provider.get_audience() == "test-audience"
def test_get_audience_assertion_error_when_none(self):
settings = Oauth2Settings(
provider="entra_id",
domain="test-tenant-id",
client_id="test-client-id",
audience=None,
)
provider = EntraIdProvider(settings)
with pytest.raises(ValueError, match="Audience is required"):
provider.get_audience()
def test_get_client_id(self):
assert self.provider.get_client_id() == "test-client-id"
def test_get_required_fields(self):
assert set(self.provider.get_required_fields()) == set(["scope"])
def test_get_oauth_scopes(self):
settings = Oauth2Settings(
provider="entra_id",
domain="tenant-id-abcdef123456",
client_id="test-client-id",
audience="test-audience",
extra={
"scope": "api://crewai-cli-dev/read"
}
)
provider = EntraIdProvider(settings)
assert provider.get_oauth_scopes() == ["openid", "profile", "email", "api://crewai-cli-dev/read"]
def test_get_oauth_scopes_with_multiple_custom_scopes(self):
settings = Oauth2Settings(
provider="entra_id",
domain="tenant-id-abcdef123456",
client_id="test-client-id",
audience="test-audience",
extra={
"scope": "api://crewai-cli-dev/read api://crewai-cli-dev/write custom-scope1 custom-scope2"
}
)
provider = EntraIdProvider(settings)
assert provider.get_oauth_scopes() == ["openid", "profile", "email", "api://crewai-cli-dev/read", "api://crewai-cli-dev/write", "custom-scope1", "custom-scope2"]
def test_base_url(self):
assert self.provider._base_url() == "https://login.microsoftonline.com/tenant-id-abcdef123456"

View File

@@ -0,0 +1,138 @@
import pytest
from crewai_cli.authentication.main import Oauth2Settings
from crewai_cli.authentication.providers.keycloak import KeycloakProvider
class TestKeycloakProvider:
@pytest.fixture(autouse=True)
def setup_method(self):
self.valid_settings = Oauth2Settings(
provider="keycloak",
domain="keycloak.example.com",
client_id="test-client-id",
audience="test-audience",
extra={
"realm": "test-realm"
}
)
self.provider = KeycloakProvider(self.valid_settings)
def test_initialization_with_valid_settings(self):
provider = KeycloakProvider(self.valid_settings)
assert provider.settings == self.valid_settings
assert provider.settings.provider == "keycloak"
assert provider.settings.domain == "keycloak.example.com"
assert provider.settings.client_id == "test-client-id"
assert provider.settings.audience == "test-audience"
assert provider.settings.extra.get("realm") == "test-realm"
def test_get_authorize_url(self):
expected_url = "https://keycloak.example.com/realms/test-realm/protocol/openid-connect/auth/device"
assert self.provider.get_authorize_url() == expected_url
def test_get_authorize_url_with_different_domain(self):
settings = Oauth2Settings(
provider="keycloak",
domain="auth.company.com",
client_id="test-client",
audience="test-audience",
extra={
"realm": "my-realm"
}
)
provider = KeycloakProvider(settings)
expected_url = "https://auth.company.com/realms/my-realm/protocol/openid-connect/auth/device"
assert provider.get_authorize_url() == expected_url
def test_get_token_url(self):
expected_url = "https://keycloak.example.com/realms/test-realm/protocol/openid-connect/token"
assert self.provider.get_token_url() == expected_url
def test_get_token_url_with_different_domain(self):
settings = Oauth2Settings(
provider="keycloak",
domain="sso.enterprise.com",
client_id="test-client",
audience="test-audience",
extra={
"realm": "enterprise-realm"
}
)
provider = KeycloakProvider(settings)
expected_url = "https://sso.enterprise.com/realms/enterprise-realm/protocol/openid-connect/token"
assert provider.get_token_url() == expected_url
def test_get_jwks_url(self):
expected_url = "https://keycloak.example.com/realms/test-realm/protocol/openid-connect/certs"
assert self.provider.get_jwks_url() == expected_url
def test_get_jwks_url_with_different_domain(self):
settings = Oauth2Settings(
provider="keycloak",
domain="identity.org",
client_id="test-client",
audience="test-audience",
extra={
"realm": "org-realm"
}
)
provider = KeycloakProvider(settings)
expected_url = "https://identity.org/realms/org-realm/protocol/openid-connect/certs"
assert provider.get_jwks_url() == expected_url
def test_get_issuer(self):
expected_issuer = "https://keycloak.example.com/realms/test-realm"
assert self.provider.get_issuer() == expected_issuer
def test_get_issuer_with_different_domain(self):
settings = Oauth2Settings(
provider="keycloak",
domain="login.myapp.io",
client_id="test-client",
audience="test-audience",
extra={
"realm": "app-realm"
}
)
provider = KeycloakProvider(settings)
expected_issuer = "https://login.myapp.io/realms/app-realm"
assert provider.get_issuer() == expected_issuer
def test_get_audience(self):
assert self.provider.get_audience() == "test-audience"
def test_get_client_id(self):
assert self.provider.get_client_id() == "test-client-id"
def test_get_required_fields(self):
assert self.provider.get_required_fields() == ["realm"]
def test_oauth2_base_url(self):
assert self.provider._oauth2_base_url() == "https://keycloak.example.com"
def test_oauth2_base_url_strips_https_prefix(self):
settings = Oauth2Settings(
provider="keycloak",
domain="https://keycloak.example.com",
client_id="test-client-id",
audience="test-audience",
extra={
"realm": "test-realm"
}
)
provider = KeycloakProvider(settings)
assert provider._oauth2_base_url() == "https://keycloak.example.com"
def test_oauth2_base_url_strips_http_prefix(self):
settings = Oauth2Settings(
provider="keycloak",
domain="http://keycloak.example.com",
client_id="test-client-id",
audience="test-audience",
extra={
"realm": "test-realm"
}
)
provider = KeycloakProvider(settings)
assert provider._oauth2_base_url() == "https://keycloak.example.com"

View File

@@ -0,0 +1,257 @@
import pytest
from crewai_cli.authentication.main import Oauth2Settings
from crewai_cli.authentication.providers.okta import OktaProvider
class TestOktaProvider:
@pytest.fixture(autouse=True)
def setup_method(self):
self.valid_settings = Oauth2Settings(
provider="okta",
domain="test-domain.okta.com",
client_id="test-client-id",
audience="test-audience",
)
self.provider = OktaProvider(self.valid_settings)
def test_initialization_with_valid_settings(self):
provider = OktaProvider(self.valid_settings)
assert provider.settings == self.valid_settings
assert provider.settings.provider == "okta"
assert provider.settings.domain == "test-domain.okta.com"
assert provider.settings.client_id == "test-client-id"
assert provider.settings.audience == "test-audience"
def test_get_authorize_url(self):
expected_url = "https://test-domain.okta.com/oauth2/default/v1/device/authorize"
assert self.provider.get_authorize_url() == expected_url
def test_get_authorize_url_with_different_domain(self):
settings = Oauth2Settings(
provider="okta",
domain="my-company.okta.com",
client_id="test-client",
audience="test-audience",
)
provider = OktaProvider(settings)
expected_url = "https://my-company.okta.com/oauth2/default/v1/device/authorize"
assert provider.get_authorize_url() == expected_url
def test_get_authorize_url_with_custom_authorization_server_name(self):
settings = Oauth2Settings(
provider="okta",
domain="test-domain.okta.com",
client_id="test-client-id",
audience=None,
extra={
"using_org_auth_server": False,
"authorization_server_name": "my_auth_server_xxxAAA777"
}
)
provider = OktaProvider(settings)
expected_url = "https://test-domain.okta.com/oauth2/my_auth_server_xxxAAA777/v1/device/authorize"
assert provider.get_authorize_url() == expected_url
def test_get_authorize_url_when_using_org_auth_server(self):
settings = Oauth2Settings(
provider="okta",
domain="test-domain.okta.com",
client_id="test-client-id",
audience=None,
extra={
"using_org_auth_server": True,
"authorization_server_name": None
}
)
provider = OktaProvider(settings)
expected_url = "https://test-domain.okta.com/oauth2/v1/device/authorize"
assert provider.get_authorize_url() == expected_url
def test_get_token_url(self):
expected_url = "https://test-domain.okta.com/oauth2/default/v1/token"
assert self.provider.get_token_url() == expected_url
def test_get_token_url_with_different_domain(self):
settings = Oauth2Settings(
provider="okta",
domain="another-domain.okta.com",
client_id="test-client",
audience="test-audience",
)
provider = OktaProvider(settings)
expected_url = "https://another-domain.okta.com/oauth2/default/v1/token"
assert provider.get_token_url() == expected_url
def test_get_token_url_with_custom_authorization_server_name(self):
settings = Oauth2Settings(
provider="okta",
domain="test-domain.okta.com",
client_id="test-client-id",
audience=None,
extra={
"using_org_auth_server": False,
"authorization_server_name": "my_auth_server_xxxAAA777"
}
)
provider = OktaProvider(settings)
expected_url = "https://test-domain.okta.com/oauth2/my_auth_server_xxxAAA777/v1/token"
assert provider.get_token_url() == expected_url
def test_get_token_url_when_using_org_auth_server(self):
settings = Oauth2Settings(
provider="okta",
domain="test-domain.okta.com",
client_id="test-client-id",
audience=None,
extra={
"using_org_auth_server": True,
"authorization_server_name": None
}
)
provider = OktaProvider(settings)
expected_url = "https://test-domain.okta.com/oauth2/v1/token"
assert provider.get_token_url() == expected_url
def test_get_jwks_url(self):
expected_url = "https://test-domain.okta.com/oauth2/default/v1/keys"
assert self.provider.get_jwks_url() == expected_url
def test_get_jwks_url_with_different_domain(self):
settings = Oauth2Settings(
provider="okta",
domain="dev.okta.com",
client_id="test-client",
audience="test-audience",
)
provider = OktaProvider(settings)
expected_url = "https://dev.okta.com/oauth2/default/v1/keys"
assert provider.get_jwks_url() == expected_url
def test_get_jwks_url_with_custom_authorization_server_name(self):
settings = Oauth2Settings(
provider="okta",
domain="test-domain.okta.com",
client_id="test-client-id",
audience=None,
extra={
"using_org_auth_server": False,
"authorization_server_name": "my_auth_server_xxxAAA777"
}
)
provider = OktaProvider(settings)
expected_url = "https://test-domain.okta.com/oauth2/my_auth_server_xxxAAA777/v1/keys"
assert provider.get_jwks_url() == expected_url
def test_get_jwks_url_when_using_org_auth_server(self):
settings = Oauth2Settings(
provider="okta",
domain="test-domain.okta.com",
client_id="test-client-id",
audience=None,
extra={
"using_org_auth_server": True,
"authorization_server_name": None
}
)
provider = OktaProvider(settings)
expected_url = "https://test-domain.okta.com/oauth2/v1/keys"
assert provider.get_jwks_url() == expected_url
def test_get_issuer(self):
expected_issuer = "https://test-domain.okta.com/oauth2/default"
assert self.provider.get_issuer() == expected_issuer
def test_get_issuer_with_different_domain(self):
settings = Oauth2Settings(
provider="okta",
domain="prod.okta.com",
client_id="test-client",
audience="test-audience",
)
provider = OktaProvider(settings)
expected_issuer = "https://prod.okta.com/oauth2/default"
assert provider.get_issuer() == expected_issuer
def test_get_issuer_with_custom_authorization_server_name(self):
settings = Oauth2Settings(
provider="okta",
domain="test-domain.okta.com",
client_id="test-client-id",
audience=None,
extra={
"using_org_auth_server": False,
"authorization_server_name": "my_auth_server_xxxAAA777"
}
)
provider = OktaProvider(settings)
expected_issuer = "https://test-domain.okta.com/oauth2/my_auth_server_xxxAAA777"
assert provider.get_issuer() == expected_issuer
def test_get_issuer_when_using_org_auth_server(self):
settings = Oauth2Settings(
provider="okta",
domain="test-domain.okta.com",
client_id="test-client-id",
audience=None,
extra={
"using_org_auth_server": True,
"authorization_server_name": None
}
)
provider = OktaProvider(settings)
expected_issuer = "https://test-domain.okta.com"
assert provider.get_issuer() == expected_issuer
def test_get_audience(self):
assert self.provider.get_audience() == "test-audience"
def test_get_audience_assertion_error_when_none(self):
settings = Oauth2Settings(
provider="okta",
domain="test-domain.okta.com",
client_id="test-client-id",
audience=None,
)
provider = OktaProvider(settings)
with pytest.raises(ValueError, match="Audience is required"):
provider.get_audience()
def test_get_client_id(self):
assert self.provider.get_client_id() == "test-client-id"
def test_get_required_fields(self):
assert set(self.provider.get_required_fields()) == set(["authorization_server_name", "using_org_auth_server"])
def test_oauth2_base_url(self):
assert self.provider._oauth2_base_url() == "https://test-domain.okta.com/oauth2/default"
def test_oauth2_base_url_with_custom_authorization_server_name(self):
settings = Oauth2Settings(
provider="okta",
domain="test-domain.okta.com",
client_id="test-client-id",
audience=None,
extra={
"using_org_auth_server": False,
"authorization_server_name": "my_auth_server_xxxAAA777"
}
)
provider = OktaProvider(settings)
assert provider._oauth2_base_url() == "https://test-domain.okta.com/oauth2/my_auth_server_xxxAAA777"
def test_oauth2_base_url_when_using_org_auth_server(self):
settings = Oauth2Settings(
provider="okta",
domain="test-domain.okta.com",
client_id="test-client-id",
audience=None,
extra={
"using_org_auth_server": True,
"authorization_server_name": None
}
)
provider = OktaProvider(settings)
assert provider._oauth2_base_url() == "https://test-domain.okta.com/oauth2"

View File

@@ -0,0 +1,100 @@
import pytest
from crewai_cli.authentication.main import Oauth2Settings
from crewai_cli.authentication.providers.workos import WorkosProvider
class TestWorkosProvider:
@pytest.fixture(autouse=True)
def setup_method(self):
self.valid_settings = Oauth2Settings(
provider="workos",
domain="login.company.com",
client_id="test-client-id",
audience="test-audience"
)
self.provider = WorkosProvider(self.valid_settings)
def test_initialization_with_valid_settings(self):
provider = WorkosProvider(self.valid_settings)
assert provider.settings == self.valid_settings
assert provider.settings.provider == "workos"
assert provider.settings.domain == "login.company.com"
assert provider.settings.client_id == "test-client-id"
assert provider.settings.audience == "test-audience"
def test_get_authorize_url(self):
expected_url = "https://login.company.com/oauth2/device_authorization"
assert self.provider.get_authorize_url() == expected_url
def test_get_authorize_url_with_different_domain(self):
settings = Oauth2Settings(
provider="workos",
domain="login.example.com",
client_id="test-client",
audience="test-audience"
)
provider = WorkosProvider(settings)
expected_url = "https://login.example.com/oauth2/device_authorization"
assert provider.get_authorize_url() == expected_url
def test_get_token_url(self):
expected_url = "https://login.company.com/oauth2/token"
assert self.provider.get_token_url() == expected_url
def test_get_token_url_with_different_domain(self):
settings = Oauth2Settings(
provider="workos",
domain="api.workos.com",
client_id="test-client",
audience="test-audience"
)
provider = WorkosProvider(settings)
expected_url = "https://api.workos.com/oauth2/token"
assert provider.get_token_url() == expected_url
def test_get_jwks_url(self):
expected_url = "https://login.company.com/oauth2/jwks"
assert self.provider.get_jwks_url() == expected_url
def test_get_jwks_url_with_different_domain(self):
settings = Oauth2Settings(
provider="workos",
domain="auth.enterprise.com",
client_id="test-client",
audience="test-audience"
)
provider = WorkosProvider(settings)
expected_url = "https://auth.enterprise.com/oauth2/jwks"
assert provider.get_jwks_url() == expected_url
def test_get_issuer(self):
expected_issuer = "https://login.company.com"
assert self.provider.get_issuer() == expected_issuer
def test_get_issuer_with_different_domain(self):
settings = Oauth2Settings(
provider="workos",
domain="sso.company.com",
client_id="test-client",
audience="test-audience"
)
provider = WorkosProvider(settings)
expected_issuer = "https://sso.company.com"
assert provider.get_issuer() == expected_issuer
def test_get_audience(self):
assert self.provider.get_audience() == "test-audience"
def test_get_audience_fallback_to_default(self):
settings = Oauth2Settings(
provider="workos",
domain="login.company.com",
client_id="test-client-id",
audience=None
)
provider = WorkosProvider(settings)
assert provider.get_audience() == ""
def test_get_client_id(self):
assert self.provider.get_client_id() == "test-client-id"

View File

@@ -0,0 +1,348 @@
from datetime import datetime, timedelta
from unittest.mock import MagicMock, call, patch
import pytest
import httpx
from crewai_cli.authentication.main import AuthenticationCommand
from crewai_cli.constants import (
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_AUDIENCE,
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_CLIENT_ID,
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN,
)
class TestAuthenticationCommand:
def setup_method(self):
# Mock Settings so we always use default constants regardless of local config.
with patch("crewai_cli.authentication.main.Settings") as mock_settings:
instance = mock_settings.return_value
instance.oauth2_provider = "workos"
instance.oauth2_domain = CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN
instance.oauth2_client_id = CREWAI_ENTERPRISE_DEFAULT_OAUTH2_CLIENT_ID
instance.oauth2_audience = CREWAI_ENTERPRISE_DEFAULT_OAUTH2_AUDIENCE
instance.oauth2_extra = {}
self.auth_command = AuthenticationCommand()
@pytest.mark.parametrize(
"user_provider,expected_urls",
[
(
"workos",
{
"device_code_url": f"https://{CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN}/oauth2/device_authorization",
"token_url": f"https://{CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN}/oauth2/token",
"client_id": CREWAI_ENTERPRISE_DEFAULT_OAUTH2_CLIENT_ID,
"audience": CREWAI_ENTERPRISE_DEFAULT_OAUTH2_AUDIENCE,
"domain": CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN,
},
),
],
)
@patch("crewai_cli.authentication.main.AuthenticationCommand._get_device_code")
@patch(
"crewai_cli.authentication.main.AuthenticationCommand._display_auth_instructions"
)
@patch("crewai_cli.authentication.main.AuthenticationCommand._poll_for_token")
@patch("crewai_cli.authentication.main.console.print")
def test_login(
self,
mock_console_print,
mock_poll,
mock_display,
mock_get_device,
user_provider,
expected_urls,
):
mock_get_device.return_value = {
"device_code": "test_code",
"user_code": "123456",
}
self.auth_command.login()
mock_console_print.assert_called_once_with(
"Signing in to CrewAI AMP...\n", style="bold blue"
)
mock_get_device.assert_called_once()
mock_display.assert_called_once_with(
{"device_code": "test_code", "user_code": "123456"}
)
mock_poll.assert_called_once_with(
{"device_code": "test_code", "user_code": "123456"},
)
assert (
self.auth_command.oauth2_provider.get_client_id()
== expected_urls["client_id"]
)
assert (
self.auth_command.oauth2_provider.get_audience()
== expected_urls["audience"]
)
assert (
self.auth_command.oauth2_provider._get_domain() == expected_urls["domain"]
)
@patch("crewai_cli.authentication.main.webbrowser")
@patch("crewai_cli.authentication.main.console.print")
def test_display_auth_instructions(self, mock_console_print, mock_webbrowser):
device_code_data = {
"verification_uri_complete": "https://example.com/auth",
"user_code": "123456",
}
self.auth_command._display_auth_instructions(device_code_data)
expected_calls = [
call("1. Navigate to: ", "https://example.com/auth"),
call("2. Enter the following code: ", "123456"),
]
mock_console_print.assert_has_calls(expected_calls)
mock_webbrowser.open.assert_called_once_with("https://example.com/auth")
@pytest.mark.parametrize(
"user_provider,jwt_config",
[
(
"workos",
{
"jwks_url": f"https://{CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN}/oauth2/jwks",
"issuer": f"https://{CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN}",
"audience": CREWAI_ENTERPRISE_DEFAULT_OAUTH2_AUDIENCE,
},
),
],
)
@pytest.mark.parametrize("has_expiration", [True, False])
@patch("crewai_cli.authentication.main.validate_jwt_token")
@patch("crewai_cli.authentication.main.TokenManager.save_tokens")
def test_validate_and_save_token(
self,
mock_save_tokens,
mock_validate_jwt,
user_provider,
jwt_config,
has_expiration,
):
from crewai_cli.authentication.main import Oauth2Settings
from crewai_cli.authentication.providers.workos import WorkosProvider
if user_provider == "workos":
self.auth_command.oauth2_provider = WorkosProvider(
settings=Oauth2Settings(
provider=user_provider,
client_id="test-client-id",
domain=CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN,
audience=jwt_config["audience"],
)
)
token_data = {"access_token": "test_access_token", "id_token": "test_id_token"}
if has_expiration:
future_timestamp = int((datetime.now() + timedelta(days=100)).timestamp())
decoded_token = {"exp": future_timestamp}
else:
decoded_token = {}
mock_validate_jwt.return_value = decoded_token
self.auth_command._validate_and_save_token(token_data)
mock_validate_jwt.assert_called_once_with(
jwt_token="test_access_token",
jwks_url=jwt_config["jwks_url"],
issuer=jwt_config["issuer"],
audience=jwt_config["audience"],
)
if has_expiration:
mock_save_tokens.assert_called_once_with(
"test_access_token", future_timestamp
)
else:
mock_save_tokens.assert_called_once_with("test_access_token", 0)
@patch("crewai_cli.tools.main.ToolCommand")
@patch("crewai_cli.authentication.main.Settings")
@patch("crewai_cli.authentication.main.console.print")
def test_login_to_tool_repository_success(
self, mock_console_print, mock_settings, mock_tool_command
):
mock_tool_instance = MagicMock()
mock_tool_command.return_value = mock_tool_instance
mock_settings_instance = MagicMock()
mock_settings_instance.org_name = "Test Org"
mock_settings_instance.org_uuid = "test-uuid-123"
mock_settings.return_value = mock_settings_instance
self.auth_command._login_to_tool_repository()
mock_tool_command.assert_called_once()
mock_tool_instance.login.assert_called_once()
expected_calls = [
call(
"Now logging you in to the Tool Repository... ",
style="bold blue",
end="",
),
call("Success!\n", style="bold green"),
call(
"You are now authenticated to the tool repository for organization [bold cyan]'Test Org'[/bold cyan]",
style="green",
),
]
mock_console_print.assert_has_calls(expected_calls)
@patch("crewai_cli.tools.main.ToolCommand")
@patch("crewai_cli.authentication.main.console.print")
def test_login_to_tool_repository_error(
self, mock_console_print, mock_tool_command
):
mock_tool_instance = MagicMock()
mock_tool_instance.login.side_effect = Exception("Tool repository error")
mock_tool_command.return_value = mock_tool_instance
self.auth_command._login_to_tool_repository()
mock_tool_command.assert_called_once()
mock_tool_instance.login.assert_called_once()
expected_calls = [
call(
"Now logging you in to the Tool Repository... ",
style="bold blue",
end="",
),
call(
"\n[bold yellow]Warning:[/bold yellow] Authentication with the Tool Repository failed.",
style="yellow",
),
call(
"Other features will work normally, but you may experience limitations with downloading and publishing tools.\nRun [bold]crewai login[/bold] to try logging in again.\n",
style="yellow",
),
]
mock_console_print.assert_has_calls(expected_calls)
@patch("crewai_cli.authentication.main.httpx.post")
def test_get_device_code(self, mock_post):
mock_response = MagicMock()
mock_response.json.return_value = {
"device_code": "test_device_code",
"user_code": "123456",
"verification_uri_complete": "https://example.com/auth",
}
mock_post.return_value = mock_response
self.auth_command.oauth2_provider = MagicMock()
self.auth_command.oauth2_provider.get_client_id.return_value = "test_client"
self.auth_command.oauth2_provider.get_authorize_url.return_value = (
"https://example.com/device"
)
self.auth_command.oauth2_provider.get_audience.return_value = "test_audience"
self.auth_command.oauth2_provider.get_oauth_scopes.return_value = ["openid", "profile", "email"]
result = self.auth_command._get_device_code()
mock_post.assert_called_once_with(
url="https://example.com/device",
data={
"client_id": "test_client",
"scope": "openid profile email",
"audience": "test_audience",
},
timeout=20,
)
assert result == {
"device_code": "test_device_code",
"user_code": "123456",
"verification_uri_complete": "https://example.com/auth",
}
@patch("crewai_cli.authentication.main.httpx.post")
@patch("crewai_cli.authentication.main.console.print")
def test_poll_for_token_success(self, mock_console_print, mock_post):
mock_response_success = MagicMock()
mock_response_success.status_code = 200
mock_response_success.json.return_value = {
"access_token": "test_access_token",
"id_token": "test_id_token",
}
mock_post.return_value = mock_response_success
device_code_data = {"device_code": "test_device_code", "interval": 1}
with (
patch.object(
self.auth_command, "_validate_and_save_token"
) as mock_validate,
patch.object(
self.auth_command, "_login_to_tool_repository"
) as mock_tool_login,
):
self.auth_command.oauth2_provider = MagicMock()
self.auth_command.oauth2_provider.get_token_url.return_value = (
"https://example.com/token"
)
self.auth_command.oauth2_provider.get_client_id.return_value = "test_client"
self.auth_command._poll_for_token(device_code_data)
mock_post.assert_called_once_with(
"https://example.com/token",
data={
"grant_type": "urn:ietf:params:oauth:grant-type:device_code",
"device_code": "test_device_code",
"client_id": "test_client",
},
timeout=30,
)
mock_validate.assert_called_once()
mock_tool_login.assert_called_once()
expected_calls = [
call("\nWaiting for authentication... ", style="bold blue", end=""),
call("Success!", style="bold green"),
call("\n[bold green]Welcome to CrewAI AMP![/bold green]\n"),
]
mock_console_print.assert_has_calls(expected_calls)
@patch("crewai_cli.authentication.main.httpx.post")
@patch("crewai_cli.authentication.main.console.print")
def test_poll_for_token_timeout(self, mock_console_print, mock_post):
mock_response_pending = MagicMock()
mock_response_pending.status_code = 400
mock_response_pending.json.return_value = {"error": "authorization_pending"}
mock_post.return_value = mock_response_pending
device_code_data = {
"device_code": "test_device_code",
"interval": 0.1, # Short interval for testing
}
self.auth_command._poll_for_token(device_code_data)
mock_console_print.assert_any_call(
"Timeout: Failed to get the token. Please try again.", style="bold red"
)
@patch("crewai_cli.authentication.main.httpx.post")
def test_poll_for_token_error(self, mock_post):
"""Test the method to poll for token (error path)."""
# Setup mock to return error
mock_response_error = MagicMock()
mock_response_error.status_code = 400
mock_response_error.json.return_value = {
"error": "access_denied",
"error_description": "User denied access",
}
mock_post.return_value = mock_response_error
device_code_data = {"device_code": "test_device_code", "interval": 1}
with pytest.raises(httpx.HTTPError):
self.auth_command._poll_for_token(device_code_data)

View File

@@ -0,0 +1,107 @@
import unittest
from unittest.mock import MagicMock, patch
import jwt
from crewai_cli.authentication.utils import validate_jwt_token
@patch("crewai_cli.authentication.utils.PyJWKClient", return_value=MagicMock())
@patch("crewai_cli.authentication.utils.jwt")
class TestUtils(unittest.TestCase):
def test_validate_jwt_token(self, mock_jwt, mock_pyjwkclient):
mock_jwt.decode.return_value = {"exp": 1719859200}
# Create signing key object mock with a .key attribute
mock_pyjwkclient.return_value.get_signing_key_from_jwt.return_value = MagicMock(
key="mock_signing_key"
)
jwt_token = "aaaaa.bbbbbb.cccccc" # noqa: S105
decoded_token = validate_jwt_token(
jwt_token=jwt_token,
jwks_url="https://mock_jwks_url",
issuer="https://mock_issuer",
audience="app_id_xxxx",
)
mock_jwt.decode.assert_called_with(
jwt_token,
"mock_signing_key",
algorithms=["RS256"],
audience="app_id_xxxx",
issuer="https://mock_issuer",
leeway=10.0,
options={
"verify_signature": True,
"verify_exp": True,
"verify_nbf": True,
"verify_iat": True,
"require": ["exp", "iat", "iss", "aud", "sub"],
},
)
mock_pyjwkclient.assert_called_once_with("https://mock_jwks_url")
self.assertEqual(decoded_token, {"exp": 1719859200})
def test_validate_jwt_token_expired(self, mock_jwt, mock_pyjwkclient):
mock_jwt.decode.side_effect = jwt.ExpiredSignatureError
with self.assertRaises(Exception): # noqa: B017
validate_jwt_token(
jwt_token="aaaaa.bbbbbb.cccccc", # noqa: S106
jwks_url="https://mock_jwks_url",
issuer="https://mock_issuer",
audience="app_id_xxxx",
)
def test_validate_jwt_token_invalid_audience(self, mock_jwt, mock_pyjwkclient):
mock_jwt.decode.side_effect = jwt.InvalidAudienceError
with self.assertRaises(Exception): # noqa: B017
validate_jwt_token(
jwt_token="aaaaa.bbbbbb.cccccc", # noqa: S106
jwks_url="https://mock_jwks_url",
issuer="https://mock_issuer",
audience="app_id_xxxx",
)
def test_validate_jwt_token_invalid_issuer(self, mock_jwt, mock_pyjwkclient):
mock_jwt.decode.side_effect = jwt.InvalidIssuerError
with self.assertRaises(Exception): # noqa: B017
validate_jwt_token(
jwt_token="aaaaa.bbbbbb.cccccc", # noqa: S106
jwks_url="https://mock_jwks_url",
issuer="https://mock_issuer",
audience="app_id_xxxx",
)
def test_validate_jwt_token_missing_required_claims(
self, mock_jwt, mock_pyjwkclient
):
mock_jwt.decode.side_effect = jwt.MissingRequiredClaimError
with self.assertRaises(Exception): # noqa: B017
validate_jwt_token(
jwt_token="aaaaa.bbbbbb.cccccc", # noqa: S106
jwks_url="https://mock_jwks_url",
issuer="https://mock_issuer",
audience="app_id_xxxx",
)
def test_validate_jwt_token_jwks_error(self, mock_jwt, mock_pyjwkclient):
mock_jwt.decode.side_effect = jwt.exceptions.PyJWKClientError
with self.assertRaises(Exception): # noqa: B017
validate_jwt_token(
jwt_token="aaaaa.bbbbbb.cccccc", # noqa: S106
jwks_url="https://mock_jwks_url",
issuer="https://mock_issuer",
audience="app_id_xxxx",
)
def test_validate_jwt_token_invalid_token(self, mock_jwt, mock_pyjwkclient):
mock_jwt.decode.side_effect = jwt.InvalidTokenError
with self.assertRaises(Exception): # noqa: B017
validate_jwt_token(
jwt_token="aaaaa.bbbbbb.cccccc", # noqa: S106
jwks_url="https://mock_jwks_url",
issuer="https://mock_issuer",
audience="app_id_xxxx",
)

View File

@@ -7,14 +7,14 @@ import pytest
import json
import httpx
from crewai.cli.deploy.main import DeployCommand
from crewai.cli.utils import parse_toml
from crewai_cli.deploy.main import DeployCommand
from crewai_cli.utils import parse_toml
class TestDeployCommand(unittest.TestCase):
@patch("crewai.cli.command.get_auth_token")
@patch("crewai.cli.deploy.main.get_project_name")
@patch("crewai.cli.command.PlusAPI")
@patch("crewai_cli.command.get_auth_token")
@patch("crewai_cli.deploy.main.get_project_name")
@patch("crewai_cli.command.PlusAPI")
def setUp(self, mock_plus_api, mock_get_project_name, mock_get_auth_token):
self.mock_get_auth_token = mock_get_auth_token
self.mock_get_project_name = mock_get_project_name
@@ -30,7 +30,7 @@ class TestDeployCommand(unittest.TestCase):
self.assertEqual(self.deploy_command.project_name, "test_project")
self.mock_plus_api.assert_called_once_with(api_key="test_token")
@patch("crewai.cli.command.get_auth_token")
@patch("crewai_cli.command.get_auth_token")
def test_init_failure(self, mock_get_auth_token):
mock_get_auth_token.side_effect = Exception("Auth failed")
@@ -118,7 +118,7 @@ class TestDeployCommand(unittest.TestCase):
)
self.assertIn("2023-01-01 - INFO: Test log", fake_out.getvalue())
@patch("crewai.cli.deploy.main.DeployCommand._display_deployment_info")
@patch("crewai_cli.deploy.main.DeployCommand._display_deployment_info")
def test_deploy_with_uuid(self, mock_display):
mock_response = MagicMock()
mock_response.status_code = 200
@@ -130,7 +130,7 @@ class TestDeployCommand(unittest.TestCase):
self.mock_client.deploy_by_uuid.assert_called_once_with("test-uuid")
mock_display.assert_called_once_with({"uuid": "test-uuid"})
@patch("crewai.cli.deploy.main.DeployCommand._display_deployment_info")
@patch("crewai_cli.deploy.main.DeployCommand._display_deployment_info")
def test_deploy_with_project_name(self, mock_display):
mock_response = MagicMock()
mock_response.status_code = 200
@@ -142,8 +142,8 @@ class TestDeployCommand(unittest.TestCase):
self.mock_client.deploy_by_name.assert_called_once_with("test_project")
mock_display.assert_called_once_with({"uuid": "test-uuid"})
@patch("crewai.cli.deploy.main.fetch_and_json_env_file")
@patch("crewai.cli.deploy.main.git.Repository.origin_url")
@patch("crewai_cli.deploy.main.fetch_and_json_env_file")
@patch("crewai_cli.deploy.main.git.Repository.origin_url")
@patch("builtins.input")
def test_create_crew(self, mock_input, mock_git_origin_url, mock_fetch_env):
mock_fetch_env.return_value = {"ENV_VAR": "value"}
@@ -236,7 +236,7 @@ class TestDeployCommand(unittest.TestCase):
""",
)
def test_get_project_name_python_310(self, mock_open):
from crewai.cli.utils import get_project_name
from crewai_cli.utils import get_project_name
project_name = get_project_name()
print("project_name", project_name)
@@ -255,12 +255,12 @@ class TestDeployCommand(unittest.TestCase):
""",
)
def test_get_project_name_python_311_plus(self, mock_open):
from crewai.cli.utils import get_project_name
from crewai_cli.utils import get_project_name
project_name = get_project_name()
self.assertEqual(project_name, "test_project")
def test_get_crewai_version(self):
from crewai.cli.version import get_crewai_version
from crewai_cli.version import get_crewai_version
assert isinstance(get_crewai_version(), str)

View File

Some files were not shown because too many files have changed in this diff Show More