Merge branch 'main' of github.com:crewAIInc/crewAI into add/agent-specific-knowledge

This commit is contained in:
Lorenze Jay
2024-11-25 15:29:53 -08:00
10 changed files with 45 additions and 22 deletions

View File

@@ -68,6 +68,7 @@
"concepts/tasks", "concepts/tasks",
"concepts/crews", "concepts/crews",
"concepts/flows", "concepts/flows",
"concepts/knowledge",
"concepts/llms", "concepts/llms",
"concepts/processes", "concepts/processes",
"concepts/collaboration", "concepts/collaboration",

View File

@@ -7,6 +7,7 @@ from rich.console import Console
from .constants import AUTH0_AUDIENCE, AUTH0_CLIENT_ID, AUTH0_DOMAIN from .constants import AUTH0_AUDIENCE, AUTH0_CLIENT_ID, AUTH0_DOMAIN
from .utils import TokenManager, validate_token from .utils import TokenManager, validate_token
from crewai.cli.tools.main import ToolCommand
console = Console() console = Console()
@@ -63,7 +64,22 @@ class AuthenticationCommand:
validate_token(token_data["id_token"]) validate_token(token_data["id_token"])
expires_in = 360000 # Token expiration time in seconds expires_in = 360000 # Token expiration time in seconds
self.token_manager.save_tokens(token_data["access_token"], expires_in) self.token_manager.save_tokens(token_data["access_token"], expires_in)
console.print("\nWelcome to CrewAI+ !!", style="green")
try:
ToolCommand().login()
except Exception:
console.print(
"\n[bold yellow]Warning:[/bold yellow] Authentication with the Tool Repository failed.",
style="yellow",
)
console.print(
"Other features will work normally, but you may experience limitations "
"with downloading and publishing tools."
"\nRun [bold]crewai login[/bold] to try logging in again.\n",
style="yellow",
)
console.print("\n[bold green]Welcome to CrewAI Enterprise![/bold green]\n")
return return
if token_data["error"] not in ("authorization_pending", "slow_down"): if token_data["error"] not in ("authorization_pending", "slow_down"):

View File

@@ -0,0 +1,10 @@
from .utils import TokenManager
def get_auth_token() -> str:
"""Get the authentication token."""
access_token = TokenManager().get_token()
if not access_token:
raise Exception()
return access_token

View File

@@ -2,7 +2,7 @@ import requests
from requests.exceptions import JSONDecodeError from requests.exceptions import JSONDecodeError
from rich.console import Console from rich.console import Console
from crewai.cli.plus_api import PlusAPI from crewai.cli.plus_api import PlusAPI
from crewai.cli.utils import get_auth_token from crewai.cli.authentication.token import get_auth_token
from crewai.telemetry.telemetry import Telemetry from crewai.telemetry.telemetry import Telemetry
console = Console() console = Console()

View File

@@ -1,7 +1,7 @@
from typing import Optional from typing import Optional
import requests import requests
from os import getenv from os import getenv
from crewai.cli.utils import get_crewai_version from crewai.cli.version import get_crewai_version
from urllib.parse import urljoin from urllib.parse import urljoin

View File

@@ -3,7 +3,8 @@ import subprocess
import click import click
from packaging import version from packaging import version
from crewai.cli.utils import get_crewai_version, read_toml from crewai.cli.utils import read_toml
from crewai.cli.version import get_crewai_version
def run_crew() -> None: def run_crew() -> None:

View File

