diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index d99c19524..dcd3e2f1e 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -11,6 +11,7 @@ env: jobs: deploy: runs-on: ubuntu-latest + timeout-minutes: 15 steps: - name: Checkout code diff --git a/poetry.lock b/poetry.lock index 701a7e146..f98e4d9d5 100644 --- a/poetry.lock +++ b/poetry.lock @@ -253,6 +253,24 @@ docs = ["cogapp", "furo", "myst-parser", "sphinx", "sphinx-notfound-page", "sphi tests = ["cloudpickle", "hypothesis", "mypy (>=1.11.1)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"] tests-mypy = ["mypy (>=1.11.1)", "pytest-mypy-plugins"] +[[package]] +name = "auth0-python" +version = "4.7.1" +description = "" +optional = false +python-versions = ">=3.8" +files = [ + {file = "auth0_python-4.7.1-py3-none-any.whl", hash = "sha256:5bdbefd582171f398c2b686a19fb5e241a2fa267929519a0c02e33e5932fa7b8"}, + {file = "auth0_python-4.7.1.tar.gz", hash = "sha256:5cf8be11aa807d54e19271a990eb92bea1863824e4863c7fc8493c6f15a597f1"}, +] + +[package.dependencies] +aiohttp = ">=3.8.5,<4.0.0" +cryptography = ">=42.0.4,<43.0.0" +pyjwt = ">=2.8.0,<3.0.0" +requests = ">=2.31.0,<3.0.0" +urllib3 = ">=2.0.7,<3.0.0" + [[package]] name = "autoflake" version = "2.3.1" @@ -851,6 +869,60 @@ pytube = ">=15.0.0,<16.0.0" requests = ">=2.31.0,<3.0.0" selenium = ">=4.18.1,<5.0.0" +[[package]] +name = "cryptography" +version = "42.0.8" +description = "cryptography is a package which provides cryptographic recipes and primitives to Python developers." +optional = false +python-versions = ">=3.7" +files = [ + {file = "cryptography-42.0.8-cp37-abi3-macosx_10_12_universal2.whl", hash = "sha256:81d8a521705787afe7a18d5bfb47ea9d9cc068206270aad0b96a725022e18d2e"}, + {file = "cryptography-42.0.8-cp37-abi3-macosx_10_12_x86_64.whl", hash = "sha256:961e61cefdcb06e0c6d7e3a1b22ebe8b996eb2bf50614e89384be54c48c6b63d"}, + {file = "cryptography-42.0.8-cp37-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e3ec3672626e1b9e55afd0df6d774ff0e953452886e06e0f1eb7eb0c832e8902"}, + {file = "cryptography-42.0.8-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e599b53fd95357d92304510fb7bda8523ed1f79ca98dce2f43c115950aa78801"}, + {file = "cryptography-42.0.8-cp37-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:5226d5d21ab681f432a9c1cf8b658c0cb02533eece706b155e5fbd8a0cdd3949"}, + {file = "cryptography-42.0.8-cp37-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:6b7c4f03ce01afd3b76cf69a5455caa9cfa3de8c8f493e0d3ab7d20611c8dae9"}, + {file = "cryptography-42.0.8-cp37-abi3-musllinux_1_1_aarch64.whl", hash = "sha256:2346b911eb349ab547076f47f2e035fc8ff2c02380a7cbbf8d87114fa0f1c583"}, + {file = "cryptography-42.0.8-cp37-abi3-musllinux_1_1_x86_64.whl", hash = "sha256:ad803773e9df0b92e0a817d22fd8a3675493f690b96130a5e24f1b8fabbea9c7"}, + {file = "cryptography-42.0.8-cp37-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:2f66d9cd9147ee495a8374a45ca445819f8929a3efcd2e3df6428e46c3cbb10b"}, + {file = "cryptography-42.0.8-cp37-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:d45b940883a03e19e944456a558b67a41160e367a719833c53de6911cabba2b7"}, + {file = "cryptography-42.0.8-cp37-abi3-win32.whl", hash = "sha256:a0c5b2b0585b6af82d7e385f55a8bc568abff8923af147ee3c07bd8b42cda8b2"}, + {file = "cryptography-42.0.8-cp37-abi3-win_amd64.whl", hash = "sha256:57080dee41209e556a9a4ce60d229244f7a66ef52750f813bfbe18959770cfba"}, + {file = "cryptography-42.0.8-cp39-abi3-macosx_10_12_universal2.whl", hash = "sha256:dea567d1b0e8bc5764b9443858b673b734100c2871dc93163f58c46a97a83d28"}, + {file = "cryptography-42.0.8-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c4783183f7cb757b73b2ae9aed6599b96338eb957233c58ca8f49a49cc32fd5e"}, + {file = "cryptography-42.0.8-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a0608251135d0e03111152e41f0cc2392d1e74e35703960d4190b2e0f4ca9c70"}, + {file = "cryptography-42.0.8-cp39-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:dc0fdf6787f37b1c6b08e6dfc892d9d068b5bdb671198c72072828b80bd5fe4c"}, + {file = "cryptography-42.0.8-cp39-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:9c0c1716c8447ee7dbf08d6db2e5c41c688544c61074b54fc4564196f55c25a7"}, + {file = "cryptography-42.0.8-cp39-abi3-musllinux_1_1_aarch64.whl", hash = "sha256:fff12c88a672ab9c9c1cf7b0c80e3ad9e2ebd9d828d955c126be4fd3e5578c9e"}, + {file = "cryptography-42.0.8-cp39-abi3-musllinux_1_1_x86_64.whl", hash = "sha256:cafb92b2bc622cd1aa6a1dce4b93307792633f4c5fe1f46c6b97cf67073ec961"}, + {file = "cryptography-42.0.8-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:31f721658a29331f895a5a54e7e82075554ccfb8b163a18719d342f5ffe5ecb1"}, + {file = "cryptography-42.0.8-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:b297f90c5723d04bcc8265fc2a0f86d4ea2e0f7ab4b6994459548d3a6b992a14"}, + {file = "cryptography-42.0.8-cp39-abi3-win32.whl", hash = "sha256:2f88d197e66c65be5e42cd72e5c18afbfae3f741742070e3019ac8f4ac57262c"}, + {file = "cryptography-42.0.8-cp39-abi3-win_amd64.whl", hash = "sha256:fa76fbb7596cc5839320000cdd5d0955313696d9511debab7ee7278fc8b5c84a"}, + {file = "cryptography-42.0.8-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:ba4f0a211697362e89ad822e667d8d340b4d8d55fae72cdd619389fb5912eefe"}, + {file = "cryptography-42.0.8-pp310-pypy310_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:81884c4d096c272f00aeb1f11cf62ccd39763581645b0812e99a91505fa48e0c"}, + {file = "cryptography-42.0.8-pp310-pypy310_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:c9bb2ae11bfbab395bdd072985abde58ea9860ed84e59dbc0463a5d0159f5b71"}, + {file = "cryptography-42.0.8-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:7016f837e15b0a1c119d27ecd89b3515f01f90a8615ed5e9427e30d9cdbfed3d"}, + {file = "cryptography-42.0.8-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:5a94eccb2a81a309806027e1670a358b99b8fe8bfe9f8d329f27d72c094dde8c"}, + {file = "cryptography-42.0.8-pp39-pypy39_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:dec9b018df185f08483f294cae6ccac29e7a6e0678996587363dc352dc65c842"}, + {file = "cryptography-42.0.8-pp39-pypy39_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:343728aac38decfdeecf55ecab3264b015be68fc2816ca800db649607aeee648"}, + {file = "cryptography-42.0.8-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:013629ae70b40af70c9a7a5db40abe5d9054e6f4380e50ce769947b73bf3caad"}, + {file = "cryptography-42.0.8.tar.gz", hash = "sha256:8d09d05439ce7baa8e9e95b07ec5b6c886f548deb7e0f69ef25f64b3bce842f2"}, +] + +[package.dependencies] +cffi = {version = ">=1.12", markers = "platform_python_implementation != \"PyPy\""} + +[package.extras] +docs = ["sphinx (>=5.3.0)", "sphinx-rtd-theme (>=1.1.1)"] +docstest = ["pyenchant (>=1.6.11)", "readme-renderer", "sphinxcontrib-spelling (>=4.0.1)"] +nox = ["nox"] +pep8test = ["check-sdist", "click", "mypy", "ruff"] +sdist = ["build"] +ssh = ["bcrypt (>=3.1.5)"] +test = ["certifi", "pretend", "pytest (>=6.2.0)", "pytest-benchmark", "pytest-cov", "pytest-xdist"] +test-randomorder = ["pytest-randomly"] + [[package]] name = "cssselect2" version = "0.7.0" @@ -4230,6 +4302,23 @@ files = [ [package.extras] windows-terminal = ["colorama (>=0.4.6)"] +[[package]] +name = "pyjwt" +version = "2.9.0" +description = "JSON Web Token implementation in Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "PyJWT-2.9.0-py3-none-any.whl", hash = "sha256:3b02fb0f44517787776cf48f2ae25d8e14f300e6d7545a4315cee571a415e850"}, + {file = "pyjwt-2.9.0.tar.gz", hash = "sha256:7e1e5b56cc735432a7369cbfa0efe50fa113ebecdc04ae6922deba8b84582d0c"}, +] + +[package.extras] +crypto = ["cryptography (>=3.4.0)"] +dev = ["coverage[toml] (==5.0.4)", "cryptography (>=3.4.0)", "pre-commit", "pytest (>=6.0.0,<7.0.0)", "sphinx", "sphinx-rtd-theme", "zope.interface"] +docs = ["sphinx", "sphinx-rtd-theme", "zope.interface"] +tests = ["coverage[toml] (==5.0.4)", "pytest (>=6.0.0,<7.0.0)"] + [[package]] name = "pylance" version = "0.9.18" @@ -5478,22 +5567,23 @@ files = [ [[package]] name = "urllib3" -version = "1.26.19" +version = "2.2.2" description = "HTTP library with thread-safe connection pooling, file post, and more." optional = false -python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,>=2.7" +python-versions = ">=3.8" files = [ - {file = "urllib3-1.26.19-py2.py3-none-any.whl", hash = "sha256:37a0344459b199fce0e80b0d3569837ec6b6937435c5244e7fd73fa6006830f3"}, - {file = "urllib3-1.26.19.tar.gz", hash = "sha256:3e3d753a8618b86d7de333b4223005f68720bcd6a7d2bcb9fbd2229ec7c1e429"}, + {file = "urllib3-2.2.2-py3-none-any.whl", hash = "sha256:a448b2f64d686155468037e1ace9f2d2199776e17f0a46610480d311f73e3472"}, + {file = "urllib3-2.2.2.tar.gz", hash = "sha256:dd505485549a7a552833da5e6063639d0d177c04f23bc3864e41e5dc5f612168"}, ] [package.dependencies] -PySocks = {version = ">=1.5.6,<1.5.7 || >1.5.7,<2.0", optional = true, markers = "extra == \"socks\""} +pysocks = {version = ">=1.5.6,<1.5.7 || >1.5.7,<2.0", optional = true, markers = "extra == \"socks\""} [package.extras] -brotli = ["brotli (==1.0.9)", "brotli (>=1.0.9)", "brotlicffi (>=0.8.0)", "brotlipy (>=0.6.0)"] -secure = ["certifi", "cryptography (>=1.3.4)", "idna (>=2.0.0)", "ipaddress", "pyOpenSSL (>=0.14)", "urllib3-secure-extra"] -socks = ["PySocks (>=1.5.6,!=1.5.7,<2.0)"] +brotli = ["brotli (>=1.0.9)", "brotlicffi (>=0.8.0)"] +h2 = ["h2 (>=4,<5)"] +socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"] +zstd = ["zstandard (>=0.18.0)"] [[package]] name = "uvicorn" @@ -5567,23 +5657,20 @@ test = ["Cython (>=0.29.36,<0.30.0)", "aiohttp (==3.9.0b0)", "aiohttp (>=3.8.1)" [[package]] name = "vcrpy" -version = "6.0.1" +version = "5.1.0" description = "Automatically mock your HTTP interactions to simplify and speed up testing" optional = false python-versions = ">=3.8" files = [ - {file = "vcrpy-6.0.1.tar.gz", hash = "sha256:9e023fee7f892baa0bbda2f7da7c8ac51165c1c6e38ff8688683a12a4bde9278"}, + {file = "vcrpy-5.1.0-py2.py3-none-any.whl", hash = "sha256:605e7b7a63dcd940db1df3ab2697ca7faf0e835c0852882142bafb19649d599e"}, + {file = "vcrpy-5.1.0.tar.gz", hash = "sha256:bbf1532f2618a04f11bce2a99af3a9647a32c880957293ff91e0a5f187b6b3d2"}, ] [package.dependencies] PyYAML = "*" -urllib3 = {version = "<2", markers = "platform_python_implementation == \"PyPy\""} wrapt = "*" yarl = "*" -[package.extras] -tests = ["Werkzeug (==2.0.3)", "aiohttp", "boto3", "httplib2", "httpx", "pytest", "pytest-aiohttp", "pytest-asyncio", "pytest-cov", "pytest-httpbin", "requests (>=2.22.0)", "tornado", "urllib3"] - [[package]] name = "virtualenv" version = "20.26.3" @@ -6073,4 +6160,4 @@ tools = ["crewai-tools"] [metadata] lock-version = "2.0" python-versions = ">=3.10,<=3.13" -content-hash = "91ba982ea96ca7be017d536784223d4ef83e86de05d11eb1c3ce0fc1b726f283" +content-hash = "8327a37f807d35d0851e9cc46960e8df0d06924938b2c5354b09951fa54f15e3" diff --git a/pyproject.toml b/pyproject.toml index e438f6574..6cb50c771 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,6 +29,7 @@ jsonref = "^1.1.0" agentops = { version = "^0.3.0", optional = true } embedchain = "^0.1.114" json-repair = "^0.25.2" +auth0-python = "^4.7.1" [tool.poetry.extras] tools = ["crewai-tools"] diff --git a/src/crewai/cli/authentication/__init__.py b/src/crewai/cli/authentication/__init__.py new file mode 100644 index 000000000..484453771 --- /dev/null +++ b/src/crewai/cli/authentication/__init__.py @@ -0,0 +1,3 @@ +from .main import AuthenticationCommand + +__all__ = ["AuthenticationCommand"] diff --git a/src/crewai/cli/authentication/constants.py b/src/crewai/cli/authentication/constants.py new file mode 100644 index 000000000..9418087aa --- /dev/null +++ b/src/crewai/cli/authentication/constants.py @@ -0,0 +1,4 @@ +ALGORITHMS = ["RS256"] +AUTH0_DOMAIN = "dev-jzsr0j8zs0atl5ha.us.auth0.com" +AUTH0_CLIENT_ID = "CZtyRHuVW80HbLSjk4ggXNzjg4KAt7Oe" +AUTH0_AUDIENCE = "https://dev-jzsr0j8zs0atl5ha.us.auth0.com/api/v2/" diff --git a/src/crewai/cli/authentication/main.py b/src/crewai/cli/authentication/main.py new file mode 100644 index 000000000..331b583e8 --- /dev/null +++ b/src/crewai/cli/authentication/main.py @@ -0,0 +1,75 @@ +import time +import webbrowser +from typing import Any, Dict + +import requests +from rich.console import Console + +from .constants import AUTH0_AUDIENCE, AUTH0_CLIENT_ID, AUTH0_DOMAIN +from .utils import TokenManager, validate_token + +console = Console() + + +class AuthenticationCommand: + DEVICE_CODE_URL = f"https://{AUTH0_DOMAIN}/oauth/device/code" + TOKEN_URL = f"https://{AUTH0_DOMAIN}/oauth/token" + + def __init__(self): + self.token_manager = TokenManager() + + def signup(self) -> None: + """Sign up to CrewAI+""" + console.print("Signing Up to CrewAI+ \n", style="bold blue") + device_code_data = self._get_device_code() + self._display_auth_instructions(device_code_data) + + return self._poll_for_token(device_code_data) + + def _get_device_code(self) -> Dict[str, Any]: + """Get the device code to authenticate the user.""" + + device_code_payload = { + "client_id": AUTH0_CLIENT_ID, + "scope": "openid", + "audience": AUTH0_AUDIENCE, + } + response = requests.post(url=self.DEVICE_CODE_URL, data=device_code_payload) + response.raise_for_status() + return response.json() + + def _display_auth_instructions(self, device_code_data: Dict[str, str]) -> None: + """Display the authentication instructions to the user.""" + console.print("1. Navigate to: ", device_code_data["verification_uri_complete"]) + console.print("2. Enter the following code: ", device_code_data["user_code"]) + webbrowser.open(device_code_data["verification_uri_complete"]) + + def _poll_for_token(self, device_code_data: Dict[str, Any]) -> None: + """Poll the server for the token.""" + token_payload = { + "grant_type": "urn:ietf:params:oauth:grant-type:device_code", + "device_code": device_code_data["device_code"], + "client_id": AUTH0_CLIENT_ID, + } + + attempts = 0 + while True and attempts < 5: + response = requests.post(self.TOKEN_URL, data=token_payload) + token_data = response.json() + + if response.status_code == 200: + validate_token(token_data["id_token"]) + expires_in = 360000 # Token expiration time in seconds + self.token_manager.save_tokens(token_data["access_token"], expires_in) + console.print("\nWelcome to CrewAI+ !!", style="green") + return + + if token_data["error"] not in ("authorization_pending", "slow_down"): + raise requests.HTTPError(token_data["error_description"]) + + time.sleep(device_code_data["interval"]) + attempts += 1 + + console.print( + "Timeout: Failed to get the token. Please try again.", style="bold red" + ) diff --git a/src/crewai/cli/authentication/utils.py b/src/crewai/cli/authentication/utils.py new file mode 100644 index 000000000..09e7491b1 --- /dev/null +++ b/src/crewai/cli/authentication/utils.py @@ -0,0 +1,144 @@ +import json +import os +import sys +from datetime import datetime, timedelta +from pathlib import Path +from typing import Optional + +from auth0.authentication.token_verifier import ( + AsymmetricSignatureVerifier, + TokenVerifier, +) +from cryptography.fernet import Fernet + +from .constants import AUTH0_CLIENT_ID, AUTH0_DOMAIN + + +def validate_token(id_token: str) -> None: + """ + Verify the token and its precedence + + :param id_token: + """ + jwks_url = f"https://{AUTH0_DOMAIN}/.well-known/jwks.json" + issuer = f"https://{AUTH0_DOMAIN}/" + signature_verifier = AsymmetricSignatureVerifier(jwks_url) + token_verifier = TokenVerifier( + signature_verifier=signature_verifier, issuer=issuer, audience=AUTH0_CLIENT_ID + ) + token_verifier.verify(id_token) + + +class TokenManager: + def __init__(self, file_path: str = "tokens.enc") -> None: + """ + Initialize the TokenManager class. + + :param file_path: The file path to store the encrypted tokens. Default is "tokens.enc". + """ + self.file_path = file_path + self.key = self._get_or_create_key() + self.fernet = Fernet(self.key) + + def _get_or_create_key(self) -> bytes: + """ + Get or create the encryption key. + + :return: The encryption key. + """ + key_filename = "secret.key" + key = self.read_secure_file(key_filename) + + if key is not None: + return key + + new_key = Fernet.generate_key() + self.save_secure_file(key_filename, new_key) + return new_key + + def save_tokens(self, access_token: str, expires_in: int) -> None: + """ + Save the access token and its expiration time. + + :param access_token: The access token to save. + :param expires_in: The expiration time of the access token in seconds. + """ + expiration_time = datetime.now() + timedelta(seconds=expires_in) + data = { + "access_token": access_token, + "expiration": expiration_time.isoformat(), + } + encrypted_data = self.fernet.encrypt(json.dumps(data).encode()) + self.save_secure_file(self.file_path, encrypted_data) + + def get_token(self) -> Optional[str]: + """ + Get the access token if it is valid and not expired. + + :return: The access token if valid and not expired, otherwise None. + """ + encrypted_data = self.read_secure_file(self.file_path) + + decrypted_data = self.fernet.decrypt(encrypted_data) + data = json.loads(decrypted_data) + + expiration = datetime.fromisoformat(data["expiration"]) + if expiration <= datetime.now(): + return None + + return data["access_token"] + + def get_secure_storage_path(self) -> Path: + """ + Get the secure storage path based on the operating system. + + :return: The secure storage path. + """ + if sys.platform == "win32": + # Windows: Use %LOCALAPPDATA% + base_path = os.environ.get("LOCALAPPDATA") + elif sys.platform == "darwin": + # macOS: Use ~/Library/Application Support + base_path = os.path.expanduser("~/Library/Application Support") + else: + # Linux and other Unix-like: Use ~/.local/share + base_path = os.path.expanduser("~/.local/share") + + app_name = "crewai/credentials" + storage_path = Path(base_path) / app_name + + storage_path.mkdir(parents=True, exist_ok=True) + + return storage_path + + def save_secure_file(self, filename: str, content: bytes) -> None: + """ + Save the content to a secure file. + + :param filename: The name of the file. + :param content: The content to save. + """ + storage_path = self.get_secure_storage_path() + file_path = storage_path / filename + + with open(file_path, "wb") as f: + f.write(content) + + # Set appropriate permissions (read/write for owner only) + os.chmod(file_path, 0o600) + + def read_secure_file(self, filename: str) -> Optional[bytes]: + """ + Read the content of a secure file. + + :param filename: The name of the file. + :return: The content of the file if it exists, otherwise None. + """ + storage_path = self.get_secure_storage_path() + file_path = storage_path / filename + + if not file_path.exists(): + return None + + with open(file_path, "rb") as f: + return f.read() diff --git a/src/crewai/cli/cli.py b/src/crewai/cli/cli.py index 2ca400000..cf1e7584b 100644 --- a/src/crewai/cli/cli.py +++ b/src/crewai/cli/cli.py @@ -1,3 +1,5 @@ +from typing import Optional + import click import pkg_resources @@ -7,6 +9,8 @@ from crewai.memory.storage.kickoff_task_outputs_storage import ( KickoffTaskOutputsSQLiteStorage, ) +from .authentication.main import AuthenticationCommand +from .deploy.main import DeployCommand from .evaluate_crew import evaluate_crew from .install_crew import install_crew from .replay_from_task import replay_task_command @@ -179,5 +183,70 @@ def run(): run_crew() +@crewai.command() +def signup(): + """Sign Up/Login to CrewAI+.""" + AuthenticationCommand().signup() + + +@crewai.command() +def login(): + """Sign Up/Login to CrewAI+.""" + AuthenticationCommand().signup() + + +# DEPLOY CREWAI+ COMMANDS +@crewai.group() +def deploy(): + """Deploy the Crew CLI group.""" + pass + + +@deploy.command(name="create") +def deploy_create(): + """Create a Crew deployment.""" + deploy_cmd = DeployCommand() + deploy_cmd.create_crew() + + +@deploy.command(name="list") +def deploy_list(): + """List all deployments.""" + deploy_cmd = DeployCommand() + deploy_cmd.list_crews() + + +@deploy.command(name="push") +@click.option("-u", "--uuid", type=str, help="Crew UUID parameter") +def deploy_push(uuid: Optional[str]): + """Deploy the Crew.""" + deploy_cmd = DeployCommand() + deploy_cmd.deploy(uuid=uuid) + + +@deploy.command(name="status") +@click.option("-u", "--uuid", type=str, help="Crew UUID parameter") +def deply_status(uuid: Optional[str]): + """Get the status of a deployment.""" + deploy_cmd = DeployCommand() + deploy_cmd.get_crew_status(uuid=uuid) + + +@deploy.command(name="logs") +@click.option("-u", "--uuid", type=str, help="Crew UUID parameter") +def deploy_logs(uuid: Optional[str]): + """Get the logs of a deployment.""" + deploy_cmd = DeployCommand() + deploy_cmd.get_crew_logs(uuid=uuid) + + +@deploy.command(name="remove") +@click.option("-u", "--uuid", type=str, help="Crew UUID parameter") +def deploy_remove(uuid: Optional[str]): + """Remove a deployment.""" + deploy_cmd = DeployCommand() + deploy_cmd.remove_crew(uuid=uuid) + + if __name__ == "__main__": crewai() diff --git a/src/crewai/cli/deploy/__init__.py b/src/crewai/cli/deploy/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/crewai/cli/deploy/api.py b/src/crewai/cli/deploy/api.py new file mode 100644 index 000000000..942fc487e --- /dev/null +++ b/src/crewai/cli/deploy/api.py @@ -0,0 +1,63 @@ +from os import getenv + +import requests + + +class CrewAPI: + """ + CrewAPI class to interact with the crewAI+ API. + """ + + def __init__(self, api_key: str) -> None: + self.api_key = api_key + self.headers = { + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json", + } + self.base_url = getenv( + "CREWAI_BASE_URL", "https://dev.crewai.com/crewai_plus/api/v1/crews" + ) + + def _make_request(self, method: str, endpoint: str, **kwargs) -> requests.Response: + url = f"{self.base_url}/{endpoint}" + return requests.request(method, url, headers=self.headers, **kwargs) + + # Deploy + def deploy_by_name(self, project_name: str) -> requests.Response: + return self._make_request("POST", f"by-name/{project_name}/deploy") + + def deploy_by_uuid(self, uuid: str) -> requests.Response: + return self._make_request("POST", f"{uuid}/deploy") + + # Status + def status_by_name(self, project_name: str) -> requests.Response: + return self._make_request("GET", f"by-name/{project_name}/status") + + def status_by_uuid(self, uuid: str) -> requests.Response: + return self._make_request("GET", f"{uuid}/status") + + # Logs + def logs_by_name( + self, project_name: str, log_type: str = "deployment" + ) -> requests.Response: + return self._make_request("GET", f"by-name/{project_name}/logs/{log_type}") + + def logs_by_uuid( + self, uuid: str, log_type: str = "deployment" + ) -> requests.Response: + return self._make_request("GET", f"{uuid}/logs/{log_type}") + + # Delete + def delete_by_name(self, project_name: str) -> requests.Response: + return self._make_request("DELETE", f"by-name/{project_name}") + + def delete_by_uuid(self, uuid: str) -> requests.Response: + return self._make_request("DELETE", f"{uuid}") + + # List + def list_crews(self) -> requests.Response: + return self._make_request("GET", "") + + # Create + def create_crew(self, payload) -> requests.Response: + return self._make_request("POST", "", json=payload) diff --git a/src/crewai/cli/deploy/main.py b/src/crewai/cli/deploy/main.py new file mode 100644 index 000000000..d67e1cdc8 --- /dev/null +++ b/src/crewai/cli/deploy/main.py @@ -0,0 +1,289 @@ +from typing import Any, Dict, List, Optional + +from rich.console import Console + +from .api import CrewAPI +from .utils import ( + fetch_and_json_env_file, + get_auth_token, + get_git_remote_url, + get_project_name, +) + +console = Console() + + +class DeployCommand: + """ + A class to handle deployment-related operations for CrewAI projects. + """ + + def __init__(self): + """ + Initialize the DeployCommand with project name and API client. + """ + try: + access_token = get_auth_token() + except Exception: + console.print( + "Please sign up/login to CrewAI+ before using the CLI.", + style="bold red", + ) + console.print("Run 'crewai signup' to sign up/login.", style="bold green") + raise SystemExit + + self.project_name = get_project_name() + self.client = CrewAPI(api_key=access_token) + + def _handle_error(self, json_response: Dict[str, Any]) -> None: + """ + Handle and display error messages from API responses. + + Args: + json_response (Dict[str, Any]): The JSON response containing error information. + """ + error = json_response.get("error", "Unknown error") + message = json_response.get("message", "No message provided") + console.print(f"Error: {error}", style="bold red") + console.print(f"Message: {message}", style="bold red") + + def _standard_no_param_error_message(self) -> None: + """ + Display a standard error message when no UUID or project name is available. + """ + console.print( + "No UUID provided, project pyproject.toml not found or with error.", + style="bold red", + ) + + def _display_deployment_info(self, json_response: Dict[str, Any]) -> None: + """ + Display deployment information. + + Args: + json_response (Dict[str, Any]): The deployment information to display. + """ + console.print("Deploying the crew...\n", style="bold blue") + for key, value in json_response.items(): + console.print(f"{key.title()}: [green]{value}[/green]") + console.print("\nTo check the status of the deployment, run:") + console.print("crewai deploy status") + console.print(" or") + console.print(f"crewai deploy status --uuid \"{json_response['uuid']}\"") + + def _display_logs(self, log_messages: List[Dict[str, Any]]) -> None: + """ + Display log messages. + + Args: + log_messages (List[Dict[str, Any]]): The log messages to display. + """ + for log_message in log_messages: + console.print( + f"{log_message['timestamp']} - {log_message['level']}: {log_message['message']}" + ) + + def deploy(self, uuid: Optional[str] = None) -> None: + """ + Deploy a crew using either UUID or project name. + + Args: + uuid (Optional[str]): The UUID of the crew to deploy. + """ + console.print("Starting deployment...", style="bold blue") + if uuid: + response = self.client.deploy_by_uuid(uuid) + elif self.project_name: + response = self.client.deploy_by_name(self.project_name) + else: + self._standard_no_param_error_message() + return + + json_response = response.json() + if response.status_code == 200: + self._display_deployment_info(json_response) + else: + self._handle_error(json_response) + + def create_crew(self) -> None: + """ + Create a new crew deployment. + """ + console.print("Creating deployment...", style="bold blue") + env_vars = fetch_and_json_env_file() + remote_repo_url = get_git_remote_url() + + self._confirm_input(env_vars, remote_repo_url) + payload = self._create_payload(env_vars, remote_repo_url) + + response = self.client.create_crew(payload) + if response.status_code == 201: + self._display_creation_success(response.json()) + else: + self._handle_error(response.json()) + + def _confirm_input(self, env_vars: Dict[str, str], remote_repo_url: str) -> None: + """ + Confirm input parameters with the user. + + Args: + env_vars (Dict[str, str]): Environment variables. + remote_repo_url (str): Remote repository URL. + """ + input(f"Press Enter to continue with the following Env vars: {env_vars}") + input( + f"Press Enter to continue with the following remote repository: {remote_repo_url}\n" + ) + + def _create_payload( + self, + env_vars: Dict[str, str], + remote_repo_url: str, + ) -> Dict[str, Any]: + """ + Create the payload for crew creation. + + Args: + remote_repo_url (str): Remote repository URL. + env_vars (Dict[str, str]): Environment variables. + + Returns: + Dict[str, Any]: The payload for crew creation. + """ + return { + "deploy": { + "name": self.project_name, + "repo_clone_url": remote_repo_url, + "env": env_vars, + } + } + + def _display_creation_success(self, json_response: Dict[str, Any]) -> None: + """ + Display success message after crew creation. + + Args: + json_response (Dict[str, Any]): The response containing crew information. + """ + console.print("Deployment created successfully!\n", style="bold green") + console.print( + f"Name: {self.project_name} ({json_response['uuid']})", style="bold green" + ) + console.print(f"Status: {json_response['status']}", style="bold green") + console.print("\nTo (re)deploy the crew, run:") + console.print("crewai deploy push") + console.print(" or") + console.print(f"crewai deploy push --uuid {json_response['uuid']}") + + def list_crews(self) -> None: + """ + List all available crews. + """ + console.print("Listing all Crews\n", style="bold blue") + + response = self.client.list_crews() + json_response = response.json() + if response.status_code == 200: + self._display_crews(json_response) + else: + self._display_no_crews_message() + + def _display_crews(self, crews_data: List[Dict[str, Any]]) -> None: + """ + Display the list of crews. + + Args: + crews_data (List[Dict[str, Any]]): List of crew data to display. + """ + for crew_data in crews_data: + console.print( + f"- {crew_data['name']} ({crew_data['uuid']}) [blue]{crew_data['status']}[/blue]" + ) + + def _display_no_crews_message(self) -> None: + """ + Display a message when no crews are available. + """ + console.print("You don't have any Crews yet. Let's create one!", style="yellow") + console.print(" crewai create crew ", style="green") + + def get_crew_status(self, uuid: Optional[str] = None) -> None: + """ + Get the status of a crew. + + Args: + uuid (Optional[str]): The UUID of the crew to check. + """ + console.print("Fetching deployment status...", style="bold blue") + if uuid: + response = self.client.status_by_uuid(uuid) + elif self.project_name: + response = self.client.status_by_name(self.project_name) + else: + self._standard_no_param_error_message() + return + + json_response = response.json() + if response.status_code == 200: + self._display_crew_status(json_response) + else: + self._handle_error(json_response) + + def _display_crew_status(self, status_data: Dict[str, str]) -> None: + """ + Display the status of a crew. + + Args: + status_data (Dict[str, str]): The status data to display. + """ + console.print(f"Name:\t {status_data['name']}") + console.print(f"Status:\t {status_data['status']}") + + def get_crew_logs(self, uuid: Optional[str], log_type: str = "deployment") -> None: + """ + Get logs for a crew. + + Args: + uuid (Optional[str]): The UUID of the crew to get logs for. + log_type (str): The type of logs to retrieve (default: "deployment"). + """ + console.print(f"Fetching {log_type} logs...", style="bold blue") + + if uuid: + response = self.client.logs_by_uuid(uuid, log_type) + elif self.project_name: + response = self.client.logs_by_name(self.project_name, log_type) + else: + self._standard_no_param_error_message() + return + + if response.status_code == 200: + self._display_logs(response.json()) + else: + self._handle_error(response.json()) + + def remove_crew(self, uuid: Optional[str]) -> None: + """ + Remove a crew deployment. + + Args: + uuid (Optional[str]): The UUID of the crew to remove. + """ + console.print("Removing deployment...", style="bold blue") + + if uuid: + response = self.client.delete_by_uuid(uuid) + elif self.project_name: + response = self.client.delete_by_name(self.project_name) + else: + self._standard_no_param_error_message() + return + + if response.status_code == 204: + console.print( + f"Crew '{self.project_name}' removed successfully.", style="green" + ) + else: + console.print( + f"Failed to remove crew '{self.project_name}'", style="bold red" + ) diff --git a/src/crewai/cli/deploy/utils.py b/src/crewai/cli/deploy/utils.py new file mode 100644 index 000000000..8fe1851df --- /dev/null +++ b/src/crewai/cli/deploy/utils.py @@ -0,0 +1,117 @@ +import re +import subprocess + +import tomllib + +from ..authentication.utils import TokenManager + + +def get_git_remote_url() -> str: + """Get the Git repository's remote URL.""" + try: + # Run the git remote -v command + result = subprocess.run( + ["git", "remote", "-v"], capture_output=True, text=True, check=True + ) + + # Get the output + output = result.stdout + + # Parse the output to find the origin URL + matches = re.findall(r"origin\s+(.*?)\s+\(fetch\)", output) + + if matches: + return matches[0] # Return the first match (origin URL) + else: + print("No origin remote found.") + return "No remote URL found" + + except subprocess.CalledProcessError as e: + return f"Error running trying to fetch the Git Repository: {e}" + except FileNotFoundError: + return "Git command not found. Make sure Git is installed and in your PATH." + + +def get_project_name(pyproject_path: str = "pyproject.toml"): + """Get the project name from the pyproject.toml file.""" + try: + # Read the pyproject.toml file + with open(pyproject_path, "rb") as f: + pyproject_content = tomllib.load(f) + + # Extract the project name + project_name = pyproject_content["tool"]["poetry"]["name"] + + if "crewai" not in pyproject_content["tool"]["poetry"]["dependencies"]: + raise Exception("crewai is not in the dependencies.") + + return project_name + + except FileNotFoundError: + print(f"Error: {pyproject_path} not found.") + except KeyError: + print(f"Error: {pyproject_path} is not a valid pyproject.toml file.") + except tomllib.TOMLDecodeError: + print(f"Error: {pyproject_path} is not a valid TOML file.") + except Exception as e: + print(f"Error reading the pyproject.toml file: {e}") + + return None + + +def get_crewai_version(pyproject_path: str = "pyproject.toml") -> str: + """Get the version number of crewai from the pyproject.toml file.""" + try: + # Read the pyproject.toml file + with open("pyproject.toml", "rb") as f: + pyproject_content = tomllib.load(f) + + # Extract the version number of crewai + crewai_version = pyproject_content["tool"]["poetry"]["dependencies"]["crewai"][ + "version" + ] + + return crewai_version + + except FileNotFoundError: + print(f"Error: {pyproject_path} not found.") + except KeyError: + print(f"Error: {pyproject_path} is not a valid pyproject.toml file.") + except tomllib.TOMLDecodeError: + print(f"Error: {pyproject_path} is not a valid TOML file.") + except Exception as e: + print(f"Error reading the pyproject.toml file: {e}") + + return "no-version-found" + + +def fetch_and_json_env_file(env_file_path: str = ".env") -> dict: + """Fetch the environment variables from a .env file and return them as a dictionary.""" + try: + # Read the .env file + with open(env_file_path, "r") as f: + env_content = f.read() + + # Parse the .env file content to a dictionary + env_dict = {} + for line in env_content.splitlines(): + if line.strip() and not line.strip().startswith("#"): + key, value = line.split("=", 1) + env_dict[key.strip()] = value.strip() + + return env_dict + + except FileNotFoundError: + print(f"Error: {env_file_path} not found.") + except Exception as e: + print(f"Error reading the .env file: {e}") + + return {} + + +def get_auth_token() -> str: + """Get the authentication token.""" + access_token = TokenManager().get_token() + if not access_token: + raise Exception() + return access_token diff --git a/src/crewai/memory/memory.py b/src/crewai/memory/memory.py index 4b8c687c0..9df09d3c7 100644 --- a/src/crewai/memory/memory.py +++ b/src/crewai/memory/memory.py @@ -21,7 +21,7 @@ class Memory: if agent: metadata["agent"] = agent - self.storage.save(value, metadata) # type: ignore # Maybe BUG? Should be self.storage.save(key, value, metadata) + self.storage.save(value, metadata) def search(self, query: str) -> Dict[str, Any]: return self.storage.search(query) diff --git a/tests/cli/authentication/test_auth_main.py b/tests/cli/authentication/test_auth_main.py new file mode 100644 index 000000000..c56968aab --- /dev/null +++ b/tests/cli/authentication/test_auth_main.py @@ -0,0 +1,94 @@ +import unittest +from unittest.mock import MagicMock, patch + +import requests +from crewai.cli.authentication.main import AuthenticationCommand + + +class TestAuthenticationCommand(unittest.TestCase): + def setUp(self): + self.auth_command = AuthenticationCommand() + + @patch("crewai.cli.authentication.main.requests.post") + def test_get_device_code(self, mock_post): + mock_response = MagicMock() + mock_response.json.return_value = { + "device_code": "123456", + "user_code": "ABCDEF", + "verification_uri_complete": "https://example.com", + "interval": 5, + } + mock_post.return_value = mock_response + + device_code_data = self.auth_command._get_device_code() + + self.assertEqual(device_code_data["device_code"], "123456") + self.assertEqual(device_code_data["user_code"], "ABCDEF") + self.assertEqual( + device_code_data["verification_uri_complete"], "https://example.com" + ) + self.assertEqual(device_code_data["interval"], 5) + + @patch("crewai.cli.authentication.main.console.print") + @patch("crewai.cli.authentication.main.webbrowser.open") + def test_display_auth_instructions(self, mock_open, mock_print): + device_code_data = { + "verification_uri_complete": "https://example.com", + "user_code": "ABCDEF", + } + + self.auth_command._display_auth_instructions(device_code_data) + + mock_print.assert_any_call("1. Navigate to: ", "https://example.com") + mock_print.assert_any_call("2. Enter the following code: ", "ABCDEF") + mock_open.assert_called_once_with("https://example.com") + + @patch("crewai.cli.authentication.main.requests.post") + @patch("crewai.cli.authentication.main.validate_token") + @patch("crewai.cli.authentication.main.console.print") + def test_poll_for_token_success(self, mock_print, mock_validate_token, mock_post): + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "id_token": "TOKEN", + "access_token": "ACCESS_TOKEN", + } + mock_post.return_value = mock_response + + self.auth_command._poll_for_token({"device_code": "123456"}) + + mock_validate_token.assert_called_once_with("TOKEN") + mock_print.assert_called_once_with("\nWelcome to CrewAI+ !!", style="green") + + @patch("crewai.cli.authentication.main.requests.post") + @patch("crewai.cli.authentication.main.console.print") + def test_poll_for_token_error(self, mock_print, mock_post): + mock_response = MagicMock() + mock_response.status_code = 400 + mock_response.json.return_value = { + "error": "invalid_request", + "error_description": "Invalid request", + } + mock_post.return_value = mock_response + + with self.assertRaises(requests.HTTPError): + self.auth_command._poll_for_token({"device_code": "123456"}) + + mock_print.assert_not_called() + + @patch("crewai.cli.authentication.main.requests.post") + @patch("crewai.cli.authentication.main.console.print") + def test_poll_for_token_timeout(self, mock_print, mock_post): + mock_response = MagicMock() + mock_response.status_code = 400 + mock_response.json.return_value = { + "error": "authorization_pending", + "error_description": "Authorization pending", + } + mock_post.return_value = mock_response + + self.auth_command._poll_for_token({"device_code": "123456", "interval": 0.01}) + + mock_print.assert_called_once_with( + "Timeout: Failed to get the token. Please try again.", style="bold red" + ) diff --git a/tests/cli/authentication/test_utils.py b/tests/cli/authentication/test_utils.py new file mode 100644 index 000000000..b04dceede --- /dev/null +++ b/tests/cli/authentication/test_utils.py @@ -0,0 +1,147 @@ +import json +import unittest +from datetime import datetime, timedelta +from unittest.mock import MagicMock, patch + +from crewai.cli.authentication.utils import TokenManager, validate_token +from cryptography.fernet import Fernet + + +class TestValidateToken(unittest.TestCase): + @patch("crewai.cli.authentication.utils.AsymmetricSignatureVerifier") + @patch("crewai.cli.authentication.utils.TokenVerifier") + def test_validate_token(self, mock_token_verifier, mock_asymmetric_verifier): + mock_verifier_instance = mock_token_verifier.return_value + mock_id_token = "mock_id_token" + + validate_token(mock_id_token) + + mock_asymmetric_verifier.assert_called_once_with( + "https://dev-jzsr0j8zs0atl5ha.us.auth0.com/.well-known/jwks.json" + ) + mock_token_verifier.assert_called_once_with( + signature_verifier=mock_asymmetric_verifier.return_value, + issuer="https://dev-jzsr0j8zs0atl5ha.us.auth0.com/", + audience="CZtyRHuVW80HbLSjk4ggXNzjg4KAt7Oe", + ) + mock_verifier_instance.verify.assert_called_once_with(mock_id_token) + + +class TestTokenManager(unittest.TestCase): + def setUp(self): + self.token_manager = TokenManager() + + @patch("crewai.cli.authentication.utils.TokenManager.read_secure_file") + @patch("crewai.cli.authentication.utils.TokenManager.save_secure_file") + @patch("crewai.cli.authentication.utils.TokenManager._get_or_create_key") + def test_get_or_create_key_existing(self, mock_get_or_create, mock_save, mock_read): + mock_key = Fernet.generate_key() + mock_get_or_create.return_value = mock_key + + token_manager = TokenManager() + result = token_manager.key + + self.assertEqual(result, mock_key) + + @patch("crewai.cli.authentication.utils.Fernet.generate_key") + @patch("crewai.cli.authentication.utils.TokenManager.read_secure_file") + @patch("crewai.cli.authentication.utils.TokenManager.save_secure_file") + def test_get_or_create_key_new(self, mock_save, mock_read, mock_generate): + mock_key = b"new_key" + mock_read.return_value = None + mock_generate.return_value = mock_key + + result = self.token_manager._get_or_create_key() + + self.assertEqual(result, mock_key) + mock_read.assert_called_once_with("secret.key") + mock_generate.assert_called_once() + mock_save.assert_called_once_with("secret.key", mock_key) + + @patch("crewai.cli.authentication.utils.TokenManager.save_secure_file") + def test_save_tokens(self, mock_save): + access_token = "test_token" + expires_in = 3600 + + self.token_manager.save_tokens(access_token, expires_in) + + mock_save.assert_called_once() + args = mock_save.call_args[0] + self.assertEqual(args[0], "tokens.enc") + decrypted_data = self.token_manager.fernet.decrypt(args[1]) + data = json.loads(decrypted_data) + self.assertEqual(data["access_token"], access_token) + expiration = datetime.fromisoformat(data["expiration"]) + self.assertAlmostEqual( + expiration, + datetime.now() + timedelta(seconds=expires_in), + delta=timedelta(seconds=1), + ) + + @patch("crewai.cli.authentication.utils.TokenManager.read_secure_file") + def test_get_token_valid(self, mock_read): + access_token = "test_token" + expiration = (datetime.now() + timedelta(hours=1)).isoformat() + data = {"access_token": access_token, "expiration": expiration} + encrypted_data = self.token_manager.fernet.encrypt(json.dumps(data).encode()) + mock_read.return_value = encrypted_data + + result = self.token_manager.get_token() + + self.assertEqual(result, access_token) + + @patch("crewai.cli.authentication.utils.TokenManager.read_secure_file") + def test_get_token_expired(self, mock_read): + access_token = "test_token" + expiration = (datetime.now() - timedelta(hours=1)).isoformat() + data = {"access_token": access_token, "expiration": expiration} + encrypted_data = self.token_manager.fernet.encrypt(json.dumps(data).encode()) + mock_read.return_value = encrypted_data + + result = self.token_manager.get_token() + + self.assertIsNone(result) + + @patch("crewai.cli.authentication.utils.TokenManager.get_secure_storage_path") + @patch("builtins.open", new_callable=unittest.mock.mock_open) + @patch("crewai.cli.authentication.utils.os.chmod") + def test_save_secure_file(self, mock_chmod, mock_open, mock_get_path): + mock_path = MagicMock() + mock_get_path.return_value = mock_path + filename = "test_file.txt" + content = b"test_content" + + self.token_manager.save_secure_file(filename, content) + + mock_path.__truediv__.assert_called_once_with(filename) + mock_open.assert_called_once_with(mock_path.__truediv__.return_value, "wb") + mock_open().write.assert_called_once_with(content) + mock_chmod.assert_called_once_with(mock_path.__truediv__.return_value, 0o600) + + @patch("crewai.cli.authentication.utils.TokenManager.get_secure_storage_path") + @patch( + "builtins.open", new_callable=unittest.mock.mock_open, read_data=b"test_content" + ) + def test_read_secure_file_exists(self, mock_open, mock_get_path): + mock_path = MagicMock() + mock_get_path.return_value = mock_path + mock_path.__truediv__.return_value.exists.return_value = True + filename = "test_file.txt" + + result = self.token_manager.read_secure_file(filename) + + self.assertEqual(result, b"test_content") + mock_path.__truediv__.assert_called_once_with(filename) + mock_open.assert_called_once_with(mock_path.__truediv__.return_value, "rb") + + @patch("crewai.cli.authentication.utils.TokenManager.get_secure_storage_path") + def test_read_secure_file_not_exists(self, mock_get_path): + mock_path = MagicMock() + mock_get_path.return_value = mock_path + mock_path.__truediv__.return_value.exists.return_value = False + filename = "test_file.txt" + + result = self.token_manager.read_secure_file(filename) + + self.assertIsNone(result) + mock_path.__truediv__.assert_called_once_with(filename) diff --git a/tests/cli/cli_test.py b/tests/cli/cli_test.py index 4f606e213..b2fb8d0e5 100644 --- a/tests/cli/cli_test.py +++ b/tests/cli/cli_test.py @@ -2,8 +2,19 @@ from unittest import mock import pytest from click.testing import CliRunner - -from crewai.cli.cli import reset_memories, test, train, version +from crewai.cli.cli import ( + deploy_create, + deploy_list, + deploy_logs, + deploy_push, + deploy_remove, + deply_status, + reset_memories, + signup, + test, + train, + version, +) @pytest.fixture @@ -163,3 +174,106 @@ def test_test_invalid_string_iterations(evaluate_crew, runner): "Usage: test [OPTIONS]\nTry 'test --help' for help.\n\nError: Invalid value for '-n' / '--n_iterations': 'invalid' is not a valid integer.\n" in result.output ) + + +@mock.patch("crewai.cli.cli.AuthenticationCommand") +def test_signup(command, runner): + mock_auth = command.return_value + result = runner.invoke(signup) + + assert result.exit_code == 0 + mock_auth.signup.assert_called_once() + + +@mock.patch("crewai.cli.cli.DeployCommand") +def test_deploy_create(command, runner): + mock_deploy = command.return_value + result = runner.invoke(deploy_create) + + assert result.exit_code == 0 + mock_deploy.create_crew.assert_called_once() + + +@mock.patch("crewai.cli.cli.DeployCommand") +def test_deploy_list(command, runner): + mock_deploy = command.return_value + result = runner.invoke(deploy_list) + + assert result.exit_code == 0 + mock_deploy.list_crews.assert_called_once() + + +@mock.patch("crewai.cli.cli.DeployCommand") +def test_deploy_push(command, runner): + mock_deploy = command.return_value + uuid = "test-uuid" + result = runner.invoke(deploy_push, ["-u", uuid]) + + assert result.exit_code == 0 + mock_deploy.deploy.assert_called_once_with(uuid=uuid) + + +@mock.patch("crewai.cli.cli.DeployCommand") +def test_deploy_push_no_uuid(command, runner): + mock_deploy = command.return_value + result = runner.invoke(deploy_push) + + assert result.exit_code == 0 + mock_deploy.deploy.assert_called_once_with(uuid=None) + + +@mock.patch("crewai.cli.cli.DeployCommand") +def test_deploy_status(command, runner): + mock_deploy = command.return_value + uuid = "test-uuid" + result = runner.invoke(deply_status, ["-u", uuid]) + + assert result.exit_code == 0 + mock_deploy.get_crew_status.assert_called_once_with(uuid=uuid) + + +@mock.patch("crewai.cli.cli.DeployCommand") +def test_deploy_status_no_uuid(command, runner): + mock_deploy = command.return_value + result = runner.invoke(deply_status) + + assert result.exit_code == 0 + mock_deploy.get_crew_status.assert_called_once_with(uuid=None) + + +@mock.patch("crewai.cli.cli.DeployCommand") +def test_deploy_logs(command, runner): + mock_deploy = command.return_value + uuid = "test-uuid" + result = runner.invoke(deploy_logs, ["-u", uuid]) + + assert result.exit_code == 0 + mock_deploy.get_crew_logs.assert_called_once_with(uuid=uuid) + + +@mock.patch("crewai.cli.cli.DeployCommand") +def test_deploy_logs_no_uuid(command, runner): + mock_deploy = command.return_value + result = runner.invoke(deploy_logs) + + assert result.exit_code == 0 + mock_deploy.get_crew_logs.assert_called_once_with(uuid=None) + + +@mock.patch("crewai.cli.cli.DeployCommand") +def test_deploy_remove(command, runner): + mock_deploy = command.return_value + uuid = "test-uuid" + result = runner.invoke(deploy_remove, ["-u", uuid]) + + assert result.exit_code == 0 + mock_deploy.remove_crew.assert_called_once_with(uuid=uuid) + + +@mock.patch("crewai.cli.cli.DeployCommand") +def test_deploy_remove_no_uuid(command, runner): + mock_deploy = command.return_value + result = runner.invoke(deploy_remove) + + assert result.exit_code == 0 + mock_deploy.remove_crew.assert_called_once_with(uuid=None) diff --git a/tests/cli/deploy/test_api.py b/tests/cli/deploy/test_api.py new file mode 100644 index 000000000..f1a6c573d --- /dev/null +++ b/tests/cli/deploy/test_api.py @@ -0,0 +1,102 @@ +import unittest +from os import environ +from unittest.mock import MagicMock, patch + +from crewai.cli.deploy.api import CrewAPI + + +class TestCrewAPI(unittest.TestCase): + def setUp(self): + self.api_key = "test_api_key" + self.api = CrewAPI(self.api_key) + + def test_init(self): + self.assertEqual(self.api.api_key, self.api_key) + self.assertEqual( + self.api.headers, + { + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json", + }, + ) + + @patch("crewai.cli.deploy.api.requests.request") + def test_make_request(self, mock_request): + mock_response = MagicMock() + mock_request.return_value = mock_response + + response = self.api._make_request("GET", "test_endpoint") + + mock_request.assert_called_once_with( + "GET", f"{self.api.base_url}/test_endpoint", headers=self.api.headers + ) + self.assertEqual(response, mock_response) + + @patch("crewai.cli.deploy.api.CrewAPI._make_request") + def test_deploy_by_name(self, mock_make_request): + self.api.deploy_by_name("test_project") + mock_make_request.assert_called_once_with("POST", "by-name/test_project/deploy") + + @patch("crewai.cli.deploy.api.CrewAPI._make_request") + def test_deploy_by_uuid(self, mock_make_request): + self.api.deploy_by_uuid("test_uuid") + mock_make_request.assert_called_once_with("POST", "test_uuid/deploy") + + @patch("crewai.cli.deploy.api.CrewAPI._make_request") + def test_status_by_name(self, mock_make_request): + self.api.status_by_name("test_project") + mock_make_request.assert_called_once_with("GET", "by-name/test_project/status") + + @patch("crewai.cli.deploy.api.CrewAPI._make_request") + def test_status_by_uuid(self, mock_make_request): + self.api.status_by_uuid("test_uuid") + mock_make_request.assert_called_once_with("GET", "test_uuid/status") + + @patch("crewai.cli.deploy.api.CrewAPI._make_request") + def test_logs_by_name(self, mock_make_request): + self.api.logs_by_name("test_project") + mock_make_request.assert_called_once_with( + "GET", "by-name/test_project/logs/deployment" + ) + + self.api.logs_by_name("test_project", "custom_log") + mock_make_request.assert_called_with( + "GET", "by-name/test_project/logs/custom_log" + ) + + @patch("crewai.cli.deploy.api.CrewAPI._make_request") + def test_logs_by_uuid(self, mock_make_request): + self.api.logs_by_uuid("test_uuid") + mock_make_request.assert_called_once_with("GET", "test_uuid/logs/deployment") + + self.api.logs_by_uuid("test_uuid", "custom_log") + mock_make_request.assert_called_with("GET", "test_uuid/logs/custom_log") + + @patch("crewai.cli.deploy.api.CrewAPI._make_request") + def test_delete_by_name(self, mock_make_request): + self.api.delete_by_name("test_project") + mock_make_request.assert_called_once_with("DELETE", "by-name/test_project") + + @patch("crewai.cli.deploy.api.CrewAPI._make_request") + def test_delete_by_uuid(self, mock_make_request): + self.api.delete_by_uuid("test_uuid") + mock_make_request.assert_called_once_with("DELETE", "test_uuid") + + @patch("crewai.cli.deploy.api.CrewAPI._make_request") + def test_list_crews(self, mock_make_request): + self.api.list_crews() + mock_make_request.assert_called_once_with("GET", "") + + @patch("crewai.cli.deploy.api.CrewAPI._make_request") + def test_create_crew(self, mock_make_request): + payload = {"name": "test_crew"} + self.api.create_crew(payload) + mock_make_request.assert_called_once_with("POST", "", json=payload) + + @patch.dict(environ, {"CREWAI_BASE_URL": "https://custom-url.com/api"}) + def test_custom_base_url(self): + custom_api = CrewAPI("test_key") + self.assertEqual( + custom_api.base_url, + "https://custom-url.com/api", + ) diff --git a/tests/cli/deploy/test_deploy_main.py b/tests/cli/deploy/test_deploy_main.py new file mode 100644 index 000000000..f4b08d877 --- /dev/null +++ b/tests/cli/deploy/test_deploy_main.py @@ -0,0 +1,153 @@ +import unittest +from io import StringIO +from unittest.mock import MagicMock, patch + +from crewai.cli.deploy.main import DeployCommand + + +class TestDeployCommand(unittest.TestCase): + @patch("crewai.cli.deploy.main.get_auth_token") + @patch("crewai.cli.deploy.main.get_project_name") + @patch("crewai.cli.deploy.main.CrewAPI") + def setUp(self, mock_crew_api, mock_get_project_name, mock_get_auth_token): + self.mock_get_auth_token = mock_get_auth_token + self.mock_get_project_name = mock_get_project_name + self.mock_crew_api = mock_crew_api + + self.mock_get_auth_token.return_value = "test_token" + self.mock_get_project_name.return_value = "test_project" + + self.deploy_command = DeployCommand() + self.mock_client = self.deploy_command.client + + def test_init_success(self): + self.assertEqual(self.deploy_command.project_name, "test_project") + self.mock_crew_api.assert_called_once_with(api_key="test_token") + + @patch("crewai.cli.deploy.main.get_auth_token") + def test_init_failure(self, mock_get_auth_token): + mock_get_auth_token.side_effect = Exception("Auth failed") + + with self.assertRaises(SystemExit): + DeployCommand() + + def test_handle_error(self): + with patch("sys.stdout", new=StringIO()) as fake_out: + self.deploy_command._handle_error( + {"error": "Test error", "message": "Test message"} + ) + self.assertIn("Error: Test error", fake_out.getvalue()) + self.assertIn("Message: Test message", fake_out.getvalue()) + + def test_standard_no_param_error_message(self): + with patch("sys.stdout", new=StringIO()) as fake_out: + self.deploy_command._standard_no_param_error_message() + self.assertIn("No UUID provided", fake_out.getvalue()) + + def test_display_deployment_info(self): + with patch("sys.stdout", new=StringIO()) as fake_out: + self.deploy_command._display_deployment_info( + {"uuid": "test-uuid", "status": "deployed"} + ) + self.assertIn("Deploying the crew...", fake_out.getvalue()) + self.assertIn("test-uuid", fake_out.getvalue()) + self.assertIn("deployed", fake_out.getvalue()) + + def test_display_logs(self): + with patch("sys.stdout", new=StringIO()) as fake_out: + self.deploy_command._display_logs( + [{"timestamp": "2023-01-01", "level": "INFO", "message": "Test log"}] + ) + self.assertIn("2023-01-01 - INFO: Test log", fake_out.getvalue()) + + @patch("crewai.cli.deploy.main.DeployCommand._display_deployment_info") + def test_deploy_with_uuid(self, mock_display): + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = {"uuid": "test-uuid"} + self.mock_client.deploy_by_uuid.return_value = mock_response + + self.deploy_command.deploy(uuid="test-uuid") + + self.mock_client.deploy_by_uuid.assert_called_once_with("test-uuid") + mock_display.assert_called_once_with({"uuid": "test-uuid"}) + + @patch("crewai.cli.deploy.main.DeployCommand._display_deployment_info") + def test_deploy_with_project_name(self, mock_display): + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = {"uuid": "test-uuid"} + self.mock_client.deploy_by_name.return_value = mock_response + + self.deploy_command.deploy() + + self.mock_client.deploy_by_name.assert_called_once_with("test_project") + mock_display.assert_called_once_with({"uuid": "test-uuid"}) + + @patch("crewai.cli.deploy.main.fetch_and_json_env_file") + @patch("crewai.cli.deploy.main.get_git_remote_url") + @patch("builtins.input") + def test_create_crew(self, mock_input, mock_get_git_remote_url, mock_fetch_env): + mock_fetch_env.return_value = {"ENV_VAR": "value"} + mock_get_git_remote_url.return_value = "https://github.com/test/repo.git" + mock_input.return_value = "" + + mock_response = MagicMock() + mock_response.status_code = 201 + mock_response.json.return_value = {"uuid": "new-uuid", "status": "created"} + self.mock_client.create_crew.return_value = mock_response + + with patch("sys.stdout", new=StringIO()) as fake_out: + self.deploy_command.create_crew() + self.assertIn("Deployment created successfully!", fake_out.getvalue()) + self.assertIn("new-uuid", fake_out.getvalue()) + + def test_list_crews(self): + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = [ + {"name": "Crew1", "uuid": "uuid1", "status": "active"}, + {"name": "Crew2", "uuid": "uuid2", "status": "inactive"}, + ] + self.mock_client.list_crews.return_value = mock_response + + with patch("sys.stdout", new=StringIO()) as fake_out: + self.deploy_command.list_crews() + self.assertIn("Crew1 (uuid1) active", fake_out.getvalue()) + self.assertIn("Crew2 (uuid2) inactive", fake_out.getvalue()) + + def test_get_crew_status(self): + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = {"name": "TestCrew", "status": "active"} + self.mock_client.status_by_name.return_value = mock_response + + with patch("sys.stdout", new=StringIO()) as fake_out: + self.deploy_command.get_crew_status() + self.assertIn("TestCrew", fake_out.getvalue()) + self.assertIn("active", fake_out.getvalue()) + + def test_get_crew_logs(self): + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = [ + {"timestamp": "2023-01-01", "level": "INFO", "message": "Log1"}, + {"timestamp": "2023-01-02", "level": "ERROR", "message": "Log2"}, + ] + self.mock_client.logs_by_name.return_value = mock_response + + with patch("sys.stdout", new=StringIO()) as fake_out: + self.deploy_command.get_crew_logs(None) + self.assertIn("2023-01-01 - INFO: Log1", fake_out.getvalue()) + self.assertIn("2023-01-02 - ERROR: Log2", fake_out.getvalue()) + + def test_remove_crew(self): + mock_response = MagicMock() + mock_response.status_code = 204 + self.mock_client.delete_by_name.return_value = mock_response + + with patch("sys.stdout", new=StringIO()) as fake_out: + self.deploy_command.remove_crew(None) + self.assertIn( + "Crew 'test_project' removed successfully", fake_out.getvalue() + )