mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-05-07 10:12:38 +00:00
refactor: extract CLI into standalone crewai-cli package
This commit is contained in:
26
lib/cli/README.md
Normal file
26
lib/cli/README.md
Normal file
@@ -0,0 +1,26 @@
|
||||
# crewai-cli
|
||||
|
||||
CLI for CrewAI — scaffold, run, deploy and manage AI agent crews without
|
||||
installing the full framework.
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
pip install crewai-cli
|
||||
```
|
||||
|
||||
This pulls in `crewai-core` (shared utilities) but not the `crewai` framework
|
||||
itself, so commands that don't need a crew loaded — `crewai version`,
|
||||
`crewai login`, `crewai org list`, `crewai config *`, `crewai traces *`,
|
||||
`crewai create`, `crewai template *` — work standalone.
|
||||
|
||||
Commands that load a user's crew or flow (`crewai run`, `crewai train`,
|
||||
`crewai test`, `crewai chat`, `crewai replay`, `crewai reset-memories`,
|
||||
`crewai deploy push`, `crewai tool publish`) require `crewai` to be installed
|
||||
in the project's environment. They print a clear error if it is missing.
|
||||
|
||||
To install both at once:
|
||||
|
||||
```bash
|
||||
pip install crewai[cli]
|
||||
```
|
||||
43
lib/cli/pyproject.toml
Normal file
43
lib/cli/pyproject.toml
Normal file
@@ -0,0 +1,43 @@
|
||||
[project]
|
||||
name = "crewai-cli"
|
||||
dynamic = ["version"]
|
||||
description = "CLI for CrewAI — scaffold, run, deploy and manage AI agent crews."
|
||||
readme = "README.md"
|
||||
authors = [
|
||||
{ name = "Joao Moura", email = "joao@crewai.com" }
|
||||
]
|
||||
requires-python = ">=3.10, <3.14"
|
||||
dependencies = [
|
||||
"crewai-core>=1.14.5a2",
|
||||
"click~=8.1.7",
|
||||
"pydantic>=2.11.9,<2.13",
|
||||
"pydantic-settings~=2.10.1",
|
||||
"appdirs~=1.4.4",
|
||||
"cryptography>=42.0",
|
||||
"httpx~=0.28.1",
|
||||
"pyjwt>=2.9.0,<3",
|
||||
"rich>=13.7.1",
|
||||
"tomli~=2.0.2",
|
||||
"tomli-w~=1.1.0",
|
||||
"packaging>=23.0",
|
||||
"python-dotenv>=1.2.2,<2",
|
||||
"uv~=0.11.6",
|
||||
]
|
||||
|
||||
[project.urls]
|
||||
Homepage = "https://crewai.com"
|
||||
Documentation = "https://docs.crewai.com"
|
||||
Repository = "https://github.com/crewAIInc/crewAI"
|
||||
|
||||
[project.scripts]
|
||||
crewai = "crewai_cli.cli:crewai"
|
||||
|
||||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
[tool.hatch.version]
|
||||
path = "src/crewai_cli/__init__.py"
|
||||
|
||||
[tool.hatch.build.targets.wheel]
|
||||
packages = ["src/crewai_cli"]
|
||||
1
lib/cli/src/crewai_cli/__init__.py
Normal file
1
lib/cli/src/crewai_cli/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
__version__ = "1.14.5a2"
|
||||
73
lib/cli/src/crewai_cli/add_crew_to_flow.py
Normal file
73
lib/cli/src/crewai_cli/add_crew_to_flow.py
Normal file
@@ -0,0 +1,73 @@
|
||||
from pathlib import Path
|
||||
|
||||
import click
|
||||
from crewai_core.printer import PRINTER
|
||||
|
||||
from crewai_cli.utils import copy_template
|
||||
|
||||
|
||||
def add_crew_to_flow(crew_name: str) -> None:
|
||||
"""Add a new crew to the current flow."""
|
||||
# Check if pyproject.toml exists in the current directory
|
||||
if not Path("pyproject.toml").exists():
|
||||
PRINTER.print(
|
||||
"This command must be run from the root of a flow project.", color="red"
|
||||
)
|
||||
raise click.ClickException(
|
||||
"This command must be run from the root of a flow project."
|
||||
)
|
||||
|
||||
# Determine the flow folder based on the current directory
|
||||
flow_folder = Path.cwd()
|
||||
crews_folder = flow_folder / "src" / flow_folder.name / "crews"
|
||||
|
||||
if not crews_folder.exists():
|
||||
PRINTER.print("Crews folder does not exist in the current flow.", color="red")
|
||||
raise click.ClickException("Crews folder does not exist in the current flow.")
|
||||
|
||||
# Create the crew within the flow's crews directory
|
||||
create_embedded_crew(crew_name, parent_folder=crews_folder)
|
||||
|
||||
click.echo(
|
||||
f"Crew {crew_name} added to the current flow successfully!",
|
||||
)
|
||||
|
||||
|
||||
def create_embedded_crew(crew_name: str, parent_folder: Path) -> None:
|
||||
"""Create a new crew within an existing flow project."""
|
||||
folder_name = crew_name.replace(" ", "_").replace("-", "_").lower()
|
||||
class_name = crew_name.replace("_", " ").replace("-", " ").title().replace(" ", "")
|
||||
|
||||
crew_folder = parent_folder / folder_name
|
||||
|
||||
if crew_folder.exists():
|
||||
if not click.confirm(
|
||||
f"Crew {folder_name} already exists. Do you want to override it?"
|
||||
):
|
||||
click.secho("Operation cancelled.", fg="yellow")
|
||||
return
|
||||
click.secho(f"Overriding crew {folder_name}...", fg="green", bold=True)
|
||||
else:
|
||||
click.secho(f"Creating crew {folder_name}...", fg="green", bold=True)
|
||||
crew_folder.mkdir(parents=True)
|
||||
|
||||
# Create config and crew.py files
|
||||
config_folder = crew_folder / "config"
|
||||
config_folder.mkdir(exist_ok=True)
|
||||
|
||||
templates_dir = Path(__file__).parent / "templates" / "crew"
|
||||
config_template_files = ["agents.yaml", "tasks.yaml"]
|
||||
crew_template_file = f"{folder_name}.py" # Updated file name
|
||||
|
||||
for file_name in config_template_files:
|
||||
src_file = templates_dir / "config" / file_name
|
||||
dst_file = config_folder / file_name
|
||||
copy_template(src_file, dst_file, crew_name, class_name, folder_name)
|
||||
|
||||
src_file = templates_dir / "crew.py"
|
||||
dst_file = crew_folder / crew_template_file
|
||||
copy_template(src_file, dst_file, crew_name, class_name, folder_name)
|
||||
|
||||
click.secho(
|
||||
f"Crew {crew_name} added to the flow successfully!", fg="green", bold=True
|
||||
)
|
||||
8
lib/cli/src/crewai_cli/authentication/__init__.py
Normal file
8
lib/cli/src/crewai_cli/authentication/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
"""CLI authentication entry point."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from crewai_cli.authentication.main import AuthenticationCommand
|
||||
|
||||
|
||||
__all__ = ["AuthenticationCommand"]
|
||||
8
lib/cli/src/crewai_cli/authentication/constants.py
Normal file
8
lib/cli/src/crewai_cli/authentication/constants.py
Normal file
@@ -0,0 +1,8 @@
|
||||
"""Re-export of authentication constants from ``crewai_core.auth.constants``."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from crewai_core.auth.constants import ALGORITHMS as ALGORITHMS
|
||||
|
||||
|
||||
__all__ = ["ALGORITHMS"]
|
||||
60
lib/cli/src/crewai_cli/authentication/main.py
Normal file
60
lib/cli/src/crewai_cli/authentication/main.py
Normal file
@@ -0,0 +1,60 @@
|
||||
"""CLI-side authentication wiring.
|
||||
|
||||
Re-exports the OAuth2 primitives from ``crewai_core.auth`` and overrides the
|
||||
``_post_login`` hook to also log into the tool repository.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from crewai_core.auth.oauth2 import (
|
||||
AuthenticationCommand as _BaseAuthenticationCommand,
|
||||
Oauth2Settings as Oauth2Settings,
|
||||
ProviderFactory as ProviderFactory,
|
||||
console,
|
||||
)
|
||||
from crewai_core.settings import Settings
|
||||
|
||||
|
||||
__all__ = ["AuthenticationCommand", "Oauth2Settings", "ProviderFactory"]
|
||||
|
||||
|
||||
class AuthenticationCommand(_BaseAuthenticationCommand):
|
||||
"""CLI-side login that also signs the user into the tool repository."""
|
||||
|
||||
def _post_login(self) -> None:
|
||||
self._login_to_tool_repository()
|
||||
|
||||
def _login_to_tool_repository(self) -> None:
|
||||
from crewai_cli.tools.main import ToolCommand
|
||||
|
||||
try:
|
||||
console.print(
|
||||
"Now logging you in to the Tool Repository... ",
|
||||
style="bold blue",
|
||||
end="",
|
||||
)
|
||||
|
||||
ToolCommand().login()
|
||||
|
||||
console.print(
|
||||
"Success!\n",
|
||||
style="bold green",
|
||||
)
|
||||
|
||||
settings = Settings()
|
||||
|
||||
console.print(
|
||||
f"You are now authenticated to the tool repository for organization [bold cyan]'{settings.org_name if settings.org_name else settings.org_uuid}'[/bold cyan]",
|
||||
style="green",
|
||||
)
|
||||
except (Exception, SystemExit):
|
||||
console.print(
|
||||
"\n[bold yellow]Warning:[/bold yellow] Authentication with the Tool Repository failed.",
|
||||
style="yellow",
|
||||
)
|
||||
console.print(
|
||||
"Other features will work normally, but you may experience limitations "
|
||||
"with downloading and publishing tools."
|
||||
"\nRun [bold]crewai login[/bold] to try logging in again.\n",
|
||||
style="yellow",
|
||||
)
|
||||
@@ -0,0 +1 @@
|
||||
"""OAuth2 authentication providers — re-exported from ``crewai_core.auth.providers``."""
|
||||
8
lib/cli/src/crewai_cli/authentication/providers/auth0.py
Normal file
8
lib/cli/src/crewai_cli/authentication/providers/auth0.py
Normal file
@@ -0,0 +1,8 @@
|
||||
"""Re-export of ``Auth0Provider`` from ``crewai_core.auth.providers.auth0``."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from crewai_core.auth.providers.auth0 import Auth0Provider as Auth0Provider
|
||||
|
||||
|
||||
__all__ = ["Auth0Provider"]
|
||||
@@ -0,0 +1,8 @@
|
||||
"""Re-export of ``BaseProvider`` from ``crewai_core.auth.providers.base_provider``."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from crewai_core.auth.providers.base_provider import BaseProvider as BaseProvider
|
||||
|
||||
|
||||
__all__ = ["BaseProvider"]
|
||||
@@ -0,0 +1,8 @@
|
||||
"""Re-export of ``EntraIdProvider`` from ``crewai_core.auth.providers.entra_id``."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from crewai_core.auth.providers.entra_id import EntraIdProvider as EntraIdProvider
|
||||
|
||||
|
||||
__all__ = ["EntraIdProvider"]
|
||||
@@ -0,0 +1,8 @@
|
||||
"""Re-export of ``KeycloakProvider`` from ``crewai_core.auth.providers.keycloak``."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from crewai_core.auth.providers.keycloak import KeycloakProvider as KeycloakProvider
|
||||
|
||||
|
||||
__all__ = ["KeycloakProvider"]
|
||||
8
lib/cli/src/crewai_cli/authentication/providers/okta.py
Normal file
8
lib/cli/src/crewai_cli/authentication/providers/okta.py
Normal file
@@ -0,0 +1,8 @@
|
||||
"""Re-export of ``OktaProvider`` from ``crewai_core.auth.providers.okta``."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from crewai_core.auth.providers.okta import OktaProvider as OktaProvider
|
||||
|
||||
|
||||
__all__ = ["OktaProvider"]
|
||||
@@ -0,0 +1,8 @@
|
||||
"""Re-export of ``WorkosProvider`` from ``crewai_core.auth.providers.workos``."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from crewai_core.auth.providers.workos import WorkosProvider as WorkosProvider
|
||||
|
||||
|
||||
__all__ = ["WorkosProvider"]
|
||||
11
lib/cli/src/crewai_cli/authentication/token.py
Normal file
11
lib/cli/src/crewai_cli/authentication/token.py
Normal file
@@ -0,0 +1,11 @@
|
||||
"""Re-exports of authentication token helpers from ``crewai_core.auth.token``."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from crewai_core.auth.token import (
|
||||
AuthError as AuthError,
|
||||
get_auth_token as get_auth_token,
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["AuthError", "get_auth_token"]
|
||||
8
lib/cli/src/crewai_cli/authentication/utils.py
Normal file
8
lib/cli/src/crewai_cli/authentication/utils.py
Normal file
@@ -0,0 +1,8 @@
|
||||
"""Re-export of ``validate_jwt_token`` from ``crewai_core.auth.utils``."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from crewai_core.auth.utils import validate_jwt_token as validate_jwt_token
|
||||
|
||||
|
||||
__all__ = ["validate_jwt_token"]
|
||||
732
lib/cli/src/crewai_cli/checkpoint_cli.py
Normal file
732
lib/cli/src/crewai_cli/checkpoint_cli.py
Normal file
@@ -0,0 +1,732 @@
|
||||
"""CLI commands for inspecting checkpoint files."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timedelta, timezone
|
||||
import glob
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import sqlite3
|
||||
from typing import Any
|
||||
|
||||
import click
|
||||
|
||||
|
||||
_PLACEHOLDER_RE = re.compile(r"\{([A-Za-z_][A-Za-z0-9_\-]*)}")
|
||||
|
||||
|
||||
_SQLITE_MAGIC = b"SQLite format 3\x00"
|
||||
|
||||
_SELECT_ALL = """
|
||||
SELECT id, created_at, json(data)
|
||||
FROM checkpoints
|
||||
ORDER BY rowid DESC
|
||||
"""
|
||||
|
||||
_SELECT_ONE = """
|
||||
SELECT id, created_at, json(data)
|
||||
FROM checkpoints
|
||||
WHERE id = ?
|
||||
"""
|
||||
|
||||
_SELECT_LATEST = """
|
||||
SELECT id, created_at, json(data)
|
||||
FROM checkpoints
|
||||
ORDER BY rowid DESC
|
||||
LIMIT 1
|
||||
"""
|
||||
|
||||
_DELETE_OLDER_THAN = """
|
||||
DELETE FROM checkpoints
|
||||
WHERE created_at < ?
|
||||
"""
|
||||
|
||||
_DELETE_KEEP_N = """
|
||||
DELETE FROM checkpoints WHERE rowid NOT IN (
|
||||
SELECT rowid FROM checkpoints ORDER BY rowid DESC LIMIT ?
|
||||
)
|
||||
"""
|
||||
|
||||
_COUNT_CHECKPOINTS = "SELECT COUNT(*) FROM checkpoints"
|
||||
|
||||
_SELECT_LIKE = """
|
||||
SELECT id, created_at, json(data)
|
||||
FROM checkpoints
|
||||
WHERE id LIKE ?
|
||||
ORDER BY rowid DESC
|
||||
"""
|
||||
|
||||
|
||||
_DEFAULT_DIR = "./.checkpoints"
|
||||
_DEFAULT_DB = "./.checkpoints.db"
|
||||
|
||||
|
||||
def _detect_location(location: str) -> str:
|
||||
"""Resolve the default checkpoint location.
|
||||
|
||||
When the caller passes the default directory path, check whether a
|
||||
SQLite database exists at the conventional ``.db`` path and prefer it.
|
||||
"""
|
||||
if (
|
||||
location == _DEFAULT_DIR
|
||||
and not os.path.exists(_DEFAULT_DIR)
|
||||
and os.path.exists(_DEFAULT_DB)
|
||||
):
|
||||
return _DEFAULT_DB
|
||||
return location
|
||||
|
||||
|
||||
def _is_sqlite(path: str) -> bool:
|
||||
"""Check if a file is a SQLite database by reading its magic bytes."""
|
||||
if not os.path.isfile(path):
|
||||
return False
|
||||
try:
|
||||
with open(path, "rb") as f:
|
||||
return f.read(16) == _SQLITE_MAGIC
|
||||
except OSError:
|
||||
return False
|
||||
|
||||
|
||||
def _parse_checkpoint_json(raw: str, source: str) -> dict[str, Any]:
|
||||
"""Parse checkpoint JSON into metadata dict."""
|
||||
data = json.loads(raw)
|
||||
entities = data.get("entities", [])
|
||||
nodes = data.get("event_record", {}).get("nodes", {})
|
||||
event_count = len(nodes)
|
||||
|
||||
trigger_event = data.get("trigger")
|
||||
|
||||
parsed_entities: list[dict[str, Any]] = []
|
||||
for entity in entities:
|
||||
tasks = entity.get("tasks", [])
|
||||
completed = sum(1 for t in tasks if t.get("output") is not None)
|
||||
info: dict[str, Any] = {
|
||||
"type": entity.get("entity_type", "unknown"),
|
||||
"name": entity.get("name"),
|
||||
"id": entity.get("id"),
|
||||
}
|
||||
|
||||
raw_agents = entity.get("agents", [])
|
||||
agents_by_id: dict[str, dict[str, Any]] = {}
|
||||
parsed_agents: list[dict[str, Any]] = []
|
||||
for ag in raw_agents:
|
||||
agent_info: dict[str, Any] = {
|
||||
"id": ag.get("id", ""),
|
||||
"role": ag.get("role", ""),
|
||||
"goal": ag.get("goal", ""),
|
||||
}
|
||||
parsed_agents.append(agent_info)
|
||||
if ag.get("id"):
|
||||
agents_by_id[str(ag["id"])] = agent_info
|
||||
if parsed_agents:
|
||||
info["agents"] = parsed_agents
|
||||
|
||||
if tasks:
|
||||
info["tasks_completed"] = completed
|
||||
info["tasks_total"] = len(tasks)
|
||||
parsed_tasks: list[dict[str, Any]] = []
|
||||
for t in tasks:
|
||||
task_info: dict[str, Any] = {
|
||||
"description": t.get("description", ""),
|
||||
"completed": t.get("output") is not None,
|
||||
"output": (t.get("output") or {}).get("raw", ""),
|
||||
}
|
||||
task_agent = t.get("agent")
|
||||
if isinstance(task_agent, dict):
|
||||
task_info["agent_role"] = task_agent.get("role", "")
|
||||
task_info["agent_id"] = task_agent.get("id", "")
|
||||
elif isinstance(task_agent, str) and task_agent in agents_by_id:
|
||||
task_info["agent_role"] = agents_by_id[task_agent].get("role", "")
|
||||
task_info["agent_id"] = task_agent
|
||||
parsed_tasks.append(task_info)
|
||||
info["tasks"] = parsed_tasks
|
||||
|
||||
if entity.get("entity_type") == "flow":
|
||||
completed_methods = entity.get("checkpoint_completed_methods")
|
||||
if completed_methods:
|
||||
info["completed_methods"] = sorted(completed_methods)
|
||||
state = entity.get("checkpoint_state")
|
||||
if isinstance(state, dict):
|
||||
info["flow_state"] = state
|
||||
|
||||
parsed_entities.append(info)
|
||||
|
||||
inputs: dict[str, Any] = {}
|
||||
for entity in entities:
|
||||
cp_inputs = entity.get("checkpoint_inputs")
|
||||
if isinstance(cp_inputs, dict) and cp_inputs:
|
||||
inputs = dict(cp_inputs)
|
||||
break
|
||||
|
||||
for entity in entities:
|
||||
for task in entity.get("tasks", []):
|
||||
for field in (
|
||||
"checkpoint_original_description",
|
||||
"checkpoint_original_expected_output",
|
||||
):
|
||||
text = task.get(field) or ""
|
||||
for match in _PLACEHOLDER_RE.findall(text):
|
||||
if match not in inputs:
|
||||
inputs[match] = ""
|
||||
for agent in entity.get("agents", []):
|
||||
for field in ("role", "goal", "backstory"):
|
||||
text = agent.get(field) or ""
|
||||
for match in _PLACEHOLDER_RE.findall(text):
|
||||
if match not in inputs:
|
||||
inputs[match] = ""
|
||||
|
||||
branch = data.get("branch", "main")
|
||||
parent_id = data.get("parent_id")
|
||||
|
||||
return {
|
||||
"source": source,
|
||||
"event_count": event_count,
|
||||
"trigger": trigger_event,
|
||||
"entities": parsed_entities,
|
||||
"branch": branch,
|
||||
"parent_id": parent_id,
|
||||
"inputs": inputs,
|
||||
}
|
||||
|
||||
|
||||
def _format_size(size: int) -> str:
|
||||
if size < 1024:
|
||||
return f"{size}B"
|
||||
if size < 1024 * 1024:
|
||||
return f"{size / 1024:.1f}KB"
|
||||
return f"{size / 1024 / 1024:.1f}MB"
|
||||
|
||||
|
||||
def _ts_from_name(name: str) -> str | None:
|
||||
"""Extract timestamp from checkpoint ID or filename."""
|
||||
stem = os.path.basename(name).split("_")[0].removesuffix(".json")
|
||||
try:
|
||||
dt = datetime.strptime(stem, "%Y%m%dT%H%M%S")
|
||||
except ValueError:
|
||||
return None
|
||||
return dt.strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
|
||||
def _entity_summary(entities: list[dict[str, Any]]) -> str:
|
||||
parts = []
|
||||
for ent in entities:
|
||||
etype = ent.get("type", "unknown")
|
||||
ename = ent.get("name", "")
|
||||
completed = ent.get("tasks_completed")
|
||||
total = ent.get("tasks_total")
|
||||
if completed is not None and total is not None:
|
||||
parts.append(f"{etype}:{ename} [{completed}/{total} tasks]")
|
||||
else:
|
||||
parts.append(f"{etype}:{ename}")
|
||||
return ", ".join(parts) if parts else "empty"
|
||||
|
||||
|
||||
# --- JSON directory ---
|
||||
|
||||
|
||||
def _list_json(location: str) -> list[dict[str, Any]]:
|
||||
pattern = os.path.join(location, "**", "*.json")
|
||||
results = []
|
||||
for path in sorted(
|
||||
glob.glob(pattern, recursive=True), key=os.path.getmtime, reverse=True
|
||||
):
|
||||
name = os.path.basename(path)
|
||||
try:
|
||||
with open(path) as f:
|
||||
raw = f.read()
|
||||
meta = _parse_checkpoint_json(raw, source=name)
|
||||
meta["name"] = name
|
||||
meta["ts"] = _ts_from_name(name)
|
||||
meta["size"] = os.path.getsize(path)
|
||||
meta["path"] = path
|
||||
except Exception:
|
||||
meta = {"name": name, "ts": None, "size": 0, "entities": [], "source": name}
|
||||
results.append(meta)
|
||||
return results
|
||||
|
||||
|
||||
def _info_json_latest(location: str) -> dict[str, Any] | None:
|
||||
pattern = os.path.join(location, "**", "*.json")
|
||||
files = sorted(
|
||||
glob.glob(pattern, recursive=True), key=os.path.getmtime, reverse=True
|
||||
)
|
||||
if not files:
|
||||
return None
|
||||
path = files[0]
|
||||
with open(path) as f:
|
||||
raw = f.read()
|
||||
meta = _parse_checkpoint_json(raw, source=os.path.basename(path))
|
||||
meta["name"] = os.path.basename(path)
|
||||
meta["ts"] = _ts_from_name(path)
|
||||
meta["size"] = os.path.getsize(path)
|
||||
meta["path"] = path
|
||||
return meta
|
||||
|
||||
|
||||
def _info_json_file(path: str) -> dict[str, Any]:
|
||||
with open(path) as f:
|
||||
raw = f.read()
|
||||
meta = _parse_checkpoint_json(raw, source=os.path.basename(path))
|
||||
meta["name"] = os.path.basename(path)
|
||||
meta["ts"] = _ts_from_name(path)
|
||||
meta["size"] = os.path.getsize(path)
|
||||
meta["path"] = path
|
||||
return meta
|
||||
|
||||
|
||||
# --- SQLite ---
|
||||
|
||||
|
||||
def _list_sqlite(db_path: str) -> list[dict[str, Any]]:
|
||||
results = []
|
||||
with sqlite3.connect(db_path) as conn:
|
||||
for row in conn.execute(_SELECT_ALL):
|
||||
checkpoint_id, created_at, raw = row
|
||||
try:
|
||||
meta = _parse_checkpoint_json(raw, source=checkpoint_id)
|
||||
meta["name"] = checkpoint_id
|
||||
meta["ts"] = _ts_from_name(checkpoint_id) or created_at
|
||||
except Exception:
|
||||
meta = {
|
||||
"name": checkpoint_id,
|
||||
"ts": created_at,
|
||||
"entities": [],
|
||||
"source": checkpoint_id,
|
||||
}
|
||||
meta["db"] = db_path
|
||||
results.append(meta)
|
||||
return results
|
||||
|
||||
|
||||
def _info_sqlite_latest(db_path: str) -> dict[str, Any] | None:
|
||||
with sqlite3.connect(db_path) as conn:
|
||||
row = conn.execute(_SELECT_LATEST).fetchone()
|
||||
if not row:
|
||||
return None
|
||||
checkpoint_id, created_at, raw = row
|
||||
meta = _parse_checkpoint_json(raw, source=checkpoint_id)
|
||||
meta["name"] = checkpoint_id
|
||||
meta["ts"] = _ts_from_name(checkpoint_id) or created_at
|
||||
meta["db"] = db_path
|
||||
return meta
|
||||
|
||||
|
||||
def _info_sqlite_id(db_path: str, checkpoint_id: str) -> dict[str, Any] | None:
|
||||
with sqlite3.connect(db_path) as conn:
|
||||
row = conn.execute(_SELECT_ONE, (checkpoint_id,)).fetchone()
|
||||
if not row:
|
||||
row = conn.execute(_SELECT_LIKE, (f"%{checkpoint_id}%",)).fetchone()
|
||||
if not row:
|
||||
return None
|
||||
cid, created_at, raw = row
|
||||
meta = _parse_checkpoint_json(raw, source=cid)
|
||||
meta["name"] = cid
|
||||
meta["ts"] = _ts_from_name(cid) or created_at
|
||||
meta["db"] = db_path
|
||||
return meta
|
||||
|
||||
|
||||
# --- Public API ---
|
||||
|
||||
|
||||
def list_checkpoints(location: str) -> None:
|
||||
"""List all checkpoints at a location."""
|
||||
if _is_sqlite(location):
|
||||
entries = _list_sqlite(location)
|
||||
label = f"SQLite: {location}"
|
||||
elif os.path.isdir(location):
|
||||
entries = _list_json(location)
|
||||
label = location
|
||||
else:
|
||||
click.echo(f"Not a directory or SQLite database: {location}")
|
||||
return
|
||||
|
||||
if not entries:
|
||||
click.echo(f"No checkpoints found in {label}")
|
||||
return
|
||||
|
||||
click.echo(f"Found {len(entries)} checkpoint(s) in {label}\n")
|
||||
|
||||
for entry in entries:
|
||||
ts = entry.get("ts") or "unknown"
|
||||
name = entry.get("name", "")
|
||||
size = _format_size(entry["size"]) if "size" in entry else ""
|
||||
trigger = entry.get("trigger") or ""
|
||||
summary = _entity_summary(entry.get("entities", []))
|
||||
parts = [name, ts]
|
||||
if size:
|
||||
parts.append(size)
|
||||
if trigger:
|
||||
parts.append(trigger)
|
||||
parts.append(summary)
|
||||
click.echo(f" {' '.join(parts)}")
|
||||
|
||||
|
||||
def info_checkpoint(path: str) -> None:
|
||||
"""Show details of a single checkpoint."""
|
||||
meta: dict[str, Any] | None = None
|
||||
|
||||
# db_path#checkpoint_id format
|
||||
if "#" in path:
|
||||
db_path, checkpoint_id = path.rsplit("#", 1)
|
||||
if _is_sqlite(db_path):
|
||||
meta = _info_sqlite_id(db_path, checkpoint_id)
|
||||
if not meta:
|
||||
click.echo(f"Checkpoint not found: {checkpoint_id}")
|
||||
return
|
||||
|
||||
# SQLite file — show latest
|
||||
if meta is None and _is_sqlite(path):
|
||||
meta = _info_sqlite_latest(path)
|
||||
if not meta:
|
||||
click.echo(f"No checkpoints in database: {path}")
|
||||
return
|
||||
click.echo(f"Latest checkpoint: {meta['name']}\n")
|
||||
|
||||
# Directory — show latest JSON
|
||||
if meta is None and os.path.isdir(path):
|
||||
meta = _info_json_latest(path)
|
||||
if not meta:
|
||||
click.echo(f"No checkpoints found in {path}")
|
||||
return
|
||||
click.echo(f"Latest checkpoint: {meta['name']}\n")
|
||||
|
||||
# Specific JSON file
|
||||
if meta is None and os.path.isfile(path):
|
||||
try:
|
||||
meta = _info_json_file(path)
|
||||
except Exception as exc:
|
||||
click.echo(f"Failed to read checkpoint: {exc}")
|
||||
return
|
||||
|
||||
if meta is None:
|
||||
click.echo(f"Not found: {path}")
|
||||
return
|
||||
|
||||
_print_info(meta)
|
||||
|
||||
|
||||
def _print_info(meta: dict[str, Any]) -> None:
|
||||
ts = meta.get("ts") or "unknown"
|
||||
source = meta.get("path") or meta.get("db") or meta.get("source", "")
|
||||
click.echo(f"Source: {source}")
|
||||
click.echo(f"Name: {meta.get('name', '')}")
|
||||
click.echo(f"Time: {ts}")
|
||||
if "size" in meta:
|
||||
click.echo(f"Size: {_format_size(meta['size'])}")
|
||||
click.echo(f"Events: {meta.get('event_count', 0)}")
|
||||
trigger = meta.get("trigger")
|
||||
if trigger:
|
||||
click.echo(f"Trigger: {trigger}")
|
||||
click.echo(f"Branch: {meta.get('branch', 'main')}")
|
||||
parent_id = meta.get("parent_id")
|
||||
if parent_id:
|
||||
click.echo(f"Parent: {parent_id}")
|
||||
|
||||
for ent in meta.get("entities", []):
|
||||
eid = str(ent.get("id", ""))[:8]
|
||||
click.echo(f"\n {ent['type']}: {ent.get('name', 'unnamed')} ({eid}...)")
|
||||
|
||||
tasks = ent.get("tasks")
|
||||
if isinstance(tasks, list):
|
||||
click.echo(
|
||||
f" Tasks: {ent['tasks_completed']}/{ent['tasks_total']} completed"
|
||||
)
|
||||
for i, task in enumerate(tasks):
|
||||
status = "done" if task.get("completed") else "pending"
|
||||
desc = str(task.get("description", ""))
|
||||
if len(desc) > 70:
|
||||
desc = desc[:67] + "..."
|
||||
click.echo(f" {i + 1}. [{status}] {desc}")
|
||||
|
||||
|
||||
def _resolve_checkpoint(
|
||||
location: str, checkpoint_id: str | None
|
||||
) -> dict[str, Any] | None:
|
||||
if _is_sqlite(location):
|
||||
if checkpoint_id:
|
||||
return _info_sqlite_id(location, checkpoint_id)
|
||||
return _info_sqlite_latest(location)
|
||||
if os.path.isdir(location):
|
||||
if checkpoint_id:
|
||||
from crewai.state.provider.json_provider import JsonProvider
|
||||
|
||||
_json_provider: JsonProvider = JsonProvider()
|
||||
pattern: str = os.path.join(location, "**", "*.json")
|
||||
all_files: list[str] = glob.glob(pattern, recursive=True)
|
||||
matches: list[str] = [
|
||||
f for f in all_files if checkpoint_id in _json_provider.extract_id(f)
|
||||
]
|
||||
matches.sort(key=os.path.getmtime, reverse=True)
|
||||
if matches:
|
||||
return _info_json_file(matches[0])
|
||||
return None
|
||||
return _info_json_latest(location)
|
||||
if os.path.isfile(location):
|
||||
return _info_json_file(location)
|
||||
return None
|
||||
|
||||
|
||||
def _entity_type_from_meta(meta: dict[str, Any]) -> str:
|
||||
for ent in meta.get("entities", []):
|
||||
if ent.get("type") == "flow":
|
||||
return "flow"
|
||||
if ent.get("type") == "agent":
|
||||
return "agent"
|
||||
return "crew"
|
||||
|
||||
|
||||
def resume_checkpoint(location: str, checkpoint_id: str | None) -> None:
|
||||
import asyncio
|
||||
|
||||
meta: dict[str, Any] | None = _resolve_checkpoint(location, checkpoint_id)
|
||||
if meta is None:
|
||||
if checkpoint_id:
|
||||
click.echo(f"Checkpoint not found: {checkpoint_id}")
|
||||
else:
|
||||
click.echo(f"No checkpoints found in {location}")
|
||||
return
|
||||
|
||||
restore_path: str = meta.get("path") or meta.get("source", "")
|
||||
if meta.get("db"):
|
||||
restore_path = f"{meta['db']}#{meta['name']}"
|
||||
|
||||
click.echo(f"Resuming from: {meta.get('name', restore_path)}")
|
||||
_print_info(meta)
|
||||
click.echo()
|
||||
|
||||
from crewai.state.checkpoint_config import CheckpointConfig
|
||||
|
||||
config: CheckpointConfig = CheckpointConfig(restore_from=restore_path)
|
||||
entity_type: str = _entity_type_from_meta(meta)
|
||||
inputs: dict[str, Any] | None = meta.get("inputs") or None
|
||||
|
||||
if entity_type == "flow":
|
||||
from crewai.flow.flow import Flow
|
||||
|
||||
flow = Flow.from_checkpoint(config)
|
||||
result = asyncio.run(flow.kickoff_async(inputs=inputs))
|
||||
elif entity_type == "agent":
|
||||
from crewai.agent import Agent
|
||||
|
||||
agent = Agent.from_checkpoint(config)
|
||||
result = asyncio.run(agent.akickoff(messages="Resume execution."))
|
||||
else:
|
||||
from crewai.crew import Crew
|
||||
|
||||
crew = Crew.from_checkpoint(config)
|
||||
result = asyncio.run(crew.akickoff(inputs=inputs))
|
||||
|
||||
click.echo(f"\nResult: {getattr(result, 'raw', result)}")
|
||||
|
||||
|
||||
def _task_list_from_meta(meta: dict[str, Any]) -> list[dict[str, Any]]:
|
||||
tasks: list[dict[str, Any]] = []
|
||||
for ent in meta.get("entities", []):
|
||||
tasks.extend(
|
||||
{
|
||||
"entity": ent.get("name", "unnamed"),
|
||||
"description": t.get("description", ""),
|
||||
"completed": t.get("completed", False),
|
||||
"output": t.get("output", ""),
|
||||
}
|
||||
for t in ent.get("tasks", [])
|
||||
)
|
||||
return tasks
|
||||
|
||||
|
||||
def diff_checkpoints(location: str, id1: str, id2: str) -> None:
|
||||
meta1: dict[str, Any] | None = _resolve_checkpoint(location, id1)
|
||||
meta2: dict[str, Any] | None = _resolve_checkpoint(location, id2)
|
||||
|
||||
if meta1 is None:
|
||||
click.echo(f"Checkpoint not found: {id1}")
|
||||
return
|
||||
if meta2 is None:
|
||||
click.echo(f"Checkpoint not found: {id2}")
|
||||
return
|
||||
|
||||
name1: str = meta1.get("name", id1)
|
||||
name2: str = meta2.get("name", id2)
|
||||
|
||||
click.echo(f"--- {name1}")
|
||||
click.echo(f"+++ {name2}")
|
||||
click.echo()
|
||||
|
||||
fields: list[tuple[str, str]] = [
|
||||
("Time", "ts"),
|
||||
("Branch", "branch"),
|
||||
("Trigger", "trigger"),
|
||||
("Events", "event_count"),
|
||||
]
|
||||
for label, key in fields:
|
||||
v1: str = str(meta1.get(key, ""))
|
||||
v2: str = str(meta2.get(key, ""))
|
||||
if v1 != v2:
|
||||
click.echo(f" {label}:")
|
||||
click.echo(f" - {v1}")
|
||||
click.echo(f" + {v2}")
|
||||
|
||||
inputs1: dict[str, Any] = meta1.get("inputs", {})
|
||||
inputs2: dict[str, Any] = meta2.get("inputs", {})
|
||||
all_keys: list[str] = sorted(set(list(inputs1.keys()) + list(inputs2.keys())))
|
||||
changed_inputs: list[tuple[str, Any, Any]] = [
|
||||
(k, inputs1.get(k, ""), inputs2.get(k, ""))
|
||||
for k in all_keys
|
||||
if inputs1.get(k) != inputs2.get(k)
|
||||
]
|
||||
if changed_inputs:
|
||||
click.echo("\n Inputs:")
|
||||
for key, v1, v2 in changed_inputs:
|
||||
click.echo(f" {key}:")
|
||||
click.echo(f" - {v1}")
|
||||
click.echo(f" + {v2}")
|
||||
|
||||
tasks1: list[dict[str, Any]] = _task_list_from_meta(meta1)
|
||||
tasks2: list[dict[str, Any]] = _task_list_from_meta(meta2)
|
||||
|
||||
max_tasks: int = max(len(tasks1), len(tasks2))
|
||||
if max_tasks == 0:
|
||||
return
|
||||
|
||||
click.echo("\n Tasks:")
|
||||
for i in range(max_tasks):
|
||||
t1: dict[str, Any] | None = tasks1[i] if i < len(tasks1) else None
|
||||
t2: dict[str, Any] | None = tasks2[i] if i < len(tasks2) else None
|
||||
|
||||
if t1 is None:
|
||||
desc: str = t2["description"][:60] if t2 else ""
|
||||
click.echo(f" + {i + 1}. [new] {desc}")
|
||||
continue
|
||||
if t2 is None:
|
||||
desc = t1["description"][:60]
|
||||
click.echo(f" - {i + 1}. [removed] {desc}")
|
||||
continue
|
||||
|
||||
desc = str(t1["description"][:60])
|
||||
s1: str = "done" if t1["completed"] else "pending"
|
||||
s2: str = "done" if t2["completed"] else "pending"
|
||||
|
||||
if s1 != s2:
|
||||
click.echo(f" {i + 1}. {desc}")
|
||||
click.echo(f" status: {s1} -> {s2}")
|
||||
|
||||
out1: str = (t1.get("output") or "").strip()
|
||||
out2: str = (t2.get("output") or "").strip()
|
||||
if out1 != out2:
|
||||
if s1 == s2:
|
||||
click.echo(f" {i + 1}. {desc}")
|
||||
preview1: str = (
|
||||
out1[:80] + ("..." if len(out1) > 80 else "") if out1 else "(empty)"
|
||||
)
|
||||
preview2: str = (
|
||||
out2[:80] + ("..." if len(out2) > 80 else "") if out2 else "(empty)"
|
||||
)
|
||||
click.echo(" output:")
|
||||
click.echo(f" - {preview1}")
|
||||
click.echo(f" + {preview2}")
|
||||
|
||||
|
||||
def _parse_duration(value: str) -> timedelta:
|
||||
match: re.Match[str] | None = re.match(r"^(\d+)([dhm])$", value.strip())
|
||||
if not match:
|
||||
raise click.BadParameter(
|
||||
f"Invalid duration: {value!r}. Use format like '7d', '24h', or '30m'."
|
||||
)
|
||||
amount: int = int(match.group(1))
|
||||
unit: str = match.group(2)
|
||||
if unit == "d":
|
||||
return timedelta(days=amount)
|
||||
if unit == "h":
|
||||
return timedelta(hours=amount)
|
||||
return timedelta(minutes=amount)
|
||||
|
||||
|
||||
def _prune_json(location: str, keep: int | None, older_than: timedelta | None) -> int:
|
||||
pattern: str = os.path.join(location, "**", "*.json")
|
||||
files: list[str] = sorted(
|
||||
glob.glob(pattern, recursive=True), key=os.path.getmtime, reverse=True
|
||||
)
|
||||
if not files:
|
||||
return 0
|
||||
|
||||
to_delete: set[str] = set()
|
||||
|
||||
if keep is not None and len(files) > keep:
|
||||
to_delete.update(files[keep:])
|
||||
|
||||
if older_than is not None:
|
||||
cutoff: datetime = datetime.now(timezone.utc) - older_than
|
||||
for path in files:
|
||||
mtime: datetime = datetime.fromtimestamp(
|
||||
os.path.getmtime(path), tz=timezone.utc
|
||||
)
|
||||
if mtime < cutoff:
|
||||
to_delete.add(path)
|
||||
|
||||
deleted: int = 0
|
||||
for path in to_delete:
|
||||
try:
|
||||
os.remove(path)
|
||||
deleted += 1
|
||||
except OSError: # noqa: PERF203
|
||||
pass
|
||||
|
||||
for dirpath, dirnames, filenames in os.walk(location, topdown=False):
|
||||
if dirpath != location and not filenames and not dirnames:
|
||||
try:
|
||||
os.rmdir(dirpath)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
return deleted
|
||||
|
||||
|
||||
def _prune_sqlite(db_path: str, keep: int | None, older_than: timedelta | None) -> int:
|
||||
deleted: int = 0
|
||||
with sqlite3.connect(db_path) as conn:
|
||||
if older_than is not None:
|
||||
cutoff: str = (datetime.now(timezone.utc) - older_than).strftime(
|
||||
"%Y%m%dT%H%M%S"
|
||||
)
|
||||
cursor: sqlite3.Cursor = conn.execute(_DELETE_OLDER_THAN, (cutoff,))
|
||||
deleted += cursor.rowcount
|
||||
|
||||
if keep is not None:
|
||||
cursor = conn.execute(_DELETE_KEEP_N, (keep,))
|
||||
deleted += cursor.rowcount
|
||||
|
||||
conn.commit()
|
||||
return deleted
|
||||
|
||||
|
||||
def prune_checkpoints(
|
||||
location: str, keep: int | None, older_than: str | None, dry_run: bool = False
|
||||
) -> None:
|
||||
if keep is None and older_than is None:
|
||||
click.echo("Specify --keep N and/or --older-than DURATION (e.g. 7d, 24h)")
|
||||
return
|
||||
|
||||
duration: timedelta | None = _parse_duration(older_than) if older_than else None
|
||||
|
||||
deleted: int
|
||||
if _is_sqlite(location):
|
||||
if dry_run:
|
||||
with sqlite3.connect(location) as conn:
|
||||
total: int = conn.execute(_COUNT_CHECKPOINTS).fetchone()[0]
|
||||
click.echo(f"Would prune from {total} checkpoint(s) in {location}")
|
||||
return
|
||||
deleted = _prune_sqlite(location, keep, duration)
|
||||
elif os.path.isdir(location):
|
||||
if dry_run:
|
||||
files: list[str] = glob.glob(
|
||||
os.path.join(location, "**", "*.json"), recursive=True
|
||||
)
|
||||
click.echo(f"Would prune from {len(files)} checkpoint(s) in {location}")
|
||||
return
|
||||
deleted = _prune_json(location, keep, duration)
|
||||
else:
|
||||
click.echo(f"Not a directory or SQLite database: {location}")
|
||||
return
|
||||
click.echo(f"Pruned {deleted} checkpoint(s) from {location}")
|
||||
877
lib/cli/src/crewai_cli/checkpoint_tui.py
Normal file
877
lib/cli/src/crewai_cli/checkpoint_tui.py
Normal file
@@ -0,0 +1,877 @@
|
||||
"""Textual TUI for browsing checkpoint files."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from typing import Any, ClassVar, Literal
|
||||
|
||||
from textual.app import App, ComposeResult
|
||||
from textual.binding import Binding
|
||||
from textual.containers import Horizontal, Vertical, VerticalScroll
|
||||
from textual.widgets import (
|
||||
Collapsible,
|
||||
Footer,
|
||||
Header,
|
||||
Input,
|
||||
Static,
|
||||
TabPane,
|
||||
TabbedContent,
|
||||
TextArea,
|
||||
Tree,
|
||||
)
|
||||
|
||||
from crewai_cli.checkpoint_cli import (
|
||||
_format_size,
|
||||
_is_sqlite,
|
||||
_list_json,
|
||||
_list_sqlite,
|
||||
)
|
||||
|
||||
|
||||
_PRIMARY = "#eb6658"
|
||||
_SECONDARY = "#1F7982"
|
||||
_TERTIARY = "#ffffff"
|
||||
_DIM = "#888888"
|
||||
_BG_DARK = "#0d1117"
|
||||
_BG_PANEL = "#161b22"
|
||||
_ACCENT = "#c9a227"
|
||||
_SUCCESS = "#3fb950"
|
||||
_PENDING = "#e3b341"
|
||||
|
||||
_ENTITY_ICONS: dict[str, str] = {
|
||||
"flow": "◆",
|
||||
"crew": "●",
|
||||
"agent": "◈",
|
||||
"unknown": "○",
|
||||
}
|
||||
_ENTITY_COLORS: dict[str, str] = {
|
||||
"flow": _ACCENT,
|
||||
"crew": _SECONDARY,
|
||||
"agent": _PRIMARY,
|
||||
"unknown": _DIM,
|
||||
}
|
||||
|
||||
|
||||
def _load_entries(location: str) -> list[dict[str, Any]]:
|
||||
if _is_sqlite(location):
|
||||
return _list_sqlite(location)
|
||||
return _list_json(location)
|
||||
|
||||
|
||||
def _human_ts(ts: str) -> str:
|
||||
"""Turn '2026-04-17 17:05:00' into a short relative label."""
|
||||
try:
|
||||
dt = datetime.strptime(ts, "%Y-%m-%d %H:%M:%S")
|
||||
except ValueError:
|
||||
return ts
|
||||
now = datetime.now()
|
||||
delta = now.date() - dt.date()
|
||||
hour = dt.hour % 12 or 12
|
||||
ampm = "am" if dt.hour < 12 else "pm"
|
||||
time_str = f"{hour}:{dt.minute:02d}{ampm}"
|
||||
if delta.days == 0:
|
||||
return time_str
|
||||
if delta.days == 1:
|
||||
return f"yest {time_str}"
|
||||
if delta.days < 7:
|
||||
return f"{dt.strftime('%a').lower()} {time_str}"
|
||||
return f"{dt.strftime('%b')} {dt.day}"
|
||||
|
||||
|
||||
def _short_id(name: str) -> str:
|
||||
if len(name) > 30:
|
||||
return name[:27] + "..."
|
||||
return name
|
||||
|
||||
|
||||
def _entry_id(entry: dict[str, Any]) -> str:
|
||||
"""Normalize an entry's name into its checkpoint ID.
|
||||
|
||||
JSON filenames are ``{ts}_{uuid}_p-{parent}.json``; SQLite IDs
|
||||
are already ``{ts}_{uuid}``. This strips the JSON suffix so
|
||||
fork-parent lookups work in both providers.
|
||||
"""
|
||||
name = str(entry.get("name", ""))
|
||||
if name.endswith(".json"):
|
||||
name = name[: -len(".json")]
|
||||
idx = name.find("_p-")
|
||||
if idx != -1:
|
||||
name = name[:idx]
|
||||
return name
|
||||
|
||||
|
||||
def _build_progress_bar(completed: int, total: int, width: int = 20) -> str:
|
||||
if total == 0:
|
||||
return f"[{_DIM}]{'░' * width}[/] 0/0"
|
||||
pct = int(completed / total * 100)
|
||||
filled = int(width * completed / total)
|
||||
color = _SUCCESS if completed == total else _PRIMARY
|
||||
bar = f"[{color}]{'█' * filled}[/][{_DIM}]{'░' * (width - filled)}[/]"
|
||||
return f"{bar} {completed}/{total} ({pct}%)"
|
||||
|
||||
|
||||
def _entity_icon(etype: str) -> str:
|
||||
icon = _ENTITY_ICONS.get(etype, _ENTITY_ICONS["unknown"])
|
||||
color = _ENTITY_COLORS.get(etype, _DIM)
|
||||
return f"[{color}]{icon}[/]"
|
||||
|
||||
|
||||
_TuiResult = (
|
||||
tuple[
|
||||
str,
|
||||
str,
|
||||
dict[str, Any] | None,
|
||||
dict[int, str] | None,
|
||||
Literal["crew", "flow", "agent"],
|
||||
]
|
||||
| None
|
||||
)
|
||||
|
||||
|
||||
class CheckpointTUI(App[_TuiResult]):
|
||||
"""TUI to browse and inspect checkpoints.
|
||||
|
||||
Returns ``(location, action, inputs, task_overrides, entity_type)``
|
||||
where action is ``"resume"`` or ``"fork"``, inputs is a parsed dict
|
||||
or ``None``, and entity_type is ``"crew"`` or ``"flow"``;
|
||||
or ``None`` if the user quit without selecting.
|
||||
"""
|
||||
|
||||
TITLE = "CrewAI Checkpoints"
|
||||
|
||||
CSS = f"""
|
||||
Screen {{
|
||||
background: {_BG_DARK};
|
||||
}}
|
||||
Header {{
|
||||
background: {_PRIMARY};
|
||||
color: {_TERTIARY};
|
||||
}}
|
||||
Footer {{
|
||||
background: {_SECONDARY};
|
||||
color: {_TERTIARY};
|
||||
}}
|
||||
Footer > .footer-key--key {{
|
||||
background: {_PRIMARY};
|
||||
color: {_TERTIARY};
|
||||
}}
|
||||
#main-layout {{
|
||||
height: 1fr;
|
||||
}}
|
||||
#tree-panel {{
|
||||
width: 40%;
|
||||
background: {_BG_PANEL};
|
||||
border: round {_SECONDARY};
|
||||
padding: 0 1;
|
||||
scrollbar-color: {_PRIMARY};
|
||||
}}
|
||||
#tree-panel:focus-within {{
|
||||
border: round {_PRIMARY};
|
||||
}}
|
||||
#detail-container {{
|
||||
width: 60%;
|
||||
height: 1fr;
|
||||
}}
|
||||
#status {{
|
||||
height: 1;
|
||||
padding: 0 2;
|
||||
color: {_DIM};
|
||||
}}
|
||||
#detail-tabs {{
|
||||
height: 1fr;
|
||||
}}
|
||||
TabbedContent > ContentSwitcher {{
|
||||
background: {_BG_PANEL};
|
||||
height: 1fr;
|
||||
}}
|
||||
TabPane {{
|
||||
padding: 0;
|
||||
}}
|
||||
Tabs {{
|
||||
background: {_BG_DARK};
|
||||
}}
|
||||
Tab {{
|
||||
background: {_BG_DARK};
|
||||
color: {_DIM};
|
||||
padding: 0 2;
|
||||
}}
|
||||
Tab.-active {{
|
||||
background: {_BG_PANEL};
|
||||
color: {_PRIMARY};
|
||||
}}
|
||||
Tab:hover {{
|
||||
color: {_TERTIARY};
|
||||
}}
|
||||
Underline > .underline--bar {{
|
||||
color: {_SECONDARY};
|
||||
background: {_BG_DARK};
|
||||
}}
|
||||
.tab-scroll {{
|
||||
background: {_BG_PANEL};
|
||||
height: 1fr;
|
||||
padding: 1 2;
|
||||
scrollbar-color: {_PRIMARY};
|
||||
}}
|
||||
.section-header {{
|
||||
padding: 0 0 0 1;
|
||||
margin: 1 0 0 0;
|
||||
}}
|
||||
.detail-line {{
|
||||
padding: 0 0 0 1;
|
||||
}}
|
||||
.task-label {{
|
||||
padding: 0 1;
|
||||
}}
|
||||
.task-output-editor {{
|
||||
height: auto;
|
||||
max-height: 10;
|
||||
margin: 0 1 1 3;
|
||||
border: round {_DIM};
|
||||
}}
|
||||
.task-output-editor:focus {{
|
||||
border: round {_PRIMARY};
|
||||
}}
|
||||
Collapsible {{
|
||||
background: {_BG_PANEL};
|
||||
padding: 0;
|
||||
margin: 0 0 1 1;
|
||||
}}
|
||||
CollapsibleTitle {{
|
||||
background: {_BG_DARK};
|
||||
color: {_TERTIARY};
|
||||
padding: 0 1;
|
||||
}}
|
||||
CollapsibleTitle:hover {{
|
||||
background: {_SECONDARY};
|
||||
}}
|
||||
.input-row {{
|
||||
height: 3;
|
||||
padding: 0 1;
|
||||
}}
|
||||
.input-row Static {{
|
||||
width: auto;
|
||||
min-width: 12;
|
||||
padding: 1 1 0 0;
|
||||
color: {_TERTIARY};
|
||||
}}
|
||||
.input-row Input {{
|
||||
width: 1fr;
|
||||
}}
|
||||
.empty-state {{
|
||||
color: {_DIM};
|
||||
padding: 1;
|
||||
}}
|
||||
Tree {{
|
||||
background: {_BG_PANEL};
|
||||
}}
|
||||
Tree > .tree--cursor {{
|
||||
background: {_SECONDARY};
|
||||
color: {_TERTIARY};
|
||||
}}
|
||||
"""
|
||||
|
||||
BINDINGS: ClassVar[list[Binding | tuple[str, str] | tuple[str, str, str]]] = [
|
||||
("q", "quit", "Quit"),
|
||||
("r", "refresh", "Refresh"),
|
||||
("e", "resume", "Resume"),
|
||||
("f", "fork", "Fork"),
|
||||
]
|
||||
|
||||
def __init__(self, location: str = "./.checkpoints") -> None:
|
||||
super().__init__()
|
||||
self._location = location
|
||||
self._entries: list[dict[str, Any]] = []
|
||||
self._selected_entry: dict[str, Any] | None = None
|
||||
self._input_keys: list[str] = []
|
||||
self._task_output_ids: list[tuple[int, str, str]] = []
|
||||
|
||||
def compose(self) -> ComposeResult:
|
||||
yield Header(show_clock=False)
|
||||
with Horizontal(id="main-layout"):
|
||||
tree: Tree[dict[str, Any]] = Tree("Checkpoints", id="tree-panel")
|
||||
tree.show_root = False
|
||||
tree.guide_depth = 3
|
||||
yield tree
|
||||
with Vertical(id="detail-container"):
|
||||
yield Static("", id="status")
|
||||
with TabbedContent(id="detail-tabs"):
|
||||
with TabPane("Overview", id="tab-overview"):
|
||||
with VerticalScroll(classes="tab-scroll"):
|
||||
yield Static(
|
||||
f"[{_DIM}]Select a checkpoint from the tree[/]", # noqa: S608
|
||||
id="overview-empty",
|
||||
)
|
||||
with TabPane("Tasks", id="tab-tasks"):
|
||||
with VerticalScroll(classes="tab-scroll"):
|
||||
yield Static(
|
||||
f"[{_DIM}]Select a checkpoint to view tasks[/]",
|
||||
id="tasks-empty",
|
||||
)
|
||||
with TabPane("Inputs", id="tab-inputs"):
|
||||
with VerticalScroll(classes="tab-scroll"):
|
||||
yield Static(
|
||||
f"[{_DIM}]Select a checkpoint to view inputs[/]",
|
||||
id="inputs-empty",
|
||||
)
|
||||
yield Footer()
|
||||
|
||||
async def on_mount(self) -> None:
|
||||
self._refresh_tree()
|
||||
self.query_one("#tree-panel", Tree).root.expand()
|
||||
|
||||
# ── Tree building ──────────────────────────────────────────────
|
||||
|
||||
@staticmethod
|
||||
def _top_level_entity(entry: dict[str, Any]) -> tuple[str, str]:
|
||||
etype, ename = "unknown", ""
|
||||
for ent in entry.get("entities", []):
|
||||
t = ent.get("type", "unknown")
|
||||
if t == "flow":
|
||||
return "flow", ent.get("name") or ""
|
||||
if t == "crew" and etype != "crew":
|
||||
etype, ename = "crew", ent.get("name") or ""
|
||||
return etype, ename
|
||||
|
||||
def _refresh_tree(self) -> None:
|
||||
self._entries = _load_entries(self._location)
|
||||
self._selected_entry = None
|
||||
|
||||
tree = self.query_one("#tree-panel", Tree)
|
||||
tree.clear()
|
||||
|
||||
if not self._entries:
|
||||
self.sub_title = self._location
|
||||
self.query_one("#status", Static).update("")
|
||||
return
|
||||
|
||||
grouped: dict[tuple[str, str], dict[str, list[dict[str, Any]]]] = defaultdict(
|
||||
lambda: defaultdict(list)
|
||||
)
|
||||
for entry in self._entries:
|
||||
key = self._top_level_entity(entry)
|
||||
branch = entry.get("branch", "main")
|
||||
grouped[key][branch].append(entry)
|
||||
|
||||
def _make_label(e: dict[str, Any]) -> str:
|
||||
ts = e.get("ts") or ""
|
||||
trigger = e.get("trigger") or ""
|
||||
time_part = ts.split(" ")[-1] if " " in ts else ts
|
||||
|
||||
total_c, total_t = 0, 0
|
||||
for ent in e.get("entities", []):
|
||||
c = ent.get("tasks_completed")
|
||||
t = ent.get("tasks_total")
|
||||
if c is not None and t is not None:
|
||||
total_c += c
|
||||
total_t += t
|
||||
|
||||
parts: list[str] = []
|
||||
if time_part:
|
||||
parts.append(f"[{_DIM}]{time_part}[/]")
|
||||
if trigger:
|
||||
parts.append(f"[{_PRIMARY}]{trigger}[/]")
|
||||
if total_t:
|
||||
display_c = total_c
|
||||
if trigger == "task_started" and total_c < total_t:
|
||||
display_c = total_c + 1
|
||||
color = _SUCCESS if total_c == total_t else _DIM
|
||||
parts.append(f"[{color}]{display_c}/{total_t}[/]")
|
||||
return " ".join(parts) if parts else _short_id(e.get("name", ""))
|
||||
|
||||
fork_parents: set[str] = set()
|
||||
for branches in grouped.values():
|
||||
for branch_name, entries in branches.items():
|
||||
if branch_name == "main" or not entries:
|
||||
continue
|
||||
oldest = min(entries, key=lambda e: str(e.get("name", "")))
|
||||
first_parent = oldest.get("parent_id")
|
||||
if first_parent:
|
||||
fork_parents.add(str(first_parent))
|
||||
|
||||
node_by_name: dict[str, Any] = {}
|
||||
|
||||
def _add_checkpoint(parent_node: Any, e: dict[str, Any]) -> None:
|
||||
cp_id = _entry_id(e)
|
||||
if cp_id in fork_parents:
|
||||
node = parent_node.add(
|
||||
_make_label(e), data=e, expand=False, allow_expand=True
|
||||
)
|
||||
else:
|
||||
node = parent_node.add_leaf(_make_label(e), data=e)
|
||||
node_by_name[cp_id] = node
|
||||
|
||||
type_order = {"flow": 0, "crew": 1}
|
||||
sorted_keys = sorted(
|
||||
grouped.keys(), key=lambda k: (type_order.get(k[0], 9), k[1])
|
||||
)
|
||||
|
||||
for etype, ename in sorted_keys:
|
||||
branches = grouped[(etype, ename)]
|
||||
icon = _entity_icon(etype)
|
||||
color = _ENTITY_COLORS.get(etype, _DIM)
|
||||
total = sum(len(v) for v in branches.values())
|
||||
|
||||
label_parts = [f"{icon} [bold {color}]{etype.upper()}[/]"]
|
||||
if ename:
|
||||
label_parts.append(f"[bold]{ename}[/]")
|
||||
label_parts.append(f"[{_DIM}]({total})[/]")
|
||||
all_entries = [e for bl in branches.values() for e in bl]
|
||||
timestamps = [str(e.get("ts", "")) for e in all_entries if e.get("ts")]
|
||||
if timestamps:
|
||||
latest = max(timestamps)
|
||||
label_parts.append(f"[{_DIM}]{_human_ts(latest)}[/]")
|
||||
entity_label = " ".join(label_parts)
|
||||
entity_node = tree.root.add(entity_label, expand=True)
|
||||
|
||||
if "main" in branches:
|
||||
for entry in reversed(branches["main"]):
|
||||
_add_checkpoint(entity_node, entry)
|
||||
|
||||
fork_branches = [
|
||||
(name, sorted(entries, key=lambda e: str(e.get("name", ""))))
|
||||
for name, entries in branches.items()
|
||||
if name != "main"
|
||||
]
|
||||
remaining = fork_branches
|
||||
max_passes = len(remaining) + 1
|
||||
while remaining and max_passes > 0:
|
||||
max_passes -= 1
|
||||
deferred = []
|
||||
made_progress = False
|
||||
for branch_name, entries in remaining:
|
||||
first_parent = entries[0].get("parent_id") if entries else None
|
||||
if first_parent and str(first_parent) not in node_by_name:
|
||||
deferred.append((branch_name, entries))
|
||||
continue
|
||||
attach_to: Any = entity_node
|
||||
if first_parent:
|
||||
attach_to = node_by_name.get(str(first_parent), entity_node)
|
||||
branch_label = (
|
||||
f"[bold {_SECONDARY}]{branch_name}[/] "
|
||||
f"[{_DIM}]({len(entries)})[/]"
|
||||
)
|
||||
branch_node = attach_to.add(branch_label, expand=False)
|
||||
for entry in entries:
|
||||
_add_checkpoint(branch_node, entry)
|
||||
made_progress = True
|
||||
remaining = deferred
|
||||
if not made_progress:
|
||||
break
|
||||
|
||||
for branch_name, entries in remaining:
|
||||
branch_label = (
|
||||
f"[bold {_SECONDARY}]{branch_name}[/] "
|
||||
f"[{_DIM}]({len(entries)})[/] [{_DIM}](orphaned)[/]"
|
||||
)
|
||||
branch_node = entity_node.add(branch_label, expand=False)
|
||||
for entry in entries:
|
||||
_add_checkpoint(branch_node, entry)
|
||||
|
||||
count = len(self._entries)
|
||||
storage = "SQLite" if _is_sqlite(self._location) else "JSON"
|
||||
self.sub_title = self._location
|
||||
self.query_one("#status", Static).update(f" {count} checkpoint(s) | {storage}")
|
||||
|
||||
# ── Detail panel ───────────────────────────────────────────────
|
||||
|
||||
async def _clear_scroll(self, tab_id: str) -> VerticalScroll:
|
||||
tab = self.query_one(f"#{tab_id}", TabPane)
|
||||
scroll = tab.query_one(VerticalScroll)
|
||||
for child in list(scroll.children):
|
||||
await child.remove()
|
||||
return scroll
|
||||
|
||||
async def _show_detail(self, entry: dict[str, Any]) -> None:
|
||||
self._selected_entry = entry
|
||||
|
||||
await self._render_overview(entry)
|
||||
await self._render_tasks(entry)
|
||||
await self._render_inputs(entry.get("inputs", {}))
|
||||
|
||||
async def _render_overview(self, entry: dict[str, Any]) -> None:
|
||||
scroll = await self._clear_scroll("tab-overview")
|
||||
|
||||
name = entry.get("name", "")
|
||||
ts = entry.get("ts") or "unknown"
|
||||
trigger = entry.get("trigger") or ""
|
||||
branch = entry.get("branch", "main")
|
||||
parent_id = entry.get("parent_id")
|
||||
|
||||
header_lines = [
|
||||
f"[bold {_PRIMARY}]{name}[/]",
|
||||
f"[{_DIM}]{'─' * 50}[/]",
|
||||
"",
|
||||
f" [bold]Time[/] {ts}",
|
||||
]
|
||||
if "size" in entry:
|
||||
header_lines.append(f" [bold]Size[/] {_format_size(entry['size'])}")
|
||||
header_lines.append(f" [bold]Events[/] {entry.get('event_count', 0)}")
|
||||
if trigger:
|
||||
header_lines.append(f" [bold]Trigger[/] [{_PRIMARY}]{trigger}[/]")
|
||||
header_lines.append(f" [bold]Branch[/] [{_SECONDARY}]{branch}[/]")
|
||||
if parent_id:
|
||||
header_lines.append(f" [bold]Parent[/] [{_DIM}]{parent_id}[/]")
|
||||
|
||||
await scroll.mount(Static("\n".join(header_lines)))
|
||||
|
||||
for ent in entry.get("entities", []):
|
||||
etype = ent.get("type", "unknown")
|
||||
ename = ent.get("name", "unnamed")
|
||||
icon = _entity_icon(etype)
|
||||
color = _ENTITY_COLORS.get(etype, _DIM)
|
||||
|
||||
eid = str(ent.get("id", ""))[:8]
|
||||
entity_title = (
|
||||
f"\n{icon} [bold {color}]{etype.upper()}[/] [bold]{ename}[/]"
|
||||
)
|
||||
if eid:
|
||||
entity_title += f" [{_DIM}]{eid}…[/]"
|
||||
await scroll.mount(Static(entity_title, classes="section-header"))
|
||||
await scroll.mount(Static(f"[{_DIM}]{'─' * 46}[/]", classes="detail-line"))
|
||||
|
||||
if etype == "flow":
|
||||
methods = ent.get("completed_methods", [])
|
||||
if methods:
|
||||
method_list = ", ".join(f"[{_SUCCESS}]{m}[/]" for m in methods)
|
||||
await scroll.mount(
|
||||
Static(
|
||||
f" [bold]Methods[/] {method_list}",
|
||||
classes="detail-line",
|
||||
)
|
||||
)
|
||||
flow_state = ent.get("flow_state")
|
||||
if isinstance(flow_state, dict) and flow_state:
|
||||
state_parts: list[str] = []
|
||||
for k, v in list(flow_state.items())[:5]:
|
||||
sv = str(v)
|
||||
if len(sv) > 40:
|
||||
sv = sv[:37] + "..."
|
||||
state_parts.append(f"[{_DIM}]{k}[/]={sv}")
|
||||
await scroll.mount(
|
||||
Static(
|
||||
f" [bold]State[/] {', '.join(state_parts)}",
|
||||
classes="detail-line",
|
||||
)
|
||||
)
|
||||
|
||||
agents = ent.get("agents", [])
|
||||
if agents:
|
||||
agent_lines: list[Static] = []
|
||||
for ag in agents:
|
||||
role = ag.get("role", "unnamed")
|
||||
goal = ag.get("goal", "")
|
||||
if len(goal) > 60:
|
||||
goal = goal[:57] + "..."
|
||||
agent_line = f" {_entity_icon('agent')} [bold]{role}[/]"
|
||||
if goal:
|
||||
agent_line += f"\n [{_DIM}]{goal}[/]"
|
||||
agent_lines.append(Static(agent_line))
|
||||
|
||||
collapsible = Collapsible(
|
||||
*agent_lines,
|
||||
title=f"Agents ({len(agents)})",
|
||||
collapsed=len(agents) > 3,
|
||||
)
|
||||
await scroll.mount(collapsible)
|
||||
|
||||
async def _render_tasks(self, entry: dict[str, Any]) -> None:
|
||||
scroll = await self._clear_scroll("tab-tasks")
|
||||
|
||||
self._task_output_ids = []
|
||||
flat_task_idx = 0
|
||||
has_tasks = False
|
||||
|
||||
for ent_idx, ent in enumerate(entry.get("entities", [])):
|
||||
etype = ent.get("type", "unknown")
|
||||
ename = ent.get("name", "unnamed")
|
||||
icon = _entity_icon(etype)
|
||||
color = _ENTITY_COLORS.get(etype, _DIM)
|
||||
|
||||
tasks = ent.get("tasks", [])
|
||||
if not tasks:
|
||||
continue
|
||||
has_tasks = True
|
||||
|
||||
completed = ent.get("tasks_completed", 0)
|
||||
total = ent.get("tasks_total", 0)
|
||||
|
||||
await scroll.mount(
|
||||
Static(
|
||||
f"{icon} [bold {color}]{ename}[/] "
|
||||
f"{_build_progress_bar(completed, total, width=16)}",
|
||||
classes="section-header",
|
||||
)
|
||||
)
|
||||
|
||||
for i, task in enumerate(tasks):
|
||||
desc = str(task.get("description", ""))
|
||||
if len(desc) > 50:
|
||||
desc = desc[:47] + "..."
|
||||
agent_role = task.get("agent_role", "")
|
||||
|
||||
if task.get("completed"):
|
||||
status_icon = f"[{_SUCCESS}]✓[/]"
|
||||
task_line = f" {status_icon} {i + 1}. {desc}"
|
||||
if agent_role:
|
||||
task_line += (
|
||||
f" [{_DIM}]→ {_entity_icon('agent')} {agent_role}[/]"
|
||||
)
|
||||
await scroll.mount(Static(task_line, classes="task-label"))
|
||||
output_text = task.get("output", "")
|
||||
editor_id = f"task-output-{ent_idx}-{i}"
|
||||
await scroll.mount(
|
||||
TextArea(
|
||||
str(output_text),
|
||||
classes="task-output-editor",
|
||||
id=editor_id,
|
||||
)
|
||||
)
|
||||
self._task_output_ids.append(
|
||||
(flat_task_idx, editor_id, str(output_text))
|
||||
)
|
||||
else:
|
||||
status_icon = f"[{_PENDING}]○[/]"
|
||||
task_line = f" {status_icon} {i + 1}. {desc}"
|
||||
if agent_role:
|
||||
task_line += (
|
||||
f" [{_DIM}]→ {_entity_icon('agent')} {agent_role}[/]"
|
||||
)
|
||||
await scroll.mount(Static(task_line, classes="task-label"))
|
||||
flat_task_idx += 1
|
||||
|
||||
if not has_tasks:
|
||||
await scroll.mount(Static(f"[{_DIM}]No tasks[/]", classes="empty-state"))
|
||||
|
||||
async def _render_inputs(self, inputs: dict[str, Any]) -> None:
|
||||
scroll = await self._clear_scroll("tab-inputs")
|
||||
|
||||
self._input_keys = []
|
||||
|
||||
if not inputs:
|
||||
await scroll.mount(Static(f"[{_DIM}]No inputs[/]", classes="empty-state"))
|
||||
return
|
||||
|
||||
for key, value in inputs.items():
|
||||
self._input_keys.append(key)
|
||||
row = Horizontal(classes="input-row")
|
||||
row.compose_add_child(Static(f"[bold]{key}[/]"))
|
||||
row.compose_add_child(
|
||||
Input(value=str(value), placeholder=key, id=f"input-{key}")
|
||||
)
|
||||
await scroll.mount(row)
|
||||
|
||||
# ── Data collection ────────────────────────────────────────────
|
||||
|
||||
def _collect_inputs(self) -> dict[str, Any] | None:
|
||||
if not self._input_keys:
|
||||
return None
|
||||
result: dict[str, Any] = {}
|
||||
for key in self._input_keys:
|
||||
widget = self.query_one(f"#input-{key}", Input)
|
||||
result[key] = widget.value
|
||||
return result
|
||||
|
||||
def _collect_task_overrides(self) -> dict[int, str] | None:
|
||||
if not self._task_output_ids or self._selected_entry is None:
|
||||
return None
|
||||
overrides: dict[int, str] = {}
|
||||
for task_idx, editor_id, original in self._task_output_ids:
|
||||
editor = self.query_one(f"#{editor_id}", TextArea)
|
||||
if editor.text != original:
|
||||
overrides[task_idx] = editor.text
|
||||
return overrides or None
|
||||
|
||||
def _detect_entity_type(
|
||||
self, entry: dict[str, Any]
|
||||
) -> Literal["crew", "flow", "agent"]:
|
||||
for ent in entry.get("entities", []):
|
||||
if ent.get("type") == "flow":
|
||||
return "flow"
|
||||
if ent.get("type") == "agent":
|
||||
return "agent"
|
||||
return "crew"
|
||||
|
||||
def _resolve_location(self, entry: dict[str, Any]) -> str:
|
||||
if "path" in entry:
|
||||
return str(entry["path"])
|
||||
if _is_sqlite(self._location):
|
||||
return f"{self._location}#{entry['name']}"
|
||||
return str(entry.get("name", ""))
|
||||
|
||||
# ── Events ─────────────────────────────────────────────────────
|
||||
|
||||
async def on_tree_node_highlighted(
|
||||
self, event: Tree.NodeHighlighted[dict[str, Any]]
|
||||
) -> None:
|
||||
if event.node.data is not None:
|
||||
await self._show_detail(event.node.data)
|
||||
|
||||
def _exit_with_action(self, action: str) -> None:
|
||||
if self._selected_entry is None:
|
||||
self.notify("No checkpoint selected", severity="warning")
|
||||
return
|
||||
inputs = self._collect_inputs()
|
||||
overrides = self._collect_task_overrides()
|
||||
loc = self._resolve_location(self._selected_entry)
|
||||
etype = self._detect_entity_type(self._selected_entry)
|
||||
name = self._selected_entry.get("name", "")[:30]
|
||||
self.notify(f"{action.title()}: {name}")
|
||||
self.exit((loc, action, inputs, overrides, etype))
|
||||
|
||||
def action_resume(self) -> None:
|
||||
self._exit_with_action("resume")
|
||||
|
||||
def action_fork(self) -> None:
|
||||
self._exit_with_action("fork")
|
||||
|
||||
def action_refresh(self) -> None:
|
||||
self._refresh_tree()
|
||||
|
||||
|
||||
def _apply_task_overrides(crew: Any, task_overrides: dict[int, str]) -> None:
|
||||
"""Apply task output overrides to a restored Crew and print modifications."""
|
||||
import click
|
||||
|
||||
click.echo("Modifications:")
|
||||
overridden_agents: set[int] = set()
|
||||
for task_idx, new_output in task_overrides.items():
|
||||
if task_idx < len(crew.tasks) and crew.tasks[task_idx].output is not None:
|
||||
desc = crew.tasks[task_idx].description or f"Task {task_idx + 1}"
|
||||
if len(desc) > 60:
|
||||
desc = desc[:57] + "..."
|
||||
crew.tasks[task_idx].output.raw = new_output
|
||||
preview = new_output.replace("\n", " ")
|
||||
if len(preview) > 80:
|
||||
preview = preview[:77] + "..."
|
||||
click.echo(f" Task {task_idx + 1}: {desc}")
|
||||
click.echo(f" -> {preview}")
|
||||
agent = crew.tasks[task_idx].agent
|
||||
if agent and agent.agent_executor:
|
||||
nth = sum(1 for t in crew.tasks[:task_idx] if t.agent is agent)
|
||||
messages = agent.agent_executor.messages
|
||||
system_positions = [
|
||||
i for i, m in enumerate(messages) if m.get("role") == "system"
|
||||
]
|
||||
if nth < len(system_positions):
|
||||
seg_start = system_positions[nth]
|
||||
seg_end = (
|
||||
system_positions[nth + 1]
|
||||
if nth + 1 < len(system_positions)
|
||||
else len(messages)
|
||||
)
|
||||
for j in range(seg_end - 1, seg_start, -1):
|
||||
if messages[j].get("role") == "assistant":
|
||||
messages[j]["content"] = new_output
|
||||
break
|
||||
overridden_agents.add(id(agent))
|
||||
|
||||
earliest = min(task_overrides)
|
||||
for offset, subsequent in enumerate(crew.tasks[earliest + 1 :], start=earliest + 1):
|
||||
if subsequent.output and offset not in task_overrides:
|
||||
subsequent.output = None
|
||||
if subsequent.agent and subsequent.agent.agent_executor:
|
||||
subsequent.agent.agent_executor._resuming = False
|
||||
if id(subsequent.agent) not in overridden_agents:
|
||||
subsequent.agent.agent_executor.messages = []
|
||||
click.echo()
|
||||
|
||||
|
||||
async def _run_checkpoint_tui_async(location: str) -> None:
|
||||
"""Async implementation of the checkpoint TUI flow."""
|
||||
import click
|
||||
|
||||
app = CheckpointTUI(location=location)
|
||||
selection = await app.run_async()
|
||||
|
||||
if selection is None:
|
||||
return
|
||||
|
||||
selected, action, inputs, task_overrides, entity_type = selection
|
||||
|
||||
from crewai.state.checkpoint_config import CheckpointConfig
|
||||
|
||||
config = CheckpointConfig(restore_from=selected)
|
||||
|
||||
if entity_type == "flow":
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.flow.flow import Flow
|
||||
|
||||
if action == "fork":
|
||||
click.echo(f"\nForking flow from: {selected}\n")
|
||||
flow = Flow.fork(config)
|
||||
else:
|
||||
click.echo(f"\nResuming flow from: {selected}\n")
|
||||
flow = Flow.from_checkpoint(config)
|
||||
|
||||
if task_overrides:
|
||||
from crewai.crew import Crew as CrewCls
|
||||
|
||||
state = crewai_event_bus._runtime_state
|
||||
if state is not None:
|
||||
flat_offset = 0
|
||||
for entity in state.root:
|
||||
if not isinstance(entity, CrewCls) or not entity.tasks:
|
||||
continue
|
||||
n = len(entity.tasks)
|
||||
local = {
|
||||
idx - flat_offset: out
|
||||
for idx, out in task_overrides.items()
|
||||
if flat_offset <= idx < flat_offset + n
|
||||
}
|
||||
if local:
|
||||
_apply_task_overrides(entity, local)
|
||||
flat_offset += n
|
||||
|
||||
if inputs:
|
||||
click.echo("Inputs:")
|
||||
for k, v in inputs.items():
|
||||
click.echo(f" {k}: {v}")
|
||||
click.echo()
|
||||
|
||||
result = await flow.kickoff_async(inputs=inputs)
|
||||
click.echo(f"\nResult: {getattr(result, 'raw', result)}")
|
||||
return
|
||||
|
||||
if entity_type == "agent":
|
||||
from crewai.agent import Agent
|
||||
|
||||
if action == "fork":
|
||||
click.echo(f"\nForking agent from: {selected}\n")
|
||||
agent = Agent.fork(config)
|
||||
else:
|
||||
click.echo(f"\nResuming agent from: {selected}\n")
|
||||
agent = Agent.from_checkpoint(config)
|
||||
|
||||
click.echo()
|
||||
result = await agent.akickoff(messages="Resume execution.")
|
||||
click.echo(f"\nResult: {getattr(result, 'raw', result)}")
|
||||
return
|
||||
|
||||
from crewai.crew import Crew
|
||||
|
||||
if action == "fork":
|
||||
click.echo(f"\nForking from: {selected}\n")
|
||||
crew = Crew.fork(config)
|
||||
else:
|
||||
click.echo(f"\nResuming from: {selected}\n")
|
||||
crew = Crew.from_checkpoint(config)
|
||||
|
||||
if task_overrides:
|
||||
_apply_task_overrides(crew, task_overrides)
|
||||
|
||||
if inputs:
|
||||
click.echo("Inputs:")
|
||||
for k, v in inputs.items():
|
||||
click.echo(f" {k}: {v}")
|
||||
click.echo()
|
||||
|
||||
result = await crew.akickoff(inputs=inputs)
|
||||
click.echo(f"\nResult: {getattr(result, 'raw', result)}")
|
||||
|
||||
|
||||
def run_checkpoint_tui(location: str = "./.checkpoints") -> None:
|
||||
"""Launch the checkpoint browser TUI."""
|
||||
import asyncio
|
||||
|
||||
asyncio.run(_run_checkpoint_tui_async(location))
|
||||
960
lib/cli/src/crewai_cli/cli.py
Normal file
960
lib/cli/src/crewai_cli/cli.py
Normal file
@@ -0,0 +1,960 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from importlib.metadata import version as get_version
|
||||
import os
|
||||
import subprocess
|
||||
from typing import Any
|
||||
|
||||
import click
|
||||
from crewai_core.token_manager import TokenManager
|
||||
|
||||
from crewai_cli.add_crew_to_flow import add_crew_to_flow
|
||||
from crewai_cli.authentication.main import AuthenticationCommand
|
||||
from crewai_cli.config import Settings
|
||||
from crewai_cli.create_crew import create_crew
|
||||
from crewai_cli.create_flow import create_flow
|
||||
from crewai_cli.crew_chat import run_chat
|
||||
from crewai_cli.deploy.main import DeployCommand
|
||||
from crewai_cli.enterprise.main import EnterpriseConfigureCommand
|
||||
from crewai_cli.evaluate_crew import evaluate_crew
|
||||
from crewai_cli.install_crew import install_crew
|
||||
from crewai_cli.kickoff_flow import kickoff_flow
|
||||
from crewai_cli.organization.main import OrganizationCommand
|
||||
from crewai_cli.plot_flow import plot_flow
|
||||
from crewai_cli.remote_template.main import TemplateCommand
|
||||
from crewai_cli.replay_from_task import replay_task_command
|
||||
from crewai_cli.reset_memories_command import reset_memories_command
|
||||
from crewai_cli.run_crew import run_crew
|
||||
from crewai_cli.settings.main import SettingsCommand
|
||||
from crewai_cli.task_outputs import load_task_outputs
|
||||
from crewai_cli.tools.main import ToolCommand
|
||||
from crewai_cli.train_crew import train_crew
|
||||
from crewai_cli.triggers.main import TriggersCommand
|
||||
from crewai_cli.update_crew import update_crew
|
||||
from crewai_cli.user_data import (
|
||||
_load_user_data,
|
||||
is_tracing_enabled,
|
||||
update_user_data,
|
||||
)
|
||||
from crewai_cli.utils import build_env_with_all_tool_credentials, read_toml
|
||||
|
||||
|
||||
def _get_cli_version() -> str:
|
||||
"""Return the best available version string for the CLI."""
|
||||
# Prefer crewai version if installed (keeps existing UX)
|
||||
try:
|
||||
return get_version("crewai")
|
||||
except Exception: # noqa: S110
|
||||
pass
|
||||
try:
|
||||
return get_version("crewai-cli")
|
||||
except Exception:
|
||||
return "unknown"
|
||||
|
||||
|
||||
@click.group()
|
||||
@click.version_option(_get_cli_version())
|
||||
def crewai() -> None:
|
||||
"""Top-level command group for crewai."""
|
||||
|
||||
|
||||
@crewai.command(
|
||||
name="uv",
|
||||
context_settings={"ignore_unknown_options": True},
|
||||
)
|
||||
@click.argument("uv_args", nargs=-1, type=click.UNPROCESSED)
|
||||
def uv(uv_args: tuple[str, ...]) -> None:
|
||||
"""A wrapper around uv commands that adds custom tool authentication through env vars."""
|
||||
try:
|
||||
# Verify pyproject.toml exists first
|
||||
read_toml()
|
||||
except FileNotFoundError as e:
|
||||
raise SystemExit(
|
||||
"Error. A valid pyproject.toml file is required. Check that a valid pyproject.toml file exists in the current directory."
|
||||
) from e
|
||||
except Exception as e:
|
||||
raise SystemExit(f"Error: {e}") from e
|
||||
|
||||
env = build_env_with_all_tool_credentials()
|
||||
|
||||
try:
|
||||
subprocess.run( # noqa: S603
|
||||
["uv", *uv_args], # noqa: S607
|
||||
capture_output=False,
|
||||
env=env,
|
||||
text=True,
|
||||
check=True,
|
||||
)
|
||||
except subprocess.CalledProcessError as e:
|
||||
click.secho(f"uv command failed with exit code {e.returncode}", fg="red")
|
||||
raise SystemExit(e.returncode) from e
|
||||
|
||||
|
||||
@crewai.command()
|
||||
@click.argument("type", type=click.Choice(["crew", "flow"]))
|
||||
@click.argument("name")
|
||||
@click.option("--provider", type=str, help="The provider to use for the crew")
|
||||
@click.option("--skip_provider", is_flag=True, help="Skip provider validation")
|
||||
def create(
|
||||
type: str, name: str, provider: str | None, skip_provider: bool = False
|
||||
) -> None:
|
||||
"""Create a new crew, or flow."""
|
||||
if type == "crew":
|
||||
create_crew(name, provider, skip_provider)
|
||||
elif type == "flow":
|
||||
create_flow(name)
|
||||
else:
|
||||
click.secho("Error: Invalid type. Must be 'crew' or 'flow'.", fg="red")
|
||||
|
||||
|
||||
@crewai.command()
|
||||
@click.option(
|
||||
"--tools", is_flag=True, help="Show the installed version of crewai tools"
|
||||
)
|
||||
def version(tools: bool) -> None:
|
||||
"""Show the installed version of crewai."""
|
||||
try:
|
||||
crewai_version = get_version("crewai")
|
||||
except Exception:
|
||||
crewai_version = "unknown version"
|
||||
click.echo(f"crewai version: {crewai_version}")
|
||||
|
||||
if tools:
|
||||
try:
|
||||
tools_version = get_version("crewai-tools")
|
||||
click.echo(f"crewai tools version: {tools_version}")
|
||||
except Exception:
|
||||
click.echo("crewai tools not installed")
|
||||
|
||||
|
||||
@crewai.command()
|
||||
@click.option(
|
||||
"-n",
|
||||
"--n_iterations",
|
||||
type=int,
|
||||
default=5,
|
||||
help="Number of iterations to train the crew",
|
||||
)
|
||||
@click.option(
|
||||
"-f",
|
||||
"--filename",
|
||||
type=str,
|
||||
default="trained_agents_data.pkl",
|
||||
help="Path to a custom file for training",
|
||||
)
|
||||
def train(n_iterations: int, filename: str) -> None:
|
||||
"""Train the crew."""
|
||||
click.echo(f"Training the Crew for {n_iterations} iterations")
|
||||
train_crew(n_iterations, filename)
|
||||
|
||||
|
||||
@crewai.command()
|
||||
@click.option(
|
||||
"-t",
|
||||
"--task_id",
|
||||
type=str,
|
||||
help="Replay the crew from this task ID, including all subsequent tasks.",
|
||||
)
|
||||
@click.option(
|
||||
"-f",
|
||||
"--filename",
|
||||
"trained_agents_file",
|
||||
type=str,
|
||||
default=None,
|
||||
help=(
|
||||
"Path to a trained-agents pickle (produced by `crewai train -f`). "
|
||||
"When set, agents load suggestions from this file instead of the "
|
||||
"default trained_agents_data.pkl. Equivalent to setting "
|
||||
"CREWAI_TRAINED_AGENTS_FILE."
|
||||
),
|
||||
)
|
||||
def replay(task_id: str, trained_agents_file: str | None) -> None:
|
||||
"""Replay the crew execution from a specific task.
|
||||
|
||||
Args:
|
||||
task_id: The ID of the task to replay from.
|
||||
trained_agents_file: Optional trained-agents pickle path.
|
||||
"""
|
||||
try:
|
||||
click.echo(f"Replaying the crew from task {task_id}")
|
||||
replay_task_command(task_id, trained_agents_file=trained_agents_file)
|
||||
except Exception as e:
|
||||
click.echo(f"An error occurred while replaying: {e}", err=True)
|
||||
|
||||
|
||||
@crewai.command()
|
||||
def log_tasks_outputs() -> None:
|
||||
"""Retrieve your latest crew.kickoff() task outputs."""
|
||||
try:
|
||||
tasks = load_task_outputs()
|
||||
|
||||
if not tasks:
|
||||
click.echo(
|
||||
"No task outputs found. Only crew kickoff task outputs are logged."
|
||||
)
|
||||
return
|
||||
|
||||
for index, task in enumerate(tasks, 1):
|
||||
click.echo(f"Task {index}: {task['task_id']}")
|
||||
click.echo(f"Description: {task['expected_output']}")
|
||||
click.echo("------")
|
||||
|
||||
except Exception as e:
|
||||
click.echo(f"An error occurred while logging task outputs: {e}", err=True)
|
||||
|
||||
|
||||
@crewai.command()
|
||||
@click.option("-m", "--memory", is_flag=True, help="Reset MEMORY")
|
||||
@click.option(
|
||||
"-l",
|
||||
"--long",
|
||||
is_flag=True,
|
||||
hidden=True,
|
||||
help="[Deprecated: use --memory] Reset memory",
|
||||
)
|
||||
@click.option(
|
||||
"-s",
|
||||
"--short",
|
||||
is_flag=True,
|
||||
hidden=True,
|
||||
help="[Deprecated: use --memory] Reset memory",
|
||||
)
|
||||
@click.option(
|
||||
"-e",
|
||||
"--entities",
|
||||
is_flag=True,
|
||||
hidden=True,
|
||||
help="[Deprecated: use --memory] Reset memory",
|
||||
)
|
||||
@click.option("-kn", "--knowledge", is_flag=True, help="Reset KNOWLEDGE storage")
|
||||
@click.option(
|
||||
"-akn", "--agent-knowledge", is_flag=True, help="Reset AGENT KNOWLEDGE storage"
|
||||
)
|
||||
@click.option(
|
||||
"-k", "--kickoff-outputs", is_flag=True, help="Reset LATEST KICKOFF TASK OUTPUTS"
|
||||
)
|
||||
@click.option("-a", "--all", is_flag=True, help="Reset ALL memories")
|
||||
def reset_memories(
|
||||
memory: bool,
|
||||
long: bool,
|
||||
short: bool,
|
||||
entities: bool,
|
||||
knowledge: bool,
|
||||
kickoff_outputs: bool,
|
||||
agent_knowledge: bool,
|
||||
all: bool,
|
||||
) -> None:
|
||||
"""Reset the crew memories (memory, knowledge, agent_knowledge, kickoff_outputs). This will delete all the data saved."""
|
||||
try:
|
||||
if long or short or entities:
|
||||
legacy_used = [
|
||||
f
|
||||
for f, v in [
|
||||
("--long", long),
|
||||
("--short", short),
|
||||
("--entities", entities),
|
||||
]
|
||||
if v
|
||||
]
|
||||
click.echo(
|
||||
f"Warning: {', '.join(legacy_used)} {'is' if len(legacy_used) == 1 else 'are'} "
|
||||
"deprecated. Use --memory (-m) instead. All memory is now unified."
|
||||
)
|
||||
memory = True
|
||||
|
||||
memory_types = [
|
||||
memory,
|
||||
knowledge,
|
||||
agent_knowledge,
|
||||
kickoff_outputs,
|
||||
all,
|
||||
]
|
||||
if not any(memory_types):
|
||||
click.echo(
|
||||
"Please specify at least one memory type to reset using the appropriate flags."
|
||||
)
|
||||
return
|
||||
reset_memories_command(memory, knowledge, agent_knowledge, kickoff_outputs, all)
|
||||
except Exception as e:
|
||||
click.echo(f"An error occurred while resetting memories: {e}", err=True)
|
||||
|
||||
|
||||
@crewai.command()
|
||||
@click.option(
|
||||
"--storage-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to LanceDB memory directory. If omitted, uses ./.crewai/memory.",
|
||||
)
|
||||
@click.option(
|
||||
"--embedder-provider",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Embedder provider for recall queries (e.g. openai, google-vertex, cohere, ollama).",
|
||||
)
|
||||
@click.option(
|
||||
"--embedder-model",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Embedder model name (e.g. text-embedding-3-small, gemini-embedding-001).",
|
||||
)
|
||||
@click.option(
|
||||
"--embedder-config",
|
||||
type=str,
|
||||
default=None,
|
||||
help='Full embedder config as JSON (e.g. \'{"provider": "cohere", "config": {"model_name": "embed-v4.0"}}\').',
|
||||
)
|
||||
def memory(
|
||||
storage_path: str | None,
|
||||
embedder_provider: str | None,
|
||||
embedder_model: str | None,
|
||||
embedder_config: str | None,
|
||||
) -> None:
|
||||
"""Open the Memory TUI to browse scopes and recall memories."""
|
||||
try:
|
||||
from crewai_cli.memory_tui import MemoryTUI
|
||||
except ImportError as exc:
|
||||
click.echo(
|
||||
"Textual is required for the memory TUI but could not be imported. "
|
||||
"Try reinstalling crewai or: pip install textual"
|
||||
)
|
||||
raise SystemExit(1) from exc
|
||||
|
||||
# Build embedder spec from CLI flags.
|
||||
embedder_spec: dict[str, Any] | None = None
|
||||
if embedder_config:
|
||||
import json as _json
|
||||
|
||||
try:
|
||||
embedder_spec = _json.loads(embedder_config)
|
||||
except _json.JSONDecodeError as exc:
|
||||
click.echo(f"Invalid --embedder-config JSON: {exc}")
|
||||
raise SystemExit(1) from exc
|
||||
elif embedder_provider:
|
||||
cfg: dict[str, str] = {}
|
||||
if embedder_model:
|
||||
cfg["model_name"] = embedder_model
|
||||
embedder_spec = {"provider": embedder_provider, "config": cfg}
|
||||
|
||||
app = MemoryTUI(storage_path=storage_path, embedder_config=embedder_spec)
|
||||
app.run()
|
||||
|
||||
|
||||
@crewai.command()
|
||||
@click.option(
|
||||
"-n",
|
||||
"--n_iterations",
|
||||
type=int,
|
||||
default=3,
|
||||
help="Number of iterations to Test the crew",
|
||||
)
|
||||
@click.option(
|
||||
"-m",
|
||||
"--model",
|
||||
type=str,
|
||||
default="gpt-4o-mini",
|
||||
help="LLM Model to run the tests on the Crew. For now only accepting only OpenAI models.",
|
||||
)
|
||||
@click.option(
|
||||
"-f",
|
||||
"--filename",
|
||||
"trained_agents_file",
|
||||
type=str,
|
||||
default=None,
|
||||
help=(
|
||||
"Path to a trained-agents pickle (produced by `crewai train -f`). "
|
||||
"When set, agents load suggestions from this file instead of the "
|
||||
"default trained_agents_data.pkl. Equivalent to setting "
|
||||
"CREWAI_TRAINED_AGENTS_FILE."
|
||||
),
|
||||
)
|
||||
def test(n_iterations: int, model: str, trained_agents_file: str | None) -> None:
|
||||
"""Test the crew and evaluate the results."""
|
||||
click.echo(f"Testing the crew for {n_iterations} iterations with model {model}")
|
||||
evaluate_crew(n_iterations, model, trained_agents_file=trained_agents_file)
|
||||
|
||||
|
||||
@crewai.command(
|
||||
context_settings={
|
||||
"ignore_unknown_options": True,
|
||||
"allow_extra_args": True,
|
||||
}
|
||||
)
|
||||
@click.pass_context
|
||||
def install(context: click.Context) -> None:
|
||||
"""Install the Crew."""
|
||||
install_crew(context.args)
|
||||
|
||||
|
||||
@crewai.command()
|
||||
@click.option(
|
||||
"-f",
|
||||
"--filename",
|
||||
"trained_agents_file",
|
||||
type=str,
|
||||
default=None,
|
||||
help=(
|
||||
"Path to a trained-agents pickle (produced by `crewai train -f`). "
|
||||
"When set, agents load suggestions from this file instead of the "
|
||||
"default trained_agents_data.pkl. Equivalent to setting "
|
||||
"CREWAI_TRAINED_AGENTS_FILE."
|
||||
),
|
||||
)
|
||||
def run(trained_agents_file: str | None) -> None:
|
||||
"""Run the Crew."""
|
||||
run_crew(trained_agents_file=trained_agents_file)
|
||||
|
||||
|
||||
@crewai.command()
|
||||
def update() -> None:
|
||||
"""Update the pyproject.toml of the Crew project to use uv."""
|
||||
update_crew()
|
||||
|
||||
|
||||
@crewai.command()
|
||||
def login() -> None:
|
||||
"""Sign Up/Login to CrewAI AMP."""
|
||||
Settings().clear_user_settings()
|
||||
AuthenticationCommand().login()
|
||||
|
||||
|
||||
@crewai.command()
|
||||
@click.option(
|
||||
"--reset", is_flag=True, help="Also reset all CLI configuration to defaults"
|
||||
)
|
||||
def logout(reset: bool) -> None:
|
||||
"""Logout from CrewAI AMP."""
|
||||
settings = Settings()
|
||||
if reset:
|
||||
settings.reset()
|
||||
click.echo("Successfully logged out and reset all CLI configuration.")
|
||||
else:
|
||||
TokenManager().clear_tokens()
|
||||
settings.clear_user_settings()
|
||||
click.echo("Successfully logged out from CrewAI AMP.")
|
||||
|
||||
|
||||
# DEPLOY CREWAI+ COMMANDS
|
||||
@crewai.group()
|
||||
def deploy() -> None:
|
||||
"""Deploy the Crew CLI group."""
|
||||
|
||||
|
||||
@deploy.command(name="create")
|
||||
@click.option("-y", "--yes", is_flag=True, help="Skip the confirmation prompt")
|
||||
@click.option(
|
||||
"--skip-validate",
|
||||
is_flag=True,
|
||||
help="Skip the pre-deploy validation checks.",
|
||||
)
|
||||
def deploy_create(yes: bool, skip_validate: bool) -> None:
|
||||
"""Create a Crew deployment."""
|
||||
deploy_cmd = DeployCommand()
|
||||
deploy_cmd.create_crew(yes, skip_validate=skip_validate)
|
||||
|
||||
|
||||
@deploy.command(name="list")
|
||||
def deploy_list() -> None:
|
||||
"""List all deployments."""
|
||||
deploy_cmd = DeployCommand()
|
||||
deploy_cmd.list_crews()
|
||||
|
||||
|
||||
@deploy.command(name="push")
|
||||
@click.option("-u", "--uuid", type=str, help="Crew UUID parameter")
|
||||
@click.option(
|
||||
"--skip-validate",
|
||||
is_flag=True,
|
||||
help="Skip the pre-deploy validation checks.",
|
||||
)
|
||||
def deploy_push(uuid: str | None, skip_validate: bool) -> None:
|
||||
"""Deploy the Crew."""
|
||||
deploy_cmd = DeployCommand()
|
||||
deploy_cmd.deploy(uuid=uuid, skip_validate=skip_validate)
|
||||
|
||||
|
||||
@deploy.command(name="validate")
|
||||
def deploy_validate() -> None:
|
||||
"""Validate the current project against common deployment failures.
|
||||
|
||||
Runs the same pre-deploy checks that `crewai deploy create` and
|
||||
`crewai deploy push` run automatically, without contacting the platform.
|
||||
Exits non-zero if any blocking issues are found.
|
||||
"""
|
||||
from crewai_cli.deploy.validate import run_validate_command
|
||||
|
||||
run_validate_command()
|
||||
|
||||
|
||||
@deploy.command(name="status")
|
||||
@click.option("-u", "--uuid", type=str, help="Crew UUID parameter")
|
||||
def deply_status(uuid: str | None) -> None:
|
||||
"""Get the status of a deployment."""
|
||||
deploy_cmd = DeployCommand()
|
||||
deploy_cmd.get_crew_status(uuid=uuid)
|
||||
|
||||
|
||||
@deploy.command(name="logs")
|
||||
@click.option("-u", "--uuid", type=str, help="Crew UUID parameter")
|
||||
def deploy_logs(uuid: str | None) -> None:
|
||||
"""Get the logs of a deployment."""
|
||||
deploy_cmd = DeployCommand()
|
||||
deploy_cmd.get_crew_logs(uuid=uuid)
|
||||
|
||||
|
||||
@deploy.command(name="remove")
|
||||
@click.option("-u", "--uuid", type=str, help="Crew UUID parameter")
|
||||
def deploy_remove(uuid: str | None) -> None:
|
||||
"""Remove a deployment."""
|
||||
deploy_cmd = DeployCommand()
|
||||
deploy_cmd.remove_crew(uuid=uuid)
|
||||
|
||||
|
||||
@crewai.group()
|
||||
def tool() -> None:
|
||||
"""Tool Repository related commands."""
|
||||
|
||||
|
||||
@tool.command(name="create")
|
||||
@click.argument("handle")
|
||||
def tool_create(handle: str) -> None:
|
||||
tool_cmd = ToolCommand()
|
||||
tool_cmd.create(handle)
|
||||
|
||||
|
||||
@tool.command(name="install")
|
||||
@click.argument("handle")
|
||||
def tool_install(handle: str) -> None:
|
||||
tool_cmd = ToolCommand()
|
||||
tool_cmd.login()
|
||||
tool_cmd.install(handle)
|
||||
|
||||
|
||||
@tool.command(name="publish")
|
||||
@click.option(
|
||||
"--force",
|
||||
is_flag=True,
|
||||
show_default=True,
|
||||
default=False,
|
||||
help="Bypasses Git remote validations",
|
||||
)
|
||||
@click.option("--public", "is_public", flag_value=True, default=False)
|
||||
@click.option("--private", "is_public", flag_value=False)
|
||||
def tool_publish(is_public: bool, force: bool) -> None:
|
||||
tool_cmd = ToolCommand()
|
||||
tool_cmd.login()
|
||||
tool_cmd.publish(is_public, force)
|
||||
|
||||
|
||||
@crewai.group()
|
||||
def template() -> None:
|
||||
"""Browse and install project templates."""
|
||||
|
||||
|
||||
@template.command(name="list")
|
||||
def template_list() -> None:
|
||||
"""List available templates and select one to install."""
|
||||
template_cmd = TemplateCommand()
|
||||
template_cmd.list_templates()
|
||||
|
||||
|
||||
@template.command(name="add")
|
||||
@click.argument("name")
|
||||
@click.option(
|
||||
"-o",
|
||||
"--output-dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Directory name for the template (defaults to template name)",
|
||||
)
|
||||
def template_add(name: str, output_dir: str | None) -> None:
|
||||
"""Add a template to the current directory."""
|
||||
template_cmd = TemplateCommand()
|
||||
template_cmd.add_template(name, output_dir)
|
||||
|
||||
|
||||
@crewai.group()
|
||||
def flow() -> None:
|
||||
"""Flow related commands."""
|
||||
|
||||
|
||||
@flow.command(name="kickoff")
|
||||
def flow_run() -> None:
|
||||
"""Kickoff the Flow."""
|
||||
click.echo("Running the Flow")
|
||||
kickoff_flow()
|
||||
|
||||
|
||||
@flow.command(name="plot")
|
||||
def flow_plot() -> None:
|
||||
"""Plot the Flow."""
|
||||
click.echo("Plotting the Flow")
|
||||
plot_flow()
|
||||
|
||||
|
||||
@flow.command(name="add-crew")
|
||||
@click.argument("crew_name")
|
||||
def flow_add_crew(crew_name: str) -> None:
|
||||
"""Add a crew to an existing flow."""
|
||||
click.echo(f"Adding crew {crew_name} to the flow")
|
||||
add_crew_to_flow(crew_name)
|
||||
|
||||
|
||||
@crewai.group()
|
||||
def triggers() -> None:
|
||||
"""Trigger related commands. Use 'crewai triggers list' to see available triggers, or 'crewai triggers run app_slug/trigger_slug' to execute."""
|
||||
|
||||
|
||||
@triggers.command(name="list")
|
||||
def triggers_list() -> None:
|
||||
"""List all available triggers from integrations."""
|
||||
triggers_cmd = TriggersCommand()
|
||||
triggers_cmd.list_triggers()
|
||||
|
||||
|
||||
@triggers.command(name="run")
|
||||
@click.argument("trigger_path")
|
||||
def triggers_run(trigger_path: str) -> None:
|
||||
"""Execute crew with trigger payload. Format: app_slug/trigger_slug"""
|
||||
triggers_cmd = TriggersCommand()
|
||||
triggers_cmd.execute_with_trigger(trigger_path)
|
||||
|
||||
|
||||
@crewai.command()
|
||||
def chat() -> None:
|
||||
"""Start a conversation with the Crew, collecting user-supplied inputs,
|
||||
and using the Chat LLM to generate responses.
|
||||
"""
|
||||
click.secho(
|
||||
"\nStarting a conversation with the Crew\nType 'exit' or Ctrl+C to quit.\n",
|
||||
)
|
||||
run_chat()
|
||||
|
||||
|
||||
@crewai.group(invoke_without_command=True)
|
||||
def org() -> None:
|
||||
"""Organization management commands."""
|
||||
|
||||
|
||||
@org.command("list")
|
||||
def org_list() -> None:
|
||||
"""List available organizations."""
|
||||
org_command = OrganizationCommand()
|
||||
org_command.list()
|
||||
|
||||
|
||||
@org.command()
|
||||
@click.argument("id")
|
||||
def switch(id: str) -> None:
|
||||
"""Switch to a specific organization."""
|
||||
org_command = OrganizationCommand()
|
||||
org_command.switch(id)
|
||||
|
||||
|
||||
@org.command()
|
||||
def current() -> None:
|
||||
"""Show current organization when 'crewai org' is called without subcommands."""
|
||||
org_command = OrganizationCommand()
|
||||
org_command.current()
|
||||
|
||||
|
||||
@crewai.group()
|
||||
def enterprise() -> None:
|
||||
"""Enterprise Configuration commands."""
|
||||
|
||||
|
||||
@enterprise.command("configure")
|
||||
@click.argument("enterprise_url")
|
||||
def enterprise_configure(enterprise_url: str) -> None:
|
||||
"""Configure CrewAI AMP OAuth2 settings from the provided Enterprise URL."""
|
||||
enterprise_command = EnterpriseConfigureCommand()
|
||||
enterprise_command.configure(enterprise_url)
|
||||
|
||||
|
||||
@crewai.group()
|
||||
def config() -> None:
|
||||
"""CLI Configuration commands."""
|
||||
|
||||
|
||||
@config.command("list")
|
||||
def config_list() -> None:
|
||||
"""List all CLI configuration parameters."""
|
||||
config_command = SettingsCommand()
|
||||
config_command.list()
|
||||
|
||||
|
||||
@config.command("set")
|
||||
@click.argument("key")
|
||||
@click.argument("value")
|
||||
def config_set(key: str, value: str) -> None:
|
||||
"""Set a CLI configuration parameter."""
|
||||
config_command = SettingsCommand()
|
||||
config_command.set(key, value)
|
||||
|
||||
|
||||
@config.command("reset")
|
||||
def config_reset() -> None:
|
||||
"""Reset all CLI configuration parameters to default values."""
|
||||
config_command = SettingsCommand()
|
||||
config_command.reset_all_settings()
|
||||
|
||||
|
||||
@crewai.group()
|
||||
def env() -> None:
|
||||
"""Environment variable commands."""
|
||||
|
||||
|
||||
@env.command("view")
|
||||
def env_view() -> None:
|
||||
"""View tracing-related environment variables."""
|
||||
from pathlib import Path
|
||||
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
from rich.table import Table
|
||||
|
||||
console = Console()
|
||||
|
||||
# Check for .env file
|
||||
env_file = Path(".env")
|
||||
env_file_exists = env_file.exists()
|
||||
|
||||
# Create table for environment variables
|
||||
table = Table(show_header=True, header_style="bold cyan", expand=True)
|
||||
table.add_column("Environment Variable", style="cyan", width=30)
|
||||
table.add_column("Value", style="white", width=20)
|
||||
table.add_column("Source", style="yellow", width=20)
|
||||
|
||||
# Check CREWAI_TRACING_ENABLED
|
||||
crewai_tracing = os.getenv("CREWAI_TRACING_ENABLED", "")
|
||||
if crewai_tracing:
|
||||
table.add_row(
|
||||
"CREWAI_TRACING_ENABLED",
|
||||
crewai_tracing,
|
||||
"Environment/Shell",
|
||||
)
|
||||
else:
|
||||
table.add_row(
|
||||
"CREWAI_TRACING_ENABLED",
|
||||
"[dim]Not set[/dim]",
|
||||
"[dim]—[/dim]",
|
||||
)
|
||||
|
||||
# Check other related env vars
|
||||
crewai_testing = os.getenv("CREWAI_TESTING", "")
|
||||
if crewai_testing:
|
||||
table.add_row("CREWAI_TESTING", crewai_testing, "Environment/Shell")
|
||||
|
||||
crewai_user_id = os.getenv("CREWAI_USER_ID", "")
|
||||
if crewai_user_id:
|
||||
table.add_row("CREWAI_USER_ID", crewai_user_id, "Environment/Shell")
|
||||
|
||||
crewai_org_id = os.getenv("CREWAI_ORG_ID", "")
|
||||
if crewai_org_id:
|
||||
table.add_row("CREWAI_ORG_ID", crewai_org_id, "Environment/Shell")
|
||||
|
||||
# Check if .env file exists
|
||||
table.add_row(
|
||||
".env file",
|
||||
"✅ Found" if env_file_exists else "❌ Not found",
|
||||
str(env_file.resolve()) if env_file_exists else "N/A",
|
||||
)
|
||||
|
||||
panel = Panel(
|
||||
table,
|
||||
title="Tracing Environment Variables",
|
||||
border_style="blue",
|
||||
padding=(1, 2),
|
||||
)
|
||||
console.print("\n")
|
||||
console.print(panel)
|
||||
|
||||
# Show helpful message
|
||||
if env_file_exists:
|
||||
console.print(
|
||||
"\n[dim]💡 Tip: To enable tracing via .env, add: CREWAI_TRACING_ENABLED=true[/dim]"
|
||||
)
|
||||
else:
|
||||
console.print(
|
||||
"\n[dim]💡 Tip: Create a .env file in your project root and add: CREWAI_TRACING_ENABLED=true[/dim]"
|
||||
)
|
||||
console.print()
|
||||
|
||||
|
||||
@crewai.group()
|
||||
def traces() -> None:
|
||||
"""Trace collection management commands."""
|
||||
|
||||
|
||||
@traces.command("enable")
|
||||
def traces_enable() -> None:
|
||||
"""Enable trace collection for crew/flow executions."""
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
|
||||
console = Console()
|
||||
|
||||
update_user_data({"trace_consent": True, "first_execution_done": True})
|
||||
|
||||
panel = Panel(
|
||||
"✅ Trace collection enabled.\n\n"
|
||||
"Your crew/flow executions will now send traces to CrewAI+.\n"
|
||||
"Use 'crewai traces disable' to opt out.",
|
||||
title="Traces Enabled",
|
||||
border_style="green",
|
||||
padding=(1, 2),
|
||||
)
|
||||
console.print(panel)
|
||||
|
||||
|
||||
@traces.command("disable")
|
||||
def traces_disable() -> None:
|
||||
"""Disable trace collection for crew/flow executions."""
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
|
||||
console = Console()
|
||||
|
||||
update_user_data({"trace_consent": False, "first_execution_done": True})
|
||||
|
||||
panel = Panel(
|
||||
"❌ Trace collection disabled.\n\n"
|
||||
"Your crew/flow executions will no longer send traces "
|
||||
"(unless [bold]CREWAI_TRACING_ENABLED=true[/bold] is set in the environment, "
|
||||
"which overrides the opt-out).\n"
|
||||
"Use 'crewai traces enable' to opt back in.",
|
||||
title="Traces Disabled",
|
||||
border_style="red",
|
||||
padding=(1, 2),
|
||||
)
|
||||
console.print(panel)
|
||||
|
||||
|
||||
@traces.command("status")
|
||||
def traces_status() -> None:
|
||||
"""Show current trace collection status."""
|
||||
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
from rich.table import Table
|
||||
|
||||
console = Console()
|
||||
user_data = _load_user_data()
|
||||
|
||||
table = Table(show_header=False, box=None)
|
||||
table.add_column("Setting", style="cyan")
|
||||
table.add_column("Value", style="white")
|
||||
|
||||
# Check environment variable
|
||||
env_enabled = os.getenv("CREWAI_TRACING_ENABLED", "false")
|
||||
table.add_row("CREWAI_TRACING_ENABLED", env_enabled)
|
||||
|
||||
# Check user consent
|
||||
trace_consent = user_data.get("trace_consent")
|
||||
if trace_consent is True:
|
||||
consent_status = "✅ Enabled (user consented)"
|
||||
elif trace_consent is False:
|
||||
consent_status = "❌ Disabled (user declined)"
|
||||
else:
|
||||
consent_status = "⚪ Not set (first-time user)"
|
||||
table.add_row("User Consent", consent_status)
|
||||
|
||||
# Check overall status
|
||||
if is_tracing_enabled():
|
||||
overall_status = "✅ ENABLED"
|
||||
border_style = "green"
|
||||
else:
|
||||
overall_status = "❌ DISABLED"
|
||||
border_style = "red"
|
||||
table.add_row("Overall Status", overall_status)
|
||||
|
||||
panel = Panel(
|
||||
table,
|
||||
title="Trace Collection Status",
|
||||
border_style=border_style,
|
||||
padding=(1, 2),
|
||||
)
|
||||
console.print(panel)
|
||||
|
||||
|
||||
@crewai.group(invoke_without_command=True)
|
||||
@click.option(
|
||||
"--location", default="./.checkpoints", help="Checkpoint directory or SQLite file."
|
||||
)
|
||||
@click.pass_context
|
||||
def checkpoint(ctx: click.Context, location: str) -> None:
|
||||
"""Browse and inspect checkpoints. Launches a TUI when called without a subcommand."""
|
||||
from crewai_cli.checkpoint_cli import _detect_location
|
||||
|
||||
location = _detect_location(location)
|
||||
ctx.ensure_object(dict)
|
||||
ctx.obj["location"] = location
|
||||
if ctx.invoked_subcommand is None:
|
||||
from crewai_cli.checkpoint_tui import run_checkpoint_tui
|
||||
|
||||
run_checkpoint_tui(location)
|
||||
|
||||
|
||||
@checkpoint.command("list")
|
||||
@click.argument("location", default="./.checkpoints")
|
||||
def checkpoint_list(location: str) -> None:
|
||||
"""List checkpoints in a directory."""
|
||||
from crewai_cli.checkpoint_cli import _detect_location, list_checkpoints
|
||||
|
||||
list_checkpoints(_detect_location(location))
|
||||
|
||||
|
||||
@checkpoint.command("info")
|
||||
@click.argument("path", default="./.checkpoints")
|
||||
def checkpoint_info(path: str) -> None:
|
||||
"""Show details of a checkpoint. Pass a file or directory for latest."""
|
||||
from crewai_cli.checkpoint_cli import _detect_location, info_checkpoint
|
||||
|
||||
info_checkpoint(_detect_location(path))
|
||||
|
||||
|
||||
@checkpoint.command("resume")
|
||||
@click.argument("checkpoint_id", required=False, default=None)
|
||||
@click.pass_context
|
||||
def checkpoint_resume(ctx: click.Context, checkpoint_id: str | None) -> None:
|
||||
"""Resume from a checkpoint. Defaults to the most recent."""
|
||||
from crewai_cli.checkpoint_cli import resume_checkpoint
|
||||
|
||||
resume_checkpoint(ctx.obj["location"], checkpoint_id)
|
||||
|
||||
|
||||
@checkpoint.command("diff")
|
||||
@click.argument("id1")
|
||||
@click.argument("id2")
|
||||
@click.pass_context
|
||||
def checkpoint_diff(ctx: click.Context, id1: str, id2: str) -> None:
|
||||
"""Compare two checkpoints side-by-side."""
|
||||
from crewai_cli.checkpoint_cli import diff_checkpoints
|
||||
|
||||
diff_checkpoints(ctx.obj["location"], id1, id2)
|
||||
|
||||
|
||||
@checkpoint.command("prune")
|
||||
@click.option(
|
||||
"--keep", type=int, default=None, help="Keep the N most recent checkpoints."
|
||||
)
|
||||
@click.option(
|
||||
"--older-than",
|
||||
default=None,
|
||||
help="Remove checkpoints older than duration (e.g. 7d, 24h, 30m).",
|
||||
)
|
||||
@click.option(
|
||||
"--dry-run", is_flag=True, help="Show what would be pruned without deleting."
|
||||
)
|
||||
@click.pass_context
|
||||
def checkpoint_prune(
|
||||
ctx: click.Context, keep: int | None, older_than: str | None, dry_run: bool
|
||||
) -> None:
|
||||
"""Remove old checkpoints."""
|
||||
from crewai_cli.checkpoint_cli import prune_checkpoints
|
||||
|
||||
prune_checkpoints(ctx.obj["location"], keep, older_than, dry_run)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
crewai()
|
||||
77
lib/cli/src/crewai_cli/command.py
Normal file
77
lib/cli/src/crewai_cli/command.py
Normal file
@@ -0,0 +1,77 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
||||
from crewai_core.telemetry import Telemetry
|
||||
import httpx
|
||||
from rich.console import Console
|
||||
|
||||
from crewai_cli.authentication.token import get_auth_token
|
||||
from crewai_cli.plus_api import PlusAPI
|
||||
|
||||
|
||||
console = Console()
|
||||
|
||||
|
||||
class BaseCommand:
|
||||
def __init__(self) -> None:
|
||||
self._telemetry = Telemetry()
|
||||
self._telemetry.set_tracer()
|
||||
|
||||
|
||||
class PlusAPIMixin:
|
||||
def __init__(self, telemetry: Telemetry) -> None:
|
||||
try:
|
||||
telemetry.set_tracer()
|
||||
self.plus_api_client = PlusAPI(api_key=get_auth_token())
|
||||
except Exception:
|
||||
telemetry.deploy_signup_error_span()
|
||||
console.print(
|
||||
"Please sign up/login to CrewAI+ before using the CLI.",
|
||||
style="bold red",
|
||||
)
|
||||
console.print("Run 'crewai login' to sign up/login.", style="bold green")
|
||||
raise SystemExit from None
|
||||
|
||||
def _validate_response(self, response: httpx.Response) -> None:
|
||||
"""Handle and display error messages from API responses.
|
||||
|
||||
Args:
|
||||
response: The response from the Plus API.
|
||||
"""
|
||||
try:
|
||||
json_response = response.json()
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
console.print(
|
||||
"Failed to parse response from Enterprise API failed. Details:",
|
||||
style="bold red",
|
||||
)
|
||||
console.print(f"Status Code: {response.status_code}")
|
||||
console.print(
|
||||
f"Response:\n{response.content.decode('utf-8', errors='replace')}"
|
||||
)
|
||||
raise SystemExit from None
|
||||
|
||||
if response.status_code == 422:
|
||||
console.print(
|
||||
"Failed to complete operation. Please fix the following errors:",
|
||||
style="bold red",
|
||||
)
|
||||
for field, messages in json_response.items():
|
||||
for message in messages:
|
||||
console.print(
|
||||
f"* [bold red]{field.capitalize()}[/bold red] {message}"
|
||||
)
|
||||
raise SystemExit
|
||||
|
||||
if not response.is_success:
|
||||
console.print(
|
||||
"Request to Enterprise API failed. Details:", style="bold red"
|
||||
)
|
||||
details = (
|
||||
json_response.get("error")
|
||||
or json_response.get("message")
|
||||
or response.content.decode("utf-8", errors="replace")
|
||||
)
|
||||
console.print(f"{details}")
|
||||
raise SystemExit
|
||||
30
lib/cli/src/crewai_cli/config.py
Normal file
30
lib/cli/src/crewai_cli/config.py
Normal file
@@ -0,0 +1,30 @@
|
||||
"""Re-exports of shared settings from ``crewai_core.settings``.
|
||||
|
||||
Kept as a stable import path for the CLI; new code should import from
|
||||
``crewai_core.settings`` directly.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from crewai_core.settings import (
|
||||
CLI_SETTINGS_KEYS as CLI_SETTINGS_KEYS,
|
||||
DEFAULT_CLI_SETTINGS as DEFAULT_CLI_SETTINGS,
|
||||
DEFAULT_CONFIG_PATH as DEFAULT_CONFIG_PATH,
|
||||
HIDDEN_SETTINGS_KEYS as HIDDEN_SETTINGS_KEYS,
|
||||
READONLY_SETTINGS_KEYS as READONLY_SETTINGS_KEYS,
|
||||
USER_SETTINGS_KEYS as USER_SETTINGS_KEYS,
|
||||
Settings as Settings,
|
||||
get_writable_config_path as get_writable_config_path,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"CLI_SETTINGS_KEYS",
|
||||
"DEFAULT_CLI_SETTINGS",
|
||||
"DEFAULT_CONFIG_PATH",
|
||||
"HIDDEN_SETTINGS_KEYS",
|
||||
"READONLY_SETTINGS_KEYS",
|
||||
"USER_SETTINGS_KEYS",
|
||||
"Settings",
|
||||
"get_writable_config_path",
|
||||
]
|
||||
333
lib/cli/src/crewai_cli/constants.py
Normal file
333
lib/cli/src/crewai_cli/constants.py
Normal file
@@ -0,0 +1,333 @@
|
||||
from typing import Any
|
||||
|
||||
|
||||
DEFAULT_CREWAI_ENTERPRISE_URL = "https://app.crewai.com"
|
||||
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_PROVIDER = "workos"
|
||||
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_AUDIENCE = "client_01JNJQWBJ4SPFN3SWJM5T7BDG8"
|
||||
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_CLIENT_ID = "client_01JYT06R59SP0NXYGD994NFXXX"
|
||||
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN = "login.crewai.com"
|
||||
|
||||
ENV_VARS: dict[str, list[dict[str, Any]]] = {
|
||||
"openai": [
|
||||
{
|
||||
"prompt": "Enter your OPENAI API key (press Enter to skip)",
|
||||
"key_name": "OPENAI_API_KEY",
|
||||
}
|
||||
],
|
||||
"anthropic": [
|
||||
{
|
||||
"prompt": "Enter your ANTHROPIC API key (press Enter to skip)",
|
||||
"key_name": "ANTHROPIC_API_KEY",
|
||||
}
|
||||
],
|
||||
"gemini": [
|
||||
{
|
||||
"prompt": "Enter your GEMINI API key from https://ai.dev/apikey (press Enter to skip)",
|
||||
"key_name": "GEMINI_API_KEY",
|
||||
}
|
||||
],
|
||||
"nvidia_nim": [
|
||||
{
|
||||
"prompt": "Enter your NVIDIA API key (press Enter to skip)",
|
||||
"key_name": "NVIDIA_NIM_API_KEY",
|
||||
}
|
||||
],
|
||||
"groq": [
|
||||
{
|
||||
"prompt": "Enter your GROQ API key (press Enter to skip)",
|
||||
"key_name": "GROQ_API_KEY",
|
||||
}
|
||||
],
|
||||
"watson": [
|
||||
{
|
||||
"prompt": "Enter your WATSONX URL (press Enter to skip)",
|
||||
"key_name": "WATSONX_URL",
|
||||
},
|
||||
{
|
||||
"prompt": "Enter your WATSONX API Key (press Enter to skip)",
|
||||
"key_name": "WATSONX_APIKEY",
|
||||
},
|
||||
{
|
||||
"prompt": "Enter your WATSONX Project Id (press Enter to skip)",
|
||||
"key_name": "WATSONX_PROJECT_ID",
|
||||
},
|
||||
],
|
||||
"ollama": [
|
||||
{
|
||||
"default": True,
|
||||
"API_BASE": "http://localhost:11434",
|
||||
}
|
||||
],
|
||||
"bedrock": [
|
||||
{
|
||||
"prompt": "Enter your AWS Access Key ID (press Enter to skip)",
|
||||
"key_name": "AWS_ACCESS_KEY_ID",
|
||||
},
|
||||
{
|
||||
"prompt": "Enter your AWS Secret Access Key (press Enter to skip)",
|
||||
"key_name": "AWS_SECRET_ACCESS_KEY",
|
||||
},
|
||||
{
|
||||
"prompt": "Enter your AWS Region Name (press Enter to skip)",
|
||||
"key_name": "AWS_DEFAULT_REGION",
|
||||
},
|
||||
],
|
||||
"azure": [
|
||||
{
|
||||
"prompt": "Enter your Azure deployment name (must start with 'azure/')",
|
||||
"key_name": "model",
|
||||
},
|
||||
{
|
||||
"prompt": "Enter your AZURE API key (press Enter to skip)",
|
||||
"key_name": "AZURE_API_KEY",
|
||||
},
|
||||
{
|
||||
"prompt": "Enter your AZURE API base URL (press Enter to skip)",
|
||||
"key_name": "AZURE_API_BASE",
|
||||
},
|
||||
{
|
||||
"prompt": "Enter your AZURE API version (press Enter to skip)",
|
||||
"key_name": "AZURE_API_VERSION",
|
||||
},
|
||||
],
|
||||
"cerebras": [
|
||||
{
|
||||
"prompt": "Enter your Cerebras model name (must start with 'cerebras/')",
|
||||
"key_name": "model",
|
||||
},
|
||||
{
|
||||
"prompt": "Enter your Cerebras API version (press Enter to skip)",
|
||||
"key_name": "CEREBRAS_API_KEY",
|
||||
},
|
||||
],
|
||||
"huggingface": [
|
||||
{
|
||||
"prompt": "Enter your Huggingface API key (HF_TOKEN) (press Enter to skip)",
|
||||
"key_name": "HF_TOKEN",
|
||||
},
|
||||
],
|
||||
"sambanova": [
|
||||
{
|
||||
"prompt": "Enter your SambaNovaCloud API key (press Enter to skip)",
|
||||
"key_name": "SAMBANOVA_API_KEY",
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
PROVIDERS: list[str] = [
|
||||
"openai",
|
||||
"anthropic",
|
||||
"gemini",
|
||||
"nvidia_nim",
|
||||
"groq",
|
||||
"huggingface",
|
||||
"ollama",
|
||||
"watson",
|
||||
"bedrock",
|
||||
"azure",
|
||||
"cerebras",
|
||||
"sambanova",
|
||||
]
|
||||
|
||||
MODELS: dict[str, list[str]] = {
|
||||
"openai": [
|
||||
"gpt-4",
|
||||
"gpt-4.1",
|
||||
"gpt-4.1-mini-2025-04-14",
|
||||
"gpt-4.1-nano-2025-04-14",
|
||||
"gpt-4o",
|
||||
"gpt-4o-mini",
|
||||
"o1-mini",
|
||||
"o1-preview",
|
||||
],
|
||||
"anthropic": [
|
||||
"claude-3-5-sonnet-20240620",
|
||||
"claude-3-sonnet-20240229",
|
||||
"claude-3-opus-20240229",
|
||||
"claude-3-haiku-20240307",
|
||||
],
|
||||
"gemini": [
|
||||
"gemini/gemini-3-pro-preview",
|
||||
"gemini/gemini-1.5-flash",
|
||||
"gemini/gemini-1.5-pro",
|
||||
"gemini/gemini-2.0-flash-lite-001",
|
||||
"gemini/gemini-2.0-flash-001",
|
||||
"gemini/gemini-2.0-flash-thinking-exp-01-21",
|
||||
"gemini/gemini-2.5-flash-preview-04-17",
|
||||
"gemini/gemini-2.5-pro-exp-03-25",
|
||||
"gemini/gemini-gemma-2-9b-it",
|
||||
"gemini/gemini-gemma-2-27b-it",
|
||||
"gemini/gemma-3-1b-it",
|
||||
"gemini/gemma-3-4b-it",
|
||||
"gemini/gemma-3-12b-it",
|
||||
"gemini/gemma-3-27b-it",
|
||||
],
|
||||
"nvidia_nim": [
|
||||
"nvidia_nim/nvidia/mistral-nemo-minitron-8b-8k-instruct",
|
||||
"nvidia_nim/nvidia/nemotron-4-mini-hindi-4b-instruct",
|
||||
"nvidia_nim/nvidia/llama-3.1-nemotron-70b-instruct",
|
||||
"nvidia_nim/nvidia/llama3-chatqa-1.5-8b",
|
||||
"nvidia_nim/nvidia/llama3-chatqa-1.5-70b",
|
||||
"nvidia_nim/nvidia/vila",
|
||||
"nvidia_nim/nvidia/neva-22",
|
||||
"nvidia_nim/nvidia/nemotron-mini-4b-instruct",
|
||||
"nvidia_nim/nvidia/usdcode-llama3-70b-instruct",
|
||||
"nvidia_nim/nvidia/nemotron-4-340b-instruct",
|
||||
"nvidia_nim/meta/codellama-70b",
|
||||
"nvidia_nim/meta/llama2-70b",
|
||||
"nvidia_nim/meta/llama3-8b-instruct",
|
||||
"nvidia_nim/meta/llama3-70b-instruct",
|
||||
"nvidia_nim/meta/llama-3.1-8b-instruct",
|
||||
"nvidia_nim/meta/llama-3.1-70b-instruct",
|
||||
"nvidia_nim/meta/llama-3.1-405b-instruct",
|
||||
"nvidia_nim/meta/llama-3.2-1b-instruct",
|
||||
"nvidia_nim/meta/llama-3.2-3b-instruct",
|
||||
"nvidia_nim/meta/llama-3.2-11b-vision-instruct",
|
||||
"nvidia_nim/meta/llama-3.2-90b-vision-instruct",
|
||||
"nvidia_nim/meta/llama-3.1-70b-instruct",
|
||||
"nvidia_nim/google/gemma-7b",
|
||||
"nvidia_nim/google/gemma-2b",
|
||||
"nvidia_nim/google/codegemma-7b",
|
||||
"nvidia_nim/google/codegemma-1.1-7b",
|
||||
"nvidia_nim/google/recurrentgemma-2b",
|
||||
"nvidia_nim/google/gemma-2-9b-it",
|
||||
"nvidia_nim/google/gemma-2-27b-it",
|
||||
"nvidia_nim/google/gemma-2-2b-it",
|
||||
"nvidia_nim/google/deplot",
|
||||
"nvidia_nim/google/paligemma",
|
||||
"nvidia_nim/mistralai/mistral-7b-instruct-v0.2",
|
||||
"nvidia_nim/mistralai/mixtral-8x7b-instruct-v0.1",
|
||||
"nvidia_nim/mistralai/mistral-large",
|
||||
"nvidia_nim/mistralai/mixtral-8x22b-instruct-v0.1",
|
||||
"nvidia_nim/mistralai/mistral-7b-instruct-v0.3",
|
||||
"nvidia_nim/nv-mistralai/mistral-nemo-12b-instruct",
|
||||
"nvidia_nim/mistralai/mamba-codestral-7b-v0.1",
|
||||
"nvidia_nim/microsoft/phi-3-mini-128k-instruct",
|
||||
"nvidia_nim/microsoft/phi-3-mini-4k-instruct",
|
||||
"nvidia_nim/microsoft/phi-3-small-8k-instruct",
|
||||
"nvidia_nim/microsoft/phi-3-small-128k-instruct",
|
||||
"nvidia_nim/microsoft/phi-3-medium-4k-instruct",
|
||||
"nvidia_nim/microsoft/phi-3-medium-128k-instruct",
|
||||
"nvidia_nim/microsoft/phi-3.5-mini-instruct",
|
||||
"nvidia_nim/microsoft/phi-3.5-moe-instruct",
|
||||
"nvidia_nim/microsoft/kosmos-2",
|
||||
"nvidia_nim/microsoft/phi-3-vision-128k-instruct",
|
||||
"nvidia_nim/microsoft/phi-3.5-vision-instruct",
|
||||
"nvidia_nim/databricks/dbrx-instruct",
|
||||
"nvidia_nim/snowflake/arctic",
|
||||
"nvidia_nim/aisingapore/sea-lion-7b-instruct",
|
||||
"nvidia_nim/ibm/granite-8b-code-instruct",
|
||||
"nvidia_nim/ibm/granite-34b-code-instruct",
|
||||
"nvidia_nim/ibm/granite-3.0-8b-instruct",
|
||||
"nvidia_nim/ibm/granite-3.0-3b-a800m-instruct",
|
||||
"nvidia_nim/mediatek/breeze-7b-instruct",
|
||||
"nvidia_nim/upstage/solar-10.7b-instruct",
|
||||
"nvidia_nim/writer/palmyra-med-70b-32k",
|
||||
"nvidia_nim/writer/palmyra-med-70b",
|
||||
"nvidia_nim/writer/palmyra-fin-70b-32k",
|
||||
"nvidia_nim/01-ai/yi-large",
|
||||
"nvidia_nim/deepseek-ai/deepseek-coder-6.7b-instruct",
|
||||
"nvidia_nim/rakuten/rakutenai-7b-instruct",
|
||||
"nvidia_nim/rakuten/rakutenai-7b-chat",
|
||||
"nvidia_nim/baichuan-inc/baichuan2-13b-chat",
|
||||
],
|
||||
"groq": [
|
||||
"groq/llama-3.1-8b-instant",
|
||||
"groq/llama-3.1-70b-versatile",
|
||||
"groq/llama-3.1-405b-reasoning",
|
||||
"groq/gemma2-9b-it",
|
||||
"groq/gemma-7b-it",
|
||||
],
|
||||
"ollama": ["ollama/llama3.1", "ollama/mixtral"],
|
||||
"watson": [
|
||||
"watsonx/meta-llama/llama-3-1-70b-instruct",
|
||||
"watsonx/meta-llama/llama-3-1-8b-instruct",
|
||||
"watsonx/meta-llama/llama-3-2-11b-vision-instruct",
|
||||
"watsonx/meta-llama/llama-3-2-1b-instruct",
|
||||
"watsonx/meta-llama/llama-3-2-90b-vision-instruct",
|
||||
"watsonx/meta-llama/llama-3-405b-instruct",
|
||||
"watsonx/mistral/mistral-large",
|
||||
"watsonx/ibm/granite-3-8b-instruct",
|
||||
],
|
||||
"bedrock": [
|
||||
"bedrock/us.amazon.nova-pro-v1:0",
|
||||
"bedrock/us.amazon.nova-micro-v1:0",
|
||||
"bedrock/us.amazon.nova-lite-v1:0",
|
||||
"bedrock/us.anthropic.claude-3-5-sonnet-20240620-v1:0",
|
||||
"bedrock/us.anthropic.claude-3-5-haiku-20241022-v1:0",
|
||||
"bedrock/us.anthropic.claude-3-5-sonnet-20241022-v2:0",
|
||||
"bedrock/us.anthropic.claude-3-7-sonnet-20250219-v1:0",
|
||||
"bedrock/us.anthropic.claude-3-sonnet-20240229-v1:0",
|
||||
"bedrock/us.anthropic.claude-3-opus-20240229-v1:0",
|
||||
"bedrock/us.anthropic.claude-3-haiku-20240307-v1:0",
|
||||
"bedrock/us.meta.llama3-2-11b-instruct-v1:0",
|
||||
"bedrock/us.meta.llama3-2-3b-instruct-v1:0",
|
||||
"bedrock/us.meta.llama3-2-90b-instruct-v1:0",
|
||||
"bedrock/us.meta.llama3-2-1b-instruct-v1:0",
|
||||
"bedrock/us.meta.llama3-1-8b-instruct-v1:0",
|
||||
"bedrock/us.meta.llama3-1-70b-instruct-v1:0",
|
||||
"bedrock/us.meta.llama3-3-70b-instruct-v1:0",
|
||||
"bedrock/us.meta.llama3-1-405b-instruct-v1:0",
|
||||
"bedrock/eu.anthropic.claude-3-5-sonnet-20240620-v1:0",
|
||||
"bedrock/eu.anthropic.claude-3-sonnet-20240229-v1:0",
|
||||
"bedrock/eu.anthropic.claude-3-haiku-20240307-v1:0",
|
||||
"bedrock/eu.meta.llama3-2-3b-instruct-v1:0",
|
||||
"bedrock/eu.meta.llama3-2-1b-instruct-v1:0",
|
||||
"bedrock/apac.anthropic.claude-3-5-sonnet-20240620-v1:0",
|
||||
"bedrock/apac.anthropic.claude-3-5-sonnet-20241022-v2:0",
|
||||
"bedrock/apac.anthropic.claude-3-sonnet-20240229-v1:0",
|
||||
"bedrock/apac.anthropic.claude-3-haiku-20240307-v1:0",
|
||||
"bedrock/amazon.nova-pro-v1:0",
|
||||
"bedrock/amazon.nova-micro-v1:0",
|
||||
"bedrock/amazon.nova-lite-v1:0",
|
||||
"bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0",
|
||||
"bedrock/anthropic.claude-3-5-haiku-20241022-v1:0",
|
||||
"bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0",
|
||||
"bedrock/anthropic.claude-3-7-sonnet-20250219-v1:0",
|
||||
"bedrock/anthropic.claude-3-sonnet-20240229-v1:0",
|
||||
"bedrock/anthropic.claude-3-opus-20240229-v1:0",
|
||||
"bedrock/anthropic.claude-3-haiku-20240307-v1:0",
|
||||
"bedrock/anthropic.claude-v2:1",
|
||||
"bedrock/anthropic.claude-v2",
|
||||
"bedrock/anthropic.claude-instant-v1",
|
||||
"bedrock/meta.llama3-1-405b-instruct-v1:0",
|
||||
"bedrock/meta.llama3-1-70b-instruct-v1:0",
|
||||
"bedrock/meta.llama3-1-8b-instruct-v1:0",
|
||||
"bedrock/meta.llama3-70b-instruct-v1:0",
|
||||
"bedrock/meta.llama3-8b-instruct-v1:0",
|
||||
"bedrock/amazon.titan-text-lite-v1",
|
||||
"bedrock/amazon.titan-text-express-v1",
|
||||
"bedrock/cohere.command-text-v14",
|
||||
"bedrock/ai21.j2-mid-v1",
|
||||
"bedrock/ai21.j2-ultra-v1",
|
||||
"bedrock/ai21.jamba-instruct-v1:0",
|
||||
"bedrock/mistral.mistral-7b-instruct-v0:2",
|
||||
"bedrock/mistral.mixtral-8x7b-instruct-v0:1",
|
||||
],
|
||||
"huggingface": [
|
||||
"huggingface/meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"huggingface/mistralai/Mixtral-8x7B-Instruct-v0.1",
|
||||
"huggingface/tiiuae/falcon-180B-chat",
|
||||
"huggingface/google/gemma-7b-it",
|
||||
],
|
||||
"sambanova": [
|
||||
"sambanova/Meta-Llama-3.3-70B-Instruct",
|
||||
"sambanova/QwQ-32B-Preview",
|
||||
"sambanova/Qwen2.5-72B-Instruct",
|
||||
"sambanova/Qwen2.5-Coder-32B-Instruct",
|
||||
"sambanova/Meta-Llama-3.1-405B-Instruct",
|
||||
"sambanova/Meta-Llama-3.1-70B-Instruct",
|
||||
"sambanova/Meta-Llama-3.1-8B-Instruct",
|
||||
"sambanova/Llama-3.2-90B-Vision-Instruct",
|
||||
"sambanova/Llama-3.2-11B-Vision-Instruct",
|
||||
"sambanova/Meta-Llama-3.2-3B-Instruct",
|
||||
"sambanova/Meta-Llama-3.2-1B-Instruct",
|
||||
],
|
||||
}
|
||||
|
||||
DEFAULT_LLM_MODEL = "gpt-4.1-mini"
|
||||
|
||||
JSON_URL = "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json"
|
||||
|
||||
LITELLM_PARAMS = ["api_key", "api_base", "api_version"]
|
||||
321
lib/cli/src/crewai_cli/create_crew.py
Normal file
321
lib/cli/src/crewai_cli/create_crew.py
Normal file
@@ -0,0 +1,321 @@
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
import sys
|
||||
|
||||
import click
|
||||
import tomli
|
||||
|
||||
from crewai_cli.constants import ENV_VARS, MODELS
|
||||
from crewai_cli.provider import (
|
||||
get_provider_data,
|
||||
select_model,
|
||||
select_provider,
|
||||
)
|
||||
from crewai_cli.utils import copy_template, load_env_vars, write_env_file
|
||||
|
||||
|
||||
def get_reserved_script_names() -> set[str]:
|
||||
"""Get reserved script names from pyproject.toml template.
|
||||
|
||||
Returns:
|
||||
Set of reserved script names that would conflict with crew folder names.
|
||||
"""
|
||||
package_dir = Path(__file__).parent
|
||||
template_path = package_dir / "templates" / "crew" / "pyproject.toml"
|
||||
|
||||
with open(template_path, "r") as f:
|
||||
template_content = f.read()
|
||||
|
||||
template_content = template_content.replace("{{folder_name}}", "_placeholder_")
|
||||
template_content = template_content.replace("{{name}}", "placeholder")
|
||||
template_content = template_content.replace("{{crew_name}}", "Placeholder")
|
||||
|
||||
template_data = tomli.loads(template_content)
|
||||
script_names = set(template_data.get("project", {}).get("scripts", {}).keys())
|
||||
script_names.discard("_placeholder_")
|
||||
return script_names
|
||||
|
||||
|
||||
def create_folder_structure(
|
||||
name: str, parent_folder: str | None = None
|
||||
) -> tuple[Path, str, str]:
|
||||
import keyword
|
||||
import re
|
||||
|
||||
name = name.rstrip("/")
|
||||
|
||||
if not name.strip():
|
||||
raise ValueError("Project name cannot be empty or contain only whitespace")
|
||||
|
||||
folder_name = name.replace(" ", "_").replace("-", "_").lower()
|
||||
folder_name = re.sub(r"[^a-zA-Z0-9_]", "", folder_name)
|
||||
|
||||
# Check if the name starts with invalid characters or is primarily invalid
|
||||
if re.match(r"^[^a-zA-Z0-9_-]+", name):
|
||||
raise ValueError(
|
||||
f"Project name '{name}' contains no valid characters for a Python module name"
|
||||
)
|
||||
|
||||
if not folder_name:
|
||||
raise ValueError(
|
||||
f"Project name '{name}' contains no valid characters for a Python module name"
|
||||
)
|
||||
|
||||
if folder_name[0].isdigit():
|
||||
raise ValueError(
|
||||
f"Project name '{name}' would generate folder name '{folder_name}' which cannot start with a digit (invalid Python module name)"
|
||||
)
|
||||
|
||||
if keyword.iskeyword(folder_name):
|
||||
raise ValueError(
|
||||
f"Project name '{name}' would generate folder name '{folder_name}' which is a reserved Python keyword"
|
||||
)
|
||||
|
||||
if not folder_name.isidentifier():
|
||||
raise ValueError(
|
||||
f"Project name '{name}' would generate invalid Python module name '{folder_name}'"
|
||||
)
|
||||
|
||||
reserved_names = get_reserved_script_names()
|
||||
if folder_name in reserved_names:
|
||||
raise ValueError(
|
||||
f"Project name '{name}' would generate folder name '{folder_name}' which is reserved. "
|
||||
f"Reserved names are: {', '.join(sorted(reserved_names))}. "
|
||||
"Please choose a different name."
|
||||
)
|
||||
|
||||
class_name = name.replace("_", " ").replace("-", " ").title().replace(" ", "")
|
||||
|
||||
class_name = re.sub(r"[^a-zA-Z0-9_]", "", class_name)
|
||||
|
||||
if not class_name:
|
||||
raise ValueError(
|
||||
f"Project name '{name}' contains no valid characters for a Python class name"
|
||||
)
|
||||
|
||||
if class_name[0].isdigit():
|
||||
raise ValueError(
|
||||
f"Project name '{name}' would generate class name '{class_name}' which cannot start with a digit"
|
||||
)
|
||||
|
||||
# Check if the original name (before title casing) is a keyword
|
||||
original_name_clean = re.sub(
|
||||
r"[^a-zA-Z0-9_]", "", name.replace("_", "").replace("-", "").lower()
|
||||
)
|
||||
if (
|
||||
keyword.iskeyword(original_name_clean)
|
||||
or keyword.iskeyword(class_name)
|
||||
or class_name in ("True", "False", "None")
|
||||
):
|
||||
raise ValueError(
|
||||
f"Project name '{name}' would generate class name '{class_name}' which is a reserved Python keyword"
|
||||
)
|
||||
|
||||
if not class_name.isidentifier():
|
||||
raise ValueError(
|
||||
f"Project name '{name}' would generate invalid Python class name '{class_name}'"
|
||||
)
|
||||
|
||||
if parent_folder:
|
||||
folder_path = Path(parent_folder) / folder_name
|
||||
else:
|
||||
folder_path = Path(folder_name)
|
||||
|
||||
if folder_path.exists():
|
||||
if not click.confirm(
|
||||
f"Folder {folder_name} already exists. Do you want to override it?"
|
||||
):
|
||||
click.secho("Operation cancelled.", fg="yellow")
|
||||
sys.exit(0)
|
||||
click.secho(f"Overriding folder {folder_name}...", fg="green", bold=True)
|
||||
shutil.rmtree(folder_path) # Delete the existing folder and its contents
|
||||
|
||||
click.secho(
|
||||
f"Creating {'crew' if parent_folder else 'folder'} {folder_name}...",
|
||||
fg="green",
|
||||
bold=True,
|
||||
)
|
||||
|
||||
folder_path.mkdir(parents=True)
|
||||
(folder_path / "tests").mkdir(exist_ok=True)
|
||||
(folder_path / "knowledge").mkdir(exist_ok=True)
|
||||
if not parent_folder:
|
||||
(folder_path / "src" / folder_name).mkdir(parents=True)
|
||||
(folder_path / "src" / folder_name / "tools").mkdir(parents=True)
|
||||
(folder_path / "src" / folder_name / "config").mkdir(parents=True)
|
||||
|
||||
# Copy AGENTS.md to project root (top-level projects only)
|
||||
package_dir = Path(__file__).parent
|
||||
agents_md_src = package_dir / "templates" / "AGENTS.md"
|
||||
if agents_md_src.exists():
|
||||
shutil.copy2(agents_md_src, folder_path / "AGENTS.md")
|
||||
|
||||
return folder_path, folder_name, class_name
|
||||
|
||||
|
||||
def copy_template_files(
|
||||
folder_path: Path, name: str, class_name: str, parent_folder: str | None
|
||||
) -> None:
|
||||
package_dir = Path(__file__).parent
|
||||
templates_dir = package_dir / "templates" / "crew"
|
||||
|
||||
root_template_files = (
|
||||
[
|
||||
".gitignore",
|
||||
"pyproject.toml",
|
||||
"README.md",
|
||||
"knowledge/user_preference.txt",
|
||||
]
|
||||
if not parent_folder
|
||||
else []
|
||||
)
|
||||
tools_template_files = ["tools/custom_tool.py", "tools/__init__.py"]
|
||||
config_template_files = ["config/agents.yaml", "config/tasks.yaml"]
|
||||
src_template_files = (
|
||||
["__init__.py", "main.py", "crew.py"] if not parent_folder else ["crew.py"]
|
||||
)
|
||||
|
||||
for file_name in root_template_files:
|
||||
src_file = templates_dir / file_name
|
||||
dst_file = folder_path / file_name
|
||||
copy_template(src_file, dst_file, name, class_name, folder_path.name)
|
||||
|
||||
src_folder = (
|
||||
folder_path / "src" / folder_path.name if not parent_folder else folder_path
|
||||
)
|
||||
|
||||
for file_name in src_template_files:
|
||||
src_file = templates_dir / file_name
|
||||
dst_file = src_folder / file_name
|
||||
copy_template(src_file, dst_file, name, class_name, folder_path.name)
|
||||
|
||||
if not parent_folder:
|
||||
for file_name in tools_template_files + config_template_files:
|
||||
src_file = templates_dir / file_name
|
||||
dst_file = src_folder / file_name
|
||||
copy_template(src_file, dst_file, name, class_name, folder_path.name)
|
||||
|
||||
|
||||
def create_crew(
|
||||
name: str,
|
||||
provider: str | None = None,
|
||||
skip_provider: bool = False,
|
||||
parent_folder: str | None = None,
|
||||
) -> None:
|
||||
folder_path, folder_name, class_name = create_folder_structure(name, parent_folder)
|
||||
env_vars = load_env_vars(folder_path)
|
||||
if not skip_provider:
|
||||
if not provider:
|
||||
provider_models = get_provider_data()
|
||||
if not provider_models:
|
||||
return
|
||||
|
||||
existing_provider = None
|
||||
for provider, env_keys in ENV_VARS.items():
|
||||
if any(
|
||||
"key_name" in details and details["key_name"] in env_vars
|
||||
for details in env_keys
|
||||
):
|
||||
existing_provider = provider
|
||||
break
|
||||
|
||||
if existing_provider:
|
||||
if not click.confirm(
|
||||
f"Found existing environment variable configuration for {existing_provider.capitalize()}. Do you want to override it?"
|
||||
):
|
||||
click.secho("Keeping existing provider configuration.", fg="yellow")
|
||||
return
|
||||
|
||||
provider_models = get_provider_data()
|
||||
if not provider_models:
|
||||
return
|
||||
|
||||
while True:
|
||||
selected_provider = select_provider(provider_models)
|
||||
if selected_provider is None: # User typed 'q'
|
||||
click.secho("Exiting...", fg="yellow")
|
||||
sys.exit(0)
|
||||
if selected_provider and isinstance(
|
||||
selected_provider, str
|
||||
): # Valid selection
|
||||
break
|
||||
click.secho(
|
||||
"No provider selected. Please try again or press 'q' to exit.", fg="red"
|
||||
)
|
||||
|
||||
# Check if the selected provider has predefined models
|
||||
if MODELS.get(selected_provider):
|
||||
while True:
|
||||
selected_model = select_model(selected_provider, provider_models)
|
||||
if selected_model is None: # User typed 'q'
|
||||
click.secho("Exiting...", fg="yellow")
|
||||
sys.exit(0)
|
||||
if selected_model: # Valid selection
|
||||
break
|
||||
click.secho(
|
||||
"No model selected. Please try again or press 'q' to exit.",
|
||||
fg="red",
|
||||
)
|
||||
env_vars["MODEL"] = selected_model
|
||||
|
||||
# Check if the selected provider requires API keys
|
||||
if selected_provider in ENV_VARS:
|
||||
provider_env_vars = ENV_VARS[selected_provider]
|
||||
for details in provider_env_vars:
|
||||
if details.get("default", False):
|
||||
# Automatically add default key-value pairs
|
||||
for key, value in details.items():
|
||||
if key not in ["prompt", "key_name", "default"]:
|
||||
env_vars[key] = value
|
||||
elif "key_name" in details:
|
||||
# Prompt for non-default key-value pairs
|
||||
prompt = details["prompt"]
|
||||
key_name = details["key_name"]
|
||||
api_key_value = click.prompt(prompt, default="", show_default=False)
|
||||
|
||||
if api_key_value.strip():
|
||||
env_vars[key_name] = api_key_value
|
||||
|
||||
if env_vars:
|
||||
write_env_file(folder_path, env_vars)
|
||||
click.secho("API keys and model saved to .env file", fg="green")
|
||||
else:
|
||||
click.secho(
|
||||
"No API keys provided. Skipping .env file creation.", fg="yellow"
|
||||
)
|
||||
|
||||
click.secho(f"Selected model: {env_vars.get('MODEL', 'N/A')}", fg="green")
|
||||
|
||||
package_dir = Path(__file__).parent
|
||||
templates_dir = package_dir / "templates" / "crew"
|
||||
|
||||
root_template_files = (
|
||||
[".gitignore", "pyproject.toml", "README.md", "knowledge/user_preference.txt"]
|
||||
if not parent_folder
|
||||
else []
|
||||
)
|
||||
tools_template_files = ["tools/custom_tool.py", "tools/__init__.py"]
|
||||
config_template_files = ["config/agents.yaml", "config/tasks.yaml"]
|
||||
src_template_files = (
|
||||
["__init__.py", "main.py", "crew.py"] if not parent_folder else ["crew.py"]
|
||||
)
|
||||
|
||||
for file_name in root_template_files:
|
||||
src_file = templates_dir / file_name
|
||||
dst_file = folder_path / file_name
|
||||
copy_template(src_file, dst_file, name, class_name, folder_name)
|
||||
|
||||
src_folder = folder_path / "src" / folder_name if not parent_folder else folder_path
|
||||
|
||||
for file_name in src_template_files:
|
||||
src_file = templates_dir / file_name
|
||||
dst_file = src_folder / file_name
|
||||
copy_template(src_file, dst_file, name, class_name, folder_name)
|
||||
|
||||
if not parent_folder:
|
||||
for file_name in tools_template_files + config_template_files:
|
||||
src_file = templates_dir / file_name
|
||||
dst_file = src_folder / file_name
|
||||
copy_template(src_file, dst_file, name, class_name, folder_name)
|
||||
|
||||
click.secho(f"Crew {name} created successfully!", fg="green", bold=True)
|
||||
103
lib/cli/src/crewai_cli/create_flow.py
Normal file
103
lib/cli/src/crewai_cli/create_flow.py
Normal file
@@ -0,0 +1,103 @@
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
|
||||
import click
|
||||
from crewai_core.telemetry import Telemetry
|
||||
|
||||
|
||||
def create_flow(name: str) -> None:
|
||||
"""Create a new flow."""
|
||||
folder_name = name.replace(" ", "_").replace("-", "_").lower()
|
||||
class_name = name.replace("_", " ").replace("-", " ").title().replace(" ", "")
|
||||
|
||||
click.secho(f"Creating flow {folder_name}...", fg="green", bold=True)
|
||||
|
||||
project_root = Path(folder_name)
|
||||
if project_root.exists():
|
||||
click.secho(f"Error: Folder {folder_name} already exists.", fg="red")
|
||||
return
|
||||
|
||||
telemetry = Telemetry()
|
||||
telemetry.flow_creation_span(class_name)
|
||||
|
||||
# Create directory structure
|
||||
(project_root / "src" / folder_name).mkdir(parents=True)
|
||||
(project_root / "src" / folder_name / "crews").mkdir(parents=True)
|
||||
(project_root / "src" / folder_name / "tools").mkdir(parents=True)
|
||||
(project_root / "tests").mkdir(exist_ok=True)
|
||||
|
||||
# Create .env file
|
||||
with open(project_root / ".env", "w") as file:
|
||||
file.write("OPENAI_API_KEY=YOUR_API_KEY")
|
||||
|
||||
package_dir = Path(__file__).parent
|
||||
templates_dir = package_dir / "templates" / "flow"
|
||||
|
||||
# Copy AGENTS.md to project root
|
||||
agents_md_src = package_dir / "templates" / "AGENTS.md"
|
||||
if agents_md_src.exists():
|
||||
shutil.copy2(agents_md_src, project_root / "AGENTS.md")
|
||||
|
||||
# List of template files to copy
|
||||
root_template_files = [".gitignore", "pyproject.toml", "README.md"]
|
||||
src_template_files = ["__init__.py", "main.py"]
|
||||
tools_template_files = ["tools/__init__.py", "tools/custom_tool.py"]
|
||||
|
||||
crew_folders = [
|
||||
"content_crew",
|
||||
]
|
||||
|
||||
def process_file(src_file: Path, dst_file: Path) -> None:
|
||||
if src_file.suffix in [".pyc", ".pyo", ".pyd"]:
|
||||
return
|
||||
|
||||
try:
|
||||
with open(src_file, "r", encoding="utf-8") as file:
|
||||
content = file.read()
|
||||
except Exception as e:
|
||||
click.secho(f"Error processing file {src_file}: {e}", fg="red")
|
||||
return
|
||||
|
||||
content = content.replace("{{name}}", name)
|
||||
content = content.replace("{{flow_name}}", class_name)
|
||||
content = content.replace("{{folder_name}}", folder_name)
|
||||
|
||||
with open(dst_file, "w") as file:
|
||||
file.write(content)
|
||||
|
||||
# Copy and process root template files
|
||||
for file_name in root_template_files:
|
||||
src_file = templates_dir / file_name
|
||||
dst_file = project_root / file_name
|
||||
process_file(src_file, dst_file)
|
||||
|
||||
# Copy and process src template files
|
||||
for file_name in src_template_files:
|
||||
src_file = templates_dir / file_name
|
||||
dst_file = project_root / "src" / folder_name / file_name
|
||||
process_file(src_file, dst_file)
|
||||
|
||||
# Copy tools files
|
||||
for file_name in tools_template_files:
|
||||
src_file = templates_dir / file_name
|
||||
dst_file = project_root / "src" / folder_name / file_name
|
||||
process_file(src_file, dst_file)
|
||||
|
||||
# Copy crew folders
|
||||
for crew_folder in crew_folders:
|
||||
src_crew_folder = templates_dir / "crews" / crew_folder
|
||||
dst_crew_folder = project_root / "src" / folder_name / "crews" / crew_folder
|
||||
if src_crew_folder.exists():
|
||||
for src_file in src_crew_folder.rglob("*"):
|
||||
if src_file.is_file():
|
||||
relative_path = src_file.relative_to(src_crew_folder)
|
||||
dst_file = dst_crew_folder / relative_path
|
||||
dst_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
process_file(src_file, dst_file)
|
||||
else:
|
||||
click.secho(
|
||||
f"Warning: Crew folder {crew_folder} not found in template.",
|
||||
fg="yellow",
|
||||
)
|
||||
|
||||
click.secho(f"Flow {name} created successfully!", fg="green", bold=True)
|
||||
23
lib/cli/src/crewai_cli/crew_chat.py
Normal file
23
lib/cli/src/crewai_cli/crew_chat.py
Normal file
@@ -0,0 +1,23 @@
|
||||
"""Wrapper for the crew chat command.
|
||||
|
||||
Delegates to ``crewai.utilities.crew_chat.run_chat`` when the full crewai
|
||||
package is installed, otherwise prints a helpful error message.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import click
|
||||
|
||||
|
||||
def run_chat() -> None:
|
||||
try:
|
||||
from crewai.utilities.crew_chat import run_chat as _run_chat
|
||||
except ImportError:
|
||||
click.secho(
|
||||
"The 'chat' command requires the full crewai package.\n"
|
||||
"Install it with: pip install crewai",
|
||||
fg="red",
|
||||
)
|
||||
raise SystemExit(1) from None
|
||||
|
||||
_run_chat()
|
||||
0
lib/cli/src/crewai_cli/deploy/__init__.py
Normal file
0
lib/cli/src/crewai_cli/deploy/__init__.py
Normal file
308
lib/cli/src/crewai_cli/deploy/main.py
Normal file
308
lib/cli/src/crewai_cli/deploy/main.py
Normal file
@@ -0,0 +1,308 @@
|
||||
from typing import Any
|
||||
|
||||
from rich.console import Console
|
||||
|
||||
from crewai_cli import git
|
||||
from crewai_cli.command import BaseCommand, PlusAPIMixin
|
||||
from crewai_cli.deploy.validate import validate_project
|
||||
from crewai_cli.utils import fetch_and_json_env_file, get_project_name
|
||||
|
||||
|
||||
console = Console()
|
||||
|
||||
|
||||
def _run_predeploy_validation(skip_validate: bool) -> bool:
|
||||
"""Run pre-deploy validation unless skipped.
|
||||
|
||||
Returns True if deployment should proceed, False if it should abort.
|
||||
"""
|
||||
if skip_validate:
|
||||
console.print(
|
||||
"[yellow]Skipping pre-deploy validation (--skip-validate).[/yellow]"
|
||||
)
|
||||
return True
|
||||
|
||||
console.print("Running pre-deploy validation...", style="bold blue")
|
||||
validator = validate_project()
|
||||
if not validator.ok:
|
||||
console.print(
|
||||
"\n[bold red]Pre-deploy validation failed. "
|
||||
"Fix the issues above or re-run with --skip-validate.[/bold red]"
|
||||
)
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
class DeployCommand(BaseCommand, PlusAPIMixin):
|
||||
"""
|
||||
A class to handle deployment-related operations for CrewAI projects.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""
|
||||
Initialize the DeployCommand with project name and API client.
|
||||
"""
|
||||
|
||||
BaseCommand.__init__(self)
|
||||
PlusAPIMixin.__init__(self, telemetry=self._telemetry)
|
||||
self.project_name = get_project_name(require=True)
|
||||
|
||||
def _standard_no_param_error_message(self) -> None:
|
||||
"""
|
||||
Display a standard error message when no UUID or project name is available.
|
||||
"""
|
||||
console.print(
|
||||
"No UUID provided, project pyproject.toml not found or with error.",
|
||||
style="bold red",
|
||||
)
|
||||
|
||||
def _display_deployment_info(self, json_response: dict[str, Any]) -> None:
|
||||
"""
|
||||
Display deployment information.
|
||||
|
||||
Args:
|
||||
json_response (Dict[str, Any]): The deployment information to display.
|
||||
"""
|
||||
console.print("Deploying the crew...\n", style="bold blue")
|
||||
for key, value in json_response.items():
|
||||
console.print(f"{key.title()}: [green]{value}[/green]")
|
||||
console.print("\nTo check the status of the deployment, run:")
|
||||
console.print("crewai deploy status")
|
||||
console.print(" or")
|
||||
console.print(f'crewai deploy status --uuid "{json_response["uuid"]}"')
|
||||
|
||||
def _display_logs(self, log_messages: list[dict[str, Any]]) -> None:
|
||||
"""
|
||||
Display log messages.
|
||||
|
||||
Args:
|
||||
log_messages (List[Dict[str, Any]]): The log messages to display.
|
||||
"""
|
||||
for log_message in log_messages:
|
||||
console.print(
|
||||
f"{log_message['timestamp']} - {log_message['level']}: {log_message['message']}"
|
||||
)
|
||||
|
||||
def deploy(self, uuid: str | None = None, skip_validate: bool = False) -> None:
|
||||
"""
|
||||
Deploy a crew using either UUID or project name.
|
||||
|
||||
Args:
|
||||
uuid (Optional[str]): The UUID of the crew to deploy.
|
||||
skip_validate (bool): Skip pre-deploy validation checks.
|
||||
"""
|
||||
if not _run_predeploy_validation(skip_validate):
|
||||
return
|
||||
self._telemetry.start_deployment_span(uuid)
|
||||
console.print("Starting deployment...", style="bold blue")
|
||||
if uuid:
|
||||
response = self.plus_api_client.deploy_by_uuid(uuid)
|
||||
elif self.project_name:
|
||||
response = self.plus_api_client.deploy_by_name(self.project_name)
|
||||
else:
|
||||
self._standard_no_param_error_message()
|
||||
return
|
||||
|
||||
self._validate_response(response)
|
||||
self._display_deployment_info(response.json())
|
||||
|
||||
def create_crew(self, confirm: bool = False, skip_validate: bool = False) -> None:
|
||||
"""
|
||||
Create a new crew deployment.
|
||||
|
||||
Args:
|
||||
confirm (bool): Whether to skip the interactive confirmation prompt.
|
||||
skip_validate (bool): Skip pre-deploy validation checks.
|
||||
"""
|
||||
if not _run_predeploy_validation(skip_validate):
|
||||
return
|
||||
self._telemetry.create_crew_deployment_span()
|
||||
console.print("Creating deployment...", style="bold blue")
|
||||
env_vars = fetch_and_json_env_file()
|
||||
|
||||
try:
|
||||
remote_repo_url = git.Repository().origin_url()
|
||||
except ValueError:
|
||||
remote_repo_url = None
|
||||
|
||||
if remote_repo_url is None:
|
||||
console.print("No remote repository URL found.", style="bold red")
|
||||
console.print(
|
||||
"Please ensure your project has a valid remote repository.",
|
||||
style="yellow",
|
||||
)
|
||||
return
|
||||
|
||||
self._confirm_input(env_vars, remote_repo_url, confirm)
|
||||
payload = self._create_payload(env_vars, remote_repo_url)
|
||||
response = self.plus_api_client.create_crew(payload)
|
||||
|
||||
self._validate_response(response)
|
||||
self._display_creation_success(response.json())
|
||||
|
||||
def _confirm_input(
|
||||
self, env_vars: dict[str, str], remote_repo_url: str, confirm: bool
|
||||
) -> None:
|
||||
"""
|
||||
Confirm input parameters with the user.
|
||||
|
||||
Args:
|
||||
env_vars (Dict[str, str]): Environment variables.
|
||||
remote_repo_url (str): Remote repository URL.
|
||||
confirm (bool): Whether to confirm input.
|
||||
"""
|
||||
if not confirm:
|
||||
input(f"Press Enter to continue with the following Env vars: {env_vars}")
|
||||
input(
|
||||
f"Press Enter to continue with the following remote repository: {remote_repo_url}\n"
|
||||
)
|
||||
|
||||
def _create_payload(
|
||||
self,
|
||||
env_vars: dict[str, str],
|
||||
remote_repo_url: str,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Create the payload for crew creation.
|
||||
|
||||
Args:
|
||||
remote_repo_url (str): Remote repository URL.
|
||||
env_vars (Dict[str, str]): Environment variables.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: The payload for crew creation.
|
||||
"""
|
||||
return {
|
||||
"deploy": {
|
||||
"name": self.project_name,
|
||||
"repo_clone_url": remote_repo_url,
|
||||
"env": env_vars,
|
||||
}
|
||||
}
|
||||
|
||||
def _display_creation_success(self, json_response: dict[str, Any]) -> None:
|
||||
"""
|
||||
Display success message after crew creation.
|
||||
|
||||
Args:
|
||||
json_response (Dict[str, Any]): The response containing crew information.
|
||||
"""
|
||||
console.print("Deployment created successfully!\n", style="bold green")
|
||||
console.print(
|
||||
f"Name: {self.project_name} ({json_response['uuid']})", style="bold green"
|
||||
)
|
||||
console.print(f"Status: {json_response['status']}", style="bold green")
|
||||
console.print("\nTo (re)deploy the crew, run:")
|
||||
console.print("crewai deploy push")
|
||||
console.print(" or")
|
||||
console.print(f"crewai deploy push --uuid {json_response['uuid']}")
|
||||
|
||||
def list_crews(self) -> None:
|
||||
"""
|
||||
List all available crews.
|
||||
"""
|
||||
console.print("Listing all Crews\n", style="bold blue")
|
||||
|
||||
response = self.plus_api_client.list_crews()
|
||||
json_response = response.json()
|
||||
if response.status_code == 200:
|
||||
self._display_crews(json_response)
|
||||
else:
|
||||
self._display_no_crews_message()
|
||||
|
||||
def _display_crews(self, crews_data: list[dict[str, Any]]) -> None:
|
||||
"""
|
||||
Display the list of crews.
|
||||
|
||||
Args:
|
||||
crews_data (List[Dict[str, Any]]): List of crew data to display.
|
||||
"""
|
||||
for crew_data in crews_data:
|
||||
console.print(
|
||||
f"- {crew_data['name']} ({crew_data['uuid']}) [blue]{crew_data['status']}[/blue]"
|
||||
)
|
||||
|
||||
def _display_no_crews_message(self) -> None:
|
||||
"""
|
||||
Display a message when no crews are available.
|
||||
"""
|
||||
console.print("You don't have any Crews yet. Let's create one!", style="yellow")
|
||||
console.print(" crewai create crew <crew_name>", style="green")
|
||||
|
||||
def get_crew_status(self, uuid: str | None = None) -> None:
|
||||
"""
|
||||
Get the status of a crew.
|
||||
|
||||
Args:
|
||||
uuid (Optional[str]): The UUID of the crew to check.
|
||||
"""
|
||||
console.print("Fetching deployment status...", style="bold blue")
|
||||
if uuid:
|
||||
response = self.plus_api_client.crew_status_by_uuid(uuid)
|
||||
elif self.project_name:
|
||||
response = self.plus_api_client.crew_status_by_name(self.project_name)
|
||||
else:
|
||||
self._standard_no_param_error_message()
|
||||
return
|
||||
|
||||
self._validate_response(response)
|
||||
self._display_crew_status(response.json())
|
||||
|
||||
def _display_crew_status(self, status_data: dict[str, str]) -> None:
|
||||
"""
|
||||
Display the status of a crew.
|
||||
|
||||
Args:
|
||||
status_data (Dict[str, str]): The status data to display.
|
||||
"""
|
||||
console.print(f"Name:\t {status_data['name']}")
|
||||
console.print(f"Status:\t {status_data['status']}")
|
||||
|
||||
def get_crew_logs(self, uuid: str | None, log_type: str = "deployment") -> None:
|
||||
"""
|
||||
Get logs for a crew.
|
||||
|
||||
Args:
|
||||
uuid (Optional[str]): The UUID of the crew to get logs for.
|
||||
log_type (str): The type of logs to retrieve (default: "deployment").
|
||||
"""
|
||||
self._telemetry.get_crew_logs_span(uuid, log_type)
|
||||
console.print(f"Fetching {log_type} logs...", style="bold blue")
|
||||
|
||||
if uuid:
|
||||
response = self.plus_api_client.crew_by_uuid(uuid, log_type)
|
||||
elif self.project_name:
|
||||
response = self.plus_api_client.crew_by_name(self.project_name, log_type)
|
||||
else:
|
||||
self._standard_no_param_error_message()
|
||||
return
|
||||
|
||||
self._validate_response(response)
|
||||
self._display_logs(response.json())
|
||||
|
||||
def remove_crew(self, uuid: str | None) -> None:
|
||||
"""
|
||||
Remove a crew deployment.
|
||||
|
||||
Args:
|
||||
uuid (Optional[str]): The UUID of the crew to remove.
|
||||
"""
|
||||
self._telemetry.remove_crew_span(uuid)
|
||||
console.print("Removing deployment...", style="bold blue")
|
||||
|
||||
if uuid:
|
||||
response = self.plus_api_client.delete_crew_by_uuid(uuid)
|
||||
elif self.project_name:
|
||||
response = self.plus_api_client.delete_crew_by_name(self.project_name)
|
||||
else:
|
||||
self._standard_no_param_error_message()
|
||||
return
|
||||
|
||||
if response.status_code == 204:
|
||||
console.print(
|
||||
f"Crew '{self.project_name}' removed successfully.", style="green"
|
||||
)
|
||||
else:
|
||||
console.print(
|
||||
f"Failed to remove crew '{self.project_name}'", style="bold red"
|
||||
)
|
||||
845
lib/cli/src/crewai_cli/deploy/validate.py
Normal file
845
lib/cli/src/crewai_cli/deploy/validate.py
Normal file
@@ -0,0 +1,845 @@
|
||||
"""Pre-deploy validation for CrewAI projects.
|
||||
|
||||
Catches locally what a deploy would reject at build or runtime so users
|
||||
don't burn deployment attempts on fixable project-structure problems.
|
||||
|
||||
Each check is grouped into one of:
|
||||
- ERROR: will block a deployment; validator exits non-zero.
|
||||
- WARNING: may still deploy but is almost always a deployment bug; printed
|
||||
but does not block.
|
||||
|
||||
The individual checks mirror the categories observed in production
|
||||
deployment-failure logs:
|
||||
|
||||
1. pyproject.toml present with ``[project].name``
|
||||
2. lockfile (``uv.lock`` or ``poetry.lock``) present and not stale
|
||||
3. package directory at ``src/<package>/`` exists (no empty name, no egg-info)
|
||||
4. standard crew files: ``crew.py``, ``config/agents.yaml``, ``config/tasks.yaml``
|
||||
5. flow entrypoint: ``main.py`` with a Flow subclass
|
||||
6. hatch wheel target resolves (packages = [...] or default dir matches name)
|
||||
7. crew/flow module imports cleanly (catches ``@CrewBase not found``,
|
||||
``No Flow subclass found``, provider import errors)
|
||||
8. environment variables referenced in code vs ``.env`` / deployment env
|
||||
9. installed crewai vs lockfile pin (catches missing-attribute failures from
|
||||
stale pins)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
import re
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
from typing import Any
|
||||
|
||||
from rich.console import Console
|
||||
|
||||
from crewai_cli.utils import parse_toml
|
||||
|
||||
|
||||
console = Console()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Severity(str, Enum):
|
||||
"""Severity of a validation finding."""
|
||||
|
||||
ERROR = "error"
|
||||
WARNING = "warning"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ValidationResult:
|
||||
"""A single finding from a validation check.
|
||||
|
||||
Attributes:
|
||||
severity: whether this blocks deploy or is advisory.
|
||||
code: stable short identifier, used in tests and docs
|
||||
(e.g. ``missing_pyproject``, ``stale_lockfile``).
|
||||
title: one-line summary shown to the user.
|
||||
detail: optional multi-line explanation.
|
||||
hint: optional remediation suggestion.
|
||||
"""
|
||||
|
||||
severity: Severity
|
||||
code: str
|
||||
title: str
|
||||
detail: str = ""
|
||||
hint: str = ""
|
||||
|
||||
|
||||
# Maps known provider env var names → label used in hint messages.
|
||||
_KNOWN_API_KEY_HINTS: dict[str, str] = {
|
||||
"OPENAI_API_KEY": "OpenAI",
|
||||
"ANTHROPIC_API_KEY": "Anthropic",
|
||||
"GOOGLE_API_KEY": "Google",
|
||||
"GEMINI_API_KEY": "Gemini",
|
||||
"AZURE_OPENAI_API_KEY": "Azure OpenAI",
|
||||
"AZURE_API_KEY": "Azure",
|
||||
"AWS_ACCESS_KEY_ID": "AWS",
|
||||
"AWS_SECRET_ACCESS_KEY": "AWS",
|
||||
"COHERE_API_KEY": "Cohere",
|
||||
"GROQ_API_KEY": "Groq",
|
||||
"MISTRAL_API_KEY": "Mistral",
|
||||
"TAVILY_API_KEY": "Tavily",
|
||||
"SERPER_API_KEY": "Serper",
|
||||
"SERPLY_API_KEY": "Serply",
|
||||
"PERPLEXITY_API_KEY": "Perplexity",
|
||||
"DEEPSEEK_API_KEY": "DeepSeek",
|
||||
"OPENROUTER_API_KEY": "OpenRouter",
|
||||
"FIRECRAWL_API_KEY": "Firecrawl",
|
||||
"EXA_API_KEY": "Exa",
|
||||
"BROWSERBASE_API_KEY": "Browserbase",
|
||||
}
|
||||
|
||||
|
||||
def normalize_package_name(project_name: str) -> str:
|
||||
"""Normalize a pyproject project.name into a Python package directory name.
|
||||
|
||||
Mirrors the rules in ``crewai.cli.create_crew.create_crew`` so the
|
||||
validator agrees with the scaffolder about where ``src/<pkg>/`` should
|
||||
live.
|
||||
"""
|
||||
folder = project_name.replace(" ", "_").replace("-", "_").lower()
|
||||
return re.sub(r"[^a-zA-Z0-9_]", "", folder)
|
||||
|
||||
|
||||
class DeployValidator:
|
||||
"""Runs the full pre-deploy validation suite against a project directory."""
|
||||
|
||||
def __init__(self, project_root: Path | None = None) -> None:
|
||||
self.project_root: Path = (project_root or Path.cwd()).resolve()
|
||||
self.results: list[ValidationResult] = []
|
||||
self._pyproject: dict[str, Any] | None = None
|
||||
self._project_name: str | None = None
|
||||
self._package_name: str | None = None
|
||||
self._package_dir: Path | None = None
|
||||
self._is_flow: bool = False
|
||||
|
||||
def _add(
|
||||
self,
|
||||
severity: Severity,
|
||||
code: str,
|
||||
title: str,
|
||||
detail: str = "",
|
||||
hint: str = "",
|
||||
) -> None:
|
||||
self.results.append(
|
||||
ValidationResult(
|
||||
severity=severity,
|
||||
code=code,
|
||||
title=title,
|
||||
detail=detail,
|
||||
hint=hint,
|
||||
)
|
||||
)
|
||||
|
||||
@property
|
||||
def errors(self) -> list[ValidationResult]:
|
||||
return [r for r in self.results if r.severity is Severity.ERROR]
|
||||
|
||||
@property
|
||||
def warnings(self) -> list[ValidationResult]:
|
||||
return [r for r in self.results if r.severity is Severity.WARNING]
|
||||
|
||||
@property
|
||||
def ok(self) -> bool:
|
||||
return not self.errors
|
||||
|
||||
def run(self) -> list[ValidationResult]:
|
||||
"""Run all checks. Later checks are skipped when earlier ones make
|
||||
them impossible (e.g. no pyproject.toml → no lockfile check)."""
|
||||
if not self._check_pyproject():
|
||||
return self.results
|
||||
|
||||
self._check_lockfile()
|
||||
|
||||
if not self._check_package_dir():
|
||||
self._check_hatch_wheel_target()
|
||||
return self.results
|
||||
|
||||
if self._is_flow:
|
||||
self._check_flow_entrypoint()
|
||||
else:
|
||||
self._check_crew_entrypoint()
|
||||
self._check_config_yamls()
|
||||
|
||||
self._check_hatch_wheel_target()
|
||||
self._check_module_imports()
|
||||
self._check_env_vars()
|
||||
self._check_version_vs_lockfile()
|
||||
|
||||
return self.results
|
||||
|
||||
def _check_pyproject(self) -> bool:
|
||||
pyproject_path = self.project_root / "pyproject.toml"
|
||||
if not pyproject_path.exists():
|
||||
self._add(
|
||||
Severity.ERROR,
|
||||
"missing_pyproject",
|
||||
"Cannot find pyproject.toml",
|
||||
detail=(
|
||||
f"Expected pyproject.toml at {pyproject_path}. "
|
||||
"CrewAI projects must be installable Python packages."
|
||||
),
|
||||
hint="Run `crewai create crew <name>` to scaffold a valid project layout.",
|
||||
)
|
||||
return False
|
||||
|
||||
try:
|
||||
self._pyproject = parse_toml(pyproject_path.read_text())
|
||||
except Exception as e:
|
||||
self._add(
|
||||
Severity.ERROR,
|
||||
"invalid_pyproject",
|
||||
"pyproject.toml is not valid TOML",
|
||||
detail=str(e),
|
||||
)
|
||||
return False
|
||||
|
||||
project = self._pyproject.get("project") or {}
|
||||
name = project.get("name")
|
||||
if not isinstance(name, str) or not name.strip():
|
||||
self._add(
|
||||
Severity.ERROR,
|
||||
"missing_project_name",
|
||||
"pyproject.toml is missing [project].name",
|
||||
detail=(
|
||||
"Without a project name the platform cannot resolve your "
|
||||
"package directory (this produces errors like "
|
||||
"'Cannot find src//crew.py')."
|
||||
),
|
||||
hint='Set a `name = "..."` field under `[project]` in pyproject.toml.',
|
||||
)
|
||||
return False
|
||||
|
||||
self._project_name = name
|
||||
self._package_name = normalize_package_name(name)
|
||||
self._is_flow = (self._pyproject.get("tool") or {}).get("crewai", {}).get(
|
||||
"type"
|
||||
) == "flow"
|
||||
return True
|
||||
|
||||
def _check_lockfile(self) -> None:
|
||||
uv_lock = self.project_root / "uv.lock"
|
||||
poetry_lock = self.project_root / "poetry.lock"
|
||||
pyproject = self.project_root / "pyproject.toml"
|
||||
|
||||
if not uv_lock.exists() and not poetry_lock.exists():
|
||||
self._add(
|
||||
Severity.ERROR,
|
||||
"missing_lockfile",
|
||||
"Expected to find at least one of these files: uv.lock or poetry.lock",
|
||||
hint=(
|
||||
"Run `uv lock` (recommended) or `poetry lock` in your project "
|
||||
"directory, commit the lockfile, then redeploy."
|
||||
),
|
||||
)
|
||||
return
|
||||
|
||||
lockfile = uv_lock if uv_lock.exists() else poetry_lock
|
||||
try:
|
||||
if lockfile.stat().st_mtime < pyproject.stat().st_mtime:
|
||||
self._add(
|
||||
Severity.WARNING,
|
||||
"stale_lockfile",
|
||||
f"{lockfile.name} is older than pyproject.toml",
|
||||
detail=(
|
||||
"Your lockfile may not reflect recent dependency changes. "
|
||||
"The platform resolves from the lockfile, so deployed "
|
||||
"dependencies may differ from local."
|
||||
),
|
||||
hint="Run `uv lock` (or `poetry lock`) and commit the result.",
|
||||
)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
def _check_package_dir(self) -> bool:
|
||||
if self._package_name is None:
|
||||
return False
|
||||
|
||||
src_dir = self.project_root / "src"
|
||||
if not src_dir.is_dir():
|
||||
self._add(
|
||||
Severity.ERROR,
|
||||
"missing_src_dir",
|
||||
"Missing src/ directory",
|
||||
detail=(
|
||||
"CrewAI deployments expect a src-layout project: "
|
||||
f"src/{self._package_name}/crew.py (or main.py for flows)."
|
||||
),
|
||||
hint="Run `crewai create crew <name>` to see the expected layout.",
|
||||
)
|
||||
return False
|
||||
|
||||
package_dir = src_dir / self._package_name
|
||||
if not package_dir.is_dir():
|
||||
siblings = [
|
||||
p.name
|
||||
for p in src_dir.iterdir()
|
||||
if p.is_dir() and not p.name.endswith(".egg-info")
|
||||
]
|
||||
egg_info = [
|
||||
p.name for p in src_dir.iterdir() if p.name.endswith(".egg-info")
|
||||
]
|
||||
|
||||
hint_parts = [
|
||||
f'Create src/{self._package_name}/ to match [project].name = "{self._project_name}".'
|
||||
]
|
||||
if siblings:
|
||||
hint_parts.append(
|
||||
f"Found other package directories: {', '.join(siblings)}. "
|
||||
f"Either rename one to '{self._package_name}' or update [project].name."
|
||||
)
|
||||
if egg_info:
|
||||
hint_parts.append(
|
||||
f"Delete stale build artifacts: {', '.join(egg_info)} "
|
||||
"(these confuse the platform's package discovery)."
|
||||
)
|
||||
|
||||
self._add(
|
||||
Severity.ERROR,
|
||||
"missing_package_dir",
|
||||
f"Cannot find src/{self._package_name}/",
|
||||
detail=(
|
||||
"The platform looks for your crew source under "
|
||||
"src/<package_name>/, derived from [project].name."
|
||||
),
|
||||
hint=" ".join(hint_parts),
|
||||
)
|
||||
return False
|
||||
|
||||
for p in src_dir.iterdir():
|
||||
if p.name.endswith(".egg-info"):
|
||||
self._add(
|
||||
Severity.WARNING,
|
||||
"stale_egg_info",
|
||||
f"Stale build artifact in src/: {p.name}",
|
||||
detail=(
|
||||
".egg-info directories can be mistaken for your package "
|
||||
"and cause 'Cannot find src/<name>.egg-info/crew.py' errors."
|
||||
),
|
||||
hint=f"Delete {p} and add `*.egg-info/` to .gitignore.",
|
||||
)
|
||||
|
||||
self._package_dir = package_dir
|
||||
return True
|
||||
|
||||
def _check_crew_entrypoint(self) -> None:
|
||||
if self._package_dir is None:
|
||||
return
|
||||
crew_py = self._package_dir / "crew.py"
|
||||
if not crew_py.is_file():
|
||||
self._add(
|
||||
Severity.ERROR,
|
||||
"missing_crew_py",
|
||||
f"Cannot find {crew_py.relative_to(self.project_root)}",
|
||||
detail=(
|
||||
"Standard crew projects must define a Crew class decorated "
|
||||
"with @CrewBase inside crew.py."
|
||||
),
|
||||
hint=(
|
||||
"Create crew.py with an @CrewBase-annotated class, or set "
|
||||
'`[tool.crewai] type = "flow"` in pyproject.toml if this is a flow.'
|
||||
),
|
||||
)
|
||||
|
||||
def _check_config_yamls(self) -> None:
|
||||
if self._package_dir is None:
|
||||
return
|
||||
config_dir = self._package_dir / "config"
|
||||
if not config_dir.is_dir():
|
||||
self._add(
|
||||
Severity.ERROR,
|
||||
"missing_config_dir",
|
||||
f"Cannot find {config_dir.relative_to(self.project_root)}",
|
||||
hint="Create a config/ directory with agents.yaml and tasks.yaml.",
|
||||
)
|
||||
return
|
||||
|
||||
for yaml_name in ("agents.yaml", "tasks.yaml"):
|
||||
yaml_path = config_dir / yaml_name
|
||||
if not yaml_path.is_file():
|
||||
self._add(
|
||||
Severity.ERROR,
|
||||
f"missing_{yaml_name.replace('.', '_')}",
|
||||
f"Cannot find {yaml_path.relative_to(self.project_root)}",
|
||||
detail=(
|
||||
"CrewAI loads agent and task config from these files; "
|
||||
"missing them causes empty-config warnings and runtime crashes."
|
||||
),
|
||||
)
|
||||
|
||||
def _check_flow_entrypoint(self) -> None:
|
||||
if self._package_dir is None:
|
||||
return
|
||||
main_py = self._package_dir / "main.py"
|
||||
if not main_py.is_file():
|
||||
self._add(
|
||||
Severity.ERROR,
|
||||
"missing_flow_main",
|
||||
f"Cannot find {main_py.relative_to(self.project_root)}",
|
||||
detail=(
|
||||
"Flow projects must define a Flow subclass in main.py. "
|
||||
'This project has `[tool.crewai] type = "flow"` set.'
|
||||
),
|
||||
hint="Create main.py with a `class MyFlow(Flow[...])`.",
|
||||
)
|
||||
|
||||
def _check_hatch_wheel_target(self) -> None:
|
||||
if not self._pyproject:
|
||||
return
|
||||
|
||||
build_system = self._pyproject.get("build-system") or {}
|
||||
backend = build_system.get("build-backend", "")
|
||||
if "hatchling" not in backend:
|
||||
return
|
||||
|
||||
hatch_wheel = (
|
||||
(self._pyproject.get("tool") or {})
|
||||
.get("hatch", {})
|
||||
.get("build", {})
|
||||
.get("targets", {})
|
||||
.get("wheel", {})
|
||||
)
|
||||
if hatch_wheel.get("packages") or hatch_wheel.get("only-include"):
|
||||
return
|
||||
|
||||
if self._package_dir and self._package_dir.is_dir():
|
||||
return
|
||||
|
||||
self._add(
|
||||
Severity.ERROR,
|
||||
"hatch_wheel_target_missing",
|
||||
"Hatchling cannot determine which files to ship",
|
||||
detail=(
|
||||
"Your pyproject uses hatchling but has no "
|
||||
"[tool.hatch.build.targets.wheel] configuration and no "
|
||||
"directory matching your project name."
|
||||
),
|
||||
hint=(
|
||||
"Add:\n"
|
||||
" [tool.hatch.build.targets.wheel]\n"
|
||||
f' packages = ["src/{self._package_name}"]'
|
||||
),
|
||||
)
|
||||
|
||||
def _check_module_imports(self) -> None:
|
||||
"""Import the user's crew/flow via `uv run` so the check sees the same
|
||||
package versions as `crewai run` would. Result is reported as JSON on
|
||||
the subprocess's stdout."""
|
||||
script = (
|
||||
"import json, sys, traceback, os\n"
|
||||
"os.chdir(sys.argv[1])\n"
|
||||
"try:\n"
|
||||
" from crewai.utilities.project_utils import get_crews, get_flows\n"
|
||||
" is_flow = sys.argv[2] == 'flow'\n"
|
||||
" if is_flow:\n"
|
||||
" instances = get_flows()\n"
|
||||
" kind = 'flow'\n"
|
||||
" else:\n"
|
||||
" instances = get_crews()\n"
|
||||
" kind = 'crew'\n"
|
||||
" print(json.dumps({'ok': True, 'kind': kind, 'count': len(instances)}))\n"
|
||||
"except BaseException as e:\n"
|
||||
" print(json.dumps({\n"
|
||||
" 'ok': False,\n"
|
||||
" 'error_type': type(e).__name__,\n"
|
||||
" 'error': str(e),\n"
|
||||
" 'traceback': traceback.format_exc(),\n"
|
||||
" }))\n"
|
||||
)
|
||||
|
||||
uv_path = shutil.which("uv")
|
||||
if uv_path is None:
|
||||
self._add(
|
||||
Severity.WARNING,
|
||||
"uv_not_found",
|
||||
"Skipping import check: `uv` not installed",
|
||||
hint="Install uv: https://docs.astral.sh/uv/",
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
proc = subprocess.run( # noqa: S603 - args constructed from trusted inputs
|
||||
[
|
||||
uv_path,
|
||||
"run",
|
||||
"python",
|
||||
"-c",
|
||||
script,
|
||||
str(self.project_root),
|
||||
"flow" if self._is_flow else "crew",
|
||||
],
|
||||
cwd=self.project_root,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=120,
|
||||
check=False,
|
||||
)
|
||||
except subprocess.TimeoutExpired:
|
||||
self._add(
|
||||
Severity.ERROR,
|
||||
"import_timeout",
|
||||
"Importing your crew/flow module timed out after 120s",
|
||||
detail=(
|
||||
"User code may be making network calls or doing heavy work "
|
||||
"at import time. Move that work into agent methods."
|
||||
),
|
||||
)
|
||||
return
|
||||
|
||||
# The payload is the last JSON object on stdout; user code may print
|
||||
# other lines before it.
|
||||
payload: dict[str, Any] | None = None
|
||||
for line in reversed(proc.stdout.splitlines()):
|
||||
line = line.strip()
|
||||
if line.startswith("{") and line.endswith("}"):
|
||||
try:
|
||||
payload = json.loads(line)
|
||||
break
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
if payload is None:
|
||||
self._add(
|
||||
Severity.ERROR,
|
||||
"import_failed",
|
||||
"Could not import your crew/flow module",
|
||||
detail=(proc.stderr or proc.stdout or "").strip()[:1500],
|
||||
hint="Run `crewai run` locally first to reproduce the error.",
|
||||
)
|
||||
return
|
||||
|
||||
if payload.get("ok"):
|
||||
if payload.get("count", 0) == 0:
|
||||
kind = payload.get("kind", "crew")
|
||||
if kind == "flow":
|
||||
self._add(
|
||||
Severity.ERROR,
|
||||
"no_flow_subclass",
|
||||
"No Flow subclass found in the module",
|
||||
hint=(
|
||||
"main.py must define a class extending "
|
||||
"`crewai.flow.Flow`, instantiable with no arguments."
|
||||
),
|
||||
)
|
||||
else:
|
||||
self._add(
|
||||
Severity.ERROR,
|
||||
"no_crewbase_class",
|
||||
"Crew class annotated with @CrewBase not found",
|
||||
hint=(
|
||||
"Decorate your crew class with @CrewBase from "
|
||||
"crewai.project (see `crewai create crew` template)."
|
||||
),
|
||||
)
|
||||
return
|
||||
|
||||
err_msg = str(payload.get("error", ""))
|
||||
err_type = str(payload.get("error_type", "Exception"))
|
||||
tb = str(payload.get("traceback", ""))
|
||||
self._classify_import_error(err_type, err_msg, tb)
|
||||
|
||||
def _classify_import_error(self, err_type: str, err_msg: str, tb: str) -> None:
|
||||
"""Turn a raw import-time exception into a user-actionable finding."""
|
||||
# Must be checked before the generic "native provider" branch below:
|
||||
# the extras-missing message contains the same phrase. Providers
|
||||
# format the install command as plain text (`to install: uv add
|
||||
# "crewai[extra]"`); also tolerate backtick-delimited variants.
|
||||
m = re.search(
|
||||
r"(?P<pkg>[A-Za-z0-9_ -]+?)\s+native provider not available"
|
||||
r".*?to install:\s*`?(?P<cmd>uv add [\"']crewai\[[^\]]+\][\"'])`?",
|
||||
err_msg,
|
||||
)
|
||||
if m:
|
||||
self._add(
|
||||
Severity.ERROR,
|
||||
"missing_provider_extra",
|
||||
f"{m.group('pkg').strip()} provider extra not installed",
|
||||
hint=f"Run: {m.group('cmd')}",
|
||||
)
|
||||
return
|
||||
|
||||
# crewai.llm.LLM.__new__ wraps provider init errors as
|
||||
# ImportError("Error importing native provider: ...").
|
||||
if "Error importing native provider" in err_msg or "native provider" in err_msg:
|
||||
missing_key = self._extract_missing_api_key(err_msg)
|
||||
if missing_key:
|
||||
provider = _KNOWN_API_KEY_HINTS.get(missing_key, missing_key)
|
||||
self._add(
|
||||
Severity.WARNING,
|
||||
"llm_init_missing_key",
|
||||
f"LLM is constructed at import time but {missing_key} is not set",
|
||||
detail=(
|
||||
f"Your crew instantiates a {provider} LLM during module "
|
||||
"load (e.g. in a class field default or @crew method). "
|
||||
f"The {provider} provider currently requires {missing_key} "
|
||||
"at construction time, so this will fail on the platform "
|
||||
"unless the key is set in your deployment environment."
|
||||
),
|
||||
hint=(
|
||||
f"Add {missing_key} to your deployment's Environment "
|
||||
"Variables before deploying, or move LLM construction "
|
||||
"inside agent methods so it runs lazily."
|
||||
),
|
||||
)
|
||||
return
|
||||
self._add(
|
||||
Severity.ERROR,
|
||||
"llm_provider_init_failed",
|
||||
"LLM native provider failed to initialize",
|
||||
detail=err_msg,
|
||||
hint=(
|
||||
"Check your LLM(model=...) configuration and provider-specific "
|
||||
"extras (e.g. `uv add 'crewai[azure-ai-inference]'` for Azure)."
|
||||
),
|
||||
)
|
||||
return
|
||||
|
||||
if err_type == "KeyError":
|
||||
key = err_msg.strip("'\"")
|
||||
if key in _KNOWN_API_KEY_HINTS or key.endswith("_API_KEY"):
|
||||
self._add(
|
||||
Severity.WARNING,
|
||||
"env_var_read_at_import",
|
||||
f"{key} is read at import time via os.environ[...]",
|
||||
detail=(
|
||||
"Using os.environ[...] (rather than os.getenv(...)) "
|
||||
"at module scope crashes the build if the key isn't set."
|
||||
),
|
||||
hint=(
|
||||
f"Either add {key} as a deployment env var, or switch "
|
||||
"to os.getenv() and move the access inside agent methods."
|
||||
),
|
||||
)
|
||||
return
|
||||
|
||||
if "Crew class annotated with @CrewBase not found" in err_msg:
|
||||
self._add(
|
||||
Severity.ERROR,
|
||||
"no_crewbase_class",
|
||||
"Crew class annotated with @CrewBase not found",
|
||||
detail=err_msg,
|
||||
)
|
||||
return
|
||||
if "No Flow subclass found" in err_msg:
|
||||
self._add(
|
||||
Severity.ERROR,
|
||||
"no_flow_subclass",
|
||||
"No Flow subclass found in the module",
|
||||
detail=err_msg,
|
||||
)
|
||||
return
|
||||
|
||||
if (
|
||||
err_type == "AttributeError"
|
||||
and "has no attribute '_load_response_format'" in err_msg
|
||||
):
|
||||
self._add(
|
||||
Severity.ERROR,
|
||||
"stale_crewai_pin",
|
||||
"Your lockfile pins a crewai version missing `_load_response_format`",
|
||||
detail=err_msg,
|
||||
hint=(
|
||||
"Run `uv lock --upgrade-package crewai` (or `poetry update crewai`) "
|
||||
"to pin a newer release."
|
||||
),
|
||||
)
|
||||
return
|
||||
|
||||
if "pydantic" in tb.lower() or "validation error" in err_msg.lower():
|
||||
self._add(
|
||||
Severity.ERROR,
|
||||
"pydantic_validation_error",
|
||||
"Pydantic validation failed while loading your crew",
|
||||
detail=err_msg[:800],
|
||||
hint=(
|
||||
"Check agent/task configuration fields. `crewai run` locally "
|
||||
"will show the full traceback."
|
||||
),
|
||||
)
|
||||
return
|
||||
|
||||
self._add(
|
||||
Severity.ERROR,
|
||||
"import_failed",
|
||||
f"Importing your crew failed: {err_type}",
|
||||
detail=err_msg[:800],
|
||||
hint="Run `crewai run` locally to see the full traceback.",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _extract_missing_api_key(err_msg: str) -> str | None:
|
||||
"""Pull 'FOO_API_KEY' out of '... FOO_API_KEY is required ...'."""
|
||||
m = re.search(r"([A-Z][A-Z0-9_]*_API_KEY)\s+is required", err_msg)
|
||||
if m:
|
||||
return m.group(1)
|
||||
m = re.search(r"['\"]([A-Z][A-Z0-9_]*_API_KEY)['\"]", err_msg)
|
||||
if m:
|
||||
return m.group(1)
|
||||
return None
|
||||
|
||||
def _check_env_vars(self) -> None:
|
||||
"""Warn about env vars referenced in user code but missing locally.
|
||||
Best-effort only — the platform sets vars server-side, so we never error.
|
||||
"""
|
||||
if not self._package_dir:
|
||||
return
|
||||
|
||||
referenced: set[str] = set()
|
||||
pattern = re.compile(
|
||||
r"""(?x)
|
||||
(?:os\.environ\s*(?:\[\s*|\.get\s*\(\s*)
|
||||
|os\.getenv\s*\(\s*
|
||||
|getenv\s*\(\s*)
|
||||
['"]([A-Z][A-Z0-9_]*)['"]
|
||||
"""
|
||||
)
|
||||
|
||||
for path in self._package_dir.rglob("*.py"):
|
||||
try:
|
||||
text = path.read_text(encoding="utf-8", errors="ignore")
|
||||
except OSError:
|
||||
continue
|
||||
referenced.update(pattern.findall(text))
|
||||
|
||||
for path in self._package_dir.rglob("*.yaml"):
|
||||
try:
|
||||
text = path.read_text(encoding="utf-8", errors="ignore")
|
||||
except OSError:
|
||||
continue
|
||||
referenced.update(re.findall(r"\$\{?([A-Z][A-Z0-9_]+)\}?", text))
|
||||
|
||||
env_file = self.project_root / ".env"
|
||||
env_keys: set[str] = set()
|
||||
if env_file.exists():
|
||||
for line in env_file.read_text(errors="ignore").splitlines():
|
||||
line = line.strip()
|
||||
if not line or line.startswith("#") or "=" not in line:
|
||||
continue
|
||||
env_keys.add(line.split("=", 1)[0].strip())
|
||||
|
||||
missing_known: list[str] = sorted(
|
||||
var
|
||||
for var in referenced
|
||||
if var in _KNOWN_API_KEY_HINTS
|
||||
and var not in env_keys
|
||||
and var not in os.environ
|
||||
)
|
||||
if missing_known:
|
||||
self._add(
|
||||
Severity.WARNING,
|
||||
"env_vars_not_in_dotenv",
|
||||
f"{len(missing_known)} referenced API key(s) not in .env",
|
||||
detail=(
|
||||
"These env vars are referenced in your source but not set "
|
||||
f"locally: {', '.join(missing_known)}. Deploys will fail "
|
||||
"unless they are added to the deployment's Environment "
|
||||
"Variables in the CrewAI dashboard."
|
||||
),
|
||||
)
|
||||
|
||||
def _check_version_vs_lockfile(self) -> None:
|
||||
"""Warn when the lockfile pins a crewai release older than 1.13.0,
|
||||
which is where ``_load_response_format`` was introduced.
|
||||
"""
|
||||
uv_lock = self.project_root / "uv.lock"
|
||||
poetry_lock = self.project_root / "poetry.lock"
|
||||
lockfile = (
|
||||
uv_lock
|
||||
if uv_lock.exists()
|
||||
else poetry_lock
|
||||
if poetry_lock.exists()
|
||||
else None
|
||||
)
|
||||
if lockfile is None:
|
||||
return
|
||||
|
||||
try:
|
||||
text = lockfile.read_text(errors="ignore")
|
||||
except OSError:
|
||||
return
|
||||
|
||||
m = re.search(
|
||||
r'name\s*=\s*"crewai"\s*\nversion\s*=\s*"([^"]+)"',
|
||||
text,
|
||||
)
|
||||
if not m:
|
||||
return
|
||||
locked = m.group(1)
|
||||
|
||||
try:
|
||||
from packaging.version import Version
|
||||
|
||||
if Version(locked) < Version("1.13.0"):
|
||||
self._add(
|
||||
Severity.WARNING,
|
||||
"old_crewai_pin",
|
||||
f"Lockfile pins crewai=={locked} (older than 1.13.0)",
|
||||
detail=(
|
||||
"Older pinned versions are missing API surface the "
|
||||
"platform builder expects (e.g. `_load_response_format`)."
|
||||
),
|
||||
hint="Run `uv lock --upgrade-package crewai` and redeploy.",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug("Could not parse crewai pin from lockfile: %s", e)
|
||||
|
||||
|
||||
def render_report(results: list[ValidationResult]) -> None:
|
||||
"""Pretty-print results to the shared rich console."""
|
||||
if not results:
|
||||
console.print("[bold green]Pre-deploy validation passed.[/bold green]")
|
||||
return
|
||||
|
||||
errors = [r for r in results if r.severity is Severity.ERROR]
|
||||
warnings = [r for r in results if r.severity is Severity.WARNING]
|
||||
|
||||
for result in errors:
|
||||
console.print(f"[bold red]ERROR[/bold red] [{result.code}] {result.title}")
|
||||
if result.detail:
|
||||
console.print(f" {result.detail}")
|
||||
if result.hint:
|
||||
console.print(f" [dim]hint:[/dim] {result.hint}")
|
||||
|
||||
for result in warnings:
|
||||
console.print(
|
||||
f"[bold yellow]WARNING[/bold yellow] [{result.code}] {result.title}"
|
||||
)
|
||||
if result.detail:
|
||||
console.print(f" {result.detail}")
|
||||
if result.hint:
|
||||
console.print(f" [dim]hint:[/dim] {result.hint}")
|
||||
|
||||
summary_parts: list[str] = []
|
||||
if errors:
|
||||
summary_parts.append(f"[bold red]{len(errors)} error(s)[/bold red]")
|
||||
if warnings:
|
||||
summary_parts.append(f"[bold yellow]{len(warnings)} warning(s)[/bold yellow]")
|
||||
console.print(f"\n{' / '.join(summary_parts)}")
|
||||
|
||||
|
||||
def validate_project(project_root: Path | None = None) -> DeployValidator:
|
||||
"""Entrypoint: run validation, render results, return the validator.
|
||||
|
||||
The caller inspects ``validator.ok`` to decide whether to proceed with a
|
||||
deploy.
|
||||
"""
|
||||
validator = DeployValidator(project_root=project_root)
|
||||
validator.run()
|
||||
render_report(validator.results)
|
||||
return validator
|
||||
|
||||
|
||||
def run_validate_command() -> None:
|
||||
"""Implementation of `crewai deploy validate`."""
|
||||
validator = validate_project()
|
||||
if not validator.ok:
|
||||
sys.exit(1)
|
||||
0
lib/cli/src/crewai_cli/enterprise/__init__.py
Normal file
0
lib/cli/src/crewai_cli/enterprise/__init__.py
Normal file
125
lib/cli/src/crewai_cli/enterprise/main.py
Normal file
125
lib/cli/src/crewai_cli/enterprise/main.py
Normal file
@@ -0,0 +1,125 @@
|
||||
import json
|
||||
from typing import Any, cast
|
||||
|
||||
import httpx
|
||||
from rich.console import Console
|
||||
|
||||
from crewai_cli.authentication.main import Oauth2Settings, ProviderFactory
|
||||
from crewai_cli.command import BaseCommand
|
||||
from crewai_cli.settings.main import SettingsCommand
|
||||
from crewai_cli.version import get_crewai_version
|
||||
|
||||
|
||||
console = Console()
|
||||
|
||||
|
||||
class EnterpriseConfigureCommand(BaseCommand):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.settings_command = SettingsCommand()
|
||||
|
||||
def configure(self, enterprise_url: str) -> None:
|
||||
try:
|
||||
enterprise_url = enterprise_url.rstrip("/")
|
||||
|
||||
oauth_config = self._fetch_oauth_config(enterprise_url)
|
||||
|
||||
self._update_oauth_settings(enterprise_url, oauth_config)
|
||||
|
||||
console.print(
|
||||
f"✅ Successfully configured CrewAI AMP with OAuth2 settings from {enterprise_url}",
|
||||
style="bold green",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
console.print(
|
||||
f"❌ Failed to configure Enterprise settings: {e!s}", style="bold red"
|
||||
)
|
||||
raise SystemExit(1) from e
|
||||
|
||||
def _fetch_oauth_config(self, enterprise_url: str) -> dict[str, Any]:
|
||||
oauth_endpoint = f"{enterprise_url}/auth/parameters"
|
||||
|
||||
try:
|
||||
console.print(f"🔄 Fetching OAuth2 configuration from {oauth_endpoint}...")
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"User-Agent": f"CrewAI-CLI/{get_crewai_version()}",
|
||||
"X-Crewai-Version": get_crewai_version(),
|
||||
}
|
||||
response = httpx.get(oauth_endpoint, timeout=30, headers=headers)
|
||||
response.raise_for_status()
|
||||
|
||||
try:
|
||||
oauth_config = response.json()
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError(f"Invalid JSON response from {oauth_endpoint}") from e
|
||||
|
||||
self._validate_oauth_config(oauth_config)
|
||||
|
||||
console.print(
|
||||
"✅ Successfully retrieved OAuth2 configuration", style="green"
|
||||
)
|
||||
return cast(dict[str, Any], oauth_config)
|
||||
|
||||
except httpx.HTTPError as e:
|
||||
raise ValueError(f"Failed to connect to enterprise URL: {e!s}") from e
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error fetching OAuth2 configuration: {e!s}") from e
|
||||
|
||||
def _update_oauth_settings(
|
||||
self, enterprise_url: str, oauth_config: dict[str, Any]
|
||||
) -> None:
|
||||
try:
|
||||
config_mapping = {
|
||||
"enterprise_base_url": enterprise_url,
|
||||
"oauth2_provider": oauth_config["provider"],
|
||||
"oauth2_audience": oauth_config["audience"],
|
||||
"oauth2_client_id": oauth_config["device_authorization_client_id"],
|
||||
"oauth2_domain": oauth_config["domain"],
|
||||
"oauth2_extra": oauth_config["extra"],
|
||||
}
|
||||
|
||||
console.print("🔄 Updating local OAuth2 configuration...")
|
||||
|
||||
for key, value in config_mapping.items():
|
||||
self.settings_command.set(key, value)
|
||||
console.print(f" ✓ Set {key}: {value}", style="dim")
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to update OAuth2 settings: {e!s}") from e
|
||||
|
||||
def _validate_oauth_config(self, oauth_config: dict[str, Any]) -> None:
|
||||
required_fields = [
|
||||
"audience",
|
||||
"domain",
|
||||
"device_authorization_client_id",
|
||||
"provider",
|
||||
"extra",
|
||||
]
|
||||
|
||||
missing_basic_fields = [
|
||||
field for field in required_fields if field not in oauth_config
|
||||
]
|
||||
missing_provider_specific_fields = [
|
||||
field
|
||||
for field in self._get_provider_specific_fields(oauth_config["provider"])
|
||||
if field not in oauth_config.get("extra", {})
|
||||
]
|
||||
|
||||
if missing_basic_fields:
|
||||
raise ValueError(
|
||||
f"Missing required fields in OAuth2 configuration: [{', '.join(missing_basic_fields)}]"
|
||||
)
|
||||
|
||||
if missing_provider_specific_fields:
|
||||
raise ValueError(
|
||||
f"Missing authentication provider required fields in OAuth2 configuration: [{', '.join(missing_provider_specific_fields)}] (Configured provider: '{oauth_config['provider']}')"
|
||||
)
|
||||
|
||||
def _get_provider_specific_fields(self, provider_name: str) -> list[str]:
|
||||
provider = ProviderFactory.from_settings(
|
||||
Oauth2Settings(provider=provider_name, client_id="dummy", domain="dummy")
|
||||
)
|
||||
|
||||
return provider.get_required_fields()
|
||||
41
lib/cli/src/crewai_cli/evaluate_crew.py
Normal file
41
lib/cli/src/crewai_cli/evaluate_crew.py
Normal file
@@ -0,0 +1,41 @@
|
||||
import subprocess
|
||||
|
||||
import click
|
||||
from crewai_core.constants import CREWAI_TRAINED_AGENTS_FILE_ENV
|
||||
|
||||
from crewai_cli.utils import build_env_with_all_tool_credentials
|
||||
|
||||
|
||||
def evaluate_crew(
|
||||
n_iterations: int, model: str, trained_agents_file: str | None = None
|
||||
) -> None:
|
||||
"""Test and Evaluate the crew by running a command in the UV environment.
|
||||
|
||||
Args:
|
||||
n_iterations: The number of iterations to test the crew.
|
||||
model: The model to test the crew with.
|
||||
trained_agents_file: Optional trained-agents pickle path forwarded to
|
||||
the subprocess via the ``CREWAI_TRAINED_AGENTS_FILE`` env var.
|
||||
"""
|
||||
command = ["uv", "run", "test", str(n_iterations), model]
|
||||
env = build_env_with_all_tool_credentials()
|
||||
if trained_agents_file:
|
||||
env[CREWAI_TRAINED_AGENTS_FILE_ENV] = trained_agents_file
|
||||
|
||||
try:
|
||||
if n_iterations <= 0:
|
||||
raise ValueError("The number of iterations must be a positive integer.")
|
||||
|
||||
result = subprocess.run( # noqa: S603
|
||||
command, capture_output=False, text=True, check=True, env=env
|
||||
)
|
||||
|
||||
if result.stderr:
|
||||
click.echo(result.stderr, err=True)
|
||||
|
||||
except subprocess.CalledProcessError as e:
|
||||
click.echo(f"An error occurred while testing the crew: {e}", err=True)
|
||||
click.echo(e.output, err=True)
|
||||
|
||||
except Exception as e:
|
||||
click.echo(f"An unexpected error occurred: {e}", err=True)
|
||||
89
lib/cli/src/crewai_cli/git.py
Normal file
89
lib/cli/src/crewai_cli/git.py
Normal file
@@ -0,0 +1,89 @@
|
||||
from functools import lru_cache
|
||||
import subprocess
|
||||
|
||||
|
||||
class Repository:
|
||||
def __init__(self, path: str = ".") -> None:
|
||||
self.path = path
|
||||
|
||||
if not self.is_git_installed():
|
||||
raise ValueError("Git is not installed or not found in your PATH.")
|
||||
|
||||
if not self.is_git_repo():
|
||||
raise ValueError(f"{self.path} is not a Git repository.")
|
||||
|
||||
self.fetch()
|
||||
|
||||
@staticmethod
|
||||
def is_git_installed() -> bool:
|
||||
"""Check if Git is installed and available in the system."""
|
||||
try:
|
||||
subprocess.run(
|
||||
["git", "--version"], # noqa: S607
|
||||
capture_output=True,
|
||||
check=True,
|
||||
text=True,
|
||||
)
|
||||
return True
|
||||
except (subprocess.CalledProcessError, FileNotFoundError):
|
||||
return False
|
||||
|
||||
def fetch(self) -> None:
|
||||
"""Fetch latest updates from the remote."""
|
||||
subprocess.run(["git", "fetch"], cwd=self.path, check=True) # noqa: S607
|
||||
|
||||
def status(self) -> str:
|
||||
"""Get the git status in porcelain format."""
|
||||
return subprocess.check_output(
|
||||
["git", "status", "--branch", "--porcelain"], # noqa: S607
|
||||
cwd=self.path,
|
||||
encoding="utf-8",
|
||||
).strip()
|
||||
|
||||
@lru_cache(maxsize=None) # noqa: B019
|
||||
def is_git_repo(self) -> bool:
|
||||
"""Check if the current directory is a git repository.
|
||||
|
||||
Notes:
|
||||
- TODO: This method is cached to avoid redundant checks, but using lru_cache on methods can lead to memory leaks
|
||||
"""
|
||||
try:
|
||||
subprocess.check_output(
|
||||
["git", "rev-parse", "--is-inside-work-tree"], # noqa: S607
|
||||
cwd=self.path,
|
||||
encoding="utf-8",
|
||||
)
|
||||
return True
|
||||
except subprocess.CalledProcessError:
|
||||
return False
|
||||
|
||||
def has_uncommitted_changes(self) -> bool:
|
||||
"""Check if the repository has uncommitted changes."""
|
||||
return len(self.status().splitlines()) > 1
|
||||
|
||||
def is_ahead_or_behind(self) -> bool:
|
||||
"""Check if the repository is ahead or behind the remote."""
|
||||
for line in self.status().splitlines():
|
||||
if line.startswith("##") and ("ahead" in line or "behind" in line):
|
||||
return True
|
||||
return False
|
||||
|
||||
def is_synced(self) -> bool:
|
||||
"""Return True if the Git repository is fully synced with the remote, False otherwise."""
|
||||
if self.has_uncommitted_changes() or self.is_ahead_or_behind():
|
||||
return False
|
||||
return True
|
||||
|
||||
def origin_url(self) -> str | None:
|
||||
"""Get the Git repository's remote URL."""
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["git", "remote", "get-url", "origin"], # noqa: S607
|
||||
cwd=self.path,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=True,
|
||||
)
|
||||
return result.stdout.strip()
|
||||
except subprocess.CalledProcessError:
|
||||
return None
|
||||
32
lib/cli/src/crewai_cli/install_crew.py
Normal file
32
lib/cli/src/crewai_cli/install_crew.py
Normal file
@@ -0,0 +1,32 @@
|
||||
import subprocess
|
||||
|
||||
import click
|
||||
|
||||
from crewai_cli.utils import build_env_with_all_tool_credentials
|
||||
|
||||
|
||||
# Be mindful about changing this.
|
||||
# on some environments we don't use this command but instead uv sync directly
|
||||
# so if you expect this to support more things you will need to replicate it there
|
||||
# ask @joaomdmoura if you are unsure
|
||||
def install_crew(proxy_options: list[str]) -> None:
|
||||
"""
|
||||
Install the crew by running the UV command to lock and install.
|
||||
"""
|
||||
try:
|
||||
command = ["uv", "sync", *proxy_options]
|
||||
|
||||
# Inject tool repository credentials so uv can authenticate
|
||||
# against private package indexes (e.g. crewai tool repository).
|
||||
# Without this, `uv sync` fails with 401 Unauthorized when the
|
||||
# project depends on tools from a private index.
|
||||
env = build_env_with_all_tool_credentials()
|
||||
|
||||
subprocess.run(command, check=True, capture_output=False, text=True, env=env) # noqa: S603
|
||||
|
||||
except subprocess.CalledProcessError as e:
|
||||
click.echo(f"An error occurred while running the crew: {e}", err=True)
|
||||
click.echo(e.output, err=True)
|
||||
|
||||
except Exception as e:
|
||||
click.echo(f"An unexpected error occurred: {e}", err=True)
|
||||
23
lib/cli/src/crewai_cli/kickoff_flow.py
Normal file
23
lib/cli/src/crewai_cli/kickoff_flow.py
Normal file
@@ -0,0 +1,23 @@
|
||||
import subprocess
|
||||
|
||||
import click
|
||||
|
||||
|
||||
def kickoff_flow() -> None:
|
||||
"""
|
||||
Kickoff the flow by running a command in the UV environment.
|
||||
"""
|
||||
command = ["uv", "run", "kickoff"]
|
||||
|
||||
try:
|
||||
result = subprocess.run(command, capture_output=False, text=True, check=True) # noqa: S603
|
||||
|
||||
if result.stderr:
|
||||
click.echo(result.stderr, err=True)
|
||||
|
||||
except subprocess.CalledProcessError as e:
|
||||
click.echo(f"An error occurred while running the flow: {e}", err=True)
|
||||
click.echo(e.output, err=True)
|
||||
|
||||
except Exception as e:
|
||||
click.echo(f"An unexpected error occurred: {e}", err=True)
|
||||
403
lib/cli/src/crewai_cli/memory_tui.py
Normal file
403
lib/cli/src/crewai_cli/memory_tui.py
Normal file
@@ -0,0 +1,403 @@
|
||||
"""Textual TUI for browsing and recalling unified memory."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import Any
|
||||
|
||||
from textual.app import App, ComposeResult
|
||||
from textual.containers import Horizontal, Vertical
|
||||
from textual.widgets import Footer, Header, Input, OptionList, Static, Tree
|
||||
|
||||
|
||||
# -- CrewAI brand palette --
|
||||
_PRIMARY = "#eb6658" # coral
|
||||
_SECONDARY = "#1F7982" # teal
|
||||
_TERTIARY = "#ffffff" # white
|
||||
|
||||
|
||||
def _format_scope_info(info: Any) -> str:
|
||||
"""Format ScopeInfo with Rich markup."""
|
||||
return (
|
||||
f"[bold {_PRIMARY}]{info.path}[/]\n\n"
|
||||
f"[dim]Records:[/] [bold]{info.record_count}[/]\n"
|
||||
f"[dim]Categories:[/] {', '.join(info.categories) or 'none'}\n"
|
||||
f"[dim]Oldest:[/] {info.oldest_record or '-'}\n"
|
||||
f"[dim]Newest:[/] {info.newest_record or '-'}\n"
|
||||
f"[dim]Children:[/] {', '.join(info.child_scopes) or 'none'}"
|
||||
)
|
||||
|
||||
|
||||
class MemoryTUI(App[None]):
|
||||
"""TUI to browse memory scopes and run recall queries."""
|
||||
|
||||
TITLE = "CrewAI Memory"
|
||||
SUB_TITLE = "Browse scopes and recall memories"
|
||||
|
||||
CSS = f"""
|
||||
Header {{
|
||||
background: {_PRIMARY};
|
||||
color: {_TERTIARY};
|
||||
}}
|
||||
Footer {{
|
||||
background: {_SECONDARY};
|
||||
color: {_TERTIARY};
|
||||
}}
|
||||
Footer > .footer-key--key {{
|
||||
background: {_PRIMARY};
|
||||
color: {_TERTIARY};
|
||||
}}
|
||||
Horizontal {{
|
||||
height: 1fr;
|
||||
}}
|
||||
#scope-tree {{
|
||||
width: 30%;
|
||||
padding: 1 2;
|
||||
background: {_SECONDARY} 8%;
|
||||
border-right: solid {_SECONDARY};
|
||||
}}
|
||||
#scope-tree:focus > .tree--cursor {{
|
||||
background: {_SECONDARY};
|
||||
color: {_TERTIARY};
|
||||
}}
|
||||
#scope-tree > .tree--guides {{
|
||||
color: {_SECONDARY} 50%;
|
||||
}}
|
||||
#scope-tree > .tree--guides-hover {{
|
||||
color: {_PRIMARY};
|
||||
}}
|
||||
#scope-tree > .tree--guides-selected {{
|
||||
color: {_SECONDARY};
|
||||
}}
|
||||
#right-panel {{
|
||||
width: 70%;
|
||||
padding: 0 1;
|
||||
}}
|
||||
#info-panel {{
|
||||
height: 2fr;
|
||||
padding: 1 2;
|
||||
overflow-y: auto;
|
||||
border: round {_SECONDARY};
|
||||
}}
|
||||
#info-panel:focus {{
|
||||
border: round {_PRIMARY};
|
||||
}}
|
||||
#info-panel LoadingIndicator {{
|
||||
color: {_PRIMARY};
|
||||
}}
|
||||
#entry-list {{
|
||||
height: 1fr;
|
||||
border: round {_SECONDARY};
|
||||
padding: 0 1;
|
||||
scrollbar-color: {_PRIMARY};
|
||||
}}
|
||||
#entry-list:focus {{
|
||||
border: round {_PRIMARY};
|
||||
}}
|
||||
#entry-list > .option-list--option-highlighted {{
|
||||
background: {_SECONDARY};
|
||||
color: {_TERTIARY};
|
||||
}}
|
||||
#recall-input {{
|
||||
margin: 0 1 1 1;
|
||||
border: tall {_SECONDARY};
|
||||
}}
|
||||
#recall-input:focus {{
|
||||
border: tall {_PRIMARY};
|
||||
}}
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
storage_path: str | None = None,
|
||||
embedder_config: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self._memory: Any = None
|
||||
self._init_error: str | None = None
|
||||
self._selected_scope: str = "/"
|
||||
self._entries: list[Any] = []
|
||||
self._view_mode: str = "list" # "list" | "recall"
|
||||
self._recall_matches: list[Any] = []
|
||||
self._last_scope_info: Any = None
|
||||
self._custom_embedder = embedder_config is not None
|
||||
try:
|
||||
from crewai.memory.storage.lancedb_storage import LanceDBStorage
|
||||
from crewai.memory.unified_memory import Memory
|
||||
|
||||
storage = (
|
||||
LanceDBStorage(path=storage_path) if storage_path else LanceDBStorage()
|
||||
)
|
||||
embedder = None
|
||||
if embedder_config is not None:
|
||||
from crewai.rag.embeddings.factory import build_embedder
|
||||
|
||||
embedder = build_embedder(embedder_config)
|
||||
self._memory = (
|
||||
Memory(storage=storage, embedder=embedder)
|
||||
if embedder
|
||||
else Memory(storage=storage)
|
||||
)
|
||||
except Exception as e:
|
||||
self._init_error = str(e)
|
||||
|
||||
def compose(self) -> ComposeResult:
|
||||
yield Header(show_clock=False)
|
||||
with Horizontal():
|
||||
yield self._build_scope_tree()
|
||||
initial = (
|
||||
self._init_error
|
||||
if self._init_error
|
||||
else "Select a scope or type a recall query."
|
||||
)
|
||||
with Vertical(id="right-panel"):
|
||||
yield Static(initial, id="info-panel")
|
||||
yield OptionList(id="entry-list")
|
||||
yield Input(
|
||||
placeholder="Type a query and press Enter to recall...",
|
||||
id="recall-input",
|
||||
)
|
||||
yield Footer()
|
||||
|
||||
def on_mount(self) -> None:
|
||||
"""Set initial border titles on mounted widgets."""
|
||||
self.query_one("#info-panel", Static).border_title = "Detail"
|
||||
self.query_one("#entry-list", OptionList).border_title = "Entries"
|
||||
|
||||
def _build_scope_tree(self) -> Tree[str]:
|
||||
tree: Tree[str] = Tree("/", id="scope-tree")
|
||||
if self._memory is None:
|
||||
tree.root.data = "/"
|
||||
tree.root.label = "/ (0 records)"
|
||||
return tree
|
||||
info = self._memory.info("/")
|
||||
tree.root.label = f"/ ({info.record_count} records)"
|
||||
tree.root.data = "/"
|
||||
self._add_scope_children(tree.root, "/", depth=0, max_depth=3)
|
||||
tree.root.expand()
|
||||
return tree
|
||||
|
||||
def _add_scope_children(
|
||||
self,
|
||||
parent_node: Any,
|
||||
path: str,
|
||||
depth: int,
|
||||
max_depth: int,
|
||||
) -> None:
|
||||
if depth >= max_depth or self._memory is None:
|
||||
return
|
||||
info = self._memory.info(path)
|
||||
for child in info.child_scopes:
|
||||
child_info = self._memory.info(child)
|
||||
label = f"{child} ({child_info.record_count})"
|
||||
node = parent_node.add(label, data=child)
|
||||
self._add_scope_children(node, child, depth + 1, max_depth)
|
||||
|
||||
# -- Populating the OptionList -------------------------------------------
|
||||
|
||||
def _populate_entry_list(self) -> None:
|
||||
"""Clear the OptionList and fill it with the current scope's entries."""
|
||||
option_list = self.query_one("#entry-list", OptionList)
|
||||
option_list.clear_options()
|
||||
for record in self._entries:
|
||||
date_str = record.created_at.strftime("%Y-%m-%d")
|
||||
preview = (
|
||||
(record.content[:80] + "…")
|
||||
if len(record.content) > 80
|
||||
else record.content
|
||||
)
|
||||
label = f"{date_str} [bold]{record.importance:.1f}[/] {preview}"
|
||||
option_list.add_option(label)
|
||||
|
||||
def _populate_recall_list(self) -> None:
|
||||
"""Clear the OptionList and fill it with the current recall matches."""
|
||||
option_list = self.query_one("#entry-list", OptionList)
|
||||
option_list.clear_options()
|
||||
if not self._recall_matches:
|
||||
return
|
||||
for m in self._recall_matches:
|
||||
preview = (
|
||||
(m.record.content[:80] + "…")
|
||||
if len(m.record.content) > 80
|
||||
else m.record.content
|
||||
)
|
||||
label = (
|
||||
f"[bold]\\[{m.score:.2f}][/] {preview} [dim]scope={m.record.scope}[/]"
|
||||
)
|
||||
option_list.add_option(label)
|
||||
|
||||
# -- Detail rendering ----------------------------------------------------
|
||||
|
||||
def _format_record_detail(self, record: Any, context_line: str = "") -> str:
|
||||
"""Format a full MemoryRecord as Rich markup for the detail view.
|
||||
|
||||
Args:
|
||||
record: A MemoryRecord instance.
|
||||
context_line: Optional header line shown above the fields
|
||||
(e.g. "Entry 3 of 47").
|
||||
|
||||
Returns:
|
||||
A Rich-markup string with all meaningful record fields.
|
||||
"""
|
||||
sep = f"[bold {_PRIMARY}]{'─' * 44}[/]"
|
||||
lines: list[str] = []
|
||||
|
||||
if context_line:
|
||||
lines.append(context_line)
|
||||
lines.append("")
|
||||
|
||||
# -- Fields block --
|
||||
lines.append(f"[dim]ID:[/] {record.id}")
|
||||
lines.append(f"[dim]Scope:[/] [bold]{record.scope}[/]")
|
||||
lines.append(f"[dim]Importance:[/] [bold]{record.importance:.2f}[/]")
|
||||
lines.append(
|
||||
f"[dim]Created:[/] {record.created_at.strftime('%Y-%m-%d %H:%M:%S')}"
|
||||
)
|
||||
lines.append(
|
||||
f"[dim]Last accessed:[/] "
|
||||
f"{record.last_accessed.strftime('%Y-%m-%d %H:%M:%S')}"
|
||||
)
|
||||
lines.append(
|
||||
f"[dim]Categories:[/] "
|
||||
f"{', '.join(record.categories) if record.categories else 'none'}"
|
||||
)
|
||||
lines.append(f"[dim]Source:[/] {record.source or '-'}")
|
||||
lines.append(f"[dim]Private:[/] {'Yes' if record.private else 'No'}")
|
||||
|
||||
# -- Content block --
|
||||
lines.append(f"\n{sep}")
|
||||
lines.append("[bold]Content[/]\n")
|
||||
lines.append(record.content)
|
||||
|
||||
# -- Metadata block --
|
||||
if record.metadata:
|
||||
lines.append(f"\n{sep}")
|
||||
lines.append("[bold]Metadata[/]\n")
|
||||
for k, v in record.metadata.items():
|
||||
lines.append(f"[dim]{k}:[/] {v}")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
# -- Event handlers ------------------------------------------------------
|
||||
|
||||
def on_tree_node_selected(self, event: Tree.NodeSelected[str]) -> None:
|
||||
"""Load entries for the selected scope and populate the OptionList."""
|
||||
path = event.node.data if event.node.data is not None else "/"
|
||||
self._selected_scope = path
|
||||
self._view_mode = "list"
|
||||
panel = self.query_one("#info-panel", Static)
|
||||
if self._memory is None:
|
||||
panel.update(self._init_error or "No memory loaded.")
|
||||
return
|
||||
display_limit = 1000
|
||||
info = self._memory.info(path)
|
||||
self._last_scope_info = info
|
||||
self._entries = self._memory.list_records(scope=path, limit=display_limit)
|
||||
panel.update(_format_scope_info(info))
|
||||
panel.border_title = "Detail"
|
||||
entry_list = self.query_one("#entry-list", OptionList)
|
||||
capped = info.record_count > display_limit
|
||||
count_label = (
|
||||
f"Entries (showing {display_limit} of {info.record_count} — display limit)"
|
||||
if capped
|
||||
else f"Entries ({len(self._entries)})"
|
||||
)
|
||||
entry_list.border_title = count_label
|
||||
self._populate_entry_list()
|
||||
|
||||
def on_option_list_option_highlighted(
|
||||
self, event: OptionList.OptionHighlighted
|
||||
) -> None:
|
||||
"""Live-update the info panel with the detail of the highlighted entry."""
|
||||
panel = self.query_one("#info-panel", Static)
|
||||
idx = event.option_index
|
||||
|
||||
if self._view_mode == "list":
|
||||
if idx < len(self._entries):
|
||||
record = self._entries[idx]
|
||||
total = len(self._entries)
|
||||
context = (
|
||||
f"[bold {_PRIMARY}]Entry {idx + 1} of {total}[/] "
|
||||
f"[dim]in[/] [bold]{self._selected_scope}[/]"
|
||||
)
|
||||
panel.border_title = f"Entry {idx + 1} of {total}"
|
||||
panel.update(self._format_record_detail(record, context_line=context))
|
||||
|
||||
elif self._view_mode == "recall":
|
||||
if idx < len(self._recall_matches):
|
||||
match = self._recall_matches[idx]
|
||||
total = len(self._recall_matches)
|
||||
panel.border_title = f"Match {idx + 1} of {total}"
|
||||
score_color = _PRIMARY if match.score >= 0.5 else "dim"
|
||||
header_lines: list[str] = [
|
||||
f"[bold {_PRIMARY}]Recall Match {idx + 1} of {total}[/]\n",
|
||||
f"[dim]Score:[/] [{score_color}][bold]{match.score:.2f}[/][/]",
|
||||
(
|
||||
f"[dim]Match reasons:[/] "
|
||||
f"{', '.join(match.match_reasons) if match.match_reasons else '-'}"
|
||||
),
|
||||
(
|
||||
f"[dim]Evidence gaps:[/] "
|
||||
f"{', '.join(match.evidence_gaps) if match.evidence_gaps else 'none'}"
|
||||
),
|
||||
f"\n[bold {_PRIMARY}]{'─' * 44}[/]",
|
||||
]
|
||||
record_detail = self._format_record_detail(match.record)
|
||||
header_lines.append(record_detail)
|
||||
panel.update("\n".join(header_lines))
|
||||
|
||||
def on_input_submitted(self, event: Input.Submitted) -> None:
|
||||
query = event.value.strip()
|
||||
if not query:
|
||||
return
|
||||
if self._memory is None:
|
||||
panel = self.query_one("#info-panel", Static)
|
||||
panel.update(self._init_error or "No memory loaded. Cannot recall.")
|
||||
return
|
||||
self.run_worker(self._do_recall(query), exclusive=True)
|
||||
|
||||
async def _do_recall(self, query: str) -> None:
|
||||
"""Execute a recall query and display results in the OptionList."""
|
||||
panel = self.query_one("#info-panel", Static)
|
||||
panel.loading = True
|
||||
try:
|
||||
scope = self._selected_scope if self._selected_scope != "/" else None
|
||||
loop = asyncio.get_event_loop()
|
||||
matches = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: self._memory.recall(query, scope=scope, limit=10, depth="deep"),
|
||||
)
|
||||
self._recall_matches = matches or []
|
||||
self._view_mode = "recall"
|
||||
|
||||
if not self._recall_matches:
|
||||
panel.update("[dim]No memories found.[/]")
|
||||
self.query_one("#entry-list", OptionList).clear_options()
|
||||
return
|
||||
|
||||
info_lines: list[str] = []
|
||||
info_lines.append(
|
||||
"[dim italic]Searched the full dataset"
|
||||
+ (f" within [bold]{scope}[/]" if scope else "")
|
||||
+ " using the recall flow (semantic + recency + importance).[/]\n"
|
||||
)
|
||||
if not self._custom_embedder:
|
||||
info_lines.append(
|
||||
"[dim italic]Note: Using default OpenAI embedder. "
|
||||
"If memories were created with a different embedder, "
|
||||
"pass --embedder-provider to match.[/]\n"
|
||||
)
|
||||
info_lines.append(
|
||||
f"[bold]Recall Results[/] [dim]"
|
||||
f"({len(self._recall_matches)} matches)[/]\n"
|
||||
f"[dim]Navigate the list below to view details.[/]"
|
||||
)
|
||||
panel.update("\n".join(info_lines))
|
||||
panel.border_title = "Recall Detail"
|
||||
entry_list = self.query_one("#entry-list", OptionList)
|
||||
entry_list.border_title = f"Recall Results ({len(self._recall_matches)})"
|
||||
self._populate_recall_list()
|
||||
except Exception as e:
|
||||
panel.update(f"[bold red]Error:[/] {e}")
|
||||
finally:
|
||||
panel.loading = False
|
||||
1
lib/cli/src/crewai_cli/organization/__init__.py
Normal file
1
lib/cli/src/crewai_cli/organization/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
|
||||
107
lib/cli/src/crewai_cli/organization/main.py
Normal file
107
lib/cli/src/crewai_cli/organization/main.py
Normal file
@@ -0,0 +1,107 @@
|
||||
from httpx import HTTPStatusError
|
||||
from rich.console import Console
|
||||
from rich.table import Table
|
||||
|
||||
from crewai_cli.command import BaseCommand, PlusAPIMixin
|
||||
from crewai_cli.config import Settings
|
||||
|
||||
|
||||
console = Console()
|
||||
|
||||
|
||||
class OrganizationCommand(BaseCommand, PlusAPIMixin):
|
||||
def __init__(self) -> None:
|
||||
BaseCommand.__init__(self)
|
||||
PlusAPIMixin.__init__(self, telemetry=self._telemetry)
|
||||
|
||||
def list(self) -> None:
|
||||
try:
|
||||
response = self.plus_api_client.get_organizations()
|
||||
response.raise_for_status()
|
||||
orgs = response.json()
|
||||
|
||||
if not orgs:
|
||||
console.print(
|
||||
"You don't belong to any organizations yet.", style="yellow"
|
||||
)
|
||||
return
|
||||
|
||||
table = Table(title="Your Organizations")
|
||||
table.add_column("Name", style="cyan")
|
||||
table.add_column("ID", style="green")
|
||||
for org in orgs:
|
||||
table.add_row(org["name"], org["uuid"])
|
||||
|
||||
console.print(table)
|
||||
except HTTPStatusError as e:
|
||||
if e.response.status_code == 401:
|
||||
console.print(
|
||||
"You are not logged in to any organization. Use 'crewai login' to login.",
|
||||
style="bold red",
|
||||
)
|
||||
return
|
||||
console.print(
|
||||
f"Failed to retrieve organization list: {e!s}", style="bold red"
|
||||
)
|
||||
raise SystemExit(1) from e
|
||||
except Exception as e:
|
||||
console.print(
|
||||
f"Failed to retrieve organization list: {e!s}", style="bold red"
|
||||
)
|
||||
raise SystemExit(1) from e
|
||||
|
||||
def switch(self, org_id: str) -> None:
|
||||
try:
|
||||
response = self.plus_api_client.get_organizations()
|
||||
response.raise_for_status()
|
||||
orgs = response.json()
|
||||
|
||||
org = next((o for o in orgs if o["uuid"] == org_id), None)
|
||||
if not org:
|
||||
console.print(
|
||||
f"Organization with id '{org_id}' not found.", style="bold red"
|
||||
)
|
||||
return
|
||||
|
||||
settings = Settings()
|
||||
settings.org_name = org["name"]
|
||||
settings.org_uuid = org["uuid"]
|
||||
settings.dump()
|
||||
|
||||
console.print(
|
||||
f"Successfully switched to {org['name']} ({org['uuid']})",
|
||||
style="bold green",
|
||||
)
|
||||
except HTTPStatusError as e:
|
||||
if e.response.status_code == 401:
|
||||
console.print(
|
||||
"You are not logged in to any organization. Use 'crewai login' to login.",
|
||||
style="bold red",
|
||||
)
|
||||
return
|
||||
console.print(
|
||||
f"Failed to retrieve organization list: {e!s}", style="bold red"
|
||||
)
|
||||
raise SystemExit(1) from e
|
||||
except Exception as e:
|
||||
console.print(f"Failed to switch organization: {e!s}", style="bold red")
|
||||
raise SystemExit(1) from e
|
||||
|
||||
def current(self) -> None:
|
||||
settings = Settings()
|
||||
if settings.org_uuid:
|
||||
console.print(
|
||||
f"Currently logged in to organization {settings.org_name} ({settings.org_uuid})",
|
||||
style="bold green",
|
||||
)
|
||||
else:
|
||||
console.print(
|
||||
"You're not currently logged in to any organization.", style="yellow"
|
||||
)
|
||||
console.print(
|
||||
"Use 'crewai org list' to see available organizations.", style="yellow"
|
||||
)
|
||||
console.print(
|
||||
"Use 'crewai org switch <id>' to switch to an organization.",
|
||||
style="yellow",
|
||||
)
|
||||
23
lib/cli/src/crewai_cli/plot_flow.py
Normal file
23
lib/cli/src/crewai_cli/plot_flow.py
Normal file
@@ -0,0 +1,23 @@
|
||||
import subprocess
|
||||
|
||||
import click
|
||||
|
||||
|
||||
def plot_flow() -> None:
|
||||
"""
|
||||
Plot the flow by running a command in the UV environment.
|
||||
"""
|
||||
command = ["uv", "run", "plot"]
|
||||
|
||||
try:
|
||||
result = subprocess.run(command, capture_output=False, text=True, check=True) # noqa: S603
|
||||
|
||||
if result.stderr:
|
||||
click.echo(result.stderr, err=True)
|
||||
|
||||
except subprocess.CalledProcessError as e:
|
||||
click.echo(f"An error occurred while plotting the flow: {e}", err=True)
|
||||
click.echo(e.output, err=True)
|
||||
|
||||
except Exception as e:
|
||||
click.echo(f"An unexpected error occurred: {e}", err=True)
|
||||
12
lib/cli/src/crewai_cli/plus_api.py
Normal file
12
lib/cli/src/crewai_cli/plus_api.py
Normal file
@@ -0,0 +1,12 @@
|
||||
"""Re-export of ``crewai_core.plus_api.PlusAPI``.
|
||||
|
||||
Kept as a stable import path for the CLI; new code should import from
|
||||
``crewai_core.plus_api`` directly.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from crewai_core.plus_api import PlusAPI as PlusAPI
|
||||
|
||||
|
||||
__all__ = ["PlusAPI"]
|
||||
231
lib/cli/src/crewai_cli/provider.py
Normal file
231
lib/cli/src/crewai_cli/provider.py
Normal file
@@ -0,0 +1,231 @@
|
||||
from collections import defaultdict
|
||||
from collections.abc import Sequence
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import certifi
|
||||
import click
|
||||
import httpx
|
||||
|
||||
from crewai_cli.constants import JSON_URL, MODELS, PROVIDERS
|
||||
|
||||
|
||||
def select_choice(prompt_message: str, choices: Sequence[str]) -> str | None:
|
||||
"""Presents a list of choices to the user and prompts them to select one.
|
||||
|
||||
Args:
|
||||
prompt_message: The message to display to the user before presenting the choices.
|
||||
choices: A list of options to present to the user.
|
||||
|
||||
Returns:
|
||||
The selected choice from the list, or None if the user chooses to quit.
|
||||
"""
|
||||
|
||||
provider_models = get_provider_data()
|
||||
if not provider_models:
|
||||
return None
|
||||
click.secho(prompt_message, fg="cyan")
|
||||
for idx, choice in enumerate(choices, start=1):
|
||||
click.secho(f"{idx}. {choice}", fg="cyan")
|
||||
click.secho("q. Quit", fg="cyan")
|
||||
|
||||
while True:
|
||||
choice = click.prompt(
|
||||
"Enter the number of your choice or 'q' to quit", type=str
|
||||
)
|
||||
|
||||
if choice.lower() == "q":
|
||||
return None
|
||||
|
||||
try:
|
||||
selected_index = int(choice) - 1
|
||||
if 0 <= selected_index < len(choices):
|
||||
return choices[selected_index]
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
click.secho(
|
||||
"Invalid selection. Please select a number between 1 and 6 or 'q' to quit.",
|
||||
fg="red",
|
||||
)
|
||||
|
||||
|
||||
def select_provider(provider_models: dict[str, list[str]]) -> str | None | bool:
|
||||
"""Presents a list of providers to the user and prompts them to select one.
|
||||
|
||||
Args:
|
||||
provider_models: A dictionary of provider models.
|
||||
|
||||
Returns:
|
||||
The selected provider, None if user explicitly quits, or False if no selection.
|
||||
"""
|
||||
predefined_providers = [p.lower() for p in PROVIDERS]
|
||||
all_providers = sorted(set(predefined_providers + list(provider_models.keys())))
|
||||
|
||||
provider = select_choice(
|
||||
"Select a provider to set up:", [*predefined_providers, "other"]
|
||||
)
|
||||
if provider is None: # User typed 'q'
|
||||
return None
|
||||
|
||||
if provider == "other":
|
||||
provider = select_choice("Select a provider from the full list:", all_providers)
|
||||
if provider is None: # User typed 'q'
|
||||
return None
|
||||
|
||||
return provider.lower() if provider else False
|
||||
|
||||
|
||||
def select_model(provider: str, provider_models: dict[str, list[str]]) -> str | None:
|
||||
"""Presents a list of models for a given provider to the user and prompts them to select one.
|
||||
|
||||
Args:
|
||||
provider: The provider for which to select a model.
|
||||
provider_models: A dictionary of provider models.
|
||||
|
||||
Returns:
|
||||
The selected model, or None if the operation is aborted or an invalid selection is made.
|
||||
"""
|
||||
predefined_providers = [p.lower() for p in PROVIDERS]
|
||||
|
||||
if provider in predefined_providers:
|
||||
available_models = MODELS.get(provider, [])
|
||||
else:
|
||||
available_models = provider_models.get(provider, [])
|
||||
|
||||
if not available_models:
|
||||
click.secho(f"No models available for provider '{provider}'.", fg="red")
|
||||
return None
|
||||
|
||||
return select_choice(
|
||||
f"Select a model to use for {provider.capitalize()}:", available_models
|
||||
)
|
||||
|
||||
|
||||
def load_provider_data(cache_file: Path, cache_expiry: int) -> dict[str, Any] | None:
|
||||
"""Loads provider data from a cache file if it exists and is not expired.
|
||||
|
||||
If the cache is expired or corrupted, it fetches the data from the web.
|
||||
|
||||
Args:
|
||||
cache_file: The path to the cache file.
|
||||
cache_expiry: The cache expiry time in seconds.
|
||||
|
||||
Returns:
|
||||
The loaded provider data or None if the operation fails.
|
||||
"""
|
||||
current_time = time.time()
|
||||
if (
|
||||
cache_file.exists()
|
||||
and (current_time - cache_file.stat().st_mtime) < cache_expiry
|
||||
):
|
||||
data = read_cache_file(cache_file)
|
||||
if data:
|
||||
return data
|
||||
click.secho(
|
||||
"Cache is corrupted. Fetching provider data from the web...", fg="yellow"
|
||||
)
|
||||
else:
|
||||
click.secho(
|
||||
"Cache expired or not found. Fetching provider data from the web...",
|
||||
fg="cyan",
|
||||
)
|
||||
return fetch_provider_data(cache_file)
|
||||
|
||||
|
||||
def read_cache_file(cache_file: Path) -> dict[str, Any] | None:
|
||||
"""Reads and returns the JSON content from a cache file.
|
||||
|
||||
Args:
|
||||
cache_file: The path to the cache file.
|
||||
|
||||
Returns:
|
||||
The JSON content of the cache file or None if the JSON is invalid.
|
||||
"""
|
||||
try:
|
||||
with open(cache_file, "r") as f:
|
||||
data: dict[str, Any] = json.load(f)
|
||||
return data
|
||||
except json.JSONDecodeError:
|
||||
return None
|
||||
|
||||
|
||||
def fetch_provider_data(cache_file: Path) -> dict[str, Any] | None:
|
||||
"""Fetches provider data from a specified URL and caches it to a file.
|
||||
|
||||
Args:
|
||||
cache_file: The path to the cache file.
|
||||
|
||||
Returns:
|
||||
The fetched provider data or None if the operation fails.
|
||||
"""
|
||||
ssl_config = os.environ["SSL_CERT_FILE"] = certifi.where()
|
||||
|
||||
try:
|
||||
with httpx.stream("GET", JSON_URL, timeout=60, verify=ssl_config) as response:
|
||||
response.raise_for_status()
|
||||
data = download_data(response)
|
||||
with open(cache_file, "w") as f:
|
||||
json.dump(data, f)
|
||||
return data
|
||||
except httpx.HTTPError as e:
|
||||
click.secho(f"Error fetching provider data: {e}", fg="red")
|
||||
except json.JSONDecodeError:
|
||||
click.secho("Error parsing provider data. Invalid JSON format.", fg="red")
|
||||
return None
|
||||
|
||||
|
||||
def download_data(response: httpx.Response) -> dict[str, Any]:
|
||||
"""Downloads data from a given HTTP response and returns the JSON content.
|
||||
|
||||
Args:
|
||||
response: The HTTP response object.
|
||||
|
||||
Returns:
|
||||
The JSON content of the response.
|
||||
"""
|
||||
total_size = int(response.headers.get("content-length", 0))
|
||||
block_size = 8192
|
||||
data_chunks: list[bytes] = []
|
||||
bar: Any
|
||||
with click.progressbar(
|
||||
length=total_size, label="Downloading", show_pos=True
|
||||
) as bar:
|
||||
for chunk in response.iter_bytes(block_size):
|
||||
if chunk:
|
||||
data_chunks.append(chunk)
|
||||
bar.update(len(chunk))
|
||||
data_content = b"".join(data_chunks)
|
||||
result: dict[str, Any] = json.loads(data_content.decode("utf-8"))
|
||||
return result
|
||||
|
||||
|
||||
def get_provider_data() -> dict[str, list[str]] | None:
|
||||
"""Retrieves provider data from a cache file.
|
||||
|
||||
Filters out models based on provider criteria, and returns a dictionary of providers
|
||||
mapped to their models.
|
||||
|
||||
Returns:
|
||||
A dictionary of providers mapped to their models or None if the operation fails.
|
||||
"""
|
||||
cache_dir = Path.home() / ".crewai"
|
||||
cache_dir.mkdir(exist_ok=True)
|
||||
cache_file = cache_dir / "provider_cache.json"
|
||||
cache_expiry = 24 * 3600
|
||||
|
||||
data = load_provider_data(cache_file, cache_expiry)
|
||||
if not data:
|
||||
return None
|
||||
|
||||
provider_models = defaultdict(list)
|
||||
for model_name, properties in data.items():
|
||||
provider = properties.get("litellm_provider", "").strip().lower()
|
||||
if "http" in provider or provider == "other":
|
||||
continue
|
||||
if provider:
|
||||
provider_models[provider].append(model_name)
|
||||
return provider_models
|
||||
0
lib/cli/src/crewai_cli/py.typed
Normal file
0
lib/cli/src/crewai_cli/py.typed
Normal file
0
lib/cli/src/crewai_cli/remote_template/__init__.py
Normal file
0
lib/cli/src/crewai_cli/remote_template/__init__.py
Normal file
250
lib/cli/src/crewai_cli/remote_template/main.py
Normal file
250
lib/cli/src/crewai_cli/remote_template/main.py
Normal file
@@ -0,0 +1,250 @@
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
from typing import Any
|
||||
import zipfile
|
||||
|
||||
import click
|
||||
import httpx
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
from rich.text import Text
|
||||
|
||||
from crewai_cli.command import BaseCommand
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
console = Console()
|
||||
|
||||
GITHUB_ORG = "crewAIInc"
|
||||
TEMPLATE_PREFIX = "template_"
|
||||
GITHUB_API_BASE = "https://api.github.com"
|
||||
|
||||
BANNER = """\
|
||||
[bold white] ██████╗██████╗ ███████╗██╗ ██╗[/bold white] [bold red] █████╗ ██╗[/bold red]
|
||||
[bold white]██╔════╝██╔══██╗██╔════╝██║ ██║[/bold white] [bold red]██╔══██╗██║[/bold red]
|
||||
[bold white]██║ ██████╔╝█████╗ ██║ █╗ ██║[/bold white] [bold red]███████║██║[/bold red]
|
||||
[bold white]██║ ██╔══██╗██╔══╝ ██║███╗██║[/bold white] [bold red]██╔══██║██║[/bold red]
|
||||
[bold white]╚██████╗██║ ██║███████╗╚███╔███╔╝[/bold white] [bold red]██║ ██║██║[/bold red]
|
||||
[bold white] ╚═════╝╚═╝ ╚═╝╚══════╝ ╚══╝╚══╝[/bold white] [bold red]╚═╝ ╚═╝╚═╝[/bold red]
|
||||
[dim white]████████╗███████╗███╗ ███╗██████╗ ██╗ █████╗ ████████╗███████╗███████╗[/dim white]
|
||||
[dim white]╚══██╔══╝██╔════╝████╗ ████║██╔══██╗██║ ██╔══██╗╚══██╔══╝██╔════╝██╔════╝[/dim white]
|
||||
[dim white] ██║ █████╗ ██╔████╔██║██████╔╝██║ ███████║ ██║ █████╗ ███████╗[/dim white]
|
||||
[dim white] ██║ ██╔══╝ ██║╚██╔╝██║██╔═══╝ ██║ ██╔══██║ ██║ ██╔══╝ ╚════██║[/dim white]
|
||||
[dim white] ██║ ███████╗██║ ╚═╝ ██║██║ ███████╗██║ ██║ ██║ ███████╗███████║[/dim white]
|
||||
[dim white] ╚═╝ ╚══════╝╚═╝ ╚═╝╚═╝ ╚══════╝╚═╝ ╚═╝ ╚═╝ ╚══════╝╚══════╝[/dim white]"""
|
||||
|
||||
|
||||
class TemplateCommand(BaseCommand):
|
||||
"""Handle template-related operations for CrewAI projects."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def list_templates(self) -> None:
|
||||
"""List available templates with an interactive selector to install."""
|
||||
templates = self._fetch_templates()
|
||||
if not templates:
|
||||
click.echo("No templates found.")
|
||||
return
|
||||
|
||||
console.print(f"\n{BANNER}\n")
|
||||
console.print(" [on cyan] templates [/on cyan]\n")
|
||||
console.print(f" [green]o[/green] Source: https://github.com/{GITHUB_ORG}")
|
||||
console.print(
|
||||
f" [green]o[/green] Found [bold]{len(templates)}[/bold] templates\n"
|
||||
)
|
||||
console.print(" [green]o[/green] Select a template to install")
|
||||
|
||||
for idx, repo in enumerate(templates, start=1):
|
||||
name = repo["name"].removeprefix(TEMPLATE_PREFIX)
|
||||
description = repo.get("description") or ""
|
||||
if description:
|
||||
console.print(
|
||||
f" [bold cyan]{idx}.[/bold cyan] [bold white]{name}[/bold white] [dim]({description})[/dim]"
|
||||
)
|
||||
else:
|
||||
console.print(
|
||||
f" [bold cyan]{idx}.[/bold cyan] [bold white]{name}[/bold white]"
|
||||
)
|
||||
|
||||
console.print(" [bold cyan]q.[/bold cyan] [dim]Quit[/dim]\n")
|
||||
|
||||
while True:
|
||||
choice = click.prompt("Enter your choice", type=str)
|
||||
|
||||
if choice.lower() == "q":
|
||||
return
|
||||
|
||||
if choice.isdigit() and 1 <= int(choice) <= len(templates):
|
||||
selected_index = int(choice) - 1
|
||||
break
|
||||
|
||||
click.secho(
|
||||
f"Please enter a number between 1 and {len(templates)}, or 'q' to quit.",
|
||||
fg="yellow",
|
||||
)
|
||||
|
||||
selected = templates[selected_index]
|
||||
repo_name = selected["name"]
|
||||
self._install_repo(repo_name)
|
||||
|
||||
def add_template(self, name: str, output_dir: str | None = None) -> None:
|
||||
"""Download a template and copy it into the current working directory.
|
||||
|
||||
Args:
|
||||
name: Template name (with or without the template_ prefix).
|
||||
output_dir: Optional directory name. Defaults to the template name.
|
||||
"""
|
||||
repo_name = self._resolve_repo_name(name)
|
||||
if repo_name is None:
|
||||
click.secho(f"Template '{name}' not found.", fg="red")
|
||||
click.echo("Run 'crewai template list' to see available templates.")
|
||||
raise SystemExit(1)
|
||||
|
||||
self._install_repo(repo_name, output_dir)
|
||||
|
||||
def _install_repo(self, repo_name: str, output_dir: str | None = None) -> None:
|
||||
"""Download and extract a template repo into the current directory.
|
||||
|
||||
Args:
|
||||
repo_name: Full GitHub repo name (e.g. template_deep_research).
|
||||
output_dir: Optional directory name. Defaults to the template name.
|
||||
"""
|
||||
folder_name = output_dir or repo_name.removeprefix(TEMPLATE_PREFIX)
|
||||
dest = os.path.join(os.getcwd(), folder_name)
|
||||
|
||||
while os.path.exists(dest):
|
||||
click.secho(f"Directory '{folder_name}' already exists.", fg="yellow")
|
||||
folder_name = click.prompt(
|
||||
"Enter a different directory name (or 'q' to quit)", type=str
|
||||
)
|
||||
if folder_name.lower() == "q":
|
||||
return
|
||||
dest = os.path.join(os.getcwd(), folder_name)
|
||||
|
||||
click.echo(
|
||||
f"Downloading template '{repo_name.removeprefix(TEMPLATE_PREFIX)}'..."
|
||||
)
|
||||
|
||||
zip_bytes = self._download_zip(repo_name)
|
||||
self._extract_zip(zip_bytes, dest)
|
||||
|
||||
self._telemetry.template_installed_span(repo_name.removeprefix(TEMPLATE_PREFIX))
|
||||
|
||||
console.print(
|
||||
f"\n [green]\u2713[/green] Installed template [bold white]{folder_name}[/bold white]"
|
||||
f" [dim](source: github.com/{GITHUB_ORG}/{repo_name})[/dim]\n"
|
||||
)
|
||||
|
||||
next_steps = Text()
|
||||
next_steps.append(f" cd {folder_name}\n", style="bold white")
|
||||
next_steps.append(" crewai install", style="bold white")
|
||||
|
||||
panel = Panel(
|
||||
next_steps,
|
||||
title="[green]\u25c7 Next steps[/green]",
|
||||
title_align="left",
|
||||
border_style="dim",
|
||||
padding=(1, 2),
|
||||
)
|
||||
console.print(panel)
|
||||
|
||||
def _fetch_templates(self) -> list[dict[str, Any]]:
|
||||
"""Fetch all template repos from the GitHub org."""
|
||||
templates: list[dict[str, Any]] = []
|
||||
page = 1
|
||||
while True:
|
||||
url = f"{GITHUB_API_BASE}/orgs/{GITHUB_ORG}/repos"
|
||||
params: dict[str, str | int] = {
|
||||
"per_page": 100,
|
||||
"page": page,
|
||||
"type": "public",
|
||||
}
|
||||
try:
|
||||
response = httpx.get(url, params=params, timeout=15)
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPError as e:
|
||||
click.secho(f"Failed to fetch templates from GitHub: {e}", fg="red")
|
||||
raise SystemExit(1) from e
|
||||
|
||||
repos = response.json()
|
||||
if not repos:
|
||||
break
|
||||
|
||||
templates.extend(
|
||||
repo
|
||||
for repo in repos
|
||||
if repo["name"].startswith(TEMPLATE_PREFIX) and not repo.get("private")
|
||||
)
|
||||
|
||||
page += 1
|
||||
|
||||
templates.sort(key=lambda r: r["name"])
|
||||
return templates
|
||||
|
||||
def _resolve_repo_name(self, name: str) -> str | None:
|
||||
"""Resolve user input to a full repo name, or None if not found."""
|
||||
# Accept both 'deep_research' and 'template_deep_research'
|
||||
candidates = [
|
||||
f"{TEMPLATE_PREFIX}{name}"
|
||||
if not name.startswith(TEMPLATE_PREFIX)
|
||||
else name,
|
||||
name,
|
||||
]
|
||||
|
||||
templates = self._fetch_templates()
|
||||
template_names = {t["name"] for t in templates}
|
||||
|
||||
for candidate in candidates:
|
||||
if candidate in template_names:
|
||||
return candidate
|
||||
|
||||
return None
|
||||
|
||||
def _download_zip(self, repo_name: str) -> bytes:
|
||||
"""Download the default branch zipball for a repo."""
|
||||
url = f"{GITHUB_API_BASE}/repos/{GITHUB_ORG}/{repo_name}/zipball"
|
||||
try:
|
||||
response = httpx.get(url, follow_redirects=True, timeout=60)
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPError as e:
|
||||
click.secho(f"Failed to download template: {e}", fg="red")
|
||||
raise SystemExit(1) from e
|
||||
|
||||
return response.content
|
||||
|
||||
def _extract_zip(self, zip_bytes: bytes, dest: str) -> None:
|
||||
"""Extract a GitHub zipball into dest, stripping the top-level directory."""
|
||||
with zipfile.ZipFile(io.BytesIO(zip_bytes)) as zf:
|
||||
# GitHub zipballs have a single top-level dir like 'crewAIInc-template_xxx-<sha>/'
|
||||
members = zf.namelist()
|
||||
if not members:
|
||||
click.secho("Downloaded archive is empty.", fg="red")
|
||||
raise SystemExit(1)
|
||||
|
||||
top_dir = members[0].split("/")[0] + "/"
|
||||
|
||||
os.makedirs(dest, exist_ok=True)
|
||||
|
||||
for member in members:
|
||||
if member == top_dir or not member.startswith(top_dir):
|
||||
continue
|
||||
|
||||
relative_path = member[len(top_dir) :]
|
||||
if not relative_path:
|
||||
continue
|
||||
|
||||
target = os.path.realpath(os.path.join(dest, relative_path))
|
||||
if not target.startswith(
|
||||
os.path.realpath(dest) + os.sep
|
||||
) and target != os.path.realpath(dest):
|
||||
continue
|
||||
|
||||
if member.endswith("/"):
|
||||
os.makedirs(target, exist_ok=True)
|
||||
else:
|
||||
os.makedirs(os.path.dirname(target), exist_ok=True)
|
||||
with zf.open(member) as src, open(target, "wb") as dst:
|
||||
shutil.copyfileobj(src, dst)
|
||||
34
lib/cli/src/crewai_cli/replay_from_task.py
Normal file
34
lib/cli/src/crewai_cli/replay_from_task.py
Normal file
@@ -0,0 +1,34 @@
|
||||
import subprocess
|
||||
|
||||
import click
|
||||
from crewai_core.constants import CREWAI_TRAINED_AGENTS_FILE_ENV
|
||||
|
||||
from crewai_cli.utils import build_env_with_all_tool_credentials
|
||||
|
||||
|
||||
def replay_task_command(task_id: str, trained_agents_file: str | None = None) -> None:
|
||||
"""Replay the crew execution from a specific task.
|
||||
|
||||
Args:
|
||||
task_id: The ID of the task to replay from.
|
||||
trained_agents_file: Optional trained-agents pickle path forwarded to
|
||||
the subprocess via the ``CREWAI_TRAINED_AGENTS_FILE`` env var.
|
||||
"""
|
||||
command = ["uv", "run", "replay", task_id]
|
||||
env = build_env_with_all_tool_credentials()
|
||||
if trained_agents_file:
|
||||
env[CREWAI_TRAINED_AGENTS_FILE_ENV] = trained_agents_file
|
||||
|
||||
try:
|
||||
result = subprocess.run( # noqa: S603
|
||||
command, capture_output=False, text=True, check=True, env=env
|
||||
)
|
||||
if result.stderr:
|
||||
click.echo(result.stderr, err=True)
|
||||
|
||||
except subprocess.CalledProcessError as e:
|
||||
click.echo(f"An error occurred while replaying the task: {e}", err=True)
|
||||
click.echo(e.output, err=True)
|
||||
|
||||
except Exception as e:
|
||||
click.echo(f"An unexpected error occurred: {e}", err=True)
|
||||
31
lib/cli/src/crewai_cli/reset_memories_command.py
Normal file
31
lib/cli/src/crewai_cli/reset_memories_command.py
Normal file
@@ -0,0 +1,31 @@
|
||||
"""Wrapper for the reset-memories command.
|
||||
|
||||
Delegates to ``crewai.utilities.reset_memories`` when the full crewai
|
||||
package is installed, otherwise prints a helpful error message.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import click
|
||||
|
||||
|
||||
def reset_memories_command(
|
||||
memory: bool,
|
||||
knowledge: bool,
|
||||
agent_knowledge: bool,
|
||||
kickoff_outputs: bool,
|
||||
all: bool,
|
||||
) -> None:
|
||||
try:
|
||||
from crewai.utilities.reset_memories import (
|
||||
reset_memories_command as _reset,
|
||||
)
|
||||
except ImportError:
|
||||
click.secho(
|
||||
"The 'reset-memories' command requires the full crewai package.\n"
|
||||
"Install it with: pip install crewai",
|
||||
fg="red",
|
||||
)
|
||||
raise SystemExit(1) from None
|
||||
|
||||
_reset(memory, knowledge, agent_knowledge, kickoff_outputs, all)
|
||||
101
lib/cli/src/crewai_cli/run_crew.py
Normal file
101
lib/cli/src/crewai_cli/run_crew.py
Normal file
@@ -0,0 +1,101 @@
|
||||
from enum import Enum
|
||||
import subprocess
|
||||
|
||||
import click
|
||||
from crewai_core.constants import CREWAI_TRAINED_AGENTS_FILE_ENV
|
||||
from packaging import version
|
||||
|
||||
from crewai_cli.utils import build_env_with_all_tool_credentials, read_toml
|
||||
from crewai_cli.version import get_crewai_version
|
||||
|
||||
|
||||
class CrewType(Enum):
|
||||
STANDARD = "standard"
|
||||
FLOW = "flow"
|
||||
|
||||
|
||||
def run_crew(trained_agents_file: str | None = None) -> None:
|
||||
"""Run the crew or flow by running a command in the UV environment.
|
||||
|
||||
Starting from version 0.103.0, this command can be used to run both
|
||||
standard crews and flows. For flows, it detects the type from pyproject.toml
|
||||
and automatically runs the appropriate command.
|
||||
|
||||
Args:
|
||||
trained_agents_file: Optional path to a trained-agents pickle produced
|
||||
by ``crewai train -f``. When set, exported as
|
||||
``CREWAI_TRAINED_AGENTS_FILE`` so agents load suggestions from this
|
||||
file instead of the default ``trained_agents_data.pkl``.
|
||||
"""
|
||||
crewai_version = get_crewai_version()
|
||||
min_required_version = "0.71.0"
|
||||
pyproject_data = read_toml()
|
||||
|
||||
# Check for legacy poetry configuration
|
||||
if pyproject_data.get("tool", {}).get("poetry") and (
|
||||
version.parse(crewai_version) < version.parse(min_required_version)
|
||||
):
|
||||
click.secho(
|
||||
f"You are running an older version of crewAI ({crewai_version}) that uses poetry pyproject.toml. "
|
||||
f"Please run `crewai update` to update your pyproject.toml to use uv.",
|
||||
fg="red",
|
||||
)
|
||||
|
||||
# Determine crew type
|
||||
is_flow = pyproject_data.get("tool", {}).get("crewai", {}).get("type") == "flow"
|
||||
crew_type = CrewType.FLOW if is_flow else CrewType.STANDARD
|
||||
|
||||
# Display appropriate message
|
||||
click.echo(f"Running the {'Flow' if is_flow else 'Crew'}")
|
||||
|
||||
# Execute the appropriate command
|
||||
execute_command(crew_type, trained_agents_file=trained_agents_file)
|
||||
|
||||
|
||||
def execute_command(
|
||||
crew_type: CrewType, trained_agents_file: str | None = None
|
||||
) -> None:
|
||||
"""Execute the appropriate command based on crew type.
|
||||
|
||||
Args:
|
||||
crew_type: The type of crew to run.
|
||||
trained_agents_file: Optional trained-agents pickle path forwarded to
|
||||
the subprocess via the ``CREWAI_TRAINED_AGENTS_FILE`` env var.
|
||||
"""
|
||||
command = ["uv", "run", "kickoff" if crew_type == CrewType.FLOW else "run_crew"]
|
||||
|
||||
env = build_env_with_all_tool_credentials()
|
||||
if trained_agents_file:
|
||||
env[CREWAI_TRAINED_AGENTS_FILE_ENV] = trained_agents_file
|
||||
|
||||
try:
|
||||
subprocess.run(command, capture_output=False, text=True, check=True, env=env) # noqa: S603
|
||||
|
||||
except subprocess.CalledProcessError as e:
|
||||
handle_error(e, crew_type)
|
||||
|
||||
except Exception as e:
|
||||
click.echo(f"An unexpected error occurred: {e}", err=True)
|
||||
|
||||
|
||||
def handle_error(error: subprocess.CalledProcessError, crew_type: CrewType) -> None:
|
||||
"""
|
||||
Handle subprocess errors with appropriate messaging.
|
||||
|
||||
Args:
|
||||
error: The subprocess error that occurred
|
||||
crew_type: The type of crew that was being run
|
||||
"""
|
||||
entity_type = "flow" if crew_type == CrewType.FLOW else "crew"
|
||||
click.echo(f"An error occurred while running the {entity_type}: {error}", err=True)
|
||||
|
||||
if error.output:
|
||||
click.echo(error.output, err=True, nl=True)
|
||||
|
||||
pyproject_data = read_toml()
|
||||
if pyproject_data.get("tool", {}).get("poetry"):
|
||||
click.secho(
|
||||
"It's possible that you are using an old version of crewAI that uses poetry, "
|
||||
"please run `crewai update` to update your pyproject.toml to use uv.",
|
||||
fg="yellow",
|
||||
)
|
||||
0
lib/cli/src/crewai_cli/settings/__init__.py
Normal file
0
lib/cli/src/crewai_cli/settings/__init__.py
Normal file
110
lib/cli/src/crewai_cli/settings/main.py
Normal file
110
lib/cli/src/crewai_cli/settings/main.py
Normal file
@@ -0,0 +1,110 @@
|
||||
from datetime import datetime
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from rich.console import Console
|
||||
from rich.table import Table
|
||||
|
||||
from crewai_cli.command import BaseCommand
|
||||
from crewai_cli.config import HIDDEN_SETTINGS_KEYS, READONLY_SETTINGS_KEYS, Settings
|
||||
from crewai_cli.user_data import _load_user_data
|
||||
|
||||
|
||||
console = Console()
|
||||
|
||||
|
||||
class SettingsCommand(BaseCommand):
|
||||
"""A class to handle CLI configuration commands."""
|
||||
|
||||
def __init__(self, settings_kwargs: dict[str, Any] | None = None):
|
||||
super().__init__()
|
||||
settings_kwargs = settings_kwargs or {}
|
||||
self.settings = Settings(**settings_kwargs)
|
||||
|
||||
def list(self) -> None:
|
||||
"""List all CLI configuration parameters."""
|
||||
table = Table(title="CrewAI CLI Configuration")
|
||||
table.add_column("Setting", style="cyan", no_wrap=True)
|
||||
table.add_column("Value", style="green")
|
||||
table.add_column("Description", style="yellow")
|
||||
|
||||
# Add all settings to the table
|
||||
for field_name, field_info in Settings.model_fields.items():
|
||||
if field_name in HIDDEN_SETTINGS_KEYS:
|
||||
# Do not display hidden settings
|
||||
continue
|
||||
|
||||
current_value = getattr(self.settings, field_name)
|
||||
description = field_info.description or "No description available"
|
||||
display_value = (
|
||||
str(current_value) if current_value not in [None, {}] else "Not set"
|
||||
)
|
||||
|
||||
table.add_row(field_name, display_value, description)
|
||||
|
||||
# Add trace-related settings from user data
|
||||
user_data = _load_user_data()
|
||||
|
||||
# CREWAI_TRACING_ENABLED environment variable
|
||||
env_tracing = os.getenv("CREWAI_TRACING_ENABLED", "")
|
||||
env_tracing_display = env_tracing if env_tracing else "Not set"
|
||||
table.add_row(
|
||||
"CREWAI_TRACING_ENABLED",
|
||||
env_tracing_display,
|
||||
"Environment variable to enable/disable tracing",
|
||||
)
|
||||
|
||||
# Trace consent status
|
||||
trace_consent = user_data.get("trace_consent")
|
||||
if trace_consent is True:
|
||||
consent_display = "✅ Enabled"
|
||||
elif trace_consent is False:
|
||||
consent_display = "❌ Disabled"
|
||||
else:
|
||||
consent_display = "Not set"
|
||||
table.add_row(
|
||||
"trace_consent", consent_display, "Whether trace collection is enabled"
|
||||
)
|
||||
|
||||
# First execution timestamp
|
||||
if user_data.get("first_execution_at"):
|
||||
timestamp = datetime.fromtimestamp(user_data["first_execution_at"])
|
||||
first_exec_display = timestamp.strftime("%Y-%m-%d %H:%M:%S")
|
||||
else:
|
||||
first_exec_display = "Not set"
|
||||
table.add_row(
|
||||
"first_execution_at",
|
||||
first_exec_display,
|
||||
"Timestamp of first crew/flow execution",
|
||||
)
|
||||
|
||||
console.print(table)
|
||||
|
||||
def set(self, key: str, value: str) -> None:
|
||||
"""Set a CLI configuration parameter."""
|
||||
|
||||
readonly_settings = READONLY_SETTINGS_KEYS + HIDDEN_SETTINGS_KEYS
|
||||
|
||||
if not hasattr(self.settings, key) or key in readonly_settings:
|
||||
console.print(
|
||||
f"Error: Unknown or readonly configuration key '{key}'",
|
||||
style="bold red",
|
||||
)
|
||||
console.print("Available keys:", style="yellow")
|
||||
for field_name in Settings.model_fields:
|
||||
if field_name not in readonly_settings:
|
||||
console.print(f" - {field_name}", style="yellow")
|
||||
raise SystemExit(1)
|
||||
|
||||
setattr(self.settings, key, value)
|
||||
self.settings.dump()
|
||||
|
||||
console.print(f"Successfully set '{key}' to '{value}'", style="bold green")
|
||||
|
||||
def reset_all_settings(self) -> None:
|
||||
"""Reset all CLI configuration parameters to default values."""
|
||||
self.settings.reset()
|
||||
console.print(
|
||||
"Successfully reset all configuration parameters to default values. It is recommended to run [bold yellow]'crewai login'[/bold yellow] to re-authenticate.",
|
||||
style="bold green",
|
||||
)
|
||||
0
lib/cli/src/crewai_cli/shared/__init__.py
Normal file
0
lib/cli/src/crewai_cli/shared/__init__.py
Normal file
12
lib/cli/src/crewai_cli/shared/token_manager.py
Normal file
12
lib/cli/src/crewai_cli/shared/token_manager.py
Normal file
@@ -0,0 +1,12 @@
|
||||
"""Re-export of ``crewai_core.token_manager.TokenManager``.
|
||||
|
||||
Kept as a stable import path for the CLI; new code should import from
|
||||
``crewai_core.token_manager`` directly.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from crewai_core.token_manager import TokenManager as TokenManager
|
||||
|
||||
|
||||
__all__ = ["TokenManager"]
|
||||
67
lib/cli/src/crewai_cli/task_outputs.py
Normal file
67
lib/cli/src/crewai_cli/task_outputs.py
Normal file
@@ -0,0 +1,67 @@
|
||||
"""Lightweight SQLite reader for kickoff task outputs.
|
||||
|
||||
Only used by the ``crewai log-tasks-outputs`` CLI command. Depends solely on
|
||||
the standard library + *appdirs* so crewai-cli can read stored outputs without
|
||||
importing the full crewai framework.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
import sqlite3
|
||||
from typing import Any
|
||||
|
||||
from crewai_cli.user_data import _db_storage_path
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def load_task_outputs(db_path: str | None = None) -> list[dict[str, Any]]:
|
||||
"""Return all rows from the kickoff task outputs database."""
|
||||
if db_path is None:
|
||||
db_path = str(Path(_db_storage_path()) / "latest_kickoff_task_outputs.db")
|
||||
|
||||
if not Path(db_path).exists():
|
||||
return []
|
||||
|
||||
try:
|
||||
with sqlite3.connect(db_path) as conn:
|
||||
conn.row_factory = sqlite3.Row
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("""
|
||||
SELECT task_id, expected_output, output, task_index,
|
||||
inputs, was_replayed, timestamp
|
||||
FROM latest_kickoff_task_outputs
|
||||
ORDER BY task_index
|
||||
""")
|
||||
rows = cursor.fetchall()
|
||||
except sqlite3.Error as e:
|
||||
logger.error("Failed to load task outputs: %s", e)
|
||||
return []
|
||||
|
||||
return [
|
||||
{
|
||||
"task_id": row["task_id"],
|
||||
"expected_output": row["expected_output"],
|
||||
"output": _safe_json_loads(row["output"]),
|
||||
"task_index": row["task_index"],
|
||||
"inputs": _safe_json_loads(row["inputs"]),
|
||||
"was_replayed": row["was_replayed"],
|
||||
"timestamp": row["timestamp"],
|
||||
}
|
||||
for row in rows
|
||||
]
|
||||
|
||||
|
||||
def _safe_json_loads(value: str | None) -> Any:
|
||||
"""Decode a JSON column tolerantly: NULL/blank/corrupt → None."""
|
||||
if not value:
|
||||
return None
|
||||
try:
|
||||
return json.loads(value)
|
||||
except (json.JSONDecodeError, TypeError) as e:
|
||||
logger.warning("Failed to decode JSON column: %s", e)
|
||||
return None
|
||||
1017
lib/cli/src/crewai_cli/templates/AGENTS.md
Normal file
1017
lib/cli/src/crewai_cli/templates/AGENTS.md
Normal file
File diff suppressed because it is too large
Load Diff
0
lib/cli/src/crewai_cli/templates/__init__.py
Normal file
0
lib/cli/src/crewai_cli/templates/__init__.py
Normal file
3
lib/cli/src/crewai_cli/templates/crew/.gitignore
vendored
Normal file
3
lib/cli/src/crewai_cli/templates/crew/.gitignore
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
.env
|
||||
__pycache__/
|
||||
.DS_Store
|
||||
54
lib/cli/src/crewai_cli/templates/crew/README.md
Normal file
54
lib/cli/src/crewai_cli/templates/crew/README.md
Normal file
@@ -0,0 +1,54 @@
|
||||
# {{crew_name}} Crew
|
||||
|
||||
Welcome to the {{crew_name}} Crew project, powered by [crewAI](https://crewai.com). This template is designed to help you set up a multi-agent AI system with ease, leveraging the powerful and flexible framework provided by crewAI. Our goal is to enable your agents to collaborate effectively on complex tasks, maximizing their collective intelligence and capabilities.
|
||||
|
||||
## Installation
|
||||
|
||||
Ensure you have Python >=3.10 <3.14 installed on your system. This project uses [UV](https://docs.astral.sh/uv/) for dependency management and package handling, offering a seamless setup and execution experience.
|
||||
|
||||
First, if you haven't already, install uv:
|
||||
|
||||
```bash
|
||||
pip install uv
|
||||
```
|
||||
|
||||
Next, navigate to your project directory and install the dependencies:
|
||||
|
||||
(Optional) Lock the dependencies and install them by using the CLI command:
|
||||
```bash
|
||||
crewai install
|
||||
```
|
||||
### Customizing
|
||||
|
||||
**Add your `OPENAI_API_KEY` into the `.env` file**
|
||||
|
||||
- Modify `src/{{folder_name}}/config/agents.yaml` to define your agents
|
||||
- Modify `src/{{folder_name}}/config/tasks.yaml` to define your tasks
|
||||
- Modify `src/{{folder_name}}/crew.py` to add your own logic, tools and specific args
|
||||
- Modify `src/{{folder_name}}/main.py` to add custom inputs for your agents and tasks
|
||||
|
||||
## Running the Project
|
||||
|
||||
To kickstart your crew of AI agents and begin task execution, run this from the root folder of your project:
|
||||
|
||||
```bash
|
||||
$ crewai run
|
||||
```
|
||||
|
||||
This command initializes the {{name}} Crew, assembling the agents and assigning them tasks as defined in your configuration.
|
||||
|
||||
This example, unmodified, will run the create a `report.md` file with the output of a research on LLMs in the root folder.
|
||||
|
||||
## Understanding Your Crew
|
||||
|
||||
The {{name}} Crew is composed of multiple AI agents, each with unique roles, goals, and tools. These agents collaborate on a series of tasks, defined in `config/tasks.yaml`, leveraging their collective skills to achieve complex objectives. The `config/agents.yaml` file outlines the capabilities and configurations of each agent in your crew.
|
||||
|
||||
## Support
|
||||
|
||||
For support, questions, or feedback regarding the {{crew_name}} Crew or crewAI.
|
||||
- Visit our [documentation](https://docs.crewai.com)
|
||||
- Reach out to us through our [GitHub repository](https://github.com/joaomdmoura/crewai)
|
||||
- [Join our Discord](https://discord.com/invite/X4JWnZnxPb)
|
||||
- [Chat with our docs](https://chatg.pt/DWjSBZn)
|
||||
|
||||
Let's create wonders together with the power and simplicity of crewAI.
|
||||
0
lib/cli/src/crewai_cli/templates/crew/__init__.py
Normal file
0
lib/cli/src/crewai_cli/templates/crew/__init__.py
Normal file
19
lib/cli/src/crewai_cli/templates/crew/config/agents.yaml
Normal file
19
lib/cli/src/crewai_cli/templates/crew/config/agents.yaml
Normal file
@@ -0,0 +1,19 @@
|
||||
researcher:
|
||||
role: >
|
||||
{topic} Senior Data Researcher
|
||||
goal: >
|
||||
Uncover cutting-edge developments in {topic}
|
||||
backstory: >
|
||||
You're a seasoned researcher with a knack for uncovering the latest
|
||||
developments in {topic}. Known for your ability to find the most relevant
|
||||
information and present it in a clear and concise manner.
|
||||
|
||||
reporting_analyst:
|
||||
role: >
|
||||
{topic} Reporting Analyst
|
||||
goal: >
|
||||
Create detailed reports based on {topic} data analysis and research findings
|
||||
backstory: >
|
||||
You're a meticulous analyst with a keen eye for detail. You're known for
|
||||
your ability to turn complex data into clear and concise reports, making
|
||||
it easy for others to understand and act on the information you provide.
|
||||
17
lib/cli/src/crewai_cli/templates/crew/config/tasks.yaml
Normal file
17
lib/cli/src/crewai_cli/templates/crew/config/tasks.yaml
Normal file
@@ -0,0 +1,17 @@
|
||||
research_task:
|
||||
description: >
|
||||
Conduct a thorough research about {topic}
|
||||
Make sure you find any interesting and relevant information given
|
||||
the current year is {current_year}.
|
||||
expected_output: >
|
||||
A list with 10 bullet points of the most relevant information about {topic}
|
||||
agent: researcher
|
||||
|
||||
reporting_task:
|
||||
description: >
|
||||
Review the context you got and expand each topic into a full section for a report.
|
||||
Make sure the report is detailed and contains any and all relevant information.
|
||||
expected_output: >
|
||||
A fully fledged report with the main topics, each with a full section of information.
|
||||
Formatted as markdown without '```'
|
||||
agent: reporting_analyst
|
||||
63
lib/cli/src/crewai_cli/templates/crew/crew.py
Normal file
63
lib/cli/src/crewai_cli/templates/crew/crew.py
Normal file
@@ -0,0 +1,63 @@
|
||||
from crewai import Agent, Crew, Process, Task
|
||||
from crewai.project import CrewBase, agent, crew, task
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
# If you want to run a snippet of code before or after the crew starts,
|
||||
# you can use the @before_kickoff and @after_kickoff decorators
|
||||
# https://docs.crewai.com/concepts/crews#example-crew-class-with-decorators
|
||||
|
||||
@CrewBase
|
||||
class {{crew_name}}():
|
||||
"""{{crew_name}} crew"""
|
||||
|
||||
agents: list[BaseAgent]
|
||||
tasks: list[Task]
|
||||
|
||||
# Learn more about YAML configuration files here:
|
||||
# Agents: https://docs.crewai.com/concepts/agents#yaml-configuration-recommended
|
||||
# Tasks: https://docs.crewai.com/concepts/tasks#yaml-configuration-recommended
|
||||
|
||||
# If you would like to add tools to your agents, you can learn more about it here:
|
||||
# https://docs.crewai.com/concepts/agents#agent-tools
|
||||
@agent
|
||||
def researcher(self) -> Agent:
|
||||
return Agent(
|
||||
config=self.agents_config['researcher'], # type: ignore[index]
|
||||
verbose=True
|
||||
)
|
||||
|
||||
@agent
|
||||
def reporting_analyst(self) -> Agent:
|
||||
return Agent(
|
||||
config=self.agents_config['reporting_analyst'], # type: ignore[index]
|
||||
verbose=True
|
||||
)
|
||||
|
||||
# To learn more about structured task outputs,
|
||||
# task dependencies, and task callbacks, check out the documentation:
|
||||
# https://docs.crewai.com/concepts/tasks#overview-of-a-task
|
||||
@task
|
||||
def research_task(self) -> Task:
|
||||
return Task(
|
||||
config=self.tasks_config['research_task'], # type: ignore[index]
|
||||
)
|
||||
|
||||
@task
|
||||
def reporting_task(self) -> Task:
|
||||
return Task(
|
||||
config=self.tasks_config['reporting_task'], # type: ignore[index]
|
||||
output_file='report.md'
|
||||
)
|
||||
|
||||
@crew
|
||||
def crew(self) -> Crew:
|
||||
"""Creates the {{crew_name}} crew"""
|
||||
# To learn how to add knowledge sources to your crew, check out the documentation:
|
||||
# https://docs.crewai.com/concepts/knowledge#what-is-knowledge
|
||||
|
||||
return Crew(
|
||||
agents=self.agents, # Automatically created by the @agent decorator
|
||||
tasks=self.tasks, # Automatically created by the @task decorator
|
||||
process=Process.sequential,
|
||||
verbose=True,
|
||||
# process=Process.hierarchical, # In case you wanna use that instead https://docs.crewai.com/how-to/Hierarchical/
|
||||
)
|
||||
@@ -0,0 +1,4 @@
|
||||
User name is John Doe.
|
||||
User is an AI Engineer.
|
||||
User is interested in AI Agents.
|
||||
User is based in San Francisco, California.
|
||||
94
lib/cli/src/crewai_cli/templates/crew/main.py
Normal file
94
lib/cli/src/crewai_cli/templates/crew/main.py
Normal file
@@ -0,0 +1,94 @@
|
||||
#!/usr/bin/env python
|
||||
import sys
|
||||
import warnings
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from {{folder_name}}.crew import {{crew_name}}
|
||||
|
||||
warnings.filterwarnings("ignore", category=SyntaxWarning, module="pysbd")
|
||||
|
||||
# This main file is intended to be a way for you to run your
|
||||
# crew locally, so refrain from adding unnecessary logic into this file.
|
||||
# Replace with inputs you want to test with, it will automatically
|
||||
# interpolate any tasks and agents information
|
||||
|
||||
def run():
|
||||
"""
|
||||
Run the crew.
|
||||
"""
|
||||
inputs = {
|
||||
'topic': 'AI LLMs',
|
||||
'current_year': str(datetime.now().year)
|
||||
}
|
||||
|
||||
try:
|
||||
{{crew_name}}().crew().kickoff(inputs=inputs)
|
||||
except Exception as e:
|
||||
raise Exception(f"An error occurred while running the crew: {e}")
|
||||
|
||||
|
||||
def train():
|
||||
"""
|
||||
Train the crew for a given number of iterations.
|
||||
"""
|
||||
inputs = {
|
||||
"topic": "AI LLMs",
|
||||
'current_year': str(datetime.now().year)
|
||||
}
|
||||
try:
|
||||
{{crew_name}}().crew().train(n_iterations=int(sys.argv[1]), filename=sys.argv[2], inputs=inputs)
|
||||
|
||||
except Exception as e:
|
||||
raise Exception(f"An error occurred while training the crew: {e}")
|
||||
|
||||
def replay():
|
||||
"""
|
||||
Replay the crew execution from a specific task.
|
||||
"""
|
||||
try:
|
||||
{{crew_name}}().crew().replay(task_id=sys.argv[1])
|
||||
|
||||
except Exception as e:
|
||||
raise Exception(f"An error occurred while replaying the crew: {e}")
|
||||
|
||||
def test():
|
||||
"""
|
||||
Test the crew execution and returns the results.
|
||||
"""
|
||||
inputs = {
|
||||
"topic": "AI LLMs",
|
||||
"current_year": str(datetime.now().year)
|
||||
}
|
||||
|
||||
try:
|
||||
{{crew_name}}().crew().test(n_iterations=int(sys.argv[1]), eval_llm=sys.argv[2], inputs=inputs)
|
||||
|
||||
except Exception as e:
|
||||
raise Exception(f"An error occurred while testing the crew: {e}")
|
||||
|
||||
def run_with_trigger():
|
||||
"""
|
||||
Run the crew with trigger payload.
|
||||
"""
|
||||
import json
|
||||
|
||||
if len(sys.argv) < 2:
|
||||
raise Exception("No trigger payload provided. Please provide JSON payload as argument.")
|
||||
|
||||
try:
|
||||
trigger_payload = json.loads(sys.argv[1])
|
||||
except json.JSONDecodeError:
|
||||
raise Exception("Invalid JSON payload provided as argument")
|
||||
|
||||
inputs = {
|
||||
"crewai_trigger_payload": trigger_payload,
|
||||
"topic": "",
|
||||
"current_year": ""
|
||||
}
|
||||
|
||||
try:
|
||||
result = {{crew_name}}().crew().kickoff(inputs=inputs)
|
||||
return result
|
||||
except Exception as e:
|
||||
raise Exception(f"An error occurred while running the crew with trigger: {e}")
|
||||
24
lib/cli/src/crewai_cli/templates/crew/pyproject.toml
Normal file
24
lib/cli/src/crewai_cli/templates/crew/pyproject.toml
Normal file
@@ -0,0 +1,24 @@
|
||||
[project]
|
||||
name = "{{folder_name}}"
|
||||
version = "0.1.0"
|
||||
description = "{{name}} using crewAI"
|
||||
authors = [{ name = "Your Name", email = "you@example.com" }]
|
||||
requires-python = ">=3.10,<3.14"
|
||||
dependencies = [
|
||||
"crewai[tools]==1.14.5a2"
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
{{folder_name}} = "{{folder_name}}.main:run"
|
||||
run_crew = "{{folder_name}}.main:run"
|
||||
train = "{{folder_name}}.main:train"
|
||||
replay = "{{folder_name}}.main:replay"
|
||||
test = "{{folder_name}}.main:test"
|
||||
run_with_trigger = "{{folder_name}}.main:run_with_trigger"
|
||||
|
||||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
[tool.crewai]
|
||||
type = "crew"
|
||||
19
lib/cli/src/crewai_cli/templates/crew/tools/custom_tool.py
Normal file
19
lib/cli/src/crewai_cli/templates/crew/tools/custom_tool.py
Normal file
@@ -0,0 +1,19 @@
|
||||
from crewai.tools import BaseTool
|
||||
from typing import Type
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class MyCustomToolInput(BaseModel):
|
||||
"""Input schema for MyCustomTool."""
|
||||
argument: str = Field(..., description="Description of the argument.")
|
||||
|
||||
class MyCustomTool(BaseTool):
|
||||
name: str = "Name of my tool"
|
||||
description: str = (
|
||||
"Clear description for what this tool is useful for, your agent will need this information to use it."
|
||||
)
|
||||
args_schema: Type[BaseModel] = MyCustomToolInput
|
||||
|
||||
def _run(self, argument: str) -> str:
|
||||
# Implementation goes here
|
||||
return "this is an example of a tool output, ignore it and move along."
|
||||
4
lib/cli/src/crewai_cli/templates/flow/.gitignore
vendored
Normal file
4
lib/cli/src/crewai_cli/templates/flow/.gitignore
vendored
Normal file
@@ -0,0 +1,4 @@
|
||||
.env
|
||||
__pycache__/
|
||||
lib/
|
||||
.DS_Store
|
||||
56
lib/cli/src/crewai_cli/templates/flow/README.md
Normal file
56
lib/cli/src/crewai_cli/templates/flow/README.md
Normal file
@@ -0,0 +1,56 @@
|
||||
# {{crew_name}} Crew
|
||||
|
||||
Welcome to the {{crew_name}} Crew project, powered by [crewAI](https://crewai.com). This template is designed to help you set up a multi-agent AI system with ease, leveraging the powerful and flexible framework provided by crewAI. Our goal is to enable your agents to collaborate effectively on complex tasks, maximizing their collective intelligence and capabilities.
|
||||
|
||||
## Installation
|
||||
|
||||
Ensure you have Python >=3.10 <3.14 installed on your system. This project uses [UV](https://docs.astral.sh/uv/) for dependency management and package handling, offering a seamless setup and execution experience.
|
||||
|
||||
First, if you haven't already, install uv:
|
||||
|
||||
```bash
|
||||
pip install uv
|
||||
```
|
||||
|
||||
Next, navigate to your project directory and install the dependencies:
|
||||
|
||||
(Optional) Lock the dependencies and install them by using the CLI command:
|
||||
```bash
|
||||
crewai install
|
||||
```
|
||||
|
||||
### Customizing
|
||||
|
||||
**Add your `OPENAI_API_KEY` into the `.env` file**
|
||||
|
||||
- Modify `src/{{folder_name}}/config/agents.yaml` to define your agents
|
||||
- Modify `src/{{folder_name}}/config/tasks.yaml` to define your tasks
|
||||
- Modify `src/{{folder_name}}/crew.py` to add your own logic, tools and specific args
|
||||
- Modify `src/{{folder_name}}/main.py` to add custom inputs for your agents and tasks
|
||||
|
||||
## Running the Project
|
||||
|
||||
To kickstart your flow and begin execution, run this from the root folder of your project:
|
||||
|
||||
```bash
|
||||
crewai run
|
||||
```
|
||||
|
||||
This command initializes the {{name}} Flow as defined in your configuration.
|
||||
|
||||
This example, unmodified, will run a content creation flow on AI Agents and save the output to `output/post.md`.
|
||||
|
||||
## Understanding Your Crew
|
||||
|
||||
The {{name}} Crew is composed of multiple AI agents, each with unique roles, goals, and tools. These agents collaborate on a series of tasks, defined in `config/tasks.yaml`, leveraging their collective skills to achieve complex objectives. The `config/agents.yaml` file outlines the capabilities and configurations of each agent in your crew.
|
||||
|
||||
## Support
|
||||
|
||||
For support, questions, or feedback regarding the {{crew_name}} Crew or crewAI.
|
||||
|
||||
- Visit our [documentation](https://docs.crewai.com)
|
||||
- Reach out to us through our [GitHub repository](https://github.com/joaomdmoura/crewai)
|
||||
- [Join our Discord](https://discord.com/invite/X4JWnZnxPb)
|
||||
- [Chat with our docs](https://chatg.pt/DWjSBZn)
|
||||
|
||||
Let's create wonders together with the power and simplicity of crewAI.
|
||||
0
lib/cli/src/crewai_cli/templates/flow/__init__.py
Normal file
0
lib/cli/src/crewai_cli/templates/flow/__init__.py
Normal file
@@ -0,0 +1,33 @@
|
||||
planner:
|
||||
role: >
|
||||
Content Planner
|
||||
goal: >
|
||||
Plan a detailed and engaging blog post outline on {topic}
|
||||
backstory: >
|
||||
You're an experienced content strategist who excels at creating
|
||||
structured outlines for blog posts. You know how to organize ideas
|
||||
into a logical flow that keeps readers engaged from start to finish.
|
||||
|
||||
writer:
|
||||
role: >
|
||||
Content Writer
|
||||
goal: >
|
||||
Write a compelling and well-structured blog post on {topic}
|
||||
based on the provided outline
|
||||
backstory: >
|
||||
You're a skilled writer with a talent for turning outlines into
|
||||
engaging, informative blog posts. Your writing is clear, conversational,
|
||||
and backed by solid reasoning. You adapt your tone to the subject matter
|
||||
while keeping things accessible to a broad audience.
|
||||
|
||||
editor:
|
||||
role: >
|
||||
Content Editor
|
||||
goal: >
|
||||
Review and polish the blog post on {topic} to ensure it is
|
||||
publication-ready
|
||||
backstory: >
|
||||
You're a meticulous editor with years of experience refining written
|
||||
content. You have an eye for clarity, flow, grammar, and consistency.
|
||||
You improve prose without changing the author's voice and ensure every
|
||||
piece you touch is polished and professional.
|
||||
@@ -0,0 +1,50 @@
|
||||
planning_task:
|
||||
description: >
|
||||
Create a detailed outline for a blog post about {topic}.
|
||||
|
||||
The outline should include:
|
||||
- A compelling title
|
||||
- An introduction hook
|
||||
- 3-5 main sections with key points to cover in each
|
||||
- A conclusion with a call to action
|
||||
|
||||
Make the outline detailed enough that a writer can produce
|
||||
a full blog post from it without additional research.
|
||||
expected_output: >
|
||||
A structured blog post outline with a title, introduction notes,
|
||||
detailed section breakdowns, and conclusion notes.
|
||||
agent: planner
|
||||
|
||||
writing_task:
|
||||
description: >
|
||||
Using the outline provided, write a full blog post about {topic}.
|
||||
|
||||
Requirements:
|
||||
- Follow the outline structure closely
|
||||
- Write in a clear, engaging, and conversational tone
|
||||
- Each section should be 2-3 paragraphs
|
||||
- Include a strong introduction and conclusion
|
||||
- Target around 800-1200 words
|
||||
expected_output: >
|
||||
A complete blog post in markdown format, ready for editing.
|
||||
The post should follow the outline and be well-written with
|
||||
clear transitions between sections.
|
||||
agent: writer
|
||||
|
||||
editing_task:
|
||||
description: >
|
||||
Review and edit the blog post about {topic}.
|
||||
|
||||
Focus on:
|
||||
- Fixing any grammar or spelling errors
|
||||
- Improving sentence clarity and flow
|
||||
- Ensuring consistent tone throughout
|
||||
- Strengthening the introduction and conclusion
|
||||
- Removing any redundancy
|
||||
|
||||
Do not rewrite the post — refine and polish it.
|
||||
expected_output: >
|
||||
The final, polished blog post in markdown format without '```'.
|
||||
Publication-ready with clean formatting and professional prose.
|
||||
agent: editor
|
||||
output_file: output/post.md
|
||||
@@ -0,0 +1,75 @@
|
||||
from crewai import Agent, Crew, Process, Task
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
from crewai.project import CrewBase, agent, crew, task
|
||||
|
||||
# If you want to run a snippet of code before or after the crew starts,
|
||||
# you can use the @before_kickoff and @after_kickoff decorators
|
||||
# https://docs.crewai.com/concepts/crews#example-crew-class-with-decorators
|
||||
|
||||
|
||||
@CrewBase
|
||||
class ContentCrew:
|
||||
"""Content Crew"""
|
||||
|
||||
agents: list[BaseAgent]
|
||||
tasks: list[Task]
|
||||
|
||||
# Learn more about YAML configuration files here:
|
||||
# Agents: https://docs.crewai.com/concepts/agents#yaml-configuration-recommended
|
||||
# Tasks: https://docs.crewai.com/concepts/tasks#yaml-configuration-recommended
|
||||
agents_config = "config/agents.yaml"
|
||||
tasks_config = "config/tasks.yaml"
|
||||
|
||||
# If you would like to add tools to your crew, you can learn more about it here:
|
||||
# https://docs.crewai.com/concepts/agents#agent-tools
|
||||
@agent
|
||||
def planner(self) -> Agent:
|
||||
return Agent(
|
||||
config=self.agents_config["planner"], # type: ignore[index]
|
||||
)
|
||||
|
||||
@agent
|
||||
def writer(self) -> Agent:
|
||||
return Agent(
|
||||
config=self.agents_config["writer"], # type: ignore[index]
|
||||
)
|
||||
|
||||
@agent
|
||||
def editor(self) -> Agent:
|
||||
return Agent(
|
||||
config=self.agents_config["editor"], # type: ignore[index]
|
||||
)
|
||||
|
||||
# To learn more about structured task outputs,
|
||||
# task dependencies, and task callbacks, check out the documentation:
|
||||
# https://docs.crewai.com/concepts/tasks#overview-of-a-task
|
||||
@task
|
||||
def planning_task(self) -> Task:
|
||||
return Task(
|
||||
config=self.tasks_config["planning_task"], # type: ignore[index]
|
||||
)
|
||||
|
||||
@task
|
||||
def writing_task(self) -> Task:
|
||||
return Task(
|
||||
config=self.tasks_config["writing_task"], # type: ignore[index]
|
||||
)
|
||||
|
||||
@task
|
||||
def editing_task(self) -> Task:
|
||||
return Task(
|
||||
config=self.tasks_config["editing_task"], # type: ignore[index]
|
||||
)
|
||||
|
||||
@crew
|
||||
def crew(self) -> Crew:
|
||||
"""Creates the Content Crew"""
|
||||
# To learn how to add knowledge sources to your crew, check out the documentation:
|
||||
# https://docs.crewai.com/concepts/knowledge#what-is-knowledge
|
||||
|
||||
return Crew(
|
||||
agents=self.agents, # Automatically created by the @agent decorator
|
||||
tasks=self.tasks, # Automatically created by the @task decorator
|
||||
process=Process.sequential,
|
||||
verbose=True,
|
||||
)
|
||||
92
lib/cli/src/crewai_cli/templates/flow/main.py
Normal file
92
lib/cli/src/crewai_cli/templates/flow/main.py
Normal file
@@ -0,0 +1,92 @@
|
||||
#!/usr/bin/env python
|
||||
from pathlib import Path
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from crewai.flow import Flow, listen, start
|
||||
|
||||
from {{folder_name}}.crews.content_crew.content_crew import ContentCrew
|
||||
|
||||
|
||||
class ContentState(BaseModel):
|
||||
topic: str = ""
|
||||
outline: str = ""
|
||||
draft: str = ""
|
||||
final_post: str = ""
|
||||
|
||||
|
||||
class ContentFlow(Flow[ContentState]):
|
||||
|
||||
@start()
|
||||
def plan_content(self, crewai_trigger_payload: dict = None):
|
||||
print("Planning content")
|
||||
|
||||
if crewai_trigger_payload:
|
||||
self.state.topic = crewai_trigger_payload.get("topic", "AI Agents")
|
||||
print(f"Using trigger payload: {crewai_trigger_payload}")
|
||||
else:
|
||||
self.state.topic = "AI Agents"
|
||||
|
||||
print(f"Topic: {self.state.topic}")
|
||||
|
||||
@listen(plan_content)
|
||||
def generate_content(self):
|
||||
print(f"Generating content on: {self.state.topic}")
|
||||
result = (
|
||||
ContentCrew()
|
||||
.crew()
|
||||
.kickoff(inputs={"topic": self.state.topic})
|
||||
)
|
||||
|
||||
print("Content generated")
|
||||
self.state.final_post = result.raw
|
||||
|
||||
@listen(generate_content)
|
||||
def save_content(self):
|
||||
print("Saving content")
|
||||
output_dir = Path("output")
|
||||
output_dir.mkdir(exist_ok=True)
|
||||
with open(output_dir / "post.md", "w") as f:
|
||||
f.write(self.state.final_post)
|
||||
print("Post saved to output/post.md")
|
||||
|
||||
|
||||
def kickoff():
|
||||
content_flow = ContentFlow()
|
||||
content_flow.kickoff()
|
||||
|
||||
|
||||
def plot():
|
||||
content_flow = ContentFlow()
|
||||
content_flow.plot()
|
||||
|
||||
|
||||
def run_with_trigger():
|
||||
"""
|
||||
Run the flow with trigger payload.
|
||||
"""
|
||||
import json
|
||||
import sys
|
||||
|
||||
# Get trigger payload from command line argument
|
||||
if len(sys.argv) < 2:
|
||||
raise Exception("No trigger payload provided. Please provide JSON payload as argument.")
|
||||
|
||||
try:
|
||||
trigger_payload = json.loads(sys.argv[1])
|
||||
except json.JSONDecodeError:
|
||||
raise Exception("Invalid JSON payload provided as argument")
|
||||
|
||||
# Create flow and kickoff with trigger payload
|
||||
# The @start() methods will automatically receive crewai_trigger_payload parameter
|
||||
content_flow = ContentFlow()
|
||||
|
||||
try:
|
||||
result = content_flow.kickoff({"crewai_trigger_payload": trigger_payload})
|
||||
return result
|
||||
except Exception as e:
|
||||
raise Exception(f"An error occurred while running the flow with trigger: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
kickoff()
|
||||
22
lib/cli/src/crewai_cli/templates/flow/pyproject.toml
Normal file
22
lib/cli/src/crewai_cli/templates/flow/pyproject.toml
Normal file
@@ -0,0 +1,22 @@
|
||||
[project]
|
||||
name = "{{folder_name}}"
|
||||
version = "0.1.0"
|
||||
description = "{{name}} using crewAI"
|
||||
authors = [{ name = "Your Name", email = "you@example.com" }]
|
||||
requires-python = ">=3.10,<3.14"
|
||||
dependencies = [
|
||||
"crewai[tools]==1.14.5a2"
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
kickoff = "{{folder_name}}.main:kickoff"
|
||||
run_crew = "{{folder_name}}.main:kickoff"
|
||||
plot = "{{folder_name}}.main:plot"
|
||||
run_with_trigger = "{{folder_name}}.main:run_with_trigger"
|
||||
|
||||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
[tool.crewai]
|
||||
type = "flow"
|
||||
21
lib/cli/src/crewai_cli/templates/flow/tools/custom_tool.py
Normal file
21
lib/cli/src/crewai_cli/templates/flow/tools/custom_tool.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from typing import Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
|
||||
|
||||
class MyCustomToolInput(BaseModel):
|
||||
"""Input schema for MyCustomTool."""
|
||||
|
||||
argument: str = Field(..., description="Description of the argument.")
|
||||
|
||||
|
||||
class MyCustomTool(BaseTool):
|
||||
name: str = "Name of my tool"
|
||||
description: str = "Clear description for what this tool is useful for, your agent will need this information to use it."
|
||||
args_schema: Type[BaseModel] = MyCustomToolInput
|
||||
|
||||
def _run(self, argument: str) -> str:
|
||||
# Implementation goes here
|
||||
return "this is an example of a tool output, ignore it and move along."
|
||||
10
lib/cli/src/crewai_cli/templates/tool/.gitignore
vendored
Normal file
10
lib/cli/src/crewai_cli/templates/tool/.gitignore
vendored
Normal file
@@ -0,0 +1,10 @@
|
||||
# Python-generated files
|
||||
__pycache__/
|
||||
*.py[oc]
|
||||
build/
|
||||
dist/
|
||||
wheels/
|
||||
*.egg-info
|
||||
|
||||
# Virtual environments
|
||||
.venv
|
||||
48
lib/cli/src/crewai_cli/templates/tool/README.md
Normal file
48
lib/cli/src/crewai_cli/templates/tool/README.md
Normal file
@@ -0,0 +1,48 @@
|
||||
# {{folder_name}}
|
||||
|
||||
{{folder_name}} is a CrewAI Tool. This template is designed to help you create
|
||||
custom tools to power up your crews.
|
||||
|
||||
## Installing
|
||||
|
||||
Ensure you have Python >=3.10 <3.14 installed on your system. This project
|
||||
uses [UV](https://docs.astral.sh/uv/) for dependency management and package
|
||||
handling, offering a seamless setup and execution experience.
|
||||
|
||||
First, if you haven't already, install `uv`:
|
||||
|
||||
```bash
|
||||
pip install uv
|
||||
```
|
||||
|
||||
Next, navigate to your project directory and install the dependencies with:
|
||||
|
||||
```bash
|
||||
crewai install
|
||||
```
|
||||
|
||||
## Publishing
|
||||
|
||||
Collaborate by sharing tools within your organization, or publish them publicly
|
||||
to contribute with the community.
|
||||
|
||||
```bash
|
||||
crewai tool publish {{tool_name}}
|
||||
```
|
||||
|
||||
Others may install your tool in their crews running:
|
||||
|
||||
```bash
|
||||
crewai tool install {{tool_name}}
|
||||
```
|
||||
|
||||
## Support
|
||||
|
||||
For support, questions, or feedback regarding the {{crew_name}} tool or CrewAI.
|
||||
|
||||
- Visit our [documentation](https://docs.crewai.com)
|
||||
- Reach out to us through our [GitHub repository](https://github.com/joaomdmoura/crewai)
|
||||
- [Join our Discord](https://discord.com/invite/X4JWnZnxPb)
|
||||
- [Chat with our docs](https://chatg.pt/DWjSBZn)
|
||||
|
||||
Let's create wonders together with the power and simplicity of crewAI.
|
||||
12
lib/cli/src/crewai_cli/templates/tool/pyproject.toml
Normal file
12
lib/cli/src/crewai_cli/templates/tool/pyproject.toml
Normal file
@@ -0,0 +1,12 @@
|
||||
[project]
|
||||
name = "{{folder_name}}"
|
||||
version = "0.1.0"
|
||||
description = "Power up your crews with {{folder_name}}"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10,<3.14"
|
||||
dependencies = [
|
||||
"crewai[tools]==1.14.5a2"
|
||||
]
|
||||
|
||||
[tool.crewai]
|
||||
type = "tool"
|
||||
@@ -0,0 +1,3 @@
|
||||
from .tool import {{class_name}}
|
||||
|
||||
__all__ = ["{{class_name}}"]
|
||||
@@ -0,0 +1,10 @@
|
||||
from crewai.tools import BaseTool
|
||||
|
||||
|
||||
class {{class_name}}(BaseTool):
|
||||
name: str = "Name of my tool"
|
||||
description: str = "What this tool does. It's vital for effective utilization."
|
||||
|
||||
def _run(self, argument: str) -> str:
|
||||
# Your tool's logic here
|
||||
return "Tool's result"
|
||||
0
lib/cli/src/crewai_cli/tools/__init__.py
Normal file
0
lib/cli/src/crewai_cli/tools/__init__.py
Normal file
364
lib/cli/src/crewai_cli/tools/main.py
Normal file
364
lib/cli/src/crewai_cli/tools/main.py
Normal file
@@ -0,0 +1,364 @@
|
||||
import base64
|
||||
from json import JSONDecodeError
|
||||
import os
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
import subprocess
|
||||
import tempfile
|
||||
from typing import Any
|
||||
|
||||
import click
|
||||
from rich.console import Console
|
||||
|
||||
from crewai_cli import git
|
||||
from crewai_cli.command import BaseCommand, PlusAPIMixin
|
||||
from crewai_cli.config import Settings
|
||||
from crewai_cli.constants import DEFAULT_CREWAI_ENTERPRISE_URL
|
||||
from crewai_cli.utils import (
|
||||
build_env_with_tool_repository_credentials,
|
||||
get_project_description,
|
||||
get_project_name,
|
||||
get_project_version,
|
||||
read_toml,
|
||||
tree_copy,
|
||||
tree_find_and_replace,
|
||||
)
|
||||
|
||||
|
||||
console = Console()
|
||||
|
||||
|
||||
_REQUIRES_CREWAI_MSG = (
|
||||
"[red]This subcommand requires the full crewai package.\n"
|
||||
"Install it with: pip install crewai[/red]"
|
||||
)
|
||||
|
||||
|
||||
def _require_project_utils() -> Any:
|
||||
try:
|
||||
from crewai.utilities import project_utils
|
||||
|
||||
return project_utils
|
||||
except ImportError:
|
||||
console.print(_REQUIRES_CREWAI_MSG)
|
||||
raise SystemExit(1) from None
|
||||
|
||||
|
||||
def _require_get_user_id() -> Any:
|
||||
try:
|
||||
from crewai.events.listeners.tracing.utils import get_user_id
|
||||
|
||||
return get_user_id
|
||||
except ImportError:
|
||||
console.print(_REQUIRES_CREWAI_MSG)
|
||||
raise SystemExit(1) from None
|
||||
|
||||
|
||||
class ToolCommand(BaseCommand, PlusAPIMixin):
|
||||
"""
|
||||
A class to handle tool repository related operations for CrewAI projects.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
BaseCommand.__init__(self)
|
||||
PlusAPIMixin.__init__(self, telemetry=self._telemetry)
|
||||
|
||||
def create(self, handle: str) -> None:
|
||||
self._ensure_not_in_project()
|
||||
|
||||
folder_name = handle.replace(" ", "_").replace("-", "_").lower()
|
||||
class_name = handle.replace("_", " ").replace("-", " ").title().replace(" ", "")
|
||||
|
||||
project_root = Path(folder_name)
|
||||
if project_root.exists():
|
||||
click.secho(f"Folder {folder_name} already exists.", fg="red")
|
||||
raise SystemExit
|
||||
os.makedirs(project_root)
|
||||
|
||||
click.secho(f"Creating custom tool {folder_name}...", fg="green", bold=True)
|
||||
|
||||
template_dir = Path(__file__).parent.parent / "templates" / "tool"
|
||||
tree_copy(template_dir, project_root)
|
||||
tree_find_and_replace(project_root, "{{folder_name}}", folder_name)
|
||||
tree_find_and_replace(project_root, "{{class_name}}", class_name)
|
||||
|
||||
# Copy AGENTS.md to project root
|
||||
agents_md_src = Path(__file__).parent.parent / "templates" / "AGENTS.md"
|
||||
if agents_md_src.exists():
|
||||
shutil.copy2(agents_md_src, project_root / "AGENTS.md")
|
||||
|
||||
old_directory = os.getcwd()
|
||||
os.chdir(project_root)
|
||||
try:
|
||||
self.login()
|
||||
subprocess.run(["git", "init"], check=True) # noqa: S607
|
||||
console.print(
|
||||
f"[green]Created custom tool [bold]{folder_name}[/bold]. Run [bold]cd {project_root}[/bold] to start working.[/green]"
|
||||
)
|
||||
finally:
|
||||
os.chdir(old_directory)
|
||||
|
||||
def publish(self, is_public: bool, force: bool = False) -> None:
|
||||
if not git.Repository().is_synced() and not force:
|
||||
console.print(
|
||||
"[bold red]Failed to publish tool.[/bold red]\n"
|
||||
"Local changes need to be resolved before publishing. Please do the following:\n"
|
||||
"* [bold]Commit[/bold] your changes.\n"
|
||||
"* [bold]Push[/bold] to sync with the remote.\n"
|
||||
"* [bold]Pull[/bold] the latest changes from the remote.\n"
|
||||
"\nOnce your repository is up-to-date, retry publishing the tool."
|
||||
)
|
||||
raise SystemExit()
|
||||
|
||||
project_name = get_project_name(require=True)
|
||||
assert isinstance(project_name, str) # noqa: S101
|
||||
|
||||
project_version = get_project_version(require=True)
|
||||
assert isinstance(project_version, str) # noqa: S101
|
||||
|
||||
project_description = get_project_description(require=False)
|
||||
encoded_tarball = None
|
||||
|
||||
console.print("[bold blue]Discovering tools from your project...[/bold blue]")
|
||||
project_utils = _require_project_utils()
|
||||
available_exports = project_utils.extract_available_exports()
|
||||
|
||||
if available_exports:
|
||||
console.print(
|
||||
f"[green]Found these tools to publish: {', '.join([e['name'] for e in available_exports])}[/green]"
|
||||
)
|
||||
|
||||
console.print("[bold blue]Extracting tool metadata...[/bold blue]")
|
||||
try:
|
||||
tools_metadata = project_utils.extract_tools_metadata()
|
||||
except Exception as e:
|
||||
console.print(
|
||||
f"[yellow]Warning: Could not extract tool metadata: {e}[/yellow]\n"
|
||||
f"Publishing will continue without detailed metadata."
|
||||
)
|
||||
tools_metadata = []
|
||||
|
||||
self._print_tools_preview(tools_metadata)
|
||||
self._print_current_organization()
|
||||
|
||||
build_env = os.environ.copy()
|
||||
try:
|
||||
pyproject_data = read_toml()
|
||||
sources = pyproject_data.get("tool", {}).get("uv", {}).get("sources", {})
|
||||
|
||||
for source_config in sources.values():
|
||||
if isinstance(source_config, dict):
|
||||
index = source_config.get("index")
|
||||
if index:
|
||||
index_env = build_env_with_tool_repository_credentials(index)
|
||||
build_env.update(index_env)
|
||||
except Exception: # noqa: S110
|
||||
pass
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_build_dir:
|
||||
subprocess.run( # noqa: S603
|
||||
["uv", "build", "--sdist", "--out-dir", temp_build_dir], # noqa: S607
|
||||
check=True,
|
||||
capture_output=False,
|
||||
env=build_env,
|
||||
)
|
||||
|
||||
tarball_filename = next(
|
||||
(f for f in os.listdir(temp_build_dir) if f.endswith(".tar.gz")), None
|
||||
)
|
||||
if not tarball_filename:
|
||||
console.print(
|
||||
"Project build failed. Please ensure that the command `uv build --sdist` completes successfully.",
|
||||
style="bold red",
|
||||
)
|
||||
raise SystemExit(1)
|
||||
|
||||
tarball_path = os.path.join(temp_build_dir, tarball_filename)
|
||||
with open(tarball_path, "rb") as file:
|
||||
tarball_contents = file.read()
|
||||
|
||||
encoded_tarball = base64.b64encode(tarball_contents).decode("utf-8")
|
||||
|
||||
console.print("[bold blue]Publishing tool to repository...[/bold blue]")
|
||||
publish_response = self.plus_api_client.publish_tool(
|
||||
handle=project_name,
|
||||
is_public=is_public,
|
||||
version=project_version,
|
||||
description=project_description,
|
||||
encoded_file=f"data:application/x-gzip;base64,{encoded_tarball}",
|
||||
available_exports=available_exports,
|
||||
tools_metadata=tools_metadata,
|
||||
)
|
||||
|
||||
self._validate_response(publish_response)
|
||||
|
||||
published_handle = publish_response.json()["handle"]
|
||||
settings = Settings()
|
||||
base_url = settings.enterprise_base_url or DEFAULT_CREWAI_ENTERPRISE_URL
|
||||
|
||||
console.print(
|
||||
f"Successfully published `{published_handle}` ({project_version}).\n\n"
|
||||
+ "⚠️ Security checks are running in the background. Your tool will be available once these are complete.\n"
|
||||
+ f"You can monitor the status or access your tool here:\n{base_url}/crewai_plus/tools/{published_handle}",
|
||||
style="bold green",
|
||||
)
|
||||
|
||||
def install(self, handle: str) -> None:
|
||||
self._print_current_organization()
|
||||
get_response = self.plus_api_client.get_tool(handle)
|
||||
|
||||
if get_response.status_code == 404:
|
||||
console.print(
|
||||
"No tool found with this name. Please ensure the tool was published and you have access to it.",
|
||||
style="bold red",
|
||||
)
|
||||
raise SystemExit
|
||||
if get_response.status_code != 200:
|
||||
console.print(
|
||||
"Failed to get tool details. Please try again later.", style="bold red"
|
||||
)
|
||||
raise SystemExit
|
||||
|
||||
self._add_package(get_response.json())
|
||||
|
||||
console.print(f"Successfully installed {handle}", style="bold green")
|
||||
|
||||
def login(self) -> None:
|
||||
get_user_id = _require_get_user_id()
|
||||
login_response = self.plus_api_client.login_to_tool_repository(
|
||||
user_identifier=get_user_id()
|
||||
)
|
||||
|
||||
if login_response.status_code != 200:
|
||||
console.print(
|
||||
"Authentication failed. Verify if the currently active organization can access the tool repository, and run 'crewai login' again.",
|
||||
style="bold red",
|
||||
)
|
||||
try:
|
||||
console.print(
|
||||
f"[{login_response.status_code} error - {login_response.json().get('message', 'Unknown error')}]",
|
||||
style="bold red italic",
|
||||
)
|
||||
except JSONDecodeError:
|
||||
console.print(
|
||||
f"[{login_response.status_code} error - Unknown error - Invalid JSON response]",
|
||||
style="bold red italic",
|
||||
)
|
||||
raise SystemExit
|
||||
|
||||
login_response_json = login_response.json()
|
||||
|
||||
settings = Settings()
|
||||
settings.tool_repository_username = login_response_json["credential"][
|
||||
"username"
|
||||
]
|
||||
settings.tool_repository_password = login_response_json["credential"][
|
||||
"password"
|
||||
]
|
||||
settings.org_uuid = login_response_json["current_organization"]["uuid"]
|
||||
settings.org_name = login_response_json["current_organization"]["name"]
|
||||
settings.dump()
|
||||
|
||||
def _add_package(self, tool_details: dict[str, Any]) -> None:
|
||||
is_from_pypi = tool_details.get("source", None) == "pypi"
|
||||
tool_handle = tool_details["handle"]
|
||||
repository_handle = tool_details["repository"]["handle"]
|
||||
repository_url = tool_details["repository"]["url"]
|
||||
index = f"{repository_handle}={repository_url}"
|
||||
|
||||
add_package_command = [
|
||||
"uv",
|
||||
"add",
|
||||
]
|
||||
|
||||
if is_from_pypi:
|
||||
add_package_command.append(tool_handle)
|
||||
else:
|
||||
add_package_command.extend(["--index", index, tool_handle])
|
||||
|
||||
add_package_result = subprocess.run( # noqa: S603
|
||||
add_package_command,
|
||||
capture_output=False,
|
||||
env=build_env_with_tool_repository_credentials(repository_handle),
|
||||
text=True,
|
||||
check=True,
|
||||
)
|
||||
|
||||
if add_package_result.stderr:
|
||||
click.echo(add_package_result.stderr, err=True)
|
||||
raise SystemExit
|
||||
|
||||
def _ensure_not_in_project(self) -> None:
|
||||
if os.path.isfile("./pyproject.toml"):
|
||||
console.print(
|
||||
"[bold red]Oops! It looks like you're inside a project.[/bold red]"
|
||||
)
|
||||
console.print(
|
||||
"You can't create a new tool while inside an existing project."
|
||||
)
|
||||
console.print(
|
||||
"[bold yellow]Tip:[/bold yellow] Navigate to a different directory and try again."
|
||||
)
|
||||
raise SystemExit
|
||||
|
||||
def _print_tools_preview(self, tools_metadata: list[dict[str, Any]]) -> None:
|
||||
if not tools_metadata:
|
||||
console.print("[yellow]No tool metadata extracted.[/yellow]")
|
||||
return
|
||||
|
||||
console.print(
|
||||
f"\n[bold]Tools to be published ({len(tools_metadata)}):[/bold]\n"
|
||||
)
|
||||
|
||||
for tool in tools_metadata:
|
||||
console.print(f" [bold cyan]{tool.get('name', 'Unknown')}[/bold cyan]")
|
||||
if tool.get("module"):
|
||||
console.print(f" Module: {tool.get('module')}")
|
||||
console.print(f" Name: {tool.get('humanized_name', 'N/A')}")
|
||||
console.print(
|
||||
f" Description: {tool.get('description', 'N/A')[:80]}{'...' if len(tool.get('description', '')) > 80 else ''}"
|
||||
)
|
||||
|
||||
init_params = tool.get("init_params_schema", {}).get("properties", {})
|
||||
if init_params:
|
||||
required = tool.get("init_params_schema", {}).get("required", [])
|
||||
console.print(" Init parameters:")
|
||||
for param_name, param_info in init_params.items():
|
||||
param_type = param_info.get("type", "any")
|
||||
is_required = param_name in required
|
||||
req_marker = "[red]*[/red]" if is_required else ""
|
||||
default = (
|
||||
f" = {param_info['default']}" if "default" in param_info else ""
|
||||
)
|
||||
console.print(
|
||||
f" - {param_name}: {param_type}{default} {req_marker}"
|
||||
)
|
||||
|
||||
env_vars = tool.get("env_vars", [])
|
||||
if env_vars:
|
||||
console.print(" Environment variables:")
|
||||
for env_var in env_vars:
|
||||
req_marker = "[red]*[/red]" if env_var.get("required") else ""
|
||||
default = (
|
||||
f" (default: {env_var['default']})"
|
||||
if env_var.get("default")
|
||||
else ""
|
||||
)
|
||||
console.print(
|
||||
f" - {env_var['name']}: {env_var.get('description', 'N/A')}{default} {req_marker}"
|
||||
)
|
||||
|
||||
console.print()
|
||||
|
||||
def _print_current_organization(self) -> None:
|
||||
settings = Settings()
|
||||
if settings.org_uuid:
|
||||
console.print(
|
||||
f"Current organization: {settings.org_name} ({settings.org_uuid})",
|
||||
style="bold blue",
|
||||
)
|
||||
else:
|
||||
console.print(
|
||||
"No organization currently set. We recommend setting one before using: `crewai org switch <org_id>` command.",
|
||||
style="yellow",
|
||||
)
|
||||
32
lib/cli/src/crewai_cli/train_crew.py
Normal file
32
lib/cli/src/crewai_cli/train_crew.py
Normal file
@@ -0,0 +1,32 @@
|
||||
import subprocess
|
||||
|
||||
import click
|
||||
|
||||
|
||||
def train_crew(n_iterations: int, filename: str) -> None:
|
||||
"""
|
||||
Train the crew by running a command in the UV environment.
|
||||
|
||||
Args:
|
||||
n_iterations (int): The number of iterations to train the crew.
|
||||
"""
|
||||
command = ["uv", "run", "train", str(n_iterations), filename]
|
||||
|
||||
try:
|
||||
if n_iterations <= 0:
|
||||
raise ValueError("The number of iterations must be a positive integer.")
|
||||
|
||||
if not filename.endswith(".pkl"):
|
||||
raise ValueError("The filename must not end with .pkl")
|
||||
|
||||
result = subprocess.run(command, capture_output=False, text=True, check=True) # noqa: S603
|
||||
|
||||
if result.stderr:
|
||||
click.echo(result.stderr, err=True)
|
||||
|
||||
except subprocess.CalledProcessError as e:
|
||||
click.echo(f"An error occurred while training the crew: {e}", err=True)
|
||||
click.echo(e.output, err=True)
|
||||
|
||||
except Exception as e:
|
||||
click.echo(f"An unexpected error occurred: {e}", err=True)
|
||||
6
lib/cli/src/crewai_cli/triggers/__init__.py
Normal file
6
lib/cli/src/crewai_cli/triggers/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
"""Triggers command module for CrewAI CLI."""
|
||||
|
||||
from crewai_cli.triggers.main import TriggersCommand
|
||||
|
||||
|
||||
__all__ = ["TriggersCommand"]
|
||||
137
lib/cli/src/crewai_cli/triggers/main.py
Normal file
137
lib/cli/src/crewai_cli/triggers/main.py
Normal file
@@ -0,0 +1,137 @@
|
||||
import json
|
||||
import subprocess
|
||||
from typing import Any
|
||||
|
||||
from rich.console import Console
|
||||
from rich.table import Table
|
||||
|
||||
from crewai_cli.command import BaseCommand, PlusAPIMixin
|
||||
|
||||
|
||||
console = Console()
|
||||
|
||||
|
||||
class TriggersCommand(BaseCommand, PlusAPIMixin):
|
||||
"""
|
||||
A class to handle trigger-related operations for CrewAI projects.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
BaseCommand.__init__(self)
|
||||
PlusAPIMixin.__init__(self, telemetry=self._telemetry)
|
||||
|
||||
def list_triggers(self) -> None:
|
||||
"""List all available triggers from integrations."""
|
||||
try:
|
||||
console.print("[bold blue]Fetching available triggers...[/bold blue]")
|
||||
response = self.plus_api_client.get_triggers()
|
||||
self._validate_response(response)
|
||||
|
||||
triggers_data = response.json()
|
||||
self._display_triggers(triggers_data)
|
||||
|
||||
except Exception as e:
|
||||
console.print(f"[bold red]Error fetching triggers: {e}[/bold red]")
|
||||
raise SystemExit(1) from e
|
||||
|
||||
def execute_with_trigger(self, trigger_path: str) -> None:
|
||||
"""Execute crew with trigger payload."""
|
||||
try:
|
||||
# Parse app_slug/trigger_slug
|
||||
if "/" not in trigger_path:
|
||||
console.print(
|
||||
"[bold red]Error: Trigger must be in format 'app_slug/trigger_slug'[/bold red]"
|
||||
)
|
||||
raise SystemExit(1)
|
||||
|
||||
app_slug, trigger_slug = trigger_path.split("/", 1)
|
||||
|
||||
console.print(
|
||||
f"[bold blue]Fetching trigger payload for {app_slug}/{trigger_slug}...[/bold blue]"
|
||||
)
|
||||
response = self.plus_api_client.get_trigger_payload(app_slug, trigger_slug)
|
||||
|
||||
if response.status_code == 404:
|
||||
error_data = response.json()
|
||||
console.print(
|
||||
f"[bold red]Error: {error_data.get('error', 'Trigger not found')}[/bold red]"
|
||||
)
|
||||
raise SystemExit(1)
|
||||
|
||||
self._validate_response(response)
|
||||
|
||||
trigger_data = response.json()
|
||||
self._display_trigger_info(trigger_data)
|
||||
|
||||
# Run crew with trigger payload
|
||||
self._run_crew_with_payload(trigger_data.get("sample_payload", {}))
|
||||
|
||||
except Exception as e:
|
||||
console.print(
|
||||
f"[bold red]Error executing crew with trigger: {e}[/bold red]"
|
||||
)
|
||||
raise SystemExit(1) from e
|
||||
|
||||
def _display_triggers(self, triggers_data: dict[str, Any]) -> None:
|
||||
"""Display triggers in a formatted table."""
|
||||
apps = triggers_data.get("apps", [])
|
||||
|
||||
if not apps:
|
||||
console.print("[yellow]No triggers found.[/yellow]")
|
||||
return
|
||||
|
||||
for app in apps:
|
||||
app_name = app.get("name", "Unknown App")
|
||||
app_slug = app.get("slug", "unknown")
|
||||
is_connected = app.get("is_connected", False)
|
||||
connection_status = (
|
||||
"[green]✓ Connected[/green]"
|
||||
if is_connected
|
||||
else "[red]✗ Not Connected[/red]"
|
||||
)
|
||||
|
||||
console.print(
|
||||
f"\n[bold cyan]{app_name}[/bold cyan] ({app_slug}) - {connection_status}"
|
||||
)
|
||||
console.print(
|
||||
f"[dim]{app.get('description', 'No description available')}[/dim]"
|
||||
)
|
||||
|
||||
triggers = app.get("triggers", [])
|
||||
if triggers:
|
||||
table = Table(show_header=True, header_style="bold magenta")
|
||||
table.add_column("Trigger", style="cyan")
|
||||
table.add_column("Name", style="green")
|
||||
table.add_column("Description", style="dim")
|
||||
|
||||
for trigger in triggers:
|
||||
trigger_path = f"{app_slug}/{trigger.get('slug', 'unknown')}"
|
||||
table.add_row(
|
||||
trigger_path,
|
||||
trigger.get("name", "Unknown"),
|
||||
trigger.get("description", "No description"),
|
||||
)
|
||||
|
||||
console.print(table)
|
||||
else:
|
||||
console.print("[dim] No triggers available[/dim]")
|
||||
|
||||
def _display_trigger_info(self, trigger_data: dict[str, Any]) -> None:
|
||||
"""Display trigger information before execution."""
|
||||
sample_payload = trigger_data.get("sample_payload", {})
|
||||
if sample_payload:
|
||||
console.print("\n[bold yellow]Sample Payload:[/bold yellow]")
|
||||
console.print(json.dumps(sample_payload, indent=2))
|
||||
|
||||
def _run_crew_with_payload(self, payload: dict[str, Any]) -> None:
|
||||
"""Run the crew with the trigger payload using the run_with_trigger method."""
|
||||
try:
|
||||
subprocess.run( # noqa: S603
|
||||
["uv", "run", "run_with_trigger", json.dumps(payload)], # noqa: S607
|
||||
capture_output=False,
|
||||
text=True,
|
||||
check=True,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise SystemExit(1) from e
|
||||
122
lib/cli/src/crewai_cli/update_crew.py
Normal file
122
lib/cli/src/crewai_cli/update_crew.py
Normal file
@@ -0,0 +1,122 @@
|
||||
import os
|
||||
import shutil
|
||||
from typing import Any
|
||||
|
||||
import tomli_w
|
||||
|
||||
from crewai_cli.utils import read_toml
|
||||
|
||||
|
||||
def update_crew() -> None:
|
||||
"""Update the pyproject.toml of the Crew project to use uv."""
|
||||
migrate_pyproject("pyproject.toml", "pyproject.toml")
|
||||
|
||||
|
||||
def migrate_pyproject(input_file: str, output_file: str) -> None:
|
||||
"""
|
||||
Migrate the pyproject.toml to the new format.
|
||||
|
||||
This function is used to migrate the pyproject.toml to the new format.
|
||||
And it will be used to migrate the pyproject.toml to the new format when uv is used.
|
||||
When the time comes that uv supports the new format, this function will be deprecated.
|
||||
"""
|
||||
poetry_data = {}
|
||||
# Read the input pyproject.toml
|
||||
pyproject_data = read_toml()
|
||||
|
||||
new_pyproject: dict[str, Any] = {
|
||||
"project": {},
|
||||
"build-system": {"requires": ["hatchling"], "build-backend": "hatchling.build"},
|
||||
}
|
||||
|
||||
# Migrate project metadata
|
||||
if "tool" in pyproject_data and "poetry" in pyproject_data["tool"]:
|
||||
poetry_data = pyproject_data["tool"]["poetry"]
|
||||
new_pyproject["project"]["name"] = poetry_data.get("name")
|
||||
new_pyproject["project"]["version"] = poetry_data.get("version")
|
||||
new_pyproject["project"]["description"] = poetry_data.get("description")
|
||||
new_pyproject["project"]["authors"] = [
|
||||
{
|
||||
"name": author.split("<")[0].strip(),
|
||||
"email": author.split("<")[1].strip(">").strip(),
|
||||
}
|
||||
for author in poetry_data.get("authors", [])
|
||||
]
|
||||
new_pyproject["project"]["requires-python"] = poetry_data.get("python")
|
||||
else:
|
||||
# If it's already in the new format, just copy the project and tool sections
|
||||
new_pyproject["project"] = pyproject_data.get("project", {})
|
||||
new_pyproject["tool"] = pyproject_data.get("tool", {})
|
||||
|
||||
# Migrate or copy dependencies
|
||||
if "dependencies" in new_pyproject["project"]:
|
||||
# If dependencies are already in the new format, keep them as is
|
||||
pass
|
||||
elif poetry_data and "dependencies" in poetry_data:
|
||||
new_pyproject["project"]["dependencies"] = []
|
||||
for dep, version in poetry_data["dependencies"].items():
|
||||
if isinstance(version, dict): # Handle extras
|
||||
extras = ",".join(version.get("extras", []))
|
||||
new_dep = f"{dep}[{extras}]"
|
||||
if "version" in version:
|
||||
new_dep += parse_version(version["version"])
|
||||
elif dep == "python":
|
||||
new_pyproject["project"]["requires-python"] = version
|
||||
continue
|
||||
else:
|
||||
new_dep = f"{dep}{parse_version(version)}"
|
||||
new_pyproject["project"]["dependencies"].append(new_dep)
|
||||
|
||||
# Migrate or copy scripts
|
||||
if poetry_data and "scripts" in poetry_data:
|
||||
new_pyproject["project"]["scripts"] = poetry_data["scripts"]
|
||||
elif pyproject_data.get("project", {}) and "scripts" in pyproject_data["project"]:
|
||||
new_pyproject["project"]["scripts"] = pyproject_data["project"]["scripts"]
|
||||
else:
|
||||
new_pyproject["project"]["scripts"] = {}
|
||||
|
||||
if (
|
||||
"run_crew" not in new_pyproject["project"]["scripts"]
|
||||
and len(new_pyproject["project"]["scripts"]) > 0
|
||||
):
|
||||
# Extract the module name from any existing script
|
||||
existing_scripts = new_pyproject["project"]["scripts"]
|
||||
module_name = next(
|
||||
(value.split(".")[0] for value in existing_scripts.values() if "." in value)
|
||||
)
|
||||
|
||||
new_pyproject["project"]["scripts"]["run_crew"] = f"{module_name}.main:run"
|
||||
|
||||
# Migrate optional dependencies
|
||||
if poetry_data and "extras" in poetry_data:
|
||||
new_pyproject["project"]["optional-dependencies"] = poetry_data["extras"]
|
||||
|
||||
# Backup the old pyproject.toml
|
||||
backup_file = "pyproject-old.toml"
|
||||
shutil.copy2(input_file, backup_file)
|
||||
|
||||
# Rename the poetry.lock file
|
||||
lock_file = "poetry.lock"
|
||||
lock_backup = "poetry-old.lock"
|
||||
if os.path.exists(lock_file):
|
||||
os.rename(lock_file, lock_backup)
|
||||
else:
|
||||
pass
|
||||
|
||||
# Write the new pyproject.toml
|
||||
with open(output_file, "wb") as f:
|
||||
tomli_w.dump(new_pyproject, f)
|
||||
|
||||
|
||||
def parse_version(version: str) -> str:
|
||||
"""Parse and convert version specifiers."""
|
||||
if version.startswith("^"):
|
||||
main_lib_version = version[1:].split(",")[0]
|
||||
addtional_lib_version = None
|
||||
if len(version[1:].split(",")) > 1:
|
||||
addtional_lib_version = version[1:].split(",")[1]
|
||||
|
||||
return f">={main_lib_version}" + (
|
||||
f",{addtional_lib_version}" if addtional_lib_version else ""
|
||||
)
|
||||
return version
|
||||
22
lib/cli/src/crewai_cli/user_data.py
Normal file
22
lib/cli/src/crewai_cli/user_data.py
Normal file
@@ -0,0 +1,22 @@
|
||||
"""User-data helpers — re-exported from ``crewai_core.user_data``."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from crewai_core.paths import db_storage_path as _db_storage_path
|
||||
from crewai_core.user_data import (
|
||||
_load_user_data as _load_user_data,
|
||||
_save_user_data as _save_user_data,
|
||||
has_user_declined_tracing as has_user_declined_tracing,
|
||||
is_tracing_enabled as is_tracing_enabled,
|
||||
update_user_data as update_user_data,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"_db_storage_path",
|
||||
"_load_user_data",
|
||||
"_save_user_data",
|
||||
"has_user_declined_tracing",
|
||||
"is_tracing_enabled",
|
||||
"update_user_data",
|
||||
]
|
||||
137
lib/cli/src/crewai_cli/utils.py
Normal file
137
lib/cli/src/crewai_cli/utils.py
Normal file
@@ -0,0 +1,137 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
from typing import Any
|
||||
|
||||
import click
|
||||
from crewai_core.project import (
|
||||
get_project_description as get_project_description,
|
||||
get_project_name as get_project_name,
|
||||
get_project_version as get_project_version,
|
||||
parse_toml as parse_toml,
|
||||
read_toml as read_toml,
|
||||
)
|
||||
from crewai_core.tool_credentials import (
|
||||
build_env_with_all_tool_credentials as build_env_with_all_tool_credentials,
|
||||
build_env_with_tool_repository_credentials as build_env_with_tool_repository_credentials,
|
||||
)
|
||||
from rich.console import Console
|
||||
|
||||
|
||||
__all__ = [
|
||||
"build_env_with_all_tool_credentials",
|
||||
"build_env_with_tool_repository_credentials",
|
||||
"copy_template",
|
||||
"fetch_and_json_env_file",
|
||||
"get_project_description",
|
||||
"get_project_name",
|
||||
"get_project_version",
|
||||
"load_env_vars",
|
||||
"parse_toml",
|
||||
"read_toml",
|
||||
"tree_copy",
|
||||
"tree_find_and_replace",
|
||||
"write_env_file",
|
||||
]
|
||||
|
||||
|
||||
console = Console()
|
||||
|
||||
|
||||
def copy_template(
|
||||
src: Path, dst: Path, name: str, class_name: str, folder_name: str
|
||||
) -> None:
|
||||
"""Copy a file from src to dst."""
|
||||
with open(src, "r") as file:
|
||||
content = file.read()
|
||||
|
||||
content = content.replace("{{name}}", name)
|
||||
content = content.replace("{{crew_name}}", class_name)
|
||||
content = content.replace("{{folder_name}}", folder_name)
|
||||
|
||||
with open(dst, "w") as file:
|
||||
file.write(content)
|
||||
|
||||
click.secho(f" - Created {dst}", fg="green")
|
||||
|
||||
|
||||
def fetch_and_json_env_file(env_file_path: str = ".env") -> dict[str, Any]:
|
||||
"""Fetch the environment variables from a .env file and return them as a dictionary."""
|
||||
try:
|
||||
with open(env_file_path, "r") as f:
|
||||
env_content = f.read()
|
||||
|
||||
env_dict = {}
|
||||
for line in env_content.splitlines():
|
||||
if line.strip() and not line.strip().startswith("#"):
|
||||
key, value = line.split("=", 1)
|
||||
env_dict[key.strip()] = value.strip()
|
||||
|
||||
return env_dict
|
||||
|
||||
except FileNotFoundError:
|
||||
console.print(f"Error: {env_file_path} not found.", style="bold red")
|
||||
except Exception as e:
|
||||
console.print(f"Error reading the .env file: {e}", style="bold red")
|
||||
|
||||
return {}
|
||||
|
||||
|
||||
def tree_copy(source: Path, destination: Path) -> None:
|
||||
"""Copies the entire directory structure from the source to the destination."""
|
||||
for item in os.listdir(source):
|
||||
source_item = os.path.join(source, item)
|
||||
destination_item = os.path.join(destination, item)
|
||||
if os.path.isdir(source_item):
|
||||
shutil.copytree(source_item, destination_item)
|
||||
else:
|
||||
shutil.copy2(source_item, destination_item)
|
||||
|
||||
|
||||
def tree_find_and_replace(directory: Path, find: str, replace: str) -> None:
|
||||
"""Recursively searches through a directory, replacing a target string in
|
||||
both file contents and filenames with a specified replacement string.
|
||||
"""
|
||||
for path, dirs, files in os.walk(os.path.abspath(directory), topdown=False):
|
||||
for filename in files:
|
||||
filepath = os.path.join(path, filename)
|
||||
|
||||
with open(filepath, "r", encoding="utf-8", errors="ignore") as file:
|
||||
contents = file.read()
|
||||
with open(filepath, "w") as file:
|
||||
file.write(contents.replace(find, replace))
|
||||
|
||||
if find in filename:
|
||||
new_filename = filename.replace(find, replace)
|
||||
new_filepath = os.path.join(path, new_filename)
|
||||
os.rename(filepath, new_filepath)
|
||||
|
||||
for dirname in dirs:
|
||||
if find in dirname:
|
||||
new_dirname = dirname.replace(find, replace)
|
||||
new_dirpath = os.path.join(path, new_dirname)
|
||||
old_dirpath = os.path.join(path, dirname)
|
||||
os.rename(old_dirpath, new_dirpath)
|
||||
|
||||
|
||||
def load_env_vars(folder_path: Path) -> dict[str, Any]:
|
||||
"""Loads environment variables from a .env file in the specified folder path."""
|
||||
env_file_path = folder_path / ".env"
|
||||
env_vars = {}
|
||||
if env_file_path.exists():
|
||||
with open(env_file_path, "r") as file:
|
||||
for line in file:
|
||||
key, _, value = line.strip().partition("=")
|
||||
if key and value:
|
||||
env_vars[key] = value
|
||||
return env_vars
|
||||
|
||||
|
||||
def write_env_file(folder_path: Path, env_vars: dict[str, Any]) -> None:
|
||||
"""Writes environment variables to a .env file in the specified folder."""
|
||||
env_file_path = folder_path / ".env"
|
||||
with open(env_file_path, "w") as file:
|
||||
for key, value in env_vars.items():
|
||||
file.write(f"{key.upper()}={value}\n")
|
||||
24
lib/cli/src/crewai_cli/version.py
Normal file
24
lib/cli/src/crewai_cli/version.py
Normal file
@@ -0,0 +1,24 @@
|
||||
"""Re-exports of version utilities from ``crewai_core.version``.
|
||||
|
||||
Kept as a stable import path for the CLI; new code should import from
|
||||
``crewai_core.version`` directly.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from crewai_core.version import (
|
||||
check_version as check_version,
|
||||
get_crewai_version as get_crewai_version,
|
||||
get_latest_version_from_pypi as get_latest_version_from_pypi,
|
||||
is_current_version_yanked as is_current_version_yanked,
|
||||
is_newer_version_available as is_newer_version_available,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"check_version",
|
||||
"get_crewai_version",
|
||||
"get_latest_version_from_pypi",
|
||||
"is_current_version_yanked",
|
||||
"is_newer_version_available",
|
||||
]
|
||||
0
lib/cli/tests/__init__.py
Normal file
0
lib/cli/tests/__init__.py
Normal file
0
lib/cli/tests/authentication/__init__.py
Normal file
0
lib/cli/tests/authentication/__init__.py
Normal file
0
lib/cli/tests/authentication/providers/__init__.py
Normal file
0
lib/cli/tests/authentication/providers/__init__.py
Normal file
91
lib/cli/tests/authentication/providers/test_auth0.py
Normal file
91
lib/cli/tests/authentication/providers/test_auth0.py
Normal file
@@ -0,0 +1,91 @@
|
||||
import pytest
|
||||
from crewai_cli.authentication.main import Oauth2Settings
|
||||
from crewai_cli.authentication.providers.auth0 import Auth0Provider
|
||||
|
||||
|
||||
|
||||
class TestAuth0Provider:
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_method(self):
|
||||
self.valid_settings = Oauth2Settings(
|
||||
provider="auth0",
|
||||
domain="test-domain.auth0.com",
|
||||
client_id="test-client-id",
|
||||
audience="test-audience"
|
||||
)
|
||||
self.provider = Auth0Provider(self.valid_settings)
|
||||
|
||||
def test_initialization_with_valid_settings(self):
|
||||
provider = Auth0Provider(self.valid_settings)
|
||||
assert provider.settings == self.valid_settings
|
||||
assert provider.settings.provider == "auth0"
|
||||
assert provider.settings.domain == "test-domain.auth0.com"
|
||||
assert provider.settings.client_id == "test-client-id"
|
||||
assert provider.settings.audience == "test-audience"
|
||||
|
||||
def test_get_authorize_url(self):
|
||||
expected_url = "https://test-domain.auth0.com/oauth/device/code"
|
||||
assert self.provider.get_authorize_url() == expected_url
|
||||
|
||||
def test_get_authorize_url_with_different_domain(self):
|
||||
settings = Oauth2Settings(
|
||||
provider="auth0",
|
||||
domain="my-company.auth0.com",
|
||||
client_id="test-client",
|
||||
audience="test-audience"
|
||||
)
|
||||
provider = Auth0Provider(settings)
|
||||
expected_url = "https://my-company.auth0.com/oauth/device/code"
|
||||
assert provider.get_authorize_url() == expected_url
|
||||
|
||||
def test_get_token_url(self):
|
||||
expected_url = "https://test-domain.auth0.com/oauth/token"
|
||||
assert self.provider.get_token_url() == expected_url
|
||||
|
||||
def test_get_token_url_with_different_domain(self):
|
||||
settings = Oauth2Settings(
|
||||
provider="auth0",
|
||||
domain="another-domain.auth0.com",
|
||||
client_id="test-client",
|
||||
audience="test-audience"
|
||||
)
|
||||
provider = Auth0Provider(settings)
|
||||
expected_url = "https://another-domain.auth0.com/oauth/token"
|
||||
assert provider.get_token_url() == expected_url
|
||||
|
||||
def test_get_jwks_url(self):
|
||||
expected_url = "https://test-domain.auth0.com/.well-known/jwks.json"
|
||||
assert self.provider.get_jwks_url() == expected_url
|
||||
|
||||
def test_get_jwks_url_with_different_domain(self):
|
||||
settings = Oauth2Settings(
|
||||
provider="auth0",
|
||||
domain="dev.auth0.com",
|
||||
client_id="test-client",
|
||||
audience="test-audience"
|
||||
)
|
||||
provider = Auth0Provider(settings)
|
||||
expected_url = "https://dev.auth0.com/.well-known/jwks.json"
|
||||
assert provider.get_jwks_url() == expected_url
|
||||
|
||||
def test_get_issuer(self):
|
||||
expected_issuer = "https://test-domain.auth0.com/"
|
||||
assert self.provider.get_issuer() == expected_issuer
|
||||
|
||||
def test_get_issuer_with_different_domain(self):
|
||||
settings = Oauth2Settings(
|
||||
provider="auth0",
|
||||
domain="prod.auth0.com",
|
||||
client_id="test-client",
|
||||
audience="test-audience"
|
||||
)
|
||||
provider = Auth0Provider(settings)
|
||||
expected_issuer = "https://prod.auth0.com/"
|
||||
assert provider.get_issuer() == expected_issuer
|
||||
|
||||
def test_get_audience(self):
|
||||
assert self.provider.get_audience() == "test-audience"
|
||||
|
||||
def test_get_client_id(self):
|
||||
assert self.provider.get_client_id() == "test-client-id"
|
||||
141
lib/cli/tests/authentication/providers/test_entra_id.py
Normal file
141
lib/cli/tests/authentication/providers/test_entra_id.py
Normal file
@@ -0,0 +1,141 @@
|
||||
import pytest
|
||||
|
||||
from crewai_cli.authentication.main import Oauth2Settings
|
||||
from crewai_cli.authentication.providers.entra_id import EntraIdProvider
|
||||
|
||||
|
||||
class TestEntraIdProvider:
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_method(self):
|
||||
self.valid_settings = Oauth2Settings(
|
||||
provider="entra_id",
|
||||
domain="tenant-id-abcdef123456",
|
||||
client_id="test-client-id",
|
||||
audience="test-audience",
|
||||
extra={
|
||||
"scope": "openid profile email api://crewai-cli-dev/read"
|
||||
}
|
||||
)
|
||||
self.provider = EntraIdProvider(self.valid_settings)
|
||||
|
||||
def test_initialization_with_valid_settings(self):
|
||||
provider = EntraIdProvider(self.valid_settings)
|
||||
assert provider.settings == self.valid_settings
|
||||
assert provider.settings.provider == "entra_id"
|
||||
assert provider.settings.domain == "tenant-id-abcdef123456"
|
||||
assert provider.settings.client_id == "test-client-id"
|
||||
assert provider.settings.audience == "test-audience"
|
||||
|
||||
def test_get_authorize_url(self):
|
||||
expected_url = "https://login.microsoftonline.com/tenant-id-abcdef123456/oauth2/v2.0/devicecode"
|
||||
assert self.provider.get_authorize_url() == expected_url
|
||||
|
||||
def test_get_authorize_url_with_different_domain(self):
|
||||
# For EntraID, the domain is the tenant ID.
|
||||
settings = Oauth2Settings(
|
||||
provider="entra_id",
|
||||
domain="my-company.entra.id",
|
||||
client_id="test-client",
|
||||
audience="test-audience",
|
||||
)
|
||||
provider = EntraIdProvider(settings)
|
||||
expected_url = "https://login.microsoftonline.com/my-company.entra.id/oauth2/v2.0/devicecode"
|
||||
assert provider.get_authorize_url() == expected_url
|
||||
|
||||
def test_get_token_url(self):
|
||||
expected_url = "https://login.microsoftonline.com/tenant-id-abcdef123456/oauth2/v2.0/token"
|
||||
assert self.provider.get_token_url() == expected_url
|
||||
|
||||
def test_get_token_url_with_different_domain(self):
|
||||
# For EntraID, the domain is the tenant ID.
|
||||
settings = Oauth2Settings(
|
||||
provider="entra_id",
|
||||
domain="another-domain.entra.id",
|
||||
client_id="test-client",
|
||||
audience="test-audience",
|
||||
)
|
||||
provider = EntraIdProvider(settings)
|
||||
expected_url = "https://login.microsoftonline.com/another-domain.entra.id/oauth2/v2.0/token"
|
||||
assert provider.get_token_url() == expected_url
|
||||
|
||||
def test_get_jwks_url(self):
|
||||
expected_url = "https://login.microsoftonline.com/tenant-id-abcdef123456/discovery/v2.0/keys"
|
||||
assert self.provider.get_jwks_url() == expected_url
|
||||
|
||||
def test_get_jwks_url_with_different_domain(self):
|
||||
# For EntraID, the domain is the tenant ID.
|
||||
settings = Oauth2Settings(
|
||||
provider="entra_id",
|
||||
domain="dev.entra.id",
|
||||
client_id="test-client",
|
||||
audience="test-audience",
|
||||
)
|
||||
provider = EntraIdProvider(settings)
|
||||
expected_url = "https://login.microsoftonline.com/dev.entra.id/discovery/v2.0/keys"
|
||||
assert provider.get_jwks_url() == expected_url
|
||||
|
||||
def test_get_issuer(self):
|
||||
expected_issuer = "https://login.microsoftonline.com/tenant-id-abcdef123456/v2.0"
|
||||
assert self.provider.get_issuer() == expected_issuer
|
||||
|
||||
def test_get_issuer_with_different_domain(self):
|
||||
# For EntraID, the domain is the tenant ID.
|
||||
settings = Oauth2Settings(
|
||||
provider="entra_id",
|
||||
domain="other-tenant-id-xpto",
|
||||
client_id="test-client",
|
||||
audience="test-audience",
|
||||
)
|
||||
provider = EntraIdProvider(settings)
|
||||
expected_issuer = "https://login.microsoftonline.com/other-tenant-id-xpto/v2.0"
|
||||
assert provider.get_issuer() == expected_issuer
|
||||
|
||||
def test_get_audience(self):
|
||||
assert self.provider.get_audience() == "test-audience"
|
||||
|
||||
def test_get_audience_assertion_error_when_none(self):
|
||||
settings = Oauth2Settings(
|
||||
provider="entra_id",
|
||||
domain="test-tenant-id",
|
||||
client_id="test-client-id",
|
||||
audience=None,
|
||||
)
|
||||
provider = EntraIdProvider(settings)
|
||||
|
||||
with pytest.raises(ValueError, match="Audience is required"):
|
||||
provider.get_audience()
|
||||
|
||||
def test_get_client_id(self):
|
||||
assert self.provider.get_client_id() == "test-client-id"
|
||||
|
||||
def test_get_required_fields(self):
|
||||
assert set(self.provider.get_required_fields()) == set(["scope"])
|
||||
|
||||
def test_get_oauth_scopes(self):
|
||||
settings = Oauth2Settings(
|
||||
provider="entra_id",
|
||||
domain="tenant-id-abcdef123456",
|
||||
client_id="test-client-id",
|
||||
audience="test-audience",
|
||||
extra={
|
||||
"scope": "api://crewai-cli-dev/read"
|
||||
}
|
||||
)
|
||||
provider = EntraIdProvider(settings)
|
||||
assert provider.get_oauth_scopes() == ["openid", "profile", "email", "api://crewai-cli-dev/read"]
|
||||
|
||||
def test_get_oauth_scopes_with_multiple_custom_scopes(self):
|
||||
settings = Oauth2Settings(
|
||||
provider="entra_id",
|
||||
domain="tenant-id-abcdef123456",
|
||||
client_id="test-client-id",
|
||||
audience="test-audience",
|
||||
extra={
|
||||
"scope": "api://crewai-cli-dev/read api://crewai-cli-dev/write custom-scope1 custom-scope2"
|
||||
}
|
||||
)
|
||||
provider = EntraIdProvider(settings)
|
||||
assert provider.get_oauth_scopes() == ["openid", "profile", "email", "api://crewai-cli-dev/read", "api://crewai-cli-dev/write", "custom-scope1", "custom-scope2"]
|
||||
|
||||
def test_base_url(self):
|
||||
assert self.provider._base_url() == "https://login.microsoftonline.com/tenant-id-abcdef123456"
|
||||
138
lib/cli/tests/authentication/providers/test_keycloak.py
Normal file
138
lib/cli/tests/authentication/providers/test_keycloak.py
Normal file
@@ -0,0 +1,138 @@
|
||||
import pytest
|
||||
|
||||
from crewai_cli.authentication.main import Oauth2Settings
|
||||
from crewai_cli.authentication.providers.keycloak import KeycloakProvider
|
||||
|
||||
|
||||
class TestKeycloakProvider:
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_method(self):
|
||||
self.valid_settings = Oauth2Settings(
|
||||
provider="keycloak",
|
||||
domain="keycloak.example.com",
|
||||
client_id="test-client-id",
|
||||
audience="test-audience",
|
||||
extra={
|
||||
"realm": "test-realm"
|
||||
}
|
||||
)
|
||||
self.provider = KeycloakProvider(self.valid_settings)
|
||||
|
||||
def test_initialization_with_valid_settings(self):
|
||||
provider = KeycloakProvider(self.valid_settings)
|
||||
assert provider.settings == self.valid_settings
|
||||
assert provider.settings.provider == "keycloak"
|
||||
assert provider.settings.domain == "keycloak.example.com"
|
||||
assert provider.settings.client_id == "test-client-id"
|
||||
assert provider.settings.audience == "test-audience"
|
||||
assert provider.settings.extra.get("realm") == "test-realm"
|
||||
|
||||
def test_get_authorize_url(self):
|
||||
expected_url = "https://keycloak.example.com/realms/test-realm/protocol/openid-connect/auth/device"
|
||||
assert self.provider.get_authorize_url() == expected_url
|
||||
|
||||
def test_get_authorize_url_with_different_domain(self):
|
||||
settings = Oauth2Settings(
|
||||
provider="keycloak",
|
||||
domain="auth.company.com",
|
||||
client_id="test-client",
|
||||
audience="test-audience",
|
||||
extra={
|
||||
"realm": "my-realm"
|
||||
}
|
||||
)
|
||||
provider = KeycloakProvider(settings)
|
||||
expected_url = "https://auth.company.com/realms/my-realm/protocol/openid-connect/auth/device"
|
||||
assert provider.get_authorize_url() == expected_url
|
||||
|
||||
def test_get_token_url(self):
|
||||
expected_url = "https://keycloak.example.com/realms/test-realm/protocol/openid-connect/token"
|
||||
assert self.provider.get_token_url() == expected_url
|
||||
|
||||
def test_get_token_url_with_different_domain(self):
|
||||
settings = Oauth2Settings(
|
||||
provider="keycloak",
|
||||
domain="sso.enterprise.com",
|
||||
client_id="test-client",
|
||||
audience="test-audience",
|
||||
extra={
|
||||
"realm": "enterprise-realm"
|
||||
}
|
||||
)
|
||||
provider = KeycloakProvider(settings)
|
||||
expected_url = "https://sso.enterprise.com/realms/enterprise-realm/protocol/openid-connect/token"
|
||||
assert provider.get_token_url() == expected_url
|
||||
|
||||
def test_get_jwks_url(self):
|
||||
expected_url = "https://keycloak.example.com/realms/test-realm/protocol/openid-connect/certs"
|
||||
assert self.provider.get_jwks_url() == expected_url
|
||||
|
||||
def test_get_jwks_url_with_different_domain(self):
|
||||
settings = Oauth2Settings(
|
||||
provider="keycloak",
|
||||
domain="identity.org",
|
||||
client_id="test-client",
|
||||
audience="test-audience",
|
||||
extra={
|
||||
"realm": "org-realm"
|
||||
}
|
||||
)
|
||||
provider = KeycloakProvider(settings)
|
||||
expected_url = "https://identity.org/realms/org-realm/protocol/openid-connect/certs"
|
||||
assert provider.get_jwks_url() == expected_url
|
||||
|
||||
def test_get_issuer(self):
|
||||
expected_issuer = "https://keycloak.example.com/realms/test-realm"
|
||||
assert self.provider.get_issuer() == expected_issuer
|
||||
|
||||
def test_get_issuer_with_different_domain(self):
|
||||
settings = Oauth2Settings(
|
||||
provider="keycloak",
|
||||
domain="login.myapp.io",
|
||||
client_id="test-client",
|
||||
audience="test-audience",
|
||||
extra={
|
||||
"realm": "app-realm"
|
||||
}
|
||||
)
|
||||
provider = KeycloakProvider(settings)
|
||||
expected_issuer = "https://login.myapp.io/realms/app-realm"
|
||||
assert provider.get_issuer() == expected_issuer
|
||||
|
||||
def test_get_audience(self):
|
||||
assert self.provider.get_audience() == "test-audience"
|
||||
|
||||
def test_get_client_id(self):
|
||||
assert self.provider.get_client_id() == "test-client-id"
|
||||
|
||||
def test_get_required_fields(self):
|
||||
assert self.provider.get_required_fields() == ["realm"]
|
||||
|
||||
def test_oauth2_base_url(self):
|
||||
assert self.provider._oauth2_base_url() == "https://keycloak.example.com"
|
||||
|
||||
def test_oauth2_base_url_strips_https_prefix(self):
|
||||
settings = Oauth2Settings(
|
||||
provider="keycloak",
|
||||
domain="https://keycloak.example.com",
|
||||
client_id="test-client-id",
|
||||
audience="test-audience",
|
||||
extra={
|
||||
"realm": "test-realm"
|
||||
}
|
||||
)
|
||||
provider = KeycloakProvider(settings)
|
||||
assert provider._oauth2_base_url() == "https://keycloak.example.com"
|
||||
|
||||
def test_oauth2_base_url_strips_http_prefix(self):
|
||||
settings = Oauth2Settings(
|
||||
provider="keycloak",
|
||||
domain="http://keycloak.example.com",
|
||||
client_id="test-client-id",
|
||||
audience="test-audience",
|
||||
extra={
|
||||
"realm": "test-realm"
|
||||
}
|
||||
)
|
||||
provider = KeycloakProvider(settings)
|
||||
assert provider._oauth2_base_url() == "https://keycloak.example.com"
|
||||
257
lib/cli/tests/authentication/providers/test_okta.py
Normal file
257
lib/cli/tests/authentication/providers/test_okta.py
Normal file
@@ -0,0 +1,257 @@
|
||||
import pytest
|
||||
|
||||
from crewai_cli.authentication.main import Oauth2Settings
|
||||
from crewai_cli.authentication.providers.okta import OktaProvider
|
||||
|
||||
|
||||
class TestOktaProvider:
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_method(self):
|
||||
self.valid_settings = Oauth2Settings(
|
||||
provider="okta",
|
||||
domain="test-domain.okta.com",
|
||||
client_id="test-client-id",
|
||||
audience="test-audience",
|
||||
)
|
||||
self.provider = OktaProvider(self.valid_settings)
|
||||
|
||||
def test_initialization_with_valid_settings(self):
|
||||
provider = OktaProvider(self.valid_settings)
|
||||
assert provider.settings == self.valid_settings
|
||||
assert provider.settings.provider == "okta"
|
||||
assert provider.settings.domain == "test-domain.okta.com"
|
||||
assert provider.settings.client_id == "test-client-id"
|
||||
assert provider.settings.audience == "test-audience"
|
||||
|
||||
def test_get_authorize_url(self):
|
||||
expected_url = "https://test-domain.okta.com/oauth2/default/v1/device/authorize"
|
||||
assert self.provider.get_authorize_url() == expected_url
|
||||
|
||||
def test_get_authorize_url_with_different_domain(self):
|
||||
settings = Oauth2Settings(
|
||||
provider="okta",
|
||||
domain="my-company.okta.com",
|
||||
client_id="test-client",
|
||||
audience="test-audience",
|
||||
)
|
||||
provider = OktaProvider(settings)
|
||||
expected_url = "https://my-company.okta.com/oauth2/default/v1/device/authorize"
|
||||
assert provider.get_authorize_url() == expected_url
|
||||
|
||||
def test_get_authorize_url_with_custom_authorization_server_name(self):
|
||||
settings = Oauth2Settings(
|
||||
provider="okta",
|
||||
domain="test-domain.okta.com",
|
||||
client_id="test-client-id",
|
||||
audience=None,
|
||||
extra={
|
||||
"using_org_auth_server": False,
|
||||
"authorization_server_name": "my_auth_server_xxxAAA777"
|
||||
}
|
||||
)
|
||||
provider = OktaProvider(settings)
|
||||
expected_url = "https://test-domain.okta.com/oauth2/my_auth_server_xxxAAA777/v1/device/authorize"
|
||||
assert provider.get_authorize_url() == expected_url
|
||||
|
||||
def test_get_authorize_url_when_using_org_auth_server(self):
|
||||
settings = Oauth2Settings(
|
||||
provider="okta",
|
||||
domain="test-domain.okta.com",
|
||||
client_id="test-client-id",
|
||||
audience=None,
|
||||
extra={
|
||||
"using_org_auth_server": True,
|
||||
"authorization_server_name": None
|
||||
}
|
||||
)
|
||||
provider = OktaProvider(settings)
|
||||
expected_url = "https://test-domain.okta.com/oauth2/v1/device/authorize"
|
||||
assert provider.get_authorize_url() == expected_url
|
||||
|
||||
def test_get_token_url(self):
|
||||
expected_url = "https://test-domain.okta.com/oauth2/default/v1/token"
|
||||
assert self.provider.get_token_url() == expected_url
|
||||
|
||||
def test_get_token_url_with_different_domain(self):
|
||||
settings = Oauth2Settings(
|
||||
provider="okta",
|
||||
domain="another-domain.okta.com",
|
||||
client_id="test-client",
|
||||
audience="test-audience",
|
||||
)
|
||||
provider = OktaProvider(settings)
|
||||
expected_url = "https://another-domain.okta.com/oauth2/default/v1/token"
|
||||
assert provider.get_token_url() == expected_url
|
||||
|
||||
def test_get_token_url_with_custom_authorization_server_name(self):
|
||||
settings = Oauth2Settings(
|
||||
provider="okta",
|
||||
domain="test-domain.okta.com",
|
||||
client_id="test-client-id",
|
||||
audience=None,
|
||||
extra={
|
||||
"using_org_auth_server": False,
|
||||
"authorization_server_name": "my_auth_server_xxxAAA777"
|
||||
}
|
||||
)
|
||||
provider = OktaProvider(settings)
|
||||
expected_url = "https://test-domain.okta.com/oauth2/my_auth_server_xxxAAA777/v1/token"
|
||||
assert provider.get_token_url() == expected_url
|
||||
|
||||
def test_get_token_url_when_using_org_auth_server(self):
|
||||
settings = Oauth2Settings(
|
||||
provider="okta",
|
||||
domain="test-domain.okta.com",
|
||||
client_id="test-client-id",
|
||||
audience=None,
|
||||
extra={
|
||||
"using_org_auth_server": True,
|
||||
"authorization_server_name": None
|
||||
}
|
||||
)
|
||||
provider = OktaProvider(settings)
|
||||
expected_url = "https://test-domain.okta.com/oauth2/v1/token"
|
||||
assert provider.get_token_url() == expected_url
|
||||
|
||||
def test_get_jwks_url(self):
|
||||
expected_url = "https://test-domain.okta.com/oauth2/default/v1/keys"
|
||||
assert self.provider.get_jwks_url() == expected_url
|
||||
|
||||
def test_get_jwks_url_with_different_domain(self):
|
||||
settings = Oauth2Settings(
|
||||
provider="okta",
|
||||
domain="dev.okta.com",
|
||||
client_id="test-client",
|
||||
audience="test-audience",
|
||||
)
|
||||
provider = OktaProvider(settings)
|
||||
expected_url = "https://dev.okta.com/oauth2/default/v1/keys"
|
||||
assert provider.get_jwks_url() == expected_url
|
||||
|
||||
def test_get_jwks_url_with_custom_authorization_server_name(self):
|
||||
settings = Oauth2Settings(
|
||||
provider="okta",
|
||||
domain="test-domain.okta.com",
|
||||
client_id="test-client-id",
|
||||
audience=None,
|
||||
extra={
|
||||
"using_org_auth_server": False,
|
||||
"authorization_server_name": "my_auth_server_xxxAAA777"
|
||||
}
|
||||
)
|
||||
provider = OktaProvider(settings)
|
||||
expected_url = "https://test-domain.okta.com/oauth2/my_auth_server_xxxAAA777/v1/keys"
|
||||
assert provider.get_jwks_url() == expected_url
|
||||
|
||||
def test_get_jwks_url_when_using_org_auth_server(self):
|
||||
settings = Oauth2Settings(
|
||||
provider="okta",
|
||||
domain="test-domain.okta.com",
|
||||
client_id="test-client-id",
|
||||
audience=None,
|
||||
extra={
|
||||
"using_org_auth_server": True,
|
||||
"authorization_server_name": None
|
||||
}
|
||||
)
|
||||
provider = OktaProvider(settings)
|
||||
expected_url = "https://test-domain.okta.com/oauth2/v1/keys"
|
||||
assert provider.get_jwks_url() == expected_url
|
||||
|
||||
def test_get_issuer(self):
|
||||
expected_issuer = "https://test-domain.okta.com/oauth2/default"
|
||||
assert self.provider.get_issuer() == expected_issuer
|
||||
|
||||
def test_get_issuer_with_different_domain(self):
|
||||
settings = Oauth2Settings(
|
||||
provider="okta",
|
||||
domain="prod.okta.com",
|
||||
client_id="test-client",
|
||||
audience="test-audience",
|
||||
)
|
||||
provider = OktaProvider(settings)
|
||||
expected_issuer = "https://prod.okta.com/oauth2/default"
|
||||
assert provider.get_issuer() == expected_issuer
|
||||
|
||||
def test_get_issuer_with_custom_authorization_server_name(self):
|
||||
settings = Oauth2Settings(
|
||||
provider="okta",
|
||||
domain="test-domain.okta.com",
|
||||
client_id="test-client-id",
|
||||
audience=None,
|
||||
extra={
|
||||
"using_org_auth_server": False,
|
||||
"authorization_server_name": "my_auth_server_xxxAAA777"
|
||||
}
|
||||
)
|
||||
provider = OktaProvider(settings)
|
||||
expected_issuer = "https://test-domain.okta.com/oauth2/my_auth_server_xxxAAA777"
|
||||
assert provider.get_issuer() == expected_issuer
|
||||
|
||||
def test_get_issuer_when_using_org_auth_server(self):
|
||||
settings = Oauth2Settings(
|
||||
provider="okta",
|
||||
domain="test-domain.okta.com",
|
||||
client_id="test-client-id",
|
||||
audience=None,
|
||||
extra={
|
||||
"using_org_auth_server": True,
|
||||
"authorization_server_name": None
|
||||
}
|
||||
)
|
||||
provider = OktaProvider(settings)
|
||||
expected_issuer = "https://test-domain.okta.com"
|
||||
assert provider.get_issuer() == expected_issuer
|
||||
|
||||
def test_get_audience(self):
|
||||
assert self.provider.get_audience() == "test-audience"
|
||||
|
||||
def test_get_audience_assertion_error_when_none(self):
|
||||
settings = Oauth2Settings(
|
||||
provider="okta",
|
||||
domain="test-domain.okta.com",
|
||||
client_id="test-client-id",
|
||||
audience=None,
|
||||
)
|
||||
provider = OktaProvider(settings)
|
||||
|
||||
with pytest.raises(ValueError, match="Audience is required"):
|
||||
provider.get_audience()
|
||||
|
||||
def test_get_client_id(self):
|
||||
assert self.provider.get_client_id() == "test-client-id"
|
||||
|
||||
def test_get_required_fields(self):
|
||||
assert set(self.provider.get_required_fields()) == set(["authorization_server_name", "using_org_auth_server"])
|
||||
|
||||
def test_oauth2_base_url(self):
|
||||
assert self.provider._oauth2_base_url() == "https://test-domain.okta.com/oauth2/default"
|
||||
|
||||
def test_oauth2_base_url_with_custom_authorization_server_name(self):
|
||||
settings = Oauth2Settings(
|
||||
provider="okta",
|
||||
domain="test-domain.okta.com",
|
||||
client_id="test-client-id",
|
||||
audience=None,
|
||||
extra={
|
||||
"using_org_auth_server": False,
|
||||
"authorization_server_name": "my_auth_server_xxxAAA777"
|
||||
}
|
||||
)
|
||||
|
||||
provider = OktaProvider(settings)
|
||||
assert provider._oauth2_base_url() == "https://test-domain.okta.com/oauth2/my_auth_server_xxxAAA777"
|
||||
|
||||
def test_oauth2_base_url_when_using_org_auth_server(self):
|
||||
settings = Oauth2Settings(
|
||||
provider="okta",
|
||||
domain="test-domain.okta.com",
|
||||
client_id="test-client-id",
|
||||
audience=None,
|
||||
extra={
|
||||
"using_org_auth_server": True,
|
||||
"authorization_server_name": None
|
||||
}
|
||||
)
|
||||
provider = OktaProvider(settings)
|
||||
assert provider._oauth2_base_url() == "https://test-domain.okta.com/oauth2"
|
||||
100
lib/cli/tests/authentication/providers/test_workos.py
Normal file
100
lib/cli/tests/authentication/providers/test_workos.py
Normal file
@@ -0,0 +1,100 @@
|
||||
import pytest
|
||||
from crewai_cli.authentication.main import Oauth2Settings
|
||||
from crewai_cli.authentication.providers.workos import WorkosProvider
|
||||
|
||||
|
||||
class TestWorkosProvider:
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_method(self):
|
||||
self.valid_settings = Oauth2Settings(
|
||||
provider="workos",
|
||||
domain="login.company.com",
|
||||
client_id="test-client-id",
|
||||
audience="test-audience"
|
||||
)
|
||||
self.provider = WorkosProvider(self.valid_settings)
|
||||
|
||||
def test_initialization_with_valid_settings(self):
|
||||
provider = WorkosProvider(self.valid_settings)
|
||||
assert provider.settings == self.valid_settings
|
||||
assert provider.settings.provider == "workos"
|
||||
assert provider.settings.domain == "login.company.com"
|
||||
assert provider.settings.client_id == "test-client-id"
|
||||
assert provider.settings.audience == "test-audience"
|
||||
|
||||
def test_get_authorize_url(self):
|
||||
expected_url = "https://login.company.com/oauth2/device_authorization"
|
||||
assert self.provider.get_authorize_url() == expected_url
|
||||
|
||||
def test_get_authorize_url_with_different_domain(self):
|
||||
settings = Oauth2Settings(
|
||||
provider="workos",
|
||||
domain="login.example.com",
|
||||
client_id="test-client",
|
||||
audience="test-audience"
|
||||
)
|
||||
provider = WorkosProvider(settings)
|
||||
expected_url = "https://login.example.com/oauth2/device_authorization"
|
||||
assert provider.get_authorize_url() == expected_url
|
||||
|
||||
def test_get_token_url(self):
|
||||
expected_url = "https://login.company.com/oauth2/token"
|
||||
assert self.provider.get_token_url() == expected_url
|
||||
|
||||
def test_get_token_url_with_different_domain(self):
|
||||
settings = Oauth2Settings(
|
||||
provider="workos",
|
||||
domain="api.workos.com",
|
||||
client_id="test-client",
|
||||
audience="test-audience"
|
||||
)
|
||||
provider = WorkosProvider(settings)
|
||||
expected_url = "https://api.workos.com/oauth2/token"
|
||||
assert provider.get_token_url() == expected_url
|
||||
|
||||
def test_get_jwks_url(self):
|
||||
expected_url = "https://login.company.com/oauth2/jwks"
|
||||
assert self.provider.get_jwks_url() == expected_url
|
||||
|
||||
def test_get_jwks_url_with_different_domain(self):
|
||||
settings = Oauth2Settings(
|
||||
provider="workos",
|
||||
domain="auth.enterprise.com",
|
||||
client_id="test-client",
|
||||
audience="test-audience"
|
||||
)
|
||||
provider = WorkosProvider(settings)
|
||||
expected_url = "https://auth.enterprise.com/oauth2/jwks"
|
||||
assert provider.get_jwks_url() == expected_url
|
||||
|
||||
def test_get_issuer(self):
|
||||
expected_issuer = "https://login.company.com"
|
||||
assert self.provider.get_issuer() == expected_issuer
|
||||
|
||||
def test_get_issuer_with_different_domain(self):
|
||||
settings = Oauth2Settings(
|
||||
provider="workos",
|
||||
domain="sso.company.com",
|
||||
client_id="test-client",
|
||||
audience="test-audience"
|
||||
)
|
||||
provider = WorkosProvider(settings)
|
||||
expected_issuer = "https://sso.company.com"
|
||||
assert provider.get_issuer() == expected_issuer
|
||||
|
||||
def test_get_audience(self):
|
||||
assert self.provider.get_audience() == "test-audience"
|
||||
|
||||
def test_get_audience_fallback_to_default(self):
|
||||
settings = Oauth2Settings(
|
||||
provider="workos",
|
||||
domain="login.company.com",
|
||||
client_id="test-client-id",
|
||||
audience=None
|
||||
)
|
||||
provider = WorkosProvider(settings)
|
||||
assert provider.get_audience() == ""
|
||||
|
||||
def test_get_client_id(self):
|
||||
assert self.provider.get_client_id() == "test-client-id"
|
||||
348
lib/cli/tests/authentication/test_auth_main.py
Normal file
348
lib/cli/tests/authentication/test_auth_main.py
Normal file
@@ -0,0 +1,348 @@
|
||||
from datetime import datetime, timedelta
|
||||
from unittest.mock import MagicMock, call, patch
|
||||
|
||||
import pytest
|
||||
import httpx
|
||||
from crewai_cli.authentication.main import AuthenticationCommand
|
||||
from crewai_cli.constants import (
|
||||
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_AUDIENCE,
|
||||
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_CLIENT_ID,
|
||||
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN,
|
||||
)
|
||||
|
||||
|
||||
class TestAuthenticationCommand:
|
||||
def setup_method(self):
|
||||
# Mock Settings so we always use default constants regardless of local config.
|
||||
with patch("crewai_cli.authentication.main.Settings") as mock_settings:
|
||||
instance = mock_settings.return_value
|
||||
instance.oauth2_provider = "workos"
|
||||
instance.oauth2_domain = CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN
|
||||
instance.oauth2_client_id = CREWAI_ENTERPRISE_DEFAULT_OAUTH2_CLIENT_ID
|
||||
instance.oauth2_audience = CREWAI_ENTERPRISE_DEFAULT_OAUTH2_AUDIENCE
|
||||
instance.oauth2_extra = {}
|
||||
self.auth_command = AuthenticationCommand()
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"user_provider,expected_urls",
|
||||
[
|
||||
(
|
||||
"workos",
|
||||
{
|
||||
"device_code_url": f"https://{CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN}/oauth2/device_authorization",
|
||||
"token_url": f"https://{CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN}/oauth2/token",
|
||||
"client_id": CREWAI_ENTERPRISE_DEFAULT_OAUTH2_CLIENT_ID,
|
||||
"audience": CREWAI_ENTERPRISE_DEFAULT_OAUTH2_AUDIENCE,
|
||||
"domain": CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN,
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
@patch("crewai_cli.authentication.main.AuthenticationCommand._get_device_code")
|
||||
@patch(
|
||||
"crewai_cli.authentication.main.AuthenticationCommand._display_auth_instructions"
|
||||
)
|
||||
@patch("crewai_cli.authentication.main.AuthenticationCommand._poll_for_token")
|
||||
@patch("crewai_core.auth.oauth2.console.print")
|
||||
def test_login(
|
||||
self,
|
||||
mock_console_print,
|
||||
mock_poll,
|
||||
mock_display,
|
||||
mock_get_device,
|
||||
user_provider,
|
||||
expected_urls,
|
||||
):
|
||||
mock_get_device.return_value = {
|
||||
"device_code": "test_code",
|
||||
"user_code": "123456",
|
||||
}
|
||||
|
||||
self.auth_command.login()
|
||||
|
||||
mock_console_print.assert_called_once_with(
|
||||
"Signing in to CrewAI AMP...\n", style="bold blue"
|
||||
)
|
||||
mock_get_device.assert_called_once()
|
||||
mock_display.assert_called_once_with(
|
||||
{"device_code": "test_code", "user_code": "123456"}
|
||||
)
|
||||
mock_poll.assert_called_once_with(
|
||||
{"device_code": "test_code", "user_code": "123456"},
|
||||
)
|
||||
assert (
|
||||
self.auth_command.oauth2_provider.get_client_id()
|
||||
== expected_urls["client_id"]
|
||||
)
|
||||
assert (
|
||||
self.auth_command.oauth2_provider.get_audience()
|
||||
== expected_urls["audience"]
|
||||
)
|
||||
assert (
|
||||
self.auth_command.oauth2_provider._get_domain() == expected_urls["domain"]
|
||||
)
|
||||
|
||||
@patch("crewai_core.auth.oauth2.webbrowser")
|
||||
@patch("crewai_core.auth.oauth2.console.print")
|
||||
def test_display_auth_instructions(self, mock_console_print, mock_webbrowser):
|
||||
device_code_data = {
|
||||
"verification_uri_complete": "https://example.com/auth",
|
||||
"user_code": "123456",
|
||||
}
|
||||
|
||||
self.auth_command._display_auth_instructions(device_code_data)
|
||||
|
||||
expected_calls = [
|
||||
call("1. Navigate to: ", "https://example.com/auth"),
|
||||
call("2. Enter the following code: ", "123456"),
|
||||
]
|
||||
mock_console_print.assert_has_calls(expected_calls)
|
||||
mock_webbrowser.open.assert_called_once_with("https://example.com/auth")
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"user_provider,jwt_config",
|
||||
[
|
||||
(
|
||||
"workos",
|
||||
{
|
||||
"jwks_url": f"https://{CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN}/oauth2/jwks",
|
||||
"issuer": f"https://{CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN}",
|
||||
"audience": CREWAI_ENTERPRISE_DEFAULT_OAUTH2_AUDIENCE,
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("has_expiration", [True, False])
|
||||
@patch("crewai_core.auth.oauth2.validate_jwt_token")
|
||||
@patch("crewai_core.auth.oauth2.TokenManager.save_tokens")
|
||||
def test_validate_and_save_token(
|
||||
self,
|
||||
mock_save_tokens,
|
||||
mock_validate_jwt,
|
||||
user_provider,
|
||||
jwt_config,
|
||||
has_expiration,
|
||||
):
|
||||
from crewai_cli.authentication.main import Oauth2Settings
|
||||
from crewai_cli.authentication.providers.workos import WorkosProvider
|
||||
|
||||
if user_provider == "workos":
|
||||
self.auth_command.oauth2_provider = WorkosProvider(
|
||||
settings=Oauth2Settings(
|
||||
provider=user_provider,
|
||||
client_id="test-client-id",
|
||||
domain=CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN,
|
||||
audience=jwt_config["audience"],
|
||||
)
|
||||
)
|
||||
|
||||
token_data = {"access_token": "test_access_token", "id_token": "test_id_token"}
|
||||
|
||||
if has_expiration:
|
||||
future_timestamp = int((datetime.now() + timedelta(days=100)).timestamp())
|
||||
decoded_token = {"exp": future_timestamp}
|
||||
else:
|
||||
decoded_token = {}
|
||||
|
||||
mock_validate_jwt.return_value = decoded_token
|
||||
|
||||
self.auth_command._validate_and_save_token(token_data)
|
||||
|
||||
mock_validate_jwt.assert_called_once_with(
|
||||
jwt_token="test_access_token",
|
||||
jwks_url=jwt_config["jwks_url"],
|
||||
issuer=jwt_config["issuer"],
|
||||
audience=jwt_config["audience"],
|
||||
)
|
||||
|
||||
if has_expiration:
|
||||
mock_save_tokens.assert_called_once_with(
|
||||
"test_access_token", future_timestamp
|
||||
)
|
||||
else:
|
||||
mock_save_tokens.assert_called_once_with("test_access_token", 0)
|
||||
|
||||
@patch("crewai_cli.tools.main.ToolCommand")
|
||||
@patch("crewai_cli.authentication.main.Settings")
|
||||
@patch("crewai_core.auth.oauth2.console.print")
|
||||
def test_login_to_tool_repository_success(
|
||||
self, mock_console_print, mock_settings, mock_tool_command
|
||||
):
|
||||
mock_tool_instance = MagicMock()
|
||||
mock_tool_command.return_value = mock_tool_instance
|
||||
|
||||
mock_settings_instance = MagicMock()
|
||||
mock_settings_instance.org_name = "Test Org"
|
||||
mock_settings_instance.org_uuid = "test-uuid-123"
|
||||
mock_settings.return_value = mock_settings_instance
|
||||
|
||||
self.auth_command._login_to_tool_repository()
|
||||
|
||||
mock_tool_command.assert_called_once()
|
||||
mock_tool_instance.login.assert_called_once()
|
||||
|
||||
expected_calls = [
|
||||
call(
|
||||
"Now logging you in to the Tool Repository... ",
|
||||
style="bold blue",
|
||||
end="",
|
||||
),
|
||||
call("Success!\n", style="bold green"),
|
||||
call(
|
||||
"You are now authenticated to the tool repository for organization [bold cyan]'Test Org'[/bold cyan]",
|
||||
style="green",
|
||||
),
|
||||
]
|
||||
mock_console_print.assert_has_calls(expected_calls)
|
||||
|
||||
@patch("crewai_cli.tools.main.ToolCommand")
|
||||
@patch("crewai_core.auth.oauth2.console.print")
|
||||
def test_login_to_tool_repository_error(
|
||||
self, mock_console_print, mock_tool_command
|
||||
):
|
||||
mock_tool_instance = MagicMock()
|
||||
mock_tool_instance.login.side_effect = Exception("Tool repository error")
|
||||
mock_tool_command.return_value = mock_tool_instance
|
||||
|
||||
self.auth_command._login_to_tool_repository()
|
||||
|
||||
mock_tool_command.assert_called_once()
|
||||
mock_tool_instance.login.assert_called_once()
|
||||
|
||||
expected_calls = [
|
||||
call(
|
||||
"Now logging you in to the Tool Repository... ",
|
||||
style="bold blue",
|
||||
end="",
|
||||
),
|
||||
call(
|
||||
"\n[bold yellow]Warning:[/bold yellow] Authentication with the Tool Repository failed.",
|
||||
style="yellow",
|
||||
),
|
||||
call(
|
||||
"Other features will work normally, but you may experience limitations with downloading and publishing tools.\nRun [bold]crewai login[/bold] to try logging in again.\n",
|
||||
style="yellow",
|
||||
),
|
||||
]
|
||||
mock_console_print.assert_has_calls(expected_calls)
|
||||
|
||||
@patch("crewai_core.auth.oauth2.httpx.post")
|
||||
def test_get_device_code(self, mock_post):
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"device_code": "test_device_code",
|
||||
"user_code": "123456",
|
||||
"verification_uri_complete": "https://example.com/auth",
|
||||
}
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
self.auth_command.oauth2_provider = MagicMock()
|
||||
self.auth_command.oauth2_provider.get_client_id.return_value = "test_client"
|
||||
self.auth_command.oauth2_provider.get_authorize_url.return_value = (
|
||||
"https://example.com/device"
|
||||
)
|
||||
self.auth_command.oauth2_provider.get_audience.return_value = "test_audience"
|
||||
self.auth_command.oauth2_provider.get_oauth_scopes.return_value = ["openid", "profile", "email"]
|
||||
|
||||
result = self.auth_command._get_device_code()
|
||||
|
||||
mock_post.assert_called_once_with(
|
||||
url="https://example.com/device",
|
||||
data={
|
||||
"client_id": "test_client",
|
||||
"scope": "openid profile email",
|
||||
"audience": "test_audience",
|
||||
},
|
||||
timeout=20,
|
||||
)
|
||||
|
||||
assert result == {
|
||||
"device_code": "test_device_code",
|
||||
"user_code": "123456",
|
||||
"verification_uri_complete": "https://example.com/auth",
|
||||
}
|
||||
|
||||
@patch("crewai_core.auth.oauth2.httpx.post")
|
||||
@patch("crewai_core.auth.oauth2.console.print")
|
||||
def test_poll_for_token_success(self, mock_console_print, mock_post):
|
||||
mock_response_success = MagicMock()
|
||||
mock_response_success.status_code = 200
|
||||
mock_response_success.json.return_value = {
|
||||
"access_token": "test_access_token",
|
||||
"id_token": "test_id_token",
|
||||
}
|
||||
mock_post.return_value = mock_response_success
|
||||
|
||||
device_code_data = {"device_code": "test_device_code", "interval": 1}
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
self.auth_command, "_validate_and_save_token"
|
||||
) as mock_validate,
|
||||
patch.object(
|
||||
self.auth_command, "_login_to_tool_repository"
|
||||
) as mock_tool_login,
|
||||
):
|
||||
self.auth_command.oauth2_provider = MagicMock()
|
||||
self.auth_command.oauth2_provider.get_token_url.return_value = (
|
||||
"https://example.com/token"
|
||||
)
|
||||
self.auth_command.oauth2_provider.get_client_id.return_value = "test_client"
|
||||
|
||||
self.auth_command._poll_for_token(device_code_data)
|
||||
|
||||
mock_post.assert_called_once_with(
|
||||
"https://example.com/token",
|
||||
data={
|
||||
"grant_type": "urn:ietf:params:oauth:grant-type:device_code",
|
||||
"device_code": "test_device_code",
|
||||
"client_id": "test_client",
|
||||
},
|
||||
timeout=30,
|
||||
)
|
||||
|
||||
mock_validate.assert_called_once()
|
||||
mock_tool_login.assert_called_once()
|
||||
|
||||
expected_calls = [
|
||||
call("\nWaiting for authentication... ", style="bold blue", end=""),
|
||||
call("Success!", style="bold green"),
|
||||
call("\n[bold green]Welcome to CrewAI AMP![/bold green]\n"),
|
||||
]
|
||||
mock_console_print.assert_has_calls(expected_calls)
|
||||
|
||||
@patch("crewai_core.auth.oauth2.httpx.post")
|
||||
@patch("crewai_core.auth.oauth2.console.print")
|
||||
def test_poll_for_token_timeout(self, mock_console_print, mock_post):
|
||||
mock_response_pending = MagicMock()
|
||||
mock_response_pending.status_code = 400
|
||||
mock_response_pending.json.return_value = {"error": "authorization_pending"}
|
||||
mock_post.return_value = mock_response_pending
|
||||
|
||||
device_code_data = {
|
||||
"device_code": "test_device_code",
|
||||
"interval": 0.1, # Short interval for testing
|
||||
}
|
||||
|
||||
self.auth_command._poll_for_token(device_code_data)
|
||||
|
||||
mock_console_print.assert_any_call(
|
||||
"Timeout: Failed to get the token. Please try again.", style="bold red"
|
||||
)
|
||||
|
||||
@patch("crewai_core.auth.oauth2.httpx.post")
|
||||
def test_poll_for_token_error(self, mock_post):
|
||||
"""Test the method to poll for token (error path)."""
|
||||
# Setup mock to return error
|
||||
mock_response_error = MagicMock()
|
||||
mock_response_error.status_code = 400
|
||||
mock_response_error.json.return_value = {
|
||||
"error": "access_denied",
|
||||
"error_description": "User denied access",
|
||||
}
|
||||
mock_post.return_value = mock_response_error
|
||||
|
||||
device_code_data = {"device_code": "test_device_code", "interval": 1}
|
||||
|
||||
with pytest.raises(httpx.HTTPError):
|
||||
self.auth_command._poll_for_token(device_code_data)
|
||||
107
lib/cli/tests/authentication/test_utils.py
Normal file
107
lib/cli/tests/authentication/test_utils.py
Normal file
@@ -0,0 +1,107 @@
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import jwt
|
||||
|
||||
from crewai_cli.authentication.utils import validate_jwt_token
|
||||
|
||||
|
||||
@patch("crewai_core.auth.utils.PyJWKClient", return_value=MagicMock())
|
||||
@patch("crewai_core.auth.utils.jwt")
|
||||
class TestUtils(unittest.TestCase):
|
||||
def test_validate_jwt_token(self, mock_jwt, mock_pyjwkclient):
|
||||
mock_jwt.decode.return_value = {"exp": 1719859200}
|
||||
|
||||
# Create signing key object mock with a .key attribute
|
||||
mock_pyjwkclient.return_value.get_signing_key_from_jwt.return_value = MagicMock(
|
||||
key="mock_signing_key"
|
||||
)
|
||||
|
||||
jwt_token = "aaaaa.bbbbbb.cccccc" # noqa: S105
|
||||
|
||||
decoded_token = validate_jwt_token(
|
||||
jwt_token=jwt_token,
|
||||
jwks_url="https://mock_jwks_url",
|
||||
issuer="https://mock_issuer",
|
||||
audience="app_id_xxxx",
|
||||
)
|
||||
|
||||
mock_jwt.decode.assert_called_with(
|
||||
jwt_token,
|
||||
"mock_signing_key",
|
||||
algorithms=["RS256"],
|
||||
audience="app_id_xxxx",
|
||||
issuer="https://mock_issuer",
|
||||
leeway=10.0,
|
||||
options={
|
||||
"verify_signature": True,
|
||||
"verify_exp": True,
|
||||
"verify_nbf": True,
|
||||
"verify_iat": True,
|
||||
"require": ["exp", "iat", "iss", "aud", "sub"],
|
||||
},
|
||||
)
|
||||
mock_pyjwkclient.assert_called_once_with("https://mock_jwks_url")
|
||||
self.assertEqual(decoded_token, {"exp": 1719859200})
|
||||
|
||||
def test_validate_jwt_token_expired(self, mock_jwt, mock_pyjwkclient):
|
||||
mock_jwt.decode.side_effect = jwt.ExpiredSignatureError
|
||||
with self.assertRaises(Exception): # noqa: B017
|
||||
validate_jwt_token(
|
||||
jwt_token="aaaaa.bbbbbb.cccccc", # noqa: S106
|
||||
jwks_url="https://mock_jwks_url",
|
||||
issuer="https://mock_issuer",
|
||||
audience="app_id_xxxx",
|
||||
)
|
||||
|
||||
def test_validate_jwt_token_invalid_audience(self, mock_jwt, mock_pyjwkclient):
|
||||
mock_jwt.decode.side_effect = jwt.InvalidAudienceError
|
||||
with self.assertRaises(Exception): # noqa: B017
|
||||
validate_jwt_token(
|
||||
jwt_token="aaaaa.bbbbbb.cccccc", # noqa: S106
|
||||
jwks_url="https://mock_jwks_url",
|
||||
issuer="https://mock_issuer",
|
||||
audience="app_id_xxxx",
|
||||
)
|
||||
|
||||
def test_validate_jwt_token_invalid_issuer(self, mock_jwt, mock_pyjwkclient):
|
||||
mock_jwt.decode.side_effect = jwt.InvalidIssuerError
|
||||
with self.assertRaises(Exception): # noqa: B017
|
||||
validate_jwt_token(
|
||||
jwt_token="aaaaa.bbbbbb.cccccc", # noqa: S106
|
||||
jwks_url="https://mock_jwks_url",
|
||||
issuer="https://mock_issuer",
|
||||
audience="app_id_xxxx",
|
||||
)
|
||||
|
||||
def test_validate_jwt_token_missing_required_claims(
|
||||
self, mock_jwt, mock_pyjwkclient
|
||||
):
|
||||
mock_jwt.decode.side_effect = jwt.MissingRequiredClaimError
|
||||
with self.assertRaises(Exception): # noqa: B017
|
||||
validate_jwt_token(
|
||||
jwt_token="aaaaa.bbbbbb.cccccc", # noqa: S106
|
||||
jwks_url="https://mock_jwks_url",
|
||||
issuer="https://mock_issuer",
|
||||
audience="app_id_xxxx",
|
||||
)
|
||||
|
||||
def test_validate_jwt_token_jwks_error(self, mock_jwt, mock_pyjwkclient):
|
||||
mock_jwt.decode.side_effect = jwt.exceptions.PyJWKClientError
|
||||
with self.assertRaises(Exception): # noqa: B017
|
||||
validate_jwt_token(
|
||||
jwt_token="aaaaa.bbbbbb.cccccc", # noqa: S106
|
||||
jwks_url="https://mock_jwks_url",
|
||||
issuer="https://mock_issuer",
|
||||
audience="app_id_xxxx",
|
||||
)
|
||||
|
||||
def test_validate_jwt_token_invalid_token(self, mock_jwt, mock_pyjwkclient):
|
||||
mock_jwt.decode.side_effect = jwt.InvalidTokenError
|
||||
with self.assertRaises(Exception): # noqa: B017
|
||||
validate_jwt_token(
|
||||
jwt_token="aaaaa.bbbbbb.cccccc", # noqa: S106
|
||||
jwks_url="https://mock_jwks_url",
|
||||
issuer="https://mock_issuer",
|
||||
audience="app_id_xxxx",
|
||||
)
|
||||
1
lib/cli/tests/deploy/__init__.py
Normal file
1
lib/cli/tests/deploy/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Tests for CLI deploy."""
|
||||
271
lib/cli/tests/deploy/test_deploy_main.py
Normal file
271
lib/cli/tests/deploy/test_deploy_main.py
Normal file
@@ -0,0 +1,271 @@
|
||||
import sys
|
||||
import unittest
|
||||
from io import StringIO
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
import json
|
||||
|
||||
import httpx
|
||||
from crewai_cli.deploy.main import DeployCommand
|
||||
from crewai_cli.utils import parse_toml
|
||||
|
||||
|
||||
class TestDeployCommand(unittest.TestCase):
|
||||
@patch("crewai_cli.command.get_auth_token")
|
||||
@patch("crewai_cli.deploy.main.get_project_name")
|
||||
@patch("crewai_cli.command.PlusAPI")
|
||||
def setUp(
|
||||
self,
|
||||
mock_plus_api,
|
||||
mock_get_project_name,
|
||||
mock_get_auth_token,
|
||||
):
|
||||
self.mock_get_auth_token = mock_get_auth_token
|
||||
self.mock_get_project_name = mock_get_project_name
|
||||
self.mock_plus_api = mock_plus_api
|
||||
|
||||
self.mock_get_auth_token.return_value = "test_token"
|
||||
self.mock_get_project_name.return_value = "test_project"
|
||||
|
||||
self.deploy_command = DeployCommand()
|
||||
self.mock_client = self.deploy_command.plus_api_client
|
||||
|
||||
def test_init_success(self):
|
||||
self.assertEqual(self.deploy_command.project_name, "test_project")
|
||||
self.mock_plus_api.assert_called_once_with(api_key="test_token")
|
||||
|
||||
@patch("crewai_cli.command.get_auth_token")
|
||||
def test_init_failure(self, mock_get_auth_token):
|
||||
mock_get_auth_token.side_effect = Exception("Auth failed")
|
||||
|
||||
with self.assertRaises(SystemExit):
|
||||
DeployCommand()
|
||||
|
||||
def test_validate_response_successful_response(self):
|
||||
mock_response = Mock(spec=httpx.Response)
|
||||
mock_response.json.return_value = {"message": "Success"}
|
||||
mock_response.status_code = 200
|
||||
mock_response.is_success = True
|
||||
|
||||
with patch("sys.stdout", new=StringIO()) as fake_out:
|
||||
self.deploy_command._validate_response(mock_response)
|
||||
assert fake_out.getvalue() == ""
|
||||
|
||||
def test_validate_response_json_decode_error(self):
|
||||
mock_response = Mock(spec=httpx.Response)
|
||||
mock_response.json.side_effect = json.JSONDecodeError("Decode error", "", 0)
|
||||
mock_response.status_code = 500
|
||||
mock_response.content = b"Invalid JSON"
|
||||
|
||||
with patch("sys.stdout", new=StringIO()) as fake_out:
|
||||
with pytest.raises(SystemExit):
|
||||
self.deploy_command._validate_response(mock_response)
|
||||
output = fake_out.getvalue()
|
||||
assert (
|
||||
"Failed to parse response from Enterprise API failed. Details:"
|
||||
in output
|
||||
)
|
||||
assert "Status Code: 500" in output
|
||||
assert "Response:\nInvalid JSON" in output
|
||||
|
||||
def test_validate_response_422_error(self):
|
||||
mock_response = Mock(spec=httpx.Response)
|
||||
mock_response.json.return_value = {
|
||||
"field1": ["Error message 1"],
|
||||
"field2": ["Error message 2"],
|
||||
}
|
||||
mock_response.status_code = 422
|
||||
mock_response.is_success = False
|
||||
|
||||
with patch("sys.stdout", new=StringIO()) as fake_out:
|
||||
with pytest.raises(SystemExit):
|
||||
self.deploy_command._validate_response(mock_response)
|
||||
output = fake_out.getvalue()
|
||||
assert (
|
||||
"Failed to complete operation. Please fix the following errors:"
|
||||
in output
|
||||
)
|
||||
assert "Field1 Error message 1" in output
|
||||
assert "Field2 Error message 2" in output
|
||||
|
||||
def test_validate_response_other_error(self):
|
||||
mock_response = Mock(spec=httpx.Response)
|
||||
mock_response.json.return_value = {"error": "Something went wrong"}
|
||||
mock_response.status_code = 500
|
||||
mock_response.is_success = False
|
||||
|
||||
with patch("sys.stdout", new=StringIO()) as fake_out:
|
||||
with pytest.raises(SystemExit):
|
||||
self.deploy_command._validate_response(mock_response)
|
||||
output = fake_out.getvalue()
|
||||
assert "Request to Enterprise API failed. Details:" in output
|
||||
assert "Details:\nSomething went wrong" in output
|
||||
|
||||
def test_standard_no_param_error_message(self):
|
||||
with patch("sys.stdout", new=StringIO()) as fake_out:
|
||||
self.deploy_command._standard_no_param_error_message()
|
||||
self.assertIn("No UUID provided", fake_out.getvalue())
|
||||
|
||||
def test_display_deployment_info(self):
|
||||
with patch("sys.stdout", new=StringIO()) as fake_out:
|
||||
self.deploy_command._display_deployment_info(
|
||||
{"uuid": "test-uuid", "status": "deployed"}
|
||||
)
|
||||
self.assertIn("Deploying the crew...", fake_out.getvalue())
|
||||
self.assertIn("test-uuid", fake_out.getvalue())
|
||||
self.assertIn("deployed", fake_out.getvalue())
|
||||
|
||||
def test_display_logs(self):
|
||||
with patch("sys.stdout", new=StringIO()) as fake_out:
|
||||
self.deploy_command._display_logs(
|
||||
[{"timestamp": "2023-01-01", "level": "INFO", "message": "Test log"}]
|
||||
)
|
||||
self.assertIn("2023-01-01 - INFO: Test log", fake_out.getvalue())
|
||||
|
||||
@patch("crewai_cli.deploy.main.DeployCommand._display_deployment_info")
|
||||
def test_deploy_with_uuid(self, mock_display):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {"uuid": "test-uuid"}
|
||||
self.mock_client.deploy_by_uuid.return_value = mock_response
|
||||
|
||||
self.deploy_command.deploy(uuid="test-uuid", skip_validate=True)
|
||||
|
||||
self.mock_client.deploy_by_uuid.assert_called_once_with("test-uuid")
|
||||
mock_display.assert_called_once_with({"uuid": "test-uuid"})
|
||||
|
||||
@patch("crewai_cli.deploy.main.DeployCommand._display_deployment_info")
|
||||
def test_deploy_with_project_name(self, mock_display):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {"uuid": "test-uuid"}
|
||||
self.mock_client.deploy_by_name.return_value = mock_response
|
||||
|
||||
self.deploy_command.deploy(skip_validate=True)
|
||||
|
||||
self.mock_client.deploy_by_name.assert_called_once_with("test_project")
|
||||
mock_display.assert_called_once_with({"uuid": "test-uuid"})
|
||||
|
||||
@patch("crewai_cli.deploy.main.fetch_and_json_env_file")
|
||||
@patch("crewai_cli.deploy.main.git.Repository.origin_url")
|
||||
@patch("builtins.input")
|
||||
def test_create_crew(self, mock_input, mock_git_origin_url, mock_fetch_env):
|
||||
mock_fetch_env.return_value = {"ENV_VAR": "value"}
|
||||
mock_git_origin_url.return_value = "https://github.com/test/repo.git"
|
||||
mock_input.return_value = ""
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 201
|
||||
mock_response.json.return_value = {"uuid": "new-uuid", "status": "created"}
|
||||
self.mock_client.create_crew.return_value = mock_response
|
||||
|
||||
with patch("sys.stdout", new=StringIO()) as fake_out:
|
||||
self.deploy_command.create_crew(skip_validate=True)
|
||||
self.assertIn("Deployment created successfully!", fake_out.getvalue())
|
||||
self.assertIn("new-uuid", fake_out.getvalue())
|
||||
|
||||
def test_list_crews(self):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = [
|
||||
{"name": "Crew1", "uuid": "uuid1", "status": "active"},
|
||||
{"name": "Crew2", "uuid": "uuid2", "status": "inactive"},
|
||||
]
|
||||
self.mock_client.list_crews.return_value = mock_response
|
||||
|
||||
with patch("sys.stdout", new=StringIO()) as fake_out:
|
||||
self.deploy_command.list_crews()
|
||||
self.assertIn("Crew1 (uuid1) active", fake_out.getvalue())
|
||||
self.assertIn("Crew2 (uuid2) inactive", fake_out.getvalue())
|
||||
|
||||
def test_get_crew_status(self):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {"name": "InternalCrew", "status": "active"}
|
||||
self.mock_client.crew_status_by_name.return_value = mock_response
|
||||
|
||||
with patch("sys.stdout", new=StringIO()) as fake_out:
|
||||
self.deploy_command.get_crew_status()
|
||||
self.assertIn("InternalCrew", fake_out.getvalue())
|
||||
self.assertIn("active", fake_out.getvalue())
|
||||
|
||||
def test_get_crew_logs(self):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = [
|
||||
{"timestamp": "2023-01-01", "level": "INFO", "message": "Log1"},
|
||||
{"timestamp": "2023-01-02", "level": "ERROR", "message": "Log2"},
|
||||
]
|
||||
self.mock_client.crew_by_name.return_value = mock_response
|
||||
|
||||
with patch("sys.stdout", new=StringIO()) as fake_out:
|
||||
self.deploy_command.get_crew_logs(None)
|
||||
self.assertIn("2023-01-01 - INFO: Log1", fake_out.getvalue())
|
||||
self.assertIn("2023-01-02 - ERROR: Log2", fake_out.getvalue())
|
||||
|
||||
def test_remove_crew(self):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 204
|
||||
self.mock_client.delete_crew_by_name.return_value = mock_response
|
||||
|
||||
with patch("sys.stdout", new=StringIO()) as fake_out:
|
||||
self.deploy_command.remove_crew(None)
|
||||
self.assertIn(
|
||||
"Crew 'test_project' removed successfully", fake_out.getvalue()
|
||||
)
|
||||
|
||||
@unittest.skipIf(sys.version_info < (3, 11), "Requires Python 3.11+")
|
||||
def test_parse_toml_python_311_plus(self):
|
||||
toml_content = """
|
||||
[tool.poetry]
|
||||
name = "test_project"
|
||||
version = "0.1.0"
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = "^3.11"
|
||||
crewai = { extras = ["tools"], version = ">=0.51.0,<1.0.0" }
|
||||
"""
|
||||
parsed = parse_toml(toml_content)
|
||||
self.assertEqual(parsed["tool"]["poetry"]["name"], "test_project")
|
||||
|
||||
@patch(
|
||||
"builtins.open",
|
||||
new_callable=unittest.mock.mock_open,
|
||||
read_data="""
|
||||
[project]
|
||||
name = "test_project"
|
||||
version = "0.1.0"
|
||||
requires-python = ">=3.10,<3.14"
|
||||
dependencies = ["crewai"]
|
||||
""",
|
||||
)
|
||||
def test_get_project_name_python_310(self, mock_open):
|
||||
from crewai_cli.utils import get_project_name
|
||||
|
||||
project_name = get_project_name()
|
||||
print("project_name", project_name)
|
||||
self.assertEqual(project_name, "test_project")
|
||||
|
||||
@unittest.skipIf(sys.version_info < (3, 11), "Requires Python 3.11+")
|
||||
@patch(
|
||||
"builtins.open",
|
||||
new_callable=unittest.mock.mock_open,
|
||||
read_data="""
|
||||
[project]
|
||||
name = "test_project"
|
||||
version = "0.1.0"
|
||||
requires-python = ">=3.10,<3.14"
|
||||
dependencies = ["crewai"]
|
||||
""",
|
||||
)
|
||||
def test_get_project_name_python_311_plus(self, mock_open):
|
||||
from crewai_cli.utils import get_project_name
|
||||
|
||||
project_name = get_project_name()
|
||||
self.assertEqual(project_name, "test_project")
|
||||
|
||||
def test_get_crewai_version(self):
|
||||
from crewai_cli.version import get_crewai_version
|
||||
|
||||
assert isinstance(get_crewai_version(), str)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user