@@ -1,4 +1,3 @@
import importlib.metadata
import os import os
import shutil import shutil
import sys import sys
@@ -9,7 +8,6 @@ import click
import tomli import tomli
from rich.console import Console from rich.console import Console
from crewai.cli.authentication.utils import TokenManager
from crewai.cli.constants import ENV_VARS from crewai.cli.constants import ENV_VARS
if sys.version_info >= (3, 11): if sys.version_info >= (3, 11):
@@ -137,11 +135,6 @@ def _get_nested_value(data: Dict[str, Any], keys: List[str]) -> Any:
return reduce(dict.__getitem__, keys, data) return reduce(dict.__getitem__, keys, data)
def get_crewai_version() -> str:
"""Get the version number of CrewAI running the CLI"""
return importlib.metadata.version("crewai")
def fetch_and_json_env_file(env_file_path: str = ".env") -> dict: def fetch_and_json_env_file(env_file_path: str = ".env") -> dict:
"""Fetch the environment variables from a .env file and return them as a dictionary.""" """Fetch the environment variables from a .env file and return them as a dictionary."""
try: try:
@@ -166,14 +159,6 @@ def fetch_and_json_env_file(env_file_path: str = ".env") -> dict:
return {} return {}
def get_auth_token() -> str:
"""Get the authentication token."""
access_token = TokenManager().get_token()
if not access_token:
raise Exception()
return access_token
def tree_copy(source, destination): def tree_copy(source, destination):
"""Copies the entire directory structure from the source to the destination.""" """Copies the entire directory structure from the source to the destination."""
for item in os.listdir(source): for item in os.listdir(source):

View File

@@ -0,0 +1,6 @@
import importlib.metadata
def get_crewai_version() -> str:
"""Get the version number of CrewAI running the CLI"""
return importlib.metadata.version("crewai")

View File

@@ -43,10 +43,11 @@ class TestAuthenticationCommand(unittest.TestCase):
mock_print.assert_any_call("2. Enter the following code: ", "ABCDEF") mock_print.assert_any_call("2. Enter the following code: ", "ABCDEF")
mock_open.assert_called_once_with("https://example.com") mock_open.assert_called_once_with("https://example.com")
@patch("crewai.cli.authentication.main.ToolCommand")
@patch("crewai.cli.authentication.main.requests.post") @patch("crewai.cli.authentication.main.requests.post")
@patch("crewai.cli.authentication.main.validate_token") @patch("crewai.cli.authentication.main.validate_token")
@patch("crewai.cli.authentication.main.console.print") @patch("crewai.cli.authentication.main.console.print")
def test_poll_for_token_success(self, mock_print, mock_validate_token, mock_post): def test_poll_for_token_success(self, mock_print, mock_validate_token, mock_post, mock_tool):
mock_response = MagicMock() mock_response = MagicMock()
mock_response.status_code = 200 mock_response.status_code = 200
mock_response.json.return_value = { mock_response.json.return_value = {
@@ -55,10 +56,13 @@ class TestAuthenticationCommand(unittest.TestCase):
} }
mock_post.return_value = mock_response mock_post.return_value = mock_response
mock_instance = mock_tool.return_value
mock_instance.login.return_value = None
self.auth_command._poll_for_token({"device_code": "123456"}) self.auth_command._poll_for_token({"device_code": "123456"})
mock_validate_token.assert_called_once_with("TOKEN") mock_validate_token.assert_called_once_with("TOKEN")
mock_print.assert_called_once_with("\nWelcome to CrewAI+ !!", style="green") mock_print.assert_called_once_with("\n[bold green]Welcome to CrewAI Enterprise![/bold green]\n")
@patch("crewai.cli.authentication.main.requests.post") @patch("crewai.cli.authentication.main.requests.post")
@patch("crewai.cli.authentication.main.console.print") @patch("crewai.cli.authentication.main.console.print")

View File

@@ -260,6 +260,6 @@ class TestDeployCommand(unittest.TestCase):
self.assertEqual(project_name, "test_project") self.assertEqual(project_name, "test_project")
def test_get_crewai_version(self): def test_get_crewai_version(self):
from crewai.cli.utils import get_crewai_version from crewai.cli.version import get_crewai_version
assert isinstance(get_crewai_version(), str) assert isinstance(get_crewai_version(), str)