mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-07 15:18:29 +00:00
Compare commits
2 Commits
devin/1755
...
devin/1755
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
774dbe3474 | ||
|
|
f96b779df5 |
@@ -7,7 +7,8 @@ from rich.console import Console
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
from .utils import TokenManager, validate_jwt_token
|
||||
from .utils import validate_jwt_token
|
||||
from crewai.cli.shared.token_manager import TokenManager
|
||||
from urllib.parse import quote
|
||||
from crewai.cli.plus_api import PlusAPI
|
||||
from crewai.cli.config import Settings
|
||||
@@ -21,10 +22,19 @@ console = Console()
|
||||
|
||||
|
||||
class Oauth2Settings(BaseModel):
|
||||
provider: str = Field(description="OAuth2 provider used for authentication (e.g., workos, okta, auth0).")
|
||||
client_id: str = Field(description="OAuth2 client ID issued by the provider, used during authentication requests.")
|
||||
domain: str = Field(description="OAuth2 provider's domain (e.g., your-org.auth0.com) used for issuing tokens.")
|
||||
audience: Optional[str] = Field(description="OAuth2 audience value, typically used to identify the target API or resource.", default=None)
|
||||
provider: str = Field(
|
||||
description="OAuth2 provider used for authentication (e.g., workos, okta, auth0)."
|
||||
)
|
||||
client_id: str = Field(
|
||||
description="OAuth2 client ID issued by the provider, used during authentication requests."
|
||||
)
|
||||
domain: str = Field(
|
||||
description="OAuth2 provider's domain (e.g., your-org.auth0.com) used for issuing tokens."
|
||||
)
|
||||
audience: Optional[str] = Field(
|
||||
description="OAuth2 audience value, typically used to identify the target API or resource.",
|
||||
default=None,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_settings(cls):
|
||||
@@ -44,11 +54,15 @@ class ProviderFactory:
|
||||
settings = settings or Oauth2Settings.from_settings()
|
||||
|
||||
import importlib
|
||||
module = importlib.import_module(f"crewai.cli.authentication.providers.{settings.provider.lower()}")
|
||||
|
||||
module = importlib.import_module(
|
||||
f"crewai.cli.authentication.providers.{settings.provider.lower()}"
|
||||
)
|
||||
provider = getattr(module, f"{settings.provider.capitalize()}Provider")
|
||||
|
||||
return provider(settings)
|
||||
|
||||
|
||||
class AuthenticationCommand:
|
||||
def __init__(self):
|
||||
self.token_manager = TokenManager()
|
||||
@@ -65,7 +79,7 @@ class AuthenticationCommand:
|
||||
provider="auth0",
|
||||
client_id=AUTH0_CLIENT_ID,
|
||||
domain=AUTH0_DOMAIN,
|
||||
audience=AUTH0_AUDIENCE
|
||||
audience=AUTH0_AUDIENCE,
|
||||
)
|
||||
self.oauth2_provider = ProviderFactory.from_settings(settings)
|
||||
# End of temporary code.
|
||||
@@ -75,9 +89,7 @@ class AuthenticationCommand:
|
||||
|
||||
return self._poll_for_token(device_code_data)
|
||||
|
||||
def _get_device_code(
|
||||
self
|
||||
) -> Dict[str, Any]:
|
||||
def _get_device_code(self) -> Dict[str, Any]:
|
||||
"""Get the device code to authenticate the user."""
|
||||
|
||||
device_code_payload = {
|
||||
@@ -86,7 +98,9 @@ class AuthenticationCommand:
|
||||
"audience": self.oauth2_provider.get_audience(),
|
||||
}
|
||||
response = requests.post(
|
||||
url=self.oauth2_provider.get_authorize_url(), data=device_code_payload, timeout=20
|
||||
url=self.oauth2_provider.get_authorize_url(),
|
||||
data=device_code_payload,
|
||||
timeout=20,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
@@ -97,9 +111,7 @@ class AuthenticationCommand:
|
||||
console.print("2. Enter the following code: ", device_code_data["user_code"])
|
||||
webbrowser.open(device_code_data["verification_uri_complete"])
|
||||
|
||||
def _poll_for_token(
|
||||
self, device_code_data: Dict[str, Any]
|
||||
) -> None:
|
||||
def _poll_for_token(self, device_code_data: Dict[str, Any]) -> None:
|
||||
"""Polls the server for the token until it is received, or max attempts are reached."""
|
||||
|
||||
token_payload = {
|
||||
@@ -112,7 +124,9 @@ class AuthenticationCommand:
|
||||
|
||||
attempts = 0
|
||||
while True and attempts < 10:
|
||||
response = requests.post(self.oauth2_provider.get_token_url(), data=token_payload, timeout=30)
|
||||
response = requests.post(
|
||||
self.oauth2_provider.get_token_url(), data=token_payload, timeout=30
|
||||
)
|
||||
token_data = response.json()
|
||||
|
||||
if response.status_code == 200:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from .utils import TokenManager
|
||||
from crewai.cli.shared.token_manager import TokenManager
|
||||
|
||||
|
||||
class AuthError(Exception):
|
||||
|
||||
@@ -1,12 +1,5 @@
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
import jwt
|
||||
from jwt import PyJWKClient
|
||||
from cryptography.fernet import Fernet
|
||||
|
||||
|
||||
def validate_jwt_token(
|
||||
@@ -67,118 +60,3 @@ def validate_jwt_token(
|
||||
raise Exception(f"JWKS or key processing error: {str(e)}")
|
||||
except jwt.InvalidTokenError as e:
|
||||
raise Exception(f"Invalid token: {str(e)}")
|
||||
|
||||
|
||||
class TokenManager:
|
||||
def __init__(self, file_path: str = "tokens.enc") -> None:
|
||||
"""
|
||||
Initialize the TokenManager class.
|
||||
|
||||
:param file_path: The file path to store the encrypted tokens. Default is "tokens.enc".
|
||||
"""
|
||||
self.file_path = file_path
|
||||
self.key = self._get_or_create_key()
|
||||
self.fernet = Fernet(self.key)
|
||||
|
||||
def _get_or_create_key(self) -> bytes:
|
||||
"""
|
||||
Get or create the encryption key.
|
||||
|
||||
:return: The encryption key.
|
||||
"""
|
||||
key_filename = "secret.key"
|
||||
key = self.read_secure_file(key_filename)
|
||||
|
||||
if key is not None:
|
||||
return key
|
||||
|
||||
new_key = Fernet.generate_key()
|
||||
self.save_secure_file(key_filename, new_key)
|
||||
return new_key
|
||||
|
||||
def save_tokens(self, access_token: str, expires_at: int) -> None:
|
||||
"""
|
||||
Save the access token and its expiration time.
|
||||
|
||||
:param access_token: The access token to save.
|
||||
:param expires_at: The UNIX timestamp of the expiration time.
|
||||
"""
|
||||
expiration_time = datetime.fromtimestamp(expires_at)
|
||||
data = {
|
||||
"access_token": access_token,
|
||||
"expiration": expiration_time.isoformat(),
|
||||
}
|
||||
encrypted_data = self.fernet.encrypt(json.dumps(data).encode())
|
||||
self.save_secure_file(self.file_path, encrypted_data)
|
||||
|
||||
def get_token(self) -> Optional[str]:
|
||||
"""
|
||||
Get the access token if it is valid and not expired.
|
||||
|
||||
:return: The access token if valid and not expired, otherwise None.
|
||||
"""
|
||||
encrypted_data = self.read_secure_file(self.file_path)
|
||||
|
||||
decrypted_data = self.fernet.decrypt(encrypted_data) # type: ignore
|
||||
data = json.loads(decrypted_data)
|
||||
|
||||
expiration = datetime.fromisoformat(data["expiration"])
|
||||
if expiration <= datetime.now():
|
||||
return None
|
||||
|
||||
return data["access_token"]
|
||||
|
||||
def get_secure_storage_path(self) -> Path:
|
||||
"""
|
||||
Get the secure storage path based on the operating system.
|
||||
|
||||
:return: The secure storage path.
|
||||
"""
|
||||
if sys.platform == "win32":
|
||||
# Windows: Use %LOCALAPPDATA%
|
||||
base_path = os.environ.get("LOCALAPPDATA")
|
||||
elif sys.platform == "darwin":
|
||||
# macOS: Use ~/Library/Application Support
|
||||
base_path = os.path.expanduser("~/Library/Application Support")
|
||||
else:
|
||||
# Linux and other Unix-like: Use ~/.local/share
|
||||
base_path = os.path.expanduser("~/.local/share")
|
||||
|
||||
app_name = "crewai/credentials"
|
||||
storage_path = Path(base_path) / app_name
|
||||
|
||||
storage_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
return storage_path
|
||||
|
||||
def save_secure_file(self, filename: str, content: bytes) -> None:
|
||||
"""
|
||||
Save the content to a secure file.
|
||||
|
||||
:param filename: The name of the file.
|
||||
:param content: The content to save.
|
||||
"""
|
||||
storage_path = self.get_secure_storage_path()
|
||||
file_path = storage_path / filename
|
||||
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(content)
|
||||
|
||||
# Set appropriate permissions (read/write for owner only)
|
||||
os.chmod(file_path, 0o600)
|
||||
|
||||
def read_secure_file(self, filename: str) -> Optional[bytes]:
|
||||
"""
|
||||
Read the content of a secure file.
|
||||
|
||||
:param filename: The name of the file.
|
||||
:return: The content of the file if it exists, otherwise None.
|
||||
"""
|
||||
storage_path = self.get_secure_storage_path()
|
||||
file_path = storage_path / filename
|
||||
|
||||
if not file_path.exists():
|
||||
return None
|
||||
|
||||
with open(file_path, "rb") as f:
|
||||
return f.read()
|
||||
|
||||
@@ -11,6 +11,7 @@ from crewai.cli.constants import (
|
||||
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_CLIENT_ID,
|
||||
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN,
|
||||
)
|
||||
from crewai.cli.shared.token_manager import TokenManager
|
||||
|
||||
DEFAULT_CONFIG_PATH = Path.home() / ".config" / "crewai" / "settings.json"
|
||||
|
||||
@@ -53,6 +54,7 @@ HIDDEN_SETTINGS_KEYS = [
|
||||
"tool_repository_password",
|
||||
]
|
||||
|
||||
|
||||
class Settings(BaseModel):
|
||||
enterprise_base_url: Optional[str] = Field(
|
||||
default=DEFAULT_CLI_SETTINGS["enterprise_base_url"],
|
||||
@@ -74,12 +76,12 @@ class Settings(BaseModel):
|
||||
|
||||
oauth2_provider: str = Field(
|
||||
description="OAuth2 provider used for authentication (e.g., workos, okta, auth0).",
|
||||
default=DEFAULT_CLI_SETTINGS["oauth2_provider"]
|
||||
default=DEFAULT_CLI_SETTINGS["oauth2_provider"],
|
||||
)
|
||||
|
||||
oauth2_audience: Optional[str] = Field(
|
||||
description="OAuth2 audience value, typically used to identify the target API or resource.",
|
||||
default=DEFAULT_CLI_SETTINGS["oauth2_audience"]
|
||||
default=DEFAULT_CLI_SETTINGS["oauth2_audience"],
|
||||
)
|
||||
|
||||
oauth2_client_id: str = Field(
|
||||
@@ -89,7 +91,7 @@ class Settings(BaseModel):
|
||||
|
||||
oauth2_domain: str = Field(
|
||||
description="OAuth2 provider's domain (e.g., your-org.auth0.com) used for issuing tokens.",
|
||||
default=DEFAULT_CLI_SETTINGS["oauth2_domain"]
|
||||
default=DEFAULT_CLI_SETTINGS["oauth2_domain"],
|
||||
)
|
||||
|
||||
def __init__(self, config_path: Path = DEFAULT_CONFIG_PATH, **data):
|
||||
@@ -116,6 +118,7 @@ class Settings(BaseModel):
|
||||
"""Reset all settings to default values"""
|
||||
self._reset_user_settings()
|
||||
self._reset_cli_settings()
|
||||
self._clear_auth_tokens()
|
||||
self.dump()
|
||||
|
||||
def dump(self) -> None:
|
||||
@@ -139,3 +142,7 @@ class Settings(BaseModel):
|
||||
"""Reset all CLI settings to default values"""
|
||||
for key in CLI_SETTINGS_KEYS:
|
||||
setattr(self, key, DEFAULT_CLI_SETTINGS.get(key))
|
||||
|
||||
def _clear_auth_tokens(self) -> None:
|
||||
"""Clear all authentication tokens"""
|
||||
TokenManager().clear_tokens()
|
||||
|
||||
0
src/crewai/cli/shared/__init__.py
Normal file
0
src/crewai/cli/shared/__init__.py
Normal file
139
src/crewai/cli/shared/token_manager.py
Normal file
139
src/crewai/cli/shared/token_manager.py
Normal file
@@ -0,0 +1,139 @@
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from cryptography.fernet import Fernet
|
||||
|
||||
|
||||
class TokenManager:
|
||||
def __init__(self, file_path: str = "tokens.enc") -> None:
|
||||
"""
|
||||
Initialize the TokenManager class.
|
||||
|
||||
:param file_path: The file path to store the encrypted tokens. Default is "tokens.enc".
|
||||
"""
|
||||
self.file_path = file_path
|
||||
self.key = self._get_or_create_key()
|
||||
self.fernet = Fernet(self.key)
|
||||
|
||||
def _get_or_create_key(self) -> bytes:
|
||||
"""
|
||||
Get or create the encryption key.
|
||||
|
||||
:return: The encryption key.
|
||||
"""
|
||||
key_filename = "secret.key"
|
||||
key = self.read_secure_file(key_filename)
|
||||
|
||||
if key is not None:
|
||||
return key
|
||||
|
||||
new_key = Fernet.generate_key()
|
||||
self.save_secure_file(key_filename, new_key)
|
||||
return new_key
|
||||
|
||||
def save_tokens(self, access_token: str, expires_at: int) -> None:
|
||||
"""
|
||||
Save the access token and its expiration time.
|
||||
|
||||
:param access_token: The access token to save.
|
||||
:param expires_at: The UNIX timestamp of the expiration time.
|
||||
"""
|
||||
expiration_time = datetime.fromtimestamp(expires_at)
|
||||
data = {
|
||||
"access_token": access_token,
|
||||
"expiration": expiration_time.isoformat(),
|
||||
}
|
||||
encrypted_data = self.fernet.encrypt(json.dumps(data).encode())
|
||||
self.save_secure_file(self.file_path, encrypted_data)
|
||||
|
||||
def get_token(self) -> Optional[str]:
|
||||
"""
|
||||
Get the access token if it is valid and not expired.
|
||||
|
||||
:return: The access token if valid and not expired, otherwise None.
|
||||
"""
|
||||
encrypted_data = self.read_secure_file(self.file_path)
|
||||
|
||||
decrypted_data = self.fernet.decrypt(encrypted_data) # type: ignore
|
||||
data = json.loads(decrypted_data)
|
||||
|
||||
expiration = datetime.fromisoformat(data["expiration"])
|
||||
if expiration <= datetime.now():
|
||||
return None
|
||||
|
||||
return data["access_token"]
|
||||
|
||||
def clear_tokens(self) -> None:
|
||||
"""
|
||||
Clear the tokens.
|
||||
"""
|
||||
self.delete_secure_file(self.file_path)
|
||||
|
||||
def get_secure_storage_path(self) -> Path:
|
||||
"""
|
||||
Get the secure storage path based on the operating system.
|
||||
|
||||
:return: The secure storage path.
|
||||
"""
|
||||
if sys.platform == "win32":
|
||||
# Windows: Use %LOCALAPPDATA%
|
||||
base_path = os.environ.get("LOCALAPPDATA")
|
||||
elif sys.platform == "darwin":
|
||||
# macOS: Use ~/Library/Application Support
|
||||
base_path = os.path.expanduser("~/Library/Application Support")
|
||||
else:
|
||||
# Linux and other Unix-like: Use ~/.local/share
|
||||
base_path = os.path.expanduser("~/.local/share")
|
||||
|
||||
app_name = "crewai/credentials"
|
||||
storage_path = Path(base_path) / app_name
|
||||
|
||||
storage_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
return storage_path
|
||||
|
||||
def save_secure_file(self, filename: str, content: bytes) -> None:
|
||||
"""
|
||||
Save the content to a secure file.
|
||||
|
||||
:param filename: The name of the file.
|
||||
:param content: The content to save.
|
||||
"""
|
||||
storage_path = self.get_secure_storage_path()
|
||||
file_path = storage_path / filename
|
||||
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(content)
|
||||
|
||||
# Set appropriate permissions (read/write for owner only)
|
||||
os.chmod(file_path, 0o600)
|
||||
|
||||
def read_secure_file(self, filename: str) -> Optional[bytes]:
|
||||
"""
|
||||
Read the content of a secure file.
|
||||
|
||||
:param filename: The name of the file.
|
||||
:return: The content of the file if it exists, otherwise None.
|
||||
"""
|
||||
storage_path = self.get_secure_storage_path()
|
||||
file_path = storage_path / filename
|
||||
|
||||
if not file_path.exists():
|
||||
return None
|
||||
|
||||
with open(file_path, "rb") as f:
|
||||
return f.read()
|
||||
|
||||
def delete_secure_file(self, filename: str) -> None:
|
||||
"""
|
||||
Delete the secure file.
|
||||
|
||||
:param filename: The name of the file.
|
||||
"""
|
||||
storage_path = self.get_secure_storage_path()
|
||||
file_path = storage_path / filename
|
||||
if file_path.exists():
|
||||
file_path.unlink(missing_ok=True)
|
||||
@@ -7,6 +7,7 @@ from crewai.utilities.chromadb import sanitize_collection_name
|
||||
from crewai.memory.storage.interface import Storage
|
||||
|
||||
MAX_AGENT_ID_LENGTH_MEM0 = 255
|
||||
MAX_METADATA_SIZE_MEM0 = 2000
|
||||
|
||||
|
||||
class Mem0Storage(Storage):
|
||||
@@ -98,8 +99,11 @@ class Mem0Storage(Storage):
|
||||
}
|
||||
|
||||
# Shared base params
|
||||
raw_metadata = {"type": base_metadata[self.memory_type], **metadata}
|
||||
truncated_metadata = self._truncate_metadata_if_needed(raw_metadata)
|
||||
|
||||
params: dict[str, Any] = {
|
||||
"metadata": {"type": base_metadata[self.memory_type], **metadata},
|
||||
"metadata": truncated_metadata,
|
||||
"infer": self.infer
|
||||
}
|
||||
|
||||
@@ -181,3 +185,30 @@ class Mem0Storage(Storage):
|
||||
agents = [self._sanitize_role(agent.role) for agent in agents]
|
||||
agents = "_".join(agents)
|
||||
return sanitize_collection_name(name=agents, max_collection_length=MAX_AGENT_ID_LENGTH_MEM0)
|
||||
|
||||
def _truncate_metadata_if_needed(self, metadata: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Truncate metadata to stay within Mem0's size limits.
|
||||
Prioritizes essential fields and truncates messages if needed.
|
||||
"""
|
||||
import json
|
||||
|
||||
metadata_str = json.dumps(metadata, default=str)
|
||||
|
||||
if len(metadata_str) <= MAX_METADATA_SIZE_MEM0:
|
||||
return metadata
|
||||
|
||||
truncated_metadata = metadata.copy()
|
||||
|
||||
if "messages" in truncated_metadata:
|
||||
messages = truncated_metadata["messages"]
|
||||
if isinstance(messages, list) and len(messages) > 0:
|
||||
while len(json.dumps(truncated_metadata, default=str)) > MAX_METADATA_SIZE_MEM0 and len(messages) > 1:
|
||||
messages = messages[-len(messages)//2:] # Keep last half
|
||||
truncated_metadata["messages"] = messages
|
||||
|
||||
if len(json.dumps(truncated_metadata, default=str)) > MAX_METADATA_SIZE_MEM0:
|
||||
truncated_metadata.pop("messages", None)
|
||||
truncated_metadata["_truncated"] = True
|
||||
|
||||
return truncated_metadata
|
||||
|
||||
@@ -1,17 +1,14 @@
|
||||
import json
|
||||
import jwt
|
||||
import unittest
|
||||
from datetime import datetime, timedelta
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from cryptography.fernet import Fernet
|
||||
|
||||
from crewai.cli.authentication.utils import TokenManager, validate_jwt_token
|
||||
from crewai.cli.authentication.utils import validate_jwt_token
|
||||
|
||||
|
||||
@patch("crewai.cli.authentication.utils.PyJWKClient", return_value=MagicMock())
|
||||
@patch("crewai.cli.authentication.utils.jwt")
|
||||
class TestValidateToken(unittest.TestCase):
|
||||
class TestUtils(unittest.TestCase):
|
||||
def test_validate_jwt_token(self, mock_jwt, mock_pyjwkclient):
|
||||
mock_jwt.decode.return_value = {"exp": 1719859200}
|
||||
|
||||
@@ -105,121 +102,3 @@ class TestValidateToken(unittest.TestCase):
|
||||
issuer="https://mock_issuer",
|
||||
audience="app_id_xxxx",
|
||||
)
|
||||
|
||||
|
||||
class TestTokenManager(unittest.TestCase):
|
||||
@patch("crewai.cli.authentication.utils.TokenManager._get_or_create_key")
|
||||
def setUp(self, mock_get_key):
|
||||
mock_get_key.return_value = Fernet.generate_key()
|
||||
self.token_manager = TokenManager()
|
||||
|
||||
@patch("crewai.cli.authentication.utils.TokenManager.read_secure_file")
|
||||
@patch("crewai.cli.authentication.utils.TokenManager.save_secure_file")
|
||||
@patch("crewai.cli.authentication.utils.TokenManager._get_or_create_key")
|
||||
def test_get_or_create_key_existing(self, mock_get_or_create, mock_save, mock_read):
|
||||
mock_key = Fernet.generate_key()
|
||||
mock_get_or_create.return_value = mock_key
|
||||
|
||||
token_manager = TokenManager()
|
||||
result = token_manager.key
|
||||
|
||||
self.assertEqual(result, mock_key)
|
||||
|
||||
@patch("crewai.cli.authentication.utils.Fernet.generate_key")
|
||||
@patch("crewai.cli.authentication.utils.TokenManager.read_secure_file")
|
||||
@patch("crewai.cli.authentication.utils.TokenManager.save_secure_file")
|
||||
def test_get_or_create_key_new(self, mock_save, mock_read, mock_generate):
|
||||
mock_key = b"new_key"
|
||||
mock_read.return_value = None
|
||||
mock_generate.return_value = mock_key
|
||||
|
||||
result = self.token_manager._get_or_create_key()
|
||||
|
||||
self.assertEqual(result, mock_key)
|
||||
mock_read.assert_called_once_with("secret.key")
|
||||
mock_generate.assert_called_once()
|
||||
mock_save.assert_called_once_with("secret.key", mock_key)
|
||||
|
||||
@patch("crewai.cli.authentication.utils.TokenManager.save_secure_file")
|
||||
def test_save_tokens(self, mock_save):
|
||||
access_token = "test_token"
|
||||
expires_at = int((datetime.now() + timedelta(seconds=3600)).timestamp())
|
||||
|
||||
self.token_manager.save_tokens(access_token, expires_at)
|
||||
|
||||
mock_save.assert_called_once()
|
||||
args = mock_save.call_args[0]
|
||||
self.assertEqual(args[0], "tokens.enc")
|
||||
decrypted_data = self.token_manager.fernet.decrypt(args[1])
|
||||
data = json.loads(decrypted_data)
|
||||
self.assertEqual(data["access_token"], access_token)
|
||||
expiration = datetime.fromisoformat(data["expiration"])
|
||||
self.assertEqual(expiration, datetime.fromtimestamp(expires_at))
|
||||
|
||||
@patch("crewai.cli.authentication.utils.TokenManager.read_secure_file")
|
||||
def test_get_token_valid(self, mock_read):
|
||||
access_token = "test_token"
|
||||
expiration = (datetime.now() + timedelta(hours=1)).isoformat()
|
||||
data = {"access_token": access_token, "expiration": expiration}
|
||||
encrypted_data = self.token_manager.fernet.encrypt(json.dumps(data).encode())
|
||||
mock_read.return_value = encrypted_data
|
||||
|
||||
result = self.token_manager.get_token()
|
||||
|
||||
self.assertEqual(result, access_token)
|
||||
|
||||
@patch("crewai.cli.authentication.utils.TokenManager.read_secure_file")
|
||||
def test_get_token_expired(self, mock_read):
|
||||
access_token = "test_token"
|
||||
expiration = (datetime.now() - timedelta(hours=1)).isoformat()
|
||||
data = {"access_token": access_token, "expiration": expiration}
|
||||
encrypted_data = self.token_manager.fernet.encrypt(json.dumps(data).encode())
|
||||
mock_read.return_value = encrypted_data
|
||||
|
||||
result = self.token_manager.get_token()
|
||||
|
||||
self.assertIsNone(result)
|
||||
|
||||
@patch("crewai.cli.authentication.utils.TokenManager.get_secure_storage_path")
|
||||
@patch("builtins.open", new_callable=unittest.mock.mock_open)
|
||||
@patch("crewai.cli.authentication.utils.os.chmod")
|
||||
def test_save_secure_file(self, mock_chmod, mock_open, mock_get_path):
|
||||
mock_path = MagicMock()
|
||||
mock_get_path.return_value = mock_path
|
||||
filename = "test_file.txt"
|
||||
content = b"test_content"
|
||||
|
||||
self.token_manager.save_secure_file(filename, content)
|
||||
|
||||
mock_path.__truediv__.assert_called_once_with(filename)
|
||||
mock_open.assert_called_once_with(mock_path.__truediv__.return_value, "wb")
|
||||
mock_open().write.assert_called_once_with(content)
|
||||
mock_chmod.assert_called_once_with(mock_path.__truediv__.return_value, 0o600)
|
||||
|
||||
@patch("crewai.cli.authentication.utils.TokenManager.get_secure_storage_path")
|
||||
@patch(
|
||||
"builtins.open", new_callable=unittest.mock.mock_open, read_data=b"test_content"
|
||||
)
|
||||
def test_read_secure_file_exists(self, mock_open, mock_get_path):
|
||||
mock_path = MagicMock()
|
||||
mock_get_path.return_value = mock_path
|
||||
mock_path.__truediv__.return_value.exists.return_value = True
|
||||
filename = "test_file.txt"
|
||||
|
||||
result = self.token_manager.read_secure_file(filename)
|
||||
|
||||
self.assertEqual(result, b"test_content")
|
||||
mock_path.__truediv__.assert_called_once_with(filename)
|
||||
mock_open.assert_called_once_with(mock_path.__truediv__.return_value, "rb")
|
||||
|
||||
@patch("crewai.cli.authentication.utils.TokenManager.get_secure_storage_path")
|
||||
def test_read_secure_file_not_exists(self, mock_get_path):
|
||||
mock_path = MagicMock()
|
||||
mock_get_path.return_value = mock_path
|
||||
mock_path.__truediv__.return_value.exists.return_value = False
|
||||
filename = "test_file.txt"
|
||||
|
||||
result = self.token_manager.read_secure_file(filename)
|
||||
|
||||
self.assertIsNone(result)
|
||||
mock_path.__truediv__.assert_called_once_with(filename)
|
||||
|
||||
@@ -3,6 +3,7 @@ import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
from crewai.cli.config import (
|
||||
Settings,
|
||||
@@ -10,6 +11,8 @@ from crewai.cli.config import (
|
||||
CLI_SETTINGS_KEYS,
|
||||
DEFAULT_CLI_SETTINGS,
|
||||
)
|
||||
from crewai.cli.shared.token_manager import TokenManager
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
|
||||
class TestSettings(unittest.TestCase):
|
||||
@@ -66,7 +69,8 @@ class TestSettings(unittest.TestCase):
|
||||
for key in user_settings.keys():
|
||||
self.assertEqual(getattr(settings, key), None)
|
||||
|
||||
def test_reset_settings(self):
|
||||
@patch("crewai.cli.config.TokenManager")
|
||||
def test_reset_settings(self, mock_token_manager):
|
||||
user_settings = {key: f"value_for_{key}" for key in USER_SETTINGS_KEYS}
|
||||
cli_settings = {key: f"value_for_{key}" for key in CLI_SETTINGS_KEYS}
|
||||
|
||||
@@ -74,6 +78,11 @@ class TestSettings(unittest.TestCase):
|
||||
config_path=self.config_path, **user_settings, **cli_settings
|
||||
)
|
||||
|
||||
mock_token_manager.return_value = MagicMock()
|
||||
TokenManager().save_tokens(
|
||||
"aaa.bbb.ccc", (datetime.now() + timedelta(seconds=36000)).timestamp()
|
||||
)
|
||||
|
||||
settings.reset()
|
||||
|
||||
for key in user_settings.keys():
|
||||
@@ -81,6 +90,8 @@ class TestSettings(unittest.TestCase):
|
||||
for key in cli_settings.keys():
|
||||
self.assertEqual(getattr(settings, key), DEFAULT_CLI_SETTINGS.get(key))
|
||||
|
||||
mock_token_manager.return_value.clear_tokens.assert_called_once()
|
||||
|
||||
def test_dump_new_settings(self):
|
||||
settings = Settings(
|
||||
config_path=self.config_path, tool_repository_username="user1"
|
||||
|
||||
138
tests/cli/test_token_manager.py
Normal file
138
tests/cli/test_token_manager.py
Normal file
@@ -0,0 +1,138 @@
|
||||
import json
|
||||
import unittest
|
||||
from datetime import datetime, timedelta
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from cryptography.fernet import Fernet
|
||||
|
||||
from crewai.cli.shared.token_manager import TokenManager
|
||||
|
||||
|
||||
class TestTokenManager(unittest.TestCase):
|
||||
@patch("crewai.cli.shared.token_manager.TokenManager._get_or_create_key")
|
||||
def setUp(self, mock_get_key):
|
||||
mock_get_key.return_value = Fernet.generate_key()
|
||||
self.token_manager = TokenManager()
|
||||
|
||||
@patch("crewai.cli.shared.token_manager.TokenManager.read_secure_file")
|
||||
@patch("crewai.cli.shared.token_manager.TokenManager.save_secure_file")
|
||||
@patch("crewai.cli.shared.token_manager.TokenManager._get_or_create_key")
|
||||
def test_get_or_create_key_existing(self, mock_get_or_create, mock_save, mock_read):
|
||||
mock_key = Fernet.generate_key()
|
||||
mock_get_or_create.return_value = mock_key
|
||||
|
||||
token_manager = TokenManager()
|
||||
result = token_manager.key
|
||||
|
||||
self.assertEqual(result, mock_key)
|
||||
|
||||
@patch("crewai.cli.shared.token_manager.Fernet.generate_key")
|
||||
@patch("crewai.cli.shared.token_manager.TokenManager.read_secure_file")
|
||||
@patch("crewai.cli.shared.token_manager.TokenManager.save_secure_file")
|
||||
def test_get_or_create_key_new(self, mock_save, mock_read, mock_generate):
|
||||
mock_key = b"new_key"
|
||||
mock_read.return_value = None
|
||||
mock_generate.return_value = mock_key
|
||||
|
||||
result = self.token_manager._get_or_create_key()
|
||||
|
||||
self.assertEqual(result, mock_key)
|
||||
mock_read.assert_called_once_with("secret.key")
|
||||
mock_generate.assert_called_once()
|
||||
mock_save.assert_called_once_with("secret.key", mock_key)
|
||||
|
||||
@patch("crewai.cli.shared.token_manager.TokenManager.save_secure_file")
|
||||
def test_save_tokens(self, mock_save):
|
||||
access_token = "test_token"
|
||||
expires_at = int((datetime.now() + timedelta(seconds=3600)).timestamp())
|
||||
|
||||
self.token_manager.save_tokens(access_token, expires_at)
|
||||
|
||||
mock_save.assert_called_once()
|
||||
args = mock_save.call_args[0]
|
||||
self.assertEqual(args[0], "tokens.enc")
|
||||
decrypted_data = self.token_manager.fernet.decrypt(args[1])
|
||||
data = json.loads(decrypted_data)
|
||||
self.assertEqual(data["access_token"], access_token)
|
||||
expiration = datetime.fromisoformat(data["expiration"])
|
||||
self.assertEqual(expiration, datetime.fromtimestamp(expires_at))
|
||||
|
||||
@patch("crewai.cli.shared.token_manager.TokenManager.read_secure_file")
|
||||
def test_get_token_valid(self, mock_read):
|
||||
access_token = "test_token"
|
||||
expiration = (datetime.now() + timedelta(hours=1)).isoformat()
|
||||
data = {"access_token": access_token, "expiration": expiration}
|
||||
encrypted_data = self.token_manager.fernet.encrypt(json.dumps(data).encode())
|
||||
mock_read.return_value = encrypted_data
|
||||
|
||||
result = self.token_manager.get_token()
|
||||
|
||||
self.assertEqual(result, access_token)
|
||||
|
||||
@patch("crewai.cli.shared.token_manager.TokenManager.read_secure_file")
|
||||
def test_get_token_expired(self, mock_read):
|
||||
access_token = "test_token"
|
||||
expiration = (datetime.now() - timedelta(hours=1)).isoformat()
|
||||
data = {"access_token": access_token, "expiration": expiration}
|
||||
encrypted_data = self.token_manager.fernet.encrypt(json.dumps(data).encode())
|
||||
mock_read.return_value = encrypted_data
|
||||
|
||||
result = self.token_manager.get_token()
|
||||
|
||||
self.assertIsNone(result)
|
||||
|
||||
@patch("crewai.cli.shared.token_manager.TokenManager.get_secure_storage_path")
|
||||
@patch("builtins.open", new_callable=unittest.mock.mock_open)
|
||||
@patch("crewai.cli.shared.token_manager.os.chmod")
|
||||
def test_save_secure_file(self, mock_chmod, mock_open, mock_get_path):
|
||||
mock_path = MagicMock()
|
||||
mock_get_path.return_value = mock_path
|
||||
filename = "test_file.txt"
|
||||
content = b"test_content"
|
||||
|
||||
self.token_manager.save_secure_file(filename, content)
|
||||
|
||||
mock_path.__truediv__.assert_called_once_with(filename)
|
||||
mock_open.assert_called_once_with(mock_path.__truediv__.return_value, "wb")
|
||||
mock_open().write.assert_called_once_with(content)
|
||||
mock_chmod.assert_called_once_with(mock_path.__truediv__.return_value, 0o600)
|
||||
|
||||
@patch("crewai.cli.shared.token_manager.TokenManager.get_secure_storage_path")
|
||||
@patch(
|
||||
"builtins.open", new_callable=unittest.mock.mock_open, read_data=b"test_content"
|
||||
)
|
||||
def test_read_secure_file_exists(self, mock_open, mock_get_path):
|
||||
mock_path = MagicMock()
|
||||
mock_get_path.return_value = mock_path
|
||||
mock_path.__truediv__.return_value.exists.return_value = True
|
||||
filename = "test_file.txt"
|
||||
|
||||
result = self.token_manager.read_secure_file(filename)
|
||||
|
||||
self.assertEqual(result, b"test_content")
|
||||
mock_path.__truediv__.assert_called_once_with(filename)
|
||||
mock_open.assert_called_once_with(mock_path.__truediv__.return_value, "rb")
|
||||
|
||||
@patch("crewai.cli.shared.token_manager.TokenManager.get_secure_storage_path")
|
||||
def test_read_secure_file_not_exists(self, mock_get_path):
|
||||
mock_path = MagicMock()
|
||||
mock_get_path.return_value = mock_path
|
||||
mock_path.__truediv__.return_value.exists.return_value = False
|
||||
filename = "test_file.txt"
|
||||
|
||||
result = self.token_manager.read_secure_file(filename)
|
||||
|
||||
self.assertIsNone(result)
|
||||
mock_path.__truediv__.assert_called_once_with(filename)
|
||||
|
||||
@patch("crewai.cli.shared.token_manager.TokenManager.get_secure_storage_path")
|
||||
def test_clear_tokens(self, mock_get_path):
|
||||
mock_path = MagicMock()
|
||||
mock_get_path.return_value = mock_path
|
||||
|
||||
self.token_manager.clear_tokens()
|
||||
|
||||
mock_path.__truediv__.assert_called_once_with("tokens.enc")
|
||||
mock_path.__truediv__.return_value.unlink.assert_called_once_with(
|
||||
missing_ok=True
|
||||
)
|
||||
@@ -11,7 +11,7 @@ from unittest.mock import MagicMock, patch
|
||||
import pytest
|
||||
from pytest import raises
|
||||
|
||||
from crewai.cli.authentication.utils import TokenManager
|
||||
from crewai.cli.shared.token_manager import TokenManager
|
||||
from crewai.cli.tools.main import ToolCommand
|
||||
|
||||
|
||||
|
||||
@@ -328,6 +328,84 @@ def test_search_method_with_agent_entity():
|
||||
assert results[0]["context"] == "Result 1"
|
||||
|
||||
|
||||
def test_metadata_truncation_with_large_messages():
|
||||
"""Test that large messages in metadata are truncated to stay under limit"""
|
||||
mock_memory = MagicMock(spec=Memory)
|
||||
|
||||
with patch.object(Memory, "__new__", return_value=mock_memory):
|
||||
mem0_storage = Mem0Storage(type="external", config={})
|
||||
|
||||
large_messages = [
|
||||
{"role": "user", "content": "x" * 500},
|
||||
{"role": "assistant", "content": "y" * 500},
|
||||
{"role": "user", "content": "z" * 500},
|
||||
{"role": "assistant", "content": "w" * 500},
|
||||
]
|
||||
|
||||
large_metadata = {
|
||||
"description": "Test task",
|
||||
"messages": large_messages,
|
||||
}
|
||||
|
||||
mem0_storage.save("test memory", large_metadata)
|
||||
|
||||
call_args = mem0_storage.memory.add.call_args
|
||||
saved_metadata = call_args[1]["metadata"]
|
||||
|
||||
import json
|
||||
metadata_str = json.dumps(saved_metadata, default=str)
|
||||
assert len(metadata_str) <= 2000
|
||||
|
||||
assert saved_metadata["type"] == "external"
|
||||
assert "description" in saved_metadata
|
||||
|
||||
|
||||
def test_metadata_truncation_removes_messages_when_necessary():
|
||||
"""Test that messages are completely removed if truncation isn't enough"""
|
||||
mock_memory = MagicMock(spec=Memory)
|
||||
|
||||
with patch.object(Memory, "__new__", return_value=mock_memory):
|
||||
mem0_storage = Mem0Storage(type="external", config={})
|
||||
|
||||
extremely_large_messages = [{"role": "user", "content": "x" * 2000}]
|
||||
|
||||
large_metadata = {
|
||||
"description": "Test task",
|
||||
"messages": extremely_large_messages,
|
||||
}
|
||||
|
||||
mem0_storage.save("test memory", large_metadata)
|
||||
|
||||
call_args = mem0_storage.memory.add.call_args
|
||||
saved_metadata = call_args[1]["metadata"]
|
||||
|
||||
assert "messages" not in saved_metadata
|
||||
assert saved_metadata.get("_truncated") is True
|
||||
assert saved_metadata["type"] == "external"
|
||||
|
||||
|
||||
def test_small_metadata_not_truncated():
|
||||
"""Test that small metadata is not modified"""
|
||||
mock_memory = MagicMock(spec=Memory)
|
||||
|
||||
with patch.object(Memory, "__new__", return_value=mock_memory):
|
||||
mem0_storage = Mem0Storage(type="external", config={})
|
||||
|
||||
small_metadata = {
|
||||
"description": "Small task",
|
||||
"messages": [{"role": "user", "content": "short message"}],
|
||||
}
|
||||
|
||||
mem0_storage.save("test memory", small_metadata)
|
||||
|
||||
call_args = mem0_storage.memory.add.call_args
|
||||
saved_metadata = call_args[1]["metadata"]
|
||||
|
||||
assert "messages" in saved_metadata
|
||||
assert saved_metadata["messages"] == small_metadata["messages"]
|
||||
assert "_truncated" not in saved_metadata
|
||||
|
||||
|
||||
def test_search_method_with_agent_id_and_user_id():
|
||||
mock_memory = MagicMock(spec=Memory)
|
||||
mock_results = {"results": [{"score": 0.9, "memory": "Result 1"}, {"score": 0.4, "memory": "Result 2"}]}
|
||||
|
||||
Reference in New Issue
Block a user